diff --git a/.clang-tidy b/.clang-tidy index 5b36ac93d48..880a8ae9d97 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -1,20 +1,26 @@ Checks: > *, -abseil-*, + -altera-struct-pack-align, + -altera-unroll-loops, -android-*, - -cert-err58-cpp, + -boost-use-ranges, + -bugprone-easily-swappable-parameters, -cert-err58-cpp, -clang-analyzer-osx-*, -cppcoreguidelines-avoid-c-arrays, + -cppcoreguidelines-avoid-const-or-ref-data-members, + -cppcoreguidelines-avoid-do-while, -cppcoreguidelines-avoid-goto, -cppcoreguidelines-avoid-magic-numbers, -cppcoreguidelines-avoid-non-const-global-variables, + -cppcoreguidelines-non-private-member-variables-in-classes, -cppcoreguidelines-owning-memory, -cppcoreguidelines-pro-bounds-array-to-pointer-decay, + -cppcoreguidelines-pro-bounds-constant-array-index, -cppcoreguidelines-pro-bounds-pointer-arithmetic, -cppcoreguidelines-pro-type-reinterpret-cast, -cppcoreguidelines-pro-type-vararg, - -cppcoreguidelines-pro-type-vararg, -cppcoreguidelines-special-member-functions, -fuchsia-*, -google-*, @@ -25,25 +31,32 @@ Checks: > -hicpp-special-member-functions, -hicpp-use-equals-default, -hicpp-vararg, - -hicpp-vararg, -llvm-header-guard, -llvm-include-order, + -llvm-qualified-auto, -llvmlibc-*, - -misc-no-recursion, + -misc-include-cleaner, -misc-no-recursion, -misc-non-private-member-variables-in-classes, -misc-unused-parameters, -modernize-avoid-c-arrays, -modernize-deprecated-headers, + -modernize-use-designated-initializers, -modernize-use-nodiscard, -modernize-use-trailing-return-type, -mpi-*, -objc-*, -openmp-*, + -performance-avoid-endl, + -performance-enum-size, -readability-avoid-const-params-in-decls, -readability-convert-member-functions-to-static, + -readability-function-cognitive-complexity, + -readability-identifier-length, -readability-implicit-bool-conversion, -readability-magic-numbers, + -readability-math-missing-parentheses, + -readability-qualified-auto, -zircon-*, HeaderFilterRegex: '.*' diff --git a/.claude/CLAUDE.md b/.claude/CLAUDE.md new file mode 100644 index 00000000000..be69dad6ea7 --- /dev/null +++ b/.claude/CLAUDE.md @@ -0,0 +1,235 @@ +# CLAUDE.md + +Guidance for Claude Code when working in the Velox repository. + +## Branch Hygiene + +Before creating a new feature branch, always: + +1. `git checkout main` +2. `git fetch upstream && git rebase upstream/main` +3. Delete stale feature branches. +4. Push main to origin if behind upstream. +5. `git checkout -b ` + +## PR Review + +When asked to review a PR (via `/pr-review`), always use the /pr-review skill. + +### Review scripts + +Use `scripts/review/fetch.py` and `scripts/review/post.py` for PR reviews +instead of raw `gh api` calls. + +```bash +# Fetch PR metadata, diff, comments, and reviews in one shot. +python3 scripts/review/fetch.py +python3 scripts/review/fetch.py + +# Post a review from a file. +python3 scripts/review/post.py +python3 scripts/review/post.py +# Events: APPROVE, REQUEST_CHANGES, COMMENT +``` + +Always draft the review body in `/tmp/` and get approval before calling +`post.py`. + +### Review style + +See [scripts/review/REVIEW_GUIDE.md](../scripts/review/REVIEW_GUIDE.md). + +## Queries + +When asked a question about the PR or codebase (via `/query`), use the /query skill. + +## Overview + +Velox is an open source C++ library for composable data processing and +query execution. Licensed under Apache 2.0. Requires C++20, GCC 11+ or +Clang 15+. + +## Build + +```bash +make debug # debug build +make release # optimized build +``` + +## Testing + +```bash +make unittest # run all tests +cd _build/debug && ctest -j 8 # run all tests in parallel +ctest -R ExprTest # run tests matching a pattern +``` + +Test files live in `tests/` subdirectories alongside source. + +### Grouped tests + +Four test suites use `velox_add_grouped_tests` to reduce link times on Linux CI +by batching source files into shared binaries: +- `velox/exec/tests` (`velox_exec_test`, `velox_exec_util_test`) +- `velox/functions/prestosql/aggregates/tests` +- `velox/common/caching/tests` +- `velox/serializers/tests` + +All other test suites use individual binaries on all platforms. + +On macOS, grouping is off by default (`VELOX_ENABLE_GROUPED_TESTS=OFF`) and each +test file gets its own binary (e.g., `ValuesTest.cpp` → `velox_exec_test_ValuesTest`). +On Linux CI, grouping is on (`velox_exec_test_group0` through `_group7`). +Override with `-DVELOX_ENABLE_GROUPED_TESTS=ON/OFF`. + +### Common test workflows + +```bash +# Run all test binaries whose ctest name matches a regex. +# On Linux this matches velox_exec_test_group0 … _group7. +# On macOS this matches velox_exec_test_ValuesTest, +# velox_exec_test_HashJoinTest, etc. +cd _build/debug && ctest -R velox_exec + +# Run a specific test file (macOS — individual binary) +_build/debug/velox/exec/tests/velox_exec_test_ValuesTest --gtest_filter="ValuesTest.*" + +# Run a specific test case (Linux — grouped binary) +_build/debug/velox/exec/tests/velox_exec_test_group3 --gtest_filter="ValuesTest.empty" +``` + +**Re-running a CI failure locally:** CI reports a failure in +`velox_exec_test_group3` with `ValuesTest.empty`. On Linux, run the grouped +binary directly. On macOS, the grouped binary does not exist — use the +per-file binary instead: `velox_exec_test_ValuesTest --gtest_filter="ValuesTest.empty"`. + +**Adding a new test to a grouped suite:** Add the source file to the `SOURCES` +list in the relevant `velox_add_grouped_tests()` call in `CMakeLists.txt`. It +is automatically assigned to a group on Linux and gets its own binary on macOS. + +**Creating a new test suite:** Use `velox_add_grouped_tests` for suites with +many test files (10+) that link against large libraries like velox core — each +individual binary pays the full link cost, so grouping them into shared binaries +significantly reduces total CI build time. For suites with only a few test files +or lightweight dependencies, use `add_executable` / `add_test`. + +## Formatting + +```bash +make format # format all changed files +``` + +## Coding Style + +Read [CODING_STYLE.md](../CODING_STYLE.md) for the complete guide. Key rules +are summarized below. + +### Comments + +- Use `///` for public API documentation (classes, public methods, public members). +- Use `//` for private/protected members and comments inside code blocks. +- Start comments with active verbs, not "This class…" or "This method…". + - ❌ `/// This class builds query plans.` + - ✅ `/// Builds query plans.` +- Comments should be full English sentences starting with a capital letter and ending with a period. +- Comment every class, every non-trivial method, every member variable. +- Do not restate the variable name. Either explain the semantic meaning or omit the comment. + - ❌ `// A simple counter.` above `size_t count_{0};` +- Avoid redundant comments that repeat what the code already says. Comments should explain *why*, not *what*. +- Use `// TODO: Description.` for future work. Do not include author's username. +- Do not duplicate comments between `.h` and `.cpp`. Document the function in the header; the implementation should not repeat the same comment. Duplicated comments diverge over time. + +### Naming Conventions + +- **PascalCase** for types and file names. +- **camelCase** for functions, member and local variables. +- **camelCase_** for private and protected member variables. +- **snake_case** for namespace names and build targets. +- **UPPER_SNAKE_CASE** for macros. +- **kPascalCase** for static constants and enumerators. +- Do not abbreviate. Use full, descriptive names. Well-established abbreviations (`id`, `url`, `sql`, `expr`) are acceptable. +- Prefer `numXxx` over `xxxCount` (e.g. `numRows`, `numKeys`). +- Never name a file or class `*Utils`, `*Helpers`, or `*Common`. These generic + names attract unrelated functions over time and lose cohesion. Name files and + classes after the concept they represent. Use a class with static methods to + group related operations, and shorten method names since the class name + provides context. + +### Asserts and CHECKs + +- Use `VELOX_CHECK_*` for internal errors, `VELOX_USER_CHECK_*` for user errors. +- Prefer two-argument forms: `VELOX_CHECK_LT(idx, size)` over `VELOX_CHECK(idx < size)`. +- Use `VELOX_FAIL()` / `VELOX_USER_FAIL()` to throw unconditionally. +- Use `VELOX_UNREACHABLE()` for impossible branches, `VELOX_NYI()` for unimplemented paths. +- Put runtime information (names, values, types) at the **end** of error messages, after the static description. + - ❌ `VELOX_USER_FAIL("Column '{}' is ambiguous", name);` + - ✅ `VELOX_USER_FAIL("Column is ambiguous: {}", name);` + +### Variables + +- Prefer value types, then `std::optional`, then `std::unique_ptr`. +- Prefer `std::string_view` over `const std::string&` for function parameters. +- Use uniform initialization: `size_t size{0}` over `size_t size = 0`. +- Declare variables in the smallest scope, as close to usage as possible. +- Use digit separators (`'`) for numeric literals with 4 or more digits: `10'000`, not `10000`. +- Use trailing commas in multi-line initializer lists, enum definitions, and + function-call argument lists that span multiple lines. This produces cleaner + diffs when items are added or reordered. + +### API Design + +- Keep the public API surface small. +- Prefer free functions in `.cpp` (anonymous namespace) over private/static class methods. +- Define free functions close to where they are used, not grouped together at the top or bottom of the file. +- Keep method implementations in `.cpp` except for trivial one-liners. +- Avoid default arguments when all callers can pass values explicitly. +- Never use `friend`, `FRIEND_TEST`, or any friend declarations. If a test needs access to private members, redesign the API or test through public methods instead. + +### Tests + +- Place new tests next to related existing tests, not at the end of the file. Group tests by topic (e.g., place `tryCast` next to `types`, `notBetween` next to `ifClause` which uses `between`). + +Use gtest container matchers (`testing::ElementsAre`, etc.) for verifying collections: + +```cpp +// ❌ Avoid - multiple individual assertions +EXPECT_EQ(result.size(), 3); +EXPECT_EQ(result[0], "a"); +EXPECT_EQ(result[1], "b"); +EXPECT_EQ(result[2], "c"); + +// ✅ Prefer - single matcher assertion +EXPECT_THAT(result, testing::ElementsAre("a", "b", "c")); +``` + +Common matchers: +- `ElementsAre(...)` - exact ordered match +- `UnorderedElementsAre(...)` - exact unordered match +- `Contains(...)` - at least one element matches +- `IsEmpty()` - collection is empty +- `SizeIs(n)` - collection has n elements + +Requires `#include `. + +## Common Mistakes + +These are frequently violated rules. Check every new or modified line against +this list before finishing. + +- **Bug fixes without a failing test first.** Write the test first, confirm it fails, then fix. A test that passes with and without the fix proves nothing. +- **`///` vs `//` wrong comment style.** `///` is only for public API in headers. Everything else uses `//`. +- **One-letter and abbreviated variable names.** Use full, descriptive names. Only loop indices (`i`, `j`) are acceptable. +- **Undocumented APIs in headers.** Every class, method, and member variable in a `.h` file must have a comment. +- **Non-trivial implementations in headers.** If a method body has more than one statement, it belongs in the `.cpp` file. +- **`goto` statements.** Never use `goto`. Use early returns, helper functions, or duplicated code paths. +- **Fitting tests to buggy code.** Never update test expectations to match buggy output without verifying correctness first. +- **Generic file and class names.** Never name a file or class `*Utils`, `*Helpers`, or `*Common`. +- **Verify causation before asserting it.** Do not attribute failures to a commit based on its message alone. Verify empirically. +- **Silently simplifying an approved plan.** If a step is harder than expected, say so and get approval before reducing scope. +- **Working around infrastructure bugs.** Do not silently work around bugs in shared infrastructure. Report and discuss. + +## Design Documents + +Design (including proposals) live in `docs/designs/`. When creating new +designs, place them there with a descriptive filename (e.g., +`column-extraction-pushdown.md`). diff --git a/.claude/skills/ci-failure-analysis/SKILL.md b/.claude/skills/ci-failure-analysis/SKILL.md new file mode 100644 index 00000000000..e088f9af705 --- /dev/null +++ b/.claude/skills/ci-failure-analysis/SKILL.md @@ -0,0 +1,137 @@ +You are a CI failure analyst for the Velox C++ project. A CI run has failed +on PR #{{PR_NUMBER}} in the {{REPOSITORY}} repo. +The workflow run ID is {{RUN_ID}}. + +Failure metadata (JSON array of failed jobs): +{{FAILURE_METADATA}} + +Each entry has: "job" (job name), "type" ("build", "test", or "unknown"), and +optionally "failed_tests" (newline-separated test names). A type of "unknown" +means no structured failure metadata was available (e.g., Fuzzer Jobs) — you +must determine the failure type from the job logs. + +Your task: +1. Use `gh api` to download the logs for the failed jobs in this workflow run. + - List jobs: `gh api repos/{{REPOSITORY}}/actions/runs/{{RUN_ID}}/jobs` + - For each job, save its `id` and note the step numbers from the `steps` array. + Find the step that ran the tests or build (usually named "Run Tests", "Build", + or similar — look for the step whose logs contain the failure output, not the + status-reporting step). You need the job `id` and step `number` to build a + direct link: `https://github.com/{{REPOSITORY}}/actions/runs/{{RUN_ID}}/job/{job_id}#step:{step_number}:{ui_line}` + To compute `ui_line`: the raw log numbers lines across all steps, but the + GitHub UI numbers lines per-step starting from 1. To convert, find the line + in the raw log where the test step begins (search for "Test project /__w/") + and call that `start_line`. Then find the `[ FAILED ]` line and call that + `failed_line`. The UI line number is `failed_line - start_line + 1`. + For build failures, use the first `error:` line instead of `[ FAILED ]`. + - Download job logs: `gh api repos/{{REPOSITORY}}/actions/jobs/{job_id}/logs` (returns plain text) + - If job logs API fails, try: `gh run view {{RUN_ID}} --repo {{REPOSITORY}} --log-failed` + +2. For TEST failures: Find the gtest failure output — the lines between `[ RUN ]` and + `[ FAILED ]` for each failing test. Extract the assertion message, expected vs actual + values, file path, and line number. Also find the test binary name from the ctest output + (look for lines like `Start N: ` followed by the binary path, or search for + the binary path in the test output). You will need this for the reproduce command. + +3. For BUILD failures: Find compiler `error:` lines with file paths and error messages. + +4. For FUZZER failures (from the "Fuzzer Jobs" workflow): Fuzzers run via + `run-fuzzer-parallel.sh` which launches multiple instances in parallel. + Look for crash reports, assertion failures (VELOX_CHECK, VELOX_FAIL), + segfaults, or timeouts in the logs. Key information to extract: + - The fuzzer name (e.g., "Presto Fuzzer", "Join Fuzzer", "Spark Expression Fuzzer") + - The error message or assertion failure + - The seed value (from `--seed` flag or log output) for reproduction + - The file and line number from stack traces + - The reproduce command uses the fuzzer binary with `--seed ` flag + +5. Get the PR diff: `gh pr diff {{PR_NUMBER}} --repo {{REPOSITORY}}` + Determine if the failures are likely caused by the PR changes. + +6. Search open issues for known failures: + `gh issue list --repo {{REPOSITORY}} --search "" --state open --limit 5` + Check if any failing test has a known open issue. + +7. Check if the same failures occur on the main branch (pre-existing/flaky): + `gh run list --repo {{REPOSITORY}} --branch main --workflow "" --limit 3 --json conclusion,databaseId` + Use the appropriate workflow name: "Linux Build using GCC" for build/test + failures, "Fuzzer Jobs" for fuzzer failures. + +8. Post a SINGLE comment on the PR with your analysis. Use update-or-create + behavior so re-runs replace the prior analysis instead of stacking new + comments. Look up the prior comment by its heading, then edit that + specific comment ID via the GitHub API. + + Write the body to a file first (e.g., `/tmp/ci-failure-comment.md`) so + multi-line markdown round-trips correctly, then: + - Find any prior comment with this workflow's heading: + `gh api "repos/{{REPOSITORY}}/issues/{{PR_NUMBER}}/comments" --paginate \ + --jq '[.[] | select(.body | contains("## CI Failure Analysis")) | .id] | first'` + - If a comment ID is returned, edit that specific comment in place by + piping a JSON body in via `--input -`: + `jq -Rs '{body: .}' /tmp/ci-failure-comment.md \ + | gh api -X PATCH "repos/{{REPOSITORY}}/issues/comments/" --input -` + - Otherwise, create a new comment: + `gh pr comment {{PR_NUMBER}} --repo {{REPOSITORY}} --body-file /tmp/ci-failure-comment.md` + + The `## CI Failure Analysis` heading must remain in the body so this + lookup keeps working across re-runs. + +Format the comment as follows (use markdown): +``` +## CI Failure Analysis + +> _Auto-generated by the CI Failure Analysis workflow. This comment is updated in place each time CI fails on a new commit, so it always reflects the latest run — re-pushing or re-running CI will refresh the analysis below. Last updated from [workflow run {{RUN_ID}}](https://github.com/{{REPOSITORY}}/actions/runs/{{RUN_ID}})._ + +### Failure [View logs]() + +**Failed tests:** (or **Build errors:** for build failures) + +For each failing test, show: +- Test name +- The assertion error (expected vs actual, or the error message) +- Source file and line number + +For build failures, show: +- The compiler error message +- Source file and line number + +Keep failure details in a code block for readability. + +(Repeat the above section for each failed job, each with its own step-level link) + +--- + +**Correlation with PR changes:** +- State whether the failure appears related to the PR diff or not +- If related, point to the specific file/function in the diff that likely caused it +- If unrelated, explain why (e.g., "This test modifies X but the PR only touches Y") + +**Known issues:** +- If an open issue tracks this failure, link to it +- If the same test fails on main, note it as a pre-existing/flaky failure + +**Reproduce locally:** (for test failures) +- Show the command to reproduce, e.g.: + `./_build/debug/velox/exec/tests/velox_exec_test_group0 --gtest_filter="TestSuite.testCase"` + Use the actual binary path from the ctest log output. + +**Recommended fix:** (if the failure is related to the PR) +- Brief suggestion of what to fix +``` + +The blockquote line directly under the heading MUST be present on every +posted comment. Use the current UTC time in `YYYY-MM-DD HH:MM:SS UTC` +format (run `date -u +'%Y-%m-%d %H:%M:%S UTC'`). Because the comment is +overwritten in place on re-runs, this line is the only visible signal +that the analysis was refreshed — it changes on every update and stays +static if no re-run has happened. It sits directly under the heading +(not as a footer) so reviewers see it without scrolling past the +analysis. + +Important rules: +- Be concise. Show only the relevant failure output, not the entire log. +- If many tests fail (>5), show the first 3-5 in detail and summarize the rest. +- Cap the comment at 60,000 characters (GitHub limit is 65,536). +- Use the `gh pr comment` command to post. Do NOT use any other method. +- If you cannot determine the cause, say so honestly rather than guessing. diff --git a/.claude/skills/pr-review/SKILL.md b/.claude/skills/pr-review/SKILL.md new file mode 100644 index 00000000000..28833194377 --- /dev/null +++ b/.claude/skills/pr-review/SKILL.md @@ -0,0 +1,187 @@ +--- +name: pr-review +description: Review Velox pull requests for code quality, memory safety, performance, and correctness. Use when reviewing PRs, when asked to review code changes, or when the user mentions "/pr-review". +--- + +# Velox PR Review Skill + +Review Velox pull requests focusing on what CI cannot check: code quality, memory +safety, concurrency, performance, and correctness. This is performance-critical C++ +code for a database execution engine where bugs can cause data corruption, crashes, +or security vulnerabilities. + +## Usage Modes + +### GitHub Actions Mode + +When invoked via `/pr-review [additional context]` on a GitHub PR, the action +pre-fetches PR metadata and injects it into the prompt. Detect this mode by the +presence of ``, ``, and `` tags in +the prompt. + +The prompt already contains: +- PR metadata (title, author, branch names, additions/deletions, file count) +- PR body/description +- All comments and review comments (with file/line references) +- List of changed files with paths and change types + +Use git commands to get the diff and commit history. The base branch name is in the prompt context (look for PR Branch: -> or the baseBranch field). + +```bash +# Get the full diff against the base branch +git diff origin/...HEAD + +# Get diff stats +git diff --stat origin/...HEAD + +# Get commit history for this PR +git log origin/..HEAD --oneline + +# If the base branch ref is not available, fetch it first +git fetch origin --depth=1 +``` + +Do NOT use `gh` CLI commands in this mode -- only git commands are available. +All PR metadata, comments, and reviews are already in the prompt context; +only the diff and commit log need to be fetched via git. + +If the reviewer provided additional context or instructions after the `/pr-review` +command, incorporate those into your review focus. + +### Local CLI Mode + +The user provides a PR number or URL: + +``` +/pr-review 12345 +/pr-review https://github.com/facebookincubator/velox/pull/12345 +``` + +Use `gh` CLI to fetch PR data: + +```bash +gh pr view --json title,body,author,baseRefName,headRefName,files,additions,deletions,commits +gh pr diff +gh pr view --json comments,reviews +``` + +## Review Workflow + +### Step 1: Read Project Guidelines + +**Before reviewing, you MUST Read `CODING_STYLE.md` at the repo root in +full.** Do not skim, do not skip — every modified line in the diff must be +checked against it. + +The "Common Mistakes" section is the authoritative checklist for the +highest-volume real review hits (`///` vs `//`, abbreviations, `*Utils`, +undocumented headers, header-body weight, `goto`, test-first for bug +fixes, naming conventions, assert forms, etc.). + +This skill does not maintain a duplicate checklist — `CODING_STYLE.md` +is the single source of truth. If anything in this skill ever appears to +contradict `CODING_STYLE.md`, prefer `CODING_STYLE.md`. + +### Step 2: Analyze Changes and Prior Review + +Read through the diff systematically: +1. Identify the purpose of the change from title/description +2. Group changes by type (new code, tests, config, docs) +3. Note the scope of changes (files affected, lines changed) + +The `` block in the prompt context contains all prior review +comments — including any from earlier `/pr-review` invocations on this PR. +Read them before reviewing: +- Do **not** re-flag issues already raised by a prior reviewer (human or + Claude). Re-flagging trains authors to ignore Claude reviews. +- If a prior comment was addressed by a follow-up commit, verify the fix in + the diff rather than restating the original concern. +- If `/pr-review` was invoked in reply to a specific comment thread, focus + the review on that thread's concerns instead of re-reviewing the whole PR. + +### Step 3: Deep Review + +Trace the logic step by step. For each change, consider boundary conditions +(empty, null, max size, first/last iteration), failure modes (allocation +failures, exceptions, partial state), concurrency (race conditions, lock +ordering), and memory safety (ownership, lifetimes, dangling references). Be +strict — better to flag a potential issue than miss a real bug. The Review +Areas table below enumerates what to check; do not duplicate it as narrative. + +## Review Areas + +Analyze each of these areas thoroughly: + +| Area | Focus | +|------|-------| +| Correctness & Edge Cases | Logic errors, off-by-one, null/empty handling, boundary conditions, integer overflow, floating point edge cases (NaN, Inf, negative zero) | +| Memory Safety | Use-after-free, double-free, leaks, dangling pointers/references, buffer overflows, ownership/lifetime issues, exception safety | +| Concurrency | Race conditions, data races, deadlocks, lock ordering, thread-safety of shared state | +| Performance | Unnecessary copies (move semantics?), inefficient algorithms, cache-unfriendly access, excessive allocations in hot paths | +| Error Handling | All error paths handled? Exceptions caught appropriately? Informative error messages? Correct use of VELOX_CHECK_* vs VELOX_USER_CHECK_*? | +| Code Quality | RAII, const-correctness, smart pointers, naming conventions, clear structure | +| Testing | Sufficient tests? Edge cases covered? Error paths tested? Using gtest matchers? **Bug-fix PRs**: does the diff add a test that would fail without the fix? Flag bug fixes that ship code-only. | + +## Output Format + +The output should be a markdown-formatted summary and should follow the following markdown format exactly: + +```markdown +### Summary +Brief overall assessment (1-2 sentences) + +### Issues Found +List any issue, categorized by severity: + - 🔴 **Critical**: Must fix before merge + - 🟡 **Suggestion**: Should consider + - 🟢 **Nitpick**: Minor style issues + +Each issue should also include: +- File and line reference +- Description of the issue +- Suggested fix if applicable + +### Positive Observations +Note any particularly good patterns or improvements. +``` + +## Inline Comments + +Use the `mcp__github_inline_comment__create_inline_comment` tool to post +comments directly on specific lines in the PR diff. Inline comments should +be used whenever pointing at the exact line adds clarity beyond the summary +comment. + +**Use inline comments for:** +- Concrete bugs or incorrect logic +- Memory safety issues (use-after-free, dangling references, leaks) +- Off-by-one errors or boundary condition mistakes +- Incorrect use of VELOX_CHECK_* vs VELOX_USER_CHECK_* + +**Do NOT use inline comments for:** +- Style nitpicks or naming suggestions +- General architectural feedback +- Positive observations +- Anything that applies broadly rather than to a specific line + +**Always post a summary comment** with the overall review. Inline comments +supplement the summary — they do not replace it. + +## Key Principles + +1. **No repetition** - Each observation appears in exactly one place +2. **Focus on what CI cannot check** - Don't comment on formatting, linting, or type errors +3. **Be specific** - Reference file paths and line numbers +4. **Be actionable** - Provide concrete suggestions, not vague concerns +5. **Be proportionate** - Minor issues shouldn't block, but note them +6. **Assume competence** - The author knows C++; explain only non-obvious context +7. **Permission to be quiet** - If the PR has no meaningful issues, post a + short LGTM (one or two sentences) and stop. Do not manufacture nitpicks + to fill space — padding trains authors to ignore Claude reviews. A clean + PR getting a clean and high-signal review is the correct outcome. + +## Files to Reference + +When reviewing, consult these project files for context: +- `CLAUDE.md` - Coding style and project guidelines +- `CODING_STYLE.md` - Complete coding style guide diff --git a/.claude/skills/query/SKILL.md b/.claude/skills/query/SKILL.md new file mode 100644 index 00000000000..35ed47dff51 --- /dev/null +++ b/.claude/skills/query/SKILL.md @@ -0,0 +1,24 @@ +--- +name: query +description: Answer questions about the Velox codebase or pull requests. Use when asked a question via "/query" or when the user wants to understand code, architecture, or implementation details. +--- + +# Velox Query Skill + +Answer questions about the Velox project codebase or specific pull requests. + +## Key Context + +- Velox is a C++ execution engine library for analytical data processing +- Uses C++20 standard with heavy use of templates and SFINAE +- Custom memory management with MemoryPool +- Vectorized execution with custom Vector types +- Follows Google C++ style with some modifications + +## Guidelines + +- Read CLAUDE.md and CODING_STYLE.md for project-specific conventions +- Answer thoroughly and accurately +- If the question is about PR changes, analyze the diff carefully +- If it's about the codebase, explore relevant files to provide a complete answer +- Be specific and reference exact file paths and line numbers when relevant diff --git a/.github/actions/generate-dependency-graph/action.yml b/.github/actions/generate-dependency-graph/action.yml new file mode 100644 index 00000000000..d3efe53335a --- /dev/null +++ b/.github/actions/generate-dependency-graph/action.yml @@ -0,0 +1,59 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Generate Dependency Graph +description: > + Configure CMake with File API and compile_commands.json, then generate + the dependency graph JSON (file_to_targets, header_to_sources, target_deps). + +runs: + using: composite + steps: + - name: Fix git permissions + shell: bash + run: git config --global --add safe.directory ${GITHUB_WORKSPACE} + + - name: Set up File API query + shell: bash + run: | + mkdir -p _build/release/.cmake/api/v1/query + echo '{}' > _build/release/.cmake/api/v1/query/codemodel-v2 + + - name: CMake configure + shell: bash + env: + MAKEFLAGS: NUM_THREADS=8 + run: | + source /opt/rh/gcc-toolset-14/enable + cmake -B _build/release -GNinja \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ + -DVELOX_MONO_LIBRARY=OFF \ + -DVELOX_ENABLE_BENCHMARKS=ON \ + -DVELOX_ENABLE_EXAMPLES=ON \ + -DVELOX_ENABLE_ARROW=ON \ + -DVELOX_ENABLE_GEO=ON \ + -DVELOX_ENABLE_PARQUET=ON \ + -DVELOX_ENABLE_HDFS=ON \ + -DVELOX_ENABLE_S3=ON \ + -DVELOX_ENABLE_GCS=ON \ + -DVELOX_ENABLE_ABFS=ON \ + -DVELOX_ENABLE_WAVE=ON + + - name: Generate dependency graph + shell: bash + run: | + python3 .github/scripts/generate-dependency-graph.py \ + --build-dir _build/release \ + --source-dir . \ + --output dependency-graph.json diff --git a/.github/copy-pr-bot.yaml b/.github/copy-pr-bot.yaml index 895ba83ee54..b28d1fe0a0c 100644 --- a/.github/copy-pr-bot.yaml +++ b/.github/copy-pr-bot.yaml @@ -1,3 +1,16 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # Configuration file for `copy-pr-bot` GitHub App # https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/ diff --git a/.github/disabled-workflows/breeze.yml b/.github/disabled-workflows/breeze.yml index 66eaaa480c1..bc77bc24811 100644 --- a/.github/disabled-workflows/breeze.yml +++ b/.github/disabled-workflows/breeze.yml @@ -54,13 +54,13 @@ jobs: working-directory: velox steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: path: velox persist-credentials: false - name: Install uv - uses: astral-sh/setup-uv@4959332f0f014c5280e7eac8b70c90cb574c9f9b # v6.6.0 + uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2 - name: Install Dependencies run: | @@ -88,20 +88,20 @@ jobs: if: ${{ github.repository == 'facebookincubator/velox' }} name: Ubuntu GPU debug env: - CUDA_VERSION: '12.8' + CUDA_VERSION: '12.9' defaults: run: shell: bash working-directory: velox steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: path: velox persist-credentials: false - name: Install uv - uses: astral-sh/setup-uv@4959332f0f014c5280e7eac8b70c90cb574c9f9b # v6.6.0 + uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2 - name: Install Dependencies run: | diff --git a/.github/disabled-workflows/build-impact-comment.yml b/.github/disabled-workflows/build-impact-comment.yml new file mode 100644 index 00000000000..a96944ed399 --- /dev/null +++ b/.github/disabled-workflows/build-impact-comment.yml @@ -0,0 +1,97 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Build Impact Comment + +# zizmor:disable:dangerous-triggers -- only reads an artifact and posts a comment, no PR code execution. +on: + workflow_run: + workflows: [Build Impact Analysis] + types: + - completed + +permissions: + pull-requests: write + issues: write + actions: read + +jobs: + post-comment: + if: > + github.event.workflow_run.event == 'pull_request' && + github.event.workflow_run.conclusion == 'success' + runs-on: ubuntu-latest + steps: + - name: Download comment artifact + uses: actions/download-artifact@v4 + with: + name: build-impact-comment + github-token: ${{ github.token }} + run-id: ${{ github.event.workflow_run.id }} + + - name: Get PR number + id: pr + env: + GH_TOKEN: ${{ github.token }} + REPO: ${{ github.repository }} + HEAD_OWNER: ${{ github.event.workflow_run.head_repository.owner.login }} + HEAD_BRANCH: ${{ github.event.workflow_run.head_branch }} + run: | + pr_number=$(gh api \ + "/repos/${REPO}/pulls?head=${HEAD_OWNER}:${HEAD_BRANCH}&state=open" \ + -q '.[0].number // empty') + + if [ -z "$pr_number" ]; then + echo "No open PR found for branch ${HEAD_BRANCH}" + exit 0 + fi + echo "number=$pr_number" >> "$GITHUB_OUTPUT" + + - name: Post or update PR comment + if: steps.pr.outputs.number + uses: actions/github-script@v7 + env: + PR_NUMBER: ${{ steps.pr.outputs.number }} + with: + script: | + const fs = require('fs'); + const comment = fs.readFileSync('comment.md', 'utf8'); + const marker = '## Build Impact Analysis'; + const prNumber = parseInt(process.env.PR_NUMBER, 10); + + const { data: comments } = await github.rest.issues.listComments({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: prNumber, + }); + + const existing = comments.find(c => + c.user.login === 'github-actions[bot]' && c.body.includes(marker) + ); + + if (existing) { + await github.rest.issues.updateComment({ + owner: context.repo.owner, + repo: context.repo.repo, + comment_id: existing.id, + body: comment, + }); + } else { + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: prNumber, + body: comment, + }); + } diff --git a/.github/disabled-workflows/build-impact.yml b/.github/disabled-workflows/build-impact.yml new file mode 100644 index 00000000000..65cb5b08ac6 --- /dev/null +++ b/.github/disabled-workflows/build-impact.yml @@ -0,0 +1,217 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Build Impact Analysis + +on: + push: + branches: + - main + paths: + - velox/** + - CMakeLists.txt + - CMake/** + - .github/workflows/build-impact.yml + - .github/scripts/generate-dependency-graph.py + - .github/scripts/detect-build-impact.py + + pull_request: + paths: + - velox/** + - '!velox/docs/**' + - CMakeLists.txt + - CMake/** + # TODO: Remove after testing comment workflow + - .github/** + +permissions: + contents: read + actions: read + +concurrency: + group: ${{ github.workflow }}-${{ github.repository }}-${{ github.head_ref || github.sha }} + cancel-in-progress: true + +jobs: + # ========================================================================== + # Push to main: generate and upload the dependency graph artifact. + # ========================================================================== + generate-graph: + if: github.event_name == 'push' + runs-on: 8-core-ubuntu + container: ghcr.io/facebookincubator/velox-dev:adapters + env: + VELOX_DEPENDENCY_SOURCE: SYSTEM + steps: + - uses: actions/checkout@v5 + with: + persist-credentials: false + + - uses: ./.github/actions/generate-dependency-graph + + - name: Upload dependency graph + uses: actions/upload-artifact@v4 + with: + name: dependency-graph + path: dependency-graph.json + retention-days: 90 + + # ========================================================================== + # Pull request: check what changed and whether the graph artifact exists. + # ========================================================================== + check-changes: + if: github.event_name == 'pull_request' + runs-on: ubuntu-latest + outputs: + cmake-changed: ${{ steps.check.outputs.cmake_changed }} + graph-available: ${{ steps.artifact.outputs.available }} + graph-run-id: ${{ steps.artifact.outputs.run_id }} + use-slow-path: ${{ steps.decide.outputs.slow }} + steps: + - name: Get changed files + id: check + env: + GH_TOKEN: ${{ github.token }} + run: | + # Use the GitHub API to get the changed files directly. + # This avoids issues with fork PRs where HEAD_SHA may not + # be available in the local clone. + gh api --paginate \ + "/repos/${{ github.repository }}/pulls/${{ github.event.pull_request.number }}/files" \ + -q '.[].filename' > changed_files.txt + cat changed_files.txt + + # Check if any CMake files changed. + if grep -qE '(CMakeLists\.txt|CMake/.+\.cmake)$' changed_files.txt; then + echo "cmake_changed=true" >> "$GITHUB_OUTPUT" + echo "CMake files changed — slow path needed." + else + echo "cmake_changed=false" >> "$GITHUB_OUTPUT" + fi + + - name: Check if graph artifact exists + id: artifact + env: + GH_TOKEN: ${{ github.token }} + run: | + # Find the latest successful generate-graph run on main. + run_id=$(gh api \ + "/repos/${{ github.repository }}/actions/workflows/build-impact.yml/runs?branch=main&status=success&per_page=5" \ + -q '.workflow_runs[0].id // empty') + + if [ -z "$run_id" ]; then + echo "available=false" >> "$GITHUB_OUTPUT" + echo "No graph artifact found — slow path needed." + else + echo "available=true" >> "$GITHUB_OUTPUT" + echo "run_id=$run_id" >> "$GITHUB_OUTPUT" + echo "Graph artifact available from run $run_id." + fi + + - name: Decide fast or slow path + id: decide + env: + CMAKE_CHANGED: ${{ steps.check.outputs.cmake_changed }} + GRAPH_AVAILABLE: ${{ steps.artifact.outputs.available }} + run: | + if [[ "$CMAKE_CHANGED" == "true" ]] || \ + [[ "$GRAPH_AVAILABLE" == "false" ]]; then + echo "slow=true" >> "$GITHUB_OUTPUT" + else + echo "slow=false" >> "$GITHUB_OUTPUT" + fi + + - name: Upload changed files list + uses: actions/upload-artifact@v4 + with: + name: changed-files + path: changed_files.txt + + # ========================================================================== + # Fast path: download pre-computed graph, run impact detection. + # ========================================================================== + fast-detect: + needs: check-changes + if: needs.check-changes.outputs.use-slow-path == 'false' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + with: + persist-credentials: false + + - name: Download changed files + uses: actions/download-artifact@v4 + with: + name: changed-files + + - name: Download dependency graph + uses: actions/download-artifact@v4 + with: + name: dependency-graph + github-token: ${{ github.token }} + repository: ${{ github.repository }} + run-id: ${{ needs.check-changes.outputs.graph-run-id }} + + - name: Run impact detection + run: | + python3 .github/scripts/detect-build-impact.py \ + --graph dependency-graph.json \ + --changed-files changed_files.txt \ + --build-type release \ + --graph-source "Fast path • Graph from main@${{ github.event.pull_request.base.sha }}" \ + --output comment.md + + - name: Upload comment + uses: actions/upload-artifact@v4 + with: + name: build-impact-comment + path: comment.md + + # ========================================================================== + # Slow path: generate graph from PR branch, then run impact detection. + # ========================================================================== + slow-detect: + needs: check-changes + if: needs.check-changes.outputs.use-slow-path == 'true' + runs-on: 8-core-ubuntu + container: ghcr.io/facebookincubator/velox-dev:adapters + env: + VELOX_DEPENDENCY_SOURCE: SYSTEM + steps: + - uses: actions/checkout@v5 + with: + fetch-depth: 0 + persist-credentials: false + + - name: Download changed files + uses: actions/download-artifact@v4 + with: + name: changed-files + + - uses: ./.github/actions/generate-dependency-graph + + - name: Run impact detection + run: | + python3 .github/scripts/detect-build-impact.py \ + --graph dependency-graph.json \ + --changed-files changed_files.txt \ + --build-type release \ + --graph-source "Slow path • Graph generated from PR branch" \ + --output comment.md + + - name: Upload comment + uses: actions/upload-artifact@v4 + with: + name: build-impact-comment + path: comment.md diff --git a/.github/disabled-workflows/build-metrics.yml b/.github/disabled-workflows/build-metrics.yml index 2bf62b4352e..c709b13ae30 100644 --- a/.github/disabled-workflows/build-metrics.yml +++ b/.github/disabled-workflows/build-metrics.yml @@ -36,7 +36,10 @@ permissions: jobs: metrics: name: Linux ${{ matrix.link-type }} - ${{ matrix.type }} with adapters - if: ${{ github.repository == 'facebookincubator/velox' }} + # Disabled: the conbench service at velox-conbench.voltrondata.run is no + # longer available (DNS resolution fails). The scheduled runs have been + # failing since late 2025. Re-enable once a replacement service is set up. + if: false #${{ github.repository == 'facebookincubator/velox' }} runs-on: ${{ matrix.runner }} container: ghcr.io/facebookincubator/velox-dev:adapters strategy: @@ -49,7 +52,7 @@ jobs: run: shell: bash steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: ref: ${{ inputs.ref || github.sha }} persist-credentials: false @@ -101,7 +104,7 @@ jobs: cat $sizes_file echo "::endgroup::" - - uses: actions/upload-artifact@v4 + - uses: actions/upload-artifact@v6 with: path: ${{ env.sizes_file }} name: ${{ matrix.type }}-${{ matrix.link-type }}-sizes @@ -145,7 +148,7 @@ jobs: needs: metrics steps: - name: Checkout - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: fetch-depth: 0 persist-credentials: true @@ -174,7 +177,7 @@ jobs: nix-shell --run "quarto render report.qmd" - name: Upload Report Artifact - uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: report path: scripts/ci/bm-report/report.html diff --git a/.github/disabled-workflows/build_pyvelox.yml b/.github/disabled-workflows/build_pyvelox.yml index 168e017e802..a37ea7b1af3 100644 --- a/.github/disabled-workflows/build_pyvelox.yml +++ b/.github/disabled-workflows/build_pyvelox.yml @@ -50,14 +50,14 @@ jobs: matrix: os: [8-core-ubuntu, macos-13, macos-14] steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: ref: ${{ inputs.ref || github.ref }} fetch-depth: 0 persist-credentials: false - - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: '3.10' @@ -115,7 +115,7 @@ jobs: uv build --sdist --out-dir wheelhouse - name: Build wheels - uses: pypa/cibuildwheel@5f22145df44122af0f5a201f93cf0207171beca7 # v3.0.0 + uses: pypa/cibuildwheel@63fd63b352a9a8bdcc24791c9dbee952ee9a8abc # v3.3.0 env: BUILD_VERSION: ${{ inputs.version }} CIBW_BUILD: cp310-* cp311-* cp312-* cp313-* @@ -140,7 +140,7 @@ jobs: cd wheelhouse rename 's/11_0/10_15/g' *.whl - - uses: actions/upload-artifact@4cec3d8aa04e39d1a68397de0c4cd6fb9dce8ec1 # v4.6.1 + - uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: wheels-${{ matrix.os }} retention-days: 5 @@ -154,7 +154,7 @@ jobs: needs: build_wheels runs-on: ubuntu-22.04 steps: - - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + - uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: pattern: wheels-* merge-multiple: true @@ -162,12 +162,12 @@ jobs: - run: ls wheelhouse - - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: '3.10' - name: Publish a Python distribution to PyPI - uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4 + uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # v1.13.0 with: password: ${{ secrets.PYPI_API_TOKEN }} packages_dir: wheelhouse diff --git a/.github/disabled-workflows/ci-failure-comment.yml b/.github/disabled-workflows/ci-failure-comment.yml new file mode 100644 index 00000000000..3e21114e799 --- /dev/null +++ b/.github/disabled-workflows/ci-failure-comment.yml @@ -0,0 +1,187 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: CI Failure Comment + +# This workflow runs in the context of the base repo (not the fork) and has +# write access to post PR comments. It is triggered when the Linux Build or +# Fuzzer Jobs workflow completes with a failure and uses Claude to analyze +# failure logs, correlate with the PR diff and open issues, and post a rich +# diagnostic comment. +# +# Security: This workflow never checks out or executes PR code. It reads +# failure metadata from artifacts uploaded by our own status jobs, then uses +# Claude with read-only gh CLI access to fetch logs and produce comments. +# +# zizmor:disable:dangerous-triggers -- only reads artifacts and posts comments, no PR code execution. +on: + workflow_dispatch: + inputs: + run_id: + description: Workflow run ID to analyze (from a failed CI run) + required: true + pr_number: + description: PR number to comment on + required: true + workflow_run: + workflows: [Linux Build using GCC, Fuzzer Jobs] + types: + - completed + +permissions: + contents: read + pull-requests: write + issues: write + actions: read + id-token: write + +jobs: + analyze-and-comment: + if: > + github.repository == 'facebookincubator/velox' && + (github.event_name == 'workflow_dispatch' || + (github.event.workflow_run.event == 'pull_request' && + github.event.workflow_run.conclusion == 'failure')) + runs-on: ubuntu-latest + steps: + - name: Get PR number + id: pr + env: + GH_TOKEN: ${{ github.token }} + REPO: ${{ github.repository }} + HEAD_OWNER: ${{ github.event.workflow_run.head_repository.owner.login }} + HEAD_BRANCH: ${{ github.event.workflow_run.head_branch }} + INPUT_PR_NUMBER: ${{ inputs.pr_number }} + run: | + # For workflow_dispatch, use the provided PR number directly. + if [ -n "$INPUT_PR_NUMBER" ]; then + echo "number=$INPUT_PR_NUMBER" >> "$GITHUB_OUTPUT" + exit 0 + fi + + pr_number=$(gh api \ + "/repos/${REPO}/pulls?head=${HEAD_OWNER}:${HEAD_BRANCH}&state=open" \ + -q '.[0].number // empty') + + if [ -z "$pr_number" ]; then + echo "No open PR found for branch ${HEAD_BRANCH}" + exit 0 + fi + echo "number=$pr_number" >> "$GITHUB_OUTPUT" + + - name: Download failure artifacts + if: steps.pr.outputs.number + uses: actions/download-artifact@v4 + with: + github-token: ${{ github.token }} + run-id: ${{ inputs.run_id || github.event.workflow_run.id }} + pattern: ci-failure-* + path: /tmp/ci-failures + merge-multiple: false + + - name: Collect failure metadata + if: steps.pr.outputs.number + id: metadata + env: + GH_TOKEN: ${{ github.token }} + REPO: ${{ github.repository }} + RUN_ID: ${{ inputs.run_id || github.event.workflow_run.id }} + run: | + FAILURES_DIR="/tmp/ci-failures" + if [ -d "$FAILURES_DIR" ] && [ -n "$(ls -A "$FAILURES_DIR" 2>/dev/null)" ]; then + # Collect all failure.json contents into a single JSON array. + METADATA="[" + FIRST=true + for entry in "$FAILURES_DIR"/ci-failure-*/failure.json; do + if [ -f "$entry" ]; then + if [ "$FIRST" = true ]; then + FIRST=false + else + METADATA="$METADATA," + fi + METADATA="$METADATA$(cat "$entry")" + fi + done + METADATA="$METADATA]" + else + # No failure artifacts — build metadata from failed jobs in the run. + # This handles workflows (e.g., Fuzzer Jobs) that don't upload + # ci-failure-* artifacts. + METADATA=$(gh api "repos/${REPO}/actions/runs/${RUN_ID}/jobs" \ + --paginate --jq '[.jobs[] | select(.conclusion == "failure") | {job: .name, type: "unknown"}]') + if [ "$METADATA" = "[]" ] || [ -z "$METADATA" ]; then + echo "No failure artifacts or failed jobs found." + echo "has_failures=false" >> "$GITHUB_OUTPUT" + exit 0 + fi + fi + + echo "has_failures=true" >> "$GITHUB_OUTPUT" + { + echo 'failure_metadata<> "$GITHUB_OUTPUT" + + - name: Checkout for Claude context and prompt + if: steps.metadata.outputs.has_failures == 'true' + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + sparse-checkout: | + CLAUDE.md + .claude/CLAUDE.md + .claude/skills/ci-failure-analysis/SKILL.md + persist-credentials: false + + - name: Prepare analysis prompt + if: steps.metadata.outputs.has_failures == 'true' + id: prompt + env: + PR_NUMBER: ${{ steps.pr.outputs.number }} + REPOSITORY: ${{ github.repository }} + RUN_ID: ${{ inputs.run_id || github.event.workflow_run.id }} + FAILURE_METADATA: ${{ steps.metadata.outputs.failure_metadata }} + run: | + # Read the prompt template and interpolate variables. + PROMPT=$(sed \ + -e "s|{{PR_NUMBER}}|${PR_NUMBER}|g" \ + -e "s|{{REPOSITORY}}|${REPOSITORY}|g" \ + -e "s|{{RUN_ID}}|${RUN_ID}|g" \ + .claude/skills/ci-failure-analysis/SKILL.md) + + # Failure metadata may contain newlines/special chars — use bash + # parameter expansion instead of sed to handle multiline content. + PROMPT="${PROMPT//\{\{FAILURE_METADATA\}\}/${FAILURE_METADATA}}" + + { + echo 'value<> "$GITHUB_OUTPUT" + + - name: Analyze failures with Claude + if: steps.metadata.outputs.has_failures == 'true' + uses: izaitsevfb/claude-code-action@ececd56fb999d06b4dd2477437bc408938295d76 # forked-pr-fix + with: + anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} + github_token: ${{ github.token }} + claude_args: --model claude-opus-4-6 --allowedTools Bash Read Grep Glob + prompt: ${{ steps.prompt.outputs.value }} + # Most Velox PRs are exported from Phabricator and pushed to GitHub + # by the meta-codesync bot, so the upstream workflow_run's `actor` + # is a Bot. Without this allowlist, claude-code-action refuses to + # run for those PRs and the failure analysis is silently dropped. + allowed_bots: meta-codesync + env: + GH_TOKEN: ${{ github.token }} diff --git a/.github/disabled-workflows/claude-review.yml b/.github/disabled-workflows/claude-review.yml new file mode 100644 index 00000000000..b34f637369f --- /dev/null +++ b/.github/disabled-workflows/claude-review.yml @@ -0,0 +1,643 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Claude Assistant for PRs +# +# This workflow enables Claude-powered interactions on pull requests from forks. +# Since secrets are not available to workflows triggered by fork PRs (for security), +# this uses a comment-triggered approach where a maintainer must explicitly request +# assistance by commenting on the PR. +# +# Supported Commands: +# - /claude-review [additional context] : Perform a thorough code review of the PR. +# Optionally provide additional context or instructions after the command. +# Example: /claude-review Please focus on memory safety and thread safety. +# - /claude-query : Ask Claude a question about the PR or codebase +# +# Security Model: +# - Only authorized users can trigger Claude (verified before any code access) +# - Code is fetched as text only (never executed) +# - Claude uses read-only tools (no Bash, no file writes) +# - All operations run in the context of the base repository + +name: Claude Assistant + +on: + issue_comment: + types: [created] + + # Manual trigger for testing - can be removed after validation + workflow_dispatch: + inputs: + pr_number: + description: PR number to review + required: true + type: number + dry_run: + description: Dry run (skip posting comment) + required: false + type: boolean + default: true + model: + description: Claude model to use + required: false + type: choice + options: + - claude-opus-4-6 + - claude-opus-4-1-20250805 + - claude-sonnet-4-20250514 + - claude-4-0-sonnet-20250805 + default: claude-opus-4-6 + +# Restrict default permissions +permissions: + contents: read + +jobs: + claude-assistant: + name: Claude Assistant + runs-on: ubuntu-latest + + # Job-level permissions - only what's needed for this job + permissions: + contents: read + pull-requests: write + issues: write + + # Run if: + # A) workflow_dispatch: Manual trigger for testing + # B) issue_comment: PR comment with /claude-review or /claude-query from authorized user + if: >- + github.event_name == 'workflow_dispatch' || + ( + github.event.issue.pull_request && + (contains(github.event.comment.body, '/claude-review') || contains(github.event.comment.body, '/claude-query')) && + contains(fromJSON('["kgpai", "mbasmanova", "pedroerp", "yuhta", "kagamiori", "bikramSingh91", "kevinwilfong", "xiaoxmeng", "kKPulla", "juwentus1234", "penescu", "srsuryadev", "jainxrohit"]'), github.event.comment.user.login) + ) + + steps: + # Step 1: Acknowledge the request (skip for workflow_dispatch) + - name: Add reaction to comment + if: github.event_name == 'issue_comment' + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + with: + script: | + await github.rest.reactions.createForIssueComment({ + owner: context.repo.owner, + repo: context.repo.repo, + comment_id: context.payload.comment.id, + content: 'eyes' + }); + + # Step 2: Detect command type and extract query + - name: Detect command type + id: command + if: github.event_name == 'issue_comment' + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + with: + script: | + const body = context.payload.comment.body; + + if (body.includes('/claude-review')) { + core.setOutput('type', 'review'); + // Extract optional additional context after /claude-review + const reviewMatch = body.match(/\/claude-review\s+([\s\S]*)/); + const additionalContext = reviewMatch ? reviewMatch[1].trim() : ''; + core.setOutput('query', additionalContext); + } else if (body.includes('/claude-query')) { + // Extract everything after /claude-query + const match = body.match(/\/claude-query\s+([\s\S]*)/); + const query = match ? match[1].trim() : ''; + if (!query) { + core.setFailed('No query provided after /claude-query. Usage: /claude-query '); + return; + } + core.setOutput('type', 'query'); + core.setOutput('query', query); + } else { + core.setFailed('Unknown command'); + } + + # Step 3: Get PR information (set command type for workflow_dispatch) + - name: Set command type for workflow_dispatch + id: command_dispatch + if: github.event_name == 'workflow_dispatch' + run: | + echo "type=review" >> "$GITHUB_OUTPUT" + echo "query=" >> "$GITHUB_OUTPUT" + + # Step 4: Get PR information + - name: Get PR details + id: pr_info + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + env: + # Pass input via env to avoid template injection + INPUT_PR_NUMBER: ${{ inputs.pr_number }} + with: + script: | + // Get PR number from either issue_comment or workflow_dispatch input + const prNumber = context.eventName === 'workflow_dispatch' + ? parseInt(process.env.INPUT_PR_NUMBER, 10) + : context.issue.number; + + const pr = await github.rest.pulls.get({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: prNumber + }); + + core.setOutput('head_repo', pr.data.head.repo.clone_url); + core.setOutput('head_ref', pr.data.head.ref); + core.setOutput('head_sha', pr.data.head.sha); + core.setOutput('base_sha', pr.data.base.sha); + core.setOutput('base_ref', pr.data.base.ref); + core.setOutput('pr_title', pr.data.title); + core.setOutput('pr_body', pr.data.body || ''); + core.setOutput('pr_author', pr.data.user.login); + + # Step 3: Checkout base repository (safe - this is our own code) + - name: Checkout base repository + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + ref: ${{ steps.pr_info.outputs.base_ref }} + fetch-depth: 0 + persist-credentials: false + + # Step 4: Fetch PR branch and generate diff (read-only, no execution) + - name: Generate PR diff + id: generate_diff + env: + HEAD_REPO: ${{ steps.pr_info.outputs.head_repo }} + HEAD_REF: ${{ steps.pr_info.outputs.head_ref }} + HEAD_SHA: ${{ steps.pr_info.outputs.head_sha }} + BASE_SHA: ${{ steps.pr_info.outputs.base_sha }} + run: | + # Fetch the PR branch from the fork (read-only) + git remote add fork "${HEAD_REPO}" || true + git fetch fork "${HEAD_REF}" --depth=100 + + # Generate the diff as text + git diff "${BASE_SHA}...${HEAD_SHA}" > /tmp/pr.diff + + # Get diff stats for context + DIFF_STATS=$(git diff --stat "${BASE_SHA}...${HEAD_SHA}" | tail -1) + echo "diff_stats=${DIFF_STATS}" >> "$GITHUB_OUTPUT" + + # Check diff size (limit to ~100KB to avoid token limits) + DIFF_SIZE=$(wc -c < /tmp/pr.diff) + echo "diff_size=${DIFF_SIZE}" >> "$GITHUB_OUTPUT" + + if [ "$DIFF_SIZE" -gt 100000 ]; then + echo "⚠️ Diff is large (${DIFF_SIZE} bytes). Review may be truncated." + # Truncate to first 100KB + head -c 100000 /tmp/pr.diff > /tmp/pr_truncated.diff + mv /tmp/pr_truncated.diff /tmp/pr.diff + echo "is_truncated=true" >> "$GITHUB_OUTPUT" + else + echo "is_truncated=false" >> "$GITHUB_OUTPUT" + fi + + # Step 7: Create the prompt (review or query) + - name: Create review prompt + if: steps.command.outputs.type == 'review' || github.event_name == 'workflow_dispatch' + env: + PR_TITLE: ${{ steps.pr_info.outputs.pr_title }} + PR_BODY: ${{ steps.pr_info.outputs.pr_body }} + PR_AUTHOR: ${{ steps.pr_info.outputs.pr_author }} + DIFF_STATS: ${{ steps.generate_diff.outputs.diff_stats }} + IS_TRUNCATED: ${{ steps.generate_diff.outputs.is_truncated }} + ADDITIONAL_CONTEXT: ${{ steps.command.outputs.query }} + run: | + # Create system context first (no variable expansion needed) + cat > /tmp/prompt.txt << 'PROMPT_EOF' + You are an expert C++ code reviewer for the Velox project. Your role is to perform + a thorough, rigorous code review that catches issues before they reach production. + + **IMPORTANT**: First, read the CLAUDE.md file in the repository root. It contains: + - Project overview and architecture + - Build commands and environment setup + - Code style and naming conventions + - PR title format requirements + - Testing guidelines + - Function addition requirements + + Use CLAUDE.md as your primary reference for project-specific standards. + + Key things to know about Velox: + - Uses C++20 standard + - Heavy use of templates and SFINAE + - Custom memory management with MemoryPool + - Vectorized execution with custom Vector types + - Follows Google C++ style with some modifications + + ## Review Approach + + **Think deeply and carefully about this code.** Take your time to: + - Trace through the logic step by step + - Consider what happens at boundary conditions (empty inputs, null values, max sizes) + - Think about concurrency issues if multiple threads could access this code + - Consider memory safety: ownership, lifetimes, dangling references + - Look for off-by-one errors, integer overflow, and other subtle bugs + - Examine error handling paths - what happens when things fail? + - Consider how this code interacts with existing code in the codebase + + **Be thorough and strict.** This is a high-performance database engine where bugs can + cause data corruption, crashes, or security vulnerabilities. It's better to flag a + potential issue that turns out to be fine than to miss a real bug. + + **Explore edge cases exhaustively:** + - What if the input is empty? Null? Maximum size? + - What if allocation fails? What if an exception is thrown? + - What happens on the first iteration? The last iteration? + - Are there race conditions if called concurrently? + - What assumptions does this code make? Are they documented and validated? + + Provide actionable, specific feedback. Reference exact file paths and line numbers. + Be constructive and educational in your feedback. + If the diff looks good after thorough analysis, say so - but only after genuinely + examining edge cases and potential issues. + + --- + + PROMPT_EOF + + # Note: Using unquoted PROMPT_EOF to enable variable expansion + cat >> /tmp/prompt.txt << PROMPT_EOF + Please review the following pull request for the Velox project. + + ## Pull Request Information + - **Title:** ${PR_TITLE} + - **Author:** ${PR_AUTHOR} + - **Changes:** ${DIFF_STATS} + PROMPT_EOF + + if [ "${IS_TRUNCATED}" = "true" ]; then + echo "- **Note:** This diff was truncated due to size. Focus on the visible changes." >> /tmp/prompt.txt + fi + + # Using quoted 'PROMPT_EOF' here since this section has no variables to expand + cat >> /tmp/prompt.txt << 'PROMPT_EOF' + + ## PR Description + PROMPT_EOF + + # PR body may contain special characters, write it safely + printf '%s\n' "${PR_BODY}" >> /tmp/prompt.txt + + # Add additional context/instructions if provided via /claude-review + if [ -n "${ADDITIONAL_CONTEXT}" ]; then + cat >> /tmp/prompt.txt << 'PROMPT_EOF' + + ## Additional Instructions from Reviewer + + PROMPT_EOF + printf '%s\n' "${ADDITIONAL_CONTEXT}" >> /tmp/prompt.txt + fi + + cat >> /tmp/prompt.txt << 'PROMPT_EOF' + + ## Review Guidelines + + Velox is a C++ execution engine library for analytical data processing. This is + performance-critical code that must be correct, efficient, and robust. + + **Analyze each of these areas thoroughly:** + + 1. **Correctness & Edge Cases** + - Logic errors, off-by-one bugs, incorrect conditions + - Null/empty input handling + - Boundary conditions (first element, last element, single element, max size) + - Integer overflow/underflow + - Floating point edge cases (NaN, Inf, negative zero) + + 2. **Memory Safety** + - Use-after-free, double-free, memory leaks + - Dangling pointers/references + - Buffer overflows/underflows + - Ownership and lifetime issues + - Exception safety (what happens if an exception is thrown mid-operation?) + + 3. **Concurrency** + - Race conditions, data races + - Deadlocks, lock ordering + - Thread-safety of shared state + + 4. **Performance** + - Unnecessary copies (should use move semantics?) + - Inefficient algorithms (O(n²) when O(n) is possible?) + - Cache-unfriendly access patterns + - Excessive allocations in hot paths + + 5. **Error Handling** + - Are all error paths handled? + - Are exceptions caught appropriately? + - Are error messages informative? + + 6. **Code Quality** + - RAII, const-correctness, proper use of smart pointers + - Following Velox naming conventions (PascalCase for types, camelCase for functions) + - Clear, maintainable code structure + + 7. **Testing** + - Are there sufficient tests for the new/changed code? + - Are edge cases covered in tests? + - Are error paths tested? + + ## Format Your Review As + + ### Summary + Brief overall assessment (1-2 sentences) + + ### Issues Found + List any problems, categorized by severity: + - 🔴 **Critical**: Must fix before merge + - 🟡 **Suggestion**: Should consider + - 🟢 **Nitpick**: Minor style issues + + For each issue, include: + - File and line reference + - Description of the issue + - Suggested fix if applicable + + ### Positive Observations + Note any particularly good patterns or improvements. + + --- + + ## Diff to Review + + PROMPT_EOF + + cat /tmp/pr.diff >> /tmp/prompt.txt + + # Step 8: Create query prompt (for /claude-query) + - name: Create query prompt + if: steps.command.outputs.type == 'query' + env: + PR_TITLE: ${{ steps.pr_info.outputs.pr_title }} + PR_BODY: ${{ steps.pr_info.outputs.pr_body }} + PR_AUTHOR: ${{ steps.pr_info.outputs.pr_author }} + DIFF_STATS: ${{ steps.generate_diff.outputs.diff_stats }} + USER_QUERY: ${{ steps.command.outputs.query }} + run: | + cat > /tmp/prompt.txt << 'PROMPT_EOF' + You are an expert C++ engineer helping with questions about the Velox project. + + **IMPORTANT**: First, read the CLAUDE.md file in the repository root for project context. + + Key things to know about Velox: + - Velox is a C++ execution engine library for analytical data processing + - Uses C++20 standard with heavy use of templates and SFINAE + - Custom memory management with MemoryPool + - Vectorized execution with custom Vector types + - Follows Google C++ style with some modifications + + You have access to: + - The PR diff (changes being proposed) + - The full codebase via View, GlobTool, and GrepTool + - The CLAUDE.md file with project guidelines + + Answer the user's question thoroughly and accurately. If the question is about + the PR changes, analyze them carefully. If it's about the codebase, explore + relevant files to provide a complete answer. + + Be specific and reference exact file paths and line numbers when relevant. + + --- + + PROMPT_EOF + + cat >> /tmp/prompt.txt << PROMPT_EOF + ## Pull Request Context + - **Title:** ${PR_TITLE} + - **Author:** ${PR_AUTHOR} + - **Changes:** ${DIFF_STATS} + + ## PR Description + PROMPT_EOF + + printf '%s\n' "${PR_BODY}" >> /tmp/prompt.txt + + cat >> /tmp/prompt.txt << 'PROMPT_EOF' + + ## User Question + + PROMPT_EOF + + printf '%s\n' "${USER_QUERY}" >> /tmp/prompt.txt + + cat >> /tmp/prompt.txt << 'PROMPT_EOF' + + --- + + ## PR Diff (for reference) + + PROMPT_EOF + + cat /tmp/pr.diff >> /tmp/prompt.txt + + # Step 9: Run Claude + - name: Run Claude + id: claude_review + uses: anthropics/claude-code-base-action@e8132bc5e637a42c27763fc757faa37e1ee43b34 # beta + env: + CLAUDE_MODEL: ${{ github.event_name == 'workflow_dispatch' && inputs.model || 'claude-opus-4-6' }} + with: + prompt_file: /tmp/prompt.txt + # Use configurable model for thorough analysis, with read-only tools + claude_args: >- + --model ${{ github.event_name == 'workflow_dispatch' && inputs.model || 'claude-opus-4-6' }} + --max-turns 25 + --allowedTools View,GlobTool,GrepTool + anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} + + # Step 7: Save execution log as artifact for debugging + - name: Upload execution log + if: always() + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + with: + name: claude-execution-log-${{ github.event_name == 'workflow_dispatch' && inputs.pr_number || github.event.issue.number }} + path: ${{ steps.claude_review.outputs.execution_file }} + retention-days: 7 + if-no-files-found: warn + + # Step 11: Post response as PR comment (skip if dry_run) + - name: Post Claude response + if: ${{ !(github.event_name == 'workflow_dispatch' && inputs.dry_run) }} + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + env: + EXECUTION_FILE: ${{ steps.claude_review.outputs.execution_file }} + CONCLUSION: ${{ steps.claude_review.outputs.conclusion }} + REVIEWER: ${{ github.event_name == 'workflow_dispatch' && github.actor || github.event.comment.user.login }} + PR_NUMBER: ${{ github.event_name == 'workflow_dispatch' && inputs.pr_number || github.event.issue.number }} + COMMAND_TYPE: ${{ steps.command.outputs.type || 'review' }} + with: + script: | + const fs = require('fs'); + + let responseBody = ''; + const commandType = process.env.COMMAND_TYPE; + const isReview = commandType === 'review'; + + try { + // Read the execution log (array of messages from claude-code-base-action) + const executionLog = JSON.parse(fs.readFileSync(process.env.EXECUTION_FILE, 'utf8')); + + if (!Array.isArray(executionLog)) { + throw new Error('Expected array format from claude-code-base-action'); + } + + // Find the "result" message which contains the final response + const resultMessage = executionLog.find(m => m.type === 'result'); + + if (!resultMessage) { + responseBody = '⚠️ No result message found in execution log.'; + } else if (resultMessage.subtype === 'success' && resultMessage.result) { + responseBody = resultMessage.result; + } else if (resultMessage.is_error || resultMessage.subtype === 'error') { + const errorInfo = resultMessage.result || resultMessage.error || 'Unknown error'; + responseBody = `❌ **Claude encountered an error:**\n\n${errorInfo}`; + } else if (resultMessage.result) { + responseBody = `⚠️ **Completed with status: ${resultMessage.subtype}**\n\n${resultMessage.result}`; + } else { + responseBody = '⚠️ Claude completed but produced no output.'; + } + } catch (error) { + console.error('Error parsing execution log:', error); + responseBody = '❌ Error parsing Claude response. Please check the workflow logs.'; + } + + // Add header and footer based on command type + const conclusion = process.env.CONCLUSION === 'success' ? '✅' : '⚠️'; + const headerTitle = isReview ? 'Claude Code Review' : 'Claude Response'; + const aboutText = isReview + ? 'This review was generated by [Claude Code](https://github.com/anthropics/claude-code-action). It analyzed the PR diff and codebase to provide feedback.' + : 'This response was generated by [Claude Code](https://github.com/anthropics/claude-code-action). It analyzed the PR and codebase to answer your question.'; + + const fullComment = `## ${conclusion} ${headerTitle} + + *Requested by @${process.env.REVIEWER}* + + --- + + ${responseBody} + + --- + +
+ ℹ️ About this response + + ${aboutText} + + **Limitations:** + - Claude may miss context from files not in the diff + - Large PRs may be truncated + - Always apply human judgment to AI suggestions + + **Available commands:** + - \`/claude-review [additional context]\` - Request a code review. Optionally provide additional instructions (e.g., \`/claude-review Please focus on memory safety\`) + - \`/claude-query \` - Ask a question about the PR or codebase +
`; + + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: parseInt(process.env.PR_NUMBER), + body: fullComment + }); + + # Step 8b: Print review to logs if dry_run + - name: Print review to logs (dry run) + if: ${{ github.event_name == 'workflow_dispatch' && inputs.dry_run }} + env: + EXECUTION_FILE: ${{ steps.claude_review.outputs.execution_file }} + run: | + echo "=== DRY RUN - Review would be posted as comment ===" + echo "" + echo "=== Extracting review from result message ===" + # Find the "result" message and extract based on subtype + RESULT_MSG=$(jq '[.[] | select(.type == "result")] | .[0]' "$EXECUTION_FILE" 2>/dev/null) + if [ "$RESULT_MSG" = "null" ] || [ -z "$RESULT_MSG" ]; then + echo "⚠️ No result message found" + else + SUBTYPE=$(echo "$RESULT_MSG" | jq -r '.subtype // "unknown"') + echo "Status: $SUBTYPE" + echo "" + if [ "$SUBTYPE" = "success" ]; then + echo "$RESULT_MSG" | jq -r '.result // "No result field"' + elif [ "$SUBTYPE" = "error" ]; then + echo "❌ Error:" + echo "$RESULT_MSG" | jq -r '.result // .error // "Unknown error"' + else + echo "⚠️ Status: $SUBTYPE" + echo "$RESULT_MSG" | jq -r '.result // "No result"' + fi + fi + echo "" + echo "=== End of review ===" + + # Step 9: Update reaction on completion (only for issue_comment trigger) + - name: Update reaction on success + if: success() && github.event_name == 'issue_comment' + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + with: + script: | + await github.rest.reactions.createForIssueComment({ + owner: context.repo.owner, + repo: context.repo.repo, + comment_id: context.payload.comment.id, + content: 'rocket' + }); + + - name: Update reaction on failure + if: failure() && github.event_name == 'issue_comment' + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + with: + script: | + await github.rest.reactions.createForIssueComment({ + owner: context.repo.owner, + repo: context.repo.repo, + comment_id: context.payload.comment.id, + content: 'confused' + }); + + # Job to handle unauthorized users attempting to trigger Claude (issue_comment only) + unauthorized-notice: + name: Unauthorized Notice + runs-on: ubuntu-latest + + # Job-level permissions + permissions: + issues: write + pull-requests: write + + if: >- + github.event_name == 'issue_comment' && + github.event.issue.pull_request && + (contains(github.event.comment.body, '/claude-review') || contains(github.event.comment.body, '/claude-query')) && + !contains(fromJSON('["kgpai", "mbasmanova", "pedroerp", "yuhta", "kagamiori", "bikramSingh91", "kevinwilfong", "xiaoxmeng", "kKPulla", "juwentus1234", "penescu", "srsuryadev", "jainxrohit"]'), github.event.comment.user.login) + + steps: + - name: Post unauthorized notice + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + with: + script: | + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: `👋 @${context.payload.comment.user.login}, the \`/claude-review\` and \`/claude-query\` commands are currently restricted to a small group of users during this initial rollout. + + If you'd like a Claude review, please ask a maintainer to run this command.` + }); diff --git a/.github/disabled-workflows/claude.yml b/.github/disabled-workflows/claude.yml new file mode 100644 index 00000000000..2b7b2cb31a5 --- /dev/null +++ b/.github/disabled-workflows/claude.yml @@ -0,0 +1,198 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Claude — skill-based PR assistant (replaces claude-review.yml) +# +# This workflow uses izaitsevfb/claude-code-action (a fork that fixes +# forked-PR handling) with Claude Skills instead of inline prompts. +# Review logic lives in .claude/skills/pr-review/SKILL.md. +# +# Supported commands (comment on a PR): +# @claude /pr-review [context] — thorough code review +# @claude /query — ask about the PR or codebase +# +# Differences from claude-review.yml (the legacy workflow): +# - Uses claude-code-action (higher-level) instead of claude-code-base-action +# - Review/query instructions are in Skills, not inline heredocs +# - Trigger phrase is @claude (not /claude-review or /claude-query) +# - Tag mode handles PR context, diff, and comment posting automatically +# +# Security model: +# - Hardcoded user allowlist gates workflow execution +# - Read-only tools enforced via --allowedTools +# - Fork PRs handled securely by the action fork + +name: Claude + +on: + # Tag mode: action detects @claude in comments and responds inline + issue_comment: + types: [created] + + # Agent mode: manual trigger for testing (log-only when dry_run is true) + workflow_dispatch: + inputs: + pr_number: + description: PR number to review + required: true + type: number + model: + description: Claude model to use + required: false + type: choice + options: + - claude-opus-4-6 + - claude-opus-4-1-20250805 + - claude-sonnet-4-20250514 + - claude-4-0-sonnet-20250805 + default: claude-opus-4-6 + additional_context: + description: Additional instructions for the review (e.g. focus on memory safety) + required: false + type: string + default: '' + dry_run: + description: Log review output only — do not post a comment on the PR + required: false + type: boolean + default: false + +# Restrict default permissions — jobs declare only what they need +permissions: + contents: read + +jobs: + claude-assistant: + name: Claude Assistant + runs-on: ubuntu-latest + permissions: + contents: read + pull-requests: write + issues: write + actions: read + id-token: write + # Gate: workflow_dispatch always runs; issue_comment requires PR context, + # @claude mention, and an authorized user + if: >- + github.event_name == 'workflow_dispatch' || + ( + github.event_name == 'issue_comment' && + github.event.issue.pull_request && + contains(github.event.comment.body, '@claude') && + contains(fromJSON('["kgpai", "mbasmanova", "pedroerp", "yuhta", "kagamiori", "bikramSingh91", "kevinwilfong", "xiaoxmeng", "kKPulla", "juwentus1234", "penescu", "srsuryadev", "jainxrohit", "pratikpugalia"]'), github.event.comment.user.login) + ) + + steps: + # For workflow_dispatch, checkout the PR merge ref so Claude has the + # PR code. For issue_comment, default ref is fine — the action handles + # diff fetching internally. + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + fetch-depth: 1 + ref: ${{ github.event_name == 'workflow_dispatch' && format('refs/pull/{0}/merge', inputs.pr_number) || '' }} + persist-credentials: false + + # Build prompt and claude_args in bash to avoid format() curly-brace + # issues with user-supplied additional_context, and to keep the logic + # readable. Outputs are only set for workflow_dispatch; for + # issue_comment (tag mode), prompt stays empty so the action uses + # tag mode and CLAUDE.md provides skill routing. + - name: Build prompt and args + id: config + if: github.event_name == 'workflow_dispatch' + env: + INPUT_PR_NUMBER: ${{ inputs.pr_number }} + INPUT_MODEL: ${{ inputs.model }} + INPUT_DRY_RUN: ${{ inputs.dry_run }} + INPUT_ADDITIONAL_CONTEXT: ${{ inputs.additional_context }} + run: | + # --- Prompt --- + PROMPT="Review the PR currently checked out in the Velox project current workspace. Use the /pr-review skill." + + if [ "$INPUT_DRY_RUN" = "true" ]; then + PROMPT="$PROMPT Output the review to the workflow log only. Do NOT post any comments on the PR." + else + PROMPT="$PROMPT Post the final output as a comment on the original PR ${INPUT_PR_NUMBER}." + fi + + if [ -n "$INPUT_ADDITIONAL_CONTEXT" ]; then + PROMPT="$PROMPT Additional reviewer instructions: ${INPUT_ADDITIONAL_CONTEXT}" + fi + + # Write multi-line safe output + { + echo "prompt<> "$GITHUB_OUTPUT" + + # --- Allowed tools --- + TOOLS="Skill" + + # Comment-posting tools only when NOT dry_run + if [ "$INPUT_DRY_RUN" != "true" ]; then + TOOLS="$TOOLS,Bash(gh pr comment:*)" + TOOLS="$TOOLS,mcp__github_inline_comment__create_inline_comment" + fi + + # Read-only PR tools (always available for workflow_dispatch) + TOOLS="$TOOLS,Bash(gh pr diff:*),Bash(gh pr view:*)" + + # Codebase exploration tools + TOOLS="$TOOLS,Read,Glob,Grep" + + ARGS="--model ${INPUT_MODEL:-claude-opus-4-6} --allowedTools ${TOOLS}" + echo "claude_args=$ARGS" >> "$GITHUB_OUTPUT" + + - name: Run Claude Code + uses: izaitsevfb/claude-code-action@ececd56fb999d06b4dd2477437bc408938295d76 # forked-pr-fix + with: + trigger_phrase: "@claude" + # prompt is only set for workflow_dispatch (agent mode). + # For issue_comment (tag mode), prompt must be empty so the action + # uses tag mode; CLAUDE.md provides skill routing instead. + prompt: ${{ steps.config.outputs.prompt || '' }} + anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} + # Tool access (workflow_dispatch uses computed args from setup step): + # - Skill: invoke /pr-review and /query skills + # - Read,Glob,Grep: codebase exploration (used by skills) + # - mcp__github_inline_comment: post inline comments (excluded in dry_run) + # - Bash(gh pr comment/diff/view): gh CLI for PR data (comment excluded in dry_run) + claude_args: ${{ steps.config.outputs.claude_args || '--model claude-opus-4-6 --allowedTools Skill,mcp__github_inline_comment__create_inline_comment,Read,Glob,Grep' }} + settings: '{"alwaysThinkingEnabled": true}' + + # Notify unauthorized users who try to invoke @claude + unauthorized-notice: + name: Unauthorized Notice + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + if: >- + github.event_name == 'issue_comment' && + github.event.issue.pull_request && + contains(github.event.comment.body, '@claude') && + !contains(fromJSON('["kgpai", "mbasmanova", "pedroerp", "yuhta", "kagamiori", "bikramSingh91", "kevinwilfong", "xiaoxmeng", "kKPulla", "juwentus1234", "penescu", "srsuryadev", "jainxrohit", "pratikpugalia"]'), github.event.comment.user.login) + + steps: + - name: Post unauthorized notice + uses: actions/github-script@v7 + with: + script: | + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: `👋 @${context.payload.comment.user.login}, \`@claude\` commands are currently restricted to a small group of users during this initial rollout.\n\nIf you'd like a Claude review, please ask a maintainer to run this command.` + }); diff --git a/.github/disabled-workflows/docker.yml b/.github/disabled-workflows/docker.yml index b01751579c0..00a3da83f1f 100644 --- a/.github/disabled-workflows/docker.yml +++ b/.github/disabled-workflows/docker.yml @@ -37,6 +37,7 @@ permissions: env: BASE_NAME: ghcr.io/facebookincubator + DOCKER_UPLOAD_CACHE: ${{ github.repository == 'facebookincubator/velox' && github.event_name != 'pull_request'}} jobs: multi-arch-base: @@ -55,16 +56,56 @@ jobs: steps: - name: Free Disk Space run: | - # 15G - sudo rm -rf /usr/local/lib/android || : - # 5.3GB - sudo rm -rf /opt/hostedtoolcache/CodeQL || : + # Re-used from free-disk-space github action. + getAvailableSpace() { echo $(df -a $1 | awk 'NR > 1 {avail+=$4} END {print avail}'); } + # Show before + echo "Original available disk space: " $(getAvailableSpace) + # Remove DotNet. + sudo rm -rf /usr/share/dotnet || true + # Remove android + sudo rm -rf /usr/local/lib/android || true + # Remove CodeQL + sudo rm -rf /opt/hostedtoolcache/CodeQL || true + # Remove Haskell + sudo rm -rf /opt/ghc || true + sudo rm -rf /usr/local/.ghcup || true + # Remove Powershell + sudo rm -rf /usr/local/share/powershell || true + # Remove Node + sudo rm -rf /usr/local/lib/node_modules || true + # Remove Go + sudo rm -rf /opt/hostedtoolcache/go || true + # Remove PyPy + sudo rm -rf /opt/hostedtoolcache/PyPy || true + # Remove chromium + sudo rm -rf /usr/local/share/chromium || true + # Remove azure + sudo rm -rf /opt/az || true + # Remove miscellaneous + sudo rm -rf \ + /usr/local/bin/aliyun \ + /usr/local/bin/azcopy \ + /usr/local/bin/bicep \ + /usr/local/bin/cmake-gui \ + /usr/local/bin/cpack \ + /usr/local/bin/helm \ + /usr/local/bin/hub \ + /usr/local/bin/kubectl \ + /usr/local/bin/minikube \ + /usr/local/bin/node \ + /usr/local/bin/packer \ + /usr/local/bin/pulumi* \ + /usr/local/bin/sam \ + /usr/local/bin/stack \ + /usr/local/bin/terraform || true + # Show after + echo "New available disk space: " $(getAvailableSpace) - name: Set up Docker Buildx uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1 - name: Login to GitHub Container Registry - uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0 + uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0 with: registry: ghcr.io username: ${{ github.actor }} @@ -73,8 +114,6 @@ jobs: - name: Bake Images id: bake uses: docker/bake-action@3acf805d94d93a86cce4ca44798a76464a75b88c # v6.9.0 - env: - DOCKER_UPLOAD_CACHE: ${{ github.repository == 'facebookincubator/velox' && github.event_name != 'pull_request'}} with: targets: ${{ matrix.target }}-${{ matrix.os.platform }} push: ${{ github.repository == 'facebookincubator/velox' && github.event_name != 'pull_request'}} @@ -92,7 +131,7 @@ jobs: - name: Upload Base Image if: ${{ github.event_name == 'pull_request' && matrix.target == 'ci' && matrix.os.platform == 'amd64' }} - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: path: centos9.tar name: centos9-amd64 @@ -108,7 +147,8 @@ jobs: # Add the target for any images that shouldn't be build as multi-platform # into the skip list echo "$METADATA" | jq -r 'def skip: ["ubuntu", "fedora"]; - . | to_entries[] | .key | split("-")[0] | + . | to_entries[] | select(.value | has("image.name")) | + .key | split("-")[0] | if . as $name | skip | index($name) != null then empty else . end' | \ while read -r image_name; do touch "$image_name" @@ -116,7 +156,7 @@ jobs: ls -la - name: Upload digest - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: digests-${{ matrix.target }}-${{ matrix.os.platform }} path: ${{ runner.temp }}/digests/* @@ -130,14 +170,14 @@ jobs: packages: write steps: - name: Login to GitHub Container Registry - uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0 + uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0 with: registry: ghcr.io username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - name: Download digests - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: path: ${{ runner.temp }}/digests pattern: digests-* @@ -168,7 +208,7 @@ jobs: target: [java] steps: - name: Login to GitHub Container Registry - uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0 + uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0 with: registry: ghcr.io username: ${{ github.actor }} @@ -182,7 +222,7 @@ jobs: - name: Download Base Image if: ${{ github.event_name == 'pull_request' }} - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: pattern: centos9-amd64 path: /tmp diff --git a/.github/disabled-workflows/docs.yml b/.github/disabled-workflows/docs.yml index 523fda54dcf..dd4156c1d9a 100644 --- a/.github/disabled-workflows/docs.yml +++ b/.github/disabled-workflows/docs.yml @@ -35,7 +35,7 @@ concurrency: jobs: build_docs: name: Build and Push - runs-on: 8-core-ubuntu + runs-on: 16-core-ubuntu container: ghcr.io/facebookincubator/velox-dev:centos9 env: CCACHE_DIR: /tmp/ccache @@ -48,14 +48,17 @@ jobs: id: restore-cache with: path: ${{ env.CCACHE_DIR }} - key: ccache-docs-8-core-ubuntu + key: ccache-docs-16-core-ubuntu - name: Checkout - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: fetch-depth: 0 persist-credentials: true + - name: Configure git safe directory + run: git config --global --add safe.directory ${GITHUB_WORKSPACE} + - name: Install System Dependencies run: | dnf install -y --setopt=install_weak_deps=False pandoc @@ -63,10 +66,53 @@ jobs: - name: CCache Stats Before run: ccache -sz + - name: Install uv + run: | + pip install --upgrade uv + uv --version + + - name: Check for pyvelox changes + id: check-pyvelox + run: | + if [ "${{ github.event_name }}" = "pull_request" ]; then + CHANGED=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..HEAD) + else + CHANGED=$(git diff --name-only HEAD~1..HEAD) + fi + if echo "$CHANGED" | grep -qE '^(python/|velox/python/)'; then + echo "Building pyvelox: python/ files changed." + echo "changed=true" >> $GITHUB_OUTPUT + else + echo "Skipping pyvelox build: no python/ files changed." + echo "changed=false" >> $GITHUB_OUTPUT + fi + - name: Install Python Dependencies + timeout-minutes: 30 + env: + PYVELOX_CHANGED: ${{ steps.check-pyvelox.outputs.changed }} run: | - # Install pyvelox to generate it's docs - uv sync --extra docs + # When python/ files haven't changed, skip building pyvelox + # from C++ source (slow) and only install doc dependencies. + EXTRA_FLAGS="" + if [ "$PYVELOX_CHANGED" != "true" ]; then + echo "Skipping pyvelox C++ build (no changes in python/)." + EXTRA_FLAGS="--no-install-project" + fi + # Retry up to 3 times to handle transient network issues + # (pip/uv hangs, registry timeouts). + for attempt in 1 2 3; do + echo "Attempt $attempt of 3..." + if uv sync --extra docs $EXTRA_FLAGS --verbose; then + echo "Python dependencies installed successfully." + exit 0 + fi + echo "Attempt $attempt failed. Cleaning up and retrying in 10 seconds..." + rm -rf .venv uv.lock + sleep 10 + done + echo "All 3 attempts failed." + exit 1 - name: CCache Stats After run: ccache -s @@ -75,7 +121,7 @@ jobs: uses: apache/infrastructure-actions/stash/save@3354c1565d4b0e335b78a76aedd82153a9e144d4 with: path: ${{ env.CCACHE_DIR }} - key: ccache-docs-8-core-ubuntu + key: ccache-docs-16-core-ubuntu - name: Build Documentation run: | @@ -92,7 +138,6 @@ jobs: run: | git config --global user.email "velox@users.noreply.github.com" git config --global user.name "velox" - git config --global --add safe.directory ${GITHUB_WORKSPACE} - name: Push Documentation if: ${{ github.event_name == 'push' && github.repository == 'facebookincubator/velox'}} @@ -109,7 +154,7 @@ jobs: - name: Upload Documentation if: github.event_name == 'pull_request' - uses: actions/upload-artifact@5d5d22a31266ced268874388b861e4b58bb5c2f3 # v4.3.1 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: path: velox/docs/_build/html retention-days: 3 diff --git a/.github/disabled-workflows/linux-build-base.yml b/.github/disabled-workflows/linux-build-base.yml index 2f08b94319c..34b800dd398 100644 --- a/.github/disabled-workflows/linux-build-base.yml +++ b/.github/disabled-workflows/linux-build-base.yml @@ -24,11 +24,63 @@ on: type: boolean jobs: + + get-changes: + runs-on: ubuntu-latest + outputs: + run-clang-tidy: ${{ steps.changes.outputs.run_clang_tidy }} + changed-files: ${{ steps.changes.outputs.files}} + diff-range: ${{ steps.changes.outputs.range }} + merge-base-commit: ${{ steps.changes.outputs.merge_base_commit }} + steps: + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + persist-credentials: false + fetch-depth: 0 + + - name: Get changed files + id: changes + env: + GH_TOKEN: ${{ github.token }} + BASE_REF: ${{ github.base_ref }} + HEAD_REF: ${{ github.head_ref }} + PR_OWNER: ${{ github.event.pull_request.head.repo.owner.login }} + HEAD_SHA: ${{ github.event.pull_request.head.sha }} + JOB_TRIGGER: ${{ github.event_name }} + run: | + echo "JOB_TRIGGER: $JOB_TRIGGER" + if [[ "$JOB_TRIGGER" == "pull_request" ]]; then + merge_base_commit=$(gh api -q '.merge_base_commit.sha' \ + /repos/facebookincubator/velox/compare/facebookincubator:$BASE_REF...$PR_OWNER:$HEAD_REF \ + ) + + range="$merge_base_commit..$HEAD_SHA" + echo "range=$range" >> "$GITHUB_OUTPUT" + echo "merge_base_commit=$merge_base_commit" >> "$GITHUB_OUTPUT" + + git diff --name-only $range > changed_files.txt + + cpp_files='.+\.(cpp|h|hpp)$' + + { + echo 'files<> "$GITHUB_OUTPUT" + + if grep -qE $cpp_files changed_files.txt; then + echo "run_clang_tidy=true" >> "$GITHUB_OUTPUT" + fi + else + echo "run_clang_tidy=false" >> "$GITHUB_OUTPUT" + fi + adapters: name: Linux release with adapters + needs: get-changes # prevent errors when forks ff their main branch if: ${{ github.repository == 'facebookincubator/velox' }} - runs-on: 8-core-ubuntu-22.04 + runs-on: 32-core-ubuntu container: ghcr.io/facebookincubator/velox-dev:adapters defaults: run: @@ -36,13 +88,21 @@ jobs: env: CCACHE_DIR: ${{ github.workspace }}/ccache VELOX_DEPENDENCY_SOURCE: SYSTEM - GTest_SOURCE: BUNDLED cudf_SOURCE: BUNDLED - CUDA_VERSION: '12.8' + CUDA_VERSION: '12.9' faiss_SOURCE: BUNDLED USE_CLANG: "${{ inputs.use-clang && 'true' || 'false' }}" + outputs: + cudf-changes: ${{ steps.changes.outputs.cudf }} + build-outcome: ${{ steps.build.outcome }} + build-failure-details: ${{ steps.build-errors.outputs.build-failure-details }} + test-outcome: ${{ steps.retry-tests.outcome != 'skipped' && steps.retry-tests.outcome || steps.tests.outcome }} + flaky: ${{ steps.retry-tests.outputs.flaky }} + failed-tests: ${{ steps.retry-tests.outputs.failed-tests }} + failed-cases: ${{ steps.retry-tests.outputs.failed-cases }} + failure-details: ${{ steps.retry-tests.outputs.failure-details }} steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: fetch-depth: 2 persist-credentials: false @@ -58,6 +118,9 @@ jobs: VELOX_ARROW_CMAKE_PATCH: ${{ github.workspace }}/CMake/resolve_dependency_modules/arrow/cmake-compatibility.patch run: | if git diff --name-only HEAD^1 HEAD | grep -q "scripts/setup-"; then + echo "Removing previous AWS SDK and s2n installations to avoid conflicts..." + rm -rf /usr/local/include/s2n /usr/local/lib64/s2n + rm -rf /usr/local/include/aws # Overwrite old setup scripts with changed versions cp scripts/setup-* / @@ -83,12 +146,13 @@ jobs: ccache -sz - name: Make Release Build + id: build env: - MAKEFLAGS: NUM_THREADS=8 MAX_HIGH_MEM_JOBS=4 MAX_LINK_JOBS=4 + MAKEFLAGS: NUM_THREADS=32 MAX_HIGH_MEM_JOBS=12 MAX_LINK_JOBS=12 CUDA_ARCHITECTURES: 70 CUDA_COMPILER: /usr/local/cuda-${CUDA_VERSION}/bin/nvcc - # Set compiler to GCC 12 - CUDA_FLAGS: -ccbin /opt/rh/gcc-toolset-12/root/usr/bin + # Set compiler to GCC 14 + CUDA_FLAGS: -ccbin /opt/rh/gcc-toolset-14/root/usr/bin run: | EXTRA_CMAKE_FLAGS=( "-DVELOX_ENABLE_BENCHMARKS=ON" @@ -102,19 +166,38 @@ jobs: "-DVELOX_ENABLE_ABFS=ON" "-DVELOX_ENABLE_WAVE=ON" "-DVELOX_MONO_LIBRARY=ON" - "-DVELOX_BUILD_SHARED=ON" ) if [[ "${USE_CLANG}" = "true" ]]; then scripts/setup-centos9.sh install_clang15; export CC=/usr/bin/clang-15; export CXX=/usr/bin/clang++-15; CUDA_FLAGS="-ccbin /usr/lib64/llvm15/bin/clang++-15"; else # cuDF (unsupported for Clang) and Faiss (link issue when using Clang) # are excluded for Clang compilation and need to be added back when using GCC. - EXTRA_CMAKE_FLAGS+="-DVELOX_ENABLE_CUDF=ON" - EXTRA_CMAKE_FLAGS+="-DVELOX_ENABLE_FAISS=ON" + EXTRA_CMAKE_FLAGS+=("-DVELOX_ENABLE_CUDF=ON") + EXTRA_CMAKE_FLAGS+=("-DVELOX_ENABLE_FAISS=ON") # Investigate issues with remote function service: Issue #13897 - EXTRA_CMAKE_FLAGS+="-DVELOX_ENABLE_REMOTE_FUNCTIONS=ON" + EXTRA_CMAKE_FLAGS+=("-DVELOX_ENABLE_REMOTE_FUNCTIONS=ON") + fi + source /opt/rh/gcc-toolset-14/enable + g++ --version + # CC/CXX are set to GCC12 in the adapters image build and it needs to be overridden + # after sourcing GCC14. + CC=gcc CXX=g++ make release EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS[*]}" 2>&1 | tee /tmp/build-output.log + + - name: Extract build errors + id: build-errors + if: failure() && steps.build.outcome == 'failure' + run: | + # Strip ANSI escape sequences before grepping — the build uses + # -fdiagnostics-color=always which embeds SGR and erase (ESC[K) + # codes that prevent matching literal ': error:'. + BUILD_ERRORS=$(sed 's/\x1b\[[0-9;]*[a-zA-Z]//g' /tmp/build-output.log | grep -E ' (error|fatal error):' | head -50 || true) + if [[ -n $BUILD_ERRORS ]]; then + { + echo 'build-failure-details<> "$GITHUB_OUTPUT" fi - make release EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS[*]}" - name: Ccache after run: ccache -s @@ -125,32 +208,295 @@ jobs: key: ccache-linux-adapters-${{ inputs.use-clang && 'clang' || 'gcc' }} - name: Run Tests + id: tests + continue-on-error: true + if: steps.build.outcome == 'success' env: LIBHDFS3_CONF: ${{ github.workspace }}/scripts/ci/hdfs-client.xml working-directory: _build/release run: | - # Can be removed after images are rebuild - if [ -f "/opt/miniforge/etc/profile.d/conda.sh" ]; then - source "/opt/miniforge/etc/profile.d/conda.sh" - conda activate adapters + source /setup-classpath.sh + ulimit -n 65536 + ctest -j 24 --timeout 900 --label-exclude cuda_driver \ + --output-on-failure --no-tests=error \ + --output-junit test-results.xml 2>&1 | tee /tmp/ctest-output.log + + - name: Retry Flaky Tests + id: retry-tests + continue-on-error: true + if: steps.tests.outcome == 'failure' + env: + LIBHDFS3_CONF: ${{ github.workspace }}/scripts/ci/hdfs-client.xml + working-directory: _build/release + run: | + source /setup-classpath.sh + echo "::warning::Some tests failed. Retrying failed tests..." + if ctest --rerun-failed -j 4 --timeout 900 \ + --output-on-failure --output-junit retry-results.xml; then + echo "::warning::All failed tests passed on retry — these are flaky tests." + echo "flaky=true" >> "$GITHUB_OUTPUT" + else + echo "::error::Tests failed consistently on retry." + "$GITHUB_WORKSPACE/.github/scripts/extract-test-failures.sh" /tmp/ctest-output.log + exit 1 + fi + + # Clang-tidy needs a complete build because some files are only generated during the build + # that clang tidy will not find and report errors otherwise. + # When clang is used a number of dependencies are excluded from the build so we don't + # need to run clang-tidy for that. + # Let's also run this as last step so that if skipped it doesn't affect subsequent steps. + - name: Install and run clang-tidy + if: ${{ steps.build.outcome == 'success' && ! inputs.use-clang && needs.get-changes.outputs.run-clang-tidy == 'true' }} + env: + FILES: ${{ needs.get-changes.outputs.changed-files }} + RANGE: ${{ needs.get-changes.outputs.diff-range }} + MERGE_BASE_COMMIT: ${{ needs.get-changes.outputs.merge-base-commit }} + run: | + git config --global --add safe.directory /__w/velox/velox + uv tool install clang-tidy==18.1.8 + git fetch origin $MERGE_BASE_COMMIT + # The usage of GCC14 adds compiler warnings not understood by Clang/Clang-tidy. + # We replace them for now but eventually move Clang-tidy to the Clang based build. + BUILD_PATH=$(realpath ./_build/release) + sed -i 's/-Wno-error=template-id-cdtor//g' $BUILD_PATH/compile_commands.json + sed -i 's/-Wno-class-memaccess//g' $BUILD_PATH/compile_commands.json + sed -i 's/-Wno-maybe-uninitialized//g' $BUILD_PATH/compile_commands.json + sed -i 's/-fcoroutines//g' $BUILD_PATH/compile_commands.json + python ./scripts/checks/run-clang-tidy.py -p $BUILD_PATH --commit $RANGE $FILES + + - name: Check cuDF Changes + uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2 + id: changes + if: ${{ ! inputs.use-clang && always() }} + with: + filters: | + cudf: + - 'velox/experimental/cudf/**' + - 'CMake/resolve_dependency_modules/cudf.cmake' + + - name: Copy Shared Libraries for cuDF Test Binaries + if: ${{ ! inputs.use-clang && steps.changes.outputs.cudf && always() }} + env: + CUDF_DIR: _build/release/velox/experimental/cudf + run: | + mkdir -p "$CUDF_DIR/cudf-libs" + deps=$( + find "$CUDF_DIR/tests" -name "velox_cudf_*" -type f -executable | + xargs ldd 2>/dev/null | + sed -n 's/[^\/]*\(\/[^ ]*\) .*/\1/p' | + grep -vE '^/(lib|usr/lib|lib64)/(libcuda|librt\.so|libm\.so|libstdc\+\+\.so|ld-linux-x86-64\.so|libc\.so)' | + sort -u + ) + if [ -z "$deps" ]; then + echo "Error: No dependencies found" + exit 1 + fi + # copy all "real" files and symlink targets to "$CUDF_DIR/cudf-libs" + ( + echo "$deps" | + xargs readlink -f | + xargs -I {} cp {} "$CUDF_DIR/cudf-libs/" + ) + # filter out which libraries are symlinks + lndeps=$(echo "$deps" | xargs -I {} bash -c '[ -L "$1" ] && echo "$1" || true' bash {}) + # write new symlinks to the same folder, with targets in the same folder + ( + if [[ -d "$CUDF_DIR/cudf-libs" ]]; then + cd "$CUDF_DIR/cudf-libs" || exit 1 + while IFS= read -r link; do + [[ -L "$link" ]] || continue + target=$(readlink -f "$link") || continue + target_basename=$(basename "$target") || continue + link_basename=$(basename "$link") || continue + ln -sf "$target_basename" "$link_basename" || echo "Warning: Failed to create symlink $link_basename" + done <<< "$lndeps" + else + echo "Error: Directory $CUDF_DIR/cudf-libs not found" + exit 1 + fi + ) + tar -cvf "$CUDF_DIR/cudf-libs.tar" -C "$CUDF_DIR" cudf-libs + + - name: Upload cuDF Test Binaries + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + if: ${{ ! inputs.use-clang && steps.changes.outputs.cudf && always() }} + with: + name: cudftestbinaries + path: | + velox/dwio/parquet/tests/examples/int.parquet + velox/experimental/cudf/tests/CMakeLists.txt + _build/release/velox/experimental/cudf/tests/velox_cudf_* + _build/release/velox/experimental/cudf/tests/CTestTestfile.cmake + _build/release/velox/experimental/cudf/cudf-libs.tar + retention-days: ${{ env.RETENTION }} + + adapters-build-status: + if: always() + needs: adapters + runs-on: ubuntu-latest + name: "BUILD: Linux release with adapters" + steps: + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + sparse-checkout: .github/scripts + persist-credentials: false + + - name: Report build status + run: .github/scripts/report-build-status.sh "Linux release with adapters" + env: + BUILD_OUTCOME: ${{ needs.adapters.outputs.build-outcome }} + BUILD_FAILURE_DETAILS: ${{ needs.adapters.outputs.build-failure-details }} + + - name: Upload failure artifact + if: failure() + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + with: + name: ci-failure-adapters-build + path: /tmp/ci-failure/failure.json + retention-days: 1 + if-no-files-found: ignore + + adapters-test-status: + if: always() + needs: adapters + runs-on: ubuntu-latest + name: "TEST: Linux release with adapters" + steps: + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + sparse-checkout: .github/scripts + persist-credentials: false + + - name: Report test status + run: .github/scripts/report-test-status.sh "Linux release with adapters" "Linux adapters release" + env: + BUILD_OUTCOME: ${{ needs.adapters.outputs.build-outcome }} + TEST_OUTCOME: ${{ needs.adapters.outputs.test-outcome }} + FLAKY: ${{ needs.adapters.outputs.flaky }} + FAILED_TESTS: ${{ needs.adapters.outputs.failed-tests }} + FAILED_CASES: ${{ needs.adapters.outputs.failed-cases }} + FAILURE_DETAILS: ${{ needs.adapters.outputs.failure-details }} + + - name: Upload failure artifact + if: failure() + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + with: + name: ci-failure-adapters-test + path: /tmp/ci-failure/failure.json + retention-days: 1 + if-no-files-found: ignore + + cudf-tests: + runs-on: 4-core-ubuntu-gpu-t4 + needs: adapters + if: ${{ ! inputs.use-clang && needs.adapters.outputs.cudf-changes == 'true' && needs.adapters.result == 'success' && always() }} + timeout-minutes: 30 + env: + CUDF_DIR: _build/release/velox/experimental/cudf + steps: + - name: Install Packages + run: | + sudo apt-get update && sudo apt-get install -y cmake patchelf cuda-toolkit-12-9 + export MINIO_VERSION="2022-05-26T05-48-41Z" + export MINIO_BINARY_NAME="minio-2022-05-26" + wget https://dl.min.io/server/minio/release/linux-amd64/archive/minio.RELEASE."${MINIO_VERSION}" -O "${MINIO_BINARY_NAME}" + sudo mv ./"${MINIO_BINARY_NAME}" /usr/local/bin/ + sudo chmod +x /usr/local/bin/"${MINIO_BINARY_NAME}" + + - name: Check NVIDIA Driver Version + run: | + nvidia-smi + + - name: Download cuDF Test Binaries + uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + with: + name: cudftestbinaries + + - name: Adapt Downloaded Files + run: | + # Change hardcoded workspace path in CTestTestfile.cmake to the runner workspace + if [[ -f "$CUDF_DIR/tests/CTestTestfile.cmake" ]]; then + workspace_escaped=$(printf '%s\n' "${{ github.workspace }}" | sed 's/[[\.*^$()+?{|]/\\&/g') + sed -i "s|/__w/velox/velox|${workspace_escaped}|g" "$CUDF_DIR/tests/CTestTestfile.cmake" + # Verify the replacement worked + if grep -q "${{ github.workspace }}" "$CUDF_DIR/tests/CTestTestfile.cmake"; then + echo "Successfully updated paths in CTestTestfile.cmake" + else + echo "Warning: Path replacement may have failed" + fi + else + echo "Error: CTestTestfile.cmake not found" + exit 1 fi + sed -i 's|/__w/velox/velox|${{ github.workspace }}|g' $CUDF_DIR/tests/CTestTestfile.cmake + grep "Source directory" $CUDF_DIR/tests/CTestTestfile.cmake + grep "Build directory" $CUDF_DIR/tests/CTestTestfile.cmake + (cd $CUDF_DIR && tar -xf cudf-libs.tar) + # Patch test executables. + # ldfconfig causes system wide changes. Since we are copying some system libraries from + # the build machine, we need to use patchelf to add the copied libraries path only to + # the test executables. + for exe in $CUDF_DIR/tests/velox_cudf_*; do + patchelf --force-rpath --set-rpath '$ORIGIN/../cudf-libs' "$exe" + chmod +x "$exe" + ls -l "$exe" + done - export CLASSPATH=`/usr/local/hadoop/bin/hdfs classpath --glob` - ctest -j 8 --label-exclude cuda_driver --output-on-failure --no-tests=error + - name: Run cuDF Tests + run: | + cd $CUDF_DIR/tests/ + ctest --output-on-failure ubuntu-debug: - runs-on: 8-core-ubuntu-22.04 + runs-on: 32-core-ubuntu + container: ghcr.io/facebookincubator/velox-dev:ubuntu-22.04 # prevent errors when forks ff their main branch if: ${{ github.repository == 'facebookincubator/velox' }} - name: Ubuntu debug with resolve_dependency + name: Ubuntu debug with system dependencies env: CCACHE_DIR: ${{ github.workspace }}/ccache USE_CLANG: ${{ inputs.use-clang && 'true' || 'false' }} + outputs: + build-outcome: ${{ steps.build.outcome }} + build-failure-details: ${{ steps.build-errors.outputs.build-failure-details }} + test-outcome: ${{ steps.retry-tests.outcome != 'skipped' && steps.retry-tests.outcome || steps.tests.outcome }} + flaky: ${{ steps.retry-tests.outputs.flaky }} + failed-tests: ${{ steps.retry-tests.outputs.failed-tests }} + failed-cases: ${{ steps.retry-tests.outputs.failed-cases }} + failure-details: ${{ steps.retry-tests.outputs.failure-details }} defaults: run: shell: bash working-directory: velox steps: + - uses: actions/checkout@v5 + with: + fetch-depth: 2 + path: velox + persist-credentials: false + + - name: Fix git permissions + # Usually actions/checkout does this but as we run in a container + # it doesn't work + run: git config --global --add safe.directory ${GITHUB_WORKSPACE} + + - name: Install Dependencies + env: + VELOX_ARROW_CMAKE_PATCH: ${{ github.workspace }}/velox/CMake/resolve_dependency_modules/arrow/cmake-compatibility.patch + run: | + if git diff --name-only HEAD^1 HEAD | grep -q "scripts/setup-"; then + # Overwrite old setup scripts with changed versions + cp scripts/setup-* / + + mkdir /tmp/build + cd /tmp/build + + USE_CLANG=false bash /setup-ubuntu.sh + + cd / + rm -rf /tmp/build # cleanup to avoid issues with disk space + fi - name: Get Ccache Stash uses: apache/infrastructure-actions/stash/restore@3354c1565d4b0e335b78a76aedd82153a9e144d4 @@ -163,35 +509,229 @@ jobs: run: | mkdir -p "$CCACHE_DIR" - - uses: actions/checkout@v4 + - name: Clear CCache Statistics + run: | + ccache -sz + + - name: Make Debug Build + id: build + env: + VELOX_DEPENDENCY_SOURCE: SYSTEM + ICU_SOURCE: SYSTEM + # Use BUNDLED gflags to provide PIC static gflags for .so plugins. + # The container's folly is built with -DGFLAGS_SHARED=FALSE so its + # exported config references gflags_static which BUNDLED gflags provides. + gflags_SOURCE: BUNDLED + # Keep system glog (container's glog is built against the same gflags + # version). Without this, BUNDLED gflags cascades to BUNDLED glog + # which conflicts with system glog loaded transitively. + glog_SOURCE: SYSTEM + MAKEFLAGS: NUM_THREADS=32 MAX_HIGH_MEM_JOBS=8 MAX_LINK_JOBS=6 + run: | + EXTRA_CMAKE_FLAGS=( + "-DCMAKE_LINK_LIBRARIES_STRATEGY=REORDER_FREELY" + "-DVELOX_BUILD_SHARED=ON" + "-DVELOX_ENABLE_BENCHMARKS=ON" + "-DVELOX_ENABLE_EXAMPLES=ON" + "-DVELOX_ENABLE_ARROW=ON" + "-DVELOX_ENABLE_GEO=ON" + "-DVELOX_ENABLE_PARQUET=ON" + "-DVELOX_MONO_LIBRARY=ON" + ) + if [[ "${USE_CLANG}" = "true" ]]; then + export CC=/usr/bin/clang-15; export CXX=/usr/bin/clang++-15; + else + EXTRA_CMAKE_FLAGS+=("-DVELOX_ENABLE_FAISS=ON") + EXTRA_CMAKE_FLAGS+=("-DVELOX_ENABLE_REMOTE_FUNCTIONS=ON") + fi + export EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS[*]}" + make debug 2>&1 | tee /tmp/build-output.log + + - name: Extract build errors + id: build-errors + if: failure() && steps.build.outcome == 'failure' + run: | + # Strip ANSI escape sequences before grepping — the build uses + # -fdiagnostics-color=always which embeds SGR and erase (ESC[K) + # codes that prevent matching literal ': error:'. + BUILD_ERRORS=$(sed 's/\x1b\[[0-9;]*[a-zA-Z]//g' /tmp/build-output.log | grep -E ' (error|fatal error):' | head -50 || true) + if [[ -n $BUILD_ERRORS ]]; then + { + echo 'build-failure-details<> "$GITHUB_OUTPUT" + fi + + - name: CCache after + run: | + ccache -vs + + - uses: apache/infrastructure-actions/stash/save@3354c1565d4b0e335b78a76aedd82153a9e144d4 + with: + path: ${{ env.CCACHE_DIR }} + key: ccache-ubuntu-debug-default-${{ inputs.use-clang && 'clang' || 'gcc' }} + + - name: Run Tests + id: tests + continue-on-error: true + if: steps.build.outcome == 'success' + run: | + ulimit -n 65536 + cd _build/debug + ctest -j 24 --timeout 1800 --output-on-failure --no-tests=error \ + --output-junit test-results.xml 2>&1 | tee /tmp/ctest-output.log + + - name: Retry Flaky Tests + id: retry-tests + continue-on-error: true + if: steps.tests.outcome == 'failure' + run: | + cd _build/debug + echo "::warning::Some tests failed. Retrying failed tests..." + if ctest --rerun-failed -j 4 --timeout 1800 \ + --output-on-failure --output-junit retry-results.xml; then + echo "::warning::All failed tests passed on retry — these are flaky tests." + echo "flaky=true" >> "$GITHUB_OUTPUT" + else + echo "::error::Tests failed consistently on retry." + "$GITHUB_WORKSPACE/velox/.github/scripts/extract-test-failures.sh" /tmp/ctest-output.log + exit 1 + fi + + ubuntu-debug-build-status: + if: always() + needs: ubuntu-debug + runs-on: ubuntu-latest + name: "BUILD: Ubuntu debug with system dependencies" + steps: + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + sparse-checkout: .github/scripts + persist-credentials: false + + - name: Report build status + run: .github/scripts/report-build-status.sh "Ubuntu debug with system dependencies" + env: + BUILD_OUTCOME: ${{ needs.ubuntu-debug.outputs.build-outcome }} + BUILD_FAILURE_DETAILS: ${{ needs.ubuntu-debug.outputs.build-failure-details }} + + - name: Upload failure artifact + if: failure() + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 with: + name: ci-failure-ubuntu-debug-build + path: /tmp/ci-failure/failure.json + retention-days: 1 + if-no-files-found: ignore + + ubuntu-debug-test-status: + if: always() + needs: ubuntu-debug + runs-on: ubuntu-latest + name: "TEST: Ubuntu debug with system dependencies" + steps: + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + sparse-checkout: .github/scripts + persist-credentials: false + + - name: Report test status + run: .github/scripts/report-test-status.sh "Ubuntu debug with system dependencies" "Ubuntu debug" + env: + BUILD_OUTCOME: ${{ needs.ubuntu-debug.outputs.build-outcome }} + TEST_OUTCOME: ${{ needs.ubuntu-debug.outputs.test-outcome }} + FLAKY: ${{ needs.ubuntu-debug.outputs.flaky }} + FAILED_TESTS: ${{ needs.ubuntu-debug.outputs.failed-tests }} + FAILED_CASES: ${{ needs.ubuntu-debug.outputs.failed-cases }} + FAILURE_DETAILS: ${{ needs.ubuntu-debug.outputs.failure-details }} + + - name: Upload failure artifact + if: failure() + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + with: + name: ci-failure-ubuntu-debug-test + path: /tmp/ci-failure/failure.json + retention-days: 1 + if-no-files-found: ignore + + fedora-debug: + runs-on: 32-core-ubuntu + container: ghcr.io/facebookincubator/velox-dev:fedora + # prevent errors when forks ff their main branch + if: ${{ github.repository == 'facebookincubator/velox' }} + name: Fedora debug + env: + CCACHE_DIR: ${{ github.workspace }}/ccache + outputs: + build-outcome: ${{ steps.build.outcome }} + defaults: + run: + shell: bash + working-directory: velox + steps: + - uses: actions/checkout@v5 + with: + fetch-depth: 2 path: velox persist-credentials: false + - name: Fix git permissions + # Usually actions/checkout does this but as we run in a container + # it doesn't work + run: git config --global --add safe.directory ${GITHUB_WORKSPACE} + - name: Install Dependencies + env: + VELOX_BUILD_SHARED: "ON" + VELOX_ARROW_CMAKE_PATCH: ${{ github.workspace }}/velox/CMake/resolve_dependency_modules/arrow/cmake-compatibility.patch run: | - source scripts/setup-ubuntu.sh && install_apt_deps && install_faiss_deps + if git diff --name-only HEAD^1 HEAD | grep -q "scripts/setup-"; then + # Overwrite old setup scripts with changed versions + cp scripts/setup-* / + + mkdir /tmp/build + cd /tmp/build + + # Install basic deps with GCC. + USE_CLANG=false bash /setup-fedora.sh + + cd / + rm -rf /tmp/build # cleanup to avoid issues with disk space + fi + + - name: Get Ccache Stash + uses: apache/infrastructure-actions/stash/restore@3354c1565d4b0e335b78a76aedd82153a9e144d4 + with: + path: ${{ env.CCACHE_DIR }} + key: ccache-fedora-debug-default-gcc + + - name: Ensure Stash Dirs Exists + working-directory: ${{ github.workspace }} + run: | + mkdir -p "$CCACHE_DIR" - name: Clear CCache Statistics run: | ccache -sz - name: Make Debug Build + id: build env: - VELOX_DEPENDENCY_SOURCE: BUNDLED - ICU_SOURCE: SYSTEM - MAKEFLAGS: NUM_THREADS=8 MAX_HIGH_MEM_JOBS=4 MAX_LINK_JOBS=3 + VELOX_DEPENDENCY_SOURCE: SYSTEM + faiss_SOURCE: BUNDLED + fmt_SOURCE: BUNDLED + simdjson_SOURCE: BUNDLED + gRPC_SOURCE: SYSTEM + MAKEFLAGS: NUM_THREADS=32 MAX_HIGH_MEM_JOBS=8 MAX_LINK_JOBS=6 EXTRA_CMAKE_FLAGS: >- - -DCMAKE_LINK_LIBRARIES_STRATEGY=REORDER_FREELY -DVELOX_ENABLE_PARQUET=ON + -DARROW_THRIFT_USE_SHARED=ON -DVELOX_ENABLE_EXAMPLES=ON run: | - # Faiss (link issue when using Clang) is excluded for Clang compilation and needs to be added back when using GCC. - if [[ "${USE_CLANG}" = "true" ]]; then - export CC=/usr/bin/clang-15; export CXX=/usr/bin/clang++-15; - else - export EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} -DVELOX_ENABLE_FAISS=ON" - fi + uv tool install --force cmake@3.31.1 + dnf install -y -q --setopt=install_weak_deps=False grpc-devel grpc-plugins + export EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} -DVELOX_ENABLE_FAISS=ON" make debug - name: CCache after @@ -201,8 +741,39 @@ jobs: - uses: apache/infrastructure-actions/stash/save@3354c1565d4b0e335b78a76aedd82153a9e144d4 with: path: ${{ env.CCACHE_DIR }} - key: ccache-ubuntu-debug-default-${{ inputs.use-clang && 'clang' || 'gcc' }} + key: ccache-fedora-debug-default-gcc - - name: Run Tests + fedora-debug-build-status: + if: always() + needs: fedora-debug + runs-on: ubuntu-latest + name: "BUILD: Fedora debug" + steps: + - run: | + if [[ -z "$BUILD_OUTCOME" || "$BUILD_OUTCOME" == "cancelled" ]]; then + echo "Build was cancelled (likely superseded by a newer push). Skipping." + exit 0 + fi + if [[ "$BUILD_OUTCOME" != "success" ]]; then + echo "::error::Fedora debug build failed. Do not land this PR until the build is fixed." + exit 1 + fi + echo "Build succeeded." + env: + BUILD_OUTCOME: ${{ needs.fedora-debug.outputs.build-outcome }} + + - name: Upload build failure artifact + if: failure() run: | - cd _build/debug && ctest -j 8 --output-on-failure --no-tests=error + mkdir -p /tmp/ci-failure + jq -n --arg job "Fedora debug" --arg type "build" \ + '{job: $job, type: $type}' \ + > /tmp/ci-failure/failure.json + + - name: Upload failure comment artifact + if: failure() + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + with: + name: ci-failure-fedora-build + path: /tmp/ci-failure/failure.json + retention-days: 1 diff --git a/.github/disabled-workflows/macos.yml b/.github/disabled-workflows/macos.yml index f2db45237cb..2fb94f0f836 100644 --- a/.github/disabled-workflows/macos.yml +++ b/.github/disabled-workflows/macos.yml @@ -22,6 +22,9 @@ on: - CMakeLists.txt - CMake/** - scripts/setup-macos.sh + - scripts/setup-common.sh + - scripts/setup-versions.sh + - scripts/setup-helper-functions.sh - .github/workflows/macos.yml pull_request: @@ -31,6 +34,9 @@ on: - CMakeLists.txt - CMake/** - scripts/setup-macos.sh + - scripts/setup-common.sh + - scripts/setup-versions.sh + - scripts/setup-helper-functions.sh - .github/workflows/macos.yml permissions: @@ -43,22 +49,20 @@ concurrency: jobs: macos-build: if: ${{ github.repository == 'facebookincubator/velox' }} - name: ${{ matrix.os }} + name: macos-15-${{ matrix.type }} strategy: fail-fast: false matrix: - # macos-13 = x86_64 Mac - # macos-15 = arm64 Mac and cmake 4.0 - os: [macos-13, macos-15] - runs-on: ${{ matrix.os }} + # macos-15 = arm64 Mac and cmake 4.0 with 7GB RAM + type: [debug, release] + runs-on: macos-15 env: CCACHE_DIR: ${{ github.workspace }}/ccache - # The arm runners have only 7GB RAM - BUILD_TYPE: ${{ matrix.os == 'macos-15' && 'Release' || 'Debug' }} + BUILD_TYPE: ${{ matrix.type }} INSTALL_PREFIX: /tmp/deps-install steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: persist-credentials: false @@ -70,6 +74,9 @@ jobs: source scripts/setup-macos.sh install_build_prerequisites install_velox_deps_from_brew + install_gflags + install_glog + install_protobuf # Needed for faiss to find BLAS install_faiss_deps install_double_conversion @@ -86,7 +93,7 @@ jobs: uses: apache/infrastructure-actions/stash/restore@3354c1565d4b0e335b78a76aedd82153a9e144d4 with: path: ${{ env.CCACHE_DIR }} - key: ccache-macos-1-${{ matrix.os }} + key: ccache-macos-1-macos-15-${{ matrix.type }} - name: Configure Build env: @@ -114,7 +121,7 @@ jobs: - uses: apache/infrastructure-actions/stash/save@3354c1565d4b0e335b78a76aedd82153a9e144d4 with: path: ${{ env.CCACHE_DIR }} - key: ccache-macos-1-${{ matrix.os }} + key: ccache-macos-1-macos-15-${{ matrix.type }} - name: Run Tests if: false diff --git a/.github/disabled-workflows/scheduled.yml b/.github/disabled-workflows/scheduled.yml index 17eb81ca190..429e174fac4 100644 --- a/.github/disabled-workflows/scheduled.yml +++ b/.github/disabled-workflows/scheduled.yml @@ -82,8 +82,10 @@ concurrency: cancel-in-progress: true env: - # Run for 15 minute on PRs - DURATION: ${{ inputs.duration || ( github.event_name != 'schedule' && 900 || 1800 )}} + # Per-instance fuzzer duration. With 4 parallel instances at 300s each, + # total fuzzer coverage is 1200s (20 min) vs the previous 900s (15 min). + DURATION: ${{ inputs.duration || ( github.event_name != 'schedule' && 300 || 600 )}} + NUM_FUZZER_INSTANCES: 4 # minimize artifact duration for PRs, keep them a bit longer for nightly runs RETENTION: ${{ github.event_name == 'pull_request' && 1 || 3 }} @@ -97,9 +99,9 @@ jobs: timeout-minutes: 120 env: CCACHE_DIR: ${{ github.workspace }}/ccache - NUM_THREADS: ${{ inputs.numThreads || 16 }} + NUM_THREADS: ${{ inputs.numThreads || 32 }} MAX_HIGH_MEM_JOBS: ${{ inputs.maxHighMemJobs || 8 }} - MAX_LINK_JOBS: ${{ inputs.maxLinkJobs || 4 }} + MAX_LINK_JOBS: ${{ inputs.maxLinkJobs || 6 }} SKBUILD_BUILD_DIR: _build/debug PYVELOX_LEGACY_ONLY: 'ON' CMAKE_POLICY_VERSION_MINIMUM: "3.5" @@ -115,6 +117,7 @@ jobs: spark_error: ${{ steps.sig-check.outputs.spark_error }} presto_aggregate_bias: ${{ steps.sig-check.outputs.presto_aggregate_functions }} presto_aggregate_error: ${{ steps.sig-check.outputs.presto_aggregate_error }} + signatures_checked: ${{ steps.sig-paths.outputs.signatures }} steps: @@ -170,22 +173,39 @@ jobs: mkdir -p '$CCACHE_DIR' mkdir -p /tmp/signatures + - name: Check for signature-relevant changes + uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2 + if: github.event_name == 'pull_request' + id: sig-paths + with: + filters: | + signatures: + - 'velox/functions/prestosql/**' + - 'velox/functions/sparksql/**' + - 'velox/functions/lib/**' + - 'velox/expression/**' + - 'velox/exec/Aggregate.h' + - 'velox/exec/Aggregate.cpp' + - 'velox/type/**' + - 'velox/python/**' + - 'pyproject.toml' + - name: Checkout Main - if: ${{ github.event_name != 'schedule' && steps.get-sig.outputs.stash-hit != 'true' }} - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + if: ${{ github.event_name != 'schedule' && steps.get-sig.outputs.stash-hit != 'true' && (github.event_name != 'pull_request' || steps.sig-paths.outputs.signatures == 'true') }} + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: ref: ${{ steps.get-merge-base.outputs.head_main || 'main' }} path: velox_main persist-credentials: false - name: Build PyVelox - if: ${{ github.event_name != 'schedule' && steps.get-sig.outputs.stash-hit != 'true' }} + if: ${{ github.event_name != 'schedule' && steps.get-sig.outputs.stash-hit != 'true' && (github.event_name != 'pull_request' || steps.sig-paths.outputs.signatures == 'true') }} working-directory: velox_main run: | make python-build - name: Create Baseline Signatures - if: ${{ github.event_name != 'schedule' && steps.get-sig.outputs.stash-hit != 'true' }} + if: ${{ github.event_name != 'schedule' && steps.get-sig.outputs.stash-hit != 'true' && (github.event_name != 'pull_request' || steps.sig-paths.outputs.signatures == 'true') }} working-directory: velox_main run: | # TODO convert to uv with script defined deps after 14046 is merged @@ -197,14 +217,14 @@ jobs: python3 scripts/ci/signature.py export_aggregates --presto /tmp/signatures/presto_aggregate_signatures_main.json - name: Save Function Signature Stash - if: ${{ github.event_name == 'pull_request' && steps.get-sig.outputs.stash-hit != 'true' }} + if: ${{ github.event_name == 'pull_request' && steps.get-sig.outputs.stash-hit != 'true' && steps.sig-paths.outputs.signatures == 'true' }} uses: apache/infrastructure-actions/stash/save@3354c1565d4b0e335b78a76aedd82153a9e144d4 with: path: /tmp/signatures key: function-signatures-${{ steps.get-merge-base.outputs.head_main }} - name: Checkout Contender - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: path: velox ref: ${{ inputs.ref }} @@ -214,19 +234,35 @@ jobs: run: | ccache -sz - - name: Build + - name: Build Fuzzer Targets env: EXTRA_CMAKE_FLAGS: > -DVELOX_ENABLE_ARROW=ON -DVELOX_ENABLE_GEO=ON -DVELOX_MONO_LIBRARY=ON + -DVELOX_BUILD_SHARED=ON -DVELOX_BUILD_PYTHON_PACKAGE=ON -DVELOX_PYTHON_LEGACY_ONLY=ON ${{ inputs.extraCMakeFlags }} run: | make python-venv source .venv/bin/activate - make debug + make cmake BUILD_DIR=debug BUILD_TYPE=Debug + cmake --build _build/debug -j ${NUM_THREADS} --target \ + velox_expression_fuzzer_test \ + spark_expression_fuzzer_test \ + spark_aggregation_fuzzer_test \ + velox_aggregation_fuzzer_test \ + velox_join_fuzzer \ + velox_exchange_fuzzer \ + velox_window_fuzzer_test \ + velox_cache_fuzzer \ + velox_table_evolution_fuzzer_test \ + velox_memory_arbitration_fuzzer \ + velox_row_number_fuzzer \ + velox_topn_row_number_fuzzer \ + velox_writer_fuzzer_test \ + velox_spatial_join_fuzzer - name: Ccache after run: ccache -s @@ -241,12 +277,12 @@ jobs: key: ccache-fuzzer-centos - name: Build PyVelox - if: ${{ github.event_name != 'schedule' }} + if: ${{ github.event_name != 'schedule' && (github.event_name != 'pull_request' || steps.sig-paths.outputs.signatures == 'true') }} run: | make python-build - name: Create and test new function signatures - if: ${{ github.event_name != 'schedule' }} + if: ${{ github.event_name != 'schedule' && (github.event_name != 'pull_request' || steps.sig-paths.outputs.signatures == 'true') }} id: sig-check run: | source .venv/bin/activate @@ -259,8 +295,8 @@ jobs: /tmp/signatures/presto_aggregate_errors - name: Upload Signature Artifacts - if: ${{ github.event_name != 'schedule' }} - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + if: ${{ github.event_name != 'schedule' && (github.event_name != 'pull_request' || steps.sig-paths.outputs.signatures == 'true') }} + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: signatures path: /tmp/signatures @@ -285,104 +321,125 @@ jobs: path: /tmp/signatures key: function-signatures-${{ github.sha }} + - name: Upload fuzzer scripts + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 + with: + name: fuzzer-scripts + path: velox/scripts/ci/run-fuzzer-parallel.sh + retention-days: ${{ env.RETENTION }} + - name: Upload presto fuzzer - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: presto path: velox/_build/debug/velox/expression/fuzzer/velox_expression_fuzzer_test retention-days: ${{ env.RETENTION }} - name: Upload spark expression fuzzer - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: spark_expression_fuzzer path: velox/_build/debug/velox/expression/fuzzer/spark_expression_fuzzer_test retention-days: ${{ env.RETENTION }} - name: Upload spark aggregation fuzzer - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: spark_aggregation_fuzzer path: velox/_build/debug/velox/functions/sparksql/fuzzer/spark_aggregation_fuzzer_test retention-days: ${{ env.RETENTION }} - name: Upload aggregation fuzzer - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: aggregation path: velox/_build/debug/velox/functions/prestosql/fuzzer/velox_aggregation_fuzzer_test retention-days: ${{ env.RETENTION }} - name: Upload join fuzzer - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: join path: velox/_build/debug/velox/exec/fuzzer/velox_join_fuzzer retention-days: ${{ env.RETENTION }} - name: Upload exchange fuzzer - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: exchange path: velox/_build/debug//velox/exec/fuzzer/velox_exchange_fuzzer retention-days: ${{ env.RETENTION }} - name: Upload window fuzzer - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: window path: velox/_build/debug/velox/functions/prestosql/fuzzer/velox_window_fuzzer_test retention-days: ${{ env.RETENTION }} - name: Upload cache fuzzer - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: cache_fuzzer path: velox/_build/debug/velox/exec/fuzzer/velox_cache_fuzzer retention-days: ${{ env.RETENTION }} - name: Upload table evolution fuzzer - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: table_evolution_fuzzer path: velox/_build/debug/velox/exec/tests/velox_table_evolution_fuzzer_test retention-days: ${{ env.RETENTION }} - name: Upload memory arbitration fuzzer - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: memory_arbitration_fuzzer path: velox/_build/debug/velox/exec/fuzzer/velox_memory_arbitration_fuzzer retention-days: ${{ env.RETENTION }} - name: Upload row number fuzzer - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: row_number path: velox/_build/debug//velox/exec/fuzzer/velox_row_number_fuzzer retention-days: ${{ env.RETENTION }} - name: Upload topn row number fuzzer - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: topn_row_number path: velox/_build/debug//velox/exec/fuzzer/velox_topn_row_number_fuzzer retention-days: ${{ env.RETENTION }} - name: Upload writer fuzzer - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: writer path: velox/_build/debug/velox/functions/prestosql/fuzzer/velox_writer_fuzzer_test retention-days: ${{ env.RETENTION }} + - name: Upload spatial join fuzzer + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 + with: + name: spatial_join_fuzzer + path: velox/_build/debug/velox/exec/fuzzer/velox_spatial_join_fuzzer + retention-days: ${{ env.RETENTION }} + + - name: Upload shared library + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 + with: + name: libvelox + path: velox/_build/debug/lib/libvelox.so + retention-days: ${{ env.RETENTION }} + presto-fuzzer-run: name: Presto Fuzzer if: ${{ needs.compile.outputs.presto_bias != 'true' }} - runs-on: ubuntu-latest + runs-on: 4-core-ubuntu container: ghcr.io/facebookincubator/velox-dev:centos9 needs: compile - timeout-minutes: 120 + timeout-minutes: 30 steps: - uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2 if: github.event_name == 'pull_request' @@ -398,10 +455,10 @@ jobs: - name: Set presto specific fuzzer duration env: - # Run for 30 minutes instead of 15, when files relevant to presto are touched - pr_duration: ${{ steps.changes.outputs.presto == 'true' && 1800 || 900 }} - # Run for 60 minutes if its a scheduled run - other_duration: ${{ inputs.duration || (github.event_name == 'push' && 1800 || 3600) }} + # Run for 10 minutes per instance instead of 5, when files relevant to presto are touched + pr_duration: ${{ steps.changes.outputs.presto == 'true' && 600 || 300 }} + # Run for 20 minutes per instance if its a scheduled run + other_duration: ${{ inputs.duration || (github.event_name == 'push' && 600 || 1200) }} is_pr: ${{ github.event_name == 'pull_request' }} run: | if [ "$is_pr" == "true" ]; then @@ -413,19 +470,27 @@ jobs: echo "DURATION=$duration" >> $GITHUB_ENV - name: Download presto fuzzer - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: name: presto + - name: Download libvelox + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: libvelox + + - name: Download fuzzer scripts + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: fuzzer-scripts + - name: Run Presto Fuzzer run: | - mkdir -p /tmp/fuzzer_repro/logs/ - chmod -R 777 /tmp/fuzzer_repro - chmod +x velox_expression_fuzzer_test - random_seed=${RANDOM} - echo "Random seed: ${random_seed}" - ./velox_expression_fuzzer_test \ - --seed ${random_seed} \ + export LD_LIBRARY_PATH=$(pwd):$LD_LIBRARY_PATH + chmod +x velox_expression_fuzzer_test run-fuzzer-parallel.sh + ./run-fuzzer-parallel.sh $NUM_FUZZER_INSTANCES /tmp/fuzzer_repro \ + ./velox_expression_fuzzer_test \ + --seed __SEED__ \ --enable_variadic_signatures \ --velox_fuzzer_enable_complex_types \ --velox_fuzzer_enable_decimal_type \ @@ -435,17 +500,17 @@ jobs: --velox_fuzzer_enable_expression_reuse \ --max_expression_trees_per_step 2 \ --retry_with_try \ + --special_forms="and,or,cast,coalesce" \ --enable_dereference \ --duration_sec $DURATION \ --minloglevel=0 \ --stderrthreshold=2 \ - --log_dir=/tmp/fuzzer_repro/logs \ - --repro_persist_path=/tmp/fuzzer_repro \ - && echo -e "\n\nFuzzer run finished successfully." + --log_dir=__LOG_DIR__ \ + --repro_persist_path=__REPRO_DIR__ - name: Archive production artifacts - if: ${{ !cancelled() }} - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + if: ${{ always() }} + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: presto-fuzzer-failure-artifacts path: | @@ -454,21 +519,31 @@ jobs: presto-sot-fuzzer-run: name: Expression Fuzzer with Presto SOT needs: compile - runs-on: ubuntu-latest + runs-on: 4-core-ubuntu container: ghcr.io/facebookincubator/velox-dev:presto-java - timeout-minutes: 120 + timeout-minutes: 30 env: CCACHE_DIR: ${{ github.workspace }}/ccache/ LINUX_DISTRO: centos steps: - name: Download Presto expression fuzzer - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: name: presto + - name: Download libvelox + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: libvelox + + - name: Download fuzzer scripts + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: fuzzer-scripts + - name: Checkout Repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: path: velox ref: ${{ inputs.ref }} @@ -484,18 +559,16 @@ jobs: cd velox cp ./scripts/ci/presto/etc/hive.properties $PRESTO_HOME/etc/catalog ls -lR $PRESTO_HOME/etc - $PRESTO_HOME/bin/launcher run -v > /tmp/server.log 2>&1 & + $PRESTO_HOME/bin/launcher run > /tmp/server.log 2>&1 & # Sleep for 60 seconds to allow Presto server to start. sleep 60 /opt/presto-cli --server 127.0.0.1:8080 --execute 'CREATE SCHEMA IF NOT EXISTS hive.tpch;' cd - - mkdir -p /tmp/fuzzer_repro/logs/ - chmod -R 777 /tmp/fuzzer_repro - chmod +x velox_expression_fuzzer_test - random_seed=${RANDOM} - echo "Random seed: ${random_seed}" - ./velox_expression_fuzzer_test \ - --seed ${random_seed} \ + export LD_LIBRARY_PATH=$(pwd):$LD_LIBRARY_PATH + chmod +x velox_expression_fuzzer_test run-fuzzer-parallel.sh + ./run-fuzzer-parallel.sh $NUM_FUZZER_INSTANCES /tmp/fuzzer_repro \ + ./velox_expression_fuzzer_test \ + --seed __SEED__ \ --enable_variadic_signatures \ --velox_fuzzer_enable_complex_types \ --lazy_vector_generation_ratio 0.2 \ @@ -503,19 +576,19 @@ jobs: --velox_fuzzer_enable_column_reuse \ --velox_fuzzer_enable_expression_reuse \ --enable_dereference \ - --duration_sec 900 \ + --duration_sec $DURATION \ --minloglevel=0 \ --stderrthreshold=2 \ - --log_dir=/tmp/fuzzer_repro/logs \ - --repro_persist_path=/tmp/fuzzer_repro \ + --log_dir=__LOG_DIR__ \ + --repro_persist_path=__REPRO_DIR__ \ --special_forms="cast,coalesce,if" \ --velox_fuzzer_max_level_of_nesting=1 \ --presto_url=http://127.0.0.1:8080 \ - && echo -e "\n\nFuzzer run finished successfully." + --table_name_prefix=i__INSTANCE_ID__ - name: Archive expression production artifacts - if: ${{ !cancelled() }} - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + if: ${{ always() }} + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: presto-sot-fuzzer-failure-artifacts path: | @@ -531,12 +604,21 @@ jobs: steps: - name: Download presto expression fuzzer - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: name: presto + - name: Download libvelox + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: libvelox + - name: Download fuzzer scripts + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: fuzzer-scripts + - name: Download Signatures - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: name: signatures path: /tmp/signatures @@ -546,9 +628,11 @@ jobs: ls /tmp/signatures mkdir -p /tmp/presto_bias_fuzzer_repro/logs/ chmod -R 777 /tmp/presto_bias_fuzzer_repro + export LD_LIBRARY_PATH=$(pwd):$LD_LIBRARY_PATH chmod +x velox_expression_fuzzer_test random_seed=${RANDOM} echo "Random seed: ${random_seed}" + echo "Biased functions: $(< /tmp/signatures/presto_bias_functions)" ./velox_expression_fuzzer_test \ --seed ${random_seed} \ --lazy_vector_generation_ratio 0.2 \ @@ -570,8 +654,8 @@ jobs: && echo -e "\n\nPresto Fuzzer run finished successfully." - name: Archive Presto Bias expression production artifacts - if: ${{ !cancelled() }} - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + if: ${{ always() }} + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: presto-bias-fuzzer-failure-artifacts path: | @@ -579,40 +663,48 @@ jobs: spark-aggregate-fuzzer-run: name: Spark Aggregate Fuzzer - runs-on: ubuntu-latest + runs-on: 4-core-ubuntu container: ghcr.io/facebookincubator/velox-dev:spark-server needs: compile - timeout-minutes: 60 + timeout-minutes: 30 steps: - name: Download spark aggregation fuzzer - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: name: spark_aggregation_fuzzer + - name: Download libvelox + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: libvelox + - name: Download fuzzer scripts + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: fuzzer-scripts + - name: Run Spark Aggregate Fuzzer run: | bash /opt/start-spark.sh - # Sleep for 60 seconds to allow Spark server to start. - sleep 60 - mkdir -p /tmp/spark_aggregate_fuzzer_repro/logs/ - chmod -R 777 /tmp/spark_aggregate_fuzzer_repro - chmod +x spark_aggregation_fuzzer_test - random_seed=${RANDOM} - echo "Random seed: ${random_seed}" - ./spark_aggregation_fuzzer_test \ - --seed ${random_seed} \ + # Sleep for 120 seconds to allow Spark server to start. + sleep 120 + export LD_LIBRARY_PATH=$(pwd):$LD_LIBRARY_PATH + chmod +x spark_aggregation_fuzzer_test run-fuzzer-parallel.sh + ./run-fuzzer-parallel.sh $NUM_FUZZER_INSTANCES /tmp/spark_aggregate_fuzzer_repro \ + ./spark_aggregation_fuzzer_test \ + --seed __SEED__ \ --duration_sec $DURATION \ --enable_sorted_aggregations=false \ + --enable_streaming_aggregations=false \ --minloglevel=0 \ --stderrthreshold=2 \ - --log_dir=/tmp/spark_aggregate_fuzzer_repro/logs \ - --repro_persist_path=/tmp/spark_aggregate_fuzzer_repro \ - && echo -e "\n\nSpark Aggregation Fuzzer run finished successfully." + --log_dir=__LOG_DIR__ \ + --repro_persist_path=__REPRO_DIR__ \ + --table_name_prefix=i__INSTANCE_ID__ - name: Archive Spark aggregate production artifacts - if: ${{ !cancelled() }} - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + if: ${{ always() }} + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: spark-agg-fuzzer-failure-artifacts path: | @@ -628,12 +720,21 @@ jobs: steps: - name: Download spark expression fuzzer - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: name: spark_expression_fuzzer + - name: Download libvelox + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: libvelox + - name: Download fuzzer scripts + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: fuzzer-scripts + - name: Download Signatures - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: name: signatures path: /tmp/signatures @@ -643,9 +744,11 @@ jobs: ls /tmp/signatures mkdir -p /tmp/spark_bias_fuzzer_repro/logs/ chmod -R 777 /tmp/spark_bias_fuzzer_repro + export LD_LIBRARY_PATH=$(pwd):$LD_LIBRARY_PATH chmod +x spark_expression_fuzzer_test random_seed=${RANDOM} echo "Random seed: ${random_seed}" + echo "Biased functions: $(< /tmp/signatures/spark_bias_functions)" ./spark_expression_fuzzer_test \ --seed ${random_seed} \ --duration_sec $DURATION \ @@ -658,8 +761,8 @@ jobs: && echo -e "\n\nSpark Fuzzer run finished successfully." - name: Archive Spark expression production artifacts - if: ${{ !cancelled() }} - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + if: ${{ always() }} + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: spark-fuzzer-failure-artifacts path: | @@ -668,26 +771,33 @@ jobs: spark-fuzzer: name: Spark Fuzzer if: ${{ needs.compile.outputs.spark_bias != 'true' }} - runs-on: ubuntu-latest + runs-on: 4-core-ubuntu container: ghcr.io/facebookincubator/velox-dev:centos9 needs: compile - timeout-minutes: 120 + timeout-minutes: 30 steps: - name: Download spark expression fuzzer - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: name: spark_expression_fuzzer + - name: Download libvelox + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: libvelox + - name: Download fuzzer scripts + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: fuzzer-scripts + - name: Run Spark Expression Fuzzer run: | - mkdir -p /tmp/spark_fuzzer_repro/logs/ - chmod -R 777 /tmp/spark_fuzzer_repro - chmod +x spark_expression_fuzzer_test - random_seed=${RANDOM} - echo "Random seed: ${random_seed}" - ./spark_expression_fuzzer_test \ - --seed ${random_seed} \ + export LD_LIBRARY_PATH=$(pwd):$LD_LIBRARY_PATH + chmod +x spark_expression_fuzzer_test run-fuzzer-parallel.sh + ./run-fuzzer-parallel.sh $NUM_FUZZER_INSTANCES /tmp/spark_fuzzer_repro \ + ./spark_expression_fuzzer_test \ + --seed __SEED__ \ --enable_variadic_signatures \ --lazy_vector_generation_ratio 0.2 \ --velox_fuzzer_enable_column_reuse \ @@ -696,16 +806,16 @@ jobs: --retry_with_try \ --enable_dereference \ --velox_fuzzer_enable_decimal_type \ + --special_forms="and,or,cast,coalesce" \ --duration_sec $DURATION \ --minloglevel=0 \ --stderrthreshold=2 \ - --log_dir=/tmp/spark_fuzzer_repro/logs \ - --repro_persist_path=/tmp/spark_fuzzer_repro \ - && echo -e "\n\nSpark Fuzzer run finished successfully." + --log_dir=__LOG_DIR__ \ + --repro_persist_path=__REPRO_DIR__ - name: Archive Spark expression production artifacts - if: ${{ !cancelled() }} - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + if: ${{ always() }} + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: spark-fuzzer-failure-artifacts path: | @@ -713,22 +823,31 @@ jobs: presto-java-join-fuzzer-run: name: Join Fuzzer - runs-on: ubuntu-latest + runs-on: 4-core-ubuntu container: ghcr.io/facebookincubator/velox-dev:presto-java needs: compile - timeout-minutes: 120 + timeout-minutes: 30 env: CCACHE_DIR: ${{ github.workspace }}/ccache/ LINUX_DISTRO: centos steps: - name: Download join fuzzer - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: name: join + - name: Download libvelox + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: libvelox + - name: Download fuzzer scripts + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: fuzzer-scripts + - name: Checkout Repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: path: velox ref: ${{ inputs.ref }} @@ -744,29 +863,27 @@ jobs: cd velox cp ./scripts/ci/presto/etc/hive.properties $PRESTO_HOME/etc/catalog ls -lR $PRESTO_HOME/etc - $PRESTO_HOME/bin/launcher run -v > /tmp/server.log 2>&1 & + $PRESTO_HOME/bin/launcher run > /tmp/server.log 2>&1 & # Sleep for 60 seconds to allow Presto server to start. sleep 60 /opt/presto-cli --server 127.0.0.1:8080 --execute 'CREATE SCHEMA IF NOT EXISTS hive.tpch;' cd - - mkdir -p /tmp/join_fuzzer_repro/logs/ - chmod -R 777 /tmp/join_fuzzer_repro - chmod +x velox_join_fuzzer - random_seed=${RANDOM} - echo "Random seed: ${random_seed}" - ./velox_join_fuzzer \ - --seed ${random_seed} \ + export LD_LIBRARY_PATH=$(pwd):$LD_LIBRARY_PATH + chmod +x velox_join_fuzzer run-fuzzer-parallel.sh + ./run-fuzzer-parallel.sh $NUM_FUZZER_INSTANCES /tmp/join_fuzzer_repro \ + ./velox_join_fuzzer \ + --seed __SEED__ \ --duration_sec $DURATION \ --minloglevel=0 \ --stderrthreshold=2 \ - --log_dir=/tmp/join_fuzzer_repro/logs \ + --log_dir=__LOG_DIR__ \ --presto_url=http://127.0.0.1:8080 \ - --req_timeout_ms=2000 \ - && echo -e "\n\nJoin fuzzer run finished successfully." + --req_timeout_ms=10000 \ + --table_name_prefix=i__INSTANCE_ID__ - name: Archive join production artifacts - if: ${{ !cancelled() }} - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + if: ${{ always() }} + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: presto-sot-join-fuzzer-failure-artifacts path: | @@ -778,34 +895,42 @@ jobs: runs-on: ubuntu-latest container: ghcr.io/facebookincubator/velox-dev:centos9 needs: compile - timeout-minutes: 120 + timeout-minutes: 30 steps: - name: Download exchange fuzzer - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: name: exchange + - name: Download libvelox + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: libvelox + - name: Download fuzzer scripts + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: fuzzer-scripts + - name: Run exchange Fuzzer run: | cat /proc/sys/vm/max_map_count - mkdir -p /tmp/exchange_fuzzer_repro/logs/ - chmod -R 777 /tmp/exchange_fuzzer_repro - chmod +x velox_exchange_fuzzer - random_seed=${RANDOM} - echo "Random seed: ${random_seed}" - ./velox_exchange_fuzzer \ - --seed ${random_seed} \ - --duration_sec $DURATION \ + export LD_LIBRARY_PATH=$(pwd):$LD_LIBRARY_PATH + chmod +x velox_exchange_fuzzer run-fuzzer-parallel.sh + # Exchange fuzzer is multi-threaded and memory-heavy; run 1 instance + # on ubuntu-latest with an extended duration. + ./run-fuzzer-parallel.sh 1 /tmp/exchange_fuzzer_repro \ + ./velox_exchange_fuzzer \ + --seed __SEED__ \ + --duration_sec 480 \ --minloglevel=0 \ --stderrthreshold=2 \ - --log_dir=/tmp/exchange_fuzzer_repro/logs \ - --repro_path=/tmp/exchange_fuzzer_repro \ - && echo -e "\n\Exchange fuzzer run finished successfully." + --log_dir=__LOG_DIR__ \ + --repro_path=__REPRO_DIR__ - name: Archive Exchange production artifacts - if: ${{ !cancelled() }} - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + if: ${{ always() }} + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: exchange-fuzzer-failure-artifacts path: | @@ -813,10 +938,10 @@ jobs: presto-java-row-number-fuzzer-run: name: RowNumber Fuzzer - runs-on: ubuntu-latest + runs-on: 4-core-ubuntu container: ghcr.io/facebookincubator/velox-dev:presto-java needs: compile - timeout-minutes: 120 + timeout-minutes: 30 env: CCACHE_DIR: ${{ github.workspace }}/ccache/ LINUX_DISTRO: centos @@ -824,12 +949,21 @@ jobs: steps: - name: Download row number fuzzer - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: name: row_number + - name: Download libvelox + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: libvelox + - name: Download fuzzer scripts + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: fuzzer-scripts + - name: Checkout Repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: path: velox ref: ${{ inputs.ref }} @@ -845,29 +979,27 @@ jobs: cd velox cp ./scripts/ci/presto/etc/hive.properties $PRESTO_HOME/etc/catalog ls -lR $PRESTO_HOME/etc - $PRESTO_HOME/bin/launcher run -v > /tmp/server.log 2>&1 & + $PRESTO_HOME/bin/launcher run > /tmp/server.log 2>&1 & # Sleep for 60 seconds to allow Presto server to start. sleep 60 /opt/presto-cli --server 127.0.0.1:8080 --execute 'CREATE SCHEMA IF NOT EXISTS hive.tpch;' cd - - cat /proc/sys/vm/max_map_count - mkdir -p /tmp/row_fuzzer_repro/logs/ - chmod -R 777 /tmp/row_fuzzer_repro - chmod +x velox_row_number_fuzzer - random_seed=${RANDOM} - echo "Random seed: ${random_seed}" - ./velox_row_number_fuzzer \ - --seed ${random_seed} \ + export LD_LIBRARY_PATH=$(pwd):$LD_LIBRARY_PATH + chmod +x velox_row_number_fuzzer run-fuzzer-parallel.sh + ./run-fuzzer-parallel.sh $NUM_FUZZER_INSTANCES /tmp/row_fuzzer_repro \ + ./velox_row_number_fuzzer \ + --seed __SEED__ \ --duration_sec $DURATION \ --minloglevel=0 \ --stderrthreshold=2 \ - --log_dir=/tmp/row_fuzzer_repro/logs \ + --log_dir=__LOG_DIR__ \ --presto_url=http://127.0.0.1:8080 \ - && echo -e "\n\Row number fuzzer run finished successfully." + --req_timeout_ms=10000 \ + --table_name_prefix=i__INSTANCE_ID__ - name: Archive row number production artifacts - if: ${{ !cancelled() }} - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + if: ${{ always() }} + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: row-fuzzer-failure-artifacts path: | @@ -875,10 +1007,10 @@ jobs: presto-java-topn-row-number-fuzzer-run: name: TopNRowNumber Fuzzer - runs-on: ubuntu-latest + runs-on: 4-core-ubuntu container: ghcr.io/facebookincubator/velox-dev:presto-java needs: compile - timeout-minutes: 120 + timeout-minutes: 30 env: CCACHE_DIR: ${{ github.workspace }}/ccache/ LINUX_DISTRO: centos @@ -886,12 +1018,21 @@ jobs: steps: - name: Download topn row number fuzzer - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: name: topn_row_number + - name: Download libvelox + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: libvelox + - name: Download fuzzer scripts + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: fuzzer-scripts + - name: Checkout Repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: path: velox ref: ${{ inputs.ref }} @@ -907,30 +1048,27 @@ jobs: cd velox cp ./scripts/ci/presto/etc/hive.properties $PRESTO_HOME/etc/catalog ls -lR $PRESTO_HOME/etc - $PRESTO_HOME/bin/launcher run -v > /tmp/server.log 2>&1 & + $PRESTO_HOME/bin/launcher run > /tmp/server.log 2>&1 & # Sleep for 60 seconds to allow Presto server to start. sleep 60 /opt/presto-cli --server 127.0.0.1:8080 --execute 'CREATE SCHEMA IF NOT EXISTS hive.tpch;' cd - - cat /proc/sys/vm/max_map_count - mkdir -p /tmp/topn_row_fuzzer_repro/logs/ - chmod -R 777 /tmp/topn_row_fuzzer_repro - chmod +x velox_topn_row_number_fuzzer - random_seed=${RANDOM} - echo "Random seed: ${random_seed}" - ./velox_topn_row_number_fuzzer \ - --seed ${random_seed} \ + export LD_LIBRARY_PATH=$(pwd):$LD_LIBRARY_PATH + chmod +x velox_topn_row_number_fuzzer run-fuzzer-parallel.sh + ./run-fuzzer-parallel.sh $NUM_FUZZER_INSTANCES /tmp/topn_row_fuzzer_repro \ + ./velox_topn_row_number_fuzzer \ + --seed __SEED__ \ --duration_sec $DURATION \ --minloglevel=0 \ --stderrthreshold=2 \ --req_timeout_ms 50000 \ - --log_dir=/tmp/topn_row_fuzzer_repro/logs \ + --log_dir=__LOG_DIR__ \ --presto_url=http://127.0.0.1:8080 \ - && echo -e "\n\TopN Row number fuzzer run finished successfully." + --table_name_prefix=i__INSTANCE_ID__ - name: Archive topn row number production artifacts - if: ${{ !cancelled() }} - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + if: ${{ always() }} + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: topn-row-fuzzer-failure-artifacts path: | @@ -938,35 +1076,43 @@ jobs: cache-fuzzer-run: name: Cache Fuzzer - runs-on: ubuntu-latest + runs-on: 4-core-ubuntu container: ghcr.io/facebookincubator/velox-dev:centos9 needs: compile - timeout-minutes: 120 + timeout-minutes: 30 # Temporarily disable on PRs till flakiness is fixed #12167 if: ${{ github.event_name != 'pull_request' }} steps: - name: Download cache fuzzer - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: name: cache_fuzzer + - name: Download libvelox + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: libvelox + - name: Download fuzzer scripts + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: fuzzer-scripts + - name: Run Cache Fuzzer run: | - mkdir -p /tmp/cache_fuzzer/logs/ - chmod -R 777 /tmp/cache_fuzzer - chmod +x velox_cache_fuzzer - ./velox_cache_fuzzer \ - --seed ${RANDOM} \ + export LD_LIBRARY_PATH=$(pwd):$LD_LIBRARY_PATH + chmod +x velox_cache_fuzzer run-fuzzer-parallel.sh + ./run-fuzzer-parallel.sh $NUM_FUZZER_INSTANCES /tmp/cache_fuzzer \ + ./velox_cache_fuzzer \ + --seed __SEED__ \ --duration_sec $DURATION \ --minloglevel=0 \ --stderrthreshold=2 \ - --log_dir=/tmp/cache_fuzzer/logs \ - && echo -e "\n\Cache fuzzer run finished successfully." + --log_dir=__LOG_DIR__ - name: Archive Cache production artifacts - if: ${{ !cancelled() }} - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + if: ${{ always() }} + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: cache-fuzzer-logs path: | @@ -974,33 +1120,41 @@ jobs: table-evolution-fuzzer-run: name: Table Evolution Fuzzer - runs-on: ubuntu-latest + runs-on: 4-core-ubuntu container: ghcr.io/facebookincubator/velox-dev:centos9 needs: compile - timeout-minutes: 120 + timeout-minutes: 30 steps: - name: Download table evolution fuzzer - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: name: table_evolution_fuzzer + - name: Download libvelox + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: libvelox + - name: Download fuzzer scripts + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: fuzzer-scripts + - name: Run Table Evolution Fuzzer run: | - mkdir -p /tmp/table_evolution_fuzzer_test/logs/ - chmod -R 777 /tmp/table_evolution_fuzzer_test - chmod +x velox_table_evolution_fuzzer_test - ./velox_table_evolution_fuzzer_test \ - --seed ${RANDOM} \ + export LD_LIBRARY_PATH=$(pwd):$LD_LIBRARY_PATH + chmod +x velox_table_evolution_fuzzer_test run-fuzzer-parallel.sh + ./run-fuzzer-parallel.sh $NUM_FUZZER_INSTANCES /tmp/table_evolution_fuzzer_test \ + ./velox_table_evolution_fuzzer_test \ + --seed __SEED__ \ --table_evolution_fuzzer_duration_sec $DURATION \ --minloglevel=0 \ --stderrthreshold=2 \ - --log_dir=/tmp/table_evolution_fuzzer_test/logs \ - && echo -e "\n\Table evolution fuzzer run finished successfully." + --log_dir=__LOG_DIR__ - name: Archive table evolution production artifacts - if: ${{ !cancelled() }} - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + if: ${{ always() }} + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: table-evolution-fuzzer-test-logs path: | @@ -1008,56 +1162,73 @@ jobs: memory-arbitration-fuzzer-run: name: Memory Arbitration Fuzzer - runs-on: ubuntu-latest + runs-on: 4-core-ubuntu container: ghcr.io/facebookincubator/velox-dev:centos9 needs: compile - timeout-minutes: 120 + timeout-minutes: 30 steps: - name: Download memory arbitration fuzzer - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: name: memory_arbitration_fuzzer + - name: Download libvelox + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: libvelox + - name: Download fuzzer scripts + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: fuzzer-scripts + - name: Run Memory Arbitration Fuzzer run: | - mkdir -p /tmp/memory_arbitration_fuzzer/logs/ - chmod -R 777 /tmp/memory_arbitration_fuzzer - chmod +x velox_memory_arbitration_fuzzer - ./velox_memory_arbitration_fuzzer \ - --seed ${RANDOM} \ + export LD_LIBRARY_PATH=$(pwd):$LD_LIBRARY_PATH + chmod +x velox_memory_arbitration_fuzzer run-fuzzer-parallel.sh + ./run-fuzzer-parallel.sh $NUM_FUZZER_INSTANCES /tmp/memory_arbitration_fuzzer \ + ./velox_memory_arbitration_fuzzer \ + --seed __SEED__ \ --duration_sec $DURATION \ --minloglevel=0 \ --stderrthreshold=2 \ - --log_dir=/tmp/memory_arbitration_fuzzer/logs \ - && echo -e "\n\Memory arbitration fuzzer run finished successfully." + --log_dir=__LOG_DIR__ - name: Archive memory arbitration production artifacts - if: ${{ !cancelled() }} - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + if: ${{ always() }} + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: memory-arbitration-fuzzer-test-logs path: | - /tmp/memory_arbitration_fuzzer_test + /tmp/memory_arbitration_fuzzer presto-java-aggregation-fuzzer-run: name: Aggregation Fuzzer with Presto as source of truth needs: compile - runs-on: ubuntu-latest + runs-on: 4-core-ubuntu container: ghcr.io/facebookincubator/velox-dev:presto-java - timeout-minutes: 120 + timeout-minutes: 30 env: CCACHE_DIR: ${{ github.workspace }}/ccache/ LINUX_DISTRO: centos steps: - name: Download aggregation fuzzer - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: name: aggregation + - name: Download libvelox + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: libvelox + - name: Download fuzzer scripts + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: fuzzer-scripts + - name: Checkout Repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: path: velox ref: ${{ inputs.ref }} @@ -1073,30 +1244,29 @@ jobs: cd velox cp ./scripts/ci/presto/etc/hive.properties $PRESTO_HOME/etc/catalog ls -lR $PRESTO_HOME/etc - $PRESTO_HOME/bin/launcher run -v > /tmp/server.log 2>&1 & + $PRESTO_HOME/bin/launcher run > /tmp/server.log 2>&1 & # Sleep for 60 seconds to allow Presto server to start. sleep 60 /opt/presto-cli --server 127.0.0.1:8080 --execute 'CREATE SCHEMA IF NOT EXISTS hive.tpch;' cd - - mkdir -p /tmp/aggregate_fuzzer_repro/logs/ - chmod -R 777 /tmp/aggregate_fuzzer_repro - chmod +x velox_aggregation_fuzzer_test - random_seed=${RANDOM} - echo "Random seed: ${random_seed}" - ./velox_aggregation_fuzzer_test \ - --seed ${random_seed} \ + export LD_LIBRARY_PATH=$(pwd):$LD_LIBRARY_PATH + chmod +x velox_aggregation_fuzzer_test run-fuzzer-parallel.sh + ./run-fuzzer-parallel.sh $NUM_FUZZER_INSTANCES /tmp/aggregate_fuzzer_repro \ + ./velox_aggregation_fuzzer_test \ + --seed __SEED__ \ --duration_sec $DURATION \ --minloglevel=0 \ --stderrthreshold=2 \ - --log_dir=/tmp/aggregate_fuzzer_repro/logs \ - --repro_persist_path=/tmp/aggregate_fuzzer_repro \ + --log_dir=__LOG_DIR__ \ + --repro_persist_path=__REPRO_DIR__ \ --enable_sorted_aggregations=true \ --presto_url=http://127.0.0.1:8080 \ - && echo -e "\n\nAggregation fuzzer run finished successfully." + --req_timeout_ms=10000 \ + --table_name_prefix=i__INSTANCE_ID__ - name: Archive aggregate production artifacts - if: ${{ !cancelled() }} - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + if: ${{ always() }} + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: presto-sot-aggregate-fuzzer-failure-artifacts path: | @@ -1116,12 +1286,21 @@ jobs: steps: - name: Download presto expression fuzzer - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: name: presto + - name: Download libvelox + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: libvelox + - name: Download fuzzer scripts + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: fuzzer-scripts + - name: Checkout Repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: path: velox ref: ${{ inputs.ref }} @@ -1133,7 +1312,7 @@ jobs: run: git config --global --add safe.directory ${GITHUB_WORKSPACE}/velox - name: Download Signatures - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: name: signatures path: /tmp/signatures @@ -1143,13 +1322,14 @@ jobs: cd velox cp ./scripts/ci/presto/etc/hive.properties $PRESTO_HOME/etc/catalog ls -lR $PRESTO_HOME/etc - $PRESTO_HOME/bin/launcher run -v > /tmp/server.log 2>&1 & + $PRESTO_HOME/bin/launcher run > /tmp/server.log 2>&1 & # Sleep for 60 seconds to allow Presto server to start. sleep 60 /opt/presto-cli --server 127.0.0.1:8080 --execute 'CREATE SCHEMA IF NOT EXISTS hive.tpch;' cd - mkdir -p /tmp/presto_only_bias_function_fuzzer_repro/logs/ chmod -R 777 /tmp/presto_only_bias_function_fuzzer_repro + export LD_LIBRARY_PATH=$(pwd):$LD_LIBRARY_PATH chmod +x velox_expression_fuzzer_test echo "Biased functions: $(< /tmp/signatures/presto_bias_functions)" # Convert the list of function names with tickets into a list of function names only. @@ -1185,8 +1365,8 @@ jobs: --repro_persist_path=/tmp/presto_only_bias_function_fuzzer_repro \ && echo -e "\n\nPresto Fuzzer run finished successfully." - name: Archive Presto only-bias-function expression fuzzer production artifacts - if: ${{ !cancelled() }} - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + if: ${{ always() }} + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: presto-only-bias-function-fuzzer-failure-artifacts path: | @@ -1205,12 +1385,21 @@ jobs: steps: - name: Download aggregation fuzzer - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: name: aggregation + - name: Download libvelox + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: libvelox + - name: Download fuzzer scripts + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: fuzzer-scripts + - name: Checkout Repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: path: velox ref: ${{ inputs.ref }} @@ -1222,7 +1411,7 @@ jobs: run: git config --global --add safe.directory ${GITHUB_WORKSPACE}/velox - name: Download Signatures - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: name: signatures path: /tmp/signatures @@ -1232,13 +1421,14 @@ jobs: cd velox cp ./scripts/ci/presto/etc/hive.properties $PRESTO_HOME/etc/catalog ls -lR $PRESTO_HOME/etc - $PRESTO_HOME/bin/launcher run -v > /tmp/server.log 2>&1 & + $PRESTO_HOME/bin/launcher run > /tmp/server.log 2>&1 & # Sleep for 60 seconds to allow Presto server to start. sleep 60 /opt/presto-cli --server 127.0.0.1:8080 --execute 'CREATE SCHEMA IF NOT EXISTS hive.tpch;' cd - mkdir -p /tmp/aggregate_fuzzer_repro/logs/ chmod -R 777 /tmp/aggregate_fuzzer_repro + export LD_LIBRARY_PATH=$(pwd):$LD_LIBRARY_PATH chmod +x velox_aggregation_fuzzer_test echo "signatures folder" ls /tmp/signatures/ @@ -1259,8 +1449,8 @@ jobs: && echo -e "\n\nAggregation fuzzer run finished successfully." - name: Archive bias aggregate production artifacts - if: ${{ !cancelled() }} - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + if: ${{ always() }} + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: presto-bias-sot-aggregate-fuzzer-failure-artifacts path: | @@ -1269,12 +1459,12 @@ jobs: surface-signature-errors: name: Signature Changes - if: ${{ github.event_name != 'schedule' }} + if: ${{ github.event_name != 'schedule' && needs.compile.outputs.signatures_checked != 'false' }} needs: compile runs-on: ubuntu-latest steps: - name: Download Signatures - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: name: signatures path: /tmp/signatures @@ -1294,21 +1484,30 @@ jobs: presto-java-window-fuzzer-run: name: Window Fuzzer with Presto as source of truth needs: compile - runs-on: ubuntu-latest + runs-on: 4-core-ubuntu container: ghcr.io/facebookincubator/velox-dev:presto-java - timeout-minutes: 120 + timeout-minutes: 30 env: CCACHE_DIR: ${{ github.workspace }}/ccache/ LINUX_DISTRO: centos steps: - name: Download window fuzzer - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: name: window + - name: Download libvelox + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: libvelox + - name: Download fuzzer scripts + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: fuzzer-scripts + - name: Checkout Repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: path: velox ref: ${{ inputs.ref }} @@ -1324,31 +1523,30 @@ jobs: cd velox cp ./scripts/ci/presto/etc/hive.properties $PRESTO_HOME/etc/catalog ls -lR $PRESTO_HOME/etc - $PRESTO_HOME/bin/launcher run -v > /tmp/server.log 2>&1 & + $PRESTO_HOME/bin/launcher run > /tmp/server.log 2>&1 & # Sleep for 60 seconds to allow Presto server to start. sleep 60 /opt/presto-cli --server 127.0.0.1:8080 --execute 'CREATE SCHEMA IF NOT EXISTS hive.tpch;' cd - - mkdir -p /tmp/window_fuzzer_repro/logs/ - chmod -R 777 /tmp/window_fuzzer_repro - chmod +x velox_window_fuzzer_test - random_seed=${RANDOM} - echo "Random seed: ${random_seed}" - ./velox_window_fuzzer_test \ - --seed ${random_seed} \ + export LD_LIBRARY_PATH=$(pwd):$LD_LIBRARY_PATH + chmod +x velox_window_fuzzer_test run-fuzzer-parallel.sh + ./run-fuzzer-parallel.sh $NUM_FUZZER_INSTANCES /tmp/window_fuzzer_repro \ + ./velox_window_fuzzer_test \ + --seed __SEED__ \ --duration_sec $DURATION \ --batch_size=50 \ --minloglevel=0 \ --stderrthreshold=2 \ - --log_dir=/tmp/window_fuzzer_repro/logs \ - --repro_persist_path=/tmp/window_fuzzer_repro \ + --log_dir=__LOG_DIR__ \ + --repro_persist_path=__REPRO_DIR__ \ --enable_window_reference_verification \ --presto_url=http://127.0.0.1:8080 \ - && echo -e "\n\nWindow fuzzer run finished successfully." + --req_timeout_ms=10000 \ + --table_name_prefix=i__INSTANCE_ID__ - name: Archive window production artifacts - if: ${{ !cancelled() }} - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + if: ${{ always() }} + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: presto-sot-window-fuzzer-failure-artifacts path: | @@ -1358,21 +1556,30 @@ jobs: presto-java-writer-fuzzer-run: name: Writer Fuzzer with Presto as source of truth needs: compile - runs-on: ubuntu-latest + runs-on: 4-core-ubuntu container: ghcr.io/facebookincubator/velox-dev:presto-java - timeout-minutes: 120 + timeout-minutes: 30 env: CCACHE_DIR: ${{ github.workspace }}/ccache/ LINUX_DISTRO: centos steps: - name: Download writer fuzzer - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: name: writer + - name: Download libvelox + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: libvelox + - name: Download fuzzer scripts + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: fuzzer-scripts + - name: Checkout Repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: path: velox ref: ${{ inputs.ref }} @@ -1390,31 +1597,29 @@ jobs: ls -lR $PRESTO_HOME/etc echo "jvm config content:" cat $PRESTO_HOME/etc/jvm.config - $PRESTO_HOME/bin/launcher run -v > /tmp/server.log 2>&1 & + $PRESTO_HOME/bin/launcher run > /tmp/server.log 2>&1 & ls -lR /var/log # Sleep for 60 seconds to allow Presto server to start. sleep 60 /opt/presto-cli --version /opt/presto-cli --server 127.0.0.1:8080 --execute 'CREATE SCHEMA IF NOT EXISTS hive.tpch;' cd - - mkdir -p /tmp/writer_fuzzer_repro/logs/ - chmod -R 777 /tmp/writer_fuzzer_repro - chmod +x velox_writer_fuzzer_test - random_seed=${RANDOM} - echo "Random seed: ${random_seed}" - ./velox_writer_fuzzer_test \ - --seed ${random_seed} \ + export LD_LIBRARY_PATH=$(pwd):$LD_LIBRARY_PATH + chmod +x velox_writer_fuzzer_test run-fuzzer-parallel.sh + ./run-fuzzer-parallel.sh $NUM_FUZZER_INSTANCES /tmp/writer_fuzzer_repro \ + ./velox_writer_fuzzer_test \ + --seed __SEED__ \ --duration_sec $DURATION \ --minloglevel=0 \ --stderrthreshold=2 \ --req_timeout_ms 60000 \ - --log_dir=/tmp/writer_fuzzer_repro/logs \ + --log_dir=__LOG_DIR__ \ --presto_url=http://127.0.0.1:8080 \ - && echo -e "\n\Writer fuzzer run finished successfully." + --table_name_prefix=i__INSTANCE_ID__ - name: Archive writer production artifacts - if: ${{ !cancelled() }} - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + if: ${{ always() }} + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: presto-sot-writer-fuzzer-failure-artifacts path: | @@ -1422,6 +1627,48 @@ jobs: /tmp/server.log /var/log + spatial-join-fuzzer-run: + name: Spatial Join Fuzzer + runs-on: 4-core-ubuntu + container: ghcr.io/facebookincubator/velox-dev:centos9 + needs: compile + timeout-minutes: 30 + steps: + + - name: Download spatial join fuzzer + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: spatial_join_fuzzer + + - name: Download libvelox + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: libvelox + - name: Download fuzzer scripts + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + with: + name: fuzzer-scripts + + - name: Run Spatial Join Fuzzer + run: | + export LD_LIBRARY_PATH=$(pwd):$LD_LIBRARY_PATH + chmod +x velox_spatial_join_fuzzer run-fuzzer-parallel.sh + ./run-fuzzer-parallel.sh $NUM_FUZZER_INSTANCES /tmp/spatial_join_fuzzer_repro \ + ./velox_spatial_join_fuzzer \ + --seed __SEED__ \ + --duration_sec $DURATION \ + --minloglevel=0 \ + --stderrthreshold=2 \ + --log_dir=__LOG_DIR__ + + - name: Archive spatial join fuzzer production artifacts + if: ${{ always() }} + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 + with: + name: spatial-join-fuzzer-failure-artifacts + path: | + /tmp/spatial_join_fuzzer_repro + linux-clang: if: ${{ (github.event_name == 'schedule') || (github.event_name == 'workflow_dispatch') }} name: Build with Clang diff --git a/.github/disabled-workflows/tag.yml b/.github/disabled-workflows/tag.yml new file mode 100644 index 00000000000..7710960fc78 --- /dev/null +++ b/.github/disabled-workflows/tag.yml @@ -0,0 +1,76 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +name: Weekly Date Tag + +on: + schedule: + # Runs every Friday at 09:23 UTC, using a odd time avoids jobs + # from being skipped during peak hours. + - cron: 23 9 * * 5 + workflow_dispatch: + inputs: + commit: + description: Which commit to tag + required: true + + patch-version: + description: Additional version component + required: false + default: '00' + +permissions: {} + +jobs: + create-date-tag: + runs-on: ubuntu-latest + permissions: + contents: write # required to push tag + checks: read + env: + COMMIT: ${{ inputs.commit || github.sha }} + + steps: + - name: Checkout repository + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + fetch-depth: 0 + persist-credentials: true + + - name: Set up Git + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + + - name: Check CI Status + # Allow manual triggered workflows to circumvent the check + if: ${{ github.event_name != 'workflow_dispatch' }} + env: + GH_TOKEN: ${{ github.token }} + run: | + # This `gh` invocation returns a json array with the workflow status like this + # [ { "status": "completed" } ] + # If the workflow wasn't success full the array will be empty which we check for + # using `grep` -q = set exit code -v = invert match + gh run list --commit "$COMMIT" \ + --workflow "Linux Build using GCC" \ + --status success --json status | grep -qv '\[\]' + + - name: Create and push date-version tag + env: + PATCH_VERSION: ${{ inputs.patch-version || '00' }} + run: | + MESSAGE="This is convenience tag, not a full release." + DATE_TAG=$(date -u +"v%Y.%m.%d.$PATCH_VERSION") + git tag -m "$MESSAGE" "$DATE_TAG" "$COMMIT" + git push origin "$DATE_TAG" diff --git a/.github/disabled-workflows/ubuntu-bundled-deps.yml b/.github/disabled-workflows/ubuntu-bundled-deps.yml new file mode 100644 index 00000000000..801e93d7d14 --- /dev/null +++ b/.github/disabled-workflows/ubuntu-bundled-deps.yml @@ -0,0 +1,110 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Tests the BUNDLED dependency resolution path on a plain Ubuntu runner +# (no container). Triggered on PRs that change dependency scripts or +# CMake modules, and on a nightly schedule. + +name: Ubuntu Bundled Dependencies + +on: + schedule: + - cron: 0 5 * * * + workflow_dispatch: {} + pull_request: + paths: + - scripts/setup-*.sh + - scripts/setup-common.sh + - scripts/setup-versions.sh + - scripts/setup-helper-functions.sh + - CMake/** + - CMakeLists.txt + - .github/workflows/ubuntu-bundled-deps.yml + +permissions: + contents: read + +concurrency: + group: ${{ github.workflow }}-${{ github.repository }}-${{ github.head_ref || github.sha }} + cancel-in-progress: true + +jobs: + ubuntu-debug-bundled-deps: + runs-on: 16-core-ubuntu + if: ${{ github.repository == 'facebookincubator/velox' }} + name: Ubuntu debug with bundled dependencies + env: + CCACHE_DIR: ${{ github.workspace }}/ccache + defaults: + run: + shell: bash + working-directory: velox + steps: + + - name: Get Ccache Stash + uses: apache/infrastructure-actions/stash/restore@3354c1565d4b0e335b78a76aedd82153a9e144d4 + with: + path: ${{ env.CCACHE_DIR }} + key: ccache-ubuntu-debug-bundled-deps + + - name: Ensure Stash Dirs Exists + working-directory: ${{ github.workspace }} + run: | + mkdir -p "$CCACHE_DIR" + + - uses: actions/checkout@v5 + with: + path: velox + persist-credentials: false + + - name: Install Dependencies + run: | + source scripts/setup-ubuntu.sh && install_apt_deps && install_faiss_deps + + - name: Clear CCache Statistics + run: | + ccache -sz + + - name: Make Debug Build + env: + VELOX_DEPENDENCY_SOURCE: BUNDLED + ICU_SOURCE: SYSTEM + MAKEFLAGS: NUM_THREADS=16 MAX_HIGH_MEM_JOBS=4 MAX_LINK_JOBS=2 + run: | + EXTRA_CMAKE_FLAGS=( + "-DCMAKE_LINK_LIBRARIES_STRATEGY=REORDER_FREELY" + "-DVELOX_ENABLE_BENCHMARKS=ON" + "-DVELOX_ENABLE_EXAMPLES=ON" + "-DVELOX_ENABLE_ARROW=ON" + "-DVELOX_ENABLE_GEO=ON" + "-DVELOX_ENABLE_PARQUET=ON" + "-DVELOX_MONO_LIBRARY=ON" + "-DVELOX_BUILD_SHARED=ON" + "-DVELOX_ENABLE_FAISS=ON" + "-DVELOX_ENABLE_REMOTE_FUNCTIONS=ON" + ) + make debug + + - name: CCache after + run: | + ccache -vs + + - uses: apache/infrastructure-actions/stash/save@3354c1565d4b0e335b78a76aedd82153a9e144d4 + with: + path: ${{ env.CCACHE_DIR }} + key: ccache-ubuntu-debug-bundled-deps + + - name: Run Tests + run: | + cd _build/debug && ctest -j 8 --output-on-failure --no-tests=error diff --git a/.github/matchers/clang-tidy.json b/.github/matchers/clang-tidy.json new file mode 100644 index 00000000000..1ae1bffc022 --- /dev/null +++ b/.github/matchers/clang-tidy.json @@ -0,0 +1,17 @@ +{ + "problemMatcher": [ + { + "owner": "clang-tidy", + "pattern": [ + { + "regexp": "^(.*):(\\d+):(\\d+):\\s+(error|warning):\\s+(.*) \\[([a-z0-9,\\-]+)\\]\\s*$", + "file": 1, + "line": 2, + "column": 3, + "severity": 4, + "message": 5 + } + ] + } + ] +} diff --git a/.github/matchers/gcc.json b/.github/matchers/gcc.json new file mode 100644 index 00000000000..899239f8160 --- /dev/null +++ b/.github/matchers/gcc.json @@ -0,0 +1,18 @@ +{ + "problemMatcher": [ + { + "owner": "gcc", + "severity": "error", + "pattern": [ + { + "regexp": "^(.*):(\\d+):(\\d+):\\s+(?:fatal\\s+)?(warning|error):\\s+(.*)$", + "file": 1, + "line": 2, + "column": 3, + "severity": 4, + "message": 5 + } + ] + } + ] +} diff --git a/.github/matchers/shellcheck.json b/.github/matchers/shellcheck.json new file mode 100644 index 00000000000..431d91ddd9f --- /dev/null +++ b/.github/matchers/shellcheck.json @@ -0,0 +1,18 @@ +{ + "problemMatcher": [ + { + "owner": "shellcheck", + "severity": "error", + "pattern": [ + { + "regexp": "^(.*):(\\d+):(\\d+):\\s+(?:fatal\\s+)?(warning|error):\\s+(.*)$", + "file": 1, + "line": 2, + "column": 3, + "severity": 4, + "message": 5 + } + ] + } + ] +} diff --git a/.github/scripts/detect-build-impact.py b/.github/scripts/detect-build-impact.py new file mode 100755 index 00000000000..6d6e31bbc22 --- /dev/null +++ b/.github/scripts/detect-build-impact.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Detect build impact of changed files using the pre-computed dependency graph. + +Early exits: + - velox/experimental/ or velox/external/ changes → full build recommended + (CUDA-only and vendored code outside the dependency graph). + +File resolution: + For each changed file under velox/, resolves to affected CMake targets: + 1. File API exact match (source file directly mapped to a target) + 2. Header scan match (header → source files via g++ -MM → targets) + For headers, both results are unioned: exact match provides the owning + target, header scan provides all targets whose sources include the header. + Files that match neither are reported as unresolved. + +Then computes the transitive reverse dependency closure and identifies the +minimal set of selective build targets. + +Usage: + python detect-build-impact.py \\ + --graph dependency-graph.json \\ + --changed-files changed_files.txt \\ + --build-type release \\ + --output comment.md +""" + +import argparse +import json +import os +from collections import defaultdict, deque + + +def load_graph(graph_path: str) -> dict: + """Load the dependency graph JSON.""" + with open(graph_path) as fp: + return json.load(fp) + + +def resolve_file_to_targets( + file_path: str, + file_to_targets: dict[str, list[str]], + header_to_sources: dict[str, list[str]], +) -> tuple[set[str], str]: + """Resolve a changed file to its affected targets. + + Returns: + A tuple of (set of target names, resolution method string). + """ + # Step 1: File API exact match (source or header directly in a target). + # Step 2: Header scan match (header → source files → targets). + # For headers, union both: exact gives the owning target, scan gives includers. + exact_targets = set(file_to_targets.get(file_path, [])) + scan_targets: set[str] = set() + + if file_path.endswith((".h", ".hpp", ".cuh")): + for source in header_to_sources.get(file_path, []): + scan_targets.update(file_to_targets.get(source, [])) + + combined = exact_targets | scan_targets + if combined: + if exact_targets and scan_targets: + method = "exact+header-scan" + elif exact_targets: + method = "exact" + else: + method = "header-scan" + return combined, method + + return set(), "unresolved" + + +def compute_reverse_deps(target_deps: dict[str, list[str]]) -> dict[str, set[str]]: + """Invert the dependency graph to get reverse dependencies.""" + reverse: dict[str, set[str]] = defaultdict(set) + for target, deps in target_deps.items(): + for dep in deps: + reverse[dep].add(target) + return reverse + + +def compute_transitive_closure( + start_targets: set[str], reverse_deps: dict[str, set[str]] +) -> set[str]: + """BFS from start_targets through reverse_deps to find all affected targets.""" + visited = set() + queue = deque(start_targets) + while queue: + target = queue.popleft() + if target in visited: + continue + visited.add(target) + for dependent in reverse_deps.get(target, set()): + if dependent not in visited: + queue.append(dependent) + return visited + + +def compute_selective_build_targets( + affected: set[str], target_deps: dict[str, list[str]] +) -> set[str]: + """Find the minimal set of root targets that cover all affected targets. + + These are affected targets that are not a dependency of any other + affected target. + """ + depended_upon = set() + for target in affected: + for dep in target_deps.get(target, []): + if dep in affected: + depended_upon.add(dep) + + return affected - depended_upon + + +def generate_comment( + changed_targets: dict[str, dict], + all_affected: set[str], + selective_targets: set[str], + total_targets: int, + build_type: str, + graph_source: str, + unresolved_files: list[str], +) -> str: + """Generate the PR comment markdown.""" + total_affected = len(all_affected) + + lines = [] + lines.append("## Build Impact Analysis\n") + + # Group changed files by target. + target_files: dict[str, list[str]] = defaultdict(list) + for file_path, info in changed_targets.items(): + for target in info["targets"]: + target_files[target].append(os.path.basename(file_path)) + + # Selective build targets. + selective_sorted = sorted(selective_targets) + lines.append( + f"### Selective Build Targets " + f"(building these covers all {total_affected} affected)" + ) + targets_str = " ".join(selective_sorted) + lines.append("```") + lines.append(f"cmake --build _build/{build_type} --target {targets_str}") + lines.append("```") + lines.append("") + + lines.append(f"**Total affected:** {total_affected}/{total_targets} targets") + lines.append("") + + # Unresolved files warning. + if unresolved_files: + lines.append( + f"> **Warning:** {len(unresolved_files)} file(s) could not be " + f"mapped to any target. A full build may be needed." + ) + lines.append(">") + for f in unresolved_files[:10]: + lines.append(f"> - `{f}`") + if len(unresolved_files) > 10: + lines.append(f"> - ... and {len(unresolved_files) - 10} more") + lines.append("") + + # Collapsible affected targets breakdown. + transitive_only = all_affected - set(target_files.keys()) + lines.append("
") + lines.append(f"Affected targets ({total_affected})") + lines.append("") + lines.append(f"#### Directly changed ({len(target_files)})") + lines.append("") + lines.append("| Target | Changed Files |") + lines.append("|--------|--------------|") + for target in sorted(target_files.keys()): + files_list = sorted(set(target_files[target])) + files = ", ".join(files_list[:5]) + if len(files_list) > 5: + files += f", ... (+{len(files_list) - 5} more)" + lines.append(f"| `{target}` | {files} |") + lines.append("") + if transitive_only: + lines.append(f"#### Transitively affected ({len(transitive_only)})") + lines.append("") + for target in sorted(transitive_only): + lines.append(f"- `{target}`") + lines.append("") + lines.append("
") + lines.append("") + + # Footer. + lines.append("---") + lines.append(f"*{graph_source}*") + + return "\n".join(lines) + + +def main(): + parser = argparse.ArgumentParser( + description="Detect build impact of changed files." + ) + parser.add_argument( + "--graph", + required=True, + help="Path to the dependency graph JSON file.", + ) + parser.add_argument( + "--changed-files", + required=True, + help="Path to a file listing changed files (one per line).", + ) + parser.add_argument( + "--build-type", + default="release", + help="Build type for the cmake command (default: release).", + ) + parser.add_argument( + "--graph-source", + default="", + help="Description of graph source for the comment footer.", + ) + parser.add_argument( + "--output", + default="comment.md", + help="Output file for the PR comment markdown (default: comment.md).", + ) + args = parser.parse_args() + + # Load inputs. + graph = load_graph(args.graph) + file_to_targets = graph["file_to_targets"] + header_to_sources = graph.get("header_to_sources", {}) + target_deps = graph["target_deps"] + total_targets = len(target_deps) + + with open(args.changed_files) as fp: + changed_files = [ + line.strip() for line in fp if line.strip() and not line.startswith("#") + ] + + # Check for experimental/ or external/ changes — these are outside + # the dependency graph (CUDA-only, vendored code) and need a full build. + full_build_prefixes = ("velox/experimental/", "velox/external/") + full_build_files = [f for f in changed_files if f.startswith(full_build_prefixes)] + if full_build_files: + comment = ( + "## Build Impact Analysis\n\n" + "**Full build recommended.** Files outside the dependency graph changed:\n\n" + ) + for f in full_build_files[:10]: + comment += f"- `{f}`\n" + if len(full_build_files) > 10: + comment += f"- ... and {len(full_build_files) - 10} more\n" + comment += ( + "\nThese directories are not fully covered by the dependency graph. " + "A full build is the safest option.\n\n" + "```\n" + f"cmake --build _build/{args.build_type}\n" + "```\n\n" + "---\n" + f"*{args.graph_source or 'Build impact analysis'}*" + ) + with open(args.output, "w") as fp: + fp.write(comment) + print("experimental/external files changed — full build recommended.") + return + + # Only keep files under velox/ that can affect build targets. + source_files = [] + skipped_files = [] + for f in changed_files: + if f.startswith("velox/"): + source_files.append(f) + else: + skipped_files.append(f) + if skipped_files: + print(f" Skipped {len(skipped_files)} non-source files") + changed_files = source_files + + # Build lookup structures. + reverse_deps = compute_reverse_deps(target_deps) + + # Resolve each changed file. + directly_affected: set[str] = set() + changed_targets: dict[str, dict] = {} + unresolved_files: list[str] = [] + + for file_path in changed_files: + targets, method = resolve_file_to_targets( + file_path, + file_to_targets, + header_to_sources, + ) + if targets: + directly_affected.update(targets) + changed_targets[file_path] = { + "targets": sorted(targets), + "method": method, + } + else: + unresolved_files.append(file_path) + + if not directly_affected: + comment = ( + "## Build Impact Analysis\n\n" + "No build targets affected by this change.\n\n" + "---\n" + f"*{args.graph_source or 'Build impact analysis'}*" + ) + with open(args.output, "w") as fp: + fp.write(comment) + print("No targets affected.") + return + + # Compute transitive closure. + all_affected = compute_transitive_closure(directly_affected, reverse_deps) + + # Compute selective build targets. + selective_targets = compute_selective_build_targets(all_affected, target_deps) + + # Generate comment. + graph_source = args.graph_source or "Build impact analysis" + comment = generate_comment( + changed_targets, + all_affected, + selective_targets, + total_targets, + args.build_type, + graph_source, + unresolved_files, + ) + + with open(args.output, "w") as fp: + fp.write(comment) + + print(f"Comment written to {args.output}") + print(f" Directly affected targets: {len(directly_affected)}") + print(f" Total affected (transitive): {len(all_affected)}") + print(f" Selective build targets: {len(selective_targets)}") + if unresolved_files: + print(f" Unresolved files: {len(unresolved_files)}") + + # Also output the selective build targets as a simple list for CI use. + targets_file = os.path.splitext(args.output)[0] + "-targets.txt" + with open(targets_file, "w") as fp: + fp.write("\n".join(sorted(selective_targets))) + print(f" Selective targets list: {targets_file}") + + +if __name__ == "__main__": + main() diff --git a/.github/scripts/extract-test-failures.sh b/.github/scripts/extract-test-failures.sh new file mode 100755 index 00000000000..ee2e29b37f8 --- /dev/null +++ b/.github/scripts/extract-test-failures.sh @@ -0,0 +1,52 @@ +#!/usr/bin/env bash +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Extract failed test names, gtest cases, and failure details from ctest output. +# Writes results to $GITHUB_OUTPUT as failed-tests, failed-cases, and +# failure-details. +# +# Usage: extract-test-failures.sh + +set -euo pipefail + +CTEST_LOG="$1" + +FAILED_TESTS=$(grep -A 1000 'The following tests FAILED:' "$CTEST_LOG" | grep -E '^\s+[0-9]+ - ' | sed 's/.*- \(.*\) (.*/\1/' | head -20 || true) +FAILED_CASES=$(grep -E '^\[ FAILED \]' "$CTEST_LOG" | sed 's/\[ FAILED \] //' | sed 's/ (.*//' | grep '\.' | sort -u | head -20 || true) +FAILURE_DETAILS=$(awk '/^\[ RUN \]/{buf=$0; next} buf{buf=buf"\n"$0} /^\[ FAILED \]/ && buf{print buf; buf=""} /^\[ OK \]/{buf=""}' "$CTEST_LOG" | head -200 || true) + +if [[ -n $FAILED_TESTS ]]; then + { + echo 'failed-tests<>"$GITHUB_OUTPUT" +fi + +if [[ -n $FAILED_CASES ]]; then + { + echo 'failed-cases<>"$GITHUB_OUTPUT" +fi + +if [[ -n $FAILURE_DETAILS ]]; then + { + echo 'failure-details<>"$GITHUB_OUTPUT" +fi diff --git a/.github/scripts/generate-dependency-graph.py b/.github/scripts/generate-dependency-graph.py new file mode 100755 index 00000000000..1e2f17d8c10 --- /dev/null +++ b/.github/scripts/generate-dependency-graph.py @@ -0,0 +1,341 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Parse CMake File API codemodel output and generate a dependency graph JSON. + +This script reads the CMake File API reply directory (codemodel-v2) and +produces a JSON file containing: + - file_to_targets: mapping of source file paths to their owning targets + - target_deps: mapping of each target to its direct dependencies + - header_to_sources: mapping of header file paths to source files that + include them (via g++ -MM scanning of compile_commands.json) + +The total target count can be derived from len(target_deps). + +Usage: + python generate-dependency-graph.py \\ + --build-dir _build/release \\ + --source-dir . \\ + --output dependency-graph.json +""" + +import argparse +import json +import os +import re +import shlex +import subprocess +import sys +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + + +def find_codemodel_reply(reply_dir: Path) -> dict: + """Find and parse the codemodel reply index file.""" + for f in sorted(reply_dir.iterdir()): + if f.name.startswith("codemodel-v2-") and f.suffix == ".json": + with open(f) as fp: + return json.load(fp) + print("ERROR: No codemodel-v2 reply found in", reply_dir, file=sys.stderr) + sys.exit(1) + + +def parse_target_file(reply_dir: Path, target_json_file: str) -> dict: + """Parse a single target JSON file from the file API reply.""" + target_path = reply_dir / target_json_file + with open(target_path) as fp: + return json.load(fp) + + +def build_dependency_graph( + build_dir: str, source_dir: str +) -> tuple[dict[str, list[str]], dict[str, list[str]]]: + """Build the dependency graph from CMake File API output. + + Returns: + file_to_targets: maps relative source file paths to list of target names + target_deps: maps each target name to its direct dependency target names + """ + reply_dir = Path(build_dir) / ".cmake" / "api" / "v1" / "reply" + if not reply_dir.exists(): + print("ERROR: File API reply directory not found:", reply_dir, file=sys.stderr) + print( + "Ensure cmake was configured with the file API query file.", + file=sys.stderr, + ) + sys.exit(1) + + source_dir = os.path.realpath(source_dir) + build_dir_real = os.path.realpath(build_dir) + + codemodel = find_codemodel_reply(reply_dir) + + file_to_targets: dict[str, list[str]] = {} + target_deps: dict[str, list[str]] = {} + + configurations = codemodel.get("configurations", []) + if not configurations: + print("ERROR: No configurations found in codemodel reply.", file=sys.stderr) + sys.exit(1) + + config = configurations[0] + targets = config.get("targets", []) + + # Build an ID-to-name lookup to avoid re-parsing target files for deps. + id_to_name: dict[str, str] = {} + target_data_cache: dict[str, dict] = {} + for target_ref in targets: + target_json_file = target_ref["jsonFile"] + target_data = parse_target_file(reply_dir, target_json_file) + target_id = target_ref.get("id", "") + target_name = target_data["name"] + id_to_name[target_id] = target_name + target_data_cache[target_id] = target_data + + for target_ref in targets: + target_id = target_ref.get("id", "") + target_data = target_data_cache[target_id] + target_name = id_to_name[target_id] + + # Extract source files. + for source in target_data.get("sources", []): + source_path = source.get("path", "") + if not source_path: + continue + + # Resolve to absolute then make relative to source dir. + if not os.path.isabs(source_path): + abs_path = os.path.normpath(os.path.join(source_dir, source_path)) + else: + abs_path = os.path.normpath(source_path) + + # Skip generated files (those in the build directory). + if abs_path.startswith(build_dir_real): + continue + + # Make path relative to source directory. + try: + rel_path = os.path.relpath(abs_path, source_dir) + except ValueError: + continue + + # Skip paths outside the source tree. + if rel_path.startswith(".."): + continue + + if rel_path not in file_to_targets: + file_to_targets[rel_path] = [] + if target_name not in file_to_targets[rel_path]: + file_to_targets[rel_path].append(target_name) + + # Extract dependencies. + dep_names = [] + for dep in target_data.get("dependencies", []): + dep_id = dep.get("id", "") + if dep_id in id_to_name: + dep_names.append(id_to_name[dep_id]) + target_deps[target_name] = dep_names + + return file_to_targets, target_deps + + +def extract_flags(command: str) -> list[str]: + """Extract include paths, defines, and std flag from a compile command.""" + flags = [] + try: + tokens = shlex.split(command) + except ValueError: + return flags + + i = 0 + while i < len(tokens): + tok = tokens[i] + if tok in ("-I", "-isystem", "-D") and i + 1 < len(tokens): + flags.extend([tok, tokens[i + 1]]) + i += 2 + elif tok.startswith(("-I", "-isystem", "-D")): + flags.append(tok) + i += 1 + elif re.match(r"-std=", tok): + flags.append(tok) + i += 1 + else: + i += 1 + return flags + + +def scan_one_file( + entry: dict, +) -> tuple[str, list[str]]: + """Run g++ -MM -MG on a single compile_commands.json entry. + + Returns (source_file_path, [header_paths]) or (source_file_path, []) + on failure. + """ + source_file = entry["file"] + directory = entry.get("directory", ".") + command = entry.get("command", "") + if not command: + args = entry.get("arguments", []) + command = " ".join(shlex.quote(a) for a in args) + + flags = extract_flags(command) + cmd = ["g++", "-MM", "-MG"] + flags + [source_file] + + try: + result = subprocess.run( + cmd, + cwd=directory, + capture_output=True, + text=True, + timeout=60, + ) + except (subprocess.TimeoutExpired, OSError) as e: + print(f" WARNING: g++ -MM failed for {source_file}: {e}", file=sys.stderr) + return source_file, [] + + if result.returncode != 0: + return source_file, [] + + # Parse make-style output: target.o: source.cpp header1.h header2.h ... + output = result.stdout.replace("\\\n", " ") + colon_idx = output.find(":") + if colon_idx == -1: + return source_file, [] + + deps_str = output[colon_idx + 1 :] + deps = deps_str.split() + # First entry is the source file itself; remaining are headers. + headers = deps[1:] if len(deps) > 1 else [] + return source_file, headers + + +def scan_header_deps(build_dir: str, source_dir: str) -> dict[str, list[str]]: + """Scan compile_commands.json with g++ -MM to build header_to_sources map.""" + compile_commands_path = os.path.join(build_dir, "compile_commands.json") + if not os.path.exists(compile_commands_path): + print( + "WARNING: compile_commands.json not found, skipping header scan.", + file=sys.stderr, + ) + return {} + + with open(compile_commands_path) as fp: + entries = json.load(fp) + + print(f" Scanning {len(entries)} compilation units with g++ -MM ...") + source_dir_real = os.path.realpath(source_dir) + build_dir_real = os.path.realpath(build_dir) + + header_to_sources: dict[str, set[str]] = defaultdict(set) + workers = os.cpu_count() or 4 + + with ThreadPoolExecutor(max_workers=workers) as executor: + results = executor.map(scan_one_file, entries) + scanned = 0 + failed = 0 + for source_file, headers in results: + if not headers: + failed += 1 + continue + scanned += 1 + + # Normalize source path relative to source dir. + if os.path.isabs(source_file): + source_abs = os.path.normpath(source_file) + else: + # Resolve relative to the entry's directory — but we + # don't have it here. The source_file in + # compile_commands.json is typically absolute. + source_abs = os.path.normpath( + os.path.join(source_dir_real, source_file) + ) + + if source_abs.startswith(build_dir_real): + continue + try: + rel_source = os.path.relpath(source_abs, source_dir_real) + except ValueError: + continue + if rel_source.startswith(".."): + continue + + for header in headers: + header_abs = os.path.normpath(header) + if not os.path.isabs(header_abs): + continue + if header_abs.startswith(build_dir_real): + continue + try: + rel_header = os.path.relpath(header_abs, source_dir_real) + except ValueError: + continue + if rel_header.startswith(".."): + continue + header_to_sources[rel_header].add(rel_source) + + # Convert sets to sorted lists for JSON serialization. + result = {k: sorted(v) for k, v in header_to_sources.items()} + print(f" Header scan complete: {scanned} succeeded, {failed} failed") + print(f" Headers mapped: {len(result)}") + return result + + +def main(): + parser = argparse.ArgumentParser( + description="Generate dependency graph from CMake File API output." + ) + parser.add_argument( + "--build-dir", + required=True, + help="Path to the CMake build directory.", + ) + parser.add_argument( + "--source-dir", + default=".", + help="Path to the source directory (default: current directory).", + ) + parser.add_argument( + "--output", + default="dependency-graph.json", + help="Output JSON file path (default: dependency-graph.json).", + ) + args = parser.parse_args() + + file_to_targets, target_deps = build_dependency_graph( + args.build_dir, args.source_dir + ) + + # Phase 2: Scan header dependencies via g++ -MM. + header_to_sources = scan_header_deps(args.build_dir, args.source_dir) + + graph = { + "file_to_targets": file_to_targets, + "header_to_sources": header_to_sources, + "target_deps": target_deps, + } + + with open(args.output, "w") as fp: + json.dump(graph, fp, indent=2, sort_keys=True) + + print(f"Dependency graph written to {args.output}") + print(f" Files mapped: {len(file_to_targets)}") + print(f" Headers mapped: {len(header_to_sources)}") + print(f" Targets: {len(target_deps)}") + + +if __name__ == "__main__": + main() diff --git a/.github/scripts/report-build-status.sh b/.github/scripts/report-build-status.sh new file mode 100755 index 00000000000..2e95cd645bd --- /dev/null +++ b/.github/scripts/report-build-status.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Report build status and create failure metadata artifact. +# +# Usage: report-build-status.sh +# +# Required environment variables: +# BUILD_OUTCOME - build step outcome (success/failure/cancelled) +# BUILD_FAILURE_DETAILS - compiler errors extracted from build output (optional) + +set -euo pipefail + +JOB_NAME="$1" + +if [[ -z $BUILD_OUTCOME || $BUILD_OUTCOME == "cancelled" ]]; then + echo "Build was cancelled (likely superseded by a newer push). Skipping." + exit 0 +fi + +if [[ $BUILD_OUTCOME != "success" ]]; then + echo "::error::${JOB_NAME} build failed. Do not land this PR until the build is fixed." + + if [[ -n ${BUILD_FAILURE_DETAILS:-} ]]; then + echo "" + echo "Build errors:" + echo "----------------------------------------" + echo "$BUILD_FAILURE_DETAILS" + echo "----------------------------------------" + fi + + # Write failure metadata for the CI failure analysis workflow. + mkdir -p /tmp/ci-failure + jq -n --arg job "$JOB_NAME" --arg type "build" \ + '{job: $job, type: $type}' \ + >/tmp/ci-failure/failure.json + + exit 1 +fi + +echo "Build succeeded." diff --git a/.github/scripts/report-test-status.sh b/.github/scripts/report-test-status.sh new file mode 100755 index 00000000000..553c3fcc9f9 --- /dev/null +++ b/.github/scripts/report-test-status.sh @@ -0,0 +1,95 @@ +#!/usr/bin/env bash +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Report test status and create failure metadata artifact. +# +# Usage: report-test-status.sh +# +# Required environment variables: +# BUILD_OUTCOME - build step outcome (success/failure/cancelled) +# TEST_OUTCOME - test step outcome (success/failure/cancelled) +# FLAKY - "true" if tests were flaky (failed then passed on retry) +# FAILED_TESTS - newline-separated list of failed ctest targets (optional) +# FAILED_CASES - newline-separated list of failed gtest cases (optional) +# FAILURE_DETAILS - gtest output for failing tests (optional) + +set -euo pipefail + +JOB_NAME="$1" +CONFIG_NAME="$2" + +if [[ -z $BUILD_OUTCOME || $BUILD_OUTCOME == "cancelled" ]]; then + echo "Run was cancelled (likely superseded by a newer push). Skipping." + exit 0 +fi + +if [[ $BUILD_OUTCOME != "success" ]]; then + echo "::error::Build failed — tests did not run. Check the '${JOB_NAME}' build job for compiler errors and build logs." + exit 1 +fi + +if [[ -z $TEST_OUTCOME || $TEST_OUTCOME == "cancelled" ]]; then + echo "Tests were cancelled (likely superseded by a newer push). Skipping." + exit 0 +fi + +if [[ $TEST_OUTCOME != "success" ]]; then + if [[ -n $FAILED_CASES ]]; then + CASE_LIST=$(echo "$FAILED_CASES" | sed 's/^/ - /') + CASE_COUNT=$(echo "$FAILED_CASES" | wc -l | tr -d ' ') + echo "::error::${CASE_COUNT} test case(s) failed in the ${CONFIG_NAME} configuration. Failed: $(echo "$FAILED_CASES" | tr '\n' ', ' | sed 's/,$//'). To see the full test output with failure details, click the '${JOB_NAME}' job in the workflow run, then expand the 'Run Tests' step." + echo "" + echo "Failed test cases:" + echo "$CASE_LIST" + elif [[ -n $FAILED_TESTS ]]; then + TEST_LIST=$(echo "$FAILED_TESTS" | sed 's/^/ - /') + TEST_COUNT=$(echo "$FAILED_TESTS" | wc -l | tr -d ' ') + echo "::error::${TEST_COUNT} test(s) failed in the ${CONFIG_NAME} configuration. Failed tests: $(echo "$FAILED_TESTS" | tr '\n' ', ' | sed 's/,$//'). To see the full test output with failure details, click the '${JOB_NAME}' job in the workflow run, then expand the 'Run Tests' step." + echo "" + echo "Failed tests:" + echo "$TEST_LIST" + else + echo "::error::Tests failed in the ${CONFIG_NAME} configuration but no specific test names were captured. Check the 'Run Tests' step in the '${JOB_NAME}' job for details." + fi + if [[ -n ${FAILURE_DETAILS:-} ]]; then + echo "" + echo "Failure details:" + echo "$FAILURE_DETAILS" + fi + echo "" + echo "To investigate, click the '${JOB_NAME}' job in the workflow run, then expand the 'Run Tests' step for the full ctest output with assertion failures and stack traces." + + # Write failure metadata for the CI failure analysis workflow. + mkdir -p /tmp/ci-failure + if [[ -n $FAILED_CASES ]]; then + TESTS="$FAILED_CASES" + elif [[ -n $FAILED_TESTS ]]; then + TESTS="$FAILED_TESTS" + else + TESTS="" + fi + jq -n --arg job "$JOB_NAME" \ + --arg type "test" \ + --arg failed_tests "$TESTS" \ + '{job: $job, type: $type, failed_tests: $failed_tests}' \ + >/tmp/ci-failure/failure.json + + exit 1 +fi + +if [[ $FLAKY == "true" ]]; then + echo "::warning::Some tests were flaky (failed then passed on retry)." +fi +echo "All tests passed in ${CONFIG_NAME}." diff --git a/.github/workflows/README.md b/.github/workflows/README.md new file mode 100644 index 00000000000..2e36c71c0a6 --- /dev/null +++ b/.github/workflows/README.md @@ -0,0 +1,152 @@ +# CI Workflows + +Velox CI validates builds and tests across Linux (GCC and optionally Clang) and macOS (Apple Clang). Linux workflows run full unit test suites in Docker containers on 32-core runners; macOS workflows verify compilation only. Fuzzer workflows stress-test functions and operators with randomized inputs. + +### Why Docker containers? + +Velox has a [large dependency footprint](../../CMake/resolve_dependency_modules/README.md) — over 20 third-party libraries. Installing and version-pinning these on bare runners would be fragile and slow. Instead, the `docker.yml` workflow builds pre-configured Docker images (`ghcr.io/facebookincubator/velox-dev`) with all dependencies pre-installed. CI build jobs run inside these containers, ensuring reproducible environments and fast startup. The images come in several variants: + +- **`centos9`** — Base CentOS Stream 9 environment used by fuzzer workflows +- **`adapters`** — Full environment with cloud storage adapters (S3, GCS, ABFS, HDFS) and GPU support +- **`ubuntu`** — Standard Ubuntu 22.04 build environment +- **`fedora`** — Fedora environment for testing against different system package versions (e.g., system-provided gRPC, Arrow, Thrift) +- **`presto-java`** / **`spark-server`** — Images with Presto/Spark servers for fuzzer reference result verification + +The `centos9` and `adapters` images are multi-architecture (amd64 + arm64); the `ubuntu` and `fedora` images are amd64-only. Images are rebuilt automatically when Dockerfiles or setup scripts change on `main`. On macOS, dependencies are installed directly via `setup-macos.sh` using Homebrew since Docker is not used. + +For current build times and performance trends, see the [CI performance metrics](https://github.com/facebookincubator/velox/actions/metrics/performance). + +| Workflow | File | Triggers | Purpose | +|----------|------|----------|---------| +| Linux Build using GCC | `linux-build.yml` + `linux-build-base.yml` | push to main, PRs | Main build & test (3 configs) | +| macOS Build | `macos.yml` | push, PRs | Compilation check (debug + release) | +| Breeze Linux Build | `breeze.yml` | push to main, PRs | Tracing module with sanitizers | +| Fuzzer Jobs | `scheduled.yml` | PRs, push to main, daily cron, manual | Randomized correctness testing | +| Run Checks | `preliminary_checks.yml` | PRs | Formatting, linting, PR title | +| Build Impact Analysis | `build-impact.yml` + `build-impact-comment.yml` | push to main, PRs | CMake dependency impact | +| CI Failure Comment | `ci-failure-comment.yml` | workflow_run (on Linux Build / Fuzzer failure) | AI-powered failure analysis on PRs | +| Claude PR Assistants | `claude.yml` + `claude-review.yml` | PR comments (`@claude`) | AI code review | +| Build Pyvelox Wheels | `build_pyvelox.yml` | manual | Python wheel packaging | +| Docker Images | `docker.yml` | push to main, manual | Multi-arch Docker images | +| Weekly Date Tag | `tag.yml` | weekly cron, manual | Version tagging | +| Update Documentation | `docs.yml` | push (docs changes), PRs | Sphinx docs + GitHub Pages | +| Collect Build Metrics | `build-metrics.yml` | PRs, daily, manual | Binary size tracking (disabled) | +| Ubuntu Bundled Deps | `ubuntu-bundled-deps.yml` | nightly, manual, PRs (dep scripts) | Build-from-source validation | + +## Core Build & Test + +### Linux Build using GCC (`linux-build.yml` + `linux-build-base.yml`) + +The main CI workflow for Velox. Triggered on pushes to `main` and on pull requests when relevant files change (source code, CMake files, setup scripts, or the workflow files themselves). The entry point `linux-build.yml` delegates to the reusable `linux-build-base.yml` template, which builds and tests three configurations in parallel on 32-core Ubuntu runners: + +- **Linux adapters release** — Release build using the `velox-dev:adapters` Docker image. Enables cloud storage adapters (S3, GCS, ABFS, HDFS), Parquet, Arrow, geospatial functions, and GPU support (WAVE, cuDF). Tests run with `ctest -j 24` and a 900-second timeout. When cuDF-related files change, a separate cuDF test job runs on a `4-core-ubuntu-gpu-t4` GPU runner. +- **Ubuntu debug** — Debug build using the `velox-dev:ubuntu-22.04` Docker image. Enables benchmarks, examples, Arrow, geospatial, Parquet, shared library (`VELOX_BUILD_SHARED=ON`), and mono library modes. Tests run with `ctest -j 24` and a 1800-second timeout. +- **Fedora debug** — Debug build using the `velox-dev:fedora` Docker image. Validates compilation compatibility with Fedora's system packages including system-provided gRPC and Arrow/Thrift shared libraries. This configuration focuses on compiler and OS compatibility rather than test coverage — it does not have a separate test status job. + +All configurations use ccache for build acceleration (persisted via Apache infrastructure stash). The adapters and ubuntu-debug configurations include a flaky test retry mechanism: if any tests fail on the first run, they are automatically retried with `ctest --rerun-failed`. If the retry passes, the tests are marked as flaky; if it fails again, the specific failed test case names are extracted and reported. + +#### BUILD / TEST status jobs and `continue-on-error` + +Each configuration has separate BUILD and TEST status jobs that appear as independent checks on PRs. This design solves a specific problem: developers need to immediately see whether a CI failure is a build failure (absolute blocker) or a test failure (could be flaky, needs investigation). + +The obvious approach — splitting build and test into separate GitHub Actions jobs — is impractical for Velox because the build artifacts are too large to transfer between jobs. Test binaries are ~3GB each (8 grouped binaries totaling ~24GB), plus the shared `libvelox.so` mono library. We tried multiple compression strategies in PR #16938 (tar+gzip, pigz, zstd, direct upload) and all either exceeded the test runtime or were incompatible with the container environment. Stripping debug symbols would defeat the purpose of a debug build. The sequential overhead of packaging, uploading, downloading, and extracting artifacts negates any parallelism gains. + +Instead, the test step uses `continue-on-error: true` so the main job always completes, and lightweight status jobs read the step outcomes to provide separate pass/fail signals. The main job (which contains the full test output) appears green even when tests fail, but the status jobs compensate in two ways: (1) failed test names, gtest case names, and failure details (assertion messages, compiler errors) are forwarded to the status jobs via job output variables and displayed directly in the status job logs and as `::error::` annotations, so developers clicking the red check see the specific failures immediately; (2) the `ci-failure-comment.yml` workflow uses Claude to analyze the full logs and post a rich diagnostic comment directly on the PR. + +Status jobs handle cancelled runs gracefully — when a job is cancelled (e.g., superseded by a newer push), status jobs exit cleanly without false "build failed" reports. + +### macOS Build (`macos.yml`) + +Builds Velox on macOS 15 (ARM64/Apple Silicon) with both debug and release configurations. Triggered on pushes to any branch and on pull requests when relevant files change. Uses the Ninja build generator and ccache for faster builds. Tests are currently disabled on macOS — the workflow focuses on ensuring compilation compatibility with Apple's toolchain rather than full test coverage. Dependencies are installed via the `setup-macos.sh` script. + +### Breeze Linux Build (`breeze.yml`) + +Experimental workflow for the Breeze tracing/profiling module and Perfetto integration. Builds two configurations: a debug build with Address Sanitizer (ASAN) and Undefined Behavior Sanitizer (UBSAN) enabled for memory safety testing, and a RelWithDebInfo build with CUDA support for the Breeze module. Unlike the main Linux build, Breeze uses `VELOX_DEPENDENCY_SOURCE=BUNDLED` to build dependencies from source rather than using system packages. Tests run with `ctest -j 8` under sanitizers. + +## Fuzzing (`scheduled.yml`) + +A comprehensive fuzzing suite that tests Velox functions and operators against random inputs to catch correctness issues. Triggered on pull requests, pushes to `main`, a daily cron schedule (`0 3 * * *` UTC), and manual `workflow_dispatch`. The daily and main-push triggers are particularly important as they catch regressions that may not be exercised by PR-level test coverage. + +The workflow first compiles all fuzzer binaries in a shared `compile` job, then runs 12+ independent fuzzer targets in parallel: + +- **Presto Fuzzer** and **Spark Fuzzer** — Test Presto and Spark SQL function implementations with random inputs and verify results against reference implementations (DuckDB for Presto, Spark for Spark functions). +- **Aggregation Fuzzers** — Test aggregate functions with random grouping keys and inputs, with Presto as source of truth. +- **Join Fuzzer** — Tests hash join, merge join, and nested loop join operators with random schemas and data. +- **Exchange Fuzzer** — Tests the data exchange/shuffle operator. +- **Window Fuzzer** — Tests window function implementations. +- **Writer Fuzzer** — Tests file format writers (Parquet, DWRF). +- **RowNumber and TopNRowNumber Fuzzers** — Test row numbering operators. +- **Table Evolution Fuzzer** — Tests schema evolution scenarios. +- **Memory Arbitration Fuzzer** — Tests memory management under pressure. +- **Spatial Join Fuzzer** — Tests geospatial join operations. +- **Cache Fuzzer** — Tests caching infrastructure. + +The workflow also includes bias fuzzers that focus specifically on newly added or recently updated functions, and a signature check job that validates function signatures match expected interfaces. + +## PR Checks & Feedback + +### Run Checks (`preliminary_checks.yml`) + +Runs early validation on pull requests before the heavier build workflows. Executes `pre-commit run --all-files` to check code formatting (clang-format), linting (yamllint, zizmor), license headers, and other code quality rules. Also validates the PR title against the conventional commits format (`type(scope): description`), which is required for all PRs. + +### Build Impact Analysis (`build-impact.yml` + `build-impact-comment.yml`) + +Analyzes which build targets are affected by the files changed in a PR using the CMake dependency graph. On pushes to `main`, the workflow generates and caches a fresh dependency graph. On PRs, it uses a two-path strategy: a fast path that reuses the cached graph from `main`, and a slow path that regenerates the graph when CMake files change. The companion `build-impact-comment.yml` workflow posts the analysis results as a PR comment, updating the existing comment if one already exists. + +### CI Failure Comment (`ci-failure-comment.yml`) + +Analyzes CI failures and posts diagnostic comments on PRs. Triggered via `workflow_run` when the "Linux Build using GCC" or "Fuzzer Jobs" workflow completes with a failure. The workflow finds the associated PR, downloads failure metadata artifacts uploaded by the status jobs in `linux-build-base.yml` (or, for fuzzer failures, queries the GitHub API for failed job names), then uses Claude to analyze the failures. Claude fetches the full job logs via the GitHub API, reads the PR diff, searches open issues for known failures, and checks recent main branch CI runs for pre-existing flaky tests. The resulting PR comment includes specific failure details (assertion errors, compiler diagnostics), correlation analysis with the PR changes, links to known issues, and recommended fixes. Also supports manual triggering via `workflow_dispatch` with a run ID and PR number. + +Uses the `workflow_run` pattern because fork PRs have read-only tokens and cannot post comments directly. + +### Claude PR Assistants (`claude.yml` + `claude-review.yml`) + +Two AI-powered code review workflows. The newer `claude.yml` uses a skill-based approach triggered by `@claude` mentions in PR comments, restricted to an authorized user allowlist. The legacy `claude-review.yml` supports `/claude-review` (full code review) and `/claude-query` (targeted questions) commands, parsing PR diffs and generating detailed review feedback with file-by-file analysis, risk assessment, and testing recommendations. + +## Packaging & Release + +### Build Pyvelox Wheels (`build_pyvelox.yml`) + +Builds Python wheel packages for PyVelox across multiple Python versions (3.10–3.13) and platforms (Ubuntu, macOS Intel, macOS ARM). Uses `cibuildwheel` for cross-platform wheel building. Triggered via `workflow_dispatch` and on PRs that modify the workflow file. Optionally publishes to PyPI on success. + +### Docker Images (`docker.yml`) + +Builds and pushes the Docker images described in the [Why Docker containers?](#why-docker-containers) section above. Uses Docker Buildx with BuildKit for multi-arch support (amd64 + arm64) and registry-based layer caching. Triggered on pushes to `main` when Dockerfiles or setup scripts change, and on manual dispatch. + +### Weekly Date Tag (`tag.yml`) + +Automatically creates weekly date-based version tags (e.g., `v2026.04.03.00`) every Friday at 09:23 UTC. Checks CI status before tagging to ensure only successful builds are released. Also supports manual triggering via `workflow_dispatch`. + +## Documentation & Metrics + +### Update Documentation (`docs.yml`) + +Builds Sphinx documentation when files under `velox/docs/` change on pushes or pull requests. Publishes to GitHub Pages only on pushes to the official `facebookincubator/velox` repository. Includes PyVelox Python API documentation. On pull requests, validates that documentation builds without errors but does not deploy. + +### Collect Build Metrics (`build-metrics.yml`) + +Measures binary sizes across four build configurations (debug/release x shared/static) and was designed to upload metrics to conbench for performance regression tracking. Currently disabled because the conbench service is unavailable. Triggered on PRs, daily schedule, and manual dispatch. + +## Dependency Testing + +### Ubuntu Bundled Dependencies (`ubuntu-bundled-deps.yml`) + +Tests that Velox can be built entirely from source on a plain Ubuntu system without pre-installed dependencies (except ICU). Uses `VELOX_DEPENDENCY_SOURCE=BUNDLED` to build all dependencies from source, validating the bundled dependency resolution scripts. Runs on a 16-core runner (no Docker container) with 16 build threads. Includes a comprehensive set of feature flags: benchmarks, examples, Arrow, geospatial, Parquet, mono library, shared library, FAISS, and remote functions. Tests run with `ctest -j 8`. Runs nightly at 5 AM UTC, on manual dispatch, and on PRs that modify dependency scripts. Only runs on the official `facebookincubator/velox` repository, not forks. + +## Architecture Notes + +### Fork PR Permissions + +GitHub restricts fork PR tokens to read-only for security. Workflows that need to post PR comments use the **`workflow_run` pattern**: the main workflow uploads results as artifacts, and a separate `workflow_run`-triggered workflow downloads them and posts comments using the base repo's write permissions. + +This pattern is used by: +- `build-impact-comment.yml` (posts build impact analysis) +- `ci-failure-comment.yml` (posts CI failure analysis) + +### Build Caching + +All build workflows use ccache for compiler output caching, persisted across runs using Apache infrastructure stash actions. Cache keys include the platform, build type, and compiler to avoid cache collisions. + +### Concurrency + +Most workflows use concurrency groups keyed on `workflow + repository + branch/SHA` to automatically cancel in-progress runs when new commits are pushed to the same branch, avoiding wasted CI resources. diff --git a/.github/workflows/linux-build.yml b/.github/workflows/linux-build.yml index c2885d2fb56..0b9020be107 100644 --- a/.github/workflows/linux-build.yml +++ b/.github/workflows/linux-build.yml @@ -30,7 +30,7 @@ concurrency: jobs: adapters: name: Linux release with adapters - runs-on: linux-amd64-gpu-l4-latest-1 + runs-on: linux-amd64-cpu16 container: ghcr.io/facebookincubator/velox-dev:adapters defaults: run: @@ -44,46 +44,38 @@ jobs: Arrow_SOURCE: BUNDLED Thrift_SOURCE: BUNDLED cudf_SOURCE: BUNDLED - CUDA_VERSION: "12.8" + CUDA_VERSION: "12.9" steps: - uses: actions/checkout@v4 with: persist-credentials: false - name: Fix git permissions - # Usually actions/checkout does this but as we run in a container - # it doesn't work run: git config --global --add safe.directory ${GITHUB_WORKSPACE} - name: Install Dependencies run: | - # Allows to install arbitrary cuda-version whithout needing to update - # docker container before. It simplifies testing new/different versions if ! yum list installed cuda-nvcc-$(echo ${CUDA_VERSION} | tr '.' '-') 1>/dev/null || \ ! yum list installed libnvjitlink-devel-$(echo ${CUDA_VERSION} | tr '.' '-') 1>/dev/null; then source scripts/setup-centos-adapters.sh install_cuda ${CUDA_VERSION} fi - - # TODO: Install a newer cmake here until we update the images upstream pip install cmake==3.30.4 - uses: apache/infrastructure-actions/stash/restore@3354c1565d4b0e335b78a76aedd82153a9e144d4 with: - path: '${{ env.CCACHE_DIR }}' + path: "${{ env.CCACHE_DIR }}" key: ccache-linux-adapters - name: Zero Ccache Statistics - run: | - ccache -sz + run: ccache -sz - name: Make Release Build env: MAKEFLAGS: TREAT_WARNINGS_AS_ERRORS=0 NUM_THREADS=16 MAX_HIGH_MEM_JOBS=4 CUDA_ARCHITECTURES: 70 CUDA_COMPILER: /usr/local/cuda-${CUDA_VERSION}/bin/nvcc - # Set compiler to GCC 12 - CUDA_FLAGS: -ccbin /opt/rh/gcc-toolset-12/root/usr/bin + CUDA_FLAGS: -ccbin /opt/rh/gcc-toolset-14/root/usr/bin run: | EXTRA_CMAKE_FLAGS=( "-DVELOX_ENABLE_BENCHMARKS=ON" @@ -98,29 +90,117 @@ jobs: "-DVELOX_ENABLE_CUDF=ON" "-DVELOX_MONO_LIBRARY=ON" ) - make release EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS[*]}" + source /opt/rh/gcc-toolset-14/enable + CC=gcc CXX=g++ make release EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS[*]}" - name: Ccache after run: ccache -s - uses: apache/infrastructure-actions/stash/save@3354c1565d4b0e335b78a76aedd82153a9e144d4 with: - path: '${{ env.CCACHE_DIR }}' + path: "${{ env.CCACHE_DIR }}" key: ccache-linux-adapters - - name: Run Tests + - name: Run non-GPU tests env: LIBHDFS3_CONF: "${{ github.workspace }}/scripts/hdfs-client.xml" working-directory: _build/release run: | - # Can be removed after images are rebuild if [ -f "/opt/miniforge/etc/profile.d/conda.sh" ]; then source "/opt/miniforge/etc/profile.d/conda.sh" conda activate adapters fi export CLASSPATH=`/usr/local/hadoop/bin/hdfs classpath --glob` - ctest -j 8 --output-on-failure --no-tests=error -E "velox_exec_test|velox_hdfs_file_test|velox_s3" + ctest -j 8 --output-on-failure --no-tests=error \ + -E "velox_exec_test|velox_hdfs_file_test|velox_s3|velox_cudf_" + + - name: Copy Shared Libraries for cuDF Test Binaries + env: + CUDF_DIR: _build/release/velox/experimental/cudf + run: | + mkdir -p "$CUDF_DIR/cudf-libs" + deps=$( + find "$CUDF_DIR/tests" -name "velox_cudf_*" -type f -executable | + xargs ldd 2>/dev/null | + sed -n 's/[^\/]*\(\/[^ ]*\) .*/\1/p' | + grep -vE '^/(lib|usr/lib|lib64)/(libcuda|librt\.so|libm\.so|libstdc\+\+\.so|ld-linux-x86-64\.so|libc\.so)' | + sort -u + ) + if [ -z "$deps" ]; then + echo "Error: No dependencies found" + exit 1 + fi + ( + echo "$deps" | + xargs readlink -f | + xargs -I {} cp {} "$CUDF_DIR/cudf-libs/" + ) + lndeps=$(echo "$deps" | xargs -I {} bash -c '[ -L "$1" ] && echo "$1" || true' bash {}) + ( + cd "$CUDF_DIR/cudf-libs" || exit 1 + while IFS= read -r link; do + [[ -L "$link" ]] || continue + target=$(readlink -f "$link") || continue + target_basename=$(basename "$target") || continue + link_basename=$(basename "$link") || continue + ln -sf "$target_basename" "$link_basename" + done <<< "$lndeps" + ) + tar -cvf "$CUDF_DIR/cudf-libs.tar" -C "$CUDF_DIR" cudf-libs + + - name: Upload cuDF Test Binaries + uses: actions/upload-artifact@v4 + with: + name: cudftestbinaries + path: | + velox/dwio/parquet/tests/examples/int.parquet + velox/experimental/cudf/tests/CMakeLists.txt + _build/release/velox/experimental/cudf/tests/velox_cudf_* + _build/release/velox/experimental/cudf/tests/CTestTestfile.cmake + _build/release/velox/experimental/cudf/cudf-libs.tar + + cudf-tests: + name: cuDF GPU tests + runs-on: linux-amd64-gpu-l4-latest-1 + needs: adapters + timeout-minutes: 30 + env: + CUDF_DIR: _build/release/velox/experimental/cudf + steps: + - name: Install Packages + run: | + sudo apt-get update + sudo apt-get install -y cmake patchelf cuda-toolkit-12-9 + export MINIO_VERSION="2022-05-26T05-48-41Z" + export MINIO_BINARY_NAME="minio-2022-05-26" + wget https://dl.min.io/server/minio/release/linux-amd64/archive/minio.RELEASE."${MINIO_VERSION}" -O "${MINIO_BINARY_NAME}" + sudo mv ./"${MINIO_BINARY_NAME}" /usr/local/bin/ + sudo chmod +x /usr/local/bin/"${MINIO_BINARY_NAME}" + + - name: Check NVIDIA Driver Version + run: nvidia-smi + + - name: Download cuDF Test Binaries + uses: actions/download-artifact@v4 + with: + name: cudftestbinaries + + - name: Adapt Downloaded Files + run: | + sed -i 's|/__w/velox/velox|${{ github.workspace }}|g' $CUDF_DIR/tests/CTestTestfile.cmake + grep "Source directory" $CUDF_DIR/tests/CTestTestfile.cmake + grep "Build directory" $CUDF_DIR/tests/CTestTestfile.cmake + (cd $CUDF_DIR && tar -xf cudf-libs.tar) + for exe in $CUDF_DIR/tests/velox_cudf_*; do + patchelf --force-rpath --set-rpath '$ORIGIN/../cudf-libs' "$exe" + chmod +x "$exe" + done + + - name: Run cuDF Tests + run: | + cd $CUDF_DIR/tests/ + ctest --output-on-failure # ubuntu-debug: # runs-on: linux-amd64-cpu16 diff --git a/.github/zizmor.yml b/.github/zizmor.yml index a1baabf91cf..a3dc69a0b9f 100644 --- a/.github/zizmor.yml +++ b/.github/zizmor.yml @@ -15,3 +15,7 @@ rules: use-trusted-publishing: ignore: - build_pyvelox.yml + dangerous-triggers: + ignore: + - build-impact-comment.yml + - ci-failure-comment.yml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9c31eee8fb2..3124eae3d29 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,15 +37,35 @@ repos: hooks: - id: clang-tidy name: clang-tidy - description: Run clang-tidy on C/C++ files + description: >- + Run clang-tidy on C/C++ files, requires 'compile_commands.json' + to be available in the repo root (e.g. symlinked) or BUILD_PATH + set to it's path e.g. _build/release. stages: - - manual # Needs compile_commands.json - entry: clang-tidy + - pre-commit + entry: ./scripts/checks/run-clang-tidy.py + args: [--commit, HEAD] language: python types_or: [c++, c] additional_dependencies: [clang-tidy==18.1.8] require_serial: true + - id: check-header-ownership + name: check-header-ownership + description: Verify every .h file is tracked by a CMakeLists.txt target. + entry: python3 scripts/checks/check-header-ownership.py velox + language: python + pass_filenames: false + files: (\.h|CMakeLists\.txt)$ + + - id: check-readme-blogs + name: check-readme-blogs + description: Ensure README lists the 3 most recent blog posts. + entry: ./scripts/checks/check-readme-blogs.sh + language: script + files: ^(README\.md|website/blog/.*)$ + pass_filenames: false + - id: license-header name: license-header description: Add missing license headers. @@ -73,7 +93,7 @@ repos: name: CMake formatter - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v18.1.3 + rev: v21.1.2 hooks: - id: clang-format # types_or: [c++, c, cuda, metal, objective-c] diff --git a/CMake/FindSnappy.cmake b/CMake/FindSnappy.cmake deleted file mode 100644 index 2d65b3d1766..00000000000 --- a/CMake/FindSnappy.cmake +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# - Try to find snappy -# Once done, this will define -# -# SNAPPY_FOUND - system has Glog -# SNAPPY_INCLUDE_DIRS - deprecated -# SNAPPY_LIBRARIES - deprecated -# Snappy::snappy will be defined based on CMAKE_FIND_LIBRARY_SUFFIXES priority - -include(FindPackageHandleStandardArgs) -include(SelectLibraryConfigurations) - -find_library(SNAPPY_LIBRARY_RELEASE snappy PATHS $SNAPPY_LIBRARYDIR}) -find_library(SNAPPY_LIBRARY_DEBUG snappyd PATHS ${SNAPPY_LIBRARYDIR}) - -find_path(SNAPPY_INCLUDE_DIR snappy.h PATHS ${SNAPPY_INCLUDEDIR}) - -select_library_configurations(SNAPPY) - -find_package_handle_standard_args(Snappy DEFAULT_MSG SNAPPY_LIBRARY SNAPPY_INCLUDE_DIR) - -mark_as_advanced(SNAPPY_LIBRARY SNAPPY_INCLUDE_DIR) - -get_filename_component(libsnappy_ext ${SNAPPY_LIBRARY} EXT) -if(libsnappy_ext STREQUAL ".a") - set(libsnappy_type STATIC) -else() - set(libsnappy_type SHARED) -endif() - -if(NOT TARGET Snappy::snappy) - add_library(Snappy::snappy ${libsnappy_type} IMPORTED) - set_target_properties( - Snappy::snappy - PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${SNAPPY_INCLUDE_DIR}" - ) - set_target_properties( - Snappy::snappy - PROPERTIES IMPORTED_LINK_INTERFACE_LANGUAGES "C" IMPORTED_LOCATION "${SNAPPY_LIBRARIES}" - ) -endif() diff --git a/CMake/Findc-ares.cmake b/CMake/Findc-ares.cmake index 58087c018fa..8c3fa68222a 100644 --- a/CMake/Findc-ares.cmake +++ b/CMake/Findc-ares.cmake @@ -20,7 +20,7 @@ if(c-ares_FOUND) endif() find_path(C_ARES_INCLUDE_DIR NAMES ares.h PATH_SUFFIXES c-ares) -find_library(C_ARES_LIBRARY NAMES c-ares) +find_library(C_ARES_LIBRARY NAMES c-ares cares) include(FindPackageHandleStandardArgs) find_package_handle_standard_args(c-ares DEFAULT_MSG C_ARES_LIBRARY C_ARES_INCLUDE_DIR) diff --git a/CMake/Findlz4.cmake b/CMake/Findlz4.cmake index 2840d017479..8a43e8ea856 100644 --- a/CMake/Findlz4.cmake +++ b/CMake/Findlz4.cmake @@ -23,7 +23,7 @@ include(FindPackageHandleStandardArgs) include(SelectLibraryConfigurations) -find_library(LZ4_LIBRARY_RELEASE lz4 PATHS $LZ4_LIBRARYDIR}) +find_library(LZ4_LIBRARY_RELEASE lz4 PATHS ${LZ4_LIBRARYDIR}) find_library(LZ4_LIBRARY_DEBUG lz4d PATHS ${LZ4_LIBRARYDIR}) find_path(LZ4_INCLUDE_DIR lz4.h PATHS ${LZ4_INCLUDEDIR}) diff --git a/CMake/VeloxConfig.cmake.in b/CMake/VeloxConfig.cmake.in index 0164099c886..158f47277b4 100644 --- a/CMake/VeloxConfig.cmake.in +++ b/CMake/VeloxConfig.cmake.in @@ -57,7 +57,7 @@ block() if("@simdjson_SOURCE@" STREQUAL "SYSTEM") find_dependency(simdjson) endif() - if("@THRIFT_SOURCE@" STREQUAL "SYSTEM") + if("@Thrift_FOUND@") find_dependency(Thrift) endif() if("@xsimd_SOURCE@" STREQUAL "SYSTEM") diff --git a/CMake/VeloxUtils.cmake b/CMake/VeloxUtils.cmake index 8d80d374579..47cdeb7f795 100644 --- a/CMake/VeloxUtils.cmake +++ b/CMake/VeloxUtils.cmake @@ -42,7 +42,8 @@ function(pyvelox_add_module TARGET) install(TARGETS ${TARGET} LIBRARY DESTINATION pyvelox COMPONENT pyvelox_libraries) endfunction() -# TODO use file sets +# Glob-based header install fallback. Kept for directories that have not yet +# migrated to the HEADERS keyword in velox_add_library (which uses FILE_SET). function(velox_install_library_headers) # Find any headers and install them relative to the source tree in include. file(GLOB _hdrs "*.h") @@ -57,6 +58,28 @@ function(velox_install_library_headers) endif() endfunction() +# Associate headers with test/benchmark/fuzzer targets via FILE_SET for CMake +# File API discoverability. For production libraries use the HEADERS keyword +# in velox_add_library() instead. +function(velox_add_test_headers TARGET) + get_target_property(_type ${TARGET} TYPE) + if(_type STREQUAL "INTERFACE_LIBRARY") + set(_scope INTERFACE) + else() + set(_scope PUBLIC) + endif() + + target_sources( + ${TARGET} + ${_scope} + FILE_SET HEADERS + BASE_DIRS + ${PROJECT_SOURCE_DIR} + FILES + ${ARGN} + ) +endfunction() + # Base add velox library call to add a library and install it. function(velox_base_add_library TARGET) add_library(${TARGET} ${ARGN}) @@ -68,10 +91,9 @@ endfunction() function(velox_add_library TARGET) set(options OBJECT STATIC SHARED INTERFACE) set(oneValueArgs) - set(multiValueArgs) + set(multiValueArgs HEADERS) cmake_parse_arguments(VELOX "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - # Remove library type specifiers from ARGN set(library_type) if(VELOX_OBJECT) set(library_type OBJECT) @@ -83,23 +105,23 @@ function(velox_add_library TARGET) set(library_type INTERFACE) endif() - list(REMOVE_ITEM ARGN OBJECT) - list(REMOVE_ITEM ARGN STATIC) - list(REMOVE_ITEM ARGN SHARED) - list(REMOVE_ITEM ARGN INTERFACE) + set(_sources ${VELOX_UNPARSED_ARGUMENTS}) + # Propagate to the underlying add_library and then install the target. if(VELOX_MONO_LIBRARY) if(TARGET velox) # Target already exists, append sources to it. - target_sources(velox PRIVATE ${ARGN}) - install(TARGETS velox LIBRARY DESTINATION pyvelox COMPONENT pyvelox_libraries) + target_sources(velox PRIVATE ${_sources}) + if(VELOX_BUILD_PYTHON_PACKAGE) + install(TARGETS velox LIBRARY DESTINATION pyvelox COMPONENT pyvelox_libraries) + endif() else() set(_type STATIC) if(VELOX_BUILD_SHARED) set(_type SHARED) endif() # Create the target if this is the first invocation. - add_library(velox ${_type} ${ARGN}) + add_library(velox ${_type} ${_sources}) set_target_properties(velox PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/lib) set_target_properties(velox PROPERTIES ARCHIVE_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/lib) install(TARGETS velox DESTINATION lib/velox EXPORT velox_targets) @@ -133,10 +155,10 @@ function(velox_add_library TARGET) list(APPEND system_dependencies Snappy zstd) endif() foreach(system_dependency ${system_dependencies}) - install( - FILES "${PROJECT_SOURCE_DIR}/CMake/Find${system_dependency}.cmake" - DESTINATION "${package_cmake_dir}" - ) + set(velox_find_module "${PROJECT_SOURCE_DIR}/CMake/Find${system_dependency}.cmake") + if(EXISTS "${velox_find_module}") + install(FILES "${velox_find_module}" DESTINATION "${package_cmake_dir}") + endif() endforeach() # TODO: We can enable this once we add version to Velox. # set(version_cmake "${PROJECT_BINARY_DIR}/CMake/VeloxConfigVersion.cmake") @@ -157,8 +179,36 @@ function(velox_add_library TARGET) endif() else() # Create a library for each invocation. - velox_base_add_library(${TARGET} ${library_type} ${ARGN}) + velox_base_add_library(${TARGET} ${library_type} ${_sources}) endif() + + # Associate headers with the target via FILE_SET for tracking and IDE + # integration. The glob-based velox_install_library_headers() remains as + # fallback for directories that have not yet listed their headers explicitly. + if(VELOX_HEADERS) + if(VELOX_MONO_LIBRARY) + set(_header_target velox) + # The velox target is a real (non-INTERFACE) library, so always use PUBLIC. + set(_header_scope PUBLIC) + else() + set(_header_target ${TARGET}) + if(VELOX_INTERFACE) + set(_header_scope INTERFACE) + else() + set(_header_scope PUBLIC) + endif() + endif() + target_sources( + ${_header_target} + ${_header_scope} + FILE_SET HEADERS + BASE_DIRS + ${PROJECT_SOURCE_DIR} + FILES + ${VELOX_HEADERS} + ) + endif() + velox_install_library_headers() endfunction() @@ -218,3 +268,78 @@ function(velox_sources TARGET) target_sources(${TARGET} ${ARGN}) endif() endfunction() + +# Group test sources into batched binaries to reduce link target count on CI. +# On macOS, defaults to OFF so each test source gets its own binary and +# individual tests are discoverable via 'ctest -R '. +if(APPLE) + option(VELOX_ENABLE_GROUPED_TESTS "Group test sources into batched binaries" OFF) +else() + option(VELOX_ENABLE_GROUPED_TESTS "Group test sources into batched binaries" ON) +endif() + +# Number of test source files per grouped test binary. Controls the trade-off +# between link time (fewer groups = faster linking) and ctest parallelism +# (more groups = more parallel test processes). Ignored when +# VELOX_ENABLE_GROUPED_TESTS is OFF. +set(VELOX_TESTS_PER_GROUP 10 CACHE STRING "Number of test source files per grouped test binary") + +# Creates grouped test binaries from a list of test sources. Groups tests into +# batches of VELOX_TESTS_PER_GROUP to reduce link target count while +# maintaining ctest parallelism. When VELOX_ENABLE_GROUPED_TESTS is OFF, each +# source file becomes its own binary named after the source file (without +# extension), making individual tests discoverable via 'ctest -R '. +# +# Usage: +# velox_add_grouped_tests( +# PREFIX velox_exec +# SOURCES ${MY_SOURCES} +# DEPS ${MY_DEPS} +# [EXTRA_SOURCES Main.cpp] +# [WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}] +# ) +function(velox_add_grouped_tests) + cmake_parse_arguments(ARG "" "PREFIX;WORKING_DIRECTORY" "SOURCES;DEPS;EXTRA_SOURCES" ${ARGN}) + + if(NOT ARG_WORKING_DIRECTORY) + set(ARG_WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + endif() + + if(NOT VELOX_ENABLE_GROUPED_TESTS) + # Create one binary per source file, named after the source file. + foreach(_source IN LISTS ARG_SOURCES) + get_filename_component(_name ${_source} NAME_WE) + set(_target "${ARG_PREFIX}_${_name}") + add_executable(${_target} ${_source} ${ARG_EXTRA_SOURCES}) + add_test(NAME ${_target} COMMAND ${_target} WORKING_DIRECTORY ${ARG_WORKING_DIRECTORY}) + target_link_libraries(${_target} ${ARG_DEPS}) + endforeach() + return() + endif() + + list(LENGTH ARG_SOURCES _num_sources) + math( + EXPR + _num_groups + "(${_num_sources} + ${VELOX_TESTS_PER_GROUP} - 1) / ${VELOX_TESTS_PER_GROUP}" + ) + + set(_idx 0) + set(_group 0) + set(_current_sources "") + + foreach(_source IN LISTS ARG_SOURCES) + list(APPEND _current_sources ${_source}) + math(EXPR _idx "${_idx} + 1") + math(EXPR _group_end "(${_group} + 1) * ${VELOX_TESTS_PER_GROUP}") + + if(_idx GREATER_EQUAL _group_end OR _idx EQUAL _num_sources) + set(_target "${ARG_PREFIX}_group${_group}") + add_executable(${_target} ${_current_sources} ${ARG_EXTRA_SOURCES}) + add_test(NAME ${_target} COMMAND ${_target} WORKING_DIRECTORY ${ARG_WORKING_DIRECTORY}) + target_link_libraries(${_target} ${ARG_DEPS}) + set(_current_sources "") + math(EXPR _group "${_group} + 1") + endif() + endforeach() +endfunction() diff --git a/CMake/resolve_dependency_modules/README.md b/CMake/resolve_dependency_modules/README.md index b6fb9dab55a..7c94995e209 100644 --- a/CMake/resolve_dependency_modules/README.md +++ b/CMake/resolve_dependency_modules/README.md @@ -6,46 +6,45 @@ The versions of certain libraries is the default provided by the platform's package manager. Some libraries can be bundled by Velox. See details on bundling below. -| Library Name | Minimum Version | Bundled? | -|-------------------|-----------------|----------| -| ninja | default | No | -| ccache | default | No | -| icu4c | default | Yes | -| gflags | default | Yes | -| glog | default | Yes | -| gtest (testing) | default | Yes | -| libevent | default | No | -| libcudf | default | Yes | -| libsodium | default | No | -| lz4 | default | No | -| snappy | default | No | -| xz | default | No | -| zstd | default | No | -| openssl | default | No | -| protobuf | 21.7 >= x < 22 | Yes | -| boost | 1.77.0 | Yes | -| flex | 2.5.13 | No | -| bison | 3.0.4 | No | -| cmake | 3.28 | No | -| double-conversion | 3.1.5 | No | -| xsimd | 10.0.0 | Yes | -| re2 | 2024-07-02 | Yes | -| fmt | 10.1.1 | Yes | -| simdjson | 3.9.3 | Yes | -| faiss | 1.11.0 | Yes | -| folly | v2025.04.28.00 | Yes | -| fizz | v2025.04.28.00 | No | -| wangle | v2025.04.28.00 | No | -| mvfst | v2025.04.28.00 | No | -| fbthrift | v2025.04.28.00 | No | -| libstemmer | 2.2.0 | Yes | -| DuckDB (testing) | 0.8.1 | Yes | -| cpr (testing) | 1.10.15 | Yes | -| arrow | 15.0.0 | Yes | -| geos | 3.10.7 | Yes | -| fast_float | v8.0.2 | Yes | -| xxhash | default | No | -| thrift | 0.16 | No | +| Library Name | Minimum Version | Bundled? | Comment | +|-------------------|-----------------|----------|---------| +| ninja | default | No || +| ccache | default | No || +| icu4c | default | Yes || +| gflags | default | Yes || +| glog | default | Yes || +| gtest (testing) | default | Yes || +| libevent | default | No || +| libsodium | default | No || +| lz4 | default | No || +| snappy | default | No || +| xz | default | No || +| zstd | default | No || +| openssl | default | No || +| protobuf | 21.7 >= x < 22 | Yes || +| boost | 1.77.0 | Yes || +| flex | 2.5.13 | No || +| bison | 3.0.4 | No || +| cmake | 3.28 | No || +| double-conversion | 3.1.5 | No || +| xsimd | 10.0.0 | Yes || +| re2 | 2024-07-02 | Yes || +| fmt | 11.2.0 | Yes | Used API must be fmt 9 compatible | +| simdjson | 4.1.0 | Yes || +| faiss | 1.11.0 | Yes || +| folly | v2026.01.05.00 | Yes || +| fizz | v2026.01.05.00 | No || +| wangle | v2026.01.05.00 | No || +| mvfst | v2026.01.05.00 | No || +| fbthrift | v2026.01.05.00 | No || +| libstemmer | 2.2.0 | Yes || +| DuckDB (testing) | 0.8.1 | Yes || +| arrow | 15.0.0 | Yes || +| geos | 3.10.7 | Yes || +| s2geometry | 0.12.0 | Yes || +| fast_float | v8.0.2 | Yes || +| xxhash | default | No || +| thrift | 0.16 | No || # Bundled Dependency Management This module provides a dependency management system that allows us to automatically fetch and build dependencies from source if needed. diff --git a/CMake/resolve_dependency_modules/arrow/CMakeLists.txt b/CMake/resolve_dependency_modules/arrow/CMakeLists.txt index 2d421827ab5..4f4e4031934 100644 --- a/CMake/resolve_dependency_modules/arrow/CMakeLists.txt +++ b/CMake/resolve_dependency_modules/arrow/CMakeLists.txt @@ -20,6 +20,19 @@ if(VELOX_ENABLE_ARROW) set(THRIFT_SOURCE "BUNDLED") endif() + # Avoid issues in finding the boost headers and libraries + # by setting the BOOST_ROOT to the install prefix. + # The same logic is used in the setup script-common.sh to install boost. + if(NOT DEFINED ENV{BOOST_ROOT} OR "$ENV{BOOST_ROOT}" STREQUAL "") + if(DEFINED ENV{INSTALL_PREFIX} AND NOT "$ENV{INSTALL_PREFIX}" STREQUAL "") + set(BOOST_ROOT "$ENV{INSTALL_PREFIX}") + else() + set(BOOST_ROOT "/usr/local") + endif() + else() + set(BOOST_ROOT "$ENV{BOOST_ROOT}") + endif() + message(STATUS "Using BOOST_ROOT: ${BOOST_ROOT}") set(ARROW_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/arrow_ep") set( ARROW_CMAKE_ARGS @@ -35,16 +48,18 @@ if(VELOX_ENABLE_ARROW) -DARROW_RUNTIME_SIMD_LEVEL=NONE -DARROW_WITH_UTF8PROC=OFF -DARROW_TESTING=ON + -DCMAKE_INSTALL_LIBDIR=lib -DCMAKE_INSTALL_PREFIX=${ARROW_PREFIX}/install -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DARROW_BUILD_STATIC=ON -DThrift_SOURCE=${THRIFT_SOURCE} -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} + -DBOOST_ROOT=${BOOST_ROOT} -DCMAKE_POLICY_VERSION_MINIMUM=3.5 # Remove with Arrow upgrade to Arrow 20. -DARROW_CXXFLAGS=-Wno-documentation ) - set(ARROW_LIBDIR ${ARROW_PREFIX}/install/${CMAKE_INSTALL_LIBDIR}) + set(ARROW_LIBDIR ${ARROW_PREFIX}/install/lib) add_library(thrift STATIC IMPORTED GLOBAL) if(NOT Thrift_FOUND) @@ -58,10 +73,10 @@ if(VELOX_ENABLE_ARROW) set_property(TARGET thrift PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${THRIFT_INCLUDE_DIR}) set_property(TARGET thrift PROPERTY IMPORTED_LOCATION ${THRIFT_LIB}) - set(VELOX_ARROW_BUILD_VERSION 15.0.0) + set(VELOX_ARROW_BUILD_VERSION 18.0.0) set( VELOX_ARROW_BUILD_SHA256_CHECKSUM - ab74c60c46938505c8cd7599b1d2826c68450645d5860d0ff40f67e371a5d0b5 + 9c473f2c9914c59ab571761c9497cf0e5cfd3ea335f7782ccc6121f5cb99ae9b ) set( VELOX_ARROW_SOURCE_URL @@ -79,8 +94,8 @@ if(VELOX_ENABLE_ARROW) CMAKE_ARGS ${ARROW_CMAKE_ARGS} BUILD_BYPRODUCTS ${ARROW_LIBDIR}/libarrow.a ${ARROW_LIBDIR}/libarrow_testing.a ${THRIFT_LIB} PATCH_COMMAND - git apply ${CMAKE_CURRENT_LIST_DIR}/thrift-download.patch && git apply - ${CMAKE_CURRENT_LIST_DIR}/cmake-compatibility.patch + git apply ${CMAKE_CURRENT_LIST_DIR}/cmake-compatibility.patch && git apply + ${CMAKE_CURRENT_LIST_DIR}/thrift-download.patch ) add_library(arrow STATIC IMPORTED GLOBAL) diff --git a/CMake/resolve_dependency_modules/arrow/cmake-compatibility.patch b/CMake/resolve_dependency_modules/arrow/cmake-compatibility.patch index 249ea609048..41b0528390d 100644 --- a/CMake/resolve_dependency_modules/arrow/cmake-compatibility.patch +++ b/CMake/resolve_dependency_modules/arrow/cmake-compatibility.patch @@ -14,7 +14,7 @@ --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake -@@ -971,7 +971,8 @@ +@@ -979,7 +979,8 @@ -DCMAKE_FIND_PACKAGE_NO_PACKAGE_REGISTRY=${CMAKE_FIND_PACKAGE_NO_PACKAGE_REGISTRY} -DCMAKE_INSTALL_LIBDIR=lib -DCMAKE_OSX_SYSROOT=${CMAKE_OSX_SYSROOT} @@ -22,9 +22,9 @@ + -DCMAKE_VERBOSE_MAKEFILE=${CMAKE_VERBOSE_MAKEFILE} + -DCMAKE_POLICY_VERSION_MINIMUM=3.5) - # Enable s/ccache if set by parent. - if(CMAKE_C_COMPILER_LAUNCHER AND CMAKE_CXX_COMPILER_LAUNCHER) -@@ -1026,6 +1027,7 @@ + # if building with a toolchain file, pass that through + if(CMAKE_TOOLCHAIN_FILE) +@@ -1045,6 +1046,7 @@ set(CMAKE_COMPILE_WARNING_AS_ERROR FALSE) set(CMAKE_EXPORT_NO_PACKAGE_REGISTRY TRUE) set(CMAKE_MACOSX_RPATH ${ARROW_INSTALL_NAME_RPATH}) diff --git a/CMake/resolve_dependency_modules/arrow/thrift-download.patch b/CMake/resolve_dependency_modules/arrow/thrift-download.patch index 421b2a1253c..92b8d87dd08 100644 --- a/CMake/resolve_dependency_modules/arrow/thrift-download.patch +++ b/CMake/resolve_dependency_modules/arrow/thrift-download.patch @@ -28,13 +28,13 @@ "ARROW_S2N_TLS_URL s2n-${ARROW_S2N_TLS_BUILD_VERSION}.tar.gz https://github.com/aws/s2n-tls/archive/${ARROW_S2N_TLS_BUILD_VERSION}.tar.gz" "ARROW_SNAPPY_URL snappy-${ARROW_SNAPPY_BUILD_VERSION}.tar.gz https://github.com/google/snappy/archive/${ARROW_SNAPPY_BUILD_VERSION}.tar.gz" - "ARROW_THRIFT_URL thrift-${ARROW_THRIFT_BUILD_VERSION}.tar.gz https://archive.apache.org/dist/thrift/${ARROW_THRIFT_BUILD_VERSION}/thrift-${ARROW_THRIFT_BUILD_VERSION}.tar.gz" -+ "ARROW_THRIFT_URL thrift-${ARROW_THRIFT_BUILD_VERSION}.tar.gz https://github.com/apache/thrift/archive/refs/tags/v${ARROW_THRIFT_BUILD_VERSION}.tar.gz ++ "ARROW_THRIFT_URL thrift-${ARROW_THRIFT_BUILD_VERSION}.tar.gz https://github.com/apache/thrift/archive/refs/tags/v${ARROW_THRIFT_BUILD_VERSION}.tar.gz" "ARROW_UCX_URL ucx-${ARROW_UCX_BUILD_VERSION}.tar.gz https://github.com/openucx/ucx/archive/v${ARROW_UCX_BUILD_VERSION}.tar.gz" "ARROW_UTF8PROC_URL utf8proc-${ARROW_UTF8PROC_BUILD_VERSION}.tar.gz https://github.com/JuliaStrings/utf8proc/archive/${ARROW_UTF8PROC_BUILD_VERSION}.tar.gz" "ARROW_XSIMD_URL xsimd-${ARROW_XSIMD_BUILD_VERSION}.tar.gz https://github.com/xtensor-stack/xsimd/archive/${ARROW_XSIMD_BUILD_VERSION}.tar.gz" --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake -@@ -809,20 +809,7 @@ +@@ -817,20 +817,7 @@ set(THRIFT_SOURCE_URL "$ENV{ARROW_THRIFT_URL}") else() set_urls(THRIFT_SOURCE_URL diff --git a/CMake/resolve_dependency_modules/boost.cmake b/CMake/resolve_dependency_modules/boost.cmake index 842cba6f133..dfbc6169847 100644 --- a/CMake/resolve_dependency_modules/boost.cmake +++ b/CMake/resolve_dependency_modules/boost.cmake @@ -15,13 +15,6 @@ include_guard(GLOBAL) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/boost) -if(ICU_SOURCE) - if(${ICU_SOURCE} STREQUAL "BUNDLED") - # ensure ICU is built before Boost - add_dependencies(boost_regex ICU ICU::i18n) - endif() -endif() - # This prevents system boost from leaking in set(Boost_NO_SYSTEM_PATHS ON) # We have to keep the FindBoost.cmake in an subfolder to prevent it from diff --git a/CMake/resolve_dependency_modules/cpr.cmake b/CMake/resolve_dependency_modules/cpr.cmake deleted file mode 100644 index 37cacac79cb..00000000000 --- a/CMake/resolve_dependency_modules/cpr.cmake +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -include_guard(GLOBAL) - -set(VELOX_CPR_VERSION 1.10.5) -set( - VELOX_CPR_BUILD_SHA256_CHECKSUM - c8590568996cea918d7cf7ec6845d954b9b95ab2c4980b365f582a665dea08d8 -) -set( - VELOX_CPR_SOURCE_URL - "https://github.com/libcpr/cpr/archive/refs/tags/${VELOX_CPR_VERSION}.tar.gz" -) - -# Add the dependency for curl, so that we can define the source URL for curl in -# curl.cmake. This will override the curl version declared by cpr. -set(curl_SOURCE BUNDLED) -velox_resolve_dependency(curl) - -velox_resolve_dependency_url(CPR) - -message(STATUS "Building cpr from source") -FetchContent_Declare( - cpr - URL ${VELOX_CPR_SOURCE_URL} - URL_HASH ${VELOX_CPR_BUILD_SHA256_CHECKSUM} - PATCH_COMMAND - git apply ${CMAKE_CURRENT_LIST_DIR}/cpr/cpr-libcurl-compatible.patch && git apply - ${CMAKE_CURRENT_LIST_DIR}/cpr/cpr-remove-sancheck.patch -) -set(BUILD_SHARED_LIBS ${VELOX_BUILD_SHARED}) -set(CPR_USE_SYSTEM_CURL OFF) -# ZLIB has already been found by find_package(ZLIB, REQUIRED), set CURL_ZLIB=OFF -# to save compile time. -set(CURL_ZLIB OFF) -FetchContent_MakeAvailable(cpr) -# libcpr in its CMakeLists.txt file disables the BUILD_TESTING globally when -# CPR_USE_SYSTEM_CURL=OFF. unset BUILD_TESTING here. -unset(BUILD_TESTING) -unset(BUILD_SHARED_LIBS) diff --git a/CMake/resolve_dependency_modules/cpr/cpr-libcurl-compatible.patch b/CMake/resolve_dependency_modules/cpr/cpr-libcurl-compatible.patch deleted file mode 100644 index 49821889f2b..00000000000 --- a/CMake/resolve_dependency_modules/cpr/cpr-libcurl-compatible.patch +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This can be removed once we upgrade to curl >= 7.68.0 ---- a/cpr/multiperform.cpp -+++ b/cpr/multiperform.cpp -@@ -97,9 +97,9 @@ void MultiPerform::DoMultiPerform() { - - if (still_running) { - const int timeout_ms{250}; -- error_code = curl_multi_poll(multicurl_->handle, nullptr, 0, timeout_ms, nullptr); -+ error_code = curl_multi_wait(multicurl_->handle, nullptr, 0, timeout_ms, nullptr); - if (error_code) { -- std::cerr << "curl_multi_poll() failed, code " << static_cast(error_code) << std::endl; -+ std::cerr << "curl_multi_wait() failed, code " << static_cast(error_code) << std::endl; - break; - } - } - ---- a/include/cpr/util.h -+++ b/include/cpr/util.h -@@ -23,7 +23,7 @@ size_t writeUserFunction(char* ptr, size_t size, size_t nmemb, const WriteCallba - template - int progressUserFunction(const T* progress, cpr_pf_arg_t dltotal, cpr_pf_arg_t dlnow, cpr_pf_arg_t ultotal, cpr_pf_arg_t ulnow) { - const int cancel_retval{1}; -- static_assert(cancel_retval != CURL_PROGRESSFUNC_CONTINUE); -+ static_assert(cancel_retval != 0x10000001); - return (*progress)(dltotal, dlnow, ultotal, ulnow) ? 0 : cancel_retval; - } - int debugUserFunction(CURL* handle, curl_infotype type, char* data, size_t size, const DebugCallback* debug); diff --git a/CMake/resolve_dependency_modules/cudf.cmake b/CMake/resolve_dependency_modules/cudf.cmake index 278744e9d43..a945660f8e7 100644 --- a/CMake/resolve_dependency_modules/cudf.cmake +++ b/CMake/resolve_dependency_modules/cudf.cmake @@ -17,49 +17,50 @@ include_guard(GLOBAL) # 3.30.4 is the minimum version required by cudf cmake_minimum_required(VERSION 3.30.4) -set(VELOX_rapids_cmake_VERSION 25.10) +# rapids_cmake commit d79e071 from 2026-05-01 +set(VELOX_rapids_cmake_VERSION 26.06) +set(VELOX_rapids_cmake_COMMIT d79e071f805e709771b80008d50a8b3a5bed93ca) set( VELOX_rapids_cmake_BUILD_SHA256_CHECKSUM - 635aff67e017c64021bf3d225d31f843e9541f3bf9c3d07bac72466dc57c917b + d0f9eea4feaef1cc325e86eac787052ec951659fdcf21abdfb06efc337a63179 ) set( VELOX_rapids_cmake_SOURCE_URL - "https://github.com/rapidsai/rapids-cmake/archive/0b111489d1e6f8400e1fc88297623a2a9915fa77.tar.gz" + "https://github.com/rapidsai/rapids-cmake/archive/${VELOX_rapids_cmake_COMMIT}.tar.gz" ) velox_resolve_dependency_url(rapids_cmake) -set(VELOX_rmm_VERSION 25.10) +# rmm commit 2357ddd from 2026-05-04 +set(VELOX_rmm_VERSION 26.06) +set(VELOX_rmm_COMMIT 2357ddddcff042ba378e8de6f89e4a995a23b2db) set( VELOX_rmm_BUILD_SHA256_CHECKSUM - 72dd6a26a1a75e193723571ec7ba8bcb040ea9a38592eb0809e64ebdbf291d76 -) -set( - VELOX_rmm_SOURCE_URL - "https://github.com/rapidsai/rmm/archive/7cef2f5f30e962e9f3b27a3a3f2753a40277c093.tar.gz" + 61492c2da88e7f6a6a4edc7101cce4698c156704d135db0b80119f4b9a2c575c ) +set(VELOX_rmm_SOURCE_URL "https://github.com/rapidsai/rmm/archive/${VELOX_rmm_COMMIT}.tar.gz") velox_resolve_dependency_url(rmm) -set(VELOX_kvikio_VERSION 25.10) +# kvikio commit 247b64e from 2026-05-02 +set(VELOX_kvikio_VERSION 26.06) +set(VELOX_kvikio_COMMIT 247b64e97ecb7cb9ccb06ab123aea87ac571c5b4) set( VELOX_kvikio_BUILD_SHA256_CHECKSUM - 76c217bd925f7665246135311697393b5118185d4bdd4291e8ff4506e4feb6af + 6419490de95e412cefdbafffc73fac6b2162bc2200076611883fe41637028198 ) set( VELOX_kvikio_SOURCE_URL - "https://github.com/rapidsai/kvikio/archive/6efd22dc6ae3389caea7d3e736c7f954b9db0619.tar.gz" + "https://github.com/rapidsai/kvikio/archive/${VELOX_kvikio_COMMIT}.tar.gz" ) velox_resolve_dependency_url(kvikio) -set(VELOX_cudf_VERSION 25.10 CACHE STRING "cudf version") - +# cudf commit d09d10d from 2026-05-04 +set(VELOX_cudf_VERSION 26.06 CACHE STRING "cudf version") +set(VELOX_cudf_COMMIT d09d10d14d3ed932b8de93638809101af5c7fec3) set( VELOX_cudf_BUILD_SHA256_CHECKSUM - c7dfb333ee0cb9f86d5ee94aaa34985ae6cf45d4ed8658d850707cc8e0db8e16 -) -set( - VELOX_cudf_SOURCE_URL - "https://github.com/rapidsai/cudf/archive/2bfd896b4e0c1f0b66402c1e067b4904dbd15c5e.tar.gz" + 5042ec46beb8260eb60d13b9cd44f26357b9756628f7d58659c77e88c67e15d5 ) +set(VELOX_cudf_SOURCE_URL "https://github.com/rapidsai/cudf/archive/${VELOX_cudf_COMMIT}.tar.gz") velox_resolve_dependency_url(cudf) # Use block so we don't leak variables @@ -114,4 +115,5 @@ block(SCOPE_FOR VARIABLES) ) unset(BUILD_SHARED_LIBS) + unset(BUILD_TESTING CACHE) endblock() diff --git a/CMake/resolve_dependency_modules/curl.cmake b/CMake/resolve_dependency_modules/curl.cmake index c749c06a23f..fd4b51adf2b 100644 --- a/CMake/resolve_dependency_modules/curl.cmake +++ b/CMake/resolve_dependency_modules/curl.cmake @@ -13,19 +13,4 @@ # limitations under the License. include_guard(GLOBAL) -set(VELOX_CURL_VERSION 8.4.0) -string(REPLACE "." "_" VELOX_CURL_VERSION_UNDERSCORES ${VELOX_CURL_VERSION}) -set( - VELOX_CURL_BUILD_SHA256_CHECKSUM - 16c62a9c4af0f703d28bda6d7bbf37ba47055ad3414d70dec63e2e6336f2a82d -) -string( - CONCAT - VELOX_CURL_SOURCE_URL - "https://github.com/curl/curl/releases/download/" - "curl-${VELOX_CURL_VERSION_UNDERSCORES}/curl-${VELOX_CURL_VERSION}.tar.xz" -) - -velox_resolve_dependency_url(CURL) - -FetchContent_Declare(curl URL ${VELOX_CURL_SOURCE_URL} URL_HASH ${VELOX_CURL_BUILD_SHA256_CHECKSUM}) +add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/curl) diff --git a/CMake/resolve_dependency_modules/curl/CMakeLists.txt b/CMake/resolve_dependency_modules/curl/CMakeLists.txt new file mode 100644 index 00000000000..867c19cf242 --- /dev/null +++ b/CMake/resolve_dependency_modules/curl/CMakeLists.txt @@ -0,0 +1,42 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set(VELOX_CURL_VERSION 8.4.0) +string(REPLACE "." "_" VELOX_CURL_VERSION_UNDERSCORES ${VELOX_CURL_VERSION}) +set( + VELOX_CURL_BUILD_SHA256_CHECKSUM + 16c62a9c4af0f703d28bda6d7bbf37ba47055ad3414d70dec63e2e6336f2a82d +) +string( + CONCAT + VELOX_CURL_SOURCE_URL + "https://github.com/curl/curl/releases/download/" + "curl-${VELOX_CURL_VERSION_UNDERSCORES}/curl-${VELOX_CURL_VERSION}.tar.xz" +) + +set(BUILD_TESTING OFF) +set(BUILD_SHARED_LIBS ON) + +velox_resolve_dependency_url(CURL) + +message(STATUS "Building CURL from source") + +FetchContent_Declare(curl URL ${VELOX_CURL_SOURCE_URL} URL_HASH ${VELOX_CURL_BUILD_SHA256_CHECKSUM}) + +FetchContent_MakeAvailable(curl) + +# Curl uses CMake option for BUILD_TESTING and BUILD_SHARED_LIBS +# See CMake option semantics. +unset(BUILD_TESTING CACHE) +unset(BUILD_SHARED_LIBS CACHE) diff --git a/CMake/resolve_dependency_modules/duckdb.cmake b/CMake/resolve_dependency_modules/duckdb.cmake index 8859ccb6ac5..bd3fbd7e14f 100644 --- a/CMake/resolve_dependency_modules/duckdb.cmake +++ b/CMake/resolve_dependency_modules/duckdb.cmake @@ -44,6 +44,7 @@ FetchContent_Declare( # that. set(GIT_COMMIT_HASH "6536a77") set(BUILD_UNITTESTS OFF) +set(BUILD_TESTING OFF) set(ENABLE_SANITIZER OFF) set(ENABLE_UBSAN OFF) set(BUILD_SHELL OFF) @@ -65,3 +66,5 @@ endif() set(CMAKE_CXX_FLAGS ${PREVIOUS_CMAKE_CXX_FLAGS}) set(CMAKE_BUILD_TYPE ${PREVIOUS_BUILD_TYPE}) +# Some DuckDB third-party package sets this flags. We cannot control that. +unset(BUILD_TESTING) diff --git a/CMake/resolve_dependency_modules/faiss.cmake b/CMake/resolve_dependency_modules/faiss.cmake index 911ba5c8ef9..db158dbe65d 100644 --- a/CMake/resolve_dependency_modules/faiss.cmake +++ b/CMake/resolve_dependency_modules/faiss.cmake @@ -60,10 +60,14 @@ FetchContent_Declare( # Set build options block() set(BUILD_SHARED_LIBS OFF) + set(BUILD_TESTING OFF) set(CMAKE_BUILD_TYPE Release) set(FAISS_ENABLE_GPU OFF) set(FAISS_ENABLE_PYTHON OFF) set(FAISS_ENABLE_GPU_TESTS OFF) # Make FAISS available FetchContent_MakeAvailable(faiss) + add_library(FAISS::faiss ALIAS faiss) + unset(BUILD_TESTING CACHE) + unset(BUILD_SHARED_LIBS CACHE) endblock() diff --git a/CMake/resolve_dependency_modules/fmt.cmake b/CMake/resolve_dependency_modules/fmt.cmake index fc69a934a83..9e8087f1da0 100644 --- a/CMake/resolve_dependency_modules/fmt.cmake +++ b/CMake/resolve_dependency_modules/fmt.cmake @@ -13,10 +13,10 @@ # limitations under the License. include_guard(GLOBAL) -set(VELOX_FMT_VERSION 10.1.1) +set(VELOX_FMT_VERSION 11.2.0) set( VELOX_FMT_BUILD_SHA256_CHECKSUM - 78b8c0a72b1c35e4443a7e308df52498252d1cefc2b08c9a97bc9ee6cfe61f8b + bc23066d87ab3168f27cef3e97d545fa63314f5c79df5ea444d41d56f962c6af ) set(VELOX_FMT_SOURCE_URL "https://github.com/fmtlib/fmt/archive/${VELOX_FMT_VERSION}.tar.gz") diff --git a/CMake/resolve_dependency_modules/folly/CMakeLists.txt b/CMake/resolve_dependency_modules/folly/CMakeLists.txt index 28af264cfaa..8139c11a57e 100644 --- a/CMake/resolve_dependency_modules/folly/CMakeLists.txt +++ b/CMake/resolve_dependency_modules/folly/CMakeLists.txt @@ -14,13 +14,13 @@ project(Folly) cmake_minimum_required(VERSION 3.28) -velox_set_source(fastfloat) -velox_resolve_dependency(fastfloat CONFIG REQUIRED) +velox_set_source(FastFloat) +velox_resolve_dependency(FastFloat CONFIG REQUIRED) -set(VELOX_FOLLY_BUILD_VERSION v2025.04.28.00) +set(VELOX_FOLLY_BUILD_VERSION v2026.01.05.00) set( VELOX_FOLLY_BUILD_SHA256_CHECKSUM - ccbb7eac662023f9f5beba94e51350d527f33d8a7a036eb5e3d8a5cf1b49d3bc + 8b41494b664fcde3d02f652d6a27381ec5e938ef58cf986d8ce4958d424af101 ) set( VELOX_FOLLY_SOURCE_URL @@ -54,8 +54,6 @@ set(BUILD_SHARED_LIBS ${VELOX_BUILD_SHARED}) # Enable INT128 support set(FOLLY_HAVE_INT128_T ON) -set(PREVIOUS_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DFOLLY_CFG_NO_COROUTINES") FetchContent_MakeAvailable(folly) @@ -66,4 +64,3 @@ add_library(Folly::follybenchmark ALIAS follybenchmark) if(gflags_SOURCE STREQUAL "BUNDLED") add_dependencies(folly glog::glog gflags::gflags fmt::fmt) endif() -set(CMAKE_CXX_FLAGS ${PREVIOUS_CMAKE_CXX_FLAGS}) diff --git a/CMake/resolve_dependency_modules/geos.cmake b/CMake/resolve_dependency_modules/geos.cmake index f092ba823b1..6037d7ac20e 100644 --- a/CMake/resolve_dependency_modules/geos.cmake +++ b/CMake/resolve_dependency_modules/geos.cmake @@ -19,13 +19,13 @@ block() set(VELOX_GEOS_BUILD_VERSION 3.10.7) set( VELOX_GEOS_BUILD_SHA256_CHECKSUM - 8b2ab4d04d660e27f2006550798f49dd11748c3767455cae9f71967dc437da1f + fcde02913159711fd3188afccf9ba78008ecced089d7e6cb6ff9a4bc1eed10a1 ) string( CONCAT VELOX_GEOS_SOURCE_URL - "https://download.osgeo.org/geos/" - "geos-${VELOX_GEOS_BUILD_VERSION}.tar.bz2" + "https://github.com/libgeos/geos/archive/refs/tags/" + "${VELOX_GEOS_BUILD_VERSION}.tar.gz" ) velox_resolve_dependency_url(GEOS) diff --git a/CMake/resolve_dependency_modules/re2.cmake b/CMake/resolve_dependency_modules/re2.cmake index b31bed33b6e..738e0564b7c 100644 --- a/CMake/resolve_dependency_modules/re2.cmake +++ b/CMake/resolve_dependency_modules/re2.cmake @@ -60,3 +60,5 @@ set(re2_INCLUDE_DIRS ${re2_SOURCE_DIR}) set(RE2_ROOT ${re2_BINARY_DIR}) set(re2_ROOT ${re2_BINARY_DIR}) + +unset(BUILD_TESTING CACHE) diff --git a/CMake/resolve_dependency_modules/s2geometry.cmake b/CMake/resolve_dependency_modules/s2geometry.cmake new file mode 100644 index 00000000000..a7f39397578 --- /dev/null +++ b/CMake/resolve_dependency_modules/s2geometry.cmake @@ -0,0 +1,66 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +include_guard(GLOBAL) + +list(PREPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_LIST_DIR}/s2geometry) + +# This creates a separate scope so any changed variables don't affect +# the rest of the build. +block() + # s2geometry needs absl. + if(NOT TARGET absl::base) + velox_set_source(absl) + velox_resolve_dependency(absl) + endif() + + set(VELOX_S2GEOMETRY_BUILD_VERSION 0.12.0) + set( + VELOX_S2GEOMETRY_BUILD_SHA256_CHECKSUM + c09ec751c3043965a0d441e046a73c456c995e6063439a72290f661c1054d611 + ) + string( + CONCAT + VELOX_S2GEOMETRY_SOURCE_URL + "https://github.com/google/s2geometry/archive/refs/tags/" + "v${VELOX_S2GEOMETRY_BUILD_VERSION}.tar.gz" + ) + + velox_resolve_dependency_url(S2GEOMETRY) + + FetchContent_Declare( + s2geometry + URL ${VELOX_S2GEOMETRY_SOURCE_URL} + URL_HASH ${VELOX_S2GEOMETRY_BUILD_SHA256_CHECKSUM} + OVERRIDE_FIND_PACKAGE + SYSTEM + EXCLUDE_FROM_ALL + PATCH_COMMAND git apply ${CMAKE_CURRENT_LIST_DIR}/s2geometry/s2geometry-gcc12-max.patch + ) + + list(APPEND CMAKE_MODULE_PATH "${s2geometry_SOURCE_DIR}/cmake") + set(BUILD_SHARED_LIBS OFF) + set(BUILD_TESTING OFF) + set(BUILD_TESTS OFF) + set(CMAKE_BUILD_TYPE Release) + + FetchContent_MakeAvailable(s2geometry) + + # Clang does not enable C++14 sized deallocation by default, unlike GCC. + # s2geometry's port.h uses ::operator delete(ptr, size) which requires it. + if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") + target_compile_options(s2 PRIVATE -fsized-deallocation) + endif() + + add_library(s2::s2 ALIAS s2) +endblock() diff --git a/CMake/resolve_dependency_modules/s2geometry/Finds2geometry.cmake b/CMake/resolve_dependency_modules/s2geometry/Finds2geometry.cmake new file mode 100644 index 00000000000..a93b3edaf82 --- /dev/null +++ b/CMake/resolve_dependency_modules/s2geometry/Finds2geometry.cmake @@ -0,0 +1,23 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Find the s2geometry library installed on the system. +# s2geometry installs its CMake config as "s2Config.cmake" with the s2::s2 +# target. This shim bridges the package name difference so that +# find_package(s2geometry) works. + +find_package(s2 CONFIG QUIET) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(s2geometry DEFAULT_MSG s2_FOUND) diff --git a/CMake/resolve_dependency_modules/s2geometry/s2geometry-gcc12-max.patch b/CMake/resolve_dependency_modules/s2geometry/s2geometry-gcc12-max.patch new file mode 100644 index 00000000000..185756c399c --- /dev/null +++ b/CMake/resolve_dependency_modules/s2geometry/s2geometry-gcc12-max.patch @@ -0,0 +1,11 @@ +--- a/src/s2/s2cell_id.cc ++++ b/src/s2/s2cell_id.cc +@@ -195,7 +195,7 @@ int S2CellId::GetCommonAncestorLevel(S2CellId other) const { + // Compute the position of the most significant bit, and then map the bit + // position as follows: + // {0} -> 30, {1,2} -> 29, {3,4} -> 28, ... , {59,60} -> 0, {61,62,63} -> -1. +- return max(61 - absl::bit_width(bits), -1) >> 1; ++ return std::max(61 - absl::bit_width(bits), -1) >> 1; + } + + // Print the num_digits low order hex digits. diff --git a/CMake/resolve_dependency_modules/simdjson.cmake b/CMake/resolve_dependency_modules/simdjson.cmake index 962a76fadf1..59894896c8b 100644 --- a/CMake/resolve_dependency_modules/simdjson.cmake +++ b/CMake/resolve_dependency_modules/simdjson.cmake @@ -13,10 +13,10 @@ # limitations under the License. include_guard(GLOBAL) -set(VELOX_SIMDJSON_VERSION 3.9.3) +set(VELOX_SIMDJSON_VERSION 4.1.0) set( VELOX_SIMDJSON_BUILD_SHA256_CHECKSUM - 2e3d10abcde543d3dd8eba9297522cafdcebdd1db4f51b28f3bc95bf1d6ad23c + 78115e37b2e88ec63e6ae20bb148063a9112c55bcd71404c8572078fd8a6ac3e ) set( VELOX_SIMDJSON_SOURCE_URL diff --git a/CMake/third-party/FBThriftCppLibrary.cmake b/CMake/third-party/FBThriftCppLibrary.cmake index 0d05546a26f..416a88b752f 100644 --- a/CMake/third-party/FBThriftCppLibrary.cmake +++ b/CMake/third-party/FBThriftCppLibrary.cmake @@ -4,22 +4,29 @@ include(FBCMakeParseArgs) # Generate a C++ library from a thrift file # -# Parameters: - SERVICES [ ...] The names of the services defined -# in the thrift file. - DEPENDS [ ...] A list of other thrift C++ -# libraries that this library depends on. - OPTIONS [ ...] A list -# of options to pass to the thrift compiler. - INCLUDE_DIR The -# sub-directory where generated headers will be installed. Defaults to "include" -# if not specified. The caller must still call install() to install the thrift -# library if desired. - THRIFT_INCLUDE_DIR The sub-directory where -# generated headers will be installed. Defaults to "${INCLUDE_DIR}/thrift-files" -# if not specified. The caller must still call install() to install the thrift -# library if desired. +# Parameters: +# - SERVICES [ ...] +# The names of the services defined in the thrift file. +# - DEPENDS [ ...] +# A list of other thrift C++ libraries that this library depends on. +# - OPTIONS [ ...] +# A list of options to pass to the thrift compiler. +# - INCLUDE_DIR +# The sub-directory where generated headers will be installed. +# Defaults to "include" if not specified. The caller must still call +# install() to install the thrift library if desired. +# - THRIFT_INCLUDE_DIR +# The sub-directory where generated headers will be installed. +# Defaults to "${INCLUDE_DIR}/thrift-files" if not specified. +# The caller must still call install() to install the thrift library if +# desired. function(add_fbthrift_cpp_library LIB_NAME THRIFT_FILE) # Parse the arguments set(one_value_args INCLUDE_DIR THRIFT_INCLUDE_DIR) set(multi_value_args SERVICES DEPENDS OPTIONS) - fb_cmake_parse_args(ARG "" "${one_value_args}" "${multi_value_args}" - "${ARGN}") + fb_cmake_parse_args( + ARG "" "${one_value_args}" "${multi_value_args}" "${ARGN}" + ) if(NOT DEFINED ARG_INCLUDE_DIR) set(ARG_INCLUDE_DIR "include") endif() @@ -28,134 +35,168 @@ function(add_fbthrift_cpp_library LIB_NAME THRIFT_FILE) endif() get_filename_component(base ${THRIFT_FILE} NAME_WE) - get_filename_component(output_dir ${CMAKE_CURRENT_BINARY_DIR}/${THRIFT_FILE} - DIRECTORY) + get_filename_component( + output_dir + ${CMAKE_CURRENT_BINARY_DIR}/${THRIFT_FILE} + DIRECTORY + ) # Generate relative paths in #includes - file(RELATIVE_PATH include_prefix "${CMAKE_SOURCE_DIR}" - "${CMAKE_CURRENT_SOURCE_DIR}/${THRIFT_FILE}") + file( + RELATIVE_PATH include_prefix + "${CMAKE_SOURCE_DIR}" + "${CMAKE_CURRENT_SOURCE_DIR}/${THRIFT_FILE}" + ) get_filename_component(include_prefix ${include_prefix} DIRECTORY) - if(NOT "${include_prefix}" STREQUAL "") + if (NOT "${include_prefix}" STREQUAL "") list(APPEND ARG_OPTIONS "include_prefix=${include_prefix}") endif() - # CMake 3.12 is finally getting a list(JOIN) function, but until then treating - # the list as a string and replacing the semicolons is good enough. + # CMake 3.12 is finally getting a list(JOIN) function, but until then + # treating the list as a string and replacing the semicolons is good enough. string(REPLACE ";" "," GEN_ARG_STR "${ARG_OPTIONS}") # Compute the list of generated files - list( - APPEND - generated_headers + list(APPEND generated_headers "${output_dir}/gen-cpp2/${base}_constants.h" "${output_dir}/gen-cpp2/${base}_types.h" "${output_dir}/gen-cpp2/${base}_types.tcc" "${output_dir}/gen-cpp2/${base}_types_custom_protocol.h" - "${output_dir}/gen-cpp2/${base}_metadata.h") - list( - APPEND - generated_sources + "${output_dir}/gen-cpp2/${base}_metadata.h" + ) + list(APPEND generated_sources "${output_dir}/gen-cpp2/${base}_constants.cpp" "${output_dir}/gen-cpp2/${base}_data.h" "${output_dir}/gen-cpp2/${base}_data.cpp" "${output_dir}/gen-cpp2/${base}_types.cpp" - "${output_dir}/gen-cpp2/${base}_metadata.cpp") + "${output_dir}/gen-cpp2/${base}_types_binary.cpp" + "${output_dir}/gen-cpp2/${base}_types_compact.cpp" + "${output_dir}/gen-cpp2/${base}_types_serialization.cpp" + "${output_dir}/gen-cpp2/${base}_metadata.cpp" + ) foreach(service IN LISTS ARG_SERVICES) - list( - APPEND - generated_headers + list(APPEND generated_headers "${output_dir}/gen-cpp2/${service}.h" "${output_dir}/gen-cpp2/${service}.tcc" "${output_dir}/gen-cpp2/${service}AsyncClient.h" - "${output_dir}/gen-cpp2/${service}_custom_protocol.h") - list( - APPEND - generated_sources + "${output_dir}/gen-cpp2/${service}_custom_protocol.h" + ) + list(APPEND generated_sources "${output_dir}/gen-cpp2/${service}.cpp" "${output_dir}/gen-cpp2/${service}AsyncClient.cpp" "${output_dir}/gen-cpp2/${service}_processmap_binary.cpp" - "${output_dir}/gen-cpp2/${service}_processmap_compact.cpp") + "${output_dir}/gen-cpp2/${service}_processmap_compact.cpp" + ) endforeach() - # This generator expression gets the list of include directories required for - # all of our dependencies. It requires using COMMAND_EXPAND_LISTS in the - # add_custom_command() call below. COMMAND_EXPAND_LISTS is only available in - # CMake 3.8+ If we really had to support older versions of CMake we would - # probably need to use a wrapper script around the thrift compiler that could - # take the include list as a single argument and split it up before invoking - # the thrift compiler. - if(NOT POLICY CMP0067) + # This generator expression gets the list of include directories required + # for all of our dependencies. + # It requires using COMMAND_EXPAND_LISTS in the add_custom_command() call + # below. COMMAND_EXPAND_LISTS is only available in CMake 3.8+ + # If we really had to support older versions of CMake we would probably need + # to use a wrapper script around the thrift compiler that could take the + # include list as a single argument and split it up before invoking the + # thrift compiler. + if (NOT POLICY CMP0067) message(FATAL_ERROR "add_fbthrift_cpp_library() requires CMake 3.8+") endif() - set(thrift_include_options - "-I;$,;-I;>" + set( + thrift_include_options + "-I;$,;-I;>" ) # Emit the rule to run the thrift compiler add_custom_command( - OUTPUT ${generated_headers} ${generated_sources} + OUTPUT + ${generated_headers} + ${generated_sources} COMMAND_EXPAND_LISTS - COMMAND "${CMAKE_COMMAND}" -E make_directory "${output_dir}" + COMMAND + "${CMAKE_COMMAND}" -E make_directory "${output_dir}" COMMAND "${FBTHRIFT_COMPILER}" - # TODO: this flag does not exist in all fbthrift versions. --legacy-strict - --gen "mstch_cpp2:${GEN_ARG_STR}" "${thrift_include_options}" -I - "${FBTHRIFT_INCLUDE_DIR}" -o "${output_dir}" + --legacy-strict + --gen "mstch_cpp2:${GEN_ARG_STR}" + "${thrift_include_options}" + -I "${FBTHRIFT_INCLUDE_DIR}" + -o "${output_dir}" "${CMAKE_CURRENT_SOURCE_DIR}/${THRIFT_FILE}" - WORKING_DIRECTORY "${CMAKE_BINARY_DIR}" - MAIN_DEPENDENCY "${THRIFT_FILE}" - DEPENDS ${ARG_DEPENDS} "${FBTHRIFT_COMPILER}") + WORKING_DIRECTORY + "${CMAKE_BINARY_DIR}" + MAIN_DEPENDENCY + "${THRIFT_FILE}" + DEPENDS + ${ARG_DEPENDS} + "${FBTHRIFT_COMPILER}" + ) # Now emit the library rule to compile the sources - if(BUILD_SHARED_LIBS) + if (BUILD_SHARED_LIBS) set(LIB_TYPE SHARED) - else() + else () set(LIB_TYPE STATIC) - endif() + endif () - add_library("${LIB_NAME}" ${LIB_TYPE} ${generated_sources}) + add_library( + "${LIB_NAME}" ${LIB_TYPE} + ${generated_sources} + ) target_include_directories( - "${LIB_NAME}" PUBLIC "$" - "$") + "${LIB_NAME}" + PUBLIC + "$" + "$" + ${Xxhash_INCLUDE_DIR} + ) target_link_libraries( "${LIB_NAME}" - PUBLIC ${ARG_DEPENDS} - Folly::folly - FBThrift::thriftcpp2 - # TODO: these symbols require other dependencies that need to be - # correctly handle by Velox's build system before they can be - # enabled. mvfst::mvfst_server_async_tran mvfst::mvfst_server + PUBLIC + ${ARG_DEPENDS} + FBThrift::thriftcpp2 + Folly::folly + mvfst::mvfst_server_async_tran + mvfst::mvfst_server + ${Xxhash_LIBRARY} ) # Add ${generated_headers} to the PUBLIC_HEADER property for ${LIB_NAME} # - # This allows callers to install it using "install(TARGETS ${LIB_NAME} - # PUBLIC_HEADER)" However, note that CMake's PUBLIC_HEADER behavior is rather - # inflexible, and does have any way to preserve header directory structure. - # Callers must be careful to use the correct PUBLIC_HEADER DESTINATION - # parameter when doing this, to put the files the correct directory - # themselves. We define a HEADER_INSTALL_DIR property with the include - # directory prefix, so typically callers should specify the PUBLIC_HEADER - # DESTINATION as "$" - set_property(TARGET "${LIB_NAME}" PROPERTY PUBLIC_HEADER ${generated_headers}) + # This allows callers to install it using + # "install(TARGETS ${LIB_NAME} PUBLIC_HEADER)" + # However, note that CMake's PUBLIC_HEADER behavior is rather inflexible, + # and does have any way to preserve header directory structure. Callers + # must be careful to use the correct PUBLIC_HEADER DESTINATION parameter + # when doing this, to put the files the correct directory themselves. + # We define a HEADER_INSTALL_DIR property with the include directory prefix, + # so typically callers should specify the PUBLIC_HEADER DESTINATION as + # "$" + set_property( + TARGET "${LIB_NAME}" + PROPERTY PUBLIC_HEADER ${generated_headers} + ) # Define a dummy interface library to help propagate the thrift include # directories between dependencies. add_library("${LIB_NAME}.thrift_includes" INTERFACE) target_include_directories( "${LIB_NAME}.thrift_includes" - INTERFACE "$" - "$") + INTERFACE + "$" + "$" + ) foreach(dep IN LISTS ARG_DEPENDS) - target_link_libraries("${LIB_NAME}.thrift_includes" - INTERFACE "${dep}.thrift_includes") + target_link_libraries( + "${LIB_NAME}.thrift_includes" + INTERFACE "${dep}.thrift_includes" + ) endforeach() set_target_properties( "${LIB_NAME}" - PROPERTIES EXPORT_PROPERTIES "THRIFT_INSTALL_DIR" - THRIFT_INSTALL_DIR "${ARG_THRIFT_INCLUDE_DIR}/${include_prefix}" - HEADER_INSTALL_DIR - "${ARG_INCLUDE_DIR}/${include_prefix}/gen-cpp2") + PROPERTIES + EXPORT_PROPERTIES "THRIFT_INSTALL_DIR" + THRIFT_INSTALL_DIR "${ARG_THRIFT_INCLUDE_DIR}/${include_prefix}" + HEADER_INSTALL_DIR "${ARG_INCLUDE_DIR}/${include_prefix}/gen-cpp2" + ) endfunction() diff --git a/CMakeLists.txt b/CMakeLists.txt index 4b048a59739..bd6d45055ed 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -111,7 +111,7 @@ endif() if(VELOX_BUILD_SHARED) message( WARNING - "When building Velox as a shared library it's recommended to build against a shared build of folly to avoid issues with linking of gflags." + "When building Velox as a shared library it's recommended to build against a shared build of folly to avoid issues with linking of gflags. " "This is currently NOT being enforced so user discretion is advised." ) endif() @@ -156,7 +156,7 @@ option(VELOX_ENABLE_ABFS "Build Abfs Connector" OFF) option(VELOX_ENABLE_HDFS "Build Hdfs Connector" OFF) option(VELOX_ENABLE_PARQUET "Enable Parquet support" ON) option(VELOX_ENABLE_ARROW "Enable Arrow support" OFF) -option(VELOX_ENABLE_GEO "Enable Geospatial support" OFF) +option(VELOX_ENABLE_GEO "Enable Geospatial support" ON) option(VELOX_ENABLE_REMOTE_FUNCTIONS "Enable remote function support" OFF) option(VELOX_ENABLE_CCACHE "Use ccache if installed." ON) option(VELOX_ENABLE_COMPRESSION_LZ4 "Enable Lz4 compression support." OFF) @@ -212,9 +212,24 @@ if(VELOX_ENABLE_BENCHMARKS_BASIC) set(VELOX_BUILD_TEST_UTILS ON) endif() +if(VELOX_ENABLE_CUDF) + if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + message(FATAL_ERROR "cuDF requires GCC. Found ${CMAKE_CXX_COMPILER_ID}.") + elseif(CMAKE_CXX_COMPILER_VERSION VERSION_LESS 13.3) + message(FATAL_ERROR "cuDF requires GCC >= 13.3. Found GCC ${CMAKE_CXX_COMPILER_VERSION}.") + endif() + message(STATUS "Building curl from source to satisfy cuDF curl version requirement") + set(CURL_SOURCE BUNDLED) + velox_resolve_dependency(CURL) +endif() + if(VELOX_BUILD_TESTING OR VELOX_BUILD_TEST_UTILS) - set(cpr_SOURCE BUNDLED) - velox_resolve_dependency(cpr) + # cuDF bundles curl since it needs a specific version. + # Use bundled or system curl otherwise. + if(NOT VELOX_ENABLE_CUDF) + velox_set_source(CURL) + velox_resolve_dependency(CURL) + endif() set(VELOX_ENABLE_DUCKDB ON) set(VELOX_ENABLE_PARSE ON) endif() @@ -300,7 +315,7 @@ if(VELOX_ENABLE_S3) if(AWSSDK_ROOT_DIR) list(APPEND CMAKE_PREFIX_PATH ${AWSSDK_ROOT_DIR}) endif() - find_package(AWSSDK REQUIRED COMPONENTS s3;identity-management) + find_package(AWSSDK 1.11.654 REQUIRED COMPONENTS s3;identity-management) add_definitions(-DVELOX_ENABLE_S3) endif() @@ -396,13 +411,6 @@ if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsigned-char") endif() -# Ensure that we don't bring in headers that might have coroutines enabled. The -# dependencies turn off coroutines in folly and the other FBOS dependencies. If -# not explicitly turned off differences in the build using GCC cause libray -# incompatibilities of the thrift server and the remote functions library -# causing SEGVs in the tests. -string(APPEND CMAKE_CXX_FLAGS " -DFOLLY_CFG_NO_COROUTINES") - # Under Ninja, we are able to designate certain targets large enough to require # restricted parallelism. if("${MAX_HIGH_MEM_JOBS}") @@ -445,7 +453,8 @@ if(ENABLE_ALL_WARNINGS) endif() if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL "14.0.0") string(APPEND KNOWN_COMPILER_SPECIFIC_WARNINGS " -Wno-error=template-id-cdtor") - string(APPEND KNOWN_COMPILER_SPECIFIC_WARNINGS " -Wno-error=overloaded-virtual") + string(APPEND KNOWN_COMPILER_SPECIFIC_WARNINGS " -Wno-overloaded-virtual") + string(APPEND KNOWN_COMPILER_SPECIFIC_WARNINGS " -Wno-error=tautological-compare") endif() endif() @@ -472,9 +481,16 @@ if(VELOX_ENABLE_WAVE OR VELOX_ENABLE_CUDF) message(FATAL_ERROR "-DCMAKE_CUDA_ARCHITECTURES= must be set") endif() if(CMAKE_BUILD_TYPE MATCHES Debug) - add_compile_options("$<$:-G>") + set(VELOX_CUDA_DEBUG_FLAGS "$<$:-G>") endif() find_package(CUDAToolkit REQUIRED) + # Suppress deprecated GPU targets warning for CUDA 12.8.x + if( + CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0 + AND CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 13.0.0 + ) + string(APPEND CMAKE_CUDA_FLAGS " -Wno-deprecated-gpu-targets") + endif() if(VELOX_ENABLE_CUDF) foreach(arch ${CMAKE_CUDA_ARCHITECTURES}) if(arch LESS 70) @@ -487,6 +503,16 @@ if(VELOX_ENABLE_WAVE OR VELOX_ENABLE_CUDF) set(VELOX_ENABLE_ARROW ON) velox_set_source(cudf) velox_resolve_dependency(cudf) + if(TARGET aws-cpp-sdk-core) + # Fix for AWS SDK CPP using hardcoded system curl instead of soft link to curl + get_target_property(override_curl_lib aws-cpp-sdk-core INTERFACE_LINK_LIBRARIES) + list(REMOVE_ITEM override_curl_lib "/usr/lib64/libcurl.so") + list(APPEND override_curl_lib "\$") + set_target_properties( + aws-cpp-sdk-core + PROPERTIES INTERFACE_LINK_LIBRARIES "${override_curl_lib}" + ) + endif() endif() endif() @@ -546,15 +572,17 @@ if(NOT TARGET gflags::gflags) # target even when velox is built as a subproject which uses # `find_package(gflags)` which does not create a globally imported target that # we can ALIAS. - add_library(gflags_gflags INTERFACE) - target_link_libraries(gflags_gflags INTERFACE gflags) - add_library(gflags::gflags ALIAS gflags_gflags) + add_library(gflags::gflags ALIAS gflags) endif() if(${gflags_SOURCE} STREQUAL "BUNDLED") - # we force glog from source to avoid issues with a system version built - # against another gflags version (which is likely) - set(glog_SOURCE BUNDLED) + # Allow explicit glog_SOURCE override (e.g. when system glog is built + # against the same gflags version as BUNDLED). + if(DEFINED ENV{glog_SOURCE}) + set(glog_SOURCE $ENV{glog_SOURCE}) + else() + set(glog_SOURCE BUNDLED) + endif() else() set(glog_SOURCE SYSTEM) endif() @@ -574,13 +602,16 @@ if(${VELOX_BUILD_MINIMAL_WITH_DWIO} OR ${VELOX_ENABLE_HIVE_CONNECTOR}) find_package(ZLIB REQUIRED) find_package(zstd REQUIRED) find_package(Snappy REQUIRED) + # Ensure zstd::zstd target exists - handle different zstd package configurations if(NOT TARGET zstd::zstd) if(TARGET zstd::libzstd_static) - set(ZSTD_TYPE static) + add_library(zstd::zstd ALIAS zstd::libzstd_static) + elseif(TARGET zstd::libzstd_shared) + add_library(zstd::zstd ALIAS zstd::libzstd_shared) else() - set(ZSTD_TYPE shared) + # Fallback: use Findzstd.cmake to create the target + include(Findzstd) endif() - add_library(zstd::zstd ALIAS zstd::libzstd_${ZSTD_TYPE}) endif() endif() @@ -602,7 +633,10 @@ if(${VELOX_BUILD_MINIMAL_WITH_DWIO} OR ${VELOX_ENABLE_HIVE_CONNECTOR} OR VELOX_E endif() velox_set_source(simdjson) -velox_resolve_dependency(simdjson 3.9.3) +velox_resolve_dependency(simdjson 4.1.0) + +velox_set_source(FastFloat) +velox_resolve_dependency(FastFloat) velox_set_source(folly) velox_resolve_dependency(folly) @@ -636,9 +670,15 @@ if(VELOX_ENABLE_GCS) endif() # GCC needs to link a library to enable std::filesystem. +# GCC needs to enable co-routines. We only support compiler that +# support coroutines. if("${CMAKE_CXX_COMPILER_ID}" MATCHES "GNU") # Find Threads library find_package(Threads REQUIRED) + add_compile_options($<$:-fcoroutines>) + # Explicitly enable folly coroutines when compiler supports them. + # This is required for folly::coro::AsyncGenerator to work correctly. + add_compile_definitions(FOLLY_HAS_COROUTINES=1) endif() if(VELOX_BUILD_TESTING AND NOT VELOX_ENABLE_DUCKDB) @@ -706,10 +746,6 @@ endif() include_directories(.) -# TODO: Include all other installation files. For now just making sure this -# generates an installable makefile. -install(FILES velox/type/Type.h DESTINATION "include/velox") - # Adding this down here prevents warnings in dependencies from stopping the # build if("${TREAT_WARNINGS_AS_ERRORS}") @@ -728,4 +764,10 @@ if(VELOX_ENABLE_GEO) velox_resolve_dependency(geos) endif() +if(VELOX_ENABLE_GEO) + list(PREPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/CMake/resolve_dependency_modules/s2geometry) + velox_set_source(s2geometry) + velox_resolve_dependency(s2geometry) +endif() + add_subdirectory(velox) diff --git a/CODING_STYLE.md b/CODING_STYLE.md index 43947192cd7..1942d7096ce 100644 --- a/CODING_STYLE.md +++ b/CODING_STYLE.md @@ -83,6 +83,46 @@ line width, indentation and ordering (for includes, using directives and etc).  code paths, and are not recommended for use in production environments. For example, `debug_disable_expression_with_peeling` is used to disable peeling optimization employed in expression evaluation. +* **Never name a file or class `*Utils`, `*Helpers`, or `*Common`.** These + generic names attract unrelated functions over time and lose cohesion. + Instead, name files and classes after the concept they represent. When + extracting shared functions, ask: "What do these functions have in common + beyond being useful?" The answer is the name. Use a class with static methods + to group related operations, and shorten method names since the class name + provides context. + + ```cpp + // ❌ Wrong — generic name, will attract unrelated functions. + class ParserUtils { + static std::vector widenProjectionsForSort(...); + static void sortAndTrimProjections(...); + }; + + // ✅ Correct — named after the concept (sorting + projections). + class SortProjection { + static std::vector widenProjections(...); + static void sortAndTrim(...); + }; + ``` + +* **Do not abbreviate** variable, function, or class names. Use full, + descriptive names. Clarity is more important than brevity. + + | ❌ Avoid | ✅ Prefer | + |----------|-----------| + | `outputCol` | `outputColumn` | + | `val` | `value` | + | `idx` | `index` | + | `agg` | `aggregation` | + | `sel` | `selectivity` | + | `rowCount` | `numRows` | + + * Well-established abbreviations in the domain are acceptable (e.g., `id`, + `url`, `sql`, `expr`). + * Loop indices like `i`, `j`, `k` are acceptable for simple loops. + * Iterator variables named `it` are acceptable. + * `numXxx` naming pattern is acceptable and preferred over `xxxCount` or + `xxxCnt` (e.g., `numRows`, `numKeys`). ## Comments @@ -154,8 +194,40 @@ About comment style: * Include enough context in the comment itself to make clear what will be done, without requiring any references from outside the code. - * Do not include the author’s username. If required, this can always be + * Do not include the author's username. If required, this can always be retrieved from git blame. +* **Start comments with active verbs**, not "This class…" or "This method…". + + | ❌ Avoid | ✅ Prefer | + |----------|-----------| + | `/// This class builds query plans.` | `/// Builds query plans.` | + | `/// This method computes selectivity.` | `/// Computes selectivity.` | + | `/// This function returns the cost.` | `/// Returns the cost.` | + +* **Avoid redundant comments** that simply repeat what the code already says. + Comments should explain *why*, not *what*. + +```cpp +// ❌ Avoid - comment just repeats the code +// Increment the counter. +++counter; + +// ❌ Avoid - comment states the obvious +// Return the result. +return result; + +// ✅ Prefer - no comment needed when code is self-explanatory +++counter; +return result; + +// ✅ Prefer - comment explains WHY, not WHAT +// Use a larger buffer to avoid repeated reallocations for typical queries. +buffer.reserve(1024); +``` + +* **Do not duplicate comments between `.h` and `.cpp`.** Document the function + in the header; the implementation should not repeat the same comment. + Duplicated comments diverge over time. ## Asserts and CHECKs @@ -187,6 +259,12 @@ About comment style: i)` * Note that the values of v1 and v2 are already included in the exception message by default. +* Put runtime information (names, values, types) at the **end** of error + messages, after the static description. + + | ❌ Avoid | ✅ Prefer | + |----------|-----------| + | `VELOX_USER_FAIL("Column '{}' is ambiguous", name);` | `VELOX_USER_FAIL("Column is ambiguous: {}", name);` | ## Variables @@ -208,6 +286,9 @@ About comment style: * `int* foo;` `const Bar& bar;` NOT `int *foo;` `const Bar &bar`; * Beware that `int* foo, bar;` will be parsed as declaring `foo` as an `int*` and `bar` as an `int`. Note that multiple declaration is discouraged. +* Use trailing commas in multi-line initializer lists, enum definitions, and + function-call argument lists that span multiple lines. This produces cleaner + diffs when items are added or reordered. * For member variables: * Group member variable and methods based on their visibility (public, protected and private) @@ -234,8 +315,14 @@ About comment style: * Always use `nullptr` if you need a constant that represents a null pointer (`T*` for some `T`); use `0` otherwise for a zero value. -* For large literal numbers, use ‘ to make it more readable, e.g:  `1’000’000` - instead of `1000000`. +* For large literal numbers, use `'` to improve readability. Use digit + separators for numbers with 4 or more digits. + + | ❌ Avoid | ✅ Prefer | + |----------|-----------| + | `10000` | `10'000` | + | `1000000` | `1'000'000` | + | `100` | `100` (no separator needed) | * For floating point literals, never omit the initial 0 before the decimal point (always `0.5`, not `.5`). * File level variables and constants should be defined in an anonymous @@ -246,16 +333,17 @@ About comment style: * As a general rule, do not use string literals without declaring a named constant for them. * The best way to make a constant string literal is to use constexpr - `std::string_view`/`folly::StringPiece` + `std::string_view` * **NEVER** use `std::string` - this makes your code more prone to SIOF bugs. * Avoid `const char* const` and `const char*` - these are less efficient to convert to `std::string` later on in your program if you ever need to - because `std::string_view`/ `folly::StringPiece` knows its size and can use - a more efficient constructor. `std::string_view`/ `folly::StringPiece` also - has richer interfaces and often works as a drop-in replacement to - `std::string`. + because `std::string_view` knows its size and can use a more efficient + constructor. `std::string_view` also has richer interfaces and often + works as a drop-in replacement to `std::string`. * Need compile-time string concatenation? You can use `folly::FixedString` for that. + * Do not use `folly::StringPiece` in new code, use `std::string_view` + instead. ## Macros @@ -300,6 +388,20 @@ macro names are always upper-snake-case. Also: should opening the `-inl.h` be necessary (or the .cpp file, for that matter). +### CMake Header Ownership + +Every `.h` file in the Velox codebase must be associated with a CMake target. +This ensures headers are discoverable through the CMake File API, enabling +build impact analysis and selective builds. + +* When adding headers to a `velox_add_library()` target, list them under the + `HEADERS` keyword. +* When adding headers to test, benchmark, or fuzzer targets that use + `add_library()` or `add_executable()`, use + `velox_add_test_headers(
)`. +* The `check-header-ownership` pre-commit hook will flag any `.h` file that + is not tracked by a CMake target. + ## Function Arguments * Const @@ -415,3 +517,160 @@ using TypePtr = std::shared_ptr; using ContinueFuture = folly::SemiFuture; using ContinuePromise = VeloxPromise; ``` + +## API Design + +* **Keep the public API surface small.** Do not expose methods or types that + callers don't need. Fewer public symbols mean less coupling and easier + evolution. +* **Prefer free functions in .cpp over class methods.** If a helper doesn't + need access to class state or doesn't need to be called from outside the + translation unit, make it a free function (typically in an anonymous + namespace) rather than a static or private method. +* **Define free functions close to where they are used**, not grouped together + at the top or bottom of the file. +* **Keep method implementations in .cpp.** Except for trivial one-liners, + define methods in the .cpp file to keep headers small and reduce build times. +* **Avoid default arguments** when all callers can pass values explicitly. +* **Never use `friend`, `FRIEND_TEST`, or any friend declarations.** If a test + needs access to private members, redesign the API or test through public + methods instead. + +## Tests + +* **Place new tests next to related existing tests**, not at the end of the + file. Group tests by topic (e.g., place `tryCast` next to `types`, + `notBetween` next to `ifClause` which uses `between`). +* **Use gtest container matchers** (`testing::ElementsAre`, etc.) for + verifying collections: + + ```cpp + // ❌ Avoid - multiple individual assertions + EXPECT_EQ(result.size(), 3); + EXPECT_EQ(result[0], "a"); + EXPECT_EQ(result[1], "b"); + EXPECT_EQ(result[2], "c"); + + // ✅ Prefer - single matcher assertion + EXPECT_THAT(result, testing::ElementsAre("a", "b", "c")); + ``` + + Common matchers: + * `ElementsAre(...)` - exact ordered match + * `UnorderedElementsAre(...)` - exact unordered match + * `Contains(...)` - at least one element matches + * `IsEmpty()` - collection is empty + * `SizeIs(n)` - collection has n elements + + Requires `#include `. + +## Common Mistakes + +These are frequently violated rules. Check every new or modified line against +this list before finishing. + +### Bug fixes without a failing test first + +When fixing a bug, write the test **first**, run it, and confirm it **fails** +before applying the fix. Then apply the fix and confirm the test passes. If you +wrote the test after the fix, temporarily revert the fix and verify the test +fails. A test that passes both with and without the fix proves nothing. This is +the single most important workflow rule for bug fixes. + +### `///` vs `//` — wrong comment style + +`///` is **only** for public API: public classes, public methods, public member +variables. Everything else uses `//`: private/protected members, anonymous +namespace functions and types, comments inside function bodies. + +```cpp +// ❌ Wrong — anonymous-namespace function is not public API. +namespace { +/// Returns true if 'a' is a prefix of 'b'. +bool isPrefix(const ExprVector& a, const ExprVector& b); +} // namespace + +// ✅ Correct. +namespace { +// Returns true if 'a' is a prefix of 'b'. +bool isPrefix(const ExprVector& a, const ExprVector& b); +} // namespace +``` + +### One-letter and abbreviated variable names + +Do not abbreviate. Use full, descriptive names. Loop indices (`i`, `j`) are +acceptable. Everything else — function parameters, lambda parameters, local +variables — must be descriptive. + +```cpp +// ❌ Wrong — one-letter names and abbreviations. +bool sameKeys(const ExprVector& a, const ExprVector& b); +std::sort(groups.begin(), groups.end(), [](const auto& a, const auto& b) { ... }); + +// ✅ Correct — descriptive names. +bool sameKeys(const ExprVector& lhs, const ExprVector& rhs); +std::sort(groups.begin(), groups.end(), [](const auto& lhs, const auto& rhs) { ... }); +``` + +### Undocumented APIs in headers + +Every class, every non-trivial method declaration, and every member variable in +a `.h` file must have a comment. Trivial one-liner getters may be left +undocumented if the name is self-explanatory. + +```cpp +// ❌ Wrong — no comment on method declaration. + ExprPtr translate(const PlanNode* node); + +// ✅ Correct. + // Translates a logical plan node to an executable expression. + ExprPtr translate(const PlanNode* node); +``` + +### Non-trivial implementations in headers + +Keep method implementations in `.cpp` except for trivial one-liners. If a +method body has more than one statement, it belongs in the `.cpp` file. + +### `goto` statements + +Never use `goto`. Restructure the code using early returns, helper functions, or +duplicated code paths instead. `goto` makes control flow hard to follow and is +never necessary in C++. + +### Fitting tests to buggy code + +When a test fails, **never** update the test expectation to match what the code +produces without first verifying that the code is correct. The default +assumption should be that the code is wrong, not the test. + +### Generic file and class names (`*Utils`, `*Helpers`, `*Common`) + +Never name a file or class `*Utils`, `*Helpers`, or `*Common`. These generic +names attract unrelated functions over time and lose cohesion. Instead, name +files and classes after the concept they represent. When extracting shared +functions, ask: "What do these functions have in common beyond being useful?" +The answer is the name. + +### Verify causation before asserting it + +When investigating a test failure or regression, do not attribute it to a +specific commit based on the commit message alone. Verify empirically by +checking out the parent commit and running the test. Incorrect attribution +leads to wrong fixes — e.g., updating test expectations when the real problem +is a bug introduced by a different commit. + +### Silently simplifying an approved plan + +Never silently simplify or skip parts of an approved implementation plan. If a +step turns out to be harder than expected, or you want to defer it, say so +explicitly and get approval before proceeding with a reduced scope. Reporting +"done" when a key piece was dropped is worse than asking for help. + +### Working around infrastructure bugs + +When you discover a bug in test infrastructure, shared helpers, or common +utilities, do **not** silently work around it. Stop, report the finding, and +discuss whether to fix the root cause or work around it. Workarounds accumulate +into technical debt and mask real problems. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a76dd9e334c..751e9f8c960 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -137,6 +137,7 @@ where: * *Type* can be any of the following keywords: * **feat** when new features are being added. * **fix** for bug fixes. + * **perf** for performance improvements. * **build** for build or CI-related improvements. * **test** for adding tests (only). * **docs** for enhancements to documentation (only). @@ -224,6 +225,47 @@ following best practices: your PR. If a component or API does not have a corresponding unit test suite, please consider improving the codebase by first adding a new unit test suite to ensure the existing behavior is correct. + * **Common test workflows**: + ```bash + # Run all tests in parallel. + cd _build/debug && ctest -j 8 + + # Run all test binaries whose ctest name matches a regex. + # On Linux this matches velox_exec_test_group0 … _group7. + # On macOS this matches velox_exec_test_ValuesTest, + # velox_exec_test_HashJoinTest, etc. + cd _build/debug && ctest -R velox_exec + + # Run a single test binary by name (works on macOS where each + # test file produces its own binary). + cd _build/debug && ctest -R ValuesTest + ``` + * **Re-running a CI failure locally**: CI reports a failure in + `velox_exec_test_group3` with `ValuesTest.empty`. On Linux, run the grouped + binary directly. On macOS, the grouped binary does not exist — use the + per-file binary instead: + ```bash + # Linux (grouped binary) + _build/debug/velox/exec/tests/velox_exec_test_group3 --gtest_filter="ValuesTest.empty" + # macOS (per-file binary) + _build/debug/velox/exec/tests/velox_exec_test_ValuesTest --gtest_filter="ValuesTest.empty" + ``` + * **Test binary structure**: Four test suites (`velox/exec/tests`, + `velox/functions/prestosql/aggregates/tests`, `velox/common/caching/tests`, + `velox/serializers/tests`) use grouped binaries on Linux CI (e.g., + `velox_exec_test_group0` through `_group7`) to reduce link times. All other + suites use individual binaries on all platforms. On macOS, grouping is off + by default and each test file gets its own binary (e.g., + `velox_exec_test_ValuesTest`). To disable grouping on Linux, pass + `-DVELOX_ENABLE_GROUPED_TESTS=OFF` to CMake. + * **Adding a test to a grouped suite**: Add the source file to the `SOURCES` + list in the relevant `velox_add_grouped_tests()` call in `CMakeLists.txt`. + It is automatically assigned to a group on Linux and gets its own binary on + macOS. For new test suites, use `velox_add_grouped_tests` when the suite + has many test files (10+) that link against large libraries like velox + core — each individual binary pays the full link cost, so grouping + significantly reduces total CI build time. For suites with only a few + test files or lightweight dependencies, use `add_executable` / `add_test`. 4. **Code Comments**: Appropriately add comments to your code and document APIs. * As a library, Velox code is optimized for the reader, not the writer. diff --git a/README.md b/README.md index 4880335989e..2043c7ca65b 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,9 @@ Velox logo +[![Linux Build using GCC](https://github.com/facebookincubator/velox/actions/workflows/linux-build.yml/badge.svg)](https://github.com/facebookincubator/velox/actions/workflows/linux-build.yml) +[![macOS Build](https://github.com/facebookincubator/velox/actions/workflows/macos.yml/badge.svg)](https://github.com/facebookincubator/velox/actions/workflows/macos.yml) + Velox is a composable execution engine distributed as an open source C++ library. It provides reusable, extensible, and high-performance data processing components that can be (re-)used to build data management systems focused on @@ -88,7 +91,11 @@ found here](velox/examples) Developer guides detailing many aspects of the library, in addition to the list of available functions [can be found here.](https://facebookincubator.github.io/velox) -Blog posts are available [here](https://velox-lib.io/blog). +Recent blog posts ([all posts](https://velox-lib.io/blog)): + +- [FlatMapVector Adoption for Scaling High-Performance AI/ML Data Pre-Processing](https://velox-lib.io/blog/flatmapvector) (2026-05-01) +- [Nimble Cluster Index: Efficient Indexed Lookups on Columnar Data](https://velox-lib.io/blog/nimble-cluster-index) (2026-04-27) +- [Axiom: Composable Query Engines Built on Velox](https://velox-lib.io/blog/axiom-composable-query-engines) (2026-04-23) ## Community @@ -171,7 +178,7 @@ Using the default install location `/usr/local` on macOS is discouraged since th location is used by certain Homebrew versions. Manually add the `INSTALL_PREFIX` value in the IDE or bash environment, -say `export INSTALL_PREFIX=/Users/$USERNAME/velox/deps-install` to `~/.zshrc` so that +say `export INSTALL_PREFIX=/Users/$USER/velox/deps-install` to `~/.zshrc` so that subsequent Velox builds can use the installed packages. *You can reuse `DEPENDENCY_INSTALL` and `INSTALL_PREFIX` for Velox clients such as Prestissimo @@ -252,6 +259,13 @@ Run `make` in the root directory to compile the sources. For development, use `make debug` to build a non-optimized debug version, or `make release` to build an optimized version. Use `make unittest` to build and run tests. +Four test suites use grouped binaries on Linux CI to reduce link times +(`velox/exec/tests`, `velox/functions/prestosql/aggregates/tests`, +`velox/common/caching/tests`, `velox/serializers/tests`). All other suites use +individual binaries on all platforms. On macOS, grouping is off by default. To +disable grouping on Linux, pass `-DVELOX_ENABLE_GROUPED_TESTS=OFF` via +`EXTRA_CMAKE_FLAGS`. + Note that, * Velox requires a compiler at the minimum GCC 11.0 or Clang 15.0. * Velox requires the CPU to support instruction sets: diff --git a/docker-bake.hcl b/docker-bake.hcl index 3a63980522a..26a62dcb9ee 100644 --- a/docker-bake.hcl +++ b/docker-bake.hcl @@ -68,6 +68,13 @@ target "pyvelox" { args = { image = "quay.io/pypa/manylinux_2_28:latest" VELOX_BUILD_SHARED = "OFF" + # pyvelox uses a manylinux base which is el8, not el9. The + # dockerfile's CENTOS_TZDATA_VERSION default is el9 (correct for + # the centos9/adapters targets); override here to the el8 build of + # the same upstream tzdata release. Keep the upstream version + # (2026a) in sync with CENTOS_TZDATA_VERSION at the top of + # scripts/docker/centos-multi.dockerfile. + CENTOS_TZDATA_VERSION = "2026a-1.el8" } matrix = { arch = ["amd64", "arm64"] diff --git a/docker-compose.yml b/docker-compose.yml index 5457efe33b7..c7c47c575f3 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -77,13 +77,17 @@ services: - driver: nvidia count: 1 capabilities: [gpu] + environment: + CC: /opt/rh/gcc-toolset-14/root/bin/gcc # Default to gcc-14 + CXX: /opt/rh/gcc-toolset-14/root/bin/g++ # Default to gcc-14 + CUDAHOSTCXX: /opt/rh/gcc-toolset-14/root/bin/g++ # Default to gcc-14 presto-java: extends: base-block image: ghcr.io/facebookincubator/velox-dev:presto-java build: args: - - PRESTO_VERSION=0.293 + - PRESTO_VERSION=0.295 dockerfile: scripts/docker/java.dockerfile spark-server: @@ -93,6 +97,7 @@ services: args: - SPARK_VERSION=3.5.1 dockerfile: scripts/docker/java.dockerfile + target: spark-server fedora: extends: base-block diff --git a/python/test/test_debugger_runner.py b/python/test/test_debugger_runner.py new file mode 100644 index 00000000000..5f33c10b31f --- /dev/null +++ b/python/test/test_debugger_runner.py @@ -0,0 +1,313 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import pyarrow +from pyvelox.arrow import to_velox +from pyvelox.plan_builder import PlanBuilder +from pyvelox.runner import ( + LocalDebuggerRunner, +) + + +class TestPyVeloxDebuggerRunner(unittest.TestCase): + def setUp(self) -> None: + self._batch_size = 20 + self._num_batches = 10 + self._num_projections = 10 + + array = pyarrow.array(list(range(self._batch_size))) + # type: ignore + batch = pyarrow.record_batch([array], names=["c0"]) + + vector = to_velox(batch) + vectors = [vector] * self._num_batches + + plan_builder = PlanBuilder().values(vectors) + self._node_ids = [] + + for _ in range(self._num_projections): + plan_builder.project(["c0 * 10 as c0"]) + self._node_ids.append(plan_builder.get_plan_node().id()) + + self._plan_node = plan_builder.get_plan_node() + + def _count_next(self, runner): + total_size = 0 + it = runner.execute() + + while True: + try: + vector = it.next() + total_size += vector.size() + except StopIteration: + break + return total_size + + def _count_step(self, runner): + total_size = 0 + it = runner.execute() + + while True: + try: + vector = it.step() + total_size += vector.size() + except StopIteration: + break + return total_size + + def test_undrained(self): + """Ensure tasks won't hang or fail if they are not completely drained.""" + runner = LocalDebuggerRunner(self._plan_node) + runner.execute() + + def test_next_no_breakpoints(self): + runner = LocalDebuggerRunner(self._plan_node) + self.assertEqual(self._count_next(runner), self._batch_size * self._num_batches) + + def test_step_no_breakpoints(self): + runner = LocalDebuggerRunner(self._plan_node) + self.assertEqual(self._count_step(runner), self._batch_size * self._num_batches) + + def test_next_with_breakpoints(self): + runner = LocalDebuggerRunner(self._plan_node) + runner.set_breakpoint(self._node_ids[3]) + runner.set_breakpoint(self._node_ids[5]) + self.assertEqual(self._count_next(runner), self._batch_size * self._num_batches) + + def test_step_with_breakpoints(self): + runner = LocalDebuggerRunner(self._plan_node) + runner.set_breakpoint(self._node_ids[3]) + runner.set_breakpoint(self._node_ids[5]) + runner.set_breakpoint(self._node_ids[8]) + self.assertEqual( + self._count_step(runner), self._batch_size * self._num_batches * 4 + ) + + def test_step_all_breakpoints(self): + runner = LocalDebuggerRunner(self._plan_node) + [runner.set_breakpoint(i) for i in self._node_ids] + self.assertEqual( + self._count_step(runner), + self._batch_size * self._num_batches * (self._num_projections + 1), + ) + + def test_breakpoint_with_aggregate(self): + """Set a breakpoint before aggregation to see pre-aggregated data.""" + batch_size = 100 + + # Create data with some duplicates for aggregation. + values = [i % 10 for i in range(batch_size)] + array = pyarrow.array(values) + batch = pyarrow.record_batch([array], names=["c0"]) + vector = to_velox(batch) + + # Produce the input vector 3 times. + plan_builder = PlanBuilder().values([vector, vector, vector]) + + plan_builder.aggregate(grouping_keys=["c0"], aggregations=["count(1) as cnt"]) + agg_node_id = plan_builder.get_plan_node().id() + + runner = LocalDebuggerRunner(plan_builder.get_plan_node()) + runner.set_breakpoint(agg_node_id) + + it = runner.execute() + + # Get aggregate first input, 100 records. + it.step() + self.assertEqual(it.current().size(), 100) + + # Get aggregate second input, 100 records. + it.step() + self.assertEqual(it.current().size(), 100) + + # Ignore next aggregate input and move to the task output, 10 records. + it.next() + self.assertEqual(it.current().size(), 10) + + # Should be done now. + with self.assertRaises(StopIteration): + it.next() + + def test_iterator_at(self): + runner = LocalDebuggerRunner(self._plan_node) + runner.set_breakpoint(self._node_ids[3]) + runner.set_breakpoint(self._node_ids[8]) + + it = runner.execute() + self.assertEqual(it.at(), "") + + it.step() + self.assertEqual(it.at(), self._node_ids[3]) + + it.step() + self.assertEqual(it.at(), self._node_ids[8]) + + it.step() + self.assertEqual(it.at(), "") + + it.step() + self.assertEqual(it.at(), self._node_ids[3]) + + it.next() + self.assertEqual(it.at(), "") + + it.next() + self.assertEqual(it.at(), "") + + it.step() + self.assertEqual(it.at(), self._node_ids[3]) + + def test_hook_always_stop(self): + """Hook that always returns True should behave like set_breakpoint.""" + runner = LocalDebuggerRunner(self._plan_node) + runner.set_hook(self._node_ids[3], lambda v: True) + runner.set_hook(self._node_ids[5], lambda v: True) + runner.set_hook(self._node_ids[8], lambda v: True) + self.assertEqual( + self._count_step(runner), self._batch_size * self._num_batches * 4 + ) + + def test_hook_never_stop(self): + """Hook that always returns False should skip the breakpoint.""" + runner = LocalDebuggerRunner(self._plan_node) + runner.set_hook(self._node_ids[3], lambda v: False) + runner.set_hook(self._node_ids[5], lambda v: False) + runner.set_hook(self._node_ids[8], lambda v: False) + # Since all hooks return False, step() should behave like next() + self.assertEqual(self._count_step(runner), self._batch_size * self._num_batches) + + def test_hook_conditional(self): + """Hook that conditionally stops based on vector content.""" + runner = LocalDebuggerRunner(self._plan_node) + + # Track how many times the hook is called. + call_count = 0 + + def conditional_hook(vector): + nonlocal call_count + call_count += 1 + # Stop only on odd calls. + return call_count % 2 == 1 + + runner.set_hook(self._node_ids[5], conditional_hook) + + # The hook is called for each batch (10 batches). + # It stops on odd calls (1, 3, 5, 7, 9) = 5 stops. + # Plus 10 task outputs = 15 total vectors. + self.assertEqual(self._count_step(runner), self._batch_size * 15) + self.assertEqual(call_count, self._num_batches) + + def test_hook_mixed_with_breakpoint(self): + """Mix set_breakpoint and set_hook.""" + runner = LocalDebuggerRunner(self._plan_node) + + # set_breakpoint always stops. + runner.set_breakpoint(self._node_ids[3]) + + # set_hook that never stops. + runner.set_hook(self._node_ids[5], lambda v: False) + + # set_hook that always stops. + runner.set_hook(self._node_ids[8], lambda v: True) + + # node_ids[3] stops (10 batches), node_ids[8] stops (10 batches), + # plus 10 task outputs = 30 total vectors. + self.assertEqual( + self._count_step(runner), self._batch_size * self._num_batches * 3 + ) + + def test_hook_inspects_vector(self): + """Verify the hook receives a valid vector.""" + runner = LocalDebuggerRunner(self._plan_node) + + received_sizes = [] + + def inspect_hook(vector): + received_sizes.append(vector.size()) + return True + + runner.set_hook(self._node_ids[5], inspect_hook) + + it = runner.execute() + + # Consume all output. + while True: + try: + it.step() + except StopIteration: + break + + # Hook should have been called for each batch. + self.assertEqual(len(received_sizes), self._num_batches) + # Each vector should have batch_size rows. + self.assertTrue(all(s == self._batch_size for s in received_sizes)) + + def test_step_with_plan_id(self): + """step(plan_id) should only stop at the matching breakpoint.""" + runner = LocalDebuggerRunner(self._plan_node) + runner.set_breakpoint(self._node_ids[3]) + runner.set_breakpoint(self._node_ids[5]) + runner.set_breakpoint(self._node_ids[8]) + + it = runner.execute() + + # Step targeting node_ids[5] should skip node_ids[3] and stop at + # node_ids[5]. + it.step(self._node_ids[5]) + self.assertEqual(it.at(), self._node_ids[5]) + + # Step targeting node_ids[8] should skip node_ids[5] (remaining) and + # stop at node_ids[8]. + it.step(self._node_ids[8]) + self.assertEqual(it.at(), self._node_ids[8]) + + # Step with no filter (default) stops at the next breakpoint or task + # output. + it.step() + self.assertEqual(it.at(), "") + + def test_step_with_plan_id_counts(self): + """step(plan_id) should only produce vectors from the matching + breakpoint plus task outputs.""" + runner = LocalDebuggerRunner(self._plan_node) + runner.set_breakpoint(self._node_ids[3]) + runner.set_breakpoint(self._node_ids[5]) + + it = runner.execute() + total_size = 0 + + # Only step to node_ids[5], skipping node_ids[3]. + while True: + try: + vector = it.step(self._node_ids[5]) + total_size += vector.size() + except StopIteration: + break + + # We expect node_ids[5] breakpoint hits (num_batches) plus task outputs + # (num_batches) = 2 * num_batches vectors, each of batch_size rows. + self.assertEqual(total_size, self._batch_size * self._num_batches * 2) + + def test_step_with_plan_id_default_behavior(self): + """step() with no argument should preserve original behavior.""" + runner = LocalDebuggerRunner(self._plan_node) + runner.set_breakpoint(self._node_ids[3]) + runner.set_breakpoint(self._node_ids[5]) + + # step() with no plan_id should hit both breakpoints + task output. + self.assertEqual( + self._count_step(runner), self._batch_size * self._num_batches * 3 + ) diff --git a/python/test/test_plan_builder.py b/python/test/test_plan_builder.py index c0c1c8d1ec4..f2d1c23d69d 100644 --- a/python/test/test_plan_builder.py +++ b/python/test/test_plan_builder.py @@ -100,3 +100,19 @@ def test_plan_serialization(self): ) self.assertEqual(str(plan_node), str(plan_clone)) self.assertEqual(plan_node.to_string(), plan_clone.to_string()) + + def test_mark_sorted(self): + plan_builder = PlanBuilder() + plan_builder.table_scan(ROW(["c0", "c1"], [BIGINT(), BIGINT()])) + + plan_builder.mark_sorted( + marker_key="is_sorted", + sorting_keys=["c0", "c1 DESC"], + ) + mark_sorted_node = plan_builder.get_plan_node() + self.assertEqual(mark_sorted_node.name(), "MarkSorted") + + self.assertEqual( + str(mark_sorted_node), + "-- MarkSorted[1]\n -- TableScan[0]\n", + ) diff --git a/python/test/test_runner.py b/python/test/test_runner.py index 59a5baf4a43..1f668c68fd4 100644 --- a/python/test/test_runner.py +++ b/python/test/test_runner.py @@ -54,12 +54,24 @@ def test_not_executed(self): plan_builder = PlanBuilder().values() LocalRunner(plan_builder.get_plan_node()) - def test_executed_twice(self): - # Ensure the runner fails if it is executed twice. - plan_builder = PlanBuilder().values() + def test_execute_twice(self): + # Ensure a runner can be executed twice. + vector = to_velox( + pyarrow.record_batch([pyarrow.array(list(range(10)))], names=["c0"]) + ) + + plan_builder = PlanBuilder().values([vector]) runner = LocalRunner(plan_builder.get_plan_node()) - runner.execute() - self.assertRaises(RuntimeError, runner.execute) + + total_size = 0 + for vector in runner.execute(): + total_size += vector.size() + self.assertEqual(total_size, 10) + + total_size = 0 + for vector in runner.execute(): + total_size += vector.size() + self.assertEqual(total_size, 10) def test_values(self): vectors = [] @@ -103,6 +115,55 @@ def test_values_order_limit(self): ) self.assertEqual(output, expected_result) + def test_mark_sorted(self): + # Sorted data: marker column should be all True. + vector = to_velox( + pyarrow.record_batch([pyarrow.array([1, 2, 3, 4, 5])], names=["c0"]) + ) + + plan_builder = ( + PlanBuilder() + .values([vector]) + .mark_sorted( + marker_key="is_sorted", + sorting_keys=["c0"], + ) + ) + runner = LocalRunner(plan_builder.get_plan_node()) + iterator = runner.execute() + output = next(iterator) + self.assertRaises(StopIteration, next, iterator) + + self.assertEqual(output.size(), 5) + # Marker column is at index 1 (after c0). + marker = output.child_at(1) + for i in range(5): + self.assertEqual(marker[i], "true") + + def test_mark_sorted_unsorted(self): + # Unsorted data: row at index 2 breaks sort order (3 -> 2). + vector = to_velox( + pyarrow.record_batch([pyarrow.array([1, 3, 2, 4, 5])], names=["c0"]) + ) + + plan_builder = ( + PlanBuilder() + .values([vector]) + .mark_sorted( + marker_key="is_sorted", + sorting_keys=["c0"], + ) + ) + runner = LocalRunner(plan_builder.get_plan_node()) + iterator = runner.execute() + output = next(iterator) + self.assertRaises(StopIteration, next, iterator) + + marker = output.child_at(1) + expected = ["true", "true", "false", "true", "true"] + for i in range(5): + self.assertEqual(marker[i], expected[i]) + def test_hash_join(self): batch_size = 100 probe = list(range(batch_size)) diff --git a/python/test/test_type.py b/python/test/test_type.py index ff079cd1cb1..b059ee58723 100644 --- a/python/test/test_type.py +++ b/python/test/test_type.py @@ -26,6 +26,7 @@ DOUBLE, VARCHAR, VARBINARY, + JSON, ARRAY, MAP, ROW, @@ -43,6 +44,7 @@ def test_simple_types(self): self.assertTrue(isinstance(DOUBLE(), Type)) self.assertTrue(isinstance(VARCHAR(), Type)) self.assertTrue(isinstance(VARBINARY(), Type)) + self.assertTrue(isinstance(JSON(), Type)) def test_complex_types(self): self.assertTrue(isinstance(ARRAY(VARCHAR()), Type)) @@ -86,3 +88,5 @@ def test_equality(self): self.assertNotEqual(BIGINT(), INTEGER()) self.assertNotEqual(ARRAY(BIGINT()), REAL()) + self.assertNotEqual(VARBINARY(), VARCHAR()) + self.assertNotEqual(JSON(), VARCHAR()) diff --git a/scripts/checks/check-header-ownership.py b/scripts/checks/check-header-ownership.py new file mode 100755 index 00000000000..9d6da3f650a --- /dev/null +++ b/scripts/checks/check-header-ownership.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +CI check: verify every .h file under velox/ is listed in a CMakeLists.txt target. + +Exits with code 1 if orphan headers are found. + +Usage: + python3 .github/scripts/check-header-ownership.py [velox_source_dir] + +The velox_source_dir defaults to "velox" (relative to repo root). +""" + +import os +import re +import sys + +SKIP_DIRS = { + "external", + "experimental", + "facebook", + "public_tld", + "python", +} + +SKIP_PATHS = { + "tpcds/gen/dsdgen/include", + "tpch/gen/dbgen/include", +} + + +def collect_cmake_headers(velox_dir): + """Walk CMakeLists.txt files and collect all .h filenames mentioned.""" + tracked = set() + + for root, _dirs, files in os.walk(velox_dir): + if "CMakeLists.txt" not in files: + continue + + rel_root = os.path.relpath(root, velox_dir) + + cmake_path = os.path.join(root, "CMakeLists.txt") + with open(cmake_path) as f: + content = f.read() + + # Find all .h references in any target call or HEADERS block + for h in re.findall(r"([\w][\w/\-]*\.h)", content): + # Resolve relative to CMakeLists.txt directory + tracked.add(os.path.normpath(os.path.join(rel_root, h))) + + return tracked + + +def collect_fs_headers(velox_dir): + """Walk filesystem for all .h files, respecting skip rules.""" + headers = set() + + for root, dirs, files in os.walk(velox_dir): + rel_root = os.path.relpath(root, velox_dir) + + # Skip entire directories + parts = rel_root.split(os.sep) + if any(p in SKIP_DIRS for p in parts): + dirs.clear() + continue + + if any(rel_root.startswith(sp) for sp in SKIP_PATHS): + dirs.clear() + continue + + for f in files: + if f.endswith(".h"): + headers.add(os.path.normpath(os.path.join(rel_root, f))) + + return headers + + +def main(): + velox_dir = sys.argv[1] if len(sys.argv) > 1 else "velox" + + if not os.path.isdir(velox_dir): + print(f"Error: {velox_dir} is not a directory", file=sys.stderr) + sys.exit(2) + + tracked = collect_cmake_headers(velox_dir) + fs_headers = collect_fs_headers(velox_dir) + + orphans = sorted(fs_headers - tracked) + + if orphans: + print( + f"Found {len(orphans)} header(s) not tracked by any CMakeLists.txt target:" + ) + for h in orphans: + print(f" {h}") + print() + print("To fix, add each header to the appropriate CMakeLists.txt:") + print(" - For velox_add_library() targets: add to the HEADERS list") + print( + " - For test/benchmark/fuzzer targets: use velox_add_test_headers(
)" + ) + sys.exit(1) + else: + print(f"All {len(fs_headers)} headers are tracked by CMakeLists.txt targets.") + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/scripts/checks/check-readme-blogs.sh b/scripts/checks/check-readme-blogs.sh new file mode 100755 index 00000000000..46cdb8c209d --- /dev/null +++ b/scripts/checks/check-readme-blogs.sh @@ -0,0 +1,68 @@ +#!/usr/bin/env bash +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Ensures the "Recent blog posts" section in README.md lists the 3 most recent +# blog posts from website/blog/. Auto-fixes if out of date. + +set -euo pipefail + +BLOG_DIR="website/blog" +NUM_POSTS=3 + +# Build expected blog list lines from the most recent .mdx files. +EXPECTED_LINES=() +for f in $(ls -1 "$BLOG_DIR"/*.mdx | sort -r | head -$NUM_POSTS); do + DATE=$(basename "$f" | grep -oE '^[0-9]{4}-[0-9]{2}-[0-9]{2}') + TITLE=$(grep '^title:' "$f" | sed 's/title: *//; s/"//g') + SLUG=$(grep '^slug:' "$f" | sed 's/slug: *//') + EXPECTED_LINES+=("- [$TITLE](https://velox-lib.io/blog/$SLUG) ($DATE)") +done + +# Extract current blog list: everything between "Recent blog posts" header and +# the next markdown section (##). +CURRENT=$(sed -n '/^Recent blog posts/,/^##/{/^Recent blog posts/d;/^##/d;/^$/d;p;}' README.md) +EXPECTED=$(printf '%s\n' "${EXPECTED_LINES[@]}") + +if [ "$CURRENT" = "$EXPECTED" ]; then + exit 0 +fi + +# Auto-fix: replace everything between "Recent blog posts" header and the next +# section with the expected blog list. +TEMP=$(mktemp) +IN_SECTION=false +while IFS= read -r line; do + if [[ $line == "Recent blog posts"* ]]; then + echo "$line" >>"$TEMP" + echo "" >>"$TEMP" + printf '%s\n' "${EXPECTED_LINES[@]}" >>"$TEMP" + IN_SECTION=true + continue + fi + if $IN_SECTION; then + # Skip until the next section header. + if [[ $line == "##"* ]]; then + IN_SECTION=false + echo "" >>"$TEMP" + echo "$line" >>"$TEMP" + fi + continue + fi + echo "$line" >>"$TEMP" +done = 3: lspan = fields[2].split(",") if len(lspan) <= 1: - lspan.append(0) - - changed_lines[file] = [int(lspan[0]), int(lspan[0]) + int(lspan[1])] + lspan.append("0") - return json.dumps( - [{"name": key, "lines": value} for key, value in changed_lines.items()] - ) + start_line = int(lspan[0]) + line_count = int(lspan[1]) + # Skip invalid line ranges (e.g., +0,0 from deleted files) + if start_line > 0 or line_count > 0: + changed_lines[file] = [start_line, start_line + line_count] -def checks(args): - status, stdout, stderr = util.run( - f"clang-tidy -checks='{CODE_CHECKS}' --list-checks" - ) - print(stdout) + return changed_lines def check_output(output): - return regex.match(r"^/.* warning: ", output) + return re.match(r"(^/.* warning: |^$)", output) def tidy(args): files = util.input_files(args.files) + files = [file for file in files if re.match(r".*(\.cpp|\.h|\.hpp)$", file)] - groups = Multimap() + # Exclude files in cudf, wave, and torchwave directories + # as clang-tidy doesn't support CUDA compiler flags and CUDA headers + files = [file for file in files if "cudf/" not in file and "wave/" not in file] - for file in files: - groups["test" if "/tests/" in file else "main"] = file + # Exclude *-inl.h files: they are designed to be included from their + # corresponding header and cannot be compiled as standalone translation + # units (see git_changed_lines for rationale). + files = [file for file in files if not file.endswith("-inl.h")] - fix = "--fix" if args.fix == "fix" else "" - lines = ( - ("'--line-filter=" + git_changed_lines(args.commit)) + "'" - if args.commit is not None - else "" + in_gha = os.environ.get("GITHUB_ACTIONS") is not None + changed_lines = git_changed_lines(args.commit) + + line_filter = json.dumps( + [{"name": key, "lines": value} for key, value in changed_lines.items()] ) + filtered_files = [*changed_lines.keys()] + if len(filtered_files) == 0: + return 0 + + fix = "--fix" if args.fix == "fix" else "" + lines = f"'--line-filter={line_filter}'" if args.commit is not None else "" ok = True - if groups.get("main", None): - status, stdout, stderr = util.run( - f"xargs clang-tidy -p=build/release/ --format-style=file -header-filter='.*' --checks='{CODE_CHECKS}' --quiet {fix} {lines}", - input=groups["main"], - ) - ok = check_output(stdout) and ok + build_path = args.p or os.getenv("BUILD_PATH") + build_path_str = f"-p {build_path}" if build_path else "" - if groups.get("test", None): - status, stdout, stderr = util.run( - f"xargs clang-tidy -p=build/release/ --format-style=file -header-filter='.*' --checks='{TEST_CHECKS}' --quiet {fix} {lines}", - input=groups["test"], + if build_path_str == "" and not os.path.isfile( + os.getcwd().join("compile_commands.json") + ): + print("compile_commands.json not found, skipping clang-tidy") + return 0 + + status, stdout, stderr = util.run( + f"xargs clang-tidy --format-style=file -header-filter='.*' --quiet {build_path_str} {fix} {lines}", + input=filtered_files, + ) + + if in_gha: + clang_tidy_pattern = ( + r"^(.*):(\d+):(\d+):\s+(error|warning):\s+(.*) \[([a-z0-9,\-]+)\]\s*$" ) - ok = check_output(stdout) and ok + + for stdout_line in stdout.split("\n"): + m = re.match(clang_tidy_pattern, stdout_line) + if m is not None: + file, line, col, severity, message, rule = m.groups() + file = file.removeprefix("/__w/velox/velox/") + print( + f"::{severity} file={file},line={line},col={col},title={rule}::{message}" + ) + + ok = check_output(stdout) + if not ok: + print(stdout) return 0 if ok else 1 def parse_args(): global parser - parser = argparse.ArgumentParser(description="CircliCi Utility") + parser = argparse.ArgumentParser(description="Clang Tidy Utility") parser.add_argument("--commit") parser.add_argument("--fix") + parser.add_argument("-p", help="Path containing 'compile_commands.json'") parser.add_argument("files", metavar="FILES", nargs="+", help="files to process") diff --git a/scripts/checks/util.py b/scripts/checks/util.py index c0b34532a94..fe7f05c86dd 100644 --- a/scripts/checks/util.py +++ b/scripts/checks/util.py @@ -15,7 +15,7 @@ import gzip import json import os -import regex +import re import subprocess import sys @@ -27,7 +27,7 @@ class attrdict(dict): class string(str): def extract(self, rexp): - return regex.match(rexp, self).group(1) + return re.match(rexp, self).group(1) def json(self): return json.loads(self, object_hook=attrdict) diff --git a/scripts/ci/benchmark-alert.py b/scripts/ci/benchmark-alert.py deleted file mode 100755 index a4e7c9f9582..00000000000 --- a/scripts/ci/benchmark-alert.py +++ /dev/null @@ -1,160 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Facebook, Inc. and its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os -import re -from pprint import pprint -from typing import Optional, Tuple - -import benchalerts.pipeline_steps as steps -import requests -from benchalerts import AlertPipeline -from benchclients import log - -DESCRIPTION = """ -Analyze benchmark runs, post a GitHub Check whether there were regressions or not, and -if it's a merge-commit, post a comment back to the PR that merged the commit in. - -Required environment variables: - GITHUB_REPOSITORY - the default GITHUB_REPOSITORY given in GitHub Actions - GITHUB_APP_ID - the ID of a GitHub App installed to this repo - GITHUB_APP_PRIVATE_KEY - the private key file contents of the GitHub App - CONBENCH_URL - the URL to the Conbench server where benchmark results are stored -""" - - -def parse_args() -> Tuple[str, str, str, float]: - parser = argparse.ArgumentParser( - description=DESCRIPTION, formatter_class=argparse.RawDescriptionHelpFormatter - ) - parser.add_argument( - "--contender-sha", - type=str, - required=True, - help="The SHA hash of the contender commit. The script will download benchmark " - "information about this commit and post a GitHub Check of the full report to " - "this commit.", - ) - parser.add_argument( - "--merge-commit-message", - type=str, - required=True, - help="If the contender commit is a merge-commit, its message. If not, an empty " - "string.", - ) - parser.add_argument( - "--z-score-threshold", - type=float, - required=True, - help="The (positive) z-score threshold. Benchmarks with a z-score more extreme " - "than this threshold will be marked as regressions.", - ) - args = parser.parse_args() - - repo = os.environ["GITHUB_REPOSITORY"] - - return ( - repo, - args.contender_sha, - args.merge_commit_message, - args.z_score_threshold, - ) - - -def _merge_commit_pr_number(commit_message: str) -> Optional[int]: - """If this is a merge-commit run, get the number of the PR that was merged by - grepping the commit message. - - Other alternatives to this strategy: - - - Hit /repos/{owner}/{repo}/commits/{commit_sha}/pulls to get the number. This - doesn't work because of the way facebook-github-bot merges PRs. - - Use the GitHub GraphQL API to search through recent "PR closed" events until you - find the one that's associated with this commit. This does work but actually feels - *more* likely to break than grepping. - """ - if not commit_message: - print("Merge-commit message not given; this must be a PR commit") - return None - - # Look for e.g. (#123) at the end of the first line - res = re.search(r"\(#(\d+)\)$", commit_message.split("\n")[0]) - - if not res: - print("Could not find the PR number in the following merge-commit message:") - print(commit_message) - return None - - print(f"Found a PR number (in the given merge-commit message): {repr(res[1])}") - return int(res[1]) - - -def main( - repo: str, - contender_sha: str, - merge_commit_message: str, - z_score_threshold: float, -): - log.setLevel("DEBUG") - - if merge_commit_message: - # Compare against the parent commit of the merge-commit - baseline_run_type = steps.BaselineRunCandidates.parent - else: - # Compare against the default-branch commit from which the PR was forked - baseline_run_type = steps.BaselineRunCandidates.fork_point - - pipeline_steps = [ - steps.GetConbenchZComparisonStep( - commit_hash=contender_sha, - baseline_run_type=baseline_run_type, - z_score_threshold=z_score_threshold, - step_name="z_comparison", - ), - steps.GitHubCheckStep( - commit_hash=contender_sha, comparison_step_name="z_comparison", repo=repo - ), - ] - - # Only post a comment on merge-commits - if merge_commit_message: - pr_number = _merge_commit_pr_number(merge_commit_message) - pipeline_steps.append( - steps.GitHubPRCommentAboutCheckStep(pr_number=pr_number, repo=repo) - ) - - pipeline = AlertPipeline(steps=pipeline_steps) - pprint(pipeline.run_pipeline()) - - -def test_grepping(): - """This is never called. Use for local dev.""" - _merge_commit_pr_number("") - for commit in requests.get( - "https://api.github.com/repos/facebookincubator/velox/commits" - ).json(): - print(commit["sha"][:7]) - _merge_commit_pr_number(commit["commit"]["message"]) - - -if __name__ == "__main__": - repo, contender_sha, merge_commit_message, z_score_threshold = parse_args() - main( - repo=repo, - contender_sha=contender_sha, - merge_commit_message=merge_commit_message, - z_score_threshold=z_score_threshold, - ) diff --git a/scripts/ci/bm-report/report.qmd b/scripts/ci/bm-report/report.qmd index a91a116a6b9..b9ad14c79a4 100644 --- a/scripts/ci/bm-report/report.qmd +++ b/scripts/ci/bm-report/report.qmd @@ -50,7 +50,7 @@ run_shas <- runs |> jsonlite::fromJSON() run_ids <- mruns(run_shas) |> - filter(commit.branch == "facebookincubator:main", substr(id, 1, 2) == "BM") |> + filter(substr(id, 1, 2) == "BM") |> pull(id) # Speed up local dev by saving 'results' as conbench requests can't be memoised diff --git a/scripts/ci/gersemi_cmd_definitions.py b/scripts/ci/gersemi_cmd_definitions.py index 78aca426e6a..32f1ebe2dd8 100644 --- a/scripts/ci/gersemi_cmd_definitions.py +++ b/scripts/ci/gersemi_cmd_definitions.py @@ -45,6 +45,10 @@ "pyvelox_add_module": pyvelox_add_module, "velox_add_library": velox_add_library, "velox_base_add_library": velox_base_add_library, + "velox_add_test_headers": { + "front_positional_arguments": ["target_name"], + "back_positional_arguments": ["headers"], + }, "velox_build_dependency": { "front_positional_arguments": ["dependency_name"], }, @@ -60,4 +64,8 @@ "front_positional_arguments": ["var_name", "envvar_name", "default"] }, "velox_sources": builtin_commands["target_sources"], + "velox_add_cudf_test": { + "one_value_keywords": ["NAME", "TIMEOUT"], + "multi_value_keywords": ["SOURCES", "LIBS"], + }, } diff --git a/scripts/ci/run-fuzzer-parallel.sh b/scripts/ci/run-fuzzer-parallel.sh new file mode 100755 index 00000000000..c62509378a0 --- /dev/null +++ b/scripts/ci/run-fuzzer-parallel.sh @@ -0,0 +1,113 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Run multiple fuzzer instances in parallel with different seeds. +# +# Usage: run-fuzzer-parallel.sh [args...] +# +# In the fuzzer arguments, use these placeholders which are replaced per instance: +# __SEED__ -> unique random seed per instance +# __INSTANCE_ID__ -> instance number (1, 2, ...) +# __REPRO_DIR__ -> per-instance repro directory (/instance_N) +# __LOG_DIR__ -> per-instance log directory (/instance_N/logs) + +set -euo pipefail + +NUM_INSTANCES=$1 +REPRO_BASE=$2 +BINARY=$3 +shift 3 + +PIDS=() +SEEDS=() +FAILED=0 + +dump_instance() { + local idx=$1 + echo "=== Instance ${idx} stdout (last 200 lines) ===" + tail -200 "${REPRO_BASE}/instance_${idx}/stdout.log" 2>/dev/null || + echo "(no stdout.log)" + echo "=== End instance ${idx} ===" +} + +# When the GH Actions step timeout fires it sends SIGTERM. Without this trap +# `wait` blocks indefinitely, the per-instance stdout is never surfaced, and +# the only signal in the job log is "context canceled" with no diagnostics. +on_signal() { + local sig=$1 + echo "::warning::run-fuzzer-parallel.sh received ${sig}; dumping per-instance state" + for i in "${!PIDS[@]}"; do + local idx=$((i + 1)) + local pid=${PIDS[$i]} + if kill -0 "$pid" 2>/dev/null; then + echo "::error::Instance ${idx} (PID ${pid}, seed=${SEEDS[$i]}) still running at ${sig}; sending SIGTERM" + kill -TERM "$pid" 2>/dev/null || true + fi + done + sleep 5 + for i in "${!PIDS[@]}"; do + local idx=$((i + 1)) + local pid=${PIDS[$i]} + if kill -0 "$pid" 2>/dev/null; then + kill -KILL "$pid" 2>/dev/null || true + fi + dump_instance "$idx" + done + exit 124 +} +trap 'on_signal SIGTERM' TERM +trap 'on_signal SIGINT' INT + +for i in $(seq 1 "$NUM_INSTANCES"); do + REPRO_DIR="${REPRO_BASE}/instance_${i}" + LOG_DIR="${REPRO_DIR}/logs" + mkdir -p "${LOG_DIR}" + SEED=$((RANDOM * 32768 + RANDOM + i)) + + # Replace placeholders in arguments + INSTANCE_ARGS=() + for arg in "$@"; do + arg="${arg//__SEED__/$SEED}" + arg="${arg//__INSTANCE_ID__/$i}" + arg="${arg//__REPRO_DIR__/$REPRO_DIR}" + arg="${arg//__LOG_DIR__/$LOG_DIR}" + INSTANCE_ARGS+=("$arg") + done + + echo "Starting instance ${i}: seed=${SEED}, repro=${REPRO_DIR}" + "$BINARY" "${INSTANCE_ARGS[@]}" >"${REPRO_DIR}/stdout.log" 2>&1 & + PIDS+=($!) + SEEDS+=("$SEED") +done + +echo "Waiting for ${NUM_INSTANCES} instances to complete..." + +for i in "${!PIDS[@]}"; do + IDX=$((i + 1)) + if ! wait "${PIDS[$i]}"; then + echo "::error::Fuzzer instance ${IDX} FAILED (PID ${PIDS[$i]}, seed=${SEEDS[$i]})" + dump_instance "$IDX" + FAILED=$((FAILED + 1)) + else + echo "Instance ${IDX} passed (seed=${SEEDS[$i]})" + fi +done + +if [ "$FAILED" -gt 0 ]; then + echo "${FAILED} of ${NUM_INSTANCES} instance(s) failed" + exit 1 +fi + +echo "All ${NUM_INSTANCES} instances passed" diff --git a/scripts/ci/signature.py b/scripts/ci/signature.py index daa876942f1..f9534e88d0c 100644 --- a/scripts/ci/signature.py +++ b/scripts/ci/signature.py @@ -337,7 +337,7 @@ def parse_args(args): bias_command_parser.add_argument("contender", type=str) bias_command_parser.add_argument("output_path", type=str) bias_command_parser.add_argument( - "ticket_value", type=get_tickets, default=10, nargs="?" + "ticket_value", type=get_tickets, default=20, nargs="?" ) bias_command_parser.add_argument("error_path", type=str, default="") diff --git a/scripts/ci/spark/conf/spark-defaults.conf.example b/scripts/ci/spark/conf/spark-defaults.conf.example index 5b008b44801..55a356aa01b 100644 --- a/scripts/ci/spark/conf/spark-defaults.conf.example +++ b/scripts/ci/spark/conf/spark-defaults.conf.example @@ -1 +1,3 @@ -spark.master local[*] +spark.master local[1] +spark.driver.cores 1 +spark.sql.shuffle.partitions 1 diff --git a/scripts/ci/spark/conf/spark-env.sh.example b/scripts/ci/spark/conf/spark-env.sh.example index 8cd004a8613..4df23910796 100644 --- a/scripts/ci/spark/conf/spark-env.sh.example +++ b/scripts/ci/spark/conf/spark-env.sh.example @@ -1 +1,2 @@ export SPARK_DAEMON_MEMORY=5g +export SPARK_DAEMON_JAVA_OPTS="-XX:+UseG1GC -XX:MaxGCPauseMillis=200" diff --git a/scripts/docker/centos-multi.dockerfile b/scripts/docker/centos-multi.dockerfile index d957e70d511..cdd97603dae 100644 --- a/scripts/docker/centos-multi.dockerfile +++ b/scripts/docker/centos-multi.dockerfile @@ -17,6 +17,14 @@ # - centos9: Our base CI build # - adapters: Based on centos9 with all optional dependencies installed # - pyvelox: Image used by cibuildwheel to build pyvelox +# +# tzdata is pinned to a known-good version across every stage. Without +# the pin, each docker rebuild silently picks up the latest tzdata from +# the CentOS repos — tzdata 2026b's encoding of British Columbia's +# permanent-PDT change broke Velox's bundled libc++ chrono::tzdb parser +# (issue #17522). To bump intentionally: change the default below and +# re-run the Presto-SOT fuzzers locally to confirm before merging. +ARG CENTOS_TZDATA_VERSION=2026a-1.el9 ######################## # Stage 1: Base Build # @@ -24,6 +32,12 @@ ARG image=quay.io/centos/centos:stream9 FROM $image AS base-build +ARG CENTOS_TZDATA_VERSION +# Pin tzdata. `dnf install` is a no-op if already at the pinned version, +# and falls through to `downgrade` when the base image ships a newer one. +RUN dnf -y install "tzdata-${CENTOS_TZDATA_VERSION}.noarch" || \ + dnf -y downgrade "tzdata-${CENTOS_TZDATA_VERSION}.noarch" + COPY scripts/setup-helper-functions.sh / COPY scripts/setup-versions.sh / COPY scripts/setup-common.sh / @@ -32,8 +46,12 @@ COPY CMake/resolve_dependency_modules/arrow/cmake-compatibility.patch / ARG VELOX_BUILD_SHARED=ON # Building libvelox.so requires folly and gflags to be built shared as well for now +# gflags is always both shared and static turned on. ENV VELOX_BUILD_SHARED=${VELOX_BUILD_SHARED} +ARG ARM_BUILD_TARGET=local +ENV ARM_BUILD_TARGET=${ARM_BUILD_TARGET} + RUN mkdir build WORKDIR /build @@ -46,6 +64,11 @@ ENV UV_TOOL_BIN_DIR=/usr/local/bin \ ENV CMAKE_POLICY_VERSION_MINIMUM="3.5" \ VELOX_ARROW_CMAKE_PATCH=/cmake-compatibility.patch +# Ensure libraries installed to INSTALL_PREFIX are found at runtime (e.g. +# thrift1 needs libgflags.so.2.2 when folly links gflags statically but +# other tools still use shared gflags). +ENV LD_LIBRARY_PATH="${INSTALL_PREFIX}/lib:${INSTALL_PREFIX}/lib64" + # Some CMake configs contain the hard coded prefix '/deps', we need to replace that with # the future location to avoid build errors in the base-image RUN bash /setup-centos9.sh && \ @@ -56,6 +79,12 @@ RUN bash /setup-centos9.sh && \ ######################## FROM $image AS base-image +ARG CENTOS_TZDATA_VERSION +# Pin tzdata — see top of file for rationale. Inherited by centos9, +# pyvelox, and (transitively, via the centos9 tag) the java images. +RUN dnf -y install "tzdata-${CENTOS_TZDATA_VERSION}.noarch" || \ + dnf -y downgrade "tzdata-${CENTOS_TZDATA_VERSION}.noarch" + COPY scripts/setup-helper-functions.sh / COPY scripts/setup-versions.sh / COPY scripts/setup-common.sh / @@ -97,19 +126,29 @@ CMD ["/bin/bash"] ######################## FROM base-image AS pyvelox -ENV LD_LIBRARY_PATH="/usr/local/lib:/usr/local/lib64:$LD_LIBRARY_PATH" +RUN echo "/usr/local/lib" > /etc/ld.so.conf.d/velox_deps.conf \ + && echo "/usr/local/lib64" >> /etc/ld.so.conf.d/velox_deps.conf \ + && ldconfig ######################## # Stage: Adapters Build# ######################## FROM $image AS adapters-build +ARG CENTOS_TZDATA_VERSION +# Pin tzdata — see top of file for rationale. +RUN dnf -y install "tzdata-${CENTOS_TZDATA_VERSION}.noarch" || \ + dnf -y downgrade "tzdata-${CENTOS_TZDATA_VERSION}.noarch" + COPY scripts/setup-helper-functions.sh / COPY scripts/setup-versions.sh / COPY scripts/setup-common.sh / COPY scripts/setup-centos9.sh / COPY scripts/setup-centos-adapters.sh / +ARG ARM_BUILD_TARGET=local +ENV ARM_BUILD_TARGET=${ARM_BUILD_TARGET} + RUN mkdir build WORKDIR /build @@ -129,12 +168,22 @@ FROM centos9 AS adapters COPY scripts/setup-centos-adapters.sh / +ARG CUDA_VERSION +ENV CUDA_VERSION=${CUDA_VERSION:-12.9} + RUN bash /setup-centos-adapters.sh install_cuda && \ dnf clean all RUN bash /setup-centos-adapters.sh install_adapters_deps_from_dnf && \ dnf clean all +ARG CENTOS_TZDATA_VERSION +# tzdata-java is pulled in by the Java install above. Pin it to match +# the OS tzdata pinned in centos9 so the JDK and Velox C++ see the +# same timezone rules — issue #17522 was caused by this drift. +RUN dnf -y install "tzdata-java-${CENTOS_TZDATA_VERSION}.noarch" || \ + dnf -y downgrade "tzdata-java-${CENTOS_TZDATA_VERSION}.noarch" + # put CUDA binaries on the PATH ENV PATH=/usr/local/cuda/bin:${PATH} @@ -161,6 +210,11 @@ ENV HADOOP_HOME=/usr/local/hadoop \ COPY --from=adapters-build /deps /usr/local +# thrift1 requires shared libraries copied from /deps to /usr/local. +RUN echo "/usr/local/lib" > /etc/ld.so.conf.d/velox_deps.conf \ + && echo "/usr/local/lib64" >> /etc/ld.so.conf.d/velox_deps.conf \ + && ldconfig + COPY scripts/setup-classpath.sh / ENTRYPOINT ["/bin/bash", "-c", "source /setup-classpath.sh && source /opt/rh/gcc-toolset-12/enable && exec \"$@\"", "--"] CMD ["/bin/bash"] diff --git a/scripts/docker/fedora.dockerfile b/scripts/docker/fedora.dockerfile index d098c03f721..586b6a02133 100644 --- a/scripts/docker/fedora.dockerfile +++ b/scripts/docker/fedora.dockerfile @@ -12,12 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +# tzdata is pinned to a known-good version so docker rebuilds (any time +# scripts/docker/*.dockerfile or scripts/setup-*.sh changes) don't +# silently bump tzdata in the image. See issue #17522 for the bug class +# this prevents — version mismatch between OS tzdata and consumers' +# bundled tzdb code can produce silent 1-hour offsets in TIMESTAMP +# WITH TIME ZONE values. To bump intentionally: change the default and +# rebuild locally to confirm before merging. +ARG FEDORA_TZDATA_VERSION=2025c-1.fc42 + ######################## # Stage 1: Base Build # ######################## ARG base=quay.io/fedora/fedora:42-x86_64 FROM $base AS base-build +ARG FEDORA_TZDATA_VERSION +RUN dnf -y install "tzdata-${FEDORA_TZDATA_VERSION}.noarch" || \ + dnf -y downgrade "tzdata-${FEDORA_TZDATA_VERSION}.noarch" + COPY scripts/setup-helper-functions.sh / COPY scripts/setup-versions.sh / COPY scripts/setup-common.sh / @@ -51,6 +64,10 @@ RUN bash /setup-fedora.sh && \ ######################## FROM $base AS fedora +ARG FEDORA_TZDATA_VERSION +RUN dnf -y install "tzdata-${FEDORA_TZDATA_VERSION}.noarch" || \ + dnf -y downgrade "tzdata-${FEDORA_TZDATA_VERSION}.noarch" + COPY scripts/setup-helper-functions.sh / COPY scripts/setup-versions.sh / COPY scripts/setup-common.sh / @@ -64,6 +81,7 @@ ENV UV_TOOL_BIN_DIR=/usr/local/bin \ RUN /bin/bash -c 'source /setup-fedora.sh && \ install_build_prerequisites && \ install_velox_deps_from_dnf && \ + dnf_install jq gh &&\ dnf clean all' RUN ln -s $(which python3) /usr/bin/python diff --git a/scripts/docker/java.dockerfile b/scripts/docker/java.dockerfile index c51508af98b..be9ad58e140 100644 --- a/scripts/docker/java.dockerfile +++ b/scripts/docker/java.dockerfile @@ -15,7 +15,11 @@ # Global arg default to share across stages ARG SPARK_VERSION=3.5.1 -ARG PRESTO_VERSION=0.293 +ARG PRESTO_VERSION=0.295 +# tzdata version pin — see scripts/docker/centos-multi.dockerfile for the +# rationale. Keep this string in sync with CENTOS_TZDATA_VERSION there so +# the OS tzdata and JDK tzdata-java agree on timezone rules. Issue #17522. +ARG CENTOS_TZDATA_VERSION=2026a-1.el9 ######################### # Stage: Spark Download # @@ -40,7 +44,7 @@ ARG PRESTO_VERSION RUN wget -O presto-server.tar.gz \ https://repo1.maven.org/maven2/com/facebook/presto/presto-server/${PRESTO_VERSION}/presto-server-${PRESTO_VERSION}.tar.gz RUN wget -O presto-cli \ - https://repo1.maven.org/maven2/com/facebook/presto/presto-cli/${PRESTO_VERSION}/presto-cli-${PRESTO_VERSION}-executable.jar + https://github.com/prestodb/presto/releases/download/${PRESTO_VERSION}/presto-cli-${PRESTO_VERSION}-executable.jar RUN tar -xzf presto-server.tar.gz @@ -49,7 +53,17 @@ RUN tar -xzf presto-server.tar.gz ######################### FROM ghcr.io/facebookincubator/velox-dev:centos9 AS java-base -RUN dnf install -y -q --setopt=install_weak_deps=False java-11-openjdk less procps tzdata +ARG CENTOS_TZDATA_VERSION +# Pin tzdata-java to the same version as the OS tzdata pinned in the +# centos9 base. The base image already has tzdata pinned; we re-pin +# here as a defensive `install || downgrade` so this stage builds +# correctly even before the centos9 image has been rebuilt with the pin. +RUN dnf install -y -q --setopt=install_weak_deps=False \ + java-17-openjdk less procps tzdata && \ + (dnf -y install "tzdata-java-${CENTOS_TZDATA_VERSION}.noarch" || \ + dnf -y downgrade "tzdata-java-${CENTOS_TZDATA_VERSION}.noarch") && \ + (dnf -y install "tzdata-${CENTOS_TZDATA_VERSION}.noarch" || \ + dnf -y downgrade "tzdata-${CENTOS_TZDATA_VERSION}.noarch") # We set the timezone to America/Los_Angeles due to issue # detailed here : https://github.com/facebookincubator/velox/issues/8127 diff --git a/scripts/docker/ubuntu-22.04-cpp.dockerfile b/scripts/docker/ubuntu-22.04-cpp.dockerfile index 8175ecd6823..e96efe33733 100644 --- a/scripts/docker/ubuntu-22.04-cpp.dockerfile +++ b/scripts/docker/ubuntu-22.04-cpp.dockerfile @@ -11,9 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# tzdata is pinned to a known-good version so docker rebuilds (any time +# scripts/docker/*.dockerfile or scripts/setup-*.sh changes) don't +# silently bump tzdata in the image. See issue #17522 for the bug class +# this prevents — version mismatch between OS tzdata and consumers' +# bundled tzdb code can produce silent 1-hour offsets in TIMESTAMP +# WITH TIME ZONE values. To bump intentionally: change the default and +# rebuild locally to confirm before merging. +ARG UBUNTU_TZDATA_VERSION=2026a-0ubuntu0.22.04.1 + ARG base=ubuntu:22.04 FROM ${base} +ARG UBUNTU_TZDATA_VERSION +ARG DEBIAN_FRONTEND="noninteractive" +RUN apt-get update && \ + apt-get install -y --allow-downgrades "tzdata=${UBUNTU_TZDATA_VERSION}" + RUN apt update && \ apt install -y sudo \ lsb-release \ @@ -37,4 +51,19 @@ ARG tz="Etc/UTC" ENV TZ=${tz} RUN /bin/bash -o pipefail /velox/scripts/setup-ubuntu.sh +# Install tools needed for CI (gh for GitHub Actions stash, jq for JSON parsing) +RUN apt-get update && \ + apt-get install -y -q --no-install-recommends jq && \ + curl -fsSL https://cli.github.com/packages/githubcli-archive-keyring.gpg \ + | dd of=/usr/share/keyrings/githubcli-archive-keyring.gpg && \ + echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" \ + | tee /etc/apt/sources.list.d/github-cli.list > /dev/null && \ + apt-get update && apt-get install -y -q --no-install-recommends gh && \ + apt-get clean && rm -rf /var/lib/apt/lists/* + +# Pre-download gflags source for BUNDLED builds to avoid downloading at build time. +RUN mkdir -p /velox/deps-sources && \ + curl -fsSL -o /velox/deps-sources/gflags-v2.2.2.tar.gz \ + https://github.com/gflags/gflags/archive/refs/tags/v2.2.2.tar.gz + WORKDIR /velox diff --git a/scripts/review/REVIEW_GUIDE.md b/scripts/review/REVIEW_GUIDE.md new file mode 100644 index 00000000000..591d764b236 --- /dev/null +++ b/scripts/review/REVIEW_GUIDE.md @@ -0,0 +1,106 @@ +# PR Review Style Guide + +## Procedure + +1. Fetch PR using `scripts/review/fetch.py`. This is the only fetch needed — + work from its output for all subsequent analysis. Do not make additional + `gh api` or `gh pr diff` calls. +2. Draft review to `~/.claude/review-drafts/pr-XXXXX-r1-v1.md`. +3. Show draft and get approval before posting. +4. Post using `scripts/review/post.py`. + +## First rule + +Before writing anything, ask: "Do I understand what this PR does end-to-end, +and have I verified the claims?" If the answer is no, dig deeper before +drafting. The most common mistake is focusing on code details before +understanding the change. + +## Tone + +- **Opening:** "Thank you for the fix/contribution!" — then straight to the + points. Don't elaborate on what the PR does or praise specific design choices. +- **Don't restate the PR description.** The author knows what they wrote. +- **Skip "Thank you" when the PR needs fundamental clarification.** Lead with + the question instead. +- **Be extra respectful when asking contributors to engage upstream.** They're + volunteering their time. Suggest filing an issue first — it shows respect for + upstream maintainers' expertise. +- **Be encouraging with new contributors.** Acknowledge the value of the + capability they're adding. + +## Structure + +- Reviews should be concise and actionable. The author wants to get their work + done — don't waste their time with fluff. +- **Order: big picture first.** Documentation, design questions, then code, + then tests. The most impactful feedback should come first. +- **Every point must be actionable.** No observations without asks. If it + doesn't require the author to change something, don't include it. +- **Drop qualifiers when the fix is obvious.** Don't explain why something is + wrong if the fix is self-evident. + +## Rigor + +- **Verify before claiming.** Don't assert facts about the codebase, other + projects, or behavior without checking. When in doubt, read the code. +- **Verify author claims.** When an author says an API doesn't support + something, or references another PR/diff/external behavior, check the source. + Don't accept at face value. +- **Check terminology.** Use precise terms. Don't conflate catalog/connector, + function/method, etc. + +## Re-reviews + +When the author says "addressed comments", re-review the full diff — don't +just check the boxes from the previous round. Docs, naming, design choices, +and new code added while addressing feedback all need fresh eyes. + +## What to check + +### Correctness + +- **PR title and description.** Are they clear, succinct, accurate? Does the + title describe what changed (not the symptom)? Does the description match + what the code actually does? +- **CI status.** Is CI green? If red, is it related to the PR or pre-existing? +- **Design.** Question design choices, don't just document them. Flag + surprising behavior (e.g., auto-selection), unnecessary complexity, or + features that could be simpler. +- **API design.** Flag anti-patterns: setter injection when constructor params + would work, bypassing existing methods with inline boilerplate, unnecessary + type aliases. When a PR adds API surface for an external consumer, question + the consumer's architecture — "why does the caller need this?" matters more + than "what's the cleanest way to expose it?" +- **Registries.** New registries should follow existing patterns (e.g., + query-scoped registries design in #16993). + +### Code quality + +- **Velox coding conventions.** Ensure code follows + [CODING_STYLE.md](../../CODING_STYLE.md) and the rules in + [.claude/CLAUDE.md](../../.claude/CLAUDE.md). +- **Comments.** Flag verbose code comments that restate the code, duplicate + docs elsewhere, or explain obvious behavior. Remove references to other + implementations ("like Java Presto") — logic should stand on its own. +- **Naming.** Check variable names, file names, class names against coding + style conventions. Do not abbreviate parameter names. +- **Enums.** `kPascalCase` enumerators, trailing commas, `VELOX_DEFINE_ENUM_NAME`. + +### Testing + +- **Tests.** Are they sufficient? Do they reproduce the bug? Is there an + integration test, not just a unit test? Are expected values hand-computed + (fragile) or derived from the test input? Are test helpers used to reduce + duplication? Do test names follow conventions? +- **Test files.** Each test file should have one test suite with a matching + name. Empty test fixtures should use `TEST()` instead of `TEST_F()`. + +### Documentation + +- Are new functions/features documented? Are docs updated for changed behavior? + Are README changes accurate (not removing still-valid content)? +- Check if **existing** doc pages need updating — e.g., a change to plan output + may require updating the print-plan-with-stats page, a dependency version + bump may require updating the dependency table. +- **When unsure about conventions**, CC the maintainer rather than guessing. diff --git a/scripts/review/fetch.py b/scripts/review/fetch.py new file mode 100755 index 00000000000..54ecf9b8340 --- /dev/null +++ b/scripts/review/fetch.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fetch all review-relevant information for a GitHub PR in one shot. + +Usage: fetch.py + fetch.py + +Examples: + fetch.py facebookincubator/velox 17495 + fetch.py https://github.com/facebookincubator/velox/pull/17495 +""" + +import json +import re +import subprocess +import sys + + +def run_gh(*args): + result = subprocess.run(["gh"] + list(args), capture_output=True, text=True) + if result.returncode != 0: + print(f"Error: {result.stderr.strip()}", file=sys.stderr) + return None + return result.stdout.strip() + + +def parse_args(): + if len(sys.argv) == 2: + match = re.match(r"https?://github\.com/([^/]+/[^/]+)/pull/(\d+)", sys.argv[1]) + if match: + return match.group(1), match.group(2) + print(f"Invalid URL: {sys.argv[1]}", file=sys.stderr) + sys.exit(1) + elif len(sys.argv) == 3: + return sys.argv[1], sys.argv[2] + else: + print(__doc__.strip(), file=sys.stderr) + sys.exit(1) + + +def fetch_pr_metadata(repo, pr): + raw = run_gh( + "pr", + "view", + pr, + "--repo", + repo, + "--json", + "title,body,author,state,files,additions,deletions,baseRefName,headRefName", + ) + if not raw: + return None + return json.loads(raw) + + +def fetch_diff(repo, pr): + return run_gh("pr", "diff", pr, "--repo", repo) + + +def fetch_issue_comments(repo, pr): + raw = run_gh( + "api", + f"repos/{repo}/issues/{pr}/comments", + "--paginate", + ) + if not raw: + return [] + return json.loads(raw) + + +def fetch_review_comments(repo, pr): + raw = run_gh( + "api", + f"repos/{repo}/pulls/{pr}/comments", + "--paginate", + ) + if not raw: + return [] + return json.loads(raw) + + +def fetch_reviews(repo, pr): + raw = run_gh( + "api", + f"repos/{repo}/pulls/{pr}/reviews", + ) + if not raw: + return [] + return json.loads(raw) + + +def print_section(title, content): + print(f"\n{'=' * 60}") + print(f" {title}") + print(f"{'=' * 60}\n") + print(content) + + +def main(): + repo, pr = parse_args() + + metadata = fetch_pr_metadata(repo, pr) + if not metadata: + sys.exit(1) + + print_section( + f"PR #{pr}: {metadata['title']}", + f"Author: {metadata['author']['login']}\n" + f"State: {metadata['state']}\n" + f"Branch: {metadata['headRefName']} -> {metadata['baseRefName']}\n" + f"+{metadata['additions']} -{metadata['deletions']}\n" + f"Files: {', '.join(f['path'] for f in metadata['files'])}\n" + f"\n{metadata['body']}", + ) + + diff = fetch_diff(repo, pr) + if diff: + print_section("Diff", diff) + + skip = {"netlify[bot]", "github-actions[bot]"} + + comments = fetch_issue_comments(repo, pr) + visible = [c for c in comments if c["user"]["login"] not in skip] + if visible: + lines = [] + for c in visible: + lines.append(f"--- {c['user']['login']} ({c['created_at']}) ---") + lines.append(c["body"]) + lines.append("") + print_section("Comments", "\n".join(lines)) + + reviews = fetch_reviews(repo, pr) + visible_reviews = [r for r in reviews if r.get("body")] + if visible_reviews: + lines = [] + for r in visible_reviews: + lines.append( + f"--- {r['user']['login']} ({r['state']}, {r['submitted_at']}) ---" + ) + lines.append(r["body"]) + lines.append("") + print_section("Reviews", "\n".join(lines)) + + inline = fetch_review_comments(repo, pr) + if inline: + lines = [] + for c in inline: + loc = f"{c['path']}:{c.get('line', c.get('original_line', '?'))}" + lines.append(f"--- {c['user']['login']} on {loc} ({c['created_at']}) ---") + lines.append(c["body"]) + lines.append("") + print_section("Inline Comments", "\n".join(lines)) + + +if __name__ == "__main__": + main() diff --git a/scripts/review/post.py b/scripts/review/post.py new file mode 100755 index 00000000000..1f461daa50a --- /dev/null +++ b/scripts/review/post.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Post a GitHub PR review from a file. + +Usage: post.py + post.py + +Events: APPROVE, REQUEST_CHANGES, COMMENT + +Examples: + post.py facebookincubator/velox 17495 REQUEST_CHANGES /tmp/review.md + post.py https://github.com/facebookincubator/velox/pull/17495 APPROVE /tmp/review.md +""" + +import json +import re +import subprocess +import sys + +VALID_EVENTS = {"APPROVE", "REQUEST_CHANGES", "COMMENT"} + + +def run_gh(*args, stdin=None): + result = subprocess.run( + ["gh"] + list(args), + input=stdin, + capture_output=True, + text=True, + ) + if result.returncode != 0: + print(f"Error: {result.stderr.strip()}", file=sys.stderr) + sys.exit(1) + return result.stdout.strip() + + +def parse_args(): + if len(sys.argv) == 4: + match = re.match(r"https?://github\.com/([^/]+/[^/]+)/pull/(\d+)", sys.argv[1]) + if match: + return match.group(1), match.group(2), sys.argv[2], sys.argv[3] + print(f"Invalid URL: {sys.argv[1]}", file=sys.stderr) + sys.exit(1) + elif len(sys.argv) == 5: + return sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4] + else: + print(__doc__.strip(), file=sys.stderr) + sys.exit(1) + + +def main(): + repo, pr, event, body_file = parse_args() + + event = event.upper() + if event not in VALID_EVENTS: + print( + f"Invalid event: {event}. Must be one of: {', '.join(sorted(VALID_EVENTS))}", + file=sys.stderr, + ) + sys.exit(1) + + with open(body_file) as f: + body = f.read() + + if not body.strip(): + print("Error: review body is empty", file=sys.stderr) + sys.exit(1) + + payload = json.dumps({"event": event, "body": body, "comments": []}) + + result = run_gh( + "api", + f"repos/{repo}/pulls/{pr}/reviews", + "--method", + "POST", + "--input", + "-", + stdin=payload, + ) + + data = json.loads(result) + print(data.get("html_url", "Review submitted")) + + +if __name__ == "__main__": + main() diff --git a/scripts/setup-centos-adapters.sh b/scripts/setup-centos-adapters.sh index 6e8276c4848..ded2cd556d3 100755 --- a/scripts/setup-centos-adapters.sh +++ b/scripts/setup-centos-adapters.sh @@ -23,15 +23,51 @@ # * INSTALL_PREREQUISITES="N": Skip installation of packages for build. # * PROMPT_ALWAYS_RESPOND="n": Automatically respond to interactive prompts. # Use "n" to never wipe directories. -# * VELOX_CUDA_VERSION="12.8": Which version of CUDA to install, will pick up +# * VELOX_CUDA_VERSION="12.9": Which version of CUDA to install, will pick up # CUDA_VERSION from the env +# * VELOX_UCX_VERSION="1.19.0": Which version of ucx to install, will pick up +# UCX_VERSION from the env set -efx -o pipefail -VELOX_CUDA_VERSION=${CUDA_VERSION:-"12.8"} +VELOX_CUDA_VERSION=${CUDA_VERSION:-"12.9"} +VELOX_UCX_VERSION=${UCX_VERSION:-"1.19.0"} SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}") source "$SCRIPT_DIR"/setup-centos9.sh +function install_ucx { + dnf_install rdma-core-devel + local UCX_REPO_NAME="openucx/ucx" + local NEEDS_AUTOGEN=false + + if [ "${VELOX_UCX_VERSION}" == "master" ]; then + github_checkout "${UCX_REPO_NAME}" "${VELOX_UCX_VERSION}" + NEEDS_AUTOGEN=true + else + wget_and_untar https://github.com/openucx/ucx/releases/download/v"${VELOX_UCX_VERSION}"/ucx-"${VELOX_UCX_VERSION}".tar.gz ucx + fi + + ( + cd "${DEPENDENCY_DIR}"/ucx || exit + if [ "${NEEDS_AUTOGEN}" = true ]; then + ./autogen.sh + fi + + local CUDA_FLAG="" + if [ -d "/usr/local/cuda" ]; then + CUDA_FLAG="--with-cuda=/usr/local/cuda" + fi + + mkdir build-linux && cd build-linux + + ../contrib/configure-release --prefix="${INSTALL_PREFIX}" --with-sysroot --enable-cma \ + --enable-mt --with-gnu-ld --with-rdmacm --with-verbs \ + --without-go --without-java ${CUDA_FLAG} + make "-j${NPROC}" + make install + ) +} + function install_cuda { # See https://developer.nvidia.com/cuda-downloads local arch @@ -52,14 +88,15 @@ function install_cuda { dnf config-manager --add-repo "$repo_url" local dashed dashed="$(echo "$version" | tr '.' '-')" - dnf_install --repo cuda-rhel9-"$arch" \ + dnf_install \ cuda-compat-"$dashed" \ cuda-driver-devel-"$dashed" \ cuda-minimal-build-"$dashed" \ cuda-nvrtc-devel-"$dashed" \ libcufile-devel-"$dashed" \ libnvjitlink-devel-"$dashed" \ - numactl-libs + cuda-nvml-devel-"$dashed" \ + numactl-devel } function install_adapters_deps_from_dnf { @@ -79,8 +116,8 @@ function install_s3 { function install_adapters { run_and_time install_adapters_deps_from_dnf run_and_time install_s3 - run_and_time install_gcs-sdk-cpp - run_and_time install_azure-storage-sdk-cpp + run_and_time install_gcs_sdk_cpp + run_and_time install_azure_storage_sdk_cpp run_and_time install_hdfs_deps } diff --git a/scripts/setup-centos9.sh b/scripts/setup-centos9.sh index b1dd08b0d88..44b60be2151 100755 --- a/scripts/setup-centos9.sh +++ b/scripts/setup-centos9.sh @@ -56,11 +56,12 @@ function install_build_prerequisites { dnf config-manager --set-enabled crb dnf update -y fi - dnf_install autoconf automake ccache gcc-toolset-12 git libtool \ - ninja-build python3-pip python3-devel wget which + dnf_install autoconf automake ccache clang compiler-rt \ + gcc-toolset-12 gcc-toolset-14 git libtool \ + llvm ninja-build python3-pip python3-devel wget which install_uv - uv_install cmake + uv_install cmake@3.31.1 if [[ ${USE_CLANG} != "false" ]]; then install_clang15 @@ -85,7 +86,12 @@ function install_gflags { # Remove an older version if present. dnf remove -y gflags wget_and_untar https://github.com/gflags/gflags/archive/"${GFLAGS_VERSION}".tar.gz gflags - cmake_install_dir gflags -DBUILD_SHARED_LIBS=ON -DBUILD_STATIC_LIBS=ON -DBUILD_gflags_LIB=ON -DLIB_SUFFIX=64 + cmake_install_dir gflags -DBUILD_SHARED_LIBS=ON -DBUILD_STATIC_LIBS=ON -DBUILD_gflags_LIB=ON + # CentOS 9 does not include ${INSTALL_PREFIX}/lib in the default ldconfig + # search paths. Register it so that downstream builds (e.g. fbthrift's + # thrift1 compiler) can find libgflags.so at runtime. + echo "${INSTALL_PREFIX}/lib" >/etc/ld.so.conf.d/usr-local.conf + ldconfig } function install_faiss_deps { @@ -114,6 +120,7 @@ function install_velox_deps { run_and_time install_xsimd run_and_time install_simdjson run_and_time install_geos + run_and_time install_s2geometry run_and_time install_faiss } diff --git a/scripts/setup-common.sh b/scripts/setup-common.sh index 59461664942..9699939d2af 100755 --- a/scripts/setup-common.sh +++ b/scripts/setup-common.sh @@ -24,9 +24,12 @@ VELOX_ARROW_CMAKE_PATCH=${VELOX_ARROW_CMAKE_PATCH:-""} # avoid error due to +u CMAKE_BUILD_TYPE="${BUILD_TYPE:-Release}" DEPENDENCY_DIR=${DEPENDENCY_DIR:-$(pwd)} BUILD_GEOS="${BUILD_GEOS:-true}" +BUILD_S2GEOMETRY="${BUILD_S2GEOMETRY:-true}" BUILD_FAISS="${BUILD_FAISS:-true}" BUILD_DUCKDB="${BUILD_DUCKDB:-true}" EXTRA_ARROW_OPTIONS=${EXTRA_ARROW_OPTIONS:-""} +EXTRA_ARROW_PATCH=${EXTRA_ARROW_PATCH:-""} +SIMDJSON_SKIPUTF8VALIDATION=${SIMDJSON_SKIPUTF8VALIDATION:-"OFF"} USE_CLANG="${USE_CLANG:-false}" @@ -44,19 +47,18 @@ function install_fmt { } function install_folly { - # Folly Portability.h being used to decide whether or not support coroutines - # causes issues (build, link) if the selection is not consistent across users of folly. - # shellcheck disable=SC2034 - EXTRA_PKG_CXXFLAGS=" -DFOLLY_CFG_NO_COROUTINES" wget_and_untar https://github.com/facebook/folly/archive/refs/tags/"${FB_OS_VERSION}".tar.gz folly - cmake_install_dir folly -DBUILD_SHARED_LIBS="$VELOX_BUILD_SHARED" -DBUILD_TESTS=OFF -DFOLLY_HAVE_INT128_T=ON + local FOLLY_FLAGS=(-DBUILD_SHARED_LIBS="$VELOX_BUILD_SHARED" -DBUILD_TESTS=OFF -DFOLLY_HAVE_INT128_T=ON) + # When folly is static, use static gflags to avoid dual gflags flag + # registration when .so plugins are dlopen'd (both the binary and plugin + # would register the same flags in a shared gflags registry). + if [[ ${VELOX_BUILD_SHARED} != "ON" ]]; then + FOLLY_FLAGS+=(-DGFLAGS_SHARED=FALSE) + fi + cmake_install_dir folly "${FOLLY_FLAGS[@]}" } function install_fizz { - # Folly Portability.h being used to decide whether or not support coroutines - # causes issues (build, link) if the selection is not consistent across users of folly. - # shellcheck disable=SC2034 - EXTRA_PKG_CXXFLAGS=" -DFOLLY_CFG_NO_COROUTINES" wget_and_untar https://github.com/facebookincubator/fizz/archive/refs/tags/"${FB_OS_VERSION}".tar.gz fizz cmake_install_dir fizz/fizz -DBUILD_TESTS=OFF } @@ -67,28 +69,16 @@ function install_fast_float { } function install_wangle { - # Folly Portability.h being used to decide whether or not support coroutines - # causes issues (build, link) if the selection is not consistent across users of folly. - # shellcheck disable=SC2034 - EXTRA_PKG_CXXFLAGS=" -DFOLLY_CFG_NO_COROUTINES" wget_and_untar https://github.com/facebook/wangle/archive/refs/tags/"${FB_OS_VERSION}".tar.gz wangle cmake_install_dir wangle/wangle -DBUILD_TESTS=OFF } function install_mvfst { - # Folly Portability.h being used to decide whether or not support coroutines - # causes issues (build, link) if the selection is not consistent across users of folly. - # shellcheck disable=SC2034 - EXTRA_PKG_CXXFLAGS=" -DFOLLY_CFG_NO_COROUTINES" wget_and_untar https://github.com/facebook/mvfst/archive/refs/tags/"${FB_OS_VERSION}".tar.gz mvfst cmake_install_dir mvfst -DBUILD_TESTS=OFF } function install_fbthrift { - # Folly Portability.h being used to decide whether or not support coroutines - # causes issues (build, link) if the selection is not consistent across users of folly. - # shellcheck disable=SC2034 - EXTRA_PKG_CXXFLAGS=" -DFOLLY_CFG_NO_COROUTINES" wget_and_untar https://github.com/facebook/fbthrift/archive/refs/tags/"${FB_OS_VERSION}".tar.gz fbthrift cmake_install_dir fbthrift -Denable_tests=OFF -DBUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF } @@ -125,10 +115,25 @@ function install_boost { } function install_protobuf { + install_abseil + wget_and_untar https://github.com/protocolbuffers/protobuf/releases/download/v"${PROTOBUF_VERSION}"/protobuf-all-"${PROTOBUF_VERSION}".tar.gz protobuf cmake_install_dir protobuf -Dprotobuf_BUILD_TESTS=OFF -Dprotobuf_ABSL_PROVIDER=package } +function install_grpc { + wget_and_untar https://github.com/grpc/grpc/archive/refs/tags/v"${GRPC_VERSION}".tar.gz grpc + cmake_install_dir grpc \ + -DgRPC_BUILD_TESTS=OFF \ + -DgRPC_ABSL_PROVIDER=package \ + -DgRPC_ZLIB_PROVIDER=package \ + -DgRPC_CARES_PROVIDER=package \ + -DgRPC_RE2_PROVIDER=package \ + -DgRPC_SSL_PROVIDER=package \ + -DgRPC_PROTOBUF_PROVIDER=package \ + -DgRPC_INSTALL=ON +} + function install_double_conversion { wget_and_untar https://github.com/google/double-conversion/archive/refs/tags/"${DOUBLE_CONVERSION_VERSION}".tar.gz double-conversion cmake_install_dir double-conversion -DBUILD_TESTING=OFF @@ -139,9 +144,29 @@ function install_ranges_v3 { cmake_install_dir ranges_v3 -DRANGES_ENABLE_WERROR=OFF -DRANGE_V3_TESTS=OFF -DRANGE_V3_EXAMPLES=OFF } +function install_abseil { + wget_and_untar https://github.com/abseil/abseil-cpp/archive/refs/tags/"${ABSEIL_VERSION}".tar.gz abseil-cpp + local OS + OS=$(uname) + if [[ $OS == "Darwin" ]]; then + ABSOLUTE_SCRIPTDIR=$(realpath "$SCRIPT_DIR") + ( + cd "${DEPENDENCY_DIR}/abseil-cpp" || exit 1 + git apply $ABSOLUTE_SCRIPTDIR/../CMake/resolve_dependency_modules/absl/absl-macos.patch + ) + fi + cmake_install_dir abseil-cpp \ + -DABSL_BUILD_TESTING=OFF \ + -DCMAKE_CXX_STANDARD=17 \ + -DABSL_PROPAGATE_CXX_STD=ON \ + -DABSL_ENABLE_INSTALL=ON +} + function install_re2 { + install_abseil + wget_and_untar https://github.com/google/re2/archive/refs/tags/"${RE2_VERSION}".tar.gz re2 - cmake_install_dir re2 -DRE2_BUILD_TESTING=OFF + cmake_install_dir re2 -DRE2_BUILD_TESTING=OFF -Dabsl_DIR="${INSTALL_PREFIX}/lib/cmake/absl" } function install_glog { @@ -150,7 +175,7 @@ function install_glog { } function install_lzo { - wget_and_untar http://www.oberhumer.com/opensource/lzo/download/lzo-"${LZO_VERSION}".tar.gz lzo + wget_and_untar https://www.oberhumer.com/opensource/lzo/download/lzo-"${LZO_VERSION}".tar.gz lzo ( cd "${DEPENDENCY_DIR}"/lzo || exit ./configure --prefix="${INSTALL_PREFIX}" --enable-shared --disable-static --docdir=/usr/share/doc/lzo-"${LZO_VERSION}" @@ -171,7 +196,7 @@ function install_xsimd { function install_simdjson { wget_and_untar https://github.com/simdjson/simdjson/archive/refs/tags/v"${SIMDJSON_VERSION}".tar.gz simdjson - cmake_install_dir simdjson + cmake_install_dir simdjson -DSIMDJSON_SKIPUTF8VALIDATION=${SIMDJSON_SKIPUTF8VALIDATION} } function install_arrow { @@ -186,6 +211,10 @@ function install_arrow { cd "$DEPENDENCY_DIR"/arrow || exit 1 git apply "$VELOX_ARROW_CMAKE_PATCH" + # Presto needs this for Arrow Flight + if [[ -n $EXTRA_ARROW_PATCH ]]; then + git apply "$EXTRA_ARROW_PATCH" + fi ) || exit 1 cmake_install_dir arrow/cpp \ @@ -252,6 +281,16 @@ function install_geos { fi } +function install_s2geometry { + if [[ $BUILD_S2GEOMETRY == "true" ]]; then + wget_and_untar https://github.com/google/s2geometry/archive/refs/tags/v"${S2GEOMETRY_VERSION}".tar.gz s2geometry + # Apply the same GCC 12 patch used by the BUNDLED CMake resolver. + patch -p1 -d "${DEPENDENCY_DIR}/s2geometry" < \ + "${SCRIPT_DIR}/../CMake/resolve_dependency_modules/s2geometry/s2geometry-gcc12-max.patch" || true + cmake_install_dir s2geometry -DBUILD_TESTING=OFF -DBUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF + fi +} + function install_faiss_deps { echo "Unsupported platform for faiss" } @@ -277,7 +316,7 @@ function install_aws_deps { local AWS_REPO_NAME="aws/aws-sdk-cpp" github_checkout $AWS_REPO_NAME "$AWS_SDK_VERSION" --depth 1 --recurse-submodules - cmake_install -DCMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE}" -DBUILD_SHARED_LIBS:BOOL=OFF -DMINIMIZE_SIZE:BOOL=ON -DENABLE_TESTING:BOOL=OFF -DBUILD_ONLY:STRING="s3;identity-management" + cmake_install_dir aws-sdk-cpp -DCMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE}" -DBUILD_SHARED_LIBS:BOOL=OFF -DMINIMIZE_SIZE:BOOL=ON -DENABLE_TESTING:BOOL=OFF -DBUILD_ONLY:STRING="s3;identity-management" } function install_minio { @@ -298,62 +337,41 @@ function install_minio { ${SUDO} mv ./"${MINIO_BINARY_NAME}" "$INSTALL_PREFIX"/bin/ } -function install_gcs-sdk-cpp { +function install_gcs_sdk_cpp { # Install gcs dependencies # https://github.com/googleapis/google-cloud-cpp/blob/main/doc/packaging.md#required-libraries - # abseil-cpp - github_checkout abseil/abseil-cpp "${ABSEIL_VERSION}" --depth 1 - cmake_install \ - -DABSL_BUILD_TESTING=OFF \ - -DCMAKE_CXX_STANDARD=17 \ - -DABSL_PROPAGATE_CXX_STD=ON \ - -DABSL_ENABLE_INSTALL=ON - - # protobuf - github_checkout protocolbuffers/protobuf v"${PROTOBUF_VERSION}" --depth 1 - cmake_install \ - -Dprotobuf_BUILD_TESTS=OFF \ - -Dprotobuf_ABSL_PROVIDER=package - - # grpc - github_checkout grpc/grpc "${GRPC_VERSION}" --depth 1 - cmake_install \ - -DgRPC_BUILD_TESTS=OFF \ - -DgRPC_ABSL_PROVIDER=package \ - -DgRPC_ZLIB_PROVIDER=package \ - -DgRPC_CARES_PROVIDER=package \ - -DgRPC_RE2_PROVIDER=package \ - -DgRPC_SSL_PROVIDER=package \ - -DgRPC_PROTOBUF_PROVIDER=package \ - -DgRPC_INSTALL=ON + # abseil-cpp, protobuf, grpc + install_protobuf + install_grpc # crc32 - github_checkout google/crc32c "${CRC32_VERSION}" --depth 1 - cmake_install \ + wget_and_untar https://github.com/google/crc32c/archive/refs/tags/"${CRC32_VERSION}".tar.gz crc32c + cmake_install_dir crc32c \ -DCRC32C_BUILD_TESTS=OFF \ -DCRC32C_BUILD_BENCHMARKS=OFF \ -DCRC32C_USE_GLOG=OFF # nlohmann json - github_checkout nlohmann/json "${NLOHMAN_JSON_VERSION}" --depth 1 - cmake_install \ + wget_and_untar https://github.com/nlohmann/json/archive/refs/tags/v"${NLOHMAN_JSON_VERSION}".tar.gz json + cmake_install_dir json \ -DJSON_BuildTests=OFF # google-cloud-cpp - github_checkout googleapis/google-cloud-cpp "${GOOGLE_CLOUD_CPP_VERSION}" --depth 1 - cmake_install \ + wget_and_untar https://github.com/googleapis/google-cloud-cpp/archive/refs/tags/v"${GOOGLE_CLOUD_CPP_VERSION}".tar.gz google-cloud-cpp + cmake_install_dir google-cloud-cpp \ -DGOOGLE_CLOUD_CPP_ENABLE_EXAMPLES=OFF \ -DGOOGLE_CLOUD_CPP_ENABLE=storage } -function install_azure-storage-sdk-cpp { +function install_azure_storage_sdk_cpp { # Disable VCPKG to install additional static dependencies under the VCPKG installed path # instead of using system pre-installed dependencies. export AZURE_SDK_DISABLE_AUTO_VCPKG=ON vcpkg_commit_id=7a6f366cefd27210f6a8309aed10c31104436509 github_checkout azure/azure-sdk-for-cpp azure-storage-files-datalake_"${AZURE_SDK_VERSION}" - sed -i "s/set(VCPKG_COMMIT_STRING .*)/set(VCPKG_COMMIT_STRING $vcpkg_commit_id)/" cmake-modules/AzureVcpkg.cmake + pushd azure-sdk-for-cpp || exit + sed -i='' "s/set(VCPKG_COMMIT_STRING .*)/set(VCPKG_COMMIT_STRING $vcpkg_commit_id)/" cmake-modules/AzureVcpkg.cmake azure_core_dir="sdk/core/azure-core" if ! grep -q "baseline" $azure_core_dir/vcpkg.json; then @@ -362,8 +380,8 @@ function install_azure-storage-sdk-cpp { if [[ $openssl_version == 1.1.1* ]]; then openssl_version="1.1.1n" fi - sed -i "s/\"version-string\"/\"builtin-baseline\": \"$vcpkg_commit_id\",\"version-string\"/" $azure_core_dir/vcpkg.json - sed -i "s/\"version-string\"/\"overrides\": [{ \"name\": \"openssl\", \"version-string\": \"$openssl_version\" }],\"version-string\"/" $azure_core_dir/vcpkg.json + sed -i='' "s/\"version-string\"/\"builtin-baseline\": \"$vcpkg_commit_id\",\"version-string\"/" $azure_core_dir/vcpkg.json + sed -i='' "s/\"version-string\"/\"overrides\": [{ \"name\": \"openssl\", \"version-string\": \"$openssl_version\" }],\"version-string\"/" $azure_core_dir/vcpkg.json fi ( cd $azure_core_dir || exit @@ -389,13 +407,16 @@ function install_azure-storage-sdk-cpp { cd sdk/storage/azure-storage-files-datalake || exit cmake_install -DCMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE}" -DBUILD_SHARED_LIBS=OFF ) + popd || exit } function install_hdfs_deps { # Dependencies for Hadoop testing - wget_and_untar https://archive.apache.org/dist/hadoop/common/hadoop-"${HADOOP_VERSION}"/hadoop-"${HADOOP_VERSION}".tar.gz hadoop + wget_and_untar https://dlcdn.apache.org/hadoop/common/hadoop-"${HADOOP_VERSION}"/hadoop-"${HADOOP_VERSION}".tar.gz hadoop cp -a "${DEPENDENCY_DIR}"/hadoop "$INSTALL_PREFIX" wget "${WGET_OPTS[@]}" -P "$INSTALL_PREFIX"/hadoop/share/hadoop/common/lib/ https://repo1.maven.org/maven2/junit/junit/4.11/junit-4.11.jar + # Needed for HADOOP 3.3.6 minicluster. Can remove after updating to 3.4.2. + wget "${WGET_OPTS[@]}" -P "$INSTALL_PREFIX"/hadoop/share/hadoop/mapreduce/ https://repo1.maven.org/maven2/org/mockito/mockito-core/2.23.4/mockito-core-2.23.4.jar } function install_uv { diff --git a/scripts/setup-fedora.sh b/scripts/setup-fedora.sh index 619be70ec89..0164bee9a85 100755 --- a/scripts/setup-fedora.sh +++ b/scripts/setup-fedora.sh @@ -55,6 +55,36 @@ function install_build_prerequisites { fi } +function install_velox_deps_from_dnf { + dnf_install \ + bison boost-devel c-ares-devel curl-devel double-conversion-devel \ + elfutils-libelf-devel flex fmt-devel gflags-devel glog-devel gmock-devel \ + gtest-devel libdwarf-devel libevent-devel libicu-devel \ + libsodium-devel libzstd-devel lz4-devel openssl-devel-engine \ + re2-devel snappy-devel thrift-devel xxhash-devel zlib-devel grpc-devel grpc-plugins + + install_faiss_deps +} + +function install_velox_deps { + run_and_time install_velox_deps_from_dnf + run_and_time install_gcs_sdk_cpp #grpc, abseil, protobuf + run_and_time install_fast_float + run_and_time install_folly + run_and_time install_fizz + run_and_time install_wangle + run_and_time install_mvfst + run_and_time install_fbthrift + run_and_time install_duckdb + run_and_time install_stemmer + run_and_time install_arrow + run_and_time install_xsimd # to new in fedora repos + run_and_time install_simdjson # to new in fedora repos + run_and_time install_geos # to new in fedora repos + run_and_time install_s2geometry + run_and_time install_faiss +} + (return 2>/dev/null) && return # If script was sourced, don't run commands. ( @@ -90,6 +120,8 @@ function install_build_prerequisites { set -u fi install_velox_deps + # BUILD_TESTING requires grpc + dnf_install grpc echo "All dependencies for Velox installed!" if [[ ${USE_CLANG} != "false" ]]; then echo "To use clang for the Velox build set the CC and CXX environment variables in your session." diff --git a/scripts/setup-helper-functions.sh b/scripts/setup-helper-functions.sh index 078454f8ca2..a50fb02ae0e 100755 --- a/scripts/setup-helper-functions.sh +++ b/scripts/setup-helper-functions.sh @@ -49,6 +49,7 @@ function prompt { done ) 2>/dev/null } + function github_checkout { local REPO=$1 shift @@ -69,7 +70,6 @@ function github_checkout { if [ ! -d "${DIRNAME}" ]; then git clone -q -b "$VERSION" "${GIT_CLONE_PARAMS[@]}" "https://github.com/${REPO}.git" fi - cd "${DIRNAME}" || exit } # get_cxx_flags [$CPU_ARCH] diff --git a/scripts/setup-macos.sh b/scripts/setup-macos.sh index 5c05c0dcbb8..fbd45af7801 100755 --- a/scripts/setup-macos.sh +++ b/scripts/setup-macos.sh @@ -41,8 +41,9 @@ export OS_CXXFLAGS export CMAKE_POLICY_VERSION_MINIMUM="3.5" DEPENDENCY_DIR=${DEPENDENCY_DIR:-$(pwd)} -MACOS_VELOX_DEPS="bison flex gflags glog googletest icu4c libevent libsodium lz4 openssl protobuf@21 simdjson snappy xz xxhash zstd" - +# gflags and glog are installed from source to ensure version compatibility. +# Homebrew's glog 0.7.x has breaking API changes that are incompatible with folly. +MACOS_VELOX_DEPS="bison flex googletest icu4c libevent libsodium lz4 openssl simdjson snappy xz xxhash zstd" MACOS_BUILD_DEPS="ninja cmake" SUDO="${SUDO:-""}" @@ -103,6 +104,11 @@ function install_velox_deps_from_brew { done } +function install_gflags { + wget_and_untar https://github.com/gflags/gflags/archive/"${GFLAGS_VERSION}".tar.gz gflags + cmake_install_dir gflags -DBUILD_SHARED_LIBS=ON -DBUILD_STATIC_LIBS=ON -DBUILD_gflags_LIB=ON +} + function install_s3 { install_aws_deps @@ -111,11 +117,11 @@ function install_s3 { } function install_gcs { - install_gcs-sdk-cpp + install_gcs_sdk_cpp } function install_abfs { - install_azure-storage-sdk-cpp + install_azure_storage_sdk_cpp } function install_hdfs { @@ -151,7 +157,6 @@ function install_faiss { install_faiss_deps wget_and_untar "https://github.com/facebookresearch/faiss/archive/refs/tags/v${FAISS_VERSION}.tar.gz" faiss - local cmake_args cmake_args=( -DFAISS_ENABLE_GPU=OFF @@ -177,7 +182,10 @@ function install_velox_deps { run_and_time install_ranges_v3 run_and_time install_double_conversion run_and_time install_re2 + run_and_time install_gflags + run_and_time install_glog run_and_time install_boost + run_and_time install_protobuf run_and_time install_fmt run_and_time install_fast_float run_and_time install_folly @@ -193,6 +201,7 @@ function install_velox_deps { run_and_time install_arrow run_and_time install_duckdb_clang run_and_time install_geos + run_and_time install_s2geometry run_and_time install_faiss } diff --git a/scripts/setup-ubuntu.sh b/scripts/setup-ubuntu.sh index 339b129bb85..4bb68643678 100755 --- a/scripts/setup-ubuntu.sh +++ b/scripts/setup-ubuntu.sh @@ -189,7 +189,7 @@ function install_cuda { cuda-minimal-build-"$dashed" \ cuda-nvrtc-dev-"$dashed" \ libcufile-dev-"$dashed" \ - libnvjitlink-devel-"$dashed" \ + libnvjitlink-dev-"$dashed" \ libnuma1 } @@ -202,18 +202,18 @@ function install_s3 { function install_gcs { # Dependencies of GCS, probably a workaround until the docker image is rebuilt - apt install -y --no-install-recommends libc-ares-dev libcurl4-openssl-dev - install_gcs-sdk-cpp + ${SUDO} apt install -y --no-install-recommends libc-ares-dev libcurl4-openssl-dev + install_gcs_sdk_cpp } function install_abfs { # Dependencies of Azure Storage Blob cpp - apt install -y openssl libxml2-dev - install_azure-storage-sdk-cpp + ${SUDO} apt install -y openssl libxml2-dev + install_azure_storage_sdk_cpp } function install_hdfs { - apt install -y --no-install-recommends libxml2-dev libgsasl7-dev uuid-dev openjdk-8-jdk + ${SUDO} apt install -y --no-install-recommends libxml2-dev libgsasl7-dev uuid-dev openjdk-8-jdk install_hdfs_deps } @@ -225,13 +225,14 @@ function install_adapters { } function install_faiss_deps { - sudo apt-get install -y libopenblas-dev libomp-dev + ${SUDO} apt-get install -y libopenblas-dev libomp-dev } function install_velox_deps { run_and_time install_velox_deps_from_apt run_and_time install_fmt run_and_time install_protobuf + run_and_time install_grpc run_and_time install_boost run_and_time install_fast_float run_and_time install_folly @@ -247,6 +248,7 @@ function install_velox_deps { run_and_time install_xsimd run_and_time install_simdjson run_and_time install_geos + run_and_time install_s2geometry run_and_time install_faiss } diff --git a/scripts/setup-versions.sh b/scripts/setup-versions.sh index 43428a7f4d8..a14a631df76 100755 --- a/scripts/setup-versions.sh +++ b/scripts/setup-versions.sh @@ -19,14 +19,20 @@ # The versions should match the declared versions in this file. # Build dependencies versions. -FB_OS_VERSION="v2025.04.28.00" -FMT_VERSION="10.1.1" +# Note: When updating FB_OS_VERSION, ensure that +# CMake/third-party/FBThriftCppLibrary.cmake is updated +# with the matching version from fbthrift: +# /build/fbcode_builder/CMake/FBThriftCppLibrary.cmake +# The new FB_OS version of fbthrift might require changes such that thrift +# files are generated properly on all platforms. +FB_OS_VERSION="v2026.01.05.00" +FMT_VERSION="11.2.0" BOOST_VERSION="boost-1.84.0" -ARROW_VERSION="15.0.0" +ARROW_VERSION="18.0.0" DUCKDB_VERSION="v0.8.1" PROTOBUF_VERSION="21.8" XSIMD_VERSION="10.0.0" -SIMDJSON_VERSION="3.9.3" +SIMDJSON_VERSION="4.1.0" CPR_VERSION="1.10.5" DOUBLE_CONVERSION_VERSION="v3.1.5" RANGE_V3_VERSION="0.12.0" @@ -38,6 +44,7 @@ SNAPPY_VERSION="1.1.8" THRIFT_VERSION="${THRIFT_VERSION:-v0.16.0}" STEMMER_VERSION="2.2.0" GEOS_VERSION="3.10.7" +S2GEOMETRY_VERSION="0.12.0" # shellcheck disable=SC2034 FAISS_VERSION="1.11.0" FAST_FLOAT_VERSION="v8.0.2" @@ -45,12 +52,12 @@ CCACHE_VERSION="4.11.3" # Adapter related versions. ABSEIL_VERSION="20240116.2" -GRPC_VERSION="v1.48.1" +GRPC_VERSION="1.48.1" CRC32_VERSION="1.1.2" -NLOHMAN_JSON_VERSION="v3.11.3" -GOOGLE_CLOUD_CPP_VERSION="v2.22.0" -HADOOP_VERSION="3.3.0" +NLOHMAN_JSON_VERSION="3.11.3" +GOOGLE_CLOUD_CPP_VERSION="2.22.0" +HADOOP_VERSION="3.3.6" AZURE_SDK_VERSION="12.8.0" MINIO_VERSION="2022-05-26T05-48-41Z" MINIO_BINARY_NAME="minio-2022-05-26" -AWS_SDK_VERSION="1.11.321" +AWS_SDK_VERSION="1.11.654" diff --git a/scripts/update-cudf-deps.sh b/scripts/update-cudf-deps.sh new file mode 100755 index 00000000000..a7048f23177 --- /dev/null +++ b/scripts/update-cudf-deps.sh @@ -0,0 +1,250 @@ +#!/usr/bin/env bash +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -euo pipefail + +usage() { + cat < + $0 --pr + $0 --commit + +Options: + --branch Update all cudf dependencies to latest from branch + --pr Update only cudf from a specific PR + --commit Update all dependencies using cudf commit and compatible versions + +Environment Variables: + GH_TOKEN GitHub personal access token for higher API rate limits via curl (optional) + +Examples: + $0 --branch main + $0 --branch release/26.02 + $0 --pr 12345 + $0 --commit abc123def456 + GH_TOKEN=github_pat_xxx $0 --branch main +EOF +} + +[[ $# -eq 0 ]] && usage && exit 1 + +MODE="$1" +ARG="${2:-}" +[[ -z $ARG ]] && echo "Error: $MODE requires an argument" && usage && exit 1 + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CMAKE_FILE="$SCRIPT_DIR/../CMake/resolve_dependency_modules/cudf.cmake" + +# Use gh CLI if available or fall back to curl supporting GitHub token for higher rate limits. +GH_TOKEN="${GH_TOKEN:-}" + +fetch_github_api() { + local endpoint=$1 + local response http_code + if command -v gh &>/dev/null; then + if response=$(gh api "${endpoint}" 2>&1); then + echo "$response" + return 0 + fi + echo "Error: gh api failed" >&2 + echo "Response: $response" >&2 + else + local curl_opts=(-s --max-time 30 -w '%{http_code}') + if [[ -n $GH_TOKEN ]]; then + curl_opts+=(-H "Authorization: token $GH_TOKEN") + fi + local raw + raw=$(curl "${curl_opts[@]}" "https://api.github.com/${endpoint}" 2>&1) + http_code="${raw: -3}" + response="${raw%???}" + if [[ $http_code == "200" ]]; then + echo "$response" + return 0 + fi + echo "Error: GitHub API returned HTTP $http_code for $endpoint" >&2 + fi + echo "" >&2 + echo "SOLUTION: Authenticate to get 5,000 requests/hour:" >&2 + if command -v gh &>/dev/null; then + echo " gh auth login" >&2 + else + echo " export GH_TOKEN=gh_pat_xx" >&2 + echo " Get a token at: https://github.com/settings/personal-access-tokens/new (access: public repo)" >&2 + fi + return 1 +} + +get_commit_info() { + local repo=$1 branch=$2 + local response + response=$(fetch_github_api "repos/rapidsai/${repo}/commits/${branch}") || return 1 + echo "$response" | jq -r '[.sha, .commit.committer.date[:10]] | join(" ")' 2>/dev/null || { + echo "Error: Failed to parse JSON response from $repo" >&2 + return 1 + } +} + +get_commit_before_date() { + local repo=$1 until_date=$2 + local response + response=$(fetch_github_api "repos/rapidsai/${repo}/commits?sha=main&until=${until_date}&per_page=1") || return 1 + echo "$response" | jq -r '.[0] | [.sha, .commit.committer.date[:10]] | join(" ")' 2>/dev/null || { + echo "Error: Failed to parse response for $repo before $until_date" >&2 + return 1 + } +} + +get_sha256() { + curl -sL --max-time 30 "https://github.com/rapidsai/$1/archive/$2.tar.gz" | sha256sum | cut -d' ' -f1 || { + echo "Error: Failed to compute SHA256 for $1:$2" >&2 + return 1 + } +} + +get_version() { + local branch=$1 + if [[ $branch =~ ^release/([0-9]+\.[0-9]+)$ ]]; then + echo "${BASH_REMATCH[1]}" + else + local response + response=$(fetch_github_api "repos/rapidsai/cudf/contents/VERSION?ref=${branch}") || return 1 + echo "$response" | jq -r '.content' | base64 -d | grep -oP '^[0-9]+\.[0-9]+' || { + echo "Error: Failed to parse VERSION file from branch $branch" >&2 + return 1 + } + fi +} + +update_dependency() { + local var=$1 commit=$2 date=$3 checksum=$4 + sed -i "s/# ${var} commit [a-f0-9]* from [0-9-]*/# ${var} commit ${commit:0:7} from ${date}/" "$CMAKE_FILE" + sed -i "s/set(VELOX_${var}_COMMIT [a-f0-9]*)/set(VELOX_${var}_COMMIT ${commit})/" "$CMAKE_FILE" + + if [[ $var == "cudf" ]]; then + sed -i "s/set(VELOX_${var}_VERSION [0-9.]* CACHE/set(VELOX_${var}_VERSION ${VERSION} CACHE/" "$CMAKE_FILE" + else + sed -i "s/set(VELOX_${var}_VERSION [0-9.]*)/set(VELOX_${var}_VERSION ${VERSION})/" "$CMAKE_FILE" + fi + + awk -v var="VELOX_${var}_BUILD_SHA256_CHECKSUM" -v sum="$checksum" ' + $0 ~ var { found=1 } + found && /^[[:space:]]*[a-f0-9]{64}[[:space:]]*$/ { sub(/[a-f0-9]{64}/, sum); found=0 } + { print } + ' "$CMAKE_FILE" >"${CMAKE_FILE}.tmp" && mv "${CMAKE_FILE}.tmp" "$CMAKE_FILE" +} + +if [[ $MODE == "--pr" ]]; then + echo "Fetching cuDF PR #${ARG}..." + PR_INFO=$(fetch_github_api "repos/rapidsai/cudf/pulls/${ARG}") || exit 1 + SHA=$(echo "$PR_INFO" | jq -r '.head.sha') + BASE=$(echo "$PR_INFO" | jq -r '.base.ref') + VERSION=$(get_version "$BASE") || exit 1 + DATE=$(fetch_github_api "repos/rapidsai/cudf/commits/${SHA}" | jq -r '.commit.committer.date[:10]') || exit 1 + + echo " Base: $BASE (version $VERSION)" + echo " Commit: ${SHA:0:7} from $DATE" + echo " Computing SHA256..." + CHECKSUM=$(get_sha256 "cudf" "$SHA") || exit 1 + echo " SHA256: $CHECKSUM" + echo + + update_dependency "cudf" "$SHA" "$DATE" "$CHECKSUM" + echo "Done! Updated cudf to PR #${ARG}: ${SHA:0:7} ($DATE)" + +elif [[ $MODE == "--commit" ]]; then + echo "Fetching cuDF commit ${ARG:0:7}..." + COMMIT_INFO=$(fetch_github_api "repos/rapidsai/cudf/commits/${ARG}") || exit 1 + SHA=$(echo "$COMMIT_INFO" | jq -r '.sha') + DATE=$(echo "$COMMIT_INFO" | jq -r '.commit.committer.date[:10]') + TIMESTAMP=$(echo "$COMMIT_INFO" | jq -r '.commit.committer.date') + VERSION=$(fetch_github_api "repos/rapidsai/cudf/contents/VERSION?ref=${SHA}" | jq -r '.content' | base64 -d | grep -oP '^[0-9]+\.[0-9]+') || exit 1 + + echo " Commit: ${SHA:0:7} from $DATE" + echo " Version: $VERSION" + echo + + declare -A COMMITS DATES CHECKSUMS + COMMITS[cudf]=$SHA + DATES[cudf]=$DATE + + echo "Finding compatible dependency versions (main branch commits before $TIMESTAMP)..." + echo + + for dep in rapids_cmake rmm kvikio; do + repo=${dep//_/-} + echo "Fetching $repo..." + read -r commit date < <(get_commit_before_date "$repo" "$TIMESTAMP") || exit 1 + echo " Commit: ${commit:0:7} from $date" + echo " Computing SHA256..." + checksum=$(get_sha256 "$repo" "$commit") || exit 1 + echo " SHA256: $checksum" + + COMMITS[$dep]=$commit + DATES[$dep]=$date + CHECKSUMS[$dep]=$checksum + echo + done + + echo "Computing SHA256 for cudf..." + CHECKSUMS[cudf]=$(get_sha256 "cudf" "$SHA") || exit 1 + echo " SHA256: ${CHECKSUMS[cudf]}" + echo + + echo "Updating $CMAKE_FILE..." + for dep in rapids_cmake rmm kvikio cudf; do + update_dependency "$dep" "${COMMITS[$dep]}" "${DATES[$dep]}" "${CHECKSUMS[$dep]}" + done + + echo "Done! Updated dependencies:" + for dep in rapids_cmake rmm kvikio cudf; do + echo " $dep: ${COMMITS[$dep]:0:7} (${DATES[$dep]})" + done + +elif [[ $MODE == "--branch" ]]; then + VERSION=$(get_version "$ARG") || exit 1 + echo "Updating cuDF dependencies from branch $ARG (version $VERSION)" + echo + + declare -A COMMITS DATES CHECKSUMS + + for dep in rapids_cmake rmm kvikio cudf; do + repo=${dep//_/-} + echo "Fetching $repo..." + + read -r commit date < <(get_commit_info "$repo" "$ARG") || exit 1 + echo " Commit: ${commit:0:7} from $date" + echo " Computing SHA256..." + checksum=$(get_sha256 "$repo" "$commit") || exit 1 + echo " SHA256: $checksum" + + COMMITS[$dep]=$commit + DATES[$dep]=$date + CHECKSUMS[$dep]=$checksum + echo + done + + echo "Updating $CMAKE_FILE..." + for dep in rapids_cmake rmm kvikio cudf; do + update_dependency "$dep" "${COMMITS[$dep]}" "${DATES[$dep]}" "${CHECKSUMS[$dep]}" + done + + echo "Done! Updated dependencies:" + for dep in rapids_cmake rmm kvikio cudf; do + echo " $dep: ${COMMITS[$dep]:0:7} (${DATES[$dep]})" + done +else + usage + exit 1 +fi diff --git a/velox/benchmarks/CMakeLists.txt b/velox/benchmarks/CMakeLists.txt index 48dc1fde93f..44da32311e3 100644 --- a/velox/benchmarks/CMakeLists.txt +++ b/velox/benchmarks/CMakeLists.txt @@ -31,12 +31,14 @@ if(${VELOX_BUILD_TEST_UTILS}) ) add_library(velox_benchmark_builder ExpressionBenchmarkBuilder.cpp) - target_link_libraries(velox_benchmark_builder ${velox_benchmark_deps}) + velox_add_test_headers(velox_benchmark_builder ExpressionBenchmarkBuilder.h) + target_link_libraries(velox_benchmark_builder ${velox_benchmark_deps} velox_vector_test_lib) # This is a workaround for the use of VectorTestBase.h which includes gtest.h target_link_libraries(velox_benchmark_builder GTest::gtest) add_library(velox_query_benchmark QueryBenchmarkBase.cpp) + velox_add_test_headers(velox_query_benchmark QueryBenchmarkBase.h) target_link_libraries( velox_query_benchmark @@ -69,5 +71,6 @@ endif() if(${VELOX_ENABLE_BENCHMARKS}) add_subdirectory(tpch) + add_subdirectory(tpcds) add_subdirectory(filesystem) endif() diff --git a/velox/benchmarks/QueryBenchmarkBase.cpp b/velox/benchmarks/QueryBenchmarkBase.cpp index d1cd5d21dfc..2e9a16035c1 100644 --- a/velox/benchmarks/QueryBenchmarkBase.cpp +++ b/velox/benchmarks/QueryBenchmarkBase.cpp @@ -18,6 +18,7 @@ #include #include "velox/common/base/SuccinctPrinter.h" #include "velox/common/file/FileSystems.h" +#include "velox/connectors/ConnectorRegistry.h" #include "velox/connectors/hive/HiveConnector.h" #include "velox/dwio/dwrf/RegisterDwrfReader.h" #include "velox/dwio/parquet/RegisterParquetReader.h" @@ -204,25 +205,32 @@ void QueryBenchmarkBase::initialize() { std::make_unique(FLAGS_num_io_threads); // Add new values into the hive configuration... - auto configurationValues = std::unordered_map(); - configurationValues[connector::hive::HiveConfig::kMaxCoalescedBytes] = - std::to_string(FLAGS_max_coalesced_bytes); - configurationValues[connector::hive::HiveConfig::kMaxCoalescedDistance] = - FLAGS_max_coalesced_distance_bytes; - configurationValues[connector::hive::HiveConfig::kPrefetchRowGroups] = - std::to_string(FLAGS_parquet_prefetch_rowgroups); - auto properties = std::make_shared( - std::move(configurationValues)); + auto properties = makeConnectorProperties(); // Create hive connector with config... connector::hive::HiveConnectorFactory factory; auto hiveConnector = factory.newConnector(kHiveConnectorId, properties, ioExecutor_.get()); - connector::registerConnector(hiveConnector); + connector::ConnectorRegistry::global().insert( + hiveConnector->connectorId(), hiveConnector); parquet::registerParquetReaderFactory(); dwrf::registerDwrfReaderFactory(); } +std::shared_ptr +QueryBenchmarkBase::makeConnectorProperties() { + // Default behaviour identical to the original hard-coded version. + auto configurationValues = std::unordered_map(); + configurationValues[connector::hive::HiveConfig::kMaxCoalescedBytes] = + std::to_string(FLAGS_max_coalesced_bytes); + configurationValues[connector::hive::HiveConfig::kMaxCoalescedDistance] = + FLAGS_max_coalesced_distance_bytes; + configurationValues[connector::hive::HiveConfig::kPrefetchRowGroups] = + std::to_string(FLAGS_parquet_prefetch_rowgroups); + return std::make_shared( + std::move(configurationValues), true); +} + std::vector> QueryBenchmarkBase::listSplits( const std::string& path, @@ -244,13 +252,16 @@ void QueryBenchmarkBase::shutdown() { } std::pair, std::vector> -QueryBenchmarkBase::run(const TpchPlan& tpchPlan) { +QueryBenchmarkBase::run( + const TpchPlan& tpchPlan, + const std::unordered_map& queryConfigs) { int32_t repeat = 0; try { for (;;) { CursorParameters params; params.maxDrivers = FLAGS_num_drivers; params.planNode = tpchPlan.plan; + params.queryConfigs = queryConfigs; params.queryConfigs[core::QueryConfig::kMaxSplitPreloadPerDriver] = std::to_string(FLAGS_split_preload_per_driver); const int numSplitsPerFile = FLAGS_num_splits_per_file; diff --git a/velox/benchmarks/QueryBenchmarkBase.h b/velox/benchmarks/QueryBenchmarkBase.h index e450430f6da..755b67d36a9 100644 --- a/velox/benchmarks/QueryBenchmarkBase.h +++ b/velox/benchmarks/QueryBenchmarkBase.h @@ -55,9 +55,10 @@ class QueryBenchmarkBase { public: virtual ~QueryBenchmarkBase() = default; virtual void initialize(); - void shutdown(); + virtual void shutdown(); std::pair, std::vector> run( - const exec::test::TpchPlan& tpchPlan); + const exec::test::TpchPlan& tpchPlan, + const std::unordered_map& queryConfigs = {}); virtual std::vector> listSplits( const std::string& path, @@ -81,6 +82,8 @@ class QueryBenchmarkBase { void runAllCombinations(); + virtual std::shared_ptr makeConnectorProperties(); + protected: std::unique_ptr ioExecutor_; std::unique_ptr cacheExecutor_; diff --git a/velox/benchmarks/basic/CMakeLists.txt b/velox/benchmarks/basic/CMakeLists.txt index 7ba8c8fee60..22614f0bd2a 100644 --- a/velox/benchmarks/basic/CMakeLists.txt +++ b/velox/benchmarks/basic/CMakeLists.txt @@ -107,12 +107,30 @@ target_link_libraries( add_executable(velox_cast_benchmark CastBenchmark.cpp) target_link_libraries(velox_cast_benchmark ${velox_benchmark_deps} velox_vector_test_lib) +add_executable(velox_numeric_upcast_benchmark NumericUpcastBenchmark.cpp) +target_link_libraries(velox_numeric_upcast_benchmark ${velox_benchmark_deps} velox_vector_test_lib) + +add_executable(velox_benchmark_expr_flat_no_nulls ExprFlatNoNullsBenchmark.cpp) +target_link_libraries( + velox_benchmark_expr_flat_no_nulls + ${velox_benchmark_deps} + velox_functions_prestosql +) + add_executable(velox_format_datetime_benchmark FormatDateTimeBenchmark.cpp) target_link_libraries( velox_format_datetime_benchmark ${velox_benchmark_deps} velox_vector_test_lib - velox_functions_spark velox_functions_prestosql velox_row_fast ) + +add_executable(velox_date_extract_benchmark DateExtractBenchmark.cpp) +target_link_libraries( + velox_date_extract_benchmark + ${velox_benchmark_deps} + velox_vector_test_lib + velox_functions_prestosql + velox_functions_spark +) diff --git a/velox/benchmarks/basic/CastBenchmark.cpp b/velox/benchmarks/basic/CastBenchmark.cpp index a06f75db8e4..314788df030 100644 --- a/velox/benchmarks/basic/CastBenchmark.cpp +++ b/velox/benchmarks/basic/CastBenchmark.cpp @@ -63,6 +63,11 @@ int main(int argc, char** argv) { [&](auto j) { return 123456789 * j; }, nullptr, DECIMAL(18, 6)); + auto smallDecimalInput = vectorMaker.flatVector( + vectorSize, + [&](auto j) { return 123456789 * j; }, + nullptr, + DECIMAL(18, 17)); auto longDecimalInput = vectorMaker.flatVector( vectorSize, [&](auto j) { @@ -160,6 +165,7 @@ int main(int argc, char** argv) { "bigint", "decimal", "short_decimal", + "small_decimal", "long_decimal", "large_real", "small_real", @@ -172,6 +178,7 @@ int main(int argc, char** argv) { bigintInput, decimalInput, shortDecimalInput, + smallDecimalInput, longDecimalInput, largeRealInput, smallRealInput, @@ -189,6 +196,7 @@ int main(int argc, char** argv) { .addExpression( "cast_decimal_to_inline_string", "cast (decimal as varchar)") .addExpression("cast_short_decimal", "cast (short_decimal as varchar)") + .addExpression("cast_small_decimal", "cast (small_decimal as varchar)") .addExpression("cast_long_decimal", "cast (long_decimal as varchar)") .addExpression( "cast_large_real_to_scientific_notation", diff --git a/velox/benchmarks/basic/DateExtractBenchmark.cpp b/velox/benchmarks/basic/DateExtractBenchmark.cpp new file mode 100644 index 00000000000..cec32691684 --- /dev/null +++ b/velox/benchmarks/basic/DateExtractBenchmark.cpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "velox/benchmarks/ExpressionBenchmarkBuilder.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/functions/sparksql/registration/Register.h" + +using namespace facebook; +using namespace facebook::velox; + +int main(int argc, char** argv) { + folly::Init init(&argc, &argv); + memory::MemoryManager::initialize(memory::MemoryManager::Options{}); + functions::prestosql::registerDateTimeFunctions(""); + functions::sparksql::registerFunctions("spark_"); + + ExpressionBenchmarkBuilder benchmarkBuilder; + VectorFuzzer::Options options; + options.vectorSize = 1024; + options.nullRatio = 0; + auto* pool = benchmarkBuilder.pool(); + VectorFuzzer fuzzer(options, pool); + auto vectorMaker = benchmarkBuilder.vectorMaker(); + + // Each set runs a different extraction expression on the same fuzzed input, + // exercising Timestamp::epochToCalendarUtc through getDateTime in + // velox/functions/lib/TimeUtils.h. + + // DATE inputs (int32 days since epoch). + benchmarkBuilder + .addBenchmarkSet( + "year_date", vectorMaker.rowVector({fuzzer.fuzz(DATE())})) + .addExpression("year", "year(c0)") + .disableTesting(); + + benchmarkBuilder + .addBenchmarkSet( + "month_date", vectorMaker.rowVector({fuzzer.fuzz(DATE())})) + .addExpression("month", "month(c0)") + .disableTesting(); + + benchmarkBuilder + .addBenchmarkSet("day_date", vectorMaker.rowVector({fuzzer.fuzz(DATE())})) + .addExpression("day", "day(c0)") + .disableTesting(); + + benchmarkBuilder + .addBenchmarkSet( + "quarter_date", vectorMaker.rowVector({fuzzer.fuzz(DATE())})) + .addExpression("quarter", "quarter(c0)") + .disableTesting(); + + benchmarkBuilder + .addBenchmarkSet( + "day_of_year_date", vectorMaker.rowVector({fuzzer.fuzz(DATE())})) + .addExpression("day_of_year", "day_of_year(c0)") + .disableTesting(); + + benchmarkBuilder + .addBenchmarkSet( + "last_day_of_month_date", + vectorMaker.rowVector({fuzzer.fuzz(DATE())})) + .addExpression("last_day_of_month", "last_day_of_month(c0)") + .disableTesting(); + + // TIMESTAMP inputs (seconds + nanos). + benchmarkBuilder + .addBenchmarkSet( + "year_timestamp", vectorMaker.rowVector({fuzzer.fuzz(TIMESTAMP())})) + .addExpression("year", "year(c0)") + .disableTesting(); + + benchmarkBuilder + .addBenchmarkSet( + "month_timestamp", vectorMaker.rowVector({fuzzer.fuzz(TIMESTAMP())})) + .addExpression("month", "month(c0)") + .disableTesting(); + + benchmarkBuilder + .addBenchmarkSet( + "day_timestamp", vectorMaker.rowVector({fuzzer.fuzz(TIMESTAMP())})) + .addExpression("day", "day(c0)") + .disableTesting(); + + // Inverse direction: (year, month, day) -> Date via Spark make_date. + // Three INTEGER columns of small valid year/month/day values. + VectorFuzzer::Options yearOpts = options; + yearOpts.dataSpec.includeNaN = false; + VectorFuzzer yearFuzzer(yearOpts, pool); + benchmarkBuilder + .addBenchmarkSet( + "make_date", + vectorMaker.rowVector( + {"y", "m", "d"}, + {vectorMaker.flatVector( + 1024, [](auto i) { return 1970 + (int)(i % 80); }), + vectorMaker.flatVector( + 1024, [](auto i) { return 1 + (int)(i % 12); }), + vectorMaker.flatVector( + 1024, [](auto i) { return 1 + (int)(i % 28); })})) + .addExpression("make_date", "spark_make_date(y, m, d)") + .disableTesting(); + + benchmarkBuilder.registerBenchmarks(); + folly::runBenchmarks(); + return 0; +} diff --git a/velox/benchmarks/basic/ExprFlatNoNullsBenchmark.cpp b/velox/benchmarks/basic/ExprFlatNoNullsBenchmark.cpp new file mode 100644 index 00000000000..8694c6dd3fc --- /dev/null +++ b/velox/benchmarks/basic/ExprFlatNoNullsBenchmark.cpp @@ -0,0 +1,121 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/// Benchmark comparing expression evaluation with FlatNoNulls fast path +/// enabled vs disabled via the expression.eval_flat_no_nulls config. + +#include +#include + +#include "velox/benchmarks/ExpressionBenchmarkBuilder.h" +#include "velox/core/QueryConfig.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" + +using namespace facebook; +using namespace facebook::velox; + +namespace { + +void addArithmeticBenchmarks( + ExpressionBenchmarkBuilder& builder, + const std::string& prefix) { + for (auto vectorSize : {1'024, 4'096, 40'960, 99'999}) { + builder + .addBenchmarkSet( + fmt::format("{}_arith_batch{}", prefix, vectorSize), + ROW({"a", "b", "c", "d"}, {DOUBLE(), DOUBLE(), DOUBLE(), DOUBLE()})) + .withFuzzerOptions( + {.vectorSize = static_cast(vectorSize), .nullRatio = 0.0}) + // Simple arithmetic — 1 node. + .addExpression("add_ab", "a + b") + // Complex arithmetic — 7 nodes. + .addExpression("complex_7n", "(a + b) * c + (a - d) * b") + // Deep tree — 15 nodes, depth 8 (left-skewed chain). + .addExpression( + "deep_15n_d8", "((((((a + b) * c - d) + a) * b - c) + d) * a - b)") + .withIterations(1'000); + } +} + +void addComparisonBenchmarks( + ExpressionBenchmarkBuilder& builder, + const std::string& prefix) { + for (auto vectorSize : {1'024, 4'096, 40'960, 99'999}) { + builder + .addBenchmarkSet( + fmt::format("{}_cmp_batch{}", prefix, vectorSize), + ROW({"a", "b"}, {DOUBLE(), DOUBLE()})) + .withFuzzerOptions( + {.vectorSize = static_cast(vectorSize), .nullRatio = 0.0}) + // Comparison — 1 node. + .addExpression("eq_ab", "a = b") + .withIterations(1'000); + } +} + +void addConstMixedBenchmarks( + ExpressionBenchmarkBuilder& builder, + const std::string& prefix) { + for (auto vectorSize : {1'024, 4'096, 40'960, 99'999}) { + builder + .addBenchmarkSet( + fmt::format("{}_const_batch{}", prefix, vectorSize), + ROW({"a", "b"}, {DOUBLE(), DOUBLE()})) + .withFuzzerOptions( + {.vectorSize = static_cast(vectorSize), .nullRatio = 0.0}) + // 1 constant. + .addExpression("const_1", "a + 1.5") + // 2 constants. + .addExpression("const_2", "(a + 1.5) * 2.0") + // 3 constants mixed with columns — 7 nodes. + .addExpression("const_3_7n", "(a + 1.5) * 2.0 + (a - 3.0) * b") + .withIterations(1'000); + } +} + +/// Extends ExpressionBenchmarkBuilder to allow setting query config. +class ConfigurableBenchmarkBuilder : public ExpressionBenchmarkBuilder { + public: + void setConfig(const std::string& key, const std::string& value) { + queryCtx_->testingOverrideConfigUnsafe({{key, value}}); + } +}; + +} // namespace + +int main(int argc, char** argv) { + folly::Init init(&argc, &argv); + memory::MemoryManager::initialize(memory::MemoryManager::Options{}); + functions::prestosql::registerAllScalarFunctions(); + + // Fast path ON (default). + ConfigurableBenchmarkBuilder fastPathOn; + addArithmeticBenchmarks(fastPathOn, "on"); + addComparisonBenchmarks(fastPathOn, "on"); + addConstMixedBenchmarks(fastPathOn, "on"); + + // Fast path OFF via config. + ConfigurableBenchmarkBuilder fastPathOff; + fastPathOff.setConfig(core::QueryConfig::kExprEvalFlatNoNulls, "false"); + addArithmeticBenchmarks(fastPathOff, "off"); + addComparisonBenchmarks(fastPathOff, "off"); + addConstMixedBenchmarks(fastPathOff, "off"); + + fastPathOn.registerBenchmarks(); + fastPathOff.registerBenchmarks(); + folly::runBenchmarks(); + return 0; +} diff --git a/velox/benchmarks/basic/LikeTpchBenchmark.cpp b/velox/benchmarks/basic/LikeTpchBenchmark.cpp index 1f4dab648f9..d904ae5118b 100644 --- a/velox/benchmarks/basic/LikeTpchBenchmark.cpp +++ b/velox/benchmarks/basic/LikeTpchBenchmark.cpp @@ -127,8 +127,9 @@ class LikeFunctionsBenchmark : public FunctionBaseTest, return tpchSupplier->childAt(6); } default: - VELOX_FAIL(fmt::format( - "Tpch data generation for case {} is not supported", tpchCase)); + VELOX_FAIL( + fmt::format( + "Tpch data generation for case {} is not supported", tpchCase)); } } diff --git a/velox/benchmarks/basic/NumericUpcastBenchmark.cpp b/velox/benchmarks/basic/NumericUpcastBenchmark.cpp new file mode 100644 index 00000000000..931828846a6 --- /dev/null +++ b/velox/benchmarks/basic/NumericUpcastBenchmark.cpp @@ -0,0 +1,190 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "velox/benchmarks/ExpressionBenchmarkBuilder.h" + +using namespace facebook; +using namespace facebook::velox; + +int main(int argc, char** argv) { + folly::Init init(&argc, &argv); + memory::MemoryManager::initialize(memory::MemoryManager::Options{}); + + ExpressionBenchmarkBuilder benchmarkBuilder; + const vector_size_t vectorSize = 300'000'000; + auto vectorMaker = benchmarkBuilder.vectorMaker(); + + auto tinyIntInputNullable = vectorMaker.flatVector( + vectorSize, + [](auto j) { return j % std::numeric_limits::max(); }, + [](auto j) { return j % 5 == 0; }); + auto tinyIntInput = vectorMaker.flatVector( + vectorSize, + [](auto j) { return j % std::numeric_limits::max(); }, + nullptr); + + auto smallIntInputNullable = vectorMaker.flatVector( + vectorSize, + [](auto j) { return j % std::numeric_limits::max(); }, + [](auto j) { return j % 5 == 0; }); + auto smallIntInput = vectorMaker.flatVector( + vectorSize, + [](auto j) { return j % std::numeric_limits::max(); }, + nullptr); + + auto integerInputNullable = vectorMaker.flatVector( + vectorSize, + [](auto j) { return j * 2 + 1; }, + [](auto j) { return j % 5 == 0; }); + auto integerInput = vectorMaker.flatVector( + vectorSize, [](auto j) { return j * 2 + 1; }, nullptr); + + auto bigintInputNullable = vectorMaker.flatVector( + vectorSize, + [](auto j) { return j * 2 + 1; }, + [](auto j) { return j % 5 == 0; }); + auto bigintInput = vectorMaker.flatVector( + vectorSize, [](auto j) { return j * 2 + 1; }, nullptr); + + auto realInputNullable = vectorMaker.flatVector( + vectorSize, + [](auto j) { return j * 9.9999; }, + [](auto j) { return j % 5 == 0; }); + auto realInput = vectorMaker.flatVector( + vectorSize, [](auto j) { return j * 9.9999; }, nullptr); + + benchmarkBuilder + .addBenchmarkSet( + "numeric_upcast", + vectorMaker.rowVector( + { + "tinyint_column_nullable", + "tinyint_column", + "smallint_column_nullable", + "smallint_column", + "integer_column_nullable", + "integer_column", + "bigint_column_nullable", + "bigint_column", + "real_column_nullable", + "real_column", + }, + { + tinyIntInputNullable, + tinyIntInput, + smallIntInputNullable, + smallIntInput, + integerInputNullable, + integerInput, + bigintInputNullable, + bigintInput, + realInputNullable, + realInput, + })) + // Cast from tinyint. + .addExpression( + "cast_tinyint_nullable_as_smallint", + "cast(tinyint_column_nullable as smallint)") + .addExpression( + "cast_tinyint_as_smallint", "cast(tinyint_column as smallint)") + + .addExpression( + "cast_tinyint_nullable_as_integer", + "cast(tinyint_column_nullable as integer)") + .addExpression( + "cast_tinyint_as_integer", "cast(tinyint_column as integer)") + + .addExpression( + "cast_tinyint_nullable_as_bigint", + "cast(tinyint_column_nullable as bigint)") + .addExpression("cast_tinyint_as_bigint", "cast(tinyint_column as bigint)") + + .addExpression( + "cast_tinyint_nullable_as_real", + "cast(tinyint_column_nullable as real)") + .addExpression("cast_tinyint_as_real", "cast(tinyint_column as real)") + + .addExpression( + "cast_tinyint_nullable_as_double", + "cast(tinyint_column_nullable as double)") + .addExpression("cast_tinyint_as_double", "cast(tinyint_column as double)") + + // Cast from smallint. + .addExpression( + "cast_smallint_nullable_as_integer", + "cast(smallint_column_nullable as integer)") + .addExpression( + "cast_smallint_as_integer", "cast(smallint_column as integer)") + + .addExpression( + "cast_smallint_nullable_as_bigint", + "cast(smallint_column_nullable as bigint)") + .addExpression( + "cast_smallint_as_bigint", "cast(smallint_column as bigint)") + + .addExpression( + "cast_smallint_nullable_as_real", + "cast(smallint_column_nullable as real)") + .addExpression("cast_smallint_as_real", "cast(smallint_column as real)") + + .addExpression( + "cast_smallint_nullable_as_double", + "cast(smallint_column_nullable as double)") + .addExpression( + "cast_smallint_as_double", "cast(smallint_column as double)") + + // Cast from integer. + .addExpression( + "cast_integer_nullable_as_bigint", + "cast(integer_column_nullable as bigint)") + .addExpression("cast_integer_as_bigint", "cast(integer_column as bigint)") + + .addExpression( + "cast_integer_nullable_as_real", + "cast(integer_column_nullable as real)") + .addExpression("cast_integer_as_real", "cast(integer_column as real)") + + .addExpression( + "cast_integer_nullable_as_double", + "cast(integer_column_nullable as double)") + .addExpression("cast_integer_as_double", "cast(integer_column as double)") + + // Cast from bigint. + .addExpression( + "cast_bigint_nullable_as_real", + "cast(bigint_column_nullable as real)") + .addExpression("cast_bigint_as_real", "cast(bigint_column as real)") + + .addExpression( + "cast_bigint_nullable_as_double", + "cast(bigint_column_nullable as double)") + .addExpression("cast_bigint_as_double", "cast(bigint_column as double)") + + // Cast from real. + .addExpression( + "cast_real_nullable_as_double", + "cast(real_column_nullable as double)") + .addExpression("cast_real_as_double", "cast(real_column as double)") + .withIterations(100) + .disableTesting(); + + benchmarkBuilder.registerBenchmarks(); + folly::runBenchmarks(); + return 0; +} diff --git a/velox/benchmarks/basic/Preproc.cpp b/velox/benchmarks/basic/Preproc.cpp index f3fbca54b4f..e7fdc72e320 100644 --- a/velox/benchmarks/basic/Preproc.cpp +++ b/velox/benchmarks/basic/Preproc.cpp @@ -186,11 +186,12 @@ std::vector> signatures() { std::vector> signatures; for (auto type : {"tinyint", "smallint", "integer", "bigint", "real", "double"}) { - signatures.push_back(exec::FunctionSignatureBuilder() - .returnType(type) - .argumentType(type) - .argumentType(type) - .build()); + signatures.push_back( + exec::FunctionSignatureBuilder() + .returnType(type) + .argumentType(type) + .argumentType(type) + .build()); } return signatures; } @@ -258,15 +259,7 @@ class PreprocBenchmark : public functions::test::FunctionBenchmarkBase { } exec::ExprSet compile(const std::vector& texts) { - std::vector typedExprs; - parse::ParseOptions options; - for (const auto& text : texts) { - auto untyped = parse::parseExpr(text, options); - auto typed = core::Expressions::inferTypes( - untyped, ROW({"c0"}, {REAL()}), execCtx_.pool()); - typedExprs.push_back(typed); - } - return exec::ExprSet(typedExprs, &execCtx_); + return compileExpressions(texts, ROW({{"c0", REAL()}})); } std::string makeExpression(int n, RunConfig config, bool withNulls) { diff --git a/velox/benchmarks/basic/VectorSlice.cpp b/velox/benchmarks/basic/VectorSlice.cpp index 44ad91f8200..54b706265e7 100644 --- a/velox/benchmarks/basic/VectorSlice.cpp +++ b/velox/benchmarks/basic/VectorSlice.cpp @@ -37,9 +37,10 @@ constexpr int kVectorSize = 16 << 10; struct BenchmarkData { BenchmarkData() - : pool_(memory::memoryManager()->addLeafPool( - "BenchmarkData", - FLAGS_use_thread_safe_memory_usage_track)) { + : pool_( + memory::memoryManager()->addLeafPool( + "BenchmarkData", + FLAGS_use_thread_safe_memory_usage_track)) { VectorFuzzer::Options opts; opts.nullRatio = 0.01; opts.vectorSize = kVectorSize; diff --git a/velox/benchmarks/filesystem/CMakeLists.txt b/velox/benchmarks/filesystem/CMakeLists.txt index 5f5b411b94a..9b272573ba3 100644 --- a/velox/benchmarks/filesystem/CMakeLists.txt +++ b/velox/benchmarks/filesystem/CMakeLists.txt @@ -13,6 +13,7 @@ # limitations under the License. add_library(velox_read_benchmark_lib ReadBenchmark.cpp) +velox_add_test_headers(velox_read_benchmark_lib ReadBenchmark.h) target_link_libraries( velox_read_benchmark_lib diff --git a/velox/benchmarks/filesystem/ReadBenchmark.h b/velox/benchmarks/filesystem/ReadBenchmark.h index e033bc953ac..69b26082d64 100644 --- a/velox/benchmarks/filesystem/ReadBenchmark.h +++ b/velox/benchmarks/filesystem/ReadBenchmark.h @@ -180,8 +180,9 @@ class ReadBenchmark { } else { std::vector> ranges; for (auto start = 0; start < rangeSize; start += size + gap) { - ranges.push_back(folly::Range( - globalScratch.buffer.data() + start, size)); + ranges.push_back( + folly::Range( + globalScratch.buffer.data() + start, size)); if (gap && start + gap < rangeSize) { ranges.push_back(folly::Range(nullptr, gap)); } diff --git a/velox/benchmarks/tpcds/CMakeLists.txt b/velox/benchmarks/tpcds/CMakeLists.txt new file mode 100644 index 00000000000..15a216eca8b --- /dev/null +++ b/velox/benchmarks/tpcds/CMakeLists.txt @@ -0,0 +1,44 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_library(velox_tpcds_benchmark_lib TpcdsBenchmark.cpp) +velox_add_test_headers(velox_tpcds_benchmark_lib TpcdsBenchmark.h) + +target_link_libraries( + velox_tpcds_benchmark_lib + velox_query_benchmark + velox_aggregates + velox_window + velox_exec + velox_exec_test_lib + velox_dwio_common + velox_dwio_parquet_reader + velox_dwio_common_test_utils + velox_hive_connector + velox_exception + velox_memory + velox_process + velox_serialization + velox_encode + velox_type + velox_caching + velox_vector_test_lib + Folly::follybenchmark + Folly::folly + fmt::fmt +) + +add_executable(velox_tpcds_benchmark TpcdsBenchmarkMain.cpp) + +target_link_libraries(velox_tpcds_benchmark velox_tpcds_benchmark_lib) diff --git a/velox/benchmarks/tpcds/TpcdsBenchmark.cpp b/velox/benchmarks/tpcds/TpcdsBenchmark.cpp new file mode 100644 index 00000000000..c07b3d7342c --- /dev/null +++ b/velox/benchmarks/tpcds/TpcdsBenchmark.cpp @@ -0,0 +1,523 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/benchmarks/tpcds/TpcdsBenchmark.h" + +#include "velox/connectors/ConnectorRegistry.h" +#include "velox/connectors/hive/HiveConnector.h" +#include "velox/connectors/hive/HiveConnectorSplit.h" +#include "velox/exec/PartitionFunction.h" +#include "velox/exec/PlanNodeStats.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/exec/tests/utils/TpchQueryBuilder.h" +#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/functions/prestosql/window/WindowFunctionsRegistration.h" +#include "velox/parse/TypeResolver.h" + +#include + +using namespace facebook::velox; +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; +using namespace facebook::velox::dwio::common; + +DEFINE_string( + data_path, + "", + "Root path of TPC-DS data. Data layout must follow Hive-style partitioning. " + "Example layout for '--data_path=/data/tpcds_sf100'\n" + " /data/tpcds_sf100/store_sales\n" + " /data/tpcds_sf100/customer\n" + " /data/tpcds_sf100/date_dim\n" + " /data/tpcds_sf100/store\n" + " ...\n" + "Each directory contains data files (e.g. parquet) for the table."); + +namespace { +static bool notEmpty(const char* /*flagName*/, const std::string& value) { + return !value.empty(); +} +} // namespace + +DEFINE_validator(data_path, ¬Empty); + +DEFINE_string( + plan_path, + "", + "Directory containing Velox plan JSON files (Q1.json, Q2.json, ...). " + "These are typically dumped from a Presto worker's plan-dump-dir."); + +DEFINE_validator(plan_path, ¬Empty); + +DEFINE_int32( + run_query_verbose, + -1, + "Run a given query and print execution statistics"); + +void TpcdsBenchmark::initQueryBuilder() { + queryBuilder_ = + std::make_unique(toFileFormat(FLAGS_data_format)); + queryBuilder_->initialize(FLAGS_data_path); +} + +void TpcdsBenchmark::initialize() { + QueryBenchmarkBase::initialize(); + + const std::string prestoPrefix{kPrestoFunctionNamespacePrefix}; + functions::prestosql::registerAllScalarFunctions(prestoPrefix); + aggregate::prestosql::registerAllAggregateFunctions(prestoPrefix); + window::prestosql::registerAllWindowFunctions(prestoPrefix); + + // Register serialization/deserialization for plan nodes. + Type::registerSerDe(); + common::Filter::registerSerDe(); + connector::hive::HiveConnector::registerSerDe(); + core::PlanNode::registerSerDe(); + core::ITypedExpr::registerSerDe(); + exec::registerPartitionFunctionSerDe(); + + // Presto-dumped plans use connector ID "hive", while the base class + // registers under kHiveConnectorId ("test-hive"). Register a properly + // configured connector under "hive" so both the plan and splits match. + const std::string prestoConnectorId{kPrestoHiveConnectorId}; + if (connector::ConnectorRegistry::tryGet(prestoConnectorId) == nullptr) { + auto properties = makeConnectorProperties(); + connector::hive::HiveConnectorFactory factory; + auto hiveConnector = + factory.newConnector(prestoConnectorId, properties, ioExecutor_.get()); + connector::ConnectorRegistry::global().insert( + hiveConnector->connectorId(), hiveConnector); + } + + planDir_ = FLAGS_plan_path; + pool_ = memory::memoryManager()->addLeafPool("TpcdsBenchmark"); + + initQueryBuilder(); +} + +void TpcdsBenchmark::shutdown() { + if (queryBuilder_) { + queryBuilder_->shutdown(); + queryBuilder_.reset(); + } + const std::string prestoConnectorId{kPrestoHiveConnectorId}; + if (connector::ConnectorRegistry::tryGet(prestoConnectorId) != nullptr) { + connector::ConnectorRegistry::global().erase(prestoConnectorId); + } + pool_.reset(); + QueryBenchmarkBase::shutdown(); +} + +std::vector> +TpcdsBenchmark::listSplits( + const std::string& path, + int32_t numSplitsPerFile, + const exec::test::TpchPlan& plan) { + // The base class creates splits with the default (empty) connector ID, + // which doesn't match the plan's connector ID (e.g. "hive" from Presto). + // Re-create splits with the correct connector ID from the query builder. + auto baseSplits = + QueryBenchmarkBase::listSplits(path, numSplitsPerFile, plan); + + const auto& cid = + queryBuilder_ ? queryBuilder_->connectorId() : std::string(); + if (cid.empty()) { + return baseSplits; + } + + // Rebuild each split with the plan's connector ID. + std::vector> result; + result.reserve(baseSplits.size()); + for (const auto& baseSplit : baseSplits) { + auto hiveSplit = + std::dynamic_pointer_cast( + baseSplit); + if (hiveSplit) { + result.push_back( + connector::hive::HiveConnectorSplitBuilder(hiveSplit->filePath) + .connectorId(cid) + .fileFormat(hiveSplit->fileFormat) + .start(hiveSplit->start) + .length(hiveSplit->length) + .build()); + } else { + result.push_back(baseSplit); + } + } + return result; +} + +void TpcdsBenchmark::runQuery(int32_t queryId) { + auto plan = queryBuilder_->getQueryPlan(queryId, planDir_, pool_.get()); + run(plan, queryConfigs_); +} + +void TpcdsBenchmark::runMain( + std::ostream& out, + facebook::velox::RunStats& runStats) { + if (FLAGS_run_query_verbose == -1) { + folly::runBenchmarks(); + } else { + auto plan = queryBuilder_->getQueryPlan( + FLAGS_run_query_verbose, planDir_, pool_.get()); + + auto [cursor, actualResults] = run(plan, queryConfigs_); + VELOX_CHECK(cursor, "Query terminated with error. Exiting"); + auto task = cursor->task(); + ensureTaskCompletion(task.get()); + if (FLAGS_include_results) { + printResults(actualResults, out); + out << std::endl; + } + const auto stats = task->taskStats(); + int64_t rawInputBytes = 0; + for (auto& pipeline : stats.pipelineStats) { + auto& first = pipeline.operatorStats[0]; + if (first.operatorType == "TableScan") { + rawInputBytes += first.rawInputBytes; + } + } + runStats.rawInputBytes = rawInputBytes; + out << fmt::format( + "Execution time: {}", + facebook::velox::succinctMillis( + stats.executionEndTimeMs - stats.executionStartTimeMs)) + << std::endl; + out << fmt::format( + "Splits total: {}, finished: {}", + stats.numTotalSplits, + stats.numFinishedSplits) + << std::endl; + out << printPlanWithStats(*plan.plan, stats, FLAGS_include_custom_stats) + << std::endl; + } +} + +std::unique_ptr tpcdsBenchmark; + +// TPC-DS has 99 queries. Define BENCHMARK macros for each. +// Only a subset may have plan JSONs available; missing plans will cause +// a runtime error for that specific query. +BENCHMARK(tpcds_q1) { + tpcdsBenchmark->runQuery(1); +} +BENCHMARK(tpcds_q2) { + tpcdsBenchmark->runQuery(2); +} +BENCHMARK(tpcds_q3) { + tpcdsBenchmark->runQuery(3); +} +BENCHMARK(tpcds_q4) { + tpcdsBenchmark->runQuery(4); +} +BENCHMARK(tpcds_q5) { + tpcdsBenchmark->runQuery(5); +} +BENCHMARK(tpcds_q6) { + tpcdsBenchmark->runQuery(6); +} +BENCHMARK(tpcds_q7) { + tpcdsBenchmark->runQuery(7); +} +BENCHMARK(tpcds_q8) { + tpcdsBenchmark->runQuery(8); +} +BENCHMARK(tpcds_q9) { + tpcdsBenchmark->runQuery(9); +} +BENCHMARK(tpcds_q10) { + tpcdsBenchmark->runQuery(10); +} +BENCHMARK(tpcds_q11) { + tpcdsBenchmark->runQuery(11); +} +BENCHMARK(tpcds_q12) { + tpcdsBenchmark->runQuery(12); +} +BENCHMARK(tpcds_q13) { + tpcdsBenchmark->runQuery(13); +} +BENCHMARK(tpcds_q14) { + tpcdsBenchmark->runQuery(14); +} +BENCHMARK(tpcds_q15) { + tpcdsBenchmark->runQuery(15); +} +BENCHMARK(tpcds_q16) { + tpcdsBenchmark->runQuery(16); +} +BENCHMARK(tpcds_q17) { + tpcdsBenchmark->runQuery(17); +} +BENCHMARK(tpcds_q18) { + tpcdsBenchmark->runQuery(18); +} +BENCHMARK(tpcds_q19) { + tpcdsBenchmark->runQuery(19); +} +BENCHMARK(tpcds_q20) { + tpcdsBenchmark->runQuery(20); +} +BENCHMARK(tpcds_q21) { + tpcdsBenchmark->runQuery(21); +} +BENCHMARK(tpcds_q22) { + tpcdsBenchmark->runQuery(22); +} +BENCHMARK(tpcds_q23) { + tpcdsBenchmark->runQuery(23); +} +BENCHMARK(tpcds_q24) { + tpcdsBenchmark->runQuery(24); +} +BENCHMARK(tpcds_q25) { + tpcdsBenchmark->runQuery(25); +} +BENCHMARK(tpcds_q26) { + tpcdsBenchmark->runQuery(26); +} +BENCHMARK(tpcds_q27) { + tpcdsBenchmark->runQuery(27); +} +BENCHMARK(tpcds_q28) { + tpcdsBenchmark->runQuery(28); +} +BENCHMARK(tpcds_q29) { + tpcdsBenchmark->runQuery(29); +} +BENCHMARK(tpcds_q30) { + tpcdsBenchmark->runQuery(30); +} +BENCHMARK(tpcds_q31) { + tpcdsBenchmark->runQuery(31); +} +BENCHMARK(tpcds_q32) { + tpcdsBenchmark->runQuery(32); +} +BENCHMARK(tpcds_q33) { + tpcdsBenchmark->runQuery(33); +} +BENCHMARK(tpcds_q34) { + tpcdsBenchmark->runQuery(34); +} +BENCHMARK(tpcds_q35) { + tpcdsBenchmark->runQuery(35); +} +BENCHMARK(tpcds_q36) { + tpcdsBenchmark->runQuery(36); +} +BENCHMARK(tpcds_q37) { + tpcdsBenchmark->runQuery(37); +} +BENCHMARK(tpcds_q38) { + tpcdsBenchmark->runQuery(38); +} +BENCHMARK(tpcds_q39) { + tpcdsBenchmark->runQuery(39); +} +BENCHMARK(tpcds_q40) { + tpcdsBenchmark->runQuery(40); +} +BENCHMARK(tpcds_q41) { + tpcdsBenchmark->runQuery(41); +} +BENCHMARK(tpcds_q42) { + tpcdsBenchmark->runQuery(42); +} +BENCHMARK(tpcds_q43) { + tpcdsBenchmark->runQuery(43); +} +BENCHMARK(tpcds_q44) { + tpcdsBenchmark->runQuery(44); +} +BENCHMARK(tpcds_q45) { + tpcdsBenchmark->runQuery(45); +} +BENCHMARK(tpcds_q46) { + tpcdsBenchmark->runQuery(46); +} +BENCHMARK(tpcds_q47) { + tpcdsBenchmark->runQuery(47); +} +BENCHMARK(tpcds_q48) { + tpcdsBenchmark->runQuery(48); +} +BENCHMARK(tpcds_q49) { + tpcdsBenchmark->runQuery(49); +} +BENCHMARK(tpcds_q50) { + tpcdsBenchmark->runQuery(50); +} +BENCHMARK(tpcds_q51) { + tpcdsBenchmark->runQuery(51); +} +BENCHMARK(tpcds_q52) { + tpcdsBenchmark->runQuery(52); +} +BENCHMARK(tpcds_q53) { + tpcdsBenchmark->runQuery(53); +} +BENCHMARK(tpcds_q54) { + tpcdsBenchmark->runQuery(54); +} +BENCHMARK(tpcds_q55) { + tpcdsBenchmark->runQuery(55); +} +BENCHMARK(tpcds_q56) { + tpcdsBenchmark->runQuery(56); +} +BENCHMARK(tpcds_q57) { + tpcdsBenchmark->runQuery(57); +} +BENCHMARK(tpcds_q58) { + tpcdsBenchmark->runQuery(58); +} +BENCHMARK(tpcds_q59) { + tpcdsBenchmark->runQuery(59); +} +BENCHMARK(tpcds_q60) { + tpcdsBenchmark->runQuery(60); +} +BENCHMARK(tpcds_q61) { + tpcdsBenchmark->runQuery(61); +} +BENCHMARK(tpcds_q62) { + tpcdsBenchmark->runQuery(62); +} +BENCHMARK(tpcds_q63) { + tpcdsBenchmark->runQuery(63); +} +BENCHMARK(tpcds_q64) { + tpcdsBenchmark->runQuery(64); +} +BENCHMARK(tpcds_q65) { + tpcdsBenchmark->runQuery(65); +} +BENCHMARK(tpcds_q66) { + tpcdsBenchmark->runQuery(66); +} +BENCHMARK(tpcds_q67) { + tpcdsBenchmark->runQuery(67); +} +BENCHMARK(tpcds_q68) { + tpcdsBenchmark->runQuery(68); +} +BENCHMARK(tpcds_q69) { + tpcdsBenchmark->runQuery(69); +} +BENCHMARK(tpcds_q70) { + tpcdsBenchmark->runQuery(70); +} +BENCHMARK(tpcds_q71) { + tpcdsBenchmark->runQuery(71); +} +BENCHMARK(tpcds_q72) { + tpcdsBenchmark->runQuery(72); +} +BENCHMARK(tpcds_q73) { + tpcdsBenchmark->runQuery(73); +} +BENCHMARK(tpcds_q74) { + tpcdsBenchmark->runQuery(74); +} +BENCHMARK(tpcds_q75) { + tpcdsBenchmark->runQuery(75); +} +BENCHMARK(tpcds_q76) { + tpcdsBenchmark->runQuery(76); +} +BENCHMARK(tpcds_q77) { + tpcdsBenchmark->runQuery(77); +} +BENCHMARK(tpcds_q78) { + tpcdsBenchmark->runQuery(78); +} +BENCHMARK(tpcds_q79) { + tpcdsBenchmark->runQuery(79); +} +BENCHMARK(tpcds_q80) { + tpcdsBenchmark->runQuery(80); +} +BENCHMARK(tpcds_q81) { + tpcdsBenchmark->runQuery(81); +} +BENCHMARK(tpcds_q82) { + tpcdsBenchmark->runQuery(82); +} +BENCHMARK(tpcds_q83) { + tpcdsBenchmark->runQuery(83); +} +BENCHMARK(tpcds_q84) { + tpcdsBenchmark->runQuery(84); +} +BENCHMARK(tpcds_q85) { + tpcdsBenchmark->runQuery(85); +} +BENCHMARK(tpcds_q86) { + tpcdsBenchmark->runQuery(86); +} +BENCHMARK(tpcds_q87) { + tpcdsBenchmark->runQuery(87); +} +BENCHMARK(tpcds_q88) { + tpcdsBenchmark->runQuery(88); +} +BENCHMARK(tpcds_q89) { + tpcdsBenchmark->runQuery(89); +} +BENCHMARK(tpcds_q90) { + tpcdsBenchmark->runQuery(90); +} +BENCHMARK(tpcds_q91) { + tpcdsBenchmark->runQuery(91); +} +BENCHMARK(tpcds_q92) { + tpcdsBenchmark->runQuery(92); +} +BENCHMARK(tpcds_q93) { + tpcdsBenchmark->runQuery(93); +} +BENCHMARK(tpcds_q94) { + tpcdsBenchmark->runQuery(94); +} +BENCHMARK(tpcds_q95) { + tpcdsBenchmark->runQuery(95); +} +BENCHMARK(tpcds_q96) { + tpcdsBenchmark->runQuery(96); +} +BENCHMARK(tpcds_q97) { + tpcdsBenchmark->runQuery(97); +} +BENCHMARK(tpcds_q98) { + tpcdsBenchmark->runQuery(98); +} +BENCHMARK(tpcds_q99) { + tpcdsBenchmark->runQuery(99); +} + +void tpcdsBenchmarkMain() { + VELOX_CHECK_NOT_NULL(tpcdsBenchmark); + tpcdsBenchmark->initialize(); + if (FLAGS_test_flags_file.empty()) { + RunStats ignore; + tpcdsBenchmark->runMain(std::cout, ignore); + } else { + tpcdsBenchmark->runAllCombinations(); + } + tpcdsBenchmark->shutdown(); +} diff --git a/velox/benchmarks/tpcds/TpcdsBenchmark.h b/velox/benchmarks/tpcds/TpcdsBenchmark.h new file mode 100644 index 00000000000..dc7809e851f --- /dev/null +++ b/velox/benchmarks/tpcds/TpcdsBenchmark.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +#include "velox/benchmarks/QueryBenchmarkBase.h" +#include "velox/exec/tests/utils/TpcdsQueryBuilder.h" + +DECLARE_string(plan_path); + +inline constexpr std::string_view kPrestoHiveConnectorId = "hive"; + +inline constexpr std::string_view kPrestoFunctionNamespacePrefix = + "presto.default."; + +class TpcdsBenchmark : public facebook::velox::QueryBenchmarkBase { + public: + void initialize() override; + + void shutdown() override; + + void runMain(std::ostream& out, facebook::velox::RunStats& runStats) override; + + void runQuery(int32_t queryId); + + /// Override to stamp splits with the plan's connector ID (e.g. "hive") + /// instead of the default "test-hive" used by the base class. + std::vector> + listSplits( + const std::string& path, + int32_t numSplitsPerFile, + const facebook::velox::exec::test::TpchPlan& plan) override; + + protected: + /// Override to create a different query builder (e.g. CudfTpcdsQueryBuilder). + virtual void initQueryBuilder(); + + std::unique_ptr queryBuilder_; + std::unordered_map queryConfigs_; + std::string planDir_; + std::shared_ptr pool_; +}; + +extern std::unique_ptr tpcdsBenchmark; + +void tpcdsBenchmarkMain(); diff --git a/velox/benchmarks/tpcds/TpcdsBenchmarkMain.cpp b/velox/benchmarks/tpcds/TpcdsBenchmarkMain.cpp new file mode 100644 index 00000000000..49eecd8baeb --- /dev/null +++ b/velox/benchmarks/tpcds/TpcdsBenchmarkMain.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "velox/benchmarks/tpcds/TpcdsBenchmark.h" + +int main(int argc, char** argv) { + std::string kUsage( + "This program benchmarks TPC-DS queries. Run " + "'velox_tpcds_benchmark -helpon=TpcdsBenchmark' for available options.\n"); + gflags::SetUsageMessage(kUsage); + folly::Init init{&argc, &argv, false}; + tpcdsBenchmark = std::make_unique(); + tpcdsBenchmarkMain(); +} diff --git a/velox/benchmarks/tpch/CMakeLists.txt b/velox/benchmarks/tpch/CMakeLists.txt index 1dbe29ea7d2..326fe814d40 100644 --- a/velox/benchmarks/tpch/CMakeLists.txt +++ b/velox/benchmarks/tpch/CMakeLists.txt @@ -13,6 +13,7 @@ # limitations under the License. add_library(velox_tpch_benchmark_lib TpchBenchmark.cpp) +velox_add_test_headers(velox_tpch_benchmark_lib TpchBenchmark.h) target_link_libraries( velox_tpch_benchmark_lib diff --git a/velox/benchmarks/tpch/TpchBenchmark.cpp b/velox/benchmarks/tpch/TpchBenchmark.cpp index 1f149f068ec..96935d9b7c6 100644 --- a/velox/benchmarks/tpch/TpchBenchmark.cpp +++ b/velox/benchmarks/tpch/TpchBenchmark.cpp @@ -14,8 +14,9 @@ * limitations under the License. */ +#include "velox/benchmarks/tpch/TpchBenchmark.h" #include -#include "velox/benchmarks/QueryBenchmarkBase.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/PlanNodeStats.h" using namespace facebook::velox; @@ -59,178 +60,165 @@ DEFINE_int32( "include in IO meter query. The columns are sorted by name and the n% first " "are scanned"); -std::shared_ptr queryBuilder; - -class TpchBenchmark : public QueryBenchmarkBase { - public: - void runMain(std::ostream& out, RunStats& runStats) override { - if (FLAGS_run_query_verbose == -1 && FLAGS_io_meter_column_pct == 0) { - folly::runBenchmarks(); - } else { - const auto queryPlan = FLAGS_io_meter_column_pct > 0 - ? queryBuilder->getIoMeterPlan(FLAGS_io_meter_column_pct) - : queryBuilder->getQueryPlan(FLAGS_run_query_verbose); - auto [cursor, actualResults] = run(queryPlan); - if (!cursor) { - LOG(ERROR) << "Query terminated with error. Exiting"; - exit(1); - } - auto task = cursor->task(); - ensureTaskCompletion(task.get()); - if (FLAGS_include_results) { - printResults(actualResults, out); - out << std::endl; - } - const auto stats = task->taskStats(); - int64_t rawInputBytes = 0; - for (auto& pipeline : stats.pipelineStats) { - auto& first = pipeline.operatorStats[0]; - if (first.operatorType == "TableScan") { - rawInputBytes += first.rawInputBytes; - } +void TpchBenchmark::initQueryBuilder() { + queryBuilder_ = + std::make_shared(toFileFormat(FLAGS_data_format)); + queryBuilder_->initialize(FLAGS_data_path); +} + +void TpchBenchmark::initialize() { + QueryBenchmarkBase::initialize(); + initQueryBuilder(); +} + +void TpchBenchmark::shutdown() { + QueryBenchmarkBase::shutdown(); + queryBuilder_.reset(); +} + +void TpchBenchmark::runMain( + std::ostream& out, + facebook::velox::RunStats& runStats) { + if (FLAGS_run_query_verbose == -1 && FLAGS_io_meter_column_pct == 0) { + folly::runBenchmarks(); + } else { + auto queryPlan = FLAGS_io_meter_column_pct > 0 + ? queryBuilder_->getIoMeterPlan(FLAGS_io_meter_column_pct) + : queryBuilder_->getQueryPlan(FLAGS_run_query_verbose); + auto [cursor, actualResults] = run(queryPlan, queryConfigs_); + if (!cursor) { + LOG(ERROR) << "Query terminated with error. Exiting"; + exit(1); + } + auto task = cursor->task(); + ensureTaskCompletion(task.get()); + if (FLAGS_include_results) { + printResults(actualResults, out); + out << std::endl; + } + const auto stats = task->taskStats(); + int64_t rawInputBytes = 0; + for (auto& pipeline : stats.pipelineStats) { + auto& first = pipeline.operatorStats[0]; + if (first.operatorType == OperatorType::kTableScan) { + rawInputBytes += first.rawInputBytes; } - runStats.rawInputBytes = rawInputBytes; - out << fmt::format( - "Execution time: {}", - succinctMillis( - stats.executionEndTimeMs - stats.executionStartTimeMs)) - << std::endl; - out << fmt::format( - "Splits total: {}, finished: {}", - stats.numTotalSplits, - stats.numFinishedSplits) - << std::endl; - out << printPlanWithStats( - *queryPlan.plan, stats, FLAGS_include_custom_stats) - << std::endl; } + runStats.rawInputBytes = rawInputBytes; + out << fmt::format( + "Execution time: {}", + facebook::velox::succinctMillis( + stats.executionEndTimeMs - stats.executionStartTimeMs)) + << std::endl; + out << fmt::format( + "Splits total: {}, finished: {}", + stats.numTotalSplits, + stats.numFinishedSplits) + << std::endl; + out << printPlanWithStats( + *queryPlan.plan, stats, FLAGS_include_custom_stats) + << std::endl; } -}; +} -TpchBenchmark benchmark; +std::unique_ptr benchmark; BENCHMARK(q1) { - const auto planContext = queryBuilder->getQueryPlan(1); - benchmark.run(planContext); + benchmark->runQuery(1); } BENCHMARK(q2) { - const auto planContext = queryBuilder->getQueryPlan(2); - benchmark.run(planContext); + benchmark->runQuery(2); } BENCHMARK(q3) { - const auto planContext = queryBuilder->getQueryPlan(3); - benchmark.run(planContext); + benchmark->runQuery(3); } BENCHMARK(q4) { - const auto planContext = queryBuilder->getQueryPlan(4); - benchmark.run(planContext); + benchmark->runQuery(4); } BENCHMARK(q5) { - const auto planContext = queryBuilder->getQueryPlan(5); - benchmark.run(planContext); + benchmark->runQuery(5); } BENCHMARK(q6) { - const auto planContext = queryBuilder->getQueryPlan(6); - benchmark.run(planContext); + benchmark->runQuery(6); } BENCHMARK(q7) { - const auto planContext = queryBuilder->getQueryPlan(7); - benchmark.run(planContext); + benchmark->runQuery(7); } BENCHMARK(q8) { - const auto planContext = queryBuilder->getQueryPlan(8); - benchmark.run(planContext); + benchmark->runQuery(8); } BENCHMARK(q9) { - const auto planContext = queryBuilder->getQueryPlan(9); - benchmark.run(planContext); + benchmark->runQuery(9); } BENCHMARK(q10) { - const auto planContext = queryBuilder->getQueryPlan(10); - benchmark.run(planContext); + benchmark->runQuery(10); } BENCHMARK(q11) { - const auto planContext = queryBuilder->getQueryPlan(11); - benchmark.run(planContext); + benchmark->runQuery(11); } BENCHMARK(q12) { - const auto planContext = queryBuilder->getQueryPlan(12); - benchmark.run(planContext); + benchmark->runQuery(12); } BENCHMARK(q13) { - const auto planContext = queryBuilder->getQueryPlan(13); - benchmark.run(planContext); + benchmark->runQuery(13); } BENCHMARK(q14) { - const auto planContext = queryBuilder->getQueryPlan(14); - benchmark.run(planContext); + benchmark->runQuery(14); } BENCHMARK(q15) { - const auto planContext = queryBuilder->getQueryPlan(15); - benchmark.run(planContext); + benchmark->runQuery(15); } BENCHMARK(q16) { - const auto planContext = queryBuilder->getQueryPlan(16); - benchmark.run(planContext); + benchmark->runQuery(16); } BENCHMARK(q17) { - const auto planContext = queryBuilder->getQueryPlan(17); - benchmark.run(planContext); + benchmark->runQuery(17); } BENCHMARK(q18) { - const auto planContext = queryBuilder->getQueryPlan(18); - benchmark.run(planContext); + benchmark->runQuery(18); } BENCHMARK(q19) { - const auto planContext = queryBuilder->getQueryPlan(19); - benchmark.run(planContext); + benchmark->runQuery(19); } BENCHMARK(q20) { - const auto planContext = queryBuilder->getQueryPlan(20); - benchmark.run(planContext); + benchmark->runQuery(20); } BENCHMARK(q21) { - const auto planContext = queryBuilder->getQueryPlan(21); - benchmark.run(planContext); + benchmark->runQuery(21); } BENCHMARK(q22) { - const auto planContext = queryBuilder->getQueryPlan(22); - benchmark.run(planContext); + benchmark->runQuery(22); } -int tpchBenchmarkMain() { - benchmark.initialize(); - queryBuilder = - std::make_shared(toFileFormat(FLAGS_data_format)); - queryBuilder->initialize(FLAGS_data_path); +void tpchBenchmarkMain() { + VELOX_CHECK_NOT_NULL(benchmark); + benchmark->initialize(); if (FLAGS_test_flags_file.empty()) { RunStats ignore; - benchmark.runMain(std::cout, ignore); + benchmark->runMain(std::cout, ignore); } else { - benchmark.runAllCombinations(); + benchmark->runAllCombinations(); } - benchmark.shutdown(); - queryBuilder.reset(); - return 0; + benchmark->shutdown(); } diff --git a/velox/benchmarks/tpch/TpchBenchmark.h b/velox/benchmarks/tpch/TpchBenchmark.h index e66e7c53cbc..7297f7db17a 100644 --- a/velox/benchmarks/tpch/TpchBenchmark.h +++ b/velox/benchmarks/tpch/TpchBenchmark.h @@ -15,4 +15,31 @@ */ #pragma once +#include "velox/benchmarks/QueryBenchmarkBase.h" +#include "velox/exec/tests/utils/TpchQueryBuilder.h" + +class TpchBenchmark : public facebook::velox::QueryBenchmarkBase { + public: + void initialize() override; + + void shutdown() override; + + void runMain(std::ostream& out, facebook::velox::RunStats& runStats) override; + + void runQuery(int32_t queryId) { + const auto planContext = queryBuilder_->getQueryPlan(queryId); + run(planContext, queryConfigs_); + } + + protected: + std::unordered_map queryConfigs_; + + private: + void initQueryBuilder(); + + std::shared_ptr queryBuilder_; +}; + +extern std::unique_ptr benchmark; + void tpchBenchmarkMain(); diff --git a/velox/benchmarks/tpch/TpchBenchmarkMain.cpp b/velox/benchmarks/tpch/TpchBenchmarkMain.cpp index 4477455d8f3..0fa9718ca7b 100644 --- a/velox/benchmarks/tpch/TpchBenchmarkMain.cpp +++ b/velox/benchmarks/tpch/TpchBenchmarkMain.cpp @@ -24,5 +24,6 @@ int main(int argc, char** argv) { "This program benchmarks TPC-H queries. Run 'velox_tpch_benchmark -helpon=TpchBenchmark' for available options.\n"); gflags::SetUsageMessage(kUsage); folly::Init init{&argc, &argv, false}; + benchmark = std::make_unique(); tpchBenchmarkMain(); } diff --git a/velox/buffer/Buffer.cpp b/velox/buffer/Buffer.cpp index 80695929540..abb7a3f09d6 100644 --- a/velox/buffer/Buffer.cpp +++ b/velox/buffer/Buffer.cpp @@ -18,9 +18,24 @@ namespace facebook::velox { +std::string Buffer::typeString(Type type) { + switch (type) { + case Type::kPOD: + return "kPOD"; + case Type::kNonPOD: + return "kNonPOD"; + case Type::kPODView: + return "kPODView"; + case Type::kNonPODView: + return "kNonPODView"; + default: + return fmt::format("Unknown({})", static_cast(type)); + } +} + namespace { struct BufferReleaser { - explicit BufferReleaser(const BufferPtr& parent) : parent_(parent) {} + explicit BufferReleaser(BufferPtr parent) : parent_{std::move(parent)} {} void addRef() const {} void release() const {} diff --git a/velox/buffer/Buffer.h b/velox/buffer/Buffer.h index 48cd0a62909..85ceab56712 100644 --- a/velox/buffer/Buffer.h +++ b/velox/buffer/Buffer.h @@ -19,6 +19,8 @@ #include #include +#include +#include #include "velox/common/base/BitUtil.h" #include "velox/common/base/CheckedArithmetic.h" #include "velox/common/base/Exceptions.h" @@ -26,8 +28,7 @@ #include "velox/common/base/SimdUtil.h" #include "velox/common/memory/Memory.h" -namespace facebook { -namespace velox { +namespace facebook::velox { class Buffer; class AlignedBuffer; @@ -58,21 +59,40 @@ class Buffer { // type. Thus the conditions are: trivial destructor (no resources to release) // and trivially copyable (so memcpy works) template - static inline constexpr bool is_pod_like_v = + static constexpr bool is_pod_like_v = std::is_trivially_destructible_v && std::is_trivially_copyable_v; - virtual ~Buffer() {} + virtual ~Buffer() = default; - void addRef() { - referenceCount_.fetch_add(1); + static constexpr uint8_t kPODBit = 0; + static constexpr uint8_t kPODMask = 1 << kPODBit; + static constexpr uint8_t kViewBit = 1; + static constexpr uint8_t kViewMask = 1 << kViewBit; + static_assert(kPODBit != kViewBit); + + enum class Type : uint8_t { + kNonPOD = 0 << kPODBit | 0 << kViewBit, + kPOD = 1 << kPODBit | 0 << kViewBit, + kNonPODView = 0 << kPODBit | 1 << kViewBit, + kPODView = 1 << kPODBit | 1 << kViewBit, + }; + + static std::string typeString(Type type); + + Type type() const { + return type_; + } + + void addRef() noexcept { + referenceCount_.fetch_add(1, std::memory_order_acq_rel); } - int refCount() const { - return referenceCount_; + int refCount() const noexcept { + return referenceCount_.load(std::memory_order_acquire); } void release() { - if (referenceCount_.fetch_sub(1) == 1) { + if (referenceCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) { releaseResources(); if (pool_) { freeToPool(); @@ -86,13 +106,13 @@ class Buffer { const T* as() const { // We can't check actual types, but we can sanity-check POD/non-POD // conversion. `void` is special as it's used in type-erased contexts - VELOX_DCHECK((std::is_same_v) || podType_ == is_pod_like_v); + VELOX_DCHECK(std::is_void_v || isPOD() == is_pod_like_v); return reinterpret_cast(data_); } template Range asRange() { - return Range(as(), 0, size() / sizeof(T)); + return {as(), 0, static_cast(size() / sizeof(T))}; } template @@ -102,16 +122,16 @@ class Buffer { VELOX_CHECK(!isView()); // We can't check actual types, but we can sanity-check POD/non-POD // conversion. `void` is special as it's used in type-erased contexts - VELOX_DCHECK((std::is_same_v) || podType_ == is_pod_like_v); + VELOX_DCHECK(std::is_void_v || isPOD() == is_pod_like_v); return reinterpret_cast(data_); } template MutableRange asMutableRange() { - return MutableRange(asMutable(), 0, size() / sizeof(T)); + return {asMutable(), 0, static_cast(size() / sizeof(T))}; } - size_t size() const { + size_t size() const noexcept { return size_; } @@ -126,24 +146,28 @@ class Buffer { checkEndGuard(); } - uint64_t capacity() const { + uint64_t capacity() const noexcept { return capacity_; } - bool unique() const { - return referenceCount_ == 1; + bool unique() const noexcept { + return refCount() == 1; } - velox::memory::MemoryPool* pool() const { + velox::memory::MemoryPool* pool() const noexcept { return pool_; } - bool isMutable() const { + bool isMutable() const noexcept { return !isView() && unique(); } - virtual bool isView() const { - return false; + bool isView() const { + return (static_cast(type_) & kViewMask) != 0; + } + + bool isPOD() const { + return (static_cast(type_) & kPODMask) != 0; } friend std::ostream& operator<<(std::ostream& os, const Buffer& buffer) { @@ -199,6 +223,14 @@ class Buffer { sizeof(T), is_pod_like_v, buffer, offset, length); } + /// Transfers this buffer to 'pool'. Returns true if the transfer succeeds, or + /// false if the transfer fails. A buffer can be transferred to 'pool' if its + /// original pool and 'pool' are from the same MemoryAllocator and the buffer + /// is not a BufferView. + virtual bool transferTo(velox::memory::MemoryPool* /*pool*/) { + VELOX_NYI("{} unsupported", __FUNCTION__); + } + protected: // Writes a magic word at 'capacity_'. No-op for a BufferView. The actual // logic is inside a separate virtual function, allowing override by derived @@ -241,7 +273,7 @@ class Buffer { virtual void copyFrom(const Buffer* other, size_t bytes) { VELOX_CHECK(!isView()); VELOX_CHECK_GE(capacity_, bytes); - VELOX_CHECK(podType_); + VELOX_CHECK_EQ(type_, Type::kPOD); memcpy(data_, other->data_, bytes); } @@ -252,27 +284,24 @@ class Buffer { } Buffer( - velox::memory::MemoryPool* pool, + Type type, uint8_t* data, size_t capacity, - bool podType) - : pool_(pool), - data_(data), - capacity_(capacity), - referenceCount_(0), - podType_(podType) {} + velox::memory::MemoryPool* pool) + : pool_{pool}, data_{data}, capacity_{capacity}, type_{type} {} velox::memory::MemoryPool* const pool_; uint8_t* const data_; - uint64_t size_ = 0; - uint64_t capacity_ = 0; - std::atomic referenceCount_; - bool podType_ = true; - // Pad to 64 bytes. If using as int32_t[], guarantee that value at index -1 == - // -1. - uint64_t padding_[2] = {static_cast(-1), static_cast(-1)}; - // Needs to use setCapacity() from static method reallocate(). - friend class AlignedBuffer; + + uint64_t size_{0}; + uint64_t capacity_; + std::atomic_int32_t referenceCount_{0}; + + const Type type_; + + // Pad to 64 bytes. + // If using as int32_t[], guarantee that value at index -1 == -1. + uint64_t padding_[2]{static_cast(-1), static_cast(-1)}; private: static BufferPtr sliceBufferZeroCopy( @@ -281,6 +310,9 @@ class Buffer { const BufferPtr& buffer, size_t offset, size_t length); + + // Needs to use setCapacity() from static method reallocate(). + friend class AlignedBuffer; }; static_assert( @@ -289,12 +321,12 @@ static_assert( template <> inline Range Buffer::asRange() { - return Range(as(), 0, size() * 8); + return {as(), 0, static_cast(size() * 8)}; } template <> inline MutableRange Buffer::asMutableRange() { - return MutableRange(asMutable(), 0, size() * 8); + return {asMutable(), 0, static_cast(size() * 8)}; } template <> @@ -304,11 +336,11 @@ BufferPtr Buffer::slice( size_t length, memory::MemoryPool* pool); -static inline void intrusive_ptr_add_ref(Buffer* buffer) { +FOLLY_ALWAYS_INLINE void intrusive_ptr_add_ref(Buffer* buffer) noexcept { buffer->addRef(); } -static inline void intrusive_ptr_release(Buffer* buffer) { +FOLLY_ALWAYS_INLINE void intrusive_ptr_release(Buffer* buffer) noexcept { buffer->release(); } @@ -325,7 +357,7 @@ class AlignedBuffer : public Buffer { static constexpr int32_t kSizeofAlignedBuffer = 64; static constexpr int32_t kPaddedSize = kSizeofAlignedBuffer + simd::kPadding; - ~AlignedBuffer() { + ~AlignedBuffer() override { // This may throw, which is expected to signal an error to the // user. This is better for distributed debugging than killing the // process. In concept this indicates the possibility of memory @@ -337,10 +369,8 @@ class AlignedBuffer : public Buffer { // It's almost like partial specialization, but we redirect all POD types to // the same non-templated class template - using ImplClass = typename std::conditional< - is_pod_like_v, - AlignedBuffer, - NonPODAlignedBuffer>::type; + using ImplClass = std:: + conditional_t, AlignedBuffer, NonPODAlignedBuffer>; /** * Allocates enough memory to store numElements of type T. May @@ -368,7 +398,7 @@ class AlignedBuffer : public Buffer { void* memory = pool->allocate(preferredSize); VELOX_CHECK_NOT_NULL(memory); - auto* buffer = new (memory) ImplClass(pool, preferredSize - kPaddedSize); + auto* buffer = new (memory) ImplClass{pool, preferredSize - kPaddedSize}; // set size explicitly instead of setSize because `fillNewMemory` already // called the constructors buffer->size_ = size; @@ -377,6 +407,18 @@ class AlignedBuffer : public Buffer { return result; } + /// A verbose version of the allocate() with the exact size. + /// May allocate slightly more memory than strictly necessary. Guarantees that + /// simd::kPadding bytes past capacity() are addressable and asserts that + /// these do not get overrun. + template + static BufferPtr allocateExact( + size_t numElements, + velox::memory::MemoryPool* pool, + const std::optional& initValue = std::nullopt) { + return allocate(numElements, pool, initValue, true); + } + // Changes the capacity of '*buffer'. The buffer may grow/shrink in // place or may change addresses. The content is copied up to the // old size() or the new size, whichever is smaller. If the buffer grows, the @@ -418,34 +460,32 @@ class AlignedBuffer : public Buffer { // called the constructors newBuffer->size_ = size; *buffer = std::move(newBuffer); - return; - } - if (!old->unique()) { + } else if (!old->unique()) { auto newBuffer = allocate(numElements, pool); newBuffer->copyFrom(old, std::min(size, old->size())); reinterpret_cast(newBuffer.get()) ->template fillNewMemory(old->size(), size, initValue); newBuffer->size_ = size; *buffer = std::move(newBuffer); - return; - } - auto oldCapacity = checkedPlus(old->capacity(), kPaddedSize); - auto preferredSize = - pool->preferredSize(checkedPlus(size, kPaddedSize)); + } else { + auto oldCapacity = checkedPlus(old->capacity(), kPaddedSize); + auto preferredSize = + pool->preferredSize(checkedPlus(size, kPaddedSize)); - void* newPtr = pool->reallocate(old, oldCapacity, preferredSize); + void* newPtr = pool->reallocate(old, oldCapacity, preferredSize); - // Make the old buffer no longer owned by '*buffer' because reallocate - // freed the old buffer. Reassigning the new buffer to - // '*buffer' would be a double free if we didn't do this. - buffer->detach(); + // Make the old buffer no longer owned by '*buffer' because reallocate + // freed the old buffer. Reassigning the new buffer to + // '*buffer' would be a double free if we didn't do this. + buffer->detach(); - auto newBuffer = - new (newPtr) AlignedBuffer(pool, preferredSize - kPaddedSize); - newBuffer->setSize(size); - newBuffer->fillNewMemory(oldSize, size, initValue); + auto newBuffer = + new (newPtr) AlignedBuffer{pool, preferredSize - kPaddedSize}; + newBuffer->setSize(size); + newBuffer->fillNewMemory(oldSize, size, initValue); - *buffer = newBuffer; + *buffer = newBuffer; + } } // Appends bytes starting at 'items' for a length of 'sizeof(T) * @@ -480,7 +520,7 @@ class AlignedBuffer : public Buffer { } VELOX_CHECK( - bufferPtr->podType_, "Support for non POD types not implemented yet"); + bufferPtr->isPOD(), "Support for non POD types not implemented yet"); // The reason we use uint8_t is because mutableNulls()->size() will return // in byte count. We also don't bother initializing since copyFrom will be @@ -492,13 +532,49 @@ class AlignedBuffer : public Buffer { return newBuffer; } + template + static BufferPtr copy( + const BufferPtr& buffer, + velox::memory::MemoryPool* pool) { + if (buffer == nullptr) { + return nullptr; + } + + // The reason we use uint8_t is because mutableNulls()->size() will return + // in byte count. We also don't bother initializing since copyFrom will be + // overwriting anyway. + BufferPtr newBuffer; + if constexpr (std::is_same_v) { + newBuffer = AlignedBuffer::allocate(buffer->size(), pool); + } else { + const auto numElements = checkedDivide(buffer->size(), sizeof(T)); + newBuffer = AlignedBuffer::allocate(numElements, pool); + } + + newBuffer->copyFrom(buffer.get(), newBuffer->size()); + + return newBuffer; + } + + bool transferTo(velox::memory::MemoryPool* pool) override { + if (pool_ == pool) { + return true; + } + if (pool_->transferTo( + pool, this, checkedPlus(kPaddedSize, capacity_))) { + setPool(pool); + return true; + } + return false; + } + protected: AlignedBuffer(velox::memory::MemoryPool* pool, size_t capacity) - : Buffer( - pool, + : Buffer{ + Type::kPOD, reinterpret_cast(this) + sizeof(*this), capacity, - true /*podType*/) { + pool} { static_assert(sizeof(*this) == kAlignment); static_assert(sizeof(*this) == kSizeofAlignedBuffer); setEndGuard(); @@ -532,7 +608,12 @@ class AlignedBuffer : public Buffer { } } - protected: + void setPool(velox::memory::MemoryPool* pool) { + velox::memory::MemoryPool** poolPtr = + const_cast(&pool_); + *poolPtr = pool; + } + void setEndGuardImpl() override { *reinterpret_cast(data_ + capacity_) = kEndGuard; } @@ -597,13 +678,30 @@ class NonPODAlignedBuffer : public Buffer { } } + bool transferTo(velox::memory::MemoryPool* pool) override { + if (pool_ == pool) { + return true; + } + + if (pool_->transferTo( + pool, + this, + checkedPlus(AlignedBuffer::kPaddedSize, capacity_))) { + velox::memory::MemoryPool** poolPtr = + const_cast(&pool_); + *poolPtr = pool; + return true; + } + return false; + } + protected: NonPODAlignedBuffer(velox::memory::MemoryPool* pool, size_t capacity) - : Buffer( - pool, + : Buffer{ + Type::kNonPOD, reinterpret_cast(this) + sizeof(*this), capacity, - false /*podType*/) { + pool} { static_assert(sizeof(*this) == AlignedBuffer::kAlignment); static_assert(sizeof(*this) == sizeof(AlignedBuffer)); } @@ -611,8 +709,8 @@ class NonPODAlignedBuffer : public Buffer { void releaseResources() override { VELOX_CHECK_EQ(size_ % sizeof(T), 0); size_t numValues = size_ / sizeof(T); - // we can't use asMutable because it checks isMutable and we wan't to - // destroy regardless + // we can't use asMutable because it checks isMutable and we wan't + // to destroy regardless T* ptr = reinterpret_cast(data_); for (int i = 0; i < numValues; ++i) { ptr[i].~T(); @@ -620,6 +718,8 @@ class NonPODAlignedBuffer : public Buffer { } void copyFrom(const Buffer* other, size_t bytes) override { + // TODO: change this to isMutable(). See + // https://github.com/facebookincubator/velox/issues/6562. VELOX_CHECK(!isView()); VELOX_CHECK_GE(size_, bytes); VELOX_DCHECK( @@ -676,47 +776,60 @@ class NonPODAlignedBuffer : public Buffer { template class BufferView : public Buffer { public: - static BufferPtr create( - const uint8_t* data, - size_t size, - Releaser releaser, - bool podType = true) { - BufferView* view = new BufferView(data, size, releaser, podType); - BufferPtr result(view); + template + static BufferPtr + create(const uint8_t* data, size_t size, R&& releaser, bool podType = true) { + auto* view = new BufferView{data, size, std::forward(releaser), podType}; + BufferPtr result{view}; return result; } // Helper method to create a buffer view referencing another existing Buffer. + template static BufferPtr - create(BufferPtr innerBuffer, Releaser releaser, bool podType = true) { + create(const BufferPtr& innerBuffer, R&& releaser, bool podType = true) { return create( - innerBuffer->as(), innerBuffer->size(), releaser, podType); + innerBuffer->as(), + innerBuffer->size(), + std::forward(releaser), + podType); } ~BufferView() override { releaser_.release(); } - bool isView() const override { - return true; + bool transferTo(velox::memory::MemoryPool* pool) override { + if (pool_ == pool) { + return true; + } + return false; } private: - BufferView(const uint8_t* data, size_t size, Releaser releaser, bool podType) + template + BufferView(const uint8_t* data, size_t size, R&& releaser, bool podType) // A BufferView must be created over the data held by a cache // pin, which is typically const. The Buffer enforces const-ness // when returning the pointer. We cast away the const here to // avoid a separate code path for const and non-const Buffer // payloads. - : Buffer(nullptr, const_cast(data), size, podType), - releaser_(releaser) { + : Buffer{podType ? Type::kPODView : Type::kNonPODView, const_cast(data), size, nullptr}, + releaser_{std::forward(releaser)} { size_ = size; - capacity_ = size; releaser_.addRef(); } - Releaser const releaser_; + [[no_unique_address]] const Releaser releaser_; }; -} // namespace velox -} // namespace facebook +} // namespace facebook::velox + +// fmt formatter specialization for Buffer::Type +template <> +struct fmt::formatter : formatter { + auto format(facebook::velox::Buffer::Type s, format_context& ctx) const { + return formatter::format( + facebook::velox::Buffer::typeString(s), ctx); + } +}; diff --git a/velox/buffer/BufferPool.cpp b/velox/buffer/BufferPool.cpp new file mode 100644 index 00000000000..71525c80176 --- /dev/null +++ b/velox/buffer/BufferPool.cpp @@ -0,0 +1,53 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/buffer/BufferPool.h" + +namespace facebook::velox { + +BufferPtr BufferPool::get(uint64_t minBytes) { + for (size_t i = 0; i < buffers_.size(); ++i) { + if (buffers_[i]->capacity() >= minBytes) { + auto result = std::move(buffers_[i]); + buffers_[i] = std::move(buffers_.back()); + buffers_.pop_back(); + return result; + } + } + return nullptr; +} + +BufferPtr BufferPool::get() { + if (buffers_.empty()) { + return nullptr; + } + auto result = std::move(buffers_.back()); + buffers_.pop_back(); + return result; +} + +void BufferPool::release(BufferPtr&& buffer) { + if (buffer == nullptr || !buffer->unique() || buffers_.size() >= kMaxCached) { + return; + } + buffers_.push_back(std::move(buffer)); +} + +size_t BufferPool::size() const { + return buffers_.size(); +} + +} // namespace facebook::velox diff --git a/velox/buffer/BufferPool.h b/velox/buffer/BufferPool.h new file mode 100644 index 00000000000..bf5c1081e89 --- /dev/null +++ b/velox/buffer/BufferPool.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "velox/buffer/Buffer.h" + +namespace facebook::velox { + +/// Caches BufferPtr objects for reuse across encoding lifetimes. +/// When encoding objects are destroyed, their scratch buffers can be returned +/// to the pool instead of being freed. New encodings can then grab +/// pre-allocated buffers from the pool, avoiding MemoryPool allocation +/// overhead (including mutex contention). +/// +/// Thread safety: NOT thread-safe. Intended for use within a single +/// deserialization stream. +class BufferPool { + public: + /// Returns a cached buffer with capacity >= minBytes, or nullptr if none + /// available. + BufferPtr get(uint64_t minBytes); + + /// Returns a cached buffer (any size), or nullptr if none available. + BufferPtr get(); + + /// Returns a buffer to the pool for future reuse. Drops the buffer if the + /// pool is full or if 'buffer' is null. The buffer must have a unique + /// reference (refcount == 1) to ensure safe reuse. + void release(BufferPtr&& buffer); + + /// Returns the number of cached buffers currently in the pool. + size_t size() const; + + private: + // Maximum number of BufferPtr objects retained in the pool. + static constexpr size_t kMaxCached = 32; + // Cached buffers available for reuse. + std::vector buffers_; +}; + +} // namespace facebook::velox diff --git a/velox/buffer/CMakeLists.txt b/velox/buffer/CMakeLists.txt index ec15197d4f5..01d59ef258e 100644 --- a/velox/buffer/CMakeLists.txt +++ b/velox/buffer/CMakeLists.txt @@ -12,7 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -velox_add_library(velox_buffer Buffer.cpp StringViewBufferHolder.cpp) +velox_add_library( + velox_buffer + Buffer.cpp + BufferPool.cpp + StringViewBufferHolder.cpp + HEADERS + Buffer.h + BufferPool.h + StringViewBufferHolder.h +) velox_link_libraries(velox_buffer velox_memory velox_common_base Folly::folly) diff --git a/velox/buffer/StringViewBufferHolder.cpp b/velox/buffer/StringViewBufferHolder.cpp index 83479bba0a7..750b011b2c1 100644 --- a/velox/buffer/StringViewBufferHolder.cpp +++ b/velox/buffer/StringViewBufferHolder.cpp @@ -26,8 +26,9 @@ StringView StringViewBufferHolder::getOwnedStringView( if (stringBuffers_.empty() || stringBuffers_.back()->size() + size > stringBuffers_.back()->capacity()) { - stringBuffers_.push_back(AlignedBuffer::allocate( - std::max(size, kInitialStringReservation), pool_)); + stringBuffers_.push_back( + AlignedBuffer::allocate( + std::max(size, kInitialStringReservation), pool_)); stringBuffers_.back()->setSize(0); } auto stringBuffer = stringBuffers_.back().get(); diff --git a/velox/buffer/StringViewBufferHolder.h b/velox/buffer/StringViewBufferHolder.h index 99a1f3276ce..962b484fd9e 100644 --- a/velox/buffer/StringViewBufferHolder.h +++ b/velox/buffer/StringViewBufferHolder.h @@ -33,7 +33,7 @@ class StringViewBufferHolder { /// Return a copy of the StringView where the StringView is copied to this /// StringViewBufferHolder if the StringView is not inlined. std::string and - /// folly::StringPiece are also copied to the internal buffers (see the + /// std::string_view are also copied to the internal buffers (see the /// specializations below). /// /// NOTE: Out of convenience, we allow different types to be passed in, but @@ -52,8 +52,8 @@ class StringViewBufferHolder { return getOwnedStringView(value.data(), value.size()); } - /// Specialization for folly::StringPiece type. - StringView getOwnedValue(folly::StringPiece value) { + /// Specialization for std::string_view type. + StringView getOwnedValue(std::string_view value) { return getOwnedStringView(value.data(), value.size()); } diff --git a/velox/buffer/tests/BufferPoolTest.cpp b/velox/buffer/tests/BufferPoolTest.cpp new file mode 100644 index 00000000000..07ff3468cf3 --- /dev/null +++ b/velox/buffer/tests/BufferPoolTest.cpp @@ -0,0 +1,152 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/buffer/BufferPool.h" + +#include + +#include "velox/common/memory/Memory.h" + +namespace facebook::velox::test { + +class BufferPoolTest : public ::testing::Test { + protected: + void SetUp() override { + pool_ = memoryManager_.addLeafPool("BufferPoolTest"); + } + + memory::MemoryManager memoryManager_; + std::shared_ptr pool_; +}; + +TEST_F(BufferPoolTest, emptyPool) { + BufferPool bufferPool; + EXPECT_EQ(bufferPool.size(), 0); + EXPECT_EQ(bufferPool.get(), nullptr); + EXPECT_EQ(bufferPool.get(100), nullptr); +} + +TEST_F(BufferPoolTest, recycleAndGet) { + BufferPool bufferPool; + auto buffer = AlignedBuffer::allocate(1'024, pool_.get()); + auto* rawPtr = buffer.get(); + const auto capacity = buffer->capacity(); + + bufferPool.release(std::move(buffer)); + EXPECT_EQ(bufferPool.size(), 1); + + auto retrieved = bufferPool.get(); + EXPECT_NE(retrieved, nullptr); + EXPECT_EQ(retrieved.get(), rawPtr); + EXPECT_EQ(retrieved->capacity(), capacity); + EXPECT_EQ(bufferPool.size(), 0); +} + +TEST_F(BufferPoolTest, getWithMinBytes) { + BufferPool bufferPool; + + auto small = AlignedBuffer::allocate(128, pool_.get()); + auto large = AlignedBuffer::allocate(4'096, pool_.get()); + auto* largePtr = large.get(); + + bufferPool.release(std::move(small)); + bufferPool.release(std::move(large)); + EXPECT_EQ(bufferPool.size(), 2); + + // Request minBytes that only the large buffer satisfies. + auto retrieved = bufferPool.get(1'024); + EXPECT_NE(retrieved, nullptr); + EXPECT_EQ(retrieved.get(), largePtr); + EXPECT_EQ(bufferPool.size(), 1); + + // The small buffer is still there. + auto remaining = bufferPool.get(); + EXPECT_NE(remaining, nullptr); + EXPECT_EQ(bufferPool.size(), 0); +} + +TEST_F(BufferPoolTest, getWithMinBytesNoMatch) { + BufferPool bufferPool; + + auto small = AlignedBuffer::allocate(128, pool_.get()); + bufferPool.release(std::move(small)); + + auto retrieved = bufferPool.get(1'000'000); + EXPECT_EQ(retrieved, nullptr); + EXPECT_EQ(bufferPool.size(), 1); +} + +TEST_F(BufferPoolTest, recycleNullptr) { + BufferPool bufferPool; + BufferPtr nullBuffer; + bufferPool.release(std::move(nullBuffer)); + EXPECT_EQ(bufferPool.size(), 0); +} + +TEST_F(BufferPoolTest, releaseNonUniqueBuffer) { + BufferPool bufferPool; + auto buffer = AlignedBuffer::allocate(1'024, pool_.get()); + auto copy = buffer; + EXPECT_FALSE(buffer->unique()); + bufferPool.release(std::move(buffer)); + EXPECT_EQ(bufferPool.size(), 0); +} + +TEST_F(BufferPoolTest, releaseMoveSemantics) { + BufferPool bufferPool; + auto buffer = AlignedBuffer::allocate(1'024, pool_.get()); + auto* rawPtr = buffer.get(); + EXPECT_TRUE(buffer->unique()); + // NOLINTNEXTLINE(bugprone-use-after-move) + bufferPool.release(std::move(buffer)); + EXPECT_EQ(buffer, nullptr); + EXPECT_EQ(bufferPool.size(), 1); + auto retrieved = bufferPool.get(); + EXPECT_EQ(retrieved.get(), rawPtr); +} + +TEST_F(BufferPoolTest, maxCachedLimit) { + BufferPool bufferPool; + + // Fill beyond the max cached limit. + for (size_t i = 0; i < 40; ++i) { + auto buffer = AlignedBuffer::allocate(64, pool_.get()); + bufferPool.release(std::move(buffer)); + } + + // Should be capped at 32. + EXPECT_EQ(bufferPool.size(), 32); +} + +TEST_F(BufferPoolTest, multipleRecycleAndGet) { + BufferPool bufferPool; + + for (int i = 0; i < 5; ++i) { + auto buffer = AlignedBuffer::allocate((i + 1) * 256, pool_.get()); + bufferPool.release(std::move(buffer)); + } + EXPECT_EQ(bufferPool.size(), 5); + + // Get all back out. + for (int i = 0; i < 5; ++i) { + auto retrieved = bufferPool.get(); + EXPECT_NE(retrieved, nullptr); + } + EXPECT_EQ(bufferPool.size(), 0); + EXPECT_EQ(bufferPool.get(), nullptr); +} + +} // namespace facebook::velox::test diff --git a/velox/buffer/tests/BufferTest.cpp b/velox/buffer/tests/BufferTest.cpp index 9db4963221d..a45eeff2472 100644 --- a/velox/buffer/tests/BufferTest.cpp +++ b/velox/buffer/tests/BufferTest.cpp @@ -143,6 +143,20 @@ TEST_F(BufferTest, testAlignedBufferExact) { EXPECT_GE(buffer4->capacity(), oneMBMinusPad + 1); } +TEST_F(BufferTest, testAllocateExact) { + const int32_t oneMBMinusPad = 1024 * 1024 - AlignedBuffer::kPaddedSize; + + BufferPtr buffer1 = AlignedBuffer::allocateExact( + oneMBMinusPad + 1, pool_.get(), std::nullopt); + EXPECT_EQ(buffer1->size(), oneMBMinusPad + 1); + EXPECT_GE(buffer1->capacity(), oneMBMinusPad + 1); + + BufferPtr buffer2 = AlignedBuffer::allocateExact(3, pool_.get(), 'i'); + for (size_t i = 0; i < buffer2->size(); i++) { + EXPECT_EQ(buffer2->as()[i], 'i'); + } +} + TEST_F(BufferTest, testAsRange) { // Simple 2 element vector. std::vector testData({5, 255}); @@ -484,7 +498,9 @@ TEST_F(BufferTest, testNonPOD) { TEST_F(BufferTest, testNonPODMemoryUsage) { using T = std::shared_ptr; const int64_t currentBytes = pool_->usedBytes(); - { auto buffer = AlignedBuffer::allocate(0, pool_.get()); } + { + auto buffer = AlignedBuffer::allocate(0, pool_.get()); + } EXPECT_EQ(pool_->usedBytes(), currentBytes); } @@ -535,5 +551,34 @@ TEST_F(BufferTest, sliceBooleanBuffer) { Buffer::slice(bufferPtr, 5, 6, nullptr), "Pool must not be null."); } +TEST_F(BufferTest, testType) { + // Test AlignedBuffer type + auto alignedBuffer = AlignedBuffer::allocate(100, pool_.get()); + EXPECT_EQ(alignedBuffer->type(), Buffer::Type::kPOD); + EXPECT_TRUE(alignedBuffer->isPOD()); + EXPECT_FALSE(alignedBuffer->isView()); + + // Test NonPODAlignedBuffer type + auto nonPODBuffer = AlignedBuffer::allocate(10, pool_.get()); + EXPECT_EQ(nonPODBuffer->type(), Buffer::Type::kNonPOD); + EXPECT_FALSE(nonPODBuffer->isPOD()); + EXPECT_FALSE(nonPODBuffer->isView()); + + // Test BufferView type + MockCachePin pin; + const char* data = "test data"; + auto podBufferView = BufferView::create( + reinterpret_cast(data), 9, pin); + EXPECT_EQ(podBufferView->type(), Buffer::Type::kPODView); + EXPECT_TRUE(podBufferView->isPOD()); + EXPECT_TRUE(podBufferView->isView()); + + auto nonPodBufferView = BufferView::create( + reinterpret_cast(data), 9, pin, false); + EXPECT_EQ(nonPodBufferView->type(), Buffer::Type::kNonPODView); + EXPECT_FALSE(nonPodBufferView->isPOD()); + EXPECT_TRUE(nonPodBufferView->isView()); +} + } // namespace velox } // namespace facebook diff --git a/velox/buffer/tests/CMakeLists.txt b/velox/buffer/tests/CMakeLists.txt index 9ca02cf2545..eabc16c6911 100644 --- a/velox/buffer/tests/CMakeLists.txt +++ b/velox/buffer/tests/CMakeLists.txt @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -add_executable(velox_buffer_test BufferTest.cpp StringViewBufferHolderTest.cpp) +add_executable(velox_buffer_test BufferPoolTest.cpp BufferTest.cpp StringViewBufferHolderTest.cpp) add_test(velox_buffer_test velox_buffer_test) diff --git a/velox/buffer/tests/StringViewBufferHolderTest.cpp b/velox/buffer/tests/StringViewBufferHolderTest.cpp index 2f5c3fc3024..4ae7ece1a38 100644 --- a/velox/buffer/tests/StringViewBufferHolderTest.cpp +++ b/velox/buffer/tests/StringViewBufferHolderTest.cpp @@ -161,23 +161,21 @@ TEST_F(StringViewBufferHolderTest, getOwnedValueCanBeCalledWithStringType) { ASSERT_EQ(1, holder.buffers().size()); } -TEST_F( - StringViewBufferHolderTest, - getOwnedValueCanBeCalledWithStringPieceType) { +TEST_F(StringViewBufferHolderTest, getOwnedValueCanBeCalledWithStringViewType) { const char* buf = "abcdefghijklmnopqrstuvxz"; StringView result; - folly::StringPiece piece; + std::string_view view; auto holder = makeHolder(); ASSERT_EQ(0, holder.buffers().size()); { std::string str = buf; - piece = str; - result = holder.getOwnedValue(piece); + view = str; + result = holder.getOwnedValue(view); } - // `str` is already destructed and piece is invalid. + // `str` is already destructed and `view` is invalid. ASSERT_EQ(StringView(buf), result); ASSERT_EQ(1, holder.buffers().size()); } diff --git a/velox/common/CMakeLists.txt b/velox/common/CMakeLists.txt index 6661427654b..0e8925038b8 100644 --- a/velox/common/CMakeLists.txt +++ b/velox/common/CMakeLists.txt @@ -18,11 +18,29 @@ add_subdirectory(config) add_subdirectory(dynamic_registry) add_subdirectory(encode) add_subdirectory(file) +add_subdirectory(future) +add_subdirectory(geospatial) add_subdirectory(hyperloglog) add_subdirectory(io) add_subdirectory(memory) add_subdirectory(process) +add_subdirectory(rpc) add_subdirectory(serialization) add_subdirectory(time) add_subdirectory(testutil) add_subdirectory(fuzzer) + +if(${VELOX_BUILD_TESTING}) + add_subdirectory(tests) +endif() + +velox_add_library(velox_enum_declare INTERFACE HEADERS EnumDeclare.h) + +velox_add_library(velox_enum_define INTERFACE HEADERS EnumDefine.h) + +velox_add_library(velox_casts INTERFACE HEADERS Casts.h) + +velox_add_library(velox_scoped_registry INTERFACE HEADERS ScopedRegistry.h) +velox_link_libraries(velox_scoped_registry INTERFACE velox_exception Folly::folly) + +velox_install_library_headers() diff --git a/velox/common/Casts.h b/velox/common/Casts.h index 013a60a3b89..8687d56c149 100644 --- a/velox/common/Casts.h +++ b/velox/common/Casts.h @@ -40,13 +40,13 @@ void ensureCastSucceeded(To* casted, From* original) { } // namespace detail -// `checked_pointer_cast` is a dynamic casting tool to throw a Velox exception +// `checkedPointerCast` is a dynamic casting tool to throw a Velox exception // when the casting failed. Use this instead of `std::dynamic_pointer_cast` // when: // 1) Casting must happen // 2) We want a stack trace if it failed. template -std::shared_ptr checked_pointer_cast(const std::shared_ptr& input) { +std::shared_ptr checkedPointerCast(const std::shared_ptr& input) { VELOX_CHECK_NOT_NULL(input.get()); auto casted = std::dynamic_pointer_cast(input); detail::ensureCastSucceeded(casted.get(), input.get()); @@ -54,7 +54,7 @@ std::shared_ptr checked_pointer_cast(const std::shared_ptr& input) { } template -std::unique_ptr checked_pointer_cast(std::unique_ptr input) { +std::unique_ptr checkedPointerCast(std::unique_ptr input) { VELOX_CHECK_NOT_NULL(input.get()); auto* released = input.release(); To* casted{nullptr}; @@ -69,7 +69,7 @@ std::unique_ptr checked_pointer_cast(std::unique_ptr input) { } template -To* checked_pointer_cast(From* input) { +To* checkedPointerCast(From* input) { VELOX_CHECK_NOT_NULL(input); auto* casted = dynamic_cast(input); detail::ensureCastSucceeded(casted, input); @@ -77,7 +77,7 @@ To* checked_pointer_cast(From* input) { } template -std::unique_ptr static_unique_pointer_cast(std::unique_ptr input) { +std::unique_ptr staticUniquePointerCast(std::unique_ptr input) { VELOX_CHECK_NOT_NULL(input.get()); auto* released = input.release(); auto* casted = static_cast(released); @@ -85,24 +85,23 @@ std::unique_ptr static_unique_pointer_cast(std::unique_ptr input) { } template -bool is_instance_of(const std::shared_ptr& input) { +bool isInstanceOf(const std::shared_ptr& input) { VELOX_CHECK_NOT_NULL(input.get()); auto* casted = dynamic_cast(input.get()); return casted != nullptr; } template -bool is_instance_of(const std::unique_ptr& input) { +bool isInstanceOf(const std::unique_ptr& input) { VELOX_CHECK_NOT_NULL(input.get()); auto* casted = dynamic_cast(input.get()); return casted != nullptr; } template -bool is_instance_of(const From* input) { +bool isInstanceOf(const From* input) { VELOX_CHECK_NOT_NULL(input); auto* casted = dynamic_cast(input); return casted != nullptr; } - } // namespace facebook::velox diff --git a/velox/common/EnumDeclare.h b/velox/common/EnumDeclare.h new file mode 100644 index 00000000000..e9ea89a714d --- /dev/null +++ b/velox/common/EnumDeclare.h @@ -0,0 +1,75 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +/// Lightweight declaration macros for enum-to-name mappings. +/// +/// Include this header in .h files. Include EnumDefine.h only in .cpp files +/// that use the DEFINE macros. EnumDefine.h pulls in heavy transitive includes +/// (folly/F14Map.h, Exceptions.h, ~8M preprocessed) that propagate to every +/// file that includes your header. +/// +/// Usage: +/// +/// In the header file, define the enum and declare its name mapping: +/// +/// #include "velox/common/EnumDeclare.h" +/// +/// enum class Foo {...}; +/// +/// VELOX_DECLARE_ENUM_NAME(Foo); +/// +/// In the .cpp file, define the mapping: +/// +/// #include "velox/common/EnumDefine.h" +/// +/// namespace { +/// const auto& fooNames() { +/// static const folly::F14FastMap kNames = { +/// {Foo::kFirst, "FIRST"}, +/// {Foo::kSecond, "SECOND"}, +/// ... +/// }; +/// return kNames; +/// } +/// } // namespace +/// +/// VELOX_DEFINE_ENUM_NAME(Foo, fooNames); +/// +/// In client code, use FooName::toName(Foo::kFirst) to get the name of the +/// enum and FooName::toFoo("FIRST") or FooName::tryToFoo("FIRST") to get the +/// enum value. toFoo throws an exception if the input is not a valid name, +/// while tryToFoo returns std::nullopt. +/// +/// Use _EMBEDDED_ versions of the macros to define enums embedded in other +/// classes. + +#define VELOX_DECLARE_ENUM_NAME(EnumType) \ + struct EnumType##Name { \ + static std::string_view toName(EnumType value); \ + static EnumType to##EnumType(std::string_view name); \ + static std::optional tryTo##EnumType(std::string_view name); \ + }; \ + std::ostream& operator<<(std::ostream& os, const EnumType& value); + +#define VELOX_DECLARE_EMBEDDED_ENUM_NAME(EnumType) \ + static std::string_view toName(EnumType value); \ + static EnumType to##EnumType(std::string_view name); \ + static std::optional tryTo##EnumType(std::string_view name); diff --git a/velox/common/Enums.h b/velox/common/EnumDefine.h similarity index 76% rename from velox/common/Enums.h rename to velox/common/EnumDefine.h index 8f844ea1b05..c883f93a5f8 100644 --- a/velox/common/Enums.h +++ b/velox/common/EnumDefine.h @@ -37,49 +37,9 @@ struct Enums { } // namespace facebook::velox -/// Helper macros to implement bi-direction mappings between enum values and -/// names. +/// DEFINE macros for enum-to-name mappings. Include only in .cpp files. /// -/// Usage: -/// -/// In the header file, define the enum: -/// -/// #include "velox/common/Enums.h" -/// -/// enum class Foo {...}; -/// -/// VELOX_DECLARE_ENUM_NAME(Foo); -/// -/// In the cpp file, define the mapping: -/// -/// namespace { -/// const auto& fooNames() { -/// static const folly::F14FastMap kNames = { -/// {Foo::kFirst, "FIRST"}, -/// {Foo::kSecond, "SECOND"}, -/// ... -/// }; -/// return kNames; -/// } -/// } // namespace -/// -/// VELOX_DEFINE_ENUM_NAME(Foo, fooNames); -/// -/// In the client code, use FooName::toName(Foo::kFirst) to get the name of the -/// enum and FooName::toFoo("FIRST") or FooName::tryToFoo("FIRST") to get the -/// enum value. toFoo throws an exception if input is not a valid name, while -/// tryToFoo returns a std::nullopt. -/// -/// Use _EMBEDDED_ versions of the macros to define enums embedded in other -/// classes. - -#define VELOX_DECLARE_ENUM_NAME(EnumType) \ - struct EnumType##Name { \ - static std::string_view toName(EnumType value); \ - static EnumType to##EnumType(std::string_view name); \ - static std::optional tryTo##EnumType(std::string_view name); \ - }; \ - std::ostream& operator<<(std::ostream& os, const EnumType& value); +/// See EnumDeclare.h for full usage documentation and the DECLARE macros. #define VELOX_DEFINE_ENUM_NAME(EnumType, Names) \ std::string_view EnumType##Name::toName(EnumType value) { \ @@ -113,11 +73,6 @@ struct Enums { return *maybeType; \ } -#define VELOX_DECLARE_EMBEDDED_ENUM_NAME(EnumType) \ - static std::string_view toName(EnumType value); \ - static EnumType to##EnumType(std::string_view name); \ - static std::optional tryTo##EnumType(std::string_view name); - #define VELOX_DEFINE_EMBEDDED_ENUM_NAME(Class, EnumType, Names) \ std::string_view Class::toName(Class::EnumType value) { \ const auto& names = Names(); \ diff --git a/velox/common/ScopedRegistry.h b/velox/common/ScopedRegistry.h new file mode 100644 index 00000000000..0ed1cb61582 --- /dev/null +++ b/velox/common/ScopedRegistry.h @@ -0,0 +1,143 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include "folly/Synchronized.h" +#include "folly/container/F14Map.h" +#include "folly/container/F14Set.h" +#include "velox/common/base/Exceptions.h" + +namespace facebook::velox { + +/// Layered key-value registry. Lookups check the local scope first, then fall +/// back to the parent chain. Supports arbitrary nesting (e.g., global -> +/// session -> query). All operations are synchronized. +/// +/// Two usage modes: +/// +/// - Override mode (parent != nullptr): The registry inherits all entries from +/// the parent and selectively overrides specific entries. Lookups that miss +/// locally fall through to the parent. +/// +/// - Isolation mode (parent == nullptr): The registry is self-contained. Only +/// explicitly registered entries are visible. +/// +/// Thread safety: Each scope has its own lock. Lookups and snapshots acquire +/// locks one scope at a time while walking the parent chain. This is safe and +/// consistent as long as parent registries are not mutated while child scopes +/// exist. Mutations should be limited to leaf scopes (e.g., per-query +/// overrides). Parent scopes may be freely mutated before children are created +/// and after all children are destroyed. +/// +/// @tparam K Key type. Must be copyable, hashable, and equality-comparable. +/// @tparam V Value type. Stored as shared_ptr. +template +class ScopedRegistry { + public: + using ValuePtr = std::shared_ptr; + + /// Create a root registry (no parent). + ScopedRegistry() : parent_{nullptr} {} + + /// Create a derived scope that falls back to 'parent' when a key is not + /// found locally. The parent must outlive this registry. + explicit ScopedRegistry(const ScopedRegistry* parent) : parent_{parent} {} + + /// Insert an entry in the local scope. Returns true if the key was newly + /// inserted. Throws if the key already exists unless 'overwrite' is true, + /// in which case the existing entry is replaced and false is returned. + bool insert(K key, ValuePtr value, bool overwrite = false) { + return local_.withWLock([&](auto& map) { + auto [it, inserted] = map.emplace(std::move(key), value); + if (!inserted) { + VELOX_CHECK(overwrite, "Key already registered: {}", it->first); + it->second = std::move(value); + } + return inserted; + }); + } + + /// Look up a key. Checks local scope first, then walks the parent chain. + /// Returns nullptr if not found. + ValuePtr find(const K& key) const { + auto result = local_.withRLock([&](const auto& map) -> ValuePtr { + auto it = map.find(key); + return it != map.end() ? it->second : nullptr; + }); + if (result) { + return result; + } + return parent_ ? parent_->find(key) : nullptr; + } + + /// Remove an entry from the local scope. Returns true if the entry was + /// removed, false if the key was not found locally. Does not affect parent + /// scopes. + bool erase(const K& key) { + return local_.withWLock([&](auto& map) { return map.erase(key) > 0; }); + } + + /// Remove all entries from the local scope. Does not affect parent scopes. + /// Entries are moved out under the lock and destroyed outside to avoid + /// holding the lock during potentially slow destructors. + void clear() { + folly::F14FastMap entries; + local_.withWLock([&](auto& map) { entries.swap(map); }); + } + + /// Return a snapshot of all visible entries. Local entries take precedence + /// over parent entries with the same key. The snapshot is a copy — no locks + /// are held after this returns. + std::vector> snapshot() const { + // Collect local entries. + std::vector> result; + local_.withRLock([&](const auto& map) { + result.reserve(map.size()); + for (const auto& [key, value] : map) { + result.emplace_back(key, value); + } + }); + + if (!parent_) { + return result; + } + + // Merge parent entries, skipping keys already present locally. + auto parentEntries = parent_->snapshot(); + folly::F14FastSet localKeys; + localKeys.reserve(result.size()); + for (const auto& [key, _] : result) { + localKeys.insert(key); + } + for (auto& [key, value] : parentEntries) { + if (!localKeys.contains(key)) { + result.emplace_back(std::move(key), std::move(value)); + } + } + return result; + } + + private: + folly::Synchronized> local_; + const ScopedRegistry* parent_; +}; + +} // namespace facebook::velox diff --git a/velox/common/base/AdmissionController.cpp b/velox/common/base/AdmissionController.cpp index c7e1b71ea5a..e1dae70a402 100644 --- a/velox/common/base/AdmissionController.cpp +++ b/velox/common/base/AdmissionController.cpp @@ -32,12 +32,10 @@ void AdmissionController::accept(uint64_t resourceUnits) { { std::lock_guard l(mu_); if (unitsUsed_ + resourceUnits > config_.maxLimit) { - auto [unblockPromise, unblockFuture] = makeVeloxContinuePromiseContract(); Request req; req.unitsRequested = resourceUnits; - req.promise = std::move(unblockPromise); + future = req.promise.getSemiFuture(); queue_.push_back(std::move(req)); - future = std::move(unblockFuture); } else { updatedValue = unitsUsed_ += resourceUnits; } diff --git a/velox/common/base/AsyncSource.h b/velox/common/base/AsyncSource.h index 76740dd8681..e6ff198705e 100644 --- a/velox/common/base/AsyncSource.h +++ b/velox/common/base/AsyncSource.h @@ -16,11 +16,14 @@ #pragma once +#include #include #include #include +#include #include #include +#include #include "velox/common/time/CpuWallTimer.h" #include "velox/common/base/Exceptions.h" @@ -31,22 +34,23 @@ namespace facebook::velox { -// A future-like object that prefabricates Items on an executor and -// allows consumer threads to pick items as they are ready. If the -// consumer needs the item before the executor started making it, the -// consumer will make it instead. If multiple consumers request the -// same item, exactly one gets it. Propagates exceptions to the -// consumer. +/// A future-like object that prefabricates Items on an executor and +/// allows consumer threads to pick items as they are ready. If the +/// consumer needs the item before the executor started making it, the +/// consumer will make it instead. If multiple consumers request the +/// same item, exactly one gets it. Propagates exceptions to the +/// consumer. template class AsyncSource { public: - explicit AsyncSource(std::function()> make) - : make_(std::move(make)) { + explicit AsyncSource(std::function()> itemMaker) + : itemMaker_(std::move(itemMaker)) { + VELOX_CHECK_NOT_NULL(itemMaker_); if (process::GetThreadDebugInfo() != nullptr) { auto* currentThreadDebugInfo = process::GetThreadDebugInfo(); // We explicitly leave out the callback when copying the ThreadDebugInfo // as that may have captured state that goes out of scope by the time - // _make is called. + // itemMaker_ is called. threadDebugInfo_ = std::make_optional( {currentThreadDebugInfo->queryId_, currentThreadDebugInfo->taskId_, @@ -55,167 +59,348 @@ class AsyncSource { } ~AsyncSource() { + const auto currentState = state(); VELOX_CHECK( - moved_ || closed_, - "AsyncSource should be properly consumed or closed."); + currentState == State::kFinished || currentState == State::kCancelled || + currentState == State::kFailed, + "AsyncSource should be properly finished, cancelled, or failed, unexpected state: {}", + stateName(currentState)); } - // Makes an item if it is not already made. To be called on a background - // executor. + /// Makes an item if it is not already made. To be called on a background + /// executor. void prepare() { common::testutil::TestValue::adjust( "facebook::velox::AsyncSource::prepare", this); - std::function()> make = nullptr; + std::function()> itemMaker{nullptr}; { std::lock_guard l(mutex_); - if (!make_) { + if (state() != State::kInit) { + VELOX_CHECK_NULL(itemMaker_); return; } - making_ = true; - std::swap(make, make_); - } - std::unique_ptr item; - try { - CpuWallTimer timer(timing_); - item = runMake(make); - } catch (std::exception&) { - std::lock_guard l(mutex_); - exception_ = std::current_exception(); - } - std::unique_ptr promise; - { - std::lock_guard l(mutex_); - VELOX_CHECK_NULL(item_); - if (FOLLY_LIKELY(exception_ == nullptr)) { - item_ = std::move(item); - } - making_ = false; - promise.swap(promise_); - } - if (promise != nullptr) { - promise->setValue(); + setState(State::kMaking); + std::swap(itemMaker, itemMaker_); } + makeItem(std::move(itemMaker)); } - // Returns the item to the first caller and nullptr to subsequent callers. - // If the item is preparing on the executor, waits for the item and - // otherwise makes it on the caller thread. + /// Returns the item to the first caller and nullptr to subsequent callers. + /// If the item is preparing on the executor, waits for the item and + /// otherwise makes it on the caller thread. std::unique_ptr move() { common::testutil::TestValue::adjust( "facebook::velox::AsyncSource::move", this); - std::function()> make = nullptr; + std::function()> itemMaker{nullptr}; ContinueFuture wait; { std::lock_guard l(mutex_); - moved_ = true; - // 'making_' can be read atomically, 'exception' maybe not. So test - // 'making' so as not to read half-assigned 'exception_'. - if (!making_ && exception_) { - std::rethrow_exception(exception_); - } - if (item_) { - return std::move(item_); - } - if (promise_) { - // Somebody else is now waiting for the item to be made. - return nullptr; - } - if (making_) { - promise_ = std::make_unique(); - wait = promise_->getSemiFuture(); - } else { - if (!make_) { + const auto currentState = state(); + switch (currentState) { + case State::kFinished: + case State::kCancelled: return nullptr; - } - std::swap(make, make_); + case State::kFailed: + VELOX_CHECK_NOT_NULL(exception_); + std::rethrow_exception(exception_); + case State::kPrepared: + VELOX_CHECK(promises_.empty()); + setState(State::kFinished); + return std::move(item_); + case State::kMaking: + VELOX_CHECK_NULL(itemMaker_); + VELOX_CHECK_NULL(item_); + if (!promises_.empty()) { + // Somebody else is already waiting for the item to be made. + return nullptr; + } + promises_.emplace_back("AsyncSource::move"); + wait = promises_.back().getSemiFuture(); + break; + case State::kInit: + VELOX_CHECK_NOT_NULL(itemMaker_); + VELOX_CHECK_NULL(item_); + setState(State::kMaking); + std::swap(itemMaker, itemMaker_); + break; } } // Outside of mutex_. - if (make) { - try { - return runMake(make); - } catch (const std::exception&) { - std::lock_guard l(mutex_); - exception_ = std::current_exception(); - throw; - } + if (itemMaker != nullptr) { + makeItem(std::move(itemMaker)); } - auto& exec = folly::QueuedImmediateExecutor::instance(); - std::move(wait).via(&exec).wait(); + + makeWait(std::move(wait)); + std::lock_guard l(mutex_); - if (exception_) { + const auto currentState = state(); + if (exception_ != nullptr) { + checkState(currentState, State::kFailed); std::rethrow_exception(exception_); } + if (currentState == State::kFinished) { + // Another move() or close() might have grabbed the item first. + return nullptr; + } + checkState(currentState, State::kPrepared); + setState(State::kFinished); return std::move(item_); } - // If true, move() will not block. But there is no guarantee that somebody - // else will not get the item first. + /// If true, move() will not block. But there is no guarantee that somebody + /// else will not get the item first. bool hasValue() const { tsan_lock_guard l(mutex_); return item_ != nullptr || exception_ != nullptr; } - /// Returns the timing of prepare(). If the item was made on the calling - /// thread, the timing is 0 since only off-thread activity needs to be added - /// to the caller's timing. - const CpuWallTiming& prepareTiming() { - return timing_; + /// Returns the timing of making the item. If the item was made on the + /// calling thread, the timing is 0 since only off-thread activity needs to + /// be added to the caller's timing. + const CpuWallTiming& prepareTiming() const { + return makeTiming_; + } + + /// Cancels the task if it hasn't started yet or if item is already prepared. + /// If the task is making, the task will continue but AsyncSource + /// is marked as cancelled to allow proper cleanup in destructor. + void cancel() { + std::lock_guard l(mutex_); + const auto currentState = state(); + switch (currentState) { + case State::kInit: + VELOX_CHECK_NOT_NULL(itemMaker_); + VELOX_CHECK_NULL(item_); + itemMaker_ = nullptr; + setState(State::kCancelled); + return; + case State::kPrepared: + VELOX_CHECK_NULL(itemMaker_); + item_ = nullptr; + setState(State::kCancelled); + return; + default: + return; + } } /// This function assists the caller in ensuring that resources allocated in /// AsyncSource are promptly released: - /// 1. Waits for the completion of the 'make_' function if it is executing - /// in the thread pool. - /// 2. Resets the 'make_' function if it has not started yet. - /// 3. Cleans up the 'item_' if 'make_' has completed, but the result + /// 1. Waits for the completion of the 'itemMaker_' function if it is + /// executing in the thread pool. + /// 2. Resets the 'itemMaker_' function if it has not started yet. + /// 3. Cleans up the 'item_' if 'itemMaker_' has completed, but the result /// 'item_' has not been returned to the caller. void close() { - if (closed_ || moved_) { + auto currentState = state(); + if (currentState == State::kFinished || currentState == State::kFailed || + currentState == State::kCancelled) { return; } ContinueFuture wait; { std::lock_guard l(mutex_); - if (making_) { - promise_ = std::make_unique(); - wait = promise_->getSemiFuture(); - } else if (make_) { - make_ = nullptr; + if (tryCloseLocked()) { + return; } + checkState(state(), State::kMaking); + promises_.emplace_back("AsyncSource::close"); + wait = promises_.back().getSemiFuture(); } - auto& exec = folly::QueuedImmediateExecutor::instance(); - std::move(wait).via(&exec).wait(); + makeWait(std::move(wait)); + { std::lock_guard l(mutex_); - if (item_) { - item_ = nullptr; - } - closed_ = true; + const auto closed = tryCloseLocked(); + VELOX_CHECK(closed, "Unexpected close failure"); } } private: - std::unique_ptr runMake(std::function()>& make) { - process::ScopedThreadDebugInfo threadDebugInfo( - threadDebugInfo_.has_value() ? &threadDebugInfo_.value() : nullptr); - return make(); + // State transition diagram: + // + // ┌───────┐ prepare() ┌─────────┐ success ┌──────────┐ + // │ kInit │ ────────────► │ kMaking │ ──────────► │kPrepared │ + // └───────┘ move() └─────────┘ └──────────┘ + // │ │ │ │ │ + // │ │ close() │ exception │ │ cancel() + // │ │ ▼ │ │ + // │ │ ┌──────────┐ move() │ │ + // │ │ │ kFailed │ close() │ │ + // │ │ └──────────┘ │ │ + // │ │ │ │ + // │ └──────────────────────┬─────────────────────┘ │ + // │ ▼ │ + // │ cancel() ┌───────────┐ │ + // │ │ kFinished │ │ + // │ └───────────┘ │ + // │ │ + // └──────────────────────────┬──────────────────────────┘ + // ▼ + // ┌───────────┐ + // │kCancelled │ + // └───────────┘ + // + enum class State : uint8_t { + // Initial state before prepare() or move() is called. + kInit = 0, + // prepare() is executing. + kMaking = 1, + // prepare() has completed and item is ready. + kPrepared = 2, + // prepare() has failed with an exception. + kFailed = 3, + // move() or close() has been called and the AsyncSource is finished. + kFinished = 4, + // cancel() has been called. + kCancelled = 5, + }; + + State state() const { + return state_.load(std::memory_order_acquire); + } + + static std::string stateName(State state) { + switch (state) { + case State::kInit: + return "INIT"; + case State::kMaking: + return "MAKING"; + case State::kPrepared: + return "PREPARED"; + case State::kFailed: + return "FAILED"; + case State::kFinished: + return "FINISHED"; + case State::kCancelled: + return "CANCELLED"; + default: + VELOX_UNREACHABLE("Unknown state: {}", static_cast(state)); + } + } + + inline void checkState(State actualState, State expectedState) const { + VELOX_CHECK( + actualState == expectedState, + "Unexpected state: {}, expected: {}", + stateName(actualState), + stateName(expectedState)); + } + + inline void setState(State newState) { + const auto oldState = state(); + VELOX_CHECK( + isValidStateTransition(oldState, newState), + "Invalid state transition from {} to {}", + stateName(oldState), + stateName(newState)); + state_.store(newState, std::memory_order_release); + } + + static bool isValidStateTransition(State oldState, State newState) { + switch (oldState) { + case State::kInit: + return newState == State::kMaking || newState == State::kFinished || + newState == State::kCancelled; + case State::kMaking: + return newState == State::kPrepared || newState == State::kFailed; + case State::kPrepared: + return newState == State::kFinished || newState == State::kCancelled; + case State::kFailed: + case State::kFinished: + case State::kCancelled: + return false; + default: + VELOX_UNREACHABLE("Unknown state: {}", stateName(oldState)); + } + } + + // Makes item with timing, handles exceptions, state transitions, and promise + // signaling. + void makeItem(std::function()>&& itemMaker) { + VELOX_CHECK_NOT_NULL(itemMaker); + std::unique_ptr item; + std::exception_ptr exceptionPtr; + try { + CpuWallTimer timer(makeTiming_); + process::ScopedThreadDebugInfo threadDebugInfo( + threadDebugInfo_.has_value() ? &threadDebugInfo_.value() : nullptr); + item = itemMaker(); + } catch (std::exception&) { + exceptionPtr = std::current_exception(); + } + + std::vector promises; + { + std::lock_guard l(mutex_); + VELOX_CHECK_NULL(item_); + VELOX_CHECK_NULL(itemMaker_); + checkState(state(), State::kMaking); + if (FOLLY_LIKELY(exceptionPtr == nullptr)) { + item_ = std::move(item); + setState(State::kPrepared); + } else { + setExceptionLocked(std::move(exceptionPtr)); + } + promises.swap(promises_); + } + for (auto& promise : promises) { + promise.setValue(); + } } - // Stored context (if present upon construction) so they can be restored when - // make_ is invoked. + inline void setExceptionLocked(std::exception_ptr exception) { + VELOX_CHECK_NULL(exception_); + exception_ = std::move(exception); + setState(State::kFailed); + } + + // Waits for the promise to be fulfilled and injects a TestValue for testing + // race conditions. + inline void makeWait(ContinueFuture&& wait) { + common::testutil::TestValue::adjust( + "facebook::velox::AsyncSource::makeWait", this); + auto& exec = folly::QueuedImmediateExecutor::instance(); + std::move(wait).via(&exec).wait(); + } + + // Attempts to close immediately while holding the lock. Returns true if + // close completed, false if caller needs to wait for making to complete. + inline bool tryCloseLocked() { + const auto currentState = state(); + switch (currentState) { + case State::kInit: + VELOX_CHECK_NOT_NULL(itemMaker_); + VELOX_CHECK_NULL(item_); + itemMaker_ = nullptr; + setState(State::kFinished); + return true; + case State::kPrepared: + VELOX_CHECK_NULL(itemMaker_); + item_ = nullptr; + setState(State::kFinished); + return true; + case State::kMaking: + return false; + default: + VELOX_CHECK_NULL(itemMaker_); + VELOX_CHECK_NULL(item_); + return true; + } + } + + // Stored context (if present upon construction) so they can be restored + // when itemMaker_ is invoked. std::optional threadDebugInfo_; + std::atomic state_{State::kInit}; + CpuWallTiming makeTiming_; mutable std::mutex mutex_; - // True if 'prepare() is making the item. - bool making_{false}; - std::unique_ptr promise_; + std::vector promises_; std::unique_ptr item_; - std::function()> make_; + std::function()> itemMaker_; std::exception_ptr exception_; - CpuWallTiming timing_; - bool closed_{false}; - bool moved_{false}; }; + } // namespace facebook::velox diff --git a/velox/common/base/BigintIdMap.h b/velox/common/base/BigintIdMap.h index e63445f4d12..97d3abfdb51 100644 --- a/velox/common/base/BigintIdMap.h +++ b/velox/common/base/BigintIdMap.h @@ -31,8 +31,10 @@ class BigintIdMap { static constexpr int64_t kMaxCapacity = 1 << 30; // 1G entries, 12GB BigintIdMap(int32_t capacity, memory::MemoryPool& pool) : pool_(pool) { - makeTable(std::max( - 2 * sizeof(xsimd::batch), bits::nextPowerOfTwo(capacity))); + makeTable( + std::max( + 2 * sizeof(xsimd::batch), + bits::nextPowerOfTwo(capacity))); } BigintIdMap(const BigintIdMap& other) = delete; diff --git a/velox/common/base/BitUtil.cpp b/velox/common/base/BitUtil.cpp index 686bce02b1a..82aa286719c 100644 --- a/velox/common/base/BitUtil.cpp +++ b/velox/common/base/BitUtil.cpp @@ -219,4 +219,102 @@ uint64_t hashBytes(uint64_t seed, const char* data, size_t size) { } return a0 ^ ((a1 * kMul)) ^ (a2 * kMul); } + +void packBitmap(std::span bools, char* bitmap) { + uint64_t* word = reinterpret_cast(bitmap); + const uint64_t loopCount = bools.size() >> 6; + const uint64_t remainder = bools.size() - (loopCount << 6); + const bool* rawBools = bools.data(); + for (uint64_t i = 0; i < loopCount; ++i) { + for (int j = 0; j < 64; ++j) { + *word |= static_cast(*rawBools++) << j; + } + ++word; + } + for (int j = 0; j < remainder; ++j) { + *word |= static_cast(*rawBools++) << j; + } +} + +uint32_t +findSetBit(const char* bitmap, uint32_t begin, uint32_t end, uint32_t n) { + if (begin >= end || n == 0) { + return begin; + } + + const uint64_t* wordPtr = reinterpret_cast(bitmap); + + // Handle bits in the first partial word + uint32_t wordIdx = begin >> 6; + uint32_t bitOffset = begin & 63; + uint64_t word = wordPtr[wordIdx]; + + // Mask out bits before 'begin' + word &= ~((1ULL << bitOffset) - 1); + + while (true) { + // Count set bits in current word + uint32_t setBitsInWord = __builtin_popcountll(word); + + if (setBitsInWord >= n) { + // The n'th set bit is in this word + while (n > 0) { + uint32_t firstSetBit = __builtin_ffsll(static_cast(word)); + if (firstSetBit == 0) { + break; // No more set bits + } + + // __builtin_ffsll returns the index plus one, so subtract 1 + --firstSetBit; + + word &= ~(1ULL << firstSetBit); + --n; + + if (n == 0) { + uint32_t result = (wordIdx << 6) + firstSetBit; + return result < end ? result : end; + } + } + } + + // Move to next word + n -= setBitsInWord; + ++wordIdx; + bitOffset = 0; + + // Check if we've reached the end + uint32_t nextWordStart = wordIdx << 6; + if (nextWordStart >= end) { + return end; + } + + word = wordPtr[wordIdx]; + + // Mask out bits beyond 'end' if this is the last word + if (nextWordStart + 64 > end) { + word &= (1ULL << (end - nextWordStart)) - 1; + } + + // If no bits set in this word, continue to next word + if (word == 0) { + continue; + } + } +} + +void BitmapBuilder::copy(const Bitmap& other, uint32_t begin, uint32_t end) { + auto source = static_cast(other.bits()); + auto dest = static_cast(bitmap_); + auto firstByte = begin / 8; + if (begin % 8) { + uint8_t mask = (1 << (begin % 8)) - 1; + dest[firstByte] = static_cast( + (dest[firstByte] & mask) | (source[firstByte] & ~mask)); + ++firstByte; + } + // @lint-ignore CLANGSECURITY facebook-security-vulnerable-memcpy + std::memcpy( + dest + firstByte, source + firstByte, bits::nbytes(end) - firstByte); +} + } // namespace facebook::velox::bits diff --git a/velox/common/base/BitUtil.h b/velox/common/base/BitUtil.h index 3257cbe4b0f..164c4b85cd4 100644 --- a/velox/common/base/BitUtil.h +++ b/velox/common/base/BitUtil.h @@ -24,6 +24,7 @@ #include #include #include +#include #include #ifdef __BMI2__ @@ -94,6 +95,14 @@ inline void setBit(T* bits, uint64_t idx, bool value) { value ? setBit(bits, idx) : clearBit(bits, idx); } +/// Branchless: sets the bit at idx if value is true, no-op if false. +/// Assumes target memory is pre-zeroed for unset bits. +template +inline void maybeSetBit(T* bits, uint64_t idx, bool value) { + auto* bitsAs8Bit = reinterpret_cast(bits); + bitsAs8Bit[idx / 8] |= (static_cast(value) << (idx % 8)); +} + inline void negateBit(void* bits, uint64_t idx) { auto* bitsAs8Bit = reinterpret_cast(bits); bitsAs8Bit[idx / 8] ^= (1 << (idx % 8)); @@ -125,7 +134,7 @@ constexpr inline T divRoundUp(T value, U factor) { } constexpr inline uint64_t lowMask(int32_t bits) { - return (1UL << bits) - 1; + return (1ULL << bits) - 1; } constexpr inline uint64_t highMask(int32_t bits) { @@ -753,7 +762,7 @@ inline uint64_t nextPowerOfTwo(uint64_t size) { return 2 * lower; } -inline bool isPowerOfTwo(uint64_t size) { +constexpr bool isPowerOfTwo(uint64_t size) { return (size & (size - 1)) == 0; } @@ -1027,6 +1036,100 @@ void storeBitsToByte(uint8_t bits, uint8_t* bytes, unsigned index) { } } +/// Returns the number of bits required to store the value. +/// For a value of 0, returns 1. +inline int bitsRequired(uint64_t value) noexcept { + return 64 - __builtin_clzll(value | 1); +} + +/// Packs bools into bitmap. bitmap must point to a region large enough. +/// Does not clear bitmap first — bit i is set if bools[i] is true OR +/// bit i was already set. +void packBitmap(std::span bools, char* bitmap); + +/// Finds the index of the n'th set bit in [begin, end) in bitmap. +/// Returns end if not found. Returns begin if begin >= end or n == 0. +uint32_t +findSetBit(const char* bitmap, uint32_t begin, uint32_t end, uint32_t n); + +/// Debug: prints bits of a numeric type in nibble groups. +template +std::string printBits(T c) { + std::string result; + for (int i = 0; i < (sizeof(T) << 3); ++i) { + if (i > 0 && i % 4 == 0) { + result += ' '; + } + if (c & 1) { + result += '1'; + } else { + result += '0'; + } + c >>= 1; + } + // We actually want little endian order. + std::reverse(result.begin(), result.end()); + return result; +} + +/// Read-only view over a bitmap stored as char*. +class Bitmap { + public: + Bitmap(const void* bitmap, uint32_t size) + : bitmap_{static_cast(const_cast(bitmap))}, size_{size} {} + + bool test(uint32_t pos) const { + return isBitSet(reinterpret_cast(bitmap_), pos); + } + + uint32_t size() const { + return size_; + } + + const void* bits() const { + return bitmap_; + } + + protected: + char* bitmap_; + uint32_t size_; +}; + +/// Mutable bitmap builder. +class BitmapBuilder : public Bitmap { + public: + BitmapBuilder(void* bitmap, uint32_t size) : Bitmap{bitmap, size} {} + + void set(uint32_t pos) { + setBit(reinterpret_cast(bitmap_), pos); + } + + void maybeSet(uint32_t pos, bool bit) { + maybeSetBit(bitmap_, pos, bit); + } + + void set(uint32_t begin, uint32_t end) { + fillBits( + reinterpret_cast(bitmap_), + static_cast(begin), + static_cast(end), + true); + } + + void clear(uint32_t begin, uint32_t end) { + fillBits( + reinterpret_cast(bitmap_), + static_cast(begin), + static_cast(end), + false); + } + + /// Copy the specified range from the source bitmap into this one. It + /// guarantees |begin| is the beginning bit offset, but may copy more beyond + /// |end|. + void copy(const Bitmap& other, uint32_t begin, uint32_t end); +}; + } // namespace bits } // namespace velox } // namespace facebook diff --git a/velox/common/base/BloomFilter.h b/velox/common/base/BloomFilter.h index 6fab5a4376c..73c421442f9 100644 --- a/velox/common/base/BloomFilter.h +++ b/velox/common/base/BloomFilter.h @@ -80,6 +80,7 @@ class BloomFilter { auto version = stream.read(); VELOX_USER_CHECK_EQ(kBloomFilterV1, version); auto size = stream.read(); + VELOX_CHECK_GE(size, 0, "Invalid BloomFilter size: {}", size); bits_.resize(size); auto bitsdata = reinterpret_cast(serialized + stream.offset()); @@ -110,6 +111,34 @@ class BloomFilter { } } + /// Computes m (total bits of Bloom filter) which is expected to achieve, + /// for the specified expected insertions, the required false positive + /// probability. + /// + /// See + /// http://en.wikipedia.org/wiki/Bloom_filter#Probability_of_false_positives + /// for the formula. + /// + /// @param n expected insertions (must be positive). + /// @param p false positive rate (must be 0 < p < 1). + static int64_t optimalNumOfBits(int64_t n, double p) { + return static_cast( + -n * std::log(p) / (std::log(2.0) * std::log(2.0))); + } + + /// Computes m (total bits of Bloom filter) which is expected to achieve. + /// The smaller the expectedNumItems, the smaller the fpp. + /// + /// @param expectedNumItems expected number of items to insert. + /// @param maxNumItems maximum number of items. + static int64_t optimalNumOfBits( + int64_t expectedNumItems, + int64_t maxNumItems) { + double ratio = static_cast(expectedNumItems) / maxNumItems; + double fpp = kDefaultFpp * std::min(ratio, 1.0); + return optimalNumOfBits(expectedNumItems, fpp); + } + private: // We use 4 independent hash functions by taking 24 bits of // the hash code and breaking these up into 4 groups of 6 bits. Each group @@ -142,6 +171,8 @@ class BloomFilter { } static constexpr int8_t kBloomFilterV1 = 1; + static constexpr double kDefaultFpp = 0.03; + std::vector bits_; }; diff --git a/velox/common/base/CMakeLists.txt b/velox/common/base/CMakeLists.txt index b93691f67fe..c158ea872b5 100644 --- a/velox/common/base/CMakeLists.txt +++ b/velox/common/base/CMakeLists.txt @@ -12,7 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -velox_add_library(velox_exception Exceptions.cpp VeloxException.cpp Exceptions.h) +velox_add_library( + velox_exception + Exceptions.cpp + VeloxException.cpp + Exceptions.h + HEADERS + FmtStdFormatters.h + VeloxException.h +) velox_link_libraries( velox_exception PUBLIC velox_flag_definitions velox_process Folly::folly fmt::fmt gflags::gflags glog::glog @@ -30,10 +38,47 @@ velox_add_library( SimdUtil.cpp SkewedPartitionBalancer.cpp SpillConfig.cpp - SpillStats.cpp + SplitBlockBloomFilter.cpp StatsReporter.cpp SuccinctPrinter.cpp - TraceConfig.cpp + HEADERS + AdmissionController.h + BitSet.h + BitUtil.h + BloomFilter.h + CheckedArithmetic.h + ClassName.h + CoalesceIo.h + ConcurrentCounter.h + CountBits.h + Counters.h + Crc.h + Doubles.h + Fs.h + GTestMacros.h + IOUtils.h + IndexedPriorityQueue.h + Nulls.h + PeriodicStatsReporter.h + Pointers.h + Portability.h + PrefixSortConfig.h + RandomUtil.h + Range.h + RuntimeMetrics.h + ScopedLock.h + SelectivityInfo.h + Semaphore.h + SimdUtil-inl.h + SimdUtil.h + SkewedPartitionBalancer.h + SortingNetwork.h + SpillConfig.h + SplitBlockBloomFilter.h + StatsReporter.h + SuccinctPrinter.h + TreeOfLosers.h + XxHashInline.h ) velox_link_libraries( @@ -50,7 +95,7 @@ if(${VELOX_ENABLE_BENCHMARKS}) add_subdirectory(benchmarks) endif() -velox_add_library(velox_id_map BigintIdMap.cpp) +velox_add_library(velox_id_map BigintIdMap.cpp HEADERS BigintIdMap.h) velox_link_libraries( velox_id_map velox_memory @@ -61,5 +106,15 @@ velox_link_libraries( fmt::fmt ) -velox_add_library(velox_status Status.cpp) +velox_add_library(velox_status Status.cpp HEADERS Status.h) velox_link_libraries(velox_status PUBLIC fmt::fmt Folly::folly PRIVATE glog::glog) + +velox_add_library(velox_compare_flags INTERFACE HEADERS CompareFlags.h) + +velox_add_library(velox_macros INTERFACE HEADERS Macros.h) + +velox_add_library(velox_async_source INTERFACE HEADERS AsyncSource.h) + +velox_add_library(velox_lazy_cpu_thread_pool_executor INTERFACE HEADERS LazyCPUThreadPoolExecutor.h) + +velox_add_library(velox_exception_helper INTERFACE HEADERS ExceptionHelper.h) diff --git a/velox/common/base/CoalesceIo.h b/velox/common/base/CoalesceIo.h index f5f58ba1092..7ad405cf290 100644 --- a/velox/common/base/CoalesceIo.h +++ b/velox/common/base/CoalesceIo.h @@ -65,14 +65,13 @@ CoalesceIoStats coalesceIo( AddRanges addRanges, SkipRange skipRange, IoFunc ioFunc) { - std::vector buffers; int32_t startItem = 0; auto startOffset = offsetFunc(startItem); auto lastEndOffset = startOffset; std::vector ranges; CoalesceIoStats result; for (int32_t i = 0; i < items.size(); ++i) { - auto& item = items[i]; + const auto& item = items[i]; const auto itemOffset = offsetFunc(i); const auto itemSize = sizeFunc(i); result.payloadBytes += itemSize; diff --git a/velox/common/base/CompareFlags.h b/velox/common/base/CompareFlags.h index 705c1ed8729..f7873ab330d 100644 --- a/velox/common/base/CompareFlags.h +++ b/velox/common/base/CompareFlags.h @@ -84,7 +84,7 @@ struct CompareFlags { /// ex: (null, 1) = (null, 1) is indeterminate. /// /// - If all fields compare results are true, then result is true. - /// ex: (1, 1) = (1, 1) is indeterminate. + /// ex: (1, 1) = (1, 1) is true. /// /// 4. Maps: /// - Keys are compared first, if keys are not equal values are not diff --git a/velox/common/base/ConcurrentCounter.h b/velox/common/base/ConcurrentCounter.h index 689c91d1525..5fd16f9771d 100644 --- a/velox/common/base/ConcurrentCounter.h +++ b/velox/common/base/ConcurrentCounter.h @@ -19,6 +19,7 @@ #include #include #include +#include #include "velox/common/base/BitUtil.h" #include "velox/common/base/Exceptions.h" diff --git a/velox/common/base/CountBits.h b/velox/common/base/CountBits.h index b267d2f636e..f40fb95355e 100644 --- a/velox/common/base/CountBits.h +++ b/velox/common/base/CountBits.h @@ -16,6 +16,8 @@ #pragma once +#include + namespace facebook::velox { // Copied from format.h of fmt. diff --git a/velox/common/base/Counters.cpp b/velox/common/base/Counters.cpp index d3353656789..6e88aa653e8 100644 --- a/velox/common/base/Counters.cpp +++ b/velox/common/base/Counters.cpp @@ -45,6 +45,9 @@ void registerVeloxMetrics() { DEFINE_HISTOGRAM_METRIC( kMetricTaskBarrierProcessTimeMs, 1'000, 0, 30'000, 50, 90, 99, 100); + // Tracks the total number of splits received by all tasks. + DEFINE_METRIC(kMetricTaskSplitsCount, facebook::velox::StatType::COUNT); + /// ================== Cache Counters ================= // Tracks hive handle generation latency in range of [0, 100s] and reports @@ -107,8 +110,13 @@ void registerVeloxMetrics() { // was opened to load the cache. DEFINE_METRIC(kMetricCacheMaxAgeSecs, facebook::velox::StatType::AVG); - // Total number of cache entries. - DEFINE_METRIC(kMetricMemoryCacheNumEntries, facebook::velox::StatType::AVG); + // Total number of tiny cache entries. + DEFINE_METRIC( + kMetricMemoryCacheNumTinyEntries, facebook::velox::StatType::AVG); + + // Total number of large cache entries. + DEFINE_METRIC( + kMetricMemoryCacheNumLargeEntries, facebook::velox::StatType::AVG); // Total number of cache entries that do not cache anything. DEFINE_METRIC( @@ -249,6 +257,10 @@ void registerVeloxMetrics() { // Total number of error while writing to SSD cache files. DEFINE_METRIC(kMetricSsdCacheWriteSsdErrors, facebook::velox::StatType::SUM); + // Total number of errors due to SSD no space for writes. + DEFINE_METRIC( + kMetricSsdCacheWriteNoSpaceErrors, facebook::velox::StatType::SUM); + // Total number of errors while writing SSD checkpoint file. DEFINE_METRIC( kMetricSsdCacheWriteCheckpointErrors, facebook::velox::StatType::SUM); @@ -256,6 +268,10 @@ void registerVeloxMetrics() { // Total number of writes dropped due to no cache space. DEFINE_METRIC(kMetricSsdCacheWriteSsdDropped, facebook::velox::StatType::SUM); + // Total number of writes dropped due to entry limit exceeded. + DEFINE_METRIC( + kMetricSsdCacheWriteExceedEntryLimit, facebook::velox::StatType::SUM); + // Total number of errors while reading from SSD cache files. DEFINE_METRIC(kMetricSsdCacheReadSsdErrors, facebook::velox::StatType::SUM); @@ -365,9 +381,6 @@ void registerVeloxMetrics() { kMetricTaskMemoryReclaimWaitTimeoutCount, facebook::velox::StatType::COUNT); - // Tracks the total number of splits received by all tasks. - DEFINE_METRIC(kMetricTaskSplitsCount, facebook::velox::StatType::COUNT); - // The number of times that the memory reclaim fails because the operator is // executing a non-reclaimable section where it is expected to have reserved // enough memory to execute without asking for more. Therefore, it is an @@ -644,6 +657,10 @@ void registerVeloxMetrics() { DEFINE_HISTOGRAM_METRIC( kMetricIndexLookupBlockedWaitTimeMs, 32, 0, 16L << 10, 50, 90, 99, 100); + // The number of index lookup results with error. + DEFINE_METRIC( + kMetricIndexLookupErrorResultCount, facebook::velox::StatType::COUNT); + /// ================== Table Scan Counters ================= // Tracks the averaged table scan batch processing time in milliseconds. DEFINE_METRIC( diff --git a/velox/common/base/Counters.h b/velox/common/base/Counters.h index a936a871546..3a144f1bc88 100644 --- a/velox/common/base/Counters.h +++ b/velox/common/base/Counters.h @@ -23,373 +23,374 @@ namespace facebook::velox { /// Velox metrics Registration. void registerVeloxMetrics(); -constexpr folly::StringPiece kMetricHiveFileHandleGenerateLatencyMs{ +constexpr std::string_view kMetricHiveFileHandleGenerateLatencyMs{ "velox.hive_file_handle_generate_latency_ms"}; -constexpr folly::StringPiece kMetricCacheShrinkCount{ - "velox.cache_shrink_count"}; +constexpr std::string_view kMetricCacheShrinkCount{"velox.cache_shrink_count"}; -constexpr folly::StringPiece kMetricCacheShrinkTimeMs{"velox.cache_shrink_ms"}; +constexpr std::string_view kMetricCacheShrinkTimeMs{"velox.cache_shrink_ms"}; -constexpr folly::StringPiece kMetricMaxSpillLevelExceededCount{ +constexpr std::string_view kMetricMaxSpillLevelExceededCount{ "velox.spill_max_level_exceeded_count"}; -constexpr folly::StringPiece kMetricQueryMemoryReclaimTimeMs{ +constexpr std::string_view kMetricQueryMemoryReclaimTimeMs{ "velox.query_memory_reclaim_time_ms"}; -constexpr folly::StringPiece kMetricQueryMemoryReclaimedBytes{ +constexpr std::string_view kMetricQueryMemoryReclaimedBytes{ "velox.query_memory_reclaim_bytes"}; -constexpr folly::StringPiece kMetricQueryMemoryReclaimCount{ +constexpr std::string_view kMetricQueryMemoryReclaimCount{ "velox.query_memory_reclaim_count"}; -constexpr folly::StringPiece kMetricTaskMemoryReclaimCount{ +constexpr std::string_view kMetricTaskMemoryReclaimCount{ "velox.task_memory_reclaim_count"}; -constexpr folly::StringPiece kMetricTaskMemoryReclaimWaitTimeMs{ +constexpr std::string_view kMetricTaskMemoryReclaimWaitTimeMs{ "velox.task_memory_reclaim_wait_ms"}; -constexpr folly::StringPiece kMetricTaskMemoryReclaimExecTimeMs{ +constexpr std::string_view kMetricTaskMemoryReclaimExecTimeMs{ "velox.task_memory_reclaim_exec_ms"}; -constexpr folly::StringPiece kMetricTaskMemoryReclaimWaitTimeoutCount{ +constexpr std::string_view kMetricTaskMemoryReclaimWaitTimeoutCount{ "velox.task_memory_reclaim_wait_timeout_count"}; -constexpr folly::StringPiece kMetricTaskSplitsCount{"velox.task_splits_count"}; +constexpr std::string_view kMetricTaskSplitsCount{"velox.task_splits_count"}; -constexpr folly::StringPiece kMetricOpMemoryReclaimTimeMs{ +constexpr std::string_view kMetricOpMemoryReclaimTimeMs{ "velox.op_memory_reclaim_time_ms"}; -constexpr folly::StringPiece kMetricOpMemoryReclaimedBytes{ +constexpr std::string_view kMetricOpMemoryReclaimedBytes{ "velox.op_memory_reclaim_bytes"}; -constexpr folly::StringPiece kMetricOpMemoryReclaimCount{ +constexpr std::string_view kMetricOpMemoryReclaimCount{ "velox.op_memory_reclaim_count"}; -constexpr folly::StringPiece kMetricMemoryNonReclaimableCount{ +constexpr std::string_view kMetricMemoryNonReclaimableCount{ "velox.memory_non_reclaimable_count"}; -constexpr folly::StringPiece kMetricMemoryPoolInitialCapacityBytes{ +constexpr std::string_view kMetricMemoryPoolInitialCapacityBytes{ "velox.memory_pool_initial_capacity_bytes"}; -constexpr folly::StringPiece kMetricMemoryPoolCapacityGrowCount{ +constexpr std::string_view kMetricMemoryPoolCapacityGrowCount{ "velox.memory_pool_capacity_growth_count"}; -constexpr folly::StringPiece kMetricMemoryPoolUsageLeakBytes{ +constexpr std::string_view kMetricMemoryPoolUsageLeakBytes{ "velox.memory_pool_usage_leak_bytes"}; -constexpr folly::StringPiece kMetricMemoryPoolReservationLeakBytes{ +constexpr std::string_view kMetricMemoryPoolReservationLeakBytes{ "velox.memory_pool_reservation_leak_bytes"}; -constexpr folly::StringPiece kMetricMemoryAllocatorDoubleFreeCount{ +constexpr std::string_view kMetricMemoryAllocatorDoubleFreeCount{ "velox.memory_allocator_double_free_count"}; -constexpr folly::StringPiece kMetricArbitratorLocalArbitrationCount{ +constexpr std::string_view kMetricArbitratorLocalArbitrationCount{ "velox.arbitrator_local_arbitration_count"}; -constexpr folly::StringPiece kMetricArbitratorGlobalArbitrationCount{ +constexpr std::string_view kMetricArbitratorGlobalArbitrationCount{ "velox.arbitrator_global_arbitration_count"}; -constexpr folly::StringPiece - kMetricArbitratorGlobalArbitrationNumReclaimVictims{ - "velox.arbitrator_global_arbitration_num_reclaim_victims"}; +constexpr std::string_view kMetricArbitratorGlobalArbitrationNumReclaimVictims{ + "velox.arbitrator_global_arbitration_num_reclaim_victims"}; -constexpr folly::StringPiece - kMetricArbitratorGlobalArbitrationFailedVictimCount{ - "velox.arbitrator_global_arbitration_failed_victim_count"}; +constexpr std::string_view kMetricArbitratorGlobalArbitrationFailedVictimCount{ + "velox.arbitrator_global_arbitration_failed_victim_count"}; -constexpr folly::StringPiece kMetricArbitratorGlobalArbitrationBytes{ +constexpr std::string_view kMetricArbitratorGlobalArbitrationBytes{ "velox.arbitrator_global_arbitration_bytes"}; -constexpr folly::StringPiece kMetricArbitratorGlobalArbitrationTimeMs{ +constexpr std::string_view kMetricArbitratorGlobalArbitrationTimeMs{ "velox.arbitrator_global_arbitration_time_ms"}; -constexpr folly::StringPiece kMetricArbitratorGlobalArbitrationWaitCount{ +constexpr std::string_view kMetricArbitratorGlobalArbitrationWaitCount{ "velox.arbitrator_global_arbitration_wait_count"}; -constexpr folly::StringPiece kMetricArbitratorGlobalArbitrationWaitTimeMs{ +constexpr std::string_view kMetricArbitratorGlobalArbitrationWaitTimeMs{ "velox.arbitrator_global_arbitration_wait_time_ms"}; -constexpr folly::StringPiece kMetricArbitratorAbortedCount{ +constexpr std::string_view kMetricArbitratorAbortedCount{ "velox.arbitrator_aborted_count"}; -constexpr folly::StringPiece kMetricArbitratorFailuresCount{ +constexpr std::string_view kMetricArbitratorFailuresCount{ "velox.arbitrator_failures_count"}; -constexpr folly::StringPiece kMetricArbitratorOpExecTimeMs{ +constexpr std::string_view kMetricArbitratorOpExecTimeMs{ "velox.arbitrator_op_exec_time_ms"}; -constexpr folly::StringPiece kMetricArbitratorFreeCapacityBytes{ +constexpr std::string_view kMetricArbitratorFreeCapacityBytes{ "velox.arbitrator_free_capacity_bytes"}; -constexpr folly::StringPiece kMetricArbitratorFreeReservedCapacityBytes{ +constexpr std::string_view kMetricArbitratorFreeReservedCapacityBytes{ "velox.arbitrator_free_reserved_capacity_bytes"}; -constexpr folly::StringPiece kMetricDriverYieldCount{ - "velox.driver_yield_count"}; +constexpr std::string_view kMetricDriverYieldCount{"velox.driver_yield_count"}; -constexpr folly::StringPiece kMetricDriverQueueTimeMs{ +constexpr std::string_view kMetricDriverQueueTimeMs{ "velox.driver_queue_time_ms"}; -constexpr folly::StringPiece kMetricDriverExecTimeMs{ - "velox.driver_exec_time_ms"}; +constexpr std::string_view kMetricDriverExecTimeMs{"velox.driver_exec_time_ms"}; -constexpr folly::StringPiece kMetricSpilledInputBytes{ - "velox.spill_input_bytes"}; +constexpr std::string_view kMetricSpilledInputBytes{"velox.spill_input_bytes"}; -constexpr folly::StringPiece kMetricSpilledBytes{"velox.spill_bytes"}; +constexpr std::string_view kMetricSpilledBytes{"velox.spill_bytes"}; -constexpr folly::StringPiece kMetricSpilledRowsCount{"velox.spill_rows_count"}; +constexpr std::string_view kMetricSpilledRowsCount{"velox.spill_rows_count"}; -constexpr folly::StringPiece kMetricSpilledFilesCount{ - "velox.spill_files_count"}; +constexpr std::string_view kMetricSpilledFilesCount{"velox.spill_files_count"}; -constexpr folly::StringPiece kMetricSpillFillTimeMs{"velox.spill_fill_time_ms"}; +constexpr std::string_view kMetricSpillFillTimeMs{"velox.spill_fill_time_ms"}; -constexpr folly::StringPiece kMetricSpillSortTimeMs{"velox.spill_sort_time_ms"}; +constexpr std::string_view kMetricSpillSortTimeMs{"velox.spill_sort_time_ms"}; -constexpr folly::StringPiece kMetricSpillExtractVectorTimeMs{ +constexpr std::string_view kMetricSpillExtractVectorTimeMs{ "velox.spill_extract_vector_time_ms"}; -constexpr folly::StringPiece kMetricSpillSerializationTimeMs{ +constexpr std::string_view kMetricSpillSerializationTimeMs{ "velox.spill_serialization_time_ms"}; -constexpr folly::StringPiece kMetricSpillWritesCount{ - "velox.spill_writes_count"}; +constexpr std::string_view kMetricSpillWritesCount{"velox.spill_writes_count"}; -constexpr folly::StringPiece kMetricSpillFlushTimeMs{ - "velox.spill_flush_time_ms"}; +constexpr std::string_view kMetricSpillFlushTimeMs{"velox.spill_flush_time_ms"}; -constexpr folly::StringPiece kMetricSpillWriteTimeMs{ - "velox.spill_write_time_ms"}; +constexpr std::string_view kMetricSpillWriteTimeMs{"velox.spill_write_time_ms"}; -constexpr folly::StringPiece kMetricSpillMemoryBytes{ - "velox.spill_memory_bytes"}; +constexpr std::string_view kMetricSpillMemoryBytes{"velox.spill_memory_bytes"}; -constexpr folly::StringPiece kMetricSpillPeakMemoryBytes{ +constexpr std::string_view kMetricSpillPeakMemoryBytes{ "velox.spill_peak_memory_bytes"}; -constexpr folly::StringPiece kMetricFileWriterEarlyFlushedRawBytes{ +constexpr std::string_view kMetricFileWriterEarlyFlushedRawBytes{ "velox.file_writer_early_flushed_raw_bytes"}; -constexpr folly::StringPiece kMetricHiveSortWriterFinishTimeMs{ +constexpr std::string_view kMetricHiveSortWriterFinishTimeMs{ "velox.hive_sort_writer_finish_time_ms"}; -constexpr folly::StringPiece kMetricArbitratorRequestsCount{ +constexpr std::string_view kMetricArbitratorRequestsCount{ "velox.arbitrator_requests_count"}; -constexpr folly::StringPiece kMetricMemoryAllocatorMappedBytes{ +constexpr std::string_view kMetricMemoryAllocatorMappedBytes{ "velox.memory_allocator_mapped_bytes"}; -constexpr folly::StringPiece kMetricMemoryAllocatorExternalMappedBytes{ +constexpr std::string_view kMetricMemoryAllocatorExternalMappedBytes{ "velox.memory_allocator_external_mapped_bytes"}; -constexpr folly::StringPiece kMetricMemoryAllocatorAllocatedBytes{ +constexpr std::string_view kMetricMemoryAllocatorAllocatedBytes{ "velox.memory_allocator_allocated_bytes"}; -constexpr folly::StringPiece kMetricMemoryAllocatorTotalUsedBytes{ +constexpr std::string_view kMetricMemoryAllocatorTotalUsedBytes{ "velox.memory_allocator_total_used_bytes"}; -constexpr folly::StringPiece kMetricMmapAllocatorDelegatedAllocatedBytes{ +constexpr std::string_view kMetricMmapAllocatorDelegatedAllocatedBytes{ "velox.mmap_allocator_delegated_allocated_bytes"}; -constexpr folly::StringPiece kMetricCacheMaxAgeSecs{"velox.cache_max_age_secs"}; +constexpr std::string_view kMetricCacheMaxAgeSecs{"velox.cache_max_age_secs"}; -constexpr folly::StringPiece kMetricMemoryCacheNumEntries{ - "velox.memory_cache_num_entries"}; +constexpr std::string_view kMetricMemoryCacheNumTinyEntries{ + "velox.memory_cache_num_tiny_entries"}; -constexpr folly::StringPiece kMetricMemoryCacheNumEmptyEntries{ +constexpr std::string_view kMetricMemoryCacheNumLargeEntries{ + "velox.memory_cache_num_large_entries"}; + +constexpr std::string_view kMetricMemoryCacheNumEmptyEntries{ "velox.memory_cache_num_empty_entries"}; -constexpr folly::StringPiece kMetricMemoryCacheNumSharedEntries{ +constexpr std::string_view kMetricMemoryCacheNumSharedEntries{ "velox.memory_cache_num_shared_entries"}; -constexpr folly::StringPiece kMetricMemoryCacheNumExclusiveEntries{ +constexpr std::string_view kMetricMemoryCacheNumExclusiveEntries{ "velox.memory_cache_num_exclusive_entries"}; -constexpr folly::StringPiece kMetricMemoryCacheNumPrefetchedEntries{ +constexpr std::string_view kMetricMemoryCacheNumPrefetchedEntries{ "velox.memory_cache_num_prefetched_entries"}; -constexpr folly::StringPiece kMetricMemoryCacheTotalTinyBytes{ +constexpr std::string_view kMetricMemoryCacheTotalTinyBytes{ "velox.memory_cache_total_tiny_bytes"}; -constexpr folly::StringPiece kMetricMemoryCacheTotalLargeBytes{ +constexpr std::string_view kMetricMemoryCacheTotalLargeBytes{ "velox.memory_cache_total_large_bytes"}; -constexpr folly::StringPiece kMetricMemoryCacheTotalTinyPaddingBytes{ +constexpr std::string_view kMetricMemoryCacheTotalTinyPaddingBytes{ "velox.memory_cache_total_tiny_padding_bytes"}; -constexpr folly::StringPiece kMetricMemoryCacheTotalLargePaddingBytes{ +constexpr std::string_view kMetricMemoryCacheTotalLargePaddingBytes{ "velox.memory_cache_total_large_padding_bytes"}; -constexpr folly::StringPiece kMetricMemoryCacheTotalPrefetchBytes{ +constexpr std::string_view kMetricMemoryCacheTotalPrefetchBytes{ "velox.memory_cache_total_prefetched_bytes"}; -constexpr folly::StringPiece kMetricMemoryCacheSumEvictScore{ +constexpr std::string_view kMetricMemoryCacheSumEvictScore{ "velox.memory_cache_sum_evict_score"}; -constexpr folly::StringPiece kMetricMemoryCacheNumHits{ +constexpr std::string_view kMetricMemoryCacheNumHits{ "velox.memory_cache_num_hits"}; -constexpr folly::StringPiece kMetricMemoryCacheHitBytes{ +constexpr std::string_view kMetricMemoryCacheHitBytes{ "velox.memory_cache_hit_bytes"}; -constexpr folly::StringPiece kMetricMemoryCacheNumNew{ +constexpr std::string_view kMetricMemoryCacheNumNew{ "velox.memory_cache_num_new"}; -constexpr folly::StringPiece kMetricMemoryCacheNumEvicts{ +constexpr std::string_view kMetricMemoryCacheNumEvicts{ "velox.memory_cache_num_evicts"}; -constexpr folly::StringPiece kMetricMemoryCacheNumSavableEvicts{ +constexpr std::string_view kMetricMemoryCacheNumSavableEvicts{ "velox.memory_cache_num_savable_evicts"}; -constexpr folly::StringPiece kMetricMemoryCacheNumEvictChecks{ +constexpr std::string_view kMetricMemoryCacheNumEvictChecks{ "velox.memory_cache_num_evict_checks"}; -constexpr folly::StringPiece kMetricMemoryCacheNumWaitExclusive{ +constexpr std::string_view kMetricMemoryCacheNumWaitExclusive{ "velox.memory_cache_num_wait_exclusive"}; -constexpr folly::StringPiece kMetricMemoryCacheNumAllocClocks{ +constexpr std::string_view kMetricMemoryCacheNumAllocClocks{ "velox.memory_cache_num_alloc_clocks"}; -constexpr folly::StringPiece kMetricMemoryCacheNumAgedOutEntries{ +constexpr std::string_view kMetricMemoryCacheNumAgedOutEntries{ "velox.memory_cache_num_aged_out_entries"}; -constexpr folly::StringPiece kMetricMemoryCacheNumStaleEntries{ +constexpr std::string_view kMetricMemoryCacheNumStaleEntries{ "velox.memory_cache_num_stale_entries"}; -constexpr folly::StringPiece kMetricSsdCacheCachedRegions{ +constexpr std::string_view kMetricSsdCacheCachedRegions{ "velox.ssd_cache_cached_regions"}; -constexpr folly::StringPiece kMetricSsdCacheCachedEntries{ +constexpr std::string_view kMetricSsdCacheCachedEntries{ "velox.ssd_cache_cached_entries"}; -constexpr folly::StringPiece kMetricSsdCacheCachedBytes{ +constexpr std::string_view kMetricSsdCacheCachedBytes{ "velox.ssd_cache_cached_bytes"}; -constexpr folly::StringPiece kMetricSsdCacheReadEntries{ +constexpr std::string_view kMetricSsdCacheReadEntries{ "velox.ssd_cache_read_entries"}; -constexpr folly::StringPiece kMetricSsdCacheReadBytes{ +constexpr std::string_view kMetricSsdCacheReadBytes{ "velox.ssd_cache_read_bytes"}; -constexpr folly::StringPiece kMetricSsdCacheWrittenEntries{ +constexpr std::string_view kMetricSsdCacheWrittenEntries{ "velox.ssd_cache_written_entries"}; -constexpr folly::StringPiece kMetricSsdCacheWrittenBytes{ +constexpr std::string_view kMetricSsdCacheWrittenBytes{ "velox.ssd_cache_written_bytes"}; -constexpr folly::StringPiece kMetricSsdCacheAgedOutEntries{ +constexpr std::string_view kMetricSsdCacheAgedOutEntries{ "velox.ssd_cache_aged_out_entries"}; -constexpr folly::StringPiece kMetricSsdCacheAgedOutRegions{ +constexpr std::string_view kMetricSsdCacheAgedOutRegions{ "velox.ssd_cache_aged_out_regions"}; -constexpr folly::StringPiece kMetricSsdCacheOpenSsdErrors{ +constexpr std::string_view kMetricSsdCacheOpenSsdErrors{ "velox.ssd_cache_open_ssd_errors"}; -constexpr folly::StringPiece kMetricSsdCacheOpenCheckpointErrors{ +constexpr std::string_view kMetricSsdCacheOpenCheckpointErrors{ "velox.ssd_cache_open_checkpoint_errors"}; -constexpr folly::StringPiece kMetricSsdCacheOpenLogErrors{ +constexpr std::string_view kMetricSsdCacheOpenLogErrors{ "velox.ssd_cache_open_log_errors"}; -constexpr folly::StringPiece kMetricSsdCacheMetaFileDeleteErrors{ +constexpr std::string_view kMetricSsdCacheMetaFileDeleteErrors{ "velox.ssd_cache_delete_meta_file_errors"}; -constexpr folly::StringPiece kMetricSsdCacheGrowFileErrors{ +constexpr std::string_view kMetricSsdCacheGrowFileErrors{ "velox.ssd_cache_grow_file_errors"}; -constexpr folly::StringPiece kMetricSsdCacheWriteSsdErrors{ +constexpr std::string_view kMetricSsdCacheWriteSsdErrors{ "velox.ssd_cache_write_ssd_errors"}; -constexpr folly::StringPiece kMetricSsdCacheWriteSsdDropped{ +constexpr std::string_view kMetricSsdCacheWriteNoSpaceErrors{ + "velox.ssd_cache_write_no_space_errors"}; + +constexpr std::string_view kMetricSsdCacheWriteSsdDropped{ "velox.ssd_cache_write_ssd_dropped"}; -constexpr folly::StringPiece kMetricSsdCacheWriteCheckpointErrors{ +constexpr std::string_view kMetricSsdCacheWriteExceedEntryLimit{ + "velox.ssd_cache_write_exceed_entry_limit"}; + +constexpr std::string_view kMetricSsdCacheWriteCheckpointErrors{ "velox.ssd_cache_write_checkpoint_errors"}; -constexpr folly::StringPiece kMetricSsdCacheReadCorruptions{ +constexpr std::string_view kMetricSsdCacheReadCorruptions{ "velox.ssd_cache_read_corruptions"}; -constexpr folly::StringPiece kMetricSsdCacheReadSsdErrors{ +constexpr std::string_view kMetricSsdCacheReadSsdErrors{ "velox.ssd_cache_read_ssd_errors"}; -constexpr folly::StringPiece kMetricSsdCacheReadCheckpointErrors{ +constexpr std::string_view kMetricSsdCacheReadCheckpointErrors{ "velox.ssd_cache_read_checkpoint_errors"}; -constexpr folly::StringPiece kMetricSsdCacheReadWithoutChecksum{ +constexpr std::string_view kMetricSsdCacheReadWithoutChecksum{ "velox.ssd_cache_read_without_checksum"}; -constexpr folly::StringPiece kMetricSsdCacheCheckpointsRead{ +constexpr std::string_view kMetricSsdCacheCheckpointsRead{ "velox.ssd_cache_checkpoints_read"}; -constexpr folly::StringPiece kMetricSsdCacheCheckpointsWritten{ +constexpr std::string_view kMetricSsdCacheCheckpointsWritten{ "velox.ssd_cache_checkpoints_written"}; -constexpr folly::StringPiece kMetricSsdCacheRegionsEvicted{ +constexpr std::string_view kMetricSsdCacheRegionsEvicted{ "velox.ssd_cache_regions_evicted"}; -constexpr folly::StringPiece kMetricSsdCacheRecoveredEntries{ +constexpr std::string_view kMetricSsdCacheRecoveredEntries{ "velox.ssd_cache_recovered_entries"}; -constexpr folly::StringPiece kMetricExchangeTransactionCreateDelay{ +constexpr std::string_view kMetricExchangeTransactionCreateDelay{ "velox.exchange.transaction_create_delay_ms"}; -constexpr folly::StringPiece kMetricExchangeDataTimeMs{ +constexpr std::string_view kMetricExchangeDataTimeMs{ "velox.exchange_data_time_ms"}; -constexpr folly::StringPiece kMetricExchangeDataBytes{ +constexpr std::string_view kMetricExchangeDataBytes{ "velox.exchange_data_bytes"}; -constexpr folly::StringPiece kMetricExchangeDataSize{ - "velox.exchange_data_size"}; +constexpr std::string_view kMetricExchangeDataSize{"velox.exchange_data_size"}; -constexpr folly::StringPiece kMetricExchangeDataCount{ +constexpr std::string_view kMetricExchangeDataCount{ "velox.exchange_data_count"}; -constexpr folly::StringPiece kMetricExchangeDataSizeTimeMs{ +constexpr std::string_view kMetricExchangeDataSizeTimeMs{ "velox.exchange_data_size_time_ms"}; -constexpr folly::StringPiece kMetricExchangeDataSizeCount{ +constexpr std::string_view kMetricExchangeDataSizeCount{ "velox.exchange_data_size_count"}; -constexpr folly::StringPiece kMetricStorageThrottledDurationMs{ +constexpr std::string_view kMetricStorageThrottledDurationMs{ "velox.storage_throttled_duration_ms"}; -constexpr folly::StringPiece kMetricStorageLocalThrottled{ +constexpr std::string_view kMetricStorageLocalThrottled{ "velox.storage_local_throttled_count"}; -constexpr folly::StringPiece kMetricStorageGlobalThrottled{ +constexpr std::string_view kMetricStorageGlobalThrottled{ "velox.storage_global_throttled_count"}; -constexpr folly::StringPiece kMetricStorageNetworkThrottled{ +constexpr std::string_view kMetricStorageNetworkThrottled{ "velox.storage_network_throttled_count"}; -constexpr folly::StringPiece kMetricIndexLookupResultRawBytes{ +constexpr std::string_view kMetricIndexLookupResultRawBytes{ "velox.index_lookup_result_raw_bytes"}; -constexpr folly::StringPiece kMetricIndexLookupResultBytes{ +constexpr std::string_view kMetricIndexLookupResultBytes{ "velox.index_lookup_result_bytes"}; -constexpr folly::StringPiece kMetricIndexLookupTimeMs{ +constexpr std::string_view kMetricIndexLookupTimeMs{ "velox.index_lookup_time_ms"}; -constexpr folly::StringPiece kMetricIndexLookupWaitTimeMs{ +constexpr std::string_view kMetricIndexLookupWaitTimeMs{ "velox.index_lookup_wait_time_ms"}; -constexpr folly::StringPiece kMetricIndexLookupBlockedWaitTimeMs{ +constexpr std::string_view kMetricIndexLookupBlockedWaitTimeMs{ "velox.index_lookup_blocked_wait_time_ms"}; -constexpr folly::StringPiece kMetricTableScanBatchProcessTimeMs{ +constexpr std::string_view kMetricIndexLookupErrorResultCount{ + "velox.index_lookup_error_result_count"}; + +constexpr std::string_view kMetricTableScanBatchProcessTimeMs{ "velox.table_scan_batch_process_time_ms"}; -constexpr folly::StringPiece kMetricTableScanBatchBytes{ +constexpr std::string_view kMetricTableScanBatchBytes{ "velox.table_scan_batch_bytes"}; -constexpr folly::StringPiece kMetricTaskBatchProcessTimeMs{ +constexpr std::string_view kMetricTaskBatchProcessTimeMs{ "velox.task_batch_process_time_ms"}; -constexpr folly::StringPiece kMetricTaskBarrierProcessTimeMs{ +constexpr std::string_view kMetricTaskBarrierProcessTimeMs{ "velox.task_barrier_process_time_ms"}; + } // namespace facebook::velox diff --git a/velox/common/base/ExceptionHelper.h b/velox/common/base/ExceptionHelper.h index 193641ce63a..ad9b21c9471 100644 --- a/velox/common/base/ExceptionHelper.h +++ b/velox/common/base/ExceptionHelper.h @@ -34,6 +34,24 @@ struct CompileTimeEmptyString { } }; +/// Wraps a const char* that must originate from a string literal (or other +/// compile-time constant). Provides a distinct type to resolve overload +/// ambiguity with std::string_view parameters in exception constructors. +class CompileTimeStringLiteral { + const char* data_; + + public: + /* implicit */ constexpr CompileTimeStringLiteral(const char* data) + : data_{data} {} + + constexpr operator const char*() const { + return data_; + } + constexpr operator std::string_view() const { + return data_ ? std::string_view(data_) : std::string_view(); + } +}; + // When there is no message passed, we can statically detect this case // and avoid passing even a single unnecessary argument pointer, // minimizing size and thus maximizing eligibility for inlining. diff --git a/velox/common/base/Exceptions.h b/velox/common/base/Exceptions.h index 4485b8b006a..92be6d30989 100644 --- a/velox/common/base/Exceptions.h +++ b/velox/common/base/Exceptions.h @@ -44,6 +44,8 @@ struct VeloxCheckFailArgs { // inline it when it is large. Having an out-of-line error path helps // otherwise-small functions that call error-checking macros stay // small and thus stay eligible for inlining. + +// Overload without messageTemplate: message IS the template (0 or 1 args). template [[noreturn]] void veloxCheckFail(const VeloxCheckFailArgs& args, StringType s) { static_assert( @@ -69,6 +71,39 @@ template args.isRetriable); } +// Overload with messageTemplate: template is the format string before +// interpolation (>=2 args, format string + arguments). The template should +// be a string literal; for runtime format strings, use the single-arg +// VELOX_FAIL("{}", runtimeStr) pattern instead. +template +[[noreturn]] void veloxCheckFail( + const VeloxCheckFailArgs& args, + StringType s, + CompileTimeStringLiteral messageTemplate) { + static_assert( + !std::is_same_v, + "BUG: we should not pass std::string by value to veloxCheckFail"); + if constexpr (!std::is_same_v) { + LOG(ERROR) << "Line: " << args.file << ":" << args.line + << ", Function:" << args.function + << ", Expression: " << args.expression << " " << s + << ", Source: " << args.errorSource + << ", ErrorCode: " << args.errorCode; + } + + ++threadNumVeloxThrow(); + throw Exception( + args.file, + args.line, + args.function, + args.expression, + s, + args.errorSource, + args.errorCode, + args.isRetriable, + messageTemplate); +} + // VeloxCheckFailStringType helps us pass by reference to // veloxCheckFail exactly when the string type is std::string. template @@ -106,21 +141,99 @@ struct VeloxCheckFailStringType { extern template void veloxCheckFail( \ const VeloxCheckFailArgs& args, \ const std::string&); \ + extern template void veloxCheckFail( \ + const VeloxCheckFailArgs& args, \ + CompileTimeEmptyString, \ + CompileTimeStringLiteral); \ + extern template void veloxCheckFail( \ + const VeloxCheckFailArgs& args, \ + const char*, \ + CompileTimeStringLiteral); \ + extern template void veloxCheckFail( \ + const VeloxCheckFailArgs& args, \ + const std::string&, \ + CompileTimeStringLiteral); \ } // namespace detail // Definitions corresponding to DECLARE_CHECK_FAIL_TEMPLATES. Should // only be used in Exceptions.cpp. -#define DEFINE_CHECK_FAIL_TEMPLATES(exception_type) \ - template void veloxCheckFail( \ - const VeloxCheckFailArgs& args, CompileTimeEmptyString); \ - template void veloxCheckFail( \ - const VeloxCheckFailArgs& args, const char*); \ - template void veloxCheckFail( \ - const VeloxCheckFailArgs& args, const std::string&); +#define DEFINE_CHECK_FAIL_TEMPLATES(exception_type) \ + template void veloxCheckFail( \ + const VeloxCheckFailArgs& args, CompileTimeEmptyString); \ + template void veloxCheckFail( \ + const VeloxCheckFailArgs& args, const char*); \ + template void veloxCheckFail( \ + const VeloxCheckFailArgs& args, const std::string&); \ + template void veloxCheckFail( \ + const VeloxCheckFailArgs& args, \ + CompileTimeEmptyString, \ + CompileTimeStringLiteral); \ + template void veloxCheckFail( \ + const VeloxCheckFailArgs& args, const char*, CompileTimeStringLiteral); \ + template void veloxCheckFail( \ + const VeloxCheckFailArgs& args, \ + const std::string&, \ + CompileTimeStringLiteral); } // namespace detail -#define _VELOX_THROW_IMPL( \ +// Macro arg-count detection for dispatching between the no-template +// and with-template overloads of veloxCheckFail. +// 0 or 1 args: message IS the template (no-template overload). +// >=2 args: first arg is the format string template (with-template overload). +#define _VELOX_NARGS_IMPL( \ + _0, \ + _1, \ + _2, \ + _3, \ + _4, \ + _5, \ + _6, \ + _7, \ + _8, \ + _9, \ + _10, \ + _11, \ + _12, \ + _13, \ + _14, \ + _15, \ + _16, \ + N, \ + ...) \ + N +#define _VELOX_NARGS(...) \ + _VELOX_NARGS_IMPL( \ + dummy, \ + ##__VA_ARGS__, \ + 16, \ + 15, \ + 14, \ + 13, \ + 12, \ + 11, \ + 10, \ + 9, \ + 8, \ + 7, \ + 6, \ + 5, \ + 4, \ + 3, \ + 2, \ + 1, \ + 0) + +// Extract the first argument from __VA_ARGS__ (the format string) as a +// CompileTimeStringLiteral. Only used for >=2 args path. +#define _VELOX_MSG_TEMPLATE_PICK(_1, _2, ...) _2 +#define _VELOX_MSG_TEMPLATE(...) \ + ::facebook::velox::CompileTimeStringLiteral( \ + _VELOX_MSG_TEMPLATE_PICK("", ##__VA_ARGS__, "")) + +// _VELOX_THROW_IMPL dispatches to the no-template or with-template path +// based on the number of user-supplied message arguments. +#define _VELOX_THROW_IMPL_BODY_NO_TEMPLATE( \ exception, exprStr, errorSource, errorCode, isRetriable, ...) \ do { \ /* GCC 9.2.1 doesn't accept this code with constexpr. */ \ @@ -140,6 +253,54 @@ struct VeloxCheckFailStringType { decltype(message)>::type>(veloxCheckFailArgs, message); \ } while (0) +#define _VELOX_THROW_IMPL_BODY_WITH_TEMPLATE( \ + exception, exprStr, errorSource, errorCode, isRetriable, ...) \ + do { \ + /* GCC 9.2.1 doesn't accept this code with constexpr. */ \ + static const ::facebook::velox::detail::VeloxCheckFailArgs \ + veloxCheckFailArgs = { \ + __FILE__, \ + __LINE__, \ + __FUNCTION__, \ + exprStr, \ + errorSource, \ + errorCode, \ + isRetriable}; \ + auto message = ::facebook::velox::errorMessage(__VA_ARGS__); \ + ::facebook::velox::detail::veloxCheckFail< \ + exception, \ + typename ::facebook::velox::detail::VeloxCheckFailStringType< \ + decltype(message)>::type>( \ + veloxCheckFailArgs, message, _VELOX_MSG_TEMPLATE(__VA_ARGS__)); \ + } while (0) + +#define _VELOX_THROW_DISPATCH_0 _VELOX_THROW_IMPL_BODY_NO_TEMPLATE +#define _VELOX_THROW_DISPATCH_1 _VELOX_THROW_IMPL_BODY_NO_TEMPLATE +#define _VELOX_THROW_DISPATCH_2 _VELOX_THROW_IMPL_BODY_WITH_TEMPLATE +#define _VELOX_THROW_DISPATCH_3 _VELOX_THROW_IMPL_BODY_WITH_TEMPLATE +#define _VELOX_THROW_DISPATCH_4 _VELOX_THROW_IMPL_BODY_WITH_TEMPLATE +#define _VELOX_THROW_DISPATCH_5 _VELOX_THROW_IMPL_BODY_WITH_TEMPLATE +#define _VELOX_THROW_DISPATCH_6 _VELOX_THROW_IMPL_BODY_WITH_TEMPLATE +#define _VELOX_THROW_DISPATCH_7 _VELOX_THROW_IMPL_BODY_WITH_TEMPLATE +#define _VELOX_THROW_DISPATCH_8 _VELOX_THROW_IMPL_BODY_WITH_TEMPLATE +#define _VELOX_THROW_DISPATCH_9 _VELOX_THROW_IMPL_BODY_WITH_TEMPLATE +#define _VELOX_THROW_DISPATCH_10 _VELOX_THROW_IMPL_BODY_WITH_TEMPLATE +#define _VELOX_THROW_DISPATCH_11 _VELOX_THROW_IMPL_BODY_WITH_TEMPLATE +#define _VELOX_THROW_DISPATCH_12 _VELOX_THROW_IMPL_BODY_WITH_TEMPLATE +#define _VELOX_THROW_DISPATCH_13 _VELOX_THROW_IMPL_BODY_WITH_TEMPLATE +#define _VELOX_THROW_DISPATCH_14 _VELOX_THROW_IMPL_BODY_WITH_TEMPLATE +#define _VELOX_THROW_DISPATCH_15 _VELOX_THROW_IMPL_BODY_WITH_TEMPLATE +#define _VELOX_THROW_DISPATCH_16 _VELOX_THROW_IMPL_BODY_WITH_TEMPLATE + +#define _VELOX_THROW_CONCAT2(a, b) a##b +#define _VELOX_THROW_CONCAT(a, b) _VELOX_THROW_CONCAT2(a, b) +#define _VELOX_THROW_SELECT(n) _VELOX_THROW_CONCAT(_VELOX_THROW_DISPATCH_, n) + +#define _VELOX_THROW_IMPL( \ + exception, exprStr, errorSource, errorCode, isRetriable, ...) \ + _VELOX_THROW_SELECT(_VELOX_NARGS(__VA_ARGS__))( \ + exception, exprStr, errorSource, errorCode, isRetriable, ##__VA_ARGS__) + #define _VELOX_CHECK_AND_THROW_IMPL( \ expr, exprStr, exception, errorSource, errorCode, isRetriable, ...) \ do { \ @@ -334,6 +495,14 @@ DECLARE_CHECK_FAIL_TEMPLATES(::facebook::velox::VeloxRuntimeError) /* isRetriable */ false, \ ##__VA_ARGS__) +#define VELOX_TRACE_LIMIT_EXCEEDED(...) \ + _VELOX_THROW( \ + ::facebook::velox::VeloxRuntimeError, \ + ::facebook::velox::error_source::kErrorSourceRuntime.c_str(), \ + ::facebook::velox::error_code::kTraceLimitExceeded.c_str(), \ + /* isRetriable */ true, \ + ##__VA_ARGS__) + DECLARE_CHECK_FAIL_TEMPLATES(::facebook::velox::VeloxUserError) // For all below macros, an additional message can be passed using a diff --git a/velox/common/base/Macros.h b/velox/common/base/Macros.h index 62664a488f5..2af0303c423 100644 --- a/velox/common/base/Macros.h +++ b/velox/common/base/Macros.h @@ -28,6 +28,21 @@ _Pragma("GCC diagnostic pop"); #endif +// Disable deprecated-declarations for Clang and GCC +#ifdef __clang__ +#define VELOX_SUPPRESS_DEPRECATED_WARNING \ + _Pragma("clang diagnostic push"); \ + _Pragma("clang diagnostic ignored \"-Wdeprecated-declarations\"") + +#define VELOX_UNSUPPRESS_DEPRECATED_WARNING _Pragma("clang diagnostic pop") +#else +#define VELOX_SUPPRESS_DEPRECATED_WARNING \ + _Pragma("GCC diagnostic push"); \ + _Pragma("GCC diagnostic ignored \"-Wdeprecated-declarations\"") + +#define VELOX_UNSUPPRESS_DEPRECATED_WARNING _Pragma("GCC diagnostic pop") +#endif + #define VELOX_CONCAT(x, y) x##y // Need this extra layer to expand __COUNTER__. #define VELOX_VARNAME_IMPL(x, y) VELOX_CONCAT(x, y) diff --git a/velox/common/base/PeriodicStatsReporter.cpp b/velox/common/base/PeriodicStatsReporter.cpp index 8afbabd59c3..fb07906dcfe 100644 --- a/velox/common/base/PeriodicStatsReporter.cpp +++ b/velox/common/base/PeriodicStatsReporter.cpp @@ -139,7 +139,10 @@ void PeriodicStatsReporter::reportCacheStats() { const auto cacheStats = cache_->refreshStats(); // Memory cache snapshot stats. - RECORD_METRIC_VALUE(kMetricMemoryCacheNumEntries, cacheStats.numEntries); + RECORD_METRIC_VALUE( + kMetricMemoryCacheNumTinyEntries, cacheStats.numTinyEntries); + RECORD_METRIC_VALUE( + kMetricMemoryCacheNumLargeEntries, cacheStats.numLargeEntries); RECORD_METRIC_VALUE( kMetricMemoryCacheNumEmptyEntries, cacheStats.numEmptyEntries); RECORD_METRIC_VALUE(kMetricMemoryCacheNumSharedEntries, cacheStats.numShared); @@ -208,8 +211,13 @@ void PeriodicStatsReporter::reportCacheStats() { kMetricSsdCacheGrowFileErrors, deltaSsdStats.growFileErrors); REPORT_IF_NOT_ZERO( kMetricSsdCacheWriteSsdErrors, deltaSsdStats.writeSsdErrors); + REPORT_IF_NOT_ZERO( + kMetricSsdCacheWriteNoSpaceErrors, deltaSsdStats.writeSsdNoSpaceErrors); REPORT_IF_NOT_ZERO( kMetricSsdCacheWriteSsdDropped, deltaSsdStats.writeSsdDropped); + REPORT_IF_NOT_ZERO( + kMetricSsdCacheWriteExceedEntryLimit, + deltaSsdStats.writeSsdExceedEntryLimit); REPORT_IF_NOT_ZERO( kMetricSsdCacheWriteCheckpointErrors, deltaSsdStats.writeCheckpointErrors); diff --git a/velox/common/base/Portability.h b/velox/common/base/Portability.h index 60049fcc54c..bd8daf71e77 100644 --- a/velox/common/base/Portability.h +++ b/velox/common/base/Portability.h @@ -19,7 +19,7 @@ #include #include #include -#include +#include inline size_t count_trailing_zeros(uint64_t x) { return x == 0 ? 64 : __builtin_ctzll(x); @@ -88,15 +88,4 @@ using tsan_lock_guard = TsanEmptyLockGuard; #endif -template -inline void resizeTsanAtomic( - std::vector>& vector, - int32_t newSize) { - std::vector> newVector(newSize); - auto numCopy = std::min(newSize, vector.size()); - for (auto i = 0; i < numCopy; ++i) { - newVector[i] = tsanAtomicValue(vector[i]); - } - vector = std::move(newVector); -} } // namespace facebook::velox diff --git a/velox/common/base/RuntimeMetrics.cpp b/velox/common/base/RuntimeMetrics.cpp index 3130cbcabef..f72b396ce91 100644 --- a/velox/common/base/RuntimeMetrics.cpp +++ b/velox/common/base/RuntimeMetrics.cpp @@ -30,7 +30,7 @@ void RuntimeMetric::addValue(int64_t value) { } void RuntimeMetric::aggregate() { - count = std::min(count, static_cast(1)); + count = std::min(count, static_cast(1)); min = max = sum; } @@ -113,7 +113,7 @@ BaseRuntimeStatWriter* getThreadLocalRunTimeStatWriter() { } void addThreadLocalRuntimeStat( - const std::string& name, + std::string_view name, const RuntimeCounter& value) { if (localRuntimeStatWriter) { localRuntimeStatWriter->addRuntimeStat(name, value); diff --git a/velox/common/base/RuntimeMetrics.h b/velox/common/base/RuntimeMetrics.h index 5b699a41a99..579e54a1435 100644 --- a/velox/common/base/RuntimeMetrics.h +++ b/velox/common/base/RuntimeMetrics.h @@ -19,9 +19,17 @@ #include #include #include +#include namespace facebook::velox { +/// Converts unsigned bigint to signed, capping at int64_t max if overflow +/// happens. Could be replaced by 'std::saturate_cast' since C++26. +inline int64_t saturateCast(uint64_t value) { + return static_cast(std::min( + value, static_cast(std::numeric_limits::max()))); +} + struct RuntimeCounter { enum class Unit { kNone, kNanos, kBytes }; int64_t value; @@ -35,7 +43,7 @@ struct RuntimeMetric { // Sum, min, max have the same unit, count has kNone. RuntimeCounter::Unit unit; int64_t sum{0}; - int64_t count{0}; + uint64_t count{0}; int64_t min{std::numeric_limits::max()}; int64_t max{std::numeric_limits::min()}; @@ -48,6 +56,14 @@ struct RuntimeMetric { RuntimeCounter::Unit _unit = RuntimeCounter::Unit::kNone) : unit(_unit), sum{value}, count{1}, min{value}, max{value} {} + explicit RuntimeMetric( + int64_t _sum, + uint64_t _count, + int64_t _min, + int64_t _max, + RuntimeCounter::Unit _unit = RuntimeCounter::Unit::kNone) + : unit(_unit), sum{_sum}, count{_count}, min{_min}, max{_max} {} + void addValue(int64_t value); /// Aggregate sets 'min' and 'max' to 'sum', also sets 'count' to 1 if @@ -69,7 +85,7 @@ class BaseRuntimeStatWriter { virtual ~BaseRuntimeStatWriter() = default; virtual void addRuntimeStat( - const std::string& /* name */, + std::string_view /* name */, const RuntimeCounter& /* value */) {} }; @@ -85,7 +101,7 @@ BaseRuntimeStatWriter* getThreadLocalRunTimeStatWriter(); /// Writes runtime counter to the current Operator running on that thread. void addThreadLocalRuntimeStat( - const std::string& name, + std::string_view name, const RuntimeCounter& value); /// Scope guard to conveniently set and revert back the current stat writer. diff --git a/velox/common/base/SimdUtil-inl.h b/velox/common/base/SimdUtil-inl.h index 937c86d9f7e..d46c4f21b95 100644 --- a/velox/common/base/SimdUtil-inl.h +++ b/velox/common/base/SimdUtil-inl.h @@ -91,10 +91,24 @@ struct BitMask { #if XSIMD_WITH_NEON static int toBitMask(xsimd::batch_bool mask, const xsimd::neon&) { - alignas(A::alignment()) static const int8_t kShift[] = { - -7, -6, -5, -4, -3, -2, -1, 0, -7, -6, -5, -4, -3, -2, -1, 0}; - int8x16_t vshift = vld1q_s8(kShift); - uint8x16_t vmask = vshlq_u8(vandq_u8(mask, vdupq_n_u8(0x80)), vshift); + static constexpr uint8x16_t vmask_const = { + 0x01, + 0x02, + 0x04, + 0x08, + 0x10, + 0x20, + 0x40, + 0x80, // lower half + 0x01, + 0x02, + 0x04, + 0x08, + 0x10, + 0x20, + 0x40, + 0x80}; // upper half + uint8x16_t vmask = vandq_u8(mask.data, vmask_const); return (vaddv_u8(vget_high_u8(vmask)) << 8) | vaddv_u8(vget_low_u8(vmask)); } #endif @@ -125,6 +139,15 @@ struct BitMask { } #endif +#if XSIMD_WITH_NEON + static int toBitMask(xsimd::batch_bool mask, const xsimd::neon&) { + static constexpr uint16x8_t vmask_const = { + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80}; + uint16x8_t vmask = vandq_u16(mask.data, vmask_const); + return vaddvq_u16(vmask); + } +#endif + static int toBitMask(xsimd::batch_bool mask, const xsimd::generic&) { return genericToBitMask(mask); } @@ -146,6 +169,14 @@ struct BitMask { } #endif +#if XSIMD_WITH_NEON + static int toBitMask(xsimd::batch_bool mask, const xsimd::neon&) { + static constexpr uint32x4_t vmask_const = {0x1, 0x2, 0x4, 0x8}; + uint32x4_t vmask = vandq_u32(mask.data, vmask_const); + return vaddvq_u32(vmask); + } +#endif + static int toBitMask(xsimd::batch_bool mask, const xsimd::generic&) { return genericToBitMask(mask); } @@ -173,6 +204,14 @@ struct BitMask { } #endif +#if XSIMD_WITH_NEON + static int toBitMask(xsimd::batch_bool mask, const xsimd::neon&) { + static constexpr uint64x2_t vmask_const = {0x1, 0x2}; + const uint64x2_t vmask = vandq_u64(mask.data, vmask_const); + return vaddvq_u64(vmask); + } +#endif + static int toBitMask(xsimd::batch_bool mask, const xsimd::generic&) { return genericToBitMask(mask); } @@ -1346,6 +1385,19 @@ struct Filter { template struct Crc32 { + static uint32_t + apply(uint32_t checksum, uint64_t value, const xsimd::generic&) { + checksum ^= static_cast(value); + for (int i = 0; i < 32; ++i) { + checksum = (checksum >> 1) ^ (0x82F63B78 & -(checksum & 1)); + } + checksum ^= static_cast(value >> 32); + for (int i = 0; i < 32; ++i) { + checksum = (checksum >> 1) ^ (0x82F63B78 & -(checksum & 1)); + } + return checksum; + } + #if XSIMD_WITH_SSE4_2 static uint32_t apply(uint32_t checksum, uint64_t value, const xsimd::sse4_2&) { diff --git a/velox/common/base/SimdUtil.h b/velox/common/base/SimdUtil.h index 1aabc0f2952..176b3acae43 100644 --- a/velox/common/base/SimdUtil.h +++ b/velox/common/base/SimdUtil.h @@ -16,7 +16,10 @@ #pragma once +#include #include +#include + #include "velox/common/base/BitUtil.h" #include "velox/common/base/Exceptions.h" @@ -364,6 +367,29 @@ auto toBitMask(xsimd::batch_bool mask, const A& arch = {}) { return detail::BitMask::toBitMask(mask, arch); } +/// Returns true if at least one lane in the mask is true. +template +inline bool any(xsimd::batch_bool mask, const A& arch = {}) { +#if XSIMD_WITH_AVX2 + // x86 bitmasks perform better than xsimd reductions. + return toBitMask(mask, arch) != 0; +#else + (void)arch; + return xsimd::any(mask); +#endif +} + +/// Returns true if no lanes in the mask are true. +template +inline bool none(xsimd::batch_bool mask, const A& arch = {}) { +#if XSIMD_WITH_AVX2 + return toBitMask(mask, arch) == 0; +#else + (void)arch; + return xsimd::none(mask); +#endif +} + // Get a vector mask from bit mask. template xsimd::batch_bool fromBitMask(BitMaskType bitMask, const A& arch = {}) { @@ -377,6 +403,17 @@ auto allSetBitMask(const A& = {}) { return detail::BitMask::kAllSet; } +/// Returns true if every lane in the mask is true. +template +inline bool all(xsimd::batch_bool mask, const A& arch = {}) { +#if XSIMD_WITH_AVX2 + return toBitMask(mask, arch) == allSetBitMask(arch); +#else + (void)arch; + return xsimd::all(mask); +#endif +} + namespace detail { template struct Filter; @@ -430,6 +467,7 @@ uint32_t crc32U64(uint32_t checksum, uint64_t value, const A& arch = {}) { template xsimd::batch iota(const A& = {}); +#ifdef VELOX_ENABLE_LOAD_SIMD_VALUE_BUFFER // Returns a batch with all elements set to value. For batch we // use one bit to represent one element. template @@ -445,6 +483,7 @@ xsimd::batch setAll(T value, const A& = {}) { return xsimd::broadcast(value); } } +#endif // Stores 'data' into 'destination' for the lanes in 'mask'. 'mask' is expected // to specify contiguous lower lanes of 'batch'. For non-SIMD cases, 'mask' is @@ -529,6 +568,43 @@ inline void memcpy(void* to, const void* from, int64_t bytes, const A& = {}); template void memset(void* to, char data, int32_t bytes, const A& = {}); +// Fills 'count' elements of 'output' with 'value' using SIMD broadcast+store +// for types that support it (int32_t, uint32_t, int64_t, uint64_t, float, +// double), and falls back to std::fill for other types. +// +// Uses a simpler structure than std::fill's auto-vectorized code — broadcast + +// tight store loop + overlapping tail — that performs better for repeated calls +// with small variable-length fills. +template +inline void simdFill(T* output, T value, uint32_t count) { + if constexpr ( + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + constexpr auto kBatchSize = xsimd::batch::size; + if (count >= kBatchSize) { + auto batch = xsimd::broadcast(value); + uint32_t i = 0; + for (; i + kBatchSize <= count; i += kBatchSize) { + batch.store_unaligned(output + i); + } + // Handle tail with an overlapping store. Safe because we're filling + // with a constant value, so re-writing already-written positions is + // harmless. The outer 'count >= kBatchSize' guard ensures + // 'count - kBatchSize' never underflows. + if (i < count) { + batch.store_unaligned(output + count - kBatchSize); + } + } else { + for (uint32_t i = 0; i < count; ++i) { + output[i] = value; + } + } + } else { + std::fill(output, output + count, value); + } +} + // Calls a different instantiation of a template function according to // 'numBytes'. #define VELOX_WIDTH_DISPATCH(numBytes, TEMPLATE_FUNC, ...) \ diff --git a/velox/common/base/SkewedPartitionBalancer.cpp b/velox/common/base/SkewedPartitionBalancer.cpp index 52ea170d0c6..7a1b4a6efa3 100644 --- a/velox/common/base/SkewedPartitionBalancer.cpp +++ b/velox/common/base/SkewedPartitionBalancer.cpp @@ -30,9 +30,10 @@ SkewedPartitionRebalancer::SkewedPartitionRebalancer( numTasks_(numTasks), minProcessedBytesRebalanceThresholdPerPartition_( minProcessedBytesRebalanceThresholdPerPartition), - minProcessedBytesRebalanceThreshold_(std::max( - minProcessedBytesRebalanceThreshold, - minProcessedBytesRebalanceThresholdPerPartition_)), + minProcessedBytesRebalanceThreshold_( + std::max( + minProcessedBytesRebalanceThreshold, + minProcessedBytesRebalanceThresholdPerPartition_)), partitionRowCount_(numPartitions_), partitionAssignments_(numPartitions_) { VELOX_CHECK_GT(numPartitions_, 0); diff --git a/velox/common/base/SkewedPartitionBalancer.h b/velox/common/base/SkewedPartitionBalancer.h index 3c998920c1e..e145fa47381 100644 --- a/velox/common/base/SkewedPartitionBalancer.h +++ b/velox/common/base/SkewedPartitionBalancer.h @@ -15,6 +15,8 @@ */ #pragma once +#include + #include "velox/common/base/IndexedPriorityQueue.h" namespace facebook::velox::common { diff --git a/velox/common/base/SpillConfig.cpp b/velox/common/base/SpillConfig.cpp index dd428a41ec7..caf55d77386 100644 --- a/velox/common/base/SpillConfig.cpp +++ b/velox/common/base/SpillConfig.cpp @@ -34,8 +34,10 @@ SpillConfig::SpillConfig( uint64_t _maxSpillRunRows, uint64_t _writerFlushThresholdSize, const std::string& _compressionKind, + uint32_t _numMaxMergeFiles, std::optional _prefixSortConfig, - const std::string& _fileCreateConfig) + const std::string& _fileCreateConfig, + uint32_t _windowMinReadBatchRows) : getSpillDirPathCb(std::move(_getSpillDirPathCb)), updateAndCheckSpillLimitCb(std::move(_updateAndCheckSpillLimitCb)), fileNamePrefix(std::move(_fileNamePrefix)), @@ -53,12 +55,18 @@ SpillConfig::SpillConfig( maxSpillRunRows(_maxSpillRunRows), writerFlushThresholdSize(_writerFlushThresholdSize), compressionKind(common::stringToCompressionKind(_compressionKind)), + numMaxMergeFiles(_numMaxMergeFiles), prefixSortConfig(_prefixSortConfig), - fileCreateConfig(_fileCreateConfig) { + fileCreateConfig(_fileCreateConfig), + windowMinReadBatchRows(_windowMinReadBatchRows) { VELOX_USER_CHECK_GE( spillableReservationGrowthPct, minSpillableReservationPct, "Spillable memory reservation growth pct should not be lower than minimum available pct"); + VELOX_CHECK_NE( + numMaxMergeFiles, + 1, + "NumMaxMergeFiles should not be 1 as merging should take at least 2 files to make progress"); } int32_t SpillConfig::spillLevel(uint8_t startBitOffset) const { diff --git a/velox/common/base/SpillConfig.h b/velox/common/base/SpillConfig.h index 7f30bc6e614..adec0274e9d 100644 --- a/velox/common/base/SpillConfig.h +++ b/velox/common/base/SpillConfig.h @@ -52,6 +52,13 @@ using GetSpillDirectoryPathCB = std::function; /// bytes exceed the set limit. using UpdateAndCheckSpillLimitCB = std::function; +/// Specifies the options for spill to disk. +struct SpillDiskOptions { + std::string spillDirPath; + bool spillDirCreated{true}; + std::function spillDirCreateCb{nullptr}; +}; + /// Specifies the config for spilling. struct SpillConfig { SpillConfig() = default; @@ -71,8 +78,10 @@ struct SpillConfig { uint64_t _maxSpillRunRows, uint64_t _writerFlushThresholdSize, const std::string& _compressionKind, + uint32_t numMaxMergeFiles, std::optional _prefixSortConfig = std::nullopt, - const std::string& _fileCreateConfig = {}); + const std::string& _fileCreateConfig = {}, + uint32_t _windowMinReadBatchRows = 1'000); /// Returns the spilling level with given 'startBitOffset' and /// 'numPartitionBits'. @@ -151,11 +160,22 @@ struct SpillConfig { /// CompressionKind when spilling, CompressionKind_NONE means no compression. common::CompressionKind compressionKind; + /// The max number of files to merge at a time when merging sorted files into + /// a single ordered stream. 0 means unlimited. This is used to reduce memory + /// pressure by capping the number of open files when merging spilled sorted + /// files to avoid using too much memory and causing OOM. Note that this is + /// only applicable for ordered spill, is not applicable for spill scenarios + /// that don't need sorting, e.g. HashJoin. + uint32_t numMaxMergeFiles{0}; + /// Prefix sort config when spilling, enable prefix sort when this config is /// set, otherwise, fallback to timsort. std::optional prefixSortConfig; /// Custom options passed to velox::FileSystem to create spill WriteFile. std::string fileCreateConfig; + + /// The minimum number of rows to read when processing spilled window data. + uint32_t windowMinReadBatchRows; }; } // namespace facebook::velox::common diff --git a/velox/common/base/SpillStats.cpp b/velox/common/base/SpillStats.cpp deleted file mode 100644 index 07b60692db0..00000000000 --- a/velox/common/base/SpillStats.cpp +++ /dev/null @@ -1,267 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "velox/common/base/SpillStats.h" -#include "velox/common/base/Counters.h" -#include "velox/common/base/Exceptions.h" -#include "velox/common/base/StatsReporter.h" -#include "velox/common/base/SuccinctPrinter.h" - -namespace facebook::velox::common { -namespace { -std::vector>& allSpillStats() { - static std::vector> spillStatsList( - std::thread::hardware_concurrency()); - return spillStatsList; -} - -folly::Synchronized& localSpillStats() { - const auto idx = std::hash{}(std::this_thread::get_id()); - auto& spillStatsVector = allSpillStats(); - return spillStatsVector[idx % spillStatsVector.size()]; -} -} // namespace - -SpillStats::SpillStats( - uint64_t _spillRuns, - uint64_t _spilledInputBytes, - uint64_t _spilledBytes, - uint64_t _spilledRows, - uint32_t _spilledPartitions, - uint64_t _spilledFiles, - uint64_t _spillFillTimeNanos, - uint64_t _spillSortTimeNanos, - uint64_t _spillExtractVectorTimeNanos, - uint64_t _spillSerializationTimeNanos, - uint64_t _spillWrites, - uint64_t _spillFlushTimeNanos, - uint64_t _spillWriteTimeNanos, - uint64_t _spillMaxLevelExceededCount, - uint64_t _spillReadBytes, - uint64_t _spillReads, - uint64_t _spillReadTimeNanos, - uint64_t _spillDeserializationTimeNanos) - : spillRuns(_spillRuns), - spilledInputBytes(_spilledInputBytes), - spilledBytes(_spilledBytes), - spilledRows(_spilledRows), - spilledPartitions(_spilledPartitions), - spilledFiles(_spilledFiles), - spillFillTimeNanos(_spillFillTimeNanos), - spillSortTimeNanos(_spillSortTimeNanos), - spillExtractVectorTimeNanos(_spillExtractVectorTimeNanos), - spillSerializationTimeNanos(_spillSerializationTimeNanos), - spillWrites(_spillWrites), - spillFlushTimeNanos(_spillFlushTimeNanos), - spillWriteTimeNanos(_spillWriteTimeNanos), - spillMaxLevelExceededCount(_spillMaxLevelExceededCount), - spillReadBytes(_spillReadBytes), - spillReads(_spillReads), - spillReadTimeNanos(_spillReadTimeNanos), - spillDeserializationTimeNanos(_spillDeserializationTimeNanos) {} - -SpillStats& SpillStats::operator+=(const SpillStats& other) { - spillRuns += other.spillRuns; - spilledInputBytes += other.spilledInputBytes; - spilledBytes += other.spilledBytes; - spilledRows += other.spilledRows; - spilledPartitions += other.spilledPartitions; - spilledFiles += other.spilledFiles; - spillFillTimeNanos += other.spillFillTimeNanos; - spillSortTimeNanos += other.spillSortTimeNanos; - spillExtractVectorTimeNanos += other.spillExtractVectorTimeNanos; - spillSerializationTimeNanos += other.spillSerializationTimeNanos; - spillWrites += other.spillWrites; - spillFlushTimeNanos += other.spillFlushTimeNanos; - spillWriteTimeNanos += other.spillWriteTimeNanos; - spillMaxLevelExceededCount += other.spillMaxLevelExceededCount; - spillReadBytes += other.spillReadBytes; - spillReads += other.spillReads; - spillReadTimeNanos += other.spillReadTimeNanos; - spillDeserializationTimeNanos += other.spillDeserializationTimeNanos; - return *this; -} - -SpillStats SpillStats::operator-(const SpillStats& other) const { - SpillStats result; - result.spillRuns = spillRuns - other.spillRuns; - result.spilledInputBytes = spilledInputBytes - other.spilledInputBytes; - result.spilledBytes = spilledBytes - other.spilledBytes; - result.spilledRows = spilledRows - other.spilledRows; - result.spilledPartitions = spilledPartitions - other.spilledPartitions; - result.spilledFiles = spilledFiles - other.spilledFiles; - result.spillFillTimeNanos = spillFillTimeNanos - other.spillFillTimeNanos; - result.spillSortTimeNanos = spillSortTimeNanos - other.spillSortTimeNanos; - result.spillExtractVectorTimeNanos = - spillExtractVectorTimeNanos - other.spillExtractVectorTimeNanos; - result.spillDeserializationTimeNanos = - spillExtractVectorTimeNanos - other.spillExtractVectorTimeNanos; - result.spillSerializationTimeNanos = - spillSerializationTimeNanos - other.spillSerializationTimeNanos; - result.spillWrites = spillWrites - other.spillWrites; - result.spillFlushTimeNanos = spillFlushTimeNanos - other.spillFlushTimeNanos; - result.spillWriteTimeNanos = spillWriteTimeNanos - other.spillWriteTimeNanos; - result.spillMaxLevelExceededCount = - spillMaxLevelExceededCount - other.spillMaxLevelExceededCount; - result.spillReadBytes = spillReadBytes - other.spillReadBytes; - result.spillReads = spillReads - other.spillReads; - result.spillReadTimeNanos = spillReadTimeNanos - other.spillReadTimeNanos; - result.spillDeserializationTimeNanos = - spillDeserializationTimeNanos - other.spillDeserializationTimeNanos; - return result; -} - -void SpillStats::reset() { - spillRuns = 0; - spilledInputBytes = 0; - spilledBytes = 0; - spilledRows = 0; - spilledPartitions = 0; - spilledFiles = 0; - spillFillTimeNanos = 0; - spillSortTimeNanos = 0; - spillExtractVectorTimeNanos = 0; - spillSerializationTimeNanos = 0; - spillWrites = 0; - spillFlushTimeNanos = 0; - spillWriteTimeNanos = 0; - spillMaxLevelExceededCount = 0; - spillReadBytes = 0; - spillReads = 0; - spillReadTimeNanos = 0; - spillDeserializationTimeNanos = 0; -} - -std::string SpillStats::toString() const { - return fmt::format( - "spillRuns[{}] spilledInputBytes[{}] spilledBytes[{}] spilledRows[{}] " - "spilledPartitions[{}] spilledFiles[{}] spillFillTimeNanos[{}] " - "spillSortTimeNanos[{}] spillExtractVectorTime[{}] spillSerializationTimeNanos[{}] spillWrites[{}] " - "spillFlushTimeNanos[{}] spillWriteTimeNanos[{}] maxSpillExceededLimitCount[{}] " - "spillReadBytes[{}] spillReads[{}] spillReadTimeNanos[{}] " - "spillReadDeserializationTimeNanos[{}]", - spillRuns, - succinctBytes(spilledInputBytes), - succinctBytes(spilledBytes), - spilledRows, - spilledPartitions, - spilledFiles, - succinctNanos(spillFillTimeNanos), - succinctNanos(spillSortTimeNanos), - succinctNanos(spillExtractVectorTimeNanos), - succinctNanos(spillSerializationTimeNanos), - spillWrites, - succinctNanos(spillFlushTimeNanos), - succinctNanos(spillWriteTimeNanos), - spillMaxLevelExceededCount, - succinctBytes(spillReadBytes), - spillReads, - succinctNanos(spillReadTimeNanos), - succinctNanos(spillDeserializationTimeNanos)); -} - -void updateGlobalSpillRunStats(uint64_t numRuns) { - auto statsLocked = localSpillStats().wlock(); - statsLocked->spillRuns += numRuns; -} - -void updateGlobalSpillAppendStats( - uint64_t numRows, - uint64_t serializationTimeNs) { - RECORD_METRIC_VALUE(kMetricSpilledRowsCount, numRows); - RECORD_HISTOGRAM_METRIC_VALUE( - kMetricSpillSerializationTimeMs, serializationTimeNs / 1'000'000); - auto statsLocked = localSpillStats().wlock(); - statsLocked->spilledRows += numRows; - statsLocked->spillSerializationTimeNanos += serializationTimeNs; -} - -void incrementGlobalSpilledPartitionStats() { - ++localSpillStats().wlock()->spilledPartitions; -} - -void updateGlobalSpillFillTime(uint64_t timeNs) { - RECORD_HISTOGRAM_METRIC_VALUE(kMetricSpillFillTimeMs, timeNs / 1'000'000); - localSpillStats().wlock()->spillFillTimeNanos += timeNs; -} - -void updateGlobalSpillSortTime(uint64_t timeNs) { - RECORD_HISTOGRAM_METRIC_VALUE(kMetricSpillSortTimeMs, timeNs / 1'000'000); - localSpillStats().wlock()->spillSortTimeNanos += timeNs; -} - -void updateGlobalSpillExtractVectorTime(uint64_t timeNs) { - RECORD_HISTOGRAM_METRIC_VALUE( - kMetricSpillExtractVectorTimeMs, timeNs / 1'000'000); - localSpillStats().wlock()->spillExtractVectorTimeNanos += timeNs; -} - -void updateGlobalSpillWriteStats( - uint64_t spilledBytes, - uint64_t flushTimeNs, - uint64_t writeTimeNs) { - RECORD_METRIC_VALUE(kMetricSpillWritesCount); - RECORD_METRIC_VALUE(kMetricSpilledBytes, spilledBytes); - RECORD_HISTOGRAM_METRIC_VALUE( - kMetricSpillFlushTimeMs, flushTimeNs / 1'000'000); - RECORD_HISTOGRAM_METRIC_VALUE( - kMetricSpillWriteTimeMs, writeTimeNs / 1'000'000); - auto statsLocked = localSpillStats().wlock(); - ++statsLocked->spillWrites; - statsLocked->spilledBytes += spilledBytes; - statsLocked->spillFlushTimeNanos += flushTimeNs; - statsLocked->spillWriteTimeNanos += writeTimeNs; -} - -void updateGlobalSpillReadStats( - uint64_t spillReads, - uint64_t spillReadBytes, - uint64_t spillReadTimeNs) { - auto statsLocked = localSpillStats().wlock(); - statsLocked->spillReads += spillReads; - statsLocked->spillReadBytes += spillReadBytes; - statsLocked->spillReadTimeNanos += spillReadTimeNs; -} - -void updateGlobalSpillMemoryBytes(uint64_t spilledInputBytes) { - RECORD_METRIC_VALUE(kMetricSpilledInputBytes, spilledInputBytes); - auto statsLocked = localSpillStats().wlock(); - statsLocked->spilledInputBytes += spilledInputBytes; -} - -void incrementGlobalSpilledFiles() { - RECORD_METRIC_VALUE(kMetricSpilledFilesCount); - ++localSpillStats().wlock()->spilledFiles; -} - -void updateGlobalMaxSpillLevelExceededCount( - uint64_t maxSpillLevelExceededCount) { - localSpillStats().wlock()->spillMaxLevelExceededCount += - maxSpillLevelExceededCount; -} - -void updateGlobalSpillDeserializationTimeNs(uint64_t timeNs) { - localSpillStats().wlock()->spillDeserializationTimeNanos += timeNs; -} - -SpillStats globalSpillStats() { - SpillStats gSpillStats; - for (auto& spillStats : allSpillStats()) { - gSpillStats += spillStats.copy(); - } - return gSpillStats; -} -} // namespace facebook::velox::common diff --git a/velox/common/base/SplitBlockBloomFilter.cpp b/velox/common/base/SplitBlockBloomFilter.cpp new file mode 100644 index 00000000000..ef4d5e87c56 --- /dev/null +++ b/velox/common/base/SplitBlockBloomFilter.cpp @@ -0,0 +1,66 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/base/SplitBlockBloomFilter.h" + +#include "velox/common/base/BitUtil.h" +#include "velox/common/base/Exceptions.h" + +#include +#include + +namespace facebook::velox { + +int64_t SplitBlockBloomFilter::numBlocks( + int64_t numElements, + double falsePositive) { + constexpr int K = xsimd::batch::size; + int64_t numBits = std::ceil( + -K * numElements / std::log(1 - std::pow(falsePositive, 1.0 / K))); + return bits::divRoundUp(numBits, 8 * sizeof(Block)); +} + +SplitBlockBloomFilter::SplitBlockBloomFilter(const std::span& blocks) + : blocks_(blocks) { + VELOX_CHECK_EQ(reinterpret_cast(blocks.data()) % sizeof(Block), 0); +} + +std::string SplitBlockBloomFilter::debugString() const { + std::ostringstream out; + out << "numBlocks=" << blocks_.size() << '\n'; + int64_t byBlockSetBits[1 + 8 * sizeof(Block)]{}; + int64_t byBitPosition[8 * sizeof(Block)]{}; + for (auto& block : blocks_) { + int numSetBits = 0; + for (int i = 0; i < xsimd::batch::size; ++i) { + auto n = std::popcount(block.data[i]); + numSetBits += n; + for (int j = 0; j < 32; ++j) { + byBitPosition[j + i * 32] += 1 & (block.data[i] >> j); + } + } + ++byBlockSetBits[numSetBits]; + } + for (int i = 0; i <= 8 * sizeof(Block); ++i) { + out << "Block set bits " << i << ": " << byBlockSetBits[i] << '\n'; + } + for (int i = 0; i < 8 * sizeof(Block); ++i) { + out << "Bit " << i << ": " << byBitPosition[i] << '\n'; + } + return out.str(); +} + +} // namespace facebook::velox diff --git a/velox/common/base/SplitBlockBloomFilter.h b/velox/common/base/SplitBlockBloomFilter.h new file mode 100644 index 00000000000..b5a40f88c3b --- /dev/null +++ b/velox/common/base/SplitBlockBloomFilter.h @@ -0,0 +1,137 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/common/base/SimdUtil.h" + +#include +#include + +namespace facebook::velox { + +/// SIMDized bloom filter implementation. We take 8 or 4 (depending on the SIMD +/// register size) bits from the inserted value, and split them into a block. A +/// block is the same size of a SIMD register. +/// +/// This data structure is not responsible for memory management or hashing. +/// +/// A detailed explanation about how the data structure works can be found here: +/// https://parquet.apache.org/docs/file-format/bloomfilter/ +class SplitBlockBloomFilter { + public: + /// A block is basically a SIMD register. Made public so user can calculate + /// the size needed for memory allocation; otherwise it's implementation + /// detail. + struct alignas(sizeof(xsimd::batch)) Block { + uint32_t data[xsimd::batch::size]; + }; + + /// Calculate the number of blocks needed to satisfy certain number of inserts + /// and false positive rate. A rough estimation between the memory usage and + /// false positive rate can be found below: + /// + /// | Bits of space per insert | False positive probability | + /// |--------------------------|----------------------------| + /// | 6.0 | 10% | + /// | 10.5 | 1% | + /// | 16.9 | 0.1% | + /// | 26.4 | 0.01% | + /// | 41 | 0.001% | + static int64_t numBlocks(int64_t numElements, double falsePositive); + + /// Construct the bloom filter using the blocks memory passed as parameter. + /// The block memory address must be aligned with block size (i.e. SIMD + /// register size). It is recommended to use numBlocks() to calculate the + /// number of block for allocation. + explicit SplitBlockBloomFilter(const std::span& blocks); + + /// Delete copy constructor and assignment to avoid accidentally mutating the + /// same underlying block data from multiple places. + SplitBlockBloomFilter(const SplitBlockBloomFilter&) = delete; + SplitBlockBloomFilter& operator=(const SplitBlockBloomFilter&) = delete; + + SplitBlockBloomFilter(SplitBlockBloomFilter&&) = default; + + /// Insert a hash into the bloom filter. The function used to generate this + /// hash should be avalanching. + void insert(uint64_t hash) { + auto mask = makeMask(hash); + auto* block = blocks_[blockIndex(hash)].data; + (xsimd::load_aligned(block) | mask).store_aligned(block); + } + + /// Check whether a hash has been inserted before. Could return true when it + /// has not been inserted. Never return false when the hash has been + /// inserted. + bool mayContain(uint64_t hash) const { + auto mask = makeMask(hash); + auto block = xsimd::load_aligned(blocks_[blockIndex(hash)].data); +#if XSIMD_WITH_AVX + return _mm256_testc_si256(block, mask); +#else + return xsimd::all(xsimd::bitwise_andnot(mask, block) == 0); +#endif + } + + /// Return the block index of the given hash. + uint64_t blockIndex(uint64_t hash) const { + return ((hash >> 32) * blocks_.size()) >> 32; + } + + std::string debugString() const; + + private: + static_assert(64 % sizeof(Block) == 0); + + template + static xsimd::batch makeSaltsVec() { + constexpr uint32_t kSalts[] = { + 0x2df1424bU, + 0x44974d91U, + 0x47b6137bU, + 0x5c6bfb31U, + 0x705495c7U, + 0x8824ad5bU, + 0x9efc4947U, + 0xa2b7289dU, + }; + if constexpr (xsimd::batch::size == 8) { + return xsimd::batch( + kSalts[0], + kSalts[1], + kSalts[2], + kSalts[3], + kSalts[4], + kSalts[5], + kSalts[6], + kSalts[7]); + } else { + static_assert(xsimd::batch::size == 4); + return xsimd::batch( + kSalts[0], kSalts[2], kSalts[4], kSalts[6]); + } + } + + static xsimd::batch makeMask(uint32_t hash) { + auto shifts = (makeSaltsVec() * xsimd::broadcast(hash)) >> 27; + return xsimd::broadcast(1) << shifts; + } + + std::span const blocks_; +}; + +} // namespace facebook::velox diff --git a/velox/common/base/StatsReporter.h b/velox/common/base/StatsReporter.h index 94d0bd32c4f..9df453a75db 100644 --- a/velox/common/base/StatsReporter.h +++ b/velox/common/base/StatsReporter.h @@ -63,7 +63,7 @@ enum class StatType { HISTOGRAM, }; -inline std::string statTypeString(StatType stat) { +inline std::string_view statTypeString(StatType stat) { switch (stat) { case StatType::AVG: return "Avg"; @@ -76,7 +76,7 @@ inline std::string statTypeString(StatType stat) { case StatType::HISTOGRAM: return "Histogram"; default: - return fmt::format("UNKNOWN: {}", static_cast(stat)); + return "Unknown"; } } @@ -84,17 +84,17 @@ inline std::string statTypeString(StatType stat) { /// different implementations. class BaseStatsReporter { public: - virtual ~BaseStatsReporter() {} + virtual ~BaseStatsReporter() = default; /// Register a stat of the given stat type. /// @param key The key to identify the stat. /// @param statType How the stat is aggregated. virtual void registerMetricExportType(const char* key, StatType statType) - const = 0; + const {} virtual void registerMetricExportType( folly::StringPiece key, - StatType statType) const = 0; + StatType statType) const {} /// Register a histogram with a list of percentiles defined. /// @param key The key to identify the histogram. @@ -107,14 +107,14 @@ class BaseStatsReporter { int64_t bucketWidth, int64_t min, int64_t max, - const std::vector& pcts) const = 0; + const std::vector& pcts) const {} virtual void registerHistogramMetricExportType( folly::StringPiece key, int64_t bucketWidth, int64_t min, int64_t max, - const std::vector& pcts) const = 0; + const std::vector& pcts) const {} /// Register a quantile metric for quantile stats with export types, /// quantiles, and sliding window periods. @@ -127,13 +127,13 @@ class BaseStatsReporter { const char* key, const std::vector& statTypes, const std::vector& pcts, - const std::vector& slidingWindowsSeconds = {60}) const = 0; + const std::vector& slidingWindowsSeconds = {60}) const {} virtual void registerQuantileMetricExportType( folly::StringPiece key, const std::vector& statTypes, const std::vector& pcts, - const std::vector& slidingWindowsSeconds = {60}) const = 0; + const std::vector& slidingWindowsSeconds = {60}) const {} /// Register a dynamic quantile metric with a template key pattern that /// supports runtime substitution. @@ -145,60 +145,60 @@ class BaseStatsReporter { const char* keyPattern, const std::vector& statTypes, const std::vector& pcts, - const std::vector& slidingWindowsSeconds = {60}) const = 0; + const std::vector& slidingWindowsSeconds = {60}) const {} virtual void registerDynamicQuantileMetricExportType( folly::StringPiece keyPattern, const std::vector& statTypes, const std::vector& pcts, - const std::vector& slidingWindowsSeconds = {60}) const = 0; + const std::vector& slidingWindowsSeconds = {60}) const {} /// Add the given value to the stat. - virtual void addMetricValue(const std::string& key, size_t value = 1) - const = 0; + virtual void addMetricValue(const std::string& key, size_t value = 1) const {} - virtual void addMetricValue(const char* key, size_t value = 1) const = 0; + virtual void addMetricValue(const char* key, size_t value = 1) const {} - virtual void addMetricValue(folly::StringPiece key, size_t value = 1) - const = 0; + virtual void addMetricValue(folly::StringPiece key, size_t value = 1) const {} /// Add the given value to the histogram. virtual void addHistogramMetricValue(const std::string& key, size_t value) - const = 0; + const {} - virtual void addHistogramMetricValue(const char* key, size_t value) const = 0; + virtual void addHistogramMetricValue(const char* key, size_t value) const {} virtual void addHistogramMetricValue(folly::StringPiece key, size_t value) - const = 0; + const {} /// Add the given value to a quantile metric. virtual void addQuantileMetricValue(const std::string& key, size_t value = 1) - const = 0; + const {} - virtual void addQuantileMetricValue(const char* key, size_t value = 1) - const = 0; + virtual void addQuantileMetricValue(const char* key, size_t value = 1) const { + } virtual void addQuantileMetricValue(folly::StringPiece key, size_t value = 1) - const = 0; + const {} /// Add the given value to a quantile metric. virtual void addDynamicQuantileMetricValue( const std::string& key, folly::Range subkeys, - size_t value = 1) const = 0; + size_t value = 1) const {} virtual void addDynamicQuantileMetricValue( const char* key, folly::Range subkeys, - size_t value = 1) const = 0; + size_t value = 1) const {} virtual void addDynamicQuantileMetricValue( folly::StringPiece key, folly::Range subkeys, - size_t value = 1) const = 0; + size_t value = 1) const {} /// Return the aggregated metrics in a serialized string format. - virtual std::string fetchMetrics() = 0; + virtual std::string fetchMetrics() { + return ""; + } static bool registered; }; diff --git a/velox/common/base/Status.h b/velox/common/base/Status.h index 72ec5447618..1cf42673bc8 100644 --- a/velox/common/base/Status.h +++ b/velox/common/base/Status.h @@ -530,6 +530,9 @@ void Status::moveFrom(Status& s) { #define _VELOX_RETURN_IMPL(expr, exprStr, error, ...) \ do { \ if (FOLLY_UNLIKELY(expr)) { \ + if (::facebook::velox::threadSkipErrorDetails()) { \ + return error(); \ + } \ auto message = ::facebook::velox::errorMessage(__VA_ARGS__); \ return error( \ ::facebook::velox::internal::generateError(message, exprStr)); \ diff --git a/velox/exec/TreeOfLosers.h b/velox/common/base/TreeOfLosers.h similarity index 100% rename from velox/exec/TreeOfLosers.h rename to velox/common/base/TreeOfLosers.h diff --git a/velox/common/base/VeloxException.cpp b/velox/common/base/VeloxException.cpp index 735953d169a..3f0d0dec5db 100644 --- a/velox/common/base/VeloxException.cpp +++ b/velox/common/base/VeloxException.cpp @@ -16,9 +16,11 @@ #include "velox/common/base/VeloxException.h" -#include +#include #include +#include + namespace facebook { namespace velox { @@ -102,6 +104,35 @@ VeloxException::VeloxException( state.isRetriable = isRetriable; })) {} +VeloxException::VeloxException( + const char* file, + size_t line, + const char* function, + std::string_view failingExpression, + std::string_view message, + std::string_view errorSource, + std::string_view errorCode, + bool isRetriable, + CompileTimeStringLiteral messageTemplate, + Type exceptionType, + std::string_view exceptionName) + : VeloxException(State::make(exceptionType, [&](auto& state) { + state.exceptionType = exceptionType; + state.exceptionName = exceptionName; + state.file = file; + state.line = line; + state.function = function; + state.failingExpression = failingExpression; + state.message = message; + state.messageTemplate = messageTemplate; + state.errorSource = errorSource; + state.errorCode = errorCode; + state.context = getExceptionContext().message(exceptionType); + state.additionalContext = + getAdditionalExceptionContextString(exceptionType, state.context); + state.isRetriable = isRetriable; + })) {} + VeloxException::VeloxException( const std::exception_ptr& e, std::string_view message, @@ -169,6 +200,18 @@ bool isStackTraceEnabled(VeloxException::Type type) { return last->compare_exchange_strong(latest, now, std::memory_order_relaxed); } +void stringAppendNumber(std::string& str, size_t number) { + // Manual implementation of itoa to avoid the cost of std::to_string. + const auto numberStartOffset = str.end() - + str.begin(); // Not `size()`. We need distance. The type is different. + size_t remaining = number; + do { + str += static_cast('0' + remaining % 10); + remaining /= 10; + } while (remaining); + reverse(str.begin() + numberStartOffset, str.end()); +} + } // namespace template @@ -248,13 +291,7 @@ void VeloxException::State::finalize() const { if (line) { elaborateMessage += "Line: "; - auto len = elaborateMessage.size(); - size_t t = line; - do { - elaborateMessage += static_cast('0' + t % 10); - t /= 10; - } while (t); - reverse(elaborateMessage.begin() + len, elaborateMessage.end()); + stringAppendNumber(elaborateMessage, line); elaborateMessage += '\n'; } diff --git a/velox/common/base/VeloxException.h b/velox/common/base/VeloxException.h index b9ffb813f4f..44d08cdd285 100644 --- a/velox/common/base/VeloxException.h +++ b/velox/common/base/VeloxException.h @@ -28,6 +28,8 @@ #include "velox/common/process/StackTrace.h" +#include "velox/common/base/ExceptionHelper.h" + DECLARE_bool(velox_exception_user_stacktrace_enabled); DECLARE_bool(velox_exception_system_stacktrace_enabled); @@ -139,6 +141,8 @@ class VeloxException : public std::exception { public: enum class Type { kUser = 0, kSystem = 1 }; + /// Construct without an explicit message template. messageTemplate() will + /// return the message itself (the message IS the template). VeloxException( const char* file, size_t line, @@ -151,6 +155,21 @@ class VeloxException : public std::exception { Type exceptionType = Type::kSystem, std::string_view exceptionName = "VeloxException"); + /// Construct with an explicit message template (the format string before + /// fmt::format interpolation). + VeloxException( + const char* file, + size_t line, + const char* function, + std::string_view expression, + std::string_view message, + std::string_view errorSource, + std::string_view errorCode, + bool isRetriable, + CompileTimeStringLiteral messageTemplate, + Type exceptionType = Type::kSystem, + std::string_view exceptionName = "VeloxException"); + /// Wrap an std::exception. VeloxException( const std::exception_ptr& e, @@ -202,6 +221,14 @@ class VeloxException : public std::exception { return state_->message; } + /// Returns the format template before fmt::format interpolation. Useful for + /// grouping exceptions by error category in monitoring systems. When no + /// explicit template was provided, returns the message itself. + std::string_view messageTemplate() const { + return state_->messageTemplate ? std::string_view(state_->messageTemplate) + : std::string_view(state_->message); + } + const std::string& errorCode() const { return state_->errorCode; } @@ -248,6 +275,10 @@ class VeloxException : public std::exception { const char* function = nullptr; std::string failingExpression; std::string message; + // The format template before fmt::format interpolation. Points to a + // string literal (static lifetime) when set explicitly, or nullptr when + // the message itself serves as the template. + const char* messageTemplate{nullptr}; std::string errorSource; std::string errorCode; // The current exception context. @@ -284,6 +315,8 @@ class VeloxException : public std::exception { class VeloxUserError : public VeloxException { public: + static constexpr std::string_view kDefaultName = "VeloxUserError"; + VeloxUserError( const char* file, size_t line, @@ -293,7 +326,7 @@ class VeloxUserError : public VeloxException { std::string_view /* errorSource */, std::string_view errorCode, bool isRetriable, - std::string_view exceptionName = "VeloxUserError") + std::string_view exceptionName = kDefaultName) : VeloxException( file, line, @@ -306,12 +339,36 @@ class VeloxUserError : public VeloxException { Type::kUser, exceptionName) {} + VeloxUserError( + const char* file, + size_t line, + const char* function, + std::string_view expression, + std::string_view message, + std::string_view /* errorSource */, + std::string_view errorCode, + bool isRetriable, + CompileTimeStringLiteral messageTemplate, + std::string_view exceptionName = kDefaultName) + : VeloxException( + file, + line, + function, + expression, + message, + error_source::kErrorSourceUser, + errorCode, + isRetriable, + messageTemplate, + Type::kUser, + exceptionName) {} + /// Wrap an std::exception. VeloxUserError( const std::exception_ptr& e, std::string_view message, bool isRetriable, - std::string_view exceptionName = "VeloxUserError") + std::string_view exceptionName = kDefaultName) : VeloxException( e, message, @@ -324,6 +381,30 @@ class VeloxUserError : public VeloxException { class VeloxRuntimeError final : public VeloxException { public: + static constexpr std::string_view kDefaultName = "VeloxRuntimeError"; + + VeloxRuntimeError( + const char* file, + size_t line, + const char* function, + std::string_view expression, + std::string_view message, + std::string_view /* errorSource */, + std::string_view errorCode, + bool isRetriable, + std::string_view exceptionName = kDefaultName) + : VeloxException( + file, + line, + function, + expression, + message, + error_source::kErrorSourceRuntime, + errorCode, + isRetriable, + Type::kSystem, + exceptionName) {} + VeloxRuntimeError( const char* file, size_t line, @@ -333,7 +414,8 @@ class VeloxRuntimeError final : public VeloxException { std::string_view /* errorSource */, std::string_view errorCode, bool isRetriable, - std::string_view exceptionName = "VeloxRuntimeError") + CompileTimeStringLiteral messageTemplate, + std::string_view exceptionName = kDefaultName) : VeloxException( file, line, @@ -343,6 +425,7 @@ class VeloxRuntimeError final : public VeloxException { error_source::kErrorSourceRuntime, errorCode, isRetriable, + messageTemplate, Type::kSystem, exceptionName) {} @@ -351,7 +434,7 @@ class VeloxRuntimeError final : public VeloxException { const std::exception_ptr& e, std::string_view message, bool isRetriable, - std::string_view exceptionName = "VeloxRuntimeError") + std::string_view exceptionName = kDefaultName) : VeloxException( e, message, diff --git a/velox/common/base/XxHashInline.h b/velox/common/base/XxHashInline.h new file mode 100644 index 00000000000..77b6b20ee4f --- /dev/null +++ b/velox/common/base/XxHashInline.h @@ -0,0 +1,20 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#define XXH_INLINE_ALL +#include // @manual=third-party//xxHash:xxhash diff --git a/velox/common/base/benchmarks/CMakeLists.txt b/velox/common/base/benchmarks/CMakeLists.txt index cb7dcb1ab55..b4ffaf2f60a 100644 --- a/velox/common/base/benchmarks/CMakeLists.txt +++ b/velox/common/base/benchmarks/CMakeLists.txt @@ -42,3 +42,11 @@ target_link_libraries( PUBLIC Folly::follybenchmark PRIVATE velox_common_base Folly::folly ) + +add_executable(velox_common_split_block_bloom_filter_benchmark SplitBlockBloomFilterBenchmark.cpp) + +target_link_libraries( + velox_common_split_block_bloom_filter_benchmark + PUBLIC Folly::follybenchmark + PRIVATE velox_common_base Folly::folly +) diff --git a/velox/common/base/benchmarks/SplitBlockBloomFilterBenchmark.cpp b/velox/common/base/benchmarks/SplitBlockBloomFilterBenchmark.cpp new file mode 100644 index 00000000000..6706be84ab2 --- /dev/null +++ b/velox/common/base/benchmarks/SplitBlockBloomFilterBenchmark.cpp @@ -0,0 +1,155 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/base/SplitBlockBloomFilter.h" + +#include +#include +#include + +#define VELOX_BENCHMARK(_make, _name, ...) \ + [[maybe_unused]] auto _name = _make(FOLLY_PP_STRINGIZE(_name), __VA_ARGS__) + +namespace facebook::velox { +namespace { + +template +class SplitBlockBloomFilterBenchmark { + public: + SplitBlockBloomFilterBenchmark( + const char* name, + Hasher hasher, + double falsePositive, + int numInserts, + int numTests) + : hasher_(std::move(hasher)), + numTests_(numTests), + blocks_(SplitBlockBloomFilter::numBlocks(numInserts, falsePositive)), + filter_(blocks_) { + for (int i = 0; i < numInserts; ++i) { + filter_.insert(hasher_(generateValue())); + } + folly::addBenchmark(__FILE__, name, [this] { return run(); }); + } + + private: + static T generateValue() { + if constexpr (sizeof(T) == 8) { + return folly::Random::rand64(); + } else { + static_assert(sizeof(T) == 4); + return folly::Random::rand32(); + } + } + + unsigned run() const { + int numHits = 0; + for (int i = 0; i < numTests_; ++i) { + numHits += filter_.mayContain(hasher_(generateValue())); + } + folly::doNotOptimizeAway(numHits); + return numTests_; + } + + const Hasher hasher_; + const double numTests_; + std::vector blocks_; + SplitBlockBloomFilter filter_; +}; + +template +SplitBlockBloomFilterBenchmark makeBenchmark( + const char* name, + Hasher hasher, + double falsePositive, + int numInserts, + int numTests) { + return SplitBlockBloomFilterBenchmark( + name, std::move(hasher), falsePositive, numInserts, numTests); +} + +} // namespace +} // namespace facebook::velox + +int main(int argc, char* argv[]) { + using namespace facebook::velox; + folly::Init follyInit(&argc, &argv); + VELOX_BENCHMARK( + makeBenchmark, + int32, + folly::hasher(), + 0.01, + 5'000'000, + 10'000'000); + VELOX_BENCHMARK( + makeBenchmark, + int64, + folly::hasher(), + 0.01, + 5'000'000, + 10'000'000); + VELOX_BENCHMARK( + makeBenchmark, + int64_nohash, + folly::identity, + 0.01, + 5'000'000, + 10'000'000); + VELOX_BENCHMARK( + makeBenchmark, + int32_small, + folly::hasher(), + 0.01, + 500'000, + 1'000'000); + VELOX_BENCHMARK( + makeBenchmark, + int64_small, + folly::hasher(), + 0.01, + 500'000, + 1'000'000); + VELOX_BENCHMARK( + makeBenchmark, + int64_nohash_small, + folly::identity, + 0.01, + 500'000, + 1'000'000); + VELOX_BENCHMARK( + makeBenchmark, + int32_large, + folly::hasher(), + 0.01, + 50'000'000, + 100'000'000); + VELOX_BENCHMARK( + makeBenchmark, + int64_large, + folly::hasher(), + 0.01, + 50'000'000, + 100'000'000); + VELOX_BENCHMARK( + makeBenchmark, + int64_nohash_large, + folly::identity, + 0.01, + 50'000'000, + 100'000'000); + folly::runBenchmarks(); + return 0; +} diff --git a/velox/common/base/benchmarks/StringSearchBenchmark.cpp b/velox/common/base/benchmarks/StringSearchBenchmark.cpp index 5fee996709a..cad51cb56fa 100644 --- a/velox/common/base/benchmarks/StringSearchBenchmark.cpp +++ b/velox/common/base/benchmarks/StringSearchBenchmark.cpp @@ -105,11 +105,12 @@ class TestStringSearch { void runSearching(size_t iters) const { if constexpr (alg == SIMD) { FOR_EACH_RANGE (i, 0, iters) - doNotOptimizeAway(simd::simdStrstr( - heyStack_.data(), - heyStack_.size(), - needle_.data(), - needle_.size())); + doNotOptimizeAway( + simd::simdStrstr( + heyStack_.data(), + heyStack_.size(), + needle_.data(), + needle_.size())); } else if constexpr (alg == STD) { FOR_EACH_RANGE (i, 0, iters) doNotOptimizeAway( diff --git a/velox/common/base/tests/AsyncSourceTest.cpp b/velox/common/base/tests/AsyncSourceTest.cpp index 657a7ba8e08..a94cec99f9d 100644 --- a/velox/common/base/tests/AsyncSourceTest.cpp +++ b/velox/common/base/tests/AsyncSourceTest.cpp @@ -22,18 +22,76 @@ #include #include #include "velox/common/base/Exceptions.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/testutil/TestValue.h" using namespace facebook::velox; using namespace std::chrono_literals; -// A sample class to be constructed via AsyncSource. +namespace { struct Gizmo { explicit Gizmo(int32_t _id) : id(_id) {} const int32_t id; }; -TEST(AsyncSourceTest, basic) { +class DataCounter { + public: + DataCounter() { + objectNumber_ = ++numCreatedDataCounters_; + } + + ~DataCounter() { + ++numDeletedDataCounters_; + } + + static uint64_t numCreatedDataCounters() { + return numCreatedDataCounters_; + } + + static uint64_t numDeletedDataCounters() { + return numDeletedDataCounters_; + } + + static void reset() { + numCreatedDataCounters_ = 0; + numDeletedDataCounters_ = 0; + } + + uint64_t objectNumber() const { + return objectNumber_; + } + + private: + inline static std::atomic numCreatedDataCounters_{0}; + inline static std::atomic numDeletedDataCounters_{0}; + + uint64_t objectNumber_{0}; +}; + +void verifyContexts( + const std::string& expectedPoolName, + const std::string& expectedTaskId) { + EXPECT_EQ(process::GetThreadDebugInfo()->taskId_, expectedTaskId); +} +} // namespace + +class AsyncSourceTest : public ::testing::Test { + protected: + static void SetUpTestSuite() { + common::testutil::TestValue::enable(); + } + + void SetUp() override { + DataCounter::reset(); + } + + void TearDown() override { + DataCounter::reset(); + } +}; + +TEST_F(AsyncSourceTest, basic) { AsyncSource gizmo([]() { return std::make_unique(11); }); EXPECT_FALSE(gizmo.hasValue()); gizmo.prepare(); @@ -45,15 +103,247 @@ TEST(AsyncSourceTest, basic) { AsyncSource error( []() -> std::unique_ptr { VELOX_USER_FAIL("Testing error"); }); - EXPECT_THROW(error.move(), VeloxException); + VELOX_ASSERT_USER_THROW(error.move(), "Testing error"); EXPECT_TRUE(error.hasValue()); } -TEST(AsyncSourceTest, threads) { +TEST_F(AsyncSourceTest, close) { + { + auto dateCounter = std::make_shared(); + AsyncSource countAsyncSource([dateCounter]() { + return std::make_unique(dateCounter->objectNumber()); + }); + dateCounter.reset(); + EXPECT_EQ(DataCounter::numCreatedDataCounters(), 1); + EXPECT_EQ(DataCounter::numDeletedDataCounters(), 0); + EXPECT_EQ(0, countAsyncSource.prepareTiming().count); + + countAsyncSource.close(); + + EXPECT_EQ(0, countAsyncSource.prepareTiming().count); + EXPECT_EQ(DataCounter::numCreatedDataCounters(), 1); + EXPECT_EQ(DataCounter::numDeletedDataCounters(), 1); + } + DataCounter::reset(); + + { + auto asyncSource = std::make_shared>( + []() { return std::make_unique(); }); + asyncSource->prepare(); + EXPECT_EQ(DataCounter::numCreatedDataCounters(), 1); + EXPECT_EQ(DataCounter::numDeletedDataCounters(), 0); + EXPECT_EQ(1, asyncSource->prepareTiming().count); + + asyncSource->close(); + EXPECT_EQ(DataCounter::numCreatedDataCounters(), 1); + EXPECT_EQ(DataCounter::numDeletedDataCounters(), 1); + EXPECT_EQ(1, asyncSource->prepareTiming().count); + } + DataCounter::reset(); + + { + folly::Baton<> baton; + auto sleepAsyncSource = + std::make_shared>([&baton]() { + baton.post(); + return std::make_unique(); + }); + auto thread = + std::thread([&sleepAsyncSource] { sleepAsyncSource->prepare(); }); + EXPECT_TRUE(baton.try_wait_for(1s)); + sleepAsyncSource->close(); + EXPECT_EQ(DataCounter::numCreatedDataCounters(), 1); + EXPECT_EQ(DataCounter::numDeletedDataCounters(), 1); + EXPECT_EQ(1, sleepAsyncSource->prepareTiming().count); + thread.join(); + } +} + +TEST_F(AsyncSourceTest, emptyContexts) { + EXPECT_EQ(process::GetThreadDebugInfo(), nullptr); + + AsyncSource src([]() { + verifyContexts("test", "task_id"); + return std::make_unique(true); + }); + + process::ThreadDebugInfo debugInfo{"query_id", "task_id", nullptr}; + process::ScopedThreadDebugInfo scopedDebugInfo(debugInfo); + + verifyContexts("test", "task_id"); + + ASSERT_TRUE(*src.move()); + + verifyContexts("test", "task_id"); +} + +TEST_F(AsyncSourceTest, setContexts) { + process::ThreadDebugInfo debugInfo1{"query_id1", "task_id1", nullptr}; + + std::unique_ptr> src; + process::ScopedThreadDebugInfo scopedDebugInfo1(debugInfo1); + + verifyContexts("test1", "task_id1"); + + src = std::make_unique>(([]() { + verifyContexts("test1", "task_id1"); + return std::make_unique(true); + })); + + process::ThreadDebugInfo debugInfo2{"query_id2", "task_id2", nullptr}; + process::ScopedThreadDebugInfo scopedDebugInfo2(debugInfo2); + + verifyContexts("test2", "task_id2"); + + ASSERT_TRUE(*src->move()); + + verifyContexts("test2", "task_id2"); +} + +TEST_F(AsyncSourceTest, cancel) { + { + auto dataCounter = std::make_shared(); + auto asyncSource = std::make_shared>([dataCounter]() { + return std::make_unique(dataCounter->objectNumber()); + }); + dataCounter.reset(); + + asyncSource->cancel(); + EXPECT_EQ(DataCounter::numDeletedDataCounters(), 1); + } + DataCounter::reset(); + + { + folly::Baton<> startBaton; + folly::Baton<> finishBaton; + auto asyncSource = std::make_shared>( + [&startBaton, &finishBaton]() { + startBaton.post(); + finishBaton.wait(); + return std::make_unique(); + }); + + auto thread = std::thread([&asyncSource] { asyncSource->prepare(); }); + EXPECT_TRUE(startBaton.try_wait_for(1s)); + + asyncSource->cancel(); + + finishBaton.post(); + thread.join(); + + EXPECT_EQ(DataCounter::numCreatedDataCounters(), 1); + EXPECT_TRUE(asyncSource->hasValue()); + asyncSource->close(); + EXPECT_EQ(1, asyncSource->prepareTiming().count); + } + DataCounter::reset(); + + { + auto asyncSource = std::make_shared>( + []() { return std::make_unique(); }); + asyncSource->prepare(); + + asyncSource->cancel(); + + EXPECT_FALSE(asyncSource->hasValue()); + EXPECT_EQ(asyncSource->move(), nullptr); + EXPECT_EQ(1, asyncSource->prepareTiming().count); + } + DataCounter::reset(); + + { + std::atomic_bool taskExecuted{false}; + auto asyncSource = + std::make_shared>([&taskExecuted]() { + taskExecuted = true; + return std::make_unique(); + }); + + asyncSource->cancel(); + asyncSource->prepare(); + EXPECT_FALSE(taskExecuted); + EXPECT_FALSE(asyncSource->hasValue()); + + EXPECT_EQ(asyncSource->move(), nullptr); + EXPECT_FALSE(taskExecuted); + } + + { + auto dataCounter = std::make_shared(); + auto asyncSource = std::make_shared>([dataCounter]() { + return std::make_unique(dataCounter->objectNumber()); + }); + dataCounter.reset(); + + asyncSource->cancel(); + asyncSource->cancel(); + EXPECT_EQ(DataCounter::numDeletedDataCounters(), 1); + EXPECT_EQ(0, asyncSource->prepareTiming().count); + } + DataCounter::reset(); + + { + folly::Baton<> moveStarted; + folly::Baton<> continueMove; + auto asyncSource = std::make_shared>( + [&moveStarted, &continueMove]() { + moveStarted.post(); + continueMove.wait(); + return std::make_unique(); + }); + + auto moveThread = std::thread([&asyncSource] { + auto result = asyncSource->move(); + EXPECT_NE(result, nullptr); + }); + + EXPECT_TRUE(moveStarted.try_wait_for(1s)); + + asyncSource->cancel(); + + continueMove.post(); + moveThread.join(); + EXPECT_EQ(DataCounter::numCreatedDataCounters(), 1); + } + DataCounter::reset(); + + { + auto asyncSource = std::make_shared>( + []() { return std::make_unique(); }); + + auto result = asyncSource->move(); + EXPECT_NE(result, nullptr); + EXPECT_EQ(DataCounter::numCreatedDataCounters(), 1); + + asyncSource->cancel(); + + EXPECT_EQ(1, asyncSource->prepareTiming().count); + EXPECT_EQ(DataCounter::numCreatedDataCounters(), 1); + } + DataCounter::reset(); + + // Cancel called after close() - should be no-op. + { + auto asyncSource = std::make_shared>( + []() { return std::make_unique(); }); + asyncSource->prepare(); + EXPECT_EQ(DataCounter::numCreatedDataCounters(), 1); + + asyncSource->close(); + EXPECT_EQ(DataCounter::numDeletedDataCounters(), 1); + + asyncSource->cancel(); + EXPECT_EQ(DataCounter::numDeletedDataCounters(), 1); + EXPECT_EQ(1, asyncSource->prepareTiming().count); + } +} + +TEST_F(AsyncSourceTest, multithreadedPrepareAndMove) { constexpr int32_t kNumThreads = 10; constexpr int32_t kNumGizmos = 2000; folly::Synchronized> results; std::vector>> gizmos; + gizmos.reserve(kNumGizmos); for (auto i = 0; i < kNumGizmos; ++i) { gizmos.push_back(std::make_shared>([i]() { std::this_thread::sleep_for(std::chrono::milliseconds(1)); // NOLINT @@ -64,16 +354,12 @@ TEST(AsyncSourceTest, threads) { std::vector threads; threads.reserve(kNumThreads); for (int32_t threadIndex = 0; threadIndex < kNumThreads; ++threadIndex) { - threads.push_back(std::thread([threadIndex, &gizmos, &results]() { + threads.emplace_back([threadIndex, &gizmos, &results]() { if (threadIndex < kNumThreads / 2) { - // The first half of the threads prepare Gizmos in the background. for (auto i = 0; i < kNumGizmos; ++i) { gizmos[i]->prepare(); } } else { - // The rest of the threads first get random Gizmos and then do a pass - // over all the Gizmos to make sure all get collected. We assert that - // each Gizmo is obtained once. folly::Random::DefaultGenerator rng; for (auto i = 0; i < kNumGizmos / 3; ++i) { auto gizmo = @@ -95,7 +381,7 @@ TEST(AsyncSourceTest, threads) { } } } - })); + }); } for (auto& thread : threads) { thread.join(); @@ -107,11 +393,12 @@ TEST(AsyncSourceTest, threads) { }); } -TEST(AsyncSourceTest, errorsWithThreads) { +TEST_F(AsyncSourceTest, multithreadedErrorHandling) { constexpr int32_t kNumGizmos = 50; constexpr int32_t kNumThreads = 10; std::vector>> gizmos; std::atomic numErrors{0}; + gizmos.reserve(kNumGizmos); for (auto i = 0; i < kNumGizmos; ++i) { gizmos.push_back( std::make_shared>([]() -> std::unique_ptr { @@ -123,16 +410,12 @@ TEST(AsyncSourceTest, errorsWithThreads) { std::vector threads; threads.reserve(kNumThreads); for (int32_t threadIndex = 0; threadIndex < kNumThreads; ++threadIndex) { - threads.push_back(std::thread([threadIndex, &gizmos, &numErrors]() { + threads.emplace_back([threadIndex, &gizmos, &numErrors]() { if (threadIndex < kNumThreads / 2) { - // The first half of the threads prepare Gizmos in the background. for (auto i = 0; i < kNumGizmos; ++i) { gizmos[i]->prepare(); } } else { - // The rest of the threads get random gizmos. They are - // expected to produce an error or nullptr in the event - // another thread is already waiting for the same gizmo. folly::Random::DefaultGenerator rng; for (auto i = 0; i < kNumGizmos / 3; ++i) { try { @@ -144,155 +427,253 @@ TEST(AsyncSourceTest, errorsWithThreads) { } } } - })); + }); } for (auto& thread : threads) { thread.join(); } - // There will always be errors since the first to wait for any given - // gizmo is sure to get an error. EXPECT_LT(0, numErrors); for (auto& source : gizmos) { source->close(); } } -class DataCounter { - public: - DataCounter() { - objectNumber_ = ++numCreatedDataCounters_; - } +DEBUG_ONLY_TEST_F(AsyncSourceTest, concurrentMoveSteal) { + // Test scenario: first move() waits for making + // it gets signaled through promises, but a second move() comes in + // between the wait completion and lock re-acquisition and steals the item. + // The first move() should get nothing. + folly::Baton<> makingStarted; + folly::Baton<> makingContinue; + folly::Baton<> firstMoveWaiting; + folly::Baton<> secondMoveComplete; + + auto asyncSource = + std::make_shared>([&makingStarted, &makingContinue]() { + makingStarted.post(); + makingContinue.wait(); + return std::make_unique(42); + }); - ~DataCounter() { - ++numDeletedDataCounters_; - } + std::atomic firstMoveResult{nullptr}; + std::atomic secondMoveResult{nullptr}; + std::unique_ptr firstMoveHolder; + std::unique_ptr secondMoveHolder; + + SCOPED_TESTVALUE_SET( + "facebook::velox::AsyncSource::makeWait", + std::function*)>([&](AsyncSource* source) { + // Signal that first move is about to re-acquire lock. + firstMoveWaiting.post(); + // Wait for second move to complete and steal the item. + secondMoveComplete.wait(); + })); + + // Thread 1: First move() - will wait for making and then get blocked by + // TestValue. + auto firstMoveThread = std::thread([&]() { + firstMoveHolder = asyncSource->move(); + firstMoveResult = firstMoveHolder.get(); + }); - static uint64_t numCreatedDataCounters() { - return numCreatedDataCounters_; - } + // Thread 2: prepare() - starts making the item. + auto prepareThread = std::thread([&]() { asyncSource->prepare(); }); - static uint64_t numDeletedDataCounters() { - return numDeletedDataCounters_; - } + // Wait for making to start. + ASSERT_TRUE(makingStarted.try_wait_for(1s)); - static void reset() { - numCreatedDataCounters_ = 0; - numDeletedDataCounters_ = 0; - } + // Let making complete - this will signal the first move's promise. + makingContinue.post(); - uint64_t objectNumber() const { - return objectNumber_; - } + // Wait for first move to be signaled and about to re-acquire lock. + ASSERT_TRUE(firstMoveWaiting.try_wait_for(1s)); - private: - static std::atomic numCreatedDataCounters_; - static std::atomic numDeletedDataCounters_; + // Thread 3: Second move() - steals the item while first move is blocked. + auto secondMoveThread = std::thread([&]() { + secondMoveHolder = asyncSource->move(); + secondMoveResult = secondMoveHolder.get(); + secondMoveComplete.post(); + }); - uint64_t objectNumber_{0}; -}; + firstMoveThread.join(); + secondMoveThread.join(); + prepareThread.join(); -std::atomic DataCounter::numCreatedDataCounters_ = 0; + // Second move should have stolen the item. + EXPECT_NE(secondMoveResult.load(), nullptr); + EXPECT_EQ(secondMoveResult.load()->id, 42); -std::atomic DataCounter::numDeletedDataCounters_ = 0; + // First move should get nothing because second move stole the item. + EXPECT_EQ(firstMoveResult.load(), nullptr); +} -TEST(AsyncSourceTest, close) { - // If 'prepare()' is not executed within the thread pool, invoking 'close()' - // will set 'make_' to nullptr. The deletion of 'dateCounter' is used as a - // verification for this behavior. - auto dateCounter = std::make_shared(); - AsyncSource countAsyncSource([dateCounter]() { - return std::make_unique(dateCounter->objectNumber()); +DEBUG_ONLY_TEST_F(AsyncSourceTest, concurrentMoveCloseRace) { + // Test scenario: Tests the race condition where move() + // preparation and close() sneaks in to grab the item first. + // + // Timeline: + // 1. prepare() starts making the item in a background thread + // 2. move() enters and waits for preparation to complete (blocked on + // promise) + // 3. prepare() completes and signals the promise + // 4. move() wakes up but TestValue blocks it before re-acquiring the lock + // 5. close() runs and transitions state to kFinished, clearing the item + // 6. move() finally re-acquires lock but finds state is kFinished + // 7. move() returns nullptr + // + // Expected: move() gets nullptr because close() grabbed the item first. + folly::Baton<> makingStarted; + folly::Baton<> makingContinue; + folly::Baton<> moveWaiting; + folly::Baton<> closeComplete; + + auto asyncSource = + std::make_shared>([&makingStarted, &makingContinue]() { + makingStarted.post(); + makingContinue.wait(); + return std::make_unique(42); + }); + + std::atomic moveResult{nullptr}; + std::unique_ptr moveHolder; + + SCOPED_TESTVALUE_SET( + "facebook::velox::AsyncSource::makeWait", + std::function*)>([&](AsyncSource* source) { + // Signal that move is about to re-acquire lock. + moveWaiting.post(); + // Wait for close to complete. + closeComplete.wait(); + })); + + // Thread 1: move() - will wait for making and then get blocked by TestValue. + auto moveThread = std::thread([&]() { + moveHolder = asyncSource->move(); + moveResult = moveHolder.get(); }); - dateCounter.reset(); - EXPECT_EQ(DataCounter::numCreatedDataCounters(), 1); - EXPECT_EQ(DataCounter::numDeletedDataCounters(), 0); - countAsyncSource.close(); - EXPECT_EQ(DataCounter::numCreatedDataCounters(), 1); - EXPECT_EQ(DataCounter::numDeletedDataCounters(), 1); - DataCounter::reset(); + // Thread 2: prepare() - starts making the item. + auto prepareThread = std::thread([&]() { asyncSource->prepare(); }); - // If 'prepare()' is executed within the thread pool but 'move()' is not - // invoked, invoking 'close()' will set 'item_' to nullptr. The deletion of - // 'dateCounter' is used as a verification for this behavior. - auto asyncSource = std::make_shared>( - []() { return std::make_unique(); }); - asyncSource->prepare(); - EXPECT_EQ(DataCounter::numCreatedDataCounters(), 1); - EXPECT_EQ(DataCounter::numDeletedDataCounters(), 0); + // Wait for making to start. + ASSERT_TRUE(makingStarted.try_wait_for(1s)); + // Let making complete - this will signal move's promise. + makingContinue.post(); + + // Wait for move to be signaled and about to re-acquire lock. + ASSERT_TRUE(moveWaiting.try_wait_for(1s)); + + // close() comes in and closes the item while move is blocked. asyncSource->close(); - EXPECT_EQ(DataCounter::numCreatedDataCounters(), 1); - EXPECT_EQ(DataCounter::numDeletedDataCounters(), 1); - DataCounter::reset(); + closeComplete.post(); - // If 'prepare()' is currently being executed within the thread pool, - // 'close()' should wait for the completion of 'prepare()' and set 'item_' to - // nullptr. - folly::Baton<> baton; - auto sleepAsyncSource = - std::make_shared>([&baton]() { - baton.post(); - return std::make_unique(); - }); - auto thread1 = - std::thread([&sleepAsyncSource] { sleepAsyncSource->prepare(); }); - EXPECT_TRUE(baton.try_wait_for(1s)); - sleepAsyncSource->close(); - EXPECT_EQ(DataCounter::numCreatedDataCounters(), 1); - EXPECT_EQ(DataCounter::numDeletedDataCounters(), 1); - thread1.join(); -} + moveThread.join(); + prepareThread.join(); -void verifyContexts( - const std::string& expectedPoolName, - const std::string& expectedTaskId) { - EXPECT_EQ(process::GetThreadDebugInfo()->taskId_, expectedTaskId); + // move() should get nothing because close() grabbed the item first. + EXPECT_EQ(moveResult.load(), nullptr); } -TEST(AsyncSourceTest, emptyContexts) { - EXPECT_EQ(process::GetThreadDebugInfo(), nullptr); +DEBUG_ONLY_TEST_F(AsyncSourceTest, concurrentCloseMoveRace) { + // Test scenario: Tests the race condition where close() + // preparation and move() sneaks in to grab the item first. + // + // Timeline: + // 1. prepare() starts making the item in a background thread + // 2. close() enters and waits for preparation to complete (blocked on + // promise) + // 3. prepare() completes and signals the promise + // 4. close() wakes up but TestValue blocks it before re-acquiring the lock + // 5. move() runs and transitions state to kFinished, taking the item + // 6. close() finally re-acquires lock and finds state is kFinished + // 7. close() completes successfully (nothing to close) + // + // Expected: move() gets the item, close() finds nothing to close. + folly::Baton<> makingStarted; + folly::Baton<> makingContinue; + folly::Baton<> closeWaiting; + folly::Baton<> moveComplete; + + auto asyncSource = + std::make_shared>([&makingStarted, &makingContinue]() { + makingStarted.post(); + makingContinue.wait(); + return std::make_unique(42); + }); - AsyncSource src([]() { - // The Contexts at the time this was created were null so we should inherit - // them from the caller. - verifyContexts("test", "task_id"); + std::atomic moveResult{nullptr}; + std::unique_ptr moveHolder; - return std::make_unique(true); - }); + SCOPED_TESTVALUE_SET( + "facebook::velox::AsyncSource::makeWait", + std::function*)>([&](AsyncSource* source) { + // Signal that close is about to re-acquire lock. + closeWaiting.post(); + // Wait for move to complete and take the item. + moveComplete.wait(); + })); - process::ThreadDebugInfo debugInfo{"query_id", "task_id", nullptr}; - process::ScopedThreadDebugInfo scopedDebugInfo(debugInfo); + // Thread 1: prepare() - starts making the item. + auto prepareThread = std::thread([&]() { asyncSource->prepare(); }); - verifyContexts("test", "task_id"); + // Wait for making to start. + ASSERT_TRUE(makingStarted.try_wait_for(1s)); - ASSERT_TRUE(*src.move()); + // Thread 2: close() - will wait for making and then get blocked by TestValue. + auto closeThread = std::thread([&]() { asyncSource->close(); }); - verifyContexts("test", "task_id"); -} + // Wait for close to be signaled and about to re-acquire lock. + ASSERT_TRUE(closeWaiting.try_wait_for(1s)); -TEST(AsyncSourceTest, setContexts) { - process::ThreadDebugInfo debugInfo1{"query_id1", "task_id1", nullptr}; + // Let making complete - this will signal close's promise. + makingContinue.post(); - std::unique_ptr> src; - process::ScopedThreadDebugInfo scopedDebugInfo1(debugInfo1); + std::this_thread::sleep_for(std::chrono::seconds(1)); // NOLINT - verifyContexts("test1", "task_id1"); + // move() comes in and takes the item while close is blocked. + moveHolder = asyncSource->move(); + moveResult = moveHolder.get(); + moveComplete.post(); - src = std::make_unique>(([]() { - // The Contexts at the time this was created were set so we should have - // the same contexts when this is executed. - verifyContexts("test1", "task_id1"); + closeThread.join(); + prepareThread.join(); - return std::make_unique(true); - })); + // move() should have taken the item. + EXPECT_NE(moveResult.load(), nullptr); + EXPECT_EQ(moveResult.load()->id, 42); +} - process::ThreadDebugInfo debugInfo2{"query_id2", "task_id2", nullptr}; - process::ScopedThreadDebugInfo scopedDebugInfo2(debugInfo2); +TEST_F(AsyncSourceTest, prepareTiming) { + auto asyncSource = std::make_shared>([]() { + std::this_thread::sleep_for(1s); + return std::make_unique(42); + }); - verifyContexts("test2", "task_id2"); + asyncSource->prepare(); - ASSERT_TRUE(*src->move()); + const auto& timing = asyncSource->prepareTiming(); + EXPECT_EQ(timing.count, 1); + EXPECT_GE(timing.wallNanos, 1'000'000'000); + asyncSource->close(); +} - verifyContexts("test2", "task_id2"); +TEST_F(AsyncSourceTest, itemMakerReturnsNull) { + // Test when itemMaker returns nullptr via prepare(). + { + auto asyncSource = std::make_shared>( + []() -> std::unique_ptr { return nullptr; }); + asyncSource->prepare(); + EXPECT_FALSE(asyncSource->hasValue()); + auto result = asyncSource->move(); + EXPECT_EQ(result, nullptr); + } + + // Test when itemMaker returns nullptr via move() (inline making). + { + auto asyncSource = std::make_shared>( + []() -> std::unique_ptr { return nullptr; }); + auto result = asyncSource->move(); + EXPECT_EQ(result, nullptr); + } } diff --git a/velox/common/base/tests/BitUtilTest.cpp b/velox/common/base/tests/BitUtilTest.cpp index 4450a5145ca..10a9ce911ad 100644 --- a/velox/common/base/tests/BitUtilTest.cpp +++ b/velox/common/base/tests/BitUtilTest.cpp @@ -18,7 +18,9 @@ #include "velox/common/base/Crc.h" #include "velox/type/HugeInt.h" -#include +#include +#include +#include #include #include @@ -946,6 +948,291 @@ TEST_F(BitUtilTest, divRoundUp) { bits::divRoundUp(testData.value, testData.factor), testData.expected); } } +TEST_F(BitUtilTest, bitsRequired) { + EXPECT_EQ(bitsRequired(0), 1); + EXPECT_EQ(bitsRequired(1), 1); + EXPECT_EQ(bitsRequired(2), 2); + EXPECT_EQ(bitsRequired(3), 2); + EXPECT_EQ(bitsRequired(4), 3); + EXPECT_EQ(bitsRequired(7), 3); + EXPECT_EQ(bitsRequired(8), 4); + EXPECT_EQ(bitsRequired(15), 4); + EXPECT_EQ(bitsRequired(16), 5); + EXPECT_EQ(bitsRequired(255), 8); + EXPECT_EQ(bitsRequired(256), 9); + EXPECT_EQ(bitsRequired(0xFFFFFFFF), 32); + EXPECT_EQ(bitsRequired(0x100000000ULL), 33); + EXPECT_EQ(bitsRequired(std::numeric_limits::max()), 64); + // Powers of two. + for (int i = 0; i < 63; ++i) { + EXPECT_EQ(bitsRequired(1ULL << i), i + 1); + } +} + +TEST_F(BitUtilTest, maybeSetBit) { + uint8_t bytes[16]; + memset(bytes, 0, sizeof(bytes)); + + // Setting with value=true should set the bit. + maybeSetBit(bytes, 0, true); + EXPECT_TRUE(isBitSet(bytes, 0)); + + maybeSetBit(bytes, 7, true); + EXPECT_TRUE(isBitSet(bytes, 7)); + + maybeSetBit(bytes, 64, true); + EXPECT_TRUE(isBitSet(bytes, 64)); + + // Setting with value=false should be a no-op (bit stays as-is). + maybeSetBit(bytes, 5, false); + EXPECT_FALSE(isBitSet(bytes, 5)); + + // Already-set bit stays set even with value=false (OR semantics). + maybeSetBit(bytes, 0, false); + EXPECT_TRUE(isBitSet(bytes, 0)); + + // Verify untouched bits are still clear. + EXPECT_FALSE(isBitSet(bytes, 1)); + EXPECT_FALSE(isBitSet(bytes, 6)); + + // Test with uint64_t pointer type. + uint64_t words[2] = {0, 0}; + maybeSetBit(words, 3, true); + EXPECT_TRUE(isBitSet(words, 3)); + maybeSetBit(words, 3, false); + EXPECT_TRUE(isBitSet(words, 3)); // OR semantics, stays set. + maybeSetBit(words, 100, true); + EXPECT_TRUE(isBitSet(words, 100)); +} + +TEST_F(BitUtilTest, packBitmap) { + // Pack 128 bools into a bitmap. + bool boolArray[128]; + for (int i = 0; i < 128; ++i) { + boolArray[i] = (i % 3 == 0); + } + + char bitmap[16]; + memset(bitmap, 0, sizeof(bitmap)); + packBitmap(std::span(boolArray, 128), bitmap); + + for (int i = 0; i < 128; ++i) { + EXPECT_EQ( + isBitSet(reinterpret_cast(bitmap), i), (i % 3 == 0)) + << "at bit " << i; + } + + // Test empty span. + char emptyBitmap[8]; + memset(emptyBitmap, 0, sizeof(emptyBitmap)); + packBitmap(std::span(), emptyBitmap); + for (int i = 0; i < 64; ++i) { + EXPECT_FALSE(isBitSet(reinterpret_cast(emptyBitmap), i)); + } + + // Test non-64-aligned count (e.g., 100 bools). + bool bools100[100]; + for (int i = 0; i < 100; ++i) { + bools100[i] = (i % 7 == 0); + } + char bitmap100[16]; + memset(bitmap100, 0, sizeof(bitmap100)); + packBitmap(std::span(bools100, 100), bitmap100); + for (int i = 0; i < 100; ++i) { + EXPECT_EQ( + isBitSet(reinterpret_cast(bitmap100), i), (i % 7 == 0)) + << "at bit " << i; + } + + // Test OR semantics: pre-existing bits are preserved. + char bitmapOr[8]; + memset(bitmapOr, 0, sizeof(bitmapOr)); + setBit(reinterpret_cast(bitmapOr), 1); + bool boolsOr[8] = {true, false, false, true, false, false, false, false}; + packBitmap(std::span(boolsOr, 8), bitmapOr); + EXPECT_TRUE(isBitSet(reinterpret_cast(bitmapOr), 0)); + EXPECT_TRUE( + isBitSet(reinterpret_cast(bitmapOr), 1)); // preserved + EXPECT_TRUE(isBitSet(reinterpret_cast(bitmapOr), 3)); +} + +TEST_F(BitUtilTest, findSetBit) { + // Build a bitmap with known set bits. + uint64_t words[4] = {0, 0, 0, 0}; + auto bitmap = reinterpret_cast(words); + + // Set bits at positions 5, 10, 15, 64, 100, 200. + setBit(words, 5); + setBit(words, 10); + setBit(words, 15); + setBit(words, 64); + setBit(words, 100); + setBit(words, 200); + + // Find the 1st set bit starting from 0. + EXPECT_EQ(findSetBit(bitmap, 0, 256, 1), 5u); + // Find the 2nd set bit. + EXPECT_EQ(findSetBit(bitmap, 0, 256, 2), 10u); + // Find the 3rd set bit. + EXPECT_EQ(findSetBit(bitmap, 0, 256, 3), 15u); + // Find the 4th set bit. + EXPECT_EQ(findSetBit(bitmap, 0, 256, 4), 64u); + // Find the 5th set bit. + EXPECT_EQ(findSetBit(bitmap, 0, 256, 5), 100u); + // Find the 6th set bit. + EXPECT_EQ(findSetBit(bitmap, 0, 256, 6), 200u); + // 7th set bit doesn't exist. + EXPECT_EQ(findSetBit(bitmap, 0, 256, 7), 256u); + + // Search starting from a non-zero begin. + EXPECT_EQ(findSetBit(bitmap, 6, 256, 1), 10u); + EXPECT_EQ(findSetBit(bitmap, 11, 256, 1), 15u); + EXPECT_EQ(findSetBit(bitmap, 16, 256, 1), 64u); + + // Search with restricted end. + EXPECT_EQ(findSetBit(bitmap, 0, 14, 3), 14u); // Only 2 set bits in [0,14) + + // Edge cases. + EXPECT_EQ(findSetBit(bitmap, 0, 256, 0), 0u); // n=0 returns begin + EXPECT_EQ(findSetBit(bitmap, 10, 10, 1), 10u); // begin==end returns begin + EXPECT_EQ(findSetBit(bitmap, 20, 10, 1), 20u); // begin>end returns begin + + // All bits set in a word. + uint64_t allSet[2] = {~0ULL, ~0ULL}; + auto allSetBitmap = reinterpret_cast(allSet); + EXPECT_EQ(findSetBit(allSetBitmap, 0, 128, 1), 0u); + EXPECT_EQ(findSetBit(allSetBitmap, 0, 128, 64), 63u); + EXPECT_EQ(findSetBit(allSetBitmap, 0, 128, 65), 64u); + EXPECT_EQ(findSetBit(allSetBitmap, 0, 128, 128), 127u); + EXPECT_EQ(findSetBit(allSetBitmap, 0, 128, 129), 128u); +} + +TEST_F(BitUtilTest, printBitsTest) { + // uint8_t: 0 should be all zeros. + EXPECT_EQ(printBits(static_cast(0)), "0000 0000"); + // uint8_t: 0xFF should be all ones. + EXPECT_EQ(printBits(static_cast(0xFF)), "1111 1111"); + // uint8_t: 0xA5 = 1010 0101. + EXPECT_EQ(printBits(static_cast(0xA5)), "1010 0101"); + // uint16_t. + EXPECT_EQ(printBits(static_cast(0x00FF)), "0000 0000 1111 1111"); + // uint32_t: 1. + EXPECT_EQ( + printBits(static_cast(1)), + "0000 0000 0000 0000 0000 0000 0000 0001"); +} + +TEST_F(BitUtilTest, bitmap) { + uint8_t data[16]; + memset(data, 0, sizeof(data)); + setBit(data, 0); + setBit(data, 7); + setBit(data, 8); + setBit(data, 63); + setBit(data, 64); + setBit(data, 127); + + Bitmap bm(data, 128); + EXPECT_EQ(bm.size(), 128u); + EXPECT_EQ(bm.bits(), data); + EXPECT_TRUE(bm.test(0)); + EXPECT_FALSE(bm.test(1)); + EXPECT_TRUE(bm.test(7)); + EXPECT_TRUE(bm.test(8)); + EXPECT_FALSE(bm.test(9)); + EXPECT_TRUE(bm.test(63)); + EXPECT_TRUE(bm.test(64)); + EXPECT_FALSE(bm.test(65)); + EXPECT_TRUE(bm.test(127)); +} + +TEST_F(BitUtilTest, bitmapBuilder) { + uint8_t data[16]; + memset(data, 0, sizeof(data)); + + BitmapBuilder builder(data, 128); + + // Test single-bit set. + builder.set(5); + EXPECT_TRUE(builder.test(5)); + EXPECT_FALSE(builder.test(4)); + + // Test maybeSet. + builder.maybeSet(10, true); + EXPECT_TRUE(builder.test(10)); + builder.maybeSet(11, false); + EXPECT_FALSE(builder.test(11)); + + // Test range set. + builder.set(20, 30); + for (uint32_t i = 20; i < 30; ++i) { + EXPECT_TRUE(builder.test(i)) << "at " << i; + } + EXPECT_FALSE(builder.test(19)); + EXPECT_FALSE(builder.test(30)); + + // Test range clear. + builder.clear(22, 28); + EXPECT_TRUE(builder.test(20)); + EXPECT_TRUE(builder.test(21)); + for (uint32_t i = 22; i < 28; ++i) { + EXPECT_FALSE(builder.test(i)) << "at " << i; + } + EXPECT_TRUE(builder.test(28)); + EXPECT_TRUE(builder.test(29)); +} + +TEST_F(BitUtilTest, bitmapBuilderCopy) { + uint8_t src[16]; + uint8_t dst[16]; + memset(src, 0, sizeof(src)); + memset(dst, 0, sizeof(dst)); + + // Set some bits in src. + setBit(src, 0); + setBit(src, 3); + setBit(src, 7); + setBit(src, 8); + setBit(src, 15); + setBit(src, 50); + setBit(src, 63); + setBit(src, 64); + setBit(src, 100); + + // Set a bit in dst that is before the copy range; it should be preserved. + setBit(dst, 1); + + Bitmap srcBm(src, 128); + BitmapBuilder dstBuilder(dst, 128); + + // Copy range [3, 101). + dstBuilder.copy(srcBm, 3, 101); + + // Bit 1 in dst should still be set (before copy range). + EXPECT_TRUE(isBitSet(dst, 1)); + // Bits from src in [3, 101) should now be in dst. + EXPECT_TRUE(isBitSet(dst, 3)); + EXPECT_TRUE(isBitSet(dst, 7)); + EXPECT_TRUE(isBitSet(dst, 8)); + EXPECT_TRUE(isBitSet(dst, 15)); + EXPECT_TRUE(isBitSet(dst, 50)); + EXPECT_TRUE(isBitSet(dst, 63)); + EXPECT_TRUE(isBitSet(dst, 64)); + EXPECT_TRUE(isBitSet(dst, 100)); + + // Bits that are not set in src within the range should be clear in dst. + EXPECT_FALSE(isBitSet(dst, 4)); + EXPECT_FALSE(isBitSet(dst, 9)); + EXPECT_FALSE(isBitSet(dst, 65)); + + // Copy byte-aligned range. + memset(dst, 0xFF, sizeof(dst)); + dstBuilder.copy(srcBm, 0, 128); + for (int i = 0; i < 128; ++i) { + EXPECT_EQ(isBitSet(dst, i), isBitSet(src, i)) << "at bit " << i; + } +} + } // namespace bits } // namespace velox } // namespace facebook diff --git a/velox/common/base/tests/BloomFilterTest.cpp b/velox/common/base/tests/BloomFilterTest.cpp index 85dcec0cb47..28cf5585c98 100644 --- a/velox/common/base/tests/BloomFilterTest.cpp +++ b/velox/common/base/tests/BloomFilterTest.cpp @@ -75,8 +75,9 @@ TEST_F(BloomFilterTest, staticMayContain) { bloom.serialize(serializedBloom.data()); int32_t numFalsePositives = 0; for (auto i = 0; i < kSize; ++i) { - EXPECT_TRUE(BloomFilter<>::mayContain( - serializedBloom.data(), folly::hasher()(i))); + EXPECT_TRUE( + BloomFilter<>::mayContain( + serializedBloom.data(), folly::hasher()(i))); const uint64_t smallValueHash = folly::hasher()(i + kSize); const bool isFalsePositiveForSmallValue = @@ -121,3 +122,41 @@ TEST_F(BloomFilterTest, merge) { EXPECT_EQ(bloom.serializedSize(), merge.serializedSize()); } + +TEST_F(BloomFilterTest, corruptMergeSize) { + // Serialization format: int8_t version (1) + int32_t size. + // Craft data with negative size to verify validation. + std::string data(5, '\0'); + data[0] = 1; // kBloomFilterV1 + // Write size = -1 as little-endian int32_t at offset 1. + int32_t badSize = -1; + memcpy(&data[1], &badSize, sizeof(badSize)); + + BloomFilter bloom; + EXPECT_THROW(bloom.merge(data.data()), VeloxRuntimeError); +} + +TEST_F(BloomFilterTest, optimalNumOfBitsWithFpp) { + EXPECT_EQ(BloomFilter<>::optimalNumOfBits(1000, 0.03), 7298); + EXPECT_EQ(BloomFilter<>::optimalNumOfBits(1000000, 0.01), 9585058); + EXPECT_EQ(BloomFilter<>::optimalNumOfBits(1, 0.5), 1); + EXPECT_EQ(BloomFilter<>::optimalNumOfBits(1000, 0.001), 14377); +} + +TEST_F(BloomFilterTest, optimalNumOfBitsWithMaxItems) { + constexpr int64_t kMaxNumItems = 4'000'000L; + + EXPECT_EQ( + BloomFilter<>::optimalNumOfBits(kMaxNumItems, kMaxNumItems), 29'193'763); + + EXPECT_EQ( + BloomFilter<>::optimalNumOfBits(1'000'000L, kMaxNumItems), 10'183'830); + + EXPECT_EQ(BloomFilter<>::optimalNumOfBits(100L, kMaxNumItems), 2935); + + EXPECT_EQ( + BloomFilter<>::optimalNumOfBits(5'000'000L, kMaxNumItems), 36'492'204); + + EXPECT_EQ( + BloomFilter<>::optimalNumOfBits(10'000'000L, kMaxNumItems), 72'984'408); +} diff --git a/velox/common/base/tests/CMakeLists.txt b/velox/common/base/tests/CMakeLists.txt index aae21a5355c..4eb30c7c57a 100644 --- a/velox/common/base/tests/CMakeLists.txt +++ b/velox/common/base/tests/CMakeLists.txt @@ -31,11 +31,12 @@ add_executable( SimdUtilTest.cpp SkewedPartitionBalancerTest.cpp SpillConfigTest.cpp - SpillStatsTest.cpp + SplitBlockBloomFilterTest.cpp StatsReporterTest.cpp StatusTest.cpp SuccinctPrinterTest.cpp ) +velox_add_test_headers(velox_base_test GTestUtils.h) add_test(velox_base_test velox_base_test) @@ -48,7 +49,7 @@ target_link_libraries( velox_time velox_status velox_exception - velox_temp_path + velox_test_util Boost::filesystem Boost::headers Folly::folly @@ -80,3 +81,9 @@ target_link_libraries( velox_memcpy_meter PRIVATE velox_common_base velox_exception velox_time Folly::folly gflags::gflags ) + +velox_add_library(velox_gtest_utils INTERFACE HEADERS GTestUtils.h) + +velox_add_library(velox_stats_reporter_utils INTERFACE HEADERS StatsReporterUtils.h) + +velox_add_library(velox_float_constants INTERFACE HEADERS FloatConstants.h) diff --git a/velox/common/base/tests/ConcurrentCounterTest.cpp b/velox/common/base/tests/ConcurrentCounterTest.cpp index c96fc1ac957..7945cb7e6c8 100644 --- a/velox/common/base/tests/ConcurrentCounterTest.cpp +++ b/velox/common/base/tests/ConcurrentCounterTest.cpp @@ -18,6 +18,7 @@ #include #include +#include #include #include "velox/common/base/tests/GTestUtils.h" @@ -48,7 +49,7 @@ class ConcurrentCounterTest : public testing::TestWithParam { void setupCounter() { counter_ = std::make_unique>( - std::thread::hardware_concurrency()); + folly::available_concurrency()); } const bool useUpdateFn_{GetParam()}; @@ -72,8 +73,8 @@ TEST_P(ConcurrentCounterTest, multithread) { const int32_t numUpdatesPerThread = 5'000; std::vector numThreads; numThreads.push_back(1); - numThreads.push_back(std::thread::hardware_concurrency()); - numThreads.push_back(std::thread::hardware_concurrency() * 2); + numThreads.push_back(folly::available_concurrency()); + numThreads.push_back(folly::available_concurrency() * 2); for (int numThreads : numThreads) { SCOPED_TRACE(fmt::format("numThreads: {}", numThreads)); counter_->testingClear(); diff --git a/velox/common/base/tests/ExceptionTest.cpp b/velox/common/base/tests/ExceptionTest.cpp index d057eb2dae3..c58ca70539b 100644 --- a/velox/common/base/tests/ExceptionTest.cpp +++ b/velox/common/base/tests/ExceptionTest.cpp @@ -16,6 +16,7 @@ #include #include +#include #include #include "velox/common/base/Exceptions.h" @@ -65,7 +66,7 @@ void verifyVeloxException( std::function f, const std::string& messagePrefix) { verifyException(f, [&messagePrefix](const auto& e) { - EXPECT_TRUE(folly::StringPiece{e.what()}.startsWith(messagePrefix)) + EXPECT_TRUE(std::string_view{e.what()}.starts_with(messagePrefix)) << "\nException message prefix mismatch.\n\nExpected prefix: " << messagePrefix << "\n\nActual message: " << e.what(); }); @@ -106,11 +107,12 @@ void testExceptionTraceCollectionControl(bool userException, bool enabled) { false); } } catch (VeloxException& e) { - SCOPED_TRACE(fmt::format( - "enabled: {}, user flag: {}, sys flag: {}", - enabled, - FLAGS_velox_exception_user_stacktrace_enabled, - FLAGS_velox_exception_system_stacktrace_enabled)); + SCOPED_TRACE( + fmt::format( + "enabled: {}, user flag: {}, sys flag: {}", + enabled, + FLAGS_velox_exception_user_stacktrace_enabled, + FLAGS_velox_exception_system_stacktrace_enabled)); ASSERT_EQ(userException, e.exceptionType() == VeloxException::Type::kUser); ASSERT_EQ(enabled, e.stackTrace() != nullptr); } @@ -170,12 +172,13 @@ void testExceptionTraceCollectionRateControl( false); } } catch (VeloxException& e) { - SCOPED_TRACE(fmt::format( - "userException: {}, hasRateLimit: {}, user limit: {}ms, sys limit: {}ms", - userException, - hasRateLimit, - FLAGS_velox_exception_user_stacktrace_rate_limit_ms, - FLAGS_velox_exception_system_stacktrace_rate_limit_ms)); + SCOPED_TRACE( + fmt::format( + "userException: {}, hasRateLimit: {}, user limit: {}ms, sys limit: {}ms", + userException, + hasRateLimit, + FLAGS_velox_exception_user_stacktrace_rate_limit_ms, + FLAGS_velox_exception_system_stacktrace_rate_limit_ms)); ASSERT_EQ( userException, e.exceptionType() == VeloxException::Type::kUser); ASSERT_EQ(!hasRateLimit || ((iter % 2) == 0), e.stackTrace() != nullptr); @@ -1009,12 +1012,123 @@ TEST(ExceptionTest, exceptionMacroInlining) { } catch (const VeloxUserError& ve) { ASSERT_EQ(ve.message(), errorStr); } +} + +TEST(ExceptionTest, messageTemplate) { + using testing::Property; + using testing::StrEq; + using testing::Throws; + + auto noMsg = [] { VELOX_FAIL(); }; + EXPECT_THAT( + noMsg, + Throws( + Property(&VeloxException::messageTemplate, StrEq("")))); + + auto plainMsg = [] { VELOX_FAIL("Something went wrong"); }; + EXPECT_THAT( + plainMsg, + Throws(Property( + &VeloxException::messageTemplate, StrEq("Something went wrong")))); + + // message() is interpolated, messageTemplate() is not. + auto fmtMsg = [] { VELOX_FAIL("Error: {} vs {}", 42, 99); }; + EXPECT_THAT( + fmtMsg, + Throws(Property( + &VeloxException::messageTemplate, StrEq("Error: {} vs {}")))); + EXPECT_THAT( + fmtMsg, + Throws( + Property(&VeloxException::message, StrEq("Error: 42 vs 99")))); + + auto userFail = [] { VELOX_USER_FAIL("Not supported: {}", "rank"); }; + EXPECT_THAT( + userFail, + Throws(Property( + &VeloxException::messageTemplate, StrEq("Not supported: {}")))); + + auto checkFail = [] { VELOX_CHECK(false, "Bad state: {}", "disconnected"); }; + EXPECT_THAT( + checkFail, + Throws( + Property(&VeloxException::messageTemplate, StrEq("Bad state: {}")))); + + auto checkEq = [] { VELOX_CHECK_EQ(1, 2); }; + EXPECT_THAT( + checkEq, + Throws( + Property(&VeloxException::messageTemplate, StrEq("({} vs. {})")))); + + auto checkLt = [] { VELOX_CHECK_LT(10, 5, "Expected smaller"); }; + EXPECT_THAT( + checkLt, + Throws(Property( + &VeloxException::messageTemplate, + StrEq("({} vs. {}) Expected smaller")))); + + auto unsupported = [] { VELOX_UNSUPPORTED("{} not supported", "merge"); }; + EXPECT_THAT( + unsupported, + Throws(Property( + &VeloxException::messageTemplate, StrEq("{} not supported")))); + + auto nyi = [] { VELOX_NYI("{} not implemented", "windowing"); }; + EXPECT_THAT( + nyi, + Throws(Property( + &VeloxException::messageTemplate, StrEq("{} not implemented")))); + + // Wrapped exception has no explicit template, so messageTemplate() returns + // the message itself. + { + auto exceptionPtr = + std::make_exception_ptr(std::invalid_argument("wrapped")); + VeloxUserError ve(exceptionPtr, "wrapped", false); + EXPECT_EQ(ve.messageTemplate(), "wrapped"); + } - // Inlined with the method that passes the errorStr and the next argument via - // fmt::vformat. Should throw format_error. + // messageTemplate not in what(). try { - VELOX_USER_FAIL(errorStr, "definitely"); - } catch (const std::exception& e) { - ASSERT_TRUE(folly::StringPiece{e.what()}.startsWith("argument not found")); + VELOX_FAIL("Count: {}", 42); + FAIL() << "Expected an exception"; + } catch (const VeloxException& e) { + EXPECT_EQ(e.messageTemplate(), "Count: {}"); + EXPECT_EQ( + std::string(e.what()).find("Message Template:"), std::string::npos); + } + + // No message arg. + auto checkNotNull = [] { + int* ptr = nullptr; + VELOX_CHECK_NOT_NULL(ptr); + }; + EXPECT_THAT( + checkNotNull, + Throws( + Property(&VeloxException::messageTemplate, StrEq("")))); + + // VELOX_CHECK_EQ with user-provided format args. + auto checkEqFmt = [] { VELOX_CHECK_EQ(1, 2, "Expected {} == {}", 1, 2); }; + EXPECT_THAT( + checkEqFmt, + Throws(Property( + &VeloxException::messageTemplate, + StrEq("({} vs. {}) Expected {} == {}")))); + + // Exception copy preserves messageTemplate. + try { + VELOX_FAIL("Copy test: {}", 42); + FAIL() << "Expected an exception"; + } catch (const VeloxException& e) { + const VeloxException& copy = e; + EXPECT_EQ(copy.messageTemplate(), "Copy test: {}"); + + auto rethrown = std::current_exception(); + try { + std::rethrow_exception(rethrown); + } catch (const VeloxException& e2) { + EXPECT_EQ(e2.messageTemplate(), "Copy test: {}"); + } } } diff --git a/velox/common/base/tests/FsTest.cpp b/velox/common/base/tests/FsTest.cpp index 10e4dafee67..e50f1639b4d 100644 --- a/velox/common/base/tests/FsTest.cpp +++ b/velox/common/base/tests/FsTest.cpp @@ -17,14 +17,16 @@ #include "velox/common/base/Fs.h" #include #include "boost/filesystem.hpp" -#include "velox/exec/tests/utils/TempDirectoryPath.h" +#include "velox/common/testutil/TempDirectoryPath.h" + +using namespace facebook::velox::common::testutil; namespace facebook::velox::common { class FsTest : public testing::Test {}; TEST_F(FsTest, createDirectory) { - auto dir = exec::test::TempDirectoryPath::create(); + auto dir = TempDirectoryPath::create(); auto rootPath = dir->getPath(); auto tmpDirectoryPath = rootPath + "/first/second/third"; // First time should generate directory successfully. diff --git a/velox/common/base/tests/GTestUtils.h b/velox/common/base/tests/GTestUtils.h index ce88abf0e05..42ec23513c4 100644 --- a/velox/common/base/tests/GTestUtils.h +++ b/velox/common/base/tests/GTestUtils.h @@ -117,6 +117,22 @@ } \ } +#define VELOX_ASSERT_EQ_TYPES(actual, expected) \ + { \ + auto _actualType = (actual); \ + auto _expectedType = (expected); \ + if (_expectedType != nullptr) { \ + ASSERT_TRUE(_actualType != nullptr) \ + << "Expected: " << _expectedType->toString() << ", got null"; \ + ASSERT_EQ(*_actualType, *_expectedType) \ + << "Expected: " << _expectedType->toString() << ", got " \ + << _actualType->toString(); \ + } else { \ + ASSERT_EQ(_actualType, nullptr) \ + << "Expected null, got " << _actualType->toString(); \ + } \ + } + #ifndef NDEBUG #define DEBUG_ONLY_TEST(test_fixture, test_name) TEST(test_fixture, test_name) #define DEBUG_ONLY_TEST_F(test_fixture, test_name) \ diff --git a/velox/common/base/tests/RuntimeMetricsTest.cpp b/velox/common/base/tests/RuntimeMetricsTest.cpp index 7180f682a80..29a5ffe65ae 100644 --- a/velox/common/base/tests/RuntimeMetricsTest.cpp +++ b/velox/common/base/tests/RuntimeMetricsTest.cpp @@ -24,7 +24,7 @@ class RuntimeMetricsTest : public testing::Test { static void testMetric( const RuntimeMetric& rm1, int64_t expectedSum, - int64_t expectedCount, + uint64_t expectedCount, int64_t expectedMin = std::numeric_limits::max(), int64_t expectedMax = std::numeric_limits::min()) { EXPECT_EQ(expectedSum, rm1.sum); @@ -84,4 +84,19 @@ TEST_F(RuntimeMetricsTest, basic) { "sum:2.00us, count:1, min:2.00us, max:2.00us, avg: 2.00us"); } +TEST_F(RuntimeMetricsTest, saturateCast) { + auto maxUint64 = std::numeric_limits::max(); + RuntimeMetric rm{ + saturateCast(maxUint64), + maxUint64, + saturateCast(maxUint64), + saturateCast(maxUint64)}; + + auto maxInt64 = std::numeric_limits::max(); + EXPECT_EQ(rm.sum, maxInt64); + EXPECT_EQ(rm.count, maxUint64); + EXPECT_EQ(rm.min, maxInt64); + EXPECT_EQ(rm.max, maxInt64); +} + } // namespace facebook::velox diff --git a/velox/common/base/tests/SimdUtilTest.cpp b/velox/common/base/tests/SimdUtilTest.cpp index 447cc55a6e2..5573df46028 100644 --- a/velox/common/base/tests/SimdUtilTest.cpp +++ b/velox/common/base/tests/SimdUtilTest.cpp @@ -126,6 +126,7 @@ class SimdUtilTest : public testing::Test { folly::Random::DefaultGenerator rng_; }; +#ifdef VELOX_ENABLE_LOAD_SIMD_VALUE_BUFFER TEST_F(SimdUtilTest, setAll) { auto bits = simd::setAll(true); auto words = reinterpret_cast(&bits); @@ -133,6 +134,7 @@ TEST_F(SimdUtilTest, setAll) { EXPECT_EQ(words[i], -1ll); } } +#endif TEST_F(SimdUtilTest, bitIndices) { testIndices(1); @@ -385,6 +387,10 @@ TEST_F(SimdUtilTest, crc32) { EXPECT_EQ(checksum, 3531890030); checksum = simd::crc32U64(0, 987654321); EXPECT_EQ(checksum, 121285919); + checksum = simd::crc32U64(0, 123456789, xsimd::generic{}); + EXPECT_EQ(checksum, 3531890030); + checksum = simd::crc32U64(0, 987654321, xsimd::generic{}); + EXPECT_EQ(checksum, 121285919); } TEST_F(SimdUtilTest, Batch64_assign) { @@ -654,4 +660,76 @@ TEST_F(SimdUtilTest, randomStringStrStr) { } } +TEST_F(SimdUtilTest, simdFill) { + // Test supported types: SIMD path (int32_t, uint32_t, int64_t, uint64_t, + // float, double). + auto testFill = [](auto value) { + using T = decltype(value); + struct TestParam { + uint32_t count; + std::string debugString() const { + return fmt::format("count {}", count); + } + }; + std::vector testSettings = { + {0}, + {1}, + {3}, + {7}, + {8}, + {9}, + {15}, + {16}, + {17}, + {31}, + {32}, + {33}, + {63}, + {64}, + {100}, + {255}, + {256}, + {1'000}, + {10'000}}; + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + std::vector output(testData.count, T{}); + simd::simdFill(output.data(), value, testData.count); + for (uint32_t i = 0; i < testData.count; ++i) { + ASSERT_EQ(output[i], value) << "at index " << i; + } + } + }; + + testFill(static_cast(42)); + testFill(static_cast(42)); + testFill(static_cast(123'456'789)); + testFill(static_cast(123'456'789)); + testFill(3.14f); + testFill(2.718281828); + + // Test unsupported types: fallback to std::fill (int8_t, int16_t, bool). + { + std::vector output(100, 0); + simd::simdFill(output.data(), static_cast(7), 100u); + for (auto v : output) { + ASSERT_EQ(v, 7); + } + } + { + std::vector output(100, 0); + simd::simdFill(output.data(), static_cast(1'234), 100u); + for (auto v : output) { + ASSERT_EQ(v, 1'234); + } + } + { + bool output[100] = {}; + simd::simdFill(output, true, 100u); + for (int i = 0; i < 100; ++i) { + ASSERT_EQ(output[i], true); + } + } +} + } // namespace diff --git a/velox/common/base/tests/SkewedPartitionBalancerTest.cpp b/velox/common/base/tests/SkewedPartitionBalancerTest.cpp index 5edb57563c6..52e592f9351 100644 --- a/velox/common/base/tests/SkewedPartitionBalancerTest.cpp +++ b/velox/common/base/tests/SkewedPartitionBalancerTest.cpp @@ -19,7 +19,7 @@ #include #include "folly/Random.h" -#include "folly/experimental/EventCount.h" +#include "folly/synchronization/EventCount.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/testutil/TestValue.h" @@ -76,6 +76,10 @@ class SkewedPartitionRebalancerTestHelper { class SkewedPartitionRebalancerTest : public testing::Test { protected: + static void SetUpTestSuite() { + TestValue::enable(); + } + std::unique_ptr createBalancer( uint32_t numPartitions = 128, uint32_t numTasks = 8, @@ -330,10 +334,18 @@ DEBUG_ONLY_TEST_F(SkewedPartitionRebalancerTest, serializedRebalanceExecution) { SkewedPartitionRebalancerTestHelper helper(balancer.get()); folly::EventCount rebalanceWait; std::atomic_bool rebalanceWaitFlag{true}; + + // Used to signal that the background thread has entered the test value + // callback. + folly::EventCount mainThreadRebalancerWait; + std::atomic_bool mainThreadRebalancerWaitFlag{true}; SCOPED_TESTVALUE_SET( "facebook::velox::common::SkewedPartitionRebalancer::rebalancePartitions", std::function( [&](SkewedPartitionRebalancer*) { + // Signal that we've entered the callback. + mainThreadRebalancerWaitFlag = false; + mainThreadRebalancerWait.notifyAll(); rebalanceWait.await([&] { return !rebalanceWaitFlag.load(); }); })); @@ -344,6 +356,12 @@ DEBUG_ONLY_TEST_F(SkewedPartitionRebalancerTest, serializedRebalanceExecution) { std::thread rebalanceThread([&]() { balancer->rebalance(); }); + // Wait for the background thread to enter the test value callback and block + // there. This ensures that when we call rebalance() from the main thread, + // rebalancing_ is already true and our rebalance() call will return early. + mainThreadRebalancerWait.await( + [&] { return !mainThreadRebalancerWaitFlag.load(); }); + balancer->rebalance(); rebalanceWaitFlag = false; @@ -409,8 +427,9 @@ TEST_F(SkewedPartitionRebalancerTest, concurrentFuzz) { threads.emplace_back([&]() { std::mt19937 localRng{200}; for (int iteration = 0; iteration < 1'000; ++iteration) { - SCOPED_TRACE(fmt::format( - "taskCount {}, iteration {}", taskCount, iteration)); + SCOPED_TRACE( + fmt::format( + "taskCount {}, iteration {}", taskCount, iteration)); const uint64_t processedBytes = 1 + folly::Random::rand32(512, localRng); balancer->addProcessedBytes(processedBytes); diff --git a/velox/common/base/tests/SpillConfigTest.cpp b/velox/common/base/tests/SpillConfigTest.cpp index 9949486a486..30251da5c9b 100644 --- a/velox/common/base/tests/SpillConfigTest.cpp +++ b/velox/common/base/tests/SpillConfigTest.cpp @@ -48,6 +48,7 @@ TEST_P(SpillConfigTest, spillLevel) { 0, 0, "none", + 0, prefixSortConfig_); struct { uint8_t bitOffset; @@ -134,6 +135,7 @@ TEST_P(SpillConfigTest, spillLevelLimit) { 0, 0, "none", + 0, prefixSortConfig_); ASSERT_EQ( @@ -181,6 +183,7 @@ TEST_P(SpillConfigTest, spillableReservationPercentages) { 1'000'000, 0, "none", + 0, prefixSortConfig_); }; diff --git a/velox/common/base/tests/SplitBlockBloomFilterTest.cpp b/velox/common/base/tests/SplitBlockBloomFilterTest.cpp new file mode 100644 index 00000000000..2a98ea75634 --- /dev/null +++ b/velox/common/base/tests/SplitBlockBloomFilterTest.cpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/base/SplitBlockBloomFilter.h" +#include "velox/common/testutil/RandomSeed.h" + +#include +#include + +#include + +namespace facebook::velox::test { +namespace { + +template +SplitBlockBloomFilter makeFilter( + const folly::F14FastSet& values, + const Hasher& hasher, + std::vector& blocks) { + blocks.resize(SplitBlockBloomFilter::numBlocks(values.size(), 0.01)); + bzero(blocks.data(), blocks.size() * sizeof(SplitBlockBloomFilter::Block)); + SplitBlockBloomFilter filter(blocks); + for (auto& value : values) { + filter.insert(hasher(value)); + } + return filter; +} + +TEST(SplitBlockBloomFilterTest, numBlocks) { + ASSERT_EQ( + SplitBlockBloomFilter::numBlocks(50'000'000, 0.01) * + sizeof(SplitBlockBloomFilter::Block), + xsimd::batch::size == 8 ? 60509568 : 65766912); + ASSERT_EQ( + SplitBlockBloomFilter::numBlocks(45'523'964, 0.1) * + sizeof(SplitBlockBloomFilter::Block), + xsimd::batch::size == 8 ? 32848640 : 27546352); +} + +TEST(SplitBlockBloomFilterTest, contiguous) { + constexpr int kSize = 100'000; + std::default_random_engine gen(common::testutil::getRandomSeed(42)); + std::uniform_int_distribution<> dist(0, 9); + folly::F14FastSet values; + values.reserve(kSize / 10); + for (int i = 0; i < kSize; ++i) { + if (dist(gen) == 0) { + values.insert(i); + } + } + std::vector blocks; + auto test = [&](auto hasher) { + auto filter = makeFilter(values, hasher, blocks); + int numFalsePositive = 0; + for (int i = 0; i < kSize; ++i) { + if (values.contains(i)) { + ASSERT_TRUE(filter.mayContain(hasher(i))); + } else { + numFalsePositive += filter.mayContain(hasher(i)); + } + } + ASSERT_LT(1.0 * numFalsePositive / kSize, 0.03); + }; + { + SCOPED_TRACE("Folly"); + test(folly::hasher()); + } + { + SCOPED_TRACE("Multiplication"); + test([](auto x) { return x * 0xc6a4a7935bd1e995L; }); + } +} + +TEST(SplitBlockBloomFilterTest, random) { + constexpr int kSize = 100'000; + std::default_random_engine gen(common::testutil::getRandomSeed(42)); + std::uniform_int_distribution dist; + folly::F14FastSet values; + values.reserve(kSize); + for (int i = 0; i < kSize; ++i) { + values.insert(dist(gen)); + } + std::vector blocks; + auto test = [&](auto hasher) { + auto filter = makeFilter(values, hasher, blocks); + for (auto value : values) { + ASSERT_TRUE(filter.mayContain(hasher(value))); + } + int numFalsePositive = 0; + for (int i = 0; i < kSize; ++i) { + auto value = dist(gen); + if (!values.contains(value)) { + numFalsePositive += filter.mayContain(hasher(value)); + } + } + ASSERT_LT(1.0 * numFalsePositive / kSize, 0.03); + }; + { + SCOPED_TRACE("Folly"); + test(folly::hasher()); + } + { + SCOPED_TRACE("Multiplication"); + test([](auto x) { return x * 0xc6a4a7935bd1e995L; }); + } +} + +} // namespace +} // namespace facebook::velox::test diff --git a/velox/common/base/tests/StatsReporterTest.cpp b/velox/common/base/tests/StatsReporterTest.cpp index 773c52db6c8..d76827d32f9 100644 --- a/velox/common/base/tests/StatsReporterTest.cpp +++ b/velox/common/base/tests/StatsReporterTest.cpp @@ -28,7 +28,7 @@ #include "velox/common/caching/SsdCache.h" #include "velox/common/memory/MmapAllocator.h" -namespace facebook::velox { +namespace facebook::velox::test { struct QuantileConfig { std::vector statTypes; @@ -409,69 +409,109 @@ TEST_F(PeriodicStatsReporterTest, basic) { const auto& counterMap = reporter_->counterMap; { std::lock_guard l(reporter_->m); - ASSERT_EQ(counterMap.count(kMetricArbitratorFreeCapacityBytes.str()), 1); - ASSERT_EQ( - counterMap.count(kMetricArbitratorFreeReservedCapacityBytes.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumEntries.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumEmptyEntries.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumSharedEntries.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumExclusiveEntries.str()), 1); - ASSERT_EQ( - counterMap.count(kMetricMemoryCacheNumPrefetchedEntries.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheTotalTinyBytes.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheTotalLargeBytes.str()), 1); - ASSERT_EQ( - counterMap.count(kMetricMemoryCacheTotalTinyPaddingBytes.str()), 1); - ASSERT_EQ( - counterMap.count(kMetricMemoryCacheTotalLargePaddingBytes.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheTotalPrefetchBytes.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheCachedEntries.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheCachedRegions.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheCachedBytes.str()), 1); - ASSERT_EQ(counterMap.count(kMetricCacheMaxAgeSecs.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryAllocatorMappedBytes.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryAllocatorAllocatedBytes.str()), 1); - ASSERT_EQ( - counterMap.count(kMetricMmapAllocatorDelegatedAllocatedBytes.str()), 1); - ASSERT_EQ( - counterMap.count(kMetricMemoryAllocatorExternalMappedBytes.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSpillMemoryBytes.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSpillPeakMemoryBytes.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryAllocatorTotalUsedBytes.str()), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricArbitratorFreeCapacityBytes)), 1); + ASSERT_EQ( + counterMap.count( + std::string(kMetricArbitratorFreeReservedCapacityBytes)), + 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumTinyEntries)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumLargeEntries)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumEmptyEntries)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumSharedEntries)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumExclusiveEntries)), + 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumPrefetchedEntries)), + 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheTotalTinyBytes)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheTotalLargeBytes)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheTotalTinyPaddingBytes)), + 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheTotalLargePaddingBytes)), + 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheTotalPrefetchBytes)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheCachedEntries)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheCachedRegions)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheCachedBytes)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricCacheMaxAgeSecs)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryAllocatorMappedBytes)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryAllocatorAllocatedBytes)), 1); + ASSERT_EQ( + counterMap.count( + std::string(kMetricMmapAllocatorDelegatedAllocatedBytes)), + 1); + ASSERT_EQ( + counterMap.count( + std::string(kMetricMemoryAllocatorExternalMappedBytes)), + 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSpillMemoryBytes)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSpillPeakMemoryBytes)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryAllocatorTotalUsedBytes)), 1); // Check deltas are not reported - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumHits.str()), 0); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheHitBytes.str()), 0); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumNew.str()), 0); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumEvicts.str()), 0); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumSavableEvicts.str()), 0); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumEvictChecks.str()), 0); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumWaitExclusive.str()), 0); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumAllocClocks.str()), 0); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumAgedOutEntries.str()), 0); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheSumEvictScore.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheReadEntries.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheReadBytes.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheWrittenEntries.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheWrittenBytes.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheOpenSsdErrors.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheOpenCheckpointErrors.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheOpenLogErrors.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheMetaFileDeleteErrors.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheGrowFileErrors.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheWriteSsdErrors.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheWriteSsdDropped.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheWriteCheckpointErrors.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheReadSsdErrors.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheReadCorruptions.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheReadCheckpointErrors.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheCheckpointsRead.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheCheckpointsWritten.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheRegionsEvicted.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheAgedOutEntries.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheAgedOutRegions.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheRecoveredEntries.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheReadWithoutChecksum.str()), 0); - ASSERT_EQ(counterMap.size(), 23); + ASSERT_EQ(counterMap.count(std::string(kMetricMemoryCacheNumHits)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricMemoryCacheHitBytes)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricMemoryCacheNumNew)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricMemoryCacheNumEvicts)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumSavableEvicts)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumEvictChecks)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumWaitExclusive)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumAllocClocks)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumAgedOutEntries)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheSumEvictScore)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheReadEntries)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheReadBytes)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheWrittenEntries)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheWrittenBytes)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheOpenSsdErrors)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheOpenCheckpointErrors)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheOpenLogErrors)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheMetaFileDeleteErrors)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheGrowFileErrors)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheWriteSsdErrors)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheWriteNoSpaceErrors)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheWriteSsdDropped)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheWriteExceedEntryLimit)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheWriteCheckpointErrors)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheReadSsdErrors)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheReadCorruptions)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheReadCheckpointErrors)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheCheckpointsRead)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheCheckpointsWritten)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheRegionsEvicted)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheAgedOutEntries)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheAgedOutRegions)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheRecoveredEntries)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheReadWithoutChecksum)), 0); + ASSERT_EQ(counterMap.size(), 24); } // Update stats @@ -492,7 +532,9 @@ TEST_F(PeriodicStatsReporterTest, basic) { newSsdStats->deleteMetaFileErrors = 10; newSsdStats->growFileErrors = 10; newSsdStats->writeSsdErrors = 10; + newSsdStats->writeSsdNoSpaceErrors = 1; newSsdStats->writeSsdDropped = 10; + newSsdStats->writeSsdExceedEntryLimit = 10; newSsdStats->writeCheckpointErrors = 10; newSsdStats->readSsdErrors = 10; newSsdStats->readSsdCorruptions = 10; @@ -511,8 +553,9 @@ TEST_F(PeriodicStatsReporterTest, basic) { .allocClocks = 10, .sumEvictScore = 10, .ssdStats = newSsdStats}); - arbitrator.updateStats(memory::MemoryArbitrator::Stats( - 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10)); + arbitrator.updateStats( + memory::MemoryArbitrator::Stats( + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10)); std::this_thread::sleep_for(std::chrono::milliseconds(4'000)); // Stop right after sufficient wait to ensure the following reads from main @@ -522,39 +565,56 @@ TEST_F(PeriodicStatsReporterTest, basic) { // Check delta stats are reported { std::lock_guard l(reporter_->m); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumHits.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheHitBytes.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumNew.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumEvicts.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumSavableEvicts.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumEvictChecks.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumWaitExclusive.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumAllocClocks.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumAgedOutEntries.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheSumEvictScore.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheReadEntries.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheReadBytes.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheWrittenEntries.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheWrittenBytes.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheOpenSsdErrors.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheOpenCheckpointErrors.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheOpenLogErrors.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheMetaFileDeleteErrors.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheGrowFileErrors.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheWriteSsdErrors.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheWriteSsdDropped.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheWriteCheckpointErrors.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheReadSsdErrors.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheReadCorruptions.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheReadCheckpointErrors.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheCheckpointsRead.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheCheckpointsWritten.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheRegionsEvicted.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheAgedOutEntries.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheAgedOutRegions.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheRecoveredEntries.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheReadWithoutChecksum.str()), 1); - ASSERT_EQ(counterMap.size(), 55); + ASSERT_EQ(counterMap.count(std::string(kMetricMemoryCacheNumHits)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricMemoryCacheHitBytes)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricMemoryCacheNumNew)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricMemoryCacheNumEvicts)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumSavableEvicts)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumEvictChecks)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumWaitExclusive)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumAllocClocks)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumAgedOutEntries)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheSumEvictScore)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheReadEntries)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheReadBytes)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheWrittenEntries)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheWrittenBytes)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheOpenSsdErrors)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheOpenCheckpointErrors)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheOpenLogErrors)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheMetaFileDeleteErrors)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheGrowFileErrors)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheWriteSsdErrors)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheWriteNoSpaceErrors)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheWriteSsdDropped)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheWriteExceedEntryLimit)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheWriteCheckpointErrors)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheReadSsdErrors)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheReadCorruptions)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheReadCheckpointErrors)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheCheckpointsRead)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheCheckpointsWritten)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheRegionsEvicted)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheAgedOutEntries)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheAgedOutRegions)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheRecoveredEntries)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheReadWithoutChecksum)), 1); + ASSERT_EQ(counterMap.size(), 58); } } @@ -903,7 +963,7 @@ folly::Singleton reporter([]() { return new TestReporter(); }); -} // namespace facebook::velox +} // namespace facebook::velox::test int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); diff --git a/velox/common/base/tests/StatsReporterUtils.h b/velox/common/base/tests/StatsReporterUtils.h index b65bee056e2..f58111c3e29 100644 --- a/velox/common/base/tests/StatsReporterUtils.h +++ b/velox/common/base/tests/StatsReporterUtils.h @@ -26,13 +26,11 @@ #include #include "velox/common/base/StatsReporter.h" -namespace facebook::velox { +namespace facebook::velox::test { -/** - * A test implementation of BaseStatsReporter for use in unit tests. - * This class provides a mock implementation that captures all metric - * registrations and values for verification in tests. - */ +/// A test implementation of BaseStatsReporter for use in unit tests. +/// This class provides a mock implementation that captures all metric +/// registrations and values for verification in tests. class TestReporter : public BaseStatsReporter { public: mutable std::mutex m; @@ -239,10 +237,8 @@ class TestReporter : public BaseStatsReporter { return ss.str(); } - /** - * Get the current counter value for a specific key. - * Returns 0 if the key doesn't exist. - */ + // Get the current counter value for a specific key. + // Returns 0 if the key doesn't exist. size_t getCounterValue(const std::string& key) const { std::lock_guard l(m); auto it = counterMap.find(key); @@ -250,4 +246,4 @@ class TestReporter : public BaseStatsReporter { } }; -} // namespace facebook::velox +} // namespace facebook::velox::test diff --git a/velox/common/base/tests/StatusTest.cpp b/velox/common/base/tests/StatusTest.cpp index 282563d142f..1501c18dd30 100644 --- a/velox/common/base/tests/StatusTest.cpp +++ b/velox/common/base/tests/StatusTest.cpp @@ -211,6 +211,22 @@ TEST(StatusTest, statusMacros) { "Reason: User error occurred.\nExpression: status != nullptr\n")); } +TEST(StatusTest, statusMacrosSkipDetails) { + ScopedThreadSkipErrorDetails skipErrorDetails(true); + ASSERT_EQ(returnMacroCheck(), Status::UserError()); + ASSERT_EQ(returnMacroEmptyMessage(), Status::UserError()); + ASSERT_EQ(returnMacroFormat(), Status::UserError()); + ASSERT_EQ(returnMacroGT(), Status::UserError()); + ASSERT_EQ(returnMacroGE(), Status::UserError()); + ASSERT_EQ(returnMacroLT(), Status::UserError()); + ASSERT_EQ(returnMacroLE(), Status::UserError()); + ASSERT_EQ(returnMacroEQ(), Status::UserError()); + ASSERT_EQ(returnMacroNE(), Status::UserError()); + ASSERT_EQ(returnMacroNULL(), Status::UserError()); + Status status = Status::OK(); + ASSERT_EQ(returnNotNull(&status), Status::UserError()); +} + Expected modulo(int a, int b) { if (b == 0) { return folly::makeUnexpected(Status::UserError("division by zero")); diff --git a/velox/common/caching/AsyncDataCache.cpp b/velox/common/caching/AsyncDataCache.cpp index 717107416d5..9ff5b382323 100644 --- a/velox/common/caching/AsyncDataCache.cpp +++ b/velox/common/caching/AsyncDataCache.cpp @@ -23,7 +23,6 @@ #include "velox/common/base/Exceptions.h" #include "velox/common/base/StatsReporter.h" #include "velox/common/base/SuccinctPrinter.h" -#include "velox/common/caching/FileIds.h" #define VELOX_CACHE_ERROR(errorMessage) \ _VELOX_THROW( \ @@ -44,7 +43,34 @@ AsyncDataCacheEntry::AsyncDataCacheEntry(CacheShard* shard) : shard_(shard) { } AsyncDataCacheEntry::~AsyncDataCacheEntry() { - shard_->cache()->allocator()->freeNonContiguous(data_); + freeData(); +} + +void AsyncDataCacheEntry::freeData() { + auto* cache = shard_->cache(); + const auto nonContiguousPages = nonContiguousData_.numPages(); + if (nonContiguousPages > 0) { + VELOX_CHECK_NULL( + contiguousData_, + "Entry cannot have both non-contiguous and contiguous data"); + VELOX_CHECK( + tinyData_.empty(), + "Entry cannot have both non-contiguous and tiny data"); + cache->incrementCachedPages(-nonContiguousPages); + ClockTimer t(shard_->allocClocks()); + cache->allocator()->freeNonContiguous(nonContiguousData_); + } + if (contiguousData_ != nullptr) { + VELOX_CHECK( + tinyData_.empty(), "Entry cannot have both contiguous and tiny data"); + cache->incrementCachedPages(-memory::AllocationTraits::numPages(size_)); + ClockTimer t(shard_->allocClocks()); + cache->allocator()->freeBytes(contiguousData_, size_); + contiguousData_ = nullptr; + } + tinyData_.clear(); + tinyData_.shrink_to_fit(); + size_ = 0; } void AsyncDataCacheEntry::setExclusiveToShared(bool ssdSavable) { @@ -111,7 +137,7 @@ memory::MachinePageCount AsyncDataCacheEntry::setPrefetch(bool flag) { return shard_->cache()->incrementPrefetchPages(flag ? numPages : -numPages); } -void AsyncDataCacheEntry::initialize(FileCacheKey key) { +void AsyncDataCacheEntry::initialize(FileCacheKey key, bool contiguous) { VELOX_CHECK(isExclusive()); setSsdFile(nullptr, 0); key_ = std::move(key); @@ -120,20 +146,78 @@ void AsyncDataCacheEntry::initialize(FileCacheKey key) { if (size_ < AsyncDataCacheEntry::kTinyDataSize) { tinyData_.resize(size_); tinyData_.shrink_to_fit(); - } else { - tinyData_.clear(); - tinyData_.shrink_to_fit(); - const auto sizePages = memory::AllocationTraits::numPages(size_); - if (cache->allocator()->allocateNonContiguous(sizePages, data_)) { - cache->incrementCachedPages(data().numPages()); - } else { - // No memory to cover 'this'. - release(); - VELOX_CACHE_ERROR(fmt::format( + return; + } + + tinyData_.clear(); + tinyData_.shrink_to_fit(); + if (contiguous) { + contiguousData_ = cache->allocator()->allocateBytes(size_); + if (contiguousData_ != nullptr) { + cache->incrementCachedPages(memory::AllocationTraits::numPages(size_)); + return; + } + release(); + VELOX_CACHE_ERROR( + fmt::format( + "Failed to allocate {} for contiguous cache: {}", + succinctBytes(size_), + cache->allocator()->getAndClearFailureMessage())); + } + + const auto sizePages = memory::AllocationTraits::numPages(size_); + if (cache->allocator()->allocateNonContiguous( + sizePages, nonContiguousData_)) { + cache->incrementCachedPages(nonContiguousData().numPages()); + return; + } + release(); + VELOX_CACHE_ERROR( + fmt::format( "Failed to allocate {} pages for cache: {}", sizePages, cache->allocator()->getAndClearFailureMessage())); - } +} + +std::vector> AsyncDataCacheEntry::dataRanges( + size_t length) { + std::vector> buffers; + if (hasContiguousData()) { + buffers.emplace_back(contiguousData(), std::min(length, size_)); + return buffers; + } + buffers.reserve(nonContiguousData_.numRuns()); + uint64_t offsetInRuns = 0; + for (int i = 0; i < nonContiguousData_.numRuns(); ++i) { + auto run = nonContiguousData_.runAt(i); + const uint64_t bytes = run.numBytes(); + const uint64_t readSize = std::min(bytes, length - offsetInRuns); + buffers.emplace_back(run.data(), readSize); + offsetInRuns += readSize; + } + return buffers; +} + +int64_t AsyncDataCacheEntry::dataCapacity() const { + return tinyData_.capacity() + (contiguousData_ != nullptr ? size_ : 0) + + nonContiguousData_.byteSize(); +} + +void AsyncDataCacheEntry::updateDataStats(CacheStats& stats) const { + if (!tinyData_.empty()) { + VELOX_CHECK_NULL(contiguousData_); + VELOX_CHECK(nonContiguousData_.empty()); + stats.tinySize += tinyData_.size(); + stats.tinyPadding += tinyData_.capacity() - tinyData_.size(); + ++stats.numTinyEntries; + } else if (contiguousData_ != nullptr) { + VELOX_CHECK(nonContiguousData_.empty()); + stats.largeSize += size_; + ++stats.numLargeEntries; + } else { + stats.largeSize += size_; + stats.largePadding += nonContiguousData_.byteSize() - size_; + ++stats.numLargeEntries; } } @@ -151,7 +235,7 @@ std::string AsyncDataCacheEntry::toString() const { numPins_); } -std::unique_ptr CacheShard::getFreeEntry() { +std::unique_ptr CacheShard::getFreeEntryLocked() { std::unique_ptr newEntry; if (freeEntries_.empty()) { newEntry = std::make_unique(this); @@ -162,57 +246,64 @@ std::unique_ptr CacheShard::getFreeEntry() { return newEntry; } +std::optional CacheShard::lookupLocked( + RawFileCacheKey key, + uint64_t size, + folly::SemiFuture* wait) { + ++eventCounter_; + auto it = entryMap_.find(key); + if (it == entryMap_.end()) { + return std::nullopt; + } + auto* foundEntry = it->second; + if (foundEntry->isExclusive()) { + ++numWaitExclusive_; + if (wait != nullptr) { + *wait = foundEntry->getFutureLocked(); + } + return CachePin{}; + } + // size=0 (from find()) always passes since entry size is non-negative. + if (foundEntry->size() < size) { + // This can happen if different load quanta apply to access via different + // connectors. This is not an error but still worth logging. + VELOX_CACHE_LOG_EVERY_MS(WARNING, 1'000) + << "Requested larger entry. Found size " << foundEntry->size() + << " requested size " << size; + RECORD_METRIC_VALUE(kMetricMemoryCacheNumStaleEntries); + ++numStales_; + foundEntry->key_.fileNum.clear(); + entryMap_.erase(it); + return std::nullopt; + } + foundEntry->touch(); + if (foundEntry->isPrefetch()) { + foundEntry->isFirstUse_ = true; + foundEntry->setPrefetch(false); + } else { + ++numHit_; + hitBytes_ += foundEntry->size(); + } + ++foundEntry->numPins_; + CachePin pin; + pin.setEntry(foundEntry); + return pin; +} + CachePin CacheShard::findOrCreate( RawFileCacheKey key, uint64_t size, + bool contiguous, folly::SemiFuture* wait) { AsyncDataCacheEntry* entryToInit = nullptr; { std::lock_guard l(mutex_); - ++eventCounter_; - auto it = entryMap_.find(key); - if (it != entryMap_.end()) { - auto* foundEntry = it->second; - if (foundEntry->isExclusive()) { - ++numWaitExclusive_; - if (wait != nullptr) { - *wait = foundEntry->getFuture(); - } - return CachePin(); - } - - if (foundEntry->size() >= size) { - foundEntry->touch(); - // The entry is in a readable state. Add a pin. - if (foundEntry->isPrefetch()) { - foundEntry->isFirstUse_ = true; - foundEntry->setPrefetch(false); - } else { - ++numHit_; - hitBytes_ += foundEntry->size(); - } - ++foundEntry->numPins_; - CachePin pin; - pin.setEntry(foundEntry); - return pin; - } - - // TODO: add stats to report or send alert in production. - - // This can happen if different load quanta apply to access via different - // connectors. This is not an error but still worth logging. - VELOX_CACHE_LOG_EVERY_MS(WARNING, 1'000) - << "Requested larger entry. Found size " << foundEntry->size() - << " requested size " << size; - // The old entry is superseded. Possible readers of the old entry still - // retain a valid read pin. - RECORD_METRIC_VALUE(kMetricMemoryCacheNumStaleEntries); - ++numStales_; - foundEntry->key_.fileNum.clear(); - entryMap_.erase(it); + auto result = lookupLocked(key, size, wait); + if (result.has_value()) { + return std::move(result.value()); } - auto newEntry = getFreeEntry(); + auto newEntry = getFreeEntryLocked(); // Initialize the members that must be set inside 'mutex_'. newEntry->numPins_ = AsyncDataCacheEntry::kExclusive; newEntry->promise_ = nullptr; @@ -231,7 +322,16 @@ CachePin CacheShard::findOrCreate( entryToInit->size_ = size; entryToInit->isFirstUse_ = true; } - return initEntry(key, entryToInit); + return initEntry(key, contiguous, entryToInit); +} + +std::optional CacheShard::find( + RawFileCacheKey key, + folly::SemiFuture* wait) { + std::lock_guard l(mutex_); + // size=0 means any cached entry size is acceptable, so lookupLocked will + // never trigger the stale-entry eviction path. + return lookupLocked(key, 0, wait); } void CacheShard::makeEvictable(RawFileCacheKey key) { @@ -253,8 +353,18 @@ bool CacheShard::exists(RawFileCacheKey key) const { return false; } +bool CacheShard::testingIsEvictable(RawFileCacheKey key) const { + std::lock_guard l(mutex_); + auto it = entryMap_.find(key); + if (it == entryMap_.end()) { + return false; + } + return it->second->testingIsEvictable(); +} + CachePin CacheShard::initEntry( RawFileCacheKey key, + bool contiguous, AsyncDataCacheEntry* entry) { // The new entry is in the map and is in exclusive mode and is otherwise // uninitialized. Other threads may find it and may add a promise or wait for @@ -262,7 +372,8 @@ CachePin CacheShard::initEntry( // and uninterpretable except for this thread. Non access serializing members // can be set outside of 'mutex_'. entry->initialize( - FileCacheKey{StringIdLease(fileIds(), key.fileNum), key.offset}); + FileCacheKey{StringIdLease(fileIds(), key.fileNum), key.offset}, + contiguous); cache_->incrementNew(entry->size()); CachePin pin; pin.setEntry(entry); @@ -278,6 +389,8 @@ bool CoalescedLoad::loadOrFuture( folly::SemiFuture* wait, bool ssdSavable) { { + common::testutil::TestValue::adjust( + "facebook::velox::cache::CoalescedLoad::loadOrFuture", this); std::lock_guard l(mutex_); if (state_ == State::kCancelled || state_ == State::kLoaded) { return true; @@ -298,6 +411,8 @@ bool CoalescedLoad::loadOrFuture( } // Outside of 'mutex_'. + common::testutil::TestValue::adjust( + "facebook::velox::cache::CoalescedLoad::loadOrFuture::loading", this); try { const auto pins = loadData(/*prefetch=*/wait == nullptr); for (const auto& pin : pins) { @@ -336,7 +451,7 @@ std::unique_ptr> CacheShard::removeEntry( removeEntryLocked(entry); // After the entry is removed from the hash table, a promise can no longer // be made. It is safe to move the promise and realize it. - return entry->movePromise(); + return entry->movePromiseLocked(); } void CacheShard::removeEntryLocked(AsyncDataCacheEntry* entry) { @@ -355,27 +470,58 @@ void CacheShard::removeEntryLocked(AsyncDataCacheEntry* entry) { // to fill it. Free the data and account for the difference. In // eviction, the data of the evicted entries is moved away, so // that freeing while holding the shard mutex is exceptional. - const auto numPages = entry->data().numPages(); - if (numPages > 0) { - cache_->incrementCachedPages(-numPages); - ClockTimer t(allocClocks_); - cache_->allocator()->freeNonContiguous(entry->data()); + entry->freeData(); +} + +void CacheShard::acquireEvictedData( + AsyncDataCacheEntry* entry, + uint64_t& bytesToAcquire, + AcquiredMemory& acquired, + AcquiredMemory& toFree, + int64_t& largeEvicted, + int64_t& tinyEvicted) { + if (entry->contiguousData_ != nullptr) { + VELOX_CHECK(entry->tinyData_.empty()); + VELOX_CHECK(entry->nonContiguousData_.empty()); + const uint64_t bytes = entry->size_; + if (bytesToAcquire > 0) { + bytesToAcquire = bytes > bytesToAcquire ? 0 : bytesToAcquire - bytes; + acquired.byteAllocations.emplace_back(entry->contiguousData_, bytes); + } else { + toFree.byteAllocations.emplace_back(entry->contiguousData_, bytes); + } + entry->contiguousData_ = nullptr; + largeEvicted += memory::AllocationTraits::pageBytes( + memory::AllocationTraits::numPages(bytes)); + } else if (entry->nonContiguousData_.numPages() > 0) { + VELOX_CHECK(entry->tinyData_.empty()); + const auto bytes = entry->nonContiguousData_.byteSize(); + if (bytesToAcquire > 0) { + bytesToAcquire = bytes > bytesToAcquire ? 0 : bytesToAcquire - bytes; + acquired.nonContiguousAllocs.appendMove(entry->nonContiguousData()); + } else { + toFree.nonContiguousAllocs.appendMove(entry->nonContiguousData()); + } + VELOX_DCHECK(entry->nonContiguousData().empty()); + largeEvicted += bytes; + } else { + tinyEvicted += entry->tinyData_.size(); + entry->tinyData_.clear(); + entry->tinyData_.shrink_to_fit(); } - entry->tinyData_.clear(); - entry->tinyData_.shrink_to_fit(); entry->size_ = 0; } uint64_t CacheShard::evict( uint64_t bytesToFree, bool evictAllUnpinned, - MachinePageCount pagesToAcquire, - memory::Allocation& acquired) { + uint64_t bytesToAcquire, + AcquiredMemory& acquired) { auto* ssdCache = cache_->ssdCache(); const bool skipSsdSaveable = (ssdCache != nullptr) && ssdCache->writeInProgress(); auto now = accessTime(); - std::vector toFree; + AcquiredMemory toFree; int64_t tinyEvicted = 0; int64_t largeEvicted = 0; int32_t evictSaveableSkipped = 0; @@ -409,7 +555,7 @@ uint64_t CacheShard::evict( eventCounter_ > entries_.size() / 4 || numChecked > entries_.size() / 8) { now = accessTime(); - calibrateThreshold(); + calibrateThresholdLocked(); numChecked = 0; eventCounter_ = 0; } @@ -425,21 +571,13 @@ uint64_t CacheShard::evict( if (candidate->ssdSaveable()) { ++numSavableEvict_; } - largeEvicted += candidate->data_.byteSize(); - if (pagesToAcquire > 0) { - const auto candidatePages = candidate->data().numPages(); - pagesToAcquire = candidatePages > pagesToAcquire - ? 0 - : pagesToAcquire - candidatePages; - acquired.appendMove(candidate->data()); - VELOX_CHECK(candidate->data().empty()); - } else { - toFree.push_back(std::move(candidate->data())); - } - tinyEvicted += candidate->tinyData_.size(); - candidate->tinyData_.clear(); - candidate->tinyData_.shrink_to_fit(); - candidate->size_ = 0; + acquireEvictedData( + candidate, + bytesToAcquire, + acquired, + toFree, + largeEvicted, + tinyEvicted); removeEntryLocked(candidate); emptySlots_.push_back(entryIndex); @@ -456,7 +594,7 @@ uint64_t CacheShard::evict( } ClockTimer t(allocClocks_); - freeAllocations(toFree); + toFree.free(cache_->allocator()); cache_->incrementCachedPages( -memory::AllocationTraits::numPages(largeEvicted)); if (evictSaveableSkipped) { @@ -483,15 +621,8 @@ void CacheShard::tryAddFreeEntry(std::unique_ptr&& entry) { } } -void CacheShard::freeAllocations(std::vector& allocations) { - for (auto& allocation : allocations) { - cache_->allocator()->freeNonContiguous(allocation); - } - allocations.clear(); -} - -void CacheShard::calibrateThreshold() { - auto numSamples = std::min(10, entries_.size()); +void CacheShard::calibrateThresholdLocked() { + auto numSamples = std::min(kMaxEvictionSamples, entries_.size()); auto now = accessTime(); auto entryIndex = (clockHand_ % entries_.size()); auto step = entries_.size() / numSamples; @@ -510,24 +641,33 @@ void CacheShard::calibrateThreshold() { return score; }, numSamples, - 80); + kEvictionPercentile); } void CacheShard::updateStats(CacheStats& stats) { std::lock_guard l(mutex_); for (auto& entry : entries_) { - if (!entry || !entry->key_.fileNum.hasValue()) { + if (!entry) { ++stats.numEmptyEntries; continue; } if (entry->isExclusive()) { - stats.exclusivePinnedBytes += - entry->data().byteSize() + entry->tinyData_.capacity(); + // We cannot read data() or tinyData_ which are being allocated during + // initialize(). Use size_ as an approximation of the pinned bytes. + stats.exclusivePinnedBytes += entry->size_; ++stats.numExclusive; - } else if (entry->isShared()) { - stats.sharedPinnedBytes += - entry->data().byteSize() + entry->tinyData_.capacity(); + // Skip rest of the field accesses while entry is being initialized. + continue; + } + + if (!entry->key_.fileNum.hasValue()) { + ++stats.numEmptyEntries; + continue; + } + + if (entry->isShared()) { + stats.sharedPinnedBytes += entry->dataCapacity(); ++stats.numShared; } @@ -537,12 +677,7 @@ void CacheShard::updateStats(CacheStats& stats) { } ++stats.numEntries; - stats.tinySize += entry->tinyData_.size(); - stats.tinyPadding += entry->tinyData_.capacity() - entry->tinyData_.size(); - if (entry->tinyData_.empty()) { - stats.largeSize += entry->size_; - stats.largePadding += entry->data_.byteSize() - entry->size_; - } + entry->updateDataStats(stats); } stats.numHit += numHit_; stats.hitBytes += hitBytes_; @@ -591,8 +726,9 @@ bool CacheShard::removeFileEntries( return true; } + int32_t numRemoved = 0; int64_t pagesRemoved = 0; - std::vector toFree; + AcquiredMemory toFree; { std::lock_guard l(mutex_); @@ -610,22 +746,28 @@ bool CacheShard::removeFileEntries( continue; } - numAgedOut_++; - pagesRemoved += static_cast(cacheEntry->data().numPages()); - - toFree.push_back(std::move(cacheEntry->data())); + ++numAgedOut_; + ++numRemoved; + if (cacheEntry->contiguousData_ != nullptr) { + pagesRemoved += memory::AllocationTraits::numPages(cacheEntry->size_); + toFree.byteAllocations.emplace_back( + cacheEntry->contiguousData_, cacheEntry->size_); + cacheEntry->contiguousData_ = nullptr; + } else { + pagesRemoved += cacheEntry->nonContiguousData().numPages(); + toFree.nonContiguousAllocs.appendMove(cacheEntry->nonContiguousData()); + } removeEntryLocked(cacheEntry.get()); emptySlots_.push_back(entryIndex); tryAddFreeEntry(std::move(cacheEntry)); cacheEntry = nullptr; } } - VELOX_CACHE_LOG(INFO) << "Removed " << toFree.size() + VELOX_CACHE_LOG(INFO) << "Removed " << numRemoved << " AsyncDataCache entries."; - // Free the memory allocation out of the cache shard lock. ClockTimer t(allocClocks_); - freeAllocations(toFree); + toFree.free(cache_->allocator()); cache_->incrementCachedPages(-pagesRemoved); return true; @@ -658,17 +800,25 @@ CacheStats CacheStats::operator-(const CacheStats& other) const { AsyncDataCache::AsyncDataCache( memory::MemoryAllocator* allocator, std::unique_ptr ssdCache) - : AsyncDataCache({}, allocator, std::move(ssdCache)){}; + : AsyncDataCache({}, allocator, std::move(ssdCache)) {} AsyncDataCache::AsyncDataCache( const Options& options, memory::MemoryAllocator* allocator, std::unique_ptr ssdCache) : opts_(options), + numShards_(opts_.numShards), + shardMask_(numShards_ - 1), allocator_(allocator), ssdCache_(std::move(ssdCache)), cachedPages_(0) { - for (auto i = 0; i < kNumShards; ++i) { + VELOX_CHECK_GT(numShards_, 0, "numShards must be positive"); + VELOX_CHECK_EQ( + numShards_ & shardMask_, + 0, + "numShards must be a power of 2, got {}", + numShards_); + for (auto i = 0; i < numShards_; ++i) { shards_.push_back(std::make_unique(this, opts_.maxWriteRatio)); } } @@ -719,24 +869,37 @@ void CacheShard::shutdown() { CachePin AsyncDataCache::findOrCreate( RawFileCacheKey key, uint64_t size, + bool contiguous, folly::SemiFuture* wait) { - const int shard = std::hash()(key) & (kShardMask); - return shards_[shard]->findOrCreate(key, size, wait); + const int shard = std::hash()(key) & shardMask_; + return shards_[shard]->findOrCreate(key, size, contiguous, wait); +} + +std::optional AsyncDataCache::find( + RawFileCacheKey key, + folly::SemiFuture* waitFuture) { + const int shard = std::hash()(key) & shardMask_; + return shards_[shard]->find(key, waitFuture); } void AsyncDataCache::makeEvictable(RawFileCacheKey key) { - const int shard = std::hash()(key) & (kShardMask); + const int shard = std::hash()(key) & shardMask_; return shards_[shard]->makeEvictable(key); } bool AsyncDataCache::exists(RawFileCacheKey key) const { - int shard = std::hash()(key) & (kShardMask); + const int shard = std::hash()(key) & shardMask_; return shards_[shard]->exists(key); } +bool AsyncDataCache::testingIsEvictable(RawFileCacheKey key) const { + const int shard = std::hash()(key) & shardMask_; + return shards_[shard]->testingIsEvictable(key); +} + bool AsyncDataCache::makeSpace( MachinePageCount numPages, - std::function allocate) { + std::function allocate) { // Try to allocate and if failed, evict the desired amount and // retry. This is without synchronization, so that other threads may // get what one thread evicted but this will usually work in a @@ -749,25 +912,25 @@ bool AsyncDataCache::makeSpace( // serialize with a mutex because memory arbitration must not be // called from inside a global mutex. - constexpr int32_t kMaxAttempts = kNumShards * 4; + const int32_t kMaxAttempts = numShards_ * 4; // Evict at least 1MB even for small allocations to avoid constantly hitting // the mutex protected evict loop. constexpr int32_t kMinEvictPages = 256; // If requesting less than kSmallSizePages try up to 4x more if // first try failed. constexpr int32_t kSmallSizePages = 2048; // 8MB + const auto requestBytes = memory::AllocationTraits::pageBytes(numPages); float sizeMultiplier = 1.2; // True if this thread is counted in 'numThreadsInAllocate_'. bool isCounted = false; // If more than half the allowed retries are needed, this is the rank in // arrival order of this. int32_t rank = 0; - // Allocation into which evicted pages are moved. - memory::Allocation acquired; - // 'acquired' is not managed by a pool. Make sure it is freed on throw. - // Destruct without pool and non-empty kills the process. + // Memory collected from evicted entries, freed by the allocate callback. + AcquiredMemory acquired; auto guard = folly::makeGuard([&]() { - allocator_->freeNonContiguous(acquired); + // Free on exception or early exit. Normal path frees inside allocate. + acquired.free(allocator_); if (isCounted) { --numThreadsInAllocate_; } @@ -781,10 +944,12 @@ bool AsyncDataCache::makeSpace( isCounted = true; } for (auto nthAttempt = 0; nthAttempt < kMaxAttempts; ++nthAttempt) { - if (canTryAllocate(numPages, acquired)) { + if (canTryAllocate(requestBytes, acquired)) { if (allocate(acquired)) { + VELOX_CHECK(acquired.empty()); return true; } + VELOX_CHECK(acquired.empty()); } if (nthAttempt > 2 && ssdCache_ && ssdCache_->writeInProgress()) { @@ -799,25 +964,26 @@ bool AsyncDataCache::makeSpace( } } if (rank) { - // Free the grabbed allocation before sleep so the contender can make + // Free the acquired memory before sleep so the contender can make // progress. This is only on heavy contention, after 8 missed tries. - allocator_->freeNonContiguous(acquired); + acquired.free(allocator_); backoff(nthAttempt + rank); // If some of the competing threads are done, maybe give this thread a // better rank. rank = std::min(rank, numThreadsInAllocate_); } ++shardCounter_; - int32_t numPagesToAcquire = - acquired.numPages() < numPages ? numPages - acquired.numPages() : 0; + const auto acquiredBytes = acquired.totalBytes(); + const uint64_t bytesToAcquire = + acquiredBytes < requestBytes ? requestBytes - acquiredBytes : 0; // Evict from next shard. If we have gone through all shards once // and still have not made the allocation, we go to desperate mode // with 'evictAllUnpinned' set to true. - shards_[shardCounter_ & (kShardMask)]->evict( + shards_[shardCounter_ & shardMask_]->evict( memory::AllocationTraits::pageBytes( std::max(kMinEvictPages, numPages) * sizeMultiplier), - nthAttempt >= kNumShards, - numPagesToAcquire, + nthAttempt >= numShards_, + bytesToAcquire, acquired); if (numPages < kSmallSizePages && sizeMultiplier < 4) { sizeMultiplier *= 2; @@ -835,22 +1001,22 @@ uint64_t AsyncDataCache::shrink(uint64_t targetBytes) { LOG(INFO) << "Try to shrink cache to free up " << velox::succinctBytes(targetBytes) << " memory"; - const uint64_t minBytesToEvict = 8UL << 20; uint64_t evictedBytes{0}; uint64_t shrinkTimeUs{0}; { MicrosecondTimer timer(&shrinkTimeUs); for (int shard = 0; shard < shards_.size(); ++shard) { - memory::Allocation unused; - evictedBytes += shards_[shardCounter_++ & (kShardMask)]->evict( - std::max(minBytesToEvict, targetBytes - evictedBytes), + AcquiredMemory acquired; + evictedBytes += shards_[shardCounter_++ & shardMask_]->evict( + std::max( + CacheShard::kMinBytesToEvict, targetBytes - evictedBytes), // Cache shrink is triggered when server is under low memory pressure // so need to free up memory as soon as possible. So we always avoid // triggering ssd save to accelerate the cache evictions. true, 0, - unused); - VELOX_CHECK(unused.empty()); + acquired); + VELOX_CHECK(acquired.empty()); if (evictedBytes >= targetBytes) { break; } @@ -869,14 +1035,16 @@ uint64_t AsyncDataCache::shrink(uint64_t targetBytes) { } bool AsyncDataCache::canTryAllocate( - MachinePageCount numPages, - const memory::Allocation& acquired) const { - if (numPages <= acquired.numPages()) { + uint64_t requestBytes, + const AcquiredMemory& acquired) const { + const auto acquiredBytes = acquired.totalBytes(); + if (requestBytes <= acquiredBytes) { return true; } - return numPages - acquired.numPages() <= - (memory::AllocationTraits::numPages(allocator_->capacity())) - - allocator_->numAllocated(); + return requestBytes - acquiredBytes <= + memory::AllocationTraits::pageBytes( + memory::AllocationTraits::numPages(allocator_->capacity()) - + allocator_->numAllocated()); } void AsyncDataCache::backoff(int32_t counter) { @@ -909,11 +1077,15 @@ void AsyncDataCache::possibleSsdSave(uint64_t bytes) { ssdSaveable_ += bytes; if (memory::AllocationTraits::numPages(ssdSaveable_) > - std::max( - static_cast( - memory::AllocationTraits::numPages(opts_.minSsdSavableBytes)), - static_cast( - static_cast(cachedPages_) * opts_.ssdSavableRatio))) { + std::max( + static_cast( + memory::AllocationTraits::numPages(opts_.minSsdSavableBytes)), + static_cast( + static_cast(cachedPages_) * opts_.ssdSavableRatio)) || + (opts_.ssdFlushThresholdBytes > 0 && + memory::AllocationTraits::numPages(ssdSaveable_) > + static_cast(memory::AllocationTraits::numPages( + opts_.ssdFlushThresholdBytes)))) { // Do not start a new save if another one is in progress. if (!ssdCache_->startWrite()) { return; @@ -925,7 +1097,7 @@ void AsyncDataCache::possibleSsdSave(uint64_t bytes) { void AsyncDataCache::saveToSsd(bool saveAll) { std::vector pins; VELOX_CHECK(ssdCache_->writeInProgress()); - ssdSaveable_ = 0; + ssdSaveable_ = false; for (auto& shard : shards_) { shard->appendSsdSaveable(saveAll, pins); } @@ -966,9 +1138,9 @@ CacheStats AsyncDataCache::refreshStats() const { void AsyncDataCache::clear() { for (auto& shard : shards_) { - memory::Allocation unused; - shard->evict(std::numeric_limits::max(), true, 0, unused); - VELOX_CHECK(unused.empty()); + AcquiredMemory acquired; + shard->evict(std::numeric_limits::max(), true, 0, acquired); + VELOX_CHECK(acquired.empty()); } } @@ -1039,33 +1211,34 @@ CoalesceIoStats readPins( [&](int32_t index) { return pins[index].checkedEntry()->size(); }, [&](int32_t index) { return std::max( - 1, pins[index].checkedEntry()->data().numRuns()); + 1, pins[index].checkedEntry()->nonContiguousData().numRuns()); }, [&](const CachePin& pin, std::vector>& ranges) { auto* entry = pin.checkedEntry(); - auto& data = entry->data(); + const auto size = entry->size(); + if (entry->hasContiguousData()) { + VELOX_CHECK_EQ(entry->contiguousDataSize(), size); + ranges.push_back(folly::Range(entry->contiguousData(), size)); + return; + } + const auto& data = entry->nonContiguousData(); + VELOX_CHECK_GT(data.numPages(), 0); uint64_t offsetInRuns = 0; - auto size = entry->size(); - if (data.numPages() == 0) { - ranges.push_back( - folly::Range(pin.checkedEntry()->tinyData(), size)); - offsetInRuns = size; - } else { - for (int i = 0; i < data.numRuns(); ++i) { - const auto run = data.runAt(i); - const uint64_t bytes = run.numBytes(); - const uint64_t readSize = std::min(bytes, size - offsetInRuns); - ranges.push_back(folly::Range(run.data(), readSize)); - offsetInRuns += readSize; - } + for (int i = 0; i < data.numRuns(); ++i) { + const auto run = data.runAt(i); + const uint64_t readSize = + std::min(run.numBytes(), size - offsetInRuns); + ranges.push_back(folly::Range(run.data(), readSize)); + offsetInRuns += readSize; } VELOX_CHECK_EQ(offsetInRuns, size); }, [&](int32_t size, std::vector>& ranges) { // This hack allows us to store the size of the gap in the Range, // without actually allocating a buffer for it. - ranges.push_back(folly::Range( - nullptr, reinterpret_cast(static_cast(size)))); + ranges.push_back( + folly::Range( + nullptr, reinterpret_cast(static_cast(size)))); }, std::move(readFunc)); } diff --git a/velox/common/caching/AsyncDataCache.h b/velox/common/caching/AsyncDataCache.h index 605e95cd568..40d9a4b4ed1 100644 --- a/velox/common/caching/AsyncDataCache.h +++ b/velox/common/caching/AsyncDataCache.h @@ -44,6 +44,7 @@ namespace facebook::velox::cache { class AsyncDataCache; class CacheShard; +struct CacheStats; class SsdCache; struct SsdCacheStats; class SsdFile; @@ -160,27 +161,59 @@ class AsyncDataCacheEntry { explicit AsyncDataCacheEntry(CacheShard* shard); ~AsyncDataCacheEntry(); - /// Sets the key and allocates the entry's memory. Resets - /// all other state. The entry must be held exclusively and must - /// hold no memory when calling this. - void initialize(FileCacheKey key); + /// Sets the key and allocates the entry's memory. Resets all other state. + /// The entry must be held exclusively and must hold no memory when calling + /// this. If 'contiguous' is true, allocates a single contiguous region + /// instead of a potentially non-contiguous Allocation. + void initialize(FileCacheKey key, bool contiguous = false); - memory::Allocation& data() { - return data_; + memory::Allocation& nonContiguousData() { + return nonContiguousData_; } - const memory::Allocation& data() const { - return data_; + const memory::Allocation& nonContiguousData() const { + return nonContiguousData_; } - const char* tinyData() const { - return tinyData_.empty() ? nullptr : tinyData_.data(); + /// Returns a pointer to contiguous data. Valid for entries created + /// with contiguous=true. Covers both tinyData_ (small) and + /// contiguousData_ (larger) paths. Throws if the entry has no + /// contiguous data. + char* contiguousData() { + if (!tinyData_.empty()) { + VELOX_CHECK_NULL(contiguousData_); + return tinyData_.data(); + } + VELOX_CHECK_NOT_NULL( + contiguousData_, "Entry does not have contiguous data"); + return static_cast(contiguousData_); + } + + const char* contiguousData() const { + if (!tinyData_.empty()) { + VELOX_CHECK_NULL(contiguousData_); + return tinyData_.data(); + } + VELOX_CHECK_NOT_NULL( + contiguousData_, "Entry does not have contiguous data"); + return static_cast(contiguousData_); + } + + /// Returns true if this entry has contiguous data. + bool hasContiguousData() const { + return contiguousData_ != nullptr || !tinyData_.empty(); } - char* tinyData() { - return tinyData_.empty() ? nullptr : tinyData_.data(); + uint64_t contiguousDataSize() const { + return size_; } + // Returns writable buffer ranges covering the first 'length' bytes of + // this entry's data. For small entries (tinyData), returns a single range. + // For contiguous entries, returns a single range. + // For larger entries (allocation-backed), returns one range per run. + std::vector> dataRanges(size_t length); + const FileCacheKey& key() const { return key_; } @@ -193,6 +226,13 @@ class AsyncDataCacheEntry { return size_; } + /// Returns the allocated capacity in bytes for this entry's data, + /// including any padding from the allocator. + int64_t dataCapacity() const; + + /// Updates the data-type-specific size and padding fields in 'stats'. + void updateDataStats(CacheStats& stats) const; + void touch() { accessStats_.touch(); } @@ -263,11 +303,10 @@ class AsyncDataCacheEntry { /// Sets access stats so that this is immediately evictable. void makeEvictable(); - /// Moves the promise out of 'this'. Used in order to handle the - /// promise within the lock of the cache shard, so not within private - /// methods of 'this'. - std::unique_ptr> movePromise() { - return std::move(promise_); + /// Returns true if this entry has been marked as immediately evictable + /// (lastUse == 0). + bool testingIsEvictable() const { + return accessStats_.lastUse == 0; } std::string toString() const; @@ -276,9 +315,19 @@ class AsyncDataCacheEntry { void release(); void addReference(); + // Frees all data storage (non-contiguous, contiguous, and tiny) with + // cache page accounting and resets size to zero. + void freeData(); + + // Moves the promise out of 'this'. Must be called inside the mutex of + // 'shard_'. + std::unique_ptr> movePromiseLocked() { + return std::move(promise_); + } + // Returns a future that will be realized when a caller can retry getting // 'this'. Must be called inside the mutex of 'shard_'. - folly::SemiFuture getFuture() { + folly::SemiFuture getFutureLocked() { if (promise_ == nullptr) { promise_ = std::make_unique>(); } @@ -290,8 +339,12 @@ class AsyncDataCacheEntry { CacheShard* const shard_; - // The data being cached. - memory::Allocation data_; + // The data being cached (non-contiguous page runs). + memory::Allocation nonContiguousData_; + + // Contiguous bytes allocated via allocateBytes. Populated when the entry is + // created with contiguous=true and size >= kTinyDataSize. + void* contiguousData_{nullptr}; // Contains the cached data if this is much smaller than a MemoryAllocator // page (kTinyDataSize). @@ -308,7 +361,7 @@ class AsyncDataCacheEntry { // True if 'this' is speculatively loaded. This is reset on first hit. Allows // catching a situation where prefetched entries get evicted before they are // hit. - bool isPrefetch_{false}; + tsan_atomic isPrefetch_{false}; // Sets after first use of a prefetched entry. Cleared by // getAndClearFirstUseFlag(). Does not require synchronization since used for @@ -457,6 +510,10 @@ class CoalescedLoad { return ""; } + /// Returns true if this is a load from SSD cache, false if from remote + /// storage. + virtual bool isSsdLoad() const = 0; + protected: // Makes entries for 'keys_' and loads their content. Elements of 'keys_' that // are already loaded or loading are expected to be left out. The returned @@ -488,14 +545,18 @@ struct CacheStats { /// Total size in 'tinyData_' int64_t tinySize{0}; - /// Total size in 'data_' + /// Total size in 'nonContiguousData_' int64_t largeSize{0}; /// Unused capacity in 'tinyData_'. int64_t tinyPadding{0}; - /// Unused capacity in 'data_'. + /// Unused capacity in 'nonContiguousData_'. int64_t largePadding{0}; /// Total number of entries. int32_t numEntries{0}; + /// Total number of tiny entries. + int32_t numTinyEntries{0}; + /// Total number of large entries. + int32_t numLargeEntries{0}; /// Number of entries that do not cache anything. int32_t numEmptyEntries{0}; /// Number of entries pinned for shared access. @@ -550,20 +611,35 @@ struct CacheStats { std::string toString() const; }; +using AcquiredMemory = memory::AcquiredMemory; + /// Collection of cache entries whose key hashes to the same shard of /// the hash number space. The cache population is divided into shards /// to decrease contention on the mutex for the key to entry mapping /// and other housekeeping. class CacheShard { public: + static constexpr uint64_t kMinBytesToEvict = 8UL << 20; // 8MB + CacheShard(AsyncDataCache* cache, double maxWriteRatio) : cache_(cache), maxWriteRatio_(maxWriteRatio) {} - /// See AsyncDataCache::findOrCreate. + /// See AsyncDataCache::findOrCreate. If 'contiguous' is true, the + /// entry's data is allocated as a single contiguous region. CachePin findOrCreate( RawFileCacheKey key, uint64_t size, - folly::SemiFuture* readyFuture); + bool contiguous = false, + folly::SemiFuture* readyFuture = nullptr); + + /// Finds a cache entry for 'key'. Returns a shared-mode pin if the entry + /// exists and is not exclusive. Returns an empty pin (inside optional) if + /// the entry is exclusive; if 'waitFuture' is not nullptr it is set to a + /// future realized when the entry is no longer exclusive. Returns + /// std::nullopt on miss. Does not create entries. + std::optional find( + RawFileCacheKey key, + folly::SemiFuture* waitFuture = nullptr); /// Marks the cache entry with given cache 'key' as immediate evictable. void makeEvictable(RawFileCacheKey key); @@ -571,6 +647,8 @@ class CacheShard { /// Returns true if there is an entry for 'key'. Updates access time. bool exists(RawFileCacheKey key) const; + bool testingIsEvictable(RawFileCacheKey key) const; + AsyncDataCache* cache() const { return cache_; } @@ -586,15 +664,15 @@ class CacheShard { /// Removes 'bytesToFree' worth of entries or as many entries as are not /// pinned. This favors first removing older and less frequently used entries. /// If 'evictAllUnpinned' is true, anything that is not pinned is evicted at - /// first sight. This is for out of memory emergencies. If 'pagesToAcquire' is - /// set, up to this amount is added to 'allocation'. A smaller amount can be - /// added if not enough evictable data is found. The function returns the - /// total evicted bytes. + /// first sight. This is for out of memory emergencies. If + /// 'bytesToAcquire' is set, up to this amount of non-contiguous pages + /// is moved into 'acquired'. A smaller amount can be added if not enough + /// evictable data is found. The function returns the total evicted bytes. uint64_t evict( uint64_t bytesToFree, bool evictAllUnpinned, - memory::MachinePageCount pagesToAcquire, - memory::Allocation& acquiredAllocation); + uint64_t bytesToAcquire, + AcquiredMemory& acquired); /// Removes 'entry' from 'this'. Removes a possible promise from the entry /// inside the shard mutex and returns it so that it can be realized outside @@ -629,8 +707,10 @@ class CacheShard { private: static constexpr uint32_t kMaxFreeEntries = 1 << 10; static constexpr int32_t kNoThreshold = std::numeric_limits::max(); + static constexpr int32_t kMaxEvictionSamples = 10; + static constexpr int32_t kEvictionPercentile = 80; - void calibrateThreshold(); + void calibrateThresholdLocked(); void removeEntryLocked(AsyncDataCacheEntry* entry); @@ -638,14 +718,35 @@ class CacheShard { // // TODO: consider to pass a size hint so as to select the a free entry which // already has the right amount of memory associated with it. - std::unique_ptr getFreeEntry(); + std::unique_ptr getFreeEntryLocked(); - CachePin initEntry(RawFileCacheKey key, AsyncDataCacheEntry* entry); + CachePin + initEntry(RawFileCacheKey key, bool contiguous, AsyncDataCacheEntry* entry); - void freeAllocations(std::vector& allocations); + // Looks up 'key' in the cache under mutex_. 'size' is the minimum acceptable + // entry size: pass 0 from find() to accept any size, or the required size + // from findOrCreate() to trigger stale-entry eviction when too small. + // Returns std::nullopt on miss (or after evicting a stale entry), + // an empty CachePin if the entry is exclusive, or a shared CachePin on hit. + std::optional lookupLocked( + RawFileCacheKey key, + uint64_t size, + folly::SemiFuture* waitFuture); void tryAddFreeEntry(std::unique_ptr&& entry); + // Acquires or frees evicted data from 'entry', including tiny data + // cleanup. If 'bytesToAcquire' > 0, moves data into 'acquired' and + // decrements 'bytesToAcquire'; otherwise moves data into 'toFree'. + // Increments 'largeEvicted' and 'tinyEvicted' accordingly. + void acquireEvictedData( + AsyncDataCacheEntry* entry, + uint64_t& bytesToAcquire, + AcquiredMemory& acquired, + AcquiredMemory& toFree, + int64_t& largeEvicted, + int64_t& tinyEvicted); + AsyncDataCache* const cache_; const double maxWriteRatio_; @@ -696,14 +797,20 @@ class CacheShard { class AsyncDataCache : public memory::Cache { public: + static constexpr int32_t kDefaultNumShards = 4; + struct Options { Options( double _maxWriteRatio = 0.7, double _ssdSavableRatio = 0.125, - int32_t _minSsdSavableBytes = 1 << 24) + int32_t _minSsdSavableBytes = 1 << 24, + int32_t _numShards = kDefaultNumShards, + uint64_t _ssdFlushThresholdBytes = 0) : maxWriteRatio(_maxWriteRatio), ssdSavableRatio(_ssdSavableRatio), - minSsdSavableBytes(_minSsdSavableBytes) {} + minSsdSavableBytes(_minSsdSavableBytes), + numShards(_numShards), + ssdFlushThresholdBytes(_ssdFlushThresholdBytes) {} /// The max ratio of the number of in-memory cache entries being written to /// SSD cache over the total number of cache entries. This is to control SSD @@ -722,6 +829,16 @@ class AsyncDataCache : public memory::Cache { /// NOTE: we only write to SSD cache when both above conditions satisfy. The /// default is 16MB. int32_t minSsdSavableBytes; + + /// The number of shards for the cache. The cache population is divided into + /// shards to decrease contention on the mutex for the key to entry mapping + /// and other housekeeping. Must be a power of 2. + int32_t numShards; + + /// The maximum threshold in bytes for triggering SSD flush. When the + /// accumulated SSD-savable bytes exceed this value, a flush to SSD is + /// triggered. Set to 0 to disable this threshold (default). + uint64_t ssdFlushThresholdBytes; }; AsyncDataCache( @@ -757,7 +874,7 @@ class AsyncDataCache : public memory::Cache { /// for memory arbitration to work. bool makeSpace( memory::MachinePageCount numPages, - std::function allocate) override; + std::function allocate) override; uint64_t shrink(uint64_t targetBytes) override; @@ -767,18 +884,30 @@ class AsyncDataCache : public memory::Cache { /// Finds or creates a cache entry corresponding to 'key'. The entry /// is returned in 'pin'. If the entry is new, it is pinned in - /// exclusive mode and its 'data_' has uninitialized space for at - /// least 'size' bytes. If the entry is in cache and already filled, - /// the pin is in shared mode. If the entry is in exclusive mode for - /// some other pin, the pin is empty. If 'waitFuture' is not nullptr - /// and the pin is exclusive on some other pin, this is set to a - /// future that is realized when the pin is no longer exclusive. When - /// the future is realized, the caller may retry findOrCreate(). - /// runtime error with code kNoCacheSpace if there is no space to create the - /// new entry after evicting any unpinned content. + /// exclusive mode and its data has uninitialized space for at least + /// 'size' bytes. If 'contiguous' is true, the data is allocated as a + /// single contiguous region; otherwise it is allocated as potentially + /// non-contiguous page runs. If the entry is in cache and already + /// filled, the pin is in shared mode. If the entry is in exclusive + /// mode for some other pin, the pin is empty. If 'waitFuture' is not + /// nullptr and the pin is exclusive on some other pin, this is set to + /// a future that is realized when the pin is no longer exclusive. + /// When the future is realized, the caller may retry findOrCreate(). + /// runtime error with code kNoCacheSpace if there is no space to + /// create the new entry after evicting any unpinned content. CachePin findOrCreate( RawFileCacheKey key, uint64_t size, + bool contiguous = false, + folly::SemiFuture* waitFuture = nullptr); + + /// Finds a cache entry for 'key'. Returns a shared-mode pin if the entry + /// exists and is not exclusive. Returns an empty pin (inside optional) if + /// the entry is exclusive; if 'waitFuture' is not nullptr it is set to a + /// future realized when the entry is no longer exclusive. Returns + /// std::nullopt on miss. + std::optional find( + RawFileCacheKey key, folly::SemiFuture* waitFuture = nullptr); /// Marks the cache entry with given cache 'key' as immediate evictable. @@ -787,6 +916,11 @@ class AsyncDataCache : public memory::Cache { /// Returns true if there is an entry for 'key'. Updates access time. bool exists(RawFileCacheKey key) const; + /// Returns true if the entry for 'key' exists and has been marked as + /// immediately evictable (lastUse == 0). Returns false if entry does not + /// exist or is not evictable. Does not update access stats. + bool testingIsEvictable(RawFileCacheKey key) const; + #if defined(__has_feature) #if __has_feature(thread_sanitizer) __attribute__((__no_sanitize__("thread"))) @@ -794,16 +928,14 @@ class AsyncDataCache : public memory::Cache { #endif /// Returns snapshot of the aggregated stats from all shards and the stats of /// SSD cache if used. - virtual CacheStats - refreshStats() const; + virtual CacheStats refreshStats() const; /// If 'details' is true, returns the stats of the backing memory allocator /// and ssd cache. Otherwise, only returns the cache stats. std::string toString(bool details = true) const; - memory::MachinePageCount incrementCachedPages(int64_t pages) { - // The counter is unsigned and the increment is signed. - return cachedPages_.fetch_add(pages) + pages; + memory::MachinePageCount cachedPages() const { + return cachedPages_.load(); } memory::MachinePageCount incrementPrefetchPages(int64_t pages) { @@ -839,14 +971,16 @@ class AsyncDataCache : public memory::Cache { /// Looks up a pin for each in 'keys' and skips all loading or loaded pins. /// Calls processPin for each exclusive pin. processPin must move its argument /// if it wants to use it afterwards. sizeFunc(i) returns the size of the ith - /// item in 'keys'. + /// item in 'keys'. If 'contiguous' is true, new entries are allocated as + /// single contiguous regions instead of non-contiguous page runs. template void makePins( const std::vector& keys, const SizeFunc& sizeFunc, - const ProcessPin& processPin) { + const ProcessPin& processPin, + bool contiguous = false) { for (size_t i = 0; i < keys.size(); ++i) { - auto pin = findOrCreate(keys[i], sizeFunc(i), nullptr); + auto pin = findOrCreate(keys[i], sizeFunc(i), contiguous); if (pin.empty() || pin.checkedEntry()->isShared()) { continue; } @@ -876,14 +1010,14 @@ class AsyncDataCache : public memory::Cache { void clear(); private: - static constexpr int32_t kNumShards = 4; // Must be power of 2. - static constexpr int32_t kShardMask = kNumShards - 1; + // True if acquired bytes plus available allocator capacity is enough + // for 'requestBytes'. + bool canTryAllocate(uint64_t requestBytes, const AcquiredMemory& acquired) + const; - // True if 'acquired' has more pages than 'numPages' or allocator has space - // for numPages - acquired pages of more allocation. - bool canTryAllocate( - memory::MachinePageCount numPages, - const memory::Allocation& acquired) const; + memory::MachinePageCount incrementCachedPages(int64_t pages) { + return cachedPages_.fetch_add(pages) + pages; + } static AsyncDataCache** getInstancePtr(); @@ -891,6 +1025,10 @@ class AsyncDataCache : public memory::Cache { void backoff(int32_t counter); const Options opts_; + // Number of shards. Must be a power of 2. + const int32_t numShards_; + // Bitmask for efficient shard index calculation (numShards_ - 1). + const int32_t shardMask_; memory::MemoryAllocator* const allocator_; std::unique_ptr ssdCache_; std::vector> shards_; @@ -926,6 +1064,8 @@ class AsyncDataCache : public memory::Cache { // for setting staggered backoff. Mutexes are not allowed for this. std::atomic numThreadsInAllocate_{0}; + friend class AsyncDataCacheEntry; + friend class CacheShard; friend class test::AsyncDataCacheTestHelper; }; diff --git a/velox/common/caching/CMakeLists.txt b/velox/common/caching/CMakeLists.txt index bc09b5552c8..2829d2c798b 100644 --- a/velox/common/caching/CMakeLists.txt +++ b/velox/common/caching/CMakeLists.txt @@ -16,17 +16,31 @@ velox_add_library( velox_caching AsyncDataCache.cpp CacheTTLController.cpp + FileHandle.cpp FileIds.cpp ScanTracker.cpp SsdCache.cpp SsdFile.cpp SsdFileTracker.cpp StringIdMap.cpp + HEADERS + AsyncDataCache.h + CacheTTLController.h + FileGroupStats.h + FileHandle.h + FileIds.h + FileProperties.h + ScanTracker.h + SsdCache.h + SsdFile.h + SsdFileTracker.h + StringIdMap.h ) velox_link_libraries( velox_caching PUBLIC velox_common_base + velox_common_config velox_exception velox_file velox_memory @@ -40,3 +54,7 @@ velox_link_libraries( if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) endif() + +velox_add_library(velox_simple_lru_cache INTERFACE HEADERS SimpleLRUCache.h) + +velox_add_library(velox_cached_factory INTERFACE HEADERS CachedFactory.h) diff --git a/velox/common/caching/FileHandle.cpp b/velox/common/caching/FileHandle.cpp new file mode 100644 index 00000000000..93190caad72 --- /dev/null +++ b/velox/common/caching/FileHandle.cpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/caching/FileHandle.h" +#include "velox/common/base/Counters.h" +#include "velox/common/base/StatsReporter.h" +#include "velox/common/file/FileSystems.h" +#include "velox/common/time/Timer.h" + +#include + +namespace facebook::velox { + +uint64_t FileHandleSizer::operator()(const FileHandle& fileHandle) { + // TODO: add to support variable file cache size support when the file system + // underneath supports. + return 1; +} + +namespace { +// The group tracking is at the level of the directory, i.e. Hive partition. +std::string groupName(const std::string& filename) { + const char* slash = strrchr(filename.c_str(), '/'); + return slash ? std::string(filename.data(), slash - filename.data()) + : filename; +} +} // namespace + +std::unique_ptr FileHandleGenerator::operator()( + const FileHandleKey& key, + const FileProperties* properties, + IoStats* stats) { + // We have seen cases where drivers are stuck when creating file handles. + // Adding a trace here to spot this more easily in future. + process::TraceContext trace("FileHandleGenerator::operator()"); + uint64_t elapsedTimeUs{0}; + std::unique_ptr fileHandle; + { + MicrosecondTimer timer(&elapsedTimeUs); + fileHandle = std::make_unique(); + filesystems::FileOptions options; + options.stats = stats; + options.tokenProvider = key.tokenProvider; + if (properties) { + options.fileSize = properties->fileSize; + options.readRangeHint = properties->readRangeHint; + options.extraFileInfo = properties->extraFileInfo; + options.fileReadOps = properties->fileReadOps; + } + const auto& filename = key.filename; + fileHandle->file = filesystems::getFileSystem(filename, properties_) + ->openFileForRead(filename, options); + fileHandle->uuid = StringIdLease(fileIds(), filename); + fileHandle->groupId = StringIdLease(fileIds(), groupName(filename)); + VLOG(1) << "Generating file handle for: " << filename + << " uuid: " << fileHandle->uuid.id(); + } + RECORD_HISTOGRAM_METRIC_VALUE( + kMetricHiveFileHandleGenerateLatencyMs, elapsedTimeUs / 1000); + return fileHandle; +} + +} // namespace facebook::velox diff --git a/velox/common/caching/FileHandle.h b/velox/common/caching/FileHandle.h new file mode 100644 index 00000000000..71a00ff5db0 --- /dev/null +++ b/velox/common/caching/FileHandle.h @@ -0,0 +1,123 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "velox/common/base/BitUtil.h" +#include "velox/common/caching/CachedFactory.h" +#include "velox/common/caching/FileIds.h" +#include "velox/common/caching/FileProperties.h" +#include "velox/common/config/Config.h" +#include "velox/common/file/File.h" +#include "velox/common/file/TokenProvider.h" + +namespace facebook::velox { + +/// fileReadOps keys for passing table identity (db and table name) +/// through the file handle layer. Written by the connector (FileSplitReader) +/// and read by storage implementations to build per-table access token keys. +constexpr std::string_view kDbNameKey = "dbName"; +constexpr std::string_view kTableNameKey = "tableName"; + +/// File pointer plus cache-friendly identifiers for downstream caching. +struct FileHandle { + std::shared_ptr file; + + // Each time we make a new FileHandle we assign it a uuid and use that id as + // the identifier in downstream data caching structures. This saves a lot of + // memory compared to using the filename as the identifier. + StringIdLease uuid; + + // Id for the group of files this belongs to, e.g. its + // directory. Used for coarse granularity access tracking, for + // example to decide placing on SSD. + StringIdLease groupId; +}; + +/// Estimates the memory usage of a FileHandle object. +struct FileHandleSizer { + uint64_t operator()(const FileHandle& a); +}; + +struct FileHandleKey { + std::string filename; + std::shared_ptr tokenProvider{nullptr}; + + bool operator==(const FileHandleKey& other) const { + if (filename != other.filename) { + return false; + } + + if (tokenProvider == other.tokenProvider) { + return true; + } + + if (!tokenProvider || !other.tokenProvider) { + return false; + } + + return tokenProvider->equals(*other.tokenProvider); + } +}; + +} // namespace facebook::velox + +namespace std { +template <> +struct hash { + size_t operator()(const facebook::velox::FileHandleKey& key) const noexcept { + size_t filenameHash = std::hash()(key.filename); + return key.tokenProvider ? facebook::velox::bits::hashMix( + filenameHash, key.tokenProvider->hash()) + : filenameHash; + } +}; +} // namespace std + +namespace facebook::velox { +using FileHandleCache = + SimpleLRUCache; + +/// Creates FileHandles via the Generator interface the CachedFactory requires. +class FileHandleGenerator { + public: + FileHandleGenerator() {} + FileHandleGenerator(std::shared_ptr properties) + : properties_(std::move(properties)) {} + std::unique_ptr operator()( + const FileHandleKey& filename, + const FileProperties* properties, + IoStats* stats); + + private: + const std::shared_ptr properties_; +}; + +using FileHandleFactory = CachedFactory< + FileHandleKey, + FileHandle, + FileHandleGenerator, + FileProperties, + IoStats, + FileHandleSizer>; + +using FileHandleCachedPtr = CachedPtr; + +using FileHandleCacheStats = SimpleLRUCacheStats; + +} // namespace facebook::velox diff --git a/velox/common/caching/FileProperties.h b/velox/common/caching/FileProperties.h new file mode 100644 index 00000000000..9827a7e71f9 --- /dev/null +++ b/velox/common/caching/FileProperties.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include + +namespace facebook::velox { + +struct FileProperties { + std::optional fileSize; + std::optional modificationTime; + std::optional readRangeHint{std::nullopt}; + std::shared_ptr extraFileInfo{nullptr}; + folly::F14FastMap fileReadOps{}; +}; + +} // namespace facebook::velox diff --git a/velox/common/caching/SsdCache.cpp b/velox/common/caching/SsdCache.cpp index 2347b45982f..50d32d6ba3e 100644 --- a/velox/common/caching/SsdCache.cpp +++ b/velox/common/caching/SsdCache.cpp @@ -32,7 +32,8 @@ SsdCache::SsdCache(const Config& config) : filePrefix_(config.filePrefix), numShards_(config.numShards), groupStats_(std::make_unique()), - executor_(config.executor) { + executor_(config.executor), + maxEntries_(config.maxEntries) { // Make sure the given path of Ssd files has the prefix for local file system. // Local file system would be derived based on the prefix. VELOX_CHECK( @@ -58,6 +59,9 @@ SsdCache::SsdCache(const Config& config) const uint64_t sizeQuantum = numShards_ * SsdFile::kRegionSize; const int32_t fileMaxRegions = bits::roundUp(config.maxBytes, sizeQuantum) / sizeQuantum; + // Distribute maxEntries across shards + const uint64_t maxEntriesPerShard = + maxEntries_ == 0 ? 0 : bits::divRoundUp(maxEntries_, numShards_); for (auto i = 0; i < numShards_; ++i) { const auto fileConfig = SsdFile::Config( fmt::format("{}{}", filePrefix_, i), @@ -67,6 +71,7 @@ SsdCache::SsdCache(const Config& config) config.disableFileCow, config.checksumEnabled, checksumReadVerificationEnabled, + maxEntriesPerShard, executor_); files_.push_back(std::make_unique(fileConfig)); } @@ -90,7 +95,8 @@ bool SsdCache::startWrite() { } void SsdCache::write(std::vector pins) { - VELOX_CHECK_EQ(numShards_, writesInProgress_); + VELOX_CHECK_EQ( + numShards_, writesInProgress_, "startWrite() have not been called"); TestValue::adjust("facebook::velox::cache::SsdCache::write", this); @@ -98,7 +104,7 @@ void SsdCache::write(std::vector pins) { uint64_t bytes = 0; std::vector> shards(numShards_); - for (auto& pin : pins) { + for (const auto& pin : pins) { bytes += pin.checkedEntry()->size(); const auto& target = file(pin.checkedEntry()->key().fileNum.id()); shards[target.shardId()].push_back(std::move(pin)); @@ -135,9 +141,10 @@ void SsdCache::write(std::vector pins) { // Typically occurs every few GB. Allows detecting unusually slow rates // from failing devices. VELOX_SSD_CACHE_LOG(INFO) << fmt::format( - "Wrote {}, {} bytes/s", + "Wrote {} to SSD, {} bytes/s", succinctBytes(bytes), - static_cast(bytes) / (getCurrentTimeMicro() - startTimeUs)); + static_cast(bytes) * 1'000'000 / + (getCurrentTimeMicro() - startTimeUs)); } }); } @@ -191,7 +198,11 @@ std::string SsdCache::toString() const { out << "Ssd cache IO: Write " << succinctBytes(data.bytesWritten) << " read " << succinctBytes(data.bytesRead) << " Size " << succinctBytes(capacity) << " Occupied " << succinctBytes(data.bytesCached); - out << " " << (data.entriesCached >> 10) << "K entries."; + out << " " << (data.entriesCached >> 10) << "K entries"; + if (maxEntries_ > 0) { + out << " (max " << (maxEntries_ >> 10) << "K)"; + } + out << "."; out << "\nGroupStats: " << groupStats_->toString(capacity); return out.str(); } @@ -207,7 +218,7 @@ void SsdCache::shutdown() { VELOX_SSD_CACHE_LOG(INFO) << "SSD cache is shutting down"; while (writesInProgress_) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); // NOLINT + std::this_thread::sleep_for(kWriteWaitMs); } for (auto& file : files_) { file->checkpoint(true); @@ -223,7 +234,7 @@ void SsdCache::clear() { void SsdCache::waitForWriteToFinish() { while (writesInProgress_ != 0) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); // NOLINT + std::this_thread::sleep_for(kWriteWaitMs); } } diff --git a/velox/common/caching/SsdCache.h b/velox/common/caching/SsdCache.h index e4670594948..2ec837958c1 100644 --- a/velox/common/caching/SsdCache.h +++ b/velox/common/caching/SsdCache.h @@ -23,6 +23,8 @@ namespace facebook::velox::cache { #define VELOX_SSD_CACHE_LOG_PREFIX "[SSDCA] " #define VELOX_SSD_CACHE_LOG(severity) \ LOG(severity) << VELOX_SSD_CACHE_LOG_PREFIX +#define VELOX_SSD_CACHE_LOG_EVERY_MS(severity, ms) \ + FB_LOG_EVERY_MS(severity, ms) << VELOX_SSD_CACHE_LOG_PREFIX namespace test { class SsdCacheTestHelper; @@ -41,7 +43,8 @@ class SsdCache { uint64_t _checkpointIntervalBytes = 0, bool _disableFileCow = false, bool _checksumEnabled = false, - bool _checksumReadVerificationEnabled = false) + bool _checksumReadVerificationEnabled = false, + uint64_t _maxEntries = 0) : filePrefix(_filePrefix), maxBytes(_maxBytes), numShards(_numShards), @@ -49,7 +52,8 @@ class SsdCache { disableFileCow(_disableFileCow), checksumEnabled(_checksumEnabled), checksumReadVerificationEnabled(_checksumReadVerificationEnabled), - executor(_executor){}; + executor(_executor), + maxEntries(_maxEntries) {} std::string filePrefix; uint64_t maxBytes; @@ -71,6 +75,10 @@ class SsdCache { /// Executor for async fsync in checkpoint. folly::Executor* executor; + /// Maximum number of SSD cache entries allowed. A value of 0 means no + /// limit. When the limit is reached, new entry writes will be skipped. + uint64_t maxEntries; + std::string toString() const { return fmt::format( "{} shards, capacity {}, checkpoint size {}, file cow {}, checksum {}, read verification {}", @@ -169,6 +177,9 @@ class SsdCache { void waitForWriteToFinish(); private: + // Polling interval for waiting on write completion + static constexpr auto kWriteWaitMs = std::chrono::milliseconds(100); + void checkNotShutdownLocked() { VELOX_CHECK( !shutdown_, "Unexpected write after SSD cache has been shutdown"); @@ -179,6 +190,9 @@ class SsdCache { // Stats for selecting entries to save from AsyncDataCache. const std::unique_ptr groupStats_; folly::Executor* const executor_; + // Maximum number of SSD cache entries allowed. 0 means no limit. + const uint64_t maxEntries_; + mutable std::mutex mutex_; std::vector> files_; diff --git a/velox/common/caching/SsdFile.cpp b/velox/common/caching/SsdFile.cpp index fa5c46f22da..0044b152e37 100644 --- a/velox/common/caching/SsdFile.cpp +++ b/velox/common/caching/SsdFile.cpp @@ -42,11 +42,12 @@ namespace facebook::velox::cache { namespace { void addEntryToIovecs(AsyncDataCacheEntry& entry, std::vector& iovecs) { - if (entry.tinyData() != nullptr) { - iovecs.push_back({entry.tinyData(), static_cast(entry.size())}); + if (entry.hasContiguousData()) { + iovecs.push_back( + {entry.contiguousData(), static_cast(entry.size())}); return; } - const auto& data = entry.data(); + const auto& data = entry.nonContiguousData(); iovecs.reserve(iovecs.size() + data.numRuns()); int64_t bytesLeft = entry.size(); for (auto i = 0; i < data.numRuns(); ++i) { @@ -60,12 +61,11 @@ void addEntryToIovecs(AsyncDataCacheEntry& entry, std::vector& iovecs) { } } -// Returns the number of entries in a cache 'entry'. uint32_t numIoVectorsFromEntry(AsyncDataCacheEntry& entry) { - if (entry.tinyData() != nullptr) { + if (entry.hasContiguousData()) { return 1; } - return entry.data().numRuns(); + return entry.nonContiguousData().numRuns(); } } // namespace @@ -112,9 +112,10 @@ SsdFile::SsdFile(const Config& config) checksumReadVerificationEnabled_( config.checksumEnabled && config.checksumReadVerificationEnabled), shardId_(config.shardId), + maxEntries_(config.maxEntries), + executor_(config.executor), fs_(filesystems::getFileSystem(fileName_, nullptr)), - checkpointIntervalBytes_(config.checkpointIntervalBytes), - executor_(config.executor) { + checkpointIntervalBytes_(config.checkpointIntervalBytes) { process::TraceContext trace("SsdFile::SsdFile"); filesystems::FileOptions fileOptions; fileOptions.shouldThrowOnFileAlreadyExists = false; @@ -137,6 +138,8 @@ SsdFile::SsdFile(const Config& config) regionPins_.resize(maxRegions_, 0); if (checkpointEnabled()) { initializeCheckpoint(); + } else { + removeStaleRecoveryFiles(); } if (disableFileCow_) { @@ -231,6 +234,9 @@ CoalesceIoStats SsdFile::load( read(offset, buffers); }); + common::testutil::TestValue::adjust( + "facebook::velox::cache::SsdFile::load", this); + for (auto i = 0; i < ssdPins.size(); ++i) { pins[i].checkedEntry()->setSsdFile(this, ssdPins[i].run().offset()); auto* entry = pins[i].checkedEntry(); @@ -307,8 +313,19 @@ bool SsdFile::growOrEvictLocked() { } } - auto candidates = - tracker_.findEvictionCandidates(3, numRegions_, regionPins_); + // If SSD is in no space state and future eviction logging cannot go through, + // skip eviction, to avoid data inconsistency for checkpointing. Eviction log + // is up to date. Separately when SSD is in no space state, + // growOrEvictLocked() would not be invoked by write() in the first place. + if (state_.load() == State::kNoSpace) { + VELOX_SSD_CACHE_LOG_EVERY_MS(WARNING, 1'000) + << "Failed to grow cache file " << fileName_ + << " due to SSD in no space state."; + return false; + } + + auto candidates = tracker_.findEvictionCandidates( + kNumEvictionCandidates, numRegions_, regionPins_); if (candidates.empty()) { suspended_ = true; return false; @@ -345,6 +362,23 @@ void SsdFile::clearRegionEntriesLocked(const std::vector& regions) { void SsdFile::write(std::vector& pins) { process::TraceContext trace("SsdFile::write"); + + if (state_.load() == State::kNoSpace) { + ++stats_.writeSsdDropped; + VELOX_SSD_CACHE_LOG_EVERY_MS(WARNING, 10'000) + << "SSD file write is dropped in no space state."; + return; + } + + // Check entry count limit before writing + if (maxEntries_ > 0) { + std::shared_lock l(mutex_); + if (entries_.size() + pins.size() >= maxEntries_) { + ++stats_.writeSsdExceedEntryLimit; + return; + } + } + // Sorts the pins by their file/offset. In this way what is adjacent in // storage is likely adjacent on SSD. std::sort(pins.begin(), pins.end()); @@ -438,16 +472,29 @@ bool SsdFile::write( int64_t offset, int64_t length, const std::vector& iovecs) { + VELOX_DCHECK_NE(state_.load(), State::kNoSpace); + try { writeFile_->write(iovecs, offset, length); return true; } catch (const std::exception&) { + const int err = errno; VELOX_SSD_CACHE_LOG(ERROR) << "Failed to write to SSD, file name: " << fileName_ << ", size: " << iovecs.size() << ", offset: " << offset - << ", error code: " << errno - << ", error string: " << folly::errnoStr(errno); + << ", error code: " << err + << ", error string: " << folly::errnoStr(err); ++stats_.writeSsdErrors; + + if (err == ENOSPC) { + if (state_.exchange(State::kNoSpace) != State::kNoSpace) { + VELOX_SSD_CACHE_LOG(WARNING) + << "State of cache file " << fileName_ << " transits to " + << stateString(State::kNoSpace); + } + ++stats_.writeSsdNoSpaceErrors; + } + return false; } } @@ -469,12 +516,12 @@ void SsdFile::verifyWrite(AsyncDataCacheEntry& entry, SsdRun ssdRun) { const auto rc = readFile_->pread(ssdRun.offset(), entry.size(), testData.get()); VELOX_CHECK_EQ(rc.size(), entry.size()); - if (entry.tinyData() != nullptr) { - if (::memcmp(testData.get(), entry.tinyData(), entry.size()) != 0) { + if (entry.hasContiguousData()) { + if (::memcmp(testData.get(), entry.contiguousData(), entry.size()) != 0) { VELOX_FAIL("bad read back"); } } else { - const auto& data = entry.data(); + const auto& data = entry.nonContiguousData(); int64_t bytesLeft = entry.size(); int64_t offset = 0; for (auto i = 0; i < data.numRuns(); ++i) { @@ -520,6 +567,9 @@ void SsdFile::updateStats(SsdCacheStats& stats) const { stats.deleteMetaFileErrors += stats_.deleteMetaFileErrors; stats.growFileErrors += stats_.growFileErrors; stats.writeSsdErrors += stats_.writeSsdErrors; + stats.writeSsdNoSpaceErrors += stats_.writeSsdNoSpaceErrors; + stats.writeSsdDropped += stats_.writeSsdDropped; + stats.writeSsdExceedEntryLimit += stats_.writeSsdExceedEntryLimit; stats.writeCheckpointErrors += stats_.writeCheckpointErrors; stats.readSsdErrors += stats_.readSsdErrors; stats.readCheckpointErrors += stats_.readCheckpointErrors; @@ -668,6 +718,28 @@ void SsdFile::deleteFile(std::unique_ptr file) { } } +void SsdFile::removeStaleRecoveryFiles() { + const auto checkpointPath = checkpointFilePath(); + if (fs_->exists(checkpointPath)) { + try { + fs_->remove(checkpointPath); + } catch (const std::exception& e) { + VELOX_SSD_CACHE_LOG(WARNING) << "Failed to remove stale checkpoint file " + << checkpointPath << ": " << e.what(); + } + } + const auto logPath = evictLogFilePath(); + if (fs_->exists(logPath)) { + try { + fs_->remove(logPath); + } catch (const std::exception& e) { + VELOX_SSD_CACHE_LOG(WARNING) + << "Failed to remove stale eviction log file " << logPath << ": " + << e.what(); + } + } +} + void SsdFile::checkpointError(int32_t rc, const std::string& error) { VELOX_SSD_CACHE_LOG(ERROR) << error << " with rc=" << rc @@ -711,9 +783,10 @@ void SsdFile::maybeFlushCheckpointBuffer(uint32_t appendBytes, bool force) { (force || checkpointBufferedDataSize_ + appendBytes >= kCheckpointBufferSize)) { VELOX_CHECK_NOT_NULL(checkpointBuffer_); - checkpointWriteFile_->append(std::string_view( - static_cast(checkpointBuffer_), - checkpointBufferedDataSize_)); + checkpointWriteFile_->append( + std::string_view( + static_cast(checkpointBuffer_), + checkpointBufferedDataSize_)); checkpointBufferedDataSize_ = 0; } } @@ -830,6 +903,23 @@ void SsdFile::checkpoint(bool force) { } } +// static +std::string SsdFile::stateString(State state) { + switch (state) { + case State::kActive: + return "Active"; + case State::kNoSpace: + return "NoSpace"; + default: + return fmt::format("UNKNOWN: {}", static_cast(state)); + } +} + +std::ostream& operator<<(std::ostream& out, const SsdFile::State& state) { + out << SsdFile::stateString(state); + return out; +} + void SsdFile::initializeCheckpoint() { if (!checkpointEnabled()) { return; @@ -875,11 +965,11 @@ void SsdFile::initializeCheckpoint() { uint32_t SsdFile::checksumEntry(const AsyncDataCacheEntry& entry) const { bits::Crc32 crc; - if (entry.tinyData()) { - crc.process_bytes(entry.tinyData(), entry.size()); + if (entry.hasContiguousData()) { + crc.process_bytes(entry.contiguousData(), entry.size()); } else { int64_t bytesLeft = entry.size(); - const auto& data = entry.data(); + const auto& data = entry.nonContiguousData(); for (auto i = 0; i < data.numRuns() && bytesLeft > 0; ++i) { const auto run = data.runAt(i); const auto bytesToProcess = std::min(bytesLeft, run.numBytes()); @@ -964,7 +1054,7 @@ void SsdFile::readCheckpoint() { auto checkpointReadFile = fs_->openFileForRead(checkpointPath); stream = std::make_unique( std::move(checkpointReadFile), - 1 << 20, + kCheckpointReadBufferSize, memory::memoryManager()->cachePool()); } catch (std::exception& e) { ++stats_.openCheckpointErrors; diff --git a/velox/common/caching/SsdFile.h b/velox/common/caching/SsdFile.h index c9ab1f5b301..e729d836527 100644 --- a/velox/common/caching/SsdFile.h +++ b/velox/common/caching/SsdFile.h @@ -32,16 +32,18 @@ class SsdFileTestHelper; class SsdCacheTestHelper; } // namespace test -/// A 64 bit word describing a SSD cache entry in an SsdFile. The low 23 bits -/// are the size, for a maximum entry size of 8MB. The high bits are the offset. +/// The 'fileBits_' field is a 64 bit word describing a SSD cache entry in an +/// SsdFile. The low 23 bits are the size, for a maximum entry size of 8MB. The +/// high 41 bits are the offset. The 'checksum_' field is optional and is used +/// only when the checksum feature is enabled, otherwise, its value is always 0. class SsdRun { public: static constexpr int32_t kSizeBits = 23; - SsdRun() : fileBits_(0) {} + SsdRun() = default; SsdRun(uint64_t offset, uint32_t size, uint32_t checksum) - : fileBits_((offset << kSizeBits) | ((size - 1))), checksum_(checksum) { + : fileBits_((offset << kSizeBits) | (size - 1)), checksum_(checksum) { VELOX_CHECK_LT(offset, 1L << (64 - kSizeBits)); VELOX_CHECK_NE(size, 0); VELOX_CHECK_LE(size, 1 << kSizeBits); @@ -58,9 +60,11 @@ class SsdRun { checksum_ = other.checksum_; } - void operator=(SsdRun&& other) { + void operator=(SsdRun&& other) noexcept { fileBits_ = other.fileBits_; checksum_ = other.checksum_; + other.fileBits_ = 0; + other.checksum_ = 0; } uint64_t offset() const { @@ -83,8 +87,8 @@ class SsdRun { private: // Contains the file offset and size. - uint64_t fileBits_; - uint32_t checksum_; + uint64_t fileBits_{0}; + uint32_t checksum_{0}; }; /// Represents an SsdFile entry that is planned for load or being loaded. This @@ -164,7 +168,9 @@ struct SsdCacheStats { deleteMetaFileErrors = tsanAtomicValue(other.deleteMetaFileErrors); growFileErrors = tsanAtomicValue(other.growFileErrors); writeSsdErrors = tsanAtomicValue(other.writeSsdErrors); + writeSsdNoSpaceErrors = tsanAtomicValue(other.writeSsdNoSpaceErrors); writeSsdDropped = tsanAtomicValue(other.writeSsdDropped); + writeSsdExceedEntryLimit = tsanAtomicValue(other.writeSsdExceedEntryLimit); writeCheckpointErrors = tsanAtomicValue(other.writeCheckpointErrors); readSsdErrors = tsanAtomicValue(other.readSsdErrors); readCheckpointErrors = tsanAtomicValue(other.readCheckpointErrors); @@ -193,7 +199,11 @@ struct SsdCacheStats { deleteMetaFileErrors - other.deleteMetaFileErrors; result.growFileErrors = growFileErrors - other.growFileErrors; result.writeSsdErrors = writeSsdErrors - other.writeSsdErrors; + result.writeSsdNoSpaceErrors = + writeSsdNoSpaceErrors - other.writeSsdNoSpaceErrors; result.writeSsdDropped = writeSsdDropped - other.writeSsdDropped; + result.writeSsdExceedEntryLimit = + writeSsdExceedEntryLimit - other.writeSsdExceedEntryLimit; result.writeCheckpointErrors = writeCheckpointErrors - other.writeCheckpointErrors; result.readSsdCorruptions = readSsdCorruptions - other.readSsdCorruptions; @@ -232,7 +242,9 @@ struct SsdCacheStats { tsan_atomic deleteMetaFileErrors{0}; tsan_atomic growFileErrors{0}; tsan_atomic writeSsdErrors{0}; + tsan_atomic writeSsdNoSpaceErrors{0}; tsan_atomic writeSsdDropped{0}; + tsan_atomic writeSsdExceedEntryLimit{0}; tsan_atomic writeCheckpointErrors{0}; tsan_atomic readSsdErrors{0}; tsan_atomic readCheckpointErrors{0}; @@ -257,6 +269,7 @@ class SsdFile { bool _disableFileCow = false, bool _checksumEnabled = false, bool _checksumReadVerificationEnabled = false, + uint64_t _maxEntries = 0, folly::Executor* _executor = nullptr) : fileName(_fileName), shardId(_shardId), @@ -266,7 +279,8 @@ class SsdFile { checksumEnabled(_checksumEnabled), checksumReadVerificationEnabled( _checksumEnabled && _checksumReadVerificationEnabled), - executor(_executor){}; + maxEntries(_maxEntries), + executor(_executor) {} /// Name of cache file, used as prefix for checkpoint files. const std::string fileName; @@ -279,19 +293,28 @@ class SsdFile { /// Checkpoint after every 'checkpointIntervalBytes' written into this /// file. 0 means no checkpointing. This is set to 0 if checkpointing fails. - uint64_t checkpointIntervalBytes; + const uint64_t checkpointIntervalBytes; /// True if copy on write should be disabled. - bool disableFileCow; + const bool disableFileCow; /// If true, checksum write to SSD is enabled. - bool checksumEnabled; + const bool checksumEnabled; /// If true, checksum read verification from SSD is enabled. - bool checksumReadVerificationEnabled; + const bool checksumReadVerificationEnabled; + + /// Maximum number of SSD cache entries allowed. A value of 0 means no + /// limit. When the limit is reached, new entry writes will be skipped. + const uint64_t maxEntries; /// Executor for async fsync in checkpoint. - folly::Executor* executor; + folly::Executor* const executor; + }; + + enum class State : uint8_t { + kActive, + kNoSpace, }; static constexpr uint64_t kRegionSize = 1 << 26; // 64MB @@ -300,6 +323,9 @@ class SsdFile { /// filename. SsdFile(const Config& config); + /// Convert State to std::string. + static std::string stateString(State state); + /// Adds entries of 'pins' to this file. 'pins' must be in read mode and /// those pins that are successfully added to SSD are marked as being on SSD. /// The file of the entries must be a file that is backed by 'this'. @@ -310,6 +336,7 @@ class SsdFile { /// Erases 'key' bool erase(RawFileCacheKey key); + /// Copies the data in 'ssdPins' into 'pins'. Coalesces IO for nearby /// entries if they are in ascending order and near enough. CoalesceIoStats load( @@ -386,8 +413,19 @@ class SsdFile { // Magic number at end of completed checkpoint file. static constexpr int64_t kCheckpointEndMarker = 0xcbedf11e; + // Maximum percentage of erased entries in a region before it becomes + // eligible for clearing and reuse. When more than 50% of a region's + // entries have been erased (e.g., via TTL eviction), the region can be + // cleared and added back to the writable regions pool. static constexpr int kMaxErasedSizePct = 50; + // Number of eviction candidates to consider when selecting regions to + // evict. + static constexpr int32_t kNumEvictionCandidates = 3; + + // Buffer size for reading checkpoint files during recovery. + static constexpr int32_t kCheckpointReadBufferSize = 1 << 20; // 1MB + // Updates the read count of a region. void regionRead(int32_t region, int32_t size) { tracker_.regionRead(region, size); @@ -473,6 +511,10 @@ class SsdFile { if (!checkpointEnabled()) { return false; } + // Once no SSD space, skip the subsequent checkpointing. + if (state_.load() == State::kNoSpace) { + return false; + } return force || (bytesAfterCheckpoint_ >= checkpointIntervalBytes_); } @@ -490,6 +532,22 @@ class SsdFile { // Deletes the given file if it exists. void deleteFile(std::unique_ptr file); + // Removes the checkpoint and eviction-log files left behind by a previous + // SsdFile instance using this same data file. Called from the constructor + // when this instance has checkpointing disabled. + // + // The checkpoint and log together describe which logical keys live at + // which offsets in the data file; recovery on startup reads them to + // rebuild 'entries_'. They are trustworthy only while every write to the + // data file also keeps them up to date. With checkpointing off this + // instance writes into the existing data file but never touches the + // meta files, so a later instance with checkpointing re-enabled would + // recover from a stale checkpoint pointing at overwritten regions and + // silently return wrong bytes. Removing them here keeps the on-disk + // state to either {data + matching checkpoint} or {data alone}; the + // {data + stale checkpoint} state becomes unreachable. + void removeStaleRecoveryFiles(); + // Allocates 'kCheckpointBufferSize' buffer from cache memory pool for // checkpointing. void allocateCheckpointBuffer(); @@ -548,6 +606,12 @@ class SsdFile { // Shard index within 'cache_'. const int32_t shardId_; + // Maximum number of SSD cache entries allowed in this file. 0 means no limit. + const uint64_t maxEntries_; + + // Executor for async fsync in checkpoint. + folly::Executor* const executor_; + // Serializes access to all private data members. mutable std::shared_mutex mutex_; @@ -578,6 +642,8 @@ class SsdFile { // Map of file number and offset to location in file. folly::F14FastMap entries_; + std::atomic state_{State::kActive}; + // File system. std::shared_ptr fs_; @@ -603,9 +669,6 @@ class SsdFile { // means no checkpointing. This is set to 0 if checkpointing fails. int64_t checkpointIntervalBytes_{0}; - // Executor for async fsync in checkpoint. - folly::Executor* executor_; - // Count of bytes written after last checkpoint. std::atomic bytesAfterCheckpoint_{0}; @@ -622,4 +685,16 @@ class SsdFile { friend class test::SsdCacheTestHelper; }; +std::ostream& operator<<(std::ostream& out, const SsdFile::State& state); + } // namespace facebook::velox::cache + +template <> +struct fmt::formatter + : formatter { + auto format(facebook::velox::cache::SsdFile::State state, format_context& ctx) + const { + return formatter::format( + facebook::velox::cache::SsdFile::stateString(state), ctx); + } +}; diff --git a/velox/common/caching/SsdFileTracker.h b/velox/common/caching/SsdFileTracker.h index 5ab825f50b8..e9eed9720dc 100644 --- a/velox/common/caching/SsdFileTracker.h +++ b/velox/common/caching/SsdFileTracker.h @@ -30,7 +30,12 @@ namespace facebook::velox::cache { class SsdFileTracker { public: void resize(int32_t numRegions) { - resizeTsanAtomic(regionScores_, numRegions); + std::vector> newScores(numRegions); + auto numCopy = std::min(numRegions, regionScores_.size()); + for (auto i = 0; i < numCopy; ++i) { + newScores[i] = tsanAtomicValue(regionScores_[i]); + } + regionScores_ = std::move(newScores); } void regionRead(int32_t region, int32_t bytes) { diff --git a/velox/common/caching/StringIdMap.cpp b/velox/common/caching/StringIdMap.cpp index c8c88542da1..e991c5a750e 100644 --- a/velox/common/caching/StringIdMap.cpp +++ b/velox/common/caching/StringIdMap.cpp @@ -31,8 +31,8 @@ void StringIdMap::release(uint64_t id) { std::lock_guard l(mutex_); auto it = idToEntry_.find(id); if (it != idToEntry_.end()) { - VELOX_CHECK_LT( - 0, it->second.numInUse, "Extra release of id in StringIdMap"); + VELOX_CHECK_GT( + it->second.numInUse, 0, "Extra release of id in StringIdMap"); if (--it->second.numInUse == 0) { pinnedSize_ -= it->second.string.size(); auto strIter = stringToId_.find(it->second.string); @@ -60,11 +60,11 @@ uint64_t StringIdMap::makeId(std::string_view string) { if (it != stringToId_.end()) { auto entry = idToEntry_.find(it->second); VELOX_CHECK(entry != idToEntry_.end()); - if (++entry->second.numInUse == 1) { - pinnedSize_ += entry->second.string.size(); - } + VELOX_CHECK_GE(entry->second.numInUse, 1); + ++entry->second.numInUse; return it->second; } + Entry entry; entry.string = string; // Check that we do not use an id twice. In practice this never @@ -91,9 +91,8 @@ uint64_t StringIdMap::recoverId(uint64_t id, std::string_view string) { id, it->second, "Multiple recover ids assigned to {}", string); auto entry = idToEntry_.find(it->second); VELOX_CHECK(entry != idToEntry_.end()); - if (++entry->second.numInUse == 1) { - pinnedSize_ += entry->second.string.size(); - } + VELOX_CHECK_GE(entry->second.numInUse, 1); + ++entry->second.numInUse; return id; } diff --git a/velox/common/caching/tests/AsyncDataCacheTest.cpp b/velox/common/caching/tests/AsyncDataCacheTest.cpp index f60da8f4ba5..7df81482038 100644 --- a/velox/common/caching/tests/AsyncDataCacheTest.cpp +++ b/velox/common/caching/tests/AsyncDataCacheTest.cpp @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "folly/experimental/EventCount.h" +#include "folly/synchronization/EventCount.h" #include "velox/common/base/Semaphore.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/caching/CacheTTLController.h" @@ -25,8 +25,8 @@ #include "velox/common/memory/Memory.h" #include "velox/common/memory/MmapAllocator.h" #include "velox/common/testutil/ScopedTestTime.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/common/testutil/TestValue.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include #include @@ -144,7 +144,7 @@ class AsyncDataCacheTest : public ::testing::TestWithParam { // second creation of cache must find the checkpoint of the // previous one. if (tempDirectory_ == nullptr || eraseCheckpoint) { - tempDirectory_ = exec::test::TempDirectoryPath::create(); + tempDirectory_ = TempDirectoryPath::create(); } SsdCache::Config config( fmt::format("{}/cache", tempDirectory_->getPath()), @@ -158,7 +158,7 @@ class AsyncDataCacheTest : public ::testing::TestWithParam { ssdCache = std::make_unique(config); if (ssdCache != nullptr) { ssdCacheHelper_ = - std::make_unique(ssdCache.get()); + std::make_unique(ssdCache.get()); ASSERT_EQ(ssdCacheHelper_->numShards(), kNumSsdShards); } } @@ -173,7 +173,7 @@ class AsyncDataCacheTest : public ::testing::TestWithParam { cache_ = AsyncDataCache::create(allocator_, std::move(ssdCache), cacheOptions); asyncDataCacheHelper_ = - std::make_unique(cache_.get()); + std::make_unique(cache_.get()); if (filenames_.empty()) { for (auto i = 0; i < kNumFiles; ++i) { auto name = fmt::format("testing_file_{}", i); @@ -244,7 +244,7 @@ class AsyncDataCacheTest : public ::testing::TestWithParam { // Checks that the contents are consistent with what is set in // initializeContents. static void checkContents(const AsyncDataCacheEntry& entry) { - const auto& alloc = entry.data(); + const auto& alloc = entry.nonContiguousData(); const int32_t numBytes = entry.size(); const int64_t expectedSequence = entry.key().fileNum.id() + entry.offset(); int32_t bytesChecked = sizeof(int64_t); @@ -274,7 +274,7 @@ class AsyncDataCacheTest : public ::testing::TestWithParam { folly::SemiFuture wait(false); try { RawFileCacheKey key{filenames_[0].id(), offset}; - auto pin = cache_->findOrCreate(key, size, &wait); + auto pin = cache_->findOrCreate(key, size, /*contiguous=*/false, &wait); EXPECT_FALSE(pin.empty()); EXPECT_TRUE(pin.entry()->isExclusive()); pin.entry()->setPrefetch(); @@ -313,12 +313,12 @@ class AsyncDataCacheTest : public ::testing::TestWithParam { } } - std::shared_ptr tempDirectory_; + std::shared_ptr tempDirectory_; std::unique_ptr manager_; memory::MemoryAllocator* allocator_; std::shared_ptr cache_; - std::unique_ptr asyncDataCacheHelper_; - std::unique_ptr ssdCacheHelper_; + std::unique_ptr asyncDataCacheHelper_; + std::unique_ptr ssdCacheHelper_; std::vector filenames_; std::unique_ptr loadExecutor_; std::unique_ptr ssdExecutor_; @@ -346,7 +346,7 @@ class TestingCoalescedLoad : public CoalescedLoad { pins.push_back(std::move(pin)); }); for (const auto& pin : pins) { - auto& buffer = pin.entry()->data(); + auto& buffer = pin.entry()->nonContiguousData(); AsyncDataCacheTest::initializeContents( pin.entry()->key().offset + pin.entry()->key().fileNum.id(), buffer); } @@ -362,6 +362,10 @@ class TestingCoalescedLoad : public CoalescedLoad { return sum; } + bool isSsdLoad() const override { + return false; + } + protected: const std::shared_ptr cache_; const std::vector requests_; @@ -415,6 +419,10 @@ class TestingCoalescedSsdLoad : public TestingCoalescedLoad { return pins; } + bool isSsdLoad() const override { + return true; + } + private: std::vector ssdPins_; }; @@ -433,7 +441,8 @@ void AsyncDataCacheTest::loadOne( RawFileCacheKey key{fileNum, request.offset}; for (;;) { folly::SemiFuture loadFuture(false); - auto pin = cache_->findOrCreate(key, request.size, &loadFuture); + auto pin = cache_->findOrCreate( + key, request.size, /*contiguous=*/false, &loadFuture); if (pin.empty()) { // The pin was exclusive on another thread. Wait until it is no longer so // and retry. @@ -467,7 +476,8 @@ void AsyncDataCacheTest::loadOne( } // Load from storage. initializeContents( - entry->key().offset + entry->key().fileNum.id(), entry->data()); + entry->key().offset + entry->key().fileNum.id(), + entry->nonContiguousData()); entry->setExclusiveToShared(); return; } @@ -661,35 +671,36 @@ TEST_P(AsyncDataCacheTest, pin) { uint64_t offset = 1000; folly::SemiFuture wait(false); RawFileCacheKey key{file.id(), offset}; - auto pin = cache_->findOrCreate(key, kSize, &wait); + auto pin = cache_->findOrCreate(key, kSize, /*contiguous=*/false, &wait); EXPECT_FALSE(pin.empty()); EXPECT_TRUE(wait.isReady()); EXPECT_TRUE(pin.entry()->isExclusive()); pin.entry()->setPrefetch(); - EXPECT_LE(kSize, pin.entry()->data().byteSize()); + EXPECT_LE(kSize, pin.entry()->nonContiguousData().byteSize()); EXPECT_LT(0, cache_->incrementPrefetchPages(0)); auto stats = cache_->refreshStats(); EXPECT_EQ(1, stats.numExclusive); - EXPECT_LE(kSize, stats.largeSize); + EXPECT_EQ(0, stats.largeSize); CachePin otherPin; EXPECT_THROW(otherPin = pin, VeloxException); EXPECT_TRUE(otherPin.empty()); // Second reference to an exclusive entry. - otherPin = cache_->findOrCreate(key, kSize, &wait); + otherPin = cache_->findOrCreate(key, kSize, /*contiguous=*/false, &wait); EXPECT_FALSE(wait.isReady()); EXPECT_TRUE(otherPin.empty()); bool noLongerExclusive = false; std::move(wait).via(&exec).thenValue([&](bool) { noLongerExclusive = true; }); - initializeContents(key.fileNum + key.offset, pin.checkedEntry()->data()); + initializeContents( + key.fileNum + key.offset, pin.checkedEntry()->nonContiguousData()); pin.checkedEntry()->setExclusiveToShared(); pin.clear(); EXPECT_TRUE(pin.empty()); EXPECT_TRUE(noLongerExclusive); - pin = cache_->findOrCreate(key, kSize, &wait); + pin = cache_->findOrCreate(key, kSize, /*contiguous=*/false, &wait); EXPECT_TRUE(pin.entry()->isShared()); EXPECT_TRUE(pin.entry()->getAndClearFirstUseFlag()); EXPECT_FALSE(pin.entry()->getAndClearFirstUseFlag()); @@ -697,7 +708,8 @@ TEST_P(AsyncDataCacheTest, pin) { otherPin = pin; EXPECT_EQ(2, pin.entry()->numPins()); EXPECT_FALSE(pin.entry()->isPrefetch()); - auto largerPin = cache_->findOrCreate(key, kSize * 2, &wait); + auto largerPin = + cache_->findOrCreate(key, kSize * 2, /*contiguous=*/false, &wait); // We expect a new uninitialized entry with a larger size to displace the // previous one. @@ -719,6 +731,223 @@ TEST_P(AsyncDataCacheTest, pin) { EXPECT_EQ(0, cache_->incrementPrefetchPages(0)); } +TEST_P(AsyncDataCacheTest, contiguousPin) { + initializeCache(1 << 20); + auto& exec = folly::QueuedImmediateExecutor::instance(); + + StringIdLease file(fileIds(), std::string_view("testingfile_contiguous")); + + struct TestParam { + int64_t size; + bool expectTiny; + std::string debugString() const { + return fmt::format("size {}, expectTiny {}", size, expectTiny); + } + }; + + std::vector testSettings = { + {AsyncDataCacheEntry::kTinyDataSize / 2, true}, + {AsyncDataCacheEntry::kTinyDataSize - 1, true}, + {AsyncDataCacheEntry::kTinyDataSize, false}, + {AsyncDataCacheEntry::kTinyDataSize * 4, false}, + {25'000, false}, + }; + + uint64_t offset = 1'000; + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + folly::SemiFuture wait(false); + RawFileCacheKey key{file.id(), offset}; + offset += testData.size; + + auto pin = + cache_->findOrCreate(key, testData.size, /*contiguous=*/true, &wait); + ASSERT_FALSE(pin.empty()); + ASSERT_TRUE(wait.isReady()); + auto* entry = pin.checkedEntry(); + ASSERT_TRUE(entry->isExclusive()); + ASSERT_TRUE(entry->hasContiguousData()); + ASSERT_TRUE(entry->nonContiguousData().empty()); + ASSERT_EQ(entry->contiguousDataSize(), testData.size); + + test::AsyncDataCacheEntryTestHelper entryHelper(entry); + if (testData.expectTiny) { + ASSERT_TRUE(entryHelper.isTinyData()); + ASSERT_FALSE(entryHelper.isContiguousData()); + } else { + ASSERT_FALSE(entryHelper.isTinyData()); + ASSERT_TRUE(entryHelper.isContiguousData()); + } + + ::memset(entry->contiguousData(), 0xCD, testData.size); + entry->setExclusiveToShared(); + + // Verify stats include the contiguous allocation. + { + auto stats = cache_->refreshStats(); + ASSERT_EQ(stats.numEntries, 1); + if (testData.expectTiny) { + ASSERT_EQ(stats.numTinyEntries, 1); + ASSERT_EQ(stats.tinySize, testData.size); + } else { + ASSERT_EQ(stats.numLargeEntries, 1); + ASSERT_EQ(stats.largeSize, testData.size); + } + if (!testData.expectTiny) { + ASSERT_EQ( + cache_->cachedPages(), + memory::AllocationTraits::numPages(testData.size)); + } + } + + auto pin2 = + cache_->findOrCreate(key, testData.size, /*contiguous=*/true, &wait); + ASSERT_FALSE(pin2.empty()); + ASSERT_TRUE(pin2.checkedEntry()->isShared()); + ASSERT_TRUE(pin2.checkedEntry()->hasContiguousData()); + ASSERT_EQ( + static_cast(pin2.checkedEntry()->contiguousData()[0]), 0xCD); + pin2.clear(); + + // Lookup with contiguous=false should return the same contiguous entry. + auto pin3 = + cache_->findOrCreate(key, testData.size, /*contiguous=*/false, &wait); + ASSERT_FALSE(pin3.empty()); + ASSERT_TRUE(pin3.checkedEntry()->isShared()); + ASSERT_TRUE(pin3.checkedEntry()->hasContiguousData()); + pin.clear(); + pin3.clear(); + + // Wait-future: concurrent access to an exclusive entry. + RawFileCacheKey key2{file.id(), offset}; + offset += testData.size; + auto exclusivePin = + cache_->findOrCreate(key2, testData.size, /*contiguous=*/true, &wait); + ASSERT_FALSE(exclusivePin.empty()); + ASSERT_TRUE(exclusivePin.checkedEntry()->isExclusive()); + + auto waitingPin = + cache_->findOrCreate(key2, testData.size, /*contiguous=*/true, &wait); + ASSERT_TRUE(waitingPin.empty()); + ASSERT_FALSE(wait.isReady()); + + bool notified = false; + std::move(wait).via(&exec).thenValue([&](bool) { notified = true; }); + exclusivePin.checkedEntry()->setExclusiveToShared(); + exclusivePin.clear(); + ASSERT_TRUE(notified); + + auto retryPin = + cache_->findOrCreate(key2, testData.size, /*contiguous=*/true, &wait); + ASSERT_FALSE(retryPin.empty()); + ASSERT_TRUE(retryPin.checkedEntry()->isShared()); + retryPin.clear(); + + cache_->clear(); + auto stats = cache_->refreshStats(); + ASSERT_EQ(stats.numEntries, 0); + ASSERT_EQ(stats.largeSize, 0); + ASSERT_EQ(stats.tinySize, 0); + ASSERT_EQ(cache_->cachedPages(), 0); + } +} + +TEST_P(AsyncDataCacheTest, nonContiguousPin) { + initializeCache(1 << 20); + + StringIdLease file(fileIds(), std::string_view("testingfile_noncontiguous")); + + struct TestParam { + int64_t size; + bool expectTiny; + std::string debugString() const { + return fmt::format("size {}, expectTiny {}", size, expectTiny); + } + }; + + std::vector testSettings = { + {AsyncDataCacheEntry::kTinyDataSize / 2, true}, + {AsyncDataCacheEntry::kTinyDataSize - 1, true}, + {AsyncDataCacheEntry::kTinyDataSize, false}, + {AsyncDataCacheEntry::kTinyDataSize * 4, false}, + {25'000, false}, + }; + + uint64_t offset = 1'000; + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + folly::SemiFuture wait(false); + RawFileCacheKey key{file.id(), offset}; + offset += testData.size; + + auto pin = + cache_->findOrCreate(key, testData.size, /*contiguous=*/false, &wait); + ASSERT_FALSE(pin.empty()); + ASSERT_TRUE(wait.isReady()); + auto* entry = pin.checkedEntry(); + ASSERT_TRUE(entry->isExclusive()); + + test::AsyncDataCacheEntryTestHelper entryHelper(entry); + if (testData.expectTiny) { + ASSERT_TRUE(entryHelper.isTinyData()); + ASSERT_FALSE(entryHelper.isContiguousData()); + ASSERT_TRUE(entry->hasContiguousData()); + ASSERT_TRUE(entry->nonContiguousData().empty()); + } else { + ASSERT_FALSE(entryHelper.isTinyData()); + ASSERT_FALSE(entryHelper.isContiguousData()); + ASSERT_FALSE(entry->hasContiguousData()); + ASSERT_FALSE(entry->nonContiguousData().empty()); + } + + entry->setExclusiveToShared(); + + // Verify stats. + { + auto stats = cache_->refreshStats(); + ASSERT_EQ(stats.numEntries, 1); + if (testData.expectTiny) { + ASSERT_EQ(stats.numTinyEntries, 1); + ASSERT_EQ(stats.tinySize, testData.size); + } else { + ASSERT_EQ(stats.numLargeEntries, 1); + } + if (!testData.expectTiny) { + ASSERT_EQ( + cache_->cachedPages(), + memory::AllocationTraits::numPages(testData.size)); + } + } + + // Shared hit with same flag. + auto pin2 = + cache_->findOrCreate(key, testData.size, /*contiguous=*/false, &wait); + ASSERT_FALSE(pin2.empty()); + ASSERT_TRUE(pin2.checkedEntry()->isShared()); + pin2.clear(); + + // Lookup with contiguous=true should return the same non-contiguous entry. + auto pin3 = + cache_->findOrCreate(key, testData.size, /*contiguous=*/true, &wait); + ASSERT_FALSE(pin3.empty()); + ASSERT_TRUE(pin3.checkedEntry()->isShared()); + if (testData.expectTiny) { + ASSERT_TRUE(pin3.checkedEntry()->hasContiguousData()); + } else { + ASSERT_FALSE(pin3.checkedEntry()->hasContiguousData()); + } + pin.clear(); + pin3.clear(); + + cache_->clear(); + auto stats = cache_->refreshStats(); + ASSERT_EQ(stats.numEntries, 0); + ASSERT_EQ(cache_->cachedPages(), 0); + } +} + TEST_P(AsyncDataCacheTest, replace) { constexpr int64_t kMaxBytes = 64 << 20; FLAGS_velox_exception_user_stacktrace_enabled = false; @@ -733,8 +962,7 @@ TEST_P(AsyncDataCacheTest, replace) { EXPECT_LT(0, stats.hitBytes); EXPECT_LT(0, stats.numEvict); EXPECT_GE( - kMaxBytes / memory::AllocationTraits::kPageSize, - cache_->incrementCachedPages(0)); + kMaxBytes / memory::AllocationTraits::kPageSize, cache_->cachedPages()); } TEST_P(AsyncDataCacheTest, evictAccounting) { @@ -776,8 +1004,7 @@ TEST_P(AsyncDataCacheTest, largeEvict) { auto stats = cache_->refreshStats(); EXPECT_LT(0, stats.numEvict); EXPECT_GE( - kMaxBytes / memory::AllocationTraits::kPageSize, - cache_->incrementCachedPages(0)); + kMaxBytes / memory::AllocationTraits::kPageSize, cache_->cachedPages()); LOG(INFO) << "Reties after failed evict: " << numLargeRetries_; } @@ -805,7 +1032,7 @@ TEST_P(AsyncDataCacheTest, outOfCapacity) { ASSERT_FALSE(allocator_->allocateNonContiguous(kSizeInPages, allocation)); // One 4 page entry below the max size of 4K 4 page entries in 16MB of // capacity. - ASSERT_EQ(16384, cache_->incrementCachedPages(0)); + ASSERT_EQ(16384, cache_->cachedPages()); ASSERT_EQ(16384, cache_->incrementPrefetchPages(0)); pins.clear(); @@ -816,7 +1043,7 @@ TEST_P(AsyncDataCacheTest, outOfCapacity) { } allocations.push_back(std::move(allocation)); } - EXPECT_EQ(0, cache_->incrementCachedPages(0)); + EXPECT_EQ(0, cache_->cachedPages()); EXPECT_EQ(0, cache_->incrementPrefetchPages(0)); EXPECT_EQ(16384, allocator_->numAllocated()); clearAllocations(allocations); @@ -1044,7 +1271,7 @@ TEST_P(AsyncDataCacheTest, staleEntry) { const uint64_t size = 200; folly::SemiFuture wait(false); RawFileCacheKey key{file.id(), offset}; - auto pin = cache_->findOrCreate(key, size, &wait); + auto pin = cache_->findOrCreate(key, size, /*contiguous=*/false, &wait); ASSERT_FALSE(pin.empty()); ASSERT_TRUE(wait.isReady()); ASSERT_TRUE(pin.entry()->isExclusive()); @@ -1055,7 +1282,7 @@ TEST_P(AsyncDataCacheTest, staleEntry) { ASSERT_EQ(stats.numEntries, 1); ASSERT_EQ(stats.numHit, 0); - auto validPin = cache_->findOrCreate(key, size, &wait); + auto validPin = cache_->findOrCreate(key, size, /*contiguous=*/false, &wait); ASSERT_FALSE(validPin.empty()); ASSERT_TRUE(wait.isReady()); ASSERT_FALSE(validPin.entry()->isExclusive()); @@ -1065,7 +1292,8 @@ TEST_P(AsyncDataCacheTest, staleEntry) { ASSERT_EQ(stats.numHit, 1); // Stale cache access with large cache size. - auto stalePin = cache_->findOrCreate(key, 2 * size, &wait); + auto stalePin = + cache_->findOrCreate(key, 2 * size, /*contiguous=*/false, &wait); ASSERT_FALSE(stalePin.empty()); ASSERT_TRUE(wait.isReady()); ASSERT_TRUE(stalePin.entry()->isExclusive()); @@ -1125,26 +1353,26 @@ TEST_P(AsyncDataCacheTest, shrinkCache) { for (int i = 0; i < numEntries; ++i) { auto tinyPin = cache_->findOrCreate(tinyCacheKeys[i], kTinyDataSize); ASSERT_FALSE(tinyPin.empty()); - ASSERT_TRUE(tinyPin.entry()->tinyData() != nullptr); - ASSERT_TRUE(tinyPin.entry()->data().empty()); + ASSERT_TRUE(tinyPin.entry()->hasContiguousData()); + ASSERT_TRUE(tinyPin.entry()->nonContiguousData().empty()); ASSERT_FALSE(tinyPin.entry()->isPrefetch()); ASSERT_FALSE(tinyPin.entry()->ssdSaveable()); pins.push_back(std::move(tinyPin)); auto largePin = cache_->findOrCreate(largeCacheKeys[i], kLargeDataSize); - ASSERT_FALSE(largePin.entry()->tinyData() != nullptr); - ASSERT_FALSE(largePin.entry()->data().empty()); + ASSERT_FALSE(largePin.entry()->hasContiguousData()); + ASSERT_FALSE(largePin.entry()->nonContiguousData().empty()); ASSERT_FALSE(largePin.entry()->isPrefetch()); ASSERT_FALSE(largePin.entry()->ssdSaveable()); pins.push_back(std::move(largePin)); } auto stats = cache_->refreshStats(); - ASSERT_EQ(stats.numEntries, numEntries * 2); + ASSERT_EQ(stats.numEntries, 0); ASSERT_EQ(stats.numEmptyEntries, 0); ASSERT_EQ(stats.numExclusive, numEntries * 2); ASSERT_EQ(stats.numEvict, 0); ASSERT_EQ(stats.numHit, 0); - ASSERT_EQ(stats.tinySize, kTinyDataSize * numEntries); - ASSERT_EQ(stats.largeSize, kLargeDataSize * numEntries); + ASSERT_EQ(stats.tinySize, 0); + ASSERT_EQ(stats.largeSize, 0); ASSERT_EQ(stats.sharedPinnedBytes, 0); ASSERT_GE( stats.exclusivePinnedBytes, @@ -1402,9 +1630,10 @@ TEST_P(AsyncDataCacheTest, makeEvictable) { std::vector keys; keys.reserve(cachePins.size()); for (const auto& pin : cachePins) { - keys.push_back(RawFileCacheKey{ - pin.checkedEntry()->key().fileNum.id(), - pin.checkedEntry()->key().offset}); + keys.push_back( + RawFileCacheKey{ + pin.checkedEntry()->key().fileNum.id(), + pin.checkedEntry()->key().offset}); } cachePins.clear(); for (const auto& key : keys) { @@ -1414,7 +1643,7 @@ TEST_P(AsyncDataCacheTest, makeEvictable) { const auto cacheEntries = asyncDataCacheHelper_->cacheEntries(); for (const auto& cacheEntry : cacheEntries) { const auto cacheEntryHelper = - test::AsyncDataCacheEntryTestHelper(cacheEntry); + cache::test::AsyncDataCacheEntryTestHelper(cacheEntry); ASSERT_EQ(cacheEntry->ssdSaveable(), !evictable); ASSERT_EQ(cacheEntryHelper.accessStats().numUses, 0); if (evictable) { @@ -1495,6 +1724,65 @@ TEST_P(AsyncDataCacheTest, ssdWriteOptions) { } } +TEST_P(AsyncDataCacheTest, ssdFlushThresholdBytes) { + constexpr uint64_t kRamBytes = 16UL << 20; // 16 MB + constexpr uint64_t kSsdBytes = 64UL << 20; // 64 MB + + struct { + double maxWriteRatio; + double ssdSavableRatio; + int32_t minSsdSavableBytes; + uint64_t ssdFlushThresholdBytes; + bool expectedSaveToSsd; + + std::string debugString() const { + return fmt::format( + "maxWriteRatio {}, ssdSavableRatio {}, minSsdSavableBytes {}, ssdFlushThresholdBytes {}, expectedSaveToSsd {}", + maxWriteRatio, + ssdSavableRatio, + minSsdSavableBytes, + ssdFlushThresholdBytes, + expectedSaveToSsd); + } + } testSettings[] = { + // Ratio-based threshold not met, ssdFlushThresholdBytes disabled (0). + // No flush expected. + {0.8, 0.95, 32 << 20, 0, false}, + // Ratio-based threshold not met, but ssdFlushThresholdBytes is small + // (1MB). + // Flush expected due to absolute threshold. + {0.8, 0.95, 32 << 20, 1UL << 20, true}, + // Ratio-based threshold met. ssdFlushThresholdBytes disabled. + // Flush expected due to ratio. + {0.8, 0.3, 4 << 20, 0, true}, + // Both thresholds could trigger. Flush expected. + {0.8, 0.3, 4 << 20, 1UL << 20, true}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + initializeCache( + kRamBytes, + kSsdBytes, + 0, + true, + AsyncDataCache::Options( + testData.maxWriteRatio, + testData.ssdSavableRatio, + testData.minSsdSavableBytes, + AsyncDataCache::kDefaultNumShards, + testData.ssdFlushThresholdBytes)); + // Load data half of the in-memory capacity. + loadLoop(0, kRamBytes / 2); + waitForPendingLoads(); + auto stats = cache_->refreshStats(); + if (testData.expectedSaveToSsd) { + EXPECT_GT(stats.ssdStats->entriesWritten, 0); + } else { + EXPECT_EQ(stats.ssdStats->entriesWritten, 0); + } + } +} + TEST_P(AsyncDataCacheTest, appendSsdSaveable) { constexpr uint64_t kRamBytes = 64UL << 20; // 64 MB constexpr uint64_t kSsdBytes = 128UL << 20; // 128 MB @@ -1572,6 +1860,810 @@ TEST_P(AsyncDataCacheTest, checkpoint) { // TODO: add concurrent fuzzer test. +TEST_P(AsyncDataCacheTest, numShardsDefault) { + constexpr uint64_t kRamBytes = 16UL << 20; + + initializeCache(kRamBytes); + ASSERT_EQ( + asyncDataCacheHelper_->numShards(), AsyncDataCache::kDefaultNumShards); +} + +TEST_P(AsyncDataCacheTest, numShardsInvalid) { + constexpr uint64_t kRamBytes = 16UL << 20; + + // Non-power-of-2 should fail. + for (int32_t numShards : {3, 5, 6, 7, 9, 10}) { + AsyncDataCache::Options options; + options.numShards = numShards; + VELOX_ASSERT_THROW( + initializeCache(kRamBytes, 0, 0, false, options), + "numShards must be a power of 2"); + } + + // Zero should fail. + { + AsyncDataCache::Options options; + options.numShards = 0; + VELOX_ASSERT_THROW( + initializeCache(kRamBytes, 0, 0, false, options), + "numShards must be positive"); + } + + // Negative should fail. + { + AsyncDataCache::Options options; + options.numShards = -1; + VELOX_ASSERT_THROW( + initializeCache(kRamBytes, 0, 0, false, options), + "numShards must be positive"); + } +} + +TEST_P(AsyncDataCacheTest, findMiss) { + constexpr int64_t kRamBytes = 32 << 20; + initializeMemoryManager(kRamBytes); + initializeCache(kRamBytes); + + RawFileCacheKey key{filenames_[0].id(), 0}; + auto result = cache_->find(key); + ASSERT_FALSE(result.has_value()); +} + +TEST_P(AsyncDataCacheTest, findHit) { + constexpr int64_t kRamBytes = 32 << 20; + constexpr int32_t kEntrySize = 4096; + initializeMemoryManager(kRamBytes); + initializeCache(kRamBytes); + + RawFileCacheKey key{filenames_[0].id(), 1000}; + + // Populate the entry via findOrCreate. + { + auto pin = cache_->findOrCreate(key, kEntrySize); + ASSERT_FALSE(pin.empty()); + ASSERT_TRUE(pin.entry()->isExclusive()); + initializeContents( + key.offset + key.fileNum, pin.entry()->nonContiguousData()); + pin.entry()->setExclusiveToShared(); + } + + // find should return a shared pin with correct data. + auto result = cache_->find(key); + ASSERT_TRUE(result.has_value()); + ASSERT_FALSE(result->empty()); + auto* entry = result->checkedEntry(); + ASSERT_TRUE(entry->isShared()); + ASSERT_EQ(entry->size(), kEntrySize); + checkContents(*entry); +} + +TEST_P(AsyncDataCacheTest, findExclusiveWithWait) { + constexpr int64_t kRamBytes = 32 << 20; + constexpr int32_t kEntrySize = 4096; + initializeMemoryManager(kRamBytes); + initializeCache(kRamBytes); + + RawFileCacheKey key{filenames_[0].id(), 2000}; + + // Create an exclusive entry. + auto exclusivePin = cache_->findOrCreate(key, kEntrySize); + ASSERT_FALSE(exclusivePin.empty()); + ASSERT_TRUE(exclusivePin.entry()->isExclusive()); + + // find without wait returns empty pin (entry exists but is exclusive). + { + auto result = cache_->find(key); + ASSERT_TRUE(result.has_value()); + ASSERT_TRUE(result->empty()); + } + + // find with wait returns empty pin and sets a future. + folly::SemiFuture waitFuture(false); + { + auto result = cache_->find(key, &waitFuture); + ASSERT_TRUE(result.has_value()); + ASSERT_TRUE(result->empty()); + } + + // The future should not be ready while the entry is exclusive. + ASSERT_FALSE(waitFuture.isReady()); + + // Timed wait should time out while the entry is still exclusive. + { + auto waitCopy = std::move(waitFuture); + auto timedResult = + std::move(waitCopy).via(&folly::QueuedImmediateExecutor::instance()); + ASSERT_FALSE( + std::move(timedResult).wait(std::chrono::seconds(1)).isReady()); + } + + // Re-issue find with wait after the timed-out future was consumed. + waitFuture = folly::SemiFuture(false); + { + auto result = cache_->find(key, &waitFuture); + ASSERT_TRUE(result.has_value()); + ASSERT_TRUE(result->empty()); + } + ASSERT_FALSE(waitFuture.isReady()); + + // Transition to shared makes the future ready. + initializeContents( + key.offset + key.fileNum, exclusivePin.entry()->nonContiguousData()); + exclusivePin.entry()->setExclusiveToShared(); + exclusivePin.clear(); + + auto& exec = folly::QueuedImmediateExecutor::instance(); + ASSERT_TRUE(std::move(waitFuture).via(&exec).wait().isReady()); + + // Now find should return a shared pin. + auto result = cache_->find(key); + ASSERT_TRUE(result.has_value()); + ASSERT_FALSE(result->empty()); + checkContents(*result->checkedEntry()); +} + +TEST_P(AsyncDataCacheTest, fuzz) { + constexpr int64_t kRamBytes = 64 << 20; + constexpr int32_t kNumThreads = 8; + constexpr int32_t kNumFiles = 10; + constexpr int32_t kNumOffsets = 20; + constexpr int32_t kEntrySize = 4096; + constexpr int32_t kTestDurationMs = 10'000; + + initializeMemoryManager(kRamBytes); + initializeCache(kRamBytes); + + std::atomic_bool stop{false}; + + // Worker threads: findOrCreate/find entries, verify content. + auto workerFunc = [&](int32_t threadId) { + std::mt19937 rng(threadId); + while (!stop.load(std::memory_order_relaxed)) { + const auto fileIdx = rng() % kNumFiles; + const auto offsetIdx = rng() % kNumOffsets; + const uint64_t offset = offsetIdx * kEntrySize; + RawFileCacheKey key{filenames_[fileIdx].id(), offset}; + + // Randomly choose between find and findOrCreate. + if (rng() % 3 == 0) { + // find: lookup only. + auto result = cache_->find(key); + if (result.has_value() && !result->empty()) { + checkContents(*result->checkedEntry()); + } + } else { + // findOrCreate: populate if new. + folly::SemiFuture waitFuture(false); + auto pin = cache_->findOrCreate( + key, kEntrySize, /*contiguous=*/false, &waitFuture); + if (pin.empty()) { + auto& exec = folly::QueuedImmediateExecutor::instance(); + std::move(waitFuture).via(&exec).wait(); + continue; + } + auto* entry = pin.checkedEntry(); + if (entry->isExclusive()) { + initializeContents( + key.offset + key.fileNum, entry->nonContiguousData()); + entry->setExclusiveToShared(); + } + checkContents(*entry); + } + } + }; + + // Eviction thread: periodically remove a subset of files. + auto evictFunc = [&]() { + std::mt19937 rng(kNumThreads + 1); + while (!stop.load(std::memory_order_relaxed)) { + std::this_thread::sleep_for(std::chrono::milliseconds(50)); // NOLINT + folly::F14FastSet filesToRemove; + // Remove a random subset of files. + const auto numToRemove = (rng() % kNumFiles) + 1; + for (uint32_t i = 0; i < numToRemove; ++i) { + filesToRemove.insert(filenames_[rng() % kNumFiles].id()); + } + folly::F14FastSet filesRetained; + cache_->removeFileEntries(filesToRemove, filesRetained); + } + }; + + std::vector threads; + for (int32_t i = 0; i < kNumThreads; ++i) { + threads.emplace_back(workerFunc, i); + } + threads.emplace_back(evictFunc); + + std::this_thread::sleep_for( + std::chrono::milliseconds(kTestDurationMs)); // NOLINT + stop.store(true, std::memory_order_relaxed); + + for (auto& thread : threads) { + thread.join(); + } + + auto stats = cache_->refreshStats(); + LOG(INFO) << "fuzz stats: " << stats.numEntries << " entries, " + << stats.numHit << " hits, " << stats.numNew << " new, " + << stats.numEvict << " evicts"; +} + +TEST_P(AsyncDataCacheTest, dataRanges) { + constexpr uint64_t kRamBytes = 64UL << 20; + initializeCache(kRamBytes); + + struct TestParam { + int32_t size; + std::string debugString() const { + return fmt::format("size {}", size); + } + }; + + std::vector testSettings = { + // Tiny entry (< kTinyDataSize). + {AsyncDataCacheEntry::kTinyDataSize - 1}, + // Allocation-backed entry (>= kTinyDataSize). + {AsyncDataCacheEntry::kTinyDataSize * 4}, + }; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + auto pin = newEntry(testData.size, testData.size); + ASSERT_FALSE(pin.empty()); + auto* entry = pin.checkedEntry(); + ASSERT_TRUE(entry->isExclusive()); + + auto ranges = entry->dataRanges(testData.size); + ASSERT_FALSE(ranges.empty()); + + // Verify total bytes across all ranges covers the entry size. + uint64_t totalBytes = 0; + for (const auto& range : ranges) { + ASSERT_GT(range.size(), 0); + totalBytes += range.size(); + } + ASSERT_EQ(totalBytes, testData.size); + + // Verify ranges are writable by writing a pattern and reading it back. + uint8_t pattern = 0; + for (auto& range : ranges) { + ::memset(range.data(), pattern, range.size()); + ++pattern; + } + pattern = 0; + for (const auto& range : ranges) { + for (size_t i = 0; i < range.size(); ++i) { + ASSERT_EQ(static_cast(range.data()[i]), pattern); + } + ++pattern; + } + + if (testData.size < AsyncDataCacheEntry::kTinyDataSize) { + // Tiny entry: single range backed by tinyData. + ASSERT_EQ(ranges.size(), 1); + ASSERT_TRUE(entry->hasContiguousData()); + ASSERT_EQ(ranges[0].data(), entry->contiguousData()); + } else { + // Allocation-backed: one range per run. + ASSERT_EQ(ranges.size(), entry->nonContiguousData().numRuns()); + } + + entry->setExclusiveToShared(); + } +} + +TEST_P(AsyncDataCacheTest, acquiredMemory) { + constexpr uint64_t kRamBytes = 64UL << 20; + initializeCache(kRamBytes); + auto* allocator = cache_->allocator(); + + // Default is empty with zero bytes. + { + AcquiredMemory acquired; + ASSERT_TRUE(acquired.empty()); + ASSERT_EQ(acquired.totalBytes(), 0); + acquired.free(allocator); + ASSERT_TRUE(acquired.empty()); + } + + // Non-contiguous only. + { + AcquiredMemory acquired; + memory::Allocation allocation; + ASSERT_TRUE(allocator->allocateNonContiguous(10, allocation)); + const auto expectedBytes = allocation.byteSize(); + acquired.nonContiguousAllocs.appendMove(allocation); + + ASSERT_FALSE(acquired.empty()); + ASSERT_EQ(acquired.totalBytes(), expectedBytes); + + acquired.free(allocator); + ASSERT_TRUE(acquired.empty()); + ASSERT_EQ(acquired.totalBytes(), 0); + } + + // Byte allocations only. + { + AcquiredMemory acquired; + auto* ptr1 = allocator->allocateBytes(1'024); + auto* ptr2 = allocator->allocateBytes(2'048); + ASSERT_NE(ptr1, nullptr); + ASSERT_NE(ptr2, nullptr); + acquired.byteAllocations.emplace_back(ptr1, 1'024); + acquired.byteAllocations.emplace_back(ptr2, 2'048); + + ASSERT_FALSE(acquired.empty()); + ASSERT_EQ(acquired.totalBytes(), 3'072); + + acquired.free(allocator); + ASSERT_TRUE(acquired.empty()); + ASSERT_EQ(acquired.totalBytes(), 0); + } + + // Mixed non-contiguous and byte allocations. + { + AcquiredMemory acquired; + memory::Allocation allocation; + ASSERT_TRUE(allocator->allocateNonContiguous(10, allocation)); + const auto nonContiguousBytes = allocation.byteSize(); + acquired.nonContiguousAllocs.appendMove(allocation); + + constexpr uint64_t kByteAllocSize = 4'096; + auto* ptr = allocator->allocateBytes(kByteAllocSize); + ASSERT_NE(ptr, nullptr); + acquired.byteAllocations.emplace_back(ptr, kByteAllocSize); + + ASSERT_FALSE(acquired.empty()); + ASSERT_EQ(acquired.totalBytes(), nonContiguousBytes + kByteAllocSize); + + acquired.free(allocator); + ASSERT_TRUE(acquired.empty()); + ASSERT_EQ(acquired.totalBytes(), 0); + } +} + +TEST_P(AsyncDataCacheTest, eviction) { + constexpr uint64_t kRamBytes = 64UL << 20; + initializeCache(kRamBytes); + + constexpr int32_t kEntrySize = 64 * 1'024; + constexpr int32_t kNumEntries = 32; + + enum class EntryType { kContiguous, kNonContiguous, kMixed }; + + struct TestParam { + EntryType entryType; + std::string debugString() const { + switch (entryType) { + case EntryType::kContiguous: + return "contiguous"; + case EntryType::kNonContiguous: + return "nonContiguous"; + case EntryType::kMixed: + return "mixed"; + } + VELOX_UNREACHABLE(); + } + }; + + std::vector testSettings = { + {EntryType::kContiguous}, + {EntryType::kNonContiguous}, + {EntryType::kMixed}, + }; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + std::vector pins; + for (int i = 0; i < kNumEntries; ++i) { + bool contiguous = false; + switch (testData.entryType) { + case EntryType::kContiguous: + contiguous = true; + break; + case EntryType::kNonContiguous: + contiguous = false; + break; + case EntryType::kMixed: + contiguous = (i % 2 == 0); + break; + } + RawFileCacheKey key{ + filenames_[0].id(), static_cast(i) * kEntrySize}; + auto pin = cache_->findOrCreate(key, kEntrySize, contiguous); + ASSERT_FALSE(pin.empty()); + auto* entry = pin.checkedEntry(); + ASSERT_TRUE(entry->isExclusive()); + if (contiguous) { + ASSERT_TRUE(entry->hasContiguousData()); + ::memset(entry->contiguousData(), static_cast(i), kEntrySize); + } + entry->setExclusiveToShared(); + pins.push_back(std::move(pin)); + } + ASSERT_EQ(pins.size(), kNumEntries); + + auto statsBefore = cache_->refreshStats(); + ASSERT_EQ(statsBefore.numEntries, kNumEntries); + const auto evictsBefore = statsBefore.numEvict; + + pins.clear(); + + auto freed = cache_->shrink(kRamBytes); + ASSERT_GT(freed, 0); + + auto statsAfter = cache_->refreshStats(); + ASSERT_EQ(statsAfter.numEntries, 0); + ASSERT_GE(statsAfter.numEvict - evictsBefore, kNumEntries); + } +} + +TEST_P(AsyncDataCacheTest, retryAllocation) { + constexpr uint64_t kRamBytes = 64UL << 20; + initializeCache(kRamBytes); + + constexpr int32_t kEntrySize = 64 * 1'024; + constexpr int32_t kNumEntries = 512; + + enum class EntryType { kContiguous, kNonContiguous, kMixed }; + + struct TestParam { + EntryType entryType; + std::string debugString() const { + switch (entryType) { + case EntryType::kContiguous: + return "contiguous"; + case EntryType::kNonContiguous: + return "nonContiguous"; + case EntryType::kMixed: + return "mixed"; + } + VELOX_UNREACHABLE(); + } + }; + + std::vector testSettings = { + {EntryType::kContiguous}, + {EntryType::kNonContiguous}, + {EntryType::kMixed}, + }; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + // Fill the cache with entries. + std::vector pins; + for (int i = 0; i < kNumEntries; ++i) { + bool contiguous = false; + switch (testData.entryType) { + case EntryType::kContiguous: + contiguous = true; + break; + case EntryType::kNonContiguous: + contiguous = false; + break; + case EntryType::kMixed: + contiguous = (i % 2 == 0); + break; + } + RawFileCacheKey key{ + filenames_[0].id(), static_cast(i) * kEntrySize}; + auto pin = cache_->findOrCreate(key, kEntrySize, contiguous); + ASSERT_FALSE(pin.empty()); + auto* entry = pin.checkedEntry(); + if (contiguous) { + ::memset(entry->contiguousData(), 0xAB, kEntrySize); + } + entry->setExclusiveToShared(); + pins.push_back(std::move(pin)); + } + + // Unpin so entries are evictable. + pins.clear(); + + auto statsBefore = cache_->refreshStats(); + ASSERT_EQ(statsBefore.numEntries, kNumEntries); + + auto* allocator = cache_->allocator(); + + // Allocate non-contiguous pages through the allocator directly. + // Request more than what's free to force eviction via makeSpace. + constexpr uint64_t kAllocBytes = 48UL << 20; + memory::Allocation allocation; + ASSERT_TRUE(allocator->allocateNonContiguous( + memory::AllocationTraits::numPages(kAllocBytes), allocation)); + + auto statsAfter = cache_->refreshStats(); + ASSERT_GT(statsAfter.numEvict, statsBefore.numEvict); + + allocator->freeNonContiguous(allocation); + + cache_->clear(); + auto statsFinal = cache_->refreshStats(); + ASSERT_EQ(statsFinal.numEntries, 0); + } +} + +TEST_P(AsyncDataCacheTest, makePins) { + constexpr uint64_t kRamBytes = 64UL << 20; + initializeCache(kRamBytes); + + constexpr int32_t kEntrySize = 8'192; + + struct TestParam { + int numEntries; + bool contiguous; + std::string debugString() const { + return fmt::format( + "numEntries {}, contiguous {}", numEntries, contiguous); + } + }; + + std::vector testSettings = { + {1, false}, + {1, true}, + {10, false}, + {10, true}, + {100, false}, + {100, true}, + }; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + std::vector keys; + keys.reserve(testData.numEntries); + for (int i = 0; i < testData.numEntries; ++i) { + keys.push_back( + {filenames_[0].id(), static_cast(i) * kEntrySize}); + } + + std::vector pins; + cache_->makePins( + keys, + [&](size_t /*i*/) { return kEntrySize; }, + [&](size_t /*i*/, CachePin&& pin) { pins.push_back(std::move(pin)); }, + testData.contiguous); + + ASSERT_EQ(pins.size(), testData.numEntries); + for (auto& pin : pins) { + auto* entry = pin.checkedEntry(); + ASSERT_TRUE(entry->isExclusive()); + test::AsyncDataCacheEntryTestHelper entryHelper(entry); + if (testData.contiguous) { + ASSERT_TRUE(entryHelper.isContiguousData()); + ASSERT_TRUE(entry->hasContiguousData()); + ASSERT_EQ(entry->contiguousDataSize(), kEntrySize); + } else { + ASSERT_FALSE(entryHelper.isContiguousData()); + ASSERT_FALSE(entry->hasContiguousData()); + } + entry->setExclusiveToShared(); + } + pins.clear(); + cache_->clear(); + } +} + +TEST_P(AsyncDataCacheTest, removeFileEntries) { + constexpr uint64_t kRamBytes = 64UL << 20; + initializeCache(kRamBytes); + + constexpr int32_t kEntrySize = 8'192; + constexpr int kNumEntries = 10; + + struct TestParam { + bool contiguous; + std::string debugString() const { + return fmt::format("contiguous {}", contiguous); + } + }; + + std::vector testSettings = {{false}, {true}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + for (int i = 0; i < kNumEntries; ++i) { + RawFileCacheKey key{ + filenames_[0].id(), static_cast(i) * kEntrySize}; + auto pin = cache_->findOrCreate(key, kEntrySize, testData.contiguous); + ASSERT_FALSE(pin.empty()); + auto* entry = pin.checkedEntry(); + if (testData.contiguous) { + ::memset(entry->contiguousData(), 0xCD, kEntrySize); + } + entry->setExclusiveToShared(); + } + + auto statsBefore = cache_->refreshStats(); + ASSERT_EQ(statsBefore.numEntries, kNumEntries); + ASSERT_GT(statsBefore.largeSize, 0); + + folly::F14FastSet filesToRemove; + filesToRemove.insert(filenames_[0].id()); + folly::F14FastSet filesRetained; + cache_->removeFileEntries(filesToRemove, filesRetained); + ASSERT_TRUE(filesRetained.empty()); + + auto statsAfter = cache_->refreshStats(); + ASSERT_EQ(statsAfter.numEntries, 0); + ASSERT_EQ(statsAfter.largeSize, 0); + ASSERT_EQ(statsAfter.tinySize, 0); + } +} + +TEST_P(AsyncDataCacheTest, mixedBufferFuzz) { + constexpr uint64_t kRamBytes = 256UL << 20; + constexpr int32_t kNumThreads = 8; + constexpr int32_t kNumKeys = 200; + constexpr int32_t kTestDurationMs = 20'000; + + initializeCache(kRamBytes); + + // Per-key properties are deterministic so all threads agree. + struct KeyProps { + int64_t size; + bool contiguous; + }; + std::vector keyProps(kNumKeys); + { + std::mt19937 rng(42); + for (int i = 0; i < kNumKeys; ++i) { + switch (i % 3) { + case 0: + keyProps[i].size = 512 + (rng() % 1'024); + break; + case 1: + keyProps[i].size = + AsyncDataCacheEntry::kTinyDataSize + (rng() % (64 * 1'024)); + break; + default: + // Large entries ensure total data exceeds cache capacity, + // triggering eviction. + keyProps[i].size = 1'024 * 1'024 + (rng() % (3 * 1'024 * 1'024)); + break; + } + keyProps[i].contiguous = (i % 2 == 0); + } + } + + struct EntryState { + uint8_t pattern{0}; + }; + + std::mutex stateMutex; + std::unordered_map entryStates; + + std::atomic_bool stop{false}; + + auto workerFunc = [&](int32_t threadId) { + std::mt19937 rng(threadId * 7 + 13); + while (!stop.load(std::memory_order_relaxed)) { + const auto keyIdx = rng() % kNumKeys; + const uint64_t offset = keyIdx * 128 * 1'024; + RawFileCacheKey key{filenames_[0].id(), offset}; + + const auto& props = keyProps[keyIdx]; + const uint8_t pattern = static_cast(rng()); + + folly::SemiFuture waitFuture(false); + auto pin = + cache_->findOrCreate(key, props.size, props.contiguous, &waitFuture); + if (pin.empty()) { + continue; + } + auto* entry = pin.checkedEntry(); + ASSERT_EQ(entry->size(), props.size); + + if (entry->isExclusive()) { + if (entry->hasContiguousData()) { + ::memset(entry->contiguousData(), pattern, entry->size()); + } else { + for (auto& range : entry->dataRanges(entry->size())) { + ::memset(range.data(), pattern, range.size()); + } + } + { + // Record state before making shared so other threads always + // find the state when they get a shared pin. + std::lock_guard l(stateMutex); + entryStates[offset] = {pattern}; + } + entry->setExclusiveToShared(); + } else { + EntryState expected; + { + std::lock_guard l(stateMutex); + auto it = entryStates.find(offset); + ASSERT_NE(it, entryStates.end()); + expected = it->second; + } + + if (entry->hasContiguousData()) { + const auto* data = entry->contiguousData(); + for (int i = 0; i < entry->size(); ++i) { + ASSERT_EQ(static_cast(data[i]), expected.pattern) + << "Data mismatch at offset " << offset << " byte " << i; + } + } else { + int byteIdx = 0; + for (const auto& range : entry->dataRanges(entry->size())) { + for (size_t i = 0; i < range.size(); ++i) { + ASSERT_EQ(static_cast(range.data()[i]), expected.pattern) + << "Data mismatch at offset " << offset << " byte " + << byteIdx; + ++byteIdx; + } + } + } + } + // Release pin immediately so eviction can reclaim it. + } + }; + + std::vector threads; + for (int32_t i = 0; i < kNumThreads; ++i) { + threads.emplace_back(workerFunc, i); + } + + std::this_thread::sleep_for( + std::chrono::milliseconds(kTestDurationMs)); // NOLINT + stop.store(true, std::memory_order_relaxed); + + for (auto& thread : threads) { + thread.join(); + } + + auto stats = cache_->refreshStats(); + LOG(INFO) << "contiguousFuzz stats: " << stats.numEntries << " entries, " + << stats.numHit << " hits, " << stats.numNew << " new, " + << stats.numEvict << " evicts"; + + // Verify remaining entries have correct data. + int32_t verified = 0; + int32_t evicted = 0; + { + std::lock_guard l(stateMutex); + for (const auto& [offset, expected] : entryStates) { + RawFileCacheKey key{filenames_[0].id(), offset}; + auto result = cache_->find(key); + if (!result.has_value()) { + ++evicted; + continue; + } + ASSERT_FALSE(result->empty()); + auto* entry = result->checkedEntry(); + // Verify the data pattern. + if (entry->hasContiguousData()) { + const auto* data = entry->contiguousData(); + for (int i = 0; i < entry->size(); ++i) { + ASSERT_EQ(static_cast(data[i]), expected.pattern); + } + } else { + for (const auto& range : entry->dataRanges(entry->size())) { + for (size_t i = 0; i < range.size(); ++i) { + ASSERT_EQ(static_cast(range.data()[i]), expected.pattern); + } + } + } + ++verified; + } + } + + LOG(INFO) << "Verified " << verified << " entries, " << evicted + << " evicted from tracked state"; + if (evicted > 0) { + ASSERT_GT(stats.numEvict, 0); + } + + cache_->clear(); + auto finalStats = cache_->refreshStats(); + ASSERT_EQ(finalStats.numEntries, 0); +} + INSTANTIATE_TEST_SUITE_P( AsyncDataCacheTest, AsyncDataCacheTest, diff --git a/velox/common/caching/tests/CMakeLists.txt b/velox/common/caching/tests/CMakeLists.txt index c80dacb3396..2124e9cb7d1 100644 --- a/velox/common/caching/tests/CMakeLists.txt +++ b/velox/common/caching/tests/CMakeLists.txt @@ -19,28 +19,34 @@ target_link_libraries( PRIVATE velox_common_base Folly::folly velox_time glog::glog GTest::gtest GTest::gtest_main ) -add_executable( - velox_cache_test +# Split velox_cache_test into individual test binaries for parallel execution. +set( + VELOX_CACHE_TEST_SOURCES AsyncDataCacheTest.cpp CacheTTLControllerTest.cpp SsdFileTest.cpp SsdFileTrackerTest.cpp StringIdMapTest.cpp ) -add_test(velox_cache_test velox_cache_test) -target_link_libraries( - velox_cache_test - PRIVATE - velox_caching - velox_file - velox_file_test_utils - velox_memory - velox_temp_path - velox_flag_definitions - Folly::folly - glog::glog - GTest::gtest - GTest::gtest_main + +set( + VELOX_CACHE_TEST_DEPS + velox_caching + velox_file + velox_file_test_utils + velox_memory + velox_test_util + velox_flag_definitions + Folly::folly + glog::glog + GTest::gtest + GTest::gtest_main +) + +velox_add_grouped_tests( + PREFIX velox_cache_test + SOURCES ${VELOX_CACHE_TEST_SOURCES} + DEPS ${VELOX_CACHE_TEST_DEPS} ) add_executable(cached_factory_test CachedFactoryTest.cpp) @@ -49,3 +55,5 @@ target_link_libraries( cached_factory_test PRIVATE velox_common_base Folly::folly velox_time glog::glog GTest::gtest GTest::gtest_main ) + +velox_add_library(velox_caching_test_util INTERFACE HEADERS CacheTestUtil.h) diff --git a/velox/common/caching/tests/CacheTTLControllerTest.cpp b/velox/common/caching/tests/CacheTTLControllerTest.cpp index 6f7f39454af..dad3c0ed504 100644 --- a/velox/common/caching/tests/CacheTTLControllerTest.cpp +++ b/velox/common/caching/tests/CacheTTLControllerTest.cpp @@ -31,7 +31,7 @@ class CacheTTLControllerTest : public ::testing::Test { protected: void SetUp() override { allocator_ = std::make_shared( - MmapAllocator::Options{.capacity = 1024L * 1024L}); + MemoryAllocator::Options{.capacity = 1024L * 1024L}); cache_ = AsyncDataCache::create(allocator_.get()); } diff --git a/velox/common/caching/tests/CacheTestUtil.h b/velox/common/caching/tests/CacheTestUtil.h index 731da4e876a..d848f1007c9 100644 --- a/velox/common/caching/tests/CacheTestUtil.h +++ b/velox/common/caching/tests/CacheTestUtil.h @@ -59,6 +59,18 @@ class SsdFileTestHelper { return ssdFile_->checksumReadVerificationEnabled_; } + SsdFile::State state() const { + return ssdFile_->state_.load(); + } + + void setState(SsdFile::State state) { + ssdFile_->state_ = state; + } + + int32_t maxRegions() const { + return ssdFile_->maxRegions_; + } + /// Deletes the backing file. void deleteFile() { process::TraceContext trace("SsdFile::testingDeleteFile"); @@ -135,6 +147,14 @@ class AsyncDataCacheEntryTestHelper { return asyncDataCacheEntry_->isFirstUse_; } + bool isTinyData() const { + return !asyncDataCacheEntry_->tinyData_.empty(); + } + + bool isContiguousData() const { + return asyncDataCacheEntry_->contiguousData_ != nullptr; + } + private: AsyncDataCacheEntry* const asyncDataCacheEntry_; }; diff --git a/velox/common/caching/tests/SsdFileTest.cpp b/velox/common/caching/tests/SsdFileTest.cpp index bf4f2a94fbf..c971ebc87ca 100644 --- a/velox/common/caching/tests/SsdFileTest.cpp +++ b/velox/common/caching/tests/SsdFileTest.cpp @@ -20,7 +20,7 @@ #include "velox/common/file/FileSystems.h" #include "velox/common/file/tests/FaultyFileSystem.h" #include "velox/common/memory/Memory.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include #include @@ -31,6 +31,7 @@ using namespace facebook::velox; using namespace facebook::velox::cache; +using namespace facebook::velox::common::testutil; using namespace facebook::velox::tests::utils; using facebook::velox::memory::MemoryAllocator; @@ -79,10 +80,9 @@ class SsdFileTest : public testing::Test { FLAGS_velox_ssd_odirect = false; cache_ = AsyncDataCache::create(memory::memoryManager()->allocator()); cacheHelper_ = - std::make_unique(cache_.get()); + std::make_unique(cache_.get()); fileName_ = StringIdLease(fileIds(), "fileInStorage"); - tempDirectory_ = - exec::test::TempDirectoryPath::create(enableFaultInjection); + tempDirectory_ = TempDirectoryPath::create(enableFaultInjection); initializeSsdFile( ssdBytes, checkpointIntervalBytes, @@ -96,7 +96,8 @@ class SsdFileTest : public testing::Test { uint64_t checkpointIntervalBytes = 0, bool checksumEnabled = false, bool checksumReadVerificationEnabled = false, - bool disableFileCow = false) { + bool disableFileCow = false, + uint64_t maxEntries = 0) { SsdFile::Config config( fmt::format("{}/ssdtest", tempDirectory_->getPath()), 0, // shardId @@ -105,11 +106,12 @@ class SsdFileTest : public testing::Test { disableFileCow, checksumEnabled, checksumReadVerificationEnabled, + maxEntries, ssdExecutor()); ssdFile_ = std::make_unique(config); if (ssdFile_ != nullptr) { ssdFileHelper_ = - std::make_unique(ssdFile_.get()); + std::make_unique(ssdFile_.get()); } } @@ -197,12 +199,12 @@ class SsdFileTest : public testing::Test { std::vector pins; while (bytesFromCache < totalSize) { pins.push_back( - cache_->findOrCreate(RawFileCacheKey{fileId, offset}, size, nullptr)); + cache_->findOrCreate(RawFileCacheKey{fileId, offset}, size)); bytesFromCache += size; EXPECT_FALSE(pins.back().empty()); auto entry = pins.back().entry(); if (entry && entry->isExclusive()) { - initializeContents(fileId + offset, entry->data()); + initializeContents(fileId + offset, entry->nonContiguousData()); } offset += size; size *= 2; @@ -231,13 +233,14 @@ class SsdFileTest : public testing::Test { std::vector ssdPins; ssdPins.reserve(pins.size()); for (auto& pin : pins) { - ssdPins.push_back(ssdFile_->find(RawFileCacheKey{ - pin.entry()->key().fileNum.id(), pin.entry()->key().offset})); + ssdPins.push_back(ssdFile_->find( + RawFileCacheKey{ + pin.entry()->key().fileNum.id(), pin.entry()->key().offset})); EXPECT_FALSE(ssdPins.back().empty()); } ssdFile_->load(ssdPins, pins); for (auto& pin : pins) { - checkContents(pin.entry()->data(), pin.entry()->size()); + checkContents(pin.entry()->nonContiguousData(), pin.entry()->size()); } } @@ -273,7 +276,7 @@ class SsdFileTest : public testing::Test { EXPECT_FALSE( ssdFile_->find(RawFileCacheKey{fileName_.id(), ssdSize}).empty()); - pins.erase(pins.begin(), pins.begin() + numWritten); + pins.erase(pins.cbegin(), pins.cbegin() + numWritten); ssdFile_->write(pins); for (auto& pin : pins) { if (pin.entry()->ssdFile()) { @@ -299,8 +302,7 @@ class SsdFileTest : public testing::Test { std::vector pins; pins.push_back(cache_->findOrCreate( RawFileCacheKey{entry.key.fileNum.id(), entry.key.offset}, - entry.size, - nullptr)); + entry.size)); if (pins.back().entry()->isExclusive()) { std::vector ssdPins; ssdPins.push_back(ssdFile_->find( @@ -309,21 +311,23 @@ class SsdFileTest : public testing::Test { ++numFound; ssdFile_->load(ssdPins, pins); checkContents( - pins[0].entry()->data(), pins[0].entry()->size(), expectEqual); + pins[0].entry()->nonContiguousData(), + pins[0].entry()->size(), + expectEqual); } } } return numFound; } - std::shared_ptr tempDirectory_; + std::shared_ptr tempDirectory_; std::shared_ptr cache_; - std::unique_ptr cacheHelper_; + std::unique_ptr cacheHelper_; StringIdLease fileName_; std::unique_ptr ssdFile_; - std::unique_ptr ssdFileHelper_; + std::unique_ptr ssdFileHelper_; }; TEST_F(SsdFileTest, writeAndRead) { @@ -375,9 +379,7 @@ TEST_F(SsdFileTest, writeAndRead) { std::vector pins; pins.push_back(cache_->findOrCreate( - RawFileCacheKey{fileName_.id(), entry.key.offset}, - entry.size, - nullptr)); + RawFileCacheKey{fileName_.id(), entry.key.offset}, entry.size)); if (pins.back().entry()->isExclusive()) { std::vector ssdPins; @@ -385,7 +387,8 @@ TEST_F(SsdFileTest, writeAndRead) { ssdFile_->find(RawFileCacheKey{fileName_.id(), entry.key.offset})); if (!ssdPins.back().empty()) { ssdFile_->load(ssdPins, pins); - checkContents(pins[0].entry()->data(), pins[0].entry()->size()); + checkContents( + pins[0].entry()->nonContiguousData(), pins[0].entry()->size()); } } } @@ -415,9 +418,7 @@ TEST_F(SsdFileTest, checkpoint) { makePins(fileName_.id(), startOffset, 4096, 2048 * 1025, 62 * kMB); // Each region has one entry from `fileNameAlt`. pins.push_back(cache_->findOrCreate( - RawFileCacheKey{fileNameAlt.id(), (uint64_t)startOffset}, - 1024, - nullptr)); + RawFileCacheKey{fileNameAlt.id(), (uint64_t)startOffset}, 1024)); ssdFile_->write(pins); for (auto& pin : pins) { EXPECT_EQ(ssdFile_.get(), pin.entry()->ssdFile()); @@ -442,9 +443,7 @@ TEST_F(SsdFileTest, checkpoint) { auto pins = makePins(fileName_.id(), startOffset, 4096, 2048 * 1025, 62 * kMB); pins.push_back(cache_->findOrCreate( - RawFileCacheKey{fileNameAlt.id(), (uint64_t)startOffset}, - 1024, - nullptr)); + RawFileCacheKey{fileNameAlt.id(), (uint64_t)startOffset}, 1024)); readAndCheckPins(pins); } // All entries can be found. @@ -820,6 +819,110 @@ TEST_F(SsdFileTest, ssdReadWithoutChecksumCheck) { #endif } +TEST_F(SsdFileTest, writeInNoSpaceState) { + constexpr int64_t kSsdSize = 16 * SsdFile::kRegionSize; + initializeCache(kSsdSize); + + // Verify the initial state is kActive. + EXPECT_EQ(ssdFileHelper_->state(), SsdFile::State::kActive); + + // Verify the cache write is successful in the initial kActiveState. + auto pins = makePins(fileName_.id(), 0, 4096, 4096, 4096 * 10); + ssdFile_->write(pins); + SsdCacheStats statsBeforeNoSpace; + ssdFile_->updateStats(statsBeforeNoSpace); + EXPECT_GT(statsBeforeNoSpace.entriesWritten, 0); + EXPECT_EQ(statsBeforeNoSpace.writeSsdDropped, 0); + + // Set the state to kNoSpace to simulate the SSD running out of space. + ssdFileHelper_->setState(SsdFile::State::kNoSpace); + EXPECT_EQ(ssdFileHelper_->state(), SsdFile::State::kNoSpace); + + // Verify that writes are dropped and no new entries are written in the + // kNoSpace state. + auto morePins = makePins(fileName_.id(), 4096 * 10, 4096, 4096, 4096 * 5); + ssdFile_->write(morePins); + SsdCacheStats statsAfterNoSpace; + ssdFile_->updateStats(statsAfterNoSpace); + EXPECT_GT( + statsAfterNoSpace.writeSsdDropped, statsBeforeNoSpace.writeSsdDropped); + EXPECT_EQ( + statsAfterNoSpace.entriesWritten, statsBeforeNoSpace.entriesWritten); + + // Verify none of the new entries have ssdFile set. + for (const auto& pin : morePins) { + EXPECT_EQ(pin.entry()->ssdFile(), nullptr); + } +} + +TEST_F(SsdFileTest, checkpointInNoSpaceState) { + constexpr int64_t kSsdSize = 4 * SsdFile::kRegionSize; + const uint64_t checkpointIntervalBytes = 2 * SsdFile::kRegionSize; + initializeCache(kSsdSize, checkpointIntervalBytes); + + // Write some entries to trigger checkpoint eligibility. + auto pins = makePins(fileName_.id(), 0, 4096, 4096, 3 * SsdFile::kRegionSize); + ssdFile_->write(pins); + + SsdCacheStats statsBeforeNoSpace; + ssdFile_->updateStats(statsBeforeNoSpace); + + // Set the state to kNoSpace. + ssdFileHelper_->setState(SsdFile::State::kNoSpace); + EXPECT_EQ(ssdFileHelper_->state(), SsdFile::State::kNoSpace); + + // Verify checkpointing is skipped in the kNoSpace state. + ssdFile_->checkpoint(true); + + SsdCacheStats statsAfterNoSpace; + ssdFile_->updateStats(statsAfterNoSpace); + EXPECT_EQ( + statsAfterNoSpace.checkpointsWritten, + statsBeforeNoSpace.checkpointsWritten); +} + +TEST_F(SsdFileTest, growOrEvictBlockedInNoSpaceState) { + constexpr int64_t kSsdSize = 4 * SsdFile::kRegionSize; + const uint64_t checkpointIntervalBytes = kSsdSize; + initializeCache(kSsdSize, checkpointIntervalBytes); + + // Fill up the SSD cache to trigger eviction on subsequent writes. + for (auto startOffset = 0; startOffset <= kSsdSize - SsdFile::kRegionSize; + startOffset += SsdFile::kRegionSize) { + auto pins = makePins( + fileName_.id(), startOffset, 4096, 4096, SsdFile::kRegionSize - 1024); + ssdFile_->write(pins); + } + + // The SSD cache is at max regins. + SsdCacheStats statsBeforeNoSpace; + ssdFile_->updateStats(statsBeforeNoSpace); + EXPECT_EQ(statsBeforeNoSpace.regionsCached, ssdFileHelper_->maxRegions()); + + ssdFileHelper_->setState(SsdFile::State::kNoSpace); + EXPECT_EQ(ssdFileHelper_->state(), SsdFile::State::kNoSpace); + + // Verify the eviction should not happen since write was dropped before + // eviction. + auto newPins = makePins(fileName_.id(), kSsdSize * 2, 4096, 4096, 4096 * 5); + ssdFile_->write(newPins); + + SsdCacheStats statsAfterNoSpace; + ssdFile_->updateStats(statsAfterNoSpace); + + EXPECT_EQ( + statsAfterNoSpace.regionsEvicted, statsBeforeNoSpace.regionsEvicted); + EXPECT_GT( + statsAfterNoSpace.writeSsdDropped, statsBeforeNoSpace.writeSsdDropped); + EXPECT_EQ( + statsAfterNoSpace.entriesWritten, statsBeforeNoSpace.entriesWritten); + + // Verify none of the new entries have ssdFile set. + for (const auto& pin : newPins) { + EXPECT_EQ(pin.entry()->ssdFile(), nullptr); + } +} + TEST_F(SsdFileTest, dataFileErrorInjection) { constexpr int64_t kSsdSize = 16 * SsdFile::kRegionSize; initializeCache(kSsdSize, 0, false, false, false, true); @@ -868,8 +971,9 @@ TEST_F(SsdFileTest, dataFileErrorInjection) { std::vector ssdPins; ssdPins.reserve(pins.size()); for (auto& pin : pins) { - ssdPins.push_back(ssdFile_->find(RawFileCacheKey{ - pin.entry()->key().fileNum.id(), pin.entry()->key().offset})); + ssdPins.push_back(ssdFile_->find( + RawFileCacheKey{ + pin.entry()->key().fileNum.id(), pin.entry()->key().offset})); } SsdCacheStats statsWithReadErrorInjected; @@ -954,6 +1058,65 @@ TEST_F(SsdFileTest, evictlogFileErrorInjection) { ASSERT_GT(statsAfterRecovery.readCheckpointErrors, 0); } +TEST_F(SsdFileTest, maxEntriesLimit) { + constexpr int64_t kSsdSize = 16 * SsdFile::kRegionSize; + constexpr uint64_t kMaxEntries = 100; + FLAGS_velox_ssd_verify_write = true; + + initializeCache(kSsdSize); + // Re-initialize SSD file with maxEntries limit + initializeSsdFile(kSsdSize, 0, false, false, false, kMaxEntries); + + // Write more entries than the limit + auto pins = makePins(fileName_.id(), 0, 4096, 2048 * 1025, 62 * kMB); + ASSERT_GT(pins.size(), kMaxEntries); + + ssdFile_->write(pins); + + SsdCacheStats stats; + ssdFile_->updateStats(stats); + + // The SSD file should have at most maxEntries + EXPECT_LE(stats.entriesCached, kMaxEntries); + // Some writes should have been dropped due to the entry limit + EXPECT_GT(stats.writeSsdExceedEntryLimit, 0); +} + +TEST_F(SsdFileTest, noWritesDroppedWithinMaxEntriesLimit) { + constexpr int64_t kSsdSize = 16 * SsdFile::kRegionSize; + constexpr uint64_t kMaxEntries = 200; + FLAGS_velox_ssd_verify_write = true; + + initializeCache(kSsdSize); + // Re-initialize SSD file with maxEntries limit + initializeSsdFile(kSsdSize, 0, false, false, false, kMaxEntries); + + // Write fewer entries than the limit + auto pins = makePins(fileName_.id(), 0, 4096, 4096, 4096 * 50); + ASSERT_LT(pins.size(), kMaxEntries); + const auto numPins = pins.size(); + + ssdFile_->write(pins); + + SsdCacheStats stats; + ssdFile_->updateStats(stats); + + // All entries should be cached since we're within the limit + EXPECT_EQ(stats.entriesCached, numPins); + // No writes should be dropped due to exceeding entry limit + EXPECT_EQ(stats.writeSsdExceedEntryLimit, 0); + // No writes should be dropped due to lack of space + EXPECT_EQ(stats.writeSsdDropped, 0); + // All entries should have been written + EXPECT_EQ(stats.entriesWritten, numPins); + + // Verify all entries are readable + for (auto& pin : pins) { + EXPECT_EQ(ssdFile_.get(), pin.entry()->ssdFile()); + } + readAndCheckPins(pins); +} + #ifdef VELOX_SSD_FILE_TEST_SET_NO_COW_FLAG TEST_F(SsdFileTest, disabledCow) { constexpr int64_t kSsdSize = 16 * SsdFile::kRegionSize; diff --git a/velox/common/caching/tests/StringIdMapTest.cpp b/velox/common/caching/tests/StringIdMapTest.cpp index 59d95af2e88..1a1d006748a 100644 --- a/velox/common/caching/tests/StringIdMapTest.cpp +++ b/velox/common/caching/tests/StringIdMapTest.cpp @@ -22,7 +22,7 @@ using namespace facebook::velox; TEST(StringIdMapTest, basic) { - constexpr const char* kFile1 = "file_1"; + constexpr std::string_view kFile1 = "file_1"; StringIdMap map; uint64_t id = 0; { @@ -33,7 +33,7 @@ TEST(StringIdMapTest, basic) { id = lease2.id(); lease1 = lease2; EXPECT_EQ(id, lease1.id()); - EXPECT_EQ(strlen(kFile1), map.pinnedSize()); + EXPECT_EQ(kFile1.size(), map.pinnedSize()); } StringIdLease lease3(map, kFile1); EXPECT_NE(lease3.id(), id); @@ -56,50 +56,48 @@ TEST(StringIdMapTest, rehash) { } TEST(StringIdMapTest, recover) { - constexpr const char* kRecoverFile1 = "file_1"; - constexpr const char* kRecoverFile2 = "file_2"; - constexpr const char* kRecoverFile3 = "file_3"; + constexpr std::string_view kRecoverFile1("file_1"); + constexpr std::string_view kRecoverFile2("file_2"); + constexpr std::string_view kRecoverFile3("file_3"); StringIdMap map; const uint64_t recoverId1{10}; const uint64_t recoverId2{20}; { StringIdLease lease(map, recoverId1, kRecoverFile1); ASSERT_TRUE(lease.hasValue()); - ASSERT_EQ(map.pinnedSize(), ::strlen(kRecoverFile1)); + ASSERT_EQ(map.pinnedSize(), kRecoverFile1.size()); ASSERT_EQ(map.testingLastId(), recoverId1); VELOX_ASSERT_THROW( - std::make_unique(map, recoverId1, kRecoverFile2), + StringIdLease(map, recoverId1, kRecoverFile2), "(1 vs. 0) Reused recover id 10 assigned to file_2"); VELOX_ASSERT_THROW( - std::make_unique(map, recoverId2, kRecoverFile1), + StringIdLease(map, recoverId2, kRecoverFile1), "(20 vs. 10) Multiple recover ids assigned to file_1"); } ASSERT_EQ(map.pinnedSize(), 0); StringIdLease lease1(map, kRecoverFile1); - ASSERT_EQ(map.pinnedSize(), ::strlen(kRecoverFile1)); + ASSERT_EQ(map.pinnedSize(), kRecoverFile1.size()); ASSERT_EQ(map.testingLastId(), recoverId1 + 1); { StringIdLease lease(map, recoverId2, kRecoverFile2); ASSERT_TRUE(lease.hasValue()); ASSERT_EQ(lease.id(), recoverId2); - ASSERT_EQ( - map.pinnedSize(), ::strlen(kRecoverFile1) + ::strlen(kRecoverFile2)); + ASSERT_EQ(map.pinnedSize(), kRecoverFile1.size() + kRecoverFile2.size()); ASSERT_EQ(map.testingLastId(), recoverId2); VELOX_ASSERT_THROW( - std::make_unique(map, recoverId2, kRecoverFile3), + StringIdLease(map, recoverId2, kRecoverFile3), "(1 vs. 0) Reused recover id 20 assigned to file_3"); VELOX_ASSERT_THROW( - std::make_unique(map, recoverId2, kRecoverFile1), + StringIdLease(map, recoverId2, kRecoverFile1), "(20 vs. 11) Multiple recover ids assigned to file_1"); StringIdLease dupLease(map, recoverId2, kRecoverFile2); ASSERT_TRUE(lease.hasValue()); ASSERT_EQ(lease.id(), recoverId2); - ASSERT_EQ( - map.pinnedSize(), ::strlen(kRecoverFile1) + ::strlen(kRecoverFile2)); + ASSERT_EQ(map.pinnedSize(), kRecoverFile1.size() + kRecoverFile2.size()); } ASSERT_EQ(map.testingLastId(), recoverId2); - ASSERT_EQ(map.pinnedSize(), ::strlen(kRecoverFile1)); + ASSERT_EQ(map.pinnedSize(), kRecoverFile1.size()); } diff --git a/velox/common/compression/CMakeLists.txt b/velox/common/compression/CMakeLists.txt index d8743ce3860..bea6c0ad4e9 100644 --- a/velox/common/compression/CMakeLists.txt +++ b/velox/common/compression/CMakeLists.txt @@ -16,7 +16,16 @@ if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) endif() -velox_add_library(velox_common_compression Compression.cpp LzoDecompressor.cpp) +velox_add_library( + velox_common_compression + Compression.cpp + LzoDecompressor.cpp + HEADERS + Compression.h + HadoopCompressionFormat.h + Lz4Compression.h + LzoDecompressor.h +) velox_link_libraries( velox_common_compression PUBLIC velox_status Folly::folly diff --git a/velox/common/compression/Compression.cpp b/velox/common/compression/Compression.cpp index 8473b4fe119..4dde4b95a89 100644 --- a/velox/common/compression/Compression.cpp +++ b/velox/common/compression/Compression.cpp @@ -122,9 +122,10 @@ Expected> Codec::create( const CodecOptions& codecOptions) { if (!isAvailable(kind)) { auto name = compressionKindToString(kind); - return folly::makeUnexpected(Status::Invalid( - "Support for codec '{}' is either not built or not implemented.", - name)); + return folly::makeUnexpected( + Status::Invalid( + "Support for codec '{}' is either not built or not implemented.", + name)); } auto compressionLevel = codecOptions.compressionLevel; @@ -155,9 +156,10 @@ Expected> Codec::create( } VELOX_RETURN_UNEXPECTED_IF( codec == nullptr, - Status::Invalid(fmt::format( - "Support for codec '{}' is either not built or not implemented.", - compressionKindToString(kind)))); + Status::Invalid( + fmt::format( + "Support for codec '{}' is either not built or not implemented.", + compressionKindToString(kind)))); VELOX_RETURN_UNEXPECTED_NOT_OK(codec->init()); @@ -184,8 +186,9 @@ bool Codec::isAvailable(CompressionKind kind) { Expected Codec::getUncompressedLength( const uint8_t* input, uint64_t inputLength) const { - return folly::makeUnexpected(Status::Invalid( - "getUncompressedLength is unsupported with {} format.", name())); + return folly::makeUnexpected( + Status::Invalid( + "getUncompressedLength is unsupported with {} format.", name())); } Expected Codec::compressFixedLength( @@ -203,14 +206,16 @@ bool Codec::supportsStreamingCompression() const { Expected> Codec::makeStreamingCompressor() { - return folly::makeUnexpected(Status::Invalid( - "Streaming compression is unsupported with {} format.", name())); + return folly::makeUnexpected( + Status::Invalid( + "Streaming compression is unsupported with {} format.", name())); } Expected> Codec::makeStreamingDecompressor() { - return folly::makeUnexpected(Status::Invalid( - "Streaming decompression is unsupported with {} format.", name())); + return folly::makeUnexpected( + Status::Invalid( + "Streaming decompression is unsupported with {} format.", name())); } int32_t Codec::compressionLevel() const { diff --git a/velox/common/compression/Lz4Compression.cpp b/velox/common/compression/Lz4Compression.cpp index 58c6ed6662f..f7bf7a66db0 100644 --- a/velox/common/compression/Lz4Compression.cpp +++ b/velox/common/compression/Lz4Compression.cpp @@ -588,8 +588,9 @@ Expected Lz4HadoopCodec::compress( folly::Endian::big(static_cast(inputLength)); const uint32_t compressedLength = folly::Endian::big(static_cast(compressedSize)); - folly::storeUnaligned(output, decompressedLength); - folly::storeUnaligned(output + sizeof(uint32_t), compressedLength); + folly::storeUnaligned(output, decompressedLength); + folly::storeUnaligned( + output + sizeof(uint32_t), compressedLength); return kPrefixLength + compressedSize; }); } diff --git a/velox/common/config/CMakeLists.txt b/velox/common/config/CMakeLists.txt index c2d2acd82b9..9b90fb2a890 100644 --- a/velox/common/config/CMakeLists.txt +++ b/velox/common/config/CMakeLists.txt @@ -16,5 +16,17 @@ if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) endif() -velox_add_library(velox_common_config Config.cpp) +velox_add_library(velox_common_config Config.cpp HEADERS Config.h IConfig.h) velox_link_libraries(velox_common_config PUBLIC velox_common_base velox_exception PRIVATE re2::re2) + +velox_add_library( + velox_config_property + ConfigProperty.cpp + HEADERS + ConfigProperty.h + ConfigProvider.h +) +velox_link_libraries( + velox_config_property + PUBLIC velox_enum_declare velox_enum_define velox_exception +) diff --git a/velox/common/config/Config.cpp b/velox/common/config/Config.cpp index 9a37fa65a08..aae7ed5ab00 100644 --- a/velox/common/config/Config.cpp +++ b/velox/common/config/Config.cpp @@ -15,6 +15,7 @@ */ #include +#include #include "velox/common/config/Config.h" @@ -138,13 +139,18 @@ std::unordered_map ConfigBase::rawConfigsCopy() return configs_; } -std::optional ConfigBase::get(const std::string& key) const { - std::optional val; - std::shared_lock l(mutex_); - auto it = configs_.find(key); - if (it != configs_.end()) { - val = it->second; +std::string ConfigBase::toConfigKey(std::string_view sessionKey) { + std::string configKey{sessionKey}; + std::replace(configKey.begin(), configKey.end(), '_', '-'); + return configKey; +} + +std::optional ConfigBase::access(const std::string& key) const { + std::shared_lock l{mutex_}; + if (auto it = configs_.find(key); it != configs_.end()) { + return it->second; } - return val; + return std::nullopt; } + } // namespace facebook::velox::config diff --git a/velox/common/config/Config.h b/velox/common/config/Config.h index 7aea6575a03..92985e76f4e 100644 --- a/velox/common/config/Config.h +++ b/velox/common/config/Config.h @@ -18,11 +18,13 @@ #include #include +#include #include #include #include "folly/Conv.h" #include "velox/common/base/Exceptions.h" +#include "velox/common/config/IConfig.h" namespace facebook::velox::config { @@ -47,7 +49,7 @@ std::chrono::duration toDuration(const std::string& str); /// The concrete config class should inherit the config base and define all the /// entries. -class ConfigBase { +class ConfigBase : public IConfig { public: template struct Entry { @@ -111,49 +113,51 @@ class ConfigBase { : entry.defaultVal; } + using IConfig::get; + + bool valueExists(const std::string& key) const; + + const std::unordered_map& rawConfigs() const; + + std::unordered_map rawConfigsCopy() const final; + + /// Converts a session key to another config key by replacing + /// '_' with '-'. + static std::string toConfigKey(std::string_view sessionKey); + + /// Returns the value for 'key' if present; otherwise checks 'fallback'. + /// Fallback key is derived from 'key' by replacing '_' with '-'. template - std::optional get( - const std::string& key, - std::function toT = [](auto /* unused */, - auto value) { - return folly::to(value); - }) const { - auto val = get(key); - if (val.has_value()) { - return toT(key, val.value()); - } else { - return std::nullopt; + std::optional getWithFallback( + std::string_view key, + const ConfigBase& fallback) const { + if (auto value = get(std::string(key))) { + return value; } + return fallback.get(toConfigKey(key)); } + /// Deprecated: Do not use in new code. + /// Legacy helper for key pairs that don't follow the standard + /// session-key('_') <-> config-key('-') naming convention. + /// Prefer getWithFallback(key, fallback) instead. template - T get( - const std::string& key, - const T& defaultValue, - std::function toT = [](auto /* unused */, - auto value) { - return folly::to(value); - }) const { - auto val = get(key); - if (val.has_value()) { - return toT(key, val.value()); - } else { - return defaultValue; + std::optional getLegacyWithFallback( + std::string_view key, + const ConfigBase& fallback, + std::string_view fallbackKey) const { + if (auto value = get(std::string(key))) { + return value; } + return fallback.get(std::string(fallbackKey)); } - bool valueExists(const std::string& key) const; - - const std::unordered_map& rawConfigs() const; - - std::unordered_map rawConfigsCopy() const; - protected: mutable std::shared_mutex mutex_; std::unordered_map configs_; private: - std::optional get(const std::string& key) const; + std::optional access(const std::string& key) const final; const bool mutable_; }; diff --git a/velox/common/config/ConfigProperty.cpp b/velox/common/config/ConfigProperty.cpp new file mode 100644 index 00000000000..d2bee05eab6 --- /dev/null +++ b/velox/common/config/ConfigProperty.cpp @@ -0,0 +1,48 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/common/config/ConfigProperty.h" + +#include +#include "velox/common/EnumDefine.h" + +namespace facebook::velox::config { + +namespace { + +const auto& configPropertyTypeNames() { + static const folly::F14FastMap kNames = + { + {ConfigPropertyType::kBoolean, "BOOLEAN"}, + {ConfigPropertyType::kInteger, "INTEGER"}, + {ConfigPropertyType::kDouble, "DOUBLE"}, + {ConfigPropertyType::kString, "STRING"}, + }; + return kNames; +} + +} // namespace + +VELOX_DEFINE_ENUM_NAME(ConfigPropertyType, configPropertyTypeNames); + +namespace detail { + +std::string toStringValue(double value) { + return fmt::format("{}", value); +} + +} // namespace detail + +} // namespace facebook::velox::config diff --git a/velox/common/config/ConfigProperty.h b/velox/common/config/ConfigProperty.h new file mode 100644 index 00000000000..09085b02507 --- /dev/null +++ b/velox/common/config/ConfigProperty.h @@ -0,0 +1,99 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +#include "velox/common/EnumDeclare.h" + +namespace facebook::velox::config { + +enum class ConfigPropertyType { + kBoolean, + kInteger, + kDouble, + kString, +}; + +VELOX_DECLARE_ENUM_NAME(ConfigPropertyType); + +/// Describes a single configuration property: name, type, default, and +/// human-readable description. +struct ConfigProperty { + /// Property name, e.g. "session_timezone". + std::string name; + + ConfigPropertyType type; + + /// Default value as a string, or nullopt if the property has no default. + std::optional defaultValue; + + /// Human-readable description of the property. + std::string description; +}; + +namespace detail { + +std::string toStringValue(double value); + +template +constexpr ConfigPropertyType configPropertyTypeOf() { + if constexpr (std::is_same_v) { + return ConfigPropertyType::kBoolean; + } else if constexpr (std::is_integral_v) { + return ConfigPropertyType::kInteger; + } else if constexpr (std::is_floating_point_v) { + return ConfigPropertyType::kDouble; + } else { + return ConfigPropertyType::kString; + } +} + +template +std::string configPropertyDefaultToString(const T& value) { + if constexpr (std::is_same_v) { + return value ? "true" : "false"; + } else if constexpr (std::is_arithmetic_v) { + if constexpr (std::is_floating_point_v) { + return toStringValue(static_cast(value)); + } else { + static_assert(!std::is_same_v, "Use int, not char."); + return std::to_string(value); + } + } else { + return std::string(value); + } +} + +} // namespace detail + +/// Registers a property from a traits struct into a vector. +/// The traits struct must have: key, type, defaultValue, description. +template +void registerConfigProperty(std::vector& registry) { + registry.push_back({ + PropertyTraits::key, + detail::configPropertyTypeOf(), + detail::configPropertyDefaultToString( + PropertyTraits::defaultValue), + PropertyTraits::description, + }); +} + +} // namespace facebook::velox::config diff --git a/velox/common/config/ConfigProvider.h b/velox/common/config/ConfigProvider.h new file mode 100644 index 00000000000..ddc5e39c971 --- /dev/null +++ b/velox/common/config/ConfigProvider.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/common/config/ConfigProperty.h" + +namespace facebook::velox::config { + +/// Declares and normalizes configuration properties for a component. +/// A ConfigProvider is shared and stateless — it does not hold per-session +/// values. It knows only its own property names (e.g., "session_timezone"), +/// not how they are exposed to the user. +class ConfigProvider { + public: + virtual ~ConfigProvider() = default; + + /// Returns the list of supported configuration properties. + virtual std::vector properties() const = 0; + + /// Validates and returns the canonical form of the value. Called after + /// type validation. Throws if the property name is unknown or the value + /// is invalid. + virtual std::string normalize(std::string_view name, std::string_view value) + const = 0; +}; + +} // namespace facebook::velox::config diff --git a/velox/common/config/IConfig.h b/velox/common/config/IConfig.h new file mode 100644 index 00000000000..11cba5cc029 --- /dev/null +++ b/velox/common/config/IConfig.h @@ -0,0 +1,78 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace facebook::velox::config { + +/// IConfig - Read-only config interface +/// for accessing key-value parameters. +/// Supports value retrieval by key and +/// duplication of the raw configuration data. +/// Can be used by velox::QueryConfig to access +/// externally managed system configuration. +class IConfig { + public: + // Do not inline this member function as lambda. Otherwise, a GCC bug + // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=103186 might be triggered + // with GCC 11.1 and 11.2. + // This is currently a required workaround specifically for Apache Gluten + // that still relies on GCC 11.2: + // https://github.com/apache/gluten/issues/11991. The workaround is needed + // until Apache Gluten migrates away from the old compiler. + template + static T defaultToT(std::string /* unused */, std::string value) { + return folly::to(value); + } + + template + std::optional get( + const std::string& key, + const std::function& toT = + defaultToT) const { + if (auto val = access(key)) { + return toT(key, *val); + } + return std::nullopt; + } + + template + T get( + const std::string& key, + const T& defaultValue, + const std::function& toT = + defaultToT) const { + if (auto val = access(key)) { + return toT(key, *val); + } + return defaultValue; + } + + virtual std::unordered_map rawConfigsCopy() + const = 0; + + virtual ~IConfig() = default; + + private: + virtual std::optional access(const std::string& key) const = 0; +}; + +} // namespace facebook::velox::config diff --git a/velox/common/dynamic_registry/CMakeLists.txt b/velox/common/dynamic_registry/CMakeLists.txt index 31b37df7083..0b8a256b7b4 100644 --- a/velox/common/dynamic_registry/CMakeLists.txt +++ b/velox/common/dynamic_registry/CMakeLists.txt @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -velox_add_library(velox_dynamic_library_loader DynamicLibraryLoader.cpp) +velox_add_library( + velox_dynamic_library_loader + DynamicLibraryLoader.cpp + HEADERS + DynamicLibraryLoader.h +) velox_link_libraries(velox_dynamic_library_loader PRIVATE velox_exception) diff --git a/velox/common/dynamic_registry/tests/DynamicLinkTest.cpp b/velox/common/dynamic_registry/tests/DynamicLinkTest.cpp index 97672761bb8..fbade514dbd 100644 --- a/velox/common/dynamic_registry/tests/DynamicLinkTest.cpp +++ b/velox/common/dynamic_registry/tests/DynamicLinkTest.cpp @@ -152,8 +152,8 @@ TEST_F(DynamicLinkTest, dynamicLoadErrFunc) { dynamicFunctionFail(0, 0), "Scalar function signature is not supported: dynamic_err(BIGINT). Supported signatures: (array(bigint)) -> bigint."); - auto check = makeRowVector( - {makeNullableArrayVector(std::vector>>{ + auto check = makeRowVector({makeNullableArrayVector( + std::vector>>{ {0, 1, 3, 4, 5, 6, 7, 8, 9}})}); // Expecting a success because we are passing in an array. diff --git a/velox/common/encode/Base32.cpp b/velox/common/encode/Base32.cpp new file mode 100644 index 00000000000..b43a2b03f0a --- /dev/null +++ b/velox/common/encode/Base32.cpp @@ -0,0 +1,192 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/common/encode/Base32.h" + +#include +#include +#include + +#include "velox/common/base/CheckedArithmetic.h" +#include "velox/common/base/Exceptions.h" + +namespace facebook::velox::encoding { + +// Reverse lookup table for decoding. 255 means invalid character. +// Only uppercase letters (A-Z) and digits 2-7 are valid per RFC 4648. +// Lowercase letters are NOT supported (matching Google Guava's +// BaseEncoding.base32()). +constexpr const Base32::ReverseIndex kBase32ReverseIndexTable = { + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, // 0-15 + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, // 16-31 + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, // 32-47 + 255, 255, 26, 27, 28, 29, 30, 31, 255, + 255, 255, 255, 255, 255, 255, 255, // 48-63 ('2'-'7') + 255, 0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, // 64-79 ('A'-'O') + 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 255, 255, 255, 255, 255, // 80-95 ('P'-'Z') + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, // 96-111 (lowercase 'a'-'o' - INVALID) + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, // 112-127 (lowercase 'p'-'z' - INVALID) + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, // 128-143 + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, // 144-159 + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, // 160-175 + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, // 176-191 + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, // 192-207 + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, // 208-223 + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, // 224-239 + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255 // 240-255 +}; + +// static +folly::Expected Base32::base32ReverseLookup( + char encodedChar, + const ReverseIndex& reverseIndex) { + auto index = reverseIndex[static_cast(encodedChar)]; + if (index >= 32) { + return folly::makeUnexpected( + Status::UserError("Unrecognized character: {}", encodedChar)); + } + return index; +} + +// static +folly::Expected Base32::calculateDecodedSize( + const char* input, + const size_t inputSize) { + if (inputSize == 0) { + return 0; + } + + // Count valid (non-padding, non-whitespace) characters and validate them + size_t validCharCount = 0; + for (size_t i = 0; i < inputSize; ++i) { + char c = input[i]; + if (c == Base32::kPadding || std::isspace(static_cast(c))) { + continue; + } + + // Validate character first + auto index = kBase32ReverseIndexTable[static_cast(c)]; + if (index >= 32) { + return folly::makeUnexpected( + Status::UserError("Unrecognized character: {}", c)); + } + validCharCount++; + } + + // Validate input length matches Google Guava's Base32 behavior. + // Base32 encoding groups characters into quantums of 8 characters (40 bits). + // Valid character counts (mod 8) are: 0, 2, 4, 5, 7 + // Invalid character counts (mod 8) are: 1, 3, 6 + // These invalid counts leave too many incomplete bits that cannot form + // complete bytes. + if (validCharCount > 0) { + size_t remainder = validCharCount % 8; + if (remainder == 1 || remainder == 3 || remainder == 6) { + return folly::makeUnexpected( + Status::UserError("Invalid input length {}", validCharCount)); + } + } + + // Calculate decoded size + // Each base32 character represents 5 bits + // We need to convert to bytes (8 bits each) + size_t totalBits = checkedMultiply(validCharCount, size_t(5)); + size_t decodedSize = totalBits / 8; + + return decodedSize; +} + +// static +Status Base32::decode( + const char* input, + size_t inputSize, + char* outputBuffer, + size_t outputSize) { + auto decodedSize = decodeImpl( + input, inputSize, outputBuffer, outputSize, kBase32ReverseIndexTable); + if (decodedSize.hasError()) { + return decodedSize.error(); + } + return Status::OK(); +} + +// Decodes Base32 input using the provided reverse lookup table. +// This is the core decoding implementation that accumulates 5-bit values +// from Base32 characters and outputs 8-bit bytes. +// static +folly::Expected Base32::decodeImpl( + const char* input, + size_t inputSize, + char* outputBuffer, + size_t outputSize, + const ReverseIndex& reverseIndex) { + if (inputSize == 0) { + return 0; + } + + size_t outputPos = 0; + uint64_t accumulator = 0; + size_t bitsAccumulated = 0; + + for (size_t i = 0; i < inputSize; ++i) { + char c = input[i]; + + // Skip padding and whitespace (RFC 4648 allows whitespace in encoded data) + if (c == Base32::kPadding || std::isspace(static_cast(c))) { + continue; + } + + // Validate and convert character to 5-bit value + auto value = base32ReverseLookup(c, reverseIndex); + if (value.hasError()) { + return folly::makeUnexpected(value.error()); + } + + // Accumulate 5 bits from each Base32 character + // Each character contributes 5 bits to the bit accumulator + accumulator = (accumulator << 5) | value.value(); + bitsAccumulated += 5; + + // Extract full bytes (8 bits) when we have accumulated enough bits + if (bitsAccumulated >= 8) { + if (outputPos >= outputSize) { + return folly::makeUnexpected( + Status::UserError("Output buffer too small")); + } + outputBuffer[outputPos++] = + static_cast((accumulator >> (bitsAccumulated - 8)) & 0xFF); + bitsAccumulated -= 8; + } + } + + return outputPos; +} + +} // namespace facebook::velox::encoding diff --git a/velox/common/encode/Base32.h b/velox/common/encode/Base32.h new file mode 100644 index 00000000000..246a60d7994 --- /dev/null +++ b/velox/common/encode/Base32.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include + +#include "velox/common/base/Status.h" + +namespace facebook::velox::encoding { + +class Base32 { + public: + static const size_t kCharsetSize = 32; + static const size_t kReverseIndexSize = 256; + + /// Character set used for Base32 encoding. + /// Contains specific characters that form the encoding scheme. + using Charset = std::array; + + /// Reverse lookup table for decoding. + /// Maps each possible encoded character to its corresponding numeric value + /// within the encoding base. + using ReverseIndex = std::array; + + /// Decodes the specified number of characters from the 'input' and writes the + /// result to the 'outputBuffer'. + static Status decode( + const char* input, + size_t inputSize, + char* outputBuffer, + size_t outputSize); + + /// Calculates the decoded size based on encoded input. + static folly::Expected calculateDecodedSize( + const char* input, + const size_t inputSize); + + // Padding character used in encoding. + static const char kPadding = '='; + + // Constants defining the size in bytes of binary and encoded blocks for + // Base32 encoding. Size of a binary block in bytes (5 bytes = 40 bits) + static const int kBinaryBlockByteSize = 5; + // Size of an encoded block in bytes (8 bytes = 40 bits) + static const int kEncodedBlockByteSize = 8; + + private: + // Reverse lookup helper function to get the original index of a Base32 + // character. + static folly::Expected base32ReverseLookup( + char encodedChar, + const ReverseIndex& reverseIndex); + + // Decodes the specified data using the provided reverse lookup table. + static folly::Expected decodeImpl( + const char* input, + size_t inputSize, + char* outputBuffer, + size_t outputSize, + const ReverseIndex& reverseIndex); +}; + +} // namespace facebook::velox::encoding diff --git a/velox/common/encode/Base64.cpp b/velox/common/encode/Base64.cpp index aa521a57d47..2b87fd141ab 100644 --- a/velox/common/encode/Base64.cpp +++ b/velox/common/encode/Base64.cpp @@ -187,14 +187,13 @@ size_t Base64::calculateEncodedSize(size_t inputSize, bool withPadding) { // static void Base64::encode(const char* input, size_t inputSize, char* output) { - encodeImpl( - folly::StringPiece(input, inputSize), kBase64Charset, true, output); + encodeImpl(std::string_view(input, inputSize), kBase64Charset, true, output); } // static void Base64::encodeUrl(const char* input, size_t inputSize, char* output) { encodeImpl( - folly::StringPiece(input, inputSize), kBase64UrlCharset, true, output); + std::string_view(input, inputSize), kBase64UrlCharset, true, output); } // static @@ -249,13 +248,13 @@ void Base64::encodeImpl( } // static -std::string Base64::encode(folly::StringPiece text) { +std::string Base64::encode(std::string_view text) { return encodeImpl(text, kBase64Charset, true); } // static std::string Base64::encode(const char* input, size_t inputSize) { - return encode(folly::StringPiece(input, inputSize)); + return encode(std::string_view(input, inputSize)); } namespace { @@ -308,7 +307,7 @@ std::string Base64::encode(const folly::IOBuf* inputBuffer) { } // static -std::string Base64::decode(folly::StringPiece encodedText) { +std::string Base64::decode(std::string_view encodedText) { std::string decodedResult; decode(std::make_pair(encodedText.data(), encodedText.size()), decodedResult); return decodedResult; @@ -346,9 +345,10 @@ Expected Base64::base64ReverseLookup( const ReverseIndex& reverseIndex) { auto reverseLookupValue = reverseIndex[static_cast(encodedChar)]; if (reverseLookupValue >= 0x40) { - return folly::makeUnexpected(Status::UserError( - "decode() - invalid input string: invalid character '{}'", - encodedChar)); + return folly::makeUnexpected( + Status::UserError( + "decode() - invalid input string: invalid character '{}'", + encodedChar)); } return reverseLookupValue; } @@ -381,8 +381,9 @@ Expected Base64::calculateDecodedSize( // block size if (inputSize % kEncodedBlockByteSize != 0) { return folly::makeUnexpected( - Status::UserError("Base64::decode() - invalid input string: " - "string length is not a multiple of 4.")); + Status::UserError( + "Base64::decode() - invalid input string: " + "string length is not a multiple of 4.")); } auto decodedSize = @@ -403,9 +404,10 @@ Expected Base64::calculateDecodedSize( // Adjust the needed size for extra bytes, if present if (extraBytes) { if (extraBytes == 1) { - return folly::makeUnexpected(Status::UserError( - "Base64::decode() - invalid input string: " - "string length cannot be 1 more than a multiple of 4.")); + return folly::makeUnexpected( + Status::UserError( + "Base64::decode() - invalid input string: " + "string length cannot be 1 more than a multiple of 4.")); } decodedSize += (extraBytes * kBinaryBlockByteSize) / kEncodedBlockByteSize; } @@ -431,8 +433,9 @@ Expected Base64::decodeImpl( if (outputSize < decodedSize.value()) { return folly::makeUnexpected( - Status::UserError("Base64::decode() - invalid output string: " - "output string is too small.")); + Status::UserError( + "Base64::decode() - invalid output string: " + "output string is too small.")); } outputSize = decodedSize.value(); @@ -492,13 +495,13 @@ Expected Base64::decodeImpl( } // static -std::string Base64::encodeUrl(folly::StringPiece text) { +std::string Base64::encodeUrl(std::string_view text) { return encodeImpl(text, kBase64UrlCharset, false); } // static std::string Base64::encodeUrl(const char* input, size_t inputSize) { - return encodeUrl(folly::StringPiece(input, inputSize)); + return encodeUrl(std::string_view(input, inputSize)); } // static @@ -521,7 +524,7 @@ Status Base64::decodeUrl( } // static -std::string Base64::decodeUrl(folly::StringPiece encodedText) { +std::string Base64::decodeUrl(std::string_view encodedText) { std::string decodedOutput; decodeUrl( std::make_pair(encodedText.data(), encodedText.size()), decodedOutput); @@ -628,8 +631,9 @@ Expected Base64::calculateMimeDecodedSize( if (kBase64ReverseIndexTable[static_cast(input[0])] >= 0x40) { return 0; } - return folly::makeUnexpected(Status::UserError( - "Input should at least have 2 bytes for base64 bytes.")); + return folly::makeUnexpected( + Status::UserError( + "Input should at least have 2 bytes for base64 bytes.")); } auto decodedSize = inputSize; // Compute how many true Base64 chars. diff --git a/velox/common/encode/Base64.h b/velox/common/encode/Base64.h index 073cc49cd4f..7dca7d2fdbc 100644 --- a/velox/common/encode/Base64.h +++ b/velox/common/encode/Base64.h @@ -16,7 +16,6 @@ #pragma once -#include #include #include @@ -45,7 +44,7 @@ class Base64 { static std::string encode(const char* input, size_t inputSize); /// Encodes the specified text. - static std::string encode(folly::StringPiece text); + static std::string encode(std::string_view text); /// Encodes the specified IOBuf data. static std::string encode(const folly::IOBuf* inputBuffer); @@ -60,7 +59,7 @@ class Base64 { static std::string encodeUrl(const char* input, size_t inputSize); /// Encodes the specified text using URL encoding. - static std::string encodeUrl(folly::StringPiece text); + static std::string encodeUrl(std::string_view text); /// Encodes the specified IOBuf data using URL encoding. static std::string encodeUrl(const folly::IOBuf* inputBuffer); @@ -72,7 +71,7 @@ class Base64 { encodeUrl(const char* input, size_t inputSize, char* outputBuffer); /// Decodes the input Base64 encoded string. - static std::string decode(folly::StringPiece encodedText); + static std::string decode(std::string_view encodedText); /// Decodes the specified encoded payload and writes the result to the /// 'output'. @@ -94,7 +93,7 @@ class Base64 { size_t outputSize); /// Decodes the input Base64 URL encoded string. - static std::string decodeUrl(folly::StringPiece encodedText); + static std::string decodeUrl(std::string_view encodedText); /// Decodes the specified URL encoded payload and writes the result to the /// 'output'. diff --git a/velox/common/encode/ByteStream.h b/velox/common/encode/ByteStream.h index 70bd2502513..9ebbee1054d 100644 --- a/velox/common/encode/ByteStream.h +++ b/velox/common/encode/ByteStream.h @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #pragma once /** @@ -26,7 +27,6 @@ #include #include -#include namespace facebook::velox::strings { @@ -86,28 +86,28 @@ class ByteSink { * append() will return 0). In particular, this may not be used for * non-blocking behavior. */ - virtual size_t append(folly::StringPiece str) = 0; + virtual size_t append(std::string_view str) = 0; size_t append(const void* data, size_t size) { - return append(folly::StringPiece(static_cast(data), size)); + return append(std::string_view(static_cast(data), size)); } /** - * Append the given string to this ByteSink. The string must remain + * Append the given string to this ByteSink. The string must remain * allocated (and unchanged) until the ByteSink is destroyed. */ - virtual size_t appendAllocated(folly::StringPiece str) { + virtual size_t appendAllocated(std::string_view str) { return append(str); } /** * Convenience function that appends the bitwise representation of count - * objects starting at address obj. The usual caveats about endianness, + * objects starting at address obj. The usual caveats about endianness, * padding apply. */ template size_t appendBitwise(const T* obj, size_t count) { const size_t sz = count * sizeof(T); - return append(folly::StringPiece(reinterpret_cast(obj), sz)); + return append(std::string_view(reinterpret_cast(obj), sz)); } /** @@ -185,8 +185,8 @@ class SByteSink : public ByteSink { public: explicit SByteSink(S* str) : str_(str) {} - size_t append(folly::StringPiece s) override { - str_->append(s.start(), s.size()); + size_t append(std::string_view s) override { + str_->append(s.data(), s.size()); return s.size(); } @@ -237,7 +237,7 @@ class ByteSource { * next() will return false, but bad() will also return false. On error, * next() returns false, and bad() returns true. */ - virtual bool next(folly::StringPiece* chunk) = 0; + virtual bool next(std::string_view* chunk) = 0; /** * Push back the last numBytes returned by the last next() call, so @@ -316,7 +316,7 @@ class ByteSourceBuffer : public std::basic_streambuf { class StringByteSource : public ByteSource { public: explicit StringByteSource( - const folly::StringPiece& str, + const std::string_view& str, size_t maxBytes = kSizeMax) : str_(str), offset_(0), @@ -326,15 +326,17 @@ class StringByteSource : public ByteSource { bool bad() const override { return false; } - bool next(folly::StringPiece* chunk) override { + + bool next(std::string_view* chunk) override { if (offset_ == str_.size()) { return false; } size_t len = std::min(str_.size() - offset_, maxBytes_); - chunk->reset(str_.start() + offset_, len); + *chunk = std::string_view(str_.data() + offset_, len); offset_ += len; return true; } + void backUp(size_t numBytes) override { CHECK_LE(numBytes, maxBytes_); CHECK_GE(offset_, numBytes); @@ -342,7 +344,7 @@ class StringByteSource : public ByteSource { } private: - folly::StringPiece str_; + std::string_view str_; size_t offset_; size_t maxBytes_; }; diff --git a/velox/common/encode/CMakeLists.txt b/velox/common/encode/CMakeLists.txt index 15acfcc232a..f3ed80bb042 100644 --- a/velox/common/encode/CMakeLists.txt +++ b/velox/common/encode/CMakeLists.txt @@ -16,5 +16,14 @@ if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) endif() -velox_add_library(velox_encode Base64.cpp) +velox_add_library( + velox_encode + Base64.cpp + Base32.cpp + HEADERS + Base32.h + Base64.h +) velox_link_libraries(velox_encode PUBLIC velox_status Folly::folly) + +velox_add_library(velox_common_encode INTERFACE HEADERS ByteStream.h Coding.h UInt128.h) diff --git a/velox/common/encode/Coding.h b/velox/common/encode/Coding.h index 55e66d3ce91..e4edaac4bcd 100644 --- a/velox/common/encode/Coding.h +++ b/velox/common/encode/Coding.h @@ -109,20 +109,20 @@ class Varint { char buf[kMaxSize64]; char* p = buf; encode(val, &p); - sink->append(folly::StringPiece(buf, p - buf)); + sink->append(std::string_view(buf, p - buf)); } static void encode128ToByteSink(UInt128 val, strings::ByteSink* sink) { char buf[kMaxSize128]; char* p = buf; encode128(val, &p); - sink->append(folly::StringPiece(buf, p - buf)); + sink->append(std::string_view(buf, p - buf)); } // Returns true if decode can be called without causing a CHECK failure. // The pointers are not adjusted at all - static bool canDecode(folly::StringPiece src) { - src = src.subpiece(0, kMaxSize64); + static bool canDecode(std::string_view src) { + src = src.substr(0, kMaxSize64); return std::any_of( src.begin(), src.end(), [](char v) { return ~v & 0x80; }); } @@ -187,18 +187,18 @@ class Varint { return val; } - // Decode a value from a StringPiece, and advance the StringPiece. - static uint64_t decode(folly::StringPiece* data) { - const char* p = data->start(); + // Decode a value from a string_view, and advance it. + static uint64_t decode(std::string_view* data) { + const char* p = data->data(); uint64_t val = decode(&p, data->size()); - data->advance(p - data->start()); + data->remove_prefix(p - data->data()); return val; } - static UInt128 decode128(folly::StringPiece* data) { - const char* p = data->start(); + static UInt128 decode128(std::string_view* data) { + const char* p = data->data(); UInt128 val = decode128(&p, data->size()); - data->advance(p - data->start()); + data->remove_prefix(p - data->data()); return val; } @@ -207,13 +207,13 @@ class Varint { uint64_t val = 0; int32_t shift = 0; int32_t max_size = kMaxSize64; - folly::StringPiece chunk; + std::string_view chunk; int32_t remaining = 0; const char* p = nullptr; for (;;) { if (remaining == 0) { CHECK(src->next(&chunk)); - p = chunk.start(); + p = chunk.data(); remaining = chunk.size(); DCHECK_GT(remaining, 0); } @@ -238,13 +238,13 @@ class Varint { UInt128 val = 0; int32_t shift = 0; int32_t max_size = kMaxSize128; - folly::StringPiece chunk; + std::string_view chunk; int32_t remaining = 0; const char* p = nullptr; for (;;) { if (remaining == 0) { CHECK(src->next(&chunk)); - p = chunk.start(); + p = chunk.data(); remaining = chunk.size(); DCHECK_GT(remaining, 0); } @@ -292,7 +292,7 @@ namespace detail { class ByteSinkAppender { public: /* implicit */ ByteSinkAppender(strings::ByteSink* out) : out_(out) {} - void operator()(folly::StringPiece sp) { + void operator()(std::string_view sp) { out_->append(sp.data(), sp.size()); } diff --git a/velox/common/encode/tests/Base64Test.cpp b/velox/common/encode/tests/Base64Test.cpp index 91b0fcab908..bc78bd6c501 100644 --- a/velox/common/encode/tests/Base64Test.cpp +++ b/velox/common/encode/tests/Base64Test.cpp @@ -25,27 +25,20 @@ namespace facebook::velox::encoding { class Base64Test : public ::testing::Test {}; TEST_F(Base64Test, fromBase64) { - EXPECT_EQ( - "Hello, World!", - Base64::decode(folly::StringPiece("SGVsbG8sIFdvcmxkIQ=="))); + EXPECT_EQ("Hello, World!", Base64::decode("SGVsbG8sIFdvcmxkIQ==")); EXPECT_EQ( "Base64 encoding is fun.", - Base64::decode(folly::StringPiece("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4="))); - EXPECT_EQ( - "Simple text", Base64::decode(folly::StringPiece("U2ltcGxlIHRleHQ="))); - EXPECT_EQ( - "1234567890", Base64::decode(folly::StringPiece("MTIzNDU2Nzg5MA=="))); + Base64::decode("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=")); + EXPECT_EQ("Simple text", Base64::decode("U2ltcGxlIHRleHQ=")); + EXPECT_EQ("1234567890", Base64::decode("MTIzNDU2Nzg5MA==")); // Check encoded strings without padding - EXPECT_EQ( - "Hello, World!", - Base64::decode(folly::StringPiece("SGVsbG8sIFdvcmxkIQ"))); + EXPECT_EQ("Hello, World!", Base64::decode("SGVsbG8sIFdvcmxkIQ")); EXPECT_EQ( "Base64 encoding is fun.", - Base64::decode(folly::StringPiece("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4"))); - EXPECT_EQ( - "Simple text", Base64::decode(folly::StringPiece("U2ltcGxlIHRleHQ"))); - EXPECT_EQ("1234567890", Base64::decode(folly::StringPiece("MTIzNDU2Nzg5MA"))); + Base64::decode("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4")); + EXPECT_EQ("Simple text", Base64::decode("U2ltcGxlIHRleHQ")); + EXPECT_EQ("1234567890", Base64::decode("MTIzNDU2Nzg5MA")); } TEST_F(Base64Test, calculateDecodedSizeProperSize) { diff --git a/velox/common/file/CMakeLists.txt b/velox/common/file/CMakeLists.txt index b89a4e3a38e..27cba975037 100644 --- a/velox/common/file/CMakeLists.txt +++ b/velox/common/file/CMakeLists.txt @@ -14,7 +14,23 @@ # for generated headers include_directories(.) -velox_add_library(velox_file File.cpp FileInputStream.cpp FileSystems.cpp Utils.cpp) +velox_add_library( + velox_file + File.cpp + FileInputStream.cpp + FileIoTracer.cpp + FileSystems.cpp + FileUtils.cpp + HEADERS + File.h + FileInputStream.h + FileIoTracer.h + FileSystems.h + FileUtils.h + PlainUserNameTokenProvider.h + Region.h + TokenProvider.h +) velox_link_libraries( velox_file PUBLIC velox_exception Folly::folly diff --git a/velox/common/file/File.cpp b/velox/common/file/File.cpp index c1f1e898273..6c098636d32 100644 --- a/velox/common/file/File.cpp +++ b/velox/common/file/File.cpp @@ -30,6 +30,36 @@ namespace facebook::velox { +void IoStats::addCounter(const std::string& name, RuntimeCounter counter) { + auto locked = stats_.wlock(); + auto it = locked->find(name); + if (it == locked->end()) { + auto [ptr, inserted] = locked->emplace(name, RuntimeMetric(counter.unit)); + VELOX_CHECK(inserted); + ptr->second.addValue(counter.value); + } else { + VELOX_CHECK_EQ(it->second.unit, counter.unit); + it->second.addValue(counter.value); + } +} + +void IoStats::merge(const IoStats& other) { + auto otherStats = other.stats(); + auto locked = stats_.wlock(); + for (const auto& [name, metric] : otherStats) { + auto it = locked->find(name); + if (it == locked->end()) { + locked->emplace(name, metric); + } else { + it->second.merge(metric); + } + } +} + +folly::F14FastMap IoStats::stats() const { + return stats_.copy(); +} + #define RETURN_IF_ERROR(func, result) \ result = func; \ if (result < 0) { \ @@ -60,10 +90,10 @@ T getAttribute( std::string ReadFile::pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats) const { + const FileIoContext& context) const { std::string buf; buf.resize(length); - auto res = pread(offset, length, buf.data(), stats); + auto res = pread(offset, length, buf.data(), context); buf.resize(res.size()); return buf; } @@ -71,7 +101,7 @@ std::string ReadFile::pread( uint64_t ReadFile::preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats) const { + const FileIoContext& context) const { auto fileSize = size(); uint64_t numRead = 0; if (offset >= fileSize) { @@ -81,7 +111,7 @@ uint64_t ReadFile::preadv( auto copySize = std::min(range.size(), fileSize - offset); // NOTE: skip the gap in case of coalesce io. if (range.data() != nullptr) { - pread(offset, copySize, range.data(), stats); + pread(offset, copySize, range.data(), context); } offset += copySize; numRead += copySize; @@ -92,18 +122,17 @@ uint64_t ReadFile::preadv( uint64_t ReadFile::preadv( folly::Range regions, folly::Range iobufs, - filesystems::File::IoStats* stats) const { + const FileIoContext& context) const { VELOX_CHECK_EQ(regions.size(), iobufs.size()); uint64_t length = 0; for (size_t i = 0; i < regions.size(); ++i) { const auto& region = regions[i]; auto& output = iobufs[i]; output = folly::IOBuf(folly::IOBuf::CREATE, region.length); - pread(region.offset, region.length, output.writableData(), stats); + pread(region.offset, region.length, output.writableData(), context); output.append(region.length); length += region.length; } - return length; } @@ -111,7 +140,7 @@ std::string_view InMemoryReadFile::pread( uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats) const { + const FileIoContext& context) const { bytesRead_ += length; memcpy(buf, file_.data() + offset, length); return {static_cast(buf), length}; @@ -120,7 +149,7 @@ std::string_view InMemoryReadFile::pread( std::string InMemoryReadFile::pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats) const { + const FileIoContext& context) const { bytesRead_ += length; return std::string(file_.data() + offset, length); } @@ -202,7 +231,7 @@ std::string_view LocalReadFile::pread( uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats) const { + const FileIoContext& context) const { preadInternal(offset, length, static_cast(buf)); return {static_cast(buf), length}; } @@ -210,7 +239,7 @@ std::string_view LocalReadFile::pread( uint64_t LocalReadFile::preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats) const { + const FileIoContext& context) const { // Dropped bytes sized so that a typical dropped range of 50K is not // too many iovecs. static thread_local std::vector droppedBytes(16 * 1024); @@ -267,17 +296,17 @@ uint64_t LocalReadFile::preadv( folly::SemiFuture LocalReadFile::preadvAsync( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats) const { + const FileIoContext& context) const { if (!executor_) { - return ReadFile::preadvAsync(offset, buffers, stats); + return ReadFile::preadvAsync(offset, buffers, context); } auto [promise, future] = folly::makePromiseContract(); executor_->add([this, _promise = std::move(promise), _offset = offset, _buffers = buffers, - _stats = stats]() mutable { - auto delegateFuture = ReadFile::preadvAsync(_offset, _buffers, _stats); + _context = context]() mutable { + auto delegateFuture = ReadFile::preadvAsync(_offset, _buffers, _context); _promise.setTry(std::move(delegateFuture).getTry()); }); return std::move(future); diff --git a/velox/common/file/File.h b/velox/common/file/File.h index 18d1c264ca7..c41451d912d 100644 --- a/velox/common/file/File.h +++ b/velox/common/file/File.h @@ -36,15 +36,65 @@ #include #include +#include +#include #include +#include #include "velox/common/base/Exceptions.h" +#include "velox/common/base/RuntimeMetrics.h" +#include "velox/common/file/FileIoTracer.h" #include "velox/common/file/FileSystems.h" #include "velox/common/file/Region.h" #include "velox/common/io/IoStatistics.h" namespace facebook::velox { +/// Free form statistics for file I/O operations. The keys are arbitrary +/// strings, and values are RuntimeMetric. This class can be used to record +/// observability about filesystem operations. +class IoStats { + public: + IoStats() = default; + + void addCounter(const std::string& name, RuntimeCounter counter); + + void merge(const IoStats& other); + + folly::F14FastMap stats() const; + + private: + folly::Synchronized> stats_; +}; + +struct FileIoContext { + /// Stats for IO operations. + IoStats* ioStats{nullptr}; + + /// Options for file read operations. + folly::F14FastMap fileOpts; + + /// Tracer for IO operations, providing call stack context. + std::shared_ptr ioTracer; + + /// When false, hints to the storage layer that this read should not be cached + /// or should be evicted soon after reading. This is useful for one-time reads + /// where caching would waste resources. + bool cacheable{false}; + + FileIoContext() = default; + + explicit FileIoContext( + IoStats* stats, + folly::F14FastMap fileOpts = {}, + std::shared_ptr tracer = nullptr, + bool cacheable = false) + : ioStats(stats), + fileOpts(std::move(fileOpts)), + ioTracer(std::move(tracer)), + cacheable(cacheable) {} +}; + // A read-only file. All methods in this object should be thread safe. class ReadFile { public: @@ -52,16 +102,12 @@ class ReadFile { // Reads the data at [offset, offset + length) into the provided pre-allocated // buffer 'buf'. The bytes are returned as a string_view pointing to 'buf'. - // - // 'stats' is an IoStatistics pointer passed in by the caller to collect stats - // for this read operation. - // // This method should be thread safe. virtual std::string_view pread( uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats = nullptr) const = 0; + const FileIoContext& context = {}) const = 0; // Same as above, but returns owned data directly. // @@ -69,20 +115,16 @@ class ReadFile { virtual std::string pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats = nullptr) const; + const FileIoContext& context = {}) const; // Reads starting at 'offset' into the memory referenced by the // Ranges in 'buffers'. The buffers are filled left to right. A // buffer with nullptr data will cause its size worth of bytes to be skipped. - // - // 'stats' is an IoStatistics pointer passed in by the caller to collect stats - // for this read operation. - // // This method should be thread safe. virtual uint64_t preadv( uint64_t /*offset*/, const std::vector>& /*buffers*/, - filesystems::File::IoStats* stats = nullptr) const; + const FileIoContext& context = {}) const; // Vectorized read API. Implementations can coalesce and parallelize. // The offsets don't need to be sorted. @@ -93,30 +135,22 @@ class ReadFile { // by the preadv. // Returns the total number of bytes read, which might be different than the // sum of all buffer sizes (for example, if coalescing was used). - // - // 'stats' is an IoStatistics pointer passed in by the caller to collect stats - // for this read operation. - // // This method should be thread safe. virtual uint64_t preadv( folly::Range regions, folly::Range iobufs, - filesystems::File::IoStats* stats = nullptr) const; + const FileIoContext& context = {}) const; /// Like preadv but may execute asynchronously and returns the read size or /// exception via SemiFuture. Use hasPreadvAsync() to check if the /// implementation is in fact asynchronous. - /// - /// 'stats' is an IoStatistics pointer passed in by the caller to collect - /// stats for this read operation. - /// /// This method should be thread safe. virtual folly::SemiFuture preadvAsync( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats = nullptr) const { + const FileIoContext& context = {}) const { try { - return folly::SemiFuture(preadv(offset, buffers, stats)); + return folly::SemiFuture(preadv(offset, buffers, context)); } catch (const std::exception& e) { return folly::makeSemiFuture(e); } @@ -240,12 +274,12 @@ class InMemoryReadFile : public ReadFile { uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats = nullptr) const override; + const FileIoContext& context = {}) const override; std::string pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats) const override; + const FileIoContext& context = {}) const override; uint64_t size() const final { return file_.size(); @@ -259,6 +293,7 @@ class InMemoryReadFile : public ReadFile { void setShouldCoalesce(bool shouldCoalesce) { shouldCoalesce_ = shouldCoalesce; } + bool shouldCoalesce() const final { return shouldCoalesce_; } @@ -311,19 +346,19 @@ class LocalReadFile final : public ReadFile { uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats = nullptr) const final; + const FileIoContext& context = {}) const final; uint64_t size() const final; uint64_t preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats = nullptr) const final; + const FileIoContext& context = {}) const final; folly::SemiFuture preadvAsync( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats = nullptr) const override; + const FileIoContext& context = {}) const override; bool hasPreadvAsync() const override { return executor_ != nullptr; diff --git a/velox/common/file/FileIoTracer.cpp b/velox/common/file/FileIoTracer.cpp new file mode 100644 index 00000000000..b8b4a418ceb --- /dev/null +++ b/velox/common/file/FileIoTracer.cpp @@ -0,0 +1,115 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/file/FileIoTracer.h" + +#include + +namespace facebook::velox { + +std::string IoTag::toString() const { + if (parent == nullptr) { + return name; + } + return parent->toString() + " -> " + name; +} + +size_t IoTag::depth() const { + size_t d = 1; + const IoTag* p = parent; + while (p != nullptr) { + ++d; + p = p->parent; + } + return d; +} + +const IoTag*& threadIoTag() { + thread_local const IoTag* tag = nullptr; + return tag; +} + +ScopedIoTag::ScopedIoTag(std::string_view name) : tag_(name, threadIoTag()) { + threadIoTag() = &tag_; +} + +ScopedIoTag::~ScopedIoTag() { + threadIoTag() = tag_.parent; +} + +const IoTag& ScopedIoTag::tag() const { + return tag_; +} + +std::string toString(IoType type) { + switch (type) { + case IoType::Read: + return "Read"; + case IoType::AsyncRead: + return "AsyncRead"; + case IoType::Write: + return "Write"; + case IoType::AsyncWrite: + return "AsyncWrite"; + } + return "Unknown"; +} + +std::ostream& operator<<(std::ostream& os, IoType type) { + return os << toString(type); +} + +InMemoryFileIoTracer::InMemoryFileIoTracer(std::vector& records) + : records_(records) { + records_.reserve(4096); +} + +std::shared_ptr InMemoryFileIoTracer::create( + std::vector& records) { + return std::shared_ptr( + new InMemoryFileIoTracer(records)); +} + +void InMemoryFileIoTracer::record( + IoType type, + uint64_t offset, + uint64_t length) { + IoRecord record; + record.type = type; + record.offset = offset; + record.length = length; + const IoTag* tag = threadIoTag(); + record.tag = tag != nullptr ? tag->toString() : ""; + + std::lock_guard lock(mutex_); + records_.push_back(std::move(record)); +} + +void InMemoryFileIoTracer::finish() { + LOG(INFO) << "InMemoryFileIoTracer recorded " << records_.size() + << " IO operations"; +} + +std::string IoRecord::toString() const { + std::string result = velox::toString(type); + result += " [" + std::to_string(offset) + ", " + std::to_string(length) + "]"; + if (!tag.empty()) { + result += " " + tag; + } + return result; +} + +} // namespace facebook::velox diff --git a/velox/common/file/FileIoTracer.h b/velox/common/file/FileIoTracer.h new file mode 100644 index 00000000000..a81fb7f6cb5 --- /dev/null +++ b/velox/common/file/FileIoTracer.h @@ -0,0 +1,151 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace facebook::velox { + +/// Represents a tag in the IO call stack. Tags are linked together to form +/// a stack trace of IO operations, useful for debugging and tracing. +struct IoTag { + /// The name/label for this tag in the call stack. + std::string name; + + /// Pointer to the parent tag in the call stack. nullptr if this is the root. + const IoTag* parent{nullptr}; + + IoTag() = default; + + explicit IoTag(std::string_view tagName, const IoTag* parentTag = nullptr) + : name(tagName), parent(parentTag) {} + + /// Returns the full call stack as a string, with tags separated by " -> ". + /// Example: "TableScan -> ColumnReader -> PrefixEncoding" + std::string toString() const; + + /// Returns the depth of this tag in the call stack (1-based). + size_t depth() const; +}; + +/// Thread-local IO tag for the current call stack. +/// Use ScopedIoTag to automatically push/pop tags. +const IoTag*& threadIoTag(); + +/// RAII helper to push/pop IoTag onto the thread-local stack automatically. +/// Usage: +/// ScopedIoTag scopedTag("ColumnReader"); +/// // threadIoTag() now points to the new tag +/// // When scopedTag goes out of scope, threadIoTag() is restored +class ScopedIoTag { + public: + explicit ScopedIoTag(std::string_view name); + + ~ScopedIoTag(); + + ScopedIoTag(const ScopedIoTag&) = delete; + ScopedIoTag& operator=(const ScopedIoTag&) = delete; + + const IoTag& tag() const; + + private: + IoTag tag_; +}; + +/// Type of IO operation. +enum class IoType { + Read, + AsyncRead, + Write, + AsyncWrite, +}; + +/// Returns a string representation of the IoType. +std::string toString(IoType type); + +/// Stream output operator for IoType. +std::ostream& operator<<(std::ostream& os, IoType type); + +/// Abstract tracer for file IO operations. Implementations can record IO +/// operations for debugging, performance analysis, or monitoring purposes. +/// The IO tag is captured from thread-local storage via threadIoTag(). +class FileIoTracer { + public: + virtual ~FileIoTracer() = default; + + /// Records an IO operation. + /// @param type The type of IO operation (read, asyncRead, write, asyncWrite). + /// @param offset The byte offset in the file where the IO starts. + /// @param length The number of bytes involved in the IO operation. + /// Note: The IO tag is captured from threadIoTag() by the implementation. + virtual void record(IoType type, uint64_t offset, uint64_t length) = 0; + + /// Called when all IO operations are complete. + /// Implementations can use this to flush buffers, finalize reports, etc. + virtual void finish() = 0; +}; + +/// Record of a single IO operation captured by InMemoryFileIoTracer. +struct IoRecord { + IoType type; + uint64_t offset{}; + uint64_t length{}; + std::string tag; + + /// Returns a string representation of the record. + /// Format: " [, ] " + std::string toString() const; +}; + +/// In-memory file IO tracer that records IO operations to a vector. +/// Useful for testing and debugging. Thread-safe. +class InMemoryFileIoTracer : public FileIoTracer { + public: + /// Creates a new InMemoryFileIoTracer with a reference to a vector for + /// recording. + /// @param records The vector to record IO operations to. + static std::shared_ptr create( + std::vector& records); + + void record(IoType type, uint64_t offset, uint64_t length) override; + + void finish() override; + + private: + explicit InMemoryFileIoTracer(std::vector& records); + + std::mutex mutex_; + std::vector& records_; +}; + +} // namespace facebook::velox + +/// fmt::formatter specialization for IoType. +template <> +struct fmt::formatter : fmt::formatter { + auto format(facebook::velox::IoType type, fmt::format_context& ctx) const { + return fmt::formatter::format( + facebook::velox::toString(type), ctx); + } +}; diff --git a/velox/common/file/FileSystems.cpp b/velox/common/file/FileSystems.cpp index 4c420afd4e4..ed7658e572c 100644 --- a/velox/common/file/FileSystems.cpp +++ b/velox/common/file/FileSystems.cpp @@ -17,6 +17,7 @@ #include "velox/common/file/FileSystems.h" #include #include +#include #include "velox/common/base/Exceptions.h" #include "velox/common/file/File.h" @@ -90,7 +91,7 @@ class LocalFileSystem : public FileSystem { std::max( 1, static_cast( - std::thread::hardware_concurrency() / 2)), + folly::available_concurrency() / 2)), std::make_shared( "LocalReadahead")) : nullptr) {} diff --git a/velox/common/file/FileSystems.h b/velox/common/file/FileSystems.h index 337478be63b..e5c74fb5e26 100644 --- a/velox/common/file/FileSystems.h +++ b/velox/common/file/FileSystems.h @@ -28,11 +28,9 @@ namespace facebook::velox { namespace config { class ConfigBase; } +class IoStats; class ReadFile; class WriteFile; -namespace filesystems::File { -class IoStats; -} } // namespace facebook::velox namespace facebook::velox::filesystems { @@ -48,10 +46,10 @@ struct FileOptions { /// etc. static constexpr folly::StringPiece kFileCreateConfig{"file-create-config"}; - std::unordered_map values; + std::unordered_map values{}; memory::MemoryPool* pool{nullptr}; /// If specified then can be trusted to be the file size. - std::optional fileSize; + std::optional fileSize{}; /// Whether to create parent directories if they don't exist. /// @@ -74,7 +72,7 @@ struct FileOptions { std::optional> properties{ std::nullopt}; - File::IoStats* stats{nullptr}; + IoStats* stats{nullptr}; /// A raw string that client can encode as anything they want to describe the /// file. For example, extraFileInfo can contain serialized file descriptors @@ -88,6 +86,10 @@ struct FileOptions { /// A token provider that can be used to get tokens for accessing the file. std::shared_ptr tokenProvider{nullptr}; + + /// File read operations metadata that can be passed to the underlying file + /// system for tracking and logging purposes. + folly::F14FastMap fileReadOps{}; }; /// Defines directory options @@ -110,49 +112,6 @@ struct FileSystemOptions { bool readAheadEnabled{false}; }; -/// Free form statistics for a file system. The keys are arbitrary strings, and -/// values are RuntimeMetric. The underlying filesystem implementation can use -/// this class to record observability about filesystem operations. -namespace File { -class IoStats { - public: - IoStats() = default; - - void addCounter(const std::string& name, RuntimeCounter counter) { - auto locked = stats_.wlock(); - auto it = locked->find(name); - if (it == locked->end()) { - auto [ptr, inserted] = locked->emplace(name, RuntimeMetric(counter.unit)); - VELOX_CHECK(inserted); - ptr->second.addValue(counter.value); - } else { - VELOX_CHECK_EQ(it->second.unit, counter.unit); - it->second.addValue(counter.value); - } - } - - void merge(const IoStats& other) { - auto otherStats = other.stats(); - auto locked = stats_.wlock(); - for (const auto& [name, metric] : otherStats) { - auto it = locked->find(name); - if (it == locked->end()) { - locked->emplace(name, metric); - } else { - it->second.merge(metric); - } - } - } - - folly::F14FastMap stats() const { - return stats_.copy(); - } - - private: - folly::Synchronized> stats_; -}; -} // namespace File - /// An abstract FileSystem class FileSystem { public: diff --git a/velox/common/file/Utils.cpp b/velox/common/file/FileUtils.cpp similarity index 96% rename from velox/common/file/Utils.cpp rename to velox/common/file/FileUtils.cpp index bd9bfea9dcb..74a1a7da39e 100644 --- a/velox/common/file/Utils.cpp +++ b/velox/common/file/FileUtils.cpp @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "velox/common/file/Utils.h" +#include "velox/common/file/FileUtils.h" #include "velox/common/base/Exceptions.h" namespace facebook::velox::file::utils { diff --git a/velox/common/file/Utils.h b/velox/common/file/FileUtils.h similarity index 100% rename from velox/common/file/Utils.h rename to velox/common/file/FileUtils.h diff --git a/velox/common/file/Region.h b/velox/common/file/Region.h index ae2ef9c8a2b..5d588479b76 100644 --- a/velox/common/file/Region.h +++ b/velox/common/file/Region.h @@ -17,8 +17,13 @@ #pragma once #include +#include #include +#include + +#include "velox/common/base/SuccinctPrinter.h" + namespace facebook::velox::common { /// Defines a disk region to read. @@ -35,6 +40,14 @@ struct Region { return offset < other.offset || (offset == other.offset && length < other.length); } + + std::string toString() const { + return fmt::format( + "Region{{offset: {}, length: {}, label: {}}}", + offset, + succinctBytes(length), + label); + } }; } // namespace facebook::velox::common diff --git a/velox/common/file/tests/CMakeLists.txt b/velox/common/file/tests/CMakeLists.txt index 47b3c11e3d2..2cf06039bc6 100644 --- a/velox/common/file/tests/CMakeLists.txt +++ b/velox/common/file/tests/CMakeLists.txt @@ -13,10 +13,23 @@ # limitations under the License. add_library(velox_file_test_utils TestUtils.cpp FaultyFile.cpp FaultyFileSystem.cpp) +velox_add_test_headers( + velox_file_test_utils + FaultyFile.h + FaultyFileSystem.h + FaultyFileSystemOperations.h + TestUtils.h +) target_link_libraries(velox_file_test_utils PUBLIC velox_file) -add_executable(velox_file_test FileTest.cpp FileInputStreamTest.cpp UtilsTest.cpp) +add_executable( + velox_file_test + FileTest.cpp + FileInputStreamTest.cpp + FileIoTracerTest.cpp + FileUtilsTest.cpp +) add_test(velox_file_test velox_file_test) target_link_libraries( velox_file_test @@ -24,7 +37,7 @@ target_link_libraries( velox_buffer velox_file velox_file_test_utils - velox_temp_path + velox_test_util GTest::gmock GTest::gtest GTest::gtest_main diff --git a/velox/common/file/tests/FaultyFile.cpp b/velox/common/file/tests/FaultyFile.cpp index 17897fa9921..23f59ea886f 100644 --- a/velox/common/file/tests/FaultyFile.cpp +++ b/velox/common/file/tests/FaultyFile.cpp @@ -34,7 +34,7 @@ std::string_view FaultyReadFile::pread( uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats) const { + const FileIoContext& context) const { if (injectionHook_ != nullptr) { FaultFileReadOperation op(path_, offset, length, buf); injectionHook_(&op); @@ -42,13 +42,13 @@ std::string_view FaultyReadFile::pread( return std::string_view(static_cast(op.buf), op.length); } } - return delegatedFile_->pread(offset, length, buf, stats); + return delegatedFile_->pread(offset, length, buf, context); } uint64_t FaultyReadFile::preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats) const { + const FileIoContext& context) const { if (injectionHook_ != nullptr) { FaultFileReadvOperation op(path_, offset, buffers); injectionHook_(&op); @@ -56,16 +56,16 @@ uint64_t FaultyReadFile::preadv( return op.readBytes; } } - return delegatedFile_->preadv(offset, buffers, stats); + return delegatedFile_->preadv(offset, buffers, context); } folly::SemiFuture FaultyReadFile::preadvAsync( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats) const { + const FileIoContext& context) const { // TODO: add fault injection for async read later. if (delegatedFile_->hasPreadvAsync() || executor_ == nullptr) { - return delegatedFile_->preadvAsync(offset, buffers, stats); + return delegatedFile_->preadvAsync(offset, buffers, context); } auto promise = std::make_unique>(); folly::SemiFuture future = promise->getSemiFuture(); @@ -73,9 +73,9 @@ folly::SemiFuture FaultyReadFile::preadvAsync( _promise = std::move(promise), _offset = offset, _buffers = buffers, - _stats = stats]() { + _context = context]() { auto delegateFuture = - delegatedFile_->preadvAsync(_offset, _buffers, _stats); + delegatedFile_->preadvAsync(_offset, _buffers, _context); _promise->setValue(delegateFuture.wait().value()); }); return future; diff --git a/velox/common/file/tests/FaultyFile.h b/velox/common/file/tests/FaultyFile.h index 2b4818bd7a1..7986064b9e4 100644 --- a/velox/common/file/tests/FaultyFile.h +++ b/velox/common/file/tests/FaultyFile.h @@ -29,7 +29,7 @@ class FaultyReadFile : public ReadFile { FileFaultInjectionHook injectionHook, folly::Executor* executor); - ~FaultyReadFile() override{}; + ~FaultyReadFile() override {} uint64_t size() const override { return delegatedFile_->size(); @@ -39,12 +39,12 @@ class FaultyReadFile : public ReadFile { uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats = nullptr) const override; + const FileIoContext& context = {}) const override; uint64_t preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats = nullptr) const override; + const FileIoContext& context = {}) const override; uint64_t memoryUsage() const override { return delegatedFile_->memoryUsage(); @@ -72,7 +72,7 @@ class FaultyReadFile : public ReadFile { folly::SemiFuture preadvAsync( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats = nullptr) const override; + const FileIoContext& context = {}) const override; private: const std::string path_; @@ -88,7 +88,7 @@ class FaultyWriteFile : public WriteFile { std::shared_ptr delegatedFile, FileFaultInjectionHook injectionHook); - ~FaultyWriteFile() override{}; + ~FaultyWriteFile() override {} void append(std::string_view data) override; diff --git a/velox/common/file/tests/FaultyFileSystem.h b/velox/common/file/tests/FaultyFileSystem.h index 1dbd698df66..c14cc7201ee 100644 --- a/velox/common/file/tests/FaultyFileSystem.h +++ b/velox/common/file/tests/FaultyFileSystem.h @@ -176,7 +176,7 @@ class FaultyFileSystem : public FileSystem { mutable std::mutex mu_; std::optional fileInjections_; std::optional fsInjections_; - folly::Executor* executor_; + folly::Executor* executor_{nullptr}; }; /// Registers the faulty filesystem. diff --git a/velox/common/file/tests/FileInputStreamTest.cpp b/velox/common/file/tests/FileInputStreamTest.cpp index 5727b302571..4d9bc680236 100644 --- a/velox/common/file/tests/FileInputStreamTest.cpp +++ b/velox/common/file/tests/FileInputStreamTest.cpp @@ -19,11 +19,12 @@ #include "velox/common/file/FileInputStream.h" #include "velox/common/file/FileSystems.h" #include "velox/common/memory/MmapAllocator.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include using namespace facebook::velox; +using namespace facebook::velox::common::testutil; using namespace facebook::velox::memory; class FileInputStreamTest : public testing::Test { @@ -42,7 +43,7 @@ class FileInputStreamTest : public testing::Test { mmapAllocator_ = static_cast(memoryManager_->allocator()); pool_ = memoryManager_->addLeafPool("ByteStreamTest"); rng_.seed(124); - tempDirPath_ = exec::test::TempDirectoryPath::create(); + tempDirPath_ = TempDirectoryPath::create(); fs_ = filesystems::getFileSystem(tempDirPath_->getPath(), nullptr); } @@ -70,7 +71,7 @@ class FileInputStreamTest : public testing::Test { MmapAllocator* mmapAllocator_; std::shared_ptr pool_; std::atomic_uint64_t fileId_{0}; - std::shared_ptr tempDirPath_; + std::shared_ptr tempDirPath_; std::shared_ptr fs_; }; diff --git a/velox/common/file/tests/FileIoTracerTest.cpp b/velox/common/file/tests/FileIoTracerTest.cpp new file mode 100644 index 00000000000..28429f0122d --- /dev/null +++ b/velox/common/file/tests/FileIoTracerTest.cpp @@ -0,0 +1,346 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/file/FileIoTracer.h" + +#include +#include +#include + +#include "gtest/gtest.h" + +using namespace facebook::velox; + +class FileIoTracerTest : public ::testing::Test { + protected: + void SetUp() override { + // Reset thread-local tag before each test + threadIoTag() = nullptr; + } + + void TearDown() override { + // Ensure thread-local tag is reset after each test + threadIoTag() = nullptr; + } +}; + +TEST_F(FileIoTracerTest, ioTag) { + { + IoTag tag("TestTag"); + EXPECT_EQ(tag.name, "TestTag"); + EXPECT_EQ(tag.parent, nullptr); + EXPECT_EQ(tag.toString(), "TestTag"); + EXPECT_EQ(tag.depth(), 1); + } + + { + IoTag parent("Parent"); + IoTag child("Child", &parent); + + EXPECT_EQ(child.name, "Child"); + EXPECT_EQ(child.parent, &parent); + EXPECT_EQ(child.toString(), "Parent -> Child"); + EXPECT_EQ(child.depth(), 2); + } + + { + IoTag tag1("Level1"); + IoTag tag2("Level2", &tag1); + IoTag tag3("Level3", &tag2); + IoTag tag4("Level4", &tag3); + + EXPECT_EQ(tag4.toString(), "Level1 -> Level2 -> Level3 -> Level4"); + EXPECT_EQ(tag4.depth(), 4); + EXPECT_EQ(tag3.depth(), 3); + EXPECT_EQ(tag2.depth(), 2); + EXPECT_EQ(tag1.depth(), 1); + } +} + +TEST_F(FileIoTracerTest, threadIoTag) { + { + EXPECT_EQ(threadIoTag(), nullptr); + } + + { + EXPECT_EQ(threadIoTag(), nullptr); + + { + ScopedIoTag tag1("Tag1"); + EXPECT_NE(threadIoTag(), nullptr); + EXPECT_EQ(threadIoTag()->name, "Tag1"); + EXPECT_EQ(threadIoTag()->parent, nullptr); + EXPECT_EQ(threadIoTag()->toString(), "Tag1"); + } + + EXPECT_EQ(threadIoTag(), nullptr); + } + + { + EXPECT_EQ(threadIoTag(), nullptr); + + { + ScopedIoTag tag1("TableScan"); + EXPECT_EQ(threadIoTag()->toString(), "TableScan"); + EXPECT_EQ(threadIoTag()->depth(), 1); + + { + ScopedIoTag tag2("ColumnReader"); + EXPECT_EQ(threadIoTag()->toString(), "TableScan -> ColumnReader"); + EXPECT_EQ(threadIoTag()->depth(), 2); + + { + ScopedIoTag tag3("PrefixEncoding"); + EXPECT_EQ( + threadIoTag()->toString(), + "TableScan -> ColumnReader -> PrefixEncoding"); + EXPECT_EQ(threadIoTag()->depth(), 3); + } + + EXPECT_EQ(threadIoTag()->toString(), "TableScan -> ColumnReader"); + EXPECT_EQ(threadIoTag()->depth(), 2); + } + + EXPECT_EQ(threadIoTag()->toString(), "TableScan"); + EXPECT_EQ(threadIoTag()->depth(), 1); + } + + EXPECT_EQ(threadIoTag(), nullptr); + } + + { + ScopedIoTag tag("TestTag"); + EXPECT_EQ(tag.tag().name, "TestTag"); + EXPECT_EQ(tag.tag().parent, nullptr); + } +} + +TEST_F(FileIoTracerTest, threadLocalIsolation) { + // Verify that thread-local tags are isolated between threads + std::atomic_bool stop{false}; + + std::thread thread1([&]() { + ScopedIoTag tag("Thread1Tag"); + while (!stop) { + EXPECT_EQ(threadIoTag()->name, "Thread1Tag"); + std::this_thread::sleep_for(std::chrono::microseconds(100)); + } + }); + + std::thread thread2([&]() { + ScopedIoTag tag("Thread2Tag"); + while (!stop) { + EXPECT_EQ(threadIoTag()->name, "Thread2Tag"); + std::this_thread::sleep_for(std::chrono::microseconds(100)); + } + }); + + std::this_thread::sleep_for(std::chrono::seconds(2)); + stop = true; + + thread1.join(); + thread2.join(); +} + +TEST_F(FileIoTracerTest, inMemoryFileIoTracer) { + { + std::vector records; + auto tracer = InMemoryFileIoTracer::create(records); + + tracer->record(IoType::Read, 100, 200); + ASSERT_EQ(records.size(), 1); + EXPECT_EQ(records[0].type, IoType::Read); + EXPECT_EQ(records[0].offset, 100); + EXPECT_EQ(records[0].length, 200); + EXPECT_EQ(records[0].tag, ""); + tracer->finish(); + ASSERT_EQ(records.size(), 1); + EXPECT_EQ(records[0].type, IoType::Read); + EXPECT_EQ(records[0].offset, 100); + EXPECT_EQ(records[0].length, 200); + EXPECT_EQ(records[0].tag, ""); + } + + { + std::vector records; + auto tracer = InMemoryFileIoTracer::create(records); + + { + ScopedIoTag tag1("TableScan"); + tracer->record(IoType::Read, 0, 100); + + { + ScopedIoTag tag2("ColumnReader"); + tracer->record(IoType::AsyncRead, 100, 200); + } + + tracer->record(IoType::Write, 300, 50); + } + + tracer->record(IoType::AsyncWrite, 400, 75); + + ASSERT_EQ(records.size(), 4); + + EXPECT_EQ(records[0].type, IoType::Read); + EXPECT_EQ(records[0].offset, 0); + EXPECT_EQ(records[0].length, 100); + EXPECT_EQ(records[0].tag, "TableScan"); + + EXPECT_EQ(records[1].type, IoType::AsyncRead); + EXPECT_EQ(records[1].offset, 100); + EXPECT_EQ(records[1].length, 200); + EXPECT_EQ(records[1].tag, "TableScan -> ColumnReader"); + + EXPECT_EQ(records[2].type, IoType::Write); + EXPECT_EQ(records[2].offset, 300); + EXPECT_EQ(records[2].length, 50); + EXPECT_EQ(records[2].tag, "TableScan"); + + EXPECT_EQ(records[3].type, IoType::AsyncWrite); + EXPECT_EQ(records[3].offset, 400); + EXPECT_EQ(records[3].length, 75); + EXPECT_EQ(records[3].tag, ""); + } +} + +TEST_F(FileIoTracerTest, ioType) { + // Verify enum values exist and are distinct + EXPECT_NE(IoType::Read, IoType::AsyncRead); + EXPECT_NE(IoType::Read, IoType::Write); + EXPECT_NE(IoType::Read, IoType::AsyncWrite); + EXPECT_NE(IoType::AsyncRead, IoType::Write); + EXPECT_NE(IoType::AsyncRead, IoType::AsyncWrite); + EXPECT_NE(IoType::Write, IoType::AsyncWrite); +} + +TEST_F(FileIoTracerTest, ioRecord) { + { + IoRecord record; + record.type = IoType::Read; + record.offset = 1024; + record.length = 4096; + record.tag = "TableScan -> ColumnReader"; + + EXPECT_EQ(record.toString(), "Read [1024, 4096] TableScan -> ColumnReader"); + } + + { + IoRecord record; + record.type = IoType::AsyncRead; + record.offset = 0; + record.length = 512; + record.tag = ""; + + EXPECT_EQ(record.toString(), "AsyncRead [0, 512]"); + } + + { + IoRecord record; + record.type = IoType::Write; + record.offset = 100; + record.length = 200; + record.tag = "Writer"; + + EXPECT_EQ(record.toString(), "Write [100, 200] Writer"); + } + + { + IoRecord record; + record.type = IoType::AsyncWrite; + record.offset = 500; + record.length = 1000; + record.tag = "AsyncWriter"; + + EXPECT_EQ(record.toString(), "AsyncWrite [500, 1000] AsyncWriter"); + } +} + +TEST_F(FileIoTracerTest, inMemoryFileIoTracerFuzz) { + constexpr int kNumThreads = 8; + constexpr int kMaxRecordsPerThread = 1'000; + + std::vector sharedRecords; + auto tracer = InMemoryFileIoTracer::create(sharedRecords); + + std::vector threads; + std::vector> perThreadRecords(kNumThreads); + std::vector ioTypes = { + IoType::Read, IoType::AsyncRead, IoType::Write, IoType::AsyncWrite}; + + threads.reserve(kNumThreads); + for (int threadIdx = 0; threadIdx < kNumThreads; ++threadIdx) { + threads.emplace_back([&, threadIdx]() { + std::mt19937 rng(threadIdx); + std::uniform_int_distribution countDist(1, kMaxRecordsPerThread); + std::uniform_int_distribution offsetDist(0, 1000000); + std::uniform_int_distribution lengthDist(1, 10000); + std::uniform_int_distribution typeDist(0, 3); + + const std::string tagName = "Thread" + std::to_string(threadIdx); + ScopedIoTag tag(tagName); + + const int numRecords = countDist(rng); + for (int i = 0; i < numRecords; ++i) { + IoRecord localRecord; + localRecord.type = ioTypes[typeDist(rng)]; + localRecord.offset = offsetDist(rng); + localRecord.length = lengthDist(rng); + localRecord.tag = tagName; + + perThreadRecords[threadIdx].push_back(localRecord); + tracer->record( + localRecord.type, localRecord.offset, localRecord.length); + } + }); + } + + for (auto& thread : threads) { + thread.join(); + } + + tracer->finish(); + + size_t expectedTotalRecords = 0; + for (const auto& threadRecords : perThreadRecords) { + expectedTotalRecords += threadRecords.size(); + } + EXPECT_EQ(sharedRecords.size(), expectedTotalRecords); + + std::unordered_map> recordsByTag; + for (const auto& record : sharedRecords) { + recordsByTag[record.tag].push_back(record); + } + + for (int threadIdx = 0; threadIdx < kNumThreads; ++threadIdx) { + const std::string tagName = "Thread" + std::to_string(threadIdx); + const auto& expected = perThreadRecords[threadIdx]; + const auto& actual = recordsByTag[tagName]; + + ASSERT_EQ(actual.size(), expected.size()) + << "Thread " << threadIdx << " record count mismatch"; + + for (size_t i = 0; i < expected.size(); ++i) { + EXPECT_EQ(actual[i].type, expected[i].type) + << "Thread " << threadIdx << " record " << i << " type mismatch"; + EXPECT_EQ(actual[i].offset, expected[i].offset) + << "Thread " << threadIdx << " record " << i << " offset mismatch"; + EXPECT_EQ(actual[i].length, expected[i].length) + << "Thread " << threadIdx << " record " << i << " length mismatch"; + EXPECT_EQ(actual[i].tag, expected[i].tag) + << "Thread " << threadIdx << " record " << i << " tag mismatch"; + } + } +} diff --git a/velox/common/file/tests/FileTest.cpp b/velox/common/file/tests/FileTest.cpp index 6194fbcc35b..7da4b360356 100644 --- a/velox/common/file/tests/FileTest.cpp +++ b/velox/common/file/tests/FileTest.cpp @@ -16,17 +16,19 @@ #include #include +#include #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/File.h" #include "velox/common/file/FileSystems.h" #include "velox/common/file/tests/FaultyFileSystem.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" -#include "velox/exec/tests/utils/TempFilePath.h" +#include "velox/common/testutil/TempDirectoryPath.h" +#include "velox/common/testutil/TempFilePath.h" #include "gtest/gtest.h" using namespace facebook::velox; +using namespace facebook::velox::common::testutil; using facebook::velox::common::Region; using namespace facebook::velox::tests::utils; @@ -49,6 +51,17 @@ void writeData(WriteFile* writeFile, bool useIOBuf = false) { } } +TEST(FileIoContextTest, defaultCacheableIsFalse) { + FileIoContext defaultContext; + EXPECT_FALSE(defaultContext.cacheable); + + FileIoContext explicitContext(nullptr); + EXPECT_FALSE(explicitContext.cacheable); + + FileIoContext cacheableContext(nullptr, {}, nullptr, true); + EXPECT_TRUE(cacheableContext.cacheable); +} + void writeDataWithOffset(WriteFile* writeFile) { ASSERT_EQ(writeFile->size(), 0); writeFile->truncate(15 + kOneMB); @@ -168,8 +181,9 @@ TEST(InMemoryFile, preadv) { std::vector values; values.reserve(iobufs.size()); for (auto& iobuf : iobufs) { - values.push_back(std::string{ - reinterpret_cast(iobuf.data()), iobuf.length()}); + values.push_back( + std::string{ + reinterpret_cast(iobuf.data()), iobuf.length()}); } EXPECT_EQ(expected, values); @@ -187,9 +201,7 @@ class LocalFileTest : public ::testing::TestWithParam { const bool useFaultyFs_; const std::unique_ptr executor_ = std::make_unique( - std::max( - 1, - static_cast(std::thread::hardware_concurrency() / 2)), + std::max(1, static_cast(folly::available_concurrency() / 2)), std::make_shared( "LocalFileReadAheadTest")); }; @@ -206,7 +218,7 @@ TEST_P(LocalFileTest, writeAndRead) { for (auto testData : testSettings) { SCOPED_TRACE(testData.debugString()); - auto tempFile = exec::test::TempFilePath::create(useFaultyFs_); + auto tempFile = TempFilePath::create(useFaultyFs_); const auto& filename = tempFile->getPath(); auto fs = filesystems::getFileSystem(filename, {}); fs->remove(filename); @@ -234,7 +246,7 @@ TEST_P(LocalFileTest, writeAndRead) { } TEST_P(LocalFileTest, viaRegistry) { - auto tempFile = exec::test::TempFilePath::create(useFaultyFs_); + auto tempFile = TempFilePath::create(useFaultyFs_); const auto& filename = tempFile->getPath(); auto fs = filesystems::getFileSystem(filename, {}); fs->remove(filename); @@ -250,7 +262,7 @@ TEST_P(LocalFileTest, viaRegistry) { } TEST_P(LocalFileTest, rename) { - const auto tempFolder = ::exec::test::TempDirectoryPath::create(useFaultyFs_); + const auto tempFolder = TempDirectoryPath::create(useFaultyFs_); const auto a = fmt::format("{}/a", tempFolder->getPath()); const auto b = fmt::format("{}/b", tempFolder->getPath()); const auto newA = fmt::format("{}/newA", tempFolder->getPath()); @@ -277,7 +289,7 @@ TEST_P(LocalFileTest, rename) { } TEST_P(LocalFileTest, exists) { - auto tempFolder = ::exec::test::TempDirectoryPath::create(useFaultyFs_); + auto tempFolder = TempDirectoryPath::create(useFaultyFs_); auto a = fmt::format("{}/a", tempFolder->getPath()); auto b = fmt::format("{}/b", tempFolder->getPath()); auto localFs = filesystems::getFileSystem(a, nullptr); @@ -296,7 +308,7 @@ TEST_P(LocalFileTest, exists) { } TEST_P(LocalFileTest, isDirectory) { - auto tempFolder = ::exec::test::TempDirectoryPath::create(useFaultyFs_); + auto tempFolder = TempDirectoryPath::create(useFaultyFs_); auto a = fmt::format("{}/a", tempFolder->getPath()); auto localFs = filesystems::getFileSystem(a, nullptr); auto writeFile = localFs->openFileForWrite(a); @@ -305,7 +317,7 @@ TEST_P(LocalFileTest, isDirectory) { } TEST_P(LocalFileTest, list) { - const auto tempFolder = ::exec::test::TempDirectoryPath::create(useFaultyFs_); + const auto tempFolder = TempDirectoryPath::create(useFaultyFs_); const auto a = fmt::format("{}/1", tempFolder->getPath()); const auto b = fmt::format("{}/2", tempFolder->getPath()); auto localFs = filesystems::getFileSystem(a, nullptr); @@ -328,7 +340,7 @@ TEST_P(LocalFileTest, readFileDestructor) { if (useFaultyFs_) { return; } - auto tempFile = exec::test::TempFilePath::create(useFaultyFs_); + auto tempFile = TempFilePath::create(useFaultyFs_); const auto& filename = tempFile->getPath(); auto fs = filesystems::getFileSystem(filename, {}); fs->remove(filename); @@ -362,7 +374,7 @@ TEST_P(LocalFileTest, readFileDestructor) { } TEST_P(LocalFileTest, mkdirFailIfPresent) { - auto tempFolder = exec::test::TempDirectoryPath::create(useFaultyFs_); + auto tempFolder = TempDirectoryPath::create(useFaultyFs_); std::string path = tempFolder->getPath(); auto localFs = filesystems::getFileSystem(path, nullptr); @@ -384,7 +396,7 @@ TEST_P(LocalFileTest, mkdirFailIfPresent) { } TEST_P(LocalFileTest, mkdir) { - auto tempFolder = exec::test::TempDirectoryPath::create(useFaultyFs_); + auto tempFolder = TempDirectoryPath::create(useFaultyFs_); std::string path = tempFolder->getPath(); auto localFs = filesystems::getFileSystem(path, nullptr); @@ -409,7 +421,7 @@ TEST_P(LocalFileTest, mkdir) { } TEST_P(LocalFileTest, rmdir) { - auto tempFolder = exec::test::TempDirectoryPath::create(useFaultyFs_); + auto tempFolder = TempDirectoryPath::create(useFaultyFs_); std::string path = tempFolder->getPath(); auto localFs = filesystems::getFileSystem(path, nullptr); @@ -442,7 +454,7 @@ TEST_P(LocalFileTest, rmdir) { } TEST_P(LocalFileTest, fileNotFound) { - auto tempFolder = exec::test::TempDirectoryPath::create(useFaultyFs_); + auto tempFolder = TempDirectoryPath::create(useFaultyFs_); auto path = fmt::format("{}/file", tempFolder->getPath()); auto localFs = filesystems::getFileSystem(path, nullptr); VELOX_ASSERT_RUNTIME_THROW_CODE( @@ -452,7 +464,7 @@ TEST_P(LocalFileTest, fileNotFound) { } TEST_P(LocalFileTest, attributes) { - auto tempFile = exec::test::TempFilePath::create(useFaultyFs_); + auto tempFile = TempFilePath::create(useFaultyFs_); const auto& filename = tempFile->getPath(); auto fs = filesystems::getFileSystem(filename, {}); fs->remove(filename); @@ -486,7 +498,7 @@ class FaultyFsTest : public ::testing::Test { } void SetUp() { - dir_ = exec::test::TempDirectoryPath::create(true); + dir_ = TempDirectoryPath::create(true); fs_ = std::dynamic_pointer_cast( filesystems::getFileSystem(dir_->getPath(), {})); VELOX_CHECK_NOT_NULL(fs_); @@ -535,7 +547,7 @@ class FaultyFsTest : public ::testing::Test { } } - std::shared_ptr dir_; + std::shared_ptr dir_; std::string readFilePath_; std::string writeFilePath_; std::shared_ptr fs_; diff --git a/velox/common/file/tests/UtilsTest.cpp b/velox/common/file/tests/FileUtilsTest.cpp similarity index 91% rename from velox/common/file/tests/UtilsTest.cpp rename to velox/common/file/tests/FileUtilsTest.cpp index a2b2d6ee5a0..eaa2b72cb94 100644 --- a/velox/common/file/tests/UtilsTest.cpp +++ b/velox/common/file/tests/FileUtilsTest.cpp @@ -17,7 +17,8 @@ #include #include -#include "velox/common/file/Utils.h" +#include "velox/common/file/FileUtils.h" +#include "velox/common/file/Region.h" #include "velox/common/file/tests/TestUtils.h" using namespace ::testing; @@ -98,11 +99,12 @@ auto getReader( /* minTailRoom*/ 0); for (size_t i = 1; i < size; ++i) { - head->appendToChain(folly::IOBuf::copyBuffer( - buf.data() + offset + i, - /* size */ 1, - /* headroom */ 0, - /* minTailRoom*/ 0)); + head->appendToChain( + folly::IOBuf::copyBuffer( + buf.data() + offset + i, + /* size */ 1, + /* headroom */ 0, + /* minTailRoom*/ 0)); } return head; } else { @@ -289,3 +291,19 @@ INSTANTIATE_TEST_SUITE_P( ReadToIOBufsTest, ValuesIn( std::vector({false, true}))); + +TEST(RegionTest, toString) { + EXPECT_EQ(Region(0, 0).toString(), "Region{offset: 0, length: 0B, label: }"); + EXPECT_EQ( + Region(100, 256).toString(), + "Region{offset: 100, length: 256B, label: }"); + EXPECT_EQ( + Region(1024, 1024, "test").toString(), + "Region{offset: 1024, length: 1.00KB, label: test}"); + EXPECT_EQ( + Region(0, 1'048'576, "stream").toString(), + "Region{offset: 0, length: 1.00MB, label: stream}"); + EXPECT_EQ( + Region(12345, 1'073'741'824).toString(), + "Region{offset: 12345, length: 1.00GB, label: }"); +} diff --git a/velox/common/file/tests/TestUtils.h b/velox/common/file/tests/TestUtils.h index 43355958f52..f0883e29d11 100644 --- a/velox/common/file/tests/TestUtils.h +++ b/velox/common/file/tests/TestUtils.h @@ -23,4 +23,26 @@ namespace facebook::velox::tests::utils { std::vector iobufsToStrings( const std::vector& iobufs); +/// Wraps InMemoryReadFile with pread call counting for test assertions. +class CountingReadFile : public InMemoryReadFile { + public: + using InMemoryReadFile::InMemoryReadFile; + + std::string_view pread( + uint64_t offset, + uint64_t length, + void* buf, + const FileIoContext& context = {}) const override { + ++numReads_; + return InMemoryReadFile::pread(offset, length, buf, context); + } + + uint64_t numReads() const { + return numReads_; + } + + private: + mutable std::atomic_uint64_t numReads_{0}; +}; + } // namespace facebook::velox::tests::utils diff --git a/CMake/resolve_dependency_modules/cpr/cpr-remove-sancheck.patch b/velox/common/future/CMakeLists.txt similarity index 66% rename from CMake/resolve_dependency_modules/cpr/cpr-remove-sancheck.patch rename to velox/common/future/CMakeLists.txt index 4fca92831a2..c9ca4266c65 100644 --- a/CMake/resolve_dependency_modules/cpr/cpr-remove-sancheck.patch +++ b/velox/common/future/CMakeLists.txt @@ -11,14 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# This hangs on CI and is not needed #9116 ---- a/CMakeLists.txt -+++ b/CMakeLists.txt -@@ -84,7 +84,6 @@ endif() - include(GNUInstallDirs) - include(FetchContent) - include(cmake/code_coverage.cmake) --include(cmake/sanitizer.cmake) - include(cmake/clear_variable.cmake) +velox_add_library(velox_future INTERFACE HEADERS VeloxPromise.h) - # So CMake can find FindMbedTLS.cmake +velox_install_library_headers() diff --git a/velox/common/future/VeloxPromise.h b/velox/common/future/VeloxPromise.h index ba8fbdfbce3..7ac170a2fc1 100644 --- a/velox/common/future/VeloxPromise.h +++ b/velox/common/future/VeloxPromise.h @@ -28,12 +28,15 @@ class VeloxPromise : public folly::Promise { VeloxPromise() : folly::Promise() {} explicit VeloxPromise(const std::string& context) - : folly::Promise(), context_(context) {} + : folly::Promise(), context_(context) { + if (context.empty()) { + LOG(WARNING) + << "PROMISE: VeloxPromise must be constructed with a context."; + } + } - VeloxPromise( - folly::futures::detail::EmptyConstruct, - const std::string& context) noexcept - : folly::Promise(folly::Promise::makeEmpty()), context_(context) {} + explicit VeloxPromise(folly::futures::detail::EmptyConstruct) noexcept + : folly::Promise(folly::Promise::makeEmpty()) {} ~VeloxPromise() { if (!this->isFulfilled()) { @@ -52,8 +55,8 @@ class VeloxPromise : public folly::Promise { return *this; } - static VeloxPromise makeEmpty(const std::string& context = "") noexcept { - return VeloxPromise(folly::futures::detail::EmptyConstruct{}, context); + static VeloxPromise makeEmpty() noexcept { + return VeloxPromise(folly::futures::detail::EmptyConstruct{}); } private: @@ -65,8 +68,14 @@ using ContinuePromise = VeloxPromise; using ContinueFuture = folly::SemiFuture; /// Equivalent of folly's makePromiseContract for VeloxPromise. +/// +/// NOTE: When you already have a valid promise, just call +/// Promise::getSemiFuture() on it to get the future, instead of using this +/// function to overwrite the promise. Overwriting valid promise would cause +/// exception throwing and stack unwinding thus performance issue. See +/// https://github.com/prestodb/presto/issues/26094 for details. static inline std::pair -makeVeloxContinuePromiseContract(const std::string& promiseContext = "") { +makeVeloxContinuePromiseContract(const std::string& promiseContext) { auto p = ContinuePromise(promiseContext); auto f = p.getSemiFuture(); return std::make_pair(std::move(p), std::move(f)); diff --git a/velox/common/fuzzer/CMakeLists.txt b/velox/common/fuzzer/CMakeLists.txt index d2e3663b73d..3391f989d02 100644 --- a/velox/common/fuzzer/CMakeLists.txt +++ b/velox/common/fuzzer/CMakeLists.txt @@ -12,11 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -velox_add_library(velox_common_fuzzer_util Utils.cpp) +velox_add_library(velox_common_fuzzer_util Utils.cpp HEADERS Utils.h) velox_link_libraries(velox_common_fuzzer_util velox_type velox_exception) -velox_add_library(velox_constrained_input_generators ConstrainedGenerators.cpp) +velox_add_library( + velox_constrained_input_generators + ConstrainedGenerators.cpp + HEADERS + ConstrainedGenerators.h +) velox_link_libraries( velox_constrained_input_generators diff --git a/velox/common/fuzzer/ConstrainedGenerators.cpp b/velox/common/fuzzer/ConstrainedGenerators.cpp index 708e5ee8721..4334e8cb3bd 100644 --- a/velox/common/fuzzer/ConstrainedGenerators.cpp +++ b/velox/common/fuzzer/ConstrainedGenerators.cpp @@ -17,6 +17,8 @@ #include "velox/common/fuzzer/ConstrainedGenerators.h" #include #include "velox/common/fuzzer/Utils.h" +#include "velox/common/memory/HashStringAllocator.h" +#include "velox/functions/lib/SetDigest.h" #include "velox/functions/lib/TDigest.h" #include "velox/functions/prestosql/types/BingTileType.h" @@ -282,10 +284,84 @@ variant TDigestInputGenerator::generate() { size_t byteSize = digest.serializedByteSize(); std::string serializedDigest(byteSize, '\0'); digest.serialize(&serializedDigest[0]); - StringView serializedView(serializedDigest.data(), serializedDigest.size()); return variant::create(serializedDigest); } +SetDigestInputGenerator::SetDigestInputGenerator( + size_t seed, + const TypePtr& type, + double nullRatio) + : AbstractInputGenerator(seed, type, nullptr, nullRatio), + pool_(velox::memory::memoryManager()->addLeafPool()), + allocator_(std::make_unique(pool_.get())) { + // SetDigest supports int64_t and StringView + static const std::vector kBaseTypes{BIGINT(), VARCHAR()}; + baseType_ = kBaseTypes[rand(rng_, 0, kBaseTypes.size() - 1)]; +} + +SetDigestInputGenerator::~SetDigestInputGenerator() = default; + +template +variant SetDigestInputGenerator::generateTyped() { + velox::functions::SetDigest digest(allocator_.get()); + + // SetDigest defaults to maxHashes=8192. Usually generate small datasets + // (exact mode), but occasionally generate >8192 values to test approximate + // mode. + int numValues = coinToss(rng_, 0.1) ? rand(rng_, 8500, 10000) + : rand(rng_, 10, 100); + for (int i = 0; i < numValues; ++i) { + int64_t value = rand(rng_); + digest.add(value); + } + + size_t byteSize = digest.estimatedSerializedSize(); + std::string serializedDigest(byteSize, '\0'); + digest.serialize(&serializedDigest[0]); + return variant::create(serializedDigest); +} + +template <> +variant SetDigestInputGenerator::generateTyped() { + velox::functions::SetDigest digest(allocator_.get()); + + int numValues = coinToss(rng_, 0.1) ? rand(rng_, 8500, 10000) + : rand(rng_, 10, 100); + static const std::vector encodings{ + UTF8CharList::ASCII, + UTF8CharList::UNICODE_CASE_SENSITIVE, + UTF8CharList::EXTENDED_UNICODE, + UTF8CharList::MATHEMATICAL_SYMBOLS}; +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + std::wstring_convert, char16_t> converter; +#pragma GCC diagnostic pop + + for (int i = 0; i < numValues; ++i) { + auto size = rand(rng_, 0, 100); + std::string result; + auto str = randString(rng_, size, encodings, result, converter); + digest.add(StringView(str)); + } + + size_t byteSize = digest.estimatedSerializedSize(); + std::string serializedDigest(byteSize, '\0'); + digest.serialize(&serializedDigest[0]); + return variant::create(serializedDigest); +} + +variant SetDigestInputGenerator::generate() { + if (coinToss(rng_, nullRatio_)) { + return variant::null(type_->kind()); + } + + if (baseType_->isBigint()) { + return generateTyped(); + } else { + return generateTyped(); + } +} + // BingTileInputGenerator BingTileInputGenerator::BingTileInputGenerator( @@ -537,7 +613,6 @@ CastVarcharInputGenerator::CastVarcharInputGenerator( CastVarcharInputGenerator::~CastVarcharInputGenerator() = default; -#pragma GCC diagnostic ignored "-Wdeprecated-declarations" std::string CastVarcharInputGenerator::generateValidPrimitiveAsString() { switch (castToType_->kind()) { case TypeKind::BOOLEAN: { @@ -575,7 +650,10 @@ std::string CastVarcharInputGenerator::generateValidPrimitiveAsString() { case TypeKind::VARCHAR: { // Generate random string. std::string input; +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" std::wstring_convert, char16_t> converter; +#pragma GCC diagnostic pop auto randomStr = randString( rng_, rand(rng_, 0, 20), @@ -586,9 +664,10 @@ std::string CastVarcharInputGenerator::generateValidPrimitiveAsString() { } default: // cast from varchar doesn't support complex types - VELOX_FAIL_UNSUPPORTED_INPUT_UNCATCHABLE(fmt::format( - "Type `{}` not supported for cast varchar custom generator", - castToType_->kind())); + VELOX_FAIL_UNSUPPORTED_INPUT_UNCATCHABLE( + fmt::format( + "Type `{}` not supported for cast varchar custom generator", + castToType_->kind())); } } diff --git a/velox/common/fuzzer/ConstrainedGenerators.h b/velox/common/fuzzer/ConstrainedGenerators.h index 8d694eb40e0..20d9565e38c 100644 --- a/velox/common/fuzzer/ConstrainedGenerators.h +++ b/velox/common/fuzzer/ConstrainedGenerators.h @@ -21,6 +21,7 @@ #include "folly/json.h" #include "velox/common/fuzzer/Utils.h" +#include "velox/common/memory/HashStringAllocator.h" #include "velox/functions/lib/QuantileDigest.h" #include "velox/type/Type.h" #include "velox/type/Variant.h" @@ -61,6 +62,10 @@ class RandomInputGenerator : public AbstractInputGenerator { if (type_->isDate()) { return variant(randDate(rng_)); } + if (type_->isTime()) { + VELOX_DCHECK(type_->equivalent(*TIME())); + return variant(randTime(rng_)); + } return variant(rand(rng_)); } }; @@ -527,6 +532,23 @@ class TDigestInputGenerator : public AbstractInputGenerator { variant generate() override; }; +class SetDigestInputGenerator : public AbstractInputGenerator { + public: + SetDigestInputGenerator(size_t seed, const TypePtr& type, double nullRatio); + + ~SetDigestInputGenerator() override; + + variant generate() override; + + private: + template + variant generateTyped(); + + TypePtr baseType_; + std::shared_ptr pool_; + std::unique_ptr allocator_; +}; + class BingTileInputGenerator : public AbstractInputGenerator { public: BingTileInputGenerator(size_t seed, const TypePtr& type, double nullRatio); diff --git a/velox/common/fuzzer/Utils.cpp b/velox/common/fuzzer/Utils.cpp index 36f86be60ff..2e5580396e5 100644 --- a/velox/common/fuzzer/Utils.cpp +++ b/velox/common/fuzzer/Utils.cpp @@ -15,6 +15,8 @@ */ #include "velox/common/fuzzer/Utils.h" +#include +#include "velox/type/Time.h" namespace facebook::velox::fuzzer { @@ -117,6 +119,10 @@ int32_t randDate(FuzzerGenerator& rng) { return rand(rng, min, max); } +int32_t randTime(FuzzerGenerator& rng) { + return rand(rng, TIME()->getMin(), TIME()->getMax()); +} + /// Unicode character ranges. Ensure the vector indexes match the UTF8CharList /// enum values. /// @@ -202,4 +208,38 @@ std::string randString( } #pragma GCC diagnostic pop +int16_t generateRandomTimezoneOffset( + FuzzerGenerator& rng, + double frequentlyUsedProbability) { + // 25% probability: pick from frequently used offsets + // 75% probability: generate random offset in range [-840, 840] + if (coinToss(rng, frequentlyUsedProbability)) { + auto index = + rand(rng, 0, kFrequentlyUsedTimezoneOffsets.size() - 1); + return kFrequentlyUsedTimezoneOffsets[index]; + } else { + return rand(rng, -util::kTimeZoneBias, util::kTimeZoneBias); + } +} + +std::string timezoneOffsetToString(int16_t offsetMinutes) { + // Validate range [-840, 840] + VELOX_USER_CHECK( + offsetMinutes >= -util::kTimeZoneBias && + offsetMinutes <= util::kTimeZoneBias, + "Timezone offset {} minutes is out of range [-840, 840]", + offsetMinutes); + + // Determine sign + char sign = (offsetMinutes >= 0) ? '+' : '-'; + + // Calculate hours and minutes using absolute value + int16_t absOffset = std::abs(offsetMinutes); + int16_t hours = absOffset / util::kMinutesInHour; + int16_t minutes = absOffset % util::kMinutesInHour; + + // Format as "+HH:mm" or "-HH:mm" with zero-padding + return fmt::format("{}{:02d}:{:02d}", sign, hours, minutes); +} + } // namespace facebook::velox::fuzzer diff --git a/velox/common/fuzzer/Utils.h b/velox/common/fuzzer/Utils.h index db910007c31..a9b5cdc2309 100644 --- a/velox/common/fuzzer/Utils.h +++ b/velox/common/fuzzer/Utils.h @@ -31,6 +31,17 @@ namespace facebook::velox::fuzzer { using FuzzerGenerator = folly::detail::DefaultGenerator; +// Frequently used timezone offsets in minutes (US timezones including DST) +// -240 = UTC-4:00 (EDT) +// -300 = UTC-5:00 (EST/CDT) +// -360 = UTC-6:00 (CST/MDT) +// -420 = UTC-7:00 (MST/PDT) +// -480 = UTC-8:00 (PST) +// -540 = UTC-9:00 (AKST) +// -600 = UTC-10:00 (HST) +constexpr std::array kFrequentlyUsedTimezoneOffsets = + {-240, -300, -360, -420, -480, -540, -600}; + enum UTF8CharList { ASCII = 0, // Ascii character set. UNICODE_CASE_SENSITIVE = 1, // Unicode scripts that support case. @@ -167,6 +178,35 @@ inline Timestamp rand(FuzzerGenerator& rng, DataSpec /*dataSpec*/) { int32_t randDate(FuzzerGenerator& rng); +int32_t randTime(FuzzerGenerator& rng); + +/// Generate random timezone offset using biased distribution +/// 25% probability: picks from frequently used offsets +/// 75% probability: generates random offset from [-840, 840] +/// +/// @param rng Random number generator +/// @param frequentlyUsedProbability Probability of selecting from frequently +/// used offsets (default 0.25 for 25%) +/// @return Timezone offset in minutes [-840, 840] +int16_t generateRandomTimezoneOffset( + FuzzerGenerator& rng, + double frequentlyUsedProbability = 0.25); + +/// Convert timezone offset in minutes to "+HH:mm" or "-HH:mm" format +/// Always uses Presto-compatible +HH:mm format (never +HH or +HHmm) +/// +/// Examples: +/// - 0 → "+00:00" +/// - 330 → "+05:30" +/// - -300 → "-05:00" +/// - 840 → "+14:00" +/// - -840 → "-14:00" +/// +/// @param offsetMinutes Timezone offset in minutes [-840, 840] +/// @return Timezone offset string in "+HH:mm" or "-HH:mm" format +/// @throws VeloxException if offsetMinutes is out of range [-840, 840] +std::string timezoneOffsetToString(int16_t offsetMinutes); + template < typename T, typename std::enable_if_t, int> = 0> diff --git a/velox/common/fuzzer/tests/ConstrainedGeneratorsTest.cpp b/velox/common/fuzzer/tests/ConstrainedGeneratorsTest.cpp index 5b6c70f15bb..1defe5db7b2 100644 --- a/velox/common/fuzzer/tests/ConstrainedGeneratorsTest.cpp +++ b/velox/common/fuzzer/tests/ConstrainedGeneratorsTest.cpp @@ -22,6 +22,7 @@ #include "velox/functions/prestosql/json/JsonExtractor.h" #include "velox/functions/prestosql/types/JsonType.h" #include "velox/functions/prestosql/types/QDigestType.h" +#include "velox/functions/prestosql/types/SetDigestType.h" #include "velox/functions/prestosql/types/TDigestType.h" #include "velox/type/Variant.h" @@ -405,9 +406,10 @@ TEST_F(ConstrainedGeneratorsTest, jsonPath) { const auto jsonPath = jsonPathGenerator->generate(); if (jsonPath.hasValue()) { if (json.hasValue()) { - EXPECT_NO_THROW(functions::jsonExtract( - json.value(), - jsonPath.value())); + EXPECT_NO_THROW( + functions::jsonExtract( + json.value(), + jsonPath.value())); } } else { hasNull = true; @@ -423,6 +425,13 @@ TEST_F(ConstrainedGeneratorsTest, tdigest) { EXPECT_EQ(value.kind(), TypeKind::VARBINARY); } +TEST_F(ConstrainedGeneratorsTest, setdigest) { + std::unique_ptr generator = + std::make_unique(0, SETDIGEST(), 0.4); + auto value = generator->generate(); + EXPECT_EQ(value.kind(), TypeKind::VARBINARY); +} + TEST_F(ConstrainedGeneratorsTest, qdigest) { std::unique_ptr generator = std::make_unique( diff --git a/velox/common/fuzzer/tests/UtilsTest.cpp b/velox/common/fuzzer/tests/UtilsTest.cpp index ba21bf93a1a..5e82bdc85d6 100644 --- a/velox/common/fuzzer/tests/UtilsTest.cpp +++ b/velox/common/fuzzer/tests/UtilsTest.cpp @@ -26,23 +26,25 @@ namespace facebook::velox::fuzzer::test { class UtilsTest : public testing::Test {}; TEST_F(UtilsTest, testRuleList) { - auto simple = RuleList(std::vector>{ - std::make_shared("Hello"), - std::make_shared(","), - std::make_shared(" "), - std::make_shared("world"), - std::make_shared("!"), - }); + auto simple = RuleList( + std::vector>{ + std::make_shared("Hello"), + std::make_shared(","), + std::make_shared(" "), + std::make_shared("world"), + std::make_shared("!"), + }); ASSERT_EQ(simple.generate(), "Hello, world!"); FuzzerGenerator rng; - auto fuzz = RuleList(std::vector>{ - std::make_shared("Hello"), - std::make_shared(","), - std::make_shared(" "), - std::make_shared(rng), - std::make_shared("!"), - }); + auto fuzz = RuleList( + std::vector>{ + std::make_shared("Hello"), + std::make_shared(","), + std::make_shared(" "), + std::make_shared(rng), + std::make_shared("!"), + }); ASSERT_TRUE( std::regex_match(fuzz.generate(), std::regex("Hello, \\w{1,20}!"))); } @@ -83,36 +85,43 @@ TEST_F(UtilsTest, testConstantRule) { auto rule = std::make_shared("a"); ASSERT_EQ(rule->generate(), "a"); - auto rule_list = RuleList(std::vector>{ - std::make_shared("a"), - std::make_shared("b"), - std::make_shared("c")}); + auto rule_list = RuleList( + std::vector>{ + std::make_shared("a"), + std::make_shared("b"), + std::make_shared("c")}); ASSERT_EQ(rule_list.generate(), "abc"); } TEST_F(UtilsTest, testStringRule) { FuzzerGenerator rng; auto simple = std::make_shared(rng); - ASSERT_TRUE(std::regex_match( - simple->generate(), std::regex("^[\x21-\x7F]+$"))); // printable ascii + ASSERT_TRUE( + std::regex_match( + simple->generate(), std::regex("^[\x21-\x7F]+$"))); // printable ascii ASSERT_FALSE(std::regex_match(simple->generate(), std::regex("^\\w+$"))); ASSERT_FALSE(std::regex_match(simple->generate(), std::regex("^\\d+$"))); auto specified_flexible = std::make_shared( rng, std::vector{UTF8CharList::ASCII}, 3, 7, true); - ASSERT_TRUE(std::regex_match( - specified_flexible->generate(), std::regex("^[\\x21-\\x7F]{3,7}$"))); - ASSERT_FALSE(std::regex_match( - specified_flexible->generate(), std::regex("^[\\x21-\\x7F]{0,2}$"))); - ASSERT_FALSE(std::regex_match( - specified_flexible->generate(), std::regex("^[\\x21-\\x7F]{8,}$"))); + ASSERT_TRUE( + std::regex_match( + specified_flexible->generate(), std::regex("^[\\x21-\\x7F]{3,7}$"))); + ASSERT_FALSE( + std::regex_match( + specified_flexible->generate(), std::regex("^[\\x21-\\x7F]{0,2}$"))); + ASSERT_FALSE( + std::regex_match( + specified_flexible->generate(), std::regex("^[\\x21-\\x7F]{8,}$"))); auto specified_strict = std::make_shared( rng, std::vector{UTF8CharList::ASCII}, 3, 7, false); - ASSERT_TRUE(std::regex_match( - specified_strict->generate(), std::regex("^[\x21-\x7F]{7}$"))); - ASSERT_FALSE(std::regex_match( - specified_strict->generate(), std::regex("^[\x21-\x7F]{8}$"))); + ASSERT_TRUE( + std::regex_match( + specified_strict->generate(), std::regex("^[\x21-\x7F]{7}$"))); + ASSERT_FALSE( + std::regex_match( + specified_strict->generate(), std::regex("^[\x21-\x7F]{8}$"))); } TEST_F(UtilsTest, testWordRule) { @@ -124,16 +133,21 @@ TEST_F(UtilsTest, testWordRule) { ASSERT_FALSE(std::regex_match(simple->generate(), std::regex("^\\W+$"))); auto specified_flexible = std::make_shared(rng, 3, 7, true); - ASSERT_TRUE(std::regex_match( - specified_flexible->generate(), std::regex("^[a-zA-Z]{3,7}$"))); - ASSERT_TRUE(std::regex_match( - specified_flexible->generate(), std::regex("^\\w{3,7}$"))); - ASSERT_FALSE(std::regex_match( - specified_flexible->generate(), std::regex("^\\d{3,7}$"))); - ASSERT_FALSE(std::regex_match( - specified_flexible->generate(), std::regex("^\\w{0,2}$"))); - ASSERT_FALSE(std::regex_match( - specified_flexible->generate(), std::regex("^\\w{8,}$"))); + ASSERT_TRUE( + std::regex_match( + specified_flexible->generate(), std::regex("^[a-zA-Z]{3,7}$"))); + ASSERT_TRUE( + std::regex_match( + specified_flexible->generate(), std::regex("^\\w{3,7}$"))); + ASSERT_FALSE( + std::regex_match( + specified_flexible->generate(), std::regex("^\\d{3,7}$"))); + ASSERT_FALSE( + std::regex_match( + specified_flexible->generate(), std::regex("^\\w{0,2}$"))); + ASSERT_FALSE( + std::regex_match( + specified_flexible->generate(), std::regex("^\\w{8,}$"))); auto specified_strict = std::make_shared(rng, 3, 7, false); ASSERT_TRUE( @@ -149,16 +163,21 @@ TEST_F(UtilsTest, testNumRule) { ASSERT_FALSE(std::regex_match(simple->generate(), std::regex("^\\D+$"))); auto specified_flexible = std::make_shared(rng, 3, 7, true); - ASSERT_TRUE(std::regex_match( - specified_flexible->generate(), std::regex("^\\d{3,7}$"))); - ASSERT_TRUE(std::regex_match( - specified_flexible->generate(), std::regex("^\\w{3,7}$"))); - ASSERT_FALSE(std::regex_match( - specified_flexible->generate(), std::regex("^\\D{3,7}$"))); - ASSERT_FALSE(std::regex_match( - specified_flexible->generate(), std::regex("^\\d{0,2}$"))); - ASSERT_FALSE(std::regex_match( - specified_flexible->generate(), std::regex("^\\d{8,}$"))); + ASSERT_TRUE( + std::regex_match( + specified_flexible->generate(), std::regex("^\\d{3,7}$"))); + ASSERT_TRUE( + std::regex_match( + specified_flexible->generate(), std::regex("^\\w{3,7}$"))); + ASSERT_FALSE( + std::regex_match( + specified_flexible->generate(), std::regex("^\\D{3,7}$"))); + ASSERT_FALSE( + std::regex_match( + specified_flexible->generate(), std::regex("^\\d{0,2}$"))); + ASSERT_FALSE( + std::regex_match( + specified_flexible->generate(), std::regex("^\\d{8,}$"))); auto specified_strict = std::make_shared(rng, 3, 7, false); ASSERT_TRUE( diff --git a/velox/common/geospatial/CMakeLists.txt b/velox/common/geospatial/CMakeLists.txt new file mode 100644 index 00000000000..375f99efeff --- /dev/null +++ b/velox/common/geospatial/CMakeLists.txt @@ -0,0 +1,26 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(VELOX_ENABLE_GEO) + velox_add_library(velox_common_geospatial_serde GeometrySerde.cpp HEADERS GeometrySerde.h) + velox_link_libraries(velox_common_geospatial_serde velox_expression GEOS::geos) +endif() + +velox_install_library_headers() + +if(${VELOX_BUILD_TESTING}) + add_subdirectory(tests) +endif() + +velox_add_library(velox_geospatial_constants INTERFACE HEADERS GeometryConstants.h) diff --git a/velox/common/geospatial/GeometryConstants.h b/velox/common/geospatial/GeometryConstants.h new file mode 100644 index 00000000000..d598f29e026 --- /dev/null +++ b/velox/common/geospatial/GeometryConstants.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +// This file contains constats for working with geospatial queries. +// They _must not_ require the GEOS library (or any 3p library). + +namespace facebook::velox::common::geospatial { + +enum class GeometrySerializationType : uint8_t { + POINT = 0, + MULTI_POINT = 1, + LINE_STRING = 2, + MULTI_LINE_STRING = 3, + POLYGON = 4, + MULTI_POLYGON = 5, + GEOMETRY_COLLECTION = 6, + ENVELOPE = 7 +}; + +enum class EsriShapeType : uint32_t { + POINT = 1, + POLYLINE = 3, + POLYGON = 5, + MULTI_POINT = 8 +}; + +/// Latitude/Longitude range constraints for spherical coordinates. +constexpr double kMinLatitude = -90.0; +constexpr double kMaxLatitude = 90.0; +constexpr double kMinLongitude = -180.0; +constexpr double kMaxLongitude = 180.0; + +/// BingTile-specific latitude constraints (narrower than standard lat/long). +constexpr double kMinBingTileLatitude = -85.05112878; +constexpr double kMaxBingTileLatitude = 85.05112878; + +} // namespace facebook::velox::common::geospatial diff --git a/velox/functions/prestosql/geospatial/GeometrySerde.cpp b/velox/common/geospatial/GeometrySerde.cpp similarity index 84% rename from velox/functions/prestosql/geospatial/GeometrySerde.cpp rename to velox/common/geospatial/GeometrySerde.cpp index 4c0d98903f8..ac9921441b1 100644 --- a/velox/functions/prestosql/geospatial/GeometrySerde.cpp +++ b/velox/common/geospatial/GeometrySerde.cpp @@ -20,11 +20,20 @@ #include #include "velox/common/base/IOUtils.h" -#include "velox/functions/prestosql/geospatial/GeometrySerde.h" +#include "velox/common/geospatial/GeometrySerde.h" using facebook::velox::common::InputByteStream; +using facebook::velox::common::geospatial::EsriShapeType; +using facebook::velox::common::geospatial::GeometrySerializationType; + +namespace facebook::velox::common::geospatial { + +geos::geom::GeometryFactory* GeometryDeserializer::getGeometryFactory() { + thread_local static geos::geom::GeometryFactory::Ptr geometryFactory = + geos::geom::GeometryFactory::create(); + return geometryFactory.get(); +} -namespace facebook::velox::functions::geospatial { std::unique_ptr GeometryDeserializer::readGeometry( velox::common::InputByteStream& stream, size_t size) { @@ -53,6 +62,15 @@ std::unique_ptr GeometryDeserializer::readGeometry( } } +std::unique_ptr GeometryDeserializer::deserializeNonEmpty( + const StringView& geometry) { + auto envelope = deserializeEnvelope(geometry); + if (envelope->isNull()) { + return nullptr; + } + return deserialize(geometry); +} + const std::unique_ptr GeometryDeserializer::deserializeEnvelope(const StringView& geometry) { velox::common::InputByteStream inputStream(geometry.data()); @@ -174,8 +192,9 @@ std::unique_ptr GeometryDeserializer::readPolyline( std::vector> lineStrings; lineStrings.reserve(partCount); for (size_t i = 0; i < partCount; ++i) { - lineStrings.push_back(getGeometryFactory()->createLineString( - readCoordinates(input, partLengths[i]))); + lineStrings.push_back( + getGeometryFactory()->createLineString( + readCoordinates(input, partLengths[i]))); } if (multiType) { @@ -222,22 +241,35 @@ std::unique_ptr GeometryDeserializer::readPolygon( std::vector> holes; std::vector> polygons; + // Shells _should_ be clockwise and holes _should_ be counter-clockwise, + // but this doesn't always happen for single Polygons. For single Polygons, + // we read the first ring as a shell and the rest as holes. For MultiPolygons, + // we read the first ring as a shell, and any counter-clockwise rings as + // holes, then push a polygon and reset if a clockwise ring is encountered. for (size_t i = 0; i < partCount; i++) { auto coordinates = readCoordinates(input, partLengths[i]); - if (isClockwise(coordinates, 0, coordinates->size())) { - // next polygon has started - if (shell) { - polygons.push_back(getGeometryFactory()->createPolygon( - std::move(shell), std::move(holes))); + if (multiType) { + bool clockwiseFlag = + GeometrySerializer::isClockwise(coordinates, 0, coordinates->size()); + if (shell && clockwiseFlag) { + // next polygon has started + polygons.push_back( + getGeometryFactory()->createPolygon( + std::move(shell), std::move(holes))); holes.clear(); + shell = nullptr; } - shell = getGeometryFactory()->createLinearRing(std::move(coordinates)); + } + + auto ring = getGeometryFactory()->createLinearRing(std::move(coordinates)); + if (shell == nullptr) { + shell = std::move(ring); } else { - holes.push_back( - getGeometryFactory()->createLinearRing(std::move(coordinates))); + holes.push_back(std::move(ring)); } } + polygons.push_back( getGeometryFactory()->createPolygon(std::move(shell), std::move(holes))); @@ -297,4 +329,4 @@ GeometryDeserializer::readGeometryCollection( getGeometryFactory()->createGeometryCollection(rawGeometries)); } -} // namespace facebook::velox::functions::geospatial +} // namespace facebook::velox::common::geospatial diff --git a/velox/functions/prestosql/geospatial/GeometrySerde.h b/velox/common/geospatial/GeometrySerde.h similarity index 72% rename from velox/functions/prestosql/geospatial/GeometrySerde.h rename to velox/common/geospatial/GeometrySerde.h index e85c00acbd1..e65bd4bb0c2 100644 --- a/velox/functions/prestosql/geospatial/GeometrySerde.h +++ b/velox/common/geospatial/GeometrySerde.h @@ -16,32 +16,17 @@ #pragma once +#include +#include #include #include #include "velox/common/base/IOUtils.h" -#include "velox/functions/prestosql/geospatial/GeometryUtils.h" +#include "velox/common/geospatial/GeometryConstants.h" +#include "velox/expression/ComplexViewTypes.h" #include "velox/type/StringView.h" -namespace facebook::velox::functions::geospatial { - -enum class GeometrySerializationType : uint8_t { - POINT = 0, - MULTI_POINT = 1, - LINE_STRING = 2, - MULTI_LINE_STRING = 3, - POLYGON = 4, - MULTI_POLYGON = 5, - GEOMETRY_COLLECTION = 6, - ENVELOPE = 7 -}; - -enum class EsriShapeType : uint32_t { - POINT = 1, - POLYLINE = 3, - POLYGON = 5, - MULTI_POINT = 8 -}; +namespace facebook::velox::common::geospatial { /** * VarbinaryWriter is a utility for serializing raw binary data to a @@ -55,7 +40,8 @@ enum class EsriShapeType : uint32_t { template class VarbinaryWriter { public: - VarbinaryWriter(StringWriter& stringWriter) : stringWriter_(stringWriter) {} + /* implicit */ VarbinaryWriter(StringWriter& stringWriter) + : stringWriter_(stringWriter) {} VarbinaryWriter() = delete; void write(const char* data, size_t size) { @@ -122,6 +108,70 @@ class GeometrySerializer { } } + /// Determines if a ring of coordinates (from `start` to `end`) is oriented + /// clockwise. + FOLLY_ALWAYS_INLINE static bool isClockwise( + const std::unique_ptr& coordinates, + size_t start, + size_t end) { + double sum = 0.0; + for (size_t i = start; i < end - 1; i++) { + const auto& p1 = coordinates->getAt(i); + const auto& p2 = coordinates->getAt(i + 1); + sum += (p2.x - p1.x) * (p2.y + p1.y); + } + return sum > 0.0; + } + + /// Reverses the order of coordinates in the sequence between `start` and + /// `end` + FOLLY_ALWAYS_INLINE static void reverse( + const std::unique_ptr& coordinates, + size_t start, + size_t end) { + for (size_t i = 0; i < (end - start) / 2; ++i) { + auto temp = coordinates->getAt(start + i); + coordinates->setAt(coordinates->getAt(end - 1 - i), start + i); + coordinates->setAt(temp, end - 1 - i); + } + } + + /// Ensures that a polygon ring has the canonical orientation: + /// - Exterior rings (shells) must be clockwise. + /// - Interior rings (holes) must be counter-clockwise. + FOLLY_ALWAYS_INLINE static void canonicalizePolygonCoordinates( + const std::unique_ptr& coordinates, + size_t start, + size_t end, + bool isShell) { + bool isClockwiseFlag = isClockwise(coordinates, start, end); + + if ((isShell && !isClockwiseFlag) || (!isShell && isClockwiseFlag)) { + reverse(coordinates, start, end); + } + } + + /// Applies `canonicalizePolygonCoordinates` to all rings in a polygon. + FOLLY_ALWAYS_INLINE static void canonicalizePolygonCoordinates( + const std::unique_ptr& coordinates, + const std::vector& partIndexes, + const std::vector& shellPart) { + for (size_t part = 0; part < partIndexes.size() - 1; part++) { + canonicalizePolygonCoordinates( + coordinates, + partIndexes[part], + partIndexes[part + 1], + shellPart[part]); + } + if (!partIndexes.empty()) { + canonicalizePolygonCoordinates( + coordinates, + partIndexes.back(), + coordinates->size(), + shellPart.back()); + } + } + private: template static void writeGeometry( @@ -359,6 +409,68 @@ class GeometryDeserializer { static const std::unique_ptr deserializeEnvelope( const StringView& geometry); + /// Deserializes a geometry, returning nullptr for empty geometries + /// (those with null envelopes). + static std::unique_ptr deserializeNonEmpty( + const StringView& geometry); + + template + static std::unique_ptr + deserializePointsToCoordinate( + const exec::ArrayView& input, + const std::string& functionName, + bool forbidDuplicates) { + std::unique_ptr coords = + std::make_unique(input.size(), 2); + + double lastX = std::numeric_limits::signaling_NaN(); + double lastY = std::numeric_limits::signaling_NaN(); + for (int i = 0; i < input.size(); i++) { + if (!input[i].has_value()) { + VELOX_USER_FAIL( + fmt::format( + "Invalid input to {}: input array contains null at index {}.", + functionName, + i)); + } + + StringView view = *input[i]; + + velox::common::InputByteStream inputStream(view.data()); + auto geometryType = inputStream.read(); + if (geometryType != GeometrySerializationType::POINT) { + VELOX_USER_FAIL( + fmt::format( + "Non-point geometry in {} input at index {}.", + functionName, + i)); + } + auto x = inputStream.read(); + auto y = inputStream.read(); + if (std::isnan(x) || std::isnan(y)) { + VELOX_USER_FAIL( + fmt::format( + "Empty point in {} input at index {}.", functionName, i)); + } + if (forbidDuplicates && x == lastX && y == lastY) { + VELOX_USER_FAIL( + fmt::format( + "Repeated point sequence in {}: point {},{} at index {}.", + functionName, + x, + y, + i)); + } + lastX = x; + lastY = y; + coords->setAt({x, y}, i); + } + return coords; + } + + /// Returns the thread-local GEOS geometry factory. + static geos::geom::GeometryFactory* getGeometryFactory(); + private: static std::unique_ptr readGeometry( velox::common::InputByteStream& stream, @@ -408,4 +520,4 @@ class GeometryDeserializer { size_t size); }; -} // namespace facebook::velox::functions::geospatial +} // namespace facebook::velox::common::geospatial diff --git a/velox/common/geospatial/tests/CMakeLists.txt b/velox/common/geospatial/tests/CMakeLists.txt new file mode 100644 index 00000000000..55b2f76cde5 --- /dev/null +++ b/velox/common/geospatial/tests/CMakeLists.txt @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(VELOX_ENABLE_GEO) + add_executable(velox_common_geospatial_serde_test GeometrySerdeTest.cpp) + + add_test(velox_common_geospatial_serde_test velox_common_geospatial_serde_test) + + target_link_libraries( + velox_common_geospatial_serde_test + velox_common_geospatial_serde + GTest::gtest + GTest::gtest_main + GTest::gmock + GTest::gmock_main + GEOS::geos + ) +endif() diff --git a/velox/functions/prestosql/geospatial/tests/GeometrySerdeTest.cpp b/velox/common/geospatial/tests/GeometrySerdeTest.cpp similarity index 76% rename from velox/functions/prestosql/geospatial/tests/GeometrySerdeTest.cpp rename to velox/common/geospatial/tests/GeometrySerdeTest.cpp index 3dba74c756f..14ee6186a68 100644 --- a/velox/functions/prestosql/geospatial/tests/GeometrySerdeTest.cpp +++ b/velox/common/geospatial/tests/GeometrySerdeTest.cpp @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "velox/functions/prestosql/geospatial/GeometrySerde.h" +#include "velox/common/geospatial/GeometrySerde.h" #include #include #include @@ -23,7 +23,7 @@ using namespace ::testing; -using namespace facebook::velox::functions::geospatial; +using namespace facebook::velox::common::geospatial; void assertRoundtrip(const std::string& wkt) { geos::io::WKTReader reader; @@ -100,3 +100,8 @@ TEST(GeometrySerdeTest, testComplexSerde) { assertRoundtrip( "GEOMETRYCOLLECTION (POLYGON EMPTY, GEOMETRYCOLLECTION ( POINT (1 2), POLYGON ((0 0, 4 0, 4 4, 0 4, 0 0), (1 1, 2 1, 2 2, 1 2, 1 1)), GEOMETRYCOLLECTION EMPTY, MULTIPOLYGON ( ((10 10, 14 10, 14 14, 10 14, 10 10), (11 11, 12 11, 12 12, 11 12, 11 11)), ((-1 -1, -2 -2, -1 -2, -1 -1)) ) ))"); } + +TEST(GeometrySerdeTest, testSmallAreaRing) { + assertRoundtrip( + "MULTIPOLYGON (((18.6317421 49.9605785, 18.6318832 49.9607979, 18.6324683 49.9607312, 18.6332842 49.9605658, 18.6332003 49.9603557, 18.6339711 49.9602283, 18.6341994 49.9601905, 18.6343455 49.96016, 18.6344167 49.9601452, 18.6346696 49.9600919, 18.6349643 49.9600567, 18.6352271 49.9601455, 18.6354493 49.9600501, 18.6358024 49.9601071, 18.6358911 49.9600263, 18.6336542 49.9592453, 18.6334794 49.9591838, 18.6337483 49.9581339, 18.6335303 49.9580562, 18.6331284 49.9579122, 18.6324931 49.9576885, 18.6322503 49.9575998, 18.6321381 49.9581593, 18.6321172 49.9582692, 18.6324683 49.9583852, 18.6325255 49.9584004, 18.6327588 49.958489, 18.6324792 49.9588351, 18.6323941 49.9588049, 18.6323261 49.9587807, 18.6320354 49.9586789, 18.6319443 49.9592903, 18.6326731 49.9595648, 18.6331388 49.9594836, 18.6335981 49.959673, 18.6333065 49.9597934, 18.6328096 49.9600844, 18.6330209 49.9601348, 18.633424 49.9602597, 18.6332263 49.960317, 18.6315633 49.9597642, 18.6309331 49.9600741, 18.6317421 49.9605785)), ((18.6298591 49.9606201, 18.6298592 49.96062, 18.6298589 49.9606193, 18.6298591 49.9606201)))"); +} diff --git a/velox/common/hyperloglog/CMakeLists.txt b/velox/common/hyperloglog/CMakeLists.txt index 649d1bc9851..6a03daed873 100644 --- a/velox/common/hyperloglog/CMakeLists.txt +++ b/velox/common/hyperloglog/CMakeLists.txt @@ -17,6 +17,12 @@ velox_add_library( DenseHll.cpp SparseHll.cpp Murmur3Hash128.cpp + HEADERS + BiasCorrection.h + DenseHll.h + HllUtils.h + Murmur3Hash128.h + SparseHll.h ) velox_link_libraries(velox_common_hyperloglog PUBLIC velox_memory PRIVATE velox_exception) diff --git a/velox/common/hyperloglog/DenseHll.cpp b/velox/common/hyperloglog/DenseHll.cpp index f3a3f54b2d2..b7b7a31fe78 100644 --- a/velox/common/hyperloglog/DenseHll.cpp +++ b/velox/common/hyperloglog/DenseHll.cpp @@ -15,13 +15,13 @@ */ #include "velox/common/hyperloglog/DenseHll.h" -#include -#include +#include "velox/common/base/BitUtil.h" #include "velox/common/base/IOUtils.h" #include "velox/common/hyperloglog/BiasCorrection.h" #include "velox/common/hyperloglog/HllUtils.h" namespace facebook::velox::common::hll { + namespace { const int kBitsPerBucket = 4; const int8_t kMaxDelta = (1 << kBitsPerBucket) - 1; @@ -119,14 +119,26 @@ double correctBias(double rawEstimate, int8_t indexBitLength) { } } // namespace -DenseHll::DenseHll(int8_t indexBitLength, HashStringAllocator* allocator) - : deltas_{StlAllocator(allocator)}, - overflowBuckets_{StlAllocator(allocator)}, - overflowValues_{StlAllocator(allocator)} { +template +DenseHll::DenseHll(int8_t indexBitLength, TAllocator* allocator) + : allocator_(allocator), + deltas_{TStlAllocator(allocator)}, + overflowBuckets_{TStlAllocator(allocator)}, + overflowValues_{TStlAllocator(allocator)} { initialize(indexBitLength); } -void DenseHll::initialize(int8_t indexBitLength) { +template +DenseHll::DenseHll(TAllocator* allocator) + : indexBitLength_(-1), + baselineCount_(0), + allocator_(allocator), + deltas_{TStlAllocator(allocator)}, + overflowBuckets_{TStlAllocator(allocator)}, + overflowValues_{TStlAllocator(allocator)} {} + +template +void DenseHll::initialize(int8_t indexBitLength) { VELOX_CHECK_GE(indexBitLength, 4, "indexBitLength must be in [4, 16] range"); VELOX_CHECK_LE(indexBitLength, 16, "indexBitLength must be in [4, 16] range"); @@ -137,13 +149,15 @@ void DenseHll::initialize(int8_t indexBitLength) { deltas_.resize(numBuckets * kBitsPerBucket / 8); } -void DenseHll::insertHash(uint64_t hash) { +template +void DenseHll::insertHash(uint64_t hash) { auto index = computeIndex(hash, indexBitLength_); auto value = numberOfLeadingZeros(hash, indexBitLength_) + 1; insert(index, value); } -void DenseHll::insert(int32_t index, int8_t value) { +template +void DenseHll::insert(int32_t index, int8_t value) { auto delta = value - baseline_; auto oldDelta = getDelta(index); @@ -238,6 +252,9 @@ DenseHllView deserialize(const char* serialized) { VELOX_CHECK_EQ(kPrestoDenseV2, version); auto indexBitLength = stream.read(); + VELOX_CHECK_GE(indexBitLength, 4, "indexBitLength must be in [4, 16] range"); + VELOX_CHECK_LE(indexBitLength, 16, "indexBitLength must be in [4, 16] range"); + auto baseline = stream.read(); auto numBuckets = 1 << indexBitLength; @@ -245,6 +262,8 @@ DenseHllView deserialize(const char* serialized) { const int8_t* deltas = stream.read(numBuckets / 2); auto overflows = stream.read(); + VELOX_CHECK_GE( + overflows, 0, "Invalid DenseHll overflow count: {}", overflows); const uint16_t* overflowBuckets = overflows ? stream.read(overflows) : nullptr; @@ -261,7 +280,8 @@ DenseHllView deserialize(const char* serialized) { } } // namespace -int64_t DenseHll::cardinality() const { +template +int64_t DenseHll::cardinality() const { DenseHllView hll{ indexBitLength_, baseline_, @@ -272,18 +292,14 @@ int64_t DenseHll::cardinality() const { return cardinalityImpl(hll); } -// static -int64_t DenseHll::cardinality(const char* serialized) { - auto hll = deserialize(serialized); - return cardinalityImpl(hll); -} - -int8_t DenseHll::getDelta(int32_t index) const { +template +int8_t DenseHll::getDelta(int32_t index) const { int slot = index >> 1; return (deltas_[slot] >> shiftForBucket(index)) & kBucketMask; } -void DenseHll::setDelta(int32_t index, int8_t value) { +template +void DenseHll::setDelta(int32_t index, int8_t value) { int slot = index >> 1; // Clear the old value. @@ -295,12 +311,14 @@ void DenseHll::setDelta(int32_t index, int8_t value) { deltas_[slot] |= setMask; } -int8_t DenseHll::getOverflow(int32_t index) const { +template +int8_t DenseHll::getOverflow(int32_t index) const { return getOverflowImpl( index, overflows_, overflowBuckets_.data(), overflowValues_.data()); } -int DenseHll::findOverflowEntry(int32_t index) const { +template +int DenseHll::findOverflowEntry(int32_t index) const { for (auto i = 0; i < overflows_; i++) { if (overflowBuckets_[i] == index) { return i; @@ -309,7 +327,8 @@ int DenseHll::findOverflowEntry(int32_t index) const { return -1; } -void DenseHll::adjustBaselineIfNeeded() { +template +void DenseHll::adjustBaselineIfNeeded() { auto numBuckets = 1 << indexBitLength_; while (baselineCount_ == 0) { @@ -359,7 +378,8 @@ void DenseHll::adjustBaselineIfNeeded() { } } -void DenseHll::sortOverflows() { +template +void DenseHll::sortOverflows() { // traditional insertion sort (ok for small arrays) for (int i = 1; i < overflows_; i++) { auto bucket = overflowBuckets_[i]; @@ -385,7 +405,8 @@ void DenseHll::sortOverflows() { } } -int32_t DenseHll::serializedSize() const { +template +int32_t DenseHll::serializedSize() const { return 1 /* type + version */ + 1 /* indexBitLength */ + 1 /* baseline */ @@ -395,13 +416,17 @@ int32_t DenseHll::serializedSize() const { + overflows_ /* overflow bucket values */; } -// static -bool DenseHll::canDeserialize(const char* input) { +int64_t DenseHlls::cardinality(const char* serialized) { + auto hll = deserialize(serialized); + return cardinalityImpl(hll); +} + +bool DenseHlls::canDeserialize(const char* input) { return *reinterpret_cast(input) == kPrestoDenseV2; } // static -bool DenseHll::canDeserialize(const char* input, int size) { +bool DenseHlls::canDeserialize(const char* input, int size) { if (size < 5) { // Min serialized sparse HLL size is 5 bytes. return false; @@ -459,22 +484,23 @@ bool DenseHll::canDeserialize(const char* input, int size) { return true; } -// static -int8_t DenseHll::deserializeIndexBitLength(const char* input) { +int8_t DenseHlls::deserializeIndexBitLength(const char* input) { common::InputByteStream stream(input); stream.read(); return stream.read(); } -// static -int32_t DenseHll::estimateInMemorySize(int8_t indexBitLength) { +int32_t DenseHlls::estimateInMemorySize(int8_t indexBitLength) { // Note: we don't take into account overflow entries since their number can // vary. - return sizeof(indexBitLength_) + sizeof(baseline_) + sizeof(baselineCount_) + + // return sizeof(indexBitLength_) + sizeof(baseline_) + + // sizeof(baselineCount_) + (1 << indexBitLength) / 2; + return sizeof(int8_t) + sizeof(int8_t) + sizeof(int32_t) + (1 << indexBitLength) / 2; } -void DenseHll::serialize(char* output) { +template +void DenseHll::serialize(char* output) { // sort overflow arrays to get consistent serialization for equivalent HLLs sortOverflows(); @@ -492,10 +518,12 @@ void DenseHll::serialize(char* output) { } } -DenseHll::DenseHll(const char* serialized, HashStringAllocator* allocator) - : deltas_{StlAllocator(allocator)}, - overflowBuckets_{StlAllocator(allocator)}, - overflowValues_{StlAllocator(allocator)} { +template +DenseHll::DenseHll(const char* serialized, TAllocator* allocator) + : allocator_(allocator), + deltas_{TStlAllocator(allocator)}, + overflowBuckets_{TStlAllocator(allocator)}, + overflowValues_{TStlAllocator(allocator)} { auto hll = deserialize(serialized); initialize(hll.indexBitLength); baseline_ = hll.baseline; @@ -525,7 +553,8 @@ DenseHll::DenseHll(const char* serialized, HashStringAllocator* allocator) } } -void DenseHll::mergeWith(const DenseHll& other) { +template +void DenseHll::mergeWith(const DenseHll& other) { VELOX_CHECK_EQ( indexBitLength_, other.indexBitLength_, @@ -539,7 +568,8 @@ void DenseHll::mergeWith(const DenseHll& other) { other.overflowValues_.data()}); } -void DenseHll::mergeWith(const char* serialized) { +template +void DenseHll::mergeWith(const char* serialized) { common::InputByteStream stream(serialized); auto version = stream.read(); @@ -561,7 +591,8 @@ void DenseHll::mergeWith(const char* serialized) { mergeWith({baseline, deltas, overflows, overflowBuckets, overflowValues}); } -std::pair DenseHll::computeNewValue( +template +std::pair DenseHll::computeNewValue( int8_t delta, int8_t otherDelta, int32_t bucket, @@ -585,7 +616,8 @@ std::pair DenseHll::computeNewValue( return {std::max(value1, value2), overflowEntry}; } -void DenseHll::mergeWith(const HllView& other) { +template +void DenseHll::mergeWith(const HllView& other) { // Number of 'delta' bytes that fit in a single SIMD batch. Each 'delta' byte // stores 2 4-bit deltas. constexpr auto batchSize = xsimd::batch::size; @@ -611,7 +643,10 @@ void DenseHll::mergeWith(const HllView& other) { adjustBaselineIfNeeded(); } -int32_t DenseHll::mergeWithSimd(const HllView& other, int8_t newBaseline) { +template +int32_t DenseHll::mergeWithSimd( + const HllView& other, + int8_t newBaseline) { const auto batchSize = xsimd::batch::size; const auto bucketMaskBatch = xsimd::broadcast(kBucketMask); @@ -751,7 +786,10 @@ int32_t DenseHll::mergeWithSimd(const HllView& other, int8_t newBaseline) { return baselineCount; } -int32_t DenseHll::mergeWithScalar(const HllView& other, int8_t newBaseline) { +template +int32_t DenseHll::mergeWithScalar( + const HllView& other, + int8_t newBaseline) { int32_t baselineCount = 0; int bucket = 0; @@ -787,8 +825,11 @@ int32_t DenseHll::mergeWithScalar(const HllView& other, int8_t newBaseline) { return baselineCount; } -int8_t -DenseHll::updateOverflow(int32_t index, int overflowEntry, int8_t delta) { +template +int8_t DenseHll::updateOverflow( + int32_t index, + int overflowEntry, + int8_t delta) { if (delta > kMaxDelta) { if (overflowEntry != -1) { // update existing overflow @@ -804,7 +845,8 @@ DenseHll::updateOverflow(int32_t index, int overflowEntry, int8_t delta) { return delta; } -void DenseHll::addOverflow(int32_t index, int8_t overflow) { +template +void DenseHll::addOverflow(int32_t index, int8_t overflow) { overflowBuckets_.resize(overflows_ + 1); overflowValues_.resize(overflows_ + 1); @@ -813,10 +855,17 @@ void DenseHll::addOverflow(int32_t index, int8_t overflow) { overflows_++; } -void DenseHll::removeOverflow(int overflowEntry) { +template +void DenseHll::removeOverflow(int overflowEntry) { // Remove existing overflow. overflowBuckets_[overflowEntry] = overflowBuckets_[overflows_ - 1]; overflowValues_[overflowEntry] = overflowValues_[overflows_ - 1]; overflows_--; } + +// Explicit template instantiation for both HashStringAllocator (default) and +// memory::MemoryPool +template class DenseHll; +template class DenseHll; + } // namespace facebook::velox::common::hll diff --git a/velox/common/hyperloglog/DenseHll.h b/velox/common/hyperloglog/DenseHll.h index b6b5f03f8cd..6936b6d8085 100644 --- a/velox/common/hyperloglog/DenseHll.h +++ b/velox/common/hyperloglog/DenseHll.h @@ -17,7 +17,45 @@ #include "velox/common/memory/HashStringAllocator.h" namespace facebook::velox::common::hll { -class SparseHll; + +class DenseHlls { + public: + /// Returns cardinality estimate from the specified serialized digest. + /// @param serialized Pointer to serialized DenseHll data + /// @return Estimated cardinality of the HyperLogLog + static int64_t cardinality(const char* serialized); + + /// Returns true if 'input' contains Presto DenseV2 format indicator. + /// @param input Pointer to serialized data to check + /// @return True if the data is in DenseV2 format, false otherwise + static bool canDeserialize(const char* input); + + /// Returns true if 'input' contains Presto DenseV2 format indicator and the + /// rest of the data matches HLL format: + /// 1 byte for version + /// 1 byte for index bit length, index bit length must be in [4,16] + /// 1 byte for baseline value + /// 2^(n-1) bytes for buckets, values in buckets must be in [0,63] + /// 2 bytes for # overflow buckets + /// 3 * #overflow buckets bytes for overflow buckets/values + /// More information here: + /// https://engineering.fb.com/2018/12/13/data-infrastructure/hyperloglog/ + /// @param input Pointer to serialized data to validate + /// @param size Size of the serialized data in bytes + /// @return True if the data is valid DenseV2 format, false otherwise + static bool canDeserialize(const char* input, int size); + + /// Extracts the index bit length from serialized DenseHll data. + /// @param input Pointer to serialized DenseHll data + /// @return The index bit length used in the serialized HLL + static int8_t deserializeIndexBitLength(const char* input); + + /// Returns an estimate of memory usage for DenseHll instance with the + /// specified number of bits per bucket. + /// @param indexBitLength Number of bits per bucket (must be in [4,16]) + /// @return Estimated memory usage in bytes + static int32_t estimateInMemorySize(int8_t indexBitLength); +}; /// HyperLogLog implementation using dense storage layout. /// The number of bits to use as bucket (indexBitLength) is specified by the @@ -26,18 +64,19 @@ class SparseHll; /// /// Memory usage: 2 ^ (indexBitLength - 1) bytes. 2KB for indexBitLength of 12 /// which provides max standard error of 0.023. +template class DenseHll { public: - DenseHll(int8_t indexBitLength, HashStringAllocator* allocator); + template + using TStlAllocator = typename TAllocator::template TStlAllocator; + + DenseHll(int8_t indexBitLength, TAllocator* allocator); - DenseHll(const char* serialized, HashStringAllocator* allocator); + DenseHll(const char* serialized, TAllocator* allocator); /// Creates an uninitialized instance that doesn't allcate any significant /// memory. The caller must call initialize before using the HLL. - explicit DenseHll(HashStringAllocator* allocator) - : deltas_{StlAllocator(allocator)}, - overflowBuckets_{StlAllocator(allocator)}, - overflowValues_{StlAllocator(allocator)} {} + explicit DenseHll(TAllocator* allocator); /// Allocates memory that can fit 2 ^ indexBitLength buckets. void initialize(int8_t indexBitLength); @@ -55,28 +94,9 @@ class DenseHll { int64_t cardinality() const; - static int64_t cardinality(const char* serialized); - /// Serializes internal state using Presto DenseV2 format. void serialize(char* output); - /// Returns true if 'input' contains Presto DenseV2 format indicator. - static bool canDeserialize(const char* input); - - /// Returns true if 'input' contains Presto DenseV2 format indicator and the - /// rest of the data matches HLL format: - /// 1 byte for version - /// 1 byte for index bit length, index bit length must be in [4,16] - /// 1 byte for baseline value - /// 2^(n-1) bytes for buckets, values in buckets must be in [0,63] - /// 2 bytes for # overflow buckets - /// 3 * #overflow buckets bytes for overflow buckets/values - /// More information here: - /// https://engineering.fb.com/2018/12/13/data-infrastructure/hyperloglog/ - static bool canDeserialize(const char* input, int size); - - static int8_t deserializeIndexBitLength(const char* input); - /// Returns the size of the serialized state without serialising. int32_t serializedSize() const; @@ -86,10 +106,6 @@ class DenseHll { void mergeWith(const char* serialized); - /// Returns an estimate of memory usage for DenseHll instance with the - /// specified number of bits per bucket. - static int32_t estimateInMemorySize(int8_t indexBitLength); - private: int8_t getDelta(int32_t index) const; @@ -147,20 +163,19 @@ class DenseHll { /// Number of zero deltas. int32_t baselineCount_; + TAllocator* allocator_; + /// Per-bucket values represented as deltas from the baseline_. Each entry /// stores 2 values, 4 bits each. The maximum value that can be stored is 15. /// Larger values are stored in a separate overflow list. - std::vector> deltas_; - - /// Number of overflowing values, e.g. values where delta from baseline is - /// greater than 15. + std::vector> deltas_; int16_t overflows_{0}; /// List of buckets with overflowing values. - std::vector> overflowBuckets_; + std::vector> overflowBuckets_; /// Overflowing values stored as deltas from the deltas: value - 15 - /// baseline. - std::vector> overflowValues_; + std::vector> overflowValues_; }; } // namespace facebook::velox::common::hll diff --git a/velox/common/hyperloglog/SparseHll.cpp b/velox/common/hyperloglog/SparseHll.cpp index 27d290dd10e..ab29f9d0e88 100644 --- a/velox/common/hyperloglog/SparseHll.cpp +++ b/velox/common/hyperloglog/SparseHll.cpp @@ -34,9 +34,8 @@ inline uint32_t decodeValue(uint32_t entry) { return entry & ((1 << kValueBitLength) - 1); } -int searchIndex( - uint32_t index, - const std::vector>& entries) { +template +int searchIndex(uint32_t index, const VectorType& entries) { int low = 0; int high = entries.size() - 1; @@ -69,7 +68,66 @@ common::InputByteStream initializeInputStream(const char* serialized) { } } // namespace -bool SparseHll::insertHash(uint64_t hash) { +// Static utility functions implementation +int64_t SparseHlls::cardinality(const char* serialized) { + static const int kTotalBuckets = 1 << kIndexBitLength; + + auto stream = initializeInputStream(serialized); + auto size = stream.read(); + + int zeroBuckets = kTotalBuckets - size; + return std::round(linearCounting(zeroBuckets, kTotalBuckets)); +} + +std::string SparseHlls::serializeEmpty(int8_t indexBitLength) { + static const size_t kSize = 4; + + std::string serialized; + serialized.resize(kSize); + + common::OutputByteStream stream(serialized.data()); + stream.appendOne(kPrestoSparseV2); + stream.appendOne(indexBitLength); + stream.appendOne(static_cast(0)); + return serialized; +} + +bool SparseHlls::canDeserialize(const char* input) { + return *reinterpret_cast(input) == kPrestoSparseV2; +} + +int8_t SparseHlls::deserializeIndexBitLength(const char* input) { + common::InputByteStream stream(input); + stream.read(); // Skip version + return stream.read(); // Return indexBitLength +} + +// Template method implementations +template +SparseHll::SparseHll(TAllocator* allocator) + : allocator_(allocator), entries_{TStlAllocator(allocator)} {} + +template +SparseHll::SparseHll(const char* serialized, TAllocator* allocator) + : allocator_(allocator), entries_{TStlAllocator(allocator)} { + common::InputByteStream stream(serialized); + auto version = stream.read(); + VELOX_CHECK_EQ(kPrestoSparseV2, version); + + // Skip indexBitLength from serialized data - we use fixed kIndexBitLength + // internally + stream.read(); + + auto size = stream.read(); + VELOX_CHECK_GE(size, 0, "Invalid SparseHll entry count: {}", size); + entries_.resize(size); + for (auto i = 0; i < size; i++) { + entries_[i] = stream.read(); + } +} + +template +bool SparseHll::insertHash(uint64_t hash) { auto index = computeIndex(hash, kIndexBitLength); auto value = numberOfLeadingZeros(hash, kIndexBitLength); @@ -88,29 +146,21 @@ bool SparseHll::insertHash(uint64_t hash) { return overLimit(); } -int64_t SparseHll::cardinality() const { +template +int64_t SparseHll::cardinality() const { // Estimate the cardinality using linear counting over the theoretical // 2^kIndexBitLength buckets available due to the fact that we're // recording the raw leading kIndexBitLength of the hash. This produces // much better precision while in the sparse regime. - static const int kTotalBuckets = 1 << kIndexBitLength; + const int kTotalBuckets = 1 << kIndexBitLength; int zeroBuckets = kTotalBuckets - entries_.size(); return std::round(linearCounting(zeroBuckets, kTotalBuckets)); } -// static -int64_t SparseHll::cardinality(const char* serialized) { - static const int kTotalBuckets = 1 << kIndexBitLength; - - auto stream = initializeInputStream(serialized); - auto size = stream.read(); - - int zeroBuckets = kTotalBuckets - size; - return std::round(linearCounting(zeroBuckets, kTotalBuckets)); -} - -void SparseHll::serialize(int8_t indexBitLength, char* output) const { +template +void SparseHll::serialize(int8_t indexBitLength, char* output) + const { common::OutputByteStream stream(output); stream.appendOne(kPrestoSparseV2); stream.appendOne(indexBitLength); @@ -120,75 +170,54 @@ void SparseHll::serialize(int8_t indexBitLength, char* output) const { } } -// static -std::string SparseHll::serializeEmpty(int8_t indexBitLength) { - static const size_t kSize = 4; - - std::string serialized; - serialized.resize(kSize); - - common::OutputByteStream stream(serialized.data()); - stream.appendOne(kPrestoSparseV2); - stream.appendOne(indexBitLength); - stream.appendOne(static_cast(0)); - return serialized; -} - -// static -bool SparseHll::canDeserialize(const char* input) { - return *reinterpret_cast(input) == kPrestoSparseV2; -} - -int32_t SparseHll::serializedSize() const { +template +int32_t SparseHll::serializedSize() const { return 1 /* version */ + 1 /* indexBitLength */ + 2 /* number of entries */ + entries_.size() * 4; } -int32_t SparseHll::inMemorySize() const { +template +int32_t SparseHll::inMemorySize() const { return sizeof(uint32_t) * entries_.size(); } -SparseHll::SparseHll(const char* serialized, HashStringAllocator* allocator) - : entries_{StlAllocator(allocator)} { - auto stream = initializeInputStream(serialized); - - auto size = stream.read(); - entries_.resize(size); - for (auto i = 0; i < size; i++) { - entries_[i] = stream.read(); - } -} - -void SparseHll::mergeWith(const SparseHll& other) { +template +void SparseHll::mergeWith(const SparseHll& other) { auto size = other.entries_.size(); // This check prevents merge aggregation from being performed on - // empty_approx_set(), an empty HyperLogLog. The merge function typically does - // not take an empty HyperLogLog structure as an argument. + // empty_approx_set(), an empty HyperLogLog. The merge function typically + // does not take an empty HyperLogLog structure as an argument. if (size) { mergeWith(size, other.entries_.data()); } } -void SparseHll::mergeWith(const char* serialized) { +template +void SparseHll::mergeWith(const char* serialized) { auto stream = initializeInputStream(serialized); auto size = stream.read(); // This check prevents merge aggregation from being performed on - // empty_approx_set(), an empty HyperLogLog. The merge function typically does - // not take an empty HyperLogLog structure as an argument. + // empty_approx_set(), an empty HyperLogLog. The merge function typically + // does not take an empty HyperLogLog structure as an argument. if (size) { mergeWith( size, reinterpret_cast(serialized + stream.offset())); } } -void SparseHll::mergeWith(size_t otherSize, const uint32_t* otherEntries) { +template +void SparseHll::mergeWith( + size_t otherSize, + const uint32_t* otherEntries) { VELOX_CHECK_GT(otherSize, 0); auto size = entries_.size(); - std::vector merged(size + otherSize); + + auto merged = std::vector>( + size + otherSize, TStlAllocator(allocator_)); int pos = 0; int leftPos = 0; @@ -223,7 +252,8 @@ void SparseHll::mergeWith(size_t otherSize, const uint32_t* otherEntries) { } } -void SparseHll::verify() const { +template +void SparseHll::verify() const { if (entries_.size() <= 1) { return; } @@ -236,11 +266,11 @@ void SparseHll::verify() const { } } -void SparseHll::toDense(DenseHll& denseHll) const { +template +void SparseHll::toDense(DenseHll& denseHll) const { auto indexBitLength = denseHll.indexBitLength(); - for (auto i = 0; i < entries_.size(); i++) { - auto entry = entries_[i]; + for (auto entry : entries_) { auto index = entry >> (32 - indexBitLength); auto shiftedValue = entry << indexBitLength; auto zeros = shiftedValue == 0 ? 32 : __builtin_clz(shiftedValue); @@ -257,4 +287,9 @@ void SparseHll::toDense(DenseHll& denseHll) const { } } +// Explicit template instantiation for HashStringAllocator (default) +template class SparseHll; +// Explicit template instantiation for memory::MemoryPool +template class SparseHll; + } // namespace facebook::velox::common::hll diff --git a/velox/common/hyperloglog/SparseHll.h b/velox/common/hyperloglog/SparseHll.h index e881b577627..61a3cb8cb29 100644 --- a/velox/common/hyperloglog/SparseHll.h +++ b/velox/common/hyperloglog/SparseHll.h @@ -18,15 +18,42 @@ #include "velox/common/memory/HashStringAllocator.h" namespace facebook::velox::common::hll { + +class SparseHlls { + public: + /// Returns cardinality estimate from the specified serialized digest. + /// @param serialized Pointer to serialized SparseHll data + /// @return Estimated cardinality of the HyperLogLog + static int64_t cardinality(const char* serialized); + + /// Returns true if 'input' has Presto SparseV2 format. + /// @param input Pointer to serialized data to check + /// @return True if the data is in SparseV2 format, false otherwise + static bool canDeserialize(const char* input); + + /// Creates an empty serialized SparseHll with the specified index bit length. + /// @param indexBitLength Number of bits for indexing (must be in [4,16]) + /// @return Serialized empty SparseHll as a string + static std::string serializeEmpty(int8_t indexBitLength); + + /// Extracts the index bit length from serialized SparseHll data. + /// @param input Pointer to serialized SparseHll data + /// @return The index bit length used in the serialized HLL + static int8_t deserializeIndexBitLength(const char* input); +}; + /// HyperLogLog implementation using sparse storage layout. /// It uses 26-bit buckets and provides high accuracy for low cardinalities. /// Memory usage: 4 bytes for each observed bucket. +template class SparseHll { public: - explicit SparseHll(HashStringAllocator* allocator) - : entries_{StlAllocator(allocator)} {} + template + using TStlAllocator = typename TAllocator::template TStlAllocator; + + explicit SparseHll(TAllocator* allocator); - SparseHll(const char* serialized, HashStringAllocator* allocator); + SparseHll(const char* serialized, TAllocator* allocator); void setSoftMemoryLimit(uint32_t softMemoryLimit) { softNumEntriesLimit_ = softMemoryLimit / 4; @@ -42,17 +69,9 @@ class SparseHll { int64_t cardinality() const; - /// Returns cardinality estimate from the specified serialized digest. - static int64_t cardinality(const char* serialized); - /// Serializes internal state using Presto SparseV2 format. void serialize(int8_t indexBitLength, char* output) const; - static std::string serializeEmpty(int8_t indexBitLength); - - /// Returns true if 'input' has Presto SparseV2 format. - static bool canDeserialize(const char* input); - /// Returns the size of the serialized state without serialising. int32_t serializedSize() const; @@ -63,7 +82,7 @@ class SparseHll { void mergeWith(const char* serialized); /// Merges state into provided instance of DenseHll. - void toDense(DenseHll& denseHll) const; + void toDense(DenseHll& denseHll) const; /// Returns current memory usage. int32_t inMemorySize() const; @@ -84,8 +103,8 @@ class SparseHll { /// A list of observed buckets. Each entry is a 32 bit integer encoding 26-bit /// bucket and 6-bit value (number of zeros in the input hash after the bucket /// + 1). - std::vector> entries_; - + TAllocator* allocator_; + std::vector> entries_; /// Number of entries that can be stored before reaching soft memory limit. uint32_t softNumEntriesLimit_{0}; }; diff --git a/velox/common/hyperloglog/benchmarks/DenseHll.cpp b/velox/common/hyperloglog/benchmarks/DenseHll.cpp index 7233280f1d6..ecda97fac0f 100644 --- a/velox/common/hyperloglog/benchmarks/DenseHll.cpp +++ b/velox/common/hyperloglog/benchmarks/DenseHll.cpp @@ -18,8 +18,7 @@ #include #include "velox/common/memory/HashStringAllocator.h" -#define XXH_INLINE_ALL -#include +#include "velox/common/base/XxHashInline.h" using namespace facebook::velox; @@ -49,7 +48,7 @@ class DenseHllBenchmark { folly::BenchmarkSuspender suspender; HashStringAllocator allocator(pool_); - common::hll::DenseHll hll(hashBits, &allocator); + common::hll::DenseHll<> hll(hashBits, &allocator); suspender.dismiss(); @@ -61,7 +60,7 @@ class DenseHllBenchmark { private: std::string makeSerializedHll(int hashBits, int32_t step) { HashStringAllocator allocator(pool_); - common::hll::DenseHll hll(hashBits, &allocator); + common::hll::DenseHll<> hll(hashBits, &allocator); for (int32_t i = 0; i < 1'000'000; ++i) { auto hash = hashOne(i * step); hll.insertHash(hash); @@ -69,7 +68,7 @@ class DenseHllBenchmark { return serialize(hll); } - static std::string serialize(common::hll::DenseHll& denseHll) { + static std::string serialize(common::hll::DenseHll<>& denseHll) { auto size = denseHll.serializedSize(); std::string serialized; serialized.resize(size); diff --git a/velox/common/hyperloglog/tests/DenseHllTest.cpp b/velox/common/hyperloglog/tests/DenseHllTest.cpp index c688a6918c4..c206a7870ff 100644 --- a/velox/common/hyperloglog/tests/DenseHllTest.cpp +++ b/velox/common/hyperloglog/tests/DenseHllTest.cpp @@ -15,13 +15,11 @@ */ #include "velox/common/hyperloglog/DenseHll.h" +#include #include -#include -#include #include -#define XXH_INLINE_ALL -#include +#include "velox/common/base/XxHashInline.h" #include "velox/common/encode/Base64.h" @@ -34,22 +32,27 @@ uint64_t hashOne(T value) { return XXH64(&value, sizeof(value), 0); } -class DenseHllTest : public ::testing::TestWithParam { +template +class DenseHllTest : public ::testing::Test { protected: static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); } - DenseHll roundTrip(DenseHll& hll) { - auto size = hll.serializedSize(); - std::string serialized; - serialized.resize(size); - hll.serialize(serialized.data()); + void SetUp() override { + if constexpr (std::is_same_v) { + allocator_ = &hsa_; + } else { + allocator_ = pool_.get(); + } + } - return DenseHll(serialized.data(), &allocator_); + DenseHll roundTrip(DenseHll& hll) { + auto serialized = this->serialize(hll); + return DenseHll(serialized.data(), allocator_); } - std::string serialize(DenseHll& denseHll) { + std::string serialize(DenseHll& denseHll) { auto size = denseHll.serializedSize(); std::string serialized; serialized.resize(size); @@ -58,23 +61,19 @@ class DenseHllTest : public ::testing::TestWithParam { } template - void testMergeWith( - int8_t indexBitLength, - const std::vector& left, - const std::vector& right) { - testMergeWith(indexBitLength, left, right, false); - testMergeWith(indexBitLength, left, right, true); + void testMergeWith(const std::vector& left, const std::vector& right) { + testMergeWith(left, right, false); + testMergeWith(left, right, true); } template void testMergeWith( - int8_t indexBitLength, const std::vector& left, const std::vector& right, bool serialized) { - DenseHll hllLeft{indexBitLength, &allocator_}; - DenseHll hllRight{indexBitLength, &allocator_}; - DenseHll expected{indexBitLength, &allocator_}; + DenseHll hllLeft{11, allocator_}; + DenseHll hllRight{11, allocator_}; + DenseHll expected{11, allocator_}; for (auto value : left) { auto hash = hashOne(value); @@ -89,30 +88,51 @@ class DenseHllTest : public ::testing::TestWithParam { } if (serialized) { - auto serializedRight = serialize(hllRight); + auto serializedRight = this->serialize(hllRight); hllLeft.mergeWith(serializedRight.data()); } else { hllLeft.mergeWith(hllRight); } ASSERT_EQ(hllLeft.cardinality(), expected.cardinality()); - ASSERT_EQ(serialize(hllLeft), serialize(expected)); + ASSERT_EQ(this->serialize(hllLeft), this->serialize(expected)); - auto hllLeftSerialized = serialize(hllLeft); + auto hllLeftSerialized = this->serialize(hllLeft); ASSERT_EQ( - DenseHll::cardinality(hllLeftSerialized.data()), + DenseHlls::cardinality(hllLeftSerialized.data()), expected.cardinality()); } std::shared_ptr pool_{ memory::memoryManager()->addLeafPool()}; - HashStringAllocator allocator_{pool_.get()}; + HashStringAllocator hsa_{pool_.get()}; + TAllocator* allocator_; }; -TEST_P(DenseHllTest, basic) { - int8_t indexBitLength = GetParam(); +using AllocatorTypes = + ::testing::Types; + +class NameGenerator { + public: + template + static std::string GetName(int) { + if constexpr (std::is_same_v) { + return "hsa"; + } else if constexpr (std::is_same_v) { + return "pool"; + } else { + VELOX_UNREACHABLE( + "Only HashStringAllocator and MemoryPool are supported allocator types."); + } + } +}; + +TYPED_TEST_SUITE(DenseHllTest, AllocatorTypes, NameGenerator); + +TYPED_TEST(DenseHllTest, basic) { + int8_t indexBitLength = 11; + DenseHll denseHll{indexBitLength, this->allocator_}; - DenseHll denseHll{indexBitLength, &allocator_}; for (int i = 0; i < 1'000; i++) { auto value = i % 17; auto hash = hashOne(value); @@ -131,31 +151,29 @@ TEST_P(DenseHllTest, basic) { ASSERT_EQ(expectedCardinality, denseHll.cardinality()); - DenseHll deserialized = roundTrip(denseHll); + DenseHll deserialized = this->roundTrip(denseHll); ASSERT_EQ(expectedCardinality, deserialized.cardinality()); - auto serialized = serialize(denseHll); - ASSERT_EQ(expectedCardinality, DenseHll::cardinality(serialized.data())); + auto serialized = this->serialize(denseHll); + ASSERT_EQ(expectedCardinality, DenseHlls::cardinality(serialized.data())); } -TEST_P(DenseHllTest, highCardinality) { - int8_t indexBitLength = GetParam(); +TYPED_TEST(DenseHllTest, highCardinality) { + int8_t indexBitLength = 11; + DenseHll denseHll{indexBitLength, this->allocator_}; - DenseHll denseHll{indexBitLength, &allocator_}; for (int i = 0; i < 10'000'000; i++) { auto hash = hashOne(i); denseHll.insertHash(hash); } - if (indexBitLength >= 11) { - ASSERT_NEAR(10'000'000, denseHll.cardinality(), 150'000); - } + ASSERT_NEAR(10'000'000, denseHll.cardinality(), 150'000); - DenseHll deserialized = roundTrip(denseHll); + auto deserialized = this->roundTrip(denseHll); ASSERT_EQ(denseHll.cardinality(), deserialized.cardinality()); - auto serialized = serialize(denseHll); - ASSERT_EQ(denseHll.cardinality(), DenseHll::cardinality(serialized.data())); + auto serialized = this->serialize(denseHll); + ASSERT_EQ(denseHll.cardinality(), DenseHlls::cardinality(serialized.data())); } namespace { @@ -170,62 +188,234 @@ std::vector sequence(T start, T end) { } } // namespace -TEST_P(DenseHllTest, canDeserialize) { +TYPED_TEST(DenseHllTest, mergeWith) { + // small, non-overlapping + this->testMergeWith(sequence(0, 100), sequence(100, 200)); + this->testMergeWith(sequence(100, 200), sequence(0, 100)); + + // small, overlapping + this->testMergeWith(sequence(0, 100), sequence(50, 150)); + this->testMergeWith(sequence(50, 150), sequence(0, 100)); + + // small, same + this->testMergeWith(sequence(0, 100), sequence(0, 100)); + + // large, non-overlapping + this->testMergeWith(sequence(0, 20'000), sequence(20'000, 40'000)); + this->testMergeWith(sequence(20'000, 40'000), sequence(0, 20'000)); + + // large, overlapping + this->testMergeWith(sequence(0, 2'000'000), sequence(1'000'000, 3'000'000)); + this->testMergeWith(sequence(1'000'000, 3'000'000), sequence(0, 2'000'000)); + + // large, same + this->testMergeWith(sequence(0, 2'000'000), sequence(0, 2'000'000)); +} + +// Separate test class for testing various index bit lengths +template +struct AllocatorWithIndexBits { + using AllocatorType = TAllocator; + static constexpr int8_t indexBitLength() { + return IndexBitLength; + } +}; + +template +class DenseHllMergeTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + } + + void SetUp() override { + if constexpr (std::is_same_v< + typename TParam::AllocatorType, + HashStringAllocator>) { + allocator_ = &hsa_; + } else { + allocator_ = pool_.get(); + } + } + + std::string serialize(DenseHll& denseHll) { + auto size = denseHll.serializedSize(); + std::string serialized; + serialized.resize(size); + denseHll.serialize(serialized.data()); + return serialized; + } + + template + void testMergeWith( + int8_t indexBitLength, + const std::vector& left, + const std::vector& right) { + testMergeWith(indexBitLength, left, right, false); + testMergeWith(indexBitLength, left, right, true); + } + + template + void testMergeWith( + int8_t indexBitLength, + const std::vector& left, + const std::vector& right, + bool serialized) { + DenseHll hllLeft{indexBitLength, allocator_}; + DenseHll hllRight{indexBitLength, allocator_}; + DenseHll expected{indexBitLength, allocator_}; + + for (auto value : left) { + auto hash = hashOne(value); + hllLeft.insertHash(hash); + expected.insertHash(hash); + } + + for (auto value : right) { + auto hash = hashOne(value); + hllRight.insertHash(hash); + expected.insertHash(hash); + } + + if (serialized) { + auto serializedRight = this->serialize(hllRight); + hllLeft.mergeWith(serializedRight.data()); + } else { + hllLeft.mergeWith(hllRight); + } + + ASSERT_EQ(hllLeft.cardinality(), expected.cardinality()); + ASSERT_EQ(this->serialize(hllLeft), this->serialize(expected)); + + auto hllLeftSerialized = this->serialize(hllLeft); + ASSERT_EQ( + DenseHlls::cardinality(hllLeftSerialized.data()), + expected.cardinality()); + } + + std::shared_ptr pool_{ + memory::memoryManager()->addLeafPool()}; + HashStringAllocator hsa_{pool_.get()}; + typename TParam::AllocatorType* allocator_; +}; + +using DenseHllMergeTestParams = ::testing::Types< + // HashStringAllocator with all index bit lengths + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + // MemoryPool with all index bit lengths + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits>; + +class ComprehensiveNameGenerator { + public: + template + static std::string GetName(int) { + std::string allocatorName; + if constexpr (std::is_same_v< + typename TParam::AllocatorType, + HashStringAllocator>) { + allocatorName = "hsa"; + } else if constexpr (std::is_same_v< + typename TParam::AllocatorType, + memory::MemoryPool>) { + allocatorName = "pool"; + } else { + VELOX_UNREACHABLE( + "Only HashStringAllocator and MemoryPool are supported allocator types."); + } + return fmt::format("{}_{}", allocatorName, TParam::indexBitLength()); + } +}; + +TYPED_TEST_SUITE( + DenseHllMergeTest, + DenseHllMergeTestParams, + ComprehensiveNameGenerator); + +class DenseHllCanDeserializeTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + } +}; + +TEST_F(DenseHllCanDeserializeTest, canDeserialize) { // These are not valid HLL but all pass canDeserialize version only check. - std::vector invalidStrings{ + std::vector invalidStrings{ "AxIRESUhEzNBFCQWYxEjIzI1ISURNidCMlViIjOSNyATBhYSIiJDUyMBIlcSMDUiEUEiESM1ITckQkQTMSMhMyQx", "Aw==", "AyAAABEQAgAlAgAQAQAlMgAhQQABAwAAERAAAAEQACA=", "AwDuUjGFaQ==", "AwkLD8BYTA9BXyg="}; - for (folly::StringPiece& invalidString : invalidStrings) { + for (const auto invalidString : invalidStrings) { auto invalidHll = Base64::decode(invalidString); - EXPECT_TRUE(DenseHll::canDeserialize(invalidHll.c_str())); + EXPECT_TRUE(DenseHlls::canDeserialize(invalidHll.c_str())); EXPECT_FALSE( - DenseHll::canDeserialize(invalidHll.c_str(), invalidHll.length())); + DenseHlls::canDeserialize(invalidHll.c_str(), invalidHll.length())); } - std::vector validStrings{ + std::vector validStrings{ "AwwAQSVCQ4QUNDJkIzMjaSQmaVUxRSVDQ1FaIiJEYkNxNTEzWBQ0M0IhQSRDQ0RkYkRXMSJjM0MSJWMkQlNUNCJHVIVUM1QzQTVEUyE0ExMyV0NYQSR0NFaSFXI1IzKEJkMjEydDUzOVAjJFIkSTUREzM2MjEkg2U7MjIiIkRyJhMzMyMiFEQ1IyIlMyZFMkIyBzNSRGUUMiLMMTQzNDUiEmM0JxY1IjRFRUNlJjJFY0UxMjMWQkNSFRVBQlM1IzNFIiVDIiMhJiMSQVIjQpMjdWJnM0QzKCIRMjNFQnkyRkRjZFVCYjJScVZnM1QzMSYkMXIyIzQTUzMzQjNSRBVkITMEmAM1JiRAMyUzUXU0RDNkMSJBNSNCIiFCVkQxNDQiEkZFdGE5JEUio1VDRzJTJnMkQ0VDM1UTcFQTQSeCJDNCIiNiljJiMUNSdHSkEiNSMjJpJDIlUxElJiVXRCSDU1ERgQM7MyJDUiQmMUUVFSUhIjATJUJCWGYVUkc0NTISMUd4NUNTQzRiMjU0MzUjEkIzIyhRMkEQEnMiNFJGRSUjE2JVNUJTJEN0MiY2IiJGBXQ1QlUCZCcyIyJTQyRENhIzMjSEJEtnYiMjAjQzNEUyUiNJlDQ2gyImMzRDNWJRMhIVFHJUdkNjYSZDE1VlJAQ0IyNRJCMzVSQXFycjJTRlMVhjVDIxOhIiYWUTZmJSRDJFRjJEQ0gCMzYlN1dDVYIpMwMWdERkVzQxIlcjMiZkEEFCEiNCJnMhFjJiMkUmIqNCNSA1MyZiI2NCRkRkJzM1UiFBM1MzVFEoeDNGBhOTMTZEQiUVQiUjQiN0oldrMSRDEzcjNiI1K2REQlIUMiRUF1MzRmJDNDVTUSM3UFQkI1UyYiQRMxdoIlIjZUM1RSckYmMnMjNhc0RCQkFTMoMjISNyJCQSIlclUUU0IVUVQyNCAjNBUjUyIzRFFEBEMlJjEyM1IlRZIjJEUkZiZCFiNEQjU0FVIRMzNEMyIjQxIzMiIzJDM1lkQSMWYSFBMmQTZkQiYSZRhhNiQiMDcxImIzZWclJaRCCBcyYUVRVVRRZlERNBETQiIyRVBRNTMzVCQyJDVTUjJDUyJQI6iAYlMlNDISETVCZiIyZyRDIkREFEMmJEU4JDQkM0F1RSM4MyQjJSMUIjNHMkJCNBNTNiVJJTQVIjFzFhQTcyRTNERCJEU0MiRlYkIjMkO0YiISKFIyExVjYmE0QxWCMRU0NDJTUylDZCNUE0VAUSEWMjISNDU1VzUkMzM1MiQhYzhxNEQTlENHY0MSckRHEmNiFyNjY1QyAkNDM0JCI3JRIzRlEhREZTJhZkUldVI1YzJ2MzIkUzRRQTNDKUNDI1OIESUjUlNQMmNSUzZCM0FDJrMjEiFVSmMkM0VCYzMiQyYjYhBSM0RTYyISQnEzWjNCNIRkg0Q1M4EnNGYzMkoSVEU1MjJWQSRDVXVCMjMjMiIyUyM2NTN0chNSNUMkNHJzITImUjNIRDYTQ0Mi6UZyNUMYIzNGMDcWYkRUN2EzElMyEyMiEzESQjNUIxYzYhMkdjYVMhRCGEEiJDIkF0RCNTUlQiUUITgjUxWSNBRSJjNFYzJEM0UzRoY0IzM0KkJGIjUyITVUQqQzJDUqFDIzIiYxRjMidEVkJUFyc0Q0I3k2MWIyMkZnNiMRVRIjY1MyNEMxIsMmRYEkJnI1MyNCEjNlFyVDUyckVUU3WTMjdFMkFCRDRCIXNFMyElZUExEkMzIiURQiJRQkJGNDM0YUI0ExFRJiRWQVAiJTpENHZERDMyE3OFIkgiRTUlFGQxZCGFUkVDU0QjMiMSMUF1EiKmRDUVZlFUMVR1MjRVUyYxVSQzJDVTImRDJBUjEzUjMTYTUjFkUUIyEiMiYSE5STI0QTNURGIyNkMTNBVFMiMzU0UkNHY0QyJCcxJBZVRlIRIVJWgyUUdVMzQ1UWQ1KDUgJjIyMENDQgBUQiAiRGIzVSUjRmExJDMYRCYRMxJHKSkTIjEzJRA6NDdDUjZzEiMkBCUlM0IlFBZGIiRDMxJlRyQxNkJSJBQiIzIyczIzOTJhQkMjNyRCQ1MiIiOFMSNyVTJlJkE1UkMzUURmJWEYJzMyNDNAIiQDZTITJEImFiIUJEggIUMzE0k2FDVSczRRJEQTM2QSEUFHM2MSVDKEMkZAYUVhJGOERCV2EiYjJTUxRCoxI4NChBRnNSIiNFU4MhIjRoNSM0UkISYiMhIwgXRBQldUglUyVFQyEwVBIhOjNjMRNDd0hoUiJUoyJDMiJTNVQlMSMTpRkzNCNTJlMiNUEkMlYjRGJ4KUszNRISQTIBImIyFCIEFlQVc0MSIiI1JjRjNhI0YkUkRVI0UxVGSRFDgzIkUxRiNgElMiNLJaMyBGc0MzIQRFUyNEMyQnVxNUMSM1U0EzJTNENDdJMRUyMVEyNDFDSUQVRjNVIyE0RTRnMkJkYwEnJyEnUCFDMxRhUiVEMTIwUmNiFENTMyRRdFMjQRIohiJDM2KDMjM2NUNCZlIkJUMzNCMhQjMxEnREESZFUDZ0M0MTJRMkImQxYjcyQ0IjcoIxYyIzonc2JDRCUlA1JEFGJkRHVHMzI2IjUmRTMVJCMyUnJSUzNFQ1QiRjFEMTNmIUckMzRFVBIjU0UzIxI0JTI4JCM0FVYnlkNDFRJEJUQlEiI0ImNCJUIiaGFDYkQ2QSQxdFUjQnVHIzMmghQlNSZUIiRRQ0MAAA==", }; - for (folly::StringPiece& validString : validStrings) { + for (const auto validString : validStrings) { auto validHll = Base64::decode(validString); - EXPECT_TRUE(DenseHll::canDeserialize(validHll.c_str())); - EXPECT_TRUE(DenseHll::canDeserialize(validHll.c_str(), validHll.length())); + EXPECT_TRUE(DenseHlls::canDeserialize(validHll.c_str())); + EXPECT_TRUE(DenseHlls::canDeserialize(validHll.c_str(), validHll.length())); } } -TEST_P(DenseHllTest, mergeWith) { - int8_t indexBitLength = GetParam(); +TYPED_TEST(DenseHllMergeTest, mergeWith) { + int8_t indexBitLength = TypeParam::indexBitLength(); // small, non-overlapping - testMergeWith(indexBitLength, sequence(0, 100), sequence(100, 200)); - testMergeWith(indexBitLength, sequence(100, 200), sequence(0, 100)); + this->testMergeWith(indexBitLength, sequence(0, 100), sequence(100, 200)); + this->testMergeWith(indexBitLength, sequence(100, 200), sequence(0, 100)); // small, overlapping - testMergeWith(indexBitLength, sequence(0, 100), sequence(50, 150)); - testMergeWith(indexBitLength, sequence(50, 150), sequence(0, 100)); + this->testMergeWith(indexBitLength, sequence(0, 100), sequence(50, 150)); + this->testMergeWith(indexBitLength, sequence(50, 150), sequence(0, 100)); // small, same - testMergeWith(indexBitLength, sequence(0, 100), sequence(0, 100)); + this->testMergeWith(indexBitLength, sequence(0, 100), sequence(0, 100)); // large, non-overlapping - testMergeWith(indexBitLength, sequence(0, 20'000), sequence(20'000, 40'000)); - testMergeWith(indexBitLength, sequence(20'000, 40'000), sequence(0, 20'000)); + this->testMergeWith( + indexBitLength, sequence(0, 20'000), sequence(20'000, 40'000)); + this->testMergeWith( + indexBitLength, sequence(20'000, 40'000), sequence(0, 20'000)); // large, overlapping - testMergeWith( + this->testMergeWith( indexBitLength, sequence(0, 2'000'000), sequence(1'000'000, 3'000'000)); - testMergeWith( + this->testMergeWith( indexBitLength, sequence(1'000'000, 3'000'000), sequence(0, 2'000'000)); // large, same - testMergeWith(indexBitLength, sequence(0, 2'000'000), sequence(0, 2'000'000)); + this->testMergeWith( + indexBitLength, sequence(0, 2'000'000), sequence(0, 2'000'000)); } - -INSTANTIATE_TEST_SUITE_P( - DenseHllTest, - DenseHllTest, - ::testing::Values(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)); diff --git a/velox/common/hyperloglog/tests/SparseHllTest.cpp b/velox/common/hyperloglog/tests/SparseHllTest.cpp index 299e2c8aebd..b0f1e697872 100644 --- a/velox/common/hyperloglog/tests/SparseHllTest.cpp +++ b/velox/common/hyperloglog/tests/SparseHllTest.cpp @@ -15,9 +15,9 @@ */ #include "velox/common/hyperloglog/SparseHll.h" -#define XXH_INLINE_ALL -#include +#include "velox/common/base/XxHashInline.h" +#include #include using namespace facebook::velox; @@ -28,12 +28,21 @@ uint64_t hashOne(T value) { return XXH64(&value, sizeof(value), 0); } +template class SparseHllTest : public ::testing::Test { protected: static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); } + void SetUp() override { + if constexpr (std::is_same_v) { + allocator_ = &hsa_; + } else { + allocator_ = pool_.get(); + } + } + template void testMergeWith(const std::vector& left, const std::vector& right) { testMergeWith(left, right, false); @@ -45,9 +54,9 @@ class SparseHllTest : public ::testing::Test { const std::vector& left, const std::vector& right, bool serialized) { - SparseHll hllLeft{&allocator_}; - SparseHll hllRight{&allocator_}; - SparseHll expected{&allocator_}; + SparseHll hllLeft{allocator_}; + SparseHll hllRight{allocator_}; + SparseHll expected{allocator_}; for (auto value : left) { auto hash = hashOne(value); @@ -77,16 +86,20 @@ class SparseHllTest : public ::testing::Test { auto hllLeftSerialized = serialize(11, hllLeft); ASSERT_EQ( - SparseHll::cardinality(hllLeftSerialized.data()), + SparseHlls::cardinality(hllLeftSerialized.data()), expected.cardinality()); } - SparseHll roundTrip(SparseHll& hll) { - auto serialized = serialize(11, hll); - return SparseHll(serialized.data(), &allocator_); + SparseHll roundTrip( + SparseHll& hll, + int8_t indexBitLength = 11) { + auto serialized = serialize(indexBitLength, hll); + return SparseHll(serialized.data(), allocator_); } - std::string serialize(int8_t indexBitLength, const SparseHll& sparseHll) { + std::string serialize( + int8_t indexBitLength, + const SparseHll& sparseHll) { auto size = sparseHll.serializedSize(); std::string serialized; serialized.resize(size); @@ -94,7 +107,7 @@ class SparseHllTest : public ::testing::Test { return serialized; } - std::string serialize(DenseHll& denseHll) { + std::string serialize(DenseHll& denseHll) { auto size = denseHll.serializedSize(); std::string serialized; serialized.resize(size); @@ -104,11 +117,46 @@ class SparseHllTest : public ::testing::Test { std::shared_ptr pool_{ memory::memoryManager()->addLeafPool()}; - HashStringAllocator allocator_{pool_.get()}; + HashStringAllocator hsa_{pool_.get()}; + TAllocator* allocator_; }; -TEST_F(SparseHllTest, basic) { - SparseHll sparseHll{&allocator_}; +using AllocatorTypes = + ::testing::Types; + +class NameGenerator { + public: + template + static std::string GetName(int) { + if constexpr (std::is_same_v) { + return "hsa"; + } else if constexpr (std::is_same_v) { + return "pool"; + } else { + VELOX_UNREACHABLE( + "Only HashStringAllocator and MemoryPool are supported allocator types."); + } + } +}; + +TYPED_TEST_SUITE(SparseHllTest, AllocatorTypes, NameGenerator); + +TYPED_TEST(SparseHllTest, corruptEntryCount) { + // SparseHll format: int8_t version (2) + int8_t indexBitLength + int16_t + // size. A negative size causes std::terminate in vector::resize due to + // fbcode's noexcept _M_check_len. + std::string serialized(4, '\0'); + serialized[0] = 2; // kPrestoSparseV2 + serialized[1] = 11; // indexBitLength + // Write size = -1 as little-endian int16_t. + serialized[2] = static_cast(0xFF); + serialized[3] = static_cast(0xFF); + + EXPECT_ANY_THROW(SparseHll(serialized.data(), this->allocator_)); +} + +TYPED_TEST(SparseHllTest, basic) { + SparseHll sparseHll{this->allocator_}; for (int i = 0; i < 1'000; i++) { auto value = i % 17; auto hash = hashOne(value); @@ -118,16 +166,16 @@ TEST_F(SparseHllTest, basic) { sparseHll.verify(); ASSERT_EQ(17, sparseHll.cardinality()); - auto deserialized = roundTrip(sparseHll); + auto deserialized = this->roundTrip(sparseHll); deserialized.verify(); ASSERT_EQ(17, deserialized.cardinality()); - auto serialized = serialize(11, sparseHll); - ASSERT_EQ(17, SparseHll::cardinality(serialized.data())); + auto serialized = this->serialize(11, sparseHll); + ASSERT_EQ(17, SparseHlls::cardinality(serialized.data())); } -TEST_F(SparseHllTest, highCardinality) { - SparseHll sparseHll{&allocator_}; +TYPED_TEST(SparseHllTest, highCardinality) { + SparseHll sparseHll{this->allocator_}; for (int i = 0; i < 1'000; i++) { auto hash = hashOne(i); sparseHll.insertHash(hash); @@ -136,12 +184,12 @@ TEST_F(SparseHllTest, highCardinality) { sparseHll.verify(); ASSERT_EQ(1'000, sparseHll.cardinality()); - auto deserialized = roundTrip(sparseHll); + auto deserialized = this->roundTrip(sparseHll); deserialized.verify(); ASSERT_EQ(1'000, deserialized.cardinality()); - auto serialized = serialize(11, sparseHll); - ASSERT_EQ(1'000, SparseHll::cardinality(serialized.data())); + auto serialized = this->serialize(11, sparseHll); + ASSERT_EQ(1'000, SparseHlls::cardinality(serialized.data())); } namespace { @@ -156,30 +204,80 @@ std::vector sequence(T start, T end) { } } // namespace -TEST_F(SparseHllTest, mergeWith) { +TYPED_TEST(SparseHllTest, mergeWith) { // with overlap - testMergeWith(sequence(0, 100), sequence(50, 150)); - testMergeWith(sequence(50, 150), sequence(0, 100)); + this->testMergeWith(sequence(0, 100), sequence(50, 150)); + this->testMergeWith(sequence(50, 150), sequence(0, 100)); // no overlap - testMergeWith(sequence(0, 100), sequence(200, 300)); - testMergeWith(sequence(200, 300), sequence(0, 100)); + this->testMergeWith(sequence(0, 100), sequence(200, 300)); + this->testMergeWith(sequence(200, 300), sequence(0, 100)); // idempotent - testMergeWith(sequence(0, 100), sequence(0, 100)); + this->testMergeWith(sequence(0, 100), sequence(0, 100)); // empty sequence - testMergeWith(sequence(0, 100), {}); - testMergeWith({}, sequence(100, 300)); + this->testMergeWith(sequence(0, 100), {}); + this->testMergeWith({}, sequence(100, 300)); } -class SparseHllToDenseTest : public ::testing::TestWithParam { +TYPED_TEST(SparseHllTest, toDense) { + int8_t indexBitLength = 11; + + SparseHll sparseHll{this->allocator_}; + DenseHll expectedHll{indexBitLength, this->allocator_}; + for (int i = 0; i < 1'000; i++) { + auto hash = hashOne(i); + sparseHll.insertHash(hash); + expectedHll.insertHash(hash); + } + + DenseHll denseHll{indexBitLength, this->allocator_}; + sparseHll.toDense(denseHll); + ASSERT_EQ(denseHll.cardinality(), expectedHll.cardinality()); + ASSERT_EQ(this->serialize(denseHll), this->serialize(expectedHll)); +} + +TYPED_TEST(SparseHllTest, testNumberOfZeros) { + int8_t indexBitLength = 11; + for (int i = 0; i < 64 - indexBitLength; ++i) { + auto hash = 1ull << i; + SparseHll sparseHll(this->allocator_); + sparseHll.insertHash(hash); + DenseHll expectedHll(indexBitLength, this->allocator_); + expectedHll.insertHash(hash); + DenseHll denseHll(indexBitLength, this->allocator_); + sparseHll.toDense(denseHll); + ASSERT_EQ(this->serialize(denseHll), this->serialize(expectedHll)); + } +} + +template +struct AllocatorWithIndexBits { + using AllocatorType = TAllocator; + static constexpr int8_t indexBitLength() { + return IndexBitLength; + } +}; + +template +class SparseHllToDenseTest : public ::testing::Test { protected: static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); } - std::string serialize(DenseHll& denseHll) { + void SetUp() override { + if constexpr (std::is_same_v< + typename TParam::AllocatorType, + HashStringAllocator>) { + allocator_ = &hsa_; + } else { + allocator_ = pool_.get(); + } + } + + std::string serialize(DenseHll& denseHll) { auto size = denseHll.serializedSize(); std::string serialized; serialized.resize(size); @@ -189,41 +287,93 @@ class SparseHllToDenseTest : public ::testing::TestWithParam { std::shared_ptr pool_{ memory::memoryManager()->addLeafPool()}; - HashStringAllocator allocator_{pool_.get()}; + HashStringAllocator hsa_{pool_.get()}; + typename TParam::AllocatorType* allocator_; +}; + +using SparseHllToDenseTestParams = ::testing::Types< + // HashStringAllocator with various index bit lengths + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + // MemoryPool with various index bit lengths + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits>; + +class ToDenseNameGenerator { + public: + template + static std::string GetName(int) { + std::string allocatorName; + if constexpr (std::is_same_v< + typename TParam::AllocatorType, + HashStringAllocator>) { + allocatorName = "hsa"; + } else if constexpr (std::is_same_v< + typename TParam::AllocatorType, + memory::MemoryPool>) { + allocatorName = "pool"; + } else { + VELOX_UNREACHABLE( + "Only HashStringAllocator and MemoryPool are supported allocator types."); + } + return fmt::format("{}_{}", allocatorName, TParam::indexBitLength()); + } }; -TEST_P(SparseHllToDenseTest, toDense) { - int8_t indexBitLength = GetParam(); +TYPED_TEST_SUITE( + SparseHllToDenseTest, + SparseHllToDenseTestParams, + ToDenseNameGenerator); + +TYPED_TEST(SparseHllToDenseTest, toDense) { + int8_t indexBitLength = TypeParam::indexBitLength(); - SparseHll sparseHll{&allocator_}; - DenseHll expectedHll{indexBitLength, &allocator_}; + SparseHll sparseHll{this->allocator_}; + DenseHll expectedHll{indexBitLength, this->allocator_}; for (int i = 0; i < 1'000; i++) { auto hash = hashOne(i); sparseHll.insertHash(hash); expectedHll.insertHash(hash); } - DenseHll denseHll{indexBitLength, &allocator_}; + DenseHll denseHll{indexBitLength, this->allocator_}; sparseHll.toDense(denseHll); ASSERT_EQ(denseHll.cardinality(), expectedHll.cardinality()); - ASSERT_EQ(serialize(denseHll), serialize(expectedHll)); + ASSERT_EQ(this->serialize(denseHll), this->serialize(expectedHll)); } -TEST_P(SparseHllToDenseTest, testNumberOfZeros) { - auto indexBitLength = GetParam(); +TYPED_TEST(SparseHllToDenseTest, testNumberOfZeros) { + auto indexBitLength = TypeParam::indexBitLength(); for (int i = 0; i < 64 - indexBitLength; ++i) { auto hash = 1ull << i; - SparseHll sparseHll(&allocator_); + SparseHll sparseHll(this->allocator_); sparseHll.insertHash(hash); - DenseHll expectedHll(indexBitLength, &allocator_); + DenseHll expectedHll(indexBitLength, this->allocator_); expectedHll.insertHash(hash); - DenseHll denseHll(indexBitLength, &allocator_); + DenseHll denseHll(indexBitLength, this->allocator_); sparseHll.toDense(denseHll); - ASSERT_EQ(serialize(denseHll), serialize(expectedHll)); + ASSERT_EQ(this->serialize(denseHll), this->serialize(expectedHll)); } } - -INSTANTIATE_TEST_SUITE_P( - SparseHllToDenseTest, - SparseHllToDenseTest, - ::testing::Values(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)); diff --git a/velox/common/io/CMakeLists.txt b/velox/common/io/CMakeLists.txt index 3498214b4fd..4031afb8780 100644 --- a/velox/common/io/CMakeLists.txt +++ b/velox/common/io/CMakeLists.txt @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -velox_add_library(velox_common_io IoStatistics.cpp) +velox_add_library(velox_common_io IoStatistics.cpp HEADERS IoStatistics.h Options.h) velox_link_libraries(velox_common_io Folly::folly glog::glog) diff --git a/velox/common/io/IoStatistics.cpp b/velox/common/io/IoStatistics.cpp index 74d8383c308..21423486324 100644 --- a/velox/common/io/IoStatistics.cpp +++ b/velox/common/io/IoStatistics.cpp @@ -42,8 +42,8 @@ uint64_t IoStatistics::outputBatchSize() const { return outputBatchSize_.load(std::memory_order_relaxed); } -uint64_t IoStatistics::totalScanTime() const { - return totalScanTime_.load(std::memory_order_relaxed); +uint64_t IoStatistics::totalScanTimeNs() const { + return totalScanTimeNs_.load(std::memory_order_relaxed); } uint64_t IoStatistics::writeIOTimeUs() const { @@ -70,8 +70,8 @@ uint64_t IoStatistics::incRawOverreadBytes(int64_t v) { return rawOverreadBytes_.fetch_add(v, std::memory_order_relaxed); } -uint64_t IoStatistics::incTotalScanTime(int64_t v) { - return totalScanTime_.fetch_add(v, std::memory_order_relaxed); +uint64_t IoStatistics::incTotalScanTimeNs(int64_t v) { + return totalScanTimeNs_.fetch_add(v, std::memory_order_relaxed); } uint64_t IoStatistics::incWriteIOTimeUs(int64_t v) { @@ -111,14 +111,19 @@ IoStatistics::operationStats() const { void IoStatistics::merge(const IoStatistics& other) { rawBytesRead_ += other.rawBytesRead_; rawBytesWritten_ += other.rawBytesWritten_; - totalScanTime_ += other.totalScanTime_; + totalScanTimeNs_ += other.totalScanTimeNs_; rawOverreadBytes_ += other.rawOverreadBytes_; prefetch_.merge(other.prefetch_); read_.merge(other.read_); ramHit_.merge(other.ramHit_); ssdRead_.merge(other.ssdRead_); - queryThreadIoLatency_.merge(other.queryThreadIoLatency_); + queryThreadIoLatencyUs_.merge(other.queryThreadIoLatencyUs_); + storageReadLatencyUs_.merge(other.storageReadLatencyUs_); + ssdCacheReadLatencyUs_.merge(other.ssdCacheReadLatencyUs_); + cacheWaitLatencyUs_.merge(other.cacheWaitLatencyUs_); + coalescedSsdLoadLatencyUs_.merge(other.coalescedSsdLoadLatencyUs_); + coalescedStorageLoadLatencyUs_.merge(other.coalescedStorageLoadLatencyUs_); { const auto& otherOperationStats = other.operationStats(); std::lock_guard l(operationStatsMutex_); diff --git a/velox/common/io/IoStatistics.h b/velox/common/io/IoStatistics.h index 2111a8877b4..97d39e358e7 100644 --- a/velox/common/io/IoStatistics.h +++ b/velox/common/io/IoStatistics.h @@ -43,6 +43,24 @@ struct OperationCounters { class IoCounter { public: + IoCounter& operator=(const IoCounter& other) noexcept { + if (this != &other) { + count_.store( + other.count_.load(std::memory_order_relaxed), + std::memory_order_relaxed); + sum_.store( + other.sum_.load(std::memory_order_relaxed), + std::memory_order_relaxed); + min_.store( + other.min_.load(std::memory_order_relaxed), + std::memory_order_relaxed); + max_.store( + other.max_.load(std::memory_order_relaxed), + std::memory_order_relaxed); + } + return *this; + } + uint64_t count() const { return count_; } @@ -96,7 +114,7 @@ class IoStatistics { uint64_t rawBytesWritten() const; uint64_t inputBatchSize() const; uint64_t outputBatchSize() const; - uint64_t totalScanTime() const; + uint64_t totalScanTimeNs() const; uint64_t writeIOTimeUs() const; uint64_t incRawBytesRead(int64_t); @@ -104,7 +122,7 @@ class IoStatistics { uint64_t incRawBytesWritten(int64_t); uint64_t incInputBatchSize(int64_t); uint64_t incOutputBatchSize(int64_t); - uint64_t incTotalScanTime(int64_t); + uint64_t incTotalScanTimeNs(int64_t); uint64_t incWriteIOTimeUs(int64_t); IoCounter& prefetch() { @@ -123,8 +141,28 @@ class IoStatistics { return ramHit_; } - IoCounter& queryThreadIoLatency() { - return queryThreadIoLatency_; + IoCounter& queryThreadIoLatencyUs() { + return queryThreadIoLatencyUs_; + } + + IoCounter& storageReadLatencyUs() { + return storageReadLatencyUs_; + } + + IoCounter& ssdCacheReadLatencyUs() { + return ssdCacheReadLatencyUs_; + } + + IoCounter& cacheWaitLatencyUs() { + return cacheWaitLatencyUs_; + } + + IoCounter& coalescedSsdLoadLatencyUs() { + return coalescedSsdLoadLatencyUs_; + } + + IoCounter& coalescedStorageLoadLatencyUs() { + return coalescedStorageLoadLatencyUs_; } void incOperationCounters( @@ -151,7 +189,7 @@ class IoStatistics { std::atomic inputBatchSize_{0}; std::atomic outputBatchSize_{0}; std::atomic rawOverreadBytes_{0}; - std::atomic totalScanTime_{0}; + std::atomic totalScanTimeNs_{0}; std::atomic writeIOTimeUs_{0}; // Planned read from storage or SSD. @@ -169,7 +207,24 @@ class IoStatistics { // Time spent by a query processing thread waiting for synchronously issued IO // or for an in-progress read-ahead to finish. - IoCounter queryThreadIoLatency_; + IoCounter queryThreadIoLatencyUs_; + + // Breakdown of queryThreadIoLatencyUs_ by I/O type: + + // Time spent waiting for remote storage reads (S3, HDFS, etc.) + IoCounter storageReadLatencyUs_; + + // Time spent waiting for SSD cache reads + IoCounter ssdCacheReadLatencyUs_; + + // Time spent waiting for EXCLUSIVE cache entries (another thread is loading) + IoCounter cacheWaitLatencyUs_; + + // Time spent waiting for coalesced loads from SSD cache + IoCounter coalescedSsdLoadLatencyUs_; + + // Time spent waiting for coalesced loads from remote storage + IoCounter coalescedStorageLoadLatencyUs_; std::unordered_map operationStats_; mutable std::mutex operationStatsMutex_; diff --git a/velox/common/io/Options.h b/velox/common/io/Options.h index 32cb2387868..568c581cdfc 100644 --- a/velox/common/io/Options.h +++ b/velox/common/io/Options.h @@ -16,6 +16,10 @@ #pragma once +#include + +#include "velox/common/base/Exceptions.h" +#include "velox/common/io/IoStatistics.h" #include "velox/common/memory/Memory.h" namespace facebook::velox::io { @@ -63,14 +67,25 @@ class ReaderOptions { static constexpr int32_t kDefaultCoalesceBytes = 128 << 20; // 128M static constexpr int32_t kDefaultPrefetchRowGroups = 1; - explicit ReaderOptions(velox::memory::MemoryPool* pool) - : memoryPool_(pool), - autoPreloadLength_(DEFAULT_AUTO_PRELOAD_SIZE), - prefetchMode_(PrefetchMode::PREFETCH) {} + explicit ReaderOptions(velox::memory::MemoryPool* pool) : pool_{pool} { + VELOX_CHECK_NOT_NULL(pool_); + } + + ReaderOptions& setDataIoStats(std::shared_ptr stats) { + VELOX_CHECK_NULL(dataIoStats_, "dataIoStats already set"); + dataIoStats_ = std::move(stats); + return *this; + } + + ReaderOptions& setMetadataIoStats(std::shared_ptr stats) { + VELOX_CHECK_NULL(metadataIoStats_, "metadataIoStats already set"); + metadataIoStats_ = std::move(stats); + return *this; + } - /// Sets the memory pool for allocation. - ReaderOptions& setMemoryPool(velox::memory::MemoryPool& pool) { - memoryPool_ = &pool; + ReaderOptions& setIndexIoStats(std::shared_ptr stats) { + VELOX_CHECK_NULL(indexIoStats_, "indexIoStats already set"); + indexIoStats_ = std::move(stats); return *this; } @@ -112,7 +127,7 @@ class ReaderOptions { /// Gets the memory allocator. velox::memory::MemoryPool& memoryPool() const { - return *memoryPool_; + return *pool_; } uint64_t autoPreloadLength() const { @@ -139,22 +154,52 @@ class ReaderOptions { return prefetchRowGroups_; } - bool noCacheRetention() const { - return noCacheRetention_; + bool cacheable() const { + return cacheable_; + } + + void setCacheable(bool cacheable) { + cacheable_ = cacheable; } - void setNoCacheRetention(bool noCacheRetention) { - noCacheRetention_ = noCacheRetention; + const std::shared_ptr& ioExecutor() const { + return ioExecutor_; + } + + void setIOExecutor(std::shared_ptr ioExecutor) { + ioExecutor_ = std::move(ioExecutor); + } + + /// IO statistics for tracking storage reads, SSD reads, RAM cache hits, + /// and overread bytes for data stream IO. + const std::shared_ptr& dataIoStats() const { + return dataIoStats_; + } + + /// IO statistics for tracking storage reads, SSD reads, RAM cache hits, + /// and overread bytes for metadata IO (footer, stripe groups, index). + const std::shared_ptr& metadataIoStats() const { + return metadataIoStats_; + } + + const std::shared_ptr& indexIoStats() const { + return indexIoStats_; } protected: - velox::memory::MemoryPool* memoryPool_; - uint64_t autoPreloadLength_; - PrefetchMode prefetchMode_; + velox::memory::MemoryPool* pool_; + std::shared_ptr dataIoStats_; + std::shared_ptr metadataIoStats_; + std::shared_ptr indexIoStats_; + + std::shared_ptr ioExecutor_; + + uint64_t autoPreloadLength_{DEFAULT_AUTO_PRELOAD_SIZE}; + PrefetchMode prefetchMode_{PrefetchMode::PREFETCH}; int32_t loadQuantum_{kDefaultLoadQuantum}; int32_t maxCoalesceDistance_{kDefaultCoalesceDistance}; int64_t maxCoalesceBytes_{kDefaultCoalesceBytes}; int32_t prefetchRowGroups_{kDefaultPrefetchRowGroups}; - bool noCacheRetention_{false}; + bool cacheable_{true}; }; } // namespace facebook::velox::io diff --git a/velox/common/io/tests/IoStatisticsTest.cpp b/velox/common/io/tests/IoStatisticsTest.cpp new file mode 100644 index 00000000000..72a9d051e58 --- /dev/null +++ b/velox/common/io/tests/IoStatisticsTest.cpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/io/IoStatistics.h" +#include + +namespace facebook::velox::io { + +TEST(IoStatisticsTest, latencyBreakdown) { + IoStatistics stats; + + stats.storageReadLatencyUs().increment(100); + stats.ssdCacheReadLatencyUs().increment(50); + stats.cacheWaitLatencyUs().increment(25); + stats.coalescedSsdLoadLatencyUs().increment(30); + stats.coalescedStorageLoadLatencyUs().increment(45); + + EXPECT_EQ(stats.storageReadLatencyUs().count(), 1); + EXPECT_EQ(stats.storageReadLatencyUs().sum(), 100); + EXPECT_EQ(stats.ssdCacheReadLatencyUs().sum(), 50); + EXPECT_EQ(stats.cacheWaitLatencyUs().sum(), 25); + EXPECT_EQ(stats.coalescedSsdLoadLatencyUs().sum(), 30); + EXPECT_EQ(stats.coalescedStorageLoadLatencyUs().sum(), 45); + + IoStatistics stats2; + stats2.storageReadLatencyUs().increment(200); + stats.merge(stats2); + + EXPECT_EQ(stats.storageReadLatencyUs().count(), 2); + EXPECT_EQ(stats.storageReadLatencyUs().sum(), 300); +} + +TEST(IoStatisticsTest, totalScanTimeNs) { + IoStatistics stats; + EXPECT_EQ(stats.totalScanTimeNs(), 0); + + stats.incTotalScanTimeNs(1'000); + EXPECT_EQ(stats.totalScanTimeNs(), 1'000); + + stats.incTotalScanTimeNs(2'500); + EXPECT_EQ(stats.totalScanTimeNs(), 3'500); + + IoStatistics stats2; + stats2.incTotalScanTimeNs(500); + stats.merge(stats2); + EXPECT_EQ(stats.totalScanTimeNs(), 4'000); +} + +} // namespace facebook::velox::io diff --git a/velox/common/io/tests/ReaderOptionsTest.cpp b/velox/common/io/tests/ReaderOptionsTest.cpp new file mode 100644 index 00000000000..80ff4831c87 --- /dev/null +++ b/velox/common/io/tests/ReaderOptionsTest.cpp @@ -0,0 +1,175 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/io/Options.h" + +#include +#include + +#include "velox/common/base/tests/GTestUtils.h" + +namespace facebook::velox::io { + +class ReaderOptionsTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance({}); + } + + const std::shared_ptr pool_{ + memory::memoryManager()->addLeafPool("ReaderOptionsTest")}; + std::shared_ptr dataIoStats_{std::make_shared()}; + std::shared_ptr metadataIoStats_{ + std::make_shared()}; + std::shared_ptr indexIoStats_{std::make_shared()}; +}; + +TEST_F(ReaderOptionsTest, constructor) { + ReaderOptions options(pool_.get()); + + EXPECT_EQ(&options.memoryPool(), pool_.get()); + EXPECT_EQ(options.dataIoStats(), nullptr); + EXPECT_EQ(options.metadataIoStats(), nullptr); + EXPECT_EQ(options.indexIoStats(), nullptr); + EXPECT_EQ(options.ioExecutor(), nullptr); + EXPECT_EQ(options.autoPreloadLength(), DEFAULT_AUTO_PRELOAD_SIZE); + EXPECT_EQ(options.prefetchMode(), PrefetchMode::PREFETCH); + EXPECT_EQ(options.loadQuantum(), ReaderOptions::kDefaultLoadQuantum); + EXPECT_EQ( + options.maxCoalesceDistance(), ReaderOptions::kDefaultCoalesceDistance); + EXPECT_EQ(options.maxCoalesceBytes(), ReaderOptions::kDefaultCoalesceBytes); + EXPECT_EQ( + options.prefetchRowGroups(), ReaderOptions::kDefaultPrefetchRowGroups); + EXPECT_TRUE(options.cacheable()); +} + +TEST_F(ReaderOptionsTest, setters) { + ReaderOptions options(pool_.get()); + + options.setAutoPreloadLength(1'024); + EXPECT_EQ(options.autoPreloadLength(), 1'024); + + options.setPrefetchMode(PrefetchMode::NOT_SET); + EXPECT_EQ(options.prefetchMode(), PrefetchMode::NOT_SET); + + options.setLoadQuantum(4 << 20); + EXPECT_EQ(options.loadQuantum(), 4 << 20); + + options.setMaxCoalesceDistance(1 << 20); + EXPECT_EQ(options.maxCoalesceDistance(), 1 << 20); + + options.setMaxCoalesceBytes(64 << 20); + EXPECT_EQ(options.maxCoalesceBytes(), 64 << 20); + + options.setPrefetchRowGroups(4); + EXPECT_EQ(options.prefetchRowGroups(), 4); + + options.setCacheable(false); + EXPECT_FALSE(options.cacheable()); +} + +TEST_F(ReaderOptionsTest, ioExecutor) { + ReaderOptions options(pool_.get()); + EXPECT_EQ(options.ioExecutor(), nullptr); + + auto executor = std::make_shared(2); + options.setIOExecutor(executor); + EXPECT_EQ(options.ioExecutor(), executor); + + options.setIOExecutor(nullptr); + EXPECT_EQ(options.ioExecutor(), nullptr); +} + +TEST_F(ReaderOptionsTest, ioStats) { + ReaderOptions options(pool_.get()); + options.setDataIoStats(dataIoStats_); + options.setMetadataIoStats(metadataIoStats_); + options.setIndexIoStats(indexIoStats_); + + EXPECT_EQ(options.dataIoStats(), dataIoStats_); + EXPECT_EQ(options.metadataIoStats(), metadataIoStats_); + EXPECT_EQ(options.indexIoStats(), indexIoStats_); + + options.dataIoStats()->read().increment(100); + EXPECT_EQ(dataIoStats_->read().count(), 1); + EXPECT_EQ(dataIoStats_->read().sum(), 100); + + options.metadataIoStats()->read().increment(50); + EXPECT_EQ(metadataIoStats_->read().count(), 1); + EXPECT_EQ(metadataIoStats_->read().sum(), 50); + + options.indexIoStats()->read().increment(25); + EXPECT_EQ(indexIoStats_->read().count(), 1); + EXPECT_EQ(indexIoStats_->read().sum(), 25); +} + +TEST_F(ReaderOptionsTest, chainingSetters) { + ReaderOptions options(pool_.get()); + auto& result = options.setDataIoStats(dataIoStats_) + .setMetadataIoStats(metadataIoStats_) + .setIndexIoStats(indexIoStats_) + .setAutoPreloadLength(1'024) + .setPrefetchMode(PrefetchMode::PRELOAD) + .setLoadQuantum(4 << 20) + .setMaxCoalesceDistance(1 << 20) + .setMaxCoalesceBytes(64 << 20) + .setPrefetchRowGroups(4); + + EXPECT_EQ(&result, &options); + EXPECT_EQ(options.dataIoStats(), dataIoStats_); + EXPECT_EQ(options.metadataIoStats(), metadataIoStats_); + EXPECT_EQ(options.indexIoStats(), indexIoStats_); + EXPECT_EQ(options.autoPreloadLength(), 1'024); + EXPECT_EQ(options.prefetchMode(), PrefetchMode::PRELOAD); + EXPECT_EQ(options.loadQuantum(), 4 << 20); + EXPECT_EQ(options.maxCoalesceDistance(), 1 << 20); + EXPECT_EQ(options.maxCoalesceBytes(), 64 << 20); + EXPECT_EQ(options.prefetchRowGroups(), 4); +} + +TEST_F(ReaderOptionsTest, doubleSetIoStatsThrows) { + ReaderOptions options(pool_.get()); + options.setDataIoStats(dataIoStats_); + VELOX_ASSERT_THROW( + options.setDataIoStats(dataIoStats_), "dataIoStats already set"); + + ReaderOptions options2(pool_.get()); + options2.setMetadataIoStats(metadataIoStats_); + VELOX_ASSERT_THROW( + options2.setMetadataIoStats(metadataIoStats_), + "metadataIoStats already set"); + + ReaderOptions options3(pool_.get()); + options3.setIndexIoStats(indexIoStats_); + VELOX_ASSERT_THROW( + options3.setIndexIoStats(indexIoStats_), "indexIoStats already set"); +} + +TEST_F(ReaderOptionsTest, copyConstruct) { + ReaderOptions options(pool_.get()); + options.setDataIoStats(dataIoStats_); + options.setMetadataIoStats(metadataIoStats_); + options.setLoadQuantum(4 << 20); + + ReaderOptions copy(options); + EXPECT_EQ(©.memoryPool(), pool_.get()); + EXPECT_EQ(copy.dataIoStats(), dataIoStats_); + EXPECT_EQ(copy.metadataIoStats(), metadataIoStats_); + EXPECT_EQ(copy.indexIoStats(), nullptr); + EXPECT_EQ(copy.loadQuantum(), 4 << 20); +} + +} // namespace facebook::velox::io diff --git a/velox/common/memory/Allocation.cpp b/velox/common/memory/Allocation.cpp index 884af7c82cf..63fb4db4baa 100644 --- a/velox/common/memory/Allocation.cpp +++ b/velox/common/memory/Allocation.cpp @@ -35,10 +35,11 @@ void Allocation::append(uint8_t* address, MachinePageCount numPages) { runs_.empty() || address != runs_.back().data(), "Appending a duplicate address into a PageRun"); if (FOLLY_UNLIKELY(numPages > Allocation::PageRun::kMaxPagesInRun)) { - VELOX_MEM_ALLOC_ERROR(fmt::format( - "The number of pages to append {} exceeds the PageRun limit {}", - numPages, - Allocation::PageRun::kMaxPagesInRun)); + VELOX_MEM_ALLOC_ERROR( + fmt::format( + "The number of pages to append {} exceeds the PageRun limit {}", + numPages, + Allocation::PageRun::kMaxPagesInRun)); } numPages_ += numPages; runs_.emplace_back(address, numPages); diff --git a/velox/common/memory/ArbitrationParticipant.cpp b/velox/common/memory/ArbitrationParticipant.cpp index dd679f81b06..7d74c98c731 100644 --- a/velox/common/memory/ArbitrationParticipant.cpp +++ b/velox/common/memory/ArbitrationParticipant.cpp @@ -441,8 +441,9 @@ ArbitrationTimedLock::ArbitrationTimedLock( uint64_t timeoutNs) : mutex_(mutex) { if (!mutex_.try_lock_for(std::chrono::nanoseconds(timeoutNs))) { - VELOX_MEM_ARBITRATION_TIMEOUT(fmt::format( - "Memory arbitration lock timed out when reclaiming from arbitration participant.")); + VELOX_MEM_ARBITRATION_TIMEOUT( + fmt::format( + "Memory arbitration lock timed out when reclaiming from arbitration participant.")); } } diff --git a/velox/common/memory/ByteStream.cpp b/velox/common/memory/ByteStream.cpp index 4fe93b56921..ebae051fc3a 100644 --- a/velox/common/memory/ByteStream.cpp +++ b/velox/common/memory/ByteStream.cpp @@ -16,19 +16,21 @@ #include "velox/common/memory/ByteStream.h" +#include + namespace facebook::velox { +static ByteRange convByteRange(folly::ByteRange br) { + return {const_cast(br.data()), folly::to_signed(br.size()), 0}; +} + std::vector byteRangesFromIOBuf(folly::IOBuf* iobuf) { if (iobuf == nullptr) { return {}; } std::vector byteRanges; - auto* current = iobuf; - do { - byteRanges.push_back( - {current->writableData(), static_cast(current->length()), 0}); - current = current->next(); - } while (current != iobuf); + auto dst = std::back_inserter(byteRanges); + std::transform(iobuf->begin(), iobuf->end(), dst, convByteRange); return byteRanges; } @@ -55,9 +57,7 @@ std::string BufferInputStream::toString() const { } bool BufferInputStream::atEnd() const { - if (current_ == nullptr) { - return false; - } + VELOX_DCHECK_NOT_NULL(current_); if (current_->position < current_->size) { return false; } @@ -75,9 +75,7 @@ size_t BufferInputStream::size() const { } size_t BufferInputStream::remainingSize() const { - if (ranges_.empty()) { - return 0; - } + VELOX_DCHECK(!ranges_.empty()); const auto* lastRange = &ranges_.back(); auto* cur = current_; size_t remainingBytes = cur->availableBytes(); @@ -88,10 +86,8 @@ size_t BufferInputStream::remainingSize() const { } std::streampos BufferInputStream::tellp() const { - if (ranges_.empty()) { - return 0; - } - assert(current_); + VELOX_DCHECK(!ranges_.empty()); + VELOX_DCHECK_NOT_NULL(current_); int64_t size = 0; for (auto& range : ranges_) { if (&range == current_) { @@ -103,9 +99,7 @@ std::streampos BufferInputStream::tellp() const { } void BufferInputStream::seekp(std::streampos position) { - if (ranges_.empty() && position == 0) { - return; - } + VELOX_DCHECK(!ranges_.empty()); int64_t toSkip = position; for (auto& range : ranges_) { if (toSkip <= range.size) { diff --git a/velox/common/memory/ByteStream.h b/velox/common/memory/ByteStream.h index fee24a1f2e8..b8bfc1b20f9 100644 --- a/velox/common/memory/ByteStream.h +++ b/velox/common/memory/ByteStream.h @@ -259,6 +259,8 @@ class BufferInputStream : public ByteInputStream { return ranges_; } + // The byte ranges backing this stream. Guaranteed to be non-empty after + // construction. std::vector ranges_; }; @@ -361,9 +363,10 @@ class ByteOutputStream { } if (current_->position + sizeof(T) * values.size() > current_->size) { - appendStringView(std::string_view( - reinterpret_cast(&values[0]), - values.size() * sizeof(T))); + appendStringView( + std::string_view( + reinterpret_cast(&values[0]), + values.size() * sizeof(T))); return; } auto* target = current_->buffer + current_->position; @@ -414,7 +417,7 @@ class ByteOutputStream { auto* buffer = current_->buffer + (position >> 3); auto value = folly::loadUnaligned(buffer); value = (value & mask) | (bits[0] << offset); - folly::storeUnaligned(buffer, value); + folly::storeUnaligned(buffer, value); current_->position += end; return; } @@ -537,9 +540,10 @@ class AppendWindow { ~AppendWindow() noexcept { if (scratchPtr_.size()) { try { - stream_.appendStringView(std::string_view( - reinterpret_cast(scratchPtr_.get()), - scratchPtr_.size() * sizeof(T))); + stream_.appendStringView( + std::string_view( + reinterpret_cast(scratchPtr_.get()), + scratchPtr_.size() * sizeof(T))); } catch (const std::exception& e) { // This is impossible because construction ensures there is space for // the bytes in the stream. diff --git a/velox/common/memory/CMakeLists.txt b/velox/common/memory/CMakeLists.txt index a36e0f06a6b..d6353c696cc 100644 --- a/velox/common/memory/CMakeLists.txt +++ b/velox/common/memory/CMakeLists.txt @@ -33,6 +33,25 @@ velox_add_library( RawVector.cpp SharedArbitrator.cpp StreamArena.cpp + HEADERS + Allocation.h + AllocationPool.h + ArbitrationOperation.h + ArbitrationParticipant.h + ByteStream.h + CompactDoubleList.h + HashStringAllocator.h + MallocAllocator.h + Memory.h + MemoryAllocator.h + MemoryArbitrator.h + MemoryPool.h + MmapAllocator.h + MmapArena.h + RawVector.h + Scratch.h + SharedArbitrator.h + StreamArena.h ) velox_link_libraries( diff --git a/velox/common/memory/HashStringAllocator.h b/velox/common/memory/HashStringAllocator.h index 253bc4f27a1..6e153894dfc 100644 --- a/velox/common/memory/HashStringAllocator.h +++ b/velox/common/memory/HashStringAllocator.h @@ -19,7 +19,6 @@ #include "velox/common/memory/AllocationPool.h" #include "velox/common/memory/ByteStream.h" #include "velox/common/memory/CompactDoubleList.h" -#include "velox/common/memory/Memory.h" #include "velox/common/memory/StreamArena.h" #include "velox/type/StringView.h" @@ -27,6 +26,9 @@ namespace facebook::velox { +template +struct StlAllocator; + /// Implements an arena backed by memory::Allocation. This is for backing /// ByteOutputStream or for allocating single blocks. Blocks can be individually /// freed. Adjacent frees are coalesced and free blocks are kept in a free list. @@ -41,6 +43,9 @@ namespace facebook::velox { /// backing a HashStringAllocator is set to kArenaEnd. class HashStringAllocator : public StreamArena { public: + template + using TStlAllocator = StlAllocator; + /// The minimum allocation must have space after the header for the free list /// pointers and the trailing length. static constexpr int32_t kMinAlloc = @@ -659,8 +664,8 @@ class RowSizeTracker { counter_(counter) {} ~RowSizeTracker() { - auto delta = allocator_->currentBytes() - size_; - if (delta) { + const auto delta = allocator_->currentBytes() - size_; + if (delta != 0) { saturatingIncrement(&counter_, delta); } } @@ -668,7 +673,7 @@ class RowSizeTracker { private: // Increments T at *pointer without wrapping around at overflow. void saturatingIncrement(T* pointer, int64_t delta) { - auto value = *reinterpret_cast(pointer) + delta; + const auto value = *reinterpret_cast(pointer) + delta; *reinterpret_cast(pointer) = std::min(value, std::numeric_limits::max()); } @@ -688,8 +693,12 @@ struct StlAllocator { VELOX_CHECK_NOT_NULL(allocator); } + // We can use "explicit" here based on the C++ standard. But + // libstdc++ 12 or older doesn't work for std::vector and + // "explicit". We can avoid it by not using "explicit" here. + // See also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=115854 template - explicit StlAllocator(const StlAllocator& allocator) + StlAllocator(const StlAllocator& allocator) : allocator_{allocator.allocator()} { VELOX_CHECK_NOT_NULL(allocator_); } diff --git a/velox/common/memory/MallocAllocator.cpp b/velox/common/memory/MallocAllocator.cpp index 297b2a2781c..dc6da02c6e6 100644 --- a/velox/common/memory/MallocAllocator.cpp +++ b/velox/common/memory/MallocAllocator.cpp @@ -15,15 +15,17 @@ */ #include "velox/common/memory/MallocAllocator.h" +#include #include "velox/common/memory/Memory.h" #include namespace facebook::velox::memory { -MallocAllocator::MallocAllocator(size_t capacity, uint32_t reservationByteLimit) +MallocAllocator::MallocAllocator(const Options& options) : kind_(MemoryAllocator::Kind::kMalloc), - capacity_(capacity), - reservationByteLimit_(reservationByteLimit), + mallocContiguousEnabled_(options.mallocContiguousEnabled), + capacity_(options.capacity), + reservationByteLimit_(options.reservationByteLimit), reserveFunc_( [this](uint32_t& counter, uint32_t increment, std::mutex& lock) { return incrementUsageWithReservationFunc(counter, increment, lock); @@ -33,7 +35,7 @@ MallocAllocator::MallocAllocator(size_t capacity, uint32_t reservationByteLimit) decrementUsageWithReservationFunc(counter, decrement, lock); return true; }), - reservations_(std::thread::hardware_concurrency()) {} + reservations_(folly::available_concurrency()) {} MallocAllocator::~MallocAllocator() { // TODO: Remove the check when memory leak issue is resolved. @@ -144,11 +146,7 @@ bool MallocAllocator::allocateContiguousImpl( } auto numContiguousCollateralPages = allocation.numPages(); if (numContiguousCollateralPages > 0) { - useHugePages(allocation, false); - if (::munmap(allocation.data(), allocation.maxSize()) < 0) { - VELOX_MEM_LOG(ERROR) << "munmap got " << folly::errnoStr(errno) << "for " - << allocation.data() << ", " << allocation.size(); - } + dispatchFreeContiguous(allocation); numMapped_.fetch_sub(numContiguousCollateralPages); numAllocated_.fetch_sub(numContiguousCollateralPages); numExternalMapped_.fetch_sub(numContiguousCollateralPages); @@ -174,19 +172,23 @@ bool MallocAllocator::allocateContiguousImpl( numAllocated_.fetch_add(numPages); numMapped_.fetch_add(numPages); numExternalMapped_.fetch_add(numPages); - void* data = ::mmap( - nullptr, - AllocationTraits::pageBytes(maxPages), - PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, - -1, - 0); - // TODO: add handling of MAP_FAILED. - allocation.set( - data, - AllocationTraits::pageBytes(numPages), - AllocationTraits::pageBytes(maxPages)); - useHugePages(allocation, true); + const auto maxBytes = AllocationTraits::pageBytes(maxPages); + void* data = dispatchAllocateContiguous(maxBytes); + if (FOLLY_UNLIKELY(data == nullptr)) { + numAllocated_.fetch_sub(numPages); + numMapped_.fetch_sub(numPages); + numExternalMapped_.fetch_sub(numPages); + decrementUsage(static_cast(AllocationTraits::pageBytes(numPages))); + const auto errorMsg = fmt::format( + "Failed to allocate {} of contiguous memory", succinctBytes(maxBytes)); + VELOX_MEM_LOG(WARNING) << errorMsg; + setAllocatorFailureMessage(errorMsg); + return false; + } + allocation.set(data, AllocationTraits::pageBytes(numPages), maxBytes); + if (!mallocContiguousEnabled_) { + useHugePages(allocation, true); + } return true; } @@ -221,14 +223,9 @@ void MallocAllocator::freeContiguousImpl(ContiguousAllocation& allocation) { if (allocation.empty()) { return; } - useHugePages(allocation, false); const auto bytes = allocation.size(); const auto numPages = allocation.numPages(); - if (::munmap(allocation.data(), allocation.maxSize()) < 0) { - VELOX_MEM_LOG(ERROR) << "Error for munmap(" << allocation.data() << ", " - << succinctBytes(bytes) << "): '" - << folly::errnoStr(errno) << "'"; - } + dispatchFreeContiguous(allocation); numMapped_.fetch_sub(numPages); numAllocated_.fetch_sub(numPages); numExternalMapped_.fetch_sub(numPages); @@ -236,6 +233,36 @@ void MallocAllocator::freeContiguousImpl(ContiguousAllocation& allocation) { allocation.clear(); } +void* MallocAllocator::dispatchAllocateContiguous(size_t maxBytes) { + if (testingHasInjectedFailure(InjectedFailure::kAllocate)) { + return nullptr; + } + if (mallocContiguousEnabled_) { + return ::aligned_alloc(AllocationTraits::kPageSize, maxBytes); + } + // TODO: add handling of MAP_FAILED. + return ::mmap( + nullptr, + maxBytes, + PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, + -1, + 0); +} + +void MallocAllocator::dispatchFreeContiguous(ContiguousAllocation& allocation) { + if (mallocContiguousEnabled_) { + ::free(allocation.data()); + } else { + useHugePages(allocation, false); + if (::munmap(allocation.data(), allocation.maxSize()) < 0) { + VELOX_MEM_LOG(ERROR) << "Error for munmap(" << allocation.data() << ", " + << succinctBytes(allocation.size()) << "): '" + << folly::errnoStr(errno) << "'"; + } + } +} + bool MallocAllocator::growContiguousWithoutRetry( MachinePageCount increment, ContiguousAllocation& allocation) { @@ -308,6 +335,61 @@ void* MallocAllocator::allocateZeroFilledWithoutRetry(uint64_t bytes) { return result; } +void* MallocAllocator::reallocateBytesWithoutRetry( + void* p, + uint64_t oldSize, + uint64_t newSize, + uint16_t alignment) { + if (p == nullptr || alignment > kMinAlignment || + (reinterpret_cast(p) % alignment) != 0) { + return nullptr; + } + if (!isAlignmentValid(newSize, alignment)) { + VELOX_FAIL( + "Alignment check failed, reallocateBytes {}, alignmentBytes {}", + newSize, + alignment); + } + const auto delta = + static_cast(newSize) - static_cast(oldSize); + if (delta > 0 && !incrementUsage(delta)) { + auto errorMsg = fmt::format( + "Failed to reallocate: exceeded memory allocator limit of {}. " + "Old size: {}, new size: {}", + succinctBytes(capacity_), + succinctBytes(oldSize), + succinctBytes(newSize)); + VELOX_MEM_LOG_EVERY_MS(WARNING, 1000) << errorMsg; + setAllocatorFailureMessage(errorMsg); + return nullptr; + } + void* result = ::realloc(p, newSize); // NOLINT + if (result == nullptr) { + // realloc failed. The original pointer is still valid. + if (delta > 0) { + decrementUsage(delta); + } + VELOX_MEM_LOG(ERROR) << "Failed to reallocateBytes from " + << succinctBytes(oldSize) << " to " + << succinctBytes(newSize); + return nullptr; + } + // ::realloc() must return a pointer aligned to the requested alignment. + // The C standard guarantees alignof(max_align_t), and we already enforced + // alignment <= kMinAlignment above, so this should always hold for a + // conformant allocator. Fail loudly if it doesn't (e.g., a non-conformant + // LD_PRELOAD interposer). + VELOX_CHECK_EQ( + reinterpret_cast(result) % alignment, + 0, + "::realloc returned a pointer not aligned to requested alignment: {}", + alignment); + if (delta < 0) { + decrementUsage(-delta); + } + return result; +} + void MallocAllocator::freeBytes(void* p, uint64_t bytes) noexcept { ::free(p); // NOLINT decrementUsage(bytes); diff --git a/velox/common/memory/MallocAllocator.h b/velox/common/memory/MallocAllocator.h index aadeb5518b7..f3d020e05ac 100644 --- a/velox/common/memory/MallocAllocator.h +++ b/velox/common/memory/MallocAllocator.h @@ -26,7 +26,7 @@ namespace facebook::velox::memory { /// The implementation of MemoryAllocator using malloc. class MallocAllocator : public MemoryAllocator { public: - MallocAllocator(size_t capacity, uint32_t reservationByteLimit); + explicit MallocAllocator(const Options& options); ~MallocAllocator() override; @@ -102,12 +102,40 @@ class MallocAllocator : public MemoryAllocator { ContiguousAllocation& allocation, MachinePageCount maxPages); + // Allocates 'maxBytes' of contiguous memory using malloc or mmap depending + // on 'mallocContiguousEnabled_'. Returns the allocated pointer, or nullptr + // on failure. + void* dispatchAllocateContiguous(size_t maxBytes); + + // Frees contiguous memory previously allocated by + // dispatchAllocateContiguous. + void dispatchFreeContiguous(ContiguousAllocation& allocation); + void freeContiguousImpl(ContiguousAllocation& allocation); void* allocateBytesWithoutRetry(uint64_t bytes, uint16_t alignment) override; void* allocateZeroFilledWithoutRetry(uint64_t bytes) override; + /// Attempts in-place reallocation via ::realloc(). Returns nullptr when + /// ::realloc() cannot satisfy the request, in which case the caller falls + /// back to allocate + memcpy + free. Preconditions for using ::realloc(): + /// 1. p must be non-null. ::realloc(nullptr, n) would behave like malloc, + /// but we route nullptr through allocateBytes so the caller picks up + /// cache eviction on the fallback. + /// 2. The requested alignment must be at most kMinAlignment. ::realloc() + /// guarantees only kMinAlignment on the returned pointer (and may move + /// the data to a new address), so larger alignments cannot be honored. + /// 3. p must already be aligned to the requested alignment. MemoryAllocator + /// is a standalone module that may be used outside MemoryPool, so we + /// cannot rely on the caller to pass an alignment matching the original + /// allocation. + void* reallocateBytesWithoutRetry( + void* p, + uint64_t oldSize, + uint64_t newSize, + uint16_t alignment) override; + // Increments current usage and check current 'allocatedBytes_' counter to // make sure current usage does not go above 'capacity_'. If it goes above // 'capacity_', the increment will not be applied. Returns true if within @@ -206,16 +234,20 @@ class MallocAllocator : public MemoryAllocator { if (originalBytes - bytes < 0) { // In case of inconsistency while freeing memory, do not revert in this // case because free is guaranteed to happen. - VELOX_MEM_ALLOC_ERROR(fmt::format( - "Trying to free {} bytes, which is larger than current allocated " - "bytes {}", - bytes, - originalBytes)) + VELOX_MEM_ALLOC_ERROR( + fmt::format( + "Trying to free {} bytes, which is larger than current allocated " + "bytes {}", + bytes, + originalBytes)) } } const Kind kind_; + // If true, use malloc for contiguous allocations instead of mmap/munmap. + const bool mallocContiguousEnabled_; + // Capacity in bytes. Total allocation byte is not allowed to exceed this // value. const size_t capacity_; diff --git a/velox/common/memory/Memory.cpp b/velox/common/memory/Memory.cpp index 01de413c830..45a185cae7e 100644 --- a/velox/common/memory/Memory.cpp +++ b/velox/common/memory/Memory.cpp @@ -46,17 +46,21 @@ SingletonState& singletonState() { std::shared_ptr createAllocator( const MemoryManager::Options& options) { + MemoryAllocator::Options allocatorOptions; + allocatorOptions.capacity = options.allocatorCapacity; if (options.useMmapAllocator) { - MmapAllocator::Options mmapOptions; - mmapOptions.capacity = options.allocatorCapacity; - mmapOptions.largestSizeClass = options.largestSizeClassPages; - mmapOptions.useMmapArena = options.useMmapArena; - mmapOptions.mmapArenaCapacityRatio = options.mmapArenaCapacityRatio; - return std::make_shared(mmapOptions); + allocatorOptions.largestSizeClass = options.largestSizeClassPages; + allocatorOptions.useMmapArena = options.useMmapArena; + allocatorOptions.mmapArenaCapacityRatio = options.mmapArenaCapacityRatio; + allocatorOptions.smallAllocationReservePct = + options.smallAllocationReservePct; + allocatorOptions.maxMallocBytes = options.maxMallocBytes; + return std::make_shared(allocatorOptions); } else { - return std::make_shared( - options.allocatorCapacity, - options.allocationSizeThresholdWithReservation); + allocatorOptions.reservationByteLimit = + options.allocationSizeThresholdWithReservation; + allocatorOptions.mallocContiguousEnabled = options.mallocContiguousEnabled; + return std::make_shared(allocatorOptions); } } @@ -382,23 +386,23 @@ std::string MemoryManager::toString(bool detail) const { : succinctBytes(allocatorCapacity)) << " alignment " << succinctBytes(alignment_) << " usedBytes " << succinctBytes(getTotalBytes()) << " number of pools " << numPools() - << "\n"; + << '\n'; out << "List of root pools:\n"; if (detail) { out << sysRoot_->treeMemoryUsage(false); } else { - out << "\t" << sysRoot_->name() << "\n"; + out << '\t' << sysRoot_->name() << '\n'; } std::vector> pools = getAlivePools(); for (const auto& pool : pools) { if (detail) { out << pool->treeMemoryUsage(false); } else { - out << "\t" << pool->name() << "\n"; + out << '\t' << pool->name() << '\n'; } - out << "\trefcount " << pool.use_count() << "\n"; + out << "\trefcount " << pool.use_count() << '\n'; } - out << allocator_->toString() << "\n"; + out << allocator_->toString() << '\n'; out << arbitrator_->toString(); out << "]"; return out.str(); diff --git a/velox/common/memory/Memory.h b/velox/common/memory/Memory.h index ce8caf1821f..675f72a3fcc 100644 --- a/velox/common/memory/Memory.h +++ b/velox/common/memory/Memory.h @@ -41,7 +41,6 @@ #include "velox/common/memory/MemoryPool.h" DECLARE_bool(velox_memory_leak_check_enabled); -DECLARE_bool(velox_memory_pool_debug_enabled); DECLARE_bool(velox_enable_memory_usage_track_in_default_memory_pool); namespace facebook::velox::memory { @@ -140,6 +139,12 @@ class MemoryManager { /// NOTE: this only applies for MallocAllocator. uint32_t allocationSizeThresholdWithReservation{1 << 20}; + /// If true, MallocAllocator uses malloc for contiguous allocations instead + /// of mmap/munmap. + /// + /// NOTE: this only applies for MallocAllocator. + bool mallocContiguousEnabled{false}; + /// ================== 'MemoryArbitrator' settings ================= /// Memory capacity available for query/task memory pools. This capacity diff --git a/velox/common/memory/MemoryAllocator.cpp b/velox/common/memory/MemoryAllocator.cpp index 2a997ad0045..21aae7906e3 100644 --- a/velox/common/memory/MemoryAllocator.cpp +++ b/velox/common/memory/MemoryAllocator.cpp @@ -15,10 +15,11 @@ */ #include "velox/common/memory/MemoryAllocator.h" -#include "velox/common/memory/MallocAllocator.h" #include #include +#include +#include #include #include @@ -29,6 +30,14 @@ DECLARE_bool(velox_memory_use_hugepages); namespace facebook::velox::memory { +void AcquiredMemory::free(MemoryAllocator* allocator) { + allocator->freeNonContiguous(nonContiguousAllocs); + for (auto& [ptr, size] : byteAllocations) { + allocator->freeBytes(ptr, size); + } + byteAllocations.clear(); +} + // static std::vector MemoryAllocator::makeSizeClassSizes( MachinePageCount largest) { @@ -218,8 +227,9 @@ bool MemoryAllocator::allocateNonContiguous( success = allocateNonContiguousWithoutRetry(mix, out); } else { success = cache()->makeSpace( - pagesToAcquire(numPages, out.numPages()), [&](Allocation& acquired) { - freeNonContiguous(acquired); + pagesToAcquire(numPages, out.numPages()), + [&](AcquiredMemory& acquired) { + acquired.free(this); return allocateNonContiguousWithoutRetry(mix, out); }); } @@ -286,8 +296,8 @@ bool MemoryAllocator::allocateContiguous( } else { success = cache()->makeSpace( pagesToAcquire(numPages, numCollateralPages), - [&](Allocation& acquired) { - freeNonContiguous(acquired); + [&](AcquiredMemory& acquired) { + acquired.free(this); return allocateContiguousWithoutRetry( numPages, collateral, allocation, maxPages); }); @@ -320,8 +330,8 @@ bool MemoryAllocator::growContiguous( if (cache() == nullptr) { success = growContiguousWithoutRetry(increment, allocation); } else { - success = cache()->makeSpace(increment, [&](Allocation& acquired) { - freeNonContiguous(acquired); + success = cache()->makeSpace(increment, [&](AcquiredMemory& acquired) { + acquired.free(this); return growContiguousWithoutRetry(increment, allocation); }); } @@ -337,8 +347,8 @@ void* MemoryAllocator::allocateBytes(uint64_t bytes, uint16_t alignment) { } void* result = nullptr; cache()->makeSpace( - AllocationTraits::numPages(bytes), [&](Allocation& acquired) { - freeNonContiguous(acquired); + AllocationTraits::numPages(bytes), [&](AcquiredMemory& acquired) { + acquired.free(this); result = allocateBytesWithoutRetry(bytes, alignment); return result != nullptr; }); @@ -351,8 +361,8 @@ void* MemoryAllocator::allocateZeroFilled(uint64_t bytes) { } void* result = nullptr; cache()->makeSpace( - AllocationTraits::numPages(bytes), [&](Allocation& acquired) { - freeNonContiguous(acquired); + AllocationTraits::numPages(bytes), [&](AcquiredMemory& acquired) { + acquired.free(this); result = allocateZeroFilledWithoutRetry(bytes); return result != nullptr; }); @@ -367,6 +377,40 @@ void* MemoryAllocator::allocateZeroFilledWithoutRetry(uint64_t bytes) { return result; } +void* MemoryAllocator::reallocateBytes( + void* p, + uint64_t oldSize, + uint64_t newSize, + uint16_t alignment) { + // Try in-place reallocation first (supported by MallocAllocator via + // ::realloc(), which jemalloc can often service without moving data). + void* result = reallocateBytesWithoutRetry(p, oldSize, newSize, alignment); + if (result != nullptr) { + return result; + } + // Fallback: allocate new + memcpy + free old. This path also handles + // cache eviction via allocateBytes. + void* newP = allocateBytes(newSize, alignment); + if (newP == nullptr) { + return nullptr; + } + if (p != nullptr) { + ::memcpy(newP, p, std::min(oldSize, newSize)); + freeBytes(p, oldSize); + } + return newP; +} + +void* MemoryAllocator::reallocateBytesWithoutRetry( + void* /*p*/, + uint64_t /*oldSize*/, + uint64_t /*newSize*/, + uint16_t /*alignment*/) { + // Default: in-place reallocation not supported. Caller will fall back to + // allocateBytes + memcpy + freeBytes. + return nullptr; +} + Stats Stats::operator-(const Stats& other) const { Stats result; for (auto i = 0; i < sizes.size(); ++i) { diff --git a/velox/common/memory/MemoryAllocator.h b/velox/common/memory/MemoryAllocator.h index 621fe2e2a88..bfa5dde7857 100644 --- a/velox/common/memory/MemoryAllocator.h +++ b/velox/common/memory/MemoryAllocator.h @@ -140,23 +140,44 @@ struct Stats { class MemoryAllocator; +/// Memory collected from evicted cache entries. Non-contiguous page +/// allocations and byte allocations (from allocateBytes()) are tracked +/// separately so they can be freed back to the allocator before retrying +/// allocation. +struct AcquiredMemory { + Allocation nonContiguousAllocs; + // Byte allocations from allocateBytes(). + std::vector> byteAllocations; + + uint64_t totalBytes() const { + uint64_t bytes = nonContiguousAllocs.byteSize(); + for (const auto& [_, size] : byteAllocations) { + bytes += size; + } + return bytes; + } + + void free(MemoryAllocator* allocator); + + bool empty() const { + return nonContiguousAllocs.empty() && byteAllocations.empty(); + } +}; + /// A general cache interface using 'MemoryAllocator' to allocate memory, that /// is also able to free up memory upon request by shrinking itself. class Cache { public: virtual ~Cache() = default; - /// This method should be implemented so that it tries to - /// accommodate the passed in 'allocate' by freeing up space from - /// 'this' if needed. 'numPages' is the number of pages 'allocate - /// needs to be free for allocate to succeed. This should return - /// true if 'allocate' succeeds, and false otherwise. 'numPages' can - /// be less than the planned allocation, even 0 but not - /// negative. This is possible if 'allocate' brings its own memory - /// that is exchanged for the new allocation. + /// Tries to accommodate 'allocate' by freeing up space from 'this' + /// if needed. 'numPages' is the number of pages 'allocate' needs to + /// be free to succeed. Returns true if 'allocate' succeeds. + /// 'allocate' receives an 'AcquiredMemory' containing evicted memory + /// that should be freed before retrying allocation. virtual bool makeSpace( memory::MachinePageCount numPages, - std::function allocate) = 0; + std::function allocate) = 0; /// This method is implemented to shrink the cache space with the specified /// 'targetBytes'. The method returns the actually freed cache space in bytes. @@ -191,6 +212,45 @@ std::string getAndClearCacheFailureMessage(); /// tracking while delegating the allocation to a root allocator. class MemoryAllocator : public std::enable_shared_from_this { public: + struct Options { + /// Capacity in bytes, default unlimited. + /// Applies to: MallocAllocator, MmapAllocator. + size_t capacity{static_cast(std::numeric_limits::max())}; + + /// --- MallocAllocator-only options --- + + /// Allocation size threshold below which allocations use sharded local + /// counters instead of updating the global counter. Default 1MB. + uint32_t reservationByteLimit{1 << 20}; + + /// If true, use malloc for contiguous allocations instead of mmap/munmap. + bool mallocContiguousEnabled{false}; + + /// --- MmapAllocator-only options --- + + /// Number of pages in the largest size class. + int32_t largestSizeClass{256}; + + /// If set true, allocations larger than largest size class size will be + /// delegated to ManagedMmapArena. Otherwise a system mmap call will be + /// issued for each such allocation. + bool useMmapArena{false}; + + /// Used to determine MmapArena capacity. The ratio represents system + /// memory capacity to single MmapArena capacity ratio. + int32_t mmapArenaCapacityRatio{10}; + + /// If not zero, reserve 'smallAllocationReservePct'% of space from + /// 'capacity' for ad hoc small allocations delegated to std::malloc. + /// If 'maxMallocBytes' is 0, this value will be disregarded. + uint32_t smallAllocationReservePct{0}; + + /// The allocation threshold less than which an allocation is delegated + /// to std::malloc(). If zero, no allocations are delegated to malloc + /// and 'smallAllocationReservePct' is automatically set to 0. + int32_t maxMallocBytes{3072}; + }; + /// Defines the memory allocator kinds. enum class Kind { /// The default memory allocator kind which is implemented by @@ -311,6 +371,20 @@ class MemoryAllocator : public std::enable_shared_from_this { /// 'cache()' if registered. But sufficient space is not guaranteed. void* allocateZeroFilled(uint64_t bytes); + /// Reallocates contiguous memory. Tries in-place reallocation first (via the + /// allocator-specific reallocateBytesWithoutRetry), falling back to + /// allocateBytes + memcpy + freeBytes if the allocator does not support + /// in-place reallocation. Returns nullptr on failure. + /// + /// When the underlying allocator is MallocAllocator (backed by jemalloc), + /// this uses ::realloc() which can often expand the allocation in-place, + /// avoiding the expensive memcpy. + void* reallocateBytes( + void* p, + uint64_t oldSize, + uint64_t newSize, + uint16_t alignment = kMinAlignment); + /// Frees contiguous memory allocated by allocateBytes, allocateZeroFilled, /// reallocateBytes. virtual void freeBytes(void* p, uint64_t size) noexcept = 0; @@ -451,6 +525,16 @@ class MemoryAllocator : public std::enable_shared_from_this { virtual void* allocateZeroFilledWithoutRetry(uint64_t bytes); + // Attempts to reallocate 'p' from 'oldSize' to 'newSize' bytes without + // retry through cache eviction. Returns nullptr if in-place reallocation + // is not supported or fails. The default implementation always returns + // nullptr; MallocAllocator overrides this to use ::realloc(). + virtual void* reallocateBytesWithoutRetry( + void* p, + uint64_t oldSize, + uint64_t newSize, + uint16_t alignment); + virtual bool growContiguousWithoutRetry( MachinePageCount increment, ContiguousAllocation& allocation) = 0; diff --git a/velox/common/memory/MemoryArbitrator.cpp b/velox/common/memory/MemoryArbitrator.cpp index eb817300608..da6f1a812a1 100644 --- a/velox/common/memory/MemoryArbitrator.cpp +++ b/velox/common/memory/MemoryArbitrator.cpp @@ -99,7 +99,11 @@ class NoopArbitrator : public MemoryArbitrator { } void removePool(MemoryPool* pool) override { - VELOX_CHECK_EQ(pool->reservedBytes(), 0); + VELOX_CHECK_EQ( + pool->reservedBytes(), + 0, + "Memory pool has unexpected reserved bytes on removal: {}", + pool->name()); } // Noop arbitrator has no memory capacity limit so no operation needed for @@ -263,9 +267,10 @@ uint64_t MemoryReclaimer::reclaim( nonReclaimableCandidates.push_back(Candidate{std::move(child), 0}); continue; } - candidates.push_back(Candidate{ - std::move(child), - static_cast(reclaimableBytesOpt.value())}); + candidates.push_back( + Candidate{ + std::move(child), + static_cast(reclaimableBytesOpt.value())}); } } } @@ -539,10 +544,14 @@ ScopedReclaimedBytesRecorder::~ScopedReclaimedBytesRecorder() { const int64_t reservedBytesAfterReclaim = pool_->reservedBytes(); if (reservedBytesAfterReclaim > reservedBytesBeforeReclaim_) { LOG(ERROR) << "Unexpected reserved bytes growth from " << pool_->name() + << ", root pool: " << pool_->root()->name() << " after memory reclaim from " << succinctBytes(reservedBytesBeforeReclaim_) << " to " - << succinctBytes(reservedBytesAfterReclaim) << ", current usage " - << succinctBytes(pool_->usedBytes()); + << succinctBytes(reservedBytesAfterReclaim) + << ", used: " << succinctBytes(pool_->usedBytes()) + << ", reservation: " << succinctBytes(pool_->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool_->root()->reservedBytes()); } *reclaimedBytes_ = reservedBytesBeforeReclaim_ - reservedBytesAfterReclaim; } diff --git a/velox/common/memory/MemoryPool.cpp b/velox/common/memory/MemoryPool.cpp index 2a05c4774c5..1a902589bd3 100644 --- a/velox/common/memory/MemoryPool.cpp +++ b/velox/common/memory/MemoryPool.cpp @@ -18,6 +18,7 @@ #include +#include "velox/common/Casts.h" #include "velox/common/base/Counters.h" #include "velox/common/base/StatsReporter.h" #include "velox/common/base/SuccinctPrinter.h" @@ -31,6 +32,13 @@ DEFINE_bool( false, "Whether allow to memory capacity transfer between memory pools from different tasks, which might happen in use case like Spark-Gluten"); +DEFINE_bool( + velox_enable_inplace_realloc, + true, + "If true, MemoryPool::reallocate uses MemoryAllocator::reallocateBytes " + "which tries in-place reallocation via ::realloc() (jemalloc can often " + "expand without memcpy). If false, uses the legacy alloc+memcpy+free path."); + DECLARE_bool(velox_suppress_memory_capacity_exceeding_error_message); using facebook::velox::common::testutil::TestValue; @@ -38,14 +46,14 @@ using facebook::velox::common::testutil::TestValue; namespace facebook::velox::memory { namespace { // Check if memory operation is allowed and increment the named stats. -#define CHECK_AND_INC_MEM_OP_STATS(stats) \ +#define CHECK_AND_INC_MEM_OP_STATS(pool, stats) \ do { \ - if (FOLLY_UNLIKELY(kind_ != Kind::kLeaf)) { \ + if (FOLLY_UNLIKELY(pool->kind_ != Kind::kLeaf)) { \ VELOX_FAIL( \ "Memory operation is only allowed on leaf memory pool: {}", \ - toString()); \ + pool->toString()); \ } \ - ++num##stats##_; \ + ++pool->num##stats##_; \ } while (0) // Check if memory operation is allowed and increment the named stats. @@ -153,9 +161,9 @@ std::string capacityToString(int64_t capacity) { return capacity == kMaxMemory ? "UNLIMITED" : succinctBytes(capacity); } -#define DEBUG_RECORD_ALLOC(...) \ - if (FOLLY_UNLIKELY(debugEnabled())) { \ - recordAllocDbg(__VA_ARGS__); \ +#define DEBUG_RECORD_ALLOC(pool, ...) \ + if (FOLLY_UNLIKELY(pool->debugEnabled())) { \ + pool->recordAllocDbg(__VA_ARGS__); \ } #define DEBUG_RECORD_FREE(...) \ if (FOLLY_UNLIKELY(debugEnabled())) { \ @@ -521,60 +529,93 @@ void* MemoryPoolImpl::allocate( } } - CHECK_AND_INC_MEM_OP_STATS(Allocs); + CHECK_AND_INC_MEM_OP_STATS(this, Allocs); const auto alignedSize = sizeAlign(size); reserve(alignedSize); void* buffer = allocator_->allocateBytes(alignedSize, alignment_); if (FOLLY_UNLIKELY(buffer == nullptr)) { release(alignedSize); - handleAllocationFailure(fmt::format( - "{} failed with {} from {} {}", - __FUNCTION__, - succinctBytes(size), - toString(), - allocator_->getAndClearFailureMessage())); - } - DEBUG_RECORD_ALLOC(buffer, size); + handleAllocationFailure( + fmt::format( + "{} failed with {} from {} {}", + __FUNCTION__, + succinctBytes(size), + toString(), + allocator_->getAndClearFailureMessage())); + } + DEBUG_RECORD_ALLOC(this, buffer, size); return buffer; } void* MemoryPoolImpl::allocateZeroFilled(int64_t numEntries, int64_t sizeEach) { - CHECK_AND_INC_MEM_OP_STATS(Allocs); + CHECK_AND_INC_MEM_OP_STATS(this, Allocs); const auto size = sizeEach * numEntries; const auto alignedSize = sizeAlign(size); reserve(alignedSize); void* buffer = allocator_->allocateZeroFilled(alignedSize); if (FOLLY_UNLIKELY(buffer == nullptr)) { release(alignedSize); - handleAllocationFailure(fmt::format( - "{} failed with {} entries and {} each from {} {}", - __FUNCTION__, - numEntries, - succinctBytes(sizeEach), - toString(), - allocator_->getAndClearFailureMessage())); - } - DEBUG_RECORD_ALLOC(buffer, size); + handleAllocationFailure( + fmt::format( + "{} failed with {} entries and {} each from {} {}", + __FUNCTION__, + numEntries, + succinctBytes(sizeEach), + toString(), + allocator_->getAndClearFailureMessage())); + } + DEBUG_RECORD_ALLOC(this, buffer, size); return buffer; } void* MemoryPoolImpl::reallocate(void* p, int64_t size, int64_t newSize) { - CHECK_AND_INC_MEM_OP_STATS(Allocs); + CHECK_AND_INC_MEM_OP_STATS(this, Allocs); const auto alignedNewSize = sizeAlign(newSize); + const auto alignedOldSize = sizeAlign(size); reserve(alignedNewSize); - void* newP = allocator_->allocateBytes(alignedNewSize, alignment_); + // Two fully separate branches are kept here intentionally so the legacy + // path can be deleted as a single block once the in-place path is rolled + // out and FLAGS_velox_enable_inplace_realloc is removed. + if (FLAGS_velox_enable_inplace_realloc) { + // In-place path: try ::realloc() via reallocateBytes, which jemalloc can + // often service without moving data (avoiding the expensive memcpy). + void* const newP = allocator_->reallocateBytes( + p, alignedOldSize, alignedNewSize, alignment_); + if (newP == nullptr) { + release(alignedNewSize); + handleAllocationFailure( + fmt::format( + "{} failed with new {} and old {} from {} {}", + __FUNCTION__, + succinctBytes(newSize), + succinctBytes(size), + toString(), + allocator_->getAndClearFailureMessage())); + } + if (p != nullptr) { + INC_MEM_OP_STATS(Frees); + DEBUG_RECORD_FREE(p, size); + release(alignedOldSize); + } + DEBUG_RECORD_ALLOC(this, newP, newSize); + return newP; + } + + // Legacy path: allocate new + memcpy + free old. + void* const newP = allocator_->allocateBytes(alignedNewSize, alignment_); if (FOLLY_UNLIKELY(newP == nullptr)) { release(alignedNewSize); - handleAllocationFailure(fmt::format( - "{} failed with new {} and old {} from {} {}", - __FUNCTION__, - succinctBytes(newSize), - succinctBytes(size), - toString(), - allocator_->getAndClearFailureMessage())); - } - DEBUG_RECORD_ALLOC(newP, newSize); + handleAllocationFailure( + fmt::format( + "{} failed with new {} and old {} from {} {}", + __FUNCTION__, + succinctBytes(newSize), + succinctBytes(size), + toString(), + allocator_->getAndClearFailureMessage())); + } + DEBUG_RECORD_ALLOC(this, newP, newSize); if (p != nullptr) { ::memcpy(newP, p, std::min(size, newSize)); free(p, size); @@ -583,18 +624,40 @@ void* MemoryPoolImpl::reallocate(void* p, int64_t size, int64_t newSize) { } void MemoryPoolImpl::free(void* p, int64_t size) { - CHECK_AND_INC_MEM_OP_STATS(Frees); + CHECK_AND_INC_MEM_OP_STATS(this, Frees); const auto alignedSize = sizeAlign(size); DEBUG_RECORD_FREE(p, size); allocator_->freeBytes(p, alignedSize); release(alignedSize); } +bool MemoryPoolImpl::transferTo(MemoryPool* dest, void* buffer, uint64_t size) { + if (!isLeaf() || !dest->isLeaf()) { + return false; + } + VELOX_CHECK_NOT_NULL(dest); + auto* destImpl = checkedPointerCast(dest); + if (allocator_ != destImpl->allocator_) { + return false; + } + + CHECK_AND_INC_MEM_OP_STATS(destImpl, Allocs); + const auto alignedSize = sizeAlign(size); + destImpl->reserve(alignedSize); + DEBUG_RECORD_ALLOC(destImpl, buffer, size); + + CHECK_AND_INC_MEM_OP_STATS(this, Frees); + DEBUG_RECORD_FREE(buffer, size); + release(alignedSize); + + return true; +} + void MemoryPoolImpl::allocateNonContiguous( MachinePageCount numPages, Allocation& out, MachinePageCount minSizeClass) { - CHECK_AND_INC_MEM_OP_STATS(Allocs); + CHECK_AND_INC_MEM_OP_STATS(this, Allocs); if (!out.empty()) { INC_MEM_OP_STATS(Frees); } @@ -615,21 +678,22 @@ void MemoryPoolImpl::allocateNonContiguous( }, minSizeClass)) { VELOX_CHECK(out.empty()); - handleAllocationFailure(fmt::format( - "{} failed with {} pages from {} {}", - __FUNCTION__, - numPages, - toString(), - allocator_->getAndClearFailureMessage())); - } - DEBUG_RECORD_ALLOC(out); + handleAllocationFailure( + fmt::format( + "{} failed with {} pages from {} {}", + __FUNCTION__, + numPages, + toString(), + allocator_->getAndClearFailureMessage())); + } + DEBUG_RECORD_ALLOC(this, out); VELOX_CHECK(!out.empty()); VELOX_CHECK_NULL(out.pool()); out.setPool(this); } void MemoryPoolImpl::freeNonContiguous(Allocation& allocation) { - CHECK_AND_INC_MEM_OP_STATS(Frees); + CHECK_AND_INC_MEM_OP_STATS(this, Frees); DEBUG_RECORD_FREE(allocation); const int64_t freedBytes = allocator_->freeNonContiguous(allocation); VELOX_CHECK(allocation.empty()); @@ -648,7 +712,7 @@ void MemoryPoolImpl::allocateContiguous( MachinePageCount numPages, ContiguousAllocation& out, MachinePageCount maxPages) { - CHECK_AND_INC_MEM_OP_STATS(Allocs); + CHECK_AND_INC_MEM_OP_STATS(this, Allocs); if (!out.empty()) { INC_MEM_OP_STATS(Frees); } @@ -667,21 +731,22 @@ void MemoryPoolImpl::allocateContiguous( }, maxPages)) { VELOX_CHECK(out.empty()); - handleAllocationFailure(fmt::format( - "{} failed with {} pages from {} {}", - __FUNCTION__, - numPages, - toString(), - allocator_->getAndClearFailureMessage())); - } - DEBUG_RECORD_ALLOC(out); + handleAllocationFailure( + fmt::format( + "{} failed with {} pages from {} {}", + __FUNCTION__, + numPages, + toString(), + allocator_->getAndClearFailureMessage())); + } + DEBUG_RECORD_ALLOC(this, out); VELOX_CHECK(!out.empty()); VELOX_CHECK_NULL(out.pool()); out.setPool(this); } void MemoryPoolImpl::freeContiguous(ContiguousAllocation& allocation) { - CHECK_AND_INC_MEM_OP_STATS(Frees); + CHECK_AND_INC_MEM_OP_STATS(this, Frees); const int64_t bytesToFree = allocation.size(); DEBUG_RECORD_FREE(allocation); allocator_->freeContiguous(allocation); @@ -700,12 +765,13 @@ void MemoryPoolImpl::growContiguous( release(allocBytes); } })) { - handleAllocationFailure(fmt::format( - "{} failed with {} pages from {} {}", - __FUNCTION__, - increment, - toString(), - allocator_->getAndClearFailureMessage())); + handleAllocationFailure( + fmt::format( + "{} failed with {} pages from {} {}", + __FUNCTION__, + increment, + toString(), + allocator_->getAndClearFailureMessage())); } if (FOLLY_UNLIKELY(debugEnabled())) { recordGrowDbg(allocation.data(), allocation.size()); @@ -775,7 +841,7 @@ std::shared_ptr MemoryPoolImpl::genChild( } bool MemoryPoolImpl::maybeReserve(uint64_t increment) { - CHECK_AND_INC_MEM_OP_STATS(Reserves); + CHECK_AND_INC_MEM_OP_STATS(this, Reserves); TestValue::adjust( "facebook::velox::common::memory::MemoryPoolImpl::maybeReserve", this); // TODO: make this a configurable memory pool option. @@ -881,9 +947,16 @@ void MemoryPoolImpl::growCapacity(MemoryPool* requestor, uint64_t size) { VELOX_CHECK(requestor->isLeaf()); ++numCapacityGrowths_; - { + try { MemoryPoolArbitrationSection arbitrationSection(requestor); arbitrator_->growCapacity(this, size); + } catch (const VeloxRuntimeError& veloxError) { + if (FOLLY_UNLIKELY( + debugEnabled() && + veloxError.errorCode() == error_code::kMemCapExceeded)) { + std::rethrow_exception(wrapExceptionDbg(veloxError)); + } + throw; } // The memory pool might have been aborted during the time it leaves the // arbitration no matter the arbitration succeed or not. @@ -923,7 +996,7 @@ void MemoryPoolImpl::incrementReservationLocked(uint64_t bytes) { } void MemoryPoolImpl::release() { - CHECK_AND_INC_MEM_OP_STATS(Releases); + CHECK_AND_INC_MEM_OP_STATS(this, Releases); release(0, true); } @@ -980,6 +1053,21 @@ void MemoryPoolImpl::decrementReservation(uint64_t size) noexcept { sanityCheckLocked(); } +std::string MemoryPoolImpl::toString(bool detail) const { + std::string result; + { + std::lock_guard l(mutex_); + result = toStringLocked(); + } + if (detail) { + result += "\n" + treeMemoryUsage(); + } + if (FOLLY_UNLIKELY(debugEnabled())) { + result += "\n" + dumpRecordsDbg(); + } + return result; +} + std::string MemoryPoolImpl::treeMemoryUsage(bool skipEmptyPool) const { if (parent_ != nullptr) { return parent_->treeMemoryUsage(skipEmptyPool); @@ -1208,10 +1296,7 @@ void MemoryPoolImpl::recordAllocDbg(const void* addr, uint64_t size) { debugWarnThresholdExceeded_) { return; } - const auto usedBytes = [this]() -> int64_t { - std::lock_guard l(mutex_); - return reservedBytes(); - }(); + const auto usedBytes = reservedBytes(); if (usedBytes >= debugOptions_->debugPoolWarnThresholdBytes) { debugWarnThresholdExceeded_ = true; VELOX_MEM_LOG(WARNING) << fmt::format( @@ -1225,7 +1310,7 @@ void MemoryPoolImpl::recordAllocDbg(const void* addr, uint64_t size) { succinctBytes(size), succinctBytes(usedBytes), it->second.callStack.toString(), - dumpRecordsDbg()); + dumpRecordsDbgLocked()); } } @@ -1259,16 +1344,17 @@ void MemoryPoolImpl::recordFreeDbg(const void* addr, uint64_t size) { const auto allocRecord = allocResult->second; if (allocRecord.size != size) { const auto freeStackTrace = process::StackTrace().toString(); - VELOX_FAIL(fmt::format( - "[MemoryPool] Trying to free {} bytes on an allocation of {} bytes.\n" - "======== Allocation Stack ========\n" - "{}\n" - "============ Free Stack ==========\n" - "{}\n", - size, - allocRecord.size, - allocRecord.callStack.toString(), - freeStackTrace)); + VELOX_FAIL( + fmt::format( + "[MemoryPool] Trying to free {} bytes on an allocation of {} bytes.\n" + "======== Allocation Stack ========\n" + "{}\n" + "============ Free Stack ==========\n" + "{}\n", + size, + allocRecord.size, + allocRecord.callStack.toString(), + freeStackTrace)); } debugAllocRecords_.erase(addrUint64); } @@ -1308,16 +1394,79 @@ void MemoryPoolImpl::leakCheckDbg() { if (debugAllocRecords_.empty()) { return; } - VELOX_FAIL(fmt::format( - "[MemoryPool] Leak check failed for '{}' pool - {}", - name_, - dumpRecordsDbg())); + VELOX_FAIL( + fmt::format( + "[MemoryPool] Leak check failed for '{}' pool - {}", + name_, + dumpRecordsDbg())); +} + +void MemoryPoolImpl::treeAllocationRecordsDbg( + std::vector& poolDumps) const { + VELOX_CHECK(debugEnabled()); + { + std::lock_guard debugAllocLock(debugAllocMutex_); + if (!debugAllocRecords_.empty()) { + MemoryPoolDump dump{ + .dumpedRecords = fmt::format( + "Memory pool '{}' - {}", name(), dumpRecordsDbgLocked()), + .bytes = reservedBytes(), + }; + poolDumps.emplace_back(std::move(dump)); + } + } + if (isLeaf()) { + return; + } + visitChildren([&poolDumps](MemoryPool* pool) { + toImpl(pool)->treeAllocationRecordsDbg(poolDumps); + return true; + }); } -std::string MemoryPoolImpl::dumpRecordsDbg() { +std::exception_ptr MemoryPoolImpl::wrapExceptionDbg( + const VeloxRuntimeError& veloxError) const { + VELOX_CHECK(debugEnabled()); + VELOX_CHECK(isRoot()); + std::vector poolAllocationsSorted; + treeAllocationRecordsDbg(poolAllocationsSorted); + std::stringstream oss; + if (poolAllocationsSorted.empty()) { + oss << "No allocation records found."; + } else { + std::sort( + poolAllocationsSorted.begin(), + poolAllocationsSorted.end(), + [](const auto& a, const auto& b) { return a.bytes > b.bytes; }); + for (const auto& record : poolAllocationsSorted) { + oss << record.dumpedRecords << "\n\n"; + } + } + const auto wrappedMessage = fmt::format( + "{}\n\n" + "======= Current Allocations ======\n" + "{}", + veloxError.message(), + oss.str()); + return std::make_exception_ptr(VeloxRuntimeError( + veloxError.file(), + veloxError.line(), + veloxError.function(), + veloxError.failingExpression(), + wrappedMessage, + veloxError.errorSource(), + veloxError.errorCode(), + veloxError.isRetriable(), + veloxError.exceptionName())); +} + +std::string MemoryPoolImpl::dumpRecordsDbgLocked() const { VELOX_CHECK(debugEnabled()); std::stringstream oss; - oss << fmt::format("Found {} allocations:\n", debugAllocRecords_.size()); + oss << fmt::format( + "Found {} allocations with {} total size:\n", + debugAllocRecords_.size(), + succinctBytes(reservedBytes())); struct AllocationStats { uint64_t size{0}; uint64_t numAllocations{0}; diff --git a/velox/common/memory/MemoryPool.h b/velox/common/memory/MemoryPool.h index 601f3ed272b..a24edfbb4b7 100644 --- a/velox/common/memory/MemoryPool.h +++ b/velox/common/memory/MemoryPool.h @@ -29,8 +29,8 @@ #include "velox/common/memory/MemoryArbitrator.h" DECLARE_bool(velox_memory_leak_check_enabled); -DECLARE_bool(velox_memory_pool_debug_enabled); DECLARE_bool(velox_memory_pool_capacity_transfer_across_tasks); +DECLARE_bool(velox_enable_inplace_realloc); namespace facebook::velox::exec { class ParallelMemoryReclaimer; @@ -42,6 +42,9 @@ class MemoryManager; constexpr int64_t kMaxMemory = std::numeric_limits::max(); +template +class StlAllocator; + /// This class provides the memory allocation interfaces for a query execution. /// Each query execution entity creates a dedicated memory pool object. The /// memory pool objects from a query are organized as a tree with four levels @@ -91,6 +94,9 @@ constexpr int64_t kMaxMemory = std::numeric_limits::max(); /// also provides memory usage accounting. class MemoryPool : public std::enable_shared_from_this { public: + template + using TStlAllocator = StlAllocator; + /// Defines the kinds of a memory pool. enum class Kind { /// The leaf memory pool is used for memory allocation. User can allocate @@ -246,6 +252,13 @@ class MemoryPool : public std::enable_shared_from_this { /// Frees an allocated buffer. virtual void free(void* p, int64_t size) = 0; + /// Transfer the ownership of memory at 'buffer' for 'size' bytes to the + /// memory pool 'dest'. Returns true if the transfer succeeds. + virtual bool + transferTo(MemoryPool* /*dest*/, void* /*buffer*/, uint64_t /*size*/) { + return false; + } + /// Allocates one or more runs that add up to at least 'numPages', with the /// smallest run being at least 'minSizeClass' pages. 'minSizeClass' must be /// <= the size of the largest size class. The new memory is returned in 'out' @@ -610,6 +623,8 @@ class MemoryPoolImpl : public MemoryPool { void free(void* p, int64_t size) override; + bool transferTo(MemoryPool* dest, void* buffer, uint64_t size) override; + void allocateNonContiguous( MachinePageCount numPages, Allocation& out, @@ -673,17 +688,7 @@ class MemoryPoolImpl : public MemoryPool { void setDestructionCallback(const DestructionCallback& callback); - std::string toString(bool detail = false) const override { - std::string result; - { - std::lock_guard l(mutex_); - result = toStringLocked(); - } - if (detail) { - result += "\n" + treeMemoryUsage(); - } - return result; - } + std::string toString(bool detail = false) const override; /// Detailed debug pool state printout by traversing the pool structure from /// the root memory pool. @@ -779,8 +784,7 @@ class MemoryPoolImpl : public MemoryPool { } FOLLY_ALWAYS_INLINE int64_t sizeAlign(int64_t size) const { - const auto remainder = size & (alignment_ - 1); - return (remainder == 0) ? size : (size + alignment_ - remainder); + return (size + alignment_ - 1) & ~(alignment_ - 1); } // Returns a rounded up delta based on adding 'delta' to 'size'. Adding the @@ -1003,9 +1007,32 @@ class MemoryPoolImpl : public MemoryPool { // pool is enabled. void leakCheckDbg(); + // Holds formatted string of dumped allocation records for a leaf memory pool, + // along with the total pool size in bytes. + struct MemoryPoolDump { + std::string dumpedRecords; + int64_t bytes; + }; + + // Recursively collects 'MemoryPoolDump' records for this memory pool and + // all its descendants in the tree. Called during memory capacity-exceeded + // exceptions to extend the error message with additional debug information. + void treeAllocationRecordsDbg(std::vector& poolDumps) const; + + // Wraps the message of a memory capacity exceeded exception with debug + // allocation records from all memory pools in the subtree. This function is + // called from the root memory pool. + std::exception_ptr wrapExceptionDbg( + const VeloxRuntimeError& veloxError) const; + // Dump the recorded call sites of the memory allocations in // 'debugAllocRecords_' to the string. - std::string dumpRecordsDbg(); + std::string dumpRecordsDbgLocked() const; + + std::string dumpRecordsDbg() const { + std::lock_guard l(debugAllocMutex_); + return dumpRecordsDbgLocked(); + } void handleAllocationFailure(const std::string& failureMessage); @@ -1070,7 +1097,7 @@ class MemoryPoolImpl : public MemoryPool { std::atomic_uint64_t numCapacityGrowths_{0}; // Mutex for 'debugAllocRecords_'. - std::mutex debugAllocMutex_; + mutable std::mutex debugAllocMutex_; // Map from address to 'AllocationRecord'. std::unordered_map debugAllocRecords_; @@ -1084,25 +1111,29 @@ template class StlAllocator { public: typedef T value_type; - MemoryPool& pool; + MemoryPool* pool; - /* implicit */ StlAllocator(MemoryPool& pool) : pool{pool} {} + /* implicit */ StlAllocator(MemoryPool& pool) : pool{&pool} {} + + explicit StlAllocator(MemoryPool* pool) : pool{pool} { + VELOX_CHECK_NOT_NULL(pool); + } template /* implicit */ StlAllocator(const StlAllocator& a) : pool{a.pool} {} T* allocate(size_t n) { - return static_cast(pool.allocate(checkedMultiply(n, sizeof(T)))); + return static_cast(pool->allocate(checkedMultiply(n, sizeof(T)))); } void deallocate(T* p, size_t n) { - pool.free(p, checkedMultiply(n, sizeof(T))); + pool->free(p, checkedMultiply(n, sizeof(T))); } template bool operator==(const StlAllocator& rhs) const { if constexpr (std::is_same_v) { - return &this->pool == &rhs.pool; + return this->pool == rhs.pool; } return false; } diff --git a/velox/common/memory/MmapAllocator.cpp b/velox/common/memory/MmapAllocator.cpp index 178ed8b7a50..23676255267 100644 --- a/velox/common/memory/MmapAllocator.cpp +++ b/velox/common/memory/MmapAllocator.cpp @@ -33,9 +33,11 @@ MmapAllocator::MmapAllocator(const Options& options) maxMallocBytes_ == 0 ? 0 : options.capacity * options.smallAllocationReservePct / 100), - capacity_(bits::roundUp( - AllocationTraits::numPages(options.capacity - mallocReservedBytes_), - 64 * sizeClassSizes_.back())) { + capacity_( + bits::roundUp( + AllocationTraits::numPages( + options.capacity - mallocReservedBytes_), + 64 * sizeClassSizes_.back())) { for (const auto& size : sizeClassSizes_) { sizeClasses_.push_back(std::make_unique(capacity_ / size, size)); } @@ -675,8 +677,7 @@ bool MmapAllocator::SizeClass::allocateLocked( namespace { bool isAllZero(xsimd::batch bits) { - return simd::allSetBitMask() == - simd::toBitMask(bits == xsimd::broadcast(0)); + return simd::all(bits == xsimd::broadcast(0)); } } // namespace diff --git a/velox/common/memory/MmapAllocator.h b/velox/common/memory/MmapAllocator.h index f771db0071d..2568ca45863 100644 --- a/velox/common/memory/MmapAllocator.h +++ b/velox/common/memory/MmapAllocator.h @@ -51,37 +51,6 @@ using ClassPageCount = int32_t; /// malloc. class MmapAllocator : public MemoryAllocator { public: - struct Options { - /// Capacity in bytes, default unlimited. - uint64_t capacity{kMaxMemory}; - - int32_t largestSizeClass{256}; - - /// If set true, allocations larger than largest size class size will be - /// delegated to ManagedMmapArena. Otherwise a system mmap call will be - /// issued for each such allocation. - bool useMmapArena = false; - - /// Used to determine MmapArena capacity. The ratio represents system memory - /// capacity to single MmapArena capacity ratio. - int32_t mmapArenaCapacityRatio = 10; - - /// If not zero, reserve 'smallAllocationReservePct'% of space from - /// 'capacity' for ad hoc small allocations. And those allocations are - /// delegated to std::malloc. - /// - /// NOTE: if 'maxMallocBytes' is 0, this value will be disregarded. - uint32_t smallAllocationReservePct = 0; - - /// The allocation threshold less than which an allocation is delegated to - /// std::malloc(). - /// - /// NOTE: if it is zero, then we don't delegate any allocation std::malloc, - /// and 'smallAllocationReservePct' will be automatically set to 0 - /// disregarding any passed in value. - int32_t maxMallocBytes = 3072; - }; - explicit MmapAllocator(const Options& options); ~MmapAllocator(); diff --git a/velox/common/memory/RawVector.cpp b/velox/common/memory/RawVector.cpp index 2463d8d3da3..c1ec915d094 100644 --- a/velox/common/memory/RawVector.cpp +++ b/velox/common/memory/RawVector.cpp @@ -24,7 +24,7 @@ namespace { raw_vector iotaData; bool initializeIota() { - iotaData.resize(10000); + iotaData.resize(1'000'000); iotaData.resize(iotaData.capacity()); std::iota(iotaData.begin(), iotaData.end(), 0); return true; diff --git a/velox/common/memory/RawVector.h b/velox/common/memory/RawVector.h index a55c57c17ee..686edbf56f9 100644 --- a/velox/common/memory/RawVector.h +++ b/velox/common/memory/RawVector.h @@ -135,6 +135,22 @@ class raw_vector { size_ = 0; } + /// Releases unused capacity. If empty, frees all memory. + void shrink_to_fit() { + if (size_ == 0) { + free(); + capacity_ = 0; + return; + } + if (calculateCapacity(size_) < capacity_) { + auto* newData = allocateData(size_); + memcpy(newData, data_, size_ * sizeof(T)); + free(); + data_ = newData; + capacity_ = calculateCapacity(size_); + } + } + void resize(int64_t size) { if (LIKELY(size <= capacity_)) { size_ = size; @@ -212,9 +228,12 @@ class raw_vector { // Clear the word below the pointer so that we do not get read of // uninitialized when reading a partial word that extends below // the pointer. + // Suppress GCC14 warning. "error: writing 8 bytes into a region of size 0" + VELOX_SUPPRESS_STRINGOP_OVERFLOW_WARNING *reinterpret_cast( reinterpret_cast(getDataFromBuffer(buffer)) - sizeof(int64_t)) = 0; + VELOX_UNSUPPRESS_STRINGOP_OVERFLOW_WARNING return getDataFromBuffer(buffer); } diff --git a/velox/common/memory/Scratch.h b/velox/common/memory/Scratch.h index dcde26e90d9..bed226a0f94 100644 --- a/velox/common/memory/Scratch.h +++ b/velox/common/memory/Scratch.h @@ -79,14 +79,17 @@ class Scratch { // stringop-overflow warning when 'newCapacity' is 0. folly::assume(capacity_ >= 0); if (newCapacity > capacity_) { - Item* newItems = - reinterpret_cast(::malloc(sizeof(Item) * newCapacity)); + auto* newItems = + reinterpret_cast(::malloc(sizeof(Item) * newCapacity)); if (fill_ > 0) { ::memcpy(newItems, items_, fill_ * sizeof(Item)); } - ::memset(newItems + fill_, 0, (newCapacity - fill_) * sizeof(Item)); + ::memset( + newItems + fill_ * sizeof(Item), + 0, + (newCapacity - fill_) * sizeof(Item)); ::free(items_); - items_ = newItems; + items_ = reinterpret_cast(newItems); capacity_ = newCapacity; } fill_ = std::min(fill_, newCapacity); diff --git a/velox/common/memory/SharedArbitrator.cpp b/velox/common/memory/SharedArbitrator.cpp index e726d76222e..bd98f4e9f4d 100644 --- a/velox/common/memory/SharedArbitrator.cpp +++ b/velox/common/memory/SharedArbitrator.cpp @@ -15,6 +15,7 @@ */ #include "velox/common/memory/SharedArbitrator.h" +#include #include #include #include @@ -58,25 +59,28 @@ namespace { } #define MEM_POOL_CAP_EXCEEDED(errorMessage, requestPool) \ - VELOX_MEM_POOL_CAP_EXCEEDED(fmt::format( \ - "Exceeded memory pool capacity. {}\n{}\n\n{}", \ - errorMessage, \ - this->toString(), \ - requestPool->toString(true))); + VELOX_MEM_POOL_CAP_EXCEEDED( \ + fmt::format( \ + "Exceeded memory pool capacity. {}\n{}\n\n{}", \ + errorMessage, \ + this->toString(), \ + requestPool->toString(true))); #define LOCAL_MEM_ARBITRATION_FAILED(errorMessage, requestPool) \ - VELOX_MEM_ARBITRATION_FAILED(fmt::format( \ - "Local arbitration failure. {}\n{}\n\n{}", \ - errorMessage, \ - this->toString(), \ - requestPool->toString(true))); + VELOX_MEM_ARBITRATION_FAILED( \ + fmt::format( \ + "Local arbitration failure. {}\n{}\n\n{}", \ + errorMessage, \ + this->toString(), \ + requestPool->toString(true))); #define GLOBAL_MEM_ARBITRATION_FAILED(errorMessage, requestPool) \ - VELOX_MEM_ARBITRATION_FAILED(fmt::format( \ - "Global arbitration failure. {}\n{}\n\n{}", \ - errorMessage, \ - this->toString(), \ - requestPool->toString(true))); + VELOX_MEM_ARBITRATION_FAILED( \ + fmt::format( \ + "Global arbitration failure. {}\n{}\n\n{}", \ + errorMessage, \ + this->toString(), \ + requestPool->toString(true))); template T getConfig( @@ -126,10 +130,11 @@ uint64_t SharedArbitrator::ExtraConfig::memoryPoolReservedCapacity( uint64_t SharedArbitrator::ExtraConfig::maxMemoryArbitrationTimeNs( const std::unordered_map& configs) { return std::chrono::duration_cast( - config::toDuration(getConfig( - configs, - kMaxMemoryArbitrationTime, - std::string(kDefaultMaxMemoryArbitrationTime)))) + config::toDuration( + getConfig( + configs, + kMaxMemoryArbitrationTime, + std::string(kDefaultMaxMemoryArbitrationTime)))) .count(); } @@ -292,8 +297,7 @@ SharedArbitrator::SharedArbitrator(const Config& config) "memoryReclaimThreadsHwMultiplier_ needs to be positive"); const uint64_t numReclaimThreads = std::max( - 1, - std::thread::hardware_concurrency() * memoryReclaimThreadsHwMultiplier_); + 1, folly::available_concurrency() * memoryReclaimThreadsHwMultiplier_); memoryReclaimExecutor_ = std::make_unique( numReclaimThreads, std::make_shared("MemoryReclaim")); @@ -539,14 +543,14 @@ void SharedArbitrator::addPool(const std::shared_ptr& pool) { } void SharedArbitrator::removePool(MemoryPool* pool) { - VELOX_CHECK_EQ(pool->reservedBytes(), 0); + VELOX_CHECK_EQ(pool->reservedBytes(), 0, "{}", pool->name()); const uint64_t freedBytes = shrinkPool(pool, 0); - VELOX_CHECK_EQ(pool->capacity(), 0); + VELOX_CHECK_EQ(pool->capacity(), 0, "{}", pool->name()); freeCapacity(freedBytes); std::unique_lock guard{participantLock_}; const auto ret = participants_.erase(pool->name()); - VELOX_CHECK_EQ(ret, 1); + VELOX_CHECK_EQ(ret, 1, "{}", pool->name()); } std::vector SharedArbitrator::getCandidates( @@ -893,13 +897,16 @@ void SharedArbitrator::growCapacity(ArbitrationOperation& op) { RETURN_IF_TRUE(maybeGrowFromSelf(op)); if (!ensureCapacity(op)) { + const auto maxCapacity = op.participant()->maxCapacity(); MEM_POOL_CAP_EXCEEDED( fmt::format( - "Can't grow {} capacity with {}. This will exceed its max capacity " + "Can't grow {} capacity with {}. This will exceed its {} " "{}, current capacity {}.", op.participant()->name(), succinctBytes(op.requestBytes()), - succinctBytes(op.participant()->maxCapacity()), + capacity_ < maxCapacity ? "arbitrator capacity" + : "memory pool capacity", + succinctBytes(std::min(capacity_, maxCapacity)), succinctBytes(op.participant()->capacity())), op.participant()->pool()); } @@ -1139,10 +1146,11 @@ void SharedArbitrator::checkIfAborted(ArbitrationOperation& op) { void SharedArbitrator::checkIfTimeout(ArbitrationOperation& op) { if (FOLLY_UNLIKELY(op.hasTimeout())) { - VELOX_MEM_ARBITRATION_TIMEOUT(fmt::format( - "Memory arbitration timed out on memory pool: {} after running {}", - op.participant()->name(), - succinctNanos(op.executionTimeNs()))); + VELOX_MEM_ARBITRATION_TIMEOUT( + fmt::format( + "Memory arbitration timed out on memory pool: {} after running {}", + op.participant()->name(), + succinctNanos(op.executionTimeNs()))); } } @@ -1354,14 +1362,15 @@ uint64_t SharedArbitrator::reclaimUsedMemoryByAbort(bool force) { // after abort operation. const auto currentCapacity = victim.participant->pool()->capacity(); try { - VELOX_MEM_POOL_ABORTED(fmt::format( - "Memory pool aborted to reclaim used memory, current capacity {}, " - "requesting capacity from global arbitration {} memory pool " - "stats:\n{}\n{}", - succinctBytes(currentCapacity), - succinctBytes(victim.participant->globalArbitrationGrowCapacity()), - victim.participant->pool()->toString(), - victim.participant->pool()->treeMemoryUsage())); + VELOX_MEM_POOL_ABORTED( + fmt::format( + "Memory pool aborted to reclaim used memory, current capacity {}, " + "requesting capacity from global arbitration {} memory pool " + "stats:\n{}\n{}", + succinctBytes(currentCapacity), + succinctBytes(victim.participant->globalArbitrationGrowCapacity()), + victim.participant->pool()->toString(), + victim.participant->pool()->treeMemoryUsage())); } catch (VeloxRuntimeError&) { abort(victim.participant, std::current_exception()); return currentCapacity; @@ -1395,7 +1404,6 @@ uint64_t SharedArbitrator::reclaim( if (participant->aborted()) { removeGlobalArbitrationWaiter(participant->id()); } - freeCapacity(reclaimedBytes); updateMemoryReclaimStats( reclaimedBytes, reclaimTimeNs, localArbitration, stats); @@ -1407,6 +1415,8 @@ uint64_t SharedArbitrator::reclaim( << " stats " << succinctBytes(stats.reclaimedBytes) << " numNonReclaimableAttempts " << stats.numNonReclaimableAttempts; + + freeCapacity(reclaimedBytes); if (reclaimedBytes == 0) { FB_LOG_EVERY_MS(WARNING, 1'000) << fmt::format( "Nothing reclaimed from memory pool {} with reclaim target {}, memory pool stats:\n{}\n{}", diff --git a/velox/common/memory/SharedArbitrator.h b/velox/common/memory/SharedArbitrator.h index 16efd247110..e9ab03e6b5b 100644 --- a/velox/common/memory/SharedArbitrator.h +++ b/velox/common/memory/SharedArbitrator.h @@ -284,17 +284,17 @@ class SharedArbitrator : public memory::MemoryArbitrator { /// Operator level runtime stats reported for an arbitration operation /// execution. - static inline const std::string kMemoryArbitrationWallNanos{ + static constexpr std::string_view kMemoryArbitrationWallNanos{ "memoryArbitrationWallNanos"}; - static inline const std::string kLocalArbitrationCount{ + static constexpr std::string_view kLocalArbitrationCount{ "localArbitrationCount"}; - static inline const std::string kLocalArbitrationWaitWallNanos{ + static constexpr std::string_view kLocalArbitrationWaitWallNanos{ "localArbitrationWaitWallNanos"}; - static inline const std::string kLocalArbitrationExecutionWallNanos{ + static constexpr std::string_view kLocalArbitrationExecutionWallNanos{ "localArbitrationExecutionWallNanos"}; - static inline const std::string kGlobalArbitrationWaitCount{ + static constexpr std::string_view kGlobalArbitrationWaitCount{ "globalArbitrationWaitCount"}; - static inline const std::string kGlobalArbitrationWaitWallNanos{ + static constexpr std::string_view kGlobalArbitrationWaitWallNanos{ "globalArbitrationWaitWallNanos"}; private: diff --git a/velox/common/memory/StreamArena.cpp b/velox/common/memory/StreamArena.cpp index 15b74ea2a96..a9b50097686 100644 --- a/velox/common/memory/StreamArena.cpp +++ b/velox/common/memory/StreamArena.cpp @@ -43,11 +43,12 @@ void StreamArena::newRange( const int32_t numRuns = allocation_.numRuns(); if (currentRun_ >= numRuns) { if (numRuns > 0) { + // No need to push an empty Allocation into 'allocations_'. allocations_.push_back( std::make_unique(std::move(allocation_))); } pool_->allocateNonContiguous( - std::max(allocationQuantum_, numPages), allocation_); + std::max(kAllocationQuantum, numPages), allocation_); currentRun_ = 0; currentOffset_ = 0; size_ += allocation_.byteSize(); @@ -61,7 +62,7 @@ void StreamArena::newRange( VELOX_DCHECK_LE(currentOffset_, run.numBytes()); if (currentOffset_ == run.numBytes()) { ++currentRun_; - ++currentOffset_ = 0; + currentOffset_ = 0; } } diff --git a/velox/common/memory/StreamArena.h b/velox/common/memory/StreamArena.h index b00b9dff950..fcf09e855c9 100644 --- a/velox/common/memory/StreamArena.h +++ b/velox/common/memory/StreamArena.h @@ -14,6 +14,10 @@ * limitations under the License. */ #pragma once + +#include +#include + #include "velox/common/memory/Memory.h" namespace facebook::velox { @@ -30,20 +34,18 @@ class StreamArena { virtual ~StreamArena() = default; - /// Sets range to the request 'bytes' of writable memory owned by - /// 'this'. We allocate non-contiguous memory to store range bytes - /// if requested 'bytes' is equal or less than the largest class - /// page size. Otherwise, we allocate from contiguous - /// memory. 'range' is set to point to the allocated memory. If - /// 'lastRange' is non-nullptr, it is the last range of the stream - /// to which we are adding the new range. 'lastRange' is nullptr if - /// adding the first range to a stream. The memory is stays owned by - /// 'this' in all cases. Used by HashStringAllocator when extending - /// a multipart entry. The previously last part has its last 8 bytes - /// moved to the next part and gets a pointer to the next part as - /// its last 8 bytes. When extending, we need to update the entry so - /// that the next pointer is not seen when reading the content and - /// is also not counted in the payload size of the multipart entry. + /// Sets range to the request 'bytes' of writable memory owned by 'this'. We + /// allocate non-contiguous memory to store range bytes if requested 'bytes' + /// is equal or less than the largest class page size. Otherwise, we allocate + /// from contiguous memory. 'range' is set to point to the allocated memory. + /// If 'lastRange' is non-nullptr, it is the last range of the stream to which + /// we are adding the new range. 'lastRange' is nullptr if adding the first + /// range to a stream. The memory is stays owned by 'this' in all cases. Used + /// by HashStringAllocator when extending a multipart entry. The previously + /// last part has its last 8 bytes moved to the next part and gets a pointer + /// to the next part as its last 8 bytes. When extending, we need to update + /// the entry so that the next pointer is not seen when reading the content + /// and is also not counted in the payload size of the multipart entry. /// /// NOTE: The method does not guarantee returned 'range' has size of 'bytes', /// it is caller's responsibility to check. @@ -63,13 +65,13 @@ class StreamArena { virtual void clear(); memory::MachinePageCount testingAllocationQuantum() const { - return allocationQuantum_; + return kAllocationQuantum; } private: - memory::MemoryPool* const pool_; + static constexpr memory::MachinePageCount kAllocationQuantum{2}; - const memory::MachinePageCount allocationQuantum_{2}; + memory::MemoryPool* const pool_; // All non-contiguous allocations. std::vector> allocations_; diff --git a/velox/common/memory/tests/AllocationTest.cpp b/velox/common/memory/tests/AllocationTest.cpp index ea6d3e58115..351cb5dfcec 100644 --- a/velox/common/memory/tests/AllocationTest.cpp +++ b/velox/common/memory/tests/AllocationTest.cpp @@ -14,13 +14,11 @@ * limitations under the License. */ +#include "velox/common/memory/Allocation.h" + #include #include "velox/common/base/tests/GTestUtils.h" -#include "velox/common/memory/Memory.h" - -using namespace ::testing; -using namespace facebook::velox::memory; namespace facebook::velox::memory { @@ -105,7 +103,6 @@ TEST_F(AllocationTest, maxPageRunLimit) { "The number of pages to append 131070 exceeds the PageRun limit 65535"); ASSERT_EQ(allocation.numPages(), Allocation::PageRun::kMaxPagesInRun); ASSERT_EQ(allocation.numRuns(), 1); - LOG(ERROR) << "here"; allocation.clear(); } diff --git a/velox/common/memory/tests/ArbitrationParticipantTest.cpp b/velox/common/memory/tests/ArbitrationParticipantTest.cpp index 148fdf773f3..fab25762ca9 100644 --- a/velox/common/memory/tests/ArbitrationParticipantTest.cpp +++ b/velox/common/memory/tests/ArbitrationParticipantTest.cpp @@ -21,8 +21,8 @@ #include #include -#include "folly/experimental/EventCount.h" #include "folly/futures/Barrier.h" +#include "folly/synchronization/EventCount.h" #include "gmock/gmock-matchers.h" #include "velox/common/base/SuccinctPrinter.h" @@ -33,10 +33,10 @@ #include "velox/common/memory/Memory.h" #include "velox/common/memory/MemoryArbitrator.h" #include "velox/common/memory/SharedArbitrator.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/common/testutil/TestValue.h" #include "velox/exec/OperatorUtils.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" DECLARE_bool(velox_memory_leak_check_enabled); DECLARE_bool(velox_suppress_memory_capacity_exceeding_error_message); @@ -233,11 +233,12 @@ class MockTask : public std::enable_shared_from_this { ReclaimInjectionCallback reclaimInjectCb, ArbitrationInjectionCallback arbitrationInjectCb) { root_->setReclaimer(RootMemoryReclaimer::create(shared_from_this())); - pool_->setReclaimer(std::make_unique( - shared_from_this(), - reclaimable, - std::move(reclaimInjectCb), - std::move(arbitrationInjectCb))); + pool_->setReclaimer( + std::make_unique( + shared_from_this(), + reclaimable, + std::move(reclaimInjectCb), + std::move(arbitrationInjectCb))); } const std::shared_ptr& pool() const { diff --git a/velox/common/memory/tests/ByteStreamTest.cpp b/velox/common/memory/tests/ByteStreamTest.cpp index 7ef7eff64b3..1ff755bc18b 100644 --- a/velox/common/memory/tests/ByteStreamTest.cpp +++ b/velox/common/memory/tests/ByteStreamTest.cpp @@ -19,11 +19,12 @@ #include "velox/common/file/FileInputStream.h" #include "velox/common/file/FileSystems.h" #include "velox/common/memory/MmapAllocator.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include using namespace facebook::velox; +using namespace facebook::velox::common::testutil; using namespace facebook::velox::memory; class ByteStreamTest : public testing::Test { @@ -248,10 +249,11 @@ TEST_F(ByteStreamTest, newRangeAllocation) { byteStream.startWrite(0); for (int i = 0; i < testData.newRangeSizes.size(); ++i) { const auto newRangeSize = testData.newRangeSizes[i]; - SCOPED_TRACE(fmt::format( - "iteration {} allocation size {}", - i, - succinctBytes(testData.newRangeSizes[i]))); + SCOPED_TRACE( + fmt::format( + "iteration {} allocation size {}", + i, + succinctBytes(testData.newRangeSizes[i]))); std::string value(newRangeSize, 'a'); byteStream.appendStringView(value); ASSERT_EQ(arena->size(), testData.expectedArenaAllocationSizes[i]); @@ -456,7 +458,7 @@ class InputByteStreamTest : public ByteStreamTest, void SetUp() override { ByteStreamTest::SetUp(); - tempDirPath_ = exec::test::TempDirectoryPath::create(); + tempDirPath_ = TempDirectoryPath::create(); fs_ = filesystems::getFileSystem(tempDirPath_->getPath(), nullptr); } @@ -470,8 +472,9 @@ class InputByteStreamTest : public ByteStreamTest, fmt::format("{}/{}", tempDirPath_->getPath(), fileId_++); auto writeFile = fs_->openFileForWrite(filePath); for (auto& byteRange : byteRanges) { - writeFile->append(std::string_view( - reinterpret_cast(byteRange.buffer), byteRange.size)); + writeFile->append( + std::string_view( + reinterpret_cast(byteRange.buffer), byteRange.size)); } writeFile->close(); return std::make_unique( @@ -480,7 +483,7 @@ class InputByteStreamTest : public ByteStreamTest, } std::atomic_uint64_t fileId_{0}; - std::shared_ptr tempDirPath_; + std::shared_ptr tempDirPath_; std::shared_ptr fs_; }; diff --git a/velox/common/memory/tests/CMakeLists.txt b/velox/common/memory/tests/CMakeLists.txt index bb07501386c..77ce4489f9d 100644 --- a/velox/common/memory/tests/CMakeLists.txt +++ b/velox/common/memory/tests/CMakeLists.txt @@ -42,7 +42,6 @@ target_link_libraries( velox_exec velox_exec_test_lib velox_memory - velox_temp_path velox_test_util velox_vector_fuzzer Folly::folly @@ -69,3 +68,5 @@ if(VELOX_ENABLE_BENCHMARKS) target_link_libraries(velox_concurrent_allocation_benchmark PRIVATE velox_memory velox_time) endif() + +velox_add_library(velox_memory_test_util INTERFACE HEADERS SharedArbitratorTestUtil.h) diff --git a/velox/common/memory/tests/FragmentationBenchmark.cpp b/velox/common/memory/tests/FragmentationBenchmark.cpp index eb9124f2235..9d5910114c0 100644 --- a/velox/common/memory/tests/FragmentationBenchmark.cpp +++ b/velox/common/memory/tests/FragmentationBenchmark.cpp @@ -150,7 +150,7 @@ class FragmentationTest { } void initMemory(size_t sizeCap) { - MmapAllocator::Options options; + MemoryAllocator::Options options; options.capacity = sizeCap + (64 << 20); memory_ = std::make_shared(options); } diff --git a/velox/common/memory/tests/MemoryAllocatorTest.cpp b/velox/common/memory/tests/MemoryAllocatorTest.cpp index 7b9f123f96a..1526a59ab89 100644 --- a/velox/common/memory/tests/MemoryAllocatorTest.cpp +++ b/velox/common/memory/tests/MemoryAllocatorTest.cpp @@ -258,8 +258,9 @@ class MemoryAllocatorTest : public testing::TestWithParam { allocations.clear(); } - void clearAllocations(std::vector>>& - allocationsVector) { + void clearAllocations( + std::vector>>& + allocationsVector) { for (auto& allocations : allocationsVector) { for (auto& allocation : allocations) { instance_->freeNonContiguous(*allocation); @@ -431,7 +432,7 @@ TEST_P(MemoryAllocatorTest, mmapAllocatorInit) { return; } { - MmapAllocator::Options options; + MemoryAllocator::Options options; options.capacity = kCapacityBytes; options.smallAllocationReservePct = 39; options.maxMallocBytes = 2999; @@ -447,7 +448,7 @@ TEST_P(MemoryAllocatorTest, mmapAllocatorInit) { EXPECT_EQ(smallAllocationBytes, mmapAllocator->mallocReservedBytes()); } { - MmapAllocator::Options options; + MemoryAllocator::Options options; options.capacity = kCapacityBytes; options.smallAllocationReservePct = 39; options.maxMallocBytes = 0; @@ -461,7 +462,7 @@ TEST_P(MemoryAllocatorTest, mmapAllocatorInit) { EXPECT_EQ(0, mmapAllocator->mallocReservedBytes()); } { - MmapAllocator::Options options; + MemoryAllocator::Options options; options.capacity = 64 * 256 * AllocationTraits::kPageSize - 100; options.smallAllocationReservePct = 10; options.maxMallocBytes = 3072; @@ -831,65 +832,66 @@ TEST_P(MemoryAllocatorTest, nonContiguousFailure) { static_cast(numNewPages), injectedFailure); } - } testSettings[] = {// Cap failure injection. - {0, 100, MemoryAllocator::InjectedFailure::kCap}, - {0, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - {0, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kCap}, - {100, 100, MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun / 2, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kCap}, - {200, 100, MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun / 2 + 100, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun - 1, - MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - // Allocate failure injection. - {0, 100, MemoryAllocator::InjectedFailure::kAllocate}, - {0, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kAllocate}, - {0, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kCap}, - {100, 100, MemoryAllocator::InjectedFailure::kAllocate}, - {Allocation::PageRun::kMaxPagesInRun / 2, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kAllocate}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kAllocate}, - {200, 100, MemoryAllocator::InjectedFailure::kAllocate}, - {Allocation::PageRun::kMaxPagesInRun / 2 + 100, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kAllocate}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun - 1, - MemoryAllocator::InjectedFailure::kAllocate}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kAllocate}, - // Madvise failure injection. - {0, 100, MemoryAllocator::InjectedFailure::kMadvise}, - {0, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMadvise}, - {0, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kMadvise}, - {200, 100, MemoryAllocator::InjectedFailure::kMadvise}}; + } testSettings[] = { + // Cap failure injection. + {0, 100, MemoryAllocator::InjectedFailure::kCap}, + {0, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + {0, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kCap}, + {100, 100, MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun / 2, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kCap}, + {200, 100, MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun / 2 + 100, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun - 1, + MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + // Allocate failure injection. + {0, 100, MemoryAllocator::InjectedFailure::kAllocate}, + {0, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kAllocate}, + {0, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kCap}, + {100, 100, MemoryAllocator::InjectedFailure::kAllocate}, + {Allocation::PageRun::kMaxPagesInRun / 2, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kAllocate}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kAllocate}, + {200, 100, MemoryAllocator::InjectedFailure::kAllocate}, + {Allocation::PageRun::kMaxPagesInRun / 2 + 100, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kAllocate}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun - 1, + MemoryAllocator::InjectedFailure::kAllocate}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kAllocate}, + // Madvise failure injection. + {0, 100, MemoryAllocator::InjectedFailure::kMadvise}, + {0, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMadvise}, + {0, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kMadvise}, + {200, 100, MemoryAllocator::InjectedFailure::kMadvise}}; std::unordered_map expectedErrorMsg = { {MemoryAllocator::InjectedFailure::kAllocate, @@ -1249,6 +1251,109 @@ TEST_P(MemoryAllocatorTest, allocateBytes) { ASSERT_TRUE(instance_->checkConsistency()); } +TEST_P(MemoryAllocatorTest, reallocateBytes) { + // Test grow: allocate, fill, reallocate larger, verify data preserved. + { + const uint64_t kInitialSize = 1024; + const uint64_t kNewSize = 4096; + void* p = instance_->allocateBytes(kInitialSize); + ASSERT_NE(nullptr, p); + ::memset(p, 0xAB, kInitialSize); + + auto usedBefore = instance_->totalUsedBytes(); + void* newP = instance_->reallocateBytes(p, kInitialSize, kNewSize); + ASSERT_NE(nullptr, newP); + // Verify original data is preserved. + for (uint64_t i = 0; i < kInitialSize; ++i) { + ASSERT_EQ(static_cast(0xAB), static_cast(newP)[i]) + << "at byte " << i; + } + // Verify memory accounting: usage should have increased by the delta. + if (!useMmap_) { + ASSERT_EQ( + usedBefore + (kNewSize - kInitialSize), instance_->totalUsedBytes()); + } + instance_->freeBytes(newP, kNewSize); + } + + // Test shrink: allocate, fill, reallocate smaller, verify data preserved. + { + const uint64_t kInitialSize = 4096; + const uint64_t kNewSize = 1024; + void* p = instance_->allocateBytes(kInitialSize); + ASSERT_NE(nullptr, p); + ::memset(p, 0xCD, kInitialSize); + + void* newP = instance_->reallocateBytes(p, kInitialSize, kNewSize); + ASSERT_NE(nullptr, newP); + for (uint64_t i = 0; i < kNewSize; ++i) { + ASSERT_EQ(static_cast(0xCD), static_cast(newP)[i]) + << "at byte " << i; + } + instance_->freeBytes(newP, kNewSize); + } + + // Test with nullptr input: should behave like allocateBytes. + { + const uint64_t kSize = 2048; + void* newP = instance_->reallocateBytes(nullptr, 0, kSize); + ASSERT_NE(nullptr, newP); + ::memset(newP, 0xEF, kSize); + instance_->freeBytes(newP, kSize); + } + + // Test memory accounting after full cycle. + if (!useMmap_) { + ASSERT_EQ(0, instance_->totalUsedBytes()); + } + ASSERT_TRUE(instance_->checkConsistency()); +} + +TEST_P(MemoryAllocatorTest, reallocateBytesWithAlignment) { + // Non-default alignment should fall back to allocate+memcpy+free + // (MallocAllocator's reallocateBytesWithoutRetry returns nullptr for + // non-default alignment since ::realloc() cannot guarantee it). + const uint64_t kInitialSize = 1024; + const uint64_t kNewSize = 4096; + const uint16_t kAlignment = MemoryAllocator::kMaxAlignment; + void* p = instance_->allocateBytes(kInitialSize, kAlignment); + ASSERT_NE(nullptr, p); + ASSERT_EQ(0, reinterpret_cast(p) % kAlignment); + ::memset(p, 0x42, kInitialSize); + + void* newP = + instance_->reallocateBytes(p, kInitialSize, kNewSize, kAlignment); + ASSERT_NE(nullptr, newP); + ASSERT_EQ(0, reinterpret_cast(newP) % kAlignment); + for (uint64_t i = 0; i < kInitialSize; ++i) { + ASSERT_EQ(static_cast(0x42), static_cast(newP)[i]) + << "at byte " << i; + } + instance_->freeBytes(newP, kNewSize); + ASSERT_TRUE(instance_->checkConsistency()); +} + +TEST_P(MemoryAllocatorTest, reallocateBytesCapacityExceeded) { + if (useMmap_) { + // MmapAllocator has different capacity semantics; skip. + return; + } + // Allocate most of capacity, then try to reallocate beyond it. + const uint64_t kInitialSize = 1024; + const uint64_t kOverCapacity = kCapacityBytes + 1; + void* p = instance_->allocateBytes(kInitialSize); + ASSERT_NE(nullptr, p); + + void* newP = instance_->reallocateBytes(p, kInitialSize, kOverCapacity); + ASSERT_EQ(nullptr, newP); + // Original pointer should still be valid after failed reallocation. + // Memory accounting should be unchanged. + ASSERT_EQ(kInitialSize, instance_->totalUsedBytes()); + instance_->freeBytes(p, kInitialSize); + ASSERT_EQ(0, instance_->totalUsedBytes()); + ASSERT_TRUE(instance_->checkConsistency()); +} + TEST_P(MemoryAllocatorTest, allocateBytesWithAlignment) { struct { uint64_t allocateBytes; @@ -1991,4 +2096,263 @@ TEST_F(MmapConfigTest, sizeClasses) { } } +class MallocContiguousTest : public testing::TestWithParam { + protected: + static void SetUpTestCase() { + FLAGS_velox_memory_leak_check_enabled = true; + } + + void SetUp() override { + MemoryAllocator::Options options; + options.capacity = kCapacityBytes; + options.reservationByteLimit = 0; + options.mallocContiguousEnabled = GetParam(); + allocator_ = std::make_shared(options); + } + + std::shared_ptr allocator_; +}; + +TEST_P(MallocContiguousTest, allocateAndFreeContiguous) { + constexpr MachinePageCount kNumPages = 16; + ContiguousAllocation allocation; + ASSERT_TRUE(allocator_->allocateContiguous(kNumPages, nullptr, allocation)); + ASSERT_FALSE(allocation.empty()); + ASSERT_EQ(allocation.numPages(), kNumPages); + ASSERT_NE(allocation.data(), nullptr); + + // Verify we can write and read from the allocated memory. + auto* data = reinterpret_cast(allocation.data()); + for (int64_t i = 0; i < kNumPages; ++i) { + data[i] = i * 42; + } + for (int64_t i = 0; i < kNumPages; ++i) { + ASSERT_EQ(data[i], i * 42); + } + + allocator_->freeContiguous(allocation); + ASSERT_TRUE(allocation.empty()); + ASSERT_EQ(allocator_->numAllocated(), 0); + ASSERT_EQ(allocator_->numMapped(), 0); +} + +TEST_P(MallocContiguousTest, allocateContiguousWithMaxPages) { + constexpr MachinePageCount kNumPages = 8; + constexpr MachinePageCount kMaxPages = 32; + ContiguousAllocation allocation; + ASSERT_TRUE(allocator_->allocateContiguous( + kNumPages, nullptr, allocation, nullptr, kMaxPages)); + ASSERT_EQ(allocation.numPages(), kNumPages); + ASSERT_EQ(allocation.maxSize(), AllocationTraits::pageBytes(kMaxPages)); + + allocator_->freeContiguous(allocation); + ASSERT_TRUE(allocation.empty()); +} + +TEST_P(MallocContiguousTest, growContiguous) { + constexpr MachinePageCount kNumPages = 8; + constexpr MachinePageCount kMaxPages = 32; + ContiguousAllocation allocation; + ASSERT_TRUE(allocator_->allocateContiguous( + kNumPages, nullptr, allocation, nullptr, kMaxPages)); + + // Write data before growing. + auto* data = reinterpret_cast(allocation.data()); + const auto numWords = static_cast( + AllocationTraits::pageBytes(kNumPages) / sizeof(int64_t)); + for (int64_t i = 0; i < numWords; ++i) { + data[i] = i + 1; + } + + // Grow within maxPages. + constexpr MachinePageCount kIncrement = 8; + ASSERT_TRUE(allocator_->growContiguousWithoutRetry(kIncrement, allocation)); + ASSERT_EQ(allocation.numPages(), kNumPages + kIncrement); + + // Verify original data is intact after grow. + data = reinterpret_cast(allocation.data()); + for (int64_t i = 0; i < numWords; ++i) { + ASSERT_EQ(data[i], i + 1); + } + + allocator_->freeContiguous(allocation); + ASSERT_TRUE(allocation.empty()); +} + +TEST_P(MallocContiguousTest, freeContiguousCollateral) { + constexpr MachinePageCount kFirstPages = 16; + constexpr MachinePageCount kSecondPages = 8; + ContiguousAllocation allocation; + ASSERT_TRUE(allocator_->allocateContiguous(kFirstPages, nullptr, allocation)); + ASSERT_EQ(allocation.numPages(), kFirstPages); + ASSERT_EQ(allocator_->numAllocated(), kFirstPages); + + // Allocate again into the same allocation — old allocation is freed as + // contiguous collateral. + ASSERT_TRUE( + allocator_->allocateContiguous(kSecondPages, nullptr, allocation)); + ASSERT_EQ(allocation.numPages(), kSecondPages); + ASSERT_EQ(allocator_->numAllocated(), kSecondPages); + + allocator_->freeContiguous(allocation); + ASSERT_EQ(allocator_->numAllocated(), 0); +} + +TEST_P(MallocContiguousTest, allocateContiguousFailure) { + constexpr MachinePageCount kNumPages = 16; + ContiguousAllocation allocation; + + // Inject allocation failure so dispatchAllocateContiguous returns nullptr. + allocator_->testingSetFailureInjection( + MemoryAllocator::InjectedFailure::kAllocate, true); + + ASSERT_FALSE(allocator_->allocateContiguous(kNumPages, nullptr, allocation)); + ASSERT_TRUE(allocation.empty()); + // Verify all counter increments are properly rolled back. + ASSERT_EQ(allocator_->numAllocated(), 0); + ASSERT_EQ(allocator_->numMapped(), 0); + auto failureMsg = allocator_->getAndClearFailureMessage(); + EXPECT_THAT(failureMsg, testing::HasSubstr("Failed to allocate")); + ASSERT_TRUE(allocator_->checkConsistency()); +} + +TEST_P(MallocContiguousTest, allocContiguous) { + struct { + MachinePageCount nonContiguousPages; + MachinePageCount oldContiguousPages; + MachinePageCount newContiguousPages; + } testSettings[] = { + {100, 100, 200}, + {100, 200, 200}, + {200, 100, 200}, + {200, 100, 400}, + {0, 100, 100}, + {0, 200, 100}, + {0, 100, 200}, + {100, 0, 100}, + {200, 0, 100}, + {100, 0, 200}}; + for (const auto& testData : testSettings) { + SCOPED_TRACE( + fmt::format( + "nonContiguousPages:{} oldContiguousPages:{} newContiguousPages:{}", + testData.nonContiguousPages, + testData.oldContiguousPages, + testData.newContiguousPages)); + SetUp(); + Allocation allocation; + if (testData.nonContiguousPages != 0) { + allocator_->allocateNonContiguous( + testData.nonContiguousPages, allocation); + } + ContiguousAllocation contiguousAllocation; + if (testData.oldContiguousPages != 0) { + allocator_->allocateContiguous( + testData.oldContiguousPages, nullptr, contiguousAllocation); + } + allocator_->allocateContiguous( + testData.newContiguousPages, &allocation, contiguousAllocation); + ASSERT_EQ(allocator_->numAllocated(), testData.newContiguousPages); + ASSERT_EQ(allocator_->numMapped(), testData.newContiguousPages); + + allocator_->freeContiguous(contiguousAllocation); + ASSERT_EQ(allocator_->numMapped(), 0); + ASSERT_EQ(allocator_->numAllocated(), 0); + ASSERT_TRUE(allocator_->checkConsistency()); + } +} + +TEST_P(MallocContiguousTest, allocContiguousFail) { + struct { + MachinePageCount nonContiguousPages; + MachinePageCount oldContiguousPages; + MachinePageCount newContiguousPages; + } testSettings[] = {{200, 100, 400}, {0, 100, 200}, {100, 0, 200}}; + for (const auto& testData : testSettings) { + SCOPED_TRACE( + fmt::format( + "nonContiguousPages:{} oldContiguousPages:{} newContiguousPages:{}", + testData.nonContiguousPages, + testData.oldContiguousPages, + testData.newContiguousPages)); + SetUp(); + Allocation allocation; + if (testData.nonContiguousPages != 0) { + allocator_->allocateNonContiguous( + testData.nonContiguousPages, allocation); + } + ContiguousAllocation contiguousAllocation; + if (testData.oldContiguousPages != 0) { + allocator_->allocateContiguous( + testData.oldContiguousPages, nullptr, contiguousAllocation); + } + ASSERT_EQ( + allocator_->numAllocated(), + testData.oldContiguousPages + testData.nonContiguousPages); + + allocator_->testingSetFailureInjection( + MemoryAllocator::InjectedFailure::kCap, true); + + ASSERT_FALSE(allocator_->allocateContiguous( + testData.newContiguousPages, &allocation, contiguousAllocation)); + auto failureMsg = allocator_->getAndClearFailureMessage(); + EXPECT_THAT( + failureMsg, testing::HasSubstr("Exceeded memory allocator limit")); + ASSERT_EQ(allocator_->numAllocated(), 0); + ASSERT_EQ(allocator_->numMapped(), 0); + ASSERT_TRUE(allocator_->checkConsistency()); + } +} + +TEST_P(MallocContiguousTest, allocContiguousGrow) { + auto largestClass = allocator_->sizeClasses().back(); + constexpr int32_t kInitialLarge = 1024; + constexpr int32_t kMinGrow = 1024; + MachinePageCount numPages = 0; + std::vector small; + auto freeSmall = [&](int32_t toFree) { + int32_t freed = 0; + while (!small.empty() && freed < toFree) { + freed += small.back().numPages(); + allocator_->freeNonContiguous(small.back()); + small.pop_back(); + } + }; + + for (; numPages < kCapacityPages - kInitialLarge; numPages += largestClass) { + Allocation temp; + allocator_->allocateNonContiguous(largestClass, temp); + small.push_back(std::move(temp)); + } + ContiguousAllocation large; + EXPECT_FALSE(allocator_->allocateContiguous( + kInitialLarge * 2, nullptr, large, nullptr, kCapacityPages)); + EXPECT_TRUE(allocator_->allocateContiguous( + kInitialLarge, nullptr, large, nullptr, kCapacityPages)); + EXPECT_FALSE(allocator_->growContiguous(kMinGrow, large)); + auto failureMsg = allocator_->getAndClearFailureMessage(); + EXPECT_THAT( + failureMsg, testing::HasSubstr("Exceeded memory allocator limit")); + freeSmall(kMinGrow); + EXPECT_TRUE(allocator_->growContiguous(kMinGrow, large)); + EXPECT_EQ(allocator_->numAllocated(), kCapacityPages); + freeSmall(4 * kMinGrow); + EXPECT_TRUE(allocator_->growContiguous(4 * kMinGrow, large)); + EXPECT_THROW( + allocator_->growContiguous(100000 * kMinGrow, large), VeloxException); + allocator_->freeContiguous(large); + EXPECT_EQ( + kCapacityPages - kInitialLarge - 5 * kMinGrow, + allocator_->numAllocated()); + freeSmall(kCapacityPages); +} + +INSTANTIATE_TEST_SUITE_P( + MallocContiguousTests, + MallocContiguousTest, + testing::Bool(), + [](const testing::TestParamInfo& info) { + return info.param ? "mallocContiguous" : "mmapContiguous"; + }); + } // namespace facebook::velox::memory diff --git a/velox/common/memory/tests/MemoryArbitratorTest.cpp b/velox/common/memory/tests/MemoryArbitratorTest.cpp index 4cf4c492059..5fd819033fe 100644 --- a/velox/common/memory/tests/MemoryArbitratorTest.cpp +++ b/velox/common/memory/tests/MemoryArbitratorTest.cpp @@ -713,12 +713,16 @@ TEST_F(MemoryReclaimerTest, scopedReclaimedBytesRecorder) { auto childPool = root->addLeafChild("memoryReclaimRecorder", true); ASSERT_EQ(childPool->reservedBytes(), 0); int64_t reclaimedBytes{0}; - { ScopedReclaimedBytesRecorder recorder(childPool.get(), &reclaimedBytes); } + { + ScopedReclaimedBytesRecorder recorder(childPool.get(), &reclaimedBytes); + } ASSERT_EQ(reclaimedBytes, 0); void* buffer = childPool->allocate(1 << 20); ASSERT_EQ(childPool->reservedBytes(), 1 << 20); - { ScopedReclaimedBytesRecorder recorder(childPool.get(), &reclaimedBytes); } + { + ScopedReclaimedBytesRecorder recorder(childPool.get(), &reclaimedBytes); + } ASSERT_EQ(reclaimedBytes, 0); reclaimedBytes = 0; diff --git a/velox/common/memory/tests/MemoryCapExceededTest.cpp b/velox/common/memory/tests/MemoryCapExceededTest.cpp index 48eab65ba81..38c739f27f1 100644 --- a/velox/common/memory/tests/MemoryCapExceededTest.cpp +++ b/velox/common/memory/tests/MemoryCapExceededTest.cpp @@ -69,7 +69,7 @@ TEST_P(MemoryCapExceededTest, singleDriver) { // why). std::vector expectedTexts = { "Can't grow ", - "capacity with 2.00MB. This will exceed its max capacity 5.00MB, current " + "capacity with 2.00MB. This will exceed its memory pool capacity 5.00MB, current " "capacity 5.00MB.\n" "ARBITRATOR[SHARED CAPACITY[6.00GB] STATS[numRequests 1 numRunning 1 " "numSucceded 0 numAborted 0 numFailures 0 numNonReclaimableAttempts 0 " @@ -113,8 +113,9 @@ TEST_P(MemoryCapExceededTest, singleDriver) { .orderBy({"c0"}, false) .planNode(); auto queryCtx = core::QueryCtx::create(executor_.get()); - queryCtx->testingOverrideMemoryPool(memory::memoryManager()->addRootPool( - queryCtx->queryId(), kMaxBytes, exec::MemoryReclaimer::create())); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), kMaxBytes, exec::MemoryReclaimer::create())); CursorParameters params; params.planNode = plan; params.queryCtx = queryCtx; @@ -171,8 +172,9 @@ TEST_P(MemoryCapExceededTest, multipleDrivers) { .singleAggregation({"c0"}, {"sum(c1)"}) .planNode(); auto queryCtx = core::QueryCtx::create(executor_.get()); - queryCtx->testingOverrideMemoryPool(memory::memoryManager()->addRootPool( - queryCtx->queryId(), kMaxBytes, exec::MemoryReclaimer::create())); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), kMaxBytes, exec::MemoryReclaimer::create())); const int32_t numDrivers = 10; CursorParameters params; diff --git a/velox/common/memory/tests/MemoryManagerTest.cpp b/velox/common/memory/tests/MemoryManagerTest.cpp index 6cc3ef797a7..f1862a03673 100644 --- a/velox/common/memory/tests/MemoryManagerTest.cpp +++ b/velox/common/memory/tests/MemoryManagerTest.cpp @@ -31,13 +31,14 @@ DECLARE_bool(velox_enable_memory_usage_track_in_default_memory_pool); using namespace ::testing; namespace facebook::velox::memory { - namespace { -constexpr folly::StringPiece kSysRootName{"__sys_root__"}; + +constexpr std::string_view kSysRootName{"__sys_root__"}; MemoryManager& toMemoryManager(MemoryManager& manager) { return *static_cast(&manager); } + } // namespace class MemoryManagerTest : public testing::Test { @@ -221,14 +222,18 @@ TEST_F(MemoryManagerTest, addPool) { auto rootPool = manager.addRootPool("duplicateRootPool", kMaxMemory); ASSERT_EQ(rootPool->capacity(), kMaxMemory); ASSERT_EQ(rootPool->maxCapacity(), kMaxMemory); - { ASSERT_ANY_THROW(manager.addRootPool("duplicateRootPool", kMaxMemory)); } + { + ASSERT_ANY_THROW(manager.addRootPool("duplicateRootPool", kMaxMemory)); + } auto threadSafeLeafPool = manager.addLeafPool("leafPool", true); ASSERT_EQ(threadSafeLeafPool->capacity(), kMaxMemory); ASSERT_EQ(threadSafeLeafPool->maxCapacity(), kMaxMemory); auto nonThreadSafeLeafPool = manager.addLeafPool("duplicateLeafPool", true); ASSERT_EQ(nonThreadSafeLeafPool->capacity(), kMaxMemory); ASSERT_EQ(nonThreadSafeLeafPool->maxCapacity(), kMaxMemory); - { ASSERT_ANY_THROW(manager.addLeafPool("duplicateLeafPool")); } + { + ASSERT_ANY_THROW(manager.addLeafPool("duplicateLeafPool")); + } const int64_t poolCapacity = 1 << 20; auto rootPoolWithMaxCapacity = manager.addRootPool("rootPoolWithCapacity", poolCapacity); @@ -273,7 +278,9 @@ TEST_F(MemoryManagerTest, addPoolWithArbitrator) { auto nonThreadSafeLeafPool = manager.addLeafPool("duplicateLeafPool", true); ASSERT_EQ(nonThreadSafeLeafPool->capacity(), kMaxMemory); ASSERT_EQ(nonThreadSafeLeafPool->maxCapacity(), kMaxMemory); - { ASSERT_ANY_THROW(manager.addLeafPool("duplicateLeafPool")); } + { + ASSERT_ANY_THROW(manager.addLeafPool("duplicateLeafPool")); + } const int64_t poolCapacity = 1 << 30; auto rootPoolWithMaxCapacity = manager.addRootPool( "rootPoolWithCapacity", poolCapacity, MemoryReclaimer::create()); @@ -346,8 +353,9 @@ TEST_F(MemoryManagerTest, defaultMemoryManager) { for (int i = 0; i < 32; ++i) { ASSERT_THAT( managerA.toString(true), - testing::HasSubstr(fmt::format( - "__sys_shared_leaf__{} usage 0B reserved 0B peak 0B\n", i))); + testing::HasSubstr( + fmt::format( + "__sys_shared_leaf__{} usage 0B reserved 0B peak 0B\n", i))); } } @@ -487,7 +495,7 @@ TEST_F(MemoryManagerTest, globalMemoryManager) { auto childII = manager->addLeafPool("another_child"); ASSERT_EQ(childII->kind(), MemoryPool::Kind::kLeaf); ASSERT_EQ(rootI.getChildCount(), kSharedPoolCount + 2); - ASSERT_EQ(childII->parent()->name(), kSysRootName.str()); + ASSERT_EQ(childII->parent()->name(), kSysRootName); childII.reset(); ASSERT_EQ(rootI.getChildCount(), kSharedPoolCount + 1); ASSERT_EQ(rootII.getChildCount(), kSharedPoolCount + 1); diff --git a/velox/common/memory/tests/MemoryPoolTest.cpp b/velox/common/memory/tests/MemoryPoolTest.cpp index 9b9ea02985a..f57e28d0659 100644 --- a/velox/common/memory/tests/MemoryPoolTest.cpp +++ b/velox/common/memory/tests/MemoryPoolTest.cpp @@ -31,7 +31,6 @@ #include "velox/common/testutil/TestValue.h" DECLARE_bool(velox_memory_leak_check_enabled); -DECLARE_bool(velox_memory_pool_debug_enabled); DECLARE_int32(velox_memory_num_shared_leaf_pools); using namespace ::testing; @@ -61,15 +60,16 @@ struct TestParam { class MemoryPoolTest : public testing::TestWithParam { public: static const std::vector getTestParams() { - std::vector params; - params.push_back({true, true, false}); - params.push_back({true, false, false}); - params.push_back({false, true, false}); - params.push_back({false, false, false}); - params.push_back({true, true, true}); - params.push_back({true, false, true}); - params.push_back({false, true, true}); - params.push_back({false, false, true}); + std::vector params = { + {true, true, false}, + {true, false, false}, + {false, true, false}, + {false, false, false}, + {true, true, true}, + {true, false, true}, + {false, true, true}, + {false, false, true}, + }; return params; } @@ -165,13 +165,14 @@ TEST_P(MemoryPoolTest, ctor) { auto fakeRoot = std::make_shared( &manager, "fake_root", MemoryPool::Kind::kAggregate, nullptr, nullptr); // We can't construct an aggregate memory pool with non-thread safe. - ASSERT_ANY_THROW(std::make_shared( - &manager, - "fake_root", - MemoryPool::Kind::kAggregate, - nullptr, - nullptr, - MemoryPool::Options{.threadSafe = false})); + ASSERT_ANY_THROW( + std::make_shared( + &manager, + "fake_root", + MemoryPool::Kind::kAggregate, + nullptr, + nullptr, + MemoryPool::Options{.threadSafe = false})); ASSERT_EQ("fake_root", fakeRoot->name()); ASSERT_EQ( static_cast(root.get())->testingAllocator(), @@ -1061,6 +1062,22 @@ TEST_P(MemoryPoolTest, allocatorOverflow) { EXPECT_THROW(alloc.deallocate(nullptr, 1ULL << 62), VeloxException); } +TEST_P(MemoryPoolTest, allocatorSwap) { + MemoryManager& manager = *getMemoryManager(); + auto root = manager.addRootPool("swapRoot"); + auto leaf1 = root->addLeafChild("leaf1"); + auto leaf2 = root->addLeafChild("leaf2"); + + StlAllocator alloc1(*leaf1); + StlAllocator alloc2(*leaf2); + ASSERT_EQ(alloc1.pool, leaf1.get()); + ASSERT_EQ(alloc2.pool, leaf2.get()); + + std::swap(alloc1, alloc2); + EXPECT_EQ(alloc1.pool, leaf2.get()); + EXPECT_EQ(alloc2.pool, leaf1.get()); +} + TEST_P(MemoryPoolTest, contiguousAllocate) { auto manager = getMemoryManager(); auto pool = manager->addLeafPool("contiguousAllocate"); @@ -1402,71 +1419,73 @@ TEST_P(MemoryPoolTest, persistentNonContiguousAllocateFailure) { static_cast(numNewPages), injectedFailure); } - } testSettings[] = {// Cap failure injection. - {0, 100, MemoryAllocator::InjectedFailure::kCap}, - {0, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - {0, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kCap}, - {100, 100, MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun / 2, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kCap}, - {200, 100, MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun / 2 + 100, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun - 1, - MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - // Allocate failure injection. - {0, 100, MemoryAllocator::InjectedFailure::kAllocate}, - {0, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kAllocate}, - {0, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kCap}, - {100, 100, MemoryAllocator::InjectedFailure::kAllocate}, - {Allocation::PageRun::kMaxPagesInRun / 2, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kAllocate}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kAllocate}, - {200, 100, MemoryAllocator::InjectedFailure::kAllocate}, - {Allocation::PageRun::kMaxPagesInRun / 2 + 100, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kAllocate}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun - 1, - MemoryAllocator::InjectedFailure::kAllocate}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kAllocate}, - // Madvise failure injection. - {0, 100, MemoryAllocator::InjectedFailure::kMadvise}, - {0, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMadvise}, - {0, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kMadvise}, - {200, 100, MemoryAllocator::InjectedFailure::kMadvise}}; + } testSettings[] = { + // Cap failure injection. + {0, 100, MemoryAllocator::InjectedFailure::kCap}, + {0, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + {0, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kCap}, + {100, 100, MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun / 2, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kCap}, + {200, 100, MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun / 2 + 100, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun - 1, + MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + // Allocate failure injection. + {0, 100, MemoryAllocator::InjectedFailure::kAllocate}, + {0, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kAllocate}, + {0, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kCap}, + {100, 100, MemoryAllocator::InjectedFailure::kAllocate}, + {Allocation::PageRun::kMaxPagesInRun / 2, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kAllocate}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kAllocate}, + {200, 100, MemoryAllocator::InjectedFailure::kAllocate}, + {Allocation::PageRun::kMaxPagesInRun / 2 + 100, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kAllocate}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun - 1, + MemoryAllocator::InjectedFailure::kAllocate}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kAllocate}, + // Madvise failure injection. + {0, 100, MemoryAllocator::InjectedFailure::kMadvise}, + {0, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMadvise}, + {0, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kMadvise}, + {200, 100, MemoryAllocator::InjectedFailure::kMadvise}}; for (const auto& testData : testSettings) { - SCOPED_TRACE(fmt::format( - "{}, useMmap:{}, useCache:{}", - testData.debugString(), - useMmap_, - useCache_)); + SCOPED_TRACE( + fmt::format( + "{}, useMmap:{}, useCache:{}", + testData.debugString(), + useMmap_, + useCache_)); if ((testData.injectedFailure != MemoryAllocator::InjectedFailure::kAllocate) && !useMmap_) { @@ -1584,11 +1603,12 @@ TEST_P(MemoryPoolTest, transientNonContiguousAllocateFailure) { MemoryAllocator::InjectedFailure::kMadvise}, {200, 100, 100, MemoryAllocator::InjectedFailure::kMadvise}}; for (const auto& testData : testSettings) { - SCOPED_TRACE(fmt::format( - "{}, useMmap:{}, useCache:{}", - testData.debugString(), - useMmap_, - useCache_)); + SCOPED_TRACE( + fmt::format( + "{}, useMmap:{}, useCache:{}", + testData.debugString(), + useMmap_, + useCache_)); if ((testData.injectedFailure != MemoryAllocator::InjectedFailure::kAllocate) && !useMmap_) { @@ -1672,81 +1692,82 @@ TEST_P(MemoryPoolTest, persistentContiguousAllocateFailure) { numNewPages, injectedFailure); } - } testSettings[] = {// Cap failure injection. - {0, 100, MemoryAllocator::InjectedFailure::kCap}, - {0, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - {0, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kCap}, - {100, 100, MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun / 2, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kCap}, - {200, 100, MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun / 2 + 100, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun - 1, - MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - // Mmap failure injection. - {0, 100, MemoryAllocator::InjectedFailure::kMmap}, - {0, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMmap}, - {0, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kCap}, - {100, 100, MemoryAllocator::InjectedFailure::kMmap}, - {Allocation::PageRun::kMaxPagesInRun / 2, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMmap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kMmap}, - {200, 100, MemoryAllocator::InjectedFailure::kMmap}, - {Allocation::PageRun::kMaxPagesInRun / 2 + 100, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMmap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun - 1, - MemoryAllocator::InjectedFailure::kMmap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMmap}, - // Madvise failure injection. - {0, 100, MemoryAllocator::InjectedFailure::kMadvise}, - {0, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMadvise}, - {0, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kMadvise}, - {100, 100, MemoryAllocator::InjectedFailure::kMadvise}, - {Allocation::PageRun::kMaxPagesInRun / 2, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMadvise}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kMadvise}, - {200, 100, MemoryAllocator::InjectedFailure::kMadvise}, - {Allocation::PageRun::kMaxPagesInRun / 2 + 100, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMadvise}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun - 1, - MemoryAllocator::InjectedFailure::kMadvise}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMadvise}}; + } testSettings[] = { + // Cap failure injection. + {0, 100, MemoryAllocator::InjectedFailure::kCap}, + {0, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + {0, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kCap}, + {100, 100, MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun / 2, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kCap}, + {200, 100, MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun / 2 + 100, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun - 1, + MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + // Mmap failure injection. + {0, 100, MemoryAllocator::InjectedFailure::kMmap}, + {0, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMmap}, + {0, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kCap}, + {100, 100, MemoryAllocator::InjectedFailure::kMmap}, + {Allocation::PageRun::kMaxPagesInRun / 2, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMmap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kMmap}, + {200, 100, MemoryAllocator::InjectedFailure::kMmap}, + {Allocation::PageRun::kMaxPagesInRun / 2 + 100, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMmap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun - 1, + MemoryAllocator::InjectedFailure::kMmap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMmap}, + // Madvise failure injection. + {0, 100, MemoryAllocator::InjectedFailure::kMadvise}, + {0, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMadvise}, + {0, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kMadvise}, + {100, 100, MemoryAllocator::InjectedFailure::kMadvise}, + {Allocation::PageRun::kMaxPagesInRun / 2, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMadvise}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kMadvise}, + {200, 100, MemoryAllocator::InjectedFailure::kMadvise}, + {Allocation::PageRun::kMaxPagesInRun / 2 + 100, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMadvise}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun - 1, + MemoryAllocator::InjectedFailure::kMadvise}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMadvise}}; for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); if (!useMmap_) { @@ -1790,87 +1811,89 @@ TEST_P(MemoryPoolTest, transientContiguousAllocateFailure) { numNewPages, injectedFailure); } - } testSettings[] = {// Cap failure injection. - {0, 100, MemoryAllocator::InjectedFailure::kCap}, - {0, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - {0, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kCap}, - {100, 100, MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun / 2, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kCap}, - {200, 100, MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun / 2 + 100, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun - 1, - MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - // Mmap failure injection. - {0, 100, MemoryAllocator::InjectedFailure::kMmap}, - {0, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMmap}, - {0, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kCap}, - {100, 100, MemoryAllocator::InjectedFailure::kMmap}, - {Allocation::PageRun::kMaxPagesInRun / 2, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMmap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kMmap}, - {200, 100, MemoryAllocator::InjectedFailure::kMmap}, - {Allocation::PageRun::kMaxPagesInRun / 2 + 100, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMmap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun - 1, - MemoryAllocator::InjectedFailure::kMmap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMmap}, - // Madvise failure injection. - {0, 100, MemoryAllocator::InjectedFailure::kMadvise}, - {0, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMadvise}, - {0, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kMadvise}, - {100, 100, MemoryAllocator::InjectedFailure::kMadvise}, - {Allocation::PageRun::kMaxPagesInRun / 2, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMadvise}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kMadvise}, - {200, 100, MemoryAllocator::InjectedFailure::kMadvise}, - {Allocation::PageRun::kMaxPagesInRun / 2 + 100, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMadvise}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun - 1, - MemoryAllocator::InjectedFailure::kMadvise}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMadvise}}; + } testSettings[] = { + // Cap failure injection. + {0, 100, MemoryAllocator::InjectedFailure::kCap}, + {0, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + {0, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kCap}, + {100, 100, MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun / 2, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kCap}, + {200, 100, MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun / 2 + 100, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun - 1, + MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + // Mmap failure injection. + {0, 100, MemoryAllocator::InjectedFailure::kMmap}, + {0, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMmap}, + {0, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kCap}, + {100, 100, MemoryAllocator::InjectedFailure::kMmap}, + {Allocation::PageRun::kMaxPagesInRun / 2, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMmap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kMmap}, + {200, 100, MemoryAllocator::InjectedFailure::kMmap}, + {Allocation::PageRun::kMaxPagesInRun / 2 + 100, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMmap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun - 1, + MemoryAllocator::InjectedFailure::kMmap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMmap}, + // Madvise failure injection. + {0, 100, MemoryAllocator::InjectedFailure::kMadvise}, + {0, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMadvise}, + {0, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kMadvise}, + {100, 100, MemoryAllocator::InjectedFailure::kMadvise}, + {Allocation::PageRun::kMaxPagesInRun / 2, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMadvise}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kMadvise}, + {200, 100, MemoryAllocator::InjectedFailure::kMadvise}, + {Allocation::PageRun::kMaxPagesInRun / 2 + 100, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMadvise}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun - 1, + MemoryAllocator::InjectedFailure::kMadvise}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMadvise}}; for (const auto& testData : testSettings) { - SCOPED_TRACE(fmt::format( - "{}, useCache:{} , useMmap:{}", - testData.debugString(), - useCache_, - useMmap_)); + SCOPED_TRACE( + fmt::format( + "{}, useCache:{} , useMmap:{}", + testData.debugString(), + useCache_, + useMmap_)); if (!useMmap_) { // No failure injections supported for contiguous allocation of // MallocAllocator. @@ -1944,33 +1967,34 @@ TEST_P(MemoryPoolTest, persistentContiguousGrowAllocateFailure) { injectedFailure, expectedErrorMessage); } - } testSettings[] = {// Cap failure injection. - {10, - 100, - MemoryAllocator::InjectedFailure::kCap, - "growContiguous failed with 100 pages from Memory Pool"}, - {100, - 10, - MemoryAllocator::InjectedFailure::kCap, - "growContiguous failed with 10 pages from Memory Pool"}, - // Mmap failure injection. - {10, - 100, - MemoryAllocator::InjectedFailure::kMmap, - "growContiguous failed with 100 pages from Memory Pool"}, - {100, - 10, - MemoryAllocator::InjectedFailure::kMmap, - "growContiguous failed with 10 pages from Memory Pool"}, - // Madvise failure injection. - {10, - 100, - MemoryAllocator::InjectedFailure::kMadvise, - "growContiguous failed with 100 pages from Memory Pool"}, - {100, - 10, - MemoryAllocator::InjectedFailure::kMadvise, - "growContiguous failed with 10 pages from Memory Pool"}}; + } testSettings[] = { + // Cap failure injection. + {10, + 100, + MemoryAllocator::InjectedFailure::kCap, + "growContiguous failed with 100 pages from Memory Pool"}, + {100, + 10, + MemoryAllocator::InjectedFailure::kCap, + "growContiguous failed with 10 pages from Memory Pool"}, + // Mmap failure injection. + {10, + 100, + MemoryAllocator::InjectedFailure::kMmap, + "growContiguous failed with 100 pages from Memory Pool"}, + {100, + 10, + MemoryAllocator::InjectedFailure::kMmap, + "growContiguous failed with 10 pages from Memory Pool"}, + // Madvise failure injection. + {10, + 100, + MemoryAllocator::InjectedFailure::kMadvise, + "growContiguous failed with 100 pages from Memory Pool"}, + {100, + 10, + MemoryAllocator::InjectedFailure::kMadvise, + "growContiguous failed with 10 pages from Memory Pool"}}; for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); if (!useMmap_) { @@ -2014,21 +2038,23 @@ TEST_P(MemoryPoolTest, transientContiguousGrowAllocateFailure) { numGrowPages, injectedFailure); } - } testSettings[] = {// Cap failure injection. - {10, 100, MemoryAllocator::InjectedFailure::kCap}, - {100, 10, MemoryAllocator::InjectedFailure::kCap}, - // Mmap failure injection. - {10, 100, MemoryAllocator::InjectedFailure::kMmap}, - {100, 10, MemoryAllocator::InjectedFailure::kMmap}, - // Madvise failure injection. - {10, 100, MemoryAllocator::InjectedFailure::kMadvise}, - {100, 10, MemoryAllocator::InjectedFailure::kMadvise}}; + } testSettings[] = { + // Cap failure injection. + {10, 100, MemoryAllocator::InjectedFailure::kCap}, + {100, 10, MemoryAllocator::InjectedFailure::kCap}, + // Mmap failure injection. + {10, 100, MemoryAllocator::InjectedFailure::kMmap}, + {100, 10, MemoryAllocator::InjectedFailure::kMmap}, + // Madvise failure injection. + {10, 100, MemoryAllocator::InjectedFailure::kMadvise}, + {100, 10, MemoryAllocator::InjectedFailure::kMadvise}}; for (const auto& testData : testSettings) { - SCOPED_TRACE(fmt::format( - "{}, useCache:{} , useMmap:{}", - testData.debugString(), - useCache_, - useMmap_)); + SCOPED_TRACE( + fmt::format( + "{}, useCache:{} , useMmap:{}", + testData.debugString(), + useCache_, + useMmap_)); if (!useMmap_) { // No failure injections supported for contiguous allocation of // MallocAllocator. @@ -2520,7 +2546,6 @@ TEST_P(MemoryPoolTest, concurrentUpdateToDifferentPools) { } TEST_P(MemoryPoolTest, concurrentUpdatesToTheSamePool) { - FLAGS_velox_memory_pool_debug_enabled = true; if (!isLeafThreadSafe_) { return; } @@ -2706,7 +2731,6 @@ TEST(MemoryPoolTest, visitChildren) { } TEST(MemoryPoolTest, debugMode) { - FLAGS_velox_memory_pool_debug_enabled = true; constexpr int64_t kMaxMemory = 10 * GB; constexpr int64_t kNumIterations = 100; const std::vector kAllocSizes = {128, 8 * KB, 2 * MB}; @@ -2726,12 +2750,23 @@ TEST(MemoryPoolTest, debugMode) { ->addLeafChild("child"); const auto& allocRecords = std::dynamic_pointer_cast(pool) ->testingDebugAllocRecords(); + std::vector smallAllocs; + smallAllocs.reserve(kNumIterations); for (int32_t i = 0; i < kNumIterations; i++) { smallAllocs.push_back(pool->allocate(kAllocSizes[0])); } EXPECT_EQ(allocRecords.size(), kNumIterations); checkAllocs(allocRecords, kAllocSizes[0]); + + // Check toString() works with debug mode enabled + const auto poolString = pool->toString(); + EXPECT_FALSE(poolString.empty()); + EXPECT_TRUE( + poolString.find( + "======== 100 allocations of 12.50KB total size ========") != + std::string::npos); + for (int32_t i = 0; i < kNumIterations; i++) { pool->free(smallAllocs[i], kAllocSizes[0]); } @@ -2779,14 +2814,16 @@ TEST(MemoryPoolTest, debugModeWithFilter) { "root0", kMaxMemory, nullptr, - debugEnabled ? std::optional(MemoryPool::DebugOptions{ - .debugPoolNameRegex = "NO-MATCH"}) - : std::nullopt); + debugEnabled + ? std::optional( + MemoryPool::DebugOptions{.debugPoolNameRegex = "NO-MATCH"}) + : std::nullopt); auto pool0_0 = root0->addLeafChild("PartialAggregation.0.0"); auto* buffer0 = pool0_0->allocate(1 * KB); - EXPECT_TRUE(std::dynamic_pointer_cast(pool0_0) - ->testingDebugAllocRecords() - .empty()); + EXPECT_TRUE( + std::dynamic_pointer_cast(pool0_0) + ->testingDebugAllocRecords() + .empty()); pool0_0->free(buffer0, 1 * KB); // leaf child created from MemoryPool, match filter @@ -2794,8 +2831,9 @@ TEST(MemoryPoolTest, debugModeWithFilter) { "root1", kMaxMemory, nullptr, - debugEnabled ? std::optional(MemoryPool::DebugOptions{ - .debugPoolNameRegex = ".*PartialAggregation.*"}) + debugEnabled ? std::optional( + MemoryPool::DebugOptions{ + .debugPoolNameRegex = ".*PartialAggregation.*"}) : std::nullopt); auto pool1_0 = root1->addLeafChild("PartialAggregation.0.1"); auto* buffer1 = pool1_0->allocate(1 * KB); @@ -2816,9 +2854,10 @@ TEST(MemoryPoolTest, debugModeWithFilter) { // old pool from root0 should not be affected by root1 buffer0 = pool0_0->allocate(1 * KB); - EXPECT_TRUE(std::dynamic_pointer_cast(pool0_0) - ->testingDebugAllocRecords() - .empty()); + EXPECT_TRUE( + std::dynamic_pointer_cast(pool0_0) + ->testingDebugAllocRecords() + .empty()); pool0_0->free(buffer0, 1 * KB); // leaf child created from MemoryPool, match filter @@ -2826,9 +2865,10 @@ TEST(MemoryPoolTest, debugModeWithFilter) { "root2", kMaxMemory, nullptr, - debugEnabled ? std::optional(MemoryPool::DebugOptions{ - .debugPoolNameRegex = ".*OrderBy.*"}) - : std::nullopt); + debugEnabled + ? std::optional( + MemoryPool::DebugOptions{.debugPoolNameRegex = ".*OrderBy.*"}) + : std::nullopt); auto pool2_0 = root2->addLeafChild("OrderBy.0.0"); auto* buffer2 = pool2_0->allocate(1 * KB); if (!debugEnabled) { @@ -2879,13 +2919,51 @@ TEST(MemoryPoolTest, debugModeWithFilter) { // leaf child created from MemoryManager, not match filter auto sysLeaf = manager.addLeafPool("Arbitrator.0.0"); auto* buffer5 = sysLeaf->allocate(1 * KB); - EXPECT_TRUE(std::dynamic_pointer_cast(sysLeaf) - ->testingDebugAllocRecords() - .empty()); + EXPECT_TRUE( + std::dynamic_pointer_cast(sysLeaf) + ->testingDebugAllocRecords() + .empty()); sysLeaf->free(buffer5, 1 * KB); } } +TEST_P(MemoryPoolTest, debugModeWrapCapException) { + const uint64_t kMaxCap = 128L * MB; + MemoryManager::Options options; + options.allocatorCapacity = kMaxCap; + options.arbitratorCapacity = kMaxCap; + options.extraArbitratorConfigs = { + {std::string(SharedArbitrator::ExtraConfig::kReservedCapacity), + folly::to(kMaxCap / 2) + "B"}}; + setupMemory(options); + auto manager = getMemoryManager(); + auto root = + manager->addRootPool("MemoryCapExceptions", kMaxCap, nullptr, {{".*"}}); + auto pool1 = root->addLeafChild("static_quota_1", isLeafThreadSafe_); + auto pool2 = root->addLeafChild("static_quota_2", isLeafThreadSafe_); + { + std::vector buffers{ + pool1->allocate(64L * MB), pool1->allocate(64L * MB)}; + try { + pool2->allocate(1L * MB); + } catch (const velox::VeloxRuntimeError& ex) { + ASSERT_EQ(error_source::kErrorSourceRuntime.c_str(), ex.errorSource()); + ASSERT_EQ(error_code::kMemCapExceeded.c_str(), ex.errorCode()); + EXPECT_TRUE( + ex.message().find( + "Exceeded memory pool capacity.\n\n" + "======= Current Allocations ======\n" + "Memory pool 'static_quota_1' - Found 2 allocations with 128.00MB total size:\n" + "======== 2 allocations of 128.00MB total size ========") != + std::string::npos) + << "Actual error message: " << ex.message(); + } + for (auto buffer : buffers) { + pool1->free(buffer, 64L * MB); + } + } +} + TEST_P(MemoryPoolTest, shrinkAndGrowAPIs) { MemoryManager& manager = *getMemoryManager(); std::vector capacities = {kMaxMemory, 128 * MB}; @@ -3968,9 +4046,13 @@ TEST_P(MemoryPoolTest, abort) { ASSERT_TRUE(rootPool->aborted()); // Allocate more buffer to trigger reservation increment at the root. - { VELOX_ASSERT_THROW(leafPool->allocate(capacity / 2), ""); } + { + VELOX_ASSERT_THROW(leafPool->allocate(capacity / 2), ""); + } // Allocate more buffer to trigger memory arbitration at the root. - { VELOX_ASSERT_THROW(leafPool->allocate(capacity * 2), ""); } + { + VELOX_ASSERT_THROW(leafPool->allocate(capacity * 2), ""); + } // Allocate without trigger memory reservation increment. void* buf2 = leafPool->allocate(128); ASSERT_EQ(leafPool->usedBytes(), 256); @@ -4037,6 +4119,213 @@ TEST_P(MemoryPoolTest, allocationWithCoveredCollateral) { pool->freeContiguous(contiguousAllocation); } +TEST_P(MemoryPoolTest, transferTo) { + MemoryManager::Options options; + options.alignment = MemoryAllocator::kMinAlignment; + options.allocatorCapacity = kDefaultCapacity; + setupMemory(options); + auto manager = getMemoryManager(); + + auto largestSizeClass = manager->allocator()->largestSizeClass(); + std::vector pageCounts{ + largestSizeClass, + largestSizeClass + 1, + largestSizeClass / 10, + 1, + largestSizeClass * 2, + largestSizeClass * 3 + 1}; + + auto assertEqualBytes = [](const memory::MemoryPool* pool, + int64_t usedBytes, + int64_t peakBytes, + int64_t reservedBytes) { + EXPECT_EQ(pool->usedBytes(), usedBytes); + EXPECT_EQ(pool->peakBytes(), peakBytes); + EXPECT_EQ(pool->reservedBytes(), reservedBytes); + }; + + auto assertZeroByte = [](const memory::MemoryPool* pool) { + EXPECT_EQ(pool->usedBytes(), 0); + EXPECT_EQ(pool->reservedBytes(), 0); + }; + + auto getMemoryBytes = [](const memory::MemoryPool* pool) { + return std::make_tuple( + pool->usedBytes(), pool->peakBytes(), pool->reservedBytes()); + }; + + auto createPools = [&manager](bool betweenDifferentRoots) { + auto root1 = manager->addRootPool("root1"); + auto root2 = manager->addRootPool("root2"); + std::shared_ptr from; + std::shared_ptr to; + if (betweenDifferentRoots) { + from = root1->addLeafChild("from"); + to = root2->addLeafChild("to"); + } else { + from = root1->addLeafChild("from"); + to = root1->addLeafChild("to"); + } + return std::make_tuple(root1, root2, from, to); + }; + + auto testTransferAllocate = [&assertZeroByte, + &assertEqualBytes, + &getMemoryBytes, + &createPools](bool betweenDifferentRoots) { + auto [root1, root2, from, to] = createPools(betweenDifferentRoots); + assertZeroByte(from.get()); + assertZeroByte(to.get()); + assertZeroByte(from->root()); + assertZeroByte(to->root()); + + const auto kSize = 1024; + int64_t usedBytes, rootUsedBytes; + int64_t peakBytes, rootPeakBytes; + int64_t reservedBytes, rootReservedBytes; + auto buffer = from->allocate(kSize); + // Transferring between non-leaf pools is not allowed. + EXPECT_FALSE(from->root()->transferTo(to.get(), buffer, kSize)); + EXPECT_FALSE(from->transferTo(to->root(), buffer, kSize)); + + std::tie(usedBytes, peakBytes, reservedBytes) = getMemoryBytes(from.get()); + std::tie(rootUsedBytes, rootPeakBytes, rootReservedBytes) = + getMemoryBytes(from->root()); + from->transferTo(to.get(), buffer, kSize); + assertEqualBytes(to.get(), usedBytes, peakBytes, reservedBytes); + if (from->root() == to->root()) { + rootPeakBytes *= 2; + } + assertEqualBytes( + to->root(), rootUsedBytes, rootPeakBytes, rootReservedBytes); + to->free(buffer, kSize); + assertZeroByte(from.get()); + assertZeroByte(to.get()); + assertZeroByte(from->root()); + assertZeroByte(to->root()); + }; + + auto testTransferAllocateZeroFilled = + [&assertZeroByte, &assertEqualBytes, &getMemoryBytes, &createPools]( + bool betweenDifferentRoots) { + auto [root1, root2, from, to] = createPools(betweenDifferentRoots); + assertZeroByte(from.get()); + assertZeroByte(to.get()); + assertZeroByte(from->root()); + assertZeroByte(to->root()); + + const auto kSize = 1024; + int64_t usedBytes, rootUsedBytes; + int64_t peakBytes, rootPeakBytes; + int64_t reservedBytes, rootReservedBytes; + auto buffer = from->allocateZeroFilled(8, kSize / 8); + std::tie(usedBytes, peakBytes, reservedBytes) = + getMemoryBytes(from.get()); + std::tie(rootUsedBytes, rootPeakBytes, rootReservedBytes) = + getMemoryBytes(from->root()); + from->transferTo(to.get(), buffer, kSize); + assertEqualBytes(to.get(), usedBytes, peakBytes, reservedBytes); + if (from->root() == to->root()) { + rootPeakBytes *= 2; + } + assertEqualBytes( + to->root(), rootUsedBytes, rootPeakBytes, rootReservedBytes); + to->free(buffer, kSize); + assertZeroByte(from.get()); + assertZeroByte(to.get()); + assertZeroByte(from->root()); + assertZeroByte(to->root()); + }; + + auto testTransferAllocateContiguous = + [&assertZeroByte, &assertEqualBytes, &getMemoryBytes, &createPools]( + uint64_t pageCount, bool betweenDifferentRoots) { + auto [root1, root2, from, to] = createPools(betweenDifferentRoots); + assertZeroByte(from.get()); + assertZeroByte(to.get()); + assertZeroByte(from->root()); + assertZeroByte(to->root()); + + int64_t usedBytes, rootUsedBytes; + int64_t peakBytes, rootPeakBytes; + int64_t reservedBytes, rootReservedBytes; + ContiguousAllocation out; + from->allocateContiguous(pageCount, out); + std::tie(usedBytes, peakBytes, reservedBytes) = + getMemoryBytes(from.get()); + std::tie(rootUsedBytes, rootPeakBytes, rootReservedBytes) = + getMemoryBytes(from->root()); + from->transferTo(to.get(), out.data(), out.size()); + assertEqualBytes(to.get(), usedBytes, peakBytes, reservedBytes); + if (from->root() == to->root()) { + rootPeakBytes *= 2; + } + assertEqualBytes( + to->root(), rootUsedBytes, rootPeakBytes, rootReservedBytes); + to->freeContiguous(out); + assertZeroByte(from.get()); + assertZeroByte(to.get()); + assertZeroByte(from->root()); + assertZeroByte(to->root()); + }; + + auto testTransferAllocateNonContiguous = + [&assertZeroByte, &assertEqualBytes, &getMemoryBytes, &createPools]( + uint64_t pageCount, bool betweenDifferentRoots) { + auto [root1, root2, from, to] = createPools(betweenDifferentRoots); + assertZeroByte(from.get()); + assertZeroByte(to.get()); + assertZeroByte(from->root()); + assertZeroByte(to->root()); + + int64_t usedBytes, rootUsedBytes; + int64_t peakBytes, rootPeakBytes; + int64_t reservedBytes, rootReservedBytes; + Allocation out; + from->allocateNonContiguous(pageCount, out); + std::tie(usedBytes, peakBytes, reservedBytes) = + getMemoryBytes(from.get()); + std::tie(rootUsedBytes, rootPeakBytes, rootReservedBytes) = + getMemoryBytes(from->root()); + for (auto i = 0; i < out.numRuns(); ++i) { + const auto& run = out.runAt(i); + from->transferTo(to.get(), run.data(), run.numBytes()); + } + assertEqualBytes(to.get(), usedBytes, peakBytes, reservedBytes); + if (from->root() == to->root()) { + EXPECT_EQ(to->root()->usedBytes(), rootUsedBytes); + // We reserve and release memory run-by-run, so the peak bytes would + // be no greater than twice of the original peak bytes. + EXPECT_LE(to->root()->peakBytes(), rootPeakBytes * 2); + EXPECT_EQ(to->root()->reservedBytes(), rootReservedBytes); + } else { + assertEqualBytes( + to->root(), rootUsedBytes, rootPeakBytes, rootReservedBytes); + } + to->freeNonContiguous(out); + assertZeroByte(from.get()); + assertZeroByte(to.get()); + assertZeroByte(from->root()); + assertZeroByte(to->root()); + }; + + // Test transfer between siblings of the same root pool. + testTransferAllocate(false); + testTransferAllocateZeroFilled(false); + for (auto pageCount : pageCounts) { + testTransferAllocateContiguous(pageCount, false); + testTransferAllocateNonContiguous(pageCount, false); + } + + // Test transfer between different root pools. + testTransferAllocate(true); + testTransferAllocateZeroFilled(true); + for (auto pageCount : pageCounts) { + testTransferAllocateContiguous(pageCount, true); + testTransferAllocateNonContiguous(pageCount, true); + } +} + VELOX_INSTANTIATE_TEST_SUITE_P( MemoryPoolTestSuite, MemoryPoolTest, diff --git a/velox/common/memory/tests/MockSharedArbitratorTest.cpp b/velox/common/memory/tests/MockSharedArbitratorTest.cpp index a1569f547cf..830aa1e699b 100644 --- a/velox/common/memory/tests/MockSharedArbitratorTest.cpp +++ b/velox/common/memory/tests/MockSharedArbitratorTest.cpp @@ -21,17 +21,17 @@ #include #include #include -#include "folly/experimental/EventCount.h" +#include "folly/synchronization/EventCount.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/memory/MallocAllocator.h" #include "velox/common/memory/Memory.h" #include "velox/common/memory/MemoryArbitrator.h" #include "velox/common/memory/SharedArbitrator.h" #include "velox/common/memory/tests/SharedArbitratorTestUtil.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/common/testutil/TestValue.h" #include "velox/exec/OperatorUtils.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" DECLARE_bool(velox_memory_leak_check_enabled); DECLARE_bool(velox_suppress_memory_capacity_exceeding_error_message); @@ -49,7 +49,7 @@ class TestRuntimeStatWriter : public BaseRuntimeStatWriter { std::unordered_map& stats) : stats_{stats} {} - void addRuntimeStat(const std::string& name, const RuntimeCounter& value) + void addRuntimeStat(std::string_view name, const RuntimeCounter& value) override { addOperatorRuntimeStats(name, value, stats_); } @@ -1451,12 +1451,21 @@ DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, localArbitrationsFromSameQuery) { setThreadLocalRunTimeStatWriter(statsWriter.get()); runPool->allocate(memoryCapacity / 2); ASSERT_EQ( - runtimeStats[SharedArbitrator::kMemoryArbitrationWallNanos].count, 1); + runtimeStats[std::string(SharedArbitrator::kMemoryArbitrationWallNanos)] + .count, + 1); ASSERT_GT( - runtimeStats[SharedArbitrator::kMemoryArbitrationWallNanos].sum, 0); + runtimeStats[std::string(SharedArbitrator::kMemoryArbitrationWallNanos)] + .sum, + 0); + ASSERT_EQ( + runtimeStats[std::string(SharedArbitrator::kGlobalArbitrationWaitCount)] + .count, + 0); ASSERT_EQ( - runtimeStats[SharedArbitrator::kGlobalArbitrationWaitCount].count, 0); - ASSERT_EQ(runtimeStats[SharedArbitrator::kLocalArbitrationCount].count, 0); + runtimeStats[std::string(SharedArbitrator::kLocalArbitrationCount)] + .count, + 0); ++allocationCount; }); @@ -1467,13 +1476,24 @@ DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, localArbitrationsFromSameQuery) { setThreadLocalRunTimeStatWriter(statsWriter.get()); waitPool->allocate(memoryCapacity / 2 + MB); ASSERT_EQ( - runtimeStats[SharedArbitrator::kMemoryArbitrationWallNanos].count, 1); + runtimeStats[std::string(SharedArbitrator::kMemoryArbitrationWallNanos)] + .count, + 1); ASSERT_GT( - runtimeStats[SharedArbitrator::kMemoryArbitrationWallNanos].sum, 0); + runtimeStats[std::string(SharedArbitrator::kMemoryArbitrationWallNanos)] + .sum, + 0); ASSERT_EQ( - runtimeStats[SharedArbitrator::kGlobalArbitrationWaitCount].count, 0); - ASSERT_EQ(runtimeStats[SharedArbitrator::kLocalArbitrationCount].count, 1); - ASSERT_EQ(runtimeStats[SharedArbitrator::kLocalArbitrationCount].sum, 1); + runtimeStats[std::string(SharedArbitrator::kGlobalArbitrationWaitCount)] + .count, + 0); + ASSERT_EQ( + runtimeStats[std::string(SharedArbitrator::kLocalArbitrationCount)] + .count, + 1); + ASSERT_EQ( + runtimeStats[std::string(SharedArbitrator::kLocalArbitrationCount)].sum, + 1); ++allocationCount; }); @@ -1534,12 +1554,21 @@ DEBUG_ONLY_TEST_F( op1->allocate(MB); ASSERT_EQ(task1->capacity(), 8 * MB); ASSERT_EQ( - runtimeStats[SharedArbitrator::kMemoryArbitrationWallNanos].count, 1); + runtimeStats[std::string(SharedArbitrator::kMemoryArbitrationWallNanos)] + .count, + 1); ASSERT_GT( - runtimeStats[SharedArbitrator::kMemoryArbitrationWallNanos].sum, 0); + runtimeStats[std::string(SharedArbitrator::kMemoryArbitrationWallNanos)] + .sum, + 0); + ASSERT_EQ( + runtimeStats[std::string(SharedArbitrator::kGlobalArbitrationWaitCount)] + .count, + 0); ASSERT_EQ( - runtimeStats[SharedArbitrator::kGlobalArbitrationWaitCount].count, 0); - ASSERT_EQ(runtimeStats[SharedArbitrator::kLocalArbitrationCount].count, 1); + runtimeStats[std::string(SharedArbitrator::kLocalArbitrationCount)] + .count, + 1); ++allocationCount; }); @@ -1550,12 +1579,21 @@ DEBUG_ONLY_TEST_F( op2->allocate(MB); ASSERT_EQ(task2->capacity(), 8 * MB); ASSERT_EQ( - runtimeStats[SharedArbitrator::kMemoryArbitrationWallNanos].count, 1); + runtimeStats[std::string(SharedArbitrator::kMemoryArbitrationWallNanos)] + .count, + 1); ASSERT_GT( - runtimeStats[SharedArbitrator::kMemoryArbitrationWallNanos].sum, 0); + runtimeStats[std::string(SharedArbitrator::kMemoryArbitrationWallNanos)] + .sum, + 0); ASSERT_EQ( - runtimeStats[SharedArbitrator::kGlobalArbitrationWaitCount].count, 0); - ASSERT_EQ(runtimeStats[SharedArbitrator::kLocalArbitrationCount].count, 1); + runtimeStats[std::string(SharedArbitrator::kGlobalArbitrationWaitCount)] + .count, + 0); + ASSERT_EQ( + runtimeStats[std::string(SharedArbitrator::kLocalArbitrationCount)] + .count, + 1); ++allocationCount; }); @@ -1780,20 +1818,37 @@ DEBUG_ONLY_TEST_F( } // We expect global arbitration has been triggered. ASSERT_GE( - runtimeStats[SharedArbitrator::kMemoryArbitrationWallNanos].count, 1); + runtimeStats[std::string(SharedArbitrator::kMemoryArbitrationWallNanos)] + .count, + 1); ASSERT_GT( - runtimeStats[SharedArbitrator::kMemoryArbitrationWallNanos].sum, 0); + runtimeStats[std::string(SharedArbitrator::kMemoryArbitrationWallNanos)] + .sum, + 0); ASSERT_GT( - runtimeStats[SharedArbitrator::kGlobalArbitrationWaitCount].count, 0); + runtimeStats[std::string(SharedArbitrator::kGlobalArbitrationWaitCount)] + .count, + 0); ASSERT_GT( - runtimeStats[SharedArbitrator::kGlobalArbitrationWaitCount].sum, 0); - ASSERT_EQ(runtimeStats[SharedArbitrator::kLocalArbitrationCount].count, 0); - ASSERT_EQ(runtimeStats[SharedArbitrator::kLocalArbitrationCount].sum, 0); + runtimeStats[std::string(SharedArbitrator::kGlobalArbitrationWaitCount)] + .sum, + 0); + ASSERT_EQ( + runtimeStats[std::string(SharedArbitrator::kLocalArbitrationCount)] + .count, + 0); + ASSERT_EQ( + runtimeStats[std::string(SharedArbitrator::kLocalArbitrationCount)].sum, + 0); ASSERT_GT( - runtimeStats[SharedArbitrator::kGlobalArbitrationWaitWallNanos].count, + runtimeStats[std::string( + SharedArbitrator::kGlobalArbitrationWaitWallNanos)] + .count, 0); ASSERT_GT( - runtimeStats[SharedArbitrator::kGlobalArbitrationWaitWallNanos].sum, + runtimeStats[std::string( + SharedArbitrator::kGlobalArbitrationWaitWallNanos)] + .sum, 1'000'000'000); }); @@ -1845,6 +1900,26 @@ DEBUG_ONLY_TEST_F( auto* localArbitrationOp = localArbitrationTask->addMemoryOp(true); localArbitrationOp->allocate(memoryPoolCapacity); + // Install the test value callback BEFORE spawning the thread to avoid a race + // condition where the thread triggers global arbitration before the callback + // is registered, causing the main thread to deadlock. (GitHub issue #15336) + std::atomic_bool globalArbitrationStarted{false}; + folly::EventCount globalArbitrationStartWait; + std::atomic_bool globalArbitrationWaitFlag{true}; + folly::EventCount globalArbitrationWait; + SCOPED_TESTVALUE_SET( + "facebook::velox::memory::SharedArbitrator::runGlobalArbitration", + std::function( + ([&](const SharedArbitrator* /*unused*/) { + if (globalArbitrationStarted.exchange(true)) { + return; + } + globalArbitrationStartWait.notifyAll(); + + globalArbitrationWait.await( + [&]() { return !globalArbitrationWaitFlag.load(); }); + }))); + auto globalArbitrationTriggerThread = std::thread([&]() { std::unordered_map runtimeStats; auto statsWriter = std::make_unique(runtimeStats); @@ -1863,39 +1938,40 @@ DEBUG_ONLY_TEST_F( } // We expect global arbitration has been triggered. ASSERT_GE( - runtimeStats[SharedArbitrator::kMemoryArbitrationWallNanos].count, 1); + runtimeStats[std::string(SharedArbitrator::kMemoryArbitrationWallNanos)] + .count, + 1); ASSERT_GT( - runtimeStats[SharedArbitrator::kMemoryArbitrationWallNanos].sum, 0); + runtimeStats[std::string(SharedArbitrator::kMemoryArbitrationWallNanos)] + .sum, + 0); ASSERT_GE( - runtimeStats[SharedArbitrator::kGlobalArbitrationWaitCount].count, 1); + runtimeStats[std::string(SharedArbitrator::kGlobalArbitrationWaitCount)] + .count, + 1); ASSERT_GE( - runtimeStats[SharedArbitrator::kGlobalArbitrationWaitCount].sum, 1); - ASSERT_EQ(runtimeStats[SharedArbitrator::kLocalArbitrationCount].count, 0); - ASSERT_EQ(runtimeStats[SharedArbitrator::kLocalArbitrationCount].sum, 0); + runtimeStats[std::string(SharedArbitrator::kGlobalArbitrationWaitCount)] + .sum, + 1); + ASSERT_EQ( + runtimeStats[std::string(SharedArbitrator::kLocalArbitrationCount)] + .count, + 0); + ASSERT_EQ( + runtimeStats[std::string(SharedArbitrator::kLocalArbitrationCount)].sum, + 0); ASSERT_GE( - runtimeStats[SharedArbitrator::kGlobalArbitrationWaitWallNanos].count, + runtimeStats[std::string( + SharedArbitrator::kGlobalArbitrationWaitWallNanos)] + .count, 1); ASSERT_GT( - runtimeStats[SharedArbitrator::kGlobalArbitrationWaitWallNanos].sum, 1); + runtimeStats[std::string( + SharedArbitrator::kGlobalArbitrationWaitWallNanos)] + .sum, + 1); }); - std::atomic_bool globalArbitrationStarted{false}; - folly::EventCount globalArbitrationStartWait; - std::atomic_bool globalArbitrationWaitFlag{true}; - folly::EventCount globalArbitrationWait; - SCOPED_TESTVALUE_SET( - "facebook::velox::memory::SharedArbitrator::runGlobalArbitration", - std::function( - ([&](const SharedArbitrator* /*unused*/) { - if (globalArbitrationStarted.exchange(true)) { - return; - } - globalArbitrationStartWait.notifyAll(); - - globalArbitrationWait.await( - [&]() { return !globalArbitrationWaitFlag.load(); }); - }))); - globalArbitrationStartWait.await( [&]() { return globalArbitrationStarted.load(); }); @@ -1912,10 +1988,19 @@ DEBUG_ONLY_TEST_F( globalArbitrationTriggerThread.join(); ASSERT_EQ(localArbitrationOp->capacity(), memoryPoolReservedCapacity); ASSERT_EQ( - runtimeStats[SharedArbitrator::kGlobalArbitrationWaitCount].count, 0); - ASSERT_EQ(runtimeStats[SharedArbitrator::kGlobalArbitrationWaitCount].sum, 0); - ASSERT_EQ(runtimeStats[SharedArbitrator::kLocalArbitrationCount].count, 1); - ASSERT_EQ(runtimeStats[SharedArbitrator::kLocalArbitrationCount].sum, 1); + runtimeStats[std::string(SharedArbitrator::kGlobalArbitrationWaitCount)] + .count, + 0); + ASSERT_EQ( + runtimeStats[std::string(SharedArbitrator::kGlobalArbitrationWaitCount)] + .sum, + 0); + ASSERT_EQ( + runtimeStats[std::string(SharedArbitrator::kLocalArbitrationCount)].count, + 1); + ASSERT_EQ( + runtimeStats[std::string(SharedArbitrator::kLocalArbitrationCount)].sum, + 1); // Global arbitration thread may still be running in the background, // triggerring ASAN failure. Wait until it exits. @@ -1966,14 +2051,25 @@ DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, globalArbitrationAbortTimeRatio) { op1->allocate(memoryCapacity / 2); ASSERT_EQ( - runtimeStats[SharedArbitrator::kMemoryArbitrationWallNanos].count, 1); + runtimeStats[std::string(SharedArbitrator::kMemoryArbitrationWallNanos)] + .count, + 1); ASSERT_GT( - runtimeStats[SharedArbitrator::kMemoryArbitrationWallNanos].sum, 0); + runtimeStats[std::string(SharedArbitrator::kMemoryArbitrationWallNanos)] + .sum, + 0); ASSERT_EQ( - runtimeStats[SharedArbitrator::kGlobalArbitrationWaitCount].count, 1); + runtimeStats[std::string(SharedArbitrator::kGlobalArbitrationWaitCount)] + .count, + 1); ASSERT_EQ( - runtimeStats[SharedArbitrator::kGlobalArbitrationWaitCount].sum, 1); - ASSERT_EQ(runtimeStats[SharedArbitrator::kLocalArbitrationCount].count, 0); + runtimeStats[std::string(SharedArbitrator::kGlobalArbitrationWaitCount)] + .sum, + 1); + ASSERT_EQ( + runtimeStats[std::string(SharedArbitrator::kLocalArbitrationCount)] + .count, + 0); ASSERT_TRUE(task1->error() == nullptr); ASSERT_EQ(task1->capacity(), memoryCapacity); ASSERT_TRUE(task2->error() != nullptr); @@ -2019,12 +2115,24 @@ TEST_F(MockSharedArbitrationTest, globalArbitrationWithoutSpill) { triggerOp->allocate(memoryCapacity / 2); ASSERT_EQ( - runtimeStats[SharedArbitrator::kMemoryArbitrationWallNanos].count, 1); - ASSERT_GT(runtimeStats[SharedArbitrator::kMemoryArbitrationWallNanos].sum, 0); + runtimeStats[std::string(SharedArbitrator::kMemoryArbitrationWallNanos)] + .count, + 1); + ASSERT_GT( + runtimeStats[std::string(SharedArbitrator::kMemoryArbitrationWallNanos)] + .sum, + 0); ASSERT_EQ( - runtimeStats[SharedArbitrator::kGlobalArbitrationWaitCount].count, 1); - ASSERT_EQ(runtimeStats[SharedArbitrator::kGlobalArbitrationWaitCount].sum, 1); - ASSERT_EQ(runtimeStats[SharedArbitrator::kLocalArbitrationCount].count, 0); + runtimeStats[std::string(SharedArbitrator::kGlobalArbitrationWaitCount)] + .count, + 1); + ASSERT_EQ( + runtimeStats[std::string(SharedArbitrator::kGlobalArbitrationWaitCount)] + .sum, + 1); + ASSERT_EQ( + runtimeStats[std::string(SharedArbitrator::kLocalArbitrationCount)].count, + 0); ASSERT_TRUE(triggerTask->error() == nullptr); ASSERT_EQ(triggerTask->capacity(), memoryCapacity); @@ -2066,12 +2174,24 @@ TEST_F(MockSharedArbitrationTest, globalArbitrationSmallParticipantLargeGrow) { VELOX_ASSERT_THROW(op0->allocate(kMemoryCapacity / 2), "aborted"); ASSERT_EQ( - runtimeStats[SharedArbitrator::kMemoryArbitrationWallNanos].count, 1); - ASSERT_GT(runtimeStats[SharedArbitrator::kMemoryArbitrationWallNanos].sum, 0); + runtimeStats[std::string(SharedArbitrator::kMemoryArbitrationWallNanos)] + .count, + 1); + ASSERT_GT( + runtimeStats[std::string(SharedArbitrator::kMemoryArbitrationWallNanos)] + .sum, + 0); + ASSERT_EQ( + runtimeStats[std::string(SharedArbitrator::kGlobalArbitrationWaitCount)] + .count, + 1); + ASSERT_EQ( + runtimeStats[std::string(SharedArbitrator::kGlobalArbitrationWaitCount)] + .sum, + 1); ASSERT_EQ( - runtimeStats[SharedArbitrator::kGlobalArbitrationWaitCount].count, 1); - ASSERT_EQ(runtimeStats[SharedArbitrator::kGlobalArbitrationWaitCount].sum, 1); - ASSERT_EQ(runtimeStats[SharedArbitrator::kLocalArbitrationCount].count, 0); + runtimeStats[std::string(SharedArbitrator::kLocalArbitrationCount)].count, + 0); ASSERT_TRUE(task1->error() == nullptr); ASSERT_EQ(task1->capacity(), kMemoryCapacity / 2); @@ -2277,14 +2397,25 @@ DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, multipleGlobalRuns) { setThreadLocalRunTimeStatWriter(statsWriter.get()); waitPool->allocate(memoryCapacity / 2); ASSERT_EQ( - runtimeStats[SharedArbitrator::kMemoryArbitrationWallNanos].count, 1); + runtimeStats[std::string(SharedArbitrator::kMemoryArbitrationWallNanos)] + .count, + 1); ASSERT_GT( - runtimeStats[SharedArbitrator::kMemoryArbitrationWallNanos].sum, 0); + runtimeStats[std::string(SharedArbitrator::kMemoryArbitrationWallNanos)] + .sum, + 0); ASSERT_EQ( - runtimeStats[SharedArbitrator::kGlobalArbitrationWaitCount].count, 1); + runtimeStats[std::string(SharedArbitrator::kGlobalArbitrationWaitCount)] + .count, + 1); ASSERT_EQ( - runtimeStats[SharedArbitrator::kGlobalArbitrationWaitCount].sum, 1); - ASSERT_EQ(runtimeStats[SharedArbitrator::kLocalArbitrationCount].count, 0); + runtimeStats[std::string(SharedArbitrator::kGlobalArbitrationWaitCount)] + .sum, + 1); + ASSERT_EQ( + runtimeStats[std::string(SharedArbitrator::kLocalArbitrationCount)] + .count, + 0); ++allocations; }); @@ -2294,14 +2425,25 @@ DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, multipleGlobalRuns) { setThreadLocalRunTimeStatWriter(statsWriter.get()); runPool->allocate(memoryCapacity / 2); ASSERT_EQ( - runtimeStats[SharedArbitrator::kMemoryArbitrationWallNanos].count, 1); + runtimeStats[std::string(SharedArbitrator::kMemoryArbitrationWallNanos)] + .count, + 1); ASSERT_GT( - runtimeStats[SharedArbitrator::kMemoryArbitrationWallNanos].sum, 0); + runtimeStats[std::string(SharedArbitrator::kMemoryArbitrationWallNanos)] + .sum, + 0); + ASSERT_EQ( + runtimeStats[std::string(SharedArbitrator::kGlobalArbitrationWaitCount)] + .count, + 1); ASSERT_EQ( - runtimeStats[SharedArbitrator::kGlobalArbitrationWaitCount].count, 1); + runtimeStats[std::string(SharedArbitrator::kGlobalArbitrationWaitCount)] + .sum, + 1); ASSERT_EQ( - runtimeStats[SharedArbitrator::kGlobalArbitrationWaitCount].sum, 1); - ASSERT_EQ(runtimeStats[SharedArbitrator::kLocalArbitrationCount].count, 0); + runtimeStats[std::string(SharedArbitrator::kLocalArbitrationCount)] + .count, + 0); ++allocations; }); diff --git a/velox/common/memory/tests/RawVectorTest.cpp b/velox/common/memory/tests/RawVectorTest.cpp index cfd9611b57c..21cbb9bc464 100644 --- a/velox/common/memory/tests/RawVectorTest.cpp +++ b/velox/common/memory/tests/RawVectorTest.cpp @@ -26,6 +26,10 @@ struct TestParam { bool useMemoryPool{false}; }; +inline void PrintTo(const TestParam& param, std::ostream* os) { + *os << (param.useMemoryPool ? "withPool" : "noPool"); +} + class RawVectorTest : public testing::WithParamInterface, public testing::Test { protected: @@ -148,12 +152,18 @@ TEST_P(RawVectorTest, copyAndMove) { } TEST_P(RawVectorTest, iota) { + constexpr int kSizeThreshold = 1'000'000; raw_vector storage = makeRawVector(0, GetParam().useMemoryPool ? pool_.get() : nullptr); // Small sizes are preallocated. EXPECT_EQ(11, iota(12, storage)[11]); + EXPECT_EQ(6, iota(12, storage, 5)[1]); + EXPECT_EQ( + kSizeThreshold - 1, iota(kSizeThreshold, storage)[kSizeThreshold - 1]); EXPECT_TRUE(storage.empty()); - EXPECT_EQ(110000, iota(110001, storage)[110000]); + EXPECT_EQ( + 2 * kSizeThreshold - 1, + iota(2 * kSizeThreshold, storage)[2 * kSizeThreshold - 1]); // Larger sizes are allocated in 'storage'. EXPECT_FALSE(storage.empty()); } @@ -185,6 +195,74 @@ TEST_P(RawVectorTest, toStdVector) { } } +TEST_P(RawVectorTest, shrinkToFit) { + raw_vector data = + makeRawVector(0, GetParam().useMemoryPool ? pool_.get() : nullptr); + + // Empty vector — shrink should be a no-op. + data.shrink_to_fit(); + EXPECT_EQ(0, data.size()); + EXPECT_EQ(0, data.capacity()); + EXPECT_EQ(nullptr, data.data()); + + // Shrink after clear should free all memory. + for (int i = 0; i < 100; ++i) { + data.push_back(i); + } + EXPECT_EQ(100, data.size()); + EXPECT_GE(data.capacity(), 100); + data.clear(); + data.shrink_to_fit(); + EXPECT_EQ(0, data.size()); + EXPECT_EQ(0, data.capacity()); + EXPECT_EQ(nullptr, data.data()); + + // Size equals capacity — shrink should be a no-op preserving data. + for (int i = 0; i < 10; ++i) { + data.push_back(i * 10); + } + const auto exactCapacity = data.capacity(); + data.shrink_to_fit(); + EXPECT_EQ(10, data.size()); + EXPECT_EQ(exactCapacity, data.capacity()); + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(i * 10, data[i]); + } + + // Capacity larger than size — shrink should reduce capacity. + data.reserve(10'000); + const auto largeCapacity = data.capacity(); + EXPECT_GT(largeCapacity, exactCapacity); + data.shrink_to_fit(); + EXPECT_EQ(10, data.size()); + EXPECT_LT(data.capacity(), largeCapacity); + EXPECT_LE(data.capacity(), exactCapacity); + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(i * 10, data[i]); + } + + // Push back after shrink should work correctly. + for (int i = 10; i < 20; ++i) { + data.push_back(i * 10); + } + EXPECT_EQ(20, data.size()); + for (int i = 0; i < 20; ++i) { + EXPECT_EQ(i * 10, data[i]); + } + + // Push back after shrink from empty should work correctly. + data.clear(); + data.shrink_to_fit(); + EXPECT_EQ(nullptr, data.data()); + for (int i = 0; i < 5; ++i) { + data.push_back(i * 100); + } + EXPECT_EQ(5, data.size()); + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(i * 100, data[i]); + } +} + VELOX_INSTANTIATE_TEST_SUITE_P( RawVectorTest, RawVectorTest, diff --git a/velox/common/memory/tests/SharedArbitratorTest.cpp b/velox/common/memory/tests/SharedArbitratorTest.cpp index f01cc043da5..69a84e66dc5 100644 --- a/velox/common/memory/tests/SharedArbitratorTest.cpp +++ b/velox/common/memory/tests/SharedArbitratorTest.cpp @@ -22,7 +22,7 @@ #include #include #include -#include "folly/experimental/EventCount.h" +#include "folly/synchronization/EventCount.h" #include "velox/common/base/Exceptions.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/memory/MallocAllocator.h" @@ -292,23 +292,30 @@ class SharedArbitrationTest : public testing::WithParamInterface, if (expectGlobalArbitration) { VELOX_CHECK_EQ( stats.customStats.count( - SharedArbitrator::kGlobalArbitrationWaitCount), + std::string(SharedArbitrator::kGlobalArbitrationWaitCount)), 1); VELOX_CHECK_GE( - stats.customStats.at(SharedArbitrator::kGlobalArbitrationWaitCount) + stats.customStats + .at(std::string(SharedArbitrator::kGlobalArbitrationWaitCount)) .sum, 1); VELOX_CHECK_EQ( - stats.customStats.count(SharedArbitrator::kLocalArbitrationCount), 0); + stats.customStats.count( + std::string(SharedArbitrator::kLocalArbitrationCount)), + 0); } else { VELOX_CHECK_EQ( - stats.customStats.count(SharedArbitrator::kLocalArbitrationCount), 1); + stats.customStats.count( + std::string(SharedArbitrator::kLocalArbitrationCount)), + 1); VELOX_CHECK_EQ( - stats.customStats.at(SharedArbitrator::kLocalArbitrationCount).sum, + stats.customStats + .at(std::string(SharedArbitrator::kLocalArbitrationCount)) + .sum, 1); VELOX_CHECK_EQ( stats.customStats.count( - SharedArbitrator::kGlobalArbitrationWaitCount), + std::string(SharedArbitrator::kGlobalArbitrationWaitCount)), 0); } } @@ -358,18 +365,19 @@ DEBUG_ONLY_TEST_P( queryCtxStateChecked = true; }))); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); TestScopedSpillInjection scopedSpillInjection(100); core::PlanNodeId aggregationNodeId; newQueryBuilder() .queryCtx(queryCtx) .spillDirectory(spillDirectory->getPath()) .config(core::QueryConfig::kSpillEnabled, "true") - .plan(PlanBuilder() - .values(vectors) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .capturePlanNodeId(aggregationNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .capturePlanNodeId(aggregationNodeId) + .planNode()) .assertResults("SELECT c0, c1, array_agg(c2) FROM tmp GROUP BY c0, c1"); ASSERT_TRUE(queryCtxStateChecked); ASSERT_FALSE(queryCtx->testingUnderArbitration()); @@ -406,7 +414,7 @@ DEBUG_ONLY_TEST_P( }))); std::thread queryThread([&] { - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); core::PlanNodeId aggregationNodeId; auto plan = PlanBuilder() .values(vectors) @@ -485,7 +493,7 @@ DEBUG_ONLY_TEST_P( .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) .planNode(); std::thread spillableThread([&]() { - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); newQueryBuilder(spillPlan) .queryCtx(queryCtx) .spillDirectory(spillDirectory->getPath()) @@ -590,11 +598,12 @@ DEBUG_ONLY_TEST_P(SharedArbitrationTestWithThreadingModes, reclaimToOrderBy) { newQueryBuilder() .queryCtx(orderByQueryCtx) .serialExecution(isSerialExecutionMode_) - .plan(PlanBuilder() - .values(vectors) - .orderBy({"c0 ASC NULLS LAST"}, false) - .capturePlanNodeId(orderByNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .orderBy({"c0 ASC NULLS LAST"}, false) + .capturePlanNodeId(orderByNodeId) + .planNode()) .assertResults("SELECT * FROM tmp ORDER BY c0 ASC NULLS LAST"); auto taskStats = exec::toPlanStats(task->taskStats()); auto& stats = taskStats.at(orderByNodeId); @@ -607,12 +616,13 @@ DEBUG_ONLY_TEST_P(SharedArbitrationTestWithThreadingModes, reclaimToOrderBy) { newQueryBuilder() .queryCtx(fakeMemoryQueryCtx) .serialExecution(isSerialExecutionMode_) - .plan(PlanBuilder() - .values(vectors) - .addNode([&](std::string id, core::PlanNodePtr input) { - return std::make_shared(id, input); - }) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .addNode([&](std::string id, core::PlanNodePtr input) { + return std::make_shared(id, input); + }) + .planNode()) .assertResults("SELECT * FROM tmp"); }); @@ -691,11 +701,12 @@ DEBUG_ONLY_TEST_P( newQueryBuilder() .queryCtx(aggregationQueryCtx) .serialExecution(isSerialExecutionMode_) - .plan(PlanBuilder() - .values(vectors) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .capturePlanNodeId(aggregationNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .capturePlanNodeId(aggregationNodeId) + .planNode()) .assertResults( "SELECT c0, c1, array_agg(c2) FROM tmp GROUP BY c0, c1"); auto taskStats = exec::toPlanStats(task->taskStats()); @@ -709,12 +720,13 @@ DEBUG_ONLY_TEST_P( newQueryBuilder() .queryCtx(fakeMemoryQueryCtx) .serialExecution(isSerialExecutionMode_) - .plan(PlanBuilder() - .values(vectors) - .addNode([&](std::string id, core::PlanNodePtr input) { - return std::make_shared(id, input); - }) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .addNode([&](std::string id, core::PlanNodePtr input) { + return std::make_shared(id, input); + }) + .planNode()) .assertResults("SELECT * FROM tmp"); }); @@ -755,7 +767,7 @@ DEBUG_ONLY_TEST_P( folly::EventCount taskPauseWait; auto taskPauseWaitKey = taskPauseWait.prepareWait(); - const auto fakeAllocationSize = kMemoryCapacity - (32L << 20); + const auto fakeAllocationSize = kMemoryCapacity - (2L << 20); std::atomic injectAllocationOnce{true}; fakeOperatorFactory_->setAllocationCallback([&](Operator* op) { @@ -822,12 +834,13 @@ DEBUG_ONLY_TEST_P( newQueryBuilder() .queryCtx(fakeMemoryQueryCtx) .serialExecution(isSerialExecutionMode_) - .plan(PlanBuilder() - .values(vectors) - .addNode([&](std::string id, core::PlanNodePtr input) { - return std::make_shared(id, input); - }) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .addNode([&](std::string id, core::PlanNodePtr input) { + return std::make_shared(id, input); + }) + .planNode()) .assertResults("SELECT * FROM tmp"); }); @@ -940,7 +953,7 @@ DEBUG_ONLY_TEST_P( }))); const int numDrivers = 1; - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); std::thread queryThread([&]() { VELOX_ASSERT_THROW( newQueryBuilder() @@ -950,12 +963,13 @@ DEBUG_ONLY_TEST_P( .config(core::QueryConfig::kJoinSpillEnabled, "true") .config(core::QueryConfig::kSpillNumPartitionBits, "2") .maxDrivers(numDrivers) - .plan(PlanBuilder() - .values(vectors) - .localPartition({"c0", "c1"}) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .localPartition(std::vector{}) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .localPartition({"c0", "c1"}) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .localPartition(std::vector{}) + .planNode()) .assertResults( "SELECT c0, c1, array_agg(c2) FROM tmp GROUP BY c0, c1"), "Aborted for external error"); @@ -1013,7 +1027,7 @@ DEBUG_ONLY_TEST_P( [&]() { return aggregationAllocationUnblocked.load(); }); }))); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); std::shared_ptr task; std::thread queryThread([&]() { task = newQueryBuilder() @@ -1022,12 +1036,13 @@ DEBUG_ONLY_TEST_P( .config(core::QueryConfig::kSpillEnabled, "true") .config(core::QueryConfig::kJoinSpillEnabled, "true") .config(core::QueryConfig::kSpillNumPartitionBits, "2") - .plan(PlanBuilder() - .values(vectors) - .localPartition({"c0", "c1"}) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .localPartition(std::vector{}) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .localPartition({"c0", "c1"}) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .localPartition(std::vector{}) + .planNode()) .assertResults( "SELECT c0, c1, array_agg(c2) FROM tmp GROUP BY c0, c1"); }); @@ -1083,7 +1098,7 @@ DEBUG_ONLY_TEST_P(SharedArbitrationTestWithThreadingModes, runtimeStats) { values->pool()->free(buffer, fakeAllocationSize); }))); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); const auto outputDirectory = TempDirectoryPath::create(); const auto queryCtx = newQueryCtx(memoryManager_.get(), executor_.get(), memoryCapacity); @@ -1194,23 +1209,25 @@ DEBUG_ONLY_TEST_P( if (sameDriver) { task = newQueryBuilder() .queryCtx(queryCtx) - .plan(PlanBuilder() - .values(vectors) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .capturePlanNodeId(aggregationNodeId) - .localPartition(std::vector{}) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .capturePlanNodeId(aggregationNodeId) + .localPartition(std::vector{}) + .planNode()) .assertResults( "SELECT c0, c1, array_agg(c2) FROM tmp GROUP BY c0, c1"); } else { task = newQueryBuilder() .queryCtx(queryCtx) - .plan(PlanBuilder() - .values(vectors) - .localPartition({"c0", "c1"}) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .capturePlanNodeId(aggregationNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .localPartition({"c0", "c1"}) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .capturePlanNodeId(aggregationNodeId) + .planNode()) .assertResults( "SELECT c0, c1, array_agg(c2) FROM tmp GROUP BY c0, c1"); } @@ -1369,6 +1386,7 @@ TEST_P( if (e.errorCode() != error_code::kMemCapExceeded.c_str() && e.errorCode() != error_code::kMemAborted.c_str() && e.errorCode() != error_code::kMemAllocError.c_str() && + e.errorCode() != error_code::kMemArbitrationTimeout.c_str() && (e.message() != "Aborted for external error")) { std::rethrow_exception(std::current_exception()); } @@ -1424,12 +1442,20 @@ TEST_P(SharedArbitrationTestWithThreadingModes, reserveReleaseCounters) { VELOX_INSTANTIATE_TEST_SUITE_P( SharedArbitrationTest, SharedArbitrationTestWithParallelExecutionModeOnly, - testing::ValuesIn(std::vector{{false}})); + testing::ValuesIn(std::vector{{false}}), + [](const testing::TestParamInfo& info) { + return fmt::format( + "{}", info.param.isSerialExecutionMode ? "serial" : "parallel"); + }); VELOX_INSTANTIATE_TEST_SUITE_P( SharedArbitrationTest, SharedArbitrationTestWithThreadingModes, - testing::ValuesIn(std::vector{{false}, {true}})); + testing::ValuesIn(std::vector{{false}, {true}}), + [](const testing::TestParamInfo& info) { + return fmt::format( + "{}", info.param.isSerialExecutionMode ? "serial" : "parallel"); + }); } // namespace facebook::velox::memory int main(int argc, char** argv) { diff --git a/velox/common/memory/tests/StreamArenaTest.cpp b/velox/common/memory/tests/StreamArenaTest.cpp index d8db86a468b..9caa06a6ef3 100644 --- a/velox/common/memory/tests/StreamArenaTest.cpp +++ b/velox/common/memory/tests/StreamArenaTest.cpp @@ -127,9 +127,11 @@ TEST_F(StreamArenaTest, randomRange) { ByteRange range; for (int i = 0; i < numRanges; ++i) { if (folly::Random::oneIn(3)) { - const int requestSize = - AllocationTraits::pageBytes(pool_->largestSizeClass()) + - (folly::Random::rand32() % (4 << 20)); + const int requestSize = std::min( + static_cast( + AllocationTraits::pageBytes(pool_->largestSizeClass()) + + (folly::Random::rand32() % (4 << 20))), + 2 << 20); arena->newRange(requestSize, nullptr, &range); ASSERT_EQ(AllocationTraits::roundUpPageBytes(requestSize), range.size); } else { diff --git a/velox/common/process/CMakeLists.txt b/velox/common/process/CMakeLists.txt index 2fb856caa54..2c4dc583a53 100644 --- a/velox/common/process/CMakeLists.txt +++ b/velox/common/process/CMakeLists.txt @@ -19,6 +19,13 @@ velox_add_library( ThreadDebugInfo.cpp TraceContext.cpp TraceHistory.cpp + HEADERS + ProcessBase.h + StackTrace.h + ThreadDebugInfo.h + ThreadLocalRegistry.h + TraceContext.h + TraceHistory.h ) velox_link_libraries( @@ -29,6 +36,7 @@ velox_link_libraries( # Profiler need not be part of the core Velox library add_library(velox_profiler OBJECT Profiler.cpp) +velox_add_test_headers(velox_profiler Profiler.h) target_link_libraries( velox_profiler PUBLIC velox_flag_definitions Folly::folly diff --git a/velox/common/process/Profiler.cpp b/velox/common/process/Profiler.cpp index 82a5e836526..012c3ce3524 100644 --- a/velox/common/process/Profiler.cpp +++ b/velox/common/process/Profiler.cpp @@ -153,12 +153,13 @@ void Profiler::copyToResult(const std::string* data) { auto now = nowSeconds(); auto elapsed = (now - sampleStartTime_); auto cpu = cpuSeconds(); - out->append(fmt::format( - "Profile from {} to {} at {}% CPU\n\n", + out->append( + fmt::format( + "Profile from {} to {} at {}% CPU\n\n", - timeString(sampleStartTime_), - timeString(now), - 100 * (cpu - cpuAtSampleStart_) / std::max(1, elapsed))); + timeString(sampleStartTime_), + timeString(now), + 100 * (cpu - cpuAtSampleStart_) / std::max(1, elapsed))); out->append(std::string_view(buffer, resultSize)); if (extraReport_) { std::string extra = extraReport_(); @@ -191,18 +192,19 @@ std::thread Profiler::startSample() { // and killing it with SIGINT produces a corrupt perf.data // file. The perf.data file generated when called via system() is // good, though. Unsolved mystery. - system(fmt::format( - "(cd {}; /usr/bin/perf record --pid {} {};" - "perf report --sort symbol > perf ;" - "sed --in-place 's/ / /'g perf;" - "sed --in-place 's/ / /'g perf; date) " - ">> {}/perftrace 2>>{}/perftrace2", - FLAGS_profiler_tmp_dir, - getpid(), - FLAGS_profiler_perf_flags, - FLAGS_profiler_tmp_dir, - FLAGS_profiler_tmp_dir) - .c_str()); // NOLINT + system( + fmt::format( + "(cd {}; /usr/bin/perf record --pid {} {};" + "perf report --sort symbol > perf ;" + "sed --in-place 's/ / /'g perf;" + "sed --in-place 's/ / /'g perf; date) " + ">> {}/perftrace 2>>{}/perftrace2", + FLAGS_profiler_tmp_dir, + getpid(), + FLAGS_profiler_perf_flags, + FLAGS_profiler_tmp_dir, + FLAGS_profiler_tmp_dir) + .c_str()); // NOLINT if (shouldSaveResult_) { copyToResult(); } diff --git a/velox/common/process/StackTrace.cpp b/velox/common/process/StackTrace.cpp index 05da291a4c9..30738249985 100644 --- a/velox/common/process/StackTrace.cpp +++ b/velox/common/process/StackTrace.cpp @@ -30,12 +30,12 @@ #include #include #include -#include +#include #include "velox/common/process/ProcessBase.h" #ifdef __linux__ -#include // @manual +#include // @manual #include // @manual #endif @@ -96,7 +96,7 @@ const std::vector& StackTrace::toStrVector() const { btVector_.reserve(btPtrs_.size()); for (auto ptr : btPtrs_) { auto framename = translateFrame(ptr); - if (folly::StringPiece(framename).startsWith(*myname)) { + if (framename.starts_with(*myname)) { continue; // ignore frames in the StackTrace class } btVector_.push_back(fmt::format("# {:<2d} {}", frame++, framename)); diff --git a/velox/common/process/ThreadDebugInfo.cpp b/velox/common/process/ThreadDebugInfo.cpp index ae681a2dcc4..2502572c7c3 100644 --- a/velox/common/process/ThreadDebugInfo.cpp +++ b/velox/common/process/ThreadDebugInfo.cpp @@ -16,7 +16,7 @@ #include "velox/common/process/ThreadDebugInfo.h" -#include +#include #include namespace facebook::velox::process { diff --git a/velox/common/rpc/CMakeLists.txt b/velox/common/rpc/CMakeLists.txt new file mode 100644 index 00000000000..c19c3927ac1 --- /dev/null +++ b/velox/common/rpc/CMakeLists.txt @@ -0,0 +1,31 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +velox_install_library_headers() + +velox_add_library(velox_rpc_types INTERFACE RPCTypes.h) + +velox_add_library(velox_rpc_client INTERFACE IRPCClient.h) + +velox_link_libraries(velox_rpc_client INTERFACE velox_rpc_types Folly::folly) + +velox_add_library(velox_mock_rpc_client clients/MockRPCClient.cpp HEADERS clients/MockRPCClient.h) + +velox_link_libraries( + velox_mock_rpc_client + velox_rpc_client + velox_common_base + velox_exception + Folly::folly +) diff --git a/velox/common/rpc/IRPCClient.h b/velox/common/rpc/IRPCClient.h new file mode 100644 index 00000000000..02e2666fb32 --- /dev/null +++ b/velox/common/rpc/IRPCClient.h @@ -0,0 +1,105 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include + +#include "velox/common/rpc/RPCTypes.h" + +namespace facebook::velox::core { +class QueryConfig; +} // namespace facebook::velox::core + +namespace facebook::velox::rpc { + +/// Interface for RPC clients (transport layer). +/// +/// IRPCClient is concerned with how to send requests and receive responses +/// over the network. It is decoupled from the business logic — domain-specific +/// request/response formatting is handled by AsyncRPCFunction (in +/// velox/expression/rpc/). +/// +/// Implementations provide the actual transport (e.g., Thrift, gRPC, mock). +/// +/// Thread safety: Implementations MUST be thread-safe for concurrent calls. +/// The operator may dispatch multiple RPCs concurrently from a single thread, +/// and completion callbacks run on the client's executor threads. +class IRPCClient { + public: + virtual ~IRPCClient() = default; + + /// Execute a single RPC call asynchronously. + /// @param request The request to send. + /// @return A SemiFuture that will contain the response when complete. + virtual folly::SemiFuture call(const RPCRequest& request) = 0; + + /// Execute a batch of RPC calls as a single request. + /// Default implementation fans out to individual call()s. + /// Override for backends that support native batching (e.g., batch + /// inference). + /// @param requests The batch of requests to send. + /// @return A SemiFuture that will contain all responses when complete. + virtual folly::SemiFuture> callBatch( + const std::vector& requests) { + std::vector> futures; + futures.reserve(requests.size()); + for (const auto& request : requests) { + futures.push_back(call(request)); + } + // Capture rowIds to preserve them in error responses. + std::vector rowIds; + rowIds.reserve(requests.size()); + for (const auto& request : requests) { + rowIds.push_back(request.rowId); + } + return folly::collectAll(std::move(futures)) + .deferValue([rowIds = std::move(rowIds)]( + std::vector> tries) { + std::vector responses; + responses.reserve(tries.size()); + for (size_t i = 0; i < tries.size(); ++i) { + if (tries[i].hasValue()) { + responses.push_back(std::move(tries[i].value())); + } else { + RPCResponse errorResp; + errorResp.rowId = rowIds[i]; + errorResp.error = tries[i].exception().what().toStdString(); + responses.push_back(std::move(errorResp)); + } + } + return responses; + }); + } + + /// Returns the service tier key for rate limiting (e.g., + /// "service.backend.prod"). Requests from clients sharing the same tier key + /// share a concurrency budget in RPCRateLimiter. + /// Empty string means "no tier configured" — uses the global default limit. + virtual std::string tierKey() const { + return ""; + } + + /// Set the query config for session-level parameters. + /// Called by the operator before dispatching RPCs. + virtual void setQueryConfig(const core::QueryConfig* /*config*/) {} +}; + +} // namespace facebook::velox::rpc diff --git a/velox/common/rpc/RPCTypes.h b/velox/common/rpc/RPCTypes.h new file mode 100644 index 00000000000..f2ffbfccccc --- /dev/null +++ b/velox/common/rpc/RPCTypes.h @@ -0,0 +1,118 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include "velox/vector/TypeAliases.h" + +namespace facebook::velox::rpc { + +/// Well-known option key constants for RPCRequest.options. +/// Use these instead of raw string literals to prevent typo bugs. +namespace keys { +inline constexpr std::string_view kModel = "model"; +inline constexpr std::string_view kTemperature = "temperature"; +inline constexpr std::string_view kMaxTokens = "max_tokens"; +inline constexpr std::string_view kSystemPrompt = "systemPrompt"; +inline constexpr std::string_view kJsonSchema = "json_schema"; +inline constexpr std::string_view kMetagenKey = "metagen_key"; +inline constexpr std::string_view kTierOverride = "tier_override"; +inline constexpr std::string_view kCatToken = "cat_token"; +inline constexpr std::string_view kPollIntervalMs = "poll_interval_ms"; +inline constexpr std::string_view kOwnerUnixname = "owner_unixname"; +inline constexpr std::string_view kIsQuery = "is_query"; +inline constexpr std::string_view kPrefixDim = "prefix_dim"; +} // namespace keys + +/// Streaming mode for RPC execution. +/// Controls how RPC results are emitted to downstream operators. +enum class RPCStreamingMode { + /// Emit rows as they complete individually (default). + /// Lower tail latency for high-variance workloads (e.g., LLM). + kPerRow, + + /// Wait for all rows in batch before emitting. + /// Lower overhead, useful for uniform-latency workloads. + kBatch +}; + +/// Parse streaming mode from config string. +/// Returns kPerRow (default) unless explicitly set to "batch". +inline RPCStreamingMode parseStreamingMode(const std::string& value) { + if (value == "batch") { + return RPCStreamingMode::kBatch; + } + return RPCStreamingMode::kPerRow; +} + +/// Generic request structure for RPC calls. +/// This is a minimal, domain-agnostic structure that works for any backend. +/// Domain-specific formatting (e.g., LLM prompts, embedding inputs) is handled +/// by the plan node's buildRequests() method. +struct RPCRequest { + /// Row ID for tracking which row this request belongs to. + /// This is a globally unique ID assigned by the operator. + int64_t rowId{0}; + + /// Original row index in the input batch. + /// This is used to slice the correct row from input columns when storing + /// passthrough data. Unlike rowId (which is globally unique across batches), + /// this is the index within the current input batch and is set by + /// prepareRequests() based on the SelectivityVector iteration. + /// CRITICAL: When prepareRequests() skips null rows, originalRowIndex + /// tracks the actual input position to avoid slicing mismatch. + vector_size_t originalRowIndex{0}; + + /// Whether this row has a null primary input. + /// When true, the transport should short-circuit and return an error + /// response so that buildOutput() produces SQL NULL for this row. + /// Replaces the former "__null_input" magic string in options. + bool isNull{false}; + + /// The request payload (opaque to the framework). + std::string payload; + + /// Type-safe options for backend-specific parameters. + std::map options; +}; + +/// Generic response structure from RPC calls. +/// This is a minimal, domain-agnostic structure that works for any backend. +struct RPCResponse { + /// Row ID for correlating response with the original request. + int64_t rowId{0}; + + /// The response result (opaque to the framework). + std::string result; + + /// Type-safe metadata from the backend. + std::map metadata; + + /// Error message if the request failed. + std::optional error; + + /// Returns true if this response represents an error. + bool hasError() const { + return error.has_value(); + } +}; + +} // namespace facebook::velox::rpc diff --git a/velox/common/rpc/clients/MockRPCClient.cpp b/velox/common/rpc/clients/MockRPCClient.cpp new file mode 100644 index 00000000000..2a266a2c686 --- /dev/null +++ b/velox/common/rpc/clients/MockRPCClient.cpp @@ -0,0 +1,124 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/rpc/clients/MockRPCClient.h" + +#include +#include + +namespace facebook::velox::rpc { + +namespace { +// Function-local static pattern for thread-local RNG to avoid +// NonPodStaticDeclaration lint warning. +std::mt19937& threadLocalRng() { + thread_local std::mt19937 rng{std::random_device{}()}; + return rng; +} +} // namespace + +MockRPCClient::MockRPCClient( + std::chrono::milliseconds latency, + double errorRate, + std::shared_ptr executor) + : latency_(latency), errorRate_(errorRate) { + if (executor) { + executor_ = std::move(executor); + } else { + ownedExecutor_ = std::make_shared(4); + executor_ = ownedExecutor_; + } +} + +MockRPCClient::~MockRPCClient() = default; + +RPCResponse MockRPCClient::generateResponse( + const RPCRequest& request, + bool isError) { + if (isError) { + return RPCResponse{ + .rowId = request.rowId, + .result = "", + .metadata = {}, + .error = "Simulated error for row " + std::to_string(request.rowId)}; + } + + // Generate a mock response + std::string responseText = "Response for: "; + if (request.payload.size() > 30) { + responseText += request.payload.substr(0, 30) + "..."; + } else { + responseText += request.payload; + } + + return RPCResponse{ + .rowId = request.rowId, + .result = std::move(responseText), + .metadata = {}, + .error = std::nullopt}; +} + +folly::SemiFuture MockRPCClient::call(const RPCRequest& request) { + callCount_.fetch_add(1); + + // Determine if this request should fail + std::uniform_real_distribution dist(0.0, 1.0); + bool shouldError = dist(threadLocalRng()) < errorRate_; + + // Use folly::via with the thread pool executor for safe async execution + return folly::via( + executor_.get(), + [this, request = request, shouldError, latency = latency_]() + -> RPCResponse { + // Simulate network latency + /* sleep override */ std::this_thread::sleep_for(latency); + // Generate and return the response + return generateResponse(request, shouldError); + }); +} + +folly::SemiFuture> MockRPCClient::callBatch( + const std::vector& requests) { + // Capture error rate for thread safety + double errorRate = errorRate_; + + // Use folly::via with the thread pool executor for safe async execution + return folly::via( + executor_.get(), + [this, requests, errorRate, latency = latency_]() + -> std::vector { + // Simulate network latency (single batch = single latency) + /* sleep override */ std::this_thread::sleep_for(latency); + + std::vector responses; + responses.reserve(requests.size()); + + // Create RNG inside lambda to avoid thread-local access issues. + // Each executor thread will have its own properly initialized RNG. + thread_local std::mt19937 localRng{std::random_device{}()}; + std::uniform_real_distribution dist(0.0, 1.0); + + for (const auto& request : requests) { + callCount_.fetch_add(1); + bool shouldError = dist(localRng) < errorRate; + responses.push_back(generateResponse(request, shouldError)); + } + + return responses; + }); +} + +} // namespace facebook::velox::rpc diff --git a/velox/common/rpc/clients/MockRPCClient.h b/velox/common/rpc/clients/MockRPCClient.h new file mode 100644 index 00000000000..417196701cb --- /dev/null +++ b/velox/common/rpc/clients/MockRPCClient.h @@ -0,0 +1,76 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include + +#include "velox/common/rpc/IRPCClient.h" + +namespace facebook::velox::rpc { + +/// Mock RPC client that simulates backend latency for testing. +/// Thread-safe for concurrent use. Uses a thread pool executor for async +/// execution — either a shared executor passed in, or a local one created +/// per-client. +class MockRPCClient : public IRPCClient { + public: + /// Creates a mock client with configurable latency and error rate. + /// @param latency Simulated RPC latency (default 200ms). + /// @param errorRate Probability of error per request (0.0-1.0, default 0). + /// @param executor Shared executor for async work. If nullptr, creates a + /// local thread pool. Pass a shared executor for global throttling across + /// query instances. + explicit MockRPCClient( + std::chrono::milliseconds latency = std::chrono::milliseconds(200), + double errorRate = 0.0, + std::shared_ptr executor = nullptr); + + ~MockRPCClient() override; + + folly::SemiFuture call(const RPCRequest& request) override; + + folly::SemiFuture> callBatch( + const std::vector& requests) override; + + /// Returns the total number of RPC calls made. + int64_t callCount() const { + return callCount_.load(); + } + + /// Resets the call counter. + void resetCallCount() { + callCount_.store(0); + } + + private: + RPCResponse generateResponse(const RPCRequest& request, bool isError); + + const std::chrono::milliseconds latency_; + const double errorRate_; + std::atomic callCount_{0}; + + /// Shared executor (may be shared across clients for global throttling). + std::shared_ptr executor_; + /// Locally-owned executor (created when no shared executor is provided). + std::shared_ptr ownedExecutor_; +}; + +} // namespace facebook::velox::rpc diff --git a/velox/common/serialization/CMakeLists.txt b/velox/common/serialization/CMakeLists.txt index 772987a61f3..9aad6dc4b26 100644 --- a/velox/common/serialization/CMakeLists.txt +++ b/velox/common/serialization/CMakeLists.txt @@ -12,7 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -velox_add_library(velox_serialization DeserializationRegistry.cpp) +velox_add_library( + velox_serialization + DeserializationRegistry.cpp + HEADERS + DeserializationRegistry.h + Registry.h + Serializable.h +) velox_link_libraries(velox_serialization PUBLIC velox_exception Folly::folly glog::glog) diff --git a/velox/common/serialization/DeserializationRegistry.h b/velox/common/serialization/DeserializationRegistry.h index 78cbf41a9e2..c96b499ed52 100644 --- a/velox/common/serialization/DeserializationRegistry.h +++ b/velox/common/serialization/DeserializationRegistry.h @@ -53,6 +53,16 @@ struct is_templated_create< T, std::void_t( std::declval()))>> : std::true_type {}; + +template +struct is_templated_create_with_context : std::false_type {}; + +template +struct is_templated_create_with_context< + T, + std::void_t( + std::declval(), + std::declval()))>> : std::true_type {}; } // namespace detail template @@ -66,5 +76,16 @@ void registerDeserializer() { } } +template +void registerDeserializerWithContext() { + if constexpr (detail::is_templated_create_with_context::value) { + DeserializationWithContextRegistryForSharedPtr().Register( + T::getClassName(), T::template create); + } else { + DeserializationWithContextRegistryForSharedPtr().Register( + T::getClassName(), T::create); + } +} + } // namespace velox } // namespace facebook diff --git a/velox/common/serialization/Registry.h b/velox/common/serialization/Registry.h index a290dd7ecd9..7029a2eda1f 100644 --- a/velox/common/serialization/Registry.h +++ b/velox/common/serialization/Registry.h @@ -114,7 +114,7 @@ class Registry { } VELOX_UNSUPPORTED( - typeid(ReturnType).name(), " is not nullable return type"); + "{} is not nullable return type", typeid(ReturnType).name()); } return it->second(types...); } diff --git a/velox/common/serialization/Serializable.h b/velox/common/serialization/Serializable.h index 3db8daf1976..148d12ae2c7 100644 --- a/velox/common/serialization/Serializable.h +++ b/velox/common/serialization/Serializable.h @@ -298,8 +298,7 @@ class ISerializable { } template < - template - typename TMap, + template typename TMap, typename TKey, typename TMapped, typename... TArgs, diff --git a/velox/common/serialization/tests/SerializableTest.cpp b/velox/common/serialization/tests/SerializableTest.cpp index 90b60852b9b..a9692a8da16 100644 --- a/velox/common/serialization/tests/SerializableTest.cpp +++ b/velox/common/serialization/tests/SerializableTest.cpp @@ -160,8 +160,7 @@ TEST(SerializableTest, context) { } template < - template - typename TMap, + template typename TMap, typename TKey, typename TMapped, typename TIt, diff --git a/velox/functions/prestosql/geospatial/tests/CMakeLists.txt b/velox/common/tests/CMakeLists.txt similarity index 77% rename from velox/functions/prestosql/geospatial/tests/CMakeLists.txt rename to velox/common/tests/CMakeLists.txt index b2b460f3acc..a36ecf629bc 100644 --- a/velox/functions/prestosql/geospatial/tests/CMakeLists.txt +++ b/velox/common/tests/CMakeLists.txt @@ -11,17 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -if(VELOX_ENABLE_GEO) - add_executable(velox_geospatial_test GeometryFunctionsTest.cpp) -endif() - -add_test(velox_geospatial_test velox_geospatial_test) +add_executable(velox_scoped_registry_test ScopedRegistryTest.cpp) +add_test(velox_scoped_registry_test velox_scoped_registry_test) target_link_libraries( - velox_geospatial_test + velox_scoped_registry_test + velox_common_base + GTest::gmock GTest::gtest GTest::gtest_main - GTest::gmock - GTest::gmock_main + Folly::folly ) diff --git a/velox/common/tests/ScopedRegistryTest.cpp b/velox/common/tests/ScopedRegistryTest.cpp new file mode 100644 index 00000000000..2844fffdd67 --- /dev/null +++ b/velox/common/tests/ScopedRegistryTest.cpp @@ -0,0 +1,165 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/ScopedRegistry.h" + +#include +#include +#include +#include +#include + +#include + +namespace facebook::velox { +namespace { + +// Minimal value type for testing. +class TestEntry { + public: + explicit TestEntry(std::string name) : name_{std::move(name)} {} + + const std::string& name() const { + return name_; + } + + private: + std::string name_; +}; + +TEST(ScopedRegistryTest, insertAndFind) { + ScopedRegistry registry; + auto entry = std::make_shared("test-1"); + registry.insert("test-1", entry); + + EXPECT_EQ(registry.find("test-1"), entry); + EXPECT_EQ(registry.find("nonexistent"), nullptr); +} + +TEST(ScopedRegistryTest, insertDuplicateThrows) { + ScopedRegistry registry; + registry.insert("key", std::make_shared("key")); + + EXPECT_THROW( + registry.insert("key", std::make_shared("key")), + VeloxRuntimeError); +} + +TEST(ScopedRegistryTest, insertOverwrite) { + ScopedRegistry registry; + auto first = std::make_shared("first"); + auto second = std::make_shared("second"); + registry.insert("key", first); + registry.insert("key", second, /*overwrite=*/true); + + EXPECT_EQ(registry.find("key"), second); +} + +TEST(ScopedRegistryTest, erase) { + ScopedRegistry registry; + registry.insert("key", std::make_shared("key")); + + EXPECT_TRUE(registry.erase("key")); + EXPECT_EQ(registry.find("key"), nullptr); + EXPECT_FALSE(registry.erase("key")); +} + +TEST(ScopedRegistryTest, clear) { + ScopedRegistry registry; + registry.insert("a", std::make_shared("a")); + registry.insert("b", std::make_shared("b")); + registry.clear(); + + EXPECT_EQ(registry.find("a"), nullptr); + EXPECT_EQ(registry.find("b"), nullptr); +} + +TEST(ScopedRegistryTest, snapshot) { + ScopedRegistry registry; + registry.insert("a", std::make_shared("a")); + registry.insert("b", std::make_shared("b")); + + auto entries = registry.snapshot(); + EXPECT_EQ(entries.size(), 2); + + std::set keys; + for (const auto& [key, _] : entries) { + keys.insert(key); + } + EXPECT_EQ(keys.size(), 2); + EXPECT_TRUE(keys.count("a")); + EXPECT_TRUE(keys.count("b")); +} + +TEST(ScopedRegistryTest, parentFallback) { + ScopedRegistry parent; + auto entry = std::make_shared("from-parent"); + parent.insert("key", entry); + + const ScopedRegistry child(&parent); + EXPECT_EQ(child.find("key"), entry); +} + +TEST(ScopedRegistryTest, childOverridesParent) { + ScopedRegistry parent; + auto parentEntry = std::make_shared("parent"); + parent.insert("key", parentEntry); + + ScopedRegistry child(&parent); + auto childEntry = std::make_shared("child"); + child.insert("key", childEntry); + + EXPECT_EQ(child.find("key"), childEntry); + EXPECT_EQ(parent.find("key"), parentEntry); +} + +TEST(ScopedRegistryTest, childEraseDoesNotAffectParent) { + ScopedRegistry parent; + auto entry = std::make_shared("parent"); + parent.insert("key", entry); + + ScopedRegistry child(&parent); + child.insert("key", std::make_shared("child")); + child.erase("key"); + + // Child erased its own override; parent entry is still visible via fallback. + EXPECT_EQ(child.find("key"), entry); + EXPECT_EQ(parent.find("key"), entry); +} + +TEST(ScopedRegistryTest, snapshotMergesParent) { + ScopedRegistry parent; + parent.insert("a", std::make_shared("a")); + parent.insert("b", std::make_shared("b-parent")); + + ScopedRegistry child(&parent); + child.insert("b", std::make_shared("b-child")); + child.insert("c", std::make_shared("c")); + + auto entries = child.snapshot(); + EXPECT_EQ(entries.size(), 3); + + std::map snapshot; + for (const auto& [key, value] : entries) { + snapshot[key] = value->name(); + } + EXPECT_EQ(snapshot["a"], "a"); + EXPECT_EQ(snapshot["b"], "b-child"); + EXPECT_EQ(snapshot["c"], "c"); +} + +} // namespace +} // namespace facebook::velox diff --git a/velox/common/testutil/CMakeLists.txt b/velox/common/testutil/CMakeLists.txt index 5b5dca326d9..efe4631817c 100644 --- a/velox/common/testutil/CMakeLists.txt +++ b/velox/common/testutil/CMakeLists.txt @@ -12,9 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -velox_add_library(velox_test_util ScopedTestTime.cpp TestValue.cpp RandomSeed.cpp) +velox_add_library( + velox_test_util + ScopedTestTime.cpp + TestValue.cpp + RandomSeed.cpp + TempFilePath.cpp + TempDirectoryPath.cpp + HEADERS + OptionalEmpty.h + RandomSeed.h + ScopedTestTime.h + TempDirectoryPath.h + TempFilePath.h + TestValue.h +) -velox_link_libraries(velox_test_util PUBLIC velox_exception PRIVATE glog::glog Folly::folly) +velox_link_libraries(velox_test_util PUBLIC velox_exception Folly::folly PRIVATE glog::glog) if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) diff --git a/velox/exec/tests/utils/TempDirectoryPath.cpp b/velox/common/testutil/TempDirectoryPath.cpp similarity index 90% rename from velox/exec/tests/utils/TempDirectoryPath.cpp rename to velox/common/testutil/TempDirectoryPath.cpp index 358aa322964..5e63ac777df 100644 --- a/velox/exec/tests/utils/TempDirectoryPath.cpp +++ b/velox/common/testutil/TempDirectoryPath.cpp @@ -14,11 +14,11 @@ * limitations under the License. */ -#include "velox/exec/tests/utils/TempDirectoryPath.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "boost/filesystem.hpp" -namespace facebook::velox::exec::test { +namespace facebook::velox::common::testutil { std::shared_ptr TempDirectoryPath::create(bool injectFault) { auto* tempDirPath = new TempDirectoryPath(injectFault); @@ -42,4 +42,4 @@ std::string TempDirectoryPath::createTempDirectory() { return tempDirectoryPath; } -} // namespace facebook::velox::exec::test +} // namespace facebook::velox::common::testutil diff --git a/velox/common/testutil/TempDirectoryPath.h b/velox/common/testutil/TempDirectoryPath.h new file mode 100644 index 00000000000..fb2476a8a6b --- /dev/null +++ b/velox/common/testutil/TempDirectoryPath.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +#include "velox/common/base/Exceptions.h" + +namespace facebook::velox::common::testutil { + +/// Manages the lifetime of a temporary directory. +class TempDirectoryPath { + public: + /// If 'enableFaultInjection' is true, we enable fault injection on the + /// created file directory. + static std::shared_ptr create( + bool enableFaultInjection = false); + + virtual ~TempDirectoryPath(); + + TempDirectoryPath(const TempDirectoryPath&) = delete; + TempDirectoryPath& operator=(const TempDirectoryPath&) = delete; + + /// If fault injection is enabled, the returned file path will have the faulty + /// file system prefix scheme. The velox fs then opens the directory through + /// the faulty file system. The file operation will then either fail or be + /// delegated to the actual file. + const std::string& getPath() const { + return path_; + } + + /// The actual file path if fault injection is enabled. + const std::string& getDelegatePath() const { + return tempPath_; + } + + private: + static std::string createTempDirectory(); + + explicit TempDirectoryPath(bool enableFaultInjection) + : enableFaultInjection_(enableFaultInjection), + tempPath_(createTempDirectory()), + path_( + enableFaultInjection_ ? fmt::format("faulty:{}", tempPath_) + : tempPath_) {} + + const bool enableFaultInjection_{false}; + const std::string tempPath_; + const std::string path_; +}; +} // namespace facebook::velox::common::testutil diff --git a/velox/exec/tests/utils/TempFilePath.cpp b/velox/common/testutil/TempFilePath.cpp similarity index 88% rename from velox/exec/tests/utils/TempFilePath.cpp rename to velox/common/testutil/TempFilePath.cpp index 263abce930e..bb3ddd02edb 100644 --- a/velox/exec/tests/utils/TempFilePath.cpp +++ b/velox/common/testutil/TempFilePath.cpp @@ -14,13 +14,15 @@ * limitations under the License. */ -#include "velox/exec/tests/utils/TempFilePath.h" +#include "velox/common/testutil/TempFilePath.h" -namespace facebook::velox::exec::test { +namespace facebook::velox::common::testutil { TempFilePath::~TempFilePath() { ::unlink(tempPath_.c_str()); - ::close(fd_); + if (fd_ != -1) { + ::close(fd_); + } } std::shared_ptr TempFilePath::create(bool enableFaultInjection) { @@ -46,4 +48,4 @@ std::vector toFilePaths( } return filePaths; } -} // namespace facebook::velox::exec::test +} // namespace facebook::velox::common::testutil diff --git a/velox/common/testutil/TempFilePath.h b/velox/common/testutil/TempFilePath.h new file mode 100644 index 00000000000..17a19be09ea --- /dev/null +++ b/velox/common/testutil/TempFilePath.h @@ -0,0 +1,97 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "velox/common/base/Exceptions.h" + +namespace facebook::velox::common::testutil { + +/// Manages the lifetime of a temporary file. +class TempFilePath { + public: + /// If 'enableFaultInjection' is true, we enable fault injection on the + /// created file. + static std::shared_ptr create( + bool enableFaultInjection = false); + + ~TempFilePath(); + + TempFilePath(const TempFilePath&) = delete; + TempFilePath& operator=(const TempFilePath&) = delete; + + void append(const std::string& data) { + std::ofstream file(tempPath_, std::ios_base::app); + file << data; + file.flush(); + file.close(); + } + + int64_t fileSize() { + struct stat st{}; + ::stat(tempPath_.data(), &st); + return st.st_size; + } + + int64_t fileModifiedTime() { + struct stat st{}; + ::stat(tempPath_.data(), &st); + return st.st_mtime; + } + + /// If fault injection is enabled, the returned the file path has the faulty + /// file system prefix scheme. The velox fs then opens the file through the + /// faulty file system. The actual file operation might either fails or + /// delegate to the actual file. + const std::string& getPath() const { + return path_; + } + + // Returns the delegated file path if fault injection is enabled. + const std::string& tempFilePath() const { + return tempPath_; + } + + private: + static std::string createTempFile(TempFilePath* tempFilePath); + + explicit TempFilePath(bool enableFaultInjection) + : enableFaultInjection_(enableFaultInjection), + tempPath_(createTempFile(this)), + path_( + enableFaultInjection_ ? fmt::format("faulty:{}", tempPath_) + : tempPath_) { + VELOX_CHECK_NE(fd_, -1); + } + + const bool enableFaultInjection_; + // initialized to -1 before tempPath_ is initialized, + // since members are initialized in declaration order + int fd_{-1}; + const std::string tempPath_; + const std::string path_; +}; + +std::vector toFilePaths( + const std::vector>& tempFiles); + +} // namespace facebook::velox::common::testutil diff --git a/velox/common/testutil/TestValue.cpp b/velox/common/testutil/TestValue.cpp index b833b980ac4..bf14f66d807 100644 --- a/velox/common/testutil/TestValue.cpp +++ b/velox/common/testutil/TestValue.cpp @@ -20,7 +20,7 @@ namespace facebook::velox::common::testutil { std::mutex TestValue::mutex_; bool TestValue::enabled_ = false; -std::unordered_map TestValue::injectionMap_; +folly::F14FastMap TestValue::injectionMap_; #ifndef NDEBUG void TestValue::enable() { @@ -38,12 +38,12 @@ bool TestValue::enabled() { return enabled_; } -void TestValue::clear(const std::string& injectionPoint) { +void TestValue::clear(std::string_view injectionPoint) { std::lock_guard l(mutex_); injectionMap_.erase(injectionPoint); } -void TestValue::adjust(const std::string& injectionPoint, void* testData) { +void TestValue::adjust(std::string_view injectionPoint, void* testData) { Callback injectionCb; { std::lock_guard l(mutex_); @@ -60,7 +60,7 @@ void TestValue::disable() {} bool TestValue::enabled() { return false; } -void TestValue::clear(const std::string& injectionPoint) {} +void TestValue::clear(std::string_view injectionPoint) {} #endif } // namespace facebook::velox::common::testutil diff --git a/velox/common/testutil/TestValue.h b/velox/common/testutil/TestValue.h index ba988a6935b..321936bd67c 100644 --- a/velox/common/testutil/TestValue.h +++ b/velox/common/testutil/TestValue.h @@ -15,9 +15,10 @@ */ #pragma once +#include #include #include -#include +#include #include "velox/common/base/Exceptions.h" #include "velox/common/base/Macros.h" @@ -45,24 +46,24 @@ class TestValue { /// injected callback hook. template static void set( - const std::string& injectionPoint, + std::string_view injectionPoint, std::function injectionCb); /// Invoked by the test code to unregister a callback hook at the specified /// execution point. - static void clear(const std::string& injectionPoint); + static void clear(std::string_view injectionPoint); /// Invoked by the production code to try to invoke the test callback hook /// with 'testData' if there is one registered at the specified execution /// point. 'testData' capture the mutable production execution state. - static void adjust(const std::string& injectionPoint, void* testData); + static void adjust(std::string_view injectionPoint, void* testData); private: using Callback = std::function; static std::mutex mutex_; static bool enabled_; - static std::unordered_map injectionMap_; + static folly::F14FastMap injectionMap_; }; class ScopedTestValue { @@ -79,13 +80,13 @@ class ScopedTestValue { } private: - const std::string point_; + std::string point_; }; #ifndef NDEBUG template void TestValue::set( - const std::string& injectionPoint, + std::string_view injectionPoint, std::function injectionCb) { std::lock_guard l(mutex_); if (!enabled_) { @@ -99,15 +100,14 @@ void TestValue::set( #else template void TestValue::set( - const std::string& injectionPoint, + std::string_view injectionPoint, std::function injectionCb) {} #endif #ifdef NDEBUG // Keep the definition in header so that it can be inlined (elided). -inline void TestValue::adjust( - const std::string& injectionPoint, - void* testData) {} +inline void TestValue::adjust(std::string_view injectionPoint, void* testData) { +} #endif #define SCOPED_TESTVALUE_SET(point, ...) \ diff --git a/velox/common/testutil/tests/CastsTest.cpp b/velox/common/testutil/tests/CastsTest.cpp index 182c73ebeb6..b5a9cf60052 100644 --- a/velox/common/testutil/tests/CastsTest.cpp +++ b/velox/common/testutil/tests/CastsTest.cpp @@ -93,16 +93,16 @@ class CastsTest : public ::testing::Test { DerivedClass* derivedRawPtr_; }; -// Tests for checked_pointer_cast with shared_ptr +// Tests for checkedPointerCast with shared_ptr TEST_F(CastsTest, checkedPointerCastSharedPtrSuccess) { // Cast derived to base (should always work) - auto result = checked_pointer_cast(derivedPtr_); + auto result = checkedPointerCast(derivedPtr_); EXPECT_NE(result, nullptr); EXPECT_EQ(result->getValue(), 100); // Cast base to derived when it actually is derived std::shared_ptr basePtrToDerived = derivedPtr_; - auto derivedResult = checked_pointer_cast(basePtrToDerived); + auto derivedResult = checkedPointerCast(basePtrToDerived); EXPECT_NE(derivedResult, nullptr); EXPECT_EQ(derivedResult->getValue(), 100); EXPECT_EQ(derivedResult->getDerivedValue(), 200); @@ -110,22 +110,22 @@ TEST_F(CastsTest, checkedPointerCastSharedPtrSuccess) { TEST_F(CastsTest, checkedPointerCastSharedPtrFailure) { // Try to cast base to derived when it's not actually derived - VELOX_ASSERT_THROW(checked_pointer_cast(basePtr_), ""); + VELOX_ASSERT_THROW(checkedPointerCast(basePtr_), ""); // Try to cast to unrelated class - VELOX_ASSERT_THROW(checked_pointer_cast(derivedPtr_), ""); + VELOX_ASSERT_THROW(checkedPointerCast(derivedPtr_), ""); } TEST_F(CastsTest, checkedPointerCastSharedPtrNullInput) { std::shared_ptr nullPtr; - VELOX_ASSERT_THROW(checked_pointer_cast(nullPtr), ""); + VELOX_ASSERT_THROW(checkedPointerCast(nullPtr), ""); } -// Tests for checked_pointer_cast with unique_ptr +// Tests for checkedPointerCast with unique_ptr TEST_F(CastsTest, checkedPointerCastUniquePtrSuccess) { // Cast derived to base auto derivedForCast = std::make_unique(); - auto result = checked_pointer_cast(std::move(derivedForCast)); + auto result = checkedPointerCast(std::move(derivedForCast)); EXPECT_NE(result, nullptr); EXPECT_EQ(result->getValue(), 100); @@ -133,7 +133,7 @@ TEST_F(CastsTest, checkedPointerCastUniquePtrSuccess) { std::unique_ptr basePtrToDerived = std::make_unique(); auto derivedResult = - checked_pointer_cast(std::move(basePtrToDerived)); + checkedPointerCast(std::move(basePtrToDerived)); EXPECT_NE(derivedResult, nullptr); EXPECT_EQ(derivedResult->getValue(), 100); EXPECT_EQ(derivedResult->getDerivedValue(), 200); @@ -143,25 +143,24 @@ TEST_F(CastsTest, checkedPointerCastUniquePtrFailure) { // Try to cast base to derived when it's not actually derived auto baseForCast = std::make_unique(); VELOX_ASSERT_THROW( - checked_pointer_cast(std::move(baseForCast)), ""); + checkedPointerCast(std::move(baseForCast)), ""); } TEST_F(CastsTest, checkedPointerCastUniquePtrNullInput) { std::unique_ptr nullPtr; - VELOX_ASSERT_THROW( - checked_pointer_cast(std::move(nullPtr)), ""); + VELOX_ASSERT_THROW(checkedPointerCast(std::move(nullPtr)), ""); } -// Tests for checked_pointer_cast with raw pointers +// Tests for checkedPointerCast with raw pointers TEST_F(CastsTest, checkedPointerCastRawPtrSuccess) { // Cast derived to base - auto result = checked_pointer_cast(derivedRawPtr_); + auto result = checkedPointerCast(derivedRawPtr_); EXPECT_NE(result, nullptr); EXPECT_EQ(result->getValue(), 100); // Cast base to derived when it actually is derived BaseClass* basePtrToDerived = derivedRawPtr_; - auto derivedResult = checked_pointer_cast(basePtrToDerived); + auto derivedResult = checkedPointerCast(basePtrToDerived); EXPECT_NE(derivedResult, nullptr); EXPECT_EQ(derivedResult->getValue(), 100); EXPECT_EQ(derivedResult->getDerivedValue(), 200); @@ -169,21 +168,20 @@ TEST_F(CastsTest, checkedPointerCastRawPtrSuccess) { TEST_F(CastsTest, checkedPointerCastRawPtrFailure) { // Try to cast base to derived when it's not actually derived - VELOX_ASSERT_THROW(checked_pointer_cast(baseRawPtr_), ""); + VELOX_ASSERT_THROW(checkedPointerCast(baseRawPtr_), ""); } TEST_F(CastsTest, checkedPointerCastRawPtrNullInput) { BaseClass* nullPtr = nullptr; - VELOX_ASSERT_THROW(checked_pointer_cast(nullPtr), ""); + VELOX_ASSERT_THROW(checkedPointerCast(nullPtr), ""); } -// Tests for static_unique_pointer_cast +// Tests for staticUniquePointerCast TEST_F(CastsTest, staticUniquePointerCastSuccess) { // Create a unique_ptr to derived and cast to base auto derivedForCast = std::make_unique(); auto originalPtr = derivedForCast.get(); - auto result = - static_unique_pointer_cast(std::move(derivedForCast)); + auto result = staticUniquePointerCast(std::move(derivedForCast)); EXPECT_NE(result, nullptr); EXPECT_EQ(result.get(), originalPtr); // Should be the same pointer @@ -193,65 +191,65 @@ TEST_F(CastsTest, staticUniquePointerCastSuccess) { TEST_F(CastsTest, staticUniquePointerCastNullInput) { std::unique_ptr nullPtr; VELOX_ASSERT_THROW( - static_unique_pointer_cast(std::move(nullPtr)), ""); + staticUniquePointerCast(std::move(nullPtr)), ""); } -// Tests for is_instance_of with shared_ptr +// Tests for isInstanceOf with shared_ptr TEST_F(CastsTest, isInstanceOfSharedPtr) { // Test positive cases - EXPECT_TRUE(is_instance_of(derivedPtr_)); - EXPECT_TRUE(is_instance_of(derivedPtr_)); - EXPECT_TRUE(is_instance_of(anotherDerivedPtr_)); - EXPECT_TRUE(is_instance_of(anotherDerivedPtr_)); + EXPECT_TRUE(isInstanceOf(derivedPtr_)); + EXPECT_TRUE(isInstanceOf(derivedPtr_)); + EXPECT_TRUE(isInstanceOf(anotherDerivedPtr_)); + EXPECT_TRUE(isInstanceOf(anotherDerivedPtr_)); // Test negative cases - EXPECT_FALSE(is_instance_of(basePtr_)); - EXPECT_FALSE(is_instance_of(derivedPtr_)); - EXPECT_FALSE(is_instance_of(anotherDerivedPtr_)); - EXPECT_FALSE(is_instance_of(derivedPtr_)); + EXPECT_FALSE(isInstanceOf(basePtr_)); + EXPECT_FALSE(isInstanceOf(derivedPtr_)); + EXPECT_FALSE(isInstanceOf(anotherDerivedPtr_)); + EXPECT_FALSE(isInstanceOf(derivedPtr_)); } TEST_F(CastsTest, isInstanceOfSharedPtrNullInput) { std::shared_ptr nullPtr; - VELOX_ASSERT_THROW(is_instance_of(nullPtr), ""); + VELOX_ASSERT_THROW(isInstanceOf(nullPtr), ""); } -// Tests for is_instance_of with unique_ptr +// Tests for isInstanceOf with unique_ptr TEST_F(CastsTest, isInstanceOfUniquePtr) { // Test positive cases - EXPECT_TRUE(is_instance_of(derivedUniquePtr_)); - EXPECT_TRUE(is_instance_of(derivedUniquePtr_)); + EXPECT_TRUE(isInstanceOf(derivedUniquePtr_)); + EXPECT_TRUE(isInstanceOf(derivedUniquePtr_)); // Test negative cases - EXPECT_FALSE(is_instance_of(baseUniquePtr_)); - EXPECT_FALSE(is_instance_of(derivedUniquePtr_)); + EXPECT_FALSE(isInstanceOf(baseUniquePtr_)); + EXPECT_FALSE(isInstanceOf(derivedUniquePtr_)); } TEST_F(CastsTest, isInstanceOfUniquePtrNullInput) { std::unique_ptr nullPtr; - VELOX_ASSERT_THROW(is_instance_of(nullPtr), ""); + VELOX_ASSERT_THROW(isInstanceOf(nullPtr), ""); } -// Tests for is_instance_of with raw pointers +// Tests for isInstanceOf with raw pointers TEST_F(CastsTest, isInstanceOfRawPtr) { // Test positive cases - EXPECT_TRUE(is_instance_of(derivedRawPtr_)); - EXPECT_TRUE(is_instance_of(derivedRawPtr_)); + EXPECT_TRUE(isInstanceOf(derivedRawPtr_)); + EXPECT_TRUE(isInstanceOf(derivedRawPtr_)); // Test negative cases - EXPECT_FALSE(is_instance_of(baseRawPtr_)); - EXPECT_FALSE(is_instance_of(derivedRawPtr_)); + EXPECT_FALSE(isInstanceOf(baseRawPtr_)); + EXPECT_FALSE(isInstanceOf(derivedRawPtr_)); } TEST_F(CastsTest, isInstanceOfRawPtrNullInput) { BaseClass* nullPtr = nullptr; - VELOX_ASSERT_THROW(is_instance_of(nullPtr), ""); + VELOX_ASSERT_THROW(isInstanceOf(nullPtr), ""); } // Test error messages contain useful information TEST_F(CastsTest, errorMessageContent) { try { - checked_pointer_cast(basePtr_); + checkedPointerCast(basePtr_); FAIL() << "Expected VeloxException to be thrown"; } catch (const VeloxException& e) { const std::string& message = e.message(); @@ -266,12 +264,12 @@ TEST_F(CastsTest, errorMessageContent) { TEST_F(CastsTest, objectIdentityPreserved) { // For shared_ptr std::shared_ptr basePtrToDerived = derivedPtr_; - auto castedShared = checked_pointer_cast(basePtrToDerived); + auto castedShared = checkedPointerCast(basePtrToDerived); EXPECT_EQ(castedShared.get(), derivedPtr_.get()); // For raw ptr BaseClass* basePtrToDerivedRaw = derivedRawPtr_; - auto castedRaw = checked_pointer_cast(basePtrToDerivedRaw); + auto castedRaw = checkedPointerCast(basePtrToDerivedRaw); EXPECT_EQ(castedRaw, derivedRawPtr_); } @@ -280,7 +278,7 @@ TEST_F(CastsTest, uniquePtrExceptionSafety) { auto baseForCast = std::make_unique(); try { - checked_pointer_cast(std::move(baseForCast)); + checkedPointerCast(std::move(baseForCast)); FAIL() << "Expected VeloxException to be thrown"; } catch (const VeloxException&) { // The unique_ptr should have been restored and the object should still diff --git a/velox/common/time/CMakeLists.txt b/velox/common/time/CMakeLists.txt index 46c878968f0..fff0f5b52af 100644 --- a/velox/common/time/CMakeLists.txt +++ b/velox/common/time/CMakeLists.txt @@ -15,5 +15,18 @@ if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) endif() -velox_add_library(velox_time CpuWallTimer.cpp Timer.cpp) +velox_add_library( + velox_time + CpuWallTimer.cpp + Timer.cpp + HEADERS + CpuWallTimer.h + Timer.h +) velox_link_libraries(velox_time PUBLIC velox_process velox_test_util Folly::folly fmt::fmt) + +velox_add_library(velox_hierarchical_timer HierarchicalTimer.cpp HEADERS HierarchicalTimer.h) +velox_link_libraries( + velox_hierarchical_timer + PUBLIC velox_exception velox_common_base Folly::folly fmt::fmt glog::glog +) diff --git a/velox/common/time/CpuWallTimer.h b/velox/common/time/CpuWallTimer.h index 231c15f66c2..06feb989848 100644 --- a/velox/common/time/CpuWallTimer.h +++ b/velox/common/time/CpuWallTimer.h @@ -18,6 +18,7 @@ #include #include +#include "velox/common/base/Macros.h" #include "velox/common/base/SuccinctPrinter.h" #include "velox/common/process/ProcessBase.h" @@ -29,8 +30,13 @@ struct CpuWallTiming { uint64_t wallNanos = 0; uint64_t cpuNanos = 0; + auto operator<=>(const CpuWallTiming&) const = default; + void add(const CpuWallTiming& other) { + // Suppress spurious warnings in GCC 13. + VELOX_SUPPRESS_STRINGOP_OVERFLOW_WARNING count += other.count; + VELOX_UNSUPPRESS_STRINGOP_OVERFLOW_WARNING cpuNanos += other.cpuNanos; wallNanos += other.wallNanos; } diff --git a/velox/common/time/HierarchicalTimer.cpp b/velox/common/time/HierarchicalTimer.cpp new file mode 100644 index 00000000000..398ca337308 --- /dev/null +++ b/velox/common/time/HierarchicalTimer.cpp @@ -0,0 +1,556 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/time/HierarchicalTimer.h" + +#include +#include +#include + +#include +#include + +#include "velox/common/base/Exceptions.h" +#include "velox/common/base/SuccinctPrinter.h" + +#ifdef ENABLE_HW_TIMER +#if !defined(__x86_64__) +#error "HierarchicalTimer RDTSCP mode requires x86-64 architecture" +#endif +#include +#include "folly/chrono/Hardware.h" +#endif + +namespace facebook::velox { + +namespace { + +#ifdef ENABLE_HW_TIMER +double estimateTscFreqGhz( + std::chrono::milliseconds window = std::chrono::milliseconds(100)) { + static const double cached = [&]() { + using clock = std::chrono::steady_clock; + // Warm-up (reduces first-use noise). + (void)folly::hardware_timestamp(); + auto t0 = clock::now(); + uint64_t c0 = folly::hardware_timestamp(); + // Busy-wait until window elapsed (less jitter than sleep). + while (clock::now() - t0 < window) { + _mm_pause(); + } + auto t1 = clock::now(); + uint64_t c1 = folly::hardware_timestamp(); + std::chrono::duration seconds = t1 - t0; + return (static_cast(c1 - c0) / seconds.count()) / 1e9; + }(); + return cached; +} +#endif + +/// Tracks the currently active TimerNode for stack-based auto-nesting. +thread_local TimerNode* activeNode_ = + nullptr; // NOLINT(facebook-avoid-non-const-global-variables) + +/// Section column width for the formatted tree output. +/// Must accommodate prefix (4 chars per depth level) plus node name. +constexpr int kSectionWidth = 60; + +std::string buildPrefix(const std::vector& ancestorIsLast) { + if (ancestorIsLast.empty()) { + return ""; + } + std::string prefix; + // Draw continuation lines for ancestors (all but the current node). + for (size_t i = 0; i + 1 < ancestorIsLast.size(); ++i) { + prefix += ancestorIsLast[i] ? " " : "│ "; + } + // Draw the branch for the current node. + prefix += ancestorIsLast.back() ? "└── " : "├── "; + return prefix; +} + +/// Computes display width by counting UTF-8 codepoints (not bytes). +/// Box-drawing chars (│├└─) are 3 bytes each but 1 display column. +int displayWidth(const std::string& s) { + int width = 0; + for (unsigned char c : s) { + // Count only non-continuation bytes (not 0x80-0xBF). + if ((c & 0xC0) != 0x80) { + ++width; + } + } + return width; +} + +/// Pads the section string to kSectionWidth display columns for alignment. +std::string padSection(const std::string& section) { + const int dw = displayWidth(section); + if (dw >= kSectionWidth) { + return section; + } + return section + std::string(kSectionWidth - dw, ' '); +} + +} // namespace + +// --- TimerNode --- + +TimerNode::TimerNode(const std::string& name, TimerNode* parent) + : name_{name}, parent_{parent} {} + +TimerNode* TimerNode::getOrCreateChild(const std::string& childName) { + auto it = childrenByName_.find(childName); + if (it != childrenByName_.end()) { + return it->second; + } + auto child = std::make_unique(childName, this); + auto* raw = child.get(); + childrenByName_[childName] = raw; + children_.push_back(std::move(child)); + return raw; +} + +void TimerNode::addTime(uint64_t ns) { + totalTimeNs_ += ns; + if (ns < minTimeNs_) { + minTimeNs_ = ns; + } + if (ns > maxTimeNs_) { + maxTimeNs_ = ns; + } +} + +void TimerNode::addCpuTime(uint64_t ns) { + totalCpuNs_ += ns; + if (ns < minCpuNs_) { + minCpuNs_ = ns; + } + if (ns > maxCpuNs_) { + maxCpuNs_ = ns; + } +} + +void TimerNode::incrementCallCount() { + ++callCount_; +} + +void TimerNode::reset() { + totalTimeNs_ = 0; + totalCpuNs_ = 0; + callCount_ = 0; + minTimeNs_ = std::numeric_limits::max(); + maxTimeNs_ = 0; + minCpuNs_ = std::numeric_limits::max(); + maxCpuNs_ = 0; + for (auto& child : children_) { + child->reset(); + } +} + +uint64_t TimerNode::totalTimeNs() const { + return totalTimeNs_; +} + +uint64_t TimerNode::totalCpuNs() const { + return totalCpuNs_; +} + +uint64_t TimerNode::callCount() const { + return callCount_; +} + +uint64_t TimerNode::averageTimeNs() const { + if (callCount_ == 0) { + return 0; + } + return totalTimeNs_ / callCount_; +} + +uint64_t TimerNode::minTimeNs() const { + return minTimeNs_; +} + +uint64_t TimerNode::maxTimeNs() const { + return maxTimeNs_; +} + +uint64_t TimerNode::averageCpuNs() const { + if (callCount_ == 0) { + return 0; + } + return totalCpuNs_ / callCount_; +} + +uint64_t TimerNode::minCpuNs() const { + return minCpuNs_; +} + +uint64_t TimerNode::maxCpuNs() const { + return maxCpuNs_; +} + +const std::string& TimerNode::name() const { + return name_; +} + +TimerNode* TimerNode::parent() const { + return parent_; +} + +const std::vector>& TimerNode::children() const { + return children_; +} + +void TimerNode::format( + std::string& out, + int depth, + bool isLast, + uint64_t parentTimeNs, + uint64_t parentCpuTimeNs, + bool verbose) const { + std::vector ancestorTrail; + formatImpl( + out, + depth, + isLast, + parentTimeNs, + parentCpuTimeNs, + verbose, + ancestorTrail); +} + +void TimerNode::formatImpl( + std::string& out, + int depth, + bool isLast, + uint64_t parentTimeNs, + uint64_t parentCpuTimeNs, + bool verbose, + std::vector& ancestorTrail) const { + if (depth > 0) { + ancestorTrail.push_back(isLast); + } + const auto prefix = buildPrefix(ancestorTrail); + const auto wall = succinctNanos(totalTimeNs_); + const auto cpu = succinctNanos(totalCpuNs_); + const auto calls = + callCount_ > 0 ? std::to_string(callCount_) : std::string{"-"}; + + std::string wallPct; + if (parentTimeNs > 0 && depth > 0) { + wallPct = fmt::format( + "{:.1f}%", + 100.0 * static_cast(totalTimeNs_) / + static_cast(parentTimeNs)); + } else { + wallPct = "-"; + } + + std::string cpuPct; + if (parentCpuTimeNs > 0 && depth > 0) { + cpuPct = fmt::format( + "{:.1f}%", + 100.0 * static_cast(totalCpuNs_) / + static_cast(parentCpuTimeNs)); + } else { + cpuPct = "-"; + } + + const auto section = padSection(prefix + name_); + + out += fmt::format( + "{} {:>10s} {:>8s} {:>10s} {:>8s} {:>8s}", + section, + wall, + wallPct, + cpu, + cpuPct, + calls); + + if (verbose) { + std::string avgWall, minWall, maxWall; + std::string avgCpu, minCpu, maxCpu; + if (callCount_ > 0) { + avgWall = succinctNanos(averageTimeNs()); + minWall = succinctNanos(minTimeNs_); + maxWall = succinctNanos(maxTimeNs_); + avgCpu = succinctNanos(averageCpuNs()); + minCpu = succinctNanos(maxCpuNs_ > 0 ? minCpuNs_ : 0); + maxCpu = succinctNanos(maxCpuNs_); + } else { + avgWall = "-"; + minWall = "-"; + maxWall = "-"; + avgCpu = "-"; + minCpu = "-"; + maxCpu = "-"; + } + out += fmt::format( + " {:>10s} {:>10s} {:>10s} {:>10s} {:>10s} {:>10s}", + avgWall, + minWall, + maxWall, + avgCpu, + minCpu, + maxCpu); + } + out += "\n"; + + // Determine if "(other)" row will be shown so we can set isLast correctly + // for the final real child. + uint64_t childrenSum = 0; + uint64_t childrenCpuSum = 0; + for (const auto& child : children_) { + childrenSum += child->totalTimeNs(); + childrenCpuSum += child->totalCpuNs(); + } + const bool hasOther = !children_.empty() && totalTimeNs_ > childrenSum; + + // Format children. + for (size_t i = 0; i < children_.size(); ++i) { + const bool lastChild = (i == children_.size() - 1) && !hasOther; + children_[i]->formatImpl( + out, + depth + 1, + lastChild, + totalTimeNs_, + totalCpuNs_, + verbose, + ancestorTrail); + } + + // Show unaccounted time if this node has children and time exceeds children + // sum. + if (hasOther) { + const uint64_t other = totalTimeNs_ - childrenSum; + const uint64_t cpuOther = + (totalCpuNs_ > childrenCpuSum) ? totalCpuNs_ - childrenCpuSum : 0; + ancestorTrail.push_back(true); + const auto otherPrefix = buildPrefix(ancestorTrail); + const auto otherSection = padSection(otherPrefix + "(other)"); + const auto otherWallPct = fmt::format( + "{:.1f}%", + 100.0 * static_cast(other) / static_cast(totalTimeNs_)); + std::string otherCpuPct; + if (totalCpuNs_ > 0) { + otherCpuPct = fmt::format( + "{:.1f}%", + 100.0 * static_cast(cpuOther) / + static_cast(totalCpuNs_)); + } else { + otherCpuPct = "-"; + } + out += fmt::format( + "{} {:>10s} {:>8s} {:>10s} {:>8s} {:>8s}", + otherSection, + succinctNanos(other), + otherWallPct, + succinctNanos(cpuOther), + otherCpuPct, + "-"); + if (verbose) { + out += fmt::format( + " {:>10s} {:>10s} {:>10s} {:>10s} {:>10s} {:>10s}", + "-", + "-", + "-", + "-", + "-", + "-"); + } + out += "\n"; + ancestorTrail.pop_back(); + } + + if (depth > 0 && !ancestorTrail.empty()) { + ancestorTrail.pop_back(); + } +} + +// --- TimerTree --- + +TimerTree::TimerTree(const std::string& name) + : name_{name}, root_{std::make_unique("root")} {} + +TimerTree::~TimerTree() { + printStats(); +} + +uint64_t TimerTree::now() const { +#ifdef ENABLE_HW_TIMER + return static_cast( + static_cast(folly::hardware_timestamp()) / estimateTscFreqGhz()); +#else + return std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); +#endif +} + +uint64_t TimerTree::cpuNow() const { + struct timespec ts{}; + clock_gettime(CLOCK_THREAD_CPUTIME_ID, &ts); + return static_cast(ts.tv_sec) * 1'000'000'000ULL + + static_cast(ts.tv_nsec); +} + +TimerNode* TimerTree::getOrCreateNode(const std::string& path) { + TimerNode* current = root_.get(); + size_t start = 0; + while (start < path.size()) { + auto pos = path.find('/', start); + if (pos == std::string::npos) { + pos = path.size(); + } + const auto segment = path.substr(start, pos - start); + if (!segment.empty()) { + current = current->getOrCreateChild(segment); + } + start = pos + 1; + } + return current; +} + +void TimerTree::reset() { + root_->reset(); +} + +void TimerTree::clear() { + root_ = std::make_unique("root"); +} + +std::string TimerTree::toString(bool verbose) const { + const auto& topChildren = root_->children(); + if (topChildren.empty()) { + return "(no timers recorded)\n"; + } + + // Compact: Section(60) + Wall(10) + Wall%(8) + CPU(10) + CPU%(8) + Calls(8) + // + 6 separating spaces = 110. + // Verbose adds: AvgWall(10) + MinWall(10) + MaxWall(10) + AvgCPU(10) + + // MinCPU(10) + MaxCPU(10) + 6 separating spaces = 66. + constexpr int kCompactWidth = 110; + constexpr int kVerboseWidth = 176; + const int totalWidth = verbose ? kVerboseWidth : kCompactWidth; + + std::string out; + const std::string separator(totalWidth, '='); + const std::string dashSeparator(totalWidth, '-'); + + out += separator + "\n"; + if (!name_.empty()) { + out += fmt::format("{:^{}s}\n", name_, totalWidth); + } + out += fmt::format("{:^{}s}\n", "HIERARCHICAL TIMING BREAKDOWN", totalWidth); + out += separator + "\n\n"; + + out += fmt::format( + "{:<60s} {:>10s} {:>8s} {:>10s} {:>8s} {:>8s}", + "Section", + "Wall", + "Wall %", + "CPU", + "CPU %", + "Calls"); + if (verbose) { + out += fmt::format( + " {:>10s} {:>10s} {:>10s} {:>10s} {:>10s} {:>10s}", + "Avg Wall", + "Min Wall", + "Max Wall", + "Avg CPU", + "Min CPU", + "Max CPU"); + } + out += "\n"; + out += dashSeparator + "\n"; + + for (size_t i = 0; i < topChildren.size(); ++i) { + topChildren[i]->format( + out, + /*depth=*/0, + /*isLast=*/false, + /*parentTimeNs=*/0, + /*parentCpuTimeNs=*/0, + verbose); + } + + out += separator + "\n"; + return out; +} + +void TimerTree::printStats(bool verbose) { + const auto& topChildren = root_->children(); + if (topChildren.empty()) { + return; + } + LOG(INFO) << "\n" << toString(verbose); + clear(); +} + +const TimerNode& TimerTree::root() const { + return *root_; +} + +const std::string& TimerTree::name() const { + return name_; +} + +TimerTree& TimerTree::threadInstance() { + static thread_local TimerTree instance; + return instance; +} + +// --- ScopedTimer --- + +ScopedTimer::ScopedTimer(const std::string& name) + : tree_(TimerTree::threadInstance()), + restoreActive_(true), + startNs_(tree_.now()), + cpuStartNs_(tree_.cpuNow()) { + previousActive_ = activeNode_; + if (activeNode_) { + node_ = activeNode_->getOrCreateChild(name); + } else { + node_ = tree_.root_->getOrCreateChild(name); + } + activeNode_ = node_; + node_->incrementCallCount(); +} + +ScopedTimer::ScopedTimer(TimerTree& tree, const std::string& path) + : tree_(tree), + node_(tree.getOrCreateNode(path)), + restoreActive_(false), + startNs_(tree.now()), + cpuStartNs_(tree.cpuNow()) { + node_->incrementCallCount(); +} + +ScopedTimer::~ScopedTimer() { + const uint64_t endNs = tree_.now(); + const uint64_t cpuEndNs = tree_.cpuNow(); + VELOX_CHECK_GE(endNs, startNs_); + node_->addTime(endNs - startNs_); + if (cpuEndNs >= cpuStartNs_) { + node_->addCpuTime(cpuEndNs - cpuStartNs_); + } + if (restoreActive_) { + activeNode_ = previousActive_; + } +} + +} // namespace facebook::velox diff --git a/velox/common/time/HierarchicalTimer.h b/velox/common/time/HierarchicalTimer.h new file mode 100644 index 00000000000..4246dacecb9 --- /dev/null +++ b/velox/common/time/HierarchicalTimer.h @@ -0,0 +1,262 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace facebook::velox { + +/// A node in the hierarchical timer tree. +/// +/// Each node tracks cumulative wall time, CPU time, and call count for a named +/// section. Nodes form a tree: "parent/child" creates a "child" node under +/// "parent". Tracks wall total, CPU total, count, min, and max time per node. +/// +/// Not thread-safe. All access must be from a single thread. +class TimerNode { + public: + explicit TimerNode(const std::string& name, TimerNode* parent = nullptr); + + /// Returns or creates a direct child with the given name. + TimerNode* getOrCreateChild(const std::string& childName); + + /// Adds elapsed nanoseconds to this node's wall time total, updates min/max. + void addTime(uint64_t ns); + + /// Adds elapsed nanoseconds to this node's CPU time total, updates min/max. + void addCpuTime(uint64_t ns); + + /// Increments the call counter by one. + void incrementCallCount(); + + /// Zeros all counters for this node and all descendants. + /// Keeps the tree structure intact. + void reset(); + + uint64_t totalTimeNs() const; + uint64_t totalCpuNs() const; + uint64_t callCount() const; + uint64_t averageTimeNs() const; + uint64_t minTimeNs() const; + uint64_t maxTimeNs() const; + uint64_t averageCpuNs() const; + uint64_t minCpuNs() const; + uint64_t maxCpuNs() const; + const std::string& name() const; + TimerNode* parent() const; + const std::vector>& children() const; + + /// Formats this node (and descendants) into a human-readable table row. + /// `depth` controls indentation; `isLast` controls tree-drawing glyphs. + void format( + std::string& out, + int depth, + bool isLast, + uint64_t parentTimeNs, + uint64_t parentCpuTimeNs, + bool verbose) const; + + private: + /// Internal recursive formatter that tracks ancestor isLast flags for + /// drawing proper │ continuation lines in the tree. + void formatImpl( + std::string& out, + int depth, + bool isLast, + uint64_t parentTimeNs, + uint64_t parentCpuTimeNs, + bool verbose, + std::vector& ancestorTrail) const; + std::string name_; + TimerNode* parent_; + std::vector> children_; + std::unordered_map childrenByName_; + uint64_t totalTimeNs_{0}; + uint64_t totalCpuNs_{0}; + uint64_t callCount_{0}; + uint64_t minTimeNs_{std::numeric_limits::max()}; + uint64_t maxTimeNs_{0}; + uint64_t minCpuNs_{std::numeric_limits::max()}; + uint64_t maxCpuNs_{0}; +}; + +/// Owns a forest of TimerNode trees and provides path-based access. +/// +/// Paths are `/`-separated: "a/b/c" resolves to root -> a -> b -> c. +/// Multiple top-level names create separate subtrees. +/// +/// Uses RDTSCP for high-precision timing when `ENABLE_HW_TIMER` is defined, +/// automatically falls back to `steady_clock` otherwise. +/// +/// Uses ScopedTimer for RAII timing. On destruction (or when printStats() is +/// called), prints a hierarchical timing breakdown and clears all data. +/// +/// Not thread-safe. All access must be from a single thread. +/// +/// ## Usage +/// +/// ### Thread-local (recommended) -- auto-nesting via call stack: +/// +/// void Reader::loadStripe(int i) { +/// ScopedTimer t("loadStripe"); +/// { +/// ScopedTimer t2("readIO"); // auto-nested under loadStripe +/// readIO(i); +/// } +/// { +/// ScopedTimer t3("decode"); // auto-nested under loadStripe +/// decode(i); +/// } +/// } +/// +/// void Writer::flush() { +/// ScopedTimer t("flush"); +/// doFlush(); +/// } +/// +/// // Print results at any point: +/// TimerTree::threadInstance().printStats(); +/// +/// ### Explicit tree -- for isolated profiling sessions: +/// +/// TimerTree tree("my benchmark"); +/// for (int i = 0; i < numStripes; ++i) { +/// ScopedTimer t(tree, "loadStripe"); +/// loadStripe(i); +/// } +/// // Destructor prints results, or call tree.printStats() explicitly. +/// +/// ## Example output +/// +/// =============================================================...========= +/// HIERARCHICAL TIMING BREAKDOWN +/// =============================================================...========= +/// +/// Section Wall +/// CPU Calls Avg Wall ... % Parent +/// -------------------------------------------------------------...--------- +/// loadStripe 120.50ms +/// 45.20ms 5 24.10ms ... - +/// ├── readIO 80.20ms +/// 10.50ms 5 16.04ms ... 66.6% +/// ├── decode 35.10ms +/// 30.00ms 5 7.02ms ... 29.1% +/// └── (other) 5.20ms +/// 4.70ms - - ... 4.3% +/// flush 15.30ms +/// 14.80ms 1 15.30ms ... - +/// =============================================================...========= +class TimerTree { + public: + explicit TimerTree(const std::string& name = ""); + + TimerTree(const TimerTree&) = delete; + TimerTree& operator=(const TimerTree&) = delete; + TimerTree(TimerTree&&) = delete; + TimerTree& operator=(TimerTree&&) = delete; + + /// Prints the results table if any entries were recorded. + ~TimerTree(); + + /// Returns the current wall timestamp in nanoseconds. + /// Uses RDTSCP when ENABLE_HW_TIMER is defined, steady_clock otherwise. + uint64_t now() const; + + /// Returns the current thread CPU time in nanoseconds. + /// Uses clock_gettime(CLOCK_THREAD_CPUTIME_ID) to measure user+system CPU + /// time for the calling thread. + uint64_t cpuNow() const; + + /// Walks (or creates) the node addressed by a `/`-separated path. + TimerNode* getOrCreateNode(const std::string& path); + + /// Zeros all counters but keeps the tree structure. + void reset(); + + /// Destroys the entire tree (removes all nodes). + void clear(); + + /// Returns a formatted hierarchical summary table. + std::string toString(bool verbose = false) const; + + /// Prints the results table via LOG(INFO) and clears all data. + void printStats(bool verbose = false); + + /// Access root node for inspection. + const TimerNode& root() const; + + /// Returns the timer name. + const std::string& name() const; + + /// Returns the thread-local TimerTree instance. Creates one on first access + /// per thread. Results are printed when the thread exits (thread_local + /// destructor). Call printStats() to print and reset intermediate results. + static TimerTree& threadInstance(); + + private: + friend class ScopedTimer; + + std::string name_; + std::unique_ptr root_; +}; + +/// RAII timer that measures a scoped block and records it into a TimerTree. +/// +/// On construction, records the current wall and CPU time and increments the +/// call count for the node. On destruction, computes elapsed wall and CPU +/// nanoseconds and adds them to the node. +/// +/// Two constructors: +/// - ScopedTimer(name): records into TimerTree::threadInstance() with +/// stack-based auto-nesting. Nested ScopedTimers automatically form +/// parent-child relationships based on the call stack. The name is treated +/// as a simple child name (no path splitting). +/// - ScopedTimer(tree, path): records into an explicit TimerTree using +/// path-based node resolution (splits on '/'). +/// +/// Not thread-safe. The referenced TimerTree must outlive this object. +class ScopedTimer { + public: + /// Records into the thread-local TimerTree with stack-based auto-nesting. + /// The name is used as a direct child name under the currently active node + /// (or root if no timer is active). + explicit ScopedTimer(const std::string& name); + + /// Records into an explicit TimerTree using path-based node resolution. + ScopedTimer(TimerTree& tree, const std::string& path); + ~ScopedTimer(); + + ScopedTimer(const ScopedTimer&) = delete; + ScopedTimer& operator=(const ScopedTimer&) = delete; + ScopedTimer(ScopedTimer&&) = delete; + ScopedTimer& operator=(ScopedTimer&&) = delete; + + private: + TimerTree& tree_; + TimerNode* node_; + TimerNode* previousActive_{nullptr}; + bool restoreActive_; + uint64_t startNs_; + uint64_t cpuStartNs_; +}; + +} // namespace facebook::velox diff --git a/velox/common/time/tests/CMakeLists.txt b/velox/common/time/tests/CMakeLists.txt index cb602e651a3..2330d095bae 100644 --- a/velox/common/time/tests/CMakeLists.txt +++ b/velox/common/time/tests/CMakeLists.txt @@ -18,3 +18,19 @@ add_executable(velox_time_test CpuWallTimerTest.cpp) target_link_libraries(velox_time_test PRIVATE velox_time glog::glog GTest::gtest GTest::gtest_main) gtest_add_tests(velox_time_test "" AUTO) + +add_executable(velox_hierarchical_timer_test HierarchicalTimerTest.cpp) + +target_link_libraries( + velox_hierarchical_timer_test + PRIVATE + velox_hierarchical_timer + velox_exception + velox_common_base + fmt::fmt + glog::glog + GTest::gtest + GTest::gtest_main +) + +gtest_add_tests(velox_hierarchical_timer_test "" AUTO) diff --git a/velox/common/time/tests/HierarchicalTimerTest.cpp b/velox/common/time/tests/HierarchicalTimerTest.cpp new file mode 100644 index 00000000000..ea30945eec7 --- /dev/null +++ b/velox/common/time/tests/HierarchicalTimerTest.cpp @@ -0,0 +1,1029 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "velox/common/base/SuccinctPrinter.h" +#include "velox/common/base/VeloxException.h" +#include "velox/common/time/HierarchicalTimer.h" + +using namespace facebook::velox; + +namespace facebook::velox::test { + +namespace { + +// Intercepts glog INFO messages so tests can inspect logged output. +class LogCapture : public google::LogSink { + public: + LogCapture() { + google::AddLogSink(this); + } + + ~LogCapture() override { + google::RemoveLogSink(this); + } + + void send( + google::LogSeverity /*severity*/, + const char* /*full_filename*/, + const char* /*base_filename*/, + int /*line*/, + const struct ::tm* /*tm_time*/, + const char* message, + size_t message_len) override { + captured_ += std::string(message, message_len); + } + + LogCapture(const LogCapture&) = delete; + LogCapture& operator=(const LogCapture&) = delete; + + const std::string& captured() const { + return captured_; + } + + private: + std::string captured_; +}; + +void doWork() { + int x = 0; + for (int i = 0; i < 100; ++i) { + x += i; + } + // Prevent the compiler from optimizing away the loop. + EXPECT_GE(x, 0); +} + +/// Builds the expected header block for toString() output. +std::string expectedHeader( + const std::string& title = "", + bool verbose = false) { + constexpr int kCompactWidth = 110; + constexpr int kVerboseWidth = 176; + const int totalWidth = verbose ? kVerboseWidth : kCompactWidth; + + std::string out; + out += std::string(totalWidth, '=') + "\n"; + if (!title.empty()) { + out += fmt::format("{:^{}s}\n", title, totalWidth); + } + out += fmt::format("{:^{}s}\n", "HIERARCHICAL TIMING BREAKDOWN", totalWidth); + out += std::string(totalWidth, '=') + "\n\n"; + out += fmt::format( + "{:<60s} {:>10s} {:>8s} {:>10s} {:>8s} {:>8s}", + "Section", + "Wall", + "Wall %", + "CPU", + "CPU %", + "Calls"); + if (verbose) { + out += fmt::format( + " {:>10s} {:>10s} {:>10s} {:>10s} {:>10s} {:>10s}", + "Avg Wall", + "Min Wall", + "Max Wall", + "Avg CPU", + "Min CPU", + "Max CPU"); + } + out += "\n"; + out += std::string(totalWidth, '-') + "\n"; + return out; +} + +/// Builds a single expected compact node line (Section, Wall, Wall%, CPU, CPU%, +/// Calls). +std::string expectedLine( + const std::string& prefix, + const std::string& name, + const std::string& wall, + const std::string& wallPct, + const std::string& cpu, + const std::string& cpuPct, + const std::string& calls) { + auto section = prefix + name; + constexpr int kSectionWidth = 60; + // Use display width for padding (UTF-8 aware). + int dw = 0; + for (unsigned char c : section) { + if ((c & 0xC0) != 0x80) { + ++dw; + } + } + if (dw < kSectionWidth) { + section += std::string(kSectionWidth - dw, ' '); + } + return fmt::format( + "{} {:>10s} {:>8s} {:>10s} {:>8s} {:>8s}\n", + section, + wall, + wallPct, + cpu, + cpuPct, + calls); +} + +/// Builds a single expected verbose node line (compact columns + avg/min/max). +std::string expectedVerboseLine( + const std::string& prefix, + const std::string& name, + const std::string& wall, + const std::string& wallPct, + const std::string& cpu, + const std::string& cpuPct, + const std::string& calls, + const std::string& avgWall, + const std::string& minWall, + const std::string& maxWall, + const std::string& avgCpu, + const std::string& minCpu, + const std::string& maxCpu) { + auto section = prefix + name; + constexpr int kSectionWidth = 60; + int dw = 0; + for (unsigned char c : section) { + if ((c & 0xC0) != 0x80) { + ++dw; + } + } + if (dw < kSectionWidth) { + section += std::string(kSectionWidth - dw, ' '); + } + return fmt::format( + "{} {:>10s} {:>8s} {:>10s} {:>8s} {:>8s} {:>10s} {:>10s} {:>10s} {:>10s} {:>10s} {:>10s}\n", + section, + wall, + wallPct, + cpu, + cpuPct, + calls, + avgWall, + minWall, + maxWall, + avgCpu, + minCpu, + maxCpu); +} + +std::string expectedFooter(bool verbose = false) { + constexpr int kCompactWidth = 110; + constexpr int kVerboseWidth = 176; + const int totalWidth = verbose ? kVerboseWidth : kCompactWidth; + return std::string(totalWidth, '=') + "\n"; +} + +} // namespace + +// ===================================================================== +// Tests - Tree structure tests +// ===================================================================== + +class HierarchicalTimerTest : public ::testing::Test { + protected: + TimerTree tree_; +}; + +TEST_F(HierarchicalTimerTest, hierarchicalNesting) { + auto* node = tree_.getOrCreateNode("a/b/c"); + + const auto& roots = tree_.root().children(); + ASSERT_EQ(1, roots.size()); + EXPECT_EQ("a", roots[0]->name()); + + const auto& bChildren = roots[0]->children(); + ASSERT_EQ(1, bChildren.size()); + EXPECT_EQ("b", bChildren[0]->name()); + + const auto& cChildren = bChildren[0]->children(); + ASSERT_EQ(1, cChildren.size()); + EXPECT_EQ("c", cChildren[0]->name()); + EXPECT_EQ(node, cChildren[0].get()); +} + +TEST_F(HierarchicalTimerTest, multipleRoots) { + tree_.getOrCreateNode("alpha"); + tree_.getOrCreateNode("beta"); + tree_.getOrCreateNode("gamma"); + + const auto& roots = tree_.root().children(); + ASSERT_EQ(3, roots.size()); + EXPECT_EQ("alpha", roots[0]->name()); + EXPECT_EQ("beta", roots[1]->name()); + EXPECT_EQ("gamma", roots[2]->name()); +} + +TEST_F(HierarchicalTimerTest, resetZerosCountersKeepsStructure) { + tree_.getOrCreateNode("a/b"); + auto* node = tree_.getOrCreateNode("a/b"); + node->addTime(100); + node->addCpuTime(50); + node->incrementCallCount(); + + tree_.reset(); + + const auto& roots = tree_.root().children(); + ASSERT_EQ(1, roots.size()); + EXPECT_EQ("a", roots[0]->name()); + ASSERT_EQ(1, roots[0]->children().size()); + + const auto& b = *roots[0]->children()[0]; + EXPECT_EQ(0, b.totalTimeNs()); + EXPECT_EQ(0, b.totalCpuNs()); + EXPECT_EQ(0, b.callCount()); +} + +TEST_F(HierarchicalTimerTest, clearDestroysTree) { + tree_.getOrCreateNode("a/b/c"); + tree_.clear(); + + EXPECT_TRUE(tree_.root().children().empty()); +} + +TEST_F(HierarchicalTimerTest, averageTimeIsZeroWhenNoCallsRecorded) { + auto* node = tree_.getOrCreateNode("empty"); + EXPECT_EQ(0, node->averageTimeNs()); +} + +TEST_F(HierarchicalTimerTest, emptyTreeToString) { + const auto output = tree_.toString(); + EXPECT_NE(std::string::npos, output.find("no timers recorded")); +} + +// ===================================================================== +// Tests - ScopedTimer tests (explicit tree, path-based) +// ===================================================================== + +TEST_F(HierarchicalTimerTest, basicSingleTimer) { + { + ScopedTimer timer(tree_, "op"); + doWork(); + } + + const auto& topChildren = tree_.root().children(); + ASSERT_EQ(1, topChildren.size()); + + const auto& node = *topChildren[0]; + EXPECT_EQ("op", node.name()); + EXPECT_EQ(1, node.callCount()); + EXPECT_GT(node.totalTimeNs(), 0); + EXPECT_GT(node.totalCpuNs(), 0); + EXPECT_EQ(node.totalTimeNs(), node.averageTimeNs()); +} + +TEST_F(HierarchicalTimerTest, multipleCallsAccumulate) { + constexpr int kNumCalls = 5; + for (int i = 0; i < kNumCalls; ++i) { + ScopedTimer timer(tree_, "repeated"); + doWork(); + } + + const auto& node = *tree_.root().children()[0]; + EXPECT_EQ(kNumCalls, node.callCount()); + EXPECT_GT(node.totalTimeNs(), 0); +} + +TEST_F(HierarchicalTimerTest, nestedScopedTimers) { + { + ScopedTimer outer(tree_, "loadStripe"); + doWork(); + { + ScopedTimer inner(tree_, "loadStripe/readIO"); + doWork(); + } + { + ScopedTimer inner(tree_, "loadStripe/decode"); + doWork(); + } + } + + const auto& roots = tree_.root().children(); + ASSERT_EQ(1, roots.size()); + + const auto& loadStripe = *roots[0]; + EXPECT_EQ("loadStripe", loadStripe.name()); + EXPECT_EQ(1, loadStripe.callCount()); + EXPECT_GT(loadStripe.totalTimeNs(), 0); + + const auto& topChildren = loadStripe.children(); + ASSERT_EQ(2, topChildren.size()); + EXPECT_EQ("readIO", topChildren[0]->name()); + EXPECT_EQ("decode", topChildren[1]->name()); + EXPECT_EQ(1, topChildren[0]->callCount()); + EXPECT_EQ(1, topChildren[1]->callCount()); + + const uint64_t childrenSum = + topChildren[0]->totalTimeNs() + topChildren[1]->totalTimeNs(); + EXPECT_GE(loadStripe.totalTimeNs(), childrenSum); +} + +TEST_F(HierarchicalTimerTest, deeplyNestedScopedTimers) { + { + ScopedTimer t1(tree_, "a"); + { + ScopedTimer t2(tree_, "a/b"); + { + ScopedTimer t3(tree_, "a/b/c"); + doWork(); + } + } + } + + const auto& a = *tree_.root().children()[0]; + EXPECT_EQ("a", a.name()); + EXPECT_EQ(1, a.callCount()); + + const auto& b = *a.children()[0]; + EXPECT_EQ("b", b.name()); + EXPECT_EQ(1, b.callCount()); + + const auto& c = *b.children()[0]; + EXPECT_EQ("c", c.name()); + EXPECT_EQ(1, c.callCount()); + EXPECT_GT(c.totalTimeNs(), 0); + + EXPECT_GE(b.totalTimeNs(), c.totalTimeNs()); + EXPECT_GE(a.totalTimeNs(), b.totalTimeNs()); +} + +TEST_F(HierarchicalTimerTest, nestedScopedTimersMultipleIterations) { + constexpr int kIterations = 3; + for (int i = 0; i < kIterations; ++i) { + ScopedTimer outer(tree_, "loop"); + { + ScopedTimer inner(tree_, "loop/work"); + doWork(); + } + } + + const auto& loop = *tree_.root().children()[0]; + EXPECT_EQ(kIterations, loop.callCount()); + + const auto& work = *loop.children()[0]; + EXPECT_EQ(kIterations, work.callCount()); + EXPECT_GE(loop.totalTimeNs(), work.totalTimeNs()); +} + +TEST_F(HierarchicalTimerTest, multipleSameLevelScopedTimers) { + { + ScopedTimer outer(tree_, "process"); + { + ScopedTimer s1(tree_, "process/parse"); + doWork(); + } + { + ScopedTimer s2(tree_, "process/validate"); + doWork(); + } + { + ScopedTimer s3(tree_, "process/transform"); + doWork(); + } + { + ScopedTimer s4(tree_, "process/serialize"); + doWork(); + } + } + + const auto& process = *tree_.root().children()[0]; + EXPECT_EQ("process", process.name()); + EXPECT_EQ(1, process.callCount()); + + const auto& topChildren = process.children(); + ASSERT_EQ(4, topChildren.size()); + + const std::vector kExpectedNames{ + "parse", "validate", "transform", "serialize"}; + for (size_t i = 0; i < topChildren.size(); ++i) { + EXPECT_EQ(kExpectedNames[i], topChildren[i]->name()); + EXPECT_EQ(1, topChildren[i]->callCount()); + EXPECT_GT(topChildren[i]->totalTimeNs(), 0); + } + + uint64_t childrenSum = 0; + for (const auto& child : topChildren) { + childrenSum += child->totalTimeNs(); + } + EXPECT_GE(process.totalTimeNs(), childrenSum); +} + +TEST_F(HierarchicalTimerTest, sameLevelScopedTimersWithNesting) { + { + ScopedTimer outer(tree_, "root"); + { + ScopedTimer s1(tree_, "root/io"); + { + ScopedTimer s1a(tree_, "root/io/read"); + doWork(); + } + { + ScopedTimer s1b(tree_, "root/io/write"); + doWork(); + } + } + { + ScopedTimer s2(tree_, "root/compute"); + { + ScopedTimer s2a(tree_, "root/compute/map"); + doWork(); + } + { + ScopedTimer s2b(tree_, "root/compute/reduce"); + doWork(); + } + } + } + + const auto& rootNode = *tree_.root().children()[0]; + const auto& topChildren = rootNode.children(); + ASSERT_EQ(2, topChildren.size()); + + const auto& io = *topChildren[0]; + EXPECT_EQ("io", io.name()); + ASSERT_EQ(2, io.children().size()); + EXPECT_EQ("read", io.children()[0]->name()); + EXPECT_EQ("write", io.children()[1]->name()); + + const auto& compute = *topChildren[1]; + EXPECT_EQ("compute", compute.name()); + ASSERT_EQ(2, compute.children().size()); + EXPECT_EQ("map", compute.children()[0]->name()); + EXPECT_EQ("reduce", compute.children()[1]->name()); + + const uint64_t ioChildrenSum = + io.children()[0]->totalTimeNs() + io.children()[1]->totalTimeNs(); + EXPECT_GE(io.totalTimeNs(), ioChildrenSum); + + const uint64_t computeChildrenSum = compute.children()[0]->totalTimeNs() + + compute.children()[1]->totalTimeNs(); + EXPECT_GE(compute.totalTimeNs(), computeChildrenSum); +} + +TEST_F(HierarchicalTimerTest, scopedTimerIsNonCopyableNonMovable) { + EXPECT_FALSE(std::is_copy_constructible_v); + EXPECT_FALSE(std::is_copy_assignable_v); + EXPECT_FALSE(std::is_move_constructible_v); + EXPECT_FALSE(std::is_move_assignable_v); +} + +// ===================================================================== +// Tests - toString exact output tests +// ===================================================================== + +TEST_F(HierarchicalTimerTest, toStringFormattedOutputSingleNode) { + auto* node = tree_.getOrCreateNode("op"); + node->addTime(5'000'000); // 5ms + node->incrementCallCount(); + + std::string expected = expectedHeader(); + expected += expectedLine("", "op", "5.00ms", "-", "0ns", "-", "1"); + expected += expectedFooter(); + + EXPECT_EQ(expected, tree_.toString()); +} + +TEST_F(HierarchicalTimerTest, toStringFormattedOutputWithChildren) { + auto* parent = tree_.getOrCreateNode("loadStripe"); + parent->addTime(100'000'000); // 100ms + parent->incrementCallCount(); + + auto* child1 = tree_.getOrCreateNode("loadStripe/readIO"); + child1->addTime(60'000'000); // 60ms + child1->incrementCallCount(); + + auto* child2 = tree_.getOrCreateNode("loadStripe/decode"); + child2->addTime(30'000'000); // 30ms + child2->incrementCallCount(); + + std::string expected = expectedHeader(); + expected += expectedLine("", "loadStripe", "100.00ms", "-", "0ns", "-", "1"); + expected += + expectedLine("├── ", "readIO", "60.00ms", "60.0%", "0ns", "-", "1"); + expected += + expectedLine("├── ", "decode", "30.00ms", "30.0%", "0ns", "-", "1"); + expected += + expectedLine("└── ", "(other)", "10.00ms", "10.0%", "0ns", "-", "-"); + expected += expectedFooter(); + + EXPECT_EQ(expected, tree_.toString()); +} + +TEST_F(HierarchicalTimerTest, toStringFormattedOutputDeepNesting) { + auto* a = tree_.getOrCreateNode("a"); + a->addTime(1'000'000'000); // 1s + a->incrementCallCount(); + + auto* b = tree_.getOrCreateNode("a/b"); + b->addTime(500'000'000); // 500ms + b->incrementCallCount(); + b->incrementCallCount(); + + auto* c = tree_.getOrCreateNode("a/b/c"); + c->addTime(200'000'000); // 200ms + c->incrementCallCount(); + c->incrementCallCount(); + c->incrementCallCount(); + c->incrementCallCount(); + + std::string expected = expectedHeader(); + expected += expectedLine("", "a", "1.00s", "-", "0ns", "-", "1"); + expected += expectedLine("├── ", "b", "500.00ms", "50.0%", "0ns", "-", "2"); + expected += + expectedLine("│ ├── ", "c", "200.00ms", "40.0%", "0ns", "-", "4"); + expected += + expectedLine("│ └── ", "(other)", "300.00ms", "60.0%", "0ns", "-", "-"); + expected += + expectedLine("└── ", "(other)", "500.00ms", "50.0%", "0ns", "-", "-"); + expected += expectedFooter(); + + EXPECT_EQ(expected, tree_.toString()); +} + +TEST_F(HierarchicalTimerTest, toStringFormattedOutputMultipleRoots) { + auto* alpha = tree_.getOrCreateNode("alpha"); + alpha->addTime(10'000'000); // 10ms + alpha->incrementCallCount(); + + auto* beta = tree_.getOrCreateNode("beta"); + beta->addTime(20'000'000); // 20ms + beta->incrementCallCount(); + + std::string expected = expectedHeader(); + expected += expectedLine("", "alpha", "10.00ms", "-", "0ns", "-", "1"); + expected += expectedLine("", "beta", "20.00ms", "-", "0ns", "-", "1"); + expected += expectedFooter(); + + EXPECT_EQ(expected, tree_.toString()); +} + +TEST_F( + HierarchicalTimerTest, + toStringFormattedOutputMultipleSameLevelChildren) { + auto* parent = tree_.getOrCreateNode("pipeline"); + parent->addTime(200'000'000); // 200ms + parent->incrementCallCount(); + + auto* c1 = tree_.getOrCreateNode("pipeline/parse"); + c1->addTime(50'000'000); // 50ms + c1->incrementCallCount(); + + auto* c2 = tree_.getOrCreateNode("pipeline/validate"); + c2->addTime(30'000'000); // 30ms + c2->incrementCallCount(); + + auto* c3 = tree_.getOrCreateNode("pipeline/transform"); + c3->addTime(80'000'000); // 80ms + c3->incrementCallCount(); + + auto* c4 = tree_.getOrCreateNode("pipeline/emit"); + c4->addTime(20'000'000); // 20ms + c4->incrementCallCount(); + + std::string expected = expectedHeader(); + expected += expectedLine("", "pipeline", "200.00ms", "-", "0ns", "-", "1"); + expected += + expectedLine("├── ", "parse", "50.00ms", "25.0%", "0ns", "-", "1"); + expected += + expectedLine("├── ", "validate", "30.00ms", "15.0%", "0ns", "-", "1"); + expected += + expectedLine("├── ", "transform", "80.00ms", "40.0%", "0ns", "-", "1"); + expected += expectedLine("├── ", "emit", "20.00ms", "10.0%", "0ns", "-", "1"); + expected += + expectedLine("└── ", "(other)", "20.00ms", "10.0%", "0ns", "-", "-"); + expected += expectedFooter(); + + EXPECT_EQ(expected, tree_.toString()); +} + +// ===================================================================== +// Tests - Standalone tests +// ===================================================================== + +TEST(HierarchicalTimerStandaloneTest, noEntriesNoOutput) { + LogCapture capture; + { + TimerTree tree; + } + EXPECT_TRUE(capture.captured().empty()) << capture.captured(); +} + +TEST(HierarchicalTimerStandaloneTest, destructorPrintsTitleAndResults) { + LogCapture capture; + + { + TimerTree tree("MyBenchmarkTitle"); + { + ScopedTimer t(tree, "sessionTimer"); + doWork(); + } + } // Destructor prints the table. + + const auto& output = capture.captured(); + EXPECT_NE(output.find("MyBenchmarkTitle"), std::string::npos) << output; + EXPECT_NE(output.find("sessionTimer"), std::string::npos) << output; + EXPECT_NE(output.find("HIERARCHICAL TIMING BREAKDOWN"), std::string::npos) + << output; +} + +TEST(HierarchicalTimerStandaloneTest, printStatsClearsEntries) { + LogCapture capture1; + TimerTree tree("Session"); + + { + ScopedTimer t(tree, "firstTimer"); + doWork(); + } + tree.printStats(); + + const auto& output1 = capture1.captured(); + EXPECT_NE(output1.find("Session"), std::string::npos) << output1; + EXPECT_NE(output1.find("firstTimer"), std::string::npos) << output1; + + // After printStats(), entries should be cleared. + // Destructor should produce no further output. + LogCapture capture2; + tree.printStats(); + EXPECT_TRUE(capture2.captured().empty()) << capture2.captured(); +} + +// ===================================================================== +// Tests - min/max tracking and CPU time +// ===================================================================== + +TEST_F(HierarchicalTimerTest, minMaxTracking) { + auto* node = tree_.getOrCreateNode("tracked"); + node->addTime(100); + node->incrementCallCount(); + node->addTime(500); + node->incrementCallCount(); + node->addTime(200); + node->incrementCallCount(); + + EXPECT_EQ(800, node->totalTimeNs()); + EXPECT_EQ(3, node->callCount()); + EXPECT_EQ(100, node->minTimeNs()); + EXPECT_EQ(500, node->maxTimeNs()); + EXPECT_EQ(266, node->averageTimeNs()); // 800 / 3 = 266 +} + +TEST_F(HierarchicalTimerTest, toStringIncludesWallCpuColumns) { + auto* node = tree_.getOrCreateNode("op"); + node->addTime(5'000'000); // 5ms + node->incrementCallCount(); + + const auto output = tree_.toString(); + EXPECT_NE(output.find("Wall"), std::string::npos) << output; + EXPECT_NE(output.find("CPU"), std::string::npos) << output; + EXPECT_NE(output.find("Wall %"), std::string::npos) << output; + EXPECT_NE(output.find("CPU %"), std::string::npos) << output; + + // Verbose columns should not appear in compact mode. + const auto verboseOutput = tree_.toString(true); + EXPECT_NE(verboseOutput.find("Min Wall"), std::string::npos) << verboseOutput; + EXPECT_NE(verboseOutput.find("Max Wall"), std::string::npos) << verboseOutput; + EXPECT_NE(verboseOutput.find("Avg CPU"), std::string::npos) << verboseOutput; + EXPECT_NE(verboseOutput.find("Min CPU"), std::string::npos) << verboseOutput; + EXPECT_NE(verboseOutput.find("Max CPU"), std::string::npos) << verboseOutput; +} + +TEST_F(HierarchicalTimerTest, toStringWithTitle) { + TimerTree namedTree("MyBench"); + auto* node = namedTree.getOrCreateNode("op"); + node->addTime(1'000'000); + node->incrementCallCount(); + + const auto output = namedTree.toString(); + EXPECT_NE(output.find("MyBench"), std::string::npos) << output; + EXPECT_NE(output.find("HIERARCHICAL TIMING BREAKDOWN"), std::string::npos) + << output; +} + +TEST_F(HierarchicalTimerTest, resetAlsoResetsMinMaxAndCpu) { + auto* node = tree_.getOrCreateNode("tracked"); + node->addTime(100); + node->addCpuTime(50); + node->incrementCallCount(); + + EXPECT_EQ(100, node->minTimeNs()); + EXPECT_EQ(100, node->maxTimeNs()); + EXPECT_EQ(50, node->totalCpuNs()); + EXPECT_EQ(50, node->minCpuNs()); + EXPECT_EQ(50, node->maxCpuNs()); + + tree_.reset(); + + EXPECT_EQ(std::numeric_limits::max(), node->minTimeNs()); + EXPECT_EQ(0, node->maxTimeNs()); + EXPECT_EQ(0, node->totalCpuNs()); + EXPECT_EQ(std::numeric_limits::max(), node->minCpuNs()); + EXPECT_EQ(0, node->maxCpuNs()); +} + +TEST_F(HierarchicalTimerTest, cpuPerCallStats) { + auto* node = tree_.getOrCreateNode("tracked"); + node->addCpuTime(100); + node->addTime(200); + node->incrementCallCount(); + node->addCpuTime(500); + node->addTime(600); + node->incrementCallCount(); + node->addCpuTime(300); + node->addTime(400); + node->incrementCallCount(); + + EXPECT_EQ(900, node->totalCpuNs()); + EXPECT_EQ(3, node->callCount()); + EXPECT_EQ(300, node->averageCpuNs()); // 900 / 3 + EXPECT_EQ(100, node->minCpuNs()); + EXPECT_EQ(500, node->maxCpuNs()); +} + +TEST_F(HierarchicalTimerTest, toStringFormattedOutputWithCpuPerCall) { + auto* node = tree_.getOrCreateNode("op"); + node->addTime(10'000'000); // 10ms wall + node->addCpuTime(5'000'000); // 5ms CPU + node->incrementCallCount(); + node->addTime(20'000'000); // 20ms wall + node->addCpuTime(15'000'000); // 15ms CPU + node->incrementCallCount(); + + // Compact mode: no avg/min/max columns. + std::string expected = expectedHeader(); + expected += expectedLine("", "op", "30.00ms", "-", "20.00ms", "-", "2"); + expected += expectedFooter(); + + EXPECT_EQ(expected, tree_.toString()); +} + +TEST_F(HierarchicalTimerTest, toStringVerboseOutput) { + auto* node = tree_.getOrCreateNode("op"); + node->addTime(10'000'000); // 10ms wall + node->addCpuTime(5'000'000); // 5ms CPU + node->incrementCallCount(); + node->addTime(20'000'000); // 20ms wall + node->addCpuTime(15'000'000); // 15ms CPU + node->incrementCallCount(); + + std::string expected = expectedHeader("", true); + expected += expectedVerboseLine( + "", + "op", + "30.00ms", // total wall + "-", // wall % + "20.00ms", // total CPU + "-", // cpu % + "2", + "15.00ms", // avg wall: 30/2 + "10.00ms", // min wall + "20.00ms", // max wall + "10.00ms", // avg CPU: 20/2 + "5.00ms", // min CPU + "15.00ms"); // max CPU + expected += expectedFooter(true); + + EXPECT_EQ(expected, tree_.toString(true)); +} + +TEST_F(HierarchicalTimerTest, cpuNowReturnsNonZero) { + EXPECT_GT(tree_.cpuNow(), 0); +} + +TEST_F(HierarchicalTimerTest, cpuTimeTrackedByScopedTimer) { + { + ScopedTimer t(tree_, "cpuWork"); + doWork(); + } + + const auto& roots = tree_.root().children(); + ASSERT_EQ(1, roots.size()); + EXPECT_GT(roots[0]->totalTimeNs(), 0); + EXPECT_GT(roots[0]->totalCpuNs(), 0); +} + +// ===================================================================== +// Thread-local TimerTree tests +// ===================================================================== + +TEST(HierarchicalTimerThreadLocalTest, scopedTimerUsesThreadInstance) { + auto& tree = TimerTree::threadInstance(); + tree.clear(); + + { + ScopedTimer t("threadLocalOp"); + doWork(); + } + + const auto& roots = tree.root().children(); + ASSERT_EQ(1, roots.size()); + EXPECT_EQ("threadLocalOp", roots[0]->name()); + EXPECT_EQ(1, roots[0]->callCount()); + EXPECT_GT(roots[0]->totalTimeNs(), 0); + EXPECT_GT(roots[0]->totalCpuNs(), 0); + tree.clear(); +} + +TEST(HierarchicalTimerThreadLocalTest, threadInstanceIsSameAcrossCalls) { + auto& tree1 = TimerTree::threadInstance(); + auto& tree2 = TimerTree::threadInstance(); + EXPECT_EQ(&tree1, &tree2); +} + +// Simulates two different classes recording into the same thread-local tree +// using stack-based auto-nesting. +namespace { +void classAWork() { + ScopedTimer t1("classA"); + { + ScopedTimer t2("process"); + doWork(); + } +} + +void classBWork() { + ScopedTimer t1("classB"); + { + ScopedTimer t2("compute"); + doWork(); + } +} +} // namespace + +TEST(HierarchicalTimerThreadLocalTest, sharedAcrossClasses) { + auto& tree = TimerTree::threadInstance(); + tree.clear(); + + classAWork(); + classBWork(); + classAWork(); + + const auto& roots = tree.root().children(); + ASSERT_EQ(2, roots.size()); + EXPECT_EQ("classA", roots[0]->name()); + EXPECT_EQ("classB", roots[1]->name()); + + EXPECT_EQ(2, roots[0]->callCount()); + const auto& classAChildren = roots[0]->children(); + ASSERT_EQ(1, classAChildren.size()); + EXPECT_EQ("process", classAChildren[0]->name()); + EXPECT_EQ(2, classAChildren[0]->callCount()); + + EXPECT_EQ(1, roots[1]->callCount()); + const auto& classBChildren = roots[1]->children(); + ASSERT_EQ(1, classBChildren.size()); + EXPECT_EQ("compute", classBChildren[0]->name()); + EXPECT_EQ(1, classBChildren[0]->callCount()); + tree.clear(); +} + +TEST(HierarchicalTimerThreadLocalTest, nestedThreadLocalTimers) { + auto& tree = TimerTree::threadInstance(); + tree.clear(); + + { + ScopedTimer outer("pipeline"); + { + ScopedTimer inner("step1"); + doWork(); + } + { + ScopedTimer inner("step2"); + doWork(); + } + } + + const auto& roots = tree.root().children(); + ASSERT_EQ(1, roots.size()); + EXPECT_EQ("pipeline", roots[0]->name()); + EXPECT_EQ(1, roots[0]->callCount()); + ASSERT_EQ(2, roots[0]->children().size()); + EXPECT_EQ("step1", roots[0]->children()[0]->name()); + EXPECT_EQ("step2", roots[0]->children()[1]->name()); + tree.clear(); +} + +// ===================================================================== +// Tests - Stack-based auto-nesting +// ===================================================================== + +TEST(HierarchicalTimerAutoNestingTest, basicAutoNesting) { + auto& tree = TimerTree::threadInstance(); + tree.clear(); + + { + ScopedTimer t1("benchmark"); + { + ScopedTimer t2("read"); + { + ScopedTimer t3("decode"); + doWork(); + } + } + } + + const auto& roots = tree.root().children(); + ASSERT_EQ(1, roots.size()); + EXPECT_EQ("benchmark", roots[0]->name()); + EXPECT_EQ(1, roots[0]->callCount()); + + const auto& readChildren = roots[0]->children(); + ASSERT_EQ(1, readChildren.size()); + EXPECT_EQ("read", readChildren[0]->name()); + EXPECT_EQ(1, readChildren[0]->callCount()); + + const auto& decodeChildren = readChildren[0]->children(); + ASSERT_EQ(1, decodeChildren.size()); + EXPECT_EQ("decode", decodeChildren[0]->name()); + EXPECT_EQ(1, decodeChildren[0]->callCount()); + EXPECT_GT(decodeChildren[0]->totalTimeNs(), 0); + + EXPECT_GE(readChildren[0]->totalTimeNs(), decodeChildren[0]->totalTimeNs()); + EXPECT_GE(roots[0]->totalTimeNs(), readChildren[0]->totalTimeNs()); + tree.clear(); +} + +TEST(HierarchicalTimerAutoNestingTest, autoNestingRestoresActiveNode) { + auto& tree = TimerTree::threadInstance(); + tree.clear(); + + { + ScopedTimer t1("root"); + { + ScopedTimer t2("child1"); + doWork(); + } + // After child1 is destroyed, active node should be restored to "root". + { + ScopedTimer t3("child2"); + doWork(); + } + } + + const auto& roots = tree.root().children(); + ASSERT_EQ(1, roots.size()); + EXPECT_EQ("root", roots[0]->name()); + + // Both children should be siblings under "root", not nested. + const auto& children = roots[0]->children(); + ASSERT_EQ(2, children.size()); + EXPECT_EQ("child1", children[0]->name()); + EXPECT_EQ("child2", children[1]->name()); + tree.clear(); +} + +TEST(HierarchicalTimerAutoNestingTest, sequentialTopLevelTimers) { + auto& tree = TimerTree::threadInstance(); + tree.clear(); + + { + ScopedTimer t1("alpha"); + doWork(); + } + { + ScopedTimer t2("beta"); + doWork(); + } + + // Sequential (non-nested) timers should create separate top-level entries. + const auto& roots = tree.root().children(); + ASSERT_EQ(2, roots.size()); + EXPECT_EQ("alpha", roots[0]->name()); + EXPECT_EQ("beta", roots[1]->name()); + tree.clear(); +} + +TEST(HierarchicalTimerAutoNestingTest, autoNestingCpuTimeTracked) { + auto& tree = TimerTree::threadInstance(); + tree.clear(); + + { + ScopedTimer t1("outer"); + { + ScopedTimer t2("inner"); + doWork(); + } + } + + const auto& roots = tree.root().children(); + ASSERT_EQ(1, roots.size()); + EXPECT_GT(roots[0]->totalCpuNs(), 0); + EXPECT_GT(roots[0]->children()[0]->totalCpuNs(), 0); + tree.clear(); +} + +} // namespace facebook::velox::test diff --git a/velox/connectors/CMakeLists.txt b/velox/connectors/CMakeLists.txt index 3fd17dde2d2..cb174e8bad4 100644 --- a/velox/connectors/CMakeLists.txt +++ b/velox/connectors/CMakeLists.txt @@ -11,9 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -velox_add_library(velox_connector Connector.cpp) +velox_add_library(velox_connector Connector.cpp HEADERS Connector.h ConnectorRegistryInternal.h) -velox_link_libraries(velox_connector velox_common_config velox_vector) +velox_link_libraries(velox_connector velox_common_config velox_exec_spill_stats velox_vector) + +velox_add_library(velox_connector_registry ConnectorRegistry.cpp HEADERS ConnectorRegistry.h) + +velox_link_libraries(velox_connector_registry velox_connector velox_scoped_registry velox_core) add_subdirectory(fuzzer) diff --git a/velox/connectors/Connector.cpp b/velox/connectors/Connector.cpp index e7c72af478b..25912dc717c 100644 --- a/velox/connectors/Connector.cpp +++ b/velox/connectors/Connector.cpp @@ -16,92 +16,52 @@ #include "velox/connectors/Connector.h" -namespace facebook::velox::connector { -namespace { -std::unordered_map>& -connectorFactories() { - static std::unordered_map> - factories; - return factories; -} +#include "velox/common/EnumDefine.h" -std::unordered_map>& connectors() { - static std::unordered_map> connectors; - return connectors; -} -} // namespace +#include +#include -bool DataSink::Stats::empty() const { - return numWrittenBytes == 0 && numWrittenFiles == 0 && spillStats.empty(); -} +#include "velox/common/ScopedRegistry.h" +#include "velox/common/base/Exceptions.h" +#include "velox/connectors/ConnectorRegistryInternal.h" -std::string DataSink::Stats::toString() const { - return fmt::format( - "numWrittenBytes {} numWrittenFiles {} {}", - succinctBytes(numWrittenBytes), - numWrittenFiles, - spillStats.toString()); -} - -bool registerConnectorFactory(std::shared_ptr factory) { - bool ok = - connectorFactories().insert({factory->connectorName(), factory}).second; - VELOX_CHECK( - ok, - "ConnectorFactory with name '{}' is already registered", - factory->connectorName()); - return true; -} - -bool hasConnectorFactory(const std::string& connectorName) { - return connectorFactories().count(connectorName) == 1; -} - -bool unregisterConnectorFactory(const std::string& connectorName) { - auto count = connectorFactories().erase(connectorName); - return count == 1; -} +namespace facebook::velox::connector { -std::shared_ptr getConnectorFactory( - const std::string& connectorName) { - auto it = connectorFactories().find(connectorName); - VELOX_CHECK( - it != connectorFactories().end(), - "ConnectorFactory with name '{}' not registered", - connectorName); - return it->second; +ScopedRegistry& connectors() { + static ScopedRegistry instance; + return instance; } -bool registerConnector(std::shared_ptr connector) { - bool ok = connectors().insert({connector->connectorId(), connector}).second; - VELOX_CHECK( - ok, - "Connector with ID '{}' is already registered", - connector->connectorId()); +bool registerConnector(const std::shared_ptr& connector) { + connectors().insert(connector->connectorId(), connector); return true; } bool unregisterConnector(const std::string& connectorId) { - auto count = connectors().erase(connectorId); - return count == 1; + return connectors().erase(connectorId); } std::shared_ptr getConnector(const std::string& connectorId) { - auto it = connectors().find(connectorId); - VELOX_CHECK( - it != connectors().end(), - "Connector with ID '{}' not registered", - connectorId); - return it->second; + auto connector = connectors().find(connectorId); + VELOX_CHECK_NOT_NULL( + connector, "Connector with ID is not registered: {}", connectorId); + return connector; } bool hasConnector(const std::string& connectorId) { - return connectors().find(connectorId) != connectors().end(); + return connectors().find(connectorId) != nullptr; +} + +bool DataSink::Stats::empty() const { + return numWrittenBytes == 0 && numWrittenFiles == 0 && spillStats.empty(); } -const std::unordered_map>& -getAllConnectors() { - return connectors(); +std::string DataSink::Stats::toString() const { + return fmt::format( + "numWrittenBytes {} numWrittenFiles {} {}", + succinctBytes(numWrittenBytes), + numWrittenFiles, + spillStats.toString()); } folly::Synchronized< diff --git a/velox/connectors/Connector.h b/velox/connectors/Connector.h index 10cf82b342c..72ad3dc4b1e 100644 --- a/velox/connectors/Connector.h +++ b/velox/connectors/Connector.h @@ -16,17 +16,20 @@ #pragma once #include "folly/CancellationToken.h" -#include "velox/common/Enums.h" +#include "velox/common/EnumDeclare.h" #include "velox/common/base/AsyncSource.h" #include "velox/common/base/PrefixSortConfig.h" #include "velox/common/base/RuntimeMetrics.h" #include "velox/common/base/SpillConfig.h" -#include "velox/common/base/SpillStats.h" #include "velox/common/caching/AsyncDataCache.h" #include "velox/common/caching/ScanTracker.h" +#include "velox/common/config/ConfigProvider.h" #include "velox/common/file/TokenProvider.h" #include "velox/common/future/VeloxPromise.h" #include "velox/core/ExpressionEvaluator.h" +#include "velox/core/QueryConfig.h" +#include "velox/core/ScanBatchEvent.h" +#include "velox/exec/SpillStats.h" #include "velox/type/Filter.h" #include "velox/vector/ComplexVector.h" @@ -44,7 +47,7 @@ class ConfigBase; namespace facebook::velox::core { class ITypedExpr; -} +} // namespace facebook::velox::core namespace facebook::velox::core { struct IndexLookupCondition; @@ -80,7 +83,11 @@ struct ConnectorSplit : public ISerializable { return 0; } - virtual ~ConnectorSplit() {} + virtual ~ConnectorSplit() { + if (dataSource) { + dataSource->close(); + } + } virtual std::string toString() const { return fmt::format( @@ -134,6 +141,16 @@ class ConnectorTableHandle : public ISerializable { return false; } + /// Returns true if this table handle requires splits for index lookup. + /// Default implementation returns false. Subclasses can override this to + /// indicate that splits need to be provided to the index source before + /// performing lookups. + /// + /// NOTE: this only applies if supportsIndexLookup() returns true. + virtual bool needsIndexSplit() const { + return false; + } + virtual std::string toString() const { return name(); } @@ -191,7 +208,7 @@ class DataSink { uint64_t recodeTimeNs{0}; uint64_t compressionTimeNs{0}; - common::SpillStats spillStats; + exec::SpillStats spillStats; bool empty() const; @@ -266,7 +283,20 @@ class DataSource { /// Returns the number of input rows processed so far. virtual uint64_t getCompletedRows() = 0; - virtual std::unordered_map runtimeStats() = 0; + /// Stores a callback to fire after each scan batch. + void setScanBatchCallback(core::ScanBatchCallback callback) { + scanBatchCallback_ = std::move(callback); + } + + /// Called by TableScan after each non-empty batch with generic scan stats. + /// Default is a no-op. Subclasses should override to create a + /// connector-specific event (e.g., FileScanBatchEvent), enrich it with + /// connector-specific fields, and call scanBatchCallback_. + virtual void fireScanBatchCallback(core::ScanBatchEvent /*event*/) {} + + virtual std::unordered_map getRuntimeStats() { + return {}; + } /// Returns true if 'this' has initiated all the prefetch this will initiate. /// This means that the caller should schedule next splits to prefetch in the @@ -308,33 +338,46 @@ class DataSource { /// connector implementation decides how to support the cancellation if /// needed. virtual void cancel() {} + + protected: + core::ScanBatchCallback scanBatchCallback_; }; class IndexSource { public: virtual ~IndexSource() = default; + /// Adds splits to the index source for lookup. This is called when + /// the table handle's needsIndexSplit() returns true. This method must be + /// called before the first call to lookup(). This method is expected to be + /// called only once. Default implementation throws as most index sources + /// don't require splits. + virtual void addSplits( + std::vector> /*splits*/) { + VELOX_UNSUPPORTED("This IndexSource does not support splits"); + } + /// Represents a lookup request for a given input. - struct LookupRequest { + struct Request { /// Contains the input column vectors used by lookup join and range /// conditions. RowVectorPtr input; - explicit LookupRequest(RowVectorPtr input) : input(std::move(input)) {} + explicit Request(RowVectorPtr input) : input(std::move(input)) {} }; /// Represents the lookup result for a subset of input produced by the - /// 'LookupResultIterator'. - struct LookupResult { + /// 'ResultIterator'. + struct Result { /// Specifies the indices of input row in the lookup request that have /// matches in 'output'. It contains the input indices in the order /// of the input rows in the lookup request. Any gap in the indices means /// the input rows that has no matches in output. /// /// Example: - /// LookupRequest: input = [0, 1, 2, 3, 4] - /// LookupResult: inputHits = [0, 0, 2, 2, 3, 4, 4, 4] - /// output = [0, 1, 2, 3, 4, 5, 6, 7] + /// Request: input = [0, 1, 2, 3, 4] + /// Result: inputHits = [0, 0, 2, 2, 3, 4, 4, 4] + /// output = [0, 1, 2, 3, 4, 5, 6, 7] /// /// Here is match results for each input row: /// input row #0: match with output rows #0 and #1. @@ -343,7 +386,7 @@ class IndexSource { /// input row #3: match with output row #4. /// input row #4: match with output rows #5, #6 and #7. /// - /// 'LookupResultIterator' must also produce the output result in order of + /// 'ResultIterator' must also produce the output result in order of /// input rows. BufferPtr inputHits; @@ -354,7 +397,7 @@ class IndexSource { return output->size(); } - LookupResult(BufferPtr _inputHits, RowVectorPtr _output) + Result(BufferPtr _inputHits, RowVectorPtr _output) : inputHits(std::move(_inputHits)), output(std::move(_output)) { VELOX_CHECK_EQ(inputHits->size() / sizeof(vector_size_t), output->size()); } @@ -362,24 +405,28 @@ class IndexSource { /// The lookup result iterator used to fetch the lookup result in batch for a /// given lookup request. - class LookupResultIterator { + class ResultIterator { public: - virtual ~LookupResultIterator() = default; + virtual ~ResultIterator() = default; + + /// Invoked to check if there are more lookup results available to fetch. + /// Returns true if there are more results, false otherwise. This allows + /// the caller to determine whether to continue calling 'next()'. + virtual bool hasNext() = 0; /// Invoked to fetch up to 'size' number of output rows. Returns nullptr if /// all the lookup results have been fetched. Returns std::nullopt and sets /// the 'future' if started asynchronous work and needs to wait for it to /// complete to continue processing. The caller will wait for the 'future' /// to complete before calling 'next' again. - virtual std::optional> next( + virtual std::optional> next( vector_size_t size, velox::ContinueFuture& future) { VELOX_UNSUPPORTED(); } }; - virtual std::shared_ptr lookup( - const LookupRequest& request) = 0; + virtual std::shared_ptr lookup(const Request& request) = 0; virtual std::unordered_map runtimeStats() = 0; }; @@ -501,14 +548,24 @@ class ConnectorQueryCtx { return cancellationToken_; } + /// Deprecated: Use FileConfig::kSelectiveNimbleReaderEnabledSession instead. bool selectiveNimbleReaderEnabled() const { return selectiveNimbleReaderEnabled_; } + /// Deprecated: Use connector session properties instead. void setSelectiveNimbleReaderEnabled(bool value) { selectiveNimbleReaderEnabled_ = value; } + core::QueryConfig::RowSizeTrackingMode rowSizeTrackingMode() const { + return rowSizeTrackingEnabled_; + } + + void setRowSizeTrackingMode(core::QueryConfig::RowSizeTrackingMode value) { + rowSizeTrackingEnabled_ = value; + } + std::shared_ptr fsTokenProvider() const { return fsTokenProvider_; } @@ -531,6 +588,30 @@ class ConnectorQueryCtx { const folly::CancellationToken cancellationToken_; const std::shared_ptr fsTokenProvider_; bool selectiveNimbleReaderEnabled_{false}; + core::QueryConfig::RowSizeTrackingMode rowSizeTrackingEnabled_{ + core::QueryConfig::RowSizeTrackingMode::ENABLED_FOR_ALL}; +}; + +class Connector; + +class ConnectorFactory { + public: + explicit ConnectorFactory(const char* name) : name_(name) {} + + virtual ~ConnectorFactory() = default; + + const std::string& connectorName() const { + return name_; + } + + virtual std::shared_ptr newConnector( + const std::string& id, + std::shared_ptr config, + folly::Executor* ioExecutor = nullptr, + folly::Executor* cpuExecutor = nullptr) = 0; + + private: + const std::string name_; }; class Connector { @@ -550,6 +631,12 @@ class Connector { return config_; } + /// Returns the config provider for this connector's session properties, + /// or nullptr if the connector has no session-overridable properties. + virtual const config::ConfigProvider* configProvider() const { + return nullptr; + } + /// Returns true if this connector would accept a filter dynamically /// generated during query execution. virtual bool canAddDynamicFilter() const { @@ -566,15 +653,9 @@ class Connector { /// ConnectorSplit in addSplit(). If so, TableScan can preload splits /// so that file opening and metadata operations are off the Driver' /// thread. -#ifdef VELOX_ENABLE_BACKWARD_COMPATIBILITY - virtual bool supportsSplitPreload() { - return false; - } -#else virtual bool supportsSplitPreload() const { return false; } -#endif /// Returns true if the connector supports index lookup, otherwise false. virtual bool supportsIndexLookup() const { @@ -582,14 +663,14 @@ class Connector { } /// Creates index source for index join lookup. - /// @param inputType The list of probe-side columns that either used in - /// equi-clauses or join conditions. - /// @param numJoinKeys The number of key columns used in join equi-clauses. - /// The first 'numJoinKeys' columns in 'inputType' form a prefix of the - /// index, and the rest of the columns in inputType are expected to be used in - /// 'joinConditions'. - /// @param joinConditions The join conditions. It expects inputs columns from - /// the 'tail' of 'inputType' and from 'columnHandles'. + /// @param inputType The list of probe-side columns used in join conditions. + /// @param joinConditions The join conditions that specify how to perform the + /// index lookup. This includes: + /// - EqualIndexLookupCondition: For equi-join conditions. + /// - InIndexLookupCondition: For IN-list conditions. + /// - BetweenIndexLookupCondition: For range conditions. + /// The index source can determine which columns form the index prefix by + /// examining EqualIndexLookupCondition objects where !isFilter(). /// @param outputType The lookup output type from index source. /// @param tableHandle The index table handle. /// @param columnHandles The column handles which maps from column name @@ -607,9 +688,12 @@ class Connector { /// /// Here, /// - 'inputType' is ROW{t.sid, t.event_list} - /// - 'numJoinKeys' is 1 since only t.sid is used in join equi-clauses. - /// - 'joinConditions' specifies the join condition: contains(t.event_list, - /// u.event_type) + /// - 'joinConditions' includes: + /// - EqualIndexLookupCondition(u.sid, t.sid) for the equi-join + /// - InIndexLookupCondition(u.event_type, t.event_list) for the IN + /// condition + /// - BetweenIndexLookupCondition(u.ds, '2024-01-01', '2024-01-07') for the + /// BETWEEN condition /// - 'outputType' is ROW{u.event_value} /// - 'tableHandle' specifies the metadata of the index table. /// - 'columnHandles' is a map from 'u.event_type' (in 'joinConditions') and @@ -619,7 +703,6 @@ class Connector { /// virtual std::shared_ptr createIndexSource( const RowTypePtr& inputType, - size_t numJoinKeys, const std::vector>& joinConditions, const RowTypePtr& outputType, @@ -658,9 +741,37 @@ class Connector { /// The name of the common runtime stats collected and reported by connector /// data/index sources. - static inline const std::string kTotalRemainingFilterTime{ + static constexpr std::string_view kTotalRemainingFilterTime{ "totalRemainingFilterWallNanos"}; + /// Total CPU time spent on remaining filter evaluation. + static inline const std::string kTotalRemainingFilterCpuTime{ + "totalRemainingFilterCpuNanos"}; + + /// Total time spent waiting for synchronously issued IO or for an in-progress + /// read-ahead to finish. + static constexpr std::string_view kIoWaitWallNanos{"ioWaitWallNanos"}; + + /// Time spent waiting for remote storage reads (S3, HDFS, etc.) + static constexpr std::string_view kStorageReadWallNanos{ + "storageReadWallNanos"}; + + /// Time spent waiting for SSD cache reads. + static constexpr std::string_view kSsdCacheReadWallNanos{ + "ssdCacheReadWallNanos"}; + + /// Time spent waiting for EXCLUSIVE cache entries (another thread is + /// loading). + static constexpr std::string_view kCacheWaitWallNanos{"cacheWaitWallNanos"}; + + /// Time spent waiting for coalesced loads from SSD cache. + static constexpr std::string_view kCoalescedSsdLoadWallNanos{ + "coalescedSsdLoadWallNanos"}; + + /// Time spent waiting for coalesced loads from remote storage. + static constexpr std::string_view kCoalescedStorageLoadWallNanos{ + "coalescedStorageLoadWallNanos"}; + private: static void unregisterTracker(cache::ScanTracker* tracker); @@ -672,66 +783,18 @@ class Connector { trackers_; }; -class ConnectorFactory { - public: - explicit ConnectorFactory(const char* name) : name_(name) {} - - virtual ~ConnectorFactory() = default; - - const std::string& connectorName() const { - return name_; - } - - virtual std::shared_ptr newConnector( - const std::string& id, - std::shared_ptr config, - folly::Executor* ioExecutor = nullptr, - folly::Executor* cpuExecutor = nullptr) = 0; +/// Deprecated free functions. Use ConnectorRegistry methods instead. - private: - const std::string name_; -}; +[[deprecated("Use ConnectorRegistry::global().insert() instead.")]] +bool registerConnector(const std::shared_ptr& connector); -/// Adds a factory for creating connectors to the registry using connector -/// name as the key. Throws if factor with the same name is already present. -/// Always returns true. The return value makes it easy to use with -/// FB_ANONYMOUS_VARIABLE. -bool registerConnectorFactory(std::shared_ptr factory); - -/// Returns true if a connector with the specified name has been registered, -/// false otherwise. -bool hasConnectorFactory(const std::string& connectorName); - -/// Unregister a connector factory by name. -/// Returns true if a connector with the specified name has been -/// unregistered, false otherwise. -bool unregisterConnectorFactory(const std::string& connectorName); - -/// Returns a factory for creating connectors with the specified name. -/// Throws if factory doesn't exist. -std::shared_ptr getConnectorFactory( - const std::string& connectorName); - -/// Adds connector instance to the registry using connector ID as the key. -/// Throws if connector with the same ID is already present. Always returns -/// true. The return value makes it easy to use with FB_ANONYMOUS_VARIABLE. -bool registerConnector(std::shared_ptr connector); - -/// Returns true if a connector with the specified ID has been registered, false -/// otherwise. +[[deprecated("Use ConnectorRegistry::tryGet() instead.")]] bool hasConnector(const std::string& connectorId); -/// Removes the connector with specified ID from the registry. Returns true -/// if connector was removed and false if connector didn't exist. +[[deprecated("Use ConnectorRegistry::global().erase() instead.")]] bool unregisterConnector(const std::string& connectorId); -/// Returns a connector with specified ID. Throws if connector doesn't -/// exist. +[[deprecated("Use ConnectorRegistry::tryGet() instead.")]] std::shared_ptr getConnector(const std::string& connectorId); -/// Returns a map of all (connectorId -> connector) pairs currently -/// registered. -const std::unordered_map>& -getAllConnectors(); - } // namespace facebook::velox::connector diff --git a/velox/connectors/ConnectorRegistry.cpp b/velox/connectors/ConnectorRegistry.cpp new file mode 100644 index 00000000000..23155a435a1 --- /dev/null +++ b/velox/connectors/ConnectorRegistry.cpp @@ -0,0 +1,84 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/ConnectorRegistry.h" + +#include +#include +#include +#include + +#include "velox/connectors/Connector.h" +#include "velox/connectors/ConnectorRegistryInternal.h" +#include "velox/core/QueryCtx.h" + +namespace facebook::velox::connector { + +namespace { + +ConnectorRegistry::Registry& registryFor(const core::QueryCtx& queryCtx) { + auto registry = queryCtx.registry( + ConnectorRegistry::kRegistryKey); + return registry ? *registry : ConnectorRegistry::global(); +} + +} // namespace + +// static +ConnectorRegistry::Registry& ConnectorRegistry::global() { + return connectors(); +} + +// static +std::shared_ptr ConnectorRegistry::create( + const Registry* parent) { + return std::make_shared(parent); +} + +// static +std::shared_ptr ConnectorRegistry::tryGet( + const core::QueryCtx& queryCtx, + const std::string& connectorId) { + return registryFor(queryCtx).find(connectorId); +} + +// static +std::shared_ptr ConnectorRegistry::tryGet( + const std::string& connectorId) { + return global().find(connectorId); +} + +// static +void ConnectorRegistry::unregisterAll(const core::QueryCtx& queryCtx) { + auto registry = queryCtx.registry( + ConnectorRegistry::kRegistryKey); + if (registry) { + registry->clear(); + } +} + +// static +void ConnectorRegistry::unregisterAll() { + global().clear(); +} + +// static +std::vector>> +ConnectorRegistry::snapshot(const core::QueryCtx& queryCtx) { + return registryFor(queryCtx).snapshot(); +} + +} // namespace facebook::velox::connector diff --git a/velox/connectors/ConnectorRegistry.h b/velox/connectors/ConnectorRegistry.h new file mode 100644 index 00000000000..c2a61ec81d4 --- /dev/null +++ b/velox/connectors/ConnectorRegistry.h @@ -0,0 +1,112 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include "velox/common/ScopedRegistry.h" +#include "velox/connectors/Connector.h" + +namespace facebook::velox::core { +class QueryCtx; +} // namespace facebook::velox::core + +namespace facebook::velox::connector { + +/// Manages connector registration and lookup. All methods are thread-safe. +/// +/// Two groups of APIs: +/// +/// - Query-scoped APIs take a QueryCtx& and check for per-query registry +/// overrides before falling back to the global registry. Use these in +/// operator and expression evaluation code where a QueryCtx is available. +/// +/// - Global APIs operate directly on the global registry. Use these for +/// process-level operations: startup registration, shutdown cleanup, and +/// process-wide lookups (e.g., periodic stats reporting). +/// All methods are thread-safe. +class ConnectorRegistry { + public: + using Registry = ScopedRegistry; + + /// Registry key for per-query connector overrides on QueryCtx. + static constexpr std::string_view kRegistryKey = "connectors"; + + /// Return the global registry (root scope). + static Registry& global(); + + /// Create a per-query registry. If 'parent' is provided, lookups fall back + /// to it. Pass nullptr for isolation mode (no fallback). + static std::shared_ptr create(const Registry* parent = nullptr); + + /// Return the connector with the specified ID, or nullptr if not registered. + /// Checks per-query override on QueryCtx first, falls back to the global + /// registry if no override is set. + static std::shared_ptr tryGet( + const core::QueryCtx& queryCtx, + const std::string& connectorId); + + /// Return the connector with the specified ID from the global registry, or + /// nullptr if not registered. + static std::shared_ptr tryGet(const std::string& connectorId); + + /// Return all connectors whose implementation is of type T. Checks per-query + /// override on QueryCtx first, falls back to the global registry if no + /// override is set. + template + static std::vector> findAll( + const core::QueryCtx& queryCtx) { + std::vector> result; + for (auto& [_, connector] : snapshot(queryCtx)) { + if (auto casted = std::dynamic_pointer_cast(connector)) { + result.push_back(std::move(casted)); + } + } + return result; + } + + /// Return all connectors from the global registry whose implementation is + /// of type T. + template + static std::vector> findAll() { + std::vector> result; + for (auto& [_, connector] : global().snapshot()) { + if (auto casted = std::dynamic_pointer_cast(connector)) { + result.push_back(std::move(casted)); + } + } + return result; + } + + /// Unregister all connectors from the registry visible to the given query. + static void unregisterAll(const core::QueryCtx& queryCtx); + + /// Unregister all connectors from the global registry. + static void unregisterAll(); + + private: + // Return a snapshot of all connectors visible to the given query. Uses + // per-query registry if set, otherwise the global registry. + static std::vector>> + snapshot(const core::QueryCtx& queryCtx); +}; + +} // namespace facebook::velox::connector diff --git a/velox/connectors/ConnectorRegistryInternal.h b/velox/connectors/ConnectorRegistryInternal.h new file mode 100644 index 00000000000..feb55b5e1a4 --- /dev/null +++ b/velox/connectors/ConnectorRegistryInternal.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "velox/common/ScopedRegistry.h" +#include "velox/connectors/Connector.h" + +namespace facebook::velox::connector { + +// Internal helper shared by Connector.cpp and ConnectorRegistry.cpp. +// Not part of the public API. Do not include from outside velox/connectors/. +ScopedRegistry& connectors(); + +} // namespace facebook::velox::connector diff --git a/velox/connectors/fuzzer/CMakeLists.txt b/velox/connectors/fuzzer/CMakeLists.txt index 8e21030f99d..09f9772cc01 100644 --- a/velox/connectors/fuzzer/CMakeLists.txt +++ b/velox/connectors/fuzzer/CMakeLists.txt @@ -13,6 +13,7 @@ # limitations under the License. add_library(velox_fuzzer_connector OBJECT FuzzerConnector.cpp) +velox_add_test_headers(velox_fuzzer_connector FuzzerConnector.h FuzzerConnectorSplit.h) target_link_libraries(velox_fuzzer_connector velox_connector velox_vector_fuzzer) diff --git a/velox/connectors/fuzzer/FuzzerConnector.h b/velox/connectors/fuzzer/FuzzerConnector.h index 53e94b5f638..5b3f9bf74e2 100644 --- a/velox/connectors/fuzzer/FuzzerConnector.h +++ b/velox/connectors/fuzzer/FuzzerConnector.h @@ -77,7 +77,7 @@ class FuzzerDataSource : public DataSource { return completedBytes_; } - std::unordered_map runtimeStats() override { + std::unordered_map getRuntimeStats() override { // TODO: Which stats do we want to expose here? return {}; } diff --git a/velox/connectors/fuzzer/tests/CMakeLists.txt b/velox/connectors/fuzzer/tests/CMakeLists.txt index 6f0dbf3b647..14528bf0b01 100644 --- a/velox/connectors/fuzzer/tests/CMakeLists.txt +++ b/velox/connectors/fuzzer/tests/CMakeLists.txt @@ -24,3 +24,5 @@ target_link_libraries( GTest::gtest GTest::gtest_main ) + +velox_add_library(velox_fuzzer_connector_lib INTERFACE HEADERS FuzzerConnectorTestBase.h) diff --git a/velox/connectors/fuzzer/tests/FuzzerConnectorTestBase.h b/velox/connectors/fuzzer/tests/FuzzerConnectorTestBase.h index b0c703f54bc..7b38fb23cdc 100644 --- a/velox/connectors/fuzzer/tests/FuzzerConnectorTestBase.h +++ b/velox/connectors/fuzzer/tests/FuzzerConnectorTestBase.h @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "velox/connectors/ConnectorRegistry.h" #include "velox/connectors/fuzzer/FuzzerConnector.h" #include "velox/exec/tests/utils/OperatorTestBase.h" #include "velox/vector/fuzzer/VectorFuzzer.h" @@ -28,11 +29,12 @@ class FuzzerConnectorTestBase : public exec::test::OperatorTestBase { OperatorTestBase::SetUp(); connector::fuzzer::FuzzerConnectorFactory factory; auto fuzzerConnector = factory.newConnector(kFuzzerConnectorId, nullptr); - connector::registerConnector(fuzzerConnector); + connector::ConnectorRegistry::global().insert( + fuzzerConnector->connectorId(), fuzzerConnector); } void TearDown() override { - connector::unregisterConnector(kFuzzerConnectorId); + connector::ConnectorRegistry::global().erase(kFuzzerConnectorId); OperatorTestBase::TearDown(); } diff --git a/velox/connectors/hive/BufferedInputBuilder.cpp b/velox/connectors/hive/BufferedInputBuilder.cpp new file mode 100644 index 00000000000..082c77c1922 --- /dev/null +++ b/velox/connectors/hive/BufferedInputBuilder.cpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/BufferedInputBuilder.h" +#include "velox/connectors/hive/HiveConnectorUtil.h" + +namespace facebook::velox::connector::hive { + +class DefaultBufferInputBuilder : public BufferedInputBuilder { + public: + std::unique_ptr create( + const FileHandle& fileHandle, + const dwio::common::ReaderOptions& readerOpts, + const ConnectorQueryCtx* connectorQueryCtx, + std::shared_ptr ioStatistics, + std::shared_ptr ioStats, + folly::Executor* executor, + const folly::F14FastMap& fileReadOps) override { + return createBufferedInput( + fileHandle, + readerOpts, + connectorQueryCtx, + ioStatistics, + ioStats, + executor, + fileReadOps); + } +}; + +// static +std::shared_ptr BufferedInputBuilder::builder_ = + std::make_shared(); + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/BufferedInputBuilder.h b/velox/connectors/hive/BufferedInputBuilder.h new file mode 100644 index 00000000000..583ec71c5c1 --- /dev/null +++ b/velox/connectors/hive/BufferedInputBuilder.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +#include "velox/connectors/Connector.h" +#include "velox/connectors/hive/FileHandle.h" +#include "velox/dwio/common/BufferedInput.h" +#include "velox/dwio/common/Reader.h" + +namespace facebook::velox::connector::hive { + +/// Registering a different implementation of BufferedInput is allowed using +/// 'registerBuilder' API. +class BufferedInputBuilder { + public: + virtual ~BufferedInputBuilder() = default; + + static const std::shared_ptr& getInstance() { + VELOX_CHECK_NOT_NULL(builder_, "Builder is not registered"); + return builder_; + } + + static void registerBuilder(std::shared_ptr builder) { + VELOX_CHECK_NOT_NULL(builder); + builder_ = std::move(builder); + } + + virtual std::unique_ptr create( + const FileHandle& fileHandle, + const dwio::common::ReaderOptions& readerOpts, + const ConnectorQueryCtx* connectorQueryCtx, + std::shared_ptr ioStatistics, + std::shared_ptr ioStats, + folly::Executor* executor, + const folly::F14FastMap& fileReadOps = {}) = 0; + + private: + static std::shared_ptr builder_; +}; + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/CMakeLists.txt b/velox/connectors/hive/CMakeLists.txt index e7bbb90fe04..f2b941d6e71 100644 --- a/velox/connectors/hive/CMakeLists.txt +++ b/velox/connectors/hive/CMakeLists.txt @@ -12,31 +12,77 @@ # See the License for the specific language governing permissions and # limitations under the License. -velox_add_library(velox_hive_config OBJECT HiveConfig.cpp) +velox_add_library(velox_hive_config OBJECT FileConfig.cpp HiveConfig.cpp) velox_link_libraries(velox_hive_config velox_core velox_exception) add_subdirectory(iceberg) +add_subdirectory(paimon) velox_add_library( velox_hive_connector OBJECT - FileHandle.cpp + BufferedInputBuilder.cpp + ExtractionUtils.cpp + ExtractionUtils.h + FileColumnHandle.cpp + FileConfig.cpp + FileConnectorUtil.cpp + FileDataSink.cpp + FileDataSource.cpp + FileIndexReader.cpp + FileSplitReader.cpp HiveConfig.cpp + HiveConfigProvider.cpp HiveConnector.cpp - HiveConnectorUtil.cpp HiveConnectorSplit.cpp + HiveConnectorUtil.cpp HiveDataSink.cpp HiveDataSource.cpp - HivePartitionUtil.cpp + HiveIndexSource.cpp + HivePartitionName.cpp + HiveSplitReader.cpp PartitionIdGenerator.cpp - SplitReader.cpp TableHandle.cpp + HEADERS + BufferedInputBuilder.h + FileColumnHandle.h + FileConfig.h + FileConnectorSplit.h + FileConnectorUtil.h + FileDataSink.h + FileDataSource.h + FileHandle.h + FileIndexReader.h + FileProperties.h + FileSplitReader.h + FileTableHandle.h + HiveConfig.h + HiveConfigMacrosDefine.h + HiveConfigMacrosUndef.h + HiveConfigProvider.h + HiveConnector.h + HiveConnectorSplit.h + HiveConnectorUtil.h + HiveDataSink.h + HiveDataSource.h + HiveIndexSource.h + HivePartitionFunction.h + HivePartitionName.h + HiveSplitReader.h + IndexReader.h + PartitionIdGenerator.h + TableHandle.h ) velox_link_libraries( velox_hive_connector - PUBLIC velox_hive_iceberg_splitreader - PRIVATE velox_common_io velox_connector velox_dwio_catalog_fbhive velox_hive_partition_function + PRIVATE + velox_common_io + velox_connector + velox_dwio_catalog_fbhive + velox_exec + velox_hive_partition_function + velox_key_encoder ) velox_add_library(velox_hive_partition_function HivePartitionFunction.cpp) diff --git a/velox/connectors/hive/ExtractionUtils.cpp b/velox/connectors/hive/ExtractionUtils.cpp new file mode 100644 index 00000000000..b660c1cecae --- /dev/null +++ b/velox/connectors/hive/ExtractionUtils.cpp @@ -0,0 +1,531 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/ExtractionUtils.h" + +#include "velox/type/Filter.h" +#include "velox/vector/FlatVector.h" + +namespace facebook::velox::connector::hive { + +namespace { + +/// Extract sizes from a MapVector or ArrayVector as a FlatVector. +VectorPtr extractSizes(const VectorPtr& input, memory::MemoryPool* pool) { + auto numRows = input->size(); + auto sizesBuf = AlignedBuffer::allocate(numRows, pool); + auto* sizesData = sizesBuf->asMutable(); + + const vector_size_t* rawSizes; + if (input->typeKind() == TypeKind::MAP) { + rawSizes = input->as()->rawSizes(); + } else { + rawSizes = input->as()->rawSizes(); + } + for (vector_size_t i = 0; i < numRows; ++i) { + sizesData[i] = rawSizes[i]; + } + + return std::make_shared>( + pool, + BIGINT(), + input->nulls(), + numRows, + std::move(sizesBuf), + std::vector{}); +} + +/// Filter a MapVector to only keep entries with matching keys. +VectorPtr filterMapByKeys( + const VectorPtr& input, + const ExtractionPathElement& step, + memory::MemoryPool* pool) { + auto* map = input->as(); + auto numRows = map->size(); + auto* rawOffsets = map->rawOffsets(); + auto* rawSizes = map->rawSizes(); + auto mapKeys = map->mapKeys(); + auto mapValues = map->mapValues(); + + // Build output offsets and sizes, and a mapping of which key-value pairs + // to keep. + auto newOffsetsBuf = AlignedBuffer::allocate(numRows, pool); + auto newSizesBuf = AlignedBuffer::allocate(numRows, pool); + auto* newOffsets = newOffsetsBuf->asMutable(); + auto* newSizes = newSizesBuf->asMutable(); + + // Build keep flags for all key-value pairs across all map entries. + std::vector keepFlags; + vector_size_t totalInputElements = 0; + for (vector_size_t i = 0; i < numRows; ++i) { + totalInputElements += rawSizes[i]; + } + keepFlags.resize(totalInputElements, false); + + if (totalInputElements == 0) { + // All maps are empty; return the input as-is with zeroed sizes. + return std::make_shared( + pool, + input->type(), + map->nulls(), + numRows, + std::move(newOffsetsBuf), + std::move(newSizesBuf), + mapKeys, + mapValues); + } + + auto* stringFilter = + dynamic_cast(&step); + if (stringFilter) { + std::unordered_set keySet( + stringFilter->filterKeys().begin(), stringFilter->filterKeys().end()); + auto* keyVector = mapKeys->as>(); + for (vector_size_t i = 0; i < totalInputElements; ++i) { + VELOX_DCHECK_LT( + static_cast(i), keepFlags.size(), "Index out of bounds"); + if (!keyVector->isNullAt(i) && + keySet.count(std::string(keyVector->valueAt(i))) > 0) { + keepFlags[i] = true; + } + } + } else { + auto& intFilter = + static_cast(step); + std::unordered_set keySet( + intFilter.filterKeys().begin(), intFilter.filterKeys().end()); + auto* keyVector = mapKeys->as>(); + for (vector_size_t i = 0; i < totalInputElements; ++i) { + VELOX_DCHECK_LT( + static_cast(i), keepFlags.size(), "Index out of bounds"); + if (!keyVector->isNullAt(i) && keySet.count(keyVector->valueAt(i)) > 0) { + keepFlags[i] = true; + } + } + } + + // Count kept entries per map. + vector_size_t totalKept = 0; + for (vector_size_t i = 0; i < numRows; ++i) { + newOffsets[i] = totalKept; + vector_size_t kept = 0; + for (vector_size_t j = 0; j < rawSizes[i]; ++j) { + VELOX_DCHECK_LT( + static_cast(rawOffsets[i] + j), + keepFlags.size(), + "Index out of bounds"); + if (keepFlags[rawOffsets[i] + j]) { + ++kept; + } + } + newSizes[i] = kept; + totalKept += kept; + } + + // Build index mapping for kept entries. + auto indexBuf = AlignedBuffer::allocate(totalKept, pool); + auto* indices = indexBuf->asMutable(); + vector_size_t outIdx = 0; + for (vector_size_t i = 0; i < numRows; ++i) { + for (vector_size_t j = 0; j < rawSizes[i]; ++j) { + auto srcIdx = rawOffsets[i] + j; + VELOX_DCHECK_LT( + static_cast(srcIdx), keepFlags.size(), "Index out of bounds"); + if (keepFlags[srcIdx]) { + indices[outIdx++] = srcIdx; + } + } + } + + // Create filtered keys and values using dictionary wrapping. + auto filteredKeys = + BaseVector::wrapInDictionary(nullptr, indexBuf, totalKept, mapKeys); + auto filteredValues = + BaseVector::wrapInDictionary(nullptr, indexBuf, totalKept, mapValues); + + return std::make_shared( + pool, + input->type(), + map->nulls(), + numRows, + std::move(newOffsetsBuf), + std::move(newSizesBuf), + filteredKeys, + filteredValues); +} + +/// Recursive implementation of extraction chain application. +VectorPtr applyExtractionChainImpl( + const VectorPtr& input, + const std::vector& chain, + size_t index, + memory::MemoryPool* pool) { + if (index >= chain.size()) { + return input; + } + + const auto& step = chain[index]; + switch (step->step()) { + case ExtractionStep::kStructField: { + auto* row = input->as(); + VELOX_CHECK_NOT_NULL(row); + auto& fieldName = + static_cast(*step) + .fieldName(); + auto childIdx = input->type()->asRow().getChildIdx(fieldName); + return applyExtractionChainImpl( + row->childAt(childIdx), chain, index + 1, pool); + } + + case ExtractionStep::kMapKeys: { + auto* map = input->as(); + VELOX_CHECK_NOT_NULL(map); + if (index + 1 >= chain.size()) { + return std::make_shared( + pool, + ARRAY(map->mapKeys()->type()), + map->nulls(), + map->size(), + map->offsets(), + map->sizes(), + map->mapKeys()); + } + VELOX_CHECK_EQ( + static_cast(chain[index + 1]->step()), + static_cast(ExtractionStep::kArrayElements)); + auto transformedKeys = + applyExtractionChainImpl(map->mapKeys(), chain, index + 2, pool); + return std::make_shared( + pool, + ARRAY(transformedKeys->type()), + map->nulls(), + map->size(), + map->offsets(), + map->sizes(), + transformedKeys); + } + + case ExtractionStep::kMapValues: { + auto* map = input->as(); + VELOX_CHECK_NOT_NULL(map); + if (index + 1 >= chain.size()) { + return std::make_shared( + pool, + ARRAY(map->mapValues()->type()), + map->nulls(), + map->size(), + map->offsets(), + map->sizes(), + map->mapValues()); + } + VELOX_CHECK_EQ( + static_cast(chain[index + 1]->step()), + static_cast(ExtractionStep::kArrayElements)); + auto transformedValues = + applyExtractionChainImpl(map->mapValues(), chain, index + 2, pool); + return std::make_shared( + pool, + ARRAY(transformedValues->type()), + map->nulls(), + map->size(), + map->offsets(), + map->sizes(), + transformedValues); + } + + case ExtractionStep::kMapKeyFilter: { + auto filtered = filterMapByKeys(input, *step, pool); + return applyExtractionChainImpl(filtered, chain, index + 1, pool); + } + + case ExtractionStep::kArrayElements: { + auto* array = input->as(); + VELOX_CHECK_NOT_NULL(array); + if (index + 1 >= chain.size()) { + return input; + } + auto transformedElements = + applyExtractionChainImpl(array->elements(), chain, index + 1, pool); + return std::make_shared( + pool, + ARRAY(transformedElements->type()), + array->nulls(), + array->size(), + array->offsets(), + array->sizes(), + transformedElements); + } + + case ExtractionStep::kSize: { + return extractSizes(input, pool); + } + } + VELOX_UNREACHABLE(); +} + +/// Analyze extraction chains on a ROW type to determine which fields are +/// needed. +std::unordered_set analyzeStructNeeds( + const std::vector& extractions) { + std::unordered_set neededFields; + for (const auto& extraction : extractions) { + if (extraction.chain.empty()) { + // Pass-through: need all fields. + return {}; + } + if (extraction.chain[0]->step() == ExtractionStep::kStructField) { + neededFields.insert( + static_cast( + *extraction.chain[0]) + .fieldName()); + } else { + // Non-struct step on a struct: need everything. + return {}; + } + } + return neededFields; +} + +/// Build sub-chains by stripping the first step from each extraction chain. +/// For MapKeys/MapValues, also strips the following ArrayElements step. +/// Only includes extractions whose first step matches 'firstStep'. +/// For ROW StructField, only includes extractions targeting 'fieldName'. +std::vector buildSubChains( + const std::vector& extractions, + ExtractionStep firstStep, + const std::string& fieldName = "") { + std::vector subChains; + for (const auto& extraction : extractions) { + if (extraction.chain.empty()) { + continue; + } + if (extraction.chain[0]->step() != firstStep) { + continue; + } + if (firstStep == ExtractionStep::kStructField) { + const auto& name = static_cast( + *extraction.chain[0]) + .fieldName(); + if (name != fieldName) { + continue; + } + } + + // Determine how many leading steps to skip. + size_t skip = 1; + if ((firstStep == ExtractionStep::kMapKeys || + firstStep == ExtractionStep::kMapValues) && + skip < extraction.chain.size() && + extraction.chain[skip]->step() == ExtractionStep::kArrayElements) { + ++skip; + } + + if (skip < extraction.chain.size()) { + NamedExtraction sub; + sub.outputName = extraction.outputName; + sub.chain = std::vector( + extraction.chain.begin() + static_cast(skip), + extraction.chain.end()); + sub.dataType = extraction.dataType; + subChains.push_back(std::move(sub)); + } + } + return subChains; +} + +} // namespace + +VectorPtr applyExtractionChain( + const VectorPtr& input, + const std::vector& chain, + memory::MemoryPool* pool) { + return applyExtractionChainImpl(input, chain, 0, pool); +} + +void configureExtractionScanSpec( + const TypePtr& hiveType, + const std::vector& extractions, + common::ScanSpec& spec, + memory::MemoryPool* pool) { + if (extractions.empty()) { + return; + } + + switch (hiveType->kind()) { // NOLINT(clang-diagnostic-switch-enum) + case TypeKind::MAP: { + // Determine the extraction type from the first step of all chains. + bool allKeys = true; + bool allValues = true; + bool allSize = true; + bool allMapKeyFilter = true; + for (const auto& extraction : extractions) { + if (extraction.chain.empty()) { + allKeys = false; + allValues = false; + allSize = false; + allMapKeyFilter = false; + break; + } + auto firstStep = extraction.chain[0]->step(); + if (firstStep != ExtractionStep::kMapKeys) { + allKeys = false; + } + if (firstStep != ExtractionStep::kMapValues) { + allValues = false; + } + if (firstStep != ExtractionStep::kSize) { + allSize = false; + } + if (firstStep != ExtractionStep::kMapKeyFilter) { + allMapKeyFilter = false; + } + } + + if (allSize) { + spec.setExtractionType(common::ScanSpec::ExtractionType::kSize); + } else if (allKeys) { + spec.setExtractionType(common::ScanSpec::ExtractionType::kKeys); + auto subChains = buildSubChains(extractions, ExtractionStep::kMapKeys); + if (!subChains.empty()) { + auto* keysSpec = + spec.childByName(common::ScanSpec::kMapKeysFieldName); + if (keysSpec) { + configureExtractionScanSpec( + hiveType->asMap().keyType(), subChains, *keysSpec, pool); + } + } + } else if (allValues) { + spec.setExtractionType(common::ScanSpec::ExtractionType::kValues); + auto subChains = + buildSubChains(extractions, ExtractionStep::kMapValues); + if (!subChains.empty()) { + auto* valuesSpec = + spec.childByName(common::ScanSpec::kMapValuesFieldName); + if (valuesSpec) { + configureExtractionScanSpec( + hiveType->asMap().valueType(), subChains, *valuesSpec, pool); + } + } + } else if (allMapKeyFilter) { + // kMapKeyFilter is type-preserving (MAP -> MAP), so no ExtractionType + // is set. Instead, add an IN filter on the map keys ScanSpec so the + // reader can skip non-matching key-value pairs. + auto* keysSpec = spec.childByName(common::ScanSpec::kMapKeysFieldName); + if (keysSpec) { + // Merge filter keys from all extraction chains. + std::vector mergedStringKeys; + std::vector mergedIntKeys; + bool useStringKeys = false; + for (const auto& extraction : extractions) { + if (auto* strFilter = dynamic_cast< + const StringMapKeyFilterExtractionPathElement*>( + extraction.chain[0].get())) { + useStringKeys = true; + mergedStringKeys.insert( + mergedStringKeys.end(), + strFilter->filterKeys().begin(), + strFilter->filterKeys().end()); + } else if ( + auto* intFilter = + dynamic_cast( + extraction.chain[0].get())) { + mergedIntKeys.insert( + mergedIntKeys.end(), + intFilter->filterKeys().begin(), + intFilter->filterKeys().end()); + } + } + + if (useStringKeys) { + keysSpec->setFilter( + std::make_unique( + mergedStringKeys, /*nullAllowed=*/false)); + } else { + keysSpec->setFilter( + common::createBigintValues( + mergedIntKeys, /*nullAllowed=*/false)); + } + } + + // Recurse with remaining chains (strip the kMapKeyFilter step). + auto subChains = + buildSubChains(extractions, ExtractionStep::kMapKeyFilter); + if (!subChains.empty()) { + configureExtractionScanSpec(hiveType, subChains, spec, pool); + } + } + break; + } + case TypeKind::ARRAY: { + bool allSize = true; + for (const auto& extraction : extractions) { + if (extraction.chain.empty() || + extraction.chain[0]->step() != ExtractionStep::kSize) { + allSize = false; + break; + } + } + if (allSize) { + spec.setExtractionType(common::ScanSpec::ExtractionType::kSize); + } + break; + } + case TypeKind::ROW: { + auto neededFields = analyzeStructNeeds(extractions); + if (neededFields.empty()) { + // Need all fields: no pruning. + break; + } + const auto& rowType = hiveType->asRow(); + + // If exactly one field is needed, set kField extraction so the struct + // reader produces the field's vector directly instead of a RowVector. + if (neededFields.size() == 1) { + auto& onlyFieldName = *neededFields.begin(); + auto fieldIdx = rowType.getChildIdx(onlyFieldName); + spec.setExtractionType(common::ScanSpec::ExtractionType::kField); + spec.setExtractionFieldIndex(fieldIdx); + } + + for (uint32_t i = 0; i < rowType.size(); ++i) { + auto& fieldName = rowType.nameOf(i); + if (neededFields.count(fieldName) == 0) { + auto* child = spec.childByName(fieldName); + if (child) { + child->setConstantValue( + BaseVector::createNullConstant(rowType.childAt(i), 1, pool)); + } + } else { + // Recurse into needed fields with their sub-chains. + auto subChains = buildSubChains( + extractions, ExtractionStep::kStructField, fieldName); + if (!subChains.empty()) { + auto* child = spec.childByName(fieldName); + if (child) { + configureExtractionScanSpec( + rowType.childAt(i), subChains, *child, pool); + } + } + } + } + break; + } + default: + break; + } +} + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/ExtractionUtils.h b/velox/connectors/hive/ExtractionUtils.h new file mode 100644 index 00000000000..a803a66f35a --- /dev/null +++ b/velox/connectors/hive/ExtractionUtils.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/connectors/hive/TableHandle.h" +#include "velox/dwio/common/ScanSpec.h" +#include "velox/vector/ComplexVector.h" + +namespace facebook::velox::connector::hive { + +/// Apply an extraction chain to a vector, producing the output vector. +/// The input vector must have a type compatible with the chain's expected +/// input type. The output vector has the type derived by the chain. +/// +/// For multiple named extractions from the same column, call this once +/// per extraction chain and assemble the results into a RowVector. +VectorPtr applyExtractionChain( + const VectorPtr& input, + const std::vector& chain, + memory::MemoryPool* pool); + +/// Configure a ScanSpec for a column that has extraction chains. +/// Analyzes the extraction chains and marks unneeded sub-streams as +/// constant null so the reader can skip them (DWRF/Nimble pushdown). +/// +/// This is an optimization — correctness is guaranteed by the +/// post-read extraction in applyExtractionChain even without these hints. +void configureExtractionScanSpec( + const TypePtr& hiveType, + const std::vector& extractions, + common::ScanSpec& spec, + memory::MemoryPool* pool); + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/FileColumnHandle.cpp b/velox/connectors/hive/FileColumnHandle.cpp new file mode 100644 index 00000000000..07e55acec33 --- /dev/null +++ b/velox/connectors/hive/FileColumnHandle.cpp @@ -0,0 +1,296 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/FileColumnHandle.h" + +namespace facebook::velox::connector::hive { +namespace { + +const std::unordered_map& +columnTypeNames() { + static const std::unordered_map + kColumnTypeNames = { + {FileColumnHandle::ColumnType::kPartitionKey, "PartitionKey"}, + {FileColumnHandle::ColumnType::kRegular, "Regular"}, + {FileColumnHandle::ColumnType::kSynthesized, "Synthesized"}, + {FileColumnHandle::ColumnType::kRowIndex, "RowIndex"}, + {FileColumnHandle::ColumnType::kRowId, "RowId"}, + }; + return kColumnTypeNames; +} + +template +std::unordered_map invertMap(const std::unordered_map& mapping) { + std::unordered_map inverted; + for (const auto& [key, value] : mapping) { + inverted.emplace(value, key); + } + return inverted; +} + +std::unordered_map extractionStepNames() { + return { + {ExtractionStep::kStructField, "STRUCT_FIELD"}, + {ExtractionStep::kMapKeys, "MAP_KEYS"}, + {ExtractionStep::kMapValues, "MAP_VALUES"}, + {ExtractionStep::kMapKeyFilter, "MAP_KEY_FILTER"}, + {ExtractionStep::kArrayElements, "ARRAY_ELEMENTS"}, + {ExtractionStep::kSize, "SIZE"}, + }; +} + +// Validate the extraction chain against the input type for step +// compatibility. Each step's output feeds as input to the next step. +void validateExtractionChain( + const TypePtr& inputType, + const std::vector& chain) { + auto currentType = inputType; + for (size_t i = 0; i < chain.size(); ++i) { + const auto& element = chain[i]; + switch (element->step()) { + case ExtractionStep::kStructField: { + VELOX_USER_CHECK( + currentType->isRow(), + "Extraction step StructField requires ROW input, got: {}", + currentType->toString()); + const auto& rowType = currentType->asRow(); + auto& fieldName = + static_cast(*element) + .fieldName(); + VELOX_USER_CHECK( + rowType.containsChild(fieldName), + "Extraction step StructField references non-existent field: {}", + fieldName); + currentType = rowType.findChild(fieldName); + break; + } + case ExtractionStep::kMapKeys: { + VELOX_USER_CHECK( + currentType->isMap(), + "Extraction step MapKeys requires MAP input, got: {}", + currentType->toString()); + currentType = ARRAY(currentType->asMap().keyType()); + break; + } + case ExtractionStep::kMapValues: { + VELOX_USER_CHECK( + currentType->isMap(), + "Extraction step MapValues requires MAP input, got: {}", + currentType->toString()); + currentType = ARRAY(currentType->asMap().valueType()); + break; + } + case ExtractionStep::kMapKeyFilter: { + VELOX_USER_CHECK( + currentType->isMap(), + "Extraction step MapKeyFilter requires MAP input, got: {}", + currentType->toString()); + // Type-preserving: MAP(K, V) -> MAP(K, V). + break; + } + case ExtractionStep::kArrayElements: { + VELOX_USER_CHECK( + currentType->isArray(), + "Extraction step ArrayElements requires ARRAY input, got: {}", + currentType->toString()); + currentType = currentType->asArray().elementType(); + break; + } + case ExtractionStep::kSize: { + VELOX_USER_CHECK( + currentType->isMap() || currentType->isArray(), + "Extraction step Size requires MAP or ARRAY input, got: {}", + currentType->toString()); + VELOX_USER_CHECK_EQ( + i, + chain.size() - 1, + "Extraction step Size must be the last step in the chain."); + currentType = BIGINT(); + break; + } + } + } +} + +// Recursively derive the output type from the input type and extraction +// chain. This follows the derivation rules in the design document. +TypePtr deriveOutputTypeImpl( + const TypePtr& inputType, + const std::vector& chain, + size_t index) { + if (index >= chain.size()) { + return inputType; + } + + const auto& element = chain[index]; + switch (element->step()) { + case ExtractionStep::kStructField: { + VELOX_CHECK(inputType->isRow()); + auto& fieldName = + static_cast(*element) + .fieldName(); + auto childType = inputType->asRow().findChild(fieldName); + return deriveOutputTypeImpl(childType, chain, index + 1); + } + case ExtractionStep::kMapKeys: { + VELOX_CHECK(inputType->isMap()); + auto keyType = inputType->asMap().keyType(); + if (index + 1 >= chain.size()) { + return ARRAY(keyType); + } + VELOX_CHECK_EQ( + static_cast(chain[index + 1]->step()), + static_cast(ExtractionStep::kArrayElements)); + return ARRAY(deriveOutputTypeImpl(keyType, chain, index + 2)); + } + case ExtractionStep::kMapValues: { + VELOX_CHECK(inputType->isMap()); + auto valueType = inputType->asMap().valueType(); + if (index + 1 >= chain.size()) { + return ARRAY(valueType); + } + VELOX_CHECK_EQ( + static_cast(chain[index + 1]->step()), + static_cast(ExtractionStep::kArrayElements)); + return ARRAY(deriveOutputTypeImpl(valueType, chain, index + 2)); + } + case ExtractionStep::kMapKeyFilter: { + VELOX_CHECK(inputType->isMap()); + return deriveOutputTypeImpl(inputType, chain, index + 1); + } + case ExtractionStep::kArrayElements: { + VELOX_CHECK(inputType->isArray()); + auto elementType = inputType->asArray().elementType(); + return ARRAY(deriveOutputTypeImpl(elementType, chain, index + 1)); + } + case ExtractionStep::kSize: { + VELOX_CHECK(inputType->isMap() || inputType->isArray()); + return BIGINT(); + } + } + VELOX_UNREACHABLE(); +} + +} // namespace + +bool extractionPathElementEquals( + const ExtractionPathElement& lhs, + const ExtractionPathElement& rhs) { + if (lhs.step() != rhs.step()) { + return false; + } + switch (lhs.step()) { + case ExtractionStep::kStructField: + return static_cast(lhs) + .fieldName() == + static_cast(rhs).fieldName(); + case ExtractionStep::kMapKeyFilter: { + if (auto* lStr = + dynamic_cast( + &lhs)) { + auto* rStr = + dynamic_cast(&rhs); + return rStr && lStr->filterKeys() == rStr->filterKeys(); + } + if (auto* lInt = + dynamic_cast(&lhs)) { + auto* rInt = + dynamic_cast(&rhs); + return rInt && lInt->filterKeys() == rInt->filterKeys(); + } + return false; + } + case ExtractionStep::kMapKeys: + case ExtractionStep::kMapValues: + case ExtractionStep::kArrayElements: + case ExtractionStep::kSize: + return true; + } + VELOX_UNREACHABLE(); +} + +/*static*/ std::shared_ptr +ExtractionPathElement::simple(ExtractionStep step) { + return std::make_shared(step); +} + +/*static*/ std::shared_ptr +ExtractionPathElement::structField(const std::string& name) { + return std::make_shared(name); +} + +/*static*/ std::shared_ptr +ExtractionPathElement::mapKeyFilter(std::vector keys) { + return std::make_shared( + std::move(keys)); +} + +/*static*/ std::shared_ptr +ExtractionPathElement::mapKeyFilter(std::vector keys) { + return std::make_shared( + std::move(keys)); +} + +bool NamedExtraction::operator==(const NamedExtraction& other) const { + if (outputName != other.outputName || chain.size() != other.chain.size() || + !dataType->equivalent(*other.dataType)) { + return false; + } + for (size_t i = 0; i < chain.size(); ++i) { + if (!extractionPathElementEquals(*chain[i], *other.chain[i])) { + return false; + } + } + return true; +} + +std::string extractionStepName(ExtractionStep step) { + static const auto names = extractionStepNames(); + return names.at(step); +} + +ExtractionStep extractionStepFromName(const std::string& name) { + static const auto nameToStep = invertMap(extractionStepNames()); + return nameToStep.at(name); +} + +TypePtr deriveExtractionOutputType( + const TypePtr& inputType, + const std::vector& chain) { + validateExtractionChain(inputType, chain); + return deriveOutputTypeImpl(inputType, chain, 0); +} + +std::string FileColumnHandle::columnTypeName( + FileColumnHandle::ColumnType columnType) { + const auto& names = columnTypeNames(); + auto it = names.find(columnType); + VELOX_CHECK(it != names.end(), "Unknown column type"); + return it->second; +} + +FileColumnHandle::ColumnType FileColumnHandle::columnTypeFromName( + const std::string& name) { + const auto& names = columnTypeNames(); + for (const auto& [type, typeName] : names) { + if (typeName == name) { + return type; + } + } + VELOX_FAIL("Unknown column type name: {}", name); +} + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/FileColumnHandle.h b/velox/connectors/hive/FileColumnHandle.h new file mode 100644 index 00000000000..4a6d564d35c --- /dev/null +++ b/velox/connectors/hive/FileColumnHandle.h @@ -0,0 +1,244 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/connectors/Connector.h" +#include "velox/type/Subfield.h" +#include "velox/type/Type.h" + +namespace facebook::velox::connector::hive { + +/// Type of extraction to apply at one nesting level. +enum class ExtractionStep : uint8_t { + /// Navigate into a struct field. Input must be ROW. + kStructField, + /// Extract map keys as ARRAY. Input must be MAP. + kMapKeys, + /// Extract map values as ARRAY. Input must be MAP. + kMapValues, + /// Filter map to specific keys. Input must be MAP. Type-preserving. + kMapKeyFilter, + /// Navigate into array elements. Input must be ARRAY. + kArrayElements, + /// Extract size as BIGINT. Input must be MAP or ARRAY. Terminal. + kSize, +}; + +/// Base class for one step in the extraction chain. +class ExtractionPathElement { + public: + virtual ~ExtractionPathElement() = default; + + /// Return the step type. + virtual ExtractionStep step() const = 0; + + /// Create a simple step (MapKeys, MapValues, ArrayElements, Size). + static std::shared_ptr simple( + ExtractionStep step); + + /// Create a StructField step. + static std::shared_ptr structField( + const std::string& name); + + /// Create a MapKeyFilter step with string keys. + static std::shared_ptr mapKeyFilter( + std::vector keys); + + /// Create a MapKeyFilter step with integer keys. + static std::shared_ptr mapKeyFilter( + std::vector keys); +}; + +using ExtractionPathElementPtr = std::shared_ptr; + +/// Simple extraction step without extra data (MapKeys, MapValues, +/// ArrayElements, Size). +class SimpleExtractionPathElement : public ExtractionPathElement { + public: + explicit SimpleExtractionPathElement(ExtractionStep step) : step_(step) {} + + ExtractionStep step() const override { + return step_; + } + + private: + ExtractionStep step_; +}; + +/// Struct field extraction step. +class StructFieldExtractionPathElement : public ExtractionPathElement { + public: + explicit StructFieldExtractionPathElement(std::string name) + : fieldName_(std::move(name)) {} + + ExtractionStep step() const override { + return ExtractionStep::kStructField; + } + + const std::string& fieldName() const { + return fieldName_; + } + + private: + std::string fieldName_; +}; + +/// Map key filter extraction step with string keys. +class StringMapKeyFilterExtractionPathElement : public ExtractionPathElement { + public: + explicit StringMapKeyFilterExtractionPathElement( + std::vector keys) + : filterKeys_(std::move(keys)) {} + + ExtractionStep step() const override { + return ExtractionStep::kMapKeyFilter; + } + + const std::vector& filterKeys() const { + return filterKeys_; + } + + private: + std::vector filterKeys_; +}; + +/// Map key filter extraction step with integer keys. +class IntMapKeyFilterExtractionPathElement : public ExtractionPathElement { + public: + explicit IntMapKeyFilterExtractionPathElement(std::vector keys) + : filterKeys_(std::move(keys)) {} + + ExtractionStep step() const override { + return ExtractionStep::kMapKeyFilter; + } + + const std::vector& filterKeys() const { + return filterKeys_; + } + + private: + std::vector filterKeys_; +}; + +/// Compare two ExtractionPathElements for equality. Non-virtual to avoid +/// slicing hazards with virtual operator==. +bool extractionPathElementEquals( + const ExtractionPathElement& lhs, + const ExtractionPathElement& rhs); + +/// Named extraction chain producing one output column. +struct NamedExtraction { + /// Output column name in the scan's outputType. + std::string outputName; + + /// Extraction chain to apply. Empty means pass-through (no extraction). + std::vector chain; + + /// Output type after applying the chain. + TypePtr dataType; + + bool operator==(const NamedExtraction& other) const; +}; + +/// Return the string name for an ExtractionStep enum value. +std::string extractionStepName(ExtractionStep step); + +/// Parse an ExtractionStep from its string name. +ExtractionStep extractionStepFromName(const std::string& name); + +/// Derive the output type by applying the extraction chain to the input type. +/// Throws if the chain is invalid for the given input type. +TypePtr deriveExtractionOutputType( + const TypePtr& inputType, + const std::vector& chain); + +/// Base class for column handles in file-based connectors. +/// +/// Define the common interface for column metadata needed by +/// FileDataSource and FileSplitReader. Connector-specific column handles +/// (HiveColumnHandle, etc.) extend this class. +class FileColumnHandle : public ColumnHandle { + public: + /// Classify columns by their role in the scan pipeline. + enum class ColumnType { + kPartitionKey, + kRegular, + kSynthesized, + /// A zero-based row number of type BIGINT auto-generated by the connector. + /// Row numbers are unique within a single file only. + kRowIndex, + kRowId, + }; + + virtual ColumnType columnType() const = 0; + + /// The type of this column as defined in the table schema (metastore). + /// May differ from dataType() when extraction changes the output type. + /// Subclasses must provide this. + virtual const TypePtr& schemaType() const = 0; + + /// The target data type for this column in the output. Defaults to + /// schemaType() (no extraction or type coercion). Override when the + /// output type differs from the table schema type. + virtual const TypePtr& dataType() const { + return schemaType(); + } + + /// Subfields required by the query for pruning complex types. + virtual const std::vector& requiredSubfields() const = 0; + + virtual bool isPartitionKey() const { + return columnType() == ColumnType::kPartitionKey; + } + + /// Return true if partition date values are encoded as days since epoch + /// (e.g., Iceberg) rather than ISO 8601 strings (e.g., Hive). + virtual bool isPartitionDateValueDaysSinceEpoch() const = 0; + + /// Named extraction chains. Empty means no extraction (default behavior). + virtual const std::vector& extractions() const { + static const std::vector kEmpty; + return kEmpty; + } + + /// Optional row-wise post processor applied to this column after reading and + /// filtering. Must not change the vector size. + virtual const std::function& postProcessor() const = 0; + + static std::string columnTypeName(ColumnType columnType); + + static ColumnType columnTypeFromName(const std::string& name); +}; + +using FileColumnHandlePtr = std::shared_ptr; +using FileColumnHandleMap = + std::unordered_map; + +} // namespace facebook::velox::connector::hive + +template <> +struct fmt::formatter< + facebook::velox::connector::hive::FileColumnHandle::ColumnType> + : formatter { + auto format( + facebook::velox::connector::hive::FileColumnHandle::ColumnType type, + format_context& ctx) const { + return formatter::format( + facebook::velox::connector::hive::FileColumnHandle::columnTypeName( + type), + ctx); + } +}; diff --git a/velox/connectors/hive/FileConfig.cpp b/velox/connectors/hive/FileConfig.cpp new file mode 100644 index 00000000000..eec3b9cd724 --- /dev/null +++ b/velox/connectors/hive/FileConfig.cpp @@ -0,0 +1,102 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/FileConfig.h" + +#include "velox/common/config/Config.h" + +namespace facebook::velox::connector::hive { + +const std::vector& FileConfig::registeredProperties() { + static const std::vector kProperties = [] { + std::vector properties; +#define VELOX_HIVE_CONFIG_REGISTER(constName) \ + config::registerConfigProperty(properties) + + VELOX_HIVE_CONFIG_REGISTER(kOrcUseColumnNamesSession); + VELOX_HIVE_CONFIG_REGISTER(kParquetUseColumnNamesSession); + VELOX_HIVE_CONFIG_REGISTER(kAllowInt32NarrowingSession); + VELOX_HIVE_CONFIG_REGISTER(kReadTimestampPartitionValueAsLocalTimeSession); + VELOX_HIVE_CONFIG_REGISTER(kPreserveFlatMapsInMemorySession); + VELOX_HIVE_CONFIG_REGISTER(kReaderCollectColumnCpuMetricsSession); + VELOX_HIVE_CONFIG_REGISTER(kOrcFooterSpeculativeIoSizeSession); + VELOX_HIVE_CONFIG_REGISTER(kParquetFooterSpeculativeIoSizeSession); + VELOX_HIVE_CONFIG_REGISTER(kNimbleFooterSpeculativeIoSizeSession); + VELOX_HIVE_CONFIG_REGISTER(kNimbleStringDecoderZeroCopySession); + VELOX_HIVE_CONFIG_REGISTER(kNimblePreserveDictionaryEncodingSession); + VELOX_HIVE_CONFIG_REGISTER(kFileColumnNamesReadAsLowerCaseSession); + VELOX_HIVE_CONFIG_REGISTER(kIgnoreMissingFilesSession); + VELOX_HIVE_CONFIG_REGISTER(kMaxCoalescedBytesSession); + VELOX_HIVE_CONFIG_REGISTER(kLoadQuantumSession); + VELOX_HIVE_CONFIG_REGISTER(kReadStatsBasedFilterReorderDisabledSession); + VELOX_HIVE_CONFIG_REGISTER(kIndexEnabledSession); + VELOX_HIVE_CONFIG_REGISTER(kFileMetadataCacheEnabledSession); + VELOX_HIVE_CONFIG_REGISTER(kPinFileMetadataSession); + VELOX_HIVE_CONFIG_REGISTER(kSelectiveNimbleReaderEnabledSession); + VELOX_HIVE_CONFIG_REGISTER(kMaxCoalescedDistanceSession); + VELOX_HIVE_CONFIG_REGISTER(kParallelUnitLoadCountSession); + VELOX_HIVE_CONFIG_REGISTER(kReadTimestampUnitSession); + +#undef VELOX_HIVE_CONFIG_REGISTER + + return properties; + }(); + return kProperties; +} + +int32_t FileConfig::maxCoalescedDistanceBytes( + const config::ConfigBase* session) const { + const auto distance = config::toCapacity( + session->get( + kMaxCoalescedDistanceSession, + config_->get(kMaxCoalescedDistance, "512kB")), + config::CapacityUnit::BYTE); + VELOX_USER_CHECK_LE( + distance, + std::numeric_limits::max(), + "The max merge distance to combine read requests must be less than 2GB." + " Got {} bytes.", + distance); + return int32_t(distance); +} + +int32_t FileConfig::prefetchRowGroups() const { + return config_->get(kPrefetchRowGroups, 1); +} + +size_t FileConfig::parallelUnitLoadCount( + const config::ConfigBase* session) const { + auto count = session->get( + kParallelUnitLoadCountSession, + config_->get(kParallelUnitLoadCount, 0)); + VELOX_CHECK_LE(count, 100, "parallelUnitLoadCount too large: {}", count); + return count; +} + +uint64_t FileConfig::filePreloadThreshold() const { + return config_->get(kFilePreloadThreshold, 8UL << 20); +} + +uint8_t FileConfig::readTimestampUnit(const config::ConfigBase* session) const { + const auto unit = sessionValue( + session, kReadTimestampUnitSession, kReadTimestampUnit, 3 /*milli*/); + VELOX_CHECK( + unit == 3 || unit == 6 /*micro*/ || unit == 9 /*nano*/, + "Invalid timestamp unit."); + return unit; +} + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/FileConfig.h b/velox/connectors/hive/FileConfig.h new file mode 100644 index 00000000000..d932a2a066b --- /dev/null +++ b/velox/connectors/hive/FileConfig.h @@ -0,0 +1,331 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include "velox/common/base/Exceptions.h" +#include "velox/common/config/Config.h" +#include "velox/common/config/ConfigProperty.h" +#include "velox/connectors/hive/HiveConfigMacrosDefine.h" + +namespace facebook::velox::connector::hive { + +/// Configuration for file-based data sources. Contains settings for file +/// reading, I/O coalescing, format-specific options, and metadata caching. +/// HiveConfig extends this with Hive-specific settings like partitioning, +/// bucketing, and write-path options. +class FileConfig { + public: + static const std::vector& registeredProperties(); + + // --- VELOX_HIVE_CONFIG_LEGACY properties --- + + VELOX_HIVE_CONFIG_LEGACY( + kOrcUseColumnNamesSession, + kOrcUseColumnNames, + isOrcUseColumnNames, + "orc_use_column_names", + "orc.use-column-names", + bool, + false, + "Map ORC table field names to file field names using names, not indices.") + + VELOX_HIVE_CONFIG_LEGACY( + kParquetUseColumnNamesSession, + kParquetUseColumnNames, + isParquetUseColumnNames, + "parquet_use_column_names", + "parquet.use-column-names", + bool, + false, + "Map Parquet table field names to file field names using names, not indices.") + + VELOX_HIVE_CONFIG_LEGACY( + kAllowInt32NarrowingSession, + kAllowInt32Narrowing, + allowInt32Narrowing, + "allow_int32_narrowing", + "parquet.allow-int32-narrowing", + bool, + false, + "Allow reading INT32 Parquet columns as a narrower integer type.") + + VELOX_HIVE_CONFIG_LEGACY( + kReadTimestampPartitionValueAsLocalTimeSession, + kReadTimestampPartitionValueAsLocalTime, + readTimestampPartitionValueAsLocalTime, + "reader.timestamp_partition_value_as_local_time", + "reader.timestamp-partition-value-as-local-time", + bool, + true, + "Read timestamp partition values as local time.") + + VELOX_HIVE_CONFIG_LEGACY( + kPreserveFlatMapsInMemorySession, + kPreserveFlatMapsInMemory, + preserveFlatMapsInMemory, + "preserve_flat_maps_in_memory", + "preserve-flat-maps-in-memory", + bool, + false, + "Preserve flat maps in memory as FlatMapVectors.") + + VELOX_HIVE_CONFIG_LEGACY( + kReaderCollectColumnCpuMetricsSession, + kReaderCollectColumnCpuMetrics, + readerCollectColumnCpuMetrics, + "reader.collect_column_cpu_metrics", + "reader.collect-column-cpu-metrics", + bool, + false, + "Collect per-column CPU timing stats.") + + VELOX_HIVE_CONFIG_LEGACY( + kOrcFooterSpeculativeIoSizeSession, + kOrcFooterSpeculativeIoSize, + orcFooterSpeculativeIoSize, + "orc_footer_speculative_io_size", + "orc.footer-speculative-io-size", + uint64_t, + 256UL << 10, + "Speculative tail-read size in bytes for ORC files.") + + VELOX_HIVE_CONFIG_LEGACY( + kParquetFooterSpeculativeIoSizeSession, + kParquetFooterSpeculativeIoSize, + parquetFooterSpeculativeIoSize, + "parquet_footer_speculative_io_size", + "parquet.footer-speculative-io-size", + uint64_t, + 256UL << 10, + "Speculative tail-read size in bytes for Parquet files.") + + VELOX_HIVE_CONFIG_LEGACY( + kNimbleFooterSpeculativeIoSizeSession, + kNimbleFooterSpeculativeIoSize, + nimbleFooterSpeculativeIoSize, + "nimble_footer_speculative_io_size", + "nimble.footer-speculative-io-size", + uint64_t, + 8UL << 20, + "Speculative tail-read size in bytes for Nimble files.") + + VELOX_HIVE_CONFIG_LEGACY( + kNimbleStringDecoderZeroCopySession, + kNimbleStringDecoderZeroCopy, + nimbleStringDecoderZeroCopy, + "nimble_string_decoder_zero_copy", + "nimble.string-decoder-zero-copy", + bool, + false, + "Enable zero-copy string decoding in Nimble selective reader.") + + VELOX_HIVE_CONFIG_LEGACY( + kNimblePreserveDictionaryEncodingSession, + kNimblePreserveDictionaryEncoding, + nimblePreserveDictionaryEncoding, + "nimble_preserve_dictionary_encoding", + "nimble.preserve-dictionary-encoding", + bool, + false, + "Preserve dictionary encoding for Nimble string column reads.") + + // --- VELOX_HIVE_CONFIG properties --- + + VELOX_HIVE_CONFIG( + kFileColumnNamesReadAsLowerCaseSession, + isFileColumnNamesReadAsLowerCase, + "file_column_names_read_as_lower_case", + bool, + false, + "Read source file column names as lower case.") + static constexpr const char* kFileColumnNamesReadAsLowerCase = + "file-column-names-read-as-lower-case"; + + VELOX_HIVE_CONFIG_PROPERTY( + kIgnoreMissingFilesSession, + "ignore_missing_files", + bool, + false, + "Ignore missing files during table scan.") + bool ignoreMissingFiles(const config::ConfigBase* session) const { + return session->get( + kIgnoreMissingFilesSession, + kIgnoreMissingFilesSessionProperty::defaultValue); + } + + VELOX_HIVE_CONFIG( + kMaxCoalescedBytesSession, + maxCoalescedBytes, + "max-coalesced-bytes", + int64_t, + 128 << 20, + "Maximum coalesced bytes for a read request.") + static constexpr const char* kMaxCoalescedBytes = "max-coalesced-bytes"; + + VELOX_HIVE_CONFIG( + kLoadQuantumSession, + loadQuantum, + "load-quantum", + int32_t, + 8 << 20, + "Total size in bytes for a direct coalesce request.") + static constexpr const char* kLoadQuantum = "load-quantum"; + + VELOX_HIVE_CONFIG( + kReadStatsBasedFilterReorderDisabledSession, + readStatsBasedFilterReorderDisabled, + "stats_based_filter_reorder_disabled", + bool, + false, + "Disable stats-based filter reordering.") + static constexpr const char* kReadStatsBasedFilterReorderDisabled = + "stats-based-filter-reorder-disabled"; + + VELOX_HIVE_CONFIG( + kIndexEnabledSession, + indexEnabled, + "index_enabled", + bool, + false, + "Use cluster index for filter-based row pruning.") + static constexpr const char* kIndexEnabled = "index-enabled"; + + VELOX_HIVE_CONFIG( + kFileMetadataCacheEnabledSession, + fileMetadataCacheEnabled, + "file_metadata_cache_enabled", + bool, + false, + "Cache file metadata in AsyncDataCache.") + static constexpr const char* kFileMetadataCacheEnabled = + "file-metadata-cache-enabled"; + + VELOX_HIVE_CONFIG( + kPinFileMetadataSession, + pinFileMetadata, + "pin_file_metadata", + bool, + false, + "Pin parsed metadata objects in reader cache.") + static constexpr const char* kPinFileMetadata = "pin-file-metadata"; + + VELOX_HIVE_CONFIG( + kSelectiveNimbleReaderEnabledSession, + selectiveNimbleReaderEnabled, + "selective_nimble_reader_enabled", + bool, + true, + "Enable selective Nimble reader.") + + // --- VELOX_HIVE_CONFIG_PROPERTY properties --- + + VELOX_HIVE_CONFIG_PROPERTY( + kMaxCoalescedDistanceSession, + "orc_max_merge_distance", + std::string, + "512kB", + "Maximum merge distance to combine read requests.") + static constexpr const char* kMaxCoalescedDistance = "max-coalesced-distance"; + + VELOX_HIVE_CONFIG_PROPERTY( + kParallelUnitLoadCountSession, + "parallel_unit_load_count", + size_t, + 0, + "Number of units to load in parallel. 0 disables.") + static constexpr const char* kParallelUnitLoadCount = + "parallel-unit-load-count"; + + VELOX_HIVE_CONFIG_PROPERTY( + kReadTimestampUnitSession, + "reader.timestamp_unit", + uint8_t, + 3, + "Unit for reading timestamps (0=second, 3=millisecond, 6=microsecond, 9=nanosecond).") + static constexpr const char* kReadTimestampUnit = "reader.timestamp-unit"; + + // --- Server-only properties (no macro) --- + + /// The number of prefetch rowgroups. + static constexpr const char* kPrefetchRowGroups = "prefetch-rowgroups"; + + /// The threshold of file size in bytes when the whole file is fetched with + /// meta data together. Optimization to decrease the small IO requests. + static constexpr const char* kFilePreloadThreshold = "file-preload-threshold"; + + explicit FileConfig(std::shared_ptr config) { + VELOX_CHECK_NOT_NULL( + config, "Config is null for FileConfig initialization"); + config_ = std::move(config); + } + + virtual ~FileConfig() = default; + + int32_t maxCoalescedDistanceBytes(const config::ConfigBase* session) const; + + int32_t prefetchRowGroups() const; + + size_t parallelUnitLoadCount(const config::ConfigBase* session) const; + + uint64_t filePreloadThreshold() const; + + // Returns the timestamp unit used when reading timestamps from files. + uint8_t readTimestampUnit(const config::ConfigBase* session) const; + + const std::shared_ptr& config() const { + return config_; + } + + protected: + static constexpr const char* kLegacyPrefix = "hive."; + + // Looks up a config value by 'key', falling back to 'hive.' + key if not + // found. Used during migration away from redundant 'hive.' prefix in + // config keys. + template + T configValue(const std::string& key, const T& defaultValue) const { + if (auto val = config_->get(key)) { + return val.value(); + } + return config_->get(std::string(kLegacyPrefix) + key, defaultValue); + } + + // Looks up a session property by 'sessionKey' (then 'hive.' + sessionKey), + // falling back to connector config by 'configKey' (then 'hive.' + configKey). + // Used during migration away from redundant 'hive.' prefix in property + // names. + template + T sessionValue( + const config::ConfigBase* session, + const std::string& sessionKey, + const std::string& configKey, + const T& defaultValue) const { + if (auto val = session->get(sessionKey)) { + return val.value(); + } + if (auto val = session->get(std::string(kLegacyPrefix) + sessionKey)) { + return val.value(); + } + return configValue(configKey, defaultValue); + } + + std::shared_ptr config_; +}; + +} // namespace facebook::velox::connector::hive + +#include "velox/connectors/hive/HiveConfigMacrosUndef.h" diff --git a/velox/connectors/hive/FileConnectorSplit.h b/velox/connectors/hive/FileConnectorSplit.h new file mode 100644 index 00000000000..efba64c5b8b --- /dev/null +++ b/velox/connectors/hive/FileConnectorSplit.h @@ -0,0 +1,77 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +#include "velox/connectors/Connector.h" +#include "velox/connectors/hive/FileProperties.h" +#include "velox/dwio/common/Options.h" + +namespace facebook::velox::connector::hive { + +/// Base split for file-based connectors. Represents a byte range within a +/// single data file. Connector-specific splits (HiveConnectorSplit, +/// HiveIcebergSplit) extend this to add synthesized columns, serde parameters, +/// and other format-specific fields. +struct FileConnectorSplit : public ConnectorSplit { + const std::string filePath; + dwio::common::FileFormat fileFormat; + const uint64_t start; + const uint64_t length; + + /// File properties like file size used while opening the file handle. + std::optional properties; + + /// Mapping from partition keys to values. Values are specified as strings + /// formatted the same way as CAST(x as VARCHAR). Null values are specified as + /// std::nullopt. Date values must be formatted using ISO 8601 as YYYY-MM-DD. + const std::unordered_map> + partitionKeys; + + FileConnectorSplit( + const std::string& connectorId, + const std::string& _filePath, + dwio::common::FileFormat _fileFormat, + uint64_t _start = 0, + uint64_t _length = std::numeric_limits::max(), + int64_t splitWeight = 0, + bool cacheable = true, + std::optional _properties = std::nullopt, + const std::unordered_map>& + _partitionKeys = {}) + : ConnectorSplit(connectorId, splitWeight, cacheable), + filePath(_filePath), + fileFormat(_fileFormat), + start(_start), + length(_length), + properties(std::move(_properties)), + partitionKeys(_partitionKeys) {} + + ~FileConnectorSplit() override = default; + + uint64_t size() const override { + return length; + } + + std::string getFileName() const { + const auto i = filePath.rfind('/'); + return i == std::string::npos ? filePath : filePath.substr(i + 1); + } +}; + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/FileConnectorUtil.cpp b/velox/connectors/hive/FileConnectorUtil.cpp new file mode 100644 index 00000000000..fbc314fa494 --- /dev/null +++ b/velox/connectors/hive/FileConnectorUtil.cpp @@ -0,0 +1,289 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/FileConnectorUtil.h" + +#include "velox/connectors/hive/FileColumnHandle.h" +#include "velox/connectors/hive/FileConfig.h" +#include "velox/connectors/hive/FileConnectorSplit.h" +#include "velox/connectors/hive/FileTableHandle.h" + +namespace facebook::velox::connector::hive { + +void configureReaderOptions( + const std::shared_ptr& fileConfig, + const ConnectorQueryCtx* connectorQueryCtx, + const FileTableHandlePtr& tableHandle, + const std::shared_ptr& fileSplit, + dwio::common::ReaderOptions& readerOptions) { + configureReaderOptions( + fileConfig, + connectorQueryCtx, + tableHandle->dataColumns(), + fileSplit, + tableHandle->tableParameters(), + readerOptions); +} + +void configureReaderOptions( + const std::shared_ptr& fileConfig, + const ConnectorQueryCtx* connectorQueryCtx, + const RowTypePtr& fileSchema, + const std::shared_ptr& fileSplit, + const std::unordered_map& /*tableParameters*/, + dwio::common::ReaderOptions& readerOptions) { + auto sessionProperties = connectorQueryCtx->sessionProperties(); + readerOptions.setLoadQuantum(fileConfig->loadQuantum(sessionProperties)); + readerOptions.setMaxCoalesceBytes( + fileConfig->maxCoalescedBytes(sessionProperties)); + readerOptions.setMaxCoalesceDistance( + fileConfig->maxCoalescedDistanceBytes(sessionProperties)); + readerOptions.setFileColumnNamesReadAsLowerCase( + fileConfig->isFileColumnNamesReadAsLowerCase(sessionProperties)); + readerOptions.setAllowEmptyFile(true); + bool useColumnNamesForColumnMapping = false; + switch (fileSplit->fileFormat) { + case dwio::common::FileFormat::DWRF: + case dwio::common::FileFormat::ORC: { + useColumnNamesForColumnMapping = + fileConfig->isOrcUseColumnNames(sessionProperties); + break; + } + case dwio::common::FileFormat::PARQUET: { + useColumnNamesForColumnMapping = + fileConfig->isParquetUseColumnNames(sessionProperties); + readerOptions.setAllowInt32Narrowing( + fileConfig->allowInt32Narrowing(sessionProperties)); + break; + } + default: + useColumnNamesForColumnMapping = false; + } + + readerOptions.setUseColumnNamesForColumnMapping( + useColumnNamesForColumnMapping); + readerOptions.setFileSchema(fileSchema); + readerOptions.setFilePreloadThreshold(fileConfig->filePreloadThreshold()); + readerOptions.setPrefetchRowGroups(fileConfig->prefetchRowGroups()); + readerOptions.setCacheable(fileSplit->cacheable); + const auto& sessionTzName = connectorQueryCtx->sessionTimezone(); + if (!sessionTzName.empty()) { + const auto timezone = tz::locateZone(sessionTzName); + readerOptions.setSessionTimezone(timezone); + } + readerOptions.setAdjustTimestampToTimezone( + connectorQueryCtx->adjustTimestampToTimezone()); + // Prefer connector session property (FileConfig). Fall back to + // ConnectorQueryCtx (threaded from QueryConfig) for backward compatibility + // with callers that set it as a query config. + if (sessionProperties->valueExists( + FileConfig::kSelectiveNimbleReaderEnabledSession)) { + readerOptions.setSelectiveNimbleReaderEnabled( + fileConfig->selectiveNimbleReaderEnabled(sessionProperties)); + } else { + readerOptions.setSelectiveNimbleReaderEnabled( + connectorQueryCtx->selectiveNimbleReaderEnabled()); + } + readerOptions.setFileMetadataCacheEnabled( + fileConfig->fileMetadataCacheEnabled(sessionProperties)); + readerOptions.setPinFileMetadata( + fileConfig->pinFileMetadata(sessionProperties)); + + // Set footer speculative IO size based on file format. + switch (fileSplit->fileFormat) { + case dwio::common::FileFormat::DWRF: + case dwio::common::FileFormat::ORC: + readerOptions.setFooterSpeculativeIoSize( + fileConfig->orcFooterSpeculativeIoSize(sessionProperties)); + break; + case dwio::common::FileFormat::PARQUET: + readerOptions.setFooterSpeculativeIoSize( + fileConfig->parquetFooterSpeculativeIoSize(sessionProperties)); + break; + case dwio::common::FileFormat::NIMBLE: + readerOptions.setFooterSpeculativeIoSize( + fileConfig->nimbleFooterSpeculativeIoSize(sessionProperties)); + break; + default: + // Use ORC default for unknown formats. + readerOptions.setFooterSpeculativeIoSize( + fileConfig->orcFooterSpeculativeIoSize(sessionProperties)); + break; + } + + if (readerOptions.fileFormat() != dwio::common::FileFormat::UNKNOWN) { + VELOX_CHECK( + readerOptions.fileFormat() == fileSplit->fileFormat, + "HiveDataSource received splits of different formats: {} and {}", + dwio::common::toString(readerOptions.fileFormat()), + dwio::common::toString(fileSplit->fileFormat)); + } else { + readerOptions.setFileFormat(fileSplit->fileFormat); + } +} + +void configureRowReaderOptions( + const std::unordered_map& tableParameters, + const std::shared_ptr& scanSpec, + std::shared_ptr metadataFilter, + const RowTypePtr& rowType, + const std::shared_ptr& fileSplit, + const std::shared_ptr& fileConfig, + const config::ConfigBase* sessionProperties, + folly::Executor* const ioExecutor, + dwio::common::RowReaderOptions& rowReaderOptions) { + auto skipRowsIt = + tableParameters.find(dwio::common::TableParameter::kSkipHeaderLineCount); + if (skipRowsIt != tableParameters.end()) { + rowReaderOptions.setSkipRows(folly::to(skipRowsIt->second)); + } + rowReaderOptions.setScanSpec(scanSpec); + rowReaderOptions.setIOExecutor(ioExecutor); + rowReaderOptions.setMetadataFilter(std::move(metadataFilter)); + rowReaderOptions.setRequestedType(rowType); + rowReaderOptions.range(fileSplit->start, fileSplit->length); + if (fileConfig && sessionProperties) { + rowReaderOptions.setTimestampPrecision( + static_cast( + fileConfig->readTimestampUnit(sessionProperties))); + rowReaderOptions.setPreserveFlatMapsInMemory( + fileConfig->preserveFlatMapsInMemory(sessionProperties)); + rowReaderOptions.setParallelUnitLoadCount( + fileConfig->parallelUnitLoadCount(sessionProperties)); + rowReaderOptions.setIndexEnabled( + fileConfig->indexEnabled(sessionProperties)); + rowReaderOptions.setCollectColumnCpuMetrics( + fileConfig->readerCollectColumnCpuMetrics(sessionProperties)); + } +} + +namespace { + +bool applyPartitionFilter( + const TypePtr& type, + const std::string& partitionValue, + bool isPartitionDateDaysSinceEpoch, + const common::Filter* filter, + bool asLocalTime) { + if (type->isDate()) { + int32_t result = 0; + // days_since_epoch partition values are integers in string format. Eg. + // Iceberg partition values. + if (isPartitionDateDaysSinceEpoch) { + result = folly::to(partitionValue); + } else { + result = DATE()->toDays(partitionValue); + } + return applyFilter(*filter, result); + } + + switch (type->kind()) { + case TypeKind::BIGINT: + case TypeKind::INTEGER: + case TypeKind::SMALLINT: + case TypeKind::TINYINT: { + return applyFilter(*filter, folly::to(partitionValue)); + } + case TypeKind::REAL: + case TypeKind::DOUBLE: { + return applyFilter(*filter, folly::to(partitionValue)); + } + case TypeKind::BOOLEAN: { + return applyFilter(*filter, folly::to(partitionValue)); + } + case TypeKind::TIMESTAMP: { + auto result = util::fromTimestampString( + StringView(partitionValue), util::TimestampParseMode::kPrestoCast); + VELOX_CHECK(!result.hasError()); + if (asLocalTime) { + result.value().toGMT(Timestamp::defaultTimezone()); + } + return applyFilter(*filter, result.value()); + } + case TypeKind::VARCHAR: { + return applyFilter(*filter, partitionValue); + } + default: + VELOX_FAIL( + "Bad type {} for partition value: {}", type->kind(), partitionValue); + } +} + +} // namespace + +bool testFilters( + const common::ScanSpec* scanSpec, + const dwio::common::Reader* reader, + const std::string& filePath, + const std::unordered_map>& + partitionKeys, + const std::unordered_map& + partitionKeysHandle, + bool asLocalTime) { + const auto totalRows = reader->numberOfRows(); + const auto& fileTypeWithId = reader->typeWithId(); + const auto& rowType = reader->rowType(); + for (const auto& child : scanSpec->children()) { + if (child->filter()) { + const auto& name = child->fieldName(); + auto iter = partitionKeys.find(name); + // By design, the partition key columns for Iceberg tables are included in + // the data files to facilitate partition transform and partition + // evolution, so we need to test both cases. + if (!rowType->containsChild(name) || iter != partitionKeys.end()) { + if (iter != partitionKeys.end() && iter->second.has_value()) { + const auto handlesIter = partitionKeysHandle.find(name); + VELOX_CHECK(handlesIter != partitionKeysHandle.end()); + + // This is a non-null partition key + return applyPartitionFilter( + handlesIter->second->dataType(), + iter->second.value(), + handlesIter->second->isPartitionDateValueDaysSinceEpoch(), + child->filter(), + asLocalTime); + } + // Column is missing, most likely due to schema evolution. Or it's a + // partition key but the partition value is NULL. + if (child->filter()->isDeterministic() && + !child->filter()->testNull()) { + VLOG(1) << "Skipping " << filePath + << " because the filter testNull() failed for column " + << child->fieldName(); + return false; + } + } else { + const auto& typeWithId = fileTypeWithId->childByName(name); + const auto columnStats = reader->columnStatistics(typeWithId->id()); + if (columnStats != nullptr && + !testFilter( + child->filter(), + columnStats.get(), + totalRows.value(), + typeWithId->type())) { + VLOG(1) << "Skipping " << filePath + << " based on stats and filter for column " + << child->fieldName(); + return false; + } + } + } + } + + return true; +} + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/FileConnectorUtil.h b/velox/connectors/hive/FileConnectorUtil.h new file mode 100644 index 00000000000..1312593976f --- /dev/null +++ b/velox/connectors/hive/FileConnectorUtil.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "velox/connectors/Connector.h" +#include "velox/connectors/hive/FileHandle.h" +#include "velox/dwio/common/Reader.h" + +namespace facebook::velox::connector::hive { + +class FileColumnHandle; +using FileColumnHandlePtr = std::shared_ptr; +class FileTableHandle; +using FileTableHandlePtr = std::shared_ptr; +class FileConfig; +struct FileConnectorSplit; + +/// Configures reader options for reading a data file. This is the generic +/// version that does not apply serde (serialization/deserialization) options. +/// For Hive tables that need serde options, use the overload in +/// HiveConnectorUtil.h that accepts serdeParameters. +void configureReaderOptions( + const std::shared_ptr& config, + const ConnectorQueryCtx* connectorQueryCtx, + const FileTableHandlePtr& tableHandle, + const std::shared_ptr& fileSplit, + dwio::common::ReaderOptions& readerOptions); + +void configureReaderOptions( + const std::shared_ptr& fileConfig, + const ConnectorQueryCtx* connectorQueryCtx, + const RowTypePtr& fileSchema, + const std::shared_ptr& fileSplit, + const std::unordered_map& tableParameters, + dwio::common::ReaderOptions& readerOptions); + +/// Configures row reader options for reading rows from a data file. This is the +/// generic version that does not set serde parameters. For Hive tables that +/// need serde parameters, use the overload in HiveConnectorUtil.h. +void configureRowReaderOptions( + const std::unordered_map& tableParameters, + const std::shared_ptr& scanSpec, + std::shared_ptr metadataFilter, + const RowTypePtr& rowType, + const std::shared_ptr& fileSplit, + const std::shared_ptr& fileConfig, + const config::ConfigBase* sessionProperties, + folly::Executor* ioExecutor, + dwio::common::RowReaderOptions& rowReaderOptions); + +/// Tests whether a file should be read based on partition key values and +/// column statistics. Returns true if the file passes all filters, false +/// if it can be skipped. +bool testFilters( + const common::ScanSpec* scanSpec, + const dwio::common::Reader* reader, + const std::string& filePath, + const std::unordered_map>& + partitionKey, + const std::unordered_map& + partitionKeysHandle, + bool asLocalTime); + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/FileDataSink.cpp b/velox/connectors/hive/FileDataSink.cpp new file mode 100644 index 00000000000..1999e549daf --- /dev/null +++ b/velox/connectors/hive/FileDataSink.cpp @@ -0,0 +1,523 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/FileDataSink.h" + +#include "velox/common/testutil/TestValue.h" +#include "velox/exec/OperatorUtils.h" + +using facebook::velox::common::testutil::TestValue; + +namespace facebook::velox::connector::hive { +namespace { +#define WRITER_NON_RECLAIMABLE_SECTION_GUARD(index) \ + memory::NonReclaimableSectionGuard nonReclaimableGuard( \ + writerInfo_[(index)]->nonReclaimableSectionHolder.get()) + +std::shared_ptr createSinkPool( + const std::shared_ptr& writerPool) { + return writerPool->addLeafChild(fmt::format("{}.sink", writerPool->name())); +} + +std::shared_ptr createSortPool( + const std::shared_ptr& writerPool) { + return writerPool->addLeafChild(fmt::format("{}.sort", writerPool->name())); +} +} // namespace + +const WriterId& WriterId::unpartitionedId() { + static const WriterId writerId{0}; + return writerId; +} + +std::string WriterId::toString() const { + if (partitionId.has_value() && bucketId.has_value()) { + return fmt::format("part[{}.{}]", partitionId.value(), bucketId.value()); + } + + if (partitionId.has_value() && !bucketId.has_value()) { + return fmt::format("part[{}]", partitionId.value()); + } + + // This WriterId is used to add an identifier in the MemoryPools. This could + // indicate unpart, but the bucket number needs to be disambiguated. So + // creating a new label using bucket. + if (!partitionId.has_value() && bucketId.has_value()) { + return fmt::format("bucket[{}]", bucketId.value()); + } + + return "unpart"; +} + +RowTypePtr FileDataSink::getNonPartitionTypes( + const std::vector& dataCols, + const RowTypePtr& inputType) { + std::vector childNames; + std::vector childTypes; + const auto& dataSize = dataCols.size(); + childNames.reserve(dataSize); + childTypes.reserve(dataSize); + for (auto dataCol : dataCols) { + childNames.push_back(inputType->nameOf(dataCol)); + childTypes.push_back(inputType->childAt(dataCol)); + } + + return ROW(std::move(childNames), std::move(childTypes)); +} + +RowVectorPtr FileDataSink::makeDataInput( + const std::vector& dataCols, + const RowVectorPtr& input) { + std::vector childVectors; + childVectors.reserve(dataCols.size()); + for (auto dataCol : dataCols) { + childVectors.push_back(input->childAt(dataCol)); + } + + return std::make_shared( + input->pool(), + getNonPartitionTypes(dataCols, asRowType(input->type())), + input->nulls(), + input->size(), + std::move(childVectors), + input->getNullCount()); +} + +FileDataSink::FileDataSink( + RowTypePtr inputType, + const ConnectorQueryCtx* connectorQueryCtx, + CommitStrategy commitStrategy, + dwio::common::FileFormat storageFormat, + uint32_t maxOpenWriters, + std::vector partitionChannels, + std::vector dataChannels, + int32_t bucketCount, + std::unique_ptr bucketFunction, + std::unique_ptr partitionIdGenerator, + std::shared_ptr writerFactory, + uint64_t maxTargetFileBytes, + bool partitionKeyAsLowerCase, + const common::SpillConfig* spillConfig, + uint64_t sortWriterFinishTimeSliceLimitMs) + : inputType_(std::move(inputType)), + connectorQueryCtx_(connectorQueryCtx), + commitStrategy_(commitStrategy), + storageFormat_(storageFormat), + maxOpenWriters_(maxOpenWriters), + partitionChannels_(std::move(partitionChannels)), + partitionIdGenerator_(std::move(partitionIdGenerator)), + dataChannels_(std::move(dataChannels)), + bucketCount_(bucketCount), + bucketFunction_(std::move(bucketFunction)), + writerFactory_(std::move(writerFactory)), + spillConfig_(spillConfig), + sortWriterFinishTimeSliceLimitMs_(sortWriterFinishTimeSliceLimitMs), + maxTargetFileBytes_(maxTargetFileBytes), + partitionKeyAsLowerCase_(partitionKeyAsLowerCase) { + fileSystemStats_ = std::make_unique(); +} + +void FileDataSink::appendData(RowVectorPtr input) { + checkRunning(); + + // Lazy load all the input columns. + input->loadedVector(); + + // Write to unpartitioned (and unbucketed) table. + if (!isPartitioned() && !isBucketed()) { + const auto index = ensureWriter(WriterId::unpartitionedId()); + write(index, input); + return; + } + + // Compute partition and bucket numbers. + computePartitionAndBucketIds(input); + + // All inputs belong to a single non-bucketed partition. The partition id + // must be zero. + if (!isBucketed() && partitionIdGenerator_->numPartitions() == 1) { + const auto index = ensureWriter(WriterId{0}); + write(index, input); + return; + } + + splitInputRowsAndEnsureWriters(); + + for (auto index = 0; index < writers_.size(); ++index) { + const vector_size_t partitionSize = partitionSizes_[index]; + if (partitionSize == 0) { + continue; + } + + RowVectorPtr writerInput = partitionSize == input->size() + ? input + : exec::wrap(partitionSize, partitionRows_[index], input); + write(index, writerInput); + } +} + +void FileDataSink::write(size_t index, RowVectorPtr input) { + WRITER_NON_RECLAIMABLE_SECTION_GUARD(index); + auto dataInput = makeDataInput(dataChannels_, input); + + if (writers_[index] == nullptr) { + writers_[index] = createWriterForIndex(index); + } + + writers_[index]->write(dataInput); + writerInfo_[index]->inputSizeInBytes += dataInput->estimateFlatSize(); + writerInfo_[index]->numWrittenRows += dataInput->size(); + writerInfo_[index]->currentFileWrittenRows += dataInput->size(); + + // File rotation is not supported for bucketed tables (require one file per + // bucket with predictable name) or sorted writes (SortingWriter not + // recreated). + if (maxTargetFileBytes_ == 0 || isBucketed() || sortWrite()) { + return; + } + + const auto currentFileBytes = getCurrentFileBytes(index); + if (currentFileBytes >= maxTargetFileBytes_) { + rotateWriter(index); + } +} + +uint64_t FileDataSink::getCurrentFileBytes(size_t writerIndex) const { + VELOX_CHECK_LT(writerIndex, ioStats_.size()); + VELOX_CHECK_LT(writerIndex, writerInfo_.size()); + const auto totalBytes = ioStats_[writerIndex]->rawBytesWritten(); + const auto baselineBytes = writerInfo_[writerIndex]->cumulativeWrittenBytes; + // Sanity check: total should always be >= baseline since ioStats is + // never reset and cumulative is a snapshot of rawBytesWritten at rotation. + VELOX_DCHECK_GE(totalBytes, baselineBytes); + return totalBytes - baselineBytes; +} + +void FileDataSink::finalizeWriterFile(size_t index) { + VELOX_CHECK_LT(index, writerInfo_.size()); + VELOX_CHECK_LT(index, ioStats_.size()); + + auto& info = writerInfo_[index]; + + // Capture current file stats AFTER close to include footer bytes. + const auto currentFileBytes = getCurrentFileBytes(index); + + // Finalize the current file into writtenFiles using the stored names. + if (currentFileBytes > 0) { + FileInfo fileInfo; + fileInfo.writeFileName = info->currentWriteFileName; + fileInfo.targetFileName = info->currentTargetFileName; + fileInfo.fileSize = currentFileBytes; + fileInfo.numRows = info->currentFileWrittenRows; + // Reset for next file. + info->currentFileWrittenRows = 0; + info->writtenFiles.push_back(std::move(fileInfo)); + } + + // Update cumulative stats as a snapshot of total stats so far. + // This becomes the baseline for the next file. + info->cumulativeWrittenBytes = ioStats_[index]->rawBytesWritten(); +} + +void FileDataSink::rotateWriter(size_t index) { + VELOX_CHECK_LT(index, writers_.size()); + VELOX_CHECK_LT(index, writerInfo_.size()); + + auto& info = writerInfo_[index]; + + // Close the writer first to flush all data including footer. + writers_[index]->close(); + + // Finalize the current file state. + finalizeWriterFile(index); + + // Release old writer's memory pools. The new writer will be created lazily + // on the next write to avoid creating empty files. + writers_[index].reset(); + + ++info->fileSequenceNumber; +} + +std::string FileDataSink::stateString(State state) { + switch (state) { + case State::kRunning: + return "RUNNING"; + case State::kFinishing: + return "FLUSHING"; + case State::kClosed: + return "CLOSED"; + case State::kAborted: + return "ABORTED"; + default: + VELOX_UNREACHABLE("BAD STATE: {}", static_cast(state)); + } +} + +DataSink::Stats FileDataSink::stats() const { + Stats stats; + if (state_ == State::kAborted) { + return stats; + } + + for (const auto& ioStats : ioStats_) { + stats.numWrittenBytes += ioStats->rawBytesWritten(); + stats.writeIOTimeUs += ioStats->writeIOTimeUs(); + } + + if (state_ != State::kClosed) { + return stats; + } + + // Count total files written, including rotated files. + stats.numWrittenFiles = 0; + for (size_t i = 0; i < writerInfo_.size(); ++i) { + const auto& info = writerInfo_.at(i); + VELOX_CHECK_NOT_NULL(info); + stats.numWrittenFiles += info->writtenFiles.size(); + if (!info->spillStats->empty()) { + stats.spillStats += *info->spillStats; + } + } + return stats; +} + +std::unordered_map FileDataSink::runtimeStats() + const { + std::unordered_map runtimeStats; + + const auto fsStatsMap = fileSystemStats_->stats(); + for (const auto& [statName, statValue] : fsStatsMap) { + runtimeStats.emplace( + statName, RuntimeCounter(statValue.sum, statValue.unit)); + } + + return runtimeStats; +} + +std::shared_ptr FileDataSink::createWriterPool( + const WriterId& writerId) { + auto* connectorPool = connectorQueryCtx_->connectorMemoryPool(); + return connectorPool->addAggregateChild( + fmt::format("{}.{}", connectorPool->name(), writerId.toString())); +} + +void FileDataSink::setMemoryReclaimers( + WriterInfo* /*writerInfo*/, + io::IoStatistics* /*ioStats*/) { + // Default no-op. Subclasses override to set up format-specific reclaimers. +} + +void FileDataSink::setState(State newState) { + checkStateTransition(state_, newState); + state_ = newState; +} + +void FileDataSink::checkStateTransition(State oldState, State newState) { + switch (oldState) { + case State::kRunning: + if (newState == State::kAborted || newState == State::kFinishing) { + return; + } + break; + case State::kFinishing: + if (newState == State::kAborted || newState == State::kClosed || + // The finishing state is reentry state if we yield in the middle of + // finish processing if a single run takes too long. + newState == State::kFinishing) { + return; + } + [[fallthrough]]; + case State::kAborted: + case State::kClosed: + default: + break; + } + VELOX_FAIL("Unexpected state transition from {} to {}", oldState, newState); +} + +bool FileDataSink::finish() { + // Flush is reentry state. + setState(State::kFinishing); + + // As for now, only sorted writer needs flush buffered data. For non-sorted + // writer, data is directly written to the underlying file writer. + if (!sortWrite()) { + return true; + } + + const uint64_t startTimeMs = getCurrentTimeMs(); + for (auto i = 0; i < writers_.size(); ++i) { + WRITER_NON_RECLAIMABLE_SECTION_GUARD(i); + if (!writers_[i]->finish()) { + return false; + } + if (getCurrentTimeMs() - startTimeMs > sortWriterFinishTimeSliceLimitMs_) { + return false; + } + } + return true; +} + +std::vector FileDataSink::close() { + setState(State::kClosed); + closeInternal(); + return commitMessage(); +} + +void FileDataSink::abort() { + setState(State::kAborted); + closeInternal(); +} + +void FileDataSink::closeInternal() { + VELOX_CHECK_NE(state_, State::kRunning); + VELOX_CHECK_NE(state_, State::kFinishing); + + TestValue::adjust( + "facebook::velox::connector::hive::FileDataSink::closeInternal", this); + + // NOTE: writers_[i] can be nullptr during file rotation. In rotateWriter(), + // we call writers_[index].reset() to release the old writer before creating + // a new one. If an error occurs during new writer creation, or if abort is + // called during this window, the writer slot may be empty. + if (state_ == State::kClosed) { + for (int i = 0; i < writers_.size(); ++i) { + if (writers_[i] == nullptr) { + continue; + } + WRITER_NON_RECLAIMABLE_SECTION_GUARD(i); + writers_[i]->close(); + finalizeWriterFile(i); + } + } else { + for (int i = 0; i < writers_.size(); ++i) { + if (writers_[i] == nullptr) { + continue; + } + WRITER_NON_RECLAIMABLE_SECTION_GUARD(i); + writers_[i]->abort(); + } + } +} + +uint32_t FileDataSink::ensureWriter(const WriterId& id) { + auto it = writerIndexMap_.find(id); + if (it != writerIndexMap_.end()) { + return it->second; + } + return appendWriter(id); +} + +uint32_t FileDataSink::appendWriter(const WriterId& id) { + // Check max open writers. + VELOX_USER_CHECK_LE( + writers_.size(), maxOpenWriters_, "Exceeded open writer limit"); + VELOX_CHECK_EQ(writers_.size(), writerInfo_.size()); + VELOX_CHECK_EQ(writerIndexMap_.size(), writerInfo_.size()); + + std::optional partitionName; + if (isPartitioned()) { + partitionName = getPartitionName(id.partitionId.value()); + } + + // Without explicitly setting flush policy, the default memory based flush + // policy is used. + auto writerParameters = getWriterParameters(partitionName, id.bucketId); + auto writerPool = createWriterPool(id); + auto sinkPool = createSinkPool(writerPool); + std::shared_ptr sortPool{nullptr}; + if (sortWrite()) { + sortPool = createSortPool(writerPool); + } + writerInfo_.emplace_back( + std::make_shared( + std::move(writerParameters), + std::move(writerPool), + std::move(sinkPool), + std::move(sortPool))); + ioStats_.emplace_back(std::make_unique()); + + setMemoryReclaimers(writerInfo_.back().get(), ioStats_.back().get()); + writers_.emplace_back(createWriterForIndex(writerInfo_.size() - 1)); + addThreadLocalRuntimeStat( + fmt::format("{}WriterCount", dwio::common::toString(storageFormat_)), + RuntimeCounter(1)); + // Extends the buffer used for partition rows calculations. + partitionSizes_.emplace_back(0); + partitionRows_.emplace_back(nullptr); + rawPartitionRows_.emplace_back(nullptr); + + writerIndexMap_.emplace(id, writers_.size() - 1); + return writerIndexMap_[id]; +} + +WriterId FileDataSink::getWriterId(vector_size_t row) const { + std::optional partitionId; + if (isPartitioned()) { + VELOX_CHECK_LT(partitionIds_[row], std::numeric_limits::max()); + partitionId = static_cast(partitionIds_[row]); + } + + std::optional bucketId; + if (isBucketed()) { + bucketId = bucketIds_[row]; + } + return WriterId{partitionId, bucketId}; +} + +void FileDataSink::updatePartitionRows( + uint32_t index, + vector_size_t numRows, + vector_size_t row) { + VELOX_DCHECK_LT(index, partitionSizes_.size()); + VELOX_DCHECK_EQ(partitionSizes_.size(), partitionRows_.size()); + VELOX_DCHECK_EQ(partitionRows_.size(), rawPartitionRows_.size()); + if (FOLLY_UNLIKELY(partitionRows_[index] == nullptr) || + (partitionRows_[index]->capacity() < numRows * sizeof(vector_size_t))) { + partitionRows_[index] = + allocateIndices(numRows, connectorQueryCtx_->memoryPool()); + rawPartitionRows_[index] = + partitionRows_[index]->asMutable(); + } + rawPartitionRows_[index][partitionSizes_[index]] = row; + ++partitionSizes_[index]; +} + +void FileDataSink::splitInputRowsAndEnsureWriters() { + VELOX_CHECK(isPartitioned() || isBucketed()); + if (isBucketed() && isPartitioned()) { + VELOX_CHECK_EQ(bucketIds_.size(), partitionIds_.size()); + } + + std::fill(partitionSizes_.begin(), partitionSizes_.end(), 0); + + const auto numRows = static_cast( + isPartitioned() ? partitionIds_.size() : bucketIds_.size()); + for (vector_size_t row = 0; row < numRows; ++row) { + const auto id = getWriterId(row); + const uint32_t index = ensureWriter(id); + updatePartitionRows(index, numRows, row); + } + + for (uint32_t i = 0; i < partitionSizes_.size(); ++i) { + if (partitionSizes_[i] != 0) { + VELOX_CHECK_NOT_NULL(partitionRows_[i]); + partitionRows_[i]->setSize(partitionSizes_[i] * sizeof(vector_size_t)); + } + } +} + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/FileDataSink.h b/velox/connectors/hive/FileDataSink.h new file mode 100644 index 00000000000..d427cc1db22 --- /dev/null +++ b/velox/connectors/hive/FileDataSink.h @@ -0,0 +1,447 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/connectors/Connector.h" +#include "velox/connectors/hive/PartitionIdGenerator.h" +#include "velox/dwio/common/Options.h" +#include "velox/dwio/common/Writer.h" +#include "velox/dwio/common/WriterFactory.h" +#include "velox/exec/MemoryReclaimer.h" + +namespace facebook::velox::connector::hive { + +/// Parameters for file writers. +class WriterParameters { + public: + enum class UpdateMode { + kNew, // Write files to a new directory. + kOverwrite, // Overwrite an existing directory. + // Append mode is currently only supported for unpartitioned tables. + kAppend, // Append to an unpartitioned table. + }; + + /// @param updateMode Write the files to a new directory, or append to an + /// existing directory or overwrite an existing directory. + /// @param partitionName Partition name in the typical Hive style, which is + /// also the partition subdirectory part of the partition path. + /// @param targetFileName The final name of a file after committing. + /// @param targetDirectory The final directory that a file should be in after + /// committing. + /// @param writeFileName The temporary name of the file that a running writer + /// writes to. If a running writer writes directory to the target file, set + /// writeFileName to targetFileName by default. + /// @param writeDirectory The temporary directory that a running writer writes + /// to. If a running writer writes directory to the target directory, set + /// writeDirectory to targetDirectory by default. + WriterParameters( + UpdateMode updateMode, + std::optional partitionName, + std::string targetFileName, + std::string targetDirectory, + const std::optional& writeFileName = std::nullopt, + const std::optional& writeDirectory = std::nullopt) + : updateMode_(updateMode), + partitionName_(std::move(partitionName)), + targetFileName_(std::move(targetFileName)), + targetDirectory_(std::move(targetDirectory)), + writeFileName_(writeFileName.value_or(targetFileName_)), + writeDirectory_(writeDirectory.value_or(targetDirectory_)) {} + + UpdateMode updateMode() const { + return updateMode_; + } + + static std::string updateModeToString(UpdateMode updateMode) { + switch (updateMode) { + case UpdateMode::kNew: + return "NEW"; + case UpdateMode::kOverwrite: + return "OVERWRITE"; + case UpdateMode::kAppend: + return "APPEND"; + default: + VELOX_UNSUPPORTED("Unsupported update mode."); + } + } + + const std::optional& partitionName() const { + return partitionName_; + } + + const std::string& targetFileName() const { + return targetFileName_; + } + + const std::string& writeFileName() const { + return writeFileName_; + } + + const std::string& targetDirectory() const { + return targetDirectory_; + } + + const std::string& writeDirectory() const { + return writeDirectory_; + } + + private: + const UpdateMode updateMode_; + const std::optional partitionName_; + const std::string targetFileName_; + const std::string targetDirectory_; + const std::string writeFileName_; + const std::string writeDirectory_; +}; + +/// Information about a single file written as part of a writer's output. +/// When file rotation occurs, multiple FileInfo entries are created. +struct FileInfo { + /// The temporary file name used during writing (in the staging directory). + std::string writeFileName; + /// The final file name after commit (in the target directory). + std::string targetFileName; + /// Size of the file in bytes. + uint64_t fileSize{0}; + /// Number of rows in the file. + uint64_t numRows{0}; +}; + +struct WriterInfo { + WriterInfo( + WriterParameters parameters, + std::shared_ptr _writerPool, + std::shared_ptr _sinkPool, + std::shared_ptr _sortPool) + : writerParameters(std::move(parameters)), + nonReclaimableSectionHolder(new tsan_atomic(false)), + spillStats(std::make_unique()), + writerPool(std::move(_writerPool)), + sinkPool(std::move(_sinkPool)), + sortPool(std::move(_sortPool)) {} + + // Writer configuration: update mode, partition, file paths. + const WriterParameters writerParameters; + // Guards non-reclaimable sections during write operations. + const std::unique_ptr> nonReclaimableSectionHolder; + // Collects the spill stats from sort writer if the spilling has been + // triggered. + const std::unique_ptr spillStats; + // Memory pool for the writer itself. + const std::shared_ptr writerPool; + // Memory pool for the file sink (serialization layer). + const std::shared_ptr sinkPool; + // Memory pool for sort buffers (nullptr if not a sorted write). + const std::shared_ptr sortPool; + // Total rows written by this writer across all files. + uint64_t numWrittenRows{0}; + // Rows written to the current file; reset to 0 when the file is finalized. + uint64_t currentFileWrittenRows{0}; + uint64_t inputSizeInBytes{0}; + /// File sequence number for tracking multiple files written due to size-based + /// splitting. Incremented each time the writer rotates to a new file. + /// Used to generate sequenced file names (e.g., file_1.orc, file_2.orc). + /// Invariant during write: fileSequenceNumber == writtenFiles.size() + /// After close: fileSequenceNumber + 1 == writtenFiles.size() (final file + /// added) + uint32_t fileSequenceNumber{0}; + /// Tracks all files written by this writer. + /// During write: contains only rotated (completed) files. + /// After close: contains all files including the final one (via + /// finalizeWriterFile). + std::vector writtenFiles; + /// Snapshot of total bytes written at the start of the current file. + /// Used as baseline to calculate current file size: rawBytesWritten() - this. + /// Updated to ioStats->rawBytesWritten() after each rotation. + uint64_t cumulativeWrittenBytes{0}; + /// Current file's write filename (set when file is created/rotated). + /// This avoids recomputing makeSequencedFileName() in commitMessage(). + std::string currentWriteFileName; + /// Current file's target filename (set when file is created/rotated). + std::string currentTargetFileName; +}; + +/// Identifies a writer by partition and bucket. +struct WriterId { + std::optional partitionId{std::nullopt}; + std::optional bucketId{std::nullopt}; + + WriterId() = default; + + explicit WriterId( + std::optional _partitionId, + std::optional _bucketId = std::nullopt) + : partitionId(_partitionId), bucketId(_bucketId) {} + + /// Returns the special writer id for the un-partitioned (and non-bucketed) + /// table. + static const WriterId& unpartitionedId(); + + std::string toString() const; + + bool operator==(const WriterId& other) const { + return std::tie(partitionId, bucketId) == + std::tie(other.partitionId, other.bucketId); + } +}; + +struct WriterIdHasher { + std::size_t operator()(const WriterId& id) const { + return bits::hashMix( + id.partitionId.value_or(std::numeric_limits::max()), + id.bucketId.value_or(std::numeric_limits::max())); + } +}; + +struct WriterIdEq { + bool operator()(const WriterId& lhs, const WriterId& rhs) const { + return lhs == rhs; + } +}; + +/// Base class for file-based data sinks that write data to columnar file +/// formats (ORC, Parquet, etc.). Provides the generic write pipeline: state +/// machine, row routing to writers, file rotation, and stats aggregation. +/// +/// Connector-specific data sinks (Hive, Paimon, etc.) extend this class to +/// add format-specific behavior like commit protocols, partition naming, +/// writer creation, and memory reclamation. +class FileDataSink : public DataSink { + public: + /// Defines the execution states of a file data sink. + enum class State { + /// The data sink accepts new append data in this state. + kRunning = 0, + /// The data sink flushes any buffered data to the underlying file writer + /// but no more data can be appended. + kFinishing = 1, + /// The data sink is aborted on error and no more data can be appended. + kAborted = 2, + /// The data sink is closed on error and no more data can be appended. + kClosed = 3 + }; + static std::string stateString(State state); + + FileDataSink( + RowTypePtr inputType, + const ConnectorQueryCtx* connectorQueryCtx, + CommitStrategy commitStrategy, + dwio::common::FileFormat storageFormat, + uint32_t maxOpenWriters, + std::vector partitionChannels, + std::vector dataChannels, + int32_t bucketCount, + std::unique_ptr bucketFunction, + std::unique_ptr partitionIdGenerator, + std::shared_ptr writerFactory, + uint64_t maxTargetFileBytes, + bool partitionKeyAsLowerCase, + const common::SpillConfig* spillConfig, + uint64_t sortWriterFinishTimeSliceLimitMs); + + void appendData(RowVectorPtr input) override; + + bool finish() override; + + Stats stats() const override; + + std::unordered_map runtimeStats() const override; + + std::vector close() override; + + void abort() override; + + protected: + // Validates the state transition from 'oldState' to 'newState'. + void checkStateTransition(State oldState, State newState); + void setState(State newState); + + // Generates commit messages for all writers containing metadata about written + // files. Creates a JSON object for each writer with partition name, + // file paths, file names, data sizes, and row counts. This metadata is used + // by the coordinator to commit the transaction and update the metastore. + // + // @return Vector of JSON strings, one per writer. + virtual std::vector commitMessage() const = 0; + + // Returns the type of non-partition data columns. + static RowTypePtr getNonPartitionTypes( + const std::vector& dataCols, + const RowTypePtr& inputType); + + // Filters out partition columns from the input, returning only data columns. + static RowVectorPtr makeDataInput( + const std::vector& dataCols, + const RowVectorPtr& input); + + FOLLY_ALWAYS_INLINE bool sortWrite() const { + return sortWrite_; + } + + // Returns true if the table is partitioned. + FOLLY_ALWAYS_INLINE bool isPartitioned() const { + return partitionIdGenerator_ != nullptr; + } + + // Returns true if the table is bucketed. + FOLLY_ALWAYS_INLINE bool isBucketed() const { + return bucketCount_ != 0; + } + + FOLLY_ALWAYS_INLINE bool isCommitRequired() const { + return commitStrategy_ != CommitStrategy::kNoCommit; + } + + std::shared_ptr createWriterPool( + const WriterId& writerId); + + // Sets up memory reclaimers for writer pools. Override to install + // format-specific reclaimers. Default is a no-op. + virtual void setMemoryReclaimers( + WriterInfo* writerInfo, + io::IoStatistics* ioStats); + + // Returns the bytes written to the current file for the specified writer. + uint64_t getCurrentFileBytes(size_t writerIndex) const; + + // Compute the partition id and bucket id for each row in 'input'. + virtual void computePartitionAndBucketIds(const RowVectorPtr& input) = 0; + + // Get the HiveWriter corresponding to the row + // from partitionIds and bucketIds. + WriterId getWriterId(vector_size_t row) const; + + // Computes the number of input rows as well as the actual input row indices + // to each corresponding (bucketed) partition based on the partition and + // bucket ids calculated by 'computePartitionAndBucketIds'. + void splitInputRowsAndEnsureWriters(); + + // Makes sure to create one writer for the given writer id. The function + // returns the corresponding index in 'writers_'. + virtual uint32_t ensureWriter(const WriterId& id); + + // Appends a new writer for the given 'id'. The function returns the index of + // the newly created writer in 'writers_'. + uint32_t appendWriter(const WriterId& id); + + // Creates a writer for the given index using the current file sequence. + virtual std::unique_ptr + createWriterForIndex(size_t writerIndex) = 0; + + // Creates and configures WriterOptions based on file format. + virtual std::shared_ptr createWriterOptions() + const = 0; + + virtual std::shared_ptr createWriterOptions( + size_t writerIndex) const = 0; + + // Returns the partition directory name for the given partition ID. + virtual std::string getPartitionName(uint32_t partitionId) const = 0; + + // Returns writer parameters for the given partition and bucket. + virtual WriterParameters getWriterParameters( + const std::optional& partition, + std::optional bucketId) const = 0; + + // Records a row index for a specific partition. + void + updatePartitionRows(uint32_t index, vector_size_t numRows, vector_size_t row); + + FOLLY_ALWAYS_INLINE void checkRunning() const { + VELOX_CHECK_EQ(state_, State::kRunning, "File data sink is not running"); + } + + // Invoked to write 'input' to the specified file writer. + void write(size_t index, RowVectorPtr input); + + // Rotates the writer at the given index to a new file. + virtual void rotateWriter(size_t index); + + // Finalizes the current file for the writer at the given index. + void finalizeWriterFile(size_t index); + + virtual void closeInternal(); + + // IMPORTANT NOTE: these are passed to writers as raw pointers. FileDataSink + // owns the lifetime of these objects, and therefore must destroy them last. + std::vector> ioStats_; + // Generic filesystem stats, exposed as RuntimeStats + std::unique_ptr fileSystemStats_; + + const RowTypePtr inputType_; + const ConnectorQueryCtx* const connectorQueryCtx_; + const CommitStrategy commitStrategy_; + const dwio::common::FileFormat storageFormat_; + const uint32_t maxOpenWriters_; + const std::vector partitionChannels_; + const std::unique_ptr partitionIdGenerator_; + // Indices of dataChannel are stored in ascending order + const std::vector dataChannels_; + const int32_t bucketCount_{0}; + const std::unique_ptr bucketFunction_; + const std::shared_ptr writerFactory_; + const common::SpillConfig* const spillConfig_; + const uint64_t sortWriterFinishTimeSliceLimitMs_{0}; + const uint64_t maxTargetFileBytes_{0}; + const bool partitionKeyAsLowerCase_; + + /// Whether this sink uses sorted writes. Set by subclass after computing + /// sort columns in its constructor. + bool sortWrite_{false}; + + State state_{State::kRunning}; + + tsan_atomic nonReclaimableSection_{false}; + + // The map from writer id to the writer index in 'writers_' and 'writerInfo_'. + folly::F14FastMap + writerIndexMap_; + + // Below are structures for partitions from all inputs. writerInfo_ and + // writers_ are both indexed by partitionId. + std::vector> writerInfo_; + std::vector> writers_; + + // Below are structures updated when processing current input. partitionIds_ + // are indexed by the row of input_. partitionRows_, rawPartitionRows_ and + // partitionSizes_ are indexed by partitionId. + raw_vector partitionIds_; + std::vector partitionRows_; + std::vector rawPartitionRows_; + std::vector partitionSizes_; + + // Reusable buffers for bucket id calculations. + std::vector bucketIds_; +}; + +FOLLY_ALWAYS_INLINE std::ostream& operator<<( + std::ostream& os, + FileDataSink::State state) { + os << FileDataSink::stateString(state); + return os; +} +} // namespace facebook::velox::connector::hive + +template <> +struct fmt::formatter + : formatter { + auto format( + facebook::velox::connector::hive::FileDataSink::State s, + format_context& ctx) const { + return formatter::format( + facebook::velox::connector::hive::FileDataSink::stateString(s), ctx); + } +}; diff --git a/velox/connectors/hive/FileDataSource.cpp b/velox/connectors/hive/FileDataSource.cpp new file mode 100644 index 00000000000..260b6764a7c --- /dev/null +++ b/velox/connectors/hive/FileDataSource.cpp @@ -0,0 +1,735 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/FileDataSource.h" + +#include +#include +#include + +#include "velox/common/Casts.h" +#include "velox/common/testutil/TestValue.h" +#include "velox/common/time/CpuWallTimer.h" +#include "velox/connectors/hive/ExtractionUtils.h" +#include "velox/connectors/hive/FileConfig.h" +#include "velox/expression/FieldReference.h" + +using facebook::velox::common::testutil::TestValue; + +namespace facebook::velox::connector::hive { + +namespace { + +inline void addIoCounterMetric( + io::IoCounter& counter, + const std::string& key, + std::unordered_map& res) { + if (counter.count() > 0) { + res.insert({key, RuntimeMetric(counter.count())}); + } +} + +inline void addIoCounterMetric( + uint64_t value, + const std::string& key, + RuntimeCounter::Unit unit, + std::unordered_map& res) { + if (value > 0) { + res.insert({key, RuntimeMetric(value, unit)}); + } +} + +inline void addIoStatsMetric( + io::IoCounter& counter, + const std::string& key, + RuntimeCounter::Unit unit, + std::unordered_map& res) { + if (counter.count() > 0) { + res.insert( + {key, + RuntimeMetric( + saturateCast(counter.sum()), + counter.count(), + saturateCast(counter.min()), + saturateCast(counter.max()), + unit)}); + } +} + +inline void addIoLatencyMetric( + io::IoCounter& counter, + const std::string& key, + std::unordered_map& res) { + if (counter.count() > 0) { + res.insert( + {key, + RuntimeMetric( + saturateCast(counter.sum() * 1'000), + counter.count(), + saturateCast(counter.min() * 1'000), + saturateCast(counter.max() * 1'000), + RuntimeCounter::Unit::kNanos)}); + } +} + +void addIoStatsToRuntimeStats( + io::IoStatistics& ioStats, + std::string_view prefix, + std::unordered_map& res) { + auto key = [&](std::string_view name) { + return prefix.empty() ? std::string(name) + : fmt::format("{}.{}", prefix, name); + }; + + addIoLatencyMetric( + ioStats.queryThreadIoLatencyUs(), key(Connector::kIoWaitWallNanos), res); + addIoLatencyMetric( + ioStats.storageReadLatencyUs(), + key(Connector::kStorageReadWallNanos), + res); + addIoLatencyMetric( + ioStats.ssdCacheReadLatencyUs(), + key(Connector::kSsdCacheReadWallNanos), + res); + addIoLatencyMetric( + ioStats.cacheWaitLatencyUs(), key(Connector::kCacheWaitWallNanos), res); + addIoLatencyMetric( + ioStats.coalescedSsdLoadLatencyUs(), + key(Connector::kCoalescedSsdLoadWallNanos), + res); + addIoLatencyMetric( + ioStats.coalescedStorageLoadLatencyUs(), + key(Connector::kCoalescedStorageLoadWallNanos), + res); + + addIoCounterMetric( + ioStats.prefetch(), key(FileDataSource::kNumPrefetch), res); + addIoStatsMetric( + ioStats.prefetch(), + key(FileDataSource::kPrefetchBytes), + RuntimeCounter::Unit::kBytes, + res); + addIoCounterMetric( + ioStats.totalScanTimeNs(), + key(FileDataSource::kTotalScanTime), + RuntimeCounter::Unit::kNanos, + res); + addIoCounterMetric( + ioStats.rawOverreadBytes(), + key(FileDataSource::kOverreadBytes), + RuntimeCounter::Unit::kBytes, + res); + + addIoStatsMetric( + ioStats.read(), + key(FileDataSource::kStorageReadBytes), + RuntimeCounter::Unit::kBytes, + res); + addIoCounterMetric( + ioStats.ssdRead(), key(FileDataSource::kNumLocalRead), res); + addIoStatsMetric( + ioStats.ssdRead(), + key(FileDataSource::kLocalReadBytes), + RuntimeCounter::Unit::kBytes, + res); + addIoCounterMetric(ioStats.ramHit(), key(FileDataSource::kNumRamRead), res); + addIoStatsMetric( + ioStats.ramHit(), + key(FileDataSource::kRamReadBytes), + RuntimeCounter::Unit::kBytes, + res); +} + +} // namespace + +void FileDataSource::processColumnHandle(const FileColumnHandlePtr& handle) { + switch (handle->columnType()) { + case FileColumnHandle::ColumnType::kRegular: + break; + case FileColumnHandle::ColumnType::kPartitionKey: + partitionKeys_.emplace(handle->name(), handle); + break; + case FileColumnHandle::ColumnType::kSynthesized: + infoColumns_.emplace(handle->name(), handle); + break; + case FileColumnHandle::ColumnType::kRowIndex: + specialColumns_.rowIndex = handle->name(); + break; + case FileColumnHandle::ColumnType::kRowId: + specialColumns_.rowId = handle->name(); + break; + } +} + +FileDataSource::FileDataSource( + const RowTypePtr& outputType, + const connector::ConnectorTableHandlePtr& tableHandle, + const connector::ColumnHandleMap& assignments, + FileHandleFactory* fileHandleFactory, + folly::Executor* ioExecutor, + const ConnectorQueryCtx* connectorQueryCtx, + const std::shared_ptr& fileConfig) + : fileHandleFactory_(fileHandleFactory), + ioExecutor_(ioExecutor), + connectorQueryCtx_(connectorQueryCtx), + fileConfig_(fileConfig), + pool_(connectorQueryCtx->memoryPool()), + outputType_(outputType), + expressionEvaluator_(connectorQueryCtx->expressionEvaluator()) { + tableHandle_ = checkedPointerCast(tableHandle); + + folly::F14FastMap columnHandles; + // Column handles keyed on the table column name. + for (const auto& [_, columnHandle] : assignments) { + auto handle = checkedPointerCast(columnHandle); + const auto [it, unique] = + columnHandles.emplace(handle->name(), handle.get()); + if (!unique) { + // This should not happen normally, but there are cases where we get + // duplicate assignments for partitioning columns. + checkColumnHandleConsistent(*handle, *it->second); + VELOX_CHECK_EQ( + handle->columnType(), + FileColumnHandle::ColumnType::kPartitionKey, + "Cannot map from same table column to different outputs in table scan; a project node should be used instead: {}", + handle->name()); + continue; + } + processColumnHandle(handle); + } + for (auto& handle : tableHandle_->filterColumnHandles()) { + auto it = columnHandles.find(handle->name()); + if (it != columnHandles.end()) { + checkColumnHandleConsistent(*handle, *it->second); + continue; + } + processColumnHandle(handle); + } + + std::vector readColumnNames; + auto readColumnTypes = outputType_->children(); + for (const auto& outputName : outputType_->names()) { + auto it = assignments.find(outputName); + VELOX_CHECK( + it != assignments.end(), + "ColumnHandle is missing for output column: {}", + outputName); + + auto* handle = static_cast(it->second.get()); + readColumnNames.push_back(handle->name()); + for (auto& subfield : handle->requiredSubfields()) { + VELOX_USER_CHECK_EQ( + getColumnName(subfield), + handle->name(), + "Required subfield does not match column name"); + subfields_[handle->name()].push_back(&subfield); + } + columnPostProcessors_.push_back(handle->postProcessor()); + } + + if (fileConfig_->isFileColumnNamesReadAsLowerCase( + connectorQueryCtx->sessionProperties())) { + checkColumnNameLowerCase(outputType_); + checkColumnNameLowerCase(tableHandle_->subfieldFilters(), infoColumns_); + checkColumnNameLowerCase(tableHandle_->remainingFilter()); + } + + for (const auto& [k, v] : tableHandle_->subfieldFilters()) { + filters_.emplace(k.clone(), v); + } + double sampleRate = tableHandle_->sampleRate(); + auto remainingFilter = extractFiltersFromRemainingFilter( + tableHandle_->remainingFilter(), + expressionEvaluator_, + filters_, + sampleRate); + if (sampleRate != 1) { + randomSkip_ = std::make_shared(sampleRate); + } + + if (remainingFilter) { + remainingFilterExprSet_ = expressionEvaluator_->compile(remainingFilter); + auto& remainingFilterExpr = remainingFilterExprSet_->expr(0); + folly::F14FastMap columnNames; + for (int i = 0; i < readColumnNames.size(); ++i) { + columnNames[readColumnNames[i]] = i; + } + for (auto& input : remainingFilterExpr->distinctFields()) { + auto it = columnNames.find(input->field()); + if (it != columnNames.end()) { + if (shouldEagerlyMaterialize(*remainingFilterExpr, *input)) { + multiReferencedFields_.push_back(it->second); + } + continue; + } + // Remaining filter may reference columns that are not used otherwise, + // e.g. are not being projected out and are not used in range filters. + // Make sure to add these columns to readerOutputType_. + readColumnNames.push_back(input->field()); + readColumnTypes.push_back(input->type()); + } + remainingFilterSubfields_ = remainingFilterExpr->extractSubfields(); + if (VLOG_IS_ON(1)) { + VLOG(1) << fmt::format( + "Extracted subfields from remaining filter: [{}]", + fmt::join(remainingFilterSubfields_, ", ")); + } + for (auto& subfield : remainingFilterSubfields_) { + const auto& name = getColumnName(subfield); + auto it = subfields_.find(name); + if (it != subfields_.end()) { + // Some subfields of the column are already projected out, we append the + // remainingFilter subfield + it->second.push_back(&subfield); + } else if (columnNames.count(name) == 0) { + // remainingFilter subfield's column is not projected out, we add the + // column and append the subfield + subfields_[name].push_back(&subfield); + } + } + } + + readerOutputType_ = + ROW(std::move(readColumnNames), std::move(readColumnTypes)); + scanSpec_ = makeScanSpec( + readerOutputType_, + subfields_, + filters_, + /*indexColumns=*/{}, + tableHandle_->dataColumns(), + partitionKeys_, + infoColumns_, + specialColumns_, + fileConfig_->readStatsBasedFilterReorderDisabled( + connectorQueryCtx_->sessionProperties()), + pool_); + if (remainingFilter) { + metadataFilter_ = std::make_shared( + *scanSpec_, *remainingFilter, expressionEvaluator_); + } + + // Detect extraction columns and reconfigure scanSpec_ if needed. + bool hasExtractions = false; + readColumnTypes = readerOutputType_->children(); + for (int outputIdx = 0; outputIdx < outputType->size(); ++outputIdx) { + const auto& outputName = outputType->nameOf(outputIdx); + auto it = assignments.find(outputName); + if (it == assignments.end()) { + continue; + } + auto* handle = static_cast(it->second.get()); + if (!handle->extractions().empty()) { + // Column has extraction chains. Read with schemaType from file, then + // apply extraction post-read. Extractions and requiredSubfields are + // mutually exclusive (enforced by the column handle constructor). + auto readerIdx = readerOutputType_->getChildIdxIfExists(handle->name()); + if (readerIdx.has_value()) { + readColumnTypes[*readerIdx] = handle->schemaType(); + extractionColumns_[*readerIdx] = handle; + hasExtractions = true; + } + } + } + + if (hasExtractions) { + // Rebuild readerOutputType_ with schemaType for extraction columns. + readerOutputType_ = + ROW(std::vector( + readerOutputType_->names().begin(), + readerOutputType_->names().end()), + std::move(readColumnTypes)); + // Rebuild scanSpec_ with the updated readerOutputType_. + scanSpec_ = makeScanSpec( + readerOutputType_, + subfields_, + filters_, + /*indexColumns=*/{}, + tableHandle_->dataColumns(), + partitionKeys_, + infoColumns_, + specialColumns_, + fileConfig_->readStatsBasedFilterReorderDisabled( + connectorQueryCtx->sessionProperties()), + pool_); + configureExtractionColumns(); + } + + dataIoStats_ = std::make_shared(); + metadataIoStats_ = std::make_shared(); + ioStats_ = std::make_shared(); +} + +void FileDataSource::configureExtractionColumns() { + // Configure extraction columns on the ScanSpec. For each column with + // extractions, this: + // 1. Sets pruning hints so DWRF/Nimble readers skip unneeded sub-streams. + // 2. Sets a transform function on the ScanSpec node so the reader applies + // extraction chains and produces the output type directly. + for (auto& [colIdx, handle] : extractionColumns_) { + auto* fieldSpec = scanSpec_->childByName(readerOutputType_->nameOf(colIdx)); + if (!fieldSpec) { + continue; + } + const auto& extractions = handle->extractions(); + auto extractionOutputType = handle->dataType(); + + // For multiple extractions, do NOT call configureExtractionScanSpec -- + // keep ExtractionType as kNone and use full chains in the transform. + // This ensures the text reader (which does not handle ExtractionType + // natively) produces correct results. + if (extractions.size() == 1) { + configureExtractionScanSpec( + handle->schemaType(), extractions, *fieldSpec, pool_); + } + if (extractions.size() == 1) { + // Store a full-chain transform so hasTransform() returns true. This + // signals to the delta update path that extraction is configured. + // The full chain is captured for PrismSplitReader to replace it. + fieldSpec->setTransform( + [fullChain = extractions[0].chain]( + const VectorPtr& input, memory::MemoryPool* pool) -> VectorPtr { + return applyExtractionChain(input, fullChain, pool); + }, + extractionOutputType); + } else { + // Multiple extractions: do NOT set ExtractionType on the ScanSpec. + // Use full chains in the transform so the text reader (which does + // not handle ExtractionType natively) produces correct results. + // TODO: Optimization: for agreeing multiple extractions, set + // ExtractionType and use remaining chains. Requires text reader + // to handle ExtractionType natively. + struct ExtractionInfo { + std::string outputName; + std::vector chain; + }; + + std::vector infos; + for (const auto& extraction : extractions) { + infos.push_back({extraction.outputName, extraction.chain}); + } + // Always need a transform for multiple extractions to assemble ROW. + fieldSpec->setTransform( + [infos = std::move(infos)]( + const VectorPtr& input, memory::MemoryPool* pool) -> VectorPtr { + std::vector children; + std::vector names; + std::vector types; + children.reserve(infos.size()); + names.reserve(infos.size()); + types.reserve(infos.size()); + for (const auto& info : infos) { + VectorPtr extracted; + if (info.chain.empty()) { + extracted = input; + } else { + extracted = applyExtractionChain(input, info.chain, pool); + } + names.push_back(info.outputName); + types.push_back(extracted->type()); + children.push_back(std::move(extracted)); + } + return std::make_shared( + pool, + ROW(std::move(names), std::move(types)), + nullptr, + input->size(), + std::move(children)); + }, + extractionOutputType); + } + } + + // Build readerProducedType_ -- the actual type the reader will produce. + // For extraction columns where the reader handles extraction natively + // (ExtractionType != kNone), the output type differs from schemaType. + { + auto names = readerOutputType_->names(); + auto types = readerOutputType_->children(); + bool needsSeparateType = false; + for (auto& [colIdx, handle] : extractionColumns_) { + auto* fieldSpec = + scanSpec_->childByName(readerOutputType_->nameOf(colIdx)); + if (fieldSpec && + fieldSpec->extractionType() != + common::ScanSpec::ExtractionType::kNone) { + VELOX_CHECK_LT(static_cast(colIdx), types.size()); + types[colIdx] = handle->dataType(); + needsSeparateType = true; + } + } + if (needsSeparateType) { + readerProducedType_ = + ROW(std::vector(names.begin(), names.end()), + std::move(types)); + } + } +} + +std::unique_ptr FileDataSource::createSplitReader() { + return FileSplitReader::create( + split_, + tableHandle_, + &partitionKeys_, + connectorQueryCtx_, + fileConfig_, + readerOutputType_, + dataIoStats_, + metadataIoStats_, + ioStats_, + fileHandleFactory_, + ioExecutor_, + scanSpec_, + /*subfieldFiltersForValidation=*/&filters_); +} + +void FileDataSource::addSplit(std::shared_ptr split) { + VELOX_CHECK_NULL( + split_, + "Previous split has not been processed yet. Call next to process the split."); + split_ = checkedPointerCast(split); + + VLOG(1) << "Adding split " << split_->toString(); + + if (splitReader_) { + splitReader_.reset(); + } + + splitReader_ = createSplitReader(); + + // Split reader subclasses may need to use the reader options in prepareSplit + // so we initialize it beforehand. + splitReader_->configureReaderOptions(randomSkip_); + splitReader_->prepareSplit(metadataFilter_, runtimeStats_); + readerOutputType_ = splitReader_->readerOutputType(); +} + +std::optional FileDataSource::next( + uint64_t size, + velox::ContinueFuture& /*future*/) { + VELOX_CHECK(split_ != nullptr, "No split to process. Call addSplit first."); + VELOX_CHECK_NOT_NULL(splitReader_, "No split reader present"); + + TestValue::adjust( + "facebook::velox::connector::hive::FileDataSource::next", this); + + if (splitReader_->emptySplit()) { + resetSplit(); + return nullptr; + } + + // Subclass reader may add extra columns to reader output (e.g. for bucket + // conversion or delta update). + auto& outputRowType = + readerProducedType_ ? readerProducedType_ : readerOutputType_; + auto needsExtraColumn = [&] { + return output_->asUnchecked()->childrenSize() < + outputRowType->size(); + }; + if (!output_ || needsExtraColumn()) { + output_ = BaseVector::create(outputRowType, 0, pool_); + } + + const auto rowsScanned = splitReader_->next(size, output_); + completedRows_ += rowsScanned; + if (rowsScanned == 0) { + splitReader_->updateRuntimeStats(runtimeStats_); + resetSplit(); + return nullptr; + } + + VELOX_CHECK( + !output_->mayHaveNulls(), "Top-level row vector cannot have nulls"); + auto rowsRemaining = output_->size(); + if (rowsRemaining == 0) { + // no rows passed the pushed down filters. + return getEmptyOutput(); + } + + auto rowVector = std::dynamic_pointer_cast(output_); + + // In case there is a remaining filter that excludes some but not all + // rows, collect the indices of the passing rows. If there is no filter, + // or it passes on all rows, leave this as null and let exec::wrap skip + // wrapping the results. + BufferPtr remainingIndices; + filterRows_.resize(rowVector->size()); + + if (remainingFilterExprSet_) { + rowsRemaining = evaluateRemainingFilter(rowVector); + VELOX_CHECK_LE(rowsRemaining, rowsScanned); + if (rowsRemaining == 0) { + // No rows passed the remaining filter. + return getEmptyOutput(); + } + + if (rowsRemaining < rowVector->size()) { + // Some, but not all rows passed the remaining filter. + remainingIndices = filterEvalCtx_.selectedIndices; + } + } + + if (outputType_->size() == 0) { + return exec::wrap(rowsRemaining, remainingIndices, rowVector); + } + + std::vector outputColumns; + outputColumns.reserve(outputType_->size()); + for (int i = 0; i < outputType_->size(); ++i) { + auto& child = rowVector->childAt(i); + if (remainingIndices) { + // Disable dictionary values caching in expression eval so that we + // don't need to reallocate the result for every batch. + child->disableMemo(); + } + auto column = exec::wrapChild(rowsRemaining, remainingIndices, child); + if (columnPostProcessors_[i]) { + columnPostProcessors_[i](column); + } + outputColumns.push_back(std::move(column)); + } + + return std::make_shared( + pool_, outputType_, BufferPtr(nullptr), rowsRemaining, outputColumns); +} + +void FileDataSource::addDynamicFilter( + column_index_t outputChannel, + const std::shared_ptr& filter) { + auto& fieldSpec = scanSpec_->getChildByChannel(outputChannel); + fieldSpec.setFilter(filter); + scanSpec_->resetCachedValues(true); + if (splitReader_) { + splitReader_->resetFilterCaches(); + } +} + +void FileDataSource::fireScanBatchCallback(core::ScanBatchEvent event) { + if (!scanBatchCallback_) { + return; + } + FileScanBatchEvent fileEvent; + fileEvent.numRows = event.numRows; + fileEvent.wallTimeMicros = event.wallTimeMicros; + if (tableHandle_) { + fileEvent.tableName = tableHandle_->name(); + } + if (split_) { + fileEvent.filePath = split_->filePath; + if (!split_->partitionKeys.empty()) { + fileEvent.partitionKeys = &split_->partitionKeys; + } + } + scanBatchCallback_(fileEvent); +} + +std::unordered_map +FileDataSource::getRuntimeStats() { + auto res = runtimeStats_.toRuntimeMetricMap(); + addIoStatsToRuntimeStats(*dataIoStats_, "", res); + addIoStatsToRuntimeStats(*metadataIoStats_, kMetadataPrefix, res); + res.insert( + {{std::string(Connector::kTotalRemainingFilterTime), + RuntimeMetric( + totalRemainingFilterTime_.load(std::memory_order_relaxed), + RuntimeCounter::Unit::kNanos)}, + {Connector::kTotalRemainingFilterCpuTime, + RuntimeMetric( + totalRemainingFilterCpuTime_.load(std::memory_order_relaxed), + RuntimeCounter::Unit::kNanos)}}); + + const auto ioStatsMap = ioStats_->stats(); + for (const auto& [key, value] : ioStatsMap) { + // IoStats may carry a ReadFile-layer storageReadBytes that reflects the + // actual bytes fetched from remote storage. Use it to override the + // DWIO-level estimate (IoStatistics). + if (key == kStorageReadBytes) { + res[std::string(key)] = value; + } else { + res.emplace(key, value); + } + } + return res; +} + +void FileDataSource::setFromDataSource( + std::unique_ptr sourceUnique) { + auto source = dynamic_cast(sourceUnique.get()); + VELOX_CHECK_NOT_NULL(source, "Bad DataSource type"); + + split_ = std::move(source->split_); + runtimeStats_.skippedSplits += source->runtimeStats_.skippedSplits; + runtimeStats_.processedSplits += source->runtimeStats_.processedSplits; + runtimeStats_.skippedSplitBytes += source->runtimeStats_.skippedSplitBytes; + readerOutputType_ = std::move(source->readerOutputType_); + readerProducedType_ = std::move(source->readerProducedType_); + extractionColumns_ = std::move(source->extractionColumns_); + source->scanSpec_->moveAdaptationFrom(*scanSpec_); + scanSpec_ = std::move(source->scanSpec_); + metadataFilter_ = std::move(source->metadataFilter_); + splitReader_ = std::move(source->splitReader_); + splitReader_->setConnectorQueryCtx(connectorQueryCtx_); + // New io will be accounted on the stats of 'source'. Add the existing + // balance to that. + source->dataIoStats_->merge(*dataIoStats_); + dataIoStats_ = std::move(source->dataIoStats_); + source->metadataIoStats_->merge(*metadataIoStats_); + metadataIoStats_ = std::move(source->metadataIoStats_); + source->ioStats_->merge(*ioStats_); + ioStats_ = std::move(source->ioStats_); +} + +int64_t FileDataSource::estimatedRowSize() { + if (splitReader_ == nullptr) { + return kUnknownRowSize; + } + auto rowSize = splitReader_->estimatedRowSize(); + TestValue::adjust( + "facebook::velox::connector::hive::FileDataSource::estimatedRowSize", + &rowSize); + return rowSize; +} + +vector_size_t FileDataSource::evaluateRemainingFilter(RowVectorPtr& rowVector) { + for (auto fieldIndex : multiReferencedFields_) { + LazyVector::ensureLoadedRows( + rowVector->childAt(fieldIndex), + filterRows_, + filterLazyDecoded_, + filterLazyBaseRows_); + } + CpuWallTiming filterTiming; + vector_size_t rowsRemaining{0}; + { + CpuWallTimer timer(filterTiming); + expressionEvaluator_->evaluate( + remainingFilterExprSet_.get(), filterRows_, *rowVector, filterResult_); + rowsRemaining = exec::processFilterResults( + filterResult_, filterRows_, filterEvalCtx_, pool_); + } + totalRemainingFilterTime_.fetch_add( + filterTiming.wallNanos, std::memory_order_relaxed); + totalRemainingFilterCpuTime_.fetch_add( + filterTiming.cpuNanos, std::memory_order_relaxed); + return rowsRemaining; +} + +void FileDataSource::resetSplit() { + split_.reset(); + splitReader_->resetSplit(); + // Keep readers around to hold adaptation. +} + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/FileDataSource.h b/velox/connectors/hive/FileDataSource.h new file mode 100644 index 00000000000..b03c19deff6 --- /dev/null +++ b/velox/connectors/hive/FileDataSource.h @@ -0,0 +1,231 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/common/base/RandomUtil.h" +#include "velox/common/file/FileSystems.h" +#include "velox/common/io/IoStatistics.h" +#include "velox/connectors/Connector.h" +#include "velox/connectors/hive/FileConnectorSplit.h" +#include "velox/connectors/hive/FileHandle.h" +#include "velox/connectors/hive/FileSplitReader.h" +#include "velox/connectors/hive/FileTableHandle.h" +#include "velox/connectors/hive/HiveConnectorUtil.h" +#include "velox/dwio/common/Statistics.h" +#include "velox/exec/OperatorUtils.h" +#include "velox/expression/Expr.h" + +namespace facebook::velox::connector::hive { + +/// File-specific scan batch event with split metadata. +struct FileScanBatchEvent : public core::ScanBatchEvent { + /// Table name from the connector table handle. + std::string_view tableName; + /// File path of the current split. + std::string_view filePath; + /// Non-owning pointer to the current split's partition keys. + /// Null when partition keys are not available. + const std::unordered_map>* + partitionKeys{nullptr}; +}; + +class FileConfig; + +/// Base class for file-based data sources that read from columnar file formats +/// (ORC, Parquet, etc.) using FileSplitReader. Provides the common scan +/// pipeline: column resolution, filter extraction, scan spec construction, +/// split reading, remaining filter evaluation, and runtime stats collection. +/// +/// Connector-specific data sources (Hive, Paimon, etc.) extend this class to +/// add format-specific behavior like bucket conversion, multi-file splits, or +/// merge-on-read. +class FileDataSource : public DataSource { + public: + /// Runtime stat keys for file-based data sources. Data IO stats use the + /// keys directly (e.g., "storageReadBytes"). Metadata IO stats use the + /// kMetadataPrefix (e.g., "metadata.storageReadBytes"). + static constexpr std::string_view kMetadataPrefix{"metadata"}; + static constexpr std::string_view kNumPrefetch{"numPrefetch"}; + static constexpr std::string_view kPrefetchBytes{"prefetchBytes"}; + static constexpr std::string_view kTotalScanTime{"totalScanTime"}; + static constexpr std::string_view kOverreadBytes{"overreadBytes"}; + static constexpr std::string_view kStorageReadBytes{"storageReadBytes"}; + static constexpr std::string_view kNumLocalRead{"numLocalRead"}; + static constexpr std::string_view kLocalReadBytes{"localReadBytes"}; + static constexpr std::string_view kNumRamRead{"numRamRead"}; + static constexpr std::string_view kRamReadBytes{"ramReadBytes"}; + + FileDataSource( + const RowTypePtr& outputType, + const connector::ConnectorTableHandlePtr& tableHandle, + const connector::ColumnHandleMap& assignments, + FileHandleFactory* fileHandleFactory, + folly::Executor* ioExecutor, + const ConnectorQueryCtx* connectorQueryCtx, + const std::shared_ptr& fileConfig); + + void addSplit(std::shared_ptr split) override; + + std::optional next(uint64_t size, velox::ContinueFuture& future) + override; + + void addDynamicFilter( + column_index_t outputChannel, + const std::shared_ptr& filter) override; + + uint64_t getCompletedBytes() override { + return dataIoStats_->rawBytesRead(); + } + + uint64_t getCompletedRows() override { + return completedRows_; + } + + void fireScanBatchCallback(core::ScanBatchEvent event) override; + + std::unordered_map getRuntimeStats() override; + + bool allPrefetchIssued() const override { + return splitReader_ && splitReader_->allPrefetchIssued(); + } + + void setFromDataSource(std::unique_ptr sourceUnique) override; + + int64_t estimatedRowSize() override; + + const common::SubfieldFilters* getFilters() const override { + return &filters_; + } + + const ConnectorQueryCtx* testingConnectorQueryCtx() const { + return connectorQueryCtx_; + } + + protected: + virtual std::unique_ptr createSplitReader(); + + FileHandleFactory* const fileHandleFactory_; + folly::Executor* const ioExecutor_; + const ConnectorQueryCtx* const connectorQueryCtx_; + const std::shared_ptr fileConfig_; + memory::MemoryPool* const pool_; + + std::shared_ptr split_; + FileTableHandlePtr tableHandle_; + std::shared_ptr scanSpec_; + VectorPtr output_; + std::unique_ptr splitReader_; + + /// Output type from file reader. This is different from outputType_ in that + /// it contains column names before assignment, and columns that are only used + /// in the remaining filter. + RowTypePtr readerOutputType_; + + /// Column handles for the partition key columns keyed on partition key column + /// name. + std::unordered_map partitionKeys_; + + std::shared_ptr dataIoStats_; + std::shared_ptr metadataIoStats_; + std::shared_ptr ioStats_; + + /// Column handles for the split info columns keyed on their column names. + std::unordered_map infoColumns_; + SpecialColumnNames specialColumns_{}; + + /// Subfield pruning info collected from column handles and remaining filter. + folly::F14FastMap> + subfields_; + common::SubfieldFilters filters_; + + const exec::ExprSet* remainingFilterExprSet() const { + return remainingFilterExprSet_.get(); + } + + const std::shared_ptr& metadataFilter() const { + return metadataFilter_; + } + + // Actual type produced by the reader after extraction pushdown. Differs + // from readerOutputType_ when the reader handles extraction natively + // (e.g., MapKeys -> ARRAY). Null if no extraction pushdown is active. + RowTypePtr readerProducedType_; + + // Output columns that have extraction chains. Maps output column index to + // the column handle. These columns are read with schemaType and transformed + // post-read using the extraction chains. + folly::F14FastMap extractionColumns_; + + dwio::common::RuntimeStatistics runtimeStats_; + + private: + // Configure extraction columns on the ScanSpec and build + // readerProducedType_. Called from the constructor after scanSpec_ is + // created. + void configureExtractionColumns(); + + /// Adds the information from column handle to the corresponding fields in + /// this object. + void processColumnHandle(const FileColumnHandlePtr& handle); + + /// Evaluates remainingFilter_ on the specified vector. Returns number of rows + /// passed. Populates filterEvalCtx_.selectedIndices and selectedBits if only + /// some rows passed the filter. If none or all rows passed + /// filterEvalCtx_.selectedIndices and selectedBits are not updated. + vector_size_t evaluateRemainingFilter(RowVectorPtr& rowVector); + + /// Clears split_ after split has been fully processed. Keeps readers around + /// to hold adaptation. + void resetSplit(); + + const RowVectorPtr& getEmptyOutput() { + if (!emptyOutput_) { + emptyOutput_ = RowVector::createEmpty(outputType_, pool_); + } + return emptyOutput_; + } + + /// The row type for the data source output, not including filter-only + /// columns. + const RowTypePtr outputType_; + core::ExpressionEvaluator* const expressionEvaluator_; + + std::vector remainingFilterSubfields_; + /// Optional post-processors for each output column, collected from + /// HiveColumnHandle::postProcessor(). Applied after reading and filtering to + /// transform column values. Indexed by output column position. + std::vector> columnPostProcessors_; + std::shared_ptr metadataFilter_; + std::unique_ptr remainingFilterExprSet_; + RowVectorPtr emptyOutput_; + std::atomic_uint64_t totalRemainingFilterTime_{0}; + std::atomic_uint64_t totalRemainingFilterCpuTime_{0}; + uint64_t completedRows_{0}; + /// Field indices referenced in both remaining filter and output type. These + /// columns need to be materialized eagerly to avoid missing values in output. + std::vector multiReferencedFields_; + + std::shared_ptr randomSkip_; + + /// Reusable memory for remaining filter evaluation. + VectorPtr filterResult_; + SelectivityVector filterRows_; + DecodedVector filterLazyDecoded_; + SelectivityVector filterLazyBaseRows_; + exec::FilterEvalCtx filterEvalCtx_; +}; + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/FileHandle.cpp b/velox/connectors/hive/FileHandle.cpp index 267691cce2e..94b6bffb615 100644 --- a/velox/connectors/hive/FileHandle.cpp +++ b/velox/connectors/hive/FileHandle.cpp @@ -14,65 +14,6 @@ * limitations under the License. */ +// Implementation moved to velox/common/caching/FileHandle.cpp. +// This file is kept so the hive file_handle BUCK target still compiles. #include "velox/connectors/hive/FileHandle.h" -#include "velox/common/base/Counters.h" -#include "velox/common/base/StatsReporter.h" -#include "velox/common/file/FileSystems.h" -#include "velox/common/time/Timer.h" - -#include - -namespace facebook::velox { - -uint64_t FileHandleSizer::operator()(const FileHandle& fileHandle) { - // TODO: add to support variable file cache size support when the file system - // underneath supports. - return 1; -} - -namespace { -// The group tracking is at the level of the directory, i.e. Hive partition. -std::string groupName(const std::string& filename) { - const char* slash = strrchr(filename.c_str(), '/'); - return slash ? std::string(filename.data(), slash - filename.data()) - : filename; -} -} // namespace - -std::unique_ptr FileHandleGenerator::operator()( - const FileHandleKey& key, - const FileProperties* properties, - filesystems::File::IoStats* stats) { - // We have seen cases where drivers are stuck when creating file handles. - // Adding a trace here to spot this more easily in future. - process::TraceContext trace("FileHandleGenerator::operator()"); - uint64_t elapsedTimeUs{0}; - std::unique_ptr fileHandle; - { - MicrosecondTimer timer(&elapsedTimeUs); - fileHandle = std::make_unique(); - filesystems::FileOptions options; - options.stats = stats; - options.tokenProvider = key.tokenProvider; - if (properties) { - options.fileSize = properties->fileSize; - options.readRangeHint = properties->readRangeHint; - options.extraFileInfo = properties->extraFileInfo; - } - const auto& filename = key.filename; - fileHandle->file = filesystems::getFileSystem(filename, properties_) - ->openFileForRead(filename, options); - fileHandle->uuid = StringIdLease(fileIds(), filename); - fileHandle->groupId = StringIdLease(fileIds(), groupName(filename)); - VLOG(1) << "Generating file handle for: " << filename - << " uuid: " << fileHandle->uuid.id(); - } - RECORD_HISTOGRAM_METRIC_VALUE( - kMetricHiveFileHandleGenerateLatencyMs, elapsedTimeUs / 1000); - // TODO: build the hash map/etc per file type -- presumably after reading - // the appropriate magic number from the file, or perhaps we include the file - // type in the file handle key. - return fileHandle; -} - -} // namespace facebook::velox diff --git a/velox/connectors/hive/FileHandle.h b/velox/connectors/hive/FileHandle.h index 6f9b4050c31..7ac67f1a33c 100644 --- a/velox/connectors/hive/FileHandle.h +++ b/velox/connectors/hive/FileHandle.h @@ -14,117 +14,8 @@ * limitations under the License. */ -// A FileHandle is a File pointer plus some (optional, file-type-dependent) -// extra information for speeding up loading columnar data. For example, when -// we open a file we might build a hash map saying what region(s) on disk -// correspond to a given column in a given stripe. -// -// The FileHandle will normally be used in conjunction with a CachedFactory -// to speed up queries that hit the same files repeatedly; see the -// FileHandleCache and FileHandleFactory. - #pragma once -#include "velox/common/base/BitUtil.h" -#include "velox/common/caching/CachedFactory.h" -#include "velox/common/caching/FileIds.h" -#include "velox/common/config/Config.h" -#include "velox/common/file/File.h" -#include "velox/common/file/TokenProvider.h" -#include "velox/connectors/hive/FileProperties.h" - -namespace facebook::velox { - -// See the file comment. -struct FileHandle { - std::shared_ptr file; - - // Each time we make a new FileHandle we assign it a uuid and use that id as - // the identifier in downstream data caching structures. This saves a lot of - // memory compared to using the filename as the identifier. - StringIdLease uuid; - - // Id for the group of files this belongs to, e.g. its - // directory. Used for coarse granularity access tracking, for - // example to decide placing on SSD. - StringIdLease groupId; - - // We'll want to have a hash map here to record the identifier->byte range - // mappings. Different formats may have different identifiers, so we may need - // a union of maps. For example in orc you need 3 integers (I think, to be - // confirmed with xldb): the row bundle, the node, and the sequence. For the - // first diff we'll not include the map. -}; - -/// Estimates the memory usage of a FileHandle object. -struct FileHandleSizer { - uint64_t operator()(const FileHandle& a); -}; - -struct FileHandleKey { - std::string filename; - std::shared_ptr tokenProvider{nullptr}; - - bool operator==(const FileHandleKey& other) const { - if (filename != other.filename) { - return false; - } - - if (tokenProvider == other.tokenProvider) { - return true; - } - - if (!tokenProvider || !other.tokenProvider) { - return false; - } - - return tokenProvider->equals(*other.tokenProvider); - } -}; - -} // namespace facebook::velox - -namespace std { -template <> -struct hash { - size_t operator()(const facebook::velox::FileHandleKey& key) const noexcept { - size_t filenameHash = std::hash()(key.filename); - return key.tokenProvider ? facebook::velox::bits::hashMix( - filenameHash, key.tokenProvider->hash()) - : filenameHash; - } -}; -} // namespace std - -namespace facebook::velox { -using FileHandleCache = - SimpleLRUCache; - -// Creates FileHandles via the Generator interface the CachedFactory requires. -class FileHandleGenerator { - public: - FileHandleGenerator() {} - FileHandleGenerator(std::shared_ptr properties) - : properties_(std::move(properties)) {} - std::unique_ptr operator()( - const FileHandleKey& filename, - const FileProperties* properties, - filesystems::File::IoStats* stats); - - private: - const std::shared_ptr properties_; -}; - -using FileHandleFactory = CachedFactory< - FileHandleKey, - FileHandle, - FileHandleGenerator, - FileProperties, - filesystems::File::IoStats, - FileHandleSizer>; - -using FileHandleCachedPtr = CachedPtr; - -using FileHandleCacheStats = SimpleLRUCacheStats; - -} // namespace facebook::velox +// Moved to velox/common/caching/FileHandle.h. +// This header is kept for backward compatibility. +#include "velox/common/caching/FileHandle.h" diff --git a/velox/connectors/hive/FileIndexReader.cpp b/velox/connectors/hive/FileIndexReader.cpp new file mode 100644 index 00000000000..844f3bdc594 --- /dev/null +++ b/velox/connectors/hive/FileIndexReader.cpp @@ -0,0 +1,458 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/FileIndexReader.h" + +#include "velox/common/Casts.h" +#include "velox/connectors/hive/BufferedInputBuilder.h" +#include "velox/connectors/hive/FileConfig.h" +#include "velox/connectors/hive/HiveConnectorSplit.h" +#include "velox/connectors/hive/HiveConnectorUtil.h" +#include "velox/connectors/hive/TableHandle.h" +#include "velox/dwio/common/ReaderFactory.h" +#include "velox/serializers/KeyEncoder.h" + +namespace facebook::velox::connector::hive { +namespace { + +// Gets the index of a column in the request type by name. +// Throws if the column is not found. +column_index_t getRequestColumnIndex( + const std::string& name, + const RowTypePtr& requestType) { + const auto idxOpt = requestType->getChildIdxIfExists(name); + VELOX_CHECK( + idxOpt.has_value(), "Request column {} not found in request type", name); + return idxOpt.value(); +} + +// Processes a between condition bound (lower or upper). +// Returns true if the bound is a field access (non-constant). +void processBetweenBound( + const core::TypedExprPtr& bound, + size_t boundIndex, + const char* boundName, + std::vector>& constantBoundValues, + std::vector& requestColumnIndices, + const RowTypePtr& requestType) { + if (auto constantExpr = + std::dynamic_pointer_cast(bound)) { + VELOX_CHECK( + !constantExpr->hasValueVector(), + "Complex constant values not supported for between condition {} bound", + boundName); + VELOX_CHECK( + !constantExpr->value().isNull(), + "Null constant value not allowed for between condition {} bound", + boundName); + constantBoundValues[boundIndex] = constantExpr->value(); + requestColumnIndices[boundIndex] = kConstantChannel; + } else { + auto fieldAccess = + checkedPointerCast(bound); + requestColumnIndices[boundIndex] = + getRequestColumnIndex(fieldAccess->name(), requestType); + } +} + +} // namespace + +FileIndexReader::FileIndexReader( + std::shared_ptr hiveSplit, + const std::shared_ptr& hiveTableHandle, + const ConnectorQueryCtx* connectorQueryCtx, + const std::shared_ptr& fileConfig, + const std::shared_ptr& scanSpec, + const std::vector& indexLookupConditions, + const RowTypePtr& requestType, + const RowTypePtr& outputType, + const std::shared_ptr& ioStatistics, + const std::shared_ptr& ioStats, + FileHandleFactory* fileHandleFactory, + folly::Executor* ioExecutor, + uint32_t maxRowsPerRequest) + : tableHandle_{hiveTableHandle}, + connectorQueryCtx_{connectorQueryCtx}, + fileConfig_{fileConfig}, + fileHandleFactory_{fileHandleFactory}, + requestType_{requestType}, + outputType_{outputType}, + ioStatistics_{ioStatistics}, + ioStats_{ioStats}, + ioExecutor_{ioExecutor}, + pool_{connectorQueryCtx->memoryPool()}, + scanSpec_{scanSpec}, + indexLookupConditions_{indexLookupConditions}, + maxRowsPerRequest_{maxRowsPerRequest}, + hiveSplit_{std::move(hiveSplit)}, + fileReader_{createFileReader()}, + indexReader_{createIndexReader()} { + parseIndexLookupConditions(); +} + +void FileIndexReader::parseIndexLookupConditions() { + VELOX_CHECK( + !indexLookupConditions_.empty(), + "Index lookup conditions cannot be empty"); + const auto& indexColumns = tableHandle_->indexColumns(); + VELOX_CHECK( + !indexColumns.empty(), "Index columns not set in hive table handle"); + VELOX_CHECK_LE(indexLookupConditions_.size(), indexColumns.size()); + VELOX_CHECK_LE( + indexLookupConditions_.size(), + indexColumns.size(), + "Too many index lookup conditions"); + const auto numIndexLookupConditions = indexLookupConditions_.size(); + requestColumnIndices_.resize(numIndexLookupConditions); + constantBoundValues_.resize(numIndexLookupConditions); + + // For building indexBoundType_ during condition processing. + std::vector indexColumnNames; + indexColumnNames.reserve(numIndexLookupConditions); + std::vector indexColumnTypes; + indexColumnTypes.reserve(numIndexLookupConditions); + + // Validate and process index lookup conditions: + // - Index lookup conditions combined with index filters must cover all table + // index + // columns in order. + // - For each index column, it must have either an index lookup condition OR a + // filter + // in the scan spec, but not both. + // - At least one index lookup condition must have a non-constant value (field + // access + // from probe side). + // - For between conditions, at least one bound must be non-constant. + // - Processing stops when we encounter a range filter or between condition. + size_t numValidIndexLookupConditions{0}; + bool hasNonConstantCondition{false}; + for (size_t i = 0; i < indexLookupConditions_.size(); ++i) { + const auto& indexColumn = indexColumns[i]; + // Process the index lookup condition for this index column. + const auto& condition = indexLookupConditions_[i]; + // Validate that the condition's key column matches the expected index + // column. + const auto& conditionKeyName = + condition->key->asUnchecked()->name(); + VELOX_CHECK_EQ( + conditionKeyName, + indexColumn, + "Index lookup condition key column does not match expected index column at position {}. Index lookup conditions must follow index column order.", + i); + const auto* spec = scanSpec_->childByName(indexColumn); + VELOX_CHECK( + spec == nullptr || !spec->hasFilter(), + "Index column '{}' cannot have both an index lookup condition and a filter at position {}", + indexColumn, + i); + + // Determine the request column index for the condition value. + if (auto equalCondition = + std::dynamic_pointer_cast( + condition)) { + // Check if the value is a constant or a field access. + if (auto constantValue = + std::dynamic_pointer_cast( + equalCondition->value)) { + // Constant value - store in constantBoundValues_ and use + // kConstantChannel. + VELOX_CHECK( + !constantValue->hasValueVector(), + "Complex constant values not supported for equal condition value"); + VELOX_CHECK( + !constantValue->value().isNull(), + "Null constant value not allowed for equal condition value"); + constantBoundValues_[i].resize(1); + constantBoundValues_[i][0] = constantValue->value(); + requestColumnIndices_[i] = {kConstantChannel}; + } else { + // Field access - get request column index. + const auto requestFieldAccess = + checkedPointerCast( + equalCondition->value); + requestColumnIndices_[i] = { + getRequestColumnIndex(requestFieldAccess->name(), requestType_)}; + hasNonConstantCondition = true; + } + // Collect column name and type for indexBoundType_. + indexColumnNames.push_back(indexColumn); + indexColumnTypes.push_back(condition->key->type()); + ++numValidIndexLookupConditions; + continue; + } + + if (auto betweenCondition = + std::dynamic_pointer_cast( + condition)) { + constantBoundValues_[i].resize(2); + requestColumnIndices_[i].resize(2); + + processBetweenBound( + betweenCondition->lower, + 0, + "lower", + constantBoundValues_[i], + requestColumnIndices_[i], + requestType_); + processBetweenBound( + betweenCondition->upper, + 1, + "upper", + constantBoundValues_[i], + requestColumnIndices_[i], + requestType_); + + // Track if this between condition has at least one non-constant bound. + if (requestColumnIndices_[i][0] != kConstantChannel || + requestColumnIndices_[i][1] != kConstantChannel) { + hasNonConstantCondition = true; + } + + // Collect column name and type for indexBoundType_. + indexColumnNames.push_back(indexColumn); + indexColumnTypes.push_back(condition->key->type()); + ++numValidIndexLookupConditions; + // Between condition is a range condition, stop processing further. + break; + } + + VELOX_FAIL( + "Unsupported index lookup condition type: {}", condition->toString()); + } + VELOX_CHECK_EQ(numValidIndexLookupConditions, indexLookupConditions_.size()); + VELOX_CHECK( + hasNonConstantCondition, + "At least one index lookup condition must have a non-constant value"); + + // Build and cache the index bound row type. + indexBoundType_ = + ROW(std::move(indexColumnNames), std::move(indexColumnTypes)); +} + +std::unique_ptr FileIndexReader::createFileReader() { + VELOX_CHECK_NOT_NULL(hiveSplit_); + + dwio::common::ReaderOptions readerOpts(connectorQueryCtx_->memoryPool()); + // TODO: Use separate IoStatistics for data and metadata. + readerOpts.setDataIoStats(ioStatistics_); + readerOpts.setMetadataIoStats(ioStatistics_); + hive::configureReaderOptions( + fileConfig_, + connectorQueryCtx_, + tableHandle_, + hiveSplit_, + hiveSplit_->serdeParameters, + readerOpts); + readerOpts.setScanSpec(scanSpec_); + readerOpts.setFileFormat(hiveSplit_->fileFormat); + VELOX_CHECK_NULL(readerOpts.randomSkip()); + + FileHandleKey fileHandleKey{ + .filename = hiveSplit_->filePath, + .tokenProvider = connectorQueryCtx_->fsTokenProvider()}; + + auto fileProperties = hiveSplit_->properties.value_or(FileProperties{}); + auto fileHandleCachePtr = fileHandleFactory_->generate( + fileHandleKey, &fileProperties, ioStats_ ? ioStats_.get() : nullptr); + VELOX_CHECK_NOT_NULL(fileHandleCachePtr.get()); + + auto baseFileInput = BufferedInputBuilder::getInstance()->create( + *fileHandleCachePtr, + readerOpts, + connectorQueryCtx_, + ioStatistics_, + ioStats_, + ioExecutor_, + /*fileReadOps=*/{}); + + auto reader = dwio::common::getReaderFactory(readerOpts.fileFormat()) + ->createReader(std::move(baseFileInput), readerOpts); + VELOX_CHECK_NOT_NULL(reader); + return reader; +} + +std::unique_ptr +FileIndexReader::createIndexReader() { + VELOX_CHECK_NOT_NULL(fileReader_); + VELOX_CHECK( + hiveSplit_->fileFormat == dwio::common::FileFormat::NIMBLE || + hiveSplit_->fileFormat == dwio::common::FileFormat::FLUX || + hiveSplit_->fileFormat == dwio::common::FileFormat::SST, + "FileIndexReader only supports Nimble, Flux and SST file formats"); + + dwio::common::RowReaderOptions rowReaderOpts; + configureRowReaderOptions( + tableHandle_->tableParameters(), + scanSpec_, + /*metadataFilter=*/nullptr, + outputType_, + hiveSplit_, + hiveSplit_->serdeParameters, + fileConfig_, + connectorQueryCtx_->sessionProperties(), + ioExecutor_, + rowReaderOpts); + if (hiveSplit_->fileFormat != dwio::common::FileFormat::SST) { + rowReaderOpts.setIndexEnabled(true); + // Disable eager first stripe load since FileIndexReader loads stripes + // on-demand based on index lookup results. + rowReaderOpts.setEagerFirstStripeLoad(false); + } + return fileReader_->createIndexReader(rowReaderOpts); +} + +void FileIndexReader::startLookup( + const Request& request, + const Options& options) { + // Empty files have no cluster index, so indexReader_ is null. + if (indexReader_ == nullptr) { + return; + } + VELOX_CHECK( + !indexReader_->hasNext(), + "Previous request not finished. Call next() first."); + VELOX_CHECK_NOT_NULL(request.input); + VELOX_CHECK(requestType_->equivalent(*request.input->type())); + + // Use caller-provided maxRowsPerRequest if set, otherwise fall back to + // the construction-time default. + const auto maxRows = options.maxRowsPerRequest != 0 + ? options.maxRowsPerRequest + : static_cast(maxRowsPerRequest_); + auto indexBounds = buildRequestIndexBounds(request.input); + indexReader_->startLookup(indexBounds, {.maxRowsPerRequest = maxRows}); +} + +serializer::IndexBounds FileIndexReader::buildRequestIndexBounds( + const RowVectorPtr& request) { + VELOX_CHECK_NOT_NULL(request); + VELOX_CHECK_NOT_NULL(indexBoundType_); + + const auto numRows = request->size(); + + // Resize and clear reusable column vectors. + lowerBoundColumns_.resize(indexLookupConditions_.size()); + upperBoundColumns_.resize(indexLookupConditions_.size()); + + for (size_t i = 0; i < indexLookupConditions_.size(); ++i) { + const auto& condition = indexLookupConditions_[i]; + const auto& type = condition->key->type(); + + if (auto equalCondition = + std::dynamic_pointer_cast( + condition)) { + // For equal condition, lower and upper bounds have the same value. + const auto colIdx = requestColumnIndices_[i][0]; + if (colIdx == kConstantChannel) { + auto constVector = BaseVector::createConstant( + type, constantBoundValues_[i][0].value(), numRows, pool_); + lowerBoundColumns_[i] = constVector; + upperBoundColumns_[i] = constVector; + } else { + auto valueVector = request->childAt(colIdx); + lowerBoundColumns_[i] = valueVector; + upperBoundColumns_[i] = valueVector; + } + } else if ( + auto betweenCondition = + std::dynamic_pointer_cast( + condition)) { + // Handle lower bound. + if (constantBoundValues_[i][0].has_value()) { + auto constVector = BaseVector::createConstant( + type, constantBoundValues_[i][0].value(), numRows, pool_); + lowerBoundColumns_[i] = constVector; + } else { + const auto colIdx = requestColumnIndices_[i][0]; + lowerBoundColumns_[i] = request->childAt(colIdx); + } + + // Handle upper bound. + if (constantBoundValues_[i][1].has_value()) { + auto constVector = BaseVector::createConstant( + type, constantBoundValues_[i][1].value(), numRows, pool_); + upperBoundColumns_[i] = constVector; + } else { + const auto colIdx = requestColumnIndices_[i][1]; + upperBoundColumns_[i] = request->childAt(colIdx); + } + } else { + VELOX_FAIL( + "Unsupported index lookup condition type: {}", condition->toString()); + } + } + + // Build RowVectors for lower and upper bounds. + auto lowerBoundVector = std::make_shared( + pool_, indexBoundType_, nullptr, numRows, lowerBoundColumns_); + auto upperBoundVector = std::make_shared( + pool_, indexBoundType_, nullptr, numRows, upperBoundColumns_); + + // Collect column names for indexBounds. + std::vector indexColumnNames; + indexColumnNames.reserve(indexLookupConditions_.size()); + for (const auto& col : indexBoundType_->names()) { + indexColumnNames.push_back(col); + } + + serializer::IndexBounds bounds; + bounds.indexColumns = std::move(indexColumnNames); + bounds.set( + serializer::IndexBound{std::move(lowerBoundVector), /*inclusive=*/true}, + serializer::IndexBound{std::move(upperBoundVector), /*inclusive=*/true}); + return bounds; +} + +bool FileIndexReader::hasNext() { + return indexReader_ != nullptr && indexReader_->hasNext(); +} + +std::unique_ptr FileIndexReader::next( + vector_size_t maxOutputRows) { + VELOX_CHECK_NOT_NULL(indexReader_); + auto result = indexReader_->next(maxOutputRows); + if (result != nullptr) { + numIndexOutputRows_ += result->size(); + } + return result; +} + +std::unordered_map FileIndexReader::runtimeStats() { + std::unordered_map stats; + if (numIndexOutputRows_ != 0) { + stats[std::string(kNumIndexReaderOutputRows)] = RuntimeMetric( + static_cast(numIndexOutputRows_), RuntimeCounter::Unit::kNone); + } + if (indexReader_ != nullptr) { + for (auto& [key, metric] : indexReader_->stats()) { + auto [_, inserted] = stats.emplace(key, metric); + VELOX_CHECK( + inserted, "Duplicate runtime stat '{}' from index reader", key); + } + } + return stats; +} + +std::string FileIndexReader::toString() const { + return fmt::format( + "FileIndexReader: split={} scanSpec={} requestType={} outputType={}", + hiveSplit_->toString(), + scanSpec_->toString(), + requestType_->toString(), + outputType_->toString()); +} + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/FileIndexReader.h b/velox/connectors/hive/FileIndexReader.h new file mode 100644 index 00000000000..1d7d85f3bb9 --- /dev/null +++ b/velox/connectors/hive/FileIndexReader.h @@ -0,0 +1,155 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/common/file/FileSystems.h" +#include "velox/connectors/Connector.h" +#include "velox/connectors/hive/FileHandle.h" +#include "velox/connectors/hive/IndexReader.h" +#include "velox/core/PlanNode.h" +#include "velox/dwio/common/Reader.h" +#include "velox/serializers/KeyEncoder.h" + +namespace facebook::velox::connector { +class ConnectorQueryCtx; +} // namespace facebook::velox::connector + +namespace facebook::velox::connector::hive { + +struct HiveConnectorSplit; +class HiveTableHandle; +class HiveColumnHandle; +class FileConfig; + +/// Handles index lookups for a single Nimble file with cluster indexes. +/// Focuses on: +/// - Creating index bounds from index lookup conditions +/// - Delegating actual index lookups to the format-specific IndexReader +/// +/// Each FileIndexReader operates on exactly one file (split). HiveIndexSource +/// handles multi-split orchestration (union, partition routing) on top. +/// +/// The format-specific IndexReader (e.g., SelectiveNimbleIndexReader) handles: +/// - Encoding keys into format-specific representations +/// - Stripe iteration and row range computation +/// - Data reading and output assembly +class FileIndexReader : public SplitIndexReader { + public: + FileIndexReader( + std::shared_ptr hiveSplit, + const std::shared_ptr& hiveTableHandle, + const ConnectorQueryCtx* connectorQueryCtx, + const std::shared_ptr& fileConfig, + const std::shared_ptr& scanSpec, + const std::vector& indexLookupConditions, + const RowTypePtr& requestType, + const RowTypePtr& outputType, + const std::shared_ptr& ioStats, + const std::shared_ptr& fsStats, + FileHandleFactory* fileHandleFactory, + folly::Executor* ioExecutor, + uint32_t maxRowsPerRequest = 0); + + ~FileIndexReader() override = default; + + /// Sets the input for index lookup. Each row in 'input' will be converted + /// to index bounds and passed to the format-specific IndexReader. + /// + /// @param request The lookup request containing input row vector with lookup + /// keys. + /// @param options Options controlling index reader behavior (e.g., + /// maxRowsPerRequest). Defaults to no limit. + void startLookup(const Request& request, const Options& options = {}) + override; + + /// Returns true if there are more results to fetch from the current lookup. + bool hasNext() override; + + /// Returns the next batch of matching rows for the current input rows. + /// The result from a single request row is never split across multiple + /// calls to next(). + /// + /// @param maxOutputRows Maximum number of output rows to return. + /// @return Result containing matched rows and input hit indices, or nullptr + /// if no more results. + std::unique_ptr next(vector_size_t maxOutputRows) override; + + std::unordered_map runtimeStats() override; + + std::string toString() const; + + private: + // Creates the file reader for reading file metadata and schema. + std::unique_ptr createFileReader(); + + // Creates the format-specific index reader. + std::unique_ptr createIndexReader(); + + // Parses index lookup conditions to extract column indices and constant + // values. + void parseIndexLookupConditions(); + + // Builds IndexBounds from the request row vector. + serializer::IndexBounds buildRequestIndexBounds(const RowVectorPtr& request); + + const std::shared_ptr tableHandle_; + const ConnectorQueryCtx* connectorQueryCtx_; + const std::shared_ptr fileConfig_; + FileHandleFactory* const fileHandleFactory_; + const RowTypePtr requestType_; + const RowTypePtr outputType_; + const std::shared_ptr ioStatistics_; + const std::shared_ptr ioStats_; + folly::Executor* const ioExecutor_; + memory::MemoryPool* const pool_; + const std::shared_ptr scanSpec_; + // Index lookup conditions (including equal conditions converted from lookup + // keys). + const std::vector indexLookupConditions_; + // Maximum output rows per index request batch. 0 means no limit. + const uint32_t maxRowsPerRequest_; + + // Split must be initialized before fileReader_/indexReader_ since they + // depend on it during construction. + std::shared_ptr hiveSplit_; + const std::unique_ptr fileReader_; + const std::unique_ptr indexReader_; + + // Request column indices for each index lookup condition (for probe side + // columns). For EqualIndexLookupCondition, stores {valueIndex}. For + // BetweenIndexLookupCondition, stores {lowerIndex, upperIndex}. + std::vector> requestColumnIndices_; + + // For BetweenIndexLookupCondition with constant bounds, stores the constant + // values directly. The outer vector is indexed by index lookup condition + // index. The inner vector has size 2 for between conditions (lower, upper). + // If a bound is a constant, the corresponding optional contains the value; + // otherwise it's std::nullopt and the value should be decoded from request. + std::vector>> constantBoundValues_; + + // Cached row type for index bounds (column names and types from index lookup + // conditions). + RowTypePtr indexBoundType_; + + // Reusable column vectors for building index bounds. + std::vector lowerBoundColumns_; + std::vector upperBoundColumns_; + + uint64_t numIndexOutputRows_{0}; +}; + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/FileProperties.h b/velox/connectors/hive/FileProperties.h index d3ed9e3cbd6..67c409a04bb 100644 --- a/velox/connectors/hive/FileProperties.h +++ b/velox/connectors/hive/FileProperties.h @@ -14,26 +14,8 @@ * limitations under the License. */ -// A FileHandle is a File pointer plus some (optional, file-type-dependent) -// extra information for speeding up loading columnar data. For example, when -// we open a file we might build a hash map saying what region(s) on disk -// correspond to a given column in a given stripe. -// -// The FileHandle will normally be used in conjunction with a CachedFactory -// to speed up queries that hit the same files repeatedly; see the -// FileHandleCache and FileHandleFactory. - #pragma once -#include - -namespace facebook::velox { - -struct FileProperties { - std::optional fileSize; - std::optional modificationTime; - std::optional readRangeHint{std::nullopt}; - std::shared_ptr extraFileInfo{nullptr}; -}; - -} // namespace facebook::velox +// Moved to velox/common/caching/FileProperties.h. +// This header is kept for backward compatibility. +#include "velox/common/caching/FileProperties.h" diff --git a/velox/connectors/hive/SplitReader.cpp b/velox/connectors/hive/FileSplitReader.cpp similarity index 54% rename from velox/connectors/hive/SplitReader.cpp rename to velox/connectors/hive/FileSplitReader.cpp index 166a1f449b6..01895ead23f 100644 --- a/velox/connectors/hive/SplitReader.cpp +++ b/velox/connectors/hive/FileSplitReader.cpp @@ -14,152 +14,178 @@ * limitations under the License. */ -#include "velox/connectors/hive/SplitReader.h" +#include "velox/connectors/hive/FileSplitReader.h" #include "velox/common/caching/CacheTTLController.h" -#include "velox/connectors/hive/HiveConfig.h" -#include "velox/connectors/hive/HiveConnectorSplit.h" -#include "velox/connectors/hive/HiveConnectorUtil.h" -#include "velox/connectors/hive/TableHandle.h" -#include "velox/connectors/hive/iceberg/IcebergSplitReader.h" +#include "velox/connectors/hive/BufferedInputBuilder.h" +#include "velox/connectors/hive/FileConfig.h" +#include "velox/connectors/hive/FileConnectorSplit.h" +#include "velox/connectors/hive/FileConnectorUtil.h" #include "velox/dwio/common/ReaderFactory.h" +#include "velox/type/DecimalUtil.h" namespace facebook::velox::connector::hive { namespace { template -VectorPtr newConstantFromString( +VectorPtr newConstantFromStringImpl( const TypePtr& type, const std::optional& value, - vector_size_t size, velox::memory::MemoryPool* pool, - const std::string& sessionTimezone, - bool asLocalTime, - bool isPartitionDateDaysSinceEpoch = false) { + bool isLocalTimestamp, + bool isDaysSinceEpoch) { using T = typename TypeTraits::NativeType; if (!value.has_value()) { - return std::make_shared>(pool, size, true, type, T()); + return std::make_shared>(pool, 1, true, type, T()); } if (type->isDate()) { int32_t days = 0; // For Iceberg, the date partition values are already in daysSinceEpoch // form. - if (isPartitionDateDaysSinceEpoch) { + if (isDaysSinceEpoch) { days = folly::to(value.value()); } else { - days = DATE()->toDays(static_cast(value.value())); + days = DATE()->toDays(value.value()); } return std::make_shared>( - pool, size, false, type, std::move(days)); + pool, 1, false, type, std::move(days)); } + if constexpr (std::is_same_v || std::is_same_v) { + if (type->isDecimal()) { + T decimalValue = 0; + auto [precision, scale] = getDecimalPrecisionScale(*type); + auto status = DecimalUtil::castFromString( + StringView(value.value()), precision, scale, decimalValue); + if (!status.ok()) { + VELOX_USER_FAIL(status.message()); + } + return std::make_shared>( + pool, 1, false, type, std::move(decimalValue)); + } + } if constexpr (std::is_same_v) { return std::make_shared>( - pool, size, false, type, StringView(value.value())); + pool, 1, false, type, StringView(value.value())); } else { auto copy = velox::util::Converter::tryCast(value.value()) .thenOrThrow(folly::identity, [&](const Status& status) { VELOX_USER_FAIL("{}", status.message()); }); if constexpr (kind == TypeKind::TIMESTAMP) { - if (asLocalTime) { + if (isLocalTimestamp) { copy.toGMT(Timestamp::defaultTimezone()); } } return std::make_shared>( - pool, size, false, type, std::move(copy)); + pool, 1, false, type, std::move(copy)); } } } // namespace -std::unique_ptr SplitReader::create( - const std::shared_ptr& hiveSplit, - const HiveTableHandlePtr& hiveTableHandle, - const std::unordered_map* partitionKeys, +VectorPtr newConstantFromString( + const TypePtr& type, + const std::optional& value, + velox::memory::MemoryPool* pool, + bool isLocalTimestamp, + bool isDaysSinceEpoch) { + return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH_ALL( + newConstantFromStringImpl, + type->kind(), + type, + value, + pool, + isLocalTimestamp, + isDaysSinceEpoch); +} + +std::unique_ptr FileSplitReader::create( + const std::shared_ptr& fileSplit, + const FileTableHandlePtr& tableHandle, + const std::unordered_map* partitionKeys, const ConnectorQueryCtx* connectorQueryCtx, - const std::shared_ptr& hiveConfig, + const std::shared_ptr& fileConfig, const RowTypePtr& readerOutputType, - const std::shared_ptr& ioStats, - const std::shared_ptr& fsStats, + const std::shared_ptr& dataIoStats, + const std::shared_ptr& metadataIoStats, + const std::shared_ptr& ioStats, FileHandleFactory* fileHandleFactory, folly::Executor* ioExecutor, - const std::shared_ptr& scanSpec) { - // Create the SplitReader based on hiveSplit->customSplitInfo["table_format"] - if (hiveSplit->customSplitInfo.count("table_format") > 0 && - hiveSplit->customSplitInfo["table_format"] == "hive-iceberg") { - return std::make_unique( - hiveSplit, - hiveTableHandle, - partitionKeys, - connectorQueryCtx, - hiveConfig, - readerOutputType, - ioStats, - fsStats, - fileHandleFactory, - ioExecutor, - scanSpec); - } else { - return std::unique_ptr(new SplitReader( - hiveSplit, - hiveTableHandle, - partitionKeys, - connectorQueryCtx, - hiveConfig, - readerOutputType, - ioStats, - fsStats, - fileHandleFactory, - ioExecutor, - scanSpec)); - } + const std::shared_ptr& scanSpec, + const common::SubfieldFilters* subfieldFiltersForValidation) { + return std::unique_ptr(new FileSplitReader( + fileSplit, + tableHandle, + partitionKeys, + connectorQueryCtx, + fileConfig, + readerOutputType, + dataIoStats, + metadataIoStats, + ioStats, + fileHandleFactory, + ioExecutor, + scanSpec, + subfieldFiltersForValidation)); } -SplitReader::SplitReader( - const std::shared_ptr& hiveSplit, - const HiveTableHandlePtr& hiveTableHandle, - const std::unordered_map* partitionKeys, +FileSplitReader::FileSplitReader( + const std::shared_ptr& fileSplit, + const FileTableHandlePtr& tableHandle, + const std::unordered_map* partitionKeys, const ConnectorQueryCtx* connectorQueryCtx, - const std::shared_ptr& hiveConfig, + const std::shared_ptr& fileConfig, const RowTypePtr& readerOutputType, - const std::shared_ptr& ioStats, - const std::shared_ptr& fsStats, + const std::shared_ptr& dataIoStats, + const std::shared_ptr& metadataIoStats, + const std::shared_ptr& ioStats, FileHandleFactory* fileHandleFactory, folly::Executor* ioExecutor, - const std::shared_ptr& scanSpec) - : hiveSplit_(hiveSplit), - hiveTableHandle_(hiveTableHandle), + const std::shared_ptr& scanSpec, + const common::SubfieldFilters* subfieldFiltersForValidation) + : tableHandle_(tableHandle), partitionKeys_(partitionKeys), - connectorQueryCtx_(connectorQueryCtx), - hiveConfig_(hiveConfig), - readerOutputType_(readerOutputType), + fileConfig_(fileConfig), + dataIoStats_(dataIoStats), + metadataIoStats_(metadataIoStats), ioStats_(ioStats), - fsStats_(fsStats), fileHandleFactory_(fileHandleFactory), ioExecutor_(ioExecutor), pool_(connectorQueryCtx->memoryPool()), scanSpec_(scanSpec), + subfieldFiltersForValidation_(subfieldFiltersForValidation), + fileSplit_(fileSplit), + connectorQueryCtx_(connectorQueryCtx), + readerOutputType_(readerOutputType), baseReaderOpts_(connectorQueryCtx->memoryPool()), - emptySplit_(false) {} + emptySplit_(false) { + baseReaderOpts_.setDataIoStats(dataIoStats_); + baseReaderOpts_.setMetadataIoStats(metadataIoStats_); +} -void SplitReader::configureReaderOptions( +void FileSplitReader::configureReaderOptions( std::shared_ptr randomSkip) { + configureBaseReaderOptions(); + baseReaderOpts_.setRandomSkip(std::move(randomSkip)); + baseReaderOpts_.setScanSpec(scanSpec_); + baseReaderOpts_.setFileFormat(fileSplit_->fileFormat); +} + +void FileSplitReader::configureBaseReaderOptions() { hive::configureReaderOptions( - hiveConfig_, + fileConfig_, connectorQueryCtx_, - hiveTableHandle_, - hiveSplit_, + tableHandle_, + fileSplit_, baseReaderOpts_); - baseReaderOpts_.setRandomSkip(std::move(randomSkip)); - baseReaderOpts_.setScanSpec(scanSpec_); - baseReaderOpts_.setFileFormat(hiveSplit_->fileFormat); } -void SplitReader::prepareSplit( +void FileSplitReader::prepareSplit( std::shared_ptr metadataFilter, - dwio::common::RuntimeStatistics& runtimeStats) { - createReader(); + dwio::common::RuntimeStatistics& runtimeStats, + const folly::F14FastMap& fileReadOps) { + createReader(fileReadOps); if (emptySplit_) { return; } @@ -170,77 +196,33 @@ void SplitReader::prepareSplit( return; } - createRowReader(std::move(metadataFilter), std::move(rowType)); + createRowReader(std::move(metadataFilter), std::move(rowType), std::nullopt); } -void SplitReader::setBucketConversion( - std::vector bucketChannels) { - bucketChannels_ = {bucketChannels.begin(), bucketChannels.end()}; - partitionFunction_ = std::make_unique( - hiveSplit_->bucketConversion->tableBucketCount, - std::move(bucketChannels)); -} - -std::vector SplitReader::bucketConversionRows( - const RowVector& vector) { - partitions_.clear(); - partitionFunction_->partition(vector, partitions_); - const auto bucketToKeep = *hiveSplit_->tableBucketNumber; - const auto partitionBucketCount = - hiveSplit_->bucketConversion->partitionBucketCount; - std::vector ranges; - for (vector_size_t i = 0; i < vector.size(); ++i) { - VELOX_CHECK_EQ((partitions_[i] - bucketToKeep) % partitionBucketCount, 0); - if (partitions_[i] == bucketToKeep) { - auto& r = ranges.emplace_back(); - r.sourceIndex = i; - r.targetIndex = ranges.size() - 1; - r.count = 1; - } - } - return ranges; -} - -void SplitReader::applyBucketConversion( - VectorPtr& output, - const std::vector& ranges) { - auto filtered = - BaseVector::create(output->type(), ranges.size(), output->pool()); - filtered->copyRanges(output.get(), ranges); - output = std::move(filtered); -} - -uint64_t SplitReader::next(uint64_t size, VectorPtr& output) { - uint64_t numScanned; +uint64_t FileSplitReader::next(uint64_t size, VectorPtr& output) { if (!baseReaderOpts_.randomSkip()) { - numScanned = baseRowReader_->next(size, output); - } else { - dwio::common::Mutation mutation; - mutation.randomSkip = baseReaderOpts_.randomSkip().get(); - numScanned = baseRowReader_->next(size, output, &mutation); - } - if (numScanned > 0 && output->size() > 0 && partitionFunction_) { - applyBucketConversion( - output, bucketConversionRows(*output->asChecked())); + return baseRowReader_->next(size, output); } - return numScanned; + dwio::common::Mutation mutation; + mutation.randomSkip = baseReaderOpts_.randomSkip().get(); + return baseRowReader_->next(size, output, &mutation); } -void SplitReader::resetFilterCaches() { +void FileSplitReader::resetFilterCaches() { if (baseRowReader_) { baseRowReader_->resetFilterCaches(); } } -bool SplitReader::emptySplit() const { +bool FileSplitReader::emptySplit() const { return emptySplit_; } -void SplitReader::resetSplit() { - hiveSplit_.reset(); +void FileSplitReader::resetSplit() { + fileSplit_.reset(); } -int64_t SplitReader::estimatedRowSize() const { +int64_t FileSplitReader::estimatedRowSize() const { if (!baseRowReader_) { return DataSource::kUnknownRowSize; } @@ -249,31 +231,31 @@ int64_t SplitReader::estimatedRowSize() const { return size.value_or(DataSource::kUnknownRowSize); } -void SplitReader::updateRuntimeStats( +void FileSplitReader::updateRuntimeStats( dwio::common::RuntimeStatistics& stats) const { if (baseRowReader_) { baseRowReader_->updateRuntimeStats(stats); } } -bool SplitReader::allPrefetchIssued() const { +bool FileSplitReader::allPrefetchIssued() const { return baseRowReader_ && baseRowReader_->allPrefetchIssued(); } -void SplitReader::setConnectorQueryCtx( +void FileSplitReader::setConnectorQueryCtx( const ConnectorQueryCtx* connectorQueryCtx) { connectorQueryCtx_ = connectorQueryCtx; } -std::string SplitReader::toString() const { +std::string FileSplitReader::toString() const { std::string partitionKeys; std::for_each( partitionKeys_->begin(), partitionKeys_->end(), [&](const auto& column) { partitionKeys += " " + column.second->toString(); }); return fmt::format( - "SplitReader: hiveSplit_{} scanSpec_{} readerOutputType_{} partitionKeys_{} reader{} rowReader{}", - hiveSplit_->toString(), + "FileSplitReader: fileSplit_{} scanSpec_{} readerOutputType_{} partitionKeys_{} reader{} rowReader{}", + fileSplit_->toString(), scanSpec_->toString(), readerOutputType_->toString(), partitionKeys, @@ -281,23 +263,32 @@ std::string SplitReader::toString() const { static_cast(baseRowReader_.get())); } -void SplitReader::createReader() { +void FileSplitReader::createReader( + const folly::F14FastMap& fileReadOps) { VELOX_CHECK_NE( baseReaderOpts_.fileFormat(), dwio::common::FileFormat::UNKNOWN); FileHandleCachedPtr fileHandleCachePtr; FileHandleKey fileHandleKey{ - .filename = hiveSplit_->filePath, + .filename = fileSplit_->filePath, .tokenProvider = connectorQueryCtx_->fsTokenProvider()}; + + auto fileProperties = fileSplit_->properties.value_or(FileProperties{}); + fileProperties.fileReadOps = fileReadOps; + if (!tableHandle_->dbName().empty()) { + fileProperties.fileReadOps[kDbNameKey] = tableHandle_->dbName(); + } + if (!tableHandle_->name().empty()) { + fileProperties.fileReadOps[kTableNameKey] = tableHandle_->name(); + } + try { fileHandleCachePtr = fileHandleFactory_->generate( - fileHandleKey, - hiveSplit_->properties.has_value() ? &*hiveSplit_->properties : nullptr, - fsStats_ ? fsStats_.get() : nullptr); + fileHandleKey, &fileProperties, ioStats_ ? ioStats_.get() : nullptr); VELOX_CHECK_NOT_NULL(fileHandleCachePtr.get()); } catch (const VeloxRuntimeError& e) { if (e.errorCode() == error_code::kFileNotFound && - hiveConfig_->ignoreMissingFiles( + fileConfig_->ignoreMissingFiles( connectorQueryCtx_->sessionProperties())) { emptySplit_ = true; return; @@ -312,13 +303,14 @@ void SplitReader::createReader() { if (auto* cacheTTLController = cache::CacheTTLController::getInstance()) { cacheTTLController->addOpenFileInfo(fileHandleCachePtr->uuid.id()); } - auto baseFileInput = createBufferedInput( + auto baseFileInput = BufferedInputBuilder::getInstance()->create( *fileHandleCachePtr, baseReaderOpts_, connectorQueryCtx_, + dataIoStats_, ioStats_, - fsStats_, - ioExecutor_); + ioExecutor_, + fileReadOps); baseReader_ = dwio::common::getReaderFactory(baseReaderOpts_.fileFormat()) ->createReader(std::move(baseFileInput), baseReaderOpts_); @@ -327,32 +319,32 @@ void SplitReader::createReader() { } } -RowTypePtr SplitReader::getAdaptedRowType() const { +RowTypePtr FileSplitReader::getAdaptedRowType() const { auto& fileType = baseReader_->rowType(); auto columnTypes = adaptColumns(fileType, baseReaderOpts_.fileSchema()); auto columnNames = fileType->names(); return ROW(std::move(columnNames), std::move(columnTypes)); } -bool SplitReader::filterOnStats( +bool FileSplitReader::filterOnStats( dwio::common::RuntimeStatistics& runtimeStats) const { if (testFilters( scanSpec_.get(), baseReader_.get(), - hiveSplit_->filePath, - hiveSplit_->partitionKeys, + fileSplit_->filePath, + fileSplit_->partitionKeys, *partitionKeys_, - hiveConfig_->readTimestampPartitionValueAsLocalTime( + fileConfig_->readTimestampPartitionValueAsLocalTime( connectorQueryCtx_->sessionProperties()))) { ++runtimeStats.processedSplits; return true; } ++runtimeStats.skippedSplits; - runtimeStats.skippedSplitBytes += hiveSplit_->length; + runtimeStats.skippedSplitBytes += fileSplit_->length; return false; } -bool SplitReader::checkIfSplitIsEmpty( +bool FileSplitReader::checkIfSplitIsEmpty( dwio::common::RuntimeStatistics& runtimeStats) { // emptySplit_ may already be set if the data file is not found. In this case // we don't need to test further. @@ -366,23 +358,42 @@ bool SplitReader::checkIfSplitIsEmpty( return emptySplit_; } -void SplitReader::createRowReader( +void FileSplitReader::createRowReader( std::shared_ptr metadataFilter, - RowTypePtr rowType) { + RowTypePtr rowType, + std::optional rowSizeTrackingEnabled) { VELOX_CHECK_NULL(baseRowReader_); - configureRowReaderOptions( - hiveTableHandle_->tableParameters(), + configureBaseRowReaderOptions(std::move(metadataFilter), std::move(rowType)); + baseRowReaderOpts_.setStringDecoderZeroCopy( + fileConfig_->nimbleStringDecoderZeroCopy( + connectorQueryCtx_->sessionProperties())); + baseRowReaderOpts_.setNimblePreserveDictionaryEncoding( + fileConfig_->nimblePreserveDictionaryEncoding( + connectorQueryCtx_->sessionProperties())); + baseRowReaderOpts_.setTrackRowSize( + rowSizeTrackingEnabled.has_value() + ? *rowSizeTrackingEnabled + : connectorQueryCtx_->rowSizeTrackingMode() != + core::QueryConfig::RowSizeTrackingMode::DISABLED); + baseRowReader_ = baseReader_->createRowReader(baseRowReaderOpts_); +} + +void FileSplitReader::configureBaseRowReaderOptions( + std::shared_ptr metadataFilter, + RowTypePtr rowType) { + hive::configureRowReaderOptions( + tableHandle_->tableParameters(), scanSpec_, std::move(metadataFilter), std::move(rowType), - hiveSplit_, - hiveConfig_, + fileSplit_, + fileConfig_, connectorQueryCtx_->sessionProperties(), + ioExecutor_, baseRowReaderOpts_); - baseRowReader_ = baseReader_->createRowReader(baseRowReaderOpts_); } -std::vector SplitReader::adaptColumns( +std::vector FileSplitReader::adaptColumns( const RowTypePtr& fileType, const std::shared_ptr& tableSchema) const { // Keep track of schema types for columns in file, used by ColumnSelector. @@ -393,34 +404,20 @@ std::vector SplitReader::adaptColumns( auto* childSpec = childrenSpecs[i].get(); const std::string& fieldName = childSpec->fieldName(); - if (auto it = hiveSplit_->partitionKeys.find(fieldName); - it != hiveSplit_->partitionKeys.end()) { - setPartitionValue(childSpec, fieldName, it->second); - } else if (auto iter = hiveSplit_->infoColumns.find(fieldName); - iter != hiveSplit_->infoColumns.end()) { - auto infoColumnType = - readerOutputType_->childAt(readerOutputType_->getChildIdx(fieldName)); - auto constant = VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH_ALL( - newConstantFromString, - infoColumnType->kind(), - infoColumnType, - iter->second, - 1, - connectorQueryCtx_->memoryPool(), - connectorQueryCtx_->sessionTimezone(), - hiveConfig_->readTimestampPartitionValueAsLocalTime( - connectorQueryCtx_->sessionProperties())); - childSpec->setConstantValue(constant); + if (auto partitionIt = fileSplit_->partitionKeys.find(fieldName); + partitionIt != fileSplit_->partitionKeys.end()) { + setPartitionValue(childSpec, fieldName, partitionIt->second); } else if ( childSpec->columnType() == common::ScanSpec::ColumnType::kRegular) { auto fileTypeIdx = fileType->getChildIdxIfExists(fieldName); if (!fileTypeIdx.has_value()) { // Column is missing. Most likely due to schema evolution. VELOX_CHECK(tableSchema, "Unable to resolve column '{}'", fieldName); - childSpec->setConstantValue(BaseVector::createNullConstant( - tableSchema->findChild(fieldName), - 1, - connectorQueryCtx_->memoryPool())); + childSpec->setConstantValue( + BaseVector::createNullConstant( + tableSchema->findChild(fieldName), + 1, + connectorQueryCtx_->memoryPool())); } else { // Column no longer missing, reset constant value set on the spec. childSpec->setConstantValue(nullptr); @@ -446,25 +443,21 @@ std::vector SplitReader::adaptColumns( return columnTypes; } -void SplitReader::setPartitionValue( +void FileSplitReader::setPartitionValue( common::ScanSpec* spec, const std::string& partitionKey, const std::optional& value) const { - auto it = partitionKeys_->find(partitionKey); + const auto it = partitionKeys_->find(partitionKey); VELOX_CHECK( it != partitionKeys_->end(), "ColumnHandle is missing for partition key {}", partitionKey); - auto type = it->second->dataType(); - auto constant = VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH_ALL( - newConstantFromString, - type->kind(), + const auto type = it->second->dataType(); + const auto constant = newConstantFromString( type, value, - 1, connectorQueryCtx_->memoryPool(), - connectorQueryCtx_->sessionTimezone(), - hiveConfig_->readTimestampPartitionValueAsLocalTime( + fileConfig_->readTimestampPartitionValueAsLocalTime( connectorQueryCtx_->sessionProperties()), it->second->isPartitionDateValueDaysSinceEpoch()); spec->setConstantValue(constant); diff --git a/velox/connectors/hive/FileSplitReader.h b/velox/connectors/hive/FileSplitReader.h new file mode 100644 index 00000000000..04e59016218 --- /dev/null +++ b/velox/connectors/hive/FileSplitReader.h @@ -0,0 +1,245 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/common/base/RandomUtil.h" +#include "velox/common/file/FileSystems.h" +#include "velox/connectors/hive/FileColumnHandle.h" +#include "velox/connectors/hive/FileConnectorSplit.h" +#include "velox/connectors/hive/FileHandle.h" +#include "velox/connectors/hive/FileTableHandle.h" +#include "velox/dwio/common/Options.h" +#include "velox/dwio/common/Reader.h" + +namespace facebook::velox { +class BaseVector; +using VectorPtr = std::shared_ptr; +} // namespace facebook::velox + +namespace facebook::velox::common { +class MetadataFilter; +class ScanSpec; +} // namespace facebook::velox::common + +namespace facebook::velox::connector { +class ConnectorQueryCtx; +} // namespace facebook::velox::connector + +namespace facebook::velox::dwio::common { +struct RuntimeStatistics; +} // namespace facebook::velox::dwio::common + +namespace facebook::velox::memory { +class MemoryPool; +} + +namespace facebook::velox::connector::hive { + +/// Creates a constant vector of size 1 from a string representation of a value. +/// +/// Used to materialize partition column values and info columns (e.g., $path, +/// $file_size) when reading Hive and Iceberg tables. Partition values are +/// stored as strings in HiveConnectorSplit::partitionKeys and need to be +/// converted to their appropriate types. +/// +/// @param type The target Velox type for the constant vector. Supports all +/// scalar types including primitives, dates, timestamps. +/// @param value The string representation of the value to convert, formatted +/// the same way as CAST(x as VARCHAR). Date values must be formatted using ISO +/// 8601 as YYYY-MM-DD. If nullopt, creates a null constant vector. +/// @param pool Memory pool for allocating the constant vector. +/// @param isLocalTimestamp If true and type is TIMESTAMP, interprets the string +/// value as local time and converts it to GMT. If false, treats the value +/// as already in GMT. +/// @param isDaysSinceEpoch If true and type is DATE, treats the string value as +/// an integer representing days since epoch (used by Iceberg). If false, parses +/// the string as a date string in ISO 8601 format (used by Hive). +/// +/// @return A constant vector of size 1 containing the converted value, or a +/// null constant if value is nullopt. +/// @throws VeloxUserError if the string cannot be converted to the target type. +VectorPtr newConstantFromString( + const TypePtr& type, + const std::optional& value, + velox::memory::MemoryPool* pool, + bool isLocalTimestamp, + bool isDaysSinceEpoch); + +class FileConfig; + +class FileSplitReader { + public: + static std::unique_ptr create( + const std::shared_ptr& fileSplit, + const FileTableHandlePtr& tableHandle, + const std::unordered_map* partitionKeys, + const ConnectorQueryCtx* connectorQueryCtx, + const std::shared_ptr& fileConfig, + const RowTypePtr& readerOutputType, + const std::shared_ptr& dataIoStats, + const std::shared_ptr& metadataIoStats, + const std::shared_ptr& ioStats, + FileHandleFactory* fileHandleFactory, + folly::Executor* ioExecutor, + const std::shared_ptr& scanSpec, + const common::SubfieldFilters* subfieldFiltersForValidation = nullptr); + + virtual ~FileSplitReader() = default; + + void configureReaderOptions( + std::shared_ptr randomSkip); + + /// This function is used by different table formats like Iceberg and Hudi to + /// do additional preparations before reading the split, e.g. Open delete + /// files or log files, and add column adapatations for metadata columns. It + /// would be called only once per incoming split + virtual void prepareSplit( + std::shared_ptr metadataFilter, + dwio::common::RuntimeStatistics& runtimeStats, + const folly::F14FastMap& fileReadOps = {}); + + virtual uint64_t next(uint64_t size, VectorPtr& output); + + void resetFilterCaches(); + + bool emptySplit() const; + + void resetSplit(); + + int64_t estimatedRowSize() const; + + void updateRuntimeStats(dwio::common::RuntimeStatistics& stats) const; + + bool allPrefetchIssued() const; + + void setConnectorQueryCtx(const ConnectorQueryCtx* connectorQueryCtx); + + const RowTypePtr& readerOutputType() const { + return readerOutputType_; + } + + std::string toString() const; + + protected: + FileSplitReader( + const std::shared_ptr& fileSplit, + const FileTableHandlePtr& tableHandle, + const std::unordered_map* partitionKeys, + const ConnectorQueryCtx* connectorQueryCtx, + const std::shared_ptr& fileConfig, + const RowTypePtr& readerOutputType, + const std::shared_ptr& dataIoStats, + const std::shared_ptr& metadataIoStats, + const std::shared_ptr& ioStats, + FileHandleFactory* fileHandleFactory, + folly::Executor* executor, + const std::shared_ptr& scanSpec, + const common::SubfieldFilters* subfieldFiltersForValidation = nullptr); + + /// Create the dwio::common::Reader object baseReader_ + /// read the data file's metadata and schema + void createReader( + const folly::F14FastMap& fileReadOps = {}); + + // Adjust the scan spec according to the current split, then return the + // adapted row type. + RowTypePtr getAdaptedRowType() const; + + // Check if the filters pass on the column statistics. When delta update is + // present, the corresonding filter should be disabled before calling this + // function. + bool filterOnStats(dwio::common::RuntimeStatistics& runtimeStats) const; + + /// Check if the fileSplit_ is empty. The split is considered empty when + /// 1) The data file is missing but the user chooses to ignore it + /// 2) The file does not contain any rows + /// 3) The data in the file does not pass the filters. The test is based on + /// the file metadata and partition key values + /// This function needs to be called after baseReader_ is created. + bool checkIfSplitIsEmpty(dwio::common::RuntimeStatistics& runtimeStats); + + /// Create the dwio::common::RowReader object baseRowReader_, which owns the + /// ColumnReaders that will be used to read the data + void createRowReader( + std::shared_ptr metadataFilter, + RowTypePtr rowType, + std::optional rowSizeTrackingEnabled); + + /// Sets a constant partition value on the scanSpec for a partition column. + /// Converts the partition key string value to the appropriate type and sets + /// it as a constant value in the scanSpec, so the column will be filled + /// with this constant value. + /// + /// @param spec The scan spec to set the constant value on. + /// @param partitionKey The name of the partition column. + void setPartitionValue( + common::ScanSpec* spec, + const std::string& partitionKey, + const std::optional& value) const; + + /// Virtual hook called by configureReaderOptions() to set format-specific + /// reader options on baseReaderOpts_. The base implementation calls the + /// generic configureReaderOptions() from FileConnectorUtil. Subclasses + /// (e.g., HiveSplitReader) override to call the Hive-specific version + /// that also applies serde options. + virtual void configureBaseReaderOptions(); + + /// Virtual hook called by createRowReader() to set format-specific row + /// reader options on baseRowReaderOpts_. The base implementation calls the + /// generic configureRowReaderOptions() from FileConnectorUtil. Subclasses + /// (e.g., HiveSplitReader) override to call the Hive-specific version + /// that also applies serde parameters. + virtual void configureBaseRowReaderOptions( + std::shared_ptr metadataFilter, + RowTypePtr rowType); + + private: + /// Different table formats may have different meatadata columns. + /// This function will be used to update the scanSpec for these columns. + virtual std::vector adaptColumns( + const RowTypePtr& fileType, + const RowTypePtr& tableSchema) const; + + protected: + const FileTableHandlePtr tableHandle_; + const std::unordered_map* const + partitionKeys_; + const std::shared_ptr fileConfig_; + const std::shared_ptr dataIoStats_; + const std::shared_ptr metadataIoStats_; + const std::shared_ptr ioStats_; + FileHandleFactory* const fileHandleFactory_; + folly::Executor* const ioExecutor_; + memory::MemoryPool* const pool_; + const std::shared_ptr scanSpec_; + // Subfield filters from HiveDataSource, includes both original + // subfieldFilters and filters extracted from remainingFilter. Used by + // subclasses (e.g., HiveSplitReader) for synthesized column filter + // validation. + const common::SubfieldFilters* const subfieldFiltersForValidation_; + + std::shared_ptr fileSplit_; + const ConnectorQueryCtx* connectorQueryCtx_; + RowTypePtr readerOutputType_; + std::unique_ptr baseReader_; + std::unique_ptr baseRowReader_; + dwio::common::ReaderOptions baseReaderOpts_; + dwio::common::RowReaderOptions baseRowReaderOpts_; + bool emptySplit_; +}; + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/FileTableHandle.h b/velox/connectors/hive/FileTableHandle.h new file mode 100644 index 00000000000..fb5e1ac656b --- /dev/null +++ b/velox/connectors/hive/FileTableHandle.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/connectors/Connector.h" +#include "velox/connectors/hive/FileColumnHandle.h" +#include "velox/core/ITypedExpr.h" +#include "velox/type/Filter.h" +#include "velox/type/Subfield.h" + +namespace facebook::velox::connector::hive { + +/// Base class for table handles in file-based connectors. +/// +/// Define the common interface for table-level metadata needed by +/// FileDataSource and FileSplitReader. Connector-specific table handles +/// (HiveTableHandle, etc.) extend this class. +class FileTableHandle : public ConnectorTableHandle { + public: + using ConnectorTableHandle::ConnectorTableHandle; + + /// Single-field filters that can be applied efficiently during file reading. + virtual const common::SubfieldFilters& subfieldFilters() const = 0; + + /// Remaining filter expression that cannot be converted into subfield + /// filters. Usually less efficient but supports arbitrary boolean + /// expressions. + virtual const core::TypedExprPtr& remainingFilter() const = 0; + + /// Sampling rate between 0 and 1 (excluding 0). 0.1 means 10% sampling. + /// 1.0 means no sampling. + virtual double sampleRate() const = 0; + + /// Subset of schema stored in data files (non-partitioning columns). + /// Needed for reading TEXTFILE, handling schema evolution, etc. + virtual const RowTypePtr& dataColumns() const = 0; + + /// Extra parameters passed down to the file format reader layer. + virtual const std::unordered_map& tableParameters() + const = 0; + + /// Extra columns used in filters but not in the output. Returns column + /// handles as FileColumnHandlePtr for use in the generic scan pipeline. + virtual std::vector filterColumnHandles() const = 0; + + /// Database or namespace name for this table. + virtual const std::string& dbName() const = 0; +}; + +using FileTableHandlePtr = std::shared_ptr; + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/HiveConfig.cpp b/velox/connectors/hive/HiveConfig.cpp index 3463ec767c6..b05ffa1281d 100644 --- a/velox/connectors/hive/HiveConfig.cpp +++ b/velox/connectors/hive/HiveConfig.cpp @@ -15,10 +15,9 @@ */ #include "velox/connectors/hive/HiveConfig.h" -#include "velox/common/config/Config.h" -#include "velox/core/QueryConfig.h" #include +#include "velox/common/config/Config.h" namespace facebook::velox::connector::hive { @@ -39,6 +38,34 @@ stringToInsertExistingPartitionsBehavior(const std::string& strValue) { } // namespace +const std::vector& HiveConfig::registeredProperties() { + static const std::vector kProperties = [] { + // Start with FileConfig properties. + auto properties = FileConfig::registeredProperties(); +#define VELOX_HIVE_CONFIG_REGISTER(constName) \ + config::registerConfigProperty(properties) + + VELOX_HIVE_CONFIG_REGISTER(kInsertExistingPartitionsBehaviorSession); + VELOX_HIVE_CONFIG_REGISTER(kSortWriterMaxOutputBytesSession); + VELOX_HIVE_CONFIG_REGISTER(kMaxTargetFileSizeSession); + VELOX_HIVE_CONFIG_REGISTER(kMaxPartitionsPerWritersSession); + VELOX_HIVE_CONFIG_REGISTER(kAllowNullPartitionKeysSession); + VELOX_HIVE_CONFIG_REGISTER(kPartitionPathAsLowerCaseSession); + VELOX_HIVE_CONFIG_REGISTER(kSortWriterMaxOutputRowsSession); + VELOX_HIVE_CONFIG_REGISTER(kSortWriterFinishTimeSliceLimitMsSession); + VELOX_HIVE_CONFIG_REGISTER(kUser); + VELOX_HIVE_CONFIG_REGISTER(kSource); + VELOX_HIVE_CONFIG_REGISTER(kSchema); + VELOX_HIVE_CONFIG_REGISTER(kMaxBucketCountSession); + VELOX_HIVE_CONFIG_REGISTER(kMaxRowsPerIndexRequestSession); + +#undef VELOX_HIVE_CONFIG_REGISTER + + return properties; + }(); + return kProperties; +} + // static std::string HiveConfig::insertExistingPartitionsBehaviorString( InsertExistingPartitionsBehavior behavior) { @@ -60,108 +87,39 @@ HiveConfig::insertExistingPartitionsBehavior( config_->get(kInsertExistingPartitionsBehavior, "ERROR"))); } -uint32_t HiveConfig::maxPartitionsPerWriters( - const config::ConfigBase* session) const { - return session->get( - kMaxPartitionsPerWritersSession, - config_->get(kMaxPartitionsPerWriters, 128)); -} - -uint32_t HiveConfig::maxBucketCount(const config::ConfigBase* session) const { - return session->get( - kMaxBucketCountSession, config_->get(kMaxBucketCount, 100'000)); -} - bool HiveConfig::immutablePartitions() const { - return config_->get(kImmutablePartitions, false); + return configValue(kImmutablePartitions, false); } std::string HiveConfig::gcsEndpoint() const { - return config_->get(kGcsEndpoint, std::string("")); + return configValue(kGcsEndpoint, std::string("")); } std::string HiveConfig::gcsCredentialsPath() const { - return config_->get(kGcsCredentialsPath, std::string("")); + return configValue(kGcsCredentialsPath, std::string("")); } std::optional HiveConfig::gcsMaxRetryCount() const { - return static_cast>(config_->get(kGcsMaxRetryCount)); + if (auto val = config_->get(kGcsMaxRetryCount)) { + return val; + } + return config_->get(std::string(kLegacyPrefix) + kGcsMaxRetryCount); } std::optional HiveConfig::gcsMaxRetryTime() const { - return static_cast>( - config_->get(kGcsMaxRetryTime)); + if (auto val = config_->get(kGcsMaxRetryTime)) { + return val; + } + return config_->get( + std::string(kLegacyPrefix) + kGcsMaxRetryTime); } std::optional HiveConfig::gcsAuthAccessTokenProvider() const { - return static_cast>( - config_->get(kGcsAuthAccessTokenProvider)); -} - -bool HiveConfig::isOrcUseColumnNames(const config::ConfigBase* session) const { - return session->get( - kOrcUseColumnNamesSession, config_->get(kOrcUseColumnNames, false)); -} - -bool HiveConfig::isParquetUseColumnNames( - const config::ConfigBase* session) const { - return session->get( - kParquetUseColumnNamesSession, - config_->get(kParquetUseColumnNames, false)); -} - -bool HiveConfig::isFileColumnNamesReadAsLowerCase( - const config::ConfigBase* session) const { - return session->get( - kFileColumnNamesReadAsLowerCaseSession, - config_->get(kFileColumnNamesReadAsLowerCase, false)); -} - -bool HiveConfig::isPartitionPathAsLowerCase( - const config::ConfigBase* session) const { - return session->get(kPartitionPathAsLowerCaseSession, true); -} - -bool HiveConfig::allowNullPartitionKeys( - const config::ConfigBase* session) const { - return session->get( - kAllowNullPartitionKeysSession, - config_->get(kAllowNullPartitionKeys, true)); -} - -bool HiveConfig::ignoreMissingFiles(const config::ConfigBase* session) const { - return session->get(kIgnoreMissingFilesSession, false); -} - -int64_t HiveConfig::maxCoalescedBytes(const config::ConfigBase* session) const { - return session->get( - kMaxCoalescedBytesSession, - config_->get(kMaxCoalescedBytes, 128 << 20)); // 128MB -} - -int32_t HiveConfig::maxCoalescedDistanceBytes( - const config::ConfigBase* session) const { - const auto distance = config::toCapacity( - session->get( - kMaxCoalescedDistanceSession, - config_->get(kMaxCoalescedDistance, "512kB")), - config::CapacityUnit::BYTE); - VELOX_USER_CHECK_LE( - distance, - std::numeric_limits::max(), - "The max merge distance to combine read requests must be less than 2GB." - " Got {} bytes.", - distance); - return int32_t(distance); -} - -int32_t HiveConfig::prefetchRowGroups() const { - return config_->get(kPrefetchRowGroups, 1); -} - -int32_t HiveConfig::loadQuantum(const config::ConfigBase* session) const { - return session->get( - kLoadQuantumSession, config_->get(kLoadQuantum, 8 << 20)); + if (auto val = config_->get(kGcsAuthAccessTokenProvider)) { + return val; + } + return config_->get( + std::string(kLegacyPrefix) + kGcsAuthAccessTokenProvider); } int32_t HiveConfig::numCacheFileHandles() const { @@ -177,14 +135,11 @@ bool HiveConfig::isFileHandleCacheEnabled() const { } std::string HiveConfig::writeFileCreateConfig() const { - return config_->get(kWriteFileCreateConfig, ""); -} - -uint32_t HiveConfig::sortWriterMaxOutputRows( - const config::ConfigBase* session) const { - return session->get( - kSortWriterMaxOutputRowsSession, - config_->get(kSortWriterMaxOutputRows, 1024)); + // Legacy key used snake_case: "hive.write_file_create_config". + if (auto val = config_->get(kWriteFileCreateConfig)) { + return val.value(); + } + return config_->get("hive.write_file_create_config", ""); } uint64_t HiveConfig::sortWriterMaxOutputBytes( @@ -196,58 +151,13 @@ uint64_t HiveConfig::sortWriterMaxOutputBytes( config::CapacityUnit::BYTE); } -uint64_t HiveConfig::sortWriterFinishTimeSliceLimitMs( - const config::ConfigBase* session) const { - return session->get( - kSortWriterFinishTimeSliceLimitMsSession, - config_->get(kSortWriterFinishTimeSliceLimitMs, 5'000)); -} - -uint64_t HiveConfig::footerEstimatedSize() const { - return config_->get(kFooterEstimatedSize, 256UL << 10); -} - -uint64_t HiveConfig::filePreloadThreshold() const { - return config_->get(kFilePreloadThreshold, 8UL << 20); -} - -uint8_t HiveConfig::readTimestampUnit(const config::ConfigBase* session) const { - const auto unit = session->get( - kReadTimestampUnitSession, - config_->get(kReadTimestampUnit, 3 /*milli*/)); - VELOX_CHECK( - unit == 3 || unit == 6 /*micro*/ || unit == 9 /*nano*/, - "Invalid timestamp unit."); - return unit; -} - -bool HiveConfig::readTimestampPartitionValueAsLocalTime( - const config::ConfigBase* session) const { - return session->get( - kReadTimestampPartitionValueAsLocalTimeSession, - config_->get(kReadTimestampPartitionValueAsLocalTime, true)); -} - -bool HiveConfig::readStatsBasedFilterReorderDisabled( - const config::ConfigBase* session) const { - return session->get( - kReadStatsBasedFilterReorderDisabledSession, - config_->get(kReadStatsBasedFilterReorderDisabled, false)); -} - -std::string HiveConfig::hiveLocalDataPath() const { - return config_->get(kLocalDataPath, ""); -} - -std::string HiveConfig::hiveLocalFileFormat() const { - return config_->get(kLocalFileFormat, ""); -} - -bool HiveConfig::preserveFlatMapsInMemory( +uint64_t HiveConfig::maxTargetFileSizeBytes( const config::ConfigBase* session) const { - return session->get( - kPreserveFlatMapsInMemorySession, - config_->get(kPreserveFlatMapsInMemory, false)); + return config::toCapacity( + session->get( + kMaxTargetFileSizeSession, + config_->get(kMaxTargetFileSize, "0B")), + config::CapacityUnit::BYTE); } } // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/HiveConfig.h b/velox/connectors/hive/HiveConfig.h index 6cf3d911af1..6835bcd3d16 100644 --- a/velox/connectors/hive/HiveConfig.h +++ b/velox/connectors/hive/HiveConfig.h @@ -15,19 +15,20 @@ */ #pragma once -#include -#include -#include "velox/common/base/Exceptions.h" - -namespace facebook::velox::config { -class ConfigBase; -} +#include "velox/connectors/hive/FileConfig.h" +#include "velox/connectors/hive/HiveConfigMacrosDefine.h" namespace facebook::velox::connector::hive { -/// Hive connector configs. -class HiveConfig { +/// Hive connector configs. Extends FileConfig with Hive-specific settings +/// for partitioning, bucketing, write-path options, GCS storage, and other +/// Hive table format concerns. +class HiveConfig : public FileConfig { public: + /// Returns all registered session-overridable properties from both + /// FileConfig and HiveConfig. + static const std::vector& registeredProperties(); + enum class InsertExistingPartitionsBehavior { kError, kOverwrite, @@ -36,93 +37,147 @@ class HiveConfig { static std::string insertExistingPartitionsBehaviorString( InsertExistingPartitionsBehavior behavior); - /// Behavior on insert into existing partitions. - static constexpr const char* kInsertExistingPartitionsBehaviorSession = - "insert_existing_partitions_behavior"; + // --- VELOX_HIVE_CONFIG_PROPERTY properties --- + + VELOX_HIVE_CONFIG_PROPERTY( + kInsertExistingPartitionsBehaviorSession, + "insert_existing_partitions_behavior", + std::string, + "ERROR", + "Behavior when inserting into existing partitions: ERROR, OVERWRITE.") static constexpr const char* kInsertExistingPartitionsBehavior = "insert-existing-partitions-behavior"; - /// Maximum number of (bucketed) partitions per a single table writer - /// instance. + VELOX_HIVE_CONFIG_PROPERTY( + kSortWriterMaxOutputBytesSession, + "sort_writer_max_output_bytes", + std::string, + "10MB", + "Maximum output bytes for sort writer.") + static constexpr const char* kSortWriterMaxOutputBytes = + "sort-writer-max-output-bytes"; + + VELOX_HIVE_CONFIG_PROPERTY( + kMaxTargetFileSizeSession, + "max_target_file_size", + std::string, + "0B", + "Maximum target file size. 0 means no limit.") + static constexpr const char* kMaxTargetFileSize = "max-target-file-size"; + + // --- VELOX_HIVE_CONFIG properties --- + + VELOX_HIVE_CONFIG( + kMaxPartitionsPerWritersSession, + maxPartitionsPerWriters, + "max_partitions_per_writers", + uint32_t, + 128, + "Maximum partitions per writer.") static constexpr const char* kMaxPartitionsPerWriters = "max-partitions-per-writers"; - static constexpr const char* kMaxPartitionsPerWritersSession = - "max_partitions_per_writers"; - /// Maximum number of buckets allowed to output by the table writers. - static constexpr const char* kMaxBucketCount = "hive.max-bucket-count"; - static constexpr const char* kMaxBucketCountSession = "hive.max_bucket_count"; + VELOX_HIVE_CONFIG( + kAllowNullPartitionKeysSession, + allowNullPartitionKeys, + "allow_null_partition_keys", + bool, + true, + "Allow null partition keys.") + static constexpr const char* kAllowNullPartitionKeys = + "allow-null-partition-keys"; + + VELOX_HIVE_CONFIG_PROPERTY( + kPartitionPathAsLowerCaseSession, + "partition_path_as_lower_case", + bool, + true, + "Write partition path segments as lower case.") + bool isPartitionPathAsLowerCase(const config::ConfigBase* session) const { + return session->get( + kPartitionPathAsLowerCaseSession, + kPartitionPathAsLowerCaseSessionProperty::defaultValue); + } + + VELOX_HIVE_CONFIG( + kSortWriterMaxOutputRowsSession, + sortWriterMaxOutputRows, + "sort_writer_max_output_rows", + uint32_t, + 1024, + "Maximum output rows for sort writer.") + static constexpr const char* kSortWriterMaxOutputRows = + "sort-writer-max-output-rows"; + + VELOX_HIVE_CONFIG( + kSortWriterFinishTimeSliceLimitMsSession, + sortWriterFinishTimeSliceLimitMs, + "sort_writer_finish_time_slice_limit_ms", + uint64_t, + 5000, + "Time slice limit in ms for sort writer finish. 0 means no limit.") + static constexpr const char* kSortWriterFinishTimeSliceLimitMs = + "sort-writer-finish-time-slice-limit-ms"; + + VELOX_HIVE_CONFIG(kUser, user, "user", std::string, "", "User of the query.") + + VELOX_HIVE_CONFIG( + kSource, + source, + "source", + std::string, + "", + "Source of the query.") + + VELOX_HIVE_CONFIG( + kSchema, + schema, + "schema", + std::string, + "", + "Schema of the query.") + + // --- VELOX_HIVE_CONFIG_LEGACY properties --- + + VELOX_HIVE_CONFIG_LEGACY( + kMaxBucketCountSession, + kMaxBucketCount, + maxBucketCount, + "max_bucket_count", + "max-bucket-count", + uint32_t, + 100'000, + "Maximum bucket count.") + + VELOX_HIVE_CONFIG_LEGACY( + kMaxRowsPerIndexRequestSession, + kMaxRowsPerIndexRequest, + maxRowsPerIndexRequest, + "max_rows_per_index_request", + "max-rows-per-index-request", + uint32_t, + 0, + "Maximum rows per index lookup request. 0 means no limit.") + // --- Server-only properties (no macro) --- /// Whether new data can be inserted into an unpartition table. /// Velox currently does not support appending data to existing partitions. - static constexpr const char* kImmutablePartitions = - "hive.immutable-partitions"; + static constexpr const char* kImmutablePartitions = "immutable-partitions"; /// The GCS storage endpoint server. - static constexpr const char* kGcsEndpoint = "hive.gcs.endpoint"; + static constexpr const char* kGcsEndpoint = "gcs.endpoint"; /// The GCS service account configuration JSON key file. - static constexpr const char* kGcsCredentialsPath = - "hive.gcs.json-key-file-path"; + static constexpr const char* kGcsCredentialsPath = "gcs.json-key-file-path"; /// The GCS maximum retry counter of transient errors. - static constexpr const char* kGcsMaxRetryCount = "hive.gcs.max-retry-count"; + static constexpr const char* kGcsMaxRetryCount = "gcs.max-retry-count"; /// The GCS maximum time allowed to retry transient errors. - static constexpr const char* kGcsMaxRetryTime = "hive.gcs.max-retry-time"; + static constexpr const char* kGcsMaxRetryTime = "gcs.max-retry-time"; static constexpr const char* kGcsAuthAccessTokenProvider = - "hive.gcs.auth.access-token-provider"; - - /// Maps table field names to file field names using names, not indices. - // TODO: remove hive_orc_use_column_names since it doesn't exist in presto, - // right now this is only used for testing. - static constexpr const char* kOrcUseColumnNames = "hive.orc.use-column-names"; - static constexpr const char* kOrcUseColumnNamesSession = - "hive_orc_use_column_names"; - - /// Maps table field names to file field names using names, not indices. - static constexpr const char* kParquetUseColumnNames = - "hive.parquet.use-column-names"; - static constexpr const char* kParquetUseColumnNamesSession = - "parquet_use_column_names"; - - /// Reads the source file column name as lower case. - static constexpr const char* kFileColumnNamesReadAsLowerCase = - "file-column-names-read-as-lower-case"; - static constexpr const char* kFileColumnNamesReadAsLowerCaseSession = - "file_column_names_read_as_lower_case"; - - static constexpr const char* kPartitionPathAsLowerCaseSession = - "partition_path_as_lower_case"; - - static constexpr const char* kAllowNullPartitionKeys = - "allow-null-partition-keys"; - static constexpr const char* kAllowNullPartitionKeysSession = - "allow_null_partition_keys"; - - static constexpr const char* kIgnoreMissingFilesSession = - "ignore_missing_files"; - - /// The max coalesce bytes for a request. - static constexpr const char* kMaxCoalescedBytes = "max-coalesced-bytes"; - static constexpr const char* kMaxCoalescedBytesSession = - "max-coalesced-bytes"; - - /// The max merge distance to combine read requests. - /// Note: The session property name differs from the constant name for - /// backward compatibility with Presto. - static constexpr const char* kMaxCoalescedDistance = "max-coalesced-distance"; - static constexpr const char* kMaxCoalescedDistanceSession = - "orc_max_merge_distance"; - - /// The number of prefetch rowgroups - static constexpr const char* kPrefetchRowGroups = "prefetch-rowgroups"; - - /// The total size in bytes for a direct coalesce request. Up to 8MB load - /// quantum size is supported when SSD cache is enabled. - static constexpr const char* kLoadQuantum = "load-quantum"; - static constexpr const char* kLoadQuantumSession = "load-quantum"; + "gcs.auth.access-token-provider"; /// Maximum number of entries in the file handle cache. static constexpr const char* kNumCacheFileHandles = "num_cached_file_handles"; @@ -137,73 +192,15 @@ class HiveConfig { static constexpr const char* kEnableFileHandleCache = "file-handle-cache-enabled"; - /// The size in bytes to be fetched with Meta data together, used when the - /// data after meta data will be used later. Optimization to decrease small IO - /// request - static constexpr const char* kFooterEstimatedSize = "footer-estimated-size"; - - /// The threshold of file size in bytes when the whole file is fetched with - /// meta data together. Optimization to decrease the small IO requests - static constexpr const char* kFilePreloadThreshold = "file-preload-threshold"; - /// Config used to create write files. This config is provided to underlying /// file system through hive connector and data sink. The config is free form. /// The form should be defined by the underlying file system. static constexpr const char* kWriteFileCreateConfig = - "hive.write_file_create_config"; - - /// Maximum number of rows for sort writer in one batch of output. - static constexpr const char* kSortWriterMaxOutputRows = - "sort-writer-max-output-rows"; - static constexpr const char* kSortWriterMaxOutputRowsSession = - "sort_writer_max_output_rows"; - - /// Maximum bytes for sort writer in one batch of output. - static constexpr const char* kSortWriterMaxOutputBytes = - "sort-writer-max-output-bytes"; - static constexpr const char* kSortWriterMaxOutputBytesSession = - "sort_writer_max_output_bytes"; - - /// Sort Writer will exit finish() method after this many milliseconds even if - /// it has not completed its work yet. Zero means no time limit. - static constexpr const char* kSortWriterFinishTimeSliceLimitMs = - "sort-writer_finish_time_slice_limit_ms"; - static constexpr const char* kSortWriterFinishTimeSliceLimitMsSession = - "sort_writer_finish_time_slice_limit_ms"; - - // The unit for reading timestamps from files. - static constexpr const char* kReadTimestampUnit = - "hive.reader.timestamp-unit"; - static constexpr const char* kReadTimestampUnitSession = - "hive.reader.timestamp_unit"; - - static constexpr const char* kReadTimestampPartitionValueAsLocalTime = - "hive.reader.timestamp-partition-value-as-local-time"; - static constexpr const char* kReadTimestampPartitionValueAsLocalTimeSession = - "hive.reader.timestamp_partition_value_as_local_time"; - - static constexpr const char* kReadStatsBasedFilterReorderDisabled = - "stats-based-filter-reorder-disabled"; - static constexpr const char* kReadStatsBasedFilterReorderDisabledSession = - "stats_based_filter_reorder_disabled"; - - static constexpr const char* kLocalDataPath = "hive_local_data_path"; - static constexpr const char* kLocalFileFormat = "hive_local_file_format"; - - /// Whether to preserve flat maps in memory as FlatMapVectors instead of - /// converting them to MapVectors. - static constexpr const char* kPreserveFlatMapsInMemory = - "hive.preserve-flat-maps-in-memory"; - static constexpr const char* kPreserveFlatMapsInMemorySession = - "hive.preserve_flat_maps_in_memory"; + "write-file-create-config"; InsertExistingPartitionsBehavior insertExistingPartitionsBehavior( const config::ConfigBase* session) const; - uint32_t maxPartitionsPerWriters(const config::ConfigBase* session) const; - - uint32_t maxBucketCount(const config::ConfigBase* session) const; - bool immutablePartitions() const; std::string gcsEndpoint() const; @@ -216,27 +213,6 @@ class HiveConfig { std::optional gcsAuthAccessTokenProvider() const; - bool isOrcUseColumnNames(const config::ConfigBase* session) const; - - bool isParquetUseColumnNames(const config::ConfigBase* session) const; - - bool isFileColumnNamesReadAsLowerCase( - const config::ConfigBase* session) const; - - bool isPartitionPathAsLowerCase(const config::ConfigBase* session) const; - - bool allowNullPartitionKeys(const config::ConfigBase* session) const; - - bool ignoreMissingFiles(const config::ConfigBase* session) const; - - int64_t maxCoalescedBytes(const config::ConfigBase* session) const; - - int32_t maxCoalescedDistanceBytes(const config::ConfigBase* session) const; - - int32_t prefetchRowGroups() const; - - int32_t loadQuantum(const config::ConfigBase* session) const; - int32_t numCacheFileHandles() const; uint64_t fileHandleExpirationDurationMs() const; @@ -247,55 +223,14 @@ class HiveConfig { std::string writeFileCreateConfig() const; - uint32_t sortWriterMaxOutputRows(const config::ConfigBase* session) const; - uint64_t sortWriterMaxOutputBytes(const config::ConfigBase* session) const; - uint64_t sortWriterFinishTimeSliceLimitMs( - const config::ConfigBase* session) const; + uint64_t maxTargetFileSizeBytes(const config::ConfigBase* session) const; - uint64_t footerEstimatedSize() const; - - uint64_t filePreloadThreshold() const; - - // Returns the timestamp unit used when reading timestamps from files. - uint8_t readTimestampUnit(const config::ConfigBase* session) const; - - // Whether to read timestamp partition value as local time. If false, read as - // UTC. - bool readTimestampPartitionValueAsLocalTime( - const config::ConfigBase* session) const; - - /// Returns true if the stats based filter reorder for read is disabled. - bool readStatsBasedFilterReorderDisabled( - const config::ConfigBase* session) const; - - /// Returns the file system path containing local data. If non-empty, - /// initializes LocalHiveConnectorMetadata to provide metadata for the tables - /// in the directory. - std::string hiveLocalDataPath() const; - - /// Returns the name of the file format to use in interpreting the contents of - /// hiveLocalDataPath(). - std::string hiveLocalFileFormat() const; - - /// Whether to preserve flat maps in memory as FlatMapVectors instead of - /// converting them to MapVectors. - bool preserveFlatMapsInMemory(const config::ConfigBase* session) const; - - HiveConfig(std::shared_ptr config) { - VELOX_CHECK_NOT_NULL( - config, "Config is null for HiveConfig initialization"); - config_ = std::move(config); - // TODO: add sanity check - } - - const std::shared_ptr& config() const { - return config_; - } - - private: - std::shared_ptr config_; + explicit HiveConfig(std::shared_ptr config) + : FileConfig(std::move(config)) {} }; } // namespace facebook::velox::connector::hive + +#include "velox/connectors/hive/HiveConfigMacrosUndef.h" diff --git a/velox/connectors/hive/HiveConfigMacrosDefine.h b/velox/connectors/hive/HiveConfigMacrosDefine.h new file mode 100644 index 00000000000..a86a5ddc397 --- /dev/null +++ b/velox/connectors/hive/HiveConfigMacrosDefine.h @@ -0,0 +1,102 @@ +// NOLINT(facebook-hte-MultipleIncludeGuardMissing) +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// Macros for defining Hive connector config properties. Same traits-based +// pattern as VELOX_QUERY_CONFIG in QueryConfig.h. +// +// No #pragma once — this file is intentionally included multiple times +// (once per header that defines config properties). Each inclusion must +// be paired with HiveConfigMacrosUndef.h to scope the macros. +// +// To add a new property, add two lines (order doesn't matter, but +// grouping related properties together helps readability): +// +// 1. In FileConfig.h or HiveConfig.h (inside the class body): +// VELOX_HIVE_CONFIG(kMaxPartitionsPerWritersSession, +// maxPartitionsPerWriters, "max_partitions_per_writers", +// uint32_t, 128, "Maximum partitions per writer.") +// +// 2. In the matching .cpp file (inside registeredProperties()): +// VELOX_HIVE_CONFIG_REGISTER(kMaxPartitionsPerWritersSession); +// +// VELOX_HIVE_CONFIG generates: +// - struct kMaxPartitionsPerWritersSessionProperty { +// using type = uint32_t; +// static constexpr const char* key = "max_partitions_per_writers"; +// static constexpr auto defaultValue = 128; +// static constexpr const char* description = "Maximum partitions..."; +// }; +// - static constexpr const char* kMaxPartitionsPerWritersSession = +// "max_partitions_per_writers" +// - uint32_t maxPartitionsPerWriters(const ConfigBase* session) const { +// return session->get(kMaxPartitionsPerWritersSession, +// config_->get(toConfigKey(...), 128)); +// } +// +// VELOX_HIVE_CONFIG_PROPERTY generates the same but without the accessor. +// Used for properties with custom accessor logic: capacity parsing, +// validation, or session-only properties that should not fall back to +// connector config (e.g., ignoreMissingFiles, isPartitionPathAsLowerCase). +// New properties should use VELOX_HIVE_CONFIG and put validation in +// HiveConfigProvider::normalize() and HiveConnector constructor. +// TODO: Unify validation into a single path. +// +// VELOX_HIVE_CONFIG_LEGACY is the same as VELOX_HIVE_CONFIG but takes an +// explicit config key and uses sessionValue() for legacy 'hive.' prefix +// fallback. + +#ifdef VELOX_HIVE_CONFIG_MACROS_DEFINED +#error "HiveConfigMacrosDefine.h included twice without HiveConfigMacrosUndef.h" +#endif +// NOLINTBEGIN(facebook-modularize-issue-check) +#define VELOX_HIVE_CONFIG_MACROS_DEFINED + +#define VELOX_HIVE_CONFIG_PROPERTY( \ + constName, keyStr, CppType, defaultVal, desc) \ + struct constName##Property { \ + using type = CppType; \ + static constexpr const char* key = keyStr; \ + static constexpr auto defaultValue = defaultVal; \ + static constexpr const char* description = desc; \ + }; \ + static constexpr const char* constName = keyStr; + +#define VELOX_HIVE_CONFIG( \ + constName, accessorName, key, CppType, defaultVal, desc) \ + VELOX_HIVE_CONFIG_PROPERTY(constName, key, CppType, defaultVal, desc) \ + CppType accessorName(const config::ConfigBase* session) const { \ + return session->get( \ + constName, \ + config_->get( \ + config::ConfigBase::toConfigKey(constName), defaultVal)); \ + } + +#define VELOX_HIVE_CONFIG_LEGACY( \ + constName, \ + configConstName, \ + accessorName, \ + key, \ + configKey, \ + CppType, \ + defaultVal, \ + desc) \ + VELOX_HIVE_CONFIG_PROPERTY(constName, key, CppType, defaultVal, desc) \ + static constexpr const char* configConstName = configKey; \ + CppType accessorName(const config::ConfigBase* session) const { \ + return sessionValue( \ + session, constName, configConstName, defaultVal); \ + } +// NOLINTEND(facebook-modularize-issue-check) diff --git a/velox/connectors/hive/HiveConfigMacrosUndef.h b/velox/connectors/hive/HiveConfigMacrosUndef.h new file mode 100644 index 00000000000..c2d0e5685a7 --- /dev/null +++ b/velox/connectors/hive/HiveConfigMacrosUndef.h @@ -0,0 +1,29 @@ +// NOLINT(facebook-hte-MultipleIncludeGuardMissing) +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Undefines macros from HiveConfigMacrosDefine.h. +// Must be included after all macro usage. +// No #pragma once — see HiveConfigMacrosDefine.h for explanation. + +#ifndef VELOX_HIVE_CONFIG_MACROS_DEFINED +#error "HiveConfigMacrosUndef.h included without HiveConfigMacrosDefine.h" +#endif +#undef VELOX_HIVE_CONFIG_MACROS_DEFINED + +#undef VELOX_HIVE_CONFIG +#undef VELOX_HIVE_CONFIG_LEGACY +#undef VELOX_HIVE_CONFIG_PROPERTY diff --git a/velox/connectors/hive/HiveConfigProvider.cpp b/velox/connectors/hive/HiveConfigProvider.cpp new file mode 100644 index 00000000000..c6107a1872f --- /dev/null +++ b/velox/connectors/hive/HiveConfigProvider.cpp @@ -0,0 +1,32 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/connectors/hive/HiveConfigProvider.h" + +#include "velox/connectors/hive/HiveConfig.h" + +namespace facebook::velox::connector::hive { + +std::vector HiveConfigProvider::properties() const { + return HiveConfig::registeredProperties(); +} + +std::string HiveConfigProvider::normalize( + std::string_view /*name*/, + std::string_view value) const { + return std::string(value); +} + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/HiveConfigProvider.h b/velox/connectors/hive/HiveConfigProvider.h new file mode 100644 index 00000000000..752962d4585 --- /dev/null +++ b/velox/connectors/hive/HiveConfigProvider.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/common/config/ConfigProvider.h" + +namespace facebook::velox::connector::hive { + +/// Exposes Hive connector session-overridable properties. +class HiveConfigProvider : public config::ConfigProvider { + public: + /// Returns all session-overridable Hive connector properties. + std::vector properties() const override; + + /// Validates and normalizes a property value. + std::string normalize(std::string_view name, std::string_view value) + const override; +}; + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/HiveConnector.cpp b/velox/connectors/hive/HiveConnector.cpp index e04828e83aa..062a507fc64 100644 --- a/velox/connectors/hive/HiveConnector.cpp +++ b/velox/connectors/hive/HiveConnector.cpp @@ -17,8 +17,10 @@ #include "velox/connectors/hive/HiveConnector.h" #include "velox/connectors/hive/HiveConfig.h" +#include "velox/connectors/hive/HiveConfigProvider.h" #include "velox/connectors/hive/HiveDataSink.h" #include "velox/connectors/hive/HiveDataSource.h" +#include "velox/connectors/hive/HiveIndexSource.h" #include "velox/connectors/hive/HivePartitionFunction.h" #include @@ -53,10 +55,15 @@ HiveConnector::HiveConnector( } } +const config::ConfigProvider* HiveConnector::configProvider() const { + static const HiveConfigProvider kProvider; + return &kProvider; +} + std::unique_ptr HiveConnector::createDataSource( const RowTypePtr& outputType, const ConnectorTableHandlePtr& tableHandle, - const std::unordered_map& columnHandles, + const ColumnHandleMap& columnHandles, ConnectorQueryCtx* connectorQueryCtx) { return std::make_unique( outputType, @@ -86,6 +93,30 @@ std::unique_ptr HiveConnector::createDataSink( hiveConfig_); } +std::shared_ptr HiveConnector::createIndexSource( + const RowTypePtr& inputType, + const std::vector>& + joinConditions, + const RowTypePtr& outputType, + const ConnectorTableHandlePtr& tableHandle, + const ColumnHandleMap& columnHandles, + ConnectorQueryCtx* connectorQueryCtx) { + auto hiveTableHandle = + std::dynamic_pointer_cast(tableHandle); + VELOX_CHECK_NOT_NULL( + hiveTableHandle, "Hive connector expecting hive table handle!"); + return std::make_shared( + inputType, + joinConditions, + outputType, + hiveTableHandle, + columnHandles, + &fileHandleFactory_, + connectorQueryCtx, + hiveConfig_, + ioExecutor_); +} + // static void HiveConnector::registerSerDe() { HiveTableHandle::registerSerDe(); diff --git a/velox/connectors/hive/HiveConnector.h b/velox/connectors/hive/HiveConnector.h index c6b91392976..95c175c4f69 100644 --- a/velox/connectors/hive/HiveConnector.h +++ b/velox/connectors/hive/HiveConnector.h @@ -34,6 +34,8 @@ class HiveConnector : public Connector { std::shared_ptr config, folly::Executor* executor); + const config::ConfigProvider* configProvider() const override; + bool canAddDynamicFilter() const override { return true; } @@ -44,15 +46,13 @@ class HiveConnector : public Connector { const connector::ColumnHandleMap& columnHandles, ConnectorQueryCtx* connectorQueryCtx) override; -#ifdef VELOX_ENABLE_BACKWARD_COMPATIBILITY - bool supportsSplitPreload() override { + bool supportsSplitPreload() const override { return true; } -#else - bool supportsSplitPreload() const override { + + bool supportsIndexLookup() const override { return true; } -#endif std::unique_ptr createDataSink( RowTypePtr inputType, @@ -60,6 +60,15 @@ class HiveConnector : public Connector { ConnectorQueryCtx* connectorQueryCtx, CommitStrategy commitStrategy) override; + std::shared_ptr createIndexSource( + const RowTypePtr& inputType, + const std::vector>& + joinConditions, + const RowTypePtr& outputType, + const ConnectorTableHandlePtr& tableHandle, + const ColumnHandleMap& columnHandles, + ConnectorQueryCtx* connectorQueryCtx) override; + folly::Executor* ioExecutor() const override { return ioExecutor_; } @@ -95,7 +104,7 @@ class HiveConnectorFactory : public ConnectorFactory { const std::string& id, std::shared_ptr config, folly::Executor* ioExecutor = nullptr, - folly::Executor* cpuExecutor = nullptr) override { + [[maybe_unused]] folly::Executor* cpuExecutor = nullptr) override { return std::make_shared(id, config, ioExecutor); } }; diff --git a/velox/connectors/hive/HiveConnectorSplit.cpp b/velox/connectors/hive/HiveConnectorSplit.cpp index f38336d92ab..de5a55e8c14 100644 --- a/velox/connectors/hive/HiveConnectorSplit.cpp +++ b/velox/connectors/hive/HiveConnectorSplit.cpp @@ -30,15 +30,6 @@ std::string HiveConnectorSplit::toString() const { return fmt::format("Hive: {} {} - {}", filePath, start, length); } -uint64_t HiveConnectorSplit::size() const { - return length; -} - -std::string HiveConnectorSplit::getFileName() const { - const auto i = filePath.rfind('/'); - return i == std::string::npos ? filePath : filePath.substr(i + 1); -} - folly::dynamic HiveConnectorSplit::serialize() const { folly::dynamic obj = folly::dynamic::object; obj["name"] = "HiveConnectorSplit"; @@ -148,8 +139,10 @@ std::shared_ptr HiveConnectorSplit::create( std::vector> bucketColumnHandles; for (const auto& bucketColumnHandleObj : bucketConversionObj["bucketColumnHandles"]) { - bucketColumnHandles.push_back(std::const_pointer_cast( - ISerializable::deserialize(bucketColumnHandleObj))); + bucketColumnHandles.push_back( + std::const_pointer_cast( + ISerializable::deserialize( + bucketColumnHandleObj))); } bucketConversion = HiveBucketConversion{ .tableBucketCount = static_cast( diff --git a/velox/connectors/hive/HiveConnectorSplit.h b/velox/connectors/hive/HiveConnectorSplit.h index 3485c2330fa..8f54e348167 100644 --- a/velox/connectors/hive/HiveConnectorSplit.h +++ b/velox/connectors/hive/HiveConnectorSplit.h @@ -17,10 +17,8 @@ #include #include -#include "velox/connectors/Connector.h" -#include "velox/connectors/hive/FileProperties.h" +#include "velox/connectors/hive/FileConnectorSplit.h" #include "velox/connectors/hive/TableHandle.h" -#include "velox/dwio/common/Options.h" namespace facebook::velox::connector::hive { @@ -42,31 +40,17 @@ struct RowIdProperties { std::string tableGuid; }; -struct HiveConnectorSplit : public connector::ConnectorSplit { - const std::string filePath; - dwio::common::FileFormat fileFormat; - const uint64_t start; - const uint64_t length; - - /// Mapping from partition keys to values. Values are specified as strings - /// formatted the same way as CAST(x as VARCHAR). Null values are specified as - /// std::nullopt. Date values must be formatted using ISO 8601 as YYYY-MM-DD. - /// All scalar types and date type are supported. - const std::unordered_map> - partitionKeys; +struct HiveConnectorSplit : public FileConnectorSplit { + /// Synthesized columns like $path, $file_size associated with the split. + std::unordered_map infoColumns; + + /// Format-specific reader parameters (e.g., ORC serde options from Hive + /// metastore). + std::unordered_map serdeParameters; + std::optional tableBucketNumber; std::unordered_map customSplitInfo; std::shared_ptr extraFileInfo; - // Parameters that are provided as the serialization options. - std::unordered_map serdeParameters; - - /// These represent columns like $file_size, $file_modified_time that are - /// associated with the HiveSplit. - std::unordered_map infoColumns; - - /// These represent file properties like file size that are used while opening - /// the file handle. - std::optional properties; std::optional rowIdProperties; @@ -91,29 +75,28 @@ struct HiveConnectorSplit : public connector::ConnectorSplit { std::optional _rowIdProperties = std::nullopt, const std::optional& _bucketConversion = std::nullopt) - : ConnectorSplit(connectorId, splitWeight, cacheable), - filePath(_filePath), - fileFormat(_fileFormat), - start(_start), - length(_length), - partitionKeys(_partitionKeys), + : FileConnectorSplit( + connectorId, + _filePath, + _fileFormat, + _start, + _length, + splitWeight, + cacheable, + std::move(_properties), + _partitionKeys), + infoColumns(_infoColumns), + serdeParameters(_serdeParameters), tableBucketNumber(_tableBucketNumber), customSplitInfo(_customSplitInfo), extraFileInfo(_extraFileInfo), - serdeParameters(_serdeParameters), - infoColumns(_infoColumns), - properties(_properties), rowIdProperties(_rowIdProperties), bucketConversion(_bucketConversion) {} ~HiveConnectorSplit() = default; - uint64_t size() const override; - std::string toString() const override; - std::string getFileName() const; - folly::dynamic serialize() const override; static std::shared_ptr create(const folly::dynamic& obj); diff --git a/velox/connectors/hive/HiveConnectorUtil.cpp b/velox/connectors/hive/HiveConnectorUtil.cpp index e06ee973ec7..0381ee2ecb5 100644 --- a/velox/connectors/hive/HiveConnectorUtil.cpp +++ b/velox/connectors/hive/HiveConnectorUtil.cpp @@ -16,12 +16,18 @@ #include "velox/connectors/hive/HiveConnectorUtil.h" -#include "velox/connectors/hive/HiveConfig.h" -#include "velox/connectors/hive/HiveConnectorSplit.h" +#include "velox/connectors/hive/FileColumnHandle.h" + +#include "velox/connectors/hive/FileConfig.h" +#include "velox/connectors/hive/FileConnectorSplit.h" +#include "velox/connectors/hive/FileConnectorUtil.h" +#include "velox/connectors/hive/FileTableHandle.h" #include "velox/dwio/common/CachedBufferedInput.h" #include "velox/dwio/common/DirectBufferedInput.h" #include "velox/expression/Expr.h" +#include "velox/expression/ExprConstants.h" #include "velox/expression/ExprToSubfieldFilter.h" +#include "velox/expression/FieldReference.h" namespace facebook::velox::connector::hive { namespace { @@ -73,14 +79,15 @@ std::unique_ptr makeFloatingPointMapKeyFilter( if (lowerUnbounded && upperUnbounded) { continue; } - filters.push_back(std::make_unique>( - lower, - lowerUnbounded, - lowerExclusive, - upper, - upperUnbounded, - upperExclusive, - false)); + filters.push_back( + std::make_unique>( + lower, + lowerUnbounded, + lowerExclusive, + upper, + upperUnbounded, + upperExclusive, + false)); } if (filters.size() == 1) { return std::move(filters[0]); @@ -105,19 +112,16 @@ void addSubfields( } } subfields.resize(newSize); + switch (type.kind()) { case TypeKind::ROW: { folly::F14FastMap> required; for (auto& subfield : subfields) { auto* element = subfield.subfield->path()[level].get(); - auto* nestedField = element->as(); - VELOX_CHECK( - nestedField, - "Unsupported for row subfields pruning: {}", - element->toString()); + auto* nestedField = element->asChecked(); required[nestedField->name()].push_back(subfield); } - auto& rowType = type.asRow(); + const auto& rowType = type.asRow(); for (int i = 0; i < rowType.size(); ++i) { auto& childName = rowType.nameOf(i); auto& childType = rowType.childAt(i); @@ -239,7 +243,7 @@ inline uint8_t parseDelimiter(const std::string& delim) { inline bool isSynthesizedColumn( const std::string& name, - const std::unordered_map& infoColumns) { + const std::unordered_map& infoColumns) { return infoColumns.count(name) != 0; } @@ -253,9 +257,8 @@ bool isSpecialColumn( const std::string& getColumnName(const common::Subfield& subfield) { VELOX_CHECK_GT(subfield.path().size(), 0); - auto* field = dynamic_cast( + auto* field = checkedPointerCast( subfield.path()[0].get()); - VELOX_CHECK_NOT_NULL(field); return field->name(); } @@ -285,7 +288,7 @@ void checkColumnNameLowerCase(const TypePtr& type) { void checkColumnNameLowerCase( const common::SubfieldFilters& filters, - const std::unordered_map& infoColumns) { + const std::unordered_map& infoColumns) { for (const auto& filterIt : filters) { const auto name = filterIt.first.toString(); if (isSynthesizedColumn(name, infoColumns)) { @@ -333,9 +336,10 @@ void processFieldSpec( } }); if (dataColumns) { - auto i = dataColumns->getChildIdxIfExists(fieldSpec.fieldName()); - if (i.has_value()) { - if (dataColumns->childAt(*i)->isMap() && outputType->isRow()) { + const auto childIdxOpt = + dataColumns->getChildIdxIfExists(fieldSpec.fieldName()); + if (childIdxOpt.has_value()) { + if (dataColumns->childAt(*childIdxOpt)->isMap() && outputType->isRow()) { fieldSpec.setFlatMapAsStruct(true); } } @@ -344,14 +348,57 @@ void processFieldSpec( } // namespace +void checkColumnHandleConsistent( + const FileColumnHandle& x, + const FileColumnHandle& y) { + VELOX_CHECK_EQ( + x.columnType(), + y.columnType(), + "Inconsistent column handle type: {}, expected {}, got {}", + x.name(), + FileColumnHandle::columnTypeName(x.columnType()), + FileColumnHandle::columnTypeName(y.columnType())); + VELOX_CHECK( + x.dataType()->equivalent(*y.dataType()), + "Inconsistent column handle data type: {}, expected {}, got {}", + x.name(), + x.dataType()->toString(), + y.dataType()->toString()); +} + std::shared_ptr makeScanSpec( const RowTypePtr& rowType, const folly::F14FastMap>& outputSubfields, - const common::SubfieldFilters& filters, + const common::SubfieldFilters& subfieldFilters, const RowTypePtr& dataColumns, - const std::unordered_map& partitionKeys, - const std::unordered_map& infoColumns, + const std::unordered_map& partitionKeys, + const std::unordered_map& infoColumns, + const SpecialColumnNames& specialColumns, + bool disableStatsBasedFilterReorder, + memory::MemoryPool* pool) { + return makeScanSpec( + rowType, + outputSubfields, + subfieldFilters, + /*indexColumns=*/{}, + dataColumns, + partitionKeys, + infoColumns, + specialColumns, + disableStatsBasedFilterReorder, + pool); +} + +std::shared_ptr makeScanSpec( + const RowTypePtr& rowType, + const folly::F14FastMap>& + outputSubfields, + const common::SubfieldFilters& subfieldFilters, + const std::vector& indexColumns, + const RowTypePtr& dataColumns, + const std::unordered_map& partitionKeys, + const std::unordered_map& infoColumns, const SpecialColumnNames& specialColumns, bool disableStatsBasedFilterReorder, memory::MemoryPool* pool) { @@ -359,7 +406,7 @@ std::shared_ptr makeScanSpec( folly::F14FastMap> filterSubfields; std::vector subfieldSpecs; - for (auto& [subfield, _] : filters) { + for (const auto& [subfield, _] : subfieldFilters) { if (auto name = subfield.toString(); !isSynthesizedColumn(name, infoColumns) && partitionKeys.count(name) == 0) { @@ -426,16 +473,19 @@ std::shared_ptr makeScanSpec( } } - for (auto& pair : filters) { + // Process index columns from join conditions to ensure they are read. + // These columns are not projected out, only used for index lookup. + for (const auto& keyName : indexColumns) { + VELOX_CHECK_NOT_NULL(dataColumns); + if (spec->childByName(keyName) == nullptr) { + // This is required so that we can set filter on the index column in the + // selective reader later. + spec->getOrCreateChild(keyName); + } + } + + for (auto& pair : subfieldFilters) { const auto name = pair.first.toString(); - // SelectiveColumnReader doesn't support constant columns with filters, - // hence, we can't have a filter for a $path or $bucket column. - // - // Unfortunately, Presto happens to specify a filter for $path, $file_size, - // $file_modified_time or $bucket column. This filter is redundant and needs - // to be removed. - // TODO Remove this check when Presto is fixed to not specify a filter - // on $path and $bucket column. if (isSynthesizedColumn(name, infoColumns)) { continue; } @@ -523,84 +573,52 @@ std::unique_ptr parseSerdeParameters( } void configureReaderOptions( - const std::shared_ptr& hiveConfig, + const std::shared_ptr& fileConfig, const ConnectorQueryCtx* connectorQueryCtx, - const std::shared_ptr& hiveTableHandle, - const std::shared_ptr& hiveSplit, + const FileTableHandlePtr& tableHandle, + const std::shared_ptr& fileSplit, + const std::unordered_map& serdeParameters, dwio::common::ReaderOptions& readerOptions) { configureReaderOptions( - hiveConfig, + fileConfig, connectorQueryCtx, - hiveTableHandle->dataColumns(), - hiveSplit, - hiveTableHandle->tableParameters(), + tableHandle->dataColumns(), + fileSplit, + tableHandle->tableParameters(), + serdeParameters, readerOptions); } void configureReaderOptions( - const std::shared_ptr& hiveConfig, + const std::shared_ptr& fileConfig, const ConnectorQueryCtx* connectorQueryCtx, const RowTypePtr& fileSchema, - const std::shared_ptr& hiveSplit, + const std::shared_ptr& fileSplit, const std::unordered_map& tableParameters, + const std::unordered_map& serdeParameters, dwio::common::ReaderOptions& readerOptions) { - auto sessionProperties = connectorQueryCtx->sessionProperties(); - readerOptions.setLoadQuantum(hiveConfig->loadQuantum(sessionProperties)); - readerOptions.setMaxCoalesceBytes( - hiveConfig->maxCoalescedBytes(sessionProperties)); - readerOptions.setMaxCoalesceDistance( - hiveConfig->maxCoalescedDistanceBytes(sessionProperties)); - readerOptions.setFileColumnNamesReadAsLowerCase( - hiveConfig->isFileColumnNamesReadAsLowerCase(sessionProperties)); - readerOptions.setAllowEmptyFile(true); - bool useColumnNamesForColumnMapping = false; - switch (hiveSplit->fileFormat) { - case dwio::common::FileFormat::DWRF: - case dwio::common::FileFormat::ORC: { - useColumnNamesForColumnMapping = - hiveConfig->isOrcUseColumnNames(sessionProperties); - break; - } - case dwio::common::FileFormat::PARQUET: { - useColumnNamesForColumnMapping = - hiveConfig->isParquetUseColumnNames(sessionProperties); - break; - } - default: - useColumnNamesForColumnMapping = false; - } - - readerOptions.setUseColumnNamesForColumnMapping( - useColumnNamesForColumnMapping); - readerOptions.setFileSchema(fileSchema); - readerOptions.setFooterEstimatedSize(hiveConfig->footerEstimatedSize()); - readerOptions.setFilePreloadThreshold(hiveConfig->filePreloadThreshold()); - readerOptions.setPrefetchRowGroups(hiveConfig->prefetchRowGroups()); - readerOptions.setNoCacheRetention(!hiveSplit->cacheable); - const auto& sessionTzName = connectorQueryCtx->sessionTimezone(); - if (!sessionTzName.empty()) { - const auto timezone = tz::locateZone(sessionTzName); - readerOptions.setSessionTimezone(timezone); - } - readerOptions.setAdjustTimestampToTimezone( - connectorQueryCtx->adjustTimestampToTimezone()); - readerOptions.setSelectiveNimbleReaderEnabled( - connectorQueryCtx->selectiveNimbleReaderEnabled()); - - if (readerOptions.fileFormat() != dwio::common::FileFormat::UNKNOWN) { - VELOX_CHECK( - readerOptions.fileFormat() == hiveSplit->fileFormat, - "HiveDataSource received splits of different formats: {} and {}", - dwio::common::toString(readerOptions.fileFormat()), - dwio::common::toString(hiveSplit->fileFormat)); - } else { - auto serDeOptions = - parseSerdeParameters(hiveSplit->serdeParameters, tableParameters); + // Check if format was UNKNOWN before the generic call, since the generic + // version will set it. Serde options should only be applied on the first + // call (when format transitions from UNKNOWN). + const bool formatWasUnknown = + readerOptions.fileFormat() == dwio::common::FileFormat::UNKNOWN; + + // Call the generic version which handles everything except serde. + // Use fully qualified call to the generic overload (6 params, no serde). + hive::configureReaderOptions( + fileConfig, + connectorQueryCtx, + fileSchema, + fileSplit, + tableParameters, + readerOptions); + + // Apply serde options on top (Hive-specific). + if (formatWasUnknown) { + auto serDeOptions = parseSerdeParameters(serdeParameters, tableParameters); if (serDeOptions) { readerOptions.setSerDeOptions(*serDeOptions); } - - readerOptions.setFileFormat(hiveSplit->fileFormat); } } @@ -609,151 +627,37 @@ void configureRowReaderOptions( const std::shared_ptr& scanSpec, std::shared_ptr metadataFilter, const RowTypePtr& rowType, - const std::shared_ptr& hiveSplit, - const std::shared_ptr& hiveConfig, + const std::shared_ptr& fileSplit, + const std::unordered_map& serdeParameters, + const std::shared_ptr& fileConfig, const config::ConfigBase* sessionProperties, + folly::Executor* const ioExecutor, dwio::common::RowReaderOptions& rowReaderOptions) { - auto skipRowsIt = - tableParameters.find(dwio::common::TableParameter::kSkipHeaderLineCount); - if (skipRowsIt != tableParameters.end()) { - rowReaderOptions.setSkipRows(folly::to(skipRowsIt->second)); - } - rowReaderOptions.setScanSpec(scanSpec); - rowReaderOptions.setMetadataFilter(std::move(metadataFilter)); - rowReaderOptions.setRequestedType(rowType); - rowReaderOptions.range(hiveSplit->start, hiveSplit->length); - if (hiveConfig && sessionProperties) { - rowReaderOptions.setTimestampPrecision(static_cast( - hiveConfig->readTimestampUnit(sessionProperties))); - rowReaderOptions.setPreserveFlatMapsInMemory( - hiveConfig->preserveFlatMapsInMemory(sessionProperties)); - } - rowReaderOptions.setSerdeParameters(hiveSplit->serdeParameters); -} - -namespace { - -bool applyPartitionFilter( - const TypePtr& type, - const std::string& partitionValue, - bool isPartitionDateDaysSinceEpoch, - const common::Filter* filter, - bool asLocalTime) { - if (type->isDate()) { - int32_t result = 0; - // days_since_epoch partition values are integers in string format. Eg. - // Iceberg partition values. - if (isPartitionDateDaysSinceEpoch) { - result = folly::to(partitionValue); - } else { - result = DATE()->toDays(static_cast(partitionValue)); - } - return applyFilter(*filter, result); - } - - switch (type->kind()) { - case TypeKind::BIGINT: - case TypeKind::INTEGER: - case TypeKind::SMALLINT: - case TypeKind::TINYINT: { - return applyFilter(*filter, folly::to(partitionValue)); - } - case TypeKind::REAL: - case TypeKind::DOUBLE: { - return applyFilter(*filter, folly::to(partitionValue)); - } - case TypeKind::BOOLEAN: { - return applyFilter(*filter, folly::to(partitionValue)); - } - case TypeKind::TIMESTAMP: { - auto result = util::fromTimestampString( - StringView(partitionValue), util::TimestampParseMode::kPrestoCast); - VELOX_CHECK(!result.hasError()); - if (asLocalTime) { - result.value().toGMT(Timestamp::defaultTimezone()); - } - return applyFilter(*filter, result.value()); - } - case TypeKind::VARCHAR: { - return applyFilter(*filter, partitionValue); - } - default: - VELOX_FAIL( - "Bad type {} for partition value: {}", type->kind(), partitionValue); - } -} - -} // namespace - -bool testFilters( - const common::ScanSpec* scanSpec, - const dwio::common::Reader* reader, - const std::string& filePath, - const std::unordered_map>& - partitionKeys, - const std::unordered_map& - partitionKeysHandle, - bool asLocalTime) { - const auto totalRows = reader->numberOfRows(); - const auto& fileTypeWithId = reader->typeWithId(); - const auto& rowType = reader->rowType(); - for (const auto& child : scanSpec->children()) { - if (child->filter()) { - const auto& name = child->fieldName(); - auto iter = partitionKeys.find(name); - // By design, the partition key columns for Iceberg tables are included in - // the data files to facilitate partition transform and partition - // evolution, so we need to test both cases. - if (!rowType->containsChild(name) || iter != partitionKeys.end()) { - if (iter != partitionKeys.end() && iter->second.has_value()) { - const auto handlesIter = partitionKeysHandle.find(name); - VELOX_CHECK(handlesIter != partitionKeysHandle.end()); - - // This is a non-null partition key - return applyPartitionFilter( - handlesIter->second->dataType(), - iter->second.value(), - handlesIter->second->isPartitionDateValueDaysSinceEpoch(), - child->filter(), - asLocalTime); - } - // Column is missing, most likely due to schema evolution. Or it's a - // partition key but the partition value is NULL. - if (child->filter()->isDeterministic() && - !child->filter()->testNull()) { - VLOG(1) << "Skipping " << filePath - << " because the filter testNull() failed for column " - << child->fieldName(); - return false; - } - } else { - const auto& typeWithId = fileTypeWithId->childByName(name); - const auto columnStats = reader->columnStatistics(typeWithId->id()); - if (columnStats != nullptr && - !testFilter( - child->filter(), - columnStats.get(), - totalRows.value(), - typeWithId->type())) { - VLOG(1) << "Skipping " << filePath - << " based on stats and filter for column " - << child->fieldName(); - return false; - } - } - } - } - - return true; + // Call the generic version which handles everything except serde. + // Use fully qualified call to the generic overload (9 params, no serde). + hive::configureRowReaderOptions( + tableParameters, + scanSpec, + std::move(metadataFilter), + rowType, + fileSplit, + fileConfig, + sessionProperties, + ioExecutor, + rowReaderOptions); + + // Apply serde parameters on top (Hive-specific). + rowReaderOptions.setSerdeParameters(serdeParameters); } std::unique_ptr createBufferedInput( const FileHandle& fileHandle, const dwio::common::ReaderOptions& readerOpts, const ConnectorQueryCtx* connectorQueryCtx, - std::shared_ptr ioStats, - std::shared_ptr fsStats, - folly::Executor* executor) { + std::shared_ptr ioStatistics, + std::shared_ptr ioStats, + folly::Executor* executor, + const folly::F14FastMap& fileReadOps) { if (connectorQueryCtx->cache()) { return std::make_unique( fileHandle.file, @@ -763,10 +667,11 @@ std::unique_ptr createBufferedInput( Connector::getTracker( connectorQueryCtx->scanId(), readerOpts.loadQuantum()), fileHandle.groupId, - ioStats, - std::move(fsStats), + ioStatistics, + std::move(ioStats), executor, - readerOpts); + readerOpts, + fileReadOps); } if (readerOpts.fileFormat() == dwio::common::FileFormat::NIMBLE) { // Nimble streams (in case of single chunk) are compressed as whole and need @@ -777,8 +682,11 @@ std::unique_ptr createBufferedInput( fileHandle.file, readerOpts.memoryPool(), dwio::common::MetricsLog::voidLog(), + ioStatistics.get(), ioStats.get(), - fsStats.get()); + dwio::common::BufferedInput::kMaxMergeDistance, + std::nullopt, + fileReadOps); } return std::make_unique( fileHandle.file, @@ -787,10 +695,11 @@ std::unique_ptr createBufferedInput( Connector::getTracker( connectorQueryCtx->scanId(), readerOpts.loadQuantum()), fileHandle.groupId, + std::move(ioStatistics), std::move(ioStats), - std::move(fsStats), executor, - readerOpts); + readerOpts, + fileReadOps); } namespace { @@ -803,16 +712,7 @@ core::CallTypedExprPtr replaceInputs( } bool endWith(const std::string& str, const char* suffix) { - int len = strlen(suffix); - if (str.size() < len) { - return false; - } - for (int i = 0, j = str.size() - len; i < len; ++i, ++j) { - if (str[j] != suffix[i]) { - return false; - } - } - return true; + return str.ends_with(suffix); } bool isNotExpr( @@ -860,8 +760,6 @@ double getPrestoSampleRate( return std::max(0.0, std::min(1.0, rate->value().value())); } -} // namespace - core::TypedExprPtr extractFiltersFromRemainingFilter( const core::TypedExprPtr& expr, core::ExpressionEvaluator* evaluator, @@ -874,10 +772,10 @@ core::TypedExprPtr extractFiltersFromRemainingFilter( } common::Filter* oldFilter = nullptr; try { - common::Subfield subfield; - if (auto filter = exec::ExprToSubfieldFilterParser::getInstance() - ->leafCallToSubfieldFilter( - *call, subfield, evaluator, negated)) { + if (auto subfieldAndFilter = + exec::ExprToSubfieldFilterParser::getInstance() + ->leafCallToSubfieldFilter(*call, evaluator, negated)) { + auto& [subfield, filter] = subfieldAndFilter.value(); if (auto it = filters.find(subfield); it != filters.end()) { oldFilter = it->second.get(); filter = filter->mergeWith(oldFilter); @@ -899,20 +797,73 @@ core::TypedExprPtr extractFiltersFromRemainingFilter( return inner ? replaceInputs(call, {inner}) : nullptr; } - if ((call->name() == "and" && !negated) || - (call->name() == "or" && negated)) { - auto lhs = extractFiltersFromRemainingFilter( - call->inputs()[0], evaluator, negated, filters, sampleRate); - auto rhs = extractFiltersFromRemainingFilter( - call->inputs()[1], evaluator, negated, filters, sampleRate); - if (!lhs) { - return rhs; + if ((call->name() == expression::kAnd && !negated) || + (call->name() == expression::kOr && negated)) { + std::vector args; + args.reserve(call->inputs().size()); + for (const auto& input : call->inputs()) { + if (auto arg = extractFiltersFromRemainingFilter( + input, evaluator, negated, filters, sampleRate)) { + args.push_back(std::move(arg)); + } + // If extractFiltersFromRemainingFilter returns nullptr, it means + // everything in input is converted to filters. + } + if (args.empty()) { + return nullptr; + } + if (args.size() == 1) { + return std::move(args[0]); + } + return replaceInputs(call, std::move(args)); + } + + if ((call->name() == expression::kAnd && negated) || + (call->name() == expression::kOr && !negated)) { + std::vector> disjuncts; + common::Subfield subfield; + + for (const auto& input : call->inputs()) { + common::SubfieldFilters tmpFilters; + double tmpSampleRate = 1; + auto tmpRemaining = extractFiltersFromRemainingFilter( + input, evaluator, negated, tmpFilters, tmpSampleRate); + + if (tmpRemaining != nullptr || tmpSampleRate != 1 || + tmpFilters.size() != 1) { + disjuncts.clear(); + break; + } + + if (disjuncts.empty()) { + subfield = tmpFilters.begin()->first.clone(); + } else if (!(subfield == tmpFilters.begin()->first)) { + disjuncts.clear(); + break; + } + + disjuncts.push_back(tmpFilters.begin()->second->clone()); } - if (!rhs) { - return lhs; + + if (!disjuncts.empty()) { + auto filter = + exec::ExprToSubfieldFilterParser::makeOrFilter(std::move(disjuncts)); + + if (filter == nullptr) { + return expr; + } + + auto it = filters.find(subfield); + if (it != filters.end()) { + filter = filter->mergeWith(it->second.get()); + } + + filters.insert_or_assign(std::move(subfield), std::move(filter)); + + return nullptr; } - return replaceInputs(call, {lhs, rhs}); } + if (!negated) { double rate = getPrestoSampleRate(expr, call, evaluator); if (rate != -1) { @@ -920,6 +871,118 @@ core::TypedExprPtr extractFiltersFromRemainingFilter( return nullptr; } } + return expr; } +} // namespace + +core::TypedExprPtr extractFiltersFromRemainingFilter( + const core::TypedExprPtr& expr, + core::ExpressionEvaluator* evaluator, + common::SubfieldFilters& filters, + double& sampleRate) { + return extractFiltersFromRemainingFilter( + expr, evaluator, /*negated=*/false, filters, sampleRate); +} + +bool shouldEagerlyMaterialize( + const exec::Expr& remainingFilter, + const exec::FieldReference& field) { + const auto isMember = [](const std::vector& fields, + const exec::FieldReference& field) { + return std::find(fields.begin(), fields.end(), &field) != fields.end(); + }; + + if (!remainingFilter.evaluatesArgumentsOnNonIncreasingSelection()) { + return true; + } + for (auto& input : remainingFilter.inputs()) { + if (isMember(input->distinctFields(), field) && input->hasConditionals()) { + return true; + } + } + return false; +} + +namespace { +template +std::unique_ptr createRangeFilterInternal( + const variant& lower, + const variant& upper) { + using T = typename TypeTraits::NativeType; + const bool lowerUnbounded = lower.isNull(); + const bool upperUnbounded = upper.isNull(); + + if constexpr ( + kind == TypeKind::TINYINT || kind == TypeKind::SMALLINT || + kind == TypeKind::INTEGER || kind == TypeKind::BIGINT) { + return std::make_unique( + lowerUnbounded ? std::numeric_limits::min() : lower.value(), + upperUnbounded ? std::numeric_limits::max() : upper.value(), + /*nullAllowed=*/false); + } else if constexpr (kind == TypeKind::REAL) { + return std::make_unique( + lowerUnbounded ? std::numeric_limits::lowest() + : lower.value(), + lowerUnbounded, + /*lowerExclusive=*/false, + upperUnbounded ? std::numeric_limits::max() : upper.value(), + upperUnbounded, + /*upperExclusive=*/false, + /*nullAllowed=*/false); + } else if constexpr (kind == TypeKind::DOUBLE) { + return std::make_unique( + lowerUnbounded ? std::numeric_limits::lowest() + : lower.value(), + lowerUnbounded, + /*lowerExclusive=*/false, + upperUnbounded ? std::numeric_limits::max() : upper.value(), + upperUnbounded, + /*upperExclusive=*/false, + /*nullAllowed=*/false); + } else if constexpr ( + kind == TypeKind::VARCHAR || kind == TypeKind::VARBINARY) { + return std::make_unique( + lowerUnbounded ? "" : std::string(lower.value()), + lowerUnbounded, + /*lowerExclusive=*/false, + upperUnbounded ? "" : std::string(upper.value()), + upperUnbounded, + /*upperExclusive=*/false, + /*nullAllowed=*/false); + } else if constexpr (kind == TypeKind::BOOLEAN) { + VELOX_CHECK( + !lowerUnbounded && !upperUnbounded, + "Boolean range filter requires both bounds"); + return std::make_unique( + lower.value(), /*nullAllowed=*/false); + } else { + VELOX_UNSUPPORTED( + "Unsupported type kind for filter creation: {}", + TypeKindName::toName(kind)); + } +} +} // namespace + +std::unique_ptr createPointFilter( + const TypePtr& type, + const variant& value) { + VELOX_CHECK(!value.isNull(), "Value cannot be null"); + + return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( + createRangeFilterInternal, type->kind(), value, value); +} + +std::unique_ptr createRangeFilter( + const TypePtr& type, + const variant& lower, + const variant& upper) { + VELOX_CHECK( + !lower.isNull() || !upper.isNull(), + "At least one of lower or upper bound must be set"); + + return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( + createRangeFilterInternal, type->kind(), lower, upper); +} + } // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/HiveConnectorUtil.h b/velox/connectors/hive/HiveConnectorUtil.h index d649b12d093..191e03186ab 100644 --- a/velox/connectors/hive/HiveConnectorUtil.h +++ b/velox/connectors/hive/HiveConnectorUtil.h @@ -15,13 +15,15 @@ */ #pragma once -#include #include -#include "velox/connectors/Connector.h" -#include "velox/connectors/hive/FileHandle.h" +#include "velox/connectors/hive/FileConnectorUtil.h" #include "velox/dwio/common/BufferedInput.h" -#include "velox/dwio/common/Reader.h" + +namespace facebook::velox::exec { +class Expr; +class FieldReference; +} // namespace facebook::velox::exec namespace facebook::velox::connector::hive { @@ -32,13 +34,11 @@ struct HiveConnectorSplit; const std::string& getColumnName(const common::Subfield& subfield); -void checkColumnNameLowerCase(const std::shared_ptr& type); +void checkColumnNameLowerCase(const TypePtr& type); void checkColumnNameLowerCase( const common::SubfieldFilters& filters, - const std::unordered_map< - std::string, - std::shared_ptr>& infoColumns); + const std::unordered_map& infoColumns); void checkColumnNameLowerCase(const core::TypedExprPtr& typeExpr); @@ -47,35 +47,84 @@ struct SpecialColumnNames { std::optional rowId; }; +/// Check that two FileColumnHandle instances are consistent in terms of +/// column type, data type, and file type. Throw if inconsistent. +void checkColumnHandleConsistent( + const FileColumnHandle& x, + const FileColumnHandle& y); + +/// Creates a ScanSpec for reading data from a Hive table. +/// +/// The ScanSpec describes which columns to read and what filters to apply. +/// It handles several types of columns: +/// - Regular data columns from the file +/// - Partition key columns (values from file path) +/// - Synthesized columns (e.g., $path, $bucket) +/// - Special columns (e.g., row index, row ID) +/// - Index columns for index lookup joins +/// +/// @param rowType Schema of columns to be projected in the output. +/// @param outputSubfields Map of column names to subfields that need to be +/// read. Used for pruning nested structures. +/// @param subfieldFilters Map of subfields to filters to apply during scan. +/// @param indexColumns Column names used for index lookup joins. These columns +/// are added to the scan spec even if they are not in the output +/// projection, ensuring they are read from the file for join key matching. +/// @param dataColumns Full schema of all columns in the data file. Used to +/// look up column types when a column is referenced in filters or index +/// columns but not in the output projection. +/// @param partitionKeys Map of partition column names to their handles. +/// Partition columns are not read from the file. +/// @param infoColumns Map of synthesized column names (e.g., $path) to their +/// handles. +/// @param specialColumns Names of special columns like row index and row ID. +/// @param disableStatsBasedFilterReorder If true, disables reordering of +/// filters based on statistics. +/// @param pool Memory pool for allocations during scan spec construction. +/// @return A ScanSpec that can be used to configure a reader. std::shared_ptr makeScanSpec( const RowTypePtr& rowType, const folly::F14FastMap>& outputSubfields, - const common::SubfieldFilters& filters, + const common::SubfieldFilters& subfieldFilters, + const RowTypePtr& dataColumns, + const std::unordered_map& partitionKeys, + const std::unordered_map& infoColumns, + const SpecialColumnNames& specialColumns, + bool disableStatsBasedFilterReorder, + memory::MemoryPool* pool); + +/// @deprecated Use the overload without indexColumns parameter instead. +/// This overload is kept for backward compatibility and will be removed in a +/// future release. +std::shared_ptr makeScanSpec( + const RowTypePtr& rowType, + const folly::F14FastMap>& + outputSubfields, + const common::SubfieldFilters& subfieldFilters, + const std::vector& indexColumns, const RowTypePtr& dataColumns, - const std::unordered_map< - std::string, - std::shared_ptr>& partitionKeys, - const std::unordered_map< - std::string, - std::shared_ptr>& infoColumns, + const std::unordered_map& partitionKeys, + const std::unordered_map& infoColumns, const SpecialColumnNames& specialColumns, bool disableStatsBasedFilterReorder, memory::MemoryPool* pool); void configureReaderOptions( - const std::shared_ptr& config, + const std::shared_ptr& config, const ConnectorQueryCtx* connectorQueryCtx, - const std::shared_ptr& hiveTableHandle, - const std::shared_ptr& hiveSplit, + const FileTableHandlePtr& tableHandle, + const std::shared_ptr& fileSplit, + const std::unordered_map& serdeParameters, dwio::common::ReaderOptions& readerOptions); void configureReaderOptions( - const std::shared_ptr& hiveConfig, + const std::shared_ptr& fileConfig, const ConnectorQueryCtx* connectorQueryCtx, const RowTypePtr& fileSchema, - const std::shared_ptr& hiveSplit, + const std::shared_ptr& fileSplit, const std::unordered_map& tableParameters, + const std::unordered_map& serdeParameters, dwio::common::ReaderOptions& readerOptions); void configureRowReaderOptions( @@ -83,35 +132,107 @@ void configureRowReaderOptions( const std::shared_ptr& scanSpec, std::shared_ptr metadataFilter, const RowTypePtr& rowType, - const std::shared_ptr& hiveSplit, - const std::shared_ptr& hiveConfig, + const std::shared_ptr& fileSplit, + const std::unordered_map& serdeParameters, + const std::shared_ptr& fileConfig, const config::ConfigBase* sessionProperties, + folly::Executor* ioExecutor, dwio::common::RowReaderOptions& rowReaderOptions); -bool testFilters( - const common::ScanSpec* scanSpec, - const dwio::common::Reader* reader, - const std::string& filePath, - const std::unordered_map>& - partitionKey, - const std::unordered_map< - std::string, - std::shared_ptr>& partitionKeysHandle, - bool asLocalTime); - std::unique_ptr createBufferedInput( const FileHandle& fileHandle, const dwio::common::ReaderOptions& readerOpts, const ConnectorQueryCtx* connectorQueryCtx, - std::shared_ptr ioStats, - std::shared_ptr fsStats, - folly::Executor* executor); - + std::shared_ptr ioStatistics, + std::shared_ptr ioStats, + folly::Executor* executor, + const folly::F14FastMap& fileReadOps = {}); + +/// Given a boolean expression, breaks it up into conjuncts and sorts these into +/// single-column comparisons with constants (filters), rand() < sampleRate, and +/// the rest (return value). +/// +/// Multiple rand() < K conjuncts are combined into a single sampleRate by +/// multiplying individual sample rates. rand() < 0.1 and rand() < 0.2 produces +/// sampleRate = 0.02. +/// +/// Multiple single-column comparisons with constants that reference the same +/// column or subfield are combined into a single filter. Pre-existing entries +/// in 'filters' are preserved and combined with the ones extracted from the +/// 'expr'. +/// +/// NOT(x OR y) is converted to (NOT x) AND (NOT y). +/// +/// @param expr Boolean expression to break up. +/// @param evaluator Expression evaluator to use. +/// @param filters Mapping from a column or a subfield to comparison with +/// constant. +/// @param sampleRate Sample rate extracted from rand() < sampleRate conjuncts. +/// @return Expression with filters and rand() < sampleRate conjuncts removed. +/// +/// Examples: +/// expr := a = 1 AND b > 0 +/// filters := {a: eq(1), b: gt(0)} +// sampleRate left unmodified +// return value is nullptr +/// +/// expr: not (a > 0 or b > 10) +/// filters := {a: le(0), b: le(10)} +/// sampleRate left unmodified +/// return value is nullptr +/// +/// expr := a > 0 AND a < b AND rand() < 0.1 +/// filters := {a: gt(0)} +/// sampleRate := 0.1 +/// return value is a < b core::TypedExprPtr extractFiltersFromRemainingFilter( const core::TypedExprPtr& expr, core::ExpressionEvaluator* evaluator, - bool negated, common::SubfieldFilters& filters, double& sampleRate); +/// Determines whether a field referenced in the remaining filter should be +/// eagerly materialized (loaded upfront) or can be lazily loaded. +/// +/// Returns true (eager materialization needed) when: +/// 1. The remaining filter is NOT an AND expression (e.g., OR), because row +/// access patterns are unpredictable. +/// 2. The field is used within a conditional sub-expression (IF, CASE, nested +/// AND/OR) of an AND expression, because the conditional may access rows +/// unpredictably. +/// +/// Returns false (lazy loading OK) when the remaining filter is an AND +/// expression and the field is only used in simple, non-conditional conjuncts. +/// +/// @param remainingFilter The compiled remaining filter expression. +/// @param field The field reference to check. +/// @return true if the field should be eagerly materialized. +bool shouldEagerlyMaterialize( + const exec::Expr& remainingFilter, + const exec::FieldReference& field); + +/// Creates a point lookup filter from a variant value. +/// Null values are not allowed. +/// Supports TINYINT, SMALLINT, INTEGER, BIGINT, REAL, DOUBLE, BOOLEAN, +/// VARCHAR, and VARBINARY types. +/// @param type The type of the value. +/// @param value The filter value (must not be null). +/// @return A filter for point lookup, or nullptr if type is not supported. +std::unique_ptr createPointFilter( + const TypePtr& type, + const variant& value); + +/// Creates a range filter from two variant values. +/// Both lower and upper bounds are inclusive. Null values are not allowed. +/// Supports TINYINT, SMALLINT, INTEGER, BIGINT, REAL, DOUBLE, VARCHAR, +/// and VARBINARY types. +/// @param type The type of the values. +/// @param lower The lower bound value. +/// @param upper The upper bound value. +/// @return A filter for range lookup, or nullptr if type is not supported. +std::unique_ptr createRangeFilter( + const TypePtr& type, + const variant& lower, + const variant& upper); + } // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/HiveDataSink.cpp b/velox/connectors/hive/HiveDataSink.cpp index 36c797e4eaf..f627df624ed 100644 --- a/velox/connectors/hive/HiveDataSink.cpp +++ b/velox/connectors/hive/HiveDataSink.cpp @@ -19,21 +19,17 @@ #include "velox/common/base/Counters.h" #include "velox/common/base/Fs.h" #include "velox/common/base/StatsReporter.h" -#include "velox/common/testutil/TestValue.h" #include "velox/connectors/hive/HiveConfig.h" -#include "velox/connectors/hive/HiveConnectorUtil.h" #include "velox/connectors/hive/HivePartitionFunction.h" #include "velox/connectors/hive/TableHandle.h" #include "velox/dwio/common/Options.h" #include "velox/dwio/common/SortingWriter.h" -#include "velox/exec/OperatorUtils.h" #include "velox/exec/SortBuffer.h" #include #include #include - -using facebook::velox::common::testutil::TestValue; +#include namespace facebook::velox::connector::hive { namespace { @@ -41,72 +37,67 @@ namespace { memory::NonReclaimableSectionGuard nonReclaimableGuard( \ writerInfo_[(index)]->nonReclaimableSectionHolder.get()) -// Returns the type of non-partition data columns. -RowTypePtr getNonPartitionTypes( - const std::vector& dataCols, - const RowTypePtr& inputType) { - std::vector childNames; - std::vector childTypes; - const auto& dataSize = dataCols.size(); - childNames.reserve(dataSize); - childTypes.reserve(dataSize); - for (int dataCol : dataCols) { - childNames.push_back(inputType->nameOf(dataCol)); - childTypes.push_back(inputType->childAt(dataCol)); - } - - return ROW(std::move(childNames), std::move(childTypes)); -} - -// Filters out partition columns if there is any. -RowVectorPtr makeDataInput( - const std::vector& dataCols, - const RowVectorPtr& input) { - std::vector childVectors; - childVectors.reserve(dataCols.size()); - for (int dataCol : dataCols) { - childVectors.push_back(input->childAt(dataCol)); - } - - return std::make_shared( - input->pool(), - getNonPartitionTypes(dataCols, asRowType(input->type())), - input->nulls(), - input->size(), - std::move(childVectors), - input->getNullCount()); +// Appends a sequence number to a filename for file rotation. +// Returns the original filename if sequenceNumber is 0 (no rotation yet). +// Example: "file.orc" with seq 0 remains "file.orc" +// Example: "file.orc" with seq 2 becomes "file_2.orc" +std::string makeSequencedFileName( + const std::string& filename, + uint32_t sequenceNumber) { + if (sequenceNumber == 0) { + return filename; + } + const auto dotPos = filename.rfind('.'); + if (dotPos == std::string::npos) { + // No extension, just append the sequence number + return fmt::format("{}_{}", filename, sequenceNumber); + } + // Insert sequence number before the extension + return fmt::format( + "{}_{}{}", + filename.substr(0, dotPos), + sequenceNumber, + filename.substr(dotPos)); } -// Returns a subset of column indices corresponding to partition keys. -std::vector getPartitionChannels( - const std::shared_ptr& insertTableHandle) { - std::vector channels; - - for (column_index_t i = 0; i < insertTableHandle->inputColumns().size(); - i++) { - if (insertTableHandle->inputColumns()[i]->isPartitionKey()) { - channels.push_back(i); - } - } - - return channels; +std::unique_ptr createHiveFileSink( + const std::string& path, + const std::shared_ptr& hiveConfig, + memory::MemoryPool* sinkPool, + io::IoStatistics* ioStats, + IoStats* fileSystemStats, + const std::unordered_map& storageParameters) { + return dwio::common::FileSink::create( + path, + { + .bufferWrite = false, + .connectorProperties = hiveConfig->config(), + .fileCreateConfig = hiveConfig->writeFileCreateConfig(), + .pool = sinkPool, + .metricLogger = dwio::common::MetricsLog::voidLog(), + .stats = ioStats, + .fileSystemStats = fileSystemStats, + .storageParameters = storageParameters, + }); } -// Returns the column indices of non-partition data columns. -std::vector getNonPartitionChannels( - const std::vector& partitionChannels, - const column_index_t childrenSize) { - std::vector dataChannels; - dataChannels.reserve(childrenSize - partitionChannels.size()); - - for (column_index_t i = 0; i < childrenSize; i++) { - if (std::find(partitionChannels.cbegin(), partitionChannels.cend(), i) == - partitionChannels.cend()) { - dataChannels.push_back(i); - } +// Creates a PartitionIdGenerator if the table is partitioned, otherwise returns +// nullptr. +std::unique_ptr createPartitionIdGenerator( + const RowTypePtr& inputType, + const std::shared_ptr& insertTableHandle, + const std::shared_ptr& hiveConfig, + const ConnectorQueryCtx* connectorQueryCtx) { + auto partitionChannels = insertTableHandle->partitionChannels(); + if (partitionChannels.empty()) { + return nullptr; } - - return dataChannels; + return std::make_unique( + inputType, + partitionChannels, + hiveConfig->maxPartitionsPerWriters( + connectorQueryCtx->sessionProperties()), + connectorQueryCtx->memoryPool()); } std::string makePartitionDirectory( @@ -174,16 +165,6 @@ std::string computeBucketedFileName( "0{:0>{}}_0_{}", bucketValueStr, kMaxBucketCountPadding, queryId); } -std::shared_ptr createSinkPool( - const std::shared_ptr& writerPool) { - return writerPool->addLeafChild(fmt::format("{}.sink", writerPool->name())); -} - -std::shared_ptr createSortPool( - const std::shared_ptr& writerPool) { - return writerPool->addLeafChild(fmt::format("{}.sort", writerPool->name())); -} - uint64_t getFinishTimeSliceLimitMsFromHiveConfig( const std::shared_ptr& config, const config::ConfigBase* sessions) { @@ -200,32 +181,31 @@ FOLLY_ALWAYS_INLINE int32_t getBucketCount(const HiveBucketProperty* bucketProperty) { return bucketProperty == nullptr ? 0 : bucketProperty->bucketCount(); } -} // namespace - -const HiveWriterId& HiveWriterId::unpartitionedId() { - static const HiveWriterId writerId{0}; - return writerId; -} -std::string HiveWriterId::toString() const { - if (partitionId.has_value() && bucketId.has_value()) { - return fmt::format("part[{}.{}]", partitionId.value(), bucketId.value()); - } - - if (partitionId.has_value() && !bucketId.has_value()) { - return fmt::format("part[{}]", partitionId.value()); +std::vector computePartitionChannels( + const std::vector>& inputColumns) { + std::vector channels; + for (auto i = 0; i < inputColumns.size(); i++) { + if (inputColumns[i]->isPartitionKey()) { + channels.push_back(i); + } } + return channels; +} - // This WriterId is used to add an identifier in the MemoryPools. This could - // indicate unpart, but the bucket number needs to be disambiguated. So - // creating a new label using bucket. - if (!partitionId.has_value() && bucketId.has_value()) { - return fmt::format("bucket[{}]", bucketId.value()); +std::vector computeNonPartitionChannels( + const std::vector>& inputColumns) { + std::vector channels; + for (auto i = 0; i < inputColumns.size(); i++) { + if (!inputColumns[i]->isPartitionKey()) { + channels.push_back(i); + } } - - return "unpart"; + return channels; } +} // namespace + const std::string LocationHandle::tableTypeName( LocationHandle::TableType type) { static const auto tableTypes = tableTypeNames(); @@ -365,6 +345,54 @@ std::string HiveBucketProperty::toString() const { return out.str(); } +HiveInsertTableHandle::HiveInsertTableHandle( + std::vector> inputColumns, + std::shared_ptr locationHandle, + dwio::common::FileFormat storageFormat, + std::shared_ptr bucketProperty, + std::optional compressionKind, + const std::unordered_map& serdeParameters, + const std::shared_ptr& writerOptions, + // When this option is set the HiveDataSink will always write a file even + // if there's no data. This is useful when the table is bucketed, but the + // engine handles ensuring a 1 to 1 mapping from task to bucket. + const bool ensureFiles, + std::shared_ptr fileNameGenerator, + const std::unordered_map& storageParameters) + : inputColumns_(std::move(inputColumns)), + locationHandle_(std::move(locationHandle)), + storageFormat_(storageFormat), + bucketProperty_(std::move(bucketProperty)), + compressionKind_(compressionKind), + serdeParameters_(serdeParameters), + writerOptions_(writerOptions), + ensureFiles_(ensureFiles), + fileNameGenerator_(std::move(fileNameGenerator)), + storageParameters_(storageParameters), + partitionChannels_(computePartitionChannels(inputColumns_)), + nonPartitionChannels_(computeNonPartitionChannels(inputColumns_)) { + if (compressionKind.has_value()) { + VELOX_CHECK( + compressionKind.value() != common::CompressionKind_MAX, + "Unsupported compression type: CompressionKind_MAX"); + } + + if (ensureFiles_) { + // If ensureFiles is set and either the bucketProperty is set or some + // partition keys are in the data, there is not a 1:1 mapping from Task to + // files so we can't proactively create writers. + VELOX_CHECK( + bucketProperty_ == nullptr || bucketProperty_->bucketCount() == 0, + "ensureFiles is not supported with bucketing"); + + for (const auto& inputColumn : inputColumns_) { + VELOX_CHECK( + !inputColumn->isPartitionKey(), + "ensureFiles is not supported with partition keys in the data"); + } + } +} + HiveDataSink::HiveDataSink( RowTypePtr inputType, std::shared_ptr insertTableHandle, @@ -382,7 +410,14 @@ HiveDataSink::HiveDataSink( ? createBucketFunction( *insertTableHandle->bucketProperty(), inputType) - : nullptr) {} + : nullptr, + insertTableHandle->partitionChannels(), + insertTableHandle->nonPartitionChannels(), + createPartitionIdGenerator( + inputType, + insertTableHandle, + hiveConfig, + connectorQueryCtx)) {} HiveDataSink::HiveDataSink( RowTypePtr inputType, @@ -391,38 +426,35 @@ HiveDataSink::HiveDataSink( CommitStrategy commitStrategy, const std::shared_ptr& hiveConfig, uint32_t bucketCount, - std::unique_ptr bucketFunction) - : inputType_(std::move(inputType)), + std::unique_ptr bucketFunction, + const std::vector& partitionChannels, + const std::vector& dataChannels, + std::unique_ptr partitionIdGenerator) + : FileDataSink( + std::move(inputType), + connectorQueryCtx, + commitStrategy, + insertTableHandle->storageFormat(), + hiveConfig->maxPartitionsPerWriters( + connectorQueryCtx->sessionProperties()), + partitionChannels, + dataChannels, + static_cast(bucketCount), + std::move(bucketFunction), + std::move(partitionIdGenerator), + dwio::common::getWriterFactory(insertTableHandle->storageFormat()), + hiveConfig->maxTargetFileSizeBytes( + connectorQueryCtx->sessionProperties()), + hiveConfig->isPartitionPathAsLowerCase( + connectorQueryCtx->sessionProperties()), + connectorQueryCtx->spillConfig(), + getFinishTimeSliceLimitMsFromHiveConfig( + hiveConfig, + connectorQueryCtx->sessionProperties())), insertTableHandle_(std::move(insertTableHandle)), - connectorQueryCtx_(connectorQueryCtx), - commitStrategy_(commitStrategy), hiveConfig_(hiveConfig), updateMode_(getUpdateMode()), - maxOpenWriters_(hiveConfig_->maxPartitionsPerWriters( - connectorQueryCtx->sessionProperties())), - partitionChannels_(getPartitionChannels(insertTableHandle_)), - partitionIdGenerator_( - !partitionChannels_.empty() - ? std::make_unique( - inputType_, - partitionChannels_, - maxOpenWriters_, - connectorQueryCtx_->memoryPool(), - hiveConfig_->isPartitionPathAsLowerCase( - connectorQueryCtx->sessionProperties())) - : nullptr), - dataChannels_( - getNonPartitionChannels(partitionChannels_, inputType_->size())), - bucketCount_(static_cast(bucketCount)), - bucketFunction_(std::move(bucketFunction)), - writerFactory_( - dwio::common::getWriterFactory(insertTableHandle_->storageFormat())), - spillConfig_(connectorQueryCtx->spillConfig()), - sortWriterFinishTimeSliceLimitMs_(getFinishTimeSliceLimitMsFromHiveConfig( - hiveConfig_, - connectorQueryCtx->sessionProperties())), fileNameGenerator_(insertTableHandle_->fileNameGenerator()) { - fileSystemStats_ = std::make_unique(); if (isBucketed()) { VELOX_USER_CHECK_LT( bucketCount_, @@ -439,7 +471,7 @@ HiveDataSink::HiveDataSink( VELOX_CHECK( !isPartitioned() && !isBucketed(), "ensureFiles is not supported with bucketing or partition keys in the data"); - ensureWriter(HiveWriterId::unpartitionedId()); + ensureWriter(WriterId::unpartitionedId()); } if (!isBucketed()) { @@ -462,6 +494,7 @@ HiveDataSink::HiveDataSink( CompareFlags::NullHandlingMode::kNullAsValue}); } } + sortWrite_ = !sortColumnIndices_.empty(); } } @@ -472,69 +505,6 @@ bool HiveDataSink::canReclaim() const { insertTableHandle_->storageFormat() == dwio::common::FileFormat::NIMBLE); } -void HiveDataSink::appendData(RowVectorPtr input) { - checkRunning(); - - // Lazy load all the input columns. - input->loadedVector(); - - // Write to unpartitioned (and unbucketed) table. - if (!isPartitioned() && !isBucketed()) { - const auto index = ensureWriter(HiveWriterId::unpartitionedId()); - write(index, input); - return; - } - - // Compute partition and bucket numbers. - computePartitionAndBucketIds(input); - - // All inputs belong to a single non-bucketed partition. The partition id - // must be zero. - if (!isBucketed() && partitionIdGenerator_->numPartitions() == 1) { - const auto index = ensureWriter(HiveWriterId{0}); - write(index, input); - return; - } - - splitInputRowsAndEnsureWriters(); - - for (auto index = 0; index < writers_.size(); ++index) { - const vector_size_t partitionSize = partitionSizes_[index]; - if (partitionSize == 0) { - continue; - } - - RowVectorPtr writerInput = partitionSize == input->size() - ? input - : exec::wrap(partitionSize, partitionRows_[index], input); - write(index, writerInput); - } -} - -void HiveDataSink::write(size_t index, RowVectorPtr input) { - WRITER_NON_RECLAIMABLE_SECTION_GUARD(index); - auto dataInput = makeDataInput(dataChannels_, input); - - writers_[index]->write(dataInput); - writerInfo_[index]->inputSizeInBytes += dataInput->estimateFlatSize(); - writerInfo_[index]->numWrittenRows += dataInput->size(); -} - -std::string HiveDataSink::stateString(State state) { - switch (state) { - case State::kRunning: - return "RUNNING"; - case State::kFinishing: - return "FLUSHING"; - case State::kClosed: - return "CLOSED"; - case State::kAborted: - return "ABORTED"; - default: - VELOX_UNREACHABLE("BAD STATE: {}", static_cast(state)); - } -} - void HiveDataSink::computePartitionAndBucketIds(const RowVectorPtr& input) { VELOX_CHECK(isPartitioned() || isBucketed()); if (isPartitioned()) { @@ -561,59 +531,8 @@ void HiveDataSink::computePartitionAndBucketIds(const RowVectorPtr& input) { } } -DataSink::Stats HiveDataSink::stats() const { - Stats stats; - if (state_ == State::kAborted) { - return stats; - } - - int64_t numWrittenBytes{0}; - int64_t writeIOTimeUs{0}; - for (const auto& ioStats : ioStats_) { - numWrittenBytes += ioStats->rawBytesWritten(); - writeIOTimeUs += ioStats->writeIOTimeUs(); - } - stats.numWrittenBytes = numWrittenBytes; - stats.writeIOTimeUs = writeIOTimeUs; - - if (state_ != State::kClosed) { - return stats; - } - - stats.numWrittenFiles = writers_.size(); - for (int i = 0; i < writerInfo_.size(); ++i) { - const auto& info = writerInfo_.at(i); - VELOX_CHECK_NOT_NULL(info); - const auto spillStats = info->spillStats->rlock(); - if (!spillStats->empty()) { - stats.spillStats += *spillStats; - } - } - return stats; -} - -std::unordered_map HiveDataSink::runtimeStats() - const { - std::unordered_map runtimeStats; - - const auto fsStatsMap = fileSystemStats_->stats(); - for (const auto& [statName, statValue] : fsStatsMap) { - runtimeStats.emplace( - statName, RuntimeCounter(statValue.sum, statValue.unit)); - } - - return runtimeStats; -} - -std::shared_ptr HiveDataSink::createWriterPool( - const HiveWriterId& writerId) { - auto* connectorPool = connectorQueryCtx_->connectorMemoryPool(); - return connectorPool->addAggregateChild( - fmt::format("{}.{}", connectorPool->name(), writerId.toString())); -} - void HiveDataSink::setMemoryReclaimers( - HiveWriterInfo* writerInfo, + WriterInfo* writerInfo, io::IoStatistics* ioStats) { auto* connectorPool = connectorQueryCtx_->connectorMemoryPool(); if (connectorPool->reclaimer() == nullptr) { @@ -626,157 +545,51 @@ void HiveDataSink::setMemoryReclaimers( // writer. } -void HiveDataSink::setState(State newState) { - checkStateTransition(state_, newState); - state_ = newState; -} - -/// Validates the state transition from 'oldState' to 'newState'. -void HiveDataSink::checkStateTransition(State oldState, State newState) { - switch (oldState) { - case State::kRunning: - if (newState == State::kAborted || newState == State::kFinishing) { - return; - } - break; - case State::kFinishing: - if (newState == State::kAborted || newState == State::kClosed || - // The finishing state is reentry state if we yield in the middle of - // finish processing if a single run takes too long. - newState == State::kFinishing) { - return; - } - [[fallthrough]]; - case State::kAborted: - case State::kClosed: - default: - break; - } - VELOX_FAIL("Unexpected state transition from {} to {}", oldState, newState); -} - -bool HiveDataSink::finish() { - // Flush is reentry state. - setState(State::kFinishing); - - // As for now, only sorted writer needs flush buffered data. For non-sorted - // writer, data is directly written to the underlying file writer. - if (!sortWrite()) { - return true; - } - - // TODO: we might refactor to move the data sorting logic into hive data sink. - const uint64_t startTimeMs = getCurrentTimeMs(); - for (auto i = 0; i < writers_.size(); ++i) { - WRITER_NON_RECLAIMABLE_SECTION_GUARD(i); - if (!writers_[i]->finish()) { - return false; - } - if (getCurrentTimeMs() - startTimeMs > sortWriterFinishTimeSliceLimitMs_) { - return false; - } - } - return true; -} - -std::vector HiveDataSink::close() { - setState(State::kClosed); - closeInternal(); - +std::vector HiveDataSink::commitMessage() const { std::vector partitionUpdates; partitionUpdates.reserve(writerInfo_.size()); for (int i = 0; i < writerInfo_.size(); ++i) { const auto& info = writerInfo_.at(i); VELOX_CHECK_NOT_NULL(info); + + folly::dynamic fileWriteInfosArray = folly::dynamic::array; + for (const auto& fileInfo : info->writtenFiles) { + fileWriteInfosArray.push_back( + folly::dynamic::object( + HiveCommitMessage::kWriteFileName, fileInfo.writeFileName)( + HiveCommitMessage::kTargetFileName, fileInfo.targetFileName)( + HiveCommitMessage::kFileSize, fileInfo.fileSize)); + } + // clang-format off auto partitionUpdateJson = folly::toJson( folly::dynamic::object - ("name", info->writerParameters.partitionName().value_or("")) - ("updateMode", - HiveWriterParameters::updateModeToString( + (HiveCommitMessage::kName, info->writerParameters.partitionName().value_or("")) + (HiveCommitMessage::kUpdateMode, + WriterParameters::updateModeToString( info->writerParameters.updateMode())) - ("writePath", info->writerParameters.writeDirectory()) - ("targetPath", info->writerParameters.targetDirectory()) - ("fileWriteInfos", folly::dynamic::array( - folly::dynamic::object - ("writeFileName", info->writerParameters.writeFileName()) - ("targetFileName", info->writerParameters.targetFileName()) - ("fileSize", ioStats_.at(i)->rawBytesWritten()))) - ("rowCount", info->numWrittenRows) - ("inMemoryDataSizeInBytes", info->inputSizeInBytes) - ("onDiskDataSizeInBytes", ioStats_.at(i)->rawBytesWritten()) - ("containsNumberedFileNames", true)); + (HiveCommitMessage::kWritePath, info->writerParameters.writeDirectory()) + (HiveCommitMessage::kTargetPath, info->writerParameters.targetDirectory()) + (HiveCommitMessage::kFileWriteInfos, std::move(fileWriteInfosArray)) + (HiveCommitMessage::kRowCount, info->numWrittenRows) + (HiveCommitMessage::kInMemoryDataSizeInBytes, info->inputSizeInBytes) + (HiveCommitMessage::kOnDiskDataSizeInBytes, ioStats_.at(i)->rawBytesWritten()) + (HiveCommitMessage::kContainsNumberedFileNames, true)); // clang-format on partitionUpdates.push_back(partitionUpdateJson); } return partitionUpdates; } -void HiveDataSink::abort() { - setState(State::kAborted); - closeInternal(); -} - -void HiveDataSink::closeInternal() { - VELOX_CHECK_NE(state_, State::kRunning); - VELOX_CHECK_NE(state_, State::kFinishing); - - TestValue::adjust( - "facebook::velox::connector::hive::HiveDataSink::closeInternal", this); - - if (state_ == State::kClosed) { - for (int i = 0; i < writers_.size(); ++i) { - WRITER_NON_RECLAIMABLE_SECTION_GUARD(i); - writers_[i]->close(); - } - } else { - for (int i = 0; i < writers_.size(); ++i) { - WRITER_NON_RECLAIMABLE_SECTION_GUARD(i); - writers_[i]->abort(); - } - } -} - -uint32_t HiveDataSink::ensureWriter(const HiveWriterId& id) { - auto it = writerIndexMap_.find(id); - if (it != writerIndexMap_.end()) { - return it->second; - } - return appendWriter(id); +std::shared_ptr HiveDataSink::createWriterOptions() + const { + // Default: use the last writer's info (for appendWriter which just added it) + return createWriterOptions(writerInfo_.size() - 1); } -uint32_t HiveDataSink::appendWriter(const HiveWriterId& id) { - // Check max open writers. - VELOX_USER_CHECK_LE( - writers_.size(), maxOpenWriters_, "Exceeded open writer limit"); - VELOX_CHECK_EQ(writers_.size(), writerInfo_.size()); - VELOX_CHECK_EQ(writerIndexMap_.size(), writerInfo_.size()); - - std::optional partitionName; - if (isPartitioned()) { - partitionName = - partitionIdGenerator_->partitionName(id.partitionId.value()); - } - - // Without explicitly setting flush policy, the default memory based flush - // policy is used. - auto writerParameters = getWriterParameters(partitionName, id.bucketId); - const auto writePath = fs::path(writerParameters.writeDirectory()) / - writerParameters.writeFileName(); - auto writerPool = createWriterPool(id); - auto sinkPool = createSinkPool(writerPool); - std::shared_ptr sortPool{nullptr}; - if (sortWrite()) { - sortPool = createSortPool(writerPool); - } - writerInfo_.emplace_back(std::make_shared( - std::move(writerParameters), - std::move(writerPool), - std::move(sinkPool), - std::move(sortPool))); - ioStats_.emplace_back(std::make_unique()); - - setMemoryReclaimers(writerInfo_.back().get(), ioStats_.back().get()); +std::shared_ptr HiveDataSink::createWriterOptions( + size_t writerIndex) const { + VELOX_CHECK_LT(writerIndex, writerInfo_.size()); // Take the writer options provided by the user as a starting point, or // allocate a new one. @@ -793,10 +606,6 @@ uint32_t HiveDataSink::appendWriter(const HiveWriterId& id) { options->schema = getNonPartitionTypes(dataChannels_, inputType_); } - if (options->memoryPool == nullptr) { - options->memoryPool = writerInfo_.back()->writerPool.get(); - } - if (!options->compressionKind) { options->compressionKind = insertTableHandle_->compressionKind(); } @@ -805,10 +614,12 @@ uint32_t HiveDataSink::appendWriter(const HiveWriterId& id) { options->spillConfig = spillConfig_; } - if (options->nonReclaimableSection == nullptr) { - options->nonReclaimableSection = - writerInfo_.back()->nonReclaimableSectionHolder.get(); - } + // Always set per-writer options to the current writer's values. + // Since insertTableHandle_->writerOptions() returns a shared_ptr, each writer + // needs its own memory pool and nonReclaimableSection pointer. + options->memoryPool = writerInfo_[writerIndex]->writerPool.get(); + options->nonReclaimableSection = + writerInfo_[writerIndex]->nonReclaimableSectionHolder.get(); if (options->memoryReclaimerFactory == nullptr || options->memoryReclaimerFactory() == nullptr) { @@ -827,50 +638,70 @@ uint32_t HiveDataSink::appendWriter(const HiveWriterId& id) { options->adjustTimestampToTimezone = connectorQueryCtx_->adjustTimestampToTimezone(); options->processConfigs(*hiveConfig_->config(), *connectorSessionProperties); + return options; +} + +std::unique_ptr HiveDataSink::createWriterForIndex( + size_t writerIndex) { + VELOX_CHECK_LT(writerIndex, writerInfo_.size()); + VELOX_CHECK_LT(writerIndex, ioStats_.size()); + + auto& info = writerInfo_[writerIndex]; + const auto& params = info->writerParameters; + + // Compute and store the new file names. + info->currentWriteFileName = + makeSequencedFileName(params.writeFileName(), info->fileSequenceNumber); + info->currentTargetFileName = + makeSequencedFileName(params.targetFileName(), info->fileSequenceNumber); + + const auto writePath = + (fs::path(params.writeDirectory()) / info->currentWriteFileName).string(); + + auto options = createWriterOptions(writerIndex); // Prevents the memory allocation during the writer creation. - WRITER_NON_RECLAIMABLE_SECTION_GUARD(writerInfo_.size() - 1); + WRITER_NON_RECLAIMABLE_SECTION_GUARD(writerIndex); auto writer = writerFactory_->createWriter( - dwio::common::FileSink::create( + createHiveFileSink( writePath, - { - .bufferWrite = false, - .connectorProperties = hiveConfig_->config(), - .fileCreateConfig = hiveConfig_->writeFileCreateConfig(), - .pool = writerInfo_.back()->sinkPool.get(), - .metricLogger = dwio::common::MetricsLog::voidLog(), - .stats = ioStats_.back().get(), - .fileSystemStats = fileSystemStats_.get(), - }), + hiveConfig_, + info->sinkPool.get(), + ioStats_[writerIndex].get(), + fileSystemStats_.get(), + insertTableHandle_->storageParameters()), options); - writer = maybeCreateBucketSortWriter(std::move(writer)); - writers_.emplace_back(std::move(writer)); - // Extends the buffer used for partition rows calculations. - partitionSizes_.emplace_back(0); - partitionRows_.emplace_back(nullptr); - rawPartitionRows_.emplace_back(nullptr); - - writerIndexMap_.emplace(id, writers_.size() - 1); - return writerIndexMap_[id]; + return maybeCreateBucketSortWriter(writerIndex, std::move(writer)); +} + +std::string HiveDataSink::getPartitionName(uint32_t partitionId) const { + VELOX_CHECK_NOT_NULL(partitionIdGenerator_); + + return HivePartitionName::partitionName( + partitionId, + partitionIdGenerator_->partitionValues(), + partitionKeyAsLowerCase_); } std::unique_ptr HiveDataSink::maybeCreateBucketSortWriter( + size_t writerIndex, std::unique_ptr writer) { if (!sortWrite()) { return writer; } - auto* sortPool = writerInfo_.back()->sortPool.get(); + auto* sortPool = writerInfo_[writerIndex]->sortPool.get(); VELOX_CHECK_NOT_NULL(sortPool); auto sortBuffer = std::make_unique( getNonPartitionTypes(dataChannels_, inputType_), sortColumnIndices_, sortCompareFlags_, sortPool, - writerInfo_.back()->nonReclaimableSectionHolder.get(), + writerInfo_[writerIndex]->nonReclaimableSectionHolder.get(), connectorQueryCtx_->prefixSortConfig(), spillConfig_, - writerInfo_.back()->spillStats.get()); + writerInfo_[writerIndex]->spillStats.get()); + return std::make_unique( std::move(writer), std::move(sortBuffer), @@ -881,62 +712,12 @@ HiveDataSink::maybeCreateBucketSortWriter( sortWriterFinishTimeSliceLimitMs_); } -HiveWriterId HiveDataSink::getWriterId(size_t row) const { - std::optional partitionId; - if (isPartitioned()) { - VELOX_CHECK_LT(partitionIds_[row], std::numeric_limits::max()); - partitionId = static_cast(partitionIds_[row]); - } - - std::optional bucketId; - if (isBucketed()) { - bucketId = bucketIds_[row]; - } - return HiveWriterId{partitionId, bucketId}; -} - -void HiveDataSink::splitInputRowsAndEnsureWriters() { - VELOX_CHECK(isPartitioned() || isBucketed()); - if (isBucketed() && isPartitioned()) { - VELOX_CHECK_EQ(bucketIds_.size(), partitionIds_.size()); - } - - std::fill(partitionSizes_.begin(), partitionSizes_.end(), 0); - - const auto numRows = - isPartitioned() ? partitionIds_.size() : bucketIds_.size(); - for (auto row = 0; row < numRows; ++row) { - const auto id = getWriterId(row); - const uint32_t index = ensureWriter(id); - - VELOX_DCHECK_LT(index, partitionSizes_.size()); - VELOX_DCHECK_EQ(partitionSizes_.size(), partitionRows_.size()); - VELOX_DCHECK_EQ(partitionRows_.size(), rawPartitionRows_.size()); - if (FOLLY_UNLIKELY(partitionRows_[index] == nullptr) || - (partitionRows_[index]->capacity() < numRows * sizeof(vector_size_t))) { - partitionRows_[index] = - allocateIndices(numRows, connectorQueryCtx_->memoryPool()); - rawPartitionRows_[index] = - partitionRows_[index]->asMutable(); - } - rawPartitionRows_[index][partitionSizes_[index]] = row; - ++partitionSizes_[index]; - } - - for (uint32_t i = 0; i < partitionSizes_.size(); ++i) { - if (partitionSizes_[i] != 0) { - VELOX_CHECK_NOT_NULL(partitionRows_[i]); - partitionRows_[i]->setSize(partitionSizes_[i] * sizeof(vector_size_t)); - } - } -} - -HiveWriterParameters HiveDataSink::getWriterParameters( +WriterParameters HiveDataSink::getWriterParameters( const std::optional& partition, std::optional bucketId) const { auto [targetFileName, writeFileName] = getWriterFileNames(bucketId); - return HiveWriterParameters{ + return WriterParameters{ updateMode_, partition, targetFileName, @@ -996,6 +777,8 @@ std::pair HiveInsertFileNameGenerator::gen( connectorQueryCtx.queryId(), hiveConfig->maxBucketCount(connectorQueryCtx.sessionProperties()), bucketId.value()); + // queryId may contain unsafe characters. + sanitizeFileName(targetFileName); } else if (generateFileName) { // targetFileName includes planNodeId and Uuid. As a result, different // table writers run by the same task driver or the same table writer @@ -1006,7 +789,10 @@ std::pair HiveInsertFileNameGenerator::gen( connectorQueryCtx.driverId(), connectorQueryCtx.planNodeId(), makeUuid()); + // taskId, planNodeId may contain unsafe characters. + sanitizeFileName(targetFileName); } + // do not try to sanitize user provided targetFileName VELOX_CHECK(!targetFileName.empty()); const std::string writeFileName = commitRequired ? fmt::format(".tmp.velox.{}_{}", targetFileName, makeUuid()) @@ -1020,6 +806,11 @@ std::pair HiveInsertFileNameGenerator::gen( return {targetFileName, writeFileName}; } +void HiveInsertFileNameGenerator::sanitizeFileName(std::string& name) { + static const re2::RE2 re("[^a-zA-Z0-9._-]"); + re2::RE2::GlobalReplace(&name, re, "_"); +} + folly::dynamic HiveInsertFileNameGenerator::serialize() const { folly::dynamic obj = folly::dynamic::object; obj["name"] = "HiveInsertFileNameGenerator"; @@ -1043,16 +834,16 @@ std::string HiveInsertFileNameGenerator::toString() const { return "HiveInsertFileNameGenerator"; } -HiveWriterParameters::UpdateMode HiveDataSink::getUpdateMode() const { +WriterParameters::UpdateMode HiveDataSink::getUpdateMode() const { if (insertTableHandle_->isExistingTable()) { if (insertTableHandle_->isPartitioned()) { const auto insertBehavior = hiveConfig_->insertExistingPartitionsBehavior( connectorQueryCtx_->sessionProperties()); switch (insertBehavior) { case HiveConfig::InsertExistingPartitionsBehavior::kOverwrite: - return HiveWriterParameters::UpdateMode::kOverwrite; + return WriterParameters::UpdateMode::kOverwrite; case HiveConfig::InsertExistingPartitionsBehavior::kError: - return HiveWriterParameters::UpdateMode::kNew; + return WriterParameters::UpdateMode::kNew; default: VELOX_UNSUPPORTED( "Unsupported insert existing partitions behavior: {}", @@ -1063,10 +854,10 @@ HiveWriterParameters::UpdateMode HiveDataSink::getUpdateMode() const { if (hiveConfig_->immutablePartitions()) { VELOX_USER_FAIL("Unpartitioned Hive tables are immutable."); } - return HiveWriterParameters::UpdateMode::kAppend; + return WriterParameters::UpdateMode::kAppend; } } else { - return HiveWriterParameters::UpdateMode::kNew; + return WriterParameters::UpdateMode::kNew; } } @@ -1114,6 +905,13 @@ folly::dynamic HiveInsertTableHandle::serialize() const { params[key] = value; } obj["serdeParameters"] = params; + + folly::dynamic storageParams = folly::dynamic::object; + for (const auto& [key, value] : storageParameters_) { + storageParams[key] = value; + } + obj["storageParameters"] = storageParams; + obj["ensureFiles"] = ensureFiles_; obj["fileNameGenerator"] = fileNameGenerator_->serialize(); return obj; @@ -1145,6 +943,13 @@ HiveInsertTableHandlePtr HiveInsertTableHandle::create( serdeParameters.emplace(pair.first.asString(), pair.second.asString()); } + std::unordered_map storageParameters; + if (obj.count("storageParameters") > 0) { + for (const auto& pair : obj["storageParameters"].items()) { + storageParameters.emplace(pair.first.asString(), pair.second.asString()); + } + } + bool ensureFiles = obj["ensureFiles"].asBool(); auto fileNameGenerator = @@ -1158,7 +963,8 @@ HiveInsertTableHandlePtr HiveInsertTableHandle::create( serdeParameters, nullptr, // writerOptions is not serializable ensureFiles, - fileNameGenerator); + fileNameGenerator, + storageParameters); } void HiveInsertTableHandle::registerSerDe() { @@ -1231,7 +1037,7 @@ LocationHandlePtr LocationHandle::create(const folly::dynamic& obj) { std::unique_ptr HiveDataSink::WriterReclaimer::create( HiveDataSink* dataSink, - HiveWriterInfo* writerInfo, + WriterInfo* writerInfo, io::IoStatistics* ioStats) { return std::unique_ptr( new HiveDataSink::WriterReclaimer(dataSink, writerInfo, ioStats)); @@ -1261,9 +1067,13 @@ uint64_t HiveDataSink::WriterReclaimer::reclaim( if (*writerInfo_->nonReclaimableSectionHolder.get()) { RECORD_METRIC_VALUE(kMetricMemoryNonReclaimableCount); LOG(WARNING) << "Can't reclaim from hive writer pool " << pool->name() - << " which is under non-reclaimable section, " - << " reserved memory: " - << succinctBytes(pool->reservedBytes()); + << " which is under non-reclaimable section, root pool: " + << pool->root()->name() + << ", state: " << stateString(dataSink_->state_) + << ", used: " << succinctBytes(pool->usedBytes()) + << ", reservation: " << succinctBytes(pool->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool->root()->reservedBytes()); ++stats.numNonReclaimableAttempts; return 0; } @@ -1277,7 +1087,8 @@ uint64_t HiveDataSink::WriterReclaimer::reclaim( ioStats_->rawBytesWritten() - writtenBytesBeforeReclaim; addThreadLocalRuntimeStat( kEarlyFlushedRawBytes, - RuntimeCounter(earlyFlushedRawBytes, RuntimeCounter::Unit::kBytes)); + RuntimeCounter( + saturateCast(earlyFlushedRawBytes), RuntimeCounter::Unit::kBytes)); if (earlyFlushedRawBytes > 0) { RECORD_METRIC_VALUE( kMetricFileWriterEarlyFlushedRawBytes, earlyFlushedRawBytes); diff --git a/velox/connectors/hive/HiveDataSink.h b/velox/connectors/hive/HiveDataSink.h index 8c305b7595d..8eb9ec04f0f 100644 --- a/velox/connectors/hive/HiveDataSink.h +++ b/velox/connectors/hive/HiveDataSink.h @@ -17,7 +17,9 @@ #include "velox/common/compression/Compression.h" #include "velox/connectors/Connector.h" +#include "velox/connectors/hive/FileDataSink.h" #include "velox/connectors/hive/HiveConfig.h" +#include "velox/connectors/hive/HivePartitionName.h" #include "velox/connectors/hive/PartitionIdGenerator.h" #include "velox/connectors/hive/TableHandle.h" #include "velox/dwio/common/Options.h" @@ -231,6 +233,9 @@ class HiveInsertFileNameGenerator : public FileNameGenerator { void* context); std::string toString() const override; + + /// Replaces potentially unsafe characters in a file name with underscores + static void sanitizeFileName(std::string& name); }; /// Represents a request for Hive write. @@ -250,37 +255,9 @@ class HiveInsertTableHandle : public ConnectorInsertTableHandle { // engine handles ensuring a 1 to 1 mapping from task to bucket. const bool ensureFiles = false, std::shared_ptr fileNameGenerator = - std::make_shared()) - : inputColumns_(std::move(inputColumns)), - locationHandle_(std::move(locationHandle)), - storageFormat_(storageFormat), - bucketProperty_(std::move(bucketProperty)), - compressionKind_(compressionKind), - serdeParameters_(serdeParameters), - writerOptions_(writerOptions), - ensureFiles_(ensureFiles), - fileNameGenerator_(std::move(fileNameGenerator)) { - if (compressionKind.has_value()) { - VELOX_CHECK( - compressionKind.value() != common::CompressionKind_MAX, - "Unsupported compression type: CompressionKind_MAX"); - } - - if (ensureFiles_) { - // If ensureFiles is set and either the bucketProperty is set or some - // partition keys are in the data, there is not a 1:1 mapping from Task to - // files so we can't proactively create writers. - VELOX_CHECK( - bucketProperty_ == nullptr || bucketProperty_->bucketCount() == 0, - "ensureFiles is not supported with bucketing"); - - for (const auto& inputColumn : inputColumns_) { - VELOX_CHECK( - !inputColumn->isPartitionKey(), - "ensureFiles is not supported with partition keys in the data"); - } - } - } + std::make_shared(), + const std::unordered_map& storageParameters = + {}); virtual ~HiveInsertTableHandle() = default; @@ -301,10 +278,19 @@ class HiveInsertTableHandle : public ConnectorInsertTableHandle { return storageFormat_; } + /// Format specific options. const std::unordered_map& serdeParameters() const { return serdeParameters_; } + /// Storage specific options. + const std::unordered_map& storageParameters() + const { + return storageParameters_; + } + + /// Avoid this in future usages. Format specific change should go through + /// serdeParameters. const std::shared_ptr& writerOptions() const { return writerOptions_; } @@ -329,6 +315,16 @@ class HiveInsertTableHandle : public ConnectorInsertTableHandle { bool isExistingTable() const; + /// Returns a subset of column indices corresponding to partition keys. + const std::vector& partitionChannels() const { + return partitionChannels_; + } + + /// Returns the column indices of non-partition data columns. + const std::vector& nonPartitionChannels() const { + return nonPartitionChannels_; + } + folly::dynamic serialize() const override; static HiveInsertTableHandlePtr create(const folly::dynamic& obj); @@ -337,9 +333,11 @@ class HiveInsertTableHandle : public ConnectorInsertTableHandle { std::string toString() const override; - private: + protected: const std::vector> inputColumns_; const std::shared_ptr locationHandle_; + + private: const dwio::common::FileFormat storageFormat_; const std::shared_ptr bucketProperty_; const std::optional compressionKind_; @@ -347,173 +345,82 @@ class HiveInsertTableHandle : public ConnectorInsertTableHandle { const std::shared_ptr writerOptions_; const bool ensureFiles_; const std::shared_ptr fileNameGenerator_; + const std::unordered_map storageParameters_; + const std::vector partitionChannels_; + const std::vector nonPartitionChannels_; }; -/// Parameters for Hive writers. -class HiveWriterParameters { - public: - enum class UpdateMode { - kNew, // Write files to a new directory. - kOverwrite, // Overwrite an existing directory. - // Append mode is currently only supported for unpartitioned tables. - kAppend, // Append to an unpartitioned table. - }; - - /// @param updateMode Write the files to a new directory, or append to an - /// existing directory or overwrite an existing directory. - /// @param partitionName Partition name in the typical Hive style, which is - /// also the partition subdirectory part of the partition path. - /// @param targetFileName The final name of a file after committing. - /// @param targetDirectory The final directory that a file should be in after - /// committing. - /// @param writeFileName The temporary name of the file that a running writer - /// writes to. If a running writer writes directory to the target file, set - /// writeFileName to targetFileName by default. - /// @param writeDirectory The temporary directory that a running writer writes - /// to. If a running writer writes directory to the target directory, set - /// writeDirectory to targetDirectory by default. - HiveWriterParameters( - UpdateMode updateMode, - std::optional partitionName, - std::string targetFileName, - std::string targetDirectory, - std::optional writeFileName = std::nullopt, - std::optional writeDirectory = std::nullopt) - : updateMode_(updateMode), - partitionName_(std::move(partitionName)), - targetFileName_(std::move(targetFileName)), - targetDirectory_(std::move(targetDirectory)), - writeFileName_(writeFileName.value_or(targetFileName_)), - writeDirectory_(writeDirectory.value_or(targetDirectory_)) {} - - UpdateMode updateMode() const { - return updateMode_; - } - - static std::string updateModeToString(UpdateMode updateMode) { - switch (updateMode) { - case UpdateMode::kNew: - return "NEW"; - case UpdateMode::kOverwrite: - return "OVERWRITE"; - case UpdateMode::kAppend: - return "APPEND"; - default: - VELOX_UNSUPPORTED("Unsupported update mode."); - } - } - - const std::optional& partitionName() const { - return partitionName_; - } - - const std::string& targetFileName() const { - return targetFileName_; - } - - const std::string& writeFileName() const { - return writeFileName_; - } - - const std::string& targetDirectory() const { - return targetDirectory_; - } - - const std::string& writeDirectory() const { - return writeDirectory_; - } - - private: - const UpdateMode updateMode_; - const std::optional partitionName_; - const std::string targetFileName_; - const std::string targetDirectory_; - const std::string writeFileName_; - const std::string writeDirectory_; -}; - -struct HiveWriterInfo { - HiveWriterInfo( - HiveWriterParameters parameters, - std::shared_ptr _writerPool, - std::shared_ptr _sinkPool, - std::shared_ptr _sortPool) - : writerParameters(std::move(parameters)), - nonReclaimableSectionHolder(new tsan_atomic(false)), - spillStats(std::make_unique>()), - writerPool(std::move(_writerPool)), - sinkPool(std::move(_sinkPool)), - sortPool(std::move(_sortPool)) {} - - const HiveWriterParameters writerParameters; - const std::unique_ptr> nonReclaimableSectionHolder; - /// Collects the spill stats from sort writer if the spilling has been - /// triggered. - const std::unique_ptr> spillStats; - const std::shared_ptr writerPool; - const std::shared_ptr sinkPool; - const std::shared_ptr sortPool; - int64_t numWrittenRows = 0; - int64_t inputSizeInBytes = 0; -}; - -/// Identifies a hive writer. -struct HiveWriterId { - std::optional partitionId{std::nullopt}; - std::optional bucketId{std::nullopt}; - - HiveWriterId() = default; - - HiveWriterId( - std::optional _partitionId, - std::optional _bucketId = std::nullopt) - : partitionId(_partitionId), bucketId(_bucketId) {} - - /// Returns the special writer id for the un-partitioned (and non-bucketed) - /// table. - static const HiveWriterId& unpartitionedId(); - - std::string toString() const; - - bool operator==(const HiveWriterId& other) const { - return std::tie(partitionId, bucketId) == - std::tie(other.partitionId, other.bucketId); - } -}; - -struct HiveWriterIdHasher { - std::size_t operator()(const HiveWriterId& id) const { - return bits::hashMix( - id.partitionId.value_or(std::numeric_limits::max()), - id.bucketId.value_or(std::numeric_limits::max())); - } +/// JSON field names for the partition update object produced by each writer +/// and consumed by the Presto coordinator to finalize files and update the +/// metastore. +/// +/// JSON structure: +/// { +/// "name": "", +/// "updateMode": "NEW" | "APPEND" | "OVERWRITE", +/// "writePath": "", +/// "targetPath": "", +/// "fileWriteInfos": [ +/// { +/// "writeFileName": "", +/// "targetFileName": "", +/// "fileSize": +/// } +/// ], +/// "rowCount": , +/// "inMemoryDataSizeInBytes": , +/// "onDiskDataSizeInBytes": , +/// "containsNumberedFileNames": true | false +/// } +struct HiveCommitMessage { + /// Partition directory name in Hive format (e.g., "ds=2024-01-01/region=us"). + /// Empty string for unpartitioned tables. + static constexpr const char* kName = "name"; + /// Write mode: "NEW", "APPEND", or "OVERWRITE". Controls how the committer + /// handles metastore updates and existing file conflicts. + static constexpr const char* kUpdateMode = "updateMode"; + /// Staging directory where files were written during execution. + static constexpr const char* kWritePath = "writePath"; + /// Final destination directory. Files are renamed from writePath to + /// targetPath during commit. + static constexpr const char* kTargetPath = "targetPath"; + /// Array of per-file metadata objects. One entry per file written, including + /// rotated files. + static constexpr const char* kFileWriteInfos = "fileWriteInfos"; + /// Temporary filename used during writing (in the staging directory). + static constexpr const char* kWriteFileName = "writeFileName"; + /// Final filename after commit (in the target directory). + static constexpr const char* kTargetFileName = "targetFileName"; + /// Size of individual file in bytes. + static constexpr const char* kFileSize = "fileSize"; + /// Total rows written to this partition across all files. + static constexpr const char* kRowCount = "rowCount"; + /// Uncompressed input data size in bytes. + static constexpr const char* kInMemoryDataSizeInBytes = + "inMemoryDataSizeInBytes"; + /// Compressed bytes written to disk. + static constexpr const char* kOnDiskDataSizeInBytes = "onDiskDataSizeInBytes"; + /// Whether filenames follow a numbered sequence from file rotation. + static constexpr const char* kContainsNumberedFileNames = + "containsNumberedFileNames"; }; -struct HiveWriterIdEq { - bool operator()(const HiveWriterId& lhs, const HiveWriterId& rhs) const { - return lhs == rhs; - } -}; - -class HiveDataSink : public DataSink { +class HiveDataSink : public FileDataSink { public: /// The list of runtime stats reported by hive data sink static constexpr const char* kEarlyFlushedRawBytes = "earlyFlushedRawBytes"; - /// Defines the execution states of a hive data sink running internally. - enum class State { - /// The data sink accepts new append data in this state. - kRunning = 0, - /// The data sink flushes any buffered data to the underlying file writer - /// but no more data can be appended. - kFinishing = 1, - /// The data sink is aborted on error and no more data can be appended. - kAborted = 2, - /// The data sink is closed on error and no more data can be appended. - kClosed = 3 - }; - static std::string stateString(State state); - + /// Creates a HiveDataSink for writing data to Hive table files. + /// + /// @param inputType The schema of input data rows to be written. + /// @param insertTableHandle Metadata about the table write operation, + /// including storage format, compression, bucketing, and partitioning + /// configuration. + /// @param connectorQueryCtx Query context with session properties, memory + /// pools, and spill configuration. + /// @param commitStrategy Strategy for committing written data (kNoCommit or + /// kTaskCommit). + /// @param hiveConfig Hive connector configuration. HiveDataSink( RowTypePtr inputType, std::shared_ptr insertTableHandle, @@ -521,6 +428,31 @@ class HiveDataSink : public DataSink { CommitStrategy commitStrategy, const std::shared_ptr& hiveConfig); + /// Constructor with explicit bucketing and partitioning parameters. + /// + /// @param inputType The schema of input data rows to be written. + /// @param insertTableHandle Metadata about the table write operation, + /// including storage format, compression, location, and serialization + /// parameters. + /// @param connectorQueryCtx Query context with session properties, memory + /// pools, and spill configuration. + /// @param commitStrategy Strategy for committing written data (kNoCommit or + /// kTaskCommit). Determines whether temporary files need to be renamed on + /// commit. + /// @param hiveConfig Hive connector configuration with settings for max + /// partitions, bucketing limits etc. + /// @param bucketCount Number of buckets for bucketed tables (0 if not + /// bucketed). Must be less than the configured max bucket count. + /// @param bucketFunction Function to compute bucket IDs from row data + /// (nullptr if not bucketed). Used to distribute rows across buckets. + /// @param partitionChannels Column indices used for partitioning (empty if + /// not partitioned). These columns are extracted to determine partition + /// directories. + /// @param dataChannels Column indices for the actual data columns to be + /// written. + /// @param partitionIdGenerator Generates partition IDs from partition column + /// values (nullptr if not partitioned). Compute partition key combinations to + /// unique IDs. HiveDataSink( RowTypePtr inputType, std::shared_ptr insertTableHandle, @@ -528,32 +460,21 @@ class HiveDataSink : public DataSink { CommitStrategy commitStrategy, const std::shared_ptr& hiveConfig, uint32_t bucketCount, - std::unique_ptr bucketFunction); - - void appendData(RowVectorPtr input) override; - - bool finish() override; - - Stats stats() const override; - - std::unordered_map runtimeStats() const override; - - std::vector close() override; - - void abort() override; + std::unique_ptr bucketFunction, + const std::vector& partitionChannels, + const std::vector& dataChannels, + std::unique_ptr partitionIdGenerator); bool canReclaim() const; - private: - // Validates the state transition from 'oldState' to 'newState'. - void checkStateTransition(State oldState, State newState); - void setState(State newState); + protected: + std::vector commitMessage() const override; class WriterReclaimer : public exec::MemoryReclaimer { public: static std::unique_ptr create( HiveDataSink* dataSink, - HiveWriterInfo* writerInfo, + WriterInfo* writerInfo, io::IoStatistics* ioStats); bool reclaimableBytes( @@ -569,7 +490,7 @@ class HiveDataSink : public DataSink { private: WriterReclaimer( HiveDataSink* dataSink, - HiveWriterInfo* writerInfo, + WriterInfo* writerInfo, io::IoStatistics* ioStats) : exec::MemoryReclaimer(0), dataSink_(dataSink), @@ -581,161 +502,57 @@ class HiveDataSink : public DataSink { } HiveDataSink* const dataSink_; - HiveWriterInfo* const writerInfo_; + WriterInfo* const writerInfo_; io::IoStatistics* const ioStats_; }; - FOLLY_ALWAYS_INLINE bool sortWrite() const { - return !sortColumnIndices_.empty(); - } - - // Returns true if the table is partitioned. - FOLLY_ALWAYS_INLINE bool isPartitioned() const { - return partitionIdGenerator_ != nullptr; - } - - // Returns true if the table is bucketed. - FOLLY_ALWAYS_INLINE bool isBucketed() const { - return bucketCount_ != 0; - } - - FOLLY_ALWAYS_INLINE bool isCommitRequired() const { - return commitStrategy_ != CommitStrategy::kNoCommit; - } - - std::shared_ptr createWriterPool( - const HiveWriterId& writerId); - - void setMemoryReclaimers( - HiveWriterInfo* writerInfo, - io::IoStatistics* ioStats); + void setMemoryReclaimers(WriterInfo* writerInfo, io::IoStatistics* ioStats) + override; // Compute the partition id and bucket id for each row in 'input'. - void computePartitionAndBucketIds(const RowVectorPtr& input); + void computePartitionAndBucketIds(const RowVectorPtr& input) override; - // Get the HiveWriter corresponding to the row - // from partitionIds and bucketIds. - FOLLY_ALWAYS_INLINE HiveWriterId getWriterId(size_t row) const; + std::unique_ptr createWriterForIndex( + size_t writerIndex) override; - // Computes the number of input rows as well as the actual input row indices - // to each corresponding (bucketed) partition based on the partition and - // bucket ids calculated by 'computePartitionAndBucketIds'. The function also - // ensures that there is a writer created for each (bucketed) partition. - void splitInputRowsAndEnsureWriters(); + // Creates and configures WriterOptions based on file format. + std::shared_ptr createWriterOptions() + const override; - // Makes sure to create one writer for the given writer id. The function - // returns the corresponding index in 'writers_'. - uint32_t ensureWriter(const HiveWriterId& id); + virtual std::shared_ptr createWriterOptions( + size_t writerIndex) const override; - // Appends a new writer for the given 'id'. The function returns the index of - // the newly created writer in 'writers_'. - uint32_t appendWriter(const HiveWriterId& id); + // Returns the Hive partition directory name for the given partition ID. + virtual std::string getPartitionName(uint32_t partitionId) const override; std::unique_ptr maybeCreateBucketSortWriter( + size_t writerIndex, std::unique_ptr writer); - HiveWriterParameters getWriterParameters( + WriterParameters getWriterParameters( const std::optional& partition, - std::optional bucketId) const; + std::optional bucketId) const override; - // Gets write and target file names for a writer based on the table commit - // strategy as well as table partitioned type. If commit is not required, the - // write file and target file has the same name. If not, add a temp file - // prefix to the target file for write file name. The coordinator (or driver - // for Presto on spark) will rename the write file to target file to commit - // the table write when update the metadata store. If it is a bucketed table, - // the file name encodes the corresponding bucket id. + // Gets write and target file names for a writer. std::pair getWriterFileNames( std::optional bucketId) const; - HiveWriterParameters::UpdateMode getUpdateMode() const; - - FOLLY_ALWAYS_INLINE void checkRunning() const { - VELOX_CHECK_EQ(state_, State::kRunning, "Hive data sink is not running"); - } - - // Invoked to write 'input' to the specified file writer. - void write(size_t index, RowVectorPtr input); - - void closeInternal(); + WriterParameters::UpdateMode getUpdateMode() const; - // IMPORTANT NOTE: these are passed to writers as raw pointers. HiveDataSink - // owns the lifetime of these objects, and therefore must destroy them last. - // Additionally, we must assume that no objects which hold a reference to - // these stats will outlive the HiveDataSink instance. This is a reasonable - // assumption given the semantics of these stats objects. - std::vector> ioStats_; - // Generic filesystem stats, exposed as RuntimeStats - std::unique_ptr fileSystemStats_; - - const RowTypePtr inputType_; const std::shared_ptr insertTableHandle_; - const ConnectorQueryCtx* const connectorQueryCtx_; - const CommitStrategy commitStrategy_; const std::shared_ptr hiveConfig_; - const HiveWriterParameters::UpdateMode updateMode_; - const uint32_t maxOpenWriters_; - const std::vector partitionChannels_; - const std::unique_ptr partitionIdGenerator_; - // Indices of dataChannel are stored in ascending order - const std::vector dataChannels_; - const int32_t bucketCount_{0}; - const std::unique_ptr bucketFunction_; - const std::shared_ptr writerFactory_; - const common::SpillConfig* const spillConfig_; - const uint64_t sortWriterFinishTimeSliceLimitMs_{0}; + const WriterParameters::UpdateMode updateMode_; std::vector sortColumnIndices_; std::vector sortCompareFlags_; - State state_{State::kRunning}; - - tsan_atomic nonReclaimableSection_{false}; - - // The map from writer id to the writer index in 'writers_' and 'writerInfo_'. - folly::F14FastMap - writerIndexMap_; - - // Below are structures for partitions from all inputs. writerInfo_ and - // writers_ are both indexed by partitionId. - std::vector> writerInfo_; - std::vector> writers_; - - // Below are structures updated when processing current input. partitionIds_ - // are indexed by the row of input_. partitionRows_, rawPartitionRows_ and - // partitionSizes_ are indexed by partitionId. - raw_vector partitionIds_; - std::vector partitionRows_; - std::vector rawPartitionRows_; - std::vector partitionSizes_; - - // Reusable buffers for bucket id calculations. - std::vector bucketIds_; - // Strategy for naming writer files std::shared_ptr fileNameGenerator_; }; -FOLLY_ALWAYS_INLINE std::ostream& operator<<( - std::ostream& os, - HiveDataSink::State state) { - os << HiveDataSink::stateString(state); - return os; -} } // namespace facebook::velox::connector::hive -template <> -struct fmt::formatter - : formatter { - auto format( - facebook::velox::connector::hive::HiveDataSink::State s, - format_context& ctx) const { - return formatter::format( - facebook::velox::connector::hive::HiveDataSink::stateString(s), ctx); - } -}; - template <> struct fmt::formatter< facebook::velox::connector::hive::LocationHandle::TableType> diff --git a/velox/connectors/hive/HiveDataSource.cpp b/velox/connectors/hive/HiveDataSource.cpp index d0c59b6392c..1db6ba3f543 100644 --- a/velox/connectors/hive/HiveDataSource.cpp +++ b/velox/connectors/hive/HiveDataSource.cpp @@ -16,220 +16,47 @@ #include "velox/connectors/hive/HiveDataSource.h" -#include -#include -#include +#include -#include "velox/common/testutil/TestValue.h" #include "velox/connectors/hive/HiveConfig.h" -#include "velox/expression/FieldReference.h" - -using facebook::velox::common::testutil::TestValue; +#include "velox/connectors/hive/HiveConnectorSplit.h" +#include "velox/connectors/hive/HiveConnectorUtil.h" +#include "velox/connectors/hive/HiveSplitReader.h" namespace facebook::velox::connector::hive { -namespace { - -bool isMember( - const std::vector& fields, - const exec::FieldReference& field) { - return std::find(fields.begin(), fields.end(), &field) != fields.end(); -} - -bool shouldEagerlyMaterialize( - const exec::Expr& remainingFilter, - const exec::FieldReference& field) { - if (!remainingFilter.evaluatesArgumentsOnNonIncreasingSelection()) { - return true; - } - for (auto& input : remainingFilter.inputs()) { - if (isMember(input->distinctFields(), field) && input->hasConditionals()) { - return true; - } - } - return false; -} - -} // namespace - HiveDataSource::HiveDataSource( const RowTypePtr& outputType, const connector::ConnectorTableHandlePtr& tableHandle, - const connector::ColumnHandleMap& columnHandles, + const connector::ColumnHandleMap& assignments, FileHandleFactory* fileHandleFactory, folly::Executor* ioExecutor, const ConnectorQueryCtx* connectorQueryCtx, const std::shared_ptr& hiveConfig) - : fileHandleFactory_(fileHandleFactory), - ioExecutor_(ioExecutor), - connectorQueryCtx_(connectorQueryCtx), - hiveConfig_(hiveConfig), - pool_(connectorQueryCtx->memoryPool()), - outputType_(outputType), - expressionEvaluator_(connectorQueryCtx->expressionEvaluator()) { - // Column handled keyed on the column alias, the name used in the query. - for (const auto& [canonicalizedName, columnHandle] : columnHandles) { - auto handle = - std::dynamic_pointer_cast(columnHandle); - VELOX_CHECK_NOT_NULL( - handle, - "ColumnHandle must be an instance of HiveColumnHandle for {}", - canonicalizedName); - switch (handle->columnType()) { - case HiveColumnHandle::ColumnType::kRegular: - break; - case HiveColumnHandle::ColumnType::kPartitionKey: - partitionKeys_.emplace(handle->name(), handle); - break; - case HiveColumnHandle::ColumnType::kSynthesized: - infoColumns_.emplace(handle->name(), handle); - break; - case HiveColumnHandle::ColumnType::kRowIndex: - specialColumns_.rowIndex = handle->name(); - break; - case HiveColumnHandle::ColumnType::kRowId: - specialColumns_.rowId = handle->name(); - break; - } - } - - std::vector readColumnNames; - auto readColumnTypes = outputType_->children(); - for (const auto& outputName : outputType_->names()) { - auto it = columnHandles.find(outputName); - VELOX_CHECK( - it != columnHandles.end(), - "ColumnHandle is missing for output column: {}", - outputName); - - auto* handle = static_cast(it->second.get()); - readColumnNames.push_back(handle->name()); - for (auto& subfield : handle->requiredSubfields()) { - VELOX_USER_CHECK_EQ( - getColumnName(subfield), - handle->name(), - "Required subfield does not match column name"); - subfields_[handle->name()].push_back(&subfield); - } - } - - hiveTableHandle_ = - std::dynamic_pointer_cast(tableHandle); - VELOX_CHECK_NOT_NULL( - hiveTableHandle_, "TableHandle must be an instance of HiveTableHandle"); - if (hiveConfig_->isFileColumnNamesReadAsLowerCase( - connectorQueryCtx->sessionProperties())) { - checkColumnNameLowerCase(outputType_); - checkColumnNameLowerCase(hiveTableHandle_->subfieldFilters(), infoColumns_); - checkColumnNameLowerCase(hiveTableHandle_->remainingFilter()); - } - - for (const auto& [k, v] : hiveTableHandle_->subfieldFilters()) { - filters_.emplace(k.clone(), v); - } - double sampleRate = 1; - auto remainingFilter = extractFiltersFromRemainingFilter( - hiveTableHandle_->remainingFilter(), - expressionEvaluator_, - false, - filters_, - sampleRate); - if (sampleRate != 1) { - randomSkip_ = std::make_shared(sampleRate); - } - - if (remainingFilter) { - remainingFilterExprSet_ = expressionEvaluator_->compile(remainingFilter); - auto& remainingFilterExpr = remainingFilterExprSet_->expr(0); - folly::F14FastMap columnNames; - for (int i = 0; i < readColumnNames.size(); ++i) { - columnNames[readColumnNames[i]] = i; - } - for (auto& input : remainingFilterExpr->distinctFields()) { - auto it = columnNames.find(input->field()); - if (it != columnNames.end()) { - if (shouldEagerlyMaterialize(*remainingFilterExpr, *input)) { - multiReferencedFields_.push_back(it->second); - } - continue; - } - // Remaining filter may reference columns that are not used otherwise, - // e.g. are not being projected out and are not used in range filters. - // Make sure to add these columns to readerOutputType_. - readColumnNames.push_back(input->field()); - readColumnTypes.push_back(input->type()); - } - remainingFilterSubfields_ = remainingFilterExpr->extractSubfields(); - if (VLOG_IS_ON(1)) { - VLOG(1) << fmt::format( - "Extracted subfields from remaining filter: [{}]", - fmt::join(remainingFilterSubfields_, ", ")); - } - for (auto& subfield : remainingFilterSubfields_) { - const auto& name = getColumnName(subfield); - auto it = subfields_.find(name); - if (it != subfields_.end()) { - // Some subfields of the column are already projected out, we append the - // remainingFilter subfield - it->second.push_back(&subfield); - } else if (columnNames.count(name) == 0) { - // remainingFilter subfield's column is not projected out, we add the - // column and append the subfield - subfields_[name].push_back(&subfield); - } - } - } - - readerOutputType_ = - ROW(std::move(readColumnNames), std::move(readColumnTypes)); - scanSpec_ = makeScanSpec( - readerOutputType_, - subfields_, - filters_, - hiveTableHandle_->dataColumns(), - partitionKeys_, - infoColumns_, - specialColumns_, - hiveConfig_->readStatsBasedFilterReorderDisabled( - connectorQueryCtx_->sessionProperties()), - pool_); - if (remainingFilter) { - metadataFilter_ = std::make_shared( - *scanSpec_, *remainingFilter, expressionEvaluator_); - } - - ioStats_ = std::make_shared(); - fsStats_ = std::make_shared(); -} - -std::unique_ptr HiveDataSource::createSplitReader() { - return SplitReader::create( - split_, - hiveTableHandle_, - &partitionKeys_, - connectorQueryCtx_, - hiveConfig_, - readerOutputType_, - ioStats_, - fsStats_, - fileHandleFactory_, - ioExecutor_, - scanSpec_); -} + : FileDataSource( + outputType, + tableHandle, + assignments, + fileHandleFactory, + ioExecutor, + connectorQueryCtx, + hiveConfig), + hiveConfig_(hiveConfig) {} std::vector HiveDataSource::setupBucketConversion() { + auto hiveSplit = checkedPointerCast(split_); VELOX_CHECK_NE( - split_->bucketConversion->tableBucketCount, - split_->bucketConversion->partitionBucketCount); - VELOX_CHECK(split_->tableBucketNumber.has_value()); - VELOX_CHECK_NOT_NULL(hiveTableHandle_->dataColumns()); + hiveSplit->bucketConversion->tableBucketCount, + hiveSplit->bucketConversion->partitionBucketCount); + VELOX_CHECK(hiveSplit->tableBucketNumber.has_value()); + VELOX_CHECK_NOT_NULL(tableHandle_->dataColumns()); ++numBucketConversion_; bool rebuildScanSpec = false; std::vector names; std::vector types; std::vector bucketChannels; - for (auto& handle : split_->bucketConversion->bucketColumnHandles) { - VELOX_CHECK(handle->columnType() == HiveColumnHandle::ColumnType::kRegular); + for (auto& handle : hiveSplit->bucketConversion->bucketColumnHandles) { + VELOX_CHECK(handle->columnType() == FileColumnHandle::ColumnType::kRegular); if (subfields_.erase(handle->name()) > 0) { rebuildScanSpec = true; } @@ -241,8 +68,7 @@ std::vector HiveDataSource::setupBucketConversion() { } index = names.size(); names.push_back(handle->name()); - types.push_back( - hiveTableHandle_->dataColumns()->findChild(handle->name())); + types.push_back(tableHandle_->dataColumns()->findChild(handle->name())); rebuildScanSpec = true; } bucketChannels.push_back(*index); @@ -255,11 +81,12 @@ std::vector HiveDataSource::setupBucketConversion() { readerOutputType_, subfields_, filters_, - hiveTableHandle_->dataColumns(), + /*indexColumns=*/{}, + tableHandle_->dataColumns(), partitionKeys_, infoColumns_, specialColumns_, - hiveConfig_->readStatsBasedFilterReorderDisabled( + fileConfig_->readStatsBasedFilterReorderDisabled( connectorQueryCtx_->sessionProperties()), pool_); newScanSpec->moveAdaptationFrom(*scanSpec_); @@ -269,8 +96,9 @@ std::vector HiveDataSource::setupBucketConversion() { } void HiveDataSource::setupRowIdColumn() { - VELOX_CHECK(split_->rowIdProperties.has_value()); - const auto& props = *split_->rowIdProperties; + auto hiveSplit = checkedPointerCast(split_); + VELOX_CHECK(hiveSplit->rowIdProperties.has_value()); + const auto& props = *hiveSplit->rowIdProperties; auto* rowId = scanSpec_->childByName(*specialColumns_.rowId); VELOX_CHECK_NOT_NULL(rowId); auto& rowIdType = @@ -292,247 +120,65 @@ void HiveDataSource::setupRowIdColumn() { connectorQueryCtx_->memoryPool()); } -void HiveDataSource::addSplit(std::shared_ptr split) { - VELOX_CHECK_NULL( - split_, - "Previous split has not been processed yet. Call next to process the split."); - split_ = std::dynamic_pointer_cast(split); - VELOX_CHECK_NOT_NULL(split_, "Wrong type of split"); - - VLOG(1) << "Adding split " << split_->toString(); - - if (splitReader_) { - splitReader_.reset(); - } - +std::vector HiveDataSource::prepareSplit() { + auto hiveSplit = checkedPointerCast(split_); + ++numSplitsByFileFormat_[split_->fileFormat]; std::vector bucketChannels; - if (split_->bucketConversion.has_value()) { + if (hiveSplit->bucketConversion.has_value()) { bucketChannels = setupBucketConversion(); } if (specialColumns_.rowId.has_value()) { setupRowIdColumn(); } - - splitReader_ = createSplitReader(); - if (!bucketChannels.empty()) { - splitReader_->setBucketConversion(std::move(bucketChannels)); - } - // Split reader subclasses may need to use the reader options in prepareSplit - // so we initialize it beforehand. - splitReader_->configureReaderOptions(randomSkip_); - splitReader_->prepareSplit(metadataFilter_, runtimeStats_); - readerOutputType_ = splitReader_->readerOutputType(); + return bucketChannels; } -std::optional HiveDataSource::next( - uint64_t size, - velox::ContinueFuture& /*future*/) { - VELOX_CHECK(split_ != nullptr, "No split to process. Call addSplit first."); - VELOX_CHECK_NOT_NULL(splitReader_, "No split reader present"); - - TestValue::adjust( - "facebook::velox::connector::hive::HiveDataSource::next", this); - - if (splitReader_->emptySplit()) { - resetSplit(); - return nullptr; - } - - // Bucket conversion or delta update could add extra column to reader output. - auto needsExtraColumn = [&] { - return output_->asUnchecked()->childrenSize() < - readerOutputType_->size(); - }; - if (!output_ || needsExtraColumn()) { - output_ = BaseVector::create(readerOutputType_, 0, pool_); - } - - const auto rowsScanned = splitReader_->next(size, output_); - completedRows_ += rowsScanned; - if (rowsScanned == 0) { - splitReader_->updateRuntimeStats(runtimeStats_); - resetSplit(); - return nullptr; - } - - VELOX_CHECK( - !output_->mayHaveNulls(), "Top-level row vector cannot have nulls"); - auto rowsRemaining = output_->size(); - if (rowsRemaining == 0) { - // no rows passed the pushed down filters. - return getEmptyOutput(); - } - - auto rowVector = std::dynamic_pointer_cast(output_); - - // In case there is a remaining filter that excludes some but not all - // rows, collect the indices of the passing rows. If there is no filter, - // or it passes on all rows, leave this as null and let exec::wrap skip - // wrapping the results. - BufferPtr remainingIndices; - filterRows_.resize(rowVector->size()); - - if (remainingFilterExprSet_) { - rowsRemaining = evaluateRemainingFilter(rowVector); - VELOX_CHECK_LE(rowsRemaining, rowsScanned); - if (rowsRemaining == 0) { - // No rows passed the remaining filter. - return getEmptyOutput(); - } - - if (rowsRemaining < rowVector->size()) { - // Some, but not all rows passed the remaining filter. - remainingIndices = filterEvalCtx_.selectedIndices; - } - } - - if (outputType_->size() == 0) { - return exec::wrap(rowsRemaining, remainingIndices, rowVector); - } - - std::vector outputColumns; - outputColumns.reserve(outputType_->size()); - for (int i = 0; i < outputType_->size(); ++i) { - auto& child = rowVector->childAt(i); - if (remainingIndices) { - // Disable dictionary values caching in expression eval so that we - // don't need to reallocate the result for every batch. - child->disableMemo(); - } - outputColumns.emplace_back( - exec::wrapChild(rowsRemaining, remainingIndices, child)); - } - - return std::make_shared( - pool_, outputType_, BufferPtr(nullptr), rowsRemaining, outputColumns); -} +std::unique_ptr HiveDataSource::createSplitReader() { + auto bucketChannels = prepareSplit(); + auto hiveSplit = checkedPointerCast(split_); -void HiveDataSource::addDynamicFilter( - column_index_t outputChannel, - const std::shared_ptr& filter) { - auto& fieldSpec = scanSpec_->getChildByChannel(outputChannel); - fieldSpec.setFilter(filter); - scanSpec_->resetCachedValues(true); - if (splitReader_) { - splitReader_->resetFilterCaches(); - } + return std::make_unique( + hiveSplit, + tableHandle_, + &partitionKeys_, + connectorQueryCtx_, + fileConfig_, + readerOutputType_, + dataIoStats_, + metadataIoStats_, + ioStats_, + fileHandleFactory_, + ioExecutor_, + scanSpec_, + &infoColumns_, + std::move(bucketChannels), + /*subfieldFiltersForValidation=*/&filters_); } -std::unordered_map HiveDataSource::runtimeStats() { - auto res = runtimeStats_.toMap(); - res.insert( - {{"numPrefetch", RuntimeCounter(ioStats_->prefetch().count())}, - {"prefetchBytes", - RuntimeCounter( - ioStats_->prefetch().sum(), RuntimeCounter::Unit::kBytes)}, - {"totalScanTime", - RuntimeCounter( - ioStats_->totalScanTime(), RuntimeCounter::Unit::kNanos)}, - {Connector::kTotalRemainingFilterTime, - RuntimeCounter( - totalRemainingFilterTime_.load(std::memory_order_relaxed), - RuntimeCounter::Unit::kNanos)}, - {"ioWaitWallNanos", - RuntimeCounter( - ioStats_->queryThreadIoLatency().sum() * 1000, - RuntimeCounter::Unit::kNanos)}, - {"maxSingleIoWaitWallNanos", - RuntimeCounter( - ioStats_->queryThreadIoLatency().max() * 1000, - RuntimeCounter::Unit::kNanos)}, - {"overreadBytes", - RuntimeCounter( - ioStats_->rawOverreadBytes(), RuntimeCounter::Unit::kBytes)}}); - if (ioStats_->read().count() > 0) { - res.insert({"numStorageRead", RuntimeCounter(ioStats_->read().count())}); - res.insert( - {"storageReadBytes", - RuntimeCounter(ioStats_->read().sum(), RuntimeCounter::Unit::kBytes)}); - } - if (ioStats_->ssdRead().count() > 0) { - res.insert({"numLocalRead", RuntimeCounter(ioStats_->ssdRead().count())}); - res.insert( - {"localReadBytes", - RuntimeCounter( - ioStats_->ssdRead().sum(), RuntimeCounter::Unit::kBytes)}); - } - if (ioStats_->ramHit().count() > 0) { - res.insert({"numRamRead", RuntimeCounter(ioStats_->ramHit().count())}); - res.insert( - {"ramReadBytes", - RuntimeCounter( - ioStats_->ramHit().sum(), RuntimeCounter::Unit::kBytes)}); - } +std::unordered_map +HiveDataSource::getRuntimeStats() { + auto result = FileDataSource::getRuntimeStats(); if (numBucketConversion_ > 0) { - res.insert({"numBucketConversion", RuntimeCounter(numBucketConversion_)}); + result.insert( + {std::string(kNumBucketConversion), + RuntimeMetric(numBucketConversion_)}); } - - const auto fsStats = fsStats_->stats(); - for (const auto& storageStats : fsStats) { - res.emplace( - storageStats.first, - RuntimeCounter(storageStats.second.sum, storageStats.second.unit)); + for (const auto& [format, count] : numSplitsByFileFormat_) { + result.insert( + {fmt::format("{}{}", kFileFormat, dwio::common::toString(format)), + RuntimeMetric(count)}); } - return res; + return result; } void HiveDataSource::setFromDataSource( std::unique_ptr sourceUnique) { - auto source = dynamic_cast(sourceUnique.get()); - VELOX_CHECK_NOT_NULL(source, "Bad DataSource type"); - - split_ = std::move(source->split_); - runtimeStats_.skippedSplits += source->runtimeStats_.skippedSplits; - runtimeStats_.processedSplits += source->runtimeStats_.processedSplits; - runtimeStats_.skippedSplitBytes += source->runtimeStats_.skippedSplitBytes; - readerOutputType_ = std::move(source->readerOutputType_); - source->scanSpec_->moveAdaptationFrom(*scanSpec_); - scanSpec_ = std::move(source->scanSpec_); - splitReader_ = std::move(source->splitReader_); - splitReader_->setConnectorQueryCtx(connectorQueryCtx_); - // New io will be accounted on the stats of 'source'. Add the existing - // balance to that. - source->ioStats_->merge(*ioStats_); - ioStats_ = std::move(source->ioStats_); - source->fsStats_->merge(*fsStats_); - fsStats_ = std::move(source->fsStats_); - + auto* source = checkedPointerCast(sourceUnique.get()); numBucketConversion_ += source->numBucketConversion_; -} - -int64_t HiveDataSource::estimatedRowSize() { - if (!splitReader_) { - return kUnknownRowSize; + for (const auto& [format, count] : source->numSplitsByFileFormat_) { + numSplitsByFileFormat_[format] += count; } - return splitReader_->estimatedRowSize(); -} - -vector_size_t HiveDataSource::evaluateRemainingFilter(RowVectorPtr& rowVector) { - for (auto fieldIndex : multiReferencedFields_) { - LazyVector::ensureLoadedRows( - rowVector->childAt(fieldIndex), - filterRows_, - filterLazyDecoded_, - filterLazyBaseRows_); - } - uint64_t filterTimeUs{0}; - vector_size_t rowsRemaining{0}; - { - MicrosecondTimer timer(&filterTimeUs); - expressionEvaluator_->evaluate( - remainingFilterExprSet_.get(), filterRows_, *rowVector, filterResult_); - rowsRemaining = exec::processFilterResults( - filterResult_, filterRows_, filterEvalCtx_, pool_); - } - totalRemainingFilterTime_.fetch_add( - filterTimeUs * 1000, std::memory_order_relaxed); - return rowsRemaining; -} - -void HiveDataSource::resetSplit() { - split_.reset(); - splitReader_->resetSplit(); - // Keep readers around to hold adaptation. + FileDataSource::setFromDataSource(std::move(sourceUnique)); } HiveDataSource::WaveDelegateHookFunction HiveDataSource::waveDelegateHook_; @@ -540,8 +186,10 @@ HiveDataSource::WaveDelegateHookFunction HiveDataSource::waveDelegateHook_; std::shared_ptr HiveDataSource::toWaveDataSource() { VELOX_CHECK_NOT_NULL(waveDelegateHook_); if (!waveDataSource_) { + auto hiveTableHandle = + checkedPointerCast(tableHandle_); waveDataSource_ = waveDelegateHook_( - hiveTableHandle_, + hiveTableHandle, scanSpec_, readerOutputType_, &partitionKeys_, @@ -549,9 +197,9 @@ std::shared_ptr HiveDataSource::toWaveDataSource() { ioExecutor_, connectorQueryCtx_, hiveConfig_, - ioStats_, - remainingFilterExprSet_.get(), - metadataFilter_); + dataIoStats_, + remainingFilterExprSet(), + metadataFilter()); } return waveDataSource_; } @@ -560,6 +208,5 @@ std::shared_ptr HiveDataSource::toWaveDataSource() { void HiveDataSource::registerWaveDelegateHook(WaveDelegateHookFunction hook) { waveDelegateHook_ = hook; } -std::shared_ptr toWaveDataSource(); } // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/HiveDataSource.h b/velox/connectors/hive/HiveDataSource.h index 64aa3d6420b..df6dbd4c539 100644 --- a/velox/connectors/hive/HiveDataSource.h +++ b/velox/connectors/hive/HiveDataSource.h @@ -15,65 +15,31 @@ */ #pragma once -#include "velox/common/base/RandomUtil.h" -#include "velox/common/file/FileSystems.h" -#include "velox/common/io/IoStatistics.h" -#include "velox/connectors/Connector.h" -#include "velox/connectors/hive/FileHandle.h" -#include "velox/connectors/hive/HiveConnectorSplit.h" -#include "velox/connectors/hive/HiveConnectorUtil.h" -#include "velox/connectors/hive/SplitReader.h" +#include "velox/connectors/hive/FileDataSource.h" #include "velox/connectors/hive/TableHandle.h" -#include "velox/dwio/common/Statistics.h" -#include "velox/exec/OperatorUtils.h" -#include "velox/expression/Expr.h" namespace facebook::velox::connector::hive { class HiveConfig; -class HiveDataSource : public DataSource { +class HiveDataSource : public FileDataSource { public: + static constexpr std::string_view kNumBucketConversion{"numBucketConversion"}; + static constexpr std::string_view kFileFormat{"fileFormat."}; + HiveDataSource( const RowTypePtr& outputType, const connector::ConnectorTableHandlePtr& tableHandle, - const connector::ColumnHandleMap& columnHandles, + const connector::ColumnHandleMap& assignments, FileHandleFactory* fileHandleFactory, - folly::Executor* executor, + folly::Executor* ioExecutor, const ConnectorQueryCtx* connectorQueryCtx, const std::shared_ptr& hiveConfig); - void addSplit(std::shared_ptr split) override; - - std::optional next(uint64_t size, velox::ContinueFuture& future) - override; - - void addDynamicFilter( - column_index_t outputChannel, - const std::shared_ptr& filter) override; - - uint64_t getCompletedBytes() override { - return ioStats_->rawBytesRead(); - } - - uint64_t getCompletedRows() override { - return completedRows_; - } - - std::unordered_map runtimeStats() override; - - bool allPrefetchIssued() const override { - return splitReader_ && splitReader_->allPrefetchIssued(); - } + std::unordered_map getRuntimeStats() override; void setFromDataSource(std::unique_ptr sourceUnique) override; - int64_t estimatedRowSize() override; - - const common::SubfieldFilters* getFilters() const override { - return &filters_; - } - std::shared_ptr toWaveDataSource() override; using WaveDelegateHookFunction = @@ -81,12 +47,12 @@ class HiveDataSource : public DataSource { const HiveTableHandlePtr& hiveTableHandle, const std::shared_ptr& scanSpec, const RowTypePtr& readerOutputType, - std::unordered_map* partitionKeys, + std::unordered_map* partitionKeys, FileHandleFactory* fileHandleFactory, folly::Executor* executor, const ConnectorQueryCtx* connectorQueryCtx, const std::shared_ptr& hiveConfig, - const std::shared_ptr& ioStats, + const std::shared_ptr& ioStatistics, const exec::ExprSet* remainingFilter, std::shared_ptr metadataFilter)>; @@ -94,94 +60,25 @@ class HiveDataSource : public DataSource { static void registerWaveDelegateHook(WaveDelegateHookFunction hook); - const ConnectorQueryCtx* testingConnectorQueryCtx() const { - return connectorQueryCtx_; - } - protected: - virtual std::unique_ptr createSplitReader(); - - FileHandleFactory* const fileHandleFactory_; - folly::Executor* const ioExecutor_; - const ConnectorQueryCtx* const connectorQueryCtx_; - const std::shared_ptr hiveConfig_; - memory::MemoryPool* const pool_; + std::unique_ptr createSplitReader() override; - std::shared_ptr split_; - HiveTableHandlePtr hiveTableHandle_; - std::shared_ptr scanSpec_; - VectorPtr output_; - std::unique_ptr splitReader_; - - // Output type from file reader. This is different from outputType_ that it - // contains column names before assignment, and columns that only used in - // remaining filter. - RowTypePtr readerOutputType_; - - // Column handles for the partition key columns keyed on partition key column - // name. - std::unordered_map partitionKeys_; - - std::shared_ptr ioStats_; - std::shared_ptr fsStats_; + /// Pre-creation setup: stats tracking, bucket conversion, rowId. + /// Returns bucket channels (empty if none). + std::vector prepareSplit(); private: std::vector setupBucketConversion(); void setupRowIdColumn(); - // Evaluates remainingFilter_ on the specified vector. Returns number of rows - // passed. Populates filterEvalCtx_.selectedIndices and selectedBits if only - // some rows passed the filter. If none or all rows passed - // filterEvalCtx_.selectedIndices and selectedBits are not updated. - vector_size_t evaluateRemainingFilter(RowVectorPtr& rowVector); - - // Clear split_ after split has been fully processed. Keep readers around to - // hold adaptation. - void resetSplit(); - - const RowVectorPtr& getEmptyOutput() { - if (!emptyOutput_) { - emptyOutput_ = RowVector::createEmpty(outputType_, pool_); - } - return emptyOutput_; - } - - // The row type for the data source output, not including filter-only columns - const RowTypePtr outputType_; - core::ExpressionEvaluator* const expressionEvaluator_; - - // Column handles for the Split info columns keyed on their column names. - std::unordered_map infoColumns_; - SpecialColumnNames specialColumns_{}; - std::vector remainingFilterSubfields_; - folly::F14FastMap> - subfields_; - common::SubfieldFilters filters_; - std::shared_ptr metadataFilter_; - std::unique_ptr remainingFilterExprSet_; - RowVectorPtr emptyOutput_; - dwio::common::RuntimeStatistics runtimeStats_; - std::atomic totalRemainingFilterTime_{0}; - uint64_t completedRows_ = 0; - - // Field indices referenced in both remaining filter and output type. These - // columns need to be materialized eagerly to avoid missing values in output. - std::vector multiReferencedFields_; - - std::shared_ptr randomSkip_; + const std::shared_ptr hiveConfig_; int64_t numBucketConversion_ = 0; - // Reusable memory for remaining filter evaluation. - VectorPtr filterResult_; - SelectivityVector filterRows_; - DecodedVector filterLazyDecoded_; - SelectivityVector filterLazyBaseRows_; - exec::FilterEvalCtx filterEvalCtx_; + // Tracks the number of splits read per file format. + std::unordered_map numSplitsByFileFormat_; - // Remembers the WaveDataSource. Successive calls to toWaveDataSource() will - // return the same. std::shared_ptr waveDataSource_; }; } // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/HiveIndexSource.cpp b/velox/connectors/hive/HiveIndexSource.cpp new file mode 100644 index 00000000000..225fd6ea3cd --- /dev/null +++ b/velox/connectors/hive/HiveIndexSource.cpp @@ -0,0 +1,1370 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/connectors/hive/HiveIndexSource.h" + +#include +#include +#include "velox/common/base/RuntimeMetrics.h" +#include "velox/common/time/CpuWallTimer.h" +#include "velox/common/time/Timer.h" +#include "velox/connectors/hive/FileDataSource.h" +#include "velox/connectors/hive/FileIndexReader.h" +#include "velox/connectors/hive/FileSplitReader.h" +#include "velox/connectors/hive/HiveConfig.h" +#include "velox/connectors/hive/HiveConnectorUtil.h" +#include "velox/exec/OperatorUtils.h" +#include "velox/expression/FieldReference.h" +#include "velox/vector/LazyVector.h" + +namespace facebook::velox::connector::hive { +namespace { + +// Extracts a constant value from a point lookup filter (single value equality). +// Returns nullopt if the filter is not a point lookup or cannot be converted. +std::optional extractPointLookupValue(const common::Filter* filter) { + VELOX_CHECK_NOT_NULL(filter); + + switch (filter->kind()) { + case common::FilterKind::kBigintRange: { + const auto* range = filter->as(); + if (range->isSingleValue()) { + return variant(range->lower()); + } + return std::nullopt; + } + case common::FilterKind::kDoubleRange: { + const auto* range = filter->as(); + if (!range->lowerUnbounded() && !range->upperUnbounded() && + range->lower() == range->upper() && !range->lowerExclusive() && + !range->upperExclusive()) { + return variant(range->lower()); + } + return std::nullopt; + } + case common::FilterKind::kFloatRange: { + const auto* range = filter->as(); + if (!range->lowerUnbounded() && !range->upperUnbounded() && + range->lower() == range->upper() && !range->lowerExclusive() && + !range->upperExclusive()) { + return variant(range->lower()); + } + return std::nullopt; + } + case common::FilterKind::kBytesRange: { + const auto* range = filter->as(); + if (range->isSingleValue()) { + return variant(range->lower()); + } + return std::nullopt; + } + case common::FilterKind::kBytesValues: { + const auto* values = filter->as(); + if (values->values().size() == 1) { + return variant(*values->values().begin()); + } + return std::nullopt; + } + case common::FilterKind::kBoolValue: { + const auto* boolFilter = filter->as(); + // BoolValue doesn't expose value() getter - use testBool to determine the + // value + return variant(boolFilter->testBool(true)); + } + default: + return std::nullopt; + } +} + +// Extracts range bounds from a range filter. +// Returns a pair of (lower, upper) variants. If a bound is unbounded, the +// corresponding variant will be null. +// Returns nullopt if the filter is not a range filter or cannot be converted. +std::optional> extractRangeBounds( + const common::Filter* filter) { + VELOX_CHECK_NOT_NULL(filter); + + switch (filter->kind()) { + case common::FilterKind::kBigintRange: { + const auto* range = filter->as(); + return std::make_pair(variant(range->lower()), variant(range->upper())); + } + case common::FilterKind::kDoubleRange: { + const auto* range = filter->as(); + if (range->lowerUnbounded() || range->upperUnbounded()) { + // Cannot convert unbounded ranges to BetweenCondition. + return std::nullopt; + } + return std::make_pair(variant(range->lower()), variant(range->upper())); + } + case common::FilterKind::kFloatRange: { + const auto* range = filter->as(); + if (range->lowerUnbounded() || range->upperUnbounded()) { + return std::nullopt; + } + return std::make_pair(variant(range->lower()), variant(range->upper())); + } + case common::FilterKind::kBytesRange: { + const auto* range = filter->as(); + if (!range->isSingleValue()) { + // Only convert bounded range filters. + return std::make_pair(variant(range->lower()), variant(range->upper())); + } + return std::nullopt; + } + default: + return std::nullopt; + } +} + +// Creates an EqualIndexLookupCondition with a constant value. +core::IndexLookupConditionPtr createEqualConditionWithConstant( + const std::string& columnName, + const TypePtr& type, + const variant& value) { + auto keyExpr = std::make_shared(type, columnName); + auto constantExpr = std::make_shared(type, value); + return std::make_shared( + std::move(keyExpr), std::move(constantExpr)); +} + +// Creates a BetweenIndexLookupCondition with constant bounds. +core::IndexLookupConditionPtr createBetweenConditionWithConstants( + const std::string& columnName, + const TypePtr& type, + const variant& lowerValue, + const variant& upperValue) { + auto keyExpr = std::make_shared(type, columnName); + auto lowerExpr = std::make_shared(type, lowerValue); + auto upperExpr = std::make_shared(type, upperValue); + return std::make_shared( + std::move(keyExpr), std::move(lowerExpr), std::move(upperExpr)); +} + +// Checks that a HiveColumnHandle is a regular column type. +void checkColumnHandleIsRegular(const FileColumnHandle& handle) { + VELOX_CHECK_EQ( + handle.columnType(), + FileColumnHandle::ColumnType::kRegular, + "Expected regular column, got {} for column {}", + FileColumnHandle::columnTypeName(handle.columnType()), + handle.name()); +} + +// Gets the table column name from assignments for a given input column name. +// Throws if not found in assignments. +std::string getTableColumnName( + const std::string& inputColumnName, + const connector::ColumnHandleMap& assignments) { + auto it = assignments.find(inputColumnName); + VELOX_USER_CHECK( + it != assignments.end(), + "Column not found in assignments: {}", + inputColumnName); + const auto* handle = + checkedPointerCast(it->second.get()); + return handle->name(); +} + +// Creates a new FieldAccessTypedExpr with the given column name but same type. +core::FieldAccessTypedExprPtr renameFieldAccess( + const core::FieldAccessTypedExprPtr& field, + const std::string& newName) { + return std::make_shared(field->type(), newName); +} + +// Converts an index lookup condition's key name from input column name to table +// column name. Returns a new condition with the converted key name. +core::IndexLookupConditionPtr convertConditionKeyName( + const core::IndexLookupConditionPtr& condition, + const connector::ColumnHandleMap& assignments) { + const auto tableColumnName = + getTableColumnName(condition->key->name(), assignments); + auto newKey = renameFieldAccess(condition->key, tableColumnName); + + if (auto equalCondition = + std::dynamic_pointer_cast( + condition)) { + return std::make_shared( + std::move(newKey), equalCondition->value); + } + if (auto betweenCondition = + std::dynamic_pointer_cast( + condition)) { + return std::make_shared( + std::move(newKey), betweenCondition->lower, betweenCondition->upper); + } + + VELOX_UNREACHABLE( + "Unsupported IndexLookupCondition type: {}", condition->toString()); +} + +// Filters input indices based on selected indices from filter evaluation. +BufferPtr filterIndices( + vector_size_t numRows, + const BufferPtr& selectedIndices, + const BufferPtr& inputIndices, + memory::MemoryPool* pool) { + auto resultIndices = allocateIndices(numRows, pool); + auto* rawResultIndices = resultIndices->asMutable(); + const auto* rawSelected = selectedIndices->as(); + const auto* rawInputIndices = inputIndices->as(); + for (vector_size_t i = 0; i < numRows; ++i) { + rawResultIndices[i] = rawInputIndices[rawSelected[i]]; + } + return resultIndices; +} +} // namespace + +namespace { + +// Merges results from multiple split-level ResultIterators in inputHit order. +// Each split independently produces results with sorted inputHits. This +// iterator interleaves rows across splits to maintain the global non-decreasing +// inputHit ordering required by IndexLookupJoin (for left join missed-row +// detection and the check in prepareLookupResult). +// +// Uses a k-way merge: buffers one Result per split, then repeatedly picks +// the split with the smallest current request index and copies all its rows +// for that index before moving on. +class UnionResultIterator : public IndexSource::ResultIterator { + public: + UnionResultIterator( + std::vector> splitIters, + const RowTypePtr& outputType, + memory::MemoryPool* pool) + : outputType_(outputType), pool_(pool) { + VELOX_CHECK_GT( + splitIters.size(), + 1, + "UnionResultIterator requires at least two iterators"); + splits_.reserve(splitIters.size()); + for (auto& iter : splitIters) { + splits_.emplace_back(std::move(iter)); + } + } + + bool hasNext() override { + for (const auto& split : splits_) { + if (!split.hasExhausted()) { + return true; + } + } + return false; + } + + std::optional> next( + vector_size_t size, + ContinueFuture& future) override { + // Fetch results for all non-exhausted splits that need data. + for (auto& split : splits_) { + if (split.needFetchResult()) { + if (!split.fetchResult(size, future)) { + VELOX_CHECK(future.valid(), "Async return requires a valid future"); + return std::nullopt; + } + } + } + + // Merge rows from all active splits in inputHit order by repeatedly + // picking the split with the smallest current request index and copying + // all its rows for that index. + auto mergedInputHits = allocateIndices(size, pool_); + auto* rawMergedHits = mergedInputHits->asMutable(); + auto mergedOutput = BaseVector::create(outputType_, size, pool_); + + vector_size_t numOutput = 0; + while (numOutput < size) { + int minSplitIndex = -1; + auto minRequestIndex = std::numeric_limits::max(); + for (size_t i = 0; i < splits_.size(); ++i) { + if (splits_[i].hasResult() && + splits_[i].currentRequestIndex() < minRequestIndex) { + minSplitIndex = static_cast(i); + minRequestIndex = splits_[i].currentRequestIndex(); + } + } + if (minSplitIndex < 0) { + // All splits are exhausted with no buffered data. + break; + } + + auto& split = splits_[minSplitIndex]; + VELOX_CHECK_LE(numOutput, size); + numOutput += split.fillResult( + minRequestIndex, + mergedOutput, + numOutput, + rawMergedHits, + size - numOutput); + + // Stop if this split's buffer is consumed but not exhausted. We must + // refill it before continuing to avoid emitting larger inputHits from + // other splits that would violate non-decreasing order across next() + // calls. + if (split.needFetchResult()) { + break; + } + } + + if (numOutput == 0) { + return nullptr; + } + + mergedInputHits->setSize(numOutput * sizeof(vector_size_t)); + mergedOutput->resize(numOutput); + return std::make_unique( + std::move(mergedInputHits), std::move(mergedOutput)); + } + + private: + // Tracks iteration state for a single split's ResultIterator. Buffers + // one Result at a time and tracks the current read position within it. + struct SplitState { + explicit SplitState(std::shared_ptr splitIter) + : iter(std::move(splitIter)) {} + + const std::shared_ptr iter; + // Current buffered result from this split, or nullptr if not yet fetched. + std::unique_ptr result; + // Next row to read within 'result'. + vector_size_t resultOffset{0}; + // True when the underlying iterator has no more results. + bool exhausted{false}; + + // Returns true if there are unconsumed rows in the current buffered result. + bool hasResult() const { + return result != nullptr && resultOffset < result->size(); + } + + // Returns true if the underlying iterator has no more results. + bool hasExhausted() const { + return exhausted; + } + + // Returns true if the split needs to fetch the next result batch. + bool needFetchResult() const { + return !hasResult() && !hasExhausted(); + } + + // Returns the request index (inputHit) of the current row in the buffer. + vector_size_t currentRequestIndex() const { + VELOX_CHECK(hasResult()); + return result->inputHits->as()[resultOffset]; + } + + // Copies buffered rows matching the given request index to the output, + // up to maxRows. Returns the number of rows copied. + vector_size_t fillResult( + vector_size_t requestIndex, + const RowVectorPtr& output, + vector_size_t outputOffset, + vector_size_t* rawHits, + vector_size_t maxRows) { + VELOX_CHECK(hasResult()); + // Count contiguous rows with the same request index. + const auto* hits = result->inputHits->as(); + vector_size_t count = 0; + while (count < maxRows && resultOffset + count < result->size() && + hits[resultOffset + count] == requestIndex) { + ++count; + } + if (count > 0) { + output->copy(result->output.get(), outputOffset, resultOffset, count); + std::fill( + rawHits + outputOffset, + rawHits + outputOffset + count, + requestIndex); + resultOffset += count; + } + return count; + } + + // Fetches the next non-empty result from the underlying iterator. + // Returns true when data is ready (or split is exhausted), false if async. + bool fetchResult(vector_size_t size, ContinueFuture& future) { + VELOX_CHECK( + !hasResult(), "Must consume current result before fetching next"); + while (!exhausted) { + if (!iter->hasNext()) { + exhausted = true; + return true; + } + auto resultOpt = iter->next(size, future); + if (!resultOpt.has_value()) { + return false; + } + auto fetchedResult = std::move(resultOpt).value(); + if (fetchedResult == nullptr) { + exhausted = true; + return true; + } + // Skip empty results (e.g., when all rows are filtered out by + // remaining filter) and continue fetching. + if (fetchedResult->size() == 0) { + continue; + } + result = std::move(fetchedResult); + resultOffset = 0; + return true; + } + VELOX_UNREACHABLE(); + } + }; + + // Output schema used to allocate merged result vectors. + const RowTypePtr outputType_; + memory::MemoryPool* const pool_; + // Per-split state for buffering and tracking iteration progress. Not const + // because elements are mutated during iteration (result, resultOffset, + // exhausted), and const vector makes elements const via const T& access. + std::vector splits_; +}; + +} // namespace + +// Scope-attached timer that accumulates wall and CPU time into the +// corresponding iterationStats fields when the attached block exits. +// Expects a local variable named 'iterationStats' of type +// HiveIndexSource::IterationStats. +// +// Usage: +// RECORD_CPU_WALL(setup) { +// // timed work +// } +#define RECORD_CPU_WALL(name) \ + if (DeltaCpuWallTimer _timer_##name([&](const CpuWallTiming& _t) { \ + iterationStats.name##WallNs += _t.wallNanos; \ + iterationStats.name##CpuNs += _t.cpuNanos; \ + }); \ + true) + +/// Iterates over results from a SplitIndexReader and applies HiveIndexSource's +/// format-agnostic orchestration: remaining filter evaluation and output +/// projection. +class HiveLookupIterator : public IndexSource::ResultIterator { + public: + HiveLookupIterator( + std::shared_ptr indexSource, + SplitIndexReader* indexReader, + IndexSource::Request request, + SplitIndexReader::Options options) + : indexSource_(std::move(indexSource)), + indexReader_(indexReader), + request_(std::move(request)), + options_(options) {} + + bool hasNext() override { + return state_ != State::kEnd; + } + + std::optional> next( + vector_size_t size, + ContinueFuture& /*unused*/) override { + if (state_ == State::kEnd) { + return nullptr; + } + + HiveIndexSource::IterationStats iterationStats; + SCOPE_EXIT { + indexSource_->recordIterationStats(iterationStats); + }; + + // Initialize lookup on first call. + if (state_ == State::kInit) { + RECORD_CPU_WALL(setup) { + indexReader_->startLookup(request_, options_); + } + setState(State::kRead); + } + + if (!indexReader_->hasNext()) { + setState(State::kEnd); + return nullptr; + } + return getOutput(size, iterationStats); + } + + private: + // State of the iterator. + enum class State { + // Initial state after creation. + kInit, + // After lookup request has been set in index reader. + kRead, + // After all data has been read. + kEnd, + }; + + // Sets the state with validation of allowed transitions. + // Allowed transitions: + // kInit -> kRead (when request is set) + // kInit -> kEnd (when no matches on first call) + // kRead -> kEnd (when all data has been read) + void setState(State newState) { + switch (state_) { + case State::kInit: + VELOX_CHECK( + newState == State::kRead || newState == State::kEnd, + "Invalid state transition from {} to {}", + stateName(state_), + stateName(newState)); + break; + case State::kRead: + VELOX_CHECK( + newState == State::kEnd, + "Invalid state transition from {} to {}", + stateName(state_), + stateName(newState)); + break; + case State::kEnd: + VELOX_FAIL( + "Invalid state transition from {} to {}", + stateName(state_), + stateName(newState)); + } + state_ = newState; + } + + static std::string stateName(State state) { + switch (state) { + case State::kInit: + return "kInit"; + case State::kRead: + return "kRead"; + case State::kEnd: + return "kEnd"; + default: + VELOX_UNREACHABLE("Unknown state {}", static_cast(state)); + } + } + + std::unique_ptr getOutput( + vector_size_t size, + HiveIndexSource::IterationStats& iterationStats) { + std::unique_ptr result; + RECORD_CPU_WALL(read) { + result = indexReader_->next(size); + } + if (result == nullptr) { + VELOX_CHECK(!indexReader_->hasNext()); + setState(State::kEnd); + return nullptr; + } + + if (!indexSource_->nonIndexConditions_.empty()) { + result = applyNonIndexCondition(std::move(result)); + } + + if (indexSource_->remainingFilterExprSet_ == nullptr) { + RECORD_CPU_WALL(output) { + result->output = indexSource_->projectOutput( + result->output->size(), nullptr, result->output); + } + } else { + result = evaluateRemainingFilter(std::move(result), iterationStats); + } + return result; + } + + std::unique_ptr applyNonIndexCondition( + std::unique_ptr result) { + BufferPtr passingIndices{nullptr}; + const auto numPassing = indexSource_->applyNonIndexConditions( + request_.input, result->output, result->inputHits, passingIndices); + if (numPassing == 0) { + return getEmptyResult(); + } + if (passingIndices == nullptr) { + return result; + } + result->inputHits = filterIndices( + numPassing, passingIndices, result->inputHits, indexSource_->pool_); + result->output = + exec::wrap(numPassing, std::move(passingIndices), result->output); + return result; + } + + std::unique_ptr evaluateRemainingFilter( + std::unique_ptr result, + HiveIndexSource::IterationStats& iterationStats) { + auto& output = result->output; + vector_size_t numRemainingRows; + RECORD_CPU_WALL(filter) { + numRemainingRows = indexSource_->evaluateRemainingFilter(output); + } + + if (numRemainingRows == 0) { + return getEmptyResult(); + } + + BufferPtr remainingIndices{nullptr}; + if (numRemainingRows != output->size()) { + remainingIndices = indexSource_->remainingFilterEvalCtx_.selectedIndices; + result->inputHits = filterIndices( + numRemainingRows, + remainingIndices, + result->inputHits, + indexSource_->pool_); + } + RECORD_CPU_WALL(output) { + output = indexSource_->projectOutput( + numRemainingRows, remainingIndices, output); + } + return result; + } + + std::unique_ptr getEmptyResult() { + if (emptyResult_ == nullptr) { + emptyResult_ = std::make_unique( + allocateIndices(0, indexSource_->pool_), + indexSource_->getEmptyOutput()); + } + return std::make_unique( + emptyResult_->inputHits, emptyResult_->output); + } + + const std::shared_ptr indexSource_; + // Raw pointer to index reader for lookup operations. + SplitIndexReader* const indexReader_; + const IndexSource::Request request_; + const SplitIndexReader::Options options_; + + State state_{State::kInit}; + // Cached empty result for reuse when no rows pass the remaining filter. + std::unique_ptr emptyResult_; +}; + +HiveIndexSource::HiveIndexSource( + const RowTypePtr& requestType, + const std::vector& indexLookupConditions, + const RowTypePtr& outputType, + HiveTableHandlePtr tableHandle, + const ColumnHandleMap& columnHandles, + FileHandleFactory* fileHandleFactory, + ConnectorQueryCtx* connectorQueryCtx, + const std::shared_ptr& hiveConfig, + folly::Executor* executor) + : fileHandleFactory_(fileHandleFactory), + connectorQueryCtx_(connectorQueryCtx), + hiveConfig_(hiveConfig), + pool_(connectorQueryCtx->memoryPool()), + expressionEvaluator_(connectorQueryCtx->expressionEvaluator()), + maxRowsPerIndexRequest_(hiveConfig_->maxRowsPerIndexRequest( + connectorQueryCtx_->sessionProperties())), + tableHandle_(std::move(tableHandle)), + requestType_(requestType), + outputType_(outputType), + executor_(executor), + ioStatistics_(std::make_shared()), + ioStats_(std::make_shared()) { + init(columnHandles, indexLookupConditions); +} + +void HiveIndexSource::initConditions( + const std::vector& indexLookupConditions, + const ColumnHandleMap& assignments, + const folly::F14FastMap& + columnHandles, + std::vector& readColumnNames, + std::vector& readColumnTypes) { + const auto& indexColumns = tableHandle_->indexColumns(); + const auto& dataColumns = tableHandle_->dataColumns(); + + // Build a map from condition key name to condition for quick lookup. The key + // name in IndexLookupCondition references input column name, we need to + // convert it to the table column name using assignments. + folly::F14FastMap conditionMap; + for (const auto& condition : indexLookupConditions) { + auto convertedCondition = convertConditionKeyName(condition, assignments); + const auto& columnName = convertedCondition->key->name(); + VELOX_USER_CHECK( + conditionMap.emplace(columnName, std::move(convertedCondition)).second, + "Duplicate lookup key found in indexLookupConditions: {}", + columnName); + } + + indexLookupConditions_.reserve(indexColumns.size()); + size_t numValidIndexLookupConditions{0}; + folly::F14FastSet indexConditions; + // Process index columns in order, converting filters to index lookup + // conditions where possible. A range filter/condition stops further + // processing. + for (const auto& indexColumn : indexColumns) { + const common::Subfield subfield(indexColumn); + const auto filterIt = filters_.find(subfield); + const bool hasFilter = filterIt != filters_.end(); + const auto conditionIt = conditionMap.find(indexColumn); + const bool hasIndexLookupCondition = conditionIt != conditionMap.end(); + + // Cannot have both a filter and an index lookup condition on the same + // column. + VELOX_CHECK( + !(hasFilter && hasIndexLookupCondition), + "Cannot have both filter and index lookup condition on index column {}", + indexColumn); + + if (!hasFilter && !hasIndexLookupCondition) { + // No filter or index lookup condition on this column - stop processing. + break; + } + + // Get column type from data columns. + const auto typeIdx = dataColumns->getChildIdxIfExists(indexColumn); + VELOX_CHECK( + typeIdx.has_value(), + "Index column {} not found in data columns", + indexColumn); + + if (hasIndexLookupCondition) { + // Use the existing index lookup condition as-is. + const auto& condition = conditionIt->second; + indexLookupConditions_.push_back(condition); + VELOX_CHECK(!condition->isFilter()); + ++numValidIndexLookupConditions; + indexConditions.insert(indexColumn); + + // Check if this is a range condition (Between) - stops further + // processing. + if (std::dynamic_pointer_cast( + condition)) { + break; + } + continue; + } + + // Has filter - try to convert to index lookup condition. + VELOX_CHECK(hasFilter); + const auto& columnType = dataColumns->childAt(*typeIdx); + const auto* filter = filterIt->second.get(); + // Try point lookup conversion first. + auto pointValue = extractPointLookupValue(filter); + if (pointValue.has_value()) { + auto condition = createEqualConditionWithConstant( + indexColumn, columnType, pointValue.value()); + indexLookupConditions_.push_back(condition); + // Remove converted filter from filters_ map. + filters_.erase(filterIt); + continue; + } + + // Try range conversion. + auto rangeBounds = extractRangeBounds(filter); + if (rangeBounds.has_value()) { + auto condition = createBetweenConditionWithConstants( + indexColumn, columnType, rangeBounds->first, rangeBounds->second); + indexLookupConditions_.push_back(condition); + // Remove converted filter from filters_ map. + filters_.erase(filterIt); + // Range condition stops further processing. + break; + } + + // Filter cannot be converted - leave it in filters_ map and stop + // processing. + break; + } + + // Process remaining conditions not consumed as index conditions. + folly::F14FastSet readColumnNameSet( + readColumnNames.begin(), readColumnNames.end()); + for (const auto& [columnName, condition] : conditionMap) { + if (indexConditions.count(columnName) > 0) { + continue; + } + + // Reject conditions on index columns that weren't consumed — this + // indicates a prefix-gap violation (e.g., u0 and u2 but not u1). + VELOX_CHECK( + std::find(indexColumns.begin(), indexColumns.end(), columnName) == + indexColumns.end(), + "Unprocessed join condition on index column " + "(conditions must follow index column order as a prefix): {}", + columnName); + + // Non-index conditions become post-read equality filters. + const auto equalCondition = + std::dynamic_pointer_cast(condition); + VELOX_CHECK_NOT_NULL( + equalCondition, + "Non-index join condition must be an equal condition: {}", + columnName); + VELOX_CHECK( + !equalCondition->isFilter(), + "Non-index join condition cannot be a constant filter: {}", + columnName); + + // Ensure the column is in the read set. + if (readColumnNameSet.count(columnName) == 0) { + auto handleIt = columnHandles.find(columnName); + VELOX_CHECK( + handleIt != columnHandles.end(), + "Non-index condition column missing from assignments: {}", + columnName); + readColumnNames.emplace_back(columnName); + readColumnTypes.push_back(handleIt->second->dataType()); + readColumnNameSet.insert(columnName); + } + + // Resolve column indices for runtime evaluation. + const auto probeFieldAccess = + checkedPointerCast( + equalCondition->value); + auto outputColumnIt = + std::find(readColumnNames.begin(), readColumnNames.end(), columnName); + const auto outputColumnIndex = + std::distance(readColumnNames.begin(), outputColumnIt); + const auto requestColumnIndex = + requestType_->getChildIdx(probeFieldAccess->name()); + nonIndexConditions_.push_back( + {static_cast(outputColumnIndex), + static_cast(requestColumnIndex)}); + } + + VELOX_CHECK_EQ( + numValidIndexLookupConditions + nonIndexConditions_.size(), + indexLookupConditions.size(), + "Not all join conditions were processed"); + + VELOX_CHECK( + !indexLookupConditions_.empty(), + "No index column join conditions found. At least one must be an " + "index column"); +} + +void HiveIndexSource::initRemainingFilter( + std::vector& readColumnNames, + std::vector& readColumnTypes) { + VELOX_CHECK_NULL(remainingFilterExprSet_); + double sampleRate = tableHandle_->sampleRate(); + auto remainingFilter = extractFiltersFromRemainingFilter( + tableHandle_->remainingFilter(), + expressionEvaluator_, + filters_, + sampleRate); + // TODO: support sample rate later. + VELOX_CHECK_EQ( + sampleRate, 1, "Sample rate is not supported for index source"); + + if (remainingFilter == nullptr) { + return; + } + + remainingFilterExprSet_ = expressionEvaluator_->compile(remainingFilter); + + const auto& remainingFilterExpr = remainingFilterExprSet_->expr(0); + folly::F14FastMap columnNames; + for (int i = 0; i < readColumnNames.size(); ++i) { + columnNames[readColumnNames[i]] = i; + } + for (const auto* input : remainingFilterExpr->distinctFields()) { + auto it = columnNames.find(input->field()); + if (it != columnNames.end()) { + if (shouldEagerlyMaterialize(*remainingFilterExpr, *input)) { + remainingEagerlyLoadFields_.push_back(it->second); + } + continue; + } + // Remaining filter may reference columns that are not used otherwise, + // e.g. are not being projected out and are not used in range filters. + // Make sure to add these columns to readerOutputType_. + readColumnNames.push_back(input->field()); + readColumnTypes.push_back(input->type()); + } + + remainingFilterSubfields_ = remainingFilterExpr->extractSubfields(); + for (auto& subfield : remainingFilterSubfields_) { + const auto& name = getColumnName(subfield); + auto it = projectedSubfields_.find(name); + if (it != projectedSubfields_.end()) { + // Some subfields of the column are already projected out, we append + // the remainingFilter subfield + it->second.push_back(&subfield); + } else if (columnNames.count(name) == 0) { + // remainingFilter subfield's column is not projected out, we add the + // column and append the subfield + projectedSubfields_[name].push_back(&subfield); + } + } +} + +void HiveIndexSource::init( + const ColumnHandleMap& assignments, + const std::vector& indexLookupConditions) { + VELOX_CHECK_NOT_NULL(tableHandle_); + + folly::F14FastMap columnHandles; + // Column handles keyed on the table column name. + for (const auto& [_, columnHandle] : assignments) { + auto handle = checkedPointerCast(columnHandle); + const auto [it, unique] = + columnHandles.emplace(handle->name(), handle.get()); + VELOX_CHECK(unique, "Duplicate column handle for {}", handle->name()); + // Allow regular and partition key columns. Partition keys are not read + // from files — their values are synthesized from split metadata. + VELOX_CHECK( + handle->columnType() == HiveColumnHandle::ColumnType::kRegular || + handle->columnType() == HiveColumnHandle::ColumnType::kPartitionKey, + "Unsupported column type {} for column {}", + HiveColumnHandle::columnTypeName(handle->columnType()), + handle->name()); + if (handle->columnType() == HiveColumnHandle::ColumnType::kPartitionKey) { + partitionKeyHandles_.emplace(handle->name(), handle); + } + } + + for (const auto& handle : tableHandle_->filterColumnHandles()) { + auto it = columnHandles.find(handle->name()); + if (it != columnHandles.end()) { + checkColumnHandleConsistent(*handle, *it->second); + continue; + } + checkColumnHandleIsRegular(*handle); + } + + std::vector readColumnNames; + auto readColumnTypes = outputType_->children(); + for (const auto& outputName : outputType_->names()) { + auto it = assignments.find(outputName); + VELOX_CHECK( + it != assignments.end(), + "ColumnHandle is missing for output column: {}", + outputName); + + auto handle = checkedPointerCast(it->second); + readColumnNames.push_back(handle->name()); + if (handle->columnType() == HiveColumnHandle::ColumnType::kPartitionKey) { + // Subfield projection / postProcessor checks below don't apply to + // partition columns; their values come from split metadata. + continue; + } + + for (auto& subfield : handle->requiredSubfields()) { + VELOX_USER_CHECK_EQ( + getColumnName(subfield), + handle->name(), + "Required subfield does not match column name"); + projectedSubfields_[handle->name()].push_back(&subfield); + } + VELOX_CHECK_NULL(handle->postProcessor(), "Post processor not supported"); + } + + if (hiveConfig_->isFileColumnNamesReadAsLowerCase( + connectorQueryCtx_->sessionProperties())) { + checkColumnNameLowerCase(outputType_); + checkColumnNameLowerCase(tableHandle_->subfieldFilters(), {}); + checkColumnNameLowerCase(tableHandle_->remainingFilter()); + } + + for (const auto& [subfield, filter] : tableHandle_->subfieldFilters()) { + filters_.emplace(subfield.clone(), filter); + } + + initRemainingFilter(readColumnNames, readColumnTypes); + + initConditions( + indexLookupConditions, + assignments, + columnHandles, + readColumnNames, + readColumnTypes); + + readerOutputType_ = + ROW(std::move(readColumnNames), std::move(readColumnTypes)); + scanSpec_ = makeScanSpec( + readerOutputType_, + projectedSubfields_, + filters_, + /*indexColumns=*/{}, + tableHandle_->dataColumns(), + partitionKeyHandles_, + /*infoColumns=*/{}, + /*specialColumns=*/{}, + hiveConfig_->readStatsBasedFilterReorderDisabled( + connectorQueryCtx_->sessionProperties()), + pool_); +} + +void HiveIndexSource::setPartitionValues( + const std::vector>& splits) { + if (partitionKeyHandles_.empty() || splits.empty()) { + return; + } + const auto firstSplit = + checkedPointerCast(splits[0]); + const auto& partitionValues = firstSplit->partitionKeys; + + // Sanity check all splits share the same partition values. + for (size_t i = 1; i < splits.size(); ++i) { + const auto& split = checkedPointerCast(splits[i]); + VELOX_CHECK( + split->partitionKeys == partitionValues, + "Split {} partition values differ from split 0: split[0].path={}, split[{}].path={}", + i, + firstSplit->filePath, + i, + split->filePath); + } + + // Iterate scan spec children and set constant values for partition columns + // found in the split metadata. + auto& childrenSpecs = scanSpec_->children(); + for (size_t i = 0; i < childrenSpecs.size(); ++i) { + auto* childSpec = childrenSpecs[i].get(); + const auto& fieldName = childSpec->fieldName(); + + if (auto partitionIt = partitionValues.find(fieldName); + partitionIt != partitionValues.end()) { + setPartitionValue(childSpec, fieldName, partitionIt->second); + } + } +} + +void HiveIndexSource::setPartitionValue( + common::ScanSpec* spec, + const std::string& partitionKey, + const std::optional& value) const { + const auto it = partitionKeyHandles_.find(partitionKey); + VELOX_CHECK( + it != partitionKeyHandles_.end(), + "ColumnHandle is missing for partition key {}", + partitionKey); + const auto type = it->second->dataType(); + const auto constant = newConstantFromString( + type, + value, + pool_, + hiveConfig_->readTimestampPartitionValueAsLocalTime( + connectorQueryCtx_->sessionProperties()), + it->second->isPartitionDateValueDaysSinceEpoch()); + spec->setConstantValue(constant); +} + +void HiveIndexSource::addSplits( + std::vector> splits) { + VELOX_CHECK( + readers_.empty(), + "addSplits can only be called once for HiveIndexSource"); + + setPartitionValues(splits); + + auto* registry = IndexReaderFactoryRegistry::getInstance(); + for (auto& split : splits) { + auto hiveSplit = checkedPointerCast(split); + const auto* factory = registry->getFactory(hiveSplit->fileFormat); + if (factory != nullptr) { + createCustomIndexReader(*factory, std::move(hiveSplit)); + } else { + VELOX_CHECK( + hiveSplit->fileFormat == dwio::common::FileFormat::NIMBLE || + hiveSplit->fileFormat == dwio::common::FileFormat::FLUX || + hiveSplit->fileFormat == dwio::common::FileFormat::SST, + "No IndexReaderFactory registered for format: {}", + dwio::common::toString(hiveSplit->fileFormat)); + createFileIndexReader(std::move(hiveSplit)); + } + } + + VELOX_CHECK(!readers_.empty(), "No index readers created from splits"); +} + +std::shared_ptr HiveIndexSource::lookup( + const Request& request) { + VELOX_CHECK(!readers_.empty(), "No index readers available for lookup"); + + auto options = SplitIndexReader::Options{ + .maxRowsPerRequest = static_cast(maxRowsPerIndexRequest_)}; + + if (readers_.size() == 1) { + return std::make_shared( + shared_from_this(), readers_[0].get(), request, options); + } + + return createUnionLookupIterator(request, options); +} + +std::shared_ptr +HiveIndexSource::createUnionLookupIterator( + const Request& request, + const SplitIndexReader::Options& options) { + // Wraps each reader in a HiveLookupIterator and unions their results + // via UnionResultIterator. Each reader searches its own split + // independently with the same request. + VELOX_CHECK_GT(readers_.size(), 1, "Union requires at least two readers"); + std::vector> splitIters; + splitIters.reserve(readers_.size()); + for (auto& reader : readers_) { + splitIters.push_back( + std::make_shared( + shared_from_this(), reader.get(), request, options)); + } + return std::make_shared( + std::move(splitIters), outputType_, pool_); +} + +std::unordered_map HiveIndexSource::runtimeStats() { + // Start with accumulated per-call timing stats. + auto stats = runtimeStats_; + + if (remainingFilterTimeNs_ != 0) { + stats[std::string(Connector::kTotalRemainingFilterTime)] = + RuntimeMetric(remainingFilterTimeNs_, RuntimeCounter::Unit::kNanos); + } + + // Merge stats from all readers. + for (auto& reader : readers_) { + for (auto& [key, metric] : reader->runtimeStats()) { + auto it = stats.find(key); + if (it != stats.end()) { + it->second.merge(metric); + } else { + stats.emplace(key, metric); + } + } + } + // Add I/O stats from ioStatistics_ (storage read, ram cache, ssd cache). + if (ioStatistics_) { + const auto& read = ioStatistics_->read(); + if (read.count() > 0) { + stats[std::string(FileDataSource::kStorageReadBytes)] = RuntimeMetric( + static_cast(read.sum()), + read.count(), + static_cast(read.min()), + static_cast(read.max()), + RuntimeCounter::Unit::kBytes); + } + const auto& ramHit = ioStatistics_->ramHit(); + if (ramHit.count() > 0) { + stats[std::string(FileDataSource::kNumRamRead)] = + RuntimeMetric(static_cast(ramHit.count())); + stats[std::string(FileDataSource::kRamReadBytes)] = RuntimeMetric( + static_cast(ramHit.sum()), + ramHit.count(), + static_cast(ramHit.min()), + static_cast(ramHit.max()), + RuntimeCounter::Unit::kBytes); + } + const auto& ssdRead = ioStatistics_->ssdRead(); + if (ssdRead.count() > 0) { + stats[std::string(FileDataSource::kNumLocalRead)] = + RuntimeMetric(static_cast(ssdRead.count())); + stats[std::string(FileDataSource::kLocalReadBytes)] = RuntimeMetric( + static_cast(ssdRead.sum()), + ssdRead.count(), + static_cast(ssdRead.min()), + static_cast(ssdRead.max()), + RuntimeCounter::Unit::kBytes); + } + } + + return stats; +} + +vector_size_t HiveIndexSource::applyNonIndexConditions( + const RowVectorPtr& request, + const RowVectorPtr& output, + const BufferPtr& inputHits, + BufferPtr& passingIndices) { + const auto numRows = output->size(); + const auto* hits = inputHits->as(); + + SelectivityVector passingRows(numRows, true); + + for (const auto& condition : nonIndexConditions_) { + const auto& outputVector = output->childAt(condition.outputColumnIndex); + const auto& requestVector = request->childAt(condition.requestColumnIndex); + for (vector_size_t row = 0; row < numRows; ++row) { + const auto requestRow = hits[row]; + // Either null means not equal (standard SQL join semantics). + if (outputVector->isNullAt(row) || requestVector->isNullAt(requestRow) || + !outputVector->equalValueAt(requestVector.get(), row, requestRow)) { + passingRows.setValid(row, false); + } + } + } + passingRows.updateBounds(); + + const auto numPassing = passingRows.countSelected(); + if (numPassing == numRows) { + return numRows; + } + + passingIndices = allocateIndices(numPassing, pool_); + auto* rawIndices = passingIndices->asMutable(); + vector_size_t numSelected = 0; + passingRows.applyToSelected( + [&](vector_size_t row) { rawIndices[numSelected++] = row; }); + VELOX_CHECK_EQ(numSelected, numPassing); + return numPassing; +} + +vector_size_t HiveIndexSource::evaluateRemainingFilter( + RowVectorPtr& rowVector) { + remainingFilterRows_.resize(rowVector->size()); + for (auto fieldIndex : remainingEagerlyLoadFields_) { + LazyVector::ensureLoadedRows( + rowVector->childAt(fieldIndex), + remainingFilterRows_, + remainingFilterLazyDecoded_, + remainingFilterLazyBaseRows_); + } + + uint64_t filterTimeNs{0}; + vector_size_t numRemainingRows{0}; + { + NanosecondTimer timer(&filterTimeNs); + expressionEvaluator_->evaluate( + remainingFilterExprSet_.get(), + remainingFilterRows_, + *rowVector, + remainingFilterResult_); + numRemainingRows = exec::processFilterResults( + remainingFilterResult_, + remainingFilterRows_, + remainingFilterEvalCtx_, + pool_); + } + remainingFilterTimeNs_ += filterTimeNs; + return numRemainingRows; +} + +void HiveIndexSource::recordIterationStats( + const IterationStats& iterationStats) { + const auto totalWallNs = iterationStats.setupWallNs + + iterationStats.readWallNs + iterationStats.outputWallNs + + iterationStats.filterWallNs; + if (totalWallNs != 0) { + exec::addOperatorRuntimeStats( + IterationStats::kConnectorLookupWallNanos, + RuntimeCounter( + static_cast(totalWallNs), RuntimeCounter::Unit::kNanos), + runtimeStats_); + } + if (iterationStats.setupWallNs != 0) { + exec::addOperatorRuntimeStats( + IterationStats::kIndexSetupWallNanos, + RuntimeCounter( + static_cast(iterationStats.setupWallNs), + RuntimeCounter::Unit::kNanos), + runtimeStats_); + } + if (iterationStats.setupCpuNs != 0) { + exec::addOperatorRuntimeStats( + IterationStats::kIndexSetupCpuNanos, + RuntimeCounter( + static_cast(iterationStats.setupCpuNs), + RuntimeCounter::Unit::kNanos), + runtimeStats_); + } + if (iterationStats.readWallNs != 0) { + exec::addOperatorRuntimeStats( + IterationStats::kIndexReadWallNanos, + RuntimeCounter( + static_cast(iterationStats.readWallNs), + RuntimeCounter::Unit::kNanos), + runtimeStats_); + } + if (iterationStats.readCpuNs != 0) { + exec::addOperatorRuntimeStats( + IterationStats::kIndexReadCpuNanos, + RuntimeCounter( + static_cast(iterationStats.readCpuNs), + RuntimeCounter::Unit::kNanos), + runtimeStats_); + } + if (iterationStats.outputCpuNs != 0) { + exec::addOperatorRuntimeStats( + IterationStats::kConnectorResultPrepareCpuNanos, + RuntimeCounter( + static_cast(iterationStats.outputCpuNs), + RuntimeCounter::Unit::kNanos), + runtimeStats_); + } + if (iterationStats.filterWallNs != 0) { + exec::addOperatorRuntimeStats( + IterationStats::kPostFilterWallNanos, + RuntimeCounter( + static_cast(iterationStats.filterWallNs), + RuntimeCounter::Unit::kNanos), + runtimeStats_); + } + if (iterationStats.filterCpuNs != 0) { + exec::addOperatorRuntimeStats( + IterationStats::kPostFilterCpuNanos, + RuntimeCounter( + static_cast(iterationStats.filterCpuNs), + RuntimeCounter::Unit::kNanos), + runtimeStats_); + } +} + +RowVectorPtr HiveIndexSource::projectOutput( + vector_size_t numRows, + const BufferPtr& remainingIndices, + const RowVectorPtr& rowVector) { + if (outputType_->size() == 0) { + return exec::wrap(numRows, remainingIndices, rowVector); + } + + std::vector outputColumns; + outputColumns.reserve(outputType_->size()); + for (int i = 0; i < outputType_->size(); ++i) { + auto& child = rowVector->childAt(i); + if (remainingIndices) { + // Disable dictionary values caching in expression eval so that we + // don't need to reallocate the result for every batch. + child->disableMemo(); + } + auto column = exec::wrapChild(numRows, remainingIndices, child); + outputColumns.push_back(std::move(column)); + } + + return std::make_shared( + pool_, outputType_, BufferPtr(nullptr), numRows, outputColumns); +} + +void HiveIndexSource::createCustomIndexReader( + const IndexReaderFactory& factory, + std::shared_ptr split) { + VELOX_CHECK_NOT_NULL(split); + auto reader = factory(split, tableHandle_, connectorQueryCtx_); + VELOX_CHECK_NOT_NULL( + reader, + "IndexReaderFactory returned null for format: {}", + dwio::common::toString(split->fileFormat)); + readers_.push_back(std::move(reader)); +} + +void HiveIndexSource::createFileIndexReader( + std::shared_ptr split) { + VELOX_CHECK_NOT_NULL(split); + readers_.push_back( + std::make_unique( + std::move(split), + tableHandle_, + connectorQueryCtx_, + hiveConfig_, + scanSpec_, + indexLookupConditions_, + requestType_, + readerOutputType_, + ioStatistics_, + ioStats_, + fileHandleFactory_, + executor_, + maxRowsPerIndexRequest_)); +} + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/HiveIndexSource.h b/velox/connectors/hive/HiveIndexSource.h new file mode 100644 index 00000000000..40e2eed4047 --- /dev/null +++ b/velox/connectors/hive/HiveIndexSource.h @@ -0,0 +1,274 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include "velox/common/io/IoStatistics.h" +#include "velox/connectors/Connector.h" +#include "velox/connectors/hive/FileHandle.h" +#include "velox/connectors/hive/HiveConnectorSplit.h" +#include "velox/connectors/hive/IndexReader.h" +#include "velox/connectors/hive/TableHandle.h" +#include "velox/core/PlanNode.h" +#include "velox/dwio/common/ScanSpec.h" +#include "velox/exec/OperatorUtils.h" +#include "velox/expression/Expr.h" +#include "velox/vector/DecodedVector.h" + +namespace facebook::velox::connector::hive { + +class HiveConfig; + +/// Provides index lookup for Hive-metastore-backed tables. +/// +/// Owns format-agnostic orchestration: remaining filter evaluation and output +/// projection. Delegates format-specific reads to either the built-in +/// FileIndexReader (for Nimble) or external SplitIndexReader implementations +/// registered via IndexReaderFactoryRegistry. Supports partitioned tables by +/// synthesizing partition column values from split metadata via scan-spec +/// constants. +class HiveIndexSource : public IndexSource, + public std::enable_shared_from_this { + public: + HiveIndexSource( + const RowTypePtr& requestType, + const std::vector& indexLookupConditions, + const RowTypePtr& outputType, + HiveTableHandlePtr tableHandle, + const ColumnHandleMap& columnHandles, + FileHandleFactory* fileHandleFactory, + ConnectorQueryCtx* connectorQueryCtx, + const std::shared_ptr& hiveConfig, + folly::Executor* executor); + + ~HiveIndexSource() override = default; + + void addSplits(std::vector> splits) override; + + std::shared_ptr lookup(const Request& request) override; + + std::unordered_map runtimeStats() override; + + private: + friend class HiveLookupIterator; + + memory::MemoryPool* pool() const { + return pool_; + } + + const RowTypePtr& outputType() const { + return outputType_; + } + + const RowVectorPtr& getEmptyOutput() { + if (!emptyOutput_) { + emptyOutput_ = RowVector::createEmpty(outputType_, pool_); + } + return emptyOutput_; + } + + // Evaluates remainingFilter on the specified vector. Returns number of rows + // passed. Populates filterEvalCtx_.selectedIndices and selectedBits if only + // some rows passed the filter. If none or all rows passed + // filterEvalCtx_.selectedIndices and selectedBits are not updated. + vector_size_t evaluateRemainingFilter(RowVectorPtr& rowVector); + + // Applies non-index equi-join conditions as post-read equality filters. + // Compares request (probe) values against reader output column values + // for each non-index join condition. Returns the number of rows that pass + // all conditions. Populates 'passingIndices' with the indices of rows that + // pass when not all rows pass. + vector_size_t applyNonIndexConditions( + const RowVectorPtr& request, + const RowVectorPtr& output, + const BufferPtr& inputHits, + BufferPtr& passingIndices); + + // Projects output from reader output to the final output type. Wraps child + // vectors with remaining filter indices if provided. Partition + // columns are emitted by the reader directly via setConstantValue() on + // the scan spec, so they are projected like any other column here. + RowVectorPtr projectOutput( + vector_size_t numRows, + const BufferPtr& remainingIndices, + const RowVectorPtr& rowVector); + + // Sets partition column constants on scanSpec_ from the first split's + // partition values. No-op if there are no partition columns or no splits. + void setPartitionValues( + const std::vector>& splits); + + // Sets a constant partition value on the scanSpec for a partition column. + void setPartitionValue( + common::ScanSpec* spec, + const std::string& partitionKey, + const std::optional& value) const; + + void init( + const ColumnHandleMap& assignments, + const std::vector& indexLookupConditions); + + // Initializes all join conditions: + // - Index conditions are pushed to indexLookupConditions_ (filters on + // index columns are converted to index lookup conditions). + // - Non-index conditions are resolved to nonIndexConditions_ for + // post-read equality filtering, and their columns are added to + // readColumnNames/readColumnTypes if not already present. + void initConditions( + const std::vector& indexLookupConditions, + const ColumnHandleMap& assignments, + const folly::F14FastMap& + columnHandles, + std::vector& readColumnNames, + std::vector& readColumnTypes); + + // Initializes the remaining filter: + // - Compiles the remaining filter expression. + // - Identifies columns to eagerly materialize. + // - Adds columns referenced by remaining filter but not projected. + // - Extracts and processes subfields from the remaining filter. + void initRemainingFilter( + std::vector& readColumnNames, + std::vector& readColumnTypes); + + // Creates a UnionResultIterator that unions results from all readers. + std::shared_ptr createUnionLookupIterator( + const Request& request, + const SplitIndexReader::Options& options); + + // Creates a SplitIndexReader using a registered IndexReaderFactory. + void createCustomIndexReader( + const IndexReaderFactory& factory, + std::shared_ptr split); + + // Creates a FileIndexReader for a single split. + void createFileIndexReader(std::shared_ptr split); + + // Per-iteration timing breakdown for index lookups. + struct IterationStats { + uint64_t setupWallNs{0}; + uint64_t setupCpuNs{0}; + uint64_t readWallNs{0}; + uint64_t readCpuNs{0}; + uint64_t outputWallNs{0}; + uint64_t outputCpuNs{0}; + uint64_t filterWallNs{0}; + uint64_t filterCpuNs{0}; + + static constexpr std::string_view kConnectorLookupWallNanos{ + "connectorLookupWallNanos"}; + static constexpr std::string_view kConnectorResultPrepareCpuNanos{ + "connectorResultPrepareCpuNanos"}; + static constexpr std::string_view kIndexSetupWallNanos{ + "connectorIndexSetupWallNanos"}; + static constexpr std::string_view kIndexSetupCpuNanos{ + "connectorIndexSetupCpuNanos"}; + static constexpr std::string_view kIndexReadWallNanos{ + "connectorIndexReadWallNanos"}; + static constexpr std::string_view kIndexReadCpuNanos{ + "connectorIndexReadCpuNanos"}; + static constexpr std::string_view kPostFilterWallNanos{ + "connectorPostFilterWallNanos"}; + static constexpr std::string_view kPostFilterCpuNanos{ + "connectorPostFilterCpuNanos"}; + }; + + // Records per-iteration timing breakdown using addOperatorRuntimeStats to + // preserve per-call count/min/max granularity. + void recordIterationStats(const IterationStats& iterationStats); + + FileHandleFactory* const fileHandleFactory_; + ConnectorQueryCtx* const connectorQueryCtx_; + const std::shared_ptr hiveConfig_; + memory::MemoryPool* const pool_; + core::ExpressionEvaluator* const expressionEvaluator_; + const uint32_t maxRowsPerIndexRequest_; + + const HiveTableHandlePtr tableHandle_; + const RowTypePtr requestType_; + const RowTypePtr outputType_; + folly::Executor* const executor_; + + // All index lookup conditions including equal lookup keys converted to + // EqualIndexLookupConditions and original non-filter index lookup conditions. + // This is passed to FileIndexReader. + std::vector indexLookupConditions_; + + // Non-index equi-join conditions: join keys that are not index columns + // (e.g., bucket columns used for colocated joins). Applied as post-read + // equality filters during lookup. + struct NonIndexCondition { + column_index_t outputColumnIndex; + column_index_t requestColumnIndex; + }; + std::vector nonIndexConditions_; + + // Partition column handles, keyed by table column name (handle->name()). + // Populated from kPartitionKey assignments in init(). Used to feed + // partition handles into makeScanSpec() and to look up dataType() when + // parsing partition value strings into typed constants. + std::unordered_map partitionKeyHandles_; + + // Filters for pushdown. Includes subfield filters from table handle and + // filters converted from constant index lookup conditions. + common::SubfieldFilters filters_; + + // Remaining filter expression set after filter pushdown. + std::unique_ptr remainingFilterExprSet_; + // Subfields referenced by the remaining filter. + std::vector remainingFilterSubfields_; + // Total time spent on evaluating remaining filter in nanoseconds. + uint64_t remainingFilterTimeNs_{0}; + // Field indices referenced in both remaining filter and output type. These + // columns need to be materialized eagerly to avoid missing values in output. + std::vector remainingEagerlyLoadFields_; + + // Filter evaluation state for remaining filter. + SelectivityVector remainingFilterRows_; + VectorPtr remainingFilterResult_; + DecodedVector remainingFilterLazyDecoded_; + SelectivityVector remainingFilterLazyBaseRows_; + exec::FilterEvalCtx remainingFilterEvalCtx_; + + // Points to subfields from both output columns' required subfields and + // remainingFilterSubfields_. Used to tell the reader which nested fields to + // project out. + folly::F14FastMap> + projectedSubfields_; + + // Scan spec for the index reader. + std::shared_ptr scanSpec_; + // Output type for the index reader. Includes partition columns when they + // appear in outputType_; the per-split reader emits their values as + // constants via the scan spec's setConstantValue(). + RowTypePtr readerOutputType_; + // All index readers (both built-in and external). FileIndexReader + // (Nimble) and external readers both implement SplitIndexReader. + // Created by addSplits(). + std::vector> readers_; + + // Cached empty output vector. + RowVectorPtr emptyOutput_; + + std::shared_ptr ioStatistics_; + std::shared_ptr ioStats_; + + // Per-call timing stats accumulated via addOperatorRuntimeStats(). + std::unordered_map runtimeStats_; +}; + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/HivePartitionFunction.cpp b/velox/connectors/hive/HivePartitionFunction.cpp index d273cc8163e..548aff99e15 100644 --- a/velox/connectors/hive/HivePartitionFunction.cpp +++ b/velox/connectors/hive/HivePartitionFunction.cpp @@ -15,6 +15,8 @@ */ #include "velox/connectors/hive/HivePartitionFunction.h" +#include + namespace facebook::velox::connector::hive { namespace { @@ -31,8 +33,7 @@ int32_t hashInt64(int64_t value) { __attribute__((no_sanitize("integer"))) #endif #endif -uint32_t -hashBytes(StringView bytes, int32_t initialValue) { +uint32_t hashBytes(StringView bytes, int32_t initialValue) { uint32_t hash = initialValue; auto* data = bytes.data(); for (auto i = 0; i < bytes.size(); ++i) { @@ -461,7 +462,7 @@ HivePartitionFunction::HivePartitionFunction( std::vector keyChannels, const std::vector& constValues) : numBuckets_{numBuckets}, - bucketToPartition_{bucketToPartition}, + bucketToPartition_{std::move(bucketToPartition)}, keyChannels_{std::move(keyChannels)} { precomputedHashes_.resize(keyChannels_.size()); size_t constChannel{0}; @@ -495,7 +496,7 @@ std::optional HivePartitionFunction::partition( } } - static const int32_t kInt32Max = std::numeric_limits::max(); + static constexpr int32_t kInt32Max = std::numeric_limits::max(); if (bucketToPartition_.empty()) { // NOTE: if bucket to partition mapping is empty, then we do diff --git a/velox/connectors/hive/HivePartitionName.cpp b/velox/connectors/hive/HivePartitionName.cpp new file mode 100644 index 00000000000..2d7e7b9b665 --- /dev/null +++ b/velox/connectors/hive/HivePartitionName.cpp @@ -0,0 +1,102 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/HivePartitionName.h" +#include "velox/common/encode/Base64.h" +#include "velox/dwio/catalog/fbhive/FileUtils.h" +#include "velox/type/DecimalUtil.h" + +namespace facebook::velox::connector::hive { + +using namespace facebook::velox::dwio::catalog::fbhive; + +namespace { + +template +std::string formatDecimal(T value, const TypePtr& type) { + const auto& [p, s] = getDecimalPrecisionScale(*type); + const auto& maxSize = DecimalUtil::maxStringViewSize(p, s); + std::string buffer(maxSize, '\0'); + const auto& actualSize = + DecimalUtil::castToString(value, s, maxSize, buffer.data()); + buffer.resize(actualSize); + return buffer; +} + +} // namespace + +std::string HivePartitionName::toName(int32_t value, const TypePtr& type) { + if (type->isDate()) { + return DateType::toIso8601(value); + } + return fmt::to_string(value); +} + +std::string HivePartitionName::toName(int64_t value, const TypePtr& type) { + if (type->isShortDecimal()) { + return formatDecimal(value, type); + } + return fmt::to_string(value); +} + +std::string HivePartitionName::toName(int128_t value, const TypePtr& type) { + if (type->isLongDecimal()) { + return formatDecimal(value, type); + } + return fmt::to_string(value); +} + +std::string HivePartitionName::toName(Timestamp value, const TypePtr& type) { + value.toTimezone(Timestamp::defaultTimezone()); + TimestampToStringOptions options; + options.dateTimeSeparator = ' '; + // Set the precision to milliseconds, and enable the skipTrailingZeros match + // the timestamp precision and truncation behavior of Presto. + options.precision = TimestampPrecision::kMilliseconds; + options.skipTrailingZeros = true; + + auto result = value.toString(options); + + // Presto's java.sql.Timestamp.toString() always keeps at least one decimal + // place even when all fractional seconds are zero. + // If skipTrailingZeros removed all fractional digits, add back ".0" to match + // Presto's behavior. + if (auto dotPos = result.find_last_of('.'); dotPos == std::string::npos) { + // No decimal point found, add ".0" + result += ".0"; + } + + return result; +} + +std::string HivePartitionName::partitionName( + uint32_t partitionId, + const RowVectorPtr& partitionValues, + bool partitionKeyAsLowerCase) { + auto toPartitionName = + [](auto value, const TypePtr& type, int /*columnIndex*/) { + return HivePartitionName::toName(value, type); + }; + return FileUtils::makePartName( + partitionKeyValues( + partitionId, + partitionValues, + /*nullValueString=*/"", + toPartitionName), + partitionKeyAsLowerCase); +} + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/HivePartitionName.h b/velox/connectors/hive/HivePartitionName.h new file mode 100644 index 00000000000..3e519528866 --- /dev/null +++ b/velox/connectors/hive/HivePartitionName.h @@ -0,0 +1,172 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include "velox/vector/ComplexVector.h" +#include "velox/vector/SimpleVector.h" + +namespace facebook::velox::connector::hive { + +/// Converting partition values to their string representations. +/// Provides template methods for formatting different data types according to +/// Hive partitioning conventions. +class HivePartitionName { + public: + /// Generic template for formatting partition values to strings using + /// fmt::to_string. Specialized for types that need special handling + /// (int32_t, int64_t, int128_t, Timestamp). + template + FOLLY_ALWAYS_INLINE static std::string toName(T value, const TypePtr& type) { + return fmt::to_string(value); + } + + /// Format int32_t partition values. Specialized to handle DATE type which + /// requires ISO-8601 formatting (YYYY-MM-DD) instead of raw integer value. + static std::string toName(int32_t value, const TypePtr& type); + + /// Format int64_t partition values. Specialized to handle short DECIMAL type + /// which requires decimal string formatting with proper precision and scale + /// instead of raw integer value. + static std::string toName(int64_t value, const TypePtr& type); + + /// Format int128_t partition values. Specialized to handle long DECIMAL type + /// which requires decimal string formatting with proper precision and scale + /// instead of raw integer value. + static std::string toName(int128_t value, const TypePtr& type); + + /// Format Timestamp partition values. Specialized to: + /// 1. Convert to default timezone + /// 2. Use space as date-time separator (not 'T') + /// 3. Use millisecond precision with trailing zeros skipped + /// 4. Always keep at least ".0" for fractional seconds (Presto compatibility) + static std::string toName(Timestamp value, const TypePtr& type); + + /// Build partition key-value pairs from partition values. + /// Returns a vector of (key, value) pairs for all partition columns. + /// @tparam F A callable that converts a value to a partition string. + /// Takes (value, type, columnIndex) and returns string. + /// @param partitionId The partition ID (row index) to extract values from. + /// @param partitionValues RowVector containing partition values. + /// @param nullValueString The string to use for null values. + /// @param toPartitionName Callable to convert a value to a string. + template + static std::vector> partitionKeyValues( + uint32_t partitionId, + const RowVectorPtr& partitionValues, + const std::string& nullValueString, + const F& toPartitionName); + + /// Generate a Hive partition directory name from partition values for + /// partitionId. + /// + /// @param partitionId The row index in partitionValues to extract values + /// from. + /// @param partitionValues RowVector containing partition values. Each + /// child vector represents a partition column, and the row at + /// partitionId contains the values for this partition. + /// @param partitionKeyAsLowerCase Controls whether partition column names + /// should be converted to lowercase in the output. When true, column + /// names are lowercased (e.g., "year=2025"); when false, original + /// casing is preserved (e.g., "Year=2025"). + /// @return A formatted partition directory name string. Null values are + /// represented as __HIVE_DEFAULT_PARTITION__. + static std::string partitionName( + uint32_t partitionId, + const RowVectorPtr& partitionValues, + bool partitionKeyAsLowerCase); +}; + +namespace detail { + +// Unified template function to extract partition key-value string from a +// vector. Used by both Hive and Iceberg partition name generators. +// +// @tparam Kind The TypeKind of the partition column. +// @tparam F A callable that converts a value to a partition string. +// @param partitionVector The vector containing partition values. +// @param row The row index to extract the value from. +// @param type The type of the partition column. +// @param columnIndex The column index in the partition values. +// @param toPartitionName Callable to convert a value to a partition string. +// @return A pair of (column_name, formatted_value). +template +std::string makePartitionKeyValueString( + const BaseVector& partitionVector, + vector_size_t row, + const TypePtr& type, + int columnIndex, + const F& toPartitionName) { + using T = typename TypeTraits::NativeType; + + return toPartitionName( + partitionVector.as>()->valueAt(row), type, columnIndex); +} + +#define PARTITION_TYPE_DISPATCH(TEMPLATE_FUNC, typeKind, ...) \ + [&]() { \ + switch (typeKind) { \ + case TypeKind::BOOLEAN: \ + case TypeKind::TINYINT: \ + case TypeKind::SMALLINT: \ + case TypeKind::INTEGER: \ + case TypeKind::BIGINT: \ + case TypeKind::VARCHAR: \ + case TypeKind::VARBINARY: \ + case TypeKind::TIMESTAMP: \ + return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( \ + TEMPLATE_FUNC, typeKind, __VA_ARGS__); \ + default: \ + VELOX_UNSUPPORTED( \ + "Unsupported partition type: {}", TypeKindName::toName(typeKind)); \ + } \ + }() + +} // namespace detail + +template +std::vector> +HivePartitionName::partitionKeyValues( + uint32_t partitionId, + const RowVectorPtr& partitionValues, + const std::string& nullValueString, + const F& toPartitionName) { + std::vector> partitionKeyValuePairs; + for (auto i = 0; i < partitionValues->childrenSize(); i++) { + const auto& child = partitionValues->childAt(i); + const auto& name = partitionValues->rowType()->nameOf(i); + if (child->isNullAt(partitionId)) { + partitionKeyValuePairs.emplace_back( + std::make_pair(name, nullValueString)); + continue; + } + + partitionKeyValuePairs.emplace_back( + std::make_pair( + name, + PARTITION_TYPE_DISPATCH( + detail::makePartitionKeyValueString, + child->typeKind(), + *child->loadedVector(), + partitionId, + child->type(), + i, + toPartitionName))); + } + return partitionKeyValuePairs; +} + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/HivePartitionUtil.cpp b/velox/connectors/hive/HivePartitionUtil.cpp deleted file mode 100644 index cb95b916df3..00000000000 --- a/velox/connectors/hive/HivePartitionUtil.cpp +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "velox/connectors/hive/HivePartitionUtil.h" -#include "velox/vector/SimpleVector.h" - -namespace facebook::velox::connector::hive { - -#define PARTITION_TYPE_DISPATCH(TEMPLATE_FUNC, typeKind, ...) \ - [&]() { \ - switch (typeKind) { \ - case TypeKind::BOOLEAN: \ - case TypeKind::TINYINT: \ - case TypeKind::SMALLINT: \ - case TypeKind::INTEGER: \ - case TypeKind::BIGINT: \ - case TypeKind::VARCHAR: \ - case TypeKind::VARBINARY: \ - case TypeKind::TIMESTAMP: \ - return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( \ - TEMPLATE_FUNC, typeKind, __VA_ARGS__); \ - default: \ - VELOX_UNSUPPORTED( \ - "Unsupported partition type: {}", mapTypeKindToName(typeKind)); \ - } \ - }() - -namespace { -template -inline std::string makePartitionValueString(T value) { - return folly::to(value); -} - -template <> -inline std::string makePartitionValueString(bool value) { - return value ? "true" : "false"; -} - -template <> -inline std::string makePartitionValueString(Timestamp value) { - value.toTimezone(Timestamp::defaultTimezone()); - TimestampToStringOptions options; - options.dateTimeSeparator = ' '; - // Set the precision to milliseconds, and enable the skipTrailingZeros match - // the timestamp precision and truncation behavior of Presto. - options.precision = TimestampPrecision::kMilliseconds; - options.skipTrailingZeros = true; - - auto result = value.toString(options); - - // Presto's java.sql.Timestamp.toString() always keeps at least one decimal - // place even when all fractional seconds are zero. - // If skipTrailingZeros removed all fractional digits, add back ".0" to match - // Presto's behavior. - if (auto dotPos = result.find_last_of('.'); dotPos == std::string::npos) { - // No decimal point found, add ".0" - result += ".0"; - } - - return result; -} - -template -std::pair makePartitionKeyValueString( - const BaseVector* partitionVector, - vector_size_t row, - const std::string& name, - bool isDate) { - using T = typename TypeTraits::NativeType; - if (partitionVector->as>()->isNullAt(row)) { - return std::make_pair(name, ""); - } - if (isDate) { - return std::make_pair( - name, - DATE()->toString( - partitionVector->as>()->valueAt(row))); - } - return std::make_pair( - name, - makePartitionValueString( - partitionVector->as>()->valueAt(row))); -} - -} // namespace - -std::vector> extractPartitionKeyValues( - const RowVectorPtr& partitionsVector, - vector_size_t row) { - std::vector> partitionKeyValues; - for (auto i = 0; i < partitionsVector->childrenSize(); i++) { - partitionKeyValues.push_back(PARTITION_TYPE_DISPATCH( - makePartitionKeyValueString, - partitionsVector->childAt(i)->typeKind(), - partitionsVector->childAt(i)->loadedVector(), - row, - asRowType(partitionsVector->type())->nameOf(i), - partitionsVector->childAt(i)->type()->isDate())); - } - return partitionKeyValues; -} - -} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/HiveSplitReader.cpp b/velox/connectors/hive/HiveSplitReader.cpp new file mode 100644 index 00000000000..b6b1aa5a84a --- /dev/null +++ b/velox/connectors/hive/HiveSplitReader.cpp @@ -0,0 +1,237 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/HiveSplitReader.h" + +#include "velox/connectors/hive/FileConfig.h" +#include "velox/connectors/hive/HiveConnectorSplit.h" +#include "velox/connectors/hive/HiveConnectorUtil.h" + +namespace facebook::velox::connector::hive { + +HiveSplitReader::HiveSplitReader( + const std::shared_ptr& hiveSplit, + const FileTableHandlePtr& tableHandle, + const std::unordered_map* partitionKeys, + const ConnectorQueryCtx* connectorQueryCtx, + const std::shared_ptr& fileConfig, + const RowTypePtr& readerOutputType, + const std::shared_ptr& dataIoStats, + const std::shared_ptr& metadataIoStats, + const std::shared_ptr& ioStats, + FileHandleFactory* fileHandleFactory, + folly::Executor* ioExecutor, + const std::shared_ptr& scanSpec, + const std::unordered_map* infoColumns, + std::vector bucketChannels, + const common::SubfieldFilters* subfieldFiltersForValidation) + : FileSplitReader( + hiveSplit, + tableHandle, + partitionKeys, + connectorQueryCtx, + fileConfig, + readerOutputType, + dataIoStats, + metadataIoStats, + ioStats, + fileHandleFactory, + ioExecutor, + scanSpec, + subfieldFiltersForValidation), + hiveSplit_(hiveSplit), + infoColumns_(infoColumns) { + if (!bucketChannels.empty()) { + bucketChannels_ = {bucketChannels.begin(), bucketChannels.end()}; + partitionFunction_ = std::make_unique( + hiveSplit_->bucketConversion->tableBucketCount, + std::move(bucketChannels)); + } +} + +void HiveSplitReader::prepareSplit( + std::shared_ptr metadataFilter, + dwio::common::RuntimeStatistics& runtimeStats, + const folly::F14FastMap& fileReadOps) { + validateSynthesizedColumnFilters(); + FileSplitReader::prepareSplit( + std::move(metadataFilter), runtimeStats, fileReadOps); +} + +void HiveSplitReader::validateSynthesizedColumnFilters() const { + if (!subfieldFiltersForValidation_ || !infoColumns_) { + return; + } + const auto& splitInfoColumns = hiveSplit_->infoColumns; + for (const auto& [subfield, filter] : *subfieldFiltersForValidation_) { + const auto& fieldName = subfield.toString(); + auto infoColIter = splitInfoColumns.find(fieldName); + if (infoColIter == splitInfoColumns.end()) { + continue; + } + bool passed = false; + const auto& value = infoColIter->second; + auto handleIter = infoColumns_->find(fieldName); + VELOX_CHECK( + handleIter != infoColumns_->end(), + "Column handle for synthesized column '{}' not found in infoColumns", + fieldName); + TypeKind typeKind = handleIter->second->dataType()->kind(); + switch (typeKind) { + case TypeKind::BIGINT: + case TypeKind::INTEGER: + passed = common::applyFilter(*filter, folly::to(value)); + break; + case TypeKind::VARCHAR: + passed = common::applyFilter(*filter, value); + break; + default: + VELOX_FAIL("Unexpected type for synthesized column '{}'.", fieldName); + } + VELOX_CHECK( + passed, + "Synthesized column '{}' failed filter validation. " + "Filter: {}, Value: '{}'. Split: {}", + fieldName, + filter->toString(), + value, + fileSplit_->toString()); + } +} + +std::vector HiveSplitReader::adaptColumns( + const RowTypePtr& fileType, + const RowTypePtr& tableSchema) const { + std::vector columnTypes = fileType->children(); + + auto& childrenSpecs = scanSpec_->children(); + for (size_t i = 0; i < childrenSpecs.size(); ++i) { + auto* childSpec = childrenSpecs[i].get(); + const std::string& fieldName = childSpec->fieldName(); + + auto partitionIt = fileSplit_->partitionKeys.find(fieldName); + if (partitionIt != fileSplit_->partitionKeys.end()) { + setPartitionValue(childSpec, fieldName, partitionIt->second); + } else if ( + hiveSplit_->infoColumns.find(fieldName) != + hiveSplit_->infoColumns.end()) { + auto iter = hiveSplit_->infoColumns.find(fieldName); + auto infoColumnType = + readerOutputType_->childAt(readerOutputType_->getChildIdx(fieldName)); + auto constant = newConstantFromString( + infoColumnType, + iter->second, + connectorQueryCtx_->memoryPool(), + fileConfig_->readTimestampPartitionValueAsLocalTime( + connectorQueryCtx_->sessionProperties()), + false); + childSpec->setConstantValue(constant); + } else if ( + childSpec->columnType() == common::ScanSpec::ColumnType::kRegular) { + auto fileTypeIdx = fileType->getChildIdxIfExists(fieldName); + if (!fileTypeIdx.has_value()) { + VELOX_CHECK(tableSchema, "Unable to resolve column '{}'", fieldName); + childSpec->setConstantValue( + BaseVector::createNullConstant( + tableSchema->findChild(fieldName), + 1, + connectorQueryCtx_->memoryPool())); + } else { + childSpec->setConstantValue(nullptr); + auto outputTypeIdx = readerOutputType_->getChildIdxIfExists(fieldName); + if (outputTypeIdx.has_value()) { + auto& outputType = readerOutputType_->childAt(*outputTypeIdx); + auto& columnType = columnTypes[*fileTypeIdx]; + if (childSpec->isFlatMapAsStruct()) { + VELOX_CHECK(outputType->isRow() && columnType->isMap()); + } else { + columnType = outputType; + } + } + } + } + } + + scanSpec_->resetCachedValues(false); + + return columnTypes; +} + +void HiveSplitReader::configureBaseReaderOptions() { + hive::configureReaderOptions( + fileConfig_, + connectorQueryCtx_, + tableHandle_, + fileSplit_, + hiveSplit_->serdeParameters, + baseReaderOpts_); +} + +void HiveSplitReader::configureBaseRowReaderOptions( + std::shared_ptr metadataFilter, + RowTypePtr rowType) { + hive::configureRowReaderOptions( + tableHandle_->tableParameters(), + scanSpec_, + std::move(metadataFilter), + rowType, + fileSplit_, + hiveSplit_->serdeParameters, + fileConfig_, + connectorQueryCtx_->sessionProperties(), + ioExecutor_, + baseRowReaderOpts_); +} + +std::vector HiveSplitReader::bucketConversionRows( + const RowVector& vector) { + partitions_.clear(); + partitionFunction_->partition(vector, partitions_); + const auto bucketToKeep = *hiveSplit_->tableBucketNumber; + const auto partitionBucketCount = + hiveSplit_->bucketConversion->partitionBucketCount; + std::vector ranges; + for (vector_size_t i = 0; i < vector.size(); ++i) { + VELOX_CHECK_EQ((partitions_[i] - bucketToKeep) % partitionBucketCount, 0); + if (partitions_[i] == bucketToKeep) { + auto& r = ranges.emplace_back(); + r.sourceIndex = i; + r.targetIndex = ranges.size() - 1; + r.count = 1; + } + } + return ranges; +} + +void HiveSplitReader::applyBucketConversion( + VectorPtr& output, + const std::vector& ranges) { + auto filtered = + BaseVector::create(output->type(), ranges.size(), output->pool()); + filtered->copyRanges(output.get(), ranges); + output = std::move(filtered); +} + +uint64_t HiveSplitReader::next(uint64_t size, VectorPtr& output) { + auto numScanned = FileSplitReader::next(size, output); + if (numScanned > 0 && output->size() > 0 && partitionFunction_) { + applyBucketConversion( + output, bucketConversionRows(*output->asChecked())); + } + return numScanned; +} + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/HiveSplitReader.h b/velox/connectors/hive/HiveSplitReader.h new file mode 100644 index 00000000000..e50e87ed285 --- /dev/null +++ b/velox/connectors/hive/HiveSplitReader.h @@ -0,0 +1,94 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/connectors/hive/FileSplitReader.h" +#include "velox/connectors/hive/HivePartitionFunction.h" + +namespace facebook::velox::connector::hive { + +struct HiveConnectorSplit; + +/// Hive-specific FileSplitReader that adds bucket conversion support. +/// +/// Bucket conversion is needed when a table's bucket count is increased but +/// old partitions still use the original bucket count. In that case, a single +/// file may contain rows for multiple new buckets, and the reader must filter +/// to keep only rows belonging to the target bucket. +class HiveSplitReader : public FileSplitReader { + public: + HiveSplitReader( + const std::shared_ptr& hiveSplit, + const FileTableHandlePtr& tableHandle, + const std::unordered_map* partitionKeys, + const ConnectorQueryCtx* connectorQueryCtx, + const std::shared_ptr& fileConfig, + const RowTypePtr& readerOutputType, + const std::shared_ptr& dataIoStats, + const std::shared_ptr& metadataIoStats, + const std::shared_ptr& ioStats, + FileHandleFactory* fileHandleFactory, + folly::Executor* ioExecutor, + const std::shared_ptr& scanSpec, + const std::unordered_map* infoColumns, + std::vector bucketChannels = {}, + const common::SubfieldFilters* subfieldFiltersForValidation = nullptr); + + ~HiveSplitReader() override = default; + + void prepareSplit( + std::shared_ptr metadataFilter, + dwio::common::RuntimeStatistics& runtimeStats, + const folly::F14FastMap& fileReadOps = {}) + override; + + uint64_t next(uint64_t size, VectorPtr& output) override; + + const folly::F14FastSet& bucketChannels() const { + return bucketChannels_; + } + + protected: + void configureBaseReaderOptions() override; + + void configureBaseRowReaderOptions( + std::shared_ptr metadataFilter, + RowTypePtr rowType) override; + + std::vector bucketConversionRows( + const RowVector& vector); + + void applyBucketConversion( + VectorPtr& output, + const std::vector& ranges); + + void validateSynthesizedColumnFilters() const; + + const std::shared_ptr hiveSplit_; + + private: + std::vector adaptColumns( + const RowTypePtr& fileType, + const RowTypePtr& tableSchema) const override; + + const std::unordered_map* infoColumns_{ + nullptr}; + folly::F14FastSet bucketChannels_; + std::unique_ptr partitionFunction_; + std::vector partitions_; +}; + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/IndexReader.h b/velox/connectors/hive/IndexReader.h new file mode 100644 index 00000000000..3ad7756c985 --- /dev/null +++ b/velox/connectors/hive/IndexReader.h @@ -0,0 +1,148 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +#include + +#include "velox/common/base/Exceptions.h" +#include "velox/connectors/Connector.h" +#include "velox/dwio/common/Reader.h" + +namespace facebook::velox::connector::hive { + +struct HiveConnectorSplit; + +/// Abstract interface for format-specific index readers. +/// +/// Provides the customization point for storage formats to plug into +/// HiveIndexSource. Each format implements this interface +/// with its own I/O and decoding logic. HiveIndexSource owns all +/// format-agnostic orchestration (partition routing, remaining filter +/// evaluation, non-index condition filtering, output projection) and delegates +/// storage-specific reads to SplitIndexReader implementations. +/// +/// File-based readers (e.g., Nimble) can internally use ScanSpec and filter +/// pushdown without leaking those concepts through this interface. +class SplitIndexReader { + public: + using Request = IndexSource::Request; + using Result = IndexSource::Result; + using Options = dwio::common::IndexReader::Options; + + /// The total number of output rows returned across all next() calls. + static constexpr std::string_view kNumIndexReaderOutputRows{ + "numIndexReaderOutputRows"}; + + virtual ~SplitIndexReader() = default; + + /// Initializes a lookup for the given probe request. + /// + /// The reader handles format-specific I/O (e.g., file reads for Nimble, + /// network RPCs for remote stores). HiveIndexSource handles everything above + /// this level (filters, projection, partition routing). + /// + /// After calling startLookup(), the caller iterates results via hasNext() + /// and next(). + /// + /// @param request The probe-side input rows containing lookup keys. + /// @param options Lookup options (e.g., max rows per request batch). + virtual void startLookup( + const Request& request, + const Options& options = {}) = 0; + + /// Returns true if there are more result batches to read. + virtual bool hasNext() = 0; + + /// Returns the next batch of results, or nullptr if no more results. + /// + /// @param maxOutputRows Maximum number of rows to return in this batch. + virtual std::unique_ptr next(vector_size_t maxOutputRows) = 0; + + /// Returns runtime statistics collected by this reader. + virtual std::unordered_map runtimeStats() = 0; +}; + +/// Factory function type for creating IndexReader instances. +/// +/// Creates one IndexReader per split during HiveIndexSource::addSplits(). +/// Receives all the context needed to set up a reader for the given storage +/// format. +/// +/// @param split The split to create the reader for. +/// @param tableHandle The table handle containing table metadata. +/// @param connectorQueryCtx Query context (memory pool, session config, etc.). +/// @return A unique_ptr to the created IndexReader. +using IndexReaderFactory = std::function( + const std::shared_ptr& split, + const ConnectorTableHandlePtr& tableHandle, + ConnectorQueryCtx* connectorQueryCtx)>; + +/// Thread-safe registry for IndexReaderFactory instances keyed by file format. +/// +/// External storage formats register their reader +/// factories at application startup. HiveIndexSource consults this registry +/// during addSplits() to find the appropriate reader for each split's file +/// format. +/// +/// Example registration: +/// IndexReaderFactoryRegistry::getInstance()->registerFactory( +/// FileFormat::DWRF, myDwrfReaderFactory); +class IndexReaderFactoryRegistry { + public: + FOLLY_EXPORT static IndexReaderFactoryRegistry* getInstance() { + static IndexReaderFactoryRegistry instance; + return &instance; + } + + /// Registers a factory for the given file format. Throws if a factory is + /// already registered for the format. + void registerFactory( + dwio::common::FileFormat format, + IndexReaderFactory factory) { + std::lock_guard lock(mutex_); + auto [it, inserted] = factories_.emplace(format, std::move(factory)); + VELOX_CHECK( + inserted, + "IndexReaderFactory already registered for format: {}", + dwio::common::toString(format)); + } + + /// Unregisters the factory for the given file format. Returns true if a + /// factory was removed, false if none was registered. + bool unregisterFactory(dwio::common::FileFormat format) { + std::lock_guard lock(mutex_); + return factories_.erase(format) > 0; + } + + /// Returns the registered factory for the given format, or nullptr if none + /// is registered. + const IndexReaderFactory* getFactory(dwio::common::FileFormat format) const { + std::lock_guard lock(mutex_); + auto it = factories_.find(format); + return it != factories_.end() ? &it->second : nullptr; + } + + private: + IndexReaderFactoryRegistry() = default; + + mutable std::mutex mutex_; + std::unordered_map factories_; +}; + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/PartitionIdGenerator.cpp b/velox/connectors/hive/PartitionIdGenerator.cpp index bb20301b678..a4773da58a1 100644 --- a/velox/connectors/hive/PartitionIdGenerator.cpp +++ b/velox/connectors/hive/PartitionIdGenerator.cpp @@ -16,38 +16,30 @@ #include "velox/connectors/hive/PartitionIdGenerator.h" -#include "velox/connectors/hive/HivePartitionUtil.h" -#include "velox/dwio/catalog/fbhive/FileUtils.h" - -using namespace facebook::velox::dwio::catalog::fbhive; - namespace facebook::velox::connector::hive { PartitionIdGenerator::PartitionIdGenerator( const RowTypePtr& inputType, std::vector partitionChannels, uint32_t maxPartitions, - memory::MemoryPool* pool, - bool partitionPathAsLowerCase) + memory::MemoryPool* pool) : pool_(pool), partitionChannels_(std::move(partitionChannels)), - maxPartitions_(maxPartitions), - partitionPathAsLowerCase_(partitionPathAsLowerCase) { + maxPartitions_(maxPartitions) { VELOX_USER_CHECK( !partitionChannels_.empty(), "There must be at least one partition key."); for (auto channel : partitionChannels_) { hashers_.emplace_back( exec::VectorHasher::create(inputType->childAt(channel), channel)); + VELOX_USER_CHECK( + hashers_.back()->typeSupportsValueIds(), + "Unsupported partition type: {}.", + inputType->childAt(channel)->toString()); } std::vector partitionKeyTypes; std::vector partitionKeyNames; for (auto channel : partitionChannels_) { - VELOX_USER_CHECK( - exec::VectorHasher::typeKindSupportsValueIds( - inputType->childAt(channel)->kind()), - "Unsupported partition type: {}.", - inputType->childAt(channel)->toString()); partitionKeyTypes.push_back(inputType->childAt(channel)); partitionKeyNames.push_back(inputType->nameOf(channel)); } @@ -97,12 +89,6 @@ void PartitionIdGenerator::run( } } -std::string PartitionIdGenerator::partitionName(uint64_t partitionId) const { - return FileUtils::makePartName( - extractPartitionKeyValues(partitionValues_, partitionId), - partitionPathAsLowerCase_); -} - void PartitionIdGenerator::computeValueIds( const RowVectorPtr& input, raw_vector& valueIds) { diff --git a/velox/connectors/hive/PartitionIdGenerator.h b/velox/connectors/hive/PartitionIdGenerator.h index c4e0320b46c..0a53252829c 100644 --- a/velox/connectors/hive/PartitionIdGenerator.h +++ b/velox/connectors/hive/PartitionIdGenerator.h @@ -29,14 +29,11 @@ class PartitionIdGenerator { /// @param maxPartitions The max number of distinct partitions. /// @param pool Memory pool. Used to allocate memory for storing unique /// partition key values. - /// @param partitionPathAsLowerCase Used to control whether the partition path - /// need to convert to lower case. PartitionIdGenerator( const RowTypePtr& inputType, std::vector partitionChannels, uint32_t maxPartitions, - memory::MemoryPool* pool, - bool partitionPathAsLowerCase); + memory::MemoryPool* pool); /// Generate sequential partition IDs for input vector. /// @param input Input RowVector. @@ -48,11 +45,16 @@ class PartitionIdGenerator { return partitionIds_.size(); } - /// Return partition name for the given partition id in the typical Hive - /// style. It is derived from the partitionValues_ at index partitionId. - /// Partition keys appear in the order of partition columns in the table - /// schema. - std::string partitionName(uint64_t partitionId) const; + /// Returns the RowVector containing transformed partition keys. + /// Each row in this vector corresponds to a partition ID (row index = + /// partition ID). + /// Should be called after calling run() method. + /// + /// @return RowVector with one column per partition column, columns in same + /// order as partitionChannels_. + const RowVectorPtr& partitionValues() const { + return partitionValues_; + } private: static constexpr const int32_t kHasherReservePct = 20; @@ -81,8 +83,6 @@ class PartitionIdGenerator { const uint32_t maxPartitions_; - const bool partitionPathAsLowerCase_; - std::vector> hashers_; bool hasMultiplierSet_ = false; diff --git a/velox/connectors/hive/SplitReader.h b/velox/connectors/hive/SplitReader.h deleted file mode 100644 index 72a42b56b0b..00000000000 --- a/velox/connectors/hive/SplitReader.h +++ /dev/null @@ -1,204 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "velox/common/base/RandomUtil.h" -#include "velox/common/file/FileSystems.h" -#include "velox/connectors/hive/FileHandle.h" -#include "velox/connectors/hive/HivePartitionFunction.h" -#include "velox/dwio/common/Options.h" -#include "velox/dwio/common/Reader.h" - -namespace facebook::velox { -class BaseVector; -using VectorPtr = std::shared_ptr; -} // namespace facebook::velox - -namespace facebook::velox::common { -class MetadataFilter; -class ScanSpec; -} // namespace facebook::velox::common - -namespace facebook::velox::connector { -class ConnectorQueryCtx; -} // namespace facebook::velox::connector - -namespace facebook::velox::dwio::common { -struct RuntimeStatistics; -} // namespace facebook::velox::dwio::common - -namespace facebook::velox::memory { -class MemoryPool; -} - -namespace facebook::velox::connector::hive { - -struct HiveConnectorSplit; -class HiveTableHandle; -class HiveColumnHandle; -class HiveConfig; - -class SplitReader { - public: - static std::unique_ptr create( - const std::shared_ptr& hiveSplit, - const std::shared_ptr& hiveTableHandle, - const std::unordered_map< - std::string, - std::shared_ptr>* partitionKeys, - const ConnectorQueryCtx* connectorQueryCtx, - const std::shared_ptr& hiveConfig, - const RowTypePtr& readerOutputType, - const std::shared_ptr& ioStats, - const std::shared_ptr& fsStats, - FileHandleFactory* fileHandleFactory, - folly::Executor* ioExecutor, - const std::shared_ptr& scanSpec); - - virtual ~SplitReader() = default; - - void configureReaderOptions( - std::shared_ptr randomSkip); - - /// This function is used by different table formats like Iceberg and Hudi to - /// do additional preparations before reading the split, e.g. Open delete - /// files or log files, and add column adapatations for metadata columns. It - /// would be called only once per incoming split - virtual void prepareSplit( - std::shared_ptr metadataFilter, - dwio::common::RuntimeStatistics& runtimeStats); - - virtual uint64_t next(uint64_t size, VectorPtr& output); - - void resetFilterCaches(); - - bool emptySplit() const; - - void resetSplit(); - - int64_t estimatedRowSize() const; - - void updateRuntimeStats(dwio::common::RuntimeStatistics& stats) const; - - bool allPrefetchIssued() const; - - void setConnectorQueryCtx(const ConnectorQueryCtx* connectorQueryCtx); - - void setBucketConversion(std::vector bucketChannels); - - const RowTypePtr& readerOutputType() const { - return readerOutputType_; - } - - std::string toString() const; - - protected: - SplitReader( - const std::shared_ptr& hiveSplit, - const std::shared_ptr& hiveTableHandle, - const std::unordered_map< - std::string, - std::shared_ptr>* partitionKeys, - const ConnectorQueryCtx* connectorQueryCtx, - const std::shared_ptr& hiveConfig, - const RowTypePtr& readerOutputType, - const std::shared_ptr& ioStats, - const std::shared_ptr& fsStats, - FileHandleFactory* fileHandleFactory, - folly::Executor* executor, - const std::shared_ptr& scanSpec); - - /// Create the dwio::common::Reader object baseReader_, which will be used to - /// read the data file's metadata and schema - void createReader(); - - // Adjust the scan spec according to the current split, then return the - // adapted row type. - RowTypePtr getAdaptedRowType() const; - - // Check if the filters pass on the column statistics. When delta update is - // present, the corresonding filter should be disabled before calling this - // function. - bool filterOnStats(dwio::common::RuntimeStatistics& runtimeStats) const; - - /// Check if the hiveSplit_ is empty. The split is considered empty when - /// 1) The data file is missing but the user chooses to ignore it - /// 2) The file does not contain any rows - /// 3) The data in the file does not pass the filters. The test is based on - /// the file metadata and partition key values - /// This function needs to be called after baseReader_ is created. - bool checkIfSplitIsEmpty(dwio::common::RuntimeStatistics& runtimeStats); - - /// Create the dwio::common::RowReader object baseRowReader_, which owns the - /// ColumnReaders that will be used to read the data - void createRowReader( - std::shared_ptr metadataFilter, - RowTypePtr rowType); - - const folly::F14FastSet& bucketChannels() const { - return bucketChannels_; - } - - std::vector bucketConversionRows( - const RowVector& vector); - - void applyBucketConversion( - VectorPtr& output, - const std::vector& ranges); - - private: - /// Different table formats may have different meatadata columns. - /// This function will be used to update the scanSpec for these columns. - std::vector adaptColumns( - const RowTypePtr& fileType, - const std::shared_ptr& tableSchema) const; - - void setPartitionValue( - common::ScanSpec* spec, - const std::string& partitionKey, - const std::optional& value) const; - - protected: - std::shared_ptr hiveSplit_; - const std::shared_ptr hiveTableHandle_; - const std::unordered_map< - std::string, - std::shared_ptr>* const partitionKeys_; - const ConnectorQueryCtx* connectorQueryCtx_; - const std::shared_ptr hiveConfig_; - - RowTypePtr readerOutputType_; - const std::shared_ptr ioStats_; - const std::shared_ptr fsStats_; - FileHandleFactory* const fileHandleFactory_; - folly::Executor* const ioExecutor_; - memory::MemoryPool* const pool_; - - std::shared_ptr scanSpec_; - std::unique_ptr baseReader_; - std::unique_ptr baseRowReader_; - dwio::common::ReaderOptions baseReaderOpts_; - dwio::common::RowReaderOptions baseRowReaderOpts_; - bool emptySplit_; - - private: - folly::F14FastSet bucketChannels_; - std::unique_ptr partitionFunction_; - std::vector partitions_; -}; - -} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/TableHandle.cpp b/velox/connectors/hive/TableHandle.cpp index 3f7c8b6f93d..fc8b26508c6 100644 --- a/velox/connectors/hive/TableHandle.cpp +++ b/velox/connectors/hive/TableHandle.cpp @@ -19,37 +19,141 @@ namespace facebook::velox::connector::hive { namespace { -std::unordered_map -columnTypeNames() { - return { - {HiveColumnHandle::ColumnType::kPartitionKey, "PartitionKey"}, - {HiveColumnHandle::ColumnType::kRegular, "Regular"}, - {HiveColumnHandle::ColumnType::kSynthesized, "Synthesized"}, - {HiveColumnHandle::ColumnType::kRowIndex, "RowIndex"}, - }; + +folly::dynamic serializeExtractionPathElement( + const ExtractionPathElement& element) { + folly::dynamic obj = folly::dynamic::object; + obj["step"] = extractionStepName(element.step()); + if (element.step() == ExtractionStep::kStructField) { + obj["fieldName"] = + static_cast(element) + .fieldName(); + } + if (element.step() == ExtractionStep::kMapKeyFilter) { + if (auto* stringFilter = + dynamic_cast( + &element)) { + folly::dynamic keys = folly::dynamic::array; + for (const auto& key : stringFilter->filterKeys()) { + keys.push_back(key); + } + obj["stringFilterKeys"] = keys; + } else if ( + auto* intFilter = + dynamic_cast( + &element)) { + folly::dynamic keys = folly::dynamic::array; + for (const auto& key : intFilter->filterKeys()) { + keys.push_back(key); + } + obj["intFilterKeys"] = keys; + } + } + return obj; } -template -std::unordered_map invertMap(const std::unordered_map& mapping) { - std::unordered_map inverted; - for (const auto& [key, value] : mapping) { - inverted.emplace(value, key); +ExtractionPathElementPtr deserializeExtractionPathElement( + const folly::dynamic& obj) { + auto step = extractionStepFromName(obj["step"].asString()); + switch (step) { + case ExtractionStep::kStructField: + return ExtractionPathElement::structField(obj["fieldName"].asString()); + case ExtractionStep::kMapKeyFilter: { + std::vector stringKeys; + std::vector intKeys; + if (auto it = obj.find("stringFilterKeys"); it != obj.items().end()) { + for (const auto& key : it->second) { + stringKeys.push_back(key.asString()); + } + } + if (auto it = obj.find("intFilterKeys"); it != obj.items().end()) { + for (const auto& key : it->second) { + intKeys.push_back(key.asInt()); + } + } + if (!stringKeys.empty()) { + return ExtractionPathElement::mapKeyFilter(std::move(stringKeys)); + } + return ExtractionPathElement::mapKeyFilter(std::move(intKeys)); + } + case ExtractionStep::kMapKeys: + case ExtractionStep::kMapValues: + case ExtractionStep::kArrayElements: + case ExtractionStep::kSize: + return ExtractionPathElement::simple(step); } - return inverted; + VELOX_UNREACHABLE(); } -} // namespace +folly::dynamic serializeNamedExtraction(const NamedExtraction& extraction) { + folly::dynamic obj = folly::dynamic::object; + obj["outputName"] = extraction.outputName; + obj["dataType"] = extraction.dataType->serialize(); + folly::dynamic chain = folly::dynamic::array; + for (const auto& element : extraction.chain) { + chain.push_back(serializeExtractionPathElement(*element)); + } + obj["chain"] = chain; + return obj; +} -std::string HiveColumnHandle::columnTypeName( - HiveColumnHandle::ColumnType type) { - static const auto ctNames = columnTypeNames(); - return ctNames.at(type); +NamedExtraction deserializeNamedExtraction(const folly::dynamic& obj) { + NamedExtraction extraction; + extraction.outputName = obj["outputName"].asString(); + extraction.dataType = ISerializable::deserialize(obj["dataType"]); + const auto& chainArr = obj["chain"]; + extraction.chain.reserve(chainArr.size()); + for (const auto& element : chainArr) { + extraction.chain.push_back(deserializeExtractionPathElement(element)); + } + return extraction; } -HiveColumnHandle::ColumnType HiveColumnHandle::columnTypeFromName( - const std::string& name) { - static const auto nameColumnTypes = invertMap(columnTypeNames()); - return nameColumnTypes.at(name); +} // namespace + +HiveColumnHandle::HiveColumnHandle( + const std::string& name, + ColumnType columnType, + TypePtr dataType, + TypePtr hiveType, + std::vector requiredSubfields, + std::vector extractions, + ColumnParseParameters columnParseParameters, + std::function postProcessor) + : name_(name), + columnType_(columnType), + dataType_(std::move(dataType)), + hiveType_(std::move(hiveType)), + requiredSubfields_(std::move(requiredSubfields)), + extractions_(std::move(extractions)), + columnParseParameters_(columnParseParameters), + postProcessor_(std::move(postProcessor)) { + VELOX_USER_CHECK( + extractions_.empty() || requiredSubfields_.empty(), + "Extractions and requiredSubfields are mutually exclusive on column: {}", + name_); + + if (extractions_.empty()) { + // No extractions: dataType and hiveType must match (existing behavior). + VELOX_USER_CHECK( + dataType_->equivalent(*hiveType_), + "data type {} and hive type {} do not match", + dataType_->toString(), + hiveType_->toString()); + } else { + // Validate each extraction chain against hiveType and verify output types. + for (const auto& extraction : extractions_) { + auto derivedType = + deriveExtractionOutputType(hiveType_, extraction.chain); + VELOX_USER_CHECK( + derivedType->equivalent(*extraction.dataType), + "Extraction '{}' declared output type {} does not match " + "derived type: {}", + extraction.outputName, + extraction.dataType->toString(), + derivedType->toString()); + } + } } folly::dynamic HiveColumnHandle::serialize() const { @@ -63,6 +167,13 @@ folly::dynamic HiveColumnHandle::serialize() const { requiredSubfields.push_back(subfield.toString()); } obj["requiredSubfields"] = requiredSubfields; + if (!extractions_.empty()) { + folly::dynamic extractions = folly::dynamic::array; + for (const auto& extraction : extractions_) { + extractions.push_back(serializeNamedExtraction(extraction)); + } + obj["extractions"] = extractions; + } return obj; } @@ -77,7 +188,57 @@ std::string HiveColumnHandle::toString() const { for (const auto& subfield : requiredSubfields_) { out << " " << subfield.toString(); } - out << " ]]"; + out << " ]"; + if (!extractions_.empty()) { + out << ", extractions: ["; + for (size_t i = 0; i < extractions_.size(); ++i) { + if (i > 0) { + out << ", "; + } + const auto& extraction = extractions_[i]; + out << "{outputName: " << extraction.outputName << ", chain: ["; + for (size_t j = 0; j < extraction.chain.size(); ++j) { + if (j > 0) { + out << ", "; + } + const auto& elem = *extraction.chain[j]; + out << extractionStepName(elem.step()); + if (elem.step() == ExtractionStep::kStructField) { + out << "(" + << static_cast(elem) + .fieldName() + << ")"; + } + if (elem.step() == ExtractionStep::kMapKeyFilter) { + out << "("; + if (auto* strFilter = + dynamic_cast( + &elem)) { + for (size_t k = 0; k < strFilter->filterKeys().size(); ++k) { + if (k > 0) { + out << ", "; + } + out << "\"" << strFilter->filterKeys()[k] << "\""; + } + } else if ( + auto* intFilter = + dynamic_cast( + &elem)) { + for (size_t k = 0; k < intFilter->filterKeys().size(); ++k) { + if (k > 0) { + out << ", "; + } + out << intFilter->filterKeys()[k]; + } + } + out << ")"; + } + } + out << "], dataType: " << extraction.dataType->toString() << "}"; + } + out << "]"; + } + out << "]"; return out.str(); } @@ -94,8 +255,20 @@ ColumnHandlePtr HiveColumnHandle::create(const folly::dynamic& obj) { requiredSubfields.emplace_back(s.asString()); } + std::vector extractions; + if (auto it = obj.find("extractions"); it != obj.items().end()) { + for (const auto& extraction : it->second) { + extractions.push_back(deserializeNamedExtraction(extraction)); + } + } + return std::make_shared( - name, columnType, dataType, hiveType, std::move(requiredSubfields)); + name, + columnType, + dataType, + hiveType, + std::move(requiredSubfields), + std::move(extractions)); } void HiveColumnHandle::registerSerDe() { @@ -106,18 +279,47 @@ void HiveColumnHandle::registerSerDe() { HiveTableHandle::HiveTableHandle( std::string connectorId, const std::string& tableName, - bool filterPushdownEnabled, common::SubfieldFilters subfieldFilters, const core::TypedExprPtr& remainingFilter, const RowTypePtr& dataColumns, - const std::unordered_map& tableParameters) - : ConnectorTableHandle(std::move(connectorId)), + std::vector indexColumns, + const std::unordered_map& tableParameters, + std::vector filterColumnHandles, + double sampleRate, + std::string dbName) + : FileTableHandle(std::move(connectorId)), tableName_(tableName), - filterPushdownEnabled_(filterPushdownEnabled), subfieldFilters_(std::move(subfieldFilters)), remainingFilter_(remainingFilter), + sampleRate_(sampleRate), dataColumns_(dataColumns), - tableParameters_(tableParameters) {} + indexColumns_(std::move(indexColumns)), + tableParameters_(tableParameters), + filterColumnHandles_(std::move(filterColumnHandles)), + dbName_(std::move(dbName)) { + VELOX_CHECK_GT(sampleRate_, 0.0, "Sample rate must be positive"); + VELOX_CHECK_LE(sampleRate_, 1.0, "Sample rate must not exceed 1.0"); +} + +HiveTableHandle::HiveTableHandle( + std::string connectorId, + const std::string& tableName, + common::SubfieldFilters subfieldFilters, + const core::TypedExprPtr& remainingFilter, + const RowTypePtr& dataColumns, + const std::unordered_map& tableParameters, + std::vector filterColumnHandles, + double sampleRate) + : HiveTableHandle( + std::move(connectorId), + tableName, + std::move(subfieldFilters), + remainingFilter, + dataColumns, + /*indexColumns=*/{}, + tableParameters, + std::move(filterColumnHandles), + sampleRate) {} std::string HiveTableHandle::toString() const { std::stringstream out; @@ -139,6 +341,9 @@ std::string HiveTableHandle::toString() const { } out << "]"; } + if (sampleRate_ < 1.0) { + out << ", sample rate: " << sampleRate_; + } if (remainingFilter_) { out << ", remaining filter: (" << remainingFilter_->toString() << ")"; } @@ -159,13 +364,25 @@ std::string HiveTableHandle::toString() const { } out << "]"; } + if (!filterColumnHandles_.empty()) { + out << ", filter column handles: ["; + bool first = true; + for (auto& handle : filterColumnHandles_) { + if (first) { + first = false; + } else { + out << ", "; + } + out << handle->toString(); + } + out << "]"; + } return out.str(); } folly::dynamic HiveTableHandle::serialize() const { folly::dynamic obj = ConnectorTableHandle::serializeBase("HiveTableHandle"); obj["tableName"] = tableName_; - obj["filterPushdownEnabled"] = filterPushdownEnabled_; folly::dynamic subfieldFilters = folly::dynamic::array; for (const auto& [subfield, filter] : subfieldFilters_) { @@ -179,6 +396,11 @@ folly::dynamic HiveTableHandle::serialize() const { if (remainingFilter_) { obj["remainingFilter"] = remainingFilter_->serialize(); } + + if (sampleRate_ < 1.0) { + obj["sampleRate"] = sampleRate_; + } + if (dataColumns_) { obj["dataColumns"] = dataColumns_->serialize(); } @@ -187,6 +409,24 @@ folly::dynamic HiveTableHandle::serialize() const { tableParameters[param.first] = param.second; } obj["tableParameters"] = tableParameters; + if (!filterColumnHandles_.empty()) { + folly::dynamic filterColumnHandles = folly::dynamic::array; + for (const auto& handle : filterColumnHandles_) { + filterColumnHandles.push_back(handle->serialize()); + } + obj["filterColumnHandles"] = filterColumnHandles; + } + if (!indexColumns_.empty()) { + folly::dynamic indexColumns = folly::dynamic::array; + for (const auto& column : indexColumns_) { + indexColumns.push_back(column); + } + obj["indexColumns"] = indexColumns; + } + + if (!dbName_.empty()) { + obj["dbName"] = dbName_; + } return obj; } @@ -196,7 +436,6 @@ ConnectorTableHandlePtr HiveTableHandle::create( void* context) { auto connectorId = obj["connectorId"].asString(); auto tableName = obj["tableName"].asString(); - auto filterPushdownEnabled = obj["filterPushdownEnabled"].asBool(); core::TypedExprPtr remainingFilter; if (auto it = obj.find("remainingFilter"); it != obj.items().end()) { @@ -214,6 +453,11 @@ ConnectorTableHandlePtr HiveTableHandle::create( filter->clone(); } + double sampleRate = 1.0; + if (obj.count("sampleRate")) { + sampleRate = obj["sampleRate"].asDouble(); + } + RowTypePtr dataColumns; if (auto it = obj.find("dataColumns"); it != obj.items().end()) { dataColumns = ISerializable::deserialize(it->second, context); @@ -226,14 +470,37 @@ ConnectorTableHandlePtr HiveTableHandle::create( tableParameters.emplace(key.asString(), value.asString()); } + std::vector filterColumnHandles; + if (auto it = obj.find("filterColumnHandles"); it != obj.items().end()) { + for (const auto& handle : it->second) { + filterColumnHandles.push_back( + ISerializable::deserialize(handle, context)); + } + } + + std::vector indexColumns; + if (auto it = obj.find("indexColumns"); it != obj.items().end()) { + for (const auto& column : it->second) { + indexColumns.push_back(column.asString()); + } + } + + std::string dbName; + if (auto it = obj.find("dbName"); it != obj.items().end()) { + dbName = it->second.asString(); + } + return std::make_shared( connectorId, tableName, - filterPushdownEnabled, std::move(subfieldFilters), remainingFilter, dataColumns, - tableParameters); + std::move(indexColumns), + tableParameters, + std::move(filterColumnHandles), + sampleRate, + std::move(dbName)); } void HiveTableHandle::registerSerDe() { diff --git a/velox/connectors/hive/TableHandle.h b/velox/connectors/hive/TableHandle.h index c017eb41879..c0d5463c535 100644 --- a/velox/connectors/hive/TableHandle.h +++ b/velox/connectors/hive/TableHandle.h @@ -16,6 +16,8 @@ #pragma once #include "velox/connectors/Connector.h" +#include "velox/connectors/hive/FileColumnHandle.h" +#include "velox/connectors/hive/FileTableHandle.h" #include "velox/core/ITypedExpr.h" #include "velox/type/Filter.h" #include "velox/type/Subfield.h" @@ -23,18 +25,8 @@ namespace facebook::velox::connector::hive { -class HiveColumnHandle : public ColumnHandle { +class HiveColumnHandle : public FileColumnHandle { public: - enum class ColumnType { - kPartitionKey, - kRegular, - kSynthesized, - /// A zero-based row number of type BIGINT auto-generated by the connector. - /// Rows numbers are unique within a single file only. - kRowIndex, - kRowId, - }; - struct ColumnParseParameters { enum PartitionDateValueFormat { kISO8601, @@ -42,47 +34,68 @@ class HiveColumnHandle : public ColumnHandle { } partitionDateValueFormat; }; - /// NOTE: 'dataType' is the column type in target write table. 'hiveType' is + /// NOTE: 'dataType' is the column type in target write table. 'hiveType' is /// converted type of the corresponding column in source table which might not /// be the same type, and the table scan needs to do data coercion if needs. /// The table writer also needs to respect the type difference when processing /// input data such as bucket id calculation. + /// + /// 'extractions' specifies named extraction chains. When non-empty, + /// 'requiredSubfields' must be empty (mutually exclusive). When a single + /// extraction is present, 'dataType' is that extraction's dataType. When + /// multiple extractions are present, 'dataType' is a ROW type whose fields + /// are the outputNames with their corresponding dataTypes. HiveColumnHandle( const std::string& name, ColumnType columnType, TypePtr dataType, TypePtr hiveType, std::vector requiredSubfields = {}, - ColumnParseParameters columnParseParameters = {}) - : name_(name), - columnType_(columnType), - dataType_(std::move(dataType)), - hiveType_(std::move(hiveType)), - requiredSubfields_(std::move(requiredSubfields)), - columnParseParameters_(columnParseParameters) { - VELOX_USER_CHECK( - dataType_->equivalent(*hiveType_), - "data type {} and hive type {} do not match", - dataType_->toString(), - hiveType_->toString()); - } + std::vector extractions = {}, + ColumnParseParameters columnParseParameters = {}, + std::function postProcessor = {}); + + /// Legacy constructor without extractions for backward compatibility. + HiveColumnHandle( + const std::string& name, + ColumnType columnType, + TypePtr dataType, + TypePtr hiveType, + std::vector requiredSubfields, + ColumnParseParameters columnParseParameters, + std::function postProcessor = {}) + : HiveColumnHandle( + name, + columnType, + std::move(dataType), + std::move(hiveType), + std::move(requiredSubfields), + /*extractions=*/{}, + columnParseParameters, + std::move(postProcessor)) {} const std::string& name() const override { return name_; } - ColumnType columnType() const { + ColumnType columnType() const override { return columnType_; } - const TypePtr& dataType() const { + const TypePtr& dataType() const override { return dataType_; } + /// The type of this column as stored in the Hive source table. May differ + /// from dataType() when type coercion is needed for schema evolution. const TypePtr& hiveType() const { return hiveType_; } + const TypePtr& schemaType() const override { + return hiveType(); + } + /// Applies to columns of complex types: arrays, maps and structs. When a /// query uses only some of the subfields, the engine provides the complete /// list of required subfields and the connector is free to prune the rest. @@ -97,30 +110,53 @@ class HiveColumnHandle : public ColumnHandle { /// /// Pruning arrays means dropping values with indices larger than maximum /// required index. - const std::vector& requiredSubfields() const { + const std::vector& requiredSubfields() const override { return requiredSubfields_; } - bool isPartitionKey() const { + /// Named extraction chains. Empty means no extraction (current behavior). + /// When a single entry is present, the column handle's dataType is that + /// entry's dataType. When multiple entries are present, the column + /// handle's dataType is a ROW type whose fields are the outputNames with + /// their corresponding dataTypes. + /// Mutually exclusive with requiredSubfields — if extractions is non-empty, + /// requiredSubfields must be empty. + const std::vector& extractions() const override { + return extractions_; + } + + bool isPartitionKey() const override { return columnType_ == ColumnType::kPartitionKey; } - bool isPartitionDateValueDaysSinceEpoch() const { + bool isPartitionDateValueDaysSinceEpoch() const override { return columnParseParameters_.partitionDateValueFormat == ColumnParseParameters::kDaysSinceEpoch; } + /// Apply some row-wise post processing to this column when it is present in + /// output. + /// + /// It's not allowed to change the size of the vector in the processor. The + /// top level vector is guaranteed to be safe to change. Any inner vectors + /// and buffers need to check the reference count before doing any change in + /// place, otherwise you need to allocate new vectors and buffers. + /// + /// For lazy vector, this will be applied after the lazy vector is loaded. + /// This is only applied after all the filtering is done; the filters (both + /// subfield filters and remaining filter) still apply to values before post + /// processing. ValueHook usage will be disabled if a post processor is + /// present. + const std::function& postProcessor() const override { + return postProcessor_; + } + std::string toString() const override; folly::dynamic serialize() const override; static ColumnHandlePtr create(const folly::dynamic& obj); - static std::string columnTypeName(HiveColumnHandle::ColumnType columnType); - - static HiveColumnHandle::ColumnType columnTypeFromName( - const std::string& name); - static void registerSerDe(); private: @@ -129,23 +165,42 @@ class HiveColumnHandle : public ColumnHandle { const TypePtr dataType_; const TypePtr hiveType_; const std::vector requiredSubfields_; + const std::vector extractions_; const ColumnParseParameters columnParseParameters_; + const std::function postProcessor_; }; using HiveColumnHandlePtr = std::shared_ptr; using HiveColumnHandleMap = std::unordered_map; -class HiveTableHandle : public ConnectorTableHandle { +class HiveTableHandle : public FileTableHandle { public: + /// @param sampleRate Sampling rate in (0, 1] range. 0.1 means 10% sampling. + /// 1.0 means no sampling. Default is no sampling. HiveTableHandle( std::string connectorId, const std::string& tableName, - bool filterPushdownEnabled, common::SubfieldFilters subfieldFilters, const core::TypedExprPtr& remainingFilter, const RowTypePtr& dataColumns = nullptr, - const std::unordered_map& tableParameters = {}); + std::vector indexColumns = {}, + const std::unordered_map& tableParameters = {}, + std::vector filterColumnHandles = {}, + double sampleRate = 1.0, + std::string dbName = ""); + + /// Legacy constructor without indexColumns parameter for backward + /// compatibility. + HiveTableHandle( + std::string connectorId, + const std::string& tableName, + common::SubfieldFilters subfieldFilters, + const core::TypedExprPtr& remainingFilter, + const RowTypePtr& dataColumns, + const std::unordered_map& tableParameters, + std::vector filterColumnHandles, + double sampleRate = 1.0); const std::string& tableName() const { return tableName_; @@ -155,27 +210,55 @@ class HiveTableHandle : public ConnectorTableHandle { return tableName(); } - bool isFilterPushdownEnabled() const { - return filterPushdownEnabled_; + bool supportsIndexLookup() const override { + return !indexColumns_.empty(); + } + + bool needsIndexSplit() const override { + return true; } - const common::SubfieldFilters& subfieldFilters() const { + const common::SubfieldFilters& subfieldFilters() const override { return subfieldFilters_; } - const core::TypedExprPtr& remainingFilter() const { + const core::TypedExprPtr& remainingFilter() const override { return remainingFilter_; } - // Schema of the table. Need this for reading TEXTFILE. - const RowTypePtr& dataColumns() const { + double sampleRate() const override { + return sampleRate_; + } + + const RowTypePtr& dataColumns() const override { return dataColumns_; } - const std::unordered_map& tableParameters() const { + /// Returns the names of the index columns for the table. + const std::vector& indexColumns() const { + return indexColumns_; + } + + const std::unordered_map& tableParameters() + const override { return tableParameters_; } + /// Return filter column handles as FileColumnHandlePtr for the generic scan + /// pipeline. + std::vector filterColumnHandles() const override { + return {filterColumnHandles_.begin(), filterColumnHandles_.end()}; + } + + /// Return filter column handles with their concrete Hive type. + const std::vector& hiveFilterColumnHandles() const { + return filterColumnHandles_; + } + + const std::string& dbName() const override { + return dbName_; + } + std::string toString() const override; folly::dynamic serialize() const override; @@ -188,13 +271,19 @@ class HiveTableHandle : public ConnectorTableHandle { private: const std::string tableName_; - const bool filterPushdownEnabled_; const common::SubfieldFilters subfieldFilters_; const core::TypedExprPtr remainingFilter_; + const double sampleRate_; const RowTypePtr dataColumns_; + const std::vector indexColumns_; const std::unordered_map tableParameters_; + const std::vector filterColumnHandles_; + const std::string dbName_; }; using HiveTableHandlePtr = std::shared_ptr; } // namespace facebook::velox::connector::hive + +// The fmt::formatter for FileColumnHandle::ColumnType is defined in +// FileColumnHandle.h. diff --git a/velox/connectors/hive/benchmarks/HivePartitionFunctionBenchmark.cpp b/velox/connectors/hive/benchmarks/HivePartitionFunctionBenchmark.cpp index f503cb6e2e8..b55c02d7c07 100644 --- a/velox/connectors/hive/benchmarks/HivePartitionFunctionBenchmark.cpp +++ b/velox/connectors/hive/benchmarks/HivePartitionFunctionBenchmark.cpp @@ -97,7 +97,7 @@ class HivePartitionFunctionBenchmark void run(HivePartitionFunction* function) { if (rowVectors_.find(KIND) == rowVectors_.end()) { throw std::runtime_error( - fmt::format("Unsupported type {}.", mapTypeKindToName(KIND))); + fmt::format("Unsupported type {}.", TypeKindName::toName(KIND))); } function->partition(*rowVectors_[KIND], partitions_); } diff --git a/velox/connectors/hive/iceberg/CMakeLists.txt b/velox/connectors/hive/iceberg/CMakeLists.txt index 329998b5d40..f6369cb6d9e 100644 --- a/velox/connectors/hive/iceberg/CMakeLists.txt +++ b/velox/connectors/hive/iceberg/CMakeLists.txt @@ -11,15 +11,69 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +set( + ICEBERG_SOURCES + DeletionVectorReader.cpp + DeletionVectorWriter.cpp + EqualityDeleteFileReader.cpp + IcebergColumnHandle.cpp + IcebergConfig.cpp + IcebergConnector.cpp + IcebergDataFileStatistics.cpp + IcebergDataSink.cpp + IcebergDataSource.cpp + IcebergPartitionName.cpp + IcebergSplit.cpp + IcebergSplitReader.cpp + PartitionSpec.cpp + PositionalDeleteFileReader.cpp + TransformEvaluator.cpp + TransformExprBuilder.cpp + WriterOptionsAdapter.cpp +) + +if(VELOX_ENABLE_PARQUET) + list(APPEND ICEBERG_SOURCES IcebergParquetStatsCollector.cpp) +endif() velox_add_library( velox_hive_iceberg_splitreader - IcebergSplitReader.cpp - IcebergSplit.cpp - PositionalDeleteFileReader.cpp + ${ICEBERG_SOURCES} + HEADERS + DeletionVectorReader.h + DeletionVectorWriter.h + EqualityDeleteFileReader.h + IcebergColumnHandle.h + IcebergConfig.h + IcebergConnector.h + IcebergDataFileStatistics.h + IcebergDataSink.h + IcebergDataSource.h + IcebergDeleteFile.h + IcebergMetadataColumns.h + IcebergParquetStatsCollector.h + IcebergPartitionName.h + IcebergSplit.h + IcebergSplitReader.h + PartitionSpec.h + PositionalDeleteFileReader.h + TransformEvaluator.h + TransformExprBuilder.h + WriterOptionsAdapter.h +) + +velox_link_libraries( + velox_hive_iceberg_splitreader + velox_connector + velox_dwio_parquet_field_id + velox_functions_iceberg + velox_dwio_dwrf_writer + Folly::folly ) -velox_link_libraries(velox_hive_iceberg_splitreader velox_connector Folly::folly) +if(VELOX_ENABLE_PARQUET) + velox_link_libraries(velox_hive_iceberg_splitreader velox_dwio_arrow_parquet_writer) +endif() if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) diff --git a/velox/connectors/hive/iceberg/DeletionVectorReader.cpp b/velox/connectors/hive/iceberg/DeletionVectorReader.cpp new file mode 100644 index 00000000000..6d612a3905c --- /dev/null +++ b/velox/connectors/hive/iceberg/DeletionVectorReader.cpp @@ -0,0 +1,443 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/DeletionVectorReader.h" + +#include + +#include "velox/common/base/BitUtil.h" +#include "velox/common/base/Exceptions.h" +#include "velox/common/file/FileSystems.h" + +namespace facebook::velox::connector::hive::iceberg { + +namespace { +static constexpr uint32_t kSerialCookieNoRun = 12'346; +static constexpr uint32_t kSerialCookie = 12'347; +static constexpr uint32_t kRunContainersNoOffsetThreshold = 4; +} // namespace + +DeletionVectorReader::DeletionVectorReader( + const IcebergDeleteFile& dvFile, + uint64_t splitOffset, + memory::MemoryPool* /*pool*/) + : dvFile_(dvFile), splitOffset_(splitOffset) { + VELOX_CHECK( + dvFile_.content == FileContent::kDeletionVector, + "Expected deletion vector file but got content type: {}", + static_cast(dvFile_.content)); + VELOX_CHECK_GT(dvFile_.recordCount, 0, "Empty deletion vector."); + + static constexpr int64_t kMaxDeletionVectorRecordCount = 10'000'000'000LL; + VELOX_CHECK_LE( + dvFile_.recordCount, + kMaxDeletionVectorRecordCount, + "Deletion vector record count exceeds sanity limit: {}", + dvFile_.recordCount); +} + +void DeletionVectorReader::loadBitmap() { + if (loaded_) { + return; + } + loaded_ = true; + + // Prefer the typed contentOffset / contentLength fields. The legacy + // bounds-map encoding (kDvOffsetFieldId / kDvLengthFieldId) is kept as a + // fallback for callers that have not migrated yet. + uint64_t blobOffset = static_cast(dvFile_.contentOffset); + uint64_t blobLength = dvFile_.contentLength > 0 + ? static_cast(dvFile_.contentLength) + : dvFile_.fileSizeInBytes; + + if (dvFile_.contentLength == 0) { + if (auto it = dvFile_.lowerBounds.find(kDvOffsetFieldId); + it != dvFile_.lowerBounds.end()) { + try { + blobOffset = std::stoull(it->second); + } catch (const std::exception& e) { + VELOX_FAIL( + "Failed to parse DV blob offset from bounds map: {}", e.what()); + } + } + if (auto it = dvFile_.upperBounds.find(kDvLengthFieldId); + it != dvFile_.upperBounds.end()) { + try { + blobLength = std::stoull(it->second); + } catch (const std::exception& e) { + VELOX_FAIL( + "Failed to parse DV blob length from bounds map: {}", e.what()); + } + } + } + + auto fs = filesystems::getFileSystem(dvFile_.filePath, nullptr); + auto readFile = fs->openFileForRead(dvFile_.filePath); + + auto fileSize = readFile->size(); + VELOX_CHECK_LE( + blobOffset, + fileSize, + "DV blob offset {} exceeds file size {}.", + blobOffset, + fileSize); + VELOX_CHECK_LE( + blobLength, + fileSize - blobOffset, + "DV blob range [{}, {}) exceeds file size {}.", + blobOffset, + blobOffset + blobLength, + fileSize); + + std::string blobData(blobLength, '\0'); + readFile->pread(blobOffset, blobLength, blobData.data()); + + // Detect format: 64-bit Roaring64Bitmap vs 32-bit RoaringBitmap. + // 64-bit format starts with [numGroups: uint64]. If the first 4 bytes + // match a 32-bit cookie (12346 or 12347), it's a legacy 32-bit bitmap. + // Otherwise, interpret as 64-bit format. + deserializeRoaring64Bitmap(blobData); + + std::sort(deletedPositions_.begin(), deletedPositions_.end()); +} + +void DeletionVectorReader::deserializeRoaring64Bitmap(const std::string& data) { + if (data.size() < 8) { + VELOX_FAIL( + "Deletion vector blob too small: {} bytes, expected at least 8.", + data.size()); + } + + const auto* ptr = reinterpret_cast(data.data()); + const auto* end = ptr + data.size(); + + // Peek at first 4 bytes to detect 32-bit vs 64-bit format. + uint32_t firstWord; + std::memcpy(&firstWord, ptr, sizeof(uint32_t)); + firstWord = folly::Endian::little(firstWord); + + bool is32BitFormat = (firstWord == kSerialCookieNoRun) || + ((firstWord & 0xFFFF) == kSerialCookie); + + if (is32BitFormat) { + // Legacy 32-bit RoaringBitmap — all positions in [0, 2^32). + deserialize32BitRoaringBitmap(ptr, end, 0); + return; + } + + // 64-bit Roaring64Bitmap format: + // [numGroups: uint64] + // For each group (sorted by highBits): + // [highBits: uint32] + // [32-bit RoaringBitmap in portable format] + uint64_t numGroups; + std::memcpy(&numGroups, ptr, sizeof(uint64_t)); + numGroups = folly::Endian::little(numGroups); + ptr += sizeof(uint64_t); + + static constexpr uint64_t kMaxGroups = 1'000'000; + VELOX_CHECK_LE( + numGroups, + kMaxGroups, + "Roaring64Bitmap group count exceeds sanity limit: {}", + numGroups); + + for (uint64_t g = 0; g < numGroups; ++g) { + VELOX_CHECK_GE( + static_cast(end - ptr), + sizeof(uint32_t), + "Truncated Roaring64Bitmap group header."); + + uint32_t highBits; + std::memcpy(&highBits, ptr, sizeof(uint32_t)); + highBits = folly::Endian::little(highBits); + ptr += sizeof(uint32_t); + + int64_t highBitsOffset = static_cast(highBits) << 32; + + // Deserialize the 32-bit bitmap for this group. + // We need to find its size first by parsing the header. + deserialize32BitRoaringBitmap(ptr, end, highBitsOffset); + + // Advance ptr past the 32-bit bitmap we just parsed. + // Re-parse the header to compute the size. + const auto* bitmapStart = ptr; + + uint32_t cookie; + std::memcpy(&cookie, bitmapStart, sizeof(uint32_t)); + cookie = folly::Endian::little(cookie); + + bool hasRunContainers = false; + uint32_t numContainers = 0; + + if ((cookie & 0xFFFF) == kSerialCookie) { + hasRunContainers = true; + numContainers = (cookie >> 16) + 1; + ptr += sizeof(uint32_t); + } else if (cookie == kSerialCookieNoRun) { + ptr += sizeof(uint32_t); + uint32_t containerCount; + std::memcpy(&containerCount, ptr, sizeof(uint32_t)); + numContainers = folly::Endian::little(containerCount); + ptr += sizeof(uint32_t); + } else { + VELOX_FAIL("Unknown roaring bitmap cookie in 64-bit group: {}", cookie); + } + + if (numContainers == 0) { + continue; + } + + // Skip run bitmap if present. + if (hasRunContainers) { + uint32_t runBitmapBytes = (numContainers + 7) / 8; + ptr += runBitmapBytes; + } + + // Read key-cardinality pairs to compute container data sizes. + struct ContainerMeta { + uint16_t key; + uint32_t cardinality; + bool isRun; + }; + std::vector containers(numContainers); + + // Re-read run bitmap for container type detection. + const auto* runBitmapPtr = + hasRunContainers ? bitmapStart + sizeof(uint32_t) : nullptr; + + for (uint32_t i = 0; i < numContainers; ++i) { + uint16_t key, cardMinus1; + std::memcpy(&key, ptr, sizeof(uint16_t)); + key = folly::Endian::little(key); + ptr += sizeof(uint16_t); + std::memcpy(&cardMinus1, ptr, sizeof(uint16_t)); + cardMinus1 = folly::Endian::little(cardMinus1); + ptr += sizeof(uint16_t); + bool isRun = hasRunContainers && runBitmapPtr + ? ((runBitmapPtr[i / 8] >> (i % 8)) & 1) + : false; + containers[i] = {key, static_cast(cardMinus1) + 1, isRun}; + } + + // Skip offset section + const bool hasOffsetSection = + !hasRunContainers || numContainers >= kRunContainersNoOffsetThreshold; + if (hasOffsetSection) { + ptr += numContainers * sizeof(uint32_t); + } + + // Skip container data. + for (uint32_t i = 0; i < numContainers; ++i) { + if (containers[i].isRun) { + uint16_t numRuns; + std::memcpy(&numRuns, ptr, sizeof(uint16_t)); + numRuns = folly::Endian::little(numRuns); + ptr += sizeof(uint16_t) + static_cast(numRuns) * 4; + } else if (containers[i].cardinality <= 4'096) { + ptr += static_cast(containers[i].cardinality) * 2; + } else { + ptr += 8'192; + } + } + } +} + +void DeletionVectorReader::deserialize32BitRoaringBitmap( + const uint8_t* ptr, + const uint8_t* end, + int64_t highBitsOffset) { + VELOX_CHECK_GE(static_cast(end - ptr), 8, "32-bit bitmap too small."); + + uint32_t cookie; + std::memcpy(&cookie, ptr, sizeof(uint32_t)); + cookie = folly::Endian::little(cookie); + ptr += sizeof(uint32_t); + + bool hasRunContainers = false; + uint32_t numContainers = 0; + + if ((cookie & 0xFFFF) == kSerialCookie) { + hasRunContainers = true; + numContainers = (cookie >> 16) + 1; + } else if (cookie == kSerialCookieNoRun) { + std::memcpy(&numContainers, ptr, sizeof(uint32_t)); + numContainers = folly::Endian::little(numContainers); + ptr += sizeof(uint32_t); + } else { + VELOX_FAIL( + "Unknown roaring bitmap cookie: {}. Expected {} or {}.", + cookie, + kSerialCookieNoRun, + kSerialCookie); + } + + if (numContainers == 0) { + return; + } + + // Read run bitmap if present. + std::vector isRunContainer(numContainers, false); + if (hasRunContainers) { + uint32_t runBitmapBytes = (numContainers + 7) / 8; + VELOX_CHECK_GE( + static_cast(end - ptr), + runBitmapBytes, + "Truncated run bitmap."); + for (uint32_t i = 0; i < numContainers; ++i) { + isRunContainer[i] = (ptr[i / 8] >> (i % 8)) & 1; + } + ptr += runBitmapBytes; + } + + // Read key-cardinality pairs. + struct ContainerMeta { + uint16_t key; + uint32_t cardinality; + }; + std::vector containers(numContainers); + + VELOX_CHECK_GE( + static_cast(end - ptr), + numContainers * 4, + "Truncated container metadata."); + for (uint32_t i = 0; i < numContainers; ++i) { + uint16_t key, cardMinus1; + std::memcpy(&key, ptr, sizeof(uint16_t)); + key = folly::Endian::little(key); + ptr += sizeof(uint16_t); + std::memcpy(&cardMinus1, ptr, sizeof(uint16_t)); + cardMinus1 = folly::Endian::little(cardMinus1); + ptr += sizeof(uint16_t); + containers[i] = {key, static_cast(cardMinus1) + 1}; + } + + // Skip offset section + const bool hasOffsetSection = + !hasRunContainers || numContainers >= kRunContainersNoOffsetThreshold; + if (hasOffsetSection) { + VELOX_CHECK_GE( + static_cast(end - ptr), + numContainers * 4, + "Truncated offset section."); + ptr += numContainers * sizeof(uint32_t); + } + + // dvFile_.recordCount was already validated against + // kMaxDeletionVectorRecordCount in the constructor. + deletedPositions_.reserve(deletedPositions_.size() + dvFile_.recordCount); + + // Read container data. + for (uint32_t i = 0; i < numContainers; ++i) { + int64_t containerBase = + highBitsOffset | (static_cast(containers[i].key) << 16); + uint32_t cardinality = containers[i].cardinality; + + if (isRunContainer[i]) { + uint16_t numRuns; + VELOX_CHECK_GE( + static_cast(end - ptr), 2u, "Truncated run container."); + std::memcpy(&numRuns, ptr, sizeof(uint16_t)); + numRuns = folly::Endian::little(numRuns); + ptr += sizeof(uint16_t); + + VELOX_CHECK_GE( + static_cast(end - ptr), + static_cast(numRuns) * 4, + "Truncated run container data."); + for (uint16_t r = 0; r < numRuns; ++r) { + uint16_t start, lengthMinus1; + std::memcpy(&start, ptr, sizeof(uint16_t)); + start = folly::Endian::little(start); + ptr += sizeof(uint16_t); + std::memcpy(&lengthMinus1, ptr, sizeof(uint16_t)); + lengthMinus1 = folly::Endian::little(lengthMinus1); + ptr += sizeof(uint16_t); + for (uint32_t v = start; + v <= static_cast(start) + lengthMinus1; + ++v) { + deletedPositions_.push_back(containerBase | v); + } + } + } else if (cardinality <= 4'096) { + VELOX_CHECK_GE( + static_cast(end - ptr), + cardinality * 2, + "Truncated array container."); + for (uint32_t j = 0; j < cardinality; ++j) { + uint16_t val; + std::memcpy(&val, ptr, sizeof(uint16_t)); + val = folly::Endian::little(val); + ptr += sizeof(uint16_t); + deletedPositions_.push_back(containerBase | val); + } + } else { + static constexpr size_t kBitsetBytes = 8'192; + VELOX_CHECK_GE( + static_cast(end - ptr), + kBitsetBytes, + "Truncated bitset container."); + for (uint32_t word = 0; word < 1'024; ++word) { + uint64_t bits; + std::memcpy(&bits, ptr + word * 8, sizeof(uint64_t)); + bits = folly::Endian::little(bits); + while (bits != 0) { + uint32_t bit = __builtin_ctzll(bits); + deletedPositions_.push_back( + containerBase | static_cast(word * 64 + bit)); + bits &= bits - 1; + } + } + ptr += kBitsetBytes; + } + } +} + +void DeletionVectorReader::readDeletePositions( + uint64_t baseReadOffset, + uint64_t size, + BufferPtr deleteBitmap) { + loadBitmap(); + + if (deletedPositions_.empty()) { + return; + } + + auto* bitmap = deleteBitmap->asMutable(); + int64_t rowNumberLowerBound = + static_cast(splitOffset_ + baseReadOffset); + int64_t rowNumberUpperBound = + rowNumberLowerBound + static_cast(size); + + while (positionIndex_ < deletedPositions_.size() && + deletedPositions_[positionIndex_] < rowNumberLowerBound) { + ++positionIndex_; + } + + while (positionIndex_ < deletedPositions_.size() && + deletedPositions_[positionIndex_] < rowNumberUpperBound) { + auto bitIndex = static_cast( + deletedPositions_[positionIndex_] - rowNumberLowerBound); + bits::setBit(bitmap, bitIndex); + ++positionIndex_; + } +} + +bool DeletionVectorReader::noMoreData() const { + return loaded_ && positionIndex_ >= deletedPositions_.size(); +} + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/DeletionVectorReader.h b/velox/connectors/hive/iceberg/DeletionVectorReader.h new file mode 100644 index 00000000000..5c3cb4323c9 --- /dev/null +++ b/velox/connectors/hive/iceberg/DeletionVectorReader.h @@ -0,0 +1,94 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include "velox/common/base/BitUtil.h" +#include "velox/common/memory/Memory.h" +#include "velox/connectors/hive/iceberg/IcebergDeleteFile.h" + +namespace facebook::velox::connector::hive::iceberg { + +/// Reads an Iceberg V3 deletion vector and applies it to the delete bitmap. +/// +/// Supports both 64-bit Roaring64Bitmap format (used by Java's Roaring64Bitmap +/// at Meta's 500 PB scale) and legacy 32-bit RoaringBitmap format. +/// +/// 64-bit format: [numGroups: uint64] then for each group: +/// [highBits: uint32] [32-bit RoaringBitmap in portable format] +/// +/// 32-bit format: [cookie: uint32] [containerCount: uint32] ... +/// Detected by checking if the first 8 bytes match a 32-bit cookie (12346 +/// or 12347). +class DeletionVectorReader { + public: + /// @param dvFile Iceberg delete file metadata containing the Puffin file + /// path, + /// blob offset, and blob length. + /// @param splitOffset File position of the first row in the split. + /// @param pool Memory pool for internal allocations. + DeletionVectorReader( + const IcebergDeleteFile& dvFile, + uint64_t splitOffset, + memory::MemoryPool* pool); + + /// Reads deleted positions from the DV and sets corresponding bits in the + /// deleteBitmap for the current batch range. + void readDeletePositions( + uint64_t baseReadOffset, + uint64_t size, + BufferPtr deleteBitmap); + + /// Returns true when there is no more data. + bool noMoreData() const; + + static constexpr int32_t kDvOffsetFieldId = 100; + static constexpr int32_t kDvLengthFieldId = 101; + + private: + void loadBitmap(); + + // Deserializes a 64-bit roaring bitmap (Roaring64Bitmap format). + void deserializeRoaring64Bitmap(const std::string& data); + + // Deserializes a 32-bit roaring bitmap from portable binary format. + // Positions are offset by highBitsOffset (upper 32 bits for 64-bit mode, + // 0 for legacy 32-bit mode). + void deserialize32BitRoaringBitmap( + const uint8_t* ptr, + const uint8_t* end, + int64_t highBitsOffset); + + // The deletion vector file metadata from the Iceberg manifest. + const IcebergDeleteFile dvFile_; + + // Base offset of the split within the data file (for position mapping). + const uint64_t splitOffset_; + + // Sorted list of deleted row positions (absolute, file-level positions). + std::vector deletedPositions_; + + // Current scan position within deletedPositions_. + size_t positionIndex_{0}; + + // Whether the bitmap has been loaded from the DV file. + bool loaded_{false}; +}; + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/DeletionVectorWriter.cpp b/velox/connectors/hive/iceberg/DeletionVectorWriter.cpp new file mode 100644 index 00000000000..b2721e4f444 --- /dev/null +++ b/velox/connectors/hive/iceberg/DeletionVectorWriter.cpp @@ -0,0 +1,256 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/DeletionVectorWriter.h" + +#include +#include + +#include +#include + +#include "velox/common/base/Exceptions.h" + +namespace facebook::velox::connector::hive::iceberg { + +namespace { + +// Roaring Bitmap portable format constants. +constexpr uint32_t kSerialCookieNoRun = 12'346; +constexpr uint32_t kMaxArrayContainerCardinality = 4'096; +// Full bitmap container: 2^16 bits = 1024 uint64 words = 8192 bytes. +constexpr size_t kBitmapContainerBytes = 8'192; +constexpr size_t kBitmapContainerWords = 1'024; + +// Puffin file format constants (per Iceberg spec). +constexpr char kPuffinMagic[] = {'\x50', '\x55', '\x46', '\x31'}; +constexpr size_t kPuffinMagicSize = 4; +constexpr uint32_t kPuffinFooterFlags = 0; + +// Puffin blob metadata constants (per Iceberg V3 deletion vector spec). +constexpr char kDeletionVectorBlobType[] = "deletion-vector-v1"; +constexpr char kCompressionCodecNone[] = "none"; +// Iceberg spec: source-field-id for whole-row deletes is INT_MAX - 1. +constexpr int32_t kWholeRowDeleteFieldId = 2'147'483'646; + +void writeLittleEndian(std::string& out, uint16_t val) { + val = folly::Endian::little(val); + out.append(reinterpret_cast(&val), sizeof(val)); +} + +void writeLittleEndian(std::string& out, uint32_t val) { + val = folly::Endian::little(val); + out.append(reinterpret_cast(&val), sizeof(val)); +} + +void writeLittleEndian(std::string& out, uint64_t val) { + val = folly::Endian::little(val); + out.append(reinterpret_cast(&val), sizeof(val)); +} + +// Serializes the key-cardinality header for a 32-bit Roaring Bitmap. +void serializeKeyCardinality( + std::string& data, + const std::vector>>& containers) { + for (const auto& [key, values] : containers) { + writeLittleEndian(data, key); + auto cardMinus1 = static_cast(values.size() - 1); + writeLittleEndian(data, cardMinus1); + } +} + +// Serializes the offset section for a 32-bit Roaring Bitmap. +// Offsets point to where each container's data begins, measured from the start +// of the serialized bitmap. The header consists of: cookie (4) + count (4) + +// key-cardinality pairs (numContainers * 4) + offset section (numContainers * +// 4). Offsets are relative to byte 0 of the serialized output, so the first +// container's data starts immediately after the full header including offsets. +void serializeOffsets( + std::string& data, + const std::vector>>& containers) { + auto numContainers = static_cast(containers.size()); + uint32_t headerSize = 4 + 4 + numContainers * 4 + numContainers * 4; + uint32_t runningOffset = headerSize; + for (const auto& [key, values] : containers) { + writeLittleEndian(data, runningOffset); + if (values.size() <= kMaxArrayContainerCardinality) { + runningOffset += static_cast(values.size()) * 2; + } else { + runningOffset += kBitmapContainerBytes; + } + } +} + +// Serializes container data (array or bitmap) for a 32-bit Roaring Bitmap. +// Array containers store sorted uint16 values directly. Bitmap containers +// store a 65536-bit bitset as 1024 little-endian uint64 words, covering the +// full uint16 range [0, 65535]. +void serializeContainerData( + std::string& data, + const std::vector>>& containers) { + for (const auto& [key, values] : containers) { + if (values.size() <= kMaxArrayContainerCardinality) { + for (auto value : values) { + writeLittleEndian(data, value); + } + } else { + std::vector bitmap(kBitmapContainerWords, 0); + for (auto value : values) { + bitmap[value / 64] |= (1ULL << (value % 64)); + } + for (auto word : bitmap) { + writeLittleEndian(data, word); + } + } + } +} + +} // namespace + +void DeletionVectorWriter::addDeletedPosition(int64_t position) { + VELOX_CHECK_GE(position, 0, "Deleted position must be non-negative."); + positions_.push_back(position); +} + +void DeletionVectorWriter::addDeletedPositions( + const std::vector& positions) { + for (auto pos : positions) { + addDeletedPosition(pos); + } +} + +std::string DeletionVectorWriter::serialize32( + const std::vector& sorted) const { + if (sorted.empty()) { + std::string data; + writeLittleEndian(data, kSerialCookieNoRun); + uint32_t zero = 0; + writeLittleEndian(data, zero); + return data; + } + + // Group values by high 16 bits (container key). + std::map> containerMap; + for (auto val : sorted) { + auto key = static_cast(val >> 16); + auto low = static_cast(val & 0xFFFF); + containerMap[key].push_back(low); + } + + std::vector>> containers( + containerMap.begin(), containerMap.end()); + auto numContainers = static_cast(containers.size()); + + std::string data; + writeLittleEndian(data, kSerialCookieNoRun); + writeLittleEndian(data, numContainers); + serializeKeyCardinality(data, containers); + serializeOffsets(data, containers); + serializeContainerData(data, containers); + return data; +} + +std::string DeletionVectorWriter::serialize() const { + if (positions_.empty()) { + std::string data; + writeLittleEndian(data, static_cast(0)); + return data; + } + + // Sort and deduplicate positions. + std::vector sorted = positions_; + std::sort(sorted.begin(), sorted.end()); + sorted.erase(std::unique(sorted.begin(), sorted.end()), sorted.end()); + + // Partition into 32-bit high groups. + // Roaring64Bitmap format: [numGroups: uint64] then for each group: + // [highBits: uint32] [serialized 32-bit RoaringBitmap] + std::map> groups; + for (auto pos : sorted) { + // Safe cast: addDeletedPosition() rejects negative values. + auto upos = static_cast(pos); + groups[static_cast(upos >> 32)].push_back( + static_cast(upos & 0xFFFFFFFF)); + } + + std::string data; + writeLittleEndian(data, static_cast(groups.size())); + + for (auto& [highBits, lowValues] : groups) { + writeLittleEndian(data, highBits); + data.append(serialize32(lowValues)); + } + + return data; +} + +void DeletionVectorWriter::clear() { + positions_.clear(); +} + +std::pair writePuffinFile( + const std::string& filePath, + const std::string& blobData, + const std::string& referencedDataFile) { + uint64_t blobOffset = kPuffinMagicSize; + uint64_t blobLength = blobData.size(); + + folly::dynamic blobMeta = folly::dynamic::object( + "type", kDeletionVectorBlobType)( + "fields", + folly::dynamic::array( + folly::dynamic::object("source-field-id", kWholeRowDeleteFieldId))); + blobMeta["offset"] = blobOffset; + blobMeta["length"] = blobLength; + blobMeta["compression-codec"] = kCompressionCodecNone; + + folly::dynamic properties = folly::dynamic::object; + properties["referenced-data-file"] = referencedDataFile; + blobMeta["properties"] = properties; + + folly::dynamic footer = folly::dynamic::object; + footer["blobs"] = folly::dynamic::array(blobMeta); + footer["properties"] = folly::dynamic::object; + + std::string footerJson = folly::toJson(footer); + uint32_t footerPayloadSize = static_cast(footerJson.size()); + + std::string fileContent; + fileContent.append(kPuffinMagic, kPuffinMagicSize); + fileContent.append(blobData); + fileContent.append(footerJson); + uint32_t littleEndianSize = folly::Endian::little(footerPayloadSize); + fileContent.append( + reinterpret_cast(&littleEndianSize), + sizeof(littleEndianSize)); + uint32_t littleEndianFlags = folly::Endian::little(kPuffinFooterFlags); + fileContent.append( + reinterpret_cast(&littleEndianFlags), + sizeof(littleEndianFlags)); + fileContent.append(kPuffinMagic, kPuffinMagicSize); + + std::ofstream out(filePath, std::ios::binary | std::ios::trunc); + VELOX_CHECK( + out.good(), "Failed to open Puffin file for writing: {}", filePath); + out.write( + fileContent.data(), static_cast(fileContent.size())); + out.close(); + VELOX_CHECK(!out.fail(), "Failed to write Puffin file: {}", filePath); + + return {blobOffset, blobLength}; +} + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/DeletionVectorWriter.h b/velox/connectors/hive/iceberg/DeletionVectorWriter.h new file mode 100644 index 00000000000..a6d7f4f0b2d --- /dev/null +++ b/velox/connectors/hive/iceberg/DeletionVectorWriter.h @@ -0,0 +1,102 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace facebook::velox::connector::hive::iceberg { + +/// Writes Iceberg V3 deletion vectors as serialized 64-bit roaring bitmaps. +/// +/// Deletion vectors are an Iceberg V3 feature (format version 3). The roaring +/// bitmap encoding itself is version-independent, but the Puffin blob type +/// "deletion-vector-v1" is defined by the Iceberg V3 spec. +/// +/// DVs are compact roaring bitmaps stored as blobs inside Puffin files to mark +/// deleted rows in data files. This writer collects deleted row positions and +/// serializes them in the Roaring64Bitmap portable format, compatible with +/// Java's Roaring64Bitmap used by the Presto coordinator. +/// +/// The 64-bit format partitions positions by their upper 32 bits into groups, +/// each group containing a standard 32-bit RoaringBitmap for the lower 32 +/// bits. This supports files with more than 4 billion rows. +/// +/// Usage: +/// DeletionVectorWriter writer; +/// writer.addDeletedPosition(5); +/// writer.addDeletedPosition(5'000'000'000LL); +/// std::string blob = writer.serialize(); +class DeletionVectorWriter { + public: + DeletionVectorWriter() = default; + + /// Adds a deleted row position (0-based file row offset). + void addDeletedPosition(int64_t position); + + /// Adds multiple deleted row positions. + void addDeletedPositions(const std::vector& positions); + + /// Returns the number of deleted positions collected so far. + size_t numPositions() const { + return positions_.size(); + } + + /// Serializes collected positions into a 64-bit roaring bitmap. + /// + /// Format: Roaring64Bitmap portable serialization — + /// [numGroups: uint64] + /// For each group (sorted by highBits): + /// [highBits: uint32] + /// [32-bit RoaringBitmap in portable format] + /// + /// Each 32-bit RoaringBitmap uses SERIAL_COOKIE_NO_RUNCONTAINER (12346) + /// with array containers (cardinality <= 4096) or bitmap containers. + /// + /// @return Binary string containing the serialized 64-bit roaring bitmap. + std::string serialize() const; + + /// Clears all collected positions. + void clear(); + + private: + /// Serializes a single 32-bit roaring bitmap from sorted, deduplicated + /// values in the range [0, 2^32). + std::string serialize32(const std::vector& sorted) const; + + std::vector positions_; +}; + +/// Writes a Puffin file containing a single deletion vector blob. +/// +/// The Puffin file format consists of: +/// - 4-byte magic: "PUF1" +/// - Blob data (the serialized roaring bitmap) +/// - Footer: blob metadata + footer payload size + magic +/// +/// @param filePath Path to write the Puffin file. +/// @param blobData Serialized roaring bitmap bytes. +/// @param referencedDataFile Path of the data file this DV applies to. +/// @return Pair of (blobOffset, blobLength) within the written file. +std::pair writePuffinFile( + const std::string& filePath, + const std::string& blobData, + const std::string& referencedDataFile); + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/EqualityDeleteFileReader.cpp b/velox/connectors/hive/iceberg/EqualityDeleteFileReader.cpp new file mode 100644 index 00000000000..6d24875f375 --- /dev/null +++ b/velox/connectors/hive/iceberg/EqualityDeleteFileReader.cpp @@ -0,0 +1,360 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/EqualityDeleteFileReader.h" + +#include "velox/common/base/BitUtil.h" +#include "velox/connectors/hive/BufferedInputBuilder.h" +#include "velox/connectors/hive/HiveConnectorUtil.h" +#include "velox/connectors/hive/iceberg/IcebergMetadataColumns.h" +#include "velox/dwio/common/ReaderFactory.h" + +namespace facebook::velox::connector::hive::iceberg { + +namespace { + +// Hashes a single value from a vector at the given index. +// Handles lazy vectors via loadedVector(). Returns 0 for null values. +uint64_t hashValue(const VectorPtr& vectorPtr, vector_size_t index) { + const auto* vector = vectorPtr->loadedVector(); + if (vector->isNullAt(index)) { + return 0; + } + + auto type = vector->type(); + switch (type->kind()) { // NOLINT(clang-diagnostic-switch-enum) + case TypeKind::BOOLEAN: + return std::hash{}( + vector->as>()->valueAt(index)); + case TypeKind::TINYINT: + return std::hash{}( + vector->as>()->valueAt(index)); + case TypeKind::SMALLINT: + return std::hash{}( + vector->as>()->valueAt(index)); + case TypeKind::INTEGER: + return std::hash{}( + vector->as>()->valueAt(index)); + case TypeKind::BIGINT: + return std::hash{}( + vector->as>()->valueAt(index)); + case TypeKind::REAL: + return std::hash{}( + vector->as>()->valueAt(index)); + case TypeKind::DOUBLE: + return std::hash{}( + vector->as>()->valueAt(index)); + case TypeKind::VARCHAR: + case TypeKind::VARBINARY: { + auto stringView = vector->as>()->valueAt(index); + return folly::hasher{}( + std::string_view(stringView.data(), stringView.size())); + } + case TypeKind::TIMESTAMP: { + auto ts = vector->as>()->valueAt(index); + return std::hash{}(ts.toNanos()); + } + default: + VELOX_NYI( + "Equality delete hash not implemented for type: {}", + type->toString()); + } +} + +// Compares two values from vectors at given indices. +// Handles lazy vectors via loadedVector(). +bool compareValues( + const VectorPtr& leftPtr, + vector_size_t leftIndex, + const VectorPtr& rightPtr, + vector_size_t rightIndex) { + const auto* left = leftPtr->loadedVector(); + const auto* right = rightPtr->loadedVector(); + bool leftNull = left->isNullAt(leftIndex); + bool rightNull = right->isNullAt(rightIndex); + if (leftNull && rightNull) { + return true; + } + if (leftNull || rightNull) { + return false; + } + + auto type = left->type(); + switch (type->kind()) { // NOLINT(clang-diagnostic-switch-enum) + case TypeKind::BOOLEAN: + return left->as>()->valueAt(leftIndex) == + right->as>()->valueAt(rightIndex); + case TypeKind::TINYINT: + return left->as>()->valueAt(leftIndex) == + right->as>()->valueAt(rightIndex); + case TypeKind::SMALLINT: + return left->as>()->valueAt(leftIndex) == + right->as>()->valueAt(rightIndex); + case TypeKind::INTEGER: + return left->as>()->valueAt(leftIndex) == + right->as>()->valueAt(rightIndex); + case TypeKind::BIGINT: + return left->as>()->valueAt(leftIndex) == + right->as>()->valueAt(rightIndex); + case TypeKind::REAL: + return left->as>()->valueAt(leftIndex) == + right->as>()->valueAt(rightIndex); + case TypeKind::DOUBLE: + return left->as>()->valueAt(leftIndex) == + right->as>()->valueAt(rightIndex); + case TypeKind::VARCHAR: + case TypeKind::VARBINARY: { + auto leftValue = left->as>()->valueAt(leftIndex); + auto rightValue = + right->as>()->valueAt(rightIndex); + return std::string_view(leftValue.data(), leftValue.size()) == + std::string_view(rightValue.data(), rightValue.size()); + } + case TypeKind::TIMESTAMP: + return left->as>()->valueAt(leftIndex) == + right->as>()->valueAt(rightIndex); + default: + VELOX_NYI( + "Equality delete comparison not implemented for type: {}", + type->toString()); + } +} + +} // namespace + +EqualityDeleteFileReader::EqualityDeleteFileReader( + const IcebergDeleteFile& deleteFile, + const std::vector& equalityColumnNames, + const std::vector& equalityColumnTypes, + const std::string& /*baseFilePath*/, + FileHandleFactory* fileHandleFactory, + const ConnectorQueryCtx* connectorQueryCtx, + folly::Executor* executor, + const std::shared_ptr& fileConfig, + const std::shared_ptr& ioStatistics, + const std::shared_ptr& ioStats, + dwio::common::RuntimeStatistics& runtimeStats, + const std::string& connectorId) + : equalityColumnNames_(equalityColumnNames), + equalityColumnTypes_(equalityColumnTypes), + pool_(connectorQueryCtx->memoryPool()) { + VELOX_CHECK( + deleteFile.content == FileContent::kEqualityDeletes, + "Expected equality delete file but got content type: {}", + static_cast(deleteFile.content)); + VELOX_CHECK_GT(deleteFile.recordCount, 0, "Empty equality delete file."); + VELOX_CHECK( + !equalityColumnNames_.empty(), + "Equality delete file must specify at least one column."); + VELOX_CHECK_EQ( + equalityColumnNames_.size(), + equalityColumnTypes_.size(), + "Equality column names and types must have the same size."); + + // Build the file schema for the equality delete columns only. + auto deleteFileSchema = + ROW(std::vector(equalityColumnNames_), + std::vector(equalityColumnTypes_)); + + // Create a ScanSpec that reads only the equality delete columns. + auto scanSpec = std::make_shared(""); + for (size_t i = 0; i < equalityColumnNames_.size(); ++i) { + scanSpec->addField(equalityColumnNames_[i], static_cast(i)); + } + + auto deleteSplit = std::make_shared( + connectorId, + deleteFile.filePath, + deleteFile.fileFormat, + 0, + deleteFile.fileSizeInBytes); + + dwio::common::ReaderOptions deleteReaderOpts(pool_); + // TODO: Use separate IoStatistics for data and metadata. + deleteReaderOpts.setDataIoStats(ioStatistics); + deleteReaderOpts.setMetadataIoStats(ioStatistics); + configureReaderOptions( + fileConfig, + connectorQueryCtx, + deleteFileSchema, + deleteSplit, + /*tableParameters=*/{}, + deleteReaderOpts); + + const FileHandleKey fileHandleKey{ + .filename = deleteFile.filePath, + .tokenProvider = connectorQueryCtx->fsTokenProvider()}; + auto deleteFileHandleCachePtr = fileHandleFactory->generate(fileHandleKey); + auto deleteFileInput = BufferedInputBuilder::getInstance()->create( + *deleteFileHandleCachePtr, + deleteReaderOpts, + connectorQueryCtx, + ioStatistics, + ioStats, + executor); + + auto deleteReader = + dwio::common::getReaderFactory(deleteReaderOpts.fileFormat()) + ->createReader(std::move(deleteFileInput), deleteReaderOpts); + + if (!testFilters( + scanSpec.get(), + deleteReader.get(), + deleteSplit->filePath, + deleteSplit->partitionKeys, + {}, + fileConfig->readTimestampPartitionValueAsLocalTime( + connectorQueryCtx->sessionProperties()))) { + runtimeStats.skippedSplitBytes += static_cast(deleteSplit->length); + return; + } + + dwio::common::RowReaderOptions deleteRowReaderOpts; + configureRowReaderOptions( + {}, + scanSpec, + nullptr, + deleteFileSchema, + deleteSplit, + nullptr, + nullptr, + nullptr, + deleteRowReaderOpts); + + auto deleteRowReader = deleteReader->createRowReader(deleteRowReaderOpts); + + // Read the entire equality delete file and build the hash set. + VectorPtr output; + output = BaseVector::create(deleteFileSchema, 0, pool_); + + while (true) { + auto rowsRead = deleteRowReader->next( + std::max(static_cast(1'000), deleteFile.recordCount), output); + if (rowsRead == 0) { + break; + } + + auto numRows = output->size(); + if (numRows == 0) { + continue; + } + + output->loadedVector(); + auto rowOutput = std::dynamic_pointer_cast(output); + VELOX_CHECK_NOT_NULL(rowOutput); + + size_t batchIndex = deleteRows_.size(); + deleteRows_.push_back(rowOutput); + + // Resolve column indices on the first batch. + if (deleteColumnIndices_.empty()) { + for (const auto& colName : equalityColumnNames_) { + auto idx = rowOutput->type()->as().getChildIdx(colName); + deleteColumnIndices_.push_back(static_cast(idx)); + } + } + + // Hash each row and insert into the multimap. + for (vector_size_t i = 0; i < numRows; ++i) { + uint64_t hash = hashRow(rowOutput, i, deleteColumnIndices_); + deleteKeyHashes_.emplace(hash, DeleteKeyEntry{batchIndex, i}); + } + + // Reset output for next batch. + output = BaseVector::create(deleteFileSchema, 0, pool_); + } +} + +void EqualityDeleteFileReader::applyDeletes( + const RowVectorPtr& output, + BufferPtr deleteBitmap) { + if (deleteKeyHashes_.empty() || output->size() == 0) { + return; + } + + auto* bitmap = deleteBitmap->asMutable(); + + // For each row in the output, compute its hash and probe the delete set. + for (vector_size_t i = 0; i < output->size(); ++i) { + // Skip rows already deleted by positional/DV deletes. + if (bits::isBitSet(bitmap, i)) { + continue; + } + + uint64_t hash = hashRow(output, i, resolveOutputColumnIndices(output)); + auto range = deleteKeyHashes_.equal_range(hash); + + for (auto it = range.first; it != range.second; ++it) { + auto& entry = it->second; + if (equalRows(output, i, deleteRows_[entry.batchIndex], entry.rowIndex)) { + bits::setBit(bitmap, i); + break; + } + } + } +} + +const std::vector& +EqualityDeleteFileReader::resolveOutputColumnIndices( + const RowVectorPtr& row) const { + if (outputColumnIndices_.empty()) { + const auto& rowType = row->type()->asRow(); + outputColumnIndices_.reserve(equalityColumnNames_.size()); + for (const auto& colName : equalityColumnNames_) { + auto colIdx = rowType.getChildIdxIfExists(colName); + VELOX_CHECK( + colIdx.has_value(), + "Equality delete column not found in the output columns: {}", + colName); + outputColumnIndices_.push_back(static_cast(*colIdx)); + } + } + return outputColumnIndices_; +} + +uint64_t EqualityDeleteFileReader::hashRow( + const RowVectorPtr& row, + vector_size_t index, + const std::vector& colIndices) const { + uint64_t hash = 0; + + for (auto colIdx : colIndices) { + auto colHash = hashValue(row->childAt(colIdx), index); + hash ^= colHash + 0x9e3779b97f4a7c15ULL + (hash << 6) + (hash >> 2); + } + return hash; +} + +bool EqualityDeleteFileReader::equalRows( + const RowVectorPtr& left, + vector_size_t leftIndex, + const RowVectorPtr& right, + vector_size_t rightIndex) const { + const auto& leftColIndices = resolveOutputColumnIndices(left); + + for (size_t i = 0; i < leftColIndices.size(); ++i) { + if (!compareValues( + left->childAt(leftColIndices[i]), + leftIndex, + right->childAt(deleteColumnIndices_[i]), + rightIndex)) { + return false; + } + } + return true; +} + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/EqualityDeleteFileReader.h b/velox/connectors/hive/iceberg/EqualityDeleteFileReader.h new file mode 100644 index 00000000000..ad3a2be359a --- /dev/null +++ b/velox/connectors/hive/iceberg/EqualityDeleteFileReader.h @@ -0,0 +1,150 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include "velox/connectors/Connector.h" +#include "velox/connectors/hive/FileConfig.h" +#include "velox/connectors/hive/FileHandle.h" +#include "velox/connectors/hive/HiveConnectorSplit.h" +#include "velox/connectors/hive/iceberg/IcebergDeleteFile.h" +#include "velox/dwio/common/Reader.h" + +namespace facebook::velox::connector::hive::iceberg { + +/// Reads an Iceberg equality delete file and filters base data rows whose +/// equality delete column values match any row in the delete file. +/// +/// Iceberg equality delete files contain rows with values for one or more +/// columns (identified by equalityFieldIds). A base data row is deleted if +/// its values match ALL specified columns of ANY row in the delete file. +/// +/// Unlike positional deletes (which set bits before reading), equality deletes +/// require reading the base data first, then probing each row against the +/// delete set. The reader eagerly loads all delete key tuples from the file +/// into an in-memory hash set during construction. +/// +/// The equality delete column names are resolved from equalityFieldIds via +/// the table schema provided by the caller. +class EqualityDeleteFileReader { + public: + /// Constructs a reader for a single equality delete file. + /// + /// Eagerly reads the entire delete file and builds an in-memory hash set + /// of delete key tuples. The delete file is fully consumed during + /// construction. + /// + /// @param deleteFile Metadata about the equality delete file. Must have + /// content == FileContent::kEqualityDeletes and non-empty + /// equalityFieldIds. + /// @param equalityColumnNames Ordered column names corresponding to + /// equalityFieldIds, resolved by the caller from the table schema. + /// @param equalityColumnTypes Ordered column types corresponding to + /// equalityFieldIds. + /// @param baseFilePath Path of the base data file being read. + /// @param fileHandleFactory Factory for creating file handles. + /// @param connectorQueryCtx Query context for memory and config. + /// @param executor IO executor for async reads. + /// @param hiveConfig Hive configuration. + /// @param ioStatistics IO statistics collector. + /// @param ioStats IO stats tracker. + /// @param runtimeStats Runtime statistics for recording skipped bytes. + /// @param connectorId Connector identifier. + EqualityDeleteFileReader( + const IcebergDeleteFile& deleteFile, + const std::vector& equalityColumnNames, + const std::vector& equalityColumnTypes, + const std::string& baseFilePath, + FileHandleFactory* fileHandleFactory, + const ConnectorQueryCtx* connectorQueryCtx, + folly::Executor* executor, + const std::shared_ptr& fileConfig, + const std::shared_ptr& ioStatistics, + const std::shared_ptr& ioStats, + dwio::common::RuntimeStatistics& runtimeStats, + const std::string& connectorId); + + /// Applies equality deletes to the output vector by setting bits in the + /// delete bitmap for rows whose equality column values match any delete + /// key tuple. + /// + /// @param output The base data output vector to filter. + /// @param deleteBitmap Output bitmap. Bit i is set if row i matches an + /// equality delete. The bitmap must be pre-allocated to cover at least + /// output->size() rows. + void applyDeletes(const RowVectorPtr& output, BufferPtr deleteBitmap); + + /// Returns the number of delete key tuples loaded from the file. + size_t numDeleteKeys() const { + return deleteKeyHashes_.size(); + } + + /// Returns true if this reader has no delete keys (file was skipped or + /// empty). When true, applyDeletes() is a no-op. + bool empty() const { + return deleteKeyHashes_.empty(); + } + + private: + // Resolves column indices for the given row type, caching the result in + // outputColumnIndices_ for reuse across rows. + const std::vector& resolveOutputColumnIndices( + const RowVectorPtr& row) const; + + // Hashes a single row's equality delete columns into a uint64_t key. + uint64_t hashRow( + const RowVectorPtr& row, + vector_size_t index, + const std::vector& colIndices) const; + + // Checks whether two rows are equal on all equality delete columns. + bool equalRows( + const RowVectorPtr& left, + vector_size_t leftIndex, + const RowVectorPtr& right, + vector_size_t rightIndex) const; + + // Column names and types for equality delete comparison. + std::vector equalityColumnNames_; + std::vector equalityColumnTypes_; + + // Column indices in the delete file output vector. + std::vector deleteColumnIndices_; + + // Cached column indices for the output (probe) row type. Resolved lazily + // on first applyDeletes() call to avoid repeated name lookups per row. + mutable std::vector outputColumnIndices_; + + // All rows read from the equality delete file, stored for equality + // comparison during probing. + std::vector deleteRows_; + + // Hash multimap storing (hash → (batch index, row index)) for all delete + // key tuples. Used for O(1) average-case probing. + struct DeleteKeyEntry { + size_t batchIndex; + vector_size_t rowIndex; + }; + std::unordered_multimap deleteKeyHashes_; + + memory::MemoryPool* const pool_; +}; + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergColumnHandle.cpp b/velox/connectors/hive/iceberg/IcebergColumnHandle.cpp new file mode 100644 index 00000000000..28666689e6f --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergColumnHandle.cpp @@ -0,0 +1,51 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "velox/connectors/hive/TableHandle.h" +#include "velox/connectors/hive/iceberg/IcebergColumnHandle.h" +#include "velox/dwio/parquet/ParquetFieldId.h" +#include "velox/type/Subfield.h" +#include "velox/type/Type.h" + +namespace facebook::velox::connector::hive::iceberg { + +IcebergColumnHandle::IcebergColumnHandle( + const std::string& name, + ColumnType columnType, + TypePtr dataType, + parquet::ParquetFieldId icebergField, + std::vector requiredSubfields, + std::optional initialDefaultValue) + : HiveColumnHandle( + name, + columnType, + dataType, + dataType, + std::move(requiredSubfields), + ColumnParseParameters{ColumnParseParameters:: + PartitionDateValueFormat::kDaysSinceEpoch}), + field_(std::move(icebergField)), + initialDefaultValue_(std::move(initialDefaultValue)) {} + +const parquet::ParquetFieldId& IcebergColumnHandle::field() const { + return field_; +} + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergColumnHandle.h b/velox/connectors/hive/iceberg/IcebergColumnHandle.h new file mode 100644 index 00000000000..54b722dae9e --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergColumnHandle.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +#include "velox/connectors/hive/TableHandle.h" +#include "velox/dwio/parquet/ParquetFieldId.h" +#include "velox/type/Subfield.h" +#include "velox/type/Type.h" + +namespace facebook::velox::connector::hive::iceberg { + +class IcebergColumnHandle : public HiveColumnHandle { + public: + IcebergColumnHandle( + const std::string& name, + ColumnType columnType, + TypePtr dataType, + parquet::ParquetFieldId icebergField, + std::vector requiredSubfields = {}, + std::optional initialDefaultValue = std::nullopt); + + const parquet::ParquetFieldId& field() const; + + const std::optional& initialDefaultValue() const { + return initialDefaultValue_; + } + + private: + const parquet::ParquetFieldId field_; + const std::optional initialDefaultValue_; +}; + +using IcebergColumnHandlePtr = std::shared_ptr; + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergConfig.cpp b/velox/connectors/hive/iceberg/IcebergConfig.cpp new file mode 100644 index 00000000000..1b34b7c4eb7 --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergConfig.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/IcebergConfig.h" + +#include "velox/common/config/Config.h" + +namespace facebook::velox::connector::hive::iceberg { + +IcebergConfig::IcebergConfig( + const std::shared_ptr& config) + : config_(config) { + VELOX_CHECK_NOT_NULL( + config_, "Config is null for IcebergConfig initialization"); +} + +std::string IcebergConfig::functionPrefix() const { + return config_->get( + kFunctionPrefixConfig, kDefaultFunctionPrefix); +} + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergConfig.h b/velox/connectors/hive/iceberg/IcebergConfig.h new file mode 100644 index 00000000000..9eba3bd20b9 --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergConfig.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +#include "velox/common/config/Config.h" + +namespace facebook::velox::connector::hive::iceberg { + +/// Iceberg-specific connector configuration wrapper. +/// Provides accessors for Iceberg-only settings while sharing the same +/// underlying ConfigBase with HiveConfig. +class IcebergConfig { + public: + /// Iceberg function prefix. + static constexpr const char* kFunctionPrefixConfig = + "presto.iceberg-namespace"; + + /// Default prefix used to register Iceberg transform functions when no + /// connector config override is provided. + static constexpr const char* kDefaultFunctionPrefix = "$internal$.iceberg."; + + explicit IcebergConfig( + const std::shared_ptr& config); + + const std::shared_ptr& config() const { + return config_; + } + + std::string functionPrefix() const; + + private: + const std::shared_ptr config_; +}; + +using IcebergConfigPtr = std::shared_ptr; + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergConnector.cpp b/velox/connectors/hive/iceberg/IcebergConnector.cpp new file mode 100644 index 00000000000..685be818d2c --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergConnector.cpp @@ -0,0 +1,83 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/IcebergConnector.h" + +#include "velox/connectors/hive/HiveConnector.h" +#include "velox/connectors/hive/iceberg/IcebergConfig.h" +#include "velox/connectors/hive/iceberg/IcebergDataSink.h" +#include "velox/connectors/hive/iceberg/IcebergDataSource.h" + +namespace facebook::velox::connector::hive::iceberg { + +namespace { + +// Registers Iceberg partition transform functions with prefix. +// NOTE: These functions are registered for internal transform usage only. +// Upstream engines such as Prestissimo and Gluten should register the same +// functions with different prefixes to avoid conflicts. +void registerIcebergInternalFunctions(const std::string& prefix) { + static std::once_flag registerFlag; + + std::call_once(registerFlag, [prefix]() { + functions::iceberg::registerFunctions(prefix); + }); +} + +} // namespace + +IcebergConnector::IcebergConnector( + const std::string& id, + std::shared_ptr config, + folly::Executor* ioExecutor) + : HiveConnector(id, config, ioExecutor), + icebergConfig_(std::make_shared(connectorConfig())) { + registerIcebergInternalFunctions(icebergConfig_->functionPrefix()); +} + +std::unique_ptr IcebergConnector::createDataSource( + const RowTypePtr& outputType, + const ConnectorTableHandlePtr& tableHandle, + const ColumnHandleMap& columnHandles, + ConnectorQueryCtx* connectorQueryCtx) { + return std::make_unique( + outputType, + tableHandle, + columnHandles, + &fileHandleFactory_, + ioExecutor_, + connectorQueryCtx, + hiveConfig_); +} + +std::unique_ptr IcebergConnector::createDataSink( + RowTypePtr inputType, + ConnectorInsertTableHandlePtr connectorInsertTableHandle, + ConnectorQueryCtx* connectorQueryCtx, + CommitStrategy commitStrategy) { + auto icebergInsertHandle = checkedPointerCast( + connectorInsertTableHandle); + + return std::make_unique( + inputType, + icebergInsertHandle, + connectorQueryCtx, + commitStrategy, + hiveConfig_, + icebergConfig_); +} + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergConnector.h b/velox/connectors/hive/iceberg/IcebergConnector.h new file mode 100644 index 00000000000..32c8206d293 --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergConnector.h @@ -0,0 +1,97 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/connectors/hive/HiveConnector.h" +#include "velox/connectors/hive/iceberg/IcebergConfig.h" + +namespace facebook::velox::connector::hive::iceberg { + +/// Provides Iceberg table format support. +/// - Creates IcebergDataSource instances for reading Iceberg tables with +/// support for delete files and schema evolution. +/// - Creates IcebergDataSink instances for writing data with Iceberg-specific +/// partition transforms and commit metadata. +class IcebergConnector final : public HiveConnector { + public: + IcebergConnector( + const std::string& id, + std::shared_ptr config, + folly::Executor* ioExecutor); + + /// Creates IcebergDataSource for reading from Iceberg tables. + /// + /// @param outputType The schema of the output data to read. + /// @param tableHandle The table handle containing table metadata. + /// @param columnHandles Map of column names to column handles. + /// @param connectorQueryCtx Query context for the read operation. + /// @return IcebergDataSource instance configured for the read operation. + std::unique_ptr createDataSource( + const RowTypePtr& outputType, + const ConnectorTableHandlePtr& tableHandle, + const ColumnHandleMap& columnHandles, + ConnectorQueryCtx* connectorQueryCtx) override; + + /// Creates IcebergDataSink for writing to Iceberg tables. + /// + /// @param inputType The schema of the input data to write. + /// @param connectorInsertTableHandle Must be an IcebergInsertTableHandle + /// containing Iceberg-specific write configuration. + /// @param connectorQueryCtx Query context for the write operation. + /// @param commitStrategy Strategy for committing the write operation. Only + /// CommitStrategy::kNoCommit is supported for Iceberg tables. Files + /// are written directly with their final names and commit metadata is + /// returned for the coordinator to update the Iceberg metadata tables. + /// @return IcebergDataSink instance configured for the write operation. + std::unique_ptr createDataSink( + RowTypePtr inputType, + ConnectorInsertTableHandlePtr connectorInsertTableHandle, + ConnectorQueryCtx* connectorQueryCtx, + CommitStrategy commitStrategy) override; + + private: + const std::shared_ptr icebergConfig_; +}; + +class IcebergConnectorFactory final : public ConnectorFactory { + public: + static constexpr const char* kIcebergConnectorName = "iceberg"; + + IcebergConnectorFactory() : ConnectorFactory(kIcebergConnectorName) {} + + /// Creates a new IcebergConnector instance. + /// + /// @param id Unique identifier for this connector instance (typically the + /// catalog name). + /// @param config Connector configuration properties + /// @param ioExecutor Optional executor for asynchronous I/O operations such + /// as split preloading and file prefetching. When provided, enables + /// background file operations off the main driver thread. If nullptr, I/O + /// operations run synchronously. + /// @param cpuExecutor ConnectorFactory interface to support other connector + /// types that may need CPU-bound async work. Currently unused by + /// IcebergConnector. + /// @return Shared pointer to the newly created IcebergConnector instance + std::shared_ptr newConnector( + const std::string& id, + std::shared_ptr config, + folly::Executor* ioExecutor = nullptr, + [[maybe_unused]] folly::Executor* cpuExecutor = nullptr) override { + return std::make_shared(id, config, ioExecutor); + } +}; + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergDataFileStatistics.cpp b/velox/connectors/hive/iceberg/IcebergDataFileStatistics.cpp new file mode 100644 index 00000000000..018f3f8f68c --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergDataFileStatistics.cpp @@ -0,0 +1,58 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/IcebergDataFileStatistics.h" + +namespace facebook::velox::connector::hive::iceberg { + +folly::dynamic IcebergDataFileStatistics::toJson() const { + folly::dynamic json = folly::dynamic::object; + json["recordCount"] = numRecords; + + folly::dynamic columnSizes = folly::dynamic::object; + folly::dynamic valueCounts = folly::dynamic::object; + folly::dynamic nullValueCounts = folly::dynamic::object; + folly::dynamic nanValueCounts = folly::dynamic::object; + folly::dynamic lowerBounds = folly::dynamic::object; + folly::dynamic upperBounds = folly::dynamic::object; + + for (const auto& [fieldId, stats] : columnStats) { + auto fieldIdStr = folly::to(fieldId); + columnSizes[fieldIdStr] = stats.columnSize; + valueCounts[fieldIdStr] = stats.valueCount; + nullValueCounts[fieldIdStr] = stats.nullValueCount; + if (stats.nanValueCount.has_value()) { + nanValueCounts[fieldIdStr] = stats.nanValueCount.value(); + } + if (stats.lowerBound.has_value()) { + lowerBounds[fieldIdStr] = stats.lowerBound.value(); + } + if (stats.upperBound.has_value()) { + upperBounds[fieldIdStr] = stats.upperBound.value(); + } + } + + json["columnSizes"] = std::move(columnSizes); + json["valueCounts"] = std::move(valueCounts); + json["nullValueCounts"] = std::move(nullValueCounts); + json["nanValueCounts"] = std::move(nanValueCounts); + json["lowerBounds"] = std::move(lowerBounds); + json["upperBounds"] = std::move(upperBounds); + + return json; +} + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergDataFileStatistics.h b/velox/connectors/hive/iceberg/IcebergDataFileStatistics.h new file mode 100644 index 00000000000..5bcfb84b83f --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergDataFileStatistics.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include + +namespace facebook::velox::connector::hive::iceberg { + +/// Statistics for an Iceberg data file, corresponding to the `data_file` +/// structure defined in the Iceberg specification: +/// https://iceberg.apache.org/spec/#data-file-fields. +/// +/// All column-level statistics maps are keyed by Iceberg field IDs (`int32_t`), +/// which uniquely identify columns in the Iceberg schema independent of column +/// names or physical column positions. +struct IcebergDataFileStatistics { + struct ColumnStats { + int64_t columnSize{0}; + + /// Total number of values for this field ID in the file, including null and + /// NaN values. + /// + /// For primitive (flat) columns, this is equal to the number of rows in the + /// file: numRows = valueCount = (nonNullValues + numNulls + numNaNs). + /// + /// For nested columns (e.g. elements inside an array), this represents the + /// total occurrences of the field across all rows, which is not necessarily + /// related to the top-level record count. + int64_t valueCount{0}; + int64_t nullValueCount{0}; + std::optional nanValueCount; + /// Base64 encoded lower bound. + std::optional lowerBound; + /// Base64 encoded upper bound. + std::optional upperBound; + }; + + int64_t numRecords{0}; + folly::F14FastMap columnStats; + + /// Returns a IcebergDataFileStatistics with all values set to zero/empty. + /// Useful for empty data files that have no actual data. + static IcebergDataFileStatistics empty() { + return IcebergDataFileStatistics{.numRecords = 0, .columnStats = {}}; + } + + folly::dynamic toJson() const; +}; + +using IcebergDataFileStatisticsPtr = std::shared_ptr; + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergDataSink.cpp b/velox/connectors/hive/iceberg/IcebergDataSink.cpp new file mode 100644 index 00000000000..f7eb3e39f5c --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergDataSink.cpp @@ -0,0 +1,545 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/IcebergDataSink.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "velox/common/base/Fs.h" +#include "velox/common/encode/Base64.h" +#include "velox/common/memory/MemoryArbitrator.h" +#include "velox/common/testutil/TestValue.h" +#include "velox/connectors/hive/PartitionIdGenerator.h" +#include "velox/connectors/hive/iceberg/IcebergColumnHandle.h" + +#ifdef VELOX_ENABLE_PARQUET +#include "velox/connectors/hive/iceberg/IcebergParquetStatsCollector.h" +#include "velox/dwio/parquet/writer/Writer.h" +#endif + +#include "velox/connectors/hive/iceberg/TransformExprBuilder.h" +#include "velox/connectors/hive/iceberg/WriterOptionsAdapter.h" +#include "velox/dwio/dwrf/writer/Writer.h" +#include "velox/exec/OperatorUtils.h" +#include "velox/type/Type.h" + +using facebook::velox::common::testutil::TestValue; + +namespace facebook::velox::connector::hive::iceberg { + +namespace { + +template +folly::dynamic extractPartitionValue( + const VectorPtr& child, + vector_size_t row) { + using T = typename TypeTraits::NativeType; + return child->asChecked>()->valueAt(row); +} + +template <> +folly::dynamic extractPartitionValue( + const VectorPtr& child, + vector_size_t row) { + return child->asChecked>()->valueAt(row).str(); +} + +template <> +folly::dynamic extractPartitionValue( + const VectorPtr& child, + vector_size_t row) { + return encoding::Base64::encode( + child->asChecked>()->valueAt(row)); +} + +template <> +folly::dynamic extractPartitionValue( + const VectorPtr& child, + vector_size_t row) { + return child->asChecked>()->valueAt(row).toMicros(); +} + +class IcebergFileNameGenerator : public FileNameGenerator { + public: + std::pair gen( + std::optional bucketId, + const std::shared_ptr insertTableHandle, + const ConnectorQueryCtx& connectorQueryCtx, + bool commitRequired) const override; + + folly::dynamic serialize() const override; + + std::string toString() const override; +}; + +std::string makeUuid() { + return boost::lexical_cast(boost::uuids::random_generator()()); +} + +std::pair IcebergFileNameGenerator::gen( + std::optional bucketId, + const std::shared_ptr insertTableHandle, + const ConnectorQueryCtx& connectorQueryCtx, + bool commitRequired) const { + auto targetFileName = insertTableHandle->locationHandle()->targetFileName(); + if (targetFileName.empty()) { + targetFileName = fmt::format("{}", makeUuid()); + } + auto fileFormat = dwio::common::toString(insertTableHandle->storageFormat()); + auto fileName = fmt::format("{}.{}", targetFileName, fileFormat); + return {fileName, fileName}; +} + +folly::dynamic IcebergFileNameGenerator::serialize() const { + folly::dynamic obj = folly::dynamic::object; + obj["name"] = "IcebergFileNameGenerator"; + return obj; +} + +std::string IcebergFileNameGenerator::toString() const { + return "IcebergFileNameGenerator"; +} + +} // namespace + +IcebergInsertTableHandle::IcebergInsertTableHandle( + std::vector inputColumns, + LocationHandlePtr locationHandle, + dwio::common::FileFormat tableStorageFormat, + IcebergPartitionSpecPtr partitionSpec, + std::optional compressionKind, + const std::unordered_map& serdeParameters) + : HiveInsertTableHandle( + std::vector( + inputColumns.begin(), + inputColumns.end()), + std::move(locationHandle), + tableStorageFormat, + nullptr, + compressionKind, + serdeParameters, + nullptr, + false, + std::make_shared()), + partitionSpec_(partitionSpec) { + VELOX_USER_CHECK( + !inputColumns_.empty(), + "Input columns cannot be empty for Iceberg tables."); + VELOX_USER_CHECK_NOT_NULL( + locationHandle_, "Location handle is required for Iceberg tables."); + VELOX_USER_CHECK( + isSupportedFileFormat(tableStorageFormat), + "Unsupported file format for writing Iceberg tables: {}", + dwio::common::toString(tableStorageFormat)); +} + +namespace { + +// Creates partition channels by mapping partition spec fields to input column +// indices. For each field in the partition spec, finds the corresponding +// partition key column in the input columns and records its index. +// +// @param inputColumns The input columns from the insert table handle. +// @param partitionSpec The Iceberg partition specification, or nullptr if +// unpartitioned. +// @return A vector of column indices representing the partition channels. Each +// index corresponds to a partition field in the spec and points to the +// matching partition key column in the input. Returns an empty vector if +// partitionSpec is nullptr. +std::vector createPartitionChannels( + const std::vector& inputColumns, + const IcebergPartitionSpecPtr& partitionSpec) { + std::vector channels; + if (!partitionSpec) { + return channels; + } + + // Build a map from partition key column names to their indices in the input. + std::unordered_map partitionKeyMap; + for (auto i = 0; i < inputColumns.size(); ++i) { + if (inputColumns[i]->isPartitionKey()) { + partitionKeyMap[inputColumns[i]->name()] = i; + } + } + + // For each field in the partition spec, find its corresponding input column + // index. + channels.reserve(partitionSpec->fields.size()); + for (const auto& field : partitionSpec->fields) { + if (auto it = partitionKeyMap.find(field.name); + it != partitionKeyMap.end()) { + channels.push_back(it->second); + } + } + + return channels; +} + +std::vector createDataChannels( + const IcebergInsertTableHandlePtr& insertTableHandle) { + std::vector dataChannels( + insertTableHandle->inputColumns().size()); + std::iota(dataChannels.begin(), dataChannels.end(), 0); + return dataChannels; +} + +// Creates a RowType schema for transformed partition values based on the +// partition specification. This RowType is used to wrap the transformed +// partition columns before passing them to the partition ID generator. +// +// For each partition field in the spec: +// - The column type is the result type of the partition transform (e.g., +// INTEGER for year transform, DATE for day transform). +// - The column name is the source column name for identity transforms, or +// "columnName_transformName" for non-identity transforms (e.g., "birth_year" +// for a year transform on a birth column). +// +// @param partitionSpec The Iceberg partition specification, or nullptr if +// unpartitioned. +// @return A RowType containing one column per partition field with appropriate +// names and types. Returns nullptr if partitionSpec is nullptr. +RowTypePtr createPartitionRowType( + const IcebergPartitionSpecPtr& partitionSpec) { + if (!partitionSpec) { + return nullptr; + } + + std::vector partitionKeyTypes; + std::vector partitionKeyNames; + + // Build column names and types for each partition field. + // Identity transforms use the source column name directly. + // Non-identity transforms use "columnName_transformName" format. + for (const auto& field : partitionSpec->fields) { + partitionKeyTypes.emplace_back(field.resultType()); + std::string key = field.transformType == TransformType::kIdentity + ? field.name + : fmt::format( + "{}_{}", + field.name, + TransformTypeName::toName(field.transformType)); + partitionKeyNames.emplace_back(std::move(key)); + } + + return ROW(std::move(partitionKeyNames), std::move(partitionKeyTypes)); +} + +} // namespace + +IcebergDataSink::IcebergDataSink( + RowTypePtr inputType, + IcebergInsertTableHandlePtr insertTableHandle, + const ConnectorQueryCtx* connectorQueryCtx, + CommitStrategy commitStrategy, + const std::shared_ptr& hiveConfig, + const IcebergConfigPtr& icebergConfig) + : IcebergDataSink( + std::move(inputType), + insertTableHandle, + connectorQueryCtx, + commitStrategy, + hiveConfig, + createPartitionChannels( + insertTableHandle->inputColumns(), + insertTableHandle->partitionSpec()), + createDataChannels(insertTableHandle), + createPartitionRowType(insertTableHandle->partitionSpec()), + icebergConfig) {} + +IcebergDataSink::IcebergDataSink( + RowTypePtr inputType, + IcebergInsertTableHandlePtr insertTableHandle, + const ConnectorQueryCtx* connectorQueryCtx, + CommitStrategy commitStrategy, + const std::shared_ptr& hiveConfig, + const std::vector& partitionChannels, + const std::vector& dataChannels, + RowTypePtr partitionRowType, + const IcebergConfigPtr& icebergConfig) + : HiveDataSink( + inputType, + insertTableHandle, + connectorQueryCtx, + commitStrategy, + hiveConfig, + 0, + nullptr, + partitionChannels, + dataChannels, + !partitionChannels.empty() + ? std::make_unique( + partitionRowType, + [&partitionChannels]() { + std::vector transformedChannels( + partitionChannels.size()); + std::iota( + transformedChannels.begin(), + transformedChannels.end(), + 0); + return transformedChannels; + }(), + hiveConfig->maxPartitionsPerWriters( + connectorQueryCtx->sessionProperties()), + connectorQueryCtx->memoryPool()) + : nullptr), + partitionSpec_(insertTableHandle->partitionSpec()), + transformEvaluator_( + !partitionChannels.empty() ? std::make_unique( + TransformExprBuilder::toExpressions( + partitionSpec_, + partitionChannels_, + inputType_, + icebergConfig->functionPrefix()), + connectorQueryCtx_) + : nullptr), + icebergPartitionName_( + partitionSpec_ != nullptr + ? std::make_unique(partitionSpec_) + : nullptr), + partitionRowType_(std::move(partitionRowType)), + icebergInsertTableHandle_(insertTableHandle) { + commitPartitionValue_.resize(maxOpenWriters_); + +#ifdef VELOX_ENABLE_PARQUET + // Only initialize Parquet stats collector for Parquet format tables + if (insertTableHandle->storageFormat() == dwio::common::FileFormat::PARQUET) { + std::vector columnHandles; + columnHandles.reserve(insertTableHandle->inputColumns().size()); + for (auto& column : insertTableHandle->inputColumns()) { + columnHandles.emplace_back( + checkedPointerCast(column)); + } + parquetStatsCollector_ = std::make_shared( + std::move(columnHandles)); + } +#endif +} + +std::vector IcebergDataSink::commitMessage() const { + std::vector commitTasks; + commitTasks.reserve(writerInfo_.size()); + + for (auto i = 0; i < writerInfo_.size(); ++i) { + const auto& writerInfo = writerInfo_.at(i); + VELOX_CHECK_NOT_NULL(writerInfo); + + // Following metadata (json format) is consumed by Presto CommitTaskData. + // It contains the minimal subset of metadata. + VELOX_CHECK_EQ(writerInfo->writtenFiles.size(), dataFileStats_[i].size()); + for (auto fileIdx = 0; fileIdx < writerInfo->writtenFiles.size(); + ++fileIdx) { + const auto& fileInfo = writerInfo->writtenFiles[fileIdx]; + // clang-format off + folly::dynamic commitData = folly::dynamic::object( + "path", (fs::path(writerInfo->writerParameters.targetDirectory()) / + fileInfo.targetFileName).string()) + ("fileSizeInBytes", fileInfo.fileSize) + ("metrics", dataFileStats_[i][fileIdx]->toJson()) + ("partitionSpecJson", + icebergInsertTableHandle_->partitionSpec() ? + icebergInsertTableHandle_->partitionSpec()->specId : 0) + // Sort order evolution is not supported. Set default id to 0 ( unsorted order). + ("sortOrderId", 0) + ("fileFormat", toManifestFormatString(icebergInsertTableHandle_->storageFormat())) + ("content", "DATA"); + // clang-format on + if (!commitPartitionValue_.empty() && + !commitPartitionValue_[i].isNull()) { + commitData["partitionDataJson"] = folly::toJson( + folly::dynamic::object( + "partitionValues", commitPartitionValue_[i])); + } + auto commitDataJson = folly::toJson(commitData); + commitTasks.push_back(commitDataJson); + } + } + return commitTasks; +} + +void IcebergDataSink::computePartitionAndBucketIds(const RowVectorPtr& input) { + VELOX_CHECK(isPartitioned()); + VELOX_CHECK_NOT_NULL(transformEvaluator_); + VELOX_CHECK_NOT_NULL(partitionIdGenerator_); + // Step 1: Apply transforms to input partition columns. + auto transformedColumns = transformEvaluator_->evaluate(input); + + // Step 2: Create RowVector based on transformed columns. + const auto& transformedRowVector = std::make_shared( + connectorQueryCtx_->memoryPool(), + partitionRowType_, + nullptr, + input->size(), + std::move(transformedColumns)); + partitionIdGenerator_->run(transformedRowVector, partitionIds_); +} + +std::string IcebergDataSink::getPartitionName(uint32_t partitionId) const { + VELOX_CHECK_NOT_NULL(icebergPartitionName_); + + return icebergPartitionName_->partitionName( + partitionId, + partitionIdGenerator_->partitionValues(), + partitionKeyAsLowerCase_); +} + +uint32_t IcebergDataSink::ensureWriter(const WriterId& id) { + auto writerId = HiveDataSink::ensureWriter(id); + if (isPartitioned() && commitPartitionValue_[writerId].isNull()) { + commitPartitionValue_[writerId] = makeCommitPartitionValue(writerId); + } + return writerId; +} + +std::shared_ptr +IcebergDataSink::createWriterOptions(size_t writerIndex) const { + auto options = HiveDataSink::createWriterOptions(writerIndex); + + // Dispatch format-specific Iceberg overrides through the adapter so each + // supported format (Parquet, DWRF, Nimble) gets its pre/post-processConfigs + // hooks applied uniformly. + const auto adapter = + createWriterOptionsAdapter(icebergInsertTableHandle_->storageFormat()); + if (adapter != nullptr) { + adapter->applyPreConfigs(*options); + } + +#ifdef VELOX_ENABLE_PARQUET + // Iceberg-runtime stats collector is not a static config; wire it inline. + if (auto parquetOptions = + std::dynamic_pointer_cast(options)) { + if (parquetStatsCollector_) { + parquetOptions->parquetFieldIds = + parquetStatsCollector_->parquetFieldIds().children; + } + } +#endif + + options->processConfigs( + *hiveConfig_->config(), *connectorQueryCtx_->sessionProperties()); + + if (adapter != nullptr) { + adapter->applyPostConfigs(*options); + } + + return options; +} + +folly::dynamic IcebergDataSink::makeCommitPartitionValue( + uint32_t writerIndex) const { + folly::dynamic partitionValues = folly::dynamic::array(); + const auto& transformedValues = partitionIdGenerator_->partitionValues(); + for (auto i = 0; i < partitionChannels_.size(); ++i) { + const auto& child = transformedValues->childAt(i); + if (child->isNullAt(writerIndex)) { + partitionValues.push_back(nullptr); + } else { + partitionValues.push_back(VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( + extractPartitionValue, child->typeKind(), child, writerIndex)); + } + } + return partitionValues; +} + +void IcebergDataSink::closeWriterAndCollectStats(size_t index) { + auto metadata = writers_[index]->close(); + const bool fileAdded = getCurrentFileBytes(index) > 0; + + // Finalize file info (capture file size, add to writtenFiles). + finalizeWriterFile(index); + + if (!fileAdded) { + return; + } +#ifdef VELOX_ENABLE_PARQUET + if (parquetStatsCollector_) { + dataFileStats_[index].emplace_back( + parquetStatsCollector_->aggregate(std::move(metadata))); + return; + } +#endif + dataFileStats_[index].emplace_back( + std::make_shared( + IcebergDataFileStatistics::empty())); +} + +void IcebergDataSink::rotateWriter(size_t index) { + VELOX_CHECK_LT(index, writers_.size()); + VELOX_CHECK_NOT_NULL(writers_[index]); + + // Ensure dataFileStats_ has an entry for this writer index. + if (dataFileStats_.size() <= index) { + dataFileStats_.resize(index + 1); + } + + // Close the writer to flush the footer and obtain file metadata, then + // aggregate Iceberg stats from the metadata. The base rotateWriter() would + // also call writers_[index]->close() but discards the returned metadata. + // We close the writer ourselves to capture the metadata, then reset the + // writer to prevent double close. + { + const memory::NonReclaimableSectionGuard nonReclaimableGuard( + writerInfo_[index]->nonReclaimableSectionHolder.get()); + closeWriterAndCollectStats(index); + } + + // Release old writer. The new writer will be created lazily on the next + // write call. + writers_[index].reset(); + + ++writerInfo_[index]->fileSequenceNumber; +} + +void IcebergDataSink::closeInternal() { + VELOX_CHECK_NE(state_, State::kRunning); + VELOX_CHECK_NE(state_, State::kFinishing); + + TestValue::adjust( + "facebook::velox::connector::hive::FileDataSink::closeInternal", this); + + if (state_ == State::kClosed) { + // Ensure dataFileStats_ has entries for all writers. + dataFileStats_.resize(writers_.size()); + + for (auto i = 0; i < writers_.size(); ++i) { + if (writers_[i] == nullptr) { + // Writer was rotated and is null. Stats for rotated files were already + // collected in rotateWriter(). No final file to close. + continue; + } + const memory::NonReclaimableSectionGuard nonReclaimableGuard( + writerInfo_[i]->nonReclaimableSectionHolder.get()); + closeWriterAndCollectStats(i); + } + } else { + for (auto i = 0; i < writers_.size(); ++i) { + if (writers_[i] == nullptr) { + continue; + } + memory::NonReclaimableSectionGuard nonReclaimableGuard( + writerInfo_[i]->nonReclaimableSectionHolder.get()); + writers_[i]->abort(); + } + } +} + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergDataSink.h b/velox/connectors/hive/iceberg/IcebergDataSink.h new file mode 100644 index 00000000000..8d4ee101a27 --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergDataSink.h @@ -0,0 +1,261 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include "velox/connectors/hive/HiveDataSink.h" +#include "velox/connectors/hive/TableHandle.h" +#include "velox/connectors/hive/iceberg/IcebergColumnHandle.h" +#include "velox/connectors/hive/iceberg/IcebergDataFileStatistics.h" + +#ifdef VELOX_ENABLE_PARQUET +#include "velox/connectors/hive/iceberg/IcebergParquetStatsCollector.h" +#endif + +#include "velox/connectors/hive/iceberg/IcebergConfig.h" +#include "velox/connectors/hive/iceberg/IcebergPartitionName.h" +#include "velox/connectors/hive/iceberg/PartitionSpec.h" +#include "velox/connectors/hive/iceberg/TransformEvaluator.h" +#include "velox/functions/iceberg/Register.h" + +namespace facebook::velox::connector::hive::iceberg { + +/// Represents a request for Iceberg write. +class IcebergInsertTableHandle final : public HiveInsertTableHandle { + public: + /// @param inputColumns Columns from the table schema to write. + /// The input RowVector must have the same number of columns and matching + /// types in the same order. + /// Column names in the RowVector may differ from those in inputColumns, + /// only position and type must align. All columns present in the input + /// data must be included, mismatches can lead to write failure. + /// @param locationHandle Contains the target location information including: + /// - Base directory path where data files will be written. + /// - File naming scheme and temporary directory paths. + /// @param tableStorageFormat File format to use for writing data files. + /// @param partitionSpec Optional partition specification defining how to + /// partition the data. If nullptr, the table is unpartitioned and all data + /// is written to a single directory. + /// @param compressionKind Optional compression to apply to data files. + /// @param serdeParameters Additional serialization/deserialization parameters + /// for the file format. + IcebergInsertTableHandle( + std::vector inputColumns, + LocationHandlePtr locationHandle, + dwio::common::FileFormat tableStorageFormat, + IcebergPartitionSpecPtr partitionSpec, + std::optional compressionKind = {}, + const std::unordered_map& serdeParameters = {}); + + /// Returns the Iceberg partition specification that defines how the table + /// is partitioned. + const IcebergPartitionSpecPtr& partitionSpec() const { + return partitionSpec_; + } + + private: + const IcebergPartitionSpecPtr partitionSpec_; +}; + +using IcebergInsertTableHandlePtr = + std::shared_ptr; + +class IcebergDataSink : public HiveDataSink { + public: + IcebergDataSink( + RowTypePtr inputType, + IcebergInsertTableHandlePtr insertTableHandle, + const ConnectorQueryCtx* connectorQueryCtx, + CommitStrategy commitStrategy, + const std::shared_ptr& hiveConfig, + const IcebergConfigPtr& icebergConfig); + + /// Generates Iceberg-specific commit messages for all writers containing + /// metadata about written files. Creates a JSON object for each writer + /// in the format expected by Presto and Spark for Iceberg tables. + /// + /// Each commit message contains: + /// - path: full file path where data was written. + /// - fileSizeInBytes: raw bytes written to disk. + /// - metrics: object with recordCount (number of rows written). + /// - partitionSpecJson: partition specification. + /// - fileFormat: storage format. Either "PARQUET" or "ORC". DWRF files + /// are reported as "ORC" because Iceberg's file-format vocabulary has + /// no DWRF enum. + /// - content: file content type ("DATA" for data files). + /// + /// See + /// https://github.com/prestodb/presto/blob/master/presto-iceberg/src/main/java/com/facebook/presto/iceberg/CommitTaskData.java + /// + /// Note: Complete Iceberg metrics are not yet implemented, which results in + /// incomplete manifest files that may lead to suboptimal query planning. + /// + /// @return Vector of JSON strings, one per writer, formatted according to + /// Presto and Spark Iceberg commit protocol. + std::vector commitMessage() const override; + + private: + IcebergDataSink( + RowTypePtr inputType, + IcebergInsertTableHandlePtr insertTableHandle, + const ConnectorQueryCtx* connectorQueryCtx, + CommitStrategy commitStrategy, + const std::shared_ptr& hiveConfig, + const std::vector& partitionChannels, + const std::vector& dataChannels, + RowTypePtr partitionRowType, + const IcebergConfigPtr& icebergConfig); + + // Computes partition IDs for each row in the input batch by applying Iceberg + // partition transforms and generating unique partition identifiers. + // + // Performs a two-step process: + // 1. Applies Iceberg partition transforms (e.g., year, month, day, hour, + // bucket, truncate) to the input partition columns using + // transformEvaluator_ to produce transformed partition values. + // 2. Wraps the transformed columns in a RowVector with partitionRowType_ + // schema and passes it to partitionIdGenerator_ to compute partition IDs. + // + // The resulting partition IDs are stored in partitionIds_ buffer, where each + // element corresponds to a row in the input. These IDs are used to: + // - Route rows to the appropriate writer (one writer per unique partition). + // - Generate partition directory names via getPartitionName(). + // + // Note: Iceberg does not support bucketing, so this method only computes + // partition IDs, not bucket IDs. + // + // @param input The input RowVector containing rows to be partitioned. + void computePartitionAndBucketIds(const RowVectorPtr& input) override; + + // Returns the Iceberg partition directory name for the given partition ID. + // Converts the transformed partition values associated with the partition ID + // into an Iceberg compliant directory path + // (e.g., "date_year=2023/id_bucket=5"). + std::string getPartitionName(uint32_t partitionId) const override; + + // Ensures a writer exists for the given writer ID and returns its index. + // If the writer doesn't exist, creates it by calling appendWriter(). + // Additionally, extracts and stores the transformed partition values for + // the writer in commitPartitionValue_ if not already set, which will be + // included in the commit message as "partitionDataJson". + uint32_t ensureWriter(const WriterId& id) override; + + // Creates writer options configured for Iceberg table writes. Extends the + // base HiveDataSink writer options with Iceberg-specific settings: + // - Sets timestamp timezone to nullopt (UTC) for Iceberg compliance. + // - Sets timestamp precision to microseconds. + std::shared_ptr createWriterOptions( + size_t writerIndex) const override; + + // Extracts partition values for a specific writer to be included in the + // commit message. Converts the transformed partition values from columnar + // storage (partitionIdGenerator_->partitionValues() where each partition + // field is a separate column) to row storage (a folly::dynamic array of + // values for the given writer index) for JSON serialization. + // Returns nullptr for null partition values. + folly::dynamic makeCommitPartitionValue(uint32_t writerIndex) const; + + // Closes the active writer at 'index' to flush its file footer, captures + // the file metadata for Iceberg stats aggregation (via + // closeWriterAndCollectStats), then resets the writer so a new one is + // created lazily on the next write. Differs from the base + // FileDataSink::rotateWriter by also collecting per-file Iceberg stats + // before discarding the writer. + void rotateWriter(size_t index) override; + + // Closes all remaining writers and aggregates their file metadata into + // per-writer Iceberg stats (when state == kClosed). On any other state, + // aborts the writers without collecting stats. Stats for already-rotated + // files were collected during rotateWriter(). + void closeInternal() override; + + // Closes the writer at 'index', captures the resulting file metadata, and + // appends a per-file IcebergDataFileStatistics entry to dataFileStats_ + // (Parquet stats when the format provides them; an empty entry otherwise). + // Caller is responsible for the surrounding NonReclaimableSectionGuard. + void closeWriterAndCollectStats(size_t index); + + // Iceberg partition specification defining how the table is partitioned. + // Contains partition fields with source column names, transform types + // (e.g., identity, year, month, day, hour, bucket, truncate), transform + // parameters, and result types. Null if the table is unpartitioned. + const IcebergPartitionSpecPtr partitionSpec_; + + // Evaluates Iceberg partition transforms on input rows to produce transformed + // partition keys. Applies transforms defined in partitionSpec_ (e.g., + // year(date_col), bucket(id, 16)) to the corresponding input columns and + // returns a vector of transformed columns. The transformed keys are then + // wrapped in a RowVector and passed to IcebergPartitionIdGenerator. + // Null if the table is unpartitioned. + const std::unique_ptr transformEvaluator_; + + // Generates Iceberg compliant partition directory names from partition IDs. + // Converts transformed partition values to human-readable strings based on + // their transform types (e.g., year -> "2025", month -> "2025-11", hour -> + // "2025-11-12-13") and constructs URL-encoded partition paths. + // Null if the table is unpartitioned. + const std::unique_ptr icebergPartitionName_; + + // RowType schema for the transformed partition values RowVector. + // Contains one column per partition field in partitionSpec, where each + // column has: + // - Type: The result type of the partition transform (e.g., INTEGER for year + // transform, DATE for day transform). + // - Name: Source column name for identity transforms, or + // "columnName_transformName" for non-identity transforms (e.g., + // "date_year"). + // Used to construct the RowVector that wraps the transformed partition + // columns before passing them to IcebergPartitionIdGenerator for partition ID + // generation and to IcebergPartitionNameGenerator for partition path name + // generation. + RowTypePtr partitionRowType_; + + // Stores the transformed partition values for each writer to be included in + // the commit message sent to Presto. Indexed by writer index. Each entry + // contains the transformed partition values (as a folly::dynamic array) for + // that writer's partition, which are serialized to JSON as + // "partitionDataJson" in the commit protocol. These values represent the same + // transformed partition data as partitionIdGenerator_->partitionValues(), but + // converted from columnar storage (where each partition field is a separate + // column in the RowVector) to row storage (where each writer has a + // folly::dynamic array of values across all partition fields), ready for JSON + // serialization. + std::vector commitPartitionValue_; + + // Statistics for all data files written by this sink, organized by writer + // index and file index within each writer. These statistics are populated + // during rotateWriter() (for rotated files) and during closeInternal() + // (for the final file of each writer). These metrics are subsequently used + // to construct Iceberg commit messages. + // Outer vector: indexed by writer index (same as writerInfo_). + // Inner vector: one entry per file written by that writer (including + // rotated files and the final file). Each entry corresponds to one + // individual data file. + std::vector> dataFileStats_; + + const IcebergInsertTableHandlePtr icebergInsertTableHandle_; + +#ifdef VELOX_ENABLE_PARQUET + std::shared_ptr parquetStatsCollector_; +#endif +}; + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergDataSource.cpp b/velox/connectors/hive/iceberg/IcebergDataSource.cpp new file mode 100644 index 00000000000..30fa219e9ab --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergDataSource.cpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/IcebergDataSource.h" + +#include "velox/connectors/hive/iceberg/IcebergSplit.h" +#include "velox/connectors/hive/iceberg/IcebergSplitReader.h" + +namespace facebook::velox::connector::hive::iceberg { + +IcebergDataSource::IcebergDataSource( + const RowTypePtr& outputType, + const ConnectorTableHandlePtr& tableHandle, + const ColumnHandleMap& assignments, + FileHandleFactory* fileHandleFactory, + folly::Executor* ioExecutor, + const ConnectorQueryCtx* connectorQueryCtx, + const std::shared_ptr& hiveConfig) + : HiveDataSource( + outputType, + tableHandle, + assignments, + fileHandleFactory, + ioExecutor, + connectorQueryCtx, + hiveConfig), + columnHandles_(std::make_shared(assignments)) {} + +std::unique_ptr IcebergDataSource::createSplitReader() { + prepareSplit(); + auto icebergSplit = checkedPointerCast(split_); + + auto reader = std::make_unique( + icebergSplit, + tableHandle_, + &partitionKeys_, + connectorQueryCtx_, + fileConfig_, + readerOutputType_, + dataIoStats_, + metadataIoStats_, + ioStats_, + fileHandleFactory_, + ioExecutor_, + scanSpec_, + columnHandles_); + + return reader; +} + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergDataSource.h b/velox/connectors/hive/iceberg/IcebergDataSource.h new file mode 100644 index 00000000000..243ba61b043 --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergDataSource.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/connectors/hive/HiveDataSource.h" + +namespace facebook::velox::connector::hive::iceberg { + +/// Iceberg-specific data source that extends HiveDataSource. +/// +/// Provides Iceberg table format support by creating +/// IcebergSplitReader instances that handle: +/// - Positional delete files for row-level deletes. +/// - Schema evolution with column adaptation. +/// - Iceberg-specific metadata columns. +class IcebergDataSource : public HiveDataSource { + public: + IcebergDataSource( + const RowTypePtr& outputType, + const ConnectorTableHandlePtr& tableHandle, + const ColumnHandleMap& assignments, + FileHandleFactory* fileHandleFactory, + folly::Executor* ioExecutor, + const ConnectorQueryCtx* connectorQueryCtx, + const std::shared_ptr& hiveConfig); + + protected: + /// Creates an IcebergSplitReader for reading Iceberg data files. + /// + /// Unlike the base HiveDataSource which creates a generic FileSplitReader, + /// this method creates an IcebergSplitReader that handles Iceberg-specific + /// features like positional delete files and schema evolution. + std::unique_ptr createSplitReader() override; + + private: + /// Column handles map for accessing column metadata. + std::shared_ptr columnHandles_; +}; + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergDeleteFile.h b/velox/connectors/hive/iceberg/IcebergDeleteFile.h index 2f9206dfc26..a346fcccfef 100644 --- a/velox/connectors/hive/iceberg/IcebergDeleteFile.h +++ b/velox/connectors/hive/iceberg/IcebergDeleteFile.h @@ -27,6 +27,12 @@ enum class FileContent { kData, kPositionalDeletes, kEqualityDeletes, + /// Iceberg V3 deletion vector. A serialized roaring bitmap of deleted row + /// positions stored as a blob inside a Puffin file. More compact than V2 + /// positional delete files and avoids sorted merge of multiple delete files. + /// The coordinator extracts the blob offset and length from the Puffin + /// footer and provides them via IcebergDeleteFile fields. + kDeletionVector, }; struct IcebergDeleteFile { @@ -47,6 +53,28 @@ struct IcebergDeleteFile { // 1 is in range [10, 50], then upperBounds will contain entry <1, "50"> std::unordered_map upperBounds; + /// Data sequence number of this delete file, assigned by the Iceberg snapshot + /// that produced it. Per the Iceberg spec (V2+), an equality delete file must + /// only be applied to data files whose data sequence number is strictly less + /// than the delete file's data sequence number. A value of 0 means + /// "unassigned" (legacy V1 tables) and disables sequence number filtering. + int64_t dataSequenceNumber{0}; + + /// Byte offset of the deletion-vector blob inside the Puffin file pointed to + /// by 'filePath'. Only meaningful for kDeletionVector content. A value of 0 + /// is allowed (and is the default for non-DV files). + int64_t contentOffset{0}; + + /// Length in bytes of the deletion-vector blob inside the Puffin file pointed + /// to by 'filePath'. Only meaningful for kDeletionVector content. A value of + /// 0 means the consumer should fall back to reading until end-of-file. + int64_t contentLength{0}; + + /// For kDeletionVector content: path of the data file this DV applies to. + /// When set (non-empty), the reader can skip this DV for any other data file + /// as a belt-and-suspenders pruning step. Empty string disables the filter. + std::string referencedDataFile{}; + IcebergDeleteFile( FileContent _content, const std::string& _filePath, @@ -55,7 +83,11 @@ struct IcebergDeleteFile { uint64_t _fileSizeInBytes, std::vector _equalityFieldIds = {}, std::unordered_map _lowerBounds = {}, - std::unordered_map _upperBounds = {}) + std::unordered_map _upperBounds = {}, + int64_t _dataSequenceNumber = 0, + int64_t _contentOffset = 0, + int64_t _contentLength = 0, + std::string _referencedDataFile = {}) : content(_content), filePath(_filePath), fileFormat(_fileFormat), @@ -63,7 +95,11 @@ struct IcebergDeleteFile { fileSizeInBytes(_fileSizeInBytes), equalityFieldIds(_equalityFieldIds), lowerBounds(_lowerBounds), - upperBounds(_upperBounds) {} + upperBounds(_upperBounds), + dataSequenceNumber(_dataSequenceNumber), + contentOffset(_contentOffset), + contentLength(_contentLength), + referencedDataFile(std::move(_referencedDataFile)) {} }; } // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergMetadataColumns.h b/velox/connectors/hive/iceberg/IcebergMetadataColumns.h index 4cbf2a7862b..cacdb6009ec 100644 --- a/velox/connectors/hive/iceberg/IcebergMetadataColumns.h +++ b/velox/connectors/hive/iceberg/IcebergMetadataColumns.h @@ -28,6 +28,21 @@ struct IcebergMetadataColumn { std::shared_ptr type; std::string doc; + // Reserved Field IDs for Iceberg tables; see + // https://iceberg.apache.org/spec/#reserved-field-ids + static constexpr int32_t kPosId = 2'147'483'545; + static constexpr int32_t kFilePathId = 2'147'483'546; + static constexpr int32_t kRowId = 2'147'483'540; + static constexpr int32_t kLastUpdatedSequenceNumber = 2'147'483'539; + + static constexpr const char* kRowIdColumnName = "_row_id"; + static constexpr const char* kLastUpdatedSequenceNumberColumnName = + "_last_updated_sequence_number"; + // Info column keys provided in the split's infoColumns map. + static constexpr const char* kFirstRowIdInfoColumn = "$first_row_id"; + static constexpr const char* kDataSequenceNumberInfoColumn = + "$data_sequence_number"; + IcebergMetadataColumn( int _id, const std::string& _name, @@ -37,7 +52,7 @@ struct IcebergMetadataColumn { static std::shared_ptr icebergDeleteFilePathColumn() { return std::make_shared( - 2147483546, + kFilePathId, "file_path", VARCHAR(), "Path of a file in which a deleted row is stored"); @@ -45,7 +60,7 @@ struct IcebergMetadataColumn { static std::shared_ptr icebergDeletePosColumn() { return std::make_shared( - 2147483545, + kPosId, "pos", BIGINT(), "Ordinal position of a deleted row in the data file"); diff --git a/velox/connectors/hive/iceberg/IcebergParquetStatsCollector.cpp b/velox/connectors/hive/iceberg/IcebergParquetStatsCollector.cpp new file mode 100644 index 00000000000..10b04574c0a --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergParquetStatsCollector.cpp @@ -0,0 +1,174 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/IcebergParquetStatsCollector.h" + +#include "velox/common/Casts.h" +#include "velox/common/base/Exceptions.h" +#include "velox/common/encode/Base64.h" +#include "velox/connectors/hive/iceberg/IcebergColumnHandle.h" +#include "velox/connectors/hive/iceberg/IcebergDataFileStatistics.h" +#include "velox/dwio/common/FileMetadata.h" +#include "velox/dwio/parquet/writer/Writer.h" +#include "velox/dwio/parquet/writer/arrow/Metadata.h" +#include "velox/dwio/parquet/writer/arrow/Statistics.h" +#include "velox/type/Type.h" + +namespace facebook::velox::connector::hive::iceberg { + +namespace { + +void addAllRecursive( + const parquet::ParquetFieldId& field, + const TypePtr& type, + std::unordered_set& fieldIds) { + fieldIds.insert(field.fieldId); + + VELOX_CHECK_EQ(field.children.size(), type->size()); + for (auto i = 0; i < type->size(); ++i) { + addAllRecursive(field.children[i], type->childAt(i), fieldIds); + } +} + +// Recursively collects field IDs that should skip bounds collection. +// Repeated fields (e.g. MAP and ARRAY) are not currently supported by Iceberg. +// These fields, along with all their descendants, should skip bounds +// collection. +// @param field The Parquet field ID structure to process. +// @param type The Velox type corresponding to this field. +// @param fieldIds Output set to populate with field IDs to skip. +void collectSkipBoundsFieldIds( + const parquet::ParquetFieldId& field, + const TypePtr& type, + std::unordered_set& fieldIds) { + VELOX_CHECK_NOT_NULL(type, "Input column type cannot be null."); + + if (type->isMap() || type->isArray()) { + addAllRecursive(field, type, fieldIds); + return; + } + + VELOX_CHECK_EQ(field.children.size(), type->size()); + for (auto i = 0; i < type->size(); ++i) { + collectSkipBoundsFieldIds(field.children[i], type->childAt(i), fieldIds); + } +} + +} // namespace + +IcebergParquetStatsCollector::IcebergParquetStatsCollector( + const std::vector& inputColumns) { + parquetFieldIds_.children.reserve(inputColumns.size()); + for (const auto& columnHandle : inputColumns) { + parquetFieldIds_.children.emplace_back(columnHandle->field()); + collectSkipBoundsFieldIds( + columnHandle->field(), columnHandle->dataType(), skipBoundsFieldIds_); + } +} + +IcebergDataFileStatisticsPtr IcebergParquetStatsCollector::aggregate( + std::unique_ptr fileMetadata) { + // Empty data file. + if (!fileMetadata) { + return std::make_shared( + IcebergDataFileStatistics::empty()); + } + + auto parquetMetadata = + checkedPointerCast(std::move(fileMetadata)); + auto metadata = parquetMetadata->arrowMetadata(); + auto dataFileStats = std::make_shared(); + dataFileStats->numRecords = metadata->numRows(); + const auto numRowGroups = metadata->numRowGroups(); + + // Track global min/max statistics for each column across all row groups. + // Key: Iceberg field ID. + // Value: A pair of Statistics objects where: + // - first: The statistics from the row group containing the global minimum + // value. + // - second: The statistics from the row group containing the global maximum + // value. Two separate objects are stored because the global minimum and + // global maximum for a single column may originate from different row groups. + folly::F14FastMap< + int32_t, + std::pair< + std::shared_ptr, + std::shared_ptr>> + globalMinMaxStats; + + std::unordered_set fieldIds; + for (auto i = 0; i < numRowGroups; ++i) { + const auto& rowGroup = metadata->rowGroup(i); + + for (auto j = 0; j < rowGroup->numColumns(); ++j) { + const auto& columnChunk = rowGroup->columnChunk(j); + const auto fieldId = columnChunk->fieldId(); + fieldIds.insert(fieldId); + + auto& stats = dataFileStats->columnStats[fieldId]; + stats.valueCount += columnChunk->numValues(); + stats.columnSize += columnChunk->totalCompressedSize(); + + const auto& columnChunkStats = columnChunk->statistics(); + if (columnChunkStats) { + stats.nullValueCount += columnChunkStats->nullCount(); + + if (columnChunkStats->hasMinMax() && shouldStoreBounds(fieldId)) { + auto [it, inserted] = globalMinMaxStats.emplace( + fieldId, std::pair{columnChunkStats, columnChunkStats}); + + if (!inserted) { + auto& [minStats, maxStats] = it->second; + + if (columnChunkStats->maxGreaterThan(*maxStats)) { + maxStats = columnChunkStats; + } + if (columnChunkStats->minLessThan(*minStats)) { + minStats = columnChunkStats; + } + } + } + } + } + } + + for (const auto fieldId : fieldIds) { + const auto& [nanCount, hasNanCount] = metadata->getNaNCount(fieldId); + if (hasNanCount) { + dataFileStats->columnStats[fieldId].nanValueCount = nanCount; + } + } + + for (const auto& [fieldId, stats] : globalMinMaxStats) { + const auto& [minStats, maxStats] = stats; + + auto& columnStats = dataFileStats->columnStats[fieldId]; + const auto& lowerBound = + minStats->icebergLowerBoundInclusive(kDefaultTruncateLength); + columnStats.lowerBound = + encoding::Base64::encode(lowerBound.data(), lowerBound.size()); + + const auto upperBound = + maxStats->icebergUpperBoundExclusive(kDefaultTruncateLength); + if (upperBound.has_value()) { + columnStats.upperBound = + encoding::Base64::encode(upperBound->data(), upperBound->size()); + } + } + return dataFileStats; +} + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergParquetStatsCollector.h b/velox/connectors/hive/iceberg/IcebergParquetStatsCollector.h new file mode 100644 index 00000000000..8600f816d4d --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergParquetStatsCollector.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "velox/connectors/hive/iceberg/IcebergColumnHandle.h" +#include "velox/connectors/hive/iceberg/IcebergDataFileStatistics.h" +#include "velox/dwio/common/FileMetadata.h" +#include "velox/dwio/parquet/ParquetFieldId.h" + +namespace facebook::velox::connector::hive::iceberg { + +/// Aggregates per-file Iceberg column statistics (column sizes, value counts, +/// null/NaN counts, lower/upper bounds) from Parquet row group metadata +/// produced by the Parquet writer. One instance is created per IcebergDataSink +/// and reused across all writers; aggregate() is invoked after each Parquet +/// file is closed. Also exposes the Parquet field IDs for the input columns +/// so they can be threaded into the writer's column metadata. +class IcebergParquetStatsCollector { + public: + explicit IcebergParquetStatsCollector( + const std::vector& inputColumns); + + /// Returns the Parquet field IDs for all input columns. + /// The field IDs are written to the Parquet data file's column metadata. + /// The return object describes a multi-column input. + const parquet::ParquetFieldId& parquetFieldIds() const { + return parquetFieldIds_; + } + + /// Aggregates Parquet file metadata into Iceberg data file statistics. + /// Iterates through all row groups and columns to collect: + /// - Record count, split offsets, value counts, column sizes, null counts. + /// - Min/max bounds (base64-encoded). Currently not collected for MAP and + /// ARRAY types and all their descendants. + /// @param fileMetadata The Parquet file metadata to aggregate. + IcebergDataFileStatisticsPtr aggregate( + std::unique_ptr fileMetadata); + + /// TODO: Need to support this config property. + /// 16 is default value. See DEFAULT_WRITE_METRICS_MODE_DEFAULT in + /// org.apache.iceberg.TableProperties. + constexpr static int32_t kDefaultTruncateLength{16}; + + private: + bool shouldStoreBounds(int32_t fieldId) const { + return !skipBoundsFieldIds_.contains(fieldId); + } + + // Hierarchical Parquet field IDs for all input columns. A single + // ParquetFieldId can describe all the columns including their nested + // children. + parquet::ParquetFieldId parquetFieldIds_; + + // Set of field IDs for which bounds collection should be skipped. + // This includes MAP and ARRAY types and all their descendants. + std::unordered_set skipBoundsFieldIds_; +}; + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergPartitionName.cpp b/velox/connectors/hive/iceberg/IcebergPartitionName.cpp new file mode 100644 index 00000000000..97a0f565b8b --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergPartitionName.cpp @@ -0,0 +1,137 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/connectors/hive/iceberg/IcebergPartitionName.h" +#include "velox/common/encode/Base64.h" +#include "velox/dwio/catalog/fbhive/FileUtils.h" +#include "velox/functions/prestosql/URLFunctions.h" + +namespace facebook::velox::connector::hive::iceberg { + +namespace { + +std::string escapePathName(const std::string& name) { + std::string encoded; + // Pre-allocate for worst case: every byte is invalid UTF-8. + // urlEscape() writes directly into the pre-allocated buffer and + // calls resize() at the end to shrink to the actual size used. + encoded.resize(name.size() * 9); + functions::detail::urlEscape(encoded, name); + return encoded; +} + +} // namespace + +IcebergPartitionName::IcebergPartitionName( + const IcebergPartitionSpecPtr& partitionSpec) { + VELOX_CHECK_NOT_NULL(partitionSpec); + transformTypes_.reserve(partitionSpec->fields.size()); + for (const auto& field : partitionSpec->fields) { + transformTypes_.emplace_back(field.transformType); + } +} + +std::string IcebergPartitionName::partitionName( + uint32_t partitionId, + const RowVectorPtr& partitionValues, + bool partitionKeyAsLowerCase) const { + auto toPartitionName = [this]( + auto value, const TypePtr& type, int columnIndex) { + return IcebergPartitionName::toName( + value, type, transformTypes_[columnIndex]); + }; + + return dwio::catalog::fbhive::FileUtils::makePartName( + HivePartitionName::partitionKeyValues( + partitionId, + partitionValues, + /*nullValueString=*/"null", + toPartitionName), + partitionKeyAsLowerCase, + /*useDefaultPartitionValue=*/false, + escapePathName); +} + +std::string IcebergPartitionName::toName( + int32_t value, + const TypePtr& type, + TransformType transformType) { + constexpr int32_t kEpochYear = 1970; + switch (transformType) { + case TransformType::kIdentity: { + if (type->isDate()) { + return DateType::toIso8601(value); + } + return fmt::to_string(value); + } + case TransformType::kDay: + return DATE()->toString(value); + case TransformType::kYear: + return fmt::format("{:04d}", kEpochYear + value); + case TransformType::kMonth: { + int32_t year = kEpochYear + value / 12; + int32_t month = 1 + value % 12; + if (month <= 0) { + month += 12; + year -= 1; + } + return fmt::format("{:04d}-{:02d}", year, month); + } + case TransformType::kHour: { + int64_t seconds = static_cast(value) * 3600; + std::tm tmValue; + VELOX_USER_CHECK( + Timestamp::epochToCalendarUtc(seconds, tmValue), + "Failed to convert seconds to time: {}", + seconds); + return fmt::format( + "{:04d}-{:02d}-{:02d}-{:02d}", + tmValue.tm_year + 1900, + tmValue.tm_mon + 1, + tmValue.tm_mday, + tmValue.tm_hour); + } + default: + return fmt::to_string(value); + } +} + +std::string IcebergPartitionName::toName( + Timestamp value, + const TypePtr& type, + TransformType transformType) { + VELOX_CHECK(transformType == TransformType::kIdentity); + TimestampToStringOptions options; + options.precision = TimestampPrecision::kMilliseconds; + options.zeroPaddingYear = true; + options.skipTrailingZeros = true; + options.leadingPositiveSign = true; + return value.toString(options); +} + +std::string IcebergPartitionName::toName( + StringView value, + const TypePtr& type, + TransformType transformType) { + VELOX_CHECK( + transformType == TransformType::kIdentity || + transformType == TransformType::kTruncate); + if (type->isVarbinary()) { + return encoding::Base64::encode(value.data(), value.size()); + } + return std::string(value); +} + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergPartitionName.h b/velox/connectors/hive/iceberg/IcebergPartitionName.h new file mode 100644 index 00000000000..751f3620c80 --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergPartitionName.h @@ -0,0 +1,114 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/connectors/hive/HivePartitionName.h" +#include "velox/connectors/hive/iceberg/PartitionSpec.h" + +namespace facebook::velox::connector::hive::iceberg { + +/// Generates Iceberg-compliant partition path names. +/// Converts partition keys to human-readable strings based on their transform +/// types (e.g., year, month, day, hour, identity, truncate) and constructs +/// URL-encoded partition paths in the format "key1=value1/key2=value2/...". +class IcebergPartitionName { + public: + /// @param partitionSpec Iceberg partition specification containing transform + /// definitions for each partition field. Used to get transform type and call + /// different format functions to convert transformed partition values to + /// human-readable strings. + IcebergPartitionName(const IcebergPartitionSpecPtr& partitionSpec); + + /// Generates an Iceberg compliant partition path string for the given + /// partition ID. + /// + /// Constructs a partition path in the format "key1=value1/key2=value2/..." + /// where: + /// - Keys are partition column names for identity transforms, or + /// "columnName_transformName" for non-identity transforms (e.g., + /// "date_year") + /// - Values are human-readable string representations of the transformed + /// partition keys, formatted according to their transform types + /// - Both keys and values are URL-encoded per java.net.URLEncoder.encode() + /// + /// Example: "store_id=123/date_year=2025/address_bucket=1" + /// + /// Typically called once per partition ID when creating a new writer for that + /// partition. + /// + /// @param partitionId Sequential partition ID (0-based) used as the row index + /// into partitionValues. Must be less than partitionValues->size(). + /// @param partitionValues RowVector containing transformed partition keys + /// for all partitions. Each row represents one unique partition, with + /// columns corresponding to partition fields in partitionSpec. Row at + /// partitionId contains the keys for this specific partition. + /// @param partitionKeyAsLowerCase Whether to convert partition keys to + /// lowercase in the generated partition path. When true, partition keys like + /// "Year" become "year" in the path "year=2025/...". + /// @return URL-encoded partition path string suitable for use in file paths. + std::string partitionName( + uint32_t partitionId, + const RowVectorPtr& partitionValues, + bool partitionKeyAsLowerCase) const; + + /// Generic template for formatting simple types that just need string + /// conversion. Specialized for types that need special handling. + template + FOLLY_ALWAYS_INLINE static std::string + toName(T value, const TypePtr& type, TransformType transformType) { + return HivePartitionName::toName(value, type); + } + + /// Converts an int32_t partition key to its string representation based on + /// the transform type: + /// - kIdentity: For DATE type return "YYYY-MM-DD" format (e.g., + /// "2025-11-07"). + /// For other types return the value as-is (e.g., "-123"). + /// - kDay: Returns date in "YYYY-MM-DD" format (e.g., "2025-11-07"). + /// - kYear: Returns 4-digit year "YYYY" (e.g., "2025"). + /// - kMonth: Returns "YYYY-MM" format (e.g., "2025-01"). + /// - kHour: Returns "YYYY-MM-DD-HH" format (e.g., "2025-11-07-21"). + static std::string + toName(int32_t value, const TypePtr& type, TransformType transformType); + + /// Returns timestamp formatted with milliseconds precision, zero-padded + /// year, trailing zeros skipped, and leading positive sign for years >= + /// 10000. Examples: + /// - Timestamp(0, 0) -> "1970-01-01T00:00:00". + /// - Timestamp(1609459200, 999000000) -> "2021-01-01T00:00:00.999". + /// - Timestamp(1640995200, 500000000) -> "2022-01-01T00:00:00.5". + /// - Timestamp(-1, 999000000) -> "1969-12-31T23:59:59.999". + /// - Timestamp(253402300800, 100000000) -> "+10000-01-01T00:00:00.1". + static std::string + toName(Timestamp value, const TypePtr& type, TransformType transformType); + + /// Converts a StringView partition key to its string representation. + /// - For VARBINARY type returns Base64-encoded string. + /// - For VARCHAR type returns the string value as-is. + static std::string + toName(StringView value, const TypePtr& type, TransformType transformType); + + private: + // Cached transform types, one per partition column. Created once in + // constructor and reused for all formatting operations. Index corresponds to + // column index in partitionSpec_->fields. + std::vector transformTypes_; +}; + +using IcebergPartitionNamePtr = std::shared_ptr; + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergSplit.cpp b/velox/connectors/hive/iceberg/IcebergSplit.cpp index e5cdf63c33b..a6bbea066d6 100644 --- a/velox/connectors/hive/iceberg/IcebergSplit.cpp +++ b/velox/connectors/hive/iceberg/IcebergSplit.cpp @@ -33,7 +33,8 @@ HiveIcebergSplit::HiveIcebergSplit( const std::shared_ptr& extraFileInfo, bool cacheable, const std::unordered_map& infoColumns, - std::optional properties) + std::optional properties, + int64_t dataSequenceNumber) : HiveConnectorSplit( connectorId, filePath, @@ -50,7 +51,8 @@ HiveIcebergSplit::HiveIcebergSplit( infoColumns, properties, std::nullopt, - std::nullopt) { + std::nullopt), + dataSequenceNumber(dataSequenceNumber) { // TODO: Deserialize _extraFileInfo to get deleteFiles; } @@ -69,7 +71,8 @@ HiveIcebergSplit::HiveIcebergSplit( bool cacheable, std::vector deletes, const std::unordered_map& infoColumns, - std::optional properties) + std::optional properties, + int64_t dataSequenceNumber) : HiveConnectorSplit( connectorId, filePath, @@ -87,5 +90,6 @@ HiveIcebergSplit::HiveIcebergSplit( properties, std::nullopt, std::nullopt), - deleteFiles(std::move(deletes)) {} + deleteFiles(std::move(deletes)), + dataSequenceNumber(dataSequenceNumber) {} } // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergSplit.h b/velox/connectors/hive/iceberg/IcebergSplit.h index eb2448dabd1..982fbeb4e19 100644 --- a/velox/connectors/hive/iceberg/IcebergSplit.h +++ b/velox/connectors/hive/iceberg/IcebergSplit.h @@ -25,6 +25,13 @@ namespace facebook::velox::connector::hive::iceberg { struct HiveIcebergSplit : public connector::hive::HiveConnectorSplit { std::vector deleteFiles; + /// Data sequence number of the base data file in this split. Per the Iceberg + /// spec (V2+), an equality delete file should only apply to data files whose + /// data sequence number is strictly less than the delete file's sequence + /// number. A value of 0 means "unassigned" (legacy V1 tables) and disables + /// sequence number filtering. + int64_t dataSequenceNumber{0}; + HiveIcebergSplit( const std::string& connectorId, const std::string& filePath, @@ -38,7 +45,8 @@ struct HiveIcebergSplit : public connector::hive::HiveConnectorSplit { const std::shared_ptr& extraFileInfo = {}, bool cacheable = true, const std::unordered_map& infoColumns = {}, - std::optional fileProperties = std::nullopt); + std::optional fileProperties = std::nullopt, + int64_t dataSequenceNumber = 0); // For tests only HiveIcebergSplit( @@ -55,7 +63,8 @@ struct HiveIcebergSplit : public connector::hive::HiveConnectorSplit { bool cacheable = true, std::vector deletes = {}, const std::unordered_map& infoColumns = {}, - std::optional fileProperties = std::nullopt); + std::optional fileProperties = std::nullopt, + int64_t dataSequenceNumber = 0); }; } // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergSplitReader.cpp b/velox/connectors/hive/iceberg/IcebergSplitReader.cpp index 8b8e9fe6ffa..63167152f3f 100644 --- a/velox/connectors/hive/iceberg/IcebergSplitReader.cpp +++ b/velox/connectors/hive/iceberg/IcebergSplitReader.cpp @@ -16,86 +16,493 @@ #include "velox/connectors/hive/iceberg/IcebergSplitReader.h" +#include + +#include +#include + +#include "velox/common/base/Exceptions.h" +#include "velox/common/encode/Base64.h" +#include "velox/connectors/hive/iceberg/IcebergColumnHandle.h" #include "velox/connectors/hive/iceberg/IcebergDeleteFile.h" +#include "velox/connectors/hive/iceberg/IcebergMetadataColumns.h" #include "velox/connectors/hive/iceberg/IcebergSplit.h" #include "velox/dwio/common/BufferUtil.h" +#include "velox/vector/DecodedVector.h" using namespace facebook::velox::dwio::common; +namespace { + +// Fills null slots in a BIGINT child vector with values produced by valueAt(i), +// where i is the output row index. If the whole vector is a null constant, +// replaces it with a flat vector. If only some slots are null, copies non-null +// values and fills null slots with valueAt(i). No-op when no nulls are present. +template +void fillNullsWithInt64( + facebook::velox::VectorPtr& child, + facebook::velox::memory::MemoryPool* pool, + F valueAt) { + using namespace facebook::velox; + child = BaseVector::loadedVectorShared(child); + const auto vectorSize = child->size(); + if (vectorSize == 0) { + return; + } + if (child->isConstantEncoding() && child->isNullAt(0)) { + auto flat = + BaseVector::create>(BIGINT(), vectorSize, pool); + for (vector_size_t i = 0; i < vectorSize; ++i) { + flat->set(i, valueAt(i)); + } + child = std::move(flat); + } else if (child->mayHaveNulls()) { + const DecodedVector decoded(*child); + if (decoded.mayHaveNulls()) { + auto flat = + BaseVector::create>(BIGINT(), vectorSize, pool); + for (vector_size_t i = 0; i < vectorSize; ++i) { + if (decoded.isNullAt(i)) { + flat->set(i, valueAt(i)); + } else { + flat->set(i, decoded.valueAt(i)); + } + } + child = std::move(flat); + } + } +} + +} // namespace + namespace facebook::velox::connector::hive::iceberg { +namespace { + +/// Returns true if a delete/update file should be skipped based on sequence +/// number conflict resolution. Per the Iceberg spec (V2+): +/// - Equality deletes apply when deleteSeqNum > dataSeqNum (i.e., skip when +/// deleteSeqNum <= dataSeqNum). +/// - Positional deletes, deletion vectors, and positional updates apply when +/// deleteSeqNum >= dataSeqNum (i.e., skip when deleteSeqNum < dataSeqNum), +/// because same-snapshot positional deletes SHOULD apply. +/// - A sequence number of 0 means "unassigned" (legacy V1 tables) and +/// disables filtering (never skip). +bool shouldSkipBySequenceNumber( + int64_t deleteFileSeqNum, + int64_t dataSeqNum, + bool isEqualityDelete) { + if (deleteFileSeqNum <= 0 || dataSeqNum <= 0) { + return false; + } + return isEqualityDelete ? (deleteFileSeqNum <= dataSeqNum) + : (deleteFileSeqNum < dataSeqNum); +} + +} // namespace IcebergSplitReader::IcebergSplitReader( - const std::shared_ptr& hiveSplit, - const HiveTableHandlePtr& hiveTableHandle, - const std::unordered_map* partitionKeys, + const std::shared_ptr& icebergSplit, + const FileTableHandlePtr& tableHandle, + const std::unordered_map* partitionKeys, const ConnectorQueryCtx* connectorQueryCtx, - const std::shared_ptr& hiveConfig, + const std::shared_ptr& fileConfig, const RowTypePtr& readerOutputType, - const std::shared_ptr& ioStats, - const std::shared_ptr& fsStats, + const std::shared_ptr& dataIoStats, + const std::shared_ptr& metadataIoStats, + const std::shared_ptr& ioStats, FileHandleFactory* const fileHandleFactory, folly::Executor* executor, - const std::shared_ptr& scanSpec) - : SplitReader( - hiveSplit, - hiveTableHandle, + const std::shared_ptr& scanSpec, + std::shared_ptr columnHandles) + : FileSplitReader( + icebergSplit, + tableHandle, partitionKeys, connectorQueryCtx, - hiveConfig, + fileConfig, readerOutputType, + dataIoStats, + metadataIoStats, ioStats, - fsStats, fileHandleFactory, executor, scanSpec), + icebergSplit_(icebergSplit), baseReadOffset_(0), splitOffset_(0), - deleteBitmap_(nullptr) {} + deleteBitmap_(nullptr), + columnHandles_(std::move(columnHandles)) {} void IcebergSplitReader::prepareSplit( std::shared_ptr metadataFilter, - dwio::common::RuntimeStatistics& runtimeStats) { - createReader(); + dwio::common::RuntimeStatistics& runtimeStats, + const folly::F14FastMap& fileReadOps) { + // Temporarily extend the file schema with projected row-lineage columns + // (_row_id, _last_updated_sequence_number) so the reader allocates output + // slots for them. Scoped to createReader() so getAdaptedRowType() sees the + // original schema. + { + auto originalFileSchema = baseReaderOpts_.fileSchema(); + auto restorer = folly::makeGuard([&] { + if (originalFileSchema) { + baseReaderOpts_.setFileSchema(originalFileSchema); + } + }); + if (originalFileSchema) { + auto names = originalFileSchema->names(); + auto types = originalFileSchema->children(); + bool modified = false; + for (const auto* colName : + {IcebergMetadataColumn::kRowIdColumnName, + IcebergMetadataColumn::kLastUpdatedSequenceNumberColumnName}) { + if (readerOutputType_->containsChild(colName) && + !originalFileSchema->containsChild(colName)) { + names.emplace_back(std::string(colName)); + types.emplace_back(BIGINT()); + modified = true; + } + } + if (modified) { + baseReaderOpts_.setFileSchema(ROW(std::move(names), std::move(types))); + } + } + createReader(fileReadOps); + } + if (emptySplit_) { return; } + + configureEqualityDeleteColumns(); + + firstRowId_ = std::nullopt; + if (auto it = icebergSplit_->infoColumns.find( + IcebergMetadataColumn::kFirstRowIdInfoColumn); + it != icebergSplit_->infoColumns.end()) { + try { + firstRowId_ = folly::to(it->second); + } catch (const folly::ConversionError&) { + VELOX_FAIL( + "Invalid $first_row_id value in split info columns: {}", it->second); + } + VELOX_CHECK_GE(*firstRowId_, 0, "First row ID must be non-negative"); + } + + dataSequenceNumber_ = std::nullopt; + if (auto it = icebergSplit_->infoColumns.find( + IcebergMetadataColumn::kDataSequenceNumberInfoColumn); + it != icebergSplit_->infoColumns.end()) { + try { + dataSequenceNumber_ = folly::to(it->second); + } catch (const folly::ConversionError&) { + VELOX_FAIL( + "Invalid $data_sequence_number value in split info columns: {}", + it->second); + } + VELOX_CHECK_GE( + *dataSequenceNumber_, 0, "Data sequence number must be non-negative"); + } + + // getAdaptedRowType() calls adaptColumns(), which may set + // _last_updated_sequence_number to a constant. Must run after + // dataSequenceNumber_ and firstRowId_ are initialized. auto rowType = getAdaptedRowType(); + lastUpdatedSeqNumOutputIndex_ = std::nullopt; + if (dataSequenceNumber_.has_value()) { + auto* seqNumSpec = scanSpec_->childByName( + IcebergMetadataColumn::kLastUpdatedSequenceNumberColumnName); + if (seqNumSpec && !seqNumSpec->isConstant()) { + lastUpdatedSeqNumOutputIndex_ = readerOutputType_->getChildIdxIfExists( + IcebergMetadataColumn::kLastUpdatedSequenceNumberColumnName); + } + } + + rowIdOutputIndex_ = std::nullopt; + if (firstRowId_.has_value()) { + rowIdOutputIndex_ = readerOutputType_->getChildIdxIfExists( + IcebergMetadataColumn::kRowIdColumnName); + } + if (checkIfSplitIsEmpty(runtimeStats)) { VELOX_CHECK(emptySplit_); return; } - createRowReader(std::move(metadataFilter), std::move(rowType)); + // Inject a row-number column when filters, random-skip, or positional + // deletes make the output-to-file-position mapping non-contiguous. + // Check split metadata rather than positionalDeleteFileReaders_ because + // the row reader must be configured before delete files are opened. + const bool hasPositionalDeletes = std::any_of( + icebergSplit_->deleteFiles.begin(), + icebergSplit_->deleteFiles.end(), + [](const IcebergDeleteFile& deleteFile) { + return deleteFile.content == FileContent::kPositionalDeletes && + deleteFile.recordCount > 0; + }); + useRowNumberColumn_ = rowIdOutputIndex_.has_value() && + (scanSpec_->hasFilter() || baseReaderOpts_.randomSkip() != nullptr || + hasPositionalDeletes); + if (useRowNumberColumn_) { + dwio::common::RowNumberColumnInfo rowNumInfo; + rowNumInfo.insertPosition = readerOutputType_->size(); + rowNumInfo.name = ""; + baseRowReaderOpts_.setRowNumberColumnInfo(rowNumInfo); + } else { + baseRowReaderOpts_.setRowNumberColumnInfo(std::nullopt); + } + + createRowReader(std::move(metadataFilter), std::move(rowType), std::nullopt); - std::shared_ptr icebergSplit = - std::dynamic_pointer_cast(hiveSplit_); baseReadOffset_ = 0; splitOffset_ = baseRowReader_->nextRowNumber(); positionalDeleteFileReaders_.clear(); + deletionVectorReaders_.clear(); + equalityDeleteFileReaders_.clear(); - const auto& deleteFiles = icebergSplit->deleteFiles; + const auto& deleteFiles = icebergSplit_->deleteFiles; for (const auto& deleteFile : deleteFiles) { if (deleteFile.content == FileContent::kPositionalDeletes) { if (deleteFile.recordCount > 0) { + if (shouldSkipBySequenceNumber( + deleteFile.dataSequenceNumber, + icebergSplit_->dataSequenceNumber, + /*isEqualityDelete=*/false)) { + continue; + } + + // Skip the delete file if all delete positions are before this split. + // TODO: Skip delete files where all positions are after the split, if + // split row count becomes available. + if (auto iter = + deleteFile.upperBounds.find(IcebergMetadataColumn::kPosId); + iter != deleteFile.upperBounds.end()) { + auto decodedBound = encoding::Base64::decode(iter->second); + VELOX_CHECK_EQ( + decodedBound.size(), + sizeof(uint64_t), + "Unexpected decoded size for positional delete upper bound."); + uint64_t posDeleteUpperBound; + std::memcpy( + &posDeleteUpperBound, decodedBound.data(), sizeof(uint64_t)); + posDeleteUpperBound = folly::Endian::little(posDeleteUpperBound); + if (posDeleteUpperBound < splitOffset_) { + continue; + } + } positionalDeleteFileReaders_.push_back( std::make_unique( deleteFile, - hiveSplit_->filePath, + fileSplit_->filePath, fileHandleFactory_, connectorQueryCtx_, ioExecutor_, - hiveConfig_, + fileConfig_, + dataIoStats_, ioStats_, - fsStats_, runtimeStats, splitOffset_, - hiveSplit_->connectorId)); + fileSplit_->connectorId)); + } + } else if (deleteFile.content == FileContent::kEqualityDeletes) { + if (deleteFile.recordCount > 0 && !deleteFile.equalityFieldIds.empty()) { + if (shouldSkipBySequenceNumber( + deleteFile.dataSequenceNumber, + icebergSplit_->dataSequenceNumber, + /*isEqualityDelete=*/true)) { + continue; + } + + auto [equalityColumnNames, equalityColumnTypes] = + resolveEqualityColumns(deleteFile); + + if (!equalityColumnNames.empty()) { + equalityDeleteFileReaders_.push_back( + std::make_unique( + deleteFile, + equalityColumnNames, + equalityColumnTypes, + fileSplit_->filePath, + fileHandleFactory_, + connectorQueryCtx_, + ioExecutor_, + fileConfig_, + dataIoStats_, + ioStats_, + runtimeStats, + fileSplit_->connectorId)); + } + } + } else if (deleteFile.content == FileContent::kDeletionVector) { + if (deleteFile.recordCount > 0) { + if (shouldSkipBySequenceNumber( + deleteFile.dataSequenceNumber, + icebergSplit_->dataSequenceNumber, + /*isEqualityDelete=*/false)) { + continue; + } + + // If 'referencedDataFile' is set and does not match the split's + // data file, log a warning. Do NOT skip the DV: silently dropping + // a coordinator-shipped DV could mask deletes if the planner and + // worker disagree on path normalization (trailing slash, scheme + // prefix like s3:// vs s3a://, percent-encoding). The coordinator + // already filters per-split, so the worker-side check is purely + // diagnostic. + if (!deleteFile.referencedDataFile.empty() && + deleteFile.referencedDataFile != fileSplit_->filePath) { + LOG(WARNING) + << "Iceberg DV referencedDataFile does not match split path. " + << "Applying DV anyway. referencedDataFile='" + << deleteFile.referencedDataFile << "' splitPath='" + << fileSplit_->filePath << "'"; + } + + deletionVectorReaders_.push_back( + std::make_unique( + deleteFile, splitOffset_, connectorQueryCtx_->memoryPool())); } } else { - VELOX_NYI(); + VELOX_NYI( + "Unsupported delete file content type: {}", + static_cast(deleteFile.content)); + } + } +} + +void IcebergSplitReader::configureEqualityDeleteColumns() { + // Reset partition-column tracking from any prior split before re-augmenting. + equalityAugmentedPartitionColumns_.clear(); + + std::vector extraEqualityColumns; + std::vector extraNames; + std::vector extraTypes; + const auto& deleteFiles = icebergSplit_->deleteFiles; + const auto& splitPartitionKeys = icebergSplit_->partitionKeys; + + for (const auto& deleteFile : deleteFiles) { + if (deleteFile.content != FileContent::kEqualityDeletes || + deleteFile.recordCount == 0 || deleteFile.equalityFieldIds.empty()) { + continue; + } + if (shouldSkipBySequenceNumber( + deleteFile.dataSequenceNumber, + icebergSplit_->dataSequenceNumber, + /*isEqualityDelete=*/true)) { + continue; + } + + auto [equalityColumnNames, equalityColumnTypes] = + resolveEqualityColumns(deleteFile); + for (size_t i = 0; i < equalityColumnNames.size(); ++i) { + const auto& name = equalityColumnNames[i]; + // Skip if this column was already added by a previous delete file. + if (std::find( + extraEqualityColumns.begin(), extraEqualityColumns.end(), name) != + extraEqualityColumns.end()) { + continue; + } + auto* fieldSpec = scanSpec_->childByName(name); + const bool alreadyInOutput = + readerOutputType_->getChildIdxIfExists(name).has_value(); + if (fieldSpec != nullptr && fieldSpec->projectOut() && alreadyInOutput) { + // Already projected by the user (or a previous augmentation) AND + // present in the output type by name. The equality-delete reader can + // probe it directly. + continue; + } + // Either no spec exists, or one exists but is filter-only, or the + // scan-spec child is projected but the column is missing from + // 'readerOutputType_'. In all cases ensure the column ends up in + // 'readerOutputType_' AND has a projected scan-spec child with a + // non-conflicting channel. + if (fieldSpec == nullptr) { + fieldSpec = scanSpec_->getOrCreateChild(name); + } + fieldSpec->setProjectOut(true); + fieldSpec->setChannel( + static_cast( + readerOutputType_->size() + extraEqualityColumns.size())); + + // For partition columns set the partition value directly as a constant + // on the scan-spec child. This is independent of whether the data file + // contains the partition column physically. With the constant set + // up-front, 'adaptColumns' does not need any special-case logic for + // augmented partition columns and the read does not depend on the + // writer's choice of including the partition column in the file. + auto partitionIt = splitPartitionKeys.find(name); + if (partitionIt != splitPartitionKeys.end()) { + // Iceberg encodes DATE partition values as the integer number of + // days since the Unix epoch (e.g. "19345"). The standard + // 'setPartitionValue' helper learns this from the planner-supplied + // ColumnHandle via 'isPartitionDateValueDaysSinceEpoch()', but no + // ColumnHandle is available here when the partition column is not + // in the user's projection. Derive the flag from the column type + // instead — Iceberg always uses days-since-epoch for DATE. + const bool isDaysSinceEpoch = equalityColumnTypes[i]->isDate(); + auto constant = newConstantFromString( + equalityColumnTypes[i], + partitionIt->second, + connectorQueryCtx_->memoryPool(), + fileConfig_->readTimestampPartitionValueAsLocalTime( + connectorQueryCtx_->sessionProperties()), + isDaysSinceEpoch); + fieldSpec->setConstantValue(constant); + // Mirror Java's PARTITION_KEY column-type marking: this column's + // value MUST come from the partition metadata, never from the file + // body. Track it so 'adaptColumns' Branch 1 does not later wipe the + // constant when the file happens to also carry the column. + equalityAugmentedPartitionColumns_.insert(name); + } + + extraEqualityColumns.push_back(name); + extraNames.push_back(name); + extraTypes.push_back(equalityColumnTypes[i]); } } + + if (extraEqualityColumns.empty()) { + return; + } + + // Extend 'readerOutputType_' so the upstream FileDataSource allocates the + // output RowVector wide enough for the augmented scan-spec channels. The + // original projection columns remain at indices [0, originalSize), so + // FileDataSource's positional projection still returns exactly the + // user-requested columns. + auto names = readerOutputType_->names(); + auto types = readerOutputType_->children(); + names.insert(names.end(), extraNames.begin(), extraNames.end()); + types.insert(types.end(), extraTypes.begin(), extraTypes.end()); + readerOutputType_ = ROW(std::move(names), std::move(types)); +} + +std::pair, std::vector> +IcebergSplitReader::resolveEqualityColumns( + const IcebergDeleteFile& deleteFile) const { + std::vector equalityColumnNames; + std::vector equalityColumnTypes; + + const auto& dataColumns = tableHandle_->dataColumns(); + VELOX_CHECK( + dataColumns != nullptr, + "Iceberg equality delete file '{}' cannot be processed because " + "table data columns are not available in HiveTableHandle.", + deleteFile.filePath); + for (const auto& eqFieldId : deleteFile.equalityFieldIds) { + // Field IDs are 1-based sequential for non-evolved schemas. + auto colIdx = static_cast(eqFieldId - 1); + VELOX_CHECK_LT( + colIdx, + dataColumns->size(), + "Equality delete field ID {} out of range. This may indicate " + "schema evolution with non-sequential field IDs, which is " + "not yet supported.", + eqFieldId); + equalityColumnNames.push_back(dataColumns->nameOf(colIdx)); + equalityColumnTypes.push_back(dataColumns->childAt(colIdx)); + } + return {std::move(equalityColumnNames), std::move(equalityColumnTypes)}; } uint64_t IcebergSplitReader::next(uint64_t size, VectorPtr& output) { @@ -114,7 +521,8 @@ uint64_t IcebergSplitReader::next(uint64_t size, VectorPtr& output) { return 0; } - if (!positionalDeleteFileReaders_.empty()) { + if (!positionalDeleteFileReaders_.empty() || + !deletionVectorReaders_.empty()) { auto numBytes = bits::nbytes(actualSize); dwio::common::ensureCapacity( deleteBitmap_, numBytes, connectorQueryCtx_->memoryPool(), false, true); @@ -129,6 +537,17 @@ uint64_t IcebergSplitReader::next(uint64_t size, VectorPtr& output) { ++iter; } } + + for (auto iter = deletionVectorReaders_.begin(); + iter != deletionVectorReaders_.end();) { + (*iter)->readDeletePositions(baseReadOffset_, actualSize, deleteBitmap_); + + if ((*iter)->noMoreData()) { + iter = deletionVectorReaders_.erase(iter); + } else { + ++iter; + } + } } mutation.deletedRows = deleteBitmap_ && deleteBitmap_->size() > 0 @@ -137,7 +556,261 @@ uint64_t IcebergSplitReader::next(uint64_t size, VectorPtr& output) { auto rowsScanned = baseRowReader_->next(actualSize, output, &mutation); + auto* pool = connectorQueryCtx_->memoryPool(); + + if (rowsScanned > 0 && + (lastUpdatedSeqNumOutputIndex_.has_value() || + rowIdOutputIndex_.has_value())) { + auto* rowOutput = output->as(); + VELOX_DCHECK_NOT_NULL( + rowOutput, "Expected RowVector output from table scan"); + + // If lastUpdatedSeqNumOutputIndex_ is present but dataSequenceNumber_ is + // not, the output is a constant NULL value vector already populated by the + // reader via adaptColumns. + if (lastUpdatedSeqNumOutputIndex_.has_value() && + dataSequenceNumber_.has_value()) { + auto& seqNumChild = rowOutput->childAt(*lastUpdatedSeqNumOutputIndex_); + const int64_t seqNum = static_cast(*dataSequenceNumber_); + fillNullsWithInt64( + seqNumChild, pool, [seqNum](vector_size_t) { return seqNum; }); + } + + if (rowIdOutputIndex_.has_value() && firstRowId_.has_value()) { + auto& rowIdChild = rowOutput->childAt(*rowIdOutputIndex_); + if (useRowNumberColumn_) { + // Use the injected row-number column for file-absolute positions. + // The row-number column is always appended last (at index + // readerOutputType_->size()). + const DecodedVector decodedRowNums( + *rowOutput->childAt(readerOutputType_->size())); + const int64_t firstRowId = static_cast(*firstRowId_); + fillNullsWithInt64(rowIdChild, pool, [&](vector_size_t i) { + return firstRowId + decodedRowNums.valueAt(i); + }); + + // Strip the injected row-number column (always last) from the output. + auto children = rowOutput->children(); + children.pop_back(); + output = std::make_shared( + rowOutput->pool(), + readerOutputType_, + rowOutput->nulls(), + rowOutput->size(), + std::move(children)); + } else { + // Contiguous output: _row_id = firstRowId_ + file_pos. + const int64_t base = static_cast(*firstRowId_) + + static_cast(splitOffset_ + baseReadOffset_); + fillNullsWithInt64(rowIdChild, pool, [base](vector_size_t i) { + return base + static_cast(i); + }); + } + } + } + + // Apply equality deletes after reading base data. Unlike positional deletes + // (which set bits before reading), equality deletes require the data values + // to be available for comparison. + if (rowsScanned > 0 && !equalityDeleteFileReaders_.empty()) { + auto outputRowVector = std::dynamic_pointer_cast(output); + VELOX_CHECK_NOT_NULL( + outputRowVector, "Output must be a RowVector for equality deletes."); + + auto numRows = outputRowVector->size(); + + // Use a separate bitmap for equality deletes to track which rows to + // remove from the output. + BufferPtr eqDeleteBitmap = AlignedBuffer::allocate(numRows, pool); + std::memset( + eqDeleteBitmap->asMutable(), 0, eqDeleteBitmap->size()); + + for (auto& reader : equalityDeleteFileReaders_) { + reader->applyDeletes(outputRowVector, eqDeleteBitmap); + } + + // Count surviving rows and compact the output if any rows were deleted. + auto* eqBitmap = eqDeleteBitmap->as(); + vector_size_t numDeleted = 0; + for (vector_size_t i = 0; i < numRows; ++i) { + if (bits::isBitSet(eqBitmap, i)) { + ++numDeleted; + } + } + + if (numDeleted > 0) { + vector_size_t numSurviving = numRows - numDeleted; + if (numSurviving == 0) { + // All rows in this batch were deleted by equality deletes. Do not + // return 0 here — that would be interpreted as end-of-split and + // prematurely stop scanning remaining rows in the data file. + // Instead, set output to an empty vector and return the original + // scanned count so the caller continues reading. + output = BaseVector::create(outputRowVector->type(), 0, pool); + } else { + // Build a list of surviving row ranges and use it to compact. + std::vector ranges; + ranges.reserve(numSurviving); + vector_size_t targetIdx = 0; + for (vector_size_t i = 0; i < numRows; ++i) { + if (!bits::isBitSet(eqBitmap, i)) { + ranges.push_back({i, targetIdx++, 1}); + } + } + + auto newOutput = + BaseVector::create(outputRowVector->type(), numSurviving, pool); + newOutput->copyRanges(outputRowVector.get(), ranges); + newOutput->resize(numSurviving); + output = newOutput; + rowsScanned = numSurviving; + } + } + + return rowsScanned; + } + return rowsScanned; } +std::vector IcebergSplitReader::adaptColumns( + const RowTypePtr& fileType, + const RowTypePtr& tableSchema) const { + std::vector columnTypes = fileType->children(); + auto& childrenSpecs = scanSpec_->children(); + const auto& splitInfoColumns = icebergSplit_->infoColumns; + const bool readTimestampAsLocalTime = + fileConfig_->readTimestampPartitionValueAsLocalTime( + connectorQueryCtx_->sessionProperties()); + // Iceberg table stores all column's data in data file. + for (const auto& childSpec : childrenSpecs) { + const std::string& fieldName = childSpec->fieldName(); + if (auto iter = splitInfoColumns.find(fieldName); + iter != splitInfoColumns.end()) { + auto infoColumnType = readerOutputType_->findChild(fieldName); + auto constant = newConstantFromString( + infoColumnType, + iter->second, + connectorQueryCtx_->memoryPool(), + readTimestampAsLocalTime, + false); + childSpec->setConstantValue(constant); + } else { + auto fileTypeIdx = fileType->getChildIdxIfExists(fieldName); + auto outputTypeIdx = readerOutputType_->getChildIdxIfExists(fieldName); + if (outputTypeIdx.has_value() && fileTypeIdx.has_value()) { + if (equalityAugmentedPartitionColumns_.count(fieldName) > 0) { + // This column was pre-installed as a partition-value constant by + // 'configureEqualityDeleteColumns'. Mirror Java's PARTITION_KEY + // semantics — the value comes from partition metadata, never from + // the file body, even though the file body happens to carry the + // column. Leave the constant in place; the column type entry for + // the file column index stays as the file type so Velox does not + // try to bind this output channel to a file read. + continue; + } + childSpec->setConstantValue(nullptr); + auto& outputType = readerOutputType_->childAt(*outputTypeIdx); + columnTypes[*fileTypeIdx] = outputType; + } else if (!fileTypeIdx.has_value()) { + // Handle columns missing from the data file in several scenarios: + // 1. Partition columns from a Hive-migrated table where partition + // column values are stored in partition metadata rather than in + // the data file itself. + // 2. Schema evolution with default values (Iceberg V3): Column was + // added with an initial-default value. Use the default instead of + // NULL. + // 3. Schema evolution: Column was added after the data file was + // written and doesn't exist in older data files. + // 4. _last_updated_sequence_number: For Iceberg V3 row lineage, if + // the column is not in the file, inherit the data sequence number + // from the file's manifest entry. + // 5. _row_id: For Iceberg V3 row lineage, if the column is not in + // the file, set as NULL constant here. When first_row_id is + // available, next() will replace NULL with first_row_id + file + // position. + // 6. Equality-delete partition columns not in the user's projection: + // 'configureEqualityDeleteColumns' has already pre-installed the + // partition value as a constant, so this branch leaves it alone. + if (fieldName == + IcebergMetadataColumn::kLastUpdatedSequenceNumberColumnName) { + // Column is absent from the data file. If a data sequence number is + // available (V3 snapshot), set a constant so every row inherits it. + // Otherwise (pre-V3 snapshot where the sequence number chain is + // unassigned), return null — analogous to the _row_id null path. + if (dataSequenceNumber_.has_value()) { + childSpec->setConstantValue( + std::make_shared>( + connectorQueryCtx_->memoryPool(), + 1, + false, + BIGINT(), + static_cast(*dataSequenceNumber_))); + } else { + childSpec->setConstantValue( + BaseVector::createNullConstant( + BIGINT(), 1, connectorQueryCtx_->memoryPool())); + } + } else if (fieldName == IcebergMetadataColumn::kRowIdColumnName) { + childSpec->setConstantValue( + BaseVector::createNullConstant( + BIGINT(), 1, connectorQueryCtx_->memoryPool())); + } else if (childSpec->isConstant()) { + // Constant already set (case 6, or set on a previous prepareSplit + // call for the same scanSpec). Nothing to do. + continue; + } else if (auto partitionIt = fileSplit_->partitionKeys.find(fieldName); + partitionIt != fileSplit_->partitionKeys.end()) { + setPartitionValue(childSpec.get(), fieldName, partitionIt->second); + } else { + // Check if column has an initial-default value (Iceberg V3) + bool hasDefaultValue = false; + // The columnHandles_ map is keyed by output name (which may be an + // alias). We need to find the column handle where the handle's name() + // matches fieldName. fieldName is the table column name from + // readerOutputType_. + for (const auto& [outputName, handle] : *columnHandles_) { + if (handle->name() == fieldName) { + auto icebergColumnHandle = + std::dynamic_pointer_cast(handle); + if (icebergColumnHandle && + icebergColumnHandle->initialDefaultValue().has_value()) { + // Use initial-default value for schema evolution. + auto columnType = tableSchema->findChild(fieldName); + VELOX_CHECK_NOT_NULL( + columnType, + "Column '{}' not found in table schema", + fieldName); + auto constant = newConstantFromString( + columnType, + icebergColumnHandle->initialDefaultValue().value(), + connectorQueryCtx_->memoryPool(), + readTimestampAsLocalTime, + false); + childSpec->setConstantValue(constant); + hasDefaultValue = true; + break; + } + } + } + + // Fall back to NULL if no default value + if (!hasDefaultValue) { + auto columnType = tableSchema->findChild(fieldName); + VELOX_CHECK_NOT_NULL( + columnType, "Column '{}' not found in table schema", fieldName); + childSpec->setConstantValue( + BaseVector::createNullConstant( + columnType, 1, connectorQueryCtx_->memoryPool())); + } + } + } + } + } + + scanSpec_->resetCachedValues(false); + + return columnTypes; +} + } // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergSplitReader.h b/velox/connectors/hive/iceberg/IcebergSplitReader.h index 4b3c6b90104..7bba8464b57 100644 --- a/velox/connectors/hive/iceberg/IcebergSplitReader.h +++ b/velox/connectors/hive/iceberg/IcebergSplitReader.h @@ -16,45 +16,169 @@ #pragma once +#include + #include "velox/connectors/Connector.h" -#include "velox/connectors/hive/SplitReader.h" +#include "velox/connectors/hive/FileSplitReader.h" +#include "velox/connectors/hive/iceberg/DeletionVectorReader.h" +#include "velox/connectors/hive/iceberg/EqualityDeleteFileReader.h" #include "velox/connectors/hive/iceberg/PositionalDeleteFileReader.h" namespace facebook::velox::connector::hive::iceberg { +struct HiveIcebergSplit; struct IcebergDeleteFile; -class IcebergSplitReader : public SplitReader { +class IcebergSplitReader : public FileSplitReader { public: IcebergSplitReader( - const std::shared_ptr& hiveSplit, - const HiveTableHandlePtr& hiveTableHandle, - const std::unordered_map* partitionKeys, + const std::shared_ptr& icebergSplit, + const FileTableHandlePtr& tableHandle, + const std::unordered_map* partitionKeys, const ConnectorQueryCtx* connectorQueryCtx, - const std::shared_ptr& hiveConfig, + const std::shared_ptr& fileConfig, const RowTypePtr& readerOutputType, - const std::shared_ptr& ioStats, - const std::shared_ptr& fsStats, + const std::shared_ptr& dataIoStats, + const std::shared_ptr& metadataIoStats, + const std::shared_ptr& ioStats, FileHandleFactory* fileHandleFactory, folly::Executor* executor, - const std::shared_ptr& scanSpec); + const std::shared_ptr& scanSpec, + std::shared_ptr columnHandles); ~IcebergSplitReader() override = default; void prepareSplit( std::shared_ptr metadataFilter, - dwio::common::RuntimeStatistics& runtimeStats) override; + dwio::common::RuntimeStatistics& runtimeStats, + const folly::F14FastMap& fileReadOps = {}) + override; uint64_t next(uint64_t size, VectorPtr& output) override; private: - // The read offset to the beginning of the split in number of rows for the - // current batch for the base data file + /// Adapts the data file schema to match the table schema expected by the + /// query. + /// + /// This method reconciles differences between the physical data file schema + /// and the logical table schema, handling various scenarios where columns may + /// be missing, added, or need special treatment. + /// + /// @param fileType The schema read from the data file's metadata. This + /// represents the actual columns physically present in the Parquet/ORC file. + /// @param tableSchema The logical schema defined in the catalog (e.g., from + /// DDL). This represents the current table schema that queries expect. + /// + /// @return A vector of column types adapted to match the query's + /// expectations, with appropriate type conversions and constant values set + /// for missing or special columns. + /// + /// The method handles the following scenarios for each column in the scan + /// spec: + /// + /// 1. Info columns (e.g., $path, $data_sequence_number, $deleted) + /// These are virtual columns that provide metadata about the file itself. + /// Values are read from the split's infoColumns map and set as constant + /// values in the scanSpec so they're materialized for all rows. + /// + /// 2. Regular columns present in File: + /// Column exists in both fileType and readerOutputType. Type is adapted + /// from fileType to match the expected output type, handling schema + /// evolution where column types may have changed. + /// + /// 3. Columns missing from File: + /// a) Partition columns (hive-migrated tables): + /// Column is marked as partition key in splitPartitionKeys_. + /// In Hive-written Iceberg tables, partition column values are stored + /// in partition metadata, not in the data file itself. Value is read + /// from partition metadata and set as a constant. + /// b) Schema evolution (newly added columns): + /// Column was added to the table schema after this data file was + /// written. Set as NULL constant since the old file doesn't contain + /// this column. + /// c) Row lineage (_last_updated_sequence_number): + /// For Iceberg V3 row lineage, if the column is not in the file, + /// inherit the data sequence number from the file's manifest entry + /// (provided via $data_sequence_number info column). Per the spec, + /// null values indicate the value should be inherited. + /// d) Row lineage (_row_id): + /// Per the spec, null _row_id values are assigned as + /// first_row_id + file position. When first_row_id is available from + /// the split info column $first_row_id, the value is computed + /// in next(). When first_row_id is not available (e.g., + /// pre-V3 tables), NULL is returned. + std::vector adaptColumns( + const RowTypePtr& fileType, + const RowTypePtr& tableSchema) const override; + + // Resolves the equality field IDs of an equality-delete file to the + // corresponding column names and types in the table schema. In Iceberg, + // field IDs for top-level columns are assigned sequentially starting from + // 1, matching the column order in the table schema. + std::pair, std::vector> + resolveEqualityColumns(const IcebergDeleteFile& deleteFile) const; + + // Discovers equality-delete columns that are not in the user's projection + // and augments 'scanSpec_' and 'readerOutputType_' so they are physically + // read and made available in the output RowVector. For partition columns + // the partition value is set as a constant on the scan-spec child so the + // augmentation works regardless of whether the data file physically + // contains the partition column. Augmented columns are appended at the end + // of 'readerOutputType_' so the upstream FileDataSource's positional + // projection naturally drops them from the operator output. + void configureEqualityDeleteColumns(); + + // Names of scan-spec children that 'configureEqualityDeleteColumns' + // pre-installed a partition-value constant on for the current split. + // Mirrors the Java 'PARTITION_KEY' column-type distinction in + // 'IcebergUtil.getColumns(fields, schema, partitionSpec, typeManager)': + // for an equality-delete column that is also an identity-partition + // column, the value MUST come from the split's partition metadata, never + // from the data file body — even if the file body physically contains + // the column. 'adaptColumns' uses this set to skip clearing the partition + // constant on Branch 1, which would otherwise leave the column stuck as + // a constant null and break equality-delete matching after schema + // evolution adds new columns at indices that shift the augmented column + // out of its file-natural position. + folly::F14FastSet equalityAugmentedPartitionColumns_; + + const std::shared_ptr icebergSplit_; + + /// Read offset to the beginning of the split in number of rows for the + /// current batch for the base data file. uint64_t baseReadOffset_; - // The file position for the first row in the split + /// File position for the first row in the split. uint64_t splitOffset_; + // Active readers for positional delete files associated with this split. std::list> positionalDeleteFileReaders_; + // Bitmap of deleted rows in the current batch; set bits mark deleted rows. BufferPtr deleteBitmap_; + // Output column index of _last_updated_sequence_number, if projected. + std::optional lastUpdatedSeqNumOutputIndex_; + // Data sequence number from the split's manifest entry, used to populate + // _last_updated_sequence_number for rows whose stored value is null. + std::optional dataSequenceNumber_; + // First row ID from the split's $first_row_id info column, used to compute + // _row_id as first_row_id + file position for rows whose stored value is + // null. + std::optional firstRowId_; + // Output column index of _row_id, if projected. + std::optional rowIdOutputIndex_; + // Whether an implicit row-number column is needed for _row_id computation + // (set when filters, random-skip, or positional deletes make output + // positions non-contiguous). + bool useRowNumberColumn_{false}; + + /// Readers for Iceberg V3 deletion vectors (Puffin-encoded roaring bitmaps). + std::list> deletionVectorReaders_; + + /// Readers for equality delete files. + std::list> + equalityDeleteFileReaders_; + + /// Column handles map shared with IcebergDataSource. + /// Used for accessing column metadata including initial-default values. + std::shared_ptr columnHandles_; }; } // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/PartitionSpec.cpp b/velox/connectors/hive/iceberg/PartitionSpec.cpp new file mode 100644 index 00000000000..38b4f0b67c3 --- /dev/null +++ b/velox/connectors/hive/iceberg/PartitionSpec.cpp @@ -0,0 +1,159 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/PartitionSpec.h" + +#include "velox/common/EnumDefine.h" + +#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" + +namespace facebook::velox::connector::hive::iceberg { + +namespace { + +TransformCategory getTransformCategory(TransformType transformType) { + switch (transformType) { + case TransformType::kIdentity: + return TransformCategory::kIdentity; + case TransformType::kYear: + case TransformType::kMonth: + case TransformType::kDay: + case TransformType::kHour: + return TransformCategory::kTemporal; + case TransformType::kBucket: + return TransformCategory::kBucket; + case TransformType::kTruncate: + return TransformCategory::kTruncate; + default: + VELOX_UNREACHABLE("Unknown transform type"); + } +} + +bool isValidPartitionType(const TypePtr& type) { + return !( + type->isRow() || type->isArray() || type->isMap() || type->isDouble() || + type->isReal() || isTimestampWithTimeZoneType(type)); +} + +bool canTransform(TransformType transformType, const TypePtr& type) { + switch (transformType) { + case TransformType::kIdentity: + return type->isTinyint() || type->isSmallint() || type->isInteger() || + type->isBigint() || type->isBoolean() || type->isDecimal() || + type->isDate() || type->isTimestamp() || type->isVarchar() || + type->isVarbinary(); + case TransformType::kYear: + case TransformType::kMonth: + case TransformType::kDay: + return type->isDate() || type->isTimestamp(); + case TransformType::kHour: + return type->isTimestamp(); + case TransformType::kBucket: + return type->isInteger() || type->isBigint() || type->isDecimal() || + type->isVarchar() || type->isVarbinary() || type->isDate() || + type->isTimestamp(); + case TransformType::kTruncate: + return type->isInteger() || type->isBigint() || type->isDecimal() || + type->isVarchar() || type->isVarbinary(); + default: + VELOX_UNREACHABLE("Unsupported partition transform type."); + } +} + +const auto& transformTypeNames() { + static const folly::F14FastMap + kTransformNames = { + {TransformType::kIdentity, "identity"}, + {TransformType::kHour, "hour"}, + {TransformType::kDay, "day"}, + {TransformType::kMonth, "month"}, + {TransformType::kYear, "year"}, + {TransformType::kBucket, "bucket"}, + {TransformType::kTruncate, "trunc"}, + }; + return kTransformNames; +} + +const auto& transformCategoryNames() { + static const folly::F14FastMap + kTransformCategoryNames = { + {TransformCategory::kIdentity, "Identity"}, + {TransformCategory::kBucket, "Bucket"}, + {TransformCategory::kTruncate, "Truncate"}, + {TransformCategory::kTemporal, "Temporal"}, + }; + return kTransformCategoryNames; +} + +} // namespace + +VELOX_DEFINE_ENUM_NAME(TransformType, transformTypeNames); + +VELOX_DEFINE_ENUM_NAME(TransformCategory, transformCategoryNames); + +void IcebergPartitionSpec::checkCompatibility() const { + folly::F14FastMap> + columnTransforms; + + for (const auto& field : fields) { + const auto& type = field.type; + const auto& name = field.name; + VELOX_USER_CHECK( + isValidPartitionType(type), + "Type is not supported as a partition column: {}", + type->name()); + + VELOX_USER_CHECK( + canTransform(field.transformType, type), + "Transform is not supported for partition column. Column: '{}', Type: '{}', Transform: '{}'.", + name, + type->name(), + TransformTypeName::toName(field.transformType)); + + columnTransforms[name].emplace_back(field.transformType); + } + + // Check for duplicate transform categories per column. + std::vector errors; + for (const auto& [columnName, transforms] : columnTransforms) { + folly::F14FastSet seenCategories; + for (const auto& transform : transforms) { + auto category = getTransformCategory(transform); + if (!seenCategories.insert(category).second) { + std::vector transformNames; + for (const auto& t : transforms) { + transformNames.emplace_back( + std::string(TransformTypeName::toName(t))); + } + errors.emplace_back( + fmt::format( + "Column: '{}', Category: {}, Transforms: [{}]", + columnName, + TransformCategoryName::toName(category), + folly::join(", ", transformNames))); + break; + } + } + } + + VELOX_USER_CHECK( + errors.empty(), + "Multiple transforms of the same category on a column are not allowed. " + "Each transform category can appear at most once per column. {}", + folly::join("; ", errors)); +} + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/PartitionSpec.h b/velox/connectors/hive/iceberg/PartitionSpec.h new file mode 100644 index 00000000000..99297da0edc --- /dev/null +++ b/velox/connectors/hive/iceberg/PartitionSpec.h @@ -0,0 +1,146 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/type/Type.h" + +namespace facebook::velox::connector::hive::iceberg { + +/// Partition transform types. +/// Defines how source column values are converted into partition keys. +/// See https://iceberg.apache.org/spec/#partition-transforms. +enum class TransformType { + /// Use the source value as-is (no transformation). + kIdentity, + /// Extract a timestamp hour, as hours from 1970-01-01 00:00:00. + kHour, + /// Extract a date or timestamp day, as days from 1970-01-01. + kDay, + /// Extract a date or timestamp month, as months from 1970-01. + kMonth, + /// Extract a date or timestamp year, as years from 1970. + kYear, + /// Hash the value into N buckets for even distribution. Requires an integer + /// parameter specifying the bucket count. + kBucket, + /// Truncate strings or numbers to a specified width. Requires an integer + /// parameter specifying the truncate width. + kTruncate +}; + +VELOX_DECLARE_ENUM_NAME(TransformType); + +/// A single column can be used to produce multiple partition keys, but with +/// following restrictions: +/// - Transforms are organized into 4 categories: Identity, Temporal, +/// Bucket, and Truncate. +/// - Each category can appear at most once per column. +/// - Sample valid specs on same column: ['truncate(a,2)', 'bucket(a,16)', 'a'] +/// or ['year(b)', 'bucket(b, 16)', 'b'] +enum class TransformCategory { + kIdentity, + /// Year/Month/Day/Hour + kTemporal, + kBucket, + kTruncate, +}; + +VELOX_DECLARE_ENUM_NAME(TransformCategory); + +/// Represents how to produce partition data for an Iceberg table. +/// +/// This structure corresponds to the Iceberg Java PartitionSpec class but +/// contains only the necessary fields for Velox. Partition keys are computed +/// by transforming columns in a table. +/// +/// The upstream engine processes this specification through the Iceberg Java +/// library to validate column types, detect duplicates, and generate the +/// partition spec that is passed to Velox. +/// +/// IMPORTANT: Iceberg spec uses field IDs to identify source columns, but +/// Velox RowType only supports matching fields by name. Therefore, Velox uses +/// the partition field name to match against the table schema column names. +/// Callers must ensure that partition field names exactly match the column +/// names in the table schema. +/// +/// The partition spec contains: +/// - Unique ID for versioning and evolution. +/// - Which source columns in current table schema to use for partitioning +/// (identified by field name, not field ID as in the Iceberg spec). +/// - What transforms to apply (identity, bucket, truncate etc.). +/// - Transform parameters (e.g., bucket count, truncate width). +struct IcebergPartitionSpec { + struct Field { + /// Column name as defined in table schema. This column's value is used to + /// compute partition key by applying 'transformType' transformation. + const std::string name; + + /// Column type. + const TypePtr type; + + /// Transform to apply. Callers must ensure the transform is compatible with + /// the column type. + const TransformType transformType; + + /// Optional parameter for transforms that require configuration. + const std::optional parameter; + + /// Returns the result type after applying this transform. + TypePtr resultType() const { + switch (transformType) { + case TransformType::kBucket: + case TransformType::kYear: + case TransformType::kMonth: + case TransformType::kHour: + return INTEGER(); + case TransformType::kDay: + return DATE(); + case TransformType::kIdentity: + case TransformType::kTruncate: + return type; + } + VELOX_UNREACHABLE("Unknown transform type"); + } + }; + + const int32_t specId; + const std::vector fields; + + /// Constructor with validation that: + /// - Each field's type is supported for partitioning. + /// - Each field's transform type is compatible with its data type. + /// - No transform category appears more than once per column (Identity, + /// Temporal, Bucket, and Truncate are separate categories). + /// + /// @param _specId Partition specification ID. + /// @param _fields Vector of partition fields. When empty indicates no + /// partition. + /// @throws VeloxUserError if validation fails. + IcebergPartitionSpec(int32_t _specId, std::vector _fields) + : specId(_specId), fields(std::move(_fields)) { + checkCompatibility(); + } + + private: + // Validates partition fields for correctness. + // Checks type/transform compatibility and transform combination rules. + void checkCompatibility() const; +}; + +using IcebergPartitionSpecPtr = std::shared_ptr; + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/PositionalDeleteFileReader.cpp b/velox/connectors/hive/iceberg/PositionalDeleteFileReader.cpp index d36550ac66a..4dd3100eb4a 100644 --- a/velox/connectors/hive/iceberg/PositionalDeleteFileReader.cpp +++ b/velox/connectors/hive/iceberg/PositionalDeleteFileReader.cpp @@ -16,8 +16,8 @@ #include "velox/connectors/hive/iceberg/PositionalDeleteFileReader.h" -#include "velox/connectors/hive/HiveConnectorUtil.h" -#include "velox/connectors/hive/TableHandle.h" +#include "velox/connectors/hive/BufferedInputBuilder.h" +#include "velox/connectors/hive/FileConnectorUtil.h" #include "velox/connectors/hive/iceberg/IcebergDeleteFile.h" #include "velox/connectors/hive/iceberg/IcebergMetadataColumns.h" #include "velox/dwio/common/ReaderFactory.h" @@ -30,9 +30,9 @@ PositionalDeleteFileReader::PositionalDeleteFileReader( FileHandleFactory* fileHandleFactory, const ConnectorQueryCtx* connectorQueryCtx, folly::Executor* executor, - const std::shared_ptr& hiveConfig, - const std::shared_ptr& ioStats, - const std::shared_ptr& fsStats, + const std::shared_ptr& fileConfig, + const std::shared_ptr& ioStatistics, + const std::shared_ptr& ioStats, dwio::common::RuntimeStatistics& runtimeStats, uint64_t splitOffset, const std::string& connectorId) @@ -41,9 +41,9 @@ PositionalDeleteFileReader::PositionalDeleteFileReader( fileHandleFactory_(fileHandleFactory), executor_(executor), connectorQueryCtx_(connectorQueryCtx), - hiveConfig_(hiveConfig), + fileConfig_(fileConfig), + ioStatistics_(ioStatistics), ioStats_(ioStats), - fsStats_(fsStats), pool_(connectorQueryCtx->memoryPool()), filePathColumn_(IcebergMetadataColumn::icebergDeleteFilePathColumn()), posColumn_(IcebergMetadataColumn::icebergDeletePosColumn()), @@ -56,15 +56,13 @@ PositionalDeleteFileReader::PositionalDeleteFileReader( VELOX_CHECK(deleteFile_.content == FileContent::kPositionalDeletes); VELOX_CHECK(deleteFile_.recordCount); - // TODO: check if the lowerbounds and upperbounds in deleteFile overlap with - // this batch. If not, no need to proceed. - // Create the ScanSpec for this delete file auto scanSpec = std::make_shared(""); scanSpec->addField(posColumn_->name, 0); auto* pathSpec = scanSpec->getOrCreateChild(filePathColumn_->name); - pathSpec->setFilter(std::make_unique( - std::vector({baseFilePath_}), false)); + pathSpec->setFilter( + std::make_unique( + std::vector({baseFilePath_}), false)); // Create the file schema (in RowType) and split that will be used by readers std::vector deleteColumnNames( @@ -84,8 +82,11 @@ PositionalDeleteFileReader::PositionalDeleteFileReader( // Create the Reader and RowReader dwio::common::ReaderOptions deleteReaderOpts(pool_); + // TODO: Use separate IoStatistics for data and metadata. + deleteReaderOpts.setDataIoStats(ioStatistics_); + deleteReaderOpts.setMetadataIoStats(ioStatistics_); configureReaderOptions( - hiveConfig_, + fileConfig_, connectorQueryCtx, deleteFileSchema, deleteSplit_, @@ -96,12 +97,12 @@ PositionalDeleteFileReader::PositionalDeleteFileReader( .filename = deleteFile_.filePath, .tokenProvider = connectorQueryCtx_->fsTokenProvider()}; auto deleteFileHandleCachePtr = fileHandleFactory_->generate(fileHandleKey); - auto deleteFileInput = createBufferedInput( + auto deleteFileInput = BufferedInputBuilder::getInstance()->create( *deleteFileHandleCachePtr, deleteReaderOpts, connectorQueryCtx, + ioStatistics_, ioStats_, - fsStats_, executor_); auto deleteReader = @@ -118,7 +119,7 @@ PositionalDeleteFileReader::PositionalDeleteFileReader( deleteSplit_->filePath, deleteSplit_->partitionKeys, {}, - hiveConfig_->readTimestampPartitionValueAsLocalTime( + fileConfig_->readTimestampPartitionValueAsLocalTime( connectorQueryCtx_->sessionProperties()))) { // We only count the number of base splits skipped as skippedSplits runtime // statistics in Velox. Skipped delta split is only counted as skipped @@ -137,6 +138,7 @@ PositionalDeleteFileReader::PositionalDeleteFileReader( deleteSplit_, nullptr, nullptr, + nullptr, deleteRowReaderOpts); deleteRowReader_.reset(); @@ -251,15 +253,17 @@ void PositionalDeleteFileReader::updateDeleteBitmap( deletePositionsOffset_++; } - deleteBitmapBuffer->setSize(std::max( - static_cast(deleteBitmapBuffer->size()), - deletePositionsOffset_ == 0 || - (deletePositionsOffset_ < deletePositionsVector->size() && - deletePositions[deletePositionsOffset_] >= rowNumberUpperBound) - ? 0 - : bits::nbytes( - deletePositions[deletePositionsOffset_ - 1] + 1 - - rowNumberLowerBound))); + deleteBitmapBuffer->setSize( + std::max( + static_cast(deleteBitmapBuffer->size()), + deletePositionsOffset_ == 0 || + (deletePositionsOffset_ < deletePositionsVector->size() && + deletePositions[deletePositionsOffset_] >= + rowNumberUpperBound) + ? 0 + : bits::nbytes( + deletePositions[deletePositionsOffset_ - 1] + 1 - + rowNumberLowerBound))); } bool PositionalDeleteFileReader::readFinishedForBatch( diff --git a/velox/connectors/hive/iceberg/PositionalDeleteFileReader.h b/velox/connectors/hive/iceberg/PositionalDeleteFileReader.h index 211359d7fb9..7367de7818d 100644 --- a/velox/connectors/hive/iceberg/PositionalDeleteFileReader.h +++ b/velox/connectors/hive/iceberg/PositionalDeleteFileReader.h @@ -20,74 +20,139 @@ #include #include "velox/connectors/Connector.h" +#include "velox/connectors/hive/FileConfig.h" #include "velox/connectors/hive/FileHandle.h" -#include "velox/connectors/hive/HiveConfig.h" #include "velox/connectors/hive/HiveConnectorSplit.h" +#include "velox/connectors/hive/iceberg/IcebergDeleteFile.h" +#include "velox/connectors/hive/iceberg/IcebergMetadataColumns.h" #include "velox/dwio/common/Reader.h" namespace facebook::velox::connector::hive::iceberg { -struct IcebergDeleteFile; -struct IcebergMetadataColumn; - +/// Reads a positional delete file and produces a deletion bitmap for the base +/// data during merge-on-read. The delete file schema per the Iceberg V2 spec +/// is: +/// (file_path: VARCHAR, pos: BIGINT) +/// +/// For each row in the delete file matching the base data file path, this +/// reader records the deleted row position. During read, these positions are +/// converted into a bitmap where set bits indicate rows that should be excluded +/// from the output. +/// +/// The reader processes delete positions incrementally, batch by batch. +/// Positions from the delete file that fall within the current batch range are +/// converted to bitmap bits. Leftover positions beyond the current batch are +/// preserved and carried over to the next readDeletePositions() call. class PositionalDeleteFileReader { public: + /// Constructs a reader for a single positional delete file. + /// + /// Opens the delete file, sets up a filter on file_path = baseFilePath to + /// read only delete positions relevant to the current base data file, and + /// applies testFilters() to skip the entire delete file when possible (e.g., + /// when the delete file contains no entries for this base file). + /// + /// @param deleteFile Metadata about the delete file (path, format, bounds, + /// record count). + /// @param baseFilePath Path of the base data file being read. Used to filter + /// delete records that apply to this specific data file. + /// @param fileHandleFactory Factory for creating file handles. + /// @param connectorQueryCtx Query context providing memory pool, session + /// properties, and filesystem token. + /// @param executor Executor for async I/O operations. + /// @param fileConfig Hive connector configuration. + /// @param ioStatistics Shared I/O statistics counters. + /// @param ioStats Shared I/O stats tracker. + /// @param runtimeStats Runtime statistics to record skipped bytes when the + /// delete file is filtered out entirely. + /// @param splitOffset Row number offset of the current split within the base + /// data file. Delete positions are absolute within the file, so this offset + /// is used to translate them to split-relative positions. + /// @param connectorId Connector ID for constructing the internal split. PositionalDeleteFileReader( const IcebergDeleteFile& deleteFile, const std::string& baseFilePath, FileHandleFactory* fileHandleFactory, const ConnectorQueryCtx* connectorQueryCtx, folly::Executor* executor, - const std::shared_ptr& hiveConfig, - const std::shared_ptr& ioStats, - const std::shared_ptr& fsStats, + const std::shared_ptr& fileConfig, + const std::shared_ptr& ioStatistics, + const std::shared_ptr& ioStats, dwio::common::RuntimeStatistics& runtimeStats, uint64_t splitOffset, const std::string& connectorId); + /// Reads delete positions for the current batch and sets corresponding bits + /// in the deletion bitmap. + /// + /// Processes delete positions in the range + /// [splitOffset + baseReadOffset, splitOffset + baseReadOffset + size). + /// Positions from the delete file that fall within this range have their + /// corresponding bits set in deleteBitmap. Positions beyond the range are + /// buffered for subsequent calls. + /// + /// @param baseReadOffset The read offset from the beginning of the split in + /// number of rows for the current batch. + /// @param size The number of rows in the current batch, before deletion. + /// @param deleteBitmap Output bitmap buffer where set bits mark rows to + /// delete. Bit positions are relative to baseReadOffset within the split. void readDeletePositions( uint64_t baseReadOffset, uint64_t size, BufferPtr deleteBitmap); + /// Returns true when all delete positions have been read and consumed from + /// this file. bool noMoreData(); private: + // Converts delete positions from deletePositionsVector into set bits in the + // deleteBitmapBuffer for positions within + // [splitOffset + baseReadOffset, rowNumberUpperBound). void updateDeleteBitmap( VectorPtr deletePositionsVector, uint64_t baseReadOffset, int64_t rowNumberUpperBound, BufferPtr deleteBitmapBuffer); + // Returns true if enough delete positions have been read for the current + // batch, either because EOF or the next unprocessed position >= + // rowNumberUpperBound. bool readFinishedForBatch(int64_t rowNumberUpperBound); - const IcebergDeleteFile& deleteFile_; - const std::string& baseFilePath_; + const IcebergDeleteFile deleteFile_; + const std::string baseFilePath_; FileHandleFactory* const fileHandleFactory_; folly::Executor* const executor_; - const ConnectorQueryCtx* connectorQueryCtx_; - const std::shared_ptr hiveConfig_; - const std::shared_ptr ioStats_; - const std::shared_ptr fsStats_; - const std::shared_ptr fsStats; + const ConnectorQueryCtx* const connectorQueryCtx_; + const std::shared_ptr fileConfig_; + const std::shared_ptr ioStatistics_; + const std::shared_ptr ioStats_; memory::MemoryPool* const pool_; - std::shared_ptr filePathColumn_; - std::shared_ptr posColumn_; - uint64_t splitOffset_; + // Iceberg metadata column descriptors for the delete file schema. + const std::shared_ptr filePathColumn_; + const std::shared_ptr posColumn_; + + // Row number offset of the current split within the base data file. + const uint64_t splitOffset_; + // Internal split and row reader for the delete file. Reset to nullptr when + // the delete file is fully consumed or skipped by testFilters(). std::shared_ptr deleteSplit_; std::unique_ptr deleteRowReader_; - // The vector to hold the delete positions read from the positional delete - // file. These positions are relative to the start of the whole base data - // file. + + // Holds the raw output from reading the delete file. Contains a RowVector + // with a single pos column. Positions are absolute row numbers within the + // base data file. VectorPtr deletePositionsOutput_; - // The index of deletePositionsOutput_ that indicates up to where the delete - // positions have been converted into the bitmap + + // Index into deletePositionsOutput_ indicating how far delete positions + // have been consumed and converted into bitmap bits. uint64_t deletePositionsOffset_; - // Total number of rows read from this positional delete file reader, - // including the rows filtered out from filters on both filePathColumn_ and - // posColumn_. + + // Total number of rows read from this delete file, including rows filtered + // out by the file_path filter. Used to detect end-of-file. uint64_t totalNumRowsScanned_; }; diff --git a/velox/connectors/hive/iceberg/TransformEvaluator.cpp b/velox/connectors/hive/iceberg/TransformEvaluator.cpp new file mode 100644 index 00000000000..2744bddafa7 --- /dev/null +++ b/velox/connectors/hive/iceberg/TransformEvaluator.cpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/TransformEvaluator.h" + +#include "velox/expression/Expr.h" + +namespace facebook::velox::connector::hive::iceberg { + +TransformEvaluator::TransformEvaluator( + const std::vector& expressions, + const ConnectorQueryCtx* connectorQueryCtx) + : connectorQueryCtx_(connectorQueryCtx) { + VELOX_CHECK_NOT_NULL(connectorQueryCtx_); + exprSet_ = connectorQueryCtx_->expressionEvaluator()->compile(expressions); + VELOX_CHECK_NOT_NULL(exprSet_); +} + +std::vector TransformEvaluator::evaluate( + const RowVectorPtr& input) const { + const auto numRows = input->size(); + const auto numExpressions = exprSet_->exprs().size(); + + std::vector results(numExpressions); + SelectivityVector rows(numRows); + + // Evaluate all expressions in one pass. + connectorQueryCtx_->expressionEvaluator()->evaluate( + exprSet_.get(), rows, *input, results); + + return results; +} + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/TransformEvaluator.h b/velox/connectors/hive/iceberg/TransformEvaluator.h new file mode 100644 index 00000000000..ee7b26f7db8 --- /dev/null +++ b/velox/connectors/hive/iceberg/TransformEvaluator.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/connectors/Connector.h" +#include "velox/core/QueryCtx.h" +#include "velox/expression/Expr.h" + +namespace facebook::velox::connector::hive::iceberg { + +/// Evaluates multiple expressions efficiently using batch evaluation. +/// Expressions are compiled once in the constructor and reused across multiple +/// input batches. +class TransformEvaluator { + public: + /// Creates an evaluator with the given expressions and connector query + /// context. Compiles the expressions once for reuse across multiple + /// evaluations. + /// + /// @param expressions Vector of typed expressions to evaluate. These are + /// typically built using TransformExprBuilder::toExpressions() for Iceberg + /// partition transforms, but can be any valid Velox expressions. The + /// expressions are compiled once during construction. + /// @param connectorQueryCtx Connector query context providing access to the + /// expression evaluator (for compilation and evaluation) and memory pool. + /// Must remain valid for the lifetime of this TransformEvaluator. + TransformEvaluator( + const std::vector& expressions, + const ConnectorQueryCtx* connectorQueryCtx); + + /// Evaluates all expressions on the input data in a single pass. + /// Uses the pre-compiled ExprSet from the constructor for efficiency. + /// + /// The input RowType must match the RowType used when building the + /// expressions (passed to TransformExprBuilder::toExpressions). The column + /// positions, names and types must align. Create new TransformEvaluator for + /// input that has different RowType with the one when building the + /// expressions. + /// + /// @param input Input row vector containing the source data. Must have the + /// same RowType (column positions, names and types) as used when building the + /// expressions in the constructor. + /// @return Vector of result columns, one for each expression, in the same + /// order as the expressions provided to the constructor. + std::vector evaluate(const RowVectorPtr& input) const; + + private: + const ConnectorQueryCtx* connectorQueryCtx_; + std::unique_ptr exprSet_; +}; + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/TransformExprBuilder.cpp b/velox/connectors/hive/iceberg/TransformExprBuilder.cpp new file mode 100644 index 00000000000..4befbfb50b0 --- /dev/null +++ b/velox/connectors/hive/iceberg/TransformExprBuilder.cpp @@ -0,0 +1,111 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/connectors/hive/iceberg/TransformExprBuilder.h" +#include "velox/core/Expressions.h" + +namespace facebook::velox::connector::hive::iceberg { + +namespace { + +/// Converts a single partition field to a typed expression. +/// +/// Builds an expression tree for one partition transform. Identity transforms +/// become FieldAccessTypedExpr, while other transforms (bucket, truncate, +/// year, month, day, hour) become CallTypedExpr with appropriate function +/// names and parameters. +/// +/// @param field Partition field containing transform type, source column +/// type, and optional parameter (e.g., bucket count, truncate width). +/// @param inputFieldName Name of the source column in the input RowVector. +/// @param icebergFuncPrefix Prefix of iceberg transform function names. +/// @return Typed expression representing the transform. +core::TypedExprPtr toExpression( + const IcebergPartitionSpec::Field& field, + const std::string& inputFieldName, + const std::string& icebergFuncPrefix) { + // For identity transform, just return a field access expression. + if (field.transformType == TransformType::kIdentity) { + return std::make_shared( + field.type, inputFieldName); + } + + // For other transforms, build a CallTypedExpr with the appropriate function. + std::string functionName; + switch (field.transformType) { + case TransformType::kBucket: + functionName = icebergFuncPrefix + "bucket"; + break; + case TransformType::kTruncate: + functionName = icebergFuncPrefix + "truncate"; + break; + case TransformType::kYear: + functionName = icebergFuncPrefix + "years"; + break; + case TransformType::kMonth: + functionName = icebergFuncPrefix + "months"; + break; + case TransformType::kDay: + functionName = icebergFuncPrefix + "days"; + break; + case TransformType::kHour: + functionName = icebergFuncPrefix + "hours"; + break; + case TransformType::kIdentity: + break; + } + + // Build the expression arguments. + std::vector exprArgs; + if (field.parameter.has_value()) { + exprArgs.emplace_back( + std::make_shared( + INTEGER(), Variant(field.parameter.value()))); + } + exprArgs.emplace_back( + std::make_shared(field.type, inputFieldName)); + + return std::make_shared( + field.resultType(), std::move(exprArgs), functionName); +} + +} // namespace + +std::vector TransformExprBuilder::toExpressions( + const IcebergPartitionSpecPtr& partitionSpec, + const std::vector& partitionChannels, + const RowTypePtr& inputType, + const std::string& icebergFuncPrefix) { + VELOX_CHECK_EQ( + partitionSpec->fields.size(), + partitionChannels.size(), + "Number of partition fields must match number of partition channels"); + + const auto numTransforms = partitionChannels.size(); + std::vector transformExprs; + transformExprs.reserve(numTransforms); + + for (auto i = 0; i < numTransforms; i++) { + const auto channel = partitionChannels[i]; + transformExprs.emplace_back(toExpression( + partitionSpec->fields.at(i), + inputType->nameOf(channel), + icebergFuncPrefix)); + } + + return transformExprs; +} + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/TransformExprBuilder.h b/velox/connectors/hive/iceberg/TransformExprBuilder.h new file mode 100644 index 00000000000..b583adcf97d --- /dev/null +++ b/velox/connectors/hive/iceberg/TransformExprBuilder.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/connectors/hive/iceberg/PartitionSpec.h" +#include "velox/expression/Expr.h" + +namespace facebook::velox::connector::hive::iceberg { + +/// Converts Iceberg partition specification to Velox expressions. +class TransformExprBuilder { + public: + /// Converts partition specification to a list of typed expressions. + /// + /// @param partitionSpec Iceberg partition specification containing transform + /// definitions for each partition field. + /// @param partitionChannels Column indices (0-based) in the input RowVector + /// that correspond to each partition field. Must have the same size as + /// partitionSpec->fields. Provides the positional mapping from partition spec + /// fields to input RowVector columns. + /// @param inputType The row type of the input data. This is necessary for + /// building expressions because the column names in partitionSpec reference + /// table schema names, which might not match the column names in inputType + /// (e.g., inputType may use generated names like c0, c1, c2). The + /// FieldAccessTypedExpr must be built using the actual column names from + /// inputType that will be present at runtime. The partitionChannels provide + /// the positional mapping to locate the correct columns. + /// @param icebergFuncPrefix Prefix for Iceberg transform function names. + /// @return Vector of typed expressions, one for each partition field. + static std::vector toExpressions( + const IcebergPartitionSpecPtr& partitionSpec, + const std::vector& partitionChannels, + const RowTypePtr& inputType, + const std::string& icebergFuncPrefix); +}; + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/WriterOptionsAdapter.cpp b/velox/connectors/hive/iceberg/WriterOptionsAdapter.cpp new file mode 100644 index 00000000000..8bd0fc08e75 --- /dev/null +++ b/velox/connectors/hive/iceberg/WriterOptionsAdapter.cpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/WriterOptionsAdapter.h" + +#include "velox/common/base/Exceptions.h" +#include "velox/dwio/dwrf/writer/Writer.h" +#include "velox/dwio/parquet/writer/WriterConfig.h" + +namespace facebook::velox::connector::hive::iceberg { + +namespace { + +// Manifest format string emitted in Iceberg commit messages for files that +// share the ORC on-disk family. Iceberg's manifest vocabulary has no DWRF or +// NIMBLE enum, so DWRF and NIMBLE files are reported as "ORC" per the +// cross-engine convention shared with the Java planner (see +// FileFormat.{DWRF,NIMBLE}.toIceberg() in presto-facebook-iceberg). +constexpr std::string_view kOrcManifestFormat{"ORC"}; + +class ParquetWriterOptionsAdapter : public WriterOptionsAdapter { + public: + std::string manifestFormatString() const override { + return "PARQUET"; + } + + void applyPreConfigs(dwio::common::WriterOptions& options) const override { + // Per Iceberg spec (https://iceberg.apache.org/spec/#parquet): + // - Timestamps must be stored with microsecond precision. + // - Timestamps must NOT be adjusted to UTC; written as-is without + // timezone conversion (empty string disables conversion). + // + // Settings are routed through serdeParameters to avoid pulling in + // parquet-specific headers. Keys must match + // kParquetSerdeTimestampUnit and kParquetSerdeTimestampTimezone in + // velox/dwio/parquet/writer/Writer.h. The value "6" represents + // microseconds (TimestampPrecision::kMicroseconds). + options.serdeParameters[parquet::WriterConfig::kParquetSerdeTimestampUnit] = + "6"; + options.serdeParameters + [parquet::WriterConfig::kParquetSerdeTimestampTimezone] = ""; + } +}; + +class DwrfWriterOptionsAdapter : public WriterOptionsAdapter { + public: + std::string manifestFormatString() const override { + return std::string{kOrcManifestFormat}; + } + + void applyPostConfigs(dwio::common::WriterOptions& options) const override { + // DWRF stores microsecond-precision timestamps natively, so no + // precision conversion is required; only timezone adjustment must be + // disabled per the Iceberg spec. Unlike Parquet, DWRF exposes + // timestamp configuration as direct fields on dwrf::WriterOptions + // rather than serdeParameters. + auto* dwrfOptions = dynamic_cast(&options); + if (dwrfOptions == nullptr) { + return; + } + dwrfOptions->adjustTimestampToTimezone = false; + dwrfOptions->sessionTimezone = nullptr; + } +}; + +class NimbleWriterOptionsAdapter : public WriterOptionsAdapter { + public: + // Reports NIMBLE files as ORC in the manifest so cross-engine readers + // (Presto coordinator, catalog) can interpret the commit message. The + // actual on-disk format is identified at read time via the file + // extension and on-disk magic bytes, not via this string. + std::string manifestFormatString() const override { + return std::string{kOrcManifestFormat}; + } +}; + +} // namespace + +std::unique_ptr createWriterOptionsAdapter( + dwio::common::FileFormat format) { + // ORC is intentionally excluded until a dedicated ORC end-to-end test + // exists. + // NOLINTNEXTLINE(clang-diagnostic-switch-enum) + switch (format) { + case dwio::common::FileFormat::PARQUET: + return std::make_unique(); + case dwio::common::FileFormat::DWRF: + return std::make_unique(); + case dwio::common::FileFormat::NIMBLE: + return std::make_unique(); + default: + return nullptr; + } +} + +bool isSupportedFileFormat(dwio::common::FileFormat format) { + return createWriterOptionsAdapter(format) != nullptr; +} + +std::string toManifestFormatString(dwio::common::FileFormat format) { + auto adapter = createWriterOptionsAdapter(format); + VELOX_CHECK_NOT_NULL( + adapter, + "Unsupported file format for Iceberg manifest: {}", + dwio::common::toString(format)); + return adapter->manifestFormatString(); +} + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/WriterOptionsAdapter.h b/velox/connectors/hive/iceberg/WriterOptionsAdapter.h new file mode 100644 index 00000000000..29dae1dd7ee --- /dev/null +++ b/velox/connectors/hive/iceberg/WriterOptionsAdapter.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include "velox/dwio/common/Options.h" + +namespace facebook::velox::connector::hive::iceberg { + +/// Format-specific adapter for an Iceberg-bound WriterOptions instance. +/// Encapsulates the per-file-format behavior (manifest format string, +/// pre/post-processConfigs hooks on the writer options) behind a virtual +/// interface. Dispatched at runtime via createWriterOptionsAdapter() based +/// on the table's storage format. +class WriterOptionsAdapter { + public: + virtual ~WriterOptionsAdapter() = default; + + /// Identifier reported in Iceberg manifest commit messages. Iceberg's + /// file-format vocabulary has no DWRF enum; per the Iceberg SDK + /// convention, DWRF files are reported as "ORC" so downstream consumers + /// (Presto coordinator, catalog) can interpret the commit message. + virtual std::string manifestFormatString() const = 0; + + /// Hook applied to WriterOptions BEFORE processConfigs() runs. Used for + /// settings that flow through serdeParameters. + virtual void applyPreConfigs(dwio::common::WriterOptions& /*options*/) const { + } + + /// Hook applied to WriterOptions AFTER processConfigs() runs. Used for + /// direct field assignments that must not be overwritten by + /// config-driven processing. + virtual void applyPostConfigs( + dwio::common::WriterOptions& /*options*/) const {} +}; + +/// Returns the adapter for the given file format, or nullptr for +/// unsupported formats. Single source of truth for which file formats the +/// Iceberg DataSink supports on the write path. +std::unique_ptr createWriterOptionsAdapter( + dwio::common::FileFormat format); + +/// True if the Iceberg DataSink can write the given file format. +bool isSupportedFileFormat(dwio::common::FileFormat format); + +/// Maps a Velox file format to the string used in Iceberg manifest commit +/// messages. Throws if the format is not supported. +std::string toManifestFormatString(dwio::common::FileFormat format); + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/tests/CMakeLists.txt b/velox/connectors/hive/iceberg/tests/CMakeLists.txt index 3e54d543175..8931bb6d0e9 100644 --- a/velox/connectors/hive/iceberg/tests/CMakeLists.txt +++ b/velox/connectors/hive/iceberg/tests/CMakeLists.txt @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. add_library(velox_dwio_iceberg_reader_benchmark_lib IcebergSplitReaderBenchmark.cpp) +velox_add_test_headers(velox_dwio_iceberg_reader_benchmark_lib IcebergSplitReaderBenchmark.h) target_link_libraries( velox_dwio_iceberg_reader_benchmark_lib velox_exec_test_lib @@ -30,6 +31,7 @@ if(VELOX_ENABLE_BENCHMARKS) velox_exec_test_lib velox_exec velox_hive_connector + velox_hive_iceberg_splitreader Folly::folly Folly::follybenchmark ${TEST_LINK_LIBS} @@ -56,9 +58,118 @@ if(NOT VELOX_DISABLE_GOOGLETEST) GTest::gtest GTest::gtest_main ) + + add_executable( + velox_hive_iceberg_insert_test + IcebergConnectorTest.cpp + IcebergInsertTest.cpp + IcebergParquetStatsTest.cpp + IcebergTestBase.cpp + Main.cpp + PartitionNameTest.cpp + PartitionSpecTest.cpp + PartitionValueFormatterTest.cpp + TransformE2ETest.cpp + TransformTest.cpp + ) + velox_add_test_headers(velox_hive_iceberg_insert_test IcebergTestBase.h) + + add_test(velox_hive_iceberg_insert_test velox_hive_iceberg_insert_test) + + target_link_libraries( + velox_hive_iceberg_insert_test + velox_exec_test_lib + velox_hive_connector + velox_hive_iceberg_splitreader + velox_vector_fuzzer + GTest::gtest + GTest::gtest_main + ) + + add_executable(velox_hive_iceberg_deletion_vector_test DeletionVectorReaderTest.cpp) + add_test(velox_hive_iceberg_deletion_vector_test velox_hive_iceberg_deletion_vector_test) + + target_link_libraries( + velox_hive_iceberg_deletion_vector_test + velox_hive_connector + velox_hive_iceberg_splitreader + velox_exec_test_lib + velox_dwio_common_test_utils + GTest::gtest + GTest::gtest_main + ) + if(VELOX_ENABLE_PARQUET) target_link_libraries(velox_hive_iceberg_test velox_dwio_parquet_reader) file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/examples DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) + + target_link_libraries( + velox_hive_iceberg_insert_test + velox_dwio_parquet_reader + velox_dwio_parquet_writer + ) endif() + + add_executable(velox_hive_iceberg_equality_delete_test EqualityDeleteFileReaderTest.cpp Main.cpp) + add_test(velox_hive_iceberg_equality_delete_test velox_hive_iceberg_equality_delete_test) + + target_link_libraries( + velox_hive_iceberg_equality_delete_test + velox_hive_connector + velox_hive_iceberg_splitreader + velox_exec_test_lib + velox_dwio_common_test_utils + GTest::gtest + ) + + add_executable(velox_hive_iceberg_deletion_vector_writer_test DeletionVectorWriterTest.cpp) + add_test( + velox_hive_iceberg_deletion_vector_writer_test + velox_hive_iceberg_deletion_vector_writer_test + ) + + target_link_libraries( + velox_hive_iceberg_deletion_vector_writer_test + velox_hive_connector + velox_hive_iceberg_splitreader + velox_exec_test_lib + velox_dwio_common_test_utils + GTest::gtest + GTest::gtest_main + ) + + add_executable( + velox_hive_iceberg_dwrf_insert_test + IcebergDwrfInsertTest.cpp + IcebergTestBase.cpp + Main.cpp + ) + add_test(velox_hive_iceberg_dwrf_insert_test velox_hive_iceberg_dwrf_insert_test) + + target_link_libraries( + velox_hive_iceberg_dwrf_insert_test + velox_hive_connector + velox_hive_iceberg_splitreader + velox_exec_test_lib + velox_dwio_common_test_utils + velox_dwio_dwrf_reader + velox_dwio_dwrf_writer + velox_dwio_parquet_reader + velox_dwio_parquet_writer + GTest::gtest + GTest::gtest_main + ) + + add_executable(velox_hive_writer_options_adapter_test WriterOptionsAdapterTest.cpp) + add_test(velox_hive_writer_options_adapter_test velox_hive_writer_options_adapter_test) + + target_link_libraries( + velox_hive_writer_options_adapter_test + velox_hive_connector + velox_hive_iceberg_splitreader + velox_dwio_common_exception + GTest::gtest + GTest::gtest_main + ) endif() diff --git a/velox/connectors/hive/iceberg/tests/DeletionVectorReaderTest.cpp b/velox/connectors/hive/iceberg/tests/DeletionVectorReaderTest.cpp new file mode 100644 index 00000000000..001443fffbf --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/DeletionVectorReaderTest.cpp @@ -0,0 +1,613 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/DeletionVectorReader.h" + +#include + +#include + +#include "velox/common/base/BitUtil.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/file/FileSystems.h" +#include "velox/common/memory/Memory.h" +#include "velox/common/testutil/TempFilePath.h" + +using namespace facebook::velox; +using namespace facebook::velox::connector::hive::iceberg; +using namespace facebook::velox::common::testutil; + +namespace { + +// Serializes a roaring bitmap in the portable format (no-run variant, +// cookie = 12346). Supports only array containers (cardinality <= 4096). +// This is the simplest format the DeletionVectorReader needs to parse. +std::string serializeRoaringBitmapNoRun(const std::vector& positions) { + if (positions.empty()) { + // Empty bitmap: cookie + 0 containers. + std::string data(8, '\0'); + uint32_t cookie = 12346; + uint32_t numContainers = 0; + std::memcpy(data.data(), &cookie, 4); + std::memcpy(data.data() + 4, &numContainers, 4); + return data; + } + + // Group positions by high 16 bits. + std::map> containers; + for (auto pos : positions) { + auto key = static_cast(pos >> 16); + auto low = static_cast(pos & 0xFFFF); + containers[key].push_back(low); + } + + for (auto& [key, vals] : containers) { + std::sort(vals.begin(), vals.end()); + } + + uint32_t numContainers = static_cast(containers.size()); + + std::string data; + // Cookie. + constexpr uint32_t cookie = 12346; + data.append(reinterpret_cast(&cookie), 4); + // Container count. + data.append(reinterpret_cast(&numContainers), 4); + + // Key-cardinality pairs. + for (auto& [key, vals] : containers) { + uint16_t cardMinus1 = static_cast(vals.size() - 1); + data.append(reinterpret_cast(&key), 2); + data.append(reinterpret_cast(&cardMinus1), 2); + } + + // Offset section (required for > 0 containers) + if (numContainers > 0) { + uint32_t offset = 4 + 4 + numContainers * 4 + numContainers * 4; + for (auto& [key, vals] : containers) { + data.append(reinterpret_cast(&offset), 4); + offset += static_cast(vals.size()) * 2; + } + } + + // Container data (array containers: sorted uint16 values). + for (auto& [key, vals] : containers) { + for (auto v : vals) { + data.append(reinterpret_cast(&v), 2); + } + } + + return data; +} + +// Serializes a roaring bitmap in the portable format with run containers +// (cookie = 12347). All containers are run-encoded. +std::string serializeRoaringBitmapWithRuns( + const std::vector< + std::pair>>>& + containerRuns) { + // containerRuns: vector of (highBitsKey, vector of (start, lengthMinus1)). + uint32_t numContainers = static_cast(containerRuns.size()); + + // Cookie: low 16 bits = 12347, high 16 bits = numContainers - 1. + uint32_t cookie = static_cast(12347) | ((numContainers - 1) << 16); + + std::string data; + data.append(reinterpret_cast(&cookie), 4); + + // Run bitmap: all containers are run containers. ceil(numContainers / 8) + // bytes. + uint32_t runBitmapBytes = (numContainers + 7) / 8; + std::vector runBitmap(runBitmapBytes, 0xFF); + data.append(reinterpret_cast(runBitmap.data()), runBitmapBytes); + + // Compute cardinality for each container. + std::vector cardinalities; + for (auto& [key, runs] : containerRuns) { + uint32_t card = 0; + for (auto& [start, lenMinus1] : runs) { + card += static_cast(lenMinus1) + 1; + } + cardinalities.push_back(card); + } + + // Key-cardinality pairs. + for (size_t i = 0; i < containerRuns.size(); ++i) { + uint16_t key = containerRuns[i].first; + uint16_t cardMinus1 = static_cast(cardinalities[i] - 1); + data.append(reinterpret_cast(&key), 2); + data.append(reinterpret_cast(&cardMinus1), 2); + } + + // Offset section (required for >= 4) + constexpr uint32_t kRunContainersNoOffsetThreshold = 4; + if (numContainers >= kRunContainersNoOffsetThreshold) { + // First container offset = cookie (4) + runBitmap (runBitmapBytes) + // + descriptive header (4 * numContainers) + offset header + // (4 * numContainers). + uint32_t offset = + 4 + runBitmapBytes + 4 * numContainers + 4 * numContainers; + for (auto& [key, runs] : containerRuns) { + data.append(reinterpret_cast(&offset), 4); + // Each run container occupies 2 + 4 * numRuns bytes. + offset += 2 + 4 * static_cast(runs.size()); + } + } + + // Container data: each run container has numRuns (uint16) followed by + // (start, lengthMinus1) pairs. + for (auto& [key, runs] : containerRuns) { + uint16_t numRuns = static_cast(runs.size()); + data.append(reinterpret_cast(&numRuns), 2); + for (auto& [start, lenMinus1] : runs) { + data.append(reinterpret_cast(&start), 2); + data.append(reinterpret_cast(&lenMinus1), 2); + } + } + + return data; +} + +// Expands a list of run-encoded containers into the full set of positions +// they represent. Each container is keyed by its high 16 bits and contains +// runs of (start, lengthMinus1) +std::vector expandRuns( + const std::vector< + std::pair>>>& + containerRuns) { + std::vector result; + for (const auto& [key, runs] : containerRuns) { + const uint64_t base = static_cast(key) * 65536; + for (const auto& [start, lengthMinus1] : runs) { + for (uint64_t i = 0; i <= lengthMinus1; ++i) { + result.push_back(base + start + i); + } + } + } + return result; +} + +// Writes binary data to a temp file and returns the path. +std::shared_ptr writeDvFile(const std::string& bitmapData) { + auto tempFile = TempFilePath::create(); + // Write directly via C++ streams since TempFilePath already creates the + // file and the local filesystem openFileForWrite may not overwrite. + std::ofstream out(tempFile->getPath(), std::ios::binary | std::ios::trunc); + VELOX_CHECK(out.good(), "Failed to open temp file for writing"); + out.write(bitmapData.data(), static_cast(bitmapData.size())); + out.close(); + return tempFile; +} + +// Creates an IcebergDeleteFile for a deletion vector. Uses the typed +/// 'contentOffset' / 'contentLength' fields rather than the legacy bounds-map +/// encoding. +IcebergDeleteFile makeDvDeleteFile( + const std::string& filePath, + uint64_t recordCount, + uint64_t fileSize, + uint64_t blobOffset = 0, + std::optional blobLength = std::nullopt) { + const int64_t contentLength = blobLength.has_value() + ? static_cast(blobLength.value()) + : static_cast(fileSize); + + return IcebergDeleteFile( + FileContent::kDeletionVector, + filePath, + dwio::common::FileFormat::DWRF, + recordCount, + fileSize, + /*equalityFieldIds=*/{}, + /*lowerBounds=*/{}, + /*upperBounds=*/{}, + /*dataSequenceNumber=*/0, + /*contentOffset=*/static_cast(blobOffset), + /*contentLength=*/contentLength); +} + +// Extracts which bits are set in a bitmap buffer. +std::vector getSetBits(const BufferPtr& bitmap, uint64_t size) { + auto* raw = bitmap->as(); + std::vector result; + for (uint64_t i = 0; i < size; ++i) { + if (bits::isBitSet(raw, i)) { + result.push_back(i); + } + } + return result; +} + +} // namespace + +class DeletionVectorReaderTest : public ::testing::Test { + protected: + static void SetUpTestSuite() { + memory::MemoryManager::initialize(memory::MemoryManager::Options{}); + } + + void SetUp() override { + filesystems::registerLocalFileSystem(); + pool_ = memory::memoryManager()->addLeafPool("DeletionVectorReaderTest"); + } + + BufferPtr allocateBitmap(uint64_t numBits) { + auto numBytes = bits::nbytes(numBits); + auto buffer = AlignedBuffer::allocate(numBytes, pool_.get(), 0); + return buffer; + } + + std::shared_ptr pool_; +}; + +TEST_F(DeletionVectorReaderTest, basicArrayContainer) { + // Positions: 0, 5, 10, 99. + std::vector positions = {0, 5, 10, 99}; + auto bitmapData = serializeRoaringBitmapNoRun(positions); + auto tempFile = writeDvFile(bitmapData); + auto fileSize = static_cast(bitmapData.size()); + + auto dvFile = + makeDvDeleteFile(tempFile->getPath(), positions.size(), fileSize); + + DeletionVectorReader reader(dvFile, 0, pool_.get()); + EXPECT_FALSE(reader.noMoreData()); + + auto bitmap = allocateBitmap(100); + reader.readDeletePositions(0, 100, bitmap); + + auto setBits = getSetBits(bitmap, 100); + EXPECT_EQ(setBits, (std::vector{0, 5, 10, 99})); + EXPECT_TRUE(reader.noMoreData()); +} + +TEST_F(DeletionVectorReaderTest, batchRangeFiltering) { + // Positions: 10, 20, 30, 40, 50. + std::vector positions = {10, 20, 30, 40, 50}; + auto bitmapData = serializeRoaringBitmapNoRun(positions); + auto tempFile = writeDvFile(bitmapData); + auto fileSize = static_cast(bitmapData.size()); + + auto dvFile = + makeDvDeleteFile(tempFile->getPath(), positions.size(), fileSize); + + DeletionVectorReader reader(dvFile, 0, pool_.get()); + + // First batch: rows 0-24 (should contain positions 10, 20). + auto bitmap1 = allocateBitmap(25); + reader.readDeletePositions(0, 25, bitmap1); + auto bits1 = getSetBits(bitmap1, 25); + EXPECT_EQ(bits1, (std::vector{10, 20})); + EXPECT_FALSE(reader.noMoreData()); + + // Second batch: rows 25-49 (should contain positions 30, 40). + auto bitmap2 = allocateBitmap(25); + reader.readDeletePositions(25, 25, bitmap2); + auto bits2 = getSetBits(bitmap2, 25); + EXPECT_EQ(bits2, (std::vector{5, 15})); + EXPECT_FALSE(reader.noMoreData()); + + // Third batch: rows 50-74 (should contain position 50). + auto bitmap3 = allocateBitmap(25); + reader.readDeletePositions(50, 25, bitmap3); + auto bits3 = getSetBits(bitmap3, 25); + EXPECT_EQ(bits3, (std::vector{0})); + EXPECT_TRUE(reader.noMoreData()); +} + +TEST_F(DeletionVectorReaderTest, splitOffset) { + // Positions: 100, 105, 110. + std::vector positions = {100, 105, 110}; + auto bitmapData = serializeRoaringBitmapNoRun(positions); + auto tempFile = writeDvFile(bitmapData); + auto fileSize = static_cast(bitmapData.size()); + + auto dvFile = + makeDvDeleteFile(tempFile->getPath(), positions.size(), fileSize); + + // Split starts at row 100. + DeletionVectorReader reader(dvFile, 100, pool_.get()); + + auto bitmap = allocateBitmap(20); + reader.readDeletePositions(0, 20, bitmap); + + // Positions 100, 105, 110 relative to splitOffset=100, baseReadOffset=0 + // become bit indices 0, 5, 10. + auto setBits = getSetBits(bitmap, 20); + EXPECT_EQ(setBits, (std::vector{0, 5, 10})); + EXPECT_TRUE(reader.noMoreData()); +} + +TEST_F(DeletionVectorReaderTest, splitOffsetWithBaseReadOffset) { + // Positions: 200, 210, 220. + std::vector positions = {200, 210, 220}; + auto bitmapData = serializeRoaringBitmapNoRun(positions); + auto tempFile = writeDvFile(bitmapData); + auto fileSize = static_cast(bitmapData.size()); + + auto dvFile = + makeDvDeleteFile(tempFile->getPath(), positions.size(), fileSize); + + // Split starts at row 100. + DeletionVectorReader reader(dvFile, 100, pool_.get()); + + // First batch: baseReadOffset=100, so file positions [200, 300). + // Positions 200, 210, 220 are all in range. + auto bitmap = allocateBitmap(100); + reader.readDeletePositions(100, 100, bitmap); + + auto setBits = getSetBits(bitmap, 100); + EXPECT_EQ(setBits, (std::vector{0, 10, 20})); + EXPECT_TRUE(reader.noMoreData()); +} + +TEST_F(DeletionVectorReaderTest, noDeletesInRange) { + // Positions: 1000, 2000. + std::vector positions = {1000, 2000}; + auto bitmapData = serializeRoaringBitmapNoRun(positions); + auto tempFile = writeDvFile(bitmapData); + auto fileSize = static_cast(bitmapData.size()); + + auto dvFile = + makeDvDeleteFile(tempFile->getPath(), positions.size(), fileSize); + + DeletionVectorReader reader(dvFile, 0, pool_.get()); + + // Batch covers rows 0-99, no deletions in this range. + auto bitmap = allocateBitmap(100); + reader.readDeletePositions(0, 100, bitmap); + + auto setBits = getSetBits(bitmap, 100); + EXPECT_TRUE(setBits.empty()); + EXPECT_FALSE(reader.noMoreData()); +} + +TEST_F(DeletionVectorReaderTest, runContainers) { + // Use run-encoded containers: positions 10-19 and 50-59. + std::vector>>> + containerRuns = { + {0, {{10, 9}, {50, 9}}}, + }; + auto bitmapData = serializeRoaringBitmapWithRuns(containerRuns); + auto tempFile = writeDvFile(bitmapData); + auto fileSize = static_cast(bitmapData.size()); + + auto expected = expandRuns(containerRuns); + auto dvFile = + makeDvDeleteFile(tempFile->getPath(), expected.size(), fileSize); + + DeletionVectorReader reader(dvFile, 0, pool_.get()); + + auto bitmap = allocateBitmap(100); + reader.readDeletePositions(0, 100, bitmap); + + auto setBits = getSetBits(bitmap, 100); + EXPECT_EQ(setBits, expected); + EXPECT_TRUE(reader.noMoreData()); +} + +TEST_F(DeletionVectorReaderTest, runContainersWithOffsetHeader) { + std::vector>>> + containerRuns = { + {0, {{10, 4}}}, // positions 10-14 + {1, {{0, 2}, {100, 1}}}, // positions 65536-65538, 65636-65637 + {2, {{500, 0}}}, // position 131072+500 = 131572 + {3, {{1000, 9}}}, // positions 196608+1000..196608+1009 + }; + auto bitmapData = serializeRoaringBitmapWithRuns(containerRuns); + auto tempFile = writeDvFile(bitmapData); + auto fileSize = static_cast(bitmapData.size()); + + auto expected = expandRuns(containerRuns); + auto dvFile = + makeDvDeleteFile(tempFile->getPath(), expected.size(), fileSize); + + DeletionVectorReader reader(dvFile, 0, pool_.get()); + + const uint64_t numRows = 200'000; + auto bitmap = allocateBitmap(numRows); + reader.readDeletePositions(0, numRows, bitmap); + + auto setBits = getSetBits(bitmap, numRows); + EXPECT_EQ(setBits, expected); + EXPECT_TRUE(reader.noMoreData()); +} + +TEST_F(DeletionVectorReaderTest, largePositionsMultipleContainers) { + // Positions spanning two containers: one in container 0 (key=0), one in + // container 1 (key=1, i.e. pos >= 65536). + std::vector positions = {5, 100, 65536, 65600}; + auto bitmapData = serializeRoaringBitmapNoRun(positions); + auto tempFile = writeDvFile(bitmapData); + auto fileSize = static_cast(bitmapData.size()); + + auto dvFile = + makeDvDeleteFile(tempFile->getPath(), positions.size(), fileSize); + + DeletionVectorReader reader(dvFile, 0, pool_.get()); + + // Read a batch covering all positions. + auto bitmap = allocateBitmap(66000); + reader.readDeletePositions(0, 66000, bitmap); + + auto setBits = getSetBits(bitmap, 66000); + EXPECT_EQ(setBits, (std::vector{5, 100, 65536, 65600})); + EXPECT_TRUE(reader.noMoreData()); +} + +TEST_F(DeletionVectorReaderTest, blobOffset) { + // Write a file with some padding before the actual bitmap data. + std::vector positions = {3, 7, 11}; + auto bitmapData = serializeRoaringBitmapNoRun(positions); + + // Prepend 64 bytes of padding. + std::string padding(64, 'X'); + std::string fileContent = padding + bitmapData; + + auto tempFile = writeDvFile(fileContent); + auto fileSize = static_cast(fileContent.size()); + + auto dvFile = makeDvDeleteFile( + tempFile->getPath(), positions.size(), fileSize, 64, bitmapData.size()); + + DeletionVectorReader reader(dvFile, 0, pool_.get()); + + auto bitmap = allocateBitmap(20); + reader.readDeletePositions(0, 20, bitmap); + + auto setBits = getSetBits(bitmap, 20); + EXPECT_EQ(setBits, (std::vector{3, 7, 11})); + EXPECT_TRUE(reader.noMoreData()); +} + +TEST_F(DeletionVectorReaderTest, constructorRejectsWrongContentType) { + auto tempFile = TempFilePath::create(); + { + std::ofstream out(tempFile->getPath(), std::ios::binary | std::ios::trunc); + out.write("dummy", 5); + } + + IcebergDeleteFile badFile( + FileContent::kPositionalDeletes, + tempFile->getPath(), + dwio::common::FileFormat::DWRF, + 1, + 5); + + VELOX_ASSERT_THROW( + DeletionVectorReader(badFile, 0, pool_.get()), + "Expected deletion vector file"); +} + +TEST_F(DeletionVectorReaderTest, constructorRejectsEmptyDv) { + auto tempFile = TempFilePath::create(); + { + std::ofstream out(tempFile->getPath(), std::ios::binary | std::ios::trunc); + out.write("dummy", 5); + } + + IcebergDeleteFile emptyDv( + FileContent::kDeletionVector, + tempFile->getPath(), + dwio::common::FileFormat::DWRF, + 0, + 5); + + VELOX_ASSERT_THROW( + DeletionVectorReader(emptyDv, 0, pool_.get()), "Empty deletion vector"); +} + +TEST_F(DeletionVectorReaderTest, noMoreDataAfterAllConsumed) { + std::vector positions = {0, 1, 2}; + auto bitmapData = serializeRoaringBitmapNoRun(positions); + auto tempFile = writeDvFile(bitmapData); + auto fileSize = static_cast(bitmapData.size()); + + auto dvFile = + makeDvDeleteFile(tempFile->getPath(), positions.size(), fileSize); + + DeletionVectorReader reader(dvFile, 0, pool_.get()); + EXPECT_FALSE(reader.noMoreData()); + + auto bitmap = allocateBitmap(10); + reader.readDeletePositions(0, 10, bitmap); + EXPECT_TRUE(reader.noMoreData()); + + // Additional reads should be no-ops. + auto bitmap2 = allocateBitmap(10); + reader.readDeletePositions(10, 10, bitmap2); + auto setBits2 = getSetBits(bitmap2, 10); + EXPECT_TRUE(setBits2.empty()); + EXPECT_TRUE(reader.noMoreData()); +} + +TEST_F(DeletionVectorReaderTest, singlePosition) { + std::vector positions = {42}; + auto bitmapData = serializeRoaringBitmapNoRun(positions); + auto tempFile = writeDvFile(bitmapData); + auto fileSize = static_cast(bitmapData.size()); + + auto dvFile = + makeDvDeleteFile(tempFile->getPath(), positions.size(), fileSize); + + DeletionVectorReader reader(dvFile, 0, pool_.get()); + + auto bitmap = allocateBitmap(100); + reader.readDeletePositions(0, 100, bitmap); + + auto setBits = getSetBits(bitmap, 100); + EXPECT_EQ(setBits, (std::vector{42})); +} + +TEST_F(DeletionVectorReaderTest, consecutivePositions) { + // Positions: 0 through 99 (100 consecutive positions). + std::vector positions; + positions.reserve(100); + for (int64_t i = 0; i < 100; ++i) { + positions.push_back(i); + } + auto bitmapData = serializeRoaringBitmapNoRun(positions); + auto tempFile = writeDvFile(bitmapData); + auto fileSize = static_cast(bitmapData.size()); + + auto dvFile = + makeDvDeleteFile(tempFile->getPath(), positions.size(), fileSize); + + DeletionVectorReader reader(dvFile, 0, pool_.get()); + + auto bitmap = allocateBitmap(100); + reader.readDeletePositions(0, 100, bitmap); + + auto setBits = getSetBits(bitmap, 100); + std::vector expected; + expected.reserve(100); + for (uint64_t i = 0; i < 100; ++i) { + expected.push_back(i); + } + EXPECT_EQ(setBits, expected); +} + +TEST_F(DeletionVectorReaderTest, invalidBitmapTooSmall) { + // Write a file that is too small to contain a valid roaring bitmap header. + std::string tinyData(4, '\0'); + auto tempFile = writeDvFile(tinyData); + + auto dvFile = makeDvDeleteFile(tempFile->getPath(), 1, tinyData.size()); + + DeletionVectorReader reader(dvFile, 0, pool_.get()); + + auto bitmap = allocateBitmap(10); + VELOX_ASSERT_THROW(reader.readDeletePositions(0, 10, bitmap), "too small"); +} + +TEST_F(DeletionVectorReaderTest, invalidBitmapBadCookie) { + // Write a file with an invalid cookie. Data must be large enough to pass + // the minimum size check (8 bytes for 64-bit header) so that the cookie + // validation is reached. + std::string badData(64, '\0'); + uint32_t badCookie = 99999; + std::memcpy(badData.data(), &badCookie, 4); + auto tempFile = writeDvFile(badData); + + auto dvFile = makeDvDeleteFile(tempFile->getPath(), 1, badData.size()); + + DeletionVectorReader reader(dvFile, 0, pool_.get()); + + auto bitmap = allocateBitmap(10); + VELOX_ASSERT_THROW( + reader.readDeletePositions(0, 10, bitmap), + "Unknown roaring bitmap cookie"); +} diff --git a/velox/connectors/hive/iceberg/tests/DeletionVectorWriterTest.cpp b/velox/connectors/hive/iceberg/tests/DeletionVectorWriterTest.cpp new file mode 100644 index 00000000000..fc869ee7954 --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/DeletionVectorWriterTest.cpp @@ -0,0 +1,319 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/DeletionVectorWriter.h" + +#include + +#include + +#include "velox/common/base/BitUtil.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/file/FileSystems.h" +#include "velox/common/memory/Memory.h" +#include "velox/common/testutil/TempFilePath.h" +#include "velox/connectors/hive/iceberg/DeletionVectorReader.h" + +using namespace facebook::velox; +using namespace facebook::velox::connector::hive::iceberg; +using namespace facebook::velox::common::testutil; + +namespace { + +/// Extracts which bits are set in a bitmap buffer. +std::vector getSetBits(const BufferPtr& bitmap, uint64_t size) { + auto* raw = bitmap->as(); + std::vector result; + for (uint64_t i = 0; i < size; ++i) { + if (bits::isBitSet(raw, i)) { + result.push_back(i); + } + } + return result; +} + +} // namespace + +class DeletionVectorWriterTest : public ::testing::Test { + protected: + static void SetUpTestSuite() { + memory::MemoryManager::initialize(memory::MemoryManager::Options{}); + } + + void SetUp() override { + filesystems::registerLocalFileSystem(); + pool_ = memory::memoryManager()->addLeafPool("DeletionVectorWriterTest"); + } + + BufferPtr allocateBitmap(uint64_t numBits) { + auto numBytes = bits::nbytes(numBits); + return AlignedBuffer::allocate(numBytes, pool_.get(), 0); + } + + /// Writes serialized bitmap to a temp file, reads it back with + /// DeletionVectorReader, and verifies the positions match. + void verifyRoundTrip( + const std::vector& positions, + uint64_t batchSize) { + DeletionVectorWriter writer; + writer.addDeletedPositions(positions); + EXPECT_EQ(writer.numPositions(), positions.size()); + + auto blobData = writer.serialize(); + + auto tempFile = TempFilePath::create(); + { + std::ofstream out( + tempFile->getPath(), std::ios::binary | std::ios::trunc); + out.write(blobData.data(), static_cast(blobData.size())); + } + + auto fileSize = static_cast(blobData.size()); + + // Create IcebergDeleteFile with DV metadata. + std::unordered_map lowerBounds; + std::unordered_map upperBounds; + lowerBounds[DeletionVectorReader::kDvOffsetFieldId] = "0"; + upperBounds[DeletionVectorReader::kDvLengthFieldId] = + std::to_string(fileSize); + + IcebergDeleteFile dvFile( + FileContent::kDeletionVector, + tempFile->getPath(), + dwio::common::FileFormat::DWRF, + positions.size(), + fileSize, + {}, + lowerBounds, + upperBounds); + + DeletionVectorReader reader(dvFile, 0, pool_.get()); + + // Collect all set bits across batches. + std::vector allSetBits; + int64_t maxPos = positions.empty() + ? 0 + : *std::max_element(positions.begin(), positions.end()); + uint64_t totalRows = static_cast(maxPos) + batchSize; + + for (uint64_t offset = 0; offset < totalRows; offset += batchSize) { + auto bitmap = allocateBitmap(batchSize); + reader.readDeletePositions(offset, batchSize, bitmap); + auto bits = getSetBits(bitmap, batchSize); + for (auto b : bits) { + allSetBits.push_back(offset + b); + } + } + + // Sort and deduplicate the expected positions. + std::vector expected = positions; + std::sort(expected.begin(), expected.end()); + expected.erase( + std::unique(expected.begin(), expected.end()), expected.end()); + + EXPECT_EQ(allSetBits.size(), expected.size()); + for (size_t i = 0; i < expected.size(); ++i) { + EXPECT_EQ(allSetBits[i], static_cast(expected[i])); + } + } + + std::shared_ptr pool_; +}; + +TEST_F(DeletionVectorWriterTest, emptyBitmap) { + DeletionVectorWriter writer; + EXPECT_EQ(writer.numPositions(), 0); + + auto data = writer.serialize(); + // Empty 64-bit bitmap: numGroups=0 as uint64 (8 bytes). + EXPECT_EQ(data.size(), 8); +} + +TEST_F(DeletionVectorWriterTest, singlePosition) { + verifyRoundTrip({42}, 100); +} + +TEST_F(DeletionVectorWriterTest, multiplePositions) { + verifyRoundTrip({0, 5, 10, 99}, 100); +} + +TEST_F(DeletionVectorWriterTest, consecutivePositions) { + std::vector positions; + positions.reserve(100); + for (int64_t i = 0; i < 100; ++i) { + positions.push_back(i); + } + verifyRoundTrip(positions, 200); +} + +TEST_F(DeletionVectorWriterTest, multipleContainers) { + // Positions spanning two containers (key=0 and key=1). + verifyRoundTrip({5, 100, 65536, 65600}, 70000); +} + +TEST_F(DeletionVectorWriterTest, largeCardinalityBitmapContainer) { + // More than 4096 positions in a single container triggers bitmap container. + std::vector positions; + positions.reserve(5000); + for (int64_t i = 0; i < 5000; ++i) { + positions.push_back(i * 2); // Even numbers 0..9998. + } + verifyRoundTrip(positions, 10100); +} + +TEST_F(DeletionVectorWriterTest, duplicatePositions) { + // addDeletedPosition() does not deduplicate — numPositions() counts all + // insertions including duplicates. serialize() deduplicates via std::set. + DeletionVectorWriter writer; + writer.addDeletedPosition(5); + writer.addDeletedPosition(5); + writer.addDeletedPosition(10); + writer.addDeletedPosition(10); + writer.addDeletedPosition(10); + EXPECT_EQ(writer.numPositions(), 5); + + auto data = writer.serialize(); + + auto tempFile = TempFilePath::create(); + { + std::ofstream out(tempFile->getPath(), std::ios::binary | std::ios::trunc); + out.write(data.data(), static_cast(data.size())); + } + + std::unordered_map lowerBounds; + std::unordered_map upperBounds; + lowerBounds[DeletionVectorReader::kDvOffsetFieldId] = "0"; + upperBounds[DeletionVectorReader::kDvLengthFieldId] = + std::to_string(data.size()); + + IcebergDeleteFile dvFile( + FileContent::kDeletionVector, + tempFile->getPath(), + dwio::common::FileFormat::DWRF, + 2, // Only 2 unique positions. + data.size(), + {}, + lowerBounds, + upperBounds); + + DeletionVectorReader reader(dvFile, 0, pool_.get()); + + auto bitmap = allocateBitmap(20); + reader.readDeletePositions(0, 20, bitmap); + + auto setBits = getSetBits(bitmap, 20); + EXPECT_EQ(setBits, (std::vector{5, 10})); +} + +TEST_F(DeletionVectorWriterTest, clearPositions) { + DeletionVectorWriter writer; + writer.addDeletedPosition(1); + writer.addDeletedPosition(2); + EXPECT_EQ(writer.numPositions(), 2); + + writer.clear(); + EXPECT_EQ(writer.numPositions(), 0); + + auto data = writer.serialize(); + EXPECT_EQ(data.size(), 8); // Empty bitmap. +} + +TEST_F(DeletionVectorWriterTest, negativePositionRejected) { + DeletionVectorWriter writer; + VELOX_ASSERT_THROW( + writer.addDeletedPosition(-1), "Deleted position must be non-negative"); +} + +TEST_F(DeletionVectorWriterTest, fourOrMoreContainersWithOffsets) { + // With >= 4 containers, the roaring format includes an offset section. + std::vector positions; + positions.reserve(5); + for (int i = 0; i < 5; ++i) { + positions.push_back(static_cast(i) * 65536 + 42); + } + verifyRoundTrip(positions, 5 * 65536 + 100); +} + +TEST_F(DeletionVectorWriterTest, puffinFileRoundTrip) { + DeletionVectorWriter writer; + writer.addDeletedPositions({3, 7, 42, 100}); + auto blobData = writer.serialize(); + + auto tempFile = TempFilePath::create(); + auto [blobOffset, blobLength] = writePuffinFile( + tempFile->getPath(), blobData, "/data/test-data-file.parquet"); + + EXPECT_EQ(blobOffset, 4); // After "PUF1" magic. + EXPECT_EQ(blobLength, blobData.size()); + + // Read the blob back from the Puffin file using DeletionVectorReader. + std::unordered_map lowerBounds; + std::unordered_map upperBounds; + lowerBounds[DeletionVectorReader::kDvOffsetFieldId] = + std::to_string(blobOffset); + upperBounds[DeletionVectorReader::kDvLengthFieldId] = + std::to_string(blobLength); + + // Get full file size. + std::ifstream in(tempFile->getPath(), std::ios::binary | std::ios::ate); + auto fileSize = static_cast(in.tellg()); + + IcebergDeleteFile dvFile( + FileContent::kDeletionVector, + tempFile->getPath(), + dwio::common::FileFormat::DWRF, + 4, + fileSize, + {}, + lowerBounds, + upperBounds); + + DeletionVectorReader reader(dvFile, 0, pool_.get()); + + auto bitmap = allocateBitmap(200); + reader.readDeletePositions(0, 200, bitmap); + + auto setBits = getSetBits(bitmap, 200); + EXPECT_EQ(setBits, (std::vector{3, 7, 42, 100})); +} + +/// Verifies 64-bit positions (>4 billion) serialize and deserialize correctly. +/// This exercises the Roaring64Bitmap group partitioning for large data files. +TEST_F(DeletionVectorWriterTest, largePositions64Bit) { + // Positions beyond the 32-bit range. + std::vector positions = { + 100, + 65'536, + 5'000'000'000LL, + 5'000'000'001LL, + 10'000'000'000LL, + }; + verifyRoundTrip(positions, 1'024); +} + +/// Verifies mixed 32-bit and 64-bit positions in the same bitmap. +TEST_F(DeletionVectorWriterTest, mixed32And64BitPositions) { + std::vector positions = { + 0, + 1, + 65'535, + 65'536, + 4'294'967'295LL, + 4'294'967'296LL, + 8'589'934'592LL, + }; + verifyRoundTrip(positions, 2'048); +} diff --git a/velox/connectors/hive/iceberg/tests/EqualityDeleteFileReaderTest.cpp b/velox/connectors/hive/iceberg/tests/EqualityDeleteFileReaderTest.cpp new file mode 100644 index 00000000000..e0124e716aa --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/EqualityDeleteFileReaderTest.cpp @@ -0,0 +1,1188 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "velox/common/file/FileSystems.h" +#include "velox/common/testutil/TempDirectoryPath.h" +#include "velox/connectors/ConnectorRegistry.h" +#include "velox/connectors/hive/iceberg/IcebergConnector.h" +#include "velox/connectors/hive/iceberg/IcebergDeleteFile.h" +#include "velox/connectors/hive/iceberg/IcebergSplit.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" + +namespace facebook::velox::connector::hive::iceberg { + +using namespace facebook::velox::exec::test; + +namespace { + +const std::string kIcebergConnectorId = "test-iceberg-eq-delete"; + +} // namespace + +/// End-to-end tests for equality deletes via the IcebergSplitReader. +/// These tests write DWRF data files and delete files, then execute +/// table scans verifying that matching rows are filtered out. +class EqualityDeleteFileReaderTest : public HiveConnectorTestBase { + protected: + void SetUp() override { + HiveConnectorTestBase::SetUp(); + IcebergConnectorFactory icebergFactory; + auto icebergConnector = icebergFactory.newConnector( + kIcebergConnectorId, + std::make_shared( + std::unordered_map()), + ioExecutor_.get()); + connector::ConnectorRegistry::global().insert( + icebergConnector->connectorId(), icebergConnector); + } + + void TearDown() override { + connector::ConnectorRegistry::global().erase(kIcebergConnectorId); + HiveConnectorTestBase::TearDown(); + } + + uint64_t getFileSize(const std::string& path) { + return filesystems::getFileSystem(path, nullptr) + ->openFileForRead(path) + ->size(); + } + + /// Writes a DWRF data file containing the given vectors. + std::shared_ptr writeDataFile( + const std::vector& data) { + auto file = common::testutil::TempFilePath::create(); + writeToFile(file->getPath(), data); + return file; + } + + /// Writes a DWRF delete file containing the equality delete rows. + std::shared_ptr writeEqDeleteFile( + const std::vector& deleteData) { + auto file = common::testutil::TempFilePath::create(); + writeToFile(file->getPath(), deleteData); + return file; + } + + /// Creates splits with equality delete files attached. + std::vector> makeSplits( + const std::string& dataFilePath, + const std::vector& deleteFiles = {}, + int64_t dataSequenceNumber = 0) { + return makeSplits( + dataFilePath, + /*partitionKeys=*/{}, + deleteFiles, + dataSequenceNumber); + } + + /// Creates splits with equality delete files and partition keys attached. + /// Use this overload to exercise the equality-delete augmentation for + /// partition columns missing from the user's projection. + std::vector> makeSplits( + const std::string& dataFilePath, + const std::unordered_map>& + partitionKeys, + const std::vector& deleteFiles, + int64_t dataSequenceNumber = 0) { + auto fileSize = getFileSize(dataFilePath); + return {std::make_shared( + kIcebergConnectorId, + dataFilePath, + dwio::common::FileFormat::DWRF, + 0, + fileSize, + partitionKeys, + std::nullopt, + std::unordered_map{}, + nullptr, + /*cacheable=*/true, + deleteFiles, + std::unordered_map{}, + std::nullopt, + dataSequenceNumber)}; + } + + /// Builds a table scan plan node with the given schema. + core::PlanNodePtr makeTableScanPlan(const RowTypePtr& rowType) { + return makeTableScanPlan(rowType, rowType); + } + + /// Builds a table scan plan node with separate output and table column + /// schemas. Use this when the user's projection ('outputType') does not + /// contain every column referenced by an equality delete file + /// ('dataColumns' must contain the full table schema so the equality + /// column resolution can map field IDs to names). + core::PlanNodePtr makeTableScanPlan( + const RowTypePtr& outputType, + const RowTypePtr& dataColumns) { + return PlanBuilder() + .startTableScan() + .connectorId(kIcebergConnectorId) + .outputType(outputType) + .dataColumns(dataColumns) + .endTableScan() + .planNode(); + } +}; + +/// Verifies that base rows matching the equality delete file are removed. +TEST_F(EqualityDeleteFileReaderTest, basicSingleColumnDelete) { + auto rowType = ROW({"id", "value"}, {BIGINT(), VARCHAR()}); + + auto baseData = makeRowVector( + {"id", "value"}, + { + makeFlatVector({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), + makeFlatVector( + {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"}), + }); + auto dataFile = writeDataFile({baseData}); + + // Delete rows where id == 3 or id == 7. + auto deleteData = makeRowVector( + {"id"}, + { + makeFlatVector({3, 7}), + }); + auto eqDeleteFile = writeEqDeleteFile({deleteData}); + + IcebergDeleteFile icebergDeleteFile( + FileContent::kEqualityDeletes, + eqDeleteFile->getPath(), + dwio::common::FileFormat::DWRF, + 2, + getFileSize(eqDeleteFile->getPath()), + /*equalityFieldIds=*/{1}); // field ID 1 = column 0 = "id" + + auto splits = makeSplits(dataFile->getPath(), {icebergDeleteFile}); + auto plan = makeTableScanPlan(rowType); + auto result = AssertQueryBuilder(plan).splits(splits).copyResults(pool()); + + auto expected = makeRowVector( + {"id", "value"}, + { + makeFlatVector({0, 1, 2, 4, 5, 6, 8, 9}), + makeFlatVector({"a", "b", "c", "e", "f", "g", "i", "j"}), + }); + + assertEqualResults({expected}, {result}); +} + +/// Regression test for the bug where IcebergSplitReader fails with +/// "Column not found in row: " when an equality-delete column is not +/// part of the user's projection. The reader must augment its scan spec to +/// physically read the equality-delete column, apply the delete, and then +/// project the column away from the output before returning to the operator. +TEST_F(EqualityDeleteFileReaderTest, equalityColumnNotInProjection) { + auto tableType = ROW({"id", "value"}, {BIGINT(), VARCHAR()}); + // The user only selects 'value'. The equality delete is on 'id', which is + // NOT in the projection — this is the case that previously failed. + auto outputType = ROW({"value"}, {VARCHAR()}); + + auto baseData = makeRowVector( + {"id", "value"}, + { + makeFlatVector({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), + makeFlatVector( + {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"}), + }); + auto dataFile = writeDataFile({baseData}); + + // Delete rows where id == 3 or id == 7. + auto deleteData = makeRowVector( + {"id"}, + { + makeFlatVector({3, 7}), + }); + auto eqDeleteFile = writeEqDeleteFile({deleteData}); + + IcebergDeleteFile icebergDeleteFile( + FileContent::kEqualityDeletes, + eqDeleteFile->getPath(), + dwio::common::FileFormat::DWRF, + 2, + getFileSize(eqDeleteFile->getPath()), + /*equalityFieldIds=*/{1}); // field ID 1 = column 0 = "id" + + auto splits = makeSplits(dataFile->getPath(), {icebergDeleteFile}); + auto plan = makeTableScanPlan(outputType, tableType); + auto result = AssertQueryBuilder(plan).splits(splits).copyResults(pool()); + + // The 'id' column must not appear in the output; only 'value' is projected. + // Rows with id=3 ("d") and id=7 ("h") are removed by the equality delete. + auto expected = makeRowVector( + {"value"}, + { + makeFlatVector({"a", "b", "c", "e", "f", "g", "i", "j"}), + }); + + assertEqualResults({expected}, {result}); +} + +/// Verifies that two equality-delete files referencing the SAME column not in +/// the user's projection only augment 'scanSpec_' once. Exercises the +/// de-duplication branch in 'IcebergSplitReader::prepareSplit'. +TEST_F(EqualityDeleteFileReaderTest, multipleDeleteFilesSameMissingColumn) { + auto tableType = ROW({"id", "value"}, {BIGINT(), VARCHAR()}); + auto outputType = ROW({"value"}, {VARCHAR()}); + + auto baseData = makeRowVector( + {"id", "value"}, + { + makeFlatVector({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), + makeFlatVector( + {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"}), + }); + auto dataFile = writeDataFile({baseData}); + + // Two delete files, both targeting 'id' (which is NOT in the projection). + auto deleteData1 = makeRowVector( + {"id"}, + { + makeFlatVector({2, 5}), + }); + auto eqDeleteFile1 = writeEqDeleteFile({deleteData1}); + IcebergDeleteFile icebergDeleteFile1( + FileContent::kEqualityDeletes, + eqDeleteFile1->getPath(), + dwio::common::FileFormat::DWRF, + 2, + getFileSize(eqDeleteFile1->getPath()), + /*equalityFieldIds=*/{1}); + + auto deleteData2 = makeRowVector( + {"id"}, + { + makeFlatVector({0, 9}), + }); + auto eqDeleteFile2 = writeEqDeleteFile({deleteData2}); + IcebergDeleteFile icebergDeleteFile2( + FileContent::kEqualityDeletes, + eqDeleteFile2->getPath(), + dwio::common::FileFormat::DWRF, + 2, + getFileSize(eqDeleteFile2->getPath()), + /*equalityFieldIds=*/{1}); + + auto splits = + makeSplits(dataFile->getPath(), {icebergDeleteFile1, icebergDeleteFile2}); + auto plan = makeTableScanPlan(outputType, tableType); + auto result = AssertQueryBuilder(plan).splits(splits).copyResults(pool()); + + // Rows with id=0, 2, 5, 9 are removed (across both delete files). + auto expected = makeRowVector( + {"value"}, + { + makeFlatVector({"b", "d", "e", "g", "h", "i"}), + }); + + assertEqualResults({expected}, {result}); +} + +/// Verifies a multi-column equality-delete file where some columns ARE in the +/// user's projection and some are NOT. Both must end up in the read output for +/// the equality probe to succeed, while only the projected columns appear in +/// the operator-visible result. +TEST_F(EqualityDeleteFileReaderTest, equalityMixedInAndOutOfProjection) { + auto tableType = ROW({"a", "b", "c"}, {INTEGER(), VARCHAR(), BIGINT()}); + // User selects only 'b' and 'c'. 'a' is referenced by the equality delete + // but not part of the projection. + auto outputType = ROW({"b", "c"}, {VARCHAR(), BIGINT()}); + + auto baseData = makeRowVector( + {"a", "b", "c"}, + { + makeFlatVector({1, 2, 3, 4, 5}), + makeFlatVector({"x", "y", "z", "x", "y"}), + makeFlatVector({10, 20, 30, 40, 50}), + }); + auto dataFile = writeDataFile({baseData}); + + // Delete rows where (a=2, b="y") -- removes row 1. + // Also (a=1, b="y") -- no match (row with a=1 has b="x"). + auto deleteData = makeRowVector( + {"a", "b"}, + { + makeFlatVector({2, 1}), + makeFlatVector({"y", "y"}), + }); + auto eqDeleteFile = writeEqDeleteFile({deleteData}); + + IcebergDeleteFile icebergDeleteFile( + FileContent::kEqualityDeletes, + eqDeleteFile->getPath(), + dwio::common::FileFormat::DWRF, + 3, + getFileSize(eqDeleteFile->getPath()), + /*equalityFieldIds=*/{1, 2}); // field IDs 1,2 = columns "a","b" + + auto splits = makeSplits(dataFile->getPath(), {icebergDeleteFile}); + auto plan = makeTableScanPlan(outputType, tableType); + auto result = AssertQueryBuilder(plan).splits(splits).copyResults(pool()); + + // Row 1 (a=2, b="y", c=20) is deleted. The remaining rows project to + // (b, c). + auto expected = makeRowVector( + {"b", "c"}, + { + makeFlatVector({"x", "z", "x", "y"}), + makeFlatVector({10, 30, 40, 50}), + }); + + assertEqualResults({expected}, {result}); +} + +/// Verifies that an equality delete on a partition column that IS in the data +/// file (Iceberg-style) but NOT in the user's projection works correctly. The +/// augmentation should set the partition value as a constant; the file-read +/// path should then leave the constant in place because the column is present +/// in 'fileType'. +TEST_F( + EqualityDeleteFileReaderTest, + equalityPartitionColumnInFileNotInProjection) { + auto tableType = ROW({"part", "value"}, {INTEGER(), VARCHAR()}); + auto outputType = ROW({"value"}, {VARCHAR()}); + + // Data file contains both 'part' and 'value', all rows in partition 2. + auto baseData = makeRowVector( + {"part", "value"}, + { + makeFlatVector({2, 2, 2, 2}), + makeFlatVector({"a", "b", "c", "d"}), + }); + auto dataFile = writeDataFile({baseData}); + + auto deleteData = makeRowVector( + {"part", "value"}, + { + makeFlatVector({2, 2}), + makeFlatVector({"b", "d"}), + }); + auto eqDeleteFile = writeEqDeleteFile({deleteData}); + + IcebergDeleteFile icebergDeleteFile( + FileContent::kEqualityDeletes, + eqDeleteFile->getPath(), + dwio::common::FileFormat::DWRF, + 2, + getFileSize(eqDeleteFile->getPath()), + /*equalityFieldIds=*/{1, 2}); + + auto splits = makeSplits( + dataFile->getPath(), + /*partitionKeys=*/{{"part", std::optional{"2"}}}, + {icebergDeleteFile}); + auto plan = makeTableScanPlan(outputType, tableType); + auto result = AssertQueryBuilder(plan).splits(splits).copyResults(pool()); + + auto expected = makeRowVector( + {"value"}, + { + makeFlatVector({"a", "c"}), + }); + + assertEqualResults({expected}, {result}); +} + +/// Same as above but the partition value does NOT match the equality-delete +/// value, so no rows should be removed. +TEST_F( + EqualityDeleteFileReaderTest, + equalityPartitionColumnNonMatchingPartition) { + auto tableType = ROW({"part", "value"}, {INTEGER(), VARCHAR()}); + auto outputType = ROW({"value"}, {VARCHAR()}); + + // Data file holds rows in partition 2. + auto baseData = makeRowVector( + {"part", "value"}, + { + makeFlatVector({2, 2, 2}), + makeFlatVector({"a", "b", "c"}), + }); + auto dataFile = writeDataFile({baseData}); + + // Delete (part=99, value="b"). No file row matches part=99. + auto deleteData = makeRowVector( + {"part", "value"}, + { + makeFlatVector({99}), + makeFlatVector({"b"}), + }); + auto eqDeleteFile = writeEqDeleteFile({deleteData}); + + IcebergDeleteFile icebergDeleteFile( + FileContent::kEqualityDeletes, + eqDeleteFile->getPath(), + dwio::common::FileFormat::DWRF, + 1, + getFileSize(eqDeleteFile->getPath()), + /*equalityFieldIds=*/{1, 2}); + + auto splits = makeSplits( + dataFile->getPath(), + /*partitionKeys=*/{{"part", std::optional{"2"}}}, + {icebergDeleteFile}); + auto plan = makeTableScanPlan(outputType, tableType); + auto result = AssertQueryBuilder(plan).splits(splits).copyResults(pool()); + + auto expected = makeRowVector( + {"value"}, + { + makeFlatVector({"a", "b", "c"}), + }); + + assertEqualResults({expected}, {result}); +} + +/// Multi-column equality delete where ONE column is a partition column not in +/// the projection and ANOTHER is a regular data column not in the projection. +/// Both must be augmented; the partition column gets a constant value, the +/// regular column is read from the file. +TEST_F( + EqualityDeleteFileReaderTest, + equalityMixedPartitionAndRegularNotInProjection) { + auto tableType = + ROW({"part", "id", "value"}, {INTEGER(), BIGINT(), VARCHAR()}); + auto outputType = ROW({"value"}, {VARCHAR()}); + + // Data file contains all three columns, all rows in partition 7. + auto baseData = makeRowVector( + {"part", "id", "value"}, + { + makeFlatVector({7, 7, 7, 7}), + makeFlatVector({10, 20, 30, 40}), + makeFlatVector({"a", "b", "c", "d"}), + }); + auto dataFile = writeDataFile({baseData}); + + // Delete (part=7, id=20) and (part=7, id=40). Should remove "b" and "d". + auto deleteData = makeRowVector( + {"part", "id"}, + { + makeFlatVector({7, 7}), + makeFlatVector({20, 40}), + }); + auto eqDeleteFile = writeEqDeleteFile({deleteData}); + + IcebergDeleteFile icebergDeleteFile( + FileContent::kEqualityDeletes, + eqDeleteFile->getPath(), + dwio::common::FileFormat::DWRF, + 2, + getFileSize(eqDeleteFile->getPath()), + /*equalityFieldIds=*/{1, 2}); // field IDs 1,2 = part, id + + auto splits = makeSplits( + dataFile->getPath(), + /*partitionKeys=*/{{"part", std::optional{"7"}}}, + {icebergDeleteFile}); + auto plan = makeTableScanPlan(outputType, tableType); + auto result = AssertQueryBuilder(plan).splits(splits).copyResults(pool()); + + auto expected = makeRowVector( + {"value"}, + { + makeFlatVector({"a", "c"}), + }); + + assertEqualResults({expected}, {result}); +} + +/// Verifies equality delete on a DATE partition column not in projection. +/// Iceberg encodes DATE partition values as days-since-epoch (e.g. "19345"). +/// This exercises the type-derived 'isDaysSinceEpoch' flag in +/// 'configureEqualityDeleteColumns' — for DATE columns the partition string +/// must be parsed as an integer day count, NOT as an ISO-8601 date string. +TEST_F( + EqualityDeleteFileReaderTest, + equalityDatePartitionColumnNotInProjection) { + auto tableType = ROW({"part_date", "value"}, {DATE(), VARCHAR()}); + auto outputType = ROW({"value"}, {VARCHAR()}); + + // 19345 days since 1970-01-01 == 2022-12-22. All file rows belong to that + // partition. + constexpr int32_t kPartitionDays = 19345; + auto baseData = makeRowVector( + {"part_date", "value"}, + { + makeFlatVector( + {kPartitionDays, kPartitionDays, kPartitionDays}, DATE()), + makeFlatVector({"a", "b", "c"}), + }); + auto dataFile = writeDataFile({baseData}); + + // Delete (part_date=19345, value="b"). + auto deleteData = makeRowVector( + {"part_date", "value"}, + { + makeFlatVector({kPartitionDays}, DATE()), + makeFlatVector({"b"}), + }); + auto eqDeleteFile = writeEqDeleteFile({deleteData}); + + IcebergDeleteFile icebergDeleteFile( + FileContent::kEqualityDeletes, + eqDeleteFile->getPath(), + dwio::common::FileFormat::DWRF, + 1, + getFileSize(eqDeleteFile->getPath()), + /*equalityFieldIds=*/{1, 2}); + + auto splits = makeSplits( + dataFile->getPath(), + /*partitionKeys=*/ + {{"part_date", + std::optional{std::to_string(kPartitionDays)}}}, + {icebergDeleteFile}); + auto plan = makeTableScanPlan(outputType, tableType); + auto result = AssertQueryBuilder(plan).splits(splits).copyResults(pool()); + + auto expected = makeRowVector( + {"value"}, + { + makeFlatVector({"a", "c"}), + }); + + assertEqualResults({expected}, {result}); +} + +/// Exercises the filter-only column upgrade path in +/// 'configureEqualityDeleteColumns'. The equality-delete column 'id' is +/// referenced by a WHERE predicate (so the planner installs a scan-spec +/// child with 'projectOut=false') but is NOT in the user's SELECT +/// projection. The augmentation must upgrade the existing scan-spec child +/// from filter-only to 'projectOut=true' and assign a non-conflicting +/// channel so the equality-delete reader can probe by name. +TEST_F(EqualityDeleteFileReaderTest, equalityFilterOnlyColumnNotInProjection) { + auto tableType = ROW({"id", "value"}, {BIGINT(), VARCHAR()}); + auto outputType = ROW({"value"}, {VARCHAR()}); + + auto baseData = makeRowVector( + {"id", "value"}, + { + makeFlatVector({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), + makeFlatVector( + {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"}), + }); + auto dataFile = writeDataFile({baseData}); + + // Equality delete removes id == 4 and id == 8. + auto deleteData = makeRowVector( + {"id"}, + { + makeFlatVector({4, 8}), + }); + auto eqDeleteFile = writeEqDeleteFile({deleteData}); + + IcebergDeleteFile icebergDeleteFile( + FileContent::kEqualityDeletes, + eqDeleteFile->getPath(), + dwio::common::FileFormat::DWRF, + 2, + getFileSize(eqDeleteFile->getPath()), + /*equalityFieldIds=*/{1}); // field ID 1 = "id" + + auto splits = makeSplits(dataFile->getPath(), {icebergDeleteFile}); + // WHERE id >= 3 keeps rows {3,4,5,6,7,8,9} from the file; the equality + // delete then removes id=4 and id=8, leaving values {d->skipped} no, we + // expect surviving values for ids {3,5,6,7,9}, projected as 'value' only. + auto plan = PlanBuilder() + .startTableScan() + .connectorId(kIcebergConnectorId) + .outputType(outputType) + .dataColumns(tableType) + .subfieldFilter("id >= 3") + .endTableScan() + .planNode(); + auto result = AssertQueryBuilder(plan).splits(splits).copyResults(pool()); + + auto expected = makeRowVector( + {"value"}, + { + makeFlatVector({"d", "f", "g", "h", "j"}), + }); + + assertEqualResults({expected}, {result}); +} + +/// Regression test for the schema-evolution + +/// partition-column-not-in-projection scenario surfaced by the Presto Iceberg +/// integration test 'testEqualityDeleteWithPartitionColumnMissingInSelect'. +/// +/// Setup (mirrors the Presto test for the older data file): +/// - Full table schema: (a, b, c, d) with a, c, d as partition columns. +/// - The data file under test was written BEFORE 'd' was added, so it +/// physically contains only (a, b, c). +/// - User projection: (a, b, d). 'd' must be NULL-filled (not in file); +/// 'c' is NOT in the projection but IS referenced by the equality +/// delete and IS the file's identity-partition column. +/// +/// The equality delete (a=6, c=2, b=1006) targets the file row +/// (6, '1006', 2). The augmentation must: +/// 1. Add 'c' to 'scanSpec_' / 'readerOutputType_' so the eq-delete +/// probe can find it by name. +/// 2. Leave 'd' alone — 'd' is in the user projection and gets the +/// standard schema-evolution NULL-fill from 'adaptColumns'. +/// 3. Honour the existing partition-key constant on 'c' regardless of +/// whether the file physically contains 'c'. +TEST_F( + EqualityDeleteFileReaderTest, + equalityPartitionColumnNotInProjectionWithEvolvedSchema) { + // Full evolved table schema (after 'ALTER TABLE ADD COLUMN d'). + auto tableType = + ROW({"a", "b", "c", "d"}, {INTEGER(), VARCHAR(), INTEGER(), VARCHAR()}); + // User selects 'a', 'b', 'd'. Note 'c' is NOT projected. + auto outputType = ROW({"a", "b", "d"}, {INTEGER(), VARCHAR(), VARCHAR()}); + + // Data file contains only (a, b, c) — written before 'd' was added. + // Both rows are in the (a=6, c=2) partition. + auto baseData = makeRowVector( + {"a", "b", "c"}, + { + makeFlatVector({6, 6}), + makeFlatVector({"1006", "1009"}), + makeFlatVector({2, 2}), + }); + auto dataFile = writeDataFile({baseData}); + + // Equality delete on (a, b, c) with values (6, '1006', 2). Field IDs + // are in field-id order = [1, 2, 3]. + auto deleteData = makeRowVector( + {"a", "b", "c"}, + { + makeFlatVector({6}), + makeFlatVector({"1006"}), + makeFlatVector({2}), + }); + auto eqDeleteFile = writeEqDeleteFile({deleteData}); + + IcebergDeleteFile icebergDeleteFile( + FileContent::kEqualityDeletes, + eqDeleteFile->getPath(), + dwio::common::FileFormat::DWRF, + 1, + getFileSize(eqDeleteFile->getPath()), + /*equalityFieldIds=*/{1, 2, 3}); + + auto splits = makeSplits( + dataFile->getPath(), + /*partitionKeys=*/ + {{"a", std::optional{"6"}}, + {"c", std::optional{"2"}}}, + {icebergDeleteFile}); + auto plan = makeTableScanPlan(outputType, tableType); + auto result = AssertQueryBuilder(plan).splits(splits).copyResults(pool()); + + // Row (6, '1006', 2) is deleted by the equality delete; (6, '1009', 2) + // survives. 'd' is NULL because the data file was written before 'd' + // was added. + auto expected = makeRowVector( + {"a", "b", "d"}, + { + makeFlatVector({6}), + makeFlatVector({"1009"}), + makeNullableFlatVector({std::nullopt}), + }); + + assertEqualResults({expected}, {result}); +} + +/// Verifies multi-column equality deletes (both columns must match). +TEST_F(EqualityDeleteFileReaderTest, multiColumnDelete) { + auto rowType = ROW({"a", "b", "c"}, {INTEGER(), VARCHAR(), BIGINT()}); + + auto baseData = makeRowVector( + {"a", "b", "c"}, + { + makeFlatVector({1, 2, 3, 4, 5}), + makeFlatVector({"x", "y", "z", "x", "y"}), + makeFlatVector({10, 20, 30, 40, 50}), + }); + auto dataFile = writeDataFile({baseData}); + + // Delete rows where (a=2, b="y") — matches row index 1. + // Also (a=5, b="y") — matches row index 4. + // But (a=1, b="y") — no match (a=1 has b="x"). + auto deleteData = makeRowVector( + {"a", "b"}, + { + makeFlatVector({2, 5, 1}), + makeFlatVector({"y", "y", "y"}), + }); + auto eqDeleteFile = writeEqDeleteFile({deleteData}); + + IcebergDeleteFile icebergDeleteFile( + FileContent::kEqualityDeletes, + eqDeleteFile->getPath(), + dwio::common::FileFormat::DWRF, + 3, + getFileSize(eqDeleteFile->getPath()), + /*equalityFieldIds=*/{1, 2}); // field IDs 1,2 = columns "a","b" + + auto splits = makeSplits(dataFile->getPath(), {icebergDeleteFile}); + auto plan = makeTableScanPlan(rowType); + auto result = AssertQueryBuilder(plan).splits(splits).copyResults(pool()); + + // Rows 0, 2, 3 survive (rows 1 and 4 deleted). + auto expected = makeRowVector( + {"a", "b", "c"}, + { + makeFlatVector({1, 3, 4}), + makeFlatVector({"x", "z", "x"}), + makeFlatVector({10, 30, 40}), + }); + + assertEqualResults({expected}, {result}); +} + +/// Verifies that when no rows match, all rows survive. +TEST_F(EqualityDeleteFileReaderTest, noMatchingDeletes) { + auto rowType = ROW({"id"}, {BIGINT()}); + + auto baseData = makeRowVector( + {"id"}, + { + makeFlatVector({1, 2, 3}), + }); + auto dataFile = writeDataFile({baseData}); + + // Delete file has values not present in base data. + auto deleteData = makeRowVector( + {"id"}, + { + makeFlatVector({100, 200}), + }); + auto eqDeleteFile = writeEqDeleteFile({deleteData}); + + IcebergDeleteFile icebergDeleteFile( + FileContent::kEqualityDeletes, + eqDeleteFile->getPath(), + dwio::common::FileFormat::DWRF, + 2, + getFileSize(eqDeleteFile->getPath()), + /*equalityFieldIds=*/{1}); + + auto splits = makeSplits(dataFile->getPath(), {icebergDeleteFile}); + auto plan = makeTableScanPlan(rowType); + auto result = AssertQueryBuilder(plan).splits(splits).copyResults(pool()); + + auto expected = makeRowVector( + {"id"}, + { + makeFlatVector({1, 2, 3}), + }); + + assertEqualResults({expected}, {result}); +} + +/// Verifies that all rows are deleted when every base row matches. +TEST_F(EqualityDeleteFileReaderTest, allRowsDeleted) { + auto rowType = ROW({"id"}, {BIGINT()}); + + auto baseData = makeRowVector( + {"id"}, + { + makeFlatVector({1, 2, 3}), + }); + auto dataFile = writeDataFile({baseData}); + + auto deleteData = makeRowVector( + {"id"}, + { + makeFlatVector({1, 2, 3}), + }); + auto eqDeleteFile = writeEqDeleteFile({deleteData}); + + IcebergDeleteFile icebergDeleteFile( + FileContent::kEqualityDeletes, + eqDeleteFile->getPath(), + dwio::common::FileFormat::DWRF, + 3, + getFileSize(eqDeleteFile->getPath()), + /*equalityFieldIds=*/{1}); + + auto splits = makeSplits(dataFile->getPath(), {icebergDeleteFile}); + auto plan = makeTableScanPlan(rowType); + auto result = AssertQueryBuilder(plan).splits(splits).copyResults(pool()); + + EXPECT_EQ(result->size(), 0); +} + +/// Verifies equality deletes with VARCHAR columns. +TEST_F(EqualityDeleteFileReaderTest, stringColumnDelete) { + auto rowType = ROW({"name", "age"}, {VARCHAR(), INTEGER()}); + + auto baseData = makeRowVector( + {"name", "age"}, + { + makeFlatVector({"alice", "bob", "charlie", "dave"}), + makeFlatVector({25, 30, 35, 40}), + }); + auto dataFile = writeDataFile({baseData}); + + // Delete rows where name is "bob" or "dave". + auto deleteData = makeRowVector( + {"name"}, + { + makeFlatVector({"bob", "dave"}), + }); + auto eqDeleteFile = writeEqDeleteFile({deleteData}); + + IcebergDeleteFile icebergDeleteFile( + FileContent::kEqualityDeletes, + eqDeleteFile->getPath(), + dwio::common::FileFormat::DWRF, + 2, + getFileSize(eqDeleteFile->getPath()), + /*equalityFieldIds=*/{1}); // field ID 1 = "name" + + auto splits = makeSplits(dataFile->getPath(), {icebergDeleteFile}); + auto plan = makeTableScanPlan(rowType); + auto result = AssertQueryBuilder(plan).splits(splits).copyResults(pool()); + + auto expected = makeRowVector( + {"name", "age"}, + { + makeFlatVector({"alice", "charlie"}), + makeFlatVector({25, 35}), + }); + + assertEqualResults({expected}, {result}); +} + +/// Verifies equality deletes on a non-first column (field ID 2). +TEST_F(EqualityDeleteFileReaderTest, deleteOnSecondColumn) { + auto rowType = ROW({"id", "category"}, {BIGINT(), VARCHAR()}); + + auto baseData = makeRowVector( + {"id", "category"}, + { + makeFlatVector({1, 2, 3, 4, 5}), + makeFlatVector({"A", "B", "A", "C", "B"}), + }); + auto dataFile = writeDataFile({baseData}); + + // Delete rows where category == "B". + auto deleteData = makeRowVector( + {"category"}, + { + makeFlatVector({"B"}), + }); + auto eqDeleteFile = writeEqDeleteFile({deleteData}); + + IcebergDeleteFile icebergDeleteFile( + FileContent::kEqualityDeletes, + eqDeleteFile->getPath(), + dwio::common::FileFormat::DWRF, + 1, + getFileSize(eqDeleteFile->getPath()), + /*equalityFieldIds=*/{2}); // field ID 2 = column 1 = "category" + + auto splits = makeSplits(dataFile->getPath(), {icebergDeleteFile}); + auto plan = makeTableScanPlan(rowType); + auto result = AssertQueryBuilder(plan).splits(splits).copyResults(pool()); + + // Rows with category="B" (indices 1,4) deleted. + auto expected = makeRowVector( + {"id", "category"}, + { + makeFlatVector({1, 3, 4}), + makeFlatVector({"A", "A", "C"}), + }); + + assertEqualResults({expected}, {result}); +} + +/// Verifies that equality deletes apply when the delete file has a higher +/// sequence number than the data file (per the Iceberg V2+ spec). +TEST_F(EqualityDeleteFileReaderTest, sequenceNumberDeleteApplies) { + auto rowType = ROW({"id", "value"}, {BIGINT(), VARCHAR()}); + + auto baseData = makeRowVector( + {"id", "value"}, + { + makeFlatVector({1, 2, 3, 4, 5}), + makeFlatVector({"a", "b", "c", "d", "e"}), + }); + auto dataFile = writeDataFile({baseData}); + + auto deleteData = makeRowVector( + {"id"}, + { + makeFlatVector({2, 4}), + }); + auto eqDeleteFile = writeEqDeleteFile({deleteData}); + + // Delete file has sequence number 5, data file has sequence number 3. + // Since deleteSeq (5) > dataSeq (3), the delete should apply. + IcebergDeleteFile icebergDeleteFile( + FileContent::kEqualityDeletes, + eqDeleteFile->getPath(), + dwio::common::FileFormat::DWRF, + 2, + getFileSize(eqDeleteFile->getPath()), + /*equalityFieldIds=*/{1}, + /*lowerBounds=*/{}, + /*upperBounds=*/{}, + /*dataSequenceNumber=*/5); + + auto splits = makeSplits( + dataFile->getPath(), + {icebergDeleteFile}, + /*dataSequenceNumber=*/3); + auto plan = makeTableScanPlan(rowType); + auto result = AssertQueryBuilder(plan).splits(splits).copyResults(pool()); + + // Rows with id=2 and id=4 are deleted. + auto expected = makeRowVector( + {"id", "value"}, + { + makeFlatVector({1, 3, 5}), + makeFlatVector({"a", "c", "e"}), + }); + + assertEqualResults({expected}, {result}); +} + +/// Verifies that equality deletes are skipped when the delete file has a +/// lower or equal sequence number compared to the data file. +TEST_F(EqualityDeleteFileReaderTest, sequenceNumberDeleteSkipped) { + auto rowType = ROW({"id", "value"}, {BIGINT(), VARCHAR()}); + + auto baseData = makeRowVector( + {"id", "value"}, + { + makeFlatVector({1, 2, 3}), + makeFlatVector({"a", "b", "c"}), + }); + auto dataFile = writeDataFile({baseData}); + + auto deleteData = makeRowVector( + {"id"}, + { + makeFlatVector({1, 2, 3}), + }); + auto eqDeleteFile = writeEqDeleteFile({deleteData}); + + // Delete file has sequence number 2, data file has sequence number 5. + // Since deleteSeq (2) <= dataSeq (5), the delete should be skipped. + IcebergDeleteFile icebergDeleteFile( + FileContent::kEqualityDeletes, + eqDeleteFile->getPath(), + dwio::common::FileFormat::DWRF, + 3, + getFileSize(eqDeleteFile->getPath()), + /*equalityFieldIds=*/{1}, + /*lowerBounds=*/{}, + /*upperBounds=*/{}, + /*dataSequenceNumber=*/2); + + auto splits = makeSplits( + dataFile->getPath(), + {icebergDeleteFile}, + /*dataSequenceNumber=*/5); + auto plan = makeTableScanPlan(rowType); + auto result = AssertQueryBuilder(plan).splits(splits).copyResults(pool()); + + // All rows survive because the delete file is skipped. + auto expected = makeRowVector( + {"id", "value"}, + { + makeFlatVector({1, 2, 3}), + makeFlatVector({"a", "b", "c"}), + }); + + assertEqualResults({expected}, {result}); +} + +/// Verifies that equality deletes are skipped when the delete file has the +/// same sequence number as the data file (edge case of the <= check). +TEST_F(EqualityDeleteFileReaderTest, sequenceNumberEqualSkipped) { + auto rowType = ROW({"id", "value"}, {BIGINT(), VARCHAR()}); + + auto baseData = makeRowVector( + {"id", "value"}, + { + makeFlatVector({1, 2, 3}), + makeFlatVector({"a", "b", "c"}), + }); + auto dataFile = writeDataFile({baseData}); + + auto deleteData = makeRowVector( + {"id"}, + { + makeFlatVector({1, 2, 3}), + }); + auto eqDeleteFile = writeEqDeleteFile({deleteData}); + + // Delete file and data file have the same sequence number (5). + // Since deleteSeq (5) <= dataSeq (5), the delete should be skipped. + IcebergDeleteFile icebergDeleteFile( + FileContent::kEqualityDeletes, + eqDeleteFile->getPath(), + dwio::common::FileFormat::DWRF, + 3, + getFileSize(eqDeleteFile->getPath()), + /*equalityFieldIds=*/{1}, + /*lowerBounds=*/{}, + /*upperBounds=*/{}, + /*dataSequenceNumber=*/5); + + auto splits = makeSplits( + dataFile->getPath(), + {icebergDeleteFile}, + /*dataSequenceNumber=*/5); + auto plan = makeTableScanPlan(rowType); + auto result = AssertQueryBuilder(plan).splits(splits).copyResults(pool()); + + // All rows survive because the delete file is skipped (equal seq#). + auto expected = makeRowVector( + {"id", "value"}, + { + makeFlatVector({1, 2, 3}), + makeFlatVector({"a", "b", "c"}), + }); + + assertEqualResults({expected}, {result}); +} + +/// Verifies that when either sequence number is 0 (unassigned/legacy V1), +/// the delete file is always applied (filtering is disabled). +TEST_F(EqualityDeleteFileReaderTest, sequenceNumberZeroAlwaysApplies) { + auto rowType = ROW({"id"}, {BIGINT()}); + + auto baseData = makeRowVector( + {"id"}, + { + makeFlatVector({1, 2, 3}), + }); + auto dataFile = writeDataFile({baseData}); + + auto deleteData = makeRowVector( + {"id"}, + { + makeFlatVector({2}), + }); + auto eqDeleteFile = writeEqDeleteFile({deleteData}); + + // Delete file has sequence number 0 (legacy), data file has sequence 10. + // Since deleteSeq is 0, filtering is disabled and the delete applies. + IcebergDeleteFile icebergDeleteFile( + FileContent::kEqualityDeletes, + eqDeleteFile->getPath(), + dwio::common::FileFormat::DWRF, + 1, + getFileSize(eqDeleteFile->getPath()), + /*equalityFieldIds=*/{1}, + /*lowerBounds=*/{}, + /*upperBounds=*/{}, + /*dataSequenceNumber=*/0); + + auto splits = makeSplits( + dataFile->getPath(), + {icebergDeleteFile}, + /*dataSequenceNumber=*/10); + auto plan = makeTableScanPlan(rowType); + auto result = AssertQueryBuilder(plan).splits(splits).copyResults(pool()); + + // Row id=2 is deleted because sequence number filtering is disabled. + auto expected = makeRowVector( + {"id"}, + { + makeFlatVector({1, 3}), + }); + + assertEqualResults({expected}, {result}); +} + +/// Verifies that when multiple delete files have different sequence numbers, +/// only those with higher sequence numbers than the data file are applied. +TEST_F(EqualityDeleteFileReaderTest, mixedSequenceNumbers) { + auto rowType = ROW({"id", "value"}, {BIGINT(), VARCHAR()}); + + auto baseData = makeRowVector( + {"id", "value"}, + { + makeFlatVector({1, 2, 3, 4, 5}), + makeFlatVector({"a", "b", "c", "d", "e"}), + }); + auto dataFile = writeDataFile({baseData}); + + // First delete file: seqNum=10 (higher than data seqNum=5) → applied. + auto deleteData1 = makeRowVector( + {"id"}, + { + makeFlatVector({2}), + }); + auto eqDeleteFile1 = writeEqDeleteFile({deleteData1}); + IcebergDeleteFile icebergDeleteFile1( + FileContent::kEqualityDeletes, + eqDeleteFile1->getPath(), + dwio::common::FileFormat::DWRF, + 1, + getFileSize(eqDeleteFile1->getPath()), + /*equalityFieldIds=*/{1}, + /*lowerBounds=*/{}, + /*upperBounds=*/{}, + /*dataSequenceNumber=*/10); + + // Second delete file: seqNum=3 (lower than data seqNum=5) → skipped. + auto deleteData2 = makeRowVector( + {"id"}, + { + makeFlatVector({4}), + }); + auto eqDeleteFile2 = writeEqDeleteFile({deleteData2}); + IcebergDeleteFile icebergDeleteFile2( + FileContent::kEqualityDeletes, + eqDeleteFile2->getPath(), + dwio::common::FileFormat::DWRF, + 1, + getFileSize(eqDeleteFile2->getPath()), + /*equalityFieldIds=*/{1}, + /*lowerBounds=*/{}, + /*upperBounds=*/{}, + /*dataSequenceNumber=*/3); + + auto splits = makeSplits( + dataFile->getPath(), + {icebergDeleteFile1, icebergDeleteFile2}, + /*dataSequenceNumber=*/5); + auto plan = makeTableScanPlan(rowType); + auto result = AssertQueryBuilder(plan).splits(splits).copyResults(pool()); + + // Only id=2 is deleted (from delete file 1 with seqNum=10). + // id=4 survives because delete file 2 (seqNum=3) is skipped. + auto expected = makeRowVector( + {"id", "value"}, + { + makeFlatVector({1, 3, 4, 5}), + makeFlatVector({"a", "c", "d", "e"}), + }); + + assertEqualResults({expected}, {result}); +} + +// TODO: Add a Parquet-format equality delete test. Currently all equality +// delete tests use DWRF because writeToFile() (from HiveConnectorTestBase) +// only supports DWRF. Adding a Parquet test requires adding Parquet writer +// dependencies to this test target's BUCK file and a Parquet write helper. + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/tests/IcebergConnectorTest.cpp b/velox/connectors/hive/iceberg/tests/IcebergConnectorTest.cpp new file mode 100644 index 00000000000..248bccaaf3d --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/IcebergConnectorTest.cpp @@ -0,0 +1,72 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/IcebergConnector.h" +#include +#include "velox/connectors/ConnectorRegistry.h" +#include "velox/connectors/hive/HiveConfig.h" +#include "velox/connectors/hive/iceberg/tests/IcebergTestBase.h" + +namespace facebook::velox::connector::hive::iceberg { + +namespace { + +class IcebergConnectorTest : public test::IcebergTestBase { + protected: + static void resetIcebergConnector( + const std::shared_ptr& config) { + ConnectorRegistry::global().erase(test::kIcebergConnectorId); + + IcebergConnectorFactory factory; + auto icebergConnector = + factory.newConnector(test::kIcebergConnectorId, config); + ConnectorRegistry::global().insert( + icebergConnector->connectorId(), icebergConnector); + } +}; + +TEST_F(IcebergConnectorTest, connectorConfiguration) { + auto customConfig = std::make_shared( + std::unordered_map{ + {hive::HiveConfig::kEnableFileHandleCache, "true"}, + {hive::HiveConfig::kNumCacheFileHandles, "1000"}}); + + resetIcebergConnector(customConfig); + + // Verify connector was registered successfully with custom config. + auto icebergConnector = ConnectorRegistry::tryGet(test::kIcebergConnectorId); + ASSERT_NE(icebergConnector, nullptr); + + auto config = icebergConnector->connectorConfig(); + ASSERT_NE(config, nullptr); + + hive::HiveConfig hiveConfig(config); + ASSERT_TRUE(hiveConfig.isFileHandleCacheEnabled()); + ASSERT_EQ(hiveConfig.numCacheFileHandles(), 1000); +} + +TEST_F(IcebergConnectorTest, connectorProperties) { + auto icebergConnector = ConnectorRegistry::tryGet(test::kIcebergConnectorId); + ASSERT_NE(icebergConnector, nullptr); + + ASSERT_TRUE(icebergConnector->canAddDynamicFilter()); + ASSERT_TRUE(icebergConnector->supportsSplitPreload()); + ASSERT_NE(icebergConnector->ioExecutor(), nullptr); +} + +} // namespace + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/tests/IcebergDwrfInsertTest.cpp b/velox/connectors/hive/iceberg/tests/IcebergDwrfInsertTest.cpp new file mode 100644 index 00000000000..088f7b57832 --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/IcebergDwrfInsertTest.cpp @@ -0,0 +1,183 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/HiveConfig.h" +#include "velox/connectors/hive/iceberg/tests/IcebergTestBase.h" +#include "velox/dwio/dwrf/RegisterDwrfReader.h" +#include "velox/dwio/dwrf/RegisterDwrfWriter.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/PlanBuilder.h" + +using namespace facebook::velox::common::testutil; + +namespace facebook::velox::connector::hive::iceberg { +namespace { + +/// End-to-end tests for writing and reading Iceberg tables using the DWRF file +/// format. Exercises the full write path (IcebergDataSink -> DWRF writer) and +/// the full read path (IcebergSplitReader -> DWRF reader), verifying data +/// round-trip correctness. +class IcebergDwrfInsertTest : public test::IcebergTestBase { + protected: + void SetUp() override { + IcebergTestBase::SetUp(); + dwrf::registerDwrfReaderFactory(); + dwrf::registerDwrfWriterFactory(); + fileFormat_ = dwio::common::FileFormat::DWRF; + } + + /// Write test data using DWRF format, then read it back and verify results. + void test(const RowTypePtr& rowType, double nullRatio = 0.0) { + const auto outputDirectory = TempDirectoryPath::create(); + const auto dataPath = outputDirectory->getPath(); + constexpr int32_t numBatches = 10; + constexpr int32_t vectorSize = 5'000; + const auto vectors = + createTestData(rowType, numBatches, vectorSize, nullRatio); + const auto dataSink = createDataSinkAndAppendData(vectors, dataPath); + const auto commitTasks = dataSink->close(); + + auto splits = createSplitsForDirectory(dataPath); + ASSERT_EQ(splits.size(), commitTasks.size()); + auto plan = exec::test::PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .planNode(); + exec::test::AssertQueryBuilder(plan).splits(splits).assertResults(vectors); + } +}; + +TEST_F(IcebergDwrfInsertTest, basic) { + auto rowType = + ROW({"c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8"}, + {BIGINT(), + INTEGER(), + SMALLINT(), + BOOLEAN(), + REAL(), + VARCHAR(), + VARBINARY(), + DOUBLE()}); + test(rowType, 0.2); +} + +TEST_F(IcebergDwrfInsertTest, mapAndArray) { + auto rowType = + ROW({"c1", "c2"}, {MAP(INTEGER(), VARCHAR()), ARRAY(VARCHAR())}); + test(rowType); +} + +/// Verify the commit message maps DWRF format to "ORC" per Iceberg SDK +/// convention (Iceberg has no DWRF enum; DWRF files use the ORC identifier). +TEST_F(IcebergDwrfInsertTest, commitMessageFormat) { + const auto outputDirectory = TempDirectoryPath::create(); + const auto dataPath = outputDirectory->getPath(); + auto rowType = ROW({"c1", "c2"}, {BIGINT(), VARCHAR()}); + const auto vectors = createTestData(rowType, 2, 100); + const auto dataSink = createDataSinkAndAppendData(vectors, dataPath); + const auto commitTasks = dataSink->close(); + + ASSERT_GT(commitTasks.size(), 0); + for (const auto& task : commitTasks) { + auto taskJson = folly::parseJson(task); + ASSERT_TRUE(taskJson.count("fileFormat") > 0); + ASSERT_EQ(taskJson["fileFormat"].asString(), "ORC"); + } +} + +/// Round-trips TIMESTAMP values through the DWRF write path with the session +/// configured for non-UTC timezone and adjustTimestampToTimezone=true. The +/// Iceberg spec requires timestamps NOT be adjusted to UTC; the DataSink +/// enforces this by overriding the DWRF WriterOptions fields via +/// DwrfWriterOptionsAdapter::applyPostConfigs. +/// +/// TODO: This test is a symmetric Velox-only round-trip and cannot, by +/// itself, detect a regression where the DataSink stops overriding the DWRF +/// timezone fields — any write-side shift is exactly cancelled by the +/// matching read-side shift. The adapter's override contract is locked down +/// at the unit level by +/// WriterOptionsAdapterTest::dwrfPostConfigsOverridesTimestampFields. A +/// true cross-engine validation (e.g., reading Velox-written Iceberg files +/// with a Java Spark reader) is needed to verify the on-disk timestamp +/// matches the spec. +TEST_F(IcebergDwrfInsertTest, timestampRoundTrip) { + recreateConnectorQueryCtx( + /*sessionTimezone=*/"America/Los_Angeles", + /*adjustTimestampToTimezone=*/true); + auto rowType = ROW({"c1", "c2"}, {BIGINT(), TIMESTAMP()}); + test(rowType); +} + +/// End-to-end test for partitioned DWRF writes. Mirrors the identity- +/// transform partition coverage that exists for Parquet and exercises the +/// commitPartitionValue_ accounting on the DWRF code path. +TEST_F(IcebergDwrfInsertTest, partitioned) { + auto rowType = ROW({"c1", "c2"}, {BIGINT(), VARCHAR()}); + const auto outputDirectory = TempDirectoryPath::create(); + const auto dataPath = outputDirectory->getPath(); + const auto vectors = createTestData(rowType, 2, 50, 0.2); + + std::vector partitionTransforms = { + {0, TransformType::kIdentity, std::nullopt}}; + const auto dataSink = + createDataSinkAndAppendData(vectors, dataPath, partitionTransforms); + const auto commitTasks = dataSink->close(); + ASSERT_GT(commitTasks.size(), 0); + + for (const auto& task : commitTasks) { + auto taskJson = folly::parseJson(task); + ASSERT_EQ(taskJson["fileFormat"].asString(), "ORC"); + EXPECT_GT(taskJson.count("partitionDataJson"), 0); + } + + auto splits = createSplitsForDirectory(dataPath); + ASSERT_EQ(splits.size(), commitTasks.size()); + auto plan = exec::test::PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .planNode(); + exec::test::AssertQueryBuilder(plan).splits(splits).assertResults(vectors); +} + +/// Regression test for the isPartitioned() guard added to ensureWriter(). +/// Without the guard, calling ensureWriter() on a non-partitioned table +/// invoked makeCommitPartitionValue(), which dereferences +/// partitionIdGenerator_ — null for unpartitioned tables — causing a crash. +/// Exercises the unpartitioned write path explicitly so any future +/// regression is caught with a named test. +TEST_F(IcebergDwrfInsertTest, ensureWriterNonPartitioned) { + auto rowType = ROW({"c1", "c2"}, {BIGINT(), VARCHAR()}); + const auto outputDirectory = TempDirectoryPath::create(); + const auto dataPath = outputDirectory->getPath(); + const auto vectors = createTestData(rowType, 1, 50); + + // No partitionFields => unpartitioned table, partitionIdGenerator_ stays + // null inside the sink. appendData triggers ensureWriter(). + const auto dataSink = createDataSinkAndAppendData(vectors, dataPath); + const auto commitTasks = dataSink->close(); + + ASSERT_EQ(commitTasks.size(), 1); + auto taskJson = folly::parseJson(commitTasks[0]); + // Unpartitioned tables must not emit partitionDataJson. + EXPECT_EQ(taskJson.count("partitionDataJson"), 0); +} + +} // namespace +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/tests/IcebergInsertTest.cpp b/velox/connectors/hive/iceberg/tests/IcebergInsertTest.cpp new file mode 100644 index 00000000000..f7fbee5efe1 --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/IcebergInsertTest.cpp @@ -0,0 +1,272 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/HiveConfig.h" +#include "velox/connectors/hive/iceberg/IcebergConnector.h" +#include "velox/connectors/hive/iceberg/tests/IcebergTestBase.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/PlanBuilder.h" + +using namespace facebook::velox::common::testutil; + +namespace facebook::velox::connector::hive::iceberg { +namespace { + +#ifdef VELOX_ENABLE_PARQUET + +class IcebergInsertTest : public test::IcebergTestBase { + protected: + void test(const RowTypePtr& rowType, double nullRatio = 0.0) { + const auto outputDirectory = TempDirectoryPath::create(); + const auto dataPath = outputDirectory->getPath(); + constexpr int32_t numBatches = 10; + constexpr int32_t vectorSize = 5'000; + const auto vectors = + createTestData(rowType, numBatches, vectorSize, nullRatio); + const auto dataSink = createDataSinkAndAppendData(vectors, dataPath); + const auto commitTasks = dataSink->close(); + + auto splits = createSplitsForDirectory(dataPath); + ASSERT_EQ(splits.size(), commitTasks.size()); + auto plan = exec::test::PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .planNode(); + exec::test::AssertQueryBuilder(plan).splits(splits).assertResults(vectors); + } +}; + +TEST_F(IcebergInsertTest, basic) { + auto rowType = + ROW({"c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10", "c11"}, + {BIGINT(), + INTEGER(), + SMALLINT(), + BOOLEAN(), + REAL(), + DECIMAL(18, 5), + VARCHAR(), + VARBINARY(), + DATE(), + TIMESTAMP(), + ROW({"id", "name"}, {INTEGER(), VARCHAR()})}); + test(rowType, 0.2); +} + +TEST_F(IcebergInsertTest, mapAndArray) { + auto rowType = + ROW({"c1", "c2"}, {MAP(INTEGER(), VARCHAR()), ARRAY(VARCHAR())}); + test(rowType); +} + +TEST_F(IcebergInsertTest, bigDecimal) { + auto rowType = ROW({"c1"}, {DECIMAL(38, 5)}); + fileFormat_ = dwio::common::FileFormat::PARQUET; + test(rowType); +} + +TEST_F(IcebergInsertTest, singleColumnPartition) { + struct TestCase { + std::string name; + TypePtr type; + }; + + std::vector testCases = { + {"c1", BIGINT()}, + {"c2", INTEGER()}, + {"c3", SMALLINT()}, + {"c4", DECIMAL(18, 5)}, + {"c5", BOOLEAN()}, + {"c6", VARCHAR()}, + {"c7", DATE()}, + {"c8", TIMESTAMP()}}; + + for (const auto& testCase : testCases) { + const auto outputDirectory = TempDirectoryPath::create(); + constexpr int32_t numBatches = 2; + constexpr int32_t vectorSize = 50; + auto rowType = ROW({testCase.name}, {testCase.type}); + + const auto vectors = createTestData(rowType, numBatches, vectorSize, 0.5); + std::vector partitionTransforms = { + {0, TransformType::kIdentity, std::nullopt}}; + const auto dataSink = createDataSinkAndAppendData( + vectors, outputDirectory->getPath(), partitionTransforms); + const auto commitTasks = dataSink->close(); + auto splits = createSplitsForDirectory(outputDirectory->getPath()); + + ASSERT_GT(commitTasks.size(), 0); + ASSERT_EQ(splits.size(), commitTasks.size()); + + for (const auto& task : commitTasks) { + auto taskJson = folly::parseJson(task); + ASSERT_TRUE(taskJson.count("partitionDataJson") > 0); + } + + auto plan = exec::test::PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .planNode(); + exec::test::AssertQueryBuilder(plan).splits(splits).assertResults(vectors); + } +} + +TEST_F(IcebergInsertTest, partitionNullColumn) { + struct TestCase { + std::string name; + TypePtr type; + }; + + std::vector testCases = { + {"c1", BIGINT()}, + {"c2", INTEGER()}, + {"c3", SMALLINT()}, + {"c4", DECIMAL(18, 5)}, + {"c5", BOOLEAN()}, + {"c6", VARCHAR()}, + {"c7", DATE()}, + {"c8", TIMESTAMP()}}; + + for (const auto& testCase : testCases) { + const auto outputDirectory = TempDirectoryPath::create(); + constexpr int32_t numBatches = 2; + constexpr int32_t vectorSize = 100; + auto rowType = ROW({testCase.name}, {testCase.type}); + // nullRatio = 1.0 + const auto vectors = createTestData(rowType, numBatches, vectorSize, 1.0); + + std::vector partitionTransforms = { + {0, TransformType::kIdentity, std::nullopt}}; + const auto dataSink = createDataSinkAndAppendData( + vectors, outputDirectory->getPath(), partitionTransforms); + + const auto commitTasks = dataSink->close(); + ASSERT_EQ(1, commitTasks.size()); + auto taskJson = folly::parseJson(commitTasks.at(0)); + ASSERT_EQ(1, taskJson.count("partitionDataJson")); + auto partitionData = + folly::parseJson(taskJson["partitionDataJson"].asString()); + ASSERT_EQ(1, partitionData.count("partitionValues")); + auto partitionValues = partitionData["partitionValues"]; + ASSERT_TRUE(partitionValues.isArray()); + ASSERT_TRUE(partitionValues[0].isNull()); + + auto files = listFiles(outputDirectory->getPath()); + ASSERT_EQ(files.size(), 1); + + for (const auto& file : files) { + auto partitionKeys = extractPartitionKeys(file); + ASSERT_EQ(partitionKeys.size(), 1); + ASSERT_TRUE(partitionKeys.contains(testCase.name)); + ASSERT_FALSE(partitionKeys.at(testCase.name).has_value()); + } + } +} + +TEST_F(IcebergInsertTest, partitionMultiColumns) { + auto rowType = + ROW({"c1", "c2", "c3", "c4"}, + { + BIGINT(), + INTEGER(), + SMALLINT(), + DECIMAL(18, 5), + }); + std::vector> columnCombinations = { + {0, 1}, // BIGINT, INTEGER. + {2, 1}, // SMALLINT, INTEGER. + {2, 3}, // SMALLINT, DECIMAL. + {0, 2, 1} // BIGINT, SMALLINT, INTEGER. + }; + + for (const auto& combination : columnCombinations) { + const auto outputDirectory = TempDirectoryPath::create(); + constexpr int32_t numBatches = 2; + constexpr int32_t vectorSize = 50; + + std::vector vectors; + vectors.reserve(numBatches); + for (int32_t batch = 0; batch < numBatches; ++batch) { + vectors.push_back(makeRowVector( + rowType->names(), + { + makeFlatVector( + vectorSize, [](auto row) { return row * 100; }), + makeFlatVector( + vectorSize, [](auto row) { return row * 10; }), + makeFlatVector(vectorSize, [](auto row) { return row; }), + makeFlatVector( + vectorSize, + [](auto row) { return (row * 1000); }, + nullptr, + DECIMAL(18, 5)), + })); + } + + std::vector partitionTransforms; + for (auto colIndex : combination) { + partitionTransforms.push_back( + {colIndex, TransformType::kIdentity, std::nullopt}); + } + + const auto dataSink = createDataSinkAndAppendData( + vectors, outputDirectory->getPath(), partitionTransforms); + + const auto commitTasks = dataSink->close(); + auto splits = createSplitsForDirectory(outputDirectory->getPath()); + + ASSERT_EQ(commitTasks.size(), vectorSize); + ASSERT_EQ(splits.size(), commitTasks.size()); + + auto plan = exec::test::PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .planNode(); + exec::test::AssertQueryBuilder(plan).splits(splits).assertResults(vectors); + } +} + +TEST_F(IcebergInsertTest, maxTargetFileSizeRotation) { + setConnectorSessionProperty(HiveConfig::kMaxTargetFileSizeSession, "4KB"); + + const auto outputPath = TempDirectoryPath::create()->getPath(); + const auto rowType = ROW({"c0", "c1"}, {BIGINT(), VARCHAR()}); + const auto vectors = createTestData(rowType, 10, 1'000); + const auto dataSink = createDataSinkAndAppendData(vectors, outputPath); + const auto commitTasks = dataSink->close(); + + ASSERT_EQ(listFiles(outputPath).size(), 5); + + auto splits = createSplitsForDirectory(outputPath); + auto plan = exec::test::PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .planNode(); + exec::test::AssertQueryBuilder(plan).splits(splits).assertResults(vectors); +} + +#endif + +} // namespace +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/tests/IcebergParquetStatsTest.cpp b/velox/connectors/hive/iceberg/tests/IcebergParquetStatsTest.cpp new file mode 100644 index 00000000000..ebab93a2510 --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/IcebergParquetStatsTest.cpp @@ -0,0 +1,881 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "velox/common/encode/Base64.h" +#include "velox/connectors/hive/iceberg/IcebergDataFileStatistics.h" +#include "velox/connectors/hive/iceberg/tests/IcebergTestBase.h" + +using namespace facebook::velox::common::testutil; + +namespace facebook::velox::connector::hive::iceberg { + +namespace { + +#ifdef VELOX_ENABLE_PARQUET + +class IcebergParquetStatsTest : public test::IcebergTestBase { + protected: + static IcebergDataFileStatisticsPtr statsFromMetrics( + const folly::dynamic& metrics) { + VELOX_CHECK(metrics.isObject()); + VELOX_CHECK(metrics.count("recordCount") > 0); + auto stats = std::make_shared(); + stats->numRecords = metrics["recordCount"].asInt(); + + auto setIntField = [&](const folly::dynamic& map, auto setter) { + if (!map.isObject()) { + return; + } + for (const auto& item : map.items()) { + const auto fieldId = folly::to(item.first.asString()); + auto& column = stats->columnStats[fieldId]; + setter(column, item.second); + } + }; + + setIntField(metrics["columnSizes"], [](auto& column, const auto& value) { + column.columnSize = value.asInt(); + }); + setIntField(metrics["valueCounts"], [](auto& column, const auto& value) { + column.valueCount = value.asInt(); + }); + setIntField( + metrics["nullValueCounts"], [](auto& column, const auto& value) { + column.nullValueCount = value.asInt(); + }); + setIntField(metrics["nanValueCounts"], [](auto& column, const auto& value) { + column.nanValueCount = value.asInt(); + }); + + const auto& lowerBounds = metrics["lowerBounds"]; + if (lowerBounds.isObject()) { + for (const auto& item : lowerBounds.items()) { + const auto fieldId = folly::to(item.first.asString()); + stats->columnStats[fieldId].lowerBound = item.second.asString(); + } + } + const auto& upperBounds = metrics["upperBounds"]; + if (upperBounds.isObject()) { + for (const auto& item : upperBounds.items()) { + const auto fieldId = folly::to(item.first.asString()); + stats->columnStats[fieldId].upperBound = item.second.asString(); + } + } + + return stats; + } + + static std::vector statsFromCommitTasks( + const std::vector& commitTasks) { + std::vector stats; + stats.reserve(commitTasks.size()); + for (const auto& task : commitTasks) { + auto taskJson = folly::parseJson(task); + VELOX_CHECK(taskJson.isObject()); + VELOX_CHECK(taskJson.count("metrics") > 0); + stats.emplace_back(statsFromMetrics(taskJson["metrics"])); + } + return stats; + } + + // Write data and get all stats (for partitioned tables). + std::vector> + writeDataAndGetAllStats( + const RowVectorPtr& data, + const std::vector& partitionFields = {}) { + const auto outputDir = TempDirectoryPath::create(); + auto dataSink = createDataSinkAndAppendData( + {data}, outputDir->getPath(), partitionFields); + auto commitTasks = dataSink->close(); + EXPECT_FALSE(commitTasks.empty()); + return statsFromCommitTasks(commitTasks); + } + + // Decode and extract typed value from base64 encoded bounds. + template + static std::pair decodeBounds( + const std::shared_ptr& stats, + int32_t fieldId) { + auto decode = [](const std::string& base64Encoded) { + const std::string decoded = encoding::Base64::decode(base64Encoded); + T value; + std::memcpy(&value, decoded.data(), sizeof(T)); + return value; + }; + + const auto& columnStats = stats->columnStats.at(fieldId); + VELOX_CHECK(columnStats.lowerBound.has_value()); + VELOX_CHECK(columnStats.upperBound.has_value()); + return { + decode(columnStats.lowerBound.value()), + decode(columnStats.upperBound.value()), + }; + } + + // Verify basic statistics (record count, value counts, null counts). + static void verifyBasicStats( + const std::shared_ptr& stats, + int64_t expectedRecords, + const std::unordered_map& expectedValueCounts, + const std::unordered_map& expectedNullCounts) { + EXPECT_EQ(stats->numRecords, expectedRecords); + + for (const auto& [fieldId, count] : expectedValueCounts) { + ASSERT_TRUE(stats->columnStats.contains(fieldId)); + EXPECT_EQ(stats->columnStats.at(fieldId).valueCount, count); + } + + if (!expectedNullCounts.empty()) { + for (const auto& [fieldId, count] : expectedNullCounts) { + ASSERT_TRUE(stats->columnStats.contains(fieldId)); + EXPECT_EQ(stats->columnStats.at(fieldId).nullValueCount, count); + } + } + } + + // Verify bounds exist for given field IDs. + static void verifyBoundsExist( + const std::shared_ptr& stats, + const std::vector& fieldIds) { + for (const int32_t fieldId : fieldIds) { + ASSERT_TRUE(stats->columnStats.contains(fieldId)); + const auto& columnStats = stats->columnStats.at(fieldId); + ASSERT_TRUE(columnStats.lowerBound.has_value()); + ASSERT_TRUE(columnStats.upperBound.has_value()); + EXPECT_FALSE(columnStats.lowerBound.value().empty()); + EXPECT_FALSE(columnStats.upperBound.value().empty()); + } + } + + // Verify bounds do not exist for given field IDs. + static void verifyBoundsNotExist( + const std::shared_ptr& stats, + const std::vector& fieldIds) { + for (const int32_t fieldId : fieldIds) { + if (stats->columnStats.contains(fieldId)) { + const auto& columnStats = stats->columnStats.at(fieldId); + ASSERT_FALSE(columnStats.lowerBound.has_value()); + ASSERT_FALSE(columnStats.upperBound.has_value()); + } + } + } +}; + +TEST_F(IcebergParquetStatsTest, mixedNull) { + constexpr vector_size_t size = 100; + constexpr int32_t expectedIntNulls = 34; + constexpr int32_t intColId = 1; + + const auto& stats = + writeDataAndGetAllStats(makeRowVector({makeFlatVector( + size, [](vector_size_t row) { return row * 10; }, nullEvery(3))})); + verifyBasicStats( + stats[0], size, {{intColId, size}}, {{intColId, expectedIntNulls}}); + verifyBoundsExist(stats[0], {intColId}); + + const auto& [minVal, maxVal] = decodeBounds(stats[0], intColId); + EXPECT_EQ(minVal, 10); + EXPECT_EQ(maxVal, 980); +} + +TEST_F(IcebergParquetStatsTest, bigint) { + constexpr vector_size_t size = 100; + constexpr int32_t expectedNulls = 25; + constexpr int32_t bigintColId = 1; + + const auto& stats = + writeDataAndGetAllStats(makeRowVector({makeFlatVector( + size, + [](vector_size_t row) { return row * 1'000'000'000LL; }, + nullEvery(4))})); + verifyBasicStats( + stats[0], size, {{bigintColId, size}}, {{bigintColId, expectedNulls}}); + verifyBoundsExist(stats[0], {bigintColId}); + + const auto& [minVal, maxVal] = decodeBounds(stats[0], bigintColId); + EXPECT_EQ(minVal, 1'000'000'000LL); + EXPECT_EQ(maxVal, 99'000'000'000LL); +} + +TEST_F(IcebergParquetStatsTest, decimal) { + constexpr vector_size_t size = 100; + constexpr int32_t expectedNulls = 20; + constexpr int32_t decimalColId = 1; + + const auto& stats = + writeDataAndGetAllStats(makeRowVector({makeFlatVector( + size, + [](vector_size_t row) { return HugeInt::build(row, row * 123); }, + nullEvery(5), + DECIMAL(38, 3))})); + verifyBasicStats( + stats[0], size, {{decimalColId, size}}, {{decimalColId, expectedNulls}}); + verifyBoundsExist(stats[0], {decimalColId}); +} + +TEST_F(IcebergParquetStatsTest, varchar) { + constexpr vector_size_t size = 100; + constexpr int32_t varcharColId = 1; + + const auto& stats = + writeDataAndGetAllStats(makeRowVector({makeFlatVector( + size, + [](vector_size_t row) { + return "Customer#00000" + std::to_string(row) + "_" + + std::string(row % 10, 'a'); + }, + nullEvery(6))})); + + constexpr int32_t expectedNulls = 17; + verifyBasicStats( + stats[0], size, {{varcharColId, size}}, {{varcharColId, expectedNulls}}); + verifyBoundsExist(stats[0], {varcharColId}); + + EXPECT_EQ( + encoding::Base64::decode( + stats[0]->columnStats.at(varcharColId).lowerBound.value()), + "Customer#0000010"); + EXPECT_EQ( + encoding::Base64::decode( + stats[0]->columnStats.at(varcharColId).upperBound.value()), + "Customer#000009`"); +} + +TEST_F(IcebergParquetStatsTest, varbinary) { + constexpr vector_size_t size = 100; + constexpr int32_t varbinaryColId = 1; + + auto rowVector = makeRowVector({makeFlatVector( + size, + [](vector_size_t row) { + std::string value(17, 11); + value[0] = static_cast(row % 256); + value[1] = static_cast((row * 3) % 256); + value[2] = static_cast((row * 7) % 256); + value[3] = static_cast((row * 11) % 256); + return value; + }, + nullEvery(5), + VARBINARY())}); + + const auto& stats = writeDataAndGetAllStats(rowVector); + constexpr int32_t expectedNulls = 20; + verifyBasicStats( + stats[0], + size, + {{varbinaryColId, size}}, + {{varbinaryColId, expectedNulls}}); + verifyBoundsExist(stats[0], {varbinaryColId}); +} + +TEST_F(IcebergParquetStatsTest, varbinaryWithTransform) { + const auto& fileStats = writeDataAndGetAllStats( + makeRowVector({makeFlatVector( + {"01020304", + "05060708", + "090A0B0C", + "0D0E0F10", + "11121314", + "15161718", + "191A1B1C", + "1D1E1F20", + "21222324", + "25262728"}, + VARBINARY())}), + {{0, TransformType::kBucket, 4}}); + ASSERT_EQ(fileStats.size(), 3); + const auto& stats = fileStats[0]; + EXPECT_EQ(stats->numRecords, 5); + constexpr int32_t varbinaryColId = 1; + EXPECT_EQ(stats->columnStats.at(varbinaryColId).valueCount, 5); +} + +TEST_F(IcebergParquetStatsTest, multipleDataTypes) { + constexpr vector_size_t size = 100; + constexpr int32_t intColId = 1; + constexpr int32_t bigintColId = 2; + constexpr int32_t decimalColId = 3; + constexpr int32_t varcharColId = 4; + constexpr int32_t varbinaryColId = 5; + + constexpr int32_t expectedIntNulls = 34; + constexpr int32_t expectedBigintNulls = 25; + constexpr int32_t expectedDecimalNulls = 20; + constexpr int32_t expectedVarcharNulls = 17; + constexpr int32_t expectedVarbinaryNulls = 15; + + auto rowVector = makeRowVector( + {makeFlatVector( + size, [](vector_size_t row) { return row * 10; }, nullEvery(3)), + makeFlatVector( + size, + [](vector_size_t row) { return row * 1'000'000'000LL; }, + nullEvery(4)), + makeFlatVector( + size, + [](vector_size_t row) { return HugeInt::build(row, row * 12'345); }, + nullEvery(5), + DECIMAL(38, 3)), + makeFlatVector( + size, + [](vector_size_t row) { return "str_" + std::to_string(row); }, + nullEvery(6)), + makeFlatVector( + size, + [](vector_size_t row) { + std::string value(4, 0); + value[0] = static_cast(row % 256); + value[1] = static_cast((row * 3) % 256); + value[2] = static_cast((row * 7) % 256); + value[3] = static_cast((row * 11) % 256); + return value; + }, + nullEvery(7), + VARBINARY())}); + const auto& stats = writeDataAndGetAllStats(rowVector); + + verifyBasicStats( + stats[0], + size, + { + {intColId, size}, + {bigintColId, size}, + {decimalColId, size}, + {varcharColId, size}, + {varbinaryColId, size}, + }, + { + {intColId, expectedIntNulls}, + {bigintColId, expectedBigintNulls}, + {decimalColId, expectedDecimalNulls}, + {varcharColId, expectedVarcharNulls}, + {varbinaryColId, expectedVarbinaryNulls}, + }); + + verifyBoundsExist( + stats[0], + {intColId, bigintColId, decimalColId, varcharColId, varbinaryColId}); +} + +TEST_F(IcebergParquetStatsTest, date) { + constexpr vector_size_t size = 100; + constexpr int32_t expectedNulls = 20; + constexpr int32_t dateColId = 1; + + const auto& stats = + writeDataAndGetAllStats(makeRowVector({makeFlatVector( + size, + [](vector_size_t row) { return 18262 + row; }, + nullEvery(5), + DATE())})); + verifyBasicStats( + stats[0], size, {{dateColId, size}}, {{dateColId, expectedNulls}}); + verifyBoundsExist(stats[0], {dateColId}); + + const auto& [minVal, maxVal] = decodeBounds(stats[0], dateColId); + EXPECT_EQ(minVal, 18263); + EXPECT_EQ(maxVal, 18262 + 99); +} + +TEST_F(IcebergParquetStatsTest, boolean) { + constexpr vector_size_t size = 100; + constexpr int32_t expectedNulls = 10; + constexpr int32_t boolColId = 1; + + const auto& stats = + writeDataAndGetAllStats(makeRowVector({makeFlatVector( + size, + [](vector_size_t row) { return row % 2 == 1; }, + nullEvery(10), + BOOLEAN())})); + verifyBasicStats( + stats[0], size, {{boolColId, size}}, {{boolColId, expectedNulls}}); + verifyBoundsExist(stats[0], {boolColId}); + + // For boolean, the lower bound should be false (0) and upper bound should be + // true (1) if both values are present. + const auto& [minVal, maxVal] = decodeBounds(stats[0], boolColId); + EXPECT_FALSE(minVal); + EXPECT_TRUE(maxVal); +} + +TEST_F(IcebergParquetStatsTest, empty) { + const auto outputDir = TempDirectoryPath::create(); + auto dataSink = createDataSinkAndAppendData( + {makeRowVector( + {makeFlatVector(0), makeFlatVector(0)})}, + outputDir->getPath()); + auto commitTasks = dataSink->close(); + EXPECT_TRUE(commitTasks.empty()); +} + +TEST_F(IcebergParquetStatsTest, nullValues) { + constexpr vector_size_t size = 100; + + const auto& stats = writeDataAndGetAllStats(makeRowVector( + {makeNullConstant(TypeKind::INTEGER, size), + makeNullConstant(TypeKind::VARCHAR, size)})); + EXPECT_EQ(stats[0]->numRecords, size); + ASSERT_EQ(stats[0]->columnStats.at(1).nullValueCount, size); + // Do not collect lower and upper bounds for NULLs. + for (const auto& [fieldId, columnStats] : stats[0]->columnStats) { + ASSERT_FALSE(columnStats.lowerBound.has_value()); + ASSERT_FALSE(columnStats.upperBound.has_value()); + } +} + +TEST_F(IcebergParquetStatsTest, real) { + constexpr vector_size_t size = 100; + constexpr int32_t expectedNulls = 20; + constexpr int32_t realColId = 1; + int32_t expectedNaNs = 0; + + const auto& stats = + writeDataAndGetAllStats(makeRowVector({makeFlatVector( + size, + [&expectedNaNs](vector_size_t row) { + if (row % 6 == 0) { + expectedNaNs++; + return std::numeric_limits::quiet_NaN(); + } + return row * 1.5f; + }, + nullEvery(5), + REAL())})); + verifyBasicStats( + stats[0], size, {{realColId, size}}, {{realColId, expectedNulls}}); + + EXPECT_EQ( + stats[0]->columnStats.at(realColId).nanValueCount.value_or(0), + expectedNaNs); + verifyBoundsExist(stats[0], {realColId}); + const auto& [minVal, maxVal] = decodeBounds(stats[0], realColId); + EXPECT_FLOAT_EQ(minVal, 1.5f); + EXPECT_FLOAT_EQ(maxVal, 148.5f); +} + +TEST_F(IcebergParquetStatsTest, double) { + constexpr vector_size_t size = 100; + constexpr int32_t expectedNulls = 15; + constexpr int32_t doubleColId = 1; + int32_t expectedNaNs = 0; + + auto rowVector = makeRowVector({makeFlatVector( + size, + [&expectedNaNs](vector_size_t row) { + if (row % 3 == 0) { + expectedNaNs++; + return std::numeric_limits::quiet_NaN(); + } + if (row % 4 == 0) { + return std::numeric_limits::infinity(); + } + if (row % 5 == 0) { + return -std::numeric_limits::infinity(); + } + return row * 2.5; + }, + nullEvery(7), + DOUBLE())}); + + const auto& stats = writeDataAndGetAllStats(rowVector); + verifyBasicStats( + stats[0], size, {{doubleColId, size}}, {{doubleColId, expectedNulls}}); + + EXPECT_EQ( + stats[0]->columnStats.at(doubleColId).nanValueCount.value_or(0), + expectedNaNs); + + verifyBoundsExist(stats[0], {doubleColId}); + + // Verify bounds are set correctly and NaN/infinity values don't affect + // min/max incorrectly. + const auto& [minVal, maxVal] = decodeBounds(stats[0], doubleColId); + EXPECT_DOUBLE_EQ(minVal, -std::numeric_limits::infinity()) + << "Lower bound should be -infinity"; + EXPECT_DOUBLE_EQ(maxVal, std::numeric_limits::infinity()) + << "Upper bound should be infinity"; +} + +TEST_F(IcebergParquetStatsTest, mixedDoubleFloat) { + constexpr vector_size_t size = 6; + + auto rowVector = makeRowVector( + {makeFlatVector(size, [](vector_size_t row) { return 1; }), + makeFlatVector( + size, + [](vector_size_t row) { + return -std::numeric_limits::infinity(); + }), + makeFlatVector( + size, + [](vector_size_t row) { + return std::numeric_limits::infinity(); + }), + makeFlatVector(size, [](vector_size_t row) { + switch (row) { + case 0: + return 1.23; + case 1: + return -1.23; + case 2: + return std::numeric_limits::infinity(); + case 3: + return 2.23; + case 4: + return -std::numeric_limits::infinity(); + default: + return -2.23; + } + })}); + + const auto& stats = writeDataAndGetAllStats(rowVector); + constexpr int32_t doubleColId = 4; + verifyBasicStats(stats[0], size, {{doubleColId, size}}, {{doubleColId, 0}}); + const auto& [minVal, maxVal] = decodeBounds(stats[0], doubleColId); + EXPECT_DOUBLE_EQ(minVal, -std::numeric_limits::infinity()); + EXPECT_DOUBLE_EQ(maxVal, std::numeric_limits::infinity()); + + constexpr int32_t floatColId = 2; + const auto& [minFloatVal, maxFloatVal] = + decodeBounds(stats[0], floatColId); + EXPECT_FLOAT_EQ(minFloatVal, -std::numeric_limits::infinity()); + EXPECT_FLOAT_EQ(maxFloatVal, -std::numeric_limits::infinity()); +} + +TEST_F(IcebergParquetStatsTest, NaN) { + constexpr vector_size_t size = 1'000; + constexpr int32_t expectedNulls = 500; + constexpr int32_t doubleColId = 1; + int32_t expectedNaNs = 0; + + const auto& stats = + writeDataAndGetAllStats(makeRowVector({makeFlatVector( + size, + [&expectedNaNs](vector_size_t /*row*/) { + expectedNaNs++; + return std::numeric_limits::quiet_NaN(); + }, + nullEvery(2), + DOUBLE())})); + verifyBasicStats( + stats[0], size, {{doubleColId, size}}, {{doubleColId, expectedNulls}}); + + EXPECT_EQ( + stats[0]->columnStats.at(doubleColId).nanValueCount.value_or(0), + expectedNaNs); + // Do not collect bounds for NULLs and NaNs. + for (const auto& [fieldId, columnStats] : stats[0]->columnStats) { + ASSERT_FALSE(columnStats.lowerBound.has_value()); + ASSERT_FALSE(columnStats.upperBound.has_value()); + } +} + +TEST_F(IcebergParquetStatsTest, partitionedTable) { + std::vector partitionTransforms = { + {0, TransformType::kBucket, 4}, + {1, TransformType::kDay, std::nullopt}, + {2, TransformType::kTruncate, 2}, + }; + + constexpr vector_size_t size = 100; + + auto rowVector = makeRowVector( + {makeFlatVector(size, [](vector_size_t row) { return row; }), + makeFlatVector( + size, + [](vector_size_t row) { return 18262 + (row % 5); }, + nullptr, + DATE()), + makeFlatVector(size, [](vector_size_t row) { + return fmt::format("str{}", row % 10); + })}); + + const auto& fileStats = + writeDataAndGetAllStats(rowVector, partitionTransforms); + + EXPECT_GT(fileStats.size(), 1) + << "Expected multiple files due to partitioning"; + + for (const auto& stats : fileStats) { + EXPECT_GT(stats->numRecords, 0); + ASSERT_FALSE(stats->columnStats.empty()); + + constexpr int32_t intColId = 1; + constexpr int32_t dateColId = 2; + constexpr int32_t varcharColId = 3; + EXPECT_EQ(stats->columnStats.at(intColId).valueCount, stats->numRecords); + EXPECT_EQ(stats->columnStats.at(dateColId).valueCount, stats->numRecords); + EXPECT_EQ( + stats->columnStats.at(varcharColId).valueCount, stats->numRecords); + + for (const auto fieldId : {intColId, dateColId, varcharColId}) { + const auto& columnStats = stats->columnStats.at(fieldId); + ASSERT_TRUE(columnStats.lowerBound.has_value()); + ASSERT_TRUE(columnStats.upperBound.has_value()); + EXPECT_FALSE(columnStats.lowerBound.value().empty()); + EXPECT_FALSE(columnStats.upperBound.value().empty()); + } + } + + // Verify total record count across all partitions. + int64_t totalRecords = 0; + for (const auto& stats : fileStats) { + totalRecords += stats->numRecords; + } + EXPECT_EQ(totalRecords, size); +} + +TEST_F(IcebergParquetStatsTest, multiplePartitionTransforms) { + std::vector partitionTransforms = { + {0, TransformType::kBucket, 2}, + {1, TransformType::kYear, std::nullopt}, + {2, TransformType::kTruncate, 3}, + {3, TransformType::kIdentity, std::nullopt}}; + + constexpr vector_size_t size = 100; + + auto rowVector = makeRowVector( + {makeFlatVector( + size, [](vector_size_t row) { return row * 10; }), + makeFlatVector( + size, + [](vector_size_t row) { return 18262 + (row * 100); }, + nullptr, + DATE()), + makeFlatVector( + size, + [](vector_size_t row) { + return fmt::format("prefix{}_value", row % 5); + }), + makeFlatVector( + size, [](vector_size_t row) { return (row % 3) * 1'000; })}); + + const auto& fileStats = + writeDataAndGetAllStats(rowVector, partitionTransforms); + EXPECT_GT(fileStats.size(), 1) + << "Expected multiple files due to partitioning"; + // Check each file's stats. + for (const auto& stats : fileStats) { + EXPECT_GT(stats->numRecords, 0); + constexpr int32_t intColId = 1; + constexpr int32_t dateColId = 2; + constexpr int32_t bigintColId = 4; + + if (stats->columnStats.contains(intColId)) { + const auto& [minVal, maxVal] = decodeBounds(stats, intColId); + EXPECT_LE(minVal, maxVal) + << "Lower bound should be <= upper bound for int column"; + } + + if (stats->columnStats.contains(dateColId)) { + const auto& [minVal, maxVal] = decodeBounds(stats, dateColId); + EXPECT_LE(minVal, maxVal) + << "Lower bound should be <= upper bound for date column"; + } + + if (stats->columnStats.contains(bigintColId)) { + const auto& [minVal, maxVal] = decodeBounds(stats, bigintColId); + EXPECT_LE(minVal, maxVal) + << "Lower bound should be <= upper bound for bigint column"; + } + } + int64_t totalRecords = 0; + for (const auto& stats : fileStats) { + totalRecords += stats->numRecords; + } + EXPECT_EQ(totalRecords, size); +} + +TEST_F(IcebergParquetStatsTest, partitionedTableWithNulls) { + constexpr vector_size_t size = 100; + constexpr int32_t expectedIntNulls = 20; + constexpr int32_t expectedDateNulls = 15; + constexpr int32_t expectedVarcharNulls = 10; + + std::vector partitionTransforms = { + {0, TransformType::kIdentity, std::nullopt}, + {1, TransformType::kMonth, std::nullopt}, + {2, TransformType::kTruncate, 2}}; + auto rowVector = makeRowVector( + {makeFlatVector( + size, + [](vector_size_t row) { return row % 10; }, + nullEvery(5), + INTEGER()), + makeFlatVector( + size, + [](vector_size_t row) { return 18262 + (row % 3) * 30; }, + nullEvery(7), + DATE()), + makeFlatVector( + size, + [](vector_size_t row) { return fmt::format("val{}", row % 5); }, + nullEvery(11))}); + const auto& fileStats = + writeDataAndGetAllStats(rowVector, partitionTransforms); + int32_t totalIntNulls = 0; + int32_t totalDateNulls = 0; + int32_t totalVarcharNulls = 0; + int32_t totalRecords = 0; + + constexpr int32_t intColId = 1; + constexpr int32_t dateColId = 2; + constexpr int32_t varcharColId = 3; + + for (const auto& stats : fileStats) { + totalRecords += stats->numRecords; + // Add null counts if present. + if (stats->columnStats.contains(intColId)) { + totalIntNulls += stats->columnStats.at(intColId).nullValueCount; + } + + if (stats->columnStats.contains(dateColId)) { + totalDateNulls += stats->columnStats.at(dateColId).nullValueCount; + } + + if (stats->columnStats.contains(varcharColId)) { + totalVarcharNulls += stats->columnStats.at(varcharColId).nullValueCount; + } + } + + // Verify total counts match expected. + EXPECT_EQ(totalRecords, size); + EXPECT_EQ(totalIntNulls, expectedIntNulls); + EXPECT_EQ(totalDateNulls, expectedDateNulls); + EXPECT_EQ(totalVarcharNulls, expectedVarcharNulls); +} + +TEST_F(IcebergParquetStatsTest, mapType) { + constexpr vector_size_t size = 100; + constexpr int32_t intColId = 1; + constexpr int32_t mapValueColId = 3; // Map value field ID. + + std::vector>>>> + mapData; + for (auto i = 0; i < size; ++i) { + std::vector>> mapRow; + for (auto j = 0; j < 5; ++j) { + mapRow.emplace_back(j, fmt::format("value_{}", i * 5 + j)); + } + mapData.push_back(std::move(mapRow)); + } + + const auto& stats = writeDataAndGetAllStats(makeRowVector({ + makeFlatVector(size, [](auto row) { return row * 10; }), + makeNullableMapVector(mapData), + })); + verifyBasicStats(stats[0], size, {{intColId, size}}, {{intColId, 0}}); + + EXPECT_EQ(stats[0]->columnStats.at(mapValueColId).valueCount, size * 5); + // Map values have stats but no bounds (skipBounds=true for maps). + verifyBoundsNotExist(stats[0], {mapValueColId}); +} + +TEST_F(IcebergParquetStatsTest, arrayType) { + constexpr vector_size_t size = 100; + constexpr int32_t intColId = 1; + constexpr int32_t arrayElementColId = 3; // Array element field ID. + + std::vector>> arrayData; + for (auto i = 0; i < size; ++i) { + std::vector> arrayRow; + for (auto j = 0; j < 3; ++j) { + arrayRow.emplace_back(fmt::format("item_{}", i * 3 + j)); + } + arrayData.push_back(std::move(arrayRow)); + } + + const auto& stats = writeDataAndGetAllStats(makeRowVector( + {makeFlatVector(size, [](auto row) { return row * 10; }), + makeNullableArrayVector(arrayData)})); + verifyBasicStats(stats[0], size, {{intColId, size}}, {{intColId, 0}}); + + EXPECT_EQ(stats[0]->columnStats.at(arrayElementColId).valueCount, size * 3); + // Array elements have stats but no bounds (skipBounds=true for arrays). + verifyBoundsNotExist(stats[0], {arrayElementColId}); +} + +// Test statistics collection for nested struct fields. +// Field ID assignment: +// int_col: 1 +// struct_col: 2 (parent, no stats) +// first_level_id: 3 +// first_level_name: 4 +// nested_struct: 5 (parent, no stats) +// second_level_id: 6 +// second_level_name: 7 +// Statistics collected for leaf fields: [1, 3, 4, 6, 7] +TEST_F(IcebergParquetStatsTest, structType) { + constexpr vector_size_t size = 100; + constexpr int32_t intColId = 1; + constexpr int32_t firstLevelIdColId = 3; + constexpr int32_t secondLevelIdColId = 6; + constexpr int32_t secondLevelNameColId = 7; + + auto firstLevelId = makeFlatVector( + size, [](vector_size_t row) { return row % size; }, nullEvery(5)); + + auto firstLevelName = makeFlatVector( + size, + [](vector_size_t row) { return fmt::format("name_{}", row * 10); }, + nullEvery(7)); + + auto secondLevelId = makeFlatVector( + size, [](vector_size_t row) { return row * size; }, nullEvery(6)); + + auto secondLevelName = makeFlatVector( + size, + [](vector_size_t row) { return fmt::format("nested_{}", row * 100); }, + nullEvery(8)); + + auto nestedStruct = makeRowVector({secondLevelId, secondLevelName}); + auto structVector = + makeRowVector({firstLevelId, firstLevelName, nestedStruct}); + auto rowVector = makeRowVector( + {makeFlatVector(size, [](auto row) { return row * 10; }), + structVector}); + + const auto& stats = writeDataAndGetAllStats(rowVector); + EXPECT_EQ(stats[0]->numRecords, size); + EXPECT_EQ(stats[0]->columnStats.size(), 5); + + verifyBasicStats( + stats[0], + size, + {{intColId, size}, {firstLevelIdColId, size}, {secondLevelIdColId, size}}, + {{intColId, 0}, {firstLevelIdColId, 20}}); + + EXPECT_EQ( + encoding::Base64::decode( + stats[0]->columnStats.at(secondLevelNameColId).lowerBound.value()), + "nested_100"); + EXPECT_EQ( + encoding::Base64::decode( + stats[0]->columnStats.at(secondLevelNameColId).upperBound.value()), + "nested_9900"); +} + +#endif + +} // namespace + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/tests/IcebergReadTest.cpp b/velox/connectors/hive/iceberg/tests/IcebergReadTest.cpp index c8269083d52..71c1ae4b621 100644 --- a/velox/connectors/hive/iceberg/tests/IcebergReadTest.cpp +++ b/velox/connectors/hive/iceberg/tests/IcebergReadTest.cpp @@ -14,29 +14,47 @@ * limitations under the License. */ +#include +#include #include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/encode/Base64.h" #include "velox/common/file/FileSystems.h" +#include "velox/connectors/ConnectorRegistry.h" #include "velox/connectors/hive/HiveConnectorSplit.h" +#include "velox/connectors/hive/iceberg/IcebergColumnHandle.h" +#include "velox/connectors/hive/iceberg/IcebergConnector.h" #include "velox/connectors/hive/iceberg/IcebergDeleteFile.h" #include "velox/connectors/hive/iceberg/IcebergMetadataColumns.h" #include "velox/connectors/hive/iceberg/IcebergSplit.h" #include "velox/dwio/common/tests/utils/DataFiles.h" #include "velox/exec/PlanNodeStats.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" + #ifdef VELOX_ENABLE_PARQUET #include "velox/dwio/parquet/RegisterParquetReader.h" #endif -#include - using namespace facebook::velox::exec::test; using namespace facebook::velox::exec; using namespace facebook::velox::dwio::common; using namespace facebook::velox::test; +using namespace facebook::velox::common::testutil; namespace facebook::velox::connector::hive::iceberg { +namespace { +// Return the file size for the given path using Velox's filesystem API. +uint64_t getTestFileSize(const std::string& path) { + return filesystems::getFileSystem(path, nullptr) + ->openFileForRead(path) + ->size(); +} +} // namespace + +static const char* kIcebergConnectorId = "test-iceberg"; + class HiveIcebergTest : public HiveConnectorTestBase { public: void SetUp() override { @@ -44,6 +62,20 @@ class HiveIcebergTest : public HiveConnectorTestBase { #ifdef VELOX_ENABLE_PARQUET parquet::registerParquetReaderFactory(); #endif + // Register IcebergConnector. + IcebergConnectorFactory icebergFactory; + auto icebergConnector = icebergFactory.newConnector( + kIcebergConnectorId, + std::make_shared( + std::unordered_map()), + ioExecutor_.get()); + connector::ConnectorRegistry::global().insert( + icebergConnector->connectorId(), icebergConnector); + } + + void TearDown() override { + connector::ConnectorRegistry::global().erase(kIcebergConnectorId); + HiveConnectorTestBase::TearDown(); } HiveIcebergTest() @@ -155,7 +187,7 @@ class HiveIcebergTest : public HiveConnectorTestBase { return values; } - std::vector makeContinuousIncreasingValues( + static std::vector makeContinuousIncreasingValues( int64_t begin, int64_t end) { std::vector values; @@ -219,10 +251,9 @@ class HiveIcebergTest : public HiveConnectorTestBase { IcebergDeleteFile icebergDeleteFile( FileContent::kPositionalDeletes, deleteFilePath, - fileFomat_, + fileFormat_, deleteFilePaths[deleteFileName].first, - testing::internal::GetFileSize( - std::fopen(deleteFilePath.c_str(), "r"))); + getTestFileSize(deleteFilePath)); deleteFiles.push_back(icebergDeleteFile); } } @@ -234,13 +265,17 @@ class HiveIcebergTest : public HiveConnectorTestBase { std::string duckdbSql = getDuckDBQuery(rowGroupSizesForFiles, deleteFilesForBaseDatafiles); - auto plan = tableScanNode(); - auto task = HiveConnectorTestBase::assertQuery( - plan, splits, duckdbSql, numPrefetchSplits); + auto plan = PlanBuilder() + .startTableScan() + .connectorId(kIcebergConnectorId) + .outputType(ROW({"c0"}, {BIGINT()})) + .endTableScan() + .planNode(); + auto task = assertQuery(plan, splits, duckdbSql, numPrefetchSplits); auto planStats = toPlanStats(task->taskStats()); - auto scanNodeId = plan->id(); - auto it = planStats.find(scanNodeId); + + auto it = planStats.find(plan->id()); ASSERT_TRUE(it != planStats.end()); ASSERT_TRUE(it->second.peakMemoryBytes > 0); } @@ -250,6 +285,64 @@ class HiveIcebergTest : public HiveConnectorTestBase { protected: std::shared_ptr config_; std::function()> flushPolicyFactory_; + FileFormat fileFormat_{FileFormat::DWRF}; + + /// Helper to create a standard c0 HiveColumnHandle (BIGINT). + std::shared_ptr makeC0Handle() { + return std::make_shared( + "c0", + HiveColumnHandle::ColumnType::kRegular, + BIGINT(), + BIGINT(), + std::vector{}); + } + + /// Helper to create an IcebergColumnHandle with default value. + std::shared_ptr makeIcebergHandle( + const std::string& name, + const TypePtr& type, + int fieldId, + const std::string& defaultValue) { + return std::make_shared( + name, + HiveColumnHandle::ColumnType::kRegular, + type, + parquet::ParquetFieldId(fieldId), + std::vector{}, + std::optional{defaultValue}); + } + + /// Helper function to test schema evolution with initial default values. + void assertDefaultValues( + const RowTypePtr& outputType, + const RowTypePtr& scanSpecType, + const ColumnHandleMap& assignments, + const std::vector& data, + const std::vector& expected, + const std::unordered_map& sessionProperties = + {}) { + // Write data file with old schema + auto dataFilePath = TempFilePath::create(); + writeToFile(dataFilePath->getPath(), data); + auto icebergSplits = makeIcebergSplits(dataFilePath->getPath()); + + // Build plan + auto plan = PlanBuilder() + .startTableScan() + .connectorId(kIcebergConnectorId) + .outputType(outputType) + .dataColumns(scanSpecType) + .assignments(assignments) + .endTableScan() + .planNode(); + + // Build query and add session properties if provided + auto queryBuilder = AssertQueryBuilder(plan); + for (const auto& [key, value] : sessionProperties) { + queryBuilder.connectorSessionProperty(kIcebergConnectorId, key, value); + } + queryBuilder.splits(icebergSplits).assertResults(expected); + } std::vector> makeIcebergSplits( const std::string& dataFilePath, @@ -257,33 +350,57 @@ class HiveIcebergTest : public HiveConnectorTestBase { const std::unordered_map>& partitionKeys = {}, const uint32_t splitCount = 1) { - std::unordered_map customSplitInfo; - customSplitInfo["table_format"] = "hive-iceberg"; - auto file = filesystems::getFileSystem(dataFilePath, nullptr) ->openFileForRead(dataFilePath); const int64_t fileSize = file->size(); - std::vector> splits; const uint64_t splitSize = std::floor((fileSize) / splitCount); + std::vector> splits; + splits.reserve(splitCount); + for (int i = 0; i < splitCount; ++i) { - splits.emplace_back(std::make_shared( - kHiveConnectorId, - dataFilePath, - fileFomat_, - i * splitSize, - splitSize, - partitionKeys, - std::nullopt, - customSplitInfo, - nullptr, - /*cacheable=*/true, - deleteFiles)); + splits.emplace_back( + std::make_shared( + kIcebergConnectorId, + dataFilePath, + fileFormat_, + i * splitSize, + splitSize, + partitionKeys, + std::nullopt, + std::unordered_map{}, + nullptr, + /*cacheable=*/true, + deleteFiles)); } return splits; } + ColumnHandleMap makeColumnHandles( + const RowTypePtr& rowType, + const std::unordered_set& partitionIndices = {}) { + ColumnHandleMap assignments; + for (auto i = 0; i < rowType->size(); ++i) { + const auto& columnName = rowType->nameOf(i); + const auto& columnType = rowType->childAt(i); + auto columnHandleType = partitionIndices.contains(i) + ? FileColumnHandle::ColumnType::kPartitionKey + : FileColumnHandle::ColumnType::kRegular; + + assignments.insert( + {columnName, + std::make_shared( + columnName, + columnHandleType, + columnType, + columnType, + std::vector{})}); + } + + return assignments; + } + #ifdef VELOX_ENABLE_PARQUET std::vector> createParquetDeleteFileAndSplits( const std::string& path, @@ -299,39 +416,185 @@ class HiveIcebergTest : public HiveConnectorTestBase { static_cast(deletedPositionSize), [&](vector_size_t) { return path; }), makeFlatVector(deletePositionsVec), - })}, - config_, - flushPolicyFactory_); + })}); IcebergDeleteFile icebergDeleteFile( FileContent::kPositionalDeletes, deleteFilePath->getPath(), - fileFomat_, + fileFormat_, deletedPositionSize, - testing::internal::GetFileSize( - std::fopen(deleteFilePath->getPath().c_str(), "r"))); + getTestFileSize(deleteFilePath->getPath())); auto fileSize = filesystems::getFileSystem(path, nullptr) ->openFileForRead(path) ->size(); - std::unordered_map customSplitInfo{ - {"table_format", "hive-iceberg"}}; std::unordered_map> partitionKeys; return {std::make_shared( - kHiveConnectorId, + kIcebergConnectorId, path, - dwio::common::FileFormat::PARQUET, + FileFormat::PARQUET, 0, fileSize, partitionKeys, std::nullopt, - customSplitInfo, + std::unordered_map{}, nullptr, /*cacheable=*/true, std::vector{icebergDeleteFile})}; } #endif + /// Creates a single HiveIcebergSplit from the full data file with info + /// columns (e.g. $first_row_id, $data_sequence_number) and optional + /// delete files. + std::shared_ptr makeIcebergSplitWithInfoColumns( + const std::string& dataFilePath, + const std::unordered_map& infoColumns, + const std::vector& deleteFiles = {}) { + auto file = filesystems::getFileSystem(dataFilePath, nullptr) + ->openFileForRead(dataFilePath); + return std::make_shared( + kIcebergConnectorId, + dataFilePath, + fileFormat_, + 0, + file->size(), + std::unordered_map>{}, + std::nullopt, + std::unordered_map{}, + nullptr, + /*cacheable=*/true, + deleteFiles, + infoColumns); + } + + struct RowLineageTestCase { + std::vector values; + // Physically stored _row_id values; nullopt entries write a file null. + // Absent means no _row_id column in the file; reader derives from + // firstRowId. Always paired with storedSequenceNumbers. + std::optional>> storedRowIds; + // Physically stored _last_updated_sequence_number values; nullopt entries + // write a file null. Absent means no column in the file; reader derives + // from dataSequenceNumber. Always paired with storedRowIds. + std::optional>> storedSequenceNumbers; + // Passed as first_row_id in the split's info columns. + std::optional firstRowId; + // Passed as data_sequence_number in the split's info columns. + std::optional dataSequenceNumber; + // File positions to delete; empty means no delete file is created. + std::vector deletePositions; + // Subfield filter expression (e.g., "c0 > 20"); empty means no filter. + std::string subfieldFilter; + // Expected output rows: (c0, _row_id, _last_updated_sequence_number). + std::vector expectedVectors; + }; + + // Writes tc to a temp data file, executes a table scan over + // (c0, _row_id, _last_updated_sequence_number), and asserts the result. + void assertRowLineage(const RowLineageTestCase& tc) { + VELOX_CHECK_EQ( + tc.storedRowIds.has_value(), + tc.storedSequenceNumbers.has_value(), + "rowIds and sequenceNumbers must both be set or both absent."); + + auto pathColumn = IcebergMetadataColumn::icebergDeleteFilePathColumn(); + auto posColumn = IcebergMetadataColumn::icebergDeletePosColumn(); + + // Build the data file vectors from the explicit column fields. + std::vector inputVectors; + if (!tc.storedRowIds.has_value()) { + // No physical lineage columns: file contains only c0. + inputVectors = {makeRowVector({makeFlatVector(tc.values)})}; + } else { + // Physical lineage columns are present in the file. + static const std::vector kFileColumns = { + "c0", "_row_id", "_last_updated_sequence_number"}; + inputVectors = {makeRowVector( + kFileColumns, + { + makeFlatVector(tc.values), + makeNullableFlatVector(*tc.storedRowIds), + makeNullableFlatVector(*tc.storedSequenceNumbers), + })}; + } + + auto dataFilePath = TempFilePath::create(); + writeToFile(dataFilePath->getPath(), inputVectors); + + std::vector deleteFiles; + std::shared_ptr deleteFilePath; + if (!tc.deletePositions.empty()) { + deleteFilePath = TempFilePath::create(); + writeToFile( + deleteFilePath->getPath(), + {makeRowVector( + {pathColumn->name, posColumn->name}, + {makeFlatVector( + static_cast(tc.deletePositions.size()), + [&](auto) { return dataFilePath->getPath(); }), + makeFlatVector(tc.deletePositions)})}); + + const uint64_t upperBound = static_cast(*std::max_element( + tc.deletePositions.begin(), tc.deletePositions.end())); + const auto upperBoundLE = folly::Endian::little(upperBound); + const auto encodedUpperBound = encoding::Base64::encode( + std::string_view( + reinterpret_cast(&upperBoundLE), + sizeof(upperBoundLE))); + + deleteFiles.push_back(IcebergDeleteFile( + FileContent::kPositionalDeletes, + deleteFilePath->getPath(), + fileFormat_, + static_cast(tc.deletePositions.size()), + getTestFileSize(deleteFilePath->getPath()), + {}, + {}, + {{posColumn->id, encodedUpperBound}})); + } + + std::unordered_map infoColumns; + if (tc.firstRowId.has_value()) { + infoColumns[IcebergMetadataColumn::kFirstRowIdInfoColumn] = + std::to_string(*tc.firstRowId); + } + if (tc.dataSequenceNumber.has_value()) { + infoColumns[IcebergMetadataColumn::kDataSequenceNumberInfoColumn] = + std::to_string(*tc.dataSequenceNumber); + } + + auto split = makeIcebergSplitWithInfoColumns( + dataFilePath->getPath(), infoColumns, deleteFiles); + + const auto outputType = + ROW({"c0", "_row_id", "_last_updated_sequence_number"}, + {BIGINT(), BIGINT(), BIGINT()}); + const auto tableDataColumns = ROW({"c0"}, {BIGINT()}); + + core::PlanNodePtr plan; + if (!tc.subfieldFilter.empty()) { + plan = PlanBuilder() + .startTableScan() + .connectorId(kIcebergConnectorId) + .outputType(outputType) + .dataColumns(tableDataColumns) + .subfieldFilter(tc.subfieldFilter) + .endTableScan() + .planNode(); + } else { + plan = PlanBuilder() + .startTableScan() + .connectorId(kIcebergConnectorId) + .outputType(outputType) + .dataColumns(tableDataColumns) + .endTableScan() + .planNode(); + } + + AssertQueryBuilder(plan).splits({split}).assertResults(tc.expectedVectors); + } + private: std::map> writeDataFiles( std::map> rowGroupSizesForFiles) { @@ -348,11 +611,7 @@ class HiveIcebergTest : public HiveConnectorTestBase { // files. This is to make constructing DuckDB queries easier std::vector dataVectors = makeVectors(dataFile.second, startingValue); - writeToFile( - dataFilePaths[dataFile.first]->getPath(), - dataVectors, - config_, - flushPolicyFactory_); + writeToFile(dataFilePaths[dataFile.first]->getPath(), dataVectors); for (int i = 0; i < dataVectors.size(); i++) { dataVectorsJoined.push_back(dataVectors[i]); @@ -409,11 +668,7 @@ class HiveIcebergTest : public HiveConnectorTestBase { totalPositionsInDeleteFile += positionsInRowGroup.size(); } - writeToFile( - deleteFilePath->getPath(), - deleteFileVectors, - config_, - flushPolicyFactory_); + writeToFile(deleteFilePath->getPath(), deleteFileVectors); deleteFilePaths[deleteFileName] = std::make_pair(totalPositionsInDeleteFile, deleteFilePath); @@ -432,8 +687,7 @@ class HiveIcebergTest : public HiveConnectorTestBase { for (int j = 0; j < vectorSizes.size(); j++) { auto data = makeContinuousIncreasingValues( startingValue, startingValue + vectorSizes[j]); - VectorPtr c0 = makeFlatVector(data); - vectors.push_back(makeRowVector({"c0"}, {c0})); + vectors.push_back(makeRowVector({makeFlatVector(data)})); startingValue += vectorSizes[j]; } @@ -460,9 +714,9 @@ class HiveIcebergTest : public HiveConnectorTestBase { // Group the delete vectors by baseFileName std::map>> deletePosVectorsForAllBaseFiles; - for (auto deleteFile : deleteFilesForBaseDatafiles) { + for (auto& deleteFile : deleteFilesForBaseDatafiles) { auto deleteFileContent = deleteFile.second; - for (auto rowGroup : deleteFileContent) { + for (auto& rowGroup : deleteFileContent) { auto baseFileName = rowGroup.first; deletePosVectorsForAllBaseFiles[baseFileName].push_back( rowGroup.second); @@ -475,7 +729,7 @@ class HiveIcebergTest : public HiveConnectorTestBase { std::map> flattenedDeletePosVectorsForAllBaseFiles; int64_t totalNumDeletePositions = 0; - for (auto deleteVectorsForBaseFile : deletePosVectorsForAllBaseFiles) { + for (auto& deleteVectorsForBaseFile : deletePosVectorsForAllBaseFiles) { auto baseFileName = deleteVectorsForBaseFile.first; auto deletePositionVectors = deleteVectorsForBaseFile.second; std::vector deletePositionVector = @@ -488,14 +742,18 @@ class HiveIcebergTest : public HiveConnectorTestBase { // Now build the DuckDB queries if (totalNumDeletePositions == 0) { return "SELECT * FROM tmp"; - } else if (totalNumDeletePositions >= totalNumRowsInAllBaseFiles) { + } + + if (totalNumDeletePositions >= totalNumRowsInAllBaseFiles) { return "SELECT * FROM tmp WHERE 1 = 0"; - } else { + } + + { // Convert the delete positions in all base files into column values std::vector allDeleteValues; int64_t numRowsInPreviousBaseFiles = 0; - for (auto baseFileSize : baseFileSizes) { + for (auto& baseFileSize : baseFileSizes) { auto deletePositions = flattenedDeletePosVectorsForAllBaseFiles[baseFileSize.first]; @@ -515,7 +773,7 @@ class HiveIcebergTest : public HiveConnectorTestBase { return fmt::format( "SELECT * FROM tmp WHERE c0 NOT IN ({})", - makeNotInList(allDeleteValues)); + folly::join(", ", allDeleteValues)); } } @@ -523,7 +781,7 @@ class HiveIcebergTest : public HiveConnectorTestBase { const std::vector>& deletePositionVectors, int64_t baseFileSize) { std::vector deletePositionVector; - for (auto vec : deletePositionVectors) { + for (auto& vec : deletePositionVectors) { for (auto pos : vec) { if (pos >= 0 && pos < baseFileSize) { deletePositionVector.push_back(pos); @@ -539,29 +797,9 @@ class HiveIcebergTest : public HiveConnectorTestBase { return deletePositionVector; } - std::string makeNotInList(const std::vector& deletePositionVector) { - if (deletePositionVector.empty()) { - return ""; - } - - return std::accumulate( - deletePositionVector.begin() + 1, - deletePositionVector.end(), - std::to_string(deletePositionVector[0]), - [](const std::string& a, int64_t b) { - return a + ", " + std::to_string(b); - }); - } - - core::PlanNodePtr tableScanNode() { - return PlanBuilder(pool_.get()).tableScan(rowType_).planNode(); - } - - dwio::common::FileFormat fileFomat_{dwio::common::FileFormat::DWRF}; - - RowTypePtr rowType_{ROW({"c0"}, {BIGINT()})}; std::shared_ptr pathColumn_ = IcebergMetadataColumn::icebergDeleteFilePathColumn(); + std::shared_ptr posColumn_ = IcebergMetadataColumn::icebergDeletePosColumn(); }; @@ -591,7 +829,7 @@ TEST_F(HiveIcebergTest, singleBaseFileSinglePositionalDeleteFile) { /// delete positions. The parameter passed to /// assertSingleBaseFileSingleDeleteFile is the delete positions.for the middle /// base file. -TEST_F(HiveIcebergTest, MultipleBaseFilesSinglePositionalDeleteFile) { +TEST_F(HiveIcebergTest, multipleBaseFilesSinglePositionalDeleteFile) { folly::SingletonVault::singleton()->registrationComplete(); assertMultipleBaseFileSingleDeleteFile({0, 1, 2, 3}); @@ -754,88 +992,686 @@ TEST_F(HiveIcebergTest, positionalDeletesMultipleSplits) { assertMultipleSplits({1000, 9000, 20000}, 1, 0, 20000, 3); } -TEST_F(HiveIcebergTest, testPartitionedRead) { - RowTypePtr rowType{ROW({"c0", "ds"}, {BIGINT(), DateType::get()})}; - std::unordered_map> partitionKeys; - // Iceberg API sets partition values for dates to daysSinceEpoch, so - // in velox, we do not need to convert it to days. - // Test query on two partitions ds=17627(2018-04-06), ds=17628(2018-04-07) - std::vector> splits; - std::vector> dataFilePaths; - for (int i = 0; i <= 1; ++i) { - std::vector dataVectors; - int32_t daysSinceEpoch = 17627 + i; - VectorPtr c0 = makeFlatVector((std::vector){i}); - VectorPtr ds = - makeFlatVector((std::vector){daysSinceEpoch}); - dataVectors.push_back(makeRowVector({"c0", "ds"}, {c0, ds})); +TEST_F(HiveIcebergTest, schemaEvolutionRemoveColumn) { + auto oldRowType = ROW({"c0", "c1", "c2"}, {BIGINT(), INTEGER(), VARCHAR()}); + auto newRowType = ROW({"c0", "c2"}, {BIGINT(), VARCHAR()}); - auto dataFilePath = TempFilePath::create(); - dataFilePaths.push_back(dataFilePath); - writeToFile( - dataFilePath->getPath(), dataVectors, config_, flushPolicyFactory_); - partitionKeys["ds"] = std::to_string(daysSinceEpoch); + // Write data file with old schema (c0, c1, c2). + std::vector dataVectors; + dataVectors.push_back(makeRowVector( + oldRowType->names(), + { + makeFlatVector({1, 2, 3, 4, 5}), + makeFlatVector({10, 20, 30, 40, 50}), + makeFlatVector({"a", "b", "c", "d", "e"}), + })); + + auto dataFilePath = TempFilePath::create(); + writeToFile(dataFilePath->getPath(), dataVectors); + + auto icebergSplits = makeIcebergSplits(dataFilePath->getPath()); + + // Expected result: c0 and c2 have values, c1 is not present. + std::vector expectedVectors; + expectedVectors.push_back(makeRowVector( + newRowType->names(), + { + dataVectors[0]->childAt(0), + dataVectors[0]->childAt(2), + })); + + // Read with new schema (c0 and c2 only, c1 removed). + auto plan = PlanBuilder() + .startTableScan() + .connectorId(kIcebergConnectorId) + .outputType(newRowType) + .endTableScan() + .planNode(); + AssertQueryBuilder(plan).splits(icebergSplits).assertResults(expectedVectors); +} + +TEST_F(HiveIcebergTest, schemaEvolutionAddColumns) { + auto oldRowType = ROW({"c0"}, {BIGINT()}); + auto newRowType = ROW({"c0", "c1", "c2"}, {BIGINT(), INTEGER(), VARCHAR()}); + + // Write data file with old schema (only c0). + std::vector dataVectors; + dataVectors.push_back(makeRowVector({ + makeFlatVector({100, 200, 300}), + })); + auto dataFilePath = TempFilePath::create(); + writeToFile(dataFilePath->getPath(), dataVectors); + auto icebergSplits = makeIcebergSplits(dataFilePath->getPath()); + + // Expected result: c0 has values, c1 and c2 are NULL. + std::vector expectedVectors; + expectedVectors.push_back(makeRowVector({ + dataVectors[0]->childAt(0), + makeNullConstant(TypeKind::INTEGER, 3), + makeNullConstant(TypeKind::VARCHAR, 3), + })); + + // Read with new schema (c0, c1, and c2). + auto plan = PlanBuilder() + .startTableScan() + .connectorId(kIcebergConnectorId) + .outputType(newRowType) + .dataColumns(newRowType) + .endTableScan() + .planNode(); + AssertQueryBuilder(plan).splits(icebergSplits).assertResults(expectedVectors); +} + +// Test Iceberg V3 initial-default: a column added after data files were written +// should return its initial-default value (not NULL) for those historical rows. +TEST_F(HiveIcebergTest, addColumnWithDefault) { + auto newRowType = ROW({"c0", "country"}, {BIGINT(), VARCHAR()}); + + std::vector dataVectors; + dataVectors.push_back(makeRowVector({makeFlatVector({1, 2, 3})})); + + ColumnHandleMap assignments; + assignments["c0"] = makeC0Handle(); + assignments["country"] = makeIcebergHandle("country", VARCHAR(), 2, "IN"); + + std::vector expectedVectors; + expectedVectors.push_back(makeRowVector( + newRowType->names(), + {dataVectors[0]->childAt(0), + makeFlatVector({"IN", "IN", "IN"})})); + + assertDefaultValues( + newRowType, newRowType, assignments, dataVectors, expectedVectors); +} + +TEST_F(HiveIcebergTest, addColumnWithDefaultAndAlias) { + auto outputType = ROW({"c0", "region"}, {BIGINT(), VARCHAR()}); + auto dataColumns = ROW({"c0", "country"}, {BIGINT(), VARCHAR()}); + + std::vector dataVectors; + dataVectors.push_back(makeRowVector({makeFlatVector({1, 2, 3})})); + + ColumnHandleMap assignments; + assignments["c0"] = makeC0Handle(); + // Key is "region" (alias), but handle refers to "country" (table column) + assignments["region"] = makeIcebergHandle("country", VARCHAR(), 2, "IN"); + + std::vector expectedVectors; + expectedVectors.push_back(makeRowVector( + outputType->names(), + {dataVectors[0]->childAt(0), + makeFlatVector({"IN", "IN", "IN"})})); + + assertDefaultValues( + outputType, dataColumns, assignments, dataVectors, expectedVectors); +} + +TEST_F(HiveIcebergTest, fileValueOverridesDefault) { + auto rowType = ROW({"c0", "country"}, {BIGINT(), VARCHAR()}); + + std::vector dataVectors; + dataVectors.push_back(makeRowVector( + {makeFlatVector({1, 2, 3}), + makeFlatVector({"US", "UK", "CA"})})); + + ColumnHandleMap assignments; + assignments["c0"] = makeC0Handle(); + assignments["country"] = makeIcebergHandle("country", VARCHAR(), 2, "IN"); + + // Expected: file values ("US", "UK", "CA"), NOT the default "IN" + std::vector expectedVectors; + expectedVectors.push_back(makeRowVector( + rowType->names(), + {dataVectors[0]->childAt(0), dataVectors[0]->childAt(1)})); + + assertDefaultValues( + rowType, rowType, assignments, dataVectors, expectedVectors); +} + +TEST_F(HiveIcebergTest, addColumnWithDefaultAllTypes) { + auto newRowType = + ROW({"c0", + "tiny_val", + "small_val", + "int_val", + "big_val", + "real_val", + "double_val", + "bool_val", + "str_val", + "short_decimal", + "long_decimal", + "date_val", + "timestamp_val"}, + {BIGINT(), + TINYINT(), + SMALLINT(), + INTEGER(), + BIGINT(), + REAL(), + DOUBLE(), + BOOLEAN(), + VARCHAR(), + DECIMAL(10, 2), + DECIMAL(38, 10), + DATE(), + TIMESTAMP()}); + + std::vector dataVectors; + dataVectors.push_back(makeRowVector({makeFlatVector({1, 2, 3})})); + + ColumnHandleMap assignments; + assignments["c0"] = makeC0Handle(); + assignments["tiny_val"] = makeIcebergHandle("tiny_val", TINYINT(), 2, "10"); + assignments["small_val"] = + makeIcebergHandle("small_val", SMALLINT(), 3, "100"); + assignments["int_val"] = makeIcebergHandle("int_val", INTEGER(), 4, "1000"); + assignments["big_val"] = makeIcebergHandle("big_val", BIGINT(), 5, "10000"); + assignments["real_val"] = makeIcebergHandle("real_val", REAL(), 6, "3.14"); + assignments["double_val"] = + makeIcebergHandle("double_val", DOUBLE(), 7, "2.718"); + assignments["bool_val"] = makeIcebergHandle("bool_val", BOOLEAN(), 8, "true"); + assignments["str_val"] = + makeIcebergHandle("str_val", VARCHAR(), 9, "default_string"); + assignments["short_decimal"] = + makeIcebergHandle("short_decimal", DECIMAL(10, 2), 10, "99.99"); + assignments["long_decimal"] = makeIcebergHandle( + "long_decimal", + DECIMAL(38, 10), + 11, + "123456789012345678901234567.8901234567"); + assignments["date_val"] = + makeIcebergHandle("date_val", DATE(), 12, "2024-01-15"); + assignments["timestamp_val"] = makeIcebergHandle( + "timestamp_val", TIMESTAMP(), 13, "2024-01-15 10:30:00"); + + std::vector expectedVectors; + expectedVectors.push_back(makeRowVector( + newRowType->names(), + {dataVectors[0]->childAt(0), + makeFlatVector({10, 10, 10}), + makeFlatVector({100, 100, 100}), + makeFlatVector({1000, 1000, 1000}), + makeFlatVector({10000, 10000, 10000}), + makeFlatVector({3.14f, 3.14f, 3.14f}), + makeFlatVector({2.718, 2.718, 2.718}), + makeFlatVector({true, true, true}), + makeFlatVector( + {"default_string", "default_string", "default_string"}), + makeFlatVector({9999, 9999, 9999}, DECIMAL(10, 2)), + makeFlatVector( + {HugeInt::parse("1234567890123456789012345678901234567"), + HugeInt::parse("1234567890123456789012345678901234567"), + HugeInt::parse("1234567890123456789012345678901234567")}, + DECIMAL(38, 10)), + makeFlatVector({19737, 19737, 19737}, DATE()), + makeFlatVector( + {Timestamp(1705314600, 0), + Timestamp(1705314600, 0), + Timestamp(1705314600, 0)})})); + + assertDefaultValues( + newRowType, + newRowType, + assignments, + dataVectors, + expectedVectors, + {{HiveConfig::kReadTimestampPartitionValueAsLocalTimeSession, "false"}}); +} + +TEST_F(HiveIcebergTest, addColumnWithInvalidDefault) { + auto newRowType = ROW({"c0", "age"}, {BIGINT(), INTEGER()}); + + std::vector dataVectors; + dataVectors.push_back(makeRowVector({makeFlatVector({1, 2, 3})})); + auto dataFilePath = TempFilePath::create(); + writeToFile(dataFilePath->getPath(), dataVectors); + auto icebergSplits = makeIcebergSplits(dataFilePath->getPath()); + + ColumnHandleMap assignments; + assignments["c0"] = makeC0Handle(); + assignments["age"] = makeIcebergHandle("age", INTEGER(), 2, "IN"); + + auto plan = PlanBuilder() + .startTableScan() + .connectorId(kIcebergConnectorId) + .outputType(newRowType) + .dataColumns(newRowType) + .assignments(assignments) + .endTableScan() + .planNode(); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan).splits(icebergSplits).copyResults(pool()), + "Invalid"); +} + +TEST_F(HiveIcebergTest, addColumnWithEmptyStringDefault) { + auto newRowType = ROW({"c0", "name"}, {BIGINT(), VARCHAR()}); + + std::vector dataVectors; + dataVectors.push_back(makeRowVector({makeFlatVector({1, 2, 3})})); + + ColumnHandleMap assignments; + assignments["c0"] = makeC0Handle(); + assignments["name"] = makeIcebergHandle("name", VARCHAR(), 2, ""); + + std::vector expectedVectors; + expectedVectors.push_back(makeRowVector( + newRowType->names(), + {dataVectors[0]->childAt(0), makeFlatVector({"", "", ""})})); + + assertDefaultValues( + newRowType, newRowType, assignments, dataVectors, expectedVectors, {}); +} + +TEST_F(HiveIcebergTest, defaultValueWithDeletesAndFilters) { + auto newRowType = ROW({"c0", "country"}, {BIGINT(), VARCHAR()}); + + // Write data file with old schema (only c0) containing rows 1-10. + std::vector dataVectors; + dataVectors.push_back(makeRowVector( + {makeFlatVector({1, 2, 3, 4, 5, 6, 7, 8, 9, 10})})); + auto dataFilePath = TempFilePath::create(); + writeToFile(dataFilePath->getPath(), dataVectors); + + // Create delete file that deletes positions 1, 3, 5 (rows 2, 4, 6). + auto deleteFilePath = TempFilePath::create(); + std::vector deletePositions = {1, 3, 5}; + auto pathColumn = IcebergMetadataColumn::icebergDeleteFilePathColumn(); + auto posColumn = IcebergMetadataColumn::icebergDeletePosColumn(); + writeToFile( + deleteFilePath->getPath(), + {makeRowVector( + {pathColumn->name, posColumn->name}, + {makeFlatVector( + 3, [&](vector_size_t) { return dataFilePath->getPath(); }), + makeFlatVector(deletePositions)})}); + + IcebergDeleteFile icebergDeleteFile( + FileContent::kPositionalDeletes, + deleteFilePath->getPath(), + dwio::common::FileFormat::DWRF, + 3, + getTestFileSize(deleteFilePath->getPath())); + + ColumnHandleMap assignments; + assignments["c0"] = makeC0Handle(); + assignments["country"] = makeIcebergHandle("country", VARCHAR(), 2, "IN"); + + // Test 1: No filter - rows 1,3,5,7,8,9,10 (after deletes: 2,4,6 removed) + { + auto icebergSplits = + makeIcebergSplits(dataFilePath->getPath(), {icebergDeleteFile}, {}, 1); + std::vector expectedVectors; + expectedVectors.push_back(makeRowVector( + newRowType->names(), + {makeFlatVector({1, 3, 5, 7, 8, 9, 10}), + makeFlatVector( + {"IN", "IN", "IN", "IN", "IN", "IN", "IN"})})); + + auto plan = PlanBuilder() + .startTableScan() + .connectorId(kIcebergConnectorId) + .outputType(newRowType) + .dataColumns(newRowType) + .assignments(assignments) + .endTableScan() + .planNode(); + AssertQueryBuilder(plan) + .splits(icebergSplits) + .assertResults(expectedVectors); + } + + // Test 2: Filter on file column (c0 > 5) with deletes + // After deletes: 1,3,5,7,8,9,10 remain. Filter c0 > 5: 7,8,9,10 + { + auto icebergSplits = + makeIcebergSplits(dataFilePath->getPath(), {icebergDeleteFile}, {}, 1); + std::vector expectedVectors; + expectedVectors.push_back(makeRowVector( + newRowType->names(), + {makeFlatVector({7, 8, 9, 10}), + makeFlatVector({"IN", "IN", "IN", "IN"})})); + + auto plan = PlanBuilder() + .startTableScan() + .connectorId(kIcebergConnectorId) + .outputType(newRowType) + .dataColumns(newRowType) + .assignments(assignments) + .endTableScan() + .filter("c0 > 5") + .planNode(); + AssertQueryBuilder(plan) + .splits(icebergSplits) + .assertResults(expectedVectors); + } + + // Test 3: Filter on default value column (country = 'IN') with deletes + // All remaining rows should match since default is 'IN' + { + auto icebergSplits = + makeIcebergSplits(dataFilePath->getPath(), {icebergDeleteFile}, {}, 1); + std::vector expectedVectors; + expectedVectors.push_back(makeRowVector( + newRowType->names(), + {makeFlatVector({1, 3, 5, 7, 8, 9, 10}), + makeFlatVector( + {"IN", "IN", "IN", "IN", "IN", "IN", "IN"})})); + + auto plan = PlanBuilder() + .startTableScan() + .connectorId(kIcebergConnectorId) + .outputType(newRowType) + .dataColumns(newRowType) + .assignments(assignments) + .endTableScan() + .filter("country = 'IN'") + .planNode(); + AssertQueryBuilder(plan) + .splits(icebergSplits) + .assertResults(expectedVectors); + } + + // Test 4: Combined filter (c0 > 3 AND country = 'IN') with deletes + // After deletes: 1,3,5,7,8,9,10. Filter c0 > 3: 5,7,8,9,10 + { auto icebergSplits = - makeIcebergSplits(dataFilePath->getPath(), {}, partitionKeys); - splits.insert(splits.end(), icebergSplits.begin(), icebergSplits.end()); + makeIcebergSplits(dataFilePath->getPath(), {icebergDeleteFile}, {}, 1); + std::vector expectedVectors; + expectedVectors.push_back(makeRowVector( + newRowType->names(), + {makeFlatVector({5, 7, 8, 9, 10}), + makeFlatVector({"IN", "IN", "IN", "IN", "IN"})})); + + auto plan = PlanBuilder() + .startTableScan() + .connectorId(kIcebergConnectorId) + .outputType(newRowType) + .dataColumns(newRowType) + .assignments(assignments) + .endTableScan() + .filter("c0 > 3 AND country = 'IN'") + .planNode(); + AssertQueryBuilder(plan) + .splits(icebergSplits) + .assertResults(expectedVectors); } +} - connector::ColumnHandleMap assignments; - assignments.insert( - {"c0", - std::make_shared( - "c0", - HiveColumnHandle::ColumnType::kRegular, - rowType->childAt(0), - rowType->childAt(0))}); - - std::vector requiredSubFields; - HiveColumnHandle::ColumnParseParameters columnParseParameters; - columnParseParameters.partitionDateValueFormat = - HiveColumnHandle::ColumnParseParameters::kDaysSinceEpoch; - assignments.insert( - {"ds", - std::make_shared( - "ds", - HiveColumnHandle::ColumnType::kPartitionKey, - rowType->childAt(1), - rowType->childAt(1), - std::move(requiredSubFields), - columnParseParameters)}); - - auto plan = PlanBuilder(pool_.get()) - .tableScan(rowType, {}, "", nullptr, assignments) +// Test reading partition columns from Hive-migrated tables. +// This tests the adaptColumns method handling partition columns that are not +// stored in the data file but provided via partitionKeys map. +// This scenario occurs when reading Hive-written data files where partition +// column values are stored in partition metadata rather than in the data file. +TEST_F(HiveIcebergTest, partitionColumnsFromHive) { + auto fileRowType = ROW({"c0", "c1"}, {BIGINT(), INTEGER()}); + auto tableRowType = + ROW({"c0", "c1", "region", "year"}, + {BIGINT(), INTEGER(), VARCHAR(), INTEGER()}); + + // Write data file with only non-partition columns (c0, c1). + std::vector dataVectors; + dataVectors.push_back(makeRowVector({ + makeFlatVector({1, 2, 3}), + makeFlatVector({10, 20, 30}), + })); + auto dataFilePath = TempFilePath::create(); + writeToFile(dataFilePath->getPath(), dataVectors); + + // Set partition keys for region and year. + std::unordered_map> partitionKeys; + partitionKeys["region"] = "US"; + partitionKeys["year"] = "2025"; + + auto icebergSplits = + makeIcebergSplits(dataFilePath->getPath(), {}, partitionKeys); + auto assignments = makeColumnHandles(tableRowType, {2, 3}); + + // Expected result: c0 and c1 from file, region and year from partition keys. + std::vector expectedVectors; + expectedVectors.push_back(makeRowVector( + tableRowType->names(), + { + dataVectors[0]->childAt(0), + dataVectors[0]->childAt(1), + makeFlatVector({"US", "US", "US"}), + makeFlatVector({2025, 2025, 2025}), + })); + + // Read with table schema including partition columns. + auto plan = PlanBuilder() + .startTableScan() + .connectorId(kIcebergConnectorId) + .outputType(tableRowType) + .dataColumns(tableRowType) + .assignments(assignments) + .endTableScan() .planNode(); + AssertQueryBuilder(plan).splits(icebergSplits).assertResults(expectedVectors); +} - HiveConnectorTestBase::assertQuery( - plan, - splits, - "SELECT * FROM (VALUES (0, '2018-04-06'), (1, '2018-04-07'))", - 0); +// Test that positional delete files are skipped when their position upper bound +// is before the split offset. When a delete file's upperBound is less than the +// split's starting row, all deletes in that file are not relevant to this +// split. +TEST_F(HiveIcebergTest, skipDeleteFileByPositionUpperBound) { + folly::SingletonVault::singleton()->registrationComplete(); - // Test filter on non-partitioned non-date column - std::vector nonPartitionFilters = {"c0 = 1"}; - plan = PlanBuilder(pool_.get()) - .tableScan(rowType, nonPartitionFilters, "", nullptr, assignments) - .planNode(); + auto pathColumn = IcebergMetadataColumn::icebergDeleteFilePathColumn(); + auto posColumn = IcebergMetadataColumn::icebergDeletePosColumn(); - HiveConnectorTestBase::assertQuery(plan, splits, "SELECT 1, '2018-04-07'"); + // Create a data file with 100 rows. + auto dataFilePath = TempFilePath::create(); + std::vector dataVectors = {makeRowVector( + {makeFlatVector(makeContinuousIncreasingValues(0, 100))})}; + writeToFile(dataFilePath->getPath(), dataVectors); - // Test filter on non-partitioned date column - std::vector filters = {"ds = date'2018-04-06'"}; - plan = PlanBuilder(pool_.get()).tableScan(rowType, filters).planNode(); + // Create a delete file targeting positions 0, 1, 2. + auto deleteFilePath = TempFilePath::create(); + std::vector deleteVectors = {makeRowVector( + {pathColumn->name, posColumn->name}, + {makeFlatVector( + 3, [&](auto) { return dataFilePath->getPath(); }), + makeFlatVector({0, 1, 2})})}; + writeToFile(deleteFilePath->getPath(), deleteVectors); + + // upperBound "2" is the max position in the delete file. Iceberg stores + // long bounds as 8-byte little-endian binary, then Base64 encodes them. + uint64_t upperBound = 2; + auto upperBoundLE = folly::Endian::little(upperBound); + auto encodedUpperBound = encoding::Base64::encode( + std::string_view( + reinterpret_cast(&upperBoundLE), sizeof(upperBoundLE))); + IcebergDeleteFile deleteFile( + FileContent::kPositionalDeletes, + deleteFilePath->getPath(), + FileFormat::DWRF, + 3, + getTestFileSize(deleteFilePath->getPath()), + {}, + {}, + {{posColumn->id, encodedUpperBound}}); + + auto plan = PlanBuilder() + .startTableScan() + .connectorId(kIcebergConnectorId) + .outputType(ROW({"c0"}, {BIGINT()})) + .endTableScan() + .planNode(); - splits.clear(); - for (auto& dataFilePath : dataFilePaths) { - auto icebergSplits = makeIcebergSplits(dataFilePath->getPath()); - splits.insert(splits.end(), icebergSplits.begin(), icebergSplits.end()); - } + // Create a split that starts at the middle of the file. The split offset + // will be greater than the delete file's upper bound (2), so the delete + // file should be skipped completely. + auto file = filesystems::getFileSystem(dataFilePath->getPath(), nullptr) + ->openFileForRead(dataFilePath->getPath()); + const int64_t fileSize = file->size(); + std::vector deleteFiles = {deleteFile}; + auto split = std::make_shared( + kIcebergConnectorId, + dataFilePath->getPath(), + FileFormat::DWRF, + static_cast(fileSize / 2), + static_cast(fileSize / 2), + std::unordered_map>{}, + std::nullopt, + std::unordered_map{}, + std::shared_ptr{}, + true, + deleteFiles); + + // The second half of the file should be returned with no rows deleted. + createDuckDbTable({makeRowVector( + {makeFlatVector(makeContinuousIncreasingValues(50, 50))})}); + assertQuery(plan, {split}, "SELECT * FROM tmp", 0); +} + +// Row lineage scenarios for _row_id and _last_updated_sequence_number: +// 1. Pre-V3: no info columns, no physical columns → both null. +// 2. V3 new insert: no physical columns; derived from info columns. +// 3. V3 rewrite: physical values take precedence over info columns. +// 4. Physical columns all null: falls back to info column derivation. +// 5. Mixed null/non-null: null slots derived, non-null slots preserved. +// 6. first_row_id = 0 is a valid value. +// 7. Positional deletes: _row_id uses file-absolute positions. +// 8. Subfield filter: _row_id uses file-absolute positions, not output +// indices. +TEST_F(HiveIcebergTest, rowLineage) { + folly::SingletonVault::singleton()->registrationComplete(); - HiveConnectorTestBase::assertQuery(plan, splits, "SELECT 0, '2018-04-06'"); + static const std::vector kOutputNames = { + "c0", "_row_id", "_last_updated_sequence_number"}; + + // 1. Pre-V3. + assertRowLineage({ + .values = {1, 2, 3}, + .expectedVectors = {makeRowVector( + kOutputNames, + { + makeFlatVector({1, 2, 3}), + makeNullableFlatVector( + {std::nullopt, std::nullopt, std::nullopt}), + makeNullableFlatVector( + {std::nullopt, std::nullopt, std::nullopt}), + })}, + }); + + // 2. V3 new insert. + assertRowLineage({ + .values = {10, 20, 30}, + .firstRowId = 100, + .dataSequenceNumber = 7, + .expectedVectors = {makeRowVector( + kOutputNames, + { + makeFlatVector({10, 20, 30}), + makeFlatVector({100, 101, 102}), + makeFlatVector({7, 7, 7}), + })}, + }); + + // 3. V3 rewrite: physical values must not be overridden by info columns. + assertRowLineage({ + .values = {1, 2, 3}, + .storedRowIds = {{500, 501, 502}}, + .storedSequenceNumbers = {{3, 5, 3}}, + .firstRowId = 999, + .dataSequenceNumber = 99, + .expectedVectors = {makeRowVector( + kOutputNames, + { + makeFlatVector({1, 2, 3}), + makeFlatVector({500, 501, 502}), + makeFlatVector({3, 5, 3}), + })}, + }); + + // 4. Physical columns all null: falls back to info column derivation. + assertRowLineage({ + .values = {1, 2, 3}, + .storedRowIds = {{std::nullopt, std::nullopt, std::nullopt}}, + .storedSequenceNumbers = {{std::nullopt, std::nullopt, std::nullopt}}, + .firstRowId = 50, + .dataSequenceNumber = 42, + .expectedVectors = {makeRowVector( + kOutputNames, + { + makeFlatVector({1, 2, 3}), + makeFlatVector({50, 51, 52}), + makeFlatVector({42, 42, 42}), + })}, + }); + + // 5. Mixed null/non-null: null slots derived from info columns, non-null + // preserved. + // pos 0: _row_id=null→10+0=10, seq_num=null→42 + // pos 1: _row_id=99, seq_num=5 + // pos 2: _row_id=null→10+2=12, seq_num=null→42 + // pos 3: _row_id=77, seq_num=10 + assertRowLineage({ + .values = {10, 20, 30, 40}, + .storedRowIds = {{std::nullopt, 99, std::nullopt, 77}}, + .storedSequenceNumbers = {{std::nullopt, 5, std::nullopt, 10}}, + .firstRowId = 10, + .dataSequenceNumber = 42, + .expectedVectors = {makeRowVector( + kOutputNames, + { + makeFlatVector({10, 20, 30, 40}), + makeFlatVector({10, 99, 12, 77}), + makeFlatVector({42, 5, 42, 10}), + })}, + }); + + // 6. first_row_id = 0 is a valid value; _row_id starts at zero. + assertRowLineage({ + .values = {5, 6, 7}, + .firstRowId = 0, + .dataSequenceNumber = 5, + .expectedVectors = {makeRowVector( + kOutputNames, + { + makeFlatVector({5, 6, 7}), + makeFlatVector({0, 1, 2}), + makeFlatVector({5, 5, 5}), + })}, + }); + + // 7. Positional deletes: _row_id uses file-absolute positions. + assertRowLineage({ + .values = {10, 20, 30, 40, 50}, + .firstRowId = 200, + .dataSequenceNumber = 42, + .deletePositions = {1, 3}, + .expectedVectors = {makeRowVector( + kOutputNames, + { + makeFlatVector({10, 30, 50}), + makeFlatVector({200, 202, 204}), + makeFlatVector({42, 42, 42}), + })}, + }); + + // 8. Subfield filter: _row_id uses file-absolute positions, not output + // indices. + assertRowLineage({ + .values = {10, 20, 30, 40, 50}, + .firstRowId = 100, + .dataSequenceNumber = 15, + .subfieldFilter = "c0 > 20", + .expectedVectors = {makeRowVector( + kOutputNames, + { + makeFlatVector({30, 40, 50}), + makeFlatVector({102, 103, 104}), + makeFlatVector({15, 15, 15}), + })}, + }); } #ifdef VELOX_ENABLE_PARQUET -TEST_F(HiveIcebergTest, testPositionalDeleteFileWithRowGroupFilter) { +TEST_F(HiveIcebergTest, positionalDeleteFileWithRowGroupFilter) { // This file contains three row groups, each with about 100 rows. // Each row group has min/max values: [200, 299], [0, 99], [100, 199]. // The filter here is id >= 100, which will cause the parquet reader to filter @@ -850,14 +1686,294 @@ TEST_F(HiveIcebergTest, testPositionalDeleteFileWithRowGroupFilter) { deletedPositionSize); // allocate 100 elements, [100, 199]. std::iota(deletePositionsVec.begin(), deletePositionsVec.end(), 100); auto deleteFilePath = TempFilePath::create(); - HiveConnectorTestBase::assertQuery( - PlanBuilder(pool_.get()) - .tableScan(ROW({"id"}, {BIGINT()}), {"id >= 100"}) + assertQuery( + PlanBuilder() + .startTableScan() + .connectorId(kIcebergConnectorId) + .outputType(ROW({"id"}, {BIGINT()})) + .remainingFilter("id >= 100") + .endTableScan() .planNode(), + createParquetDeleteFileAndSplits( path, deletePositionsVec, deletedPositionSize, deleteFilePath), "SELECT i AS id FROM range(100, 300) AS t(i)", 0); } #endif + +// Sequence number filtering tests for positional deletes (Diff 2). +// Per the Iceberg V2+ spec, a positional delete file should only apply to +// data files whose dataSequenceNumber is strictly less than the delete file's. + +TEST_F(HiveIcebergTest, positionalDeleteSequenceNumberApplied) { + folly::SingletonVault::singleton()->registrationComplete(); + + auto pathColumn = IcebergMetadataColumn::icebergDeleteFilePathColumn(); + auto posColumn = IcebergMetadataColumn::icebergDeletePosColumn(); + auto rowType = ROW({"c0"}, {BIGINT()}); + + // Write base data file: c0 = [0, 1, 2, 3, 4]. + auto dataFilePath = TempFilePath::create(); + writeToFile( + dataFilePath->getPath(), + {makeRowVector({makeFlatVector({0, 1, 2, 3, 4})})}); + + // Write positional delete file targeting positions 1 and 3. + auto deleteFilePath = TempFilePath::create(); + auto baseFilePath = dataFilePath->getPath(); + writeToFile( + deleteFilePath->getPath(), + {makeRowVector( + {pathColumn->name, posColumn->name}, + { + makeFlatVector( + 2, [&](vector_size_t) { return baseFilePath; }), + makeFlatVector({1, 3}), + })}); + + // Delete file seqNum=10 > data seqNum=5 → delete should be applied. + IcebergDeleteFile deleteFile( + FileContent::kPositionalDeletes, + deleteFilePath->getPath(), + dwio::common::FileFormat::DWRF, + 2, + getTestFileSize(deleteFilePath->getPath()), + {}, + {}, + {}, + /*dataSequenceNumber=*/10); + + auto file = filesystems::getFileSystem(baseFilePath, nullptr) + ->openFileForRead(baseFilePath); + auto split = std::make_shared( + kIcebergConnectorId, + baseFilePath, + dwio::common::FileFormat::DWRF, + 0, + file->size(), + std::unordered_map>{}, + std::nullopt, + std::unordered_map{}, + nullptr, + true, + std::vector{deleteFile}, + std::unordered_map{}, + std::nullopt, + /*dataSequenceNumber=*/5); + + auto plan = PlanBuilder() + .startTableScan() + .connectorId(kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .planNode(); + + // Positions 1 and 3 deleted → remaining: [0, 2, 4]. + auto expected = makeRowVector({makeFlatVector({0, 2, 4})}); + AssertQueryBuilder(plan).split(split).assertResults({expected}); +} + +TEST_F(HiveIcebergTest, positionalDeleteSequenceNumberSkipped) { + folly::SingletonVault::singleton()->registrationComplete(); + + auto pathColumn = IcebergMetadataColumn::icebergDeleteFilePathColumn(); + auto posColumn = IcebergMetadataColumn::icebergDeletePosColumn(); + auto rowType = ROW({"c0"}, {BIGINT()}); + + auto dataFilePath = TempFilePath::create(); + writeToFile( + dataFilePath->getPath(), + {makeRowVector({makeFlatVector({0, 1, 2, 3, 4})})}); + + auto deleteFilePath = TempFilePath::create(); + auto baseFilePath = dataFilePath->getPath(); + writeToFile( + deleteFilePath->getPath(), + {makeRowVector( + {pathColumn->name, posColumn->name}, + { + makeFlatVector( + 2, [&](vector_size_t) { return baseFilePath; }), + makeFlatVector({1, 3}), + })}); + + // Delete file seqNum=5 < data seqNum=10 → delete should be skipped. + IcebergDeleteFile deleteFile( + FileContent::kPositionalDeletes, + deleteFilePath->getPath(), + dwio::common::FileFormat::DWRF, + 2, + getTestFileSize(deleteFilePath->getPath()), + {}, + {}, + {}, + /*dataSequenceNumber=*/5); + + auto file = filesystems::getFileSystem(baseFilePath, nullptr) + ->openFileForRead(baseFilePath); + auto split = std::make_shared( + kIcebergConnectorId, + baseFilePath, + dwio::common::FileFormat::DWRF, + 0, + file->size(), + std::unordered_map>{}, + std::nullopt, + std::unordered_map{}, + nullptr, + true, + std::vector{deleteFile}, + std::unordered_map{}, + std::nullopt, + /*dataSequenceNumber=*/10); + + auto plan = PlanBuilder() + .startTableScan() + .connectorId(kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .planNode(); + + // Delete skipped → all rows survive: [0, 1, 2, 3, 4]. + auto expected = makeRowVector({makeFlatVector({0, 1, 2, 3, 4})}); + AssertQueryBuilder(plan).split(split).assertResults({expected}); +} + +// Verifies that same-snapshot positional deletes apply (deleteSeqNum == +// dataSeqNum). Per the Iceberg spec, positional deletes in the same snapshot +// SHOULD apply, so the skip condition uses strict < (not <=). +TEST_F(HiveIcebergTest, positionalDeleteSequenceNumberEqualApplied) { + folly::SingletonVault::singleton()->registrationComplete(); + + auto pathColumn = IcebergMetadataColumn::icebergDeleteFilePathColumn(); + auto posColumn = IcebergMetadataColumn::icebergDeletePosColumn(); + auto rowType = ROW({"c0"}, {BIGINT()}); + + auto dataFilePath = TempFilePath::create(); + writeToFile( + dataFilePath->getPath(), + {makeRowVector({makeFlatVector({0, 1, 2, 3, 4})})}); + + auto deleteFilePath = TempFilePath::create(); + auto baseFilePath = dataFilePath->getPath(); + writeToFile( + deleteFilePath->getPath(), + {makeRowVector( + {pathColumn->name, posColumn->name}, + { + makeFlatVector( + 2, [&](vector_size_t) { return baseFilePath; }), + makeFlatVector({1, 3}), + })}); + + // Delete file seqNum=5 == data seqNum=5 → delete should be applied + // (same-snapshot positional deletes apply per Iceberg spec). + IcebergDeleteFile deleteFile( + FileContent::kPositionalDeletes, + deleteFilePath->getPath(), + dwio::common::FileFormat::DWRF, + 2, + getTestFileSize(deleteFilePath->getPath()), + {}, + {}, + {}, + /*dataSequenceNumber=*/5); + + auto file = filesystems::getFileSystem(baseFilePath, nullptr) + ->openFileForRead(baseFilePath); + auto split = std::make_shared( + kIcebergConnectorId, + baseFilePath, + dwio::common::FileFormat::DWRF, + 0, + file->size(), + std::unordered_map>{}, + std::nullopt, + std::unordered_map{}, + nullptr, + true, + std::vector{deleteFile}, + std::unordered_map{}, + std::nullopt, + /*dataSequenceNumber=*/5); + + auto plan = PlanBuilder() + .startTableScan() + .connectorId(kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .planNode(); + + // Same-snapshot delete applied: positions 1 and 3 deleted → [0, 2, 4]. + auto expected = makeRowVector({makeFlatVector({0, 2, 4})}); + AssertQueryBuilder(plan).split(split).assertResults({expected}); +} + +TEST_F(HiveIcebergTest, positionalDeleteSequenceNumberZeroDisablesFilter) { + folly::SingletonVault::singleton()->registrationComplete(); + + auto pathColumn = IcebergMetadataColumn::icebergDeleteFilePathColumn(); + auto posColumn = IcebergMetadataColumn::icebergDeletePosColumn(); + auto rowType = ROW({"c0"}, {BIGINT()}); + + auto dataFilePath = TempFilePath::create(); + writeToFile( + dataFilePath->getPath(), + {makeRowVector({makeFlatVector({0, 1, 2, 3, 4})})}); + + auto deleteFilePath = TempFilePath::create(); + auto baseFilePath = dataFilePath->getPath(); + writeToFile( + deleteFilePath->getPath(), + {makeRowVector( + {pathColumn->name, posColumn->name}, + { + makeFlatVector( + 2, [&](vector_size_t) { return baseFilePath; }), + makeFlatVector({1, 3}), + })}); + + // Delete file seqNum=0 (unassigned/V1 legacy) → always applied regardless. + IcebergDeleteFile deleteFile( + FileContent::kPositionalDeletes, + deleteFilePath->getPath(), + dwio::common::FileFormat::DWRF, + 2, + getTestFileSize(deleteFilePath->getPath()), + {}, + {}, + {}, + /*dataSequenceNumber=*/0); + + auto file = filesystems::getFileSystem(baseFilePath, nullptr) + ->openFileForRead(baseFilePath); + auto split = std::make_shared( + kIcebergConnectorId, + baseFilePath, + dwio::common::FileFormat::DWRF, + 0, + file->size(), + std::unordered_map>{}, + std::nullopt, + std::unordered_map{}, + nullptr, + true, + std::vector{deleteFile}, + std::unordered_map{}, + std::nullopt, + /*dataSequenceNumber=*/100); + + auto plan = PlanBuilder() + .startTableScan() + .connectorId(kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .planNode(); + + // SeqNum=0 disables filtering → delete applied: [0, 2, 4]. + auto expected = makeRowVector({makeFlatVector({0, 2, 4})}); + AssertQueryBuilder(plan).split(split).assertResults({expected}); +} + } // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/tests/IcebergSplitReaderBenchmark.cpp b/velox/connectors/hive/iceberg/tests/IcebergSplitReaderBenchmark.cpp index 2c2c26297fd..e1ba2c4269f 100644 --- a/velox/connectors/hive/iceberg/tests/IcebergSplitReaderBenchmark.cpp +++ b/velox/connectors/hive/iceberg/tests/IcebergSplitReaderBenchmark.cpp @@ -17,6 +17,8 @@ #include "velox/connectors/hive/iceberg/tests/IcebergSplitReaderBenchmark.h" #include +#include "velox/connectors/hive/HiveConfig.h" + using namespace facebook::velox; using namespace facebook::velox::dwio; using namespace facebook::velox::dwio::common; @@ -99,8 +101,6 @@ IcebergSplitReaderBenchmark::makeIcebergSplit( const std::string& dataFilePath, const std::vector& deleteFiles) { std::unordered_map> partitionKeys; - std::unordered_map customSplitInfo; - customSplitInfo["table_format"] = "hive-iceberg"; auto readFile = std::make_shared(dataFilePath); const int64_t fileSize = readFile->size(); @@ -108,12 +108,12 @@ IcebergSplitReaderBenchmark::makeIcebergSplit( return std::make_shared( kHiveConnectorId, dataFilePath, - fileFomat_, + fileFormat_, 0, fileSize, partitionKeys, std::nullopt, - customSplitInfo, + std::unordered_map{}, nullptr, /*cacheable=*/true, deleteFiles); @@ -175,7 +175,7 @@ IcebergSplitReaderBenchmark::createIcebergSplitsWithPositionalDelete( IcebergDeleteFile deleteFile( FileContent::kPositionalDeletes, deleteFilePath, - fileFomat_, + fileFormat_, deleteRowsCount, testing::internal::GetFileSize( std::fopen(deleteFilePath.c_str(), "r"))); @@ -284,7 +284,6 @@ void IcebergSplitReaderBenchmark::readSingleColumn( std::make_shared( "kHiveConnectorId", "tableName", - false, std::move(filters), remainingFilterExpr, rowType); @@ -293,10 +292,11 @@ void IcebergSplitReaderBenchmark::readSingleColumn( std::make_shared(std::make_shared( std::unordered_map(), true)); const RowTypePtr readerOutputType; - const std::shared_ptr ioStats = + const std::shared_ptr ioStatistics = + std::make_shared(); + const std::shared_ptr metadataIoStatistics = std::make_shared(); - const std::shared_ptr fsStats = - std::make_shared(); + const std::shared_ptr ioStats = std::make_shared(); std::shared_ptr root = memory::memoryManager()->addRootPool( @@ -331,21 +331,24 @@ void IcebergSplitReaderBenchmark::readSingleColumn( suspender.dismiss(); uint64_t resultSize = 0; - for (std::shared_ptr split : splits) { + for (const auto& split : splits) { scanSpec->resetCachedValues(true); + auto icebergSplit = checkedPointerCast(split); std::unique_ptr icebergSplitReader = std::make_unique( - split, + icebergSplit, hiveTableHandle, nullptr, connectorQueryCtx_.get(), hiveConfig, rowType, + ioStatistics, + metadataIoStatistics, ioStats, - fsStats, &fileHandleFactory, nullptr, - scanSpec); + scanSpec, + nullptr); std::shared_ptr randomSkip; icebergSplitReader->configureReaderOptions(randomSkip); diff --git a/velox/connectors/hive/iceberg/tests/IcebergSplitReaderBenchmark.h b/velox/connectors/hive/iceberg/tests/IcebergSplitReaderBenchmark.h index 3408fa4ce83..5d96cc21005 100644 --- a/velox/connectors/hive/iceberg/tests/IcebergSplitReaderBenchmark.h +++ b/velox/connectors/hive/iceberg/tests/IcebergSplitReaderBenchmark.h @@ -16,6 +16,7 @@ #pragma once #include "velox/common/file/FileSystems.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/connectors/hive/TableHandle.h" #include "velox/connectors/hive/iceberg/IcebergDeleteFile.h" #include "velox/connectors/hive/iceberg/IcebergMetadataColumns.h" @@ -24,7 +25,6 @@ #include "velox/dwio/common/tests/utils/DataSetBuilder.h" #include "velox/dwio/dwrf/RegisterDwrfReader.h" #include "velox/dwio/dwrf/writer/Writer.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/vector/tests/utils/VectorTestBase.h" #include @@ -32,6 +32,8 @@ namespace facebook::velox::iceberg::reader::test { +using TempDirectoryPath = common::testutil::TempDirectoryPath; + constexpr uint32_t kNumRowsPerBatch = 20000; constexpr uint32_t kNumBatches = 50; constexpr uint32_t kNumRowsPerRowGroup = 10000; @@ -108,10 +110,10 @@ class IcebergSplitReaderBenchmark { private: const std::string fileName_ = "test.data"; - const std::shared_ptr fileFolder_ = - exec::test::TempDirectoryPath::create(); - const std::shared_ptr deleteFileFolder_ = - exec::test::TempDirectoryPath::create(); + const std::shared_ptr fileFolder_ = + TempDirectoryPath::create(); + const std::shared_ptr deleteFileFolder_ = + TempDirectoryPath::create(); std::unique_ptr dataSetBuilder_; std::shared_ptr rootPool_; @@ -119,7 +121,7 @@ class IcebergSplitReaderBenchmark { std::unique_ptr writer_; dwio::common::RuntimeStatistics runtimeStats_; - dwio::common::FileFormat fileFomat_{dwio::common::FileFormat::DWRF}; + dwio::common::FileFormat fileFormat_{dwio::common::FileFormat::DWRF}; const std::string kHiveConnectorId = "hive-iceberg"; }; diff --git a/velox/connectors/hive/iceberg/tests/IcebergTestBase.cpp b/velox/connectors/hive/iceberg/tests/IcebergTestBase.cpp new file mode 100644 index 00000000000..eb1312eb1a2 --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/IcebergTestBase.cpp @@ -0,0 +1,330 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/tests/IcebergTestBase.h" + +#include + +#include "velox/connectors/ConnectorRegistry.h" +#include "velox/connectors/hive/iceberg/IcebergColumnHandle.h" +#include "velox/connectors/hive/iceberg/IcebergConfig.h" +#include "velox/connectors/hive/iceberg/IcebergConnector.h" +#include "velox/connectors/hive/iceberg/IcebergDataSink.h" +#include "velox/connectors/hive/iceberg/IcebergSplit.h" +#include "velox/connectors/hive/iceberg/PartitionSpec.h" +#include "velox/expression/Expr.h" + +namespace facebook::velox::connector::hive::iceberg::test { + +const std::string kIcebergConnectorId{"test-iceberg"}; + +void IcebergTestBase::SetUp() { + HiveConnectorTestBase::SetUp(); +#ifdef VELOX_ENABLE_PARQUET + parquet::registerParquetReaderFactory(); + parquet::registerParquetWriterFactory(); +#endif + Type::registerSerDe(); + + // Register IcebergConnector. + IcebergConnectorFactory icebergFactory; + auto icebergConnector = icebergFactory.newConnector( + kIcebergConnectorId, + std::make_shared( + std::unordered_map()), + ioExecutor_.get()); + ConnectorRegistry::global().insert( + icebergConnector->connectorId(), icebergConnector); + + connectorSessionProperties_ = std::make_shared( + std::unordered_map(), true); + + hiveConfig_ = + std::make_shared(std::make_shared( + std::unordered_map())); + + icebergConfig_ = + std::make_shared(std::make_shared( + std::unordered_map{ + {IcebergConfig::kFunctionPrefixConfig, + IcebergConfig::kDefaultFunctionPrefix}})); + + setupMemoryPools(); + + fuzzerOptions_.vectorSize = 100; + fuzzerOptions_.nullRatio = 0.1; + fuzzer_ = std::make_unique(fuzzerOptions_, opPool_.get(), 1); +} + +void IcebergTestBase::TearDown() { + fuzzer_.reset(); + connectorQueryCtx_.reset(); + connectorPool_.reset(); + opPool_.reset(); + root_.reset(); + queryCtx_.reset(); + ConnectorRegistry::global().erase(kIcebergConnectorId); + HiveConnectorTestBase::TearDown(); +} + +void IcebergTestBase::setupMemoryPools() { + root_.reset(); + opPool_.reset(); + connectorPool_.reset(); + connectorQueryCtx_.reset(); + queryCtx_.reset(); + + root_ = memory::memoryManager()->addRootPool( + "IcebergTest", 1L << 30, exec::MemoryReclaimer::create()); + opPool_ = root_->addLeafChild("operator"); + connectorPool_ = + root_->addAggregateChild("connector", exec::MemoryReclaimer::create()); + + recreateConnectorQueryCtx(/*sessionTimezone=*/"", false); +} + +void IcebergTestBase::recreateConnectorQueryCtx( + const std::string& sessionTimezone, + bool adjustTimestampToTimezone) { + connectorQueryCtx_.reset(); + queryCtx_.reset(); + + queryCtx_ = core::QueryCtx::create(nullptr, core::QueryConfig({})); + auto expressionEvaluator = std::make_unique( + queryCtx_.get(), opPool_.get()); + + connectorQueryCtx_ = std::make_unique( + opPool_.get(), + connectorPool_.get(), + connectorSessionProperties_.get(), + nullptr, + common::PrefixSortConfig(), + std::move(expressionEvaluator), + nullptr, + "query.IcebergTest", + "task.IcebergTest", + "planNodeId.IcebergTest", + 0, + sessionTimezone, + adjustTimestampToTimezone); +} + +std::vector IcebergTestBase::createTestData( + RowTypePtr rowType, + int32_t numBatches, + vector_size_t rowsPerBatch, + double nullRatio) { + std::vector vectors; + vectors.reserve(numBatches); + + fuzzerOptions_.nullRatio = nullRatio; + fuzzerOptions_.allowDictionaryVector = false; + fuzzerOptions_.timestampPrecision = + fuzzer::FuzzerTimestampPrecision::kMilliSeconds; + fuzzer_->setOptions(fuzzerOptions_); + + for (auto i = 0; i < numBatches; ++i) { + vectors.push_back(fuzzer_->fuzzRow(rowType, rowsPerBatch, false)); + } + + return vectors; +} + +void IcebergTestBase::setConnectorSessionProperty( + const std::string& key, + const std::string& value) { + VELOX_CHECK_NOT_NULL(connectorSessionProperties_); + connectorSessionProperties_->set(key, value); +} + +std::shared_ptr IcebergTestBase::createPartitionSpec( + const RowTypePtr& rowType, + const std::vector& partitionFields) { + std::vector fields; + for (const auto& partitionField : partitionFields) { + fields.push_back( + IcebergPartitionSpec::Field{ + rowType->nameOf(partitionField.id), + rowType->childAt(partitionField.id), + partitionField.type, + partitionField.parameter}); + } + + return fields.empty() ? nullptr + : std::make_shared(1, fields); +} + +namespace { + +parquet::ParquetFieldId makeField(const TypePtr& type, int32_t& fieldId) { + const int32_t currentId = fieldId++; + std::vector children; + children.reserve(type->size()); + for (auto i = 0; i < type->size(); ++i) { + children.push_back(makeField(type->childAt(i), fieldId)); + } + return parquet::ParquetFieldId{currentId, children}; +} + +void addColumnHandles( + const RowTypePtr& rowType, + const std::vector& partitionFields, + std::vector& columnHandles) { + std::unordered_set partitionColumnIds; + for (const auto& field : partitionFields) { + partitionColumnIds.insert(field.id); + } + + int32_t fieldId = 1; + columnHandles.reserve(rowType->size()); + for (auto i = 0; i < rowType->size(); ++i) { + const auto& columnName = rowType->nameOf(i); + const auto& type = rowType->childAt(i); + auto field = makeField(type, fieldId); + columnHandles.push_back( + std::make_shared( + columnName, + partitionColumnIds.contains(i) + ? FileColumnHandle::ColumnType::kPartitionKey + : FileColumnHandle::ColumnType::kRegular, + type, + field)); + } +} + +} // namespace + +IcebergInsertTableHandlePtr IcebergTestBase::createInsertTableHandle( + const RowTypePtr& rowType, + const std::string& outputDirectoryPath, + const std::vector& partitionFields) { + std::vector columnHandles; + addColumnHandles(rowType, partitionFields, columnHandles); + + auto locationHandle = std::make_shared( + outputDirectoryPath, + outputDirectoryPath, + LocationHandle::TableType::kNew); + + auto partitionSpec = createPartitionSpec(rowType, partitionFields); + + return std::make_shared( + /*inputColumns=*/columnHandles, + locationHandle, + /*tableStorageFormat=*/fileFormat_, + partitionSpec, + /*compressionKind=*/common::CompressionKind::CompressionKind_ZSTD); +} + +std::shared_ptr IcebergTestBase::createDataSink( + const RowTypePtr& rowType, + const std::string& outputDirectoryPath, + const std::vector& partitionFields) { + auto tableHandle = + createInsertTableHandle(rowType, outputDirectoryPath, partitionFields); + return std::make_shared( + rowType, + tableHandle, + connectorQueryCtx_.get(), + CommitStrategy::kNoCommit, + hiveConfig_, + icebergConfig_); +} + +std::shared_ptr IcebergTestBase::createDataSinkAndAppendData( + const std::vector& vectors, + const std::string& dataPath, + const std::vector& partitionFields) { + VELOX_CHECK(!vectors.empty(), "vectors cannot be empty"); + + auto rowType = vectors.front()->rowType(); + auto dataSink = createDataSink(rowType, dataPath, partitionFields); + + for (const auto& vector : vectors) { + dataSink->appendData(vector); + } + EXPECT_TRUE(dataSink->finish()); + return dataSink; +} + +std::vector IcebergTestBase::listFiles( + const std::string& dirPath) { + std::vector files; + if (!std::filesystem::exists(dirPath)) { + return files; + } + + for (auto& dirEntry : + std::filesystem::recursive_directory_iterator(dirPath)) { + if (dirEntry.is_regular_file()) { + files.push_back(dirEntry.path().string()); + } + } + return files; +} + +std::unordered_map> +IcebergTestBase::extractPartitionKeys(const std::string& filePath) { + std::unordered_map> partitionKeys; + + std::vector pathComponents; + folly::split("/", filePath, pathComponents); + for (const auto& component : pathComponents) { + if (component.find('=') != std::string::npos) { + std::vector keys; + folly::split('=', component, keys); + if (keys.size() == 2) { + if (keys[1] == "null") { + partitionKeys[keys[0]] = std::nullopt; + } else { + partitionKeys[keys[0]] = keys[1]; + } + } + } + } + + return partitionKeys; +} + +std::vector> +IcebergTestBase::createSplitsForDirectory(const std::string& directory) { + std::vector> splits; + + auto files = listFiles(directory); + for (const auto& filePath : files) { + auto partitionKeys = extractPartitionKeys(filePath); + + const auto file = filesystems::getFileSystem(filePath, nullptr) + ->openFileForRead(filePath); + splits.push_back( + std::make_shared( + kIcebergConnectorId, + filePath, + fileFormat_, + 0, + file->size(), + partitionKeys, + std::nullopt, + std::unordered_map{}, + nullptr, + /*cacheable=*/true, + std::vector())); + } + + return splits; +} + +} // namespace facebook::velox::connector::hive::iceberg::test diff --git a/velox/connectors/hive/iceberg/tests/IcebergTestBase.h b/velox/connectors/hive/iceberg/tests/IcebergTestBase.h new file mode 100644 index 00000000000..b3eefbd4726 --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/IcebergTestBase.h @@ -0,0 +1,122 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include "velox/common/testutil/TempDirectoryPath.h" +#include "velox/connectors/hive/iceberg/IcebergConfig.h" +#include "velox/connectors/hive/iceberg/IcebergDataSink.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/vector/fuzzer/VectorFuzzer.h" +#ifdef VELOX_ENABLE_PARQUET +#include "velox/dwio/parquet/RegisterParquetWriter.h" +#include "velox/dwio/parquet/reader/ParquetReader.h" +#endif + +namespace facebook::velox::connector::hive::iceberg::test { + +using TempDirectoryPath = common::testutil::TempDirectoryPath; + +extern const std::string kIcebergConnectorId; + +struct PartitionField { + // 0-based column index. + int32_t id; + TransformType type; + std::optional parameter; +}; + +class IcebergTestBase : public exec::test::HiveConnectorTestBase { + protected: + void SetUp() override; + + void TearDown() override; + + std::vector createTestData( + RowTypePtr rowType, + int32_t numBatches, + vector_size_t rowsPerBatch, + double nullRatio = 0.0); + + std::shared_ptr createDataSink( + const RowTypePtr& rowType, + const std::string& outputDirectoryPath, + const std::vector& partitionFields = {}); + + std::shared_ptr createDataSinkAndAppendData( + const std::vector& vectors, + const std::string& dataPath, + const std::vector& partitionFields = {}); + + std::vector> createSplitsForDirectory( + const std::string& directory); + + std::vector listFiles(const std::string& dirPath); + + std::shared_ptr createPartitionSpec( + const RowTypePtr& rowType, + const std::vector& partitionFields); + + void setConnectorSessionProperty( + const std::string& key, + const std::string& value); + + /// Recreates the connector query context with the given session timezone + /// and timestamp-adjustment flag. Tests use this to exercise non-UTC + /// session configurations and verify timezone-sensitive behavior. + void recreateConnectorQueryCtx( + const std::string& sessionTimezone, + bool adjustTimestampToTimezone); + + /// Extracts partition key-value pairs from a file path. + /// Returns a map where keys are partition column names and values are + /// partition values (std::nullopt for null values). + /// Example: "/path/to/c1=10/c2=null/file.parquet" returns + /// {{"c1", "10"}, {"c2", std::nullopt}}. + static std::unordered_map> + extractPartitionKeys(const std::string& filePath); + + dwio::common::FileFormat fileFormat_{dwio::common::FileFormat::PARQUET}; + std::shared_ptr opPool_; + std::unique_ptr connectorQueryCtx_; + + private: + IcebergInsertTableHandlePtr createInsertTableHandle( + const RowTypePtr& rowType, + const std::string& outputDirectoryPath, + const std::vector& partitionFields = {}); + + std::vector listPartitionDirectories( + const std::string& dataPath); + + void setupMemoryPools(); + + std::shared_ptr root_; + std::shared_ptr connectorPool_; + std::shared_ptr connectorSessionProperties_; + std::shared_ptr hiveConfig_; + std::shared_ptr icebergConfig_; + VectorFuzzer::Options fuzzerOptions_; + std::unique_ptr fuzzer_; + std::shared_ptr queryCtx_; +}; + +} // namespace facebook::velox::connector::hive::iceberg::test diff --git a/velox/connectors/hive/iceberg/tests/Main.cpp b/velox/connectors/hive/iceberg/tests/Main.cpp new file mode 100644 index 00000000000..3c9dd661505 --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/Main.cpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/process/ThreadDebugInfo.h" + +#include +#include + +// This main is needed for some tests on linux. +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + // Signal handler required for ThreadDebugInfoTest + facebook::velox::process::addDefaultFatalSignalHandler(); + folly::Init init(&argc, &argv, false); + return RUN_ALL_TESTS(); +} diff --git a/velox/connectors/hive/iceberg/tests/PartitionNameTest.cpp b/velox/connectors/hive/iceberg/tests/PartitionNameTest.cpp new file mode 100644 index 00000000000..8e9bafbe453 --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/PartitionNameTest.cpp @@ -0,0 +1,481 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "velox/common/encode/Base64.h" +#include "velox/connectors/hive/iceberg/IcebergConfig.h" +#include "velox/connectors/hive/iceberg/IcebergPartitionName.h" +#include "velox/connectors/hive/iceberg/TransformEvaluator.h" +#include "velox/connectors/hive/iceberg/TransformExprBuilder.h" +#include "velox/connectors/hive/iceberg/tests/IcebergTestBase.h" + +namespace facebook::velox::connector::hive::iceberg { + +using namespace facebook::velox; + +namespace { + +class PartitionNameTest : public test::IcebergTestBase { + protected: + // Generates partition IDs for the input rows and verifies that the resulting + // partition paths match the expected paths. Each row is processed + // independently, and its generated partition path is compared against the + // corresponding entry in expectedPaths. + // + // @param input Input data to generate partition IDs from. Must have the + // same size as expectedPaths. + // @param partitionSpec The IcebergPartitionSpec defining the partition + // transforms. The partition channels are determined by matching field names + // from the spec to column names in the input's type. + // @param expectedPaths Expected partition path strings, one per row. Each + // path should be the complete partition directory name (e.g., "col1=val1"). + void verifyPartitionPaths( + const RowVectorPtr& input, + const std::shared_ptr& partitionSpec, + const std::vector& expectedPaths) const { + ASSERT_EQ(expectedPaths.size(), input->size()); + std::vector partitionChannels(partitionSpec->fields.size()); + auto rowType = input->rowType(); + for (auto i = 0; i < partitionSpec->fields.size(); ++i) { + partitionChannels[i] = + rowType->getChildIdx(partitionSpec->fields[i].name); + } + + // Step 1: Build transform expressions and create evaluator. + auto transformExpressions = TransformExprBuilder::toExpressions( + partitionSpec, + partitionChannels, + rowType, + std::string(IcebergConfig::kDefaultFunctionPrefix)); + auto transformEvaluator = std::make_unique( + transformExpressions, connectorQueryCtx_.get()); + + // Step 2: Apply transforms to input partition columns. + auto transformedColumns = transformEvaluator->evaluate(input); + + std::vector partitionKeyTypes; + std::vector partitionKeyNames; + for (const auto& field : partitionSpec->fields) { + partitionKeyTypes.emplace_back(field.resultType()); + std::string key = field.transformType == TransformType::kIdentity + ? field.name + : fmt::format( + "{}_{}", + field.name, + TransformTypeName::toName(field.transformType)); + partitionKeyNames.emplace_back(std::move(key)); + } + + auto partitionRowType = + ROW(std::move(partitionKeyNames), std::move(partitionKeyTypes)); + // Step 3: Create RowVector based on transformed columns. + auto transformedRowVector = std::make_shared( + connectorQueryCtx_->memoryPool(), + partitionRowType, + nullptr, + input->size(), + std::move(transformedColumns)); + + // Step 4: Generate partition IDs from transformed data. + // The transformed row vector has columns in the same order as partition + // spec fields, so channels are sequential: 0, 1, 2, ... + std::vector transformedChannels( + partitionSpec->fields.size()); + std::iota(transformedChannels.begin(), transformedChannels.end(), 0); + + auto idGenerator = std::make_unique( + partitionRowType, + transformedChannels, + /*maxPartitions=*/128, + connectorQueryCtx_->memoryPool()); + + auto nameGenerator = std::make_unique(partitionSpec); + + raw_vector partitionIds(input->size()); + idGenerator->run(transformedRowVector, partitionIds); + + for (auto i = 0; i < input->size(); ++i) { + std::string partitionName = nameGenerator->partitionName( + partitionIds[i], idGenerator->partitionValues(), false); + ASSERT_EQ(partitionName, expectedPaths[i]); + } + } +}; + +TEST_F(PartitionNameTest, identity) { + std::vector> input = { + {INTEGER(), makeConstant(42, 1), "42"}, + {BIGINT(), makeConstant(9'876'543'210, 1), "9876543210"}, + {VARCHAR(), + makeConstant("test string partition column name", 1), + "test+string+partition+column+name"}, + {VARBINARY(), + makeConstant("\x48\x65\x6c\x6c\x6f", 1, VARBINARY()), + "SGVsbG8%3D"}, + {DECIMAL(18, 4), + makeConstant(12'345'678'901'234, 1, DECIMAL(18, 4)), + "1234567890.1234"}, + {BOOLEAN(), makeConstant(true, 1), "true"}, + {DATE(), makeConstant(18'262, 1, DATE()), "2020-01-01"}, + }; + + for (const auto& [type, value, expectedValue] : input) { + const auto& partitionSpec = createPartitionSpec( + ROW({"c0"}, {type}), {{0, TransformType::kIdentity, std::nullopt}}); + + verifyPartitionPaths( + makeRowVector({value}), + partitionSpec, + {fmt::format("c0={}", expectedValue)}); + } +} + +TEST_F(PartitionNameTest, timestamp) { + std::vector timestamps = { + Timestamp(253402300800, 100000000), // +10000-01-01T00:00:00.1. + Timestamp(-62170000000, 0), // -0001-11-29T19:33:20. + Timestamp(-62135577748, 999000000), // 0001-01-01T05:17:32.999. + Timestamp(0, 0), // 1970-01-01T00:00. + Timestamp(1609459200, 999000000), // 2021-01-01T00:00. + Timestamp(1640995200, 500000000), // 2022-01-01T00:00:00.5. + Timestamp(1672531200, 123000000), // 2023-01-01T00:00:00.123. + Timestamp(-1, 999000000), // 1969-12-31T23:59:59.999. + Timestamp(1, 1000000), // 1970-01-01T00:00:01.001. + Timestamp(-62167219199, 0), // 0000-01-01T00:00:01. + Timestamp(-377716279140, 321000000), // -10000-01-01T01:01:00.321. + Timestamp(253402304660, 321000000), // +10000-01-01T01:01:00.321. + Timestamp(951782400, 0), // 2000-02-29T00:00:00 (leap year). + Timestamp(4107456000, 0), // 2100-02-28T00:00:00. + Timestamp(-86400, 0), // 1969-12-31T00:00:00. + }; + + std::vector expectedPartitionNames = { + "c0=%2B10000-01-01T00%3A00%3A00.1", + "c0=-0001-11-29T19%3A33%3A20", + "c0=0001-01-01T05%3A17%3A32.999", + "c0=1970-01-01T00%3A00%3A00", + "c0=2021-01-01T00%3A00%3A00.999", + "c0=2022-01-01T00%3A00%3A00.5", + "c0=2023-01-01T00%3A00%3A00.123", + "c0=1969-12-31T23%3A59%3A59.999", + "c0=1970-01-01T00%3A00%3A01.001", + "c0=0000-01-01T00%3A00%3A01", + "c0=-10000-08-24T19%3A21%3A00.321", + "c0=%2B10000-01-01T01%3A04%3A20.321", + "c0=2000-02-29T00%3A00%3A00", + "c0=2100-02-28T00%3A00%3A00", + "c0=1969-12-31T00%3A00%3A00", + }; + + const auto& partitionSpec = createPartitionSpec( + ROW({"c0"}, {TIMESTAMP()}), + {{0, TransformType::kIdentity, std::nullopt}}); + + verifyPartitionPaths( + makeRowVector({makeFlatVector(timestamps)}), + partitionSpec, + expectedPartitionNames); +} + +TEST_F(PartitionNameTest, null) { + std::vector< + std::tuple, VectorPtr>> + input = { + {INTEGER(), + TransformType::kBucket, + 32, + makeConstant(std::nullopt, 1)}, + {VARCHAR(), + TransformType::kTruncate, + 100, + makeConstant(std::nullopt, 1)}, + {DECIMAL(18, 3), + TransformType::kIdentity, + std::nullopt, + makeConstant(std::nullopt, 1, DECIMAL(18, 3))}, + {TIMESTAMP(), + TransformType::kYear, + std::nullopt, + makeConstant(std::nullopt, 1)}, + {TIMESTAMP(), + TransformType::kMonth, + std::nullopt, + makeConstant(std::nullopt, 1)}, + {DATE(), + TransformType::kDay, + std::nullopt, + makeConstant(std::nullopt, 1, DATE())}, + {TIMESTAMP(), + TransformType::kHour, + std::nullopt, + makeConstant(std::nullopt, 1)}, + }; + + for (const auto& [type, transformType, parameter, value] : input) { + auto rowType = ROW({"c0"}, {type}); + const auto& partitionSpec = + createPartitionSpec(rowType, {{0, transformType, parameter}}); + if (transformType == TransformType::kIdentity) { + verifyPartitionPaths(makeRowVector({value}), partitionSpec, {"c0=null"}); + } else { + verifyPartitionPaths( + makeRowVector({value}), + partitionSpec, + {fmt::format( + "c0_{}=null", TransformTypeName::toName(transformType))}); + } + } +} + +// test both partition column name and partition key encoding. +TEST_F(PartitionNameTest, specialChars) { + std::vector> inputs = { + {"abc123", "abc123"}, + {"ABC123", "ABC123"}, + {"a.b-c_d*e", "a.b-c_d*e"}, + {"space test", "space+test"}, + {"slash/test", "slash%2Ftest"}, + {"question?test", "question%3Ftest"}, + {"percent%test", "percent%25test"}, + {"hash#test", "hash%23test"}, + {"ampersand&test", "ampersand%26test"}, + {"equals=test", "equals%3Dtest"}, + {"plus+test", "plus%2Btest"}, + {"comma,test", "comma%2Ctest"}, + {"semicolon;test", "semicolon%3Btest"}, + {"at@test", "at%40test"}, + {"exclamation!test", "exclamation%21test"}, + {"dollar$test", "dollar%24test"}, + {"backslash\\test", "backslash%5Ctest"}, + {"quote\"test", "quote%22test"}, + {"apostrophe'test", "apostrophe%27test"}, + {"paren(test", "paren%28test"}, + {"paren)test", "paren%29test"}, + {"lessthan", "greater%3Ethan"}, + {"colon:test", "colon%3Atest"}, + {"pipe|test", "pipe%7Ctest"}, + {"bracket[test", "bracket%5Btest"}, + {"bracket]test", "bracket%5Dtest"}, + {"brace{test", "brace%7Btest"}, + {"brace}test", "brace%7Dtest"}, + {"caret^test", "caret%5Etest"}, + {"tilde~test", "tilde%7Etest"}, + {"backtick`test", "backtick%60test"}, + {"newline\ntest", "newline%0Atest"}, + {"carriage\rreturn", "carriage%0Dreturn"}, + {"tab\ttest", "tab%09test"}, + {"unicode\u00A9test", "unicode%C2%A9test"}, + {"email@example.com", "email%40example.com"}, + {"user:password@host:port/path", "user%3Apassword%40host%3Aport%2Fpath"}, + {"https://github.com/facebookincubator/velox", + "https%3A%2F%2Fgithub.com%2Ffacebookincubator%2Fvelox"}, + {"a+b=c&d=e+f", "a%2Bb%3Dc%26d%3De%2Bf"}, + {"a#b=c/d e", "a%23b%3Dc%2Fd+e"}, + {"special!@#$%^&*()_+", "special%21%40%23%24%25%5E%26*%28%29_%2B"}, + }; + + for (const auto& [input, encodedValue] : inputs) { + const auto& partitionSpec = createPartitionSpec( + ROW({input}, {VARCHAR()}), + {{0, TransformType::kIdentity, std::nullopt}}); + + verifyPartitionPaths( + makeRowVector( + {input}, {makeConstant(StringView(input), 1)}), + partitionSpec, + {fmt::format("{}={}", encodedValue, encodedValue)}); + } +} + +TEST_F(PartitionNameTest, multipleRows) { + const auto& partitionSpec = createPartitionSpec( + ROW({"c0", "c1"}, {INTEGER(), VARCHAR()}), + { + {0, TransformType::kBucket, 8}, + {1, TransformType::kIdentity, std::nullopt}, + }); + + verifyPartitionPaths( + makeRowVector({ + makeFlatVector({10, 20, 30, -100}), + makeFlatVector({"value1", "VALue2", "VALUE3", ""}), + }), + partitionSpec, + { + "c0_bucket=4/c1=value1", + "c0_bucket=3/c1=VALue2", + "c0_bucket=3/c1=VALUE3", + "c0_bucket=6/c1=", + }); +} + +TEST_F(PartitionNameTest, year) { + const auto& partitionSpec = createPartitionSpec( + ROW({"c0"}, {TIMESTAMP()}), {{0, TransformType::kYear, std::nullopt}}); + + std::vector timestamps = { + Timestamp(0, 0), + Timestamp(1609459200, 0), + Timestamp(1640995200, 0), + Timestamp(-31536000, 0), + Timestamp(253402300800, 0), + }; + + verifyPartitionPaths( + makeRowVector({makeFlatVector(timestamps)}), + partitionSpec, + { + "c0_year=1970", + "c0_year=2021", + "c0_year=2022", + "c0_year=1969", + "c0_year=10000", + }); +} + +TEST_F(PartitionNameTest, yearWithDate) { + const auto& partitionSpec = createPartitionSpec( + ROW({"c0"}, {DATE()}), {{0, TransformType::kYear, std::nullopt}}); + + verifyPartitionPaths( + makeRowVector({makeFlatVector({0, 365, 18262, -365}, DATE())}), + partitionSpec, + { + "c0_year=1970", + "c0_year=1971", + "c0_year=2020", + "c0_year=1969", + }); +} + +TEST_F(PartitionNameTest, month) { + const auto& partitionSpec = createPartitionSpec( + ROW({"c0"}, {TIMESTAMP()}), {{0, TransformType::kMonth, std::nullopt}}); + + verifyPartitionPaths( + makeRowVector({makeFlatVector({ + Timestamp(0, 0), + Timestamp(2678400, 0), + Timestamp(1609459200, 0), + Timestamp(1640995200, 0), + Timestamp(-2678400, 0), + })}), + partitionSpec, + { + "c0_month=1970-01", + "c0_month=1970-02", + "c0_month=2021-01", + "c0_month=2022-01", + "c0_month=1969-12", + }); +} + +TEST_F(PartitionNameTest, monthWithDate) { + const auto& partitionSpec = createPartitionSpec( + ROW({"c0"}, {DATE()}), {{0, TransformType::kMonth, std::nullopt}}); + + verifyPartitionPaths( + makeRowVector({makeFlatVector({0, 31, 365, -31}, DATE())}), + partitionSpec, + { + "c0_month=1970-01", + "c0_month=1970-02", + "c0_month=1971-01", + "c0_month=1969-12", + }); +} + +TEST_F(PartitionNameTest, day) { + const auto& partitionSpec = createPartitionSpec( + ROW({"c0"}, {TIMESTAMP()}), {{0, TransformType::kDay, std::nullopt}}); + + std::vector timestamps = { + Timestamp(0, 0), + Timestamp(86400, 0), + Timestamp(1577836800, 0), + Timestamp(-86400, 0), + }; + + verifyPartitionPaths( + makeRowVector({makeFlatVector(timestamps)}), + partitionSpec, + { + "c0_day=1970-01-01", + "c0_day=1970-01-02", + "c0_day=2020-01-01", + "c0_day=1969-12-31", + }); +} + +TEST_F(PartitionNameTest, dayWithDate) { + const auto& partitionSpec = createPartitionSpec( + ROW({"c0"}, {DATE()}), {{0, TransformType::kDay, std::nullopt}}); + + verifyPartitionPaths( + makeRowVector({makeFlatVector({0, 1, 18262, -1}, DATE())}), + partitionSpec, + { + "c0_day=1970-01-01", + "c0_day=1970-01-02", + "c0_day=2020-01-01", + "c0_day=1969-12-31", + }); +} + +TEST_F(PartitionNameTest, hour) { + const auto& partitionSpec = createPartitionSpec( + ROW({"c0"}, {TIMESTAMP()}), {{0, TransformType::kHour, std::nullopt}}); + + std::vector timestamps = { + Timestamp(0, 0), + Timestamp(3600, 0), + Timestamp(86400, 0), + Timestamp(1577836800, 0), + Timestamp(-3600, 0), + }; + + verifyPartitionPaths( + makeRowVector({makeFlatVector(timestamps)}), + partitionSpec, + { + "c0_hour=1970-01-01-00", + "c0_hour=1970-01-01-01", + "c0_hour=1970-01-02-00", + "c0_hour=2020-01-01-00", + "c0_hour=1969-12-31-23", + }); +} + +TEST_F(PartitionNameTest, multipleTransformsSameColumn) { + const auto& partitionSpec = createPartitionSpec( + ROW({"c0"}, {TIMESTAMP()}), + { + {0, TransformType::kIdentity, std::nullopt}, + {0, TransformType::kYear, std::nullopt}, + {0, TransformType::kBucket, 10}, + }); + + verifyPartitionPaths( + makeRowVector({makeFlatVector({Timestamp(1609459200, 0)})}), + partitionSpec, + {"c0=2021-01-01T00%3A00%3A00/c0_year=2021/c0_bucket=0"}); +} + +} // namespace + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/tests/PartitionSpecTest.cpp b/velox/connectors/hive/iceberg/tests/PartitionSpecTest.cpp new file mode 100644 index 00000000000..23a06340aad --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/PartitionSpecTest.cpp @@ -0,0 +1,134 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/PartitionSpec.h" + +#include +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" +#include "velox/type/Type.h" + +namespace facebook::velox::connector::hive::iceberg { + +namespace { + +TEST(PartitionSpecTest, invalidColumnType) { + auto makeSpec = [](const TypePtr& type) { + std::vector fields = { + {"c0", type, TransformType::kIdentity, std::nullopt}, + }; + return std::make_shared(1, fields); + }; + + VELOX_ASSERT_USER_THROW( + makeSpec(ROW({{"a", INTEGER()}})), + "Type is not supported as a partition column: ROW"); + VELOX_ASSERT_USER_THROW( + makeSpec(ARRAY(INTEGER())), + "Type is not supported as a partition column: ARRAY"); + VELOX_ASSERT_USER_THROW( + makeSpec(MAP(VARCHAR(), INTEGER())), + "Type is not supported as a partition column: MAP"); + VELOX_ASSERT_USER_THROW( + makeSpec(TIMESTAMP_WITH_TIME_ZONE()), + "Type is not supported as a partition column: TIMESTAMP WITH TIME ZONE"); +} + +TEST(PartitionSpecTest, invalidMultipleTransforms) { + { + std::vector fields = { + {"c0", VARCHAR(), TransformType::kIdentity, std::nullopt}, + {"c0", VARCHAR(), TransformType::kIdentity, std::nullopt}, + }; + VELOX_ASSERT_USER_THROW( + std::make_shared(1, fields), + "Column: 'c0', Category: Identity, Transforms: [identity, identity]"); + } + + { + std::vector fields = { + {"c0", VARCHAR(), TransformType::kBucket, 16}, + {"c0", VARCHAR(), TransformType::kBucket, 32}, + }; + VELOX_ASSERT_USER_THROW( + std::make_shared(1, fields), + "Column: 'c0', Category: Bucket, Transforms: [bucket, bucket]"); + } + + { + std::vector fields = { + {"c0", VARCHAR(), TransformType::kTruncate, 2}, + {"c0", VARCHAR(), TransformType::kTruncate, 5}, + }; + VELOX_ASSERT_USER_THROW( + std::make_shared(1, fields), + "Column: 'c0', Category: Truncate, Transforms: [trunc, trunc]"); + } + + { + std::vector fields4 = { + {"c0", TIMESTAMP(), TransformType::kYear, std::nullopt}, + {"c0", TIMESTAMP(), TransformType::kMonth, std::nullopt}, + {"c0", TIMESTAMP(), TransformType::kDay, std::nullopt}, + {"c0", TIMESTAMP(), TransformType::kHour, std::nullopt}, + }; + VELOX_ASSERT_USER_THROW( + std::make_shared(1, fields4), + "Column: 'c0', Category: Temporal, Transforms: [year, month, day, hour]"); + } +} + +TEST(PartitionSpecTest, invalidMultipleTransformsMultipleColumns) { + std::vector fields = { + {"c0", DATE(), TransformType::kYear, std::nullopt}, + {"c0", DATE(), TransformType::kMonth, std::nullopt}, + {"c1", VARCHAR(), TransformType::kBucket, 16}, + {"c1", VARCHAR(), TransformType::kBucket, 32}, + }; + // order may vary due to map iteration. + VELOX_ASSERT_USER_THROW( + std::make_shared(1, fields), + "Column: 'c0', Category: Temporal, Transforms: [year, month]"); + VELOX_ASSERT_USER_THROW( + std::make_shared(1, fields), + "Column: 'c1', Category: Bucket, Transforms: [bucket, bucket]"); +} + +TEST(PartitionSpecTest, validMultipleTransforms) { + { + std::vector fields = { + {"c0", VARCHAR(), TransformType::kIdentity, std::nullopt}, + {"c0", VARCHAR(), TransformType::kBucket, 16}, + {"c0", VARCHAR(), TransformType::kTruncate, 10}, + }; + auto spec = std::make_shared(1, fields); + EXPECT_EQ(spec->fields.size(), 3); + } + + { + std::vector fields = { + {"c0", DATE(), TransformType::kYear, std::nullopt}, + {"c0", DATE(), TransformType::kBucket, 16}, + {"c0", DATE(), TransformType::kIdentity, std::nullopt}, + }; + auto spec = std::make_shared(1, fields); + EXPECT_EQ(spec->fields.size(), 3); + } +} + +} // namespace + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/tests/PartitionValueFormatterTest.cpp b/velox/connectors/hive/iceberg/tests/PartitionValueFormatterTest.cpp new file mode 100644 index 00000000000..13cda035cbf --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/PartitionValueFormatterTest.cpp @@ -0,0 +1,145 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "velox/connectors/hive/iceberg/IcebergPartitionName.h" +#include "velox/type/Type.h" + +namespace facebook::velox::connector::hive::iceberg { + +namespace { + +template +std::string toPath(TransformType transform, T value, const TypePtr& type) { + return IcebergPartitionName::toName(value, type, transform); +} + +std::string timestampToPath(const Timestamp& timestamp) { + return toPath(TransformType::kIdentity, timestamp, TIMESTAMP()); +} + +std::string testString( + const std::string& value, + const TypePtr& typePtr = VARCHAR()) { + auto identityResult = + toPath(TransformType::kIdentity, StringView(value), typePtr); + auto truncateResult = + toPath(TransformType::kTruncate, StringView(value), typePtr); + EXPECT_EQ(identityResult, truncateResult); + return identityResult; +} + +std::string testVarbinary(const std::string& value) { + return testString(value, VARBINARY()); +} + +std::string testInteger(int32_t value) { + auto identityResult = toPath(TransformType::kIdentity, value, INTEGER()); + auto bucketResult = toPath(TransformType::kBucket, value, INTEGER()); + auto truncResult = toPath(TransformType::kTruncate, value, INTEGER()); + EXPECT_EQ(identityResult, truncResult); + EXPECT_EQ(bucketResult, truncResult); + return truncResult; +} + +TEST(IcebergPartitionPathTest, integer) { + EXPECT_EQ(testInteger(0), "0"); + EXPECT_EQ(testInteger(1), "1"); + EXPECT_EQ(testInteger(100), "100"); + EXPECT_EQ(testInteger(-100), "-100"); + EXPECT_EQ(testInteger(128), "128"); + EXPECT_EQ(testInteger(1024), "1024"); +} + +TEST(IcebergPartitionPathTest, date) { + EXPECT_EQ(toPath(TransformType::kIdentity, 18'262, DATE()), "2020-01-01"); + EXPECT_EQ(toPath(TransformType::kIdentity, 0, DATE()), "1970-01-01"); + EXPECT_EQ(toPath(TransformType::kIdentity, -1, DATE()), "1969-12-31"); + EXPECT_EQ(toPath(TransformType::kIdentity, 2'932'897, DATE()), "10000-01-01"); +} + +TEST(IcebergPartitionPathTest, boolean) { + EXPECT_EQ(toPath(TransformType::kIdentity, true, BOOLEAN()), "true"); + EXPECT_EQ(toPath(TransformType::kIdentity, false, BOOLEAN()), "false"); +} + +TEST(IcebergPartitionPathTest, string) { + EXPECT_EQ(testString("a/b/c=d"), "a/b/c=d"); + EXPECT_EQ(testString(""), ""); + EXPECT_EQ(testString("abc"), "abc"); +} + +TEST(IcebergPartitionPathTest, varbinary) { + EXPECT_EQ(testVarbinary("\x48\x65\x6c\x6c\x6f"), "SGVsbG8="); + EXPECT_EQ(testVarbinary("\x1\x2\x3"), "AQID"); + EXPECT_EQ(testVarbinary(""), ""); +} + +TEST(IcebergPartitionPathTest, timestamp) { + EXPECT_EQ(timestampToPath(Timestamp(0, 0)), "1970-01-01T00:00:00"); + EXPECT_EQ( + timestampToPath(Timestamp(1'609'459'200, 999'000'000)), + "2021-01-01T00:00:00.999"); + EXPECT_EQ( + timestampToPath(Timestamp(1'640'995'200, 500'000'000)), + "2022-01-01T00:00:00.5"); + EXPECT_EQ( + timestampToPath(Timestamp(-1, 999'000'000)), "1969-12-31T23:59:59.999"); + EXPECT_EQ( + timestampToPath(Timestamp(253'402'300'800, 100'000'000)), + "+10000-01-01T00:00:00.1"); + EXPECT_EQ( + timestampToPath(Timestamp(-62'170'000'000, 0)), "-0001-11-29T19:33:20"); + EXPECT_EQ( + timestampToPath(Timestamp(-62'167'219'199, 0)), "0000-01-01T00:00:01"); +} + +TEST(IcebergPartitionPathTest, year) { + EXPECT_EQ(toPath(TransformType::kYear, 0, INTEGER()), "1970"); + EXPECT_EQ(toPath(TransformType::kYear, 1, INTEGER()), "1971"); + EXPECT_EQ(toPath(TransformType::kYear, 8'030, INTEGER()), "10000"); + EXPECT_EQ(toPath(TransformType::kYear, -1, INTEGER()), "1969"); + EXPECT_EQ(toPath(TransformType::kYear, -50, INTEGER()), "1920"); +} + +TEST(IcebergPartitionPathTest, month) { + EXPECT_EQ(toPath(TransformType::kMonth, 0, INTEGER()), "1970-01"); + EXPECT_EQ(toPath(TransformType::kMonth, 1, INTEGER()), "1970-02"); + EXPECT_EQ(toPath(TransformType::kMonth, 11, INTEGER()), "1970-12"); + EXPECT_EQ(toPath(TransformType::kMonth, 612, INTEGER()), "2021-01"); + EXPECT_EQ(toPath(TransformType::kMonth, -1, INTEGER()), "1969-12"); + EXPECT_EQ(toPath(TransformType::kMonth, -13, INTEGER()), "1968-12"); +} + +TEST(IcebergPartitionPathTest, day) { + EXPECT_EQ(toPath(TransformType::kDay, 0, DATE()), "1970-01-01"); + EXPECT_EQ(toPath(TransformType::kDay, 1, DATE()), "1970-01-02"); + EXPECT_EQ(toPath(TransformType::kDay, 18'262, DATE()), "2020-01-01"); + EXPECT_EQ(toPath(TransformType::kDay, -1, DATE()), "1969-12-31"); +} + +TEST(IcebergPartitionPathTest, hour) { + EXPECT_EQ(toPath(TransformType::kHour, 0, INTEGER()), "1970-01-01-00"); + EXPECT_EQ(toPath(TransformType::kHour, 1, INTEGER()), "1970-01-01-01"); + EXPECT_EQ(toPath(TransformType::kHour, 24, INTEGER()), "1970-01-02-00"); + EXPECT_EQ(toPath(TransformType::kHour, 438'288, INTEGER()), "2020-01-01-00"); + EXPECT_EQ(toPath(TransformType::kHour, -1, INTEGER()), "1969-12-31-23"); +} + +} // namespace + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/tests/TransformE2ETest.cpp b/velox/connectors/hive/iceberg/tests/TransformE2ETest.cpp new file mode 100644 index 00000000000..cd613219836 --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/TransformE2ETest.cpp @@ -0,0 +1,799 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include "velox/common/encode/Base64.h" +#include "velox/common/file/FileSystems.h" +#include "velox/common/testutil/TempDirectoryPath.h" +#include "velox/connectors/hive/iceberg/IcebergSplit.h" +#include "velox/connectors/hive/iceberg/tests/IcebergTestBase.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/functions/iceberg/Murmur3Hash32.h" + +using namespace facebook::velox::common::testutil; + +namespace facebook::velox::connector::hive::iceberg { + +using namespace facebook::velox::exec::test; + +namespace { + +#ifdef VELOX_ENABLE_PARQUET + +class TransformE2ETest : public test::IcebergTestBase { + protected: + static constexpr int32_t kDefaultNumBatches = 2; + static constexpr int32_t kDefaultRowsPerBatch = 100; + + std::vector createTimestampTestData() { + std::vector batches; + // 1. Aligned to hour boundaries (no minutes/seconds) since hour is the + // finest granularity transform tested. + // 2. Include negative epoch values (pre-1970) to test edge cases. + // 3. Span multiple years and months. + static const std::vector timestamps = { + // 1969-01-01 00:00:00 (negative epoch) + Timestamp(-31536000, 0), + // 1969-12-31 00:00:00 (day before epoch) + Timestamp(-86400, 0), + // 1970-01-01 00:00:00 + Timestamp(0, 0), + // 1970-01-01 01:00:00 + Timestamp(3600, 0), + // 1970-01-02 00:00:00 + Timestamp(86400, 0), + // 1970-01-31 00:00:00 + Timestamp(2592000, 0), + // 1971-01-01 00:00:00 + Timestamp(31536000, 0), + // 2021-01-01 00:00:00 + Timestamp(1609459200, 0), + // 2021-01-02 00:00:00 + Timestamp(1609545600, 0), + // 2021-02-01 00:00:00 + Timestamp(1612224000, 0), + // 2022-01-01 00:00:00 + Timestamp(1640995200, 0), + // 2023-01-01 00:00:00 + Timestamp(1672531200, 0), + // 2100-01-01 00:00:00 + Timestamp(4102444800, 0), + }; + + for (auto i = 0; i < kDefaultNumBatches; i++) { + auto timestampVector = makeFlatVector( + kDefaultRowsPerBatch, + [](auto row) { return timestamps[row % timestamps.size()]; }); + batches.push_back(makeRowVector({timestampVector})); + } + return batches; + } + + std::shared_ptr writeBatchesWithTransforms( + const std::vector& batches, + const std::vector& partitionFields) { + VELOX_CHECK(!batches.empty(), "input cannot be empty"); + + int64_t expectedRowCount = 0; + for (const auto& batch : batches) { + expectedRowCount += batch->size(); + } + + auto rowType = batches.front()->rowType(); + auto outputDirectory = TempDirectoryPath::create(); + const auto dataSink = createDataSinkAndAppendData( + batches, outputDirectory->getPath(), partitionFields); + dataSink->close(); + verifyTotalRowCount(rowType, outputDirectory->getPath(), expectedRowCount); + return outputDirectory; + } + + // Generate a key from a timestamp and transform type. + // The key format depends on the transform type: + // - kMonth: "YYYY-MM" + // - kDay: "YYYY-MM-DD" + // - kHour: "YYYY-MM-DD-HH" + static std::string timestampToKey( + const Timestamp& ts, + TransformType transformType) { + std::tm tm; + if (!Timestamp::epochToCalendarUtc(ts.getSeconds(), tm)) { + return ""; + } + + int32_t year = tm.tm_year + 1900; + int32_t month = tm.tm_mon + 1; + int32_t day = tm.tm_mday; + int32_t hour = tm.tm_hour; + + switch (transformType) { + case TransformType::kMonth: + return fmt::format("{:04d}-{:02d}", year, month); + case TransformType::kDay: + return fmt::format("{:04d}-{:02d}-{:02d}", year, month, day); + case TransformType::kHour: + return fmt::format( + "{:04d}-{:02d}-{:02d}-{:02d}", year, month, day, hour); + default: + VELOX_UNREACHABLE(); + } + } + + // Helper function to build expected counts map from timestamp batches. + // The key format depends on the transform type: + // - kMonth: "YYYY-MM" + // - kDay: "YYYY-MM-DD" + // - kHour: "YYYY-MM-DD-HH" + static std::unordered_map + buildExpectedCountsFromTimestamps( + const std::vector& batches, + TransformType transformType) { + std::unordered_map expectedCounts; + + for (const auto& batch : batches) { + auto timestampVector = batch->childAt(0)->as>(); + for (auto i = 0; i < batch->size(); i++) { + Timestamp ts = timestampVector->valueAt(i); + std::string key = timestampToKey(ts, transformType); + expectedCounts[key]++; + } + } + return expectedCounts; + } + + static std::string dirName(const std::string& path) { + return std::filesystem::path(path).filename().string(); + } + + static std::vector firstLevelDirectories( + const std::string& basePath) { + std::vector directories; + for (const auto& entry : std::filesystem::directory_iterator(basePath)) { + if (entry.is_directory()) { + directories.push_back(entry.path().string()); + } + } + return directories; + } + + static std::vector listDirectoriesRecursively( + const std::string& path) { + std::vector directories; + auto firstLevelDirs = firstLevelDirectories(path); + + for (const auto& dir : firstLevelDirs) { + directories.push_back(dirName(dir)); + auto subDirs = listDirectoriesRecursively(dir); + directories.insert(directories.end(), subDirs.begin(), subDirs.end()); + } + + return directories; + } + + static std::vector verifyPartitionCount( + const std::string& outputPath, + int32_t expectedCount) { + const auto partitionDirs = firstLevelDirectories(outputPath); + EXPECT_EQ(partitionDirs.size(), expectedCount); + for (const auto& dir : partitionDirs) { + const auto name = dirName(dir); + EXPECT_TRUE(name.find('=') != std::string::npos) + << "Partition directory " << name + << " does not follow Iceberg naming convention"; + } + return partitionDirs; + } + + // Verify the total row count across all partitions. + void verifyTotalRowCount( + const RowTypePtr& rowType, + const std::string& outputPath, + int32_t expectedRowCount) { + auto splits = createSplitsForDirectory(outputPath); + + const auto plan = PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .planNode(); + + const auto actualRowCount = + AssertQueryBuilder(plan).splits(splits).countResults(); + + ASSERT_EQ(actualRowCount, expectedRowCount); + } + + // Verify data in a specific partition. + void verifyPartitionData( + const RowTypePtr& rowType, + const std::string& partitionPath, + const std::string& partitionFilter, + const int32_t expectedRowCount) { + auto scanPlan = PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .planNode(); + + const auto actualRowCount = + AssertQueryBuilder(scanPlan) + .splits(createSplitsForDirectory(partitionPath)) + .countResults(); + + ASSERT_EQ(actualRowCount, expectedRowCount); + + const auto filterPlan = PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .filter(partitionFilter) + .planNode(); + const auto filteredRowCount = + AssertQueryBuilder(filterPlan) + .splits(createSplitsForDirectory(partitionPath)) + .countResults(); + ASSERT_EQ(expectedRowCount, filteredRowCount); + } + + static std::pair parsePartitionDirName( + const std::string& name) { + auto eq = name.find('='); + VELOX_CHECK(eq != std::string::npos); + auto us = name.rfind('_', eq - 1); + auto columnName = name.substr(0, us); + auto value = name.substr(eq + 1); + return {columnName, value}; + } + + static int32_t computeBucketHash( + const StringView& value, + int32_t numBuckets) { + int32_t hash = functions::iceberg::Murmur3Hash32::hashBytes( + value.data(), value.size()); + return ((hash & 0x7FFFFFFF) % numBuckets); + } + + folly::dynamic getPartitionValuesFromCommitMessage( + const RowVectorPtr& rowVector) { + auto outputDirectory = TempDirectoryPath::create(); + std::vector partitionTransforms = { + {0, TransformType::kIdentity, std::nullopt}, + }; + const auto dataSink = createDataSinkAndAppendData( + {rowVector}, outputDirectory->getPath(), partitionTransforms); + + dataSink->close(); + auto commitMessages = dataSink->commitMessage(); + VELOX_CHECK_EQ(commitMessages.size(), 1); + auto commitData = folly::parseJson(commitMessages[0]); + auto partitionDataJson = + folly::parseJson(commitData["partitionDataJson"].asString()); + auto partitionValues = partitionDataJson["partitionValues"]; + VELOX_CHECK_EQ(partitionValues.size(), 1); + return partitionValues[0]; + } +}; + +TEST_F(TransformE2ETest, identity) { + constexpr auto rowsPerBatch = 10; + constexpr auto duplicates = 5; + auto rowType = ROW({"c0"}, {INTEGER()}); + auto baseVectors = createTestData(rowType, kDefaultNumBatches, rowsPerBatch); + + // Duplicate each row to create multiple rows with the same partition key. + std::vector vectors; + for (const auto& baseVector : baseVectors) { + auto duplicatedColumn = wrapInDictionary( + makeIndices( + baseVector->size() * duplicates, + [duplicates](auto row) { return row / duplicates; }), + baseVector->size() * duplicates, + baseVector->childAt(0)); + vectors.push_back(makeRowVector({duplicatedColumn})); + } + + auto outputDirectory = writeBatchesWithTransforms( + vectors, {{0, TransformType::kIdentity, std::nullopt}}); + + auto partitionDirs = verifyPartitionCount( + outputDirectory->getPath(), kDefaultNumBatches * rowsPerBatch); + + for (const auto& dir : partitionDirs) { + const auto name = dirName(dir); + // Each partition should have duplicates rows. + verifyPartitionData(rowType, dir, name, duplicates); + } +} + +TEST_F(TransformE2ETest, partitionNamingConventions) { + auto rowVector = makeRowVector( + { + "c_int", + "c_bigint", + "c_varchar", + "c_varchar2", + "c_decimal", + "c_varbinary", + }, + { + makeConstant(42, 1, INTEGER()), + makeConstant(static_cast(9'876'543'210), 1, BIGINT()), + makeConstant("test string", 1, VARCHAR()), + makeNullConstant(TypeKind::VARCHAR, 1), + makeConstant(static_cast(1'234'567'890), 1, DECIMAL(18, 3)), + makeConstant("binarydata\1\2\3", 1, VARBINARY()), + }); + + std::vector partitionTransforms = { + {0, TransformType::kIdentity, std::nullopt}, // c_int. + {1, TransformType::kIdentity, std::nullopt}, // c_bigint. + {2, TransformType::kIdentity, std::nullopt}, // c_varchar. + {4, TransformType::kIdentity, std::nullopt}, // c_decimal. + {5, TransformType::kIdentity, std::nullopt}, // c_varbinary. + {3, TransformType::kIdentity, std::nullopt}, // c_varchar2. + }; + + auto outputDirectory = + writeBatchesWithTransforms({rowVector}, partitionTransforms); + + const auto actualPartitionNames = + listDirectoriesRecursively(outputDirectory->getPath()); + + // Build expected partition folder names. + std::vector expectedPartitionNames = { + "c_int=42", + "c_bigint=9876543210", + "c_varchar=test+string", + "c_decimal=1234567.890", + "c_varbinary=YmluYXJ5ZGF0YQECAw%3D%3D", + "c_varchar2=null", + }; + + ASSERT_EQ(actualPartitionNames, expectedPartitionNames) + << "Partition folder names do not match expected values"; +} + +TEST_F(TransformE2ETest, varbinaryPartitionCommitMessage) { + std::vector testData = { + "binarydata\x01\x02\x03", + "", + ".-_*/?%#&=+,;@!$\\\"'()<>:|[]{}^~`\n\r\t\u00A9", + }; + + for (const auto& binaryData : testData) { + SCOPED_TRACE(testing::Message() << "binaryData: " << binaryData); + auto rowVector = makeRowVector({ + makeConstant(binaryData, 1, VARBINARY()), + }); + + auto partitionValue = getPartitionValuesFromCommitMessage(rowVector); + ASSERT_EQ( + partitionValue.asString(), + encoding::Base64::encode(binaryData.data(), binaryData.size())); + } +} + +TEST_F(TransformE2ETest, nullPartitionValue) { + const std::vector typeKinds = { + TypeKind::BOOLEAN, + TypeKind::INTEGER, + TypeKind::BIGINT, + TypeKind::VARCHAR, + TypeKind::VARBINARY, + TypeKind::TIMESTAMP, + }; + + for (const auto& typeKind : typeKinds) { + auto rowVector = makeRowVector({ + makeNullConstant(typeKind, 1), + }); + + auto partitionValue = getPartitionValuesFromCommitMessage(rowVector); + ASSERT_TRUE(partitionValue.isNull()); + } +} + +TEST_F(TransformE2ETest, bucket) { + constexpr int32_t numBuckets = 4; + auto rowType = ROW({"c_varchar"}, {VARCHAR()}); + auto vectors = + createTestData(rowType, kDefaultNumBatches, kDefaultRowsPerBatch); + + auto outputDirectory = writeBatchesWithTransforms( + vectors, {{0, TransformType::kBucket, numBuckets}}); + + const auto partitionDirs = + verifyPartitionCount(outputDirectory->getPath(), numBuckets); + + int32_t totalRows = 0; + for (const auto& dir : partitionDirs) { + auto splits = createSplitsForDirectory(dir); + auto countPlan = PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .planNode(); + auto partitionRowCount = + AssertQueryBuilder(countPlan).splits(splits).countResults(); + + totalRows += partitionRowCount; + ASSERT_GT(partitionRowCount, 0); + } + + ASSERT_EQ(totalRows, kDefaultNumBatches * kDefaultRowsPerBatch); + + std::unordered_map valueToExpectedBucket; + + for (const auto& dir : partitionDirs) { + const auto name = dirName(dir); + const auto [k, v] = parsePartitionDirName(name); + const int32_t expectedBucket = std::stoi(v); + + auto dataPlan = PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .project({"c_varchar"}) + .planNode(); + const auto& dataResult = AssertQueryBuilder(dataPlan) + .splits(createSplitsForDirectory(dir)) + .copyResults(opPool_.get()); + + auto varcharColumn = dataResult->childAt(0)->asFlatVector(); + for (auto i = 0; i < dataResult->size(); i++) { + auto value = varcharColumn->valueAt(i); + + auto computedBucket = computeBucketHash(value, numBuckets); + ASSERT_EQ(computedBucket, expectedBucket); + } + } +} + +TEST_F(TransformE2ETest, truncate) { + auto rowType = ROW({"c_int"}, {INTEGER()}); + + std::vector batches; + for (auto i = 0; i < kDefaultNumBatches; i++) { + std::vector columns; + columns.push_back( + makeFlatVector(50, [](auto row) { return row % 100; })); + auto vectors = makeRowVector(rowType->names(), columns); + batches.push_back(vectors); + } + + auto outputDirectory = + writeBatchesWithTransforms(batches, {{0, TransformType::kTruncate, 10}}); + + auto partitionDirs = verifyPartitionCount(outputDirectory->getPath(), 5); + + for (const auto& dir : partitionDirs) { + const std::string name = dirName(dir); + auto [c, v] = parsePartitionDirName(name); + const std::string filter = fmt::format( + "{}>={} AND {}<{}", c, v, c, std::to_string(std::stoi(v) + 10)); + + verifyPartitionData( + rowType, dir, filter, 20); // 10 values per batch * 2 batches. + } +} + +TEST_F(TransformE2ETest, year) { + auto rowType = ROW({"c_date"}, {DATE()}); + static const std::vector dates = { + 18'262, 18'628, 18'993, 19'358, 19'723, 20'181}; + std::vector batches; + for (auto i = 0; i < kDefaultNumBatches; i++) { + auto dateVector = makeFlatVector( + kDefaultRowsPerBatch, + [](auto row) { return dates[row % dates.size()]; }, + nullptr, + DATE()); + batches.emplace_back(makeRowVector(rowType->names(), {dateVector})); + } + + auto outputDirectory = writeBatchesWithTransforms( + batches, {{0, TransformType::kYear, std::nullopt}}); + + auto partitionDirs = verifyPartitionCount(outputDirectory->getPath(), 6); + + for (int32_t year = 2020; year <= 2025; year++) { + const auto expectedDirName = fmt::format("c_date_year={}", year); + bool foundPartition = false; + auto yearFilter = [](int32_t year) -> std::string { + return fmt::format("YEAR(DATE '{}-01-01')={}", year, year); + }; + + for (const auto& dir : partitionDirs) { + SCOPED_TRACE(year); + const auto name = dirName(dir); + if (name == expectedDirName) { + foundPartition = true; + auto datePlan = PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .filter(yearFilter(year)) + .planNode(); + + auto partitionRowCount = AssertQueryBuilder(datePlan) + .splits(createSplitsForDirectory(dir)) + .countResults(); + + auto countPlan = PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .planNode(); + auto totalPartitionCount = AssertQueryBuilder(countPlan) + .splits(createSplitsForDirectory(dir)) + .countResults(); + ASSERT_EQ(partitionRowCount, totalPartitionCount); + break; + } + } + ASSERT_TRUE(foundPartition); + } +} + +TEST_F(TransformE2ETest, month) { + auto batches = createTimestampTestData(); + + auto outputDirectory = writeBatchesWithTransforms( + batches, {{0, TransformType::kMonth, std::nullopt}}); + + auto expectedCounts = + buildExpectedCountsFromTimestamps(batches, TransformType::kMonth); + const auto partitionDirs = firstLevelDirectories(outputDirectory->getPath()); + + for (const auto& dir : partitionDirs) { + const auto name = dirName(dir); + auto [c, v] = parsePartitionDirName(name); + size_t dashPos = v.find('-'); + ASSERT_NE(dashPos, std::string::npos) << "Invalid month format: " << v; + + int32_t year = std::stoi(v.substr(0, dashPos)); + int32_t month = std::stoi(v.substr(dashPos + 1)); + std::string filter = + fmt::format("YEAR(c0) = {} AND MONTH(c0) = {}", year, month); + std::string monthKey = fmt::format("{:04d}-{:02d}", year, month); + verifyPartitionData( + ROW({"c0"}, {TIMESTAMP()}), dir, filter, expectedCounts[monthKey]); + } +} + +TEST_F(TransformE2ETest, day) { + auto batches = createTimestampTestData(); + + auto outputDirectory = writeBatchesWithTransforms( + batches, {{0, TransformType::kDay, std::nullopt}}); + + auto expectedCounts = + buildExpectedCountsFromTimestamps(batches, TransformType::kDay); + const auto partitionDirs = firstLevelDirectories(outputDirectory->getPath()); + + for (const auto& dir : partitionDirs) { + const auto name = dirName(dir); + auto [c, v] = parsePartitionDirName(name); + std::vector dateParts; + folly::split('-', v, dateParts); + ASSERT_EQ(dateParts.size(), 3) << "Invalid day format: " << v; + + int32_t year = std::stoi(dateParts[0]); + int32_t month = std::stoi(dateParts[1]); + int32_t day = std::stoi(dateParts[2]); + + std::string filter = fmt::format( + "YEAR(c0) = {} AND MONTH(c0) = {} AND DAY(c0) = {}", year, month, day); + std::string dayKey = fmt::format("{:04d}-{:02d}-{:02d}", year, month, day); + verifyPartitionData( + ROW({"c0"}, {TIMESTAMP()}), dir, filter, expectedCounts[dayKey]); + } +} + +TEST_F(TransformE2ETest, hour) { + auto batches = createTimestampTestData(); + + auto outputDirectory = writeBatchesWithTransforms( + batches, {{0, TransformType::kHour, std::nullopt}}); + + auto expectedCounts = + buildExpectedCountsFromTimestamps(batches, TransformType::kHour); + const auto partitionDirs = firstLevelDirectories(outputDirectory->getPath()); + + for (const auto& dir : partitionDirs) { + const auto name = dirName(dir); + auto [c, v] = parsePartitionDirName(name); + std::vector dateParts; + folly::split('-', v, dateParts); + ASSERT_EQ(dateParts.size(), 4) << "Invalid hour format: " << v; + + int32_t year = std::stoi(dateParts[0]); + int32_t month = std::stoi(dateParts[1]); + int32_t day = std::stoi(dateParts[2]); + int32_t hour = std::stoi(dateParts[3]); + + std::string filter = fmt::format( + "YEAR(c0) = {} AND MONTH(c0) = {} AND " + "DAY(c0) = {} AND HOUR(c0) = {}", + year, + month, + day, + hour); + std::string hourKey = + fmt::format("{:04d}-{:02d}-{:02d}-{:02d}", year, month, day, hour); + verifyPartitionData( + ROW({"c0"}, {TIMESTAMP()}), dir, filter, expectedCounts[hourKey]); + } +} + +TEST_F(TransformE2ETest, multipleTransformsOnSameColumn) { + auto rowType = ROW( + { + "c_int", + "c_bigint", + }, + { + INTEGER(), + BIGINT(), + }); + + auto vectors = createTestData(rowType, 2, 20); + auto outputDirectory = writeBatchesWithTransforms( + vectors, + { + {0, TransformType::kIdentity, std::nullopt}, // c_int. + {0, TransformType::kTruncate, 10}, // truncate(c_int, 10). + {0, TransformType::kBucket, 4}, // bucket(c_int, 4). + }); + + auto firstLevelDirs = firstLevelDirectories(outputDirectory->getPath()); + ASSERT_GT(firstLevelDirs.size(), 0); + for (const auto& dir : firstLevelDirs) { + const auto name = dirName(dir); + ASSERT_TRUE(name.find("c_int=") != std::string::npos) + << "First level directory " << name << " should use identity transform"; + + auto secondLevelDirs = firstLevelDirectories(dir); + ASSERT_GT(secondLevelDirs.size(), 0) + << "No second level directories found in " << dir; + + for (const auto& secondDir : secondLevelDirs) { + const auto secondName = dirName(secondDir); + ASSERT_TRUE(secondName.find("c_int_trunc=") != std::string::npos) + << "Second level directory " << secondName + << " should use truncate transform"; + + auto thirdLevelDirs = firstLevelDirectories(secondDir); + ASSERT_GT(thirdLevelDirs.size(), 0) + << "No third level directories found in " << secondDir; + + for (const auto& thirdDir : thirdLevelDirs) { + const auto thirdName = dirName(thirdDir); + ASSERT_TRUE(thirdName.find("c_int_bucket=") != std::string::npos) + << "Third level directory " << thirdName + << " should use bucket transform"; + + // Verify the partition has data. + auto splits = createSplitsForDirectory(thirdDir); + auto countPlan = PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .planNode(); + auto rowCount = + AssertQueryBuilder(countPlan).splits(splits).countResults(); + ASSERT_GT(rowCount, 0) + << "Leaf partition directory " << thirdDir << " has no data"; + } + } + } +} + +TEST_F(TransformE2ETest, dateIdentityPartitionWithFilter) { + auto rowType = ROW({"c_date", "c_value"}, {DATE(), INTEGER()}); + + static const std::vector dates = {20147, 19816}; + std::vector batches; + for (auto i = 0; i < kDefaultNumBatches; i++) { + batches.emplace_back(makeRowVector( + rowType->names(), + {makeFlatVector( + kDefaultRowsPerBatch, + [](auto row) { return dates[row % dates.size()]; }, + nullptr, + DATE()), + makeFlatVector( + kDefaultRowsPerBatch, [](auto row) { return row; })})); + } + + auto outputDirectory = writeBatchesWithTransforms( + batches, {{0, TransformType::kIdentity, std::nullopt}}); + + auto partitionDirs = verifyPartitionCount(outputDirectory->getPath(), 2); + + std::vector> splits; + + for (const auto& dir : partitionDirs) { + const auto daysSinceEpoch = + dirName(dir) == "c_date=2025-02-28" ? "20147" : "19816"; + + for (const auto& filePath : listFiles(dir)) { + const auto file = filesystems::getFileSystem(filePath, nullptr) + ->openFileForRead(filePath); + splits.push_back( + std::make_shared( + test::kIcebergConnectorId, + filePath, + fileFormat_, + 0, + file->size(), + std::unordered_map>{ + {"c_date", daysSinceEpoch}}, + std::nullopt, + std::unordered_map{}, + nullptr, + /*cacheable=*/true, + std::vector())); + } + } + + ColumnHandleMap assignments{ + {"c_date", + std::make_shared( + "c_date", + FileColumnHandle::ColumnType::kPartitionKey, + DATE(), + parquet::ParquetFieldId{0, {}}, + std::vector{})}, + {"c_value", + std::make_shared( + "c_value", + FileColumnHandle::ColumnType::kRegular, + INTEGER(), + parquet::ParquetFieldId{1, {}})}, + }; + + auto filterPlan = PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .assignments(assignments) + .remainingFilter("c_date = DATE '2025-02-28'") + .endTableScan() + .planNode(); + + const auto filteredRowCount = + AssertQueryBuilder(filterPlan).splits(splits).countResults(); + + ASSERT_EQ(filteredRowCount, kDefaultRowsPerBatch); +} +#endif + +} // namespace + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/tests/TransformTest.cpp b/velox/connectors/hive/iceberg/tests/TransformTest.cpp new file mode 100644 index 00000000000..05660e6cf8c --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/TransformTest.cpp @@ -0,0 +1,332 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/encode/Base64.h" +#include "velox/connectors/hive/iceberg/IcebergConfig.h" +#include "velox/connectors/hive/iceberg/PartitionSpec.h" +#include "velox/connectors/hive/iceberg/TransformEvaluator.h" +#include "velox/connectors/hive/iceberg/TransformExprBuilder.h" +#include "velox/connectors/hive/iceberg/tests/IcebergTestBase.h" + +namespace facebook::velox::connector::hive::iceberg { + +namespace { + +class TransformTest : public test::IcebergTestBase { + protected: + void testTransform( + const IcebergPartitionSpecPtr& spec, + const RowVectorPtr& input, + const RowVectorPtr& expected) const { + std::vector partitionChannels; + for (const auto& field : spec->fields) { + partitionChannels.push_back(input->rowType()->getChildIdx(field.name)); + } + // Build and evaluate transform expressions. + auto transformExprs = TransformExprBuilder::toExpressions( + spec, + partitionChannels, + input->rowType(), + std::string(IcebergConfig::kDefaultFunctionPrefix)); + auto transformEvaluator = std::make_unique( + transformExprs, connectorQueryCtx_.get()); + auto result = transformEvaluator->evaluate(input); + + ASSERT_EQ(result.size(), expected->childrenSize()); + for (auto i = 0; i < result.size(); ++i) { + velox::test::assertEqualVectors(expected->childAt(i), result[i]); + } + } +}; + +TEST_F(TransformTest, identity) { + const auto& rowType = + ROW({"c0", "c1", "c2", "c3", "c4"}, + {INTEGER(), BIGINT(), VARCHAR(), VARBINARY(), TIMESTAMP()}); + const auto& partitionSpec = createPartitionSpec( + rowType, + { + {0, TransformType::kIdentity, std::nullopt}, + {1, TransformType::kIdentity, std::nullopt}, + {2, TransformType::kIdentity, std::nullopt}, + {3, TransformType::kIdentity, std::nullopt}, + {4, TransformType::kIdentity, std::nullopt}, + }); + + const std::vector input = { + makeFlatVector({1, -1}), + makeFlatVector({1L, -1L}), + makeFlatVector({("test data"), ("")}), + makeFlatVector({("\x01\x02\x03"), ("")}, VARBINARY()), + makeFlatVector({Timestamp(0, 0), Timestamp(1609459200, 0)}), + }; + + testTransform(partitionSpec, makeRowVector(input), makeRowVector(input)); +} + +TEST_F(TransformTest, nulls) { + const auto& rowType = + ROW({"c0", "c1", "c2", "c3", "c4", "c5", "c6"}, + {INTEGER(), + VARCHAR(), + VARBINARY(), + DATE(), + TIMESTAMP(), + TIMESTAMP(), + TIMESTAMP()}); + const auto& partitionSpec = createPartitionSpec( + rowType, + { + {0, TransformType::kIdentity, std::nullopt}, + {1, TransformType::kBucket, 8}, + {2, TransformType::kTruncate, 16}, + {3, TransformType::kYear, std::nullopt}, + {4, TransformType::kMonth, std::nullopt}, + {5, TransformType::kDay, std::nullopt}, + {6, TransformType::kHour, std::nullopt}, + }); + testTransform( + partitionSpec, + makeRowVector({ + makeNullableFlatVector({std::nullopt}), + makeNullableFlatVector({std::nullopt}), + makeNullableFlatVector({std::nullopt}, VARBINARY()), + makeNullableFlatVector({std::nullopt}, DATE()), + makeNullableFlatVector({std::nullopt}), + makeNullableFlatVector({std::nullopt}), + makeNullableFlatVector({std::nullopt}), + }), + makeRowVector({ + makeNullableFlatVector({std::nullopt}), + makeNullableFlatVector({std::nullopt}), + makeNullableFlatVector({std::nullopt}, VARBINARY()), + makeNullableFlatVector({std::nullopt}), + makeNullableFlatVector({std::nullopt}), + makeNullableFlatVector({std::nullopt}, DATE()), + makeNullableFlatVector({std::nullopt}), + })); +} + +TEST_F(TransformTest, bucket) { + const auto& rowType = + ROW({"c0", "c1", "c2", "c3", "c4", "c5"}, + {INTEGER(), BIGINT(), VARCHAR(), VARBINARY(), DATE(), TIMESTAMP()}); + const auto& partitionSpec = createPartitionSpec( + rowType, + { + {0, TransformType::kBucket, 4}, + {1, TransformType::kBucket, 8}, + {2, TransformType::kBucket, 16}, + {3, TransformType::kBucket, 32}, + {4, TransformType::kBucket, 10}, + {5, TransformType::kBucket, 8}, + }); + + testTransform( + partitionSpec, + makeRowVector({ + makeFlatVector({8, 34, 0}), + makeFlatVector({34L, 0L, -34L}), + makeFlatVector({"abcdefg", "测试", ""}), + makeFlatVector( + {"\x61\x62\x64\x00\x00", "\x01\x02\x03\x04", "\x00"}, + VARBINARY()), + makeFlatVector({0, 365, 18'262}), + makeFlatVector( + {Timestamp(0, 0), + Timestamp(-31536000, 0), + Timestamp(1612224000, 0)}), + }), + makeRowVector({ + makeFlatVector({3, 3, 0}), + makeFlatVector({3, 4, 5}), + makeFlatVector({6, 8, 0}), + makeFlatVector({26, 5, 0}), + makeFlatVector({6, 1, 3}), + makeFlatVector({4, 3, 5}), + })); +} + +TEST_F(TransformTest, year) { + const auto& rowType = ROW({"c0", "c1"}, {DATE(), TIMESTAMP()}); + const auto& partitionSpec = createPartitionSpec( + rowType, + { + {0, TransformType::kYear, std::nullopt}, + {1, TransformType::kYear, std::nullopt}, + }); + + testTransform( + partitionSpec, + makeRowVector({ + makeFlatVector({0, 18'262, -365}), + makeFlatVector( + {Timestamp(0, 0), + Timestamp(31536000, 0), + Timestamp(-31536000, 0)}), + }), + makeRowVector({ + makeFlatVector({0, 50, -1}), + makeFlatVector({0, 1, -1}), + })); +} + +TEST_F(TransformTest, month) { + const auto& rowType = ROW({"c0", "c1"}, {DATE(), TIMESTAMP()}); + const auto& partitionSpec = createPartitionSpec( + rowType, + { + {0, TransformType::kMonth, std::nullopt}, + {1, TransformType::kMonth, std::nullopt}, + }); + + testTransform( + partitionSpec, + makeRowVector({ + makeFlatVector({0, 18'262, -365}), + makeFlatVector( + {Timestamp(0, 0), + Timestamp(31536000, 0), + Timestamp(-2678400, 0)}), + }), + makeRowVector({ + makeFlatVector({0, 600, -12}), + makeFlatVector({0, 12, -1}), + })); +} + +TEST_F(TransformTest, day) { + const auto& rowType = ROW({"c0", "c1"}, {DATE(), TIMESTAMP()}); + const auto& partitionSpec = createPartitionSpec( + rowType, + { + {0, TransformType::kDay, std::nullopt}, + {1, TransformType::kDay, std::nullopt}, + }); + + testTransform( + partitionSpec, + makeRowVector({ + makeFlatVector({0, 17532, -1}, DATE()), + makeFlatVector( + {Timestamp(0, 0), + Timestamp(1514764800, 0), + Timestamp(-86400, 0)}), + }), + makeRowVector({ + makeFlatVector({0, 17532, -1}, DATE()), + makeFlatVector({0, 17532, -1}, DATE()), + })); +} + +TEST_F(TransformTest, hour) { + const auto& partitionSpec = createPartitionSpec( + ROW({"c0"}, {TIMESTAMP()}), {{0, TransformType::kHour, std::nullopt}}); + + testTransform( + partitionSpec, + makeRowVector({makeFlatVector({ + Timestamp(0, 0), + Timestamp(3600, 0), + Timestamp(-3600, 0), + })}), + makeRowVector({makeFlatVector({0, 1, -1})})); +} + +TEST_F(TransformTest, truncate) { + const auto& rowType = ROW( + {"c0", "c1", "c2", "c3"}, {INTEGER(), BIGINT(), VARCHAR(), VARBINARY()}); + const auto& partitionSpec = createPartitionSpec( + rowType, + { + {0, TransformType::kTruncate, 10}, + {1, TransformType::kTruncate, 100}, + {2, TransformType::kTruncate, 5}, + {3, TransformType::kTruncate, 3}, + }); + + testTransform( + partitionSpec, + makeRowVector({ + makeFlatVector({11, -11, 5}), + makeFlatVector({123L, -123L, 50L}), + makeFlatVector({"abcdefg", "测试data", "x"}), + makeFlatVector( + {"abcdefg", "\x01\x02\x03\x04", "\x05"}, VARBINARY()), + }), + makeRowVector({ + makeFlatVector({10, -20, 0}), + makeFlatVector({100L, -200L, 0L}), + makeFlatVector({"abcde", "测试dat", "x"}), + makeFlatVector( + {"abc", "\x01\x02\x03", "\x05"}, VARBINARY()), + })); +} + +TEST_F(TransformTest, multipleTransforms) { + const auto& rowType = ROW({"c0", "c1", "c2"}, {INTEGER(), DATE(), VARCHAR()}); + const auto& partitionSpec = createPartitionSpec( + rowType, + { + {0, TransformType::kBucket, 4}, + {1, TransformType::kYear, std::nullopt}, + {2, TransformType::kTruncate, 3}, + }); + + testTransform( + partitionSpec, + makeRowVector({ + makeFlatVector({8, 34}), + makeFlatVector({0, 17532}), + makeFlatVector({"abcdefg", "ab c"}), + }), + makeRowVector({ + makeFlatVector({3, 3}), + makeFlatVector({0, 48}), + makeFlatVector({"abc", "ab "}), + })); +} + +TEST_F(TransformTest, multipleTransformsOnSameColumn) { + const auto& rowType = ROW({"c0", "c1"}, {DATE(), VARCHAR()}); + const auto& partitionSpec = createPartitionSpec( + rowType, + { + {0, TransformType::kYear, std::nullopt}, + {0, TransformType::kBucket, 10}, + {1, TransformType::kTruncate, 5}, + {1, TransformType::kBucket, 8}, + }); + + testTransform( + partitionSpec, + makeRowVector( + rowType->names(), + { + makeFlatVector({0, 17532}), + makeFlatVector({"abcdefg", "test"}), + }), + makeRowVector({ + makeFlatVector({0, 48}), + makeFlatVector({6, 7}), + makeFlatVector({"abcde", "test"}), + makeFlatVector({6, 3}), + })); +} + +} // namespace + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/tests/WriterOptionsAdapterTest.cpp b/velox/connectors/hive/iceberg/tests/WriterOptionsAdapterTest.cpp new file mode 100644 index 00000000000..854abc4a0c2 --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/WriterOptionsAdapterTest.cpp @@ -0,0 +1,126 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/WriterOptionsAdapter.h" + +#include +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/dwio/dwrf/writer/Writer.h" +#include "velox/dwio/parquet/writer/WriterConfig.h" + +namespace facebook::velox::connector::hive::iceberg { +namespace { + +// Verifies the dispatch table in createWriterOptionsAdapter(): +// PARQUET, DWRF, NIMBLE return non-null adapters; everything else returns +// null. This is the single source of truth for which file formats the +// Iceberg DataSink supports on the write path. +TEST(WriterOptionsAdapterTest, createWriterOptionsAdapterDispatch) { + EXPECT_NE( + createWriterOptionsAdapter(dwio::common::FileFormat::PARQUET), nullptr); + EXPECT_NE( + createWriterOptionsAdapter(dwio::common::FileFormat::DWRF), nullptr); + EXPECT_NE( + createWriterOptionsAdapter(dwio::common::FileFormat::NIMBLE), nullptr); + + // ORC, TEXT, JSON, ALPHA, etc. are intentionally unsupported on the + // write path until each gets its own end-to-end coverage. + EXPECT_EQ(createWriterOptionsAdapter(dwio::common::FileFormat::ORC), nullptr); + EXPECT_EQ( + createWriterOptionsAdapter(dwio::common::FileFormat::TEXT), nullptr); + EXPECT_EQ( + createWriterOptionsAdapter(dwio::common::FileFormat::JSON), nullptr); +} + +// Verifies isSupportedFileFormat() agrees with createWriterOptionsAdapter(). +TEST(WriterOptionsAdapterTest, isSupportedFileFormatMatchesDispatch) { + EXPECT_TRUE(isSupportedFileFormat(dwio::common::FileFormat::PARQUET)); + EXPECT_TRUE(isSupportedFileFormat(dwio::common::FileFormat::DWRF)); + EXPECT_TRUE(isSupportedFileFormat(dwio::common::FileFormat::NIMBLE)); + + EXPECT_FALSE(isSupportedFileFormat(dwio::common::FileFormat::ORC)); + EXPECT_FALSE(isSupportedFileFormat(dwio::common::FileFormat::TEXT)); + EXPECT_FALSE(isSupportedFileFormat(dwio::common::FileFormat::JSON)); +} + +// Verifies the manifest format string written into Iceberg commit messages +// matches the cross-engine convention. Iceberg's manifest vocabulary has no +// DWRF/NIMBLE enum, so both report "ORC" (matching the Java planner's +// FileFormat.{DWRF,NIMBLE}.toIceberg()) so downstream Iceberg consumers can +// interpret the message. Parquet reports "PARQUET" because Iceberg has a +// native enum for it. The on-disk format is identified at read time via +// the file extension and on-disk magic bytes, so writing "ORC" for NIMBLE +// is safe. +TEST(WriterOptionsAdapterTest, toManifestFormatString) { + EXPECT_EQ( + toManifestFormatString(dwio::common::FileFormat::PARQUET), "PARQUET"); + EXPECT_EQ(toManifestFormatString(dwio::common::FileFormat::DWRF), "ORC"); + EXPECT_EQ(toManifestFormatString(dwio::common::FileFormat::NIMBLE), "ORC"); +} + +// Verifies the Parquet adapter's pre-config hook installs the Iceberg-spec +// timestamp serdeParameters. These values must be set before +// processConfigs() runs because the Parquet writer reads them from +// serdeParameters during config processing. +TEST(WriterOptionsAdapterTest, parquetPreConfigsSetsTimestampSerdeParameters) { + auto adapter = createWriterOptionsAdapter(dwio::common::FileFormat::PARQUET); + ASSERT_NE(adapter, nullptr); + + dwio::common::WriterOptions options; + adapter->applyPreConfigs(options); + + EXPECT_EQ( + options + .serdeParameters[parquet::WriterConfig::kParquetSerdeTimestampUnit], + "6"); + EXPECT_EQ( + options.serdeParameters + [parquet::WriterConfig::kParquetSerdeTimestampTimezone], + ""); +} + +// Verifies the DWRF adapter's post-config hook overrides timestamp settings +// regardless of what processConfigs() left in place. The Iceberg spec +// requires timestamps NOT be adjusted to UTC; if the DataSink stops calling +// applyPostConfigs after processConfigs, this test still locks the adapter's +// override contract — IcebergDataSink::createWriterOptions must use it. +TEST(WriterOptionsAdapterTest, dwrfPostConfigsOverridesTimestampFields) { + auto adapter = createWriterOptionsAdapter(dwio::common::FileFormat::DWRF); + ASSERT_NE(adapter, nullptr); + + dwrf::WriterOptions options; + options.adjustTimestampToTimezone = true; + options.sessionTimezone = tz::locateZone("America/Los_Angeles"); + + adapter->applyPostConfigs(options); + + EXPECT_FALSE(options.adjustTimestampToTimezone); + EXPECT_EQ(options.sessionTimezone, nullptr); +} + +// Verifies toManifestFormatString() throws for unsupported formats rather +// than silently returning an incorrect string. +TEST(WriterOptionsAdapterTest, toManifestFormatStringThrowsForUnsupported) { + VELOX_ASSERT_THROW( + toManifestFormatString(dwio::common::FileFormat::ORC), + "Unsupported file format for Iceberg manifest"); + VELOX_ASSERT_THROW( + toManifestFormatString(dwio::common::FileFormat::TEXT), + "Unsupported file format for Iceberg manifest"); +} + +} // namespace +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/paimon/CMakeLists.txt b/velox/connectors/hive/paimon/CMakeLists.txt new file mode 100644 index 00000000000..615b1e260c1 --- /dev/null +++ b/velox/connectors/hive/paimon/CMakeLists.txt @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +velox_add_library( + velox_hive_paimon_split + PaimonConnectorSplit.cpp + PaimonDataFileMeta.cpp + PaimonDeletionFile.cpp + PaimonRowKind.cpp + HEADERS + PaimonConnectorSplit.h + PaimonDataFileMeta.h + PaimonDeletionFile.h + PaimonRowKind.h +) + +velox_link_libraries(velox_hive_paimon_split velox_connector velox_hive_connector fmt::fmt) + +velox_add_library( + velox_hive_paimon_connector + PaimonConnector.cpp + PaimonDataSource.cpp + PaimonSplitReader.cpp + HEADERS + PaimonConfig.h + PaimonConnector.h + PaimonDataSource.h + PaimonSplitReader.h +) + +velox_link_libraries(velox_hive_paimon_connector velox_hive_paimon_split velox_hive_connector) + +if(${VELOX_BUILD_TESTING}) + add_subdirectory(tests) +endif() diff --git a/velox/experimental/wave/dwio/nimble/Encoding.h b/velox/connectors/hive/paimon/PaimonConfig.h similarity index 60% rename from velox/experimental/wave/dwio/nimble/Encoding.h rename to velox/connectors/hive/paimon/PaimonConfig.h index 5ed9e0cac1e..c53f76eebc4 100644 --- a/velox/experimental/wave/dwio/nimble/Encoding.h +++ b/velox/connectors/hive/paimon/PaimonConfig.h @@ -13,19 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -namespace facebook::wave::nimble { -class Encoding { +#include "velox/connectors/hive/FileConfig.h" + +namespace facebook::velox::connector::hive::paimon { + +/// Paimon-specific connector configuration. +/// Extends FileConfig with Paimon-specific settings. +class PaimonConfig : public FileConfig { public: - // The binary layout for each Encoding begins with the same prefix: - // 1 byte: EncodingType - // 1 byte: DataType - // 4 bytes: uint32_t num rows - static constexpr int kEncodingTypeOffset = 0; - static constexpr int kDataTypeOffset = 1; - static constexpr int kRowCountOffset = 2; - static constexpr int kPrefixSize = 6; + explicit PaimonConfig(std::shared_ptr config) + : FileConfig(std::move(config)) {} }; -} // namespace facebook::wave::nimble + +} // namespace facebook::velox::connector::hive::paimon diff --git a/velox/connectors/hive/paimon/PaimonConnector.cpp b/velox/connectors/hive/paimon/PaimonConnector.cpp new file mode 100644 index 00000000000..2b04959f5cc --- /dev/null +++ b/velox/connectors/hive/paimon/PaimonConnector.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/connectors/hive/paimon/PaimonConnector.h" + +#include "velox/connectors/hive/paimon/PaimonDataSource.h" + +namespace facebook::velox::connector::hive::paimon { + +PaimonConnector::PaimonConnector( + const std::string& id, + std::shared_ptr config, + folly::Executor* ioExecutor) + : HiveConnector(id, config, ioExecutor), + paimonConfig_(std::make_shared(connectorConfig())) {} + +std::unique_ptr PaimonConnector::createDataSource( + const RowTypePtr& outputType, + const ConnectorTableHandlePtr& tableHandle, + const ColumnHandleMap& columnHandles, + ConnectorQueryCtx* connectorQueryCtx) { + return std::make_unique( + outputType, + tableHandle, + columnHandles, + &fileHandleFactory_, + ioExecutor_, + connectorQueryCtx, + paimonConfig_); +} + +} // namespace facebook::velox::connector::hive::paimon diff --git a/velox/connectors/hive/paimon/PaimonConnector.h b/velox/connectors/hive/paimon/PaimonConnector.h new file mode 100644 index 00000000000..97bf4a50aa8 --- /dev/null +++ b/velox/connectors/hive/paimon/PaimonConnector.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/connectors/hive/HiveConnector.h" +#include "velox/connectors/hive/paimon/PaimonConfig.h" + +namespace facebook::velox::connector::hive::paimon { + +/// Provides Paimon table format support by extending HiveConnector. +/// +/// Creates PaimonDataSource instances that handle Paimon's multi-file splits +/// (one split = one bucket with multiple data files across LSM-tree levels). +/// Reuses HiveConnector's ORC/Parquet readers directly — no Arrow bridge. +class PaimonConnector final : public HiveConnector { + public: + PaimonConnector( + const std::string& id, + std::shared_ptr config, + folly::Executor* ioExecutor); + + /// Creates PaimonDataSource for reading from Paimon tables. + std::unique_ptr createDataSource( + const RowTypePtr& outputType, + const ConnectorTableHandlePtr& tableHandle, + const ColumnHandleMap& columnHandles, + ConnectorQueryCtx* connectorQueryCtx) override; + + private: + const std::shared_ptr paimonConfig_; +}; + +class PaimonConnectorFactory final : public ConnectorFactory { + public: + static constexpr const char* kPaimonConnectorName = "paimon"; + + PaimonConnectorFactory() : ConnectorFactory(kPaimonConnectorName) {} + + std::shared_ptr newConnector( + const std::string& id, + std::shared_ptr config, + folly::Executor* ioExecutor = nullptr, + [[maybe_unused]] folly::Executor* cpuExecutor = nullptr) override { + return std::make_shared(id, config, ioExecutor); + } +}; + +} // namespace facebook::velox::connector::hive::paimon diff --git a/velox/connectors/hive/paimon/PaimonConnectorSplit.cpp b/velox/connectors/hive/paimon/PaimonConnectorSplit.cpp new file mode 100644 index 00000000000..176159c33bb --- /dev/null +++ b/velox/connectors/hive/paimon/PaimonConnectorSplit.cpp @@ -0,0 +1,214 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/connectors/hive/paimon/PaimonConnectorSplit.h" + +#include + +#include "velox/common/base/Exceptions.h" + +namespace facebook::velox::connector::hive::paimon { + +std::string paimonTableTypeString(PaimonTableType type) { + switch (type) { + case PaimonTableType::kAppendOnly: + return "APPEND_ONLY"; + case PaimonTableType::kPrimaryKey: + return "PRIMARY_KEY"; + default: + VELOX_FAIL("Unknown PaimonTableType: {}", static_cast(type)); + } +} + +PaimonTableType paimonTableTypeFromString(const std::string& str) { + if (str == "APPEND_ONLY") { + return PaimonTableType::kAppendOnly; + } + if (str == "PRIMARY_KEY") { + return PaimonTableType::kPrimaryKey; + } + VELOX_FAIL("Unknown PaimonTableType: {}", str); +} + +PaimonConnectorSplit::PaimonConnectorSplit( + const std::string& connectorId, + int64_t snapshotId, + PaimonTableType tableType, + dwio::common::FileFormat fileFormat, + const std::vector& dataFiles, + std::unordered_map> partitionKeys, + std::optional tableBucketNumber, + bool rawConvertible) + : ConnectorSplit(connectorId), + snapshotId_(snapshotId), + tableType_(tableType), + fileFormat_(fileFormat), + dataFiles_(dataFiles), + partitionKeys_(std::move(partitionKeys)), + tableBucketNumber_(tableBucketNumber), + rawConvertible_(rawConvertible) { + VELOX_CHECK( + !dataFiles_.empty(), "PaimonConnectorSplit requires non-empty dataFiles"); + + if (rawConvertible_) { + for (const auto& file : dataFiles_) { + VELOX_CHECK_EQ( + file.deleteRowCount, + 0, + "rawConvertible split cannot have files with deleteRowCount > 0: {}", + file.toString()); + } + } +} + +std::string PaimonConnectorSplit::toString() const { + std::string dataFilesStr; + for (const auto& file : dataFiles_) { + if (!dataFilesStr.empty()) { + dataFilesStr += ", "; + } + dataFilesStr += file.toString(); + } + + return fmt::format( + "PaimonConnectorSplit[snapshot {}, type {}, rawConvertible {}, " + "connector '{}', dataFiles=[{}]]", + snapshotId_, + paimonTableTypeString(tableType_), + rawConvertible_, + connectorId, + dataFilesStr); +} + +folly::dynamic PaimonConnectorSplit::serialize() const { + folly::dynamic obj = folly::dynamic::object; + obj["name"] = "PaimonConnectorSplit"; + obj["connectorId"] = connectorId; + obj["snapshotId"] = snapshotId_; + obj["tableType"] = paimonTableTypeString(tableType_); + obj["rawConvertible"] = rawConvertible_; + + folly::dynamic filesArray = folly::dynamic::array; + for (const auto& file : dataFiles_) { + filesArray.push_back(file.serialize()); + } + obj["dataFiles"] = filesArray; + + folly::dynamic partitionKeysObj = folly::dynamic::object; + for (const auto& [key, value] : partitionKeys_) { + partitionKeysObj[key] = + value.has_value() ? folly::dynamic(value.value()) : nullptr; + } + obj["partitionKeys"] = partitionKeysObj; + + obj["tableBucketNumber"] = tableBucketNumber_.has_value() + ? folly::dynamic(tableBucketNumber_.value()) + : nullptr; + + obj["fileFormat"] = dwio::common::toString(fileFormat_); + + return obj; +} + +// static +std::shared_ptr PaimonConnectorSplit::create( + const folly::dynamic& obj) { + const auto connectorId = obj["connectorId"].asString(); + const auto snapshotId = obj["snapshotId"].asInt(); + const auto tableType = paimonTableTypeFromString(obj["tableType"].asString()); + const auto rawConvertible = obj["rawConvertible"].asBool(); + + std::vector dataFiles; + for (const auto& fileObj : obj["dataFiles"]) { + dataFiles.emplace_back(PaimonDataFile::create(fileObj)); + } + + std::unordered_map> partitionKeys; + for (const auto& [key, value] : obj["partitionKeys"].items()) { + partitionKeys[key.asString()] = value.isNull() + ? std::nullopt + : std::optional(value.asString()); + } + + const auto tableBucketNumber = obj["tableBucketNumber"].isNull() + ? std::nullopt + : std::optional(obj["tableBucketNumber"].asInt()); + + const auto fileFormat = + dwio::common::toFileFormat(obj["fileFormat"].asString()); + + return std::make_shared( + connectorId, + snapshotId, + tableType, + fileFormat, + dataFiles, + std::move(partitionKeys), + tableBucketNumber, + rawConvertible); +} + +// static +void PaimonConnectorSplit::registerSerDe() { + auto& registry = DeserializationRegistryForSharedPtr(); + registry.Register("PaimonConnectorSplit", PaimonConnectorSplit::create); +} + +// --- Builder --- + +PaimonConnectorSplitBuilder& PaimonConnectorSplitBuilder::addFile( + std::string filePath, + uint64_t fileSize, + int32_t level) { + PaimonDataFile meta; + meta.path = std::move(filePath); + meta.size = fileSize; + meta.level = level; + dataFiles_.emplace_back(std::move(meta)); + return *this; +} + +PaimonConnectorSplitBuilder& PaimonConnectorSplitBuilder::partitionKey( + std::string name, + std::optional value) { + partitionKeys_.emplace(std::move(name), std::move(value)); + return *this; +} + +PaimonConnectorSplitBuilder& PaimonConnectorSplitBuilder::tableBucketNumber( + int32_t bucketId) { + tableBucketNumber_ = bucketId; + return *this; +} + +PaimonConnectorSplitBuilder& PaimonConnectorSplitBuilder::rawConvertible( + bool value) { + rawConvertible_ = value; + return *this; +} + +std::shared_ptr PaimonConnectorSplitBuilder::build() { + return std::make_shared( + connectorId_, + snapshotId_, + tableType_, + fileFormat_, + dataFiles_, + partitionKeys_, + tableBucketNumber_, + rawConvertible_); +} + +} // namespace facebook::velox::connector::hive::paimon diff --git a/velox/connectors/hive/paimon/PaimonConnectorSplit.h b/velox/connectors/hive/paimon/PaimonConnectorSplit.h new file mode 100644 index 00000000000..9182b51b22a --- /dev/null +++ b/velox/connectors/hive/paimon/PaimonConnectorSplit.h @@ -0,0 +1,587 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +#include "velox/connectors/Connector.h" +#include "velox/connectors/hive/paimon/PaimonDataFileMeta.h" +#include "velox/dwio/common/Options.h" + +namespace facebook::velox::connector::hive::paimon { + +/// Paimon table type determines read and write semantics. +/// +/// # Table Types +/// +/// Append-only: No primary key. Each write appends new independent files. All +/// files are at LSM level 0 and rawConvertible is always true. Files can be +/// read in any order for batch queries. No merge-on-read needed. Compaction +/// only reduces the number of small files (concatenation, no deduplication). +/// +/// Primary-key: Uses an LSM tree to organize data files per bucket. New writes +/// go to level 0; background compaction merges files into higher levels, +/// deduplicating by primary key. The merge engine determines how duplicate +/// keys are resolved: +/// +/// Deduplicate (upsert): Latest write wins — entire row replaced. +/// Write {id=1, name="Alice"} then {id=1, name="Bob"} +/// → result: {id=1, name="Bob"} +/// +/// Partial-update: Only non-null columns in new record overwrite old values. +/// Write {id=1, name="Alice", age=25} then {id=1, name=NULL, age=26} +/// → result: {id=1, name="Alice", age=26} +/// +/// # Partitioning and Bucketing +/// +/// Paimon uses a fixed physical layout: partition-dirs / bucket-dir / files. +/// The hierarchy is strictly partition → bucket → files with no nesting or +/// reordering (no bucket-before-partition or bucket-after-bucket). +/// +/// Partitioning: Hive-style directory partitioning by column values. Zero or +/// more partition columns, ordered. Each unique combination of partition values +/// creates a physical directory. Partition columns are NOT stored in the data +/// files (values are in the directory path). Enables partition pruning. +/// +/// Bucketing: Hash-distributes rows within a partition into N buckets. Each +/// bucket is a physical directory (bucket-0/, bucket-1/, ...) and acts as an +/// independent LSM tree instance. For primary-key tables, the bucket key +/// defaults to the primary key, ensuring each key maps to exactly one bucket +/// (merge-on-read only needs to look within a single bucket). For append-only +/// tables, bucket key can be any column(s) for write parallelism. +/// +/// Unbucketed tables use a single implicit bucket-0/ directory. All four +/// combinations of partitioned/unpartitioned × bucketed/unbucketed are valid: +/// +/// Partitioned + bucketed: dt=2024-01-01/bucket-0/data.orc +/// Partitioned + unbucketed: dt=2024-01-01/bucket-0/data.orc (bucket-0 only) +/// Unpartitioned + bucketed: bucket-0/data.orc, bucket-1/data.orc, ... +/// Unpartitioned + unbucketed: bucket-0/data.orc (bucket-0 only) +/// +/// Each PaimonConnectorSplit represents one partition + one bucket. The +/// partitionKeys field carries partition column values, and tableBucketNumber +/// identifies the bucket (nullopt for unbucketed tables). +/// +/// # LSM Tree Structure (Primary-Key Tables) +/// +/// Level 0: Newest data. Each flush creates a new file. Files within level 0 +/// CAN have overlapping keys (multiple files may contain the same key). +/// minSequenceNumber == maxSequenceNumber (single commit per file). +/// +/// Level 1+: Compacted data. Compaction merges lower-level files, deduplicating +/// keys and producing non-overlapping key ranges. No two files at the same +/// level share a key. minSequenceNumber < maxSequenceNumber (merged from +/// multiple commits). +/// +/// # Compaction Strategy (Universal Compaction) +/// +/// Primary-key tables use Universal Compaction (similar to RocksDB): +/// +/// Sorted runs: Each L0 file is its own sorted run. All files at the same +/// higher level (L1, L2, ...) together form one sorted run with +/// non-overlapping keys. Compaction merges sorted runs. +/// +/// Trigger: When the number of sorted runs exceeds +/// num-sorted-run.compaction-trigger (default: 5), compaction is triggered. +/// +/// Picking strategy: Compaction picks the OLDEST sorted runs first and merges +/// them until the count drops below the trigger threshold. The output level +/// depends on which sorted runs are merged. Multiple non-L0 levels can +/// exist (L1, L2, ...). Within each level, files have non-overlapping key +/// ranges. L0 files can overlap with any level. +/// +/// Who runs compaction: +/// - Default: the writer process (Flink task) runs compaction asynchronously +/// after each commit. +/// - Dedicated mode: a separate compaction job can be configured to offload +/// compaction from writers. +/// - Either way, compaction creates a new snapshot with updated manifest +/// entries (REMOVE old files, ADD compacted file). +/// +/// Compaction output: +/// - Data file: deduplicated keys, latest value wins (by sequence number). +/// - With changelog-producer=lookup: also generates a changelog file. +/// - With deletion-vectors.enabled: may generate deletion bitmaps instead of +/// rewriting data files. +/// +/// Append-only compaction: Only concatenates small files into larger ones (no +/// key deduplication). All files stay at level 0. Triggered when the number +/// of small files exceeds a threshold. +/// +/// rawConvertible: Per-split flag set by the Paimon planner. +/// false = keys overlap across levels, merge-on-read required. +/// true = fully compacted (single level, no overlapping keys), read +/// directly. Always true for append-only tables. +/// +/// # Snapshots and Manifests +/// +/// Every Paimon read (batch or streaming) goes through snapshots — immutable +/// point-in-time views of the table. A snapshot points to a manifest-list, +/// which references manifest files that record which data files are active. +/// +/// Physical layout on storage: +/// +/// table-path/ +/// snapshot/ +/// snapshot-1 ← JSON: points to manifest-list-1 +/// snapshot-2 ← JSON: points to manifest-list-2 +/// manifest/ +/// manifest-list-1 ← lists which manifest files to read +/// manifest-1 ← data file entries (ADD/REMOVE) +/// manifest-2 +/// dt=2024-01-01/ +/// bucket-0/ +/// data-001.orc ← actual data files +/// data-002.orc +/// schema/ +/// schema-0 ← table schema +/// +/// Each manifest entry contains both structured metadata and the file path: +/// +/// {action: ADD, +/// partition: {dt: "2024-01-01"}, ← structured, for fast pruning +/// bucket: 0, ← structured, for fast pruning +/// filePath: "dt=2024-01-01/bucket-0/data-001.orc", ← for I/O +/// level: 0, minSeq: 10, maxSeq: 10, ...} +/// +/// The partition/bucket values are redundant with the file path (which +/// encodes them in the directory structure) — the structured fields enable +/// fast partition/bucket pruning without path parsing. +/// +/// The Paimon planner reads the manifest, applies partition pruning, and +/// generates one PaimonConnectorSplit per partition × bucket combination: +/// +/// Manifest entries after pruning for dt='2024-01-01': +/// partition={dt: "2024-01-01"}, bucket=0: [data-001.orc, data-002.orc] +/// partition={dt: "2024-01-01"}, bucket=1: [data-001.orc] +/// → Split 1: partitionKeys={dt: "2024-01-01"}, bucket=0, 2 files +/// → Split 2: partitionKeys={dt: "2024-01-01"}, bucket=1, 1 file +/// +/// Snapshot isolation: Readers see only files referenced by their snapshot. +/// Concurrent writers creating new snapshots don't affect in-flight reads. +/// Old data files are only deleted after snapshot expiration (no snapshot +/// references them). The manifest is authoritative — partition directories +/// may contain old compacted-away files that no snapshot references. +/// +/// # Read Modes +/// +/// All reads are snapshot-based. The mode determines how many snapshots are +/// read and what output is produced: +/// +/// Batch (snapshot read): Reads the full file set at ONE snapshot. For +/// primary-key tables, merge-on-read ALWAYS required to deduplicate by key +/// and return the current state — regardless of changelog mode. This is the +/// Velox/Presto read path. +/// +/// Time travel: Batch read at a user-specified older snapshot (instead of +/// the latest). Same mechanism, different snapshot ID. Only works if the +/// snapshot hasn't been expired. +/// +/// Streaming: Reads DELTAS between consecutive snapshots (N→N+1→N+2→...). +/// Only processes newly added files (the manifest diff), emitting changelog +/// records. Runs continuously, polling for new snapshots. Specifies only a +/// start snapshot — the end is unbounded. Merge-on-read is NOT needed for +/// streaming — the reader only sees delta files, not the full file set, so +/// there is nothing to merge against. For input changelog mode, RowKind is +/// stored in the files and emitted directly. For upsert mode, the reader +/// infers changelog from the delta. Compaction deltas are a special case: +/// the changelog is collapsed, so changelog-producer=lookup is needed to +/// regenerate meaningful changelog records during compaction. +/// +/// Incremental: Bounded streaming — reads deltas from snapshot N to M, then +/// stops. Same delta mechanism as streaming but with an end bound. +/// +/// Merge-on-read summary for primary-key tables: +/// Batch/time-travel: Always required (full file set, must deduplicate). +/// Streaming/incremental: Never required (delta files only, no merge). +/// +/// # Coordinator/Reader Responsibility Split +/// +/// Split generation follows a coordinator/reader architecture (mirrors +/// Flink's enumerator/reader pattern): +/// +/// Coordinator (Java planner): +/// - Batch: reads manifest at one snapshot, groups files by partition × +/// bucket, generates one PaimonConnectorSplit per group. +/// - Streaming: starts at a given snapshot, continuously polls for new +/// snapshots. For each new snapshot (N→N+1), computes the manifest diff +/// (newly added files), groups delta files by partition × bucket, and +/// generates one split per group. Repeats indefinitely. +/// - Incremental: same as streaming but stops at the end snapshot M. +/// - With changelog-producer=lookup: sends changelog files (kCompact) for +/// streaming, data files (kAppend) for batch. +/// - Without changelog-producer: sends data files (kAppend) for all modes. +/// +/// Reader (Velox worker): +/// - Stateless — receives splits and reads files. Does not track snapshots, +/// does not compute deltas, does not decide which file type to read. +/// - For batch: performs merge-on-read if rawConvertible is false. +/// - For streaming: emits rows directly (as +I without changelog-producer, +/// or with stored RowKind for changelog files / input changelog mode). +/// +/// # Changelog Semantics +/// +/// Primary-key tables can produce a changelog stream with four row kinds: +/// +I (INSERT): New key added. +/// -U (UPDATE_BEFORE): Old value being replaced (for retraction). +/// +U (UPDATE_AFTER): New value replacing it. +/// -D (DELETE): Key removed. +/// +/// -U and +U are always emitted as a consecutive pair for the same key. +/// Intermediate updates are collapsed — the consumer sees only the net +/// change (oldest value → newest value), not every step. +/// +/// # RowKind and the _rowkind Column +/// +/// Every primary-key table supports a hidden system column `_rowkind` that +/// encodes the change type per row. At the file format level (Parquet, ORC, +/// Nimble), `_rowkind` is a regular TINYINT column — the file format has no +/// special knowledge of it. It is "hidden" only at the SQL layer (not shown +/// in SELECT *). Values: 0=+I, 1=-U, 2=+U, 3=-D. +/// +/// Whether `_rowkind` is actually written to data files depends on the mode: +/// +/// Upsert mode (normal writes): `_rowkind` is omitted from data files since +/// all records are implicitly +I/+U (insert or update, latest wins). The +/// streaming reader INFERS -U/+U pairs during merge-on-read by comparing +/// old vs new values for the same key across LSM levels. When a SQL DELETE +/// is issued, Paimon writes a record with `_rowkind=-D` to level 0, along +/// with the full before-image (all column values populated). Compaction +/// removes the key entirely. +/// +/// Limitation: Streaming reads in upsert mode only process delta files +/// (newly added files between snapshots) — the reader does NOT look back +/// at previous snapshots. A key appearing once in the delta is emitted as +/// +I, even if that key already existed in a prior snapshot (should be +U). +/// For correct +I vs +U distinction, use changelog-producer=lookup which +/// looks up old values from existing files to determine the actual change +/// type. This limitation does NOT apply to input changelog mode — the +/// source provides the correct RowKind (+I/+U/-U/-D) explicitly. +/// +/// Input changelog mode (CDC source writes): `_rowkind` IS written for every +/// row. The external source (Flink CDC, Kafka) provides the change type. +/// The streaming reader emits rows with their stored RowKind directly — no +/// merge inference needed for uncompacted data. Files must be read in +/// sequence-number order since -U/+U pairs may span files. Compaction +/// collapses the stored changelog; use changelog-producer=lookup to +/// regenerate it during compaction. +/// +/// rowkind.field mode: Instead of the hidden `_rowkind` column, the table +/// designates a user-visible column (e.g., "op") to carry the change type. +/// Useful when the CDC source provides change type as a regular field. At +/// the file level, it's just a regular column — the Paimon engine +/// interprets its values as RowKind. +/// +/// # Deletes +/// +/// A delete record is NOT a row with null values. It's a complete row (full +/// before-image) tagged with RowKind=-D. The mechanism depends on the source: +/// +/// SQL DELETE: Paimon engine writes {_rowkind=-D, id=1, name="Alice"} +/// CDC source: External source provides -D record with full values +/// rowkind.field: User column carries the delete marker +/// +/// The full old column values (before-image) are carried for streaming/ +/// changelog consumers that need retraction semantics. Without old values, +/// the consumer would need its own state lookup to know what was deleted: +/// +/// Example — aggregation counting users by country: +/// State: {id=1, name="Alice", country="US"} → US count = 1 +/// +/// -D {id=1} (key only): +/// Consumer doesn't know which country to decrement — must look up. +/// -D {id=1, name="Alice", country="US"} (full before-image): +/// Consumer decrements US count directly. No lookup needed. +/// +/// Same reason -U (UPDATE_BEFORE) carries the full old row — the consumer +/// retracts old values before applying +U (UPDATE_AFTER) with new values. +/// For batch reads, the old values don't matter (compaction removes the key). +/// +/// Separately, deletion files (PaimonDeletionFile) are a bitmap-based +/// mechanism that marks row positions as deleted within an existing data +/// file without rewriting it. This is orthogonal to RowKind-based deletes. +/// +/// For batch/snapshot reads (Presto), changelog semantics don't apply — the +/// reader returns the current state of each row, not the change history. +/// +/// Append-only tables only produce +I (INSERT) records — no updates or +/// deletes since there's no primary key to identify rows. Streaming reads +/// from append-only tables emit all new rows as +I; file read order doesn't +/// matter since there are no retraction pairs. +/// +/// # Sequence Numbers +/// +/// Auto-generated per-commit, monotonically increasing. Used within a +/// snapshot for merge-on-read ordering (higher sequence number wins for +/// duplicate keys). NOT used for snapshot isolation (that's handled by the +/// manifest chain). Sequence numbers are per-file metadata, not per-row — +/// after compaction, per-row attribution is lost but the LSM level order +/// makes it unnecessary. +/// +/// # changelog-producer=lookup +/// +/// When configured on a primary-key table, the compactor generates TWO +/// outputs per compaction: (1) a data file (no `_rowkind` in upsert mode) +/// and (2) a separate changelog file containing the pre-computed changelog +/// with correct RowKind (+I/+U/-U/-D, includes `_rowkind`). Both files are +/// marked as PaimonFileSource::kCompact (produced by compaction). Without +/// this flag, compaction produces only a data file. +/// +/// The coordinator (Java planner) sends the right files based on read mode: +/// +/// Batch read: sends data files only → merge-on-read as usual +/// Streaming read: sends changelog files only → RowKind already correct, +/// no inference or merge needed +/// +/// Without changelog-producer=lookup, streaming reads work as follows: +/// +/// Streaming read (no changelog-producer): +/// Coordinator sends the delta data files (newly added between snapshots). +/// The reader does NOT do merge-on-read — it only processes delta files, +/// not the full file set. Each row in the delta is emitted as +I. The +/// limitation is that the reader cannot distinguish +I (new key) from +U +/// (existing key updated) because it doesn't look back at previous +/// snapshots. For many use cases (e.g., downstream upsert sinks) this is +/// acceptable since the consumer treats +I and +U identically. +/// +/// Compaction snapshots are problematic without changelog-producer: the +/// compacted file contains ALL keys (not just changed ones), so the reader +/// would emit +I for every key — including unchanged ones. With +/// changelog-producer=lookup, the compactor compares old vs new values and +/// generates a changelog file with only actual changes. +/// +/// If the same key is updated across multiple consecutive snapshots (N→N+1, +/// N+1→N+2, ...), each delta is processed independently. The reader emits +/// one +I per delta that touches the key. The downstream consumer must be +/// idempotent (e.g., upsert by key) to handle this correctly. Without +/// merge-on-read, the reader has no version information — it does not know +/// whether a key is new or updated, nor can it deduplicate across deltas. +/// This is by design: streaming reads trade correctness of change types for +/// simplicity and performance. +/// +/// Example — key updated across two snapshots without changelog-producer: +/// +/// Delta N→N+1: {id=1, name="Alice"} → emitted as +I +/// Delta N+1→N+2: {id=1, name="Bob"} → emitted as +I (not +U) +/// +/// This works for upsert sinks (consumer does INSERT-or-UPDATE by key — +/// final state is correct regardless of +I vs +U). It breaks for retraction +/// sinks (e.g., aggregation counting by name): +/// +/// Correct: +I "Alice" (count=1), -U "Alice" + +U "Bob" (count=1) +/// Without changelog-producer: +I "Alice" (count=1), +I "Bob" (count=2) ✗ +/// +/// For exact changelog semantics, use changelog-producer=lookup. +/// +/// The file origin (write vs compaction) is indicated by PaimonFileSource in +/// PaimonDataFile::fileSource. Both data files and changelog files use +/// the same physical file format. Note that PaimonFileSource indicates HOW +/// the file was produced, not WHETHER it is a data or changelog file — both +/// are marked kCompact when produced by compaction. The coordinator tracks +/// which is which via separate manifest entries. +/// +/// Schema difference: In upsert mode (no input changelog), data files do NOT +/// contain the `_rowkind` column (all records are implicitly +I/+U). Changelog +/// files generated by changelog-producer=lookup DO contain `_rowkind` since +/// they carry explicit change types. The coordinator knows which file type it +/// is sending, so the reader schema is adjusted accordingly. In input changelog +/// mode, both data and changelog files contain `_rowkind`. +/// +/// # Velox Connector Scope +/// +/// Currently only batch reads (Presto) are supported. Only data files are +/// needed — changelog files and streaming semantics are not required. +/// +/// Future streaming support: We will require changelog-producer=lookup (or +/// full-compaction) to be enabled on the table. This means the compactor +/// generates pre-computed changelog files, and the streaming reader consumes +/// them directly. We do NOT plan to support beforeFiles-based diffing (where +/// the reader merges old and new file sets to infer changelog at read time). +/// This avoids expensive read-time merge and simplifies the reader — the +/// coordinator sends changelog files for compaction snapshots and delta data +/// files for non-compaction snapshots. +enum class PaimonTableType { + /// No primary key. Data files are independent — no merge-on-read needed. + /// All files at level 0, rawConvertible always true. + kAppendOnly, + /// Has a primary key. Uses LSM tree with merge-on-read to deduplicate + /// records by key when data spans multiple compaction levels. + kPrimaryKey, +}; + +/// Returns the string name of the table type (e.g., "APPEND_ONLY"). +std::string paimonTableTypeString(PaimonTableType type); + +/// Parses a table type from its string name. +PaimonTableType paimonTableTypeFromString(const std::string& str); + +FOLLY_ALWAYS_INLINE std::ostream& operator<<( + std::ostream& os, + PaimonTableType type) { + os << paimonTableTypeString(type); + return os; +} + +} // namespace facebook::velox::connector::hive::paimon + +template <> +struct fmt::formatter + : formatter { + auto format( + facebook::velox::connector::hive::paimon::PaimonTableType type, + format_context& ctx) const { + return formatter::format( + facebook::velox::connector::hive::paimon::paimonTableTypeString(type), + ctx); + } +}; + +namespace facebook::velox::connector::hive::paimon { + +/// Represents a Paimon DataSplit — a collection of data files in one +/// partition/bucket. +/// +/// Unlike HiveConnectorSplit (one file per split), a Paimon split maps to a +/// logical bucket which may contain multiple physical files across LSM-tree +/// levels. +/// +/// NOTE: Table-wide metadata (primary key columns, merge engine, table schema) +/// is NOT carried in the split. A future PaimonTableHandle (extending +/// HiveTableHandle) will carry this information, needed for merge-on-read +/// (key deduplication) and schema evolution. For batch reads of +/// rawConvertible=true splits, the existing HiveTableHandle suffices. +class PaimonConnectorSplit : public connector::ConnectorSplit { + public: + /// @param connectorId Connector identifier. + /// @param snapshotId Paimon table snapshot version this split was generated + /// from. + /// @param tableType Whether this is an append-only or primary-key table. + /// @param fileFormat File format of the data files (e.g., ORC, Parquet). + /// @param dataFiles Data files in this split, each representing a physical + /// file in the LSM-tree. + /// @param partitionKeys Partition key-value pairs. Keys map to partition + /// column names; values are nullopt for null partitions. + /// @param tableBucketNumber Bucket number within the Paimon table's bucket + /// distribution. Nullopt for unbucketed tables. + /// @param rawConvertible Per-split flag indicating whether all files in this + /// split (partition × bucket) can be read without merge-on-read (no + /// key deduplication needed across LSM levels). Set by the Paimon + /// planner based on the compaction state of the entire bucket. Only + /// meaningful for primary-key tables. + PaimonConnectorSplit( + const std::string& connectorId, + int64_t snapshotId, + PaimonTableType tableType, + dwio::common::FileFormat fileFormat, + const std::vector& dataFiles, + std::unordered_map> partitionKeys, + std::optional tableBucketNumber, + bool rawConvertible = true); + + int64_t snapshotId() const { + return snapshotId_; + } + + PaimonTableType tableType() const { + return tableType_; + } + + const std::vector& dataFiles() const { + return dataFiles_; + } + + const std::unordered_map>& + partitionKeys() const { + return partitionKeys_; + } + + std::optional tableBucketNumber() const { + return tableBucketNumber_; + } + + bool rawConvertible() const { + return rawConvertible_; + } + + dwio::common::FileFormat fileFormat() const { + return fileFormat_; + } + + std::string toString() const override; + + folly::dynamic serialize() const override; + + static std::shared_ptr create( + const folly::dynamic& obj); + + static void registerSerDe(); + + private: + const int64_t snapshotId_; + const PaimonTableType tableType_; + const dwio::common::FileFormat fileFormat_; + const std::vector dataFiles_; + const std::unordered_map> + partitionKeys_; + const std::optional tableBucketNumber_; + const bool rawConvertible_; +}; + +/// Builder for PaimonConnectorSplit construction. +class PaimonConnectorSplitBuilder { + public: + PaimonConnectorSplitBuilder( + std::string connectorId, + int64_t snapshotId, + PaimonTableType tableType, + dwio::common::FileFormat fileFormat) + : connectorId_(std::move(connectorId)), + snapshotId_(snapshotId), + tableType_(tableType), + fileFormat_(fileFormat) {} + + PaimonConnectorSplitBuilder& + addFile(std::string filePath, uint64_t fileSize, int32_t level = 0); + + PaimonConnectorSplitBuilder& partitionKey( + std::string name, + std::optional value); + + PaimonConnectorSplitBuilder& tableBucketNumber(int32_t bucketId); + + PaimonConnectorSplitBuilder& rawConvertible(bool value); + + std::shared_ptr build(); + + private: + const std::string connectorId_; + const int64_t snapshotId_; + const PaimonTableType tableType_; + const dwio::common::FileFormat fileFormat_; + std::vector dataFiles_; + std::unordered_map> partitionKeys_; + std::optional tableBucketNumber_; + bool rawConvertible_{true}; +}; + +} // namespace facebook::velox::connector::hive::paimon diff --git a/velox/connectors/hive/paimon/PaimonDataFileMeta.cpp b/velox/connectors/hive/paimon/PaimonDataFileMeta.cpp new file mode 100644 index 00000000000..82b3437028f --- /dev/null +++ b/velox/connectors/hive/paimon/PaimonDataFileMeta.cpp @@ -0,0 +1,120 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/connectors/hive/paimon/PaimonDataFileMeta.h" + +#include "velox/common/base/Exceptions.h" + +namespace facebook::velox::connector::hive::paimon { + +// static +std::string PaimonDataFile::typeString(Type type) { + switch (type) { + case Type::kData: + return "DATA"; + case Type::kChangelog: + return "CHANGELOG"; + default: + VELOX_FAIL("Unknown PaimonDataFile::Type: {}", static_cast(type)); + } +} + +// static +PaimonDataFile::Type PaimonDataFile::typeFromString(const std::string& str) { + if (str == "DATA") { + return Type::kData; + } + if (str == "CHANGELOG") { + return Type::kChangelog; + } + VELOX_FAIL("Unknown PaimonDataFile::Type: {}", str); +} + +// static +std::string PaimonDataFile::sourceString(Source source) { + switch (source) { + case Source::kAppend: + return "APPEND"; + case Source::kCompact: + return "COMPACT"; + default: + VELOX_FAIL( + "Unknown PaimonDataFile::Source: {}", static_cast(source)); + } +} + +// static +PaimonDataFile::Source PaimonDataFile::sourceFromString( + const std::string& str) { + if (str == "APPEND") { + return Source::kAppend; + } + if (str == "COMPACT") { + return Source::kCompact; + } + VELOX_FAIL("Unknown PaimonDataFile::Source: {}", str); +} + +std::string PaimonDataFile::toString() const { + return fmt::format( + "{{path={}, size={}, rows={}, level={}, type={}, source={}, " + "deletionFile={}}}", + path, + size, + rowCount, + level, + typeString(type), + sourceString(source), + deletionFile.has_value() ? deletionFile->toString() : "none"); +} + +folly::dynamic PaimonDataFile::serialize() const { + folly::dynamic obj = folly::dynamic::object; + obj["filePath"] = path; + obj["fileSize"] = size; + obj["rowCount"] = rowCount; + obj["level"] = level; + obj["minSequenceNumber"] = minSequenceNumber; + obj["maxSequenceNumber"] = maxSequenceNumber; + obj["deleteRowCount"] = deleteRowCount; + obj["creationTimeMs"] = creationTimeMs; + obj["fileType"] = typeString(type); + obj["sourceType"] = sourceString(source); + if (deletionFile.has_value()) { + obj["deletionFile"] = deletionFile->serialize(); + } + return obj; +} + +// static +PaimonDataFile PaimonDataFile::create(const folly::dynamic& obj) { + PaimonDataFile file; + file.path = obj["filePath"].asString(); + file.size = static_cast(obj["fileSize"].asInt()); + file.rowCount = static_cast(obj["rowCount"].asInt()); + file.level = static_cast(obj["level"].asInt()); + file.minSequenceNumber = obj["minSequenceNumber"].asInt(); + file.maxSequenceNumber = obj["maxSequenceNumber"].asInt(); + file.deleteRowCount = obj["deleteRowCount"].asInt(); + file.creationTimeMs = obj["creationTimeMs"].asInt(); + file.type = typeFromString(obj["fileType"].asString()); + file.source = sourceFromString(obj["sourceType"].asString()); + if (obj.count("deletionFile") > 0) { + file.deletionFile = PaimonDeletionFile::create(obj["deletionFile"]); + } + return file; +} + +} // namespace facebook::velox::connector::hive::paimon diff --git a/velox/connectors/hive/paimon/PaimonDataFileMeta.h b/velox/connectors/hive/paimon/PaimonDataFileMeta.h new file mode 100644 index 00000000000..5011ba5e523 --- /dev/null +++ b/velox/connectors/hive/paimon/PaimonDataFileMeta.h @@ -0,0 +1,199 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +#include "velox/connectors/hive/paimon/PaimonDeletionFile.h" + +namespace facebook::velox::connector::hive::paimon { + +/// Represents a single Paimon file (data or changelog) within a split. +/// Mirrors Apache Paimon's DataFileMeta structure. +/// +/// Both data files and changelog files use the same physical file format +/// (Parquet, ORC, Nimble) and carry the same metadata. The distinction is +/// captured by `type`: +/// - kData: regular data file (current state of records). +/// - kChangelog: changelog file with RowKind (+I/-U/+U/-D) for streaming. +struct PaimonDataFile { + /// Whether this file is a data file or a changelog file. Both use the + /// same physical format — this enum distinguishes their semantic role. + /// + /// Data files contain the current state of records. Changelog files + /// contain the change history with RowKind tags. The coordinator sends + /// data files for batch reads and changelog files for streaming reads + /// (when changelog-producer=lookup is configured). + enum class Type { + /// Regular data file containing current record state. + kData, + + /// Changelog file containing records tagged with RowKind (+I/-U/+U/-D). + /// Only produced when changelog-producer=lookup is configured on the table. + /// The coordinator sends these for streaming reads instead of data files. + kChangelog, + }; + + /// Origin of a Paimon file — how it was produced, not what it contains. + /// Mirrors Paimon's DataFileMeta.FileSource enum. + /// + /// This does NOT indicate whether the file contains `_rowkind`. The presence + /// of `_rowkind` depends on the changelog mode (upsert vs input changelog), + /// not the source type. + enum class Source { + /// File produced by a normal write (flush/append). + kAppend, + + /// File produced by compaction. This could be: + /// - Without changelog-producer=lookup: a regular data file (same as + /// kAppend but produced by compaction instead of a write). + /// - With changelog-producer=lookup: compaction produces TWO files — + /// (1) a data file (no `_rowkind` in upsert mode) and (2) a separate + /// changelog file (with `_rowkind` containing correct +I/-U/+U/-D). + /// Both are marked kCompact. The coordinator sends the right one based + /// on the read mode (data file for batch, changelog file for + /// streaming). + kCompact, + }; + + /// Returns the string name of the file type (e.g., "DATA"). + static std::string typeString(Type type); + + /// Parses a file type from its string name. + static Type typeFromString(const std::string& str); + + /// Returns the string name of the source type (e.g., "APPEND"). + static std::string sourceString(Source source); + + /// Parses a source type from its string name. + static Source sourceFromString(const std::string& str); + + /// Path to the file (ORC, Parquet, etc.). + std::string path; + + /// Size of the file in bytes. + uint64_t size{0}; + + /// Number of rows in this file. + uint64_t rowCount{0}; + + /// LSM-tree level of this file. Level 0 contains the newest (unflushed) + /// data; higher levels contain progressively more compacted data. Within + /// level 0, files CAN have overlapping keys. Within level 1+, compaction + /// guarantees non-overlapping key ranges across files. + /// Always 0 for append-only tables (no compaction). + int32_t level{0}; + + /// Sequence number range of records in this file. Auto-generated per-commit, + /// monotonically increasing. For level 0 files, min == max (single commit). + /// For compacted files (level 1+), min < max (merged from multiple commits). + /// Used during merge-on-read to resolve duplicate keys — higher sequence + /// number wins. Per-file metadata only (not per-row); after compaction, + /// per-row sequence attribution is lost but LSM level order makes it + /// unnecessary. + int64_t minSequenceNumber{0}; + int64_t maxSequenceNumber{0}; + + /// Number of rows in this file with RowKind = DELETE or UPDATE_BEFORE. + /// row_count = addRowCount + deleteRowCount, where addRowCount is the + /// number of INSERT or UPDATE_AFTER rows. + /// + /// Only applicable to primary-key tables. Append-only tables have no + /// _rowkind column and every row is implicitly +I, so deleteRowCount is + /// always 0. + /// + /// This is independent of deletionFile — deleteRowCount counts changelog + /// records stored inside the file, while deletionFile is an external bitmap + /// of positionally deleted rows. + /// + /// Used to determine rawConvertible: if deleteRowCount > 0, the file + /// contains changelog records and cannot be read raw (needs RowKind + /// filtering during merge-on-read). + int64_t deleteRowCount{0}; + + /// Timestamp (epoch millis) when this file was created. + int64_t creationTimeMs{0}; + + /// Whether this is a data file or changelog file. + Type type{Type::kData}; + + /// How this file was produced (write vs compaction). + Source source{Source::kAppend}; + + /// Deletion file for this data file. Contains a roaring bitmap of deleted + /// row positions. Nullopt if no rows have been deleted from this file. + /// Applies to both primary-key and append-only tables (when + /// deletion-vectors.enabled is set). Orthogonal to deleteRowCount — + /// deletionFile marks positional deletes, while deleteRowCount counts + /// RowKind-based changelog records. + /// See PaimonDeletionFile for details. + std::optional deletionFile; + + std::string toString() const; + folly::dynamic serialize() const; + static PaimonDataFile create(const folly::dynamic& obj); +}; + +FOLLY_ALWAYS_INLINE std::ostream& operator<<( + std::ostream& os, + PaimonDataFile::Type type) { + os << PaimonDataFile::typeString(type); + return os; +} + +FOLLY_ALWAYS_INLINE std::ostream& operator<<( + std::ostream& os, + PaimonDataFile::Source source) { + os << PaimonDataFile::sourceString(source); + return os; +} + +} // namespace facebook::velox::connector::hive::paimon + +template <> +struct fmt::formatter< + facebook::velox::connector::hive::paimon::PaimonDataFile::Type> + : formatter { + auto format( + facebook::velox::connector::hive::paimon::PaimonDataFile::Type type, + format_context& ctx) const { + return formatter::format( + facebook::velox::connector::hive::paimon::PaimonDataFile::typeString( + type), + ctx); + } +}; + +template <> +struct fmt::formatter< + facebook::velox::connector::hive::paimon::PaimonDataFile::Source> + : formatter { + auto format( + facebook::velox::connector::hive::paimon::PaimonDataFile::Source source, + format_context& ctx) const { + return formatter::format( + facebook::velox::connector::hive::paimon::PaimonDataFile::sourceString( + source), + ctx); + } +}; diff --git a/velox/connectors/hive/paimon/PaimonDataSource.cpp b/velox/connectors/hive/paimon/PaimonDataSource.cpp new file mode 100644 index 00000000000..8ef7a14a9fc --- /dev/null +++ b/velox/connectors/hive/paimon/PaimonDataSource.cpp @@ -0,0 +1,86 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/connectors/hive/paimon/PaimonDataSource.h" + +#include "velox/common/Casts.h" +#include "velox/connectors/hive/paimon/PaimonConnectorSplit.h" +#include "velox/connectors/hive/paimon/PaimonSplitReader.h" + +namespace facebook::velox::connector::hive::paimon { + +PaimonDataSource::PaimonDataSource( + const RowTypePtr& outputType, + const ConnectorTableHandlePtr& tableHandle, + const ColumnHandleMap& assignments, + FileHandleFactory* fileHandleFactory, + folly::Executor* ioExecutor, + const ConnectorQueryCtx* connectorQueryCtx, + const std::shared_ptr& paimonConfig) + : FileDataSource( + outputType, + tableHandle, + assignments, + fileHandleFactory, + ioExecutor, + connectorQueryCtx, + paimonConfig) {} + +void PaimonDataSource::addSplit(std::shared_ptr split) { + paimonSplit_ = checkedPointerCast(split); + + if (!paimonSplit_->rawConvertible()) { + VELOX_NYI( + "Paimon merge-on-read is not yet implemented. " + "Primary-key tables with rawConvertible=false require merge-on-read " + "to deduplicate records across LSM levels."); + } + + // Create a FileConnectorSplit for the first data file and delegate to + // FileDataSource::addSplit(), which calls createSplitReader() to create + // a PaimonSplitReader that handles all files internally. + const auto& firstFile = paimonSplit_->dataFiles().front(); + auto firstFileSplit = std::make_shared( + paimonSplit_->connectorId, + firstFile.path, + paimonSplit_->fileFormat(), + /*_start=*/0, + /*_length=*/std::numeric_limits::max(), + /*splitWeight=*/0, + /*cacheable=*/true, + /*_properties=*/std::nullopt, + paimonSplit_->partitionKeys()); + + FileDataSource::addSplit(std::move(firstFileSplit)); +} + +std::unique_ptr PaimonDataSource::createSplitReader() { + return std::make_unique( + split_, + paimonSplit_, + tableHandle_, + &partitionKeys_, + connectorQueryCtx_, + fileConfig_, + readerOutputType_, + dataIoStats_, + metadataIoStats_, + ioStats_, + fileHandleFactory_, + ioExecutor_, + scanSpec_); +} + +} // namespace facebook::velox::connector::hive::paimon diff --git a/velox/connectors/hive/paimon/PaimonDataSource.h b/velox/connectors/hive/paimon/PaimonDataSource.h new file mode 100644 index 00000000000..bad88951042 --- /dev/null +++ b/velox/connectors/hive/paimon/PaimonDataSource.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/connectors/hive/FileDataSource.h" +#include "velox/connectors/hive/paimon/PaimonConfig.h" + +namespace facebook::velox::connector::hive::paimon { + +class PaimonConnectorSplit; + +/// Paimon-specific data source that extends FileDataSource. +/// +/// Acts as an orchestrator that selects the right reading strategy based on +/// split properties: +/// +/// rawConvertible=true (append-only or fully-compacted primary-key): +/// Creates a PaimonSplitReader that handles multi-file iteration, +/// deletion vectors, and _rowkind internally. FileDataSource::next() +/// drives the read loop — no override needed. +/// +/// rawConvertible=false (primary-key with overlapping keys): +/// Merge path. A PaimonMergeReader opens all files simultaneously and +/// performs a sorted merge by primary key, deduplicating by sequence +/// number. Not yet implemented. +class PaimonDataSource : public FileDataSource { + public: + PaimonDataSource( + const RowTypePtr& outputType, + const ConnectorTableHandlePtr& tableHandle, + const ColumnHandleMap& assignments, + FileHandleFactory* fileHandleFactory, + folly::Executor* ioExecutor, + const ConnectorQueryCtx* connectorQueryCtx, + const std::shared_ptr& paimonConfig); + + void addSplit(std::shared_ptr split) override; + + protected: + /// Creates a PaimonSplitReader with all data files from the Paimon split. + /// Called by FileDataSource::addSplit() during split initialization. + std::unique_ptr createSplitReader() override; + + private: + // The original Paimon split. Stored so createSplitReader() can access + // the full list of data files. + std::shared_ptr paimonSplit_; +}; + +} // namespace facebook::velox::connector::hive::paimon diff --git a/velox/connectors/hive/paimon/PaimonDeletionFile.cpp b/velox/connectors/hive/paimon/PaimonDeletionFile.cpp new file mode 100644 index 00000000000..67ded98d65d --- /dev/null +++ b/velox/connectors/hive/paimon/PaimonDeletionFile.cpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/connectors/hive/paimon/PaimonDeletionFile.h" + +#include "velox/common/base/Exceptions.h" + +namespace facebook::velox::connector::hive::paimon { + +PaimonDeletionFile::PaimonDeletionFile( + std::string _path, + uint64_t _offset, + uint64_t _length, + uint64_t _cardinality) + : path(std::move(_path)), + offset(_offset), + length(_length), + cardinality(_cardinality) { + VELOX_CHECK_GT(length, 0, "PaimonDeletionFile length must be > 0"); + VELOX_CHECK_GT(cardinality, 0, "PaimonDeletionFile cardinality must be > 0"); +} + +std::string PaimonDeletionFile::toString() const { + return fmt::format( + "{{path={}, offset={}, length={}, cardinality={}}}", + path, + offset, + length, + cardinality); +} + +folly::dynamic PaimonDeletionFile::serialize() const { + folly::dynamic obj = folly::dynamic::object; + obj["path"] = path; + obj["offset"] = offset; + obj["length"] = length; + obj["cardinality"] = cardinality; + return obj; +} + +// static +PaimonDeletionFile PaimonDeletionFile::create(const folly::dynamic& obj) { + return PaimonDeletionFile( + obj["path"].asString(), + static_cast(obj["offset"].asInt()), + static_cast(obj["length"].asInt()), + static_cast(obj["cardinality"].asInt())); +} + +} // namespace facebook::velox::connector::hive::paimon diff --git a/velox/connectors/hive/paimon/PaimonDeletionFile.h b/velox/connectors/hive/paimon/PaimonDeletionFile.h new file mode 100644 index 00000000000..671d4e1bd4f --- /dev/null +++ b/velox/connectors/hive/paimon/PaimonDeletionFile.h @@ -0,0 +1,127 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace facebook::velox::connector::hive::paimon { + +/// Paimon deletion file — a serialized Roaring Bitmap tracking deleted row +/// positions within an existing data file. Only used for primary-key tables +/// with deletion-vectors.enabled=true. Append-only tables have no concept of +/// row-level deletes (no primary key to identify specific rows). +/// +/// # How Deletion Vectors Work +/// +/// When a primary-key table has deletion vectors enabled, an UPDATE or DELETE +/// does not rewrite the original data file. Instead: +/// 1. The new/updated record is written to a new L0 data file (as usual). +/// 2. A deletion bitmap is written marking the old row's position in the +/// original data file (0-based row index). +/// 3. The reader loads the bitmap and skips those positions during scanning. +/// +/// This improves read performance: without deletion vectors, the reader must +/// do merge-on-read to discover that a key in an older file has been superseded +/// by a newer file. With deletion vectors, the reader skips deleted rows +/// immediately — no merge needed. A split with deletion vectors can be +/// rawConvertible=true because the bitmaps already tell the reader which rows +/// to skip (old and new values don't overlap from the reader's perspective). +/// +/// Similar to Iceberg's positional delete files. +/// +/// # Manifest Integration +/// +/// When deletion vectors are written, Paimon creates a new snapshot with +/// updated manifest entries for the affected data file: +/// REMOVE: old manifest entry (without or with old deletion file) +/// ADD: same data file, now with deletion file reference attached +/// The data file itself is NOT rewritten — only the manifest entry changes. +/// +/// Each data file has at most ONE deletion file at any point in time. If more +/// rows are deleted later, the bitmap is replaced with a new one containing +/// all deleted positions (old + new, merged). The old bitmap becomes obsolete. +/// +/// Deletion files can be attached to data files at ANY LSM level (L0, L1, +/// L2, etc.), not just L0. +/// +/// # File Format +/// +/// Deletion files are NOT stored in a columnar format (not Parquet/ORC/Nimble) +/// and not in a row-based format (not compact row). They use a custom binary +/// format: +/// - Core payload: standard RoaringBitmap portable binary serialization +/// (cross-language spec from RoaringFormatSpec). C++ can read this with +/// the CRoaring library; Java uses org.roaringbitmap.RoaringBitmap. +/// - Container packing: when multiple bitmaps are packed into a single file, +/// Paimon uses a simple custom binary layout with length-prefixed entries +/// mapping data file names to their bitmap bytes. +/// +/// Each bitmap contains 0-based row positions of deleted rows within its +/// associated data file. +/// +/// Multiple deletion bitmaps can be packed into a single container file. +/// 'offset' and 'length' specify the byte range within the container where +/// this bitmap's data lives. For standalone deletion files, offset is 0 and +/// length equals the file size. +/// +/// # Relationship to RowKind-Based Deletes +/// +/// Deletion files are orthogonal to RowKind-based deletes: +/// +/// RowKind=-D: A full record written to a NEW data file at level 0, +/// containing the key + all column values tagged with -D. Used for logical +/// deletes (CDC source, SQL DELETE). Compaction resolves it by removing the +/// key entirely. +/// +/// Deletion file (this struct): A bitmap referencing row positions in an +/// EXISTING data file. No new data record written. Used as a physical +/// optimization to avoid rewriting files during compaction or partial +/// deletes. +/// +/// For batch reads, the reader loads deletion bitmaps and filters out deleted +/// row positions during scanning. +struct PaimonDeletionFile { + /// @param path Path to the deletion file. + /// @param offset Byte offset within the container file. + /// @param length Number of bytes of bitmap data (must be > 0). + /// @param cardinality Number of deleted rows (must be > 0). + PaimonDeletionFile( + std::string path, + uint64_t offset, + uint64_t length, + uint64_t cardinality); + + std::string path; + + // Byte offset within the container file where this bitmap starts. + uint64_t offset; + + // Number of bytes of bitmap data. Must be > 0. + uint64_t length; + + // Number of deleted rows (pre-computed for stats without reading bitmap). + // Must be > 0. + uint64_t cardinality; + + std::string toString() const; + folly::dynamic serialize() const; + static PaimonDeletionFile create(const folly::dynamic& obj); +}; + +} // namespace facebook::velox::connector::hive::paimon diff --git a/velox/connectors/hive/paimon/PaimonRowKind.cpp b/velox/connectors/hive/paimon/PaimonRowKind.cpp new file mode 100644 index 00000000000..cfa2d156067 --- /dev/null +++ b/velox/connectors/hive/paimon/PaimonRowKind.cpp @@ -0,0 +1,83 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/connectors/hive/paimon/PaimonRowKind.h" + +#include "velox/common/base/Exceptions.h" + +namespace facebook::velox::connector::hive::paimon { + +std::string paimonRowKindString(PaimonRowKind kind) { + switch (kind) { + case PaimonRowKind::kInsert: + return "+I"; + case PaimonRowKind::kUpdateBefore: + return "-U"; + case PaimonRowKind::kUpdateAfter: + return "+U"; + case PaimonRowKind::kDelete: + return "-D"; + default: + VELOX_FAIL("Unknown PaimonRowKind: {}", static_cast(kind)); + } +} + +PaimonRowKind paimonRowKindFromValue(int8_t value) { + switch (value) { + case 0: + return PaimonRowKind::kInsert; + case 1: + return PaimonRowKind::kUpdateBefore; + case 2: + return PaimonRowKind::kUpdateAfter; + case 3: + return PaimonRowKind::kDelete; + default: + VELOX_FAIL("Unknown PaimonRowKind value: {}", value); + } +} + +std::string paimonChangelogModeString(PaimonChangelogMode mode) { + switch (mode) { + case PaimonChangelogMode::kNone: + return "NONE"; + case PaimonChangelogMode::kInput: + return "INPUT"; + case PaimonChangelogMode::kLookup: + return "LOOKUP"; + case PaimonChangelogMode::kFullCompaction: + return "FULL_COMPACTION"; + default: + VELOX_FAIL("Unknown PaimonChangelogMode: {}", static_cast(mode)); + } +} + +PaimonChangelogMode paimonChangelogModeFromString(const std::string& str) { + if (str == "NONE") { + return PaimonChangelogMode::kNone; + } + if (str == "INPUT") { + return PaimonChangelogMode::kInput; + } + if (str == "LOOKUP") { + return PaimonChangelogMode::kLookup; + } + if (str == "FULL_COMPACTION") { + return PaimonChangelogMode::kFullCompaction; + } + VELOX_FAIL("Unknown PaimonChangelogMode: {}", str); +} + +} // namespace facebook::velox::connector::hive::paimon diff --git a/velox/connectors/hive/paimon/PaimonRowKind.h b/velox/connectors/hive/paimon/PaimonRowKind.h new file mode 100644 index 00000000000..4c1546696b0 --- /dev/null +++ b/velox/connectors/hive/paimon/PaimonRowKind.h @@ -0,0 +1,137 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +#include + +namespace facebook::velox::connector::hive::paimon { + +/// Name of the hidden system column that stores the change type per row. +/// Only present in primary-key table files — append-only tables do not +/// have this column (every row is implicitly +I). +/// At the file format level (Parquet, ORC, Nimble), this is a regular +/// TINYINT column — the file format has no special knowledge of it. +/// It is "hidden" only at the SQL layer (not shown in SELECT *). +static constexpr std::string_view kRowKindColumn = "_rowkind"; + +/// Row-level change type stored in the `_rowkind` column. +/// Values match the Paimon Java RowKind enum (0-3). +/// +/// Only meaningful for primary-key tables. Append-only tables have no +/// concept of updates or deletes — every row is implicitly +I. +enum class PaimonRowKind : int8_t { + /// +I: New row inserted. + kInsert = 0, + /// -U: Old value being replaced (retraction). Always followed by +U. + kUpdateBefore = 1, + /// +U: New value replacing the old one. Always preceded by -U. + kUpdateAfter = 2, + /// -D: Row deleted. Carries the full before-image (all column values). + kDelete = 3, +}; + +std::string paimonRowKindString(PaimonRowKind kind); +PaimonRowKind paimonRowKindFromValue(int8_t value); + +FOLLY_ALWAYS_INLINE std::ostream& operator<<( + std::ostream& os, + PaimonRowKind kind) { + os << paimonRowKindString(kind); + return os; +} + +/// Changelog mode determines how a table produces changelog records and +/// whether `_rowkind` is physically stored in data files. This is a +/// table-level property set via the `changelog-producer` option. +/// Only meaningful for primary-key tables — append-only tables have no +/// updates or deletes, so every row is implicitly +I regardless of this +/// setting. +/// +/// # _rowkind Presence in Files +/// +/// Mode | Data files | Changelog files (from compaction) +/// -----------------+-------------+---------------------------------- +/// kNone | No | N/A (no changelog files produced) +/// kInput | Yes | Yes +/// kLookup | No | Yes +/// kFullCompaction | No | Yes +/// +/// For batch reads, `_rowkind` is not needed — the reader returns +/// current state, not changelog. This enum becomes relevant for streaming. +enum class PaimonChangelogMode { + /// Default for primary-key tables. No changelog producer configured. + /// Data files do NOT contain `_rowkind`. Compaction does NOT generate + /// changelog files. Streaming reads emit all delta rows as +I (cannot + /// distinguish INSERT from UPDATE). + kNone, + /// Source provides RowKind explicitly (CDC/input changelog). + /// Data files DO contain `_rowkind` for every row. + kInput, + /// Compactor generates changelog files by looking up old values during + /// every compaction (both partial and full). Data files do NOT contain + /// `_rowkind`, but changelog files do. Provides low-latency changelog + /// for streaming consumers at the cost of higher write amplification. + kLookup, + /// Same mechanism as kLookup (looking up old values to produce changelog), + /// but changelog files are only generated during full compaction — partial + /// compactions do NOT produce changelog. Between full compactions, + /// streaming consumers get degraded output (all rows emitted as +I). + /// Lower write cost than kLookup, but higher changelog latency. + kFullCompaction, +}; + +std::string paimonChangelogModeString(PaimonChangelogMode mode); +PaimonChangelogMode paimonChangelogModeFromString(const std::string& str); + +FOLLY_ALWAYS_INLINE std::ostream& operator<<( + std::ostream& os, + PaimonChangelogMode mode) { + os << paimonChangelogModeString(mode); + return os; +} + +} // namespace facebook::velox::connector::hive::paimon + +template <> +struct fmt::formatter + : formatter { + auto format( + facebook::velox::connector::hive::paimon::PaimonRowKind kind, + format_context& ctx) const { + return formatter::format( + facebook::velox::connector::hive::paimon::paimonRowKindString(kind), + ctx); + } +}; + +template <> +struct fmt::formatter< + facebook::velox::connector::hive::paimon::PaimonChangelogMode> + : formatter { + auto format( + facebook::velox::connector::hive::paimon::PaimonChangelogMode mode, + format_context& ctx) const { + return formatter::format( + facebook::velox::connector::hive::paimon::paimonChangelogModeString( + mode), + ctx); + } +}; diff --git a/velox/connectors/hive/paimon/PaimonSplitReader.cpp b/velox/connectors/hive/paimon/PaimonSplitReader.cpp new file mode 100644 index 00000000000..446e47897dc --- /dev/null +++ b/velox/connectors/hive/paimon/PaimonSplitReader.cpp @@ -0,0 +1,146 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/connectors/hive/paimon/PaimonSplitReader.h" + +#include + +namespace facebook::velox::connector::hive::paimon { + +PaimonSplitReader::PaimonSplitReader( + const std::shared_ptr& fileSplit, + std::shared_ptr paimonSplit, + const FileTableHandlePtr& tableHandle, + const std::unordered_map* partitionKeys, + const ConnectorQueryCtx* connectorQueryCtx, + const std::shared_ptr& fileConfig, + const RowTypePtr& readerOutputType, + const std::shared_ptr& dataIoStats, + const std::shared_ptr& metadataIoStats, + const std::shared_ptr& ioStats, + FileHandleFactory* fileHandleFactory, + folly::Executor* executor, + const std::shared_ptr& scanSpec) + : FileSplitReader( + fileSplit, + tableHandle, + partitionKeys, + connectorQueryCtx, + fileConfig, + readerOutputType, + dataIoStats, + metadataIoStats, + ioStats, + fileHandleFactory, + executor, + scanSpec), + paimonSplit_(std::move(paimonSplit)) { + // Validate all data files upfront. + for (const auto& dataFile : paimonSplit_->dataFiles()) { + VELOX_CHECK( + !dataFile.deletionFile.has_value(), + "Paimon deletion vector reading is not yet implemented. " + "File '{}' has a deletion vector with {} deleted rows.", + dataFile.path, + dataFile.deletionFile.has_value() ? dataFile.deletionFile->cardinality + : 0); + VELOX_CHECK_EQ( + dataFile.type, + PaimonDataFile::Type::kData, + "Paimon changelog file reading is not yet supported. " + "File '{}' has type {}.", + dataFile.path, + dataFile.type); + } + + // Build FileConnectorSplits for all data files so we can switch between + // them during multi-file iteration. + const auto& dataFiles = paimonSplit_->dataFiles(); + fileSplits_.reserve(dataFiles.size()); + for (const auto& dataFile : dataFiles) { + fileSplits_.emplace_back(makeFileConnectorSplit(dataFile)); + } +} + +void PaimonSplitReader::prepareSplit( + std::shared_ptr metadataFilter, + dwio::common::RuntimeStatistics& runtimeStats, + const folly::F14FastMap& /*fileReadOps*/) { + // Save for re-use when advancing to subsequent files. + metadataFilter_ = std::move(metadataFilter); + runtimeStats_ = &runtimeStats; +} + +uint64_t PaimonSplitReader::next(uint64_t size, VectorPtr& output) { + while (ensureFileSplitReader()) { + // When deletion vectors are implemented, apply the deletion bitmap + // here via Mutation.deletedRows, following the same pattern as + // IcebergSplitReader::next(). + const auto rowsRead = FileSplitReader::next(size, output); + if (rowsRead > 0) { + return rowsRead; + } + // Current file exhausted, try the next one. + finishFileSplitReader(); + } + return 0; +} + +bool PaimonSplitReader::ensureFileSplitReader() { + if (baseReader_ != nullptr) { + VELOX_CHECK_NOT_NULL(baseRowReader_); + return true; + } + VELOX_CHECK_NULL(baseRowReader_); + + while (currentFileIndex_ < fileSplits_.size()) { + // Point to the current file. + fileSplit_ = fileSplits_[currentFileIndex_]; + + // Initialize the base reader for this file. + FileSplitReader::prepareSplit(metadataFilter_, *runtimeStats_); + if (!emptySplit_) { + return true; + } + ++currentFileIndex_; + } + return false; // All files exhausted. +} + +void PaimonSplitReader::finishFileSplitReader() { + VELOX_CHECK_LT(currentFileIndex_, fileSplits_.size()); + VELOX_CHECK_NOT_NULL(baseReader_); + VELOX_CHECK_NOT_NULL(baseRowReader_); + ++currentFileIndex_; + baseReader_ = nullptr; + baseRowReader_ = nullptr; + emptySplit_ = false; +} + +std::shared_ptr PaimonSplitReader::makeFileConnectorSplit( + const PaimonDataFile& dataFile) const { + return std::make_shared( + paimonSplit_->connectorId, + dataFile.path, + paimonSplit_->fileFormat(), + /*_start=*/0, + /*_length=*/std::numeric_limits::max(), + /*splitWeight=*/0, + /*cacheable=*/true, + /*_properties=*/std::nullopt, + paimonSplit_->partitionKeys()); +} + +} // namespace facebook::velox::connector::hive::paimon diff --git a/velox/connectors/hive/paimon/PaimonSplitReader.h b/velox/connectors/hive/paimon/PaimonSplitReader.h new file mode 100644 index 00000000000..c21e660ee2b --- /dev/null +++ b/velox/connectors/hive/paimon/PaimonSplitReader.h @@ -0,0 +1,105 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/connectors/hive/FileSplitReader.h" +#include "velox/connectors/hive/paimon/PaimonConnectorSplit.h" + +namespace facebook::velox::connector::hive::paimon { + +/// Reader for rawConvertible Paimon splits. Extends FileSplitReader to handle +/// Paimon-specific concerns: +/// +/// - Multi-file iteration: A single Paimon split may contain multiple data +/// files (one per LSM level in a bucket). This reader iterates through +/// all files sequentially, transparently advancing to the next file when +/// the current one is exhausted. +/// +/// - Deletion vectors: Roaring bitmaps marking deleted row positions within +/// a data file. Loaded in prepareSplit(), applied in next() via +/// Mutation.deletedRows (same pattern as IcebergSplitReader). (NYI) +/// +/// - _rowkind system column: Change type (+I/-U/+U/-D) for changelog files +/// and input changelog mode. Handled in adaptColumns(). (NYI) +/// +/// Also serves as the building block for PaimonMergeReader, which composes +/// multiple PaimonSplitReaders to perform merge-on-read for primary-key +/// tables with rawConvertible=false. +class PaimonSplitReader : public FileSplitReader { + public: + /// @param fileSplit FileConnectorSplit for the first data file (set by + /// FileDataSource::addSplit before calling createSplitReader). + /// @param paimonSplit The full Paimon split. PaimonSplitReader builds + /// FileConnectorSplits for all data files internally. + PaimonSplitReader( + const std::shared_ptr& fileSplit, + std::shared_ptr paimonSplit, + const FileTableHandlePtr& tableHandle, + const std::unordered_map* partitionKeys, + const ConnectorQueryCtx* connectorQueryCtx, + const std::shared_ptr& fileConfig, + const RowTypePtr& readerOutputType, + const std::shared_ptr& dataIoStats, + const std::shared_ptr& metadataIoStats, + const std::shared_ptr& ioStats, + FileHandleFactory* fileHandleFactory, + folly::Executor* executor, + const std::shared_ptr& scanSpec); + + ~PaimonSplitReader() override = default; + + /// Saves metadata filter and runtime stats for use by + /// ensureFileSplitReader(). File validation is done in the constructor. + void prepareSplit( + std::shared_ptr metadataFilter, + dwio::common::RuntimeStatistics& runtimeStats, + const folly::F14FastMap& fileReadOps = {}) + override; + + /// Reads from the current file. When a file is exhausted, transparently + /// advances to the next file in the split. Returns 0 only when all files + /// are exhausted. + uint64_t next(uint64_t size, VectorPtr& output) override; + + private: + // Builds a FileConnectorSplit from a PaimonDataFile. + std::shared_ptr makeFileConnectorSplit( + const PaimonDataFile& dataFile) const; + + // Ensures a file split reader is ready to produce rows. If the current + // reader is exhausted, advances to the next non-empty file. Returns true + // if a reader is ready, false when all files are exhausted. + bool ensureFileSplitReader(); + + // Cleans up the current file reader after it is exhausted and advances + // currentFileIndex_ so ensureFileSplitReader() opens the next file. + void finishFileSplitReader(); + + const std::shared_ptr paimonSplit_; + + // FileConnectorSplits built from paimonSplit_->dataFiles(), one per file. + // Used to switch the base reader between files during iteration. + std::vector> fileSplits_; + + // Index of the next data file to open. + size_t currentFileIndex_{0}; + + // Saved from prepareSplit() for re-use when advancing to subsequent files. + std::shared_ptr metadataFilter_; + dwio::common::RuntimeStatistics* runtimeStats_{nullptr}; +}; + +} // namespace facebook::velox::connector::hive::paimon diff --git a/velox/connectors/hive/paimon/tests/CMakeLists.txt b/velox/connectors/hive/paimon/tests/CMakeLists.txt new file mode 100644 index 00000000000..f8265678b67 --- /dev/null +++ b/velox/connectors/hive/paimon/tests/CMakeLists.txt @@ -0,0 +1,60 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(NOT VELOX_DISABLE_GOOGLETEST) + add_executable(velox_hive_paimon_split_test PaimonConnectorSplitTest.cpp) + add_test(velox_hive_paimon_split_test velox_hive_paimon_split_test) + + target_link_libraries( + velox_hive_paimon_split_test + velox_hive_paimon_split + velox_dwio_common + GTest::gtest + GTest::gtest_main + GTest::gmock + ) + + add_executable(velox_hive_paimon_data_file_meta_test PaimonDataFileMetaTest.cpp) + add_test(velox_hive_paimon_data_file_meta_test velox_hive_paimon_data_file_meta_test) + + target_link_libraries( + velox_hive_paimon_data_file_meta_test + velox_hive_paimon_split + GTest::gtest + GTest::gtest_main + GTest::gmock + ) + + add_executable(velox_hive_paimon_deletion_file_test PaimonDeletionFileTest.cpp) + add_test(velox_hive_paimon_deletion_file_test velox_hive_paimon_deletion_file_test) + + target_link_libraries( + velox_hive_paimon_deletion_file_test + velox_hive_paimon_split + GTest::gtest + GTest::gtest_main + GTest::gmock + ) + + add_executable(velox_hive_paimon_row_kind_test PaimonRowKindTest.cpp) + add_test(velox_hive_paimon_row_kind_test velox_hive_paimon_row_kind_test) + + target_link_libraries( + velox_hive_paimon_row_kind_test + velox_hive_paimon_split + GTest::gtest + GTest::gtest_main + fmt::fmt + ) +endif() diff --git a/velox/connectors/hive/paimon/tests/PaimonConnectorSplitTest.cpp b/velox/connectors/hive/paimon/tests/PaimonConnectorSplitTest.cpp new file mode 100644 index 00000000000..0ebb2680fd7 --- /dev/null +++ b/velox/connectors/hive/paimon/tests/PaimonConnectorSplitTest.cpp @@ -0,0 +1,373 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/connectors/hive/paimon/PaimonConnectorSplit.h" + +#include +#include +#include +#include + +#include "velox/common/base/tests/GTestUtils.h" + +using namespace facebook::velox::connector::hive::paimon; +using namespace facebook::velox; +using namespace facebook::velox::dwio::common; + +namespace { +using PartitionKeys = + std::unordered_map>; +} // namespace + +class PaimonConnectorSplitTest : public testing::Test { + protected: + const std::string kConnectorId{"test-connector"}; +}; + +TEST_F(PaimonConnectorSplitTest, basic) { + PaimonDataFile file0; + file0.path = "s3://bucket/table/dt=2024-01-01/bucket-0/data-001.orc"; + file0.size = 1024; + file0.rowCount = 100; + file0.level = 0; + + PaimonDataFile file1; + file1.path = "s3://bucket/table/dt=2024-01-01/bucket-0/data-002.orc"; + file1.size = 2048; + file1.rowCount = 200; + file1.level = 1; + + std::vector files{file0, file1}; + PartitionKeys partitionKeys{{"dt", "2024-01-01"}}; + + auto split = std::make_shared( + kConnectorId, + /*snapshotId=*/1, + PaimonTableType::kPrimaryKey, + FileFormat::DWRF, + files, + partitionKeys, + 0); + + EXPECT_EQ(split->snapshotId(), 1); + EXPECT_EQ(split->tableType(), PaimonTableType::kPrimaryKey); + EXPECT_EQ(split->fileFormat(), FileFormat::DWRF); + EXPECT_TRUE(split->rawConvertible()); + EXPECT_EQ(split->tableBucketNumber(), 0); + ASSERT_EQ(split->dataFiles().size(), 2); + EXPECT_EQ(split->dataFiles()[0].path, files[0].path); + EXPECT_EQ(split->dataFiles()[0].size, 1024); + EXPECT_EQ(split->dataFiles()[1].path, files[1].path); + EXPECT_EQ(split->dataFiles()[1].size, 2048); + ASSERT_EQ(split->partitionKeys().count("dt"), 1); + EXPECT_EQ(split->partitionKeys().at("dt"), "2024-01-01"); +} + +TEST_F(PaimonConnectorSplitTest, nimbleFormat) { + PaimonDataFile file0; + file0.path = "s3://bucket/table/data-001.nimble"; + file0.size = 4096; + file0.rowCount = 100; + std::vector files{file0}; + + auto split = std::make_shared( + kConnectorId, + /*snapshotId=*/1, + PaimonTableType::kAppendOnly, + FileFormat::NIMBLE, + files, + PartitionKeys{}, + std::nullopt); + + ASSERT_EQ(split->dataFiles().size(), 1); + EXPECT_EQ(split->tableType(), PaimonTableType::kAppendOnly); + EXPECT_EQ(split->fileFormat(), FileFormat::NIMBLE); +} + +TEST_F(PaimonConnectorSplitTest, parquetFormat) { + PaimonDataFile file0; + file0.path = "s3://bucket/table/data-001.parquet"; + file0.size = 4096; + file0.rowCount = 100; + std::vector files{file0}; + + auto split = std::make_shared( + kConnectorId, + /*snapshotId=*/1, + PaimonTableType::kAppendOnly, + FileFormat::PARQUET, + files, + PartitionKeys{}, + std::nullopt); + + ASSERT_EQ(split->dataFiles().size(), 1); + EXPECT_EQ(split->fileFormat(), FileFormat::PARQUET); +} + +TEST_F(PaimonConnectorSplitTest, emptyFilesThrows) { + std::vector noFiles; + VELOX_ASSERT_THROW( + std::make_shared( + kConnectorId, + /*snapshotId=*/1, + PaimonTableType::kAppendOnly, + FileFormat::NIMBLE, + noFiles, + PartitionKeys{}, + std::nullopt), + "PaimonConnectorSplit requires non-empty dataFiles"); +} + +TEST_F(PaimonConnectorSplitTest, rawConvertibleWithDeleteRowCountThrows) { + PaimonDataFile file; + file.path = "data-001.orc"; + file.size = 1024; + file.rowCount = 100; + file.level = 1; + file.deleteRowCount = 5; + + std::vector files{file}; + VELOX_ASSERT_THROW( + std::make_shared( + kConnectorId, + /*snapshotId=*/1, + PaimonTableType::kPrimaryKey, + FileFormat::DWRF, + files, + PartitionKeys{}, + std::nullopt, + /*rawConvertible=*/true), + "rawConvertible split cannot have files with deleteRowCount > 0"); +} + +TEST_F(PaimonConnectorSplitTest, notRawConvertibleWithDeleteRowCount) { + PaimonDataFile file; + file.path = "data-001.orc"; + file.size = 1024; + file.rowCount = 100; + file.level = 1; + file.deleteRowCount = 5; + + std::vector files{file}; + auto split = std::make_shared( + kConnectorId, + /*snapshotId=*/1, + PaimonTableType::kPrimaryKey, + FileFormat::DWRF, + files, + PartitionKeys{}, + std::nullopt, + /*rawConvertible=*/false); + + EXPECT_FALSE(split->rawConvertible()); + EXPECT_EQ(split->dataFiles()[0].deleteRowCount, 5); +} + +TEST_F(PaimonConnectorSplitTest, toStringOutput) { + PaimonDataFile f0; + f0.path = "data-001.orc"; + f0.size = 1024; + f0.rowCount = 100; + + PaimonDataFile f1; + f1.path = "data-002.orc"; + f1.size = 2048; + f1.rowCount = 200; + f1.level = 1; + + PaimonDataFile f2; + f2.path = "data-003.orc"; + f2.size = 512; + f2.rowCount = 50; + f2.level = 2; + + std::vector files{f0, f1, f2}; + + auto split = std::make_shared( + kConnectorId, + /*snapshotId=*/42, + PaimonTableType::kPrimaryKey, + FileFormat::DWRF, + files, + PartitionKeys{}, + std::nullopt); + + EXPECT_THAT(split->toString(), testing::HasSubstr("snapshot 42")); + EXPECT_THAT(split->toString(), testing::HasSubstr("PRIMARY_KEY")); + EXPECT_THAT(split->toString(), testing::HasSubstr(kConnectorId)); + EXPECT_THAT(split->toString(), testing::HasSubstr("data-001.orc")); + EXPECT_THAT(split->toString(), testing::HasSubstr("data-003.orc")); + EXPECT_THAT(split->toString(), testing::HasSubstr("size=1024")); + EXPECT_THAT(split->toString(), testing::HasSubstr("rows=100")); + EXPECT_THAT(split->toString(), testing::HasSubstr("level=1")); + EXPECT_THAT(split->toString(), testing::HasSubstr("deletionFile=none")); +} + +TEST_F(PaimonConnectorSplitTest, nullPartitionValue) { + PaimonDataFile file0; + file0.path = "data-001.orc"; + file0.size = 1024; + file0.rowCount = 100; + std::vector files{file0}; + PartitionKeys partitionKeys{{"country", std::nullopt}}; + + auto split = std::make_shared( + kConnectorId, + /*snapshotId=*/1, + PaimonTableType::kAppendOnly, + FileFormat::DWRF, + files, + partitionKeys, + std::nullopt); + + ASSERT_EQ(split->partitionKeys().count("country"), 1); + EXPECT_EQ(split->partitionKeys().at("country"), std::nullopt); +} + +TEST_F(PaimonConnectorSplitTest, builder) { + auto split = PaimonConnectorSplitBuilder( + kConnectorId, + /*snapshotId=*/5, + PaimonTableType::kPrimaryKey, + FileFormat::PARQUET) + .addFile("s3://bucket/data-001.parquet", 1024, 0) + .addFile("s3://bucket/data-002.parquet", 2048, 1) + .partitionKey("dt", "2024-01-01") + .partitionKey("country", std::nullopt) + .tableBucketNumber(3) + .build(); + + EXPECT_EQ(split->snapshotId(), 5); + EXPECT_EQ(split->fileFormat(), FileFormat::PARQUET); + EXPECT_TRUE(split->rawConvertible()); + + ASSERT_EQ(split->dataFiles().size(), 2); + EXPECT_EQ(split->dataFiles()[0].path, "s3://bucket/data-001.parquet"); + EXPECT_EQ(split->dataFiles()[0].size, 1024); + EXPECT_EQ(split->dataFiles()[1].path, "s3://bucket/data-002.parquet"); + EXPECT_EQ(split->dataFiles()[1].size, 2048); + EXPECT_EQ(split->tableBucketNumber(), 3); + EXPECT_EQ(split->partitionKeys().at("dt"), "2024-01-01"); + EXPECT_EQ(split->partitionKeys().at("country"), std::nullopt); +} + +TEST_F(PaimonConnectorSplitTest, builderDefaults) { + auto split = PaimonConnectorSplitBuilder( + kConnectorId, + /*snapshotId=*/1, + PaimonTableType::kAppendOnly, + FileFormat::NIMBLE) + .addFile("data.nimble", 512) + .build(); + + ASSERT_EQ(split->dataFiles().size(), 1); + EXPECT_EQ(split->fileFormat(), FileFormat::NIMBLE); + EXPECT_EQ(split->tableBucketNumber(), std::nullopt); + EXPECT_TRUE(split->partitionKeys().empty()); +} + +TEST_F(PaimonConnectorSplitTest, serializeRoundTrip) { + PaimonDataFile meta; + meta.path = "s3://bucket/data-001.parquet"; + meta.size = 4096; + meta.rowCount = 1000; + meta.level = 2; + meta.minSequenceNumber = 10; + meta.maxSequenceNumber = 20; + meta.deleteRowCount = 5; + meta.creationTimeMs = 1700000000; + meta.source = PaimonDataFile::Source::kCompact; + meta.deletionFile = + PaimonDeletionFile{"s3://bucket/deletion-001.bin", 0, 128, 5}; + + std::vector files{meta}; + PartitionKeys partitionKeys{{"dt", "2024-01-01"}, {"country", std::nullopt}}; + + auto original = std::make_shared( + kConnectorId, + /*snapshotId=*/42, + PaimonTableType::kPrimaryKey, + FileFormat::PARQUET, + files, + partitionKeys, + /*tableBucketNumber=*/7, + /*rawConvertible=*/false); + + auto serialized = original->serialize(); + auto deserialized = PaimonConnectorSplit::create(serialized); + + EXPECT_EQ(deserialized->connectorId, kConnectorId); + EXPECT_EQ(deserialized->snapshotId(), 42); + EXPECT_EQ(deserialized->tableType(), PaimonTableType::kPrimaryKey); + EXPECT_EQ(deserialized->fileFormat(), FileFormat::PARQUET); + EXPECT_FALSE(deserialized->rawConvertible()); + EXPECT_EQ(deserialized->tableBucketNumber(), 7); + + ASSERT_EQ(deserialized->dataFiles().size(), 1); + const auto& file = deserialized->dataFiles()[0]; + EXPECT_EQ(file.path, "s3://bucket/data-001.parquet"); + EXPECT_EQ(file.size, 4096); + EXPECT_EQ(file.rowCount, 1000); + EXPECT_EQ(file.level, 2); + EXPECT_EQ(file.minSequenceNumber, 10); + EXPECT_EQ(file.maxSequenceNumber, 20); + EXPECT_EQ(file.deleteRowCount, 5); + EXPECT_EQ(file.creationTimeMs, 1700000000); + EXPECT_EQ(file.source, PaimonDataFile::Source::kCompact); + + ASSERT_TRUE(file.deletionFile.has_value()); + EXPECT_EQ(file.deletionFile->path, "s3://bucket/deletion-001.bin"); + EXPECT_EQ(file.deletionFile->offset, 0); + EXPECT_EQ(file.deletionFile->length, 128); + EXPECT_EQ(file.deletionFile->cardinality, 5); + + ASSERT_EQ(deserialized->partitionKeys().count("dt"), 1); + EXPECT_EQ(deserialized->partitionKeys().at("dt"), "2024-01-01"); + ASSERT_EQ(deserialized->partitionKeys().count("country"), 1); + EXPECT_EQ(deserialized->partitionKeys().at("country"), std::nullopt); + + ASSERT_EQ(deserialized->dataFiles().size(), 1); + EXPECT_EQ(deserialized->dataFiles()[0].path, "s3://bucket/data-001.parquet"); + EXPECT_EQ(deserialized->dataFiles()[0].size, 4096); +} + +TEST_F(PaimonConnectorSplitTest, paimonTableTypeStringAndParse) { + EXPECT_EQ(paimonTableTypeString(PaimonTableType::kAppendOnly), "APPEND_ONLY"); + EXPECT_EQ(paimonTableTypeString(PaimonTableType::kPrimaryKey), "PRIMARY_KEY"); + + EXPECT_EQ( + paimonTableTypeFromString("APPEND_ONLY"), PaimonTableType::kAppendOnly); + EXPECT_EQ( + paimonTableTypeFromString("PRIMARY_KEY"), PaimonTableType::kPrimaryKey); + + VELOX_ASSERT_THROW( + paimonTableTypeFromString("UNKNOWN"), "Unknown PaimonTableType: UNKNOWN"); +} + +TEST_F(PaimonConnectorSplitTest, tableTypeStreamAndFormat) { + { + std::ostringstream os; + os << PaimonTableType::kAppendOnly; + EXPECT_EQ(os.str(), "APPEND_ONLY"); + } + { + std::ostringstream os; + os << PaimonTableType::kPrimaryKey; + EXPECT_EQ(os.str(), "PRIMARY_KEY"); + } + + EXPECT_EQ(fmt::format("{}", PaimonTableType::kAppendOnly), "APPEND_ONLY"); + EXPECT_EQ(fmt::format("{}", PaimonTableType::kPrimaryKey), "PRIMARY_KEY"); +} diff --git a/velox/connectors/hive/paimon/tests/PaimonConnectorTest.cpp b/velox/connectors/hive/paimon/tests/PaimonConnectorTest.cpp new file mode 100644 index 00000000000..4bfb10dc774 --- /dev/null +++ b/velox/connectors/hive/paimon/tests/PaimonConnectorTest.cpp @@ -0,0 +1,329 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/connectors/hive/paimon/PaimonConnector.h" + +#include + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/connectors/ConnectorRegistry.h" +#include "velox/connectors/hive/HiveConfig.h" +#include "velox/connectors/hive/paimon/PaimonConnectorSplit.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" + +namespace facebook::velox::connector::hive::paimon { +namespace { + +static const std::string kPaimonConnectorId = "test-paimon"; + +class PaimonConnectorTest : public exec::test::HiveConnectorTestBase { + protected: + void SetUp() override { + HiveConnectorTestBase::SetUp(); + auto config = std::make_shared( + std::unordered_map{}); + auto connector = + PaimonConnectorFactory().newConnector(kPaimonConnectorId, config); + ConnectorRegistry::global().insert(connector->connectorId(), connector); + } + + void TearDown() override { + ConnectorRegistry::global().erase(kPaimonConnectorId); + HiveConnectorTestBase::TearDown(); + } + + /// Creates a table handle for the Paimon connector. + static std::shared_ptr makePaimonTableHandle( + const RowTypePtr& dataColumns = nullptr, + common::SubfieldFilters subfieldFilters = {}, + const core::TypedExprPtr& remainingFilter = nullptr) { + return std::make_shared( + kPaimonConnectorId, + "paimon_table", + std::move(subfieldFilters), + remainingFilter, + dataColumns); + } + + /// Creates column assignments for all columns as regular columns. + static connector::ColumnHandleMap makePaimonColumnHandles( + const RowTypePtr& rowType) { + connector::ColumnHandleMap assignments; + assignments.reserve(rowType->size()); + for (uint32_t i = 0; i < rowType->size(); ++i) { + const auto& name = rowType->nameOf(i); + assignments[name] = std::make_shared( + name, + HiveColumnHandle::ColumnType::kRegular, + rowType->childAt(i), + rowType->childAt(i)); + } + return assignments; + } + + /// Builds a table scan plan using the Paimon connector. + core::PlanNodePtr makePaimonScanPlan( + const RowTypePtr& outputType, + const RowTypePtr& dataColumns = nullptr, + common::SubfieldFilters subfieldFilters = {}, + const core::TypedExprPtr& remainingFilter = nullptr) { + auto tableHandle = makePaimonTableHandle( + dataColumns ? dataColumns : outputType, + std::move(subfieldFilters), + remainingFilter); + auto assignments = + makePaimonColumnHandles(dataColumns ? dataColumns : outputType); + return exec::test::PlanBuilder() + .startTableScan() + .connectorId(kPaimonConnectorId) + .outputType(outputType) + .tableHandle(tableHandle) + .assignments(assignments) + .endTableScan() + .planNode(); + } + + /// Creates a PaimonConnectorSplit from file paths. + std::shared_ptr makePaimonSplit( + const std::vector& filePaths, + PaimonTableType tableType = PaimonTableType::kAppendOnly, + dwio::common::FileFormat format = dwio::common::FileFormat::DWRF, + const std::unordered_map>& + partitionKeys = {}, + bool rawConvertible = true) { + PaimonConnectorSplitBuilder builder( + kPaimonConnectorId, /*snapshotId=*/1, tableType, format); + for (const auto& filePath : filePaths) { + builder.addFile(filePath, /*fileSize=*/0); + } + for (const auto& [key, value] : partitionKeys) { + builder.partitionKey(key, value); + } + builder.rawConvertible(rawConvertible); + return builder.build(); + } +}; + +TEST_F(PaimonConnectorTest, connectorRegistration) { + auto connector = ConnectorRegistry::tryGet(kPaimonConnectorId); + ASSERT_NE(connector, nullptr); + ASSERT_NE(connector->connectorConfig(), nullptr); +} + +TEST_F(PaimonConnectorTest, connectorFactory) { + PaimonConnectorFactory factory; + EXPECT_EQ( + std::string(PaimonConnectorFactory::kPaimonConnectorName), "paimon"); + + auto config = std::make_shared( + std::unordered_map{ + {HiveConfig::kEnableFileHandleCache, "true"}, + {HiveConfig::kNumCacheFileHandles, "500"}}); + + auto connector = factory.newConnector("test-paimon-2", config); + ASSERT_NE(connector, nullptr); + + HiveConfig hiveConfig(connector->connectorConfig()); + EXPECT_TRUE(hiveConfig.isFileHandleCacheEnabled()); + EXPECT_EQ(hiveConfig.numCacheFileHandles(), 500); +} + +// E2E test: read a single file from an append-only Paimon split. +TEST_F(PaimonConnectorTest, appendOnlySingleFile) { + auto rowType = ROW({"c0", "c1"}, {BIGINT(), VARCHAR()}); + auto vectors = makeVectors(rowType, 1, 100); + + auto filePaths = makeFilePaths(1); + writeToFile(filePaths[0]->getPath(), vectors); + + auto split = makePaimonSplit({filePaths[0]->getPath()}); + auto plan = makePaimonScanPlan(rowType); + + exec::test::AssertQueryBuilder(plan).split(split).assertResults(vectors); +} + +// E2E test: read multiple files from a single append-only Paimon split. +TEST_F(PaimonConnectorTest, appendOnlyMultipleFiles) { + auto rowType = ROW({"c0", "c1"}, {BIGINT(), VARCHAR()}); + auto vectors1 = makeVectors(rowType, 1, 100); + auto vectors2 = makeVectors(rowType, 1, 50); + + auto filePaths = makeFilePaths(2); + writeToFile(filePaths[0]->getPath(), vectors1); + writeToFile(filePaths[1]->getPath(), vectors2); + + auto split = + makePaimonSplit({filePaths[0]->getPath(), filePaths[1]->getPath()}); + auto plan = makePaimonScanPlan(rowType); + + // Expected result is all rows from both files. + std::vector expected; + expected.insert(expected.end(), vectors1.begin(), vectors1.end()); + expected.insert(expected.end(), vectors2.begin(), vectors2.end()); + + exec::test::AssertQueryBuilder(plan).split(split).assertResults(expected); +} + +// Verify that empty splits (no data files) are rejected by +// PaimonConnectorSplit. +TEST_F(PaimonConnectorTest, rejectsEmptySplit) { + VELOX_ASSERT_THROW( + makePaimonSplit({}), "PaimonConnectorSplit requires non-empty dataFiles"); +} + +// E2E test: read with partition keys from an append-only Paimon split. +TEST_F(PaimonConnectorTest, appendOnlyWithPartitionKeys) { + // Data file columns (not including partition column). + auto dataRowType = ROW({"c0"}, {BIGINT()}); + auto vectors = makeVectors(dataRowType, 1, 100); + + auto filePaths = makeFilePaths(1); + writeToFile(filePaths[0]->getPath(), vectors); + + // The output includes both the data column and the partition column. + auto outputType = ROW({"c0", "p0"}, {BIGINT(), VARCHAR()}); + auto tableHandle = makePaimonTableHandle(outputType); + + connector::ColumnHandleMap assignments; + assignments["c0"] = std::make_shared( + "c0", HiveColumnHandle::ColumnType::kRegular, BIGINT(), BIGINT()); + assignments["p0"] = std::make_shared( + "p0", HiveColumnHandle::ColumnType::kPartitionKey, VARCHAR(), VARCHAR()); + + auto plan = exec::test::PlanBuilder() + .startTableScan() + .connectorId(kPaimonConnectorId) + .outputType(outputType) + .tableHandle(tableHandle) + .assignments(assignments) + .endTableScan() + .planNode(); + + auto split = makePaimonSplit( + {filePaths[0]->getPath()}, + PaimonTableType::kAppendOnly, + dwio::common::FileFormat::DWRF, + {{"p0", std::optional("2024-01-01")}}); + + // Build expected output with the partition value filled in. + auto expectedC0 = vectors[0]->childAt(0); + auto numRows = vectors[0]->size(); + auto expectedP0 = + BaseVector::createConstant(VARCHAR(), "2024-01-01", numRows, pool()); + + auto expected = makeRowVector({"c0", "p0"}, {expectedC0, expectedP0}); + + exec::test::AssertQueryBuilder(plan).split(split).assertResults({expected}); +} + +// Verify that non-rawConvertible splits trigger NYI (merge-on-read). +TEST_F(PaimonConnectorTest, rejectsNonRawConvertible) { + auto rowType = ROW({"c0"}, {BIGINT()}); + auto vectors = makeVectors(rowType, 1, 10); + + auto filePaths = makeFilePaths(1); + writeToFile(filePaths[0]->getPath(), vectors); + + auto split = makePaimonSplit( + {filePaths[0]->getPath()}, + PaimonTableType::kAppendOnly, + dwio::common::FileFormat::DWRF, + /*partitionKeys=*/{}, + /*rawConvertible=*/false); + + auto plan = makePaimonScanPlan(rowType); + + VELOX_ASSERT_THROW( + exec::test::AssertQueryBuilder(plan).split(split).copyResults(pool()), + "Paimon merge-on-read is not yet implemented"); +} + +// E2E test: primary-key table with rawConvertible=true reads successfully. +// When fully compacted (rawConvertible), primary-key tables can be read +// the same way as append-only — each file is independent, no merge needed. +TEST_F(PaimonConnectorTest, primaryKeyRawConvertible) { + auto rowType = ROW({"c0"}, {BIGINT()}); + auto vectors = makeVectors(rowType, 1, 10); + + auto filePaths = makeFilePaths(1); + writeToFile(filePaths[0]->getPath(), vectors); + + auto split = makePaimonSplit( + {filePaths[0]->getPath()}, + PaimonTableType::kPrimaryKey, + dwio::common::FileFormat::DWRF); + + auto plan = makePaimonScanPlan(rowType); + + exec::test::AssertQueryBuilder(plan).split(split).assertResults(vectors); +} + +// Verify that primary-key table with rawConvertible=false triggers NYI. +TEST_F(PaimonConnectorTest, rejectsPrimaryKeyNonRawConvertible) { + auto rowType = ROW({"c0"}, {BIGINT()}); + auto vectors = makeVectors(rowType, 1, 10); + + auto filePaths = makeFilePaths(1); + writeToFile(filePaths[0]->getPath(), vectors); + + auto split = makePaimonSplit( + {filePaths[0]->getPath()}, + PaimonTableType::kPrimaryKey, + dwio::common::FileFormat::DWRF, + /*partitionKeys=*/{}, + /*rawConvertible=*/false); + + auto plan = makePaimonScanPlan(rowType); + + VELOX_ASSERT_THROW( + exec::test::AssertQueryBuilder(plan).split(split).copyResults(pool()), + "Paimon merge-on-read is not yet implemented"); +} + +// E2E test: multiple splits each with multiple files. +TEST_F(PaimonConnectorTest, appendOnlyMultipleSplits) { + auto rowType = ROW({"c0", "c1"}, {BIGINT(), VARCHAR()}); + + // Split 1 with 2 files. + auto vectors1a = makeVectors(rowType, 1, 50); + auto vectors1b = makeVectors(rowType, 1, 30); + // Split 2 with 1 file. + auto vectors2a = makeVectors(rowType, 1, 40); + + auto filePaths = makeFilePaths(3); + writeToFile(filePaths[0]->getPath(), vectors1a); + writeToFile(filePaths[1]->getPath(), vectors1b); + writeToFile(filePaths[2]->getPath(), vectors2a); + + auto split1 = + makePaimonSplit({filePaths[0]->getPath(), filePaths[1]->getPath()}); + auto split2 = makePaimonSplit({filePaths[2]->getPath()}); + + auto plan = makePaimonScanPlan(rowType); + + std::vector expected; + expected.insert(expected.end(), vectors1a.begin(), vectors1a.end()); + expected.insert(expected.end(), vectors1b.begin(), vectors1b.end()); + expected.insert(expected.end(), vectors2a.begin(), vectors2a.end()); + + exec::test::AssertQueryBuilder(plan) + .splits({split1, split2}) + .assertResults(expected); +} + +} // namespace +} // namespace facebook::velox::connector::hive::paimon diff --git a/velox/connectors/hive/paimon/tests/PaimonDataFileMetaTest.cpp b/velox/connectors/hive/paimon/tests/PaimonDataFileMetaTest.cpp new file mode 100644 index 00000000000..e728191992b --- /dev/null +++ b/velox/connectors/hive/paimon/tests/PaimonDataFileMetaTest.cpp @@ -0,0 +1,211 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/connectors/hive/paimon/PaimonDataFileMeta.h" + +#include +#include +#include +#include + +#include "velox/common/base/Exceptions.h" +#include "velox/common/base/tests/GTestUtils.h" + +using namespace facebook::velox::connector::hive::paimon; +using namespace facebook::velox; + +TEST(PaimonDataFileTest, serializeRoundTrip) { + PaimonDataFile file; + file.path = "data.parquet"; + file.size = 4096; + file.rowCount = 500; + file.level = 1; + file.minSequenceNumber = 5; + file.maxSequenceNumber = 15; + file.deleteRowCount = 3; + file.creationTimeMs = 1700000000; + file.type = PaimonDataFile::Type::kChangelog; + file.source = PaimonDataFile::Source::kCompact; + file.deletionFile = PaimonDeletionFile{"del.bin", 0, 64, 2}; + + auto serialized = file.serialize(); + auto deserialized = PaimonDataFile::create(serialized); + + EXPECT_EQ(deserialized.path, "data.parquet"); + EXPECT_EQ(deserialized.size, 4096); + EXPECT_EQ(deserialized.rowCount, 500); + EXPECT_EQ(deserialized.level, 1); + EXPECT_EQ(deserialized.minSequenceNumber, 5); + EXPECT_EQ(deserialized.maxSequenceNumber, 15); + EXPECT_EQ(deserialized.deleteRowCount, 3); + EXPECT_EQ(deserialized.creationTimeMs, 1700000000); + EXPECT_EQ(deserialized.type, PaimonDataFile::Type::kChangelog); + EXPECT_EQ(deserialized.source, PaimonDataFile::Source::kCompact); + ASSERT_TRUE(deserialized.deletionFile.has_value()); + EXPECT_EQ(deserialized.deletionFile->path, "del.bin"); + EXPECT_EQ(deserialized.deletionFile->offset, 0); + EXPECT_EQ(deserialized.deletionFile->length, 64); + EXPECT_EQ(deserialized.deletionFile->cardinality, 2); +} + +TEST(PaimonDataFileTest, serializeNoDeletionFile) { + PaimonDataFile file; + file.path = "data.orc"; + file.size = 1024; + + auto serialized = file.serialize(); + auto deserialized = PaimonDataFile::create(serialized); + + EXPECT_EQ(deserialized.path, "data.orc"); + EXPECT_EQ(deserialized.size, 1024); + EXPECT_FALSE(deserialized.deletionFile.has_value()); +} + +TEST(PaimonDataFileTest, serializeDefaultValues) { + PaimonDataFile file; + file.path = "data.orc"; + + auto serialized = file.serialize(); + auto deserialized = PaimonDataFile::create(serialized); + + EXPECT_EQ(deserialized.path, "data.orc"); + EXPECT_EQ(deserialized.size, 0); + EXPECT_EQ(deserialized.rowCount, 0); + EXPECT_EQ(deserialized.level, 0); + EXPECT_EQ(deserialized.minSequenceNumber, 0); + EXPECT_EQ(deserialized.maxSequenceNumber, 0); + EXPECT_EQ(deserialized.deleteRowCount, 0); + EXPECT_EQ(deserialized.creationTimeMs, 0); + EXPECT_EQ(deserialized.type, PaimonDataFile::Type::kData); + EXPECT_EQ(deserialized.source, PaimonDataFile::Source::kAppend); + EXPECT_FALSE(deserialized.deletionFile.has_value()); +} + +TEST(PaimonDataFileTest, toString) { + PaimonDataFile file; + file.path = "data.orc"; + file.size = 1024; + file.rowCount = 100; + file.level = 0; + + auto str = file.toString(); + EXPECT_THAT(str, testing::HasSubstr("data.orc")); + EXPECT_THAT(str, testing::HasSubstr("size=1024")); + EXPECT_THAT(str, testing::HasSubstr("rows=100")); + EXPECT_THAT(str, testing::HasSubstr("level=0")); + EXPECT_THAT(str, testing::HasSubstr("type=DATA")); + EXPECT_THAT(str, testing::HasSubstr("source=APPEND")); + EXPECT_THAT(str, testing::HasSubstr("deletionFile=none")); +} + +TEST(PaimonDataFileTest, toStringChangelog) { + PaimonDataFile file; + file.path = "changelog.orc"; + file.size = 2048; + file.rowCount = 200; + file.level = 1; + file.type = PaimonDataFile::Type::kChangelog; + file.source = PaimonDataFile::Source::kCompact; + file.deletionFile = PaimonDeletionFile{"del.bin", 0, 64, 2}; + + auto str = file.toString(); + EXPECT_THAT(str, testing::HasSubstr("type=CHANGELOG")); + EXPECT_THAT(str, testing::HasSubstr("source=COMPACT")); + EXPECT_THAT(str, testing::HasSubstr("del.bin")); + EXPECT_THAT(str, testing::HasSubstr("cardinality=2")); +} + +TEST(PaimonDataFileTest, defaultValues) { + PaimonDataFile file; + EXPECT_TRUE(file.path.empty()); + EXPECT_EQ(file.size, 0); + EXPECT_EQ(file.rowCount, 0); + EXPECT_EQ(file.level, 0); + EXPECT_EQ(file.minSequenceNumber, 0); + EXPECT_EQ(file.maxSequenceNumber, 0); + EXPECT_EQ(file.deleteRowCount, 0); + EXPECT_EQ(file.creationTimeMs, 0); + EXPECT_EQ(file.type, PaimonDataFile::Type::kData); + EXPECT_EQ(file.source, PaimonDataFile::Source::kAppend); + EXPECT_FALSE(file.deletionFile.has_value()); +} + +TEST(PaimonDataFileTest, typeStringAndParse) { + EXPECT_EQ(PaimonDataFile::typeString(PaimonDataFile::Type::kData), "DATA"); + EXPECT_EQ( + PaimonDataFile::typeString(PaimonDataFile::Type::kChangelog), + "CHANGELOG"); + + EXPECT_EQ( + PaimonDataFile::typeFromString("DATA"), PaimonDataFile::Type::kData); + EXPECT_EQ( + PaimonDataFile::typeFromString("CHANGELOG"), + PaimonDataFile::Type::kChangelog); + + VELOX_ASSERT_THROW( + PaimonDataFile::typeFromString("UNKNOWN"), + "Unknown PaimonDataFile::Type: UNKNOWN"); +} + +TEST(PaimonDataFileTest, sourceStringAndParse) { + EXPECT_EQ( + PaimonDataFile::sourceString(PaimonDataFile::Source::kAppend), "APPEND"); + EXPECT_EQ( + PaimonDataFile::sourceString(PaimonDataFile::Source::kCompact), + "COMPACT"); + + EXPECT_EQ( + PaimonDataFile::sourceFromString("APPEND"), + PaimonDataFile::Source::kAppend); + EXPECT_EQ( + PaimonDataFile::sourceFromString("COMPACT"), + PaimonDataFile::Source::kCompact); + + VELOX_ASSERT_THROW( + PaimonDataFile::sourceFromString("UNKNOWN"), + "Unknown PaimonDataFile::Source: UNKNOWN"); +} + +TEST(PaimonDataFileTest, typeStreamAndFormat) { + { + std::ostringstream os; + os << PaimonDataFile::Type::kData; + EXPECT_EQ(os.str(), "DATA"); + } + { + std::ostringstream os; + os << PaimonDataFile::Type::kChangelog; + EXPECT_EQ(os.str(), "CHANGELOG"); + } + + EXPECT_EQ(fmt::format("{}", PaimonDataFile::Type::kData), "DATA"); + EXPECT_EQ(fmt::format("{}", PaimonDataFile::Type::kChangelog), "CHANGELOG"); +} + +TEST(PaimonDataFileTest, sourceStreamAndFormat) { + { + std::ostringstream os; + os << PaimonDataFile::Source::kAppend; + EXPECT_EQ(os.str(), "APPEND"); + } + { + std::ostringstream os; + os << PaimonDataFile::Source::kCompact; + EXPECT_EQ(os.str(), "COMPACT"); + } + + EXPECT_EQ(fmt::format("{}", PaimonDataFile::Source::kAppend), "APPEND"); + EXPECT_EQ(fmt::format("{}", PaimonDataFile::Source::kCompact), "COMPACT"); +} diff --git a/velox/connectors/hive/paimon/tests/PaimonDeletionFileTest.cpp b/velox/connectors/hive/paimon/tests/PaimonDeletionFileTest.cpp new file mode 100644 index 00000000000..9290f6c57b5 --- /dev/null +++ b/velox/connectors/hive/paimon/tests/PaimonDeletionFileTest.cpp @@ -0,0 +1,100 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/connectors/hive/paimon/PaimonDeletionFile.h" + +#include +#include + +#include "velox/common/base/Exceptions.h" +#include "velox/common/base/tests/GTestUtils.h" + +using namespace facebook::velox::connector::hive::paimon; +using namespace facebook::velox; + +TEST(PaimonDeletionFileTest, serializeRoundTrip) { + PaimonDeletionFile df("s3://bucket/deletion.bin", 64, 256, 10); + + auto serialized = df.serialize(); + auto deserialized = PaimonDeletionFile::create(serialized); + + EXPECT_EQ(deserialized.path, "s3://bucket/deletion.bin"); + EXPECT_EQ(deserialized.offset, 64); + EXPECT_EQ(deserialized.length, 256); + EXPECT_EQ(deserialized.cardinality, 10); +} + +TEST(PaimonDeletionFileTest, serializeRoundTripZeroOffset) { + PaimonDeletionFile df("deletion-standalone.bin", 0, 512, 42); + + auto serialized = df.serialize(); + auto deserialized = PaimonDeletionFile::create(serialized); + + EXPECT_EQ(deserialized.path, "deletion-standalone.bin"); + EXPECT_EQ(deserialized.offset, 0); + EXPECT_EQ(deserialized.length, 512); + EXPECT_EQ(deserialized.cardinality, 42); +} + +TEST(PaimonDeletionFileTest, serializeRoundTripLargeValues) { + PaimonDeletionFile df( + "s3://bucket/container.bin", 1ULL << 32, 1ULL << 20, 1000000); + + auto serialized = df.serialize(); + auto deserialized = PaimonDeletionFile::create(serialized); + + EXPECT_EQ(deserialized.path, "s3://bucket/container.bin"); + EXPECT_EQ(deserialized.offset, 1ULL << 32); + EXPECT_EQ(deserialized.length, 1ULL << 20); + EXPECT_EQ(deserialized.cardinality, 1000000); +} + +TEST(PaimonDeletionFileTest, toString) { + PaimonDeletionFile df("del.bin", 0, 128, 5); + + auto str = df.toString(); + EXPECT_THAT(str, testing::HasSubstr("del.bin")); + EXPECT_THAT(str, testing::HasSubstr("offset=0")); + EXPECT_THAT(str, testing::HasSubstr("length=128")); + EXPECT_THAT(str, testing::HasSubstr("cardinality=5")); +} + +TEST(PaimonDeletionFileTest, toStringContainerFile) { + PaimonDeletionFile df("s3://bucket/container.bin", 1024, 256, 15); + + auto str = df.toString(); + EXPECT_THAT(str, testing::HasSubstr("s3://bucket/container.bin")); + EXPECT_THAT(str, testing::HasSubstr("offset=1024")); + EXPECT_THAT(str, testing::HasSubstr("length=256")); + EXPECT_THAT(str, testing::HasSubstr("cardinality=15")); +} + +TEST(PaimonDeletionFileTest, zeroLengthThrows) { + VELOX_ASSERT_THROW( + PaimonDeletionFile("del.bin", 0, 0, 5), + "PaimonDeletionFile length must be > 0"); +} + +TEST(PaimonDeletionFileTest, zeroCardinalityThrows) { + VELOX_ASSERT_THROW( + PaimonDeletionFile("del.bin", 0, 128, 0), + "PaimonDeletionFile cardinality must be > 0"); +} + +TEST(PaimonDeletionFileTest, zeroLengthAndCardinalityThrows) { + VELOX_ASSERT_THROW( + PaimonDeletionFile("del.bin", 0, 0, 0), + "PaimonDeletionFile length must be > 0"); +} diff --git a/velox/connectors/hive/paimon/tests/PaimonRowKindTest.cpp b/velox/connectors/hive/paimon/tests/PaimonRowKindTest.cpp new file mode 100644 index 00000000000..0e58526d4a5 --- /dev/null +++ b/velox/connectors/hive/paimon/tests/PaimonRowKindTest.cpp @@ -0,0 +1,116 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/connectors/hive/paimon/PaimonRowKind.h" + +#include +#include +#include + +#include "velox/common/base/Exceptions.h" +#include "velox/common/base/tests/GTestUtils.h" + +using namespace facebook::velox::connector::hive::paimon; +using namespace facebook::velox; + +TEST(PaimonRowKindTest, rowKindString) { + EXPECT_EQ(paimonRowKindString(PaimonRowKind::kInsert), "+I"); + EXPECT_EQ(paimonRowKindString(PaimonRowKind::kUpdateBefore), "-U"); + EXPECT_EQ(paimonRowKindString(PaimonRowKind::kUpdateAfter), "+U"); + EXPECT_EQ(paimonRowKindString(PaimonRowKind::kDelete), "-D"); +} + +TEST(PaimonRowKindTest, rowKindFromValue) { + EXPECT_EQ(paimonRowKindFromValue(0), PaimonRowKind::kInsert); + EXPECT_EQ(paimonRowKindFromValue(1), PaimonRowKind::kUpdateBefore); + EXPECT_EQ(paimonRowKindFromValue(2), PaimonRowKind::kUpdateAfter); + EXPECT_EQ(paimonRowKindFromValue(3), PaimonRowKind::kDelete); + + VELOX_ASSERT_THROW( + paimonRowKindFromValue(4), "Unknown PaimonRowKind value: 4"); + VELOX_ASSERT_THROW( + paimonRowKindFromValue(-1), "Unknown PaimonRowKind value: -1"); +} + +TEST(PaimonRowKindTest, rowKindStreamAndFormat) { + { + std::ostringstream os; + os << PaimonRowKind::kInsert; + EXPECT_EQ(os.str(), "+I"); + } + { + std::ostringstream os; + os << PaimonRowKind::kDelete; + EXPECT_EQ(os.str(), "-D"); + } + + EXPECT_EQ(fmt::format("{}", PaimonRowKind::kInsert), "+I"); + EXPECT_EQ(fmt::format("{}", PaimonRowKind::kUpdateBefore), "-U"); + EXPECT_EQ(fmt::format("{}", PaimonRowKind::kUpdateAfter), "+U"); + EXPECT_EQ(fmt::format("{}", PaimonRowKind::kDelete), "-D"); +} + +TEST(PaimonRowKindTest, rowKindValues) { + EXPECT_EQ(static_cast(PaimonRowKind::kInsert), 0); + EXPECT_EQ(static_cast(PaimonRowKind::kUpdateBefore), 1); + EXPECT_EQ(static_cast(PaimonRowKind::kUpdateAfter), 2); + EXPECT_EQ(static_cast(PaimonRowKind::kDelete), 3); +} + +TEST(PaimonRowKindTest, changelogModeStringAndParse) { + EXPECT_EQ(paimonChangelogModeString(PaimonChangelogMode::kNone), "NONE"); + EXPECT_EQ(paimonChangelogModeString(PaimonChangelogMode::kInput), "INPUT"); + EXPECT_EQ(paimonChangelogModeString(PaimonChangelogMode::kLookup), "LOOKUP"); + EXPECT_EQ( + paimonChangelogModeString(PaimonChangelogMode::kFullCompaction), + "FULL_COMPACTION"); + + EXPECT_EQ(paimonChangelogModeFromString("NONE"), PaimonChangelogMode::kNone); + EXPECT_EQ( + paimonChangelogModeFromString("INPUT"), PaimonChangelogMode::kInput); + EXPECT_EQ( + paimonChangelogModeFromString("LOOKUP"), PaimonChangelogMode::kLookup); + EXPECT_EQ( + paimonChangelogModeFromString("FULL_COMPACTION"), + PaimonChangelogMode::kFullCompaction); + + VELOX_ASSERT_THROW( + paimonChangelogModeFromString("UNKNOWN"), + "Unknown PaimonChangelogMode: UNKNOWN"); +} + +TEST(PaimonRowKindTest, changelogModeStreamAndFormat) { + { + std::ostringstream os; + os << PaimonChangelogMode::kNone; + EXPECT_EQ(os.str(), "NONE"); + } + { + std::ostringstream os; + os << PaimonChangelogMode::kLookup; + EXPECT_EQ(os.str(), "LOOKUP"); + } + + EXPECT_EQ(fmt::format("{}", PaimonChangelogMode::kNone), "NONE"); + EXPECT_EQ(fmt::format("{}", PaimonChangelogMode::kInput), "INPUT"); + EXPECT_EQ(fmt::format("{}", PaimonChangelogMode::kLookup), "LOOKUP"); + EXPECT_EQ( + fmt::format("{}", PaimonChangelogMode::kFullCompaction), + "FULL_COMPACTION"); +} + +TEST(PaimonRowKindTest, rowKindColumnName) { + EXPECT_EQ(kRowKindColumn, "_rowkind"); +} diff --git a/velox/connectors/hive/storage_adapters/CMakeLists.txt b/velox/connectors/hive/storage_adapters/CMakeLists.txt index bd7c37f8164..9cbad36cf77 100644 --- a/velox/connectors/hive/storage_adapters/CMakeLists.txt +++ b/velox/connectors/hive/storage_adapters/CMakeLists.txt @@ -16,3 +16,10 @@ add_subdirectory(s3fs) add_subdirectory(hdfs) add_subdirectory(gcs) add_subdirectory(abfs) + +velox_add_library( + velox_hive_storage_adapters_test_common + INTERFACE + HEADERS + test_common/InsertTest.h +) diff --git a/velox/connectors/hive/storage_adapters/abfs/AbfsPath.h b/velox/connectors/hive/storage_adapters/abfs/AbfsPath.h index 0dbaa0a4eae..3a4e0d99f09 100644 --- a/velox/connectors/hive/storage_adapters/abfs/AbfsPath.h +++ b/velox/connectors/hive/storage_adapters/abfs/AbfsPath.h @@ -54,6 +54,11 @@ static constexpr const char* kAzureOAuthAuthType = "OAuth"; static constexpr const char* kAzureSASAuthType = "SAS"; +// For performance, re - use SAS tokens until the expiry is within this number +// of seconds. +static constexpr const char* kAzureSasTokenRenewPeriod = + "fs.azure.sas.token.renew.period.for.streams"; + // Helper class to parse and extract information from a given ABFS path. class AbfsPath { public: diff --git a/velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.cpp b/velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.cpp index 9747e5e04c4..2d819a051d7 100644 --- a/velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.cpp +++ b/velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.cpp @@ -19,11 +19,14 @@ #include "velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h" #include "velox/connectors/hive/storage_adapters/abfs/AzureClientProviderFactories.h" +#include + namespace facebook::velox::filesystems { class AbfsReadFile::Impl { - constexpr static uint64_t kNaturalReadSize = 4 << 20; // 4M + constexpr static uint64_t kNaturalReadSize = 4'194'304; // 4M constexpr static uint64_t kReadConcurrency = 8; + constexpr static size_t kDiscardBufferSize = 262'144; // 256K public: explicit Impl(std::string_view path, const config::ConfigBase& config) { @@ -57,13 +60,13 @@ class AbfsReadFile::Impl { uint64_t offset, uint64_t length, void* buffer, - File::IoStats* stats) const { + const FileIoContext& context) const { preadInternal(offset, length, static_cast(buffer)); return {static_cast(buffer), length}; } - std::string pread(uint64_t offset, uint64_t length, File::IoStats* stats) - const { + std::string + pread(uint64_t offset, uint64_t length, const FileIoContext& context) const { std::string result(length, 0); preadInternal(offset, length, result.data()); return result; @@ -72,21 +75,12 @@ class AbfsReadFile::Impl { uint64_t preadv( uint64_t offset, const std::vector>& buffers, - File::IoStats* stats) const { - size_t length = 0; - auto size = buffers.size(); - for (auto& range : buffers) { + const FileIoContext& context) const { + size_t length{0}; + for (const auto& range : buffers) { length += range.size(); } - std::string result(length, 0); - preadInternal(offset, length, static_cast(result.data())); - size_t resultOffset = 0; - for (auto range : buffers) { - if (range.data()) { - memcpy(range.data(), &(result.data()[resultOffset]), range.size()); - } - resultOffset += range.size(); - } + preadvInternal(offset, length, buffers); return length; } @@ -94,14 +88,14 @@ class AbfsReadFile::Impl { uint64_t preadv( folly::Range regions, folly::Range iobufs, - File::IoStats* stats) const { + const FileIoContext& context) const { size_t length = 0; VELOX_CHECK_EQ(regions.size(), iobufs.size()); for (size_t i = 0; i < regions.size(); ++i) { const auto& region = regions[i]; auto& output = iobufs[i]; output = folly::IOBuf(folly::IOBuf::CREATE, region.length); - pread(region.offset, region.length, output.writableData(), stats); + pread(region.offset, region.length, output.writableData(), context); output.append(region.length); length += region.length; } @@ -133,8 +127,8 @@ class AbfsReadFile::Impl { void preadInternal(uint64_t offset, uint64_t length, char* position) const { // Read the desired range of bytes. Azure::Core::Http::HttpRange range; - range.Offset = offset; - range.Length = length; + range.Offset = static_cast(offset); + range.Length = static_cast(length); Azure::Storage::Blobs::DownloadBlobOptions blob; blob.Range = range; @@ -143,6 +137,40 @@ class AbfsReadFile::Impl { reinterpret_cast(position), length); } + void preadvInternal( + uint64_t offset, + uint64_t length, + const std::vector>& buffers) const { + Azure::Core::Http::HttpRange range; + range.Offset = static_cast(offset); + range.Length = static_cast(length); + + Azure::Storage::Blobs::DownloadBlobOptions blob; + blob.Range = range; + auto response = fileClient_->download(blob); + + std::vector discardBuffer; + for (const auto& buffer : buffers) { + auto remaining = buffer.size(); + if (buffer.data() != nullptr) { + response.Value.BodyStream->ReadToCount( + reinterpret_cast(buffer.data()), remaining); + continue; + } + + const auto discardBufferSize = std::min(remaining, kDiscardBufferSize); + if (discardBuffer.size() < discardBufferSize) { + discardBuffer.resize(discardBufferSize); + } + + while (remaining > 0) { + const auto readSize = std::min(remaining, discardBuffer.size()); + response.Value.BodyStream->ReadToCount(discardBuffer.data(), readSize); + remaining -= readSize; + } + } + } + std::string filePath_; std::unique_ptr fileClient_; int64_t length_ = -1; @@ -155,36 +183,36 @@ AbfsReadFile::AbfsReadFile( } void AbfsReadFile::initialize(const FileOptions& options) { - return impl_->initialize(options); + impl_->initialize(options); } std::string_view AbfsReadFile::pread( uint64_t offset, uint64_t length, void* buffer, - File::IoStats* stats) const { - return impl_->pread(offset, length, buffer, stats); + const FileIoContext& context) const { + return impl_->pread(offset, length, buffer, context); } std::string AbfsReadFile::pread( uint64_t offset, uint64_t length, - File::IoStats* stats) const { - return impl_->pread(offset, length, stats); + const FileIoContext& context) const { + return impl_->pread(offset, length, context); } uint64_t AbfsReadFile::preadv( uint64_t offset, const std::vector>& buffers, - File::IoStats* stats) const { - return impl_->preadv(offset, buffers, stats); + const FileIoContext& context) const { + return impl_->preadv(offset, buffers, context); } uint64_t AbfsReadFile::preadv( folly::Range regions, folly::Range iobufs, - File::IoStats* stats) const { - return impl_->preadv(regions, iobufs, stats); + const FileIoContext& context) const { + return impl_->preadv(regions, iobufs, context); } uint64_t AbfsReadFile::size() const { diff --git a/velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.h b/velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.h index 942439c06c1..56c99f2aa7d 100644 --- a/velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.h +++ b/velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.h @@ -35,22 +35,22 @@ class AbfsReadFile final : public ReadFile { uint64_t offset, uint64_t length, void* buf, - File::IoStats* stats = nullptr) const final; + const FileIoContext& context = {}) const final; std::string pread( uint64_t offset, uint64_t length, - File::IoStats* stats = nullptr) const final; + const FileIoContext& context = {}) const final; uint64_t preadv( uint64_t offset, const std::vector>& buffers, - File::IoStats* stats = nullptr) const final; + const FileIoContext& context = {}) const final; uint64_t preadv( folly::Range regions, folly::Range iobufs, - File::IoStats* stats = nullptr) const final; + const FileIoContext& context = {}) const final; uint64_t size() const final; diff --git a/velox/connectors/hive/storage_adapters/abfs/AbfsUtil.cpp b/velox/connectors/hive/storage_adapters/abfs/AbfsUtil.cpp new file mode 100644 index 00000000000..5aefc098386 --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/AbfsUtil.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h" +#include "velox/common/config/Config.h" +#include "velox/connectors/hive/storage_adapters/abfs/AbfsPath.h" + +namespace facebook::velox::filesystems { + +std::vector extractCacheKeyFromConfig( + const config::ConfigBase& config) { + std::vector cacheKeys; + constexpr std::string_view authTypePrefix{kAzureAccountAuthType}; + for (const auto& [key, value] : config.rawConfigs()) { + if (key.find(authTypePrefix) == 0) { + // Extract the accountName after "fs.azure.account.auth.type.". + auto remaining = std::string_view(key).substr(authTypePrefix.size() + 1); + auto dot = remaining.find("."); + VELOX_USER_CHECK_NE( + dot, + std::string_view::npos, + "Invalid Azure account auth type key: {}", + key); + cacheKeys.emplace_back(CacheKey{remaining.substr(0, dot), value}); + } + } + return cacheKeys; +} + +} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h b/velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h index 925c6f91ece..1a6cf6e0a0e 100644 --- a/velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h +++ b/velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h @@ -26,6 +26,16 @@ constexpr std::string_view kAbfsScheme{"abfs://"}; constexpr std::string_view kAbfssScheme{"abfss://"}; } // namespace +class ConfigBase; + +struct CacheKey { + const std::string accountName; + const std::string authType; + + CacheKey(std::string_view accountName, std::string_view authType) + : accountName(accountName), authType(authType) {} +}; + inline bool isAbfsFile(const std::string_view filename) { return filename.find(kAbfsScheme) == 0 || filename.find(kAbfssScheme) == 0; } @@ -45,4 +55,7 @@ inline std::string throwStorageExceptionWithOperationDetails( VELOX_FAIL(errMsg); } +std::vector extractCacheKeyFromConfig( + const config::ConfigBase& config); + } // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/AzureClientProvider.h b/velox/connectors/hive/storage_adapters/abfs/AzureClientProvider.h index 291cc8e73a6..1a1a68f6d87 100644 --- a/velox/connectors/hive/storage_adapters/abfs/AzureClientProvider.h +++ b/velox/connectors/hive/storage_adapters/abfs/AzureClientProvider.h @@ -16,6 +16,8 @@ #pragma once +#include "velox/common/config/Config.h" +#include "velox/connectors/hive/storage_adapters/abfs/AbfsPath.h" #include "velox/connectors/hive/storage_adapters/abfs/AzureBlobClient.h" #include "velox/connectors/hive/storage_adapters/abfs/AzureDataLakeFileClient.h" diff --git a/velox/connectors/hive/storage_adapters/abfs/CMakeLists.txt b/velox/connectors/hive/storage_adapters/abfs/CMakeLists.txt index e01169bbd96..ed93e4df610 100644 --- a/velox/connectors/hive/storage_adapters/abfs/CMakeLists.txt +++ b/velox/connectors/hive/storage_adapters/abfs/CMakeLists.txt @@ -14,7 +14,23 @@ # for generated headers -velox_add_library(velox_abfs RegisterAbfsFileSystem.cpp) +velox_add_library( + velox_abfs + RegisterAbfsFileSystem.cpp + HEADERS + AbfsFileSystem.h + AbfsPath.h + AbfsReadFile.h + AbfsUtil.h + AbfsWriteFile.h + AzureBlobClient.h + AzureClientProvider.h + AzureClientProviderFactories.h + AzureClientProviderImpl.h + AzureDataLakeFileClient.h + DynamicSasTokenClientProvider.h + RegisterAbfsFileSystem.h +) if(VELOX_ENABLE_ABFS) velox_sources( @@ -23,9 +39,11 @@ if(VELOX_ENABLE_ABFS) AbfsFileSystem.cpp AbfsPath.cpp AbfsReadFile.cpp + AbfsUtil.cpp AbfsWriteFile.cpp AzureClientProviderFactories.cpp AzureClientProviderImpl.cpp + DynamicSasTokenClientProvider.cpp ) velox_link_libraries( diff --git a/velox/connectors/hive/storage_adapters/abfs/DynamicSasTokenClientProvider.cpp b/velox/connectors/hive/storage_adapters/abfs/DynamicSasTokenClientProvider.cpp new file mode 100644 index 00000000000..b98ae99d48e --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/DynamicSasTokenClientProvider.cpp @@ -0,0 +1,225 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/storage_adapters/abfs/DynamicSasTokenClientProvider.h" + +#include + +namespace facebook::velox::filesystems { + +namespace { + +constexpr int64_t kDefaultSasTokenRenewPeriod = 120; // in seconds + +Azure::DateTime getExpiry(const std::string_view& token) { + if (token.empty()) { + return Azure::DateTime::clock::time_point::min(); + } + + static constexpr std::string_view kSignedExpiry{"se="}; + static constexpr int32_t kSignedExpiryLen = 3; + + auto start = token.find(kSignedExpiry); + if (start == std::string::npos) { + return Azure::DateTime::clock::time_point::min(); + } + start += kSignedExpiryLen; + + auto end = token.find("&", start); + auto seValue = (end == std::string::npos) + ? std::string(token.substr(start)) + : std::string(token.substr(start, end - start)); + + seValue = Azure::Core::Url::Decode(seValue); + auto seDate = + Azure::DateTime::Parse(seValue, Azure::DateTime::DateFormat::Rfc3339); + + static constexpr std::string_view kSignedKeyExpiry = "ske="; + static constexpr int32_t kSignedKeyExpiryLen = 4; + + start = token.find(kSignedKeyExpiry); + if (start == std::string::npos) { + return seDate; + } + start += kSignedKeyExpiryLen; + + end = token.find("&", start); + auto skeValue = (end == std::string::npos) + ? std::string(token.substr(start)) + : std::string(token.substr(start, end - start)); + + skeValue = Azure::Core::Url::Decode(skeValue); + auto skeDate = + Azure::DateTime::Parse(skeValue, Azure::DateTime::DateFormat::Rfc3339); + + return std::min(skeDate, seDate); +} + +bool isNearExpiry(Azure::DateTime expiration, int64_t minExpirationInSeconds) { + if (expiration == Azure::DateTime::clock::time_point::min()) { + return true; + } + auto remaining = std::chrono::duration_cast( + expiration - Azure::DateTime::clock::now()) + .count(); + return remaining <= minExpirationInSeconds; +} + +class DynamicSasTokenDataLakeFileClient final : public AzureDataLakeFileClient { + public: + DynamicSasTokenDataLakeFileClient( + const std::shared_ptr& abfsPath, + const std::shared_ptr& sasKeyGenerator, + int64_t sasTokenRenewPeriod) + : abfsPath_(abfsPath), + sasKeyGenerator_(sasKeyGenerator), + sasTokenRenewPeriod_(sasTokenRenewPeriod) {} + + void create() override { + getWriteClient()->Create(); + } + + Azure::Storage::Files::DataLake::Models::PathProperties getProperties() + override { + return getReadClient()->GetProperties().Value; + } + + void append(const uint8_t* buffer, size_t size, uint64_t offset) override { + auto bodyStream = Azure::Core::IO::MemoryBodyStream(buffer, size); + getWriteClient()->Append(bodyStream, offset); + } + + void flush(uint64_t position) override { + getWriteClient()->Flush(position); + } + + void close() override {} + + std::string getUrl() override { + return getWriteClient()->GetUrl(); + } + + private: + std::shared_ptr abfsPath_; + std::shared_ptr sasKeyGenerator_; + int64_t sasTokenRenewPeriod_; + + std::unique_ptr writeClient_{nullptr}; + Azure::DateTime writeSasExpiration_{ + Azure::DateTime::clock::time_point::min()}; + + std::unique_ptr readClient_{nullptr}; + Azure::DateTime readSasExpiration_{Azure::DateTime::clock::time_point::min()}; + + DataLakeFileClient* getWriteClient() { + if (writeClient_ == nullptr || + isNearExpiry(writeSasExpiration_, sasTokenRenewPeriod_)) { + const auto& sas = sasKeyGenerator_->getSasToken( + abfsPath_->fileSystem(), abfsPath_->filePath(), kAbfsWriteOperation); + writeSasExpiration_ = getExpiry(sas); + writeClient_ = std::make_unique( + fmt::format("{}?{}", abfsPath_->getUrl(false), sas)); + } + return writeClient_.get(); + } + + DataLakeFileClient* getReadClient() { + if (readClient_ == nullptr || + isNearExpiry(readSasExpiration_, sasTokenRenewPeriod_)) { + const auto& sas = sasKeyGenerator_->getSasToken( + abfsPath_->fileSystem(), abfsPath_->filePath(), kAbfsReadOperation); + readSasExpiration_ = getExpiry(sas); + readClient_ = std::make_unique( + fmt::format("{}?{}", abfsPath_->getUrl(false), sas)); + } + return readClient_.get(); + } +}; + +class DynamicSasTokenBlobClient : public AzureBlobClient { + public: + DynamicSasTokenBlobClient( + const std::shared_ptr& abfsPath, + const std::shared_ptr& sasTokenProvider, + int64_t sasTokenRenewPeriod) + : abfsPath_(abfsPath), + sasTokenProvider_(sasTokenProvider), + sasTokenRenewPeriod_(sasTokenRenewPeriod) {} + + Azure::Response getProperties() + override { + return getBlobClient()->GetProperties(); + } + + Azure::Response download( + const Azure::Storage::Blobs::DownloadBlobOptions& options) override { + return getBlobClient()->Download(options); + } + + std::string getUrl() override { + return getBlobClient()->GetUrl(); + } + + private: + std::shared_ptr abfsPath_; + std::shared_ptr sasTokenProvider_; + int64_t sasTokenRenewPeriod_; + + std::unique_ptr blobClient_{nullptr}; + Azure::DateTime sasExpiration_{Azure::DateTime::clock::time_point::min()}; + + BlobClient* getBlobClient() { + if (blobClient_ == nullptr || + isNearExpiry(sasExpiration_, sasTokenRenewPeriod_)) { + const auto& sas = sasTokenProvider_->getSasToken( + abfsPath_->fileSystem(), abfsPath_->filePath(), kAbfsReadOperation); + sasExpiration_ = getExpiry(sas); + blobClient_ = std::make_unique( + fmt::format("{}?{}", abfsPath_->getUrl(true), sas)); + } + return blobClient_.get(); + } +}; + +} // namespace + +DynamicSasTokenClientProvider::DynamicSasTokenClientProvider( + const std::shared_ptr& sasTokenProvider) + : AzureClientProvider(), sasTokenProvider_(sasTokenProvider) {} + +void DynamicSasTokenClientProvider::init(const config::ConfigBase& config) { + sasTokenRenewPeriod_ = config.get( + kAzureSasTokenRenewPeriod, kDefaultSasTokenRenewPeriod); +} + +std::unique_ptr +DynamicSasTokenClientProvider::getReadFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) { + init(config); + return std::make_unique( + abfsPath, sasTokenProvider_, sasTokenRenewPeriod_); +} + +std::unique_ptr +DynamicSasTokenClientProvider::getWriteFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) { + init(config); + return std::make_unique( + abfsPath, sasTokenProvider_, sasTokenRenewPeriod_); +} +} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/DynamicSasTokenClientProvider.h b/velox/connectors/hive/storage_adapters/abfs/DynamicSasTokenClientProvider.h new file mode 100644 index 00000000000..ab1d53f0045 --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/DynamicSasTokenClientProvider.h @@ -0,0 +1,73 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/connectors/hive/storage_adapters/abfs/AzureBlobClient.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureClientProvider.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureClientProviderFactories.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureDataLakeFileClient.h" + +namespace facebook::velox::filesystems { + +/// SAS permissions reference: +/// https://learn.microsoft.com/en-us/rest/api/storageservices/create-service-sas#permissions-for-a-directory-container-or-blob +/// +/// ReadClient uses "read" permission for Download and GetProperties. +/// WriteClient uses "read" permission for GetProperties, and "write" permission +/// for other operations. +static const std::string kAbfsReadOperation{"read"}; +static const std::string kAbfsWriteOperation{"write"}; + +/// Interface for providing SAS tokens for ABFS file system operations. +/// Adapted from the Hadoop Azure implementation: +/// org.apache.hadoop.fs.azurebfs.extensions.SASTokenProvider +class SasTokenProvider { + public: + virtual ~SasTokenProvider() = default; + + virtual std::string getSasToken( + const std::string& fileSystem, + const std::string& path, + const std::string& operation) = 0; +}; + +/// Client provider that dynamically refreshes SAS tokens based on the +/// expiration time of the token. A SasTokenProvider for retrieving SAS tokens +/// must be provided to this class. Example for generating the SAS token can be +/// found in: +/// https://github.com/Azure/azure-sdk-for-cpp/blob/3d917e7c178f0a49b189395a907180084857cc70/sdk/storage/azure-storage-blobs/samples/blob_sas.cpp +class DynamicSasTokenClientProvider : public AzureClientProvider { + public: + explicit DynamicSasTokenClientProvider( + const std::shared_ptr& sasTokenProvider); + + std::unique_ptr getReadFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) override; + + std::unique_ptr getWriteFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) override; + + private: + void init(const config::ConfigBase& config); + + std::shared_ptr sasTokenProvider_; + int64_t sasTokenRenewPeriod_; +}; + +} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.cpp b/velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.cpp index dcb46aa0529..df7a4cea829 100644 --- a/velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.cpp +++ b/velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.cpp @@ -52,7 +52,8 @@ std::unique_ptr abfsWriteFileSinkGenerator( fileSystem->openFileForWrite(fileURI), fileURI, options.metricLogger, - options.stats); + options.stats, + options.fileSystemStats); } return nullptr; } @@ -68,40 +69,28 @@ void registerAbfsFileSystem() { void registerAzureClientProvider(const config::ConfigBase& config) { #ifdef VELOX_ENABLE_ABFS - for (const auto& [key, value] : config.rawConfigs()) { - constexpr std::string_view authTypePrefix{kAzureAccountAuthType}; - if (key.find(authTypePrefix) == 0) { - std::string_view skey = key; - // Extract the accountName after "fs.azure.account.auth.type.". - auto remaining = skey.substr(authTypePrefix.size() + 1); - auto dot = remaining.find("."); - VELOX_USER_CHECK_NE( - dot, - std::string_view::npos, - "Invalid Azure account auth type key: {}", - key); - auto accountName = std::string(remaining.substr(0, dot)); - if (value == kAzureSharedKeyAuthType) { - AzureClientProviderFactories::registerFactory( - accountName, [](const std::string&) { - return std::make_unique(); - }); - } else if (value == kAzureOAuthAuthType) { - AzureClientProviderFactories::registerFactory( - accountName, [](const std::string&) { - return std::make_unique(); - }); - } else if (value == kAzureSASAuthType) { - AzureClientProviderFactories::registerFactory( - accountName, [](const std::string&) { - return std::make_unique(); - }); - } else { - VELOX_USER_FAIL( - "Unsupported auth type {}, supported auth types are SharedKey, OAuth and SAS.", - value); - } + for (const auto& [accountName, authType] : + extractCacheKeyFromConfig(config)) { + if (authType == kAzureSharedKeyAuthType) { + AzureClientProviderFactories::registerFactory( + accountName, [](const std::string&) { + return std::make_unique(); + }); + } else if (authType == kAzureOAuthAuthType) { + AzureClientProviderFactories::registerFactory( + accountName, [](const std::string&) { + return std::make_unique(); + }); + } else if (authType == kAzureSASAuthType) { + AzureClientProviderFactories::registerFactory( + accountName, [](const std::string&) { + return std::make_unique(); + }); + } else { + VELOX_USER_FAIL( + "Unsupported auth type {}, supported auth types are SharedKey, OAuth and SAS.", + authType); } } #endif diff --git a/velox/connectors/hive/storage_adapters/abfs/tests/AbfsFileSystemTest.cpp b/velox/connectors/hive/storage_adapters/abfs/tests/AbfsFileSystemTest.cpp index 245c6931cb3..6b8db42f4e1 100644 --- a/velox/connectors/hive/storage_adapters/abfs/tests/AbfsFileSystemTest.cpp +++ b/velox/connectors/hive/storage_adapters/abfs/tests/AbfsFileSystemTest.cpp @@ -14,10 +14,11 @@ * limitations under the License. */ +#include #include #include -#include #include +#include #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/config/Config.h" @@ -30,6 +31,7 @@ #include "connectors/hive/storage_adapters/abfs/AzureClientProviderFactories.h" #include "connectors/hive/storage_adapters/abfs/AzureClientProviderImpl.h" #include "connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.h" +#include "velox/common/testutil/TempFilePath.h" #include "velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.h" #include "velox/connectors/hive/storage_adapters/abfs/AbfsWriteFile.h" #include "velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.h" @@ -37,15 +39,87 @@ #include "velox/connectors/hive/storage_adapters/abfs/tests/MockDataLakeFileClient.h" #include "velox/dwio/common/FileSink.h" #include "velox/exec/tests/utils/PortUtil.h" -#include "velox/exec/tests/utils/TempFilePath.h" using namespace facebook::velox; using namespace facebook::velox::filesystems; +using namespace facebook::velox::common::testutil; using ::facebook::velox::common::Region; namespace { -constexpr int kOneMB = 1 << 20; +constexpr int kOneMB = 1'048'576; + +struct RecordedDownload { + int64_t offset{0}; + int64_t length{0}; +}; + +struct InMemoryReadState { + std::string data; + std::vector downloads; +}; + +class InMemoryAzureBlobClient final : public AzureBlobClient { + public: + explicit InMemoryAzureBlobClient(std::shared_ptr state) + : state_(std::move(state)) {} + + Azure::Response getProperties() + override { + VELOX_FAIL("Unexpected getProperties call."); + } + + Azure::Response download( + const Azure::Storage::Blobs::DownloadBlobOptions& options) override { + const auto& range = options.Range.Value(); + const auto offset = range.Offset; + const auto length = range.Length.Value(); + + VELOX_CHECK_GE(offset, 0); + VELOX_CHECK_GE(length, 0); + VELOX_CHECK_LE(offset + length, static_cast(state_->data.size())); + + state_->downloads.push_back({offset, length}); + + Azure::Storage::Blobs::Models::DownloadBlobResult result; + result.BodyStream = std::make_unique( + reinterpret_cast(state_->data.data() + offset), + static_cast(length)); + return Azure::Response( + std::move(result), nullptr); + } + + std::string getUrl() override { + return std::string{kUrl}; + } + + private: + static constexpr std::string_view kUrl = + "https://unit.blob.core.windows.net/container/test-file"; + + std::shared_ptr state_; +}; + +class InMemoryAzureClientProvider final : public AzureClientProvider { + public: + explicit InMemoryAzureClientProvider(std::shared_ptr state) + : state_(std::move(state)) {} + + std::unique_ptr getReadFileClient( + const std::shared_ptr& path, + const config::ConfigBase& config) override { + return std::make_unique(state_); + } + + std::unique_ptr getWriteFileClient( + const std::shared_ptr& path, + const config::ConfigBase& config) override { + VELOX_FAIL("Unexpected getWriteFileClient call."); + } + + private: + std::shared_ptr state_; +}; class TestAzureClientProvider final : public AzureClientProvider { public: @@ -97,22 +171,24 @@ class AbfsFileSystemTest : public testing::Test { } static std::string generateRandomData(int size) { - static const char charset[] = + static constexpr std::string_view kCharacters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; + std::mt19937 generator(std::random_device{}()); + std::uniform_int_distribution distribution( + 0, kCharacters.size() - 1); std::string data(size, ' '); for (int i = 0; i < size; ++i) { - int index = rand() % (sizeof(charset) - 1); - data[i] = charset[index]; + data[i] = kCharacters[distribution(generator)]; } return data; } private: - static std::shared_ptr<::exec::test::TempFilePath> createFile() { - auto tempFile = exec::test::TempFilePath::create(); + static std::shared_ptr createFile() { + auto tempFile = TempFilePath::create(); tempFile->append("aaaaa"); tempFile->append("bbbbb"); tempFile->append(std::string(kOneMB, 'c')); @@ -121,6 +197,8 @@ class AbfsFileSystemTest : public testing::Test { } }; +namespace { + void readData(ReadFile* readFile) { ASSERT_EQ(readFile->size(), 15 + kOneMB); char buffer1[5]; @@ -169,6 +247,8 @@ void readData(ReadFile* readFile) { "ccccc"); } +} // namespace + TEST_F(AbfsFileSystemTest, readFile) { auto readFile = abfs_->openFileForRead(azuriteServer_->fileURI()); readData(readFile.get()); @@ -181,6 +261,74 @@ TEST_F(AbfsFileSystemTest, openFileForReadWithOptions) { readData(readFile.get()); } +TEST(AbfsReadFileTest, preadvUsesSingleDownloadForBuffersWithGaps) { + auto state = std::make_shared(); + state->data = "0123456789abcdefghijklmn"; + + registerAzureClientProviderFactory("unit", [state](const std::string&) { + return std::make_unique(state); + }); + + AbfsReadFile readFile{ + "abfs://container@unit.dfs.core.windows.net/test-file", + config::ConfigBase({})}; + FileOptions options; + options.fileSize = state->data.size(); + readFile.initialize(options); + + char firstBuffer[5]; + char secondBuffer[5]; + const std::vector> buffers = { + folly::Range(firstBuffer, sizeof(firstBuffer)), + folly::Range(nullptr, 7), + folly::Range(secondBuffer, sizeof(secondBuffer)), + }; + + ASSERT_EQ(17, readFile.preadv(2, buffers)); + ASSERT_EQ(state->downloads.size(), 1); + EXPECT_EQ(state->downloads[0].offset, 2); + EXPECT_EQ(state->downloads[0].length, 17); + EXPECT_EQ(std::string_view(firstBuffer, sizeof(firstBuffer)), "23456"); + EXPECT_EQ(std::string_view(secondBuffer, sizeof(secondBuffer)), "efghi"); +} + +TEST(AbfsReadFileTest, preadvUsesSingleDownloadForDefaultCoalescedGap) { + constexpr size_t kLeadingReadSize = 4; + constexpr size_t kGapSize = static_cast(512) * 1'024; + constexpr size_t kTrailingReadSize = 4; + + auto state = std::make_shared(); + state->data = std::string(kLeadingReadSize, 'a') + + std::string(kGapSize, 'x') + std::string(kTrailingReadSize, 'b'); + + registerAzureClientProviderFactory( + "unit-large-gap", [state](const std::string&) { + return std::make_unique(state); + }); + + AbfsReadFile readFile{ + "abfs://container@unit-large-gap.dfs.core.windows.net/test-file", + config::ConfigBase({})}; + FileOptions options; + options.fileSize = state->data.size(); + readFile.initialize(options); + + char firstBuffer[kLeadingReadSize]; + char secondBuffer[kTrailingReadSize]; + const std::vector> buffers = { + folly::Range(firstBuffer, sizeof(firstBuffer)), + folly::Range(nullptr, kGapSize), + folly::Range(secondBuffer, sizeof(secondBuffer)), + }; + + ASSERT_EQ(state->data.size(), readFile.preadv(0, buffers)); + ASSERT_EQ(state->downloads.size(), 1); + EXPECT_EQ(state->downloads[0].offset, 0); + EXPECT_EQ(state->downloads[0].length, state->data.size()); + EXPECT_EQ(std::string_view(firstBuffer, sizeof(firstBuffer)), "aaaa"); + EXPECT_EQ(std::string_view(secondBuffer, sizeof(secondBuffer)), "bbbb"); +} + TEST_F(AbfsFileSystemTest, openFileForReadWithInvalidOptions) { FileOptions options; options.fileSize = -kOneMB; @@ -211,7 +359,7 @@ TEST_F(AbfsFileSystemTest, multipleThreadsWithReadFile) { 0, sleepTimesInMicroseconds.size() - 1); for (int i = 0; i < 10; i++) { auto thread = std::thread([&] { - int index = distribution(generator); + const auto index = distribution(generator); while (!startThreads) { std::this_thread::yield(); } @@ -243,10 +391,9 @@ TEST(AbfsWriteFileTest, openFileForWriteTest) { reinterpret_cast(mockClient.get())->path(); AbfsWriteFile abfsWriteFile(kAbfsFile, mockClient); EXPECT_EQ(abfsWriteFile.size(), 0); - std::string dataContent = ""; + std::string dataContent; uint64_t totalSize = 0; - std::string randomData = - AbfsFileSystemTest::generateRandomData(1 * 1024 * 1024); + std::string randomData = AbfsFileSystemTest::generateRandomData(kOneMB); for (int i = 0; i < 8; ++i) { abfsWriteFile.append(randomData); dataContent += randomData; @@ -255,11 +402,11 @@ TEST(AbfsWriteFileTest, openFileForWriteTest) { abfsWriteFile.flush(); EXPECT_EQ(abfsWriteFile.size(), totalSize); - randomData = AbfsFileSystemTest::generateRandomData(9 * 1024 * 1024); + randomData = AbfsFileSystemTest::generateRandomData(9 * kOneMB); dataContent += randomData; abfsWriteFile.append(randomData); totalSize += randomData.size(); - randomData = AbfsFileSystemTest::generateRandomData(2 * 1024 * 1024); + randomData = AbfsFileSystemTest::generateRandomData(2 * kOneMB); dataContent += randomData; totalSize += randomData.size(); abfsWriteFile.append(randomData); @@ -301,7 +448,7 @@ TEST_F(AbfsFileSystemTest, clientProviderFactoryNotRegistered) { } TEST_F(AbfsFileSystemTest, registerAbfsFileSink) { - static const std::vector paths = { + const std::vector paths = { "abfs://test@test.dfs.core.windows.net/test", "abfss://test@test.dfs.core.windows.net/test"}; std::unordered_map config( diff --git a/velox/connectors/hive/storage_adapters/abfs/tests/AzuriteServer.h b/velox/connectors/hive/storage_adapters/abfs/tests/AzuriteServer.h index 037c4d9f38d..9921db93845 100644 --- a/velox/connectors/hive/storage_adapters/abfs/tests/AzuriteServer.h +++ b/velox/connectors/hive/storage_adapters/abfs/tests/AzuriteServer.h @@ -15,7 +15,7 @@ */ #include "velox/common/config/Config.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include #include @@ -28,6 +28,7 @@ namespace facebook::velox::filesystems { using namespace Azure::Storage::Blobs; +using TempDirectoryPath = common::testutil::TempDirectoryPath; static std::string_view kAzuriteServerExecutableName{"azurite-blob"}; static std::string_view kAzuriteSearchPath{":/usr/bin/azurite"}; diff --git a/velox/connectors/hive/storage_adapters/abfs/tests/CMakeLists.txt b/velox/connectors/hive/storage_adapters/abfs/tests/CMakeLists.txt index 8deb76801b0..a8a34147d78 100644 --- a/velox/connectors/hive/storage_adapters/abfs/tests/CMakeLists.txt +++ b/velox/connectors/hive/storage_adapters/abfs/tests/CMakeLists.txt @@ -18,9 +18,11 @@ add_executable( AbfsUtilTest.cpp AzureClientProvidersTest.cpp AzureClientProviderFactoriesTest.cpp + DynamicSasTokenClientProviderTest.cpp AzuriteServer.cpp MockDataLakeFileClient.cpp ) +velox_add_test_headers(velox_abfs_test AzuriteServer.h MockDataLakeFileClient.h) add_test(velox_abfs_test velox_abfs_test) target_link_libraries( diff --git a/velox/connectors/hive/storage_adapters/abfs/tests/DynamicSasTokenClientProviderTest.cpp b/velox/connectors/hive/storage_adapters/abfs/tests/DynamicSasTokenClientProviderTest.cpp new file mode 100644 index 00000000000..8793aea080b --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/tests/DynamicSasTokenClientProviderTest.cpp @@ -0,0 +1,143 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/storage_adapters/abfs/DynamicSasTokenClientProvider.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureClientProviderFactories.h" +#include "velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.h" + +#include "gtest/gtest.h" + +#include +#include +#include + +using namespace facebook::velox::filesystems; +using namespace facebook::velox; + +namespace { + +class MyDynamicAbfsSasTokenProvider : public SasTokenProvider { + public: + MyDynamicAbfsSasTokenProvider(int64_t expiration) + : expirationSeconds_(expiration) {} + + std::string getSasToken( + const std::string& fileSystem, + const std::string& path, + const std::string& operation) override { + const auto lastSlash = path.find_last_of("/"); + const auto containerName = path.substr(0, lastSlash); + const auto blobName = path.substr(lastSlash + 1); + + Azure::Storage::Sas::BlobSasBuilder sasBuilder; + sasBuilder.ExpiresOn = Azure::DateTime::clock::now() + + std::chrono::seconds(expirationSeconds_); + sasBuilder.BlobContainerName = containerName; + sasBuilder.BlobName = blobName; + sasBuilder.Resource = Azure::Storage::Sas::BlobSasResource::Blob; + sasBuilder.SetPermissions( + Azure::Storage::Sas::BlobSasPermissions::Read & + Azure::Storage::Sas::BlobSasPermissions::Write); + + std::string sasToken = sasBuilder.GenerateSasToken( + Azure::Storage::StorageSharedKeyCredential( + "test", + "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==")); + + // Remove the leading '?' from the SAS token. + if (sasToken[0] == '?') { + sasToken = sasToken.substr(1); + } + + return sasToken; + } + + private: + int64_t expirationSeconds_; +}; + +} // namespace + +TEST(DynamicSasTokenClientProviderTest, dynamicSasToken) { + { + const std::string account = "account1"; + const config::ConfigBase config( + {{"fs.azure.account.auth.type.account1.dfs.core.windows.net", "SAS"}, + {"fs.azure.sas.token.renew.period.for.streams", "1"}}, + false); + registerAzureClientProviderFactory(account, [](const std::string&) { + auto sasTokenProvider = + std::make_shared(3); + return std::make_unique(sasTokenProvider); + }); + + auto abfsPath = std::make_shared( + fmt::format("abfs://abc@{}.dfs.core.windows.net/file", account)); + auto readClient = + AzureClientProviderFactories::getReadFileClient(abfsPath, config); + auto writeClient = + AzureClientProviderFactories::getWriteFileClient(abfsPath, config); + + auto readUrl = readClient->getUrl(); + auto writeUrl = writeClient->getUrl(); + + // Let the current time pass 3 seconds to ensure the SAS token is expired. + std::this_thread::sleep_for(std::chrono::seconds(3)); // NOLINT + + auto newReadUrl = readClient->getUrl(); + ASSERT_NE(readUrl, newReadUrl); + // The SAS token should be reused. + ASSERT_EQ(newReadUrl, readClient->getUrl()); + + auto newWriteUrl = writeClient->getUrl(); + ASSERT_NE(writeUrl, newWriteUrl); + // The SAS token should be reused. + ASSERT_EQ(newWriteUrl, writeClient->getUrl()); + } + + { + // SAS token expired by setting the renewal period to 120 seconds. + const std::string account = "account2"; + const config::ConfigBase config( + {{"fs.azure.account.auth.type.account2.dfs.core.windows.net", "SAS"}, + {"fs.azure.sas.token.renew.period.for.streams", "120"}}, + false); + registerAzureClientProviderFactory(account, [](const std::string&) { + auto sasTokenProvider = + std::make_shared(60); + return std::make_unique(sasTokenProvider); + }); + + auto abfsPath = std::make_shared( + fmt::format("abfs://abc@{}.dfs.core.windows.net/file", account)); + auto readClient = + AzureClientProviderFactories::getReadFileClient(abfsPath, config); + auto writeClient = + AzureClientProviderFactories::getWriteFileClient(abfsPath, config); + + auto readUrl = readClient->getUrl(); + auto writeUrl = writeClient->getUrl(); + + // Let the current time pass 3 seconds to ensure the timestamp in the SAS + // token is updated. + std::this_thread::sleep_for(std::chrono::seconds(3)); // NOLINT + + // Sas token should be renewed because the time left is less than the + // renewal period. + ASSERT_NE(readUrl, readClient->getUrl()); + ASSERT_NE(writeUrl, writeClient->getUrl()); + } +} diff --git a/velox/connectors/hive/storage_adapters/abfs/tests/MockDataLakeFileClient.h b/velox/connectors/hive/storage_adapters/abfs/tests/MockDataLakeFileClient.h index 560294414c8..d021a7d4259 100644 --- a/velox/connectors/hive/storage_adapters/abfs/tests/MockDataLakeFileClient.h +++ b/velox/connectors/hive/storage_adapters/abfs/tests/MockDataLakeFileClient.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#include "velox/exec/tests/utils/TempFilePath.h" +#include "velox/common/testutil/TempFilePath.h" #include "velox/connectors/hive/storage_adapters/abfs/AzureDataLakeFileClient.h" using namespace Azure::Storage::Files::DataLake::Models; +using namespace facebook::velox::common::testutil; namespace facebook::velox::filesystems { @@ -26,7 +27,7 @@ namespace facebook::velox::filesystems { class MockDataLakeFileClient : public AzureDataLakeFileClient { public: MockDataLakeFileClient() { - auto tempFile = velox::exec::test::TempFilePath::create(); + auto tempFile = TempFilePath::create(); filePath_ = tempFile->getPath(); } diff --git a/velox/connectors/hive/storage_adapters/gcs/CMakeLists.txt b/velox/connectors/hive/storage_adapters/gcs/CMakeLists.txt index 7e110edac19..2991a5e3123 100644 --- a/velox/connectors/hive/storage_adapters/gcs/CMakeLists.txt +++ b/velox/connectors/hive/storage_adapters/gcs/CMakeLists.txt @@ -14,7 +14,17 @@ # for generated headers -velox_add_library(velox_gcs RegisterGcsFileSystem.cpp) +velox_add_library( + velox_gcs + RegisterGcsFileSystem.cpp + HEADERS + GcsFileSystem.h + GcsOAuthCredentialsProvider.h + GcsReadFile.h + GcsUtil.h + GcsWriteFile.h + RegisterGcsFileSystem.h +) if(VELOX_ENABLE_GCS) velox_sources(velox_gcs PRIVATE GcsFileSystem.cpp GcsUtil.cpp GcsWriteFile.cpp GcsReadFile.cpp) diff --git a/velox/connectors/hive/storage_adapters/gcs/GcsFileSystem.cpp b/velox/connectors/hive/storage_adapters/gcs/GcsFileSystem.cpp index 897f7056171..8ec4873e28d 100644 --- a/velox/connectors/hive/storage_adapters/gcs/GcsFileSystem.cpp +++ b/velox/connectors/hive/storage_adapters/gcs/GcsFileSystem.cpp @@ -72,8 +72,9 @@ class GcsFileSystem::Impl { public: Impl(const std::string& bucket, const config::ConfigBase* config) : bucket_(bucket), - hiveConfig_(std::make_shared( - std::make_shared(config->rawConfigsCopy()))) {} + hiveConfig_( + std::make_shared(std::make_shared( + config->rawConfigsCopy()))) {} ~Impl() = default; diff --git a/velox/connectors/hive/storage_adapters/gcs/GcsReadFile.cpp b/velox/connectors/hive/storage_adapters/gcs/GcsReadFile.cpp index 072a4f7a37f..3379e15bd6c 100644 --- a/velox/connectors/hive/storage_adapters/gcs/GcsReadFile.cpp +++ b/velox/connectors/hive/storage_adapters/gcs/GcsReadFile.cpp @@ -59,7 +59,7 @@ class GcsReadFile::Impl { uint64_t length, void* buffer, std::atomic& bytesRead, - filesystems::File::IoStats* stats = nullptr) const { + const FileIoContext& context) const { preadInternal(offset, length, static_cast(buffer), bytesRead); return {static_cast(buffer), length}; } @@ -68,7 +68,7 @@ class GcsReadFile::Impl { uint64_t offset, uint64_t length, std::atomic& bytesRead, - filesystems::File::IoStats* stats = nullptr) const { + const FileIoContext& context) const { std::string result(length, 0); char* position = result.data(); preadInternal(offset, length, position, bytesRead); @@ -79,7 +79,7 @@ class GcsReadFile::Impl { uint64_t offset, const std::vector>& buffers, std::atomic& bytesRead, - filesystems::File::IoStats* stats = nullptr) const { + const FileIoContext& context) const { // 'buffers' contains Ranges(data, size) with some gaps (data = nullptr) in // between. This call must populate the ranges (except gap ranges) // sequentially starting from 'offset'. If a range pointer is nullptr, the @@ -158,21 +158,21 @@ std::string_view GcsReadFile::pread( uint64_t offset, uint64_t length, void* buffer, - filesystems::File::IoStats* stats) const { - return impl_->pread(offset, length, buffer, bytesRead_, stats); + const FileIoContext& context) const { + return impl_->pread(offset, length, buffer, bytesRead_, context); } std::string GcsReadFile::pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats) const { - return impl_->pread(offset, length, bytesRead_, stats); + const FileIoContext& context) const { + return impl_->pread(offset, length, bytesRead_, context); } uint64_t GcsReadFile::preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats) const { - return impl_->preadv(offset, buffers, bytesRead_, stats); + const FileIoContext& context) const { + return impl_->preadv(offset, buffers, bytesRead_, context); } uint64_t GcsReadFile::size() const { diff --git a/velox/connectors/hive/storage_adapters/gcs/GcsReadFile.h b/velox/connectors/hive/storage_adapters/gcs/GcsReadFile.h index a3d328996ec..035bc597ab5 100644 --- a/velox/connectors/hive/storage_adapters/gcs/GcsReadFile.h +++ b/velox/connectors/hive/storage_adapters/gcs/GcsReadFile.h @@ -38,17 +38,17 @@ class GcsReadFile : public ReadFile { uint64_t offset, uint64_t length, void* buffer, - filesystems::File::IoStats* stats = nullptr) const override; + const FileIoContext& context = {}) const override; std::string pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats = nullptr) const override; + const FileIoContext& context = {}) const override; uint64_t preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats = nullptr) const override; + const FileIoContext& context = {}) const override; uint64_t size() const override; diff --git a/velox/connectors/hive/storage_adapters/gcs/RegisterGcsFileSystem.cpp b/velox/connectors/hive/storage_adapters/gcs/RegisterGcsFileSystem.cpp index d1ae189c7f8..13e9c94e1db 100644 --- a/velox/connectors/hive/storage_adapters/gcs/RegisterGcsFileSystem.cpp +++ b/velox/connectors/hive/storage_adapters/gcs/RegisterGcsFileSystem.cpp @@ -103,7 +103,8 @@ std::unique_ptr gcsWriteFileSinkGenerator( fileSystem->openFileForWrite(fileURI, {{}, options.pool, std::nullopt}), fileURI, options.metricLogger, - options.stats); + options.stats, + options.fileSystemStats); } return nullptr; } diff --git a/velox/connectors/hive/storage_adapters/gcs/tests/CMakeLists.txt b/velox/connectors/hive/storage_adapters/gcs/tests/CMakeLists.txt index f462c45f9d3..b75f7492c5a 100644 --- a/velox/connectors/hive/storage_adapters/gcs/tests/CMakeLists.txt +++ b/velox/connectors/hive/storage_adapters/gcs/tests/CMakeLists.txt @@ -13,6 +13,7 @@ # limitations under the License. add_executable(velox_gcs_file_test GcsUtilTest.cpp GcsFileSystemTest.cpp) +velox_add_test_headers(velox_gcs_file_test GcsEmulator.h) add_test(velox_gcs_file_test velox_gcs_file_test) target_link_libraries( velox_gcs_file_test @@ -23,7 +24,7 @@ target_link_libraries( velox_file velox_gcs velox_hive_connector - velox_temp_path + velox_test_util GTest::gmock GTest::gtest GTest::gtest_main @@ -37,6 +38,8 @@ target_link_libraries( velox_gcs velox_hive_config velox_core + velox_dwio_parquet_reader + velox_dwio_parquet_writer velox_exec_test_lib velox_dwio_common_exception velox_exec diff --git a/velox/connectors/hive/storage_adapters/gcs/tests/GcsFileSystemTest.cpp b/velox/connectors/hive/storage_adapters/gcs/tests/GcsFileSystemTest.cpp index 34c62d4a472..4aa0df8d123 100644 --- a/velox/connectors/hive/storage_adapters/gcs/tests/GcsFileSystemTest.cpp +++ b/velox/connectors/hive/storage_adapters/gcs/tests/GcsFileSystemTest.cpp @@ -17,13 +17,15 @@ #include "velox/connectors/hive/storage_adapters/gcs/GcsFileSystem.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/File.h" +#include "velox/common/testutil/TempFilePath.h" #include "velox/connectors/hive/storage_adapters/gcs/GcsUtil.h" #include "velox/connectors/hive/storage_adapters/gcs/RegisterGcsFileSystem.h" #include "velox/connectors/hive/storage_adapters/gcs/tests/GcsEmulator.h" -#include "velox/exec/tests/utils/TempFilePath.h" #include "gtest/gtest.h" +using namespace facebook::velox::common::testutil; + namespace facebook::velox::filesystems { namespace { @@ -60,8 +62,8 @@ TEST_F(GcsFileSystemTest, readFile) { EXPECT_EQ(size, ref_size); EXPECT_EQ(readFile->pread(0, size), kLoremIpsum); - char buffer1[size]; - ASSERT_EQ(readFile->pread(0, size, &buffer1), kLoremIpsum); + std::vector buffer1(size); + ASSERT_EQ(readFile->pread(0, size, buffer1.data()), kLoremIpsum); ASSERT_EQ(readFile->size(), ref_size); char buffer2[50]; @@ -252,7 +254,7 @@ TEST_F(GcsFileSystemTest, credentialsConfig) { "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/foo-email%40foo-project.iam.gserviceaccount.com" })"""; - auto jsonFile = exec::test::TempFilePath::create(); + auto jsonFile = TempFilePath::create(); std::ofstream credsOut(jsonFile->getPath()); credsOut << kCreds; credsOut.close(); diff --git a/velox/connectors/hive/storage_adapters/gcs/tests/GcsMultipleEndpointsTest.cpp b/velox/connectors/hive/storage_adapters/gcs/tests/GcsMultipleEndpointsTest.cpp index 00bb3310a06..e8ae75c715e 100644 --- a/velox/connectors/hive/storage_adapters/gcs/tests/GcsMultipleEndpointsTest.cpp +++ b/velox/connectors/hive/storage_adapters/gcs/tests/GcsMultipleEndpointsTest.cpp @@ -18,6 +18,7 @@ #include #include "gtest/gtest.h" +#include "velox/connectors/ConnectorRegistry.h" #include "velox/connectors/hive/HiveConnector.h" #include "velox/connectors/hive/HiveConnectorSplit.h" #include "velox/connectors/hive/storage_adapters/gcs/RegisterGcsFileSystem.h" @@ -70,8 +71,10 @@ class GcsMultipleEndpointsTest : public testing::Test, std::string(connectorId2), gcsEmulatorTwo_->hiveConfig(config2Override), ioExecutor_.get()); - connector::registerConnector(hiveConnector1); - connector::registerConnector(hiveConnector2); + connector::ConnectorRegistry::global().insert( + hiveConnector1->connectorId(), hiveConnector1); + connector::ConnectorRegistry::global().insert( + hiveConnector2->connectorId(), hiveConnector2); } void TearDown() override { @@ -100,7 +103,8 @@ class GcsMultipleEndpointsTest : public testing::Test, // Second column contains details about written files. auto details = results->childAt(exec::TableWriteTraits::kFragmentChannel) ->as>(); - folly::dynamic obj = folly::parseJson(details->valueAt(1)); + folly::dynamic obj = + folly::parseJson(std::string_view(details->valueAt(1))); return obj["fileWriteInfos"]; } @@ -191,8 +195,8 @@ TEST_F(GcsMultipleEndpointsTest, baseEndpoints) { testJoin(kExpectedRows, gcsBucket, kConnectorId1, kConnectorId2); - connector::unregisterConnector(std::string(kConnectorId1)); - connector::unregisterConnector(std::string(kConnectorId2)); + connector::ConnectorRegistry::global().erase(std::string(kConnectorId1)); + connector::ConnectorRegistry::global().erase(std::string(kConnectorId2)); } } // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/hdfs/CMakeLists.txt b/velox/connectors/hive/storage_adapters/hdfs/CMakeLists.txt index 8c49f3e11ed..087397c1c38 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/CMakeLists.txt +++ b/velox/connectors/hive/storage_adapters/hdfs/CMakeLists.txt @@ -14,7 +14,16 @@ # for generated headers -velox_add_library(velox_hdfs RegisterHdfsFileSystem.cpp) +velox_add_library( + velox_hdfs + RegisterHdfsFileSystem.cpp + HEADERS + HdfsFileSystem.h + HdfsReadFile.h + HdfsUtil.h + HdfsWriteFile.h + RegisterHdfsFileSystem.h +) if(VELOX_ENABLE_HDFS) velox_sources(velox_hdfs PRIVATE HdfsFileSystem.cpp HdfsReadFile.cpp HdfsWriteFile.cpp) diff --git a/velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.cpp b/velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.cpp index aedd1fe44d3..46522ab2a6f 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.cpp +++ b/velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.cpp @@ -56,11 +56,29 @@ class HdfsFileSystem::Impl { } ~Impl() { - LOG(INFO) << "Disconnecting HDFS file system"; - int disconnectResult = driver_->Disconnect(hdfsClient_); - if (disconnectResult != 0) { - LOG(WARNING) << "hdfs disconnect failure in HdfsReadFile close: " - << errno; + if (!closed_) { + LOG(WARNING) + << "The HdfsFileSystem instance is not closed upon destruction. You must explicitly call the close() API before JVM termination to ensure proper disconnection."; + } + } + + // The HdfsFileSystem::Disconnect operation requires the JVM method + // definitions to be loaded within an active JVM process. + // Therefore, it must be invoked before the JVM shuts down. + + // To address this, we’ve introduced a new close() API that performs the + // disconnect operation. Third-party applications can call this close() method + // prior to JVM termination to ensure proper cleanup. + void close() { + if (!closed_) { + LOG(WARNING) << "Disconnecting HDFS file system"; + int disconnectResult = driver_->Disconnect(hdfsClient_); + if (disconnectResult != 0) { + LOG(WARNING) << "hdfs disconnect failure in HdfsReadFile close: " + << errno; + } + + closed_ = true; } } @@ -75,6 +93,7 @@ class HdfsFileSystem::Impl { private: hdfsFS hdfsClient_; filesystems::arrow::io::internal::LibHdfsShim* driver_; + bool closed_ = false; }; HdfsFileSystem::HdfsFileSystem( @@ -109,6 +128,10 @@ std::unique_ptr HdfsFileSystem::openFileForWrite( impl_->hdfsShim(), impl_->hdfsClient(), path); } +void HdfsFileSystem::close() { + impl_->close(); +} + bool HdfsFileSystem::isHdfsFile(const std::string_view filePath) { return (filePath.find(kScheme) == 0) || (filePath.find(kViewfsScheme) == 0); } diff --git a/velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.h b/velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.h index cebe40aa890..9720bb13034 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.h +++ b/velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.h @@ -61,6 +61,8 @@ class HdfsFileSystem : public FileSystem { std::string_view path, const FileOptions& options = {}) override; + void close(); + // Deletes the hdfs files. void remove(std::string_view path) override; diff --git a/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.cpp b/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.cpp index affc1dfd2ed..94adb31a742 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.cpp +++ b/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.cpp @@ -103,12 +103,17 @@ class HdfsReadFile::Impl { } } - std::string_view pread(uint64_t offset, uint64_t length, void* buf) const { + std::string_view pread( + uint64_t offset, + uint64_t length, + void* buf, + const FileIoContext& context) const { preadInternal(offset, length, static_cast(buf)); return {static_cast(buf), length}; } - std::string pread(uint64_t offset, uint64_t length) const { + std::string + pread(uint64_t offset, uint64_t length, const FileIoContext& context) const { std::string result(length, 0); char* pos = result.data(); preadInternal(offset, length, pos); @@ -163,15 +168,15 @@ std::string_view HdfsReadFile::pread( uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats) const { - return pImpl->pread(offset, length, buf); + const FileIoContext& context) const { + return pImpl->pread(offset, length, buf, context); } std::string HdfsReadFile::pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats) const { - return pImpl->pread(offset, length); + const FileIoContext& context) const { + return pImpl->pread(offset, length, context); } uint64_t HdfsReadFile::size() const { diff --git a/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.h b/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.h index ddd35e511a7..6208702eab5 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.h +++ b/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.h @@ -38,12 +38,12 @@ class HdfsReadFile final : public ReadFile { uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats = nullptr) const final; + const FileIoContext& context = {}) const final; std::string pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats = nullptr) const final; + const FileIoContext& context = {}) const final; uint64_t size() const final; diff --git a/velox/connectors/hive/storage_adapters/hdfs/HdfsWriteFile.cpp b/velox/connectors/hive/storage_adapters/hdfs/HdfsWriteFile.cpp index be668a3133e..26d43ccb910 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/HdfsWriteFile.cpp +++ b/velox/connectors/hive/storage_adapters/hdfs/HdfsWriteFile.cpp @@ -54,12 +54,14 @@ HdfsWriteFile::~HdfsWriteFile() { void HdfsWriteFile::close() { int success = driver_->CloseFile(hdfsClient_, hdfsFile_); + common::testutil::TestValue::adjust( + "facebook::velox::connectors::hive::HdfsWriteFile::close", &success); + hdfsFile_ = nullptr; VELOX_CHECK_EQ( success, 0, "Failed to close hdfs file: {}", driver_->GetLastExceptionRootCause()); - hdfsFile_ = nullptr; } void HdfsWriteFile::flush() { diff --git a/velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.cpp b/velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.cpp index 1f23179f0a7..d8b31e806d1 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.cpp +++ b/velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.cpp @@ -28,44 +28,48 @@ namespace facebook::velox::filesystems { #ifdef VELOX_ENABLE_HDFS std::mutex mtx; +folly::ConcurrentHashMap> + registeredFilesystems; + std::function(std::shared_ptr, std::string_view)> hdfsFileSystemGenerator() { - static auto filesystemGenerator = [](std::shared_ptr - properties, - std::string_view filePath) { - static folly::ConcurrentHashMap> - filesystems; - static folly:: - ConcurrentHashMap> - hdfsInitiationFlags; - HdfsServiceEndpoint endpoint = - HdfsFileSystem::getServiceEndpoint(filePath, properties.get()); - std::string hdfsIdentity = endpoint.identity(); - if (filesystems.find(hdfsIdentity) != filesystems.end()) { - return filesystems[hdfsIdentity]; - } - std::unique_lock lk(mtx, std::defer_lock); - /// If the init flag for a given hdfs identity is not found, - /// create one for init use. It's a singleton. - if (hdfsInitiationFlags.find(hdfsIdentity) == hdfsInitiationFlags.end()) { - lk.lock(); - if (hdfsInitiationFlags.find(hdfsIdentity) == hdfsInitiationFlags.end()) { - std::shared_ptr initiationFlagPtr = - std::make_shared(); - hdfsInitiationFlags.insert(hdfsIdentity, initiationFlagPtr); - } - lk.unlock(); - } - folly::call_once( - *hdfsInitiationFlags[hdfsIdentity].get(), - [&properties, endpoint, hdfsIdentity]() { - auto filesystem = - std::make_shared(properties, endpoint); - filesystems.insert(hdfsIdentity, filesystem); - }); - return filesystems[hdfsIdentity]; - }; + static auto filesystemGenerator = + [](std::shared_ptr properties, + std::string_view filePath) { + static folly:: + ConcurrentHashMap> + hdfsInitiationFlags; + HdfsServiceEndpoint endpoint = + HdfsFileSystem::getServiceEndpoint(filePath, properties.get()); + std::string hdfsIdentity = endpoint.identity(); + if (registeredFilesystems.find(hdfsIdentity) != + registeredFilesystems.end()) { + return registeredFilesystems[hdfsIdentity]; + } + std::unique_lock lk(mtx, std::defer_lock); + /// If the init flag for a given hdfs identity is not found, + /// create one for init use. It's a singleton. + if (hdfsInitiationFlags.find(hdfsIdentity) == + hdfsInitiationFlags.end()) { + lk.lock(); + if (hdfsInitiationFlags.find(hdfsIdentity) == + hdfsInitiationFlags.end()) { + std::shared_ptr initiationFlagPtr = + std::make_shared(); + hdfsInitiationFlags.insert(hdfsIdentity, initiationFlagPtr); + } + lk.unlock(); + } + folly::call_once( + *hdfsInitiationFlags[hdfsIdentity].get(), + [&properties, endpoint, hdfsIdentity]() { + auto filesystem = + std::make_shared(properties, endpoint); + registeredFilesystems.insert(hdfsIdentity, filesystem); + }); + return registeredFilesystems[hdfsIdentity]; + }; return filesystemGenerator; } @@ -85,7 +89,8 @@ hdfsWriteFileSinkGenerator() { fileSystem->openFileForWrite(pathSuffix), fileURI, options.metricLogger, - options.stats); + options.stats, + options.fileSystemStats); } return static_cast>( nullptr); diff --git a/velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.h b/velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.h index 6f6f0c032bd..18eef4aca17 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.h +++ b/velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.h @@ -16,8 +16,15 @@ #pragma once +#include "folly/concurrency/ConcurrentHashMap.h" + namespace facebook::velox::filesystems { +class HdfsFileSystem; + +extern folly::ConcurrentHashMap> + registeredFilesystems; + // Register the HDFS. void registerHdfsFileSystem(); diff --git a/velox/connectors/hive/storage_adapters/hdfs/tests/CMakeLists.txt b/velox/connectors/hive/storage_adapters/hdfs/tests/CMakeLists.txt index a0ad9d67e99..7a6b0832765 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/tests/CMakeLists.txt +++ b/velox/connectors/hive/storage_adapters/hdfs/tests/CMakeLists.txt @@ -13,6 +13,7 @@ # limitations under the License. add_executable(velox_hdfs_file_test HdfsFileSystemTest.cpp HdfsMiniCluster.cpp HdfsUtilTest.cpp) +velox_add_test_headers(velox_hdfs_file_test HdfsMiniCluster.h) add_test(velox_hdfs_file_test velox_hdfs_file_test) target_link_libraries( @@ -32,11 +33,15 @@ target_link_libraries( target_compile_options(velox_hdfs_file_test PRIVATE -Wno-deprecated-declarations) add_executable(velox_hdfs_insert_test HdfsInsertTest.cpp HdfsMiniCluster.cpp HdfsUtilTest.cpp) +velox_add_test_headers(velox_hdfs_insert_test HdfsMiniCluster.h) add_test(velox_hdfs_insert_test velox_hdfs_insert_test) target_link_libraries( velox_hdfs_insert_test + velox_dwio_parquet_reader + velox_dwio_parquet_writer + velox_hdfs velox_exec_test_lib velox_exec GTest::gtest @@ -49,7 +54,7 @@ target_compile_options(velox_hdfs_insert_test PRIVATE -Wno-deprecated-declaratio # velox_hdfs_insert_test and velox_hdfs_file_test two tests can't run in # parallel due to the port conflict in Hadoop NameNode and DataNode. The # namenode port conflict can be resolved using the -nnport configuration in -# hadoop-mapreduce-client-jobclient-3.3.0-tests.jar. However the data node port +# hadoop-mapreduce-client-jobclient-3.3.6-tests.jar. However the data node port # cannot be configured. Therefore, we need to make sure that # velox_hdfs_file_test runs only after velox_hdfs_insert_test has finished. set_tests_properties(velox_hdfs_insert_test PROPERTIES DEPENDS velox_hdfs_file_test) diff --git a/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsFileSystemTest.cpp b/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsFileSystemTest.cpp index e5c0883284b..e20e8761a07 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsFileSystemTest.cpp +++ b/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsFileSystemTest.cpp @@ -21,16 +21,18 @@ #include "gtest/gtest.h" #include "velox/common/base/Exceptions.h" #include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/testutil/TempFilePath.h" +#include "velox/common/testutil/TestValue.h" #include "velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.h" #include "velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.h" #include "velox/connectors/hive/storage_adapters/hdfs/tests/HdfsMiniCluster.h" #include "velox/core/QueryConfig.h" -#include "velox/exec/tests/utils/TempFilePath.h" #include "velox/external/hdfs/ArrowHdfsInternal.h" #include using namespace facebook::velox; +using namespace facebook::velox::common::testutil; using filesystems::arrow::io::internal::LibHdfsShim; @@ -72,6 +74,11 @@ class HdfsFileSystemTest : public testing::Test { } static void TearDownTestSuite() { + for (const auto& [_, filesystem] : + facebook::velox::filesystems::registeredFilesystems) { + filesystem->close(); + } + miniCluster->stop(); } @@ -88,8 +95,8 @@ class HdfsFileSystemTest : public testing::Test { static std::string fullDestinationPath_; private: - static std::shared_ptr<::exec::test::TempFilePath> createFile() { - auto tempFile = exec::test::TempFilePath::create(); + static std::shared_ptr createFile() { + auto tempFile = TempFilePath::create(); tempFile->append("aaaaa"); tempFile->append("bbbbb"); tempFile->append(std::string(kOneMB, 'c')); @@ -520,3 +527,34 @@ TEST_F(HdfsFileSystemTest, readFailures) { std::string(miniCluster->nameNodePort())); verifyFailures(driver, hdfs); } + +DEBUG_ONLY_TEST_F(HdfsFileSystemTest, writeFilePreventsDoubleClose) { + common::testutil::TestValue::enable(); + + int closeCallCount = 0; + + SCOPED_TESTVALUE_SET( + "facebook::velox::connectors::hive::HdfsWriteFile::close", + std::function([&closeCallCount](int* success) { + ++closeCallCount; + if (closeCallCount == 1) { + *success = -1; + } + })); + + auto writeFile = openFileForWrite("/test_double_close.txt"); + + writeFile->append("test data"); + writeFile->flush(); + + VELOX_ASSERT_THROW(writeFile->close(), "Failed to close hdfs file:"); + + EXPECT_EQ(closeCallCount, 1); + + // Destructor should not call close() again because hdfsFile_ is nullptr + // The closeCallCount should remain 1. + writeFile.reset(); + EXPECT_EQ(closeCallCount, 1); + + common::testutil::TestValue::disable(); +} diff --git a/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsInsertTest.cpp b/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsInsertTest.cpp index 9ec9a125415..ed2287a7c42 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsInsertTest.cpp +++ b/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsInsertTest.cpp @@ -17,6 +17,7 @@ #include "gtest/gtest.h" #include "velox/common/memory/Memory.h" +#include "velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.h" #include "velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.h" #include "velox/connectors/hive/storage_adapters/hdfs/tests/HdfsMiniCluster.h" #include "velox/connectors/hive/storage_adapters/test_common/InsertTest.h" @@ -47,6 +48,10 @@ class HdfsInsertTest : public testing::Test, public InsertTest { } void TearDown() override { + for (const auto& [_, filesystem] : + facebook::velox::filesystems::registeredFilesystems) { + filesystem->close(); + } InsertTest::TearDown(); miniCluster->stop(); } diff --git a/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsMiniCluster.h b/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsMiniCluster.h index c54ae9589b3..62b45609d80 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsMiniCluster.h +++ b/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsMiniCluster.h @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "velox/exec/tests/utils/TempDirectoryPath.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include #include @@ -22,11 +22,14 @@ #include "boost/process.hpp" namespace facebook::velox::filesystems::test { + +using TempDirectoryPath = common::testutil::TempDirectoryPath; + static const std::string kMiniClusterExecutableName{"hadoop"}; static const std::string kHadoopSearchPath{":/usr/local/hadoop/bin"}; static const std::string kJarCommand{"jar"}; static const std::string kMiniclusterJar{ - "/share/hadoop/mapreduce/hadoop-mapreduce-client-jobclient-3.3.0-tests.jar"}; + "/share/hadoop/mapreduce/hadoop-mapreduce-client-jobclient-3.3.6-tests.jar"}; static const std::string kMiniclusterCommand{"minicluster"}; static const std::string kNoMapReduceOption{"-nomr"}; static const std::string kFormatNameNodeOption{"-format"}; diff --git a/velox/connectors/hive/storage_adapters/s3fs/CMakeLists.txt b/velox/connectors/hive/storage_adapters/s3fs/CMakeLists.txt index 741f01a61b3..91be9ee864c 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/CMakeLists.txt +++ b/velox/connectors/hive/storage_adapters/s3fs/CMakeLists.txt @@ -14,7 +14,18 @@ # for generated headers -velox_add_library(velox_s3fs RegisterS3FileSystem.cpp) +velox_add_library( + velox_s3fs + RegisterS3FileSystem.cpp + HEADERS + RegisterS3FileSystem.h + S3Config.h + S3Counters.h + S3FileSystem.h + S3ReadFile.h + S3Util.h + S3WriteFile.h +) if(VELOX_ENABLE_S3) velox_sources( velox_s3fs diff --git a/velox/connectors/hive/storage_adapters/s3fs/RegisterS3FileSystem.cpp b/velox/connectors/hive/storage_adapters/s3fs/RegisterS3FileSystem.cpp index 62d014090a4..630bb411327 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/RegisterS3FileSystem.cpp +++ b/velox/connectors/hive/storage_adapters/s3fs/RegisterS3FileSystem.cpp @@ -102,7 +102,8 @@ std::unique_ptr s3WriteFileSinkGenerator( fileSystem->openFileForWrite(fileURI, {{}, options.pool, std::nullopt}), fileURI, options.metricLogger, - options.stats); + options.stats, + options.fileSystemStats); } return nullptr; } diff --git a/velox/connectors/hive/storage_adapters/s3fs/S3Config.cpp b/velox/connectors/hive/storage_adapters/s3fs/S3Config.cpp index a82a3be4562..b006a4e42e9 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/S3Config.cpp +++ b/velox/connectors/hive/storage_adapters/s3fs/S3Config.cpp @@ -21,6 +21,9 @@ namespace facebook::velox::filesystems { +static constexpr size_t kMinimumMultipartMinPartSize = 5U << 20; // 5MB +static constexpr size_t kMaximumMultipartMinPartSize = 5U << 30; // 5GB + std::string S3Config::cacheKey( std::string_view bucket, std::shared_ptr config) { @@ -72,6 +75,15 @@ S3Config::S3Config( } payloadSigningPolicy_ = properties->get(kS3PayloadSigningPolicy, "Never"); + + VELOX_CHECK_GE( + minPartSize(), + kMinimumMultipartMinPartSize, + "The min-part-size S3 configuration must exceed 5MB."); + VELOX_CHECK_LE( + minPartSize(), + kMaximumMultipartMinPartSize, + "The min-part-size S3 configuration must not exceed 5GB."); } std::optional S3Config::endpointRegion() const { @@ -87,4 +99,10 @@ std::optional S3Config::endpointRegion() const { return region; } +size_t S3Config::minPartSize() const { + return config::toCapacity( + config_.find(Keys::kMultipartMinPartSize)->second.value(), + config::CapacityUnit::BYTE); +} + } // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/s3fs/S3Config.h b/velox/connectors/hive/storage_adapters/s3fs/S3Config.h index 4fad4379925..cd094a59832 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/S3Config.h +++ b/velox/connectors/hive/storage_adapters/s3fs/S3Config.h @@ -76,6 +76,8 @@ class S3Config { kRetryMode, kUseProxyFromEnv, kCredentialsProvider, + kIMDSEnabled, + kMultipartMinPartSize, kEnd }; @@ -114,6 +116,9 @@ class S3Config { std::make_pair("use-proxy-from-env", "false")}, {Keys::kCredentialsProvider, std::make_pair("aws-credentials-provider", std::nullopt)}, + {Keys::kIMDSEnabled, std::make_pair("aws-imds-enabled", "true")}, + {Keys::kMultipartMinPartSize, + std::make_pair("min-part-size", "10MB")}, }; return config; } @@ -243,6 +248,14 @@ class S3Config { return config_.find(Keys::kCredentialsProvider)->second; } + /// Returns true if IMDS is enabled in the configuration settings + bool useIMDS() const { + auto value = config_.find(Keys::kIMDSEnabled)->second.value(); + return folly::to(value); + } + + size_t minPartSize() const; + private: std::unordered_map> config_; std::string payloadSigningPolicy_; diff --git a/velox/connectors/hive/storage_adapters/s3fs/S3FileSystem.cpp b/velox/connectors/hive/storage_adapters/s3fs/S3FileSystem.cpp index cdc456ee94b..d96fd7707dc 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/S3FileSystem.cpp +++ b/velox/connectors/hive/storage_adapters/s3fs/S3FileSystem.cpp @@ -229,23 +229,30 @@ void registerCredentialsProvider( class S3FileSystem::Impl { public: - Impl(const S3Config& s3Config) { + explicit Impl(std::shared_ptr s3Config) + : s3Config_(std::move(s3Config)) { VELOX_CHECK(getAwsInstance()->isInitialized(), "S3 is not initialized"); - Aws::S3::S3ClientConfiguration clientConfig; - if (s3Config.endpoint().has_value()) { - clientConfig.endpointOverride = s3Config.endpoint().value(); + Aws::Client::ClientConfigurationInitValues initValues; + initValues.shouldDisableIMDS = !s3Config_->useIMDS(); + Aws::S3::S3ClientConfiguration clientConfig(initValues); + clientConfig.checksumConfig.requestChecksumCalculation = + Aws::Client::RequestChecksumCalculation::WHEN_REQUIRED; + clientConfig.checksumConfig.responseChecksumValidation = + Aws::Client::ResponseChecksumValidation::WHEN_REQUIRED; + if (s3Config_->endpoint().has_value()) { + clientConfig.endpointOverride = s3Config_->endpoint().value(); } - if (s3Config.endpointRegion().has_value()) { - clientConfig.region = s3Config.endpointRegion().value(); + if (s3Config_->endpointRegion().has_value()) { + clientConfig.region = s3Config_->endpointRegion().value(); } - if (s3Config.useProxyFromEnv()) { + if (s3Config_->useProxyFromEnv()) { auto proxyConfig = S3ProxyConfigurationBuilder( - s3Config.endpoint().has_value() ? s3Config.endpoint().value() - : "") - .useSsl(s3Config.useSSL()) + s3Config_->endpoint().has_value() ? s3Config_->endpoint().value() + : "") + .useSsl(s3Config_->useSSL()) .build(); if (proxyConfig.has_value()) { clientConfig.proxyScheme = Aws::Http::SchemeMapper::FromString( @@ -257,42 +264,42 @@ class S3FileSystem::Impl { } } - if (s3Config.useSSL()) { + if (s3Config_->useSSL()) { clientConfig.scheme = Aws::Http::Scheme::HTTPS; } else { clientConfig.scheme = Aws::Http::Scheme::HTTP; } - if (s3Config.connectTimeout().has_value()) { + if (s3Config_->connectTimeout().has_value()) { clientConfig.connectTimeoutMs = std::chrono::duration_cast( facebook::velox::config::toDuration( - s3Config.connectTimeout().value())) + s3Config_->connectTimeout().value())) .count(); } - if (s3Config.socketTimeout().has_value()) { + if (s3Config_->socketTimeout().has_value()) { clientConfig.requestTimeoutMs = std::chrono::duration_cast( facebook::velox::config::toDuration( - s3Config.socketTimeout().value())) + s3Config_->socketTimeout().value())) .count(); } - if (s3Config.maxConnections().has_value()) { - clientConfig.maxConnections = s3Config.maxConnections().value(); + if (s3Config_->maxConnections().has_value()) { + clientConfig.maxConnections = s3Config_->maxConnections().value(); } - auto retryStrategy = getRetryStrategy(s3Config); + auto retryStrategy = getRetryStrategy(s3Config_); if (retryStrategy.has_value()) { clientConfig.retryStrategy = retryStrategy.value(); } - clientConfig.useVirtualAddressing = s3Config.useVirtualAddressing(); + clientConfig.useVirtualAddressing = s3Config_->useVirtualAddressing(); clientConfig.payloadSigningPolicy = - inferPayloadSign(s3Config.payloadSigningPolicy()); + inferPayloadSign(s3Config_->payloadSigningPolicy()); - auto credentialsProvider = getCredentialsProvider(s3Config); + auto credentialsProvider = getCredentialsProvider(*s3Config_); client_ = std::make_shared( credentialsProvider, nullptr /* endpointProvider */, clientConfig); @@ -376,9 +383,9 @@ class S3FileSystem::Impl { // Return a client RetryStrategy based on the config. std::optional> getRetryStrategy( - const S3Config& s3Config) const { - auto retryMode = s3Config.retryMode(); - auto maxAttempts = s3Config.maxAttempts(); + const std::shared_ptr& s3Config) const { + auto retryMode = s3Config->retryMode(); + auto maxAttempts = s3Config->maxAttempts(); if (retryMode.has_value()) { if (retryMode.value() == "standard") { if (maxAttempts.has_value()) { @@ -441,16 +448,21 @@ class S3FileSystem::Impl { return getAwsInstance()->getLogPrefix(); } + std::shared_ptr getS3Config() { + return s3Config_; + } + private: std::shared_ptr client_; + std::shared_ptr s3Config_; }; S3FileSystem::S3FileSystem( std::string_view bucketName, const std::shared_ptr config) : FileSystem(config) { - S3Config s3Config(bucketName, config); - impl_ = std::make_shared(s3Config); + auto s3Config = std::make_shared(bucketName, config); + impl_ = std::make_shared(std::move(s3Config)); } std::string S3FileSystem::getLogLevelName() const { @@ -474,8 +486,8 @@ std::unique_ptr S3FileSystem::openFileForWrite( std::string_view s3Path, const FileOptions& options) { const auto path = getPath(s3Path); - auto s3file = - std::make_unique(path, impl_->s3Client(), options.pool); + auto s3file = std::make_unique( + path, impl_->s3Client(), options.pool, impl_->getS3Config()); return s3file; } diff --git a/velox/connectors/hive/storage_adapters/s3fs/S3ReadFile.cpp b/velox/connectors/hive/storage_adapters/s3fs/S3ReadFile.cpp index 06d180b19f7..f19587114eb 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/S3ReadFile.cpp +++ b/velox/connectors/hive/storage_adapters/s3fs/S3ReadFile.cpp @@ -79,13 +79,13 @@ class S3ReadFile ::Impl { uint64_t offset, uint64_t length, void* buffer, - File::IoStats* stats) const { + const FileIoContext& context) const { preadInternal(offset, length, static_cast(buffer)); return {static_cast(buffer), length}; } - std::string pread(uint64_t offset, uint64_t length, File::IoStats* stats) - const { + std::string + pread(uint64_t offset, uint64_t length, const FileIoContext& context) const { std::string result(length, 0); char* position = result.data(); preadInternal(offset, length, position); @@ -95,7 +95,7 @@ class S3ReadFile ::Impl { uint64_t preadv( uint64_t offset, const std::vector>& buffers, - File::IoStats* stats) const { + const FileIoContext& context) const { // 'buffers' contains Ranges(data, size) with some gaps (data = nullptr) in // between. This call must populate the ranges (except gap ranges) // sequentially starting from 'offset'. AWS S3 GetObject does not support @@ -183,22 +183,22 @@ std::string_view S3ReadFile::pread( uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats) const { - return impl_->pread(offset, length, buf, stats); + const FileIoContext& context) const { + return impl_->pread(offset, length, buf, context); } std::string S3ReadFile::pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats) const { - return impl_->pread(offset, length, stats); + const FileIoContext& context) const { + return impl_->pread(offset, length, context); } uint64_t S3ReadFile::preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats) const { - return impl_->preadv(offset, buffers, stats); + const FileIoContext& context) const { + return impl_->preadv(offset, buffers, context); } uint64_t S3ReadFile::size() const { diff --git a/velox/connectors/hive/storage_adapters/s3fs/S3ReadFile.h b/velox/connectors/hive/storage_adapters/s3fs/S3ReadFile.h index 0b08ed0ec18..deca155abb0 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/S3ReadFile.h +++ b/velox/connectors/hive/storage_adapters/s3fs/S3ReadFile.h @@ -35,17 +35,17 @@ class S3ReadFile : public ReadFile { uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats = nullptr) const final; + const FileIoContext& context = {}) const final; std::string pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats = nullptr) const final; + const FileIoContext& context = {}) const final; uint64_t preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats = nullptr) const final; + const FileIoContext& context = {}) const final; uint64_t size() const final; diff --git a/velox/connectors/hive/storage_adapters/s3fs/S3Util.cpp b/velox/connectors/hive/storage_adapters/s3fs/S3Util.cpp index 9470f858884..8f982131d4c 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/S3Util.cpp +++ b/velox/connectors/hive/storage_adapters/s3fs/S3Util.cpp @@ -157,7 +157,7 @@ std::optional parseAWSStandardRegionName( const std::string_view kAmazonHostSuffix = ".amazonaws.com"; auto index = endpoint.size() - kAmazonHostSuffix.size(); // Handle the case where the endpoint ends in a trailing slash. - if (endpoint.back() == '/') { + if (!endpoint.empty() && endpoint.back() == '/') { index--; } if (endpoint.rfind(kAmazonHostSuffix) != index) { diff --git a/velox/connectors/hive/storage_adapters/s3fs/S3Util.h b/velox/connectors/hive/storage_adapters/s3fs/S3Util.h index ab2e25790d0..125d9cc805b 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/S3Util.h +++ b/velox/connectors/hive/storage_adapters/s3fs/S3Util.h @@ -204,7 +204,7 @@ std::optional parseAWSStandardRegionName( class S3ProxyConfigurationBuilder { public: S3ProxyConfigurationBuilder(const std::string& s3Endpoint) - : s3Endpoint_(s3Endpoint){}; + : s3Endpoint_(s3Endpoint) {} S3ProxyConfigurationBuilder& useSsl(const bool& useSsl) { useSsl_ = useSsl; @@ -237,7 +237,7 @@ class StringViewStream : Aws::Utils::Stream::PreallocatedStreamBuf, template <> struct fmt::formatter : formatter { - auto format(Aws::Http::HttpResponseCode s, format_context& ctx) { + auto format(Aws::Http::HttpResponseCode s, format_context& ctx) const { return formatter::format(static_cast(s), ctx); } }; diff --git a/velox/connectors/hive/storage_adapters/s3fs/S3WriteFile.cpp b/velox/connectors/hive/storage_adapters/s3fs/S3WriteFile.cpp index fcccfe240ab..1acf2ee08fa 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/S3WriteFile.cpp +++ b/velox/connectors/hive/storage_adapters/s3fs/S3WriteFile.cpp @@ -16,6 +16,7 @@ #include "velox/connectors/hive/storage_adapters/s3fs/S3WriteFile.h" #include "velox/common/base/StatsReporter.h" +#include "velox/connectors/hive/storage_adapters/s3fs/S3Config.h" #include "velox/connectors/hive/storage_adapters/s3fs/S3Counters.h" #include "velox/connectors/hive/storage_adapters/s3fs/S3Util.h" #include "velox/dwio/common/DataBuffer.h" @@ -29,6 +30,7 @@ #include #include #include +#include #include namespace facebook::velox::filesystems { @@ -38,13 +40,14 @@ class S3WriteFile::Impl { explicit Impl( std::string_view path, Aws::S3::S3Client* client, - memory::MemoryPool* pool) - : client_(client), pool_(pool) { + memory::MemoryPool* pool, + const std::shared_ptr& s3Config) + : client_(client), pool_(pool), minPartSize_(s3Config->minPartSize()) { VELOX_CHECK_NOT_NULL(client); VELOX_CHECK_NOT_NULL(pool); getBucketAndKeyFromPath(path, bucket_, key_); currentPart_ = std::make_unique>(*pool_); - currentPart_->reserve(kPartUploadSize); + currentPart_->reserve(minPartSize_); // Check that the object doesn't exist, if it does throw an error. { Aws::S3::Model::HeadObjectRequest request; @@ -77,36 +80,60 @@ class S3WriteFile::Impl { outcome, "Failed to create S3 bucket", bucket_, ""); } } + fileSize_ = 0; + } - // Initiate the multi-part upload. - { - Aws::S3::Model::CreateMultipartUploadRequest request; - request.SetBucket(awsString(bucket_)); - request.SetKey(awsString(key_)); + void createMultipartUploadRequest() { + Aws::S3::Model::CreateMultipartUploadRequest request; + request.SetBucket(awsString(bucket_)); + request.SetKey(awsString(key_)); - /// If we do not set anything then the SDK will default to application/xml - /// which confuses some tools - /// (https://github.com/apache/arrow/issues/11934). So we instead default - /// to application/octet-stream which is less misleading. - request.SetContentType(kApplicationOctetStream); - // The default algorithm used is MD5. However, MD5 is not supported with - // fips and can cause a SIGSEGV. Set CRC32 instead which is a standard for - // checksum computation and is not restricted by fips. - request.SetChecksumAlgorithm(Aws::S3::Model::ChecksumAlgorithm::CRC32); + /// If we do not set anything then the SDK will default to application/xml + /// which confuses some tools + /// (https://github.com/apache/arrow/issues/11934). So we instead default + /// to application/octet-stream which is less misleading. + request.SetContentType(kApplicationOctetStream); + auto outcome = client_->CreateMultipartUpload(request); + VELOX_CHECK_AWS_OUTCOME( + outcome, "Failed initiating multiple part upload", bucket_, key_); + uploadState_.id = outcome.GetResult().GetUploadId(); + } - auto outcome = client_->CreateMultipartUpload(request); - VELOX_CHECK_AWS_OUTCOME( - outcome, "Failed initiating multiple part upload", bucket_, key_); - uploadState_.id = outcome.GetResult().GetUploadId(); - } + // This uploads the buffer as a single object. This is used if we deal with + // buffers smaller than min-part-size. - fileSize_ = 0; + void putObjectRequest() { + Aws::S3::Model::PutObjectRequest request; + request.SetBucket(awsString(bucket_)); + request.SetKey(awsString(key_)); + + /// If we do not set anything then the SDK will default to application/xml + /// which confuses some tools + /// (https://github.com/apache/arrow/issues/11934). So we instead default + /// to application/octet-stream which is less misleading. + request.SetContentType(kApplicationOctetStream); + request.SetContentLength(currentPart_->size()); + request.SetBody( + std::make_shared( + currentPart_->data(), currentPart_->size())); + RECORD_METRIC_VALUE(kMetricS3StartedUploads); + auto outcome = client_->PutObject(request); + VELOX_CHECK_AWS_OUTCOME( + outcome, "Failed single object upload", bucket_, key_); + if (outcome.IsSuccess()) { + RECORD_METRIC_VALUE(kMetricS3SuccessfulUploads); + } else { + RECORD_METRIC_VALUE(kMetricS3FailedUploads); + } } // Appends data to the end of the file. void append(std::string_view data) { VELOX_CHECK(!closed(), "File is closed"); - if (data.size() + currentPart_->size() >= kPartUploadSize) { + if (data.size() + currentPart_->size() > minPartSize_) { + if (uploadState_.partNumber == 0) { + createMultipartUploadRequest(); + } upload(data); } else { // Append to current part. @@ -118,16 +145,26 @@ class S3WriteFile::Impl { // No-op. void flush() { VELOX_CHECK(!closed(), "File is closed"); - /// currentPartSize must be less than kPartUploadSize since - /// append() would have already flushed after reaching kUploadPartSize. - VELOX_CHECK_LT(currentPart_->size(), kPartUploadSize); + /// currentPartSize must be less than minPartSize since + /// append() would have already flushed after reaching minPartSize. + VELOX_CHECK_LT(currentPart_->size(), minPartSize_); } - // Complete the multipart upload and close the file. + // Send the buffer in a single request or complete the multipart upload. + // Then close the file. void close() { if (closed()) { return; } + // If we haven't sent anything yet, that is the file is less then + // minFileSize_. lets put the object that is in the buffer. + if (uploadState_.partNumber == 0) { + // Send single request. + putObjectRequest(); + currentPart_->clear(); + return; + } + RECORD_METRIC_VALUE(kMetricS3StartedUploads); uploadPart({currentPart_->data(), currentPart_->size()}, true); VELOX_CHECK_EQ(uploadState_.partNumber, uploadState_.completedParts.size()); @@ -163,7 +200,6 @@ class S3WriteFile::Impl { } private: - static constexpr int64_t kPartUploadSize = 10 * 1024 * 1024; static constexpr const char* kApplicationOctetStream = "application/octet-stream"; @@ -179,8 +215,8 @@ class S3WriteFile::Impl { }; UploadState uploadState_; - // Data can be smaller or larger than the kPartUploadSize. - // Complete the currentPart_ and upload kPartUploadSize chunks of data. + // Data can be smaller or larger than the minPartSize_. + // Complete the currentPart_ and upload minPartSize_ chunks of data. // Save the remaining into currentPart_. void upload(const std::string_view data) { auto dataPtr = data.data(); @@ -191,18 +227,18 @@ class S3WriteFile::Impl { uploadPart({currentPart_->data(), currentPart_->size()}); dataPtr += remainingBufferSize; dataSize -= remainingBufferSize; - while (dataSize > kPartUploadSize) { - uploadPart({dataPtr, kPartUploadSize}); - dataPtr += kPartUploadSize; - dataSize -= kPartUploadSize; + while (dataSize > minPartSize_) { + uploadPart({dataPtr, minPartSize_}); + dataPtr += minPartSize_; + dataSize -= minPartSize_; } // Stash the remaining at the beginning of currentPart. currentPart_->unsafeAppend(0, dataPtr, dataSize); } void uploadPart(const std::string_view part, bool isLast = false) { - // Only the last part can be less than kPartUploadSize. - VELOX_CHECK(isLast || (!isLast && (part.size() == kPartUploadSize))); + // Only the last part can be less than minPartSize. + VELOX_CHECK(isLast || (!isLast && (part.size() == minPartSize_))); // Upload the part. { Aws::S3::Model::UploadPartRequest request; @@ -213,10 +249,6 @@ class S3WriteFile::Impl { request.SetContentLength(part.size()); request.SetBody( std::make_shared(part.data(), part.size())); - // The default algorithm used is MD5. However, MD5 is not supported with - // fips and can cause a SIGSEGV. Set CRC32 instead which is a standard for - // checksum computation and is not restricted by fips. - request.SetChecksumAlgorithm(Aws::S3::Model::ChecksumAlgorithm::CRC32); auto outcome = client_->UploadPart(request); VELOX_CHECK_AWS_OUTCOME(outcome, "Failed to upload", bucket_, key_); // Append ETag and part number for this uploaded part. @@ -241,13 +273,15 @@ class S3WriteFile::Impl { std::string bucket_; std::string key_; size_t fileSize_ = -1; + const size_t minPartSize_; }; S3WriteFile::S3WriteFile( std::string_view path, Aws::S3::S3Client* client, - memory::MemoryPool* pool) { - impl_ = std::make_shared(path, client, pool); + memory::MemoryPool* pool, + const std::shared_ptr& s3Config) { + impl_ = std::make_shared(path, client, pool, s3Config); } void S3WriteFile::append(std::string_view data) { diff --git a/velox/connectors/hive/storage_adapters/s3fs/S3WriteFile.h b/velox/connectors/hive/storage_adapters/s3fs/S3WriteFile.h index 929eed20c37..ef4ad2968d2 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/S3WriteFile.h +++ b/velox/connectors/hive/storage_adapters/s3fs/S3WriteFile.h @@ -25,20 +25,25 @@ class S3Client; namespace facebook::velox::filesystems { +class S3Config; + /// S3WriteFile uses the Apache Arrow implementation as a reference. /// AWS C++ SDK allows streaming writes via the MultiPart upload API. /// Multipart upload allows you to upload a single object as a set of parts. /// Each part is a contiguous portion of the object's data. /// While AWS and Minio support different sizes for each -/// part (only requiring a minimum of 5MB), Cloudflare R2 requires that every -/// part be exactly equal (except for the last part). We set this to 10 MiB, so -/// that in combination with the maximum number of parts of 10,000, this gives a -/// file limit of 100k MiB (or about 98 GiB). -/// (see https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html) -/// (for rational, see: https://github.com/apache/arrow/issues/34363) -/// You can upload these object parts independently and in any order. -/// After all parts of your object are uploaded, Amazon S3 assembles these parts -/// and creates the object. +/// part (only requiring a minimum of 5MB - but not enforced), Apache Ozone +/// enforces the minimum 5MB (smaller parts are ignored), and Cloudflare R2 +/// requires that every part be exactly equal (except for the last part). We set +/// this to minPartSize (default 10MB), so that in combination with the maximum +/// number of parts of 10,000, this gives a file limit of 100k MiB (or about 98 +/// GiB). +/// (for rationale, see: https://github.com/apache/arrow/issues/34363) +/// For AWS limits see +/// https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html). You can +/// upload these object parts independently and in any order. After all parts of +/// your object are uploaded, Amazon S3 assembles these parts and creates the +/// object. /// https://docs.aws.amazon.com/AmazonS3/latest/userguide/mpuoverview.html /// https://github.com/apache/arrow/blob/main/cpp/src/arrow/filesystem/s3fs.cc /// S3WriteFile is not thread-safe. @@ -50,7 +55,8 @@ class S3WriteFile : public WriteFile { S3WriteFile( std::string_view path, Aws::S3::S3Client* client, - memory::MemoryPool* pool); + memory::MemoryPool* pool, + const std::shared_ptr& s3Config); /// Appends data to the end of the file. /// Uploads a part on reaching part size limit. @@ -69,6 +75,8 @@ class S3WriteFile : public WriteFile { int numPartsUploaded() const; protected: + void createMultipartUploadRequest(); + class Impl; std::shared_ptr impl_; }; diff --git a/velox/connectors/hive/storage_adapters/s3fs/tests/CMakeLists.txt b/velox/connectors/hive/storage_adapters/s3fs/tests/CMakeLists.txt index 9d92727e767..4de74413313 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/tests/CMakeLists.txt +++ b/velox/connectors/hive/storage_adapters/s3fs/tests/CMakeLists.txt @@ -23,6 +23,7 @@ target_link_libraries( ) add_executable(velox_s3file_test S3FileSystemTest.cpp S3UtilTest.cpp) +velox_add_test_headers(velox_s3file_test MinioServer.h S3Test.h) add_test(velox_s3file_test velox_s3file_test) target_link_libraries( velox_s3file_test diff --git a/velox/connectors/hive/storage_adapters/s3fs/tests/MinioServer.h b/velox/connectors/hive/storage_adapters/s3fs/tests/MinioServer.h index 591ed403f35..164bb590518 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/tests/MinioServer.h +++ b/velox/connectors/hive/storage_adapters/s3fs/tests/MinioServer.h @@ -17,13 +17,15 @@ #pragma once #include "velox/common/config/Config.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/exec/tests/utils/PortUtil.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "boost/process.hpp" using namespace facebook::velox; +using TempDirectoryPath = common::testutil::TempDirectoryPath; + namespace { constexpr char const* kMinioExecutableName{"minio-2022-05-26"}; constexpr char const* kMinioAccessKey{"minio"}; @@ -34,7 +36,7 @@ constexpr char const* kMinioSecretKey{"miniopass"}; // Adapted from the Apache Arrow library. class MinioServer { public: - MinioServer() : tempPath_(::exec::test::TempDirectoryPath::create()) { + MinioServer() : tempPath_(TempDirectoryPath::create()) { constexpr auto kHostAddressTemplate = "127.0.0.1:{}"; auto ports = facebook::velox::exec::test::getFreePorts(2); connectionString_ = fmt::format(kHostAddressTemplate, ports[0]); @@ -74,7 +76,7 @@ class MinioServer { } private: - const std::shared_ptr tempPath_; + const std::shared_ptr tempPath_; std::string connectionString_; std::string consoleAddress_; const std::string accessKey_ = kMinioAccessKey; diff --git a/velox/connectors/hive/storage_adapters/s3fs/tests/S3ConfigTest.cpp b/velox/connectors/hive/storage_adapters/s3fs/tests/S3ConfigTest.cpp index e8d48fe6edd..20b10268677 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/tests/S3ConfigTest.cpp +++ b/velox/connectors/hive/storage_adapters/s3fs/tests/S3ConfigTest.cpp @@ -15,6 +15,7 @@ */ #include "velox/connectors/hive/storage_adapters/s3fs/S3Config.h" +#include "velox/common/base/tests/GTestUtils.h" #include "velox/common/config/Config.h" #include @@ -37,6 +38,8 @@ TEST(S3ConfigTest, defaultConfig) { ASSERT_EQ(s3Config.payloadSigningPolicy(), "Never"); ASSERT_EQ(s3Config.cacheKey("foo", config), "foo"); ASSERT_EQ(s3Config.bucket(), ""); + ASSERT_EQ(s3Config.useIMDS(), true); + ASSERT_EQ(s3Config.minPartSize(), 10485760); } TEST(S3ConfigTest, overrideConfig) { @@ -53,7 +56,9 @@ TEST(S3ConfigTest, overrideConfig) { {S3Config::baseConfigKey(S3Config::Keys::kIamRole), "iam"}, {S3Config::baseConfigKey(S3Config::Keys::kIamRoleSessionName), "velox"}, {S3Config::baseConfigKey(S3Config::Keys::kCredentialsProvider), - "my-credentials-provider"}}; + "my-credentials-provider"}, + {S3Config::baseConfigKey(S3Config::Keys::kIMDSEnabled), "false"}, + {S3Config::baseConfigKey(S3Config::Keys::kMultipartMinPartSize), "20MB"}}; auto configBase = std::make_shared(std::move(configFromFile)); auto s3Config = S3Config("bucket", configBase); @@ -71,6 +76,8 @@ TEST(S3ConfigTest, overrideConfig) { ASSERT_EQ(s3Config.cacheKey("bar", configBase), "endpoint-bar"); ASSERT_EQ(s3Config.bucket(), "bucket"); ASSERT_EQ(s3Config.credentialsProvider(), "my-credentials-provider"); + ASSERT_EQ(s3Config.useIMDS(), false); + ASSERT_EQ(s3Config.minPartSize(), 20971520); } TEST(S3ConfigTest, overrideBucketConfig) { @@ -95,7 +102,9 @@ TEST(S3ConfigTest, overrideBucketConfig) { {S3Config::baseConfigKey(S3Config::Keys::kCredentialsProvider), "my-credentials-provider"}, {S3Config::bucketConfigKey(S3Config::Keys::kCredentialsProvider, bucket), - "override-credentials-provider"}}; + "override-credentials-provider"}, + {S3Config::baseConfigKey(S3Config::Keys::kIMDSEnabled), "false"}, + {S3Config::baseConfigKey(S3Config::Keys::kMultipartMinPartSize), "20MB"}}; auto configBase = std::make_shared(std::move(bucketConfigFromFile)); auto s3Config = S3Config(bucket, configBase); @@ -115,6 +124,50 @@ TEST(S3ConfigTest, overrideBucketConfig) { "bucket.s3-region.amazonaws.com-bucket"); ASSERT_EQ(s3Config.cacheKey("foo", configBase), "endpoint-foo"); ASSERT_EQ(s3Config.credentialsProvider(), "override-credentials-provider"); + ASSERT_EQ(s3Config.useIMDS(), false); + ASSERT_EQ(s3Config.minPartSize(), 20971520); +} + +TEST(S3ConfigTest, minPartSizeValidation) { + // Test that setting min-part-size below 5MB throws an error. + std::unordered_map configFromFile = { + {S3Config::baseConfigKey(S3Config::Keys::kMultipartMinPartSize), "4MB"}}; + auto configBase = + std::make_shared(std::move(configFromFile)); + + VELOX_ASSERT_THROW( + S3Config("bucket", configBase), + "The min-part-size S3 configuration must exceed 5MB"); + + configFromFile = { + {S3Config::baseConfigKey(S3Config::Keys::kMultipartMinPartSize), "10GB"}}; + configBase = std::make_shared(std::move(configFromFile)); + VELOX_ASSERT_THROW( + S3Config("bucket", configBase), + "The min-part-size S3 configuration must not exceed 5GB"); +} + +TEST(S3ConfigTest, minPartSizeValidationBucketConfig) { + // Test that setting bucket-specific min-part-size below 5MB throws an error. + std::string_view bucket = "testbucket"; + std::unordered_map configFromFile = { + {S3Config::bucketConfigKey(S3Config::Keys::kMultipartMinPartSize, bucket), + "3MB"}}; + auto configBase = + std::make_shared(std::move(configFromFile)); + + VELOX_ASSERT_THROW( + S3Config(bucket, configBase), + "The min-part-size S3 configuration must exceed 5MB"); + + configFromFile = { + {S3Config::bucketConfigKey(S3Config::Keys::kMultipartMinPartSize, bucket), + "10GB"}}; + configBase = std::make_shared(std::move(configFromFile)); + + VELOX_ASSERT_THROW( + S3Config(bucket, configBase), + "The min-part-size S3 configuration must not exceed 5GB"); } } // namespace diff --git a/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemMetricsTest.cpp b/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemMetricsTest.cpp index aba2ac90236..907acf92783 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemMetricsTest.cpp +++ b/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemMetricsTest.cpp @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include #include @@ -24,8 +25,9 @@ #include -namespace facebook::velox::filesystems { +namespace facebook::velox::filesystems::test { namespace { + class S3TestReporter : public BaseStatsReporter { public: mutable std::mutex m; @@ -40,6 +42,7 @@ class S3TestReporter : public BaseStatsReporter { statTypeMap.clear(); histogramPercentilesMap.clear(); } + void registerMetricExportType(const char* key, StatType statType) const override { statTypeMap[key] = statType; @@ -68,18 +71,6 @@ class S3TestReporter : public BaseStatsReporter { histogramPercentilesMap[key.str()] = pcts; } - void registerQuantileMetricExportType( - const char* /* key */, - const std::vector& /* statTypes */, - const std::vector& /* pcts */, - const std::vector& /* slidingWindowsSeconds */) const override {} - - void registerQuantileMetricExportType( - folly::StringPiece /* key */, - const std::vector& /* statTypes */, - const std::vector& /* pcts */, - const std::vector& /* slidingWindowsSeconds */) const override {} - void addMetricValue(const std::string& key, const size_t value) const override { std::lock_guard l(m); @@ -110,42 +101,6 @@ class S3TestReporter : public BaseStatsReporter { counterMap[key.str()] = std::max(counterMap[key.str()], value); } - void addQuantileMetricValue(const std::string& /* key */, size_t /* value */) - const override {} - - void addQuantileMetricValue(const char* /* key */, size_t /* value */) - const override {} - - void addQuantileMetricValue(folly::StringPiece /* key */, size_t /* value */) - const override {} - - void registerDynamicQuantileMetricExportType( - const char* /* keyPattern */, - const std::vector& /* statTypes */, - const std::vector& /* pcts */, - const std::vector& /* slidingWindowsSeconds */) const override {} - - void registerDynamicQuantileMetricExportType( - folly::StringPiece /* keyPattern */, - const std::vector& /* statTypes */, - const std::vector& /* pcts */, - const std::vector& /* slidingWindowsSeconds */) const override {} - - void addDynamicQuantileMetricValue( - const std::string& /* key */, - folly::Range /* subkeys */, - size_t /* value */) const override {} - - void addDynamicQuantileMetricValue( - const char* /* key */, - folly::Range /* subkeys */, - size_t /* value */) const override {} - - void addDynamicQuantileMetricValue( - folly::StringPiece /* key */, - folly::Range /* subkeys */, - size_t /* value */) const override {} - std::string fetchMetrics() override { std::stringstream ss; ss << "["; @@ -217,7 +172,7 @@ TEST_F(S3FileSystemMetricsTest, metrics) { EXPECT_EQ(1, s3Reporter->counterMap[std::string{kMetricS3GetObjectCalls}]); } -} // namespace facebook::velox::filesystems +} // namespace facebook::velox::filesystems::test int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); diff --git a/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemTest.cpp b/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemTest.cpp index 3fea62365bb..a94e9b4c352 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemTest.cpp +++ b/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemTest.cpp @@ -175,7 +175,7 @@ TEST_F(S3FileSystemTest, noBackendServer) { minioServer_->stop(); VELOX_ASSERT_THROW( s3fs.openFileForRead(kDummyPath), - "Failed to get metadata for S3 object due to: 'Network connection'. Path:'s3://dummy/foo.txt', SDK Error Type:99, HTTP Status Code:-1, S3 Service:'Unknown', Message:'curlCode: 7, Couldn't connect to server'"); + "Failed to get metadata for S3 object due to: 'Network connection'. Path:'s3://dummy/foo.txt', SDK Error Type:99, HTTP Status Code:-1, S3 Service:'Unknown', Message:'curlCode: 7, Couldn't connect to server"); // Start Minio again. minioServer_->start(); } diff --git a/velox/connectors/hive/storage_adapters/s3fs/tests/S3InsertTest.cpp b/velox/connectors/hive/storage_adapters/s3fs/tests/S3InsertTest.cpp index 2f1e753e7ba..0861df83d48 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/tests/S3InsertTest.cpp +++ b/velox/connectors/hive/storage_adapters/s3fs/tests/S3InsertTest.cpp @@ -28,18 +28,21 @@ class S3InsertTest : public S3Test, public test::InsertTest { protected: static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + filesystems::registerS3FileSystem(); + } + + static void TearDownTestCase() { + filesystems::finalizeS3FileSystem(); } void SetUp() override { S3Test::SetUp(); - filesystems::registerS3FileSystem(); InsertTest::SetUp(minioServer_->hiveConfig(), ioExecutor_.get()); } void TearDown() override { S3Test::TearDown(); InsertTest::TearDown(); - filesystems::finalizeS3FileSystem(); } }; } // namespace @@ -51,6 +54,22 @@ TEST_F(S3InsertTest, s3InsertTest) { runInsertTest(kOutputDirectory, kExpectedRows, pool()); } + +// Test with data exceeding the default 5MB minPartSize to trigger multipart +// upload. This test generates enough data to exceed 5MB, which should trigger +// at least one multipart upload part and a remainder. +TEST_F(S3InsertTest, s3MultipartUploadTest) { + // Generate enough rows to exceed 5MB. + // Each row has 4 columns: BIGINT (8 bytes), INTEGER (4 bytes), + // SMALLINT (2 bytes), DOUBLE (8 bytes) = 22 bytes per row minimum. + // To exceed 5MB (5 * 1024 * 1024 = 5,242,880 bytes), we need at least + // 5,242,880 / 22 ≈ 238,313 rows. Let's use 300,000 rows to be safe. + const int64_t kExpectedRows = 300'000; + const std::string_view kOutputDirectory{"s3://multipartdata/"}; + minioServer_->addBucket("multipartdata"); + + runInsertTest(kOutputDirectory, kExpectedRows, pool()); +} } // namespace facebook::velox::filesystems int main(int argc, char** argv) { diff --git a/velox/connectors/hive/storage_adapters/s3fs/tests/S3MultipleEndpointsTest.cpp b/velox/connectors/hive/storage_adapters/s3fs/tests/S3MultipleEndpointsTest.cpp index cf446b12d2e..3c90521bb56 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/tests/S3MultipleEndpointsTest.cpp +++ b/velox/connectors/hive/storage_adapters/s3fs/tests/S3MultipleEndpointsTest.cpp @@ -17,6 +17,7 @@ #include #include "gtest/gtest.h" +#include "velox/connectors/ConnectorRegistry.h" #include "velox/connectors/hive/HiveConnector.h" #include "velox/connectors/hive/storage_adapters/s3fs/RegisterS3FileSystem.h" #include "velox/connectors/hive/storage_adapters/s3fs/S3Util.h" @@ -71,8 +72,10 @@ class S3MultipleEndpoints : public S3Test, public ::test::VectorTestBase { std::string(connectorId2), minioSecondServer_->hiveConfig(config2Override), ioExecutor_.get()); - connector::registerConnector(hiveConnector1); - connector::registerConnector(hiveConnector2); + connector::ConnectorRegistry::global().insert( + hiveConnector1->connectorId(), hiveConnector1); + connector::ConnectorRegistry::global().insert( + hiveConnector2->connectorId(), hiveConnector2); } void TearDown() override { @@ -102,7 +105,8 @@ class S3MultipleEndpoints : public S3Test, public ::test::VectorTestBase { // Second column contains details about written files. auto details = results->childAt(exec::TableWriteTraits::kFragmentChannel) ->as>(); - folly::dynamic obj = folly::parseJson(details->valueAt(1)); + folly::dynamic obj = + folly::parseJson(std::string_view(details->valueAt(1))); return obj["fileWriteInfos"]; } @@ -190,8 +194,8 @@ TEST_F(S3MultipleEndpoints, baseEndpoints) { testJoin(kExpectedRows, outputDirectory, kConnectorId1, kConnectorId2); - connector::unregisterConnector(std::string(kConnectorId1)); - connector::unregisterConnector(std::string(kConnectorId2)); + connector::ConnectorRegistry::global().erase(std::string(kConnectorId1)); + connector::ConnectorRegistry::global().erase(std::string(kConnectorId2)); } TEST_F(S3MultipleEndpoints, bucketEndpoints) { @@ -217,8 +221,8 @@ TEST_F(S3MultipleEndpoints, bucketEndpoints) { testJoin(kExpectedRows, outputDirectory, kConnectorId1, kConnectorId2); - connector::unregisterConnector(std::string(kConnectorId1)); - connector::unregisterConnector(std::string(kConnectorId2)); + connector::ConnectorRegistry::global().erase(std::string(kConnectorId1)); + connector::ConnectorRegistry::global().erase(std::string(kConnectorId2)); } } // namespace facebook::velox diff --git a/velox/connectors/hive/storage_adapters/s3fs/tests/S3ReadTest.cpp b/velox/connectors/hive/storage_adapters/s3fs/tests/S3ReadTest.cpp index bbd08483751..9062eed288a 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/tests/S3ReadTest.cpp +++ b/velox/connectors/hive/storage_adapters/s3fs/tests/S3ReadTest.cpp @@ -18,6 +18,7 @@ #include #include "velox/common/memory/Memory.h" +#include "velox/connectors/ConnectorRegistry.h" #include "velox/connectors/hive/HiveConnector.h" #include "velox/connectors/hive/storage_adapters/s3fs/RegisterS3FileSystem.h" #include "velox/connectors/hive/storage_adapters/s3fs/tests/S3Test.h" @@ -43,14 +44,15 @@ class S3ReadTest : public S3Test, public ::test::VectorTestBase { connector::hive::HiveConnectorFactory factory; auto hiveConnector = factory.newConnector(kHiveConnectorId, minioServer_->hiveConfig()); - connector::registerConnector(hiveConnector); + connector::ConnectorRegistry::global().insert( + hiveConnector->connectorId(), hiveConnector); parquet::registerParquetReaderFactory(); } void TearDown() override { parquet::unregisterParquetReaderFactory(); filesystems::finalizeS3FileSystem(); - connector::unregisterConnector(kHiveConnectorId); + connector::ConnectorRegistry::global().erase(kHiveConnectorId); S3Test::TearDown(); } }; diff --git a/velox/connectors/hive/storage_adapters/s3fs/tests/S3Test.h b/velox/connectors/hive/storage_adapters/s3fs/tests/S3Test.h index 6a190684e2b..16e2a425649 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/tests/S3Test.h +++ b/velox/connectors/hive/storage_adapters/s3fs/tests/S3Test.h @@ -16,16 +16,17 @@ #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/File.h" +#include "velox/common/testutil/TempFilePath.h" #include "velox/connectors/hive/FileHandle.h" #include "velox/connectors/hive/storage_adapters/s3fs/S3FileSystem.h" #include "velox/connectors/hive/storage_adapters/s3fs/S3Util.h" #include "velox/connectors/hive/storage_adapters/s3fs/tests/MinioServer.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" -#include "velox/exec/tests/utils/TempFilePath.h" #include "gtest/gtest.h" using namespace facebook::velox; +using TempFilePath = common::testutil::TempFilePath; constexpr int kOneMB = 1 << 20; diff --git a/velox/connectors/hive/storage_adapters/test_common/InsertTest.h b/velox/connectors/hive/storage_adapters/test_common/InsertTest.h index 0700ca7a334..c595beca383 100644 --- a/velox/connectors/hive/storage_adapters/test_common/InsertTest.h +++ b/velox/connectors/hive/storage_adapters/test_common/InsertTest.h @@ -18,7 +18,9 @@ #include #include "velox/common/memory/Memory.h" +#include "velox/connectors/ConnectorRegistry.h" #include "velox/connectors/hive/HiveConnector.h" +#include "velox/connectors/hive/HiveDataSink.h" #include "velox/dwio/parquet/RegisterParquetReader.h" #include "velox/dwio/parquet/RegisterParquetWriter.h" #include "velox/exec/TableWriter.h" @@ -37,7 +39,8 @@ class InsertTest : public velox::test::VectorTestBase { connector::hive::HiveConnectorFactory factory; auto hiveConnector = factory.newConnector( exec::test::kHiveConnectorId, hiveConfig, ioExecutor); - connector::registerConnector(hiveConnector); + connector::ConnectorRegistry::global().insert( + hiveConnector->connectorId(), hiveConnector); parquet::registerParquetReaderFactory(); parquet::registerParquetWriterFactory(); @@ -47,7 +50,7 @@ class InsertTest : public velox::test::VectorTestBase { parquet::unregisterParquetReaderFactory(); parquet::unregisterParquetWriterFactory(); - connector::unregisterConnector(exec::test::kHiveConnectorId); + connector::ConnectorRegistry::global().erase(exec::test::kHiveConnectorId); } void runInsertTest( @@ -87,19 +90,26 @@ class InsertTest : public velox::test::VectorTestBase { ->as>(); ASSERT_TRUE(details->isNullAt(0)); ASSERT_FALSE(details->isNullAt(1)); - folly::dynamic obj = folly::parseJson(details->valueAt(1)); + folly::dynamic obj = + folly::parseJson(std::string_view(details->valueAt(1))); - ASSERT_EQ(numRows, obj["rowCount"].asInt()); - auto fileWriteInfos = obj["fileWriteInfos"]; + ASSERT_EQ( + numRows, obj[connector::hive::HiveCommitMessage::kRowCount].asInt()); + auto fileWriteInfos = + obj[connector::hive::HiveCommitMessage::kFileWriteInfos]; ASSERT_EQ(1, fileWriteInfos.size()); - auto writeFileName = fileWriteInfos[0]["writeFileName"].asString(); + auto writeFileName = + fileWriteInfos[0][connector::hive::HiveCommitMessage::kWriteFileName] + .asString(); // Read from 'writeFileName' and verify the data matches the original. plan = exec::test::PlanBuilder().tableScan(rowType).planNode(); auto filePath = fmt::format("{}{}", outputDirectory, writeFileName); - const int64_t fileSize = fileWriteInfos[0]["fileSize"].asInt(); + const int64_t fileSize = + fileWriteInfos[0][connector::hive::HiveCommitMessage::kFileSize] + .asInt(); auto split = exec::test::HiveConnectorSplitBuilder(filePath) .fileFormat(dwio::common::FileFormat::PARQUET) .length(fileSize) diff --git a/velox/connectors/hive/tests/CMakeLists.txt b/velox/connectors/hive/tests/CMakeLists.txt index 0a24b6b10c4..9cceaa511e3 100644 --- a/velox/connectors/hive/tests/CMakeLists.txt +++ b/velox/connectors/hive/tests/CMakeLists.txt @@ -13,6 +13,10 @@ # limitations under the License. add_executable( velox_hive_connector_test + FileColumnHandleTest.cpp + FileConfigTest.cpp + FileConnectorSplitTest.cpp + FileConnectorUtilTest.cpp FileHandleTest.cpp HiveConfigTest.cpp HiveDataSinkTest.cpp @@ -20,7 +24,7 @@ add_executable( HiveConnectorUtilTest.cpp HiveConnectorSerDeTest.cpp HivePartitionFunctionTest.cpp - HivePartitionUtilTest.cpp + HivePartitionNameTest.cpp HiveSplitTest.cpp PartitionIdGeneratorTest.cpp TableHandleTest.cpp @@ -36,6 +40,7 @@ target_link_libraries( velox_vector_test_lib velox_exec velox_exec_test_lib + GTest::gmock GTest::gtest GTest::gtest_main ) diff --git a/velox/connectors/hive/tests/FileColumnHandleTest.cpp b/velox/connectors/hive/tests/FileColumnHandleTest.cpp new file mode 100644 index 00000000000..ba955b57444 --- /dev/null +++ b/velox/connectors/hive/tests/FileColumnHandleTest.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/FileColumnHandle.h" +#include +#include +#include "velox/common/base/tests/GTestUtils.h" + +using namespace facebook::velox::connector::hive; + +TEST(FileColumnHandleTest, columnTypeName) { + EXPECT_EQ( + FileColumnHandle::columnTypeName( + FileColumnHandle::ColumnType::kPartitionKey), + "PartitionKey"); + EXPECT_EQ( + FileColumnHandle::columnTypeName(FileColumnHandle::ColumnType::kRegular), + "Regular"); + EXPECT_EQ( + FileColumnHandle::columnTypeName( + FileColumnHandle::ColumnType::kSynthesized), + "Synthesized"); + EXPECT_EQ( + FileColumnHandle::columnTypeName(FileColumnHandle::ColumnType::kRowIndex), + "RowIndex"); + EXPECT_EQ( + FileColumnHandle::columnTypeName(FileColumnHandle::ColumnType::kRowId), + "RowId"); +} + +TEST(FileColumnHandleTest, columnTypeFromName) { + EXPECT_EQ( + FileColumnHandle::columnTypeFromName("PartitionKey"), + FileColumnHandle::ColumnType::kPartitionKey); + EXPECT_EQ( + FileColumnHandle::columnTypeFromName("Regular"), + FileColumnHandle::ColumnType::kRegular); + EXPECT_EQ( + FileColumnHandle::columnTypeFromName("Synthesized"), + FileColumnHandle::ColumnType::kSynthesized); + EXPECT_EQ( + FileColumnHandle::columnTypeFromName("RowIndex"), + FileColumnHandle::ColumnType::kRowIndex); + EXPECT_EQ( + FileColumnHandle::columnTypeFromName("RowId"), + FileColumnHandle::ColumnType::kRowId); +} + +TEST(FileColumnHandleTest, columnTypeFromNameInvalid) { + VELOX_ASSERT_THROW( + FileColumnHandle::columnTypeFromName("Unknown"), + "Unknown column type name: Unknown"); +} + +TEST(FileColumnHandleTest, columnTypeRoundTrip) { + const std::vector allTypes = { + FileColumnHandle::ColumnType::kPartitionKey, + FileColumnHandle::ColumnType::kRegular, + FileColumnHandle::ColumnType::kSynthesized, + FileColumnHandle::ColumnType::kRowIndex, + FileColumnHandle::ColumnType::kRowId, + }; + for (auto type : allTypes) { + EXPECT_EQ( + FileColumnHandle::columnTypeFromName( + FileColumnHandle::columnTypeName(type)), + type); + } +} + +TEST(FileColumnHandleTest, fmtFormatter) { + EXPECT_EQ( + fmt::format("{}", FileColumnHandle::ColumnType::kRegular), "Regular"); + EXPECT_EQ( + fmt::format("{}", FileColumnHandle::ColumnType::kPartitionKey), + "PartitionKey"); +} diff --git a/velox/connectors/hive/tests/FileConfigTest.cpp b/velox/connectors/hive/tests/FileConfigTest.cpp new file mode 100644 index 00000000000..65377e41ca3 --- /dev/null +++ b/velox/connectors/hive/tests/FileConfigTest.cpp @@ -0,0 +1,174 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/FileConfig.h" +#include +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/config/Config.h" + +using namespace facebook::velox; +using namespace facebook::velox::connector::hive; + +TEST(FileConfigTest, defaultConfig) { + FileConfig config( + std::make_shared( + std::unordered_map())); + const auto emptySession = std::make_unique( + std::unordered_map()); + + EXPECT_FALSE(config.isOrcUseColumnNames(emptySession.get())); + EXPECT_FALSE(config.isParquetUseColumnNames(emptySession.get())); + EXPECT_FALSE(config.isFileColumnNamesReadAsLowerCase(emptySession.get())); + EXPECT_FALSE(config.ignoreMissingFiles(emptySession.get())); + EXPECT_EQ(config.maxCoalescedBytes(emptySession.get()), 128 << 20); + EXPECT_EQ(config.maxCoalescedDistanceBytes(emptySession.get()), 512 << 10); + EXPECT_EQ(config.prefetchRowGroups(), 1); + EXPECT_EQ(config.parallelUnitLoadCount(emptySession.get()), 0); + EXPECT_EQ(config.loadQuantum(emptySession.get()), 8 << 20); + EXPECT_EQ(config.filePreloadThreshold(), 8UL << 20); + EXPECT_EQ(config.readTimestampUnit(emptySession.get()), 3); + EXPECT_TRUE( + config.readTimestampPartitionValueAsLocalTime(emptySession.get())); + EXPECT_FALSE(config.readStatsBasedFilterReorderDisabled(emptySession.get())); + EXPECT_FALSE(config.preserveFlatMapsInMemory(emptySession.get())); + EXPECT_FALSE(config.indexEnabled(emptySession.get())); + EXPECT_FALSE(config.readerCollectColumnCpuMetrics(emptySession.get())); + EXPECT_FALSE(config.fileMetadataCacheEnabled(emptySession.get())); + EXPECT_FALSE(config.pinFileMetadata(emptySession.get())); + EXPECT_EQ(config.orcFooterSpeculativeIoSize(emptySession.get()), 256UL << 10); + EXPECT_EQ( + config.parquetFooterSpeculativeIoSize(emptySession.get()), 256UL << 10); + EXPECT_EQ( + config.nimbleFooterSpeculativeIoSize(emptySession.get()), 8UL << 20); + EXPECT_FALSE(config.nimbleStringDecoderZeroCopy(emptySession.get())); + EXPECT_FALSE(config.nimblePreserveDictionaryEncoding(emptySession.get())); +} + +TEST(FileConfigTest, overrideConfig) { + std::unordered_map configFromFile = { + {FileConfig::kOrcUseColumnNames, "true"}, + {FileConfig::kParquetUseColumnNames, "true"}, + {FileConfig::kFileColumnNamesReadAsLowerCase, "true"}, + {FileConfig::kMaxCoalescedBytes, "100"}, + {FileConfig::kMaxCoalescedDistance, "100kB"}, + {FileConfig::kPrefetchRowGroups, "4"}, + {FileConfig::kLoadQuantum, std::to_string(4 << 20)}, + {FileConfig::kFilePreloadThreshold, std::to_string(16UL << 20)}, + {FileConfig::kReadStatsBasedFilterReorderDisabled, "true"}, + {FileConfig::kPreserveFlatMapsInMemory, "true"}, + {FileConfig::kIndexEnabled, "true"}, + {FileConfig::kReaderCollectColumnCpuMetrics, "true"}, + {FileConfig::kFileMetadataCacheEnabled, "true"}, + {FileConfig::kPinFileMetadata, "true"}, + {FileConfig::kOrcFooterSpeculativeIoSize, std::to_string(512UL << 10)}, + {FileConfig::kParquetFooterSpeculativeIoSize, std::to_string(1UL << 20)}, + {FileConfig::kNimbleFooterSpeculativeIoSize, std::to_string(4UL << 20)}, + {FileConfig::kNimbleStringDecoderZeroCopy, "true"}, + {FileConfig::kNimblePreserveDictionaryEncoding, "true"}, + }; + FileConfig config( + std::make_shared(std::move(configFromFile))); + const auto emptySession = std::make_unique( + std::unordered_map()); + + EXPECT_TRUE(config.isOrcUseColumnNames(emptySession.get())); + EXPECT_TRUE(config.isParquetUseColumnNames(emptySession.get())); + EXPECT_TRUE(config.isFileColumnNamesReadAsLowerCase(emptySession.get())); + EXPECT_EQ(config.maxCoalescedBytes(emptySession.get()), 100); + EXPECT_EQ(config.maxCoalescedDistanceBytes(emptySession.get()), 100 << 10); + EXPECT_EQ(config.prefetchRowGroups(), 4); + EXPECT_EQ(config.loadQuantum(emptySession.get()), 4 << 20); + EXPECT_EQ(config.filePreloadThreshold(), 16UL << 20); + EXPECT_TRUE(config.readStatsBasedFilterReorderDisabled(emptySession.get())); + EXPECT_TRUE(config.preserveFlatMapsInMemory(emptySession.get())); + EXPECT_TRUE(config.indexEnabled(emptySession.get())); + EXPECT_TRUE(config.readerCollectColumnCpuMetrics(emptySession.get())); + EXPECT_TRUE(config.fileMetadataCacheEnabled(emptySession.get())); + EXPECT_TRUE(config.pinFileMetadata(emptySession.get())); + EXPECT_EQ(config.orcFooterSpeculativeIoSize(emptySession.get()), 512UL << 10); + EXPECT_EQ( + config.parquetFooterSpeculativeIoSize(emptySession.get()), 1UL << 20); + EXPECT_EQ( + config.nimbleFooterSpeculativeIoSize(emptySession.get()), 4UL << 20); + EXPECT_TRUE(config.nimbleStringDecoderZeroCopy(emptySession.get())); + EXPECT_TRUE(config.nimblePreserveDictionaryEncoding(emptySession.get())); +} + +TEST(FileConfigTest, overrideSession) { + FileConfig config( + std::make_shared( + std::unordered_map())); + std::unordered_map sessionOverride = { + {FileConfig::kOrcUseColumnNamesSession, "true"}, + {FileConfig::kParquetUseColumnNamesSession, "true"}, + {FileConfig::kFileColumnNamesReadAsLowerCaseSession, "true"}, + {FileConfig::kIgnoreMissingFilesSession, "true"}, + {FileConfig::kMaxCoalescedDistanceSession, "3MB"}, + {FileConfig::kLoadQuantumSession, std::to_string(4 << 20)}, + {FileConfig::kReadStatsBasedFilterReorderDisabledSession, "true"}, + {FileConfig::kPreserveFlatMapsInMemorySession, "true"}, + {FileConfig::kIndexEnabledSession, "true"}, + {FileConfig::kReaderCollectColumnCpuMetricsSession, "true"}, + {FileConfig::kFileMetadataCacheEnabledSession, "true"}, + {FileConfig::kPinFileMetadataSession, "true"}, + {FileConfig::kOrcFooterSpeculativeIoSizeSession, + std::to_string(128UL << 10)}, + {FileConfig::kParquetFooterSpeculativeIoSizeSession, + std::to_string(512UL << 10)}, + {FileConfig::kNimbleFooterSpeculativeIoSizeSession, + std::to_string(2UL << 20)}, + {FileConfig::kNimbleStringDecoderZeroCopySession, "true"}, + {FileConfig::kNimblePreserveDictionaryEncodingSession, "true"}, + }; + const auto session = + std::make_unique(std::move(sessionOverride)); + + EXPECT_TRUE(config.isOrcUseColumnNames(session.get())); + EXPECT_TRUE(config.isParquetUseColumnNames(session.get())); + EXPECT_TRUE(config.isFileColumnNamesReadAsLowerCase(session.get())); + EXPECT_TRUE(config.ignoreMissingFiles(session.get())); + EXPECT_EQ(config.maxCoalescedDistanceBytes(session.get()), 3 << 20); + EXPECT_EQ(config.loadQuantum(session.get()), 4 << 20); + EXPECT_TRUE(config.readStatsBasedFilterReorderDisabled(session.get())); + EXPECT_TRUE(config.preserveFlatMapsInMemory(session.get())); + EXPECT_TRUE(config.indexEnabled(session.get())); + EXPECT_TRUE(config.readerCollectColumnCpuMetrics(session.get())); + EXPECT_TRUE(config.fileMetadataCacheEnabled(session.get())); + EXPECT_TRUE(config.pinFileMetadata(session.get())); + EXPECT_EQ(config.orcFooterSpeculativeIoSize(session.get()), 128UL << 10); + EXPECT_EQ(config.parquetFooterSpeculativeIoSize(session.get()), 512UL << 10); + EXPECT_EQ(config.nimbleFooterSpeculativeIoSize(session.get()), 2UL << 20); + EXPECT_TRUE(config.nimbleStringDecoderZeroCopy(session.get())); + EXPECT_TRUE(config.nimblePreserveDictionaryEncoding(session.get())); +} + +TEST(FileConfigTest, nullConfig) { + VELOX_ASSERT_THROW( + FileConfig(nullptr), "Config is null for FileConfig initialization"); +} + +TEST(FileConfigTest, invalidTimestampUnit) { + FileConfig config( + std::make_shared( + std::unordered_map())); + std::unordered_map sessionOverride = { + {FileConfig::kReadTimestampUnitSession, "5"}, + }; + const auto session = + std::make_unique(std::move(sessionOverride)); + VELOX_ASSERT_THROW( + config.readTimestampUnit(session.get()), "Invalid timestamp unit."); +} diff --git a/velox/connectors/hive/tests/FileConnectorSplitTest.cpp b/velox/connectors/hive/tests/FileConnectorSplitTest.cpp new file mode 100644 index 00000000000..08fc673b7a0 --- /dev/null +++ b/velox/connectors/hive/tests/FileConnectorSplitTest.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/FileConnectorSplit.h" +#include + +using namespace facebook::velox; +using namespace facebook::velox::connector::hive; + +TEST(FileConnectorSplitTest, construction) { + FileConnectorSplit split( + "connectorId", + "/path/to/file.parquet", + dwio::common::FileFormat::PARQUET, + 100, + 5000, + /*splitWeight=*/2, + /*cacheable=*/false); + + EXPECT_EQ(split.connectorId, "connectorId"); + EXPECT_EQ(split.filePath, "/path/to/file.parquet"); + EXPECT_EQ(split.fileFormat, dwio::common::FileFormat::PARQUET); + EXPECT_EQ(split.start, 100); + EXPECT_EQ(split.length, 5000); + EXPECT_EQ(split.splitWeight, 2); + EXPECT_FALSE(split.cacheable); + EXPECT_FALSE(split.properties.has_value()); +} + +TEST(FileConnectorSplitTest, defaults) { + FileConnectorSplit split( + "connectorId", "/file.orc", dwio::common::FileFormat::ORC); + + EXPECT_EQ(split.start, 0); + EXPECT_EQ(split.length, std::numeric_limits::max()); + EXPECT_EQ(split.splitWeight, 0); + EXPECT_TRUE(split.cacheable); + EXPECT_FALSE(split.properties.has_value()); +} + +TEST(FileConnectorSplitTest, size) { + FileConnectorSplit split( + "connectorId", "/file.dwrf", dwio::common::FileFormat::DWRF, 0, 12345); + + EXPECT_EQ(split.size(), 12345); +} + +TEST(FileConnectorSplitTest, getFileName) { + FileConnectorSplit split( + "connectorId", + "/path/to/data/part-00000.parquet", + dwio::common::FileFormat::PARQUET); + + EXPECT_EQ(split.getFileName(), "part-00000.parquet"); +} + +TEST(FileConnectorSplitTest, getFileNameNoSlash) { + FileConnectorSplit split( + "connectorId", "file.orc", dwio::common::FileFormat::ORC); + + EXPECT_EQ(split.getFileName(), "file.orc"); +} + +TEST(FileConnectorSplitTest, fileProperties) { + FileProperties props = {.fileSize = 1024, .modificationTime = 999}; + FileConnectorSplit split( + "connectorId", + "/file.parquet", + dwio::common::FileFormat::PARQUET, + 0, + std::numeric_limits::max(), + 0, + true, + props); + + ASSERT_TRUE(split.properties.has_value()); + EXPECT_EQ(split.properties->fileSize.value(), 1024); + EXPECT_EQ(split.properties->modificationTime.value(), 999); +} diff --git a/velox/connectors/hive/tests/FileConnectorUtilTest.cpp b/velox/connectors/hive/tests/FileConnectorUtilTest.cpp new file mode 100644 index 00000000000..ef70a1e2e34 --- /dev/null +++ b/velox/connectors/hive/tests/FileConnectorUtilTest.cpp @@ -0,0 +1,481 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/FileConnectorUtil.h" + +#include + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/io/IoStatistics.h" +#include "velox/connectors/hive/FileConfig.h" +#include "velox/connectors/hive/FileConnectorSplit.h" +#include "velox/connectors/hive/TableHandle.h" +#include "velox/dwio/dwrf/reader/DwrfReader.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/type/Filter.h" + +namespace facebook::velox::connector { + +class FileConnectorUtilTest : public exec::test::HiveConnectorTestBase { + protected: + struct QueryCtxHolder { + std::shared_ptr sessionProperties; + std::unique_ptr ctx; + }; + + QueryCtxHolder makeConnectorQueryCtx( + std::unordered_map sessionProps = {}) { + QueryCtxHolder holder; + holder.sessionProperties = + std::make_shared(std::move(sessionProps), true); + holder.ctx = std::make_unique( + pool_.get(), + pool_.get(), + holder.sessionProperties.get(), + nullptr, + common::PrefixSortConfig(), + nullptr, + nullptr, + "query.FileConnectorUtilTest", + "task.FileConnectorUtilTest", + "planNodeId.FileConnectorUtilTest", + 0, + ""); + return holder; + } + + std::shared_ptr makeFileConfig( + std::unordered_map props = {}) { + return std::make_shared( + std::make_shared(std::move(props))); + } + + std::shared_ptr makeSplit( + dwio::common::FileFormat format = dwio::common::FileFormat::DWRF, + const std::string& path = "/tmp/testfile") { + return std::make_shared( + "testConnectorId", path, format); + } + + std::string writeDataFile(const RowVectorPtr& data) { + auto path = exec::test::TempFilePath::create(); + auto filePath = path->getPath(); + tempPaths_.push_back(std::move(path)); + writeToFile(filePath, {data}); + return filePath; + } + + std::unique_ptr makeReader(const std::string& path) { + dwio::common::ReaderOptions readerOpts(pool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); + readerOpts.setFileFormat(dwio::common::FileFormat::DWRF); + auto readFile = std::make_shared(path); + auto input = std::make_unique( + std::move(readFile), readerOpts.memoryPool()); + return dwrf::DwrfReader::create(std::move(input), readerOpts); + } + + std::shared_ptr dataIoStats_ = + std::make_shared(); + std::shared_ptr metadataIoStats_ = + std::make_shared(); + + private: + std::vector> tempPaths_; +}; + +TEST_F(FileConnectorUtilTest, configureReaderOptions) { + auto fileConfig = makeFileConfig(); + + // Test with DWRF format. + { + auto holder = makeConnectorQueryCtx(); + auto split = makeSplit(dwio::common::FileFormat::DWRF); + dwio::common::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + hive::configureReaderOptions( + fileConfig, + holder.ctx.get(), + /*fileSchema=*/nullptr, + split, + /*tableParameters=*/{}, + readerOptions); + + EXPECT_EQ(readerOptions.fileFormat(), dwio::common::FileFormat::DWRF); + EXPECT_FALSE(readerOptions.fileColumnNamesReadAsLowerCase()); + EXPECT_FALSE(readerOptions.useColumnNamesForColumnMapping()); + } + + // Test with ORC format and useColumnNames enabled via session. + { + auto holder = makeConnectorQueryCtx( + {{hive::FileConfig::kOrcUseColumnNamesSession, "true"}}); + auto split = makeSplit(dwio::common::FileFormat::ORC); + dwio::common::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + hive::configureReaderOptions( + fileConfig, + holder.ctx.get(), + /*fileSchema=*/nullptr, + split, + /*tableParameters=*/{}, + readerOptions); + + EXPECT_EQ(readerOptions.fileFormat(), dwio::common::FileFormat::ORC); + EXPECT_TRUE(readerOptions.useColumnNamesForColumnMapping()); + } + + // Test with Parquet format and useColumnNames enabled via session. + { + auto holder = makeConnectorQueryCtx( + {{hive::FileConfig::kParquetUseColumnNamesSession, "true"}}); + auto split = makeSplit(dwio::common::FileFormat::PARQUET); + dwio::common::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + hive::configureReaderOptions( + fileConfig, + holder.ctx.get(), + /*fileSchema=*/nullptr, + split, + /*tableParameters=*/{}, + readerOptions); + + EXPECT_EQ(readerOptions.fileFormat(), dwio::common::FileFormat::PARQUET); + EXPECT_TRUE(readerOptions.useColumnNamesForColumnMapping()); + } + + // Test format mismatch throws. + { + auto holder = makeConnectorQueryCtx(); + auto split = makeSplit(dwio::common::FileFormat::DWRF); + dwio::common::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setFileFormat(dwio::common::FileFormat::PARQUET); + VELOX_ASSERT_THROW( + hive::configureReaderOptions( + fileConfig, + holder.ctx.get(), + /*fileSchema=*/nullptr, + split, + /*tableParameters=*/{}, + readerOptions), + "received splits of different formats"); + } +} + +TEST_F(FileConnectorUtilTest, configureRowReaderOptions) { + auto holder = makeConnectorQueryCtx(); + auto fileConfig = makeFileConfig(); + auto split = makeSplit(dwio::common::FileFormat::DWRF); + auto scanSpec = std::make_shared(""); + auto rowType = ROW({"c0"}, {BIGINT()}); + + dwio::common::RowReaderOptions rowReaderOptions; + hive::configureRowReaderOptions( + /*tableParameters=*/{}, + scanSpec, + /*metadataFilter=*/nullptr, + rowType, + split, + fileConfig, + holder.ctx->sessionProperties(), + /*ioExecutor=*/nullptr, + rowReaderOptions); + + EXPECT_EQ(rowReaderOptions.scanSpec(), scanSpec); + EXPECT_EQ(rowReaderOptions.offset(), 0); + EXPECT_EQ(rowReaderOptions.length(), std::numeric_limits::max()); +} + +TEST_F(FileConnectorUtilTest, configureRowReaderOptionsSkipRows) { + auto holder = makeConnectorQueryCtx(); + auto fileConfig = makeFileConfig(); + auto split = makeSplit(dwio::common::FileFormat::DWRF); + auto scanSpec = std::make_shared(""); + auto rowType = ROW({"c0"}, {BIGINT()}); + + std::unordered_map tableParameters = { + {dwio::common::TableParameter::kSkipHeaderLineCount, "5"}, + }; + + dwio::common::RowReaderOptions rowReaderOptions; + hive::configureRowReaderOptions( + tableParameters, + scanSpec, + /*metadataFilter=*/nullptr, + rowType, + split, + fileConfig, + holder.ctx->sessionProperties(), + /*ioExecutor=*/nullptr, + rowReaderOptions); + + EXPECT_EQ(rowReaderOptions.skipRows(), 5); +} + +TEST_F(FileConnectorUtilTest, configureRowReaderOptionsSplitRange) { + auto holder = makeConnectorQueryCtx(); + auto fileConfig = makeFileConfig(); + auto split = std::make_shared( + "testConnectorId", + "/tmp/testfile", + dwio::common::FileFormat::DWRF, + /*start=*/100, + /*length=*/5000); + auto scanSpec = std::make_shared(""); + auto rowType = ROW({"c0"}, {BIGINT()}); + + dwio::common::RowReaderOptions rowReaderOptions; + hive::configureRowReaderOptions( + /*tableParameters=*/{}, + scanSpec, + /*metadataFilter=*/nullptr, + rowType, + split, + fileConfig, + holder.ctx->sessionProperties(), + /*ioExecutor=*/nullptr, + rowReaderOptions); + + EXPECT_EQ(rowReaderOptions.offset(), 100); + EXPECT_EQ(rowReaderOptions.length(), 5000); +} + +TEST_F(FileConnectorUtilTest, testFiltersNoFilters) { + auto rowType = ROW({"c0"}, {BIGINT()}); + auto batch = + makeRowVector({"c0"}, {makeFlatVector(100, folly::identity)}); + auto filePath = writeDataFile(batch); + auto reader = makeReader(filePath); + + auto scanSpec = std::make_shared(""); + scanSpec->addField("c0", 0); + + EXPECT_TRUE( + hive::testFilters( + scanSpec.get(), + reader.get(), + filePath, + /*partitionKey=*/{}, + /*partitionKeysHandle=*/{}, + /*asLocalTime=*/false)); +} + +TEST_F(FileConnectorUtilTest, testFiltersPartitionKeyPasses) { + auto rowType = ROW({"c0"}, {BIGINT()}); + auto batch = + makeRowVector({"c0"}, {makeFlatVector(100, folly::identity)}); + auto filePath = writeDataFile(batch); + auto reader = makeReader(filePath); + + auto scanSpec = std::make_shared(""); + scanSpec->addField("c0", 0); + auto* dsSpec = scanSpec->addField("ds", 1); + dsSpec->setFilter( + std::make_unique( + std::vector{"2024-01-01"}, false)); + + std::unordered_map> partitionKeys = { + {"ds", "2024-01-01"}, + }; + + auto dsHandle = std::make_shared( + "ds", + hive::HiveColumnHandle::ColumnType::kPartitionKey, + VARCHAR(), + VARCHAR()); + std::unordered_map + partitionKeysHandle = { + {"ds", dsHandle}, + }; + + EXPECT_TRUE( + hive::testFilters( + scanSpec.get(), + reader.get(), + filePath, + partitionKeys, + partitionKeysHandle, + /*asLocalTime=*/false)); +} + +TEST_F(FileConnectorUtilTest, testFiltersPartitionKeyFails) { + auto rowType = ROW({"c0"}, {BIGINT()}); + auto batch = + makeRowVector({"c0"}, {makeFlatVector(100, folly::identity)}); + auto filePath = writeDataFile(batch); + auto reader = makeReader(filePath); + + auto scanSpec = std::make_shared(""); + scanSpec->addField("c0", 0); + auto* dsSpec = scanSpec->addField("ds", 1); + dsSpec->setFilter( + std::make_unique( + std::vector{"2024-01-01"}, false)); + + std::unordered_map> partitionKeys = { + {"ds", "2024-02-15"}, + }; + + auto dsHandle = std::make_shared( + "ds", + hive::HiveColumnHandle::ColumnType::kPartitionKey, + VARCHAR(), + VARCHAR()); + std::unordered_map + partitionKeysHandle = { + {"ds", dsHandle}, + }; + + EXPECT_FALSE( + hive::testFilters( + scanSpec.get(), + reader.get(), + filePath, + partitionKeys, + partitionKeysHandle, + /*asLocalTime=*/false)); +} + +TEST_F(FileConnectorUtilTest, testFiltersNullPartitionKeyRejectsNotNull) { + auto rowType = ROW({"c0"}, {BIGINT()}); + auto batch = + makeRowVector({"c0"}, {makeFlatVector(100, folly::identity)}); + auto filePath = writeDataFile(batch); + auto reader = makeReader(filePath); + + auto scanSpec = std::make_shared(""); + scanSpec->addField("c0", 0); + auto* dsSpec = scanSpec->addField("ds", 1); + dsSpec->setFilter( + std::make_unique( + std::vector{"2024-01-01"}, false)); + + std::unordered_map> partitionKeys = { + {"ds", std::nullopt}, + }; + + auto dsHandle = std::make_shared( + "ds", + hive::HiveColumnHandle::ColumnType::kPartitionKey, + VARCHAR(), + VARCHAR()); + std::unordered_map + partitionKeysHandle = { + {"ds", dsHandle}, + }; + + EXPECT_FALSE( + hive::testFilters( + scanSpec.get(), + reader.get(), + filePath, + partitionKeys, + partitionKeysHandle, + /*asLocalTime=*/false)); +} + +TEST_F(FileConnectorUtilTest, testFiltersIntegerPartitionKey) { + auto rowType = ROW({"c0"}, {BIGINT()}); + auto batch = + makeRowVector({"c0"}, {makeFlatVector(100, folly::identity)}); + auto filePath = writeDataFile(batch); + auto reader = makeReader(filePath); + + auto scanSpec = std::make_shared(""); + scanSpec->addField("c0", 0); + auto* yearSpec = scanSpec->addField("year", 1); + yearSpec->setFilter(std::make_unique(2024, 2024, false)); + + // Matching partition value. + { + std::unordered_map> partitionKeys = + {{"year", "2024"}}; + auto yearHandle = std::make_shared( + "year", + hive::HiveColumnHandle::ColumnType::kPartitionKey, + BIGINT(), + BIGINT()); + std::unordered_map + partitionKeysHandle = {{"year", yearHandle}}; + + EXPECT_TRUE( + hive::testFilters( + scanSpec.get(), + reader.get(), + filePath, + partitionKeys, + partitionKeysHandle, + /*asLocalTime=*/false)); + } + + // Non-matching partition value. + { + std::unordered_map> partitionKeys = + {{"year", "2023"}}; + auto yearHandle = std::make_shared( + "year", + hive::HiveColumnHandle::ColumnType::kPartitionKey, + BIGINT(), + BIGINT()); + std::unordered_map + partitionKeysHandle = {{"year", yearHandle}}; + + EXPECT_FALSE( + hive::testFilters( + scanSpec.get(), + reader.get(), + filePath, + partitionKeys, + partitionKeysHandle, + /*asLocalTime=*/false)); + } +} + +TEST_F(FileConnectorUtilTest, testFiltersMissingColumn) { + auto rowType = ROW({"c0"}, {BIGINT()}); + auto batch = + makeRowVector({"c0"}, {makeFlatVector(100, folly::identity)}); + auto filePath = writeDataFile(batch); + auto reader = makeReader(filePath); + + auto scanSpec = std::make_shared(""); + scanSpec->addField("c0", 0); + // Filter on a column that doesn't exist in the file and is not a partition + // key. This simulates schema evolution where the column was added later. + auto* newColSpec = scanSpec->addField("newCol", 1); + // Filter that rejects null -- should cause the split to be skipped since + // the column is missing (will be all nulls). + newColSpec->setFilter( + std::make_unique( + std::vector{"someValue"}, false)); + + EXPECT_FALSE( + hive::testFilters( + scanSpec.get(), + reader.get(), + filePath, + /*partitionKey=*/{}, + /*partitionKeysHandle=*/{}, + /*asLocalTime=*/false)); +} + +} // namespace facebook::velox::connector diff --git a/velox/connectors/hive/tests/FileHandleTest.cpp b/velox/connectors/hive/tests/FileHandleTest.cpp index 6c3e71c42fe..f0320d715f1 100644 --- a/velox/connectors/hive/tests/FileHandleTest.cpp +++ b/velox/connectors/hive/tests/FileHandleTest.cpp @@ -20,14 +20,15 @@ #include "velox/common/caching/SimpleLRUCache.h" #include "velox/common/file/File.h" #include "velox/common/file/FileSystems.h" -#include "velox/exec/tests/utils/TempFilePath.h" +#include "velox/common/testutil/TempFilePath.h" using namespace facebook::velox; +using namespace facebook::velox::common::testutil; TEST(FileHandleTest, localFile) { filesystems::registerLocalFileSystem(); - auto tempFile = exec::test::TempFilePath::create(); + auto tempFile = TempFilePath::create(); const auto& filename = tempFile->getPath(); remove(filename.c_str()); @@ -52,7 +53,7 @@ TEST(FileHandleTest, localFile) { TEST(FileHandleTest, localFileWithProperties) { filesystems::registerLocalFileSystem(); - auto tempFile = exec::test::TempFilePath::create(); + auto tempFile = TempFilePath::create(); const auto& filename = tempFile->getPath(); remove(filename.c_str()); diff --git a/velox/connectors/hive/tests/HiveConfigTest.cpp b/velox/connectors/hive/tests/HiveConfigTest.cpp index ef6262cd4b2..6a58752c134 100644 --- a/velox/connectors/hive/tests/HiveConfigTest.cpp +++ b/velox/connectors/hive/tests/HiveConfigTest.cpp @@ -23,8 +23,9 @@ using namespace facebook::velox::connector::hive; using facebook::velox::connector::hive::HiveConfig; TEST(HiveConfigTest, defaultConfig) { - HiveConfig hiveConfig(std::make_shared( - std::unordered_map())); + HiveConfig hiveConfig( + std::make_shared( + std::unordered_map())); const auto emptySession = std::make_unique( std::unordered_map()); ASSERT_EQ( @@ -54,6 +55,17 @@ TEST(HiveConfigTest, defaultConfig) { ASSERT_TRUE(hiveConfig.allowNullPartitionKeys(emptySession.get())); ASSERT_EQ(hiveConfig.loadQuantum(emptySession.get()), 8 << 20); ASSERT_FALSE(hiveConfig.preserveFlatMapsInMemory(emptySession.get())); + ASSERT_FALSE(hiveConfig.indexEnabled(emptySession.get())); + ASSERT_FALSE(hiveConfig.fileMetadataCacheEnabled(emptySession.get())); + ASSERT_EQ( + hiveConfig.orcFooterSpeculativeIoSize(emptySession.get()), 256UL << 10); + ASSERT_EQ( + hiveConfig.parquetFooterSpeculativeIoSize(emptySession.get()), + 256UL << 10); + ASSERT_EQ( + hiveConfig.nimbleFooterSpeculativeIoSize(emptySession.get()), 8UL << 20); + ASSERT_FALSE(hiveConfig.nimbleStringDecoderZeroCopy(emptySession.get())); + ASSERT_FALSE(hiveConfig.nimblePreserveDictionaryEncoding(emptySession.get())); } TEST(HiveConfigTest, overrideConfig) { @@ -77,7 +89,15 @@ TEST(HiveConfigTest, overrideConfig) { {HiveConfig::kReadStatsBasedFilterReorderDisabled, "true"}, {HiveConfig::kLoadQuantum, std::to_string(4 << 20)}, {HiveConfig::kMaxBucketCount, std::to_string(100'000)}, - {HiveConfig::kPreserveFlatMapsInMemory, "true"}}; + {HiveConfig::kPreserveFlatMapsInMemory, "true"}, + {HiveConfig::kIndexEnabled, "true"}, + {HiveConfig::kFileMetadataCacheEnabled, "true"}, + {HiveConfig::kOrcFooterSpeculativeIoSize, std::to_string(512UL << 10)}, + {HiveConfig::kParquetFooterSpeculativeIoSize, std::to_string(1UL << 20)}, + {HiveConfig::kNimbleFooterSpeculativeIoSize, std::to_string(4UL << 20)}, + {HiveConfig::kNimbleStringDecoderZeroCopy, "true"}, + {HiveConfig::kNimblePreserveDictionaryEncoding, "true"}, + }; HiveConfig hiveConfig( std::make_shared(std::move(configFromFile))); auto emptySession = std::make_shared( @@ -109,11 +129,22 @@ TEST(HiveConfigTest, overrideConfig) { ASSERT_EQ(hiveConfig.loadQuantum(emptySession.get()), 4 << 20); ASSERT_EQ(hiveConfig.maxBucketCount(emptySession.get()), 100'000); ASSERT_TRUE(hiveConfig.preserveFlatMapsInMemory(emptySession.get())); + ASSERT_TRUE(hiveConfig.indexEnabled(emptySession.get())); + ASSERT_TRUE(hiveConfig.fileMetadataCacheEnabled(emptySession.get())); + ASSERT_EQ( + hiveConfig.orcFooterSpeculativeIoSize(emptySession.get()), 512UL << 10); + ASSERT_EQ( + hiveConfig.parquetFooterSpeculativeIoSize(emptySession.get()), 1UL << 20); + ASSERT_EQ( + hiveConfig.nimbleFooterSpeculativeIoSize(emptySession.get()), 4UL << 20); + ASSERT_TRUE(hiveConfig.nimbleStringDecoderZeroCopy(emptySession.get())); + ASSERT_TRUE(hiveConfig.nimblePreserveDictionaryEncoding(emptySession.get())); } TEST(HiveConfigTest, overrideSession) { - HiveConfig hiveConfig(std::make_shared( - std::unordered_map())); + HiveConfig hiveConfig( + std::make_shared( + std::unordered_map())); std::unordered_map sessionOverride = { {HiveConfig::kInsertExistingPartitionsBehaviorSession, "OVERWRITE"}, {HiveConfig::kOrcUseColumnNamesSession, "true"}, @@ -128,6 +159,16 @@ TEST(HiveConfigTest, overrideSession) { {HiveConfig::kReadStatsBasedFilterReorderDisabledSession, "true"}, {HiveConfig::kLoadQuantumSession, std::to_string(4 << 20)}, {HiveConfig::kPreserveFlatMapsInMemorySession, "true"}, + {HiveConfig::kIndexEnabledSession, "true"}, + {HiveConfig::kFileMetadataCacheEnabledSession, "true"}, + {HiveConfig::kOrcFooterSpeculativeIoSizeSession, + std::to_string(128UL << 10)}, + {HiveConfig::kParquetFooterSpeculativeIoSizeSession, + std::to_string(512UL << 10)}, + {HiveConfig::kNimbleFooterSpeculativeIoSizeSession, + std::to_string(2UL << 20)}, + {HiveConfig::kNimbleStringDecoderZeroCopySession, "true"}, + {HiveConfig::kNimblePreserveDictionaryEncodingSession, "true"}, }; const auto session = std::make_unique(std::move(sessionOverride)); @@ -155,4 +196,12 @@ TEST(HiveConfigTest, overrideSession) { ASSERT_TRUE(hiveConfig.readStatsBasedFilterReorderDisabled(session.get())); ASSERT_EQ(hiveConfig.loadQuantum(session.get()), 4 << 20); ASSERT_TRUE(hiveConfig.preserveFlatMapsInMemory(session.get())); + ASSERT_TRUE(hiveConfig.indexEnabled(session.get())); + ASSERT_TRUE(hiveConfig.fileMetadataCacheEnabled(session.get())); + ASSERT_EQ(hiveConfig.orcFooterSpeculativeIoSize(session.get()), 128UL << 10); + ASSERT_EQ( + hiveConfig.parquetFooterSpeculativeIoSize(session.get()), 512UL << 10); + ASSERT_EQ(hiveConfig.nimbleFooterSpeculativeIoSize(session.get()), 2UL << 20); + ASSERT_TRUE(hiveConfig.nimbleStringDecoderZeroCopy(session.get())); + ASSERT_TRUE(hiveConfig.nimblePreserveDictionaryEncoding(session.get())); } diff --git a/velox/connectors/hive/tests/HiveConnectorSerDeTest.cpp b/velox/connectors/hive/tests/HiveConnectorSerDeTest.cpp index 345050cc55e..9bec9c0c5a7 100644 --- a/velox/connectors/hive/tests/HiveConnectorSerDeTest.cpp +++ b/velox/connectors/hive/tests/HiveConnectorSerDeTest.cpp @@ -115,26 +115,44 @@ TEST_F(HiveConnectorSerDeTest, hiveTableHandle) { auto rowType = ROW({"c0c0", "c1", "c2", "c3", "c4", "c5"}, {INTEGER(), BIGINT(), DOUBLE(), BOOLEAN(), BIGINT(), VARCHAR()}); - auto tableHandle = makeTableHandle( - common::test::SubfieldFiltersBuilder() - .add("c0.c0", isNotNull()) - .add( - "c1", - lessThanOrEqualHugeint(std::numeric_limits::max())) - .add("c2", greaterThanOrEqualDouble(3.1415)) - .add("c3", boolEqual(true)) - .add("c4", in({0xdeadbeaf, 0xcafecafe})) - .add("c2", notIn({0xdeadbeaf, 0xcafecafe})) - .add( - "c5", - orFilter(between("abc", "efg"), greaterThanOrEqual("dragon"))) - .build(), - parseExpr("c1 > c4 and c3 = true", rowType), - "hive_table", - ROW({"c0", "c1"}, {BIGINT(), VARCHAR()}), - true, - {{dwio::common::TableParameter::kSkipHeaderLineCount, "1"}}); - testSerde(*tableHandle); + + for (bool withIndexColumns : {false, true}) { + SCOPED_TRACE(fmt::format("withIndexColumns: {}", withIndexColumns)); + + std::vector indexColumns = withIndexColumns + ? std::vector{"c0c0", "c1"} + : std::vector{}; + + auto tableHandle = makeTableHandle( + common::test::SubfieldFiltersBuilder() + .add("c0.c0", isNotNull()) + .add( + "c1", + lessThanOrEqualHugeint(std::numeric_limits::max())) + .add("c2", greaterThanOrEqualDouble(3.1415)) + .add("c3", boolEqual(true)) + .add("c4", in({0xdeadbeaf, 0xcafecafe})) + .add("c2", notIn({0xdeadbeaf, 0xcafecafe})) + .add( + "c5", + orFilter(between("abc", "efg"), greaterThanOrEqual("dragon"))) + .build(), + parseExpr("c1 > c4 and c3 = true", rowType), + "hive_table", + ROW({"c0", "c1"}, {BIGINT(), VARCHAR()}), + indexColumns, + {{dwio::common::TableParameter::kSkipHeaderLineCount, "1"}}); + + EXPECT_EQ(tableHandle->supportsIndexLookup(), withIndexColumns); + EXPECT_TRUE(tableHandle->needsIndexSplit()); + EXPECT_EQ(tableHandle->indexColumns().empty(), !withIndexColumns); + if (withIndexColumns) { + EXPECT_EQ(tableHandle->indexColumns().size(), 2); + EXPECT_EQ(tableHandle->indexColumns()[0], "c0c0"); + EXPECT_EQ(tableHandle->indexColumns()[1], "c1"); + } + testSerde(*tableHandle); + } } TEST_F(HiveConnectorSerDeTest, hiveColumnHandle) { @@ -150,10 +168,11 @@ TEST_F(HiveConnectorSerDeTest, hiveColumnHandle) { }); auto columnHandleTypes = { - HiveColumnHandle::ColumnType::kPartitionKey, - HiveColumnHandle::ColumnType::kRegular, - HiveColumnHandle::ColumnType::kSynthesized, - HiveColumnHandle::ColumnType::kRowIndex, + FileColumnHandle::ColumnType::kPartitionKey, + FileColumnHandle::ColumnType::kRegular, + FileColumnHandle::ColumnType::kSynthesized, + FileColumnHandle::ColumnType::kRowIndex, + FileColumnHandle::ColumnType::kRowId, }; for (auto columnHandleType : columnHandleTypes) { @@ -211,6 +230,11 @@ TEST_F(HiveConnectorSerDeTest, hiveInsertTableHandle) { {"key2", "value2"}, }; + std::unordered_map storageParameters = { + {"key3", "value3"}, + {"key4", "value4"}, + }; + auto hiveInsertTableHandle = exec::test::HiveConnectorTestBase::makeHiveInsertTableHandle( tableColumnNames, @@ -220,7 +244,10 @@ TEST_F(HiveConnectorSerDeTest, hiveInsertTableHandle) { locationHandle, dwio::common::FileFormat::NIMBLE, common::CompressionKind::CompressionKind_SNAPPY, - serdeParameters); + serdeParameters, + nullptr, // writerOptions + false, // ensureFiles + storageParameters); testSerde(*hiveInsertTableHandle); } diff --git a/velox/connectors/hive/tests/HiveConnectorTest.cpp b/velox/connectors/hive/tests/HiveConnectorTest.cpp index 6fdc8d130c9..180e347b4b6 100644 --- a/velox/connectors/hive/tests/HiveConnectorTest.cpp +++ b/velox/connectors/hive/tests/HiveConnectorTest.cpp @@ -18,9 +18,15 @@ #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/file/FileSystems.h" +#include "velox/connectors/hive/ExtractionUtils.h" +#include "velox/connectors/hive/FileHandle.h" #include "velox/connectors/hive/HiveConfig.h" #include "velox/connectors/hive/HiveConnectorUtil.h" #include "velox/connectors/hive/HiveDataSource.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/expression/ExprConstants.h" #include "velox/expression/ExprToSubfieldFilter.h" namespace facebook::velox::connector::hive { @@ -92,17 +98,20 @@ TEST_F(HiveConnectorTest, makeScanSpecRequiredSubfieldsMultilevel) { auto rowType = ROW({{"c0", columnType}}); auto subfields = makeSubfields({"c0.c0c1[3][\"foo\"].c0c1c0"}); for (bool statsBasedFilterReorderDisabled : {false, true}) { - SCOPED_TRACE(fmt::format( - "statsBasedFilterReorderDisabled {}", statsBasedFilterReorderDisabled)); + SCOPED_TRACE( + fmt::format( + "statsBasedFilterReorderDisabled {}", + statsBasedFilterReorderDisabled)); auto scanSpec = makeScanSpec( rowType, groupSubfields(subfields), - {}, - nullptr, - {}, - {}, - {}, + /*subfieldFilters=*/{}, + /*indexColumns=*/{}, + /*dataColumns=*/nullptr, + /*partitionKeys=*/{}, + /*infoColumns=*/{}, + /*specialColumns=*/{}, statsBasedFilterReorderDisabled, pool_.get()); ASSERT_EQ( @@ -141,12 +150,13 @@ TEST_F(HiveConnectorTest, makeScanSpecRequiredSubfieldsMergeFields) { rowType, groupSubfields(makeSubfields( {"c0.c0c0.c0c0c0", "c0.c0c0.c0c0c2", "c0.c0c1", "c0.c0c1.c0c1c0"})), - {}, - nullptr, - {}, - {}, - {}, - false, + /*subfieldFilters=*/{}, + /*indexColumns=*/{}, + /*dataColumns=*/nullptr, + /*partitionKeys=*/{}, + /*infoColumns=*/{}, + /*specialColumns=*/{}, + /*disableStatsBasedFilterReorder=*/false, pool_.get()); auto* c0c0 = scanSpec->childByName("c0")->childByName("c0c0"); ASSERT_FALSE(c0c0->childByName("c0c0c0")->isConstant()); @@ -166,12 +176,13 @@ TEST_F(HiveConnectorTest, makeScanSpecRequiredSubfieldsMergeArray) { auto scanSpec = makeScanSpec( rowType, groupSubfields(makeSubfields({"c0[1].c0c0", "c0[2].c0c2"})), - {}, - nullptr, - {}, - {}, - {}, - false, + /*subfieldFilters=*/{}, + /*indexColumns=*/{}, + /*dataColumns=*/nullptr, + /*partitionKeys=*/{}, + /*infoColumns=*/{}, + /*specialColumns=*/{}, + /*disableStatsBasedFilterReorder=*/false, pool_.get()); auto* c0 = scanSpec->childByName("c0"); ASSERT_EQ(c0->maxArrayElementsCount(), 2); @@ -192,12 +203,13 @@ TEST_F(HiveConnectorTest, makeScanSpecRequiredSubfieldsMergeArrayNegative) { makeScanSpec( rowType, groupedSubfields, - {}, - nullptr, - {}, - {}, - {}, - false, + /*subfieldFilters=*/{}, + /*indexColumns=*/{}, + /*dataColumns=*/nullptr, + /*partitionKeys=*/{}, + /*infoColumns=*/{}, + /*specialColumns=*/{}, + /*disableStatsBasedFilterReorder=*/false, pool_.get()), "Non-positive array subscript cannot be push down"); } @@ -210,12 +222,13 @@ TEST_F(HiveConnectorTest, makeScanSpecRequiredSubfieldsMergeMap) { auto scanSpec = makeScanSpec( rowType, groupSubfields(makeSubfields({"c0[10].c0c0", "c0[20].c0c2"})), - {}, - nullptr, - {}, - {}, - {}, - false, + /*subfieldFilters=*/{}, + /*indexColumns=*/{}, + /*dataColumns=*/nullptr, + /*partitionKeys=*/{}, + /*infoColumns=*/{}, + /*specialColumns=*/{}, + /*disableStatsBasedFilterReorder=*/false, pool_.get()); auto* c0 = scanSpec->childByName("c0"); ASSERT_EQ( @@ -241,12 +254,13 @@ TEST_F( auto scanSpec = makeScanSpec( rowType, groupSubfields(makeSubfields({"c0[\"foo\"]"})), - {}, - nullptr, - {}, - {}, - {}, - false, + /*subfieldFilters=*/{}, + /*indexColumns=*/{}, + /*dataColumns=*/nullptr, + /*partitionKeys=*/{}, + /*infoColumns=*/{}, + /*specialColumns=*/{}, + /*disableStatsBasedFilterReorder=*/false, pool_.get()); auto* c0 = scanSpec->childByName("c0"); ASSERT_EQ(c0->flatMapFeatureSelection(), std::vector({"foo"})); @@ -263,12 +277,13 @@ TEST_F(HiveConnectorTest, makeScanSpecRequiredSubfieldsAllSubscripts) { auto scanSpec = makeScanSpec( rowType, groupSubfields(makeSubfields({path})), - {}, - nullptr, - {}, - {}, - {}, - false, + /*subfieldFilters=*/{}, + /*indexColumns=*/{}, + /*dataColumns=*/nullptr, + /*partitionKeys=*/{}, + /*infoColumns=*/{}, + /*specialColumns=*/{}, + /*disableStatsBasedFilterReorder=*/false, pool_.get()); auto* c0 = scanSpec->childByName("c0"); ASSERT_TRUE(c0->flatMapFeatureSelection().empty()); @@ -285,12 +300,13 @@ TEST_F(HiveConnectorTest, makeScanSpecRequiredSubfieldsAllSubscripts) { auto scanSpec = makeScanSpec( rowType, groupSubfields(makeSubfields({"c0[*][*].c0c0"})), - {}, - nullptr, - {}, - {}, - {}, - false, + /*subfieldFilters=*/{}, + /*indexColumns=*/{}, + /*dataColumns=*/nullptr, + /*partitionKeys=*/{}, + /*infoColumns=*/{}, + /*specialColumns=*/{}, + /*disableStatsBasedFilterReorder=*/false, pool_.get()); auto* c0 = scanSpec->childByName("c0"); ASSERT_TRUE(mapKeyIsNotNull(*c0)); @@ -310,12 +326,13 @@ TEST_F(HiveConnectorTest, makeScanSpecRequiredSubfieldsDoubleMapKey) { auto scanSpec = makeScanSpec( rowType, groupSubfields(makeSubfields({"c0[0]", "c1[-1]"})), - {}, - nullptr, - {}, - {}, - {}, - false, + /*subfieldFilters=*/{}, + /*indexColumns=*/{}, + /*dataColumns=*/nullptr, + /*partitionKeys=*/{}, + /*infoColumns=*/{}, + /*specialColumns=*/{}, + /*disableStatsBasedFilterReorder=*/false, pool_.get()); auto* keysFilter = scanSpec->childByName("c0") ->childByName(ScanSpec::kMapKeysFieldName) @@ -340,12 +357,13 @@ TEST_F(HiveConnectorTest, makeScanSpecRequiredSubfieldsDoubleMapKey) { rowType, groupSubfields(makeSubfields( {"c0[-9223372036854775808]", "c1[9223372036854775807]"})), - {}, - nullptr, - {}, - {}, - {}, - false, + /*subfieldFilters=*/{}, + /*indexColumns=*/{}, + /*dataColumns=*/nullptr, + /*partitionKeys=*/{}, + /*infoColumns=*/{}, + /*specialColumns=*/{}, + /*disableStatsBasedFilterReorder=*/false, pool_.get()); keysFilter = scanSpec->childByName("c0") ->childByName(ScanSpec::kMapKeysFieldName) @@ -361,12 +379,13 @@ TEST_F(HiveConnectorTest, makeScanSpecRequiredSubfieldsDoubleMapKey) { rowType, groupSubfields(makeSubfields( {"c0[9223372036854775807]", "c0[-9223372036854775808]"})), - {}, - nullptr, - {}, - {}, - {}, - false, + /*subfieldFilters=*/{}, + /*indexColumns=*/{}, + /*dataColumns=*/nullptr, + /*partitionKeys=*/{}, + /*infoColumns=*/{}, + /*specialColumns=*/{}, + /*disableStatsBasedFilterReorder=*/false, pool_.get()); keysFilter = scanSpec->childByName("c0") ->childByName(ScanSpec::kMapKeysFieldName) @@ -379,12 +398,13 @@ TEST_F(HiveConnectorTest, makeScanSpecRequiredSubfieldsDoubleMapKey) { scanSpec = makeScanSpec( rowType, groupSubfields(makeSubfields({"c0[-100000000]", "c0[100000000]"})), - {}, - nullptr, - {}, - {}, - {}, - false, + /*subfieldFilters=*/{}, + /*indexColumns=*/{}, + /*dataColumns=*/nullptr, + /*partitionKeys=*/{}, + /*infoColumns=*/{}, + /*specialColumns=*/{}, + /*disableStatsBasedFilterReorder=*/false, pool_.get()); keysFilter = scanSpec->childByName("c0") ->childByName(ScanSpec::kMapKeysFieldName) @@ -422,11 +442,12 @@ TEST_F(HiveConnectorTest, makeScanSpecRequiredSubfieldsOnlyInFilters) { readerOutputType, groupSubfields(makeSubfields({"c0.c0c1", "c0.c0c3"})), filters, - ROW({{"c0", c0Type}, {"c1", c1Type}}), - {}, - {}, - {}, - false, + /*indexColumns=*/{}, + /*dataColumns=*/ROW({{"c0", c0Type}, {"c1", c1Type}}), + /*partitionKeys=*/{}, + /*infoColumns=*/{}, + /*specialColumns=*/{}, + /*disableStatsBasedFilterReorder=*/false, pool_.get()); auto c0 = scanSpec->childByName("c0"); @@ -503,12 +524,13 @@ TEST_F(HiveConnectorTest, makeScanSpecDuplicateSubfields) { rowType, groupSubfields(makeSubfields( {"c0[10][1]", "c0[10][2]", "c1[\"foo\"][1]", "c1[\"foo\"][2]"})), - {}, - nullptr, - {}, - {}, - {}, - false, + /*subfieldFilters=*/{}, + /*indexColumns=*/{}, + /*dataColumns=*/nullptr, + /*partitionKeys=*/{}, + /*infoColumns=*/{}, + /*specialColumns=*/{}, + /*disableStatsBasedFilterReorder=*/false, pool_.get()); auto* c0 = scanSpec->childByName("c0"); ASSERT_EQ(c0->children().size(), 2); @@ -523,13 +545,14 @@ TEST_F(HiveConnectorTest, makeScanSpecFilterPartitionKey) { filters.emplace(Subfield("ds"), exec::equal("2023-10-13")); auto scanSpec = makeScanSpec( rowType, - {}, + /*outputSubfields=*/{}, filters, - rowType, - {{"ds", nullptr}}, - {}, - {}, - false, + /*indexColumns=*/{}, + /*dataColumns=*/rowType, + /*partitionKeys=*/{{"ds", nullptr}}, + /*infoColumns=*/{}, + /*specialColumns=*/{}, + /*disableStatsBasedFilterReorder=*/false, pool_.get()); ASSERT_TRUE(scanSpec->childByName("c0")->projectOut()); ASSERT_FALSE(scanSpec->childByName("ds")->projectOut()); @@ -544,12 +567,13 @@ TEST_F(HiveConnectorTest, makeScanSpecPrunedMapNonNullMapKey) { auto scanSpec = makeScanSpec( rowType, groupSubfields(makeSubfields({"c0.c0c1"})), - {}, - nullptr, - {}, - {}, - {}, - false, + /*subfieldFilters=*/{}, + /*indexColumns=*/{}, + /*dataColumns=*/nullptr, + /*partitionKeys=*/{}, + /*infoColumns=*/{}, + /*specialColumns=*/{}, + /*disableStatsBasedFilterReorder=*/false, pool_.get()); auto* c0 = scanSpec->childByName("c0"); ASSERT_EQ(c0->children().size(), 2); @@ -560,12 +584,13 @@ TEST_F(HiveConnectorTest, makeScanSpecPrunedMapNonNullMapKey) { scanSpec = makeScanSpec( rowType, groupSubfields(makeSubfields({"c0.c0c0"})), - {}, - nullptr, - {}, - {}, - {}, - false, + /*subfieldFilters=*/{}, + /*indexColumns=*/{}, + /*dataColumns=*/nullptr, + /*partitionKeys=*/{}, + /*infoColumns=*/{}, + /*specialColumns=*/{}, + /*disableStatsBasedFilterReorder=*/false, pool_.get()); c0 = scanSpec->childByName("c0"); ASSERT_EQ(c0->children().size(), 2); @@ -581,8 +606,8 @@ TEST_F(HiveConnectorTest, extractFiltersFromRemainingFilter) { auto expr = parseExpr("not (c0 > 0 or c1 > 0)", rowType); SubfieldFilters filters; double sampleRate = 1; - auto remaining = extractFiltersFromRemainingFilter( - expr, &evaluator, false, filters, sampleRate); + auto remaining = + extractFiltersFromRemainingFilter(expr, &evaluator, filters, sampleRate); ASSERT_FALSE(remaining); ASSERT_EQ(sampleRate, 1); ASSERT_EQ(filters.size(), 2); @@ -591,8 +616,8 @@ TEST_F(HiveConnectorTest, extractFiltersFromRemainingFilter) { expr = parseExpr("not (c0 > 0 or c1 > c0)", rowType); filters.clear(); - remaining = extractFiltersFromRemainingFilter( - expr, &evaluator, false, filters, sampleRate); + remaining = + extractFiltersFromRemainingFilter(expr, &evaluator, filters, sampleRate); ASSERT_EQ(sampleRate, 1); ASSERT_EQ(filters.size(), 1); ASSERT_GT(filters.count(Subfield("c0")), 0); @@ -602,14 +627,70 @@ TEST_F(HiveConnectorTest, extractFiltersFromRemainingFilter) { expr = parseExpr( "not (c2 > 1::decimal(20, 0) or c2 < 0::decimal(20, 0))", rowType); filters.clear(); - remaining = extractFiltersFromRemainingFilter( - expr, &evaluator, false, filters, sampleRate); + remaining = + extractFiltersFromRemainingFilter(expr, &evaluator, filters, sampleRate); ASSERT_EQ(sampleRate, 1); ASSERT_GT(filters.count(Subfield("c2")), 0); // Change these once HUGEINT filter merge is fixed. ASSERT_TRUE(remaining); ASSERT_EQ( remaining->toString(), "not(lt(ROW[\"c2\"],cast(0 as DECIMAL(20, 0))))"); + + // parseExpr gives AND/OR with 2 arguments. We need to construct the node + // manually to have more than 2. + expr = std::make_shared( + BOOLEAN(), + expression::kAnd, + parseExpr("c0 > 0", rowType), + parseExpr("c1 > 0", rowType), + parseExpr("c2 > 0::decimal(20, 0)", rowType)); + filters.clear(); + remaining = + extractFiltersFromRemainingFilter(expr, &evaluator, filters, sampleRate); + ASSERT_EQ(sampleRate, 1); + ASSERT_EQ(filters.size(), 3); + ASSERT_TRUE(filters.contains(Subfield("c0"))); + ASSERT_TRUE(filters.contains(Subfield("c1"))); + ASSERT_TRUE(filters.contains(Subfield("c2"))); + ASSERT_FALSE(remaining); + + expr = std::make_shared( + BOOLEAN(), + expression::kAnd, + parseExpr("c0 % 2 = 0", rowType), + parseExpr("c1 % 3 = 0", rowType), + parseExpr("c2 > 0::decimal(20, 0)", rowType)); + filters.clear(); + remaining = + extractFiltersFromRemainingFilter(expr, &evaluator, filters, sampleRate); + ASSERT_EQ(sampleRate, 1); + ASSERT_EQ(filters.size(), 1); + ASSERT_TRUE(filters.contains(Subfield("c2"))); + ASSERT_TRUE(remaining); + ASSERT_EQ( + remaining->toString(), + "and(eq(mod(ROW[\"c0\"],2),0),eq(mod(ROW[\"c1\"],3),0))"); + + // Test VARCHAR OR filter pushdown: + // n_name = 'FRANCE' OR n_name = 'GERMANY' should push down as + // BytesValues('FRANCE', 'GERMANY'). + { + auto varcharRowType = ROW({"n_name"}, {VARCHAR()}); + expr = parseExpr("n_name = 'FRANCE' or n_name = 'GERMANY'", varcharRowType); + filters.clear(); + remaining = extractFiltersFromRemainingFilter( + expr, &evaluator, filters, sampleRate); + ASSERT_FALSE(remaining); + ASSERT_EQ(sampleRate, 1); + ASSERT_EQ(filters.size(), 1); + ASSERT_TRUE(filters.contains(Subfield("n_name"))); + auto* filter = filters.at(Subfield("n_name")).get(); + ASSERT_TRUE(filter->is(FilterKind::kBytesValues)); + auto* bytesValues = filter->as(); + ASSERT_EQ(bytesValues->values().size(), 2); + ASSERT_TRUE(bytesValues->values().count("FRANCE")); + ASSERT_TRUE(bytesValues->values().count("GERMANY")); + } } TEST_F(HiveConnectorTest, prestoTableSampling) { @@ -620,8 +701,8 @@ TEST_F(HiveConnectorTest, prestoTableSampling) { auto expr = parseExpr("rand() < 0.5", rowType); SubfieldFilters filters; double sampleRate = 1; - auto remaining = extractFiltersFromRemainingFilter( - expr, &evaluator, false, filters, sampleRate); + auto remaining = + extractFiltersFromRemainingFilter(expr, &evaluator, filters, sampleRate); ASSERT_FALSE(remaining); ASSERT_EQ(sampleRate, 0.5); ASSERT_TRUE(filters.empty()); @@ -629,8 +710,8 @@ TEST_F(HiveConnectorTest, prestoTableSampling) { expr = parseExpr("c0 > 0 and rand() < 0.5", rowType); filters.clear(); sampleRate = 1; - remaining = extractFiltersFromRemainingFilter( - expr, &evaluator, false, filters, sampleRate); + remaining = + extractFiltersFromRemainingFilter(expr, &evaluator, filters, sampleRate); ASSERT_FALSE(remaining); ASSERT_EQ(sampleRate, 0.5); ASSERT_EQ(filters.size(), 1); @@ -639,8 +720,8 @@ TEST_F(HiveConnectorTest, prestoTableSampling) { expr = parseExpr("rand() < 0.5 and rand() < 0.5", rowType); filters.clear(); sampleRate = 1; - remaining = extractFiltersFromRemainingFilter( - expr, &evaluator, false, filters, sampleRate); + remaining = + extractFiltersFromRemainingFilter(expr, &evaluator, filters, sampleRate); ASSERT_FALSE(remaining); ASSERT_EQ(sampleRate, 0.25); ASSERT_TRUE(filters.empty()); @@ -648,13 +729,532 @@ TEST_F(HiveConnectorTest, prestoTableSampling) { expr = parseExpr("c0 > 0 or rand() < 0.5", rowType); filters.clear(); sampleRate = 1; - remaining = extractFiltersFromRemainingFilter( - expr, &evaluator, false, filters, sampleRate); + remaining = + extractFiltersFromRemainingFilter(expr, &evaluator, filters, sampleRate); ASSERT_TRUE(remaining); ASSERT_EQ(*remaining, *expr); ASSERT_EQ(sampleRate, 1); ASSERT_TRUE(filters.empty()); } +#define VELOX_ASSERT_FILTER(expected, actual) \ + ASSERT_TRUE(expected->testingEquals(*actual)) \ + << expected->toString() << " vs " << actual->toString(); + +TEST_F(HiveConnectorTest, disjuncts) { + auto queryCtx = core::QueryCtx::create(); + exec::SimpleExpressionEvaluator evaluator(queryCtx.get(), pool_.get()); + auto rowType = ROW({"c0", "c1", "c2"}, {BIGINT(), BIGINT(), DECIMAL(20, 0)}); + + { + auto expr = + parseExpr("(c0 > 0 and c0 < 10) or (c0 > 5 and c0 < 15)", rowType); + + SubfieldFilters filters; + double sampleRate = 1; + auto remaining = extractFiltersFromRemainingFilter( + expr, &evaluator, filters, sampleRate); + ASSERT_TRUE(remaining == nullptr); + ASSERT_EQ(sampleRate, 1); + ASSERT_EQ(filters.size(), 1); + VELOX_ASSERT_FILTER(exec::between(1, 14), filters.begin()->second); + } + + { + auto expr = parseExpr("(c0 between -10 and 12)", rowType); + + SubfieldFilters filters; + double sampleRate = 1; + auto remaining = extractFiltersFromRemainingFilter( + expr, &evaluator, filters, sampleRate); + ASSERT_TRUE(remaining == nullptr); + ASSERT_EQ(sampleRate, 1); + ASSERT_EQ(filters.size(), 1); + ASSERT_EQ(filters.begin()->first, Subfield("c0")); + VELOX_ASSERT_FILTER(exec::between(-10, 12), filters.begin()->second); + + expr = parseExpr("(c0 > 0 and c0 < 10) or (c0 > 5 and c0 < 15)", rowType); + remaining = extractFiltersFromRemainingFilter( + expr, &evaluator, filters, sampleRate); + ASSERT_TRUE(remaining == nullptr); + ASSERT_EQ(sampleRate, 1); + ASSERT_EQ(filters.size(), 1); + ASSERT_EQ(filters.begin()->first, Subfield("c0")); + VELOX_ASSERT_FILTER(exec::between(1, 12), filters.begin()->second); + } + + { + auto expr = parseExpr("c0 not in (1, 3) or c0 in (1, 2)", rowType); + SubfieldFilters filters; + double sampleRate = 1; + auto remaining = extractFiltersFromRemainingFilter( + expr, &evaluator, filters, sampleRate); + ASSERT_EQ(remaining, expr); + ASSERT_EQ(sampleRate, 1); + ASSERT_TRUE(filters.empty()); + } +} + +#undef VELOX_ASSERT_FILTER + +/// A mock filesystem that delegates to the local filesystem but captures +/// the FileOptions::fileReadOps passed to openFileForRead. Used to verify +/// that FileSplitReader::createReader() propagates table identity (dbName, +/// tableName) into fileReadOps. +class CapturingFileSystem : public filesystems::FileSystem { + public: + static constexpr std::string_view kScheme = "capture:"; + + static folly::F14FastMap& capturedFileReadOps() { + static folly::F14FastMap instance; + return instance; + } + + explicit CapturingFileSystem(std::shared_ptr config) + : FileSystem(std::move(config)) {} + + std::string name() const override { + return "capture"; + } + + std::string_view extractPath(std::string_view path) const override { + if (path.substr(0, kScheme.size()) == kScheme) { + return path.substr(kScheme.size()); + } + return path; + } + + std::unique_ptr openFileForRead( + std::string_view path, + const filesystems::FileOptions& options) override { + capturedFileReadOps() = options.fileReadOps; + auto localPath = extractPath(path); + return filesystems::getFileSystem(localPath, config_) + ->openFileForRead(localPath, options); + } + + std::unique_ptr openFileForWrite( + std::string_view, + const filesystems::FileOptions&) override { + VELOX_UNSUPPORTED(); + } + + void remove(std::string_view) override { + VELOX_UNSUPPORTED(); + } + + void rename(std::string_view, std::string_view, bool) override { + VELOX_UNSUPPORTED(); + } + + bool exists(std::string_view path) override { + auto localPath = extractPath(path); + return filesystems::getFileSystem(localPath, config_)->exists(localPath); + } + + std::vector list(std::string_view) override { + VELOX_UNSUPPORTED(); + } + + void mkdir(std::string_view, const filesystems::DirectoryOptions&) override { + VELOX_UNSUPPORTED(); + } + + void rmdir(std::string_view) override { + VELOX_UNSUPPORTED(); + } +}; + +TEST_F(HiveConnectorTest, fileReadOpsTableIdentityPropagation) { + // Register the capturing filesystem once. + static bool registered = false; + if (!registered) { + filesystems::registerFileSystem( + [](std::string_view path) { + return path.find(CapturingFileSystem::kScheme) == 0; + }, + [](std::shared_ptr config, std::string_view) { + return std::make_shared(std::move(config)); + }); + registered = true; + } + + // Write test data to a local temp file. + auto rowType = ROW({"c0"}, {BIGINT()}); + auto vector = makeRowVector({"c0"}, {makeFlatVector({1, 2, 3})}); + auto filePath = TempFilePath::create(); + writeToFile(filePath->getPath(), vector); + + // Create a table handle with dbName and tableName set. + auto tableHandle = std::make_shared( + kHiveConnectorId, + "test_table", + SubfieldFilters{}, + /*remainingFilter=*/nullptr, + /*dataColumns=*/nullptr, + /*indexColumns=*/std::vector{}, + /*tableParameters=*/std::unordered_map{}, + /*filterColumnHandles=*/std::vector{}, + /*sampleRate=*/1.0, + /*dbName=*/"test_db"); + + // Build the split using the capturing filesystem scheme so that + // openFileForRead captures the fileReadOps populated by FileSplitReader. + auto split = exec::test::HiveConnectorSplitBuilder( + fmt::format("capture:{}", filePath->getPath())) + .fileFormat(dwio::common::FileFormat::DWRF) + .build(); + + // Build and run a table scan. This exercises the full pipeline: + // FileSplitReader::createReader() -> FileHandleGenerator -> + // CapturingFileSystem. + auto plan = PlanBuilder() + .startTableScan() + .outputType(rowType) + .tableHandle(tableHandle) + .assignments(allRegularColumns(rowType)) + .endTableScan() + .planNode(); + + auto result = AssertQueryBuilder(plan).split(split).copyResults(pool_.get()); + ASSERT_EQ(result->size(), 3); + + // Verify that FileSplitReader propagated dbName and tableName into + // fileReadOps. + auto& captured = CapturingFileSystem::capturedFileReadOps(); + ASSERT_EQ(captured.at(std::string(kDbNameKey)), "test_db"); + ASSERT_EQ(captured.at(std::string(kTableNameKey)), "test_table"); +} + +/// Verifies that getRuntimeStats() merge logic allows IoStats (ReadFile-layer) +/// to override the storageReadBytes value from IoStatistics (DWIO-level). +TEST_F(HiveConnectorTest, ioStatsOverridesStorageReadBytes) { + // Step 1: Simulate DWIO-level IoStatistics with an inaccurate estimate. + auto ioStatistics = std::make_shared(); + ioStatistics->read().increment(200); // DWIO estimate (undercounts gaps) + + // Step 2: Simulate ReadFile-layer IoStats with ground-truth values. + IoStats ioStats; + ioStats.addCounter( + std::string(FileDataSource::kStorageReadBytes), + RuntimeCounter(50, RuntimeCounter::Unit::kBytes)); + ioStats.addCounter( + "extentCacheHitBytes", RuntimeCounter(250, RuntimeCounter::Unit::kBytes)); + + // Replicate the merge logic from FileDataSource::getRuntimeStats(). + std::unordered_map res; + + // IoStatistics inserts first (Step 2 in getRuntimeStats). + res.insert( + {std::string(FileDataSource::kStorageReadBytes), + RuntimeMetric( + ioStatistics->read().sum(), + ioStatistics->read().count(), + ioStatistics->read().min(), + ioStatistics->read().max(), + RuntimeCounter::Unit::kBytes)}); + + // IoStats merge (Step 3 in getRuntimeStats) -- override storageReadBytes. + const auto ioStatsMap = ioStats.stats(); + for (const auto& [key, value] : ioStatsMap) { + if (key == FileDataSource::kStorageReadBytes) { + res[std::string(key)] = value; + } else { + res.emplace(key, value); + } + } + + // Ground-truth storageReadBytes from IoStats should override DWIO estimate. + ASSERT_EQ(res.at("storageReadBytes").sum, 50); + ASSERT_EQ(res.at("storageReadBytes").unit, RuntimeCounter::Unit::kBytes); + // extentCacheHitBytes should be added as a new diagnostic counter. + ASSERT_EQ(res.at("extentCacheHitBytes").sum, 250); +} + +/// Verifies that without IoStats override, storageReadBytes retains the +/// IoStatistics value. +TEST_F(HiveConnectorTest, storageReadBytesWithoutOverride) { + auto ioStatistics = std::make_shared(); + ioStatistics->read().increment(200); + + // IoStats has unrelated counters only -- no storageReadBytes. + IoStats ioStats; + ioStats.addCounter( + "wsInRegionReadBytes", RuntimeCounter(300, RuntimeCounter::Unit::kBytes)); + + std::unordered_map res; + res.insert( + {std::string(FileDataSource::kStorageReadBytes), + RuntimeMetric( + ioStatistics->read().sum(), + ioStatistics->read().count(), + ioStatistics->read().min(), + ioStatistics->read().max(), + RuntimeCounter::Unit::kBytes)}); + + const auto ioStatsMap = ioStats.stats(); + for (const auto& [key, value] : ioStatsMap) { + if (key == FileDataSource::kStorageReadBytes) { + res[std::string(key)] = value; + } else { + res.emplace(key, value); + } + } + + // storageReadBytes should retain the IoStatistics value. + ASSERT_EQ(res.at("storageReadBytes").sum, 200); + // Unrelated IoStats counters should still be added. + ASSERT_EQ(res.at("wsInRegionReadBytes").sum, 300); +} + +// --- ScanSpec extraction pushdown tests (file reader layer) --- + +TEST_F(HiveConnectorTest, extractionScanSpecMapKeepsBothChildren) { + // Map readers read keys and values together, so extraction from maps + // does not prune map children. The extraction is applied post-read. + auto hiveType = MAP(VARCHAR(), BIGINT()); + auto rowType = ROW({{"col", hiveType}}); + auto scanSpec = makeScanSpec( + rowType, + /*outputSubfields=*/{}, + /*subfieldFilters=*/{}, + /*indexColumns=*/{}, + /*dataColumns=*/nullptr, + /*partitionKeys=*/{}, + /*infoColumns=*/{}, + /*specialColumns=*/{}, + /*disableStatsBasedFilterReorder=*/false, + pool_.get()); + + // MapKeys extraction -- map children should still be readable. + std::vector extractions = { + {"keys", + {ExtractionPathElement::simple(ExtractionStep::kMapKeys)}, + ARRAY(VARCHAR())}}; + + auto* colSpec = scanSpec->childByName("col"); + ASSERT_NE(colSpec, nullptr); + configureExtractionScanSpec(hiveType, extractions, *colSpec, pool_.get()); + + auto* keysSpec = colSpec->childByName(ScanSpec::kMapKeysFieldName); + ASSERT_NE(keysSpec, nullptr); + ASSERT_FALSE(keysSpec->isConstant()); + auto* valuesSpec = colSpec->childByName(ScanSpec::kMapValuesFieldName); + ASSERT_NE(valuesSpec, nullptr); + ASSERT_FALSE(valuesSpec->isConstant()); +} + +TEST_F(HiveConnectorTest, extractionScanSpecSizeKeepsBothChildren) { + // Size extraction on a map -- map children should still be readable + // (extraction is post-read). + auto hiveType = MAP(VARCHAR(), BIGINT()); + auto rowType = ROW({{"col", hiveType}}); + auto scanSpec = makeScanSpec( + rowType, + /*outputSubfields=*/{}, + /*subfieldFilters=*/{}, + /*indexColumns=*/{}, + /*dataColumns=*/nullptr, + /*partitionKeys=*/{}, + /*infoColumns=*/{}, + /*specialColumns=*/{}, + /*disableStatsBasedFilterReorder=*/false, + pool_.get()); + + std::vector extractions = { + {"sz", {ExtractionPathElement::simple(ExtractionStep::kSize)}, BIGINT()}}; + + auto* colSpec = scanSpec->childByName("col"); + configureExtractionScanSpec(hiveType, extractions, *colSpec, pool_.get()); + + auto* keysSpec = colSpec->childByName(ScanSpec::kMapKeysFieldName); + ASSERT_FALSE(keysSpec->isConstant()); + auto* valuesSpec = colSpec->childByName(ScanSpec::kMapValuesFieldName); + ASSERT_FALSE(valuesSpec->isConstant()); +} + +TEST_F(HiveConnectorTest, extractionScanSpecStructFieldPruning) { + // Extraction with StructField should prune unneeded fields. + auto hiveType = ROW({{"x", INTEGER()}, {"y", DOUBLE()}, {"z", VARCHAR()}}); + auto rowType = ROW({{"col", hiveType}}); + auto scanSpec = makeScanSpec( + rowType, + /*outputSubfields=*/{}, + /*subfieldFilters=*/{}, + /*indexColumns=*/{}, + /*dataColumns=*/nullptr, + /*partitionKeys=*/{}, + /*infoColumns=*/{}, + /*specialColumns=*/{}, + /*disableStatsBasedFilterReorder=*/false, + pool_.get()); + + std::vector extractions = { + {"col_x", {ExtractionPathElement::structField("x")}, INTEGER()}}; + + auto* colSpec = scanSpec->childByName("col"); + configureExtractionScanSpec(hiveType, extractions, *colSpec, pool_.get()); + + // x should be readable, y and z should be constant null. + ASSERT_FALSE(colSpec->childByName("x")->isConstant()); + ASSERT_TRUE(colSpec->childByName("y")->isConstant()); + ASSERT_TRUE(colSpec->childByName("z")->isConstant()); +} + +TEST_F(HiveConnectorTest, scanSpecTransformApplied) { + // Verify that a transform set on ScanSpec is callable and produces + // the correct output type. + auto hiveType = MAP(VARCHAR(), BIGINT()); + auto scanSpec = std::make_shared("col"); + scanSpec->addFieldRecursively("col", *hiveType, 0); + + auto* colSpec = scanSpec->childByName("col"); + ASSERT_NE(colSpec, nullptr); + ASSERT_FALSE(colSpec->hasTransform()); + + // Set a MapKeys extraction transform. + auto chain = std::vector{ + ExtractionPathElement::simple(ExtractionStep::kMapKeys)}; + colSpec->setTransform( + [chain](const VectorPtr& input, memory::MemoryPool* pool) -> VectorPtr { + return applyExtractionChain(input, chain, pool); + }, + ARRAY(VARCHAR())); + + ASSERT_TRUE(colSpec->hasTransform()); + + // Create a test MapVector and apply the transform. + auto keys = makeFlatVector({"a", "b", "c"}); + auto values = makeFlatVector({1, 2, 3}); + auto mapVector = makeMapVector({0, 2}, keys, values); + + auto result = colSpec->transform()(mapVector, pool_.get()); + ASSERT_TRUE(result->type()->isArray()); + ASSERT_EQ(result->size(), 2); + auto* array = result->as(); + ASSERT_EQ(array->sizeAt(0), 2); + ASSERT_EQ(array->sizeAt(1), 1); +} + +TEST_F(HiveConnectorTest, extractionScanSpecMapKeyFilterString) { + // kMapKeyFilter with string keys should set an IN filter on the keys + // ScanSpec. ExtractionType should remain kNone since kMapKeyFilter is + // type-preserving. + auto hiveType = MAP(VARCHAR(), BIGINT()); + auto rowType = ROW({{"col", hiveType}}); + auto scanSpec = makeScanSpec( + rowType, + /*outputSubfields=*/{}, + /*subfieldFilters=*/{}, + /*indexColumns=*/{}, + /*dataColumns=*/nullptr, + /*partitionKeys=*/{}, + /*infoColumns=*/{}, + /*specialColumns=*/{}, + /*disableStatsBasedFilterReorder=*/false, + pool_.get()); + + std::vector extractions = { + {"filtered", + {ExtractionPathElement::mapKeyFilter( + std::vector{"a", "b"})}, + hiveType}}; + + auto* colSpec = scanSpec->childByName("col"); + ASSERT_NE(colSpec, nullptr); + configureExtractionScanSpec(hiveType, extractions, *colSpec, pool_.get()); + + // ExtractionType should remain kNone. + ASSERT_EQ(colSpec->extractionType(), ScanSpec::ExtractionType::kNone); + + // Keys ScanSpec should have an IN filter set. + auto* keysSpec = colSpec->childByName(ScanSpec::kMapKeysFieldName); + ASSERT_NE(keysSpec, nullptr); + ASSERT_NE(keysSpec->filter(), nullptr); + // Filter should pass "a" and "b" but not "c". + ASSERT_TRUE(keysSpec->filter()->testStringView(StringView("a"))); + ASSERT_TRUE(keysSpec->filter()->testStringView(StringView("b"))); + ASSERT_FALSE(keysSpec->filter()->testStringView(StringView("c"))); +} + +TEST_F(HiveConnectorTest, extractionScanSpecMapKeyFilterInt) { + // kMapKeyFilter with integer keys should set an IN filter on the keys + // ScanSpec. + auto hiveType = MAP(BIGINT(), VARCHAR()); + auto rowType = ROW({{"col", hiveType}}); + auto scanSpec = makeScanSpec( + rowType, + /*outputSubfields=*/{}, + /*subfieldFilters=*/{}, + /*indexColumns=*/{}, + /*dataColumns=*/nullptr, + /*partitionKeys=*/{}, + /*infoColumns=*/{}, + /*specialColumns=*/{}, + /*disableStatsBasedFilterReorder=*/false, + pool_.get()); + + std::vector extractions = { + {"filtered", + {ExtractionPathElement::mapKeyFilter(std::vector{10, 20, 30})}, + hiveType}}; + + auto* colSpec = scanSpec->childByName("col"); + ASSERT_NE(colSpec, nullptr); + configureExtractionScanSpec(hiveType, extractions, *colSpec, pool_.get()); + + // ExtractionType should remain kNone. + ASSERT_EQ(colSpec->extractionType(), ScanSpec::ExtractionType::kNone); + + // Keys ScanSpec should have an IN filter set. + auto* keysSpec = colSpec->childByName(ScanSpec::kMapKeysFieldName); + ASSERT_NE(keysSpec, nullptr); + ASSERT_NE(keysSpec->filter(), nullptr); + // Filter should pass 10, 20, 30 but not 5. + ASSERT_TRUE(keysSpec->filter()->testInt64(10)); + ASSERT_TRUE(keysSpec->filter()->testInt64(20)); + ASSERT_TRUE(keysSpec->filter()->testInt64(30)); + ASSERT_FALSE(keysSpec->filter()->testInt64(5)); +} + +TEST_F(HiveConnectorTest, extractionScanSpecMapKeyFilterThenMapKeys) { + // kMapKeyFilter followed by kMapKeys should set a filter on keys and then + // configure kKeys extraction on the remaining chain. + auto hiveType = MAP(VARCHAR(), BIGINT()); + auto rowType = ROW({{"col", hiveType}}); + auto scanSpec = makeScanSpec( + rowType, + /*outputSubfields=*/{}, + /*subfieldFilters=*/{}, + /*indexColumns=*/{}, + /*dataColumns=*/nullptr, + /*partitionKeys=*/{}, + /*infoColumns=*/{}, + /*specialColumns=*/{}, + /*disableStatsBasedFilterReorder=*/false, + pool_.get()); + + std::vector extractions = { + {"keys", + {ExtractionPathElement::mapKeyFilter(std::vector{"x", "y"}), + ExtractionPathElement::simple(ExtractionStep::kMapKeys)}, + ARRAY(VARCHAR())}}; + + auto* colSpec = scanSpec->childByName("col"); + ASSERT_NE(colSpec, nullptr); + configureExtractionScanSpec(hiveType, extractions, *colSpec, pool_.get()); + + // Filter should be set on keys. + auto* keysSpec = colSpec->childByName(ScanSpec::kMapKeysFieldName); + ASSERT_NE(keysSpec, nullptr); + ASSERT_NE(keysSpec->filter(), nullptr); + ASSERT_TRUE(keysSpec->filter()->testStringView(StringView("x"))); + ASSERT_FALSE(keysSpec->filter()->testStringView(StringView("z"))); + + // After stripping kMapKeyFilter, remaining chain is [kMapKeys], so + // kKeys extraction should be set. + ASSERT_EQ(colSpec->extractionType(), ScanSpec::ExtractionType::kKeys); +} + } // namespace } // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/tests/HiveConnectorUtilTest.cpp b/velox/connectors/hive/tests/HiveConnectorUtilTest.cpp index 7a0653ad22a..c7e22a8629f 100644 --- a/velox/connectors/hive/tests/HiveConnectorUtilTest.cpp +++ b/velox/connectors/hive/tests/HiveConnectorUtilTest.cpp @@ -16,10 +16,18 @@ #include "velox/connectors/hive/HiveConnectorUtil.h" #include +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/io/IoStatistics.h" #include "velox/connectors/hive/HiveConfig.h" #include "velox/connectors/hive/HiveConnectorSplit.h" #include "velox/connectors/hive/TableHandle.h" +#include "velox/core/Expressions.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/expression/Expr.h" +#include "velox/expression/ExprToSubfieldFilter.h" +#include "velox/expression/FieldReference.h" +#include "velox/parse/ExpressionsParser.h" +#include "velox/parse/TypeResolver.h" #include "velox/dwio/dwrf/writer/Writer.h" @@ -31,6 +39,24 @@ namespace facebook::velox::connector { using namespace dwio::common; +namespace { +// Unsupported types for createPointFilter and createRangeFilter with test +// values. +struct UnsupportedFilterType { + TypePtr type; + variant value; +}; + +const std::vector kUnsupportedFilterTypes = { + {TIMESTAMP(), variant(Timestamp(0, 0))}, + {ARRAY(BIGINT()), variant::array({variant::create(1)})}, + {MAP(VARCHAR(), BIGINT()), + variant::map({{variant("key"), variant::create(1)}})}, + {ROW({{"a", BIGINT()}}), + variant::row({variant::create(1)})}, +}; +} // namespace + class HiveConnectorUtilTest : public exec::test::HiveConnectorTestBase { protected: static bool compareSerDeOptions( @@ -43,6 +69,10 @@ class HiveConnectorUtilTest : public exec::test::HiveConnectorTestBase { std::shared_ptr pool_ = memory::memoryManager()->addLeafPool(); + std::shared_ptr dataIoStats_ = + std::make_shared(); + std::shared_ptr metadataIoStats_ = + std::make_shared(); }; TEST_F(HiveConnectorUtilTest, configureReaderOptions) { @@ -69,6 +99,8 @@ TEST_F(HiveConnectorUtilTest, configureReaderOptions) { // Dynamic parameters. dwio::common::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); FileFormat fileFormat{FileFormat::DWRF}; std::unordered_map tableParameters; std::unordered_map serdeParameters; @@ -78,10 +110,10 @@ TEST_F(HiveConnectorUtilTest, configureReaderOptions) { return std::make_shared( "testConnectorId", "testTable", - false, common::SubfieldFilters{}, nullptr, nullptr, + /*indexColumns=*/std::vector{}, tableParameters); }; @@ -103,11 +135,18 @@ TEST_F(HiveConnectorUtilTest, configureReaderOptions) { auto tableHandle = createTableHandle(); auto split = createSplit(); configureReaderOptions( - hiveConfig, connectorQueryCtx.get(), tableHandle, split, readerOptions); + hiveConfig, + connectorQueryCtx.get(), + tableHandle, + split, + split->serdeParameters, + readerOptions); }; auto clearDynamicParameters = [&](FileFormat newFileFormat) { readerOptions = dwio::common::ReaderOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); fileFormat = newFileFormat; tableParameters.clear(); serdeParameters.clear(); @@ -142,11 +181,12 @@ TEST_F(HiveConnectorUtilTest, configureReaderOptions) { readerOptions.fileColumnNamesReadAsLowerCase(), hiveConfig->isFileColumnNamesReadAsLowerCase(&sessionProperties)); checkUseColumnNamesForColumnMapping(); - EXPECT_EQ( - readerOptions.footerEstimatedSize(), hiveConfig->footerEstimatedSize()); EXPECT_EQ( readerOptions.filePreloadThreshold(), hiveConfig->filePreloadThreshold()); EXPECT_EQ(readerOptions.prefetchRowGroups(), hiveConfig->prefetchRowGroups()); + EXPECT_EQ( + readerOptions.fileMetadataCacheEnabled(), + hiveConfig->fileMetadataCacheEnabled(&sessionProperties)); // Modify field delimiter and change the file format. clearDynamicParameters(FileFormat::TEXT); @@ -234,9 +274,10 @@ TEST_F(HiveConnectorUtilTest, configureReaderOptions) { customHiveConfigProps[hive::HiveConfig::kFileColumnNamesReadAsLowerCase] = "true"; customHiveConfigProps[hive::HiveConfig::kOrcUseColumnNames] = "true"; - customHiveConfigProps[hive::HiveConfig::kFooterEstimatedSize] = "1111"; customHiveConfigProps[hive::HiveConfig::kFilePreloadThreshold] = "9999"; customHiveConfigProps[hive::HiveConfig::kPrefetchRowGroups] = "10"; + customHiveConfigProps[hive::HiveConfig::kFileMetadataCacheEnabled] = "true"; + customHiveConfigProps[hive::HiveConfig::kOrcFooterSpeculativeIoSize] = "1111"; hiveConfig = std::make_shared( std::make_shared(std::move(customHiveConfigProps))); performConfigure(); @@ -252,10 +293,12 @@ TEST_F(HiveConnectorUtilTest, configureReaderOptions) { readerOptions.fileColumnNamesReadAsLowerCase(), hiveConfig->isFileColumnNamesReadAsLowerCase(&sessionProperties)); EXPECT_EQ( - readerOptions.footerEstimatedSize(), hiveConfig->footerEstimatedSize()); + readerOptions.footerSpeculativeIoSize(), + hiveConfig->orcFooterSpeculativeIoSize(&sessionProperties)); EXPECT_EQ( readerOptions.filePreloadThreshold(), hiveConfig->filePreloadThreshold()); EXPECT_EQ(readerOptions.prefetchRowGroups(), hiveConfig->prefetchRowGroups()); + EXPECT_TRUE(readerOptions.fileMetadataCacheEnabled()); clearDynamicParameters(FileFormat::ORC); performConfigure(); checkUseColumnNamesForColumnMapping(); @@ -264,18 +307,210 @@ TEST_F(HiveConnectorUtilTest, configureReaderOptions) { checkUseColumnNamesForColumnMapping(); } +TEST_F(HiveConnectorUtilTest, footerSpeculativeIoSizeByFormat) { + config::ConfigBase sessionProperties{ + std::unordered_map{}}; + auto connectorQueryCtx = std::make_unique( + pool_.get(), + pool_.get(), + &sessionProperties, + nullptr, + common::PrefixSortConfig(), + nullptr, + nullptr, + "query.HiveConnectorUtilTest", + "task.HiveConnectorUtilTest", + "planNodeId.HiveConnectorUtilTest", + 0, + ""); + + std::unordered_map customHiveConfigProps; + customHiveConfigProps[hive::HiveConfig::kOrcFooterSpeculativeIoSize] = "1111"; + customHiveConfigProps[hive::HiveConfig::kParquetFooterSpeculativeIoSize] = + "2222"; + customHiveConfigProps[hive::HiveConfig::kNimbleFooterSpeculativeIoSize] = + "3333"; + auto hiveConfig = std::make_shared( + std::make_shared(std::move(customHiveConfigProps))); + + const std::unordered_map> + partitionKeys; + const std::unordered_map customSplitInfo; + std::unordered_map serdeParameters; + + auto createTableHandle = []() { + return std::make_shared( + "testConnectorId", + "testTable", + common::SubfieldFilters{}, + nullptr, + nullptr, + std::vector{}, + std::unordered_map{}); + }; + + auto createSplit = [&](FileFormat format) { + return std::make_shared( + "testConnectorId", + "/tmp/", + format, + 0UL, + std::numeric_limits::max(), + partitionKeys, + std::nullopt, + customSplitInfo, + nullptr, + serdeParameters); + }; + + // Test ORC format. + { + dwio::common::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto tableHandle = createTableHandle(); + auto split = createSplit(FileFormat::ORC); + configureReaderOptions( + hiveConfig, + connectorQueryCtx.get(), + tableHandle, + split, + split->serdeParameters, + readerOptions); + EXPECT_EQ( + readerOptions.footerSpeculativeIoSize(), + hiveConfig->orcFooterSpeculativeIoSize(&sessionProperties)); + EXPECT_EQ(readerOptions.footerSpeculativeIoSize(), 1111); + } + + // Test DWRF format (uses ORC config). + { + dwio::common::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto tableHandle = createTableHandle(); + auto split = createSplit(FileFormat::DWRF); + configureReaderOptions( + hiveConfig, + connectorQueryCtx.get(), + tableHandle, + split, + split->serdeParameters, + readerOptions); + EXPECT_EQ( + readerOptions.footerSpeculativeIoSize(), + hiveConfig->orcFooterSpeculativeIoSize(&sessionProperties)); + EXPECT_EQ(readerOptions.footerSpeculativeIoSize(), 1111); + } + + // Test Parquet format. + { + dwio::common::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto tableHandle = createTableHandle(); + auto split = createSplit(FileFormat::PARQUET); + configureReaderOptions( + hiveConfig, + connectorQueryCtx.get(), + tableHandle, + split, + split->serdeParameters, + readerOptions); + EXPECT_EQ( + readerOptions.footerSpeculativeIoSize(), + hiveConfig->parquetFooterSpeculativeIoSize(&sessionProperties)); + EXPECT_EQ(readerOptions.footerSpeculativeIoSize(), 2222); + } + + // Test Nimble format. + { + dwio::common::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto tableHandle = createTableHandle(); + auto split = createSplit(FileFormat::NIMBLE); + configureReaderOptions( + hiveConfig, + connectorQueryCtx.get(), + tableHandle, + split, + split->serdeParameters, + readerOptions); + EXPECT_EQ( + readerOptions.footerSpeculativeIoSize(), + hiveConfig->nimbleFooterSpeculativeIoSize(&sessionProperties)); + EXPECT_EQ(readerOptions.footerSpeculativeIoSize(), 3333); + } +} + +TEST_F(HiveConnectorUtilTest, fileMetadataCacheEnabledSessionOverride) { + // Verify default is off. + dwio::common::ReaderOptions defaultOptions(pool_.get()); + defaultOptions.setDataIoStats(dataIoStats_); + defaultOptions.setMetadataIoStats(metadataIoStats_); + ASSERT_FALSE(defaultOptions.fileMetadataCacheEnabled()); + + for (bool enabled : {true, false}) { + SCOPED_TRACE(fmt::format("fileMetadataCacheEnabled={}", enabled)); + + config::ConfigBase sessionProperties( + std::unordered_map{ + {hive::HiveConfig::kFileMetadataCacheEnabledSession, + enabled ? "true" : "false"}}); + auto connectorQueryCtx = std::make_unique( + pool_.get(), + pool_.get(), + &sessionProperties, + nullptr, + common::PrefixSortConfig(), + nullptr, + nullptr, + "query.HiveConnectorUtilTest", + "task.HiveConnectorUtilTest", + "planNodeId.HiveConnectorUtilTest", + 0, + ""); + auto hiveConfig = + std::make_shared(std::make_shared( + std::unordered_map())); + dwio::common::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + + auto tableHandle = std::make_shared( + "testConnectorId", + "testTable", + common::SubfieldFilters{}, + nullptr, + nullptr, + std::vector{}, + std::unordered_map{}); + auto split = std::make_shared( + "testConnectorId", "/tmp/", FileFormat::DWRF); + configureReaderOptions( + hiveConfig, + connectorQueryCtx.get(), + tableHandle, + split, + split->serdeParameters, + readerOptions); + ASSERT_EQ(readerOptions.fileMetadataCacheEnabled(), enabled); + } +} + TEST_F(HiveConnectorUtilTest, cacheRetention) { struct { bool splitCacheable; - bool expectedNoCacheRetention; + bool expectedCacheable; std::string debugString() const { return fmt::format( - "splitCacheable {}, expectedNoCacheRetention {}", + "splitCacheable {}, expectedCacheable {}", splitCacheable, - expectedNoCacheRetention); + expectedCacheable); } - } testSettings[] = {{false, true}, {true, false}}; + } testSettings[] = {{false, false}, {true, true}}; for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); @@ -300,14 +535,16 @@ TEST_F(HiveConnectorUtilTest, cacheRetention) { ""); dwio::common::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto tableHandle = std::make_shared( "testConnectorId", "testTable", - false, common::SubfieldFilters{}, nullptr, nullptr, + /*indexColumns=*/std::vector{}, std::unordered_map{}); auto hiveSplit = std::make_shared( @@ -329,10 +566,10 @@ TEST_F(HiveConnectorUtilTest, cacheRetention) { connectorQueryCtx.get(), tableHandle, hiveSplit, + hiveSplit->serdeParameters, readerOptions); - ASSERT_EQ( - readerOptions.noCacheRetention(), testData.expectedNoCacheRetention); + ASSERT_EQ(readerOptions.cacheable(), testData.expectedCacheable); } } @@ -349,9 +586,11 @@ TEST_F(HiveConnectorUtilTest, configureSstRowReaderOptions) { /*scanSpec=*/nullptr, /*metadataFilter=*/nullptr, /*rowType=*/nullptr, - /*hiveSplit=*/hiveSplit, - /*hiveConfig=*/nullptr, + /*fileSplit=*/hiveSplit, + /*serdeParameters=*/hiveSplit->serdeParameters, + /*fileConfig=*/nullptr, /*sessionProperties=*/nullptr, + /*ioExecutor=*/nullptr, /*rowReaderOptions=*/rowReaderOpts); EXPECT_EQ(rowReaderOpts.serdeParameters(), hiveSplit->serdeParameters); @@ -374,9 +613,11 @@ TEST_F(HiveConnectorUtilTest, configureRowReaderOptionsFromConfig) { /*scanSpec=*/nullptr, /*metadataFilter=*/nullptr, /*rowType=*/nullptr, - /*hiveSplit=*/hiveSplit, - /*hiveConfig=*/hiveConfig, + /*fileSplit=*/hiveSplit, + /*serdeParameters=*/{}, + /*fileConfig=*/hiveConfig, /*sessionProperties=*/&sessionProperties, + /*ioExecutor=*/nullptr, /*rowReaderOptions=*/rowReaderOpts); EXPECT_FALSE(rowReaderOpts.preserveFlatMapsInMemory()); @@ -399,9 +640,11 @@ TEST_F(HiveConnectorUtilTest, configureRowReaderOptionsFromConfig) { /*scanSpec=*/nullptr, /*metadataFilter=*/nullptr, /*rowType=*/nullptr, - /*hiveSplit=*/hiveSplit, - /*hiveConfig=*/hiveConfig, + /*fileSplit=*/hiveSplit, + /*serdeParameters=*/{}, + /*fileConfig=*/hiveConfig, /*sessionProperties=*/&sessionProperties, + /*ioExecutor=*/nullptr, /*rowReaderOptions=*/rowReaderOpts); EXPECT_TRUE(rowReaderOpts.preserveFlatMapsInMemory()); @@ -425,9 +668,11 @@ TEST_F(HiveConnectorUtilTest, configureRowReaderOptionsFromConfig) { /*scanSpec=*/nullptr, /*metadataFilter=*/nullptr, /*rowType=*/nullptr, - /*hiveSplit=*/hiveSplit, - /*hiveConfig=*/hiveConfig, + /*fileSplit=*/hiveSplit, + /*serdeParameters=*/{}, + /*fileConfig=*/hiveConfig, /*sessionProperties=*/&sessionProperties, + /*ioExecutor=*/nullptr, /*rowReaderOptions=*/rowReaderOpts); EXPECT_TRUE(rowReaderOpts.preserveFlatMapsInMemory()); @@ -452,13 +697,466 @@ TEST_F(HiveConnectorUtilTest, configureRowReaderOptionsFromConfig) { /*scanSpec=*/nullptr, /*metadataFilter=*/nullptr, /*rowType=*/nullptr, - /*hiveSplit=*/hiveSplit, - /*hiveConfig=*/hiveConfig, + /*fileSplit=*/hiveSplit, + /*serdeParameters=*/{}, + /*fileConfig=*/hiveConfig, /*sessionProperties=*/&sessionProperties, + /*ioExecutor=*/nullptr, /*rowReaderOptions=*/rowReaderOpts); EXPECT_TRUE(rowReaderOpts.preserveFlatMapsInMemory()); } } +TEST_F(HiveConnectorUtilTest, checkColumnHandleConsistent) { + // Create two consistent column handles + auto handle1 = std::make_shared( + "col1", hive::FileColumnHandle::ColumnType::kRegular, BIGINT(), BIGINT()); + auto handle2 = std::make_shared( + "col1", hive::FileColumnHandle::ColumnType::kRegular, BIGINT(), BIGINT()); + + // Should not throw for consistent handles + EXPECT_NO_THROW(hive::checkColumnHandleConsistent(*handle1, *handle2)); + + // Test inconsistent column type + auto handlePartition = std::make_shared( + "col1", + hive::FileColumnHandle::ColumnType::kPartitionKey, + BIGINT(), + BIGINT()); + VELOX_ASSERT_THROW( + hive::checkColumnHandleConsistent(*handle1, *handlePartition), + "Inconsistent column handle type: col1, expected Regular, got PartitionKey"); + + // Test inconsistent data type + auto handleVarchar = std::make_shared( + "col1", + hive::FileColumnHandle::ColumnType::kRegular, + VARCHAR(), + VARCHAR()); + VELOX_ASSERT_THROW( + hive::checkColumnHandleConsistent(*handle1, *handleVarchar), + "Inconsistent column handle data type: col1, expected BIGINT, got VARCHAR"); +} + +TEST_F(HiveConnectorUtilTest, makeScanSpecWithIndexColumns) { + // Data columns schema - all columns available in the file. + const auto dataColumns = ROW( + {{"a", BIGINT()}, + {"b", VARCHAR()}, + {"c", INTEGER()}, + {"d", DOUBLE()}, + {"e", ROW({{"x", INTEGER()}, {"y", VARCHAR()}})}}); + + struct TestCase { + std::string name; + RowTypePtr rowType; + folly::F14FastMap> + outputSubfields; + std::function makeSubfieldFilters; + std::vector indexColumns; + std::vector expectedColumns; + std::vector unexpectedColumns; + + std::string debugString() const { + return fmt::format( + "name: {}, indexColumns: [{}], expectedColumns: [{}], unexpectedColumns: [{}]", + name, + folly::join(", ", indexColumns), + folly::join(", ", expectedColumns), + folly::join(", ", unexpectedColumns)); + } + }; + + // Subfields for nested column 'e'. + const common::Subfield subfieldEx("e.x"); + const common::Subfield subfieldEy("e.y"); + + const std::vector testCases = { + { + "Index columns not in output projection", + ROW({"a", "b"}, {BIGINT(), VARCHAR()}), + /*outputSubfields=*/{}, + /*makeSubfieldFilters=*/nullptr, + /*indexColumns=*/{"c", "d"}, + /*expectedColumns=*/{"a", "b", "c", "d"}, + /*unexpectedColumns=*/{"e"}, + }, + { + "Index column already in output projection", + ROW({"a", "b"}, {BIGINT(), VARCHAR()}), + /*outputSubfields=*/{}, + /*makeSubfieldFilters=*/nullptr, + /*indexColumns=*/{"a", "c"}, + /*expectedColumns=*/{"a", "b", "c"}, + /*unexpectedColumns=*/{"d", "e"}, + }, + { + "Empty index columns", + ROW({"a", "b"}, {BIGINT(), VARCHAR()}), + /*outputSubfields=*/{}, + /*makeSubfieldFilters=*/nullptr, + /*indexColumns=*/{}, + /*expectedColumns=*/{"a", "b"}, + /*unexpectedColumns=*/{"c", "d", "e"}, + }, + { + "Output subfield without index columns", + ROW({"e"}, {ROW({{"x", INTEGER()}, {"y", VARCHAR()}})}), + /*outputSubfields=*/{{"e", {&subfieldEx}}}, + /*makeSubfieldFilters=*/nullptr, + /*indexColumns=*/{}, + /*expectedColumns=*/{"e"}, + /*unexpectedColumns=*/{"a", "b", "c", "d"}, + }, + { + "Output subfield with different index column", + ROW({"e"}, {ROW({{"x", INTEGER()}, {"y", VARCHAR()}})}), + /*outputSubfields=*/{{"e", {&subfieldEx}}}, + /*makeSubfieldFilters=*/nullptr, + /*indexColumns=*/{"c"}, + /*expectedColumns=*/{"e", "c"}, + /*unexpectedColumns=*/{"a", "b", "d"}, + }, + { + "Output subfield with same parent as index column", + ROW({"a", "e"}, + {BIGINT(), ROW({{"x", INTEGER()}, {"y", VARCHAR()}})}), + /*outputSubfields=*/{{"e", {&subfieldEx}}}, + /*makeSubfieldFilters=*/nullptr, + /*indexColumns=*/{"a", "e"}, + /*expectedColumns=*/{"a", "e"}, + /*unexpectedColumns=*/{"b", "c", "d"}, + }, + { + "Subfield filter without index column", + ROW({"a", "b"}, {BIGINT(), VARCHAR()}), + /*outputSubfields=*/{}, + /*makeSubfieldFilters=*/ + []() { + common::SubfieldFilters filters; + filters.emplace( + common::Subfield("c"), exec::greaterThanOrEqual(10)); + return filters; + }, + /*indexColumns=*/{"d"}, + /*expectedColumns=*/{"a", "b", "c", "d"}, + /*unexpectedColumns=*/{"e"}, + }, + { + "Subfield filter without index column", + ROW({"a", "b"}, {BIGINT(), VARCHAR()}), + /*outputSubfields=*/{}, + /*makeSubfieldFilters=*/ + []() { + common::SubfieldFilters filters; + filters.emplace( + common::Subfield("c"), exec::greaterThanOrEqual(10)); + return filters; + }, + /*indexColumns=*/{"c"}, + /*expectedColumns=*/{"a", "b", "c"}, + /*unexpectedColumns=*/{"e"}, + }, + }; + + for (const auto& testCase : testCases) { + SCOPED_TRACE(testCase.debugString()); + + common::SubfieldFilters subfieldFilters; + if (testCase.makeSubfieldFilters) { + subfieldFilters = testCase.makeSubfieldFilters(); + } + + auto scanSpec = hive::makeScanSpec( + testCase.rowType, + testCase.outputSubfields, + subfieldFilters, + testCase.indexColumns, + dataColumns, + /*partitionKeys=*/{}, + /*infoColumns=*/{}, + /*specialColumns=*/{}, + /*disableStatsBasedFilterReorder=*/false, + pool_.get()); + + for (const auto& col : testCase.expectedColumns) { + EXPECT_NE(scanSpec->childByName(col), nullptr) + << "Expected column " << col << " to be in scan spec"; + } + for (const auto& col : testCase.unexpectedColumns) { + EXPECT_EQ(scanSpec->childByName(col), nullptr) + << "Unexpected column " << col << " should not be in scan spec"; + } + } +} + +TEST_F(HiveConnectorUtilTest, makeScanSpecWithIndexColumnsError) { + // Test that makeScanSpec throws when index columns are set but dataColumns + // is null. + const auto rowType = ROW({"a", "b"}, {BIGINT(), VARCHAR()}); + + VELOX_ASSERT_THROW( + hive::makeScanSpec( + rowType, + /*outputSubfields=*/{}, + /*subfieldFilters=*/{}, + /*indexColumns=*/{"c"}, + /*dataColumns=*/nullptr, + /*partitionKeys=*/{}, + /*infoColumns=*/{}, + /*specialColumns=*/{}, + /*disableStatsBasedFilterReorder=*/false, + pool_.get()), + ""); +} + +TEST_F(HiveConnectorUtilTest, shouldEagerlyMaterialize) { + auto queryCtx = core::QueryCtx::create(); + auto execCtx = std::make_unique(pool_.get(), queryCtx.get()); + + auto compileExpression = [&](const std::string& expr, + const RowTypePtr& rowType) { + auto untyped = parse::DuckSqlExpressionsParser().parseExpr(expr); + auto typedExpr = + core::Expressions::inferTypes(untyped, rowType, pool_.get()); + std::vector expressions = {typedExpr}; + return std::make_unique( + std::move(expressions), execCtx.get()); + }; + + const auto rowType = ROW({"a", "b", "c"}, {BIGINT(), BIGINT(), BIGINT()}); + + // Test 1: OR expression doesn't evaluate arguments on non-increasing + // selection, so should return true (eager materialization needed). + { + auto exprSet = compileExpression("a > 10 OR b > 20", rowType); + auto& expr = *exprSet->exprs().front(); + for (const auto& field : expr.distinctFields()) { + EXPECT_TRUE(hive::shouldEagerlyMaterialize(expr, *field)); + } + } + + // Test 2: AND expression evaluates arguments on non-increasing selection. + // Field used in simple conjunct (no conditionals) should not be eagerly + // materialized. + { + auto exprSet = compileExpression("a > 10 AND b > 20", rowType); + auto& expr = *exprSet->exprs().front(); + for (const auto& field : expr.distinctFields()) { + EXPECT_FALSE(hive::shouldEagerlyMaterialize(expr, *field)); + } + } + + // Test 3: AND expression with field used in IF conditional. + // Field used in input with conditionals should be eagerly materialized. + { + auto exprSet = + compileExpression("a > 10 AND if(b > 20, c < 30, c > 5)", rowType); + auto& expr = *exprSet->exprs().front(); + for (const auto& field : expr.distinctFields()) { + if (field->field() == "c" || field->field() == "b") { + EXPECT_TRUE(hive::shouldEagerlyMaterialize(expr, *field)); + } else { + EXPECT_FALSE(hive::shouldEagerlyMaterialize(expr, *field)); + } + } + } + + // Test 4: AND expression where field is used in simple conjunct, + // not in conditional. + { + auto exprSet = + compileExpression("a > 10 OR if(b > 20, c < 30, c < 5)", rowType); + auto& expr = *exprSet->exprs().front(); + // Find field 'a' which is used in simple comparison, not in conditional. + for (const auto& field : expr.distinctFields()) { + EXPECT_TRUE(hive::shouldEagerlyMaterialize(expr, *field)); + } + } +} + +TEST_F(HiveConnectorUtilTest, createPointFilter) { + // Test BIGINT point filter. + { + auto filter = hive::createPointFilter( + BIGINT(), variant::create(42)); + ASSERT_NE(filter, nullptr); + EXPECT_TRUE(filter->testInt64(42)); + EXPECT_FALSE(filter->testInt64(41)); + EXPECT_FALSE(filter->testInt64(43)); + EXPECT_FALSE(filter->testNull()); + } + + // Test INTEGER point filter. + { + auto filter = hive::createPointFilter(INTEGER(), variant(100)); + ASSERT_NE(filter, nullptr); + EXPECT_TRUE(filter->testInt64(100)); + EXPECT_FALSE(filter->testInt64(99)); + EXPECT_FALSE(filter->testNull()); + } + + // Test DOUBLE point filter. + { + auto filter = hive::createPointFilter(DOUBLE(), variant(3.14)); + ASSERT_NE(filter, nullptr); + EXPECT_TRUE(filter->testDouble(3.14)); + EXPECT_FALSE(filter->testDouble(3.15)); + EXPECT_FALSE(filter->testNull()); + } + + // Test VARCHAR point filter. + { + auto filter = hive::createPointFilter(VARCHAR(), variant("hello")); + ASSERT_NE(filter, nullptr); + EXPECT_TRUE(filter->testBytes("hello", 5)); + EXPECT_FALSE(filter->testBytes("world", 5)); + EXPECT_FALSE(filter->testNull()); + } + + // Test BOOLEAN point filter. + { + auto filter = hive::createPointFilter(BOOLEAN(), variant(true)); + ASSERT_NE(filter, nullptr); + EXPECT_TRUE(filter->testBool(true)); + EXPECT_FALSE(filter->testBool(false)); + EXPECT_FALSE(filter->testNull()); + } + + // Test null value throws. + { + VELOX_ASSERT_THROW( + hive::createPointFilter(BIGINT(), variant::null(TypeKind::BIGINT)), + "Value cannot be null"); + } + + // Test unsupported types throw. + for (const auto& unsupported : kUnsupportedFilterTypes) { + SCOPED_TRACE( + fmt::format("Unsupported type: {}", unsupported.type->toString())); + VELOX_ASSERT_THROW( + hive::createPointFilter(unsupported.type, unsupported.value), ""); + } +} + +TEST_F(HiveConnectorUtilTest, createRangeFilter) { + // Test BIGINT range filter. + { + auto filter = hive::createRangeFilter( + BIGINT(), + variant::create(10), + variant::create(20)); + ASSERT_NE(filter, nullptr); + EXPECT_TRUE(filter->testInt64(10)); + EXPECT_TRUE(filter->testInt64(15)); + EXPECT_TRUE(filter->testInt64(20)); + EXPECT_FALSE(filter->testInt64(9)); + EXPECT_FALSE(filter->testInt64(21)); + EXPECT_FALSE(filter->testNull()); + } + + // Test INTEGER range filter. + { + auto filter = hive::createRangeFilter(INTEGER(), variant(0), variant(100)); + ASSERT_NE(filter, nullptr); + EXPECT_TRUE(filter->testInt64(0)); + EXPECT_TRUE(filter->testInt64(50)); + EXPECT_TRUE(filter->testInt64(100)); + EXPECT_FALSE(filter->testInt64(-1)); + EXPECT_FALSE(filter->testInt64(101)); + } + + // Test DOUBLE range filter. + { + auto filter = hive::createRangeFilter(DOUBLE(), variant(1.0), variant(2.0)); + ASSERT_NE(filter, nullptr); + EXPECT_TRUE(filter->testDouble(1.0)); + EXPECT_TRUE(filter->testDouble(1.5)); + EXPECT_TRUE(filter->testDouble(2.0)); + EXPECT_FALSE(filter->testDouble(0.9)); + EXPECT_FALSE(filter->testDouble(2.1)); + } + + // Test VARCHAR range filter. + { + auto filter = + hive::createRangeFilter(VARCHAR(), variant("apple"), variant("banana")); + ASSERT_NE(filter, nullptr); + EXPECT_TRUE(filter->testBytes("apple", 5)); + EXPECT_TRUE(filter->testBytes("ball", 4)); + EXPECT_TRUE(filter->testBytes("banana", 6)); + EXPECT_FALSE(filter->testBytes("aaa", 3)); + EXPECT_FALSE(filter->testBytes("cherry", 6)); + } + + // Test lower bound only (upper unbounded) - BIGINT. + { + auto filter = hive::createRangeFilter( + BIGINT(), + variant::create(10), + variant::null(TypeKind::BIGINT)); + ASSERT_NE(filter, nullptr); + EXPECT_TRUE(filter->testInt64(10)); + EXPECT_TRUE(filter->testInt64(100)); + EXPECT_TRUE(filter->testInt64(std::numeric_limits::max())); + EXPECT_FALSE(filter->testInt64(9)); + EXPECT_FALSE(filter->testNull()); + } + + // Test upper bound only (lower unbounded) - BIGINT. + { + auto filter = hive::createRangeFilter( + BIGINT(), + variant::null(TypeKind::BIGINT), + variant::create(20)); + ASSERT_NE(filter, nullptr); + EXPECT_TRUE(filter->testInt64(20)); + EXPECT_TRUE(filter->testInt64(0)); + EXPECT_TRUE(filter->testInt64(std::numeric_limits::min())); + EXPECT_FALSE(filter->testInt64(21)); + EXPECT_FALSE(filter->testNull()); + } + + // Test lower bound only - DOUBLE. + { + auto filter = hive::createRangeFilter( + DOUBLE(), variant(1.5), variant::null(TypeKind::DOUBLE)); + ASSERT_NE(filter, nullptr); + EXPECT_TRUE(filter->testDouble(1.5)); + EXPECT_TRUE(filter->testDouble(100.0)); + EXPECT_FALSE(filter->testDouble(1.0)); + } + + // Test upper bound only - VARCHAR. + { + auto filter = hive::createRangeFilter( + VARCHAR(), variant::null(TypeKind::VARCHAR), variant("banana")); + ASSERT_NE(filter, nullptr); + EXPECT_TRUE(filter->testBytes("banana", 6)); + EXPECT_TRUE(filter->testBytes("aaa", 3)); + EXPECT_FALSE(filter->testBytes("cherry", 6)); + } + + // Test both bounds null throws. + { + VELOX_ASSERT_THROW( + hive::createRangeFilter( + BIGINT(), + variant::null(TypeKind::BIGINT), + variant::null(TypeKind::BIGINT)), + "At least one of lower or upper bound must be set"); + } + + // Test unsupported types throw. + for (const auto& unsupported : kUnsupportedFilterTypes) { + SCOPED_TRACE( + fmt::format("Unsupported type: {}", unsupported.type->toString())); + VELOX_ASSERT_THROW( + hive::createRangeFilter( + unsupported.type, unsupported.value, unsupported.value), + ""); + } +} + } // namespace facebook::velox::connector diff --git a/velox/connectors/hive/tests/HiveDataSinkTest.cpp b/velox/connectors/hive/tests/HiveDataSinkTest.cpp index 4e3f0a19f4e..0fd25a9805b 100644 --- a/velox/connectors/hive/tests/HiveDataSinkTest.cpp +++ b/velox/connectors/hive/tests/HiveDataSinkTest.cpp @@ -16,14 +16,18 @@ #include #include "velox/common/caching/AsyncDataCache.h" +#include "velox/common/io/IoStatistics.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include +#include #include #include "velox/common/base/Fs.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/testutil/TestValue.h" +#include "velox/connectors/ConnectorRegistry.h" #include "velox/connectors/hive/HiveConnector.h" +#include "velox/connectors/hive/HiveDataSink.h" #include "velox/dwio/common/BufferedInput.h" #include "velox/dwio/common/Options.h" #include "velox/dwio/dwrf/reader/DwrfReader.h" @@ -37,8 +41,8 @@ #include "velox/dwio/parquet/writer/Writer.h" #endif +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/vector/fuzzer/VectorFuzzer.h" namespace facebook::velox::connector::hive { @@ -73,7 +77,7 @@ class HiveDataSinkTest : public exec::test::HiveConnectorTestBase { setupMemoryPools(); spillExecutor_ = std::make_unique( - std::thread::hardware_concurrency()); + folly::available_concurrency()); } void TearDown() override { @@ -113,7 +117,8 @@ class HiveDataSinkTest : public exec::test::HiveConnectorTestBase { 0, 0, writerFlushThreshold, - "none"); + "none", + 0); } void setupMemoryPools() { @@ -165,7 +170,7 @@ class HiveDataSinkTest : public exec::test::HiveConnectorTestBase { connector::hive::LocationHandle::TableType::kNew), fileFormat, CompressionKind::CompressionKind_ZSTD, - {}, + {}, // serdeParameters writerOptions, ensureFiles); } @@ -205,7 +210,7 @@ class HiveDataSinkTest : public exec::test::HiveConnectorTestBase { return files; } - void verifyWrittenData(const std::string& dirPath, int32_t numFiles = 1) { + void verifyWrittenData(const std::string& dirPath, uint32_t numFiles = 1) { const std::vector filePaths = listFiles(dirPath); ASSERT_EQ(filePaths.size(), numFiles); std::vector> splits; @@ -239,6 +244,10 @@ class HiveDataSinkTest : public exec::test::HiveConnectorTestBase { std::make_shared(std::make_shared( std::unordered_map())); std::unique_ptr spillExecutor_; + std::shared_ptr dataIoStats_ = + std::make_shared(); + std::shared_ptr metadataIoStats_ = + std::make_shared(); }; TEST_F(HiveDataSinkTest, hiveSortingColumn) { @@ -616,7 +625,7 @@ TEST_F(HiveDataSinkTest, close) { const auto partitions = dataSink->close(); // Can't append after close. VELOX_ASSERT_THROW( - dataSink->appendData(vectors.back()), "Hive data sink is not running"); + dataSink->appendData(vectors.back()), "File data sink is not running"); VELOX_ASSERT_THROW( dataSink->close(), "Unexpected state transition from CLOSED to CLOSED"); VELOX_ASSERT_THROW( @@ -664,7 +673,7 @@ TEST_F(HiveDataSinkTest, abort) { "Unexpected state transition from ABORTED to ABORTED"); // Can't append after abort. VELOX_ASSERT_THROW( - dataSink->appendData(vectors.back()), "Hive data sink is not running"); + dataSink->appendData(vectors.back()), "File data sink is not running"); } } @@ -739,7 +748,7 @@ DEBUG_ONLY_TEST_F(HiveDataSinkTest, memoryReclaim) { std::shared_ptr spillDirectory; std::unique_ptr spillConfig; if (testData.writerSpillEnabled) { - spillDirectory = exec::test::TempDirectoryPath::create(); + spillDirectory = TempDirectoryPath::create(); spillConfig = getSpillConfig( spillDirectory->getPath(), testData.writerFlushThreshold); auto connectorQueryCtx = std::make_unique( @@ -881,7 +890,7 @@ TEST_F(HiveDataSinkTest, memoryReclaimAfterClose) { std::shared_ptr spillDirectory; std::unique_ptr spillConfig; if (testData.writerSpillEnabled) { - spillDirectory = exec::test::TempDirectoryPath::create(); + spillDirectory = TempDirectoryPath::create(); spillConfig = getSpillConfig(spillDirectory->getPath(), 0); auto connectorQueryCtx = std::make_unique( opPool_.get(), @@ -1017,7 +1026,7 @@ TEST_F(HiveDataSinkTest, sortWriterMemoryReclaimDuringFinish) { std::make_shared( "c1", core::SortOrder{false, false})}); std::shared_ptr spillDirectory = - exec::test::TempDirectoryPath::create(); + TempDirectoryPath::create(); std::unique_ptr spillConfig = getSpillConfig(spillDirectory->getPath(), 1); connectorSessionProperties_->set( @@ -1084,7 +1093,7 @@ DEBUG_ONLY_TEST_F(HiveDataSinkTest, sortWriterFailureTest) { std::make_shared( "c1", core::SortOrder{false, false})}); const std::shared_ptr spillDirectory = - exec::test::TempDirectoryPath::create(); + TempDirectoryPath::create(); std::unique_ptr spillConfig = getSpillConfig(spillDirectory->getPath(), 0); // Triggers the memory reservation in sort buffer. @@ -1167,7 +1176,9 @@ TEST_F(HiveDataSinkTest, flushPolicyWithParquet) { ASSERT_TRUE(dataSink->finish()); dataSink->close(); - dwio::common::ReaderOptions readerOpts{pool_.get()}; + dwio::common::ReaderOptions readerOpts(pool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); const std::vector filePaths = listFiles(outputDirectory->getPath()); auto bufferedInput = std::make_unique( @@ -1205,7 +1216,9 @@ TEST_F(HiveDataSinkTest, flushPolicyWithDWRF) { ASSERT_TRUE(dataSink->finish()); dataSink->close(); - dwio::common::ReaderOptions readerOpts{pool_.get()}; + dwio::common::ReaderOptions readerOpts(pool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); const std::vector filePaths = listFiles(outputDirectory->getPath()); auto bufferedInput = std::make_unique( @@ -1288,7 +1301,7 @@ TEST_F(HiveDataSinkTest, ensureFilesUnsupported) { dwio::common::FileFormat::DWRF, CompressionKind::CompressionKind_ZSTD, {}, // serdeParameters - nullptr, // writeOptions + nullptr, // writerOptions true // ensureFiles ), "ensureFiles is not supported with partition keys in the data"); @@ -1311,12 +1324,532 @@ TEST_F(HiveDataSinkTest, ensureFilesUnsupported) { dwio::common::FileFormat::DWRF, CompressionKind::CompressionKind_ZSTD, {}, // serdeParameters - nullptr, // writeOptions + nullptr, // writerOptions true // ensureFiles ), "ensureFiles is not supported with bucketing"); } +TEST_F(HiveDataSinkTest, fileRotationBasic) { + const auto outputDirectory = TempDirectoryPath::create(); + + std::unordered_map connectorConfig; + connectorConfig.emplace("max-target-file-size", "1MB"); + connectorConfig.emplace("hive.orc.writer.stripe-max-size", "256KB"); + connectorConfig_ = std::make_shared( + std::make_shared(std::move(connectorConfig))); + + auto dataSink = createDataSink(rowType_, outputDirectory->getPath()); + + const int numBatches = 100; + const auto vectors = createVectors(1000, numBatches); + for (const auto& vector : vectors) { + dataSink->appendData(vector); + } + + ASSERT_TRUE(dataSink->finish()); + const auto partitions = dataSink->close(); + const auto stats = dataSink->stats(); + + ASSERT_GT(stats.numWrittenFiles, 1); + ASSERT_EQ(partitions.size(), 1); + + const auto partitionJson = folly::parseJson(partitions[0]); + ASSERT_TRUE(partitionJson.count(HiveCommitMessage::kFileWriteInfos) > 0); + const auto& fileWriteInfos = + partitionJson[HiveCommitMessage::kFileWriteInfos]; + ASSERT_GT(fileWriteInfos.size(), 1); + + for (size_t i = 0; i < fileWriteInfos.size(); ++i) { + ASSERT_TRUE(fileWriteInfos[i].count(HiveCommitMessage::kWriteFileName) > 0); + ASSERT_TRUE( + fileWriteInfos[i].count(HiveCommitMessage::kTargetFileName) > 0); + ASSERT_TRUE(fileWriteInfos[i].count(HiveCommitMessage::kFileSize) > 0); + ASSERT_GT(fileWriteInfos[i][HiveCommitMessage::kFileSize].asInt(), 0); + } + createDuckDbTable(vectors); + verifyWrittenData(outputDirectory->getPath(), stats.numWrittenFiles); +} + +TEST_F(HiveDataSinkTest, fileRotationNoEmptyTrailingFile) { + const auto outputDirectory = TempDirectoryPath::create(); + + std::unordered_map connectorConfig; + connectorConfig.emplace("max-target-file-size", "1KB"); + connectorConfig.emplace("hive.orc.writer.stripe-max-size", "1KB"); + connectorConfig_ = std::make_shared( + std::make_shared(std::move(connectorConfig))); + + auto dataSink = createDataSink(rowType_, outputDirectory->getPath()); + + const auto vectors = createVectors(2000, 10); + for (const auto& vector : vectors) { + dataSink->appendData(vector); + } + + ASSERT_TRUE(dataSink->finish()); + const auto partitions = dataSink->close(); + const auto stats = dataSink->stats(); + + ASSERT_EQ(partitions.size(), 1); + const auto partitionJson = folly::parseJson(partitions[0]); + ASSERT_TRUE(partitionJson.count(HiveCommitMessage::kFileWriteInfos) > 0); + const auto& fileWriteInfos = + partitionJson[HiveCommitMessage::kFileWriteInfos]; + ASSERT_EQ(fileWriteInfos.size(), 5); + + const auto filePaths = listFiles(outputDirectory->getPath()); + ASSERT_EQ(filePaths.size(), fileWriteInfos.size()); + ASSERT_EQ(filePaths.size(), stats.numWrittenFiles); + createDuckDbTable(vectors); + verifyWrittenData(outputDirectory->getPath(), stats.numWrittenFiles); +} + +TEST_F(HiveDataSinkTest, fileRotationDisabledForBucketedTables) { + const auto outputDirectory = TempDirectoryPath::create(); + + std::unordered_map connectorConfig; + connectorConfig.emplace("max-target-file-size", "100KB"); + connectorConfig_ = std::make_shared( + std::make_shared(std::move(connectorConfig))); + + const int32_t numBuckets = 2; + auto bucketProperty = std::make_shared( + HiveBucketProperty::Kind::kHiveCompatible, + numBuckets, + std::vector{"c0"}, + std::vector{BIGINT()}, + std::vector>{}); + + auto dataSink = createDataSink( + rowType_, + outputDirectory->getPath(), + dwio::common::FileFormat::DWRF, + {}, + bucketProperty); + + const int numBatches = 20; + const auto vectors = createVectors(500, numBatches); + for (const auto& vector : vectors) { + dataSink->appendData(vector); + } + + ASSERT_TRUE(dataSink->finish()); + const auto partitions = dataSink->close(); + const auto stats = dataSink->stats(); + + ASSERT_EQ(stats.numWrittenFiles, numBuckets); + ASSERT_EQ(partitions.size(), numBuckets); + + for (const auto& partition : partitions) { + const auto partitionJson = folly::parseJson(partition); + ASSERT_TRUE(partitionJson.count(HiveCommitMessage::kFileWriteInfos) > 0); + const auto& fileWriteInfos = + partitionJson[HiveCommitMessage::kFileWriteInfos]; + ASSERT_EQ(fileWriteInfos.size(), 1); + } + + createDuckDbTable(vectors); + verifyWrittenData(outputDirectory->getPath(), numBuckets); +} + +TEST_F(HiveDataSinkTest, fileRotationDisabledByDefault) { + const auto outputDirectory = TempDirectoryPath::create(); + + // Don't set max-target-file-size (use default which is disabled) + auto dataSink = createDataSink(rowType_, outputDirectory->getPath()); + + // Write a lot of data + const int numBatches = 20; + const auto vectors = createVectors(500, numBatches); + for (const auto& vector : vectors) { + dataSink->appendData(vector); + } + + ASSERT_TRUE(dataSink->finish()); + const auto partitions = dataSink->close(); + const auto stats = dataSink->stats(); + + // Should have exactly 1 file (no rotation when disabled) + ASSERT_EQ(stats.numWrittenFiles, 1) + << "Without maxTargetFileSize, should create exactly 1 file"; + ASSERT_GT(stats.numWrittenBytes, 0); + ASSERT_EQ(partitions.size(), 1); + + // Verify partition update has correct file info + const auto partitionJson = folly::parseJson(partitions[0]); + ASSERT_TRUE(partitionJson.count(HiveCommitMessage::kFileWriteInfos) > 0); + const auto& fileWriteInfos = + partitionJson[HiveCommitMessage::kFileWriteInfos]; + ASSERT_EQ(fileWriteInfos.size(), 1) + << "Should have exactly 1 file entry when rotation disabled"; + + // Verify file info fields + const auto& fileInfo = fileWriteInfos[0]; + ASSERT_TRUE(fileInfo.count(HiveCommitMessage::kWriteFileName) > 0); + ASSERT_TRUE(fileInfo.count(HiveCommitMessage::kTargetFileName) > 0); + ASSERT_TRUE(fileInfo.count(HiveCommitMessage::kFileSize) > 0); + ASSERT_FALSE(fileInfo[HiveCommitMessage::kWriteFileName].asString().empty()); + ASSERT_FALSE(fileInfo[HiveCommitMessage::kTargetFileName].asString().empty()); + + const auto reportedFileSize = + static_cast(fileInfo[HiveCommitMessage::kFileSize].asInt()); + ASSERT_GT(reportedFileSize, 0); + + // File size in fileWriteInfos should match stats.numWrittenBytes + ASSERT_EQ(reportedFileSize, stats.numWrittenBytes); + + // onDiskDataSizeInBytes should also match + const auto onDiskBytes = static_cast( + partitionJson[HiveCommitMessage::kOnDiskDataSizeInBytes].asInt()); + ASSERT_EQ(onDiskBytes, stats.numWrittenBytes); + + // Verify actual file on disk matches reported size + const auto filePaths = listFiles(outputDirectory->getPath()); + ASSERT_EQ(filePaths.size(), 1); + ASSERT_EQ(fs::file_size(filePaths[0]), reportedFileSize); + + createDuckDbTable(vectors); + verifyWrittenData(outputDirectory->getPath()); +} + +TEST_F(HiveDataSinkTest, fileRotationIoStatsAccumulation) { + const auto outputDirectory = TempDirectoryPath::create(); + + std::unordered_map connectorConfig; + connectorConfig.emplace("max-target-file-size", "1MB"); + connectorConfig.emplace("hive.orc.writer.stripe-max-size", "256KB"); + connectorConfig_ = std::make_shared( + std::make_shared(std::move(connectorConfig))); + + auto dataSink = createDataSink(rowType_, outputDirectory->getPath()); + + const int numBatches = 100; + const auto vectors = createVectors(1000, numBatches); + for (const auto& vector : vectors) { + dataSink->appendData(vector); + } + + const auto statsBeforeClose = dataSink->stats(); + ASSERT_GT(statsBeforeClose.numWrittenBytes, 0); + + ASSERT_TRUE(dataSink->finish()); + const auto partitions = dataSink->close(); + const auto statsAfterClose = dataSink->stats(); + + ASSERT_GT(statsAfterClose.numWrittenFiles, 1); + ASSERT_GT(statsAfterClose.numWrittenBytes, 0); + + const auto filePaths = listFiles(outputDirectory->getPath()); + ASSERT_EQ(filePaths.size(), statsAfterClose.numWrittenFiles); + + uint64_t totalFileSizeOnDisk = 0; + for (const auto& filePath : filePaths) { + totalFileSizeOnDisk += fs::file_size(filePath); + } + + ASSERT_GT( + statsAfterClose.numWrittenBytes, + static_cast(totalFileSizeOnDisk * 0.9)); + ASSERT_LT( + statsAfterClose.numWrittenBytes, + static_cast(totalFileSizeOnDisk * 1.1)); + + ASSERT_EQ(partitions.size(), 1); + const auto partitionJson = folly::parseJson(partitions[0]); + const auto onDiskBytes = static_cast( + partitionJson[HiveCommitMessage::kOnDiskDataSizeInBytes].asInt()); + ASSERT_EQ(onDiskBytes, statsAfterClose.numWrittenBytes); + + createDuckDbTable(vectors); + verifyWrittenData( + outputDirectory->getPath(), statsAfterClose.numWrittenFiles); +} + +TEST_F(HiveDataSinkTest, fileRotationFileInfoConsistency) { + // Tests that file info (names, sizes) in partition updates is consistent + // with actual files on disk after rotation. + const auto outputDirectory = TempDirectoryPath::create(); + + std::unordered_map connectorConfig; + connectorConfig.emplace("max-target-file-size", "500KB"); + connectorConfig.emplace("hive.orc.writer.stripe-max-size", "128KB"); + connectorConfig_ = std::make_shared( + std::make_shared(std::move(connectorConfig))); + + auto dataSink = createDataSink(rowType_, outputDirectory->getPath()); + + const int numBatches = 50; + const auto vectors = createVectors(500, numBatches); + for (const auto& vector : vectors) { + dataSink->appendData(vector); + } + + ASSERT_TRUE(dataSink->finish()); + const auto partitions = dataSink->close(); + const auto stats = dataSink->stats(); + + ASSERT_GT(stats.numWrittenFiles, 1); + ASSERT_EQ(partitions.size(), 1); + + const auto partitionJson = folly::parseJson(partitions[0]); + const auto& fileWriteInfos = + partitionJson[HiveCommitMessage::kFileWriteInfos]; + + // Verify file count matches + ASSERT_EQ(fileWriteInfos.size(), stats.numWrittenFiles); + + // Get actual files on disk + const auto filePaths = listFiles(outputDirectory->getPath()); + ASSERT_EQ(filePaths.size(), stats.numWrittenFiles); + + // Build a map of file names to sizes from partition info + std::map reportedFiles; + uint64_t totalReportedSize = 0; + for (size_t i = 0; i < fileWriteInfos.size(); ++i) { + const auto& info = fileWriteInfos[i]; + const auto fileName = info[HiveCommitMessage::kWriteFileName].asString(); + const auto fileSize = + static_cast(info[HiveCommitMessage::kFileSize].asInt()); + reportedFiles[fileName] = fileSize; + totalReportedSize += fileSize; + + // Verify each file has non-empty name and positive size + ASSERT_FALSE(fileName.empty()) << "File name at index " << i << " is empty"; + ASSERT_GT(fileSize, 0) << "File size at index " << i << " is zero"; + } + + // Verify onDiskDataSizeInBytes matches sum of file sizes + const auto onDiskBytes = static_cast( + partitionJson[HiveCommitMessage::kOnDiskDataSizeInBytes].asInt()); + ASSERT_EQ(onDiskBytes, totalReportedSize); + ASSERT_EQ(onDiskBytes, stats.numWrittenBytes); + + // Verify actual file sizes on disk match reported sizes exactly. + // Stats are captured after close, so footer bytes are included. + for (const auto& filePath : filePaths) { + const auto fileName = fs::path(filePath).filename().string(); + ASSERT_TRUE(reportedFiles.count(fileName) > 0) + << "File on disk not found in partition info: " << fileName; + + const auto actualSize = fs::file_size(filePath); + const auto reportedSize = reportedFiles[fileName]; + // Sizes should match exactly since stats are captured after close. + ASSERT_EQ(actualSize, reportedSize) + << "File size mismatch for " << fileName; + } + + createDuckDbTable(vectors); + verifyWrittenData(outputDirectory->getPath(), stats.numWrittenFiles); +} + +TEST_F(HiveDataSinkTest, fileRotationStatsProgressDuringWrite) { + // Tests that stats are correctly reported during writing (not just at close). + // Verifies stats grow monotonically even when rotation happens mid-write. + const auto outputDirectory = TempDirectoryPath::create(); + + std::unordered_map connectorConfig; + connectorConfig.emplace("max-target-file-size", "256KB"); + connectorConfig.emplace("hive.orc.writer.stripe-max-size", "64KB"); + connectorConfig_ = std::make_shared( + std::make_shared(std::move(connectorConfig))); + + auto dataSink = createDataSink(rowType_, outputDirectory->getPath()); + + const auto vectors = createVectors(500, 50); + uint64_t previousBytes = 0; + + for (size_t i = 0; i < vectors.size(); ++i) { + dataSink->appendData(vectors[i]); + const auto currentStats = dataSink->stats(); + + // Stats should be monotonically increasing + ASSERT_GE(currentStats.numWrittenBytes, previousBytes) + << "Stats decreased at batch " << i; + previousBytes = currentStats.numWrittenBytes; + } + + const auto statsBeforeFinish = dataSink->stats(); + ASSERT_GT(statsBeforeFinish.numWrittenBytes, 0); + + ASSERT_TRUE(dataSink->finish()); + const auto partitions = dataSink->close(); + const auto finalStats = dataSink->stats(); + + // Final stats should be >= stats before finish + ASSERT_GE(finalStats.numWrittenBytes, statsBeforeFinish.numWrittenBytes); + ASSERT_GT(finalStats.numWrittenFiles, 1); + + createDuckDbTable(vectors); + verifyWrittenData(outputDirectory->getPath(), finalStats.numWrittenFiles); +} + +TEST_F(HiveDataSinkTest, fileRotationWithPartitionedTable) { + // Tests file rotation with partitioned tables. + const auto outputDirectory = TempDirectoryPath::create(); + + std::unordered_map connectorConfig; + connectorConfig.emplace("max-target-file-size", "256KB"); + connectorConfig.emplace("hive.orc.writer.stripe-max-size", "64KB"); + connectorConfig_ = std::make_shared( + std::make_shared(std::move(connectorConfig))); + + // Create a partitioned table with partition column + auto partitionedRowType = ROW( + {"c0", "c1", "c2", "p0"}, {BIGINT(), INTEGER(), SMALLINT(), VARCHAR()}); + + auto dataSink = createDataSink( + partitionedRowType, + outputDirectory->getPath(), + dwio::common::FileFormat::DWRF, + {"p0"}); + + // Create vectors with a few distinct partition values + const int numBatches = 30; + std::vector vectors; + vectors.reserve(numBatches); + for (int i = 0; i < numBatches; ++i) { + auto c0 = makeFlatVector(500, [](auto row) { return row; }); + auto c1 = makeFlatVector(500, [](auto row) { return row * 2; }); + auto c2 = makeFlatVector(500, [](auto row) { return row % 100; }); + // Two partitions: "part_a" and "part_b" + auto p0 = makeFlatVector(500, [i](auto row) { + return (i + row) % 2 == 0 ? "part_a" : "part_b"; + }); + vectors.push_back(makeRowVector({c0, c1, c2, p0})); + } + + for (const auto& vector : vectors) { + dataSink->appendData(vector); + } + + ASSERT_TRUE(dataSink->finish()); + const auto partitions = dataSink->close(); + const auto stats = dataSink->stats(); + + // Should have 2 partitions + ASSERT_EQ(partitions.size(), 2); + + // Each partition may have multiple files due to rotation + uint32_t totalFilesFromPartitions = 0; + for (const auto& partition : partitions) { + const auto partitionJson = folly::parseJson(partition); + ASSERT_TRUE(partitionJson.count(HiveCommitMessage::kFileWriteInfos) > 0); + const auto& fileWriteInfos = + partitionJson[HiveCommitMessage::kFileWriteInfos]; + ASSERT_GT(fileWriteInfos.size(), 0); + totalFilesFromPartitions += fileWriteInfos.size(); + + // Verify each file has valid info + for (const auto& fileInfo : fileWriteInfos) { + ASSERT_TRUE(fileInfo.count(HiveCommitMessage::kWriteFileName) > 0); + ASSERT_TRUE(fileInfo.count(HiveCommitMessage::kTargetFileName) > 0); + ASSERT_TRUE(fileInfo.count(HiveCommitMessage::kFileSize) > 0); + ASSERT_GT(fileInfo[HiveCommitMessage::kFileSize].asInt(), 0); + } + } + + ASSERT_EQ(totalFilesFromPartitions, stats.numWrittenFiles); +} + +TEST_F(HiveDataSinkTest, fileRotationWriteIOTimeAccumulation) { + // Tests that writeIOTimeUs is correctly accumulated across rotated files. + const auto outputDirectory = TempDirectoryPath::create(); + + std::unordered_map connectorConfig; + connectorConfig.emplace("max-target-file-size", "512KB"); + connectorConfig.emplace("hive.orc.writer.stripe-max-size", "128KB"); + connectorConfig_ = std::make_shared( + std::make_shared(std::move(connectorConfig))); + + auto dataSink = createDataSink(rowType_, outputDirectory->getPath()); + + const int numBatches = 60; + const auto vectors = createVectors(500, numBatches); + for (const auto& vector : vectors) { + dataSink->appendData(vector); + } + + ASSERT_TRUE(dataSink->finish()); + dataSink->close(); + const auto stats = dataSink->stats(); + + // Should have multiple files from rotation + ASSERT_GT(stats.numWrittenFiles, 1); + + // writeIOTimeUs should be non-zero (actual I/O happened) + // Note: This may be 0 on very fast systems, so we just check it's >= 0 + ASSERT_GE(stats.writeIOTimeUs, 0); + + createDuckDbTable(vectors); + verifyWrittenData(outputDirectory->getPath(), stats.numWrittenFiles); +} + +TEST_F(HiveDataSinkTest, fileRotationWithMemoryReclaim) { + // Tests that file rotation works correctly when memory reclamation is + // enabled. This verifies the single ioStats approach - we reuse the same + // ioStats object during rotation instead of replacing it, which prevents + // race conditions with WriterReclaimer. + const auto outputDirectory = TempDirectoryPath::create(); + + std::unordered_map connectorConfig; + // Use small file size to trigger multiple rotations + connectorConfig.emplace("max-target-file-size", "256KB"); + connectorConfig.emplace("hive.orc.writer.stripe-max-size", "64KB"); + connectorConfig_ = std::make_shared( + std::make_shared(std::move(connectorConfig))); + + // Setup memory pools with spill config to enable reclaim + auto spillDirectory = TempDirectoryPath::create(); + auto spillConfig = getSpillConfig(spillDirectory->getPath(), 1 << 30); + auto connectorQueryCtx = std::make_unique( + opPool_.get(), + connectorPool_.get(), + connectorSessionProperties_.get(), + spillConfig.get(), + common::PrefixSortConfig(), + nullptr, + nullptr, + "query.HiveDataSinkTest", + "task.HiveDataSinkTest", + "planNodeId.HiveDataSinkTest", + 0, + ""); + setConnectorQueryContext(std::move(connectorQueryCtx)); + + auto dataSink = createDataSink(rowType_, outputDirectory->getPath()); + auto* hiveDataSink = static_cast(dataSink.get()); + + // Verify reclaim is enabled + ASSERT_TRUE(hiveDataSink->canReclaim()); + + const int numBatches = 50; + const auto vectors = createVectors(500, numBatches); + + for (const auto& vector : vectors) { + dataSink->appendData(vector); + } + + // Trigger memory reclaim if there's reclaimable memory + auto reclaimableBytes = root_->reclaimableBytes(); + if (reclaimableBytes.has_value() && reclaimableBytes.value() > 0) { + memory::MemoryReclaimer::Stats reclaimStats; + root_->reclaim(1L << 20, 0, reclaimStats); + } + + ASSERT_TRUE(dataSink->finish()); + const auto partitions = dataSink->close(); + const auto stats = dataSink->stats(); + + // Should have multiple files from rotation + ASSERT_GT(stats.numWrittenFiles, 1); + ASSERT_EQ(partitions.size(), 1); + + // Verify data integrity + createDuckDbTable(vectors); + verifyWrittenData(outputDirectory->getPath(), stats.numWrittenFiles); +} + TEST_F(HiveDataSinkTest, raceWithCacheEviction) { /// This test ensures that LRU cache staleness and StringIdMap cache /// eviction do not cause issues with file reads. @@ -1324,7 +1857,7 @@ TEST_F(HiveDataSinkTest, raceWithCacheEviction) { auto cacheCleaner = std::async(std::launch::async, [&] { auto cache = cache::AsyncDataCache::getInstance(); auto hiveConnector = std::dynamic_pointer_cast( - getConnector(exec::test::kHiveConnectorId)); + ConnectorRegistry::tryGet(exec::test::kHiveConnectorId)); while (!stop) { std::this_thread::sleep_for(std::chrono::milliseconds(10)); cache->clear(); @@ -1375,6 +1908,99 @@ TEST_F(HiveDataSinkTest, lazyVectorForParquet) { } #endif +// Test to verify that each writer has its own nonReclaimableSection +// pointer when writerOptions is shared. +TEST_F(HiveDataSinkTest, sharedWriterOptionsWithMultipleWriters) { + const auto outputDirectory = TempDirectoryPath::create(); + + const int32_t numBuckets = 3; + auto bucketProperty = std::make_shared( + HiveBucketProperty::Kind::kHiveCompatible, + numBuckets, + std::vector{"c0"}, + std::vector{BIGINT()}, + std::vector>{}); + + // Create shared writer options (this simulates the scenario where + // insertTableHandle_->writerOptions() returns a shared object) + auto sharedWriterOptions = std::make_shared(); + + // Create a data sink with multiple writers (one for each bucket) + auto dataSink = createDataSink( + rowType_, + outputDirectory->getPath(), + dwio::common::FileFormat::DWRF, + {}, + bucketProperty, + sharedWriterOptions); + + const auto vectors = createVectors(200, 3); + + // Write data - this should work without throwing exceptions + for (const auto& vector : vectors) { + dataSink->appendData(vector); + } + + while (!dataSink->finish()) { + } + const auto partitions = dataSink->close(); + + ASSERT_GT(partitions.size(), 1); + createDuckDbTable(vectors); + verifyWrittenData( + outputDirectory->getPath(), static_cast(partitions.size())); +} + +DEBUG_ONLY_TEST_F(HiveDataSinkTest, perWriterMemoryPool) { + const auto outputDirectory = TempDirectoryPath::create(); + auto writerOptions = std::make_shared(); + + const auto rowType = ROW({"c0", "p0"}, {BIGINT(), VARCHAR()}); + auto dataSink = createDataSink( + rowType, + outputDirectory->getPath(), + dwio::common::FileFormat::DWRF, + {"p0"}, + nullptr, + writerOptions); + + std::set writerPoolNames; + SCOPED_TESTVALUE_SET( + "facebook::velox::dwrf::Writer::write", + std::function([&](dwrf::Writer* writer) { + // Memory pool hierarchy: + // Hive writer pool -> DWRF writer pool -> DWRF category leaf pools. + auto* innerPool = writer->getContext() + .getMemoryPool(dwrf::MemoryUsageCategory::GENERAL) + .parent(); + ASSERT_NE(innerPool, nullptr); + auto* writerPool = innerPool->parent(); + ASSERT_NE(writerPool, nullptr); + writerPoolNames.insert(writerPool->name()); + })); + + dataSink->appendData(makeRowVector({ + makeFlatVector(200, folly::identity), + makeFlatVector( + 200, [](auto row) { return row % 2 == 0 ? "part_0" : "part_1"; }), + })); + + ASSERT_EQ(writerPoolNames.size(), 2); + ASSERT_TRUE(dataSink->finish()); + ASSERT_EQ(dataSink->close().size(), 2); +} + +TEST_F(HiveDataSinkTest, sanitizeFileName) { + auto sanitizeFileName = [](std::string fileName) { + HiveInsertFileNameGenerator::sanitizeFileName(fileName); + return fileName; + }; + ASSERT_EQ(sanitizeFileName("abc"), "abc"); + ASSERT_EQ(sanitizeFileName("abc_.-ABC012"), "abc_.-ABC012"); + ASSERT_EQ(sanitizeFileName("abc_.-ABC012\\/"), "abc_.-ABC012__"); + ASSERT_EQ(sanitizeFileName("local://abc/bcd/"), "local___abc_bcd_"); +} + } // namespace } // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/tests/HivePartitionFunctionTest.cpp b/velox/connectors/hive/tests/HivePartitionFunctionTest.cpp index 4fdf68684b0..1adfb387445 100644 --- a/velox/connectors/hive/tests/HivePartitionFunctionTest.cpp +++ b/velox/connectors/hive/tests/HivePartitionFunctionTest.cpp @@ -832,10 +832,11 @@ TEST_F(HivePartitionFunctionTest, skew) { } std::vector partitionedInputs; for (int partition = 0; partition < numRemotePartitions; ++partition) { - partitionedInputs.push_back(exec::wrap( - partitionSizeVectors[partition], - partitionIndicesVector[partition], - input)); + partitionedInputs.push_back( + exec::wrap( + partitionSizeVectors[partition], + partitionIndicesVector[partition], + input)); } // Checks that the bad hive partition function (using round-robin map from diff --git a/velox/connectors/hive/tests/HivePartitionUtilTest.cpp b/velox/connectors/hive/tests/HivePartitionNameTest.cpp similarity index 85% rename from velox/connectors/hive/tests/HivePartitionUtilTest.cpp rename to velox/connectors/hive/tests/HivePartitionNameTest.cpp index 8598f46742f..b236951946f 100644 --- a/velox/connectors/hive/tests/HivePartitionUtilTest.cpp +++ b/velox/connectors/hive/tests/HivePartitionNameTest.cpp @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "velox/connectors/hive/HivePartitionUtil.h" +#include "velox/connectors/hive/HivePartitionName.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/dwio/catalog/fbhive/FileUtils.h" #include "velox/vector/tests/utils/VectorTestBase.h" @@ -26,7 +26,7 @@ using namespace facebook::velox; using namespace facebook::velox::connector::hive; using namespace facebook::velox::dwio::catalog::fbhive; -class HivePartitionUtilTest : public ::testing::Test, +class HivePartitionNameTest : public ::testing::Test, public velox::test::VectorTestBase { protected: static void SetUpTestCase() { @@ -62,9 +62,26 @@ class HivePartitionUtilTest : public ::testing::Test, input->size(), partitions); } + + static auto toPartitionName() { + return [](auto value, const TypePtr& type, int /*columnIndex*/) { + return HivePartitionName::toName(value, type); + }; + } + + std::vector> extractPartitionKeyValues( + RowVectorPtr input, + const std::vector& partitionChannels, + vector_size_t rowIndex = 0) { + return HivePartitionName::partitionKeyValues( + rowIndex, + makePartitionsVector(input, partitionChannels), + /*nullValueString=*/"", + toPartitionName()); + } }; -TEST_F(HivePartitionUtilTest, partitionName) { +TEST_F(HivePartitionNameTest, partitionName) { { RowVectorPtr input = makeRowVector( {"flat_bool_col", @@ -102,9 +119,7 @@ TEST_F(HivePartitionUtilTest, partitionName) { EXPECT_EQ( FileUtils::makePartName( - extractPartitionKeyValues( - makePartitionsVector(input, partitionChannels), 0), - true), + extractPartitionKeyValues(input, partitionChannels), true), folly::join( "/", std::vector( @@ -124,14 +139,12 @@ TEST_F(HivePartitionUtilTest, partitionName) { VELOX_ASSERT_THROW( FileUtils::makePartName( - extractPartitionKeyValues( - makePartitionsVector(input, partitionChannels), 0), - true), + extractPartitionKeyValues(input, partitionChannels), true), "Unsupported partition type: MAP"); } } -TEST_F(HivePartitionUtilTest, partitionNameForNull) { +TEST_F(HivePartitionNameTest, partitionNameForNull) { std::vector partitionColumnNames{ "flat_bool_col", "flat_tinyint_col", @@ -155,15 +168,14 @@ TEST_F(HivePartitionUtilTest, partitionNameForNull) { for (auto i = 0; i < partitionColumnNames.size(); i++) { std::vector partitionChannels = {(column_index_t)i}; - auto partitionEntries = extractPartitionKeyValues( - makePartitionsVector(input, partitionChannels), 0); + auto partitionEntries = extractPartitionKeyValues(input, partitionChannels); EXPECT_EQ(1, partitionEntries.size()); EXPECT_EQ(partitionColumnNames[i], partitionEntries[0].first); EXPECT_EQ("", partitionEntries[0].second); } } -TEST_F(HivePartitionUtilTest, timestampPartitionValueFormatting) { +TEST_F(HivePartitionNameTest, timestampPartitionValueFormatting) { // Test timestamp partition value formatting to match Presto's // java.sql.Timestamp.toString() behavior: removes trailing zeros but keeps at // least one decimal place @@ -192,11 +204,10 @@ TEST_F(HivePartitionUtilTest, timestampPartitionValueFormatting) { makeRowVector({"timestamp_col"}, {makeFlatVector(timestamps)}); std::vector partitionChannels{0}; - auto partitionsVector = makePartitionsVector(input, partitionChannels); for (size_t i = 0; i < timestamps.size(); i++) { - auto partitionEntries = extractPartitionKeyValues( - partitionsVector, static_cast(i)); + auto partitionEntries = + extractPartitionKeyValues(input, partitionChannels, i); EXPECT_EQ(1, partitionEntries.size()); EXPECT_EQ("timestamp_col", partitionEntries[0].first); diff --git a/velox/connectors/hive/tests/PartitionIdGeneratorTest.cpp b/velox/connectors/hive/tests/PartitionIdGeneratorTest.cpp index 271e4599d3f..7dcb0d5e195 100644 --- a/velox/connectors/hive/tests/PartitionIdGeneratorTest.cpp +++ b/velox/connectors/hive/tests/PartitionIdGeneratorTest.cpp @@ -16,6 +16,7 @@ #include "velox/connectors/hive/PartitionIdGenerator.h" #include "velox/common/base/tests/GTestUtils.h" +#include "velox/connectors/hive/HivePartitionName.h" #include "velox/type/TimestampConversion.h" #include "velox/vector/tests/utils/VectorTestBase.h" @@ -34,7 +35,7 @@ class PartitionIdGeneratorTest : public ::testing::Test, TEST_F(PartitionIdGeneratorTest, consecutiveIdsSingleKey) { auto numPartitions = 100; - PartitionIdGenerator idGenerator(ROW({VARCHAR()}), {0}, 100, pool(), true); + PartitionIdGenerator idGenerator(ROW({VARCHAR()}), {0}, 100, pool()); auto input = makeRowVector( {makeFlatVector(numPartitions * 3, [&](auto row) { @@ -56,7 +57,7 @@ TEST_F(PartitionIdGeneratorTest, consecutiveIdsSingleKey) { TEST_F(PartitionIdGeneratorTest, consecutiveIdsMultipleKeys) { PartitionIdGenerator idGenerator( - ROW({VARCHAR(), INTEGER()}), {0, 1}, 100, pool(), true); + ROW({VARCHAR(), INTEGER()}), {0, 1}, 100, pool()); auto input = makeRowVector({ makeFlatVector( @@ -83,7 +84,7 @@ TEST_F(PartitionIdGeneratorTest, consecutiveIdsMultipleKeys) { TEST_F(PartitionIdGeneratorTest, multipleBoolKeys) { PartitionIdGenerator idGenerator( - ROW({BOOLEAN(), BOOLEAN()}), {0, 1}, 100, pool(), true); + ROW({BOOLEAN(), BOOLEAN()}), {0, 1}, 100, pool()); auto input = makeRowVector({ makeFlatVector( @@ -109,7 +110,7 @@ TEST_F(PartitionIdGeneratorTest, multipleBoolKeys) { } TEST_F(PartitionIdGeneratorTest, stableIdsSingleKey) { - PartitionIdGenerator idGenerator(ROW({BIGINT()}), {0}, 100, pool(), true); + PartitionIdGenerator idGenerator(ROW({BIGINT()}), {0}, 100, pool()); auto numPartitions = 40; auto input = makeRowVector({ @@ -136,7 +137,7 @@ TEST_F(PartitionIdGeneratorTest, stableIdsSingleKey) { TEST_F(PartitionIdGeneratorTest, stableIdsMultipleKeys) { PartitionIdGenerator idGenerator( - ROW({BIGINT(), VARCHAR(), INTEGER()}), {1, 2}, 100, pool(), true); + ROW({BIGINT(), VARCHAR(), INTEGER()}), {1, 2}, 100, pool()); const vector_size_t size = 1'000; auto input = makeRowVector({ @@ -175,7 +176,7 @@ TEST_F(PartitionIdGeneratorTest, stableIdsMultipleKeys) { TEST_F(PartitionIdGeneratorTest, partitionKeysCaseSensitive) { PartitionIdGenerator idGenerator( - ROW({"cc0", "Cc1"}, {BIGINT(), VARCHAR()}), {1}, 100, pool(), false); + ROW({"cc0", "Cc+1"}, {BIGINT(), VARCHAR()}), {1}, 100, pool()); auto input = makeRowVector({ makeFlatVector({1, 2, 3}), @@ -184,12 +185,19 @@ TEST_F(PartitionIdGeneratorTest, partitionKeysCaseSensitive) { raw_vector firstTimeIds; idGenerator.run(input, firstTimeIds); - EXPECT_EQ("Cc1=apple", idGenerator.partitionName(0)); - EXPECT_EQ("Cc1=orange", idGenerator.partitionName(1)); + + EXPECT_EQ( + "Cc+1=apple", + HivePartitionName::partitionName( + 0, idGenerator.partitionValues(), /*partitionKeyAsLowerCase=*/false)); + EXPECT_EQ( + "Cc+1=orange", + HivePartitionName::partitionName( + 1, idGenerator.partitionValues(), /*partitionKeyAsLowerCase=*/false)); } TEST_F(PartitionIdGeneratorTest, numPartitions) { - PartitionIdGenerator idGenerator(ROW({BIGINT()}), {0}, 100, pool(), true); + PartitionIdGenerator idGenerator(ROW({BIGINT()}), {0}, 100, pool()); // First run to process partition 0,..,9. Total num of partitions processed by // far is 10. @@ -224,7 +232,7 @@ TEST_F(PartitionIdGeneratorTest, limitOfPartitionNumber) { auto maxPartitions = 100; PartitionIdGenerator idGenerator( - ROW({INTEGER()}), {0}, maxPartitions, pool(), true); + ROW({INTEGER()}), {0}, maxPartitions, pool()); auto input = makeRowVector({ makeFlatVector(maxPartitions + 1, [](auto row) { return row; }), @@ -239,7 +247,7 @@ TEST_F(PartitionIdGeneratorTest, limitOfPartitionNumber) { TEST_F(PartitionIdGeneratorTest, timestampPartitionKeyComparasion) { PartitionIdGenerator idGenerator( - ROW({"timestamp_col"}, {TIMESTAMP()}), {0}, 100, pool(), true); + ROW({"timestamp_col"}, {TIMESTAMP()}), {0}, 100, pool()); auto timestampResult = util::fromTimestampString( "2025-01-02 00:00:00.0", util::TimestampParseMode::kPrestoCast); auto input = makeRowVector({ @@ -247,13 +255,17 @@ TEST_F(PartitionIdGeneratorTest, timestampPartitionKeyComparasion) { }); raw_vector testTimeIds; idGenerator.run(input, testTimeIds); + EXPECT_EQ( - idGenerator.partitionName(testTimeIds[0]), + HivePartitionName::partitionName( + testTimeIds[0], + idGenerator.partitionValues(), + /*partitionKeyAsLowerCase=*/true), "timestamp_col=2025-01-01 16%3A00%3A00.0"); } TEST_F(PartitionIdGeneratorTest, timestampPartitionKey) { - PartitionIdGenerator idGenerator(ROW({TIMESTAMP()}), {0}, 100, pool(), true); + PartitionIdGenerator idGenerator(ROW({TIMESTAMP()}), {0}, 100, pool()); auto numPartitions = 50; auto input = makeRowVector({ @@ -325,8 +337,7 @@ TEST_F(PartitionIdGeneratorTest, supportedPartitionKeyTypes) { }), {0, 1, 2, 3, 4, 5, 6, 7}, 100, - pool(), - true); + pool()); auto input = makeRowVector( {makeNullableFlatVector( @@ -362,8 +373,7 @@ TEST_F(PartitionIdGeneratorTest, supportedPartitionKeyTypes) { for (column_index_t i = 1; i < input->childrenSize(); i++) { VELOX_ASSERT_THROW( - PartitionIdGenerator( - asRowType(input->type()), {i}, 100, pool(), true), + PartitionIdGenerator(asRowType(input->type()), {i}, 100, pool()), fmt::format( "Unsupported partition type: {}.", input->childAt(i)->type()->toString())); diff --git a/velox/connectors/hive/tests/TableHandleTest.cpp b/velox/connectors/hive/tests/TableHandleTest.cpp index 53cb95a2722..c342add7f4e 100644 --- a/velox/connectors/hive/tests/TableHandleTest.cpp +++ b/velox/connectors/hive/tests/TableHandleTest.cpp @@ -16,11 +16,14 @@ #include "velox/connectors/hive/TableHandle.h" +#include #include "gtest/gtest.h" #include "velox/common/base/tests/GTestUtils.h" +#include "velox/connectors/hive/ExtractionUtils.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" using namespace facebook::velox; +using namespace facebook::velox::connector::hive; TEST(FileHandleTest, hiveColumnHandle) { Type::registerSerDe(); @@ -35,7 +38,7 @@ TEST(FileHandleTest, hiveColumnHandle) { ASSERT_EQ(columnHandle->name(), "columnHandle"); ASSERT_EQ( columnHandle->columnType(), - connector::hive::HiveColumnHandle::ColumnType::kRegular); + connector::hive::FileColumnHandle::ColumnType::kRegular); ASSERT_EQ(columnHandle->dataType(), columnType); ASSERT_EQ(columnHandle->hiveType(), columnType); ASSERT_FALSE(columnHandle->isPartitionKey()); @@ -55,3 +58,727 @@ TEST(FileHandleTest, hiveColumnHandle) { {"c0.c0c1[3][\"foo\"].c0c1c0"}), "data type ROW>>> and hive type ROW do not match"); } + +TEST(TableHandleTest, hiveTableHandleDbName) { + connector::hive::HiveTableHandle::registerSerDe(); + + // Default dbName is empty. + auto handleNoDb = std::make_shared( + "test-connector", + "test_table", + common::SubfieldFilters{}, + /*remainingFilter=*/nullptr); + ASSERT_TRUE(handleNoDb->dbName().empty()); + ASSERT_EQ(handleNoDb->tableName(), "test_table"); + + // Explicit dbName is preserved. + auto handleWithDb = std::make_shared( + "test-connector", + "test_table", + common::SubfieldFilters{}, + /*remainingFilter=*/nullptr, + /*dataColumns=*/nullptr, + /*indexColumns=*/std::vector{}, + /*tableParameters=*/std::unordered_map{}, + /*filterColumnHandles=*/ + std::vector{}, + /*sampleRate=*/1.0, + /*dbName=*/"test_db"); + ASSERT_EQ(handleWithDb->dbName(), "test_db"); + ASSERT_EQ(handleWithDb->tableName(), "test_table"); + + // Serialization round-trip preserves dbName. + auto obj = handleWithDb->serialize(); + auto clone = ISerializable::deserialize( + obj, /*context=*/nullptr); + ASSERT_EQ(clone->dbName(), "test_db"); + ASSERT_EQ(clone->tableName(), "test_table"); + + // Round-trip with empty dbName omits the field. + auto objNoDb = handleNoDb->serialize(); + auto cloneNoDb = ISerializable::deserialize( + objNoDb, /*context=*/nullptr); + ASSERT_TRUE(cloneNoDb->dbName().empty()); +} + +TEST(TableHandleTest, hiveTableHandleIndexSupport) { + // Test HiveTableHandle without index columns. + auto tableHandleWithoutIndex = + std::make_shared( + "test-connector", + "test_table", + common::SubfieldFilters{}, + /*remainingFilter=*/nullptr, + /*dataColumns=*/nullptr, + /*indexColumns=*/std::vector{}); + + ASSERT_FALSE(tableHandleWithoutIndex->supportsIndexLookup()); + ASSERT_TRUE(tableHandleWithoutIndex->needsIndexSplit()); + ASSERT_TRUE(tableHandleWithoutIndex->indexColumns().empty()); + + // Test HiveTableHandle with index columns. + auto tableHandleWithIndex = + std::make_shared( + "test-connector", + "test_table", + common::SubfieldFilters{}, + /*remainingFilter=*/nullptr, + /*dataColumns=*/nullptr, + /*indexColumns=*/std::vector{"col1", "col2"}); + + ASSERT_TRUE(tableHandleWithIndex->supportsIndexLookup()); + ASSERT_TRUE(tableHandleWithIndex->needsIndexSplit()); + ASSERT_EQ(tableHandleWithIndex->indexColumns().size(), 2); + ASSERT_EQ(tableHandleWithIndex->indexColumns()[0], "col1"); + ASSERT_EQ(tableHandleWithIndex->indexColumns()[1], "col2"); +} + +// --- Column extraction pushdown tests --- + +using EPE = ExtractionPathElementPtr; + +TEST(ExtractionStepTest, nameRoundTrip) { + // Verify all extraction steps can be converted to names and back. + std::vector steps = { + ExtractionStep::kStructField, + ExtractionStep::kMapKeys, + ExtractionStep::kMapValues, + ExtractionStep::kMapKeyFilter, + ExtractionStep::kArrayElements, + ExtractionStep::kSize, + }; + for (auto step : steps) { + auto name = extractionStepName(step); + ASSERT_EQ(extractionStepFromName(name), step); + } +} + +TEST(ExtractionStepTest, nameValues) { + ASSERT_EQ(extractionStepName(ExtractionStep::kStructField), "STRUCT_FIELD"); + ASSERT_EQ(extractionStepName(ExtractionStep::kMapKeys), "MAP_KEYS"); + ASSERT_EQ(extractionStepName(ExtractionStep::kMapValues), "MAP_VALUES"); + ASSERT_EQ( + extractionStepName(ExtractionStep::kMapKeyFilter), "MAP_KEY_FILTER"); + ASSERT_EQ( + extractionStepName(ExtractionStep::kArrayElements), "ARRAY_ELEMENTS"); + ASSERT_EQ(extractionStepName(ExtractionStep::kSize), "SIZE"); +} + +TEST(DeriveExtractionOutputTypeTest, mapKeys) { + // map_keys(col) where col: MAP(VARCHAR, BIGINT) -> ARRAY(VARCHAR). + auto hiveType = MAP(VARCHAR(), BIGINT()); + std::vector chain = { + ExtractionPathElement::simple(ExtractionStep::kMapKeys)}; + auto outputType = deriveExtractionOutputType(hiveType, chain); + ASSERT_TRUE(outputType->equivalent(*ARRAY(VARCHAR()))); +} + +TEST(DeriveExtractionOutputTypeTest, mapValues) { + // map_values(col) where col: MAP(VARCHAR, BIGINT) -> ARRAY(BIGINT). + auto hiveType = MAP(VARCHAR(), BIGINT()); + std::vector chain = { + ExtractionPathElement::simple(ExtractionStep::kMapValues)}; + auto outputType = deriveExtractionOutputType(hiveType, chain); + ASSERT_TRUE(outputType->equivalent(*ARRAY(BIGINT()))); +} + +TEST(DeriveExtractionOutputTypeTest, size) { + // cardinality(col) where col: MAP(VARCHAR, BIGINT) -> BIGINT. + auto mapType = MAP(VARCHAR(), BIGINT()); + std::vector chain = { + ExtractionPathElement::simple(ExtractionStep::kSize)}; + auto outputType = deriveExtractionOutputType(mapType, chain); + ASSERT_TRUE(outputType->equivalent(*BIGINT())); + + // cardinality(col) where col: ARRAY(BIGINT) -> BIGINT. + auto arrayType = ARRAY(BIGINT()); + outputType = deriveExtractionOutputType(arrayType, chain); + ASSERT_TRUE(outputType->equivalent(*BIGINT())); +} + +TEST(DeriveExtractionOutputTypeTest, structFieldMapKeys) { + // map_keys(col.a.b) where col: ROW(a: ROW(b: MAP(K, V))) -> ARRAY(K). + auto hiveType = ROW({{"a", ROW({{"b", MAP(VARCHAR(), BIGINT())}})}}); + std::vector chain = { + ExtractionPathElement::structField("a"), + ExtractionPathElement::structField("b"), + ExtractionPathElement::simple(ExtractionStep::kMapKeys)}; + auto outputType = deriveExtractionOutputType(hiveType, chain); + ASSERT_TRUE(outputType->equivalent(*ARRAY(VARCHAR()))); +} + +TEST(DeriveExtractionOutputTypeTest, structFieldSize) { + // cardinality(col.features) where col: ROW(features: ARRAY(FLOAT), + // label: INT) -> BIGINT. + auto hiveType = ROW({{"features", ARRAY(REAL())}, {"label", INTEGER()}}); + std::vector chain = { + ExtractionPathElement::structField("features"), + ExtractionPathElement::simple(ExtractionStep::kSize)}; + auto outputType = deriveExtractionOutputType(hiveType, chain); + ASSERT_TRUE(outputType->equivalent(*BIGINT())); +} + +TEST(DeriveExtractionOutputTypeTest, mapValuesArrayElementsMapKeys) { + // map_keys(map_values(col)) where col: MAP(K1, MAP(K2, V)) + // -> ARRAY(ARRAY(K2)). + auto hiveType = MAP(VARCHAR(), MAP(INTEGER(), DOUBLE())); + std::vector chain = { + ExtractionPathElement::simple(ExtractionStep::kMapValues), + ExtractionPathElement::simple(ExtractionStep::kArrayElements), + ExtractionPathElement::simple(ExtractionStep::kMapKeys)}; + auto outputType = deriveExtractionOutputType(hiveType, chain); + ASSERT_TRUE(outputType->equivalent(*ARRAY(ARRAY(INTEGER())))); +} + +TEST(DeriveExtractionOutputTypeTest, nestedArrayElements) { + // map_keys(array_elements(map_values(col))) + // where col: MAP(K1, ARRAY(MAP(K2, V))) -> ARRAY(ARRAY(ARRAY(K2))). + auto hiveType = MAP(VARCHAR(), ARRAY(MAP(INTEGER(), DOUBLE()))); + std::vector chain = { + ExtractionPathElement::simple(ExtractionStep::kMapValues), + ExtractionPathElement::simple(ExtractionStep::kArrayElements), + ExtractionPathElement::simple(ExtractionStep::kArrayElements), + ExtractionPathElement::simple(ExtractionStep::kMapKeys)}; + auto outputType = deriveExtractionOutputType(hiveType, chain); + ASSERT_TRUE(outputType->equivalent(*ARRAY(ARRAY(ARRAY(INTEGER()))))); +} + +TEST(DeriveExtractionOutputTypeTest, mapValuesStructField) { + // map_values(col).x where col: MAP(K, ROW(x: INT, y: INT)) + // -> ARRAY(INT). + auto hiveType = MAP(VARCHAR(), ROW({{"x", INTEGER()}, {"y", INTEGER()}})); + std::vector chain = { + ExtractionPathElement::simple(ExtractionStep::kMapValues), + ExtractionPathElement::simple(ExtractionStep::kArrayElements), + ExtractionPathElement::structField("x")}; + auto outputType = deriveExtractionOutputType(hiveType, chain); + ASSERT_TRUE(outputType->equivalent(*ARRAY(INTEGER()))); +} + +TEST(DeriveExtractionOutputTypeTest, mapKeyFilter) { + // map_subset(col, ARRAY['a', 'b']) where col: MAP(VARCHAR, BIGINT) + // -> MAP(VARCHAR, BIGINT). + auto hiveType = MAP(VARCHAR(), BIGINT()); + std::vector chain = { + ExtractionPathElement::mapKeyFilter(std::vector{"a", "b"})}; + auto outputType = deriveExtractionOutputType(hiveType, chain); + ASSERT_TRUE(outputType->equivalent(*MAP(VARCHAR(), BIGINT()))); +} + +TEST(DeriveExtractionOutputTypeTest, mapKeyFilterMapValuesStructField) { + // element_at(col, 'foo').x via extraction chain: + // [MapKeyFilter(["foo"]), MapValues, ArrayElements, StructField("x")] + // where col: MAP(VARCHAR, ROW(x: INT, y: INT)) -> ARRAY(INT). + auto hiveType = MAP(VARCHAR(), ROW({{"x", INTEGER()}, {"y", INTEGER()}})); + std::vector chain = { + ExtractionPathElement::mapKeyFilter(std::vector{"foo"}), + ExtractionPathElement::simple(ExtractionStep::kMapValues), + ExtractionPathElement::simple(ExtractionStep::kArrayElements), + ExtractionPathElement::structField("x")}; + auto outputType = deriveExtractionOutputType(hiveType, chain); + ASSERT_TRUE(outputType->equivalent(*ARRAY(INTEGER()))); +} + +TEST(DeriveExtractionOutputTypeTest, nestedKeyFilter) { + // Nested key filter: MAP(K1, MAP(VARCHAR, ROW(x: INT, y: INT))) + // Chain: [MapValues, AE, MapKeyFilter(["foo"]), MapValues, AE, + // StructField("x")] + // -> ARRAY(ARRAY(INT)). + auto hiveType = + MAP(VARCHAR(), MAP(VARCHAR(), ROW({{"x", INTEGER()}, {"y", DOUBLE()}}))); + std::vector chain = { + ExtractionPathElement::simple(ExtractionStep::kMapValues), + ExtractionPathElement::simple(ExtractionStep::kArrayElements), + ExtractionPathElement::mapKeyFilter(std::vector{"foo"}), + ExtractionPathElement::simple(ExtractionStep::kMapValues), + ExtractionPathElement::simple(ExtractionStep::kArrayElements), + ExtractionPathElement::structField("x")}; + auto outputType = deriveExtractionOutputType(hiveType, chain); + ASSERT_TRUE(outputType->equivalent(*ARRAY(ARRAY(INTEGER())))); +} + +TEST(DeriveExtractionOutputTypeTest, emptyChain) { + // Empty chain means pass-through. + auto hiveType = MAP(VARCHAR(), BIGINT()); + std::vector chain = {}; + auto outputType = deriveExtractionOutputType(hiveType, chain); + ASSERT_TRUE(outputType->equivalent(*hiveType)); +} + +TEST(DeriveExtractionOutputTypeTest, errorMissingArrayElementsAfterMapValues) { + // [MapValues, MapKeys] on MAP(K1, ARRAY(MAP(K2, V))) — missing + // ArrayElements. + auto hiveType = MAP(VARCHAR(), ARRAY(MAP(INTEGER(), DOUBLE()))); + std::vector chain = { + ExtractionPathElement::simple(ExtractionStep::kMapValues), + ExtractionPathElement::simple(ExtractionStep::kMapKeys)}; + VELOX_ASSERT_THROW( + deriveExtractionOutputType(hiveType, chain), + "MapKeys requires MAP input, got: ARRAY"); +} + +TEST(DeriveExtractionOutputTypeTest, errorMissingArrayElementsAfterMapKeys) { + // [MapKeys, StructField("x")] on MAP(ROW(x: INT, y: INT), V) — missing + // ArrayElements. + auto hiveType = MAP(ROW({{"x", INTEGER()}, {"y", INTEGER()}}), BIGINT()); + std::vector chain = { + ExtractionPathElement::simple(ExtractionStep::kMapKeys), + ExtractionPathElement::structField("x")}; + VELOX_ASSERT_THROW( + deriveExtractionOutputType(hiveType, chain), + "StructField requires ROW input, got: ARRAY"); +} + +TEST(DeriveExtractionOutputTypeTest, errorSizeNotTerminal) { + // Size must be the last step. + auto hiveType = MAP(VARCHAR(), BIGINT()); + std::vector chain = { + ExtractionPathElement::simple(ExtractionStep::kSize), + ExtractionPathElement::simple(ExtractionStep::kMapKeys)}; + VELOX_ASSERT_THROW( + deriveExtractionOutputType(hiveType, chain), + "Size must be the last step"); +} + +TEST(DeriveExtractionOutputTypeTest, errorStructFieldOnMap) { + auto hiveType = MAP(VARCHAR(), BIGINT()); + std::vector chain = {ExtractionPathElement::structField("x")}; + VELOX_ASSERT_THROW( + deriveExtractionOutputType(hiveType, chain), + "StructField requires ROW input"); +} + +TEST(DeriveExtractionOutputTypeTest, errorMapKeysOnArray) { + auto hiveType = ARRAY(BIGINT()); + std::vector chain = { + ExtractionPathElement::simple(ExtractionStep::kMapKeys)}; + VELOX_ASSERT_THROW( + deriveExtractionOutputType(hiveType, chain), + "MapKeys requires MAP input"); +} + +TEST(DeriveExtractionOutputTypeTest, errorSizeOnScalar) { + auto hiveType = BIGINT(); + std::vector chain = { + ExtractionPathElement::simple(ExtractionStep::kSize)}; + VELOX_ASSERT_THROW( + deriveExtractionOutputType(hiveType, chain), + "Size requires MAP or ARRAY input"); +} + +TEST(DeriveExtractionOutputTypeTest, errorNonExistentField) { + auto hiveType = ROW({{"a", INTEGER()}}); + std::vector chain = {ExtractionPathElement::structField("nonexistent")}; + VELOX_ASSERT_THROW( + deriveExtractionOutputType(hiveType, chain), + "non-existent field: nonexistent"); +} + +TEST(ColumnExtractionTest, singleExtraction) { + // Single extraction: map_keys(col) where col: MAP(VARCHAR, BIGINT). + Type::registerSerDe(); + HiveColumnHandle::registerSerDe(); + + auto hiveType = MAP(VARCHAR(), BIGINT()); + auto outputType = ARRAY(VARCHAR()); + std::vector extractions = { + {"keys", + {ExtractionPathElement::simple(ExtractionStep::kMapKeys)}, + outputType}}; + + auto handle = std::make_shared( + "col", + HiveColumnHandle::ColumnType::kRegular, + outputType, + hiveType, + std::vector{}, + std::move(extractions)); + + ASSERT_EQ(handle->name(), "col"); + ASSERT_TRUE(handle->dataType()->equivalent(*ARRAY(VARCHAR()))); + ASSERT_TRUE(handle->hiveType()->equivalent(*MAP(VARCHAR(), BIGINT()))); + ASSERT_TRUE(handle->requiredSubfields().empty()); + ASSERT_EQ(handle->extractions().size(), 1); + ASSERT_EQ(handle->extractions()[0].outputName, "keys"); + ASSERT_EQ(handle->extractions()[0].chain.size(), 1); + ASSERT_EQ( + handle->extractions()[0].chain[0]->step(), ExtractionStep::kMapKeys); + ASSERT_TRUE(handle->extractions()[0].dataType->equivalent(*ARRAY(VARCHAR()))); + + // Serialization round-trip. + auto obj = handle->serialize(); + auto clone = ISerializable::deserialize(obj); + ASSERT_EQ(clone->toString(), handle->toString()); + ASSERT_EQ(clone->extractions().size(), 1); + ASSERT_EQ(clone->extractions()[0].outputName, "keys"); + ASSERT_TRUE(clone->extractions()[0].dataType->equivalent(*ARRAY(VARCHAR()))); +} + +TEST(ColumnExtractionTest, multipleExtractions) { + // Multiple extractions from the same column: + // map_keys(col), cardinality(col) where col: MAP(VARCHAR, BIGINT). + Type::registerSerDe(); + HiveColumnHandle::registerSerDe(); + + auto hiveType = MAP(VARCHAR(), BIGINT()); + auto keysType = ARRAY(VARCHAR()); + auto sizeType = BIGINT(); + auto rowOutputType = ROW({{"keys", keysType}, {"size", sizeType}}); + std::vector extractions = { + {"keys", + {ExtractionPathElement::simple(ExtractionStep::kMapKeys)}, + keysType}, + {"size", + {ExtractionPathElement::simple(ExtractionStep::kSize)}, + sizeType}}; + + auto handle = std::make_shared( + "col", + HiveColumnHandle::ColumnType::kRegular, + rowOutputType, + hiveType, + std::vector{}, + std::move(extractions)); + + ASSERT_EQ(handle->extractions().size(), 2); + ASSERT_EQ(handle->extractions()[0].outputName, "keys"); + ASSERT_EQ(handle->extractions()[1].outputName, "size"); + + // Serialization round-trip. + auto obj = handle->serialize(); + auto clone = ISerializable::deserialize(obj); + ASSERT_EQ(clone->extractions().size(), 2); + ASSERT_EQ(clone->extractions()[0].outputName, "keys"); + ASSERT_EQ(clone->extractions()[1].outputName, "size"); + ASSERT_TRUE(clone->extractions()[0].dataType->equivalent(*ARRAY(VARCHAR()))); + ASSERT_TRUE(clone->extractions()[1].dataType->equivalent(*BIGINT())); +} + +TEST(ColumnExtractionTest, complexChainWithKeyFilter) { + // Chain: [MapKeyFilter(["foo", "bar"]), MapValues, AE, StructField("x")] + // on MAP(VARCHAR, ROW(x: INT, y: INT)) -> ARRAY(INT). + Type::registerSerDe(); + HiveColumnHandle::registerSerDe(); + + auto hiveType = MAP(VARCHAR(), ROW({{"x", INTEGER()}, {"y", INTEGER()}})); + auto outputType = ARRAY(INTEGER()); + std::vector extractions = { + {"col_x", + {ExtractionPathElement::mapKeyFilter( + std::vector{"foo", "bar"}), + ExtractionPathElement::simple(ExtractionStep::kMapValues), + ExtractionPathElement::simple(ExtractionStep::kArrayElements), + ExtractionPathElement::structField("x")}, + outputType}}; + + auto handle = std::make_shared( + "col", + HiveColumnHandle::ColumnType::kRegular, + outputType, + hiveType, + std::vector{}, + std::move(extractions)); + + ASSERT_EQ(handle->extractions().size(), 1); + ASSERT_EQ(handle->extractions()[0].chain.size(), 4); + + // Verify key filter keys are preserved. + ASSERT_THAT( + static_cast( + *handle->extractions()[0].chain[0]) + .filterKeys(), + testing::ElementsAre("foo", "bar")); + + // Serialization round-trip. + auto obj = handle->serialize(); + auto clone = ISerializable::deserialize(obj); + ASSERT_EQ(clone->extractions().size(), 1); + ASSERT_EQ(clone->extractions()[0].chain.size(), 4); + ASSERT_THAT( + static_cast( + *clone->extractions()[0].chain[0]) + .filterKeys(), + testing::ElementsAre("foo", "bar")); +} + +TEST(ColumnExtractionTest, intFilterKeys) { + // MapKeyFilter with integer keys. + Type::registerSerDe(); + HiveColumnHandle::registerSerDe(); + + auto hiveType = MAP(BIGINT(), VARCHAR()); + auto outputType = MAP(BIGINT(), VARCHAR()); + std::vector extractions = { + {"filtered", + {ExtractionPathElement::mapKeyFilter(std::vector{1, 2, 3})}, + outputType}}; + + auto handle = std::make_shared( + "col", + HiveColumnHandle::ColumnType::kRegular, + outputType, + hiveType, + std::vector{}, + std::move(extractions)); + + ASSERT_THAT( + static_cast( + *handle->extractions()[0].chain[0]) + .filterKeys(), + testing::ElementsAre(1, 2, 3)); + + // Serialization round-trip. + auto obj = handle->serialize(); + auto clone = ISerializable::deserialize(obj); + ASSERT_THAT( + static_cast( + *clone->extractions()[0].chain[0]) + .filterKeys(), + testing::ElementsAre(1, 2, 3)); +} + +TEST(ColumnExtractionTest, mutualExclusivity) { + // Cannot set both extractions and requiredSubfields. + auto hiveType = MAP(VARCHAR(), BIGINT()); + std::vector requiredSubfields; + requiredSubfields.emplace_back("col[\"foo\"]"); + std::vector extractions = { + {"keys", + {ExtractionPathElement::simple(ExtractionStep::kMapKeys)}, + ARRAY(VARCHAR())}}; + + VELOX_ASSERT_THROW( + std::make_shared( + "col", + HiveColumnHandle::ColumnType::kRegular, + ARRAY(VARCHAR()), + hiveType, + std::move(requiredSubfields), + std::move(extractions)), + "mutually exclusive"); +} + +TEST(ColumnExtractionTest, extractionOutputTypeMismatch) { + // Declared output type doesn't match derived type. + auto hiveType = MAP(VARCHAR(), BIGINT()); + std::vector extractions = { + {"keys", + {ExtractionPathElement::simple(ExtractionStep::kMapKeys)}, + // Wrong: should be ARRAY(VARCHAR), not ARRAY(BIGINT). + ARRAY(BIGINT())}}; + + VELOX_ASSERT_THROW( + std::make_shared( + "col", + HiveColumnHandle::ColumnType::kRegular, + ARRAY(BIGINT()), + hiveType, + std::vector{}, + std::move(extractions)), + "does not match derived type"); +} + +TEST(ColumnExtractionTest, noExtractionsBackwardCompatible) { + // Empty extractions: existing behavior, dataType must match hiveType. + auto type = MAP(VARCHAR(), BIGINT()); + auto handle = std::make_shared( + "col", HiveColumnHandle::ColumnType::kRegular, type, type); + + ASSERT_TRUE(handle->extractions().empty()); + ASSERT_TRUE(handle->requiredSubfields().empty()); + ASSERT_TRUE(handle->dataType()->equivalent(*type)); +} + +TEST(ColumnExtractionTest, threeExtractionsFromSameColumn) { + // Three extractions: size, keys, and value subfield. + // col: MAP(BIGINT, ROW(a: VARCHAR, b: DOUBLE, c: INT)) + Type::registerSerDe(); + HiveColumnHandle::registerSerDe(); + + auto hiveType = + MAP(BIGINT(), ROW({{"a", VARCHAR()}, {"b", DOUBLE()}, {"c", INTEGER()}})); + auto szType = BIGINT(); + auto keysType = ARRAY(BIGINT()); + auto valsAType = ARRAY(VARCHAR()); + auto rowOutputType = + ROW({{"sz", szType}, {"keys", keysType}, {"vals_a", valsAType}}); + + std::vector extractions = { + {"sz", {ExtractionPathElement::simple(ExtractionStep::kSize)}, szType}, + {"keys", + {ExtractionPathElement::simple(ExtractionStep::kMapKeys)}, + keysType}, + {"vals_a", + {ExtractionPathElement::simple(ExtractionStep::kMapValues), + ExtractionPathElement::simple(ExtractionStep::kArrayElements), + ExtractionPathElement::structField("a")}, + valsAType}}; + + auto handle = std::make_shared( + "col", + HiveColumnHandle::ColumnType::kRegular, + rowOutputType, + hiveType, + std::vector{}, + std::move(extractions)); + + ASSERT_EQ(handle->extractions().size(), 3); + + // Serialization round-trip. + auto obj = handle->serialize(); + auto clone = ISerializable::deserialize(obj); + ASSERT_EQ(clone->extractions().size(), 3); + ASSERT_EQ(clone->extractions()[0].outputName, "sz"); + ASSERT_EQ(clone->extractions()[1].outputName, "keys"); + ASSERT_EQ(clone->extractions()[2].outputName, "vals_a"); + ASSERT_TRUE(clone->extractions()[0].dataType->equivalent(*BIGINT())); + ASSERT_TRUE(clone->extractions()[1].dataType->equivalent(*ARRAY(BIGINT()))); + ASSERT_TRUE(clone->extractions()[2].dataType->equivalent(*ARRAY(VARCHAR()))); +} +// --- Runtime extraction application tests --- + +class ExtractionUtilsTest : public testing::Test, public test::VectorTestBase { + protected: + static void SetUpTestSuite() { + memory::MemoryManager::testingSetInstance({}); + } +}; + +TEST_F(ExtractionUtilsTest, mapKeys) { + // Apply MapKeys to a MapVector. + auto mapVector = makeMapVector( + {{{{"a", 1}, {"b", 2}}}, {{{"c", 3}}}}); + + std::vector chain = { + ExtractionPathElement::simple(ExtractionStep::kMapKeys)}; + auto result = applyExtractionChain(mapVector, chain, pool()); + + ASSERT_TRUE(result->type()->isArray()); + auto* array = result->as(); + ASSERT_EQ(array->size(), 2); + // First map has 2 keys, second has 1. + ASSERT_EQ(array->sizeAt(0), 2); + ASSERT_EQ(array->sizeAt(1), 1); +} + +TEST_F(ExtractionUtilsTest, mapValues) { + // Apply MapValues to a MapVector. + auto mapVector = makeMapVector( + {{{{"a", 10}, {"b", 20}}}, {{{"c", 30}}}}); + + std::vector chain = { + ExtractionPathElement::simple(ExtractionStep::kMapValues)}; + auto result = applyExtractionChain(mapVector, chain, pool()); + + ASSERT_TRUE(result->type()->isArray()); + auto* array = result->as(); + ASSERT_EQ(array->size(), 2); + ASSERT_EQ(array->sizeAt(0), 2); + ASSERT_EQ(array->sizeAt(1), 1); + // Check values. + auto* elements = array->elements()->as>(); + ASSERT_EQ(elements->valueAt(0), 10); + ASSERT_EQ(elements->valueAt(1), 20); + ASSERT_EQ(elements->valueAt(2), 30); +} + +TEST_F(ExtractionUtilsTest, sizeOnMap) { + // Apply Size to a MapVector. + auto mapVector = makeMapVector( + {{{{"a", 1}, {"b", 2}, {"c", 3}}}, {{{"d", 4}}}}); + + std::vector chain = { + ExtractionPathElement::simple(ExtractionStep::kSize)}; + auto result = applyExtractionChain(mapVector, chain, pool()); + + ASSERT_TRUE(result->type()->isBigint()); + auto* sizes = result->as>(); + ASSERT_EQ(sizes->size(), 2); + ASSERT_EQ(sizes->valueAt(0), 3); + ASSERT_EQ(sizes->valueAt(1), 1); +} + +TEST_F(ExtractionUtilsTest, sizeOnArray) { + // Apply Size to an ArrayVector. + auto arrayVector = makeArrayVector({{1, 2, 3}, {4, 5}}); + + std::vector chain = { + ExtractionPathElement::simple(ExtractionStep::kSize)}; + auto result = applyExtractionChain(arrayVector, chain, pool()); + + ASSERT_TRUE(result->type()->isBigint()); + auto* sizes = result->as>(); + ASSERT_EQ(sizes->size(), 2); + ASSERT_EQ(sizes->valueAt(0), 3); + ASSERT_EQ(sizes->valueAt(1), 2); +} + +TEST_F(ExtractionUtilsTest, structField) { + // Apply StructField to a RowVector. + auto rowVector = makeRowVector( + {"x", "y"}, + {makeFlatVector({10, 20}), makeFlatVector({30, 40})}); + + std::vector chain = {ExtractionPathElement::structField("x")}; + auto result = applyExtractionChain(rowVector, chain, pool()); + + ASSERT_TRUE(result->type()->isInteger()); + auto* flat = result->as>(); + ASSERT_EQ(flat->valueAt(0), 10); + ASSERT_EQ(flat->valueAt(1), 20); +} + +TEST_F(ExtractionUtilsTest, mapKeyFilter) { + // Apply MapKeyFilter to filter specific keys. + auto mapVector = makeMapVector( + {{{{"a", 1}, {"b", 2}, {"c", 3}}}, {{{"a", 10}, {"d", 40}}}}); + + std::vector chain = { + ExtractionPathElement::mapKeyFilter(std::vector{"a", "c"})}; + auto result = applyExtractionChain(mapVector, chain, pool()); + + ASSERT_TRUE(result->type()->isMap()); + auto* filteredMap = result->as(); + ASSERT_EQ(filteredMap->size(), 2); + // First map: "a" and "c" kept (2 out of 3). + ASSERT_EQ(filteredMap->sizeAt(0), 2); + // Second map: only "a" kept (1 out of 2). + ASSERT_EQ(filteredMap->sizeAt(1), 1); +} + +TEST_F(ExtractionUtilsTest, chainMapValuesStructField) { + // Chain: [MapValues, ArrayElements, StructField("x")] + // on MAP(VARCHAR, ROW(x: INT, y: INT)) -> ARRAY(INT). + auto keys = makeFlatVector({"k1", "k2", "k3"}); + auto structValues = makeRowVector( + {"x", "y"}, + {makeFlatVector({10, 20, 30}), + makeFlatVector({100, 200, 300})}); + auto mapVector = makeMapVector({0, 2}, keys, structValues); + + std::vector chain = { + ExtractionPathElement::simple(ExtractionStep::kMapValues), + ExtractionPathElement::simple(ExtractionStep::kArrayElements), + ExtractionPathElement::structField("x")}; + auto result = applyExtractionChain(mapVector, chain, pool()); + + // Output should be ARRAY(INT). + ASSERT_TRUE(result->type()->isArray()); + auto* array = result->as(); + ASSERT_EQ(array->size(), 2); + // First map had 2 entries, second had 1. + ASSERT_EQ(array->sizeAt(0), 2); + ASSERT_EQ(array->sizeAt(1), 1); + auto* elements = array->elements()->as>(); + ASSERT_EQ(elements->valueAt(0), 10); + ASSERT_EQ(elements->valueAt(1), 20); + ASSERT_EQ(elements->valueAt(2), 30); +} + +TEST_F(ExtractionUtilsTest, emptyChain) { + // Empty chain should return the input unchanged. + auto mapVector = makeMapVector({{{{"a", 1}}}}); + std::vector chain = {}; + auto result = applyExtractionChain(mapVector, chain, pool()); + ASSERT_EQ(result.get(), mapVector.get()); +} diff --git a/velox/connectors/tests/CMakeLists.txt b/velox/connectors/tests/CMakeLists.txt index b70f227b41e..ff2f31ea83a 100644 --- a/velox/connectors/tests/CMakeLists.txt +++ b/velox/connectors/tests/CMakeLists.txt @@ -17,6 +17,10 @@ add_test(velox_connector_test velox_connector_test) target_link_libraries( velox_connector_test velox_connector + velox_connector_registry + velox_core + velox_memory + GTest::gmock GTest::gtest GTest::gtest_main glog::glog diff --git a/velox/connectors/tests/ConnectorTest.cpp b/velox/connectors/tests/ConnectorTest.cpp index a58cf822777..2adc5ee91d0 100644 --- a/velox/connectors/tests/ConnectorTest.cpp +++ b/velox/connectors/tests/ConnectorTest.cpp @@ -15,15 +15,18 @@ */ #include "velox/connectors/Connector.h" -#include "velox/common/base/tests/GTestUtils.h" -#include "velox/common/config/Config.h" -#include +#include +#include -namespace facebook::velox::connector { +#include +#include -class ConnectorTest : public testing::Test {}; +#include "velox/common/memory/Memory.h" +#include "velox/connectors/ConnectorRegistry.h" +#include "velox/core/QueryCtx.h" +namespace facebook::velox::connector { namespace { class TestConnector : public connector::Connector { @@ -47,54 +50,135 @@ class TestConnector : public connector::Connector { } }; -class TestConnectorFactory : public connector::ConnectorFactory { - public: - static constexpr const char* kConnectorFactoryName = "test-factory"; - - TestConnectorFactory() : ConnectorFactory(kConnectorFactoryName) {} - - std::shared_ptr newConnector( - const std::string& id, - std::shared_ptr /*config*/, - folly::Executor* /*ioExecutor*/ = nullptr, - folly::Executor* /*cpuExecutor*/ = nullptr) override { - return std::make_shared(id); - } -}; - -} // namespace - -TEST_F(ConnectorTest, getAllConnectors) { - registerConnectorFactory(std::make_shared()); - VELOX_ASSERT_THROW( - registerConnectorFactory(std::make_shared()), - "ConnectorFactory with name 'test-factory' is already registered"); - EXPECT_TRUE(hasConnectorFactory(TestConnectorFactory::kConnectorFactoryName)); +TEST(ConnectorTest, registryOperations) { const int32_t numConnectors = 10; for (int32_t i = 0; i < numConnectors; i++) { - registerConnector( - getConnectorFactory(TestConnectorFactory::kConnectorFactoryName) - ->newConnector( - fmt::format("connector-{}", i), - std::make_shared( - std::unordered_map()))); + auto connector = + std::make_shared(fmt::format("connector-{}", i)); + auto connectorId = connector->connectorId(); + ConnectorRegistry::global().insert( + std::move(connectorId), std::move(connector)); } - const auto& connectors = getAllConnectors(); - EXPECT_EQ(connectors.size(), numConnectors); + for (int32_t i = 0; i < numConnectors; i++) { - EXPECT_EQ(connectors.count(fmt::format("connector-{}", i)), 1); + EXPECT_NE( + ConnectorRegistry::tryGet(fmt::format("connector-{}", i)), nullptr); } - for (int32_t i = 0; i < numConnectors; i++) { - unregisterConnector(fmt::format("connector-{}", i)); + EXPECT_EQ(ConnectorRegistry::tryGet("nonexistent"), nullptr); + + auto allTestConnectors = ConnectorRegistry::findAll(); + EXPECT_EQ(allTestConnectors.size(), numConnectors); + + ConnectorRegistry::unregisterAll(); + EXPECT_EQ(ConnectorRegistry::findAll().size(), 0); +} + +class ConnectorRegistryTest : public testing::Test { + protected: + static void SetUpTestSuite() { + memory::MemoryManager::testingSetInstance({}); } - EXPECT_EQ(getAllConnectors().size(), 0); - EXPECT_TRUE( - unregisterConnectorFactory(TestConnectorFactory::kConnectorFactoryName)); - EXPECT_FALSE( - unregisterConnectorFactory(TestConnectorFactory::kConnectorFactoryName)); +}; + +TEST_F(ConnectorRegistryTest, queryScopedOverride) { + auto globalConnector = std::make_shared("global"); + ConnectorRegistry::global().insert("catalog", globalConnector); + + auto queryCtx = core::QueryCtx::create(); + auto queryRegistry = ConnectorRegistry::create(&ConnectorRegistry::global()); + auto queryConnector = std::make_shared("query-override"); + queryRegistry->insert("catalog", queryConnector); + queryCtx->setRegistry(ConnectorRegistry::kRegistryKey, queryRegistry); + + // Query-scoped lookup returns the override. + EXPECT_EQ(ConnectorRegistry::tryGet(*queryCtx, "catalog"), queryConnector); + // Global lookup returns the global connector. + EXPECT_EQ(ConnectorRegistry::tryGet("catalog"), globalConnector); + + ConnectorRegistry::unregisterAll(); +} + +TEST_F(ConnectorRegistryTest, queryScopedFallbackToGlobal) { + auto globalConnector = std::make_shared("global"); + ConnectorRegistry::global().insert("catalog", globalConnector); + + auto queryCtx = core::QueryCtx::create(); + auto queryRegistry = ConnectorRegistry::create(&ConnectorRegistry::global()); + queryCtx->setRegistry(ConnectorRegistry::kRegistryKey, queryRegistry); + + // Empty per-query registry falls back to global. + EXPECT_EQ(ConnectorRegistry::tryGet(*queryCtx, "catalog"), globalConnector); + + ConnectorRegistry::unregisterAll(); +} + +TEST_F(ConnectorRegistryTest, noQueryRegistryFallsBackToGlobal) { + auto globalConnector = std::make_shared("global"); + ConnectorRegistry::global().insert("catalog", globalConnector); + + // QueryCtx with no per-query registry set. + auto queryCtx = core::QueryCtx::create(); + EXPECT_EQ(ConnectorRegistry::tryGet(*queryCtx, "catalog"), globalConnector); + + ConnectorRegistry::unregisterAll(); +} + +TEST_F(ConnectorRegistryTest, queryScopedUnregisterAll) { + auto globalConnector = std::make_shared("global"); + ConnectorRegistry::global().insert("catalog", globalConnector); + + auto queryCtx = core::QueryCtx::create(); + auto queryRegistry = ConnectorRegistry::create(&ConnectorRegistry::global()); + queryRegistry->insert("catalog", std::make_shared("query")); + queryCtx->setRegistry(ConnectorRegistry::kRegistryKey, queryRegistry); + + ConnectorRegistry::unregisterAll(*queryCtx); + + // Query-scoped registry cleared; falls back to global. + EXPECT_EQ(ConnectorRegistry::tryGet(*queryCtx, "catalog"), globalConnector); + // Global is untouched. + EXPECT_EQ(ConnectorRegistry::tryGet("catalog"), globalConnector); + + ConnectorRegistry::unregisterAll(); } -TEST_F(ConnectorTest, connectorSplit) { +// Verify that unregisterAll on a queryCtx without a per-query registry does +// not clear the global registry. +TEST_F(ConnectorRegistryTest, unregisterAllNoQueryRegistry) { + auto globalConnector = std::make_shared("global"); + ConnectorRegistry::global().insert("catalog", globalConnector); + + auto queryCtx = core::QueryCtx::create(); + ConnectorRegistry::unregisterAll(*queryCtx); + + // Global registry is untouched. + EXPECT_EQ(ConnectorRegistry::tryGet("catalog"), globalConnector); + + ConnectorRegistry::unregisterAll(); +} + +TEST_F(ConnectorRegistryTest, queryScopedFindAll) { + ConnectorRegistry::global().insert( + "global-cat", std::make_shared("global-cat")); + + auto queryCtx = core::QueryCtx::create(); + auto queryRegistry = ConnectorRegistry::create(&ConnectorRegistry::global()); + queryRegistry->insert( + "query-cat", std::make_shared("query-cat")); + queryCtx->setRegistry(ConnectorRegistry::kRegistryKey, queryRegistry); + + // findAll with queryCtx sees both query-scoped and global connectors. + auto all = ConnectorRegistry::findAll(*queryCtx); + EXPECT_EQ(all.size(), 2); + + // findAll without queryCtx sees only global. + auto globalOnly = ConnectorRegistry::findAll(); + EXPECT_EQ(globalOnly.size(), 1); + + ConnectorRegistry::unregisterAll(); +} + +TEST(ConnectorTest, connectorSplit) { { const ConnectorSplit split("test", 100, true); ASSERT_EQ(split.connectorId, "test"); @@ -114,4 +198,5 @@ TEST_F(ConnectorTest, connectorSplit) { "[split: connector id test, weight 50, cacheable false]"); } } +} // namespace } // namespace facebook::velox::connector diff --git a/velox/connectors/tpcds/CMakeLists.txt b/velox/connectors/tpcds/CMakeLists.txt index 4c92a290f13..d05ff0b3ce0 100644 --- a/velox/connectors/tpcds/CMakeLists.txt +++ b/velox/connectors/tpcds/CMakeLists.txt @@ -12,7 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -velox_add_library(velox_tpcds_connector TpcdsConnector.cpp) +velox_add_library( + velox_tpcds_connector + TpcdsConnector.cpp + HEADERS + TpcdsConnector.h + TpcdsConnectorSplit.h +) velox_link_libraries(velox_tpcds_connector PUBLIC velox_connector velox_tpcds_gen PRIVATE fmt::fmt) if(${VELOX_BUILD_TESTING}) diff --git a/velox/connectors/tpcds/TpcdsConnector.cpp b/velox/connectors/tpcds/TpcdsConnector.cpp index 5902f49da4c..b981bda903a 100644 --- a/velox/connectors/tpcds/TpcdsConnector.cpp +++ b/velox/connectors/tpcds/TpcdsConnector.cpp @@ -65,7 +65,7 @@ TpcdsDataSource::TpcdsDataSource( handle, "ColumnHandle must be an instance of TpcdsColumnHandle " "for '{}' on table '{}'", - handle->name(), + it->second->name(), toTableName(table_)); auto idx = tpcdsTableSchema->getChildIdxIfExists(handle->name()); diff --git a/velox/connectors/tpcds/TpcdsConnector.h b/velox/connectors/tpcds/TpcdsConnector.h index 88ad845ff39..329114d9005 100644 --- a/velox/connectors/tpcds/TpcdsConnector.h +++ b/velox/connectors/tpcds/TpcdsConnector.h @@ -102,7 +102,7 @@ class TpcdsDataSource : public velox::connector::DataSource { return completedBytes_; } - std::unordered_map runtimeStats() override { + std::unordered_map getRuntimeStats() override { return {}; } diff --git a/velox/connectors/tpcds/TpcdsConnectorSplit.h b/velox/connectors/tpcds/TpcdsConnectorSplit.h index 1f51eb22af6..80fb8863e76 100644 --- a/velox/connectors/tpcds/TpcdsConnectorSplit.h +++ b/velox/connectors/tpcds/TpcdsConnectorSplit.h @@ -53,8 +53,8 @@ template <> struct fmt::formatter : formatter { auto format( - facebook::velox::connector::tpcds::TpcdsConnectorSplit s, - format_context& ctx) { + facebook::velox::connector::tpcds::TpcdsConnectorSplit const& s, + format_context& ctx) const { return formatter::format(s.toString(), ctx); } }; @@ -64,8 +64,9 @@ struct fmt::formatter< std::shared_ptr> : formatter { auto format( - std::shared_ptr s, - format_context& ctx) { + std::shared_ptr< + facebook::velox::connector::tpcds::TpcdsConnectorSplit> const& s, + format_context& ctx) const { return formatter::format(s->toString(), ctx); } }; diff --git a/velox/connectors/tpcds/tests/TpcdsConnectorTest.cpp b/velox/connectors/tpcds/tests/TpcdsConnectorTest.cpp index 987a0a9e058..c03cddbd8bf 100644 --- a/velox/connectors/tpcds/tests/TpcdsConnectorTest.cpp +++ b/velox/connectors/tpcds/tests/TpcdsConnectorTest.cpp @@ -17,6 +17,7 @@ #include "velox/connectors/tpcds/TpcdsConnector.h" #include #include "gtest/gtest.h" +#include "velox/connectors/ConnectorRegistry.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/OperatorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" @@ -35,18 +36,20 @@ class TpcdsConnectorTest : public exec::test::OperatorTestBase { kTpcdsConnectorId, std::make_shared( std::unordered_map())); - connector::registerConnector(connector); + connector::ConnectorRegistry::global().insert( + connector->connectorId(), connector); } void TearDown() override { - connector::unregisterConnector(kTpcdsConnectorId); + connector::ConnectorRegistry::global().erase(kTpcdsConnectorId); OperatorTestBase::TearDown(); } exec::Split makeTpcdsSplit(size_t totalParts = 1, size_t partNumber = 0) const { - return exec::Split(std::make_shared( - kTpcdsConnectorId, /*cacheable=*/true, totalParts, partNumber)); + return exec::Split( + std::make_shared( + kTpcdsConnectorId, /*cacheable=*/true, totalParts, partNumber)); } RowVectorPtr getResults( @@ -118,8 +121,9 @@ TEST_F(TpcdsConnectorTest, singleColumnWithAlias) { auto plan = exec::test::PlanBuilder() .startTableScan() .outputType(outputType) - .tableHandle(std::make_shared( - kTpcdsConnectorId, velox::tpcds::Table::TBL_ITEM)) + .tableHandle( + std::make_shared( + kTpcdsConnectorId, velox::tpcds::Table::TBL_ITEM)) .assignments({ {aliasedName, std::make_shared("i_product_name")}, diff --git a/velox/connectors/tpch/CMakeLists.txt b/velox/connectors/tpch/CMakeLists.txt index 43ebe514b28..20c3fc632bb 100644 --- a/velox/connectors/tpch/CMakeLists.txt +++ b/velox/connectors/tpch/CMakeLists.txt @@ -12,7 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -velox_add_library(velox_tpch_connector OBJECT TpchConnector.cpp) +velox_add_library( + velox_tpch_connector + OBJECT + TpchConnector.cpp + HEADERS + TpchConnector.h + TpchConnectorSplit.h +) velox_link_libraries( velox_tpch_connector diff --git a/velox/connectors/tpch/TpchConnector.h b/velox/connectors/tpch/TpchConnector.h index 5d006490bad..3325b9dee5c 100644 --- a/velox/connectors/tpch/TpchConnector.h +++ b/velox/connectors/tpch/TpchConnector.h @@ -34,10 +34,28 @@ class TpchColumnHandle : public ColumnHandle { public: explicit TpchColumnHandle(const std::string& name) : name_(name) {} - const std::string& name() const { + const std::string& name() const override { return name_; } + folly::dynamic serialize() const override { + folly::dynamic obj = folly::dynamic::object; + obj["name"] = TpchColumnHandle::getClassName(); + obj["columnName"] = name_; + return obj; + } + + static std::shared_ptr create(const folly::dynamic& obj) { + auto name = obj["columnName"].asString(); + return std::make_shared(name); + } + + static void registerSerDe() { + registerDeserializer(); + } + + VELOX_DEFINE_CLASS_NAME(TpchColumnHandle) + private: const std::string name_; }; @@ -83,6 +101,40 @@ class TpchTableHandle : public ConnectorTableHandle { return scaleFactor_; } + folly::dynamic serialize() const override { + folly::dynamic obj = folly::dynamic::object; + obj["name"] = TpchTableHandle::getClassName(); + obj["connectorId"] = connectorId(); + obj["table"] = static_cast(table_); + obj["scaleFactor"] = scaleFactor_; + if (filterExpression_) { + obj["filterExpression"] = filterExpression_->serialize(); + } + return obj; + } + + static ConnectorTableHandlePtr create( + const folly::dynamic& obj, + void* context) { + auto connectorId = obj["connectorId"].asString(); + auto table = static_cast(obj["table"].asInt()); + auto scaleFactor = obj["scaleFactor"].asDouble(); + velox::core::TypedExprPtr filterExpression = nullptr; + if (obj.count("filterExpression") && !obj["filterExpression"].isNull()) { + filterExpression = + velox::ISerializable::deserialize( + obj["filterExpression"], context); + } + return std::make_shared( + connectorId, table, scaleFactor, std::move(filterExpression)); + } + + static void registerSerDe() { + registerDeserializerWithContext(); + } + + VELOX_DEFINE_CLASS_NAME(TpchTableHandle) + private: const velox::tpch::Table table_; double scaleFactor_; @@ -117,7 +169,7 @@ class TpchDataSource : public DataSource { return completedBytes_; } - std::unordered_map runtimeStats() override { + std::unordered_map getRuntimeStats() override { // TODO: Which stats do we want to expose here? return {}; } @@ -179,6 +231,12 @@ class TpchConnector final : public Connector { CommitStrategy /*commitStrategy*/) override final { VELOX_NYI("TpchConnector does not support data sink."); } + + static void registerSerDe() { + TpchTableHandle::registerSerDe(); + TpchColumnHandle::registerSerDe(); + TpchConnectorSplit::registerSerDe(); + } }; class TpchConnectorFactory : public ConnectorFactory { diff --git a/velox/connectors/tpch/TpchConnectorSplit.h b/velox/connectors/tpch/TpchConnectorSplit.h index b4b5420d1f9..7aa962d6a4e 100644 --- a/velox/connectors/tpch/TpchConnectorSplit.h +++ b/velox/connectors/tpch/TpchConnectorSplit.h @@ -32,19 +32,45 @@ struct TpchConnectorSplit : public connector::ConnectorSplit { bool cacheable, size_t totalParts, size_t partNumber) - : ConnectorSplit(connectorId, /*_splitWeight=*/0, cacheable), + : ConnectorSplit(connectorId, /*_splitWeight=*/1, cacheable), totalParts(totalParts), partNumber(partNumber) { VELOX_CHECK_GE(totalParts, 1, "totalParts must be >= 1"); VELOX_CHECK_GT(totalParts, partNumber, "totalParts must be > partNumber"); } + folly::dynamic serialize() const override { + folly::dynamic obj = folly::dynamic::object; + obj["name"] = TpchConnectorSplit::getClassName(); + obj["connectorId"] = connectorId; + obj["splitWeight"] = splitWeight; + obj["cacheable"] = cacheable; + obj["totalParts"] = totalParts; + obj["partNumber"] = partNumber; + return obj; + } + + static std::shared_ptr create(const folly::dynamic& obj) { + auto connectorId = obj["connectorId"].asString(); + auto cacheable = obj["cacheable"].asBool(); + auto totalParts = static_cast(obj["totalParts"].asInt()); + auto partNumber = static_cast(obj["partNumber"].asInt()); + return std::make_shared( + connectorId, cacheable, totalParts, partNumber); + } + + static void registerSerDe() { + registerDeserializer(); + } + // In how many parts the generated TPC-H table will be segmented, roughly // `rowCount / totalParts` size_t totalParts{1}; // Which of these parts will be read by this split. size_t partNumber{0}; + + VELOX_DEFINE_CLASS_NAME(TpchConnectorSplit) }; } // namespace facebook::velox::connector::tpch diff --git a/velox/connectors/tpch/tests/CMakeLists.txt b/velox/connectors/tpch/tests/CMakeLists.txt index eb4f2718516..63477670d4f 100644 --- a/velox/connectors/tpch/tests/CMakeLists.txt +++ b/velox/connectors/tpch/tests/CMakeLists.txt @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -add_executable(velox_tpch_connector_test TpchConnectorTest.cpp) +add_executable(velox_tpch_connector_test TpchConnectorTest.cpp TpchConnectorSerDeTest.cpp) add_test(velox_tpch_connector_test velox_tpch_connector_test) diff --git a/velox/connectors/tpch/tests/SpeedTest.cpp b/velox/connectors/tpch/tests/SpeedTest.cpp index 180713b0a4e..1874b3cb460 100644 --- a/velox/connectors/tpch/tests/SpeedTest.cpp +++ b/velox/connectors/tpch/tests/SpeedTest.cpp @@ -19,6 +19,7 @@ #include #include "velox/common/memory/Memory.h" +#include "velox/connectors/ConnectorRegistry.h" #include "velox/connectors/tpch/TpchConnector.h" #include "velox/connectors/tpch/TpchConnectorSplit.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" @@ -61,11 +62,12 @@ class TpchSpeedTest { kTpchConnectorId_, std::make_shared( std::unordered_map())); - connector::registerConnector(tpchConnector); + connector::ConnectorRegistry::global().insert( + tpchConnector->connectorId(), tpchConnector); } ~TpchSpeedTest() { - connector::unregisterConnector(kTpchConnectorId_); + connector::ConnectorRegistry::global().erase(kTpchConnectorId_); } void run(tpch::Table table, size_t scaleFactor, size_t numSplits) { @@ -117,8 +119,9 @@ class TpchSpeedTest { for (size_t i = 0; i < numSplits; ++i) { task.addSplit( scanId, - exec::Split(std::make_shared( - kTpchConnectorId_, /*cacheable=*/true, numSplits, i))); + exec::Split( + std::make_shared( + kTpchConnectorId_, /*cacheable=*/true, numSplits, i))); } task.noMoreSplits(scanId); diff --git a/velox/connectors/tpch/tests/TpchConnectorSerDeTest.cpp b/velox/connectors/tpch/tests/TpchConnectorSerDeTest.cpp new file mode 100644 index 00000000000..5f700faf19c --- /dev/null +++ b/velox/connectors/tpch/tests/TpchConnectorSerDeTest.cpp @@ -0,0 +1,152 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "velox/connectors/tpch/TpchConnector.h" +#include "velox/connectors/tpch/TpchConnectorSplit.h" +#include "velox/core/ITypedExpr.h" +#include "velox/type/Type.h" + +namespace facebook::velox::connector::tpch::test { +namespace { + +class TpchConnectorSerDeTest : public testing::Test { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + } + + TpchConnectorSerDeTest() { + Type::registerSerDe(); + core::ITypedExpr::registerSerDe(); + TpchConnector::registerSerDe(); + } + + template + static void testSerde(const T& handle) { + auto str = handle.toString(); + auto obj = handle.serialize(); + auto clone = ISerializable::deserialize(obj); + ASSERT_EQ(clone->toString(), str); + } + + static void testSerde(const TpchColumnHandle& handle) { + auto obj = handle.serialize(); + auto clone = ISerializable::deserialize(obj); + ASSERT_EQ(handle.name(), clone->name()); + } + + static void testSerde(const TpchTableHandle& handle) { + auto str = handle.toString(); + auto obj = handle.serialize(); + auto pool = memory::memoryManager()->addLeafPool(); + auto clone = ISerializable::deserialize(obj, pool.get()); + ASSERT_EQ(clone->toString(), str); + ASSERT_EQ(handle.connectorId(), clone->connectorId()); + ASSERT_EQ(handle.getTable(), clone->getTable()); + ASSERT_EQ(handle.getScaleFactor(), clone->getScaleFactor()); + if (handle.filterExpression()) { + ASSERT_NE(clone->filterExpression(), nullptr); + ASSERT_EQ( + handle.filterExpression()->toString(), + clone->filterExpression()->toString()); + } else { + ASSERT_EQ(clone->filterExpression(), nullptr); + } + } + + static void testSerde(const TpchConnectorSplit& split) { + auto str = split.toString(); + auto obj = split.serialize(); + auto clone = ISerializable::deserialize(obj); + ASSERT_EQ(clone->toString(), str); + ASSERT_EQ(split.connectorId, clone->connectorId); + ASSERT_EQ(split.splitWeight, clone->splitWeight); + ASSERT_EQ(split.cacheable, clone->cacheable); + ASSERT_EQ(split.totalParts, clone->totalParts); + ASSERT_EQ(split.partNumber, clone->partNumber); + } +}; + +TEST_F(TpchConnectorSerDeTest, tpchColumnHandle) { + auto handle1 = TpchColumnHandle("n_nationkey"); + testSerde(handle1); + + auto handle2 = TpchColumnHandle("l_orderkey"); + testSerde(handle2); + + auto handle3 = TpchColumnHandle("c_name"); + testSerde(handle3); +} + +TEST_F(TpchConnectorSerDeTest, tpchTableHandle) { + const std::string connectorId = "test-tpch"; + + auto handle1 = TpchTableHandle(connectorId, velox::tpch::Table::TBL_NATION); + testSerde(handle1); + + auto handle2 = + TpchTableHandle(connectorId, velox::tpch::Table::TBL_LINEITEM, 10.0); + testSerde(handle2); + + auto handle3 = + TpchTableHandle(connectorId, velox::tpch::Table::TBL_ORDERS, 0.01); + testSerde(handle3); + + // Test with filterExpression + auto filterExpr = + std::make_shared(BIGINT(), "n_nationkey"); + auto handle4 = TpchTableHandle( + connectorId, velox::tpch::Table::TBL_NATION, 1.0, filterExpr); + testSerde(handle4); + + std::vector tables = { + velox::tpch::Table::TBL_NATION, + velox::tpch::Table::TBL_REGION, + velox::tpch::Table::TBL_PART, + velox::tpch::Table::TBL_SUPPLIER, + velox::tpch::Table::TBL_PARTSUPP, + velox::tpch::Table::TBL_CUSTOMER, + velox::tpch::Table::TBL_ORDERS, + velox::tpch::Table::TBL_LINEITEM, + }; + + for (auto table : tables) { + testSerde(TpchTableHandle(connectorId, table, 1.0)); + } + + std::vector scaleFactors = {0.01, 0.1, 1.0, 5.0, 10.0, 100.0, 1000.0}; + for (auto sf : scaleFactors) { + testSerde( + TpchTableHandle(connectorId, velox::tpch::Table::TBL_CUSTOMER, sf)); + } +} + +TEST_F(TpchConnectorSerDeTest, tpchConnectorSplit) { + const std::string connectorId = "test-tpch"; + + auto split1 = TpchConnectorSplit(connectorId, false, 10, 5); + testSerde(split1); + + auto split2 = TpchConnectorSplit(connectorId, true, 100, 99); + testSerde(split2); + + auto split3 = TpchConnectorSplit(connectorId, 1, 0); + testSerde(split3); +} + +} // namespace +} // namespace facebook::velox::connector::tpch::test diff --git a/velox/connectors/tpch/tests/TpchConnectorTest.cpp b/velox/connectors/tpch/tests/TpchConnectorTest.cpp index 1d62b7f9577..2e991549f01 100644 --- a/velox/connectors/tpch/tests/TpchConnectorTest.cpp +++ b/velox/connectors/tpch/tests/TpchConnectorTest.cpp @@ -18,6 +18,7 @@ #include #include "gtest/gtest.h" #include "velox/common/base/tests/GTestUtils.h" +#include "velox/connectors/ConnectorRegistry.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/OperatorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" @@ -44,18 +45,20 @@ class TpchConnectorTest : public exec::test::OperatorTestBase { kTpchConnectorId, std::make_shared( std::unordered_map())); - connector::registerConnector(tpchConnector); + connector::ConnectorRegistry::global().insert( + tpchConnector->connectorId(), tpchConnector); } void TearDown() override { - connector::unregisterConnector(kTpchConnectorId); + connector::ConnectorRegistry::global().erase(kTpchConnectorId); OperatorTestBase::TearDown(); } exec::Split makeTpchSplit(size_t totalParts = 1, size_t partNumber = 0) const { - return exec::Split(std::make_shared( - kTpchConnectorId, /*cacheable=*/true, totalParts, partNumber)); + return exec::Split( + std::make_shared( + kTpchConnectorId, /*cacheable=*/true, totalParts, partNumber)); } RowVectorPtr getResults( @@ -132,8 +135,9 @@ TEST_F(TpchConnectorTest, singleColumnWithAlias) { PlanBuilder() .startTableScan() .outputType(outputType) - .tableHandle(std::make_shared( - kTpchConnectorId, Table::TBL_NATION)) + .tableHandle( + std::make_shared( + kTpchConnectorId, Table::TBL_NATION)) .assignments({ {aliasedName, std::make_shared("n_name")}, {"other_name", std::make_shared("n_name")}, @@ -158,8 +162,9 @@ void TpchConnectorTest::runScaleFactorTest(double scaleFactor) { auto plan = PlanBuilder() .startTableScan() .outputType(ROW({}, {})) - .tableHandle(std::make_shared( - kTpchConnectorId, Table::TBL_SUPPLIER, scaleFactor)) + .tableHandle( + std::make_shared( + kTpchConnectorId, Table::TBL_SUPPLIER, scaleFactor)) .endTableScan() .singleAggregation({}, {"count(1)"}) .planNode(); @@ -187,8 +192,9 @@ TEST_F(TpchConnectorTest, lineitemTinyRowCount) { auto plan = PlanBuilder() .startTableScan() .outputType(ROW({}, {})) - .tableHandle(std::make_shared( - kTpchConnectorId, Table::TBL_LINEITEM, 0.01)) + .tableHandle( + std::make_shared( + kTpchConnectorId, Table::TBL_LINEITEM, 0.01)) .endTableScan() .singleAggregation({}, {"count(1)"}) .planNode(); diff --git a/velox/core/CMakeLists.txt b/velox/core/CMakeLists.txt index 26fdad62a49..f68a88900da 100644 --- a/velox/core/CMakeLists.txt +++ b/velox/core/CMakeLists.txt @@ -25,6 +25,19 @@ velox_add_library( QueryConfig.cpp QueryCtx.cpp SimpleFunctionMetadata.cpp + TableWriteTraits.cpp + HEADERS + CoreTypeSystem.h + Expressions.h + ITypedExpr.h + PlanConsistencyChecker.h + PlanFragment.h + PlanNode.h + QueryConfig.h + QueryCtx.h + ScanBatchEvent.h + SimpleFunctionMetadata.h + TableWriteTraits.h ) velox_link_libraries( @@ -33,6 +46,7 @@ velox_link_libraries( velox_arrow_bridge velox_caching velox_common_config + velox_config_property velox_connector velox_exception velox_expression_functions @@ -44,3 +58,10 @@ velox_link_libraries( fmt::fmt PRIVATE velox_encode ) + +velox_add_library(velox_query_config_provider QueryConfigProvider.cpp HEADERS QueryConfigProvider.h) +velox_link_libraries(velox_query_config_provider PUBLIC velox_core velox_config_property) + +velox_add_library(velox_metaprogramming INTERFACE HEADERS Metaprogramming.h) + +velox_add_library(velox_expression_evaluator INTERFACE HEADERS ExpressionEvaluator.h) diff --git a/velox/core/Expressions.cpp b/velox/core/Expressions.cpp index 7ab4c78d048..b1b512b4331 100644 --- a/velox/core/Expressions.cpp +++ b/velox/core/Expressions.cpp @@ -14,9 +14,13 @@ * limitations under the License. */ #include "velox/core/Expressions.h" + +#include + #include "velox/common/Casts.h" #include "velox/common/encode/Base64.h" #include "velox/vector/ComplexVector.h" +#include "velox/vector/ConstantVector.h" #include "velox/vector/SimpleVector.h" #include "velox/vector/VectorSaver.h" @@ -68,6 +72,7 @@ void ITypedExpr::registerSerDe() { registry.Register("FieldAccessTypedExpr", core::FieldAccessTypedExpr::create); registry.Register("InputTypedExpr", core::InputTypedExpr::create); registry.Register("LambdaTypedExpr", core::LambdaTypedExpr::create); + registry.Register("NullIfTypedExpr", core::NullIfTypedExpr::create); } void InputTypedExpr::accept( @@ -87,6 +92,19 @@ TypedExprPtr InputTypedExpr::create(const folly::dynamic& obj, void* context) { return std::make_shared(std::move(type)); } +std::optional ConstantTypedExpr::toBool() const { + VELOX_CHECK( + this->type()->isBoolean(), + "Expected boolean expression, but got {}", + this->type()->toString()); + + if (!isNull()) { + return valueVector_ ? valueVector_->as>()->valueAt(0) + : value_.value(); + } + return std::nullopt; +} + void ConstantTypedExpr::accept( const ITypedExprVisitor& visitor, ITypedExprVisitorContext& context) const { @@ -128,6 +146,12 @@ TypedExprPtr ConstantTypedExpr::create( return std::make_shared(restoreVector(dataStream, pool)); } +// static +TypedExprPtr ConstantTypedExpr::makeNull(const TypePtr& type) { + return std::make_shared( + type, Variant::null(type->kind())); +} + std::string ConstantTypedExpr::toString() const { if (hasValueVector()) { return valueVector_->toString(0); @@ -404,7 +428,8 @@ uint64_t hashImpl(const TypePtr& type, const Variant& value) { } // namespace size_t ConstantTypedExpr::localHash() const { - static const size_t kBaseHash = std::hash()("ConstantTypedExpr"); + static const size_t kBaseHash = + folly::hasher()("ConstantTypedExpr"); uint64_t h; @@ -466,7 +491,7 @@ TypedExprPtr FieldAccessTypedExpr::rewriteInputNames( VELOX_CHECK_EQ(1, newInputs.size()); // Only rewrite name if input in InputTypedExpr. Rewrite in other // cases(like dereference) is unsound. - if (!is_instance_of(newInputs[0])) { + if (!newInputs[0]->isInputKind()) { return std::make_shared(type(), newInputs[0], name_); } auto it = mapping.find(name_); @@ -663,4 +688,63 @@ TypedExprPtr CastTypedExpr::create(const folly::dynamic& obj, void* context) { std::move(type), std::move(inputs), obj["isTryCast"].asBool()); } +NullIfTypedExpr::NullIfTypedExpr( + TypedExprPtr value, + TypedExprPtr comparand, + TypePtr commonType) + : ITypedExpr{ExprKind::kNullIf, value->type(), {std::move(value), std::move(comparand)}}, + commonType_(std::move(commonType)) {} + +TypedExprPtr NullIfTypedExpr::rewriteInputNames( + const std::unordered_map& mapping) const { + auto newInputs = rewriteInputsRecursive(mapping); + return std::make_shared( + std::move(newInputs[0]), std::move(newInputs[1]), commonType_); +} + +size_t NullIfTypedExpr::localHash() const { + static const size_t kBaseHash = + folly::hasher()("NullIfTypedExpr"); + return bits::hashMix(kBaseHash, commonType_->hashKind()); +} + +bool NullIfTypedExpr::operator==(const ITypedExpr& other) const { + const auto* otherNullIf = dynamic_cast(&other); + if (!otherNullIf) { + return false; + } + return *type() == *otherNullIf->type() && + *commonType_ == *otherNullIf->commonType_ && + *inputs()[0] == *otherNullIf->inputs()[0] && + *inputs()[1] == *otherNullIf->inputs()[1]; +} + +std::string NullIfTypedExpr::toString() const { + return fmt::format( + "nullif({}, {})", inputs()[0]->toString(), inputs()[1]->toString()); +} + +void NullIfTypedExpr::accept( + const ITypedExprVisitor& visitor, + ITypedExprVisitorContext& context) const { + visitor.visit(*this, context); +} + +folly::dynamic NullIfTypedExpr::serialize() const { + auto obj = ITypedExpr::serializeBase("NullIfTypedExpr"); + obj["commonType"] = commonType_->serialize(); + return obj; +} + +// static +TypedExprPtr NullIfTypedExpr::create(const folly::dynamic& obj, void* context) { + auto type = core::deserializeType(obj, context); + auto inputs = deserializeInputs(obj, context); + auto commonType = ISerializable::deserialize(obj["commonType"]); + + VELOX_CHECK_EQ(inputs.size(), 2); + return std::make_shared( + std::move(inputs[0]), std::move(inputs[1]), std::move(commonType)); +} + } // namespace facebook::velox::core diff --git a/velox/core/Expressions.h b/velox/core/Expressions.h index 699351d98bd..0eb5db006bd 100644 --- a/velox/core/Expressions.h +++ b/velox/core/Expressions.h @@ -15,6 +15,8 @@ */ #pragma once +#include + #include "velox/common/Casts.h" #include "velox/common/base/Exceptions.h" #include "velox/core/ITypedExpr.h" @@ -28,7 +30,7 @@ class InputTypedExpr : public ITypedExpr { : ITypedExpr{ExprKind::kInput, std::move(type)} {} bool operator==(const ITypedExpr& other) const final { - return is_instance_of(&other); + return other.isInputKind(); } std::string toString() const override { @@ -36,7 +38,8 @@ class InputTypedExpr : public ITypedExpr { } size_t localHash() const override { - static const size_t kBaseHash = std::hash()("InputTypedExpr"); + static const size_t kBaseHash = + folly::hasher()("InputTypedExpr"); return kBaseHash; } @@ -61,7 +64,13 @@ class ConstantTypedExpr : public ITypedExpr { // Variant::null() value is supported. ConstantTypedExpr(TypePtr type, Variant value) : ITypedExpr{ExprKind::kConstant, std::move(type)}, - value_{std::move(value)} {} + value_{std::move(value)} { + VELOX_CHECK( + value_.isTypeCompatible(ITypedExpr::type()), + "Expression type {} does not match variant type {}", + ITypedExpr::type()->toString(), + value_.inferType()->toString()); + } // Creates constant expression of scalar or complex type. The value comes from // index zero. @@ -108,6 +117,10 @@ class ConstantTypedExpr : public ITypedExpr { return BaseVector::createConstant(type(), value_, 1, pool); } + /// Returns value of boolean expression, std::nullopt for null booleans. + /// Throws an error if expression is not of boolean type. + std::optional toBool() const; + const std::vector& inputs() const { static const std::vector kEmpty{}; return kEmpty; @@ -141,6 +154,9 @@ class ConstantTypedExpr : public ITypedExpr { static TypedExprPtr create(const folly::dynamic& obj, void* context); + /// Returns a NULL constant expression of given type. + static TypedExprPtr makeNull(const TypePtr& type); + private: const Variant value_; const VectorPtr valueVector_; @@ -213,8 +229,9 @@ class CallTypedExpr : public ITypedExpr { std::string toString() const override; size_t localHash() const override { - static const size_t kBaseHash = std::hash()("CallTypedExpr"); - return bits::hashMix(kBaseHash, std::hash()(name_)); + static const size_t kBaseHash = + folly::hasher()("CallTypedExpr"); + return bits::hashMix(kBaseHash, folly::hasher()(name_)); } void accept( @@ -237,10 +254,10 @@ class CallTypedExpr : public ITypedExpr { return false; } return std::equal( - this->inputs().begin(), - this->inputs().end(), - other.inputs().begin(), - other.inputs().end(), + this->inputs().cbegin(), + this->inputs().cend(), + other.inputs().cbegin(), + other.inputs().cend(), [](const auto& p1, const auto& p2) { return *p1 == *p2; }); } @@ -268,7 +285,7 @@ class FieldAccessTypedExpr : public ITypedExpr { FieldAccessTypedExpr(TypePtr type, TypedExprPtr input, std::string name) : ITypedExpr{ExprKind::kFieldAccess, std::move(type), {std::move(input)}}, name_(std::move(name)), - isInputColumn_(is_instance_of(inputs()[0].get())) {} + isInputColumn_(inputs()[0]->isInputKind()) {} const std::string& name() const { return name_; @@ -282,8 +299,8 @@ class FieldAccessTypedExpr : public ITypedExpr { size_t localHash() const override { static const size_t kBaseHash = - std::hash()("FieldAccessTypedExpr"); - return bits::hashMix(kBaseHash, std::hash()(name_)); + folly::hasher()("FieldAccessTypedExpr"); + return bits::hashMix(kBaseHash, folly::hasher()(name_)); } void accept( @@ -306,10 +323,10 @@ class FieldAccessTypedExpr : public ITypedExpr { return false; } return std::equal( - this->inputs().begin(), - this->inputs().end(), - other.inputs().begin(), - other.inputs().end(), + this->inputs().cbegin(), + this->inputs().cend(), + other.inputs().cbegin(), + other.inputs().cend(), [](const auto& p1, const auto& p2) { return *p1 == *p2; }); } @@ -338,7 +355,7 @@ class DereferenceTypedExpr : public ITypedExpr { index_(index) { // Make sure this isn't being used to access a top level column. VELOX_USER_CHECK( - !is_instance_of(inputs()[0]), + !inputs()[0]->isInputKind(), "DereferenceTypedExpr select a subfeild cannot be used to access a top level column"); } @@ -360,12 +377,16 @@ class DereferenceTypedExpr : public ITypedExpr { } std::string toString() const override { - return fmt::format("{}[{}]", inputs()[0]->toString(), name()); + const auto& fieldName = name(); + if (fieldName.empty()) { + return fmt::format("{}[{}]", inputs()[0]->toString(), index_); + } + return fmt::format("{}[{}]", inputs()[0]->toString(), fieldName); } size_t localHash() const override { static const size_t kBaseHash = - std::hash()("DereferenceTypedExpr"); + folly::hasher()("DereferenceTypedExpr"); return bits::hashMix(kBaseHash, index_); } @@ -386,10 +407,10 @@ class DereferenceTypedExpr : public ITypedExpr { return false; } return std::equal( - this->inputs().begin(), - this->inputs().end(), - other.inputs().begin(), - other.inputs().end(), + this->inputs().cbegin(), + this->inputs().cend(), + other.inputs().cbegin(), + other.inputs().cend(), [](const auto& p1, const auto& p2) { return *p1 == *p2; }); } @@ -420,7 +441,8 @@ class ConcatTypedExpr : public ITypedExpr { std::string toString() const override; size_t localHash() const override { - static const size_t kBaseHash = std::hash()("ConcatTypedExpr"); + static const size_t kBaseHash = + folly::hasher()("ConcatTypedExpr"); return kBaseHash; } @@ -441,10 +463,10 @@ class ConcatTypedExpr : public ITypedExpr { return false; } return std::equal( - this->inputs().begin(), - this->inputs().end(), - other.inputs().begin(), - other.inputs().end(), + this->inputs().cbegin(), + this->inputs().cend(), + other.inputs().cbegin(), + other.inputs().cend(), [](const auto& p1, const auto& p2) { return *p1 == *p2; }); } @@ -484,7 +506,8 @@ class LambdaTypedExpr : public ITypedExpr { } size_t localHash() const override { - static const size_t kBaseHash = std::hash()("LambdaTypedExpr"); + static const size_t kBaseHash = + folly::hasher()("LambdaTypedExpr"); return bits::hashMix(kBaseHash, body_->hash()); } @@ -522,7 +545,7 @@ using LambdaTypedExprPtr = std::shared_ptr; class CastTypedExpr : public ITypedExpr { public: /// @param type Type to convert to. This is the return type of the CAST - /// expresion. + /// expression. /// @param input Single input. The type of input is referred to as from-type /// and expected to be different from to-type. /// @param isTryCast Whether this expression is used for `try_cast`. @@ -548,7 +571,8 @@ class CastTypedExpr : public ITypedExpr { std::string toString() const override; size_t localHash() const override { - static const size_t kBaseHash = std::hash()("CastTypedExpr"); + static const size_t kBaseHash = + folly::hasher()("CastTypedExpr"); return bits::hashMix(kBaseHash, std::hash()(isTryCast_)); } @@ -586,12 +610,58 @@ class CastTypedExpr : public ITypedExpr { using CastTypedExprPtr = std::shared_ptr; +/// NULLIF(a, b) expression. Returns NULL if a equals b, otherwise returns a. +/// +/// The comparison uses the common supertype of a and b, but the return type is +/// a's original type. The common type is stored as metadata and used internally +/// to cast both inputs for comparison only. +class NullIfTypedExpr : public ITypedExpr { + public: + /// @param value The first argument. Its type determines the return type. + /// @param comparand The second argument to compare against. + /// @param commonType The common supertype used to cast both inputs for + /// comparison. + NullIfTypedExpr( + TypedExprPtr value, + TypedExprPtr comparand, + TypePtr commonType); + + /// Returns the common supertype used for comparison. + const TypePtr& commonType() const { + return commonType_; + } + + TypedExprPtr rewriteInputNames( + const std::unordered_map& mapping) + const override; + + std::string toString() const override; + + size_t localHash() const override; + + void accept( + const ITypedExprVisitor& visitor, + ITypedExprVisitorContext& context) const override; + + bool operator==(const ITypedExpr& other) const override; + + folly::dynamic serialize() const override; + + static TypedExprPtr create(const folly::dynamic& obj, void* context); + + private: + // The common supertype used to cast both inputs for comparison. + const TypePtr commonType_; +}; + +using NullIfTypedExprPtr = std::shared_ptr; + /// A collection of convenience methods for working with expressions. class TypedExprs { public: /// Returns true if 'expr' is a field access expression. static bool isFieldAccess(const TypedExprPtr& expr) { - return is_instance_of(expr); + return expr->isFieldAccessKind(); } /// Returns 'expr' as FieldAccessTypedExprPtr or null if not field access @@ -602,7 +672,7 @@ class TypedExprs { /// Returns true if 'expr' is a constant expression. static bool isConstant(const TypedExprPtr& expr) { - return is_instance_of(expr); + return expr->isConstantKind(); } /// Returns 'expr' as ConstantTypedExprPtr or null if not a constant @@ -613,7 +683,7 @@ class TypedExprs { /// Returns true if 'expr' is a lambda expression. static bool isLambda(const TypedExprPtr& expr) { - return is_instance_of(expr); + return expr->isLambdaKind(); } /// Returns 'expr' as LambdaTypedExprPtr or null if not a lambda expression. @@ -658,6 +728,9 @@ class ITypedExprVisitor { virtual void visit(const LambdaTypedExpr& expr, ITypedExprVisitorContext& ctx) const = 0; + virtual void visit(const NullIfTypedExpr& expr, ITypedExprVisitorContext& ctx) + const = 0; + protected: void visitInputs(const ITypedExpr& expr, ITypedExprVisitorContext& ctx) const { @@ -710,6 +783,11 @@ class DefaultTypedExprVisitor : public ITypedExprVisitor { const override { visitInputs(expr, ctx); } + + void visit(const NullIfTypedExpr& expr, ITypedExprVisitorContext& ctx) + const override { + visitInputs(expr, ctx); + } }; } // namespace facebook::velox::core diff --git a/velox/core/ITypedExpr.cpp b/velox/core/ITypedExpr.cpp index ba531da895c..d11156e3fb5 100644 --- a/velox/core/ITypedExpr.cpp +++ b/velox/core/ITypedExpr.cpp @@ -16,6 +16,8 @@ #include "velox/core/ITypedExpr.h" +#include "velox/common/EnumDefine.h" + namespace facebook::velox::core { namespace { @@ -35,4 +37,14 @@ const auto& exprKindNames() { } // namespace VELOX_DEFINE_ENUM_NAME(ExprKind, exprKindNames); + +size_t ITypedExprHasher::operator()(const ITypedExpr* expr) const { + return expr->hash(); +} + +bool ITypedExprComparer::operator()( + const ITypedExpr* lhs, + const ITypedExpr* rhs) const { + return *lhs == *rhs; +} } // namespace facebook::velox::core diff --git a/velox/core/ITypedExpr.h b/velox/core/ITypedExpr.h index bcdf55f233f..6cbb443c755 100644 --- a/velox/core/ITypedExpr.h +++ b/velox/core/ITypedExpr.h @@ -28,6 +28,7 @@ enum class ExprKind : int32_t { kConstant = 6, kConcat = 7, kLambda = 8, + kNullIf = 9, }; VELOX_DECLARE_ENUM_NAME(ExprKind); @@ -38,6 +39,14 @@ class ITypedExprVisitorContext; using TypedExprPtr = std::shared_ptr; +struct ITypedExprHasher { + size_t operator()(const ITypedExpr* expr) const; +}; + +struct ITypedExprComparer { + bool operator()(const ITypedExpr* lhs, const ITypedExpr* rhs) const; +}; + /// Strongly-typed expression, e.g. literal, function call, etc. class ITypedExpr : public ISerializable { public: @@ -93,6 +102,10 @@ class ITypedExpr : public ISerializable { return kind_ == ExprKind::kLambda; } + bool isNullIfKind() const { + return kind_ == ExprKind::kNullIf; + } + template const T* asUnchecked() const { return dynamic_cast(this); @@ -115,8 +128,17 @@ class ITypedExpr : public ISerializable { virtual std::string toString() const = 0; + /// Returns a hash value for this expression node only, not including inputs. + /// Implementations must use a stable hash like folly::hasher to ensure + /// stable hashing across processes and builds. virtual size_t localHash() const = 0; + /// Returns a hash value for the entire expression tree rooted at this node. + /// The hash is computed by combining localHash() with the type's hash and + /// the hashes of all input expressions. + /// + /// STABILITY GUARANTEE: This hash is stable across different processes, + /// builds, and machines. size_t hash() const { size_t hash = bits::hashMix(type_->hashKind(), localHash()); for (size_t i = 0; i < inputs_.size(); ++i) { diff --git a/velox/core/Metaprogramming.h b/velox/core/Metaprogramming.h index 955ea4122b4..95906ef2140 100644 --- a/velox/core/Metaprogramming.h +++ b/velox/core/Metaprogramming.h @@ -124,11 +124,10 @@ template struct has_method { private: template - static constexpr auto check(T*) -> - typename std::is_same< - decltype(std::declval().template resolve( - std::declval()...)), - TRet>::type { + static constexpr auto check(T*) -> typename std::is_same< + decltype(std::declval().template resolve( + std::declval()...)), + TRet>::type { return {}; } diff --git a/velox/core/PlanConsistencyChecker.cpp b/velox/core/PlanConsistencyChecker.cpp index 89c95c0d654..35db704a3f0 100644 --- a/velox/core/PlanConsistencyChecker.cpp +++ b/velox/core/PlanConsistencyChecker.cpp @@ -15,15 +15,47 @@ */ #include "velox/core/PlanConsistencyChecker.h" +#include "velox/common/base/Exceptions.h" namespace facebook::velox::core { namespace { +// Returns a message describing the plan node for exception context. +std::string planNodeMessage(VeloxException::Type /*exceptionType*/, void* arg) { + auto* node = static_cast(arg); + return fmt::format("Plan node: {}", node->toString(/*detailed=*/true)); +} + class Checker : public PlanNodeVisitor { public: + explicit Checker(PlanNodeId rootId) : rootId_{std::move(rootId)} {} + void visit(const AggregationNode& node, PlanNodeVisitorContext& ctx) const override { + const auto& rowType = node.sources().at(0)->outputType(); + for (const auto& expr : node.groupingKeys()) { + checkInputs(expr, rowType); + } + + for (const auto& expr : node.preGroupedKeys()) { + checkInputs(expr, rowType); + } + + for (const auto& aggregate : node.aggregates()) { + checkInputs(aggregate.call, rowType); + + for (const auto& expr : aggregate.sortingKeys) { + checkInputs(expr, rowType); + } + + if (aggregate.mask) { + checkInputs(aggregate.mask, rowType); + } + } + + verifyOutputNames(node); + visitSources(&node, ctx); } @@ -66,6 +98,26 @@ class Checker : public PlanNodeVisitor { void visit(const HashJoinNode& node, PlanNodeVisitorContext& ctx) const override { + std::unordered_set> keyNames; + for (auto i = 0; i < node.leftKeys().size(); ++i) { + const auto& leftKey = node.leftKeys().at(i); + const auto& rightKey = node.rightKeys().at(i); + + auto [_, inserted] = keyNames.insert({leftKey->name(), rightKey->name()}); + VELOX_USER_CHECK( + inserted, + "Duplicate join condition: \"{}\" = \"{}\"", + leftKey->name(), + rightKey->name()); + } + + if (node.filter()) { + const auto& leftRowType = node.sources().at(0)->outputType(); + const auto& rightRowType = node.sources().at(1)->outputType(); + auto rowType = leftRowType->unionWith(rightRowType); + checkInputs(node.filter(), rowType); + } + visitSources(&node, ctx); } @@ -89,11 +141,26 @@ class Checker : public PlanNodeVisitor { visitSources(&node, ctx); } + void visit(const MixedUnionNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + void visit(const MarkDistinctNode& node, PlanNodeVisitorContext& ctx) const override { visitSources(&node, ctx); } + void visit(const EnforceDistinctNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const MarkSortedNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + void visit(const MergeExchangeNode& node, PlanNodeVisitorContext& ctx) const override { visitSources(&node, ctx); @@ -106,6 +173,15 @@ class Checker : public PlanNodeVisitor { void visit(const NestedLoopJoinNode& node, PlanNodeVisitorContext& ctx) const override { + if (node.joinCondition() != nullptr) { + const auto& leftRowType = node.sources().at(0)->outputType(); + const auto& rightRowType = node.sources().at(1)->outputType(); + auto rowType = leftRowType->unionWith(rightRowType); + checkInputs(node.joinCondition(), rowType); + } + + verifyOutputNames(node); + visitSources(&node, ctx); } @@ -131,12 +207,10 @@ class Checker : public PlanNodeVisitor { checkInputs(expr, rowType); } - // Verify that output column names are not empty and unique. - std::unordered_set names; - for (const auto& name : node.outputType()->names()) { - VELOX_USER_CHECK(!name.empty(), "Output column name cannot be empty"); - VELOX_USER_CHECK( - names.insert(name).second, "Duplicate output column: {}", name); + // The root ProjectNode may have empty or duplicate output names when used + // to apply user-specified column aliases (e.g. SELECT 1 as x, 2 as x). + if (node.id() != rootId_) { + verifyOutputNames(node); } visitSources(&node, ctx); @@ -160,6 +234,22 @@ class Checker : public PlanNodeVisitor { void visit(const TableScanNode& node, PlanNodeVisitorContext& ctx) const override { + verifyOutputNames(node); + + // Verify assignments match outputType 1:1. + const auto& names = node.outputType()->names(); + VELOX_USER_CHECK_EQ( + names.size(), + node.assignments().size(), + "Column assignments must match output type"); + + for (const auto& name : names) { + VELOX_USER_CHECK( + node.assignments().contains(name), + "Column assignment is missing for {}", + name); + } + visitSources(&node, ctx); } @@ -209,10 +299,22 @@ class Checker : public PlanNodeVisitor { private: void visitSources(const PlanNode* node, PlanNodeVisitorContext& ctx) const { for (auto& source : node->sources()) { + ExceptionContextSetter exceptionContext( + {planNodeMessage, (void*)source.get()}); source->accept(*this, ctx); } } + // Verify that output column names are not empty and unique. + static void verifyOutputNames(const PlanNode& node) { + folly::F14FastSet names; + for (const auto& name : node.outputType()->names()) { + VELOX_USER_CHECK(!name.empty(), "Output column name cannot be empty"); + VELOX_USER_CHECK( + names.emplace(name).second, "Duplicate output column: {}", name); + } + } + static void checkInputs( const core::TypedExprPtr& expr, const RowTypePtr& rowType) { @@ -233,16 +335,27 @@ class Checker : public PlanNodeVisitor { } } + if (expr->isLambdaKind()) { + const auto& lambda = expr->asUnchecked(); + checkInputs(lambda->body(), lambda->signature()->unionWith(rowType)); + } + for (const auto& input : expr->inputs()) { checkInputs(input, rowType); } } + + // ID of the root node. Used to skip output name validation on the root + // ProjectNode, which may have user-specified aliases that are empty or + // duplicated. + const PlanNodeId rootId_; }; } // namespace void PlanConsistencyChecker::check(const core::PlanNodePtr& plan) { + ExceptionContextSetter exceptionContext({planNodeMessage, (void*)plan.get()}); PlanNodeVisitorContext ctx; - Checker checker; + Checker checker{plan->id()}; plan->accept(checker, ctx); } }; // namespace facebook::velox::core diff --git a/velox/core/PlanNode.cpp b/velox/core/PlanNode.cpp index 2cf53c91020..f47b957aced 100644 --- a/velox/core/PlanNode.cpp +++ b/velox/core/PlanNode.cpp @@ -15,12 +15,15 @@ */ #include +#include "velox/common/Casts.h" #include "velox/common/encode/Base64.h" #include "velox/core/PlanNode.h" + +#include "velox/common/EnumDefine.h" +#include "velox/core/TableWriteTraits.h" #include "velox/vector/VectorSaver.h" namespace facebook::velox::core { - namespace { void appendComma(int32_t i, std::stringstream& sql) { @@ -59,7 +62,11 @@ IndexLookupConditionPtr createIndexJoinCondition( } } // namespace -std::vector deserializeJoinConditions( +/// Deserializes lookup conditions from dynamic object for index lookup joins. +/// These conditions are more complex than simple equality join conditions and +/// can include IN, BETWEEN, and EQUAL conditions that involve both left and +/// right side columns. +std::vector deserializejoinConditions( const folly::dynamic& obj, void* context) { if (obj.count("joinConditions") == 0) { @@ -162,6 +169,7 @@ AggregationNode::AggregationNode( const std::vector& globalGroupingSets, const std::optional& groupId, bool ignoreNullKeys, + bool noGroupsSpanBatches, PlanNodePtr source) : PlanNode(id), step_(step), @@ -172,6 +180,7 @@ AggregationNode::AggregationNode( ignoreNullKeys_(ignoreNullKeys), groupId_(groupId), globalGroupingSets_(globalGroupingSets), + noGroupsSpanBatches_(noGroupsSpanBatches), sources_{source}, outputType_(getAggregationOutputType( groupingKeys_, @@ -216,6 +225,10 @@ AggregationNode::AggregationNode( VELOX_USER_CHECK( groupId_.has_value(), "Global grouping sets require GroupId key"); } + + VELOX_USER_CHECK( + !noGroupsSpanBatches_ || isPreGrouped(), + "noGroupsSpanBatches can only be set for streaming aggregation (pre-grouped)"); } AggregationNode::AggregationNode( @@ -226,6 +239,7 @@ AggregationNode::AggregationNode( const std::vector& aggregateNames, const std::vector& aggregates, bool ignoreNullKeys, + bool noGroupsSpanBatches, PlanNodePtr source) : AggregationNode( id, @@ -237,6 +251,7 @@ AggregationNode::AggregationNode( kDefaultGlobalGroupingSets, kDefaultGroupId, ignoreNullKeys, + noGroupsSpanBatches, source) {} namespace { @@ -273,19 +288,12 @@ void addSortingKeys( } } -void addVectorSerdeKind(VectorSerde::Kind kind, std::stringstream& stream) { - stream << VectorSerde::kindName(kind); +void addVectorSerdeKind(const std::string& kind, std::stringstream& stream) { + stream << kind; } } // namespace bool AggregationNode::canSpill(const QueryConfig& queryConfig) const { - // TODO: Add spilling for aggregations over distinct inputs. - // https://github.com/facebookincubator/velox/issues/7454 - for (const auto& aggregate : aggregates_) { - if (aggregate.distinct) { - return false; - } - } // TODO: add spilling for pre-grouped aggregation later: // https://github.com/facebookincubator/velox/issues/3264 return (isFinal() || isSingle()) && !groupingKeys().empty() && @@ -331,6 +339,10 @@ void AggregationNode::addDetails(std::stringstream& stream) const { if (groupId_.has_value()) { stream << " Group Id key: " << groupId_.value()->name(); } + + if (noGroupsSpanBatches_) { + stream << " noGroupsSpanBatches"; + } } namespace { @@ -369,6 +381,7 @@ folly::dynamic AggregationNode::serialize() const { obj["groupId"] = ISerializable::serialize(groupId_.value()); } obj["ignoreNullKeys"] = ignoreNullKeys_; + obj["noGroupsSpanBatches"] = noGroupsSpanBatches_; return obj; } @@ -386,6 +399,10 @@ std::vector deserializeFields( array, context); } +FieldAccessTypedExprPtr deserializeField(const folly::dynamic& obj) { + return ISerializable::deserialize(obj); +} + std::vector deserializeStrings(const folly::dynamic& array) { return ISerializable::deserialize>(array); } @@ -485,6 +502,8 @@ PlanNodePtr AggregationNode::create(const folly::dynamic& obj, void* context) { globalGroupingSets, groupId, obj["ignoreNullKeys"].asBool(), + obj.count("noGroupsSpanBatches") ? obj["noGroupsSpanBatches"].asBool() + : false, deserializeSingleSource(obj, context)); } @@ -892,6 +911,13 @@ class SummarizeExprVisitor : public ITypedExprVisitor { myCtx.expressionCounts()["lambda"]++; expr.body()->accept(*this, ctx); } + + void visit(const NullIfTypedExpr& expr, ITypedExprVisitorContext& ctx) + const override { + auto& myCtx = static_cast(ctx); + myCtx.expressionCounts()["nullif"]++; + visitInputs(expr, ctx); + } }; void appendCounts( @@ -1129,7 +1155,7 @@ std::vector allNames( const std::vector& names, const std::vector& moreNames) { auto result = names; - result.insert(result.end(), moreNames.begin(), moreNames.end()); + result.insert(result.cend(), moreNames.cbegin(), moreNames.cend()); return result; } @@ -1141,13 +1167,14 @@ std::vector flattenExprs( const PlanNodePtr& input) { std::vector result; for (auto& group : exprs) { - result.insert(result.end(), group.begin(), group.end()); + result.insert(result.cend(), group.cbegin(), group.cend()); } const auto& sourceType = input->outputType(); for (auto& name : moreNames) { - result.push_back(std::make_shared( - sourceType->findChild(name), name)); + result.push_back( + std::make_shared( + sourceType->findChild(name), name)); } return result; } @@ -1319,7 +1346,7 @@ void ExchangeNode::addDetails(std::stringstream& stream) const { folly::dynamic ExchangeNode::serialize() const { auto obj = PlanNode::serialize(); obj["outputType"] = ExchangeNode::outputType()->serialize(); - obj["serdeKind"] = VectorSerde::kindName(serdeKind_); + obj["serdeKind"] = serdeKind_; return obj; } @@ -1334,7 +1361,7 @@ PlanNodePtr ExchangeNode::create(const folly::dynamic& obj, void* context) { return std::make_shared( deserializePlanNodeId(obj), deserializeRowType(obj["outputType"]), - VectorSerde::kindByName(obj["serdeKind"].asString())); + obj["serdeKind"].asString()); } UnnestNode::UnnestNode( @@ -1343,14 +1370,34 @@ UnnestNode::UnnestNode( std::vector unnestVariables, std::vector unnestNames, std::optional ordinalityName, - std::optional emptyUnnestValueName, + std::optional markerName, + const PlanNodePtr& source) + : UnnestNode( + id, + std::move(replicateVariables), + std::move(unnestVariables), + std::move(unnestNames), + std::move(ordinalityName), + std::move(markerName), + std::nullopt, + source) {} + +UnnestNode::UnnestNode( + const PlanNodeId& id, + std::vector replicateVariables, + std::vector unnestVariables, + std::vector unnestNames, + std::optional ordinalityName, + std::optional markerName, + std::optional splitOutput, const PlanNodePtr& source) : PlanNode(id), replicateVariables_{std::move(replicateVariables)}, unnestVariables_{std::move(unnestVariables)}, unnestNames_{std::move(unnestNames)}, ordinalityName_{std::move(ordinalityName)}, - emptyUnnestValueName_(std::move(emptyUnnestValueName)), + markerName_(std::move(markerName)), + splitOutput_(splitOutput), sources_{source} { // Calculate output type. First come "replicate" columns, followed by // "unnest" columns, followed by an optional ordinality column. @@ -1387,8 +1434,8 @@ UnnestNode::UnnestNode( types.emplace_back(BIGINT()); } - if (emptyUnnestValueName_.has_value()) { - names.emplace_back(emptyUnnestValueName_.value()); + if (markerName_.has_value()) { + names.emplace_back(markerName_.value()); types.emplace_back(BOOLEAN()); } @@ -1408,8 +1455,11 @@ folly::dynamic UnnestNode::serialize() const { if (ordinalityName_.has_value()) { obj["ordinalityName"] = ordinalityName_.value(); } - if (emptyUnnestValueName_.has_value()) { - obj["emptyUnnestValueName"] = emptyUnnestValueName_.value(); + if (markerName_.has_value()) { + obj["markerName"] = markerName_.value(); + } + if (splitOutput_.has_value()) { + obj["splitOutput"] = splitOutput_.value(); } return obj; } @@ -1431,9 +1481,13 @@ PlanNodePtr UnnestNode::create(const folly::dynamic& obj, void* context) { if (obj.count("ordinalityName")) { ordinalityName = obj["ordinalityName"].asString(); } - std::optional emptyUnnestValueName = std::nullopt; - if (obj.count("emptyUnnestValueName")) { - emptyUnnestValueName = obj["emptyUnnestValueName"].asString(); + std::optional markerName = std::nullopt; + if (obj.count("markerName")) { + markerName = obj["markerName"].asString(); + } + std::optional splitOutput = std::nullopt; + if (obj.count("splitOutput")) { + splitOutput = obj["splitOutput"].asBool(); } return std::make_shared( deserializePlanNodeId(obj), @@ -1441,7 +1495,8 @@ PlanNodePtr UnnestNode::create(const folly::dynamic& obj, void* context) { std::move(unnestVariables), std::move(unnestNames), std::move(ordinalityName), - std::move(emptyUnnestValueName), + std::move(markerName), + splitOutput, std::move(source)); } @@ -1509,7 +1564,8 @@ void AbstractJoinNode::validate() const { // Output of left semi and anti joins cannot include columns from the right // side. bool outputMayIncludeRightColumns = - !(isLeftSemiFilterJoin() || isLeftSemiProjectJoin() || isAntiJoin()); + !(isLeftSemiFilterJoin() || isLeftSemiProjectJoin() || isAntiJoin() || + isCountingJoin()); for (auto i = 0; i < numOutputColumns; ++i) { auto name = outputType_->nameOf(i); @@ -1570,28 +1626,51 @@ const auto& joinTypeNames() { {JoinType::kLeftSemiProject, "LEFT SEMI (PROJECT)"}, {JoinType::kRightSemiProject, "RIGHT SEMI (PROJECT)"}, {JoinType::kAnti, "ANTI"}, + {JoinType::kCountingAnti, "COUNTING ANTI"}, + {JoinType::kCountingLeftSemiFilter, "COUNTING LEFT SEMI (FILTER)"}, }; return kNames; } // Check that each output of the join is in exactly one of the inputs. -void checkJoinColumnNames( +void checkJoinOutput( const RowTypePtr& leftType, const RowTypePtr& rightType, const RowTypePtr& outputType, uint32_t numColumnsToCheck) { for (auto i = 0; i < numColumnsToCheck; ++i) { - const auto name = outputType->nameOf(i); - const bool leftContains = leftType->containsChild(name); - const bool rightContains = rightType->containsChild(name); + const auto& name = outputType->nameOf(i); + const auto& type = outputType->childAt(i); + + const auto leftIndex = leftType->getChildIdxIfExists(name); + const auto rightIndex = rightType->getChildIdxIfExists(name); + VELOX_USER_CHECK( - !(leftContains && rightContains), + !(leftIndex.has_value() && rightIndex.has_value()), "Duplicate column name found on join's left and right sides: {}", name); VELOX_USER_CHECK( - leftContains || rightContains, + leftIndex.has_value() || rightIndex.has_value(), "Join's output column not found in either left or right sides: {}", name); + + if (leftIndex.has_value()) { + const auto& expectedType = leftType->childAt(leftIndex.value()); + VELOX_USER_CHECK( + expectedType->equivalent(*type), + "Join output column type must match the input type: {} vs. {}", + type->toString(), + expectedType->toString()); + } + + if (rightIndex.has_value()) { + const auto& expectedType = rightType->childAt(rightIndex.value()); + VELOX_USER_CHECK( + expectedType->equivalent(*type), + "Join output column type must match the input type: {} vs. {}", + type->toString(), + expectedType->toString()); + } } } @@ -1604,11 +1683,16 @@ void HashJoinNode::addDetails(std::stringstream& stream) const { if (nullAware_) { stream << ", null aware"; } + if (nullAsValue_) { + stream << ", null as value"; + } } folly::dynamic HashJoinNode::serialize() const { auto obj = serializeBase(); obj["nullAware"] = nullAware_; + obj["nullAsValue"] = nullAsValue_; + obj["useHashTableCache"] = useHashTableCache_; return obj; } @@ -1624,6 +1708,8 @@ PlanNodePtr HashJoinNode::create(const folly::dynamic& obj, void* context) { VELOX_CHECK_EQ(2, sources.size()); auto nullAware = obj["nullAware"].asBool(); + auto nullAsValue = obj.getDefault("nullAsValue", false).asBool(); + auto useHashTableCache = obj.getDefault("useHashTableCache", false).asBool(); auto leftKeys = deserializeFields(obj["leftKeys"], context); auto rightKeys = deserializeFields(obj["rightKeys"], context); @@ -1643,7 +1729,9 @@ PlanNodePtr HashJoinNode::create(const folly::dynamic& obj, void* context) { filter, sources[0], sources[1], - outputType); + outputType, + useHashTableCache, + nullAsValue); } MergeJoinNode::MergeJoinNode( @@ -1730,7 +1818,8 @@ IndexLookupJoinNode::IndexLookupJoinNode( const std::vector& leftKeys, const std::vector& rightKeys, const std::vector& joinConditions, - bool includeMatchColumn, + TypedExprPtr filter, + bool hasMarker, PlanNodePtr left, TableScanNodePtr right, RowTypePtr outputType) @@ -1739,13 +1828,13 @@ IndexLookupJoinNode::IndexLookupJoinNode( joinType, leftKeys, rightKeys, - /*filter=*/nullptr, + std::move(filter), std::move(left), right, outputType), lookupSourceNode_(std::move(right)), joinConditions_(joinConditions), - includeMatchColumn_(includeMatchColumn) { + hasMarker_(hasMarker) { VELOX_USER_CHECK( !leftKeys.empty(), "The index lookup join node requires at least one join key"); @@ -1789,7 +1878,7 @@ IndexLookupJoinNode::IndexLookupJoinNode( } auto numOutputColumns = outputType_->size(); - if (includeMatchColumn_) { + if (hasMarker_) { VELOX_USER_CHECK( isLeftJoin(), "Index join match column can only present for {} but not {}", @@ -1809,7 +1898,7 @@ IndexLookupJoinNode::IndexLookupJoinNode( VELOX_USER_CHECK(!rightType->containsChild(name)); } - checkJoinColumnNames(leftType, rightType, outputType_, numOutputColumns); + checkJoinOutput(leftType, rightType, outputType_, numOutputColumns); } PlanNodePtr IndexLookupJoinNode::create( @@ -1818,17 +1907,19 @@ PlanNodePtr IndexLookupJoinNode::create( auto sources = deserializeSources(obj, context); VELOX_CHECK_EQ(2, sources.size()); TableScanNodePtr lookupSource = - std::dynamic_pointer_cast(sources[1]); - VELOX_CHECK_NOT_NULL(lookupSource); + checkedPointerCast(sources[1]); auto leftKeys = deserializeFields(obj["leftKeys"], context); auto rightKeys = deserializeFields(obj["rightKeys"], context); - VELOX_CHECK_EQ(obj.count("filter"), 0); + TypedExprPtr filter; + if (obj.count("filter")) { + filter = ISerializable::deserialize(obj["filter"], context); + } - auto joinConditions = deserializeJoinConditions(obj, context); + auto joinConditions = deserializejoinConditions(obj, context); - const bool includeMatchColumn = obj["includeMatchColumn"].asBool(); + const bool hasMarker = obj["hasMarker"].asBool(); auto outputType = deserializeRowType(obj["outputType"]); @@ -1838,7 +1929,8 @@ PlanNodePtr IndexLookupJoinNode::create( std::move(leftKeys), std::move(rightKeys), std::move(joinConditions), - includeMatchColumn, + filter, + hasMarker, sources[0], std::move(lookupSource), std::move(outputType)); @@ -1853,24 +1945,32 @@ folly::dynamic IndexLookupJoinNode::serialize() const { } obj["joinConditions"] = std::move(serializedJoins); } - obj["includeMatchColumn"] = includeMatchColumn_; + if (filter_) { + obj["filter"] = filter_->serialize(); + } + obj["hasMarker"] = hasMarker_; return obj; } +bool IndexLookupJoinNode::needsIndexSplit() const { + return lookupSourceNode_->tableHandle()->needsIndexSplit(); +} + void IndexLookupJoinNode::addDetails(std::stringstream& stream) const { AbstractJoinNode::addDetails(stream); if (joinConditions_.empty()) { return; } - std::vector joinConditionStrs; - joinConditionStrs.reserve(joinConditions_.size()); + std::vector joinConditionstrs; + joinConditionstrs.reserve(joinConditions_.size()); for (const auto& joinCondition : joinConditions_) { - joinConditionStrs.push_back(joinCondition->toString()); + joinConditionstrs.push_back(joinCondition->toString()); } - stream << ", joinConditions: [" << folly::join(", ", joinConditionStrs) - << " ], includeMatchColumn: [" - << (includeMatchColumn_ ? "true" : "false") << "]"; + stream << ", joinConditions: [" << folly::join(", ", joinConditionstrs) + << " ], filter: [" + << (filter_ == nullptr ? "null" : filter_->toString()) + << "], hasMarker: [" << (hasMarker_ ? "true" : "false") << "]"; } void IndexLookupJoinNode::accept( @@ -1892,9 +1992,7 @@ bool IndexLookupJoinNode::isSupported(JoinType joinType) { } bool isIndexLookupJoin(const PlanNode* planNode) { - const auto* indexLookupJoin = - dynamic_cast(planNode); - return indexLookupJoin != nullptr; + return isInstanceOf(planNode); } // static @@ -1936,7 +2034,7 @@ NestedLoopJoinNode::NestedLoopJoinNode( name); } - checkJoinColumnNames(leftType, rightType, outputType_, numOutputColumns); + checkJoinOutput(leftType, rightType, outputType_, numOutputColumns); } NestedLoopJoinNode::NestedLoopJoinNode( @@ -2374,6 +2472,158 @@ PlanNodePtr MarkDistinctNode::create(const folly::dynamic& obj, void* context) { deserializePlanNodeId(obj), markerName, distinctKeys, source); } +EnforceDistinctNode::EnforceDistinctNode( + PlanNodeId id, + std::vector distinctKeys, + std::vector preGroupedKeys, + std::string errorMessage, + PlanNodePtr source) + : PlanNode(std::move(id)), + distinctKeys_(std::move(distinctKeys)), + preGroupedKeys_(std::move(preGroupedKeys)), + errorMessage_(std::move(errorMessage)), + sources_{std::move(source)} { + VELOX_USER_CHECK(!distinctKeys_.empty(), "distinctKeys must not be empty."); + VELOX_USER_CHECK(!errorMessage_.empty(), "errorMessage must not be empty"); + + using TypedExprSet = folly:: + F14FastSet; + + TypedExprSet distinctKeySet; + const auto& inputType = sources_.front()->outputType(); + for (const auto& key : distinctKeys_) { + VELOX_USER_CHECK( + key->isInputColumn(), + "Distinct key must be a column reference: {}.", + key->toString()); + + VELOX_USER_CHECK( + distinctKeySet.insert(key.get()).second, + "Duplicate distinct key: {}.", + key->toString()); + + VELOX_USER_CHECK( + inputType->containsChild(key->name()), + "Distinct key must be present in the input: {}.", + key->toString()); + } + + TypedExprSet preGroupedKeySet; + for (const auto& key : preGroupedKeys_) { + VELOX_USER_CHECK( + preGroupedKeySet.insert(key.get()).second, + "Duplicate pre-grouped key: {}.", + key->name()); + VELOX_USER_CHECK( + distinctKeySet.contains(key.get()), + "Pre-grouped key must be one of the distinct keys: {}.", + key->name()); + } +} + +folly::dynamic EnforceDistinctNode::serialize() const { + auto obj = PlanNode::serialize(); + obj["distinctKeys"] = ISerializable::serialize(this->distinctKeys_); + obj["preGroupedKeys"] = ISerializable::serialize(this->preGroupedKeys_); + obj["errorMessage"] = this->errorMessage_; + return obj; +} + +void EnforceDistinctNode::accept( + const PlanNodeVisitor& visitor, + PlanNodeVisitorContext& context) const { + visitor.visit(*this, context); +} + +// static +PlanNodePtr EnforceDistinctNode::create( + const folly::dynamic& obj, + void* context) { + auto source = deserializeSingleSource(obj, context); + auto distinctKeys = deserializeFields(obj["distinctKeys"], context); + auto preGroupedKeys = deserializeFields(obj["preGroupedKeys"], context); + auto errorMessage = obj["errorMessage"].asString(); + + return std::make_shared( + deserializePlanNodeId(obj), + distinctKeys, + preGroupedKeys, + errorMessage, + source); +} + +RowTypePtr getMarkSortedOutputType( + const RowTypePtr& inputType, + const std::string& markerName) { + std::vector names = inputType->names(); + std::vector types = inputType->children(); + + names.emplace_back(markerName); + types.emplace_back(BOOLEAN()); + return ROW(std::move(names), std::move(types)); +} + +MarkSortedNode::MarkSortedNode( + PlanNodeId id, + std::string markerName, + std::vector sortingKeys, + std::vector sortingOrders, + PlanNodePtr source) + : PlanNode(std::move(id)), + markerName_(std::move(markerName)), + sortingKeys_(std::move(sortingKeys)), + sortingOrders_(std::move(sortingOrders)), + sources_{std::move(source)}, + outputType_( + getMarkSortedOutputType(sources_[0]->outputType(), markerName_)) { + VELOX_USER_CHECK_GT(markerName_.size(), 0); + VELOX_USER_CHECK_GT(sortingKeys_.size(), 0); + VELOX_USER_CHECK_EQ( + sortingKeys_.size(), + sortingOrders_.size(), + "Number of sorting keys and sorting orders must be the same"); +} + +void MarkSortedNode::addDetails(std::stringstream& stream) const { + stream << "marker: " << markerName_ << ", keys: ["; + for (auto i = 0; i < sortingKeys_.size(); ++i) { + if (i > 0) { + stream << ", "; + } + stream << sortingKeys_[i]->name() << " " << sortingOrders_[i].toString(); + } + stream << "]"; +} + +folly::dynamic MarkSortedNode::serialize() const { + auto obj = PlanNode::serialize(); + obj["markerName"] = this->markerName_; + obj["sortingKeys"] = ISerializable::serialize(this->sortingKeys_); + obj["sortingOrders"] = ISerializable::serialize(this->sortingOrders_); + return obj; +} + +void MarkSortedNode::accept( + const PlanNodeVisitor& visitor, + PlanNodeVisitorContext& context) const { + visitor.visit(*this, context); +} + +// static +PlanNodePtr MarkSortedNode::create(const folly::dynamic& obj, void* context) { + auto source = deserializeSingleSource(obj, context); + auto markerName = obj["markerName"].asString(); + auto sortingKeys = deserializeFields(obj["sortingKeys"], context); + auto sortingOrders = deserializeSortingOrders(obj["sortingOrders"]); + + return std::make_shared( + deserializePlanNodeId(obj), + markerName, + sortingKeys, + sortingOrders, + source); +} + namespace { RowTypePtr getRowNumberOutputType( const RowTypePtr& inputType, @@ -2486,7 +2736,7 @@ const char* TopNRowNumberNode::rankFunctionName( static const auto kFunctionNames = rankFunctionNames(); auto it = kFunctionNames.find(function); VELOX_CHECK( - it != kFunctionNames.end(), + it != kFunctionNames.cend(), "Invalid rank function {}", static_cast(function)); return it->second.c_str(); @@ -2497,7 +2747,7 @@ TopNRowNumberNode::RankFunction TopNRowNumberNode::rankFunctionFromName( std::string_view name) { static const auto kFunctionNames = invertMap(rankFunctionNames()); auto it = kFunctionNames.find(name.data()); - VELOX_CHECK(it != kFunctionNames.end(), "Invalid rank function {}", name); + VELOX_CHECK(it != kFunctionNames.cend(), "Invalid rank function {}", name); return it->second; } @@ -2640,8 +2890,135 @@ PlanNodePtr LocalMergeNode::create(const folly::dynamic& obj, void* context) { std::move(sources)); } +namespace { +// Validates that grouping keys in 'spec' are present in 'type' and have no +// duplicates. 'context' is used in error messages (e.g. "written columns", +// "source output"). +void validateGroupingKeys( + const ColumnStatsSpec& spec, + const RowType& type, + std::string_view context) { + folly::F14FastSet seenKeys; + for (const auto& key : spec.groupingKeys) { + VELOX_USER_CHECK( + type.containsChild(key->name()), + "Grouping key not found in {}: {}", + context, + key->name()); + VELOX_USER_CHECK( + seenKeys.insert(key->name()).second, + "Duplicate grouping key: {}", + key->name()); + } +} +} // namespace + +TableWriteNode::TableWriteNode( + const PlanNodeId& id, + const RowTypePtr& columns, + const std::vector& columnNames, + std::optional columnStatsSpec, + std::shared_ptr insertTableHandle, + bool hasPartitioningScheme, + RowTypePtr outputType, + connector::CommitStrategy commitStrategy, + const PlanNodePtr& source) + : PlanNode(id), + sources_{source}, + columns_{columns}, + columnNames_{columnNames}, + columnStatsSpec_(std::move(columnStatsSpec)), + insertTableHandle_(std::move(insertTableHandle)), + hasPartitioningScheme_(hasPartitioningScheme), + outputType_(std::move(outputType)), + commitStrategy_(commitStrategy) { + VELOX_USER_CHECK_NOT_NULL(sources_[0]); + VELOX_USER_CHECK_NOT_NULL(insertTableHandle_); + VELOX_USER_CHECK_EQ(columns_->size(), columnNames_.size()); + for (const auto& column : columns_->names()) { + VELOX_USER_CHECK( + sources_[0]->outputType()->containsChild(column), + "Column not found in TableWrite input: {}", + column); + } + if (columnStatsSpec_.has_value()) { + VELOX_USER_CHECK( + columnStatsSpec_->aggregationStep == AggregationNode::Step::kSingle || + columnStatsSpec_->aggregationStep == + AggregationNode::Step::kPartial, + "TableWriteNode requires aggregation step to be single or partial"); + validateGroupingKeys( + columnStatsSpec_.value(), *columns_, "written columns"); + } + // Single-column BIGINT output with no stats spec. Used by Spark/Gluten + // and other non-Prestissimo integrations. + if (outputType_->size() == 1 && !columnStatsSpec_.has_value()) { + VELOX_USER_CHECK_EQ( + outputType_->childAt(0)->kind(), + TypeKind::BIGINT, + "Single-column outputType must be BIGINT"); + return; + } + const auto expectedType = TableWriteTraits::outputType(columnStatsSpec_); + VELOX_USER_CHECK( + outputType_->equivalent(*expectedType), + "TableWriteNode outputType mismatch: {} vs computed {}", + outputType_->toString(), + expectedType->toString()); +} + +namespace { +void addStatsSpecDetails( + std::stringstream& stream, + const std::optional& spec) { + if (!spec.has_value()) { + return; + } + stream << "stats[" << AggregationNode::toName(spec->aggregationStep); + if (!spec->groupingKeys.empty()) { + stream << " ["; + addFields(stream, spec->groupingKeys); + stream << "]"; + } + stream << ": "; + for (auto i = 0; i < spec->aggregates.size(); ++i) { + appendComma(i, stream); + stream << spec->aggregates[i].call->toString(); + } + stream << "]"; +} +} // namespace + void TableWriteNode::addDetails(std::stringstream& stream) const { - stream << insertTableHandle_->connectorInsertTableHandle()->toString(); + stream << insertTableHandle_->connectorId() << ", " + << folly::join(", ", columnNames_); + if (columnStatsSpec_.has_value()) { + stream << ", "; + addStatsSpecDetails(stream, columnStatsSpec_); + } +} + +RowTypePtr ColumnStatsSpec::outputType() const { + // Create output type based on the column stats collection specs. + std::vector names; + std::vector types; + + const auto numAggregates = aggregates.size(); + const auto outputTypeSize = groupingKeys.size() + numAggregates; + + names.reserve(outputTypeSize); + types.reserve(outputTypeSize); + + for (const auto& key : groupingKeys) { + names.push_back(key->name()); + types.push_back(key->type()); + } + + for (auto i = 0; i < numAggregates; ++i) { + names.push_back(aggregateNames[i]); + types.push_back(aggregates[i].call->type()); + } + return ROW(std::move(names), std::move(types)); } folly::dynamic ColumnStatsSpec::serialize() const { @@ -2709,11 +3086,6 @@ PlanNodePtr TableWriteNode::create(const folly::dynamic& obj, void* context) { auto columns = deserializeRowType(obj["columns"]); auto columnNames = ISerializable::deserialize>(obj["columnNames"]); - AggregationNodePtr aggregationNode; - if (obj.count("aggregationNode") != 0) { - aggregationNode = ISerializable::deserialize( - obj["aggregationNode"], context); - } auto connectorId = obj["connectorId"].asString(); auto connectorInsertTableHandle = ISerializable::deserialize( @@ -2739,7 +3111,36 @@ PlanNodePtr TableWriteNode::create(const folly::dynamic& obj, void* context) { deserializeSingleSource(obj, context)); } -void TableWriteMergeNode::addDetails(std::stringstream& /* stream */) const {} +TableWriteMergeNode::TableWriteMergeNode( + const PlanNodeId& id, + RowTypePtr outputType, + std::optional columnStatsSpec, + PlanNodePtr source) + : PlanNode(id), + columnStatsSpec_(std::move(columnStatsSpec)), + sources_{std::move(source)}, + outputType_(std::move(outputType)) { + VELOX_USER_CHECK_NOT_NULL(sources_[0]); + const auto expectedType = TableWriteTraits::outputType(columnStatsSpec_); + VELOX_USER_CHECK( + outputType_->equivalent(*expectedType), + "TableWriteMergeNode outputType mismatch: {} vs computed {}", + outputType_->toString(), + expectedType->toString()); + if (hasColumnStatsSpec()) { + VELOX_USER_CHECK( + columnStatsSpec_->aggregationStep == AggregationNode::Step::kFinal || + columnStatsSpec_->aggregationStep == + AggregationNode::Step::kIntermediate, + "TableWriteMergeNode requires aggregation step to be intermediate or final"); + validateGroupingKeys( + columnStatsSpec_.value(), *sources_[0]->outputType(), "source output"); + } +} + +void TableWriteMergeNode::addDetails(std::stringstream& stream) const { + addStatsSpecDetails(stream, columnStatsSpec_); +} folly::dynamic TableWriteMergeNode::serialize() const { auto obj = PlanNode::serialize(); @@ -2780,8 +3181,8 @@ MergeExchangeNode::MergeExchangeNode( const RowTypePtr& type, const std::vector& sortingKeys, const std::vector& sortingOrders, - VectorSerde::Kind serdeKind) - : ExchangeNode(id, type, serdeKind), + std::string serdeKind) + : ExchangeNode(id, type, std::move(serdeKind)), sortingKeys_(sortingKeys), sortingOrders_(sortingOrders) {} @@ -2796,7 +3197,7 @@ folly::dynamic MergeExchangeNode::serialize() const { obj["outputType"] = ExchangeNode::outputType()->serialize(); obj["sortingKeys"] = ISerializable::serialize(sortingKeys_); obj["sortingOrders"] = serializeSortingOrders(sortingOrders_); - obj["serdeKind"] = VectorSerde::kindName(serdeKind()); + obj["serdeKind"] = serdeKind(); return obj; } @@ -2813,7 +3214,7 @@ PlanNodePtr MergeExchangeNode::create( const auto outputType = deserializeRowType(obj["outputType"]); const auto sortingKeys = deserializeFields(obj["sortingKeys"], context); const auto sortingOrders = deserializeSortingOrders(obj["sortingOrders"]); - const auto serdeKind = VectorSerde::kindByName(obj["serdeKind"].asString()); + const auto serdeKind = obj["serdeKind"].asString(); return std::make_shared( deserializePlanNodeId(obj), outputType, @@ -2883,7 +3284,7 @@ PartitionedOutputNode::PartitionedOutputNode( bool replicateNullsAndAny, PartitionFunctionSpecPtr partitionFunctionSpec, RowTypePtr outputType, - VectorSerde::Kind serdeKind, + std::string serdeKind, PlanNodePtr source) : PlanNode(id), kind_(kind), @@ -2892,7 +3293,7 @@ PartitionedOutputNode::PartitionedOutputNode( numPartitions_(numPartitions), replicateNullsAndAny_(replicateNullsAndAny), partitionFunctionSpec_(std::move(partitionFunctionSpec)), - serdeKind_(serdeKind), + serdeKind_(std::move(serdeKind)), outputType_(std::move(outputType)) { VELOX_USER_CHECK_GT(numPartitions_, 0); if (numPartitions_ == 1) { @@ -2917,7 +3318,7 @@ std::shared_ptr PartitionedOutputNode::broadcast( const PlanNodeId& id, int numPartitions, RowTypePtr outputType, - VectorSerde::Kind serdeKind, + std::string serdeKind, PlanNodePtr source) { std::vector noKeys; return std::make_shared( @@ -2936,7 +3337,7 @@ std::shared_ptr PartitionedOutputNode::broadcast( std::shared_ptr PartitionedOutputNode::arbitrary( const PlanNodeId& id, RowTypePtr outputType, - VectorSerde::Kind serdeKind, + std::string serdeKind, PlanNodePtr source) { std::vector noKeys; return std::make_shared( @@ -2955,7 +3356,7 @@ std::shared_ptr PartitionedOutputNode::arbitrary( std::shared_ptr PartitionedOutputNode::single( const PlanNodeId& id, RowTypePtr outputType, - VectorSerde::Kind serdeKind, + std::string serdeKind, PlanNodePtr source) { std::vector noKeys; return std::make_shared( @@ -3037,7 +3438,7 @@ folly::dynamic PartitionedOutputNode::serialize() const { obj["keys"] = ISerializable::serialize(keys_); obj["replicateNullsAndAny"] = replicateNullsAndAny_; obj["partitionFunctionSpec"] = partitionFunctionSpec_->serialize(); - obj["serdeKind"] = VectorSerde::kindName(serdeKind_); + obj["serdeKind"] = serdeKind_; obj["outputType"] = outputType_->serialize(); return obj; } @@ -3061,7 +3462,7 @@ PlanNodePtr PartitionedOutputNode::create( ISerializable::deserialize( obj["partitionFunctionSpec"], context), deserializeRowType(obj["outputType"]), - VectorSerde::kindByName(obj["serdeKind"].asString()), + obj["serdeKind"].asString(), deserializeSingleSource(obj, context)); } @@ -3069,21 +3470,30 @@ SpatialJoinNode::SpatialJoinNode( const PlanNodeId& id, JoinType joinType, TypedExprPtr joinCondition, + FieldAccessTypedExprPtr probeGeometry, + FieldAccessTypedExprPtr buildGeometry, + std::optional radius, PlanNodePtr left, PlanNodePtr right, RowTypePtr outputType) : PlanNode(id), joinType_(joinType), joinCondition_(std::move(joinCondition)), + probeGeometry_(std::move(probeGeometry)), + buildGeometry_(std::move(buildGeometry)), + radius_(std::move(radius)), sources_({std::move(left), std::move(right)}), outputType_(std::move(outputType)) { VELOX_USER_CHECK( isSupported(joinType_), "The join type is not supported by spatial join: {}", JoinTypeName::toName(joinType_)); - VELOX_USER_CHECK( - joinCondition_ != nullptr, - "The join condition must not be null for spatial join"); + VELOX_USER_CHECK_NOT_NULL( + joinCondition_, "The join condition must not be null for spatial join"); + VELOX_USER_CHECK_NOT_NULL( + probeGeometry_, "Probe geometery must not be null for spatial joins"); + VELOX_USER_CHECK_NOT_NULL( + buildGeometry_, "Build geometery must not be null for spatial joins"); VELOX_USER_CHECK_EQ( sources_.size(), 2, "Must have 2 sources for spatial joins"); VELOX_USER_CHECK( @@ -3092,7 +3502,7 @@ SpatialJoinNode::SpatialJoinNode( sources_[1] != nullptr, "Right source must not be null for spatial joins"); - checkJoinColumnNames( + checkJoinOutput( sources_[0]->outputType(), sources_[1]->outputType(), outputType_, @@ -3115,6 +3525,11 @@ void SpatialJoinNode::addDetails(std::stringstream& stream) const { if (joinCondition_) { stream << ", joinCondition: " << joinCondition_->toString(); } + stream << ", probeGeometry: " << probeGeometry_->name(); + stream << ", buildGeometry: " << buildGeometry_->name(); + if (radius_) { + stream << ", radius: " << radius_.value()->name(); + } } folly::dynamic SpatialJoinNode::serialize() const { @@ -3124,6 +3539,11 @@ folly::dynamic SpatialJoinNode::serialize() const { obj["joinCondition"] = joinCondition_->serialize(); } obj["outputType"] = outputType_->serialize(); + obj["probeGeometry"] = probeGeometry_->serialize(); + obj["buildGeometry"] = buildGeometry_->serialize(); + if (radius_) { + obj["radius"] = radius_.value()->serialize(); + } return obj; } @@ -3144,11 +3564,20 @@ PlanNodePtr SpatialJoinNode::create(const folly::dynamic& obj, void* context) { } auto outputType = deserializeRowType(obj["outputType"]); + auto probeGeometry = deserializeField(obj["probeGeometry"]); + auto buildGeometry = deserializeField(obj["buildGeometry"]); + std::optional radius; + if (obj.count("radius")) { + radius = deserializeField(obj["radius"]); + } return std::make_shared( deserializePlanNodeId(obj), JoinTypeName::toJoinType(obj["joinType"].asString()), joinCondition, + probeGeometry, + buildGeometry, + radius, sources[0], sources[1], outputType); @@ -3298,6 +3727,14 @@ void MarkDistinctNode::addDetails(std::stringstream& stream) const { addFields(stream, distinctKeys_); } +void EnforceDistinctNode::addDetails(std::stringstream& stream) const { + if (isPreGrouped()) { + stream << "STREAMING "; + } + addFields(stream, distinctKeys_); + stream << " " << errorMessage_; +} + void PlanNode::toString( std::stringstream& stream, bool detailed, @@ -3425,6 +3862,25 @@ void PlanNode::toSkeletonString( } } +// static +const PlanNode* PlanNode::findFirstNode( + const PlanNode* root, + const std::function& predicate) { + VELOX_CHECK_NOT_NULL(root); + if (predicate(root)) { + return root; + } + + // Recursively go further through the sources. + for (const auto& source : root->sources()) { + const auto* ret = PlanNode::findFirstNode(source.get(), predicate); + if (ret != nullptr) { + return ret; + } + } + return nullptr; +} + namespace { void collectLeafPlanNodeIds( const PlanNode& planNode, @@ -3455,6 +3911,7 @@ void PlanNode::registerSerDe() { registry.Register("AggregationNode", AggregationNode::create); registry.Register("AssignUniqueIdNode", AssignUniqueIdNode::create); registry.Register("EnforceSingleRowNode", EnforceSingleRowNode::create); + registry.Register("EnforceDistinctNode", EnforceDistinctNode::create); registry.Register("ExchangeNode", ExchangeNode::create); registry.Register("ExpandNode", ExpandNode::create); registry.Register("FilterNode", FilterNode::create); @@ -3467,6 +3924,8 @@ void PlanNode::registerSerDe() { registry.Register("LimitNode", LimitNode::create); registry.Register("LocalMergeNode", LocalMergeNode::create); registry.Register("LocalPartitionNode", LocalPartitionNode::create); + registry.Register("MarkDistinctNode", MarkDistinctNode::create); + registry.Register("MarkSortedNode", MarkSortedNode::create); registry.Register("OrderByNode", OrderByNode::create); registry.Register("PartitionedOutputNode", PartitionedOutputNode::create); registry.Register("ProjectNode", ProjectNode::create); @@ -3482,6 +3941,8 @@ void PlanNode::registerSerDe() { registry.Register("ValuesNode", ValuesNode::create); registry.Register("WindowNode", WindowNode::create); registry.Register("MarkDistinctNode", MarkDistinctNode::create); + registry.Register("MixedUnionNode", MixedUnionNode::create); + registry.Register("RPCNode", RPCNode::create); registry.Register( "GatherPartitionFunctionSpec", GatherPartitionFunctionSpec::deserialize); } @@ -3587,7 +4048,7 @@ folly::dynamic IndexLookupCondition::serialize() const { } bool InIndexLookupCondition::isFilter() const { - return std::dynamic_pointer_cast(list) != nullptr; + return list->isConstantKind(); } folly::dynamic InIndexLookupCondition::serialize() const { @@ -3605,16 +4066,13 @@ void InIndexLookupCondition::validate() const { VELOX_CHECK_NOT_NULL(key); VELOX_CHECK_NOT_NULL(list); VELOX_CHECK( - std::dynamic_pointer_cast(list) || - std::dynamic_pointer_cast(list), + list->isFieldAccessKind() || list->isConstantKind(), "Invalid condition list {}", list->toString()); - const auto listType = - std::dynamic_pointer_cast(list->type()); - VELOX_CHECK_NOT_NULL(listType); + const auto& listType = list->type()->asArray(); VELOX_CHECK_EQ( key->type()->kind(), - listType->elementType()->kind(), + listType.elementType()->kind(), "In condition key and list condition element must have the same type"); } @@ -3632,9 +4090,7 @@ IndexLookupConditionPtr InIndexLookupCondition::create( } bool BetweenIndexLookupCondition::isFilter() const { - return (std::dynamic_pointer_cast(lower) != - nullptr) && - (std::dynamic_pointer_cast(upper) != nullptr); + return lower->isConstantKind() && upper->isConstantKind(); } folly::dynamic BetweenIndexLookupCondition::serialize() const { @@ -3669,14 +4125,12 @@ void BetweenIndexLookupCondition::validate() const { VELOX_CHECK_NOT_NULL(lower); VELOX_CHECK_NOT_NULL(upper); VELOX_CHECK( - std::dynamic_pointer_cast(lower) || - std::dynamic_pointer_cast(lower), + lower->isFieldAccessKind() || lower->isConstantKind(), "Invalid lower between condition {}", lower->toString()); VELOX_CHECK( - std::dynamic_pointer_cast(upper) || - std::dynamic_pointer_cast(upper), + upper->isFieldAccessKind() || upper->isConstantKind(), "Invalid upper between condition {}", upper->toString()); @@ -3692,7 +4146,7 @@ void BetweenIndexLookupCondition::validate() const { } bool EqualIndexLookupCondition::isFilter() const { - return std::dynamic_pointer_cast(value) != nullptr; + return value->isConstantKind(); } folly::dynamic EqualIndexLookupCondition::serialize() const { @@ -3718,9 +4172,15 @@ IndexLookupConditionPtr EqualIndexLookupCondition::create( void EqualIndexLookupCondition::validate() const { VELOX_CHECK_NOT_NULL(key); VELOX_CHECK_NOT_NULL(value); - VELOX_CHECK_NOT_NULL( - std::dynamic_pointer_cast(value), - "Equal condition value must be a constant expression: {}", + // Value can be either a constant expression or a field access expression + // (probe side column). + const bool isConstant = + std::dynamic_pointer_cast(value) != nullptr; + const bool isFieldAccess = + std::dynamic_pointer_cast(value) != nullptr; + VELOX_CHECK( + isConstant || isFieldAccess, + "Equal condition value must be a constant or field access expression: {}", value->toString()); VELOX_CHECK_EQ( @@ -3730,4 +4190,184 @@ void EqualIndexLookupCondition::validate() const { key->type()->toString(), value->type()->toString()); } + +void MixedUnionNode::accept( + const PlanNodeVisitor& visitor, + PlanNodeVisitorContext& context) const { + visitor.visit(*this, context); +} + +folly::dynamic MixedUnionNode::serialize() const { + auto obj = PlanNode::serialize(); + return obj; +} + +// static +PlanNodePtr MixedUnionNode::create(const folly::dynamic& obj, void* context) { + auto sources = deserializeSources(obj, context); + + return std::make_shared( + deserializePlanNodeId(obj), std::move(sources)); +} + +RPCNode::RPCNode( + const PlanNodeId& id, + PlanNodePtr source, + std::string functionName, + TypePtr functionResultType, + std::string outputColumn, + RowTypePtr outputType, + std::vector argumentColumns, + std::vector argumentTypes, + std::vector constantInputs, + rpc::RPCStreamingMode streamingMode, + int32_t dispatchBatchSize) + : PlanNode(id), + sources_{std::move(source)}, + functionName_(std::move(functionName)), + resultType_(std::move(functionResultType)), + outputColumn_(std::move(outputColumn)), + outputType_(std::move(outputType)), + argumentColumns_(std::move(argumentColumns)), + argumentTypes_(std::move(argumentTypes)), + constantInputs_(std::move(constantInputs)), + streamingMode_(streamingMode), + dispatchBatchSize_(dispatchBatchSize) { + VELOX_CHECK_EQ( + argumentColumns_.size(), + argumentTypes_.size(), + "argumentColumns and argumentTypes must have the same size"); + VELOX_CHECK_EQ( + argumentColumns_.size(), + constantInputs_.size(), + "argumentColumns and constantInputs must have the same size"); + VELOX_CHECK( + outputType_->containsChild(outputColumn_), + "RPCNode outputType must contain the RPC result column: {}", + outputColumn_); +} + +void RPCNode::addDetails(std::stringstream& stream) const { + stream << "function: " << functionName_ << ", outputColumn: " << outputColumn_ + << ", streamingMode: " + << (streamingMode_ == rpc::RPCStreamingMode::kBatch ? "BATCH" + : "PER_ROW"); + if (dispatchBatchSize_ > 0) { + stream << ", dispatchBatchSize: " << dispatchBatchSize_; + } +} + +folly::dynamic RPCNode::serialize() const { + auto obj = PlanNode::serialize(); + obj["functionName"] = functionName_; + obj["resultType"] = resultType_->serialize(); + + // Serialize argument columns (string names). + auto colsArray = folly::dynamic::array(); + for (const auto& col : argumentColumns_) { + colsArray.push_back(col); + } + obj["argumentColumns"] = std::move(colsArray); + + // Serialize argument types. + obj["argumentTypes"] = ISerializable::serialize(argumentTypes_); + + // Serialize constant inputs as ConstantTypedExpr for round-trip fidelity. + auto constArray = folly::dynamic::array(); + for (size_t i = 0; i < constantInputs_.size(); ++i) { + if (constantInputs_[i]) { + auto constExpr = + std::make_shared(constantInputs_[i]); + constArray.push_back(constExpr->serialize()); + } else { + constArray.push_back(nullptr); + } + } + obj["constantInputs"] = std::move(constArray); + + obj["outputColumn"] = outputColumn_; + obj["outputType"] = outputType_->serialize(); + obj["streamingMode"] = + streamingMode_ == rpc::RPCStreamingMode::kBatch ? "BATCH" : "PER_ROW"; + obj["dispatchBatchSize"] = dispatchBatchSize_; + return obj; +} + +// static +PlanNodePtr RPCNode::create(const folly::dynamic& obj, void* context) { + auto source = deserializeSingleSource(obj, context); + auto functionName = obj["functionName"].asString(); + auto resultType = ISerializable::deserialize(obj["resultType"]); + + // Deserialize argument columns. + std::vector argumentColumns; + if (obj.count("argumentColumns")) { + for (const auto& col : obj["argumentColumns"]) { + argumentColumns.push_back(col.asString()); + } + } + + // Deserialize argument types. + auto argumentTypes = + ISerializable::deserialize>(obj["argumentTypes"]); + + // Deserialize constant inputs from ConstantTypedExpr. + std::vector constantInputs; + if (obj.count("constantInputs")) { + for (const auto& item : obj["constantInputs"]) { + if (item.isNull()) { + constantInputs.push_back(nullptr); + } else { + auto constExpr = std::dynamic_pointer_cast( + ISerializable::deserialize(item, context)); + VELOX_CHECK_NOT_NULL( + constExpr, "Expected ConstantTypedExpr for constant input"); + auto* pool = static_cast(context); + constantInputs.push_back(constExpr->toConstantVector(pool)); + } + } + } + + auto outputColumn = obj["outputColumn"].asString(); + + // Deserialize explicit output type. + RowTypePtr outputType; + if (obj.count("outputType")) { + outputType = std::dynamic_pointer_cast( + ISerializable::deserialize(obj["outputType"])); + } else { + // Backward compat: derive from source + result column. + std::vector names; + std::vector types; + if (source) { + auto sourceType = source->outputType(); + for (int32_t i = 0; i < sourceType->size(); ++i) { + names.emplace_back(sourceType->nameOf(i)); + types.push_back(sourceType->childAt(i)); + } + } + names.push_back(outputColumn); + types.push_back(resultType); + outputType = ROW(std::move(names), std::move(types)); + } + + auto streamingMode = obj["streamingMode"].asString() == "BATCH" + ? rpc::RPCStreamingMode::kBatch + : rpc::RPCStreamingMode::kPerRow; + auto dispatchBatchSize = + static_cast(obj["dispatchBatchSize"].asInt()); + return std::make_shared( + deserializePlanNodeId(obj), + std::move(source), + std::move(functionName), + std::move(resultType), + std::move(outputColumn), + std::move(outputType), + std::move(argumentColumns), + std::move(argumentTypes), + std::move(constantInputs), + streamingMode, + dispatchBatchSize); +} + } // namespace facebook::velox::core diff --git a/velox/core/PlanNode.h b/velox/core/PlanNode.h index b5c82d29a2c..4d7d1017f4b 100644 --- a/velox/core/PlanNode.h +++ b/velox/core/PlanNode.h @@ -19,7 +19,8 @@ #include -#include "velox/common/Enums.h" +#include "velox/common/EnumDeclare.h" +#include "velox/common/rpc/RPCTypes.h" #include "velox/connectors/Connector.h" #include "velox/core/Expressions.h" #include "velox/core/QueryConfig.h" @@ -206,6 +207,14 @@ class PlanNode : public ISerializable { return false; } + /// Returns true if this plan node requires single-threaded execution + /// (maxDrivers = 1). For example, ValuesNode, final OrderByNode, final + /// LimitNode, MergeExchangeNode, LocalMergeNode, and + /// LocalPartitionNode(Gather) all require single-threaded execution. + virtual bool requiresSingleThread() const { + return false; + } + /// Returns true if this plan node operator supports task barrier processing. /// To support barrier processing, the operator must be able to drain its /// buffered output when it receives the drain signal at split boundary. Not @@ -259,23 +268,28 @@ class PlanNode : public ISerializable { /// The name of the plan node, used in toString. virtual std::string_view name() const = 0; + template + bool is() const { + return dynamic_cast(this) != nullptr; + } + + template + const T* as() const { + return dynamic_cast(this); + } + /// Recursively checks the node tree for a first node that satisfy a given /// condition. Returns pointer to the node if found, nullptr if not. static const PlanNode* findFirstNode( - const PlanNode* node, - const std::function& predicate) { - if (predicate(node)) { - return node; - } + const PlanNode* root, + const std::function& predicate); - // Recursively go further through the sources. - for (const auto& source : node->sources()) { - const auto* ret = PlanNode::findFirstNode(source.get(), predicate); - if (ret != nullptr) { - return ret; - } - } - return nullptr; + /// @return PlanNode with matching ID or nullptr if not found. + static const PlanNode* findNodeById( + const PlanNode* root, + const PlanNodeId& id) { + return findFirstNode( + root, [&](const auto* node) { return node->id() == id; }); } private: @@ -386,6 +400,10 @@ class ValuesNode : public PlanNode { void accept(const PlanNodeVisitor& visitor, PlanNodeVisitorContext& context) const override; + bool requiresSingleThread() const override { + return !parallelizable_; + } + const std::vector& values() const { return values_; } @@ -488,6 +506,10 @@ class ArrowStreamNode : public PlanNode { void accept(const PlanNodeVisitor& visitor, PlanNodeVisitorContext& context) const override; + bool requiresSingleThread() const override { + return true; + } + const std::shared_ptr& arrowStream() const { return arrowStream_; } @@ -1103,14 +1125,14 @@ class AggregationNode : public PlanNode { /// Optional name of input column to use as a mask. Column type must be /// BOOLEAN. - FieldAccessTypedExprPtr mask; + FieldAccessTypedExprPtr mask{}; /// Optional list of input columns to sort by before applying aggregate /// function. - std::vector sortingKeys; + std::vector sortingKeys{}; /// A list of sorting orders that goes together with 'sortingKeys'. - std::vector sortingOrders; + std::vector sortingOrders{}; /// Boolean indicating whether inputs must be de-duplicated before /// aggregating. @@ -1129,6 +1151,7 @@ class AggregationNode : public PlanNode { const std::vector& aggregateNames, const std::vector& aggregates, bool ignoreNullKeys, + bool noGroupsSpanBatches, PlanNodePtr source); /// @param globalGroupingSets Group IDs of the global grouping sets produced @@ -1150,6 +1173,7 @@ class AggregationNode : public PlanNode { const std::vector& globalGroupingSets, const std::optional& groupId, bool ignoreNullKeys, + bool noGroupsSpanBatches, PlanNodePtr source); class Builder { @@ -1166,6 +1190,7 @@ class AggregationNode : public PlanNode { globalGroupingSets_ = other.globalGroupingSets(); groupId_ = other.groupId(); ignoreNullKeys_ = other.ignoreNullKeys(); + noGroupsSpanBatches_ = other.noGroupsSpanBatches(); VELOX_CHECK_EQ(other.sources().size(), 1); source_ = other.sources()[0]; } @@ -1216,6 +1241,11 @@ class AggregationNode : public PlanNode { return *this; } + Builder& noGroupsSpanBatches(bool noGroupsSpanBatches) { + noGroupsSpanBatches_ = noGroupsSpanBatches; + return *this; + } + Builder& source(PlanNodePtr source) { source_ = std::move(source); return *this; @@ -1250,6 +1280,7 @@ class AggregationNode : public PlanNode { globalGroupingSets_, groupId_, ignoreNullKeys_.value(), + noGroupsSpanBatches_, source_.value()); } @@ -1263,6 +1294,7 @@ class AggregationNode : public PlanNode { std::vector globalGroupingSets_ = kDefaultGlobalGroupingSets; std::optional groupId_ = kDefaultGroupId; std::optional ignoreNullKeys_; + bool noGroupsSpanBatches_{false}; std::optional source_; }; @@ -1296,10 +1328,10 @@ class AggregationNode : public PlanNode { bool isPreGrouped() const { return !preGroupedKeys_.empty() && std::equal( - preGroupedKeys_.begin(), - preGroupedKeys_.end(), - groupingKeys_.begin(), - groupingKeys_.end(), + preGroupedKeys_.cbegin(), + preGroupedKeys_.cend(), + groupingKeys_.cbegin(), + groupingKeys_.cend(), [](const FieldAccessTypedExprPtr& x, const FieldAccessTypedExprPtr& y) -> bool { return (*x == *y); @@ -1322,10 +1354,19 @@ class AggregationNode : public PlanNode { return globalGroupingSets_; } - std::optional groupId() const { + const std::optional& groupId() const { return groupId_; } + /// When true, indicates that for streaming aggregation, no sort group spans + /// across input batches. Each input batch contains complete data for its + /// groups - no group will appear in any subsequent input batch. This allows + /// the streaming aggregation operator to immediately produce the aggregation + /// result for all the groups in each input batch. + bool noGroupsSpanBatches() const { + return noGroupsSpanBatches_; + } + std::string_view name() const override { return "Aggregation"; } @@ -1362,8 +1403,15 @@ class AggregationNode : public PlanNode { const std::vector aggregates_; const bool ignoreNullKeys_; - std::optional groupId_; - std::vector globalGroupingSets_; + const std::optional groupId_; + const std::vector globalGroupingSets_; + + // When true, indicates that for streaming aggregation, no sort group spans + // across input batches. Each input batch contains complete data for its + // groups - no group will appear in any subsequent input batch. This allows + // the streaming aggregation operator to immediately produce the aggregation + // result for all the groups in each input batch. + const bool noGroupsSpanBatches_; const std::vector sources_; const RowTypePtr outputType_; @@ -1433,13 +1481,91 @@ struct ColumnStatsSpec : public ISerializable { VELOX_CHECK_EQ(aggregates.size(), aggregateNames.size()); } + /// Returns the output row type that will be produced by this column stats + /// spec. The output type is determined by the grouping keys and aggregate + /// functions specified in the object. + RowTypePtr outputType() const; + folly::dynamic serialize() const override; static ColumnStatsSpec create(const folly::dynamic& obj, void* context); }; +/// Writes input rows to a table via a connector-specific DataSink and +/// optionally collects per-column statistics (count, min, max, +/// approx_distinct) using an embedded ColumnStatsCollector. +/// +/// Two output modes depending on outputType: +/// +/// 1. Single-column BIGINT (Spark/Gluten): output is a single row count. +/// No stats, no fragments, no commit context. Used when columnStatsSpec +/// is not set. +/// +/// 2. Multiplexed format (Prestissimo, see TableWriteTraits): +/// Channel 0 (rows): row count or NULL +/// Channel 1 (fragments): file fragment data or NULL +/// Channel 2 (context): commit context JSON (always present) +/// Channel 3+ (stats): aggregated statistics columns (if configured) +/// +/// Each operator instance (one per driver) produces three kinds of rows: +/// - Statistics rows: rows=NULL, fragments=NULL, stats populated. +/// One row per partition (or one row for unpartitioned tables). +/// - Fragment rows: rows=NULL, fragments=non-NULL, stats=NULL. +/// One row per output file. +/// - Summary row: rows=totalCount, fragments=NULL, stats=NULL. +/// One per driver, emitted last. +/// +/// The context column (channel 2) is a JSON object. See +/// TableWriteTraits for field names (taskId, lifespan, +/// pageSinkCommitStrategy, lastPage). All rows carry context; the +/// summary row has lastPage=true. +/// +/// When columnStatsSpec is set, the aggregation step controls output types: +/// - kSingle: produces final statistics values (single-driver, no merge). +/// - kPartial: produces intermediate aggregation state (requires a +/// downstream TableWriteMergeNode to finalize). +/// +/// Typical plan topologies (data flows left to right): +/// +/// Single-node, single-driver: +/// Input → TableWrite(kSingle) +/// +/// Single-node, multi-driver: +/// Input → TableWrite(kPartial) → LocalGather → TableWriteMerge(kFinal) +/// +/// Multi-node, multi-driver: +/// Worker: Input → TableWrite(kPartial) → LocalGather +/// → TableWriteMerge(kIntermediate) → PartitionedOutput +/// Coordinator: Exchange → TableWriteMerge(kFinal) +/// In Prestissimo, the coordinator uses Presto's Java +/// TableFinishOperator instead of TableWriteMerge. class TableWriteNode : public PlanNode { public: + /// @param id Plan node ID. + /// @param columns Subset of source output columns to write, potentially + /// reordered. The names in this type must match columns in the source + /// output (used to build the input-to-output column mapping). + /// @param columnNames Target table column names for the written data. + /// Aligned 1:1 with 'columns'. May differ from 'columns' names when the + /// query renames columns (e.g. source has "expr_0" but table column is + /// "key"). The DataSink receives data using these names. + /// @param columnStatsSpec Optional specification for column statistics + /// collection. When set, the operator collects per-column aggregates + /// (count, min, max, approx_distinct) alongside the write. Restrictions: + /// - aggregation step must be kSingle or kPartial. + /// - grouping keys must be a subset of 'columns' (partition columns). + /// - grouping keys must not contain duplicates. + /// @param insertTableHandle Connector-specific handle identifying the + /// target table and write operation. + /// @param hasPartitioningScheme Whether a partitioning scheme is configured + /// for shuffles. Controls which query config determines the number of + /// writer operator instances: 'task_partitioned_writer_count' if true, + /// 'task_writer_count' if false. + /// @param outputType Output row type. For Prestissimo, must match + /// TableWriteTraits::outputType(columnStatsSpec). For Spark/Gluten, a + /// single-column BIGINT type with no columnStatsSpec. + /// @param commitStrategy Commit strategy for the write operation. + /// @param source Input plan node providing rows to write. TableWriteNode( const PlanNodeId& id, const RowTypePtr& columns, @@ -1449,25 +1575,7 @@ class TableWriteNode : public PlanNode { bool hasPartitioningScheme, RowTypePtr outputType, connector::CommitStrategy commitStrategy, - const PlanNodePtr& source) - : PlanNode(id), - sources_{source}, - columns_{columns}, - columnNames_{columnNames}, - columnStatsSpec_(std::move(columnStatsSpec)), - insertTableHandle_(std::move(insertTableHandle)), - hasPartitioningScheme_(hasPartitioningScheme), - outputType_(std::move(outputType)), - commitStrategy_(commitStrategy) { - VELOX_USER_CHECK_EQ(columns_->size(), columnNames_.size()); - for (const auto& column : columns_->names()) { - VELOX_USER_CHECK( - source->outputType()->containsChild(column), - "Column {} not found in TableWriter input: {}", - column, - source->outputType()->toString()); - } - } + const PlanNodePtr& source); class Builder { public: @@ -1605,8 +1713,7 @@ class TableWriteNode : public PlanNode { /// Indicates if this table write plan node has specified partitioning /// scheme for remote and local shuffles. If true, the task creates a /// number of table write operators based on the query config - /// 'task_partitioned_writer_count', otherwise based on - /// x'task_writer_count'. + /// 'task_partitioned_writer_count', otherwise based on 'task_writer_count'. bool hasPartitioningScheme() const { return hasPartitioningScheme_; } @@ -1615,17 +1722,19 @@ class TableWriteNode : public PlanNode { return commitStrategy_; } - /// Returns true of this table write plan node has configured column - /// statistics collection. bool hasColumnStatsSpec() const { return columnStatsSpec_.has_value(); } - /// Optional spec for column statistics collection. const std::optional& columnStatsSpec() const { return columnStatsSpec_; } + bool requiresSingleThread() const override { + return !insertTableHandle_->connectorInsertTableHandle() + ->supportsMultiThreading(); + } + bool canSpill(const QueryConfig& queryConfig) const override { return queryConfig.writerSpillEnabled(); } @@ -1653,29 +1762,50 @@ class TableWriteNode : public PlanNode { using TableWriteNodePtr = std::shared_ptr; +/// Merges output from multiple TableWrite operators. Collects fragments, +/// accumulates row counts, and aggregates column statistics using an +/// embedded ColumnStatsCollector. +/// +/// Input rows are classified per-row using TableWriteTraits::isStatisticsRow +/// (see TableWriteTraits). Input batches may contain a mix of statistics and +/// data rows (e.g. when receiving batched output from an exchange): +/// - Statistics rows (rows=NULL, fragments=NULL): routed to the stats +/// collector for aggregation. +/// - Data rows (rows or fragments non-NULL): row counts are accumulated, +/// fragments are buffered, and commit context is validated for +/// consistency (all inputs must share the same commit strategy; +/// taskId may differ in cross-worker merge). +/// +/// Output follows the same three-phase protocol as TableWriteNode: +/// 1. Fragment rows (emitted first, to free memory). +/// 2. Aggregated statistics rows from the stats collector. +/// 3. Summary row with total row count and lastPage=true. +/// +/// The aggregation step in ColumnStatsSpec must be kIntermediate or kFinal: +/// - kIntermediate: reads partial state, produces partial state (for +/// further merging downstream). +/// - kFinal: reads partial state, produces final scalar values. +/// +/// Supports both single-task multi-driver merge (via LocalGather) and +/// cross-task merge (via Exchange from multiple workers). class TableWriteMergeNode : public PlanNode { public: - /// 'outputType' specifies the type to store the metadata of table write - /// output which contains the following columns: 'numWrittenRows', - /// 'fragment' and 'tableCommitContext'. + /// @param id Plan node ID. + /// @param columnStatsSpec Optional specification for column statistics + /// aggregation. Restrictions: + /// - aggregation step must be kIntermediate or kFinal. + /// - grouping keys must be present in source output type. + /// - grouping keys must not contain duplicates. + /// @param outputType Output row type. Column names may differ from + /// TableWriteTraits defaults (e.g. Prestissimo appends node ID suffixes). + /// Types must match TableWriteTraits::outputType(columnStatsSpec). + /// @param source Input plan node, typically a LocalGather over + /// TableWriteNode(s). TableWriteMergeNode( const PlanNodeId& id, RowTypePtr outputType, std::optional columnStatsSpec, - PlanNodePtr source) - : PlanNode(id), - columnStatsSpec_(std::move(columnStatsSpec)), - sources_{std::move(source)}, - outputType_(std::move(outputType)) { - if (hasColumnStatsSpec()) { - VELOX_USER_CHECK( - columnStatsSpec_->aggregationStep == - core::AggregationNode::Step::kFinal || - columnStatsSpec_->aggregationStep == - core::AggregationNode::Step::kIntermediate, - "TableWriteMergeNode requires aggregation step to be intermediate or final"); - } - } + PlanNodePtr source); class Builder { public: @@ -1738,6 +1868,10 @@ class TableWriteMergeNode : public PlanNode { return columnStatsSpec_; } + bool requiresSingleThread() const override { + return true; + } + const std::vector& sources() const override { return sources_; } @@ -2044,11 +2178,8 @@ using GroupIdNodePtr = std::shared_ptr; class ExchangeNode : public PlanNode { public: - ExchangeNode( - const PlanNodeId& id, - RowTypePtr type, - VectorSerde::Kind serdeKind) - : PlanNode(id), outputType_(type), serdeKind_(serdeKind) {} + ExchangeNode(const PlanNodeId& id, RowTypePtr type, std::string serdeKind) + : PlanNode(id), outputType_(type), serdeKind_(std::move(serdeKind)) {} class Builder { public: @@ -2070,8 +2201,8 @@ class ExchangeNode : public PlanNode { return *this; } - Builder& serdeKind(VectorSerde::Kind serdeKind) { - serdeKind_ = serdeKind; + Builder& serdeKind(std::string serdeKind) { + serdeKind_ = std::move(serdeKind); return *this; } @@ -2089,7 +2220,7 @@ class ExchangeNode : public PlanNode { private: std::optional id_; std::optional outputType_; - std::optional serdeKind_; + std::optional serdeKind_; }; const RowTypePtr& outputType() const override { @@ -2113,7 +2244,7 @@ class ExchangeNode : public PlanNode { return "Exchange"; } - VectorSerde::Kind serdeKind() const { + const std::string& serdeKind() const { return serdeKind_; } @@ -2125,7 +2256,7 @@ class ExchangeNode : public PlanNode { void addDetails(std::stringstream& stream) const override; const RowTypePtr outputType_; - const VectorSerde::Kind serdeKind_; + const std::string serdeKind_; }; using ExchangeNodePtr = std::shared_ptr; @@ -2137,7 +2268,7 @@ class MergeExchangeNode : public ExchangeNode { const RowTypePtr& type, const std::vector& sortingKeys, const std::vector& sortingOrders, - VectorSerde::Kind serdeKind); + std::string serdeKind); class Builder { public: @@ -2171,8 +2302,8 @@ class MergeExchangeNode : public ExchangeNode { return *this; } - Builder& serdeKind(VectorSerde::Kind serdeKind) { - serdeKind_ = serdeKind; + Builder& serdeKind(std::string serdeKind) { + serdeKind_ = std::move(serdeKind); return *this; } @@ -2201,7 +2332,7 @@ class MergeExchangeNode : public ExchangeNode { std::optional outputType_; std::optional> sortingKeys_; std::optional> sortingOrders_; - std::optional serdeKind_; + std::optional serdeKind_; }; const std::vector& sortingKeys() const { @@ -2212,6 +2343,10 @@ class MergeExchangeNode : public ExchangeNode { return sortingOrders_; } + bool requiresSingleThread() const override { + return true; + } + void accept(const PlanNodeVisitor& visitor, PlanNodeVisitorContext& context) const override; @@ -2310,6 +2445,10 @@ class LocalMergeNode : public PlanNode { void accept(const PlanNodeVisitor& visitor, PlanNodeVisitorContext& context) const override; + bool requiresSingleThread() const override { + return true; + } + const std::vector& sortingKeys() const { return sortingKeys_; } @@ -2521,6 +2660,10 @@ class LocalPartitionNode : public PlanNode { return scaleWriter_; } + bool requiresSingleThread() const override { + return type_ == Type::kGather; + } + bool supportsBarrier() const override { return !scaleWriter_; } @@ -2581,26 +2724,26 @@ class PartitionedOutputNode : public PlanNode { bool replicateNullsAndAny, PartitionFunctionSpecPtr partitionFunctionSpec, RowTypePtr outputType, - VectorSerde::Kind serdeKind, + std::string serdeKind, PlanNodePtr source); static std::shared_ptr broadcast( const PlanNodeId& id, int numPartitions, RowTypePtr outputType, - VectorSerde::Kind serdeKind, + std::string serdeKind, PlanNodePtr source); static std::shared_ptr arbitrary( const PlanNodeId& id, RowTypePtr outputType, - VectorSerde::Kind serdeKind, + std::string serdeKind, PlanNodePtr source); static std::shared_ptr single( const PlanNodeId& id, RowTypePtr outputType, - VectorSerde::Kind VectorSerde, + std::string serdeKind, PlanNodePtr source); class Builder { @@ -2655,8 +2798,8 @@ class PartitionedOutputNode : public PlanNode { return *this; } - Builder& serdeKind(VectorSerde::Kind serdeKind) { - serdeKind_ = serdeKind; + Builder& serdeKind(std::string serdeKind) { + serdeKind_ = std::move(serdeKind); return *this; } @@ -2708,7 +2851,7 @@ class PartitionedOutputNode : public PlanNode { std::optional replicateNullsAndAny_; std::optional partitionFunctionSpec_; std::optional outputType_; - std::optional serdeKind_; + std::optional serdeKind_; std::optional source_; }; @@ -2751,7 +2894,7 @@ class PartitionedOutputNode : public PlanNode { return kind_; } - VectorSerde::Kind serdeKind() const { + const std::string& serdeKind() const { return serdeKind_; } @@ -2789,7 +2932,7 @@ class PartitionedOutputNode : public PlanNode { const int numPartitions_; const bool replicateNullsAndAny_; const PartitionFunctionSpecPtr partitionFunctionSpec_; - const VectorSerde::Kind serdeKind_; + const std::string serdeKind_; const RowTypePtr outputType_; }; @@ -2830,6 +2973,12 @@ enum class JoinType { // equal to the cardinality of the left side. kLeftSemiFilter = 4, + // Multiset version of kLeftSemiFilter. The build side deduplicates keys and + // stores a per-key count. On probe, each match decrements the count; the + // probe row is emitted only while the count is greater than zero. Implements + // INTERSECT ALL semantics. + kCountingLeftSemiFilter = 5, + // Return each row from the left side with a boolean flag indicating whether // there exists a match on the right side. For this join type, cardinality // of the output equals the cardinality of the left side. @@ -2838,12 +2987,12 @@ enum class JoinType { // 'nullAware' boolean specified separately. // // Null-aware join follows IN semantic. Regular join follows EXISTS semantic. - kLeftSemiProject = 5, + kLeftSemiProject = 6, // Opposite of kLeftSemiFilter. Return a subset of rows from the right side // which have a match on the left side. For this join type, cardinality of // the output is less than or equal to the cardinality of the right side. - kRightSemiFilter = 6, + kRightSemiFilter = 7, // Opposite of kLeftSemiProject. Return each row from the right side with a // boolean flag indicating whether there exists a match on the left side. @@ -2854,7 +3003,7 @@ enum class JoinType { // 'nullAware' boolean specified separately. // // Null-aware join follows IN semantic. Regular join follows EXISTS semantic. - kRightSemiProject = 7, + kRightSemiProject = 8, // Return each row from the left side which has no match on the right side. // The handling of the rows with nulls in the join key depends on the @@ -2869,9 +3018,15 @@ enum class JoinType { // Regular anti join follows NOT EXISTS semantic: // (1) ignore right-side rows with nulls in the join keys; // (2) unconditionally return left side rows with nulls in the join keys. - kAnti = 8, + kAnti = 9, - kNumJoinTypes = 9, + // Multiset version of kAnti. The build side deduplicates keys and stores a + // per-key count. On probe, each match decrements the count; the probe row is + // emitted only when the count reaches zero or no match is found. Implements + // EXCEPT ALL semantics. + kCountingAnti = 10, + + kNumJoinTypes = 11, }; VELOX_DECLARE_ENUM_NAME(JoinType); @@ -2912,6 +3067,26 @@ inline bool isAntiJoin(JoinType joinType) { return joinType == JoinType::kAnti; } +inline bool isCountingAntiJoin(JoinType joinType) { + return joinType == JoinType::kCountingAnti; +} + +inline bool isCountingLeftSemiFilterJoin(JoinType joinType) { + return joinType == JoinType::kCountingLeftSemiFilter; +} + +inline bool isCountingJoin(JoinType joinType) { + return isCountingAntiJoin(joinType) || isCountingLeftSemiFilterJoin(joinType); +} + +/// Returns true if the join type is "probe-only", meaning the output includes +/// only columns from the probe side (plus possibly a mark column). +inline bool isProbeOnlyJoin(JoinType joinType) { + return joinType == JoinType::kLeftSemiFilter || + joinType == JoinType::kLeftSemiProject || joinType == JoinType::kAnti || + isCountingJoin(joinType); +} + inline bool isNullAwareSupported(JoinType joinType) { return joinType == JoinType::kAnti || joinType == JoinType::kLeftSemiProject || @@ -3050,10 +3225,34 @@ class AbstractJoinNode : public PlanNode { return joinType_ == JoinType::kAnti; } + bool isCountingAntiJoin() const { + return joinType_ == JoinType::kCountingAnti; + } + + bool isCountingLeftSemiFilterJoin() const { + return joinType_ == JoinType::kCountingLeftSemiFilter; + } + + bool isCountingJoin() const { + return core::isCountingJoin(joinType_); + } + bool isPreservingProbeOrder() const { return isInnerJoin() || isLeftJoin() || isAntiJoin(); } + /// Indicates if this joinNode can drop duplicate rows with same join key. + /// For left semi and anti join, it is not necessary to store duplicate rows. + /// For counting joins, duplicates are folded into a per-key count. + bool canDropDuplicates() const { + // Left semi and anti join with no extra filter only needs to know whether + // there is a match. Hence, no need to store entries with duplicate keys. + // Counting joins always deduplicate and track counts. + return isCountingJoin() || + (!filter() && + (isLeftSemiFilterJoin() || isLeftSemiProjectJoin() || isAntiJoin())); + } + const std::vector& leftKeys() const { return leftKeys_; } @@ -3093,6 +3292,12 @@ class AbstractJoinNode : public PlanNode { /// EXISTS. class HashJoinNode : public AbstractJoinNode { public: + /// @param nullAware Applies to semi and anti joins only. When true, the + /// join semantic is IN / NOT IN (three-valued NULL logic). When false, the + /// join semantic is EXISTS / NOT EXISTS. + /// @param nullAsValue When true, join keys use IS NOT DISTINCT FROM + /// semantics where NULL equals NULL. Used to implement SQL set operations + /// (EXCEPT, INTERSECT). Mutually exclusive with nullAware. HashJoinNode( const PlanNodeId& id, JoinType joinType, @@ -3102,7 +3307,9 @@ class HashJoinNode : public AbstractJoinNode { TypedExprPtr filter, PlanNodePtr left, PlanNodePtr right, - RowTypePtr outputType) + RowTypePtr outputType, + bool useHashTableCache = false, + bool nullAsValue = false) : AbstractJoinNode( id, joinType, @@ -3112,9 +3319,21 @@ class HashJoinNode : public AbstractJoinNode { std::move(left), std::move(right), std::move(outputType)), - nullAware_{nullAware} { + nullAware_{nullAware}, + nullAsValue_{nullAsValue}, + useHashTableCache_{useHashTableCache} { validate(); + VELOX_USER_CHECK( + !nullAware || !nullAsValue, + "nullAware and nullAsValue are mutually exclusive"); + + if (isCountingJoin()) { + VELOX_USER_CHECK( + !nullAware, "Counting joins do not support null-aware flag"); + VELOX_USER_CHECK(!filter_, "Counting joins do not support extra filter"); + } + if (nullAware) { VELOX_USER_CHECK( isNullAwareSupported(joinType), @@ -3137,6 +3356,8 @@ class HashJoinNode : public AbstractJoinNode { explicit Builder(const HashJoinNode& other) : AbstractJoinNode::Builder(other) { nullAware_ = other.isNullAware(); + nullAsValue_ = other.isNullAsValue(); + useHashTableCache_ = other.useHashTableCache(); } Builder& nullAware(bool value) { @@ -3144,6 +3365,16 @@ class HashJoinNode : public AbstractJoinNode { return *this; } + Builder& nullAsValue(bool value) { + nullAsValue_ = value; + return *this; + } + + Builder& useHashTableCache(bool value) { + useHashTableCache_ = value; + return *this; + } + std::shared_ptr build() const { VELOX_USER_CHECK(id_.has_value(), "HashJoinNode id is not set"); VELOX_USER_CHECK( @@ -3170,11 +3401,15 @@ class HashJoinNode : public AbstractJoinNode { filter_.value_or(nullptr), left_.value(), right_.value(), - outputType_.value()); + outputType_.value(), + useHashTableCache_.value_or(false), + nullAsValue_.value_or(false)); } private: std::optional nullAware_; + std::optional nullAsValue_; + std::optional useHashTableCache_; }; std::string_view name() const override { @@ -3193,10 +3428,27 @@ class HashJoinNode : public AbstractJoinNode { queryConfig.joinSpillEnabled(); } + bool requiresSingleThread() const override { + return isRightSemiProjectJoin() && nullAware_; + } + bool isNullAware() const { return nullAware_; } + /// Returns true when join keys use IS NOT DISTINCT FROM semantics where + /// NULL equals NULL. Used to implement SQL set operations (EXCEPT, + /// INTERSECT). + bool isNullAsValue() const { + return nullAsValue_; + } + + /// Returns whether hash table caching is enabled for broadcast joins. + /// Only used by Presto-on-Spark. + bool useHashTableCache() const { + return useHashTableCache_; + } + folly::dynamic serialize() const override; static PlanNodePtr create(const folly::dynamic& obj, void* context); @@ -3205,6 +3457,8 @@ class HashJoinNode : public AbstractJoinNode { void addDetails(std::stringstream& stream) const override; const bool nullAware_; + const bool nullAsValue_; + const bool useHashTableCache_; }; using HashJoinNodePtr = std::shared_ptr; @@ -3260,6 +3514,10 @@ class MergeJoinNode : public AbstractJoinNode { } }; + bool requiresSingleThread() const override { + return true; + } + std::string_view name() const override { return "MergeJoin"; } @@ -3368,8 +3626,9 @@ struct BetweenIndexLookupCondition : public IndexLookupCondition { using BetweenIndexLookupConditionPtr = std::shared_ptr; -/// Represents EQUAL index lookup condition: 'key' = 'value'. 'value' must be a -/// constant value with the same type as 'key'. +/// Represents EQUAL index lookup condition: 'key' = 'value'. 'value' can be +/// either a constant value or a field access expression (probe side column) +/// with the same type as 'key'. struct EqualIndexLookupCondition : public IndexLookupCondition { /// The value to compare against. TypedExprPtr value; @@ -3431,7 +3690,15 @@ class IndexLookupJoinNode : public AbstractJoinNode { public: /// @param joinType Specifies the lookup join type. Only INNER and LEFT joins /// are supported. - /// @param includeMatchColumn if true, the output type includes a boolean + /// @param leftKeys Left side join keys used for index lookup. + /// @param rightKeys Right side join keys that form the index prefix. + /// @param joinConditions Additional conditions for index lookup that can't + /// be converted into simple equality join conditions. These conditions use + /// columns from both left and right and exactly one index column from + /// the right side.sides + /// @param filter Additional filter to apply on join results. This supports + /// filters that can't be converted into join conditions. + /// @param hasMarker if true, the output type includes a boolean /// column at the end to indicate if a join output row has a match or not. /// This only applies for left join. IndexLookupJoinNode( @@ -3440,7 +3707,8 @@ class IndexLookupJoinNode : public AbstractJoinNode { const std::vector& leftKeys, const std::vector& rightKeys, const std::vector& joinConditions, - bool includeMatchColumn, + TypedExprPtr filter, + bool hasMarker, PlanNodePtr left, TableScanNodePtr right, RowTypePtr outputType); @@ -3453,16 +3721,27 @@ class IndexLookupJoinNode : public AbstractJoinNode { explicit Builder(const IndexLookupJoinNode& other) : AbstractJoinNode::Builder(other) { joinConditions_ = other.joinConditions(); + filter_ = other.filter(); + hasMarker_ = other.hasMarker(); } + /// Set lookup conditions for index lookup that can't be converted into + /// simple equality join conditions. Builder& joinConditions( std::vector joinConditions) { joinConditions_ = std::move(joinConditions); return *this; } - Builder& includeMatchColumn(bool includeMatchColumn) { - includeMatchColumn_ = includeMatchColumn; + /// Set additional filter to apply on join results. + Builder& filter(TypedExprPtr filter) { + filter_ = std::move(filter); + return *this; + } + + /// Set whether to include a marker column for left joins. + Builder& hasMarker(bool hasMarker) { + hasMarker_ = hasMarker; return *this; } @@ -3480,25 +3759,23 @@ class IndexLookupJoinNode : public AbstractJoinNode { right_.has_value(), "IndexLookupJoinNode right source is not set"); VELOX_USER_CHECK( outputType_.has_value(), "IndexLookupJoinNode outputType is not set"); - VELOX_USER_CHECK( - joinConditions_.has_value(), - "IndexLookupJoinNode join conditions are not set"); return std::make_shared( id_.value(), joinType_.value(), leftKeys_.value(), rightKeys_.value(), - joinConditions_.value(), - includeMatchColumn_, + joinConditions_, + filter_.value_or(nullptr), + hasMarker_, left_.value(), std::dynamic_pointer_cast(right_.value()), outputType_.value()); } private: - std::optional> joinConditions_; - bool includeMatchColumn_; + std::vector joinConditions_; + bool hasMarker_{false}; }; bool supportsBarrier() const override { @@ -3509,6 +3786,12 @@ class IndexLookupJoinNode : public AbstractJoinNode { return lookupSourceNode_; } + /// Returns true if the lookup source requires splits for index lookup. + /// This delegates to the table handle's needsIndexSplit() method. + bool needsIndexSplit() const; + + /// Returns the join conditions for index lookup that can't be converted into + /// simple equality join conditions. const std::vector& joinConditions() const { return joinConditions_; } @@ -3517,8 +3800,9 @@ class IndexLookupJoinNode : public AbstractJoinNode { return "IndexLookupJoin"; } - bool includeMatchColumn() const { - return includeMatchColumn_; + /// Returns whether this node includes a marker column for left joins. + bool hasMarker() const { + return hasMarker_; } void accept(const PlanNodeVisitor& visitor, PlanNodeVisitorContext& context) @@ -3534,11 +3818,16 @@ class IndexLookupJoinNode : public AbstractJoinNode { private: void addDetails(std::stringstream& stream) const override; + /// The table scan node that provides the lookup source for index operations. const TableScanNodePtr lookupSourceNode_; + /// Join conditions that can't be converted into simple equality join + /// conditions. These conditions involve columns from both left and right + /// sides and exactly one index column from the right side. const std::vector joinConditions_; - const bool includeMatchColumn_; + /// Whether to include a marker column for left joins to indicate matches. + const bool hasMarker_; }; using IndexLookupJoinNodePtr = std::shared_ptr; @@ -3806,6 +4095,10 @@ class OrderByNode : public PlanNode { void accept(const PlanNodeVisitor& visitor, PlanNodeVisitorContext& context) const override; + bool requiresSingleThread() const override { + return !isPartial_; + } + // True if this node only sorts a portion of the final result. If it is // true, a local merge or merge exchange is required to merge the sorted // runs. @@ -3852,10 +4145,29 @@ class SpatialJoinNode : public PlanNode { const PlanNodeId& id, JoinType joinType, TypedExprPtr joinCondition, + FieldAccessTypedExprPtr probeGeometry, + FieldAccessTypedExprPtr buildGeometry, + std::optional radius, PlanNodePtr left, PlanNodePtr right, RowTypePtr outputType); + SpatialJoinNode( + const PlanNodeId& id, + JoinType joinType, + TypedExprPtr joinCondition, + PlanNodePtr left, + PlanNodePtr right, + RowTypePtr outputType); + + PlanNodePtr leftNode() const { + return sources()[0]; + } + + PlanNodePtr rightNode() const { + return sources()[1]; + } + class Builder { public: Builder() = default; @@ -3864,6 +4176,9 @@ class SpatialJoinNode : public PlanNode { id_ = other.id(); joinType_ = other.joinType(); joinCondition_ = other.joinCondition(); + probeGeometry_ = other.probeGeometry(); + buildGeometry_ = other.buildGeometry(); + radius_ = other.radius(); VELOX_CHECK_EQ(other.sources().size(), 2); left_ = other.sources()[0]; right_ = other.sources()[1]; @@ -3885,6 +4200,21 @@ class SpatialJoinNode : public PlanNode { return *this; } + Builder& probeGeometry(FieldAccessTypedExprPtr probeGeometry) { + probeGeometry_ = std::move(probeGeometry); + return *this; + } + + Builder& buildGeometry(FieldAccessTypedExprPtr buildGeometry) { + buildGeometry_ = std::move(buildGeometry); + return *this; + } + + Builder& radius(FieldAccessTypedExprPtr radius) { + radius_ = std::move(radius); + return *this; + } + Builder& left(PlanNodePtr left) { left_ = std::move(left); return *this; @@ -3908,11 +4238,38 @@ class SpatialJoinNode : public PlanNode { right_.has_value(), "SpatialJoinNode right source is not set"); VELOX_USER_CHECK( outputType_.has_value(), "SpatialJoinNode outputType is not set"); + VELOX_USER_CHECK( + probeGeometry_.has_value(), + "SpatialJoinNode probe geometry is not set"); + VELOX_USER_CHECK( + buildGeometry_.has_value(), + "SpatialJoinNode build geometry is not set"); + + VELOX_USER_CHECK( + (probeGeometry_.has_value() && buildGeometry_.has_value()) || + (!probeGeometry_.has_value() && !buildGeometry_.has_value()), + "Either probe and build geometry must both be set, or neither"); + + if (probeGeometry_.has_value() && buildGeometry_.has_value()) { + return std::make_shared( + id_.value(), + joinType_, + joinCondition_, + probeGeometry_.value(), + buildGeometry_.value(), + radius_, + left_.value(), + right_.value(), + outputType_.value()); + } return std::make_shared( id_.value(), joinType_, joinCondition_, + probeGeometry_.value(), + buildGeometry_.value(), + radius_, left_.value(), right_.value(), outputType_.value()); @@ -3922,6 +4279,9 @@ class SpatialJoinNode : public PlanNode { std::optional id_; JoinType joinType_ = kDefaultJoinType; TypedExprPtr joinCondition_; + std::optional probeGeometry_; + std::optional buildGeometry_; + std::optional radius_; std::optional left_; std::optional right_; std::optional outputType_; @@ -3946,6 +4306,18 @@ class SpatialJoinNode : public PlanNode { return joinCondition_; } + const FieldAccessTypedExprPtr& probeGeometry() const { + return probeGeometry_; + } + + const FieldAccessTypedExprPtr& buildGeometry() const { + return buildGeometry_; + } + + const std::optional& radius() const { + return radius_; + } + JoinType joinType() const { return joinType_; } @@ -3964,6 +4336,9 @@ class SpatialJoinNode : public PlanNode { const JoinType joinType_; const TypedExprPtr joinCondition_; + const FieldAccessTypedExprPtr probeGeometry_; + const FieldAccessTypedExprPtr buildGeometry_; + const std::optional radius_; const std::vector sources_; const RowTypePtr outputType_; }; @@ -4071,6 +4446,10 @@ class TopNNode : public PlanNode { void accept(const PlanNodeVisitor& visitor, PlanNodeVisitorContext& context) const override; + bool requiresSingleThread() const override { + return !isPartial_; + } + int32_t count() const { return count_; } @@ -4206,6 +4585,10 @@ class LimitNode : public PlanNode { return count_; } + bool requiresSingleThread() const override { + return !isPartial_; + } + bool isPartial() const { return isPartial_; } @@ -4245,17 +4628,33 @@ class UnnestNode : public PlanNode { /// names must appear in the same order as unnestVariables. /// @param ordinalityName Optional name for the ordinality columns. If not /// present, ordinality column is not produced. - /// @param emptyUnnestValueName Optional name for column which indicates an - /// output row has empty unnest value or not. If not present, emptyUnnestValue - /// column is not provided and the unnest operator also skips producing output - /// rows with empty unnest value. + /// @param markerName Optional name for column which indicates whether an + /// output row has non-empty unnested value. If not present, marker column is + /// not provided and the unnest operator also skips producing output rows + /// with empty unnest value. + /// @param splitOutput Optional flag to control whether it should output 1 + /// batch for each input batch, or split output batches if they are too large. + /// If true, output is split into batches according to Operator's + /// outputBatchRows logic. If false, output is not split and output batches + /// match input batches 1:1. If not set, defaults to the value of the + /// unnest_split_output config in the QueryConfig. + UnnestNode( + const PlanNodeId& id, + std::vector replicateVariables, + std::vector unnestVariables, + std::vector unnestNames, + std::optional ordinalityName, + std::optional markerName, + const PlanNodePtr& source); + UnnestNode( const PlanNodeId& id, std::vector replicateVariables, std::vector unnestVariables, std::vector unnestNames, std::optional ordinalityName, - std::optional emptyUnnestValueName, + std::optional markerName, + std::optional splitOutput, const PlanNodePtr& source); class Builder { @@ -4268,6 +4667,7 @@ class UnnestNode : public PlanNode { unnestVariables_ = other.unnestVariables(); unnestNames_ = other.unnestNames_; ordinalityName_ = other.ordinalityName_; + splitOutput_ = other.splitOutput_; VELOX_CHECK_EQ(other.sources().size(), 1); source_ = other.sources()[0]; } @@ -4304,9 +4704,13 @@ class UnnestNode : public PlanNode { return *this; } - Builder& emptyUnnestValueName( - std::optional emptyUnnestValueName) { - emptyUnnestValueName_ = std::move(emptyUnnestValueName); + Builder& markerName(std::optional markerName) { + markerName_ = std::move(markerName); + return *this; + } + + Builder& splitOutput(std::optional splitOutput) { + splitOutput_ = splitOutput; return *this; } @@ -4328,7 +4732,8 @@ class UnnestNode : public PlanNode { unnestVariables_.value(), unnestNames_.value(), ordinalityName_, - emptyUnnestValueName_, + markerName_, + splitOutput_, source_.value()); } @@ -4338,8 +4743,9 @@ class UnnestNode : public PlanNode { std::optional> unnestVariables_; std::optional> unnestNames_; std::optional ordinalityName_; - std::optional emptyUnnestValueName_; + std::optional markerName_; std::optional source_; + std::optional splitOutput_; }; bool supportsBarrier() const override { @@ -4380,12 +4786,16 @@ class UnnestNode : public PlanNode { return ordinalityName_.has_value(); } - const std::optional& emptyUnnestValueName() const { - return emptyUnnestValueName_; + const std::optional& markerName() const { + return markerName_; } - bool hasEmptyUnnestValue() const { - return emptyUnnestValueName_.has_value(); + bool hasMarker() const { + return markerName_.has_value(); + } + + const std::optional& splitOutput() const { + return splitOutput_; } std::string_view name() const override { @@ -4403,7 +4813,8 @@ class UnnestNode : public PlanNode { const std::vector unnestVariables_; const std::vector unnestNames_; const std::optional ordinalityName_; - const std::optional emptyUnnestValueName_; + const std::optional markerName_; + const std::optional splitOutput_; const std::vector sources_; RowTypePtr outputType_; }; @@ -4462,6 +4873,14 @@ class EnforceSingleRowNode : public PlanNode { return sources_; } + /// Validates that input produces exactly one row, so the pipeline must + /// observe all rows sequentially on a single driver. Multiple drivers + /// would each independently produce a row (or NULL on empty input), + /// breaking the single-row contract. + bool requiresSingleThread() const override { + return true; + } + void accept(const PlanNodeVisitor& visitor, PlanNodeVisitorContext& context) const override; @@ -5060,6 +5479,10 @@ class MarkDistinctNode : public PlanNode { return "MarkDistinct"; } + bool canSpill(const QueryConfig& queryConfig) const override { + return queryConfig.markDistinctSpillEnabled(); + } + const std::string& markerName() const { return markerName_; } @@ -5086,6 +5509,208 @@ class MarkDistinctNode : public PlanNode { using MarkDistinctNodePtr = std::shared_ptr; +/// Checks that input rows have unique values in the specified key columns. +/// Passes through all input rows unchanged. Raises an exception if duplicate +/// key values are detected. +/// +/// Used to validate uniqueness constraints, such as ensuring a scalar subquery +/// returns at most one row per group. +/// +class EnforceDistinctNode : public PlanNode { + public: + /// @param distinctKeys Columns that must have unique values. + /// @param preGroupedKeys Subset of distinctKeys that input is already + /// clustered on. When preGroupedKeys equals distinctKeys, a streaming + /// implementation is used that compares consecutive rows instead of using a + /// hash table. + /// @param errorMessage Custom error message to show when duplicates are + /// found. + EnforceDistinctNode( + PlanNodeId id, + std::vector distinctKeys, + std::vector preGroupedKeys, + std::string errorMessage, + PlanNodePtr source); + + const std::vector& sources() const override { + return sources_; + } + + void accept(const PlanNodeVisitor& visitor, PlanNodeVisitorContext& context) + const override; + + const RowTypePtr& outputType() const override { + return sources_[0]->outputType(); + } + + std::string_view name() const override { + return "EnforceDistinct"; + } + + const std::vector& distinctKeys() const { + return distinctKeys_; + } + + const std::vector& preGroupedKeys() const { + return preGroupedKeys_; + } + + /// Returns true if all distinct keys are pre-grouped, meaning input is + /// clustered on distinct keys and streaming enforcement can be used. + bool isPreGrouped() const { + return preGroupedKeys_.size() == distinctKeys_.size(); + } + + const std::string& errorMessage() const { + return errorMessage_; + } + + folly::dynamic serialize() const override; + + static PlanNodePtr create(const folly::dynamic& obj, void* context); + + private: + void addDetails(std::stringstream& stream) const override; + + const std::vector distinctKeys_; + const std::vector preGroupedKeys_; + const std::string errorMessage_; + const std::vector sources_; +}; + +using EnforceDistinctNodePtr = std::shared_ptr; + +/// The MarkSorted operator marks rows where the sort key changes. +/// The result is put in a new markerName column alongside the original input. +/// The first row is always marked true. Subsequent rows are marked true if +/// they compare as sorted relative to the previous row based on sortingKeys +/// and sortingOrders. +/// @param markerName Name of the output marker channel. +/// @param sortingKeys Keys to check for sorted order. +/// @param sortingOrders Sort orders (ascending/descending, nulls first/last). +class MarkSortedNode : public PlanNode { + public: + MarkSortedNode( + PlanNodeId id, + std::string markerName, + std::vector sortingKeys, + std::vector sortingOrders, + PlanNodePtr source); + + class Builder { + public: + Builder() = default; + + explicit Builder(const MarkSortedNode& other) { + id_ = other.id(); + markerName_ = other.markerName(); + sortingKeys_ = other.sortingKeys(); + sortingOrders_ = other.sortingOrders(); + VELOX_CHECK_EQ(other.sources().size(), 1); + source_ = other.sources()[0]; + } + + Builder& id(PlanNodeId id) { + id_ = std::move(id); + return *this; + } + + Builder& markerName(std::string markerName) { + markerName_ = std::move(markerName); + return *this; + } + + Builder& sortingKeys(std::vector sortingKeys) { + sortingKeys_ = std::move(sortingKeys); + return *this; + } + + Builder& sortingOrders(std::vector sortingOrders) { + sortingOrders_ = std::move(sortingOrders); + return *this; + } + + Builder& source(PlanNodePtr source) { + source_ = std::move(source); + return *this; + } + + std::shared_ptr build() const { + VELOX_USER_CHECK(id_.has_value(), "MarkSortedNode id is not set"); + VELOX_USER_CHECK( + markerName_.has_value(), "MarkSortedNode markerName is not set"); + VELOX_USER_CHECK( + sortingKeys_.has_value(), "MarkSortedNode sortingKeys is not set"); + VELOX_USER_CHECK( + sortingOrders_.has_value(), + "MarkSortedNode sortingOrders is not set"); + VELOX_USER_CHECK(source_.has_value(), "MarkSortedNode source is not set"); + + return std::make_shared( + id_.value(), + markerName_.value(), + sortingKeys_.value(), + sortingOrders_.value(), + source_.value()); + } + + private: + std::optional id_; + std::optional markerName_; + std::optional> sortingKeys_; + std::optional> sortingOrders_; + std::optional source_; + }; + + const std::vector& sources() const override { + return sources_; + } + + void accept(const PlanNodeVisitor& visitor, PlanNodeVisitorContext& context) + const override; + + /// The outputType is the concatenation of the input columns and marker + /// column. + const RowTypePtr& outputType() const override { + return outputType_; + } + + std::string_view name() const override { + return "MarkSorted"; + } + + const std::string& markerName() const { + return markerName_; + } + + const std::vector& sortingKeys() const { + return sortingKeys_; + } + + const std::vector& sortingOrders() const { + return sortingOrders_; + } + + folly::dynamic serialize() const override; + + static PlanNodePtr create(const folly::dynamic& obj, void* context); + + private: + void addDetails(std::stringstream& stream) const override; + + const std::string markerName_; + + const std::vector sortingKeys_; + + const std::vector sortingOrders_; + + const std::vector sources_; + + const RowTypePtr outputType_; +}; + +using MarkSortedNodePtr = std::shared_ptr; + /// Optimized version of a WindowNode for a single row_number, rank or /// dense_rank function with a limit over sorted partitions. The output of this /// node contains all input columns followed by an optional @@ -5294,6 +5919,149 @@ class TopNRowNumberNode : public PlanNode { using TopNRowNumberNodePtr = std::shared_ptr; +/// Union operator that combines data from multiple inputs. +/// Supports both serial mode (process inputs one at a time) and +/// mixed mode (process inputs simultaneously and combine results). +class MixedUnionNode : public PlanNode { + public: + MixedUnionNode(const PlanNodeId& id, std::vector sources) + : MixedUnionNode(id, std::move(sources), {}) {} + + MixedUnionNode( + const PlanNodeId& id, + std::vector sources, + std::vector batchSizesPerSource) + : PlanNode(id), + sources_(std::move(sources)), + batchSizesPerSource_(std::move(batchSizesPerSource)) { + VELOX_USER_CHECK( + !sources_.empty(), "Union node must have at least one source"); + + // All sources must have the same output type + outputType_ = sources_[0]->outputType(); + for (size_t i = 1; i < sources_.size(); ++i) { + VELOX_USER_CHECK( + outputType_->equivalent(*sources_[i]->outputType()), + "All Union sources must have the same output type. " + "Source 0 type: {}, Source {} type: {}", + outputType_->toString(), + i, + sources_[i]->outputType()->toString()); + } + } + + class Builder { + public: + Builder() = default; + + explicit Builder(const MixedUnionNode& other) { + id_ = other.id(); + sources_ = other.sources(); + batchSizesPerSource_ = other.batchSizesPerSource(); + } + + Builder& id(PlanNodeId id) { + id_ = std::move(id); + return *this; + } + + Builder& sources(std::vector sources) { + sources_ = std::move(sources); + return *this; + } + + Builder& source(PlanNodePtr source) { + if (!sources_.has_value()) { + sources_ = std::vector{}; + } + sources_->push_back(std::move(source)); + return *this; + } + + Builder& batchSizesPerSource(std::vector batchSizes) { + batchSizesPerSource_ = std::move(batchSizes); + return *this; + } + + Builder& batchSizeForSource(int32_t sourceIndex, int64_t batchSize) { + if (!batchSizesPerSource_.has_value()) { + batchSizesPerSource_ = std::vector{}; + } + if (sourceIndex >= batchSizesPerSource_->size()) { + batchSizesPerSource_->resize(sourceIndex + 1, 0); + } + (*batchSizesPerSource_)[sourceIndex] = batchSize; + return *this; + } + + std::shared_ptr build() const { + VELOX_USER_CHECK(id_.has_value(), "MixedUnionNode id is not set"); + VELOX_USER_CHECK( + sources_.has_value() && !sources_->empty(), + "MixedUnionNode sources is not set or empty"); + + return std::make_shared( + id_.value(), + sources_.value(), + batchSizesPerSource_.value_or(std::vector{})); + } + + private: + std::optional id_; + std::optional> sources_; + std::optional> batchSizesPerSource_; + }; + + const std::vector& sources() const override { + return sources_; + } + + const RowTypePtr& outputType() const override { + return outputType_; + } + + /// Returns the batch sizes per source index. + /// This controls how many rows are taken from each source when mixing. + const std::vector& batchSizesPerSource() const { + return batchSizesPerSource_; + } + + /// Get batch size for a specific source index (returns 0 if not set). + int64_t getBatchSizeForSource(int32_t sourceIndex) const { + if (sourceIndex < 0 || sourceIndex >= batchSizesPerSource_.size()) { + return 0; + } + return batchSizesPerSource_[sourceIndex]; + } + + bool requiresSingleThread() const override { + return true; + } + + void accept(const PlanNodeVisitor& visitor, PlanNodeVisitorContext& context) + const override; + + std::string_view name() const override { + return "MixedUnion"; + } + + folly::dynamic serialize() const override; + + static PlanNodePtr create(const folly::dynamic& obj, void* context); + + bool supportsBarrier() const override { + return true; + } + + private: + void addDetails(std::stringstream& /* stream */) const override {} + const std::vector sources_; + RowTypePtr outputType_; + std::vector batchSizesPerSource_; +}; + +using MixedUnionNodePtr = std::shared_ptr; + class PlanNodeVisitorContext { public: virtual ~PlanNodeVisitorContext() = default; @@ -5349,6 +6117,13 @@ class PlanNodeVisitor { virtual void visit(const MarkDistinctNode& node, PlanNodeVisitorContext& ctx) const = 0; + virtual void visit( + const EnforceDistinctNode& node, + PlanNodeVisitorContext& ctx) const = 0; + + virtual void visit(const MarkSortedNode& node, PlanNodeVisitorContext& ctx) + const = 0; + virtual void visit(const MergeExchangeNode& node, PlanNodeVisitorContext& ctx) const = 0; @@ -5407,6 +6182,9 @@ class PlanNodeVisitor { virtual void visit(const WindowNode& node, PlanNodeVisitorContext& ctx) const = 0; + virtual void visit(const MixedUnionNode& node, PlanNodeVisitorContext& ctx) + const = 0; + /// Used to visit custom PlanNodes that extend the set provided by Velox. virtual void visit(const PlanNode& node, PlanNodeVisitorContext& ctx) const = 0; @@ -5419,6 +6197,126 @@ class PlanNodeVisitor { } }; +/// Plan node for async RPC execution (e.g., LLM inference, embeddings). +/// +/// Stores the function name, result type, argument columns, and streaming +/// mode. The RPCNode does NOT evaluate argument expressions — a ProjectNode +/// inserted before this node by the plan rewriter computes argument columns. +/// +/// Architecture: +/// SQL: SELECT rpc_function(col1, 'model_name') FROM table +/// | +/// v (Plan Rewriter) +/// ProjectNode (__rpc_arg_0 = col1, __rpc_arg_1 = 'model_name') +/// | +/// RPCNode (argumentColumns = [__rpc_arg_0, __rpc_arg_1]) +/// | +/// source[0] +/// | +/// TableScan +class RPCNode : public PlanNode { + public: + /// @param id Unique identifier for this plan node. + /// @param source Data source (the only source). + /// @param functionName Name of the registered AsyncRPCFunction. + /// @param functionResultType Velox type of the RPC result column. + /// @param outputColumn Name of the output column for RPC responses. + /// @param outputType Explicit output type. Must contain outputColumn + /// and any passthrough source columns needed by downstream. + /// Specified explicitly (like AbstractJoinNode) to support column + /// pruning. + /// @param argumentColumns Names of input columns containing pre-evaluated + /// argument values. RPCOperator reads these columns in addInput(). + /// @param argumentTypes Types of each argument (aligned with + /// argumentColumns). Passed to AsyncRPCFunction::initialize(). + /// @param constantInputs Constant argument values (aligned with + /// argumentColumns). nullptr for non-constant args, single-element + /// ConstantVectors for constant args. Passed to initialize(). + /// @param streamingMode The streaming mode for RPC execution. + /// @param dispatchBatchSize For BATCH mode pipelining: fire callBatch() + /// every N rows during addInput() instead of collecting all rows. + RPCNode( + const PlanNodeId& id, + PlanNodePtr source, + std::string functionName, + TypePtr functionResultType, + std::string outputColumn, + RowTypePtr outputType, + std::vector argumentColumns, + std::vector argumentTypes, + std::vector constantInputs, + rpc::RPCStreamingMode streamingMode = rpc::RPCStreamingMode::kPerRow, + int32_t dispatchBatchSize = 0); + + const PlanNodePtr& source() const { + return sources_[0]; + } + + const std::string& functionName() const { + return functionName_; + } + + const TypePtr& rpcResultType() const { + return resultType_; + } + + const std::string& outputColumn() const { + return outputColumn_; + } + + const std::vector& argumentColumns() const { + return argumentColumns_; + } + + const std::vector& argumentTypes() const { + return argumentTypes_; + } + + const std::vector& constantInputs() const { + return constantInputs_; + } + + rpc::RPCStreamingMode streamingMode() const { + return streamingMode_; + } + + int32_t dispatchBatchSize() const { + return dispatchBatchSize_; + } + + std::string_view name() const override { + return "RPC"; + } + + const RowTypePtr& outputType() const override { + return outputType_; + } + + const std::vector& sources() const override { + return sources_; + } + + folly::dynamic serialize() const override; + + static PlanNodePtr create(const folly::dynamic& obj, void* context); + + private: + void addDetails(std::stringstream& stream) const override; + + std::vector sources_; + std::string functionName_; + TypePtr resultType_; + std::string outputColumn_; + RowTypePtr outputType_; + std::vector argumentColumns_; + std::vector argumentTypes_; + std::vector constantInputs_; + rpc::RPCStreamingMode streamingMode_; + int32_t dispatchBatchSize_{0}; +}; + +using RPCNodePtr = std::shared_ptr; + } // namespace facebook::velox::core template <> diff --git a/velox/core/QueryConfig.cpp b/velox/core/QueryConfig.cpp index 3d5b25ff948..9a247e35a2c 100644 --- a/velox/core/QueryConfig.cpp +++ b/velox/core/QueryConfig.cpp @@ -22,29 +22,281 @@ namespace facebook::velox::core { -QueryConfig::QueryConfig( - const std::unordered_map& values) - : config_{std::make_unique( - std::unordered_map(values))} { - validateConfig(); +const std::vector& QueryConfig::registeredProperties() { + static const std::vector kProperties = [] { + std::vector properties; +#define VELOX_REGISTER_QUERY_CONFIG(constName) \ + config::registerConfigProperty(properties) + + // Memory. + VELOX_REGISTER_QUERY_CONFIG(kQueryMaxMemoryPerNode); + + // Session. + VELOX_REGISTER_QUERY_CONFIG(kSessionTimezone); + VELOX_REGISTER_QUERY_CONFIG(kSessionStartTime); + VELOX_REGISTER_QUERY_CONFIG(kAdjustTimestampToTimezone); + + // Expression evaluation. + VELOX_REGISTER_QUERY_CONFIG(kExprEvalSimplified); + VELOX_REGISTER_QUERY_CONFIG(kExprEvalFlatNoNulls); + VELOX_REGISTER_QUERY_CONFIG(kExprTrackCpuUsage); + VELOX_REGISTER_QUERY_CONFIG(kExprTrackCpuUsageForFunctions); + VELOX_REGISTER_QUERY_CONFIG(kExprAdaptiveCpuSampling); + VELOX_REGISTER_QUERY_CONFIG(kExprAdaptiveCpuSamplingMaxOverheadPct); + VELOX_REGISTER_QUERY_CONFIG(kExprDedupNonDeterministic); + VELOX_REGISTER_QUERY_CONFIG(kExprMaxArraySizeInReduce); + VELOX_REGISTER_QUERY_CONFIG(kExprMaxCompiledRegexes); + + // Operator. + VELOX_REGISTER_QUERY_CONFIG(kOperatorTrackCpuUsage); + + // Cast. + VELOX_REGISTER_QUERY_CONFIG(kLegacyCast); + VELOX_REGISTER_QUERY_CONFIG(kCastMatchStructByName); + + // Local exchange. + VELOX_REGISTER_QUERY_CONFIG(kMaxLocalExchangeBufferSize); + VELOX_REGISTER_QUERY_CONFIG(kMaxLocalExchangePartitionCount); + VELOX_REGISTER_QUERY_CONFIG( + kMinLocalExchangePartitionCountToUsePartitionBuffer); + VELOX_REGISTER_QUERY_CONFIG(kMaxLocalExchangePartitionBufferSize); + VELOX_REGISTER_QUERY_CONFIG(kLocalExchangePartitionBufferPreserveEncoding); + VELOX_REGISTER_QUERY_CONFIG(kLocalMergeSourceQueueSize); + + // Exchange. + VELOX_REGISTER_QUERY_CONFIG(kMaxExchangeBufferSize); + VELOX_REGISTER_QUERY_CONFIG(kMaxMergeExchangeBufferSize); + VELOX_REGISTER_QUERY_CONFIG(kMinExchangeOutputBatchBytes); + + // Aggregation. + VELOX_REGISTER_QUERY_CONFIG(kMaxPartialAggregationMemory); + VELOX_REGISTER_QUERY_CONFIG(kMaxExtendedPartialAggregationMemory); + VELOX_REGISTER_QUERY_CONFIG(kAbandonPartialAggregationMinRows); + VELOX_REGISTER_QUERY_CONFIG(kAbandonPartialAggregationMinPct); + VELOX_REGISTER_QUERY_CONFIG(kAggregationCompactionBytesThreshold); + VELOX_REGISTER_QUERY_CONFIG(kAggregationCompactionUnusedMemoryRatio); + VELOX_REGISTER_QUERY_CONFIG(kAggregationMemoryCompactionReclaimEnabled); + + // TopN row number. + VELOX_REGISTER_QUERY_CONFIG(kAbandonPartialTopNRowNumberMinRows); + VELOX_REGISTER_QUERY_CONFIG(kAbandonPartialTopNRowNumberMinPct); + + // Hash build dedup. + VELOX_REGISTER_QUERY_CONFIG(kAbandonDedupHashMapMinRows); + VELOX_REGISTER_QUERY_CONFIG(kAbandonDedupHashMapMinPct); + + // Miscellaneous. + VELOX_REGISTER_QUERY_CONFIG(kMaxElementsSizeInRepeatAndSequence); + + // Partitioned output. + VELOX_REGISTER_QUERY_CONFIG(kPartitionedOutputEagerFlush); + VELOX_REGISTER_QUERY_CONFIG(kMaxPartitionedOutputBufferSize); + VELOX_REGISTER_QUERY_CONFIG(kMaxOutputBufferSize); + + // Output batch. + VELOX_REGISTER_QUERY_CONFIG(kPreferredOutputBatchBytes); + VELOX_REGISTER_QUERY_CONFIG(kPreferredOutputBatchRows); + VELOX_REGISTER_QUERY_CONFIG(kMaxOutputBatchRows); + VELOX_REGISTER_QUERY_CONFIG(kMergeJoinOutputBatchStartSize); + + // Table scan. + VELOX_REGISTER_QUERY_CONFIG(kTableScanGetOutputTimeLimitMs); + VELOX_REGISTER_QUERY_CONFIG(kTableScanOutputBatchRowsOverride); + + // Hash table. + VELOX_REGISTER_QUERY_CONFIG(kHashAdaptivityEnabled); + VELOX_REGISTER_QUERY_CONFIG(kAdaptiveFilterReorderingEnabled); + VELOX_REGISTER_QUERY_CONFIG(kParallelOutputJoinBuildRowsEnabled); + + // Spill. + VELOX_REGISTER_QUERY_CONFIG(kSpillEnabled); + VELOX_REGISTER_QUERY_CONFIG(kAggregationSpillEnabled); + VELOX_REGISTER_QUERY_CONFIG(kJoinSpillEnabled); + VELOX_REGISTER_QUERY_CONFIG(kMixedGroupedModeHashJoinSpillEnabled); + VELOX_REGISTER_QUERY_CONFIG(kOrderBySpillEnabled); + VELOX_REGISTER_QUERY_CONFIG(kWindowSpillEnabled); + VELOX_REGISTER_QUERY_CONFIG(kWindowSpillMinReadBatchRows); + VELOX_REGISTER_QUERY_CONFIG(kWriterSpillEnabled); + VELOX_REGISTER_QUERY_CONFIG(kRowNumberSpillEnabled); + VELOX_REGISTER_QUERY_CONFIG(kMarkDistinctSpillEnabled); + VELOX_REGISTER_QUERY_CONFIG(kTopNRowNumberSpillEnabled); + VELOX_REGISTER_QUERY_CONFIG(kLocalMergeSpillEnabled); + VELOX_REGISTER_QUERY_CONFIG(kMaxSpillRunRows); + VELOX_REGISTER_QUERY_CONFIG(kMaxSpillBytes); + VELOX_REGISTER_QUERY_CONFIG(kMaxSpillLevel); + VELOX_REGISTER_QUERY_CONFIG(kMaxSpillFileSize); + VELOX_REGISTER_QUERY_CONFIG(kSpillCompressionKind); + VELOX_REGISTER_QUERY_CONFIG(kSpillNumMaxMergeFiles); + VELOX_REGISTER_QUERY_CONFIG(kSpillPrefixSortEnabled); + VELOX_REGISTER_QUERY_CONFIG(kSpillWriteBufferSize); + VELOX_REGISTER_QUERY_CONFIG(kSpillReadBufferSize); + VELOX_REGISTER_QUERY_CONFIG(kSpillFileCreateConfig); + VELOX_REGISTER_QUERY_CONFIG(kAggregationSpillFileCreateConfig); + VELOX_REGISTER_QUERY_CONFIG(kHashJoinSpillFileCreateConfig); + VELOX_REGISTER_QUERY_CONFIG(kRowNumberSpillFileCreateConfig); + VELOX_REGISTER_QUERY_CONFIG(kSpillStartPartitionBit); + VELOX_REGISTER_QUERY_CONFIG(kSpillNumPartitionBits); + VELOX_REGISTER_QUERY_CONFIG(kMinSpillableReservationPct); + VELOX_REGISTER_QUERY_CONFIG(kSpillableReservationGrowthPct); + + // Writer. + VELOX_REGISTER_QUERY_CONFIG(kWriterFlushThresholdBytes); + + // Presto-specific. + VELOX_REGISTER_QUERY_CONFIG(kPrestoArrayAggIgnoreNulls); + + // Spark-specific. + VELOX_REGISTER_QUERY_CONFIG(kSparkAnsiEnabled); + VELOX_REGISTER_QUERY_CONFIG(kSparkBloomFilterExpectedNumItems); + VELOX_REGISTER_QUERY_CONFIG(kSparkBloomFilterNumBits); + VELOX_REGISTER_QUERY_CONFIG(kSparkBloomFilterMaxNumBits); + VELOX_REGISTER_QUERY_CONFIG(kSparkBloomFilterMaxNumItems); + VELOX_REGISTER_QUERY_CONFIG(kSparkLegacyDateFormatter); + VELOX_REGISTER_QUERY_CONFIG(kSparkLegacyStatisticalAggregate); + VELOX_REGISTER_QUERY_CONFIG(kSparkJsonIgnoreNullFields); + VELOX_REGISTER_QUERY_CONFIG(kSparkCollectListIgnoreNulls); + + // Task writer. + VELOX_REGISTER_QUERY_CONFIG(kTaskWriterCount); + + // Hash probe. + VELOX_REGISTER_QUERY_CONFIG(kHashProbeFinishEarlyOnEmptyBuild); + VELOX_REGISTER_QUERY_CONFIG(kHashProbeDynamicFilterPushdownEnabled); + VELOX_REGISTER_QUERY_CONFIG(kHashProbeStringDynamicFilterPushdownEnabled); + VELOX_REGISTER_QUERY_CONFIG(kHashProbeBloomFilterPushdownMaxSize); + VELOX_REGISTER_QUERY_CONFIG(kMinTableRowsForParallelJoinBuild); + + // Debug and validation. + VELOX_REGISTER_QUERY_CONFIG(kValidateOutputFromOperators); + VELOX_REGISTER_QUERY_CONFIG(kEnableExpressionEvaluationCache); + VELOX_REGISTER_QUERY_CONFIG(kMaxSharedSubexprResultsCached); + + // Split preload. + VELOX_REGISTER_QUERY_CONFIG(kMaxSplitPreloadPerDriver); + + // Driver. + VELOX_REGISTER_QUERY_CONFIG(kDriverCpuTimeSliceLimitMs); + + // Window. + VELOX_REGISTER_QUERY_CONFIG(kWindowNumSubPartitions); + + // Prefix sort. + VELOX_REGISTER_QUERY_CONFIG(kPrefixSortNormalizedKeyMaxBytes); + VELOX_REGISTER_QUERY_CONFIG(kPrefixSortMinRows); + VELOX_REGISTER_QUERY_CONFIG(kPrefixSortMaxStringPrefixLength); + + // Query trace. + VELOX_REGISTER_QUERY_CONFIG(kQueryTraceEnabled); + VELOX_REGISTER_QUERY_CONFIG(kQueryTraceDir); + VELOX_REGISTER_QUERY_CONFIG(kQueryTraceNodeId); + VELOX_REGISTER_QUERY_CONFIG(kQueryTraceMaxBytes); + VELOX_REGISTER_QUERY_CONFIG(kQueryTraceTaskRegExp); + VELOX_REGISTER_QUERY_CONFIG(kQueryTraceDryRun); + VELOX_REGISTER_QUERY_CONFIG(kOpTraceDirectoryCreateConfig); + + // Debug expression. + VELOX_REGISTER_QUERY_CONFIG(kDebugDisableExpressionWithPeeling); + VELOX_REGISTER_QUERY_CONFIG(kDebugDisableCommonSubExpressions); + VELOX_REGISTER_QUERY_CONFIG(kDebugDisableExpressionWithMemoization); + VELOX_REGISTER_QUERY_CONFIG(kDebugDisableExpressionWithLazyInputs); + VELOX_REGISTER_QUERY_CONFIG(kDebugMemoryPoolNameRegex); + VELOX_REGISTER_QUERY_CONFIG(kDebugMemoryPoolWarnThresholdBytes); + VELOX_REGISTER_QUERY_CONFIG(kDebugLambdaFunctionEvaluationBatchSize); + VELOX_REGISTER_QUERY_CONFIG(kDebugBingTileChildrenMaxZoomShift); + + // Nimble (deprecated, kept for backward compatibility). + VELOX_REGISTER_QUERY_CONFIG(kSelectiveNimbleReaderEnabled); + + // Scale writer. + VELOX_REGISTER_QUERY_CONFIG(kScaleWriterRebalanceMaxMemoryUsageRatio); + VELOX_REGISTER_QUERY_CONFIG(kScaleWriterMaxPartitionsPerWriter); + VELOX_REGISTER_QUERY_CONFIG( + kScaleWriterMinPartitionProcessedBytesRebalanceThreshold); + VELOX_REGISTER_QUERY_CONFIG( + kScaleWriterMinProcessedBytesRebalanceThreshold); + + // Table scan scaling. + VELOX_REGISTER_QUERY_CONFIG(kTableScanScaledProcessingEnabled); + VELOX_REGISTER_QUERY_CONFIG(kTableScanScaleUpMemoryUsageRatio); + + // Shuffle. + VELOX_REGISTER_QUERY_CONFIG(kShuffleCompressionKind); + VELOX_REGISTER_QUERY_CONFIG(kMinShuffleCompressionPageSizeBytes); + + // Map. + VELOX_REGISTER_QUERY_CONFIG(kThrowExceptionOnDuplicateMapKeys); + + // Index lookup join. + VELOX_REGISTER_QUERY_CONFIG(kIndexLookupJoinMaxPrefetchBatches); + VELOX_REGISTER_QUERY_CONFIG(kIndexLookupJoinSplitOutput); + + // Exchange request. + VELOX_REGISTER_QUERY_CONFIG(kRequestDataSizesMaxWaitSec); + + // Streaming aggregation. + VELOX_REGISTER_QUERY_CONFIG(kStreamingAggregationMinOutputBatchRows); + VELOX_REGISTER_QUERY_CONFIG(kStreamingAggregationEagerFlush); + + // Exchange optimization. + VELOX_REGISTER_QUERY_CONFIG(kSkipRequestDataSizeWithSingleSourceEnabled); + VELOX_REGISTER_QUERY_CONFIG(kExchangeLazyFetchingEnabled); + + // JSON cast. + VELOX_REGISTER_QUERY_CONFIG(kFieldNamesInJsonCastEnabled); + + // Operator stats. + VELOX_REGISTER_QUERY_CONFIG(kOperatorTrackExpressionStats); + VELOX_REGISTER_QUERY_CONFIG(kEnableOperatorBatchSizeStats); + + // Unnest. + VELOX_REGISTER_QUERY_CONFIG(kUnnestSplitOutput); + + // Memory reclaimer. + VELOX_REGISTER_QUERY_CONFIG(kQueryMemoryReclaimerPriority); + + // Splits. + VELOX_REGISTER_QUERY_CONFIG(kMaxNumSplitsListenedTo); + + // Source and tags. + VELOX_REGISTER_QUERY_CONFIG(kSource); + VELOX_REGISTER_QUERY_CONFIG(kClientTags); + + // Row size tracking. + VELOX_REGISTER_QUERY_CONFIG(kRowSizeTrackingMode); + + // Join build. + VELOX_REGISTER_QUERY_CONFIG(kJoinBuildVectorHasherMaxNumDistinct); + + // Mark sorted. + VELOX_REGISTER_QUERY_CONFIG(kMarkSortedZeroCopyThreshold); + +#undef VELOX_REGISTER_QUERY_CONFIG + + return properties; + }(); + return kProperties; } -QueryConfig::QueryConfig(std::unordered_map&& values) - : config_{std::make_unique(std::move(values))} { +QueryConfig::QueryConfig(std::unordered_map values) + : QueryConfig{ + ConfigTag{}, + std::make_shared(std::move(values))} {} + +QueryConfig::QueryConfig( + ConfigTag /*tag*/, + std::shared_ptr config) + : config_{std::move(config)} { validateConfig(); } void QueryConfig::validateConfig() { // Validate if timezone name can be recognized. - if (config_->valueExists(QueryConfig::kSessionTimezone)) { + if (auto tz = config_->get(QueryConfig::kSessionTimezone)) { VELOX_USER_CHECK( - tz::getTimeZoneID( - config_->get(QueryConfig::kSessionTimezone).value(), - false) != -1, - fmt::format( - "session '{}' set with invalid value '{}'", - QueryConfig::kSessionTimezone, - config_->get(QueryConfig::kSessionTimezone).value())); + tz::getTimeZoneID(*tz, false) != -1, + "session '{}' set with invalid value '{}'", + QueryConfig::kSessionTimezone, + *tz); } } diff --git a/velox/core/QueryConfig.h b/velox/core/QueryConfig.h index 02cde3df9b4..2031009585c 100644 --- a/velox/core/QueryConfig.h +++ b/velox/core/QueryConfig.h @@ -15,726 +15,1484 @@ */ #pragma once -#include "velox/common/compression/Compression.h" #include "velox/common/config/Config.h" +#include "velox/common/config/ConfigProperty.h" #include "velox/vector/TypeAliases.h" namespace facebook::velox::core { -/// A simple wrapper around velox::ConfigBase. Defines constants for query +/// Macros for defining query config properties. +/// +/// To add a new property, add two lines (order doesn't matter, but +/// grouping related properties together helps readability): +/// +/// 1. In QueryConfig.h (inside the class body): +/// VELOX_QUERY_CONFIG(kSpillEnabled, spillEnabled, "spill_enabled", +/// bool, false, "Global enable spilling flag."); +/// +/// 2. In QueryConfig.cpp (inside registeredProperties()): +/// VELOX_REGISTER_QUERY_CONFIG(kSpillEnabled); +/// +/// VELOX_QUERY_CONFIG generates: +/// - struct kSpillEnabledProperty { +/// using type = bool; +/// static constexpr const char* key = "spill_enabled"; +/// static constexpr auto defaultValue = false; +/// static constexpr const char* description = "Global enable..."; +/// }; +/// - static constexpr const char* kSpillEnabled = "spill_enabled" +/// - bool spillEnabled() const { +/// return get(kSpillEnabled, false); +/// } +/// +/// VELOX_QUERY_CONFIG_PROPERTY generates the same but without the accessor. +/// Used for legacy properties with custom accessor logic (capacity parsing, +/// validation, clamping). New properties should use VELOX_QUERY_CONFIG and +/// put validation in QueryConfig::validateConfig() and +/// QueryConfigProvider::normalize(). +/// TODO: Unify validateConfig() and normalize() into a single validation path. +#define VELOX_QUERY_CONFIG_PROPERTY( \ + constName, keyStr, CppType, defaultVal, desc) \ + struct constName##Property { \ + using type = CppType; \ + static constexpr const char* key = keyStr; \ + static constexpr auto defaultValue = defaultVal; \ + static constexpr const char* description = desc; \ + }; \ + static constexpr const char* constName = keyStr; + +#define VELOX_QUERY_CONFIG( \ + constName, accessorName, keyStr, CppType, defaultVal, desc) \ + VELOX_QUERY_CONFIG_PROPERTY(constName, keyStr, CppType, defaultVal, desc) \ + CppType accessorName() const { \ + return get(constName, defaultVal); \ + } + +/// A simple wrapper around velox::IConfig. Defines constants for query /// config properties and accessor methods. /// Create per query context. Does not have a singleton instance. /// Does not allow altering properties on the fly. Only at creation time. class QueryConfig { public: + explicit QueryConfig(std::unordered_map values); + + // This is needed only to resolve correct ctor for cases like + // QueryConfig{{}} or QueryConfig({}). + struct ConfigTag {}; + explicit QueryConfig( - const std::unordered_map& values); + ConfigTag /*tag*/, + std::shared_ptr config); - explicit QueryConfig(std::unordered_map&& values); + /// Returns all registered query config properties. + static const std::vector& registeredProperties(); /// Maximum memory that a query can use on a single host. - static constexpr const char* kQueryMaxMemoryPerNode = - "query_max_memory_per_node"; + VELOX_QUERY_CONFIG_PROPERTY( + kQueryMaxMemoryPerNode, + "query_max_memory_per_node", + std::string, + "0B", + "Maximum memory that a query can use on a single host.") /// User provided session timezone. Stores a string with the actual timezone /// name, e.g: "America/Los_Angeles". - static constexpr const char* kSessionTimezone = "session_timezone"; + VELOX_QUERY_CONFIG( + kSessionTimezone, + sessionTimezone, + "session_timezone", + std::string, + "", + "Session timezone name, e.g. 'America/Los_Angeles'.") + + /// Session start time in milliseconds since Unix epoch. This represents when + /// the query session began execution. Used for functions that need to know + /// the session start time (e.g., current_date, localtime). + VELOX_QUERY_CONFIG( + kSessionStartTime, + sessionStartTimeMs, + "start_time", + int64_t, + 0, + "Session start time in milliseconds since Unix epoch.") /// If true, timezone-less timestamp conversions (e.g. string to timestamp, /// when the string does not specify a timezone) will be adjusted to the user /// provided session timezone (if any). - /// - /// For instance: - /// - /// if this option is true and user supplied "America/Los_Angeles", - /// "1970-01-01" will be converted to -28800 instead of 0. - /// - /// False by default. - static constexpr const char* kAdjustTimestampToTimezone = - "adjust_timestamp_to_session_timezone"; + VELOX_QUERY_CONFIG( + kAdjustTimestampToTimezone, + adjustTimestampToTimezone, + "adjust_timestamp_to_session_timezone", + bool, + false, + "Adjust timezone-less timestamp conversions to session timezone.") /// Whether to use the simplified expression evaluation path. False by /// default. - static constexpr const char* kExprEvalSimplified = - "expression.eval_simplified"; + VELOX_QUERY_CONFIG( + kExprEvalSimplified, + exprEvalSimplified, + "expression.eval_simplified", + bool, + false, + "Use simplified expression evaluation path.") + + /// Whether to enable the FlatNoNulls fast path for expression evaluation. + /// When enabled, expressions skip null checking and vector decoding when all + /// inputs are flat-encoded with no nulls. True by default. + VELOX_QUERY_CONFIG( + kExprEvalFlatNoNulls, + exprEvalFlatNoNulls, + "expression.eval_flat_no_nulls", + bool, + true, + "Enable FlatNoNulls fast path for expression evaluation.") /// Whether to track CPU usage for individual expressions (supported by call /// and cast expressions). False by default. Can be expensive when processing /// small batches, e.g. < 10K rows. - static constexpr const char* kExprTrackCpuUsage = - "expression.track_cpu_usage"; + VELOX_QUERY_CONFIG( + kExprTrackCpuUsage, + exprTrackCpuUsage, + "expression.track_cpu_usage", + bool, + false, + "Track CPU usage for individual expressions.") + + /// Takes a comma separated list of function names to track CPU usage for. + /// Only applicable when kExprTrackCpuUsage is set to false. Is empty by + /// default. + VELOX_QUERY_CONFIG( + kExprTrackCpuUsageForFunctions, + exprTrackCpuUsageForFunctions, + "expression.track_cpu_usage_for_functions", + std::string, + "", + "Comma-separated function names to track CPU usage for.") + + /// Enables adaptive per-function CPU usage sampling. When enabled, each + /// function is calibrated over the first 6 batches (1 warmup + 5 + /// calibration) to measure the overhead of CPU tracking (clock_gettime). + VELOX_QUERY_CONFIG( + kExprAdaptiveCpuSampling, + exprAdaptiveCpuSampling, + "expression.adaptive_cpu_sampling", + bool, + false, + "Enable adaptive per-function CPU usage sampling.") + + /// Maximum acceptable overhead percentage for CPU tracking per function. + /// Used with kExprAdaptiveCpuSampling. + VELOX_QUERY_CONFIG( + kExprAdaptiveCpuSamplingMaxOverheadPct, + exprAdaptiveCpuSamplingMaxOverheadPct, + "expression.adaptive_cpu_sampling_max_overhead_pct", + double, + 1.0, + "Maximum acceptable overhead percentage for CPU tracking per function.") + + /// Controls whether non-deterministic expressions are deduplicated during + /// compilation. If set to false, non-deterministic functions (such as + /// rand()) will not be deduplicated. + VELOX_QUERY_CONFIG( + kExprDedupNonDeterministic, + exprDedupNonDeterministic, + "expression.dedup_non_deterministic", + bool, + true, + "Deduplicate non-deterministic expressions during compilation.") /// Whether to track CPU usage for stages of individual operators. True by /// default. Can be expensive when processing small batches, e.g. < 10K rows. - static constexpr const char* kOperatorTrackCpuUsage = - "track_operator_cpu_usage"; + VELOX_QUERY_CONFIG( + kOperatorTrackCpuUsage, + operatorTrackCpuUsage, + "track_operator_cpu_usage", + bool, + true, + "Track CPU usage for stages of individual operators.") /// Flags used to configure the CAST operator: - - static constexpr const char* kLegacyCast = "legacy_cast"; + VELOX_QUERY_CONFIG( + kLegacyCast, + isLegacyCast, + "legacy_cast", + bool, + false, + "Use legacy CAST behavior.") /// This flag makes the Row conversion to by applied in a way that the casting /// row field are matched by name instead of position. - static constexpr const char* kCastMatchStructByName = - "cast_match_struct_by_name"; + VELOX_QUERY_CONFIG( + kCastMatchStructByName, + isMatchStructByName, + "cast_match_struct_by_name", + bool, + false, + "Match Row fields by name instead of position in CAST.") /// Reduce() function will throw an error if encountered an array of size /// greater than this. - static constexpr const char* kExprMaxArraySizeInReduce = - "expression.max_array_size_in_reduce"; + VELOX_QUERY_CONFIG( + kExprMaxArraySizeInReduce, + exprMaxArraySizeInReduce, + "expression.max_array_size_in_reduce", + uint64_t, + 100'000, + "Maximum array size allowed in reduce() function.") /// Controls maximum number of compiled regular expression patterns per /// function instance per thread of execution. - static constexpr const char* kExprMaxCompiledRegexes = - "expression.max_compiled_regexes"; + VELOX_QUERY_CONFIG( + kExprMaxCompiledRegexes, + exprMaxCompiledRegexes, + "expression.max_compiled_regexes", + uint64_t, + 100, + "Maximum compiled regex patterns per function instance per thread.") /// Used for backpressure to block local exchange producers when the local /// exchange buffer reaches or exceeds this size. - static constexpr const char* kMaxLocalExchangeBufferSize = - "max_local_exchange_buffer_size"; + VELOX_QUERY_CONFIG( + kMaxLocalExchangeBufferSize, + maxLocalExchangeBufferSize, + "max_local_exchange_buffer_size", + uint64_t, + 32UL << 20, + "Max local exchange buffer size in bytes for backpressure.") /// Limits the number of partitions created by a local exchange. - /// Partitioning data too granularly can lead to poor performance. - /// This setting allows increasing the task concurrency for all - /// pipelines except the ones that require a local partitioning. - /// Affects the number of drivers for pipelines containing - /// LocalPartitionNode and cannot exceed the maximum number of - /// pipeline drivers configured for the task. - static constexpr const char* kMaxLocalExchangePartitionCount = - "max_local_exchange_partition_count"; + VELOX_QUERY_CONFIG( + kMaxLocalExchangePartitionCount, + maxLocalExchangePartitionCount, + "max_local_exchange_partition_count", + uint32_t, + std::numeric_limits::max(), + "Maximum number of partitions in local exchange.") /// Minimum number of local exchange output partitions to use buffered /// partitioning. - /// - /// When the number of output partitions is low, it is preferred to process - /// one input vector at a time. For example, with 10 output partitions - /// splitting a single 100KB input vector into 10 10KB vectors is acceptable. - /// However, when the number of output partitions is high it may result in a - /// large number of tiny vectors generated. For example, with 100 output - /// partitions splitting a single 100KB input vector results in 100 1KB - /// vectors. Exchanging and processing tiny vectors may negatively impact - /// performance. To avoid this, buffered partitioning is used to accumulate - /// larger vectors. - static constexpr const char* - kMinLocalExchangePartitionCountToUsePartitionBuffer = - "min_local_exchange_partition_count_to_use_partition_buffer"; + VELOX_QUERY_CONFIG( + kMinLocalExchangePartitionCountToUsePartitionBuffer, + minLocalExchangePartitionCountToUsePartitionBuffer, + "min_local_exchange_partition_count_to_use_partition_buffer", + uint32_t, + 33, + "Minimum output partitions to use buffered partitioning.") /// Maximum size in bytes to accumulate for a single partition of a local /// exchange before flushing. - /// - /// The total amount of memory used by a single - /// local exchange operator is the sum of the sizes of all partitions. For - /// example, if the number of downstream pipeline drivers is 10 and the max - /// local exchange partition buffer size is 100KB, then the total memory used - /// by a single local exchange operator is 1MB. The total memory needed to - /// perform a local exchange is equal to the single local exchange - /// operator memory multiplied by the number of upstream pipeline drivers. For - /// example, if the number of upstream pipeline drivers is 10 the total memory - /// used by the local exchange operator is 10MB. - static constexpr const char* kMaxLocalExchangePartitionBufferSize = - "max_local_exchange_partition_buffer_size"; + VELOX_QUERY_CONFIG( + kMaxLocalExchangePartitionBufferSize, + maxLocalExchangePartitionBufferSize, + "max_local_exchange_partition_buffer_size", + uint64_t, + 64UL * 1024, + "Max bytes to accumulate per local exchange partition.") /// Try to preserve the encoding of the input vector when copying it to the /// buffer. - static constexpr const char* kLocalExchangePartitionBufferPreserveEncoding = - "local_exchange_partition_buffer_preserve_encoding"; + VELOX_QUERY_CONFIG( + kLocalExchangePartitionBufferPreserveEncoding, + localExchangePartitionBufferPreserveEncoding, + "local_exchange_partition_buffer_preserve_encoding", + bool, + false, + "Preserve encoding of input vector in partition buffer.") /// Maximum number of vectors buffered in each local merge source before /// blocking to wait for consumers. - static constexpr const char* kLocalMergeSourceQueueSize = - "local_merge_source_queue_size"; + VELOX_QUERY_CONFIG( + kLocalMergeSourceQueueSize, + localMergeSourceQueueSize, + "local_merge_source_queue_size", + uint32_t, + 2, + "Maximum vectors buffered in each local merge source.") /// Maximum size in bytes to accumulate in ExchangeQueue. Enforced /// approximately, not strictly. - static constexpr const char* kMaxExchangeBufferSize = - "exchange.max_buffer_size"; + VELOX_QUERY_CONFIG( + kMaxExchangeBufferSize, + maxExchangeBufferSize, + "exchange.max_buffer_size", + uint64_t, + 32UL << 20, + "Maximum bytes to accumulate in ExchangeQueue.") /// Maximum size in bytes to accumulate among all sources of the merge /// exchange. Enforced approximately, not strictly. - static constexpr const char* kMaxMergeExchangeBufferSize = - "merge_exchange.max_buffer_size"; + VELOX_QUERY_CONFIG( + kMaxMergeExchangeBufferSize, + maxMergeExchangeBufferSize, + "merge_exchange.max_buffer_size", + uint64_t, + 128UL << 20, + "Maximum bytes to accumulate among all merge exchange sources.") /// The minimum number of bytes to accumulate in the ExchangeQueue - /// before unblocking a consumer. This is used to avoid creating tiny - /// batches which may have a negative impact on performance when the - /// cost of creating vectors is high (for example, when there are many - /// columns). To avoid latency degradation, the exchange client unblocks a - /// consumer when 1% of the data size observed so far is accumulated. - static constexpr const char* kMinExchangeOutputBatchBytes = - "min_exchange_output_batch_bytes"; - - static constexpr const char* kMaxPartialAggregationMemory = - "max_partial_aggregation_memory"; - - static constexpr const char* kMaxExtendedPartialAggregationMemory = - "max_extended_partial_aggregation_memory"; - - static constexpr const char* kAbandonPartialAggregationMinRows = - "abandon_partial_aggregation_min_rows"; - - static constexpr const char* kAbandonPartialAggregationMinPct = - "abandon_partial_aggregation_min_pct"; - - static constexpr const char* kAbandonPartialTopNRowNumberMinRows = - "abandon_partial_topn_row_number_min_rows"; - - static constexpr const char* kAbandonPartialTopNRowNumberMinPct = - "abandon_partial_topn_row_number_min_pct"; - - static constexpr const char* kMaxElementsSizeInRepeatAndSequence = - "max_elements_size_in_repeat_and_sequence"; + /// before unblocking a consumer. + VELOX_QUERY_CONFIG( + kMinExchangeOutputBatchBytes, + minExchangeOutputBatchBytes, + "min_exchange_output_batch_bytes", + uint64_t, + 2UL << 20, + "Minimum bytes to accumulate before unblocking an exchange consumer.") + + VELOX_QUERY_CONFIG( + kMaxPartialAggregationMemory, + maxPartialAggregationMemoryUsage, + "max_partial_aggregation_memory", + uint64_t, + 1L << 24, + "Maximum memory for partial aggregation.") + + VELOX_QUERY_CONFIG( + kMaxExtendedPartialAggregationMemory, + maxExtendedPartialAggregationMemoryUsage, + "max_extended_partial_aggregation_memory", + uint64_t, + 1L << 26, + "Maximum memory for extended partial aggregation.") + + VELOX_QUERY_CONFIG( + kAbandonPartialAggregationMinRows, + abandonPartialAggregationMinRows, + "abandon_partial_aggregation_min_rows", + int32_t, + 100'000, + "Minimum input rows before checking whether to abandon partial aggregation.") + + VELOX_QUERY_CONFIG( + kAbandonPartialAggregationMinPct, + abandonPartialAggregationMinPct, + "abandon_partial_aggregation_min_pct", + int32_t, + 80, + "Abandon partial aggregation if reduction percentage exceeds this.") + + /// Memory threshold in bytes for triggering string compaction during + /// global aggregation. Disabled by default (0). + VELOX_QUERY_CONFIG( + kAggregationCompactionBytesThreshold, + aggregationCompactionBytesThreshold, + "aggregation_compaction_bytes_threshold", + uint64_t, + 0, + "Memory threshold in bytes for triggering string compaction during aggregation.") + + /// Ratio of unused (evicted) bytes to total bytes that triggers compaction. + VELOX_QUERY_CONFIG( + kAggregationCompactionUnusedMemoryRatio, + aggregationCompactionUnusedMemoryRatio, + "aggregation_compaction_unused_memory_ratio", + double, + 0.25, + "Ratio of unused bytes to total bytes that triggers compaction.") + + /// If true, enables lightweight memory compaction before spilling during + /// memory reclaim in aggregation. + VELOX_QUERY_CONFIG( + kAggregationMemoryCompactionReclaimEnabled, + aggregationMemoryCompactionReclaimEnabled, + "aggregation_memory_compaction_reclaim_enabled", + bool, + false, + "Enable memory compaction before spilling during aggregation reclaim.") + + VELOX_QUERY_CONFIG( + kAbandonPartialTopNRowNumberMinRows, + abandonPartialTopNRowNumberMinRows, + "abandon_partial_topn_row_number_min_rows", + int32_t, + 100'000, + "Minimum rows before checking whether to abandon partial TopN row number.") + + VELOX_QUERY_CONFIG( + kAbandonPartialTopNRowNumberMinPct, + abandonPartialTopNRowNumberMinPct, + "abandon_partial_topn_row_number_min_pct", + int32_t, + 80, + "Abandon partial TopN row number if reduction percentage exceeds this.") + + /// Number of input rows to receive before starting to check whether to + /// abandon building a HashTable without duplicates in HashBuild. + VELOX_QUERY_CONFIG( + kAbandonDedupHashMapMinRows, + abandonHashBuildDedupMinRows, + "abandon_dedup_hashmap_min_rows", + int32_t, + 100'000, + "Minimum rows before checking whether to abandon dedup hash map.") + + /// Abandons building a HashTable without duplicates in HashBuild for left + /// semi/anti join if the percentage of distinct keys exceeds this threshold. + VELOX_QUERY_CONFIG( + kAbandonDedupHashMapMinPct, + abandonHashBuildDedupMinPct, + "abandon_dedup_hashmap_min_pct", + int32_t, + 0, + "Abandon dedup hash map if distinct key percentage exceeds this. 0 disables.") + + VELOX_QUERY_CONFIG( + kMaxElementsSizeInRepeatAndSequence, + maxElementsSizeInRepeatAndSequence, + "max_elements_size_in_repeat_and_sequence", + int32_t, + 10'000, + "Maximum elements size in repeat and sequence functions.") + + /// If true, the PartitionedOutput operator will flush rows eagerly. + VELOX_QUERY_CONFIG( + kPartitionedOutputEagerFlush, + partitionedOutputEagerFlush, + "partitioned_output_eager_flush", + bool, + false, + "Flush PartitionedOutput rows eagerly without buffering.") /// The maximum number of bytes to buffer in PartitionedOutput operator to /// avoid creating tiny SerializedPages. - /// - /// For PartitionedOutputNode::Kind::kPartitioned, PartitionedOutput operator - /// would buffer up to that number of bytes / number of destinations for each - /// destination before producing a SerializedPage. - static constexpr const char* kMaxPartitionedOutputBufferSize = - "max_page_partitioning_buffer_size"; + VELOX_QUERY_CONFIG( + kMaxPartitionedOutputBufferSize, + maxPartitionedOutputBufferSize, + "max_page_partitioning_buffer_size", + uint64_t, + 32UL << 20, + "Maximum bytes to buffer in PartitionedOutput operator.") /// The maximum size in bytes for the task's buffered output. - /// - /// The producer Drivers are blocked when the buffered size exceeds - /// this. The Drivers are resumed when the buffered size goes below - /// OutputBufferManager::kContinuePct % of this. - static constexpr const char* kMaxOutputBufferSize = "max_output_buffer_size"; + VELOX_QUERY_CONFIG( + kMaxOutputBufferSize, + maxOutputBufferSize, + "max_output_buffer_size", + uint64_t, + 32UL << 20, + "Maximum size in bytes for the task's buffered output.") /// Preferred size of batches in bytes to be returned by operators from - /// Operator::getOutput. It is used when an estimate of average row size is - /// known. Otherwise kPreferredOutputBatchRows is used. - static constexpr const char* kPreferredOutputBatchBytes = - "preferred_output_batch_bytes"; + /// Operator::getOutput. + VELOX_QUERY_CONFIG( + kPreferredOutputBatchBytes, + preferredOutputBatchBytes, + "preferred_output_batch_bytes", + uint64_t, + 10UL << 20, + "Preferred size of output batches in bytes.") /// Preferred number of rows to be returned by operators from - /// Operator::getOutput. It is used when an estimate of average row size is - /// not known. When the estimate of average row size is known, - /// kPreferredOutputBatchBytes is used. - static constexpr const char* kPreferredOutputBatchRows = - "preferred_output_batch_rows"; + /// Operator::getOutput. Used when average row size is not known. + VELOX_QUERY_CONFIG_PROPERTY( + kPreferredOutputBatchRows, + "preferred_output_batch_rows", + uint32_t, + 1024, + "Preferred number of rows in output batches.") /// Max number of rows that could be return by operators from - /// Operator::getOutput. It is used when an estimate of average row size is - /// known and kPreferredOutputBatchBytes is used to compute the number of - /// output rows. - static constexpr const char* kMaxOutputBatchRows = "max_output_batch_rows"; + /// Operator::getOutput. + VELOX_QUERY_CONFIG_PROPERTY( + kMaxOutputBatchRows, + "max_output_batch_rows", + uint32_t, + 10000, + "Maximum number of rows in output batches.") + + /// Initial output batch size in rows for MergeJoin operator. + VELOX_QUERY_CONFIG_PROPERTY( + kMergeJoinOutputBatchStartSize, + "merge_join_output_batch_start_size", + uint32_t, + 0, + "Initial output batch size in rows for MergeJoin. 0 disables dynamic adjustment.") /// TableScan operator will exit getOutput() method after this many /// milliseconds even if it has no data to return yet. Zero means 'no time /// limit'. - static constexpr const char* kTableScanGetOutputTimeLimitMs = - "table_scan_getoutput_time_limit_ms"; + VELOX_QUERY_CONFIG( + kTableScanGetOutputTimeLimitMs, + tableScanGetOutputTimeLimitMs, + "table_scan_getoutput_time_limit_ms", + uint32_t, + 5'000, + "Time limit in ms for TableScan getOutput(). 0 means no limit.") + + /// If non-zero, overrides the number of rows in each output batch produced + /// by the TableScan operator. Zero means 'no override'. + VELOX_QUERY_CONFIG( + kTableScanOutputBatchRowsOverride, + tableScanOutputBatchRowsOverride, + "table_scan_output_batch_rows_override", + uint32_t, + 0, + "Override number of rows in TableScan output batches. 0 means no override.") /// If false, the 'group by' code is forced to use generic hash mode /// hashtable. - static constexpr const char* kHashAdaptivityEnabled = - "hash_adaptivity_enabled"; + VELOX_QUERY_CONFIG( + kHashAdaptivityEnabled, + hashAdaptivityEnabled, + "hash_adaptivity_enabled", + bool, + true, + "Enable hash adaptivity for group-by hash mode selection.") /// If true, the conjunction expression can reorder inputs based on the time /// taken to calculate them. - static constexpr const char* kAdaptiveFilterReorderingEnabled = - "adaptive_filter_reordering_enabled"; + VELOX_QUERY_CONFIG( + kAdaptiveFilterReorderingEnabled, + adaptiveFilterReorderingEnabled, + "adaptive_filter_reordering_enabled", + bool, + true, + "Allow conjunction to reorder inputs based on evaluation time.") + + /// If true, allow hash probe drivers to generate build-side rows in parallel. + VELOX_QUERY_CONFIG( + kParallelOutputJoinBuildRowsEnabled, + parallelOutputJoinBuildRowsEnabled, + "parallel_output_join_build_rows_enabled", + bool, + false, + "Allow hash probe drivers to generate build-side rows in parallel.") /// Global enable spilling flag. - static constexpr const char* kSpillEnabled = "spill_enabled"; + VELOX_QUERY_CONFIG( + kSpillEnabled, + spillEnabled, + "spill_enabled", + bool, + false, + "Global enable spilling flag.") /// Aggregation spilling flag, only applies if "spill_enabled" flag is set. - static constexpr const char* kAggregationSpillEnabled = - "aggregation_spill_enabled"; + VELOX_QUERY_CONFIG( + kAggregationSpillEnabled, + aggregationSpillEnabled, + "aggregation_spill_enabled", + bool, + true, + "Enable aggregation spilling. Requires spill_enabled.") /// Join spilling flag, only applies if "spill_enabled" flag is set. - static constexpr const char* kJoinSpillEnabled = "join_spill_enabled"; + VELOX_QUERY_CONFIG( + kJoinSpillEnabled, + joinSpillEnabled, + "join_spill_enabled", + bool, + true, + "Enable join spilling. Requires spill_enabled.") /// Config to enable hash join spill for mixed grouped execution mode. - static constexpr const char* kMixedGroupedModeHashJoinSpillEnabled = - "mixed_grouped_mode_hash_join_spill_enabled"; + VELOX_QUERY_CONFIG( + kMixedGroupedModeHashJoinSpillEnabled, + mixedGroupedModeHashJoinSpillEnabled, + "mixed_grouped_mode_hash_join_spill_enabled", + bool, + false, + "Enable hash join spill for mixed grouped execution mode.") /// OrderBy spilling flag, only applies if "spill_enabled" flag is set. - static constexpr const char* kOrderBySpillEnabled = "order_by_spill_enabled"; + VELOX_QUERY_CONFIG( + kOrderBySpillEnabled, + orderBySpillEnabled, + "order_by_spill_enabled", + bool, + true, + "Enable order-by spilling. Requires spill_enabled.") /// Window spilling flag, only applies if "spill_enabled" flag is set. - static constexpr const char* kWindowSpillEnabled = "window_spill_enabled"; + VELOX_QUERY_CONFIG( + kWindowSpillEnabled, + windowSpillEnabled, + "window_spill_enabled", + bool, + true, + "Enable window spilling. Requires spill_enabled.") + + /// When processing spilled window data, read batches of whole partitions + /// having at least that many rows. + VELOX_QUERY_CONFIG( + kWindowSpillMinReadBatchRows, + windowSpillMinReadBatchRows, + "window_spill_min_read_batch_rows", + uint32_t, + 1'000, + "Minimum rows to read per batch when processing spilled window data.") /// If true, the memory arbitrator will reclaim memory from table writer by - /// flushing its buffered data to disk. only applies if "spill_enabled" flag - /// is set. - static constexpr const char* kWriterSpillEnabled = "writer_spill_enabled"; + /// flushing its buffered data to disk. Requires spill_enabled. + VELOX_QUERY_CONFIG( + kWriterSpillEnabled, + writerSpillEnabled, + "writer_spill_enabled", + bool, + true, + "Enable writer spilling by flushing buffered data to disk.") /// RowNumber spilling flag, only applies if "spill_enabled" flag is set. - static constexpr const char* kRowNumberSpillEnabled = - "row_number_spill_enabled"; + VELOX_QUERY_CONFIG( + kRowNumberSpillEnabled, + rowNumberSpillEnabled, + "row_number_spill_enabled", + bool, + true, + "Enable RowNumber spilling. Requires spill_enabled.") + + /// MarkDistinct spilling flag, only applies if "spill_enabled" flag is set. + VELOX_QUERY_CONFIG( + kMarkDistinctSpillEnabled, + markDistinctSpillEnabled, + "mark_distinct_spill_enabled", + bool, + false, + "Enable MarkDistinct spilling. Requires spill_enabled.") /// TopNRowNumber spilling flag, only applies if "spill_enabled" flag is set. - static constexpr const char* kTopNRowNumberSpillEnabled = - "topn_row_number_spill_enabled"; + VELOX_QUERY_CONFIG( + kTopNRowNumberSpillEnabled, + topNRowNumberSpillEnabled, + "topn_row_number_spill_enabled", + bool, + true, + "Enable TopNRowNumber spilling. Requires spill_enabled.") /// LocalMerge spilling flag, only applies if "spill_enabled" flag is set. - static constexpr const char* kLocalMergeSpillEnabled = - "local_merge_spill_enabled"; + VELOX_QUERY_CONFIG( + kLocalMergeSpillEnabled, + localMergeSpillEnabled, + "local_merge_spill_enabled", + bool, + false, + "Enable LocalMerge spilling. Requires spill_enabled.") /// Specify the max number of local sources to merge at a time. + /// Uses std::numeric_limits::max() as default. static constexpr const char* kLocalMergeMaxNumMergeSources = "local_merge_max_num_merge_sources"; - /// The max row numbers to fill and spill for each spill run. This is used to - /// cap the memory used for spilling. If it is zero, then there is no limit - /// and spilling might run out of memory. - /// Based on offline test results, the default value is set to 12 million rows - /// which uses ~128MB memory when to fill a spill run. - static constexpr const char* kMaxSpillRunRows = "max_spill_run_rows"; - - /// The max spill bytes limit set for each query. This is used to cap the - /// storage used for spilling. If it is zero, then there is no limit and - /// spilling might exhaust the storage or takes too long to run. The default - /// value is set to 100 GB. - static constexpr const char* kMaxSpillBytes = "max_spill_bytes"; + /// The max row numbers to fill and spill for each spill run. Default is + /// 12 million rows which uses ~128MB memory. + VELOX_QUERY_CONFIG( + kMaxSpillRunRows, + maxSpillRunRows, + "max_spill_run_rows", + uint64_t, + 12UL << 20, + "Maximum rows to fill and spill per spill run. 0 means no limit.") + + /// The max spill bytes limit set for each query. Default is 100 GB. + VELOX_QUERY_CONFIG( + kMaxSpillBytes, + maxSpillBytes, + "max_spill_bytes", + uint64_t, + 100UL << 30, + "Maximum total spill bytes per query. 0 means no limit.") /// The max allowed spilling level with zero being the initial spilling level. - /// This only applies for hash build spilling which might trigger recursive - /// spilling when the build table is too big. If it is set to -1, then there - /// is no limit and then some extreme large query might run out of spilling - /// partition bits (see kSpillPartitionBits) at the end. The max spill level - /// is used in production to prevent some bad user queries from using too much - /// io and cpu resources. - static constexpr const char* kMaxSpillLevel = "max_spill_level"; + VELOX_QUERY_CONFIG( + kMaxSpillLevel, + maxSpillLevel, + "max_spill_level", + int32_t, + 1, + "Maximum allowed spilling level. -1 means no limit.") /// The max allowed spill file size. If it is zero, then there is no limit. - static constexpr const char* kMaxSpillFileSize = "max_spill_file_size"; - - static constexpr const char* kSpillCompressionKind = - "spill_compression_codec"; - - /// Enable the prefix sort or fallback to timsort in spill. The prefix sort is - /// faster than std::sort but requires the memory to build normalized prefix - /// keys, which might have potential risk of running out of server memory. - static constexpr const char* kSpillPrefixSortEnabled = - "spill_prefixsort_enabled"; - - /// Specifies spill write buffer size in bytes. The spiller tries to buffer - /// serialized spill data up to the specified size before write to storage - /// underneath for io efficiency. If it is set to zero, then spill write - /// buffering is disabled. - static constexpr const char* kSpillWriteBufferSize = - "spill_write_buffer_size"; - - /// Specifies the buffer size in bytes to read from one spilled file. If the - /// underlying filesystem supports async read, we do read-ahead with double - /// buffering, which doubles the buffer used to read from each spill file. - static constexpr const char* kSpillReadBufferSize = "spill_read_buffer_size"; - - /// Config used to create spill files. This config is provided to underlying - /// file system and the config is free form. The form should be defined by the - /// underlying file system. - static constexpr const char* kSpillFileCreateConfig = - "spill_file_create_config"; - - /// Default offset spill start partition bit. It is used with - /// 'kSpillNumPartitionBits' together to - /// calculate the spilling partition number for join spill or aggregation - /// spill. - static constexpr const char* kSpillStartPartitionBit = - "spiller_start_partition_bit"; - - /// Default number of spill partition bits. It is the number of bits used to - /// calculate the spill partition number for hash join and RowNumber. The - /// number of spill partitions will be power of two. - /// - /// NOTE: as for now, we only support up to 8-way spill partitioning. - static constexpr const char* kSpillNumPartitionBits = - "spiller_num_partition_bits"; + VELOX_QUERY_CONFIG( + kMaxSpillFileSize, + maxSpillFileSize, + "max_spill_file_size", + uint64_t, + 0, + "Maximum spill file size. 0 means no limit.") + + VELOX_QUERY_CONFIG( + kSpillCompressionKind, + spillCompressionKind, + "spill_compression_codec", + std::string, + "none", + "Compression codec for spill data.") + + /// The max number of files to merge at a time when merging sorted spilled + /// files. 0 means unlimited. + VELOX_QUERY_CONFIG( + kSpillNumMaxMergeFiles, + spillNumMaxMergeFiles, + "spill_num_max_merge_files", + uint32_t, + 0, + "Maximum files to merge at a time when merging sorted spill files. 0 means unlimited.") + + /// Enable the prefix sort or fallback to timsort in spill. + VELOX_QUERY_CONFIG( + kSpillPrefixSortEnabled, + spillPrefixSortEnabled, + "spill_prefixsort_enabled", + bool, + false, + "Enable prefix sort in spill instead of timsort.") + + /// Specifies spill write buffer size in bytes. 0 disables buffering. + VELOX_QUERY_CONFIG( + kSpillWriteBufferSize, + spillWriteBufferSize, + "spill_write_buffer_size", + uint64_t, + 1L << 20, + "Spill write buffer size in bytes. 0 disables buffering.") + + /// Specifies the buffer size in bytes to read from one spilled file. + VELOX_QUERY_CONFIG( + kSpillReadBufferSize, + spillReadBufferSize, + "spill_read_buffer_size", + uint64_t, + 1L << 20, + "Buffer size in bytes to read from one spilled file.") + + /// Config used to create spill files. + VELOX_QUERY_CONFIG( + kSpillFileCreateConfig, + spillFileCreateConfig, + "spill_file_create_config", + std::string, + "", + "Config for creating spill files, passed to underlying file system.") + + /// Config used to create aggregation spill files. + VELOX_QUERY_CONFIG( + kAggregationSpillFileCreateConfig, + aggregationSpillFileCreateConfig, + "aggregation_spill_file_create_config", + std::string, + "", + "Config for creating aggregation spill files.") + + /// Config used to create hash join spill files. + VELOX_QUERY_CONFIG( + kHashJoinSpillFileCreateConfig, + hashJoinSpillFileCreateConfig, + "hash_join_spill_file_create_config", + std::string, + "", + "Config for creating hash join spill files.") + + /// Config used to create row number spill files. + VELOX_QUERY_CONFIG( + kRowNumberSpillFileCreateConfig, + rowNumberSpillFileCreateConfig, + "row_number_spill_file_create_config", + std::string, + "", + "Config for creating row number spill files.") + + /// Default offset spill start partition bit. + VELOX_QUERY_CONFIG( + kSpillStartPartitionBit, + spillStartPartitionBit, + "spiller_start_partition_bit", + uint8_t, + 48, + "Start partition bit offset for spilling.") + + /// Default number of spill partition bits. Max 3 (8-way partitioning). + VELOX_QUERY_CONFIG_PROPERTY( + kSpillNumPartitionBits, + "spiller_num_partition_bits", + uint8_t, + 3, + "Number of bits for spill partition calculation.") /// The minimal available spillable memory reservation in percentage of the - /// current memory usage. Suppose the current memory usage size of M, - /// available memory reservation size of N and min reservation percentage of - /// P, if M * P / 100 > N, then spiller operator needs to grow the memory - /// reservation with percentage of spillableReservationGrowthPct(). This - /// ensures we have sufficient amount of memory reservation to process the - /// large input outlier. - static constexpr const char* kMinSpillableReservationPct = - "min_spillable_reservation_pct"; - - /// The spillable memory reservation growth percentage of the previous memory - /// reservation size. 10 means exponential growth along a series of integer - /// powers of 11/10. The reservation grows by this much until it no longer - /// can, after which it starts spilling. - static constexpr const char* kSpillableReservationGrowthPct = - "spillable_reservation_growth_pct"; + /// current memory usage. + VELOX_QUERY_CONFIG( + kMinSpillableReservationPct, + minSpillableReservationPct, + "min_spillable_reservation_pct", + int32_t, + 5, + "Minimum available spillable memory reservation as percentage of usage.") + + /// The spillable memory reservation growth percentage. + VELOX_QUERY_CONFIG( + kSpillableReservationGrowthPct, + spillableReservationGrowthPct, + "spillable_reservation_growth_pct", + int32_t, + 10, + "Spillable memory reservation growth percentage.") /// Minimum memory footprint size required to reclaim memory from a file /// writer by flushing its buffered data to disk. - static constexpr const char* kWriterFlushThresholdBytes = - "writer_flush_threshold_bytes"; + VELOX_QUERY_CONFIG( + kWriterFlushThresholdBytes, + writerFlushThresholdBytes, + "writer_flush_threshold_bytes", + uint64_t, + 96L << 20, + "Minimum memory footprint to reclaim from a file writer by flushing.") /// If true, array_agg() aggregation function will ignore nulls in the input. - static constexpr const char* kPrestoArrayAggIgnoreNulls = - "presto.array_agg.ignore_nulls"; - - /// If true, Spark function's behavior is ANSI-compliant, e.g. throws runtime - /// exception instead of returning null on invalid inputs. It affects only - /// functions explicitly marked as "ANSI compliant". - /// Note: This feature is still under development to achieve full ANSI - /// compliance. Users can refer to the Spark function documentation to verify - /// the current support status of a specific function. - static constexpr const char* kSparkAnsiEnabled = "spark.ansi_enabled"; + VELOX_QUERY_CONFIG( + kPrestoArrayAggIgnoreNulls, + prestoArrayAggIgnoreNulls, + "presto.array_agg.ignore_nulls", + bool, + false, + "If true, array_agg() ignores nulls in the input.") + + /// If true, Spark function's behavior is ANSI-compliant. + VELOX_QUERY_CONFIG( + kSparkAnsiEnabled, + sparkAnsiEnabled, + "spark.ansi_enabled", + bool, + false, + "Enable ANSI-compliant behavior for Spark functions.") /// The default number of expected items for the bloomfilter. - static constexpr const char* kSparkBloomFilterExpectedNumItems = - "spark.bloom_filter.expected_num_items"; + VELOX_QUERY_CONFIG( + kSparkBloomFilterExpectedNumItems, + sparkBloomFilterExpectedNumItems, + "spark.bloom_filter.expected_num_items", + int64_t, + 1'000'000L, + "Default number of expected items for the Spark bloom filter.") /// The default number of bits to use for the bloom filter. - static constexpr const char* kSparkBloomFilterNumBits = - "spark.bloom_filter.num_bits"; + VELOX_QUERY_CONFIG( + kSparkBloomFilterNumBits, + sparkBloomFilterNumBits, + "spark.bloom_filter.num_bits", + int64_t, + 8'388'608L, + "Default number of bits for the Spark bloom filter.") /// The max number of bits to use for the bloom filter. - static constexpr const char* kSparkBloomFilterMaxNumBits = - "spark.bloom_filter.max_num_bits"; - - /// The current spark partition id. + VELOX_QUERY_CONFIG( + kSparkBloomFilterMaxNumBits, + sparkBloomFilterMaxNumBits, + "spark.bloom_filter.max_num_bits", + int64_t, + 67'108'864, + "Maximum number of bits for the Spark bloom filter.") + + /// The max number of items to use for the bloom filter. + VELOX_QUERY_CONFIG( + kSparkBloomFilterMaxNumItems, + sparkBloomFilterMaxNumItems, + "spark.bloom_filter.max_num_items", + int64_t, + 4'000'000L, + "Maximum number of items for the Spark bloom filter.") + + /// The current spark partition id. No default (throws if not set). static constexpr const char* kSparkPartitionId = "spark.partition_id"; /// If true, simple date formatter is used for time formatting and parsing. - /// Joda date formatter is used by default. - static constexpr const char* kSparkLegacyDateFormatter = - "spark.legacy_date_formatter"; - - /// If true, Spark statistical aggregation functions including skewness, - /// kurtosis, stddev, stddev_samp, variance, var_samp, covar_samp and corr - /// will return NaN instead of NULL when dividing by zero during expression - /// evaluation. - static constexpr const char* kSparkLegacyStatisticalAggregate = - "spark.legacy_statistical_aggregate"; + VELOX_QUERY_CONFIG( + kSparkLegacyDateFormatter, + sparkLegacyDateFormatter, + "spark.legacy_date_formatter", + bool, + false, + "Use simple date formatter instead of Joda for Spark.") + + /// If true, Spark statistical aggregation functions return NaN instead of + /// NULL when dividing by zero. + VELOX_QUERY_CONFIG( + kSparkLegacyStatisticalAggregate, + sparkLegacyStatisticalAggregate, + "spark.legacy_statistical_aggregate", + bool, + false, + "Return NaN instead of NULL for Spark statistical aggregation on divide-by-zero.") /// If true, ignore null fields when generating JSON string. - /// If false, null fields are included with a null value. - static constexpr const char* kSparkJsonIgnoreNullFields = - "spark.json_ignore_null_fields"; + VELOX_QUERY_CONFIG( + kSparkJsonIgnoreNullFields, + sparkJsonIgnoreNullFields, + "spark.json_ignore_null_fields", + bool, + true, + "Ignore null fields when generating JSON string in Spark.") + + /// If true, collect_list aggregate function will ignore nulls in the input. + VELOX_QUERY_CONFIG( + kSparkCollectListIgnoreNulls, + sparkCollectListIgnoreNulls, + "spark.collect_list.ignore_nulls", + bool, + true, + "If true, Spark collect_list() ignores nulls in the input.") /// The number of local parallel table writer operators per task. - static constexpr const char* kTaskWriterCount = "task_writer_count"; + VELOX_QUERY_CONFIG( + kTaskWriterCount, + taskWriterCount, + "task_writer_count", + uint32_t, + 4, + "Number of local parallel table writer operators per task.") /// The number of local parallel table writer operators per task for - /// partitioned writes. If not set, use "task_writer_count". + /// partitioned writes. If not set, use "task_writer_count". No default. static constexpr const char* kTaskPartitionedWriterCount = "task_partitioned_writer_count"; /// If true, finish the hash probe on an empty build table for a specific set /// of hash joins. - static constexpr const char* kHashProbeFinishEarlyOnEmptyBuild = - "hash_probe_finish_early_on_empty_build"; + VELOX_QUERY_CONFIG( + kHashProbeFinishEarlyOnEmptyBuild, + hashProbeFinishEarlyOnEmptyBuild, + "hash_probe_finish_early_on_empty_build", + bool, + false, + "Finish hash probe early on empty build table.") + + /// Whether hash probe can generate any dynamic filter and push down to + /// upstream operators. + VELOX_QUERY_CONFIG( + kHashProbeDynamicFilterPushdownEnabled, + hashProbeDynamicFilterPushdownEnabled, + "hash_probe_dynamic_filter_pushdown_enabled", + bool, + true, + "Enable dynamic filter generation and pushdown from hash probe.") + + /// Whether hash probe can generate dynamic filter for string types. + VELOX_QUERY_CONFIG( + kHashProbeStringDynamicFilterPushdownEnabled, + hashProbeStringDynamicFilterPushdownEnabled, + "hash_probe_string_dynamic_filter_pushdown_enabled", + bool, + false, + "Enable dynamic filter pushdown for string types from hash probe.") + + /// The maximum byte size of Bloom filter from hash probe. 0 = disabled. + VELOX_QUERY_CONFIG( + kHashProbeBloomFilterPushdownMaxSize, + hashProbeBloomFilterPushdownMaxSize, + "hash_probe_bloom_filter_pushdown_max_size", + uint64_t, + 0, + "Maximum byte size of Bloom filter from hash probe. 0 disables.") /// The minimum number of table rows that can trigger the parallel hash join /// table build. - static constexpr const char* kMinTableRowsForParallelJoinBuild = - "min_table_rows_for_parallel_join_build"; - - /// If set to true, then during execution of tasks, the output vectors of - /// every operator are validated for consistency. This is an expensive check - /// so should only be used for debugging. It can help debug issues where - /// malformed vector cause failures or crashes by helping identify which - /// operator is generating them. - static constexpr const char* kValidateOutputFromOperators = - "debug.validate_output_from_operators"; - - /// If true, enable caches in expression evaluation for performance, including - /// ExecCtx::vectorPool_, ExecCtx::decodedVectorPool_, - /// ExecCtx::selectivityVectorPool_, Expr::baseDictionary_, - /// Expr::dictionaryCache_, and Expr::cachedDictionaryIndices_. Otherwise, - /// disable the caches. - static constexpr const char* kEnableExpressionEvaluationCache = - "enable_expression_evaluation_cache"; + VELOX_QUERY_CONFIG( + kMinTableRowsForParallelJoinBuild, + minTableRowsForParallelJoinBuild, + "min_table_rows_for_parallel_join_build", + uint32_t, + 1'000, + "Minimum table rows to trigger parallel hash join table build.") + + /// If set to true, validate output vectors of every operator for consistency. + VELOX_QUERY_CONFIG( + kValidateOutputFromOperators, + validateOutputFromOperators, + "debug.validate_output_from_operators", + bool, + false, + "Validate output vectors from every operator for consistency.") + + /// If true, enable caches in expression evaluation for performance. + VELOX_QUERY_CONFIG( + kEnableExpressionEvaluationCache, + isExpressionEvaluationCacheEnabled, + "enable_expression_evaluation_cache", + bool, + true, + "Enable caches in expression evaluation for performance.") /// For a given shared subexpression, the maximum distinct sets of inputs we - /// cache results for. Lambdas can call the same expression with different - /// inputs many times, causing the results we cache to explode in size. - /// Putting a limit contains the memory usage. - static constexpr const char* kMaxSharedSubexprResultsCached = - "max_shared_subexpr_results_cached"; + /// cache results for. + VELOX_QUERY_CONFIG( + kMaxSharedSubexprResultsCached, + maxSharedSubexprResultsCached, + "max_shared_subexpr_results_cached", + uint32_t, + 10, + "Maximum distinct input sets to cache for shared subexpressions.") /// Maximum number of splits to preload. Set to 0 to disable preloading. - static constexpr const char* kMaxSplitPreloadPerDriver = - "max_split_preload_per_driver"; + VELOX_QUERY_CONFIG( + kMaxSplitPreloadPerDriver, + maxSplitPreloadPerDriver, + "max_split_preload_per_driver", + int32_t, + 2, + "Maximum number of splits to preload. 0 disables preloading.") /// If not zero, specifies the cpu time slice limit in ms that a driver thread - /// can continuously run without yielding. If it is zero, then there is no - /// limit. - static constexpr const char* kDriverCpuTimeSliceLimitMs = - "driver_cpu_time_slice_limit_ms"; - - /// Maximum number of bytes to use for the normalized key in prefix-sort. Use - /// 0 to disable prefix-sort. - static constexpr const char* kPrefixSortNormalizedKeyMaxBytes = - "prefixsort_normalized_key_max_bytes"; - - /// Minimum number of rows to use prefix-sort. The default value has been - /// derived using micro-benchmarking. - static constexpr const char* kPrefixSortMinRows = "prefixsort_min_rows"; + /// can continuously run without yielding. + VELOX_QUERY_CONFIG( + kDriverCpuTimeSliceLimitMs, + driverCpuTimeSliceLimitMs, + "driver_cpu_time_slice_limit_ms", + uint32_t, + 0, + "CPU time slice limit in ms before yielding. 0 means no limit.") + + /// Window operator sub-partition count per thread. + VELOX_QUERY_CONFIG( + kWindowNumSubPartitions, + windowNumSubPartitions, + "window_num_sub_partitions", + uint32_t, + 1, + "Number of sub-partitions per thread in Window operator.") + + /// Maximum number of bytes to use for the normalized key in prefix-sort. + VELOX_QUERY_CONFIG( + kPrefixSortNormalizedKeyMaxBytes, + prefixSortNormalizedKeyMaxBytes, + "prefixsort_normalized_key_max_bytes", + uint32_t, + 128, + "Maximum bytes for normalized key in prefix-sort. 0 disables.") + + /// Minimum number of rows to use prefix-sort. + VELOX_QUERY_CONFIG( + kPrefixSortMinRows, + prefixSortMinRows, + "prefixsort_min_rows", + uint32_t, + 128, + "Minimum rows to use prefix-sort.") /// Maximum number of bytes to be stored in prefix-sort buffer for a string /// key. - static constexpr const char* kPrefixSortMaxStringPrefixLength = - "prefixsort_max_string_prefix_length"; + VELOX_QUERY_CONFIG( + kPrefixSortMaxStringPrefixLength, + prefixSortMaxStringPrefixLength, + "prefixsort_max_string_prefix_length", + uint32_t, + 16, + "Maximum bytes stored in prefix-sort buffer for a string key.") /// Enable query tracing flag. - static constexpr const char* kQueryTraceEnabled = "query_trace_enabled"; + VELOX_QUERY_CONFIG( + kQueryTraceEnabled, + queryTraceEnabled, + "query_trace_enabled", + bool, + false, + "Enable query tracing.") /// Base dir of a query to store tracing data. - static constexpr const char* kQueryTraceDir = "query_trace_dir"; + VELOX_QUERY_CONFIG( + kQueryTraceDir, + queryTraceDir, + "query_trace_dir", + std::string, + "", + "Base directory for storing query trace data.") /// The plan node id whose input data will be traced. - /// Empty string if only want to trace the query metadata. - static constexpr const char* kQueryTraceNodeId = "query_trace_node_id"; + VELOX_QUERY_CONFIG( + kQueryTraceNodeId, + queryTraceNodeId, + "query_trace_node_id", + std::string, + "", + "Plan node id whose input data will be traced.") /// The max trace bytes limit. Tracing is disabled if zero. - static constexpr const char* kQueryTraceMaxBytes = "query_trace_max_bytes"; - - /// The regexp of traced task id. We only enable trace on a task if its id - /// matches. - static constexpr const char* kQueryTraceTaskRegExp = - "query_trace_task_reg_exp"; - - /// If true, we only collect the input trace for a given operator but without - /// the actual execution. - static constexpr const char* kQueryTraceDryRun = "query_trace_dry_run"; - - /// Config used to create operator trace directory. This config is provided to - /// underlying file system and the config is free form. The form should be - /// defined by the underlying file system. - static constexpr const char* kOpTraceDirectoryCreateConfig = - "op_trace_directory_create_config"; + VELOX_QUERY_CONFIG( + kQueryTraceMaxBytes, + queryTraceMaxBytes, + "query_trace_max_bytes", + uint64_t, + 0, + "Maximum trace bytes. Tracing disabled if zero.") + + /// The regexp of traced task id. + VELOX_QUERY_CONFIG( + kQueryTraceTaskRegExp, + queryTraceTaskRegExp, + "query_trace_task_reg_exp", + std::string, + "", + "Regexp for task ids to enable tracing on.") + + /// If true, only collect input trace without actual execution. + VELOX_QUERY_CONFIG( + kQueryTraceDryRun, + queryTraceDryRun, + "query_trace_dry_run", + bool, + false, + "Collect input trace without actual execution.") + + /// Config used to create operator trace directory. + VELOX_QUERY_CONFIG( + kOpTraceDirectoryCreateConfig, + opTraceDirectoryCreateConfig, + "op_trace_directory_create_config", + std::string, + "", + "Config for creating operator trace directory.") /// Disable optimization in expression evaluation to peel common dictionary /// layer from inputs. - static constexpr const char* kDebugDisableExpressionWithPeeling = - "debug_disable_expression_with_peeling"; + VELOX_QUERY_CONFIG( + kDebugDisableExpressionWithPeeling, + debugDisableExpressionsWithPeeling, + "debug_disable_expression_with_peeling", + bool, + false, + "Disable dictionary peeling optimization in expression evaluation.") + + /// Minimum number of rows in the selectivity vector for peeling to be + /// applied during expression evaluation. For small batches, the overhead of + /// peeling can outweigh the benefits. + VELOX_QUERY_CONFIG( + kMinRowsForPeeling, + minRowsForPeeling, + "expression.min_rows_for_peeling", + int32_t, + 0, + "Minimum number of rows to process for peeling optimization in expression evaluation to be active."); /// Disable optimization in expression evaluation to re-use cached results for /// common sub-expressions. - static constexpr const char* kDebugDisableCommonSubExpressions = - "debug_disable_common_sub_expressions"; + VELOX_QUERY_CONFIG( + kDebugDisableCommonSubExpressions, + debugDisableCommonSubExpressions, + "debug_disable_common_sub_expressions", + bool, + false, + "Disable common sub-expression caching in expression evaluation.") /// Disable optimization in expression evaluation to re-use cached results - /// between subsequent input batches that are dictionary encoded and have the - /// same alphabet(underlying flat vector). - static constexpr const char* kDebugDisableExpressionWithMemoization = - "debug_disable_expression_with_memoization"; + /// between subsequent dictionary-encoded input batches. + VELOX_QUERY_CONFIG( + kDebugDisableExpressionWithMemoization, + debugDisableExpressionsWithMemoization, + "debug_disable_expression_with_memoization", + bool, + false, + "Disable dictionary memoization in expression evaluation.") /// Disable optimization in expression evaluation to delay loading of lazy /// inputs unless required. - static constexpr const char* kDebugDisableExpressionWithLazyInputs = - "debug_disable_expression_with_lazy_inputs"; - - /// Fix the random seed used to create data structure used in - /// approx_percentile. This makes the query result deterministic on single - /// node; multi-node partial aggregation is still subject to non-determinism - /// due to non-deterministic merge order. + VELOX_QUERY_CONFIG( + kDebugDisableExpressionWithLazyInputs, + debugDisableExpressionsWithLazyInputs, + "debug_disable_expression_with_lazy_inputs", + bool, + false, + "Disable lazy input loading optimization in expression evaluation.") + + /// Fix the random seed used in approx_percentile. No default (optional). static constexpr const char* kDebugAggregationApproxPercentileFixedRandomSeed = "debug_aggregation_approx_percentile_fixed_random_seed"; - /// When debug is enabled for memory manager, this is used to match the memory - /// pools that need allocation callsites tracking. Default to track nothing. - static constexpr const char* kDebugMemoryPoolNameRegex = - "debug_memory_pool_name_regex"; - - /// Warning threshold in bytes for debug memory pools. When set to a - /// non-zero value, a warning will be logged once per memory pool when - /// allocations cause the pool to exceed this threshold. This is useful for - /// identifying memory usage patterns during debugging. Requires allocation - /// tracking to be enabled via `debug_memory_pool_name_regex` for the pool. A - /// value of 0 means no warning threshold is enforced. - static constexpr const char* kDebugMemoryPoolWarnThresholdBytes = - "debug_memory_pool_warn_threshold_bytes"; - - /// Some lambda functions over arrays and maps are evaluated in batches of the - /// underlying elements that comprise the arrays/maps. This is done to make - /// the batch size manageable as array vectors can have thousands of elements - /// each and hit scaling limits as implementations typically expect - /// BaseVectors to a couple of thousand entries. This lets up tune those batch - /// sizes. - static constexpr const char* kDebugLambdaFunctionEvaluationBatchSize = - "debug_lambda_function_evaluation_batch_size"; - - /// The UDF `bing_tile_children` generates the children of a Bing tile based - /// on a specified target zoom level. The number of children produced is - /// determined by the difference between the target zoom level and the zoom - /// level of the input tile. This configuration limits the number of children - /// by capping the maximum zoom level difference, with a default value set - /// to 5. This cap is necessary to prevent excessively large array outputs, - /// which can exceed the size limits of the elements vector in the Velox array - /// vector. - static constexpr const char* kDebugBingTileChildrenMaxZoomShift = - "debug_bing_tile_children_max_zoom_shift"; - - /// Temporary flag to control whether selective Nimble reader should be used - /// in this query or not. Will be removed after the selective Nimble reader - /// is fully rolled out. - static constexpr const char* kSelectiveNimbleReaderEnabled = - "selective_nimble_reader_enabled"; - - /// The max ratio of a query used memory to its max capacity, and the scale - /// writer exchange stops scaling writer processing if the query's current - /// memory usage exceeds this ratio. The value is in the range of (0, 1]. - static constexpr const char* kScaleWriterRebalanceMaxMemoryUsageRatio = - "scaled_writer_rebalance_max_memory_usage_ratio"; - - /// The max number of logical table partitions that can be assigned to a - /// single table writer thread. The logical table partition is used by local - /// exchange writer for writer scaling, and multiple physical table - /// partitions can be mapped to the same logical table partition based on the - /// hash value of calculated partitioned ids. - static constexpr const char* kScaleWriterMaxPartitionsPerWriter = - "scaled_writer_max_partitions_per_writer"; - - /// Minimum amount of data processed by a logical table partition to trigger - /// writer scaling if it is detected as overloaded by scale wrirer exchange. - static constexpr const char* - kScaleWriterMinPartitionProcessedBytesRebalanceThreshold = - "scaled_writer_min_partition_processed_bytes_rebalance_threshold"; - - /// Minimum amount of data processed by all the logical table partitions to - /// trigger skewed partition rebalancing by scale writer exchange. - static constexpr const char* kScaleWriterMinProcessedBytesRebalanceThreshold = - "scaled_writer_min_processed_bytes_rebalance_threshold"; - - /// If true, enables the scaled table scan processing. For each table scan - /// plan node, a scan controller is used to control the number of running scan - /// threads based on the query memory usage. It keeps increasing the number of - /// running threads until the query memory usage exceeds the threshold defined - /// by 'table_scan_scale_up_memory_usage_ratio'. - static constexpr const char* kTableScanScaledProcessingEnabled = - "table_scan_scaled_processing_enabled"; - - /// The query memory usage ratio used by scan controller to decide if it can - /// increase the number of running scan threads. When the query memory usage - /// is below this ratio, the scan controller keeps increasing the running scan - /// thread for scale up, and stop once exceeds this ratio. The value is in the - /// range of [0, 1]. - /// - /// NOTE: this only applies if 'table_scan_scaled_processing_enabled' is true. - static constexpr const char* kTableScanScaleUpMemoryUsageRatio = - "table_scan_scale_up_memory_usage_ratio"; - - /// Specifies the shuffle compression kind which is defined by - /// CompressionKind. If it is CompressionKind_NONE, then no compression. - static constexpr const char* kShuffleCompressionKind = - "shuffle_compression_codec"; - - /// If a key is found in multiple given maps, by default that key's value in - /// the resulting map comes from the last one of those maps. When true, throw - /// exception on duplicate map key. - static constexpr const char* kThrowExceptionOnDuplicateMapKeys = - "throw_exception_on_duplicate_map_keys"; - - /// Specifies the max number of input batches to prefetch to do index lookup - /// ahead. If it is zero, then process one input batch at a time. - static constexpr const char* kIndexLookupJoinMaxPrefetchBatches = - "index_lookup_join_max_prefetch_batches"; - - /// If this is true, then the index join operator might split output for each - /// input batch based on the output batch size control. Otherwise, it tries to - /// produce a single output for each input batch. - static constexpr const char* kIndexLookupJoinSplitOutput = - "index_lookup_join_split_output"; + /// Regex to match memory pools for allocation callsite tracking. + VELOX_QUERY_CONFIG( + kDebugMemoryPoolNameRegex, + debugMemoryPoolNameRegex, + "debug_memory_pool_name_regex", + std::string, + "", + "Regex to match memory pools for allocation callsite tracking.") + + /// Warning threshold in bytes for debug memory pools. Uses toCapacity + /// for parsing. + VELOX_QUERY_CONFIG_PROPERTY( + kDebugMemoryPoolWarnThresholdBytes, + "debug_memory_pool_warn_threshold_bytes", + std::string, + "0B", + "Warning threshold in bytes for debug memory pools.") + + /// Batch size for lambda function evaluation over arrays and maps. + VELOX_QUERY_CONFIG( + kDebugLambdaFunctionEvaluationBatchSize, + debugLambdaFunctionEvaluationBatchSize, + "debug_lambda_function_evaluation_batch_size", + int32_t, + 10'000, + "Batch size for lambda function evaluation over arrays and maps.") + + /// Max zoom level difference for bing_tile_children UDF. + VELOX_QUERY_CONFIG( + kDebugBingTileChildrenMaxZoomShift, + debugBingTileChildrenMaxZoomShift, + "debug_bing_tile_children_max_zoom_shift", + uint8_t, + 7, + "Maximum zoom level difference for bing_tile_children.") + + /// Deprecated: Use FileConfig::kSelectiveNimbleReaderEnabledSession instead. + /// Kept for backward compatibility with Presto session properties. + VELOX_QUERY_CONFIG( + kSelectiveNimbleReaderEnabled, + selectiveNimbleReaderEnabled, + "selective_nimble_reader_enabled", + bool, + true, + "Enable selective Nimble reader.") + + /// The max ratio of query memory usage to max capacity for scale writer + /// exchange to stop scaling. + VELOX_QUERY_CONFIG( + kScaleWriterRebalanceMaxMemoryUsageRatio, + scaleWriterRebalanceMaxMemoryUsageRatio, + "scaled_writer_rebalance_max_memory_usage_ratio", + double, + 0.7, + "Max memory usage ratio before scale writer exchange stops scaling.") + + /// The max number of logical table partitions per writer thread. + VELOX_QUERY_CONFIG( + kScaleWriterMaxPartitionsPerWriter, + scaleWriterMaxPartitionsPerWriter, + "scaled_writer_max_partitions_per_writer", + uint32_t, + 128, + "Max logical table partitions per table writer thread.") + + /// Minimum data processed per partition to trigger writer scaling. + VELOX_QUERY_CONFIG( + kScaleWriterMinPartitionProcessedBytesRebalanceThreshold, + scaleWriterMinPartitionProcessedBytesRebalanceThreshold, + "scaled_writer_min_partition_processed_bytes_rebalance_threshold", + uint64_t, + 128 << 20, + "Minimum bytes processed per partition to trigger writer scaling.") + + /// Minimum total data processed to trigger skewed partition rebalancing. + VELOX_QUERY_CONFIG( + kScaleWriterMinProcessedBytesRebalanceThreshold, + scaleWriterMinProcessedBytesRebalanceThreshold, + "scaled_writer_min_processed_bytes_rebalance_threshold", + uint64_t, + 256 << 20, + "Minimum total bytes processed to trigger skewed partition rebalancing.") + + /// If true, enables the scaled table scan processing. + VELOX_QUERY_CONFIG( + kTableScanScaledProcessingEnabled, + tableScanScaledProcessingEnabled, + "table_scan_scaled_processing_enabled", + bool, + false, + "Enable scaled table scan processing based on memory usage.") + + /// The query memory usage ratio for scan controller to decide if it can + /// increase running scan threads. + VELOX_QUERY_CONFIG( + kTableScanScaleUpMemoryUsageRatio, + tableScanScaleUpMemoryUsageRatio, + "table_scan_scale_up_memory_usage_ratio", + double, + 0.7, + "Memory usage ratio threshold for scaling up scan threads.") + + /// Specifies the shuffle compression kind. + VELOX_QUERY_CONFIG( + kShuffleCompressionKind, + shuffleCompressionKind, + "shuffle_compression_codec", + std::string, + "none", + "Compression codec for shuffle data.") + + /// Minimum serialized page size in bytes to attempt shuffle compression. + VELOX_QUERY_CONFIG( + kMinShuffleCompressionPageSizeBytes, + minShuffleCompressionPageSizeBytes, + "min_shuffle_compression_page_size_bytes", + int32_t, + 0, + "Minimum serialized page size in bytes to attempt shuffle compression.") + + /// When true, throw exception on duplicate map key. + VELOX_QUERY_CONFIG( + kThrowExceptionOnDuplicateMapKeys, + throwExceptionOnDuplicateMapKeys, + "throw_exception_on_duplicate_map_keys", + bool, + false, + "Throw exception on duplicate map keys.") + + /// Max number of input batches to prefetch for index lookup. 0 = disabled. + VELOX_QUERY_CONFIG( + kIndexLookupJoinMaxPrefetchBatches, + indexLookupJoinMaxPrefetchBatches, + "index_lookup_join_max_prefetch_batches", + uint32_t, + 0, + "Maximum input batches to prefetch for index lookup. 0 disables.") + + /// If true, index join operator may split output per input batch. + VELOX_QUERY_CONFIG( + kIndexLookupJoinSplitOutput, + indexLookupJoinSplitOutput, + "index_lookup_join_split_output", + bool, + true, + "Allow index join operator to split output based on batch size.") // Max wait time for exchange request in seconds. - static constexpr const char* kRequestDataSizesMaxWaitSec = - "request_data_sizes_max_wait_sec"; - - /// In streaming aggregation, wait until we have enough number of output rows - /// to produce a batch of size specified by this. If set to 0, then - /// Operator::outputBatchRows will be used as the min output batch rows. - static constexpr const char* kStreamingAggregationMinOutputBatchRows = - "streaming_aggregation_min_output_batch_rows"; + VELOX_QUERY_CONFIG( + kRequestDataSizesMaxWaitSec, + requestDataSizesMaxWaitSec, + "request_data_sizes_max_wait_sec", + int32_t, + 10, + "Maximum wait time for exchange request in seconds.") + + /// In streaming aggregation, min output rows before producing a batch. + VELOX_QUERY_CONFIG( + kStreamingAggregationMinOutputBatchRows, + streamingAggregationMinOutputBatchRows, + "streaming_aggregation_min_output_batch_rows", + int32_t, + 0, + "Minimum rows to accumulate before producing streaming aggregation output. 0 uses default.") /// TODO: Remove after dependencies are cleaned up. - static constexpr const char* kStreamingAggregationEagerFlush = - "streaming_aggregation_eager_flush"; - - /// If this is true, then it allows you to get the struct field names - /// as json element names when casting a row to json. - static constexpr const char* kFieldNamesInJsonCastEnabled = - "field_names_in_json_cast_enabled"; - - /// If this is true, then operators that evaluate expressions will track - /// stats for expressions that are not special forms and return them as - /// part of their operator stats. Tracking these stats can be expensive - /// (especially if operator stats are retrieved frequently) and this allows - /// the user to explicitly enable it. - static constexpr const char* kOperatorTrackExpressionStats = - "operator_track_expression_stats"; - - /// If this is true, enable the operator input/output batch size stats - /// collection in driver execution. This can be expensive for data types with - /// a large number of columns (e.g., ROW types) as it calls estimateFlatSize() - /// which recursively calculates sizes for all child vectors. - static constexpr const char* kEnableOperatorBatchSizeStats = - "enable_operator_batch_size_stats"; - - /// If this is true, then the unnest operator might split output for each - /// input batch based on the output batch size control. Otherwise, it produces - /// a single output for each input batch. - static constexpr const char* kUnnestSplitOutput = "unnest_split_output"; - - /// Priority of the query in the memory pool reclaimer. Lower value means - /// higher priority. This is used in global arbitration victim selection. - static constexpr const char* kQueryMemoryReclaimerPriority = - "query_memory_reclaimer_priority"; - - /// The max number of input splits to listen to by SplitListener per table - /// scan node per worker. It's up to the SplitListener implementation to - /// respect this config. - static constexpr const char* kMaxNumSplitsListenedTo = - "max_num_splits_listened_to"; - - /// Source of the query. Used by Presto to identify the file system username. - static constexpr const char* kSource = "source"; - - /// Client tags of the query. Used by Presto to identify the file system - /// username. - static constexpr const char* kClientTags = "client_tags"; - - bool selectiveNimbleReaderEnabled() const { - return get(kSelectiveNimbleReaderEnabled, false); - } - - bool debugDisableExpressionsWithPeeling() const { - return get(kDebugDisableExpressionWithPeeling, false); - } - - bool debugDisableCommonSubExpressions() const { - return get(kDebugDisableCommonSubExpressions, false); - } - - bool debugDisableExpressionsWithMemoization() const { - return get(kDebugDisableExpressionWithMemoization, false); - } - - bool debugDisableExpressionsWithLazyInputs() const { - return get(kDebugDisableExpressionWithLazyInputs, false); - } - - std::string debugMemoryPoolNameRegex() const { - return get(kDebugMemoryPoolNameRegex, ""); - } - - uint64_t debugMemoryPoolWarnThresholdBytes() const { - return config::toCapacity( - get(kDebugMemoryPoolWarnThresholdBytes, "0B"), - config::CapacityUnit::BYTE); - } - - std::optional debugAggregationApproxPercentileFixedRandomSeed() - const { - return get(kDebugAggregationApproxPercentileFixedRandomSeed); - } - - int32_t debugLambdaFunctionEvaluationBatchSize() const { - return get(kDebugLambdaFunctionEvaluationBatchSize, 10'000); - } - - uint8_t debugBingTileChildrenMaxZoomShift() const { - return get(kDebugBingTileChildrenMaxZoomShift, 5); - } + VELOX_QUERY_CONFIG( + kStreamingAggregationEagerFlush, + streamingAggregationEagerFlush, + "streaming_aggregation_eager_flush", + bool, + false, + "Enable eager flush for streaming aggregation.") + + // If true, skip request data size if there is only single source. + VELOX_QUERY_CONFIG( + kSkipRequestDataSizeWithSingleSourceEnabled, + singleSourceExchangeOptimizationEnabled, + "skip_request_data_size_with_single_source_enabled", + bool, + false, + "Skip request data size check with single exchange source.") + + /// If true, exchange clients defer data fetching until next() is called. + VELOX_QUERY_CONFIG( + kExchangeLazyFetchingEnabled, + exchangeLazyFetchingEnabled, + "exchange_lazy_fetching_enabled", + bool, + false, + "Defer exchange data fetching until next() is called.") + + /// If true, use struct field names as JSON element names when casting row to + /// json. + VELOX_QUERY_CONFIG( + kFieldNamesInJsonCastEnabled, + isFieldNamesInJsonCastEnabled, + "field_names_in_json_cast_enabled", + bool, + false, + "Use struct field names as JSON element names in CAST to JSON.") + + /// If true, operators track stats for non-special-form expressions. + VELOX_QUERY_CONFIG( + kOperatorTrackExpressionStats, + operatorTrackExpressionStats, + "operator_track_expression_stats", + bool, + false, + "Track expression stats in operators that evaluate expressions.") + + /// If true, enable operator input/output batch size stats collection. + VELOX_QUERY_CONFIG( + kEnableOperatorBatchSizeStats, + enableOperatorBatchSizeStats, + "enable_operator_batch_size_stats", + bool, + true, + "Enable input/output batch size stats collection in driver execution.") + + /// If true, the unnest operator may split output per input batch. + VELOX_QUERY_CONFIG( + kUnnestSplitOutput, + unnestSplitOutput, + "unnest_split_output", + bool, + true, + "Allow unnest operator to split output based on batch size.") + + /// Priority of the query in the memory pool reclaimer. Lower means higher + /// priority. + VELOX_QUERY_CONFIG( + kQueryMemoryReclaimerPriority, + queryMemoryReclaimerPriority, + "query_memory_reclaimer_priority", + int32_t, + std::numeric_limits::max(), + "Query priority in memory pool reclaimer. Lower means higher priority.") + + /// The max number of input splits to listen to by SplitListener. + VELOX_QUERY_CONFIG( + kMaxNumSplitsListenedTo, + maxNumSplitsListenedTo, + "max_num_splits_listened_to", + int32_t, + 0, + "Max input splits to listen to by SplitListener per scan node per worker.") + + /// Source of the query. + VELOX_QUERY_CONFIG( + kSource, + source, + "source", + std::string, + "", + "Source of the query.") + + /// Client tags of the query. + VELOX_QUERY_CONFIG( + kClientTags, + clientTags, + "client_tags", + std::string, + "", + "Client tags of the query.") + + /// Enable (reader) row size tracker. Uses enum stored as int32_t. + VELOX_QUERY_CONFIG_PROPERTY( + kRowSizeTrackingMode, + "row_size_tracking_mode", + int32_t, + 2, + "Row size tracking mode: 0=disabled, 1=exclude delta splits, 2=enabled for all.") + + /// Maximum number of distinct values to keep when merging vector hashers in + /// join HashBuild. + VELOX_QUERY_CONFIG( + kJoinBuildVectorHasherMaxNumDistinct, + joinBuildVectorHasherMaxNumDistinct, + "join_build_vector_hasher_max_num_distinct", + uint32_t, + 1'000'000, + "Max distinct values to keep when merging vector hashers in join HashBuild.") + + /// Batch size threshold for zero-copy optimization in MarkSorted operator. + VELOX_QUERY_CONFIG( + kMarkSortedZeroCopyThreshold, + markSortedZeroCopyThreshold, + "mark_sorted_zero_copy_threshold", + int32_t, + 1000, + "Batch size threshold for zero-copy in MarkSorted operator.") + + // --- Hand-written accessors for properties that need custom logic --- + + // Generated by VELOX_QUERY_CONFIG for simple properties above. uint64_t queryMaxMemoryPerNode() const { return config::toCapacity( @@ -742,112 +1500,6 @@ class QueryConfig { config::CapacityUnit::BYTE); } - uint64_t maxPartialAggregationMemoryUsage() const { - static constexpr uint64_t kDefault = 1L << 24; - return get(kMaxPartialAggregationMemory, kDefault); - } - - uint64_t maxExtendedPartialAggregationMemoryUsage() const { - static constexpr uint64_t kDefault = 1L << 26; - return get(kMaxExtendedPartialAggregationMemory, kDefault); - } - - int32_t abandonPartialAggregationMinRows() const { - return get(kAbandonPartialAggregationMinRows, 100'000); - } - - int32_t abandonPartialAggregationMinPct() const { - return get(kAbandonPartialAggregationMinPct, 80); - } - - int32_t abandonPartialTopNRowNumberMinRows() const { - return get(kAbandonPartialTopNRowNumberMinRows, 100'000); - } - - int32_t abandonPartialTopNRowNumberMinPct() const { - return get(kAbandonPartialTopNRowNumberMinPct, 80); - } - - int32_t maxElementsSizeInRepeatAndSequence() const { - return get(kMaxElementsSizeInRepeatAndSequence, 10'000); - } - - uint64_t maxSpillRunRows() const { - static constexpr uint64_t kDefault = 12UL << 20; - return get(kMaxSpillRunRows, kDefault); - } - - uint64_t maxSpillBytes() const { - static constexpr uint64_t kDefault = 100UL << 30; - return get(kMaxSpillBytes, kDefault); - } - - uint64_t maxPartitionedOutputBufferSize() const { - static constexpr uint64_t kDefault = 32UL << 20; - return get(kMaxPartitionedOutputBufferSize, kDefault); - } - - uint64_t maxOutputBufferSize() const { - static constexpr uint64_t kDefault = 32UL << 20; - return get(kMaxOutputBufferSize, kDefault); - } - - uint64_t maxLocalExchangeBufferSize() const { - static constexpr uint64_t kDefault = 32UL << 20; - return get(kMaxLocalExchangeBufferSize, kDefault); - } - - uint32_t maxLocalExchangePartitionCount() const { - // defaults to unlimited - static constexpr uint32_t kDefault = std::numeric_limits::max(); - return get(kMaxLocalExchangePartitionCount, kDefault); - } - - uint32_t minLocalExchangePartitionCountToUsePartitionBuffer() const { - // Use non buffering mode if the partition count 32 or less - // The default value is 32 is chosen rather conservatively. A - // significant performance degradation of a non-buffered approach is - // observed after 16 partitions. - static constexpr uint64_t kDefault = 33; - return get( - kMinLocalExchangePartitionCountToUsePartitionBuffer, kDefault); - } - - uint64_t maxLocalExchangePartitionBufferSize() const { - /// The default partition buffer size is 64KB. - static constexpr uint64_t kDefault = 64UL * 1024; - return get(kMaxLocalExchangePartitionBufferSize, kDefault); - } - - bool localExchangePartitionBufferPreserveEncoding() const { - /// Trying to preserve encoding can be expensive. Disabled by default. - return get(kLocalExchangePartitionBufferPreserveEncoding, false); - } - - uint32_t localMergeSourceQueueSize() const { - return get(kLocalMergeSourceQueueSize, 2); - } - - uint64_t maxExchangeBufferSize() const { - static constexpr uint64_t kDefault = 32UL << 20; - return get(kMaxExchangeBufferSize, kDefault); - } - - uint64_t maxMergeExchangeBufferSize() const { - static constexpr uint64_t kDefault = 128UL << 20; - return get(kMaxMergeExchangeBufferSize, kDefault); - } - - uint64_t minExchangeOutputBatchBytes() const { - static constexpr uint64_t kDefault = 2UL << 20; - return get(kMinExchangeOutputBatchBytes, kDefault); - } - - uint64_t preferredOutputBatchBytes() const { - static constexpr uint64_t kDefault = 10UL << 20; - return get(kPreferredOutputBatchBytes, kDefault); - } - vector_size_t preferredOutputBatchRows() const { const uint32_t batchRows = get(kPreferredOutputBatchRows, 1024); VELOX_USER_CHECK_LE(batchRows, std::numeric_limits::max()); @@ -861,94 +1513,10 @@ class QueryConfig { return maxBatchRows; } - uint32_t tableScanGetOutputTimeLimitMs() const { - return get(kTableScanGetOutputTimeLimitMs, 5'000); - } - - bool hashAdaptivityEnabled() const { - return get(kHashAdaptivityEnabled, true); - } - - uint32_t writeStrideSize() const { - static constexpr uint32_t kDefault = 100'000; - return kDefault; - } - - bool flushPerBatch() const { - static constexpr bool kDefault = true; - return kDefault; - } - - bool adaptiveFilterReorderingEnabled() const { - return get(kAdaptiveFilterReorderingEnabled, true); - } - - bool isLegacyCast() const { - return get(kLegacyCast, false); - } - - bool isMatchStructByName() const { - return get(kCastMatchStructByName, false); - } - - uint64_t exprMaxArraySizeInReduce() const { - return get(kExprMaxArraySizeInReduce, 100'000); - } - - uint64_t exprMaxCompiledRegexes() const { - return get(kExprMaxCompiledRegexes, 100); - } - - bool adjustTimestampToTimezone() const { - return get(kAdjustTimestampToTimezone, false); - } - - std::string sessionTimezone() const { - return get(kSessionTimezone, ""); - } - - bool exprEvalSimplified() const { - return get(kExprEvalSimplified, false); - } - - bool spillEnabled() const { - return get(kSpillEnabled, false); - } - - bool aggregationSpillEnabled() const { - return get(kAggregationSpillEnabled, true); - } - - bool joinSpillEnabled() const { - return get(kJoinSpillEnabled, true); - } - - bool mixedGroupedModeHashJoinSpillEnabled() const { - return get(kMixedGroupedModeHashJoinSpillEnabled, false); - } - - bool orderBySpillEnabled() const { - return get(kOrderBySpillEnabled, true); - } - - bool windowSpillEnabled() const { - return get(kWindowSpillEnabled, true); - } - - bool writerSpillEnabled() const { - return get(kWriterSpillEnabled, true); - } - - bool rowNumberSpillEnabled() const { - return get(kRowNumberSpillEnabled, true); - } - - bool topNRowNumberSpillEnabled() const { - return get(kTopNRowNumberSpillEnabled, true); - } - - bool localMergeSpillEnabled() const { - return get(kLocalMergeSpillEnabled, false); + vector_size_t mergeJoinOutputBatchStartSize() const { + const uint32_t batchRows = get(kMergeJoinOutputBatchStartSize, 0); + VELOX_USER_CHECK_LE(batchRows, std::numeric_limits::max()); + return batchRows; } uint32_t localMergeMaxNumMergeSources() const { @@ -958,15 +1526,6 @@ class QueryConfig { return maxNumMergeSources; } - int32_t maxSpillLevel() const { - return get(kMaxSpillLevel, 1); - } - - uint8_t spillStartPartitionBit() const { - constexpr uint8_t kDefaultStartBit = 48; - return get(kSpillStartPartitionBit, kDefaultStartBit); - } - uint8_t spillNumPartitionBits() const { constexpr uint8_t kDefaultBits = 3; constexpr uint8_t kMaxBits = 3; @@ -974,109 +1533,6 @@ class QueryConfig { kMaxBits, get(kSpillNumPartitionBits, kDefaultBits)); } - uint64_t writerFlushThresholdBytes() const { - return get(kWriterFlushThresholdBytes, 96L << 20); - } - - uint64_t maxSpillFileSize() const { - constexpr uint64_t kDefaultMaxFileSize = 0; - return get(kMaxSpillFileSize, kDefaultMaxFileSize); - } - - std::string spillCompressionKind() const { - return get(kSpillCompressionKind, "none"); - } - - bool spillPrefixSortEnabled() const { - return get(kSpillPrefixSortEnabled, false); - } - - uint64_t spillWriteBufferSize() const { - // The default write buffer size set to 1MB. - return get(kSpillWriteBufferSize, 1L << 20); - } - - uint64_t spillReadBufferSize() const { - // The default read buffer size set to 1MB. - return get(kSpillReadBufferSize, 1L << 20); - } - - std::string spillFileCreateConfig() const { - return get(kSpillFileCreateConfig, ""); - } - - int32_t minSpillableReservationPct() const { - constexpr int32_t kDefaultPct = 5; - return get(kMinSpillableReservationPct, kDefaultPct); - } - - int32_t spillableReservationGrowthPct() const { - constexpr int32_t kDefaultPct = 10; - return get(kSpillableReservationGrowthPct, kDefaultPct); - } - - bool queryTraceEnabled() const { - return get(kQueryTraceEnabled, false); - } - - std::string queryTraceDir() const { - // The default query trace dir, empty by default. - return get(kQueryTraceDir, ""); - } - - std::string queryTraceNodeId() const { - // The default query trace node ID, empty by default. - return get(kQueryTraceNodeId, ""); - } - - uint64_t queryTraceMaxBytes() const { - return get(kQueryTraceMaxBytes, 0); - } - - std::string queryTraceTaskRegExp() const { - // The default query trace task regexp, empty by default. - return get(kQueryTraceTaskRegExp, ""); - } - - bool queryTraceDryRun() const { - return get(kQueryTraceDryRun, false); - } - - std::string opTraceDirectoryCreateConfig() const { - return get(kOpTraceDirectoryCreateConfig, ""); - } - - bool prestoArrayAggIgnoreNulls() const { - return get(kPrestoArrayAggIgnoreNulls, false); - } - - bool sparkAnsiEnabled() const { - return get(kSparkAnsiEnabled, false); - } - - int64_t sparkBloomFilterExpectedNumItems() const { - constexpr int64_t kDefault = 1'000'000L; - return get(kSparkBloomFilterExpectedNumItems, kDefault); - } - - int64_t sparkBloomFilterNumBits() const { - constexpr int64_t kDefault = 8'388'608L; - return get(kSparkBloomFilterNumBits, kDefault); - } - - // Spark kMaxNumBits is 67'108'864, but velox has memory limit sizeClassSizes - // 256, so decrease it to not over memory limit. - int64_t sparkBloomFilterMaxNumBits() const { - constexpr int64_t kDefault = 4'096 * 1024; - auto value = get(kSparkBloomFilterMaxNumBits, kDefault); - VELOX_USER_CHECK_LE( - value, - kDefault, - "{} cannot exceed the default value", - kSparkBloomFilterMaxNumBits); - return value; - } - int32_t sparkPartitionId() const { auto id = get(kSparkPartitionId); VELOX_CHECK(id.has_value(), "Spark partition id is not set."); @@ -1085,182 +1541,45 @@ class QueryConfig { return value; } - bool sparkLegacyDateFormatter() const { - return get(kSparkLegacyDateFormatter, false); - } - - bool sparkLegacyStatisticalAggregate() const { - return get(kSparkLegacyStatisticalAggregate, false); - } - - bool sparkJsonIgnoreNullFields() const { - return get(kSparkJsonIgnoreNullFields, true); - } - - bool exprTrackCpuUsage() const { - return get(kExprTrackCpuUsage, false); - } - - bool operatorTrackCpuUsage() const { - return get(kOperatorTrackCpuUsage, true); - } - - uint32_t taskWriterCount() const { - return get(kTaskWriterCount, 4); - } - uint32_t taskPartitionedWriterCount() const { return get(kTaskPartitionedWriterCount) .value_or(taskWriterCount()); } - bool hashProbeFinishEarlyOnEmptyBuild() const { - return get(kHashProbeFinishEarlyOnEmptyBuild, false); - } - - uint32_t minTableRowsForParallelJoinBuild() const { - return get(kMinTableRowsForParallelJoinBuild, 1'000); - } - - bool validateOutputFromOperators() const { - return get(kValidateOutputFromOperators, false); - } - - bool isExpressionEvaluationCacheEnabled() const { - return get(kEnableExpressionEvaluationCache, true); - } - - uint32_t maxSharedSubexprResultsCached() const { - // 10 was chosen as a default as there are cases where a shared - // subexpression can be called in 2 different places and a particular - // argument may be peeled in one and not peeled in another. 10 is large - // enough to handle this happening for a few arguments in different - // combinations. - // - // For example, when the UDF at the root of a shared subexpression does not - // have default null behavior and takes an input that is dictionary encoded - // with nulls set in the DictionaryVector. That dictionary - // encoding may be peeled depending on whether or not there is a UDF above - // it in the expression tree that has default null behavior and takes the - // same input as an argument. - return get(kMaxSharedSubexprResultsCached, 10); - } - - int32_t maxSplitPreloadPerDriver() const { - return get(kMaxSplitPreloadPerDriver, 2); - } - - uint32_t driverCpuTimeSliceLimitMs() const { - return get(kDriverCpuTimeSliceLimitMs, 0); - } - - uint32_t prefixSortNormalizedKeyMaxBytes() const { - return get(kPrefixSortNormalizedKeyMaxBytes, 128); - } - - uint32_t prefixSortMinRows() const { - return get(kPrefixSortMinRows, 128); - } - - uint32_t prefixSortMaxStringPrefixLength() const { - return get(kPrefixSortMaxStringPrefixLength, 16); - } - - double scaleWriterRebalanceMaxMemoryUsageRatio() const { - return get(kScaleWriterRebalanceMaxMemoryUsageRatio, 0.7); - } - - uint32_t scaleWriterMaxPartitionsPerWriter() const { - return get(kScaleWriterMaxPartitionsPerWriter, 128); - } - - uint64_t scaleWriterMinPartitionProcessedBytesRebalanceThreshold() const { - return get( - kScaleWriterMinPartitionProcessedBytesRebalanceThreshold, 128 << 20); - } - - uint64_t scaleWriterMinProcessedBytesRebalanceThreshold() const { - return get( - kScaleWriterMinProcessedBytesRebalanceThreshold, 256 << 20); - } - - bool tableScanScaledProcessingEnabled() const { - return get(kTableScanScaledProcessingEnabled, false); - } - - double tableScanScaleUpMemoryUsageRatio() const { - return get(kTableScanScaleUpMemoryUsageRatio, 0.7); - } - - uint32_t indexLookupJoinMaxPrefetchBatches() const { - return get(kIndexLookupJoinMaxPrefetchBatches, 0); - } - - bool indexLookupJoinSplitOutput() const { - return get(kIndexLookupJoinSplitOutput, true); - } - - std::string shuffleCompressionKind() const { - return get(kShuffleCompressionKind, "none"); - } - - int32_t requestDataSizesMaxWaitSec() const { - return get(kRequestDataSizesMaxWaitSec, 10); - } - - bool throwExceptionOnDuplicateMapKeys() const { - return get(kThrowExceptionOnDuplicateMapKeys, false); - } - - /// TODO: Remove after dependencies are cleaned up. - bool streamingAggregationEagerFlush() const { - return get(kStreamingAggregationEagerFlush, false); - } - - int32_t streamingAggregationMinOutputBatchRows() const { - return get(kStreamingAggregationMinOutputBatchRows, 0); - } - - bool isFieldNamesInJsonCastEnabled() const { - return get(kFieldNamesInJsonCastEnabled, false); - } - - bool operatorTrackExpressionStats() const { - return get(kOperatorTrackExpressionStats, false); - } - - bool enableOperatorBatchSizeStats() const { - return get(kEnableOperatorBatchSizeStats, true); - } - - bool unnestSplitOutput() const { - return get(kUnnestSplitOutput, true); - } - - int32_t queryMemoryReclaimerPriority() const { - return get( - kQueryMemoryReclaimerPriority, std::numeric_limits::max()); + std::optional debugAggregationApproxPercentileFixedRandomSeed() + const { + return get(kDebugAggregationApproxPercentileFixedRandomSeed); } - int32_t maxNumSplitsListenedTo() const { - return get(kMaxNumSplitsListenedTo, 0); + uint64_t debugMemoryPoolWarnThresholdBytes() const { + return config::toCapacity( + get(kDebugMemoryPoolWarnThresholdBytes, "0B"), + config::CapacityUnit::BYTE); } - std::string source() const { - return get(kSource, ""); - } + enum class RowSizeTrackingMode { + DISABLED = 0, + EXCLUDE_DELTA_SPLITS = 1, + ENABLED_FOR_ALL = 2, + }; - std::string clientTags() const { - return get(kClientTags, ""); + RowSizeTrackingMode rowSizeTrackingMode() const { + return get( + kRowSizeTrackingMode, RowSizeTrackingMode::ENABLED_FOR_ALL); } template T get(const std::string& key, const T& defaultValue) const { return config_->get(key, defaultValue); } + template std::optional get(const std::string& key) const { - return std::optional(config_->get(key)); + return config_->get(key); + } + + const std::shared_ptr& config() const { + return config_; } /// Test-only method to override the current query config properties. @@ -1273,6 +1592,10 @@ class QueryConfig { private: void validateConfig(); - std::unique_ptr config_; + std::shared_ptr config_; }; + +#undef VELOX_QUERY_CONFIG +#undef VELOX_QUERY_CONFIG_PROPERTY + } // namespace facebook::velox::core diff --git a/velox/core/QueryConfigProvider.cpp b/velox/core/QueryConfigProvider.cpp new file mode 100644 index 00000000000..4ecd836ee09 --- /dev/null +++ b/velox/core/QueryConfigProvider.cpp @@ -0,0 +1,32 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/core/QueryConfigProvider.h" + +#include "velox/core/QueryConfig.h" + +namespace facebook::velox::core { + +std::vector QueryConfigProvider::properties() const { + return QueryConfig::registeredProperties(); +} + +std::string QueryConfigProvider::normalize( + std::string_view /*name*/, + std::string_view value) const { + return std::string(value); +} + +} // namespace facebook::velox::core diff --git a/velox/core/QueryConfigProvider.h b/velox/core/QueryConfigProvider.h new file mode 100644 index 00000000000..21ad4cde39d --- /dev/null +++ b/velox/core/QueryConfigProvider.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/common/config/ConfigProvider.h" + +namespace facebook::velox::core { + +/// Exposes all QueryConfig properties as ConfigProperty entries. +class QueryConfigProvider : public config::ConfigProvider { + public: + std::vector properties() const override; + + std::string normalize(std::string_view name, std::string_view value) + const override; +}; + +} // namespace facebook::velox::core diff --git a/velox/core/QueryCtx.cpp b/velox/core/QueryCtx.cpp index 6297a271882..6cfcb3950ba 100644 --- a/velox/core/QueryCtx.cpp +++ b/velox/core/QueryCtx.cpp @@ -15,8 +15,8 @@ */ #include "velox/core/QueryCtx.h" +#include "velox/common/base/Exceptions.h" #include "velox/common/base/SpillConfig.h" -#include "velox/common/base/TraceConfig.h" #include "velox/common/config/Config.h" namespace facebook::velox::core { @@ -30,18 +30,35 @@ std::shared_ptr QueryCtx::create( cache::AsyncDataCache* cache, std::shared_ptr pool, folly::Executor* spillExecutor, - const std::string& queryId, + std::string queryId, std::shared_ptr tokenProvider) { + return QueryCtx::Builder() + .executor(executor) + .queryConfig(std::move(queryConfig)) + .connectorConfigs(std::move(connectorConfigs)) + .asyncDataCache(cache) + .pool(std::move(pool)) + .spillExecutor(spillExecutor) + .queryId(std::move(queryId)) + .tokenProvider(std::move(tokenProvider)) + .build(); +} + +std::shared_ptr QueryCtx::Builder::build() { std::shared_ptr queryCtx(new QueryCtx( - executor, - std::move(queryConfig), - std::move(connectorConfigs), - cache, - std::move(pool), - spillExecutor, - queryId, - std::move(tokenProvider))); + executor_, + std::move(queryConfig_), + std::move(connectorConfigs_), + cache_, + std::move(pool_), + spillExecutor_, + std::move(queryId_), + std::move(tokenProvider_), + std::move(traceCtxProvider_))); queryCtx->maybeSetReclaimer(); + for (auto& cb : releaseCallbacks_) { + queryCtx->addReleaseCallback(std::move(cb)); + } return queryCtx; } @@ -54,7 +71,8 @@ QueryCtx::QueryCtx( std::shared_ptr pool, folly::Executor* spillExecutor, const std::string& queryId, - std::shared_ptr tokenProvider) + std::shared_ptr tokenProvider, + TraceCtxProvider traceCtxProvider) : queryId_(queryId), executor_(executor), spillExecutor_(spillExecutor), @@ -62,15 +80,29 @@ QueryCtx::QueryCtx( connectorSessionProperties_(connectorSessionProperties), pool_(std::move(pool)), queryConfig_{std::move(queryConfig)}, - fsTokenProvider_(std::move(tokenProvider)) { + fsTokenProvider_(std::move(tokenProvider)), + traceCtxProvider_(std::move(traceCtxProvider)) { initPool(queryId); } +QueryCtx::~QueryCtx() { + for (auto& cb : releaseCallbacks_) { + try { + cb(); + } catch (const std::exception& e) { + LOG(ERROR) << "Release callback threw exception: " << e.what(); + } catch (...) { + LOG(ERROR) << "Release callback threw unknown exception"; + } + } + VELOX_CHECK(!underArbitration_); +} + /*static*/ std::string QueryCtx::generatePoolName(const std::string& queryId) { // We attach a monotonically increasing sequence number to ensure the pool // name is unique. static std::atomic seqNum{0}; - return fmt::format("query.{}.{}", queryId.c_str(), seqNum++); + return fmt::format("query.{}.{}", queryId, seqNum++); } void QueryCtx::maybeSetReclaimer() { @@ -86,18 +118,19 @@ void QueryCtx::updateSpilledBytesAndCheckLimit(uint64_t bytes) { const auto numSpilledBytes = numSpilledBytes_.fetch_add(bytes) + bytes; if (queryConfig_.maxSpillBytes() > 0 && numSpilledBytes > queryConfig_.maxSpillBytes()) { - VELOX_SPILL_LIMIT_EXCEEDED(fmt::format( - "Query exceeded per-query local spill limit of {}", - succinctBytes(queryConfig_.maxSpillBytes()))); + VELOX_SPILL_LIMIT_EXCEEDED( + fmt::format( + "Query exceeded per-query local spill limit of {}", + succinctBytes(queryConfig_.maxSpillBytes()))); } } void QueryCtx::updateTracedBytesAndCheckLimit(uint64_t bytes) { if (numTracedBytes_.fetch_add(bytes) + bytes >= queryConfig_.queryTraceMaxBytes()) { - VELOX_TRACE_LIMIT_EXCEEDED(fmt::format( + VELOX_TRACE_LIMIT_EXCEEDED( "Query exceeded per-query local trace limit of {}", - succinctBytes(queryConfig_.queryTraceMaxBytes()))); + succinctBytes(queryConfig_.queryTraceMaxBytes())); } } diff --git a/velox/core/QueryCtx.h b/velox/core/QueryCtx.h index 423e93ea312..749c6d0be3f 100644 --- a/velox/core/QueryCtx.h +++ b/velox/core/QueryCtx.h @@ -17,32 +17,99 @@ #pragma once #include -#include +#include +#include +#include +#include +#include +#include +#include "velox/common/base/Exceptions.h" #include "velox/common/caching/AsyncDataCache.h" #include "velox/common/memory/Memory.h" #include "velox/core/QueryConfig.h" +#include "velox/core/ScanBatchEvent.h" #include "velox/vector/DecodedVector.h" #include "velox/vector/VectorPool.h" -namespace facebook::velox { -class Config; -} +namespace facebook::velox::exec::trace { +class TraceCtx; +} // namespace facebook::velox::exec::trace namespace facebook::velox::core { +struct PlanFragment; + +/// Query execution context that manages resources and configuration for a +/// query. +/// +/// QueryCtx encapsulates query-level state and resources including: +/// +/// - Memory pool management and memory arbitration. +/// - Query and connector-specific configuration. +/// - Executor for parallel task execution. +/// - Async data cache for IO operations. +/// - Spill executor for disk-based operations. +/// - Query tracing and metrics tracking. +/// +/// Usage Contexts: +/// +/// - Multi-threaded execution: Used with Task::start() where an executor must +/// be provided and its lifetime must outlive all tasks using this context. +/// - Single-threaded execution: Used with ExecCtx or Task::next() for +/// expression evaluation where no executor is required. +/// +/// Construction: +/// +/// To construct a QueryCtx, prefer to use the builder pattern: +/// +/// @code +/// auto queryCtx = QueryCtx::Builder() +/// .executor(myExecutor) +/// .queryConfig(configMap) +/// .queryId("query-123") +/// .pool(myMemoryPool) +/// .build(); +/// @endcode +/// +/// Memory Management: +/// +/// - Automatically creates a root memory pool if not provided +/// - Supports memory arbitration and reclamation under memory pressure +/// - Tracks spilled bytes with configurable limits +/// - Thread-safe memory pool operations +/// +/// Thread-safety: QueryCtx is thread-safe for concurrent access across +/// multiple tasks and operators within a query execution. class QueryCtx : public std::enable_shared_from_this { public: - ~QueryCtx() { - VELOX_CHECK(!underArbitration_); - } + using ReleaseCallback = std::function; + using TraceCtxProvider = std::function( + core::QueryCtx&, + const core::PlanFragment&)>; - /// QueryCtx is used in different places. When used with `Task::start()`, it's - /// required that the caller supplies the executor and ensure its lifetime - /// outlives the tasks that use it. In contrast, when used in expression - /// evaluation through `ExecCtx` or 'Task::next()' for single thread execution - /// mode, executor is not needed. Hence, we don't require executor to always - /// be passed in here, but instead, ensure that executor exists when actually - /// being used. + ~QueryCtx(); + + /// Creates a new QueryCtx instance with the specified configuration. + /// + /// This factory method constructs a QueryCtx with all necessary resources + /// and automatically sets up memory reclamation if not already configured. + /// + /// @param executor Optional executor for parallel task execution. Required + /// when used with Task::start(), but not needed for + /// expression evaluation or single-threaded execution. + /// @param queryConfig Query-level configuration settings. + /// @param connectorConfigs Connector-specific configuration mappings. + /// @param cache Async data cache for IO operations (defaults to global + /// instance). + /// @param pool Memory pool for query execution (auto-created if nullptr). + /// @param spillExecutor Optional executor for spilling operations. + /// @param queryId Unique identifier for this query. + /// @param tokenProvider Optional filesystem token provider for + /// authentication. + /// @return Shared pointer to the newly created QueryCtx. + /// + /// Note: The caller must ensure the executor's lifetime outlives all tasks + /// using this QueryCtx when executor is provided. static std::shared_ptr create( folly::Executor* executor = nullptr, QueryConfig&& queryConfig = QueryConfig{{}}, @@ -51,9 +118,109 @@ class QueryCtx : public std::enable_shared_from_this { cache::AsyncDataCache* cache = cache::AsyncDataCache::getInstance(), std::shared_ptr pool = nullptr, folly::Executor* spillExecutor = nullptr, - const std::string& queryId = "", + std::string queryId = "", std::shared_ptr tokenProvider = {}); + /// Builder pattern for constructing QueryCtx instances. + /// + /// Provides a fluent interface for creating QueryCtx with optional + /// parameters. This is the recommended approach for improved readability, + /// especially when only setting a subset of configuration options. + /// + /// Example: + /// @code + /// auto ctx = QueryCtx::Builder() + /// .queryId("my-query") + /// .executor(myExecutor) + /// .queryConfig(QueryConfig{mySettings}) + /// .build(); + /// @endcode + class Builder { + public: + Builder& executor(folly::Executor* executor) { + executor_ = executor; + return *this; + } + + Builder& queryConfig(QueryConfig queryConfig) { + queryConfig_ = std::move(queryConfig); + return *this; + } + + Builder& connectorConfigs( + std::unordered_map> + connectorConfigs) { + connectorConfigs_ = std::move(connectorConfigs); + return *this; + } + + Builder& asyncDataCache(cache::AsyncDataCache* cache) { + cache_ = cache; + return *this; + } + + Builder& pool(std::shared_ptr pool) { + pool_ = std::move(pool); + return *this; + } + + Builder& spillExecutor(folly::Executor* spillExecutor) { + spillExecutor_ = spillExecutor; + return *this; + } + + Builder& queryId(std::string queryId) { + queryId_ = std::move(queryId); + return *this; + } + + Builder& tokenProvider( + std::shared_ptr tokenProvider) { + tokenProvider_ = std::move(tokenProvider); + return *this; + } + + /// Adds a callback to be invoked when the QueryCtx is destroyed. + /// Multiple callbacks can be added by calling this method multiple times. + Builder& releaseCallback(ReleaseCallback callback) { + releaseCallbacks_.push_back(std::move(callback)); + return *this; + } + + Builder& traceCtxProvider(TraceCtxProvider provider) { + traceCtxProvider_ = std::move(provider); + return *this; + } + + /// Constructs and returns a QueryCtx with the configured parameters. + /// + /// @return Shared pointer to the newly created QueryCtx instance + std::shared_ptr build(); + + private: + folly::Executor* executor_{nullptr}; + QueryConfig queryConfig_{QueryConfig{{}}}; + std::unordered_map> + connectorConfigs_; + cache::AsyncDataCache* cache_{cache::AsyncDataCache::getInstance()}; + std::shared_ptr pool_; + folly::Executor* spillExecutor_{nullptr}; + std::string queryId_; + std::shared_ptr tokenProvider_; + std::deque releaseCallbacks_; + TraceCtxProvider traceCtxProvider_; + }; + + /// Generates a unique memory pool name for a query. + /// + /// Creates a pool name by combining the provided query ID with a + /// monotonically increasing sequence number to ensure uniqueness across + /// multiple pool creations, even for the same query ID. + /// + /// @param queryId The query identifier to incorporate into the pool name + /// @return A unique pool name in the format "query.{queryId}.{seqNum}" + /// + /// Thread-safe: Uses atomic operations for sequence number generation. static std::string generatePoolName(const std::string& queryId); memory::MemoryPool* pool() const { @@ -94,6 +261,22 @@ class QueryCtx : public std::enable_shared_from_this { return fsTokenProvider_; } + /// Registers a callback to be invoked when this QueryCtx is destroyed. + /// This allows external resources tied to the query's lifetime to be cleaned + /// up before the QueryCtx and its members are destructed. For example, + /// resources that have allocations in the query's memory pool. + /// + /// Example: HashTableCache uses this to remove cached hash tables when a + /// query completes. The cache entry holds a child memory pool of the query + /// pool, so it must be released before the query pool is destroyed. + /// + /// Note: Callbacks are invoked in registration order. Exceptions thrown by + /// callbacks are caught and logged; they do not prevent subsequent callbacks + /// from running. + void addReleaseCallback(ReleaseCallback callback) { + releaseCallbacks_.push_back(std::move(callback)); + } + /// Overrides the previous configuration. Note that this function is NOT /// thread-safe and should probably only be used in tests. void testingOverrideConfigUnsafe( @@ -131,6 +314,67 @@ class QueryCtx : public std::enable_shared_from_this { /// the max query trace bytes limit. void updateTracedBytesAndCheckLimit(uint64_t bytes); + TraceCtxProvider traceCtxProvider() { + return traceCtxProvider_; + } + + void setTraceCtxProvider(TraceCtxProvider provider) { + traceCtxProvider_ = std::move(provider); + } + + /// Sets an optional callback fired by TableScan after each non-empty batch. + void setScanBatchCallback(ScanBatchCallback callback) { + scanBatchCallback_ = std::move(callback); + } + + const ScanBatchCallback& scanBatchCallback() const { + return scanBatchCallback_; + } + + /// Store a per-query registry override. Each subsystem defines its own key + /// (e.g., "connectors", "vectorFunctions"). The registry is stored as a + /// type-erased shared_ptr; callers must use the same type T for setRegistry + /// and registry calls with the same key. Returns true if the key was newly + /// inserted. Throws if the key already exists unless 'overwrite' is true, + /// in which case the existing entry is replaced and false is returned. + template + bool setRegistry( + std::string_view key, + std::shared_ptr registry, + bool overwrite = false) { + return registries_.withWLock([&](auto& map) { + auto it = map.find(std::string(key)); + if (it != map.end()) { + VELOX_CHECK(overwrite, "Registry already set: {}", key); + it->second = {std::move(registry), std::type_index(typeid(T))}; + return false; + } + map.emplace( + std::string(key), + RegistryEntry{std::move(registry), std::type_index(typeid(T))}); + return true; + }); + } + + /// Retrieve a per-query registry override. Returns nullptr if no override + /// was set for this key. Asserts that the stored type matches T. + template + std::shared_ptr registry(std::string_view key) const { + return registries_.withRLock([&](const auto& map) -> std::shared_ptr { + auto it = map.find(std::string(key)); + if (it == map.end()) { + return nullptr; + } + VELOX_CHECK( + it->second.type == std::type_index(typeid(T)), + "Registry type mismatch for key '{}': expected {}, got {}", + key, + typeid(T).name(), + it->second.type.name()); + return std::static_pointer_cast(it->second.ptr); + }); + } + void testingOverrideMemoryPool(std::shared_ptr pool) { pool_ = std::move(pool); } @@ -159,7 +403,8 @@ class QueryCtx : public std::enable_shared_from_this { std::shared_ptr pool = nullptr, folly::Executor* spillExecutor = nullptr, const std::string& queryId = "", - std::shared_ptr tokenProvider = {}); + std::shared_ptr tokenProvider = {}, + TraceCtxProvider traceCtxProvider = nullptr); class MemoryReclaimer : public memory::MemoryReclaimer { public: @@ -214,6 +459,7 @@ class QueryCtx : public std::enable_shared_from_this { // Invoked to start memory arbitration on this query. void startArbitration(); + // Invoked to stop memory arbitration on this query. void finishArbitration(); @@ -230,10 +476,29 @@ class QueryCtx : public std::enable_shared_from_this { std::atomic numTracedBytes_{0}; mutable std::mutex mutex_; + // Indicates if this query is under memory arbitration or not. std::atomic_bool underArbitration_{false}; std::vector arbitrationPromises_; std::shared_ptr fsTokenProvider_; + // Callbacks invoked before destruction to clean up external resources. + std::deque releaseCallbacks_; + + // A function that constructs a custom trace ctx object. + TraceCtxProvider traceCtxProvider_; + + // Optional per-batch scan stats callback. + ScanBatchCallback scanBatchCallback_; + + // Type-erased registry entry for per-query overrides. + struct RegistryEntry { + std::shared_ptr ptr; + std::type_index type; + }; + + // Per-query registry overrides keyed by subsystem name. + folly::Synchronized> + registries_; }; // Represents the state of one thread of query execution. @@ -260,6 +525,7 @@ class ExecCtx { !queryConfig.debugDisableExpressionsWithMemoization() && exprEvalCacheEnabled; peelingEnabled = !queryConfig.debugDisableExpressionsWithPeeling(); + minRowsForPeeling = queryConfig.minRowsForPeeling(); sharedSubExpressionReuseEnabled = !queryConfig.debugDisableCommonSubExpressions(); deferredLazyLoadingEnabled = @@ -279,6 +545,9 @@ class ExecCtx { bool dictionaryMemoizationEnabled; /// True if peeling is enabled during experssion evaluation. bool peelingEnabled; + /// Minimum number of rows required for peeling to be applied during + /// expression evaluation. + int32_t minRowsForPeeling; /// True if shared subexpression reuse is enabled during experssion /// evaluation. bool sharedSubExpressionReuseEnabled; diff --git a/velox/core/ScanBatchEvent.h b/velox/core/ScanBatchEvent.h new file mode 100644 index 00000000000..27a01c59ec5 --- /dev/null +++ b/velox/core/ScanBatchEvent.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +namespace facebook::velox::core { + +/// Per-batch scan statistics event fired by TableScan after each batch. +struct ScanBatchEvent { + virtual ~ScanBatchEvent() = default; + + /// Post-pushdown, pre-remaining-filter row count. + uint64_t numRows{0}; + /// Wall time spent producing this batch in microseconds. + uint64_t wallTimeMicros{0}; +}; + +using ScanBatchCallback = std::function; + +} // namespace facebook::velox::core diff --git a/velox/core/SimpleFunctionMetadata.h b/velox/core/SimpleFunctionMetadata.h index 05fb26a4491..d160778cfff 100644 --- a/velox/core/SimpleFunctionMetadata.h +++ b/velox/core/SimpleFunctionMetadata.h @@ -80,6 +80,21 @@ struct udf_canonical_name< static constexpr exec::FunctionCanonicalName value = T::canonical_name; }; +// If a UDF doesn't declare a default owner +template +struct udf_owner { + static constexpr std::string_view value() { + return ""; + } +}; + +template +struct udf_owner> { + static constexpr std::string_view value() { + return T::owner; + } +}; + // Has the value true, unless a Variadic Type appears anywhere but at the end // of the parameters. template @@ -243,13 +258,14 @@ struct TypeAnalysis> { } else { auto typeVariableName = fmt::format("__user_T{}", T::getId()); results.out << typeVariableName; - results.addVariable(exec::SignatureVariable( - typeVariableName, - std::nullopt, - exec::ParameterType::kTypeParameter, - false, - orderable, - comparable)); + results.addVariable( + exec::SignatureVariable( + typeVariableName, + std::nullopt, + exec::ParameterType::kTypeParameter, + false, + orderable, + comparable)); } results.stats.hasGeneric = true; results.physicalType = UNKNOWN(); @@ -264,10 +280,12 @@ struct TypeAnalysis> { const auto p = P::name(); const auto s = S::name(); results.out << fmt::format("decimal({},{})", p, s); - results.addVariable(exec::SignatureVariable( - p, std::nullopt, exec::ParameterType::kIntegerParameter)); - results.addVariable(exec::SignatureVariable( - s, std::nullopt, exec::ParameterType::kIntegerParameter)); + results.addVariable( + exec::SignatureVariable( + p, std::nullopt, exec::ParameterType::kIntegerParameter)); + results.addVariable( + exec::SignatureVariable( + s, std::nullopt, exec::ParameterType::kIntegerParameter)); results.physicalType = BIGINT(); } }; @@ -280,10 +298,12 @@ struct TypeAnalysis> { const auto p = P::name(); const auto s = S::name(); results.out << fmt::format("decimal({},{})", p, s); - results.addVariable(exec::SignatureVariable( - p, std::nullopt, exec::ParameterType::kIntegerParameter)); - results.addVariable(exec::SignatureVariable( - s, std::nullopt, exec::ParameterType::kIntegerParameter)); + results.addVariable( + exec::SignatureVariable( + p, std::nullopt, exec::ParameterType::kIntegerParameter)); + results.addVariable( + exec::SignatureVariable( + s, std::nullopt, exec::ParameterType::kIntegerParameter)); results.physicalType = HUGEINT(); } }; @@ -295,8 +315,9 @@ struct TypeAnalysis> { const auto e = E::name(); results.out << fmt::format("bigint_enum({})", e); - results.addVariable(exec::SignatureVariable( - e, std::nullopt, exec::ParameterType::kEnumParameter)); + results.addVariable( + exec::SignatureVariable( + e, std::nullopt, exec::ParameterType::kEnumParameter)); results.physicalType = BIGINT(); } }; @@ -308,8 +329,9 @@ struct TypeAnalysis> { const auto e = E::name(); results.out << fmt::format("varchar_enum({})", e); - results.addVariable(exec::SignatureVariable( - e, std::nullopt, exec::ParameterType::kEnumParameter)); + results.addVariable( + exec::SignatureVariable( + e, std::nullopt, exec::ParameterType::kEnumParameter)); results.physicalType = VARCHAR(); } }; @@ -431,6 +453,9 @@ class ISimpleFunctionMetadata { virtual std::string getName() const = 0; virtual bool isDeterministic() const = 0; virtual bool defaultNullBehavior() const = 0; + // Return the owner of the function. This is used for logging and + // attribution. + virtual std::string_view owner() const = 0; virtual uint32_t priority() const = 0; virtual const std::shared_ptr signature() const = 0; virtual const TypePtr& resultPhysicalType() const = 0; @@ -545,6 +570,10 @@ class SimpleFunctionMetadata : public ISimpleFunctionMetadata { return udf_is_deterministic(); } + std::string_view owner() const final { + return udf_owner::value(); + } + bool defaultNullBehavior() const final { return defaultNullBehavior_; } @@ -874,8 +903,28 @@ class UDFHolder { (udf_has_callAscii_return_void && udf_has_call_return_bool)), "The return type for callAscii() must match the return type for call()."); - // initialize(): - static constexpr bool udf_has_initialize = util::has_method< + // Detects if initialize() is a template method using SFINAE. + // Template methods can match any signature via template parameter deduction, + // causing false positives in trait detection. We probe with a dummy type + // that's not in our expected signature to identify templates. + struct DummyProbeType {}; + + template + struct has_template_initialize : std::false_type {}; + + template + struct has_template_initialize< + U, + util::detail::void_t().initialize( + std::declval&>(), + std::declval(), + std::declval()))>> : std::true_type {}; + + static constexpr bool is_initialize_template = + has_template_initialize::value; + + // Check for initialize() without MemoryPool parameter. + static constexpr bool udf_has_initialize_without_pool = util::has_method< Fun, initialize_method_resolver, void, @@ -883,6 +932,25 @@ class UDFHolder { const core::QueryConfig&, const exec_arg_type*...>::value; + // Check for initialize() with MemoryPool parameter. + // Excludes template methods to prevent them from incorrectly matching + // via template parameter substitution (e.g., T=MemoryPool). + static constexpr bool udf_has_initialize_with_pool = + !is_initialize_template && + util::has_method< + Fun, + initialize_method_resolver, + void, + const std::vector&, + const core::QueryConfig&, + memory::MemoryPool*, + const exec_arg_type*...>::value; + + // Combined trait for backward compatibility: true if ANY initialize exists + // This preserves the original meaning of udf_has_initialize + static constexpr bool udf_has_initialize = + udf_has_initialize_with_pool || udf_has_initialize_without_pool; + // TODO Remove static constexpr bool udf_has_legacy_initialize = util::has_method< Fun, @@ -958,6 +1026,10 @@ class UDFHolder { return udf_is_deterministic(); } + std::string_view owner() const { + return udf_owner::value(); + } + static constexpr bool isVariadic() { if constexpr (num_args == 0) { return false; @@ -969,9 +1041,16 @@ class UDFHolder { FOLLY_ALWAYS_INLINE void initialize( const std::vector& inputTypes, const core::QueryConfig& config, + memory::MemoryPool* memoryPool, const typename exec_resolver::in_type*... constantArgs) { - if constexpr (udf_has_initialize) { + // Prefer non-MemoryPool signature first to handle template methods + // correctly. Template initialize() methods can match any signature via + // template parameter deduction, so we avoid passing MemoryPool to them. + if constexpr (udf_has_initialize_without_pool) { return instance_.initialize(inputTypes, config, constantArgs...); + } else if constexpr (udf_has_initialize_with_pool) { + return instance_.initialize( + inputTypes, config, memoryPool, constantArgs...); } } diff --git a/velox/core/TableWriteTraits.cpp b/velox/core/TableWriteTraits.cpp new file mode 100644 index 00000000000..bc6e1a31832 --- /dev/null +++ b/velox/core/TableWriteTraits.cpp @@ -0,0 +1,140 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/core/TableWriteTraits.h" +#include "velox/vector/ComplexVector.h" +#include "velox/vector/ConstantVector.h" +#include "velox/vector/FlatVector.h" + +namespace facebook::velox::core { + +// static +RowVectorPtr TableWriteTraits::createAggregationStatsOutput( + RowTypePtr outputType, + RowVectorPtr aggregationOutput, + StringView tableCommitContext, + velox::memory::MemoryPool* pool) { + // TODO: record aggregation stats output time. + if (aggregationOutput == nullptr) { + return nullptr; + } + VELOX_CHECK_GT(aggregationOutput->childrenSize(), 0); + const vector_size_t numOutputRows = aggregationOutput->childAt(0)->size(); + std::vector columns; + for (int channel = 0; channel < outputType->size(); channel++) { + if (channel < TableWriteTraits::kContextChannel) { + // 1. Set null rows column. + // 2. Set null fragments column. + columns.push_back( + BaseVector::createNullConstant( + outputType->childAt(channel), numOutputRows, pool)); + continue; + } + if (channel == TableWriteTraits::kContextChannel) { + // 3. Set commitcontext column. + columns.push_back( + std::make_shared>( + pool, + numOutputRows, + false /*isNull*/, + VARBINARY(), + // Note that we move tableCommitContext here, so ensure this + // branch is only executed once in the loop. + std::move(tableCommitContext))); + continue; + } + // 4. Set statistics columns. + columns.push_back( + aggregationOutput->childAt(channel - TableWriteTraits::kStatsChannel)); + } + return std::make_shared( + pool, outputType, nullptr, numOutputRows, columns); +} + +std::string TableWriteTraits::rowCountColumnName() { + static const std::string kRowCountName = "rows"; + return kRowCountName; +} + +std::string TableWriteTraits::fragmentColumnName() { + static const std::string kFragmentName = "fragments"; + return kFragmentName; +} + +std::string TableWriteTraits::contextColumnName() { + static const std::string kContextName = "commitcontext"; + return kContextName; +} + +const TypePtr& TableWriteTraits::rowCountColumnType() { + static const TypePtr kRowCountType = BIGINT(); + return kRowCountType; +} + +const TypePtr& TableWriteTraits::fragmentColumnType() { + static const TypePtr kFragmentType = VARBINARY(); + return kFragmentType; +} + +const TypePtr& TableWriteTraits::contextColumnType() { + static const TypePtr kContextType = VARBINARY(); + return kContextType; +} + +// static. +RowTypePtr TableWriteTraits::outputType( + const std::optional& columnStatsSpec) { + static const auto kOutputTypeWithoutStats = + ROW({rowCountColumnName(), fragmentColumnName(), contextColumnName()}, + {rowCountColumnType(), fragmentColumnType(), contextColumnType()}); + if (!columnStatsSpec.has_value()) { + return kOutputTypeWithoutStats; + } + return kOutputTypeWithoutStats->unionWith(columnStatsSpec->outputType()); +} + +bool TableWriteTraits::isStatisticsRow( + const RowVectorPtr& output, + vector_size_t index) { + VELOX_DCHECK_LT(index, output->size()); + return output->childAt(kRowCountChannel)->isNullAt(index) && + output->childAt(kFragmentChannel)->isNullAt(index); +} + +folly::dynamic TableWriteTraits::getTableCommitContext( + const RowVectorPtr& input) { + VELOX_CHECK_GT(input->size(), 0); + auto* contextVector = + input->childAt(kContextChannel)->as>(); + return folly::parseJson( + std::string_view(contextVector->valueAt(input->size() - 1))); +} + +int64_t TableWriteTraits::getRowCount(const RowVectorPtr& output) { + VELOX_CHECK_GT(output->size(), 0); + auto* rowCountVector = + output->childAt(kRowCountChannel)->as>(); + VELOX_CHECK_NOT_NULL(rowCountVector); + int64_t rowCount{0}; + for (int i = 0; i < output->size(); ++i) { + if (!rowCountVector->isNullAt(i)) { + rowCount += rowCountVector->valueAt(i); + } + } + return rowCount; +} + +} // namespace facebook::velox::core diff --git a/velox/core/TableWriteTraits.h b/velox/core/TableWriteTraits.h new file mode 100644 index 00000000000..1c671c4fe08 --- /dev/null +++ b/velox/core/TableWriteTraits.h @@ -0,0 +1,112 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "velox/core/PlanNode.h" +#include "velox/type/Type.h" + +namespace facebook::velox::core { + +/// Defines table writer output related config properties that are shared +/// between TableWriter and TableWriteMerger. +/// +/// TODO: the table write output processing is Prestissimo specific. Consider +/// move these part logic to Prestissimo and pass to Velox through a customized +/// output processing callback. +class TableWriteTraits { + public: + /// Defines the column names/types in table write output. + static std::string rowCountColumnName(); + static std::string fragmentColumnName(); + static std::string contextColumnName(); + + static const TypePtr& rowCountColumnType(); + static const TypePtr& fragmentColumnType(); + static const TypePtr& contextColumnType(); + + /// Defines the column channels in table write output. + /// Both the statistics and the row_count + fragments are transferred over the + /// same communication link between the TableWriter and TableFinish. Thus the + /// multiplexing is needed. + /// + /// The transferred page layout looks like: + /// [row_count_channel], [fragment_channel], [context_channel], + /// [statistic_channel_1] ... [statistic_channel_N]] + /// + /// [row_count_channel] - contains number of rows processed by a TableWriter + /// [fragment_channel] - contains data provided by the DataSink#finish + /// [context_channel] - JSON-serialized commit context (VARBINARY). + /// Present on every row. See k*ContextKey constants below. + /// [statistic_channel_1] ...[statistic_channel_N] - + /// contain aggregated statistics computed by the statistics aggregation + /// within the TableWriter + /// + /// For convenience, we never set both: [row_count_channel] + + /// [fragment_channel] and the [statistic_channel_1] ... + /// [statistic_channel_N]. + /// + /// If this is a row that holds statistics - the [row_count_channel] + + /// [fragment_channel] will be NULL. + /// + /// If this is a row that holds the row count + /// or the fragment - all the statistics channels will be set to NULL. + static constexpr int32_t kRowCountChannel = 0; + static constexpr int32_t kFragmentChannel = 1; + static constexpr int32_t kContextChannel = 2; + static constexpr int32_t kStatsChannel = 3; + + /// Field names in the commit context JSON object (context_channel). + /// The context is a JSON object serialized as VARBINARY, e.g.: + /// {"taskId":"...", "lifespan":"TaskWide", + /// "pageSinkCommitStrategy":"NO_COMMIT", "lastPage":true} + static constexpr std::string_view kLifeSpanContextKey = "lifespan"; + static constexpr std::string_view kTaskIdContextKey = "taskId"; + static constexpr std::string_view kCommitStrategyContextKey = + "pageSinkCommitStrategy"; + static constexpr std::string_view klastPageContextKey = "lastPage"; + + static RowTypePtr outputType( + const std::optional& columnStatsSpec); + + /// Returns true if row 'index' in 'output' is a statistics row (both row + /// count and fragment channels are NULL). Statistics rows carry aggregated + /// per-column stats; data rows carry row counts and/or file fragments. + static bool isStatisticsRow( + const RowVectorPtr& output, + vector_size_t index = 0); + + /// Returns the parsed commit context from table writer 'output'. + static folly::dynamic getTableCommitContext(const RowVectorPtr& output); + + /// Returns the sum of row counts from table writer 'output'. + static int64_t getRowCount(const RowVectorPtr& output); + + /// Creates the statistics output. + /// Statistics page layout (aggregate by partition): + /// row fragments context [partition] stats1 stats2 ... + /// null null X [X] X X + /// null null X [X] X X + static RowVectorPtr createAggregationStatsOutput( + RowTypePtr outputType, + RowVectorPtr aggregationOutput, + StringView tableCommitContext, + velox::memory::MemoryPool* pool); +}; + +} // namespace facebook::velox::core diff --git a/velox/core/tests/CMakeLists.txt b/velox/core/tests/CMakeLists.txt index 19cdc3611c6..d2c87a04bca 100644 --- a/velox/core/tests/CMakeLists.txt +++ b/velox/core/tests/CMakeLists.txt @@ -16,6 +16,7 @@ add_executable( velox_core_test ConstantTypedExprTest.cpp PlanFragmentTest.cpp + PlanNodeBuilderTest.cpp PlanNodeTest.cpp QueryConfigTest.cpp QueryCtxTest.cpp @@ -39,11 +40,18 @@ target_link_libraries( GTest::gtest_main ) +add_executable(velox_query_config_provider_test QueryConfigProviderTest.cpp) +add_test(velox_query_config_provider_test velox_query_config_provider_test) +target_link_libraries( + velox_query_config_provider_test + PRIVATE velox_query_config_provider velox_core GTest::gtest GTest::gtest_main +) + add_executable(velox_core_plan_consistency_checker_test PlanConsistencyCheckerTest.cpp) add_test(velox_core_plan_consistency_checker_test velox_core_plan_consistency_checker_test) target_link_libraries( velox_core_plan_consistency_checker_test - PRIVATE velox_core GTest::gtest GTest::gtest_main + PRIVATE velox_core velox_exec GTest::gtest GTest::gtest_main ) diff --git a/velox/core/tests/ConstantTypedExprTest.cpp b/velox/core/tests/ConstantTypedExprTest.cpp index 3d067f9ede8..2107d373e8a 100644 --- a/velox/core/tests/ConstantTypedExprTest.cpp +++ b/velox/core/tests/ConstantTypedExprTest.cpp @@ -15,15 +15,179 @@ */ #include +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/memory/Memory.h" #include "velox/core/Expressions.h" #include "velox/functions/prestosql/types/HyperLogLogType.h" #include "velox/functions/prestosql/types/JsonType.h" #include "velox/functions/prestosql/types/TDigestType.h" #include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" +#include "velox/type/Variant.h" +#include "velox/vector/BaseVector.h" +#include "velox/vector/tests/utils/VectorTestBase.h" namespace facebook::velox::core::test { -TEST(ConstantTypedExprTest, null) { +namespace { +struct TestOpaqueStruct { + int value; + std::string name; + + TestOpaqueStruct(int v, std::string n) : value(v), name(std::move(n)) {} + + bool operator==(const TestOpaqueStruct& other) const { + return value == other.value && name == other.name; + } +}; + +} // namespace + +class ConstantTypedExprTest : public ::testing::Test, + public velox::test::VectorTestBase { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + } + + void SetUp() override { + pool_ = memory::memoryManager()->addLeafPool(); + + // Register serialization/deserialization functions needed for the tests + Type::registerSerDe(); + ITypedExpr::registerSerDe(); + + // Register OPAQUE type serialization for TestOpaqueStruct + static folly::once_flag once; + folly::call_once(once, []() { + OpaqueType::registerSerialization( + "TestOpaqueStruct", + [](const std::shared_ptr& obj) -> std::string { + return folly::json::serialize( + folly::dynamic::object("value", obj->value)("name", obj->name), + folly::json::serialization_opts{}); + }, + [](const std::string& json) -> std::shared_ptr { + folly::dynamic obj = folly::parseJson(json); + return std::make_shared( + obj["value"].asInt(), obj["name"].asString()); + }); + }); + } + + // Helper functions + std::shared_ptr createVariantExpr( + const TypePtr& type, + const Variant& value) { + return std::make_shared(type, value); + } + + std::shared_ptr createNullVariantExpr( + const TypePtr& type) { + return std::make_shared( + type, variant::null(type->kind())); + } + + std::shared_ptr createVectorExpr(const VectorPtr& vector) { + return std::make_shared(vector); + } + + template + VectorPtr createConstantVector(const TypePtr& type, const T& value) { + return BaseVector::createConstant(type, variant(value), 1, pool_.get()); + } + + VectorPtr createNullConstantVector(const TypePtr& type) { + return BaseVector::createNullConstant(type, 1, pool_.get()); + } + + // Test Data + struct TestValues { + variant nullValue; + std::vector nonNullValues; + + TestValues(TypeKind kind) : nullValue(variant::null(kind)) {} + }; + + TestValues getTestValues(TypeKind kind) { + TestValues values(kind); + + switch (kind) { + case TypeKind::BOOLEAN: + values.nonNullValues = {variant(true), variant(false)}; + break; + case TypeKind::TINYINT: + values.nonNullValues = { + variant(int8_t(0)), variant(int8_t(127)), variant(int8_t(-128))}; + break; + case TypeKind::SMALLINT: + values.nonNullValues = { + variant(int16_t(0)), + variant(int16_t(32767)), + variant(int16_t(-32768))}; + break; + case TypeKind::INTEGER: + values.nonNullValues = { + variant(int32_t(0)), + variant(int32_t(2147483647)), + variant(int32_t(-2147483648))}; + break; + case TypeKind::BIGINT: + values.nonNullValues = { + variant(int64_t(0)), + variant(int64_t(9223372036854775807LL)), + variant(int64_t(-9223372036854775808ULL))}; + break; + case TypeKind::REAL: + values.nonNullValues = {variant(0.0f), variant(3.14f), variant(-1.5f)}; + break; + case TypeKind::DOUBLE: + values.nonNullValues = { + variant(0.0), variant(3.14159), variant(-2.71828)}; + break; + case TypeKind::VARCHAR: + values.nonNullValues = { + variant(""), variant("hello"), variant("test string")}; + break; + case TypeKind::VARBINARY: + values.nonNullValues = { + variant::binary(""), + variant::binary("binary data"), + variant::binary("\x00\x01\x02")}; + break; + case TypeKind::TIMESTAMP: + values.nonNullValues = { + variant(Timestamp(0, 0)), + variant(Timestamp(1234567890, 123456789))}; + break; + case TypeKind::HUGEINT: + values.nonNullValues = { + variant(int128_t(0)), + variant(int128_t(123)), + variant(int128_t(-456))}; + break; + default: + // For complex types, we'll handle them within individual tests. + break; + } + return values; + } + + std::shared_ptr pool_; + const std::vector scalarTypes_ = { + TypeKind::BOOLEAN, + TypeKind::TINYINT, + TypeKind::SMALLINT, + TypeKind::INTEGER, + TypeKind::BIGINT, + TypeKind::REAL, + TypeKind::DOUBLE, + TypeKind::VARCHAR, + TypeKind::VARBINARY, + TypeKind::TIMESTAMP, + TypeKind::HUGEINT}; +}; + +TEST_F(ConstantTypedExprTest, null) { auto makeNull = [](const TypePtr& type) { return std::make_shared( type, variant::null(type->kind())); @@ -67,4 +231,386 @@ TEST(ConstantTypedExprTest, null) { *makeNull(ROW({"x", "y"}, {INTEGER(), REAL()}))); } +TEST_F(ConstantTypedExprTest, hashScalarTypes) { + // Tests the consistency of the hash value returned by the ConstantTypedExpr + // between its construction using variant and Velox vectors. + for (auto kind : scalarTypes_) { + auto type = createScalarType(kind); + auto testValues = getTestValues(kind); + + // null values + auto nullVariantExpr = createNullVariantExpr(type); + auto nullVectorExpr = createVectorExpr(createNullConstantVector(type)); + EXPECT_EQ(nullVariantExpr->hash(), nullVectorExpr->hash()) + << "Hash mismatch for null " << TypeKindName::toName(kind); + + // non-null values + for (const auto& value : testValues.nonNullValues) { + auto variantExpr = std::make_shared(type, value); + auto vectorExpr = createVectorExpr( + BaseVector::createConstant(type, value, 1, pool_.get())); + EXPECT_EQ(variantExpr->hash(), vectorExpr->hash()) + << "Hash mismatch for non-null " << TypeKindName::toName(kind) + << " with value " << value.toJson(type); + } + } +} + +TEST_F(ConstantTypedExprTest, hashComplexTypes) { + // ARRAY + auto arrayType = ARRAY(INTEGER()); + + // null values + auto nullArrayVariantExpr = createNullVariantExpr(arrayType); + auto nullArrayVectorExpr = + createVectorExpr(createNullConstantVector(arrayType)); + EXPECT_EQ(nullArrayVariantExpr->hash(), nullArrayVectorExpr->hash()) + << "Hash mismatch for null ARRAY variant vs vector"; + + // non-null values + auto arrayVariant = Variant::array({1, 2, 3}); + auto arrayVariantExpr = + std::make_shared(arrayType, arrayVariant); + auto arrayVector = makeArrayVector({{1, 2, 3}}); + auto arrayVectorExpr = createVectorExpr(arrayVector); + EXPECT_EQ(arrayVariantExpr->hash(), arrayVectorExpr->hash()) + << "Hash mismatch for non-null ARRAY variant vs vector"; + + // MAP + auto mapType = MAP(VARCHAR(), INTEGER()); + + // null values + auto nullMapVariantExpr = createNullVariantExpr(mapType); + auto nullMapVectorExpr = createVectorExpr(createNullConstantVector(mapType)); + EXPECT_EQ(nullMapVariantExpr->hash(), nullMapVectorExpr->hash()) + << "Hash mismatch for null MAP variant vs vector"; + + // non-null values + std::map mapData = {{"key1", 1}, {"key2", 2}}; + auto mapVariant = Variant::map(mapData); + auto mapVariantExpr = + std::make_shared(mapType, mapVariant); + auto mapVector = + makeMapVector({{{"key1", 1}, {"key2", 2}}}); + auto mapVectorExpr = createVectorExpr(mapVector); + EXPECT_EQ(mapVariantExpr->hash(), mapVectorExpr->hash()) + << "Hash mismatch for non-null MAP variant vs vector"; + + // ROW + auto rowType = ROW({{"a", INTEGER()}, {"b", VARCHAR()}}); + + // null values + auto nullRowVariantExpr = createNullVariantExpr(rowType); + auto nullRowVectorExpr = createVectorExpr(createNullConstantVector(rowType)); + EXPECT_EQ(nullRowVariantExpr->hash(), nullRowVectorExpr->hash()) + << "Hash mismatch for null ROW variant vs vector"; + + // non-null values + auto rowVariant = Variant::row({42, "hello"}); + auto rowVariantExpr = + std::make_shared(rowType, rowVariant); + auto rowVector = makeRowVector( + {makeFlatVector({42}), makeFlatVector({"hello"})}); + auto rowVectorExpr = createVectorExpr(rowVector); + EXPECT_EQ(rowVariantExpr->hash(), rowVectorExpr->hash()) + << "Hash mismatch for non-null ROW variant vs vector"; + + // OPAQUE + auto testObj = std::make_shared(42, "test_data"); + auto opaqueType = OPAQUE(); + + // null values + auto nullOpaqueVariantExpr = createNullVariantExpr(opaqueType); + auto nullOpaqueVectorExpr = + createVectorExpr(createNullConstantVector(opaqueType)); + EXPECT_EQ(nullOpaqueVariantExpr->hash(), nullOpaqueVectorExpr->hash()) + << "Hash mismatch for null OPAQUE"; + + // non-null values + auto opaqueVariant = Variant::opaque(testObj); + auto opaqueVariantExpr = + std::make_shared(opaqueType, opaqueVariant); + auto opaqueVectorExpr = createVectorExpr( + BaseVector::createConstant(opaqueType, opaqueVariant, 1, pool_.get())); + EXPECT_EQ(opaqueVariantExpr->hash(), opaqueVectorExpr->hash()) + << "Hash mismatch for non-null OPAQUE"; +} + +TEST_F(ConstantTypedExprTest, serdeScalarTypes) { + // Test serialize/deserialize APIs for scalar types to ensure backward + // compatibility. + for (auto kind : scalarTypes_) { + auto type = createScalarType(kind); + auto testValues = getTestValues(kind); + + // null values + auto nullVariantExpr = createNullVariantExpr(type); + auto serialized = nullVariantExpr->serialize(); + auto deserialized = ConstantTypedExpr::create(serialized, pool_.get()); + EXPECT_TRUE(*nullVariantExpr == *deserialized) + << "Serialize/deserialize mismatch for null variant " + << TypeKindName::toName(kind); + auto nullVectorExpr = createVectorExpr(createNullConstantVector(type)); + serialized = nullVectorExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, pool_.get()); + EXPECT_TRUE(*nullVectorExpr == *deserialized) + << "Serialize/deserialize mismatch for null vector " + << TypeKindName::toName(kind); + + // non-null values + for (const auto& value : testValues.nonNullValues) { + auto variantExpr = std::make_shared(type, value); + serialized = variantExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, pool_.get()); + EXPECT_TRUE(*variantExpr == *deserialized) + << "Serialize/deserialize mismatch for variant " + << TypeKindName::toName(kind); + + auto vectorExpr = createVectorExpr( + BaseVector::createConstant(type, value, 1, pool_.get())); + serialized = vectorExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, pool_.get()); + EXPECT_TRUE(*vectorExpr == *deserialized) + << "Serialize/deserialize mismatch for vector " + << TypeKindName::toName(kind) << " with value " << value.toJson(type); + } + } +} + +TEST_F(ConstantTypedExprTest, serdeComplexTypes) { + // ARRAY + auto arrayType = ARRAY(INTEGER()); + + // null values + auto nullArrayVariantExpr = createNullVariantExpr(arrayType); + auto serialized = nullArrayVariantExpr->serialize(); + auto deserialized = ConstantTypedExpr::create(serialized, nullptr); + EXPECT_TRUE(*nullArrayVariantExpr == *deserialized) + << "Serialize/deserialize mismatch for null ARRAY variant"; + auto nullArrayVectorExpr = + createVectorExpr(createNullConstantVector(arrayType)); + serialized = nullArrayVectorExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, pool_.get()); + EXPECT_TRUE(*nullArrayVectorExpr == *deserialized) + << "Serialize/deserialize mismatch for null ARRAY vector"; + + // non-null values + auto arrayVariant = Variant::array({1, 2, 3}); + auto arrayVariantExpr = + std::make_shared(arrayType, arrayVariant); + serialized = arrayVariantExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, nullptr); + EXPECT_TRUE(*arrayVariantExpr == *deserialized) + << "Serialize/deserialize mismatch for ARRAY variant with data"; + auto arrayVector = makeArrayVector({{1, 2, 3}}); + auto arrayVectorExpr = createVectorExpr(arrayVector); + serialized = arrayVectorExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, pool_.get()); + EXPECT_TRUE(*arrayVectorExpr == *deserialized) + << "Serialize/deserialize mismatch for ARRAY vector with data"; + + // MAP + auto mapType = MAP(VARCHAR(), INTEGER()); + // null values + auto nullMapVariantExpr = createNullVariantExpr(mapType); + serialized = nullMapVariantExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, nullptr); + EXPECT_TRUE(*nullMapVariantExpr == *deserialized) + << "Serialize/deserialize mismatch for null MAP variant"; + + // non-null values + std::map mapData = {{"key1", 1}, {"key2", 2}}; + auto mapVariant = Variant::map(mapData); + auto mapVariantExpr = + std::make_shared(mapType, mapVariant); + serialized = mapVariantExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, nullptr); + EXPECT_TRUE(*mapVariantExpr == *deserialized) + << "Serialize/deserialize mismatch for MAP variant with data"; + + // ROW + auto rowType = ROW({{"a", INTEGER()}, {"b", VARCHAR()}}); + // null values + auto nullRowVariantExpr = createNullVariantExpr(rowType); + serialized = nullRowVariantExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, nullptr); + EXPECT_TRUE(*nullRowVariantExpr == *deserialized) + << "Serialize/deserialize mismatch for null ROW variant"; + + // non-null values + auto rowVariant = Variant::row({42, "hello"}); + auto rowVariantExpr = + std::make_shared(rowType, rowVariant); + serialized = rowVariantExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, nullptr); + EXPECT_TRUE(*rowVariantExpr == *deserialized) + << "Serialize/deserialize mismatch for ROW variant with data"; + + // OPAQUE + auto opaqueType = OPAQUE(); + + // null values + auto nullOpaqueVariantExpr = createNullVariantExpr(opaqueType); + serialized = nullOpaqueVariantExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, nullptr); + EXPECT_TRUE(*nullOpaqueVariantExpr == *deserialized) + << "Serialize/deserialize mismatch for null OPAQUE variant"; + auto nullOpaqueVectorExpr = + createVectorExpr(createNullConstantVector(opaqueType)); + serialized = nullOpaqueVectorExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, pool_.get()); + EXPECT_TRUE(*nullOpaqueVectorExpr == *deserialized) + << "Serialize/deserialize mismatch for null OPAQUE vector"; + + // non-null values + auto testObj = std::make_shared(42, "test_data"); + auto opaqueVariant = Variant::opaque(testObj); + auto opaqueVariantExpr = + std::make_shared(opaqueType, opaqueVariant); + serialized = opaqueVariantExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, nullptr); + auto actualObj = static_pointer_cast(deserialized) + ->value() + .value() + .obj; + EXPECT_EQ(*testObj, *static_pointer_cast(actualObj)); +} + +TEST_F(ConstantTypedExprTest, toStringScalarTypes) { + for (auto kind : scalarTypes_) { + auto type = createScalarType(kind); + auto testValues = getTestValues(kind); + + // null values + auto nullVariantExpr = createNullVariantExpr(type); + auto nullVectorExpr = createVectorExpr(createNullConstantVector(type)); + EXPECT_EQ(nullVariantExpr->toString(), nullVectorExpr->toString()) + << "toString mismatch for null " << TypeKindName::toName(kind); + + // non-null values + for (const auto& value : testValues.nonNullValues) { + auto variantExpr = std::make_shared(type, value); + auto vectorExpr = createVectorExpr( + BaseVector::createConstant(type, value, 1, pool_.get())); + EXPECT_EQ(variantExpr->toString(), vectorExpr->toString()) + << "toString mismatch for " << TypeKindName::toName(kind) + << " with value " << value.toJson(type); + } + } +} + +TEST_F(ConstantTypedExprTest, toStringComplexTypes) { + // ARRAY + auto arrayType = ARRAY(INTEGER()); + + // null values + auto nullArrayVariantExpr = createNullVariantExpr(arrayType); + auto nullArrayVectorExpr = + createVectorExpr(createNullConstantVector(arrayType)); + EXPECT_EQ(nullArrayVariantExpr->toString(), nullArrayVectorExpr->toString()) + << "toString mismatch for null ARRAY"; + + // non-null values + auto arrayVariant = Variant::array({1, 2, 3}); + auto arrayVariantExpr = + std::make_shared(arrayType, arrayVariant); + auto arrayVector = makeArrayVector({{1, 2, 3}}); + auto arrayVectorExpr = createVectorExpr(arrayVector); + EXPECT_EQ(arrayVariantExpr->toString(), arrayVectorExpr->toString()) + << "toString mismatch for ARRAY variant vs vector"; + + // MAP + auto mapType = MAP(VARCHAR(), INTEGER()); + + // null values + auto nullMapVariantExpr = createNullVariantExpr(mapType); + auto nullMapVectorExpr = createVectorExpr(createNullConstantVector(mapType)); + EXPECT_EQ(nullMapVariantExpr->toString(), nullMapVectorExpr->toString()) + << "toString mismatch for null MAP"; + + // non-null values + std::map mapData = {{"key1", 1}, {"key2", 2}}; + auto mapVariant = Variant::map(mapData); + auto mapVariantExpr = + std::make_shared(mapType, mapVariant); + auto mapVector = + makeMapVector({{{"key1", 1}, {"key2", 2}}}); + auto mapVectorExpr = createVectorExpr(mapVector); + EXPECT_EQ(mapVariantExpr->toString(), mapVectorExpr->toString()) + << "toString mismatch for MAP variant vs vector"; + + // ROW + auto rowType = ROW({{"a", INTEGER()}, {"b", VARCHAR()}}); + + // null values + auto nullRowVariantExpr = createNullVariantExpr(rowType); + auto nullRowVectorExpr = createVectorExpr(createNullConstantVector(rowType)); + EXPECT_EQ(nullRowVariantExpr->toString(), nullRowVectorExpr->toString()) + << "toString mismatch for null ROW"; + + // non-null values + auto rowVariant = Variant::row({42, "hello"}); + auto rowVariantExpr = + std::make_shared(rowType, rowVariant); + auto rowVector = makeRowVector( + {makeFlatVector({42}), makeFlatVector({"hello"})}); + auto rowVectorExpr = createVectorExpr(rowVector); + EXPECT_EQ(rowVariantExpr->toString(), rowVectorExpr->toString()) + << "toString mismatch for ROW variant vs vector"; + + // OPAQUE + auto opaqueType = OPAQUE(); + + // null values + auto nullOpaqueVariantExpr = createNullVariantExpr(opaqueType); + auto nullOpaqueVectorExpr = + createVectorExpr(createNullConstantVector(opaqueType)); + EXPECT_EQ(nullOpaqueVariantExpr->toString(), nullOpaqueVectorExpr->toString()) + << "toString mismatch for null OPAQUE"; + + // non-null values + auto testObj = std::make_shared(42, "test_data"); + auto opaqueVariant = Variant::opaque(testObj); + auto opaqueVariantExpr = + std::make_shared(opaqueType, opaqueVariant); + auto opaqueVectorExpr = createVectorExpr( + BaseVector::createConstant(opaqueType, opaqueVariant, 1, pool_.get())); + EXPECT_EQ(opaqueVariantExpr->toString(), opaqueVectorExpr->toString()) + << "toString mismatch for OPAQUE variant vs vector"; +} + +TEST_F(ConstantTypedExprTest, variantTypeCheck) { + auto testVariantExpr = [&](const Variant& value, + const TypePtr& type, + const TypePtr& expectedType) { + VELOX_ASSERT_THROW( + createVariantExpr(type, value), + fmt::format( + "Expression type {} does not match variant type {}", + type->toString(), + expectedType->toString())); + if (type->isPrimitiveType()) { + VELOX_ASSERT_THROW( + createVariantExpr(type, Variant::null(expectedType->kind())), + fmt::format( + "Expression type {} does not match variant type {}", + type->toString(), + expectedType->toString())); + } else { + ASSERT_NO_THROW( + createVariantExpr(type, Variant::null(expectedType->kind()))); + } + }; + + testVariantExpr("abc", INTEGER(), VARCHAR()); + testVariantExpr(variant(123LL), INTEGER(), BIGINT()); + testVariantExpr(2.0, BIGINT(), DOUBLE()); + testVariantExpr( + variant::array({1, 2, 3}), ARRAY(VARCHAR()), ARRAY(INTEGER())); + testVariantExpr( + variant::map({{2.0, "xyz"}}), + MAP(INTEGER(), VARCHAR()), + MAP(DOUBLE(), VARCHAR())); +} + } // namespace facebook::velox::core::test diff --git a/velox/core/tests/PlanConsistencyCheckerTest.cpp b/velox/core/tests/PlanConsistencyCheckerTest.cpp index 41c7992b7dd..33b99e191da 100644 --- a/velox/core/tests/PlanConsistencyCheckerTest.cpp +++ b/velox/core/tests/PlanConsistencyCheckerTest.cpp @@ -17,6 +17,7 @@ #include "velox/common/base/tests/GTestUtils.h" #include "velox/core/PlanConsistencyChecker.h" +#include "velox/parse/PlanNodeIdGenerator.h" namespace facebook::velox::core { @@ -26,6 +27,16 @@ class PlanConsistencyCheckerTest : public testing::Test { static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); } + + void SetUp() override { + idGenerator_.reset(); + } + + std::string nextId() { + return idGenerator_.next(); + } + + PlanNodeIdGenerator idGenerator_; }; TypedExprPtr Lit(Variant value) { @@ -33,28 +44,28 @@ TypedExprPtr Lit(Variant value) { return std::make_shared(std::move(type), std::move(value)); } -TypedExprPtr Col(TypePtr type, std::string name) { +FieldAccessTypedExprPtr Col(TypePtr type, std::string name) { return std::make_shared( std::move(type), std::move(name)); } TEST_F(PlanConsistencyCheckerTest, filter) { auto valuesNode = - std::make_shared("0", std::vector{}); + std::make_shared(nextId(), std::vector{}); auto projectNode = std::make_shared( - "2", + nextId(), std::vector{"a", "b", "c"}, std::vector{Lit(true), Lit(1), Lit(0.1)}, valuesNode); auto filterNode = - std::make_shared("1", Col(BOOLEAN(), "a"), projectNode); + std::make_shared(nextId(), Col(BOOLEAN(), "a"), projectNode); ASSERT_NO_THROW(PlanConsistencyChecker::check(filterNode)); // Wrong type. filterNode = - std::make_shared("1", Col(BOOLEAN(), "b"), projectNode); + std::make_shared(nextId(), Col(BOOLEAN(), "b"), projectNode); VELOX_ASSERT_THROW( PlanConsistencyChecker::check(filterNode), @@ -62,42 +73,402 @@ TEST_F(PlanConsistencyCheckerTest, filter) { // Wrong name. filterNode = - std::make_shared("1", Col(BOOLEAN(), "x"), projectNode); + std::make_shared(nextId(), Col(BOOLEAN(), "x"), projectNode); VELOX_ASSERT_THROW( PlanConsistencyChecker::check(filterNode), "Field not found: x"); + + // Non-existent column referenced in a lambda expression. + filterNode = std::make_shared( + nextId(), + std::make_shared( + BOOLEAN(), + "any_match", + Lit(Variant::array({1, 2, 3})), + std::make_shared( + ROW("x", INTEGER()), + std::make_shared( + BOOLEAN(), + "lt", + Col(INTEGER(), "x"), + Col(INTEGER(), "blah")))), + projectNode); + + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(filterNode), "Field not found: blah"); } TEST_F(PlanConsistencyCheckerTest, project) { auto valuesNode = - std::make_shared("0", std::vector{}); + std::make_shared(nextId(), std::vector{}); + + { + auto projectNode = std::make_shared( + nextId(), + std::vector{"a", "b", "c"}, + std::vector{Lit(true), Lit(1), Lit(0.1)}, + valuesNode); + ASSERT_NO_THROW(PlanConsistencyChecker::check(projectNode)); + } + + { + // Duplicate output name in the root ProjectNode is allowed. This is used to + // apply user-specified column aliases (e.g. SELECT 1 AS x, 2 AS x). + auto projectNode = std::make_shared( + nextId(), + std::vector{"a", "a", "c"}, + std::vector{Lit(true), Lit(1), Lit(0.1)}, + valuesNode); + ASSERT_NO_THROW(PlanConsistencyChecker::check(projectNode)); + + // Duplicate output name in a non-root ProjectNode is not allowed. + auto outputProject = std::make_shared( + nextId(), + std::vector{"x", "y", "z"}, + std::vector{ + Col(BOOLEAN(), "a"), Col(BOOLEAN(), "a"), Col(DOUBLE(), "c")}, + projectNode); + + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(outputProject), + "Duplicate output column: a"); + } + + // Wrong column name. + { + auto projectNode = std::make_shared( + nextId(), + std::vector{"a", "a", "c"}, + std::vector{Lit(true), Col(REAL(), "x"), Lit(0.1)}, + valuesNode); + + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(projectNode), "Field not found: x"); + } +} + +TEST_F(PlanConsistencyCheckerTest, aggregation) { + auto valuesNode = + std::make_shared(nextId(), std::vector{}); auto projectNode = std::make_shared( - "2", + nextId(), std::vector{"a", "b", "c"}, std::vector{Lit(true), Lit(1), Lit(0.1)}, valuesNode); ASSERT_NO_THROW(PlanConsistencyChecker::check(projectNode)); + { + auto aggregationNode = std::make_shared( + nextId(), + AggregationNode::Step::kPartial, + std::vector{}, + std::vector{}, + std::vector{"sum", "cnt"}, + std::vector{ + { + .call = std::make_shared( + BIGINT(), "sum", Col(INTEGER(), "x")), + .rawInputTypes = {BIGINT()}, + }, + { + .call = std::make_shared(BIGINT(), "count"), + .rawInputTypes = {}, + }, + }, + /*ignoreNullKeys=*/false, + /*noGroupsSpanBatches=*/false, + projectNode); + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(aggregationNode), "Field not found: x"); + } + + { + auto aggregationNode = std::make_shared( + nextId(), + AggregationNode::Step::kPartial, + std::vector{Col(INTEGER(), "y")}, + std::vector{}, + std::vector{"sum", "cnt"}, + std::vector{ + { + .call = std::make_shared( + BIGINT(), "sum", Col(INTEGER(), "b")), + .rawInputTypes = {BIGINT()}, + }, + { + .call = std::make_shared(BIGINT(), "count"), + .rawInputTypes = {}, + }, + }, + /*ignoreNullKeys=*/false, + /*noGroupsSpanBatches=*/false, + projectNode); + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(aggregationNode), "Field not found: y"); + } + + { + auto aggregationNode = std::make_shared( + nextId(), + AggregationNode::Step::kPartial, + std::vector{}, + std::vector{}, + std::vector{"sum", "cnt"}, + std::vector{ + { + .call = std::make_shared( + BIGINT(), "sum", Col(INTEGER(), "b")), + .rawInputTypes = {BIGINT()}, + .mask = Col(BOOLEAN(), "z"), + }, + { + .call = std::make_shared(BIGINT(), "count"), + .rawInputTypes = {}, + }, + }, + /*ignoreNullKeys=*/false, + /*noGroupsSpanBatches=*/false, + projectNode); + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(aggregationNode), "Field not found: z"); + } + + { + auto aggregationNode = std::make_shared( + nextId(), + AggregationNode::Step::kPartial, + std::vector{}, + std::vector{}, + std::vector{"sum", "sum"}, + std::vector{ + { + .call = std::make_shared( + BIGINT(), "sum", Col(INTEGER(), "b")), + .rawInputTypes = {BIGINT()}, + }, + { + .call = std::make_shared(BIGINT(), "count"), + .rawInputTypes = {}, + }, + }, + /*ignoreNullKeys=*/false, + /*noGroupsSpanBatches=*/false, + projectNode); + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(aggregationNode), + "Duplicate output column: sum"); + } +} + +TEST_F(PlanConsistencyCheckerTest, hashJoin) { + auto leftValuesNode = + std::make_shared(nextId(), std::vector{}); + + auto leftProjectNode = std::make_shared( + nextId(), + std::vector{"a", "b"}, + std::vector{Lit(1), Lit(2)}, + leftValuesNode); + ASSERT_NO_THROW(PlanConsistencyChecker::check(leftValuesNode)); + + auto rightValuesNode = + std::make_shared(nextId(), std::vector{}); + + auto rightProjectNode = std::make_shared( + nextId(), + std::vector{"c", "d"}, + std::vector{Lit(1), Lit(2)}, + leftValuesNode); + ASSERT_NO_THROW(PlanConsistencyChecker::check(rightProjectNode)); + + // Invalid reference in the filter. + { + auto joinNode = std::make_shared( + nextId(), + JoinType::kLeft, + /*nullAware=*/false, + std::vector{Col(INTEGER(), "a")}, + std::vector{Col(INTEGER(), "c")}, + std::make_shared( + BOOLEAN(), "lt", Col(INTEGER(), "b"), Col(INTEGER(), "blah")), + leftProjectNode, + rightProjectNode, + ROW({})); + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(joinNode), + "Field not found: blah. Available fields are: a, b, c, d."); + } + + // Duplicate join condition. + { + auto joinNode = std::make_shared( + nextId(), + JoinType::kLeft, + /*nullAware=*/false, + std::vector{ + Col(INTEGER(), "a"), Col(INTEGER(), "a")}, + std::vector{ + Col(INTEGER(), "c"), Col(INTEGER(), "c")}, + /*filter=*/nullptr, + leftProjectNode, + rightProjectNode, + ROW({})); + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(joinNode), + "Duplicate join condition: \"a\" = \"c\""); + } +} + +TEST_F(PlanConsistencyCheckerTest, nestedLoopJoin) { + auto leftValuesNode = + std::make_shared(nextId(), std::vector{}); + + auto leftProjectNode = std::make_shared( + nextId(), + std::vector{"a", "b"}, + std::vector{Lit(1), Lit(2)}, + leftValuesNode); + ASSERT_NO_THROW(PlanConsistencyChecker::check(leftValuesNode)); + + auto rightValuesNode = + std::make_shared(nextId(), std::vector{}); + + auto rightProjectNode = std::make_shared( + nextId(), + std::vector{"c", "d"}, + std::vector{Lit(1), Lit(2)}, + leftValuesNode); + ASSERT_NO_THROW(PlanConsistencyChecker::check(rightProjectNode)); + + // Invalid reference in the filter. + { + auto joinNode = std::make_shared( + nextId(), + JoinType::kLeft, + std::make_shared( + BOOLEAN(), "lt", Col(INTEGER(), "b"), Col(INTEGER(), "blah")), + leftProjectNode, + rightProjectNode, + ROW({})); + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(joinNode), + "Field not found: blah. Available fields are: a, b, c, d."); + } + // Duplicate output name. - projectNode = std::make_shared( - "2", - std::vector{"a", "a", "c"}, - std::vector{Lit(true), Lit(1), Lit(0.1)}, - valuesNode); + { + auto joinNode = std::make_shared( + nextId(), + leftProjectNode, + rightProjectNode, + ROW({"a", "c", "a"}, INTEGER())); + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(joinNode), "Duplicate output column: a"); + } +} - VELOX_ASSERT_THROW( - PlanConsistencyChecker::check(projectNode), "Duplicate output column: a"); +namespace { +class TestTableHandle : public connector::ConnectorTableHandle { + public: + explicit TestTableHandle(std::string connectorId, std::string name) + : connector::ConnectorTableHandle(std::move(connectorId)), + name_{std::move(name)} {} - // Wrong column name. - projectNode = std::make_shared( - "2", - std::vector{"a", "a", "c"}, - std::vector{Lit(true), Col(REAL(), "x"), Lit(0.1)}, - valuesNode); + const std::string& name() const override { + return name_; + } - VELOX_ASSERT_THROW( - PlanConsistencyChecker::check(projectNode), "Field not found: x"); + private: + const std::string name_; +}; + +class TestColumnHandle : public connector::ColumnHandle { + public: + explicit TestColumnHandle(std::string name) : name_{std::move(name)} {} + + const std::string& name() const override { + return name_; + } + + private: + const std::string name_; +}; +} // namespace + +TEST_F(PlanConsistencyCheckerTest, tableScan) { + // Empty output column name. + { + auto scanNode = std::make_shared( + nextId(), + ROW({"", "b"}, INTEGER()), + std::make_shared("test", "t"), + connector::ColumnHandleMap{}); + + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(scanNode), + "Output column name cannot be empty"); + } + + // Duplicate output column name. + { + auto scanNode = std::make_shared( + nextId(), + ROW({"a", "b", "a"}, INTEGER()), + std::make_shared("test", "t"), + connector::ColumnHandleMap{}); + + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(scanNode), "Duplicate output column: a"); + } + + // Missing assignments. + { + auto scanNode = std::make_shared( + nextId(), + ROW({"a", "b", "c"}, INTEGER()), + std::make_shared("test", "t"), + connector::ColumnHandleMap{}); + + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(scanNode), + "Column assignments must match output type"); + } + + { + connector::ColumnHandleMap assignments{ + {"a", std::make_shared("x")}, + {"b", std::make_shared("y")}, + {"blah", std::make_shared("z")}, + }; + + auto scanNode = std::make_shared( + nextId(), + ROW({"a", "b", "c"}, INTEGER()), + std::make_shared("test", "t"), + assignments); + + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(scanNode), + "Column assignment is missing for c"); + } + + // No issues. + { + connector::ColumnHandleMap assignments{ + {"a", std::make_shared("x")}, + {"b", std::make_shared("y")}, + {"c", std::make_shared("z")}, + }; + + auto scanNode = std::make_shared( + nextId(), + ROW({"a", "b", "c"}, INTEGER()), + std::make_shared("test", "t"), + assignments); + + ASSERT_NO_THROW(PlanConsistencyChecker::check(scanNode)); + } } } // namespace diff --git a/velox/core/tests/PlanFragmentTest.cpp b/velox/core/tests/PlanFragmentTest.cpp index 4e05757dfc1..64bc4e587df 100644 --- a/velox/core/tests/PlanFragmentTest.cpp +++ b/velox/core/tests/PlanFragmentTest.cpp @@ -188,7 +188,8 @@ TEST_F(PlanFragmentTest, aggregationCanSpill) { testData.hasPreAggregation ? preGroupingKeys : emptyPreGroupingKeys, testData.isDistinct ? emptyAggregateNames : aggregateNames, testData.isDistinct ? emptyAggregates : aggregates, - false, + /*ignoreNullKeys=*/false, + /*noGroupsSpanBatches=*/false, valueNode_); auto queryCtx = getSpillQueryCtx( testData.isSpillEnabled, diff --git a/velox/core/tests/PlanNodeBuilderTest.cpp b/velox/core/tests/PlanNodeBuilderTest.cpp index 840baacb712..f199083b51a 100644 --- a/velox/core/tests/PlanNodeBuilderTest.cpp +++ b/velox/core/tests/PlanNodeBuilderTest.cpp @@ -17,6 +17,7 @@ #include "velox/common/memory/Memory.h" #include "velox/core/PlanNode.h" +#include "velox/core/TableWriteTraits.h" #include "velox/duckdb/conversion/DuckParser.h" #include "velox/exec/tests/utils/AggregationResolver.h" #include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" @@ -62,7 +63,7 @@ class PlanNodeBuilderTest : public testing::Test, public test::VectorTestBase { core::AggregationNode::Aggregate agg; agg.call = std::dynamic_pointer_cast( - core::Expressions::inferTypes(untypedExpr.expr, type, pool())); + core::Expressions::inferTypes(untypedExpr, type, pool())); if (step == core::AggregationNode::Step::kPartial || step == core::AggregationNode::Step::kSingle) { @@ -74,14 +75,14 @@ class PlanNodeBuilderTest : public testing::Test, public test::VectorTestBase { agg.rawInputTypes = rawInputArgs[i]; } - VELOX_CHECK_NULL(untypedExpr.maskExpr); - VELOX_CHECK(!untypedExpr.distinct); - VELOX_CHECK(untypedExpr.orderBy.empty()); + VELOX_CHECK_NULL(untypedExpr->filter()); + VELOX_CHECK(!untypedExpr->isDistinct()); + VELOX_CHECK(untypedExpr->orderBy().empty()); aggs.emplace_back(agg); - if (untypedExpr.expr->alias().has_value()) { - names.push_back(untypedExpr.expr->alias().value()); + if (untypedExpr->alias().has_value()) { + names.push_back(untypedExpr->alias().value()); } else { names.push_back(fmt::format("a{}", i)); } @@ -358,7 +359,6 @@ TEST_F(PlanNodeBuilderTest, tableWriteNode) { const PlanNodeId id = "table_write_node_id"; const RowTypePtr columns = ROW({"c0"}, {INTEGER()}); const std::vector columnNames{"c0"}; - const RowTypePtr outputType = ROW({"c1"}, {BIGINT()}); const bool hasPartitioningScheme = true; const auto commitStrategy = connector::CommitStrategy::kNoCommit; @@ -367,6 +367,7 @@ TEST_F(PlanNodeBuilderTest, tableWriteNode) { std::vector{}, AggregationNode::Step::kPartial, std::vector{"sum(c0)"}); + const auto outputType = TableWriteTraits::outputType(statsSpec); const auto insertTableHandle = std::make_shared("connector_id", nullptr); @@ -378,7 +379,7 @@ TEST_F(PlanNodeBuilderTest, tableWriteNode) { EXPECT_EQ(node->insertTableHandle(), insertTableHandle); EXPECT_TRUE(node->hasColumnStatsSpec()); EXPECT_EQ(node->hasPartitioningScheme(), hasPartitioningScheme); - EXPECT_EQ(node->outputType(), outputType); + EXPECT_TRUE(node->outputType()->equivalent(*outputType)); EXPECT_EQ(node->commitStrategy(), commitStrategy); EXPECT_EQ(node->sources(), std::vector{source_}); }; @@ -402,19 +403,19 @@ TEST_F(PlanNodeBuilderTest, tableWriteNode) { TEST_F(PlanNodeBuilderTest, tableWriteMergeNode) { const PlanNodeId id = "table_write_merge_node_id"; - const RowTypePtr outputType = ROW({"c0"}, {BIGINT()}); const auto statsSpec = createStatsSpec( - outputType, + source_->outputType(), std::vector{}, AggregationNode::Step::kIntermediate, std::vector{"sum(c0)"}, {{BIGINT()}}); + const auto outputType = TableWriteTraits::outputType(statsSpec); const auto verify = [&](const std::shared_ptr& node) { EXPECT_EQ(node->id(), id); - EXPECT_EQ(node->outputType(), outputType); + EXPECT_TRUE(node->outputType()->equivalent(*outputType)); EXPECT_TRUE(node->hasColumnStatsSpec()); EXPECT_EQ(node->sources()[0], source_); }; @@ -496,7 +497,7 @@ TEST_F(PlanNodeBuilderTest, groupIdNode) { TEST_F(PlanNodeBuilderTest, exchangeNode) { const PlanNodeId id = "exchange_node_id"; const RowTypePtr type = ROW({"c0"}, {BIGINT()}); - const auto serdeKind = VectorSerde::Kind::kPresto; + const auto serdeKind = "Presto"; const auto verify = [&](const std::shared_ptr& node) { EXPECT_EQ(node->id(), id); @@ -518,7 +519,7 @@ TEST_F(PlanNodeBuilderTest, exchangeNode) { TEST_F(PlanNodeBuilderTest, mergeExchangeNode) { const PlanNodeId id = "merge_exchange_node_id"; const RowTypePtr type = ROW({"c0"}, {BIGINT()}); - const auto serdeKind = VectorSerde::Kind::kPresto; + const auto serdeKind = "Presto"; const std::vector sortingKeys = { std::make_shared(BIGINT(), "c1")}; const std::vector sortingOrders = {SortOrder(true, false)}; @@ -611,7 +612,7 @@ TEST_F(PlanNodeBuilderTest, partitionedOutputNode) { const auto partitionFunctionSpec = std::make_shared(); const RowTypePtr outputType = ROW({"c0"}, {BIGINT()}); - const auto serdeKind = VectorSerde::Kind::kPresto; + const auto serdeKind = "Presto"; const auto verify = [&](const std::shared_ptr& node) { @@ -670,6 +671,7 @@ TEST_F(PlanNodeBuilderTest, hashJoinNode) { const auto verify = [&](const std::shared_ptr& node) { EXPECT_EQ(node->id(), id); EXPECT_EQ(node->isNullAware(), nullAware); + EXPECT_FALSE(node->isNullAsValue()); EXPECT_EQ(node->joinType(), joinType); EXPECT_EQ(node->leftKeys(), leftKeys); EXPECT_EQ(node->rightKeys(), rightKeys); @@ -756,8 +758,8 @@ TEST_F(PlanNodeBuilderTest, indexLookupJoinNode) { const std::vector joinConditions{ std::make_shared( std::make_shared(BIGINT(), "c0"), - std::make_shared(BIGINT(), variant(1)), - std::make_shared(BIGINT(), variant(2)))}; + std::make_shared(BIGINT(), Variant(1LL)), + std::make_shared(BIGINT(), Variant(2LL)))}; const auto left = ValuesNode::Builder() .id("values_node_id_1") @@ -768,8 +770,9 @@ TEST_F(PlanNodeBuilderTest, indexLookupJoinNode) { TableScanNode::Builder() .id("values_node_id_2") .outputType(ROW({"c1"}, {VARCHAR()})) - .tableHandle(std::make_shared( - "connector_id")) + .tableHandle( + std::make_shared( + "connector_id")) .assignments({{"c1", std::make_shared()}}) .build(); const auto outputType = ROW({"c0"}, {BIGINT()}); @@ -809,7 +812,7 @@ TEST_F(PlanNodeBuilderTest, nestedLoopJoinNode) { const PlanNodeId id = "nested_loop_join_node_id"; const auto joinType = JoinType::kLeft; const auto joinCondition = - std::make_shared(BOOLEAN(), variant(true)); + std::make_shared(BOOLEAN(), Variant(true)); const auto left = ValuesNode::Builder() .id("values_node_id_1") @@ -881,25 +884,35 @@ TEST_F(PlanNodeBuilderTest, spatialJoinNode) { const PlanNodeId id = "spatial_join_node_id"; const auto joinType = JoinType::kInner; const auto joinCondition = - std::make_shared(BOOLEAN(), variant(true)); - const auto left = - ValuesNode::Builder() - .id("values_node_id_1") - .values({makeRowVector( - {"c0"}, {makeFlatVector(std::vector{1})})}) - .build(); - const auto right = - ValuesNode::Builder() - .id("values_node_id_2") - .values({makeRowVector( - {"c1"}, {makeFlatVector(std::vector{2})})}) - .build(); + std::make_shared(BOOLEAN(), Variant(true)); + const auto left = ValuesNode::Builder() + .id("values_node_id_1") + .values({makeRowVector( + {"c0", "g0"}, + {makeFlatVector(std::vector{1}), + makeFlatVector( + std::vector{"POINT(0 0)"})})}) + .build(); + const auto right = ValuesNode::Builder() + .id("values_node_id_2") + .values({makeRowVector( + {"c1", "g1"}, + {makeFlatVector(std::vector{2}), + makeFlatVector( + std::vector{"POINT(0 0)"})})}) + .build(); const auto outputType = ROW({"c0"}, {BIGINT()}); + const auto probeGeom = + std::make_shared(VARCHAR(), "g0"); + const auto buildGeom = + std::make_shared(VARCHAR(), "g1"); const auto verify = [&](const std::shared_ptr& node) { EXPECT_EQ(node->id(), id); EXPECT_EQ(node->joinType(), joinType); EXPECT_EQ(node->joinCondition(), joinCondition); + EXPECT_EQ(node->probeGeometry(), probeGeom); + EXPECT_EQ(node->buildGeometry(), buildGeom); EXPECT_EQ(node->sources()[0], left); EXPECT_EQ(node->sources()[1], right); EXPECT_EQ(node->outputType(), outputType); @@ -911,6 +924,8 @@ TEST_F(PlanNodeBuilderTest, spatialJoinNode) { .joinCondition(joinCondition) .left(left) .right(right) + .probeGeometry(probeGeom) + .buildGeometry(buildGeom) .outputType(outputType) .build(); verify(node); @@ -988,6 +1003,7 @@ TEST_F(PlanNodeBuilderTest, unnestNode) { std::vector unnestNames{"b"}; std::optional ordinalityName = std::make_optional("ord"); + std::optional splitOutput = false; const auto verify = [&](const std::shared_ptr& node) { EXPECT_EQ(node->id(), id); @@ -995,6 +1011,7 @@ TEST_F(PlanNodeBuilderTest, unnestNode) { EXPECT_EQ(node->unnestVariables(), unnestVariables); EXPECT_TRUE(node->hasOrdinality()); EXPECT_EQ(node->sources()[0], source_); + EXPECT_EQ(node->splitOutput(), splitOutput); for (int i = 0; i < node->outputType()->size(); ++i) { if (i < replicateVariables.size()) { @@ -1017,6 +1034,7 @@ TEST_F(PlanNodeBuilderTest, unnestNode) { .unnestNames(unnestNames) .ordinalityName(ordinalityName) .source(source_) + .splitOutput(splitOutput) .build(); verify(node); diff --git a/velox/core/tests/PlanNodeTest.cpp b/velox/core/tests/PlanNodeTest.cpp index c09d3b44cd7..30ccd68f825 100644 --- a/velox/core/tests/PlanNodeTest.cpp +++ b/velox/core/tests/PlanNodeTest.cpp @@ -13,10 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/core/PlanNode.h" #include - #include "velox/common/base/tests/GTestUtils.h" -#include "velox/core/PlanNode.h" +#include "velox/core/Expressions.h" +#include "velox/parse/PlanNodeIdGenerator.h" #include "velox/vector/fuzzer/VectorFuzzer.h" #include "velox/vector/tests/utils/VectorTestBase.h" @@ -85,6 +86,55 @@ TEST_F(PlanNodeTest, findFirstNode) { })); } +TEST_F(PlanNodeTest, findNodeById) { + auto values = std::make_shared("1", std::vector{}); + auto project = std::make_shared( + "2", + std::vector{"a", "b"}, + std::vector{ + std::make_shared(DOUBLE(), "rand"), + std::make_shared(DOUBLE(), "rand"), + }, + values); + + auto filter = std::make_shared( + "3", + std::make_shared( + BOOLEAN(), + "gt", + std::make_shared(DOUBLE(), "a"), + std::make_shared(DOUBLE(), 0.5)), + project); + + auto limit = std::make_shared("4", 0, 10, false, filter); + + ASSERT_EQ(PlanNode::findNodeById(limit.get(), "1"), values.get()); + ASSERT_EQ(PlanNode::findNodeById(limit.get(), "2"), project.get()); + ASSERT_EQ(PlanNode::findNodeById(limit.get(), "3"), filter.get()); + ASSERT_EQ(PlanNode::findNodeById(limit.get(), "4"), limit.get()); + + ASSERT_EQ(PlanNode::findNodeById(limit.get(), "5"), nullptr); + ASSERT_EQ(PlanNode::findNodeById(project.get(), "4"), nullptr); +} + +TEST_F(PlanNodeTest, is) { + auto values = std::make_shared("1", std::vector{}); + auto project = std::make_shared( + "2", + std::vector{"a", "b"}, + std::vector{ + std::make_shared(DOUBLE(), "rand"), + std::make_shared(DOUBLE(), "rand"), + }, + values); + + ASSERT_TRUE(values->is()); + ASSERT_FALSE(values->is()); + + ASSERT_FALSE(project->is()); + ASSERT_TRUE(project->is()); +} + TEST_F(PlanNodeTest, sortOrder) { struct { SortOrder order1; @@ -132,6 +182,7 @@ TEST_F(PlanNodeTest, duplicateSortKeys) { "orderBy", sortingKeys, sortingOrders, false, nullptr), "Duplicate sorting keys are not allowed: c0"); } + class TestIndexTableHandle : public connector::ConnectorTableHandle { public: TestIndexTableHandle() @@ -163,7 +214,32 @@ class TestIndexTableHandle : public connector::ConnectorTableHandle { } }; -TEST_F(PlanNodeTest, isIndexLookupJoin) { +TEST_F(PlanNodeTest, nestedLoopJoin) { + auto leftData = makeRowVector( + {"a"}, + { + makeFlatVector({1, 2, 3, 4, 5}), + }); + + auto rightData = makeRowVector({ + makeFlatVector({1, 2, 3, 4, 5}), + }); + + core::PlanNodeIdGenerator planNodeIdGenerator; + auto nextId = [&planNodeIdGenerator]() { return planNodeIdGenerator.next(); }; + + auto leftValues = std::make_shared( + nextId(), std::vector{leftData}); + auto rightValues = std::make_shared( + nextId(), std::vector{rightData}); + + VELOX_ASSERT_THROW( + std::make_shared( + nextId(), leftValues, rightValues, ROW({"a"}, VARCHAR())), + "Join output column type must match the input type: VARCHAR vs. INTEGER"); +} + +TEST_F(PlanNodeTest, indexLookupJoin) { const auto rowType = ROW({"name"}, {BIGINT()}); const auto valueNode = std::make_shared("orderBy", rowData_); ASSERT_FALSE(isIndexLookupJoin(valueNode.get())); @@ -193,12 +269,17 @@ TEST_F(PlanNodeTest, isIndexLookupJoin) { leftKeys, rightKeys, std::vector{}, - /*includeMatchColumn=*/false, + /*filter=*/nullptr, + /*hasMarker=*/false, probeNode, buildNode, outputType); ASSERT_TRUE(isIndexLookupJoin(indexJoinNodeWithInnerJoin.get())); - ASSERT_FALSE(indexJoinNodeWithInnerJoin->includeMatchColumn()); + ASSERT_FALSE(indexJoinNodeWithInnerJoin->hasMarker()); + ASSERT_EQ(indexJoinNodeWithInnerJoin->filter(), nullptr); + ASSERT_EQ( + indexJoinNodeWithInnerJoin->toString(/*detailed=*/true), + "-- IndexLookupJoin[indexJoinNode][INNER c0=c1] -> c0:BIGINT, c1:BIGINT\n"); } { const RowTypePtr outputTypeWithMatchColumn = @@ -210,12 +291,39 @@ TEST_F(PlanNodeTest, isIndexLookupJoin) { leftKeys, rightKeys, std::vector{}, - /*includeMatchColumn=*/true, + /*filter=*/nullptr, + /*hasMarker=*/true, probeNode, buildNode, outputTypeWithMatchColumn); ASSERT_TRUE(isIndexLookupJoin(indexJoinNodeWithLeftJoin.get())); - ASSERT_TRUE(indexJoinNodeWithLeftJoin->includeMatchColumn()); + ASSERT_TRUE(indexJoinNodeWithLeftJoin->hasMarker()); + ASSERT_EQ(indexJoinNodeWithLeftJoin->filter(), nullptr); + ASSERT_EQ( + indexJoinNodeWithLeftJoin->toString(/*detailed=*/true), + "-- IndexLookupJoin[indexJoinNode][LEFT c0=c1] -> c0:BIGINT, c1:BIGINT, c2:BOOLEAN\n"); + } + { + // Test IndexLookupJoinNode with filter + const auto filterExpr = std::make_shared( + BOOLEAN(), "filter_column"); + const auto indexJoinNodeWithFilter = std::make_shared( + "indexJoinNodeWithFilter", + core::JoinType::kInner, + leftKeys, + rightKeys, + std::vector{}, + /*filter=*/filterExpr, + /*hasMarker=*/false, + probeNode, + buildNode, + outputType); + ASSERT_TRUE(isIndexLookupJoin(indexJoinNodeWithFilter.get())); + ASSERT_FALSE(indexJoinNodeWithFilter->hasMarker()); + ASSERT_EQ(indexJoinNodeWithFilter->filter(), filterExpr); + ASSERT_EQ( + indexJoinNodeWithFilter->toString(/*detailed=*/true), + "-- IndexLookupJoin[indexJoinNodeWithFilter][INNER c0=c1, filter: \"filter_column\"] -> c0:BIGINT, c1:BIGINT\n"); } // Error case. { @@ -226,7 +334,8 @@ TEST_F(PlanNodeTest, isIndexLookupJoin) { leftKeys, rightKeys, std::vector{}, - /*includeMatchColumn=*/true, + /*filter=*/nullptr, + /*hasMarker=*/true, probeNode, buildNode, outputType), @@ -240,7 +349,8 @@ TEST_F(PlanNodeTest, isIndexLookupJoin) { leftKeys, rightKeys, std::vector{}, - /*includeMatchColumn=*/true, + /*filter=*/nullptr, + /*hasMarker=*/true, probeNode, buildNode, outputType), @@ -256,12 +366,28 @@ TEST_F(PlanNodeTest, isIndexLookupJoin) { leftKeys, rightKeys, std::vector{}, - /*includeMatchColumn=*/true, + /*filter=*/nullptr, + /*hasMarker=*/true, probeNode, buildNode, outputTypeWithDuplicateMatchColumn), ""); } + { + VELOX_ASSERT_THROW( + std::make_shared( + "indexJoinNode", + core::JoinType::kLeft, + leftKeys, + rightKeys, + std::vector{}, + /*filter=*/nullptr, + /*hasMarker=*/false, + probeNode, + buildNode, + ROW({"c0", "c1"}, {VARCHAR(), BIGINT()})), + "Join output column type must match the input type: VARCHAR vs. BIGINT"); + } } TEST_F(PlanNodeTest, partitionedOutputNode) { @@ -272,7 +398,7 @@ TEST_F(PlanNodeTest, partitionedOutputNode) { std::make_shared(BIGINT(), "c0")}; const PartitionFunctionSpecPtr partitionFunctionSpec = std::make_shared(); - const VectorSerde::Kind serdeKind = VectorSerde::Kind::kPresto; + const std::string serdeKind = "Presto"; PlanNodePtr source = std::make_shared("source", rowData_); { @@ -373,4 +499,239 @@ TEST_F(PlanNodeTest, partitionedOutputNode) { source), "partitioning doesn't allow for partitioning keys"); } + +TEST_F(PlanNodeTest, aggregationNodeNoGroupsSpanBatches) { + auto values = std::make_shared("values", rowData_); + + const std::vector groupingKeys{ + std::make_shared(BIGINT(), "c0")}; + const std::vector preGroupedKeys{ + std::make_shared(BIGINT(), "c0")}; + const std::vector aggregateNames{"sum"}; + const std::vector aggregates{ + {.call = std::make_shared(BIGINT(), "sum"), + .rawInputTypes = {BIGINT()}}}; + + // noGroupsSpanBatches=true with preGroupedKeys (streaming aggregation) should + // succeed and the accessor should return true. + { + auto aggNode = std::make_shared( + "agg", + AggregationNode::Step::kSingle, + groupingKeys, + preGroupedKeys, + aggregateNames, + aggregates, + /*ignoreNullKeys=*/false, + /*noGroupsSpanBatches=*/true, + values); + ASSERT_TRUE(aggNode->noGroupsSpanBatches()); + ASSERT_TRUE(aggNode->isPreGrouped()); + ASSERT_EQ( + aggNode->toString(true), + "-- Aggregation[agg][SINGLE STREAMING [c0] sum := sum() noGroupsSpanBatches] -> c0:BIGINT, sum:BIGINT\n"); + } + + // noGroupsSpanBatches=false with preGroupedKeys should succeed and the + // accessor should return false. + { + auto aggNode = std::make_shared( + "agg", + AggregationNode::Step::kSingle, + groupingKeys, + preGroupedKeys, + aggregateNames, + aggregates, + /*ignoreNullKeys=*/false, + /*noGroupsSpanBatches=*/false, + values); + ASSERT_FALSE(aggNode->noGroupsSpanBatches()); + ASSERT_TRUE(aggNode->isPreGrouped()); + ASSERT_EQ( + aggNode->toString(true), + "-- Aggregation[agg][SINGLE STREAMING [c0] sum := sum()] -> c0:BIGINT, sum:BIGINT\n"); + } + + // noGroupsSpanBatches=true without preGroupedKeys (non-streaming aggregation) + // should fail. + VELOX_ASSERT_THROW( + std::make_shared( + "agg", + AggregationNode::Step::kSingle, + groupingKeys, + /*preGroupedKeys=*/std::vector{}, + aggregateNames, + aggregates, + /*ignoreNullKeys=*/false, + /*noGroupsSpanBatches=*/true, + values), + "noGroupsSpanBatches can only be set for streaming aggregation (pre-grouped)"); + + // noGroupsSpanBatches=false without preGroupedKeys should succeed. + { + auto aggNode = std::make_shared( + "agg", + AggregationNode::Step::kSingle, + groupingKeys, + /*preGroupedKeys=*/std::vector{}, + aggregateNames, + aggregates, + /*ignoreNullKeys=*/false, + /*noGroupsSpanBatches=*/false, + values); + ASSERT_FALSE(aggNode->noGroupsSpanBatches()); + ASSERT_FALSE(aggNode->isPreGrouped()); + ASSERT_EQ( + aggNode->toString(true), + "-- Aggregation[agg][SINGLE [c0] sum := sum()] -> c0:BIGINT, sum:BIGINT\n"); + } +} +TEST_F(PlanNodeTest, rpcNodeSerdePerRowMode) { + Type::registerSerDe(); + core::PlanNode::registerSerDe(); + core::ITypedExpr::registerSerDe(); + + // Create a ValuesNode as the source with a "prompt" column. + auto sourceData = + makeRowVector({"prompt"}, {makeFlatVector({"hello"})}); + auto valuesNode = std::make_shared( + "values-1", std::vector{sourceData}); + + auto rpcNode = std::make_shared( + "rpc-1", + valuesNode, + "test_function", + VARCHAR(), + "response", + ROW({"prompt", "response"}, {VARCHAR(), VARCHAR()}), + std::vector{"prompt"}, + std::vector{VARCHAR()}, + std::vector{nullptr}, + rpc::RPCStreamingMode::kPerRow, + 0); + + // Serialize and deserialize. + const auto serialized = rpcNode->serialize(); + auto copy = + ISerializable::deserialize(serialized, pool_.get()); + + // Compare detailed string representation. + ASSERT_EQ(rpcNode->toString(true, true), copy->toString(true, true)); + + // Verify deserialized fields. + auto* copyRpc = dynamic_cast(copy.get()); + ASSERT_NE(copyRpc, nullptr); + EXPECT_EQ(copyRpc->functionName(), "test_function"); + EXPECT_EQ(copyRpc->outputColumn(), "response"); + EXPECT_EQ(copyRpc->streamingMode(), rpc::RPCStreamingMode::kPerRow); + EXPECT_EQ(copyRpc->dispatchBatchSize(), 0); + EXPECT_EQ(*copyRpc->rpcResultType(), *VARCHAR()); + EXPECT_EQ(copyRpc->argumentColumns().size(), 1); + EXPECT_EQ(copyRpc->argumentColumns()[0], "prompt"); + EXPECT_EQ(copyRpc->argumentTypes().size(), 1); + EXPECT_EQ(*copyRpc->argumentTypes()[0], *VARCHAR()); +} + +TEST_F(PlanNodeTest, rpcNodeSerdeBatchMode) { + Type::registerSerDe(); + core::PlanNode::registerSerDe(); + core::ITypedExpr::registerSerDe(); + + auto sourceData = makeRowVector( + {"prompt", "model"}, + {makeFlatVector({"hello"}), + makeFlatVector({"llama"})}); + auto valuesNode = std::make_shared( + "values-1", std::vector{sourceData}); + + auto rpcNode = std::make_shared( + "rpc-2", + valuesNode, + "batch_function", + VARCHAR(), + "result", + ROW({"prompt", "model", "result"}, {VARCHAR(), VARCHAR(), VARCHAR()}), + std::vector{"prompt", "model"}, + std::vector{VARCHAR(), VARCHAR()}, + std::vector{nullptr, nullptr}, + rpc::RPCStreamingMode::kBatch, + 50); + + const auto serialized = rpcNode->serialize(); + auto copy = + ISerializable::deserialize(serialized, pool_.get()); + + ASSERT_EQ(rpcNode->toString(true, true), copy->toString(true, true)); + + auto* copyRpc = dynamic_cast(copy.get()); + ASSERT_NE(copyRpc, nullptr); + EXPECT_EQ(copyRpc->functionName(), "batch_function"); + EXPECT_EQ(copyRpc->outputColumn(), "result"); + EXPECT_EQ(copyRpc->streamingMode(), rpc::RPCStreamingMode::kBatch); + EXPECT_EQ(copyRpc->dispatchBatchSize(), 50); + EXPECT_EQ(copyRpc->argumentColumns().size(), 2); + EXPECT_EQ(copyRpc->argumentColumns()[0], "prompt"); + EXPECT_EQ(copyRpc->argumentColumns()[1], "model"); +} + +TEST_F(PlanNodeTest, rpcNodeSerdeWithConstants) { + Type::registerSerDe(); + core::PlanNode::registerSerDe(); + core::ITypedExpr::registerSerDe(); + + auto sourceData = + makeRowVector({"prompt"}, {makeFlatVector({"hello"})}); + auto valuesNode = std::make_shared( + "values-1", std::vector{sourceData}); + + // Create constant vectors for model and system_prompt arguments. + auto modelConstant = makeConstant("llama3", 1); + auto systemPromptConstant = makeConstant("You are helpful.", 1); + + auto rpcNode = std::make_shared( + "rpc-3", + valuesNode, + "test_function", + VARCHAR(), + "response", + ROW({"prompt", "response"}, {VARCHAR(), VARCHAR()}), + std::vector{"prompt", "model", "system_prompt"}, + std::vector{VARCHAR(), VARCHAR(), VARCHAR()}, + std::vector{nullptr, modelConstant, systemPromptConstant}); + + // Verify constants before serde. + ASSERT_EQ(rpcNode->constantInputs().size(), 3); + EXPECT_EQ(rpcNode->constantInputs()[0], nullptr); + EXPECT_NE(rpcNode->constantInputs()[1], nullptr); + EXPECT_NE(rpcNode->constantInputs()[2], nullptr); + + // Serialize and deserialize. + const auto serialized = rpcNode->serialize(); + auto copy = + ISerializable::deserialize(serialized, pool_.get()); + + auto* copyRpc = dynamic_cast(copy.get()); + ASSERT_NE(copyRpc, nullptr); + EXPECT_EQ(copyRpc->functionName(), "test_function"); + EXPECT_EQ(copyRpc->argumentColumns().size(), 3); + + // Verify constants survive the round-trip. + ASSERT_EQ(copyRpc->constantInputs().size(), 3); + EXPECT_EQ(copyRpc->constantInputs()[0], nullptr); + ASSERT_NE(copyRpc->constantInputs()[1], nullptr); + ASSERT_NE(copyRpc->constantInputs()[2], nullptr); + + // Verify constant values. + auto modelVec = copyRpc->constantInputs()[1]; + EXPECT_TRUE(modelVec->isConstantEncoding()); + EXPECT_EQ( + modelVec->as>()->valueAt(0).str(), "llama3"); + + auto promptVec = copyRpc->constantInputs()[2]; + EXPECT_TRUE(promptVec->isConstantEncoding()); + EXPECT_EQ( + promptVec->as>()->valueAt(0).str(), + "You are helpful."); +} + } // namespace diff --git a/velox/core/tests/QueryConfigProviderTest.cpp b/velox/core/tests/QueryConfigProviderTest.cpp new file mode 100644 index 00000000000..4eceef24859 --- /dev/null +++ b/velox/core/tests/QueryConfigProviderTest.cpp @@ -0,0 +1,111 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/core/QueryConfigProvider.h" +#include +#include "velox/core/QueryConfig.h" + +using namespace facebook::velox::config; +using namespace facebook::velox::core; + +class QueryConfigProviderTest : public ::testing::Test { + protected: + QueryConfigProvider provider_; +}; + +TEST_F(QueryConfigProviderTest, propertiesNotEmpty) { + auto props = provider_.properties(); + EXPECT_GT(props.size(), 140); +} + +TEST_F(QueryConfigProviderTest, allNamesNonEmpty) { + for (const auto& prop : provider_.properties()) { + EXPECT_FALSE(prop.name.empty()) << "Found property with empty name"; + EXPECT_FALSE(prop.description.empty()) + << "Property " << prop.name << " has empty description"; + } +} + +TEST_F(QueryConfigProviderTest, noDuplicateNames) { + auto props = provider_.properties(); + std::set names; + for (const auto& prop : props) { + EXPECT_TRUE(names.insert(prop.name).second) + << "Duplicate property name: " << prop.name; + } +} + +TEST_F(QueryConfigProviderTest, knownProperties) { + auto props = provider_.properties(); + + auto findProp = + [&](const std::string& name) -> std::optional { + for (const auto& prop : props) { + if (prop.name == name) { + return prop; + } + } + return std::nullopt; + }; + + // Check a boolean property. + auto spillEnabled = findProp(QueryConfig::kSpillEnabled); + ASSERT_TRUE(spillEnabled.has_value()); + EXPECT_EQ(spillEnabled->type, ConfigPropertyType::kBoolean); + EXPECT_EQ(spillEnabled->defaultValue, "false"); + + // Check a string property. + auto sessionTz = findProp(QueryConfig::kSessionTimezone); + ASSERT_TRUE(sessionTz.has_value()); + EXPECT_EQ(sessionTz->type, ConfigPropertyType::kString); + EXPECT_EQ(sessionTz->defaultValue, ""); + + // Check an integer property (macro-registered). + auto startTime = findProp(QueryConfig::kSessionStartTime); + ASSERT_TRUE(startTime.has_value()); + EXPECT_EQ(startTime->type, ConfigPropertyType::kInteger); + EXPECT_EQ(startTime->defaultValue, "0"); + + // Check a double property (macro-registered). + auto cpuOverhead = + findProp(QueryConfig::kExprAdaptiveCpuSamplingMaxOverheadPct); + ASSERT_TRUE(cpuOverhead.has_value()); + EXPECT_EQ(cpuOverhead->type, ConfigPropertyType::kDouble); + EXPECT_EQ(cpuOverhead->defaultValue, "1"); +} + +TEST_F(QueryConfigProviderTest, normalizePassthrough) { + EXPECT_EQ(provider_.normalize("spill_enabled", "true"), "true"); + EXPECT_EQ(provider_.normalize("session_timezone", "UTC"), "UTC"); +} + +TEST_F(QueryConfigProviderTest, configPropertyTypeNames) { + EXPECT_EQ( + ConfigPropertyTypeName::toName(ConfigPropertyType::kBoolean), "BOOLEAN"); + EXPECT_EQ( + ConfigPropertyTypeName::toName(ConfigPropertyType::kInteger), "INTEGER"); + EXPECT_EQ( + ConfigPropertyTypeName::toName(ConfigPropertyType::kDouble), "DOUBLE"); + EXPECT_EQ( + ConfigPropertyTypeName::toName(ConfigPropertyType::kString), "STRING"); + + EXPECT_EQ( + ConfigPropertyTypeName::toConfigPropertyType("BOOLEAN"), + ConfigPropertyType::kBoolean); + EXPECT_EQ( + ConfigPropertyTypeName::toConfigPropertyType("INTEGER"), + ConfigPropertyType::kInteger); +} diff --git a/velox/core/tests/QueryConfigTest.cpp b/velox/core/tests/QueryConfigTest.cpp index daab3fad2a0..ce7d7124d7c 100644 --- a/velox/core/tests/QueryConfigTest.cpp +++ b/velox/core/tests/QueryConfigTest.cpp @@ -168,8 +168,9 @@ TEST_F(QueryConfigTest, expressionEvaluationRelatedConfigs) { std::make_shared(pool.get(), queryCtx.get()); auto evalCtx = std::make_shared(execCtx.get()); + SelectivityVector rows(100, true); ASSERT_EQ( - evalCtx->peelingEnabled(), + evalCtx->peelingEnabled(rows), !queryConfig.debugDisableExpressionsWithPeeling()); ASSERT_EQ( evalCtx->sharedSubExpressionReuseEnabled(), @@ -203,6 +204,151 @@ TEST_F(QueryConfigTest, expressionEvaluationRelatedConfigs) { testConfig(createConfig(false, true, false, false)); testConfig(createConfig(false, false, true, false)); testConfig(createConfig(false, false, false, true)); + + // Verify minRowsForPeeling: peeling is suppressed when the number of + // selected rows is below the threshold. + { + std::unordered_map configData( + {{core::QueryConfig::kMinRowsForPeeling, "50"}}); + auto queryCtx = + core::QueryCtx::create(nullptr, QueryConfig{std::move(configData)}); + auto execCtx = std::make_shared(pool.get(), queryCtx.get()); + auto evalCtx = std::make_shared(execCtx.get()); + + SelectivityVector belowThreshold(30, true); + ASSERT_FALSE(evalCtx->peelingEnabled(belowThreshold)); + + SelectivityVector atThreshold(50, true); + ASSERT_TRUE(evalCtx->peelingEnabled(atThreshold)); + + SelectivityVector aboveThreshold(100, true); + ASSERT_TRUE(evalCtx->peelingEnabled(aboveThreshold)); + } +} + +TEST_F(QueryConfigTest, sessionStartTime) { + // Test with no session start time set + { + auto queryCtx = QueryCtx::create(nullptr, QueryConfig{{}}); + const QueryConfig& config = queryCtx->queryConfig(); + + EXPECT_EQ(config.sessionStartTimeMs(), 0); + } + + // Test with session start time set + { + int64_t startTimeMs = 1674123456789; // Some timestamp in milliseconds + std::unordered_map configData( + {{QueryConfig::kSessionStartTime, std::to_string(startTimeMs)}}); + auto queryCtx = + QueryCtx::create(nullptr, QueryConfig{std::move(configData)}); + const QueryConfig& config = queryCtx->queryConfig(); + + EXPECT_EQ(config.sessionStartTimeMs(), startTimeMs); + } + + // Test with negative session start time (should be valid) + { + int64_t negativeStartTime = -1000; + std::unordered_map configData( + {{QueryConfig::kSessionStartTime, std::to_string(negativeStartTime)}}); + auto queryCtx = + QueryCtx::create(nullptr, QueryConfig{std::move(configData)}); + const QueryConfig& config = queryCtx->queryConfig(); + + EXPECT_EQ(config.sessionStartTimeMs(), negativeStartTime); + } + + // Test with maximum int64_t value + { + int64_t maxTime = std::numeric_limits::max(); + std::unordered_map configData( + {{QueryConfig::kSessionStartTime, std::to_string(maxTime)}}); + auto queryCtx = + QueryCtx::create(nullptr, QueryConfig{std::move(configData)}); + const QueryConfig& config = queryCtx->queryConfig(); + + EXPECT_EQ(config.sessionStartTimeMs(), maxTime); + } +} + +TEST_F(QueryConfigTest, singleSourceExchangeOptimizationConfig) { + // Test default value (should be false) + { + auto queryCtx = QueryCtx::create(nullptr, QueryConfig{{}}); + const QueryConfig& config = queryCtx->queryConfig(); + EXPECT_FALSE(config.singleSourceExchangeOptimizationEnabled()); + } + + // Test with optimization enabled + { + std::unordered_map configData( + {{QueryConfig::kSkipRequestDataSizeWithSingleSourceEnabled, "true"}}); + auto queryCtx = + QueryCtx::create(nullptr, QueryConfig{std::move(configData)}); + const QueryConfig& config = queryCtx->queryConfig(); + EXPECT_TRUE(config.singleSourceExchangeOptimizationEnabled()); + } + + // Test with optimization explicitly disabled + { + std::unordered_map configData( + {{QueryConfig::kSkipRequestDataSizeWithSingleSourceEnabled, "false"}}); + auto queryCtx = + QueryCtx::create(nullptr, QueryConfig{std::move(configData)}); + const QueryConfig& config = queryCtx->queryConfig(); + EXPECT_FALSE(config.singleSourceExchangeOptimizationEnabled()); + } +} + +TEST_F(QueryConfigTest, operatorSpillFileCreateConfig) { + // Test default values (empty strings) + { + auto queryCtx = QueryCtx::create(nullptr, QueryConfig{{}}); + const QueryConfig& config = queryCtx->queryConfig(); + EXPECT_EQ(config.aggregationSpillFileCreateConfig(), ""); + EXPECT_EQ(config.hashJoinSpillFileCreateConfig(), ""); + } + + // Test with aggregation spill file create config set + { + std::unordered_map configData( + {{QueryConfig::kAggregationSpillFileCreateConfig, + "aggregation_config_value"}}); + auto queryCtx = + QueryCtx::create(nullptr, QueryConfig{std::move(configData)}); + const QueryConfig& config = queryCtx->queryConfig(); + EXPECT_EQ( + config.aggregationSpillFileCreateConfig(), "aggregation_config_value"); + EXPECT_EQ(config.hashJoinSpillFileCreateConfig(), ""); + } + + // Test with hash join spill file create config set + { + std::unordered_map configData( + {{QueryConfig::kHashJoinSpillFileCreateConfig, + "hashjoin_config_value"}}); + auto queryCtx = + QueryCtx::create(nullptr, QueryConfig{std::move(configData)}); + const QueryConfig& config = queryCtx->queryConfig(); + EXPECT_EQ(config.aggregationSpillFileCreateConfig(), ""); + EXPECT_EQ(config.hashJoinSpillFileCreateConfig(), "hashjoin_config_value"); + } + + // Test with both configs set + { + std::unordered_map configData( + {{QueryConfig::kAggregationSpillFileCreateConfig, + "aggregation_config_value"}, + {QueryConfig::kHashJoinSpillFileCreateConfig, + "hashjoin_config_value"}}); + auto queryCtx = + QueryCtx::create(nullptr, QueryConfig{std::move(configData)}); + const QueryConfig& config = queryCtx->queryConfig(); + EXPECT_EQ( + config.aggregationSpillFileCreateConfig(), "aggregation_config_value"); + EXPECT_EQ(config.hashJoinSpillFileCreateConfig(), "hashjoin_config_value"); + } } } // namespace facebook::velox::core::test diff --git a/velox/core/tests/QueryCtxTest.cpp b/velox/core/tests/QueryCtxTest.cpp index 52ce900640b..44dbc5927c4 100644 --- a/velox/core/tests/QueryCtxTest.cpp +++ b/velox/core/tests/QueryCtxTest.cpp @@ -51,4 +51,90 @@ TEST_F(QueryCtxTest, withSysRootPool) { ASSERT_FALSE( queryPool->reclaimer()->reclaimableBytes(*queryPool, reclaimableBytes)); } + +TEST_F(QueryCtxTest, releaseCallbacks) { + int callbackCount = 0; + std::string capturedQueryId; + + { + auto queryCtx = QueryCtx::create( + nullptr, + QueryConfig{{}}, + std::unordered_map>{}, + nullptr, + nullptr, + nullptr, + "test_query_id"); + + // Add multiple callbacks. + queryCtx->addReleaseCallback([&callbackCount]() { ++callbackCount; }); + + queryCtx->addReleaseCallback( + [&callbackCount, &capturedQueryId, id = queryCtx->queryId()]() { + ++callbackCount; + capturedQueryId = id; + }); + + // Callbacks should not be invoked yet. + ASSERT_EQ(callbackCount, 0); + } + + // After QueryCtx destruction, all callbacks should have been invoked. + ASSERT_EQ(callbackCount, 2); + ASSERT_EQ(capturedQueryId, "test_query_id"); +} + +TEST_F(QueryCtxTest, releaseCallbackException) { + int callbackCount = 0; + + { + auto queryCtx = QueryCtx::create( + nullptr, + QueryConfig{{}}, + std::unordered_map>{}, + nullptr, + nullptr, + nullptr, + "test_query_id"); + + // First callback succeeds. + queryCtx->addReleaseCallback([&callbackCount]() { ++callbackCount; }); + + // Second callback throws an exception. + queryCtx->addReleaseCallback( + []() { throw std::runtime_error("Test exception"); }); + + // Third callback should still execute despite the previous exception. + queryCtx->addReleaseCallback([&callbackCount]() { ++callbackCount; }); + } + + // All callbacks should have been attempted, with exception caught and logged. + // First and third callbacks should have incremented the counter. + ASSERT_EQ(callbackCount, 2); +} + +TEST_F(QueryCtxTest, builderReleaseCallbacks) { + int callbackCount = 0; + std::string capturedQueryId; + + { + // Use builder to add release callbacks during construction. + auto queryCtx = + QueryCtx::Builder() + .queryId("builder_test_query_id") + .releaseCallback([&callbackCount]() { ++callbackCount; }) + .releaseCallback([&callbackCount, &capturedQueryId]() { + ++callbackCount; + capturedQueryId = "builder_test_query_id"; + }) + .build(); + + // Callbacks should not be invoked yet. + ASSERT_EQ(callbackCount, 0); + } + + // After QueryCtx destruction, all callbacks should have been invoked. + ASSERT_EQ(callbackCount, 2); + ASSERT_EQ(capturedQueryId, "builder_test_query_id"); +} } // namespace facebook::velox::core::test diff --git a/velox/core/tests/TypedExprHashConsistencyTest.cpp b/velox/core/tests/TypedExprHashConsistencyTest.cpp new file mode 100644 index 00000000000..5926510f233 --- /dev/null +++ b/velox/core/tests/TypedExprHashConsistencyTest.cpp @@ -0,0 +1,286 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "velox/common/memory/Memory.h" +#include "velox/core/Expressions.h" +#include "velox/type/Variant.h" +#include "velox/vector/tests/utils/VectorTestBase.h" + +namespace facebook::velox::core::test { + +/// Tests for expression hash consistency. +/// +/// These tests verify that hash functions are deterministic and stable: +/// 1. Same expression hashed multiple times produces same result +/// 2. Semantically equivalent expressions have same hash +/// 3. Hash survives serialization roundtrip +/// 4. Different expressions produce different hashes +/// +/// Note: We do NOT use hardcoded expected hash values as that makes tests +/// brittle. Instead we test hash properties. +class TypedExprHashConsistencyTest : public ::testing::Test, + public velox::test::VectorTestBase { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + } + + void SetUp() override { + pool_ = memory::memoryManager()->addLeafPool(); + Type::registerSerDe(); + ITypedExpr::registerSerDe(); + } + + std::shared_ptr makeConstantExpr( + const TypePtr& type, + const Variant& value) { + return std::make_shared(type, value); + } + + std::shared_ptr pool_; +}; + +// Test that hashing the same expression multiple times gives same result +TEST_F(TypedExprHashConsistencyTest, idempotency) { + // InputTypedExpr + auto inputExpr = std::make_shared(INTEGER()); + EXPECT_EQ(inputExpr->hash(), inputExpr->hash()); + + // ConstantTypedExpr with various types + auto boolExpr = makeConstantExpr(BOOLEAN(), variant(true)); + EXPECT_EQ(boolExpr->hash(), boolExpr->hash()); + + auto intExpr = makeConstantExpr(INTEGER(), variant(int32_t(42))); + EXPECT_EQ(intExpr->hash(), intExpr->hash()); + + auto strExpr = makeConstantExpr(VARCHAR(), variant("hello")); + EXPECT_EQ(strExpr->hash(), strExpr->hash()); + + // CallTypedExpr + auto callExpr = std::make_shared( + BIGINT(), + std::vector{ + std::make_shared(BIGINT(), "a")}, + "plus"); + EXPECT_EQ(callExpr->hash(), callExpr->hash()); + + // FieldAccessTypedExpr + auto fieldExpr = std::make_shared(BIGINT(), "column"); + EXPECT_EQ(fieldExpr->hash(), fieldExpr->hash()); + + // CastTypedExpr + auto castExpr = std::make_shared( + VARCHAR(), + std::vector{ + std::make_shared(INTEGER(), "x")}, + false); + EXPECT_EQ(castExpr->hash(), castExpr->hash()); +} + +// Test that equivalent expressions created separately have same hash +TEST_F(TypedExprHashConsistencyTest, equivalentExpressions) { + // Create two equivalent InputTypedExpr + auto input1 = std::make_shared(INTEGER()); + auto input2 = std::make_shared(INTEGER()); + EXPECT_EQ(input1->hash(), input2->hash()); + + // Create two equivalent ConstantTypedExpr + auto const1 = makeConstantExpr(INTEGER(), variant(int32_t(42))); + auto const2 = makeConstantExpr(INTEGER(), variant(int32_t(42))); + EXPECT_EQ(const1->hash(), const2->hash()); + + // Create two equivalent CallTypedExpr + auto call1 = std::make_shared( + BIGINT(), + std::vector{ + std::make_shared(BIGINT(), "x")}, + "negate"); + auto call2 = std::make_shared( + BIGINT(), + std::vector{ + std::make_shared(BIGINT(), "x")}, + "negate"); + EXPECT_EQ(call1->hash(), call2->hash()); + + // Create two equivalent FieldAccessTypedExpr + auto field1 = std::make_shared(VARCHAR(), "name"); + auto field2 = std::make_shared(VARCHAR(), "name"); + EXPECT_EQ(field1->hash(), field2->hash()); + + // Create two equivalent CastTypedExpr (cast) + auto cast1 = std::make_shared( + VARCHAR(), + std::vector{ + std::make_shared(INTEGER(), "y")}, + false); + auto cast2 = std::make_shared( + VARCHAR(), + std::vector{ + std::make_shared(INTEGER(), "y")}, + false); + EXPECT_EQ(cast1->hash(), cast2->hash()); + + // Create two equivalent CastTypedExpr (try_cast) + auto tryCast1 = std::make_shared( + VARCHAR(), + std::vector{ + std::make_shared(INTEGER(), "z")}, + true); + auto tryCast2 = std::make_shared( + VARCHAR(), + std::vector{ + std::make_shared(INTEGER(), "z")}, + true); + EXPECT_EQ(tryCast1->hash(), tryCast2->hash()); +} + +// Test that hash survives serialization roundtrip +TEST_F(TypedExprHashConsistencyTest, serializationRoundtrip) { + auto testRoundtrip = [this](const TypedExprPtr& expr) { + auto originalHash = expr->hash(); + + // Serialize + auto serialized = expr->serialize(); + + // Deserialize + auto deserialized = + ISerializable::deserialize(serialized, pool_.get()); + + // Hash should be the same + EXPECT_EQ(originalHash, deserialized->hash()) + << "Hash changed after serialization roundtrip for: " + << expr->toString(); + }; + + // Test various expression types + testRoundtrip(std::make_shared(INTEGER())); + testRoundtrip(makeConstantExpr(INTEGER(), variant(int32_t(42)))); + testRoundtrip(makeConstantExpr(VARCHAR(), variant("test"))); + testRoundtrip(std::make_shared(BIGINT(), "col")); + testRoundtrip( + std::make_shared( + BIGINT(), + std::vector{ + std::make_shared(BIGINT(), "a")}, + "abs")); + testRoundtrip( + std::make_shared( + VARCHAR(), + std::vector{ + std::make_shared(INTEGER(), "x")}, + false)); +} + +// Test that different expressions produce different hashes +TEST_F(TypedExprHashConsistencyTest, distinctness) { + // Different constant values should have different hashes + auto int1 = makeConstantExpr(INTEGER(), variant(int32_t(1))); + auto int2 = makeConstantExpr(INTEGER(), variant(int32_t(2))); + EXPECT_NE(int1->hash(), int2->hash()); + + // Different types should have different hashes + auto intExpr = makeConstantExpr(INTEGER(), variant(int32_t(42))); + auto bigintExpr = makeConstantExpr(BIGINT(), variant(int64_t(42))); + EXPECT_NE(intExpr->hash(), bigintExpr->hash()); + + // Different field names should have different hashes + auto field1 = std::make_shared(INTEGER(), "a"); + auto field2 = std::make_shared(INTEGER(), "b"); + EXPECT_NE(field1->hash(), field2->hash()); + + // Different function names should have different hashes + auto call1 = std::make_shared( + BIGINT(), + std::vector{ + std::make_shared(BIGINT(), "x")}, + "abs"); + auto call2 = std::make_shared( + BIGINT(), + std::vector{ + std::make_shared(BIGINT(), "x")}, + "negate"); + EXPECT_NE(call1->hash(), call2->hash()); + + // cast vs try_cast should have different hashes + auto cast = std::make_shared( + VARCHAR(), + std::vector{ + std::make_shared(INTEGER(), "x")}, + false); + auto tryCast = std::make_shared( + VARCHAR(), + std::vector{ + std::make_shared(INTEGER(), "x")}, + true); + EXPECT_NE(cast->hash(), tryCast->hash()); +} + +// Test complex nested expressions +TEST_F(TypedExprHashConsistencyTest, nestedExpressions) { + // Create a complex nested expression: cast(plus(a, b) as varchar) + auto fieldA = std::make_shared(BIGINT(), "a"); + auto fieldB = std::make_shared(BIGINT(), "b"); + auto plusExpr = std::make_shared( + BIGINT(), std::vector{fieldA, fieldB}, "plus"); + auto castExpr = std::make_shared( + VARCHAR(), std::vector{plusExpr}, false); + + // Create an equivalent expression + auto fieldA2 = std::make_shared(BIGINT(), "a"); + auto fieldB2 = std::make_shared(BIGINT(), "b"); + auto plusExpr2 = std::make_shared( + BIGINT(), std::vector{fieldA2, fieldB2}, "plus"); + auto castExpr2 = std::make_shared( + VARCHAR(), std::vector{plusExpr2}, false); + + // Equivalent nested expressions should have same hash + EXPECT_EQ(castExpr->hash(), castExpr2->hash()); + + // Idempotency + EXPECT_EQ(castExpr->hash(), castExpr->hash()); +} + +// Test lambda expressions +TEST_F(TypedExprHashConsistencyTest, lambdaExpressions) { + auto signature = ROW({"x"}, {INTEGER()}); + auto body = std::make_shared(INTEGER(), "x"); + auto lambda1 = std::make_shared(signature, body); + auto lambda2 = std::make_shared(signature, body); + + EXPECT_EQ(lambda1->hash(), lambda2->hash()); + EXPECT_EQ(lambda1->hash(), lambda1->hash()); +} + +// Test concat expressions +TEST_F(TypedExprHashConsistencyTest, concatExpressions) { + auto expr1 = std::make_shared( + std::vector{"a", "b"}, + std::vector{ + std::make_shared(INTEGER(), "x"), + std::make_shared(VARCHAR(), "y")}); + auto expr2 = std::make_shared( + std::vector{"a", "b"}, + std::vector{ + std::make_shared(INTEGER(), "x"), + std::make_shared(VARCHAR(), "y")}); + + EXPECT_EQ(expr1->hash(), expr2->hash()); + EXPECT_EQ(expr1->hash(), expr1->hash()); +} + +} // namespace facebook::velox::core::test diff --git a/velox/core/tests/TypedExprSerdeTest.cpp b/velox/core/tests/TypedExprSerdeTest.cpp index 62c6ae93006..16e32f48529 100644 --- a/velox/core/tests/TypedExprSerdeTest.cpp +++ b/velox/core/tests/TypedExprSerdeTest.cpp @@ -67,6 +67,18 @@ TEST_F(TypedExprSerDeTest, fieldAccess) { 0); testSerde(expression); ASSERT_EQ(expression->toString(), "\"ab\"[a]"); + + expression = std::make_shared( + INTEGER(), + std::make_shared(ROW({INTEGER(), VARCHAR()}), "x"), + 0); + ASSERT_EQ(expression->toString(), "\"x\"[0]"); + + expression = std::make_shared( + VARCHAR(), + std::make_shared(ROW({INTEGER(), VARCHAR()}), "x"), + 1); + ASSERT_EQ(expression->toString(), "\"x\"[1]"); } TEST_F(TypedExprSerDeTest, constant) { @@ -144,4 +156,30 @@ TEST_F(TypedExprSerDeTest, lambda) { testSerde(expression); } +TEST_F(TypedExprSerDeTest, nullIf) { + TypedExprPtr expression = std::make_shared( + std::make_shared(INTERVAL_DAY_TIME(), "a"), + std::make_shared(INTERVAL_DAY_TIME(), "b"), + INTERVAL_DAY_TIME()); + testSerde(expression); + + expression = std::make_shared( + std::make_shared(DATE(), "a"), + std::make_shared(DATE(), "b"), + DATE()); + testSerde(expression); + + expression = std::make_shared( + std::make_shared(TIME(), "a"), + std::make_shared(TIME(), "b"), + TIME()); + testSerde(expression); + + expression = std::make_shared( + std::make_shared(TIME_MICRO_UTC(), "a"), + std::make_shared(TIME_MICRO_UTC(), "b"), + TIME_MICRO_UTC()); + testSerde(expression); +} + } // namespace facebook::velox::core::test diff --git a/velox/docs/conf.py b/velox/docs/conf.py index e6401fa5fd2..a87507e85da 100644 --- a/velox/docs/conf.py +++ b/velox/docs/conf.py @@ -51,6 +51,7 @@ "pr", "spark", "iceberg", + "delta", "sphinx.ext.autodoc", "sphinx.ext.doctest", "sphinx.ext.mathjax", diff --git a/velox/docs/configs.rst b/velox/docs/configs.rst index 21eebecf8df..a30f901d552 100644 --- a/velox/docs/configs.rst +++ b/velox/docs/configs.rst @@ -27,6 +27,12 @@ Generic Configuration - 10000 - Max number of rows that could be return by operators from Operator::getOutput. It is used when an estimate of average row size is known and preferred_output_batch_bytes is used to compute the number of output rows. + * - merge_join_output_batch_start_size + - integer + - 0 + - Initial output batch size in rows for MergeJoin operator. When non-zero, the batch size starts at this value + and is dynamically adjusted based on the average row size of previous output batches. When zero (default), + dynamic adjustment is disabled and the batch size is fixed at preferred_output_batch_rows. * - max_elements_size_in_repeat_and_sequence - integer - 10000 @@ -43,6 +49,17 @@ Generic Configuration - integer - 80 - Abandons partial TopNRowNumber if number of output rows equals or exceeds this percentage of the number of input rows. + * - abandon_dedup_hashmap_min_rows + - integer + - 100,000 + - Number of input rows to receive before starting to check whether to abandon building a HashTable without + duplicates in HashBuild for left semi/anti join. + * - abandon_dedup_hashmap_min_pct + - integer + - 0 + - Abandons building a HashTable without duplicates in HashBuild for left semi/anti join if the percentage of + distinct keys in the HashTable exceeds this threshold. Zero means 'disable this optimization'. + Does not apply to counting joins (kCountingAnti, kCountingLeftSemiFilter) which always require deduplication. * - session_timezone - string - @@ -64,7 +81,7 @@ Generic Configuration - bool - true - If true, the driver will collect the operator's input/output batch size through vector flat size estimation, otherwise not. - - We might turn this off in use cases which have very wide column width and batch size estimation has non-trivial cpu cost. + We might turn this off in use cases which have very wide column width and batch size estimation has non-trivial cpu cost. * - hash_adaptivity_enabled - bool - true @@ -73,6 +90,11 @@ Generic Configuration - bool - true - If true, the conjunction expression can reorder inputs based on the time taken to calculate them. + * - parallel_join_build_rows_enabled + - bool + - false + - If true, the hash probe drivers can output build\-side rows in parallel for full and right joins (only when spilling is not + enabled by hash probe). If false, only the last prober is allowed to output build\-side rows. * - max_local_exchange_buffer_size - integer - 32MB @@ -106,6 +128,12 @@ Generic Configuration client. Enforced approximately, not strictly. A larger size can increase network throughput for larger clusters and thus decrease query processing time at the expense of reducing the amount of memory available for other usage. + * - skip_request_data_size_with_single_source_enabled + - bool + - false + - If true, skip request data size if there is only single source. + This is used to optimize the Presto-on-Spark use case where each exchange client + has only one shuffle partition source. * - local_merge_source_queue_size - integer - 2 @@ -116,6 +144,10 @@ Generic Configuration - The maximum size in bytes for the task's buffered output when output is partitioned using hash of partitioning keys. See PartitionedOutputNode::Kind::kPartitioned. The producer Drivers are blocked when the buffered size exceeds this. The Drivers are resumed when the buffered size goes below OutputBufferManager::kContinuePct (90)% of this. + * - partitioned_output_eager_flush + - bool + - false + - If true, the PartitionedOutput operator will flush rows eagerly, without waiting until buffers reach certain size. Default is false. * - max_output_buffer_size - integer - 32MB @@ -126,6 +158,21 @@ Generic Configuration - integer - 1000 - The minimum number of table rows that can trigger the parallel hash join table build. + * - hash_probe_dynamic_filter_pushdown_enabled + - bool + - true + - Whether hash probe can generate any dynamic filter (including Bloom filter) and push down to upstream operators. + * - hash_probe_string_dynamic_filter_pushdown_enabled + - bool + - false + - Whether hash probe can generate dynamic filter for string types and push down to upstream operators. + * - hash_probe_bloom_filter_pushdown_max_size + - integer + - 0 + - The maximum byte size of Bloom filter that can be generated from hash + probe. When set to 0, no Bloom filter will be generated. To achieve + optimal performance, this should not be too larger than the CPU cache + size on the host. * - debug.validate_output_from_operators - bool - false @@ -148,6 +195,12 @@ Generic Configuration - 0 - If it is not zero, specifies the time limit that a driver can continuously run on a thread before yield. If it is zero, then it no limit. + * - window_num_sub_partitions + - integer + - 1 + - Window operator can be configured to sub-divide window partitions on each thread of execution into groups of + sub partitions for sequential processing. This setting specifies how many sub-partitions to create for each + thread. Use 1 to disable sub partitioning. * - prefixsort_normalized_key_max_bytes - integer - 128 @@ -200,7 +253,6 @@ Generic Configuration be expensive (especially if operator stats are retrieved frequently) and this allows the user to explicitly enable it. -.. _expression-evaluation-conf: Expression Evaluation Configuration ----------------------------------- @@ -216,11 +268,40 @@ Expression Evaluation Configuration - boolean - false - Whether to use the simplified expression evaluation path. + * - expression.eval_flat_no_nulls + - boolean + - true + - Whether to enable the FlatNoNulls fast path for expression evaluation. When enabled, expressions skip null + checking and vector decoding when all inputs are flat-encoded with no nulls. Set to false to disable this + optimization. * - expression.track_cpu_usage - boolean - false - Whether to track CPU usage for individual expressions (supported by call and cast expressions). Can be expensive when processing small batches, e.g. < 10K rows. + * - expression.track_cpu_usage_for_functions + - string + - "" + - Comma-separated list of function names to selectively track CPU usage for. Only applicable when + ``expression.track_cpu_usage`` is set to false. Function names are case-insensitive and will be normalized + to lowercase. This allows fine-grained control over CPU tracking overhead when only specific functions need to + be monitored. + * - expression.adaptive_cpu_sampling + - boolean + - false + - Enables adaptive per-function CPU usage sampling. Each function is calibrated over 6 batches (1 warmup + 5 + calibration) to measure the overhead of CPU tracking (clock_gettime) relative to the function's execution time. + The timer overhead is measured once per ExprSet and shared across all functions. Functions where tracking overhead + is acceptable are always tracked; functions where overhead exceeds ``expression.adaptive_cpu_sampling_max_overhead_pct`` + are sampled at a rate proportional to their overhead. Sampled timing stats are extrapolated to approximate + full-population values. + * - expression.adaptive_cpu_sampling_max_overhead_pct + - float + - 1.0 + - Maximum acceptable CPU tracking overhead percentage per function, used with ``expression.adaptive_cpu_sampling``. + Functions whose tracking overhead exceeds this threshold are sampled at a rate of + ceil(overhead_pct / max_overhead_pct). For example, with max_overhead=1.0, a function with 70% tracking overhead + is sampled every 70th batch, bounding its effective overhead to ~1%. Must be greater than 0. * - legacy_cast - bool - false @@ -335,6 +416,12 @@ Spilling - boolean - true - When `spill_enabled` is true, determines whether Window operator can spill to disk under memory pressure. + * - window_spill_min_read_batch_rows + - integer + - 1000 + - When processing spilled window data, read batches of whole partitions having at least that many rows. Set to 1 to + read one whole partition at a time. Each driver processing the Window operator will process that much data at + once. * - row_number_spill_enabled - boolean - true @@ -343,6 +430,10 @@ Spilling - boolean - true - When `spill_enabled` is true, determines whether TopNRowNumber operator can spill to disk under memory pressure. + * - mark_distinct_spill_enabled + - boolean + - false + - When `spill_enabled` is true, determines whether MarkDistinct operator can spill to disk under memory pressure. * - writer_spill_enabled - boolean - true @@ -428,6 +519,12 @@ Spilling - Specifies the compression algorithm type to compress the spilled data before write to disk to trade CPU for IO efficiency. The supported compression codecs are: zlib, snappy, lzo, zstd, lz4 and gzip. none means no compression. + * - spill_num_max_merge_files + - integer + - 0 + - The max number of files to merge at a time when merging sorted files into a single ordered stream. 0 means unlimited. + This is used to reduce memory pressure by capping the number of open files when merging spilled sorted files to + avoid using too much memory and causing OOM. Note that this is only applicable for ordered spill. * - spill_prefixsort_enabled - bool - false @@ -465,6 +562,28 @@ Aggregation - integer - 80 - Abandons partial aggregation if number of groups equals or exceeds this percentage of the number of input rows. + * - aggregation_compaction_bytes_threshold + - integer + - 0 + - Memory threshold in bytes for triggering string compaction during global + aggregation. When total string storage exceeds this limit with high unused + memory ratio, compaction is triggered to reclaim dead strings. Disabled by + default (0). Currently only applies to approx_most_frequent aggregate with + StringView type during global aggregation. + * - aggregation_compaction_unused_memory_ratio + - double + - 0.25 + - Ratio of unused (evicted) bytes to total bytes that triggers compaction. + The value is in the range of [0, 1). Currently only applies to approx_most_frequent + aggregate with StringView type during global aggregation. May be extended + to other aggregation types on-demand. + * - aggregation_memory_compaction_reclaim_enabled + - bool + - false + - If true, enables lightweight memory compaction before spilling during + memory reclaim in aggregation. When enabled, the aggregation operator + will try to compact aggregate function state (e.g., free dead strings) + before resorting to spilling. * - streaming_aggregation_min_output_batch_rows - integer - 0 @@ -503,6 +622,13 @@ Table Scan increasing the number of running scan threads, and stop once exceeds this ratio. The value is in the range of [0, 1]. This only applies if 'table_scan_scaled_processing_enabled' is true. + * - table_scan_output_batch_rows_override + - integer + - 0 + - If non-zero, overrides the number of rows in each output batch produced + by the TableScan operator, bypassing the dynamic batch size calculation. + This is useful for correctness testing where a fixed batch size is needed + to produce deterministic results. Zero means 'no override'. Table Writer ------------ @@ -548,11 +674,50 @@ Table Writer - Minimum amount of data processed by all the logical table partitions to trigger skewed partition rebalancing by scale writer exchange. +Connector Config +---------------- +Connector config is initialized on velox runtime startup and is shared among queries as the default config across all connectors. +Each query can override the config by setting corresponding query session properties such as in Prestissimo. + +.. list-table:: + :widths: 20 20 10 10 70 + :header-rows: 1 + + * - Configuration Property Name + - Session Property Name + - Type + - Default Value + - Description + * - user + - + - string + - "" + - The user of the query. Used for storage logging. + * - source + - + - string + - "" + - The source of the query. Used for storage access and logging. + * - schema + - + - string + - "" + - The schema of the query. Used for storage logging. + Hive Connector -------------- Hive Connector config is initialized on velox runtime startup and is shared among queries as the default config. Each query can override the config by setting corresponding query session properties such as in Prestissimo. +Configuration property names use kebab-case (e.g., ``max-bucket-count``). +Session property names use snake_case (e.g., ``max_bucket_count``). +Properties without a session property name in the table below are fixed for the lifetime of the process +and cannot be modified per query. + +Properties of type ``capacity`` accept human-readable size strings such as ``512kB``, ``128MB``, ``1GB``, etc. +Properties of type ``integer`` with byte-valued defaults (shown as ``256KB``, ``8MB``, etc. for readability) +must be specified as raw byte counts. + .. list-table:: :widths: 20 20 10 10 70 :header-rows: 1 @@ -562,13 +727,13 @@ Each query can override the config by setting corresponding query session proper - Type - Default Value - Description - * - hive.max-partitions-per-writers - - + * - max-partitions-per-writers + - max_partitions_per_writers - integer - - 100 + - 128 - Maximum number of (bucketed) partitions per a single table writer instance. - * - hive.max-bucket-count - - hive.max_bucket_count + * - max-bucket-count + - max_bucket_count - integer - 100000 - Maximum number of buckets that a table writer is allowed to write to. @@ -580,7 +745,7 @@ Each query can override the config by setting corresponding query session proper the update mode field of the table writer operator output. ``OVERWRITE`` sets the update mode to indicate overwriting a partition if exists. ``ERROR`` sets the update mode to indicate error throwing if writing to an existing partition. - * - hive.immutable-partitions + * - immutable-partitions - - bool - false @@ -617,7 +782,7 @@ Each query can override the config by setting corresponding query session proper - Maximum size in bytes to coalesce requests to be fetched in a single request. * - max-coalesced-distance - - - integer + - capacity - 512KB - Maximum distance in capacity units between chunks to be fetched that may be coalesced into a single request. * - load-quantum @@ -644,21 +809,23 @@ Each query can override the config by setting corresponding query session proper - Maximum number of rows for sort writer in one batch of output. This is to limit the memory usage of sort writer. * - sort-writer-max-output-bytes - sort_writer_max_output_bytes - - string + - capacity - 10MB - Maximum bytes for sort writer in one batch of output. This is to limit the memory usage of sort writer. + * - max-target-file-size + - max_target_file_size + - capacity + - 0B + - Maximum target file size for writers. When a file exceeds this size during writing, the writer + closes the current file and starts writing to a new file. Accepts human-readable values like + "1GB". Zero means no limit (default). File rotation is not supported for bucketed tables or + sorted writes. * - file-preload-threshold - - integer - 8MB - Usually Velox fetches the meta data firstly then fetch the rest of file. But if the file is very small, Velox can fetch the whole file directly to avoid multiple IO requests. The parameter controls the threshold when whole file is fetched. - * - footer-estimated-size - - - - integer - - 1MB - - Define the estimation of footer size in ORC and Parquet format. The footer data includes version, schema, and meta data for every columns which may or may not need to be fetched later. - The parameter controls the size when footer is fetched each time. Bigger value can decrease the IO requests but may fetch more useless meta data. * - cache.no_retention - cache.no_retention - bool @@ -667,25 +834,74 @@ Each query can override the config by setting corresponding query session proper and also skip staging to the ssd cache. This helps to prevent the cache space pollution from the one-time table scan by large batch query when mixed running with interactive query which has high data locality. - * - hive.reader.stats_based_filter_reorder_disabaled - - hive.reader.stats_based_filter_reorder_disabaled + * - stats-based-filter-reorder-disabled + - stats_based_filter_reorder_disabled - bool - false - If true, disable the stats based filter reordering during the read processing, and the filter execution order is totally determined by the filter type. Otherwise, the file reader will dynamically adjust the filter execution order based on the past filter execution stats. - * - hive.reader.timestamp-partition-value-as-local-time - - hive.reader.timestamp_partition_value_as_local_time + * - reader.timestamp-partition-value-as-local-time + - reader.timestamp_partition_value_as_local_time - bool - true - Reads timestamp partition value as local time if true. Otherwise, reads as UTC. - * - hive.preserve-flat-maps-in-memory - - hive.preserve_flat_maps_in_memory + * - preserve-flat-maps-in-memory + - preserve_flat_maps_in_memory - bool - false - Whether to preserve flat maps in memory as FlatMapVectors instead of converting them to MapVectors. This is only applied during data reading inside the DWRF and Nimble readers, not during downstream processing like expression evaluation etc. - + * - max-rows-per-index-request + - max_rows_per_index_request + - integer + - 0 + - Maximum number of output rows to return per index lookup request. The limit is applied to the actual output rows + after filtering. 0 means no limit (default). + * - file-metadata-cache-enabled + - file_metadata_cache_enabled + - bool + - false + - Whether to cache file metadata (footer, stripes, index) in the process-wide AsyncDataCache. When enabled, + the first reader performs a speculative tail read and populates the cache; subsequent readers on the same file + serve metadata from cache with zero file IO. Currently only supported by Nimble format. + * - pin-file-metadata + - pin_file_metadata + - bool + - false + - Whether to pin parsed metadata objects (e.g., StripeGroup, IndexGroup) in the reader's metadata cache with + strong references so they are never evicted. This avoids re-reading and re-parsing metadata on every stripe + access when weak-pointer cache entries would otherwise expire. Can be used independently of + file-metadata-cache-enabled. Currently only supported by Nimble format. + * - reader.collect-column-cpu-metrics + - reader.collect_column_cpu_metrics + - bool + - false + - If true, enables collection of per-column timing statistics during file reading. This includes + decompression and decode CPU time metrics for each column, reported as runtime metrics in the format + ``column_..decompressCPUTimeNanos`` and ``column_..decodeCPUTimeNanos``. + Useful for performance analysis and identifying slow columns. + * - orc.footer-speculative-io-size + - orc_footer_speculative_io_size + - integer + - 256KB + - Speculative tail-read size in bytes when opening ORC files. Controls how many bytes are read from the end + of the file to load the footer and nearby metadata in a single IO operation. + Set to 0 for adaptive mode. + * - parquet.footer-speculative-io-size + - parquet_footer_speculative_io_size + - integer + - 256KB + - Speculative tail-read size in bytes when opening Parquet files. Controls how many bytes are read from the end + of the file to load the footer and nearby metadata in a single IO operation. + Set to 0 for adaptive mode. + * - nimble.footer-speculative-io-size + - nimble_footer_speculative_io_size + - integer + - 8MB + - Speculative tail-read size in bytes when opening Nimble files. Controls how many bytes are read from the end + of the file to load the footer and nearby metadata in a single IO operation. + Set to 0 for adaptive mode. ``ORC File Format Configuration`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -778,7 +994,7 @@ Each query can override the config by setting corresponding query session proper - 1024 - Batch size used when writing into Parquet through Arrow bridge. * - hive.parquet.writer.created-by - - + - hive.parquet.writer.created_by - string - parquet-cpp-velox version 0.0.0 - Created-by value used when writing to Parquet. @@ -879,6 +1095,19 @@ Each query can override the config by setting corresponding query session proper - - A custom credential provider, if specified, will be used to create the client in favor of other authentication mechanisms. The provider must be registered using "registerAWSCredentialsProvider" before it can be used. + * - hive.s3.aws-imds-enabled + - bool + - true + - AWS Instance Metadata Service (IMDS) is an AWS EC2 instance component used by applications to securely access metadata. + We must disable it on other instances to avoid high first-time read latency from S3 compatible object storages. + * - hive.s3.min-part-size + - string + - 10MB + - Minimum multi-part upload part size. The smallest allowed value is 5MB. The largest allowed value is 5GB. + If a file is less than this size, the file is sent as a single put request. + Otherwise, the file is split into multiple equal sized chunks of this part size excluding the last chunk. + The `AWS specification `_ limits the part size between 5MB and 5GB. + Some S3 backend providers enforce these limits strictly. Bucket Level Configuration """""""""""""""""""""""""" @@ -898,23 +1127,23 @@ These semantics are similar to the `Apache Hadoop-Aws module /oauth2/token`. + * - fs.azure.sas.token.renew.period.for.streams + - string + - 120 + - Specifies the period in seconds to re-use SAS tokens until the expiry is within this number of seconds. + This configuration is used together with `registerSasTokenProvider` for dynamic SAS token renewal. + When a SAS token is close to expiry, it will be renewed by getting a new token from the provider. Presto-specific Configuration ----------------------------- @@ -1012,15 +1247,18 @@ Spark-specific Configuration - integer - 1000000 - The default number of expected items for the bloom filter in :spark:func:`bloom_filter_agg` function. + * - spark.bloom_filter.max_num_items + - integer + - 4000000 + - The maximum number of items for the bloom filter in :spark:func:`bloom_filter_agg` function. * - spark.bloom_filter.num_bits - integer - 8388608 - The default number of bits to use for the bloom filter in :spark:func:`bloom_filter_agg` function. * - spark.bloom_filter.max_num_bits - integer - - 4194304 - - The maximum number of bits to use for the bloom filter in :spark:func:`bloom_filter_agg` function, - the value of this config can not exceed the default value. + - 67108864 + - The maximum number of bits to use for the bloom filter in :spark:func:`bloom_filter_agg` function. * - spark.partition_id - integer - @@ -1078,3 +1316,145 @@ Tracing - false - If true, we only collect the input trace for a given operator but without the actual execution. This is used for crash debugging. + +Cudf-specific Configuration (Experimental) +------------------------------------------ +These configurations are available when `compiled with cuDF `_. +Note: These configurations are experimental and subject to change. + +.. list-table:: + :widths: 30 10 10 70 + :header-rows: 1 + + * - Property Name + - Type + - Default Value + - Description + * - cudf.enabled + - bool + - true + - If true, enable cuDF. By default, it is enabled if compiled with cuDF. + * - cudf.memory_resource + - string + - async + - The memory resource to use for cuDF. Possible values are (cuda, pool, async, arena, managed, managed_pool, prefetch_managed, prefetch_managed_pool). + The prefetch options enable automatic prefetching for better GPU memory performance: prefetch_managed uses CUDA unified memory with prefetching, + prefetch_managed_pool uses a pooled version of CUDA unified memory with prefetching. + * - cudf.memory_percent + - integer + - 50 + - The initial percent of GPU memory to allocate for pool or arena memory resources. + * - cudf.function_name_prefix + - string + - "" + - The prefix to use for the function names in cuDF. + * - cudf.ast_expression_enabled + - bool + - true + - If true, enable using cuDF AST-based expression evaluation when supported. + * - cudf.ast_expression_priority + - integer + - 100 + - Priority of cuDF AST expressions. Higher value wins when multiple cuDF execution options are available for the same Velox expression. Standalone cuDF functions have priority 50. If enabled, with a default priority of 100, AST will be chosen as replacement for cudf execution. + * - cudf.allow_cpu_fallback + - bool + - true + - If true, allow falling back to Velox CPU execution when an operation is not supported in cuDF execution. If false, an error will be thrown if an operation is not supported in cuDF execution. + * - cudf.debug_enabled + - bool + - false + - If true, enable debug printing. + * - cudf.log_fallback + - bool + - true + - If true, log a reason for falling back to Velox CPU execution, when an operation is not supported in cuDF execution. + * - cudf.function_engine + - string + - presto + * - cudf.timestamp_unit + - string + - ns + - Timestamp precision unit for cuDF timestamp types. Valid values are: "s" (seconds), "ms" (milliseconds), "us" (microseconds), "ns" (nanoseconds). This controls the precision of timestamp data when converting between Velox and cuDF formats. + - Register the function for a specific engine. The optional values are presto or spark. + +Cudf Hive Connector Configuration (Experimental) +------------------------------------------------ +These configurations apply to the cuDF Hive connector (Parquet reader/writer via cuDF). +Connector config is initialized on velox runtime startup and is shared among queries as the default config. +Each query can override the config by setting the corresponding session property. +Reader options map to `libcudf parquet_reader_options `_ (used with the chunked reader when applicable). + +.. list-table:: + :widths: 20 20 10 10 70 + :header-rows: 1 + + * - Configuration Property Name + - Session Property Name + - Type + - Default Value + - Description + * - cudf.hive.use-buffered-input + - cudf.hive.use_buffered_input + - bool + - true + - Whether to use BufferedInput for CudfHiveDataSource (can use AsyncDataCache when HiveConfig file handle cache is enabled). + * - cudf.hive.use-experimental-reader + - cudf.hive.use_experimental_reader + - bool + - false + - Whether to use the experimental cuDF Parquet reader (Hybrid Scan) for highly selective filters. When enabled, uses `libcudf hybrid_scan_reader `_. + * - parquet.reader.use-pandas-metadata + - parquet.reader.use_pandas_metadata + - bool + - true + - Enable or disable use of pandas metadata while reading. Maps to ``enable_use_pandas_metadata`` in `libcudf parquet_reader_options `_. + * - parquet.reader.use-arrow-schema + - parquet.reader.use_arrow_schema + - bool + - true + - Enable or disable use of Arrow schema while reading. Maps to ``enable_use_arrow_schema`` in `libcudf parquet_reader_options `_. + * - parquet.reader.allow-mismatched-parquet-schemas + - parquet.reader.allow_mismatched_parquet_schemas + - bool + - false + - Enable or disable reading matching projected and filter columns from mismatched Parquet sources. Maps to ``enable_allow_mismatched_pq_schemas`` in `libcudf parquet_reader_options `_. + * - parquet.reader.timestamp-type + - parquet.reader.timestamp_type + - string + - TIMESTAMP_MILLISECONDS + - Timestamp type used to cast all timestamp columns (e.g. TIMESTAMP_DAYS, TIMESTAMP_SECONDS, TIMESTAMP_MILLISECONDS, TIMESTAMP_MICROSECONDS, TIMESTAMP_NANOSECONDS). Maps to ``set_timestamp_type`` in `libcudf parquet_reader_options `_. + * - parquet.reader.chunk-read-limit + - parquet.reader.chunk_read_limit + - integer + - 0 + - Limit on total number of bytes to be returned per read (per table chunk); 0 means no limit. Maps to ``chunk_read_limit`` in `libcudf hybrid_scan_reader `_ (e.g. ``setup_chunking_for_filter_columns``, ``setup_chunking_for_payload_columns``). + * - parquet.reader.pass-read-limit + - parquet.reader.pass_read_limit + - integer + - 0 + - Limit on the amount of memory (bytes) used for reading and decompressing data; 0 means no limit. This is a hint, not an absolute limit—if a single row group cannot fit within the limit, it will still be loaded. Affects how many row groups can be read at a time by limiting decompression space. Maps to ``pass_read_limit`` in `libcudf hybrid_scan_reader `_ (e.g. ``setup_chunking_for_filter_columns``, ``setup_chunking_for_payload_columns``). + * - parquet.reader.convert-strings-to-categories + - parquet.reader.convert_strings_to_categories + - bool + - false + - Whether to store string data as categorical type. + * - parquet.writer.write-timestamps-as-utc + - parquet.writer.write_timestamps_as_utc + - bool + - true + - Whether to write timestamps as UTC. + * - sort-writer_finish_time_slice_limit_ms + - sort_writer_finish_time_slice_limit_ms + - integer + - 5000 + - Sort writer exits finish() after this many milliseconds even if work is not complete; 0 means no time limit. + * - parquet.writer.write-arrow-schema + - parquet.writer.write_arrow_schema + - bool + - false + - Whether to write ARROW schema. + * - parquet.writer.write-v2-page-headers + - parquet.writer.write_v2_page_headers + - bool + - false + - Whether to write V2 page headers. diff --git a/velox/docs/designs/column-extraction-pushdown.md b/velox/docs/designs/column-extraction-pushdown.md new file mode 100644 index 00000000000..ba5dd4e1c24 --- /dev/null +++ b/velox/docs/designs/column-extraction-pushdown.md @@ -0,0 +1,767 @@ +# Column Extraction Pushdown Proposal + +*2026-03-04* + +## Motivation + +Queries like `SELECT map_keys(col) FROM t`, `SELECT cardinality(col) FROM t`, +or `SELECT map_values(col).x FROM t` currently read the entire complex-typed +column from storage, even though only a subset of its physical streams is +needed. For a `MAP(K, V)` column, `map_keys` only needs keys (not values), and +`cardinality` only needs the lengths stream (neither keys nor values). + +This waste is significant for wide map/struct values — the reader +materializes data that the engine immediately discards. + +**Column extraction pushdown** is a new pushdown mechanism, separate from +subfield pruning, that tells the reader to extract a specific component from a +complex type and produce it as a **different output type**. For example, +extracting map keys produces an `ARRAY(K)` instead of a `MAP(K, V)`. + +### Distinction from subfield pruning + +| | Subfield pruning | Column extraction | +|----------------|--------------------------------------|----------------------------------------------------| +| Output type | Same as input (MAP stays MAP) | Different (MAP -> ARRAY, MAP/ARRAY -> BIGINT) | +| Semantics | Drop unused parts, null-fill | Transform the type structure | +| Mechanism | `requiredSubfields` on column handle | `extractions` (NamedExtraction list) on column handle | +| Composability | N/A | N/A — mutually exclusive with subfield pruning | + +### Why not extend subfield pruning? + +Consider this query: + +```sql +SELECT cardinality(col) AS a, col['foo'] AS b FROM t +``` + +where `col` is `MAP(VARCHAR, BIGINT)`. + +Both `a` and `b` reference the same column `col`, but they need fundamentally +different things: + +- `a` needs only the size — no keys, no values. +- `b` needs the key `"foo"` and its corresponding value. + +With subfield pruning, a single column handle represents `col`. The required +subfields from both references are merged: + +``` +requiredSubfields: ["col[$]", "col[\"foo\"]"] +``` + +These are contradictory: `[$]` says "skip keys and values" while `["foo"]` says +"read key `foo` and its value." The reader cannot satisfy both with one pass — +it must read the whole column in order to return maps with correct size, and +if it only reads key `foo` the result would be incorrect. + +Column extraction solves this with a single column handle that carries +**multiple named extraction chains**. The column is read once, and each +extraction is applied independently: + +``` +assignments: { + "col": { + name: "col", + hiveType: MAP(VARCHAR, BIGINT), + dataType: ROW({ "a": BIGINT, "b": ARRAY(BIGINT) }), + extractions: [ + { outputName: "a", chain: [Size], dataType: BIGINT }, + { outputName: "b", chain: [MapKeyFilter(["foo"]), MapValues], dataType: ARRAY(BIGINT) } + ], + } +} +``` + +The reader reads `col` once — it reads sizes without restriction (`a`), and read +keys and values with a filter on the key (since `b` needs only key `"foo"`), and +further apply remaining of the chain in `b` in the column reader of map values. + +This is the core reason column extraction is a new mechanism rather than an +extension of subfield pruning. Subfield pruning merges all references to a +column into one output, which forces the reader to satisfy the union of all +requirements. Column extraction keeps a single column handle but produces +multiple named outputs, each with its own extraction chain. + +## Protocol: Extraction Chain + +The extraction is expressed as an ordered chain of steps. Each step operates on +one nesting level of the source type. The chain is a new field on the column +handle, mutually exclusive with `requiredSubfields`. + +### Extraction steps + +``` +ExtractionStep = + | StructField(name: string) // Navigate into a struct field. + | MapKeys // Extract map keys. MAP(K, V) → ARRAY(K). + | MapValues // Extract map values. MAP(K, V) → ARRAY(V). + | MapKeyFilter(keys: list) // Filter map to specific keys. + | // MAP(K, V) → MAP(K, V). Type-preserving. + | // Keys are strings or integers depending on + | // the map's key type. + | ArrayElements // Navigate into array elements. + | Size // Extract size. MAP/ARRAY → BIGINT. Terminal. +``` + +### Step input/output types (validation) + +Each step's output feeds as input to the next step. This forms a linear +pipeline for **type validation**: + +| Step | Required input | Output | +|--------------------|-----------------------------|--------------| +| `StructField(f)` | `ROW(..., f: T, ...)` | `T` | +| `MapKeys` | `MAP(K, V)` | `ARRAY(K)` | +| `MapValues` | `MAP(K, V)` | `ARRAY(V)` | +| `MapKeyFilter(ks)` | `MAP(K, V)` | `MAP(K, V)` | +| `ArrayElements` | `ARRAY(T)` | `T` | +| `Size` | `MAP(K, V)` or `ARRAY(T)` | `BIGINT` | + +**Rule:** `MapKeys` and `MapValues` produce `ARRAY(...)`. Any subsequent step +sees `ARRAY` as input. Therefore, `MapKeys`/`MapValues` **must** be followed +by `ArrayElements` unless it is the last step in the chain. `MapKeyFilter` is +type-preserving (`MAP` → `MAP`) and does NOT require `ArrayElements` after it. + +### Output type derivation + +The validation pipeline loses nesting information (each `ArrayElements` +unwraps one `ARRAY` layer). The actual output type is derived recursively: + +``` +derive(T, []) = T + +derive(ROW(.., f:T, ..), [StructField(f), ...rest]) = derive(T, rest) + +derive(MAP(K, V), [MapKeys, ArrayElements, ...rest]) = ARRAY(derive(K, rest)) +derive(MAP(K, V), [MapKeys]) = ARRAY(K) + +derive(MAP(K, V), [MapValues, ArrayElements, ...rest]) = ARRAY(derive(V, rest)) +derive(MAP(K, V), [MapValues]) = ARRAY(V) + +derive(MAP(K, V), [MapKeyFilter, ...rest]) = derive(MAP(K, V), rest) + +derive(ARRAY(T), [ArrayElements, ...rest]) = ARRAY(derive(T, rest)) + +derive(MAP|ARRAY, [Size]) = BIGINT +``` + +`MapKeys`/`MapValues` + `ArrayElements` are consumed as a pair — the +extraction wraps in `ARRAY`, and `ArrayElements` enters it. `MapKeyFilter` is +type-preserving and does not consume an `ArrayElements`. A standalone +`ArrayElements` (without a preceding extraction step) handles a source `ARRAY` +level. + +### `ArrayElements` roles + +| Context | Role | +|------------------------------|-----------------------------------------------------------------------------------------------| +| After `MapKeys`/`MapValues` | Consumes the `ARRAY` produced by extraction. **Mandatory** to continue the chain. | +| On a source `ARRAY` type | Navigates into the array's elements. **Mandatory** when the source type has an `ARRAY` level.| + +Every `ARRAY` boundary in the chain — whether from extraction or from the +source type — is explicitly represented by an `ArrayElements` step. + +### Mutual exclusivity with subfield pruning + +Column extraction (`extractions`) and subfield pruning (`requiredSubfields`) +**cannot coexist** on the same column handle. A column handle uses either +`requiredSubfields` (type-preserving pruning) or `extractions` (extraction +chains), but not both. + +This simplifies the reader contract — the reader either prunes subfields +within the existing type structure, or applies extraction chains to produce +new types. There is no ambiguity about which operation runs first or how +their type changes interact. + +**When to use which:** + +- Use `requiredSubfields` when the column's output type stays the same as the + file type (struct field pruning, map key filtering, array index truncation). +- Use `extractions` when the column's output type changes (MapKeys, MapValues, + Size, StructField extraction, or any chain that transforms the type). + +**Expressing subfield-like operations in extraction chains:** + +Operations that would have used `requiredSubfields` can be expressed as +extraction steps instead: + +- Struct field access → `StructField(name)` in the chain +- Map key filtering → `MapKeyFilter(keys)` in the chain +- Multiple struct fields → multiple `NamedExtraction` entries, each with a + `StructField` chain + +For example, instead of `requiredSubfields: ["col.x", "col.y"]`, use: +``` +extractions: [ + { outputName: "x", chain: [StructField("x")], dataType: INT }, + { outputName: "y", chain: [StructField("y")], dataType: INT } +] +``` + +### Overlap between `MapKeyFilter` and map key pruning + +`MapKeyFilter` and subfield pruning's map key subscripts (`["key1"]`, `["key2"]`) +both filter a map to specific keys. They overlap in functionality but differ in +where they sit: + +- **Subfield pruning** (`requiredSubfields: ["col[\"foo\"]"]`): filters keys on + the source column, output stays `MAP(K, V)`. No extraction chain involved. +- **`MapKeyFilter`**: filters keys as a step in the extraction chain, output + stays `MAP(K, V)`. Can compose with other extraction steps. + +Since extraction and subfield pruning are mutually exclusive on the same column +handle, use one or the other: + +- Use `requiredSubfields` when key filtering is the only operation needed and + no type transformation is involved. +- Use `MapKeyFilter` in an extraction chain when key filtering combines with + other extraction steps (e.g., `[MapKeyFilter(["foo"]), MapValues]`). + +### Overlap between `StructField` and struct subfield pruning + +`StructField` and subfield pruning's nested field paths (`.field1`, `.field2`) +both navigate into struct fields. They overlap when accessing struct children: + +- **Subfield pruning** (`requiredSubfields: ["col.x"]`): prunes the struct to + only field `x`, output stays `ROW(...)` with other fields null-filled. +- **`StructField("x")`**: extracts field `x` as a step in the extraction chain, + output becomes the field's type directly (e.g., `INT`). + +Since extraction and subfield pruning are mutually exclusive: + +- Use `requiredSubfields` for struct pruning when no type transformation is + needed (output stays ROW). +- Use `StructField` in an extraction chain when struct navigation combines + with other extraction steps (e.g., `[StructField("a"), MapKeys]`). + +## Examples + +### `map_keys(col)` — `col: MAP(K, V)` + +``` +Chain: [MapKeys] +Validation: MAP(K,V) → MapKeys → ARRAY(K) ✓ +Output: ARRAY(K) +``` + +### `map_keys(col.a.b)` — `col: ROW(a: ROW(b: MAP(K, V)))` + +``` +Chain: [StructField("a"), StructField("b"), MapKeys] +Validation: ROW → StructField("a") → ROW → StructField("b") → MAP(K,V) → MapKeys → ARRAY(K) ✓ +Output: ARRAY(K) +``` + +### `cardinality(col)` — `col: MAP(K, V)` or `ARRAY(T)` + +``` +Chain: [Size] +Validation: MAP(K,V) → Size → BIGINT ✓ +Output: BIGINT +``` + +### `col.x` — `col: ROW(x: INT, y: INT)` + +Struct field access with no other extraction steps. Prefer subfield pruning +(`requiredSubfields: ["col.x"]`) which keeps the `ROW` type and uses existing +infrastructure. `StructField` is only needed when it is part of a larger +extraction chain (see `map_keys(col.a.b)` and `cardinality(col.features)`). + +``` +-- Preferred: subfield pruning +extraction: [] +requiredSubfields: ["col.x"] + +-- Also valid but unnecessary: extraction +Chain: [StructField("x")] +Output: INT +``` + +### `cardinality(col.features)` — `col: ROW(features: ARRAY(FLOAT), label: INT)` + +Size extraction on a nested field. Navigate into the struct, then extract +size. + +``` +Chain: [StructField("features"), Size] +Validation: ROW → StructField("features") → ARRAY(FLOAT) → Size → BIGINT ✓ +Output: BIGINT +``` + +### `map_keys(map_values(col))` — `col: MAP(K1, MAP(K2, V))` + +This represents `transform(map_values(col), x -> map_keys(x))` — for each +value (which is `MAP(K2, V)`), extract its keys. + +``` +Chain: [MapValues, ArrayElements, MapKeys] +Validation: MAP(K1, MAP(K2,V)) → MapValues → ARRAY(MAP(K2,V)) + → ArrayElements → MAP(K2,V) + → MapKeys → ARRAY(K2) ✓ +Output: derive(MAP(K1, MAP(K2,V)), [MV, AE, MK]) + = ARRAY(derive(MAP(K2,V), [MK])) + = ARRAY(ARRAY(K2)) +``` + +### `map_keys(array_elements(map_values(col)))` — `col: MAP(K1, ARRAY(MAP(K2, V)))` + +``` +Chain: [MapValues, ArrayElements, ArrayElements, MapKeys] +Validation: MAP(K1, ARRAY(MAP(K2,V))) → MapValues → ARRAY(ARRAY(MAP(K2,V))) + → ArrayElements → ARRAY(MAP(K2,V)) + → ArrayElements → MAP(K2,V) + → MapKeys → ARRAY(K2) ✓ +Output: derive(MAP(K1, ARRAY(MAP(K2,V))), [MV, AE, AE, MK]) + = ARRAY(derive(ARRAY(MAP(K2,V)), [AE, MK])) + = ARRAY(ARRAY(derive(MAP(K2,V), [MK]))) + = ARRAY(ARRAY(ARRAY(K2))) +``` + +Two `ArrayElements` — the first consumes the `ARRAY` from `MapValues`, the +second navigates the source `ARRAY` level. + +### `map_values(col).x` — `col: MAP(K, ROW(x: INT, y: INT))` + +Extraction is used, so `requiredSubfields` cannot be set. Use `StructField` +in the chain to extract the specific field. + +``` +Chain: [MapValues, ArrayElements, StructField("x")] +Validation: MAP(K, ROW(x,y)) → MapValues → ARRAY(ROW(x,y)) + → ArrayElements → ROW(x,y) + → StructField → INT ✓ +Output: ARRAY(INT) +``` + +### `map_subset(col, ARRAY['a', 'b'])` — `col: MAP(VARCHAR, BIGINT)` + +Map key filtering with no other extraction steps. Prefer subfield pruning +(`requiredSubfields: ["col[\"a\"]", "col[\"b\"]"]`) which uses existing +infrastructure. `MapKeyFilter` is only needed when it is part of a larger +extraction chain (see the nested key filter example below). + +``` +-- Preferred: subfield pruning +extraction: [] +requiredSubfields: ["col[\"a\"]", "col[\"b\"]"] + +-- Also valid but unnecessary: extraction +Chain: [MapKeyFilter(["a", "b"])] +Output: MAP(VARCHAR, BIGINT) +``` + +### `element_at(col, 'foo').x` — `col: MAP(VARCHAR, ROW(x: INT, y: INT))` + +Single-key filter with single struct field access. Since `MapKeyFilter` is +the only extraction step, prefer subfield pruning: + +``` +-- Preferred: subfield pruning (no extraction chain) +extraction: [] +requiredSubfields: ["col[\"foo\"].x"] + +-- Also valid: extraction chain avoids materializing the full ROW, which can +-- be beneficial when the struct has many fields. +Chain: [MapKeyFilter(["foo"]), MapValues, ArrayElements, StructField("x")] +Output: ARRAY(INT) +``` + +### Non-pushable: `map_keys(map_filter(col, (k, v) -> v > 10))` — `col: MAP(VARCHAR, BIGINT)` + +The `map_filter` predicate depends on values (`v > 10`). If we pushed +`MapKeys` extraction to the reader, the reader would skip values, making it +impossible to evaluate the filter. **Extraction cannot be pushed** through +intermediate expressions that depend on skipped data. The entire map must be +read, `map_filter` applied in the engine, and then `map_keys` applied to the +filtered result. + +### Nested key filter — `col: MAP(K1, MAP(VARCHAR, ROW(x: INT, y: INT)))` + +SQL: `transform(map_values(col), m -> element_at(m, 'foo').x)` + +Extract values from outer map, filter inner map to key `"foo"`, extract +subfield `x`. Extraction is used, so `requiredSubfields` cannot be set. +Use `StructField` in the chain. + +``` +Chain: [MapValues, ArrayElements, MapKeyFilter(["foo"]), MapValues, ArrayElements, StructField("x")] +Validation: MAP(K1, MAP(VARCHAR, ROW(x,y))) → MapValues → ARRAY(MAP(VARCHAR, ROW(x,y))) + → ArrayElements → MAP(VARCHAR, ROW(x,y)) + → MapKeyFilter → MAP(VARCHAR, ROW(x,y)) + → MapValues → ARRAY(ROW(x,y)) + → ArrayElements → ROW(x,y) + → StructField → INT ✓ +Output: ARRAY(ARRAY(INT)) +``` + +### Error: missing `ArrayElements` after `MapValues` + +``` +Chain: [MapValues, MapKeys] on MAP(K1, ARRAY(MAP(K2, V))) +Validation: MAP(K1, ARRAY(MAP(K2,V))) → MapValues → ARRAY(ARRAY(MAP(K2,V))) + → MapKeys → ERROR: expects MAP, got ARRAY +``` + +### Error: missing `ArrayElements` after `MapKeys` + +``` +Chain: [MapKeys, StructField("x")] on MAP(ROW(x: INT, y: INT), V) +Validation: MAP(ROW(x,y), V) → MapKeys → ARRAY(ROW(x,y)) + → StructField("x") → ERROR: expects ROW, got ARRAY +``` + +## Column Handle Protocol + +### `HiveColumnHandle` (C++) + +```cpp +/// Type of extraction to apply at one nesting level. +enum class ExtractionStep : uint8_t { + /// Navigate into a struct field. Input must be ROW. + kStructField, + /// Extract map keys as ARRAY. Input must be MAP. + kMapKeys, + /// Extract map values as ARRAY. Input must be MAP. + kMapValues, + /// Filter map to specific keys. Input must be MAP. Type-preserving. + kMapKeyFilter, + /// Navigate into array elements. Input must be ARRAY. + kArrayElements, + /// Extract size as BIGINT. Input must be MAP or ARRAY. Terminal. + kSize, +}; + +/// Base class for one step in the extraction chain. Subclasses carry +/// step-specific data (field name, filter keys). +class ExtractionPathElement { + public: + virtual ~ExtractionPathElement() = default; + virtual ExtractionStep step() const = 0; + + /// Factory methods. + static std::shared_ptr simple(ExtractionStep); + static std::shared_ptr structField( + const std::string& name); + static std::shared_ptr mapKeyFilter( + std::vector keys); + static std::shared_ptr mapKeyFilter( + std::vector keys); +}; + +using ExtractionPathElementPtr = std::shared_ptr; + +/// Concrete subclasses: +/// SimpleExtractionPathElement — MapKeys, MapValues, ArrayElements, Size +/// StructFieldExtractionPathElement — kStructField; carries fieldName() +/// MapKeyFilterExtractionPathElement — kMapKeyFilter; carries +/// stringFilterKeys() / intFilterKeys() + +/// Named extraction chain producing one output column. +struct NamedExtraction { + /// Output column name in the scan's outputType. + std::string outputName; + + /// Extraction chain to apply. Empty means pass-through (no extraction). + std::vector chain; + + /// Output type after applying the chain. + TypePtr dataType; +}; + +class HiveColumnHandle : public connector::ColumnHandle { + public: + HiveColumnHandle( + const std::string& name, + ColumnType columnType, + TypePtr dataType, + TypePtr hiveType, + std::vector requiredSubfields = {}, + std::vector extractions = {}, + ...); + + /// Named extraction chains. Empty means no extraction (current behavior). + /// When a single entry is present, the column handle's dataType is that + /// entry's dataType. When multiple entries are present, the column + /// handle's dataType is a ROW type whose fields are the outputNames with + /// their corresponding dataTypes. + /// Mutually exclusive with requiredSubfields — if extractions is non-empty, + /// requiredSubfields must be empty. + const std::vector& extractions() const { + return extractions_; + } + + private: + ... + std::vector extractions_; +}; +``` + +### `HiveColumnHandle` (Java / Presto coordinator) + +```java +public class HiveColumnHandle extends BaseHiveColumnHandle { + ... + // Existing + private final List requiredSubfields; + + // New: named extraction chains + private final List extractions; + + public record NamedExtraction( + String outputName, + List chain, + TypeSignature dataType + ) {} + + public sealed interface ExtractionPathElement { + ExtractionStep step(); + + // MapKeys, MapValues, ArrayElements, Size + record Simple(ExtractionStep step) + implements ExtractionPathElement {} + + record StructField(String fieldName) + implements ExtractionPathElement { + public ExtractionStep step() { return ExtractionStep.STRUCT_FIELD; } + } + + record MapKeyFilter( + List stringFilterKeys, + List intFilterKeys) + implements ExtractionPathElement { + public ExtractionStep step() { return ExtractionStep.MAP_KEY_FILTER; } + } + } + + public enum ExtractionStep { + STRUCT_FIELD, + MAP_KEYS, + MAP_VALUES, + MAP_KEY_FILTER, + ARRAY_ELEMENTS, + SIZE + } +} +``` + +### Serialization + +The named extraction chains are serialized as a JSON array in the plan +fragment. Note that `requiredSubfields` and `extractions` are mutually +exclusive — when `extractions` is non-empty, `requiredSubfields` must be +empty: + +```json +{ + "name": "col", + "hiveType": "map(varchar, array(map(integer, double)))", + "requiredSubfields": [], + "extractions": [ + { + "outputName": "col_keys", + "dataType": "array(array(array(integer)))", + "chain": [ + {"step": "MAP_VALUES"}, + {"step": "ARRAY_ELEMENTS"}, + {"step": "ARRAY_ELEMENTS"}, + {"step": "MAP_KEYS"} + ] + } + ] +} +``` + +With key filter and struct field extraction: + +```json +{ + "name": "col", + "hiveType": "map(varchar, row(x integer, y integer))", + "requiredSubfields": [], + "extractions": [ + { + "outputName": "col_x", + "dataType": "array(integer)", + "chain": [ + {"step": "MAP_KEY_FILTER", "stringFilterKeys": ["foo", "bar"]}, + {"step": "MAP_VALUES"}, + {"step": "ARRAY_ELEMENTS"}, + {"step": "STRUCT_FIELD", "fieldName": "x"} + ] + } + ] +} +``` + +Multiple extractions from the same column: + +```json +{ + "name": "col", + "hiveType": "map(varchar, row(x integer, y double))", + "requiredSubfields": [], + "extractions": [ + { + "outputName": "col_size", + "dataType": "bigint", + "chain": [{"step": "SIZE"}] + }, + { + "outputName": "col_keys", + "dataType": "array(varchar)", + "chain": [{"step": "MAP_KEYS"}] + }, + { + "outputName": "col_x", + "dataType": "array(integer)", + "chain": [ + {"step": "MAP_VALUES"}, + {"step": "ARRAY_ELEMENTS"}, + {"step": "STRUCT_FIELD", "fieldName": "x"} + ] + } + ] +} +``` + +### Contract + +- `extractions` is empty → current behavior, `requiredSubfields` may be used + for type-preserving pruning. +- `extractions` is non-empty → worker applies each chain, producing the + corresponding `dataType`. `requiredSubfields` must be empty. +- `extractions` and `requiredSubfields` are **mutually exclusive** on the same + column handle. +- `requiredSubfields` operates on the file type (existing behavior). +- When multiple extractions are needed from the same column, use the + `NamedExtraction` list on a single column handle (see "Multiple extractions + per column" below). Do NOT create multiple column handles for the same + source column. +- The worker validates `hiveType` + each extraction chain → its `dataType` + and rejects mismatches. + +### Multiple extractions per column + +When a query references the same column with different extractions (e.g., +`SELECT map_keys(col) AS keys, cardinality(col) AS size FROM t`), a single +column handle carries all extraction chains via `NamedExtraction`. The column +handle's `dataType` is a **ROW** whose fields are the output names with their +corresponding types: + +``` +assignments: { + "col": HiveColumnHandle { + name: "col", + hiveType: MAP(K, V), + dataType: ROW({ "keys": ARRAY(K), "size": BIGINT }), + extractions: [ + { outputName: "keys", chain: [MapKeys], dataType: ARRAY(K) }, + { outputName: "size", chain: [Size], dataType: BIGINT } + ] + } +} +``` + +The column is read once from the file. Each extraction chain is applied +independently to produce a field in the output ROW. + +**Examples with multiple extractions:** + +``` +-- col: MAP(VARCHAR, ROW(x: INT, y: INT, z: INT)) +-- Query: SELECT map_keys(col) AS keys, map_values(col).x AS vals_x FROM t +-- +-- Two extractions: keys and a specific value subfield. +-- Use StructField in the values chain to extract only field x. + +HiveColumnHandle { + name: "col", + hiveType: MAP(VARCHAR, ROW(x: INT, y: INT, z: INT)), + dataType: ROW({ "keys": ARRAY(VARCHAR), "vals_x": ARRAY(INT) }), + extractions: [ + { outputName: "keys", chain: [MapKeys], dataType: ARRAY(VARCHAR) }, + { outputName: "vals_x", chain: [MapValues, ArrayElements, StructField("x")], dataType: ARRAY(INT) } + ] +} +``` + +``` +-- col: MAP(BIGINT, ROW(a: VARCHAR, b: DOUBLE, c: INT)) +-- Query: SELECT cardinality(col) AS sz, map_values(col).a AS vals_a, +-- map_values(col).b AS vals_b FROM t +-- +-- Three outputs: size and two value subfields. +-- Each subfield gets its own NamedExtraction with a StructField chain. + +HiveColumnHandle { + name: "col", + hiveType: MAP(BIGINT, ROW(a: VARCHAR, b: DOUBLE, c: INT)), + dataType: ROW({ "sz": BIGINT, "vals_a": ARRAY(VARCHAR), "vals_b": ARRAY(DOUBLE) }), + extractions: [ + { outputName: "sz", chain: [Size], dataType: BIGINT }, + { outputName: "vals_a", chain: [MapValues, ArrayElements, StructField("a")], dataType: ARRAY(VARCHAR) }, + { outputName: "vals_b", chain: [MapValues, ArrayElements, StructField("b")], dataType: ARRAY(DOUBLE) } + ] +} +``` + +## Future Extensions + +### `ArraySlice` — sequence truncating for ML + +ML models often consume variable-length sequences (user activity histories, +embedding arrays, feature maps) but have a maximum sequence length. Reading +10K-element arrays when the model only uses the last 128 is wasteful. + +`ArraySlice(offset, length)` would be a **type-preserving** step that truncates +arrays at the reader level. A negative `length` selects from the end of the +array: `ArraySlice(0, -128)` means "take the last 128 elements." + +| Step | Required input | Output | +|------------------------|----------------|-------------| +| `ArraySlice(off, len)` | `ARRAY(T)` | `ARRAY(T)` | + +Since the output type is the same as the input, `ArraySlice` does NOT require +`ArrayElements` after it (unlike `MapKeys`/`MapValues` which change the type). + +Derivation rule: + +``` +derive(ARRAY(T), [ArraySlice, ...rest]) = derive(ARRAY(T), rest) +``` + +Composed examples: + +``` +-- Feature map with array values, filter to specific keys, truncate each to 128 +col: MAP(VARCHAR, ARRAY(FLOAT)) +Chain: [MapKeyFilter(["feat1", "feat2"]), MapValues, ArrayElements, ArraySlice(0, 128)] +Output: ARRAY(ARRAY(FLOAT)) +``` + +``` +-- User history, take first 128 events, extract only item_id +col: ARRAY(ROW(item_id BIGINT, timestamp BIGINT)) +Chain: [ArraySlice(0, 128), ArrayElements, StructField("item_id")] +Output: ARRAY(BIGINT) +``` + +``` +-- User history, take LAST 128 events (most recent) +col: ARRAY(ROW(item_id BIGINT, timestamp BIGINT)) +Chain: [ArraySlice(0, -128)] +Output: ARRAY(ROW(item_id BIGINT, timestamp BIGINT)) +``` + +``` +-- Feature map with array values, filter to specific keys, take last 64 from each +col: MAP(VARCHAR, ARRAY(FLOAT)) +Chain: [MapKeyFilter(["feat1", "feat2"]), MapValues, ArrayElements, ArraySlice(0, -64)] +Output: ARRAY(ARRAY(FLOAT)) +``` + +This is a natural extension — it fits into the chain as a type-preserving step, +composable with extraction steps before and after it. Velox already has +`ScanSpec::maxArrayElementsCount_` for the prefix case; `ArraySlice` generalizes +it to work within extraction chains. diff --git a/velox/docs/develop.rst b/velox/docs/develop.rst index 48c4641e71c..c4df460f191 100644 --- a/velox/docs/develop.rst +++ b/velox/docs/develop.rst @@ -21,8 +21,10 @@ This guide is intended for Velox contributors and developers of Velox-based appl develop/connectors develop/joins develop/anti-join + develop/hash-table-caching develop/operators develop/task + develop/task-barrier develop/simd develop/memory develop/spilling diff --git a/velox/docs/develop/TpchBenchmark.rst b/velox/docs/develop/TpchBenchmark.rst index 23ba28b5d02..fd0a517f682 100644 --- a/velox/docs/develop/TpchBenchmark.rst +++ b/velox/docs/develop/TpchBenchmark.rst @@ -11,7 +11,7 @@ query engine. Benchmarking in Velox is made easy with the optionally built TpchBenchmark (velox_tpch_benchmark) executable. To build the benchmark executable -(*_build/release/velox/benchamrks/tpch/velox_tpch_benchmark*), use the +(*_build/release/velox/benchmarks/tpch/velox_tpch_benchmark*), use the following command line to do the build with S3 support: .. code:: shell @@ -45,9 +45,9 @@ responsible for both **driver** threads and I/O threads. Multiple Process Executor Use Case ---------------------------------- -This use case is used by Spark + `Gluten `_ +This use case is used by Spark + `Gluten `_ and it differs from the Presto use case where parallelism is concerned. Spark -uses multiple processes where each process is a Gluten+Velox query processor. +uses multiple processes where each process is a Gluten + Velox query processor. Spark scales by using many Linux processes for query processing. In this case this means that the **drivers** are outside of Velox and Gluten and is defined by the Spark configuration and number of workers. Gluten takes on the role of @@ -89,7 +89,7 @@ use a single dash not a double dash; i.e. -option and not --option): **NOTE:** *There is a limitation on the implementation of the AWS SDK that will cause failures (curl error 28) if the **driver** *threads times I/O threads -grow much beyond 350 threads. This only really effects the multi-threaded +grow much beyond 350 threads. This only really affects the multi-threaded **drivers** *use case like the benchmark tool. It is only known to be an issue when running against AWS S3. However, the error is coming from the libcurl library so it is possible other Cloud storage APIs could also be affected.* @@ -99,7 +99,7 @@ Velox exposes other options used for tuning that are of interest: * *max_coalesce_bytes* - Size of coalesced data, has small improvements as size grows. -* *max_coalesce_distance_bytes* - Maximum gap bytes between data that can +* *max_coalesce_distance_bytes* - Maximum gap bytes between data that can be coalesced. Larger may mean more fetched data but at greater bytes/sec. Top Optimization Recommendations @@ -149,14 +149,14 @@ chunks as opposed to many smaller requests. This configuration option is useful for workloads that read the same data several times per query but only applies to the single process use case. -*NOTE: There is a SSD Caching option in Velox but it to is ONLY useful in +*NOTE: There is a SSD Caching option in Velox but it too is ONLY useful in the single process use case.* **num_splits_per_file** ----------------------- This configuration option is best when the data set count of row groups -matches this value. The affect in overall performance appears based on +matches this value. The effect on overall performance appears based on testing to be small, however. Optimizations for All Workloads (Both Use Cases) @@ -183,8 +183,8 @@ fine-tuned for the workload being run. Summary ======= -If a use of Velox matches the use case of the TcphBenchmark then it is a good -tool to test, I/O and driver performance for specific TCP-H queries. This would +If a use of Velox matches the use case of the TpchBenchmark then it is a good +tool to test, I/O and driver performance for specific TPC-H queries. This would benefit execution of specific production workloads that are like the chosen queries. If in multi-process use case, like Spark/Gluten/Velox configuration, the recommendation is to oversubscribe I/O threads between 2X and 3X vCPUs and diff --git a/velox/docs/develop/aggregate-functions.rst b/velox/docs/develop/aggregate-functions.rst index e989f624976..b0146f2f4b3 100644 --- a/velox/docs/develop/aggregate-functions.rst +++ b/velox/docs/develop/aggregate-functions.rst @@ -137,8 +137,9 @@ Simple Function Interface This section describes the main concepts and the simple interface of aggregation functions. Examples of aggregation functions implemented through -the simple-function interface can be found at velox/exec/tests/SimpleAverageAggregate.cpp -and velox/exec/tests/SimpleArrayAggAggregate.cpp. +the simple-function interface can be found at SimpleAverageAggregate.cpp, +SimpleArrayAggAggregate.cpp, SimpleVariadicSumAggregate.cpp, and +SimpleVariadicArrayAggAggregate.cpp under velox/exec/tests. A simple aggregation function is implemented as a class as the following. @@ -451,16 +452,12 @@ null should be written to the final result vector. Limitations ^^^^^^^^^^^ -The simple aggregation function interface currently has three limitations. +The simple aggregation function interface currently has two limitations. -1. All values read or written by the aggrgeaiton function must be part of the - accumulators. This means that there cannot be function-level states kept - outside of accumulators. - -2. Optimizations on constant inputs is not supported. I.e., constant input +1. Optimizations on constant inputs is not supported. I.e., constant input arguments are processed once per row in the same way as non-constant inputs. -3. Aggregation pushdown to table scan is not supported yet. We're planning to +2. Aggregation pushdown to table scan is not supported yet. We're planning to add this support. Vector Function Interface diff --git a/velox/docs/develop/aggregations.rst b/velox/docs/develop/aggregations.rst index 934252f7c14..46c04818914 100644 --- a/velox/docs/develop/aggregations.rst +++ b/velox/docs/develop/aggregations.rst @@ -112,6 +112,27 @@ encounters a row with a different values in pre-grouped keys. This helps reduce the total amount of memory used and allows to unblock downstream operators faster. +noGroupsSpanBatches Optimization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +AggregationNode supports an optional ``noGroupsSpanBatches`` flag that can be set +to true for streaming aggregations. When enabled, this flag indicates that no +sort group spans across input batches - each input batch contains complete data +for its groups, and no group will appear in any subsequent input batch. + +This optimization allows the StreamingAggregation operator to immediately produce +aggregation results for all groups in each input batch without waiting to see if +more data for those groups will arrive in subsequent batches. This can +significantly improve output latency and reduce memory usage since the operator +doesn't need to hold onto partial aggregation state across batches. + +The ``noGroupsSpanBatches`` flag can only be set when the aggregation is +pre-grouped (streaming). Setting it on a non-streaming aggregation will result +in an error. + +This optimization is typically set automatically by query optimizers when they can +guarantee that the input data meets the required properties. + Push-Down into Table Scan ------------------------- @@ -343,5 +364,6 @@ accumulators. Many aggregate functions implement toIntermediate() fast path. Some examples include: :func:`min`, :func:`max`, :func:`array_agg`, :func:`set_agg`, :func:`map_agg`, :func:`map_union`. -One can use runtime statistic `abandonedPartialAggregation` to tell whether -partial aggregation was abandoned. +Runtime statistic `abandonedPartialAggregationRows` counts rows that bypassed +partial aggregation after it was abandoned. A value greater than 0 indicates +that partial aggregation was abandoned. diff --git a/velox/docs/develop/connectors.rst b/velox/docs/develop/connectors.rst index 9e550e007e7..b4c4275cdb8 100644 --- a/velox/docs/develop/connectors.rst +++ b/velox/docs/develop/connectors.rst @@ -89,7 +89,7 @@ S3 is supported using the `AWS SDK for C++ ` S3 supported schemes are `s3://` (Amazon S3, Minio), `s3a://` (Hadoop 3.x), `s3n://` (Deprecated in Hadoop 3.x), `oss://` (Alibaba cloud storage), and `cos://`, `cosn://` (Tencent cloud storage). -HDFS is supported using the +HDFS is supported using the `Apache Hadoop libhdfs.so `_ and `Apache Hawk libhdfs3 `_ library. HDFS supported schemes are `hdfs://`. @@ -121,3 +121,8 @@ This is the behavior when the proxy settings are enabled: 4. The no_proxy/NO_PROXY list is comma separated. 5. Use . or \*. to indicate domain suffix matching, e.g. `.foobar.com` will match `test.foobar.com` or `foo.foobar.com`. + +HDFS Storage adapter +******************** + +Velox currently supports HDFS by dynamically loading libhdfs.so from the environment's ${HADOOP_HOME}/native/lib directory. If you prefer to use libhdfs3 instead, you can create a symbolic link from libhdfs.so to libhdfs3.so within the same directory. diff --git a/velox/docs/develop/debugging/print-plan-with-stats.rst b/velox/docs/develop/debugging/print-plan-with-stats.rst index 3bb482763c6..f215705d2e4 100644 --- a/velox/docs/develop/debugging/print-plan-with-stats.rst +++ b/velox/docs/develop/debugging/print-plan-with-stats.rst @@ -283,3 +283,22 @@ TableScan operator shows how many rows were processed by pushing down aggregatio .. code-block:: loadedToValueHook sum: 50000, count: 5, min: 10000, max: 10000 + +GPU Operator Stats +~~~~~~~~~~~~~~~~~~ + +When cuDF GPU operators are enabled, the stats output includes GPU-specific +operators. Adapter operators (``CudfToVelox``, ``CudfFromVelox``) appear as +operator-type breakdown lines under their parent plan node: + +.. code-block:: + + -- Aggregation[4][FINAL] -> a0:DOUBLE + Output: 2 rows, Cpu time: 546us, Wall time: 601us + CudfToVelox: Input: 1 rows, Cpu time: 193us, Wall time: 224us + CudfReduceFINAL: Input: 1 rows, Cpu time: 352us, Wall time: 377us + +Here, ``Aggregation[4]`` has two operators: ``CudfReduceFINAL`` (the GPU +aggregation) and ``CudfToVelox`` (GPU-to-CPU format conversion). The summary +line shows their combined stats. This is the same multi-operator mechanism +used by ``HashBuild``/``HashProbe`` under join nodes. diff --git a/velox/docs/develop/expression-evaluation.rst b/velox/docs/develop/expression-evaluation.rst index ec9ef94a76b..7b2293aea62 100644 --- a/velox/docs/develop/expression-evaluation.rst +++ b/velox/docs/develop/expression-evaluation.rst @@ -371,13 +371,13 @@ depth-first order. For each node a sequence of operations is performed. Flat No-Nulls Fast Path ``````````````````````` -When evaluating simple expressions on short vectors (< 1000 rows), the overhead -of handling nulls and encodings is visible. To optimize these use cases, -expression evaluation takes flat-no-nulls fast path -(Expr::evalFlatNoNulls). This path applies automatically when inputs are flat -vectors or constants with no nulls and all sub-expressions are guaranteed to -produce flat-or-constant-no-nulls results given flat-or-constant-no-nulls -inputs. +When evaluating simple expressions, the overhead of handling nulls and encodings +is visible. To optimize these use cases, expression evaluation takes +flat-no-nulls fast path (Expr::evalFlatNoNulls). This path applies automatically +when inputs are flat vectors or constants with no nulls and all sub-expressions +are guaranteed to produce flat-or-constant-no-nulls results given +flat-or-constant-no-nulls inputs. The optimization can be disabled by setting the +``expression.eval_flat_no_nulls`` configuration property to false. An example of a workload that benefits from this optimization is basic arithmetic over non-null floats found in many machine learning pre-processing workloads. @@ -457,3 +457,23 @@ SWITCH expression evaluation goes through the following steps: SWITCH expression sets EvalCtx::isFinalSelection flag to false. The expressions are expected to use this flag to decide whether the partially populated result vector must be preserved or can be overwritten. + +Expression Hashing +`````````````````` + +Each expression node provides a ``hash()`` method that computes a stable hash +value for the expression tree. This hash can be used for deduplication +and expression comparison. + +**Stability Guarantee**: Expression hashes are stable across different +processes, builds, and machines. This is achieved by using a stable hasher +like ``folly::hasher``. + +The hash is computed recursively: + +* ``localHash()`` computes a hash for the expression node itself (type name, + function name, field name, etc.) +* ``hash()`` combines ``localHash()`` with the type's hash and the hashes of + all input expressions + +This enables use cases like expression-based sampling and deduplication. diff --git a/velox/docs/develop/hash-table-caching.rst b/velox/docs/develop/hash-table-caching.rst new file mode 100644 index 00000000000..056189854d7 --- /dev/null +++ b/velox/docs/develop/hash-table-caching.rst @@ -0,0 +1,510 @@ +======================== +Broadcast Build Caching +======================== + +Background +---------- + +In materialized execution engines like Spark and Presto on Spark, for broadcast joins, +the build side splits are replicated to all join tasks due to upfront split planning. +This kind of upfront split planning allows these engines to provide task level fault tolerance +as the input splits of the tasks are tracked and output data can be discarded, +thus enabling task level retries. + +But due to this, each task independently builds an identical hash table from the +same data. For large build sides this is wasteful: every task spends CPU and memory +constructing the same hash table that another task in the same query has already built. + +The Build IO Tax +^^^^^^^^^^^^^^^^ + +Also, the broadcast data follows a write-once-read-many I/O pattern. Each task re-reads +the build side data independently. When the number of tasks is large --- +O(100k) tasks across 10k+ workers --- these concurrent reads overwhelm the I/O +service layer, leading to throttling. + +Throttling causes tasks to stall for seconds to minutes waiting for I/O. When +queries are charged for reserved workers, these stalls mean reserved resources +sit idle, increasing query cost. Beyond I/O fetch delays, when the hash table is +large (in the gigabyte range), the CPU cost of rebuilding it per task is also +significant and wasteful. + +Hash table caching eliminates this redundant work by allowing the first task to +build the hash table and making it available to all subsequent tasks in the same +Velox instance. This is a build-once, reuse-many paradigm. In Sapphire-Velox, +this implements a once-per-worker model that yields more than an order of +magnitude savings, since the number of tasks far exceeds the number of workers. + +Enabling Hash Table Caching +---------------------------- + +Hash table caching is enabled by setting the ``useHashTableCache`` flag to +``true`` on the ``HashJoinNode``: + +.. code-block:: c++ + + auto joinNode = + core::HashJoinNode::Builder() + .id(planNodeIdGenerator->next()) + .joinType(core::JoinType::kInner) + .nullAware(false) + .leftKeys({leftKeyField}) + .rightKeys({rightKeyField}) + .left(probeNode) + .right(buildNode) + .outputType(outputType) + .useHashTableCache(true) + .build(); + +When ``useHashTableCache`` is false (the default), the hash join behaves +exactly as before. The flag is only intended for broadcast joins and is +currently used by Presto-on-Spark. + +Overall Design +-------------- + +Hash table caching introduces a global singleton ``HashTableCache`` that stores +built hash tables keyed by ``queryId:planNodeId``. The cache coordinates +between tasks so that exactly one task builds the hash table while other tasks +wait and then reuse the result. + +The ``HashTableCache`` is a process-wide singleton in the Velox instance, +alongside the ``AsyncDataCache`` and ``MemoryManager``. The cache and its +methods provide building blocks for drivers within a task and tasks within a +worker to coordinate hash table construction and reuse. + +The design has three main components: + +1. **HashTableCache** - A process-wide singleton that stores and manages cached + hash table entries. +2. **HashTableCacheEntry** - A cache entry that holds the hash table, build + coordination state, and a dedicated memory pool. +3. **HashBuild operator integration** - Logic in the HashBuild operator to + check the cache, build or wait, and store the result. + +Cache Structure +--------------- + +``HashTableCache`` is a thread-safe singleton that maps cache keys to cache +entries: + +.. code-block:: c++ + + class HashTableCache { + std::mutex lock_; + std::unordered_map> tables_; + }; + +Each ``HashTableCacheEntry`` contains: + +.. list-table:: + :widths: 25 75 + :header-rows: 1 + + * - Field + - Description + * - ``cacheKey`` + - The key used to look up this entry (``queryId:planNodeId``). + * - ``builderTaskId`` + - The task ID of the task that is responsible for building the table. + * - ``tablePool`` + - A leaf memory pool under the query pool used for table allocations. + * - ``table`` + - The built ``BaseHashTable``, set once build is complete. + * - ``hasNullKeys`` + - Whether the build side contained null join keys. + * - ``buildComplete`` + - Atomic flag indicating whether the table has been fully built. + * - ``buildPromises`` + - Promises used to notify waiting tasks when build completes. + +Cache API +--------- + +The ``HashTableCache`` exposes three methods. All decisions are made under a +single ``std::mutex``. The lock is held only for map lookups, inserts, and +promise creation --- never during table building or memory allocation. + +get() +^^^^^ + +The ``get()`` method is the central coordination point. It is called by every +``HashBuild`` operator during ``initialize()`` and determines the caller's +role under the mutex: + +.. code-block:: text + + get(key, taskId, queryCtx, *future): + lock(lock_) + Case 1 – No entry: create entry, set builderTaskId → return entry (Builder) + Case 2 – Same task: return entry (Builder, coordinate via JoinBridge) + Case 3 – Diff task, not complete: push promise → return entry + future (Waiter) + Case 4 – Complete: return entry (Late Arrival) + +When creating a new entry, ``get()`` allocates a ``tablePool`` as a leaf child +of the query pool and registers a ``QueryCtx`` release callback that calls +``drop()`` on query destruction. + +Key design decisions in ``get()``: + +- **Memory pool ownership**: The ``tablePool`` is a leaf child of the first + caller's ``QueryCtx`` root pool. All drivers in the builder task share this + pool for partial table allocations (via ``HashBuild::tableMemoryPool()``), + tying the cached table's memory accounting to the originating query. +- **Cleanup callback**: ``QueryCtx::addReleaseCallback`` ensures the cache entry + is dropped when the query finishes. ``drop()`` resets the + ``shared_ptr`` outside the lock to free memory before the entry + is destroyed. +- **Lock scope**: All decisions are made under a single ``std::mutex``. The lock + is held only for map lookups/inserts and promise creation --- never during + table building or memory allocation. + +put() +^^^^^ + +Called by the last driver of the builder task after merging all partial tables. +Publishes the table and wakes all waiters: + +.. code-block:: text + + put(key, table, hasNullKeys): + lock(lock_) + entry.table = table + entry.buildComplete = true + promises = move(entry.buildPromises) + unlock(lock_) + for each promise: promise.setValue() // wake waiters outside lock + +drop() +^^^^^^ + +Removes a cache entry and frees the table memory. Called by the ``QueryCtx`` +cleanup callback when the query is destroyed: + +.. code-block:: text + + drop(key): + lock(lock_) + entry = move(tables_[key]) + tables_.erase(key) + unlock(lock_) + entry.table.reset() // free memory outside lock + +Build Coordination +------------------ + +When hash table caching is enabled, the HashBuild operator calls +``HashTableCache::get()`` during initialization. The cache uses the first +caller's task as the builder and makes subsequent callers wait. + +Builder Task +^^^^^^^^^^^^ + +The first task to call ``get()`` for a given key creates the cache entry and +becomes the builder. This task proceeds through the normal HashBuild flow: +all its drivers build partial hash tables, the last driver merges them, and +the merged table is stored in the cache via ``HashTableCache::put()``. + +Drivers within the builder task coordinate with each other through the +existing ``HashJoinBridge`` mechanism. The cache does not interfere with +intra-task driver synchronization. + +Waiter Tasks +^^^^^^^^^^^^ + +When a task calls ``get()`` and finds that another task is already building the +table (``builderTaskId`` differs from its own task ID and ``buildComplete`` is +false), it receives a ``ContinueFuture`` and transitions to the +``kWaitForBuild`` state. The task is suspended until the builder task calls +``put()``, which fulfills all waiting promises. + +Once notified, the waiter task calls ``noMoreInput()`` which finds the table +in the cache and passes it directly to the ``HashJoinBridge`` without building +anything. The probe side then runs normally against the cached table. + +.. code-block:: text + + Task 1 (Builder) Task 2 (Waiter) Task 3 (Waiter) + ──────────────── ─────────────── ─────────────── + get() → creates entry get() → sees builder get() → sees builder + builds hash table receives future receives future + put() → sets table (suspended) (suspended) + notifies waiters ──────────→ wakes up wakes up + uses cached table uses cached table + +Cache Hit +^^^^^^^^^ + +If a task calls ``get()`` and finds ``buildComplete`` is already true, the +cached table is returned immediately. The HashBuild operator skips all build +logic and passes the table to the ``HashJoinBridge``. + +The HashBuild operator reports cache hits and misses via runtime statistics: + +- ``hashtable.cacheHit`` - Table was found in the cache and reused. +- ``hashtable.cacheMiss`` - Table was not in the cache; this task built it. + +Usage by HashBuild +------------------ + +The HashBuild operator uses the cache in a three-phase protocol: build, synchronize, +and probe. + +Step 1: Build Phase (Producer) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +When a ``HashBuild`` operator is initialized, it checks the cache via +``setupCachedHashTable()``. + +- **Cache Miss (Builder)**: The first task to find a miss becomes the Builder. + It creates a cache entry, pulls data from storage, builds the + ``BaseHashTable``, and calls ``put()`` to publish it. Within the Builder task, + subsequent drivers also call ``get()`` and receive the same entry (since + ``builderTaskId == taskId``). Each driver calls ``setupTable()`` to allocate + its own partial ``BaseHashTable`` using ``cacheEntry->tablePool``, receives its + subset of input via ``addInput()``, and builds a partial table. Intra-task + coordination between these drivers uses the standard ``allPeersFinished()`` / + ``JoinBridge`` mechanism, not the cache. + +- **The Wait (Waiters)**: If other tasks arrive while the Builder is building, + they encounter the pending state and transition to ``kWaitForBuild``, waiting + on a ``ContinueFuture`` provided by the cache. + +- **Short-circuiting Upstream**: Once the Builder publishes the table, waiters + are unblocked. Upon receiving the cached table, waiter tasks set their + no-more-input flags. This short-circuits their source operators (e.g., + ``TableScan``), immediately stopping further data retrieval. + +Step 2: Synchronization (JoinBridge) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The ``HashJoinBridge`` acts as the hand-off point between the build and probe +sides. Even if the table was retrieved from the cache rather than built locally, +the bridge ensures the probe side is notified that the data is ready for +processing. Both builder and waiter tasks call +``joinBridge.setHashTable()`` to publish the table (or cached table) to the +probe operators. + +Step 3: Probe Phase (Consumer) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The ``HashProbe`` operator takes the cached table from the bridge and executes +as usual. Because the table is held as a ``shared_ptr``, the probe operator's +reference prevents the cache from freeing the table while a join is actively +scanning it. Once the probe finishes, the reference count is decremented. The +table is ultimately freed when the ``QueryCtx`` release callback calls +``drop()``. + +HashBuild Lifecycle +------------------- + +The following pseudocode shows the complete lifecycle of a ``HashBuild`` +operator when hash table caching is enabled. Only the key function calls are +shown. + +Initialization +^^^^^^^^^^^^^^ + +.. code-block:: text + + initialize(): + cacheKey = "queryId:planNodeId" + cacheEntry = HashTableCache::instance()->get(cacheKey, taskId, queryCtx, &future_) + + if cacheEntry.buildComplete: // Late Arrival + noMoreInput() // → finishHashBuild() → getHashTableFromCache() + return + + if future_.valid(): // Waiter + state = kWaitForBuild + return + + // Builder: proceed with normal table setup + setupTable() // allocate BaseHashTable using cacheEntry.tablePool + setupSpiller() // no-op: canSpill() returns false with cache + +Build and Publish +^^^^^^^^^^^^^^^^^ + +.. code-block:: text + + noMoreInput() → finishHashBuild(): + if not allPeersFinished(): // wait for peer drivers in same task + state = kWaitForBuild + return + + if getHashTableFromCache(): // Waiter or Late Arrival: cache has table + joinBridge.setHashTable(cacheEntry.table, hasNullKeys) + return + + // Builder (last driver): merge and publish + table_.prepareJoinTable(otherTables) + HashTableCache::instance()->put(cacheKey, table_, hasNullKeys) + joinBridge.setHashTable(table_, hasNullKeys) + +Waiter Wake-up +^^^^^^^^^^^^^^ + +.. code-block:: text + + isBlocked(): + case kWaitForBuild: + if receivedCachedHashTable(): // future_ fulfilled, buildComplete == true + setRunning() + noMoreInput() // → finishHashBuild() → getHashTableFromCache() + +Skipping Source Reads +--------------------- + +Waiter tasks never read any data from storage. No splits are fetched, no +exchanges are initiated. + +In Velox, a build-side pipeline is a chain of operators ending with +``HashBuild`` as the sink: + +.. code-block:: text + + [TableScan / Exchange] → ... → [HashBuild] + operators_[0] operators_[last] + +The ``Driver::runInternal()`` loop iterates through operator pairs ``(op, +nextOp)`` and, for each pair, follows this sequence: + +1. Check ``op->isBlocked()`` --- if blocked, suspend the Driver. +2. Check ``nextOp->isBlocked()`` --- if blocked, suspend the Driver. +3. Check ``nextOp->needsInput()`` --- if false, skip pulling from ``op``. +4. Call ``op->getOutput()`` and feed the result to ``nextOp->addInput()``. + +The critical point is that ``nextOp->isBlocked()`` is checked **before** +``op->getOutput()`` is ever called. When the ``HashBuild`` operator is in the +``kWaitForBuild`` state, it returns a blocked status, which prevents the +driver from pulling data from any upstream operator (e.g., ``TableScan`` or +``Exchange``). Once the cached table arrives and the waiter calls +``noMoreInput()``, source operators are short-circuited immediately --- +they never execute at all. + +This is a key benefit of the caching design: waiter tasks incur zero I/O cost. + +Memory Management +----------------- + +Cached hash tables must outlive the task that built them because waiter tasks +from the same query need to access the table after the builder task has +finished. To support this, cached hash tables use a dedicated leaf memory pool +created under the **query** memory pool rather than the operator's task-level +pool. + +Pool Hierarchy +^^^^^^^^^^^^^^ + +.. code-block:: text + + Query Pool + ├── Task 1 Pool (builder - may finish first) + │ └── Operator Pool + └── cached_table_ Pool ← hash table lives here + (created by HashTableCache) + +The ``tablePool`` is created by the first call to ``get()`` as a leaf child of +the caller's ``QueryCtx`` root pool. All drivers in the builder task share this +pool for their partial table allocations via ``HashBuild::tableMemoryPool()``. +This ties the cached table's memory accounting to the originating query rather +than to any individual task, allowing the table to survive task completion. + +Cleanup Callback +^^^^^^^^^^^^^^^^ + +When a cache entry is created, ``HashTableCache::get()`` registers a release +callback on the ``QueryCtx``. When the query context is destroyed, this +callback calls ``HashTableCache::drop()`` to remove the entry and free the +table's memory before the query pool is torn down. ``drop()`` resets the +``shared_ptr`` outside the lock to free memory before the entry +itself is destroyed. This ensures there are no dangling references to +destroyed memory pools. + +``HashBuild::tableMemoryPool()`` returns the cache entry's ``tablePool`` when +caching is enabled, or the operator's own ``pool()`` for regular joins. + +Ownership and Shared Pointers +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Without caching, the hash table is transferred to the ``HashJoinBridge`` as a +``unique_ptr``. With caching enabled, the table is stored as a ``shared_ptr`` +in the cache entry and a copy of the ``shared_ptr`` is passed to the bridge. +This allows the cache to retain ownership while the bridge and probe operator +also hold references. The ``HashJoinBridge::setHashTable()`` signature was +changed to accept ``shared_ptr`` to support this. + +Reference counting ensures that the table is not freed while any probe operator +is actively scanning it. Once the probe finishes, the reference count is +decremented. The table is ultimately freed when the ``QueryCtx`` release +callback calls ``drop()``. + +Spilling +-------- + +Spilling is not supported when hash table caching is enabled. Both +``HashBuild::canSpill()`` and ``HashBuild::canReclaim()`` return false when +``useHashTableCache`` is true: + +.. code-block:: c++ + + bool HashBuild::canSpill() const { + // ... + if (useHashTableCache()) { + return false; + } + // ... + } + +This is because spilling clears the hash table from memory and rebuilds it +later, which would corrupt the cached table that other tasks may be using. +Specifically: + +- **Builder task**: Cannot spill because the table is shared via the cache. + Spilling would invalidate the ``shared_ptr`` held by waiter tasks. +- **Waiter tasks**: Cannot spill because they use the cached table directly + and never build their own. +- **Coordination complexity**: Rebuild-after-spill would require re-coordinating + across all tasks sharing the cached table. + +Broadcast joins (the primary use case for this cache) are generally expected to +fit in memory. If a build-side relation is large enough to require spilling, it +should bypass the cache and use a standard partitioned hash join with spilling +enabled. + +Eviction +-------- + +Cache eviction is not currently supported. Entries remain in the cache until +the query context is destroyed, at which point the release callback removes +them. + +Future memory pressure-based eviction would need to address: + +1. **Tracking total memory**: Summing the memory held by all cached tables. +2. **Eviction policy**: Deciding which entries to evict (e.g., LRU, by size). +3. **Reference invalidation**: Safely handling eviction while probe operators + hold references via ``shared_ptr``. +4. **Rebuild fallback**: Allowing tasks to re-build the table if it was evicted. + +The ``drop()`` method already provides the mechanism for removing individual +entries and could be extended to support eviction driven by the memory manager +or arbitration framework. + +Limitations and Future Work +--------------------------- + +- **No spilling**: Cached tables must reside entirely in memory. See + :ref:`Spilling ` above. +- **No eviction**: Cached entries live for the full query lifetime. Memory + pressure-based eviction is planned. +- **Single-query scope**: The cache key includes the ``queryId``, so tables are + not shared across different queries even if the build side data is identical. + Cross-query sharing is a potential future optimization. +- **No sanity checks on table sharing during probe**: For right joins, we rely on + the planner to not do a broadcast join and skip using cached tables. But velox + as a library does not do checks during probe that it is in fact running a join + that does not mutate the hash table. Mutating the cached hash table can cause + incorrect execution results. We should add this check diff --git a/velox/docs/develop/hash-table.rst b/velox/docs/develop/hash-table.rst index 1d82a1427ca..421668256c5 100644 --- a/velox/docs/develop/hash-table.rst +++ b/velox/docs/develop/hash-table.rst @@ -4,7 +4,7 @@ .. role:: m(math) ========== -Hash table +Hash Table ========== The hash table used in Velox is similar to the @@ -30,7 +30,7 @@ These are referred to as padding. :align: center A hash table is never full. There are always some empty slots. Velox allows the hash table to fill up to -:raw-html:`` of capacity before resizing. +:raw-html:`0.7` of capacity before resizing. On resize the hash table’s capacity doubles. Individual buckets may be completely empty, partially filled or full. Buckets are filled left to right. @@ -117,7 +117,7 @@ the hash table is never full and there are enough gaps in the form of empty slot Resizing -------- -If the hash table fills up beyond :raw-html:`` +If the hash table fills up beyond :raw-html:`0.7` of capacity, it needs to be resized. Each resize doubles the capacity. A new hash table is allocated and all existing entries inserted using the “Inserting an entry” process. Since we know that all entries are unique, the “Inserting an entry” process can be simplified to @@ -126,11 +126,181 @@ insert an entry, we compute a hash, extract tag and bucket number, go to the buc entry if there is space. If the bucket is full, we proceed to the next bucket and continue until we find a bucket with an empty slot. We insert the new entry there. +Hash Modes +---------- + +The description above covers the default bucket-based hash table (kHash mode). +Velox also supports two optimized modes that avoid per-entry hashing and +bucket probing when the key values allow it. The hash table analyzes the key +data during build and selects the best mode automatically. + +The three modes are: + +* **kArray** — Direct array lookup. Does not use the bucket-based hash table at + all. Each key combination maps to an index in a flat array. Lookup is O(1) + with no hashing or probing. Used when the combined key space is small enough + to fit in an array. + +* **kNormalizedKey** — Bucket-based (same layout as kHash), but keys are + encoded into a single 64-bit normalized key stored alongside each row. Key + comparison uses this normalized key instead of comparing individual columns, + which is faster for multi-column keys. + +* **kHash** — Bucket-based with full key comparison. Used when keys cannot be + mapped to value IDs or normalized into 64 bits (e.g., complex types like + ARRAY, MAP, ROW). + +kArray Mode +~~~~~~~~~~~ + +In kArray mode, the bucket-based hash table is not used at all. Instead, +``table_`` is a flat array of pointers indexed directly by a value ID computed +from the key columns. Lookup is a single array access — no hashing, no tag +comparison, no probing. + +VectorHasher tracks the range (min, max) and distinct values for each key +column. Each column is assigned a *multiplier* so that multi-column keys +produce a unique combined index: + +.. code-block:: text + + index = valueId(col0) + valueId(col1) * multiplier1 + valueId(col2) * multiplier2 + ... + +The value ID for a column is computed using one of two approaches: + +1. **Range-based**: for numeric types, the value ID is ``value - min``. The + array dimension for the column is the range (max - min + 1). The combined + product of all column ranges must be < 2M. This is preferred when the range + is within 20x of the distinct count (to avoid wasting array space on sparse + ranges). + +2. **Distinct-value-based**: VectorHasher maintains a mapping from each unique + value to a consecutive integer ID (0, 1, 2, ...). This works for all + supported types including VARCHAR, where each unique string gets its own + ID. The combined product of per-column distinct counts must be < 2M. + This is used when ranges are too large or not applicable (e.g., for + VARCHAR, where values don't have a numeric range). + +The array size is the product of all per-column dimensions (ranges or distinct +counts), capped at ``kArrayHashMaxSize`` (2M entries = 16MB of pointer +storage). + +**Supported types**: BOOLEAN, TINYINT, SMALLINT, INTEGER, BIGINT, VARCHAR, +VARBINARY, TIMESTAMP. Types like REAL, DOUBLE, ARRAY, MAP, ROW do not support +value ID tracking and cannot use kArray mode. + +**Examples**: + +* Two BIGINT columns, 500 rows with values 0..499. Range per column is 500, + combined range is 500 * 500 = 250'000 < 2M. Uses range-based kArray. + (See ``HashTableTest.int2DenseArray``.) + +* One VARCHAR column, 500 rows. Each unique string is assigned a consecutive + ID (e.g., "apple" → 0, "banana" → 1, ...). With 500 distinct values, the + array has 500 entries. (See ``HashTableTest.string1DenseArray``.) + +* Two BIGINT columns, 500 rows with spacing 1'000 (values 0, 1000, 2000, ...). + Range per column is 500'000, combined range is 250B — too large. But distinct + count per column is 500, combined 250'000 < 2M. Uses distinct-value-based + kArray. (See ``HashTableTest.int2SparseArray``.) + +kNormalizedKey Mode +~~~~~~~~~~~~~~~~~~~ + +When the combined key space exceeds 2M entries but can be encoded into a single +64-bit integer, the table uses kNormalizedKey mode. This uses the same +bucket-based layout as kHash, but stores a 64-bit *normalized key* immediately +before each row in the RowContainer. + +The normalized key is computed using the same multiplier-based encoding as +kArray mode: + +.. code-block:: text + + normalizedKey = valueId(col0) + valueId(col1) * multiplier1 + ... + +During lookups, the normalized key is compared first — a single 64-bit integer +comparison. If it doesn't match, the full per-column key comparison is skipped. +This is particularly effective for multi-column keys where comparing individual +columns would require multiple memory accesses and type-specific comparisons. + +**Examples**: + +* Two VARCHAR columns, 5'000 rows. Distinct count per column exceeds what fits + in a flat array, but the combined distinct values fit in 64 bits. + (See ``HashTableTest.string2Normalized``.) + +* Two BIGINT columns, 10'000 rows with spacing 1'000 (values 0, 1000, 2000, + ...). Range per column is 10M, combined range overflows the 2M array limit, + but fits in a 64-bit normalized key. + (See ``HashTableTest.int2SparseNormalized``.) + +Adaptive Prefetching in hashRows +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +When computing hashes in kNormalizedKey mode, ``hashRows`` reads the normalized +key stored immediately before each row pointer. Because rows are allocated from +a RowContainer arena and accessed in hash-partitioned order, successive row +pointers typically reference different cache lines. When +the working set exceeds the CPU's last-level cache, each normalized key read +incurs a DRAM access. + +To hide this latency, ``hashRows`` uses the ``AdaptivePrefetch`` class. During +the first 16 iterations, the class measures per-iteration time using a +conservative look-ahead of 4. After measurement, it computes an optimal look-ahead +distance based on the ratio of assumed DRAM latency to measured iteration time, +multiplied by a coefficient of 4. The result is clamped to [4, 32] — values +above 32 risk polluting L1 cache with too many outstanding prefetches. + +kHash Mode +~~~~~~~~~~ + +This is the fallback mode used when: + +* Key types don't support value IDs (e.g., ARRAY, MAP, ROW, DOUBLE, REAL). +* A single key column has more than 10'000 distinct values and the range + overflows (cannot use normalized keys). +* Both the combined range and combined distinct count overflow 64 bits. + +In this mode, lookups compute a hash, probe buckets, compare tags, and then +compare actual key values by following pointers to the RowContainer. + +**Examples**: + +* One ROW(BIGINT, VARCHAR, BIGINT) column. ROW type does not support value IDs. + (See ``HashTableTest.structKey``.) + +* Six columns (5 BIGINT + 1 VARCHAR), 100'000 rows with spacing 1'000. The + combined cardinality overflows 64 bits. + (See ``HashTableTest.mixed6Sparse``.) + +Mode Selection +~~~~~~~~~~~~~~ + +The mode is selected by ``decideHashMode()`` using this priority: + +1. If combined ranges < 2M → **kArray** (range-based). +2. If best combination of per-column ranges/distincts < 2M → **kArray** + (mixed). +3. If combined ranges fit in 64 bits → **kNormalizedKey**. +4. If single key column with > 10'000 distincts → **kHash** (normalized key + not worthwhile for a single wide column). +5. If combined distincts < 2M → **kArray** (distinct-value-based). +6. If both ranges and distincts overflow → **kHash**. +7. Otherwise → **kNormalizedKey** (combined distincts fit in 64 bits). + +The selected mode is reported in the ``hashtable.hashMode`` runtime stat: +0 for kHash, 1 for kArray, 2 for kNormalizedKey. + +See ``HashTableTest`` in ``velox/exec/tests/HashTableTest.cpp`` for tests +covering all three modes. + Use Cases --------- -The main use cases for the hash table are `Join `_ and -`Aggregation `_ operators. +The main use cases for the hash table are :doc:`Join ` and +:doc:`Aggregation ` operators. It is also used by RowNumber, +TopNRowNumber, and MarkDistinct operators. The HashBuild operator builds the hash table to store unique values of the join keys found on the build side of the join. The HashProbe operator looks up entries in the hash table using join keys from the diff --git a/velox/docs/develop/images/task-barrier-multi-driver.png b/velox/docs/develop/images/task-barrier-multi-driver.png new file mode 100644 index 00000000000..33a173cb512 Binary files /dev/null and b/velox/docs/develop/images/task-barrier-multi-driver.png differ diff --git a/velox/docs/develop/images/task-barrier-single-driver.png b/velox/docs/develop/images/task-barrier-single-driver.png new file mode 100644 index 00000000000..cd861aafb5a Binary files /dev/null and b/velox/docs/develop/images/task-barrier-single-driver.png differ diff --git a/velox/docs/develop/joins.rst b/velox/docs/develop/joins.rst index 0a1bf3d6a5c..5f91fea6f43 100644 --- a/velox/docs/develop/joins.rst +++ b/velox/docs/develop/joins.rst @@ -3,10 +3,13 @@ Joins ===== Velox supports inner, left, right, full outer, left semi filter, left semi -project, right semi filter, right semi project, and anti hash joins using -either partitioned or broadcast distribution strategies. Semi project and +project, right semi filter, right semi project, anti, and counting hash joins +using either partitioned or broadcast distribution strategies. Semi project and anti joins support additional null-aware flag to distinguish between IN -(null aware) and EXISTS (regular) semantics. Velox also supports cross joins. +(null aware) and EXISTS (regular) semantics. Anti, left semi filter, and +counting joins support a null-as-value flag that enables IS NOT DISTINCT FROM +semantics for join keys (NULL equals NULL), used to implement SQL set operations +(EXCEPT, INTERSECT, EXCEPT ALL, INTERSECT ALL). Velox also supports cross joins. Velox also supports inner and left merge join for the case where join inputs are sorted on the join keys. Right, full, left semi, right semi, and anti merge joins @@ -24,11 +27,18 @@ values need to match, and an optional filter to apply to join results. :align: center The join type can be one of kInner, kLeft, kRight, kFull, kLeftSemiFilter, -kLeftSemiProject, kRightSemiFilter, kRightSemiProject, or kAnti. +kCountingLeftSemiFilter, kLeftSemiProject, kRightSemiFilter, kRightSemiProject, +kAnti, or kCountingAnti. kLeftSemiProject, kRightSemiProject and kAnti joins support an additional nullAware flag to distinguish between IN (null aware) and EXISTS (regular) -semantics. +semantics. See :doc:`Anti joins ` for detailed semantics. + +kAnti, kLeftSemiFilter, kCountingAnti, and kCountingLeftSemiFilter joins support +an additional nullAsValue flag that makes join keys use IS NOT DISTINCT FROM +semantics where NULL equals NULL. This is used to implement SQL set operations +(EXCEPT, INTERSECT, EXCEPT ALL, INTERSECT ALL) which require NULLs to match. +The nullAsValue and nullAware flags are mutually exclusive. Filter is optional. If specified it can be any expression over the results of the join. This expression will be evaluated using the same expression @@ -119,7 +129,7 @@ HashProbe operator gets access to via a special mechanism: JoinBridge. :align: center Both HashBuild and HashAggregation operators use the same data structure for the -hash table: `velox::exec::HashTable `_. The payload, the non-join key columns +hash table: :doc:`HashTable `. The payload, the non-join key columns referred to as dependent columns, are stored row-wise in the RowContainer. Using the hash table in join and aggregation allows for a future optimization @@ -198,6 +208,12 @@ the join is executed using broadcast or partitioned strategy has no effect on the join execution itself. The only difference is that broadcast execution allows for dynamic filter pushdown while partitioned execution does not. +HashJoinNode supports a ``useHashTableCache`` flag (used only by Presto-on-Spark) +that enables caching of the hash table built for broadcast joins. When enabled, +the first task to build the hash table stores it in a global cache, and subsequent +tasks from same query reuse the cached table instead of rebuilding it. See +:doc:`Broadcast Build Caching ` for details. + PartitionedOutput operator and OutputBufferManager support broadcasting the results of the plan evaluation. This functionality is enabled by setting boolean flag "broadcast" in the PartitionedOutputNode to true. @@ -218,7 +234,7 @@ regular anti join is used for queries with NOT EXISTS clause. Broadly-speaking anti join returns probe-side rows which have no match on the build side. However, the exact semantics are a bit tricky. These are -described in detail in :doc:`Anti joins <../develop/anti-join>`. +described in detail in :doc:`Anti joins `. At a high level, null-aware anti join without extra filter behaves as follows: @@ -249,7 +265,7 @@ only if the whole build side is empty, allowing to implement semantic safe because that row cannot possibly match anything on these destinations. Semi Joins ----------- +~~~~~~~~~~ Semi filter joins are used for queries with IN and EXISTS clauses. Left semi filter join should be used when cardinality of the outer @@ -306,13 +322,63 @@ configuring exec::HashTable to set the "allowDuplicates" flag to false. This optimization reduces memory usage of the hash table in case the build side contains duplicate join keys. +Counting Joins +~~~~~~~~~~~~~~ + +Counting joins (kCountingAnti and kCountingLeftSemiFilter) are multiset variants +of kAnti and kLeftSemiFilter used to implement EXCEPT ALL and INTERSECT ALL +respectively. + +EXCEPT and INTERSECT treat both inputs as sets: they remove duplicates and +return distinct rows. EXCEPT ALL and INTERSECT ALL preserve duplicates and +operate on multisets: each row is considered independently. + +.. code-block:: sql + + -- EXCEPT removes duplicates: + -- {A, A, A, B, B} EXCEPT {A, C} = {A, B} (just "is A present?") + + -- EXCEPT ALL counts duplicates: + -- {A, A, A, B, B} EXCEPT ALL {A, C} = {A, A, B, B} (remove one A) + + -- INTERSECT removes duplicates: + -- {A, A, A, B, B} INTERSECT {A, A, C} = {A} (just "is A present?") + + -- INTERSECT ALL counts duplicates: + -- {A, A, A, B, B} INTERSECT ALL {A, A, C} = {A, A} (keep min count) + +EXCEPT ALL and INTERSECT ALL cannot be implemented using regular semi and anti +joins because those only check for the existence of a match, not the number of +matches. Counting joins solve this by tracking per-key counts. + +The build side deduplicates keys and stores a per-key count. Unlike semi and +anti joins which skip duplicate keys entirely (see `Skipping Duplicate Keys`_ +above), counting joins keep exactly one entry per key with a count of how many +times that key appeared. + +On probe, each match decrements the count. kCountingAnti emits a probe row when +the count reaches zero or no match is found (EXCEPT ALL semantics). +kCountingLeftSemiFilter emits a probe row while the count is greater than zero +(INTERSECT ALL semantics). Counting joins do not support extra filter or +null-aware mode. Spilling is not yet supported. + +Because the probe side modifies per-key counts in the hash table (decrementing +on each match), it must run either single-threaded or with the probe input +partitioned on join keys across threads. Without partitioning, multiple probe +threads could decrement the same key's count concurrently, producing incorrect +results. The build side can run multi-threaded: each build driver constructs its +own hash table, and these are merged by summing per-key counts for duplicate +keys. + Execution Statistics ~~~~~~~~~~~~~~~~~~~~ -HashBuild operator reports the range and number of distinct values for each join -key if these are not too large and allow for array-based join or use of -normalized keys. +HashBuild operator reports the hash table mode, the range and number of distinct +values for each join key if these are not too large and allow for array-based +join or use of normalized keys. +* hashtable.hashMode - the hash mode of the table: 0 for kHash, 1 for kArray, + 2 for kNormalizedKey * rangeKey - the range of values for the join key #N * distinctKey - the number of distinct values for the join key #N @@ -339,18 +405,19 @@ to HiveConnector. Memory Layout ------------- -Inside hash table we keep the row values in `RowContainer`. This is a row-wise -storage and each row consists the following components: +The :doc:`hash table ` stores row values in RowContainer. This is +a row-wise storage and each row consists of the following components: 1. Null flags (1 bit per item) for 1. Keys (only if nullable) 2. Dependants -2. Has-probed flag (1 bit) +2. Has-probed flag (1 bit, right and full outer joins only) 3. Free flag (1 bit) 4. Keys 5. Dependants 6. Variable size (32 bit) -7. Next offset (64 bit pointer) +7. Next offset (64 bit pointer, joins that allow duplicate keys) +8. Count (32 bit, counting joins only, mutually exclusive with Next offset) Merge Join Implementation @@ -377,5 +444,6 @@ pipeline. CallbackSink is installed at the end of the right-side pipeline. Usage Examples -------------- -Check out velox/exec/tests/HashJoinTest.cpp and MergeJoinTest.cpp for examples -of how to build and execute a plan with a hash or merge join. +Check out velox/exec/tests/HashJoinTest.cpp, CountingJoinTest.cpp, and +MergeJoinTest.cpp for examples of how to build and execute a plan with a hash, +counting, or merge join. diff --git a/velox/docs/develop/memory.rst b/velox/docs/develop/memory.rst index 5d879a0137d..3675f821816 100644 --- a/velox/docs/develop/memory.rst +++ b/velox/docs/develop/memory.rst @@ -122,7 +122,7 @@ Memory Manager :alt: Memory Manager The memory manager is created on server startup with the provided -*MemoryManagerOption*. It creates a memory allocator instance to manage the +*MemoryManager::Options*. It creates a memory allocator instance to manage the physical memory allocations for both query memory allocated through memory pool and cache memory allocated through the file cache. It ensures the total allocated memory is within the system memory limit (specified by @@ -539,7 +539,7 @@ between queries by adjusting their memory pool’s capacities accordingly (see The *MemoryArbitrator* is defined to support different implementations for different query systems. As for now, we implement *SharedArbitrator* for both -Prestissimo and Prestissimo-on-Spark. `Gluten `_ implements its own memory +Prestissimo and Prestissimo-on-Spark. `Gluten `_ implements its own memory arbitrator to integrate with the `Spark memory system `_. *SharedArbitrator* ensures the total allocated memory capacity is within the query memory limit (*MemoryManager::Options::arbitratorCapacity*), and also ensures each individual @@ -667,7 +667,7 @@ Here is the memory reclaim process within a query: reclamation through disk spilling and table writer flush. *Operator::reclaim* is added to support memory reclamation with the default implementation does nothing. Only spillable operators override that method: *OrderBy*, *HashBuild*, - *HashAggregation*, *RowNumber*, *TopNRowNumber*, *Window* and *TableWriter*. + *HashAggregation*, *RowNumber*, *TopNRowNumber*, *MarkDistinct*, *Window* and *TableWriter*. As for now, we simply spill everything from the spillable operator’s row container to free up memory. After we add memory compaction support for row containers, we could leverage fine-grained disk spilling features in Velox diff --git a/velox/docs/develop/operators.rst b/velox/docs/develop/operators.rst index 37457bf5cec..c9dbb53cb86 100644 --- a/velox/docs/develop/operators.rst +++ b/velox/docs/develop/operators.rst @@ -56,10 +56,12 @@ ValuesNode Values Y LocalMergeNode LocalMerge LocalPartitionNode LocalPartition and LocalExchange EnforceSingleRowNode EnforceSingleRow +EnforceDistinctNode EnforceDistinct or StreamingEnforceDistinct AssignUniqueIdNode AssignUniqueId WindowNode Window RowNumberNode RowNumber TopNRowNumberNode TopNRowNumber +MixedUnionNode MixedUnion ========================== ============================================== =========================== Plan Nodes @@ -547,6 +549,10 @@ and emitting results. - Join type: inner, left, right, full, left semi filter, left semi project, right semi filter, right semi project, anti. You can read about different join types in this `blog post `_. * - nullAware - Applies to anti and semi project joins only. Indicates whether the join semantic is IN (nullAware = true) or EXISTS (nullAware = false). + * - nullAsValue + - Optional. When true, join keys use IS NOT DISTINCT FROM semantics where NULL equals NULL. Used to implement SQL set operations (EXCEPT, INTERSECT, EXCEPT ALL, INTERSECT ALL). Mutually exclusive with nullAware. + * - useHashTableCache + - Optional. Used only by Presto-on-Spark. When true, enables caching of the hash table built for broadcast joins so that subsequent tasks can reuse it. * - leftKeys - Columns from the left hand side input that are part of the equality condition. At least one must be specified. * - rightKeys @@ -850,6 +856,35 @@ values set to null. If input contains more than one row raises an exception. Used for queries with non-correlated sub-queries. +EnforceDistinctNode +~~~~~~~~~~~~~~~~~~~ + +The EnforceDistinct operator ensures that input rows have unique values for +specified key columns. It passes through all input rows unchanged, but throws +an exception with a custom error message if any duplicate key values are +detected. This is useful for validating uniqueness constraints at runtime, +such as ensuring a correlated scalar subquery returns at most one row per group. + +When preGroupedKeys equals distinctKeys (i.e., input is clustered on the +distinct keys), the streaming implementation is used which requires only O(1) +memory. Otherwise, the hash-based implementation is used which requires O(n) +memory to track all unique key combinations seen so far. + +.. list-table:: + :widths: 10 30 + :align: left + :header-rows: 1 + + * - Property + - Description + * - distinctKeys + - List of columns that must have unique values. + * - preGroupedKeys + - Optional subset of distinctKeys that input is already clustered on. When + equal to distinctKeys, uses streaming enforcement with O(1) memory. + * - errorMessage + - Error message to include in the exception when duplicates are found. + AssignUniqueIdNode ~~~~~~~~~~~~~~~~~~ @@ -947,7 +982,7 @@ results available before seeing all input. TopNRowNumberNode ~~~~~~~~~~~~~~~~~ -An optimized version of a WindowNode with a single row_number function and a +An optimized version of a WindowNode with a single row_number, rank or dense_rank function and a limit over sorted partitions. Partitions the input using specified partitioning keys and maintains up to @@ -955,11 +990,11 @@ a 'limit' number of top rows for each partition. After receiving all input, assigns row numbers within each partition starting from 1. This operator accumulates state: a hash table mapping partition keys to a list -of top 'limit' rows within that partition. Returning the row numbers as +of top 'limit' rows within that partition. Returning the row number or rank as a column in the output is optional. This operator supports spilling as well. This operator is logically equivalent to a WindowNode followed by -FilterNode(row_number <= limit), but it uses less memory and CPU. +FilterNode(rank/row_number <= limit), but it uses less memory and CPU. .. list-table:: :widths: 10 30 @@ -985,6 +1020,11 @@ MarkDistinctNode The MarkDistinct operator is used to produce aggregate mask columns for aggregations over distinct values, e.g. agg(DISTINCT a). Mask is a boolean column set to true for a subset of input rows that collectively represent a set of unique values of 'distinctKeys'. +This operator supports spilling. The spill mechanism follows the same pattern as RowNumber: when memory pressure +triggers spilling, the hash table contents and future input are partitioned and written to disk. During restore, +each partition's hash table is rebuilt from the spilled data, preserving knowledge of which keys were already seen. +Disabled by default; enable with `mark_distinct_spill_enabled` configuration property. + .. list-table:: :widths: 10 30 :align: left @@ -997,6 +1037,35 @@ Mask is a boolean column set to true for a subset of input rows that collectivel * - distinctKeys - Names of grouping keys. +MixedUnionNode +~~~~~~~~~~~~~~ + +The mixed union operation combines data from multiple input sources concurrently, +producing a single output stream that interleaves rows from all sources. It does +not enforce a sort order but does attempt to mix input sources according to +specified ratios; after exhaustion it continues with remaining sources. + +All sources must produce the same output schema. + +MixedUnion runs single-threaded. Each source runs on its own pipeline and feeds +data into the MixedUnion operator via a merge source queue. + +This operator performs a UNION ALL. It does not deduplicate rows. + +.. list-table:: + :widths: 10 30 + :align: left + :header-rows: 1 + + * - Property + - Description + * - sources + - Two or more input plan nodes. All sources must have the same output type. + * - batchSizesPerSource + - Optional list of per-source batch sizes that controls how many rows are + taken from each source when mixing. If not specified or set to zero for a + source, a default batch size is used. + Examples -------- @@ -1059,3 +1128,41 @@ ALL. .. image:: images/local-exchange.png :width: 400 :align: center + +GPU Operators (cuDF) +-------------------- + +When cuDF is enabled, CPU operators are replaced with GPU equivalents at +pipeline construction time via the ``OperatorAdapterRegistry``. For example, +``FilterProject`` becomes ``CudfFilterProject``, ``Aggregation`` becomes +``CudfGroupby`` or ``CudfReduce``, and ``HashJoin`` becomes +``CudfHashJoinBuild``/``CudfHashJoinProbe``. + +Adapter operators are automatically inserted at GPU/CPU boundaries: + +* ``CudfFromVelox`` — inserted before a GPU operator when the preceding + operator produces CPU data (host-to-device conversion). +* ``CudfToVelox`` — inserted after a GPU operator when the next operator + or the pipeline output requires CPU data (device-to-host conversion). + +Adapter operators use synthetic planNodeIds (e.g. ``4-to-velox``) at runtime, +but redirect their stats to the parent plan node via ``setStatSplitter``. +This means they appear in ``printPlanWithStats`` output as operator-type +breakdown lines under their parent node (the same mechanism used by +``HashBuild``/``HashProbe`` under ``HashJoinNode``). + +All cuDF operators extend ``CudfOperatorBase``, which provides: + +* Template method pattern (``doAddInput``/``doGetOutput``/``doClose``) +* NVTX profiling ranges + +GPU operators are identified by their ``Cudf`` prefix in operator type names +(e.g. ``CudfFilterProject``, ``CudfLocalPartition``, ``CudfReduceFINAL``). +Operators without this prefix (e.g. ``LocalExchange``) run on CPU. +``TableScan`` uses the cuDF GPU parquet reader via the connector layer but +is not itself a ``CudfOperatorBase`` subclass. + +In stats output, "Cpu time" and "Wall time" for GPU operators reflect the +host-side duration of ``addInput``/``getOutput`` calls, which includes +enqueuing GPU work and synchronizing. These are not GPU hardware execution +times. diff --git a/velox/docs/develop/scalar-functions.rst b/velox/docs/develop/scalar-functions.rst index 663aa6bd0a7..5b24a766c18 100644 --- a/velox/docs/develop/scalar-functions.rst +++ b/velox/docs/develop/scalar-functions.rst @@ -216,6 +216,38 @@ An example of such function is rand(): } }; +Function Owner +^^^^^^^^^^^^^^ + +Functions can specify an owner (team or individual) responsible for the function. +This information is included in exception context messages to help with error +attribution and debugging. By default, no owner is set. + +To specify a custom owner, define a static ``owner`` variable in your function class: + +.. code-block:: c++ + + template + struct MyFunction { + VELOX_DEFINE_FUNCTION_TYPES(TExec); + + // Specify the team or individual responsible for this function. + static constexpr std::string_view owner = "my-team"; + + FOLLY_ALWAYS_INLINE bool call(int64_t& result, const int64_t& input) { + result = input * 2; + return true; + } + }; + +When an exception occurs during function evaluation, the error context will +include the owner information (if specified), making it easier to identify which team should +investigate the issue. For example, an error message might look like: + +.. code-block:: text + + Owner: my-team. Expression: my_function(c0) + All-ASCII Fast Path ^^^^^^^^^^^^^^^^^^^ @@ -870,13 +902,32 @@ Use exec::registerVectorFunction to register a stateless vector function. exec::registerVectorFunction takes a name, a list of supported signatures and unique_ptr to an instance of the function. It takes an optional 'metadata' parameter that specifies whether a function is deterministic, has default null -behavior, and other properties. A helper VectorFunctionMetadataBuilder class +behavior, owner, and other properties. A helper VectorFunctionMetadataBuilder class allows to easily construct 'metadata'. For example, .. code-block:: c++ VectorFunctionMetadataBuilder().defaultNullBehavior(false).build(); +To specify a custom owner for a vector function, use the owner() method of +VectorFunctionMetadataBuilder: + +.. code-block:: c++ + + VectorFunctionMetadataBuilder() + .defaultNullBehavior(false) + .owner("my-team") + .build(); + +The owner information is included in exception context messages to help with +error attribution and debugging. By default, no owner is set. +When an exception occurs during function evaluation, the error message will +include the owner (if specified), for example: + +.. code-block:: text + + Owner: my-team. Expression: my_vector_function(c0) + An optional “overwrite” flag specifies whether to overwrite a function if a function with the specified name already exists. @@ -1152,6 +1203,72 @@ Benchmarks are a great way to check if an optimization is working, evaluate how much benefit it brings and decide whether it is worth the additional complexity. +Listening to Function Calls +-------------------------- + +Velox supports observing VectorFunction::apply calls via a global listener +registry. This is useful for monitoring, access control, auditing, or any +cross-cutting concern that needs to observe function execution without modifying +function implementations. + +Listener factories are registered globally via +``registerVectorFunctionListenerFactory()``. During expression compilation, +ExprCompiler calls each factory's ``create()`` method once per resolved scalar +function. The factory receives the function name, ``VectorFunctionMetadata``, +and ``QueryConfig``, and returns a ``VectorFunctionListeners`` struct containing +optional pre and/or post listeners, or ``std::nullopt`` to skip that function. + +.. code-block:: c++ + + #include "velox/expression/VectorFunctionListener.h" + + class MyListenerFactory : public VectorFunctionListenerFactory { + public: + std::optional create( + std::string_view functionName, + const VectorFunctionMetadata& metadata, + const core::QueryConfig& queryConfig) override { + if (functionName != "target_fn") { + return std::nullopt; + } + return VectorFunctionListeners{ + std::make_shared( + [](std::string_view functionName, + const SelectivityVector& rows, + const std::vector& args, + const TypePtr& outputType, + const EvalCtx& context) { + // Called before VectorFunction::apply. + }), + std::make_shared( + [](std::string_view functionName, + const SelectivityVector& rows, + const std::vector& args, + const TypePtr& outputType, + const EvalCtx& context, + const VectorPtr& result, + std::exception_ptr error) { + // Called after VectorFunction::apply, even if apply threw. + // 'error' is non-null when apply threw; the framework + // rethrows it after all post-listeners have executed. + }), + }; + } + }; + + // Register globally (typically at startup). + auto factory = std::make_shared(); + registerVectorFunctionListenerFactory(factory); + +Key properties: + +- **Multiple factories** can be registered independently, each observing + different concerns without coordination. +- **Pre-listener** exceptions propagate immediately and abort the function call. +- **Post-listener** exceptions are caught and logged (rate-limited); they do not + mask the original apply error or prevent other post-listeners from running. +- **Special forms** (AND, OR, CAST, etc.) are not subject to listening. + Documenting ----------- diff --git a/velox/docs/develop/task-barrier.rst b/velox/docs/develop/task-barrier.rst new file mode 100644 index 00000000000..4bc31856559 --- /dev/null +++ b/velox/docs/develop/task-barrier.rst @@ -0,0 +1,425 @@ +============ +Task Barrier +============ + +Motivation & Context +-------------------- + +The introduction of Task Barrier support in Velox was driven by two distinct but +related high-performance workloads: AI training data loading and Real-time +Streaming Processing. Both require strict control over task lifecycle and state +management. + +These workloads share three critical requirements: + +1. **High Efficiency (Task Reuse)**: AI training feeds data split-by-split. + Creating a new VeloxTask for every single split incurs some overhead (memory + allocation, plan optimization, operator initialization). + +2. **Checkpointing & Consistency (Streaming)**: Streaming systems need a way to + safely "pause" the stream to take a consistent snapshot. To do this, the + system must ensure all data belonging to a specific time window or epoch is + fully processed and flushed before moving to the next. + +3. **Deterministic Execution**: To ensure experiments are reproducible, data + order and processing must be identical across runs. + +The Solution: Task Barrier +-------------------------- + +To solve this, Velox uses Sequential Task Execution combined with a Task +Barrier. This allows a single Velox Task to be reused indefinitely. + +The "Barrier" is a synchronization mechanism that forces the task to "pause" +and drain all in-flight data—including buffered data in stateful operators. + +* **For AI Loading**: This ensures the task is clean and ready for the next + batch of splits without overhead. + +* **For Streaming**: This acts as a "consistent cut," ensuring that all state + modifications for the current epoch are finalized and emitted before the next + epoch begins. + +API & Usage +----------- + +The core API introduced is ``Task::requestBarrier()``. This signals the task to +finish processing all currently queued splits and fully drain all stateful +operators. + +Workflow +^^^^^^^^ + +1. **Feed Splits**: The application adds a set of splits to the task, providing + exactly one split per each data source (e.g., one file split for every + TableScan node). + +2. **Request Barrier**: Immediately after feeding the splits, the application + calls ``requestBarrier()``. This signals that no more splits will be added + after this Split Set until it has been fully processed, all resulting data + has been produced, and the task has completely drained and signaled its + completion (reached the barrier). + +3. **Process & Barrier Detection**: The application continuously calls + ``task->next()`` to fetch results. From a user perspective, the Barrier + Reached state is detected precisely when: + + * ``task->next()`` returns no data (nullptr), AND + * The returned ContinueFuture is not set (invalid) + + This specific combination confirms that the task is not merely blocked + waiting for I/O, but has fully drained all vectors from the current Split + Set and is now idle. + + (Note: The internal mechanism of how drivers coordinate this draining + sequence and signal completion to the Task is explained in detail in the + Implementation Mechanism section). + +4. **Decision (Cycle or Finish)**: Once the barrier is reached, the application + must take one of two actions: + + * **Finish Task**: If there are no more splits to process (e.g., end of + dataset), call ``noMoreSplits()`` on the task to signal the final end of + the job and terminate. + * **Repeat Cycle**: If more splits exist, proceed to Step 5 to add the next + set. + +5. **Resume**: The application adds the new Split Set and repeats the process + from Step 1. + +Code Example (C++) +^^^^^^^^^^^^^^^^^^ + +The following pseudo-code illustrates how an AI data loader interacts with the +Velox Task Barrier: + +.. code-block:: c++ + + // AI data loading loop + // 1. Get splits from the runtime + // 2. Feed them to Velox + // 3. Wait for barrier to ensure all data for those splits is produced + + bool addSplits(); + + for (;;) { + // Data consumption loop + while (true) { + ContinueFuture dataFuture = ContinueFuture::makeEmpty(); + // Fetch next batch of results (velox vector) + auto data = veloxTask->next(&dataFuture); + + if (data != nullptr) { + // Consume the data (e.g., feed to training loop) + consume(data); + continue; + } + + // If no data is returned, check if we are blocked + // BARRIER REACHED CONDITION: Data is null AND future is invalid. + if (!dataFuture.valid()) { + // The Task is now idle at the Barrier. + + // Attempt to fetch new splits for the next Split Set + if (getSplits()) { + continue; // New splits added -> Resume processing (Repeat Cycle) + } + + // No more splits available -> Execution finished. + return; + } + + // Wait for the task to produce more data + wait(dataFuture); + } + } + + // Helper to add splits and request barrier + bool getSplits() { + auto splitSet = nextSplitSet(); + + // TERMINATION CONDITION: + // If no more splits are available from the source (e.g., end of dataset), + // we must explicitly signal the Task that input is finished. + if (!splitSet.has_value()) { + for (auto& [op, planNode] : veloxPlan.leafNodes) { + veloxTask->noMoreSplits(planNode->id()); + } + return false; + } + + // Add splits for all leaf nodes + for (auto& [op, planNode] : veloxPlan.leafNodes) { + veloxTask->addSplit(planNode->id(), splitSet.at(planNode->id())); + } + + // Request a barrier immediately after adding the split set + // This acts as a seal for the current split set, ensuring the task + // drains fully. + veloxTask->requestBarrier(); + return true; + } + +Implementation Mechanism +------------------------ + +The implementation relies on a special Barrier Split and a cooperative draining +mechanism within the Driver pipeline. + +The Barrier Split +^^^^^^^^^^^^^^^^^ + +When ``requestBarrier()`` is called: + +* Velox injects a special "Barrier Split" into the split queue of every source + operator in the leaf drivers (the drivers that read raw input splits). +* This split acts as a Sentinel that flows through the pipeline. +* It is excluded from standard task split processing statistics. + +Draining Stateful Operators +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Standard stateless operators (e.g., Filter, Project) process data +vector-by-vector, typically producing one output vector for each input vector +without buffering. They handle the barrier sentinel simply by passing it +downstream once the current input vector is finished. + +However, stateful operators (like Aggregations or Joins) actively buffer data +across multiple input vectors. The barrier mechanism forces them to flush this +buffer via a BarrierState maintained by each Driver. + +The Draining Process: + +* **Trigger**: When a source operator receives the Barrier Split, it enters + "draining mode" by calling the driver's ``drainOutput()`` method. +* **Action**: The driver then propagates this state by calling ``startDrain()`` + on the next operator in the pipeline. +* **Flush**: If that operator is stateful, it flushes all buffered data to the + output. +* **Forward**: The operator then calls ``startDrain()`` on its downstream + operator, ensuring the signal flows operator-by-operator within the driver. + +Diagram: Draining Flow (Single Driver) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The diagram below shows how the signal flows. Notice how the Stateful Operator +must flush its buffer before passing the signal to the Sink. + +.. image:: images/task-barrier-single-driver.png + :alt: Draining Flow (Single Driver) + :width: 100% + +Cross-Pipeline Propagation (The DAG) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +A complex query plan in Velox is executed as a Directed Acyclic Graph (DAG) of +pipelines. The barrier must propagate through this entire graph in topological +order. + +**Mechanism: LocalExchange via Exchange Queue** + +The pipelines in Velox are connected by LocalExchange through an Exchange Queue +mechanism. This structure allows the barrier signal to "jump" across pipeline +boundaries. + +* **Leaf Pipeline Completion**: The barrier processing starts at the leaf + drivers. The signal propagates down the operators until it reaches the Sink + Operator (e.g., LocalPartition, LocalExchangeSink). + +* **The Bridge (Queue)**: When the Sink Operator receives the ``startDrain()`` + signal, it first flushes any buffered data it holds. It then places a special + Barrier Token into the connecting Exchange Queue. This marker sits + immediately after the last data vector generated by the current Split Set. + +* **Downstream Pipeline Activation**: The downstream driver (e.g., the Join + pipeline) reads from this queue. When its Source Operator encounters the + Barrier Token, it interprets this as the signal to initiate the barrier + sequence for its own pipeline. + +* **Synchronization**: This ensures strict ordering. Pipeline B cannot finish + its barrier processing until Pipeline A has fully flushed its data and passed + the barrier token. The Task monitors the completion of all drivers via the + internal ``finishDriverBarrier()`` mechanism; only when the final sink in the + final pipeline has finished draining is the Barrier Future fulfilled. + +Illustration: MergeJoin Example +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Consider a MergeJoin with two upstream pipelines (Left and Right) feeding into a +Join pipeline. + +* **Step 1**: The Task injects Barrier Splits into the TableScan operators of + both the Left and Right pipelines. +* **Step 2**: Both pipelines process their splits, flush their local operators + (e.g., filters), and push data to their respective Exchange Queues. +* **Step 3**: Both pipelines finish by pushing a Barrier Token into their + queues. +* **Step 4**: The Join Pipeline reads from these queues. It processes data until + it hits the Barrier Token on both inputs. +* **Step 5**: The MergeJoin operator receives ``startDrain()``. It performs the + "Drop-Input" optimization (checking if it can stop early) and flushes any + remaining matches. +* **Step 6**: The Join pipeline completes, fulfilling the user's promise. + +.. image:: images/task-barrier-multi-driver.png + :alt: MergeJoin Barrier Propagation + :width: 100% + +TableScan & Source Operator Mechanics +------------------------------------- + +The TableScan (or any Source Operator) is the entry point for the barrier +signal. Its behavior is critical for determining when to "Pause" (Drain) and +when to "Resume." + +Delivery (Entering the Barrier) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Injection**: The ``requestBarrier`` API directly injects a BarrierSplit into + the Task's split queue. This split is distinct from regular data splits (e.g., + FileConnectorSplit). +* **Detection**: The Driver loop responsible for the TableScan fetches the next + split from the queue. When it pops the BarrierSplit, it does not attempt to + read data. +* **Signal Initiation**: Instead of calling getOutput, the driver invokes + ``startDrain()`` on the TableScan operator. This initiates the cascading drain + sequence described in The Draining Process Section. +* **Exclusion**: The TableScan ensures this split does not increment task + progress counters (e.g., "completed splits"), ensuring the barrier is + transparent to execution metrics. + +Resuming (Exit & Next Split) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Idle State**: Once the TableScan has propagated the drain signal and the + downstream pipeline has fully drained, the Driver completes that specific + execution cycle. The Task marks the barrier future as complete. +* **Wait**: The TableScan driver does not terminate; it simply returns to + checking the split queue. If the queue is empty, the driver yields (waits). +* **Re-activation**: When the application calls ``addSplit()`` with new data + splits for the next logical Split Set: + + 1. The splits are pushed to the queue. + 2. The TableScan driver wakes up, pops a valid data split, and resumes + standard data processing (getOutput). + +This seamless transition allows the same Task and Driver instances—and their +associated memory pools and caches—to be reused indefinitely. + +Operator-Specific Implementation Details +---------------------------------------- + +The core of the barrier implementation lies in the ``startDrain()`` virtual +method. While stateless operators simply pass the signal through, stateful +operators must actively manage their buffers to ensure data determinism. + +Index Lookup Join +^^^^^^^^^^^^^^^^^ + +* **Function**: Connects a stream of input rows (probe side) with a remote or + local index (build side), often fetching matches asynchronously. + +**Draining Complexity**: The operator typically buffers input rows to perform +batched lookups or asynchronous prefetches. Simply stopping input isn't enough; +pending lookups inside the execution engine must complete and be emitted. + +**Barrier Logic**: + +* **Flush Prefetch Buffers**: When the barrier signal is received, the operator + forces the flushing of any "in-flight" batches that have been sent for lookup + but haven't yet produced output. +* **Drain Output**: It ensures that ``getOutput()`` continues to return results + until all buffered lookup requests have been resolved and emitted. + +Unnest +^^^^^^ + +* **Function**: Expands complex types (arrays, maps) into multiple rows. + +**Draining Complexity**: An Unnest operator might require multiple calls to +``getOutput()`` to fully process a single large input vector due to output batch +size limits. If a barrier arrives, the operator cannot simply stop; it must +finish expanding the current vector. + +**Barrier Logic**: + +* **Process Remaining Rows**: The ``startDrain()`` logic ensures that the + operator does not accept new input vectors but continues to generate output + until the current input vector is fully unnested. +* **Clean State**: Once the current input vector is exhausted, the operator + transitions to a finished state for that specific barrier cycle, ensuring no + partial arrays are left "half-expanded" across the split boundary. + +Sort Merge Join +^^^^^^^^^^^^^^^ + +* **Function**: Assumes both inputs are sorted and finds matches by traversing + them simultaneously. It is highly stateful and relies on the relative order of + keys. + +**Draining Complexity**: "Over-Reading" in Draining Mode + +In a standard Sort Merge Join, the operator buffers rows from one side while +iterating through the other. When a Barrier acts as a temporary "End of Stream," +the operator might continue reading from one input even after the other input +has finished, to find potential matches. However, because the inputs are sorted, +the operator often knows that no further matches are possible, making this extra +I/O wasteful. + +**Barrier Logic & The "Drop-Input" Optimization** + +This operator implements an optimization to prevent wasteful processing: + +1. **Draining Matches**: When ``startDrain()`` is invoked, the operator forces + the processing of any matches currently possible with the buffered data. + +2. **Early Cutoff (Drop-Input)**: + + * **Logic**: If one side of the join is fully processed/drained (hit the + barrier), the operator compares the max key from that finished side against + the current key of the active side. + * **Decision**: If the active side's key has already passed the finished + side's max key (based on sort order), no future matches are possible. + * **Action**: The operator sets the ``dropInputOpId`` in the BarrierState. + This signals the driver to stop pulling data for the remaining side + immediately. + +This optimization saves significant CPU and I/O resources while still +guaranteeing the same correct result as a full drain. + +Streaming Aggregation +^^^^^^^^^^^^^^^^^^^^^ + +* **Function**: Computes aggregates (SUM, COUNT, etc.) on data already sorted by + grouping keys. + +**Draining Complexity**: The operator buffers the current group's accumulation +in memory. If the split ends exactly at a barrier, that partial accumulation +exists in the operator's state but hasn't been emitted because the operator +hasn't seen a "new" key to trigger the flush. + +**Barrier Logic**: + +* **Force Group Flush**: Upon receiving the barrier signal via ``startDrain()``, + the operator treats the event as a virtual "end of stream" for the current + batch. It forces the emission of the accumulated result for the current group + key, even if the key hasn't changed yet. +* **Reset Accumulators**: After flushing, the accumulators are reset. This + ensures that when the task resumes after the barrier with new splits, it + starts fresh. + +Limitations +----------- + +* **Sequential Execution Mode Only**: The barrier logic currently requires the + task to be running in sequential execution mode (single-threaded logic per + driver). + +* **Restricted Data Sources**: The barrier execution strictly supports only + TableScan source nodes. Other source nodes, such as Values (value source node) + or RemoteExchange (remote exchange node), are not supported. + +* **HashJoin**: Not supported. HashJoin involves complex build-side state + management that is less common in the streaming/AI workloads targeted by this + feature. These workloads typically rely on MergeJoin or IndexJoin to maintain + order and lower memory footprint. diff --git a/velox/docs/develop/testing.rst b/velox/docs/develop/testing.rst index 93f92b50a04..e9dabed5229 100644 --- a/velox/docs/develop/testing.rst +++ b/velox/docs/develop/testing.rst @@ -11,5 +11,6 @@ Testing Tools testing/join-fuzzer testing/memory-arbitration-fuzzer testing/row-number-fuzzer + testing/spatial-join-fuzzer testing/writer-fuzzer testing/spark-query-runner.rst diff --git a/velox/docs/develop/testing/memory-arbitration-fuzzer.rst b/velox/docs/develop/testing/memory-arbitration-fuzzer.rst index b4a15c89245..005b80d4ddb 100644 --- a/velox/docs/develop/testing/memory-arbitration-fuzzer.rst +++ b/velox/docs/develop/testing/memory-arbitration-fuzzer.rst @@ -9,8 +9,8 @@ It works as follows: 1. Data Generation: It starts by generating a random set of input data, also known as a vector. This data can have a variety of encodings and data layouts to ensure thorough testing. -2. Plan Generation: Generate multiple plans with different query shapes. Currently, it supports HashJoin and - HashAggregation plans. +2. Plan Generation: Generate multiple plans with different query shapes. Currently, it supports HashJoin, + HashAggregation, RowNumber, TopNRowNumber, and OrderBy plans. 3. Query Execution: Create multiple threads, each thread randomly picks a plan with spill enabled or not, and repeatedly running this process until ${iteration_duration_sec} seconds. The query thread expects query to succeed or fail with query OOM or abort errors, otherwise it throws. diff --git a/velox/docs/develop/testing/row-number-fuzzer.rst b/velox/docs/develop/testing/row-number-fuzzer.rst index f90b33f7621..769acfe2269 100644 --- a/velox/docs/develop/testing/row-number-fuzzer.rst +++ b/velox/docs/develop/testing/row-number-fuzzer.rst @@ -1,30 +1,51 @@ -================ -RowNumber Fuzzer -================ +================================== +RowNumber and TopNRowNumber Fuzzer +================================== -The RowNumberFuzzer is a testing tool that automatically generate equivalent query plans and then executes these plans -to validate the consistency of the results. It works as follows: +The RowNumberFuzzer and TopNRowNumberFuzzer are testing tools that automatically generate equivalent query plans that +use the RowNumber and TopNRowNumber Velox plan nodes, and then execute these plans to validate the consistency of +the results. They works as follows: -1. Data Generation: It starts by generating a random set of input data, also known as a vector. This data can +1. Data Generation: Generate a random set of input data, also known as a vector. This data can have a variety of encodings and data layouts to ensure thorough testing. -2. Plan Generation: Generate two equivalent query plans, one is row-number over ValuesNode as the base plan. - and the other is over TableScanNode as the alter plan. +2. Plan Generation: Generate equivalent query plans and validate results across all of them. + + For RowNumberFuzzer: + + * Base plan: RowNumber over ValuesNode. + * Alternative plan: RowNumber over TableScanNode. + + For TopNRowNumberFuzzer: + + * Base plan: TopNRowNumber over ValuesNode using a randomly chosen rank function + (``row_number``, ``rank``, or ``dense_rank``) with a random row limit. + * Alternative plan: WindowNode over ValuesNode using the same rank function and partition/sort + keys, followed by a filter on the rank value to apply the same limit. This validates that the + optimised TopNRowNumber operator produces results consistent with the general Window operator. + * Alternative plan: TopNRowNumber over TableScanNode using the same rank function and limit. + 3. Query Execution: Executes those equivalent query plans using the generated data and asserts that the results are consistent across different plans. + i. Execute the base plan, compare the result with the reference (DuckDB or Presto) and use it as the expected result. - #. Execute the alter plan multiple times with and without spill, and compare each result with the + #. Execute the alternative plans multiple times with and without spill, and compare each result with the expected result. + 4. Iteration: This process is repeated multiple times to ensure reliability and robustness. How to run ---------- -Use velox_row_number_fuzzer binary to run rowNumber fuzzer: - +Use velox_row_number_fuzzer to run RowNumberFuzzer :: velox/exec/fuzzer/velox_row_number_fuzzer --seed 123 --duration_sec 60 +Similarly, use velox_topn_row_number_fuzzer to run TopNRowNumberFuzzer +:: + + velox/exec/fuzzer/velox_topn_row_number_fuzzer --seed 123 --duration_sec 60 + By default, the fuzzer will go through 10 iterations. Use --steps or --duration-sec flag to run fuzzer for longer. Use --seed to reproduce fuzzer failures. diff --git a/velox/docs/develop/testing/spatial-join-fuzzer.rst b/velox/docs/develop/testing/spatial-join-fuzzer.rst new file mode 100644 index 00000000000..519ab359093 --- /dev/null +++ b/velox/docs/develop/testing/spatial-join-fuzzer.rst @@ -0,0 +1,118 @@ +==================== +Spatial Join Fuzzer +==================== + +Overview +======== + +The Spatial Join Fuzzer tests the correctness of the SpatialJoin operator by generating random geometry data and spatial join plans. It verifies that SpatialJoin produces the same results as NestedLoopJoin for equivalent queries. + + +Supported Features +================== + +Join Types +---------- + +The fuzzer tests the two join types supported by SpatialJoin (as defined in ``SpatialJoinNode::isSupported()``): + +* **INNER** - Only matching rows from both sides +* **LEFT** - All rows from left side, matched rows from right side + +Spatial Predicates +------------------ + +The fuzzer tests these spatial predicates: + +* ``ST_Intersects(geometry1, geometry2)`` - Tests if geometries intersect +* ``ST_Contains(geometry1, geometry2)`` - Tests if one geometry contains another +* ``ST_Within(geometry1, geometry2)`` - Tests if one geometry is within another +* ``ST_Distance(geometry1, geometry2) < threshold`` - Tests distance with threshold + +Geometry Types +-------------- + +The fuzzer generates Well-Known Text (WKT) strings for three geometry types: + +* **POINT** - Single coordinate point (e.g., ``POINT (10.5 20.3)``) +* **POLYGON** - Closed shape with vertices +* **LINESTRING** - Line segment between two points + +Distribution Patterns +--------------------- + +Geometries are generated using three distribution patterns: + +* **Uniform** - Geometries uniformly distributed in space (0-1000 range) +* **Clustered** - Geometries grouped in 5 specific regions to test overlap scenarios +* **Sparse** - Geometries widely spread (0-2000 range) with low overlap probability + +Implementation Details +====================== + + +Geometry Generation +------------------- + +Geometries are generated using ``AbstractInputGenerator`` subclasses: + +* ``PointInputGenerator`` - Generates POINT WKT strings +* ``PolygonInputGenerator`` - Generates POLYGON WKT strings +* ``LineStringInputGenerator`` - Generates LINESTRING WKT strings + +Each generator implements the ``generate(vector_size_t index)`` method to produce geometry strings based on the distribution pattern. + +**Uniform Distribution**:: + + x = random(0, 1000) + y = random(0, 1000) + POINT (x y) + +**Clustered Distribution**:: + + cluster = row % 5 // 5 clusters + centerX = cluster * 200 + 100 + centerY = cluster * 200 + 100 + x = centerX + random(-50, 50) + y = centerY + random(-50, 50) + POINT (x y) + +**Sparse Distribution**:: + + x = random(0, 2000) // Larger Range + y = random(0, 2000) + POINT (x y) + +Data Matching Strategy +---------------------- + +To ensure some matches occur during joins: + +* Build side copies ~30% of geometries from probe side +* 10% chance of empty build side to test edge cases + +Verification +------------ + +The fuzzer compares results from two equivalent plans: + +1. **SpatialJoin plan** - Using the specialized SpatialJoin operator +2. **NestedLoopJoin plan** - Using NestedLoopJoin with the same spatial predicate as a filter + +Results must match exactly, validating that SpatialJoin implements spatial predicates correctly. + +Key Differences from JoinFuzzer +================================ + +Join Conditions +--------------- + +Unlike regular joins with simple equality predicates:: + + // Regular join + probe.id = build.id + + // Spatial join + ST_Intersects(probe_geom, build_geom) + +Spatial joins use **function call expressions** as join conditions rather than simple column references. diff --git a/velox/docs/develop/timestamp.rst b/velox/docs/develop/timestamp.rst index 7d745cc21a9..139fd53fa93 100644 --- a/velox/docs/develop/timestamp.rst +++ b/velox/docs/develop/timestamp.rst @@ -139,6 +139,16 @@ generally more efficient, but std::chrono does not handle time zone offsets such as ``+09:00``. Timezone offsets are only supported in the API version that takes a timezone ID. +Timezone Database Lookup +------------------------ + +Velox uses the IANA Time Zone Database to handle timezone conversions. It looks for the database in the following order: + +1. The directory specified by the ``TZDIR`` environment variable. +2. ``/usr/share/zoneinfo/uclibc`` (on Linux). +3. ``/usr/share/zoneinfo`` (on Linux). +4. The path pointed to by ``/etc/localtime`` (on macOS). + Casts ----- diff --git a/velox/docs/develop/types.rst b/velox/docs/develop/types.rst index 98106d931bb..c948dbf136b 100644 --- a/velox/docs/develop/types.rst +++ b/velox/docs/develop/types.rst @@ -115,6 +115,9 @@ DATE INTEGER DECIMAL BIGINT if precision <= 18, HUGEINT if precision >= 19 INTERVAL DAY TO SECOND BIGINT INTERVAL YEAR TO MONTH INTEGER +TIME BIGINT +TIME_MICRO_UTC BIGINT +TIMESTAMP_UTC TIMESTAMP ====================== ====================================================== DECIMAL type carries additional `precision`, @@ -130,6 +133,12 @@ upto 38 precision, with a range of :math:`[-10^{38} + 1, +10^{38} - 1]`. All the three values, precision, scale, unscaled value are required to represent a decimal value. +TIME represents time in milliseconds since midnight, subject to session timezone interpretation. Thus min/max value can range from 0 to 23:59:59.999. +TIME_MICRO_UTC represents time in microseconds since midnight in UTC, not subject to session timezone adjustment. Thus min/max value can range from 00:00:00.000000 to 23:59:59.999999. +The TIME and TIME_MICRO_UTC types are backed by BIGINT physical type. + +TIMESTAMP represents a timestamp subject to session timezone interpretation. TIMESTAMP_UTC represents a timestamp in UTC, not subject to session timezone adjustment. Both types are backed by TIMESTAMP physical type. + Custom Types ~~~~~~~~~~~~ Most custom types can be represented as logical types and can be built by extending @@ -168,18 +177,35 @@ The table below shows the supported Presto types. Presto Type Physical Type ======================== ===================== HYPERLOGLOG VARBINARY +KHYPERLOGLOG VARBINARY +P4HYPERLOGLOG VARBINARY JSON VARCHAR TIMESTAMP WITH TIME ZONE BIGINT UUID HUGEINT IPADDRESS HUGEINT IPPREFIX ROW(HUGEINT,TINYINT) +BINGTILE BIGINT GEOMETRY VARBINARY +SPHERICALGEOGRAPHY VARBINARY +SETDIGEST VARBINARY TDIGEST VARBINARY QDIGEST VARBINARY BIGINT_ENUM BIGINT VARCHAR_ENUM VARCHAR +TIME WITH TIME ZONE BIGINT ======================== ===================== +KHYPERLOGLOG is a data sketch for estimating reidentifiability and joinability within a dataset. +Based on the `KHyperLogLog paper `_, +it maintains a map of K number of HyperLogLog structures, where each entry corresponds to a unique key from one column, +and the HLL estimates the cardinality of the associated unique identifiers from another column. +For storage and retrieval it may be cast to/from VARBINARY. + +P4HYPERLOGLOG is a data sketch for cardinality estimation that uses only the dense HyperLogLog +representation. Unlike standard HYPERLOGLOG which supports both sparse and dense formats, +P4HYPERLOGLOG always uses dense format. It may be cast to/from HYPERLOGLOG and to/from VARBINARY +for storage and retrieval. + TIMESTAMP WITH TIME ZONE represents a time point in milliseconds precision from UNIX epoch with timezone information. Its physical type is BIGINT. The high 52 bits of bigint store signed integer for milliseconds in UTC. @@ -199,7 +225,10 @@ IPPREFIX networks. IPPREFIX represents an IPv6 or IPv4 formatted IPv6 address along with a one byte prefix length. Its physical type is ROW(HUGEINT, TINYINT). The IPADDRESS is stored in the HUGEINT and is in the form defined in `RFC 4291#section-2.5.5.2 `_. -The prefix length is stored in the TINYINT. +The prefix length is stored in the TINYINT. Note that IPv6 prefix lengths go up +to 128, which overflows TINYINT (int8_t, max 127). Prefix length 128 is stored +as -128. Code that reads the prefix length must cast to uint8_t to recover the +correct unsigned value. The IP address stored is the canonical(smallest) IP address in the subnet range. This type can be used in IP subnet functions. @@ -213,6 +242,10 @@ As a result the IPPREFIX object stores *FFFF:FFFF::* and the length 32 for both IPPREFIX 'FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF/32' -- IPPREFIX 'FFFF:FFFF:0000:0000:0000:0000:0000:0000/32' IPPREFIX 'FFFF:FFFF:4455:6677:8899:AABB:CCDD:EEFF/32' -- IPPREFIX 'FFFF:FFFF:0000:0000:0000:0000:0000:0000/32' +SETDIGEST is a data sketch for estimating set cardinality and performing set operations +like intersection cardinality and Jaccard index. It combines HyperLogLog with MinHash. +SetDigests may be merged, and for storage and retrieval they may be cast to/from VARBINARY. + TDIGEST(DOUBLE) is a data sketch for estimating rank-based metrics. T-digests may be merged without losing precision, and for storage and retrieval they may be cast to/from VARBINARY. The T-digest accepts a parameter of type @@ -243,6 +276,23 @@ VarcharEnumParameter as the key. Casting is only permitted to and from VARCHAR type, and is case-sensitive. Casting between different enum types is not permitted. Comparison operations are only allowed between values of the same enum type. +TIME WITH TIME ZONE represents time from midnight in milliseconds precision at a particular timezone. +Its physical type is BIGINT. The high 52 bits of bigint store signed integer for milliseconds in UTC. +The lower 12 bits store the time zone offsets minutes. This allows the time to be converted at any point of +time without ambiguity of daylight savings time. Time zone offsets range from -14:00 hours to +14:00 hours. + +BINGTILE represents a `Bing tile `_. +It is a quadtree in the Web Mercator projection, where each tile is 256x256 pixels. Its physical type is BIGINT. + +GEOMETRY represents a geometry as defined in `Simple Feature Access `_. +Subtypes include Point, MultiPoint, LineString, MultiLineString, Polygon, MultiPolygon, and GeometryCollection. They +are often stored as `Well-Known Text `_ or +`Well-Known Binary `_. + +SPHERICALGEOGRAPHY represents a geometry on a spherical model of the Earth. It is internally represented the same +way as GEOMETRY, but only certain functions are supported. Moreover, these functions will return values in meters +as opposed to the units of the coordinate space. + Spark Types ~~~~~~~~~~~~ The `data types `_ in Spark have some semantic differences compared to those in @@ -262,6 +312,14 @@ key differences are listed below. ) AS t(ts); -- 2014-03-08 09:00:00.012345 +* Spark operates on the TIME_MICRO_UTC type for "microsecond" precision and timezone unawareness, + while Presto uses the standard TIME type. + Example:: + + SELECT cast('12:30:45.123456' as time) -- 12:30:45.123456 + +* Spark uses TIMESTAMP_UTC to support TimestampNTZType. TIMESTAMP_UTC is not subject to session timezone adjustment. + * In function comparisons, nested null values are handled as values. Example:: @@ -279,3 +337,233 @@ key differences are listed below. also not orderable, but it is comparable if both key and value types are comparable. The implication is that MAP type cannot be used as a join, group by or order by key in Spark. + +Type Coercion +~~~~~~~~~~~~~ +Type coercion is the implicit conversion of a value from one type to +another during query planning. It resolves function overloads and +special-form result types when arguments don't match a signature exactly. + +Coercion is a planning-time concern: by the time a Velox ``Task`` is +constructed, every implicit conversion is already a materialized ``Cast`` +node in the typed expression tree, and runtime evaluators do not consult +the coercer. + +Coercion rules live in ``TypeCoercer`` (``velox/type/TypeCoercer.h``). +``TypeCoercer`` is value-typed and immutable after construction. Velox ships +a default instance (``TypeCoercer::defaults()``) holding a conservative +built-in rule set used when no dialect coercer is provided. SQL dialects +ship their own complete instances -- for example, +``velox::functions::prestosql::typeCoercer()`` for the Presto dialect -- that +match the dialect's overload-resolution semantics. + +``TypeCoercer`` itself is frontend-agnostic: it's a plain value type that +the resolver APIs (``SignatureBinder``, +``resolveFunction*WithCoercions``, the special-form ``resolveTypeInt`` +helpers) accept as a defaulted tail parameter so existing callers +compile unchanged and use ``TypeCoercer::defaults()``. How a coercer +reaches the resolver APIs is the frontend's choice. Axiom, for example, +threads the dialect coercer through ``logical_plan::PlanBuilder``'s +``Context.coercer`` field, which the SQL parser sets when constructing +the plan builder; other frontends may pass a coercer directly into the +resolver APIs or wire it through their own planning context. + +Customization scope +^^^^^^^^^^^^^^^^^^^ + +What a dialect's ``TypeCoercer`` rule set controls: + +* **Primitives** (TINYINT, SMALLINT, INTEGER, BIGINT, REAL, DOUBLE, BOOLEAN, + VARCHAR, VARBINARY, DATE, TIMESTAMP, UNKNOWN): full control over which + source/target pairs are allowed and at what cost. Lower cost is preferred + during overload resolution. +* **DECIMAL**: customizable for source and target separately; DECIMAL -> + DECIMAL is not customizable. See the DECIMAL section below. +* **Container types** (ARRAY, MAP, ROW) and FUNCTION/OPAQUE: not + customizable directly. Coercibility is structural -- names and arities + must match, and children are recursed element-wise. A dialect controls + container behavior only indirectly via element-type rules. +* **Custom types** (e.g. JSON, TIMESTAMP WITH TIME ZONE, BINGTILE): not + customizable through a dialect's ``TypeCoercer`` rule set. Custom-type + coercions are registered via ``registerCastRules`` alongside + ``registerCustomType`` and live in the global ``CastRulesRegistry``, + which is shared across dialects. To keep callers from having to query + both registries, ``TypeCoercer::coerceTypeBase`` consults + ``CastRulesRegistry`` as a fallback after its own rule lookup -- so + custom-type coercions remain reachable through any ``TypeCoercer`` + instance even though the dialect can't override them. + +Cost magnitudes +^^^^^^^^^^^^^^^ + +Overload resolution sums per-argument coercion costs +(``Coercion::overallCost``) to compare candidate signatures. For sums to +be meaningful, every ``CoercionEntry.cost`` in a single ``TypeCoercer`` +instance must be in the same small magnitude -- today's defaults use +costs 1-9, one per source-type series. There is no hardcoded surcharge +added at lookup time: the dialect's rule cost is returned verbatim. + +DECIMAL +^^^^^^^ + +DECIMAL handling depends on which side is DECIMAL. Rule keys collapse on +the name ``DECIMAL`` regardless of (p, s) -- one rule per +``(sourceName, targetName)`` pair on each side. + +**Source DECIMAL** (e.g. ``DECIMAL(p, s) -> DOUBLE``). A dialect +registers one rule per ``(DECIMAL, target)`` where ``target`` is a +non-DECIMAL type (DECIMAL -> DECIMAL is not customizable; see below). +The source must be the canonical placeholder ``DECIMAL(1, 0)``; the rule +fires for any actual ``DECIMAL(p, s)`` source because the source's +precision/scale is not part of the lookup key. The rule resolves +directly to the target type at the rule's cost. + +**Target DECIMAL** (e.g. ``INTEGER -> DECIMAL(p, s)``). A dialect +registers one rule per ``(source, DECIMAL)``. The rule's stored target +is the minimum-width decimal that holds every value of the source +(e.g. ``INTEGER -> DECIMAL(10, 0)``). At lookup, the type system extends +the rule's fixed target to the caller's requested ``DECIMAL(p, s)`` via +``ShortDecimalType::isCoercibleTo`` / ``LongDecimalType::isCoercibleTo`` +(see widening rule below). Returns ``nullopt`` if the caller's target is +too narrow. + +Widening itself contributes 0 to the cost; the rule's stored cost is +returned verbatim. ``INT -> DECIMAL(10, 0)`` and ``INT -> DECIMAL(38, +18)`` both cost whatever the rule says (e.g. cost 2 in INTEGER's series). +This is a simple choice that works for current overload-resolution +cases; it may need to be revisited if a function is registered with +multiple concrete DECIMAL-target signatures, since they would all coerce +at the same cost and produce ambiguous resolution. + +**DECIMAL -> DECIMAL** (e.g. ``DECIMAL(10, 2) -> DECIMAL(20, 4)``) is +**not customizable** by dialects. ``TypeCoercer`` rejects rule entries +with both source and target DECIMAL at construction time, so attempting +to register one fails fast. Dialects that need non-standard +DECIMAL -> DECIMAL semantics must extend the type system, not +``TypeCoercer``. + +At lookup time, ``coerceTypeBase`` short-circuits for any two DECIMALs +regardless of (p, s) and returns ``Coercion{type: from, cost: 0}``. +Precision/scale reconciliation is therefore not done by +``coerceTypeBase`` itself; instead it happens via DECIMAL-specific paths +in two places: + +* ``LongDecimalType::commonSuperType`` inside ``leastCommonSuperType`` + computes the common ``(p, s)`` for plan-level operations (UNION, CASE + result type, etc.). +* ``SignatureBinder``'s integer-parameter binding handles function + signatures of the form ``DECIMAL(P, S)`` by binding ``P`` and ``S`` as + integer variables from the actual argument types. + +*DECIMAL widening rule.* ``DECIMAL(p1, s1)`` is coercible to +``DECIMAL(p2, s2)`` iff: + +* ``p1 - s1 <= p2 - s2`` -- the target has at least as many integer + digits as the source, and +* ``s1 <= s2`` -- the target has at least as much scale. + +Both conditions must hold; otherwise the widening fails. This rule is +used by Target DECIMAL above (extending a fixed-width target to a wider +caller-requested DECIMAL) and by ``LongDecimalType::commonSuperType`` +for DECIMAL -> DECIMAL reconciliation. + +``coerceTypeBase`` vs ``coercible`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +``coerceTypeBase(from, to)`` is a single-rule lookup. It does NOT recurse +into container children -- ``coerceTypeBase(ARRAY, ARRAY)`` +returns ``nullopt`` because no flat rule keys on +``("ARRAY", "ARRAY")``. ``coercible(from, to)`` is the structural +predicate: for primitives it delegates to ``coerceTypeBase``; for +containers it requires matching names and arities, then sums child +coercion costs. + +Default coercion rules +^^^^^^^^^^^^^^^^^^^^^^ + +``TypeCoercer::defaults()`` ships the rules below. In this and the +Presto-specific table that follows, allowed targets are listed in cost +order (cheapest first), and for DECIMAL targets the listed type is the +minimum-width decimal that holds every value of the source (lookup +widens to a wider DECIMAL when compatible -- see the DECIMAL section +above). + +============== ========================================================== +Source Allowed targets (in cost order, cheapest first) +============== ========================================================== +TINYINT SMALLINT, INTEGER, BIGINT, DECIMAL(3, 0), REAL, DOUBLE +SMALLINT INTEGER, BIGINT, DECIMAL(5, 0), REAL, DOUBLE +INTEGER BIGINT, DECIMAL(10, 0), REAL, DOUBLE +BIGINT DECIMAL(19, 0), DOUBLE +REAL DOUBLE +DECIMAL REAL, DOUBLE +DATE TIMESTAMP +UNKNOWN TINYINT, BOOLEAN, SMALLINT, INTEGER, BIGINT, REAL, DOUBLE, + VARCHAR, VARBINARY +============== ========================================================== + +Notable absences: + +* ``BIGINT -> REAL`` is not in the defaults (BIGINT has 64-bit integer + precision; REAL holds only ~7 decimal digits). Presto allows it; the + Presto dialect coercer adds it (see below). +* No reverse conversions (e.g. ``DOUBLE -> REAL``, ``BIGINT -> INTEGER``). +* No string conversions (e.g. ``INTEGER -> VARCHAR``). +* No conversions between unrelated families (e.g. ``BOOLEAN -> INTEGER``, + ``DATE -> BIGINT``). + +Presto-specific coercion rules +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +``velox::functions::prestosql::typeCoercer()`` (in +``velox/functions/prestosql/coercion/PrestoCoercions.{h,cpp}``) ships a +complete rule set independent from ``TypeCoercer::defaults()`` so that +dialect-specific changes don't silently shift when Velox defaults change. +Same table conventions as above. + +============== ========================================================== +Source Allowed targets (in cost order, cheapest first) +============== ========================================================== +TINYINT SMALLINT, INTEGER, BIGINT, DECIMAL(3, 0), REAL, DOUBLE +SMALLINT INTEGER, BIGINT, DECIMAL(5, 0), REAL, DOUBLE +INTEGER BIGINT, DECIMAL(10, 0), REAL, DOUBLE +BIGINT DECIMAL(19, 0), REAL, DOUBLE +REAL DOUBLE +DECIMAL REAL, DOUBLE +DATE TIMESTAMP +UNKNOWN TINYINT, BOOLEAN, SMALLINT, INTEGER, BIGINT, REAL, DOUBLE, + VARCHAR, VARBINARY +============== ========================================================== + +Differences from Velox's defaults: + +* ``BIGINT -> REAL`` is added. The BIGINT row becomes + ``DECIMAL(19, 0), REAL, DOUBLE``, mirroring the INTEGER row's + ordering. This makes ``divide(real, bigint)`` resolve to + ``divide(real, real)`` via ``BIGINT -> REAL`` (cost 2) instead of + ``divide(double, double)`` via ``REAL -> DOUBLE + BIGINT -> DOUBLE`` + (cost 1 + 3 = 4), matching Presto's overload resolution. + +Presto also allows the following implicit coercions: + +============================== ========================================== +Source Target +============================== ========================================== +TIMESTAMP TIMESTAMP WITH TIME ZONE +DATE TIMESTAMP WITH TIME ZONE +TIME TIME WITH TIME ZONE +============================== ========================================== + +These are not part of ``presto::typeCoercer()``'s rule set. From +Velox's perspective the targets above are *custom* types +(``TIMESTAMP WITH TIME ZONE``, ``TIME WITH TIME ZONE``) registered via +``registerCustomType``, so their coercion rules live in the global +``CastRulesRegistry`` (registered with ``implicitAllowed = true`` +alongside the type) rather than in the dialect's ``TypeCoercer``. +Lookup still finds them because ``coerceTypeBase`` consults +``CastRulesRegistry`` as a fallback. + +Casts to other Presto types backed by Velox custom types (JSON, +BINGTILE, IPADDRESS, IPPREFIX, UUID, BIGINT_ENUM, VARCHAR_ENUM, +P4HYPERLOGLOG) are explicit-only -- see +:doc:`/functions/presto/conversion`. diff --git a/velox/docs/ext/delta.py b/velox/docs/ext/delta.py new file mode 100644 index 00000000000..40a76426a2e --- /dev/null +++ b/velox/docs/ext/delta.py @@ -0,0 +1,773 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generating delta function link :delta:func:``""" + +from __future__ import annotations + +from function import ( + function_sig_re, + pairindextypes, + parse_arglist, + pseudo_parse_arglist, + parse_annotation, + ObjectEntry, + ModuleEntry, +) +from typing import Any, Iterable, Iterator, Tuple, cast + +from docutils import nodes +from docutils.nodes import Element, Node +from docutils.parsers.rst import directives +from sphinx import addnodes +from sphinx.addnodes import desc_signature, pending_xref +from sphinx.application import Sphinx +from sphinx.builders import Builder +from sphinx.directives import ObjectDescription +from sphinx.domains import Domain, Index, IndexEntry, ObjType +from sphinx.environment import BuildEnvironment +from sphinx.locale import _, __ +from sphinx.roles import XRefRole +from sphinx.util import logging +from sphinx.util.docfields import Field +from sphinx.util.nodes import ( + find_pending_xref_condition, + make_id, + make_refnode, +) +from sphinx.util.typing import OptionSpec + +logger = logging.getLogger(__name__) + +function_module = "delta" + + +class DeltaObject(ObjectDescription[Tuple[str, str]]): + """ + Description of a general Delta object. + + :cvar allow_nesting: Class is an object that allows for nested namespaces + :vartype allow_nesting: bool + """ + + option_spec: OptionSpec = { + "noindex": directives.flag, + "noindexentry": directives.flag, + "nocontentsentry": directives.flag, + "module": directives.unchanged, + "canonical": directives.unchanged, + "annotation": directives.unchanged, + } + + doc_field_types = [ + Field( + "returnvalue", + label=_("Returns"), + has_arg=False, + names=("returns", "return"), + ), + ] + + allow_nesting = False + + def get_signature_prefix(self, sig: str) -> list[nodes.Node]: + """May return a prefix to put before the object name in the + signature. + """ + return [] + + def needs_arglist(self) -> bool: + """May return true if an empty argument list is to be generated even if + the document contains none. + """ + return False + + def handle_signature(self, sig: str, signode: desc_signature) -> tuple[str, str]: + """Transform a Delta signature into RST nodes. + Return (fully qualified name of the thing, classname if any). + If inside a class, the current class name is handled intelligently: + * it is stripped from the displayed name if present + * it is added to the full name (return value) if not present + """ + m = function_sig_re.match(sig) + if m is None: + raise ValueError + prefix, name, arglist, retann = m.groups() + + # determine module and class name (if applicable), as well as full name + modname = self.options.get("module", self.env.ref_context.get("delta:module")) + classname = self.env.ref_context.get("delta:class") + if classname: + add_module = False + if prefix and (prefix == classname or prefix.startswith(classname + ".")): + fullname = prefix + name + # class name is given again in the signature + prefix = prefix[len(classname) :].lstrip(".") + elif prefix: + # class name is given in the signature, but different + # (shouldn't happen) + fullname = classname + "." + prefix + name + else: + # class name is not given in the signature + fullname = classname + "." + name + else: + add_module = True + if prefix: + classname = prefix.rstrip(".") + fullname = prefix + name + else: + classname = "" + fullname = name + + signode["module"] = modname + signode["class"] = classname + signode["fullname"] = fullname + + sig_prefix = self.get_signature_prefix(sig) + if sig_prefix: + if type(sig_prefix) is str: + raise TypeError( + "Python directive method get_signature_prefix()" + " must return a list of nodes." + f" Return value was '{sig_prefix}'." + ) + else: + signode += addnodes.desc_annotation(str(sig_prefix), "", *sig_prefix) + + if prefix: + signode += addnodes.desc_addname(prefix, prefix) + elif modname and add_module and self.env.config.add_module_names: + nodetext = modname + "." + signode += addnodes.desc_addname(nodetext, nodetext) + + signode += addnodes.desc_name(name, name) + if arglist: + try: + signode += parse_arglist(function_module, arglist, self.env) + except SyntaxError: + # fallback to parse arglist original parser. + # it supports to represent optional arguments (ex. "func(foo [, bar])") + pseudo_parse_arglist(signode, arglist) + except NotImplementedError as exc: + logger.warning( + "could not parse arglist (%r): %s", arglist, exc, location=signode + ) + pseudo_parse_arglist(signode, arglist) + else: + if self.needs_arglist(): + # for callables, add an empty parameter list + signode += addnodes.desc_parameterlist() + + if retann: + children = parse_annotation(function_module, retann, self.env) + signode += addnodes.desc_returns(retann, "", *children) + + anno = self.options.get("annotation") + if anno: + signode += addnodes.desc_annotation( + " " + anno, "", addnodes.desc_sig_space(), nodes.Text(anno) + ) + + return fullname, prefix + + def _object_hierarchy_parts(self, sig_node: desc_signature) -> tuple[str, ...]: + if "fullname" not in sig_node: + return () + modname = sig_node.get("module") + fullname = sig_node["fullname"] + + if modname: + return (modname, *fullname.split(".")) + else: + return tuple(fullname.split(".")) + + def get_index_text(self, modname: str, name: tuple[str, str]) -> str: + """Return the text for the index entry of the object.""" + raise NotImplementedError("must be implemented in subclasses") + + def add_target_and_index( + self, name_cls: tuple[str, str], sig: str, signode: desc_signature + ) -> None: + modname = self.options.get("module", self.env.ref_context.get("delta:module")) + fullname = (modname + "." if modname else "") + name_cls[0] + node_id = make_id(self.env, self.state.document, "", fullname) + signode["ids"].append(node_id) + self.state.document.note_explicit_target(signode) + + domain = cast(DeltaDomain, self.env.get_domain("delta")) + domain.note_object(fullname, self.objtype, node_id, location=signode) + + canonical_name = self.options.get("canonical") + if canonical_name: + domain.note_object( + canonical_name, self.objtype, node_id, aliased=True, location=signode + ) + + if "noindexentry" not in self.options: + indextext = self.get_index_text(modname, name_cls) + if indextext: + self.indexnode["entries"].append( + ("single", indextext, node_id, "", None) + ) + + def before_content(self) -> None: + """Handle object nesting before content + + For constructs that aren't nestable, the stack is bypassed, and instead + only the most recent object is tracked. This object prefix name will be + removed with :delta:meth:`after_content`. + """ + prefix = None + if self.names: + # fullname and name_prefix come from the `handle_signature` method. + # fullname represents the full object name that is constructed using + # object nesting and explicit prefixes. `name_prefix` is the + # explicit prefix given in a signature + (fullname, name_prefix) = self.names[-1] + if self.allow_nesting: + prefix = fullname + elif name_prefix: + prefix = name_prefix.strip(".") + if prefix: + self.env.ref_context["delta:class"] = prefix + if self.allow_nesting: + classes = self.env.ref_context.setdefault("delta:classes", []) + classes.append(prefix) + if "module" in self.options: + modules = self.env.ref_context.setdefault("delta:modules", []) + modules.append(self.env.ref_context.get("delta:module")) + self.env.ref_context["delta:module"] = self.options["module"] + + def after_content(self) -> None: + """Handle object de-nesting after content + + If this class is a nestable object, removing the last nested class prefix + ends further nesting in the object. + + If this class is not a nestable object, the list of classes should not + be altered as we didn't affect the nesting levels in + :delta:meth:`before_content`. + """ + classes = self.env.ref_context.setdefault("delta:classes", []) + if self.allow_nesting: + try: + classes.pop() + except IndexError: + pass + self.env.ref_context["delta:class"] = classes[-1] if len(classes) > 0 else None + if "module" in self.options: + modules = self.env.ref_context.setdefault("delta:modules", []) + if modules: + self.env.ref_context["delta:module"] = modules.pop() + else: + self.env.ref_context.pop("delta:module") + + def _toc_entry_name(self, sig_node: desc_signature) -> str: + if not sig_node.get("_toc_parts"): + return "" + + config = self.env.app.config + objtype = sig_node.parent.get("objtype") + if config.add_function_parentheses and objtype in {"function", "method"}: + parens = "()" + else: + parens = "" + *parents, name = sig_node["_toc_parts"] + if config.toc_object_entries_show_parents == "domain": + return sig_node.get("fullname", name) + parens + if config.toc_object_entries_show_parents == "hide": + return name + parens + if config.toc_object_entries_show_parents == "all": + return ".".join(parents + [name + parens]) + return "" + + +class DeltaFunction(DeltaObject): + """Description of a function.""" + + option_spec: OptionSpec = DeltaObject.option_spec.copy() + option_spec.update( + { + "async": directives.flag, + } + ) + + def get_signature_prefix(self, sig: str) -> list[nodes.Node]: + if "async" in self.options: + return [addnodes.desc_sig_keyword("", "async"), addnodes.desc_sig_space()] + else: + return [] + + def needs_arglist(self) -> bool: + return True + + def add_target_and_index( + self, name_cls: tuple[str, str], sig: str, signode: desc_signature + ) -> None: + super().add_target_and_index(name_cls, sig, signode) + if "noindexentry" not in self.options: + modname = self.options.get( + "module", self.env.ref_context.get("delta:module") + ) + node_id = signode["ids"][0] + + name, cls = name_cls + if modname: + text = _("%s() (in module %s)") % (name, modname) + self.indexnode["entries"].append(("single", text, node_id, "", None)) + else: + text = f"{pairindextypes['builtin']}; {name}()" + self.indexnode["entries"].append(("pair", text, node_id, "", None)) + + def get_index_text(self, modname: str, name_cls: tuple[str, str]) -> str | None: + # add index in own add_target_and_index() instead. + return None + + +class DeltaXRefRole(XRefRole): + def process_link( + self, + env: BuildEnvironment, + refnode: Element, + has_explicit_title: bool, + title: str, + target: str, + ) -> tuple[str, str]: + refnode["delta:module"] = env.ref_context.get("delta:module") + refnode["delta:class"] = env.ref_context.get("delta:class") + if not has_explicit_title: + title = title.lstrip(".") # only has a meaning for the target + target = target.lstrip("~") # only has a meaning for the title + # if the first character is a tilde, don't display the module/class + # parts of the contents + if title[0:1] == "~": + title = title[1:] + dot = title.rfind(".") + if dot != -1: + title = title[dot + 1 :] + # if the first character is a dot, search more specific namespaces first + # else search builtins first + if target[0:1] == ".": + target = target[1:] + refnode["refspecific"] = True + return title, target + + +class DeltaModuleIndex(Index): + """ + Index subclass to provide the Delta module index. + """ + + name = "modindex" + localname = _("Delta Module Index") + shortname = _("modules") + + def generate( + self, docnames: Iterable[str] | None = None + ) -> tuple[list[tuple[str, list[IndexEntry]]], bool]: + content: dict[str, list[IndexEntry]] = {} + # list of prefixes to ignore + ignores: list[str] = self.domain.env.config["modindex_common_prefix"] + ignores = sorted(ignores, key=len, reverse=True) + # list of all modules, sorted by module name + modules = sorted( + self.domain.data["modules"].items(), key=lambda x: x[0].lower() + ) + # sort out collapsible modules + prev_modname = "" + num_toplevels = 0 + for modname, (docname, node_id, synopsis, platforms, deprecated) in modules: + if docnames and docname not in docnames: + continue + + for ignore in ignores: + if modname.startswith(ignore): + modname = modname[len(ignore) :] + stripped = ignore + break + else: + stripped = "" + + # we stripped the whole module name? + if not modname: + modname, stripped = stripped, "" + + entries = content.setdefault(modname[0].lower(), []) + + package = modname.split(".")[0] + if package != modname: + # it's a submodule + if prev_modname == package: + # first submodule - make parent a group head + if entries: + last = entries[-1] + entries[-1] = IndexEntry( + last[0], 1, last[2], last[3], last[4], last[5], last[6] + ) + elif not prev_modname.startswith(package): + # submodule without parent in list, add dummy entry + entries.append( + IndexEntry(stripped + package, 1, "", "", "", "", "") + ) + subtype = 2 + else: + num_toplevels += 1 + subtype = 0 + + qualifier = _("Deprecated") if deprecated else "" + entries.append( + IndexEntry( + stripped + modname, + subtype, + docname, + node_id, + platforms, + qualifier, + synopsis, + ) + ) + prev_modname = modname + + # apply heuristics when to collapse modindex at page load: + # only collapse if number of toplevel modules is larger than + # number of submodules + collapse = len(modules) - num_toplevels < num_toplevels + + # sort by first letter + sorted_content = sorted(content.items()) + + return sorted_content, collapse + + +class DeltaDomain(Domain): + """Delta domain.""" + + name = "delta" + label = "Delta" + object_types: dict[str, ObjType] = { + "function": ObjType(_("function"), "func", "obj"), + } + + directives = { + "function": DeltaFunction, + } + roles = { + "func": DeltaXRefRole(fix_parens=True), + } + initial_data: dict[str, dict[str, tuple[Any]]] = { + "objects": {}, # fullname -> docname, objtype + "modules": {}, # modname -> docname, synopsis, platform, deprecated + } + indices = [ + DeltaModuleIndex, + ] + + @property + def objects(self) -> dict[str, ObjectEntry]: + return self.data.setdefault("objects", {}) # fullname -> ObjectEntry + + def note_object( + self, + name: str, + objtype: str, + node_id: str, + aliased: bool = False, + location: Any = None, + ) -> None: + """Note a delta object for cross reference. + + .. versionadded:: 2.1 + """ + if name in self.objects: + other = self.objects[name] + if other.aliased and aliased is False: + # The original definition found. Override it! + pass + elif other.aliased is False and aliased: + # The original definition is already registered. + return + else: + # duplicated + logger.warning( + __( + "duplicate object description of %s, " + "other instance in %s, use :noindex: for one of them" + ), + name, + other.docname, + location=location, + ) + self.objects[name] = ObjectEntry(self.env.docname, node_id, objtype, aliased) + + @property + def modules(self) -> dict[str, ModuleEntry]: + return self.data.setdefault("modules", {}) # modname -> ModuleEntry + + def note_module( + self, name: str, node_id: str, synopsis: str, platform: str, deprecated: bool + ) -> None: + """Note a delta module for cross reference. + + .. versionadded:: 2.1 + """ + self.modules[name] = ModuleEntry( + self.env.docname, node_id, synopsis, platform, deprecated + ) + + def clear_doc(self, docname: str) -> None: + for fullname, obj in list(self.objects.items()): + if obj.docname == docname: + del self.objects[fullname] + for modname, mod in list(self.modules.items()): + if mod.docname == docname: + del self.modules[modname] + + def merge_domaindata(self, docnames: list[str], otherdata: dict[str, Any]) -> None: + # XXX check duplicates? + for fullname, obj in otherdata["objects"].items(): + if obj.docname in docnames: + self.objects[fullname] = obj + for modname, mod in otherdata["modules"].items(): + if mod.docname in docnames: + self.modules[modname] = mod + + def find_obj( + self, + env: BuildEnvironment, + modname: str, + classname: str, + name: str, + type: str | None, + searchmode: int = 0, + ) -> list[tuple[str, ObjectEntry]]: + """Find a Delta object for "name", perhaps using the given module + and/or classname. Returns a list of (name, object entry) tuples. + """ + # skip parens + if name[-2:] == "()": + name = name[:-2] + + if not name: + return [] + + matches: list[tuple[str, ObjectEntry]] = [] + + newname = None + if searchmode == 1: + if type is None: + objtypes = list(self.object_types) + else: + objtypes = self.objtypes_for_role(type) + if objtypes is not None: + if modname and classname: + fullname = modname + "." + classname + "." + name + if ( + fullname in self.objects + and self.objects[fullname].objtype in objtypes + ): + newname = fullname + if not newname: + if ( + modname + and modname + "." + name in self.objects + and self.objects[modname + "." + name].objtype in objtypes + ): + newname = modname + "." + name + elif ( + name in self.objects and self.objects[name].objtype in objtypes + ): + newname = name + else: + # "fuzzy" searching mode + searchname = "." + name + matches = [ + (oname, self.objects[oname]) + for oname in self.objects + if oname.endswith(searchname) + and self.objects[oname].objtype in objtypes + ] + else: + # NOTE: searching for exact match, object type is not considered + if name in self.objects: + newname = name + elif type == "mod": + # only exact matches allowed for modules + return [] + elif classname and classname + "." + name in self.objects: + newname = classname + "." + name + elif modname and modname + "." + name in self.objects: + newname = modname + "." + name + elif ( + modname + and classname + and modname + "." + classname + "." + name in self.objects + ): + newname = modname + "." + classname + "." + name + if newname is not None: + matches.append((newname, self.objects[newname])) + return matches + + def resolve_xref( + self, + env: BuildEnvironment, + fromdocname: str, + builder: Builder, + type: str, + target: str, + node: pending_xref, + contnode: Element, + ) -> Element | None: + modname = node.get("delta:module") + clsname = node.get("delta:class") + searchmode = 1 if node.hasattr("refspecific") else 0 + matches = self.find_obj(env, modname, clsname, target, type, searchmode) + + if not matches and type == "attr": + # fallback to meth (for property; Sphinx-2.4.x) + # this ensures that `:attr:` role continues to refer to the old property entry + # that defined by ``method`` directive in old reST files. + matches = self.find_obj(env, modname, clsname, target, "meth", searchmode) + if not matches and type == "meth": + # fallback to attr (for property) + # this ensures that `:meth:` in the old reST files can refer to the property + # entry that defined by ``property`` directive. + # + # Note: _prop is a secret role only for internal look-up. + matches = self.find_obj(env, modname, clsname, target, "_prop", searchmode) + + if not matches: + return None + elif len(matches) > 1: + canonicals = [m for m in matches if not m[1].aliased] + if len(canonicals) == 1: + matches = canonicals + else: + logger.warning( + __("more than one target found for cross-reference %r: %s"), + target, + ", ".join(match[0] for match in matches), + type="ref", + subtype="python", + location=node, + ) + name, obj = matches[0] + + if obj[2] == "module": + return self._make_module_refnode(builder, fromdocname, name, contnode) + else: + # determine the content of the reference by conditions + content = find_pending_xref_condition(node, "resolved") + if content: + children = content.children + else: + # if not found, use contnode + children = [contnode] + + return make_refnode(builder, fromdocname, obj[0], obj[1], children, name) + + def resolve_any_xref( + self, + env: BuildEnvironment, + fromdocname: str, + builder: Builder, + target: str, + node: pending_xref, + contnode: Element, + ) -> list[tuple[str, Element]]: + modname = node.get("delta:module") + clsname = node.get("delta:class") + results: list[tuple[str, Element]] = [] + + # always search in "refspecific" mode with the :any: role + matches = self.find_obj(env, modname, clsname, target, None, 1) + multiple_matches = len(matches) > 1 + + for name, obj in matches: + if multiple_matches and obj.aliased: + # Skip duplicated matches + continue + + if obj[2] == "module": + results.append( + ( + "delta:mod", + self._make_module_refnode(builder, fromdocname, name, contnode), + ) + ) + else: + # determine the content of the reference by conditions + content = find_pending_xref_condition(node, "resolved") + if content: + children = content.children + else: + # if not found, use contnode + children = [contnode] + + results.append( + ( + "delta:" + self.role_for_objtype(obj[2]), + make_refnode( + builder, fromdocname, obj[0], obj[1], children, name + ), + ) + ) + return results + + def _make_module_refnode( + self, builder: Builder, fromdocname: str, name: str, contnode: Node + ) -> Element: + # get additional info for modules + module = self.modules[name] + title = name + if module.synopsis: + title += ": " + module.synopsis + if module.deprecated: + title += _(" (deprecated)") + if module.platform: + title += " (" + module.platform + ")" + return make_refnode( + builder, fromdocname, module.docname, module.node_id, contnode, title + ) + + def get_objects(self) -> Iterator[tuple[str, str, str, str, str, int]]: + for modname, mod in self.modules.items(): + yield (modname, modname, "module", mod.docname, mod.node_id, 0) + for refname, obj in self.objects.items(): + if obj.objtype != "module": # modules are already handled + if obj.aliased: + # aliased names are not full-text searchable. + yield (refname, refname, obj.objtype, obj.docname, obj.node_id, -1) + else: + yield (refname, refname, obj.objtype, obj.docname, obj.node_id, 1) + + def get_full_qualified_name(self, node: Element) -> str | None: + modname = node.get("delta:module") + clsname = node.get("delta:class") + target = node.get("reftarget") + if target is None: + return None + else: + return ".".join(filter(None, [modname, clsname, target])) + + +def setup(app: Sphinx) -> dict[str, Any]: + app.setup_extension("sphinx.directives") + app.add_domain(DeltaDomain) + + return { + "version": "builtin", + "env_version": 3, + "parallel_read_safe": True, + "parallel_write_safe": True, + } diff --git a/velox/docs/functions.rst b/velox/docs/functions.rst index 6164b8ad3de..ae2d4d2ef99 100644 --- a/velox/docs/functions.rst +++ b/velox/docs/functions.rst @@ -21,6 +21,8 @@ Presto Functions functions/presto/aggregate functions/presto/window functions/presto/hyperloglog + functions/presto/khyperloglog + functions/presto/setdigest functions/presto/tdigest functions/presto/qdigest functions/presto/geospatial @@ -70,132 +72,140 @@ for :doc:`all ` and :doc:`most used ================================================= ================================================= ================================================= == ================================================= == ================================================= Scalar Functions Aggregate Functions Window Functions ======================================================================================================================================================= == ================================================= == ================================================= - :func:`$internal$json_string_to_array_cast` :func:`gt` :func:`sequence` :func:`any_value` :func:`cume_dist` - :func:`$internal$json_string_to_map_cast` :func:`gte` :func:`sha1` :func:`approx_distinct` :func:`dense_rank` - :func:`$internal$json_string_to_row_cast` :func:`hamming_distance` :func:`sha256` :func:`approx_most_frequent` :func:`first_value` - :func:`$internal$split_to_map` :func:`hmac_md5` :func:`sha512` :func:`approx_percentile` :func:`lag` - :func:`abs` :func:`hmac_sha1` :func:`shuffle` :func:`approx_set` :func:`last_value` - :func:`acos` :func:`hmac_sha256` :func:`sign` :func:`arbitrary` :func:`lead` - :func:`all_keys_match` :func:`hmac_sha512` :func:`simplify_geometry` :func:`array_agg` :func:`nth_value` - :func:`all_match` :func:`hour` :func:`sin` :func:`avg` :func:`ntile` - :func:`any_keys_match` in :func:`slice` :func:`bitwise_and_agg` :func:`percent_rank` - :func:`any_match` :func:`infinity` :func:`split` :func:`bitwise_or_agg` :func:`rank` - :func:`any_values_match` :func:`inverse_beta_cdf` :func:`split_part` :func:`bitwise_xor_agg` :func:`row_number` - :func:`array_average` :func:`inverse_binomial_cdf` :func:`split_to_map` :func:`bool_and` - :func:`array_constructor` :func:`inverse_cauchy_cdf` :func:`split_to_multimap` :func:`bool_or` - :func:`array_cum_sum` :func:`inverse_chi_squared_cdf` :func:`spooky_hash_v2_32` :func:`checksum` - :func:`array_distinct` :func:`inverse_f_cdf` :func:`spooky_hash_v2_64` :func:`classification_fall_out` - :func:`array_duplicates` :func:`inverse_gamma_cdf` :func:`sqrt` :func:`classification_miss_rate` - :func:`array_except` :func:`inverse_laplace_cdf` :func:`st_area` :func:`classification_precision` - :func:`array_frequency` :func:`inverse_normal_cdf` :func:`st_asbinary` :func:`classification_recall` - :func:`array_has_duplicates` :func:`inverse_poisson_cdf` :func:`st_astext` :func:`classification_thresholds` - :func:`array_intersect` :func:`inverse_weibull_cdf` :func:`st_boundary` :func:`corr` - :func:`array_join` :func:`ip_prefix` :func:`st_buffer` :func:`count` - :func:`array_max` :func:`ip_prefix_collapse` :func:`st_centroid` :func:`count_if` - :func:`array_max_by` :func:`ip_prefix_subnets` :func:`st_contains` :func:`covar_pop` - :func:`array_min` :func:`ip_subnet_max` :func:`st_convexhull` :func:`covar_samp` - :func:`array_min_by` :func:`ip_subnet_min` :func:`st_coorddim` :func:`entropy` - :func:`array_normalize` :func:`ip_subnet_range` :func:`st_crosses` :func:`every` - :func:`array_position` :func:`is_finite` :func:`st_difference` :func:`geometric_mean` - :func:`array_remove` :func:`is_infinite` :func:`st_dimension` :func:`histogram` - :func:`array_sort` :func:`is_json_scalar` :func:`st_disjoint` :func:`kurtosis` - :func:`array_sort_desc` :func:`is_nan` :func:`st_distance` :func:`map_agg` - :func:`array_sum` :func:`is_null` :func:`st_endpoint` :func:`map_union` - :func:`array_sum_propagate_element_null` :func:`is_private_ip` :func:`st_envelope` :func:`map_union_sum` - :func:`array_top_n` :func:`is_subnet_of` :func:`st_envelopeaspts` :func:`max` - :func:`array_union` :func:`json_array_contains` :func:`st_equals` :func:`max_by` - :func:`arrays_overlap` :func:`json_array_get` :func:`st_exteriorring` :func:`max_data_size_for_stats` - :func:`asin` :func:`json_array_length` :func:`st_geometries` :func:`merge` - :func:`at_timezone` :func:`json_extract` :func:`st_geometryfromtext` :func:`min` - :func:`atan` :func:`json_extract_scalar` :func:`st_geometryn` :func:`min_by` - :func:`atan2` :func:`json_format` :func:`st_geometrytype` :func:`multimap_agg` - :func:`beta_cdf` :func:`json_parse` :func:`st_geomfrombinary` :func:`noisy_approx_distinct_sfm` - :func:`between` :func:`json_size` :func:`st_interiorringn` :func:`noisy_approx_set_sfm` - :func:`bing_tile` :func:`laplace_cdf` :func:`st_interiorrings` :func:`noisy_approx_set_sfm_from_index_and_zeros` - :func:`bing_tile_at` :func:`last_day_of_month` :func:`st_intersection` :func:`noisy_avg_gaussian` - :func:`bing_tile_children` :func:`least` :func:`st_intersects` :func:`noisy_count_gaussian` - :func:`bing_tile_coordinates` :func:`length` :func:`st_isclosed` :func:`noisy_count_if_gaussian` - :func:`bing_tile_parent` :func:`levenshtein_distance` :func:`st_isempty` :func:`noisy_sum_gaussian` - :func:`bing_tile_quadkey` :func:`like` :func:`st_isring` :func:`numeric_histogram` - :func:`bing_tile_zoom_level` :func:`line_interpolate_point` :func:`st_issimple` :func:`qdigest_agg` - :func:`bing_tiles_around` :func:`line_locate_point` :func:`st_isvalid` :func:`reduce_agg` - :func:`binomial_cdf` :func:`ln` :func:`st_length` :func:`regr_avgx` - :func:`bit_count` :func:`log10` :func:`st_numgeometries` :func:`regr_avgy` - :func:`bitwise_and` :func:`log2` :func:`st_numinteriorring` :func:`regr_count` - :func:`bitwise_arithmetic_shift_right` :func:`lower` :func:`st_numpoints` :func:`regr_intercept` - :func:`bitwise_left_shift` :func:`lpad` :func:`st_overlaps` :func:`regr_r2` - :func:`bitwise_logical_shift_right` :func:`lt` :func:`st_point` :func:`regr_slope` - :func:`bitwise_not` :func:`lte` :func:`st_pointn` :func:`regr_sxx` - :func:`bitwise_or` :func:`ltrim` :func:`st_points` :func:`regr_sxy` - :func:`bitwise_right_shift` :func:`map` :func:`st_polygon` :func:`regr_syy` - :func:`bitwise_right_shift_arithmetic` :func:`map_concat` :func:`st_relate` :func:`set_agg` - :func:`bitwise_shift_left` :func:`map_entries` :func:`st_startpoint` :func:`set_union` - :func:`bitwise_xor` :func:`map_filter` :func:`st_symdifference` :func:`skewness` - :func:`cardinality` :func:`map_from_entries` :func:`st_touches` :func:`stddev` - :func:`cauchy_cdf` :func:`map_key_exists` :func:`st_union` :func:`stddev_pop` - :func:`cbrt` :func:`map_keys` :func:`st_within` :func:`stddev_samp` - :func:`ceil` :func:`map_keys_by_top_n_values` :func:`st_x` :func:`sum` - :func:`ceiling` :func:`map_normalize` :func:`st_xmax` :func:`sum_data_size_for_stats` - :func:`chi_squared_cdf` :func:`map_remove_null_values` :func:`st_xmin` :func:`tdigest_agg` - :func:`chr` :func:`map_subset` :func:`st_y` :func:`var_pop` - :func:`clamp` :func:`map_top_n` :func:`st_ymax` :func:`var_samp` - :func:`codepoint` :func:`map_top_n_keys` :func:`st_ymin` :func:`variance` - :func:`combinations` :func:`map_top_n_values` :func:`starts_with` - :func:`combine_hash_internal` :func:`map_values` :func:`strpos` - :func:`concat` :func:`map_zip_with` :func:`strrpos` - :func:`construct_tdigest` :func:`md5` :func:`subscript` - :func:`contains` :func:`merge_sfm` :func:`substr` - :func:`cos` :func:`merge_tdigest` :func:`substring` - :func:`cosh` :func:`millisecond` :func:`tan` - :func:`cosine_similarity` :func:`minus` :func:`tanh` - :func:`crc32` :func:`minute` :func:`timezone_hour` - :func:`current_date` :func:`mod` :func:`timezone_minute` - :func:`date` :func:`month` :func:`to_base` - :func:`date_add` :func:`multimap_from_entries` :func:`to_base64` - :func:`date_diff` :func:`multiply` :func:`to_base64url` - :func:`date_format` :func:`murmur3_x64_128` :func:`to_big_endian_32` - :func:`date_parse` :func:`nan` :func:`to_big_endian_64` - :func:`date_trunc` :func:`negate` :func:`to_hex` - :func:`day` :func:`neq` :func:`to_ieee754_32` - :func:`day_of_month` :func:`ngrams` :func:`to_ieee754_64` - :func:`day_of_week` :func:`no_keys_match` :func:`to_iso8601` - :func:`day_of_year` :func:`no_values_match` :func:`to_milliseconds` - :func:`degrees` :func:`noisy_empty_approx_set_sfm` :func:`to_unixtime` - :func:`destructure_tdigest` :func:`none_match` :func:`to_utf8` - :func:`distinct_from` :func:`normal_cdf` :func:`trail` - :func:`divide` :func:`normalize` :func:`transform` - :func:`dot_product` not :func:`transform_keys` - :func:`dow` :func:`parse_datetime` :func:`transform_values` - :func:`doy` :func:`parse_duration` :func:`trim` - :func:`e` :func:`parse_presto_data_size` :func:`trim_array` - :func:`element_at` :func:`pi` :func:`trimmed_mean` - :func:`empty_approx_set` :func:`plus` :func:`truncate` - :func:`ends_with` :func:`poisson_cdf` :func:`typeof` - :func:`eq` :func:`pow` :func:`upper` - :func:`exp` :func:`power` :func:`url_decode` - :func:`f_cdf` :func:`quantile_at_value` :func:`url_encode` - :func:`fail` :func:`quantiles_at_values` :func:`url_extract_fragment` - :func:`filter` :func:`quarter` :func:`url_extract_host` - :func:`find_first` :func:`radians` :func:`url_extract_parameter` - :func:`find_first_index` :func:`rand` :func:`url_extract_path` - :func:`flatten` :func:`random` :func:`url_extract_port` - :func:`flatten_geometry_collections` :func:`reduce` :func:`url_extract_protocol` - :func:`floor` :func:`regexp_extract` :func:`url_extract_query` - :func:`format_datetime` :func:`regexp_extract_all` :func:`uuid` - :func:`from_base` :func:`regexp_like` :func:`value_at_quantile` - :func:`from_base64` :func:`regexp_replace` :func:`values_at_quantiles` - :func:`from_base64url` :func:`regexp_split` :func:`week` - :func:`from_big_endian_32` :func:`remove_nulls` :func:`week_of_year` - :func:`from_big_endian_64` :func:`repeat` :func:`weibull_cdf` - :func:`from_hex` :func:`replace` :func:`width_bucket` - :func:`from_ieee754_32` :func:`replace_first` :func:`wilson_interval_lower` - :func:`from_ieee754_64` :func:`reverse` :func:`wilson_interval_upper` - :func:`from_iso8601_date` :func:`round` :func:`word_stem` - :func:`from_iso8601_timestamp` :func:`rpad` :func:`xxhash64` - :func:`from_unixtime` :func:`rtrim` :func:`xxhash64_internal` - :func:`from_utf8` :func:`scale_qdigest` :func:`year` - :func:`gamma_cdf` :func:`scale_tdigest` :func:`year_of_week` - :func:`geometry_invalid_reason` :func:`second` :func:`yow` - :func:`geometry_nearest_points` :func:`secure_rand` :func:`zip` - :func:`greatest` :func:`secure_random` :func:`zip_with` + :func:`$internal$json_string_to_array_cast` :func:`geometry_to_dissolved_bing_tiles` :func:`secure_rand` :func:`any_value` :func:`cume_dist` + :func:`$internal$json_string_to_map_cast` :func:`geometry_union` :func:`secure_random` :func:`approx_distinct` :func:`dense_rank` + :func:`$internal$json_string_to_row_cast` :func:`great_circle_distance` :func:`sequence` :func:`approx_most_frequent` :func:`first_value` + :func:`$internal$split_to_map` :func:`greatest` :func:`sha1` :func:`approx_percentile` :func:`lag` + :func:`abs` :func:`gt` :func:`sha256` :func:`approx_set` :func:`last_value` + :func:`acos` :func:`gte` :func:`sha512` :func:`arbitrary` :func:`lead` + :func:`all_keys_match` :func:`hamming_distance` :func:`shuffle` :func:`array_agg` :func:`nth_value` + :func:`all_match` :func:`hmac_md5` :func:`sign` :func:`avg` :func:`ntile` + :func:`any_keys_match` :func:`hmac_sha1` :func:`simplify_geometry` :func:`bitwise_and_agg` :func:`percent_rank` + :func:`any_match` :func:`hmac_sha256` :func:`sin` :func:`bitwise_or_agg` :func:`rank` + :func:`any_values_match` :func:`hmac_sha512` :func:`slice` :func:`bitwise_xor_agg` :func:`row_number` + :func:`array_average` :func:`hour` :func:`split` :func:`bool_and` + :func:`array_constructor` in :func:`split_part` :func:`bool_or` + :func:`array_cum_sum` :func:`infinity` :func:`split_to_map` :func:`checksum` + :func:`array_distinct` :func:`inverse_beta_cdf` :func:`split_to_multimap` :func:`classification_fall_out` + :func:`array_duplicates` :func:`inverse_binomial_cdf` :func:`spooky_hash_v2_32` :func:`classification_miss_rate` + :func:`array_except` :func:`inverse_cauchy_cdf` :func:`spooky_hash_v2_64` :func:`classification_precision` + :func:`array_frequency` :func:`inverse_chi_squared_cdf` :func:`sqrt` :func:`classification_recall` + :func:`array_has_duplicates` :func:`inverse_f_cdf` :func:`st_area` :func:`classification_thresholds` + :func:`array_intersect` :func:`inverse_gamma_cdf` :func:`st_asbinary` :func:`corr` + :func:`array_join` :func:`inverse_laplace_cdf` :func:`st_astext` :func:`count` + :func:`array_max` :func:`inverse_normal_cdf` :func:`st_boundary` :func:`count_if` + :func:`array_max_by` :func:`inverse_poisson_cdf` :func:`st_buffer` :func:`covar_pop` + :func:`array_min` :func:`inverse_t_cdf` :func:`st_centroid` :func:`covar_samp` + :func:`array_min_by` :func:`inverse_weibull_cdf` :func:`st_contains` :func:`entropy` + :func:`array_normalize` :func:`ip_prefix` :func:`st_convexhull` :func:`every` + :func:`array_position` :func:`ip_prefix_collapse` :func:`st_coorddim` :func:`geometric_mean` + :func:`array_remove` :func:`ip_prefix_subnets` :func:`st_crosses` :func:`histogram` + :func:`array_sort` :func:`ip_subnet_max` :func:`st_difference` :func:`kurtosis` + :func:`array_sort_desc` :func:`ip_subnet_min` :func:`st_dimension` :func:`map_agg` + :func:`array_subset` :func:`ip_subnet_range` :func:`st_disjoint` :func:`map_union` + :func:`array_sum` :func:`is_finite` :func:`st_distance` :func:`map_union_sum` + :func:`array_sum_propagate_element_null` :func:`is_infinite` :func:`st_endpoint` :func:`max` + :func:`array_top_n` :func:`is_json_scalar` :func:`st_envelope` :func:`max_by` + :func:`array_union` :func:`is_nan` :func:`st_envelopeaspts` :func:`max_data_size_for_stats` + :func:`arrays_overlap` :func:`is_null` :func:`st_equals` :func:`merge` + :func:`asin` :func:`is_private_ip` :func:`st_exteriorring` :func:`min` + :func:`at_timezone` :func:`is_subnet_of` :func:`st_geometries` :func:`min_by` + :func:`atan` :func:`json_array_contains` :func:`st_geometryfromtext` :func:`multimap_agg` + :func:`atan2` :func:`json_array_get` :func:`st_geometryn` :func:`noisy_approx_distinct_sfm` + :func:`beta_cdf` :func:`json_array_length` :func:`st_geometrytype` :func:`noisy_approx_set_sfm` + :func:`between` :func:`json_extract` :func:`st_geomfrombinary` :func:`noisy_approx_set_sfm_from_index_and_zeros` + :func:`bing_tile` :func:`json_extract_scalar` :func:`st_interiorringn` :func:`noisy_avg_gaussian` + :func:`bing_tile_at` :func:`json_format` :func:`st_interiorrings` :func:`noisy_count_gaussian` + :func:`bing_tile_children` :func:`json_parse` :func:`st_intersection` :func:`noisy_count_if_gaussian` + :func:`bing_tile_coordinates` :func:`json_size` :func:`st_intersects` :func:`noisy_sum_gaussian` + :func:`bing_tile_parent` :func:`laplace_cdf` :func:`st_isclosed` :func:`numeric_histogram` + :func:`bing_tile_polygon` :func:`last_day_of_month` :func:`st_isempty` :func:`qdigest_agg` + :func:`bing_tile_quadkey` :func:`least` :func:`st_isring` :func:`reduce_agg` + :func:`bing_tile_zoom_level` :func:`length` :func:`st_issimple` :func:`regr_avgx` + :func:`bing_tiles_around` :func:`levenshtein_distance` :func:`st_isvalid` :func:`regr_avgy` + :func:`binomial_cdf` :func:`like` :func:`st_length` :func:`regr_count` + :func:`bit_count` :func:`line_interpolate_point` :func:`st_linefromtext` :func:`regr_intercept` + :func:`bit_length` :func:`line_locate_point` :func:`st_linestring` :func:`regr_r2` + :func:`bitwise_and` :func:`ln` :func:`st_multipoint` :func:`regr_slope` + :func:`bitwise_arithmetic_shift_right` :func:`localtime` :func:`st_numgeometries` :func:`regr_sxx` + :func:`bitwise_left_shift` :func:`log10` :func:`st_numinteriorring` :func:`regr_sxy` + :func:`bitwise_logical_shift_right` :func:`log2` :func:`st_numpoints` :func:`regr_syy` + :func:`bitwise_not` :func:`longest_common_prefix` :func:`st_overlaps` :func:`set_agg` + :func:`bitwise_or` :func:`lower` :func:`st_point` :func:`set_union` + :func:`bitwise_right_shift` :func:`lpad` :func:`st_pointn` :func:`skewness` + :func:`bitwise_right_shift_arithmetic` :func:`lt` :func:`st_points` :func:`stddev` + :func:`bitwise_shift_left` :func:`lte` :func:`st_polygon` :func:`stddev_pop` + :func:`bitwise_xor` :func:`ltrim` :func:`st_relate` :func:`stddev_samp` + :func:`cardinality` :func:`map` :func:`st_startpoint` :func:`sum` + :func:`cauchy_cdf` :func:`map_concat` :func:`st_symdifference` :func:`sum_data_size_for_stats` + :func:`cbrt` :func:`map_entries` :func:`st_touches` :func:`tdigest_agg` + :func:`ceil` :func:`map_filter` :func:`st_union` :func:`var_pop` + :func:`ceiling` :func:`map_from_entries` :func:`st_within` :func:`var_samp` + :func:`chi_squared_cdf` :func:`map_intersect` :func:`st_x` :func:`variance` + :func:`chr` :func:`map_key_exists` :func:`st_xmax` + :func:`clamp` :func:`map_keys` :func:`st_xmin` + :func:`codepoint` :func:`map_keys_by_top_n_values` :func:`st_y` + :func:`combinations` :func:`map_normalize` :func:`st_ymax` + :func:`combine_hash_internal` :func:`map_remove_null_values` :func:`st_ymin` + :func:`concat` :func:`map_subset` :func:`starts_with` + :func:`construct_tdigest` :func:`map_top_n` :func:`strpos` + :func:`contains` :func:`map_top_n_keys` :func:`strrpos` + :func:`cos` :func:`map_top_n_values` :func:`subscript` + :func:`cosh` :func:`map_values` :func:`substr` + :func:`cosine_similarity` :func:`map_zip_with` :func:`substring` + :func:`crc32` :func:`md5` :func:`t_cdf` + :func:`current_date` :func:`merge_hll` :func:`tan` + :func:`date` :func:`merge_sfm` :func:`tanh` + :func:`date_add` :func:`merge_tdigest` :func:`timezone_hour` + :func:`date_diff` :func:`millisecond` :func:`timezone_minute` + :func:`date_format` :func:`minus` :func:`to_base` + :func:`date_parse` :func:`minute` :func:`to_base64` + :func:`date_trunc` :func:`mod` :func:`to_base64url` + :func:`day` :func:`month` :func:`to_big_endian_32` + :func:`day_of_month` :func:`multimap_from_entries` :func:`to_big_endian_64` + :func:`day_of_week` :func:`multiply` :func:`to_hex` + :func:`day_of_year` :func:`murmur3_x64_128` :func:`to_ieee754_32` + :func:`degrees` :func:`nan` :func:`to_ieee754_64` + :func:`destructure_tdigest` :func:`negate` :func:`to_iso8601` + :func:`distinct_from` :func:`neq` :func:`to_milliseconds` + :func:`divide` :func:`ngrams` :func:`to_unixtime` + :func:`dot_product` :func:`no_keys_match` :func:`to_utf8` + :func:`dow` :func:`no_values_match` :func:`trail` + :func:`doy` :func:`noisy_empty_approx_set_sfm` :func:`transform` + :func:`e` :func:`none_match` :func:`transform_keys` + :func:`element_at` :func:`normal_cdf` :func:`transform_values` + :func:`empty_approx_set` :func:`normalize` :func:`trim` + :func:`ends_with` not :func:`trim_array` + :func:`enum_key` :func:`parse_datetime` :func:`trimmed_mean` + :func:`eq` :func:`parse_duration` :func:`truncate` + :func:`exp` :func:`parse_presto_data_size` :func:`typeof` + :func:`expand_envelope` :func:`pi` :func:`upper` + :func:`f_cdf` :func:`plus` :func:`url_decode` + :func:`fail` :func:`poisson_cdf` :func:`url_encode` + :func:`filter` :func:`pow` :func:`url_extract_fragment` + :func:`find_first` :func:`power` :func:`url_extract_host` + :func:`find_first_index` :func:`quantile_at_value` :func:`url_extract_parameter` + :func:`flatten` :func:`quantiles_at_values` :func:`url_extract_path` + :func:`flatten_geometry_collections` :func:`quarter` :func:`url_extract_port` + :func:`floor` :func:`radians` :func:`url_extract_protocol` + :func:`format_datetime` :func:`rand` :func:`url_extract_query` + :func:`from_base` :func:`random` :func:`uuid` + :func:`from_base32` :func:`reduce` :func:`value_at_quantile` + :func:`from_base64` :func:`regexp_extract` :func:`values_at_quantiles` + :func:`from_base64url` :func:`regexp_extract_all` :func:`week` + :func:`from_big_endian_32` :func:`regexp_like` :func:`week_of_year` + :func:`from_big_endian_64` :func:`regexp_replace` :func:`weibull_cdf` + :func:`from_hex` :func:`regexp_split` :func:`width_bucket` + :func:`from_ieee754_32` :func:`remap_keys` :func:`wilson_interval_lower` + :func:`from_ieee754_64` :func:`remove_nulls` :func:`wilson_interval_upper` + :func:`from_iso8601_date` :func:`repeat` :func:`word_stem` + :func:`from_iso8601_timestamp` :func:`replace` :func:`xxhash64` + :func:`from_unixtime` :func:`replace_first` :func:`xxhash64_internal` + :func:`from_utf8` :func:`reverse` :func:`year` + :func:`gamma_cdf` :func:`round` :func:`year_of_week` + :func:`geometry_as_geojson` :func:`rpad` :func:`yow` + :func:`geometry_from_geojson` :func:`rtrim` :func:`zip` + :func:`geometry_invalid_reason` :func:`scale_qdigest` :func:`zip_with` + :func:`geometry_nearest_points` :func:`scale_tdigest` + :func:`geometry_to_bing_tiles` :func:`second` ================================================= ================================================= ================================================= == ================================================= == ================================================= diff --git a/velox/docs/functions/delta/functions.rst b/velox/docs/functions/delta/functions.rst new file mode 100644 index 00000000000..e1b3de5673f --- /dev/null +++ b/velox/docs/functions/delta/functions.rst @@ -0,0 +1,13 @@ +******************** +Delta Lake Functions +******************** + +Here is a list of all scalar Delta Lake functions available in Velox. +Function names link to function description. + +These functions are used in deletion vector read. +Refer to `Delta Lake documentation `_ and `Delta Lake deletion vector blog `_ for details. + +.. delta:function:: bitmap_array_contains(bitmap_array: varbinary, input: bigint) -> bool + + Not implemented. diff --git a/velox/docs/functions/iceberg/functions.rst b/velox/docs/functions/iceberg/functions.rst index 768f79b43a5..a878f8c40ea 100644 --- a/velox/docs/functions/iceberg/functions.rst +++ b/velox/docs/functions/iceberg/functions.rst @@ -22,3 +22,58 @@ Refer to `Iceberg documenation date + + Returns the date. :: + + SELECT days(DATE '2017-12-01'); -- 2017-12-01 + SELECT days(TIMESTAMP '2017-12-01 10:12:55.038194'); -- 2017-12-01 + SELECT days(DATE '1969-12-31'); -- 1969-12-31 + +.. iceberg:function:: hours(input) -> integer + + Returns the number of hours since epoch (1970-01-01 00:00:00). Returns 0 for '1970-01-01 00:00:00' timestamps. + Returns negative value for timestamps before '1970-01-01 00:00:00'. :: + + SELECT hours(TIMESTAMP '2017-12-01 10:12:55.038194'); -- 420034 + SELECT hours(TIMESTAMP '1969-12-31 23:59:58.999999'); -- -1 + +.. iceberg:function:: months(input) -> integer + + Returns the number of months since epoch (1970-01-01). Returns 0 for '1970-01-01' date and timestamps. + Returns negative value for dates and timestamps before '1970-01-01'. :: + + SELECT months(DATE '2017-12-01'); -- 575 + SELECT months(TIMESTAMP '2017-12-01 10:12:55.038194'); -- 575 + SELECT months(DATE '1960-01-01'); -- -120 + +.. iceberg:function:: truncate(width, input) -> same type as input + + Returns the truncated value of the input based on the specified width. + For numeric values, truncate to the nearest lower multiple of ``width``, the truncate function is: input - (((input % width) + width) % width). + The ``width`` is used to truncate decimal values is applied using unscaled value to avoid additional (and potentially conflicting) parameters. + For string values, it truncates a valid UTF-8 string with no more than ``width`` code points. + In contrast to strings, binary values do not have an assumed encoding and are truncated to ``width`` bytes. + + Argument ``width`` must be a positive integer. + Supported types for ``input`` are: SHORTINT, TYNYINT, SMALLINT, INTEGER, BIGINT, DECIMAL, VARCHAR, VARBINARY. :: + + SELECT truncate(10, 11); -- 10 + SELECT truncate(10, -11); -- -20 + SELECT truncate(7, 22); -- 21 + SELECT truncate(0, 11); -- error: Reason: (0 vs. 0) Invalid truncate width\nExpression: width <= 0 + SELECT truncate(-3, 11); -- error: Reason: (-3 vs. 0) Invalid truncate width\nExpression: width <= 0 + SELECT truncate(4, 'iceberg'); -- 'iceb' + SELECT truncate(1, '测试'); -- 测 + SELECT truncate(6, '测试'); -- 测试 + SELECT truncate(6, cast('测试' as binary)); -- 测试_ + +.. iceberg:function:: years(input) -> integer + + Returns the number of years since epoch (1970-01-01). Returns 0 for '1970-01-01' date and timestamps. + Returns negative value for dates and timestamps before '1970-01-01'. :: + + SELECT years(DATE '2017-12-01'); -- 47 + SELECT years(TIMESTAMP '2017-12-01 10:12:55.038194'); -- 47 + SELECT years(DATE '1960-01-01'); -- -10 diff --git a/velox/docs/functions/presto/aggregate.rst b/velox/docs/functions/presto/aggregate.rst index 8212675c328..97b6d91b646 100644 --- a/velox/docs/functions/presto/aggregate.rst +++ b/velox/docs/functions/presto/aggregate.rst @@ -17,14 +17,14 @@ depending on the order of input values. General Aggregate Functions --------------------------- -.. function:: arbitrary(x) -> [same as x] - - Returns an arbitrary non-null value of ``x``, if one exists. - .. function:: any_value(x) -> [same as x] This is an alias for :func:`arbitrary`. +.. function:: arbitrary(x) -> [same as x] + + Returns an arbitrary non-null value of ``x``, if one exists. + .. function:: array_agg(x) -> array<[same as x]> Returns an array created from the input ``x`` elements. Ignores null @@ -89,37 +89,17 @@ General Aggregate Functions This is an alias for :func:`bool_and`. -.. function:: histogram(x) - - Returns a map containing the count of the number of times - each input value occurs. Supports integral, floating-point, - boolean, timestamp, and date input types. - .. function:: geometric_mean(bigint) -> double geometric_mean(double) -> double geometric_mean(real) -> real Returns the `geometric mean `_ of all input values. -.. function:: max_by(x, y) -> [same as x] - - Returns the value of ``x`` associated with the maximum value of ``y`` over all input values. - ``y`` must be an orderable type. - -.. function:: max_by(x, y, n) -> array([same as x]) - :noindex: - - Returns n values of ``x`` associated with the n largest values of ``y`` in descending order of ``y``. - -.. function:: min_by(x, y) -> [same as x] - - Returns the value of ``x`` associated with the minimum value of ``y`` over all input values. - ``y`` must be an orderable type. - -.. function:: min_by(x, y, n) -> array([same as x]) - :noindex: +.. function:: histogram(x) - Returns n values of ``x`` associated with the n smallest values of ``y`` in ascending order of ``y``. + Returns a map containing the count of the number of times + each input value occurs. Supports integral, floating-point, + boolean, timestamp, and date input types. .. function:: max(x) -> [same as x] @@ -138,6 +118,16 @@ General Aggregate Functions Nulls are not included in the output array. For REAL and DOUBLE types, NaN is considered greater than Infinity. +.. function:: max_by(x, y) -> [same as x] + + Returns the value of ``x`` associated with the maximum value of ``y`` over all input values. + ``y`` must be an orderable type. + +.. function:: max_by(x, y, n) -> array([same as x]) + :noindex: + + Returns n values of ``x`` associated with the n largest values of ``y`` in descending order of ``y``. + .. function:: min(x) -> [same as x] Returns the minimum value of all input values. @@ -155,6 +145,16 @@ General Aggregate Functions Nulls are not included in output array. For REAL and DOUBLE types, NaN is considered greater than Infinity. +.. function:: min_by(x, y) -> [same as x] + + Returns the value of ``x`` associated with the minimum value of ``y`` over all input values. + ``y`` must be an orderable type. + +.. function:: min_by(x, y, n) -> array([same as x]) + :noindex: + + Returns n values of ``x`` associated with the n smallest values of ``y`` in ascending order of ``y``. + .. function:: multimap_agg(K key, V value) -> map(K,array(V)) Returns a multimap created from the input ``key`` / ``value`` pairs. @@ -307,6 +307,33 @@ Map Aggregate Functions Returns the union of all the input maps summing the values of matching keys in all the maps. All null values in the original maps are coalesced to 0. +Array Aggregate Functions +------------------------- + +.. function:: vector_sum(array(T)) -> array(T) + + Returns the element-wise sum of all input arrays. Equivalent to + ``ARRAY[SUM(a[1]), SUM(a[2]), ...]``, with the same null-handling + semantics as :func:`sum`: null elements are skipped, and positions + where all input values are null produce null in the output. + All input arrays must have the same length; an error is raised if + arrays of different lengths are encountered. + Supported types for T are: TINYINT, SMALLINT, INTEGER, BIGINT, REAL + and DOUBLE. + For integer types, arithmetic overflow results in an error, + consistent with the behavior of :func:`sum`. For floating-point + types (REAL, DOUBLE), NaN values propagate through the sum and + overflow produces Infinity, following standard IEEE 754 semantics. + + This is useful when rows contain fixed-dimension vectors (e.g. + embedding vectors or feature arrays) and you need to compute a + component-wise sum across all rows:: + + SELECT vector_sum(embedding) FROM item_embeddings; + + -- With 3 rows: [1, 2, 3], [10, 20, 30], [100, 200, 300] + -- Returns: [111, 222, 333] + Approximate Aggregate Functions ------------------------------- @@ -420,6 +447,40 @@ __ https://www.cse.ust.hk/~raywong/comp5331/References/EfficientComputationOfFre As ``approx_percentile(x, w, percentages)``, but with a maximum rank error of ``accuracy``. +.. function:: numeric_histogram(buckets, value, weight) -> map + + Computes an approximate histogram with up to ``buckets`` number of buckets + for all ``value``\ s with a per-item weight of ``weight``. The keys of the + returned map are roughly the center of the bin, and the entry is the total + weight of the bin. The algorithm is based loosely on [BenHaimTomTov2010]_. + + ``buckets`` must be a ``bigint``. ``value`` and ``weight`` must be numeric. + :: + + SELECT numeric_histogram(3, v, 1.0) + FROM ( + VALUES (10), + (15), + (20), + (25), + (30) + ) AS t(v); + --{30.0->1.0, 22.5->2.0, 12.5->2.0} + +.. function:: numeric_histogram(buckets, value) -> map + + Computes an approximate histogram with up to ``buckets`` number of buckets + for all ``value``\ s. This function is equivalent to the variant of + :func:`!numeric_histogram` that takes a ``weight``, with a per-item weight of ``1``. + In this case, the total weight in the returned map is the count of items in the bin. + :: + + SELECT numeric_histogram(3, v) + FROM ( + VALUES (10.0), (15.0), (20.0), (25.0), (30.0) + ) AS t(v); + --{30.0->1.0, 22.5->2.0, 12.5->2.0} + Classification Metrics Aggregate Functions ------------------------------------------ @@ -703,10 +764,6 @@ Statistical Aggregate Functions Returns the sample standard deviation of all input values. -.. function:: variance(x) -> double - - This is an alias for :func:`var_samp`. - .. function:: var_pop(x) -> double Returns the population variance of all input values. @@ -715,6 +772,10 @@ Statistical Aggregate Functions Returns the sample variance of all input values. +.. function:: variance(x) -> double + + This is an alias for :func:`var_samp`. + Noisy Aggregate Functions ------------------------- diff --git a/velox/docs/functions/presto/array.rst b/velox/docs/functions/presto/array.rst index ffa394f6e4c..db3230cdd7a 100644 --- a/velox/docs/functions/presto/array.rst +++ b/velox/docs/functions/presto/array.rst @@ -7,7 +7,7 @@ Array Functions Returns whether all elements of an array match the given predicate. Returns true if all the elements match the predicate (a special case is when the array is empty); - Returns false if one or more elements don’t match; + Returns false if one or more elements don't match; Returns NULL if the predicate function returns NULL for one or more elements and true for all other elements. Throws an exception if the predicate fails for one or more elements and returns true or NULL for the rest. @@ -20,20 +20,12 @@ Array Functions Returns NULL if the predicate function returns NULL for one or more elements and false for all other elements. Throws an exception if the predicate fails for one or more elements and returns false or NULL for the rest. -.. function:: none_match(array(T), function(T, boolean)) → boolean - - Returns whether no elements of an array match the given predicate. - - Returns true if none of the elements matches the predicate (a special case is when the array is empty); - Returns false if one or more elements match; - Returns NULL if the predicate function returns NULL for one or more elements and false for all other elements. - Throws an exception if the predicate fails for one or more elements and returns false or NULL for the rest. - .. function:: array_average(array(double)) -> double Returns the average of all non-null elements of the array. If there are no non-null elements, returns null. .. function:: array_cum_sum(array(T)) -> array(T) + Returns the array whose elements are the cumulative sum of the input array, i.e. result[i] = input[1] + input[2] + … + input[i]. If there there is null elements in the array, the cumulative sum at and after the element is null. The following types are supported: int8_t, int16_t, int32_t, int64_t, int128_t, float, double, ShortDecimal, @@ -117,6 +109,13 @@ Array Functions SELECT array_max(ARRAY[{-1, -2, -3, nan()]); -- NaN SELECT array_max(ARRAY[{infinity(), nan()]); -- NaN +.. function:: array_max_by(array(T), function(T, U)) -> T() + + Applies the provided function to each element, and returns the element that gives the maximum value. + ``U`` can be any orderable type. :: + + SELECT array_max_by(ARRAY ['a', 'bbb', 'cc'], x -> LENGTH(x)) -- 'bbb' + .. function:: array_min(array(E)) -> E Returns the minimum value of input array. @@ -131,20 +130,16 @@ Array Functions SELECT array_min(ARRAY[{-1, -2, -3, nan()]); -- -1 SELECT array_min(ARRAY[{infinity(), nan()]); -- Infinity -.. function:: array_normalize(array(E), E) -> array(E) +.. function:: array_min_by(array(T), function(T, U)) -> T - Normalizes array ``x`` by dividing each element by the p-norm of the array. It is equivalent to ``TRANSFORM(array, v -> v / REDUCE(array, 0, (a, v) -> a + POW(ABS(v), p), a -> POW(a, 1 / p))``, but the reduce part is only executed once. Returns null if the array is null or there are null array elements. If ``p`` is 0, then the input array is returned. Only REAL and DOUBLE types are supported. + Applies the provided function to each element, and returns the element that gives the minimum value. + ``U`` can be any orderable type. :: -.. function:: arrays_overlap(x, y) -> boolean + SELECT array_min_by(ARRAY ['a', 'bbb', 'cc'], x -> LENGTH(x)) -- 'a' - Tests if arrays ``x`` and ``y`` have any non-null elements in common. - Returns null if there are no non-null elements in common but either array contains null. - For REAL and DOUBLE, NANs (Not-a-Number) are considered equal. - -.. function:: arrays_union(x, y) -> array +.. function:: array_normalize(array(E), E) -> array(E) - Returns an array of the elements in the union of x and y, without duplicates. - For REAL and DOUBLE, NANs (Not-a-Number) are considered equal. + Normalizes array ``x`` by dividing each element by the p-norm of the array. It is equivalent to ``TRANSFORM(array, v -> v / REDUCE(array, 0, (a, v) -> a + POW(ABS(v), p), a -> POW(a, 1 / p))``, but the reduce part is only executed once. Returns null if the array is null or there are null array elements. If ``p`` is 0, then the input array is returned. Only REAL and DOUBLE types are supported. .. function:: array_position(x, element) -> bigint @@ -209,7 +204,7 @@ Array Functions SELECT array_sort(ARRAY [ARRAY [1, 2], ARRAY [1, null]]); -- failed: Ordering nulls is not supported .. function:: array_sort_desc(array(T), function(T,U)) -> array(T) - :noindex: + :noindex: Returns the array sorted by values computed using specified lambda in descending order. U must be an orderable type. Null elements will be placed at the end of @@ -217,13 +212,57 @@ Array Functions nested nulls. Throws if deciding the order of elements would require comparing nested null values. :: - SELECT array_sort_desc(ARRAY ['cat', 'leopard', 'mouse'], x -> length(x)); -- ['leopard', 'mouse', 'cat'] + SELECT array_sort_desc(ARRAY ['cat', 'leopard', 'mouse'], x -> length(x)); -- ['leopard', 'mouse', 'cat'] + +.. function:: array_split_into_chunks(array(T), sz) -> array(array(T)) + + Returns an array of arrays splitting the input array into chunks of given + length. The last chunk will be shorter than the chunk length if the array's + length is not an integer multiple of the chunk length. Ignores null inputs, + but not elements. :: + + SELECT array_split_into_chunks(ARRAY [1, 2, 3, 4, 5], 2); -- [[1, 2], [3, 4], [5]] + SELECT array_split_into_chunks(ARRAY [1, 2, 3], 5); -- [[1, 2, 3]] + SELECT array_split_into_chunks(ARRAY ['a', 'b', 'c'], 2); -- [['a', 'b'], ['c']] + +.. function:: array_subset(array(T), array(int)) -> array(T) + + Returns an array containing elements from the input array at the specified 1-based indices. + Indices must be positive integers. Invalid indices (out of bounds, zero, or negative) are ignored. + Null elements at valid indices are preserved in the output. Duplicate indices result in duplicate elements in the output. + The output maintains the order of the indices array. :: + + SELECT array_subset(ARRAY[1, 2, 3, 4, 5], ARRAY[1, 3, 5]); -- [1, 3, 5] + SELECT array_subset(ARRAY['a', 'b', 'c'], ARRAY[3, 1, 2]); -- ['c', 'a', 'b'] + SELECT array_subset(ARRAY[1, NULL, 3], ARRAY[2]); -- [NULL] + SELECT array_subset(ARRAY[1, 2, 3], ARRAY[1, 1, 2]); -- [1, 1, 2] + SELECT array_subset(ARRAY[1, 2, 3], ARRAY[5, 0, -1]); -- [] .. function:: array_sum(array(T)) -> bigint/double Returns the sum of all non-null elements of the array. If there is no non-null elements, returns 0. The behaviour is similar to aggregation function sum(). T must be coercible to double. Returns bigint if T is coercible to bigint. Otherwise, returns double. +.. function:: array_top_n(array(T), int) -> array(T) + + Returns an array of the top ``n`` elements from a given ``array``, sorted according to its natural descending order. + If ``n`` is larger than the size of the given ``array``, the returned list will be the same size as the input instead of ``n``. :: + + SELECT array_top_n(ARRAY [1, 100, 2, 5, 3], 3); -- [100, 5, 3] + SELECT array_top_n(ARRAY [1, 100], 5); -- [100, 1] + SELECT array_top_n(ARRAY ['a', 'zzz', 'zz', 'b', 'g', 'f'], 3); -- ['zzz', 'zz', 'g'] + +.. function:: arrays_overlap(x, y) -> boolean + + Tests if arrays ``x`` and ``y`` have any non-null elements in common. + Returns null if there are no non-null elements in common but either array contains null. + For REAL and DOUBLE, NANs (Not-a-Number) are considered equal. + +.. function:: arrays_union(x, y) -> array + + Returns an array of the elements in the union of x and y, without duplicates. + For REAL and DOUBLE, NANs (Not-a-Number) are considered equal. + .. function:: cardinality(x) -> bigint Returns the cardinality (size) of the array ``x``. @@ -252,6 +291,32 @@ Array Functions SELECT contains(ARRAY[ARRAY[2, 3]], ARRAY[2, null]); -- failed: contains does not support arrays with elements that are null or contain null SELECT contains(ARRAY[ARRAY[2, null]], ARRAY[2, 1]); -- failed: contains does not support arrays with elements that are null or contain null +.. function:: dot_product(array(T), array(T)) -> bigint/double + + Computes the dot product of two arrays. The dot product is the sum of element-wise + products of corresponding elements. Both arrays must have the same length. + If either array is null, returns null. If arrays have different lengths, throws an error. + Null elements in arrays are treated as zero. + Returns bigint for integer arrays, double for floating-point arrays. + For empty integer arrays, returns 0. For empty floating-point arrays, returns NaN. :: + + SELECT dot_product(ARRAY[1, 2, 3], ARRAY[4, 5, 6]); -- 32 (1*4 + 2*5 + 3*6) + SELECT dot_product(ARRAY[1.0, 2.0], ARRAY[3.0, 4.0]); -- 11.0 (1.0*3.0 + 2.0*4.0) + SELECT dot_product(ARRAY[1, NULL, 3], ARRAY[4, 5, 6]); -- 22 (1*4 + 0*5 + 3*6) + SELECT dot_product(ARRAY[], ARRAY[]); -- 0 for integer arrays, NaN for floating-point arrays + SELECT dot_product(NULL, ARRAY[1, 2, 3]); -- NULL + +.. function:: dot_product(map(K, V), map(K, V)) -> bigint/double + + Computes the dot product of two maps. For maps, the dot product is computed by + multiplying values with matching keys and summing the results. Keys present in only + one map contribute zero to the result. If either map is null, returns null. + Null values in maps are treated as zero. + Returns bigint for integer value maps, double for floating-point value maps. :: + + SELECT dot_product(MAP(ARRAY[1, 2], ARRAY[10, 20]), MAP(ARRAY[1, 2], ARRAY[3, 4])); -- 110 (10*3 + 20*4) + SELECT dot_product(MAP(ARRAY['a', 'b'], ARRAY[1.0, 2.0]), MAP(ARRAY['a', 'c'], ARRAY[3.0, 4.0])); -- 3.0 (only 'a' matches) + .. function:: element_at(array(E), index) -> E Returns element of ``array`` at given ``index``. @@ -313,6 +378,31 @@ Array Functions Flattens an ``array(array(T))`` to an ``array(T)`` by concatenating the contained arrays. +.. function:: l2_norm(array(T)) -> double + + Returns the Euclidean norm (L2 norm) of an array of numeric values. + The L2 norm is calculated as the square root of the sum of squares of all elements: sqrt(sum(x^2)). + Returns 0.0 for empty arrays. Null elements are skipped in the calculation. + Supports integer and floating point types. :: + + SELECT l2_norm(ARRAY[3, 4]); -- 5.0 + SELECT l2_norm(ARRAY[1, 2, 2]); -- 3.0 + SELECT l2_norm(ARRAY[3.0, 4.0]); -- 5.0 + SELECT l2_norm(ARRAY[]); -- 0.0 + SELECT l2_norm(ARRAY[3, NULL, 4]); -- 5.0 + +.. function:: l2_norm(map(K, V)) -> double + + Returns the Euclidean norm (L2 norm) of the values in a map. + The L2 norm is calculated as the square root of the sum of squares of all values: sqrt(sum(v^2)). + Keys are ignored; only values are used in the calculation. + Returns 0.0 for empty maps. Null values are skipped in the calculation. + Supports maps with numeric value types (integer or floating point). :: + + SELECT l2_norm(MAP(ARRAY['a', 'b'], ARRAY[3, 4])); -- 5.0 + SELECT l2_norm(MAP(ARRAY[1, 2], ARRAY[3.0, 4.0])); -- 5.0 + SELECT l2_norm(MAP(ARRAY[], ARRAY[])); -- 0.0 + .. function:: ngrams(array(T), n) -> array(array(T)) Returns `n-grams `_ for the array. @@ -326,6 +416,15 @@ Array Functions SELECT ngrams(ARRAY[1, 2, 3, 4], 2); -- [[1, 2], [2, 3], [3, 4]] SELECT ngrams(ARRAY["foo", NULL, "bar"], 2); -- [["foo", NULL], [NULL, "bar"]] +.. function:: none_match(array(T), function(T, boolean)) → boolean + + Returns whether no elements of an array match the given predicate. + + Returns true if none of the elements matches the predicate (a special case is when the array is empty); + Returns false if one or more elements match; + Returns NULL if the predicate function returns NULL for one or more elements and false for all other elements. + Throws an exception if the predicate fails for one or more elements and returns false or NULL for the rest. + .. function:: reduce(array(T), initialState S, inputFunction(S,T,S), outputFunction(S,R)) -> R Returns a single value reduced from ``array``. ``inputFunction`` will @@ -348,6 +447,14 @@ Array Functions (s, x) -> CAST(ROW(x + s.sum, s.count + 1) AS ROW(sum DOUBLE, count INTEGER)), s -> IF(s.count = 0, NULL, s.sum / s.count)); +.. function:: remove_nulls(x) -> array + + Remove null values from an array ``array`` :: + + SELECT remove_nulls(ARRAY[1, NULL, 3, NULL]); -- [1, 3] + SELECT remove_nulls(ARRAY[true, false, NULL]); -- [true, false] + SELECT remove_nulls(ARRAY[ARRAY[1, 2], NULL, ARRAY[1, NULL, 3]]); -- [[1, 2], [1, null, 3]] + .. function:: repeat(element, count) -> array(E) Repeat ``element`` for ``count`` times. ``count`` cannot be negative and must be less than or equal to 10000. @@ -356,6 +463,16 @@ Array Functions Returns an array which has the reversed order of the input array. +.. function:: sequence(start, stop) -> array + + Generate a sequence of integers from start to stop, incrementing by 1 if start is less than or equal to stop, + otherwise -1. + +.. function:: sequence(start, stop, step) -> array + :noindex: + + Generate a sequence of integers from start to stop, incrementing by step. + .. function:: shuffle(array(E)) -> array(E) Generate a random permutation of the given ``array`` :: @@ -369,16 +486,6 @@ Array Functions Returns a subarray starting from index ``start``(or starting from the end if ``start`` is negative) with a length of ``length``. -.. function:: sequence(start, stop) -> array - - Generate a sequence of integers from start to stop, incrementing by 1 if start is less than or equal to stop, - otherwise -1. - -.. function:: sequence(start, stop, step) -> array - :noindex: - - Generate a sequence of integers from start to stop, incrementing by step. - .. function:: subscript(array(E), index) -> E Returns element of ``array`` at given ``index``. The index starts from one. @@ -396,6 +503,18 @@ Array Functions SELECT transform(ARRAY ['x', 'abc', 'z'], x -> x || '0'); -- ['x0', 'abc0', 'z0'] SELECT transform(ARRAY [ARRAY [1, NULL, 2], ARRAY[3, NULL]], a -> filter(a, x -> x IS NOT NULL)); -- [[1, 2], [3]] +.. function:: transform_with_index(array(T), function(T,bigint,U)) -> array(U) + + Returns an array that is the result of applying ``function`` to each element of ``array``. + The lambda function receives both the element and its 1-based index as arguments. + This is useful for transformations that need to know the position of each element:: + + SELECT transform_with_index(ARRAY [], (x, i) -> x + i); -- [] + SELECT transform_with_index(ARRAY [5, 6, 7], (x, i) -> x * i); -- [5, 12, 21] + SELECT transform_with_index(ARRAY ['a', 'b', 'c'], (x, i) -> concat(x, cast(i as varchar))); -- ['a1', 'b2', 'c3'] + SELECT transform_with_index(ARRAY [10, 20, 30], (x, i) -> i); -- [1, 2, 3] + SELECT transform_with_index(ARRAY [1, 2, 3], (x, i) -> if(i % 2 = 1, x, x * 2)); -- [1, 4, 3] + .. function:: trim_array(x, n) -> array Remove n elements from the end of ``array``:: @@ -404,14 +523,6 @@ Array Functions SELECT trim_array(ARRAY[1, 2, 3, 4], 2); -- [1, 2] SELECT trim_array(ARRAY[1, 2, 3, 4], 4); -- [] -.. function:: remove_nulls(x) -> array - - Remove null values from an array ``array`` :: - - SELECT remove_nulls(ARRAY[1, NULL, 3, NULL]); -- [1, 3] - SELECT remove_nulls(ARRAY[true, false, NULL]); -- [true, false] - SELECT remove_nulls(ARRAY[ARRAY[1, 2], NULL, ARRAY[1, NULL, 3]]); -- [[1, 2], [1, null, 3]] - .. function:: zip(array(T), array(U),..) -> array(row(T,U, ...)) Returns the merge of the given arrays, element-wise into a single array of rows. diff --git a/velox/docs/functions/presto/binary.rst b/velox/docs/functions/presto/binary.rst index 315ba71be2c..95860626277 100644 --- a/velox/docs/functions/presto/binary.rst +++ b/velox/docs/functions/presto/binary.rst @@ -6,6 +6,22 @@ Binary Functions Computes the crc32 checksum of ``binary``. +.. function:: fnv1_32(binary) -> integer + + Computes the FNV-1 32-bit hash of ``binary``. + +.. function:: fnv1_64(binary) -> bigint + + Computes the FNV-1 64-bit hash of ``binary``. + +.. function:: fnv1a_32(binary) -> integer + + Computes the FNV-1a 32-bit hash of ``binary``. + +.. function:: fnv1a_64(binary) -> bigint + + Computes the FNV-1a 64-bit hash of ``binary``. + .. function:: from_base64(string) -> varbinary Decodes a Base64-encoded ``string`` back into its original binary form. @@ -89,6 +105,11 @@ Binary Functions Computes the md5 hash of ``binary``. +.. function:: murmur3_x64_128(binary) -> varbinary + + Computes a 128-bit hash of ``binary`` that is equivalent to the 128-bit + MurmurHash3 algorithm, often called MurmurHash3_x64_128 or Murmur3F. + .. function:: rpad(binary, size, padbinary) -> varbinary :noindex: diff --git a/velox/docs/functions/presto/conversion.rst b/velox/docs/functions/presto/conversion.rst index d33e6c492d4..5d8f0ac3c6c 100644 --- a/velox/docs/functions/presto/conversion.rst +++ b/velox/docs/functions/presto/conversion.rst @@ -30,7 +30,7 @@ are supported if the conversion of their element types are supported. In additio supported conversions to/from JSON are listed in :doc:`json`. .. list-table:: - :widths: 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 + :widths: 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 :header-rows: 1 * - @@ -52,6 +52,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - ipprefix - tdigest - qdigest + - setdigest + - khyperloglog * - tinyint - Y - Y @@ -71,6 +73,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - - - + - + - * - smallint - Y - Y @@ -90,6 +94,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - - - + - + - * - integer - Y - Y @@ -109,6 +115,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - - - + - + - * - bigint - Y - Y @@ -128,6 +136,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - - - + - + - * - boolean - Y - Y @@ -147,6 +157,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - - - + - + - * - real - Y - Y @@ -166,6 +178,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - - - + - + - * - double - Y - Y @@ -185,6 +199,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - - - + - + - * - varchar - Y - Y @@ -204,6 +220,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - Y - - + - + - * - varbinary - - @@ -223,6 +241,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - Y - Y - Y + - Y + - Y * - timestamp - - @@ -242,6 +262,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - - - + - + - * - timestamp with time zone - - @@ -261,6 +283,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - - - + - + - * - date - - @@ -280,6 +304,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - - - + - + - * - interval day to second - - @@ -299,6 +325,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - - - + - + - * - decimal - Y - Y @@ -318,6 +346,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - - - + - + - * - ipaddress - - @@ -337,6 +367,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - Y - - + - + - * - ipprefix - - @@ -356,6 +388,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - Y - - + - + - * - tdigest - - @@ -375,6 +409,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - - - + - + - * - qdigest - - @@ -394,6 +430,50 @@ supported conversions to/from JSON are listed in :doc:`json`. - - - + - + - + * - setdigest + - + - + - + - + - + - + - + - + - Y + - + - + - + - + - + - + - + - + - + - + - + * - khyperloglog + - + - + - + - + - + - + - + - + - Y + - + - + - + - + - + - + - + - + - + - + - Cast to Integral Types ---------------------- @@ -870,6 +950,26 @@ This allows quantile digests to be stored and retrieved for later use. SELECT cast(qdigest_agg(cast(1.0 as double)) as varbinary); -- AHsUrkfheoQ/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAPA/AAAAAAAA8D8BAAAAAAAAAAAAAPA/AAAAAAAA8L8= +From SETDIGEST +^^^^^^^^^^^^^^ + +Returns the SetDigest as a varbinary string containing the serialized representation of the SetDigest data structure. +This allows SetDigests to be stored and retrieved for later use. + +:: + + SELECT cast(make_set_digest(1) as varbinary); + +From KHYPERLOGLOG +^^^^^^^^^^^^^^^^^ + +Returns the KHyperLogLog as a varbinary string containing the serialized representation of the KHyperLogLog data structure. +This allows KHyperLogLogs to be stored and retrieved for later use. + +:: + + SELECT cast(khyperloglog_agg(1, 123) as varbinary); + Cast to TIMESTAMP ----------------- @@ -1324,6 +1424,32 @@ This allows previously stored quantile digests to be restored for use. SELECT cast(stored_qdigest_binary as qdigest(real)); SELECT cast(stored_qdigest_binary as qdigest(double)); +Cast to SETDIGEST +----------------- + +From VARBINARY +^^^^^^^^^^^^^^ + +Returns a SetDigest reconstructed from the varbinary string containing the serialized representation. +This allows previously stored SetDigests to be restored for use. + +:: + + SELECT cast(stored_setdigest_binary as setdigest); + +Cast to KHYPERLOGLOG +-------------------- + +From VARBINARY +^^^^^^^^^^^^^^ + +Returns a KHyperLogLog reconstructed from the varbinary string containing the serialized representation. +This allows previously stored KHyperLogLogs to be restored for use. + +:: + + SELECT cast(stored_khyperloglog_binary as khyperloglog); + Cast to IPPREFIX ---------------- diff --git a/velox/docs/functions/presto/coverage.rst b/velox/docs/functions/presto/coverage.rst index aacb76f9463..6b2d88fcf22 100644 --- a/velox/docs/functions/presto/coverage.rst +++ b/velox/docs/functions/presto/coverage.rst @@ -62,12 +62,14 @@ Here is a list of all scalar, aggregate, and window functions from Presto, with table.coverage tr:nth-child(7) td:nth-child(9) {background-color: #6BA81E;} table.coverage tr:nth-child(8) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(8) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(8) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(8) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(8) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(8) td:nth-child(9) {background-color: #6BA81E;} table.coverage tr:nth-child(9) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(9) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(9) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(9) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(9) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(9) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(9) td:nth-child(9) {background-color: #6BA81E;} @@ -108,6 +110,7 @@ Here is a list of all scalar, aggregate, and window functions from Presto, with table.coverage tr:nth-child(15) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(15) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(15) td:nth-child(7) {background-color: #6BA81E;} + table.coverage tr:nth-child(16) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(16) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(16) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(16) td:nth-child(7) {background-color: #6BA81E;} @@ -117,9 +120,10 @@ Here is a list of all scalar, aggregate, and window functions from Presto, with table.coverage tr:nth-child(17) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(17) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(18) td:nth-child(1) {background-color: #6BA81E;} - table.coverage tr:nth-child(18) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(18) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(18) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(18) td:nth-child(5) {background-color: #6BA81E;} + table.coverage tr:nth-child(18) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(19) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(19) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(19) td:nth-child(3) {background-color: #6BA81E;} @@ -158,14 +162,16 @@ Here is a list of all scalar, aggregate, and window functions from Presto, with table.coverage tr:nth-child(25) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(25) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(25) td:nth-child(4) {background-color: #6BA81E;} - table.coverage tr:nth-child(25) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(25) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(26) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(26) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(26) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(26) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(27) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(27) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(27) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(27) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(27) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(27) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(28) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(28) td:nth-child(3) {background-color: #6BA81E;} @@ -175,7 +181,7 @@ Here is a list of all scalar, aggregate, and window functions from Presto, with table.coverage tr:nth-child(29) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(29) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(29) td:nth-child(4) {background-color: #6BA81E;} - table.coverage tr:nth-child(29) td:nth-child(5) {background-color: #6BA81E;} + table.coverage tr:nth-child(29) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(30) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(30) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(30) td:nth-child(4) {background-color: #6BA81E;} @@ -184,6 +190,8 @@ Here is a list of all scalar, aggregate, and window functions from Presto, with table.coverage tr:nth-child(31) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(31) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(31) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(31) td:nth-child(5) {background-color: #6BA81E;} + table.coverage tr:nth-child(31) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(32) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(32) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(32) td:nth-child(3) {background-color: #6BA81E;} @@ -193,6 +201,7 @@ Here is a list of all scalar, aggregate, and window functions from Presto, with table.coverage tr:nth-child(33) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(33) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(33) td:nth-child(5) {background-color: #6BA81E;} + table.coverage tr:nth-child(34) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(34) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(34) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(34) td:nth-child(5) {background-color: #6BA81E;} @@ -204,11 +213,13 @@ Here is a list of all scalar, aggregate, and window functions from Presto, with table.coverage tr:nth-child(36) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(36) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(36) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(36) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(37) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(37) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(37) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(37) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(37) td:nth-child(5) {background-color: #6BA81E;} + table.coverage tr:nth-child(37) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(38) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(38) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(38) td:nth-child(3) {background-color: #6BA81E;} @@ -227,6 +238,7 @@ Here is a list of all scalar, aggregate, and window functions from Presto, with table.coverage tr:nth-child(40) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(40) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(40) td:nth-child(7) {background-color: #6BA81E;} + table.coverage tr:nth-child(41) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(41) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(41) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(41) td:nth-child(4) {background-color: #6BA81E;} @@ -236,6 +248,7 @@ Here is a list of all scalar, aggregate, and window functions from Presto, with table.coverage tr:nth-child(42) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(42) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(42) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(42) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(42) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(43) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(43) td:nth-child(2) {background-color: #6BA81E;} @@ -248,8 +261,10 @@ Here is a list of all scalar, aggregate, and window functions from Presto, with table.coverage tr:nth-child(44) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(44) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(44) td:nth-child(5) {background-color: #6BA81E;} + table.coverage tr:nth-child(44) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(45) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(45) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(45) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(45) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(45) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(45) td:nth-child(7) {background-color: #6BA81E;} @@ -260,11 +275,12 @@ Here is a list of all scalar, aggregate, and window functions from Presto, with table.coverage tr:nth-child(46) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(46) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(47) td:nth-child(1) {background-color: #6BA81E;} - table.coverage tr:nth-child(47) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(47) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(47) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(47) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(47) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(48) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(48) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(48) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(48) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(48) td:nth-child(5) {background-color: #6BA81E;} @@ -277,21 +293,27 @@ Here is a list of all scalar, aggregate, and window functions from Presto, with table.coverage tr:nth-child(49) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(50) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(50) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(50) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(50) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(50) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(50) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(51) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(51) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(51) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(51) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(51) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(51) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(52) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(52) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(52) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(52) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(52) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(52) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(53) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(53) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(53) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(53) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(53) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(53) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(54) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(54) td:nth-child(3) {background-color: #6BA81E;} @@ -299,7 +321,6 @@ Here is a list of all scalar, aggregate, and window functions from Presto, with table.coverage tr:nth-child(54) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(54) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(55) td:nth-child(1) {background-color: #6BA81E;} - table.coverage tr:nth-child(55) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(55) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(55) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(55) td:nth-child(5) {background-color: #6BA81E;} @@ -311,29 +332,30 @@ Here is a list of all scalar, aggregate, and window functions from Presto, with table.coverage tr:nth-child(56) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(56) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(57) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(57) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(57) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(57) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(57) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(57) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(58) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(58) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(58) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(58) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(58) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(58) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(59) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(59) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(59) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(59) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(59) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(59) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(60) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(60) td:nth-child(2) {background-color: #6BA81E;} - table.coverage tr:nth-child(60) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(60) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(60) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(60) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(61) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(61) td:nth-child(2) {background-color: #6BA81E;} - table.coverage tr:nth-child(61) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(61) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(61) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(61) td:nth-child(7) {background-color: #6BA81E;} @@ -349,22 +371,25 @@ Here is a list of all scalar, aggregate, and window functions from Presto, with table.coverage tr:nth-child(63) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(63) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(63) td:nth-child(7) {background-color: #6BA81E;} + table.coverage tr:nth-child(64) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(64) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(64) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(64) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(64) td:nth-child(5) {background-color: #6BA81E;} - table.coverage tr:nth-child(65) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(64) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(65) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(65) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(65) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(65) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(65) td:nth-child(7) {background-color: #6BA81E;} + table.coverage tr:nth-child(66) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(66) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(66) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(66) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(66) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(66) td:nth-child(7) {background-color: #6BA81E;} - table.coverage tr:nth-child(67) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(67) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(67) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(67) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(67) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(68) td:nth-child(1) {background-color: #6BA81E;} @@ -391,6 +416,7 @@ Here is a list of all scalar, aggregate, and window functions from Presto, with table.coverage tr:nth-child(71) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(72) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(72) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(72) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(72) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(72) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(72) td:nth-child(7) {background-color: #6BA81E;} @@ -402,33 +428,44 @@ Here is a list of all scalar, aggregate, and window functions from Presto, with table.coverage tr:nth-child(73) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(74) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(74) td:nth-child(2) {background-color: #6BA81E;} - table.coverage tr:nth-child(74) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(74) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(74) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(74) td:nth-child(7) {background-color: #6BA81E;} + table.coverage tr:nth-child(75) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(75) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(75) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(75) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(75) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(75) td:nth-child(7) {background-color: #6BA81E;} + table.coverage tr:nth-child(76) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(76) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(76) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(76) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(76) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(76) td:nth-child(7) {background-color: #6BA81E;} + table.coverage tr:nth-child(77) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(77) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(77) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(77) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(77) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(77) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(78) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(78) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(78) td:nth-child(3) {background-color: #6BA81E;} - table.coverage tr:nth-child(78) td:nth-child(5) {background-color: #6BA81E;} + table.coverage tr:nth-child(78) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(78) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(79) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(79) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(79) td:nth-child(3) {background-color: #6BA81E;} - table.coverage tr:nth-child(79) td:nth-child(5) {background-color: #6BA81E;} + table.coverage tr:nth-child(79) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(80) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(80) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(80) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(80) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(81) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(81) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(81) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(81) td:nth-child(4) {background-color: #6BA81E;} .. table:: @@ -438,83 +475,85 @@ Here is a list of all scalar, aggregate, and window functions from Presto, with ======================================== ======================================== ======================================== ======================================== ======================================== == ======================================== == ======================================== Scalar Functions Aggregate Functions Window Functions ================================================================================================================================================================================================================ == ======================================== == ======================================== - :func:`abs` :func:`date_diff` :func:`ip_subnet_range` :func:`random` :func:`st_numgeometries` :func:`approx_distinct` :func:`cume_dist` - :func:`acos` :func:`date_format` :func:`is_finite` :func:`reduce` :func:`st_numinteriorring` :func:`approx_most_frequent` :func:`dense_rank` - :func:`all_match` :func:`date_parse` :func:`is_infinite` :func:`regexp_extract` :func:`st_numpoints` :func:`approx_percentile` :func:`first_value` - :func:`any_keys_match` :func:`date_trunc` :func:`is_json_scalar` :func:`regexp_extract_all` :func:`st_overlaps` :func:`approx_set` :func:`lag` - :func:`any_match` :func:`day` :func:`is_nan` :func:`regexp_like` :func:`st_point` :func:`arbitrary` :func:`last_value` - :func:`any_values_match` :func:`day_of_month` :func:`is_private_ip` :func:`regexp_replace` :func:`st_pointn` :func:`array_agg` :func:`lead` - :func:`array_average` :func:`day_of_week` :func:`is_subnet_of` :func:`regexp_split` :func:`st_points` :func:`avg` :func:`nth_value` - :func:`array_cum_sum` :func:`day_of_year` jaccard_index regress :func:`st_polygon` :func:`bitwise_and_agg` :func:`ntile` - :func:`array_distinct` :func:`degrees` :func:`json_array_contains` reidentification_potential :func:`st_relate` :func:`bitwise_or_agg` :func:`percent_rank` - :func:`array_duplicates` :func:`dow` :func:`json_array_get` :func:`remove_nulls` :func:`st_startpoint` :func:`bool_and` :func:`rank` - :func:`array_except` :func:`doy` :func:`json_array_length` render :func:`st_symdifference` :func:`bool_or` :func:`row_number` - :func:`array_frequency` :func:`e` :func:`json_extract` :func:`repeat` :func:`st_touches` :func:`checksum` - :func:`array_has_duplicates` :func:`element_at` :func:`json_extract_scalar` :func:`replace` :func:`st_union` :func:`classification_fall_out` - :func:`array_intersect` :func:`empty_approx_set` :func:`json_format` :func:`replace_first` :func:`st_within` :func:`classification_miss_rate` - :func:`array_join` :func:`ends_with` :func:`json_parse` :func:`reverse` :func:`st_x` :func:`classification_precision` - array_least_frequent enum_key :func:`json_size` rgb :func:`st_xmax` :func:`classification_recall` - :func:`array_max` :func:`exp` key_sampling_percent :func:`round` :func:`st_xmin` :func:`classification_thresholds` - :func:`array_max_by` expand_envelope :func:`laplace_cdf` :func:`rpad` :func:`st_y` convex_hull_agg - :func:`array_min` :func:`f_cdf` :func:`last_day_of_month` :func:`rtrim` :func:`st_ymax` :func:`corr` - :func:`array_min_by` features :func:`least` :func:`scale_qdigest` :func:`st_ymin` :func:`count` - :func:`array_normalize` :func:`filter` :func:`length` :func:`second` :func:`starts_with` :func:`count_if` - :func:`array_position` :func:`filter` :func:`levenshtein_distance` :func:`secure_rand` :func:`strpos` :func:`covar_pop` - :func:`array_remove` :func:`find_first` :func:`line_interpolate_point` :func:`secure_random` :func:`strrpos` :func:`covar_samp` - :func:`array_sort` :func:`find_first_index` :func:`line_locate_point` :func:`sequence` :func:`substr` differential_entropy - :func:`array_sort_desc` :func:`flatten` :func:`ln` :func:`sha1` :func:`tan` :func:`entropy` - array_split_into_chunks :func:`flatten_geometry_collections` localtime :func:`sha256` :func:`tanh` evaluate_classifier_predictions - :func:`array_sum` :func:`floor` localtimestamp :func:`sha512` tdigest_agg :func:`every` - :func:`array_top_n` fnv1_32 :func:`log10` :func:`shuffle` :func:`timezone_hour` :func:`geometric_mean` - :func:`array_union` fnv1_64 :func:`log2` :func:`sign` :func:`timezone_minute` geometry_union_agg - :func:`arrays_overlap` fnv1a_32 :func:`lower` :func:`simplify_geometry` :func:`to_base` :func:`histogram` - :func:`asin` fnv1a_64 :func:`lpad` :func:`sin` to_base32 khyperloglog_agg - :func:`atan` :func:`format_datetime` :func:`ltrim` sketch_kll_quantile :func:`to_base64` :func:`kurtosis` - :func:`atan2` :func:`from_base` :func:`map` sketch_kll_rank :func:`to_base64url` learn_classifier - bar from_base32 :func:`map_concat` :func:`slice` :func:`to_big_endian_32` learn_libsvm_classifier - :func:`beta_cdf` :func:`from_base64` :func:`map_entries` spatial_partitions :func:`to_big_endian_64` learn_libsvm_regressor - :func:`bing_tile` :func:`from_base64url` :func:`map_filter` :func:`split` to_geometry learn_regressor - :func:`bing_tile_at` :func:`from_big_endian_32` :func:`map_from_entries` :func:`split_part` :func:`to_hex` make_set_digest - :func:`bing_tile_children` :func:`from_big_endian_64` :func:`map_keys` :func:`split_to_map` :func:`to_ieee754_32` :func:`map_agg` - :func:`bing_tile_coordinates` :func:`from_hex` :func:`map_keys_by_top_n_values` :func:`split_to_multimap` :func:`to_ieee754_64` :func:`map_union` - :func:`bing_tile_parent` :func:`from_ieee754_32` :func:`map_normalize` :func:`spooky_hash_v2_32` :func:`to_iso8601` :func:`map_union_sum` - bing_tile_polygon :func:`from_ieee754_64` :func:`map_remove_null_values` :func:`spooky_hash_v2_64` :func:`to_milliseconds` :func:`max` - :func:`bing_tile_quadkey` :func:`from_iso8601_date` :func:`map_subset` :func:`sqrt` to_spherical_geography :func:`max_by` - :func:`bing_tile_zoom_level` :func:`from_iso8601_timestamp` :func:`map_top_n` :func:`st_area` :func:`to_unixtime` :func:`merge` - :func:`bing_tiles_around` :func:`from_unixtime` :func:`map_top_n_keys` :func:`st_asbinary` :func:`to_utf8` merge_set_digest - :func:`binomial_cdf` :func:`from_utf8` map_top_n_keys_by_value :func:`st_astext` :func:`trail` :func:`min` - :func:`bit_count` :func:`gamma_cdf` :func:`map_top_n_values` :func:`st_boundary` :func:`transform` :func:`min_by` - :func:`bitwise_and` geometry_as_geojson :func:`map_values` :func:`st_buffer` :func:`transform_keys` :func:`multimap_agg` - :func:`bitwise_arithmetic_shift_right` geometry_from_geojson :func:`map_zip_with` :func:`st_centroid` :func:`transform_values` :func:`noisy_avg_gaussian` - :func:`bitwise_left_shift` :func:`geometry_invalid_reason` :func:`md5` :func:`st_contains` :func:`trim` :func:`noisy_count_gaussian` - :func:`bitwise_logical_shift_right` :func:`geometry_nearest_points` merge_hll :func:`st_convexhull` :func:`trim_array` :func:`noisy_count_if_gaussian` - :func:`bitwise_not` geometry_to_bing_tiles merge_khll :func:`st_coorddim` :func:`truncate` :func:`noisy_sum_gaussian` - :func:`bitwise_or` geometry_to_dissolved_bing_tiles :func:`millisecond` :func:`st_crosses` :func:`typeof` :func:`numeric_histogram` - :func:`bitwise_right_shift` geometry_union :func:`minute` :func:`st_difference` uniqueness_distribution :func:`qdigest_agg` - :func:`bitwise_right_shift_arithmetic` great_circle_distance :func:`mod` :func:`st_dimension` :func:`upper` :func:`reduce_agg` - :func:`bitwise_shift_left` :func:`greatest` :func:`month` :func:`st_disjoint` :func:`url_decode` :func:`regr_avgx` - :func:`bitwise_xor` :func:`hamming_distance` :func:`multimap_from_entries` :func:`st_distance` :func:`url_encode` :func:`regr_avgy` - :func:`cardinality` hash_counts :func:`murmur3_x64_128` :func:`st_endpoint` :func:`url_extract_fragment` :func:`regr_count` - :func:`cauchy_cdf` :func:`hmac_md5` myanmar_font_encoding :func:`st_envelope` :func:`url_extract_host` :func:`regr_intercept` - :func:`cbrt` :func:`hmac_sha1` myanmar_normalize_unicode :func:`st_envelopeaspts` :func:`url_extract_parameter` :func:`regr_r2` - :func:`ceil` :func:`hmac_sha256` :func:`nan` :func:`st_equals` :func:`url_extract_path` :func:`regr_slope` - :func:`ceiling` :func:`hmac_sha512` :func:`ngrams` :func:`st_exteriorring` :func:`url_extract_port` :func:`regr_sxx` - :func:`chi_squared_cdf` :func:`hour` :func:`no_keys_match` :func:`st_geometries` :func:`url_extract_protocol` :func:`regr_sxy` - :func:`chr` :func:`infinity` :func:`no_values_match` :func:`st_geometryfromtext` :func:`url_extract_query` :func:`regr_syy` - classify intersection_cardinality :func:`none_match` :func:`st_geometryn` :func:`uuid` reservoir_sample - :func:`codepoint` :func:`inverse_beta_cdf` :func:`normal_cdf` :func:`st_geometrytype` :func:`value_at_quantile` :func:`set_agg` - color :func:`inverse_binomial_cdf` :func:`normalize` :func:`st_geomfrombinary` :func:`values_at_quantiles` :func:`set_union` - :func:`combinations` :func:`inverse_cauchy_cdf` now :func:`st_interiorringn` :func:`week` sketch_kll - :func:`concat` :func:`inverse_chi_squared_cdf` :func:`parse_datetime` :func:`st_interiorrings` :func:`week_of_year` sketch_kll_with_k - :func:`contains` :func:`inverse_f_cdf` :func:`parse_duration` :func:`st_intersection` :func:`weibull_cdf` :func:`skewness` - :func:`cos` :func:`inverse_gamma_cdf` :func:`parse_presto_data_size` :func:`st_intersects` :func:`width_bucket` spatial_partitioning - :func:`cosh` :func:`inverse_laplace_cdf` :func:`pi` :func:`st_isclosed` :func:`wilson_interval_lower` :func:`stddev` - :func:`cosine_similarity` :func:`inverse_normal_cdf` pinot_binary_decimal_to_double :func:`st_isempty` :func:`wilson_interval_upper` :func:`stddev_pop` - :func:`crc32` :func:`inverse_poisson_cdf` :func:`poisson_cdf` :func:`st_isring` :func:`word_stem` :func:`stddev_samp` - :func:`current_date` :func:`inverse_weibull_cdf` :func:`pow` :func:`st_issimple` :func:`xxhash64` :func:`sum` - current_time :func:`ip_prefix` :func:`power` :func:`st_isvalid` :func:`year` :func:`tdigest_agg` - current_timestamp :func:`ip_prefix_collapse` :func:`quantile_at_value` :func:`st_length` :func:`year_of_week` :func:`var_pop` - current_timezone :func:`ip_prefix_subnets` :func:`quarter` st_linefromtext :func:`yow` :func:`var_samp` - :func:`date` :func:`ip_subnet_max` :func:`radians` st_linestring :func:`zip` :func:`variance` - :func:`date_add` :func:`ip_subnet_min` :func:`rand` st_multipoint :func:`zip_with` + :func:`abs` :func:`date_format` :func:`ip_subnet_range` :func:`random` :func:`st_numpoints` :func:`approx_distinct` :func:`cume_dist` + :func:`acos` :func:`date_parse` :func:`is_finite` :func:`reduce` :func:`st_overlaps` :func:`approx_most_frequent` :func:`dense_rank` + :func:`all_match` :func:`date_trunc` :func:`is_infinite` :func:`regexp_extract` :func:`st_point` :func:`approx_percentile` :func:`first_value` + :func:`any_keys_match` :func:`day` :func:`is_json_scalar` :func:`regexp_extract_all` :func:`st_pointn` :func:`approx_set` :func:`lag` + :func:`any_match` :func:`day_of_month` :func:`is_nan` :func:`regexp_like` :func:`st_points` :func:`arbitrary` :func:`last_value` + :func:`any_values_match` :func:`day_of_week` :func:`is_private_ip` :func:`regexp_replace` :func:`st_polygon` :func:`array_agg` :func:`lead` + :func:`array_average` :func:`day_of_year` :func:`is_subnet_of` :func:`regexp_split` :func:`st_relate` :func:`avg` :func:`nth_value` + :func:`array_cum_sum` :func:`degrees` :func:`jaccard_index` regress :func:`st_startpoint` :func:`bitwise_and_agg` :func:`ntile` + :func:`array_distinct` :func:`dot_product` :func:`json_array_contains` :func:`reidentification_potential` :func:`st_symdifference` :func:`bitwise_or_agg` :func:`percent_rank` + :func:`array_duplicates` :func:`dow` :func:`json_array_get` :func:`remove_nulls` :func:`st_touches` :func:`bool_and` :func:`rank` + :func:`array_except` :func:`doy` :func:`json_array_length` render :func:`st_union` :func:`bool_or` :func:`row_number` + :func:`array_frequency` :func:`e` :func:`json_extract` :func:`repeat` :func:`st_within` :func:`checksum` + :func:`array_has_duplicates` :func:`element_at` :func:`json_extract_scalar` :func:`replace` :func:`st_x` :func:`classification_fall_out` + :func:`array_intersect` :func:`empty_approx_set` :func:`json_format` :func:`replace_first` :func:`st_xmax` :func:`classification_miss_rate` + :func:`array_join` :func:`ends_with` :func:`json_parse` :func:`reverse` :func:`st_xmin` :func:`classification_precision` + array_least_frequent :func:`enum_key` :func:`json_size` rgb :func:`st_y` :func:`classification_recall` + :func:`array_max` :func:`exp` key_sampling_percent :func:`round` :func:`st_ymax` :func:`classification_thresholds` + :func:`array_max_by` :func:`expand_envelope` l2_squared :func:`rpad` :func:`st_ymin` :func:`convex_hull_agg` + :func:`array_min` :func:`f_cdf` :func:`laplace_cdf` :func:`rtrim` :func:`starts_with` :func:`corr` + :func:`array_min_by` features :func:`last_day_of_month` :func:`scale_qdigest` :func:`strpos` :func:`count` + :func:`array_normalize` :func:`filter` :func:`least` :func:`second` :func:`strrpos` :func:`count_if` + :func:`array_position` :func:`filter` :func:`length` :func:`secure_rand` :func:`substr` :func:`covar_pop` + :func:`array_remove` :func:`find_first` :func:`levenshtein_distance` :func:`secure_random` :func:`tan` :func:`covar_samp` + :func:`array_sort` :func:`find_first_index` :func:`line_interpolate_point` :func:`sequence` :func:`tanh` differential_entropy + :func:`array_sort_desc` :func:`flatten` :func:`line_locate_point` :func:`sha1` tdigest_agg :func:`entropy` + array_split_into_chunks :func:`flatten_geometry_collections` :func:`ln` :func:`sha256` :func:`timezone_hour` evaluate_classifier_predictions + :func:`array_sum` :func:`floor` :func:`localtime` :func:`sha512` :func:`timezone_minute` :func:`every` + :func:`array_top_n` :func:`fnv1_32` :func:`localtimestamp` :func:`shuffle` :func:`to_base` :func:`geometric_mean` + :func:`array_union` :func:`fnv1_64` :func:`log10` :func:`sign` to_base32 geometry_union_agg + :func:`arrays_overlap` :func:`fnv1a_32` :func:`log2` :func:`simplify_geometry` :func:`to_base64` :func:`histogram` + :func:`asin` :func:`fnv1a_64` :func:`longest_common_prefix` :func:`sin` :func:`to_base64url` khyperloglog_agg + :func:`atan` :func:`format_datetime` :func:`lower` sketch_kll_quantile :func:`to_big_endian_32` :func:`kurtosis` + :func:`atan2` :func:`from_base` :func:`lpad` sketch_kll_rank :func:`to_big_endian_64` learn_classifier + bar :func:`from_base32` :func:`ltrim` :func:`slice` :func:`to_geometry` learn_libsvm_classifier + :func:`beta_cdf` :func:`from_base64` :func:`map` spatial_partitions :func:`to_hex` learn_libsvm_regressor + :func:`bing_tile` :func:`from_base64url` :func:`map_concat` :func:`split` :func:`to_ieee754_32` learn_regressor + :func:`bing_tile_at` :func:`from_big_endian_32` :func:`map_entries` :func:`split_part` :func:`to_ieee754_64` :func:`make_set_digest` + :func:`bing_tile_children` :func:`from_big_endian_64` :func:`map_filter` :func:`split_to_map` :func:`to_iso8601` :func:`map_agg` + :func:`bing_tile_coordinates` :func:`from_hex` :func:`map_from_entries` :func:`split_to_multimap` :func:`to_milliseconds` :func:`map_union` + :func:`bing_tile_parent` :func:`from_ieee754_32` :func:`map_keys` :func:`spooky_hash_v2_32` :func:`to_spherical_geography` :func:`map_union_sum` + :func:`bing_tile_polygon` :func:`from_ieee754_64` :func:`map_keys_by_top_n_values` :func:`spooky_hash_v2_64` :func:`to_unixtime` :func:`max` + :func:`bing_tile_quadkey` :func:`from_iso8601_date` :func:`map_normalize` :func:`sqrt` :func:`to_utf8` :func:`max_by` + :func:`bing_tile_zoom_level` :func:`from_iso8601_timestamp` :func:`map_remove_null_values` :func:`st_area` :func:`trail` :func:`merge` + :func:`bing_tiles_around` :func:`from_unixtime` :func:`map_subset` :func:`st_asbinary` :func:`transform` :func:`merge_set_digest` + :func:`binomial_cdf` :func:`from_utf8` :func:`map_top_n` :func:`st_astext` :func:`transform_keys` :func:`min` + :func:`bit_count` :func:`gamma_cdf` :func:`map_top_n_keys` :func:`st_boundary` :func:`transform_values` :func:`min_by` + :func:`bit_length` :func:`geometry_as_geojson` map_top_n_keys_by_value :func:`st_buffer` :func:`trim` :func:`multimap_agg` + :func:`bitwise_and` :func:`geometry_from_geojson` :func:`map_top_n_values` :func:`st_centroid` :func:`trim_array` :func:`noisy_avg_gaussian` + :func:`bitwise_arithmetic_shift_right` :func:`geometry_invalid_reason` :func:`map_values` :func:`st_contains` :func:`truncate` :func:`noisy_count_gaussian` + :func:`bitwise_left_shift` :func:`geometry_nearest_points` :func:`map_zip_with` :func:`st_convexhull` :func:`typeof` :func:`noisy_count_if_gaussian` + :func:`bitwise_logical_shift_right` :func:`geometry_to_bing_tiles` :func:`md5` :func:`st_coorddim` :func:`uniqueness_distribution` :func:`noisy_sum_gaussian` + :func:`bitwise_not` :func:`geometry_to_dissolved_bing_tiles` :func:`merge_hll` :func:`st_crosses` :func:`upper` :func:`numeric_histogram` + :func:`bitwise_or` :func:`geometry_union` :func:`merge_khll` :func:`st_difference` :func:`url_decode` :func:`qdigest_agg` + :func:`bitwise_right_shift` google_polyline_decode :func:`millisecond` :func:`st_dimension` :func:`url_encode` :func:`reduce_agg` + :func:`bitwise_right_shift_arithmetic` google_polyline_encode :func:`minute` :func:`st_disjoint` :func:`url_extract_fragment` :func:`regr_avgx` + :func:`bitwise_shift_left` :func:`great_circle_distance` :func:`mod` :func:`st_distance` :func:`url_extract_host` :func:`regr_avgy` + :func:`bitwise_xor` :func:`greatest` :func:`month` :func:`st_endpoint` :func:`url_extract_parameter` :func:`regr_count` + :func:`cardinality` :func:`hamming_distance` :func:`multimap_from_entries` :func:`st_envelope` :func:`url_extract_path` :func:`regr_intercept` + :func:`cauchy_cdf` :func:`hash_counts` :func:`murmur3_x64_128` :func:`st_envelopeaspts` :func:`url_extract_port` :func:`regr_r2` + :func:`cbrt` :func:`hmac_md5` myanmar_font_encoding :func:`st_equals` :func:`url_extract_protocol` :func:`regr_slope` + :func:`ceil` :func:`hmac_sha1` myanmar_normalize_unicode :func:`st_exteriorring` :func:`url_extract_query` :func:`regr_sxx` + :func:`ceiling` :func:`hmac_sha256` :func:`nan` :func:`st_geometries` :func:`uuid` :func:`regr_sxy` + :func:`chi_squared_cdf` :func:`hmac_sha512` :func:`ngrams` :func:`st_geometryfromtext` :func:`value_at_quantile` :func:`regr_syy` + :func:`chr` :func:`hour` :func:`no_keys_match` :func:`st_geometryn` :func:`values_at_quantiles` :func:`reservoir_sample` + classify :func:`infinity` :func:`no_values_match` :func:`st_geometrytype` :func:`week` :func:`set_agg` + :func:`codepoint` :func:`intersection_cardinality` :func:`none_match` :func:`st_geomfrombinary` :func:`week_of_year` :func:`set_union` + color :func:`inverse_beta_cdf` :func:`normal_cdf` :func:`st_interiorringn` :func:`weibull_cdf` sketch_kll + :func:`combinations` :func:`inverse_binomial_cdf` :func:`normalize` :func:`st_interiorrings` :func:`width_bucket` sketch_kll_with_k + :func:`concat` :func:`inverse_cauchy_cdf` :func:`now` :func:`st_intersection` :func:`wilson_interval_lower` :func:`skewness` + :func:`contains` :func:`inverse_chi_squared_cdf` :func:`parse_datetime` :func:`st_intersects` :func:`wilson_interval_upper` spatial_partitioning + :func:`cos` :func:`inverse_f_cdf` :func:`parse_duration` :func:`st_isclosed` :func:`word_stem` :func:`stddev` + :func:`cosh` :func:`inverse_gamma_cdf` :func:`parse_presto_data_size` :func:`st_isempty` :func:`xxhash64` :func:`stddev_pop` + :func:`cosine_similarity` :func:`inverse_laplace_cdf` :func:`pi` :func:`st_isring` :func:`year` :func:`stddev_samp` + :func:`crc32` :func:`inverse_normal_cdf` pinot_binary_decimal_to_double :func:`st_issimple` :func:`year_of_week` :func:`sum` + :func:`current_date` :func:`inverse_poisson_cdf` :func:`poisson_cdf` :func:`st_isvalid` :func:`yow` :func:`tdigest_agg` + :func:`current_time` :func:`inverse_weibull_cdf` :func:`pow` :func:`st_length` :func:`zip` :func:`var_pop` + :func:`current_timestamp` :func:`ip_prefix` :func:`power` :func:`st_linefromtext` :func:`zip_with` :func:`var_samp` + :func:`current_timezone` :func:`ip_prefix_collapse` :func:`quantile_at_value` :func:`st_linestring` :func:`variance` + :func:`date` :func:`ip_prefix_subnets` :func:`quarter` :func:`st_multipoint` + :func:`date_add` :func:`ip_subnet_max` :func:`radians` :func:`st_numgeometries` + :func:`date_diff` :func:`ip_subnet_min` :func:`rand` :func:`st_numinteriorring` ======================================== ======================================== ======================================== ======================================== ======================================== == ======================================== == ======================================== diff --git a/velox/docs/functions/presto/datetime.rst b/velox/docs/functions/presto/datetime.rst index 98e42dac552..5703b8baf10 100644 --- a/velox/docs/functions/presto/datetime.rst +++ b/velox/docs/functions/presto/datetime.rst @@ -144,7 +144,10 @@ Date and Time Functions .. function:: from_unixtime(unixtime) -> timestamp - Returns the UNIX timestamp ``unixtime`` as a timestamp. + Returns the UNIX timestamp ``unixtime`` as a timestamp. If the + :doc:`adjust_timestamp_to_session_timezone <../../configs>` property is set + to true, then the timestamp is adjusted to the time zone specified in + :doc:`session_timezone <../../configs>`. .. function:: from_unixtime(unixtime, string) -> timestamp with time zone :noindex: @@ -158,6 +161,10 @@ Date and Time Functions using ``hours`` and ``minutes`` for the time zone offset. The offset must be in [-14:00, 14:00] range. +.. function:: localtimestamp -> timestamp + + Returns the timestamp as of the start of the query. + .. function:: to_iso8601(x) -> varchar Formats ``x`` as an ISO 8601 string. Supported types for ``x`` are: @@ -177,6 +184,31 @@ Date and Time Functions Returns ``timestamp`` as a UNIX timestamp. +.. function:: current_time() -> time with time zone + + Returns the current time since midnight with the session timezoneReturns the current time since midnight with the session timezone, based on the query session start time. + +.. function:: current_timezone() -> varchar + + Returns the current session time zone as a varchar. + + Example:: + + SELECT current_timezone; -- Asia/Kolkata + +.. function:: current_timestamp() -> timestamp with time zone +.. function:: now() -> timestamp with time zone + + Returns the current timestamp with session time zone applied. + The timestamp is captured once at the start of query execution and remains + constant throughout the query. This matches the standard SQL behavior for + ``CURRENT_TIMESTAMP`` and ``NOW()``. + + Example:: + + SELECT current_timestamp; -- 2025-07-17 14:53:12.123 Asia/Kolkata + SELECT now(); -- 2025-07-17 14:53:12.123 Asia/Kolkata + Truncation Function ------------------- @@ -314,6 +346,10 @@ Specifier Description Formats ``x`` as a string using ``format``. ``x`` is a timestamp or a timestamp with time zone. +.. function:: date_parse(string, format) -> timestamp + + Parses ``string`` into a timestamp using ``format``. + Java Date Functions ------------------- diff --git a/velox/docs/functions/presto/decimal.rst b/velox/docs/functions/presto/decimal.rst index 539398e7124..875fc813e9b 100644 --- a/velox/docs/functions/presto/decimal.rst +++ b/velox/docs/functions/presto/decimal.rst @@ -235,6 +235,14 @@ Decimal Functions Returns absolute value of x (r = `|x|`). +.. function:: ceil(x: decimal(p, s)) -> r: decimal(pr, 0) + + Returns 'x' rounded up to the nearest integer. The scale of the result is 0. + The precision is calculated as: + :: + + pr = min(38, p - s + min(s, 1)) + .. function:: divide(x: decimal(p1, s1), y: decimal(p2, s2)) -> r: decimal(p, s) Returns the result of dividing x by y (r = x / y). diff --git a/velox/docs/functions/presto/geospatial.rst b/velox/docs/functions/presto/geospatial.rst index 256a6cb825f..052dfe05da9 100644 --- a/velox/docs/functions/presto/geospatial.rst +++ b/velox/docs/functions/presto/geospatial.rst @@ -73,6 +73,40 @@ Geometry Constructors Returns a geometry type polygon object from WKT representation. +.. function:: ST_LineFromText(wkt: varchar) -> linestring: Geometry + + Returns a geometry type linestring object from WKT representation. + An error is returned if the input WKT represents a valid non-LineString + geometry. Null input returns null output. + +.. function:: ST_LineString(points: array(Geometry)) -> linestring: Geometry + + Returns a LineString formed from an array of points. If there are fewer + than two non-empty points in the input array, an empty LineString will + be returned. Throws an exception if any element in the array is null or + empty or same as the previous one. The returned geometry may not be simple, + e.g. may self-intersect or may contain duplicate vertexes depending on the + input. + +.. function:: ST_MultiPoint(points: array(Geometry)) -> multipoint: Geometry + + Returns a MultiPoint geometry object formed from the specified points. + Return null if input array is empty. Throws an exception if any element + in the array is null or empty. The returned geometry may not be simple + and may contain duplicate points if input array has duplicates. + +.. function:: to_spherical_geography(input: Geometry) -> output: SphericalGeography + + Converts a ``Geometry`` object to a SphericalGeography object on the sphere + of the Earth’s radius. For each point of the input geometry, it verifies that + point.x is within [-180.0, 180.0] and point.y is within [-90.0, 90.0], + and uses them as (longitude, latitude) degrees to construct the shape + of the ``SphericalGeography`` result. + +.. function:: to_geometry(input: SphericalGeography) -> output: Geometry + + Converts a SphericalGeography object to a Geometry object. + Spatial Predicates ------------------ @@ -101,7 +135,9 @@ function you are using. .. function:: ST_Equals(geometry1: Geometry, geometry2: Geometry) -> boolean - Returns ``true`` if the given geometries represent the same geometry. + Returns ``true`` if the given geometries represent the same geometry + according to ISO SQL/MM semantics. Also returns ``true`` if both geometries are empty, + regardless of their geometry types. .. function:: ST_Intersects(geometry1: Geometry, geometry2: Geometry) -> boolean @@ -114,7 +150,7 @@ function you are using. Returns ``true`` if the given geometries share space, are of the same dimension, but are not completely contained by each other. -.. function:: ST_Relat(geometry1: Geometry, geometry2: Geometry, relation: varchar) -> boolean +.. function:: ST_Relate(geometry1: Geometry, geometry2: Geometry, relation: varchar) -> boolean Returns true if first geometry is spatially related to second geometry as described by the relation. The relation is a string like ``'"1*T***T**'``: @@ -178,6 +214,24 @@ Spatial Operations Empty geometries will return an empty polygon. Negative or NaN distances will return an error. Positive infinity distances may lead to undefined results. +.. function:: geometry_union(geometries: array(Geometry)) -> union: Geometry + + Returns a geometry that represents the point set union of the input geometries. + Performance of this function, in conjunction with array_agg() to first + aggregate the input geometries, may be better than geometry_union_agg(), + at the expense of higher memory utilization. Null elements in the input + array are ignored. Empty array input returns null. + +.. function:: geometry_union_agg(geometry: Geometry) -> union: Geometry + + Returns a geometry that represents the point set union of the aggregated + input geometries. Null geometries are ignored. Empty input returns null. + +.. function:: convex_hull_agg(geometry: Geometry) -> union: Geometry + + Returns a geometry that represents the convex hull of the points in the + aggregated input geometries. Null geometries are ignored. Empty input + returns null. Accessors --------- @@ -213,6 +267,12 @@ Accessors on a two dimensional plane (based on spatial ref) in projected units. Will return an error if the input geometry is not a LineString or MultiLineString. +.. function:: ST_Length(sphericalgeography: SphericalGeography) -> length: double + + Returns the length of a ``LineString`` or ``MultiLineString`` on a spherical model of the + Earth. This is equivalent to the sum of great-circle distances between adjacent points + on the ``LineString``. + .. function:: ST_PointN(linestring: Geometry, index: integer) -> point: geometry Returns the vertex of a LineString at a given index (indices start at 1). @@ -224,7 +284,7 @@ Accessors Returns an array of points in a geometry. Empty or null inputs return null. -.. function:: ST_NumPoints(geometry: Geometry) -> points: integer +.. function:: ST_NumPoints(geometry: Geometry) -> points: bigint Returns the number of points in a geometry. This is an extension to the SQL/MM ``ST_NumPoints`` function which only applies to @@ -251,6 +311,10 @@ Accessors reason. If the geometry is valid and simple (or ``NULL``), return ``NULL``. This function is relatively expensive. +.. function:: great_circle_distance(latitude1, longitude1, latitude2, longitude2) -> double + + Returns the great-circle distance between two points on Earth's surface in kilometers. + .. function:: ST_Area(geometry: Geometry) -> area: double Returns the 2D Euclidean area of ``geometry``. @@ -258,16 +322,36 @@ Accessors returns the sum of the areas of the individual geometries. Empty geometries return 0. +.. function:: ST_Area(sphericalgeography: SphericalGeography) -> area: double + + Returns the area of a polygon or multi-polygon in square meters using a spherical model for Earth. + .. function:: ST_Centroid(geometry: Geometry) -> geometry: Geometry Returns the point value that is the mathematical centroid of ``geometry``. - Empty geometry inputs result in empty output. + Empty geometry inputs result in null output. + +.. function:: ST_Centroid(SphericalGeography) -> Point + + Returns the point value that is the mathematical centroid of a spherical geometry. + Empty geometry inputs result in null output. + + It supports Points and MultiPoints as input and returns the three-dimensional + centroid projected onto the surface of the (spherical) Earth. + For example, MULTIPOINT (0 -45, 0 45, 30 0, -30 0) returns Point(0, 0). + Note: In the case that the three-dimensional centroid is at (0, 0, 0) + (e.g. MULTIPOINT (0 0, -180 0)), the spherical centroid is undefined and an + arbitrary point will be returned. .. function:: ST_Distance(geometry1: Geometry, geometry2: Geometry) -> distance: double Returns the 2-dimensional cartesian minimum distance (based on spatial ref) between two geometries in projected units. Empty geometries result in null output. +.. function:: ST_Distance(sphericalgeography1: SphericalGeography, sphericalgeography2: SphericalGeography) -> distance: double + + Returns the great-circle distance in meters between two SphericalGeography points. + .. function:: ST_GeometryType(geometry: Geometry) -> type: varchar Returns the type of the geometry. @@ -376,7 +460,7 @@ Accessors GEOMETRYCOLLECTION (POINT (0 0), GEOMETRYCOLLECTION (POINT (1 1))) -> [POINT (0 0), POINT (1 1)], GEOMETRYCOLLECTION EMPTY -> []. -.. function:: ST_NumInteriorRing(geometry: Geometry) -> output: integer +.. function:: ST_NumInteriorRing(geometry: Geometry) -> output: bigint Returns the cardinality of the collection of interior rings of a polygon. @@ -384,7 +468,7 @@ Accessors Returns the minimum convex geometry that encloses all input geometries. -.. function:: ST_CoordDim(geometry: Geometry) -> output: integer +.. function:: ST_CoordDim(geometry: Geometry) -> output: tinyint Return the coordinate dimension of the geometry. @@ -472,6 +556,16 @@ for more details. Creates a Bing tile object from a quadkey. An invalid quadkey will return a User Error. +.. function:: bing_tiles_around(latitude, longitude, zoom_level) -> array(BingTile) + + Returns a collection of Bing tiles that surround the point specified + by the latitude and longitude arguments at a given zoom level. + +.. function:: bing_tiles_around(latitude, longitude, zoom_level, radius_in_km) -> array(BingTile) + + Returns a minimum set of Bing tiles at specified zoom level that cover a circle of specified + radius in km around a specified (latitude, longitude) point. + .. function:: bing_tile_coordinates(tile: BingTile) -> coords: row(integer,integer) Returns the ``x``, ``y`` coordinates of a given Bing tile as ``row(x, y)``. @@ -503,6 +597,16 @@ for more details. childZoom is less than the tile's zoom. The order is deterministic but not specified. +.. function:: bing_tile_polygon(tile) -> Geometry + + Returns the polygon representation of a given Bing tile. + +.. function:: bing_tile_at(latitude, longitude, zoom_level) -> BingTile + + Returns a Bing tile at a given zoom level containing a point at a given latitude + and longitude. Latitude must be within ``[-85.05112878, 85.05112878]`` range. + Longitude must be within ``[-180, 180]`` range. Zoom levels from 1 to 23 are supported. + .. function:: bing_tile_quadkey() -> quadKey: varchar Returns the quadkey representing the provided bing tile. @@ -513,5 +617,98 @@ for more details. given zoom level. Empty inputs return an empty array, and null inputs return null. +.. function:: geometry_to_dissolved_bing_tiles(geometry: Geometry, max_zoom_level: tinyint) -> tile: array(BingTile) + + Returns the minimum set of Bing tiles that fully covers a given geometry at a + given zoom level, recursively dissolving full sets of children into parents. + This results in a smaller array of tiles of different zoom levels. + For example, if the non-dissolved covering is [“00”, “01”, “02”, “03”, “10”], + the dissolved covering would be [“0”, “10”]. Zoom levels from 0 to 23 are supported. + +S2 Cell Functions +----------------- + +`S2 Geometry `_ is a library for spherical geometry that +decomposes the Earth's surface into a hierarchy of cells. Unlike planar tiling +systems (e.g., Bing Tiles), S2 cells have near-uniform area across all latitudes. + +Each cell is identified by a 64-bit **cell ID** (stored as ``BIGINT``), which +encodes both the cell's position and level in the hierarchy. Cells are organized +in 31 levels (0–30), where level 0 cells are the largest (covering roughly 1/6 +of Earth's surface) and level 30 cells are the smallest (sub-centimeter). + +Cells can also be represented as compact hexadecimal **tokens** (e.g., +``'8085808b'``), which are shorter and human-readable. Use +``s2_cell_from_token`` and ``s2_cell_to_token`` to convert between the two +representations. + +All functions operate on cell IDs (``BIGINT``) rather than tokens because cell +IDs support direct integer comparison, efficient equi-joins and GROUP BY, and +compose without casting (e.g., ``s2_cell_contains(s2_cell_parent(id, 10), id)``). +Tokens are useful for human-readable output and interop with external systems +that use the token format. + + +.. function:: s2_cell_area_sq_km(cell_id: bigint) -> area: double + + Returns the area of the S2 cell in square kilometers. + Returns an error if the cell ID is invalid. + +.. function:: s2_cell_contains(parent_cell_id: bigint, child_cell_id: bigint) -> boolean + + Returns ``true`` if the first S2 cell contains the second. Containment is + hierarchical: a cell contains all of its descendants at finer levels. + Returns an error if either cell ID is invalid. + +.. function:: s2_cell_from_token(cell_token: varchar) -> cell_id: bigint + + Returns the 64-bit S2 cell ID for the given cell token. The + ``cell_token`` is a compact hexadecimal representation of the S2 cell. + Returns an error if the cell token is invalid. + +.. function:: s2_cell_level(cell_id: bigint) -> level: integer + + Returns the level of the S2 cell, from 0 (coarsest) to 30 (finest). + Returns an error if the cell ID is invalid. + +.. function:: s2_cell_parent(cell_id: bigint, level: integer) -> parent_id: bigint + + Returns the parent S2 cell ID at the given ``level``. If the cell is + already at or above the given level, returns the same cell ID. The + ``level`` must be in the ``[0, 30]`` range. Returns an error if the cell + ID is invalid or the level is out of range. + +.. function:: s2_cell_to_token(cell_id: bigint) -> cell_token: varchar + + Returns the compact hexadecimal token representation of the S2 cell. + Returns an error if the cell ID is invalid. + +.. function:: s2_cells(geometry: Geometry, level: integer) -> cell_ids: array(bigint) + + Returns the set of S2 cell IDs that cover the given geometry at a fixed + ``level``. All returned cells are at the same level. Supports Point, + LineString, Polygon, and their Multi variants. Empty geometries return an + empty array, null geometries return null. The ``level`` must be in the + ``[0, 30]`` range. + +.. function:: s2_cells(geometry: Geometry, min_level: integer, max_level: integer, max_cells: integer) -> cell_ids: array(bigint) + + Returns a compact set of S2 cell IDs at mixed levels that cover the given + geometry, similar to ``geometry_to_dissolved_bing_tiles``. The coverer uses + large cells (at ``min_level``) for interiors and small cells (up to + ``max_level``) for boundaries, targeting at most ``max_cells`` cells. This + is useful for compact spatial indexing of regions like cities or countries. + Both levels must be in the ``[0, 30]`` range with ``min_level <= max_level``, + and ``max_cells`` must be >= 1. Empty geometries return an empty array, + null geometries return null. + + Note: ``max_cells`` is a soft limit. Up to 6 cells may be returned + regardless of ``max_cells`` if the region intersects multiple cube faces + of the S2 projection. ``min_level`` takes priority over ``max_cells`` — + cells below ``min_level`` are never used even if this causes more cells + to be returned. If ``max_cells`` is less than 4, the covering area may be + significantly larger than the original region. A value of 8 or higher is + recommended for a reasonable approximation. + .. _OpenGIS Specifications: https://www.ogc.org/standards/ogcapi-features/ .. _SQL/MM Part 3: Spatial: https://www.iso.org/standard/31369.html diff --git a/velox/docs/functions/presto/hyperloglog.rst b/velox/docs/functions/presto/hyperloglog.rst index ecd8e6d384a..fd10739f5a5 100644 --- a/velox/docs/functions/presto/hyperloglog.rst +++ b/velox/docs/functions/presto/hyperloglog.rst @@ -70,3 +70,10 @@ Functions Returns the ``HyperLogLog`` of the aggregate union of the individual ``hll`` HyperLogLog structures. + +.. function:: merge_hll(array(HyperLogLog)) -> HyperLogLog + + Returns the ``HyperLogLog`` of the union of an array of ``HyperLogLog`` structures. + + * Returns ``NULL`` if the input array is ``NULL``, empty, or contains only ``NULL`` elements + * Ignores ``NULL`` elements and merges only valid ``HyperLogLog`` structures when the array contains a mix of ``NULL`` and non-null elements diff --git a/velox/docs/functions/presto/ipaddress.rst b/velox/docs/functions/presto/ipaddress.rst index f1c061566a2..3b558ee88cd 100644 --- a/velox/docs/functions/presto/ipaddress.rst +++ b/velox/docs/functions/presto/ipaddress.rst @@ -69,6 +69,35 @@ IP Functions SELECT IP_PREFIX_SUBNETS(IPPREFIX '192.168.1.0/24', 25); -- [{192.168.1.0/25}, {192.168.1.128/25}] SELECT IP_PREFIX_SUBNETS(IPPREFIX '2a03:2880:c000::/34', 36); -- [{2a03:2880:c000::/36}, {2a03:2880:d000::/36}, {2a03:2880:e000::/36}, {2a03:2880:f000::/36}] +.. function:: ip_version(ip_address) -> bigint + + Returns ``4`` if ``ip_address`` is an IPv4 address, ``6`` if it is an IPv6 address. + IPv4-mapped IPv6 addresses (e.g. ``::ffff:1.2.3.4``) are treated as IPv4. + ``ip_address`` is of type ``IPADDRESS``. :: + + SELECT ip_version(IPADDRESS '1.2.3.4'); -- 4 + SELECT ip_version(IPADDRESS '::ffff:1.2.3.4'); -- 4 + SELECT ip_version(IPADDRESS '64:ff9b::17'); -- 6 + SELECT ip_version(IPADDRESS '2001:db8::1'); -- 6 + +.. function:: ip_version(ip_prefix) -> bigint + + Returns ``4`` if ``ip_prefix`` contains an IPv4 address, ``6`` if it contains an IPv6 address. + IPv4-mapped IPv6 prefixes are treated as IPv4. + ``ip_prefix`` is of type ``IPPREFIX``. :: + + SELECT ip_version(IPPREFIX '1.2.3.4/24'); -- 4 + SELECT ip_version(IPPREFIX '64:ff9b::17/64'); -- 6 + +.. function:: ip_prefix_masklen(ip_prefix) -> bigint + + Returns the prefix length (mask length) of ``ip_prefix``. + The value is in the range [0, 32] for IPv4 and [0, 128] for IPv6. :: + + SELECT ip_prefix_masklen(IPPREFIX '1.2.3.4/24'); -- 24 + SELECT ip_prefix_masklen(IPPREFIX '64:ff9b::17/128'); -- 128 + SELECT ip_prefix_masklen(IPPREFIX '::/0'); -- 0 + .. function:: is_private_ip(ip_address) -> boolean Returns whether ``ip_address`` of type ``IPADDRESS`` is a private or reserved IP address diff --git a/velox/docs/functions/presto/khyperloglog.rst b/velox/docs/functions/presto/khyperloglog.rst new file mode 100644 index 00000000000..ba23b051a3a --- /dev/null +++ b/velox/docs/functions/presto/khyperloglog.rst @@ -0,0 +1,84 @@ +========================= +KHyperLogLog Functions +========================= + +KHyperLogLog is a data sketch for estimating reidentifiability and joinability within a dataset. +Based on the `KHyperLogLog paper `_, +it maintains a map of K number of HyperLogLog structures, where each entry corresponds to a unique key from one column, +and the HLL estimates the cardinality of the associated unique identifiers from another column. + +Data Structures +--------------- + +A KHyperLogLog is a data sketch which stores approximate cardinality information for key-value +associations. The Velox type for this data structure is called ``KHyperLogLog``. +For storage and retrieval, KHyperLogLog values may be cast to/from ``VARBINARY``. + +Serialization format is compatible with Presto's. + +Aggregate Functions +------------------- + +.. function:: khyperloglog_agg(x, uii) -> KHyperLogLog + + Returns the ``KHyperLogLog`` sketch which summarizes the association between + the key column ``x`` and the unique identifier column ``uii``. + The ``x`` parameter represents the key values and ``uii`` represents + the unique identifiers associated with each key. + +.. function:: merge(KHyperLogLog) -> KHyperLogLog + + Returns the ``KHyperLogLog`` of the aggregate union of the individual ``KHyperLogLog`` + structures. + +Scalar Functions +---------------- + +.. function:: cardinality(khll) -> bigint + + Returns the estimated total cardinality (number of unique keys) from the + ``KHyperLogLog`` sketch ``khll``. + +.. function:: intersection_cardinality(khll1, khll2) -> bigint + + Returns the estimated intersection cardinality between two ``KHyperLogLog`` sketches. + If both sketches are exact (small cardinality), returns the exact intersection count. + Otherwise, returns an approximation using the Jaccard index. + +.. function:: jaccard_index(khll1, khll2) -> double + + Returns the Jaccard index (similarity coefficient) between two ``KHyperLogLog`` sketches. + The Jaccard index is a value in [0, 1] where: + + * 1.0 means the sets are identical + * 0.0 means the sets are disjoint (no overlap) + +.. function:: merge_khll(array(KHyperLogLog)) -> KHyperLogLog + + Returns the ``KHyperLogLog`` of the union of an array of ``KHyperLogLog`` structures. + + * Returns ``NULL`` if the input array is ``NULL``, empty, or contains only ``NULL`` elements + * Ignores ``NULL`` elements and merges only valid ``KHyperLogLog`` structures when the array contains a mix of ``NULL`` and non-null elements + +.. function:: reidentification_potential(khll, threshold) -> double + + Returns the reidentification potential of the ``KHyperLogLog`` sketch ``khll`` + at the given ``threshold``. This measures the fraction of keys that have + cardinality at or below the threshold, which indicates how easily those + keys could be reidentified. + +.. function:: uniqueness_distribution(khll) -> map(bigint, double) + + Returns a histogram map representing the distribution of uniqueness values + in the ``KHyperLogLog`` sketch ``khll``. Each key in the map represents a + cardinality bucket, and the value represents the fraction of keys falling + into that bucket. The histogram size defaults to the minhash size of the + KHyperLogLog instance. + +.. function:: uniqueness_distribution(khll, histogramSize) -> map(bigint, double) + :noindex: + + Returns a histogram map representing the distribution of uniqueness values + in the ``KHyperLogLog`` sketch ``khll`` with the specified ``histogramSize``. + Each key in the map represents a cardinality bucket, and the value represents + the fraction of keys falling into that bucket. diff --git a/velox/docs/functions/presto/map.rst b/velox/docs/functions/presto/map.rst index c81d7eba948..e7cd9b160da 100644 --- a/velox/docs/functions/presto/map.rst +++ b/velox/docs/functions/presto/map.rst @@ -46,6 +46,29 @@ Map Functions See also :func:`map_agg` for creating a map as an aggregation. +.. function:: map_append(map(K,V), array(K), array(V)) -> map(K,V) + + Returns a map with new key-value pairs appended to the input map. The new keys are provided in the first array parameter and corresponding values in the second array parameter. + Keys and values arrays must have the same length. New keys must not already exist in the input map. Duplicate keys in the new keys array are not allowed. + Null keys are ignored. Null values are preserved in the output map. For REAL and DOUBLE, NaNs (Not-a-Number) are considered equal. :: + + SELECT map_append(MAP(ARRAY[1, 2], ARRAY[10, 20]), ARRAY[3, 4], ARRAY[30, 40]); -- {1 -> 10, 2 -> 20, 3 -> 30, 4 -> 40} + SELECT map_append(MAP(ARRAY['a', 'b'], ARRAY[1, 2]), ARRAY['c'], ARRAY[3]); -- {'a' -> 1, 'b' -> 2, 'c' -> 3} + SELECT map_append(MAP(ARRAY[1], ARRAY[10]), ARRAY[2, null, 3], ARRAY[20, 30, 40]); -- {1 -> 10, 2 -> 20, 3 -> 40} + SELECT map_append(MAP(ARRAY[1], ARRAY[10]), ARRAY[2, 3], ARRAY[null, 30]); -- {1 -> 10, 2 -> null, 3 -> 30} + SELECT map_append(MAP(ARRAY[1], ARRAY[10]), ARRAY[], ARRAY[]); -- {1 -> 10} + +.. function:: map_update(map(K,V), array(K), array(V)) -> map(K,V) + + Returns a map with values updated for the specified keys. If a key exists in the input map, its value is updated in place (preserving original order). If a key doesn't exist, it is added to the end of the map. + Keys and values arrays must have the same length. Duplicate keys in the keys array are not allowed. + Null keys are ignored. Null values are preserved in the output map. For REAL and DOUBLE, NaNs (Not-a-Number) are considered equal. :: + + SELECT map_update(MAP(ARRAY[1, 2, 3], ARRAY[10, 20, 30]), ARRAY[2, 4], ARRAY[200, 400]); -- {1 -> 10, 2 -> 200, 3 -> 30, 4 -> 400} + SELECT map_update(MAP(ARRAY['a', 'b'], ARRAY[1, 2]), ARRAY['a', 'c'], ARRAY[100, 300]); -- {'a' -> 100, 'b' -> 2, 'c' -> 300} + SELECT map_update(MAP(ARRAY[1], ARRAY[10]), ARRAY[1, 2], ARRAY[null, 20]); -- {1 -> null, 2 -> 20} + SELECT map_update(MAP(ARRAY[1, 2], ARRAY[10, 20]), ARRAY[], ARRAY[]); -- {1 -> 10, 2 -> 20} + .. function:: map_concat(map1(K,V), map2(K,V), ..., mapN(K,V)) -> map(K,V) Returns the union of all the given maps. If a key is found in multiple given maps, @@ -84,6 +107,40 @@ Map Functions SELECT map_normalize(map(array['a', 'b', 'c', 'd'], array[1, null, 4, 5])); -- {a=0.1, b=null, c=0.4, d=0.5} SELECT map_normalize(map(array['a', 'b', 'c'], array[1, 0, -1])); -- {a=Infinity, b=NaN, c=-Infinity} +.. function:: map_values_in_range(map(K,V), lower_bound, upper_bound) -> map(K,V) + + Returns a map containing only the entries from the input map whose values + fall within the specified range [lower_bound, upper_bound] (inclusive). + Entries with values less than lower_bound or greater than upper_bound are removed. + Entries with null values are preserved in the output. + V must be a numeric type (integer, bigint, real, or double). :: + + SELECT map_values_in_range(MAP(ARRAY[1, 2, 3, 4], ARRAY[10, 20, 30, 40]), 15, 35); -- {2 -> 20, 3 -> 30} + SELECT map_values_in_range(MAP(ARRAY['a', 'b', 'c'], ARRAY[1.5, 2.5, 3.5]), 2.0, 3.0); -- {b -> 2.5} + SELECT map_values_in_range(MAP(ARRAY[1, 2], ARRAY[null, 50]), 0, 100); -- {1 -> null, 2 -> 50} + SELECT map_values_in_range(MAP(ARRAY[1, 2, 3], ARRAY[5, 50, 500]), 10, 100); -- {2 -> 50} + +.. function:: map_values_all_match(x(K,V), function(V, boolean)) -> boolean + + Returns true if all values in the given map match the predicate and false otherwise. NULL if the predicate function returns NULL for one or more values and true for all other values. Equivalent to ``all_match(map_values(x), predicate)`` but avoids materializing the intermediate array. :: + + SELECT map_values_all_match(map(array['a', 'b', 'c'], array[1, 2, 3]), x -> x > 0); -- true + SELECT map_values_all_match(map(array['a', 'b', 'c'], array[1, 2, 3]), x -> x > 1); -- false + +.. function:: map_values_any_match(x(K,V), function(V, boolean)) -> boolean + + Returns true if one or more values in the given map match the predicate and false otherwise. NULL if the predicate function returns NULL for one or more values and false for all other values. Equivalent to ``any_match(map_values(x), predicate)`` but avoids materializing the intermediate array. :: + + SELECT map_values_any_match(map(array['a', 'b', 'c'], array[1, 2, 3]), x -> x = 1); -- true + SELECT map_values_any_match(map(array['a', 'b', 'c'], array[1, 2, 3]), x -> x = 5); -- false + +.. function:: map_values_none_match(x(K,V), function(V, boolean)) -> boolean + + Returns true if no values in the given map match the predicate and false otherwise. NULL if the predicate function returns NULL for one or more values and false for all other values. Equivalent to ``none_match(map_values(x), predicate)`` but avoids materializing the intermediate array. :: + + SELECT map_values_none_match(map(array['a', 'b', 'c'], array[1, 2, 3]), x -> x = 5); -- true + SELECT map_values_none_match(map(array['a', 'b', 'c'], array[1, 2, 3]), x -> x = 1); -- false + .. function:: map_remove_null_values(map(K,V)) -> map(K,V) Returns a map by removing all the keys in input map with null values. If input @@ -94,6 +151,16 @@ Map Functions SELECT map_remove_null_values(MAP(ARRAY[1, 2, 3], ARRAY[3, 4, NULL])); -- {1=3, 2=4} SELECT map_remove_null_values(NULL); -- NULL +.. function:: remap_keys(map(K,V), array(K), array(K)) -> map(K,V) + + Returns a map with keys remapped according to the oldKeys and newKeys arrays. + Unmapped keys remain unchanged. Values are preserved. Null keys are ignored. :: + + SELECT remap_keys(MAP(ARRAY[1, 2, 3], ARRAY[10, 20, 30]), ARRAY[1, 3], ARRAY[100, 300]); -- {100 -> 10, 2 -> 20, 300 -> 30} + SELECT remap_keys(MAP(ARRAY['a', 'b', 'c'], ARRAY[1, 2, 3]), ARRAY['a', 'c'], ARRAY['alpha', 'charlie']); -- {alpha -> 1, b -> 2, charlie -> 3} + SELECT remap_keys(MAP(ARRAY[1, 2, 3], ARRAY[10, null, 30]), ARRAY[1, 2], ARRAY[100, 200]); -- {100 -> 10, 200 -> null, 3 -> 30} + SELECT remap_keys(MAP(ARRAY[1, 2], ARRAY[10, 20]), ARRAY[], ARRAY[]); -- {1 -> 10, 2 -> 20} + .. function:: map_subset(map(K,V), array(k)) -> map(K,V) Constructs a map from those entries of ``map`` for which the key is in the array given @@ -102,8 +169,53 @@ Map Functions SELECT map_subset(MAP(ARRAY[1,2], ARRAY['a','b']), ARRAY[10]); -- {} SELECT map_subset(MAP(ARRAY[1,2], ARRAY['a','b']), ARRAY[1]); -- {1->'a'} SELECT map_subset(MAP(ARRAY[1,2], ARRAY['a','b']), ARRAY[1,3]); -- {1->'a'} - SELECT map_subset(MAP(ARRAY[1,2], ARRAY['a','b']), ARRAY[]); -- {} SELECT map_subset(MAP(ARRAY[], ARRAY[]), ARRAY[1,2]); -- {} + SELECT map_subset(MAP(ARRAY[], ARRAY[]), ARRAY[]); -- {} + +.. function:: map_subset_key_in_range(map(K,V), low_key, high_key) -> map(K,V) + + Returns a sub-map containing only the entries from the input map whose keys fall + within the inclusive range ``[low_key, high_key]``. Both bounds are inclusive. + If ``low_key > high_key``, returns an empty map. Entries with null values are + preserved. If the input map, ``low_key``, or ``high_key`` is ``NULL``, the + result is ``NULL``. ``K`` must be an orderable type. :: + + SELECT map_subset_key_in_range(MAP(ARRAY[1,2,3,4,5], ARRAY[10,20,30,40,50]), 2, 4); -- {2->20, 3->30, 4->40} + SELECT map_subset_key_in_range(MAP(ARRAY[7,10,14,20], ARRAY[70,100,140,200]), 7, 14); -- {7->70, 10->100, 14->140} + SELECT map_subset_key_in_range(MAP(ARRAY[1,2,3], ARRAY[10,20,30]), 5, 1); -- {} + SELECT map_subset_key_in_range(MAP(ARRAY[], ARRAY[]), 1, 10); -- {} + +.. function:: map_intersect(map(K,V), array(K)) -> map(K,V) + + Returns a map containing only the entries from the input map whose keys are present in the given array. + This function is equivalent to map_subset. Null keys in the array are ignored. + For keys containing REAL and DOUBLE, NANs (Not-a-Number) are considered equal. :: + + SELECT map_intersect(MAP(ARRAY[1,2,3], ARRAY['a','b','c']), ARRAY[1,3]); -- {1->'a', 3->'c'} + SELECT map_intersect(MAP(ARRAY[1,2], ARRAY['a','b']), ARRAY[10]); -- {} + SELECT map_intersect(MAP(ARRAY[1,2], ARRAY['a','b']), ARRAY[]); -- {} + SELECT map_intersect(MAP(ARRAY[], ARRAY[]), ARRAY[1,2]); -- {} + +.. function:: map_except(map(K,V), array(k)) -> map(K,V) + + Constructs a map from those entries of ``map`` for which the key is not in the array given. + For keys containing REAL and DOUBLE, NANs (Not-a-Number) are considered equal. :: + + SELECT map_except(MAP(ARRAY[1,2], ARRAY['a','b']), ARRAY[10]); -- {1->'a', 2->'b'} + SELECT map_except(MAP(ARRAY[1,2], ARRAY['a','b']), ARRAY[1]); -- {2->'b'} + SELECT map_except(MAP(ARRAY[1,2], ARRAY['a','b']), ARRAY[1,3]); -- {2->'b'} + SELECT map_except(MAP(ARRAY[1,2], ARRAY['a','b']), ARRAY[]); -- {1->'a', 2->'b'} + SELECT map_except(MAP(ARRAY[], ARRAY[]), ARRAY[1,2]); -- {} + +.. function:: map_keys_overlap(map(K,V), array(K)) -> boolean + + Returns true if any key in the map matches any element in the given array, false otherwise. + Returns false if either the map or array is empty. Null keys in the array are ignored. :: + + SELECT map_keys_overlap(MAP(ARRAY[1, 2, 3], ARRAY[10, 20, 30]), ARRAY[1, 5]); -- true + SELECT map_keys_overlap(MAP(ARRAY[1, 2, 3], ARRAY[10, 20, 30]), ARRAY[4, 5]); -- false + SELECT map_keys_overlap(MAP(ARRAY['a', 'b'], ARRAY[1, 2]), ARRAY['a']); -- true + SELECT map_keys_overlap(MAP(ARRAY[], ARRAY[]), ARRAY[1]); -- false .. function:: map_top_n(map(K,V), n) -> map(K, V) @@ -114,6 +226,24 @@ Map Functions SELECT map_top_n(map(ARRAY['a', 'b', 'c'], ARRAY[2, 3, 1]), 2) --- {'b' -> 3, 'a' -> 2} SELECT map_top_n(map(ARRAY['a', 'b', 'c'], ARRAY[NULL, 3, NULL]), 2) --- {'b' -> 3, 'c' -> NULL} +.. function:: map_trim_values(map(K, array(V)), n) -> map(K, array(V)) + + Trims the value arrays in a map to a specified maximum size. + This function is useful for optimizing memory usage and performance for large feature maps + where the value arrays may grow unbounded. + + Returns a map where each value array is trimmed to at most n elements. + If n is negative, returns the original map unchanged. + If n is 0, returns a map where all values are empty arrays. + If a value array has fewer than n elements, it is left unchanged. + Null elements in the arrays are preserved in the output. :: + + SELECT map_trim_values(MAP(ARRAY['a', 'b'], ARRAY[ARRAY[1, 2, 3], ARRAY[4, 5, 6, 7]]), 2); -- {a -> [1, 2], b -> [4, 5]} + SELECT map_trim_values(MAP(ARRAY['a'], ARRAY[ARRAY[1, 2]]), 5); -- {a -> [1, 2]} + SELECT map_trim_values(MAP(ARRAY['a'], ARRAY[ARRAY[1, NULL, 3]]), 2); -- {a -> [1, NULL]} + SELECT map_trim_values(MAP(ARRAY['a'], ARRAY[ARRAY[1, 2, 3]]), 0); -- {a -> []} + SELECT map_trim_values(MAP(ARRAY['a'], ARRAY[ARRAY[1, 2, 3]]), -1); -- {a -> [1, 2, 3]} + .. function:: map_keys_by_top_n_values(map(K,V), n) -> array(K) Returns an array of the top N keys from a map. Keeps only the top N elements by value. Keys are used to break ties with the max key being chosen. Both keys and values should be orderable. diff --git a/velox/docs/functions/presto/math.rst b/velox/docs/functions/presto/math.rst index b61a623aca6..91e3cc7e6a2 100644 --- a/velox/docs/functions/presto/math.rst +++ b/velox/docs/functions/presto/math.rst @@ -383,6 +383,10 @@ Probability Functions: cdf Compute the Poisson cdf with given lambda (mean) parameter: P(N <= value; lambda). The lambda parameter must be a positive real number (of type DOUBLE) and value must be a non-negative integer. +.. function:: t_cdf(df, value) -> double + + Compute the Student's t cdf with given degrees of freedom: P(N < value; df). + The degrees of freedom must be a positive real number and value must be a real value. .. function:: weibull_cdf(a, b, value) -> double @@ -455,6 +459,12 @@ Probability Functions: inverse_cdf probability (p): P(N < n). The df parameter must be positive real values. The probability p must lie on the interval [0, 1]. +.. function:: inverse_t_cdf(df, p) -> double + + Compute the inverse of the Student's t cdf with given degrees of freedom for the cumulative + probability (p): P(N < n). The degrees of freedom must be a positive real value. + The probability p must lie on the interval [0, 1]. + ==================================== Statistical Functions ==================================== diff --git a/velox/docs/functions/presto/qdigest.rst b/velox/docs/functions/presto/qdigest.rst index 9b7b00dabb0..ba11449e4dd 100644 --- a/velox/docs/functions/presto/qdigest.rst +++ b/velox/docs/functions/presto/qdigest.rst @@ -54,6 +54,11 @@ Functions must be a value greater than zero and less than one, and it must be constant for all input rows. +.. function:: scale_qdigest(qdigest(T), scale_factor) -> qdigest(T) + + Returns a ``qdigest`` whose distribution has been scaled by a factor + specified by ``scale_factor``. + .. function:: value_at_quantile(digest: qdigest, quantile: double) -> T Returns the approximate percentile values from the quantile digest ``digest`` given the ``quantile``. diff --git a/velox/docs/functions/presto/setdigest.rst b/velox/docs/functions/presto/setdigest.rst new file mode 100644 index 00000000000..cea62df12ee --- /dev/null +++ b/velox/docs/functions/presto/setdigest.rst @@ -0,0 +1,68 @@ +====================== +SetDigest Functions +====================== + +SetDigest is a data sketch for estimating set cardinality and performing set +operations like intersection cardinality and Jaccard index. It combines HyperLogLog +for cardinality estimation with MinHash for exact counting and intersection operations. + +SetDigests may be merged, and for storage and retrieval they may be cast to/from ``VARBINARY``. + +Data Structures +--------------- + +A SetDigest is a data sketch which stores approximate set membership and cardinality +information. The Velox type for this data structure is called ``SetDigest``. +SetDigests support two element types internally: + +* ``bigint`` - for integer values (all numeric types are converted to bigint) +* ``varchar`` - for string values + +When a SetDigest is exact (cardinality is less than the maximum hash limit), +operations like intersection cardinality return exact results. When the digest +becomes approximate (high cardinality), it uses HyperLogLog and MinHash estimation. + +Serialization format is compatible with Presto's. + +Aggregate Functions +------------------- + +.. function:: make_set_digest(x) -> SetDigest + + Returns the ``SetDigest`` sketch which summarizes the input data set of ``x``. + Supported input types include: ``boolean``, ``tinyint``, ``smallint``, ``integer``, + ``bigint``, ``real``, ``double``, ``date``, ``varchar``, and ``varbinary``. + +.. function:: merge_set_digest(SetDigest) -> SetDigest + + Returns the ``SetDigest`` of the aggregate union of the individual ``SetDigest`` + structures. + +Scalar Functions +---------------- + +.. function:: cardinality(setdigest) -> bigint + + Returns the estimated cardinality of the set represented by the ``SetDigest`` sketch. + If the digest is exact (low cardinality), returns the exact count. + Otherwise, returns an approximation using HyperLogLog. + +.. function:: intersection_cardinality(setdigest1, setdigest2) -> bigint + + Returns the estimated intersection cardinality between two ``SetDigest`` sketches. + + * If both digests are exact: returns the exact intersection count + * If either digest is approximate: returns an estimation using the Jaccard index + + The result is capped at the minimum cardinality of the two input digests to + ensure logical consistency. + +.. function:: jaccard_index(setdigest1, setdigest2) -> double + + Returns the Jaccard index (similarity coefficient) between two ``SetDigest`` sketches. + The Jaccard index is a value in [0, 1] where: + + * 1.0 means the sets are identical + * 0.0 means the sets are disjoint (no overlap) + + Uses MinHash estimation for efficient computation. diff --git a/velox/docs/functions/presto/string.rst b/velox/docs/functions/presto/string.rst index 7806bd646ee..4034ff6d652 100644 --- a/velox/docs/functions/presto/string.rst +++ b/velox/docs/functions/presto/string.rst @@ -19,6 +19,12 @@ String Functions some languages. Specifically, this will return incorrect results for Lithuanian, Turkish and Azeri. +.. function:: bit_length(string) -> integer + + Returns the bit length for the specified string column. + + SELECT bit_length('123'); -- 24 + .. function:: chr(n) -> varchar Returns the Unicode code point ``n`` as a single character string. @@ -60,6 +66,10 @@ String Functions i.e. the number of positions at which the corresponding characters are different. Note that the two strings must have the same length. +.. function:: jarowinkler_similarity(string1, string2) -> double + + Returns the Jaro-Winkler similarity of ``string1`` and ``string2``. + .. function:: length(string) -> bigint Returns the length of ``string`` in characters. @@ -69,6 +79,10 @@ String Functions Returns the Levenshtein edit distance of 2 strings. I.e. the minimum number of single-character edits (insertions, deletions or substitutions) needed to convert ``string_1`` to ``string_2``. +.. function:: longest_common_prefix(string1, string2) -> varchar + + Returns the longest common prefix between ``string1`` and ``string2`` + .. function:: lower(string) -> varchar Converts ``string`` to lowercase. @@ -93,7 +107,7 @@ String Functions SELECT ltrim('test', 't'); -- est SELECT ltrim('tetris', 'te'); -- ris -.. function:: replaceFirst(string, search, replace) -> varchar +.. function:: replace_first(string, search, replace) -> varchar Removes the first instances of ``search`` with ``replace`` in ``string``. @@ -358,3 +372,8 @@ Unicode Functions .. function:: to_utf8(string) -> varbinary Encodes ``string`` into a UTF-8 varbinary representation. + +.. function:: key_sampling_percent(varchar) -> double + + Generates a double value between 0.0 and 1.0 based on the hash of the given ``varchar``. + This function is useful for deterministic sampling of data. diff --git a/velox/docs/functions/presto/tdigest.rst b/velox/docs/functions/presto/tdigest.rst index b43be2f3de6..a115ce84be7 100644 --- a/velox/docs/functions/presto/tdigest.rst +++ b/velox/docs/functions/presto/tdigest.rst @@ -105,6 +105,55 @@ Functions Returns the mean of values between ``low_quantile`` and ``high_quantile`` (inclusive) from the T-digest ``digest``. Both quantile values must be between zero and one (inclusive), and ``low_quantile`` must be less than or equal to ``high_quantile``. +.. function:: winsorized_mean(digest: tdigest, low_quantile: double, high_quantile: double) -> double + + Returns the Winsorized mean from the T-digest ``digest``. Values below + ``low_quantile`` are replaced with the boundary value at that quantile. + Values above ``high_quantile`` are replaced with the boundary value at + that quantile. The mean is then computed on all values including the + replaced tails. Both quantile values must be between zero and one + (inclusive), and ``low_quantile`` must be less than or equal to + ``high_quantile``. + + Unlike :func:`trimmed_mean` which excludes tail values entirely, + ``winsorized_mean`` replaces them with the boundary values, keeping the + total count unchanged. + +.. function:: approx_winsorized_mean(x: double, low_quantile: double, high_quantile: double) -> double + + Returns the approximate Winsorized mean of all input values of ``x`` + using a T-digest sketch. This is a single-pass aggregate that replaces + values below ``low_quantile`` with the boundary value at that quantile, + and values above ``high_quantile`` with the boundary value at that + quantile, then computes the mean. + + ``low_quantile`` and ``high_quantile`` must be constants between zero and + one (inclusive), and ``low_quantile`` must be less than or equal to + ``high_quantile``. The default compression factor is ``100``. + + This function replaces the common two-pass pattern of computing percentile + thresholds with ``approx_percentile`` and then capping values with + ``LEAST``/``GREATEST`` before computing ``AVG``. + + Example:: + + -- Old two-pass pattern: + WITH t AS ( + SELECT APPROX_PERCENTILE(x, 0.99) AS thresh FROM data + ) + SELECT AVG(LEAST(data.x, t.thresh)) FROM data, t; + + -- New single-pass equivalent: + SELECT approx_winsorized_mean(x, 0.0, 0.99) FROM data; + +.. function:: approx_winsorized_mean(x: double, low_quantile: double, high_quantile: double, compression: double) -> double + :noindex: + + Like the above, but with a specified compression factor. ``compression`` + must be a positive constant. The default is ``100``, maximum is ``1000``, + and values lower than ``10`` are rounded to ``10``. Higher compression + means more accuracy at the cost of more memory. + .. function:: value_at_quantile(digest: tdigest, quantile: double) -> double Returns the approximate percentile value from the T-digest ``digest`` at the given ``quantile``. diff --git a/velox/docs/functions/spark/aggregate.rst b/velox/docs/functions/spark/aggregate.rst index a9ac2bd001c..b94fe339f73 100644 --- a/velox/docs/functions/spark/aggregate.rst +++ b/velox/docs/functions/spark/aggregate.rst @@ -10,17 +10,25 @@ General Aggregate Functions .. spark:function:: avg(x) -> double|decimal Returns the average (arithmetic mean) of all non-null input values. - When x is of type DECIMAL, the result type is DECIMAL, - and the intermediate results are varbinarys or (sum, count) pairs represented as row(decimal, bigint). + When ``x`` is of type DECIMAL(p, s), the result type is DECIMAL(p + 4, s + 4), + and the intermediate results are (sum, count) pairs represented as ROW(DECIMAL(p + 10, s), BIGINT). + The current implementation for DECIMAL matches Spark avg's default behavior with spark.sql.decimalOperations.allowPrecisionLoss=true. For all other input types, the result type is DOUBLE, - and the intermediate results are (sum, count) pairs represented as row(double, bigint). - When all inputs are nulls, the intermediate result is row(0, 0), + and the intermediate results are (sum, count) pairs represented as ROW(DOUBLE, BIGINT). + When all inputs are nulls, the intermediate result is ROW(0, 0), and the final result is null. .. spark:function:: bit_xor(x) -> bigint Returns the bitwise XOR of all non-null input values, or null if none. +.. spark:function:: bitmap_construct_agg(position) -> varbinary + + Builds a fixed-size 4096-byte (32768-bit) bitmap by setting bits at the + specified positions. Input positions must be BIGINT values in [0, 32767]. + Null inputs are ignored. Returns an all-zeros bitmap for empty or all-null + input (this is a non-nullable aggregate). + .. spark:function:: bloom_filter_agg(hash, estimatedNumItems, numBits) -> varbinary Creates bloom filter from input hashes and returns it serialized into VARBINARY. @@ -55,13 +63,23 @@ General Aggregate Functions .. spark:function:: collect_list(x) -> array<[same as x]> - Returns an array created from the input ``x`` elements. Ignores null - inputs, and returns an empty array when all inputs are null. + Returns an array created from the input ``x`` elements. By default, + ignores null inputs and returns an empty array when all inputs are null. + + When the configuration property ``spark.collect_list.ignore_nulls`` is set + to ``false``, null values are included in the output array (RESPECT NULLS + behavior). In this mode, an all-null input produces an array of nulls + instead of an empty array. + +.. spark:function:: collect_set(x [, ignoreNulls]) -> array<[same as x]> -.. spark:function:: collect_set(x) -> array<[same as x]> + Returns an array consisting of all unique values from the input ``x`` elements. + When ``ignoreNulls`` is ``true``, null inputs are excluded and an all-null + input returns an empty array. NaN values are considered distinct. - Returns an array consisting of all unique values from the input ``x`` elements excluding NULLs. - NaN values are considered distinct. Returns empty array if input is empty or all NULL. + When ``ignoreNulls`` is set to ``false`` (RESPECT NULLS), null values are + included in the result set. In this mode, an all-null input produces an + array containing a single null instead of an empty array. Example:: diff --git a/velox/docs/functions/spark/array.rst b/velox/docs/functions/spark/array.rst index 42dbfe54e21..0cf2b019a91 100644 --- a/velox/docs/functions/spark/array.rst +++ b/velox/docs/functions/spark/array.rst @@ -171,13 +171,35 @@ Array Functions .. spark:function:: array_sort(array(E)) -> array(E) Returns an array which has the sorted order of the input array(E). The elements of array(E) must - be orderable. Null elements will be placed at the end of the returned array. :: + be orderable. NULL and NaN elements will be placed at the end of the returned array, with NaN elements appearing before NULL elements for floating-point types. :: SELECT array_sort(array(1, 2, 3)); -- [1, 2, 3] SELECT array_sort(array(3, 2, 1)); -- [1, 2, 3] - SELECT array_sort(array(2, 1, NULL); -- [1, 2, NULL] + SELECT array_sort(array(2, 1, NULL)); -- [1, 2, NULL] SELECT array_sort(array(NULL, 1, NULL)); -- [1, NULL, NULL] SELECT array_sort(array(NULL, 2, 1)); -- [1, 2, NULL] + SELECT array_sort(array(4.0, NULL, float('nan'), 3.0)); -- [3.0, 4.0, NaN, NULL] + SELECT array_sort(array(array(), array(1, 3, NULL), array(NULL, 6), NULL, array(2, 1))); -- [[], [NULL, 6], [1, 3, NULL], [2, 1], NULL] + +.. spark:function:: array_sort(array(E), function(E,U)) -> array(E) + :noindex: + + Returns the array sorted by values computed using specified lambda in ascending order. ``U`` must be an orderable type. + NULL and NaN elements returned by the lambda function will be placed at the end of the returned array, with NaN elements appearing before NULL elements. + This function is not supported in Spark and is only used inside Velox for rewriting :spark:func:`array_sort(array(E), function(E,E,U)) -> array(E)` as :spark:func:`array_sort(array(E), function(E,U)) -> array(E)`. :: + +.. spark:function:: array_sort(array(E), function(E,E,U)) -> array(E) + :noindex: + + Returns the array sorted by values computed using specified lambda in ascending + order. ``U`` must be an orderable type. + The function attempts to analyze the lambda function and rewrite it into a simpler call that + specifies the sort-by expression (like :spark:func:`array_sort(array(E), function(E,U)) -> array(E)`). For example, ``(left, right) -> if(length(left) > length(right), 1, if(length(left) < length(right), -1, 0))`` will be rewritten to ``x -> length(x)``. If rewrite is not possible, a user error will be thrown. + If the rewritten function returns NULL, the corresponding element will be placed at the end the returned array. Please note that due to this rewrite optimization, the NULL handling logics between Spark and Velox differ. In Spark, the position of NULL element is determined by the comparison of NULL with other elements. :: + + SELECT array_sort(array('cat', 'leopard', 'mouse'), (left, right) -> if(length(left) > length(right), 1, if(length(left) < length(right), -1, 0))); -- ['cat', 'mouse', 'leopard'] + select array_sort(array("abcd123", "abcd", NULL, "abc"), (left, right) -> if(length(left) > length(right), 1, if(length(left) < length(right), -1, 0))); -- ["abc", "abcd", "abcd123", NULL] + select array_sort(array("abcd123", "abcd", NULL, "abc"), (left, right) -> if(length(left) > length(right), 1, if(length(left) = length(right), 0, -1))); -- ["abc", "abcd", "abcd123", NULL] different with Spark: ["abc", NULL, "abcd", "abcd123"] .. spark:function:: array_union(array(E) x, array(E) y) -> array(E) @@ -260,6 +282,21 @@ Array Functions Returns true if value matches at least one of the elements of the array. Supports BOOLEAN, REAL, DOUBLE, BIGINT, VARCHAR, TIMESTAMP, DATE, DECIMAL input types. +.. spark:function:: sequence(start, stop) -> array(T) + sequence(start, stop, step) -> array(T) + + Generates an array of elements from ``start`` to ``stop``, incrementing by + ``step`` (default 1 if ``start <= stop``, otherwise -1). + + Supports tinyint, smallint, integer, and bigint types for start, stop, and + step, preserving the input type in the result array. Also supports date + inputs with day-to-second or year-to-month interval step, and timestamp + inputs with interval step. :: + + SELECT sequence(1, 5); -- [1, 2, 3, 4, 5] + SELECT sequence(5, 1); -- [5, 4, 3, 2, 1] + SELECT sequence(1, 10, 3); -- [1, 4, 7, 10] + .. spark:function:: shuffle(array(E), seed) -> array(E) Generates a random permutation of the given ``array`` using a seed derived diff --git a/velox/docs/functions/spark/conversion.rst b/velox/docs/functions/spark/conversion.rst index 22a7da562c4..ae37c420097 100644 --- a/velox/docs/functions/spark/conversion.rst +++ b/velox/docs/functions/spark/conversion.rst @@ -80,9 +80,15 @@ Valid examples From strings ^^^^^^^^^^^^ +*(ANSI compliant)* + Casting a string to an integral type is allowed if the string represents a number within the range of result type. -Casting from strings that represent floating-point numbers truncates the decimal part of the input value. -Casting from invalid input values throws. +Casting from strings that represent floating-point numbers truncates the +decimal part of the input value when ANSI mode is disabled; throws an +error otherwise. + +Casting from other invalid strings returns NULL when ANSI mode is disabled; +throws an error otherwise. Valid examples @@ -91,28 +97,28 @@ Valid examples SELECT cast('12345' as bigint); -- 12345 SELECT cast('+1' as tinyint); -- 1 SELECT cast('-1' as tinyint); -- -1 - SELECT cast('12345.67' as bigint); -- 12345 - SELECT cast('1.2' as tinyint); -- 1 - SELECT cast('-1.8' as tinyint); -- -1 + SELECT cast('12345.67' as bigint); -- 12345 (ANSI OFF) / ERROR (ANSI ON) + SELECT cast('1.2' as tinyint); -- 1 (ANSI OFF) / ERROR (ANSI ON) + SELECT cast('-1.8' as tinyint); -- -1 (ANSI OFF) / ERROR (ANSI ON) SELECT cast('+1' as tinyint); -- 1 - SELECT cast('1.' as tinyint); -- 1 + SELECT cast('1.' as tinyint); -- 1 (ANSI OFF) / ERROR (ANSI ON) SELECT cast('-1' as tinyint); -- -1 - SELECT cast('-1.' as tinyint); -- -1 - SELECT cast('0.' as tinyint); -- 0 - SELECT cast('.' as tinyint); -- 0 - SELECT cast('-.' as tinyint); -- 0 + SELECT cast('-1.' as tinyint); -- -1 (ANSI OFF) / ERROR (ANSI ON) + SELECT cast('0.' as tinyint); -- 0 (ANSI OFF) / ERROR (ANSI ON) + SELECT cast('.' as tinyint); -- 0 (ANSI OFF) / ERROR (ANSI ON) + SELECT cast('-.' as tinyint); -- 0 (ANSI OFF) / ERROR (ANSI ON) Invalid examples :: - SELECT cast('1234567' as tinyint); -- NULL // Reason: Out of range - SELECT cast('1a' as tinyint); -- NULL // Invalid argument - SELECT cast('' as tinyint); -- NULL // Invalid argument - SELECT cast('1,234,567' as bigint); -- NULL // Invalid argument - SELECT cast('1'234'567' as bigint); -- NULL // Invalid argument - SELECT cast('nan' as bigint); -- NULL // Invalid argument - SELECT cast('infinity' as bigint); -- NULL // Invalid argument + SELECT cast('1234567' as tinyint); -- NULL (ANSI OFF) / ERROR (ANSI ON) + SELECT cast('1a' as tinyint); -- NULL (ANSI OFF) / ERROR (ANSI ON) + SELECT cast('' as tinyint); -- NULL (ANSI OFF) / ERROR (ANSI ON) + SELECT cast('1,234,567' as bigint); -- NULL (ANSI OFF) / ERROR (ANSI ON) + SELECT cast('1'234'567' as bigint); -- NULL (ANSI OFF) / ERROR (ANSI ON) + SELECT cast('nan' as bigint); -- NULL (ANSI OFF) / ERROR (ANSI ON) + SELECT cast('infinity' as bigint); -- NULL (ANSI OFF) / ERROR (ANSI ON) From decimal ^^^^^^^^^^^^ @@ -132,7 +138,7 @@ Valid examples SELECT cast(cast(2147483648.90 as DECIMAL(12, 2)) as bigint); -- 2147483648 From timestamp -^^^^^^^^^^^^^ +^^^^^^^^^^^^^^ Casting timestamp as integral types returns the number of seconds by converting timestamp as microseconds, dividing by the number of microseconds in a second, and then rounding down to the nearest second since the epoch (1970-01-01 00:00:00 UTC). @@ -155,8 +161,12 @@ Cast to Boolean From VARCHAR ^^^^^^^^^^^^ -The strings `t, f, y, n, 1, 0, yes, no, true, false` and their upper case equivalents are allowed to be casted to boolean. -Casting from other strings to boolean throws. +*(ANSI compliant)* + +The strings `t, f, y, n, 1, 0, yes, no, true, false` and their upper case +equivalents are allowed to be cast to boolean. +Casting from invalid strings throws an error when ANSI mode is enabled, +or returns NULL when ANSI mode is disabled. Valid examples @@ -177,16 +187,42 @@ Invalid examples :: - SELECT cast('1.7E308' as boolean); -- NULL // Invalid argument - SELECT cast('nan' as boolean); -- NULL // Invalid argument - SELECT cast('infinity' as boolean); -- NULL // Invalid argument - SELECT cast('12' as boolean); -- NULL // Invalid argument - SELECT cast('-1' as boolean); -- NULL // Invalid argument - SELECT cast('tr' as boolean); -- NULL // Invalid argument - SELECT cast('tru' as boolean); -- NULL // Invalid argument + SELECT cast('1.7E308' as boolean); -- NULL (ANSI OFF) / ERROR (ANSI ON) + SELECT cast('nan' as boolean); -- NULL (ANSI OFF) / ERROR (ANSI ON) + SELECT cast('infinity' as boolean); -- NULL (ANSI OFF) / ERROR (ANSI ON) + SELECT cast('12' as boolean); -- NULL (ANSI OFF) / ERROR (ANSI ON) + SELECT cast('-1' as boolean); -- NULL (ANSI OFF) / ERROR (ANSI ON) + SELECT cast('tr' as boolean); -- NULL (ANSI OFF) / ERROR (ANSI ON) + SELECT cast('tru' as boolean); -- NULL (ANSI OFF) / ERROR (ANSI ON) Cast to String -------------- +From DECIMAL +^^^^^^^^^^^^ + +*(ANSI compliant)* + +Casting a DECIMAL to STRING returns a plain decimal value. +The scale is preserved and trailing zeros are kept for normal (non-scientific) form. +When the absolute value is less than:math:`10^{-6}`, the result is formatted in scientific notation (e.g. ``1.23E-8``). + +The conversion always succeeds with identical results for both ANSI ON and OFF modes. + +Valid examples + +:: + + SELECT cast(cast(1.00 as decimal(10, 2)) as string); -- '1.00' + SELECT cast(cast(12.30 as decimal(10, 2)) as string); -- '12.30' + SELECT cast(cast(0.00000012 as decimal(10, 8)) as string); -- '0.00000012' + SELECT cast(cast(-1.00 as decimal(10, 2)) as string); -- '-1.00' + SELECT cast(cast(123456789.123456789 as decimal(18, 9)) as string); -- '123456789.123456789' + SELECT cast(cast(0.00 as decimal(5, 2)) as string); -- '0.00' + SELECT cast(cast(999.99 as decimal(5, 2)) as string); -- '999.99' + SELECT cast(cast(-0.01 as decimal(3, 2)) as string); -- '-0.01' + SELECT cast(cast(1 as decimal(38, 20)) as string); -- '1E-20' + SELECT cast(cast(0 as decimal(10, 7)) as string); -- '0E-7' + SELECT cast(cast(123 as decimal(38, 10)) as string); -- '1.23E-8' From TIMESTAMP ^^^^^^^^^^^^^^ @@ -214,6 +250,8 @@ Cast to Date From strings ^^^^^^^^^^^^ +*(ANSI compliant)* + All Spark supported patterns are allowed: * ``[+-](YYYY-MM-DD)`` @@ -230,7 +268,9 @@ For the last two patterns, the trailing ``*`` can represent none or any sequence * "1970-01-01 (BC)" All leading and trailing UTF8 white-spaces will be trimmed before cast. -Casting from invalid input values throws. + +When ANSI mode is enabled, casting from invalid input values throws an error. +When ANSI mode is disabled, casting from invalid input values returns NULL. Valid examples diff --git a/velox/docs/functions/spark/datetime.rst b/velox/docs/functions/spark/datetime.rst index d87f5819885..958179a2bf1 100644 --- a/velox/docs/functions/spark/datetime.rst +++ b/velox/docs/functions/spark/datetime.rst @@ -106,6 +106,15 @@ These functions support TIMESTAMP and DATE input types. SELECT datediff('2009-07-31', '2009-07-30'); -- 1 SELECT datediff('2009-07-30', '2009-07-31'); -- -1 +.. spark:function:: dayname(date) -> varchar + + Returns the three-letter abbreviated day name from the given date (Sun, Mon, Tue, Wed, Thu, Fri, Sat). :: + + SELECT dayname('2009-07-30'); -- 'Thu' + SELECT dayname('2023-08-20'); -- 'Sun' + SELECT dayname('2023-08-21'); -- 'Mon' + SELECT dayname('1582-10-15'); -- 'Fri' + .. spark:function:: dayofmonth(date) -> integer Returns the day of month of the date. :: @@ -244,6 +253,30 @@ These functions support TIMESTAMP and DATE input types. SELECT month('2009-07-30'); -- 7 +.. spark:function:: monthname(date) -> varchar + + Returns the three-letter abbreviated month name for the given ``date``. + Possible values: Jan, Feb, Mar, Apr, May, Jun, Jul, Aug, Sep, Oct, Nov, Dec. :: + + SELECT monthname('2008-02-20'); -- 'Feb' + SELECT monthname('2011-05-06'); -- 'May' + SELECT monthname('2023-08-20'); -- 'Aug' + SELECT monthname('1582-10-15'); -- 'Oct' + +.. spark:function:: months_between(timestamp1, timestamp2, roundOff) -> double + + Returns number of months between times ``timestamp1`` and ``timestamp2``. + If ``timestamp1`` is later than ``timestamp2``, the result is positive. + If ``timestamp1`` and ``timestamp2`` are on the same day of month, or both are the + last day of month, time of day will be ignored. Otherwise, the difference is calculated + based on 31 days per month, and rounded to 8 digits unless ``roundOff`` is false. :: + + SELECT months_between('1997-02-28 10:30:00', '1996-10-30', true); -- 3.94959677 + SELECT months_between('1997-02-28 10:30:00', '1996-10-30', false); -- 3.9495967741935485 + SELECT months_between('1997-02-28 10:30:00', '1996-03-31 11:00:00', true); -- 11.0 + SELECT months_between('1997-02-28 10:30:00', '1996-03-28 11:00:00', true); -- 11.0 + SELECT months_between('1997-02-21 10:30:00', '1996-03-21 11:00:00', true); -- 11.0 + .. spark:function:: next_day(startDate, dayOfWeek) -> date Returns the first date which is later than ``startDate`` and named as ``dayOfWeek``. @@ -273,7 +306,7 @@ These functions support TIMESTAMP and DATE input types. .. spark:function:: timestampadd(unit, value, timestamp) -> timestamp - Adds an interval ``value`` of type ``unit`` to ``timestamp``. + Adds an int or bigint interval ``value`` of type ``unit`` to ``timestamp``. Subtraction can be performed by using a negative ``value``. Throws exception if ``unit`` is invalid. ``unit`` is case insensitive and must be one of the following: @@ -322,12 +355,14 @@ These functions support TIMESTAMP and DATE input types. converts the number of seconds to a timestamp. For floating-point types (FLOAT, DOUBLE), the function scales the input to microseconds, truncates towards zero, and saturates the result to the minimum and maximum values allowed - in Spark.:: + in Spark. Returns NULL when ``x`` is NaN or Infinity. :: SELECT timestamp_seconds(1230219000); -- '2008-12-25 15:30:00' SELECT timestamp_seconds(1230219000.123); -- '2008-12-25 15:30:00.123' SELECT timestamp_seconds(double(1.1234567)); -- '1970-01-01 00:00:01.123456' + SELECT timestamp_seconds(double('inf')); -- NULL SELECT timestamp_seconds(float(3.4028235E+38)); -- '+294247-01-10 04:00:54.775807' + SELECT timestamp_seconds(float('nan')); -- NULL .. spark:function:: to_unix_timestamp(date) -> bigint :noindex: diff --git a/velox/docs/functions/spark/decimal.rst b/velox/docs/functions/spark/decimal.rst index d4b33dcbe08..1afd15d5e88 100644 --- a/velox/docs/functions/spark/decimal.rst +++ b/velox/docs/functions/spark/decimal.rst @@ -8,6 +8,14 @@ The result may exceed maximum allowed precision of 38. Second stage caps precision at 38 and either reduces the scale or not depending on allow-precision-loss flag. +The allow-precision-loss flag applies to both regular and checked (ANSI mode) arithmetic functions. +In Spark, there are no separate checked expression classes. The same expression (e.g., ``Add``) +handles both ANSI and non-ANSI behavior, controlled by an ``EvalMode`` flag. In Velox, the checked +variants are registered as separate functions (e.g., ``checked_add``, ``checked_subtract``) +to support the TRY evaluation mode (e.g., ``try(checked_add(...))`` returns NULL on overflow). + +Regular functions return NULL on overflow, while checked functions throw an error. + For example, addition of decimal(38, 7) and decimal(10, 0) requires precision of 39 and scale of 7. Since precision exceeds 38 it needs to be capped. When allow-precision-loss, precision is capped at 38 and scale is reduced by 1 to 6. When allow-precision-loss is false, precision is capped at 38 as well, but scale is kept at 7. @@ -38,7 +46,7 @@ The HiveQL behavior: https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf Additionally, the computation of decimal division adapts to the allow-precision-loss flag, -while the decimal addition, subtraction, and multiplication do not. +while the decimal addition, subtraction, multiplication and integer division do not. Addition and Subtraction ~~~~~~~~~~~~~~~~~~~~~~~~ @@ -74,6 +82,15 @@ When allow-precision-loss is false: p = wholeDigits + fractionalDigits s = fractionalDigits +Integer Division +~~~~~~~~~~~~~~~~ + +:: + + precision = p1 - s1 + s2 + p = precision == 0 ? 1 : min(38, precision) + s = 0 + Decimal Precision and Scale Adjustment -------------------------------------- @@ -112,6 +129,81 @@ Decimal division uses a different formula: Returns NULL when the actual result cannot be represented with the calculated decimal type. +Arithmetic Functions +-------------------- + +.. spark:function:: add(x: decimal(p1, s1), y: decimal(p2, s2)) -> r: decimal(p3, s3) + + Returns the result of adding ``x`` and ``y``. The result type is determined + by the precision and scale computation rules described above. + Returns NULL when the result overflows. + Corresponds to Spark's operator ``+`` with ``spark.sql.ansi.enabled`` set to false. :: + + SELECT CAST(1.1 as DECIMAL(3, 1)) + CAST(2.2 as DECIMAL(3, 1)); -- 3.3 + SELECT CAST('99999999999999999999999999999999999999' as DECIMAL(38, 0)) + CAST(1 as DECIMAL(38, 0)); -- NULL + +.. spark:function:: checked_add(x: decimal(p1, s1), y: decimal(p2, s2)) -> r: decimal(p3, s3) + + Returns the result of adding ``x`` and ``y``. The result type is determined + by the precision and scale computation rules described above. + Throws an error when the result overflows. + Corresponds to Spark's operator ``+`` with ``spark.sql.ansi.enabled`` set to true. + +.. spark:function:: checked_div(x: decimal(p1, s1), y: decimal(p2, s2)) -> bigint + + Performs integer division and returns the bigint result of dividing ``x`` by ``y``, truncating toward zero. + Truncation occurs if the result is within the result precision but exceeds the BIGINT range. + Division by zero or overflow results in an error. + Does not have ``allow-precision-loss`` variants because ``IntegralDivide`` always returns + ``LongType`` (result scale is 0), so precision loss is not applicable. + Corresponds to Spark's operator ``div`` with ``spark.sql.ansi.enabled`` set to true. + +.. spark:function:: checked_multiply(x: decimal(p1, s1), y: decimal(p2, s2)) -> r: decimal(p3, s3) + + Returns the result of multiplying ``x`` and ``y``. The result type is determined + by the precision and scale computation rules described above. + Throws an error when the result overflows. + Corresponds to Spark's operator ``*`` with ``spark.sql.ansi.enabled`` set to true. + +.. spark:function:: checked_subtract(x: decimal(p1, s1), y: decimal(p2, s2)) -> r: decimal(p3, s3) + + Returns the result of subtracting ``y`` from ``x``. The result type is determined + by the precision and scale computation rules described above. + Throws an error when the result overflows. + Corresponds to Spark's operator ``-`` with ``spark.sql.ansi.enabled`` set to true. + +.. spark:function:: div(x: decimal(p1, s1), y: decimal(p2, s2)) -> bigint + + Performs integer division and returns the bigint result of dividing ``x`` by ``y``, truncating toward zero. + Truncation occurs if the result is within the result precision but exceeds the BIGINT range. + Division by zero or overflow results in NULL. Does not respect the ``allow-precision-loss`` configuration. + Corresponds to Spark's operator ``div`` with ``spark.sql.ansi.enabled`` set to false. :: + + SELECT CAST(1 as DECIMAL(17, 3)) div CAST(2 as DECIMAL(17, 3)); -- 0 + SELECT CAST(21 as DECIMAL(20, 3)) div CAST(20 as DECIMAL(20, 2)); -- 1 + SELECT CAST(1 as DECIMAL(20, 3)) div CAST(0 as DECIMAL(20, 3)); -- NULL + SELECT CAST(99999999999999999999999999999999999 as DECIMAL(38, 1)) div CAST(0.001 as DECIMAL(7, 4)); -- 687399551400672280 // Result is truncated to int64_t. + +.. spark:function:: multiply(x: decimal(p1, s1), y: decimal(p2, s2)) -> r: decimal(p3, s3) + + Returns the result of multiplying ``x`` and ``y``. The result type is determined + by the precision and scale computation rules described above. + Returns NULL when the result overflows. + Corresponds to Spark's operator ``*`` with ``spark.sql.ansi.enabled`` set to false. :: + + SELECT CAST(1.1 as DECIMAL(3, 1)) * CAST(2.0 as DECIMAL(3, 1)); -- 2.20 + SELECT CAST('99999999999999999999999999999999999999' as DECIMAL(38, 0)) * CAST(10 as DECIMAL(38, 0)); -- NULL + +.. spark:function:: subtract(x: decimal(p1, s1), y: decimal(p2, s2)) -> r: decimal(p3, s3) + + Returns the result of subtracting ``y`` from ``x``. The result type is determined + by the precision and scale computation rules described above. + Returns NULL when the result overflows. + Corresponds to Spark's operator ``-`` with ``spark.sql.ansi.enabled`` set to false. :: + + SELECT CAST(1.1 as DECIMAL(3, 1)) - CAST(2.2 as DECIMAL(3, 1)); -- -1.1 + SELECT CAST('-99999999999999999999999999999999999999' as DECIMAL(38, 0)) - CAST(1 as DECIMAL(38, 0)); -- NULL + Decimal Functions ----------------- .. spark:function:: ceil(x: decimal(p, s)) -> r: decimal(pr, 0) diff --git a/velox/docs/functions/spark/json.rst b/velox/docs/functions/spark/json.rst index c5fef47a773..c7f37a7ffa5 100644 --- a/velox/docs/functions/spark/json.rst +++ b/velox/docs/functions/spark/json.rst @@ -59,10 +59,7 @@ JSON Functions .. spark:function:: get_json_object(jsonString, path) -> varchar Returns a json object, represented by VARCHAR, from ``jsonString`` by searching ``path``. - Valid ``path`` should start with '$' and then contain "[index]", "['field']" or ".field" - to define a JSON path. Here are some examples: "$.a" "$.a.b", "$[0]['a'].b". Returns - ``jsonString`` if ``path`` is "$". Returns NULL if ``jsonString`` or ``path`` is malformed. - Returns NULL if ``path`` does not exist. :: + Returns NULL if ``jsonString`` or ``path`` is malformed or ``path`` does not exist. :: SELECT get_json_object('{"a":"b"}', '$.a'); -- 'b' SELECT get_json_object('{"a":{"b":"c"}}', '$.a'); -- '{"b":"c"}' @@ -70,6 +67,12 @@ JSON Functions SELECT get_json_object('{"a"-3}'', '$.a'); -- NULL (malformed JSON string) SELECT get_json_object('{"a":3}'', '.a'); -- NULL (malformed JSON path) + Valid ``path`` syntax: + * Must start with '$'. + * Using "[index]", "['field']" or ".field" to navigate to the desired JSON object. + * Whitespace is allowed **after the dot** and **before the field name**, e.g., "$. field". + * Trailing whitespace after '$' is allowed, e.g., "$ ". + .. spark:function:: json_array_length(jsonString) -> integer Returns the number of elements in the outermost JSON array from ``jsonString``. diff --git a/velox/docs/functions/spark/map.rst b/velox/docs/functions/spark/map.rst index 85532a6a12a..667cd833c56 100644 --- a/velox/docs/functions/spark/map.rst +++ b/velox/docs/functions/spark/map.rst @@ -51,6 +51,16 @@ Map Functions SELECT map_from_arrays(array(1.0, 3.0), array('2', '4')); -- {1.0 -> 2, 3.0 -> 4} +.. spark:function:: map_from_entries(array(struct(K,V))) -> map(K,V) + + Returns a map created from the given array of entries. Throws exception if duplicate key or NULL + key is found. Returns NULL if NULL entry exists. :: + + SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'null'))); -- {1 -> 'a', 2 -> 'null'} + SELECT map_from_entries(array(struct(1, 'a'), null)); -- {null} + SELECT map_from_entries(array(struct(null, 'a'))); -- "map key cannot be null" + SELECT map_from_entries(array(struct(1, 'a'), struct(1, 'b'))); -- "Duplicate map keys (1) are not allowed" + .. spark:function:: map_keys(x(K,V)) -> array(K) Returns all the keys in the map ``x``. diff --git a/velox/docs/functions/spark/math.rst b/velox/docs/functions/spark/math.rst index 08a11a2b60e..a5f211b9bc6 100644 --- a/velox/docs/functions/spark/math.rst +++ b/velox/docs/functions/spark/math.rst @@ -2,12 +2,17 @@ Mathematical Functions ====================== -.. spark:function:: abs(x) -> [same as x] +.. spark:function:: abs(x) -> [same as x] (ANSI compliant) Returns the absolute value of ``x``. When ``x`` is negative minimum - value of integral type, returns the same value as ``x`` following - the behavior when Spark ANSI mode is disabled. - + value of integral type returns the same value as ``x`` following + the behavior when Spark ANSI mode is disabled and throws exception + when Spark ANSI mode is enabled. :: + + SELECT abs(-42); -- 42 + SELECT abs(3.14); -- 3.14 + SELECT abs(-128); -- 128 (with ANSI mode disabled) + SELECT abs(-128); -- Overflow exception (with ANSI mode enabled for TINYINT) .. spark:function:: acos(x) -> double Returns the inverse cosine (a.k.a. arc cosine) of ``x``. @@ -78,7 +83,14 @@ Mathematical Functions Returns the result of adding x to y. The types of x and y must be the same. For integral types, overflow results in an error. Corresponds to Spark's operator ``+`` with ``failOnError`` as true. -.. function:: checked_divide(x, y) -> [same as x] +.. function:: checked_div(x, y) -> bigint + + Returns the result of integer division of ``x`` by ``y``, truncating toward zero. + Supported types are integral types, ``x`` and ``y`` must have the same type. + Division by zero or overflow results in an error. This function operates in ANSI mode (error on invalid input). + Corresponds to Spark's operator ``div`` with ``spark.sql.ansi.enabled`` set to true. + +.. spark:function:: checked_divide(x, y) -> [same as x] Returns the results of dividing x by y. The types of x and y must be the same. Division by zero results in an error. Corresponds to Spark's operator ``/`` with ``failOnError`` as true. @@ -113,6 +125,16 @@ Mathematical Functions Converts angle x in radians to degrees. +.. spark:function:: div(x, y) -> bigint + + Returns the results of dividing x by y. Performs the integer division truncates toward zero. + Supported types are integral types, x and y must have the same type. + Division by zero or overflow results in null. :: + + SELECT 3 div 2; -- 1 + SELECT 1L div 2L; -- 0 + SELECT 3 div 0; -- NULL + .. spark:function:: divide(x, y) -> double Returns the results of dividing x by y. Performs floating point division. @@ -256,8 +278,8 @@ Mathematical Functions `spark.partition_id` to each thread (in a deterministic way) . ``seed`` must be constant. NULL ``seed`` is identical to zero ``seed``. :: - SELECT rand(0); -- 0.5488135024422883 - SELECT rand(NULL); -- 0.5488135024422883 + SELECT rand(0); -- 0.7604953758285915 + SELECT rand(NULL); -- 0.7604953758285915 .. spark:function:: random() -> double diff --git a/velox/docs/functions/spark/misc.rst b/velox/docs/functions/spark/misc.rst index 311586faa6f..445921ce442 100644 --- a/velox/docs/functions/spark/misc.rst +++ b/velox/docs/functions/spark/misc.rst @@ -2,6 +2,19 @@ Miscellaneous Functions ======================= +.. spark:function:: assert_not_null(value) -> value + + Returns the input ``value`` if it is not null. Throws an error if + ``value`` is null. Used to enforce NOT NULL column constraints during + table inserts. + +.. spark:function:: assert_not_null(value, errMsg) -> value + :noindex: + + A version of ``assert_not_null`` that uses a custom error message. + ``errMsg`` is a constant VARCHAR specifying the error message to throw + when ``value`` is null. + .. spark:function:: at_least_n_non_nulls(n, value1, value2, ..., valueN) -> bool Returns true if there are at least ``n`` non-null and non-NaN values, diff --git a/velox/docs/functions/spark/string.rst b/velox/docs/functions/spark/string.rst index 6e8c5f96de8..d39c86592c7 100644 --- a/velox/docs/functions/spark/string.rst +++ b/velox/docs/functions/spark/string.rst @@ -286,6 +286,27 @@ String Functions SELECT overlay('Spark SQL', 'tructured', 2, 4); -- "Structured SQL" SELECT overlay('Spark SQL', '_', -6, 3); -- "_Sql" +.. spark:function:: randstr(length, seed) -> varchar + + Returns a string of the specified ``length`` whose characters are chosen uniformly + at random from the following pool of characters: 0-9, a-z, A-Z. + Both ``length`` and ``seed`` must be non-null constants. + ``length`` must be a non-negative integer (SMALLINT or INT). + ``seed`` must be an integer (INT or BIGINT). + With the same ``seed`` and partition ID, the function produces a reproducible sequence + of outputs, though each row receives a different value from the sequence as the + internal generator advances. + The partition ID is retrieved from the ``spark_partition_id`` query configuration. + It's consistent with Spark's internal assignment for tasks. + Uses XORShift random number generator matching Spark's implementation. + Note: Spark's analyzer always provides a seed (either user-specified or + auto-generated), so only the seeded variant is implemented. + This function was added in Spark 4.0. :: + + SELECT randstr(5, 0); -- "ceV0P" (reproducible with seed) + SELECT randstr(10, 0); -- "ceV0PXaR2I" + SELECT randstr(0, 42); -- "" + .. spark:function:: read_side_padding(string, limit) -> varchar Right-pads the given string with spaces to the specified length ``limit``. @@ -458,6 +479,26 @@ String Functions SELECT substring_index('aaaaa', 'aa', 5); -- "aaaaa" SELECT substring_index('aaaaa', 'aa', -5); -- "aaaaa" +.. spark:function:: to_pretty_string(x) -> varchar + + Returns pretty string for ``x``. All scalar types are supported. + Adjusts the timestamp input to the given time zone if set through ``session_timezone`` config. + The result is different from that of casting ``x`` as string in the following aspects. + + - It prints null input as "NULL" rather than producing null output. + + - It prints binary values using the hex format. + + :: + + SELECT to_pretty_string(4); -- "4" + SELECT to_pretty_string(cast("1.0" as float)); -- "1.0" + SELECT to_pretty_string("spark"); -- "spark" + SELECT to_pretty_string(cast('abcdef' as binary)); -- "[61 62 63 64 65 66]" + SELECT to_pretty_string(null); -- "NULL" + SELECT to_pretty_string(cast(2347589 as timestamp)); -- "1970-01-28 12:06:29" + SELECT to_pretty_string(cast('2024-05-08' as date)); -- "2024-05-08" + .. spark:function:: translate(string, match, replace) -> varchar Returns a new translated string. It translates the character in ``string`` by a diff --git a/velox/docs/index.rst b/velox/docs/index.rst index 8fb3f14161b..f1d65f2065e 100644 --- a/velox/docs/index.rst +++ b/velox/docs/index.rst @@ -10,6 +10,7 @@ Velox Documentation functions spark_functions functions/iceberg/functions + functions/delta/functions configs monitoring bindings/python/index diff --git a/velox/docs/monitoring/metrics.rst b/velox/docs/monitoring/metrics.rst index 2764a1c46dd..f16bfe0530e 100644 --- a/velox/docs/monitoring/metrics.rst +++ b/velox/docs/monitoring/metrics.rst @@ -93,13 +93,16 @@ Task Execution 30 buckets. It is configured to report the latency at P50, P90, P99, and P100 percentiles. * - task_batch_process_time_ms - - Average + - Avg - Tracks the averaged task batch processing time. This only applies for sequential task execution mode. * - task_barrier_process_time_ms - Histogram - Tracks task barrier execution time in range of [0, 30s] with 30 buckets and each bucket with time window of 1s. We report P50, P90, P99, and P100. + * - task_splits_count + - Count + - The total number of splits received by all tasks. Memory Management ----------------- @@ -160,9 +163,6 @@ Memory Management * - task_memory_reclaim_wait_timeout_count - Count - The number of times that the task memory reclaim wait timeouts. - * - task_splits_count - - Count - - The total number of splits received by all tasks. * - memory_non_reclaimable_count - Count - The number of times that the memory reclaim fails because the operator is executing a @@ -226,11 +226,11 @@ Memory Management arbitration operation in range of [0, 600s] with 20 buckets. It is configured to report the latency at P50, P90, P99 and P100 percentiles. * - arbitrator_free_capacity_bytes - - Average + - Avg - The average of total free memory capacity which is managed by the memory arbitrator. * - arbitrator_free_reserved_capacity_bytes - - Average + - Avg - The average of free memory capacity reserved to ensure each query has the minimal required capacity to run. * - memory_pool_initial_capacity_bytes @@ -292,9 +292,12 @@ Cache - Avg - Max possible age of AsyncDataCache and SsdCache entries since the raw file was opened to load the cache. - * - memory_cache_num_entries + * - memory_cache_num_large_entries + - Avg + - Total number of large cache entries. + * - memory_cache_num_tiny_entries - Avg - - Total number of cache entries. + - Total number of tiny cache entries. * - memory_cache_num_empty_entries - Avg - Total number of cache entries that do not cache anything. @@ -416,9 +419,15 @@ Cache * - ssd_cache_write_ssd_errors - Sum - Total number of error while writing to SSD cache files. + * - ssd_cache_write_no_space_errors + - Sum + - Total number of errors due to SSD no space for writes. * - ssd_cache_write_ssd_dropped - Sum - Total number of writes dropped due to no cache space. + * - ssd_cache_write_exceed_entry_limit + - Sum + - Total number of writes dropped due to entry limit exceeded. * - ssd_cache_write_checkpoint_errors - Sum - Total number of errors while writing SSD checkpoint file. @@ -511,9 +520,9 @@ Spilling - The distribution of the amount of time spent on serializing rows for spilling in range of [0, 600s] with 20 buckets. It is configured to report the latency at P50, P90, P99, and P100 percentiles. - * - spill_disk_writes_count + * - spill_writes_count - Count - - The number of disk writes to spill rows. + - The number of Velox filesystem write calls to spill rows. * - spill_flush_time_ms - Histogram - The distribution of the amount of time spent on copy out serialized @@ -619,6 +628,9 @@ Index Join - The distribution of index lookup result bytes in range of [0, 128MB] with 128 buckets. It is configured to report the capacity at P50, P90, P99, and P100 percentiles. + * - index_lookup_error_result_count + - Count + - The number of results with error. Table Scan ---------- @@ -631,10 +643,10 @@ Table Scan - Type - Description * - table_scan_batch_process_time_ms - - Average + - Avg - Tracks the averaged table scan batch processing time in milliseconds. * - table_scan_batch_bytes - - Average + - Avg - Tracks the averaged table scan output batch size in bytes. with 512 buckets and reports P50, P90, P99, and P100 diff --git a/velox/docs/monitoring/stats.rst b/velox/docs/monitoring/stats.rst index f6823991773..aa14d81d78e 100644 --- a/velox/docs/monitoring/stats.rst +++ b/velox/docs/monitoring/stats.rst @@ -104,6 +104,11 @@ These stats are reported only by TableScan operator * - numRunningScanThreads - - The number of running table scan drivers. + * - fileFormat. + - + - The number of splits read for each file format (e.g. fileFormat.dwrf, + fileFormat.parquet, fileFormat.nimble). Reported per format encountered + during the query. TableWriter ----------- @@ -157,17 +162,34 @@ These stats are reported only by IndexLookupJoin operator * - Stats - Unit - Description + * - connectorIndexReadCpuNanos + - nanos + - CPU time spent reading index data from the index reader (e.g. stripe I/O, decoding). + * - connectorIndexReadWallNanos + - nanos + - Wall time spent reading index data from the index reader (e.g. stripe I/O, decoding). + * - connectorIndexSetupCpuNanos + - nanos + - CPU time spent initializing the index lookup (startLookup). + * - connectorIndexSetupWallNanos + - nanos + - Wall time spent initializing the index lookup (startLookup). * - connectorlookupWallNanos - nanos - - The end-to-end walltime in nanoseconds that the index connector do the lookup. + - End-to-end wall time for the index connector lookup (sum of setup, read, output, and filter). * - connectorlookupWaitWallNanos - nanos - The walltime in nanoseconds that the index connector wait for the lookup from remote storage. + * - connectorPostFilterCpuNanos + - nanos + - CPU time spent evaluating the remaining filter on index lookup results. + * - connectorPostFilterWallNanos + - nanos + - Wall time spent evaluating the remaining filter on index lookup results. * - connectorResultPrepareCpuNanos - nanos - - The cpu time in nanoseconds that the index connector process response from storages - client for followup processing by index join operator. + - CPU time spent projecting output columns from index reader results. * - clientlookupWaitWallNanos - nanos - The walltime in nanoseconds that the storage client wait for the lookup from remote storage. @@ -192,6 +214,9 @@ These stats are reported only by IndexLookupJoin operator * - clientNumLazyDecodedResultBatches - - The number of lazy decoded result batches returned from the storage client. + * - numIndexSplits + - + - The number of index splits provided for index lookup. Merge ----- @@ -342,6 +367,177 @@ These stats are reported only by connector data or index sources. * - Stats - Unit - Description + * - ioWaitWallNanos + - nanos + - Total time spent by query processing threads waiting for I/O operations + to complete. This includes waiting for synchronously issued I/O or for + in-progress read-ahead operations to finish. + * - storageReadWallNanos + - nanos + - Time spent waiting for direct remote storage reads (e.g., S3, HDFS). + This is a component of ioWaitWallNanos. + * - ssdCacheReadWallNanos + - nanos + - Time spent waiting for SSD cache reads. This is a component of + ioWaitWallNanos. + * - cacheWaitWallNanos + - nanos + - Time spent waiting for cache entries that are being loaded by another + thread (EXCLUSIVE state). This is a component of ioWaitWallNanos. + * - coalescedSsdLoadWallNanos + - nanos + - Time spent waiting for coalesced loads from SSD cache. This occurs when + multiple requests are combined into a single SSD read operation. + This is a component of ioWaitWallNanos. + * - coalescedStorageLoadWallNanos + - nanos + - Time spent waiting for coalesced loads from remote storage. This occurs + when multiple requests are combined into a single remote storage read. + This is a component of ioWaitWallNanos. * - totalRemainingFilterWallNanos - nanos - The total walltime in nanoseconds that the data or index connector do the remaining filtering. + * - totalRemainingFilterCpuNanos + - nanos + - The total CPU time in nanoseconds that the data or index connector do the remaining filtering. + * - numIndexReaderOutputRows + - + - The total number of output rows returned across all next() calls from the + index reader. This is the final row count after cluster index bounds + and ScanSpec filter pushdown. + * - numIndexFilterConversions + - + - The number of index columns that were converted from ScanSpec filters to + index bounds for index-based filtering (e.g., cluster index pruning in + Nimble). A value greater than zero indicates filters were successfully + converted to leverage file index structures for row pruning. + * - numIndexLookupReadSegments + - + - The total number of read segments across all stripes during index lookup. + A read segment is a contiguous row range within a stripe that needs to be + read. When filters are present, overlapping request ranges are split at + boundaries to enable per-request output tracking. Without filters, + overlapping ranges are merged to minimize I/O. + * - numIndexLookupRequests + - + - The number of index lookup requests submitted in startLookup(). Each + request corresponds to one set of index bounds and may match rows across + multiple stripes. + * - numIndexLookupStripes + - + - The total number of stripes that need to be read for all index lookup + requests. Within a single startLookup() call, a stripe shared by + multiple requests is counted once; across different startLookup() calls, + the same stripe is counted separately for each call. + * - numIndexMatchedRows + - + - The total number of rows matched by the cluster index across all stripes. + These are the rows identified as matching the lookup bounds within each + stripe, before any ScanSpec filter pushdown. Comparing with actual output + rows shows filter selectivity. + * - numIndexScannedRows + - + - The total number of rows in all loaded stripes during index lookup. + Measures the full stripe row count regardless of how many rows are + actually needed. Comparing with numIndexMatchedRows shows cluster index + selectivity within stripes. + * - numStripeLoads + - + - The number of times a stripe has been loaded during index lookup. This + metric helps track the I/O efficiency of index-based reads, where lower + values indicate better stripe reuse across lookups. + * - numIndexDistinctStripesLoaded + - + - The number of distinct stripes loaded across the lifetime of the index + reader. Comparing with numStripeLoads (which counts every load call) + reveals redundant loads of the same stripe. + * - indexStripeLoadWallNanos + - nanos + - Wall time spent loading stripes (or equivalent format-specific load + unit) during index lookup, summed across all stripe loads. + * - indexStripeLoadCpuNanos + - nanos + - CPU time spent loading stripes during index lookup. May undercount on + async/prefetch paths. + * - indexDataDecodeWallNanos + - nanos + - Wall time spent decoding column data from loaded stripes during index + lookup, summed across all read segments. + * - indexDataDecodeCpuNanos + - nanos + - CPU time spent decoding column data from loaded stripes during index + lookup. Same prefetch undercounting caveat as indexStripeLoadCpuNanos. + +FileBasedDataSource +------------------- +These stats are reported by the file-based connector data source (Hive connector). +Data stream IO stats use the stat names directly (e.g., ``storageReadBytes``). +Metadata IO stats (footer, stripe groups, index) use a ``metadata.`` prefix +(e.g., ``metadata.storageReadBytes``, ``metadata.ramReadBytes``). + +.. list-table:: + :widths: 50 25 50 + :header-rows: 1 + + * - Stats + - Unit + - Description + * - skippedSplits + - + - The number of splits skipped based on file statistics. + * - processedSplits + - + - The number of splits processed. + * - skippedSplitBytes + - bytes + - The total bytes in splits skipped based on file statistics. + * - skippedStrides + - + - The number of strides (row groups) skipped based on statistics. + * - processedStrides + - + - The number of strides (row groups) processed. + * - footerBufferOverread + - bytes + - The number of extra bytes read beyond the footer size due to buffer + over-reading. + * - numStripes + - + - The number of stripes read from the file. + * - flattenStringDictionaryValues + - + - The number of rows returned by the string dictionary reader that were + flattened instead of keeping dictionary encoding. + * - pageLoadTimeNs + - nanos + - The total time spent loading pages. + * - numPrefetch + - + - The number of prefetch operations issued. + * - prefetchBytes + - bytes + - The total bytes prefetched, including min and max per prefetch operation. + * - totalScanTime + - nanos + - The total wall time spent scanning the file. + * - overreadBytes + - bytes + - The total raw bytes over-read during I/O operations. + * - storageReadBytes + - bytes + - The total bytes read from remote storage, including min and max per read + operation. + * - numLocalRead + - + - The number of reads served from the local SSD cache. + * - localReadBytes + - bytes + - The total bytes read from the local SSD cache, including min and max per + read operation. + * - numRamRead + - + - The number of reads served from the in-memory (RAM) cache. + * - ramReadBytes + - bytes + - The total bytes read from the in-memory (RAM) cache, including min and + max per read operation. diff --git a/velox/docs/monthly-updates/may-2025.rst b/velox/docs/monthly-updates/may-2025.rst index fcacf090853..d049377a9a5 100644 --- a/velox/docs/monthly-updates/may-2025.rst +++ b/velox/docs/monthly-updates/may-2025.rst @@ -1,6 +1,6 @@ -************** +*************** May 2025 Update -************** +*************** This update was generated with the assistance of AI. While we strive for accuracy, please note that AI-generated content may not always be error-free. We encourage you to verify any information diff --git a/velox/duckdb/conversion/CMakeLists.txt b/velox/duckdb/conversion/CMakeLists.txt index 010f13f8f35..d9a8de03642 100644 --- a/velox/duckdb/conversion/CMakeLists.txt +++ b/velox/duckdb/conversion/CMakeLists.txt @@ -11,11 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -velox_add_library(velox_duckdb_conversion DuckConversion.cpp) +velox_add_library(velox_duckdb_conversion DuckConversion.cpp HEADERS DuckConversion.h) velox_link_libraries(velox_duckdb_conversion velox_core velox_vector duckdb_static) -velox_add_library(velox_duckdb_parser DuckParser.cpp) +velox_add_library(velox_duckdb_parser DuckParser.cpp HEADERS DuckParser.h) velox_link_libraries( velox_duckdb_parser diff --git a/velox/duckdb/conversion/DuckConversion.cpp b/velox/duckdb/conversion/DuckConversion.cpp index 10fc820c6b1..19dce42cd86 100644 --- a/velox/duckdb/conversion/DuckConversion.cpp +++ b/velox/duckdb/conversion/DuckConversion.cpp @@ -79,6 +79,10 @@ LogicalType fromVeloxType(const TypePtr& type) { if (type->isIntervalDayTime()) { return LogicalType::INTERVAL; } + if (type->isTime()) { + VELOX_DCHECK(type->equivalent(*TIME())); + return LogicalType::TIME; + } return LogicalType::BIGINT; case TypeKind::REAL: return LogicalType::FLOAT; @@ -138,6 +142,8 @@ TypePtr toVeloxType(LogicalType type, bool fileColumnNamesReadAsLowerCase) { return VARCHAR(); case LogicalTypeId::DATE: return DATE(); + case LogicalTypeId::TIME: + return TIME(); case LogicalTypeId::TIMESTAMP: return TIMESTAMP(); case LogicalTypeId::TIMESTAMP_TZ: { @@ -192,6 +198,9 @@ TypePtr toVeloxType(LogicalType type, bool fileColumnNamesReadAsLowerCase) { if (auto customType = getCustomType(name, {})) { return customType; } + if (name == "OPAQUE") { + return OPAQUE(); + } [[fallthrough]]; } default: diff --git a/velox/duckdb/conversion/DuckParser.cpp b/velox/duckdb/conversion/DuckParser.cpp index d65406f693e..89532695bd0 100644 --- a/velox/duckdb/conversion/DuckParser.cpp +++ b/velox/duckdb/conversion/DuckParser.cpp @@ -133,6 +133,16 @@ std::shared_ptr callExpr( std::vector params, std::optional alias, const ParseOptions& options) { + // DuckDB parser requires IF to have 3 arguments: condition, then-clause, and + // else-clause. For example, `IF(a > b, 10)` doesn't parse correctly and must + // be written as `IF(a > b, 10, null)`. Remove the redundant else-clause. + if (name == "if") { + if (params.back()->is(core::IExpr::Kind::kConstant) && + params.back()->as()->type()->isUnknown()) { + params.pop_back(); + } + } + return std::make_shared( toFullFunctionName(name, options.functionPrefix), std::move(params), @@ -259,6 +269,55 @@ std::shared_ptr tryParseInterval( INTERVAL_DAY_TIME(), Variant(value.value() * multiplier), alias); } +// DuckDB parses struct literals {'x': 1, 'y': 2} as struct_pack(1 AS x, 2 AS +// y) and ROW(1, 2) as row(1, 2). Folds into a ROW constant when all arguments +// are constants. Returns nullptr otherwise. +core::ExprPtr tryFoldRowConstant( + const std::vector& inputs, + const std::optional& alias) { + std::vector names; + std::vector types; + std::vector values; + names.reserve(inputs.size()); + types.reserve(inputs.size()); + values.reserve(inputs.size()); + for (const auto& input : inputs) { + auto* constant = input->as(); + if (!constant) { + return nullptr; + } + names.push_back(constant->alias().value_or("")); + types.push_back(constant->type()); + values.push_back(constant->value()); + } + return std::make_shared( + ROW(std::move(names), std::move(types)), + Variant::row(std::move(values)), + alias); +} + +// DuckDB parses [1, 2, 3] as list_value(1, 2, 3). Folds into an ARRAY constant +// when all arguments are constants. Returns nullptr otherwise. +core::ExprPtr tryFoldArrayConstant( + const std::vector& inputs, + const std::optional& alias) { + std::vector elements; + elements.reserve(inputs.size()); + TypePtr elementType = UNKNOWN(); + for (const auto& input : inputs) { + auto* constant = input->as(); + if (!constant) { + return nullptr; + } + elements.push_back(constant->value()); + if (!constant->value().isNull()) { + elementType = constant->type(); + } + } + return std::make_shared( + ARRAY(elementType), Variant::array(std::move(elements)), alias); +} + // Parse a function call (avg(a), func(1, b), etc). // Arithmetic operators also follow this path (a + b, a * b, etc). core::ExprPtr parseFunctionExpr( @@ -279,6 +338,20 @@ core::ExprPtr parseFunctionExpr( } } + if (func == "struct_pack" || func == "row") { + if (auto rowConstant = tryFoldRowConstant(params, getAlias(expr))) { + return rowConstant; + } + } + + // DuckDB parses [1, 2, 3] as list_value(1, 2, 3). Fold into an ARRAY + // constant when all arguments are constants. + if (func == "list_value") { + if (auto arrayConstant = tryFoldArrayConstant(params, getAlias(expr))) { + return arrayConstant; + } + } + // NOT LIKE function needs special handling as it maps to two functions // "not" and "like". if (func == "notlike") { @@ -327,10 +400,11 @@ core::ExprPtr parseConjunctionExpr( StringUtil::Lower(ExpressionTypeToOperator(expr.GetExpressionType())); if (conjExpr.children.size() < 2) { - throw std::invalid_argument(folly::sformat( - "Malformed conjunction expression " - "(expected at least 2 input columns, got {}).", - conjExpr.children.size())); + throw std::invalid_argument( + folly::sformat( + "Malformed conjunction expression " + "(expected at least 2 input columns, got {}).", + conjExpr.children.size())); } // DuckDB's parser returns conjunction involving multiple input in a flat @@ -380,6 +454,10 @@ core::ExprPtr parseOperatorExpr( if (auto constantExpr = dynamic_cast(child.get())) { auto& value = constantExpr->value; + if (value.type().id() == LogicalTypeId::INTEGER && + options.parseIntegerAsBigint) { + value = Value::BIGINT(value.GetValue()); + } if (options.parseDecimalAsDouble && value.type().id() == duckdb::LogicalTypeId::DECIMAL) { value = Value::DOUBLE(value.GetValue()); @@ -463,8 +541,9 @@ core::ExprPtr parseOperatorExpr( } if (options.parseInListAsArray) { - params.emplace_back(std::make_shared( - ARRAY(valueType), Variant::array(values), std::nullopt)); + params.emplace_back( + std::make_shared( + ARRAY(valueType), Variant::array(values), std::nullopt)); } auto inExpr = callExpr("in", std::move(params), getAlias(expr), options); // Translate COMPARE_NOT_IN into NOT(IN()). @@ -595,6 +674,23 @@ core::ExprPtr parseCastExpr( getAlias(expr)); } } + + // DuckDB parses DATE '...' and '...'::date as cast(varchar as DATE). + // Fold into a DATE constant. + if (targetType->isDate() && constant->type()->isVarchar()) { + const auto& value = constant->value().value(); + return std::make_shared( + DATE(), + Variant::create(DATE()->toDays(value)), + getAlias(expr)); + } + + // ROW(1, 2)::struct(x bigint, y bigint) — re-type the ROW constant with + // the target type (which carries field names). Child types must match. + if (targetType->isRow() && targetType->equivalent(*constant->type())) { + return std::make_shared( + targetType, constant->value(), getAlias(expr)); + } } const bool isTryCast = castExpr.try_cast; @@ -693,7 +789,8 @@ std::unique_ptr<::duckdb::ParsedExpression> parseSingleExpression( auto parsed = parseExpression(exprString); VELOX_CHECK_EQ( 1, parsed.size(), "Expected exactly one expression: {}.", exprString); - return std::move(parsed.front()); + auto result = std::move(parsed.front()); + return result; } } // namespace @@ -753,7 +850,7 @@ bool isNullsFirst( } } // namespace -OrderByClause parseOrderByExpr(const std::string& exprString) { +parse::OrderByClause parseOrderByExpr(const std::string& exprString) { ParserOptions options; ParseOptions parseOptions; options.preserve_identifier_case = false; @@ -775,49 +872,54 @@ OrderByClause parseOrderByExpr(const std::string& exprString) { .nullsFirst = nullsFirst}; } -AggregateExpr parseAggregateExpr( +core::AggregateCallExprPtr parseAggregateExpr( const std::string& exprString, const ParseOptions& options) { auto parsedExpr = parseSingleExpression(exprString); auto& functionExpr = dynamic_cast(*parsedExpr); - AggregateExpr aggregateExpr; - aggregateExpr.expr = parseExpr(*parsedExpr, options); - aggregateExpr.distinct = functionExpr.distinct; + auto callExpr = parseExpr(*parsedExpr, options); + std::vector orderBy; if (functionExpr.order_bys) { for (const auto& orderByNode : functionExpr.order_bys->orders) { - const bool ascending = isAscending(orderByNode.type, exprString); - const bool nullsFirst = isNullsFirst(orderByNode.null_order, exprString); - aggregateExpr.orderBy.emplace_back(OrderByClause{ - parseExpr(*orderByNode.expression, options), ascending, nullsFirst}); + orderBy.push_back( + {parseExpr(*orderByNode.expression, options), + isAscending(orderByNode.type, exprString), + isNullsFirst(orderByNode.null_order, exprString)}); } } + core::ExprPtr filter; if (functionExpr.filter) { - aggregateExpr.maskExpr = parseExpr(*functionExpr.filter, options); + filter = parseExpr(*functionExpr.filter, options); } - return aggregateExpr; + auto* call = callExpr->as(); + return std::make_shared( + call->name(), + call->inputs(), + functionExpr.distinct, + std::move(filter), + std::move(orderBy), + callExpr->alias()); } namespace { + +using WindowType = core::WindowCallExpr::WindowType; +using BoundType = core::WindowCallExpr::BoundType; + WindowType parseWindowType(const WindowExpression& expr) { - auto windowType = [&](const WindowBoundary& boundary) -> WindowType { - if (boundary == WindowBoundary::CURRENT_ROW_ROWS || + auto isRows = [](const WindowBoundary& boundary) { + return boundary == WindowBoundary::CURRENT_ROW_ROWS || boundary == WindowBoundary::EXPR_FOLLOWING_ROWS || - boundary == WindowBoundary::EXPR_PRECEDING_ROWS) { - return WindowType::kRows; - } - return WindowType::kRange; + boundary == WindowBoundary::EXPR_PRECEDING_ROWS; }; - auto startType = windowType(expr.start); - if (startType == WindowType::kRows) { - return startType; - } - return windowType(expr.end); + return (isRows(expr.start) || isRows(expr.end)) ? WindowType::kRows + : WindowType::kRange; } BoundType parseBoundType(WindowBoundary boundary) { @@ -841,29 +943,23 @@ BoundType parseBoundType(WindowBoundary boundary) { VELOX_UNREACHABLE(); } -} // namespace - -IExprWindowFunction parseWindowExpr( +core::WindowCallExprPtr buildWindowCallExpr( + ParsedExpression& parsedExpr, const std::string& windowString, const ParseOptions& options) { - auto parsedExpr = parseSingleExpression(windowString); - VELOX_CHECK( - parsedExpr->IsWindow(), - "Invalid window function expression: {}", - windowString); + auto& windowExpr = dynamic_cast(parsedExpr); - IExprWindowFunction windowIExpr; - auto& windowExpr = dynamic_cast(*parsedExpr); - for (int i = 0; i < windowExpr.partitions.size(); i++) { - windowIExpr.partitionBy.push_back( - parseExpr(*(windowExpr.partitions[i].get()), options)); + std::vector partitionKeys; + for (const auto& partition : windowExpr.partitions) { + partitionKeys.push_back(parseExpr(*partition, options)); } + std::vector orderByKeys; for (const auto& orderByNode : windowExpr.orders) { - const bool ascending = isAscending(orderByNode.type, windowString); - const bool nullsFirst = isNullsFirst(orderByNode.null_order, windowString); - windowIExpr.orderBy.emplace_back(OrderByClause{ - parseExpr(*orderByNode.expression, options), ascending, nullsFirst}); + orderByKeys.push_back( + {parseExpr(*orderByNode.expression, options), + isAscending(orderByNode.type, windowString), + isNullsFirst(orderByNode.null_order, windowString)}); } std::vector params; @@ -880,32 +976,58 @@ IExprWindowFunction parseWindowExpr( params.emplace_back(parseExpr(*windowExpr.default_expr, options)); } - auto func = normalizeFuncName(windowExpr.function_name); - windowIExpr.functionCall = - callExpr(func, std::move(params), getAlias(windowExpr), options); - - windowIExpr.ignoreNulls = windowExpr.ignore_nulls; - - windowIExpr.frame.type = parseWindowType(windowExpr); - windowIExpr.frame.startType = parseBoundType(windowExpr.start); + core::ExprPtr startValue; if (windowExpr.start_expr) { - windowIExpr.frame.startValue = - parseExpr(*windowExpr.start_expr.get(), options); + startValue = parseExpr(*windowExpr.start_expr, options); } - - windowIExpr.frame.endType = parseBoundType(windowExpr.end); + core::ExprPtr endValue; if (windowExpr.end_expr) { - windowIExpr.frame.endValue = parseExpr(*windowExpr.end_expr.get(), options); + endValue = parseExpr(*windowExpr.end_expr, options); } - return windowIExpr; + + auto endType = parseBoundType(windowExpr.end); + if (options.correctWindowFrameDefault && orderByKeys.empty() && + endType == core::WindowCallExpr::BoundType::kCurrentRow) { + endType = core::WindowCallExpr::BoundType::kUnboundedFollowing; + } + + return std::make_shared( + normalizeFuncName(windowExpr.function_name), + std::move(params), + std::move(partitionKeys), + std::move(orderByKeys), + core::WindowCallExpr::Frame{ + parseWindowType(windowExpr), + parseBoundType(windowExpr.start), + std::move(startValue), + endType, + std::move(endValue)}, + windowExpr.ignore_nulls, + getAlias(windowExpr)); } -std::string OrderByClause::toString() const { - return fmt::format( - "{} {} NULLS {}", - expr->toString(), - (ascending ? "ASC" : "DESC"), - (nullsFirst ? "FIRST" : "LAST")); +} // namespace + +core::WindowCallExprPtr parseWindowExpr( + const std::string& windowString, + const ParseOptions& options) { + auto parsedExpr = parseSingleExpression(windowString); + VELOX_CHECK( + parsedExpr->IsWindow(), + "Invalid window function expression: {}", + windowString); + + return buildWindowCallExpr(*parsedExpr, windowString, options); +} + +core::ExprPtr parseScalarOrWindowExpr( + const std::string& exprString, + const ParseOptions& options) { + auto parsedExpr = parseSingleExpression(exprString); + if (parsedExpr->IsWindow()) { + return buildWindowCallExpr(*parsedExpr, exprString, options); + } + return parseExpr(*parsedExpr, options); } } // namespace facebook::velox::duckdb diff --git a/velox/duckdb/conversion/DuckParser.h b/velox/duckdb/conversion/DuckParser.h index 413e9280561..f8e094109d0 100644 --- a/velox/duckdb/conversion/DuckParser.h +++ b/velox/duckdb/conversion/DuckParser.h @@ -16,18 +16,10 @@ #pragma once #include -#include "velox/parse/IExpr.h" +#include "velox/parse/SqlExpressionsParser.h" namespace facebook::velox::duckdb { -struct OrderByClause { - core::ExprPtr expr; - bool ascending; - bool nullsFirst; - - std::string toString() const; -}; - /// Hold parsing options. struct ParseOptions { // Retain legacy behavior by default. @@ -37,6 +29,12 @@ struct ParseOptions { // single array argument. bool parseInListAsArray = true; + // DuckDB defaults the window frame end bound to CURRENT ROW even when ORDER + // BY is absent. The SQL standard requires UNBOUNDED FOLLOWING in that case. + // When true, corrects this default. Cannot distinguish defaulted from + // explicit frames, so an explicit CURRENT ROW may be incorrectly overridden. + bool correctWindowFrameDefault = false; + /// SQL functions could be registered with different prefixes by the user. /// This parameter is the registered prefix of presto or spark functions, /// which helps generate the correct Velox expression. @@ -58,60 +56,30 @@ std::vector parseMultipleExpressions( const std::string& exprString, const ParseOptions& options); -struct AggregateExpr { - core::ExprPtr expr; - std::vector orderBy; - bool distinct{false}; - core::ExprPtr maskExpr{nullptr}; -}; - /// Parses aggregate function call expression with optional ORDER by clause. +/// Always returns an AggregateCallExpr. /// Examples: /// sum(a) /// sum(a) as s /// array_agg(x ORDER BY y DESC) -AggregateExpr parseAggregateExpr( +core::AggregateCallExprPtr parseAggregateExpr( const std::string& exprString, const ParseOptions& options); // Parses an ORDER BY clause using DuckDB's internal postgresql-based parser. // Uses ASC NULLS LAST as the default sort order. -OrderByClause parseOrderByExpr(const std::string& exprString); - -// Parses a WINDOW function SQL string using DuckDB's internal postgresql-based -// parser. Window Functions are executed by Velox Window PlanNodes and not the -// expression evaluation. So we cannot use an IExpr based API. The structures -// below capture all the metadata needed from the window function SQL string -// for usage in the WindowNode plan node. -enum class WindowType { kRows, kRange }; - -enum class BoundType { - kCurrentRow, - kUnboundedPreceding, - kUnboundedFollowing, - kPreceding, - kFollowing -}; +parse::OrderByClause parseOrderByExpr(const std::string& exprString); -struct IExprWindowFrame { - WindowType type; - BoundType startType; - core::ExprPtr startValue; - BoundType endType; - core::ExprPtr endValue; -}; - -struct IExprWindowFunction { - core::ExprPtr functionCall; - IExprWindowFrame frame; - bool ignoreNulls; - - std::vector partitionBy; - std::vector orderBy; -}; - -IExprWindowFunction parseWindowExpr( +/// Parses a WINDOW function SQL string. Returns a WindowCallExpr. +core::WindowCallExprPtr parseWindowExpr( const std::string& windowString, const ParseOptions& options); +/// Parses a SQL expression that can be either a scalar expression or a window +/// function. Returns a WindowCallExpr (kWindow kind) for window functions, +/// or a regular ExprPtr for scalar expressions. +core::ExprPtr parseScalarOrWindowExpr( + const std::string& exprString, + const ParseOptions& options); + } // namespace facebook::velox::duckdb diff --git a/velox/duckdb/conversion/tests/DuckConversionTest.cpp b/velox/duckdb/conversion/tests/DuckConversionTest.cpp index d73f222b818..c5d9e09db15 100644 --- a/velox/duckdb/conversion/tests/DuckConversionTest.cpp +++ b/velox/duckdb/conversion/tests/DuckConversionTest.cpp @@ -91,7 +91,6 @@ TEST(DuckConversionTest, duckValueToVariantUnsupported) { /// defined as static constexpr const causing a double definition only in the /// debug build. std::vector unsupported = { - ::duckdb::TransformStringToLogicalType("time"), ::duckdb::TransformStringToLogicalType("interval"), LogicalType::LIST({::duckdb::TransformStringToLogicalType("integer")}), LogicalType::STRUCT( @@ -123,6 +122,7 @@ TEST(DuckConversionTest, types) { testRoundTrip(TIMESTAMP()); testRoundTrip(DATE()); + testRoundTrip(TIME()); testRoundTrip(INTERVAL_DAY_TIME()); testRoundTrip(DECIMAL(22, 5)); @@ -157,7 +157,8 @@ TEST(DuckConversionTest, createTable) { DOUBLE()})); testCreateTable( - ROW({"a", "b", "c"}, {TIMESTAMP(), DATE(), INTERVAL_DAY_TIME()})); + ROW({"a", "b", "c", "d"}, + {TIMESTAMP(), DATE(), INTERVAL_DAY_TIME(), TIME()})); testCreateTable(ROW({"a", "b"}, {DECIMAL(7, 5), DECIMAL(30, 10)})); diff --git a/velox/duckdb/conversion/tests/DuckParserTest.cpp b/velox/duckdb/conversion/tests/DuckParserTest.cpp index 21168c4248d..dd73fe6c141 100644 --- a/velox/duckdb/conversion/tests/DuckParserTest.cpp +++ b/velox/duckdb/conversion/tests/DuckParserTest.cpp @@ -22,6 +22,7 @@ using namespace facebook::velox; using namespace facebook::velox::duckdb; +using namespace facebook::velox::parse; namespace { std::shared_ptr parseExpr(const std::string& exprString) { @@ -99,40 +100,9 @@ TEST(DuckParserTest, functions) { } namespace { -std::string toString(const std::vector& orderBy) { - std::stringstream out; - if (!orderBy.empty()) { - out << "ORDER BY "; - for (auto i = 0; i < orderBy.size(); ++i) { - if (i > 0) { - out << ", "; - } - out << orderBy[i].toString(); - } - } - - return out.str(); -} - std::string parseAgg(const std::string& expression) { ParseOptions options; - auto aggregateExpr = parseAggregateExpr(expression, options); - std::stringstream out; - out << aggregateExpr.expr->toString(); - - if (aggregateExpr.distinct) { - out << " DISTINCT"; - } - - if (!aggregateExpr.orderBy.empty()) { - out << " " << toString(aggregateExpr.orderBy); - } - - if (aggregateExpr.maskExpr != nullptr) { - out << " FILTER " << aggregateExpr.maskExpr->toString(); - } - - return out.str(); + return parseAggregateExpr(expression, options)->toString(); } } // namespace @@ -148,20 +118,20 @@ TEST(DuckParserTest, aggregates) { "array_agg(\"x\") ORDER BY \"y\" ASC NULLS FIRST", parseAgg("array_agg(x ORDER BY y NULLS FIRST)")); EXPECT_EQ( - "array_agg(\"x\") ORDER BY \"y\" ASC NULLS LAST, \"z\" ASC NULLS LAST", + "array_agg(\"x\") ORDER BY \"y\" ASC NULLS LAST,\"z\" ASC NULLS LAST", parseAgg("array_agg(x ORDER BY y, z)")); } TEST(DuckParserTest, aggregatesWithMasks) { EXPECT_EQ( - "array_agg(\"x\") FILTER \"m\"", + "array_agg(\"x\") FILTER(WHERE \"m\")", parseAgg("array_agg(x) filter (where m)")); } TEST(DuckParserTest, distinctAggregates) { - EXPECT_EQ("count(\"x\") DISTINCT", parseAgg("count(distinct x)")); - EXPECT_EQ("count(\"x\",\"y\") DISTINCT", parseAgg("count(distinct x, y)")); - EXPECT_EQ("sum(\"x\") DISTINCT", parseAgg("sum(distinct x)")); + EXPECT_EQ("count(DISTINCT \"x\")", parseAgg("count(distinct x)")); + EXPECT_EQ("count(DISTINCT \"x\",\"y\")", parseAgg("count(distinct x, y)")); + EXPECT_EQ("sum(DISTINCT \"x\")", parseAgg("sum(distinct x)")); } TEST(DuckParserTest, subscript) { @@ -414,8 +384,9 @@ TEST(DuckParserTest, cast) { "cast(\"str_col\" as INTERVAL DAY TO SECOND)", parseExpr("cast(str_col as interval day to second)")->toString()); - // Unsupported casts for now. - EXPECT_THROW(parseExpr("cast('2020-01-01' as TIME)"), std::runtime_error); + EXPECT_EQ( + "cast(2020-01-01 as TIME)", + parseExpr("cast('2020-01-01' as TIME)")->toString()); // Complex types. EXPECT_EQ( @@ -461,6 +432,8 @@ TEST(DuckParserTest, ifCase) { "if(\"a\",plus(\"b\",\"c\"),g(\"d\"))", parseExpr("if(a, b + c, g(d))")->toString()); + EXPECT_EQ("if(gt(\"a\",0),10)", parseExpr("if(a > 0, 10, null)")->toString()); + // CASE statements. EXPECT_EQ( "if(1,null,0)", @@ -569,92 +542,36 @@ TEST(DuckParserTest, orderBy) { } namespace { -const std::string windowTypeString(WindowType w) { - switch (w) { - case WindowType::kRange: - return "RANGE"; - case WindowType::kRows: - return "ROWS"; - } - VELOX_UNREACHABLE(); -} - -const std::string boundTypeString(BoundType b) { - switch (b) { - case BoundType::kUnboundedPreceding: - return "UNBOUNDED PRECEDING"; - case BoundType::kUnboundedFollowing: - return "UNBOUNDED FOLLOWING"; - case BoundType::kPreceding: - return "PRECEDING"; - case BoundType::kFollowing: - return "FOLLOWING"; - case BoundType::kCurrentRow: - return "CURRENT ROW"; - } - VELOX_UNREACHABLE(); -} const std::string parseWindow(const std::string& expr) { ParseOptions options; - auto windowExpr = parseWindowExpr(expr, options); - std::string concatPartitions; - int i = 0; - for (const auto& partition : windowExpr.partitionBy) { - concatPartitions += partition->toString(); - if (i > 0) { - concatPartitions += " , "; - } - i++; - } - auto partitionString = windowExpr.partitionBy.empty() - ? "" - : fmt::format("PARTITION BY {}", concatPartitions); - - auto orderByString = toString(windowExpr.orderBy); - - auto frameString = fmt::format( - "{} BETWEEN {}{} AND{} {}", - windowTypeString(windowExpr.frame.type), - (windowExpr.frame.startValue - ? windowExpr.frame.startValue->toString() + " " - : ""), - boundTypeString(windowExpr.frame.startType), - (windowExpr.frame.endValue ? " " + windowExpr.frame.endValue->toString() - : ""), - boundTypeString(windowExpr.frame.endType)); - - return fmt::format( - "{} OVER ({} {} {})", - windowExpr.functionCall->toString(), - partitionString, - orderByString, - frameString); + return parseWindowExpr(expr, options)->toString(); } } // namespace TEST(DuckParserTest, window) { EXPECT_EQ( - "row_number() AS c OVER (PARTITION BY \"a\" ORDER BY \"b\" ASC NULLS LAST" - " RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)", + "row_number() OVER (PARTITION BY \"a\" ORDER BY \"b\" ASC NULLS LAST " + "RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS c", parseWindow("row_number() over (partition by a order by b) as c")); EXPECT_EQ( - "row_number() AS a OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)", + "row_number() OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a", parseWindow("row_number() over () as a")); EXPECT_EQ( - "row_number() AS a OVER ( ORDER BY \"b\" ASC NULLS LAST " - "RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)", + "row_number() OVER (ORDER BY \"b\" ASC NULLS LAST " + "RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a", parseWindow("row_number() over (order by b) as a")); EXPECT_EQ( - "row_number() OVER (PARTITION BY \"a\" ROWS BETWEEN " - "UNBOUNDED PRECEDING AND CURRENT ROW)", + "row_number() OVER (PARTITION BY \"a\" " + "ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)", parseWindow( "row_number() over (partition by a rows between unbounded preceding and current row)")); EXPECT_EQ( "row_number() OVER (PARTITION BY \"a\" ORDER BY \"b\" ASC NULLS LAST " "ROWS BETWEEN plus(\"a\",10) PRECEDING AND 10 FOLLOWING)", - parseWindow("row_number() over (partition by a order by b " - "rows between a + 10 preceding and 10 following)")); + parseWindow( + "row_number() over (partition by a order by b " + "rows between a + 10 preceding and 10 following)")); EXPECT_EQ( "row_number() OVER (PARTITION BY \"a\" ORDER BY \"b\" DESC NULLS FIRST " "ROWS BETWEEN plus(\"a\",10) PRECEDING AND 10 FOLLOWING)", @@ -663,17 +580,17 @@ TEST(DuckParserTest, window) { "rows between a + 10 preceding and 10 following)")); EXPECT_EQ( - "lead(\"x\",\"y\",\"z\") OVER ( " + "lead(\"x\",\"y\",\"z\") OVER (" "RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)", parseWindow("lead(x, y, z) over ()")); EXPECT_EQ( - "lag(\"x\",3,\"z\") OVER ( " + "lag(\"x\",3,\"z\") OVER (" "RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)", parseWindow("lag(x, 3, z) over ()")); EXPECT_EQ( - "nth_value(\"x\",3) OVER ( " + "nth_value(\"x\",3) OVER (" "RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)", parseWindow("nth_value(x, 3) over ()")); } @@ -682,17 +599,33 @@ TEST(DuckParserTest, windowWithIntegerConstant) { ParseOptions options; options.parseIntegerAsBigint = false; auto windowExpr = parseWindowExpr("nth_value(x, 3) over ()", options); - auto func = - std::dynamic_pointer_cast(windowExpr.functionCall); - ASSERT_TRUE(func != nullptr) - << windowExpr.functionCall->toString() << " is not a call expr"; - EXPECT_EQ(func->inputs().size(), 2); - auto param = func->inputs()[1]; + ASSERT_TRUE(windowExpr->is(core::IExpr::Kind::kWindow)); + auto* windowCall = windowExpr->as(); + EXPECT_EQ(windowCall->inputs().size(), 2); + auto param = windowCall->inputs()[1]; auto constant = std::dynamic_pointer_cast(param); ASSERT_TRUE(constant != nullptr) << param->toString() << " is not a constant"; EXPECT_EQ(*constant->type(), *INTEGER()); } +TEST(DuckParserTest, parseScalarOrWindowExpr) { + ParseOptions options; + + // Scalar expression returns a plain ExprPtr. + auto scalar = parseScalarOrWindowExpr("a + b", options); + ASSERT_FALSE(scalar->is(core::IExpr::Kind::kWindow)); + EXPECT_EQ(scalar->toString(), "plus(\"a\",\"b\")"); + + // Window expression returns a WindowCallExpr. + auto window = + parseScalarOrWindowExpr("row_number() over (order by a)", options); + ASSERT_TRUE(window->is(core::IExpr::Kind::kWindow)); + auto* windowCall = window->as(); + EXPECT_EQ(windowCall->name(), "row_number"); + EXPECT_EQ(windowCall->orderByKeys().size(), 1); + EXPECT_TRUE(windowCall->orderByKeys()[0].ascending); +} + TEST(DuckParserTest, invalidExpression) { VELOX_ASSERT_THROW( parseExpr("func(a b)"), @@ -831,3 +764,128 @@ TEST(DuckParserTest, lambda) { parseExpr("filter(a, if (b > 0, x -> (x = 10), x -> (x = 20)))") ->toString()); } + +TEST(DuckParserTest, arrayLiteral) { + { + auto expr = parseExpr("ARRAY[1, 2, 3]"); + EXPECT_TRUE(expr->is(core::IExpr::Kind::kConstant)); + VELOX_EXPECT_EQ_TYPES( + expr->as()->type(), ARRAY(BIGINT())); + EXPECT_EQ("{1, 2, 3}", expr->toString()); + } + + { + auto expr = parseExpr("[1, 2, 3]"); + EXPECT_TRUE(expr->is(core::IExpr::Kind::kConstant)); + VELOX_EXPECT_EQ_TYPES( + expr->as()->type(), ARRAY(BIGINT())); + EXPECT_EQ("{1, 2, 3}", expr->toString()); + } + + // Array with null elements. + { + auto expr = parseExpr("[1, null, 3]"); + EXPECT_TRUE(expr->is(core::IExpr::Kind::kConstant)); + VELOX_EXPECT_EQ_TYPES( + expr->as()->type(), ARRAY(BIGINT())); + EXPECT_EQ("{1, null, 3}", expr->toString()); + } + + // All-null array. + { + auto expr = parseExpr("[null, null]"); + EXPECT_TRUE(expr->is(core::IExpr::Kind::kConstant)); + VELOX_EXPECT_EQ_TYPES( + expr->as()->type(), ARRAY(UNKNOWN())); + EXPECT_EQ("{null, null}", expr->toString()); + } + + // Empty array. + { + auto expr = parseExpr("[]"); + EXPECT_TRUE(expr->is(core::IExpr::Kind::kConstant)); + VELOX_EXPECT_EQ_TYPES( + expr->as()->type(), ARRAY(UNKNOWN())); + EXPECT_EQ("", expr->toString()); + } + + // Nested array. + { + auto expr = parseExpr("[[1, 2], [3, 4]]"); + EXPECT_TRUE(expr->is(core::IExpr::Kind::kConstant)); + VELOX_EXPECT_EQ_TYPES( + expr->as()->type(), ARRAY(ARRAY(BIGINT()))); + EXPECT_EQ("{{1, 2}, {3, 4}}", expr->toString()); + } + + // Non-constant argument stays a function call. + { + auto expr = parseExpr("[a, 1, 2]"); + EXPECT_TRUE(expr->is(core::IExpr::Kind::kCall)); + EXPECT_EQ("list_value(\"a\",1,2)", expr->toString()); + } +} + +TEST(DuckParserTest, structLiteral) { + // {'x': 1, 'y': 2} becomes a ROW constant with named fields. + { + auto expr = parseExpr("{'x': 1, 'y': 2}"); + EXPECT_TRUE(expr->is(core::IExpr::Kind::kConstant)); + VELOX_EXPECT_EQ_TYPES( + expr->as()->type(), + ROW({"x", "y"}, {BIGINT(), BIGINT()})); + EXPECT_EQ("{1, 2}", expr->toString()); + } + + // ROW(1, 2) becomes a ROW constant with unnamed fields. + { + auto expr = parseExpr("ROW(1, 2)"); + EXPECT_TRUE(expr->is(core::IExpr::Kind::kConstant)); + VELOX_EXPECT_EQ_TYPES( + expr->as()->type(), ROW({"", ""}, BIGINT())); + EXPECT_EQ("{1, 2}", expr->toString()); + } + + // ROW(a, 2) with a non-constant argument stays a function call. + { + auto expr = parseExpr("ROW(a, 2)"); + EXPECT_TRUE(expr->is(core::IExpr::Kind::kCall)); + EXPECT_EQ("row(\"a\",2)", expr->toString()); + } + + // ROW(1, 2)::struct(x bigint, y bigint) becomes a named ROW constant. + { + auto expr = parseExpr("ROW(1, 2)::struct(x bigint, y bigint)"); + EXPECT_TRUE(expr->is(core::IExpr::Kind::kConstant)); + VELOX_EXPECT_EQ_TYPES( + expr->as()->type(), ROW({"x", "y"}, BIGINT())); + EXPECT_EQ("{1, 2}", expr->toString()); + } + + // ROW(1, 2)::struct(x varchar, y bigint) with mismatched child types stays a + // cast. + { + auto expr = parseExpr("ROW(1, 2)::struct(x varchar, y bigint)"); + EXPECT_TRUE(expr->is(core::IExpr::Kind::kCast)); + EXPECT_EQ("cast({1, 2} as ROW)", expr->toString()); + } +} + +TEST(DuckParserTest, dateLiteral) { + for (const auto& sql : {"'1994-01-01'::date", "DATE '1994-01-01'"}) { + SCOPED_TRACE(sql); + auto expr = parseExpr(sql); + EXPECT_TRUE(expr->is(core::IExpr::Kind::kConstant)); + + auto* constant = expr->as(); + EXPECT_EQ(*constant->type(), *DATE()); + EXPECT_EQ(constant->value().value(), DATE()->toDays("1994-01-01")); + + EXPECT_EQ("1994-01-01", expr->toString()); + } + + // Invalid date string. + VELOX_ASSERT_THROW( + parseExpr("DATE 'not-a-date'"), + "Unable to parse date value: \"not-a-date\""); +} diff --git a/velox/dwio/catalog/fbhive/CMakeLists.txt b/velox/dwio/catalog/fbhive/CMakeLists.txt index 17a778f41c3..305cf427c3d 100644 --- a/velox/dwio/catalog/fbhive/CMakeLists.txt +++ b/velox/dwio/catalog/fbhive/CMakeLists.txt @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -velox_add_library(velox_dwio_catalog_fbhive FileUtils.cpp) +velox_add_library(velox_dwio_catalog_fbhive FileUtils.cpp HEADERS FileUtils.h) velox_link_libraries(velox_dwio_catalog_fbhive velox_exception fmt::fmt Folly::folly) if(${VELOX_BUILD_TESTING}) diff --git a/velox/dwio/catalog/fbhive/FileUtils.cpp b/velox/dwio/catalog/fbhive/FileUtils.cpp index 36d1b29d98e..621ee20bbbf 100644 --- a/velox/dwio/catalog/fbhive/FileUtils.cpp +++ b/velox/dwio/catalog/fbhive/FileUtils.cpp @@ -112,9 +112,10 @@ std::vector> extractPartitionKeyValues( std::vector tokens; folly::split('=', partitionPart, tokens); if (tokens.size() == 2) { - parsedParts.emplace_back(std::make_pair( - FileUtils::unescapePathName(tokens[0]), - FileUtils::unescapePathName(tokens[1]))); + parsedParts.emplace_back( + std::make_pair( + FileUtils::unescapePathName(tokens[0]), + FileUtils::unescapePathName(tokens[1]))); } }); } @@ -157,46 +158,29 @@ std::string FileUtils::unescapePathName(const std::string& data) { std::string FileUtils::makePartName( const std::vector>& entries, - bool partitionPathAsLowerCase) { - size_t size = 0; - size_t escapeCount = 0; - std::for_each(entries.begin(), entries.end(), [&](auto& pair) { - auto keySize = pair.first.size(); - VELOX_CHECK_GT(keySize, 0); - size += keySize; - escapeCount += countEscape(pair.first); - - auto valSize = pair.second.size(); - if (valSize == 0) { - size += kDefaultPartitionValue.size(); - } else { - size += valSize; - escapeCount += countEscape(pair.second); + bool partitionPathAsLowerCase, + bool useDefaultPartitionValue, + const EncodeFunction& encodeFunc) { + VELOX_CHECK(!entries.empty()); + std::ostringstream out; + + for (const auto& [key, value] : entries) { + VELOX_CHECK(!key.empty()); + if (out.tellp() > 0) { + out << '/'; } - }); - std::string ret; - ret.reserve(size + escapeCount * HEX_WIDTH + entries.size() - 1); - - std::for_each(entries.begin(), entries.end(), [&](auto& pair) { - if (ret.size() > 0) { - ret += "/"; - } - if (partitionPathAsLowerCase) { - ret += escapePathName(toLower(pair.first)); - } else { - ret += escapePathName(pair.first); - } + std::string keyToEncode = partitionPathAsLowerCase ? toLower(key) : key; + out << encodeFunc(keyToEncode) << '='; - ret += "="; - if (pair.second.size() == 0) { - ret += kDefaultPartitionValue; + if (value.empty() && useDefaultPartitionValue) { + out << kDefaultPartitionValue; } else { - ret += escapePathName(pair.second); + out << encodeFunc(value); } - }); + } - return ret; + return out.str(); } std::vector> FileUtils::parsePartKeyValues( diff --git a/velox/dwio/catalog/fbhive/FileUtils.h b/velox/dwio/catalog/fbhive/FileUtils.h index a8ca8bf07ef..519c274fc6f 100644 --- a/velox/dwio/catalog/fbhive/FileUtils.h +++ b/velox/dwio/catalog/fbhive/FileUtils.h @@ -29,6 +29,10 @@ namespace fbhive { class FileUtils { public: + /// Function type for encoding partition key/value strings. + /// Takes a string to encode and returns the encoded string. + using EncodeFunction = std::function; + /// Converts the path name to be hive metastore compliant, will do /// url-encoding when needed. static std::string escapePathName(const std::string& data); @@ -39,9 +43,19 @@ class FileUtils { /// Creates the partition directory path from the list of partition key/value /// pairs, will do url-encoding when needed. + /// @param entries Vector of (key, value) pairs for partition columns. Cannot + /// be empty. + /// @param partitionPathAsLowerCase Whether to convert keys to lowercase + /// @param useDefaultPartitionValue If true, empty values are replaced with + /// kDefaultPartitionValue. If false, empty values are encoded as-is. + /// Defaults to true for Hive compatibility. + /// @param encodeFunc Function to use for encoding keys and values. + /// Defaults to escapePathName. static std::string makePartName( const std::vector>& entries, - bool partitionPathAsLowerCase); + bool partitionPathAsLowerCase, + bool useDefaultPartitionValue = true, + const EncodeFunction& encodeFunc = escapePathName); /// Converts the hive-metastore-compliant path name back to the corresponding /// partition key/value pairs. diff --git a/velox/dwio/catalog/fbhive/test/FileUtilsTests.cpp b/velox/dwio/catalog/fbhive/test/FileUtilsTests.cpp index 042c5ba9308..579c7533ecd 100644 --- a/velox/dwio/catalog/fbhive/test/FileUtilsTests.cpp +++ b/velox/dwio/catalog/fbhive/test/FileUtilsTests.cpp @@ -19,10 +19,12 @@ #include "velox/common/base/Exceptions.h" #include "velox/dwio/catalog/fbhive/FileUtils.h" +namespace facebook::velox::dwio::catalog::fbhive { +namespace { + using namespace ::testing; -using namespace facebook::velox::dwio::catalog::fbhive; -TEST(FileUtilsTests, MakePartName) { +TEST(FileUtilsTests, makePartName) { std::vector> pairs{ {"ds", "2016-01-01"}, {"FOO", ""}, {"a\nb:c", "a#b=c"}}; ASSERT_EQ( @@ -31,9 +33,22 @@ TEST(FileUtilsTests, MakePartName) { ASSERT_EQ( FileUtils::makePartName(pairs, false), "ds=2016-01-01/FOO=__HIVE_DEFAULT_PARTITION__/a%0Ab%3Ac=a%23b%3Dc"); + ASSERT_THROW(FileUtils::makePartName({}, false), VeloxException); +} + +TEST(FileUtilsTests, makePartNameWithoutDefaultPartitionValue) { + std::vector> pairs{ + {"ds", "2016-01-01"}, {"FOO", ""}, {"a\nb:c", "a#b=c"}}; + // Test with useDefaultPartitionValue = false. + ASSERT_EQ( + FileUtils::makePartName(pairs, true, false), + "ds=2016-01-01/foo=/a%0Ab%3Ac=a%23b%3Dc"); + ASSERT_EQ( + FileUtils::makePartName(pairs, false, false), + "ds=2016-01-01/FOO=/a%0Ab%3Ac=a%23b%3Dc"); } -TEST(FileUtilsTests, ParsePartKeyValues) { +TEST(FileUtilsTests, parsePartKeyValues) { EXPECT_THROW( FileUtils::parsePartKeyValues("ds"), facebook::velox::VeloxRuntimeError); EXPECT_THROW( @@ -60,7 +75,7 @@ TEST(FileUtilsTests, ParsePartKeyValues) { std::make_pair("a\nb:c", "a#b=c/"))); } -TEST(FileUtilsTests, ExtractPartitionName) { +TEST(FileUtilsTests, extractPartitionName) { struct TestCase { public: TestCase(const std::string& filePath, const std::string& partitionName) @@ -88,3 +103,6 @@ TEST(FileUtilsTests, ExtractPartitionName) { FileUtils::extractPartitionName(testCase.filePath)); } } + +} // namespace +} // namespace facebook::velox::dwio::catalog::fbhive diff --git a/velox/dwio/common/Adaptor.h b/velox/dwio/common/Adaptor.h index 2fef0a96655..ae89074808f 100644 --- a/velox/dwio/common/Adaptor.h +++ b/velox/dwio/common/Adaptor.h @@ -27,7 +27,7 @@ #define DIAGNOSTIC_PUSH _Pragma("GCC diagnostic push") #define DIAGNOSTIC_POP _Pragma("GCC diagnostic pop") #else -#error("Unknown compiler") +#error ("Unknown compiler") #endif #define PRAGMA(TXT) _Pragma(#TXT) diff --git a/velox/dwio/common/Arena.h b/velox/dwio/common/Arena.h new file mode 100644 index 00000000000..9690ce727d3 --- /dev/null +++ b/velox/dwio/common/Arena.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace facebook::velox::dwio::common { + +/// Wrapper over protobuf's arena allocation. The API changes from +/// CreateMessage() to Create() in newer protobuf versions. +template +T* ArenaCreate(google::protobuf::Arena* arena, Args&&... args) { +#if GOOGLE_PROTOBUF_VERSION >= 5030000 + return google::protobuf::Arena::Create(arena, std::forward(args)...); +#else + return google::protobuf::Arena::CreateMessage( + arena, std::forward(args)...); +#endif +} + +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/BitPackDecoder.h b/velox/dwio/common/BitPackDecoder.h index 2aa785e2d29..5f31267026d 100644 --- a/velox/dwio/common/BitPackDecoder.h +++ b/velox/dwio/common/BitPackDecoder.h @@ -109,7 +109,7 @@ static inline uint32_t unpackNaive( uint64_t inputBufferLen, uint64_t numValues, uint8_t bitWidth, - T* FOLLY_NONNULL& result); + T * FOLLY_NONNULL & result); /// Unpack numValues number of input values from inputBuffer. The results /// will be written to result. numValues must be a multiple of 8. The @@ -122,7 +122,7 @@ inline void unpack( uint64_t inputBufferLen, uint64_t numValues, uint8_t bitWidth, - T* FOLLY_NONNULL& result) { + T * FOLLY_NONNULL & result) { unpackNaive(inputBits, inputBufferLen, numValues, bitWidth, result); } @@ -159,7 +159,7 @@ static inline uint32_t unpackNaive( uint64_t inputBufferLen, uint64_t numValues, uint8_t bitWidth, - T* FOLLY_NONNULL& result) { + T * FOLLY_NONNULL & result) { VELOX_CHECK(bitWidth >= 1 && bitWidth <= sizeof(T) * 8); VELOX_CHECK(inputBufferLen * 8 >= bitWidth * numValues); diff --git a/velox/dwio/common/BufferedInput.cpp b/velox/dwio/common/BufferedInput.cpp index 8791e418f49..ec745f0eb3e 100644 --- a/velox/dwio/common/BufferedInput.cpp +++ b/velox/dwio/common/BufferedInput.cpp @@ -27,6 +27,42 @@ using ::facebook::velox::common::Region; namespace facebook::velox::dwio::common { +CachedRegion::CachedRegion(cache::CachePin pin) : pin_(std::move(pin)) { + VELOX_CHECK(!pin_.empty(), "CachedRegion requires a non-empty cache pin"); + auto* entry = pin_.checkedEntry(); + VELOX_CHECK( + !entry->isExclusive(), + "CachedRegion requires a shared (non-exclusive) cache pin"); + size_ = entry->size(); + if (entry->hasContiguousData()) { + ranges_.push_back( + folly::Range(entry->contiguousData(), size_)); + } else { + auto& allocation = entry->nonContiguousData(); + ranges_.reserve(allocation.numRuns()); + uint64_t offset{0}; + for (int i = 0; i < allocation.numRuns() && offset < size_; ++i) { + auto run = allocation.runAt(i); + const uint64_t bytes = + run.numPages() * memory::AllocationTraits::kPageSize; + const uint64_t readSize = std::min(bytes, size_ - offset); + ranges_.emplace_back(run.data(), readSize); + offset += readSize; + } + } +} + +folly::IOBuf CachedRegion::toIOBuf() const { + VELOX_CHECK(!ranges_.empty()); + auto iobuf = + folly::IOBuf::wrapBufferAsValue(ranges_[0].data(), ranges_[0].size()); + for (size_t i = 1; i < ranges_.size(); ++i) { + iobuf.appendToChain( + folly::IOBuf::wrapBuffer(ranges_[i].data(), ranges_[i].size())); + } + return iobuf; +} + static_assert(std::is_move_constructible()); namespace { @@ -44,6 +80,14 @@ uint64_t BufferedInput::nextFetchSize() const { }); } +void BufferedInput::reset() { + regions_.clear(); + offsets_.clear(); + buffers_.clear(); + enqueuedToBufferOffset_.clear(); + allocPool_->clear(); +} + void BufferedInput::load(const LogType logType) { // no regions to load if (regions_.size() == 0) { @@ -82,14 +126,15 @@ void BufferedInput::readToBuffer( uint64_t offset, folly::Range allocated, const LogType logType) { - uint64_t usec = 0; + uint64_t storageReadTimeUs = 0; { - MicrosecondTimer timer(&usec); + MicrosecondTimer timer(&storageReadTimeUs); input_->read(allocated.data(), allocated.size(), offset, logType); } if (auto* stats = input_->getStats()) { stats->read().increment(allocated.size()); - stats->queryThreadIoLatency().increment(usec); + stats->queryThreadIoLatencyUs().increment(storageReadTimeUs); + stats->storageReadLatencyUs().increment(storageReadTimeUs); } } @@ -114,8 +159,9 @@ std::unique_ptr BufferedInput::enqueue( // help faster lookup using enqueuedToBufferOffset_ later. [region, this, i = regions_.size() - 1]() { auto result = readInternal(region.offset, region.length, i); - VELOX_CHECK( - std::get<1>(result) != MAX_UINT64, + VELOX_CHECK_NE( + std::get<1>(result), + MAX_UINT64, "Fail to read region offset={} length={}", region.offset, region.length); diff --git a/velox/dwio/common/BufferedInput.h b/velox/dwio/common/BufferedInput.h index 1f877b3fa8d..a9077a90452 100644 --- a/velox/dwio/common/BufferedInput.h +++ b/velox/dwio/common/BufferedInput.h @@ -16,6 +16,8 @@ #pragma once +#include "folly/io/IOBuf.h" +#include "velox/common/caching/AsyncDataCache.h" #include "velox/common/caching/ScanTracker.h" #include "velox/common/memory/AllocationPool.h" #include "velox/dwio/common/SeekableInputStream.h" @@ -26,6 +28,37 @@ DECLARE_bool(wsVRLoad); namespace facebook::velox::dwio::common { +/// Provides read-only access to cached data without copying. Holds a +/// shared-mode pin on the cache entry, keeping it alive while the caller +/// accesses the data buffers. +class CachedRegion { + public: + /// The pin must be non-empty and in shared (non-exclusive) mode. + explicit CachedRegion(cache::CachePin pin); + + uint64_t size() const { + return size_; + } + + /// Returns buffer ranges covering the cached data. For small entries the + /// result is a single contiguous range; for larger entries there may be + /// multiple non-contiguous ranges from the backing allocation. + const std::vector>& ranges() const { + return ranges_; + } + + /// Returns an IOBuf chain wrapping the cached data ranges without copying. + /// The returned IOBuf references memory owned by this CachedRegion, so + /// the caller must not outlive this object. + folly::IOBuf toIOBuf() const; + + private: + cache::CachePin pin_; + // Cached data size in bytes. + uint64_t size_{0}; + std::vector> ranges_; +}; + class BufferedInput { public: constexpr static uint64_t kMaxMergeDistance = 1024 * 1024 * 1.25; @@ -35,15 +68,19 @@ class BufferedInput { memory::MemoryPool& pool, const MetricsLogPtr& metricsLog = MetricsLog::voidLog(), IoStatistics* stats = nullptr, - filesystems::File::IoStats* fsStats = nullptr, + velox::IoStats* ioStats = nullptr, uint64_t maxMergeDistance = kMaxMergeDistance, - std::optional wsVRLoad = std::nullopt) + std::optional wsVRLoad = std::nullopt, + folly::F14FastMap fileReadOps = {}, + bool cacheable = false) : BufferedInput( std::make_shared( std::move(readFile), metricsLog, stats, - fsStats), + ioStats, + std::move(fileReadOps), + cacheable), pool, maxMergeDistance, wsVRLoad) {} @@ -79,9 +116,19 @@ class BufferedInput { velox::common::Region region, const StreamIdentifier* sid = nullptr); - /// Returns true if load synchronously. - virtual bool supportSyncLoad() const { - return true; + /// Preloads the entire file into memory for fast sub-region access. + /// Each subclass stores the preloaded data in its own native format. + /// For small files (<= filePreloadThreshold), this eliminates separate + /// footer and stripe data reads. + virtual void preload() { + enqueue({0, input_->getLength()}); + load(LogType::FILE); + preloaded_ = true; + } + + /// Returns true if the file has been preloaded. + virtual bool preloaded() const { + return preloaded_; } /// load all regions to be read in an optimized way (IO efficiency) @@ -149,8 +196,50 @@ class BufferedInput { return nullptr; } + /// Returns true if this BufferedInput has a backing cache (e.g., + /// AsyncDataCache). When true, callers may skip their own caching of raw + /// bytes since the BufferedInput will handle caching. + virtual bool hasCache() const { + return false; + } + + /// Offers pre-read data for a file region to the backing cache. Throws if + /// hasCache() is false. Override in subclasses with a backing cache. + virtual void cacheRegion( + uint64_t /*offset*/, + uint64_t /*length*/, + std::string_view /*data*/) { + VELOX_UNSUPPORTED("cacheRegion requires a backing cache"); + } + + /// Overload that copies from an IOBuf (possibly chained) into the cache + /// entry, avoiding the need to coalesce the IOBuf first. 'bufferOffset' is + /// the byte offset within the IOBuf chain where the region data starts. + /// Throws if hasCache() is false. + virtual void cacheRegion( + uint64_t /*offset*/, + uint64_t /*length*/, + const folly::IOBuf& /*buffer*/, + uint64_t /*bufferOffset*/) { + VELOX_UNSUPPORTED("cacheRegion requires a backing cache"); + } + + /// Finds a cached region at the given offset. Returns a CachedRegion holding + /// a shared-mode pin on the cache entry if found, or std::nullopt on cache + /// miss. Throws if hasCache() is false. + virtual std::optional findCachedRegion( + uint64_t /*offset*/) const { + VELOX_UNSUPPORTED("findCachedRegion requires a backing cache"); + } + virtual uint64_t nextFetchSize() const; + /// Resets the buffered input for reuse. This is used by index lookup which + /// reuses the same BufferedInput across different index lookups. For + /// instance, Nimble file format with cluster index supports index lookup and + /// needs to reset the buffered input state between lookups. + virtual void reset(); + protected: static int adjustedReadPct(const cache::TrackingData& trackingData) { // When this method is called, there is one more reference that is already @@ -255,6 +344,7 @@ class BufferedInput { velox::common::Region& first, const velox::common::Region& second); + bool preloaded_{false}; uint64_t maxMergeDistance_; std::optional wsVRLoad_; std::unique_ptr allocPool_; diff --git a/velox/dwio/common/CMakeLists.txt b/velox/dwio/common/CMakeLists.txt index 3a4976bd625..532f2d3bcf6 100644 --- a/velox/dwio/common/CMakeLists.txt +++ b/velox/dwio/common/CMakeLists.txt @@ -14,6 +14,7 @@ add_subdirectory(compression) add_subdirectory(encryption) add_subdirectory(exception) +add_subdirectory(wrap) if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) @@ -35,7 +36,7 @@ velox_add_library( DirectBufferedInput.cpp DirectDecoder.cpp DirectInputStream.cpp - DwioMetricsLog.cpp + MetricsLog.cpp ExecutorBarrier.cpp FileSink.cpp FlatMapHelper.cpp @@ -45,6 +46,7 @@ velox_add_library( MetadataFilter.cpp Options.cpp OutputStream.cpp + ParallelUnitLoader.cpp ParallelFor.cpp Range.cpp Reader.cpp @@ -58,11 +60,80 @@ velox_add_library( SelectiveStructColumnReader.cpp SortingWriter.cpp SortingWriter.h + StatisticsBuilder.cpp Throttler.cpp TypeUtils.cpp TypeWithId.cpp Writer.cpp WriterFactory.cpp + HEADERS + Adaptor.h + Arena.h + BitConcatenation.h + BitPackDecoder.h + BufferUtil.h + BufferedInput.h + CacheInputStream.h + CachedBufferedInput.h + ChainedBuffer.h + Closeable.h + ColumnLoader.h + ColumnSelector.h + ColumnVisitors.h + DataBuffer.h + DataBufferHolder.h + DecoderUtil.h + DirectBufferedInput.h + DirectDecoder.h + DirectInputStream.h + ErrorTolerance.h + ExecutorBarrier.h + FileMetadata.h + FileSink.h + FilterNode.h + FlatMapHelper.h + FlushPolicy.h + FormatData.h + InputStream.h + IntCodecCommon.h + IntDecoder.h + MeasureTime.h + MetadataFilter.h + MetricsLog.h + Mutation.h + OnDemandUnitLoader.h + Options.h + OutputStream.h + ParallelFor.h + ParallelUnitLoader.h + PositionProvider.h + RandGen.h + Range.h + Reader.h + ReaderFactory.h + Retry.h + ScanSpec.h + SeekableInputStream.h + SelectiveByteRleColumnReader.h + SelectiveColumnReader.h + SelectiveColumnReaderInternal.h + SelectiveFlatMapColumnReader.h + SelectiveFloatingPointColumnReader.h + SelectiveIntegerColumnReader.h + SelectiveRepeatedColumnReader.h + SelectiveStructColumnReader.h + Statistics.h + StatisticsBuilder.h + StreamIdentifier.h + StreamUtil.h + Throttler.h + TypeUtil.h + TypeUtils.h + TypeWithId.h + UnitLoader.h + UnitLoaderTools.h + Writer.h + WriterFactory.h ) velox_include_directories(velox_dwio_common PRIVATE ${Protobuf_INCLUDE_DIRS}) @@ -74,13 +145,13 @@ velox_link_libraries( velox_common_io velox_common_compression velox_common_config + velox_common_hyperloglog velox_dwio_common_encryption velox_dwio_common_exception velox_exception velox_expression velox_memory velox_type_tz - Boost::regex Folly::folly glog::glog protobuf::libprotobuf diff --git a/velox/dwio/common/CacheInputStream.cpp b/velox/dwio/common/CacheInputStream.cpp index dedda0c72f6..8474c2aac33 100644 --- a/velox/dwio/common/CacheInputStream.cpp +++ b/velox/dwio/common/CacheInputStream.cpp @@ -35,14 +35,14 @@ CacheInputStream::CacheInputStream( const Region& region, std::shared_ptr input, uint64_t fileNum, - bool noCacheRetention, + bool cacheable, std::shared_ptr tracker, TrackingId trackingId, uint64_t groupId, int32_t loadQuantum) : bufferedInput_(bufferedInput), cache_(bufferedInput_->cache()), - noCacheRetention_(noCacheRetention), + cacheable_(cacheable), region_(region), fileNum_(fileNum), tracker_(std::move(tracker)), @@ -53,12 +53,18 @@ CacheInputStream::CacheInputStream( input_(std::move(input)) {} CacheInputStream::~CacheInputStream() { + if (preloaded_) { + // Preloaded streams hold a shared pin copy. Just release it; eviction of + // the preloaded entry is handled by CachedBufferedInput destructor based + // on cacheable_ flag. + return; + } clearCachePin(); makeCacheEvictable(); } void CacheInputStream::makeCacheEvictable() { - if (!noCacheRetention_) { + if (preloaded_ || cacheable_) { return; } // Walks through the potential prefetch or access cache space of this cache @@ -97,7 +103,7 @@ bool CacheInputStream::Next(const void** buffer, int32_t* size) { } offsetInRun_ += *size; - if (prefetchPct_ < 100) { + if (!preloaded_ && prefetchPct_ < 100) { const auto offsetInQuantum = position_ % loadQuantum_; const auto nextQuantumOffset = position_ - offsetInQuantum + loadQuantum_; const auto prefetchThreshold = loadQuantum_ * prefetchPct_ / 100; @@ -142,8 +148,8 @@ bool CacheInputStream::SkipInt64(int64_t count) { return false; } -google::protobuf::int64 CacheInputStream::ByteCount() const { - return static_cast(position_); +int64_t CacheInputStream::ByteCount() const { + return static_cast(position_); } void CacheInputStream::seekToPosition(PositionProvider& seekPosition) { @@ -170,29 +176,6 @@ void CacheInputStream::setRemainingBytes(uint64_t remainingBytes) { window_ = Region{static_cast(position_), remainingBytes}; } -namespace { -std::vector> makeRanges( - cache::AsyncDataCacheEntry* entry, - size_t length) { - std::vector> buffers; - if (entry->tinyData() == nullptr) { - auto& allocation = entry->data(); - buffers.reserve(allocation.numRuns()); - uint64_t offsetInRuns = 0; - for (int i = 0; i < allocation.numRuns(); ++i) { - auto run = allocation.runAt(i); - uint64_t bytes = run.numPages() * memory::AllocationTraits::kPageSize; - uint64_t readSize = std::min(bytes, length - offsetInRuns); - buffers.push_back(folly::Range(run.data(), readSize)); - offsetInRuns += readSize; - } - } else { - buffers.push_back(folly::Range(entry->tinyData(), entry->size())); - } - return buffers; -} -} // namespace - void CacheInputStream::loadSync(const Region& region) { process::TraceContext trace("loadSync"); int64_t hitSize = region.length; @@ -211,11 +194,16 @@ void CacheInputStream::loadSync(const Region& region) { // the individual parts are hit. ioStats_->incRawBytesRead(hitSize); prefetchStarted_ = false; + + // TODO: add a maximum retry limit or timeout to prevent infinite loops under + // memory pressure or contention if findOrCreate() consistently returns empty + // pins. do { folly::SemiFuture cacheLoadWait(false); cache::RawFileCacheKey key{fileNum_, region.offset}; clearCachePin(); - pin_ = cache_->findOrCreate(key, region.length, &cacheLoadWait); + pin_ = cache_->findOrCreate( + key, region.length, /*contiguous=*/false, &cacheLoadWait); if (pin_.empty()) { VELOX_CHECK(cacheLoadWait.valid()); uint64_t waitUs{0}; @@ -225,7 +213,8 @@ void CacheInputStream::loadSync(const Region& region) { .via(&folly::QueuedImmediateExecutor::instance()) .wait(); } - ioStats_->queryThreadIoLatency().increment(waitUs); + ioStats_->queryThreadIoLatencyUs().increment(waitUs); + ioStats_->cacheWaitLatencyUs().increment(waitUs); continue; } @@ -245,24 +234,25 @@ void CacheInputStream::loadSync(const Region& region) { if (loadFromSsd(region, *entry)) { return; } - const auto ranges = makeRanges(entry, region.length); + const auto ranges = entry->dataRanges(region.length); uint64_t storageReadUs{0}; { MicrosecondTimer timer(&storageReadUs); input_->read(ranges, region.offset, LogType::FILE); } ioStats_->read().increment(region.length); - ioStats_->queryThreadIoLatency().increment(storageReadUs); - ioStats_->incTotalScanTime(storageReadUs * 1'000); - entry->setExclusiveToShared(!noCacheRetention_); + ioStats_->queryThreadIoLatencyUs().increment(storageReadUs); + ioStats_->storageReadLatencyUs().increment(storageReadUs); + ioStats_->incTotalScanTimeNs(storageReadUs * 1'000); + entry->setExclusiveToShared(cacheable_); } while (pin_.empty()); } void CacheInputStream::clearCachePin() { - if (pin_.empty()) { + if (preloaded_ || pin_.empty()) { return; } - if (noCacheRetention_) { + if (!cacheable_) { pin_.checkedEntry()->makeEvictable(); } pin_.clear(); @@ -318,7 +308,8 @@ bool CacheInputStream::loadFromSsd( VELOX_CHECK(pin_.empty()); pin_ = std::move(pins[0]); ioStats_->ssdRead().increment(region.length); - ioStats_->queryThreadIoLatency().increment(ssdLoadUs); + ioStats_->queryThreadIoLatencyUs().increment(ssdLoadUs); + ioStats_->ssdCacheReadLatencyUs().increment(ssdLoadUs); // Skip no-cache retention setting as data is loaded from ssd. entry.setExclusiveToShared(); return true; @@ -334,7 +325,9 @@ std::string CacheInputStream::ssdFileName() const { void CacheInputStream::loadPosition() { const auto offset = region_.offset; + if (pin_.empty()) { + VELOX_CHECK(!preloaded_, "Preloaded stream must always have a valid pin"); auto load = bufferedInput_->coalescedLoad(this); if (load != nullptr) { folly::SemiFuture waitFuture(false); @@ -342,16 +335,21 @@ void CacheInputStream::loadPosition() { { MicrosecondTimer timer(&loadUs); try { - if (!load->loadOrFuture(&waitFuture, !noCacheRetention_)) { + if (!load->loadOrFuture(&waitFuture, cacheable_)) { waitFuture.wait(); } } catch (const std::exception& e) { - // Log the error and continue. The error, if it persists, will be hit - // again in looking up the specific entry and thrown from there. + // Log the error and continue. The error, if it persists, will be + // hit again in looking up the specific entry and thrown from there. LOG(ERROR) << "IOERR: error in coalesced load " << e.what(); } } - ioStats_->queryThreadIoLatency().increment(loadUs); + ioStats_->queryThreadIoLatencyUs().increment(loadUs); + if (load->isSsdLoad()) { + ioStats_->coalescedSsdLoadLatencyUs().increment(loadUs); + } else { + ioStats_->coalescedStorageLoadLatencyUs().increment(loadUs); + } } const auto nextLoadRegion = nextQuantizedLoadRegion(position_); @@ -366,15 +364,16 @@ void CacheInputStream::loadPosition() { entry->offset() + entry->size() > positionInFile) { // The position is inside the range of 'entry'. const auto offsetInEntry = positionInFile - entry->offset(); - if (entry->data().numPages() == 0) { - run_ = reinterpret_cast(entry->tinyData()); + if (entry->hasContiguousData()) { + run_ = reinterpret_cast(entry->contiguousData()); runSize_ = entry->size(); offsetInRun_ = offsetInEntry; offsetOfRun_ = 0; } else { - entry->data().findRun(offsetInEntry, &runIndex_, &offsetInRun_); + entry->nonContiguousData().findRun( + offsetInEntry, &runIndex_, &offsetInRun_); offsetOfRun_ = offsetInEntry - offsetInRun_; - const auto run = entry->data().runAt(runIndex_); + const auto run = entry->nonContiguousData().runAt(runIndex_); run_ = run.data(); runSize_ = memory::AllocationTraits::pageBytes(run.numPages()); if (offsetOfRun_ + runSize_ > entry->size()) { @@ -382,6 +381,14 @@ void CacheInputStream::loadPosition() { } } } else { + // Position is out of range for the current entry. This cannot happen for + // preloaded entries since they cover the entire file. + VELOX_CHECK( + !preloaded_, + "Position {} out of range for preloaded entry [{}, {})", + positionInFile, + entry->offset(), + entry->offset() + entry->size()); clearCachePin(); loadPosition(); } @@ -394,7 +401,7 @@ velox::common::Region CacheInputStream::nextQuantizedLoadRegion( nextRegion.offset += (prevLoadedPosition / loadQuantum_) * loadQuantum_; // Set length to be the lesser of 'loadQuantum_' and distance to end of // 'region_' - nextRegion.length = std::min( + nextRegion.length = std::min( loadQuantum_, region_.length - (nextRegion.offset - region_.offset)); return nextRegion; } diff --git a/velox/dwio/common/CacheInputStream.h b/velox/dwio/common/CacheInputStream.h index 195bdbdd135..272901e1cc7 100644 --- a/velox/dwio/common/CacheInputStream.h +++ b/velox/dwio/common/CacheInputStream.h @@ -35,7 +35,7 @@ class CacheInputStream : public SeekableInputStream { const velox::common::Region& region, std::shared_ptr input, uint64_t fileNum, - bool noCacheRetention, + bool cacheable, std::shared_ptr tracker, cache::TrackingId trackingId, uint64_t groupId, @@ -51,7 +51,7 @@ class CacheInputStream : public SeekableInputStream { bool Next(const void** data, int* size) override; void BackUp(int count) override; bool SkipInt64(int64_t count) override; - google::protobuf::int64 ByteCount() const override; + int64_t ByteCount() const override; void seekToPosition(PositionProvider& position) override; std::string getName() const override; size_t positionSize() const override; @@ -72,15 +72,26 @@ class CacheInputStream : public SeekableInputStream { region_, input_, fileNum_, - noCacheRetention_, + cacheable_, tracker_, trackingId_, groupId_, loadQuantum_); copy->position_ = position_; + if (preloaded_) { + copy->setPreloadedPin(pin_); + } return copy; } + /// Sets the stream to serve data from a preloaded whole-file cache entry. + /// The pin is copied so the stream can outlive the CachedBufferedInput. When + /// set, the stream skips coalesced loading, prefetching, and eviction. + void setPreloadedPin(cache::CachePin pin) { + pin_ = std::move(pin); + preloaded_ = true; + } + /// Sets the stream to range over a window that starts at the current position /// and is 'remainingBytes' bytes in size. 'remainingBytes' must be <= /// 'region_.length - position_'. The stream cannot be used for reading @@ -98,8 +109,8 @@ class CacheInputStream : public SeekableInputStream { prefetchPct_ = pct; } - bool testingNoCacheRetention() const { - return noCacheRetention_; + bool testingCacheable() const { + return cacheable_; } private: @@ -120,7 +131,7 @@ class CacheInputStream : public SeekableInputStream { cache::AsyncDataCacheEntry& entry); // Invoked to clear the cache pin of the accessed cache entry and mark it as - // immediate evictable if 'noCacheRetention_' flag is set. + // immediate evictable if 'cacheable_' is false. void clearCachePin(); void makeCacheEvictable(); @@ -131,10 +142,10 @@ class CacheInputStream : public SeekableInputStream { CachedBufferedInput* const bufferedInput_; cache::AsyncDataCache* const cache_; - // True if a pin should be set to the lowest retention score after + // False if a pin should be set to the lowest retention score after // unpinning. This applies to sequential reads where second access // to the page is not expected. - const bool noCacheRetention_; + const bool cacheable_; // The region of 'input' 'this' ranges over. const velox::common::Region region_; const uint64_t fileNum_; @@ -152,15 +163,15 @@ class CacheInputStream : public SeekableInputStream { // Handle of cache entry. cache::CachePin pin_; - // Offset of current run from start of 'entry_->data()' + // Offset of current run from start of 'entry_->nonContiguousData()' uint64_t offsetOfRun_; - // Pointer to start of current run in 'entry->data()' or - // 'entry->tinyData()'. + // Pointer to start of current run in 'entry->nonContiguousData()' or + // 'entry->contiguousData()'. uint8_t* run_{nullptr}; // Position of stream relative to 'run_'. int offsetInRun_{0}; - // Index of run in 'entry_->data()' + // Index of run in 'entry_->nonContiguousData()' int runIndex_ = -1; // Number of valid bytes above 'run_'. uint32_t runSize_ = 0; @@ -175,6 +186,10 @@ class CacheInputStream : public SeekableInputStream { // Over 100 means no prefetch. int32_t prefetchPct_{200}; + // True if this stream serves data from a preloaded whole-file cache entry. + // When set, loading, prefetching, and eviction are all skipped. + bool preloaded_{false}; + // True if prefetch the next 'loadQuantum_' has been started. Cleared when // moving to the next load quantum. bool prefetchStarted_{false}; diff --git a/velox/dwio/common/CachedBufferedInput.cpp b/velox/dwio/common/CachedBufferedInput.cpp index 26b59d265f6..be6bffb03de 100644 --- a/velox/dwio/common/CachedBufferedInput.cpp +++ b/velox/dwio/common/CachedBufferedInput.cpp @@ -15,8 +15,11 @@ */ #include "velox/dwio/common/CachedBufferedInput.h" +#include "folly/io/Cursor.h" +#include "velox/common/Casts.h" #include "velox/common/memory/Allocation.h" #include "velox/common/process/TraceContext.h" +#include "velox/common/time/Timer.h" #include "velox/dwio/common/CacheInputStream.h" DECLARE_int32(cache_prefetch_min_pct); @@ -47,29 +50,38 @@ std::unique_ptr CachedBufferedInput::enqueue( id = TrackingId(sid->getId()); } VELOX_CHECK_LE(region.offset + region.length, fileSize_); - requests_.emplace_back( - RawFileCacheKey{fileNum_.id(), region.offset}, region.length, id); if (tracker_ != nullptr) { tracker_->recordReference(id, region.length, fileNum_.id(), groupId_.id()); } auto stream = std::make_unique( this, - ioStats_.get(), + ioStatistics_.get(), region, input_, fileNum_.id(), - options_.noCacheRetention(), + options_.cacheable(), tracker_, id, groupId_.id(), options_.loadQuantum()); - requests_.back().stream = stream.get(); + if (preloaded()) { + // Data is already in cache. Give the stream its own pin copy so it can + // outlive this CachedBufferedInput and skip all loading/prefetch logic. + stream->setPreloadedPin(preloadPin_); + } else { + requests_.emplace_back( + RawFileCacheKey{fileNum_.id(), region.offset}, region.length, id); + requests_.back().stream = stream.get(); + } return stream; } bool CachedBufferedInput::isBuffered(uint64_t /*offset*/, uint64_t /*length*/) const { - return false; + // When preloaded, the entire file content is already in cache, so any + // region within the file is considered buffered and can be served without + // additional I/O. + return preloaded(); } bool CachedBufferedInput::shouldPreload(int32_t numPages) { @@ -82,7 +94,7 @@ bool CachedBufferedInput::shouldPreload(int32_t numPages) { numPages += memory::AllocationTraits::numPages( std::min(request.size, options_.loadQuantum())); } - const auto cachePages = cache_->incrementCachedPages(0); + const auto cachePages = cache_->cachedPages(); auto* allocator = cache_->allocator(); const auto maxPages = memory::AllocationTraits::numPages(allocator->capacity()); @@ -127,10 +139,11 @@ std::vector makeRequestParts( std::vector parts; for (uint64_t offset = 0; offset < request.size; offset += loadQuantum) { const int32_t size = std::min(loadQuantum, request.size - offset); - extraRequests.push_back(std::make_unique( - RawFileCacheKey{request.key.fileNum, request.key.offset + offset}, - size, - request.trackingId)); + extraRequests.push_back( + std::make_unique( + RawFileCacheKey{request.key.fileNum, request.key.offset + offset}, + size, + request.trackingId)); parts.push_back(extraRequests.back().get()); parts.back()->coalesces = prefetch; if (prefetchOne) { @@ -160,6 +173,52 @@ bool lessThan(const CacheRequest* left, const CacheRequest* right) { } // namespace +void CachedBufferedInput::preload() { + VELOX_CHECK(preloadPin_.empty(), "preload() called more than once"); + VELOX_CHECK(requests_.empty(), "preload() must be called before enqueue()"); + cache::RawFileCacheKey key{fileNum_.id(), 0}; + folly::SemiFuture waitFuture(false); + do { + preloadPin_ = + cache_->findOrCreate(key, fileSize_, /*contiguous=*/false, &waitFuture); + if (preloadPin_.empty()) { + uint64_t waitUs{0}; + { + MicrosecondTimer timer(&waitUs); + std::move(waitFuture).wait(); + } + ioStatistics_->queryThreadIoLatencyUs().increment(waitUs); + ioStatistics_->cacheWaitLatencyUs().increment(waitUs); + } + } while (preloadPin_.empty()); + + auto* entry = preloadPin_.checkedEntry(); + if (!entry->getAndClearFirstUseFlag()) { + // Already loaded by another concurrent query. + ioStatistics_->ramHit().increment(fileSize_); + } + if (!entry->isExclusive()) { + // Cache hit — already loaded. + return; + } + + entry->setGroupId(groupId_.id()); + entry->setTrackingId( + cache::TrackingId(StreamIdentifier::sequentialFile().id_)); + auto ranges = entry->dataRanges(fileSize_); + uint64_t storageReadUs{0}; + { + MicrosecondTimer timer(&storageReadUs); + input_->read(ranges, 0, LogType::FILE); + } + ioStatistics_->read().increment(fileSize_); + ioStatistics_->incRawBytesRead(fileSize_); + ioStatistics_->queryThreadIoLatencyUs().increment(storageReadUs); + ioStatistics_->storageReadLatencyUs().increment(storageReadUs); + ioStatistics_->incTotalScanTimeNs(storageReadUs * 1'000); + entry->setExclusiveToShared(options_.cacheable()); +} + void CachedBufferedInput::load(const LogType /*unused*/) { // 'requests_ is cleared on exit. auto requests = std::move(requests_); @@ -193,8 +252,8 @@ void CachedBufferedInput::load(const LogType /*unused*/) { if (ssdFile != nullptr) { part->ssdPin = ssdFile->find(part->key); if (!part->ssdPin.empty() && part->ssdPin.run().size() < part->size) { - LOG(INFO) << "IOERR: Ignoring SSD shorter than requested: " - << part->ssdPin.run().size() << " vs " << part->size; + LOG(WARNING) << "Ignoring SSD shorter than requested: " + << part->ssdPin.run().size() << " vs " << part->size; part->ssdPin.clear(); } if (!part->ssdPin.empty()) { @@ -236,7 +295,7 @@ std::vector CachedBufferedInput::groupRequests( if (requests.empty() || (requests.size() < 2 && !prefetch)) { return {}; } - const int32_t maxDistance = kSsd ? 20000 : options_.maxCoalesceDistance(); + const int32_t maxDistance = kSsd ? 20'000 : options_.maxCoalesceDistance(); // Combine adjacent short reads. int64_t coalescedBytes = 0; @@ -278,14 +337,14 @@ class DwioCoalescedLoadBase : public cache::CoalescedLoad { public: DwioCoalescedLoadBase( cache::AsyncDataCache& cache, - std::shared_ptr ioStats, - std::shared_ptr fsStats, + std::shared_ptr ioStatistics, + std::shared_ptr ioStats, uint64_t groupId, std::vector requests) : CoalescedLoad(makeKeys(requests), makeSizes(requests)), cache_(cache), + ioStatistics_(std::move(ioStatistics)), ioStats_(std::move(ioStats)), - fsStats_(std::move(fsStats)), groupId_(groupId) { requests_.reserve(requests.size()); for (const auto& request : requests) { @@ -320,17 +379,17 @@ class DwioCoalescedLoadBase : public cache::CoalescedLoad { protected: void updateStats(const CoalesceIoStats& stats, bool prefetch, bool ssd) { - if (ioStats_ == nullptr) { + if (ioStatistics_ == nullptr) { return; } - ioStats_->incRawOverreadBytes(stats.extraBytes); + ioStatistics_->incRawOverreadBytes(stats.extraBytes); if (ssd) { - ioStats_->ssdRead().increment(stats.payloadBytes); + ioStatistics_->ssdRead().increment(stats.payloadBytes); } else { - ioStats_->read().increment(stats.payloadBytes); + ioStatistics_->read().increment(stats.payloadBytes); } if (prefetch) { - ioStats_->prefetch().increment(stats.payloadBytes); + ioStatistics_->prefetch().increment(stats.payloadBytes); } } @@ -355,8 +414,8 @@ class DwioCoalescedLoadBase : public cache::CoalescedLoad { cache::AsyncDataCache& cache_; std::vector requests_; - std::shared_ptr ioStats_; - std::shared_ptr fsStats_; + std::shared_ptr ioStatistics_; + std::shared_ptr ioStats_; const uint64_t groupId_; int64_t size_{0}; }; @@ -367,20 +426,24 @@ class DwioCoalescedLoad : public DwioCoalescedLoadBase { DwioCoalescedLoad( cache::AsyncDataCache& cache, std::shared_ptr input, - std::shared_ptr ioStats, - std::shared_ptr fsStats, + std::shared_ptr ioStatistics, + std::shared_ptr ioStats, uint64_t groupId, std::vector requests, int32_t maxCoalesceDistance) : DwioCoalescedLoadBase( cache, + std::move(ioStatistics), std::move(ioStats), - std::move(fsStats), groupId, std::move(requests)), input_(std::move(input)), maxCoalesceDistance_(maxCoalesceDistance) {} + bool isSsdLoad() const override { + return false; + } + std::vector loadData(bool prefetch) override { std::vector pins; pins.reserve(keys_.size()); @@ -421,17 +484,21 @@ class SsdLoad : public DwioCoalescedLoadBase { public: SsdLoad( cache::AsyncDataCache& cache, - std::shared_ptr ioStats, - std::shared_ptr fsStats, + std::shared_ptr ioStatistics, + std::shared_ptr ioStats, uint64_t groupId, std::vector requests) : DwioCoalescedLoadBase( cache, + std::move(ioStatistics), std::move(ioStats), - std::move(fsStats), groupId, std::move(requests)) {} + bool isSsdLoad() const override { + return true; + } + std::vector loadData(bool prefetch) override { std::vector ssdPins; std::vector pins; @@ -467,19 +534,19 @@ void CachedBufferedInput::readRegion( std::shared_ptr load; if (!requests[0]->ssdPin.empty()) { load = std::make_shared( - *cache_, ioStats_, fsStats_, groupId_.id(), requests); + *cache_, ioStatistics_, ioStats_, groupId_.id(), requests); } else { load = std::make_shared( *cache_, input_, + ioStatistics_, ioStats_, - fsStats_, groupId_.id(), requests, options_.maxCoalesceDistance()); } - allCoalescedLoads_.push_back(load); - coalescedLoads_.withWLock([&](auto& loads) { + coalescedLoads_.push_back(load); + streamToCoalescedLoad_.withWLock([&](auto& loads) { for (auto& request : requests) { loads[request->stream] = load; } @@ -490,52 +557,65 @@ void CachedBufferedInput::readRegions( const std::vector& requests, bool prefetch, const std::vector& groupEnds) { - int i = 0; - std::vector group; - for (auto end : groupEnds) { - while (i < end) { - group.push_back(requests[i++]); + if (requests.empty()) { + VELOX_CHECK(groupEnds.empty()); + return; + } + // Record the starting position so that we only submit the loads created by + // this call. Without this, non-prefetch loads or stale loads from previous + // cycles could be incorrectly submitted for async prefetching. + const int32_t startIndex = static_cast(coalescedLoads_.size()); + int32_t requestIdx{0}; + std::vector requestGroup; + for (auto groupEndIdx : groupEnds) { + while (requestIdx < groupEndIdx) { + requestGroup.push_back(requests[requestIdx++]); } - readRegion(group, prefetch); - group.clear(); + readRegion(requestGroup, prefetch); + requestGroup.clear(); } + if (prefetch && executor_) { - std::vector doneIndices; - for (auto i = 0; i < allCoalescedLoads_.size(); ++i) { - auto& load = allCoalescedLoads_[i]; + // Only submit the loads created by this call to the executor. + for (auto i = startIndex; i < coalescedLoads_.size(); ++i) { + auto& load = coalescedLoads_[i]; if (load->state() == CoalescedLoad::State::kPlanned) { executor_->add( - [pendingLoad = load, ssdSavable = !options_.noCacheRetention()]() { + [pendingLoad = load, ssdSavable = options_.cacheable()]() { process::TraceContext trace("Read Ahead"); pendingLoad->loadOrFuture(nullptr, ssdSavable); }); - } else { - doneIndices.push_back(i); } } // Remove the loads that were complete. There can be done loads if the same // CachedBufferedInput has multiple cycles of enqueues and loads. - for (int i = 0, j = 0, k = 0; i < allCoalescedLoads_.size(); ++i) { + std::vector doneIndices; + for (int32_t i = 0; i < startIndex; ++i) { + if (coalescedLoads_[i]->state() != CoalescedLoad::State::kPlanned) { + doneIndices.push_back(i); + } + } + for (int i = 0, j = 0, k = 0; i < coalescedLoads_.size(); ++i) { if (j < doneIndices.size() && doneIndices[j] == i) { ++j; } else { - allCoalescedLoads_[k++] = std::move(allCoalescedLoads_[i]); + coalescedLoads_[k++] = std::move(coalescedLoads_[i]); } } - allCoalescedLoads_.resize(allCoalescedLoads_.size() - doneIndices.size()); + coalescedLoads_.resize(coalescedLoads_.size() - doneIndices.size()); } } std::shared_ptr CachedBufferedInput::coalescedLoad( const SeekableInputStream* stream) { - return coalescedLoads_.withWLock( + return streamToCoalescedLoad_.withWLock( [&](auto& loads) -> std::shared_ptr { auto it = loads.find(stream); if (it == loads.end()) { return nullptr; } auto load = std::move(it->second); - auto* dwioLoad = static_cast(load.get()); + auto* dwioLoad = checkedPointerCast(load.get()); for (auto& request : dwioLoad->requests()) { loads.erase(request.stream); } @@ -543,22 +623,36 @@ std::shared_ptr CachedBufferedInput::coalescedLoad( }); } +void CachedBufferedInput::reset() { + BufferedInput::reset(); + for (auto& load : coalescedLoads_) { + load->cancel(); + } + coalescedLoads_.clear(); + streamToCoalescedLoad_.wlock()->clear(); + requests_.clear(); +} + std::unique_ptr CachedBufferedInput::read( uint64_t offset, uint64_t length, LogType /*logType*/) const { VELOX_CHECK_LE(offset + length, fileSize_); - return std::make_unique( + auto stream = std::make_unique( const_cast(this), - ioStats_.get(), + ioStatistics_.get(), Region{offset, length}, input_, fileNum_.id(), - options_.noCacheRetention(), + options_.cacheable(), nullptr, TrackingId(), 0, options_.loadQuantum()); + if (preloaded()) { + stream->setPreloadedPin(preloadPin_); + } + return stream; } bool CachedBufferedInput::prefetch(Region region) { @@ -574,4 +668,89 @@ bool CachedBufferedInput::prefetch(Region region) { return true; } +void CachedBufferedInput::cacheRegion( + uint64_t offset, + uint64_t length, + std::string_view data) { + VELOX_CHECK_EQ(data.size(), length); + auto iobuf = folly::IOBuf::wrapBufferAsValue(data.data(), data.size()); + cacheRegion(offset, length, iobuf, 0); +} + +void CachedBufferedInput::cacheRegion( + uint64_t offset, + uint64_t length, + const folly::IOBuf& buffer, + uint64_t bufferOffset) { + auto pin = + cache_->findOrCreate(RawFileCacheKey{fileNum_.id(), offset}, length); + // Empty pin means the cache is at capacity and cannot accept new entries. + // Non-exclusive means another thread already cached this region; skip the + // duplicate write. + if (pin.empty() || !pin.checkedEntry()->isExclusive()) { + return; + } + + folly::io::Cursor cursor(&buffer); + cursor.skip(bufferOffset); + VELOX_CHECK_GE( + cursor.totalLength(), + length, + "IOBuf has {} bytes after offset {}, need {}", + cursor.totalLength(), + bufferOffset, + length); + + auto* entry = pin.checkedEntry(); + if (entry->hasContiguousData()) { + cursor.pull(entry->contiguousData(), length); + } else { + auto& allocation = entry->nonContiguousData(); + uint64_t copyBytes = 0; + for (int i = 0; i < allocation.numRuns() && copyBytes < length; ++i) { + const auto run = allocation.runAt(i); + const uint64_t copySize = + std::min(run.numBytes(), length - copyBytes); + cursor.pull(run.data(), copySize); + copyBytes += copySize; + } + VELOX_CHECK_EQ(copyBytes, length); + } + + // Clear the first-use flag since this entry is being populated externally + // (not loaded on-demand). The first findCachedRegion access should count + // as a cache hit. + entry->getAndClearFirstUseFlag(); + entry->setExclusiveToShared(); +} + +std::optional CachedBufferedInput::findCachedRegion( + uint64_t offset) const { + const cache::RawFileCacheKey key{fileNum_.id(), offset}; + for (;;) { + folly::SemiFuture waitFuture(false); + auto result = cache_->find(key, &waitFuture); + if (!result.has_value()) { + return std::nullopt; + } + if (!result->empty()) { + auto* entry = result->checkedEntry(); + if (!entry->getAndClearFirstUseFlag()) { + ioStatistics_->ramHit().increment(entry->size()); + } + return CachedRegion{std::move(*result)}; + } + // Entry is exclusive — wait for it to become shared, then retry. + uint64_t waitUs{0}; + { + MicrosecondTimer timer(&waitUs); + std::move(waitFuture) + .via(&folly::QueuedImmediateExecutor::instance()) + .wait(); + } + ioStatistics_->queryThreadIoLatencyUs().increment(waitUs); + ioStatistics_->cacheWaitLatencyUs().increment(waitUs); + } +} + } // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/CachedBufferedInput.h b/velox/dwio/common/CachedBufferedInput.h index ddb08061c9b..bb246598e03 100644 --- a/velox/dwio/common/CachedBufferedInput.h +++ b/velox/dwio/common/CachedBufferedInput.h @@ -27,8 +27,6 @@ #include "velox/dwio/common/CacheInputStream.h" #include "velox/dwio/common/InputStream.h" -DECLARE_int32(cache_load_quantum); - namespace facebook::velox::dwio::common { struct CacheRequest { @@ -61,25 +59,31 @@ class CachedBufferedInput : public BufferedInput { cache::AsyncDataCache* cache, std::shared_ptr tracker, StringIdLease groupId, - std::shared_ptr ioStats, - std::shared_ptr fsStats, + std::shared_ptr ioStatistics, + std::shared_ptr ioStats, folly::Executor* executor, - const io::ReaderOptions& readerOptions) + const io::ReaderOptions& readerOptions, + folly::F14FastMap fileReadOps = {}) : BufferedInput( std::move(readFile), readerOptions.memoryPool(), metricsLog, + ioStatistics.get(), ioStats.get(), - fsStats.get()), + kMaxMergeDistance, + std::nullopt, + std::move(fileReadOps), + readerOptions.cacheable()), cache_(cache), fileNum_(std::move(fileNum)), tracker_(std::move(tracker)), groupId_(std::move(groupId)), + ioStatistics_(std::move(ioStatistics)), ioStats_(std::move(ioStats)), - fsStats_(std::move(fsStats)), executor_(executor), fileSize_(input_->getLength()), options_(readerOptions) { + VELOX_CHECK_NOT_NULL(cache_, "CachedBufferedInput requires a cache"); checkLoadQuantum(); } @@ -89,8 +93,8 @@ class CachedBufferedInput : public BufferedInput { cache::AsyncDataCache* cache, std::shared_ptr tracker, StringIdLease groupId, - std::shared_ptr ioStats, - std::shared_ptr fsStats, + std::shared_ptr ioStatistics, + std::shared_ptr ioStats, folly::Executor* executor, const io::ReaderOptions& readerOptions) : BufferedInput(std::move(input), readerOptions.memoryPool()), @@ -98,26 +102,32 @@ class CachedBufferedInput : public BufferedInput { fileNum_(std::move(fileNum)), tracker_(std::move(tracker)), groupId_(std::move(groupId)), + ioStatistics_(std::move(ioStatistics)), ioStats_(std::move(ioStats)), - fsStats_(std::move(fsStats)), executor_(executor), fileSize_(input_->getLength()), options_(readerOptions) { + VELOX_CHECK_NOT_NULL(cache_, "CachedBufferedInput requires a cache"); checkLoadQuantum(); } ~CachedBufferedInput() override { - for (auto& load : allCoalescedLoads_) { + for (auto& load : coalescedLoads_) { load->cancel(); } + if (!options_.cacheable() && !preloadPin_.empty()) { + preloadPin_.checkedEntry()->makeEvictable(); + } } std::unique_ptr enqueue( velox::common::Region region, const StreamIdentifier* sid) override; - bool supportSyncLoad() const override { - return false; + void preload() override; + + bool preloaded() const override { + return !preloadPin_.empty(); } void load(const LogType /*unused*/) override; @@ -151,16 +161,50 @@ class CachedBufferedInput : public BufferedInput { cache_, tracker_, groupId_, + ioStatistics_, ioStats_, - fsStats_, executor_, options_); } + /// Creates a clone that reads through the cache but marks entries as + /// immediately evictable on destruction (cacheable=false). Use when the + /// caller has its own caching layer (e.g., MetadataCache) and does not need + /// AsyncDataCache to retain the raw bytes. + std::unique_ptr cloneNonCacheable() const { + auto nonCacheableOptions = options_; + nonCacheableOptions.setCacheable(false); + return std::make_unique( + input_, + fileNum_, + cache_, + tracker_, + groupId_, + ioStatistics_, + ioStats_, + executor_, + nonCacheableOptions); + } + cache::AsyncDataCache* cache() const { return cache_; } + bool hasCache() const override { + return true; + } + + void cacheRegion(uint64_t offset, uint64_t length, std::string_view data) + override; + + void cacheRegion( + uint64_t offset, + uint64_t length, + const folly::IOBuf& buffer, + uint64_t bufferOffset) override; + + std::optional findCachedRegion(uint64_t offset) const override; + /// Returns the CoalescedLoad that contains the correlated loads for 'stream' /// or nullptr if none. Returns nullptr on all but first call for 'stream' /// since the load is to be triggered by the first access. @@ -175,6 +219,18 @@ class CachedBufferedInput : public BufferedInput { VELOX_NYI(); } + /// Resets the buffered input for reuse across different operations. + void reset() override; + + const std::vector>& + testingCoalescedLoads() const { + return coalescedLoads_; + } + + size_t testingStreamToCoalescedLoadSize() const { + return streamToCoalescedLoad_.rlock()->size(); + } + private: template std::vector groupRequests( @@ -213,8 +269,8 @@ class CachedBufferedInput : public BufferedInput { const StringIdLease fileNum_; const std::shared_ptr tracker_; const StringIdLease groupId_; - const std::shared_ptr ioStats_; - const std::shared_ptr fsStats_; + const std::shared_ptr ioStatistics_; + const std::shared_ptr ioStats_; folly::Executor* const executor_; const uint64_t fileSize_; const io::ReaderOptions options_; @@ -222,14 +278,18 @@ class CachedBufferedInput : public BufferedInput { // Regions that are candidates for loading. std::vector requests_; - // Coalesced loads spanning multiple cache entries in one IO. + // Map from stream to its coalesced load. folly::Synchronized>> - coalescedLoads_; + streamToCoalescedLoad_; + + // All distinct coalesced loads. + std::vector> coalescedLoads_; - // Distinct coalesced loads in 'coalescedLoads_'. - std::vector> allCoalescedLoads_; + // Holds the whole-file cache entry alive for the lifetime of this input. + // Set by preload(), used by CacheInputStream to serve sub-region reads. + cache::CachePin preloadPin_; }; } // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/ChainedBuffer.h b/velox/dwio/common/ChainedBuffer.h index 7b514d6b460..38ba8d7c269 100644 --- a/velox/dwio/common/ChainedBuffer.h +++ b/velox/dwio/common/ChainedBuffer.h @@ -20,10 +20,7 @@ #include "velox/common/base/GTestMacros.h" #include "velox/dwio/common/DataBuffer.h" -namespace facebook { -namespace velox { -namespace dwio { -namespace common { +namespace facebook::velox::dwio::common { namespace { @@ -228,7 +225,4 @@ class ChainedBuffer { VELOX_FRIEND_TEST(ChainedBufferTests, testClearAll); }; -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/ColumnLoader.cpp b/velox/dwio/common/ColumnLoader.cpp index 7d62008440b..c2d211d89e8 100644 --- a/velox/dwio/common/ColumnLoader.cpp +++ b/velox/dwio/common/ColumnLoader.cpp @@ -56,7 +56,7 @@ RowSet read( structReader->advanceFieldReader(fieldReader, offset); fieldReader->scanSpec()->setValueHook(hook); - fieldReader->read(offset, effectiveRows, incomingNulls); + fieldReader->readWithTiming(offset, effectiveRows, incomingNulls); if (fieldReader->fileType().type()->isRow() || fieldReader->scanSpec()->isFlatMapAsStruct()) { // 'fieldReader_' may itself produce LazyVectors. For this it must have its @@ -116,6 +116,19 @@ void ColumnLoader::loadInternal( } } +void TransformColumnLoader::loadInternal( + RowSet rows, + ValueHook* hook, + vector_size_t resultSize, + VectorPtr* result) { + process::TraceContext trace("TransformColumnLoader::loadInternal"); + VectorPtr fileResult; + ColumnLoader::loadInternal(rows, hook, resultSize, &fileResult); + if (fileResult) { + *result = transform_(fileResult, fieldReader_->memoryPool()); + } +} + void DeltaUpdateColumnLoader::loadInternal( RowSet rows, ValueHook* hook, diff --git a/velox/dwio/common/ColumnLoader.h b/velox/dwio/common/ColumnLoader.h index 50950b74e6e..07c0bca6fb2 100644 --- a/velox/dwio/common/ColumnLoader.h +++ b/velox/dwio/common/ColumnLoader.h @@ -16,7 +16,9 @@ #pragma once +#include "velox/dwio/common/ScanSpec.h" #include "velox/dwio/common/SelectiveStructColumnReader.h" +#include "velox/vector/LazyVector.h" namespace facebook::velox::dwio::common { @@ -30,11 +32,13 @@ class ColumnLoader : public VectorLoader { fieldReader_(fieldReader), version_(version) {} + virtual ~ColumnLoader() = default; + bool supportsHook() const override { return true; } - private: + protected: void loadInternal( RowSet rows, ValueHook* hook, @@ -49,6 +53,34 @@ class ColumnLoader : public VectorLoader { const uint64_t version_; }; +/// Wraps a ColumnLoader and applies a post-read transform when the lazy +/// vector is loaded. Used for mixed extraction transforms where the reader +/// produces the file type and the transform converts to the extraction +/// output type (e.g., MAP → ROW for MapKeys + Size extractions). +class TransformColumnLoader : public ColumnLoader { + public: + TransformColumnLoader( + SelectiveStructColumnReaderBase* structReader, + SelectiveColumnReader* fieldReader, + uint64_t version, + common::ScanSpec::VectorTransform transform) + : ColumnLoader(structReader, fieldReader, version), + transform_(std::move(transform)) {} + + bool supportsHook() const override { + return false; + } + + private: + void loadInternal( + RowSet rows, + ValueHook* hook, + vector_size_t resultSize, + VectorPtr* result) override; + + common::ScanSpec::VectorTransform transform_; +}; + class DeltaUpdateColumnLoader : public VectorLoader { public: DeltaUpdateColumnLoader( diff --git a/velox/dwio/common/ColumnSelector.h b/velox/dwio/common/ColumnSelector.h index 62408e55a21..9bef521d99a 100644 --- a/velox/dwio/common/ColumnSelector.h +++ b/velox/dwio/common/ColumnSelector.h @@ -388,8 +388,10 @@ class ColumnSelector { // expect a runtime_error rather than fault. // Do-Not change the message as expected by client in failure case if (!notFound.empty()) { - throw std::runtime_error(folly::to( - "Columns not found in hive table: ", folly::join(", ", notFound))); + throw std::runtime_error( + folly::to( + "Columns not found in hive table: ", + folly::join(", ", notFound))); } } diff --git a/velox/dwio/common/ColumnVisitors.h b/velox/dwio/common/ColumnVisitors.h index e4a507370bc..33d7a2d6e32 100644 --- a/velox/dwio/common/ColumnVisitors.h +++ b/velox/dwio/common/ColumnVisitors.h @@ -491,7 +491,7 @@ class ColumnVisitor { protected: const TFilter& filter_; - SelectiveColumnReader* reader_; + SelectiveColumnReader* const reader_; const bool allowNulls_; const vector_size_t* rows_; vector_size_t numRows_; @@ -928,8 +928,9 @@ class DictionaryColumnVisitor (simd::reinterpretBatch(cache) & xsimd::batch(1)) != xsimd::batch(0)); #else - auto unknowns = simd::toBitMask(xsimd::batch_bool( - simd::reinterpretBatch((cache & (kUnknown << 24)) << 1))); + auto unknowns = simd::toBitMask( + xsimd::batch_bool(simd::reinterpretBatch( + (cache & (kUnknown << 24)) << 1))); auto passed = simd::toBitMask( xsimd::batch_bool(simd::reinterpretBatch(cache))); #endif @@ -1175,7 +1176,7 @@ ColumnVisitor:: } auto result = DictionaryColumnVisitor( filter_, reader_, RowSet(rows_ + rowIndex_, numRows_), values_); - result.numValuesBias_ = numValuesBias_; + result.setNumValuesBias(numValuesBias_); return result; } @@ -1227,7 +1228,14 @@ class StringDictionaryColumnVisitor vector_size_t previous = isDense && TFilter::deterministic ? 0 : super::currentRow(); if constexpr (!DictSuper::hasFilter()) { - super::filterPassed(index); + // Hooks are LazyVector-level and have no access to the dictionary. + // They must receive decoded StringView values, not dictionary indices. + if constexpr (super::kHasHook) { + super::values_.addValue( + super::rowIndex_ + super::numValuesBias_, valueInDictionary(index)); + } else { + super::filterPassed(index); + } } else { // check the dictionary cache if (TFilter::deterministic && @@ -1329,8 +1337,9 @@ class StringDictionaryColumnVisitor (simd::reinterpretBatch(cache) & xsimd::batch(1)) != xsimd::batch(0)); #else - auto unknowns = simd::toBitMask(xsimd::batch_bool( - simd::reinterpretBatch((cache & (kUnknown << 24)) << 1))); + auto unknowns = simd::toBitMask( + xsimd::batch_bool(simd::reinterpretBatch( + (cache & (kUnknown << 24)) << 1))); auto passed = simd::toBitMask( xsimd::batch_bool(simd::reinterpretBatch(cache))); #endif @@ -1419,7 +1428,7 @@ class StringDictionaryColumnVisitor } } - folly::StringPiece valueInDictionary(int64_t index) { + StringView valueInDictionary(int64_t index) { auto stripeDictSize = DictSuper::state_.dictionary.numValues; if (index < stripeDictSize) { return reinterpret_cast( @@ -1576,7 +1585,7 @@ class StringColumnReadWithVisitorHelper { ExtractValues extractValues, F readWithVisitor) { readWithVisitor( - ColumnVisitor( + ColumnVisitor( *static_cast(filter), &reader_, rows_, diff --git a/velox/dwio/common/DataBuffer.h b/velox/dwio/common/DataBuffer.h index ba84cb45a21..1f2964c1a61 100644 --- a/velox/dwio/common/DataBuffer.h +++ b/velox/dwio/common/DataBuffer.h @@ -23,10 +23,7 @@ #include "velox/common/memory/Memory.h" #include "velox/dwio/common/exception/Exception.h" -namespace facebook { -namespace velox { -namespace dwio { -namespace common { +namespace facebook::velox::dwio::common { template >> class DataBuffer { @@ -34,8 +31,9 @@ class DataBuffer { explicit DataBuffer(velox::memory::MemoryPool& pool, uint64_t size = 0) : pool_(&pool), // Initial allocation uses calloc, to avoid memset. - buf_(reinterpret_cast( - pool_->allocateZeroFilled(1, sizeInBytes(size)))), + buf_( + reinterpret_cast( + pool_->allocateZeroFilled(1, sizeInBytes(size)))), size_(size), capacity_(size) { VELOX_CHECK(buf_ != nullptr || size_ == 0); @@ -233,7 +231,5 @@ class DataBuffer { // Maximum capacity of items of type T. uint64_t capacity_; }; -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook + +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/DataBufferHolder.cpp b/velox/dwio/common/DataBufferHolder.cpp index f6ff2757a35..0f5b246f062 100644 --- a/velox/dwio/common/DataBufferHolder.cpp +++ b/velox/dwio/common/DataBufferHolder.cpp @@ -21,7 +21,7 @@ using facebook::velox::common::testutil::TestValue; namespace facebook::velox::dwio::common { -void DataBufferHolder::take(const std::vector& buffers) { +void DataBufferHolder::take(const std::vector& buffers) { // compute size uint64_t totalSize = 0; for (auto& buf : buffers) { @@ -38,7 +38,7 @@ void DataBufferHolder::take(const std::vector& buffers) { auto* data = buf.data(); for (auto& buffer : buffers) { const auto size = buffer.size(); - ::memcpy(data, buffer.begin(), size); + ::memcpy(data, buffer.cbegin(), size); data += size; } // If possibly, write content of the data to output immediately. Otherwise, diff --git a/velox/dwio/common/DataBufferHolder.h b/velox/dwio/common/DataBufferHolder.h index 1c291f13328..91f4fe56007 100644 --- a/velox/dwio/common/DataBufferHolder.h +++ b/velox/dwio/common/DataBufferHolder.h @@ -47,14 +47,14 @@ class DataBufferHolder { /// Takes content of the incoming data buffer. It is the caller's /// responsibility to resize the buffer (if required). - void take(const std::vector& buffers); + void take(const std::vector& buffers); - void take(folly::StringPiece buffer) { - take(std::vector{buffer}); + void take(std::string_view buffer) { + take(std::vector{buffer}); } void take(const dwio::common::DataBuffer& buffer) { - take(folly::StringPiece{buffer.data(), buffer.size()}); + take(std::string_view{buffer.data(), buffer.size()}); } std::vector>& getBuffers() { diff --git a/velox/dwio/common/DecoderUtil.h b/velox/dwio/common/DecoderUtil.h index 9527915e099..b572058d9ee 100644 --- a/velox/dwio/common/DecoderUtil.h +++ b/velox/dwio/common/DecoderUtil.h @@ -202,11 +202,20 @@ void fixedWidthScan( int32_t numRowsInBuffer, int32_t rowOffset, const T* buffer) { - rowLoop( + if constexpr (!hasFilter && !hasHook && !scatter) { + if (isDense(&rows[rowIndex], numRowsInBuffer)) { + std::memcpy( + rawValues + numValues, + buffer + rows[rowIndex] - rowOffset, + sizeof(T) * numRowsInBuffer); + numValues += numRowsInBuffer; + return; + } + } + rowLoop( rows, rowIndex, rowIndex + numRowsInBuffer, - kStep, [&](int32_t rowIndex) { auto firstRow = rows[rowIndex]; if (!hasFilter) { @@ -223,7 +232,7 @@ void fixedWidthScan( kStep, rawValues); } else { - simd::memcpy( + FOLLY_BUILTIN_MEMCPY( rawValues + numValues, buffer + firstRow - rowOffset, sizeof(T) * kStep); diff --git a/velox/dwio/common/DirectBufferedInput.cpp b/velox/dwio/common/DirectBufferedInput.cpp index 1b5e4faafa5..58076fcb512 100644 --- a/velox/dwio/common/DirectBufferedInput.cpp +++ b/velox/dwio/common/DirectBufferedInput.cpp @@ -49,13 +49,12 @@ std::unique_ptr DirectBufferedInput::enqueue( id = TrackingId(sid->getId()); } VELOX_CHECK_LE(region.offset + region.length, fileSize_); - requests_.emplace_back(region, id); if (tracker_) { tracker_->recordReference(id, region.length, fileNum_.id(), groupId_.id()); } auto stream = std::make_unique( this, - ioStats_.get(), + ioStatistics_.get(), region, input_, fileNum_.id(), @@ -63,13 +62,21 @@ std::unique_ptr DirectBufferedInput::enqueue( id, groupId_.id(), options_.loadQuantum()); - requests_.back().stream = stream.get(); + if (!preloaded()) { + // Only track requests when not preloaded. Preloaded streams serve data + // directly from preloadData_ without going through coalesced loads. + requests_.emplace_back(region, id); + requests_.back().stream = stream.get(); + } return stream; } bool DirectBufferedInput::isBuffered(uint64_t /*offset*/, uint64_t /*length*/) const { - return false; + // When preloaded, the entire file content is already in memory, so any + // region within the file is considered buffered and can be served without + // additional I/O. + return preloaded(); } bool DirectBufferedInput::shouldPreload(int32_t numPages) { @@ -187,8 +194,8 @@ void DirectBufferedInput::readRegion( } auto load = std::make_shared( input_, + ioStatistics_, ioStats_, - fsStats_, groupId_.id(), requests, pool_, @@ -235,7 +242,7 @@ std::shared_ptr DirectBufferedInput::coalescedLoad( return streamToCoalescedLoad_.withWLock( [&](auto& loads) -> std::shared_ptr { auto it = loads.find(stream); - if (it == loads.end()) { + if (it == loads.cend()) { return nullptr; } auto load = std::move(it->second); @@ -244,39 +251,105 @@ std::shared_ptr DirectBufferedInput::coalescedLoad( }); } -std::unique_ptr DirectBufferedInput::read( - uint64_t offset, - uint64_t length, - LogType /*logType*/) const { - VELOX_CHECK_LE(offset + length, fileSize_); - return std::make_unique( - const_cast(this), - ioStats_.get(), - Region{offset, length}, - input_, - fileNum_.id(), - nullptr, - TrackingId(), - 0, - options_.loadQuantum()); +void DirectBufferedInput::reset() { + BufferedInput::reset(); + for (auto& load : coalescedLoads_) { + load->cancel(); + } + coalescedLoads_.clear(); + streamToCoalescedLoad_.wlock()->clear(); + requests_.clear(); } namespace { void appendRanges( - memory::Allocation& allocation, + const memory::Allocation& allocation, size_t length, std::vector>& buffers) { + VELOX_CHECK_LE( + length, + memory::AllocationTraits::pageBytes(allocation.numPages()), + "Length exceeds allocation size"); + buffers.reserve(buffers.size() + allocation.numRuns()); uint64_t offsetInRuns = 0; for (int i = 0; i < allocation.numRuns(); ++i) { + VELOX_CHECK_GE(length, offsetInRuns); auto run = allocation.runAt(i); const uint64_t bytes = memory::AllocationTraits::pageBytes(run.numPages()); const uint64_t readSize = std::min(bytes, length - offsetInRuns); - buffers.push_back(folly::Range(run.data(), readSize)); + buffers.emplace_back(run.data(), readSize); offsetInRuns += readSize; + if (offsetInRuns >= length) { + break; + } } } } // namespace +void DirectBufferedInput::preload() { + VELOX_CHECK(!preloadData_.has_value(), "preload() called more than once"); + VELOX_CHECK(requests_.empty(), "preload() must be called before enqueue()"); + preloadData_.emplace(); + preloadData_->size = fileSize_; + uint64_t storageReadUs{0}; + { + MicrosecondTimer timer(&storageReadUs); + if (fileSize_ <= kTinySize) { + preloadData_->tinyData.resize(fileSize_); + input_->read(preloadData_->tinyData.data(), fileSize_, 0, LogType::FILE); + } else { + const auto numPages = memory::AllocationTraits::numPages(fileSize_); + pool_->allocateNonContiguous(numPages, preloadData_->data); + std::vector> buffers; + appendRanges(preloadData_->data, fileSize_, buffers); + input_->read(buffers, 0, LogType::FILE); + } + } + ioStatistics_->read().increment(fileSize_); + ioStatistics_->incRawBytesRead(fileSize_); + ioStatistics_->queryThreadIoLatencyUs().increment(storageReadUs); + ioStatistics_->storageReadLatencyUs().increment(storageReadUs); + ioStatistics_->incTotalScanTimeNs(storageReadUs * 1'000); +} + +folly::Range DirectBufferedInput::preloadedData( + uint64_t offset, + uint64_t length) const { + VELOX_CHECK( + preloadData_.has_value(), "preloadedData() called without preload"); + VELOX_CHECK_LT(offset, preloadData_->size, "Offset exceeds preloaded size"); + const auto available = + std::min(length, preloadData_->size - offset); + if (preloadData_->data.numPages() == 0) { + return {preloadData_->tinyData.data() + offset, available}; + } + int32_t runIndex; + int32_t offsetInRun; + preloadData_->data.findRun(offset, &runIndex, &offsetInRun); + const auto run = preloadData_->data.runAt(runIndex); + const auto runBytes = memory::AllocationTraits::pageBytes(run.numPages()); + const auto contiguousBytes = + std::min(available, runBytes - offsetInRun); + return {run.data() + offsetInRun, contiguousBytes}; +} + +std::unique_ptr DirectBufferedInput::read( + uint64_t offset, + uint64_t length, + LogType /*logType*/) const { + VELOX_CHECK_LE(offset + length, fileSize_); + return std::make_unique( + const_cast(this), + ioStatistics_.get(), + Region{offset, length}, + input_, + fileNum_.id(), + nullptr, + TrackingId(), + 0, + options_.loadQuantum()); +} + std::vector DirectCoalescedLoad::loadData(bool prefetch) { std::vector> buffers; int64_t lastEnd = requests_[0].region.offset; @@ -286,10 +359,11 @@ std::vector DirectCoalescedLoad::loadData(bool prefetch) { for (auto& request : requests_) { const auto& region = request.region; if (region.offset > lastEnd) { - buffers.push_back(folly::Range( - nullptr, - reinterpret_cast( - static_cast(region.offset - lastEnd)))); + buffers.push_back( + folly::Range( + nullptr, + reinterpret_cast( + static_cast(region.offset - lastEnd)))); overread += buffers.back().size(); } @@ -321,13 +395,14 @@ std::vector DirectCoalescedLoad::loadData(bool prefetch) { input_->read(buffers, requests_[0].region.offset, LogType::FILE); } - ioStats_->read().increment(size + overread); - ioStats_->incRawBytesRead(size); - ioStats_->incTotalScanTime(usecs * 1'000); - ioStats_->queryThreadIoLatency().increment(usecs); - ioStats_->incRawOverreadBytes(overread); + ioStatistics_->read().increment(size + overread); + ioStatistics_->incRawBytesRead(size); + ioStatistics_->incTotalScanTimeNs(usecs * 1'000); + ioStatistics_->queryThreadIoLatencyUs().increment(usecs); + ioStatistics_->storageReadLatencyUs().increment(usecs); + ioStatistics_->incRawOverreadBytes(overread); if (prefetch) { - ioStats_->prefetch().increment(size + overread); + ioStatistics_->prefetch().increment(size + overread); } TestValue::adjust( "facebook::velox::cache::DirectCoalescedLoad::loadData", this); @@ -342,7 +417,7 @@ int32_t DirectCoalescedLoad::getData( requests_.begin(), requests_.end(), offset, [](auto& x, auto offset) { return x.region.offset < offset; }); - if (it == requests_.end() || it->region.offset != offset) { + if (it == requests_.cend() || it->region.offset != offset) { return 0; } data = std::move(it->data); diff --git a/velox/dwio/common/DirectBufferedInput.h b/velox/dwio/common/DirectBufferedInput.h index ea697c9e175..c23e104e92a 100644 --- a/velox/dwio/common/DirectBufferedInput.h +++ b/velox/dwio/common/DirectBufferedInput.h @@ -17,6 +17,7 @@ #pragma once #include +#include #include "velox/common/caching/AsyncDataCache.h" #include "velox/common/caching/FileGroupStats.h" @@ -57,23 +58,24 @@ class DirectCoalescedLoad : public cache::CoalescedLoad { public: DirectCoalescedLoad( std::shared_ptr input, - std::shared_ptr ioStats, - std::shared_ptr fsStats, + std::shared_ptr ioStatistics, + std::shared_ptr ioStats, uint64_t /* groupId */, const std::vector& requests, memory::MemoryPool* pool, int32_t loadQuantum) : CoalescedLoad({}, {}), + ioStatistics_(ioStatistics), ioStats_(ioStats), - fsStats_(fsStats), input_(std::move(input)), loadQuantum_(loadQuantum), pool_(pool) { VELOX_DCHECK_NOT_NULL(pool_); VELOX_DCHECK( - std::is_sorted(requests.begin(), requests.end(), [](auto* x, auto* y) { - return x->region.offset < y->region.offset; - })); + std::is_sorted( + requests.cbegin(), requests.cend(), [](auto* x, auto* y) { + return x->region.offset < y->region.offset; + })); requests_.reserve(requests.size()); for (auto i = 0; i < requests.size(); ++i) { requests_.push_back(std::move(*requests[i])); @@ -84,6 +86,12 @@ class DirectCoalescedLoad : public cache::CoalescedLoad { /// data is retrieved with getData(). std::vector loadData(bool prefetch) override; + /// Returns false since DirectCoalescedLoad reads from remote storage, not + /// SSD. + bool isSsdLoad() const override { + return false; + } + /// Returns the buffer for 'region' in either 'data' or 'tinyData'. 'region' /// must match a region given to DirectBufferedInput::enqueue(). int32_t @@ -102,8 +110,8 @@ class DirectCoalescedLoad : public cache::CoalescedLoad { } private: - const std::shared_ptr ioStats_; - const std::shared_ptr fsStats_; + const std::shared_ptr ioStatistics_; + const std::shared_ptr ioStats_; const std::shared_ptr input_; const int32_t loadQuantum_; memory::MemoryPool* const pool_; @@ -120,21 +128,25 @@ class DirectBufferedInput : public BufferedInput { StringIdLease fileNum, std::shared_ptr tracker, StringIdLease groupId, - std::shared_ptr ioStats, - std::shared_ptr fsStats, + std::shared_ptr ioStatistics, + std::shared_ptr ioStats, folly::Executor* executor, - const io::ReaderOptions& readerOptions) + const io::ReaderOptions& readerOptions, + folly::F14FastMap fileReadOps = {}) : BufferedInput( std::move(readFile), readerOptions.memoryPool(), metricsLog, + ioStatistics.get(), ioStats.get(), - fsStats.get()), + kMaxMergeDistance, + std::nullopt, + std::move(fileReadOps)), fileNum_(std::move(fileNum)), tracker_(std::move(tracker)), groupId_(std::move(groupId)), + ioStatistics_(std::move(ioStatistics)), ioStats_(std::move(ioStats)), - fsStats_(std::move(fsStats)), executor_(executor), fileSize_(input_->getLength()), options_(readerOptions) {} @@ -149,8 +161,10 @@ class DirectBufferedInput : public BufferedInput { velox::common::Region region, const StreamIdentifier* sid) override; - bool supportSyncLoad() const override { - return false; + void preload() override; + + bool preloaded() const override { + return preloadData_.has_value(); } void load(const LogType /*unused*/) override; @@ -176,8 +190,8 @@ class DirectBufferedInput : public BufferedInput { fileNum_, tracker_, groupId_, + ioStatistics_, ioStats_, - fsStats_, executor_, options_)); } @@ -186,6 +200,12 @@ class DirectBufferedInput : public BufferedInput { return pool_; } + /// Returns the contiguous byte range of preloaded data at 'offset' in the + /// file, up to 'length' bytes. Caller must call preload() first and ensure + /// 'offset' < file size. + folly::Range preloadedData(uint64_t offset, uint64_t length) + const; + /// Returns the CoalescedLoad that contains the correlated loads for /// 'stream' or nullptr if none. Returns nullptr on all but first /// call for 'stream' since the load is to be triggered by the first @@ -200,31 +220,60 @@ class DirectBufferedInput : public BufferedInput { return executor_; } + const std::vector>& + testingCoalescedLoads() const { + return coalescedLoads_; + } + + size_t testingStreamToCoalescedLoadSize() const { + return streamToCoalescedLoad_.rlock()->size(); + } + uint64_t nextFetchSize() const override { VELOX_NYI(); } - private: - /// Constructor used by clone(). + /// Resets the buffered input for reuse across different operations. + void reset() override; + + protected: + // The constructor and some member variables are protected to allow custom + // extended buffered inputs. + // Constructor used by clone(). DirectBufferedInput( std::shared_ptr input, StringIdLease fileNum, std::shared_ptr tracker, StringIdLease groupId, - std::shared_ptr ioStats, - std::shared_ptr fsStats, + std::shared_ptr ioStatistics, + std::shared_ptr ioStats, folly::Executor* executor, const io::ReaderOptions& readerOptions) : BufferedInput(std::move(input), readerOptions.memoryPool()), fileNum_(std::move(fileNum)), tracker_(std::move(tracker)), groupId_(std::move(groupId)), + ioStatistics_(std::move(ioStatistics)), ioStats_(std::move(ioStats)), - fsStats_(std::move(fsStats)), executor_(executor), fileSize_(input_->getLength()), options_(readerOptions) {} + // Regions that are candidates for loading. + const StringIdLease fileNum_; + const std::shared_ptr tracker_; + const StringIdLease groupId_; + const std::shared_ptr ioStatistics_; + const std::shared_ptr ioStats_; + folly::Executor* const executor_; + const uint64_t fileSize_; + const io::ReaderOptions options_; + std::vector requests_; + + // Distinct coalesced loads in 'coalescedLoads_'. + std::vector> coalescedLoads_; + + private: std::vector groupRequests( const std::vector& requests, bool prefetch) const; @@ -261,27 +310,21 @@ class DirectBufferedInput : public BufferedInput { } }; - const StringIdLease fileNum_; - const std::shared_ptr tracker_; - const StringIdLease groupId_; - const std::shared_ptr ioStats_; - const std::shared_ptr fsStats_; - folly::Executor* const executor_; - const uint64_t fileSize_; - - // Regions that are candidates for loading. - std::vector requests_; - // Coalesced loads spanning multiple streams in one IO. folly::Synchronized>> streamToCoalescedLoad_; - // Distinct coalesced loads in 'coalescedLoads_'. - std::vector> coalescedLoads_; - - io::ReaderOptions options_; + // Preloaded file data read in a single IO by preload(). Exactly one of + // 'tinyData' or 'data' is populated: 'tinyData' when the file size <= + // kTinySize, 'data' (non-contiguous allocation) otherwise. + struct PreloadData { + memory::Allocation data; + std::string tinyData; + uint64_t size; + }; + std::optional preloadData_; }; } // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/DirectInputStream.cpp b/velox/dwio/common/DirectInputStream.cpp index 68173b40f25..19bfb5ef73f 100644 --- a/velox/dwio/common/DirectInputStream.cpp +++ b/velox/dwio/common/DirectInputStream.cpp @@ -91,8 +91,8 @@ bool DirectInputStream::SkipInt64(int64_t count) { return false; } -google::protobuf::int64 DirectInputStream::ByteCount() const { - return static_cast(offsetInRegion_); +int64_t DirectInputStream::ByteCount() const { + return static_cast(offsetInRegion_); } void DirectInputStream::seekToPosition(PositionProvider& seekPosition) { @@ -153,12 +153,25 @@ void DirectInputStream::loadSync() { input_->read(ranges, loadedRegion_.offset, LogType::FILE); } ioStats_->read().increment(loadedRegion_.length); - ioStats_->queryThreadIoLatency().increment(usecs); - ioStats_->incTotalScanTime(usecs * 1'000); + ioStats_->queryThreadIoLatencyUs().increment(usecs); + ioStats_->storageReadLatencyUs().increment(usecs); + ioStats_->incTotalScanTimeNs(usecs * 1'000); } void DirectInputStream::loadPosition() { VELOX_CHECK_LT(offsetInRegion_, region_.length); + + // Fast path: serve from preloaded whole-file data. + if (bufferedInput_->preloaded()) { + const auto range = bufferedInput_->preloadedData( + region_.offset + offsetInRegion_, region_.length - offsetInRegion_); + run_ = reinterpret_cast(const_cast(range.data())); + runSize_ = range.size(); + offsetInRun_ = 0; + offsetOfRun_ = 0; + return; + } + if (!loaded_) { loaded_ = true; auto load = bufferedInput_->coalescedLoad(this); @@ -173,7 +186,9 @@ void DirectInputStream::loadPosition() { loadedRegion_.offset = region_.offset; loadedRegion_.length = load->getData(region_.offset, data_, tinyData_); } - ioStats_->queryThreadIoLatency().increment(loadUs); + ioStats_->queryThreadIoLatencyUs().increment(loadUs); + // DirectCoalescedLoad always reads from remote storage, not SSD. + ioStats_->coalescedStorageLoadLatencyUs().increment(loadUs); } else { // Standalone stream, not part of coalesced load. loadedRegion_.offset = 0; diff --git a/velox/dwio/common/DirectInputStream.h b/velox/dwio/common/DirectInputStream.h index 3d75b445956..b5bc4440066 100644 --- a/velox/dwio/common/DirectInputStream.h +++ b/velox/dwio/common/DirectInputStream.h @@ -44,7 +44,7 @@ class DirectInputStream : public SeekableInputStream { bool Next(const void** data, int* size) override; void BackUp(int count) override; bool SkipInt64(int64_t count) override; - google::protobuf::int64 ByteCount() const override; + int64_t ByteCount() const override; void seekToPosition(PositionProvider& position) override; std::string getName() const override; @@ -91,8 +91,8 @@ class DirectInputStream : public SeekableInputStream { // Contains the data if the range is too small for Allocation. std::string tinyData_; - // Pointer to start of current run in 'entry->data()' or - // 'entry->tinyData()'. + // Pointer to start of current run in 'entry->nonContiguousData()' or + // 'entry->contiguousData()'. uint8_t* run_{nullptr}; // Offset of current run from start of 'data_' diff --git a/velox/dwio/common/ErrorTolerance.h b/velox/dwio/common/ErrorTolerance.h index db186084cea..4f2aafd6111 100644 --- a/velox/dwio/common/ErrorTolerance.h +++ b/velox/dwio/common/ErrorTolerance.h @@ -18,10 +18,7 @@ #include "velox/dwio/common/exception/Exception.h" -namespace facebook { -namespace velox { -namespace dwio { -namespace common { +namespace facebook::velox::dwio::common { /** * Error tolerance level for readers @@ -63,7 +60,4 @@ struct ErrorTolerance { } }; -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/FileMetadata.h b/velox/dwio/common/FileMetadata.h new file mode 100644 index 00000000000..1beb77ffdb2 --- /dev/null +++ b/velox/dwio/common/FileMetadata.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace facebook::velox::dwio::common { + +/// File format specific metadata returned when a writer is closed. +/// Caller of Writer::close() can do further processing such as aggregate +/// row group statistics to file level statistics based on the metadata. +class FileMetadata { + public: + virtual ~FileMetadata() = default; +}; + +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/FileSink.cpp b/velox/dwio/common/FileSink.cpp index 6654dfb6407..fe69778d851 100644 --- a/velox/dwio/common/FileSink.cpp +++ b/velox/dwio/common/FileSink.cpp @@ -111,10 +111,13 @@ WriteFileSink::WriteFileSink( std::unique_ptr writeFile, std::string name, MetricsLogPtr metricLogger, - IoStatistics* stats) + IoStatistics* stats, + velox::IoStats* fileSystemStats) : FileSink( std::move(name), - {.metricLogger = std::move(metricLogger), .stats = stats}), + {.metricLogger = std::move(metricLogger), + .stats = stats, + .fileSystemStats = fileSystemStats}), writeFile_{std::move(writeFile)} { VELOX_CHECK_NOT_NULL(writeFile_); } diff --git a/velox/dwio/common/FileSink.h b/velox/dwio/common/FileSink.h index 914d06055f1..74273f2b9da 100644 --- a/velox/dwio/common/FileSink.h +++ b/velox/dwio/common/FileSink.h @@ -46,7 +46,8 @@ class FileSink : public Closeable { memory::MemoryPool* pool{nullptr}; MetricsLogPtr metricLogger{MetricsLog::voidLog()}; IoStatistics* stats{nullptr}; - filesystems::File::IoStats* fileSystemStats{nullptr}; + velox::IoStats* fileSystemStats{nullptr}; + std::unordered_map storageParameters{}; }; FileSink(std::string name, const Options& options) @@ -105,6 +106,10 @@ class FileSink : public Closeable { return stats_; } + velox::IoStats* getFileSystemStats() { + return fileSystemStats_; + } + protected: // General write wrapper with logging. All concrete subclasses gets logging // for free if they call a public method that goes through this method. @@ -119,7 +124,7 @@ class FileSink : public Closeable { memory::MemoryPool* const pool_; const MetricsLogPtr metricLogger_; IoStatistics* const stats_; - filesystems::File::IoStats* const fileSystemStats_; + velox::IoStats* const fileSystemStats_; uint64_t size_; }; @@ -131,7 +136,8 @@ class WriteFileSink final : public FileSink { std::unique_ptr writeFile, std::string name, MetricsLogPtr metricLogger = MetricsLog::voidLog(), - IoStatistics* stats = nullptr); + IoStatistics* stats = nullptr, + velox::IoStats* fileSystemStats = nullptr); ~WriteFileSink() override { destroy(); diff --git a/velox/dwio/common/FlatMapHelper.cpp b/velox/dwio/common/FlatMapHelper.cpp index 5d48887af47..e392bf60a65 100644 --- a/velox/dwio/common/FlatMapHelper.cpp +++ b/velox/dwio/common/FlatMapHelper.cpp @@ -21,11 +21,20 @@ namespace facebook::velox::dwio::common::flatmap { namespace detail { -void reset(VectorPtr& vector, vector_size_t size, bool hasNulls) { +void reset( + VectorPtr& vector, + VectorEncoding::Simple desiredEncoding, + vector_size_t size, + bool hasNulls) { if (!vector) { return; } + if (vector->encoding() != desiredEncoding) { + vector.reset(); + return; + } + if (vector.use_count() > 1) { vector.reset(); return; @@ -39,6 +48,9 @@ void reset(VectorPtr& vector, vector_size_t size, bool hasNulls) { } } vector->resize(size); + // Reside BaseVector::length_ as it will be updated in the subsequent copy() + // calls. + vector->BaseVector::resize(0); } void initializeStringVector( @@ -162,7 +174,7 @@ void initializeVectorImpl( } } - detail::reset(vector, size, hasNulls); + detail::reset(vector, VectorEncoding::Simple::ARRAY, size, hasNulls); VectorPtr origElementsVector; if (vector) { auto& arrayVector = dynamic_cast(*vector); @@ -226,7 +238,7 @@ void initializeMapVector( size = sizeOverride.value(); } - detail::reset(vector, size, hasNulls); + detail::reset(vector, VectorEncoding::Simple::MAP, size, hasNulls); VectorPtr origKeysVector; VectorPtr origValuesVector; if (vector) { @@ -298,7 +310,7 @@ void initializeVectorImpl( } } - detail::reset(vector, size, hasNulls); + detail::reset(vector, VectorEncoding::Simple::ROW, size, hasNulls); std::vector origChildren; if (vector) { auto& rowVector = dynamic_cast(*vector); @@ -359,8 +371,9 @@ vector_size_t copyNulls( vector_size_t nulls = 0; // it's assumed that initVector is called before calling this method to // properly allocate/clear nulls buffer. So we only need to check against - // target vector here. - target.resize(targetIndex + count, false); + // target vector here. We only call BaseVector::resize here to make sure + // BaseVector::size() is only updated. + target.BaseVector::resize(targetIndex + count, false); if (target.mayHaveNulls()) { auto tgtNulls = const_cast(target.rawNulls()); if (source.isConstantEncoding()) { @@ -485,7 +498,10 @@ vector_size_t copyOffsets( vector_size_t sourceIndex, vector_size_t count, vector_size_t& childOffset) { - target.resize(targetIndex + count); + // Its expected that initVector is called before calling this method so the + // offsets and sizes buffers are properly allocated. We only call + // BaseVector::resize here to make sure BaseVector::size() is only updated. + target.BaseVector::resize(targetIndex + count); auto tgtOffsets = const_cast(target.rawOffsets()); auto tgtSizes = const_cast(target.rawSizes()); auto srcSizes = source.rawSizes(); @@ -701,8 +717,9 @@ bool copyNull( vector_size_t sourceIndex) { // it's assumed that initVector is called before calling this method to // properly allocate/clear nulls buffer. So we only need to check against - // target vector here. - target.resize(targetIndex + 1, false); + // target vector here. We only call BaseVector::resize here to make sure + // BaseVector::size() is only updated. + target.BaseVector::resize(targetIndex + 1, false); if (target.mayHaveNulls()) { bool srcIsNull = (source.isConstantEncoding() || @@ -792,7 +809,10 @@ vector_size_t copyOffset( const T& source, vector_size_t sourceIndex, vector_size_t& childOffset) { - target.resize(targetIndex + 1); + // Its expected that initVector is called before calling this method so the + // offsets and sizes buffers are properly allocated. We only call + // BaseVector::resize here to make sure BaseVector::size() is only updated. + target.BaseVector::resize(targetIndex + 1); auto tgtSizes = const_cast(target.rawSizes()); childOffset = nextChildOffset(target, targetIndex); const_cast(target.rawOffsets())[targetIndex] = childOffset; diff --git a/velox/dwio/common/FlatMapHelper.h b/velox/dwio/common/FlatMapHelper.h index 05212cdd44e..404e7891a08 100644 --- a/velox/dwio/common/FlatMapHelper.h +++ b/velox/dwio/common/FlatMapHelper.h @@ -24,7 +24,11 @@ namespace facebook::velox::dwio::common::flatmap { namespace detail { // Reset vector with the desired size/hasNulls properties -void reset(VectorPtr& vector, vector_size_t size, bool hasNulls); +void reset( + VectorPtr& vector, + VectorEncoding::Simple desiredEncoding, + vector_size_t size, + bool hasNulls); // Reset vector smart pointer if any of the buffers is not single referenced. template @@ -63,7 +67,7 @@ void initializeFlatVector( vector_size_t size, bool hasNulls, std::vector&& stringBuffers = {}) { - detail::reset(vector, size, hasNulls); + detail::reset(vector, VectorEncoding::Simple::FLAT, size, hasNulls); if (vector) { auto& flatVector = dynamic_cast&>(*vector); detail::resetIfNotWritable(vector, flatVector.nulls(), flatVector.values()); @@ -232,14 +236,14 @@ KeyPredicate prepareKeyPredicate(std::string_view expression) { // You cannot mix allow key and reject key. VELOX_CHECK( modes.empty() || - std::all_of(modes.begin(), modes.end(), [&modes](const auto& v) { + std::all_of(modes.cbegin(), modes.cend(), [&modes](const auto& v) { return v == modes.front(); })); auto mode = modes.empty() ? KeyProjectionMode::ALLOW : modes.front(); return KeyPredicate( - mode, typename KeyPredicate::Lookup(keys.begin(), keys.end())); + mode, typename KeyPredicate::Lookup(keys.cbegin(), keys.cend())); } } // namespace facebook::velox::dwio::common::flatmap diff --git a/velox/dwio/common/FormatData.h b/velox/dwio/common/FormatData.h index 4e8a5548fb0..2a1a228251d 100644 --- a/velox/dwio/common/FormatData.h +++ b/velox/dwio/common/FormatData.h @@ -126,13 +126,20 @@ class FormatData { virtual bool parentNullsInLeaves() const { return false; } + + bool stringDecoderZeroCopy() const { + return stringDecoderZeroCopy_; + } + + protected: + bool stringDecoderZeroCopy_{false}; }; /// Base class for format-specific reader initialization arguments. class FormatParams { public: - explicit FormatParams(memory::MemoryPool& pool, ColumnReaderStatistics& stats) - : pool_(pool), stats_(stats) {} + FormatParams(memory::MemoryPool& pool, ColumnReaderStatistics& stats) + : pool_(&pool), stats_(&stats) {} virtual ~FormatParams() = default; @@ -143,16 +150,16 @@ class FormatParams { const velox::common::ScanSpec& scanSpec) = 0; memory::MemoryPool& pool() { - return pool_; + return *pool_; } ColumnReaderStatistics& runtimeStatistics() { - return stats_; + return *stats_; } private: - memory::MemoryPool& pool_; - ColumnReaderStatistics& stats_; + memory::MemoryPool* const pool_; + ColumnReaderStatistics* const stats_; }; } // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/InputStream.cpp b/velox/dwio/common/InputStream.cpp index cc20a5fc55e..fec1982a844 100644 --- a/velox/dwio/common/InputStream.cpp +++ b/velox/dwio/common/InputStream.cpp @@ -64,26 +64,29 @@ ReadFileInputStream::ReadFileInputStream( std::shared_ptr readFile, const MetricsLogPtr& metricsLog, IoStatistics* stats, - filesystems::File::IoStats* fsStats) - : InputStream(readFile->getName(), metricsLog, stats, fsStats), + velox::IoStats* ioStats, + folly::F14FastMap fileOpts, + bool cacheable) + : InputStream(readFile->getName(), metricsLog, stats, ioStats), + fileIoContext_(ioStats, std::move(fileOpts), nullptr, cacheable), readFile_(std::move(readFile)) {} void ReadFileInputStream::read( void* buf, uint64_t length, uint64_t offset, - MetricsLog::MetricsType purpose) { + MetricsLog::Type purpose) { VELOX_CHECK_NOT_NULL(buf); logRead(offset, length, purpose); uint64_t readTimeUs{0}; std::string_view readData; { MicrosecondTimer timer(&readTimeUs); - readData = readFile_->pread(offset, length, buf, fsStats_); + readData = readFile_->pread(offset, length, buf, fileIoContext_); } if (stats_) { stats_->incRawBytesRead(length); - stats_->incTotalScanTime(readTimeUs * 1'000); + stats_->incTotalScanTimeNs(readTimeUs * 1'000); } VELOX_CHECK_EQ( @@ -102,7 +105,7 @@ void ReadFileInputStream::read( LogType logType) { const int64_t bufferSize = totalBufferSize(buffers); logRead(offset, bufferSize, logType); - const auto size = readFile_->preadv(offset, buffers, fsStats_); + const auto size = readFile_->preadv(offset, buffers, fileIoContext_); VELOX_CHECK_EQ( size, bufferSize, @@ -119,7 +122,7 @@ folly::SemiFuture ReadFileInputStream::readAsync( LogType logType) { const int64_t bufferSize = totalBufferSize(buffers); logRead(offset, bufferSize, logType); - return readFile_->preadvAsync(offset, buffers, fsStats_); + return readFile_->preadvAsync(offset, buffers, fileIoContext_); } bool ReadFileInputStream::hasReadAsync() const { @@ -137,11 +140,12 @@ void ReadFileInputStream::vread( size_t(0), [&](size_t acc, const auto& r) { return acc + r.length; }); logRead(regions[0].offset, length, purpose); - auto readStartMicros = getCurrentTimeMicro(); - readFile_->preadv(regions, iobufs, fsStats_); + const auto readStartTimeUs = getCurrentTimeMicro(); + readFile_->preadv(regions, iobufs, fileIoContext_); if (stats_) { stats_->incRawBytesRead(length); - stats_->incTotalScanTime((getCurrentTimeMicro() - readStartMicros) * 1000); + stats_->incTotalScanTimeNs( + (getCurrentTimeMicro() - readStartTimeUs) * 1'000); } } diff --git a/velox/dwio/common/InputStream.h b/velox/dwio/common/InputStream.h index b0b6deb2c1f..a1c4f55edea 100644 --- a/velox/dwio/common/InputStream.h +++ b/velox/dwio/common/InputStream.h @@ -26,6 +26,7 @@ #include #include +#include #include "velox/common/file/File.h" #include "velox/common/file/Region.h" #include "velox/common/io/IoStatistics.h" @@ -43,11 +44,11 @@ class InputStream { const std::string& path, const MetricsLogPtr& metricsLog = MetricsLog::voidLog(), IoStatistics* stats = nullptr, - filesystems::File::IoStats* fsStats = nullptr) + velox::IoStats* ioStats = nullptr) : path_{path}, metricsLog_{metricsLog}, stats_(stats), - fsStats_(fsStats) {} + ioStats_(ioStats) {} virtual ~InputStream() = default; @@ -132,7 +133,7 @@ class InputStream { std::string path_; MetricsLogPtr metricsLog_; IoStatistics* stats_; - filesystems::File::IoStats* fsStats_; + velox::IoStats* ioStats_; }; /// An input stream that reads from an already opened ReadFile. @@ -143,7 +144,9 @@ class ReadFileInputStream final : public InputStream { std::shared_ptr, const MetricsLogPtr& metricsLog = MetricsLog::voidLog(), IoStatistics* stats = nullptr, - filesystems::File::IoStats* fsStats = nullptr); + velox::IoStats* ioStats = nullptr, + folly::F14FastMap fileReadOps = {}, + bool cacheable = false); ~ReadFileInputStream() override = default; @@ -179,6 +182,7 @@ class ReadFileInputStream final : public InputStream { } private: + FileIoContext fileIoContext_; std::shared_ptr readFile_; }; diff --git a/velox/dwio/common/IntDecoder.h b/velox/dwio/common/IntDecoder.h index 79f7bb8d89e..b106548a68d 100644 --- a/velox/dwio/common/IntDecoder.h +++ b/velox/dwio/common/IntDecoder.h @@ -225,13 +225,18 @@ FOLLY_ALWAYS_INLINE uint64_t IntDecoder::readVuLong() { if (LIKELY(bufferEnd_ - bufferStart_ >= folly::kMaxVarintLength64)) { const char* p = bufferStart_; uint64_t val; + + // Fast path for 1-byte varints (values 0-127), which are very common. + // This avoids the do-while loop overhead for the most frequent case. + int64_t b = *p++; + val = (b & 0x7f); + if (b >= 0) { + bufferStart_ = p; + return val; + } + + // Multi-byte varint path do { - int64_t b; - b = *p++; - val = (b & 0x7f); - if (UNLIKELY(b >= 0)) { - break; - } b = *p++; val |= (b & 0x7f) << 7; if (UNLIKELY(b >= 0)) { @@ -292,8 +297,8 @@ FOLLY_ALWAYS_INLINE uint64_t IntDecoder::readVuLong() { return val; } - int64_t result = 0; - int64_t offset = 0; + uint64_t result = 0; + uint64_t offset = 0; signed char ch; do { ch = readByte(); diff --git a/velox/dwio/common/MeasureTime.h b/velox/dwio/common/MeasureTime.h index c4eaf1281f3..1b7de305a03 100644 --- a/velox/dwio/common/MeasureTime.h +++ b/velox/dwio/common/MeasureTime.h @@ -20,10 +20,7 @@ #include #include -namespace facebook { -namespace velox { -namespace dwio { -namespace common { +namespace facebook::velox::dwio::common { class MeasureTime { public: @@ -61,7 +58,4 @@ inline std::optional measureTimeIfCallback( return std::nullopt; } -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/MetadataFilter.cpp b/velox/dwio/common/MetadataFilter.cpp index 374a2d86108..88e3a5481e2 100644 --- a/velox/dwio/common/MetadataFilter.cpp +++ b/velox/dwio/common/MetadataFilter.cpp @@ -18,6 +18,7 @@ #include #include "velox/dwio/common/ScanSpec.h" +#include "velox/expression/ExprConstants.h" #include "velox/expression/ExprToSubfieldFilter.h" namespace facebook::velox::common { @@ -67,86 +68,118 @@ class MetadataFilter::LeafNode : public Node { std::unique_ptr filter_; }; -struct MetadataFilter::AndNode : Node { +struct MetadataFilter::ConditionNode : Node { static std::unique_ptr create( - std::unique_ptr lhs, - std::unique_ptr rhs) { - if (!lhs) { - return rhs; - } - if (!rhs) { - return lhs; + bool conjuction, + std::vector> args); + + static std::unique_ptr fromExpression( + const std::vector& inputs, + core::ExpressionEvaluator* evaluator, + bool conjunction, + bool negated) { + conjunction = negated ? !conjunction : conjunction; + std::vector> args; + args.reserve(inputs.size()); + for (const auto& input : inputs) { + auto node = Node::fromExpression(*input, evaluator, negated); + if (node) { + args.push_back(std::move(node)); + } else if (!conjunction) { + return nullptr; + } } - return std::make_unique(std::move(lhs), std::move(rhs)); + return create(conjunction, std::move(args)); } - AndNode(std::unique_ptr lhs, std::unique_ptr rhs) - : lhs_(std::move(lhs)), rhs_(std::move(rhs)) {} + explicit ConditionNode(std::vector> args) + : args_{std::move(args)} {} - void addToScanSpec(ScanSpec& scanSpec) const override { - lhs_->addToScanSpec(scanSpec); - rhs_->addToScanSpec(scanSpec); - } - - uint64_t* eval(LeafResults& leafResults, int size) const override { - auto* l = lhs_->eval(leafResults, size); - auto* r = rhs_->eval(leafResults, size); - if (!l) { - return r; - } - if (!r) { - return l; + void addToScanSpec(ScanSpec& scanSpec) const final { + for (const auto& arg : args_) { + arg->addToScanSpec(scanSpec); } - bits::orBits(l, r, 0, size); - return l; } - std::string toString() const override { - return "and(" + lhs_->toString() + "," + rhs_->toString() + ")"; + protected: + std::string ToStringImpl(std::string_view prefix) const { + std::string result{prefix}; + for (size_t i = 0; i < args_.size(); ++i) { + if (i != 0) { + result += ","; + } + result += args_[i]->toString(); + } + result += ")"; + return result; } - private: - std::unique_ptr lhs_; - std::unique_ptr rhs_; + std::vector> args_; }; -struct MetadataFilter::OrNode : Node { - static std::unique_ptr create( - std::unique_ptr lhs, - std::unique_ptr rhs) { - if (!lhs || !rhs) { - return nullptr; +struct MetadataFilter::AndNode final : ConditionNode { + using ConditionNode::ConditionNode; + + uint64_t* eval(LeafResults& leafResults, int size) const final { + uint64_t* result = nullptr; + for (const auto& arg : args_) { + auto* a = arg->eval(leafResults, size); + if (!a) { + continue; + } + if (!result) { + result = a; + } else { + bits::orBits(result, a, 0, size); + } } - return std::make_unique(std::move(lhs), std::move(rhs)); + return result; } - OrNode(std::unique_ptr lhs, std::unique_ptr rhs) - : lhs_(std::move(lhs)), rhs_(std::move(rhs)) {} - - void addToScanSpec(ScanSpec& scanSpec) const override { - lhs_->addToScanSpec(scanSpec); - rhs_->addToScanSpec(scanSpec); + std::string toString() const final { + return ToStringImpl("and("); } +}; - uint64_t* eval(LeafResults& leafResults, int size) const override { - auto* l = lhs_->eval(leafResults, size); - auto* r = rhs_->eval(leafResults, size); - if (!l || !r) { - return nullptr; +struct MetadataFilter::OrNode final : ConditionNode { + using ConditionNode::ConditionNode; + + uint64_t* eval(LeafResults& leafResults, int size) const final { + uint64_t* result = nullptr; + for (const auto& arg : args_) { + auto* a = arg->eval(leafResults, size); + if (!a) { + return nullptr; + } + if (!result) { + result = a; + } else { + bits::andBits(result, a, 0, size); + } } - bits::andBits(l, r, 0, size); - return l; + return result; } - std::string toString() const override { - return "or(" + lhs_->toString() + "," + rhs_->toString() + ")"; + std::string toString() const final { + return ToStringImpl("or("); } - - private: - std::unique_ptr lhs_; - std::unique_ptr rhs_; }; +std::unique_ptr MetadataFilter::ConditionNode::create( + bool conjunction, + std::vector> args) { + if (args.empty()) { + return nullptr; + } + if (args.size() == 1) { + return std::move(args[0]); + } + if (conjunction) { + return std::make_unique(std::move(args)); + } + return std::make_unique(std::move(args)); +} + namespace { const core::CallTypedExpr* asCall(const core::ITypedExpr* expr) { @@ -163,29 +196,26 @@ std::unique_ptr MetadataFilter::Node::fromExpression( if (!call) { return nullptr; } - if (call->name() == "and") { - auto lhs = fromExpression(*call->inputs()[0], evaluator, negated); - auto rhs = fromExpression(*call->inputs()[1], evaluator, negated); - return negated ? OrNode::create(std::move(lhs), std::move(rhs)) - : AndNode::create(std::move(lhs), std::move(rhs)); + if (call->name() == expression::kAnd) { + return ConditionNode::fromExpression( + call->inputs(), evaluator, true, negated); } - if (call->name() == "or") { - auto lhs = fromExpression(*call->inputs()[0], evaluator, negated); - auto rhs = fromExpression(*call->inputs()[1], evaluator, negated); - return negated ? AndNode::create(std::move(lhs), std::move(rhs)) - : OrNode::create(std::move(lhs), std::move(rhs)); + if (call->name() == expression::kOr) { + return ConditionNode::fromExpression( + call->inputs(), evaluator, false, negated); } if (call->name() == "not") { return fromExpression(*call->inputs()[0], evaluator, !negated); } try { - Subfield subfield; - auto filter = + auto subfieldAndFilter = exec::ExprToSubfieldFilterParser::getInstance() - ->leafCallToSubfieldFilter(*call, subfield, evaluator, negated); - if (!filter) { + ->leafCallToSubfieldFilter(*call, evaluator, negated); + if (!subfieldAndFilter.has_value()) { return nullptr; } + + auto& [subfield, filter] = subfieldAndFilter.value(); VELOX_CHECK( subfield.valid(), "Invalid subfield from expression: {}", diff --git a/velox/dwio/common/MetadataFilter.h b/velox/dwio/common/MetadataFilter.h index 62b604b1440..d626bbdd967 100644 --- a/velox/dwio/common/MetadataFilter.h +++ b/velox/dwio/common/MetadataFilter.h @@ -50,6 +50,7 @@ class MetadataFilter { private: struct Node; + struct ConditionNode; struct AndNode; struct OrNode; diff --git a/velox/dwio/common/DwioMetricsLog.cpp b/velox/dwio/common/MetricsLog.cpp similarity index 59% rename from velox/dwio/common/DwioMetricsLog.cpp rename to velox/dwio/common/MetricsLog.cpp index d8e0323f9bd..a18a3f71073 100644 --- a/velox/dwio/common/DwioMetricsLog.cpp +++ b/velox/dwio/common/MetricsLog.cpp @@ -35,6 +35,42 @@ std::shared_ptr& metricsLogFactory() { } } // namespace +/* static */ std::shared_ptr MetricsLog::voidLog() { + static const MetricsLog kInstance{{}}; + return {std::shared_ptr{}, &kInstance}; +} + +/* static */ std::string MetricsLog::getMetricTypeName(Type type) { + switch (type) { + case Type::HEADER: + return "HEADER"; + case Type::FOOTER: + return "FOOTER"; + case Type::FILE: + return "FILE"; + case Type::STRIPE: + return "STRIPE"; + case Type::STRIPE_INDEX: + return "STRIPE_INDEX"; + case Type::STRIPE_FOOTER: + return "STRIPE_FOOTER"; + case Type::STREAM: + return "STREAM"; + case Type::STREAM_BUNDLE: + return "STREAM_BUNDLE"; + case Type::GROUP: + return "GROUP"; + case Type::GROUP_INDEX: + return "GROUP_INDEX"; + case Type::BLOCK: + return "BLOCK"; + case Type::TEST: + return "TEST"; + default: + VELOX_UNREACHABLE("Unknown MetricsLog type: {}", static_cast(type)); + } +} + void registerMetricsLogFactory(std::shared_ptr factory) { metricsLogFactory() = std::move(factory); } diff --git a/velox/dwio/common/MetricsLog.h b/velox/dwio/common/MetricsLog.h index 024117bd568..691ad356fb4 100644 --- a/velox/dwio/common/MetricsLog.h +++ b/velox/dwio/common/MetricsLog.h @@ -19,17 +19,15 @@ #include "velox/dwio/common/FilterNode.h" #include "velox/dwio/common/exception/Exception.h" -namespace facebook { -namespace velox { -namespace dwio { -namespace common { +namespace facebook::velox::dwio::common { class MetricsLog { public: static constexpr std::string_view LIB_VERSION_STRING{"1.1"}; - static constexpr folly::StringPiece WRITE_OPERATION{"WRITE"}; + static constexpr std::string_view WRITE_OPERATION{"WRITE"}; - enum class MetricsType { + /// Identifies the type of metadata being read or logged for metrics purposes. + enum class Type { HEADER, FOOTER, FILE, @@ -39,8 +37,10 @@ class MetricsLog { STREAM, STREAM_BUNDLE, GROUP, + GROUP_INDEX, BLOCK, - TEST + TEST, + METADATA }; virtual ~MetricsLog() = default; @@ -54,7 +54,7 @@ class MetricsLog { uint64_t footerSize, uint64_t readOffset, uint64_t readSize, - MetricsType type, + Type type, uint32_t numFileRead, uint32_t numStripeCache) const {} @@ -127,45 +127,17 @@ class MetricsLog { virtual void logFileClose(const FileCloseMetrics& /* metrics */) const {} - static std::shared_ptr voidLog() { - static const MetricsLog kInstance{{}}; - return {std::shared_ptr{}, &kInstance}; - } + static std::shared_ptr voidLog(); protected: MetricsLog(const std::string& file) : file_{file} {} - static std::string getMetricTypeName(MetricsType type) { - switch (type) { - case MetricsType::HEADER: - return "HEADER"; - case MetricsType::FOOTER: - return "FOOTER"; - case MetricsType::FILE: - return "FILE"; - case MetricsType::STRIPE: - return "STRIPE"; - case MetricsType::STRIPE_INDEX: - return "STRIPE_INDEX"; - case MetricsType::STRIPE_FOOTER: - return "STRIPE_FOOTER"; - case MetricsType::STREAM: - return "STREAM"; - case MetricsType::STREAM_BUNDLE: - return "STREAM_BUNDLE"; - case MetricsType::GROUP: - return "GROUP"; - case MetricsType::BLOCK: - return "BLOCK"; - case MetricsType::TEST: - return "TEST"; - } - } + static std::string getMetricTypeName(Type type); std::string file_; }; -using LogType = MetricsLog::MetricsType; +using LogType = MetricsLog::Type; using MetricsLogPtr = std::shared_ptr; class DwioMetricsLogFactory { @@ -178,7 +150,4 @@ void registerMetricsLogFactory(std::shared_ptr factory); DwioMetricsLogFactory& getMetricsLogFactory(); -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/OnDemandUnitLoader.cpp b/velox/dwio/common/OnDemandUnitLoader.cpp index d4ef4f0a5ef..1a21f5475c1 100644 --- a/velox/dwio/common/OnDemandUnitLoader.cpp +++ b/velox/dwio/common/OnDemandUnitLoader.cpp @@ -15,12 +15,12 @@ */ #include "velox/dwio/common/OnDemandUnitLoader.h" +#include "velox/common/time/Timer.h" #include #include "velox/common/base/Exceptions.h" #include "velox/dwio/common/MeasureTime.h" -#include "velox/dwio/common/UnitLoaderTools.h" using facebook::velox::dwio::common::measureTimeIfCallback; @@ -42,6 +42,7 @@ class OnDemandUnitLoader : public UnitLoader { LoadUnit& getLoadedUnit(uint32_t unit) override { VELOX_CHECK_LT(unit, loadUnits_.size(), "Unit out of range"); + processedUnits_.insert(unit); if (loadedUnit_.has_value()) { if (loadedUnit_.value() == unit) { return *loadUnits_[unit]; @@ -51,11 +52,14 @@ class OnDemandUnitLoader : public UnitLoader { loadedUnit_.reset(); } + uint64_t unitLoadNanos{0}; { + NanosecondTimer timer{&unitLoadNanos}; auto measure = measureTimeIfCallback(blockedOnIoCallback_); loadUnits_[unit]->load(); } loadedUnit_ = unit; + unitLoadNanos_ += unitLoadNanos; return *loadUnits_[unit]; } @@ -73,11 +77,25 @@ class OnDemandUnitLoader : public UnitLoader { rowOffsetInUnit, loadUnits_[unit]->getNumRows(), "Row out of range"); } + UnitLoaderStats stats() override { + UnitLoaderStats stats; + stats.addCounter("processedUnits", RuntimeCounter(processedUnits_.size())); + stats.addCounter( + "unitLoadNanos", + RuntimeCounter( + saturateCast(unitLoadNanos_), RuntimeCounter::Unit::kNanos)); + return stats; + } + private: const std::vector> loadUnits_; const std::function blockedOnIoCallback_; std::optional loadedUnit_; + + // Stats + std::unordered_set processedUnits_; + uint64_t unitLoadNanos_{0}; }; } // namespace diff --git a/velox/dwio/common/Options.cpp b/velox/dwio/common/Options.cpp index 33946f71159..26487686393 100644 --- a/velox/dwio/common/Options.cpp +++ b/velox/dwio/common/Options.cpp @@ -39,6 +39,10 @@ FileFormat toFileFormat(std::string_view s) { return FileFormat::ORC; } else if (s == "sst") { return FileFormat::SST; + } else if (s == "flux") { + return FileFormat::FLUX; + } else if (s == "avro") { + return FileFormat::AVRO; } return FileFormat::UNKNOWN; } @@ -65,6 +69,10 @@ std::string_view toString(FileFormat fmt) { return "orc"; case FileFormat::SST: return "sst"; + case FileFormat::FLUX: + return "flux"; + case FileFormat::AVRO: + return "avro"; default: return "unknown"; } diff --git a/velox/dwio/common/Options.h b/velox/dwio/common/Options.h index ef70cb72379..a16a1e30a62 100644 --- a/velox/dwio/common/Options.h +++ b/velox/dwio/common/Options.h @@ -17,7 +17,8 @@ #pragma once #include -#include +#include +#include #include #include "velox/common/base/RandomUtil.h" @@ -51,6 +52,8 @@ enum class FileFormat { NIMBLE = 8, ORC = 9, SST = 10, // rocksdb sst format + FLUX = 11, + AVRO = 12, }; FileFormat toFileFormat(std::string_view s); @@ -282,6 +285,22 @@ class RowReaderOptions { scanSpec_ = std::move(scanSpec); } + folly::Executor* ioExecutor() const { + return ioExecutor_; + } + + void setIOExecutor(folly::Executor* const ioExecutor) { + ioExecutor_ = ioExecutor; + } + + const size_t parallelUnitLoadCount() const { + return parallelUnitLoadCount_; + } + + void setParallelUnitLoadCount(size_t parallelUnitLoadCount) { + parallelUnitLoadCount_ = parallelUnitLoadCount; + } + const std::shared_ptr& metadataFilter() const { return metadataFilter_; } @@ -296,8 +315,8 @@ class RowReaderOptions { flatmapNodeIdsAsStruct) { VELOX_CHECK( std::all_of( - flatmapNodeIdsAsStruct.begin(), - flatmapNodeIdsAsStruct.end(), + flatmapNodeIdsAsStruct.cbegin(), + flatmapNodeIdsAsStruct.cend(), [](const auto& kv) { return !kv.second.empty(); }), "To use struct encoding for flatmap, keys to project must be specified"); flatmapNodeIdAsStruct_ = std::move(flatmapNodeIdsAsStruct); @@ -428,12 +447,64 @@ class RowReaderOptions { serdeParameters_ = std::move(serdeParameters); } + bool trackRowSize() const { + return trackRowSize_; + } + + void setTrackRowSize(bool trackRowSize) { + trackRowSize_ = trackRowSize; + } + + bool indexEnabled() const { + return indexEnabled_; + } + + /// Sets whether to use the cluster index for filter-based row pruning. + /// When enabled, filters from ScanSpec are converted to index bounds for + /// efficient row skipping based on the file's cluster index. + /// + /// NOTE: currently only supported by Nimble format. + void setIndexEnabled(bool enabled) { + indexEnabled_ = enabled; + } + + bool stringDecoderZeroCopy() const { + return stringDecoderZeroCopy_; + } + + void setStringDecoderZeroCopy(bool stringDecoderZeroCopy) { + stringDecoderZeroCopy_ = stringDecoderZeroCopy; + } + + bool nimblePreserveDictionaryEncoding() const { + return nimblePreserveDictionaryEncoding_; + } + + void setNimblePreserveDictionaryEncoding(bool value) { + nimblePreserveDictionaryEncoding_ = value; + } + + bool collectColumnCpuMetrics() const { + return collectColumnCpuMetrics_; + } + + RowReaderOptions& setCollectColumnCpuMetrics(bool collect) { + collectColumnCpuMetrics_ = collect; + return *this; + } + + // Legacy alias — remove after Nimble OSS bumps Velox. + RowReaderOptions& setCollectColumnStats(bool collect) { + return setCollectColumnCpuMetrics(collect); + } + private: uint64_t dataStart_; uint64_t dataLength_; bool preloadStripe_; bool projectSelectedType_; bool returnFlatVector_ = false; + size_t parallelUnitLoadCount_ = 0; ErrorTolerance errorTolerance_; std::shared_ptr selector_; RowTypePtr requestedType_; @@ -446,7 +517,8 @@ class RowReaderOptions { // Whether to generate FlatMapVectors when reading flat maps from the file. By // default, converts flat maps in the file to MapVectors. bool preserveFlatMapsInMemory_ = false; - + // Optional io executor to enable parallel unit loader. + folly::Executor* ioExecutor_{nullptr}; // Optional executors to enable internal reader parallelism. // 'decodingExecutor' allow parallelising the vector decoding process. // 'ioExecutor' enables parallelism when performing file system read @@ -485,20 +557,27 @@ class RowReaderOptions { TimestampPrecision timestampPrecision_ = TimestampPrecision::kMilliseconds; std::shared_ptr formatSpecificOptions_; + bool trackRowSize_{false}; + bool indexEnabled_{false}; + // Enables zero-copy string decoding in the Nimble selective reader, + // using the non-legacy encoding path. Controlled via session property. + bool stringDecoderZeroCopy_{false}; + // Controls whether dictionary-encoded Nimble string columns return + // DictionaryVector instead of FlatVector. Controlled via session property. + bool nimblePreserveDictionaryEncoding_{false}; + bool collectColumnCpuMetrics_{false}; }; /// Options for creating a Reader. class ReaderOptions : public io::ReaderOptions { public: - static constexpr uint64_t kDefaultFooterEstimatedSize = 1024 * 1024; // 1MB + static constexpr uint64_t kDefaultFooterSpeculativeIoSize = + 1024 * 1024; // 1MB static constexpr uint64_t kDefaultFilePreloadThreshold = 1024 * 1024 * 8; // 8MB explicit ReaderOptions(velox::memory::MemoryPool* pool) - : io::ReaderOptions(pool), - tailLocation_(std::numeric_limits::max()), - fileFormat_(FileFormat::UNKNOWN), - fileSchema_(nullptr) {} + : io::ReaderOptions(pool) {} /// Sets the format of the file, such as "rc" or "dwrf". The default is /// "dwrf". @@ -507,6 +586,13 @@ class ReaderOptions : public io::ReaderOptions { return *this; } + /// Sets the property bag. + ReaderOptions& setProperties( + std::unordered_map properties) { + properties_ = std::move(properties); + return *this; + } + /// Sets the current table schema of the file (a Type tree). This could be /// different from the actual schema in file if schema evolution happened. /// For "dwrf" format, a default schema is derived from the file. For "rc" @@ -535,8 +621,8 @@ class ReaderOptions : public io::ReaderOptions { return *this; } - ReaderOptions& setFooterEstimatedSize(uint64_t size) { - footerEstimatedSize_ = size; + ReaderOptions& setFooterSpeculativeIoSize(uint64_t size) { + footerSpeculativeIoSize_ = size; return *this; } @@ -555,11 +641,6 @@ class ReaderOptions : public io::ReaderOptions { return *this; } - ReaderOptions& setIOExecutor(std::shared_ptr executor) { - ioExecutor_ = std::move(executor); - return *this; - } - ReaderOptions& setSessionTimezone(const tz::TimeZone* sessionTimezone) { sessionTimezone_ = sessionTimezone; return *this; @@ -580,6 +661,11 @@ class ReaderOptions : public io::ReaderOptions { return fileFormat_; } + /// Gets the property bag. + const std::unordered_map& properties() const { + return properties_; + } + /// Gets the file schema. const std::shared_ptr& fileSchema() const { return fileSchema_; @@ -597,18 +683,14 @@ class ReaderOptions : public io::ReaderOptions { return decrypterFactory_; } - uint64_t footerEstimatedSize() const { - return footerEstimatedSize_; + uint64_t footerSpeculativeIoSize() const { + return footerSpeculativeIoSize_; } uint64_t filePreloadThreshold() const { return filePreloadThreshold_; } - const std::shared_ptr& ioExecutor() const { - return ioExecutor_; - } - const tz::TimeZone* sessionTimezone() const { return sessionTimezone_; } @@ -633,12 +715,12 @@ class ReaderOptions : public io::ReaderOptions { randomSkip_ = std::move(randomSkip); } - bool noCacheRetention() const { - return noCacheRetention_; + bool cacheable() const { + return cacheable_; } - void setNoCacheRetention(bool noCacheRetention) { - noCacheRetention_ = noCacheRetention; + void setCacheable(bool cacheable) { + cacheable_ = cacheable; } const std::shared_ptr& scanSpec() const { @@ -657,6 +739,52 @@ class ReaderOptions : public io::ReaderOptions { selectiveNimbleReaderEnabled_ = value; } + /// Whether to cache file metadata (footer, stripes, index) in the + /// process-wide AsyncDataCache. When enabled, the first reader performs a + /// speculative tail read and populates the cache; subsequent readers on the + /// same file initialize from the cache with zero additional IO. + bool fileMetadataCacheEnabled() const { + return fileMetadataCacheEnabled_; + } + + void setFileMetadataCacheEnabled(bool value) { + fileMetadataCacheEnabled_ = value; + } + + /// If true, pins parsed metadata objects (e.g., StripeGroup, IndexGroup) in + /// the reader's metadata cache with strong references so they are never + /// evicted. This avoids re-reading and re-parsing metadata on every stripe + /// access when weak-pointer cache entries would otherwise expire. + bool pinFileMetadata() const { + return pinFileMetadata_; + } + + void setPinFileMetadata(bool value) { + pinFileMetadata_ = value; + } + + /// Whether to load and initialize the cluster index during file open. + /// When true, the cluster index section is preloaded and the structured + /// ClusterIndex object is created. Default true. + bool loadClusterIndex() const { + return loadClusterIndex_; + } + + void setLoadClusterIndex(bool value) { + loadClusterIndex_ = value; + } + + /// Whether to load and initialize the chunk index during file open. + /// When true, the chunk index section is preloaded and the structured + /// ChunkIndex object is created. Default true. + bool loadChunkIndex() const { + return loadChunkIndex_; + } + + void setLoadChunkIndex(bool value) { + loadChunkIndex_ = value; + } + bool allowEmptyFile() const { return allowEmptyFile_; } @@ -665,23 +793,42 @@ class ReaderOptions : public io::ReaderOptions { allowEmptyFile_ = value; } + /// Allows reading INT32 physical type columns as a narrower integer type + /// (e.g., INT32 -> TINYINT/SMALLINT). Some Parquet writers store INT_8 and + /// INT_16 values as plain INT32 without a converted type annotation. When + /// enabled, the value is silently truncated on overflow. When disabled + /// (default), only annotated type-matching reads are allowed (e.g., + /// INT_8 -> TINYINT, INT_16 -> SMALLINT, INT_32 -> INTEGER). + bool allowInt32Narrowing() const { + return allowInt32Narrowing_; + } + + void setAllowInt32Narrowing(bool value) { + allowInt32Narrowing_ = value; + } + private: - uint64_t tailLocation_; - FileFormat fileFormat_; + uint64_t tailLocation_{std::numeric_limits::max()}; + FileFormat fileFormat_{FileFormat::UNKNOWN}; RowTypePtr fileSchema_; SerDeOptions serDeOptions_; + std::unordered_map properties_{}; std::shared_ptr decrypterFactory_; - uint64_t footerEstimatedSize_{kDefaultFooterEstimatedSize}; + uint64_t footerSpeculativeIoSize_{kDefaultFooterSpeculativeIoSize}; uint64_t filePreloadThreshold_{kDefaultFilePreloadThreshold}; bool fileColumnNamesReadAsLowerCase_{false}; bool useColumnNamesForColumnMapping_{false}; - std::shared_ptr ioExecutor_; std::shared_ptr randomSkip_; std::shared_ptr scanSpec_; const tz::TimeZone* sessionTimezone_{nullptr}; bool adjustTimestampToTimezone_{false}; bool selectiveNimbleReaderEnabled_{false}; + bool fileMetadataCacheEnabled_{false}; + bool pinFileMetadata_{false}; + bool loadClusterIndex_{true}; + bool loadChunkIndex_{true}; bool allowEmptyFile_{false}; + bool allowInt32Narrowing_{false}; }; struct WriterOptions { diff --git a/velox/dwio/common/OutputStream.h b/velox/dwio/common/OutputStream.h index 46e90410ab1..d106cd4c9f3 100644 --- a/velox/dwio/common/OutputStream.h +++ b/velox/dwio/common/OutputStream.h @@ -47,8 +47,8 @@ class BufferedOutputStream : public google::protobuf::io::ZeroCopyOutputStream { void BackUp(int32_t count) override; - google::protobuf::int64 ByteCount() const override { - return static_cast(size()); + int64_t ByteCount() const override { + return static_cast(size()); } bool WriteAliasedRaw(const void* /* unused */, int32_t /* unused */) @@ -116,7 +116,7 @@ class BufferedOutputStream : public google::protobuf::io::ZeroCopyOutputStream { void** buffer, int32_t* size, uint64_t headerSize, - const std::vector& bufferToFlush) { + const std::vector& bufferToFlush) { bufferHolder_.take(bufferToFlush); *buffer = buffer_.data() + headerSize; *size = static_cast(buffer_.size() - headerSize); diff --git a/velox/dwio/common/ParallelUnitLoader.cpp b/velox/dwio/common/ParallelUnitLoader.cpp new file mode 100644 index 00000000000..778fc3c380c --- /dev/null +++ b/velox/dwio/common/ParallelUnitLoader.cpp @@ -0,0 +1,197 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/dwio/common/ParallelUnitLoader.h" +#include +#include "velox/common/base/AsyncSource.h" +#include "velox/common/base/Exceptions.h" +#include "velox/common/testutil/TestValue.h" +#include "velox/common/time/Timer.h" + +namespace facebook::velox::dwio::common { + +class ParallelUnitLoader : public UnitLoader { + public: + /// Enables concurrent loading of multiple units (stripes, row groups, etc.) + /// using asynchronous I/O to improve throughput and reduce read latency. + /// + /// **Loading Strategy:** + /// - Initialization: Preloads up to `maxConcurrentLoads` units concurrently + /// - Access pattern: On each getLoadedUnit() call, ensures the requested unit + /// is loaded and triggers loading of subsequent units within the window + /// - Memory management: Unloads all previous units to control memory usage + /// + /// **Performance Characteristics:** + /// - Best suited for sequential access patterns + /// - Memory usage: O(maxConcurrentLoads * average_unit_size) + /// - I/O parallelism: Up to `maxConcurrentLoads` concurrent load operations + /// + /// **Parameters:** + /// @param units All units to be loaded + /// @param ioExecutor Thread pool for asynchronous unit loading operations + /// @param maxConcurrentLoads Maximum units to load concurrently (sliding + /// window size) + /// + /// **Example with maxConcurrentLoads=3:** + /// ``` + /// Units: [0,1,2,3,4,5,6,7,8,9] + /// Init: Load [0,1,2] concurrently + /// Get(0): Wait for unit 0, trigger load of units [0,1,2], unload none + /// Get(1): Wait for unit 1, trigger load of units [1,2,3], unload [0] + /// Get(2): Wait for unit 2, trigger load of units [2,3,4], unload [0,1] + /// ``` + ParallelUnitLoader( + std::vector> units, + folly::Executor* ioExecutor, + uint16_t maxConcurrentLoads) + : loadUnits_( + std::make_move_iterator(units.begin()), + std::make_move_iterator(units.end())), + ioExecutor_(ioExecutor), + maxConcurrentLoads_(maxConcurrentLoads) { + VELOX_CHECK_NOT_NULL(ioExecutor, "ParallelUnitLoader ioExecutor is null"); + VELOX_CHECK_GT( + maxConcurrentLoads_, + 0, + "ParallelUnitLoader maxConcurrentLoads should be larger than 0"); + asyncSources_.resize(loadUnits_.size()); + unitsLoaded_.resize(loadUnits_.size()); + } + + /// Destructor ensures all pending load operations are properly cancelled + /// and waited for to prevent resource leaks and dangling references. + ~ParallelUnitLoader() override { + for (auto& source : asyncSources_) { + if (source) { + source->cancel(); + } + } + } + + LoadUnit& getLoadedUnit(uint32_t unit) override { + VELOX_CHECK_LT(unit, loadUnits_.size(), "Unit out of range"); + + processedUnits_.insert(unit); + // Ensure sliding window of units [unit, unit + maxConcurrentLoads_) is + // loading + for (size_t i = unit; + i < loadUnits_.size() && i < unit + maxConcurrentLoads_; + ++i) { + if (!unitsLoaded_[i]) { + load(i); + } + } + + uint64_t unitLoadNanos{0}; + try { + NanosecondTimer timer{&unitLoadNanos}; + asyncSources_[unit]->move(); + } catch (const std::exception& e) { + VELOX_FAIL("Failed to load unit {}: {}", unit, e.what()); + } + waitForUnitReadyNanos_ += unitLoadNanos; + + // Unload the previous units + unloadUntil(unit); + + return *loadUnits_[unit]; + } + + void onRead(uint32_t unit, uint64_t rowOffsetInUnit, uint64_t /* rowCount */) + override { + VELOX_CHECK_LT(unit, loadUnits_.size(), "Unit out of range"); + VELOX_CHECK_LT( + rowOffsetInUnit, loadUnits_[unit]->getNumRows(), "Row out of range"); + } + + void onSeek(uint32_t unit, uint64_t rowOffsetInUnit) override { + VELOX_CHECK_LT(unit, loadUnits_.size(), "Unit out of range"); + VELOX_CHECK_LE( + rowOffsetInUnit, loadUnits_[unit]->getNumRows(), "Row out of range"); + } + + UnitLoaderStats stats() override { + UnitLoaderStats stats; + stats.addCounter("processedUnits", RuntimeCounter(processedUnits_.size())); + stats.addCounter( + "waitForUnitReadyNanos", + RuntimeCounter( + saturateCast(waitForUnitReadyNanos_), + RuntimeCounter::Unit::kNanos)); + return stats; + } + + private: + /// Submits the unit's load() to the I/O thread pool + void load(uint32_t unitIndex) { + VELOX_CHECK_LT(unitIndex, loadUnits_.size(), "Unit index out of bounds"); + VELOX_CHECK_NOT_NULL(ioExecutor_, "ParallelUnitLoader ioExecutor is null"); + VELOX_DCHECK(!loadUnits_.empty(), "loadUnits_ should not be empty"); + + // Capture shared_ptr by value to prevent use-after-free if + // ParallelUnitLoader is destroyed while async operation is running + auto unit = loadUnits_[unitIndex]; + auto asyncSource = std::make_shared>([unit] { + unit->load(); + return std::make_unique(); + }); + asyncSources_[unitIndex] = asyncSource; + ioExecutor_->add([asyncSource] { + velox::common::testutil::TestValue::adjust( + "facebook::velox::dwio::common::ParallelUnitLoader::load", + asyncSource.get()); + asyncSource->prepare(); + }); + unitsLoaded_[unitIndex] = true; + } + + /// Unloads all the units before 'unitIndex' + void unloadUntil(uint32_t unitIndex) { + for (size_t i = 0; i < unitIndex; ++i) { + if (unitsLoaded_[i]) { + loadUnits_[i]->unload(); + unitsLoaded_[i] = false; + } + } + } + + std::vector unitsLoaded_; + std::vector> loadUnits_; + std::vector>> asyncSources_; + folly::Executor* ioExecutor_; + size_t maxConcurrentLoads_; + + // Stats + std::unordered_set processedUnits_; + uint64_t waitForUnitReadyNanos_{0}; +}; + +std::unique_ptr ParallelUnitLoaderFactory::create( + std::vector> loadUnits, + uint64_t rowsToSkip) { + const auto totalRows = std::accumulate( + loadUnits.cbegin(), loadUnits.cend(), 0UL, [](uint64_t sum, auto& unit) { + return sum + unit->getNumRows(); + }); + VELOX_CHECK_LE( + rowsToSkip, + totalRows, + "Can only skip up to the past-the-end row of the file."); + return std::make_unique( + std::move(loadUnits), ioExecutor_, maxConcurrentLoads_); +} + +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/ParallelUnitLoader.h b/velox/dwio/common/ParallelUnitLoader.h new file mode 100644 index 00000000000..0ba89028326 --- /dev/null +++ b/velox/dwio/common/ParallelUnitLoader.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include "velox/dwio/common/UnitLoader.h" + +namespace facebook::velox::dwio::common { +class ParallelUnitLoaderFactory : public UnitLoaderFactory { + public: + ParallelUnitLoaderFactory( + folly::Executor* ioExecutor, + size_t maxConcurrentLoads) + : ioExecutor_(ioExecutor), maxConcurrentLoads_(maxConcurrentLoads) {} + + std::unique_ptr create( + std::vector> loadUnits, + uint64_t rowsToSkip) override; + + private: + folly::Executor* ioExecutor_; + size_t maxConcurrentLoads_; +}; + +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/PositionProvider.h b/velox/dwio/common/PositionProvider.h index 7be3bc7a160..c99655a87e7 100644 --- a/velox/dwio/common/PositionProvider.h +++ b/velox/dwio/common/PositionProvider.h @@ -23,7 +23,7 @@ namespace facebook::velox::dwio::common { class PositionProvider { public: explicit PositionProvider(const std::vector& positions) - : position_{positions.begin()}, end_{positions.end()} {} + : position_{positions.cbegin()}, end_{positions.cend()} {} uint64_t next(); diff --git a/velox/dwio/common/RandGen.h b/velox/dwio/common/RandGen.h index b83743bbcca..a1fc53e8dc7 100644 --- a/velox/dwio/common/RandGen.h +++ b/velox/dwio/common/RandGen.h @@ -19,10 +19,7 @@ #include #include -namespace facebook { -namespace velox { -namespace dwio { -namespace common { +namespace facebook::velox::dwio::common { class RandGen { public: @@ -68,7 +65,4 @@ class RandGen { std::uniform_int_distribution dist_; }; -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/Reader.cpp b/velox/dwio/common/Reader.cpp index 559dab5ae80..3d4508655a8 100644 --- a/velox/dwio/common/Reader.cpp +++ b/velox/dwio/common/Reader.cpp @@ -20,7 +20,7 @@ namespace facebook::velox::dwio::common { using namespace velox::common; -VectorPtr RowReader::projectColumns( +RowReader::ProjectColumnsResult RowReader::projectColumnsWithSelection( const VectorPtr& input, const ScanSpec& spec, const Mutation* mutation) { @@ -75,11 +75,23 @@ VectorPtr RowReader::projectColumns( auto rowType = ROW(std::move(names), std::move(types)); auto size = bits::countBits(passed.data(), 0, input->size()); if (size == 0) { - return RowVector::createEmpty(rowType, input->pool()); + // Empty output. Return null selectedRows when the input was already + // empty (no rows were dropped — identity mapping holds trivially). + // Otherwise return a zero-length selection buffer so callers can + // distinguish "all rows filtered out" from "identity mapping" by + // checking selectedRows == nullptr. + return { + RowVector::createEmpty(rowType, input->pool()), + input->size() == 0 ? nullptr : allocateIndices(0, input->pool())}; } + + // Preserve input nulls buffer + BufferPtr outputNulls = input->nulls(); + + BufferPtr selectedRows; if (size < input->size()) { - auto indices = allocateIndices(size, input->pool()); - auto* rawIndices = indices->asMutable(); + selectedRows = allocateIndices(size, input->pool()); + auto* rawIndices = selectedRows->asMutable(); vector_size_t j = 0; bits::forEachSetBit( passed.data(), 0, input->size(), [&](auto i) { rawIndices[j++] = i; }); @@ -89,11 +101,69 @@ VectorPtr RowReader::projectColumns( } child->disableMemo(); child = BaseVector::wrapInDictionary( - nullptr, indices, size, std::move(child)); + nullptr, selectedRows, size, std::move(child)); + } + + // Filter the nulls buffer to match the filtered rows + if (input->nulls()) { + outputNulls = AlignedBuffer::allocate(size, input->pool()); + auto* rawOutputNulls = outputNulls->asMutable(); + // Initialize all as not null (all bits set to 1) + memset(rawOutputNulls, 0xFF, bits::nbytes(size)); + + const auto* rawInputNulls = input->rawNulls(); + for (vector_size_t i = 0; i < size; ++i) { + if (bits::isBitNull(rawInputNulls, rawIndices[i])) { + bits::setNull(rawOutputNulls, i); + } + } + } + } + + // Apply post-read transforms for column extraction pushdown. + // This runs after filtering/dictionary wrapping so the vectors have the + // correct row count. If a child has more elements than the output size + // (e.g., text reader pre-allocates), slice it first. + bool hasTransforms = false; + for (auto& childSpec : spec.children()) { + if (!childSpec->projectOut() || !childSpec->hasTransform()) { + continue; + } + auto i = childSpec->channel(); + if (children[i]) { + auto child = children[i]; + if (child->size() > size) { + child = child->slice(0, size); + } + children[i] = childSpec->transform()(child, input->pool()); + hasTransforms = true; } } - return std::make_shared( - input->pool(), rowType, nullptr, size, std::move(children)); + + // Rebuild rowType if transforms changed any child types. + if (hasTransforms) { + auto rowNames = rowType->asRow().names(); + std::vector rowTypes; + rowTypes.reserve(numColumns); + for (column_index_t i = 0; i < numColumns; ++i) { + rowTypes.push_back( + children[i] ? children[i]->type() : rowType->childAt(i)); + } + rowType = + ROW(std::vector(rowNames.begin(), rowNames.end()), + std::move(rowTypes)); + } + + auto output = std::make_shared( + input->pool(), rowType, outputNulls, size, std::move(children)); + return {std::move(output), std::move(selectedRows)}; +} + +VectorPtr RowReader::projectColumns( + const VectorPtr& input, + const ScanSpec& spec, + const Mutation* mutation) { + return projectColumnsWithSelection(input, spec, mutation).output; } namespace { diff --git a/velox/dwio/common/Reader.h b/velox/dwio/common/Reader.h index 9dddfaeaca0..4fe4e7a5138 100644 --- a/velox/dwio/common/Reader.h +++ b/velox/dwio/common/Reader.h @@ -21,12 +21,16 @@ #include #include +#include + +#include "velox/connectors/Connector.h" #include "velox/dwio/common/InputStream.h" #include "velox/dwio/common/Mutation.h" #include "velox/dwio/common/Options.h" #include "velox/dwio/common/SelectiveColumnReader.h" #include "velox/dwio/common/Statistics.h" #include "velox/dwio/common/TypeWithId.h" +#include "velox/serializers/KeyEncoder.h" #include "velox/type/Type.h" #include "velox/vector/BaseVector.h" @@ -44,6 +48,16 @@ class RowReader { public: static constexpr int64_t kAtEnd = -1; + /// Runtime stat names. + /// Tracks the number of index columns that were converted from ScanSpec + /// filters to index bounds for index-based filtering (e.g., cluster index + /// pruning in Nimble). + static constexpr std::string_view kNumIndexFilterConversions = + "numIndexFilterConversions"; + + /// Tracks the number of times a stripe has been loaded during index lookup. + static constexpr std::string_view kNumStripeLoads = "numStripeLoads"; + virtual ~RowReader() = default; /** @@ -139,6 +153,34 @@ class RowReader { return std::nullopt; } + /** + * Result of projectColumnsWithSelection. 'output' is the projected + * RowVector. 'selectedRows' maps each output row back to its input row + * index: selectedRows[i] is the input row that produced output row i. + * 'selectedRows' is null when no rows were dropped — output rows are + * identity-aligned with the input. This includes the empty-input case + * (input->size() == 0), where the identity mapping holds trivially. + * When all rows are filtered out from a non-empty input, 'output' is + * empty and 'selectedRows' is a non-null zero-length buffer so callers + * can distinguish "filtered to empty" from "identity mapping". + */ + struct ProjectColumnsResult { + VectorPtr output; + BufferPtr selectedRows; + }; + + /** + * Like projectColumns, but also returns the input-row selection used to + * build the output. Callers that need to keep an external per-input-row + * structure (for example, an index reader's inputHits buffer) aligned + * with the filtered output can use 'selectedRows' to compact that + * structure without re-running filters. + */ + static ProjectColumnsResult projectColumnsWithSelection( + const VectorPtr& input, + const velox::common::ScanSpec& spec, + const Mutation* mutation); + /** * Helper function used by non-selective reader to project top level columns * according to the scan spec and mutations. @@ -157,6 +199,154 @@ class RowReader { VectorPtr& result); }; +/// Represents a row range within a stripe [startRow, endRow). +struct RowRange { + vector_size_t startRow{0}; // Inclusive + vector_size_t endRow{0}; // Exclusive + + RowRange() = default; + RowRange(vector_size_t _startRow, vector_size_t _endRow) + : startRow(_startRow), endRow(_endRow) {} + + /// Returns true if this row range is empty (no rows to read). + bool empty() const { + return startRow >= endRow; + } +}; + +/** + * Abstract index reader interface for index-based lookups. + * + * IndexReader provides a batch lookup API that takes a vector of index bounds + * and returns results via an iterator pattern. This interface is used by + * HiveIndexReader to perform efficient key-based lookups on indexed files. + * + * Usage pattern: + * 1. Call startLookup() with a vector of index bounds to start a new batch + * lookup + * 2. Call next() repeatedly to get results until it returns nullptr + * 3. Each next() call returns results for one or more request indices + * + * The implementation is responsible for: + * - Encoding index bounds into format-specific keys + * - Looking up stripes and row ranges + * - Managing stripe iteration and data reading + * - Returning results in the correct order + */ +class IndexReader { + public: + /// Runtime stat names for index reader. + + /// Tracks the number of index lookup requests submitted in startLookup(). + /// Each request corresponds to one set of index bounds and may match rows + /// across multiple stripes. + static constexpr std::string_view kNumIndexLookupRequests = + "numIndexLookupRequests"; + + /// Tracks the total number of stripes that need to be read for all requests. + /// Within a single startLookup() call, a stripe shared by multiple requests + /// is counted once; across different startLookup() calls, the same stripe is + /// counted separately for each call. + static constexpr std::string_view kNumIndexLookupStripes = + "numIndexLookupStripes"; + + /// Tracks the total number of rows in all loaded stripes. Measures the full + /// stripe row count regardless of how many rows are actually needed by + /// index lookups. Comparing with kNumIndexMatchedRows shows cluster index + /// selectivity within stripes. + static constexpr std::string_view kNumIndexScannedRows = + "numIndexScannedRows"; + + /// Tracks the total number of rows matched by the cluster index across all + /// stripes. These are the rows identified as matching the lookup bounds + /// within each stripe, before any ScanSpec filter pushdown. Comparing with + /// actual output rows shows filter selectivity. + static constexpr std::string_view kNumIndexMatchedRows = + "numIndexMatchedRows"; + + /// Tracks the total number of read segments across all stripes. A read + /// segment is a contiguous row range within a stripe that needs to be read. + /// When filters are present, overlapping request ranges are split at + /// boundaries to enable per-request output tracking. Without filters, + /// overlapping ranges are merged to minimize I/O. + static constexpr std::string_view kNumIndexLookupReadSegments = + "numIndexLookupReadSegments"; + + /// Wall time spent loading stripes (or equivalent format-specific load unit) + /// during index lookup, summed across all stripe loads. + static constexpr std::string_view kIndexStripeLoadWallNanos = + "indexStripeLoadWallNanos"; + + /// CPU time spent loading stripes during index lookup. May undercount on + /// async/prefetch paths; see IndexSource::lookupTiming() for details. + static constexpr std::string_view kIndexStripeLoadCpuNanos = + "indexStripeLoadCpuNanos"; + + /// Wall time spent decoding column data from loaded stripes during index + /// lookup, summed across all read segments. + static constexpr std::string_view kIndexDataDecodeWallNanos = + "indexDataDecodeWallNanos"; + + /// CPU time spent decoding column data from loaded stripes during index + /// lookup. Same prefetch caveat as kIndexStripeLoadCpuNanos. + static constexpr std::string_view kIndexDataDecodeCpuNanos = + "indexDataDecodeCpuNanos"; + + /// Number of distinct stripes loaded across the lifetime of this index + /// reader. Useful for spotting redundant loads when comparing against + /// numStripeLoads (which counts every load call). + static constexpr std::string_view kNumIndexDistinctStripesLoaded = + "numIndexDistinctStripesLoaded"; + + virtual ~IndexReader() = default; + + /// Returns runtime statistics accumulated by this index reader. + virtual folly::F14FastMap stats() const { + return {}; + } + + /// Options for controlling index reader behavior. + struct Options { + /// Maximum number of rows to read per index lookup request. + /// When set to non-zero, the index reader will stop fetching or truncate + /// stripes once the total row range (before filtering) reaches this limit. + /// 0 means no limit (default). + vector_size_t maxRowsPerRequest{0}; + }; + + /// Starts a new batch lookup with the given index bounds. + /// Each index bound in the vector represents a separate lookup request. + /// After calling startLookup(), call next() repeatedly to get results. + /// + /// @param indexBounds Index bounds for the lookup request. Contains + /// column names and lower/upper bound values. + /// @param options Options controlling index reader behavior (e.g., + /// maxRowsPerRequest). Defaults to no limit. + /// @throws if lookup is not supported by the implementation or if any + /// index bound is invalid. + virtual void startLookup( + const velox::serializer::IndexBounds& indexBounds, + const Options& options) = 0; + + /// Returns true if there are more results to fetch from the current lookup. + virtual bool hasNext() const = 0; + + /// Returns the next batch of results from the current lookup. + /// Results are returned in request order - all results for request N are + /// returned before any results for request N+1. + /// + /// The Result contains: + /// - inputHits: Buffer of request indices for each output row + /// - output: RowVector of matching data rows + /// + /// @param maxOutputRows Maximum number of output rows to return in this + /// batch. The actual number may be less if fewer rows are available. + /// @return Result containing inputHits and output rows, or nullptr if no + /// more results are available. + virtual std::unique_ptr next( + vector_size_t maxOutputRows) = 0; +}; + /** * Abstract reader class. * @@ -179,11 +369,18 @@ class Reader { */ virtual std::optional numberOfRows() const = 0; - /** - * Get statistics for a specified column. - * @param index column index - * @return column statisctics - */ + /// Returns file-level statistics for the column identified by 'index'. + /// + /// 'index' is a node ID in the type tree (TypeWithId::id()), not a top-level + /// column ordinal. Node 0 is the root ROW type; top-level columns start at 1 + /// for flat schemas. For nested types, IDs follow pre-order DFS numbering. + /// + /// Use typeWithId() to navigate the schema and obtain the correct ID: + /// auto& col = reader->typeWithId()->childByName("column_name"); + /// auto stats = reader->columnStatistics(col->id()); + /// + /// Returns nullptr if statistics are not available (e.g., non-leaf complex + /// types in Parquet, out-of-range index, or missing stats in the file). virtual std::unique_ptr columnStatistics( uint32_t index) const = 0; @@ -207,6 +404,17 @@ class Reader { virtual std::unique_ptr createRowReader( const RowReaderOptions& options = {}) const = 0; + /** + * Create index reader object for index-based lookups. + * @param options Row reader options describing the data to fetch + * @return Index reader for efficient key-based lookups + * @throws if index reading is not supported by the implementation + */ + virtual std::unique_ptr createIndexReader( + const RowReaderOptions& options = {}) const { + VELOX_UNSUPPORTED("Reader::createIndexReader() is not supported"); + } + static TypePtr updateColumnNames( const TypePtr& fileType, const TypePtr& tableType); diff --git a/velox/dwio/common/Retry.h b/velox/dwio/common/Retry.h index d2860cfb6b7..9ea087d4361 100644 --- a/velox/dwio/common/Retry.h +++ b/velox/dwio/common/Retry.h @@ -28,10 +28,7 @@ #include "velox/dwio/common/RandGen.h" #include "velox/dwio/common/exception/Exception.h" -namespace facebook { -namespace velox { -namespace dwio { -namespace common { +namespace facebook::velox::dwio::common { class retriable_error : public std::runtime_error { public: @@ -229,7 +226,4 @@ class RetryModule { } }; -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/ScanSpec.cpp b/velox/dwio/common/ScanSpec.cpp index 69669122215..b6cbd4ab516 100644 --- a/velox/dwio/common/ScanSpec.cpp +++ b/velox/dwio/common/ScanSpec.cpp @@ -104,8 +104,8 @@ uint64_t ScanSpec::newRead() { if (numReads_ == 0 || (!disableStatsBasedFilterReorder_ && !std::is_sorted( - children_.begin(), - children_.end(), + children_.cbegin(), + children_.cend(), [this]( const std::shared_ptr& left, const std::shared_ptr& right) { @@ -321,6 +321,9 @@ bool testStringFilter( bool testBoolFilter( const common::Filter* filter, dwio::common::BooleanColumnStatistics* boolStats) { + if (!boolStats) { + return true; + } const auto trueCount = boolStats->getTrueCount(); const auto falseCount = boolStats->getFalseCount(); if (trueCount.has_value() && falseCount.has_value()) { diff --git a/velox/dwio/common/ScanSpec.h b/velox/dwio/common/ScanSpec.h index fbcac3d3a59..8204011ccda 100644 --- a/velox/dwio/common/ScanSpec.h +++ b/velox/dwio/common/ScanSpec.h @@ -28,17 +28,16 @@ #include -namespace facebook { -namespace velox { +namespace facebook::velox { namespace dwio::common { class ColumnStatistics; } namespace common { -// Describes the filtering and value extraction for a -// SelectiveColumnReader. This is owned by the TableScan Operator and -// is passed to SelectiveColumnReaders at construction. This is -// mutable by readers to reflect filter order and other adaptations. +/// Describes the filtering and value extraction for a +/// SelectiveColumnReader. This is owned by the TableScan Operator and +/// is passed to SelectiveColumnReaders at construction. This is +/// mutable by readers to reflect filter order and other adaptations. class ScanSpec { public: enum class ColumnType : int8_t { @@ -47,7 +46,7 @@ class ScanSpec { kComposite, // A struct with all children not read from file }; - // Convert ColumnType to its string name representation. + /// Convert ColumnType to its string name representation. static std::string_view columnTypeString(ColumnType columnType); static constexpr column_index_t kNoChannel = ~0; @@ -57,15 +56,15 @@ class ScanSpec { explicit ScanSpec(const std::string& name) : fieldName_(name) {} - // Filter to apply. If 'this' corresponds to a struct/list/map, this - // can only be isNull or isNotNull, other filtering is given by - // 'children'. + /// Filter to apply. If 'this' corresponds to a struct/list/map, this + /// can only be isNull or isNotNull, other filtering is given by + /// 'children'. const common::Filter* filter() const { return filterDisabled_ ? nullptr : filter_.get(); } - // Sets 'filter_'. May be used at initialization or when adding a - // pushed down filter, e.g. top k cutoff. + /// Sets 'filter_'. May be used at initialization or when adding a + /// pushed down filter, e.g. top k cutoff. void setFilter(std::shared_ptr filter) { filter_ = std::move(filter); } @@ -96,8 +95,8 @@ class ScanSpec { return metadataFilters_[i].second; } - // Returns a constant vector if 'this' corresponds to a partitioning - // column or to a missing column. These change from split to split. + /// Returns a constant vector if 'this' corresponds to a partitioning + /// column or to a missing column. These change from split to split. VectorPtr constantValue() const { return constantValue_; } @@ -128,32 +127,34 @@ class ScanSpec { return columnType_ == ColumnType::kRegular && !isConstant(); } - // Name of the value in its container, i.e. field name in struct or - // string key in map. Not all fields of 'this' apply in list/map - // value cases but the overhead is manageable, the space taken is - // less than the Subfield path that will in any case exist for each - // separately named list/map element. + /// Name of the value in its container, i.e. field name in struct or + /// string key in map. Not all fields of 'this' apply in list/map + /// value cases but the overhead is manageable, the space taken is + /// less than the Subfield path that will in any case exist for each + /// separately named list/map element. const std::string& fieldName() const { return fieldName_; } - // Subscript if this refers to a member of a list or an - // integer-keyed map value. If this is a member in a row, this is - // the ordinal position in the row type. Subscript is mutable, for - // example the position of the reader in a struct's readers may vary - // between splits. Set to correspond to the position of 'fieldName' - // when first reading a struct. Not mutable if this refers to a - // list/map subscript. + /// Subscript if this refers to a member of a list or an + /// integer-keyed map value. If this is a member in a row, this is + /// the ordinal position in the row type. Subscript is mutable, for + /// example the position of the reader in a struct's readers may vary + /// between splits. Set to correspond to the position of 'fieldName' + /// when first reading a struct. Not mutable if this refers to a + /// list/map subscript. int64_t subscript() const { return subscript_; } void setSubscript(int64_t subscript) { - subscript_ = subscript; + if (subscript_ != subscript) { + subscript_ = subscript; + } } - // True if the value is returned from scan. A runtime pushdown of a filter - // function may cause this to become false at run time. + /// True if the value is returned from scan. A runtime pushdown of a filter + /// function may cause this to become false at run time. bool projectOut() const { return projectOut_; } @@ -166,8 +167,8 @@ class ScanSpec { return projectOut_ || deltaUpdate_; } - // Position in the RowVector returned by the top level scan. Applies - // only to children of the root struct where projectOut_ is true. + /// Position in the RowVector returned by the top level scan. Applies + /// only to children of the root struct where projectOut_ is true. column_index_t channel() const { return channel_; } @@ -180,31 +181,31 @@ class ScanSpec { return children_; } - // Returns 'children in a stable order. May be used for parallel - // construction and read-ahead of reader trees while the main user - // of 'this' is running. 'children_' may be reordered while running - // but the tree being constructed must see a single, unchanging - // order. + /// Returns 'children in a stable order. May be used for parallel + /// construction and read-ahead of reader trees while the main user + /// of 'this' is running. 'children_' may be reordered while running + /// but the tree being constructed must see a single, unchanging + /// order. const std::vector& stableChildren(); - // Returns a read sequence number. This can b used for tagging - // lazy vectors with a generation number so that we can check that - // the reader that made them has not advanced between the making and - // the loading of the lazy vector. This must be called if 'this' - // corresponds to a struct or flat map reader with pushdown. This - // may periodically do adaptation such as filter reordering. This - // will initialize the read order on first call and calling this at - // each level of struct is mandatory. + /// Returns a read sequence number. This can b used for tagging + /// lazy vectors with a generation number so that we can check that + /// the reader that made them has not advanced between the making and + /// the loading of the lazy vector. This must be called if 'this' + /// corresponds to a struct or flat map reader with pushdown. This + /// may periodically do adaptation such as filter reordering. This + /// will initialize the read order on first call and calling this at + /// each level of struct is mandatory. uint64_t newRead(); /// Returns the ScanSpec corresponding to 'name'. Creates it if needed without /// any intermediate level. ScanSpec* getOrCreateChild(const std::string& name); - // Returns the ScanSpec corresponding to 'subfield'. Creates it if - // needed, including any intermediate levels. This is used at - // TableScan initialization to create the ScanSpec tree that - // corresponds to the ColumnReader tree. + /// Returns the ScanSpec corresponding to 'subfield'. Creates it if + /// needed, including any intermediate levels. This is used at + /// TableScan initialization to create the ScanSpec tree that + /// corresponds to the ColumnReader tree. ScanSpec* getOrCreateChild(const Subfield& subfield); ScanSpec* childByName(const std::string& name) const { @@ -227,11 +228,11 @@ class ScanSpec { valueHook_ = valueHook; } - // Returns true if the corresponding reader only needs to reference the nulls - // stream. True if filter is is-null with or without value extraction or if - // filter is is-not-null and no value is extracted. Note that this does not - // apply to Nimble format leaf nodes, because nulls are mixed in the encoding - // with actual values. + /// Returns true if the corresponding reader only needs to reference the nulls + /// stream. True if filter is is-null with or without value extraction or if + /// filter is is-not-null and no value is extracted. Note that this does not + /// apply to Nimble format leaf nodes, because nulls are mixed in the encoding + /// with actual values. bool readsNullsOnly() const { if (auto* filter = this->filter()) { if (filter->kind() == FilterKind::kIsNull) { @@ -252,11 +253,11 @@ class ScanSpec { makeFlat_ = makeFlat; } - // True if this or a descendant has a filter that will affect the number of - // output rows. Note that filter on map keys and array indices is not - // counted, as they do not change the number of container output rows. - // - // This may change as a result of runtime adaptation. + /// True if this or a descendant has a filter that will affect the number of + /// output rows. Note that filter on map keys and array indices is not + /// counted, as they do not change the number of container output rows. + /// + /// This may change as a result of runtime adaptation. bool hasFilter() const; /// Similar as hasFilter() but also return true even there is a filter on @@ -271,8 +272,8 @@ class ScanSpec { /// filtered out. bool testNull() const; - // Resets cached values after this or children were updated, e.g. a new filter - // was added or existing filter was modified. + /// Resets cached values after this or children were updated, e.g. a new + /// filter was added or existing filter was modified. void resetCachedValues(bool doReorder) { hasFilter_.reset(); for (auto& child : children_) { @@ -283,49 +284,50 @@ class ScanSpec { } } - // Returns the child which produces values for 'channel'. Throws if not found. + /// Returns the child which produces values for 'channel'. Throws if not + /// found. ScanSpec& getChildByChannel(column_index_t channel); - // sets filter order and filters of 'this' from 'other'. Used when - // initializing a ScanSpec for a new split or stripe. This transfers - // dynamically acquired filters and adaptive filter order. 'other' - // should not be used after this. Different splits or stripes may - // have their own ScanSpec trees, so we only move the content, not - // the ScanSpec tree itself. + /// Sets filter order and filters of 'this' from 'other'. Used when + /// initializing a ScanSpec for a new split or stripe. This transfers + /// dynamically acquired filters and adaptive filter order. 'other' + /// should not be used after this. Different splits or stripes may + /// have their own ScanSpec trees, so we only move the content, not + /// the ScanSpec tree itself. void moveAdaptationFrom(ScanSpec& other); std::string toString() const; - // Add a field to this ScanSpec, with content projected out. + /// Add a field to this ScanSpec, with content projected out. ScanSpec* addField(const std::string& name, column_index_t channel); - // Add a field and its children recursively to this ScanSpec, all projected - // out. + /// Add a field and its children recursively to this ScanSpec, all projected + /// out. ScanSpec* addFieldRecursively( const std::string& name, const Type&, column_index_t channel); - // Add a field for map key. + /// Add a field for map key. ScanSpec* addMapKeyField(); - // Add a field for map key, along with its child recursively. + /// Add a field for map key, along with its child recursively. ScanSpec* addMapKeyFieldRecursively(const Type&); - // Add a field for map value. + /// Add a field for map value. ScanSpec* addMapValueField(); - // Add a field for map value, along with its child recursively. + /// Add a field for map value, along with its child recursively. ScanSpec* addMapValueFieldRecursively(const Type&); - // Add a field for array element. + /// Add a field for array element. ScanSpec* addArrayElementField(); - // Add a field for array element, along with its child recursively. + /// Add a field for array element, along with its child recursively. ScanSpec* addArrayElementFieldRecursively(const Type&); - // Add all child fields on the type recursively to this ScanSpec, all - // projected out. + /// Add all child fields on the type recursively to this ScanSpec, all + /// projected out. void addAllChildFields(const Type&); const std::vector& flatMapFeatureSelection() const { @@ -378,6 +380,86 @@ class ScanSpec { isFlatMapAsStruct_ = value; } + /// Extraction type for column extraction pushdown. When set on a map or + /// array ScanSpec, tells the column reader to skip reading unneeded + /// streams and produce a different output type. + enum class ExtractionType : uint8_t { + kNone, ///< No extraction — read normally. + kKeys, ///< Extract map keys as ARRAY(K). Skip reading values. + kValues, ///< Extract map values as ARRAY(V). Skip reading keys. + kSize, ///< Extract size as BIGINT. Skip reading keys and values. + kField, ///< Extract a single struct field. Output is the field's type. + }; + + void setExtractionType(ExtractionType type) { + extractionType_ = type; + } + + ExtractionType extractionType() const { + return extractionType_; + } + + void setExtractionFieldIndex(column_index_t index) { + extractionFieldIndex_ = index; + } + + column_index_t extractionFieldIndex() const { + return extractionFieldIndex_; + } + + /// Post-read transform applied to the column's vector after it is read + /// from the file. Used by column extraction pushdown to transform the + /// file-type vector (e.g., MAP) into the extraction output type (e.g., + /// ARRAY for MapKeys). The transform receives the read vector and a + /// memory pool, and returns the transformed vector. + /// + /// The transform is applied in these cases: + /// + /// 1. Delta update fallback: when a column has both extraction pushdown + /// and a delta update (e.g., MAP_CONCAT), the reader checks + /// deltaUpdate() and bypasses ExtractionType, producing the file + /// type. After the delta update modifies the vector, the full-chain + /// transform is applied by SelectiveStructColumnReaderBase::getValues(). + /// ExtractionType may be set but is ignored by the reader. + /// + /// 2. Multiple extractions: when multiple extraction chains target the + /// same column, the transform assembles a ROW from the individual + /// extraction results. ExtractionType is always kNone (not set for + /// multiple extractions to ensure text reader compatibility). + /// + /// 3. Text reader: RowReader::projectColumns() applies the transform + /// because the text reader does not use the selective reader + /// framework. ExtractionType is kNone (same as case 2) or set but + /// ignored (single extraction — text reader applies the full-chain + /// transform regardless). + /// + /// For selective readers (DWRF, Nimble) with a single extraction and + /// no delta update, the full chain is handled by ScanSpec pushdown + /// (ExtractionType + sub-spec pruning) and the transform is not + /// applied at read time. + using VectorTransform = + std::function; + + /// Set the post-read transform and the output type it produces. + /// The reader uses outputType to allocate the result vector, reads into + /// a temporary vector of the file type, then applies the transform. + void setTransform(VectorTransform transform, TypePtr outputType) { + transform_ = std::move(transform); + transformOutputType_ = std::move(outputType); + } + + const VectorTransform& transform() const { + return transform_; + } + + const TypePtr& transformOutputType() const { + return transformOutputType_; + } + + bool hasTransform() const { + return transform_ != nullptr; + } + /// Disable stats based filter reordering. void disableStatsBasedFilterReorder() { disableStatsBasedFilterReorder_ = true; @@ -407,29 +489,29 @@ class ScanSpec { // Number of times read is called on the corresponding reader. This // is used for setup on first use and to produce a read sequence // number for LazyVectors. - uint64_t numReads_ = 0; + uint64_t numReads_{0}; // Ordinal position of 'this' in its containing spec. For a struct // member this is the position of the reader in the child // readers. If this describes an operation on an array element or a // map with numeric key, this is the subscript as defined for array // or map. - int64_t subscript_ = -1; + int64_t subscript_{-1}; // Column name if this is a struct mamber. String key if this // describes an operation on a map value. std::string fieldName_; // Ordinal position of the extracted value in the containing // RowVector. Set only when this describes a struct member. - column_index_t channel_ = kNoChannel; + column_index_t channel_{kNoChannel}; VectorPtr constantValue_; - bool projectOut_ = false; + bool projectOut_{false}; - ColumnType columnType_ = ColumnType::kRegular; + ColumnType columnType_{ColumnType::kRegular}; // True if a string dictionary or flat map in this field should be // returned as flat. - bool makeFlat_ = false; + bool makeFlat_{false}; std::shared_ptr filter_; bool filterDisabled_ = false; dwio::common::DeltaColumnUpdater* deltaUpdate_ = nullptr; @@ -470,6 +552,18 @@ class ScanSpec { // This node represents a flat map column that need to be read as struct, // i.e. in table schema it is a MAP, but in result vector it is ROW. bool isFlatMapAsStruct_ = false; + + // Extraction type for map/array/struct column readers. + ExtractionType extractionType_{ExtractionType::kNone}; + + // Index of the field to extract when extractionType_ is kField. + column_index_t extractionFieldIndex_{0}; + + // Post-read transform for column extraction pushdown. + VectorTransform transform_; + + // Output type after the transform is applied. + TypePtr transformOutputType_; }; template @@ -501,8 +595,8 @@ void ScanSpec::visit(const Type& type, F&& f) { } } -// Returns false if no value from a range defined by stats can pass the -// filter. True, otherwise. +/// Returns false if no value from a range defined by stats can pass the +/// filter. True, otherwise. bool testFilter( const common::Filter* filter, dwio::common::ColumnStatistics* stats, @@ -510,8 +604,7 @@ bool testFilter( const TypePtr& type); } // namespace common -} // namespace velox -} // namespace facebook +} // namespace facebook::velox template <> struct fmt::formatter diff --git a/velox/dwio/common/SeekableInputStream.cpp b/velox/dwio/common/SeekableInputStream.cpp index 2f461551626..db3c7f4a502 100644 --- a/velox/dwio/common/SeekableInputStream.cpp +++ b/velox/dwio/common/SeekableInputStream.cpp @@ -163,8 +163,8 @@ bool SeekableArrayInputStream::SkipInt64(int64_t count) { return false; } -google::protobuf::int64 SeekableArrayInputStream::ByteCount() const { - return static_cast(position_); +int64_t SeekableArrayInputStream::ByteCount() const { + return static_cast(position_); } void SeekableArrayInputStream::seekToPosition(PositionProvider& position) { @@ -241,8 +241,8 @@ bool SeekableFileInputStream::SkipInt64(int64_t signedCount) { return position_ < length_; } -google::protobuf::int64 SeekableFileInputStream::ByteCount() const { - return static_cast(position_); +int64_t SeekableFileInputStream::ByteCount() const { + return static_cast(position_); } void SeekableFileInputStream::seekToPosition(PositionProvider& location) { diff --git a/velox/dwio/common/SeekableInputStream.h b/velox/dwio/common/SeekableInputStream.h index c53347a6204..f4e33008523 100644 --- a/velox/dwio/common/SeekableInputStream.h +++ b/velox/dwio/common/SeekableInputStream.h @@ -79,7 +79,7 @@ class SeekableArrayInputStream : public SeekableInputStream { virtual bool Next(const void** data, int32_t* size) override; virtual void BackUp(int32_t count) override; virtual bool SkipInt64(int64_t count) override; - virtual google::protobuf::int64 ByteCount() const override; + virtual int64_t ByteCount() const override; virtual void seekToPosition(PositionProvider& position) override; virtual std::string getName() const override; virtual size_t positionSize() const override; @@ -120,7 +120,7 @@ class SeekableFileInputStream : public SeekableInputStream { virtual bool Next(const void** data, int32_t* size) override; virtual void BackUp(int32_t count) override; virtual bool SkipInt64(int64_t count) override; - virtual google::protobuf::int64 ByteCount() const override; + virtual int64_t ByteCount() const override; virtual void seekToPosition(PositionProvider& position) override; virtual std::string getName() const override; virtual size_t positionSize() const override; diff --git a/velox/dwio/common/SelectiveColumnReader.cpp b/velox/dwio/common/SelectiveColumnReader.cpp index 17b2207d65a..f2fc5918c02 100644 --- a/velox/dwio/common/SelectiveColumnReader.cpp +++ b/velox/dwio/common/SelectiveColumnReader.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "velox/common/time/CpuWallTimer.h" #include "velox/dwio/common/SelectiveColumnReaderInternal.h" namespace facebook::velox::dwio::common { @@ -46,17 +47,36 @@ SelectiveColumnReader::SelectiveColumnReader( std::shared_ptr fileType, dwio::common::FormatParams& params, velox::common::ScanSpec& scanSpec) - : memoryPool_(¶ms.pool()), + : pool_(¶ms.pool()), requestedType_(requestedType), fileType_(fileType), formatData_(params.toFormatData(fileType, scanSpec)), scanSpec_(&scanSpec), - outputRows_(memoryPool_), - valueRows_(memoryPool_), - outerNonNullRows_(memoryPool_), - innerNonNullRows_(memoryPool_) { - scanState_.rowsCopy = raw_vector(memoryPool_); - scanState_.filterCache = raw_vector(memoryPool_); + outputRows_(pool_), + valueRows_(pool_), + outerNonNullRows_(pool_), + innerNonNullRows_(pool_) { + scanState_.rowsCopy = raw_vector(pool_); + scanState_.filterCache = raw_vector(pool_); + // Initialize per-column metrics if collection is enabled. + if (params.runtimeStatistics().columnMetricsSet) { + columnMetrics_ = params.runtimeStatistics().columnMetricsSet->getOrCreate( + fileType_->id()); + } +} + +void SelectiveColumnReader::readWithTiming( + int64_t offset, + const RowSet& rows, + const uint64_t* incomingNulls) { + if (columnMetrics_ && fileType_->type()->isPrimitiveType()) { + DeltaCpuWallTimer timer([this](const CpuWallTiming& timing) { + columnMetrics_->decodeCPUTimeNanos.increment(timing.cpuNanos); + }); + read(offset, rows, incomingNulls); + } else { + read(offset, rows, incomingNulls); + } } void SelectiveColumnReader::filterRowGroups( @@ -130,8 +150,8 @@ void SelectiveColumnReader::prepareNulls( resultNulls_->capacity() >= bits::nbytes(numRows) + simd::kPadding) { resultNulls_->setSize(bits::nbytes(numRows)); } else { - resultNulls_ = AlignedBuffer::allocate( - numRows + (simd::kPadding * 8), memoryPool_); + resultNulls_ = + AlignedBuffer::allocate(numRows + (simd::kPadding * 8), pool_); rawResultNulls_ = resultNulls_->asMutable(); } anyNulls_ = false; @@ -151,7 +171,7 @@ const uint64_t* SelectiveColumnReader::shouldMoveNulls(const RowSet& rows) { if (!(resultNulls_ && resultNulls_->unique() && resultNulls_->capacity() >= rows.size() + simd::kPadding)) { resultNulls_ = AlignedBuffer::allocate( - rows.size() + (simd::kPadding * 8), memoryPool_); + rows.size() + (simd::kPadding * 8), pool_); rawResultNulls_ = resultNulls_->asMutable(); } moveFrom = nullsInReadRange_->as(); @@ -243,7 +263,19 @@ void SelectiveColumnReader::getIntValues( } break; case TypeKind::HUGEINT: - getFlatValues(rows, result, requestedType); + switch (valueSize_) { + case 16: + getFlatValues(rows, result, requestedType); + break; + case 8: + getFlatValues(rows, result, requestedType); + break; + case 4: + getFlatValues(rows, result, requestedType); + break; + default: + VELOX_FAIL("Unsupported value size: {}", valueSize_); + } break; case TypeKind::BIGINT: switch (valueSize_) { @@ -260,6 +292,17 @@ void SelectiveColumnReader::getIntValues( VELOX_FAIL("Unsupported value size: {}", valueSize_); } break; + case TypeKind::DOUBLE: + // Only Parquet INT32 (valueSize_==4) widens to DOUBLE. INT64->DOUBLE + // is rejected in convertType due to precision loss. + switch (valueSize_) { + case 4: + getFlatValues(rows, result, requestedType); + break; + default: + VELOX_FAIL("Unsupported value size: {}", valueSize_); + } + break; default: VELOX_FAIL( "Not a valid type for integer reader: {}", requestedType->toString()); @@ -344,8 +387,7 @@ void SelectiveColumnReader::getFlatValues( constexpr int32_t kWidth = xsimd::batch::size; VELOX_CHECK_EQ(valueSize_, sizeof(int8_t)); compactScalarValues(rows, isFinal); - auto boolValues = - AlignedBuffer::allocate(numValues_, memoryPool_, false); + auto boolValues = AlignedBuffer::allocate(numValues_, pool_, false); auto rawBytes = values_->as(); auto zero = xsimd::broadcast(0); if constexpr (kWidth == 32) { @@ -363,7 +405,7 @@ void SelectiveColumnReader::getFlatValues( } } *result = std::make_shared>( - memoryPool_, + pool_, type, resultNulls(), numValues_, @@ -410,11 +452,11 @@ void SelectiveColumnReader::compactScalarValues( values_->setSize(bits::nbytes(numValues_)); } -char* SelectiveColumnReader::copyStringValue(folly::StringPiece value) { +char* SelectiveColumnReader::copyStringValue(std::string_view value) { uint64_t size = value.size(); if (stringBuffers_.empty() || rawStringUsed_ + size > rawStringSize_) { auto bytes = std::max(size, kStringBufferSize); - BufferPtr buffer = AlignedBuffer::allocate(bytes, memoryPool_); + BufferPtr buffer = AlignedBuffer::allocate(bytes, pool_); // Use the preferred size instead of the requested one to improve memory // efficiency. buffer->setSize(buffer->capacity()); @@ -431,7 +473,7 @@ char* SelectiveColumnReader::copyStringValue(folly::StringPiece value) { return rawStringBuffer_ + start; } -void SelectiveColumnReader::addStringValue(folly::StringPiece value) { +void SelectiveColumnReader::addStringValue(std::string_view value) { auto copy = copyStringValue(value); reinterpret_cast(rawValues_)[numValues_++] = StringView(copy, value.size()); @@ -449,8 +491,11 @@ void SelectiveColumnReader::setNulls(BufferPtr resultNulls) { void SelectiveColumnReader::resetFilterCaches() { if (scanState_.filterCache.empty() && scanSpec_->hasFilter()) { - scanState_.filterCache.resize(std::max( - 1, scanState_.dictionary.numValues + scanState_.dictionary2.numValues)); + scanState_.filterCache.resize( + std::max( + 1, + scanState_.dictionary.numValues + + scanState_.dictionary2.numValues)); scanState_.updateRawState(); } if (!scanState_.filterCache.empty()) { @@ -476,7 +521,7 @@ void SelectiveColumnReader::addSkippedParentNulls( int64_t from, int64_t to, int32_t numNulls) { - auto rowsPerRowGroup = formatData_->rowsPerRowGroup(); + const auto rowsPerRowGroup = formatData_->rowsPerRowGroup(); if (rowsPerRowGroup.has_value() && from / rowsPerRowGroup.value() > parentNullsRecordedTo_ / rowsPerRowGroup.value()) { @@ -484,7 +529,7 @@ void SelectiveColumnReader::addSkippedParentNulls( parentNullsRecordedTo_ = from; numParentNulls_ = 0; } - if (parentNullsRecordedTo_) { + if (parentNullsRecordedTo_ > 0) { VELOX_CHECK_EQ(parentNullsRecordedTo_, from); } numParentNulls_ += numNulls; diff --git a/velox/dwio/common/SelectiveColumnReader.h b/velox/dwio/common/SelectiveColumnReader.h index 591208b1788..fe727c1ba37 100644 --- a/velox/dwio/common/SelectiveColumnReader.h +++ b/velox/dwio/common/SelectiveColumnReader.h @@ -27,6 +27,10 @@ namespace facebook::velox::dwio::common { +struct ColumnMetrics; + +using ScanSpec = velox::common::ScanSpec; + /// Generalized representation of a set of distinct values for dictionary /// encodings. struct DictionaryValues { @@ -128,6 +132,10 @@ struct ScanState { RawScanState rawState; }; +inline bool isDense(const RowSet& rows) { + return rows.empty() || rows.size() == rows.back() + 1; +} + class SelectiveColumnReader { public: static constexpr uint64_t kStringBufferSize = 16 * 1024; @@ -169,6 +177,14 @@ class SelectiveColumnReader { virtual void read(int64_t offset, const RowSet& rows, const uint64_t* incomingNulls) = 0; + /// Wraps read() to collect decode timing stats for leaf columns. + /// Only times primitive types to avoid double-counting in complex types + /// (struct/map/array) which recursively call children's readWithTiming. + void readWithTiming( + int64_t offset, + const RowSet& rows, + const uint64_t* incomingNulls); + virtual uint64_t skip(uint64_t numValues) { return formatData_->skip(numValues); } @@ -241,7 +257,7 @@ class SelectiveColumnReader { uint64_t* mutableNulls(int32_t size) { if (!resultNulls_->unique()) { resultNulls_ = AlignedBuffer::allocate( - numValues_ + size, memoryPool_, bits::kNotNull); + numValues_ + size, pool_, bits::kNotNull); rawResultNulls_ = resultNulls_->asMutable(); } if (resultNulls_->capacity() * 8 < numValues_ + size) { @@ -482,20 +498,27 @@ class SelectiveColumnReader { return false; } - StringView copyStringValueIfNeed(folly::StringPiece value) { - if (value.size() <= StringView::kInlineSize) { + StringView copyStringValueIfNeed(std::string_view value) { + if (value.size() <= StringView::kInlineSize || + formatData().stringDecoderZeroCopy()) { return StringView(value); } + auto* data = copyStringValue(value); return StringView(data, value.size()); } + void setStringBuffers(std::vector buffers) { + stringBuffers_ = std::move(buffers); + rawStringBuffer_ = nullptr; + } + virtual void setCurrentRowNumber(int64_t /*value*/) { VELOX_UNREACHABLE("Only struct reader supports this method"); } memory::MemoryPool* memoryPool() const { - return memoryPool_; + return pool_; } protected: @@ -581,11 +604,11 @@ class SelectiveColumnReader { // Checks consistency of nulls-related state. const uint64_t* shouldMoveNulls(const RowSet& rows); - void addStringValue(folly::StringPiece value); + void addStringValue(std::string_view value); // Copies 'value' to buffers owned by 'this' and returns the start of the // copy. - char* copyStringValue(folly::StringPiece value); + char* copyStringValue(std::string_view value); virtual bool hasDeletion() const { return false; @@ -620,7 +643,7 @@ class SelectiveColumnReader { return scanSpec_->hasFilter() || hasDeletion(); } - memory::MemoryPool* const memoryPool_; + memory::MemoryPool* const pool_; // The requested data type const TypePtr requestedType_; @@ -636,6 +659,10 @@ class SelectiveColumnReader { // run time based on adaptation. Owned by caller. velox::common::ScanSpec* const scanSpec_; + // Per-column metrics for timing stats. May be nullptr if collection is + // disabled. + ColumnMetrics* columnMetrics_{nullptr}; + // Row number after last read row, relative to the ORC stripe or Parquet // Rowgroup start. int64_t readOffset_ = 0; @@ -732,13 +759,14 @@ class SelectiveColumnReader { }; template <> -inline void SelectiveColumnReader::addValue(const folly::StringPiece value) { +inline void SelectiveColumnReader::addValue(const std::string_view value) { const uint64_t size = value.size(); - if (size <= StringView::kInlineSize) { + if (formatData().stringDecoderZeroCopy() || size <= StringView::kInlineSize) { reinterpret_cast(rawValues_)[numValues_++] = StringView(value.data(), size); return; } + if (rawStringBuffer_ && rawStringUsed_ + size <= rawStringSize_) { memcpy(rawStringBuffer_ + rawStringUsed_, value.data(), size); reinterpret_cast(rawValues_)[numValues_++] = @@ -766,7 +794,7 @@ struct NoHook final : public ValueHook { void addValue(vector_size_t /*row*/, double /*value*/) final {} - void addValue(vector_size_t /*row*/, folly::StringPiece /*value*/) final {} + void addValue(vector_size_t /*row*/, std::string_view /*value*/) final {} }; } // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/SelectiveColumnReaderInternal.h b/velox/dwio/common/SelectiveColumnReaderInternal.h index f491870bfff..a9033b8d10b 100644 --- a/velox/dwio/common/SelectiveColumnReaderInternal.h +++ b/velox/dwio/common/SelectiveColumnReaderInternal.h @@ -21,16 +21,22 @@ #include "velox/dwio/common/DirectDecoder.h" #include "velox/dwio/common/SelectiveColumnReader.h" #include "velox/dwio/common/TypeUtils.h" -#include "velox/exec/AggregationHook.h" #include "velox/type/Timestamp.h" +#include "velox/vector/AggregationHook.h" #include "velox/vector/ConstantVector.h" #include "velox/vector/DictionaryVector.h" #include "velox/vector/FlatVector.h" #include +#include namespace facebook::velox::dwio::common { +/// True for arithmetic types and extended integer types (int128_t, uint128_t). +template +inline constexpr bool kIsNumericScalar = std::is_arithmetic_v || + std::is_same_v || std::is_same_v; + template void SelectiveColumnReader::ensureValuesCapacity( vector_size_t numRows, @@ -40,8 +46,8 @@ void SelectiveColumnReader::ensureValuesCapacity( BaseVector::byteSize(numRows) + simd::kPadding) { return; } - auto newValues = AlignedBuffer::allocate( - numRows + simd::kPadding / sizeof(T), memoryPool_); + auto newValues = + AlignedBuffer::allocate(numRows + simd::kPadding / sizeof(T), pool_); if (preserveData) { std::memcpy( newValues->template asMutable(), rawValues_, values_->capacity()); @@ -98,6 +104,26 @@ void SelectiveColumnReader::getFlatValues( VectorPtr* result, const TypePtr& type, bool isFinal) { + static_assert( + std::is_trivially_copyable_v && std::is_trivially_copyable_v, + "T and TVector must be trivially copyable types"); + + // When T and TVector differ, both must be numeric scalars. This prevents + // accidental cross-domain copies such as Timestamp<->int64_t or + // StringView<->int128_t. Same-size cross-domain conversions (e.g., + // int32_t -> float) would be strict-aliasing violations in + // compactScalarValues; schema validation must reject them before reaching + // here. + if constexpr (!std::is_same_v) { + static_assert( + kIsNumericScalar && kIsNumericScalar, + "Cross-type getFlatValues requires both T and TVector to be numeric"); + static_assert( + sizeof(T) != sizeof(TVector) || + std::is_floating_point_v == std::is_floating_point_v, + "Same-size cross-domain conversions (e.g., int32 -> float) are not " + "supported"); + } VELOX_CHECK_NE(valueSize_, kNoValueSize); VELOX_CHECK(mayGetValues_); if (isFinal) { @@ -111,12 +137,12 @@ void SelectiveColumnReader::getFlatValues( } else { flatMapValueConstantNullValues_ = std::make_shared>( - memoryPool_, rows.size(), true, type, T()); + pool_, rows.size(), true, type, TVector()); } *result = flatMapValueConstantNullValues_; } else { *result = std::make_shared>( - memoryPool_, rows.size(), true, type, T()); + pool_, rows.size(), true, type, TVector()); } return; } @@ -138,7 +164,7 @@ void SelectiveColumnReader::getFlatValues( flat->setStringBuffers(std::move(stringBuffers_)); } else { flatMapValueFlatValues_ = std::make_shared>( - memoryPool_, + pool_, type, resultNulls(), numValues_, @@ -148,7 +174,7 @@ void SelectiveColumnReader::getFlatValues( *result = flatMapValueFlatValues_; } else { *result = std::make_shared>( - memoryPool_, + pool_, type, resultNulls(), numValues_, diff --git a/velox/dwio/common/SelectiveFlatMapColumnReader.cpp b/velox/dwio/common/SelectiveFlatMapColumnReader.cpp index 05bf1a7a991..15ce348fd7f 100644 --- a/velox/dwio/common/SelectiveFlatMapColumnReader.cpp +++ b/velox/dwio/common/SelectiveFlatMapColumnReader.cpp @@ -57,42 +57,16 @@ void SelectiveFlatMapColumnReader::getValues( auto* resultFlatMap = prepareResult(*result, keysVector_, rows.size()); setComplexNulls(rows, *result); - for (const auto& childSpec : scanSpec_->children()) { - VELOX_TRACE_HISTORY_PUSH("getValues %s", childSpec->fieldName().c_str()); - if (!childSpec->keepValues()) { - continue; - } - - VELOX_CHECK( - childSpec->readFromFile(), - "Flatmap children must always be read from file."); - - if (childSpec->subscript() == kConstantChildSpecSubscript) { - continue; - } - - const auto channel = childSpec->channel(); - const auto index = childSpec->subscript(); - auto& childResult = resultFlatMap->mapValuesAt(channel); - - VELOX_CHECK( - !childSpec->deltaUpdate(), - "Delta update not supported in flat map yet"); - VELOX_CHECK( - !childSpec->isConstant(), - "Flat map values cannot be constant in scanSpec."); - VELOX_CHECK_EQ( - childSpec->columnType(), - velox::common::ScanSpec::ColumnType::kRegular, - "Flat map only supports regular column types in scan spec."); - - children_[index]->getValues(rows, &childResult); - - for (size_t i = 0; i < children_.size(); ++i) { - const auto& inMap = inMapBuffer(i); - if (inMap) { - resultFlatMap->inMapsAt(i, true) = inMap; - } + // Loop over column readers + for (int i = 0; i < children_.size(); ++i) { + auto& child = children_[i]; + VectorPtr values; + child->getValues(rows, &values); + resultFlatMap->mapValuesAt(i) = values; + + const auto& inMap = inMapBuffer(i); + if (inMap) { + resultFlatMap->inMapsAt(i, true) = inMap; } } } diff --git a/velox/dwio/common/SelectiveFloatingPointColumnReader.h b/velox/dwio/common/SelectiveFloatingPointColumnReader.h index a096b7e34e1..1f3445918b6 100644 --- a/velox/dwio/common/SelectiveFloatingPointColumnReader.h +++ b/velox/dwio/common/SelectiveFloatingPointColumnReader.h @@ -35,9 +35,11 @@ class SelectiveFloatingPointColumnReader : public SelectiveColumnReader { params, scanSpec) {} - // Offers fast path only if data and result widths match. + // Offers a fast path only if data and result widths match. + static constexpr bool kHasBulkPath = std::is_same_v; + bool hasBulkPath() const override { - return std::is_same_v; + return kHasBulkPath; } template @@ -90,7 +92,7 @@ void SelectiveFloatingPointColumnReader::readHelper( TFilter, ExtractValues, isDense, - std::is_same_v>( + Reader::kHasBulkPath>( *static_cast(filter), this, rows, extractValues)); } diff --git a/velox/dwio/common/SelectiveIntegerColumnReader.h b/velox/dwio/common/SelectiveIntegerColumnReader.h index 679d13dcf5c..d9ba805167a 100644 --- a/velox/dwio/common/SelectiveIntegerColumnReader.h +++ b/velox/dwio/common/SelectiveIntegerColumnReader.h @@ -200,6 +200,13 @@ void SelectiveIntegerColumnReader::processFilter( velox::common::NegatedBigintValuesUsingBitmask, isDense>(filter, rows, extractValues); break; + case velox::common::FilterKind::kBigintValuesUsingBloomFilter: + static_cast(this) + ->template readHelper< + Reader, + velox::common::BigintValuesUsingBloomFilter, + isDense>(filter, rows, extractValues); + break; default: static_cast(this) ->template readHelper( diff --git a/velox/dwio/common/SelectiveRepeatedColumnReader.cpp b/velox/dwio/common/SelectiveRepeatedColumnReader.cpp index 5342e4591e1..b17098d612f 100644 --- a/velox/dwio/common/SelectiveRepeatedColumnReader.cpp +++ b/velox/dwio/common/SelectiveRepeatedColumnReader.cpp @@ -80,7 +80,7 @@ void prepareResult( void SelectiveRepeatedColumnReader::ensureAllLengthsBuffer(vector_size_t size) { if (!allLengthsHolder_ || allLengthsHolder_->capacity() < size * sizeof(vector_size_t)) { - allLengthsHolder_ = allocateIndices(size, memoryPool_); + allLengthsHolder_ = allocateIndices(size, pool_); allLengths_ = allLengthsHolder_->asMutable(); } } @@ -209,34 +209,69 @@ RowSet SelectiveRepeatedColumnReader::applyFilter(const RowSet& rows) { return outputRows_; } +void SelectiveRepeatedColumnReader::getExtractionSizeValues( + const RowSet& rows, + VectorPtr* result) { + VELOX_DCHECK_NOT_NULL(result); + FlatVector* flatResult = nullptr; + if (*result && result->get()->type()->isBigint()) { + flatResult = result->get()->asFlatVector(); + } + if (!flatResult || !flatResult->values()) { + *result = std::make_shared>( + pool_, + BIGINT(), + nullptr, + rows.size(), + AlignedBuffer::allocate(rows.size(), pool_), + std::vector{}); + flatResult = result->get()->asFlatVector(); + } else { + flatResult->resize(static_cast(rows.size())); + } + auto* sizesData = flatResult->mutableRawValues(); + auto* nulls = nullsInReadRange_ ? nullsInReadRange_->as() : nullptr; + for (vector_size_t i = 0; i < static_cast(rows.size()); ++i) { + sizesData[i] = + (nulls && bits::isBitNull(nulls, rows[i])) ? 0 : allLengths_[rows[i]]; + } + setComplexNulls(rows, *result); +} + SelectiveListColumnReader::SelectiveListColumnReader( const TypePtr& requestedType, const std::shared_ptr& fileType, FormatParams& params, velox::common::ScanSpec& scanSpec) : SelectiveRepeatedColumnReader(requestedType, params, scanSpec, fileType) { + VELOX_CHECK( + scanSpec.extractionType() == + velox::common::ScanSpec::ExtractionType::kNone || + scanSpec.extractionType() == + velox::common::ScanSpec::ExtractionType::kSize, + "Array column reader only supports kNone and kSize extraction, got: {}", + static_cast(scanSpec.extractionType())); } uint64_t SelectiveListColumnReader::skip(uint64_t numValues) { numValues = formatData_->skipNulls(numValues); - if (child_) { - std::array buffer; - uint64_t childElements = 0; - uint64_t lengthsRead = 0; - while (lengthsRead < numValues) { - uint64_t chunk = - std::min(numValues - lengthsRead, static_cast(kBufferSize)); - readLengths(buffer.data(), chunk, nullptr); - for (size_t i = 0; i < chunk; ++i) { - childElements += static_cast(buffer[i]); - } - lengthsRead += chunk; + std::array buffer{}; + uint64_t childElements = 0; + uint64_t lengthsRead = 0; + while (lengthsRead < numValues) { + uint64_t chunk = + std::min(numValues - lengthsRead, static_cast(kBufferSize)); + readLengths(buffer.data(), static_cast(chunk), nullptr); + for (size_t i = 0; i < chunk; ++i) { + childElements += static_cast(buffer[i]); } - child_->seekTo(child_->readOffset() + childElements, false); - childTargetReadOffset_ += childElements; - } else { - VELOX_FAIL("Repeated reader with no children"); + lengthsRead += chunk; + } + if (child_) { + child_->seekTo( + child_->readOffset() + static_cast(childElements), false); } + childTargetReadOffset_ += static_cast(childElements); return numValues; } @@ -245,15 +280,27 @@ void SelectiveListColumnReader::read( const RowSet& rows, const uint64_t* incomingNulls) { // Catch up if the child is behind the length stream. - child_->seekTo(childTargetReadOffset_, false); + if (child_) { + child_->seekTo(childTargetReadOffset_, false); + } prepareRead(offset, rows, incomingNulls); auto activeRows = applyFilter(rows); nestedRowsAllSelected_ = activeRows.size() == rows.back() + 1 && scanSpec_->maxArrayElementsCount() == std::numeric_limits::max(); makeNestedRowSet(activeRows, rows.back()); - if (child_ && !nestedRows_.empty()) { - child_->read(child_->readOffset(), nestedRows_, nullptr); + // When deltaUpdate is set, treat extractionType as kNone so all child + // streams are read. The extraction transform is applied after the + // delta update. + if (scanSpec_->extractionType() == + velox::common::ScanSpec::ExtractionType::kSize && + !scanSpec_->deltaUpdate()) { + // Size extraction: only need offsets/sizes, skip child stream. + if (child_ && !nestedRows_.empty()) { + child_->seekTo(child_->readOffset() + nestedRows_.back() + 1, false); + } + } else if (child_ && !nestedRows_.empty()) { + child_->readWithTiming(child_->readOffset(), nestedRows_, nullptr); nestedRowsAllSelected_ = nestedRowsAllSelected_ && nestedRows_.size() == child_->outputRows().size(); nestedRows_ = child_->outputRows(); @@ -266,7 +313,18 @@ void SelectiveListColumnReader::getValues( const RowSet& rows, VectorPtr* result) { VELOX_DCHECK_NOT_NULL(result); - prepareResult(*result, requestedType_, rows.size(), memoryPool_); + + // When deltaUpdate is set, treat extractionType as kNone so the reader + // produces the full array. The extraction transform is applied after + // the delta update. + if (scanSpec_->extractionType() == + velox::common::ScanSpec::ExtractionType::kSize && + !scanSpec_->deltaUpdate()) { + getExtractionSizeValues(rows, result); + return; + } + + prepareResult(*result, requestedType_, rows.size(), pool_); auto* resultArray = result->get()->asUnchecked(); makeOffsetsAndSizes(rows, *resultArray); setComplexNulls(rows, *result); @@ -277,48 +335,40 @@ void SelectiveListColumnReader::getValues( } } -SelectiveMapColumnReader::SelectiveMapColumnReader( - const TypePtr& requestedType, - const std::shared_ptr& fileType, - FormatParams& params, - velox::common::ScanSpec& scanSpec) - : SelectiveRepeatedColumnReader(requestedType, params, scanSpec, fileType) { -} - -uint64_t SelectiveMapColumnReader::skip(uint64_t numValues) { +uint64_t SelectiveMapColumnReaderBase::skip(uint64_t numValues) { numValues = formatData_->skipNulls(numValues); - if (keyReader_ || elementReader_) { - std::array buffer; - uint64_t childElements{0}; - uint64_t lengthsRead{0}; - while (lengthsRead < numValues) { - const uint64_t chunk = - std::min(numValues - lengthsRead, static_cast(kBufferSize)); - readLengths(buffer.data(), chunk, nullptr); - for (size_t i = 0; i < chunk; ++i) { - childElements += buffer[i]; - } - lengthsRead += chunk; - } - - if (keyReader_) { - keyReader_->seekTo(keyReader_->readOffset() + childElements, false); - } - if (elementReader_) { - elementReader_->seekTo( - elementReader_->readOffset() + childElements, false); + std::array buffer; + uint64_t childElements{0}; + uint64_t lengthsRead{0}; + while (lengthsRead < numValues) { + const uint64_t chunk = + std::min(numValues - lengthsRead, static_cast(kBufferSize)); + readLengths(buffer.data(), chunk, nullptr); + for (size_t i = 0; i < chunk; ++i) { + childElements += buffer[i]; } - childTargetReadOffset_ += childElements; - } else { - VELOX_FAIL("repeated reader with no children"); + lengthsRead += chunk; + } + if (keyReader_) { + keyReader_->seekTo(keyReader_->readOffset() + childElements, false); } + if (elementReader_) { + elementReader_->seekTo(elementReader_->readOffset() + childElements, false); + } + childTargetReadOffset_ += childElements; return numValues; } -void SelectiveMapColumnReader::read( +void SelectiveMapColumnReaderBase::read( int64_t offset, const RowSet& rows, const uint64_t* incomingNulls) { + // When deltaUpdate is set, treat extractionType as kNone so all streams + // are read. The extraction transform is applied after the delta update. + const auto extractionType = scanSpec_->deltaUpdate() + ? velox::common::ScanSpec::ExtractionType::kNone + : scanSpec_->extractionType(); + // Catch up if child readers are behind the length stream. if (keyReader_) { keyReader_->seekTo(childTargetReadOffset_, false); @@ -334,31 +384,167 @@ void SelectiveMapColumnReader::read( scanSpec_->maxArrayElementsCount(), std::numeric_limits::max()); makeNestedRowSet(activeRows, rows.back()); - if (keyReader_ && elementReader_ && !nestedRows_.empty()) { - keyReader_->read(keyReader_->readOffset(), nestedRows_, nullptr); - nestedRowsAllSelected_ = nestedRowsAllSelected_ && - nestedRows_.size() == keyReader_->outputRows().size(); - nestedRows_ = keyReader_->outputRows(); - if (!nestedRows_.empty()) { - elementReader_->read(elementReader_->readOffset(), nestedRows_, nullptr); + + if (extractionType == velox::common::ScanSpec::ExtractionType::kSize) { + // Size extraction: only need offsets/sizes, skip both key and value + // streams. Advance children past the nested rows without reading. + if (keyReader_ && !nestedRows_.empty()) { + keyReader_->seekTo( + keyReader_->readOffset() + nestedRows_.back() + 1, false); + } + if (elementReader_ && !nestedRows_.empty()) { + elementReader_->seekTo( + elementReader_->readOffset() + nestedRows_.back() + 1, false); + } + } else if (extractionType == velox::common::ScanSpec::ExtractionType::kKeys) { + // Keys extraction: read only keys, skip values. + if (keyReader_ && !nestedRows_.empty()) { + keyReader_->readWithTiming( + keyReader_->readOffset(), nestedRows_, nullptr); + nestedRowsAllSelected_ = nestedRowsAllSelected_ && + nestedRows_.size() == keyReader_->outputRows().size(); + nestedRows_ = keyReader_->outputRows(); + } + if (elementReader_ && !nestedRows_.empty()) { + elementReader_->seekTo( + elementReader_->readOffset() + nestedRows_.back() + 1, false); + } + } else if ( + extractionType == velox::common::ScanSpec::ExtractionType::kValues) { + // Values extraction: read only values, skip keys. + if (keyReader_ && !nestedRows_.empty()) { + keyReader_->seekTo( + keyReader_->readOffset() + nestedRows_.back() + 1, false); + } + if (elementReader_ && !nestedRows_.empty()) { + elementReader_->readWithTiming( + elementReader_->readOffset(), nestedRows_, nullptr); nestedRowsAllSelected_ = nestedRowsAllSelected_ && nestedRows_.size() == elementReader_->outputRows().size(); nestedRows_ = elementReader_->outputRows(); } + } else { + // Normal read: read both keys and values. + VELOX_CHECK_EQ( + static_cast(extractionType), + static_cast(velox::common::ScanSpec::ExtractionType::kNone)); + if (keyReader_ && elementReader_ && !nestedRows_.empty()) { + keyReader_->readWithTiming( + keyReader_->readOffset(), nestedRows_, nullptr); + nestedRowsAllSelected_ = nestedRowsAllSelected_ && + nestedRows_.size() == keyReader_->outputRows().size(); + nestedRows_ = keyReader_->outputRows(); + if (!nestedRows_.empty()) { + elementReader_->readWithTiming( + elementReader_->readOffset(), nestedRows_, nullptr); + nestedRowsAllSelected_ = nestedRowsAllSelected_ && + nestedRows_.size() == elementReader_->outputRows().size(); + nestedRows_ = elementReader_->outputRows(); + } + } } numValues_ = activeRows.size(); readOffset_ = offset + rows.back() + 1; } +SelectiveMapColumnReader::SelectiveMapColumnReader( + const TypePtr& requestedType, + const TypeWithIdPtr& fileType, + FormatParams& params, + ScanSpec& scanSpec) + : SelectiveMapColumnReaderBase(requestedType, params, scanSpec, fileType) { + VELOX_CHECK(!scanSpec_->isFlatMapAsStruct()); + // We should not need this anymore. Is there a safe way to find out if there + // is any prod usages that forget to set up the map children in scan spec? + // This should be only possible when user bypasses the connector interface and + // create file readers directly. + if (scanSpec_->children().empty()) { + scanSpec_->getOrCreateChild(ScanSpec::kMapKeysFieldName); + scanSpec_->getOrCreateChild(ScanSpec::kMapValuesFieldName); + } + scanSpec_->children()[0]->setProjectOut(true); + scanSpec_->children()[1]->setProjectOut(true); +} + +void SelectiveMapColumnReaderBase::getExtractionValues( + const RowSet& rows, + VectorPtr* result) { + VELOX_DCHECK_NOT_NULL(result); + const auto extractionType = scanSpec_->extractionType(); + VELOX_DCHECK_NE( + static_cast(extractionType), + static_cast(velox::common::ScanSpec::ExtractionType::kNone)); + + // When deltaUpdate is set, treat extractionType as kNone so the reader + // produces the full map. The extraction transform is applied after + // the delta update. + if (extractionType == velox::common::ScanSpec::ExtractionType::kSize && + !scanSpec_->deltaUpdate()) { + getExtractionSizeValues(rows, result); + return; + } + + // kKeys or kValues: compute offsets/sizes via a reusable MapVector, + // then read elements and construct the output ArrayVector. + prepareResult( + extractionOffsetsTemp_, + requestedType_, + static_cast(rows.size()), + pool_); + auto* tempMap = extractionOffsetsTemp_->asUnchecked(); + makeOffsetsAndSizes(rows, *tempMap); + setComplexNulls(rows, extractionOffsetsTemp_); + + // Extract elements from the existing result to reuse across batches. + VectorPtr elements; + if (*result && result->get()->encoding() == VectorEncoding::Simple::ARRAY) { + elements = result->get()->asUnchecked()->elements(); + } + if (extractionType == velox::common::ScanSpec::ExtractionType::kKeys) { + if (!nestedRows_.empty()) { + keyReader_->getValues(nestedRows_, &elements); + } + } else { + if (!nestedRows_.empty()) { + prepareStructResult(requestedType_->childAt(1), &elements); + elementReader_->getValues(nestedRows_, &elements); + } + } + auto elemType = elements + ? elements->type() + : requestedType_->childAt( + extractionType == velox::common::ScanSpec::ExtractionType::kKeys + ? 0 + : 1); + *result = std::make_shared( + pool_, + ARRAY(elemType), + tempMap->nulls(), + rows.size(), + tempMap->offsets(), + tempMap->sizes(), + elements); +} + void SelectiveMapColumnReader::getValues( const RowSet& rows, VectorPtr* result) { VELOX_DCHECK_NOT_NULL(result); - VELOX_CHECK( - !result->get() || result->get()->type()->isMap(), - "Expect MAP result vector, got {}", - result->get()->type()->toString()); - prepareResult(*result, requestedType_, rows.size(), memoryPool_); + const auto extractionType = scanSpec_->extractionType(); + + // When deltaUpdate is set, treat extractionType as kNone so the reader + // produces the full map. The extraction transform is applied after + // the delta update. + if (extractionType != velox::common::ScanSpec::ExtractionType::kNone && + !scanSpec_->deltaUpdate()) { + getExtractionValues(rows, result); + return; + } + + // Normal path: produce MapVector. If the result has a non-MAP type + // (e.g., from a previous extraction transform), prepareResult will + // replace it with a fresh MapVector. + prepareResult(*result, requestedType_, rows.size(), pool_); auto* resultMap = result->get()->asUnchecked(); makeOffsetsAndSizes(rows, *resultMap); setComplexNulls(rows, *result); @@ -374,4 +560,114 @@ void SelectiveMapColumnReader::getValues( } } +SelectiveMapAsStructColumnReader::SelectiveMapAsStructColumnReader( + const TypePtr& requestedType, + const TypeWithIdPtr& fileType, + FormatParams& params, + ScanSpec& scanSpec) + : SelectiveMapColumnReaderBase(requestedType, params, scanSpec, fileType) { + VELOX_CHECK(scanSpec_->isFlatMapAsStruct() && requestedType_->isMap()); + VELOX_CHECK_EQ( + static_cast(scanSpec.extractionType()), + static_cast(velox::common::ScanSpec::ExtractionType::kNone), + "Flat map as struct reader does not support extraction pushdown"); + mapScanSpec_.addMapKeyFieldRecursively(*requestedType_->childAt(0)); + mapScanSpec_.addMapValueFieldRecursively(*requestedType_->childAt(1)); + column_index_t maxChannel = 0; + for (auto& childSpec : scanSpec_->children()) { + auto field = folly::tryTo(childSpec->fieldName()); + VELOX_CHECK( + field.hasValue(), + "Fail to parse field name: {}", + childSpec->fieldName()); + keyToIndex_[*field] = childSpec->channel(); + maxChannel = std::max(maxChannel, childSpec->channel()); + } + copyRanges_.resize(maxChannel + 1); +} + +void SelectiveMapAsStructColumnReader::getValues( + const RowSet& rows, + VectorPtr* result) { + VELOX_CHECK_NOT_NULL(*result); + VELOX_CHECK( + result->get()->type()->isRow(), + "Expect ROW, got {}", + result->get()->type()->toString()); + BaseVector::prepareForReuse(*result, rows.size()); + auto* resultRow = result->get()->asChecked(); + setComplexNulls(rows, *result); + for (auto& child : resultRow->children()) { + bits::fillBits(child->mutableRawNulls(), 0, rows.size(), bits::kNull); + } + numValues_ = rows.size(); + if (nestedRows_.empty()) { + return; + } + keyReader_->getValues(nestedRows_, &mapKeys_); + prepareStructResult(requestedType_->childAt(1), &mapValues_); + elementReader_->getValues(nestedRows_, &mapValues_); + decodedKeys_.decode(*mapKeys_); + for (auto& ranges : copyRanges_) { + ranges.clear(); + } + switch (mapKeys_->type()->kind()) { + case TypeKind::TINYINT: + makeCopyRanges(rows); + break; + case TypeKind::SMALLINT: + makeCopyRanges(rows); + break; + case TypeKind::INTEGER: + makeCopyRanges(rows); + break; + case TypeKind::BIGINT: + makeCopyRanges(rows); + break; + default: + VELOX_UNSUPPORTED( + "Unsupported key type: {}", mapKeys_->type()->toString()); + } + for (column_index_t i = 0; i < resultRow->childrenSize(); ++i) { + resultRow->childAt(i)->copyRanges(mapValues_.get(), copyRanges_[i]); + } +} + +template +void SelectiveMapAsStructColumnReader::makeCopyRanges(const RowSet& rows) { + auto* nulls = nullsInReadRange_ ? nullsInReadRange_->as() : nullptr; + for (vector_size_t i = 0, + currentOffset = 0, + currentRow = 0, + nestedRowIndex = 0; + i < rows.size(); + ++i) { + const auto row = rows[i]; + if (nulls && bits::isBitNull(nulls, row)) { + anyNulls_ = true; + continue; + } + currentOffset += sumLengths(allLengths_, nulls, currentRow, row); + currentRow = row + 1; + nestedRowIndex = + advanceNestedRows(nestedRows_, nestedRowIndex, currentOffset); + currentOffset += allLengths_[row]; + const auto newNestedRowIndex = + advanceNestedRows(nestedRows_, nestedRowIndex, currentOffset); + for (auto j = nestedRowIndex; j < newNestedRowIndex; ++j) { + VELOX_CHECK(!decodedKeys_.isNullAt(j)); + auto it = keyToIndex_.find(decodedKeys_.valueAt(j)); + if (it == keyToIndex_.end()) { + continue; + } + copyRanges_[it->second].push_back({ + .sourceIndex = j, + .targetIndex = i, + .count = 1, + }); + } + nestedRowIndex = newNestedRowIndex; + } +} + } // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/SelectiveRepeatedColumnReader.h b/velox/dwio/common/SelectiveRepeatedColumnReader.h index f1792ff9511..985a240bb06 100644 --- a/velox/dwio/common/SelectiveRepeatedColumnReader.h +++ b/velox/dwio/common/SelectiveRepeatedColumnReader.h @@ -42,7 +42,7 @@ class SelectiveRepeatedColumnReader : public SelectiveColumnReader { velox::common::ScanSpec& scanSpec, std::shared_ptr type) : SelectiveColumnReader(requestedType, std::move(type), params, scanSpec), - nestedRowsHolder_(memoryPool_) {} + nestedRowsHolder_(pool_) {} /// Reads 'numLengths' next lengths into 'result'. If 'nulls' is /// non-null, each kNull bit signifies a null with a length of 0 to @@ -62,7 +62,7 @@ class SelectiveRepeatedColumnReader : public SelectiveColumnReader { /// Creates a struct if '*result' is empty and 'type' is a row. void prepareStructResult(const TypePtr& type, VectorPtr* result) { if (!*result && type->kind() == TypeKind::ROW) { - *result = BaseVector::create(type, 0, memoryPool_); + *result = BaseVector::create(type, 0, pool_); } } @@ -85,6 +85,10 @@ class SelectiveRepeatedColumnReader : public SelectiveColumnReader { return i; } + /// Produce a FlatVector of sizes directly from allLengths_ for + /// kSize extraction. Reuses the result vector across batches. + void getExtractionSizeValues(const RowSet& rows, VectorPtr* result); + void ensureAllLengthsBuffer(vector_size_t size); BufferPtr allLengthsHolder_; @@ -112,7 +116,9 @@ class SelectiveListColumnReader : public SelectiveRepeatedColumnReader { velox::common::ScanSpec& scanSpec); void resetFilterCaches() override { - child_->resetFilterCaches(); + if (child_) { + child_->resetFilterCaches(); + } } uint64_t skip(uint64_t numValues) override; @@ -128,17 +134,17 @@ class SelectiveListColumnReader : public SelectiveRepeatedColumnReader { std::unique_ptr child_; }; -class SelectiveMapColumnReader : public SelectiveRepeatedColumnReader { +class SelectiveMapColumnReaderBase : public SelectiveRepeatedColumnReader { public: - SelectiveMapColumnReader( - const TypePtr& requestedType, - const std::shared_ptr& fileType, - FormatParams& params, - velox::common::ScanSpec& scanSpec); + using SelectiveRepeatedColumnReader::SelectiveRepeatedColumnReader; void resetFilterCaches() override { - keyReader_->resetFilterCaches(); - elementReader_->resetFilterCaches(); + if (keyReader_) { + keyReader_->resetFilterCaches(); + } + if (elementReader_) { + elementReader_->resetFilterCaches(); + } } uint64_t skip(uint64_t numValues) override; @@ -146,11 +152,64 @@ class SelectiveMapColumnReader : public SelectiveRepeatedColumnReader { void read(int64_t offset, const RowSet& rows, const uint64_t* incomingNulls) override; - void getValues(const RowSet& rows, VectorPtr* result) override; + void seekToRowGroup(int64_t index) override { + SelectiveRepeatedColumnReader::seekToRowGroup(index); + if (keyReader_) { + keyReader_->seekToRowGroup(index); + keyReader_->setReadOffsetRecursive(0); + } + if (elementReader_) { + elementReader_->seekToRowGroup(index); + elementReader_->setReadOffsetRecursive(0); + } + childTargetReadOffset_ = 0; + } protected: + /// Handle extraction types (kSize, kKeys, kValues) in getValues. + void getExtractionValues(const RowSet& rows, VectorPtr* result); + std::unique_ptr keyReader_; std::unique_ptr elementReader_; + + // Reusable MapVector for computing offsets/sizes in kKeys/kValues extraction. + // Not needed for kSize (sizes computed directly from allLengths_). + VectorPtr extractionOffsetsTemp_; +}; + +class SelectiveMapColumnReader : public SelectiveMapColumnReaderBase { + public: + SelectiveMapColumnReader( + const TypePtr& requestedType, + const TypeWithIdPtr& fileType, + FormatParams& params, + ScanSpec& scanSpec); + + void getValues(const RowSet& rows, VectorPtr* result) override; +}; + +class SelectiveMapAsStructColumnReader : public SelectiveMapColumnReaderBase { + public: + SelectiveMapAsStructColumnReader( + const TypePtr& requestedType, + const TypeWithIdPtr& fileType, + FormatParams& params, + ScanSpec& scanSpec); + + void getValues(const RowSet& rows, VectorPtr* result) override; + + protected: + ScanSpec mapScanSpec_{""}; + + private: + template + void makeCopyRanges(const RowSet& rows); + + folly::F14FastMap keyToIndex_; + std::vector> copyRanges_; + VectorPtr mapKeys_; + VectorPtr mapValues_; + DecodedVector decodedKeys_; }; } // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/SelectiveStructColumnReader.cpp b/velox/dwio/common/SelectiveStructColumnReader.cpp index 7c0e9b208f9..e453fa03f07 100644 --- a/velox/dwio/common/SelectiveStructColumnReader.cpp +++ b/velox/dwio/common/SelectiveStructColumnReader.cpp @@ -407,7 +407,6 @@ void SelectiveStructColumnReaderBase::read( } const uint64_t* structNulls = nulls(); - // A struct reader may have a null/non-null filter if (scanSpec_->filter()) { const auto kind = scanSpec_->filter()->kind(); @@ -429,7 +428,6 @@ void SelectiveStructColumnReaderBase::read( VELOX_CHECK(!childSpecs.empty()); for (size_t i = 0; i < childSpecs.size(); ++i) { const auto& childSpec = childSpecs[i]; - VELOX_TRACE_HISTORY_PUSH("read %s", childSpec->fieldName().c_str()); if (childSpec->deltaUpdate()) { // Will make LazyVector. @@ -453,7 +451,7 @@ void SelectiveStructColumnReaderBase::read( auto* reader = children_.at(fieldIndex); if (reader->isTopLevel() && childSpec->projectOut() && !childSpec->hasFilter() && generateLazyChildren_) { - // Will make a LazyVector. + // Will make a LazyVector (with or without transform). continue; } @@ -463,7 +461,7 @@ void SelectiveStructColumnReaderBase::read( SelectivityTimer timer(childSpec->selectivity(), activeRows.size()); reader->resetInitTimeClocks(); - reader->read(offset, activeRows, structNulls); + reader->readWithTiming(offset, activeRows, structNulls); // Exclude initialization time. timer.subtract(reader->initTimeClocks()); @@ -475,7 +473,7 @@ void SelectiveStructColumnReaderBase::read( break; } } else { - reader->read(offset, activeRows, structNulls); + reader->readWithTiming(offset, activeRows, structNulls); } } @@ -531,12 +529,59 @@ bool SelectiveStructColumnReaderBase::isChildMissing( childSpec.channel() >= fileType_->size()); } +std::unique_ptr +SelectiveStructColumnReaderBase::makeColumnLoader(vector_size_t index) { + // Check if the child at this index has a transform with kNone extraction. + // If so, return a TransformColumnLoader to apply the transform lazily. + for (const auto& childSpec : scanSpec_->children()) { + if (childSpec->subscript() == index && childSpec->hasTransform() && + childSpec->extractionType() == + velox::common::ScanSpec::ExtractionType::kNone) { + return std::make_unique( + this, children_[index], numReads_, childSpec->transform()); + } + } + return std::make_unique( + this, children_[index], numReads_); +} + void SelectiveStructColumnReaderBase::getValues( const RowSet& rows, VectorPtr* result) { VELOX_CHECK(!scanSpec_->children().empty()); VELOX_CHECK_NOT_NULL( *result, "SelectiveStructColumnReaderBase expects a non-null result"); + + // When deltaUpdate is set, skip kField extraction so the reader produces + // the full struct. The extraction transform is applied after the delta + // update. + if (!isRoot_ && + scanSpec_->extractionType() == + velox::common::ScanSpec::ExtractionType::kField && + !scanSpec_->deltaUpdate()) { + auto fieldIdx = scanSpec_->extractionFieldIndex(); + for (const auto& childSpec : scanSpec_->children()) { + if (childSpec->channel() == fieldIdx && !childSpec->isConstant()) { + auto index = static_cast(childSpec->subscript()); + if (childSpec->hasFilter() || !children_[index]->isTopLevel() || + !generateLazyChildren_) { + children_[index]->getValues(rows, result); + } else { + // Lazy loading: create a LazyVector for the extracted field. + setOutputRowsForLazy(rows); + setLazyField( + makeColumnLoader(index), + children_[index]->requestedType(), + static_cast(rows.size()), + pool_, + *result); + } + return; + } + } + VELOX_UNREACHABLE(); + } + VELOX_CHECK( result->get()->type()->isRow(), "Struct reader expects a result of type ROW."); @@ -550,7 +595,6 @@ void SelectiveStructColumnReaderBase::getValues( setComplexNulls(rows, *result); for (const auto& childSpec : scanSpec_->children()) { - VELOX_TRACE_HISTORY_PUSH("getValues %s", childSpec->fieldName().c_str()); if (!childSpec->keepValues()) { continue; } @@ -574,9 +618,18 @@ void SelectiveStructColumnReaderBase::getValues( this, children_[index], numReads_), resultRow->type()->childAt(channel), rows.size(), - memoryPool_, + pool_, childResult); } + // If the column also has an extraction transform (e.g., MapKeys on a + // MAP_CONCAT delta-updated column), apply it after the delta update. + // The delta update modifies the column (e.g., MAP_CONCAT adds entries), + // and extraction should see the updated data. + if (childSpec->hasTransform() && childResult) { + // Force-load lazy vectors so the transform can process them. + childResult = BaseVector::loadedVectorShared(childResult); + childResult = childSpec->transform()(childResult, pool_); + } continue; } @@ -615,13 +668,16 @@ void SelectiveStructColumnReaderBase::getValues( // LazyVector result. setOutputRowsForLazy(rows); + // When the child has a transform (e.g., extraction pushdown), the lazy + // vector type is the transform's output type, not the file column type. + auto lazyType = + (childSpec->hasTransform() && childSpec->transformOutputType()) + ? childSpec->transformOutputType() + : resultRow->type()->childAt(channel); setLazyField( - std::make_unique(this, children_[index], numReads_), - resultRow->type()->childAt(channel), - rows.size(), - memoryPool_, - childResult); + makeColumnLoader(index), lazyType, rows.size(), pool_, childResult); } + resultRow->updateContainsLazyNotLoaded(); } diff --git a/velox/dwio/common/SelectiveStructColumnReader.h b/velox/dwio/common/SelectiveStructColumnReader.h index e867caa80ee..9d106a40284 100644 --- a/velox/dwio/common/SelectiveStructColumnReader.h +++ b/velox/dwio/common/SelectiveStructColumnReader.h @@ -20,6 +20,8 @@ namespace facebook::velox::dwio::common { +class ColumnLoader; + template class SelectiveFlatMapColumnReaderHelper; @@ -108,7 +110,7 @@ class SelectiveStructColumnReaderBase : public SelectiveColumnReader { // The subscript of childSpecs will be set to this value if the column is // constant (either explicitly or because it's missing). - static constexpr int32_t kConstantChildSpecSubscript = -1; + static constexpr int32_t kConstantChildSpecSubscript{-1}; SelectiveStructColumnReaderBase( const TypePtr& requestedType, @@ -122,7 +124,7 @@ class SelectiveStructColumnReaderBase : public SelectiveColumnReader { getExceptionContext().message(VeloxException::Type::kSystem)), isRoot_(isRoot), generateLazyChildren_(generateLazyChildren), - rows_(memoryPool_) {} + rows_(pool_) {} bool hasDeletion() const final { return hasDeletion_; @@ -161,6 +163,12 @@ class SelectiveStructColumnReaderBase : public SelectiveColumnReader { const int64_t offset, const int32_t rowsPerRowGroup); + virtual std::unique_ptr makeColumnLoader( + vector_size_t index); + + // Sequence number of output batch. Checked against ColumnLoaders + // created by 'this' to verify they are still valid at load. + uint64_t numReads_ = 0; std::vector children_; private: @@ -189,13 +197,9 @@ class SelectiveStructColumnReaderBase : public SelectiveColumnReader { // Dense set of rows to read in next(). raw_vector rows_; - // Sequence number of output batch. Checked against ColumnLoaders - // created by 'this' to verify they are still valid at load. - uint64_t numReads_ = 0; - int64_t lazyVectorReadOffset_; - int64_t currentRowNumber_ = -1; + int64_t currentRowNumber_{-1}; const Mutation* mutation_ = nullptr; @@ -245,7 +249,7 @@ class SelectiveFlatMapColumnReaderHelper { reader_.children_[i]->setIsFlatMapValue(true); } if (auto type = reader_.requestedType_->childAt(1); type->isRow()) { - childValues_ = BaseVector::create(type, 0, reader_.memoryPool_); + childValues_ = BaseVector::create(type, 0, reader_.pool_); } } @@ -261,8 +265,7 @@ class SelectiveFlatMapColumnReaderHelper { result->resize(size); } else { VLOG(1) << "Reallocating result MAP vector of size " << size; - result = - BaseVector::create(reader_.requestedType_, size, reader_.memoryPool_); + result = BaseVector::create(reader_.requestedType_, size, reader_.pool_); } return *result->asUnchecked(); } @@ -335,7 +338,7 @@ void SelectiveFlatMapColumnReaderHelper::read( reader_.advanceFieldReader(child, offset); } for (auto* child : reader_.children_) { - child->read(offset, activeRows, mapNulls); + child->readWithTiming(offset, activeRows, mapNulls); child->addParentNulls(offset, mapNulls, rows); } reader_.lazyVectorReadOffset_ = offset; @@ -498,7 +501,7 @@ void SelectiveFlatMapColumnReaderHelper::copyValues( } } if (strKeySize > 0) { - auto buf = AlignedBuffer::allocate(strKeySize, reader_.memoryPool_); + auto buf = AlignedBuffer::allocate(strKeySize, reader_.pool_); rawStrKeyBuffer = buf->template asMutable(); flatKeys->addStringBuffer(buf); strKeySize = 0; diff --git a/velox/dwio/common/SortingWriter.cpp b/velox/dwio/common/SortingWriter.cpp index d67efd6ec22..9e973106739 100644 --- a/velox/dwio/common/SortingWriter.cpp +++ b/velox/dwio/common/SortingWriter.cpp @@ -86,11 +86,11 @@ bool SortingWriter::finish() { return true; } -void SortingWriter::close() { +std::unique_ptr SortingWriter::close() { VELOX_CHECK(isFinishing()); setState(State::kClosed); VELOX_CHECK_NULL(sortBuffer_); - outputWriter_->close(); + return outputWriter_->close(); } void SortingWriter::abort() { @@ -114,10 +114,14 @@ uint64_t SortingWriter::reclaim( if (!isRunning() && !isFinishing()) { LOG(WARNING) << "Can't reclaim from a not running hive sort writer pool: " - << sortPool_->name() << ", state: " << state() - << "used memory: " << succinctBytes(sortPool_->usedBytes()) - << ", reserved memory: " - << succinctBytes(sortPool_->reservedBytes()); + << sortPool_->name() + << ", root pool: " << sortPool_->root()->name() + << ", state: " << state() + << ", used: " << succinctBytes(sortPool_->usedBytes()) + << ", reservation: " + << succinctBytes(sortPool_->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(sortPool_->root()->reservedBytes()); ++stats.numNonReclaimableAttempts; return 0; } diff --git a/velox/dwio/common/SortingWriter.h b/velox/dwio/common/SortingWriter.h index a136cff1238..c7a191a30a5 100644 --- a/velox/dwio/common/SortingWriter.h +++ b/velox/dwio/common/SortingWriter.h @@ -42,7 +42,9 @@ class SortingWriter : public Writer { /// be flushed. void flush() override; - void close() override; + /// Closes the writer. Returns file metadata, or null if no metadata is + /// available (e.g. for an empty file). + std::unique_ptr close() override; void abort() override; diff --git a/velox/dwio/common/Statistics.h b/velox/dwio/common/Statistics.h index 1c6965d6d71..ff56ea640e1 100644 --- a/velox/dwio/common/Statistics.h +++ b/velox/dwio/common/Statistics.h @@ -16,12 +16,21 @@ #pragma once +#include #include +#include #include +#include +#include +#include "velox/common/time/CpuWallTimer.h" +#include "velox/dwio/common/Options.h" +#include "velox/dwio/common/TypeWithId.h" +#include "velox/dwio/common/UnitLoader.h" #include "velox/common/base/Exceptions.h" #include "velox/common/base/RuntimeMetrics.h" -#include "velox/dwio/common/exception/Exception.h" +#include "velox/common/io/IoStatistics.h" +#include "velox/type/Type.h" namespace facebook::velox::dwio::common { @@ -135,6 +144,12 @@ class ColumnStatistics { numDistinct_ = count; } + /// Returns true if there are no non-null values (value count is known to be + /// zero). + bool isAllNull() const { + return valueCount_.has_value() && valueCount_.value() == 0; + } + /** * return string representation of this stats object */ @@ -493,10 +508,11 @@ class MapColumnStatistics : public virtual ColumnStatistics { values.reserve(entryStatistics_.size()); for (const auto& entry : entryStatistics_) { auto& stats = *entry.second; - values.push_back(fmt::format( - "{{ Key: {}, Stats: {},}}", - entry.first.toString(), - stats.toString())); + values.push_back( + fmt::format( + "{{ Key: {}, Stats: {},}}", + entry.first.toString(), + stats.toString())); } std::string repr; folly::join(",", values, repr); @@ -532,10 +548,192 @@ class Statistics { virtual uint32_t getNumberOfColumns() const = 0; }; +/// Runs 'func' and records decompression stats if 'counter' is non-null. +template +auto withDecompressStats(io::IoCounter* counter, F&& func) + -> std::enable_if_t, decltype(func())> { + if (counter) { + DeltaCpuWallTimer timer([counter](const CpuWallTiming& timing) { + counter->increment(timing.cpuNanos); + }); + return func(); + } + return func(); +} + +template +auto withDecompressStats(io::IoCounter* counter, F&& func) + -> std::enable_if_t> { + if (counter) { + DeltaCpuWallTimer timer([counter](const CpuWallTiming& timing) { + counter->increment(timing.cpuNanos); + }); + func(); + return; + } + func(); +} + +/// Per-column statistics counters. Wraps multiple IoCounter instances for +/// different types of measurements (decompression, encoding, etc.). +/// Can be used by any file format reader (DWRF, Nimble, Parquet, etc.). +struct ColumnMetrics { + explicit ColumnMetrics(TypeKind type = TypeKind::INVALID) : typeKind(type) {} + + TypeKind typeKind; + io::IoCounter decompressCPUTimeNanos; + io::IoCounter decodeCPUTimeNanos; + + /// Merges stats from another ColumnMetrics instance. + void merge(const ColumnMetrics& other) { + decompressCPUTimeNanos.merge(other.decompressCPUTimeNanos); + decodeCPUTimeNanos.merge(other.decodeCPUTimeNanos); + } +}; + +/// Thread-safe collection of per-column metrics keyed by nodeId. +/// Can be used by any file format reader (DWRF, Nimble, Parquet, etc.). +struct ColumnMetricsSet { + /// Gets or creates a ColumnMetrics for a column. Sets typeKind when creating. + ColumnMetrics* getOrCreate( + uint32_t nodeId, + TypeKind typeKind = TypeKind::INVALID) { + auto locked = map_.wlock(); + auto it = locked->find(nodeId); + if (it == locked->end()) { + it = locked->emplace(nodeId, std::make_unique(typeKind)) + .first; + } + return it->second.get(); + } + + /// Merges all column metrics from another ColumnMetricsSet instance. + void mergeFrom(const ColumnMetricsSet& other) { + auto srcLocked = other.map_.rlock(); + auto dstLocked = map_.wlock(); + for (const auto& [nodeId, srcStats] : *srcLocked) { + auto it = dstLocked->find(nodeId); + if (it == dstLocked->end()) { + it = + dstLocked->emplace(nodeId, std::make_unique()).first; + it->second->typeKind = srcStats->typeKind; + } + it->second->merge(*srcStats); + } + } + + /// Exports per-column metrics into the runtime metrics result map. + void toRuntimeMetrics( + std::unordered_map& result) const { + auto statsLocked = map_.rlock(); + for (const auto& [nodeId, stats] : *statsLocked) { + // Export decompression timing. + const auto& decompressCounter = stats->decompressCPUTimeNanos; + if (decompressCounter.count() > 0) { + result.emplace( + fmt::format( + "column_{}.{}.decompressCPUTimeNanos", + nodeId, + TypeKindName::toName(stats->typeKind)), + RuntimeMetric{ + saturateCast(decompressCounter.sum()), + decompressCounter.count(), + saturateCast(decompressCounter.min()), + saturateCast(decompressCounter.max()), + RuntimeCounter::Unit::kNanos}); + } + // Export decode timing. + const auto& decodeCounter = stats->decodeCPUTimeNanos; + if (decodeCounter.count() > 0) { + result.emplace( + fmt::format( + "column_{}.{}.decodeCPUTimeNanos", + nodeId, + TypeKindName::toName(stats->typeKind)), + RuntimeMetric{ + saturateCast(decodeCounter.sum()), + decodeCounter.count(), + saturateCast(decodeCounter.min()), + saturateCast(decodeCounter.max()), + RuntimeCounter::Unit::kNanos}); + } + } + } + + private: + folly::Synchronized< + folly::F14FastMap>> + map_; +}; + struct ColumnReaderStatistics { // Number of rows returned by string dictionary reader that is flattened // instead of keeping dictionary encoding. int64_t flattenStringDictionaryValues{0}; + + // Total time spent in loading pages, in nanoseconds. + io::IoCounter pageLoadTimeNs; + + // Per-column decompression metrics. Only populated when column stats + // collection is enabled. + std::optional columnMetricsSet; + + /// Initializes column stats collection for the given schema if enabled in + /// options. Recursively registers metrics for all columns in the type tree. + void initColumnStatsCollection( + const TypeWithId& schema, + const RowReaderOptions& options) { + if (!options.collectColumnCpuMetrics()) { + return; + } + columnMetricsSet.emplace(); + registerColumnMetricsImpl(schema); + } + + /// Merges all stats from another ColumnReaderStatistics instance. + void mergeFrom(const ColumnReaderStatistics& other) { + flattenStringDictionaryValues += other.flattenStringDictionaryValues; + pageLoadTimeNs.merge(other.pageLoadTimeNs); + if (other.columnMetricsSet) { + if (!columnMetricsSet) { + columnMetricsSet.emplace(); + } + columnMetricsSet->mergeFrom(*other.columnMetricsSet); + } + } + + /// Exports all metrics into the runtime metrics result map. + void toRuntimeMetrics( + std::unordered_map& result) const { + if (flattenStringDictionaryValues > 0) { + result.emplace( + "flattenStringDictionaryValues", + RuntimeMetric(flattenStringDictionaryValues)); + } + if (pageLoadTimeNs.sum() > 0) { + result.emplace( + "pageLoadTimeNs", + RuntimeMetric( + pageLoadTimeNs.sum(), + pageLoadTimeNs.count(), + pageLoadTimeNs.min(), + pageLoadTimeNs.max(), + RuntimeCounter::Unit::kNanos)); + } + if (columnMetricsSet) { + columnMetricsSet->toRuntimeMetrics(result); + } + } + + private: + void registerColumnMetricsImpl(const TypeWithId& node) { + columnMetricsSet->getOrCreate(node.id(), node.type()->kind()); + for (uint32_t i = 0; i < node.size(); ++i) { + if (const auto* child = node.childAt(i).get()) { + registerColumnMetricsImpl(*child); + } + } + } }; struct RuntimeStatistics { @@ -558,40 +756,40 @@ struct RuntimeStatistics { int64_t numStripes{0}; - ColumnReaderStatistics columnReaderStatistics; + UnitLoaderStats unitLoaderStats; + ColumnReaderStatistics columnReaderStats; - std::unordered_map toMap() { - std::unordered_map result; + std::unordered_map toRuntimeMetricMap() { + std::unordered_map result; + for (const auto& [name, metric] : unitLoaderStats.stats()) { + result.emplace(name, RuntimeMetric(metric.sum, metric.unit)); + } if (skippedSplits > 0) { - result.emplace("skippedSplits", RuntimeCounter(skippedSplits)); + result.emplace("skippedSplits", RuntimeMetric(skippedSplits)); } if (processedSplits > 0) { - result.emplace("processedSplits", RuntimeCounter(processedSplits)); + result.emplace("processedSplits", RuntimeMetric(processedSplits)); } if (skippedSplitBytes > 0) { result.emplace( "skippedSplitBytes", - RuntimeCounter(skippedSplitBytes, RuntimeCounter::Unit::kBytes)); + RuntimeMetric(skippedSplitBytes, RuntimeCounter::Unit::kBytes)); } if (skippedStrides > 0) { - result.emplace("skippedStrides", RuntimeCounter(skippedStrides)); + result.emplace("skippedStrides", RuntimeMetric(skippedStrides)); } if (processedStrides > 0) { - result.emplace("processedStrides", RuntimeCounter(processedStrides)); + result.emplace("processedStrides", RuntimeMetric(processedStrides)); } if (footerBufferOverread > 0) { result.emplace( "footerBufferOverread", - RuntimeCounter(footerBufferOverread, RuntimeCounter::Unit::kBytes)); + RuntimeMetric(footerBufferOverread, RuntimeCounter::Unit::kBytes)); } if (numStripes > 0) { - result.emplace("numStripes", RuntimeCounter(numStripes)); - } - if (columnReaderStatistics.flattenStringDictionaryValues > 0) { - result.emplace( - "flattenStringDictionaryValues", - RuntimeCounter(columnReaderStatistics.flattenStringDictionaryValues)); + result.emplace("numStripes", RuntimeMetric(numStripes)); } + columnReaderStats.toRuntimeMetrics(result); return result; } }; diff --git a/velox/dwio/common/StatisticsBuilder.cpp b/velox/dwio/common/StatisticsBuilder.cpp new file mode 100644 index 00000000000..9d0b41cc6d5 --- /dev/null +++ b/velox/dwio/common/StatisticsBuilder.cpp @@ -0,0 +1,450 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/dwio/common/StatisticsBuilder.h" + +namespace facebook::velox::dwio::stats { + +// Import column statistics types from dwio::common. +using common::BinaryColumnStatistics; +using common::BooleanColumnStatistics; +using common::ColumnStatistics; +using common::DoubleColumnStatistics; +using common::IntegerColumnStatistics; +using common::StringColumnStatistics; + +namespace { + +template +void addWithOverflowCheck(std::optional& to, T value, uint64_t count) { + if (to.has_value()) { + T result; + auto overflow = __builtin_mul_overflow(value, count, &result); + if (!overflow) { + overflow = __builtin_add_overflow(to.value(), result, &to.value()); + } + if (overflow) { + to.reset(); + } + } +} + +template +void mergeWithOverflowCheck( + std::optional& to, + const std::optional& from) { + if (to.has_value()) { + if (from.has_value()) { + auto overflow = + __builtin_add_overflow(to.value(), from.value(), &to.value()); + if (overflow) { + to.reset(); + } + } else { + to.reset(); + } + } +} + +template +void mergeCount(std::optional& to, const std::optional& from) { + if (to.has_value()) { + if (from.has_value()) { + to.value() += from.value(); + } else { + to.reset(); + } + } +} + +template +void mergeMin(std::optional& to, const std::optional& from) { + if (to.has_value()) { + if (!from.has_value()) { + to.reset(); + } else if (from.value() < to.value()) { + to = from; + } + } +} + +template +void mergeMax(std::optional& to, const std::optional& from) { + if (to.has_value()) { + if (!from.has_value()) { + to.reset(); + } else if (from.value() > to.value()) { + to = from; + } + } +} + +bool isValidLength(const std::optional& length) { + return length.has_value() && + length.value() <= std::numeric_limits::max(); +} + +bool shouldKeepString( + const std::optional& val, + uint32_t lengthLimit) { + return val.has_value() && val.value().size() <= lengthLimit; +} + +} // namespace + +std::unique_ptr StatisticsBuilder::build() const { + auto result = std::make_unique( + valueCount_, hasNull_, rawSize_, size_, estimateNumDistinct()); + + // For the base builder, there are no typed stats to add. + return result; +} + +void StatisticsBuilder::incrementSize(uint64_t size) { + if (LIKELY(size_.has_value())) { + addWithOverflowCheck(size_, size, /*count=*/1); + } +} + +void StatisticsBuilder::merge(const ColumnStatistics& other, bool ignoreSize) { + mergeCount(valueCount_, other.getNumberOfValues()); + + if (!hasNull_.has_value() || !hasNull_.value()) { + auto otherHasNull = other.hasNull(); + if (otherHasNull.has_value()) { + if (otherHasNull.value()) { + hasNull_ = true; + } + } else if (hasNull_.has_value()) { + hasNull_.reset(); + } + } + mergeCount(rawSize_, other.getRawSize()); + if (!ignoreSize) { + mergeCount(size_, other.getSize()); + } + if (hll_) { + auto* otherBuilder = dynamic_cast(&other); + VELOX_CHECK_NOT_NULL(otherBuilder); + VELOX_CHECK_NOT_NULL(otherBuilder->hll_); + hll_->mergeWith(*otherBuilder->hll_); + } +} + +std::unique_ptr StatisticsBuilder::create( + const Type& type, + const StatisticsBuilderOptions& options) { + switch (type.kind()) { + case TypeKind::BOOLEAN: + return std::make_unique(options); + case TypeKind::TINYINT: + case TypeKind::SMALLINT: + case TypeKind::INTEGER: + case TypeKind::BIGINT: + return std::make_unique(options); + case TypeKind::REAL: + case TypeKind::DOUBLE: + return std::make_unique(options); + case TypeKind::VARCHAR: + return std::make_unique(options); + case TypeKind::VARBINARY: + return std::make_unique(options); + default: + return std::make_unique(options); + } +} + +void StatisticsBuilder::createTree( + std::vector>& statBuilders, + const Type& type, + const StatisticsBuilderOptions& options) { + auto kind = type.kind(); + switch (kind) { + case TypeKind::BOOLEAN: + case TypeKind::TINYINT: + case TypeKind::SMALLINT: + case TypeKind::INTEGER: + case TypeKind::BIGINT: + case TypeKind::REAL: + case TypeKind::DOUBLE: + case TypeKind::VARCHAR: + case TypeKind::VARBINARY: + case TypeKind::TIMESTAMP: + statBuilders.push_back(StatisticsBuilder::create(type, options)); + break; + + case TypeKind::ARRAY: { + statBuilders.push_back(StatisticsBuilder::create(type, options)); + const auto& arrayType = dynamic_cast(type); + createTree(statBuilders, *arrayType.elementType(), options); + break; + } + + case TypeKind::MAP: { + statBuilders.push_back(StatisticsBuilder::create(type, options)); + const auto& mapType = dynamic_cast(type); + createTree(statBuilders, *mapType.keyType(), options); + createTree(statBuilders, *mapType.valueType(), options); + break; + } + + case TypeKind::ROW: { + statBuilders.push_back(StatisticsBuilder::create(type, options)); + const auto& rowType = dynamic_cast(type); + for (const auto& childType : rowType.children()) { + createTree(statBuilders, *childType, options); + } + break; + } + default: + VELOX_FAIL("Not supported type: {}", kind); + break; + } +} + +void BooleanStatisticsBuilder::addValues(bool value, uint64_t count) { + increaseValueCount(count); + if (trueCount_.has_value() && value) { + trueCount_.value() += count; + } +} + +void BooleanStatisticsBuilder::merge( + const ColumnStatistics& other, + bool ignoreSize) { + StatisticsBuilder::merge(other, ignoreSize); + auto stats = dynamic_cast(&other); + if (!stats) { + if (!other.isAllNull() && trueCount_.has_value()) { + trueCount_.reset(); + } + return; + } + mergeCount(trueCount_, stats->getTrueCount()); +} + +std::unique_ptr BooleanStatisticsBuilder::build() const { + auto trueCount = isAllNull() ? std::nullopt : trueCount_; + auto result = std::make_unique( + static_cast(*this), trueCount); + if (auto numDistinct = estimateNumDistinct()) { + result->setNumDistinct(*numDistinct); + } + return result; +} + +void IntegerStatisticsBuilder::addValues(int64_t value, uint64_t count) { + increaseValueCount(count); + if (min_.has_value() && value < min_.value()) { + min_ = value; + } + if (max_.has_value() && value > max_.value()) { + max_ = value; + } + addWithOverflowCheck(sum_, value, count); + addHash(value); +} + +void IntegerStatisticsBuilder::merge( + const ColumnStatistics& other, + bool ignoreSize) { + StatisticsBuilder::merge(other, ignoreSize); + auto stats = dynamic_cast(&other); + if (!stats) { + if (!other.isAllNull()) { + min_.reset(); + max_.reset(); + sum_.reset(); + } + return; + } + mergeMin(min_, stats->getMinimum()); + mergeMax(max_, stats->getMaximum()); + mergeWithOverflowCheck(sum_, stats->getSum()); +} + +std::unique_ptr IntegerStatisticsBuilder::build() const { + auto min = isAllNull() ? std::nullopt : min_; + auto max = isAllNull() ? std::nullopt : max_; + auto sum = isAllNull() ? std::nullopt : sum_; + auto result = std::make_unique( + static_cast(*this), min, max, sum); + if (auto numDistinct = estimateNumDistinct()) { + result->setNumDistinct(*numDistinct); + } + return result; +} + +void DoubleStatisticsBuilder::addValues(double value, uint64_t count) { + increaseValueCount(count); + if (std::isnan(value)) { + clear(); + return; + } + + if (min_.has_value() && value < min_.value()) { + min_ = value; + } + if (max_.has_value() && value > max_.value()) { + max_ = value; + } + addHash(value); + if (sum_.has_value()) { + for (uint64_t i = 0; i < count; ++i) { + sum_.value() += value; + } + if (std::isnan(sum_.value())) { + sum_.reset(); + } + } +} + +void DoubleStatisticsBuilder::merge( + const ColumnStatistics& other, + bool ignoreSize) { + StatisticsBuilder::merge(other, ignoreSize); + auto stats = dynamic_cast(&other); + if (!stats) { + if (!other.isAllNull()) { + clear(); + } + return; + } + mergeMin(min_, stats->getMinimum()); + mergeMax(max_, stats->getMaximum()); + mergeCount(sum_, stats->getSum()); + if (sum_.has_value() && std::isnan(sum_.value())) { + sum_.reset(); + } +} + +std::unique_ptr DoubleStatisticsBuilder::build() const { + auto min = isAllNull() ? std::nullopt : min_; + auto max = isAllNull() ? std::nullopt : max_; + auto sum = isAllNull() ? std::nullopt : sum_; + auto result = std::make_unique( + static_cast(*this), min, max, sum); + if (auto numDistinct = estimateNumDistinct()) { + result->setNumDistinct(*numDistinct); + } + return result; +} + +void StringStatisticsBuilder::addValues( + std::string_view value, + uint64_t count) { + auto isSelfEmpty = isAllNull(); + increaseValueCount(count); + if (isSelfEmpty) { + min_ = value; + max_ = value; + } else { + if (min_.has_value() && value < std::string_view{min_.value()}) { + min_ = value; + } + if (max_.has_value() && value > std::string_view{max_.value()}) { + max_ = value; + } + } + addHash(value); + + addWithOverflowCheck(length_, value.size(), count); +} + +void StringStatisticsBuilder::merge( + const ColumnStatistics& other, + bool ignoreSize) { + auto isSelfEmpty = isAllNull(); + StatisticsBuilder::merge(other, ignoreSize); + auto stats = dynamic_cast(&other); + if (!stats) { + if (!other.isAllNull()) { + min_.reset(); + max_.reset(); + length_.reset(); + } + return; + } + + if (other.isAllNull()) { + return; + } + + if (isSelfEmpty) { + min_ = stats->getMinimum(); + max_ = stats->getMaximum(); + } else { + mergeMin(min_, stats->getMinimum()); + mergeMax(max_, stats->getMaximum()); + } + + mergeWithOverflowCheck(length_, stats->getTotalLength()); +} + +std::unique_ptr StringStatisticsBuilder::build() const { + std::optional min; + std::optional max; + std::optional length; + if (!isAllNull()) { + if (shouldKeepString(min_, lengthLimit_)) { + min = min_; + } + if (shouldKeepString(max_, lengthLimit_)) { + max = max_; + } + if (isValidLength(length_)) { + length = length_.value(); + } + } + auto result = std::make_unique( + static_cast(*this), min, max, length); + if (auto numDistinct = estimateNumDistinct()) { + result->setNumDistinct(*numDistinct); + } + return result; +} + +void BinaryStatisticsBuilder::addValues(uint64_t length, uint64_t count) { + increaseValueCount(count); + addWithOverflowCheck(length_, length, count); +} + +void BinaryStatisticsBuilder::merge( + const ColumnStatistics& other, + bool ignoreSize) { + StatisticsBuilder::merge(other, ignoreSize); + auto stats = dynamic_cast(&other); + if (!stats) { + if (!other.isAllNull() && length_.has_value()) { + length_.reset(); + } + return; + } + mergeWithOverflowCheck(length_, stats->getTotalLength()); +} + +std::unique_ptr BinaryStatisticsBuilder::build() const { + auto length = + (!isAllNull() && isValidLength(length_)) ? length_ : std::nullopt; + auto result = std::make_unique( + static_cast(*this), length); + if (auto numDistinct = estimateNumDistinct()) { + result->setNumDistinct(*numDistinct); + } + return result; +} + +} // namespace facebook::velox::dwio::stats diff --git a/velox/dwio/common/StatisticsBuilder.h b/velox/dwio/common/StatisticsBuilder.h new file mode 100644 index 00000000000..61319f73364 --- /dev/null +++ b/velox/dwio/common/StatisticsBuilder.h @@ -0,0 +1,334 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/common/base/Exceptions.h" +#include "velox/common/hyperloglog/SparseHll.h" +#include "velox/dwio/common/Statistics.h" +#include "velox/type/Type.h" + +namespace facebook::velox::dwio::stats { + +// Import column statistics types from dwio::common. +using common::BinaryColumnStatistics; +using common::BooleanColumnStatistics; +using common::ColumnStatistics; +using common::DoubleColumnStatistics; +using common::IntegerColumnStatistics; +using common::StringColumnStatistics; + +/// Options for creating StatisticsBuilder instances. +struct StatisticsBuilderOptions { + explicit StatisticsBuilderOptions( + uint32_t stringLengthLimit, + std::optional initialSize = std::nullopt, + bool countDistincts = false, + HashStringAllocator* allocator = nullptr) + : stringLengthLimit{stringLengthLimit}, + initialSize{initialSize}, + countDistincts(countDistincts), + allocator(allocator) {} + + /// Maximum length of min/max string values to track. Strings longer than + /// this limit are dropped from statistics. + uint32_t stringLengthLimit; + + /// Initial value for the size statistic (total stream bytes). Nullopt means + /// size tracking is disabled until ensureSize() is called. + std::optional initialSize; + + /// Whether to count approximate distinct values using HyperLogLog. Requires + /// 'allocator' to be set. + bool countDistincts{false}; + + /// Allocator for HyperLogLog distinct counting. Required if 'countDistincts' + /// is true. + HashStringAllocator* allocator; + + /// Returns a copy with distinct counting disabled. + StatisticsBuilderOptions dropNumDistinct() const { + return StatisticsBuilderOptions(stringLengthLimit, initialSize); + } +}; + +/// Base class for stats builder. Stats builder is used in writer and file merge +/// to collect and merge stats. +/// It can also be used for gathering stats in ad hoc sampling. In this case it +/// may also count distinct values if enabled in 'options'. +class StatisticsBuilder : public virtual ColumnStatistics { + public: + explicit StatisticsBuilder(const StatisticsBuilderOptions& options) + : options_{options} { + VELOX_CHECK( + !options.countDistincts || options.allocator != nullptr, + "allocator is required when countDistincts is true"); + init(); + } + + ~StatisticsBuilder() override = default; + + void setHasNull() { + hasNull_ = true; + } + + void increaseValueCount(uint64_t count = 1) { + if (valueCount_.has_value()) { + valueCount_.value() += count; + } + } + + void increaseRawSize(uint64_t rawSize) { + if (rawSize_.has_value()) { + rawSize_.value() += rawSize; + } + } + + void clearRawSize() { + rawSize_.reset(); + } + + void ensureSize() { + if (!size_.has_value()) { + size_ = 0; + } + } + + void incrementSize(uint64_t size); + + template + void addHash(const T& data) { + if (hll_) { + hll_->insertHash(folly::hasher()(data)); + } + } + + int64_t cardinality() const { + VELOX_CHECK_NOT_NULL(hll_); + return hll_->cardinality(); + } + + /// Returns estimated number of distinct values if distinct counting is + /// enabled, or std::nullopt otherwise. + std::optional estimateNumDistinct() const { + if (hll_) { + return hll_->cardinality(); + } + return std::nullopt; + } + + /// Merges stats of same type. Used in writer to aggregate file level stats. + virtual void merge(const ColumnStatistics& other, bool ignoreSize = false); + + /// Resets to initial state. Used where row index entry level stats is + /// captured. + virtual void reset() { + init(); + } + + /// Builds a read-only ColumnStatistics snapshot. Typed stats (min/max/sum) + /// are omitted when isAllNull(). String min/max are omitted when they exceed + /// the string length limit. + virtual std::unique_ptr build() const; + + /// Creates a StatisticsBuilder for the given type. For MAP type, creates a + /// base StatisticsBuilder (not a MapStatisticsBuilder, which stays in DWRF). + static std::unique_ptr create( + const Type& type, + const StatisticsBuilderOptions& options); + + /// For the given type tree, creates a list of stat builders. + static void createTree( + std::vector>& statBuilders, + const Type& type, + const StatisticsBuilderOptions& options); + + private: + void init() { + valueCount_ = 0; + hasNull_ = false; + rawSize_ = 0; + size_ = options_.initialSize; + if (options_.countDistincts) { + hll_ = + std::make_shared>(options_.allocator); + } + } + + protected: + StatisticsBuilderOptions options_; + std::shared_ptr> hll_; +}; + +class BooleanStatisticsBuilder : public virtual StatisticsBuilder, + public BooleanColumnStatistics { + public: + explicit BooleanStatisticsBuilder(const StatisticsBuilderOptions& options) + : StatisticsBuilder{options.dropNumDistinct()} { + init(); + } + + ~BooleanStatisticsBuilder() override = default; + + void addValues(bool value, uint64_t count = 1); + + std::unique_ptr build() const override; + + void merge(const ColumnStatistics& other, bool ignoreSize = false) override; + + void reset() override { + StatisticsBuilder::reset(); + init(); + } + + private: + void init() { + trueCount_ = 0; + } +}; + +class IntegerStatisticsBuilder : public virtual StatisticsBuilder, + public IntegerColumnStatistics { + public: + explicit IntegerStatisticsBuilder(const StatisticsBuilderOptions& options) + : StatisticsBuilder{options} { + init(); + } + + ~IntegerStatisticsBuilder() override = default; + + void addValues(int64_t value, uint64_t count = 1); + + std::unique_ptr build() const override; + + void merge(const ColumnStatistics& other, bool ignoreSize = false) override; + + void reset() override { + StatisticsBuilder::reset(); + init(); + } + + private: + void init() { + min_ = std::numeric_limits::max(); + max_ = std::numeric_limits::min(); + sum_ = 0; + } +}; + +static_assert( + std::numeric_limits::has_infinity, + "infinity not defined"); + +class DoubleStatisticsBuilder : public virtual StatisticsBuilder, + public DoubleColumnStatistics { + public: + explicit DoubleStatisticsBuilder(const StatisticsBuilderOptions& options) + : StatisticsBuilder{options} { + init(); + } + + ~DoubleStatisticsBuilder() override = default; + + void addValues(double value, uint64_t count = 1); + + std::unique_ptr build() const override; + + void merge(const ColumnStatistics& other, bool ignoreSize = false) override; + + void reset() override { + StatisticsBuilder::reset(); + init(); + } + + private: + void init() { + min_ = std::numeric_limits::infinity(); + max_ = -std::numeric_limits::infinity(); + sum_ = 0; + } + + void clear() { + min_.reset(); + max_.reset(); + sum_.reset(); + } +}; + +class StringStatisticsBuilder : public virtual StatisticsBuilder, + public StringColumnStatistics { + public: + explicit StringStatisticsBuilder(const StatisticsBuilderOptions& options) + : StatisticsBuilder{options}, lengthLimit_{options.stringLengthLimit} { + init(); + } + + ~StringStatisticsBuilder() override = default; + + void addValues(std::string_view value, uint64_t count = 1); + + std::unique_ptr build() const override; + + void merge(const ColumnStatistics& other, bool ignoreSize = false) override; + + void reset() override { + StatisticsBuilder::reset(); + init(); + } + + protected: + uint32_t lengthLimit_; + + bool shouldKeep(const std::optional& val) const { + return val.has_value() && val.value().size() <= lengthLimit_; + } + + private: + void init() { + min_.reset(); + max_.reset(); + length_ = 0; + } +}; + +class BinaryStatisticsBuilder : public virtual StatisticsBuilder, + public BinaryColumnStatistics { + public: + explicit BinaryStatisticsBuilder(const StatisticsBuilderOptions& options) + : StatisticsBuilder{options.dropNumDistinct()} { + init(); + } + + ~BinaryStatisticsBuilder() override = default; + + void addValues(uint64_t length, uint64_t count = 1); + + std::unique_ptr build() const override; + + void merge(const ColumnStatistics& other, bool ignoreSize = false) override; + + void reset() override { + StatisticsBuilder::reset(); + init(); + } + + private: + void init() { + length_ = 0; + } +}; + +} // namespace facebook::velox::dwio::stats diff --git a/velox/dwio/common/StreamUtil.h b/velox/dwio/common/StreamUtil.h index 0a0a2b2b0e8..c2ac4e1bdc3 100644 --- a/velox/dwio/common/StreamUtil.h +++ b/velox/dwio/common/StreamUtil.h @@ -76,19 +76,19 @@ inline bool isDense(const T* values, int32_t size) { return (values[size - 1] - values[0] == size - 1); } -template +template void rowLoop( const int32_t* rows, int32_t begin, int32_t end, - int32_t step, Dense dense, Sparse sparse, SparseN sparseN) { + static_assert(bits::isPowerOfTwo(kStep)); int32_t i = begin; - auto firstPartial = (end - begin) & ~(step - 1); - for (; i < firstPartial; i += step) { - if (isDense(&rows[i], step)) { + const auto firstPartial = (end - begin) & ~(kStep - 1); + for (; i < begin + firstPartial; i += kStep) { + if (isDense(&rows[i], kStep)) { dense(i); } else { sparse(i); @@ -144,7 +144,7 @@ inline void readContiguous( // Returns the number of elements in rows that are < limit. inline int32_t numBelow(folly::Range rows, int32_t limit) { - return std::lower_bound(rows.begin(), rows.end(), limit) - rows.begin(); + return std::lower_bound(rows.cbegin(), rows.cend(), limit) - rows.cbegin(); } template @@ -159,15 +159,15 @@ inline void loopOverBuffers( int32_t rowOffset = initialRow; int32_t rowIndex = 0; while (rowIndex < rows.size()) { - auto available = (bufferEnd - bufferStart) / sizeof(T); - auto numRowsInBuffer = rows.back() - rowOffset < available + const auto available = (bufferEnd - bufferStart) / sizeof(T); + const auto numRowsInBuffer = rows.back() - rowOffset < available ? rows.size() - rowIndex : numBelow( folly::Range( &rows[rowIndex], rows.size() - rowIndex), rowOffset + available); - if (!numRowsInBuffer) { + if (numRowsInBuffer == 0) { skipBytes( (rows[rowIndex] - rowOffset) * sizeof(T), &input, diff --git a/velox/dwio/common/Throttler.h b/velox/dwio/common/Throttler.h index 05bce4ceb2d..7bd84406efb 100644 --- a/velox/dwio/common/Throttler.h +++ b/velox/dwio/common/Throttler.h @@ -18,6 +18,8 @@ #include #include +#include + #include "velox/common/caching/CachedFactory.h" #include "velox/common/caching/SimpleLRUCache.h" #include "velox/common/io/IoStatistics.h" diff --git a/velox/dwio/common/TypeUtils.cpp b/velox/dwio/common/TypeUtils.cpp index 29e22046196..b0220b9acdc 100644 --- a/velox/dwio/common/TypeUtils.cpp +++ b/velox/dwio/common/TypeUtils.cpp @@ -131,11 +131,13 @@ void checkTypeCompatibility( const FShouldRead& shouldRead, const std::function& exceptionMessageCreator) { if (shouldRead(to) && !isCompatible(from.kind(), kind(to))) { - VELOX_SCHEMA_MISMATCH_ERROR(fmt::format( - "{}, From Kind: {}, To Kind: {}", - exceptionMessageCreator ? exceptionMessageCreator() : "Schema mismatch", - mapTypeKindToName(from.kind()), - mapTypeKindToName(kind(to)))); + VELOX_SCHEMA_MISMATCH_ERROR( + fmt::format( + "{}, From Kind: {}, To Kind: {}", + exceptionMessageCreator ? exceptionMessageCreator() + : "Schema mismatch", + TypeKindName::toName(from.kind()), + TypeKindName::toName(kind(to)))); } if (recurse) { diff --git a/velox/dwio/common/TypeWithId.h b/velox/dwio/common/TypeWithId.h index a147cfe5066..80084dbfb53 100644 --- a/velox/dwio/common/TypeWithId.h +++ b/velox/dwio/common/TypeWithId.h @@ -97,4 +97,6 @@ class TypeWithId : public velox::Tree> { const std::vector> children_; }; +using TypeWithIdPtr = std::shared_ptr; + } // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/UnitLoader.h b/velox/dwio/common/UnitLoader.h index d3125dacc4b..d1fc54ab240 100644 --- a/velox/dwio/common/UnitLoader.h +++ b/velox/dwio/common/UnitLoader.h @@ -16,9 +16,13 @@ #pragma once +#include +#include #include #include #include +#include "velox/common/base/Exceptions.h" +#include "velox/common/base/RuntimeMetrics.h" namespace facebook::velox::dwio::common { @@ -39,6 +43,44 @@ class LoadUnit { virtual uint64_t getIoSize() = 0; }; +class UnitLoaderStats { + public: + UnitLoaderStats() = default; + + void addCounter(const std::string& name, RuntimeCounter counter) { + auto locked = stats_.wlock(); + auto it = locked->find(name); + if (it == locked->end()) { + auto [ptr, inserted] = locked->emplace(name, RuntimeMetric(counter.unit)); + VELOX_CHECK(inserted); + ptr->second.addValue(counter.value); + } else { + VELOX_CHECK_EQ(it->second.unit, counter.unit); + it->second.addValue(counter.value); + } + } + + void merge(const UnitLoaderStats& other) { + auto otherStats = other.stats(); + auto locked = stats_.wlock(); + for (const auto& [name, metric] : otherStats) { + auto it = locked->find(name); + if (it == locked->end()) { + locked->emplace(name, metric); + } else { + it->second.merge(metric); + } + } + } + + folly::F14FastMap stats() const { + return stats_.copy(); + } + + private: + folly::Synchronized> stats_; +}; + class UnitLoader { public: virtual ~UnitLoader() = default; @@ -56,6 +98,10 @@ class UnitLoader { /// Reader reports seek calling this method. The call must be done **before** /// getLoadedUnit for the new unit. virtual void onSeek(uint32_t unit, uint64_t rowOffsetInUnit) = 0; + + virtual UnitLoaderStats stats() { + return UnitLoaderStats(); + }; }; class UnitLoaderFactory { diff --git a/velox/dwio/common/Writer.cpp b/velox/dwio/common/Writer.cpp index 87951cad7c5..5124cfd5ed6 100644 --- a/velox/dwio/common/Writer.cpp +++ b/velox/dwio/common/Writer.cpp @@ -16,6 +16,8 @@ #include "velox/dwio/common/Writer.h" +#include "velox/common/base/Exceptions.h" + namespace facebook::velox::dwio::common { void Writer::checkStateTransition(State oldState, State newState) { diff --git a/velox/dwio/common/Writer.h b/velox/dwio/common/Writer.h index 774aafe4c9a..1950677a0bf 100644 --- a/velox/dwio/common/Writer.h +++ b/velox/dwio/common/Writer.h @@ -16,12 +16,12 @@ #pragma once -#include -#include -#include +#include #include -#include "velox/vector/ComplexVector.h" +#include "velox/common/base/Portability.h" +#include "velox/dwio/common/FileMetadata.h" +#include "velox/vector/BaseVector.h" namespace facebook::velox::dwio::common { @@ -70,10 +70,12 @@ class Writer { /// NOTE: this must be called before close(). virtual bool finish() = 0; - /// Invokes closes the writer. Data can no longer be written. + /// Closes the writer. Data can no longer be written. Returns format-specific + /// file metadata collected during write operations. The returned pointer can + /// be null if no metadata is available, such as for an empty data file. /// /// NOTE: this must be called after the last finish() which returns true. - virtual void close() = 0; + virtual std::unique_ptr close() = 0; /// Aborts the writing by closing the writer and dropping everything. /// Data can no longer be written. @@ -91,7 +93,7 @@ class Writer { /// Validates the state transition from 'oldState' to 'newState'. static void checkStateTransition(State oldState, State newState); - State state_{State::kInit}; + tsan_atomic state_{State::kInit}; }; FOLLY_ALWAYS_INLINE std::ostream& operator<<( diff --git a/velox/dwio/common/compression/CMakeLists.txt b/velox/dwio/common/compression/CMakeLists.txt index ed39366b3a5..913a8d061c9 100644 --- a/velox/dwio/common/compression/CMakeLists.txt +++ b/velox/dwio/common/compression/CMakeLists.txt @@ -17,6 +17,11 @@ velox_add_library( Compression.cpp PagedInputStream.cpp PagedOutputStream.cpp + HEADERS + Compression.h + CompressionBufferPool.h + PagedInputStream.h + PagedOutputStream.h ) velox_link_libraries( diff --git a/velox/dwio/common/compression/Compression.cpp b/velox/dwio/common/compression/Compression.cpp index 222385b3ce4..59e4232176c 100644 --- a/velox/dwio/common/compression/Compression.cpp +++ b/velox/dwio/common/compression/Compression.cpp @@ -16,7 +16,9 @@ #include "velox/dwio/common/compression/Compression.h" #include "velox/common/compression/LzoDecompressor.h" +#include "velox/common/time/CpuWallTimer.h" #include "velox/dwio/common/IntCodecCommon.h" +#include "velox/dwio/common/Statistics.h" #include "velox/dwio/common/compression/PagedInputStream.h" #include @@ -398,9 +400,6 @@ uint64_t Lz4Decompressor::decompressInternal( return static_cast(result); } -// NOTE: We do not keep `ZSTD_DCtx' around on purpose, because if we keep it -// around, in flat map column reader we have hundreds of thousands of -// decompressors at same time and causing OOM. class ZstdDecompressor : public Decompressor { public: explicit ZstdDecompressor( @@ -424,7 +423,10 @@ uint64_t ZstdDecompressor::decompress( uint64_t srcLength, char* dest, uint64_t destLength) { - auto ret = ZSTD_decompress(dest, destLength, src, srcLength); + // Reuse 'ZSTD_DCtx' per-thread to avoid repeated allocations. + thread_local std::unique_ptr ctx{ + ZSTD_createDCtx(), ZSTD_freeDCtx}; + auto ret = ZSTD_decompressDCtx(ctx.get(), dest, destLength, src, srcLength); DWIO_ENSURE( !ZSTD_isError(ret), "ZSTD returned an error: ", @@ -513,8 +515,9 @@ class ZlibDecompressionStream : public PagedInputStream, const std::string& streamDebugInfo, bool isGzip = false, bool useRawDecompression = false, - size_t compressedLength = 0) - : PagedInputStream{std::move(inStream), pool, streamDebugInfo, useRawDecompression, compressedLength}, + size_t compressedLength = 0, + io::IoCounter* decompressCounter = nullptr) + : PagedInputStream{std::move(inStream), pool, streamDebugInfo, useRawDecompression, compressedLength, decompressCounter}, ZlibDecompressor{blockSize, windowBits, streamDebugInfo, isGzip} {} ~ZlibDecompressionStream() override = default; @@ -555,6 +558,8 @@ bool ZlibDecompressionStream::readOrSkip(const void** data, int32_t* size) { *size = static_cast(availSize); outputBufferPtr_ = inputBufferPtr_ + availSize; outputBufferLength_ = 0; + inputBufferPtr_ += availSize; + remainingLength_ -= availSize; } else { DWIO_ENSURE_EQ( state_, @@ -566,43 +571,54 @@ bool ZlibDecompressionStream::readOrSkip(const void** data, int32_t* size) { prepareOutputBuffer( getDecompressedLength(inputBufferPtr_, availSize).first); - reset(); - zstream_.next_in = - reinterpret_cast(const_cast(inputBufferPtr_)); - zstream_.avail_in = folly::to(availSize); - outputBufferPtr_ = outputBuffer_->data(); - zstream_.next_out = - reinterpret_cast(const_cast(outputBufferPtr_)); - zstream_.avail_out = folly::to(blockSize_); - int32_t result; - do { - result = inflate( - &zstream_, availSize == remainingLength_ ? Z_FINISH : Z_SYNC_FLUSH); - switch (result) { - case Z_OK: - remainingLength_ -= availSize; - inputBufferPtr_ += availSize; + auto doDecompress = [&]() { + reset(); + int32_t result; + *size = 0; + do { + if (inputBufferPtr_ == inputBufferPtrEnd_) { readBuffer(true); - availSize = std::min( - static_cast(inputBufferPtrEnd_ - inputBufferPtr_), - remainingLength_); - zstream_.next_in = - reinterpret_cast(const_cast(inputBufferPtr_)); - zstream_.avail_in = static_cast(availSize); - break; - case Z_STREAM_END: - break; - default: - DWIO_RAISE( - "Error in ZlibDecompressionStream::Next in ", - getName(), - ". error: ", - result, - " Info: ", - ZlibDecompressor::streamDebugInfo_); - } - } while (result != Z_STREAM_END); - *size = static_cast(blockSize_ - zstream_.avail_out); + } + zstream_.next_in = + reinterpret_cast(const_cast(inputBufferPtr_)); + zstream_.avail_in = + static_cast(inputBufferPtrEnd_ - inputBufferPtr_); + + do { + // size_ of outputBuffer_ is not updated in inflate, so *size is used + // here to ensure enough capacity for the output data. + outputBuffer_->extend(*size); + outputBufferPtr_ = outputBuffer_->data(); + zstream_.next_out = reinterpret_cast( + const_cast(outputBufferPtr_ + *size)); + zstream_.avail_out = folly::to(blockSize_); + result = inflate(&zstream_, Z_SYNC_FLUSH); + // Result handling adapted from https://zlib.net/zlib_how.html + switch (result) { + case Z_NEED_DICT: + result = Z_DATA_ERROR; + [[fallthrough]]; + case Z_DATA_ERROR: + [[fallthrough]]; + case Z_MEM_ERROR: + [[fallthrough]]; + case Z_STREAM_ERROR: + DWIO_RAISE("Failed to inflate input data. error: ", result); + default: + *size += static_cast( + blockSize_ - static_cast(zstream_.avail_out)); + const size_t inputConsumed = + reinterpret_cast(zstream_.next_in) - + inputBufferPtr_; + remainingLength_ -= inputConsumed; + inputBufferPtr_ += inputConsumed; + } + } while (zstream_.avail_out == 0); + } while (result != Z_STREAM_END); + }; + + withDecompressStats(decompressCounter_, [&] { doDecompress(); }); + if (data) { *data = outputBufferPtr_; } @@ -610,8 +626,6 @@ bool ZlibDecompressionStream::readOrSkip(const void** data, int32_t* size) { outputBufferPtr_ += *size; } - inputBufferPtr_ += availSize; - remainingLength_ -= availSize; bytesReturned_ += *size; return true; } @@ -668,7 +682,8 @@ std::unique_ptr createDecompressor( const std::string& streamDebugInfo, const Decrypter* decrypter, bool useRawDecompression, - size_t compressedLength) { + size_t compressedLength, + io::IoCounter* decompressCounter) { std::unique_ptr decompressor; switch (static_cast(kind)) { case CompressionKind::CompressionKind_NONE: @@ -689,7 +704,8 @@ std::unique_ptr createDecompressor( streamDebugInfo, false, useRawDecompression, - compressedLength); + compressedLength, + decompressCounter); } decompressor = std::make_unique( blockSize, options.format.zlib.windowBits, streamDebugInfo, false); @@ -706,7 +722,8 @@ std::unique_ptr createDecompressor( streamDebugInfo, true, useRawDecompression, - compressedLength); + compressedLength, + decompressCounter); } decompressor = std::make_unique( blockSize, options.format.zlib.windowBits, streamDebugInfo, true); @@ -741,7 +758,8 @@ std::unique_ptr createDecompressor( decrypter, streamDebugInfo, useRawDecompression, - compressedLength); + compressedLength, + decompressCounter); } } // namespace facebook::velox::dwio::common::compression diff --git a/velox/dwio/common/compression/Compression.h b/velox/dwio/common/compression/Compression.h index 3d26b3af98a..00204918a4c 100644 --- a/velox/dwio/common/compression/Compression.h +++ b/velox/dwio/common/compression/Compression.h @@ -17,6 +17,7 @@ #pragma once #include "velox/common/compression/Compression.h" +#include "velox/common/io/IoStatistics.h" #include "velox/dwio/common/SeekableInputStream.h" #include "velox/dwio/common/encryption/Encryption.h" @@ -98,6 +99,7 @@ struct CompressionOptions { * @param options The compression options to use * @param useRawDecompression Specify whether to perform raw decompression * @param compressedLength The compressed block length for raw decompression + * @param decompressCounter Optional IoCounter for tracking decompression stats */ std::unique_ptr createDecompressor( facebook::velox::common::CompressionKind kind, @@ -108,7 +110,8 @@ std::unique_ptr createDecompressor( const std::string& streamDebugInfo, const dwio::common::encryption::Decrypter* decryptr = nullptr, bool useRawDecompression = false, - size_t compressedLength = 0); + size_t compressedLength = 0, + io::IoCounter* decompressCounter = nullptr); /** * Create a compressor for the given compression kind. diff --git a/velox/dwio/common/compression/PagedInputStream.cpp b/velox/dwio/common/compression/PagedInputStream.cpp index 08357298cf5..7891f3a3d58 100644 --- a/velox/dwio/common/compression/PagedInputStream.cpp +++ b/velox/dwio/common/compression/PagedInputStream.cpp @@ -16,6 +16,8 @@ #include "velox/dwio/common/compression/PagedInputStream.h" +#include "velox/dwio/common/Statistics.h" + namespace facebook::velox::dwio::common::compression { void PagedInputStream::prepareOutputBuffer(uint64_t uncompressedLength) { @@ -165,7 +167,7 @@ bool PagedInputStream::readOrSkip(const void** data, int32_t* size) { // perform decryption if (decrypter_) { decryptionBuffer_ = - decrypter_->decrypt(folly::StringPiece{input, remainingLength_}); + decrypter_->decrypt(std::string_view{input, remainingLength_}); input = reinterpret_cast(decryptionBuffer_->data()); remainingLength_ = decryptionBuffer_->length(); if (data) { @@ -186,11 +188,13 @@ bool PagedInputStream::readOrSkip(const void** data, int32_t* size) { outputBufferPtr_ = nullptr; } else { prepareOutputBuffer(decompressedLength); - outputBufferLength_ = decompressor_->decompress( - input, - remainingLength_, - outputBuffer_->data(), - outputBuffer_->capacity()); + outputBufferLength_ = withDecompressStats(decompressCounter_, [&] { + return decompressor_->decompress( + input, + remainingLength_, + outputBuffer_->data(), + outputBuffer_->capacity()); + }); if (data) { *data = outputBuffer_->data(); } diff --git a/velox/dwio/common/compression/PagedInputStream.h b/velox/dwio/common/compression/PagedInputStream.h index 15b1acd630a..7045e98483f 100644 --- a/velox/dwio/common/compression/PagedInputStream.h +++ b/velox/dwio/common/compression/PagedInputStream.h @@ -16,6 +16,7 @@ #pragma once +#include "velox/common/io/IoStatistics.h" #include "velox/dwio/common/SeekableInputStream.h" #include "velox/dwio/common/compression/Compression.h" @@ -30,13 +31,15 @@ class PagedInputStream : public dwio::common::SeekableInputStream { const dwio::common::encryption::Decrypter* decrypter, const std::string& streamDebugInfo, bool useRawDecompression = false, - size_t compressedLength = 0) + size_t compressedLength = 0, + io::IoCounter* decompressCounter = nullptr) : input_(std::move(inStream)), pool_(memPool), inputBuffer_(pool_), decompressor_{std::move(decompressor)}, decrypter_{decrypter}, - streamDebugInfo_{streamDebugInfo} { + streamDebugInfo_{streamDebugInfo}, + decompressCounter_{decompressCounter} { DWIO_ENSURE( decompressor_ || decrypter_, "one of decompressor or decryptor is required"); @@ -56,7 +59,7 @@ class PagedInputStream : public dwio::common::SeekableInputStream { // NOTE: This always returns true. bool SkipInt64(int64_t count) override; - google::protobuf::int64 ByteCount() const override { + int64_t ByteCount() const override { return bytesReturned_ + pendingSkip_; } @@ -87,13 +90,15 @@ class PagedInputStream : public dwio::common::SeekableInputStream { memory::MemoryPool& memPool, const std::string& streamDebugInfo, bool useRawDecompression = false, - size_t compressedLength = 0) + size_t compressedLength = 0, + io::IoCounter* decompressCounter = nullptr) : input_(std::move(inStream)), pool_(memPool), inputBuffer_(pool_), decompressor_{nullptr}, decrypter_{nullptr}, - streamDebugInfo_{streamDebugInfo} { + streamDebugInfo_{streamDebugInfo}, + decompressCounter_{decompressCounter} { DWIO_ENSURE( !useRawDecompression || compressedLength > 0, "For raw decompression, compressedLength should be greater than zero"); @@ -182,6 +187,11 @@ class PagedInputStream : public dwio::common::SeekableInputStream { // Stream Debug Info const std::string streamDebugInfo_; + + protected: + // Owned by ColumnReaderStatistics. Valid for the lifetime of this stream + // because ColumnReaderStatistics outlives all streams within a DwrfRowReader. + io::IoCounter* const decompressCounter_{nullptr}; }; } // namespace facebook::velox::dwio::common::compression diff --git a/velox/dwio/common/compression/PagedOutputStream.cpp b/velox/dwio/common/compression/PagedOutputStream.cpp index 18d993bf74f..2519893be4b 100644 --- a/velox/dwio/common/compression/PagedOutputStream.cpp +++ b/velox/dwio/common/compression/PagedOutputStream.cpp @@ -18,7 +18,7 @@ namespace facebook::velox::dwio::common::compression { -std::vector PagedOutputStream::createPage() { +std::vector PagedOutputStream::createPage() { auto origSize = buffer_.size(); VELOX_CHECK_GT(origSize, pageHeaderSize_); origSize -= pageHeaderSize_; @@ -34,15 +34,15 @@ std::vector PagedOutputStream::createPage() { origSize); } - folly::StringPiece compressed; + std::string_view compressed; if (compressedSize >= origSize) { // write orig writeHeader(buffer_.data(), origSize, true); - compressed = folly::StringPiece(buffer_.data(), origSize + pageHeaderSize_); + compressed = std::string_view(buffer_.data(), origSize + pageHeaderSize_); } else { // write compressed writeHeader(compressionBuffer_->data(), compressedSize, false); - compressed = folly::StringPiece( + compressed = std::string_view( compressionBuffer_->data(), compressedSize + pageHeaderSize_); } @@ -50,13 +50,13 @@ std::vector PagedOutputStream::createPage() { return {compressed}; } - encryptionBuffer_ = encryptor_->encrypt(folly::StringPiece( - compressed.begin() + pageHeaderSize_, compressed.end())); + encryptionBuffer_ = encryptor_->encrypt( + std::string_view(compressed.begin() + pageHeaderSize_, compressed.end())); updateSize( const_cast(compressed.begin()), encryptionBuffer_->length()); return { - folly::StringPiece(compressed.begin(), pageHeaderSize_), - folly::StringPiece( + std::string_view(compressed.begin(), pageHeaderSize_), + std::string_view( reinterpret_cast(encryptionBuffer_->data()), encryptionBuffer_->length())}; } diff --git a/velox/dwio/common/compression/PagedOutputStream.h b/velox/dwio/common/compression/PagedOutputStream.h index 498bba3c781..415d2961eca 100644 --- a/velox/dwio/common/compression/PagedOutputStream.h +++ b/velox/dwio/common/compression/PagedOutputStream.h @@ -64,8 +64,8 @@ class PagedOutputStream : public BufferedOutputStream { int32_t strideIndex = -1) const override; private: - // create page using compressor and encryptor - std::vector createPage(); + // Create page using compressor and encryptor. + std::vector createPage(); void writeHeader(char* buffer, size_t compressedSize, bool original); diff --git a/velox/dwio/common/encryption/CMakeLists.txt b/velox/dwio/common/encryption/CMakeLists.txt index 610ebb4684a..6cd1d01807d 100644 --- a/velox/dwio/common/encryption/CMakeLists.txt +++ b/velox/dwio/common/encryption/CMakeLists.txt @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -velox_add_library(velox_dwio_common_encryption Encryption.cpp) +velox_add_library(velox_dwio_common_encryption Encryption.cpp HEADERS Encryption.h TestProvider.h) velox_link_libraries(velox_dwio_common_encryption Folly::folly) diff --git a/velox/dwio/common/encryption/Encryption.cpp b/velox/dwio/common/encryption/Encryption.cpp index 2e7df9a4793..1db5b8c24de 100644 --- a/velox/dwio/common/encryption/Encryption.cpp +++ b/velox/dwio/common/encryption/Encryption.cpp @@ -16,19 +16,11 @@ #include "velox/dwio/common/encryption/Encryption.h" -namespace facebook { -namespace velox { -namespace dwio { -namespace common { -namespace encryption { +namespace facebook::velox::dwio::common::encryption { bool operator==(const EncryptionProperties& a, const EncryptionProperties& b) { return std::addressof(a) == std::addressof(b) || (typeid(a) == typeid(b) && a.equals(b)); } -} // namespace encryption -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio::common::encryption diff --git a/velox/dwio/common/encryption/Encryption.h b/velox/dwio/common/encryption/Encryption.h index db3c2196d95..798f67d3c9f 100644 --- a/velox/dwio/common/encryption/Encryption.h +++ b/velox/dwio/common/encryption/Encryption.h @@ -20,11 +20,7 @@ #include "folly/io/IOBuf.h" #include "velox/dwio/common/exception/Exception.h" -namespace facebook { -namespace velox { -namespace dwio { -namespace common { -namespace encryption { +namespace facebook::velox::dwio::common::encryption { enum class EncryptionProvider { Unknown = 0, CryptoService }; @@ -64,7 +60,7 @@ class Encrypter { virtual const std::string& getKey() const = 0; virtual std::unique_ptr encrypt( - folly::StringPiece input) const = 0; + std::string_view input) const = 0; virtual std::unique_ptr clone() const = 0; }; @@ -87,7 +83,7 @@ class Decrypter { virtual bool isKeyLoaded() const = 0; virtual std::unique_ptr decrypt( - folly::StringPiece input) const = 0; + std::string_view input) const = 0; virtual std::unique_ptr clone() const = 0; }; @@ -108,7 +104,7 @@ class DummyDecrypter : public Decrypter { } std::unique_ptr decrypt( - folly::StringPiece /* unused */) const override { + std::string_view /* unused */) const override { DWIO_RAISE("Failed to access encrypted data"); } @@ -127,8 +123,4 @@ class DummyDecrypterFactory : public DecrypterFactory { } }; -} // namespace encryption -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio::common::encryption diff --git a/velox/dwio/common/encryption/TestProvider.h b/velox/dwio/common/encryption/TestProvider.h index 7ef41de8e57..b50f727a1e9 100644 --- a/velox/dwio/common/encryption/TestProvider.h +++ b/velox/dwio/common/encryption/TestProvider.h @@ -20,12 +20,7 @@ #include "velox/dwio/common/encryption/Encryption.h" #include "velox/dwio/common/exception/Exception.h" -namespace facebook { -namespace velox { -namespace dwio { -namespace common { -namespace encryption { -namespace test { +namespace facebook::velox::dwio::common::encryption::test { class TestEncryption { public: @@ -37,18 +32,19 @@ class TestEncryption { return key_; } - std::unique_ptr encrypt(folly::StringPiece input) const { + std::unique_ptr encrypt(std::string_view input) const { ++count_; auto encoded = velox::encoding::Base64::encodeUrl(input); return folly::IOBuf::copyBuffer(key_ + encoded); } - std::unique_ptr decrypt(folly::StringPiece input) const { + std::unique_ptr decrypt(std::string_view input) const { ++count_; - std::string key{input.begin(), key_.size()}; + std::string key{input.cbegin(), key_.size()}; DWIO_ENSURE_EQ(key_, key); - auto decoded = velox::encoding::Base64::decodeUrl(folly::StringPiece{ - input.begin() + key_.size(), input.size() - key_.size()}); + auto decoded = velox::encoding::Base64::decodeUrl( + std::string_view{ + input.begin() + key_.size(), input.size() - key_.size()}); return folly::IOBuf::copyBuffer(decoded); } @@ -67,8 +63,7 @@ class TestEncrypter : public TestEncryption, public Encrypter { return TestEncryption::getKey(); } - std::unique_ptr encrypt( - folly::StringPiece input) const override { + std::unique_ptr encrypt(std::string_view input) const override { return TestEncryption::encrypt(input); } @@ -89,8 +84,7 @@ class TestDecrypter : public TestEncryption, public Decrypter { return !getKey().empty(); } - std::unique_ptr decrypt( - folly::StringPiece input) const override { + std::unique_ptr decrypt(std::string_view input) const override { return TestEncryption::decrypt(input); } @@ -144,9 +138,4 @@ class TestDecrypterFactory : public DecrypterFactory { } }; -} // namespace test -} // namespace encryption -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio::common::encryption::test diff --git a/velox/dwio/common/exception/CMakeLists.txt b/velox/dwio/common/exception/CMakeLists.txt index c30b1577c29..d981830ca69 100644 --- a/velox/dwio/common/exception/CMakeLists.txt +++ b/velox/dwio/common/exception/CMakeLists.txt @@ -12,6 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -velox_add_library(velox_dwio_common_exception Exception.cpp Exceptions.cpp) +velox_add_library( + velox_dwio_common_exception + Exception.cpp + Exceptions.cpp + HEADERS + Exception.h + Exceptions.h +) velox_link_libraries(velox_dwio_common_exception velox_exception Folly::folly glog::glog) diff --git a/velox/dwio/common/exception/Exception.cpp b/velox/dwio/common/exception/Exception.cpp index 97b0b78fd3c..e15f503d3ee 100644 --- a/velox/dwio/common/exception/Exception.cpp +++ b/velox/dwio/common/exception/Exception.cpp @@ -17,11 +17,7 @@ #include "velox/dwio/common/exception/Exception.h" #include "velox/common/base/Exceptions.h" -namespace facebook { -namespace velox { -namespace dwio { -namespace common { -namespace exception { +namespace facebook::velox::dwio::common::exception { std::unique_ptr& exceptionLogger() { static std::unique_ptr logger(nullptr); @@ -40,8 +36,4 @@ ExceptionLogger* getExceptionLogger() { return exceptionLogger().get(); } -} // namespace exception -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio::common::exception diff --git a/velox/dwio/common/exception/Exception.h b/velox/dwio/common/exception/Exception.h index 5e5e5b85f73..140df90e195 100644 --- a/velox/dwio/common/exception/Exception.h +++ b/velox/dwio/common/exception/Exception.h @@ -18,11 +18,8 @@ #include "velox/common/base/VeloxException.h" -namespace facebook { -namespace velox { -namespace dwio { -namespace common { -namespace exception { +namespace facebook::velox::dwio { +namespace common::exception { class ExceptionLogger { public: @@ -101,8 +98,7 @@ class LoggedException : public velox::VeloxException { } }; -} // namespace exception -} // namespace common +} // namespace common::exception #define DWIO_WARN_IF(e, ...) \ ({ \ @@ -256,6 +252,4 @@ containing information about the file, line, and function where it happened. "]: ", \ ##__VA_ARGS__); -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio diff --git a/velox/dwio/common/exception/Exceptions.cpp b/velox/dwio/common/exception/Exceptions.cpp index 04835dc5ac1..6c1bc5818e3 100644 --- a/velox/dwio/common/exception/Exceptions.cpp +++ b/velox/dwio/common/exception/Exceptions.cpp @@ -22,10 +22,7 @@ #include #include -namespace facebook { -namespace velox { -namespace dwio { -namespace common { +namespace facebook::velox::dwio::common { void verify_range(uint64_t v, uint64_t rangeMask) { auto mv = (v & rangeMask); @@ -67,7 +64,4 @@ std::string format_error_string(std::string fmt...) { return s; } -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/exception/Exceptions.h b/velox/dwio/common/exception/Exceptions.h index d6963468ea4..c241e5788f4 100644 --- a/velox/dwio/common/exception/Exceptions.h +++ b/velox/dwio/common/exception/Exceptions.h @@ -22,10 +22,7 @@ #include #include "velox/dwio/common/exception/Exception.h" -namespace facebook { -namespace velox { -namespace dwio { -namespace common { +namespace facebook::velox::dwio::common { class NotImplementedYet : public std::logic_error { public: @@ -117,7 +114,4 @@ using logic_error = exception_error; using runtime_error = exception_error; using EOF_error = exception_error; -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/tests/BufferedInputTest.cpp b/velox/dwio/common/tests/BufferedInputTest.cpp new file mode 100644 index 00000000000..820394cb350 --- /dev/null +++ b/velox/dwio/common/tests/BufferedInputTest.cpp @@ -0,0 +1,798 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/dwio/common/BufferedInput.h" + +#include +#include +#include +#include + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/caching/FileIds.h" +#include "velox/common/file/tests/TestUtils.h" +#include "velox/connectors/hive/BufferedInputBuilder.h" +#include "velox/dwio/common/DirectBufferedInput.h" +#include "velox/dwio/dwrf/test/TestReadFile.h" + +using namespace facebook::velox::dwio::common; +namespace cache = facebook::velox::cache; +using facebook::velox::StringIdLease; +using facebook::velox::common::Region; +using namespace facebook::velox::memory; +using namespace ::testing; + +namespace { + +class ReadFileMock : public ::facebook::velox::ReadFile { + public: + virtual ~ReadFileMock() override = default; + +// On Centos9 the gtest mock header doesn't initialize the +// buffer_ member in MatcherBase correctly - the default constructor only +// initializes one: /usr/include/gtest/gtest-matchers.h:302:33 resulting in +// error: +// '.testing::Matcher::.testing::internal::MatcherBase::buffer_' is used uninitialized +// [-Werror=uninitialized] +// 302 | : vtable_(other.vtable_), buffer_(other.buffer_) { +// Fix: https://github.com/google/googletest/pull/3797 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wuninitialized" + MOCK_METHOD( + std::string_view, + pread, + (uint64_t offset, + uint64_t length, + void* buf, + (const facebook::velox::FileIoContext&)context), + (const, override)); + + MOCK_METHOD(bool, shouldCoalesce, (), (const, override)); + MOCK_METHOD(uint64_t, size, (), (const, override)); + MOCK_METHOD(uint64_t, memoryUsage, (), (const, override)); + MOCK_METHOD(std::string, getName, (), (const, override)); + MOCK_METHOD(uint64_t, getNaturalReadSize, (), (const, override)); + MOCK_METHOD( + uint64_t, + preadv, + (folly::Range regions, + folly::Range iobufs, + (const facebook::velox::FileIoContext&)context), + (const, override)); +}; + +void expectPreads( + ReadFileMock& file, + std::string_view content, + std::vector reads) { + EXPECT_CALL(file, getName()).WillRepeatedly(Return("mock_name")); + EXPECT_CALL(file, size()).WillRepeatedly(Return(content.size())); + for (auto& read : reads) { + ASSERT_GE(content.size(), read.offset + read.length); + EXPECT_CALL(file, pread(read.offset, read.length, _, _)) + .Times(1) + .WillOnce( + [content]( + uint64_t offset, + uint64_t length, + void* buf, + const facebook::velox::FileIoContext& context) + -> std::string_view { + memcpy(buf, content.data() + offset, length); + return {content.data() + offset, length}; + }); + } +} + +void expectPreadvs( + ReadFileMock& file, + std::string_view content, + std::vector reads) { + EXPECT_CALL(file, getName()).WillRepeatedly(Return("mock_name")); + EXPECT_CALL(file, size()).WillRepeatedly(Return(content.size())); + EXPECT_CALL(file, preadv(_, _, _)) + .Times(1) + .WillOnce( + [content, reads]( + folly::Range regions, + folly::Range iobufs, + const facebook::velox::FileIoContext& context) -> uint64_t { + EXPECT_EQ(regions.size(), reads.size()); + uint64_t length = 0; + for (size_t i = 0; i < reads.size(); ++i) { + const auto& region = regions[i]; + const auto& read = reads[i]; + auto& iobuf = iobufs[i]; + length += region.length; + EXPECT_EQ(region.offset, read.offset); + EXPECT_EQ(region.length, read.length); + if (!read.label.empty()) { + EXPECT_EQ(read.label, region.label); + } + EXPECT_LE(region.offset + region.length, content.size()); + iobuf = folly::IOBuf( + folly::IOBuf::COPY_BUFFER, + content.data() + region.offset, + region.length); + } + + return length; + }); +} +#pragma GCC diagnostic pop + +std::optional getNext(SeekableInputStream& input) { + const void* buf = nullptr; + int32_t size; + if (input.Next(&buf, &size)) { + return std::string( + static_cast(buf), static_cast(size)); + } else { + return std::nullopt; + } +} + +class BufferedInputTest : public testing::Test { + protected: + static void SetUpTestCase() { + MemoryManager::testingSetInstance(MemoryManager::Options{}); + } + + const std::shared_ptr pool_ = memoryManager()->addLeafPool(); +}; + +TEST_F(BufferedInputTest, hasCache) { + auto readFile = + std::make_shared(std::string("test")); + BufferedInput input(readFile, *pool_); + // Base BufferedInput does not have cache. + EXPECT_FALSE(input.hasCache()); + + // Cache APIs throw when there is no backing cache. + VELOX_ASSERT_THROW( + input.cacheRegion(0, 4, std::string_view("test")), + "cacheRegion requires a backing cache"); + VELOX_ASSERT_THROW( + input.findCachedRegion(0), "findCachedRegion requires a backing cache"); +} + +TEST_F(BufferedInputTest, cachedRegion) { + auto dataCache = cache::AsyncDataCache::create(memoryManager()->allocator()); + + // Empty pin throws. + VELOX_ASSERT_THROW( + CachedRegion(cache::CachePin{}), + "CachedRegion requires a non-empty cache pin"); + + // Exclusive pin throws. + auto& ids = facebook::velox::fileIds(); + { + StringIdLease fileId(ids, "exclusiveTestFile"); + cache::RawFileCacheKey key{fileId.id(), 0}; + auto pin = dataCache->findOrCreate(key, 100); + ASSERT_FALSE(pin.empty()); + ASSERT_TRUE(pin.checkedEntry()->isExclusive()); + VELOX_ASSERT_THROW( + CachedRegion(std::move(pin)), + "CachedRegion requires a shared (non-exclusive) cache pin"); + } + + struct TestParam { + uint64_t entrySize; + bool expectTinyData; + + std::string debugString() const { + return fmt::format( + "entrySize {}, expectTinyData {}", entrySize, expectTinyData); + } + }; + std::vector testSettings = { + // Small entry uses tinyData (single contiguous range). + {100, true}, + // Entry just below tinyData boundary still uses tinyData. + {cache::AsyncDataCacheEntry::kTinyDataSize - 1, true}, + // Entry at tinyData boundary uses allocation. + {cache::AsyncDataCacheEntry::kTinyDataSize, false}, + // One page allocation (single run). + {AllocationTraits::kPageSize, false}, + // Large allocation (possibly multiple runs). + {128 << 10, false}, + }; + + for (size_t i = 0; i < testSettings.size(); ++i) { + const auto& testData = testSettings[i]; + SCOPED_TRACE(testData.debugString()); + + const auto entrySize = testData.entrySize; + std::string expected(entrySize, '\0'); + for (uint64_t j = 0; j < entrySize; ++j) { + expected[j] = static_cast('a' + (j % 26)); + } + + StringIdLease fileId(ids, fmt::format("cachedRegionTestFile_{}", i)); + cache::RawFileCacheKey key{fileId.id(), 0}; + auto pin = dataCache->findOrCreate(key, entrySize); + ASSERT_FALSE(pin.empty()); + auto* entry = pin.checkedEntry(); + ASSERT_TRUE(entry->isExclusive()); + + // Populate the entry with test data. + if (testData.expectTinyData) { + ASSERT_TRUE(entry->hasContiguousData()); + memcpy(entry->contiguousData(), expected.data(), entrySize); + } else { + auto& allocation = entry->nonContiguousData(); + ASSERT_GT(allocation.numRuns(), 0); + uint64_t offset = 0; + for (int i = 0; i < allocation.numRuns() && offset < entrySize; ++i) { + auto run = allocation.runAt(i); + const uint64_t bytes = run.numPages() * AllocationTraits::kPageSize; + const uint64_t copySize = std::min(bytes, entrySize - offset); + memcpy(run.data(), expected.data() + offset, copySize); + offset += copySize; + } + } + entry->setExclusiveToShared(); + + CachedRegion region(std::move(pin)); + EXPECT_EQ(region.size(), entrySize); + ASSERT_FALSE(region.ranges().empty()); + + if (testData.expectTinyData) { + EXPECT_EQ(region.ranges().size(), 1); + } + + // Verify content through ranges. + uint64_t verified = 0; + for (const auto& range : region.ranges()) { + EXPECT_EQ( + std::string_view(range.data(), range.size()), + std::string_view(expected.data() + verified, range.size())); + verified += range.size(); + } + EXPECT_EQ(verified, entrySize); + + // Verify toIOBuf produces identical content. + auto iobuf = region.toIOBuf(); + EXPECT_EQ(iobuf.computeChainDataLength(), entrySize); + iobuf.coalesce(); + EXPECT_EQ( + std::string_view( + reinterpret_cast(iobuf.data()), iobuf.length()), + expected); + } +} + +TEST_F(BufferedInputTest, zeroLengthStream) { + auto readFile = + std::make_shared(std::string()); + BufferedInput input(readFile, *pool_); + auto ret = input.enqueue({0, 0}); + EXPECT_EQ(input.nextFetchSize(), 0); + EXPECT_NE(ret, nullptr); + const void* buf = nullptr; + int32_t size = 1; + EXPECT_FALSE(ret->Next(&buf, &size)); + EXPECT_EQ(size, 0); +} + +TEST_F(BufferedInputTest, useRead) { + std::string content = "hello"; + auto readFileMock = std::make_shared(); + expectPreads(*readFileMock, content, {{0, 5}}); + // Use read + BufferedInput input( + readFileMock, + *pool_, + MetricsLog::voidLog(), + nullptr, + nullptr, + 10, + /* wsVRLoad = */ false); + auto ret = input.enqueue({0, 5}); + ASSERT_NE(ret, nullptr); + + EXPECT_EQ(input.nextFetchSize(), 5); + input.load(LogType::TEST); + + auto next = getNext(*ret); + ASSERT_TRUE(next.has_value()); + EXPECT_EQ(next.value(), content); +} + +TEST_F(BufferedInputTest, useVRead) { + std::string content = "hello"; + auto readFileMock = std::make_shared(); + expectPreadvs(*readFileMock, content, {{0, 5}}); + // Use vread + BufferedInput input( + readFileMock, + *pool_, + MetricsLog::voidLog(), + nullptr, + nullptr, + 10, + /* wsVRLoad = */ true); + auto ret = input.enqueue({0, 5}); + ASSERT_NE(ret, nullptr); + + EXPECT_EQ(input.nextFetchSize(), 5); + input.load(LogType::TEST); + + auto next = getNext(*ret); + ASSERT_TRUE(next.has_value()); + EXPECT_EQ(next.value(), content); +} + +TEST_F(BufferedInputTest, willMerge) { + std::string content = "hello world"; + auto readFileMock = std::make_shared(); + + // Will merge because the distance is 1 and max distance to merge is 10. + // Expect only one call. + expectPreads(*readFileMock, content, {{0, 11}}); + + BufferedInput input( + readFileMock, + *pool_, + MetricsLog::voidLog(), + nullptr, + nullptr, + 10, // Will merge if distance <= 10 + /* wsVRLoad = */ false); + + auto ret1 = input.enqueue({0, 5}); + auto ret2 = input.enqueue({6, 5}); + ASSERT_NE(ret1, nullptr); + ASSERT_NE(ret2, nullptr); + + EXPECT_EQ(input.nextFetchSize(), 10); + input.load(LogType::TEST); + + auto next1 = getNext(*ret1); + ASSERT_TRUE(next1.has_value()); + EXPECT_EQ(next1.value(), "hello"); + + auto next2 = getNext(*ret2); + ASSERT_TRUE(next2.has_value()); + EXPECT_EQ(next2.value(), "world"); +} + +TEST_F(BufferedInputTest, wontMerge) { + std::string content = "hello world"; // two spaces + auto readFileMock = std::make_shared(); + + // Won't merge because the distance is 2 and max distance to merge is 1. + // Expect two calls + expectPreads(*readFileMock, content, {{0, 5}, {7, 5}}); + + BufferedInput input( + readFileMock, + *pool_, + MetricsLog::voidLog(), + nullptr, + nullptr, + 1, // Will merge if distance <= 1 + /* wsVRLoad = */ false); + + auto ret1 = input.enqueue({0, 5}); + auto ret2 = input.enqueue({7, 5}); + ASSERT_NE(ret1, nullptr); + ASSERT_NE(ret2, nullptr); + + EXPECT_EQ(input.nextFetchSize(), 10); + input.load(LogType::TEST); + + auto next1 = getNext(*ret1); + ASSERT_TRUE(next1.has_value()); + EXPECT_EQ(next1.value(), "hello"); + + auto next2 = getNext(*ret2); + ASSERT_TRUE(next2.has_value()); + EXPECT_EQ(next2.value(), "world"); +} + +TEST_F(BufferedInputTest, readSorting) { + std::string content = "aaabbbcccdddeeefffggghhhiiijjjkkklllmmmnnnooopppqqq"; + std::vector regions = {{6, 3}, {24, 3}, {3, 3}, {0, 3}, {29, 3}}; + + auto readFileMock = std::make_shared(); + expectPreads(*readFileMock, content, {{0, 9}, {24, 3}, {29, 3}}); + BufferedInput input( + readFileMock, + *pool_, + MetricsLog::voidLog(), + nullptr, + nullptr, + 1, // Will merge if distance <= 1 + /* wsVRLoad = */ false); + + std::vector, std::string>> + result; + result.reserve(regions.size()); + int64_t bytesToRead = 0; + for (auto& region : regions) { + bytesToRead += region.length; + auto ret = input.enqueue(region); + ASSERT_NE(ret, nullptr); + result.push_back( + {std::move(ret), content.substr(region.offset, region.length)}); + } + + EXPECT_EQ(input.nextFetchSize(), bytesToRead); + input.load(LogType::TEST); + + for (auto& r : result) { + auto next = getNext(*r.first); + ASSERT_TRUE(next.has_value()); + EXPECT_EQ(next.value(), r.second); + } +} + +TEST_F(BufferedInputTest, VreadSorting) { + std::string content = "aaabbbcccdddeeefffggghhhiiijjjkkklllmmmnnnooopppqqq"; + std::vector regions = {{6, 3}, {24, 3}, {3, 3}, {0, 3}, {29, 3}}; + + auto readFileMock = std::make_shared(); + expectPreadvs( + *readFileMock, content, {{0, 3}, {3, 3}, {6, 3}, {24, 3}, {29, 3}}); + BufferedInput input( + readFileMock, + *pool_, + MetricsLog::voidLog(), + nullptr, + nullptr, + 1, // Will merge if distance <= 1 + /* wsVRLoad = */ true); + + std::vector, std::string>> + result; + result.reserve(regions.size()); + int64_t bytesToRead = 0; + for (auto& region : regions) { + bytesToRead += region.length; + auto ret = input.enqueue(region); + ASSERT_NE(ret, nullptr); + result.push_back( + {std::move(ret), content.substr(region.offset, region.length)}); + } + + EXPECT_EQ(input.nextFetchSize(), bytesToRead); + input.load(LogType::TEST); + + for (auto& r : result) { + auto next = getNext(*r.first); + ASSERT_TRUE(next.has_value()); + EXPECT_EQ(next.value(), r.second); + } +} + +TEST_F(BufferedInputTest, VreadSortingWithLabels) { + std::string content = "aaabbbcccdddeeefffggghhhiiijjjkkklllmmmnnnooopppqqq"; + std::vector l = {"a", "b", "c", "d", "e"}; + std::vector regions = { + {6, 3, l[2]}, {24, 3, l[3]}, {3, 3, l[1]}, {0, 3, l[0]}, {29, 3, l[4]}}; + + auto readFileMock = std::make_shared(); + expectPreadvs( + *readFileMock, + content, + {{0, 3, l[0]}, {3, 3, l[1]}, {6, 3, l[2]}, {24, 3, l[3]}, {29, 3, l[4]}}); + BufferedInput input( + readFileMock, + *pool_, + MetricsLog::voidLog(), + nullptr, + nullptr, + 1, // Will merge if distance <= 1 + /* wsVRLoad = */ true); + + std::vector, std::string>> + result; + result.reserve(regions.size()); + int64_t bytesToRead = 0; + for (auto& region : regions) { + bytesToRead += region.length; + auto ret = input.enqueue(region); + ASSERT_NE(ret, nullptr); + result.push_back( + {std::move(ret), content.substr(region.offset, region.length)}); + } + + EXPECT_EQ(input.nextFetchSize(), bytesToRead); + input.load(LogType::TEST); + + for (auto& r : result) { + auto next = getNext(*r.first); + ASSERT_TRUE(next.has_value()); + EXPECT_EQ(next.value(), r.second); + } +} + +TEST_F(BufferedInputTest, resetAfterAllStreamsConsumed) { + const std::string content = "aaabbbcccdddeee"; + auto readFileMock = std::make_shared(); + + // First round: enqueue and consume all streams. + expectPreads(*readFileMock, content, {{0, 6}}); + + BufferedInput input( + readFileMock, + *pool_, + MetricsLog::voidLog(), + nullptr, + nullptr, + 10, + /* wsVRLoad = */ false); + + const auto ret1 = input.enqueue({0, 3}); + const auto ret2 = input.enqueue({3, 3}); + ASSERT_NE(ret1, nullptr); + ASSERT_NE(ret2, nullptr); + + input.load(LogType::TEST); + + // Consume all streams. + auto next1 = getNext(*ret1); + ASSERT_TRUE(next1.has_value()); + EXPECT_EQ(next1.value(), "aaa"); + + const auto next2 = getNext(*ret2); + ASSERT_TRUE(next2.has_value()); + EXPECT_EQ(next2.value(), "bbb"); + + // Reset and enqueue new streams. + input.reset(); + EXPECT_EQ(input.nextFetchSize(), 0); + + // Second round: enqueue different regions after reset. + expectPreads(*readFileMock, content, {{6, 9}}); + + const auto ret3 = input.enqueue({6, 3}); + const auto ret4 = input.enqueue({9, 6}); + ASSERT_NE(ret3, nullptr); + ASSERT_NE(ret4, nullptr); + + EXPECT_EQ(input.nextFetchSize(), 9); + input.load(LogType::TEST); + + const auto next3 = getNext(*ret3); + ASSERT_TRUE(next3.has_value()); + EXPECT_EQ(next3.value(), "ccc"); + + const auto next4 = getNext(*ret4); + ASSERT_TRUE(next4.has_value()); + EXPECT_EQ(next4.value(), "dddeee"); +} + +TEST_F(BufferedInputTest, resetAfterPartialStreamsConsumed) { + const std::string content = "aaabbbcccdddeee"; + auto readFileMock = std::make_shared(); + + // First round: enqueue streams but only consume some. + expectPreads(*readFileMock, content, {{0, 9}}); + + BufferedInput input( + readFileMock, + *pool_, + MetricsLog::voidLog(), + nullptr, + nullptr, + 10, + /*wsVRLoad=*/false); + + const auto ret1 = input.enqueue({0, 3}); + const auto ret2 = input.enqueue({3, 3}); + const auto ret3 = input.enqueue({6, 3}); + ASSERT_NE(ret1, nullptr); + ASSERT_NE(ret2, nullptr); + ASSERT_NE(ret3, nullptr); + + input.load(LogType::TEST); + + // Only consume the first stream, leave ret2 and ret3 unconsumed. + auto next1 = getNext(*ret1); + ASSERT_TRUE(next1.has_value()); + EXPECT_EQ(next1.value(), "aaa"); + + // Reset without consuming all streams. + input.reset(); + EXPECT_EQ(input.nextFetchSize(), 0); + + // Second round: enqueue different regions after reset. + expectPreads(*readFileMock, content, {{9, 6}}); + + const auto ret4 = input.enqueue({9, 6}); + ASSERT_NE(ret4, nullptr); + + EXPECT_EQ(input.nextFetchSize(), 6); + input.load(LogType::TEST); + + const auto next4 = getNext(*ret4); + ASSERT_TRUE(next4.has_value()); + EXPECT_EQ(next4.value(), "dddeee"); +} + +TEST_F(BufferedInputTest, preload) { + std::string content = "hello world, this is preload test data!"; + auto readFile = + std::make_shared( + content); + + BufferedInput input( + readFile, + *pool_, + MetricsLog::voidLog(), + nullptr, + nullptr, + 10, + /* wsVRLoad = */ false); + + ASSERT_EQ(readFile->numReads(), 0); + EXPECT_FALSE(input.preloaded()); + + input.preload(); + EXPECT_TRUE(input.preloaded()); + + const auto readsAfterPreload = readFile->numReads(); + ASSERT_GT(readsAfterPreload, 0); + + // After preload, sub-region reads should be served from preloaded data. + auto stream1 = input.read(0, 5, LogType::FILE); + auto next1 = getNext(*stream1); + ASSERT_TRUE(next1.has_value()); + EXPECT_EQ(next1.value(), "hello"); + + auto stream2 = input.read(6, 5, LogType::FILE); + auto next2 = getNext(*stream2); + ASSERT_TRUE(next2.has_value()); + EXPECT_EQ(next2.value(), "world"); + + // No additional file reads after preload. + ASSERT_EQ(readFile->numReads(), readsAfterPreload); +} + +class CustomDirectBufferedInput + : public facebook::velox::dwio::common::DirectBufferedInput { + public: + CustomDirectBufferedInput( + std::shared_ptr readFile, + const facebook::velox::dwio::common::MetricsLogPtr& metricsLog, + facebook::velox::StringIdLease fileNum, + std::shared_ptr tracker, + facebook::velox::StringIdLease groupId, + std::shared_ptr ioStatistics, + std::shared_ptr ioStats, + folly::Executor* executor, + const facebook::velox::io::ReaderOptions& readerOptions, + folly::F14FastMap fileReadOps = {}) + : DirectBufferedInput( + std::move(readFile), + metricsLog, + std::move(fileNum), + std::move(tracker), + std::move(groupId), + std::move(ioStatistics), + std::move(ioStats), + executor, + readerOptions, + std::move(fileReadOps)) { + VELOX_NYI("Not implemented in CustomBufferedInputBuilder"); + } + + std::unique_ptr clone() + const override { + return std::unique_ptr( + new CustomDirectBufferedInput( + input_, + fileNum_, + tracker_, + groupId_, + ioStatistics_, + ioStats_, + executor_, + options_)); + } + + protected: + // Expose protected members to verify their accessibility. + using facebook::velox::dwio::common::DirectBufferedInput::coalescedLoads_; + using facebook::velox::dwio::common::DirectBufferedInput::requests_; + + private: + CustomDirectBufferedInput( + std::shared_ptr input, + facebook::velox::StringIdLease fileNum, + std::shared_ptr tracker, + facebook::velox::StringIdLease groupId, + std::shared_ptr ioStatistics, + std::shared_ptr ioStats, + folly::Executor* executor, + const facebook::velox::io::ReaderOptions& readerOptions) + : DirectBufferedInput( + std::move(input), + std::move(fileNum), + std::move(tracker), + std::move(groupId), + std::move(ioStatistics), + std::move(ioStats), + executor, + readerOptions) {} +}; + +class CustomBufferedInputBuilder + : public facebook::velox::connector::hive::BufferedInputBuilder { + public: + std::unique_ptr create( + const facebook::velox::FileHandle& fileHandle, + const facebook::velox::dwio::common::ReaderOptions& readerOpts, + const facebook::velox::connector::ConnectorQueryCtx* connectorQueryCtx, + std::shared_ptr ioStatistics, + std::shared_ptr ioStats, + folly::Executor* executor, + const folly::F14FastMap& fileReadOps = {}) + override { + auto file = std::make_shared(11, 100 << 20, ioStats); + auto tracker = std::make_shared( + "", nullptr, /*loadQuantum=*/8 << 20); + return std::make_unique( + std::move(file), + facebook::velox::dwio::common::MetricsLog::voidLog(), + fileHandle.uuid, + std::move(tracker), + fileHandle.groupId, + std::move(ioStatistics), + std::move(ioStats), + executor, + readerOpts, + fileReadOps); + } +}; + +class CustomBufferedInputTest : public testing::Test { + protected: + static void SetUpTestCase() { + MemoryManager::testingSetInstance(MemoryManager::Options{}); + facebook::velox::connector::hive::BufferedInputBuilder::registerBuilder( + std::make_shared()); + } + + const std::shared_ptr pool_ = memoryManager()->addLeafPool(); + std::shared_ptr dataIoStats_{ + std::make_shared()}; + std::shared_ptr metadataIoStats_{ + std::make_shared()}; +}; + +} // namespace + +TEST_F(CustomBufferedInputTest, basic) { + facebook::velox::FileHandle fileHandle; + facebook::velox::dwio::common::ReaderOptions readerOpts(pool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); + auto ioStatistics = std::make_shared(); + auto ioStats = std::make_shared(); + auto executor = std::make_unique(10, 10); + + VELOX_ASSERT_THROW( + facebook::velox::connector::hive::BufferedInputBuilder::getInstance() + ->create( + fileHandle, readerOpts, nullptr, ioStatistics, ioStats, nullptr), + "Not implemented in CustomBufferedInputBuilder"); +} diff --git a/velox/dwio/common/tests/CMakeLists.txt b/velox/dwio/common/tests/CMakeLists.txt index b65afce1367..ccf991c6daf 100644 --- a/velox/dwio/common/tests/CMakeLists.txt +++ b/velox/dwio/common/tests/CMakeLists.txt @@ -14,6 +14,7 @@ add_subdirectory(utils) add_library(velox_dwio_faulty_file_sink FaultyFileSink.cpp) +velox_add_test_headers(velox_dwio_faulty_file_sink FaultyFileSink.h) target_link_libraries(velox_dwio_faulty_file_sink velox_file_test_utils velox_dwio_common) # There is an issue with the VTT symbol for the InlineExecutor from folly when @@ -38,6 +39,7 @@ add_executable( DecoderUtilTest.cpp ExecutorBarrierTest.cpp OnDemandUnitLoaderTests.cpp + ParallelUnitLoaderTest.cpp LocalFileSinkTest.cpp MemorySinkTest.cpp LoggedExceptionTest.cpp @@ -48,21 +50,26 @@ add_executable( ReaderTest.cpp RetryTests.cpp ScanSpecTest.cpp + SelectiveColumnReaderTest.cpp SortingWriterTest.cpp - TestBufferedInput.cpp + StreamUtilTest.cpp + BufferedInputTest.cpp + CachedBufferedInputTest.cpp + DirectBufferedInputTest.cpp ThrottlerTest.cpp TypeTests.cpp UnitLoaderToolsTests.cpp WriterTest.cpp OptionsTests.cpp ) +velox_add_test_headers(velox_dwio_common_test UnitLoaderBaseTest.h) add_test(velox_dwio_common_test velox_dwio_common_test) target_link_libraries( velox_dwio_common_test velox_dwio_common_test_utils - velox_temp_path + velox_hive_connector + velox_test_util velox_vector_test_lib - Boost::regex velox_link_libs Folly::folly ${TEST_LINK_LIBS} @@ -97,7 +104,7 @@ if(VELOX_ENABLE_BENCHMARKS) ) if(VELOX_ENABLE_ARROW) - add_subdirectory(Lemire/FastPFor) + add_subdirectory(Lemire) add_executable(velox_dwio_common_bitpack_decoder_benchmark BitPackDecoderBenchmark.cpp) target_compile_options( diff --git a/velox/dwio/common/tests/CachedBufferedInputTest.cpp b/velox/dwio/common/tests/CachedBufferedInputTest.cpp new file mode 100644 index 00000000000..a1caeca3728 --- /dev/null +++ b/velox/dwio/common/tests/CachedBufferedInputTest.cpp @@ -0,0 +1,1733 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/dwio/common/CachedBufferedInput.h" + +#include + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/caching/FileIds.h" +#include "velox/common/file/tests/TestUtils.h" +#include "velox/common/io/IoStatistics.h" +#include "velox/common/io/Options.h" +#include "velox/common/memory/MallocAllocator.h" +#include "velox/common/testutil/TestValue.h" + +#include +#include +#include + +#include + +using namespace facebook::velox; +using namespace facebook::velox::dwio::common; +using namespace facebook::velox::cache; +using namespace facebook::velox::memory; +using facebook::velox::common::testutil::TestValue; + +namespace { + +std::optional getNext(SeekableInputStream& input) { + const void* buf = nullptr; + int32_t size; + if (input.Next(&buf, &size)) { + return std::string( + static_cast(buf), static_cast(size)); + } else { + return std::nullopt; + } +} + +class CachedBufferedInputTest : public testing::Test { + protected: + static void SetUpTestCase() { + TestValue::enable(); + MemoryManager::testingSetInstance(MemoryManager::Options{}); + } + + void SetUp() override { + executor_ = std::make_unique(10); + allocator_ = std::make_shared(MemoryAllocator::Options{ + .capacity = 512 << 20, .reservationByteLimit = 0}); + cache_ = AsyncDataCache::create(allocator_.get()); + tracker_ = std::make_shared( + "testTracker", nullptr, 256 << 10 /* 256KB */); + pool_ = memoryManager()->addLeafPool(); + } + + void TearDown() override { + executor_.reset(); + cache_->shutdown(); + cache_.reset(); + allocator_.reset(); + } + + const std::shared_ptr dataIoStats_{ + std::make_shared()}; + const std::shared_ptr metadataIoStats_{ + std::make_shared()}; + + std::unique_ptr executor_; + std::shared_ptr pool_; + std::shared_ptr allocator_; + std::shared_ptr cache_; + std::shared_ptr tracker_; +}; + +enum class CacheRegionApi { + kStringView, + kIOBuf, +}; + +std::string cacheRegionApiString(CacheRegionApi api) { + switch (api) { + case CacheRegionApi::kStringView: + return "StringView"; + case CacheRegionApi::kIOBuf: + return "IOBuf"; + } + VELOX_UNREACHABLE(); +} + +class CacheRegionTest : public CachedBufferedInputTest, + public testing::WithParamInterface { + protected: + // Calls cacheRegion using the API selected by the test parameter. + // For IOBuf mode, splits the source data into multiple chained buffers + // to exercise the multi-buffer path. + void cacheRegionWithApi( + CachedBufferedInput& input, + uint64_t offset, + uint64_t length, + const char* data) { + switch (GetParam()) { + case CacheRegionApi::kStringView: { + input.cacheRegion(offset, length, std::string_view(data, length)); + break; + } + case CacheRegionApi::kIOBuf: { + // Split data into 3 chained IOBufs to exercise multi-buffer cursor + // path. Also prepend a prefix to exercise non-zero bufferOffset. + constexpr uint64_t kPrefix = 37; + const std::string prefix(kPrefix, 'Z'); + const uint64_t chunk1 = length / 3; + const uint64_t chunk2 = length / 3; + const uint64_t chunk3 = length - chunk1 - chunk2; + + auto iobuf = folly::IOBuf::copyBuffer(prefix); + iobuf->appendToChain(folly::IOBuf::copyBuffer(data, chunk1)); + iobuf->appendToChain(folly::IOBuf::copyBuffer(data + chunk1, chunk2)); + iobuf->appendToChain( + folly::IOBuf::copyBuffer(data + chunk1 + chunk2, chunk3)); + ASSERT_EQ(iobuf->computeChainDataLength(), kPrefix + length); + input.cacheRegion(offset, length, *iobuf, kPrefix); + break; + } + } + } +}; + +TEST_F(CachedBufferedInputTest, reset) { + constexpr int32_t kContentSize = 4 << 20; // 4MB + std::string content; + content.resize(kContentSize); + for (int32_t i = 0; i < kContentSize; ++i) { + content[i] = static_cast('a' + (i % 26)); + } + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setLoadQuantum(1 << 20); + + auto& ids = fileIds(); + StringIdLease fileId(ids, "testFile"); + StringIdLease groupId(ids, "testGroup"); + + CachedBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + cache_.get(), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + // First round: enqueue and load streams. + constexpr int32_t kRegionSize = 8 << 10; // 8KB + auto stream1 = input.enqueue({0, kRegionSize}, nullptr); + auto stream2 = input.enqueue({kRegionSize, kRegionSize}, nullptr); + ASSERT_NE(stream1, nullptr); + ASSERT_NE(stream2, nullptr); + + // Verify cache is empty before load. + auto stats = cache_->refreshStats(); + EXPECT_EQ(stats.numEntries, 0); + + input.load(LogType::TEST); + + EXPECT_EQ(input.testingStreamToCoalescedLoadSize(), 2); + EXPECT_EQ(input.testingCoalescedLoads().size(), 1); + + // Wait until cache has two entries after load. + while (cache_->refreshStats().numEntries != 2) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + stats = cache_->refreshStats(); + EXPECT_EQ(stats.numEntries, 2); + + EXPECT_EQ(input.testingStreamToCoalescedLoadSize(), 2); + EXPECT_EQ(input.testingCoalescedLoads().size(), 1); + + // Consume streams. + auto next1 = getNext(*stream1); + ASSERT_TRUE(next1.has_value()); + EXPECT_EQ(next1.value(), content.substr(0, kRegionSize)); + EXPECT_EQ(input.testingStreamToCoalescedLoadSize(), 0); + EXPECT_EQ(input.testingCoalescedLoads().size(), 1); + + stats = cache_->refreshStats(); + EXPECT_EQ(stats.numEntries, 2); + + auto next2 = getNext(*stream2); + ASSERT_TRUE(next2.has_value()); + EXPECT_EQ(next2.value(), content.substr(kRegionSize, kRegionSize)); + EXPECT_EQ(input.testingStreamToCoalescedLoadSize(), 0); + EXPECT_EQ(input.testingCoalescedLoads().size(), 1); + + stream1.reset(); + stream2.reset(); + + EXPECT_EQ(input.testingStreamToCoalescedLoadSize(), 0); + EXPECT_EQ(input.testingCoalescedLoads().size(), 1); + + stats = cache_->refreshStats(); + EXPECT_EQ(stats.sharedPinnedBytes, 0); + EXPECT_EQ(stats.exclusivePinnedBytes, 0); + EXPECT_EQ(stats.numEntries, 2); + + // Reset the input. + input.reset(); + + EXPECT_EQ(input.testingStreamToCoalescedLoadSize(), 0); + EXPECT_EQ(input.testingCoalescedLoads().size(), 0); + + stats = cache_->refreshStats(); + EXPECT_EQ(stats.sharedPinnedBytes, 0); + EXPECT_EQ(stats.exclusivePinnedBytes, 0); + EXPECT_EQ(stats.numEntries, 2); + + // Second round: enqueue different regions after reset. + auto stream3 = input.enqueue({2 * kRegionSize, kRegionSize}, nullptr); + auto stream4 = input.enqueue({3 * kRegionSize, kRegionSize}, nullptr); + ASSERT_NE(stream3, nullptr); + ASSERT_NE(stream4, nullptr); + + stats = cache_->refreshStats(); + EXPECT_EQ(stats.sharedPinnedBytes, 0); + EXPECT_EQ(stats.exclusivePinnedBytes, 0); + EXPECT_EQ(stats.numEntries, 2); + + input.load(LogType::TEST); + + // Wait until cache has two entries after load. + while (cache_->refreshStats().numEntries != 4) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + stats = cache_->refreshStats(); + EXPECT_EQ(stats.numEntries, 4); + + auto next3 = getNext(*stream3); + ASSERT_TRUE(next3.has_value()); + EXPECT_EQ(next3.value(), content.substr(2 * kRegionSize, kRegionSize)); + + auto next4 = getNext(*stream4); + ASSERT_TRUE(next4.has_value()); + EXPECT_EQ(next4.value(), content.substr(3 * kRegionSize, kRegionSize)); + + // Reset the input. + input.reset(); + + stats = cache_->refreshStats(); + EXPECT_GT(stats.sharedPinnedBytes, 0); + EXPECT_EQ(stats.exclusivePinnedBytes, 0); + EXPECT_EQ(stats.numEntries, 4); + + stream3.reset(); + stream4.reset(); + + stats = cache_->refreshStats(); + EXPECT_EQ(stats.sharedPinnedBytes, 0); + EXPECT_EQ(stats.exclusivePinnedBytes, 0); + EXPECT_EQ(stats.numEntries, 4); +} + +TEST_F(CachedBufferedInputTest, readAfterReset) { + constexpr int32_t kContentSize = 4 << 20; // 4MB + std::string content; + content.resize(kContentSize); + for (int32_t i = 0; i < kContentSize; ++i) { + content[i] = static_cast('a' + (i % 26)); + } + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setLoadQuantum(1 << 20); + + auto& ids = fileIds(); + StringIdLease fileId(ids, "testFile"); + StringIdLease groupId(ids, "testGroup"); + + CachedBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + cache_.get(), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + constexpr int32_t kRegionSize = 8 << 10; // 8KB + + // Enqueue and load streams. + auto stream1 = input.enqueue({0, kRegionSize}, nullptr); + auto stream2 = input.enqueue({kRegionSize, kRegionSize}, nullptr); + ASSERT_NE(stream1, nullptr); + ASSERT_NE(stream2, nullptr); + + input.load(LogType::TEST); + + // Wait until cache has two entries after load. + while (cache_->refreshStats().numEntries != 2) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + auto stats = cache_->refreshStats(); + EXPECT_EQ(stats.numEntries, 2); + + // Reset the input before reading from streams. + input.reset(); + EXPECT_EQ(input.testingStreamToCoalescedLoadSize(), 0); + EXPECT_EQ(input.testingCoalescedLoads().size(), 0); + + stats = cache_->refreshStats(); + EXPECT_EQ(stats.numEntries, 2); + + // Read from streams after reset - data should still be available from cache. + auto next1 = getNext(*stream1); + ASSERT_TRUE(next1.has_value()); + EXPECT_EQ(next1.value(), content.substr(0, kRegionSize)); + + auto next2 = getNext(*stream2); + ASSERT_TRUE(next2.has_value()); + EXPECT_EQ(next2.value(), content.substr(kRegionSize, kRegionSize)); + + EXPECT_EQ(stats.sharedPinnedBytes, 0); + EXPECT_EQ(stats.exclusivePinnedBytes, 0); + EXPECT_EQ(stats.numEntries, 2); +} + +DEBUG_ONLY_TEST_F(CachedBufferedInputTest, resetInputWithBeforeLoading) { + constexpr int32_t kContentSize = 4 << 20; // 4MB + std::string content; + content.resize(kContentSize); + for (int32_t i = 0; i < kContentSize; ++i) { + content[i] = static_cast('a' + (i % 26)); + } + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setLoadQuantum(1 << 20); + + auto& ids = fileIds(); + StringIdLease fileId(ids, "testFile"); + StringIdLease groupId(ids, "testGroup"); + + CachedBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + cache_.get(), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + constexpr int32_t kRegionSize = 8 << 10; // 8KB + + // Block the coalesced load to verify cache references are held. + folly::Baton<> loadStarted; + folly::Baton<> loadAllowed; + + SCOPED_TESTVALUE_SET( + "facebook::velox::cache::CoalescedLoad::loadOrFuture", + std::function( + [&](const CoalescedLoad* /*load*/) { + loadStarted.post(); + loadAllowed.wait(); + })); + + // Enqueue and load streams. + auto stream1 = input.enqueue({0, kRegionSize}, nullptr); + auto stream2 = input.enqueue({kRegionSize, kRegionSize}, nullptr); + ASSERT_NE(stream1, nullptr); + ASSERT_NE(stream2, nullptr); + + input.load(LogType::TEST); + + // Wait for the load to start (but it's blocked). + loadStarted.wait(); + + // Verify cache is still empty (load is pending). + auto stats = cache_->refreshStats(); + EXPECT_EQ(stats.numEntries, 0); + EXPECT_EQ(stats.numExclusive, 0); + + // Verify coalesced load references are held. + EXPECT_EQ(input.testingStreamToCoalescedLoadSize(), 2); + EXPECT_EQ(input.testingCoalescedLoads().size(), 1); + + // Reset the input while load is pending. + input.reset(); + + stats = cache_->refreshStats(); + EXPECT_EQ(stats.numEntries, 0); + EXPECT_EQ(stats.numExclusive, 0); + + // After reset, internal tracking should be cleared. + EXPECT_EQ(input.testingStreamToCoalescedLoadSize(), 0); + EXPECT_EQ(input.testingCoalescedLoads().size(), 0); + + // Allow the load to proceed but cancelled. + loadAllowed.post(); + + std::this_thread::sleep_for(std::chrono::seconds(1)); + + stats = cache_->refreshStats(); + EXPECT_EQ(stats.numEntries, 0); + EXPECT_EQ(stats.numExclusive, 0); + + // Read from streams - data should be available from cache. + auto next1 = getNext(*stream1); + ASSERT_TRUE(next1.has_value()); + EXPECT_EQ(next1.value(), content.substr(0, kRegionSize)); + + auto next2 = getNext(*stream2); + ASSERT_TRUE(next2.has_value()); + EXPECT_EQ(next2.value(), content.substr(kRegionSize, kRegionSize)); + + stats = cache_->refreshStats(); + EXPECT_EQ(stats.numEntries, 2); + EXPECT_EQ(stats.numExclusive, 0); +} + +DEBUG_ONLY_TEST_F(CachedBufferedInputTest, resetInputWithAfterLoading) { + constexpr int32_t kContentSize = 4 << 20; // 4MB + std::string content; + content.resize(kContentSize); + for (int32_t i = 0; i < kContentSize; ++i) { + content[i] = static_cast('a' + (i % 26)); + } + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setLoadQuantum(1 << 20); + + auto& ids = fileIds(); + StringIdLease fileId(ids, "testFile"); + StringIdLease groupId(ids, "testGroup"); + + CachedBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + cache_.get(), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + constexpr int32_t kRegionSize = 8 << 10; // 8KB + + // Block the coalesced load to verify cache references are held. + folly::Baton<> loadStarted; + folly::Baton<> loadAllowed; + + SCOPED_TESTVALUE_SET( + "facebook::velox::cache::CoalescedLoad::loadOrFuture::loading", + std::function( + [&](const CoalescedLoad* /*load*/) { + loadStarted.post(); + loadAllowed.wait(); + })); + + // Enqueue and load streams. + auto stream1 = input.enqueue({0, kRegionSize}, nullptr); + auto stream2 = input.enqueue({kRegionSize, kRegionSize}, nullptr); + ASSERT_NE(stream1, nullptr); + ASSERT_NE(stream2, nullptr); + + input.load(LogType::TEST); + + // Wait for the load to start (but it's blocked). + loadStarted.wait(); + + // Verify cache is still empty (load is pending). + auto stats = cache_->refreshStats(); + EXPECT_EQ(stats.numEntries, 0); + EXPECT_EQ(stats.numExclusive, 0); + + // Verify coalesced load references are held. + EXPECT_EQ(input.testingStreamToCoalescedLoadSize(), 2); + EXPECT_EQ(input.testingCoalescedLoads().size(), 1); + + // Reset the input while load is pending. + input.reset(); + + stats = cache_->refreshStats(); + EXPECT_EQ(stats.numEntries, 0); + EXPECT_EQ(stats.numExclusive, 0); + + // After reset, internal tracking should be cleared. + EXPECT_EQ(input.testingStreamToCoalescedLoadSize(), 0); + EXPECT_EQ(input.testingCoalescedLoads().size(), 0); + + // Allow the load to proceed without cancelling. + loadAllowed.post(); + + // Wait until cache has two entries after load. + while (cache_->refreshStats().numEntries != 2) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + stats = cache_->refreshStats(); + EXPECT_EQ(stats.numEntries, 2); + EXPECT_EQ(stats.numExclusive, 0); + + // Read from streams - data should be available from cache. + auto next1 = getNext(*stream1); + ASSERT_TRUE(next1.has_value()); + EXPECT_EQ(next1.value(), content.substr(0, kRegionSize)); + + auto next2 = getNext(*stream2); + ASSERT_TRUE(next2.has_value()); + EXPECT_EQ(next2.value(), content.substr(kRegionSize, kRegionSize)); + + stats = cache_->refreshStats(); + EXPECT_EQ(stats.numEntries, 2); + EXPECT_EQ(stats.numExclusive, 0); +} + +TEST_F(CachedBufferedInputTest, hasCache) { + constexpr int32_t kContentSize = 1 << 20; + std::string content; + content.resize(kContentSize); + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto& ids = fileIds(); + StringIdLease fileId(ids, "testFile"); + StringIdLease groupId(ids, "testGroup"); + + CachedBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + cache_.get(), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + // CachedBufferedInput always has cache. + EXPECT_TRUE(input.hasCache()); +} + +TEST_P(CacheRegionTest, cacheAndFind) { + SCOPED_TRACE(cacheRegionApiString(GetParam())); + constexpr int32_t kContentSize = 1 << 20; // 1MB + std::string content; + content.resize(kContentSize); + for (int32_t i = 0; i < kContentSize; ++i) { + content[i] = static_cast('a' + (i % 26)); + } + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto& ids = fileIds(); + StringIdLease fileId(ids, "testFile_cacheAndFind"); + StringIdLease groupId(ids, "testGroup_cacheAndFind"); + + CachedBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + cache_.get(), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + // Initially no regions are cached. + EXPECT_FALSE(input.findCachedRegion(0).has_value()); + EXPECT_FALSE(input.findCachedRegion(1000).has_value()); + + // Cache a small region (fits in tinyData path). + constexpr uint64_t kSmallSize = 100; + cacheRegionWithApi(input, 0, kSmallSize, content.data()); + + EXPECT_TRUE(input.findCachedRegion(0).has_value()); + EXPECT_FALSE(input.findCachedRegion(1000).has_value()); + + // Verify cached data is readable via read(). + auto stream = input.read(0, kSmallSize, LogType::FILE); + auto next = getNext(*stream); + ASSERT_TRUE(next.has_value()); + EXPECT_EQ(next.value(), content.substr(0, kSmallSize)); + + // Cache a larger region. + constexpr uint64_t kLargeOffset = 1000; + constexpr uint64_t kLargeSize = 64 * 1024; // 64KB + cacheRegionWithApi( + input, kLargeOffset, kLargeSize, content.data() + kLargeOffset); + + EXPECT_TRUE(input.findCachedRegion(kLargeOffset).has_value()); + + // Verify large cached data is readable. + auto stream2 = input.read(kLargeOffset, kLargeSize, LogType::FILE); + std::string readBack; + readBack.resize(kLargeSize); + stream2->readFully(readBack.data(), kLargeSize); + EXPECT_EQ(readBack, content.substr(kLargeOffset, kLargeSize)); + + // Caching the same region again is a no-op (already cached). + cacheRegionWithApi(input, 0, kSmallSize, content.data()); + EXPECT_TRUE(input.findCachedRegion(0).has_value()); + + auto stats = cache_->refreshStats(); + EXPECT_EQ(stats.numEntries, 2); +} + +TEST_P(CacheRegionTest, smallEntry) { + SCOPED_TRACE(cacheRegionApiString(GetParam())); + constexpr int32_t kContentSize = 1 << 20; + std::string content; + content.resize(kContentSize); + for (int32_t i = 0; i < kContentSize; ++i) { + content[i] = static_cast('a' + (i % 26)); + } + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto& ids = fileIds(); + StringIdLease fileId(ids, "testFile_smallEntry"); + StringIdLease groupId(ids, "testGroup_smallEntry"); + + CachedBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + cache_.get(), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + constexpr uint64_t kOffset = 0; + constexpr uint64_t kSize = 100; + static_assert(kSize < AsyncDataCacheEntry::kTinyDataSize); + + EXPECT_FALSE(input.findCachedRegion(kOffset).has_value()); + + cacheRegionWithApi(input, kOffset, kSize, content.data() + kOffset); + + // First findCachedRegion after cacheRegion should count as a cache hit since + // cacheRegion clears the first-use flag (data was populated externally, not + // loaded on-demand). + EXPECT_EQ(dataIoStats_->ramHit().count(), 0); + auto result = input.findCachedRegion(kOffset); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result->size(), kSize); + EXPECT_EQ(dataIoStats_->ramHit().count(), 1); + EXPECT_EQ(dataIoStats_->ramHit().sum(), kSize); + + // Second findCachedRegion is also a cache hit. + auto result2 = input.findCachedRegion(kOffset); + ASSERT_TRUE(result2.has_value()); + EXPECT_EQ(dataIoStats_->ramHit().count(), 2); + EXPECT_EQ(dataIoStats_->ramHit().sum(), 2 * kSize); + + // Small entry should have exactly one contiguous range. + const auto& ranges = result->ranges(); + ASSERT_EQ(ranges.size(), 1); + EXPECT_EQ(ranges[0].size(), kSize); + EXPECT_EQ( + std::string_view(ranges[0].data(), ranges[0].size()), + std::string_view(content.data() + kOffset, kSize)); + + // toIOBuf on a small (single-range) entry produces a single IOBuf. + auto iobuf = result->toIOBuf(); + EXPECT_FALSE(iobuf.isChained()); + EXPECT_EQ(iobuf.length(), kSize); + EXPECT_EQ( + std::string_view( + reinterpret_cast(iobuf.data()), iobuf.length()), + std::string_view(content.data() + kOffset, kSize)); +} + +TEST_P(CacheRegionTest, largeEntry) { + SCOPED_TRACE(cacheRegionApiString(GetParam())); + constexpr int32_t kContentSize = 1 << 20; + std::string content; + content.resize(kContentSize); + for (int32_t i = 0; i < kContentSize; ++i) { + content[i] = static_cast('a' + (i % 26)); + } + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto& ids = fileIds(); + StringIdLease fileId(ids, "testFile_largeEntry"); + StringIdLease groupId(ids, "testGroup_largeEntry"); + + CachedBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + cache_.get(), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + constexpr uint64_t kOffset = 0; + constexpr uint64_t kSize = 64 * 1024; // 64KB + static_assert(kSize >= AsyncDataCacheEntry::kTinyDataSize); + + cacheRegionWithApi(input, kOffset, kSize, content.data() + kOffset); + + auto result = input.findCachedRegion(kOffset); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result->size(), kSize); + + // Large entry may have multiple ranges. Verify total size and content. + const auto& ranges = result->ranges(); + ASSERT_GT(ranges.size(), 0); + + uint64_t totalSize = 0; + std::string reassembled; + reassembled.reserve(kSize); + for (const auto& range : ranges) { + EXPECT_GT(range.size(), 0); + reassembled.append(range.data(), range.size()); + totalSize += range.size(); + } + EXPECT_EQ(totalSize, kSize); + EXPECT_EQ(reassembled, content.substr(kOffset, kSize)); + + // toIOBuf total length matches and content matches. + auto iobuf = result->toIOBuf(); + EXPECT_EQ(iobuf.computeChainDataLength(), kSize); + + std::string fromIobuf; + fromIobuf.reserve(kSize); + const auto* current = &iobuf; + do { + fromIobuf.append( + reinterpret_cast(current->data()), current->length()); + current = current->next(); + } while (current != &iobuf); + EXPECT_EQ(fromIobuf, content.substr(kOffset, kSize)); +} + +TEST_P(CacheRegionTest, miss) { + SCOPED_TRACE(cacheRegionApiString(GetParam())); + constexpr int32_t kContentSize = 1 << 20; + std::string content(kContentSize, 'x'); + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto& ids = fileIds(); + StringIdLease fileId(ids, "testFile_miss"); + StringIdLease groupId(ids, "testGroup_miss"); + + CachedBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + cache_.get(), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + EXPECT_FALSE(input.findCachedRegion(0).has_value()); + EXPECT_FALSE(input.findCachedRegion(12345).has_value()); +} + +TEST_P(CacheRegionTest, pinKeepsDataAlive) { + SCOPED_TRACE(cacheRegionApiString(GetParam())); + constexpr int32_t kContentSize = 1 << 20; + std::string content; + content.resize(kContentSize); + for (int32_t i = 0; i < kContentSize; ++i) { + content[i] = static_cast('a' + (i % 26)); + } + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto& ids = fileIds(); + StringIdLease fileId(ids, "testFile_pinAlive"); + StringIdLease groupId(ids, "testGroup_pinAlive"); + + CachedBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + cache_.get(), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + constexpr uint64_t kSize = 100; + cacheRegionWithApi(input, 0, kSize, content.data()); + + auto cached = input.findCachedRegion(0); + ASSERT_TRUE(cached.has_value()); + + auto stats = cache_->refreshStats(); + EXPECT_GT(stats.sharedPinnedBytes, 0); + + const auto& ranges = cached->ranges(); + ASSERT_EQ(ranges.size(), 1); + EXPECT_EQ( + std::string_view(ranges[0].data(), ranges[0].size()), + std::string_view(content.data(), kSize)); + + // Move into a new optional — pin transfers, data stays valid. + auto moved = std::move(cached); + ASSERT_TRUE(moved.has_value()); + EXPECT_EQ(moved->size(), kSize); + EXPECT_EQ(moved->ranges().size(), 1); + EXPECT_EQ( + std::string_view(moved->ranges()[0].data(), moved->ranges()[0].size()), + std::string_view(content.data(), kSize)); + + stats = cache_->refreshStats(); + EXPECT_GT(stats.sharedPinnedBytes, 0); + + // Dropping the moved CachedRegion releases the pin. + moved.reset(); + stats = cache_->refreshStats(); + EXPECT_EQ(stats.sharedPinnedBytes, 0); +} + +INSTANTIATE_TEST_SUITE_P( + CacheRegionTest, + CacheRegionTest, + testing::Values(CacheRegionApi::kStringView, CacheRegionApi::kIOBuf), + [](const testing::TestParamInfo& info) { + return cacheRegionApiString(info.param); + }); + +TEST_F(CachedBufferedInputTest, findCachedRegionExclusiveWithWait) { + constexpr int32_t kContentSize = 1 << 20; + std::string content; + content.resize(kContentSize); + for (int32_t i = 0; i < kContentSize; ++i) { + content[i] = static_cast('a' + (i % 26)); + } + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto& ids = fileIds(); + const std::string fileName = "testFile_findWait"; + StringIdLease fileId(ids, fileName); + StringIdLease groupId(ids, "testGroup_findWait"); + + CachedBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + cache_.get(), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + constexpr uint64_t kOffset = 0; + constexpr uint64_t kSize = 100; + + // Create an exclusive entry using the same file ID as the + // CachedBufferedInput. + StringIdLease sameFileId(ids, fileName); + RawFileCacheKey key{sameFileId.id(), kOffset}; + auto exclusivePin = cache_->findOrCreate(key, kSize); + ASSERT_FALSE(exclusivePin.empty()); + ASSERT_TRUE(exclusivePin.entry()->isExclusive()); + + // Verify latency counters are initially zero. + EXPECT_EQ(dataIoStats_->cacheWaitLatencyUs().count(), 0); + EXPECT_EQ(dataIoStats_->queryThreadIoLatencyUs().count(), 0); + + // Spawn a thread that calls findCachedRegion — it will block waiting for + // the exclusive entry to become shared. + std::atomic_bool findFinished{false}; + std::optional findResult; + std::thread findThread([&]() { + findResult = input.findCachedRegion(kOffset); + findFinished.store(true, std::memory_order_release); + }); + + // Give the thread time to enter the wait state. + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + ASSERT_FALSE(findFinished.load(std::memory_order_acquire)); + + // Populate the entry and transition to shared. + auto* entry = exclusivePin.entry(); + ASSERT_LT(kSize, AsyncDataCacheEntry::kTinyDataSize); + ::memcpy(entry->contiguousData(), content.data() + kOffset, kSize); + entry->setExclusiveToShared(); + exclusivePin.clear(); + + // The findCachedRegion thread should now complete. + findThread.join(); + ASSERT_TRUE(findFinished.load(std::memory_order_acquire)); + + // Verify the latency counters were incremented. + EXPECT_EQ(dataIoStats_->cacheWaitLatencyUs().count(), 1); + EXPECT_GT(dataIoStats_->cacheWaitLatencyUs().sum(), 0); + EXPECT_EQ(dataIoStats_->queryThreadIoLatencyUs().count(), 1); + EXPECT_GT(dataIoStats_->queryThreadIoLatencyUs().sum(), 0); + + // Verify the returned CachedRegion has correct data. + ASSERT_TRUE(findResult.has_value()); + EXPECT_EQ(findResult->size(), kSize); + const auto& ranges = findResult->ranges(); + ASSERT_EQ(ranges.size(), 1); + EXPECT_EQ( + std::string_view(ranges[0].data(), ranges[0].size()), + std::string_view(content.data() + kOffset, kSize)); +} + +TEST_F(CachedBufferedInputTest, cacheRegionSkipsOngoingInsert) { + constexpr int32_t kContentSize = 1 << 20; + std::string content; + content.resize(kContentSize); + for (int32_t i = 0; i < kContentSize; ++i) { + content[i] = static_cast('a' + (i % 26)); + } + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto& ids = fileIds(); + const std::string fileName = "testFile_skipInsert"; + StringIdLease fileId(ids, fileName); + StringIdLease groupId(ids, "testGroup_skipInsert"); + + CachedBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + cache_.get(), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + constexpr uint64_t kOffset = 0; + constexpr uint64_t kSize = 100; + + // Simulate an ongoing insert by creating an exclusive entry directly. + StringIdLease sameFileId(ids, fileName); + RawFileCacheKey key{sameFileId.id(), kOffset}; + auto exclusivePin = cache_->findOrCreate(key, kSize); + ASSERT_FALSE(exclusivePin.empty()); + ASSERT_TRUE(exclusivePin.entry()->isExclusive()); + + // cacheRegion with different data should return immediately without blocking + // or overwriting (the entry is already exclusively held by another + // operation). + const std::string differentData(kSize, 'X'); + input.cacheRegion(kOffset, kSize, std::string_view(differentData)); + + // The entry should still be exclusive — cacheRegion gave up. + ASSERT_TRUE(exclusivePin.entry()->isExclusive()); + + // Complete the original insert with the real data. + auto* entry = exclusivePin.entry(); + ASSERT_LT(kSize, AsyncDataCacheEntry::kTinyDataSize); + ::memcpy(entry->contiguousData(), content.data() + kOffset, kSize); + entry->setExclusiveToShared(); + exclusivePin.clear(); + + // findCachedRegion should return the original data, not the 'X' data + // that cacheRegion tried to write. + auto cached = input.findCachedRegion(kOffset); + ASSERT_TRUE(cached.has_value()); + EXPECT_EQ(cached->size(), kSize); + const auto& ranges = cached->ranges(); + ASSERT_EQ(ranges.size(), 1); + EXPECT_EQ( + std::string_view(ranges[0].data(), ranges[0].size()), + std::string_view(content.data() + kOffset, kSize)); + // Verify it's NOT the data passed to cacheRegion. + EXPECT_NE( + std::string_view(ranges[0].data(), ranges[0].size()), + std::string_view(differentData)); +} + +TEST_F(CachedBufferedInputTest, preloadCalledTwice) { + std::string content(1024, 'x'); + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto& ids = fileIds(); + StringIdLease fileId(ids, "preloadTwice"); + StringIdLease groupId(ids, "preloadTwiceGroup"); + + CachedBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + cache_.get(), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + input.preload(); + ASSERT_TRUE(input.preloaded()); + VELOX_ASSERT_THROW(input.preload(), "preload() called more than once"); +} + +TEST_F(CachedBufferedInputTest, isBufferedWithPreload) { + std::string content(1024, 'x'); + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto& ids = fileIds(); + StringIdLease fileId(ids, "isBufferedPreload"); + StringIdLease groupId(ids, "isBufferedPreloadGroup"); + + CachedBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + cache_.get(), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + // Before preload, isBuffered should return false. + EXPECT_FALSE(input.isBuffered(0, 100)); + + input.preload(); + ASSERT_TRUE(input.preloaded()); + + // After preload, isBuffered should return true. + EXPECT_TRUE(input.isBuffered(0, 100)); + EXPECT_TRUE(input.isBuffered(500, 200)); +} + +TEST_F(CachedBufferedInputTest, enqueueSkipsRequestsWhenPreloaded) { + std::string content(1024, 'x'); + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto& ids = fileIds(); + StringIdLease fileId(ids, "enqueueSkipsRequests"); + StringIdLease groupId(ids, "enqueueSkipsRequestsGroup"); + + CachedBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + cache_.get(), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + ASSERT_EQ(readFile->numReads(), 0); + + input.preload(); + ASSERT_TRUE(input.preloaded()); + ASSERT_EQ(readFile->numReads(), 1); + + // After preload, enqueue should not add to requests or coalesced loads. + auto stream = input.enqueue(common::Region{0, 100}, nullptr); + ASSERT_NE(stream, nullptr); + + // Verify no coalesced loads were created (requests are skipped when + // preloaded). + EXPECT_EQ(input.testingCoalescedLoads().size(), 0); + EXPECT_EQ(input.testingStreamToCoalescedLoadSize(), 0); + + // Stream should still be able to read data from preloaded content. + auto next = getNext(*stream); + ASSERT_TRUE(next.has_value()); + EXPECT_EQ(next.value(), content.substr(0, 100)); + + // No additional file reads - all served from preloaded cache. + ASSERT_EQ(readFile->numReads(), 1); +} + +TEST_F(CachedBufferedInputTest, readSetsPreloadedPinWhenPreloaded) { + constexpr int32_t kContentSize = 1 << 20; + std::string content; + content.resize(kContentSize); + for (int32_t i = 0; i < kContentSize; ++i) { + content[i] = static_cast('a' + (i % 26)); + } + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setLoadQuantum(1 << 20); + + auto& ids = fileIds(); + StringIdLease fileId(ids, "readSetsPreloadedPin"); + StringIdLease groupId(ids, "readSetsPreloadedPinGroup"); + + CachedBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + cache_.get(), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + ASSERT_EQ(readFile->numReads(), 0); + + input.preload(); + ASSERT_TRUE(input.preloaded()); + ASSERT_EQ(readFile->numReads(), 1); + + // Use read() to create a stream (not enqueue). + auto stream = input.read(0, 100, LogType::TEST); + ASSERT_NE(stream, nullptr); + + // Stream should be able to read data from preloaded content. + auto next = getNext(*stream); + ASSERT_TRUE(next.has_value()); + EXPECT_EQ(next.value(), content.substr(0, 100)); + + // No additional file reads - all served from preloaded cache. + ASSERT_EQ(readFile->numReads(), 1); +} + +TEST_F(CachedBufferedInputTest, preloadAfterEnqueue) { + std::string content(1024, 'x'); + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto& ids = fileIds(); + StringIdLease fileId(ids, "preloadAfterEnqueue"); + StringIdLease groupId(ids, "preloadAfterEnqueueGroup"); + + CachedBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + cache_.get(), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + input.enqueue({0, 100}, nullptr); + VELOX_ASSERT_THROW( + input.preload(), "preload() must be called before enqueue()"); +} + +TEST_F(CachedBufferedInputTest, preloadedStreamSkipsEviction) { + // When cacheable_=false, non-preloaded streams mark cache entries as + // immediately evictable on destruction (via clearCachePin and + // makeCacheEvictable). Preloaded streams must skip both because the + // preloaded cache entry is shared across all streams. + constexpr int32_t kContentSize = 1 << 20; + std::string content; + content.resize(kContentSize); + for (int32_t i = 0; i < kContentSize; ++i) { + content[i] = static_cast('a' + (i % 26)); + } + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setCacheable(false); + readerOptions.setLoadQuantum(1 << 20); + + auto& ids = fileIds(); + StringIdLease fileId(ids, "preloadEvictionTest"); + StringIdLease groupId(ids, "preloadEvictionGroup"); + + CachedBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + cache_.get(), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + ASSERT_EQ(readFile->numReads(), 0); + + input.preload(); + ASSERT_TRUE(input.preloaded()); + ASSERT_EQ(readFile->numReads(), 1); + + auto statsBefore = cache_->refreshStats(); + ASSERT_GT(statsBefore.sharedPinnedBytes, 0); + + // Create a stream, read from it, and destroy it. + { + auto stream = input.enqueue(common::Region{0, 100}, nullptr); + auto next = getNext(*stream); + ASSERT_TRUE(next.has_value()); + EXPECT_EQ(next.value(), content.substr(0, 100)); + } + + // The preloaded cache entry must still be pinned and accessible after stream + // destruction, even with cacheable_=false. + auto statsAfter = cache_->refreshStats(); + EXPECT_GT(statsAfter.sharedPinnedBytes, 0); + + // A new stream should still be able to read from the preloaded entry. + auto stream2 = input.enqueue(common::Region{0, 100}, nullptr); + auto next2 = getNext(*stream2); + ASSERT_TRUE(next2.has_value()); + EXPECT_EQ(next2.value(), content.substr(0, 100)); + + // No additional file reads — all served from preloaded cache. + ASSERT_EQ(readFile->numReads(), 1); +} + +TEST_F(CachedBufferedInputTest, preloadRespectsNotCacheable) { + // When cacheable_=false, the preloaded cache entry should be made evictable + // (lastUse=0, numUses=0) when CachedBufferedInput is destroyed. + constexpr int32_t kContentSize = 1 << 20; + std::string content; + content.resize(kContentSize); + for (int32_t i = 0; i < kContentSize; ++i) { + content[i] = static_cast('a' + (i % 26)); + } + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setCacheable(false); + readerOptions.setLoadQuantum(1 << 20); + + auto& ids = fileIds(); + StringIdLease fileId(ids, "preloadNotCacheableTest"); + StringIdLease groupId(ids, "preloadNotCacheableGroup"); + const auto fileNumId = fileId.id(); + + { + CachedBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + cache_.get(), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + input.preload(); + ASSERT_TRUE(input.preloaded()); + + auto stats = cache_->refreshStats(); + ASSERT_GT(stats.sharedPinnedBytes, 0); + } + + // After CachedBufferedInput destruction with cacheable_=false, the preloaded + // entry should be marked as immediately evictable via makeEvictable(). + cache::RawFileCacheKey key{fileNumId, 0}; + EXPECT_TRUE(cache_->testingIsEvictable(key)); +} + +TEST_F(CachedBufferedInputTest, preloadRespectsCacheable) { + // When cacheable_=true (default), the preloaded cache entry should not be + // marked as immediately evictable after CachedBufferedInput is destroyed. + constexpr int32_t kContentSize = 1 << 20; + std::string content; + content.resize(kContentSize); + for (int32_t i = 0; i < kContentSize; ++i) { + content[i] = static_cast('a' + (i % 26)); + } + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setCacheable(true); + readerOptions.setLoadQuantum(1 << 20); + + auto& ids = fileIds(); + StringIdLease fileId(ids, "preloadCacheableTest"); + StringIdLease groupId(ids, "preloadCacheableGroup"); + const auto fileNumId = fileId.id(); + + { + CachedBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + cache_.get(), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + input.preload(); + ASSERT_TRUE(input.preloaded()); + } + + // After CachedBufferedInput destruction with cacheable_=true, the cache + // entry should NOT be marked as immediately evictable. + cache::RawFileCacheKey key{fileNumId, 0}; + EXPECT_FALSE(cache_->testingIsEvictable(key)); +} + +TEST_F(CachedBufferedInputTest, preload) { + struct TestParam { + uint64_t fileSize; + std::string debugString() const { + return fmt::format("fileSize {}", fileSize); + } + }; + std::vector testSettings = { + // Small file (tinyData path in cache entry). + {100}, + // File at tinyData boundary. + {AsyncDataCacheEntry::kTinyDataSize}, + // Larger file (allocation path). + {1 << 20}, + }; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + std::string content; + content.resize(testData.fileSize); + for (uint64_t i = 0; i < testData.fileSize; ++i) { + content[i] = static_cast('a' + (i % 26)); + } + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setLoadQuantum(1 << 20); + + auto& ids = fileIds(); + StringIdLease fileId(ids, fmt::format("preloadTest_{}", testData.fileSize)); + StringIdLease groupId( + ids, fmt::format("preloadGroup_{}", testData.fileSize)); + + CachedBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + cache_.get(), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + ASSERT_EQ(readFile->numReads(), 0); + EXPECT_FALSE(input.preloaded()); + + auto statsBefore = cache_->refreshStats(); + const auto readCountBefore = dataIoStats_->read().count(); + const auto readSumBefore = dataIoStats_->read().sum(); + const auto rawBytesBefore = dataIoStats_->rawBytesRead(); + + input.preload(); + + EXPECT_TRUE(input.preloaded()); + + const auto readsAfterPreload = readFile->numReads(); + ASSERT_GT(readsAfterPreload, 0); + + // Cache should have one new entry for the whole file. + auto statsAfter = cache_->refreshStats(); + EXPECT_EQ(statsAfter.numEntries, statsBefore.numEntries + 1); + // preloadPin_ holds a shared pin. + EXPECT_GT(statsAfter.sharedPinnedBytes, 0); + // IO stats: one storage read of fileSize bytes. + EXPECT_EQ(dataIoStats_->read().count(), readCountBefore + 1); + EXPECT_EQ(dataIoStats_->read().sum(), readSumBefore + testData.fileSize); + EXPECT_EQ(dataIoStats_->rawBytesRead(), rawBytesBefore + testData.fileSize); + + // Enqueue sub-region streams and read from the preloaded cache entry. + const uint64_t regionSize = + std::min(testData.fileSize / 2, testData.fileSize); + auto stream1 = input.enqueue(common::Region{0, regionSize}, nullptr); + ASSERT_NE(stream1, nullptr); + + auto next1 = getNext(*stream1); + ASSERT_TRUE(next1.has_value()); + EXPECT_EQ(next1.value(), content.substr(0, regionSize)); + + if (testData.fileSize > regionSize) { + auto stream2 = input.enqueue( + common::Region{regionSize, testData.fileSize - regionSize}, nullptr); + ASSERT_NE(stream2, nullptr); + + auto next2 = getNext(*stream2); + ASSERT_TRUE(next2.has_value()); + EXPECT_EQ( + next2.value(), + content.substr(regionSize, testData.fileSize - regionSize)); + } + + // No additional file reads after preload. + ASSERT_EQ(readFile->numReads(), readsAfterPreload); + } +} + +TEST_F(CachedBufferedInputTest, preloadCacheSharing) { + constexpr int32_t kContentSize = 1 << 20; // 1MB + std::string content; + content.resize(kContentSize); + for (int32_t i = 0; i < kContentSize; ++i) { + content[i] = static_cast('a' + (i % 26)); + } + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setLoadQuantum(1 << 20); + + auto& ids = fileIds(); + const std::string fileName = "preloadSharingTest"; + StringIdLease fileId1(ids, fileName); + StringIdLease groupId1(ids, "preloadSharingGroup1"); + + CachedBufferedInput input1( + readFile, + MetricsLog::voidLog(), + std::move(fileId1), + cache_.get(), + tracker_, + std::move(groupId1), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + ASSERT_EQ(readFile->numReads(), 0); + + // First preload reads from storage. + input1.preload(); + EXPECT_TRUE(input1.preloaded()); + ASSERT_EQ(readFile->numReads(), 1); + + auto stats1 = cache_->refreshStats(); + auto readCount1 = dataIoStats_->read().count(); + + // Second CachedBufferedInput with the same file should hit the cache. + StringIdLease fileId2(ids, fileName); + StringIdLease groupId2(ids, "preloadSharingGroup2"); + + CachedBufferedInput input2( + readFile, + MetricsLog::voidLog(), + std::move(fileId2), + cache_.get(), + tracker_, + std::move(groupId2), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + input2.preload(); + EXPECT_TRUE(input2.preloaded()); + + // No additional file read — second preload hits the cache. + ASSERT_EQ(readFile->numReads(), 1); + + // No additional cache entry should be created. + auto stats2 = cache_->refreshStats(); + EXPECT_EQ(stats2.numEntries, stats1.numEntries); + + // No additional storage read should happen (cache hit). + EXPECT_EQ(dataIoStats_->read().count(), readCount1); + + // Both inputs should still be able to read data. + auto stream1 = input1.enqueue(common::Region{0, 100}, nullptr); + auto next1 = getNext(*stream1); + ASSERT_TRUE(next1.has_value()); + EXPECT_EQ(next1.value(), content.substr(0, 100)); + + auto stream2 = input2.enqueue(common::Region{0, 100}, nullptr); + auto next2 = getNext(*stream2); + ASSERT_TRUE(next2.has_value()); + EXPECT_EQ(next2.value(), content.substr(0, 100)); + + // No additional file reads — all served from cache. + ASSERT_EQ(readFile->numReads(), 1); +} + +TEST_F(CachedBufferedInputTest, prefetchScope) { + constexpr int32_t kContentSize = 32 << 20; + std::string content; + content.resize(kContentSize); + for (int32_t i = 0; i < kContentSize; ++i) { + content[i] = static_cast('a' + (i % 26)); + } + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + + auto& ids = fileIds(); + StringIdLease fileId(ids, "testFile"); + StringIdLease groupId(ids, "testGroup"); + + CachedBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + cache_.get(), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + // Enqueue non-prefetch requests (with stream IDs). + constexpr int32_t kNumRequests = 3; + constexpr uint64_t kRequestSize = 500; + std::vector> streamIds; + for (int32_t i = 0; i < kNumRequests; ++i) { + streamIds.push_back(std::make_unique(i)); + } + for (int32_t i = 0; i < kNumRequests; ++i) { + input.enqueue( + {static_cast(i * 1000), kRequestSize}, streamIds[i].get()); + } + input.load(LogType::TEST); + + // There should be exactly 1 non-prefetch CoalescedLoad in kPlanned state. + const auto& loadsBeforePrefetch = input.testingCoalescedLoads(); + ASSERT_EQ(loadsBeforePrefetch.size(), 1); + ASSERT_EQ(loadsBeforePrefetch[0]->state(), CoalescedLoad::State::kPlanned); + + // Enqueue prefetch requests (without stream IDs). + constexpr int32_t kMB = 1 << 20; + for (int32_t i = 0; i < kNumRequests; ++i) { + input.enqueue( + {static_cast(10 * kMB + i * 1000), kRequestSize}, nullptr); + } + input.load(LogType::TEST); + + // Wait for executor to complete all submitted tasks. + executor_->join(); + + const auto& loadsAfterPrefetch = input.testingCoalescedLoads(); + ASSERT_EQ(loadsAfterPrefetch.size(), 2); + EXPECT_EQ(loadsAfterPrefetch[0]->state(), CoalescedLoad::State::kPlanned) + << "Non-prefetch load should NOT be submitted to executor"; + EXPECT_EQ(loadsAfterPrefetch[1]->state(), CoalescedLoad::State::kLoaded) + << "Prefetch load should be submitted and completed by executor"; + + EXPECT_EQ(dataIoStats_->prefetch().sum(), kNumRequests * kRequestSize); +} + +TEST_F(CachedBufferedInputTest, cloneNonCacheable) { + constexpr int32_t kContentSize = 1 << 20; // 1MB + std::string content(kContentSize, 'x'); + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setLoadQuantum(1 << 20); + + auto& ids = fileIds(); + StringIdLease fileId(ids, "testFile"); + StringIdLease groupId(ids, "testGroup"); + + const auto fileNum = fileId.id(); + CachedBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + cache_.get(), + tracker_, + std::move(groupId), + dataIoStats_, + /*ioStats=*/nullptr, + executor_.get(), + readerOptions); + + ASSERT_TRUE(input.hasCache()); + + auto nonCacheable = input.cloneNonCacheable(); + ASSERT_TRUE(nonCacheable->hasCache()); + + // Read through the non-cacheable clone. + constexpr uint64_t kOffset = 0; + constexpr uint64_t kLength = 1'024; + StreamIdentifier sid(0); + auto stream = nonCacheable->enqueue({kOffset, kLength}, &sid); + nonCacheable->load(LogType::FILE); + + auto result = getNext(*stream); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result->size(), kLength); + + // The cache should have an entry from the non-cacheable read. + auto statsAfterLoad = cache_->refreshStats(); + EXPECT_GT(statsAfterLoad.numEntries, 0); + + // The entry should not be evictable while the non-cacheable clone is alive. + const cache::RawFileCacheKey cacheKey{fileNum, kOffset}; + EXPECT_FALSE(cache_->testingIsEvictable(cacheKey)); + + // Release the stream and destroy the non-cacheable input. Because + // cacheable=false, the destructor marks preloaded entries as immediately + // evictable (lastUse=0, numUses=0). + stream.reset(); + nonCacheable.reset(); + + // The entry is still in cache but now marked as immediately evictable. + auto statsAfterDestroy = cache_->refreshStats(); + EXPECT_EQ(statsAfterDestroy.numEntries, statsAfterLoad.numEntries); + EXPECT_TRUE(cache_->testingIsEvictable(cacheKey)); + + // Read the same region through a cacheable input — should still work. + StringIdLease fileId2(ids, "testFile2"); + StringIdLease groupId2(ids, "testGroup2"); + const auto fileNum2 = fileId2.id(); + CachedBufferedInput cacheableInput( + readFile, + MetricsLog::voidLog(), + std::move(fileId2), + cache_.get(), + tracker_, + std::move(groupId2), + dataIoStats_, + /*ioStats=*/nullptr, + executor_.get(), + readerOptions); + + auto stream2 = cacheableInput.enqueue({kOffset, kLength}, &sid); + cacheableInput.load(LogType::FILE); + auto result2 = getNext(*stream2); + ASSERT_TRUE(result2.has_value()); + EXPECT_EQ(result2->size(), kLength); + + // The cacheable input's entry should persist and not be marked evictable. + stream2.reset(); + const cache::RawFileCacheKey cacheKey2{fileNum2, kOffset}; + auto statsAfterCacheable = cache_->refreshStats(); + EXPECT_GT(statsAfterCacheable.numEntries, 0); + EXPECT_FALSE(cache_->testingIsEvictable(cacheKey2)); +} + +} // namespace diff --git a/velox/dwio/common/tests/ChainedBufferTests.cpp b/velox/dwio/common/tests/ChainedBufferTests.cpp index 0612f4341ff..43820bc0c77 100644 --- a/velox/dwio/common/tests/ChainedBufferTests.cpp +++ b/velox/dwio/common/tests/ChainedBufferTests.cpp @@ -23,10 +23,7 @@ using namespace ::testing; -namespace facebook { -namespace velox { -namespace dwio { -namespace common { +namespace facebook::velox::dwio::common { class ChainedBufferTests : public Test { protected: @@ -64,8 +61,9 @@ TEST_F(ChainedBufferTests, testCreate) { TEST_F(ChainedBufferTests, testReserve) { for (const uint32_t initialCapacityBytes : {0, 16}) { - SCOPED_TRACE(fmt::format( - "initialCapacityBytes ", succinctBytes(initialCapacityBytes))); + SCOPED_TRACE( + fmt::format( + "initialCapacityBytes ", succinctBytes(initialCapacityBytes))); ChainedBuffer buf{*pool_, initialCapacityBytes, 1024}; ASSERT_EQ(buf.capacity(), initialCapacityBytes); ASSERT_EQ(buf.size(), 0); @@ -252,8 +250,9 @@ TEST_F(ChainedBufferTests, testTrailingZeros) { TEST_F(ChainedBufferTests, testClearAll) { for (const uint32_t initialCapacityBytes : {0, 128}) { - SCOPED_TRACE(fmt::format( - "initialCapacityBytes ", succinctBytes(initialCapacityBytes))); + SCOPED_TRACE( + fmt::format( + "initialCapacityBytes ", succinctBytes(initialCapacityBytes))); ChainedBuffer buf{*pool_, initialCapacityBytes, 1024}; ASSERT_EQ(buf.capacity(), initialCapacityBytes); ASSERT_EQ(buf.size(), 0); @@ -321,7 +320,5 @@ TEST_F(ChainedBufferTests, testClearAll) { ASSERT_EQ(buf.pages_.size(), 9); } } -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook + +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/tests/ColumnReaderStatisticsTests.cpp b/velox/dwio/common/tests/ColumnReaderStatisticsTests.cpp new file mode 100644 index 00000000000..3e71fbcdf07 --- /dev/null +++ b/velox/dwio/common/tests/ColumnReaderStatisticsTests.cpp @@ -0,0 +1,403 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "velox/dwio/common/Statistics.h" +#include "velox/type/Type.h" + +using namespace facebook::velox::dwio::common; +using facebook::velox::RuntimeMetric; +using facebook::velox::TypeKind; + +TEST(IoCounterTest, BasicOperations) { + facebook::velox::io::IoCounter counter; + + EXPECT_EQ(counter.sum(), 0); + EXPECT_EQ(counter.count(), 0); + + counter.increment(5'000); + counter.increment(3'000); + + EXPECT_EQ(counter.sum(), 8'000); + EXPECT_EQ(counter.count(), 2); + EXPECT_EQ(counter.min(), 3'000); + EXPECT_EQ(counter.max(), 5'000); +} + +TEST(IoCounterTest, ConcurrentAccess) { + facebook::velox::io::IoCounter counter; + constexpr int kNumThreads = 4; + constexpr int kIterationsPerThread = 1'000; + + std::vector threads; + threads.reserve(kNumThreads); + for (int i = 0; i < kNumThreads; ++i) { + threads.emplace_back([&counter]() { + for (int j = 0; j < kIterationsPerThread; ++j) { + counter.increment(5); + } + }); + } + for (auto& t : threads) { + t.join(); + } + + EXPECT_EQ(counter.sum(), kNumThreads * kIterationsPerThread * 5); + EXPECT_EQ(counter.count(), kNumThreads * kIterationsPerThread); +} + +TEST(ColumnMetricsSetTest, GetOrCreate) { + ColumnMetricsSet metricsSet; + + auto* result = metricsSet.getOrCreate(1); + ASSERT_NE(result, nullptr); + + // Returns same instance for same nodeId. + EXPECT_EQ(metricsSet.getOrCreate(1), result); + + // Returns different instance for different nodeId. + auto* result2 = metricsSet.getOrCreate(2); + EXPECT_NE(result2, result); +} + +TEST(ColumnMetricsSetTest, GetOrCreateWithTypeKind) { + ColumnMetricsSet metricsSet; + + // Pass type when calling getOrCreate. + auto* result = metricsSet.getOrCreate(1, TypeKind::BIGINT); + ASSERT_NE(result, nullptr); + result->decompressCPUTimeNanos.increment(1'000); + + // Returns same instance for same nodeId. + auto* result2 = metricsSet.getOrCreate(1); + EXPECT_EQ(result2, result); + + // Different nodeId with different type. + auto* result3 = metricsSet.getOrCreate(2, TypeKind::VARCHAR); + EXPECT_NE(result3, result); + result3->decompressCPUTimeNanos.increment(2'000); + + // Verify types are used in toRuntimeMetrics. + std::unordered_map metrics; + metricsSet.toRuntimeMetrics(metrics); + EXPECT_EQ(metrics["column_1.BIGINT.decompressCPUTimeNanos"].sum, 1'000); + EXPECT_EQ(metrics["column_2.VARCHAR.decompressCPUTimeNanos"].sum, 2'000); +} + +TEST(ColumnMetricsSetTest, ToRuntimeMetrics) { + ColumnMetricsSet metricsSet; + + // Empty stats produces empty result. + std::unordered_map result; + metricsSet.toRuntimeMetrics(result); + EXPECT_TRUE(result.empty()); + + // Add timing data with type information. + auto* col1 = metricsSet.getOrCreate(1, TypeKind::BIGINT); + col1->decompressCPUTimeNanos.increment(5'000); + col1->decompressCPUTimeNanos.increment(3'000); + + auto* col2 = metricsSet.getOrCreate(42, TypeKind::VARCHAR); + col2->decompressCPUTimeNanos.increment(2'000); + + // Create a column with type but no data. + metricsSet.getOrCreate(99, TypeKind::DOUBLE); + + result.clear(); + metricsSet.toRuntimeMetrics(result); + + // RuntimeMetric has sum/count/min/max, metric name includes type. + EXPECT_EQ(result["column_1.BIGINT.decompressCPUTimeNanos"].sum, 8'000); + EXPECT_EQ(result["column_1.BIGINT.decompressCPUTimeNanos"].count, 2); + EXPECT_EQ(result["column_1.BIGINT.decompressCPUTimeNanos"].min, 3'000); + EXPECT_EQ(result["column_1.BIGINT.decompressCPUTimeNanos"].max, 5'000); + EXPECT_EQ(result["column_42.VARCHAR.decompressCPUTimeNanos"].sum, 2'000); + EXPECT_EQ(result["column_42.VARCHAR.decompressCPUTimeNanos"].count, 1); + + // Zero values are not included. + metricsSet.getOrCreate(99); + result.clear(); + metricsSet.toRuntimeMetrics(result); + EXPECT_EQ(result.count("column_99.DOUBLE.decompressCPUTimeNanos"), 0); +} + +TEST(ColumnMetricsSetTest, ToRuntimeMetricsWithInvalidType) { + ColumnMetricsSet metricsSet; + + // Add timing data without type information (INVALID type). + auto* col1 = metricsSet.getOrCreate(1); + col1->decompressCPUTimeNanos.increment(5'000); + + std::unordered_map result; + metricsSet.toRuntimeMetrics(result); + + // Should use INVALID as type name. + EXPECT_EQ(result["column_1.INVALID.decompressCPUTimeNanos"].sum, 5'000); +} + +TEST(ColumnMetricsSetTest, ToRuntimeMetricsWithDecodeTime) { + ColumnMetricsSet metricsSet; + + // Add both decompress and decode timing data. + auto* col1 = metricsSet.getOrCreate(1, TypeKind::BIGINT); + col1->decompressCPUTimeNanos.increment(5'000); + col1->decodeCPUTimeNanos.increment(10'000); + col1->decodeCPUTimeNanos.increment(8'000); + + auto* col2 = metricsSet.getOrCreate(2, TypeKind::VARCHAR); + col2->decodeCPUTimeNanos.increment(3'000); + + std::unordered_map result; + metricsSet.toRuntimeMetrics(result); + + // Column 1 has both decompress and decode metrics. + EXPECT_EQ(result["column_1.BIGINT.decompressCPUTimeNanos"].sum, 5'000); + EXPECT_EQ(result["column_1.BIGINT.decodeCPUTimeNanos"].sum, 18'000); + EXPECT_EQ(result["column_1.BIGINT.decodeCPUTimeNanos"].count, 2); + EXPECT_EQ(result["column_1.BIGINT.decodeCPUTimeNanos"].min, 8'000); + EXPECT_EQ(result["column_1.BIGINT.decodeCPUTimeNanos"].max, 10'000); + + // Column 2 has only decode metrics. + EXPECT_EQ(result.count("column_2.VARCHAR.decompressCPUTimeNanos"), 0); + EXPECT_EQ(result["column_2.VARCHAR.decodeCPUTimeNanos"].sum, 3'000); + EXPECT_EQ(result["column_2.VARCHAR.decodeCPUTimeNanos"].count, 1); +} + +TEST(RuntimeStatisticsTest, ToRuntimeMetricMap) { + RuntimeStatistics stats; + + // Empty stats produces empty result. + EXPECT_TRUE(stats.toRuntimeMetricMap().empty()); + + // Set various stats. + stats.skippedSplits = 5; + stats.processedSplits = 15; + stats.skippedStrides = 10; + stats.processedStrides = 30; + stats.numStripes = 4; + stats.columnReaderStats.flattenStringDictionaryValues = 1'000; + + // Add per-column stats with type. + stats.columnReaderStats.columnMetricsSet.emplace(); + auto* colMetrics = stats.columnReaderStats.columnMetricsSet->getOrCreate( + 1, TypeKind::BIGINT); + colMetrics->decompressCPUTimeNanos.increment(5'000); + colMetrics->decodeCPUTimeNanos.increment(12'000); + + auto result = stats.toRuntimeMetricMap(); + + EXPECT_EQ(result["skippedSplits"].sum, 5); + EXPECT_EQ(result["processedSplits"].sum, 15); + EXPECT_EQ(result["skippedStrides"].sum, 10); + EXPECT_EQ(result["processedStrides"].sum, 30); + EXPECT_EQ(result["numStripes"].sum, 4); + EXPECT_EQ(result["flattenStringDictionaryValues"].sum, 1'000); + EXPECT_EQ(result["column_1.BIGINT.decompressCPUTimeNanos"].sum, 5'000); + EXPECT_EQ(result["column_1.BIGINT.decompressCPUTimeNanos"].count, 1); + EXPECT_EQ(result["column_1.BIGINT.decodeCPUTimeNanos"].sum, 12'000); + EXPECT_EQ(result["column_1.BIGINT.decodeCPUTimeNanos"].count, 1); +} + +TEST(ColumnMetricsSetConcurrencyTest, ConcurrentGetOrCreate) { + ColumnMetricsSet metricsSet; + constexpr int kNumThreads = 4; + constexpr int kNumColumns = 10; + + // Pre-populate columns with types before concurrent access. + for (uint32_t colId = 0; colId < kNumColumns; ++colId) { + metricsSet.getOrCreate(colId, TypeKind::BIGINT); + } + + std::vector threads; + threads.reserve(kNumThreads); + for (int i = 0; i < kNumThreads; ++i) { + threads.emplace_back([&metricsSet]() { + for (uint32_t colId = 0; colId < kNumColumns; ++colId) { + auto* colMetrics = metricsSet.getOrCreate(colId); + colMetrics->decompressCPUTimeNanos.increment(100); + } + }); + } + for (auto& t : threads) { + t.join(); + } + + std::unordered_map result; + metricsSet.toRuntimeMetrics(result); + + for (uint32_t colId = 0; colId < kNumColumns; ++colId) { + auto key = fmt::format("column_{}.BIGINT.decompressCPUTimeNanos", colId); + EXPECT_EQ(result[key].sum, kNumThreads * 100); + EXPECT_EQ(result[key].count, kNumThreads); + } +} + +TEST(IoCounterTest, MergeStats) { + facebook::velox::io::IoCounter counter1; + counter1.increment(5'000); + counter1.increment(3'000); + + facebook::velox::io::IoCounter counter2; + counter2.increment(2'000); + + counter1.merge(counter2); + + EXPECT_EQ(counter1.sum(), 10'000); + EXPECT_EQ(counter1.count(), 3); +} + +TEST(ColumnMetricsSetTest, MergeFromWithOverlappingNodeIds) { + ColumnMetricsSet src; + auto* srcCol1 = src.getOrCreate(1, TypeKind::BIGINT); + srcCol1->decompressCPUTimeNanos.increment(5'000); + srcCol1->decompressCPUTimeNanos.increment(3'000); + srcCol1->decodeCPUTimeNanos.increment(10'000); + + auto* srcCol2 = src.getOrCreate(2, TypeKind::VARCHAR); + srcCol2->decompressCPUTimeNanos.increment(2'000); + srcCol2->decodeCPUTimeNanos.increment(4'000); + + ColumnMetricsSet dst; + auto* dstCol1 = dst.getOrCreate(1, TypeKind::BIGINT); + dstCol1->decompressCPUTimeNanos.increment(1'000); + dstCol1->decodeCPUTimeNanos.increment(6'000); + + dst.mergeFrom(src); + + std::unordered_map result; + dst.toRuntimeMetrics(result); + + EXPECT_EQ(result["column_1.BIGINT.decompressCPUTimeNanos"].sum, 9'000); + EXPECT_EQ(result["column_1.BIGINT.decompressCPUTimeNanos"].count, 3); + EXPECT_EQ(result["column_1.BIGINT.decodeCPUTimeNanos"].sum, 16'000); + EXPECT_EQ(result["column_1.BIGINT.decodeCPUTimeNanos"].count, 2); + EXPECT_EQ(result["column_2.VARCHAR.decompressCPUTimeNanos"].sum, 2'000); + EXPECT_EQ(result["column_2.VARCHAR.decompressCPUTimeNanos"].count, 1); + EXPECT_EQ(result["column_2.VARCHAR.decodeCPUTimeNanos"].sum, 4'000); + EXPECT_EQ(result["column_2.VARCHAR.decodeCPUTimeNanos"].count, 1); +} + +TEST(ColumnMetricsSetTest, MergeFromWithDisjointNodeIds) { + ColumnMetricsSet src; + auto* srcCol3 = src.getOrCreate(3, TypeKind::DOUBLE); + srcCol3->decompressCPUTimeNanos.increment(3'000); + + auto* srcCol4 = src.getOrCreate(4, TypeKind::BOOLEAN); + srcCol4->decompressCPUTimeNanos.increment(4'000); + + ColumnMetricsSet dst; + auto* dstCol1 = dst.getOrCreate(1, TypeKind::BIGINT); + dstCol1->decompressCPUTimeNanos.increment(1'000); + + auto* dstCol2 = dst.getOrCreate(2, TypeKind::VARCHAR); + dstCol2->decompressCPUTimeNanos.increment(2'000); + + dst.mergeFrom(src); + + std::unordered_map result; + dst.toRuntimeMetrics(result); + + EXPECT_EQ(result["column_1.BIGINT.decompressCPUTimeNanos"].sum, 1'000); + EXPECT_EQ(result["column_2.VARCHAR.decompressCPUTimeNanos"].sum, 2'000); + EXPECT_EQ(result["column_3.DOUBLE.decompressCPUTimeNanos"].sum, 3'000); + EXPECT_EQ(result["column_4.BOOLEAN.decompressCPUTimeNanos"].sum, 4'000); +} + +TEST(ColumnMetricsSetTest, MergeFromEmpty) { + ColumnMetricsSet nonEmpty; + auto* col = nonEmpty.getOrCreate(1, TypeKind::BIGINT); + col->decompressCPUTimeNanos.increment(5'000); + + ColumnMetricsSet empty; + + nonEmpty.mergeFrom(empty); + + std::unordered_map result; + nonEmpty.toRuntimeMetrics(result); + EXPECT_EQ(result["column_1.BIGINT.decompressCPUTimeNanos"].sum, 5'000); + + ColumnMetricsSet empty2; + empty2.mergeFrom(nonEmpty); + + result.clear(); + empty2.toRuntimeMetrics(result); + EXPECT_EQ(result["column_1.BIGINT.decompressCPUTimeNanos"].sum, 5'000); +} + +TEST(ColumnReaderStatisticsTest, MergeFromWithColumnMetrics) { + ColumnReaderStatistics src; + src.flattenStringDictionaryValues = 100; + src.columnMetricsSet.emplace(); + src.columnMetricsSet->getOrCreate(1, TypeKind::BIGINT) + ->decompressCPUTimeNanos.increment(1'000); + + // Merge into stats without columnMetricsSet - creates and populates it. + ColumnReaderStatistics dst; + dst.flattenStringDictionaryValues = 50; + dst.mergeFrom(src); + + EXPECT_EQ(dst.flattenStringDictionaryValues, 150); + ASSERT_TRUE(dst.columnMetricsSet.has_value()); + + std::unordered_map result; + dst.columnMetricsSet->toRuntimeMetrics(result); + EXPECT_EQ(result["column_1.BIGINT.decompressCPUTimeNanos"].sum, 1'000); +} + +TEST(ColumnReaderStatisticsTest, MergeFromBothWithColumnMetrics) { + ColumnReaderStatistics src; + src.flattenStringDictionaryValues = 100; + src.columnMetricsSet.emplace(); + src.columnMetricsSet->getOrCreate(1, TypeKind::BIGINT) + ->decompressCPUTimeNanos.increment(1'000); + + ColumnReaderStatistics dst; + dst.flattenStringDictionaryValues = 50; + dst.columnMetricsSet.emplace(); + dst.columnMetricsSet->getOrCreate(1, TypeKind::BIGINT) + ->decompressCPUTimeNanos.increment(2'000); + + dst.mergeFrom(src); + + EXPECT_EQ(dst.flattenStringDictionaryValues, 150); + ASSERT_TRUE(dst.columnMetricsSet.has_value()); + + std::unordered_map result; + dst.columnMetricsSet->toRuntimeMetrics(result); + EXPECT_EQ(result["column_1.BIGINT.decompressCPUTimeNanos"].sum, 3'000); +} + +TEST(ColumnReaderStatisticsTest, MergeFromWithoutColumnMetrics) { + ColumnReaderStatistics src; + src.flattenStringDictionaryValues = 100; + + ColumnReaderStatistics dst; + dst.flattenStringDictionaryValues = 50; + dst.columnMetricsSet.emplace(); + dst.columnMetricsSet->getOrCreate(1, TypeKind::BIGINT) + ->decompressCPUTimeNanos.increment(1'000); + + dst.mergeFrom(src); + + EXPECT_EQ(dst.flattenStringDictionaryValues, 150); + ASSERT_TRUE(dst.columnMetricsSet.has_value()); + + std::unordered_map result; + dst.columnMetricsSet->toRuntimeMetrics(result); + EXPECT_EQ(result["column_1.BIGINT.decompressCPUTimeNanos"].sum, 1'000); +} diff --git a/velox/dwio/common/tests/ColumnSelectorTests.cpp b/velox/dwio/common/tests/ColumnSelectorTests.cpp index f0176b3824d..b65411d0e6f 100644 --- a/velox/dwio/common/tests/ColumnSelectorTests.cpp +++ b/velox/dwio/common/tests/ColumnSelectorTests.cpp @@ -394,17 +394,19 @@ TEST(ColumnSelectorTests, testFlatMapKeyFilterAllowed) { } TEST(ColumnSelectorTests, testPartitionKeysMark) { - const auto schema = std::dynamic_pointer_cast( - HiveTypeParser().parse("struct<" - "id:bigint" - "memo:string" - "ds:string" - "key:string>")); - - const auto physicalSchema = std::dynamic_pointer_cast( - HiveTypeParser().parse("struct<" - "id:bigint" - "memo:string>")); + const auto schema = + std::dynamic_pointer_cast(HiveTypeParser().parse( + "struct<" + "id:bigint" + "memo:string" + "ds:string" + "key:string>")); + + const auto physicalSchema = + std::dynamic_pointer_cast(HiveTypeParser().parse( + "struct<" + "id:bigint" + "memo:string>")); // use schema and physical schema to initialize a column selector // without filtering @@ -456,14 +458,16 @@ TEST(ColumnSelectorTests, testPartitionKeysMark) { EXPECT_EQ(root->childAt(3)->getNode().expression, "gold"); // test apply to real data file disk schema - const auto schemaMore = std::dynamic_pointer_cast( - HiveTypeParser().parse("struct<" - "id:bigint" - "memo:string" - "extra:array>")); - const auto schemaLess = std::dynamic_pointer_cast( - HiveTypeParser().parse("struct<" - "id:bigint>")); + const auto schemaMore = + std::dynamic_pointer_cast(HiveTypeParser().parse( + "struct<" + "id:bigint" + "memo:string" + "extra:array>")); + const auto schemaLess = + std::dynamic_pointer_cast(HiveTypeParser().parse( + "struct<" + "id:bigint>")); auto csMore = ColumnSelector::apply(cs, schemaMore); LOG(INFO) << "CS filter size: " << cs->getProjection().size(); @@ -510,14 +514,15 @@ TEST(ColumnSelectorTests, testPartitionKeysMark) { } TEST(ColumnSelectorTests, testProjectionUnchangedWhenReadSetChanged) { - const auto schema = std::dynamic_pointer_cast( - HiveTypeParser().parse("struct<" - "id:bigint" - "values:array" - "tags:map" - "notes:struct" - "memo:string" - "extra:string>")); + const auto schema = + std::dynamic_pointer_cast(HiveTypeParser().parse( + "struct<" + "id:bigint" + "values:array" + "tags:map" + "notes:struct" + "memo:string" + "extra:string>")); ColumnSelector cs(schema, std::vector{"id", "values"}); cs.setRead(cs.findColumn("notes")); @@ -559,13 +564,14 @@ TEST(ColumnSelectorTests, testProjectionUnchangedWhenReadSetChanged) { } TEST(ColumnSelectorTests, testProjectOrder) { - const auto schema = std::dynamic_pointer_cast( - HiveTypeParser().parse("struct<" - "id:bigint" - "values:array" - "tags:map" - "notes:struct" - "memo:string>")); + const auto schema = + std::dynamic_pointer_cast(HiveTypeParser().parse( + "struct<" + "id:bigint" + "values:array" + "tags:map" + "notes:struct" + "memo:string>")); // test filter with names with order of tags, memo and id { @@ -642,14 +648,15 @@ TEST(ColumnSelectorTests, testProjectOrder) { } TEST(ColumnSelectorTests, testNonexistingColFilters) { - const auto schema = std::dynamic_pointer_cast( - HiveTypeParser().parse("struct<" - "id:bigint" - "values:array" - "tags:map" - "notes:struct" - "memo:string" - "extra:string>")); + const auto schema = + std::dynamic_pointer_cast(HiveTypeParser().parse( + "struct<" + "id:bigint" + "values:array" + "tags:map" + "notes:struct" + "memo:string" + "extra:string>")); EXPECT_THROW( ColumnSelector cs( @@ -659,15 +666,16 @@ TEST(ColumnSelectorTests, testNonexistingColFilters) { } TEST(TestColumnSelector, fileColumnNamesReadAsLowerCaseDuplicateColFilters) { - const auto schema = std::dynamic_pointer_cast( - HiveTypeParser().parse("struct<" - "id:bigint" - "id:bigint" - "values:array" - "tags:map" - "notes:struct" - "memo:string" - "extra:string>")); + const auto schema = + std::dynamic_pointer_cast(HiveTypeParser().parse( + "struct<" + "id:bigint" + "id:bigint" + "values:array" + "tags:map" + "notes:struct" + "memo:string" + "extra:string>")); EXPECT_THROW( ColumnSelector cs(schema, std::vector{"id"}, nullptr, true), diff --git a/velox/dwio/common/tests/DataBufferTests.cpp b/velox/dwio/common/tests/DataBufferTests.cpp index a6ddee1afab..c5e7b816cbc 100644 --- a/velox/dwio/common/tests/DataBufferTests.cpp +++ b/velox/dwio/common/tests/DataBufferTests.cpp @@ -21,10 +21,8 @@ #include "velox/common/memory/Memory.h" #include "velox/dwio/common/DataBuffer.h" -namespace facebook { -namespace velox { -namespace dwio { -namespace common { +namespace facebook::velox::dwio::common { + using namespace facebook::velox::memory; using namespace testing; using MemoryPool = facebook::velox::memory::MemoryPool; @@ -175,7 +173,5 @@ TEST_F(DataBufferTest, Move) { } ASSERT_EQ(0, pool_->usedBytes()); } -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook + +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/tests/DecoderUtilTest.cpp b/velox/dwio/common/tests/DecoderUtilTest.cpp index f142d2baf46..9e42cb22882 100644 --- a/velox/dwio/common/tests/DecoderUtilTest.cpp +++ b/velox/dwio/common/tests/DecoderUtilTest.cpp @@ -17,6 +17,7 @@ #include "velox/dwio/common/DecoderUtil.h" #include #include "velox/common/base/Nulls.h" +#include "velox/dwio/common/SelectiveColumnReader.h" #include "velox/type/Filter.h" #include @@ -161,17 +162,6 @@ TEST_F(DecoderUtilTest, nonNullsFromSparse) { } } -namespace facebook::velox::dwio::common { -// Excerpt from LazyVector.h. -struct NoHook { - void addValues( - const int32_t* /*rows*/, - const int32_t* /*values*/, - int32_t /*size*/) {} -}; - -} // namespace facebook::velox::dwio::common - TEST_F(DecoderUtilTest, processFixedWithRun) { // Tests processing consecutive batches of integers with processFixedWidthRun. constexpr int kSize = 100; @@ -233,3 +223,35 @@ TEST_F(DecoderUtilTest, processFixedWithRun) { } } } + +TEST_F(DecoderUtilTest, fixedWidthScanMemcpyFastPath) { + constexpr int kSize = 10; + int32_t rows[kSize]; + std::iota(std::begin(rows), std::end(rows), 0); + float expectedValues[kSize], actualValues[kSize]; + for (int i = 0; i < kSize; ++i) { + expectedValues[i] = std::sin(i); + actualValues[i] = NAN; + } + int32_t numValues = 0; + SeekableArrayInputStream input( + reinterpret_cast(expectedValues), sizeof(expectedValues)); + const char* bufferStart = nullptr; + const char* bufferEnd = nullptr; + NoHook noHook; + fixedWidthScan( + {rows, kSize}, + nullptr, + actualValues, + nullptr, + numValues, + input, + bufferStart, + bufferEnd, + common::AlwaysTrue(), + noHook); + for (int i = 0; i < kSize; ++i) { + ASSERT_EQ(actualValues[i], expectedValues[i]); + } + ASSERT_EQ(numValues, kSize); +} diff --git a/velox/dwio/common/tests/DirectBufferedInputTest.cpp b/velox/dwio/common/tests/DirectBufferedInputTest.cpp new file mode 100644 index 00000000000..a7a0f48591e --- /dev/null +++ b/velox/dwio/common/tests/DirectBufferedInputTest.cpp @@ -0,0 +1,918 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/dwio/common/DirectBufferedInput.h" + +#include +#include +#include +#include +#include + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/caching/FileIds.h" +#include "velox/common/file/tests/TestUtils.h" +#include "velox/common/io/IoStatistics.h" +#include "velox/common/io/Options.h" +#include "velox/common/testutil/TestValue.h" + +using namespace facebook::velox; +using namespace facebook::velox::dwio::common; +using namespace facebook::velox::cache; +using namespace facebook::velox::memory; +using facebook::velox::common::testutil::TestValue; + +namespace { + +std::optional getNext(SeekableInputStream& input) { + const void* buf = nullptr; + int32_t size; + if (input.Next(&buf, &size)) { + return std::string( + static_cast(buf), static_cast(size)); + } else { + return std::nullopt; + } +} + +class DirectBufferedInputTest : public testing::Test { + protected: + static void SetUpTestCase() { + TestValue::enable(); + MemoryManager::testingSetInstance(MemoryManager::Options{}); + } + + void SetUp() override { + executor_ = std::make_unique(10); + tracker_ = std::make_shared( + "testTracker", nullptr, 256 << 10 /* 256KB */); + rootPool_ = memoryManager()->addRootPool(); + pool_ = rootPool_->addLeafChild("DirectBufferedInputTest"); + } + + void TearDown() override { + executor_.reset(); + } + + const std::shared_ptr dataIoStats_{ + std::make_shared()}; + const std::shared_ptr metadataIoStats_{ + std::make_shared()}; + + std::unique_ptr executor_; + std::shared_ptr rootPool_; + std::shared_ptr pool_; + std::shared_ptr tracker_; +}; + +TEST_F(DirectBufferedInputTest, reset) { + constexpr int32_t kContentSize = 4 << 20; // 4MB + std::string content; + content.resize(kContentSize); + for (int32_t i = 0; i < kContentSize; ++i) { + content[i] = static_cast('a' + (i % 26)); + } + + // Test both tiny and non-tiny region sizes. + for (const bool tinyRegion : {false, true}) { + SCOPED_TRACE(fmt::format("tinyRegion: {}", tinyRegion)); + + const uint64_t regionSize = tinyRegion + ? DirectBufferedInput::kTinySize - 100 + : DirectBufferedInput::kTinySize + 1000; + + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setLoadQuantum(1 << 20); + + auto& ids = fileIds(); + StringIdLease fileId(ids, "testFile"); + StringIdLease groupId(ids, "testGroup"); + + DirectBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + // First round: enqueue and load streams. + auto stream1 = input.enqueue(common::Region{0, regionSize}, nullptr); + auto stream2 = + input.enqueue(common::Region{regionSize, regionSize}, nullptr); + ASSERT_NE(stream1, nullptr); + ASSERT_NE(stream2, nullptr); + + input.load(LogType::TEST); + + EXPECT_EQ(input.testingStreamToCoalescedLoadSize(), 2); + EXPECT_EQ(input.testingCoalescedLoads().size(), 1); + + // Consume streams. + auto next1 = getNext(*stream1); + ASSERT_TRUE(next1.has_value()); + EXPECT_EQ(next1.value(), content.substr(0, regionSize)); + EXPECT_EQ(input.testingStreamToCoalescedLoadSize(), 1); + EXPECT_EQ(input.testingCoalescedLoads().size(), 1); + + auto next2 = getNext(*stream2); + ASSERT_TRUE(next2.has_value()); + EXPECT_EQ(next2.value(), content.substr(regionSize, regionSize)); + EXPECT_EQ(input.testingStreamToCoalescedLoadSize(), 0); + EXPECT_EQ(input.testingCoalescedLoads().size(), 1); + + if (tinyRegion) { + ASSERT_EQ(pool_->usedBytes(), 0); + } else { + ASSERT_GT(pool_->usedBytes(), 0); + } + + stream1.reset(); + stream2.reset(); + + EXPECT_EQ(input.testingStreamToCoalescedLoadSize(), 0); + EXPECT_EQ(input.testingCoalescedLoads().size(), 1); + ASSERT_EQ(pool_->usedBytes(), 0); + + // Reset the input. + input.reset(); + EXPECT_EQ(input.testingStreamToCoalescedLoadSize(), 0); + EXPECT_EQ(input.testingCoalescedLoads().size(), 0); + ASSERT_EQ(pool_->usedBytes(), 0); + + // Second round: enqueue different regions after reset. + auto stream3 = + input.enqueue(common::Region{2 * regionSize, regionSize}, nullptr); + auto stream4 = + input.enqueue(common::Region{3 * regionSize, regionSize}, nullptr); + ASSERT_NE(stream3, nullptr); + ASSERT_NE(stream4, nullptr); + + input.load(LogType::TEST); + + EXPECT_EQ(input.testingStreamToCoalescedLoadSize(), 2); + EXPECT_EQ(input.testingCoalescedLoads().size(), 1); + + auto next3 = getNext(*stream3); + ASSERT_TRUE(next3.has_value()); + EXPECT_EQ(next3.value(), content.substr(2 * regionSize, regionSize)); + + auto next4 = getNext(*stream4); + ASSERT_TRUE(next4.has_value()); + EXPECT_EQ(next4.value(), content.substr(3 * regionSize, regionSize)); + + if (tinyRegion) { + ASSERT_EQ(pool_->usedBytes(), 0); + } else { + ASSERT_GT(pool_->usedBytes(), 0); + } + + // Reset the input. + input.reset(); + + if (tinyRegion) { + ASSERT_EQ(pool_->usedBytes(), 0); + } else { + ASSERT_GT(pool_->usedBytes(), 0); + } + + EXPECT_EQ(input.testingStreamToCoalescedLoadSize(), 0); + EXPECT_EQ(input.testingCoalescedLoads().size(), 0); + + stream3.reset(); + stream4.reset(); + ASSERT_EQ(pool_->usedBytes(), 0); + } +} + +TEST_F(DirectBufferedInputTest, readAfterReset) { + constexpr int32_t kContentSize = 4 << 20; // 4MB + std::string content; + content.resize(kContentSize); + for (int32_t i = 0; i < kContentSize; ++i) { + content[i] = static_cast('a' + (i % 26)); + } + + // Test both tiny and non-tiny region sizes. + for (const bool tinyRegion : {false, true}) { + SCOPED_TRACE(fmt::format("tinyRegion: {}", tinyRegion)); + + const uint64_t regionSize = tinyRegion + ? DirectBufferedInput::kTinySize - 100 + : DirectBufferedInput::kTinySize + 1000; + + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setLoadQuantum(1 << 20); + + auto& ids = fileIds(); + StringIdLease fileId(ids, "testFile"); + StringIdLease groupId(ids, "testGroup"); + + DirectBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + // Enqueue and load streams. + auto stream1 = input.enqueue(common::Region{0, regionSize}, nullptr); + auto stream2 = + input.enqueue(common::Region{regionSize, regionSize}, nullptr); + ASSERT_NE(stream1, nullptr); + ASSERT_NE(stream2, nullptr); + + input.load(LogType::TEST); + + // Reset the input before reading from streams. + input.reset(); + + EXPECT_EQ(input.testingStreamToCoalescedLoadSize(), 0); + EXPECT_EQ(input.testingCoalescedLoads().size(), 0); + + // Read from streams after reset - data should still be available. + auto next1 = getNext(*stream1); + ASSERT_TRUE(next1.has_value()); + EXPECT_EQ(next1.value(), content.substr(0, regionSize)); + if (tinyRegion) { + ASSERT_EQ(pool_->usedBytes(), 0); + } else { + ASSERT_GT(pool_->usedBytes(), 0); + } + + auto next2 = getNext(*stream2); + ASSERT_TRUE(next2.has_value()); + EXPECT_EQ(next2.value(), content.substr(regionSize, regionSize)); + if (tinyRegion) { + ASSERT_EQ(pool_->usedBytes(), 0); + } else { + ASSERT_GT(pool_->usedBytes(), 0); + } + + stream1.reset(); + stream2.reset(); + while (pool_->usedBytes() > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + } +} + +DEBUG_ONLY_TEST_F(DirectBufferedInputTest, resetInputWithBeforeLoading) { + constexpr int32_t kContentSize = 4 << 20; // 4MB + std::string content; + content.resize(kContentSize); + for (int32_t i = 0; i < kContentSize; ++i) { + content[i] = static_cast('a' + (i % 26)); + } + + // Test both tiny and non-tiny region sizes. + for (const bool tinyRegion : {false, true}) { + SCOPED_TRACE(fmt::format("tinyRegion: {}", tinyRegion)); + + const uint64_t regionSize = tinyRegion + ? DirectBufferedInput::kTinySize - 100 + : DirectBufferedInput::kTinySize + 1000; + + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setLoadQuantum(1 << 20); + + auto& ids = fileIds(); + StringIdLease fileId(ids, "testFile"); + StringIdLease groupId(ids, "testGroup"); + + DirectBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + // Block the coalesced load to verify references are held. + folly::Baton<> loadStarted; + folly::Baton<> loadAllowed; + + SCOPED_TESTVALUE_SET( + "facebook::velox::cache::CoalescedLoad::loadOrFuture", + std::function( + [&](const CoalescedLoad* /*load*/) { + loadStarted.post(); + loadAllowed.wait(); + })); + + // Enqueue and load streams. + auto stream1 = input.enqueue(common::Region{0, regionSize}, nullptr); + auto stream2 = + input.enqueue(common::Region{regionSize, regionSize}, nullptr); + ASSERT_NE(stream1, nullptr); + ASSERT_NE(stream2, nullptr); + + ASSERT_EQ(pool_->usedBytes(), 0); + input.load(LogType::TEST); + ASSERT_EQ(pool_->usedBytes(), 0); + + // Wait for the load to start (but it's blocked). + loadStarted.wait(); + ASSERT_EQ(pool_->usedBytes(), 0); + + // Verify coalesced load references are held. + EXPECT_EQ(input.testingStreamToCoalescedLoadSize(), 2); + EXPECT_EQ(input.testingCoalescedLoads().size(), 1); + + // Reset the input while load is pending. + input.reset(); + ASSERT_EQ(pool_->usedBytes(), 0); + + // After reset, internal tracking should be cleared. + EXPECT_EQ(input.testingStreamToCoalescedLoadSize(), 0); + EXPECT_EQ(input.testingCoalescedLoads().size(), 0); + + // Allow the load to proceed but cancelled. + loadAllowed.post(); + + std::this_thread::sleep_for(std::chrono::seconds(1)); + ASSERT_EQ(pool_->usedBytes(), 0); + + // Read from streams - data should be available after load completes. + auto next1 = getNext(*stream1); + ASSERT_TRUE(next1.has_value()); + EXPECT_EQ(next1.value(), content.substr(0, regionSize)); + if (tinyRegion) { + ASSERT_EQ(pool_->usedBytes(), 0); + } else { + ASSERT_GT(pool_->usedBytes(), 0); + } + + auto next2 = getNext(*stream2); + ASSERT_TRUE(next2.has_value()); + EXPECT_EQ(next2.value(), content.substr(regionSize, regionSize)); + if (tinyRegion) { + ASSERT_EQ(pool_->usedBytes(), 0); + } else { + ASSERT_GT(pool_->usedBytes(), 0); + } + stream1.reset(); + stream2.reset(); + ASSERT_EQ(pool_->usedBytes(), 0); + } +} + +DEBUG_ONLY_TEST_F(DirectBufferedInputTest, resetInputWithAfterLoading) { + constexpr int32_t kContentSize = 4 << 20; // 4MB + std::string content; + content.resize(kContentSize); + for (int32_t i = 0; i < kContentSize; ++i) { + content[i] = static_cast('a' + (i % 26)); + } + + // Test both tiny and non-tiny region sizes. + for (const bool tinyRegion : {false, true}) { + SCOPED_TRACE(fmt::format("tinyRegion: {}", tinyRegion)); + + const uint64_t regionSize = tinyRegion + ? DirectBufferedInput::kTinySize - 100 + : DirectBufferedInput::kTinySize + 1000; + + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setLoadQuantum(1 << 20); + + auto& ids = fileIds(); + StringIdLease fileId(ids, "testFile"); + StringIdLease groupId(ids, "testGroup"); + + DirectBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + // Block the coalesced load to verify references are held. + folly::Baton<> loadStarted; + folly::Baton<> loadAllowed; + + SCOPED_TESTVALUE_SET( + "facebook::velox::cache::CoalescedLoad::loadOrFuture::loading", + std::function( + [&](const CoalescedLoad* /*load*/) { + loadStarted.post(); + loadAllowed.wait(); + })); + + // Enqueue and load streams. + auto stream1 = input.enqueue(common::Region{0, regionSize}, nullptr); + auto stream2 = + input.enqueue(common::Region{regionSize, regionSize}, nullptr); + ASSERT_NE(stream1, nullptr); + ASSERT_NE(stream2, nullptr); + + ASSERT_EQ(pool_->usedBytes(), 0); + input.load(LogType::TEST); + ASSERT_EQ(pool_->usedBytes(), 0); + + // Wait for the load to start (but it's blocked). + loadStarted.wait(); + ASSERT_EQ(pool_->usedBytes(), 0); + + // Verify coalesced load references are held. + EXPECT_EQ(input.testingStreamToCoalescedLoadSize(), 2); + EXPECT_EQ(input.testingCoalescedLoads().size(), 1); + + // Reset the input while load is pending. + input.reset(); + ASSERT_EQ(pool_->usedBytes(), 0); + + // After reset, internal tracking should be cleared. + EXPECT_EQ(input.testingStreamToCoalescedLoadSize(), 0); + EXPECT_EQ(input.testingCoalescedLoads().size(), 0); + + // Allow the load to proceed without cancelling. + loadAllowed.post(); + std::this_thread::sleep_for(std::chrono::seconds(1)); + + // Read from streams - data should be available after load completes. + auto next1 = getNext(*stream1); + ASSERT_TRUE(next1.has_value()); + EXPECT_EQ(next1.value(), content.substr(0, regionSize)); + if (tinyRegion) { + ASSERT_EQ(pool_->usedBytes(), 0); + } else { + ASSERT_GT(pool_->usedBytes(), 0); + } + + auto next2 = getNext(*stream2); + ASSERT_TRUE(next2.has_value()); + EXPECT_EQ(next2.value(), content.substr(regionSize, regionSize)); + if (tinyRegion) { + ASSERT_EQ(pool_->usedBytes(), 0); + } else { + ASSERT_GT(pool_->usedBytes(), 0); + } + stream1.reset(); + stream2.reset(); + while (pool_->usedBytes() > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + } +} + +TEST_F(DirectBufferedInputTest, preloadCalledTwice) { + std::string content(1024, 'x'); + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto& ids = fileIds(); + StringIdLease fileId(ids, "preloadTwice"); + StringIdLease groupId(ids, "preloadTwiceGroup"); + + DirectBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + input.preload(); + ASSERT_TRUE(input.preloaded()); + VELOX_ASSERT_THROW(input.preload(), "preload() called more than once"); +} + +TEST_F(DirectBufferedInputTest, isBufferedWithPreload) { + std::string content(1024, 'x'); + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto& ids = fileIds(); + StringIdLease fileId(ids, "isBufferedPreload"); + StringIdLease groupId(ids, "isBufferedPreloadGroup"); + + DirectBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + // Before preload, isBuffered should return false. + EXPECT_FALSE(input.isBuffered(0, 100)); + + input.preload(); + ASSERT_TRUE(input.preloaded()); + + // After preload, isBuffered should return true. + EXPECT_TRUE(input.isBuffered(0, 100)); + EXPECT_TRUE(input.isBuffered(500, 200)); +} + +TEST_F(DirectBufferedInputTest, enqueueSkipsRequestsWhenPreloaded) { + std::string content(1024, 'x'); + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto& ids = fileIds(); + StringIdLease fileId(ids, "enqueueSkipsRequests"); + StringIdLease groupId(ids, "enqueueSkipsRequestsGroup"); + + DirectBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + ASSERT_EQ(readFile->numReads(), 0); + + input.preload(); + ASSERT_TRUE(input.preloaded()); + ASSERT_EQ(readFile->numReads(), 1); + + // After preload, enqueue should not add to requests or coalesced loads. + auto stream = input.enqueue(common::Region{0, 100}, nullptr); + ASSERT_NE(stream, nullptr); + + // Verify no coalesced loads were created (requests are skipped when + // preloaded). + EXPECT_EQ(input.testingCoalescedLoads().size(), 0); + EXPECT_EQ(input.testingStreamToCoalescedLoadSize(), 0); + + // Stream should still be able to read data from preloaded content. + auto next = getNext(*stream); + ASSERT_TRUE(next.has_value()); + EXPECT_EQ(next.value(), content.substr(0, 100)); + + // No additional file reads - all served from preloaded data. + ASSERT_EQ(readFile->numReads(), 1); +} + +TEST_F(DirectBufferedInputTest, preloadAfterEnqueue) { + std::string content(1024, 'x'); + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto& ids = fileIds(); + StringIdLease fileId(ids, "preloadAfterEnqueue"); + StringIdLease groupId(ids, "preloadAfterEnqueueGroup"); + + DirectBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + input.enqueue({0, 100}, nullptr); + VELOX_ASSERT_THROW( + input.preload(), "preload() must be called before enqueue()"); +} + +TEST_F(DirectBufferedInputTest, preloadedDataWithoutPreload) { + std::string content(1024, 'x'); + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto& ids = fileIds(); + StringIdLease fileId(ids, "preloadedDataNoPreload"); + StringIdLease groupId(ids, "preloadedDataNoPreloadGroup"); + + DirectBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + VELOX_ASSERT_THROW( + input.preloadedData(0, 100), "preloadedData() called without preload"); +} + +TEST_F(DirectBufferedInputTest, preloadedDataOffsetOutOfRange) { + std::string content(1024, 'x'); + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto& ids = fileIds(); + StringIdLease fileId(ids, "preloadedDataOOR"); + StringIdLease groupId(ids, "preloadedDataOORGroup"); + + DirectBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + input.preload(); + VELOX_ASSERT_THROW( + input.preloadedData(1024, 100), "Offset exceeds preloaded size"); +} + +TEST_F(DirectBufferedInputTest, preload) { + struct TestParam { + uint64_t fileSize; + std::string debugString() const { + return fmt::format("fileSize {}", fileSize); + } + }; + std::vector testSettings = { + // Tiny file (below kTinySize). + {DirectBufferedInput::kTinySize - 100}, + // Non-tiny file (above kTinySize). + {DirectBufferedInput::kTinySize + 1000}, + // Larger file. + {1 << 20}, + }; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + std::string content; + content.resize(testData.fileSize); + for (uint64_t i = 0; i < testData.fileSize; ++i) { + content[i] = static_cast('a' + (i % 26)); + } + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setLoadQuantum(1 << 20); + + auto& ids = fileIds(); + StringIdLease fileId(ids, fmt::format("preloadTest_{}", testData.fileSize)); + StringIdLease groupId( + ids, fmt::format("preloadGroup_{}", testData.fileSize)); + + DirectBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + ASSERT_EQ(readFile->numReads(), 0); + EXPECT_FALSE(input.preloaded()); + + input.preload(); + + EXPECT_TRUE(input.preloaded()); + + ASSERT_EQ(readFile->numReads(), 1); + + // Enqueue sub-region streams and read from preloaded data. + const uint64_t regionSize = + std::min(testData.fileSize / 2, testData.fileSize); + auto stream1 = input.enqueue(common::Region{0, regionSize}, nullptr); + ASSERT_NE(stream1, nullptr); + + auto next1 = getNext(*stream1); + ASSERT_TRUE(next1.has_value()); + EXPECT_EQ(next1.value(), content.substr(0, regionSize)); + + if (testData.fileSize > regionSize) { + auto stream2 = input.enqueue( + common::Region{regionSize, testData.fileSize - regionSize}, nullptr); + ASSERT_NE(stream2, nullptr); + + auto next2 = getNext(*stream2); + ASSERT_TRUE(next2.has_value()); + EXPECT_EQ( + next2.value(), + content.substr(regionSize, testData.fileSize - regionSize)); + } + + // No additional file reads after preload. + ASSERT_EQ(readFile->numReads(), 1); + } +} + +TEST_F(DirectBufferedInputTest, preloadedData) { + struct TestParam { + uint64_t fileSize; + uint64_t offset; + uint64_t length; + // Expected size of the returned range. For tiny files, this equals + // min(length, fileSize - offset). For large files, it may be smaller + // due to non-contiguous allocation run boundaries. + bool expectTinyPath; + std::string debugString() const { + return fmt::format( + "fileSize {}, offset {}, length {}, expectTinyPath {}", + fileSize, + offset, + length, + expectTinyPath); + } + }; + + std::vector testSettings = { + // Tiny file: read from beginning. + {1000, 0, 500, true}, + // Tiny file: read from middle. + {1000, 400, 200, true}, + // Tiny file: length exceeds remaining bytes. + {1000, 800, 500, true}, + // Tiny file: length zero. + {1000, 0, 0, true}, + // Tiny file: read last byte. + {1000, 999, 100, true}, + // Tiny file: read entire file. + {1000, 0, 1000, true}, + // Large file: read from beginning. + {1 << 20, 0, 4096, false}, + // Large file: read from middle. + {1 << 20, 100'000, 4096, false}, + // Large file: length exceeds remaining bytes. + {1 << 20, (1 << 20) - 100, 4096, false}, + // Large file: length zero. + {1 << 20, 0, 0, false}, + // Large file: read last byte. + {1 << 20, (1 << 20) - 1, 100, false}, + }; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + std::string content; + content.resize(testData.fileSize); + for (uint64_t i = 0; i < testData.fileSize; ++i) { + content[i] = static_cast('a' + (i % 26)); + } + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto& ids = fileIds(); + StringIdLease fileId( + ids, + fmt::format( + "preloadedData_{}_{}_{}", + testData.fileSize, + testData.offset, + testData.length)); + StringIdLease groupId( + ids, + fmt::format( + "preloadedDataGroup_{}_{}_{}", + testData.fileSize, + testData.offset, + testData.length)); + + DirectBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + ASSERT_EQ(readFile->numReads(), 0); + input.preload(); + ASSERT_EQ(readFile->numReads(), 1); + + const auto range = input.preloadedData(testData.offset, testData.length); + + // No additional file reads — preloadedData serves from memory. + ASSERT_EQ(readFile->numReads(), 1); + + const auto expectedAvailable = std::min( + testData.length, testData.fileSize - testData.offset); + + if (testData.length == 0) { + EXPECT_EQ(range.size(), 0); + continue; + } + + // For tiny files, the data is contiguous so we get all available bytes. + // For large files, we may get fewer bytes due to allocation run boundaries. + EXPECT_GT(range.size(), 0); + EXPECT_LE(range.size(), expectedAvailable); + if (testData.expectTinyPath) { + EXPECT_EQ(range.size(), expectedAvailable); + } + + // Verify data content matches original file. + EXPECT_EQ( + std::string_view(range.data(), range.size()), + std::string_view(content.data() + testData.offset, range.size())); + } +} + +TEST_F(DirectBufferedInputTest, hasCache) { + std::string content(1024, 'x'); + auto readFile = std::make_shared(content); + + io::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + + auto& ids = fileIds(); + StringIdLease fileId(ids, "testFile"); + StringIdLease groupId(ids, "testGroup"); + + DirectBufferedInput input( + readFile, + MetricsLog::voidLog(), + std::move(fileId), + tracker_, + std::move(groupId), + dataIoStats_, + nullptr, + executor_.get(), + readerOptions); + + ASSERT_FALSE(input.hasCache()); + VELOX_ASSERT_THROW( + input.cacheRegion(0, 10, std::string_view("0123456789")), + "cacheRegion requires a backing cache"); + VELOX_ASSERT_THROW( + input.findCachedRegion(0), "findCachedRegion requires a backing cache"); +} + +} // namespace diff --git a/velox/dwio/common/tests/Lemire/CMakeLists.txt b/velox/dwio/common/tests/Lemire/CMakeLists.txt new file mode 100644 index 00000000000..16d880cbc2f --- /dev/null +++ b/velox/dwio/common/tests/Lemire/CMakeLists.txt @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_subdirectory(FastPFor) + +velox_add_library(velox_lemire_bmipacking32 INTERFACE HEADERS bmipacking32.h) diff --git a/velox/dwio/common/tests/Lemire/FastPFor/CMakeLists.txt b/velox/dwio/common/tests/Lemire/FastPFor/CMakeLists.txt index d74b5b7af66..9b16a4ebf7d 100644 --- a/velox/dwio/common/tests/Lemire/FastPFor/CMakeLists.txt +++ b/velox/dwio/common/tests/Lemire/FastPFor/CMakeLists.txt @@ -12,5 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. add_library(velox_fastpforlib STATIC bitpacking.cpp) +velox_add_test_headers(velox_fastpforlib bitpacking.h bitpackinghelpers.h) target_include_directories(velox_fastpforlib PUBLIC $) diff --git a/velox/dwio/common/tests/LocalFileSinkTest.cpp b/velox/dwio/common/tests/LocalFileSinkTest.cpp index c052ce1e064..4eb7cc95433 100644 --- a/velox/dwio/common/tests/LocalFileSinkTest.cpp +++ b/velox/dwio/common/tests/LocalFileSinkTest.cpp @@ -17,15 +17,15 @@ #include "velox/common/base/Fs.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/FileSystems.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/dwio/common/FileSink.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include using namespace ::testing; -using namespace facebook::velox::exec::test; namespace facebook::velox::dwio::common { +using namespace facebook::velox::common::testutil; class LocalFileSinkTest : public testing::Test { protected: @@ -79,4 +79,74 @@ TEST_F(LocalFileSinkTest, existFileCheck) { "File exists"); } +TEST_F(LocalFileSinkTest, getIoStatisticsReturnsNullWhenNotProvided) { + LocalFileSink::registerFactory(); + auto root = TempDirectoryPath::create(); + auto filePath = fs::path(root->getPath()) / "test_stats_null.ext"; + + auto localFileSink = FileSink::create( + fmt::format("file:{}", filePath.string()), {.pool = pool_.get()}); + + EXPECT_EQ(localFileSink->getIoStatistics(), nullptr); + localFileSink->close(); +} + +TEST_F(LocalFileSinkTest, getIoStatisticsReturnsProvidedStats) { + LocalFileSink::registerFactory(); + auto root = TempDirectoryPath::create(); + auto filePath = fs::path(root->getPath()) / "test_stats_provided.ext"; + + IoStatistics ioStats; + auto localFileSink = FileSink::create( + fmt::format("file:{}", filePath.string()), + {.pool = pool_.get(), .stats = &ioStats}); + + EXPECT_EQ(localFileSink->getIoStatistics(), &ioStats); + localFileSink->close(); +} + +TEST_F(LocalFileSinkTest, getFileSystemStatsReturnsNullWhenNotProvided) { + LocalFileSink::registerFactory(); + auto root = TempDirectoryPath::create(); + auto filePath = fs::path(root->getPath()) / "test_fs_stats_null.ext"; + + auto localFileSink = FileSink::create( + fmt::format("file:{}", filePath.string()), {.pool = pool_.get()}); + + EXPECT_EQ(localFileSink->getFileSystemStats(), nullptr); + localFileSink->close(); +} + +TEST_F(LocalFileSinkTest, getFileSystemStatsReturnsProvidedStats) { + LocalFileSink::registerFactory(); + auto root = TempDirectoryPath::create(); + auto filePath = fs::path(root->getPath()) / "test_fs_stats_provided.ext"; + + velox::IoStats fileSystemStats; + auto localFileSink = FileSink::create( + fmt::format("file:{}", filePath.string()), + {.pool = pool_.get(), .fileSystemStats = &fileSystemStats}); + + EXPECT_EQ(localFileSink->getFileSystemStats(), &fileSystemStats); + localFileSink->close(); +} + +TEST_F(LocalFileSinkTest, getFileSystemStatsAndIoStatisticsBothProvided) { + LocalFileSink::registerFactory(); + auto root = TempDirectoryPath::create(); + auto filePath = fs::path(root->getPath()) / "test_both_stats.ext"; + + IoStatistics ioStats; + velox::IoStats fileSystemStats; + auto localFileSink = FileSink::create( + fmt::format("file:{}", filePath.string()), + {.pool = pool_.get(), + .stats = &ioStats, + .fileSystemStats = &fileSystemStats}); + + EXPECT_EQ(localFileSink->getIoStatistics(), &ioStats); + EXPECT_EQ(localFileSink->getFileSystemStats(), &fileSystemStats); + localFileSink->close(); +} + } // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/tests/LoggedExceptionTest.cpp b/velox/dwio/common/tests/LoggedExceptionTest.cpp index 6eea7e8b5c7..0f17e1f50a3 100644 --- a/velox/dwio/common/tests/LoggedExceptionTest.cpp +++ b/velox/dwio/common/tests/LoggedExceptionTest.cpp @@ -34,11 +34,12 @@ void testTraceCollectionSwitchControl(bool enabled) { try { throw LoggedException("Test error message"); } catch (VeloxException& e) { - SCOPED_TRACE(fmt::format( - "enabled: {}, user flag: {}, sys flag: {}", - enabled, - FLAGS_velox_exception_user_stacktrace_enabled, - FLAGS_velox_exception_system_stacktrace_enabled)); + SCOPED_TRACE( + fmt::format( + "enabled: {}, user flag: {}, sys flag: {}", + enabled, + FLAGS_velox_exception_user_stacktrace_enabled, + FLAGS_velox_exception_system_stacktrace_enabled)); ASSERT_TRUE(e.exceptionType() == VeloxException::Type::kSystem); ASSERT_EQ(enabled, e.stackTrace() != nullptr); } diff --git a/velox/dwio/common/tests/OnDemandUnitLoaderTests.cpp b/velox/dwio/common/tests/OnDemandUnitLoaderTests.cpp index 178ad21f9b2..245c7d6186d 100644 --- a/velox/dwio/common/tests/OnDemandUnitLoaderTests.cpp +++ b/velox/dwio/common/tests/OnDemandUnitLoaderTests.cpp @@ -17,9 +17,8 @@ #include #include -#include "velox/common/base/tests/GTestUtils.h" #include "velox/dwio/common/OnDemandUnitLoader.h" -#include "velox/dwio/common/UnitLoaderTools.h" +#include "velox/dwio/common/tests/UnitLoaderBaseTest.h" #include "velox/dwio/common/tests/utils/UnitLoaderTestTools.h" using namespace ::testing; @@ -31,6 +30,38 @@ using facebook::velox::dwio::common::test::getUnitsLoadedWithFalse; using facebook::velox::dwio::common::test::LoadUnitMock; using facebook::velox::dwio::common::test::ReaderMock; +class OnDemandUnitLoaderCommonTests + : public UnitLoaderBaseTest { + protected: + OnDemandUnitLoaderFactory createFactory() override { + return OnDemandUnitLoaderFactory(nullptr); + } +}; + +TEST_F(OnDemandUnitLoaderCommonTests, NoUnitButSkip) { + testNoUnitButSkip(); +} + +TEST_F(OnDemandUnitLoaderCommonTests, InitialSkip) { + testInitialSkip(); +} + +TEST_F(OnDemandUnitLoaderCommonTests, CanRequestUnitMultipleTimes) { + testCanRequestUnitMultipleTimes(); +} + +TEST_F(OnDemandUnitLoaderCommonTests, UnitOutOfRange) { + testUnitOutOfRange(); +} + +TEST_F(OnDemandUnitLoaderCommonTests, SeekOutOfRange) { + testSeekOutOfRange(); +} + +TEST_F(OnDemandUnitLoaderCommonTests, SeekOutOfRangeReaderError) { + testSeekOutOfRangeReaderError(); +} + TEST(OnDemandUnitLoaderTests, LoadsCorrectlyWithReader) { size_t blockedOnIoCount = 0; OnDemandUnitLoaderFactory factory([&](auto) { ++blockedOnIoCount; }); @@ -127,96 +158,3 @@ TEST(OnDemandUnitLoaderTests, CanSeek) { EXPECT_EQ(readerMock.unitsLoaded(), std::vector({true, false, false})); EXPECT_EQ(blockedOnIoCount, 4); } - -TEST(OnDemandUnitLoaderTests, SeekOutOfRangeReaderError) { - size_t blockedOnIoCount = 0; - OnDemandUnitLoaderFactory factory([&](auto) { ++blockedOnIoCount; }); - ReaderMock readerMock{{10, 20, 30}, {0, 0, 0}, factory, 0}; - EXPECT_EQ(readerMock.unitsLoaded(), std::vector({false, false, false})); - EXPECT_EQ(blockedOnIoCount, 0); - readerMock.seek(59); - - readerMock.seek(60); - - VELOX_ASSERT_THROW( - readerMock.seek(61), - "Can't seek to possition 61 in file. Must be up to 60."); -} - -TEST(OnDemandUnitLoaderTests, SeekOutOfRange) { - OnDemandUnitLoaderFactory factory(nullptr); - std::vector unitsLoaded(getUnitsLoadedWithFalse(1)); - std::vector> units; - units.push_back(std::make_unique(10, 0, unitsLoaded, 0)); - - auto unitLoader = factory.create(std::move(units), 0); - - unitLoader->onSeek(0, 10); - - VELOX_ASSERT_THROW(unitLoader->onSeek(0, 11), "Row out of range"); -} - -TEST(OnDemandUnitLoaderTests, UnitOutOfRange) { - OnDemandUnitLoaderFactory factory(nullptr); - std::vector unitsLoaded(getUnitsLoadedWithFalse(1)); - std::vector> units; - units.push_back(std::make_unique(10, 0, unitsLoaded, 0)); - - auto unitLoader = factory.create(std::move(units), 0); - unitLoader->getLoadedUnit(0); - - VELOX_ASSERT_THROW(unitLoader->getLoadedUnit(1), "Unit out of range"); -} - -TEST(OnDemandUnitLoaderTests, CanRequestUnitMultipleTimes) { - OnDemandUnitLoaderFactory factory(nullptr); - std::vector unitsLoaded(getUnitsLoadedWithFalse(1)); - std::vector> units; - units.push_back(std::make_unique(10, 0, unitsLoaded, 0)); - - auto unitLoader = factory.create(std::move(units), 0); - unitLoader->getLoadedUnit(0); - unitLoader->getLoadedUnit(0); - unitLoader->getLoadedUnit(0); -} - -TEST(OnDemandUnitLoaderTests, InitialSkip) { - auto getFactoryWithSkip = [](uint64_t skipToRow) { - auto factory = std::make_unique(nullptr); - std::vector unitsLoaded(getUnitsLoadedWithFalse(1)); - std::vector> units; - units.push_back(std::make_unique(10, 0, unitsLoaded, 0)); - units.push_back(std::make_unique(20, 0, unitsLoaded, 1)); - units.push_back(std::make_unique(30, 0, unitsLoaded, 2)); - factory->create(std::move(units), skipToRow); - }; - - EXPECT_NO_THROW(getFactoryWithSkip(0)); - EXPECT_NO_THROW(getFactoryWithSkip(1)); - EXPECT_NO_THROW(getFactoryWithSkip(9)); - EXPECT_NO_THROW(getFactoryWithSkip(10)); - EXPECT_NO_THROW(getFactoryWithSkip(11)); - EXPECT_NO_THROW(getFactoryWithSkip(29)); - EXPECT_NO_THROW(getFactoryWithSkip(30)); - EXPECT_NO_THROW(getFactoryWithSkip(31)); - EXPECT_NO_THROW(getFactoryWithSkip(59)); - EXPECT_NO_THROW(getFactoryWithSkip(60)); - VELOX_ASSERT_THROW( - getFactoryWithSkip(61), - "Can only skip up to the past-the-end row of the file."); - VELOX_ASSERT_THROW( - getFactoryWithSkip(100), - "Can only skip up to the past-the-end row of the file."); -} - -TEST(OnDemandUnitLoaderTests, NoUnitButSkip) { - OnDemandUnitLoaderFactory factory(nullptr); - std::vector> units; - - EXPECT_NO_THROW(factory.create(std::move(units), 0)); - - std::vector> units2; - VELOX_ASSERT_THROW( - factory.create(std::move(units2), 1), - "Can only skip up to the past-the-end row of the file."); -} diff --git a/velox/dwio/common/tests/OptionsTests.cpp b/velox/dwio/common/tests/OptionsTests.cpp index 82335821c52..427f086772e 100644 --- a/velox/dwio/common/tests/OptionsTests.cpp +++ b/velox/dwio/common/tests/OptionsTests.cpp @@ -25,6 +25,11 @@ TEST(OptionsTests, defaultRowNumberColumnInfoTest) { ASSERT_EQ(std::nullopt, rowReaderOptions.rowNumberColumnInfo()); } +TEST(OptionsTests, fluxFileFormatRoundTrip) { + ASSERT_EQ(FileFormat::FLUX, toFileFormat("flux")); + ASSERT_EQ("flux", toString(FileFormat::FLUX)); +} + TEST(OptionsTests, setRowNumberColumnInfoTest) { RowReaderOptions rowReaderOptions; RowNumberColumnInfo rowNumberColumnInfo; diff --git a/velox/dwio/common/tests/ParallelUnitLoaderTest.cpp b/velox/dwio/common/tests/ParallelUnitLoaderTest.cpp new file mode 100644 index 00000000000..690acd9fc11 --- /dev/null +++ b/velox/dwio/common/tests/ParallelUnitLoaderTest.cpp @@ -0,0 +1,141 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/dwio/common/ParallelUnitLoader.h" +#include "velox/dwio/common/OnDemandUnitLoader.h" +#include "velox/dwio/common/tests/UnitLoaderBaseTest.h" +#include "velox/dwio/common/tests/utils/UnitLoaderTestTools.h" + +#include +#include +#include + +using namespace facebook::velox::dwio::common; +using namespace facebook::velox::dwio::common::test; + +class ParallelUnitLoaderTest + : public UnitLoaderBaseTest { + protected: + ParallelUnitLoaderFactory createFactory() override { + return ParallelUnitLoaderFactory(ioExecutor_.get(), 2); + } + + std::unique_ptr ioExecutor_ = + std::make_unique(10); +}; + +TEST_F(ParallelUnitLoaderTest, NoUnitButSkip) { + testNoUnitButSkip(); +} + +TEST_F(ParallelUnitLoaderTest, InitialSkip) { + testInitialSkip(); +} + +TEST_F(ParallelUnitLoaderTest, CanRequestUnitMultipleTimes) { + testCanRequestUnitMultipleTimes(); +} + +TEST_F(ParallelUnitLoaderTest, UnitOutOfRange) { + testUnitOutOfRange(); +} + +TEST_F(ParallelUnitLoaderTest, SeekOutOfRange) { + testSeekOutOfRange(); +} + +TEST_F(ParallelUnitLoaderTest, SeekOutOfRangeReaderError) { + testSeekOutOfRangeReaderError(); +} + +TEST_F(ParallelUnitLoaderTest, LoadsCorrectlyWithReader) { + auto factory = createFactory(); + ReaderMock readerMock{{10, 20, 30}, {0, 0, 0}, factory, 0}; + + EXPECT_TRUE(readerMock.read(3)); // Unit: 0, rows: 0-2, load(0) + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + EXPECT_EQ(readerMock.unitsLoaded(), std::vector({true, true, false})); + + EXPECT_TRUE(readerMock.read(3)); // Unit: 0, rows: 3-5 + EXPECT_EQ(readerMock.unitsLoaded(), std::vector({true, true, false})); + + EXPECT_TRUE(readerMock.read(4)); // Unit: 0, rows: 6-9 + EXPECT_EQ(readerMock.unitsLoaded(), std::vector({true, true, false})); + + EXPECT_TRUE(readerMock.read(14)); // Unit: 1, rows: 0-13, unload(0), load(1) + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + EXPECT_EQ(readerMock.unitsLoaded(), std::vector({false, true, true})); + + // will only read 5 rows, no more rows in unit 1 + EXPECT_TRUE(readerMock.read(10)); // Unit: 1, rows: 14-19 + EXPECT_EQ(readerMock.unitsLoaded(), std::vector({false, true, true})); + + EXPECT_TRUE(readerMock.read(30)); // Unit: 2, rows: 0-29, unload(1), load(2) + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + EXPECT_EQ(readerMock.unitsLoaded(), std::vector({false, false, true})); + + EXPECT_FALSE(readerMock.read(30)); // No more data + EXPECT_EQ(readerMock.unitsLoaded(), std::vector({false, false, true})); +} + +// Performance comparison test +TEST_F(ParallelUnitLoaderTest, PerformanceComparison) { + std::vector rowsPerUnit = {100, 100, 100, 100, 100, 100, 100, 100}; + std::vector ioSizes = { + 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024}; + + // Measure ParallelUnitLoader performance + auto parallelStart = std::chrono::high_resolution_clock::now(); + { + auto factory = createFactory(); + ReaderMock reader(rowsPerUnit, ioSizes, factory, 0); + + for (size_t i = 0; i < rowsPerUnit.size(); ++i) { + uint64_t totalRowsRead = 0; + while (totalRowsRead < rowsPerUnit[i]) { + reader.read(25); + int nextRead = rowsPerUnit[i] - totalRowsRead; + totalRowsRead += std::min(25, nextRead); + } + } + } + auto parallelEnd = std::chrono::high_resolution_clock::now(); + + // Measure OnDemandUnitLoader performance + auto onDemandStart = std::chrono::high_resolution_clock::now(); + { + auto factory = std::make_shared(nullptr); + ReaderMock reader(rowsPerUnit, ioSizes, *factory, 0); + + for (size_t i = 0; i < rowsPerUnit.size(); ++i) { + uint64_t totalRowsRead = 0; + while (totalRowsRead < rowsPerUnit[i]) { + reader.read(25); + int nextRead = rowsPerUnit[i] - totalRowsRead; + totalRowsRead += std::min(25, nextRead); + } + } + } + auto onDemandEnd = std::chrono::high_resolution_clock::now(); + + auto parallelDuration = std::chrono::duration_cast( + parallelEnd - parallelStart); + auto onDemandDuration = std::chrono::duration_cast( + onDemandEnd - onDemandStart); + + // ParallelUnitLoader should be faster + EXPECT_GT(onDemandDuration.count(), parallelDuration.count()); +} diff --git a/velox/dwio/common/tests/ReadFileInputStreamTests.cpp b/velox/dwio/common/tests/ReadFileInputStreamTests.cpp index 273b523549f..1b600b6adaf 100644 --- a/velox/dwio/common/tests/ReadFileInputStreamTests.cpp +++ b/velox/dwio/common/tests/ReadFileInputStreamTests.cpp @@ -15,8 +15,8 @@ */ #include "velox/common/file/FileSystems.h" +#include "velox/common/testutil/TempFilePath.h" #include "velox/dwio/common/InputStream.h" -#include "velox/exec/tests/utils/TempFilePath.h" #include #include "folly/io/Cursor.h" @@ -24,6 +24,7 @@ #include "gtest/gtest.h" using namespace facebook::velox; +using namespace facebook::velox::common::testutil; using namespace facebook::velox::dwio::common; using facebook::velox::common::Region; @@ -35,7 +36,7 @@ class ReadFileInputStreamTest : public testing::Test { }; TEST_F(ReadFileInputStreamTest, LocalReadFile) { - auto tempFile = exec::test::TempFilePath::create(); + auto tempFile = TempFilePath::create(); const auto& filename = tempFile->getPath(); remove(filename.c_str()); { diff --git a/velox/dwio/common/tests/ReadVuLongBenchmark.cpp b/velox/dwio/common/tests/ReadVuLongBenchmark.cpp new file mode 100644 index 00000000000..5cc0b2aabb1 --- /dev/null +++ b/velox/dwio/common/tests/ReadVuLongBenchmark.cpp @@ -0,0 +1,319 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Benchmark comparing old vs new readVuLong implementation. +// The "old" version uses UNLIKELY for the first byte termination check. +// The "new" version (in production IntDecoder.h) uses LIKELY and early return. + +#include "folly/Benchmark.h" +#include "folly/Random.h" +#include "folly/Varint.h" +#include "folly/init/Init.h" +#include "velox/dwio/common/IntCodecCommon.h" + +using namespace facebook::velox::dwio::common; + +constexpr size_t kNumElements = 1000000; + +// Encapsulates benchmark state to avoid global mutable variables. +struct BenchmarkState { + std::vector bufferSmall; // Values 0-127 (1-byte varints) + std::vector bufferMedium; // Values 0-16383 (1-2 byte varints) + std::vector bufferMixed; // Random uint32 (1-5 byte varints) + + size_t lenSmall{0}; + size_t lenMedium{0}; + size_t lenMixed{0}; + + size_t numValuesSmall{0}; + size_t numValuesMedium{0}; + size_t numValuesMixed{0}; + + static BenchmarkState& instance() { + static BenchmarkState state; + return state; + } +}; + +// Helper to write a varint to buffer +size_t writeVulong(uint64_t val, char* buffer, size_t pos) { + while (true) { + if ((val & ~0x7f) == 0) { + buffer[pos++] = static_cast(val); + return pos; + } + buffer[pos++] = static_cast(0x80 | (val & BASE_128_MASK)); + val = (static_cast(val) >> 7); + } +} + +// OLD implementation: uses UNLIKELY for first byte, no early return +// This is the original code before the optimization. +uint64_t readVuLongOld(const char*& bufferStart, const char* bufferEnd) { + if (LIKELY(bufferEnd - bufferStart >= folly::kMaxVarintLength64)) { + const char* p = bufferStart; + uint64_t val; + do { + int64_t b; + b = *p++; + val = (b & 0x7f); + if (UNLIKELY(b >= 0)) { // OLD: UNLIKELY here + break; + } + b = *p++; + val |= (b & 0x7f) << 7; + if (UNLIKELY(b >= 0)) { + break; + } + b = *p++; + val |= (b & 0x7f) << 14; + if (UNLIKELY(b >= 0)) { + break; + } + b = *p++; + val |= (b & 0x7f) << 21; + if (UNLIKELY(b >= 0)) { + break; + } + b = *p++; + val |= (b & 0x7f) << 28; + if (UNLIKELY(b >= 0)) { + break; + } + b = *p++; + val |= (b & 0x7f) << 35; + if (UNLIKELY(b >= 0)) { + break; + } + b = *p++; + val |= (b & 0x7f) << 42; + if (UNLIKELY(b >= 0)) { + break; + } + b = *p++; + val |= (b & 0x7f) << 49; + if (UNLIKELY(b >= 0)) { + break; + } + b = *p++; + val |= (b & 0x7f) << 56; + if (UNLIKELY(b >= 0)) { + break; + } + b = *p++; + val |= (b & 0x01) << 63; + } while (false); + + bufferStart = p; + return val; + } + + // Slow path + uint64_t result = 0; + uint64_t offset = 0; + signed char ch; + do { + ch = *(bufferStart++); + result |= (ch & BASE_128_MASK) << offset; + offset += 7; + } while (ch < 0); + return result; +} + +// NEW implementation: uses LIKELY for first byte with early return +// This matches the optimized code in IntDecoder.h +uint64_t readVuLongNew(const char*& bufferStart, const char* bufferEnd) { + if (LIKELY(bufferEnd - bufferStart >= folly::kMaxVarintLength64)) { + const char* p = bufferStart; + uint64_t val; + + // Fast path for 1-byte varints (values 0-127), which are very common. + // This avoids the do-while loop overhead for the most frequent case. + int64_t b = *p++; + val = (b & 0x7f); + if (b >= 0) { // better without likely or unlikely here. + bufferStart = p; + return val; + } + + // Multi-byte varint path + do { + b = *p++; + val |= (b & 0x7f) << 7; + if (UNLIKELY(b >= 0)) { + break; + } + b = *p++; + val |= (b & 0x7f) << 14; + if (UNLIKELY(b >= 0)) { + break; + } + b = *p++; + val |= (b & 0x7f) << 21; + if (UNLIKELY(b >= 0)) { + break; + } + b = *p++; + val |= (b & 0x7f) << 28; + if (UNLIKELY(b >= 0)) { + break; + } + b = *p++; + val |= (b & 0x7f) << 35; + if (UNLIKELY(b >= 0)) { + break; + } + b = *p++; + val |= (b & 0x7f) << 42; + if (UNLIKELY(b >= 0)) { + break; + } + b = *p++; + val |= (b & 0x7f) << 49; + if (UNLIKELY(b >= 0)) { + break; + } + b = *p++; + val |= (b & 0x7f) << 56; + if (UNLIKELY(b >= 0)) { + break; + } + b = *p++; + val |= (b & 0x01) << 63; + } while (false); + + bufferStart = p; + return val; + } + + // Slow path + uint64_t result = 0; + uint64_t offset = 0; + signed char ch; + do { + ch = *(bufferStart++); + result |= (ch & BASE_128_MASK) << offset; + offset += 7; + } while (ch < 0); + return result; +} + +// Benchmarks for small values (0-127, all 1-byte varints) +// This is where the optimization should show the biggest improvement. +BENCHMARK(small_values_old) { + auto& state = BenchmarkState::instance(); + const char* p = state.bufferSmall.data(); + const char* end = p + state.lenSmall; + for (size_t i = 0; i < state.numValuesSmall; ++i) { + auto result = readVuLongOld(p, end); + folly::doNotOptimizeAway(result); + } +} + +BENCHMARK_RELATIVE(small_values_new) { + auto& state = BenchmarkState::instance(); + const char* p = state.bufferSmall.data(); + const char* end = p + state.lenSmall; + for (size_t i = 0; i < state.numValuesSmall; ++i) { + auto result = readVuLongNew(p, end); + folly::doNotOptimizeAway(result); + } +} + +BENCHMARK_DRAW_LINE(); + +// Benchmarks for medium values (0-16383, mix of 1-2 byte varints) +BENCHMARK(medium_values_old) { + auto& state = BenchmarkState::instance(); + const char* p = state.bufferMedium.data(); + const char* end = p + state.lenMedium; + for (size_t i = 0; i < state.numValuesMedium; ++i) { + auto result = readVuLongOld(p, end); + folly::doNotOptimizeAway(result); + } +} + +BENCHMARK_RELATIVE(medium_values_new) { + auto& state = BenchmarkState::instance(); + const char* p = state.bufferMedium.data(); + const char* end = p + state.lenMedium; + for (size_t i = 0; i < state.numValuesMedium; ++i) { + auto result = readVuLongNew(p, end); + folly::doNotOptimizeAway(result); + } +} + +BENCHMARK_DRAW_LINE(); + +// Benchmarks for mixed/random uint32 values (1-5 byte varints) +BENCHMARK(mixed_values_old) { + auto& state = BenchmarkState::instance(); + const char* p = state.bufferMixed.data(); + const char* end = p + state.lenMixed; + for (size_t i = 0; i < state.numValuesMixed; ++i) { + auto result = readVuLongOld(p, end); + folly::doNotOptimizeAway(result); + } +} + +BENCHMARK_RELATIVE(mixed_values_new) { + auto& state = BenchmarkState::instance(); + const char* p = state.bufferMixed.data(); + const char* end = p + state.lenMixed; + for (size_t i = 0; i < state.numValuesMixed; ++i) { + auto result = readVuLongNew(p, end); + folly::doNotOptimizeAway(result); + } +} + +int32_t main(int32_t argc, char* argv[]) { + folly::Init init{&argc, &argv}; + + auto& state = BenchmarkState::instance(); + + // Populate small values buffer (0-127, all 1-byte varints) + state.bufferSmall.resize(kNumElements); + size_t pos = 0; + state.numValuesSmall = 500000; + for (size_t i = 0; i < state.numValuesSmall; i++) { + uint64_t val = folly::Random::rand32() % 128; // 0-127 + pos = writeVulong(val, state.bufferSmall.data(), pos); + } + state.lenSmall = pos; + + // Populate medium values buffer (0-16383, mix of 1-2 byte varints) + state.bufferMedium.resize(kNumElements); + pos = 0; + state.numValuesMedium = 400000; + for (size_t i = 0; i < state.numValuesMedium; i++) { + uint64_t val = folly::Random::rand32() % 16384; // 0-16383 + pos = writeVulong(val, state.bufferMedium.data(), pos); + } + state.lenMedium = pos; + + // Populate mixed values buffer (random uint32, 1-5 byte varints) + state.bufferMixed.resize(kNumElements); + pos = 0; + state.numValuesMixed = 200000; + for (size_t i = 0; i < state.numValuesMixed; i++) { + uint64_t val = folly::Random::rand32(); + pos = writeVulong(val, state.bufferMixed.data(), pos); + } + state.lenMixed = pos; + + folly::runBenchmarks(); + return 0; +} diff --git a/velox/dwio/common/tests/ReaderTest.cpp b/velox/dwio/common/tests/ReaderTest.cpp index 4b277ff484a..400bd80df93 100644 --- a/velox/dwio/common/tests/ReaderTest.cpp +++ b/velox/dwio/common/tests/ReaderTest.cpp @@ -18,6 +18,7 @@ #include "velox/common/base/tests/GTestUtils.h" #include "velox/vector/tests/utils/VectorTestBase.h" +#include #include namespace facebook::velox::dwio::common { @@ -148,34 +149,395 @@ TEST_F(ReaderTest, projectColumnsMutation) { makeFlatVector({0, 1, 3, 4, 5, 6, 7, 8, 9}), }); test::assertEqualVectors(expected, actual); - random::setSeed(42); - random::RandomSkipTracker randomSkip(0.5); - mutation.randomSkip = &randomSkip; - actual = RowReader::projectColumns(input, spec, &mutation); - if constexpr (std::is_same_v) { -#if __APPLE__ - expected = makeRowVector({ - makeFlatVector({1, 5, 6, 7, 8, 9}), - }); -#else - expected = makeRowVector({ - makeFlatVector({3, 4, 7, 9}), - }); -#endif -#if FOLLY_HAVE_EXTRANDOM_SFMT19937 - } else if constexpr (std::is_same_v< - folly::detail::DefaultGenerator, - __gnu_cxx::sfmt19937>) { - expected = makeRowVector({ - makeFlatVector({0, 1, 3, 5, 6, 8}), - }); -#endif - } else { - expected = makeRowVector({ - makeFlatVector({1, 3, 5, 7}), - }); + + constexpr auto kNumRounds = 1U << 6; + + size_t numNonZero = 0; + size_t numNonMax = 0; + + // Test with random skip - use property-based testing instead of hardcoded + // outputs to avoid brittleness when folly::Random implementation changes. + std::mt19937 seeds; + for (size_t round = 0; round < kNumRounds; ++round) { + const auto seed = seeds(); + + random::setSeed(folly::to_narrow(seed)); + random::RandomSkipTracker randomSkip(0.5); + mutation.randomSkip = &randomSkip; + actual = RowReader::projectColumns(input, spec, &mutation); + + // Property 1: Result size should be less than input size (some rows + // skipped). With 0.5 sample rate and 9 eligible rows (excluding deleted row + // 2), we expect roughly 4-5 rows, but allow wider range for RNG variance. + EXPECT_GE(actual->size(), 0); + EXPECT_LE(actual->size(), kSize - 1); + + numNonZero += actual->size() > 0; + numNonMax += actual->size() < kSize - 1; + + // The result is a RowVector with one child column. Assume it. + auto res = actual->as()->childAt(0)->as>(); + std::vector vec; + vec.reserve(actual->size()); + for (vector_size_t i = 0; i < actual->size(); ++i) { + vec.push_back(res->valueAt(i)); + } + + // Property 2: All values in result must be from original input. + for (auto val : vec) { + // Each value must be in valid range + EXPECT_GE(val, 0); + EXPECT_LT(val, kSize); + // Deleted row should never appear + EXPECT_NE(val, 2); + } + + // Property 3: Values should be in ascending order (projectColumns preserves + // order). + EXPECT_TRUE(std::is_sorted(vec.begin(), vec.end())); + + // Property 4: No duplicate values (each input row appears at most once). + EXPECT_TRUE(std::adjacent_find(vec.begin(), vec.end()) == vec.end()); + + // Property 5: With a fixed seed, the result should be deterministic + // (same seed = same output, even if we don't know what that output is) + random::setSeed(folly::to_narrow(seed)); + random::RandomSkipTracker randomSkip2(0.5); + mutation.randomSkip = &randomSkip2; + auto actual2 = RowReader::projectColumns(input, spec, &mutation); + test::assertEqualVectors(actual, actual2); } - test::assertEqualVectors(expected, actual); + + EXPECT_NE(0, numNonZero); + EXPECT_NE(0, numNonMax); +} + +TEST_F(ReaderTest, rowRangeEmpty) { + // Empty when startRow >= endRow + EXPECT_TRUE((RowRange{0, 0}.empty())); + EXPECT_TRUE((RowRange{5, 5}.empty())); + EXPECT_TRUE((RowRange{10, 5}.empty())); + + // Not empty when startRow < endRow + EXPECT_FALSE((RowRange{0, 1}.empty())); + EXPECT_FALSE((RowRange{0, 10}.empty())); + EXPECT_FALSE((RowRange{5, 10}.empty())); +} + +// Test that projectColumns preserves top level nulls when the input RowVector +// has null rows. +TEST_F(ReaderTest, projectColumnsTopLevelNulls) { + constexpr int kSize = 10; + // All nulls + { + SCOPED_TRACE("All nulls"); + auto child = makeFlatVector(kSize, folly::identity); + auto input = makeRowVector({child}); + + auto nulls = AlignedBuffer::allocate(kSize, pool()); + auto* rawNulls = nulls->asMutable(); + // Set all bits to 0 (all null) + memset(rawNulls, 0, bits::nbytes(kSize)); + input->setNulls(nulls); + + ASSERT_EQ(BaseVector::countNulls(input->nulls(), kSize), kSize); + + common::ScanSpec spec(""); + spec.addAllChildFields(*input->type()); + + auto actual = RowReader::projectColumns(input, spec, nullptr); + + ASSERT_NE(actual->nulls(), nullptr); + EXPECT_EQ(BaseVector::countNulls(actual->nulls(), actual->size()), kSize); + + // Verify the output has the same size + EXPECT_EQ(actual->size(), kSize); + } + + // Partial nulls. + { + SCOPED_TRACE("Partial nulls"); + auto child = makeFlatVector(kSize, folly::identity); + auto input = makeRowVector({child}); + + auto nulls = AlignedBuffer::allocate(kSize, pool()); + auto* rawNulls = nulls->asMutable(); + + // Set rows 0, 2, 4, 6, 8 as null (even indices) + memset(rawNulls, 0xFF, bits::nbytes(kSize)); + for (int i = 0; i < kSize; i += 2) { + bits::setNull(rawNulls, i); + } + input->setNulls(nulls); + + ASSERT_EQ(BaseVector::countNulls(input->nulls(), kSize), 5); // 5 null rows + + common::ScanSpec spec(""); + spec.addAllChildFields(*input->type()); + + // Test without mutation (no filtering) + auto actual = RowReader::projectColumns(input, spec, nullptr); + + // Verify nulls are preserved + ASSERT_NE(actual->nulls(), nullptr); + EXPECT_EQ(BaseVector::countNulls(actual->nulls(), actual->size()), 5); + EXPECT_EQ(actual->size(), kSize); + + // Verify specific null positions + for (int i = 0; i < kSize; ++i) { + EXPECT_EQ(actual->isNullAt(i), (i % 2 == 0)) + << "Row " << i << " null status mismatch"; + } + } + + // Partial nulls with constant encoding child + { + SCOPED_TRACE("Constant encoding child"); + // Create a constant vector (all values are the same) + auto constantChild = + BaseVector::createConstant(INTEGER(), 42, kSize, pool()); + auto input = makeRowVector({constantChild}); + + auto nulls = AlignedBuffer::allocate(kSize, pool()); + auto* rawNulls = nulls->asMutable(); + + // Set rows 0, 2, 4, 6, 8 as null (even indices) + memset(rawNulls, 0xFF, bits::nbytes(kSize)); + for (int i = 0; i < kSize; i += 2) { + bits::setNull(rawNulls, i); + } + input->setNulls(nulls); + + ASSERT_EQ(BaseVector::countNulls(input->nulls(), kSize), 5); // 5 null rows + ASSERT_TRUE(constantChild->isConstantEncoding()); + + common::ScanSpec spec(""); + spec.addAllChildFields(*input->type()); + + auto actual = RowReader::projectColumns(input, spec, nullptr); + + // Verify nulls are preserved + ASSERT_NE(actual->nulls(), nullptr); + EXPECT_EQ(BaseVector::countNulls(actual->nulls(), actual->size()), 5); + EXPECT_EQ(actual->size(), kSize); + + // Verify specific null positions + for (int i = 0; i < kSize; ++i) { + EXPECT_EQ(actual->isNullAt(i), (i % 2 == 0)) + << "Row " << i << " null status mismatch"; + } + } + + // Partial nulls with dictionary encoding child + { + SCOPED_TRACE("Dictionary encoding child"); + // Create a dictionary vector wrapping a flat vector + auto baseValues = makeFlatVector({10, 20, 30, 40, 50}); + // Create indices that map to the base values + auto indices = makeIndices(kSize, [](auto i) { return i % 5; }); + auto dictionaryChild = + BaseVector::wrapInDictionary(nullptr, indices, kSize, baseValues); + auto input = makeRowVector({dictionaryChild}); + + auto nulls = AlignedBuffer::allocate(kSize, pool()); + auto* rawNulls = nulls->asMutable(); + + // Set rows 0, 2, 4, 6, 8 as null (even indices) + memset(rawNulls, 0xFF, bits::nbytes(kSize)); + for (int i = 0; i < kSize; i += 2) { + bits::setNull(rawNulls, i); + } + input->setNulls(nulls); + + ASSERT_EQ(BaseVector::countNulls(input->nulls(), kSize), 5); // 5 null rows + ASSERT_EQ(dictionaryChild->encoding(), VectorEncoding::Simple::DICTIONARY); + + common::ScanSpec spec(""); + spec.addAllChildFields(*input->type()); + + auto actual = RowReader::projectColumns(input, spec, nullptr); + + // Verify nulls are preserved + ASSERT_NE(actual->nulls(), nullptr); + EXPECT_EQ(BaseVector::countNulls(actual->nulls(), actual->size()), 5); + EXPECT_EQ(actual->size(), kSize); + + // Verify specific null positions + for (int i = 0; i < kSize; ++i) { + EXPECT_EQ(actual->isNullAt(i), (i % 2 == 0)) + << "Row " << i << " null status mismatch"; + } + } + + // Two columns: constant encoding and dictionary encoding + { + SCOPED_TRACE("Two columns: constant and dictionary encoding"); + // Create a constant vector + auto constantChild = + BaseVector::createConstant(INTEGER(), 42, kSize, pool()); + + // Create a dictionary vector + auto baseValues = makeFlatVector({100, 200, 300, 400, 500}); + auto indices = makeIndices(kSize, [](auto i) { return i % 5; }); + auto dictionaryChild = + BaseVector::wrapInDictionary(nullptr, indices, kSize, baseValues); + + auto input = makeRowVector({constantChild, dictionaryChild}); + + auto nulls = AlignedBuffer::allocate(kSize, pool()); + auto* rawNulls = nulls->asMutable(); + + // Set rows 0, 2, 4, 6, 8 as null (even indices) + memset(rawNulls, 0xFF, bits::nbytes(kSize)); + for (int i = 0; i < kSize; i += 2) { + bits::setNull(rawNulls, i); + } + input->setNulls(nulls); + + ASSERT_EQ(BaseVector::countNulls(input->nulls(), kSize), 5); + ASSERT_TRUE(constantChild->isConstantEncoding()); + ASSERT_EQ(dictionaryChild->encoding(), VectorEncoding::Simple::DICTIONARY); + + common::ScanSpec spec(""); + spec.addAllChildFields(*input->type()); + + auto actual = RowReader::projectColumns(input, spec, nullptr); + + // Verify nulls are preserved + ASSERT_NE(actual->nulls(), nullptr); + EXPECT_EQ(BaseVector::countNulls(actual->nulls(), actual->size()), 5); + EXPECT_EQ(actual->size(), kSize); + + // Verify both children exist + auto rowResult = actual->as(); + ASSERT_EQ(rowResult->childrenSize(), 2); + + // Verify specific null positions + for (int i = 0; i < kSize; ++i) { + EXPECT_EQ(actual->isNullAt(i), (i % 2 == 0)) + << "Row " << i << " null status mismatch"; + } + } +} + +// Test that projectColumns correctly filters row-level nulls with mutation +TEST_F(ReaderTest, projectColumnsFiltersRowNullsWithMutation) { + constexpr int kSize = 10; + + // Create a RowVector with some null rows + auto child = makeFlatVector(kSize, folly::identity); + auto input = makeRowVector({child}); + + // Set rows 0, 2, 4, 6, 8 as null (even indices) + auto nulls = AlignedBuffer::allocate(kSize, pool()); + auto* rawNulls = nulls->asMutable(); + memset(rawNulls, 0xFF, bits::nbytes(kSize)); + for (int i = 0; i < kSize; i += 2) { + bits::setNull(rawNulls, i); + } + input->setNulls(nulls); + + common::ScanSpec spec(""); + spec.addAllChildFields(*input->type()); + + // Delete rows 1 and 3 (odd indices, which are not null) + std::vector deleted(bits::nwords(kSize)); + bits::setBit(deleted.data(), 1); + bits::setBit(deleted.data(), 3); + Mutation mutation; + mutation.deletedRows = deleted.data(); + + auto actual = RowReader::projectColumns(input, spec, &mutation); + + // 8 rows should remain (10 - 2 deleted) + EXPECT_EQ(actual->size(), 8); + + // Verify nulls buffer exists + ASSERT_NE(actual->nulls(), nullptr); + + // After filtering, the remaining rows are: 0, 2, 4, 5, 6, 7, 8, 9 + // Of these, 0, 2, 4, 6, 8 were null in the original (indices 0, 1, 2, 4, 6) + // So in the output: positions 0, 1, 2, 4, 6 should be null + EXPECT_TRUE(actual->isNullAt(0)); // was row 0 + EXPECT_TRUE(actual->isNullAt(1)); // was row 2 + EXPECT_TRUE(actual->isNullAt(2)); // was row 4 + EXPECT_FALSE(actual->isNullAt(3)); // was row 5 + EXPECT_TRUE(actual->isNullAt(4)); // was row 6 + EXPECT_FALSE(actual->isNullAt(5)); // was row 7 + EXPECT_TRUE(actual->isNullAt(6)); // was row 8 + EXPECT_FALSE(actual->isNullAt(7)); // was row 9 +} + +TEST_F(ReaderTest, projectColumnsWithSelectionIdentity) { + // No filter: every input row passes, selectedRows is null to signal + // identity mapping with input. + constexpr int kSize = 5; + auto input = + makeRowVector({"c0"}, {makeFlatVector(kSize, folly::identity)}); + + common::ScanSpec spec(""); + spec.addAllChildFields(*input->type()); + + auto result = RowReader::projectColumnsWithSelection(input, spec, nullptr); + EXPECT_EQ(result.output->size(), kSize); + EXPECT_EQ(result.selectedRows, nullptr); +} + +TEST_F(ReaderTest, projectColumnsWithSelectionFiltered) { + // Filter keeps a subset; selectedRows must list the surviving input + // indices in order. + constexpr int kSize = 6; + auto input = + makeRowVector({"c0"}, {makeFlatVector(kSize, folly::identity)}); + + common::ScanSpec spec(""); + spec.addAllChildFields(*input->type()); + spec.childByName("c0")->setFilter( + common::createBigintValues({1, 3, 5}, false)); + + auto result = RowReader::projectColumnsWithSelection(input, spec, nullptr); + ASSERT_NE(result.selectedRows, nullptr); + ASSERT_EQ(result.output->size(), 3); + ASSERT_EQ( + result.selectedRows->size() / sizeof(vector_size_t), + result.output->size()); + const auto* indices = result.selectedRows->as(); + EXPECT_THAT( + std::vector(indices, indices + 3), + testing::ElementsAre(1, 3, 5)); +} + +TEST_F(ReaderTest, projectColumnsWithSelectionAllFiltered) { + // Filter rejects every row; output is empty and selectedRows is a + // zero-length buffer (non-null) so callers can distinguish "empty after + // filtering" from "identity mapping". + constexpr int kSize = 4; + auto input = + makeRowVector({"c0"}, {makeFlatVector(kSize, folly::identity)}); + + common::ScanSpec spec(""); + spec.addAllChildFields(*input->type()); + spec.childByName("c0")->setFilter(common::createBigintValues({99}, false)); + + auto result = RowReader::projectColumnsWithSelection(input, spec, nullptr); + EXPECT_EQ(result.output->size(), 0); + ASSERT_NE(result.selectedRows, nullptr); + EXPECT_EQ(result.selectedRows->size(), 0); +} + +TEST_F(ReaderTest, projectColumnsWithSelectionEmptyInput) { + // Empty input: no rows were dropped, so the identity contract holds and + // selectedRows must be null. Distinguishes "empty identity" from "all + // rows filtered out", which returns a non-null zero-length buffer. + auto input = + makeRowVector({"c0"}, {makeFlatVector(0, folly::identity)}); + + common::ScanSpec spec(""); + spec.addAllChildFields(*input->type()); + + auto result = RowReader::projectColumnsWithSelection(input, spec, nullptr); + EXPECT_EQ(result.output->size(), 0); + EXPECT_EQ(result.selectedRows, nullptr); } } // namespace diff --git a/velox/dwio/common/tests/ScanSpecTest.cpp b/velox/dwio/common/tests/ScanSpecTest.cpp index d61675878d2..96eb4f10dcb 100644 --- a/velox/dwio/common/tests/ScanSpecTest.cpp +++ b/velox/dwio/common/tests/ScanSpecTest.cpp @@ -54,6 +54,57 @@ TEST_F(ScanSpecTest, applyFilter) { VeloxRuntimeError); } +TEST_F(ScanSpecTest, setFilterResetsHasFilter) { + auto rowVector = makeRowVector({ + makeFlatVector(64, folly::identity), + makeFlatVector(64, folly::identity), + }); + + ScanSpec scanSpec(""); + scanSpec.addAllChildFields(*rowVector->type()); + + // Initially no filter, hasFilter should be false. + ASSERT_FALSE(scanSpec.hasFilter()); + ASSERT_FALSE(scanSpec.childByName("c0")->hasFilter()); + ASSERT_FALSE(scanSpec.childByName("c1")->hasFilter()); + + // Set a filter on c0, hasFilter should be true for c0 and root. + scanSpec.childByName("c0")->setFilter(createBigintValues({1, 2, 3}, false)); + ASSERT_FALSE(scanSpec.childByName("c0")->hasFilter()); + ASSERT_FALSE(scanSpec.hasFilter()); + // Root's hasFilter_ was cached as false, but setFilter should have reset it. + // After setting filter on child, root should report hasFilter as true. + scanSpec.resetCachedValues(false); + ASSERT_TRUE(scanSpec.hasFilter()); + ASSERT_TRUE(scanSpec.childByName("c0")->hasFilter()); + + // Set filter to nullptr, hasFilter should become false. + scanSpec.childByName("c0")->setFilter(nullptr); + ASSERT_TRUE(scanSpec.childByName("c0")->hasFilter()); + ASSERT_TRUE(scanSpec.hasFilter()); + scanSpec.resetCachedValues(false); + ASSERT_FALSE(scanSpec.childByName("c0")->hasFilter()); + ASSERT_FALSE(scanSpec.hasFilter()); + + // Set a new filter on c1, verify hasFilter updates correctly. + scanSpec.childByName("c1")->setFilter( + std::make_shared(10, 50, false)); + ASSERT_FALSE(scanSpec.childByName("c1")->hasFilter()); + ASSERT_FALSE(scanSpec.childByName("c0")->hasFilter()); + scanSpec.resetCachedValues(false); + ASSERT_FALSE(scanSpec.childByName("c0")->hasFilter()); + ASSERT_TRUE(scanSpec.childByName("c1")->hasFilter()); + ASSERT_TRUE(scanSpec.hasFilter()); + + // Replace filter on c1 with a different filter. + scanSpec.childByName("c1")->setFilter( + std::make_shared(20, 30, false)); + // hasFilter should still be true after replacing with another filter. + ASSERT_TRUE(scanSpec.childByName("c1")->hasFilter()); + ASSERT_FALSE(scanSpec.childByName("c0")->hasFilter()); + ASSERT_TRUE(scanSpec.hasFilter()); +} + class TypedScanSpecTest : public testing::TestWithParam, public test::VectorTestBase { protected: diff --git a/velox/dwio/common/tests/SelectiveColumnReaderTest.cpp b/velox/dwio/common/tests/SelectiveColumnReaderTest.cpp new file mode 100644 index 00000000000..d4a237d799c --- /dev/null +++ b/velox/dwio/common/tests/SelectiveColumnReaderTest.cpp @@ -0,0 +1,592 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/dwio/common/SelectiveColumnReader.h" +#include "velox/dwio/common/SelectiveColumnReaderInternal.h" + +#include +#include +#include +#include + +#include "velox/common/memory/MemoryPool.h" +#include "velox/dwio/common/FormatData.h" +#include "velox/dwio/common/TypeWithId.h" + +namespace facebook::velox::dwio::common { +namespace { + +TEST(IsDenseTest, empty) { + const RowSet rows(nullptr, nullptr); + EXPECT_TRUE(isDense(rows)); +} + +TEST(IsDenseTest, singleElement) { + const std::vector data{0}; + const RowSet rows(data.data(), data.size()); + EXPECT_TRUE(isDense(rows)); +} + +TEST(IsDenseTest, contiguousFromZero) { + const std::vector data{0, 1, 2, 3, 4}; + const RowSet rows(data.data(), data.size()); + EXPECT_TRUE(isDense(rows)); +} + +TEST(IsDenseTest, sparseRows) { + const std::vector data{0, 2, 4}; + const RowSet rows(data.data(), data.size()); + EXPECT_FALSE(isDense(rows)); +} + +TEST(IsDenseTest, startingFromNonZero) { + const std::vector data{1, 2, 3}; + const RowSet rows(data.data(), data.size()); + EXPECT_FALSE(isDense(rows)); +} + +TEST(IsDenseTest, singleNonZeroElement) { + const std::vector data{5}; + const RowSet rows(data.data(), data.size()); + EXPECT_FALSE(isDense(rows)); +} + +// Minimal FormatData stub for testing SelectiveColumnReader in isolation. +class StubFormatData : public FormatData { + public: + void readNulls( + vector_size_t /*numValues*/, + const uint64_t* /*incomingNulls*/, + BufferPtr& nulls, + bool /*nullsOnly*/) override { + nulls = nullptr; + } + uint64_t skipNulls(uint64_t numValues, bool /*nullsOnly*/) override { + return numValues; + } + uint64_t skip(uint64_t numValues) override { + return numValues; + } + bool hasNulls() const override { + return false; + } + dwio::common::PositionProvider seekToRowGroup(int64_t /*index*/) override { + static std::vector empty; + return dwio::common::PositionProvider(empty); + } + void filterRowGroups( + const velox::common::ScanSpec& /*scanSpec*/, + uint64_t /*rowsPerRowGroup*/, + const StatsContext& /*writerContext*/, + FilterRowGroupsResult& /*result*/) override {} +}; + +// Minimal FormatParams stub that produces a StubFormatData. +class StubFormatParams : public FormatParams { + public: + StubFormatParams(memory::MemoryPool& pool, ColumnReaderStatistics& stats) + : FormatParams(pool, stats) {} + + std::unique_ptr toFormatData( + const std::shared_ptr& /*type*/, + const velox::common::ScanSpec& /*scanSpec*/) override { + return std::make_unique(); + } +}; + +// Concrete subclass that exposes getFlatValues and internal state for testing. +class TestColumnReader : public SelectiveColumnReader { + public: + TestColumnReader( + const TypePtr& requestedType, + std::shared_ptr fileType, + FormatParams& params, + velox::common::ScanSpec& scanSpec) + : SelectiveColumnReader( + requestedType, + std::move(fileType), + params, + scanSpec) {} + + void read( + int64_t /*offset*/, + const RowSet& /*rows*/, + const uint64_t* /*incomingNulls*/) override {} + + void getValues(const RowSet& /*rows*/, VectorPtr* /*result*/) override {} + + /// Populate the internal values buffer with source data of type T. + /// Sets valueSize_, numValues_, mayGetValues_, and inputRows_. + /// @note rowNumbers must remain live until getFlatValues is called, as + /// inputRows_ stores a non-owning reference. + template + void setupValues( + const std::vector& data, + const std::vector& rowNumbers) { + VELOX_CHECK_EQ(data.size(), rowNumbers.size()); + const auto n = static_cast(data.size()); + ensureValuesCapacity(n); + std::memcpy(rawValues_, data.data(), n * sizeof(T)); + numValues_ = n; + valueSize_ = sizeof(T); + mayGetValues_ = true; + allNull_ = false; + inputRows_ = RowSet(rowNumbers.data(), rowNumbers.size()); + } + + /// Populate internal values and mark elements null according to the mask. + template + void setupValuesWithNulls( + const std::vector& data, + const std::vector& nulls, + const std::vector& rowNumbers) { + VELOX_CHECK_EQ(data.size(), nulls.size()); + setupValues(data, rowNumbers); + const auto n = static_cast(data.size()); + anyNulls_ = true; + resultNulls_ = AlignedBuffer::allocate( + n + simd::kPadding * 8, pool_, bits::kNotNull); + rawResultNulls_ = resultNulls_->asMutable(); + for (int32_t i = 0; i < n; ++i) { + bits::setBit(rawResultNulls_, i, !nulls[i]); + } + } + + using SelectiveColumnReader::getFlatValues; +}; + +class GetFlatValuesTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + } + + void SetUp() override { + pool_ = memory::memoryManager()->addLeafPool("GetFlatValuesTest"); + stats_ = std::make_unique(); + params_ = std::make_unique(*pool_, *stats_); + scanSpec_ = std::make_unique("test"); + scanSpec_->setProjectOut(true); + } + + std::unique_ptr makeReader( + const TypePtr& requestedType) const { + auto fileType = TypeWithId::create(requestedType); + return std::make_unique( + requestedType, std::move(fileType), *params_, *scanSpec_); + } + + /// Run getFlatValues and verify each element equals static_cast(src). + /// When sparseRows is provided, only those rows are selected from data. + template + void testConversion( + const TypePtr& type, + const std::vector& data, + const std::vector& sparseRows = {}) { + auto reader = makeReader(type); + std::vector allRows(data.size()); + std::iota(allRows.begin(), allRows.end(), 0); + reader->setupValues(data, allRows); + + const auto& selectedRows = sparseRows.empty() ? allRows : sparseRows; + const RowSet rows(selectedRows.data(), selectedRows.size()); + VectorPtr result; + reader->getFlatValues(rows, &result, type, true); + + auto* flat = result->as>(); + ASSERT_NE(flat, nullptr); + ASSERT_EQ(flat->size(), selectedRows.size()); + for (size_t i = 0; i < selectedRows.size(); ++i) { + EXPECT_EQ(flat->valueAt(i), static_cast(data[selectedRows[i]])) + << "Mismatch at index " << i; + } + } + + std::shared_ptr pool_; + std::unique_ptr stats_; + std::unique_ptr params_; + std::unique_ptr scanSpec_; +}; + +// Cross-domain upcast conversions (different sizes, valid via +// upcastScalarValues). + +TEST_F(GetFlatValuesTest, int16ToFloat) { + testConversion(REAL(), {1, -32'000, 0, 127, 32'767}); + testConversion(REAL(), {1, -32'000, 0, 127, 32'767}, {0, 3}); +} + +TEST_F(GetFlatValuesTest, int32ToDouble) { + testConversion( + DOUBLE(), {1, -999'999, 2'147'483'647, 0, -2'147'483'647}); + testConversion( + DOUBLE(), {1, -999'999, 2'147'483'647, 0, -2'147'483'647}, {1, 2, 4}); +} + +TEST_F(GetFlatValuesTest, int16ToDouble) { + testConversion(DOUBLE(), {0, 1, -1, 32'767, -32'768}); + testConversion( + DOUBLE(), {0, 1, -1, 32'767, -32'768}, {0, 4}); +} + +// Same-domain conversions as regression guards. + +TEST_F(GetFlatValuesTest, int32ToInt32) { + testConversion(INTEGER(), {0, 1, -1, 42, 2'147'483'647}); + testConversion( + INTEGER(), {0, 1, -1, 42, 2'147'483'647}, {2, 3}); +} + +TEST_F(GetFlatValuesTest, int16ToInt32) { + testConversion(INTEGER(), {0, 1, -1, 32'767, -32'768}); + testConversion( + INTEGER(), {0, 1, -1, 32'767, -32'768}, {1, 4}); +} + +TEST_F(GetFlatValuesTest, int32ToInt16) { + testConversion(SMALLINT(), {0, 1, -1, 127, -128}); + testConversion(SMALLINT(), {0, 1, -1, 127, -128}, {0, 2}); +} + +TEST_F(GetFlatValuesTest, int64ToInt16) { + testConversion(SMALLINT(), {0, 1, -1, 127, -128}); + testConversion(SMALLINT(), {0, 1, -1, 127, -128}, {3, 4}); +} + +TEST_F(GetFlatValuesTest, int32ToInt8) { + testConversion(TINYINT(), {0, 1, -1, 42, -100}); + testConversion(TINYINT(), {0, 1, -1, 42, -100}, {1, 3, 4}); +} + +TEST_F(GetFlatValuesTest, int16ToInt8) { + testConversion(TINYINT(), {0, 1, -1, 42, -100}); + testConversion(TINYINT(), {0, 1, -1, 42, -100}, {0, 4}); +} + +TEST_F(GetFlatValuesTest, int32ToInt64) { + testConversion( + BIGINT(), {0, 1, -1, 2'147'483'647, -2'147'483'648}); + testConversion( + BIGINT(), {0, 1, -1, 2'147'483'647, -2'147'483'648}, {0, 2, 3}); +} + +TEST_F(GetFlatValuesTest, int16ToInt64) { + testConversion(BIGINT(), {0, 1, -1, 32'767, -32'768}); + testConversion( + BIGINT(), {0, 1, -1, 32'767, -32'768}, {2, 4}); +} + +// HUGEINT widening conversions for Parquet type widening. + +TEST_F(GetFlatValuesTest, int128ToInt128) { + testConversion(DECIMAL(38, 0), {0, 1, -1, 1'000'000}); + testConversion( + DECIMAL(38, 0), {0, 1, -1, 1'000'000}, {1, 3}); +} + +TEST_F(GetFlatValuesTest, int64ToInt128) { + testConversion( + DECIMAL(38, 0), + {0, 1, -1, 9'223'372'036'854'775'807LL, -9'223'372'036'854'775'807LL}); + testConversion( + DECIMAL(38, 0), + {0, 1, -1, 9'223'372'036'854'775'807LL, -9'223'372'036'854'775'807LL}, + {0, 3, 4}); +} + +TEST_F(GetFlatValuesTest, int32ToInt128) { + testConversion( + DECIMAL(38, 0), {0, 1, -1, 2'147'483'647, -2'147'483'648}); + testConversion( + DECIMAL(38, 0), {0, 1, -1, 2'147'483'647, -2'147'483'648}, {1, 4}); +} + +TEST_F(GetFlatValuesTest, int16ToInt128) { + testConversion( + DECIMAL(38, 0), {0, 1, -1, 32'767, -32'768}); + testConversion( + DECIMAL(38, 0), {0, 1, -1, 32'767, -32'768}, {0, 2, 3}); +} + +// Regression guards for existing same-domain paths affected by template +// changes. + +TEST_F(GetFlatValuesTest, int64ToInt64) { + testConversion( + BIGINT(), + {0, 1, -1, 9'223'372'036'854'775'807LL, -9'223'372'036'854'775'807LL}); + testConversion( + BIGINT(), + {0, 1, -1, 9'223'372'036'854'775'807LL, -9'223'372'036'854'775'807LL}, + {0, 4}); +} + +TEST_F(GetFlatValuesTest, int64ToInt32) { + testConversion(INTEGER(), {0, 1, -1, 42, -100}); + testConversion(INTEGER(), {0, 1, -1, 42, -100}, {2, 3}); +} + +TEST_F(GetFlatValuesTest, int16ToInt16) { + testConversion(SMALLINT(), {0, 1, -1, 32'767, -32'768}); + testConversion( + SMALLINT(), {0, 1, -1, 32'767, -32'768}, {1, 2, 4}); +} + +// ByteRle conversions (int8_t source). + +TEST_F(GetFlatValuesTest, int8ToInt8) { + testConversion(TINYINT(), {0, 1, -1, 127, -128}); + testConversion(TINYINT(), {0, 1, -1, 127, -128}, {0, 2, 4}); +} + +TEST_F(GetFlatValuesTest, int8ToInt16) { + testConversion(SMALLINT(), {0, 1, -1, 127, -128}); + testConversion(SMALLINT(), {0, 1, -1, 127, -128}, {1, 3}); +} + +TEST_F(GetFlatValuesTest, int8ToInt32) { + testConversion(INTEGER(), {0, 1, -1, 127, -128}); + testConversion(INTEGER(), {0, 1, -1, 127, -128}, {2, 4}); +} + +TEST_F(GetFlatValuesTest, int8ToInt64) { + testConversion(BIGINT(), {0, 1, -1, 127, -128}); + testConversion(BIGINT(), {0, 1, -1, 127, -128}, {0, 3, 4}); +} + +// Unsigned integer conversions. + +TEST_F(GetFlatValuesTest, uint8ToUint8) { + testConversion(TINYINT(), {0, 1, 42, 255}); + testConversion(TINYINT(), {0, 1, 42, 255}, {1, 3}); +} + +TEST_F(GetFlatValuesTest, uint16ToUint16) { + testConversion(SMALLINT(), {0, 1, 42, 65'535}); + testConversion(SMALLINT(), {0, 1, 42, 65'535}, {0, 2}); +} + +TEST_F(GetFlatValuesTest, uint32ToUint8) { + testConversion(TINYINT(), {0, 1, 42, 200}); + testConversion(TINYINT(), {0, 1, 42, 200}, {2, 3}); +} + +TEST_F(GetFlatValuesTest, uint32ToUint16) { + testConversion(SMALLINT(), {0, 1, 42, 60'000}); + testConversion(SMALLINT(), {0, 1, 42, 60'000}, {0, 3}); +} + +TEST_F(GetFlatValuesTest, uint32ToUint32) { + testConversion(INTEGER(), {0, 1, 42, 4'294'967'295U}); + testConversion( + INTEGER(), {0, 1, 42, 4'294'967'295U}, {1, 2}); +} + +TEST_F(GetFlatValuesTest, uint32ToUint64) { + testConversion(BIGINT(), {0, 1, 42, 4'294'967'295U}); + testConversion( + BIGINT(), {0, 1, 42, 4'294'967'295U}, {0, 3}); +} + +TEST_F(GetFlatValuesTest, uint64ToUint64) { + testConversion(BIGINT(), {0, 1, 42, 1'000'000'000ULL}); + testConversion( + BIGINT(), {0, 1, 42, 1'000'000'000ULL}, {1, 3}); +} + +TEST_F(GetFlatValuesTest, uint64ToUint128) { + testConversion( + DECIMAL(38, 0), {0, 1, 42, 1'000'000'000ULL}); + testConversion( + DECIMAL(38, 0), {0, 1, 42, 1'000'000'000ULL}, {0, 2, 3}); +} + +TEST_F(GetFlatValuesTest, uint128ToUint128) { + testConversion(DECIMAL(38, 0), {0, 1, 42, 1'000'000}); + testConversion( + DECIMAL(38, 0), {0, 1, 42, 1'000'000}, {1, 2}); +} + +// Floating-point conversions. + +TEST_F(GetFlatValuesTest, floatToFloat) { + testConversion(REAL(), {0.0f, 1.5f, -3.14f, 1e10f}); + testConversion(REAL(), {0.0f, 1.5f, -3.14f, 1e10f}, {0, 2}); +} + +TEST_F(GetFlatValuesTest, floatToDouble) { + testConversion(DOUBLE(), {0.0f, 1.5f, -3.14f, 1e10f}); + testConversion(DOUBLE(), {0.0f, 1.5f, -3.14f, 1e10f}, {1, 3}); +} + +TEST_F(GetFlatValuesTest, doubleToDouble) { + testConversion(DOUBLE(), {0.0, 1.5, -3.14, 1e100}); + testConversion(DOUBLE(), {0.0, 1.5, -3.14, 1e100}, {0, 3}); +} + +TEST_F(GetFlatValuesTest, floatToDoubleSpecialValues) { + auto reader = makeReader(DOUBLE()); + std::vector data = { + std::numeric_limits::quiet_NaN(), + std::numeric_limits::infinity(), + -std::numeric_limits::infinity(), + -0.0f, + std::numeric_limits::denorm_min()}; + std::vector rowNums(data.size()); + std::iota(rowNums.begin(), rowNums.end(), 0); + reader->setupValues(data, rowNums); + + const RowSet rows(rowNums.data(), rowNums.size()); + VectorPtr result; + reader->getFlatValues(rows, &result, DOUBLE(), true); + + auto* flat = result->as>(); + ASSERT_NE(flat, nullptr); + ASSERT_EQ(flat->size(), data.size()); + EXPECT_TRUE(std::isnan(flat->valueAt(0))); + EXPECT_EQ(flat->valueAt(1), std::numeric_limits::infinity()); + EXPECT_EQ(flat->valueAt(2), -std::numeric_limits::infinity()); + EXPECT_DOUBLE_EQ(flat->valueAt(3), -0.0); + EXPECT_TRUE(std::signbit(flat->valueAt(3))); + EXPECT_DOUBLE_EQ( + flat->valueAt(4), + static_cast(std::numeric_limits::denorm_min())); +} + +// Boundary condition: single-element input. + +TEST_F(GetFlatValuesTest, singleElement) { + testConversion(BIGINT(), {42}); +} + +// Large-scale conversion: 1024 values with dense and sparse rows. + +TEST_F(GetFlatValuesTest, largeScaleInt32ToDouble) { + constexpr int kSize = 1024; + std::vector data(kSize); + for (int i = 0; i < kSize; ++i) { + data[i] = i * 7 - 3000; + } + // Dense. + testConversion(DOUBLE(), data); + // Sparse: every 3rd row. + std::vector sparse; + for (int i = 0; i < kSize; i += 3) { + sparse.push_back(i); + } + testConversion(DOUBLE(), data, sparse); +} + +TEST_F(GetFlatValuesTest, largeScaleInt64ToInt128) { + constexpr int kSize = 1024; + std::vector data(kSize); + for (int i = 0; i < kSize; ++i) { + data[i] = static_cast(i) * 123'456'789LL - 50'000'000'000LL; + } + testConversion(DECIMAL(38, 0), data); + // Sparse: odd indices. + std::vector sparse; + for (int i = 1; i < kSize; i += 2) { + sparse.push_back(i); + } + testConversion(DECIMAL(38, 0), data, sparse); +} + +TEST_F(GetFlatValuesTest, largeScaleInt32ToInt16) { + constexpr int kSize = 1024; + std::vector data(kSize); + for (int i = 0; i < kSize; ++i) { + data[i] = (i % 256) - 128; + } + testConversion(SMALLINT(), data); + // Sparse: every 5th row. + std::vector sparse; + for (int i = 0; i < kSize; i += 5) { + sparse.push_back(i); + } + testConversion(SMALLINT(), data, sparse); +} + +// Cross-domain upcast conversion with nulls. + +TEST_F(GetFlatValuesTest, int32ToDoubleWithNulls) { + auto reader = makeReader(DOUBLE()); + const std::vector data = {10, 20, 0, 40, 0}; + const std::vector nulls = {false, false, true, false, true}; + std::vector rowNums(data.size()); + std::iota(rowNums.begin(), rowNums.end(), 0); + reader->setupValuesWithNulls(data, nulls, rowNums); + + const RowSet rows(rowNums.data(), rowNums.size()); + VectorPtr result; + reader->getFlatValues(rows, &result, DOUBLE(), true); + + auto* flat = result->as>(); + ASSERT_NE(flat, nullptr); + ASSERT_EQ(flat->size(), data.size()); + EXPECT_DOUBLE_EQ(flat->valueAt(0), 10.0); + EXPECT_DOUBLE_EQ(flat->valueAt(1), 20.0); + EXPECT_TRUE(flat->isNullAt(2)); + EXPECT_DOUBLE_EQ(flat->valueAt(3), 40.0); + EXPECT_TRUE(flat->isNullAt(4)); +} + +// Sparse rows with nulls. + +TEST_F(GetFlatValuesTest, int32ToDoubleSparseRowsWithNulls) { + auto reader = makeReader(DOUBLE()); + const std::vector data = {10, 20, 0, 40, 0, 60}; + const std::vector nulls = {false, false, true, false, true, false}; + const std::vector allRows = {0, 1, 2, 3, 4, 5}; + reader->setupValuesWithNulls(data, nulls, allRows); + + // Select rows {0, 2, 3, 5} — includes null at index 2. + const std::vector sparseRows = {0, 2, 3, 5}; + const RowSet rows(sparseRows.data(), sparseRows.size()); + VectorPtr result; + reader->getFlatValues(rows, &result, DOUBLE(), true); + + auto* flat = result->as>(); + ASSERT_NE(flat, nullptr); + ASSERT_EQ(flat->size(), 4); + EXPECT_DOUBLE_EQ(flat->valueAt(0), 10.0); + EXPECT_TRUE(flat->isNullAt(1)); + EXPECT_DOUBLE_EQ(flat->valueAt(2), 40.0); + EXPECT_DOUBLE_EQ(flat->valueAt(3), 60.0); +} + +TEST_F(GetFlatValuesTest, int64ToInt128SparseRowsWithNulls) { + auto reader = makeReader(DECIMAL(38, 0)); + const std::vector data = {100, 0, 300, 0, 500}; + const std::vector nulls = {false, true, false, true, false}; + const std::vector allRows = {0, 1, 2, 3, 4}; + reader->setupValuesWithNulls(data, nulls, allRows); + + const std::vector sparseRows = {0, 1, 4}; + const RowSet rows(sparseRows.data(), sparseRows.size()); + VectorPtr result; + reader->getFlatValues(rows, &result, DECIMAL(38, 0), true); + + auto* flat = result->as>(); + ASSERT_NE(flat, nullptr); + ASSERT_EQ(flat->size(), 3); + EXPECT_EQ(flat->valueAt(0), 100); + EXPECT_TRUE(flat->isNullAt(1)); + EXPECT_EQ(flat->valueAt(2), 500); +} + +} // namespace +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/tests/SortingWriterTest.cpp b/velox/dwio/common/tests/SortingWriterTest.cpp index 4beadc61e1a..d13837641c8 100644 --- a/velox/dwio/common/tests/SortingWriterTest.cpp +++ b/velox/dwio/common/tests/SortingWriterTest.cpp @@ -43,8 +43,9 @@ class MockWriter : public Writer { void flush() override {} - void close() override { + std::unique_ptr close() override { setState(State::kClosed); + return nullptr; } void abort() override { diff --git a/velox/dwio/common/tests/StreamUtilTest.cpp b/velox/dwio/common/tests/StreamUtilTest.cpp new file mode 100644 index 00000000000..dc1e6b5abfa --- /dev/null +++ b/velox/dwio/common/tests/StreamUtilTest.cpp @@ -0,0 +1,60 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/dwio/common/StreamUtil.h" + +#include + +namespace facebook::velox::dwio::common { +namespace { + +template +void testRowLoop( + const int32_t* rows, + int32_t begin, + int32_t end, + const std::vector& expectedDense, + const std::vector& expectedSparse, + const std::vector>& expectedSparseN) { + std::vector actualDense, actualSparse; + std::vector> actualSparseN; + rowLoop( + rows, + begin, + end, + [&](auto i) { actualDense.push_back(i); }, + [&](auto i) { actualSparse.push_back(i); }, + [&](auto i, auto size) { actualSparseN.emplace_back(i, size); }); + ASSERT_EQ(actualDense, expectedDense); + ASSERT_EQ(actualSparse, expectedSparse); + ASSERT_EQ(actualSparseN, expectedSparseN); +} + +TEST(StreamUtilTest, rowLoop) { + const int32_t rows[] = { + 0, 1, 2, 3, 4, // Dense + 5, 7, 9, 11, 13, // Sparse + 14, 15, 16, 17, 18, // Dense + 19, 21, 23, 25, 27, // Sparse + }; + testRowLoop<2>(rows, 0, 6, {0, 2, 4}, {}, {}); + testRowLoop<4>(rows, 0, 6, {0}, {}, {{4, 2}}); + testRowLoop<4>(rows, 0, 10, {0}, {4}, {{8, 2}}); + testRowLoop<4>(rows, 10, 20, {10}, {14}, {{18, 2}}); +} + +} // namespace +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/tests/TestBufferedInput.cpp b/velox/dwio/common/tests/TestBufferedInput.cpp deleted file mode 100644 index 6fa5be8da00..00000000000 --- a/velox/dwio/common/tests/TestBufferedInput.cpp +++ /dev/null @@ -1,379 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "velox/dwio/common/BufferedInput.h" - -using namespace facebook::velox::dwio::common; -using facebook::velox::common::Region; -using namespace facebook::velox::memory; -using namespace ::testing; - -namespace { - -class ReadFileMock : public ::facebook::velox::ReadFile { - public: - virtual ~ReadFileMock() override = default; - - MOCK_METHOD( - std::string_view, - pread, - (uint64_t offset, - uint64_t length, - void* buf, - facebook::velox::filesystems::File::IoStats* stats), - (const, override)); - - MOCK_METHOD(bool, shouldCoalesce, (), (const, override)); - MOCK_METHOD(uint64_t, size, (), (const, override)); - MOCK_METHOD(uint64_t, memoryUsage, (), (const, override)); - MOCK_METHOD(std::string, getName, (), (const, override)); - MOCK_METHOD(uint64_t, getNaturalReadSize, (), (const, override)); - MOCK_METHOD( - uint64_t, - preadv, - (folly::Range regions, - folly::Range iobufs, - facebook::velox::filesystems::File::IoStats* stats), - (const, override)); -}; - -void expectPreads( - ReadFileMock& file, - std::string_view content, - std::vector reads) { - EXPECT_CALL(file, getName()).WillRepeatedly(Return("mock_name")); - EXPECT_CALL(file, size()).WillRepeatedly(Return(content.size())); - for (auto& read : reads) { - ASSERT_GE(content.size(), read.offset + read.length); - EXPECT_CALL(file, pread(read.offset, read.length, _, nullptr)) - .Times(1) - .WillOnce( - [content]( - uint64_t offset, - uint64_t length, - void* buf, - facebook::velox::filesystems::File::IoStats* stats) - -> std::string_view { - memcpy(buf, content.data() + offset, length); - return {content.data() + offset, length}; - }); - } -} - -void expectPreadvs( - ReadFileMock& file, - std::string_view content, - std::vector reads) { - EXPECT_CALL(file, getName()).WillRepeatedly(Return("mock_name")); - EXPECT_CALL(file, size()).WillRepeatedly(Return(content.size())); - EXPECT_CALL(file, preadv(_, _, nullptr)) - .Times(1) - .WillOnce( - [content, reads]( - folly::Range regions, - folly::Range iobufs, - facebook::velox::filesystems::File::IoStats* stats) -> uint64_t { - EXPECT_EQ(regions.size(), reads.size()); - uint64_t length = 0; - for (size_t i = 0; i < reads.size(); ++i) { - const auto& region = regions[i]; - const auto& read = reads[i]; - auto& iobuf = iobufs[i]; - length += region.length; - EXPECT_EQ(region.offset, read.offset); - EXPECT_EQ(region.length, read.length); - if (!read.label.empty()) { - EXPECT_EQ(read.label, region.label); - } - EXPECT_LE(region.offset + region.length, content.size()); - iobuf = folly::IOBuf( - folly::IOBuf::COPY_BUFFER, - content.data() + region.offset, - region.length); - } - - return length; - }); -} - -std::optional getNext(SeekableInputStream& input) { - const void* buf = nullptr; - int32_t size; - if (input.Next(&buf, &size)) { - return std::string( - static_cast(buf), static_cast(size)); - } else { - return std::nullopt; - } -} - -class TestBufferedInput : public testing::Test { - protected: - static void SetUpTestCase() { - MemoryManager::testingSetInstance(MemoryManager::Options{}); - } - - const std::shared_ptr pool_ = memoryManager()->addLeafPool(); -}; -} // namespace - -TEST_F(TestBufferedInput, ZeroLengthStream) { - auto readFile = - std::make_shared(std::string()); - BufferedInput input(readFile, *pool_); - auto ret = input.enqueue({0, 0}); - EXPECT_EQ(input.nextFetchSize(), 0); - EXPECT_NE(ret, nullptr); - const void* buf = nullptr; - int32_t size = 1; - EXPECT_FALSE(ret->Next(&buf, &size)); - EXPECT_EQ(size, 0); -} - -TEST_F(TestBufferedInput, UseRead) { - std::string content = "hello"; - auto readFileMock = std::make_shared(); - expectPreads(*readFileMock, content, {{0, 5}}); - // Use read - BufferedInput input( - readFileMock, - *pool_, - MetricsLog::voidLog(), - nullptr, - nullptr, - 10, - /* wsVRLoad = */ false); - auto ret = input.enqueue({0, 5}); - ASSERT_NE(ret, nullptr); - - EXPECT_EQ(input.nextFetchSize(), 5); - input.load(LogType::TEST); - - auto next = getNext(*ret); - ASSERT_TRUE(next.has_value()); - EXPECT_EQ(next.value(), content); -} - -TEST_F(TestBufferedInput, UseVRead) { - std::string content = "hello"; - auto readFileMock = std::make_shared(); - expectPreadvs(*readFileMock, content, {{0, 5}}); - // Use vread - BufferedInput input( - readFileMock, - *pool_, - MetricsLog::voidLog(), - nullptr, - nullptr, - 10, - /* wsVRLoad = */ true); - auto ret = input.enqueue({0, 5}); - ASSERT_NE(ret, nullptr); - - EXPECT_EQ(input.nextFetchSize(), 5); - input.load(LogType::TEST); - - auto next = getNext(*ret); - ASSERT_TRUE(next.has_value()); - EXPECT_EQ(next.value(), content); -} - -TEST_F(TestBufferedInput, WillMerge) { - std::string content = "hello world"; - auto readFileMock = std::make_shared(); - - // Will merge because the distance is 1 and max distance to merge is 10. - // Expect only one call. - expectPreads(*readFileMock, content, {{0, 11}}); - - BufferedInput input( - readFileMock, - *pool_, - MetricsLog::voidLog(), - nullptr, - nullptr, - 10, // Will merge if distance <= 10 - /* wsVRLoad = */ false); - - auto ret1 = input.enqueue({0, 5}); - auto ret2 = input.enqueue({6, 5}); - ASSERT_NE(ret1, nullptr); - ASSERT_NE(ret2, nullptr); - - EXPECT_EQ(input.nextFetchSize(), 10); - input.load(LogType::TEST); - - auto next1 = getNext(*ret1); - ASSERT_TRUE(next1.has_value()); - EXPECT_EQ(next1.value(), "hello"); - - auto next2 = getNext(*ret2); - ASSERT_TRUE(next2.has_value()); - EXPECT_EQ(next2.value(), "world"); -} - -TEST_F(TestBufferedInput, WontMerge) { - std::string content = "hello world"; // two spaces - auto readFileMock = std::make_shared(); - - // Won't merge because the distance is 2 and max distance to merge is 1. - // Expect two calls - expectPreads(*readFileMock, content, {{0, 5}, {7, 5}}); - - BufferedInput input( - readFileMock, - *pool_, - MetricsLog::voidLog(), - nullptr, - nullptr, - 1, // Will merge if distance <= 1 - /* wsVRLoad = */ false); - - auto ret1 = input.enqueue({0, 5}); - auto ret2 = input.enqueue({7, 5}); - ASSERT_NE(ret1, nullptr); - ASSERT_NE(ret2, nullptr); - - EXPECT_EQ(input.nextFetchSize(), 10); - input.load(LogType::TEST); - - auto next1 = getNext(*ret1); - ASSERT_TRUE(next1.has_value()); - EXPECT_EQ(next1.value(), "hello"); - - auto next2 = getNext(*ret2); - ASSERT_TRUE(next2.has_value()); - EXPECT_EQ(next2.value(), "world"); -} - -TEST_F(TestBufferedInput, ReadSorting) { - std::string content = "aaabbbcccdddeeefffggghhhiiijjjkkklllmmmnnnooopppqqq"; - std::vector regions = {{6, 3}, {24, 3}, {3, 3}, {0, 3}, {29, 3}}; - - auto readFileMock = std::make_shared(); - expectPreads(*readFileMock, content, {{0, 9}, {24, 3}, {29, 3}}); - BufferedInput input( - readFileMock, - *pool_, - MetricsLog::voidLog(), - nullptr, - nullptr, - 1, // Will merge if distance <= 1 - /* wsVRLoad = */ false); - - std::vector, std::string>> - result; - result.reserve(regions.size()); - int64_t bytesToRead = 0; - for (auto& region : regions) { - bytesToRead += region.length; - auto ret = input.enqueue(region); - ASSERT_NE(ret, nullptr); - result.push_back( - {std::move(ret), content.substr(region.offset, region.length)}); - } - - EXPECT_EQ(input.nextFetchSize(), bytesToRead); - input.load(LogType::TEST); - - for (auto& r : result) { - auto next = getNext(*r.first); - ASSERT_TRUE(next.has_value()); - EXPECT_EQ(next.value(), r.second); - } -} - -TEST_F(TestBufferedInput, VReadSorting) { - std::string content = "aaabbbcccdddeeefffggghhhiiijjjkkklllmmmnnnooopppqqq"; - std::vector regions = {{6, 3}, {24, 3}, {3, 3}, {0, 3}, {29, 3}}; - - auto readFileMock = std::make_shared(); - expectPreadvs( - *readFileMock, content, {{0, 3}, {3, 3}, {6, 3}, {24, 3}, {29, 3}}); - BufferedInput input( - readFileMock, - *pool_, - MetricsLog::voidLog(), - nullptr, - nullptr, - 1, // Will merge if distance <= 1 - /* wsVRLoad = */ true); - - std::vector, std::string>> - result; - result.reserve(regions.size()); - int64_t bytesToRead = 0; - for (auto& region : regions) { - bytesToRead += region.length; - auto ret = input.enqueue(region); - ASSERT_NE(ret, nullptr); - result.push_back( - {std::move(ret), content.substr(region.offset, region.length)}); - } - - EXPECT_EQ(input.nextFetchSize(), bytesToRead); - input.load(LogType::TEST); - - for (auto& r : result) { - auto next = getNext(*r.first); - ASSERT_TRUE(next.has_value()); - EXPECT_EQ(next.value(), r.second); - } -} - -TEST_F(TestBufferedInput, VReadSortingWithLabels) { - std::string content = "aaabbbcccdddeeefffggghhhiiijjjkkklllmmmnnnooopppqqq"; - std::vector l = {"a", "b", "c", "d", "e"}; - std::vector regions = { - {6, 3, l[2]}, {24, 3, l[3]}, {3, 3, l[1]}, {0, 3, l[0]}, {29, 3, l[4]}}; - - auto readFileMock = std::make_shared(); - expectPreadvs( - *readFileMock, - content, - {{0, 3, l[0]}, {3, 3, l[1]}, {6, 3, l[2]}, {24, 3, l[3]}, {29, 3, l[4]}}); - BufferedInput input( - readFileMock, - *pool_, - MetricsLog::voidLog(), - nullptr, - nullptr, - 1, // Will merge if distance <= 1 - /* wsVRLoad = */ true); - - std::vector, std::string>> - result; - result.reserve(regions.size()); - int64_t bytesToRead = 0; - for (auto& region : regions) { - bytesToRead += region.length; - auto ret = input.enqueue(region); - ASSERT_NE(ret, nullptr); - result.push_back( - {std::move(ret), content.substr(region.offset, region.length)}); - } - - EXPECT_EQ(input.nextFetchSize(), bytesToRead); - input.load(LogType::TEST); - - for (auto& r : result) { - auto next = getNext(*r.first); - ASSERT_TRUE(next.has_value()); - EXPECT_EQ(next.value(), r.second); - } -} diff --git a/velox/dwio/common/tests/ThrottlerTest.cpp b/velox/dwio/common/tests/ThrottlerTest.cpp index 773e8d0dd70..12ea0c86d75 100644 --- a/velox/dwio/common/tests/ThrottlerTest.cpp +++ b/velox/dwio/common/tests/ThrottlerTest.cpp @@ -110,14 +110,15 @@ TEST_F(ThrottlerTest, throttle) { SCOPED_TRACE(fmt::format("signal: {}", Throttler::signalTypeName(signal))); Throttler::testingReset(); - Throttler::init(Throttler::Config( - true, - minThrottleBackoffMs, - maxThrottleBackoffMs, - 2.0, - signal == Throttler::SignalType::kLocal ? 2 : 1'000, - signal == Throttler::SignalType::kGlobal ? 2 : 1'000, - signal == Throttler::SignalType::kNetwork ? 2 : 1'000)); + Throttler::init( + Throttler::Config( + true, + minThrottleBackoffMs, + maxThrottleBackoffMs, + 2.0, + signal == Throttler::SignalType::kLocal ? 2 : 1'000, + signal == Throttler::SignalType::kGlobal ? 2 : 1'000, + signal == Throttler::SignalType::kNetwork ? 2 : 1'000)); auto* instance = Throttler::instance(); const auto& stats = instance->stats(); ASSERT_EQ(stats.localThrottled, 0); @@ -186,16 +187,17 @@ TEST_F(ThrottlerTest, expire) { for (const auto signal : kSignalTypes) { SCOPED_TRACE(fmt::format("signal: {}", Throttler::signalTypeName(signal))); Throttler::testingReset(); - Throttler::init(Throttler::Config( - true, - minThrottleBackoffMs, - maxThrottleBackoffMs, - 2.0, - signal == Throttler::SignalType::kLocal ? 2 : 1'000, - signal == Throttler::SignalType::kGlobal ? 2 : 1'000, - signal == Throttler::SignalType::kNetwork ? 2 : 1'000, - 1'000, - 1'000)); + Throttler::init( + Throttler::Config( + true, + minThrottleBackoffMs, + maxThrottleBackoffMs, + 2.0, + signal == Throttler::SignalType::kLocal ? 2 : 1'000, + signal == Throttler::SignalType::kGlobal ? 2 : 1'000, + signal == Throttler::SignalType::kNetwork ? 2 : 1'000, + 1'000, + 1'000)); auto* instance = Throttler::instance(); const auto& stats = instance->stats(); ASSERT_EQ(stats.localThrottled, 0); @@ -228,8 +230,9 @@ TEST_F(ThrottlerTest, expire) { TEST_F(ThrottlerTest, differentLocals) { const uint64_t minThrottleBackoffMs = 1'000; const uint64_t maxThrottleBackoffMs = 2'000; - Throttler::init(Throttler::Config( - true, minThrottleBackoffMs, maxThrottleBackoffMs, 2.0, 2, 1'0000)); + Throttler::init( + Throttler::Config( + true, minThrottleBackoffMs, maxThrottleBackoffMs, 2.0, 2, 1'0000)); auto* instance = Throttler::instance(); const auto& stats = instance->stats(); ASSERT_EQ(stats.localThrottled, 0); @@ -320,8 +323,9 @@ TEST_F(ThrottlerTest, differentLocals) { TEST_F(ThrottlerTest, differentGlobals) { const uint64_t minThrottleBackoffMs = 1'000; const uint64_t maxThrottleBackoffMs = 2'000; - Throttler::init(Throttler::Config( - true, minThrottleBackoffMs, maxThrottleBackoffMs, 2.0, 1'0000, 2)); + Throttler::init( + Throttler::Config( + true, minThrottleBackoffMs, maxThrottleBackoffMs, 2.0, 1'0000, 2)); auto* instance = Throttler::instance(); const auto& stats = instance->stats(); ASSERT_EQ(stats.localThrottled, 0); @@ -409,14 +413,15 @@ TEST_F(ThrottlerTest, differentGlobals) { TEST_F(ThrottlerTest, differentNetworks) { const uint64_t minThrottleBackoffMs = 1'000; const uint64_t maxThrottleBackoffMs = 2'000; - Throttler::init(Throttler::Config( - true, - minThrottleBackoffMs, - maxThrottleBackoffMs, - 2.0, - 1'0000, - 1'0000, - 2)); + Throttler::init( + Throttler::Config( + true, + minThrottleBackoffMs, + maxThrottleBackoffMs, + 2.0, + 1'0000, + 1'0000, + 2)); auto* instance = Throttler::instance(); const auto& stats = instance->stats(); ASSERT_EQ(stats.localThrottled, 0); @@ -511,8 +516,9 @@ TEST_F(ThrottlerTest, maxOfGlobalAndLocal) { for (const bool localFirst : {false, true}) { SCOPED_TRACE(fmt::format("localFirst: {}", localFirst)); Throttler::testingReset(); - Throttler::init(Throttler::Config( - true, minThrottleBackoffMs, maxThrottleBackoffMs, 2.0, 2, 2)); + Throttler::init( + Throttler::Config( + true, minThrottleBackoffMs, maxThrottleBackoffMs, 2.0, 2, 2)); auto* instance = Throttler::instance(); const auto& stats = instance->stats(); ASSERT_EQ(stats.localThrottled, 0); @@ -627,16 +633,17 @@ TEST_F(ThrottlerTest, fuzz) { const uint32_t maxCacheEntries = 64; const uint32_t cacheTTLMs = 10; Throttler::testingReset(); - Throttler::init(Throttler::Config( - true, - minThrottleBackoffMs, - maxThrottleBackoffMs, - backoffScaleFactor, - minLocalThrottledSignals, - minGlobalThrottledSignals, - minNetworkThrottledSignals, - maxCacheEntries, - cacheTTLMs)); + Throttler::init( + Throttler::Config( + true, + minThrottleBackoffMs, + maxThrottleBackoffMs, + backoffScaleFactor, + minLocalThrottledSignals, + minGlobalThrottledSignals, + minNetworkThrottledSignals, + maxCacheEntries, + cacheTTLMs)); auto* instance = Throttler::instance(); const auto seed = getCurrentTimeMs(); diff --git a/velox/dwio/common/tests/UnitLoaderBaseTest.h b/velox/dwio/common/tests/UnitLoaderBaseTest.h new file mode 100644 index 00000000000..9faf9bc91b3 --- /dev/null +++ b/velox/dwio/common/tests/UnitLoaderBaseTest.h @@ -0,0 +1,140 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/dwio/common/UnitLoaderTools.h" +#include "velox/dwio/common/tests/utils/UnitLoaderTestTools.h" + +using facebook::velox::dwio::common::LoadUnit; +using facebook::velox::dwio::common::test::getUnitsLoadedWithFalse; +using facebook::velox::dwio::common::test::LoadUnitMock; +using facebook::velox::dwio::common::test::ReaderMock; + +/// Base test class template that provides common test functionality for +/// different UnitLoader implementations. This template class can be inherited +/// by specific test classes to get access to common test methods. Each derived +/// class should provide a createFactory() method that returns the appropriate +/// factory instance. +template +class UnitLoaderBaseTest : public ::testing::Test { + protected: + /// Factory method to create the appropriate UnitLoaderFactory instance. + /// This method should be implemented by derived classes. + virtual UnitLoaderFactoryType createFactory() = 0; + + /// Test that UnitLoader factory handles the case where no units exist but + /// skip is requested + void testNoUnitButSkip() { + UnitLoaderFactoryType factory = createFactory(); + std::vector> units; + + EXPECT_NO_THROW(factory.create(std::move(units), 0)); + + std::vector> units2; + VELOX_ASSERT_THROW( + factory.create(std::move(units2), 1), + "Can only skip up to the past-the-end row of the file."); + } + + /// Test that UnitLoader factory handles initial skip correctly for various + /// skip values + void testInitialSkip() { + auto getFactoryWithSkip = [this](uint64_t skipToRow) { + auto factory = createFactory(); + std::vector unitsLoaded(getUnitsLoadedWithFalse(3)); + std::vector> units; + units.push_back(std::make_unique(10, 0, unitsLoaded, 0)); + units.push_back(std::make_unique(20, 0, unitsLoaded, 1)); + units.push_back(std::make_unique(30, 0, unitsLoaded, 2)); + factory.create(std::move(units), skipToRow); + }; + + EXPECT_NO_THROW(getFactoryWithSkip(0)); + EXPECT_NO_THROW(getFactoryWithSkip(1)); + EXPECT_NO_THROW(getFactoryWithSkip(9)); + EXPECT_NO_THROW(getFactoryWithSkip(10)); + EXPECT_NO_THROW(getFactoryWithSkip(11)); + EXPECT_NO_THROW(getFactoryWithSkip(29)); + EXPECT_NO_THROW(getFactoryWithSkip(30)); + EXPECT_NO_THROW(getFactoryWithSkip(31)); + EXPECT_NO_THROW(getFactoryWithSkip(59)); + EXPECT_NO_THROW(getFactoryWithSkip(60)); + VELOX_ASSERT_THROW( + getFactoryWithSkip(61), + "Can only skip up to the past-the-end row of the file."); + VELOX_ASSERT_THROW( + getFactoryWithSkip(100), + "Can only skip up to the past-the-end row of the file."); + } + + /// Test that the same unit can be requested multiple times without issues + void testCanRequestUnitMultipleTimes() { + auto factory = createFactory(); + std::vector unitsLoaded(getUnitsLoadedWithFalse(1)); + std::vector> units; + units.push_back(std::make_unique(10, 0, unitsLoaded, 0)); + + auto unitLoader = factory.create(std::move(units), 0); + unitLoader->getLoadedUnit(0); + unitLoader->getLoadedUnit(0); + unitLoader->getLoadedUnit(0); + } + + /// Test that requesting a unit index out of range throws an exception + void testUnitOutOfRange() { + auto factory = createFactory(); + std::vector unitsLoaded(getUnitsLoadedWithFalse(1)); + std::vector> units; + units.push_back(std::make_unique(10, 0, unitsLoaded, 0)); + + auto unitLoader = factory.create(std::move(units), 0); + unitLoader->getLoadedUnit(0); + + VELOX_ASSERT_THROW(unitLoader->getLoadedUnit(1), "Unit out of range"); + } + + /// Test that seeking out of range throws an exception + void testSeekOutOfRange() { + auto factory = createFactory(); + std::vector unitsLoaded(getUnitsLoadedWithFalse(1)); + std::vector> units; + units.push_back(std::make_unique(10, 0, unitsLoaded, 0)); + + auto unitLoader = factory.create(std::move(units), 0); + + unitLoader->onSeek(0, 10); + + VELOX_ASSERT_THROW(unitLoader->onSeek(0, 11), "Row out of range"); + } + + /// Test that seeking out of range in ReaderMock throws appropriate exception + void testSeekOutOfRangeReaderError() { + auto factory = createFactory(); + ReaderMock readerMock{{10, 20, 30}, {0, 0, 0}, factory, 0}; + + readerMock.seek(59); + readerMock.seek(60); + + VELOX_ASSERT_THROW( + readerMock.seek(61), + "Can't seek to possition 61 in file. Must be up to 60."); + } +}; diff --git a/velox/dwio/common/tests/WriterTest.cpp b/velox/dwio/common/tests/WriterTest.cpp index 51e68de59d0..f1118927e30 100644 --- a/velox/dwio/common/tests/WriterTest.cpp +++ b/velox/dwio/common/tests/WriterTest.cpp @@ -15,13 +15,47 @@ */ #include "velox/dwio/common/Writer.h" + #include + #include "velox/common/base/tests/GTestUtils.h" using namespace ::testing; namespace facebook::velox::dwio::common { namespace { + +class MockWriter : public Writer { + public: + MockWriter() = default; + + void setStateForTest(State state) { + setState(state); + } + + void callCheckRunning() const { + checkRunning(); + } + + bool callIsRunning() const { + return isRunning(); + } + + void write(const VectorPtr& /*data*/) override {} + + void flush() override {} + + bool finish() override { + return true; + } + + void abort() override {} + + std::unique_ptr close() override { + return nullptr; + } +}; + TEST(WriterTest, stateString) { ASSERT_EQ(Writer::stateString(Writer::State::kInit), "INIT"); ASSERT_EQ(Writer::stateString(Writer::State::kRunning), "RUNNING"); @@ -31,5 +65,91 @@ TEST(WriterTest, stateString) { VELOX_ASSERT_THROW( Writer::stateString(static_cast(100)), "BAD STATE: 100"); } + +TEST(WriterTest, checkRunning) { + MockWriter writer; + VELOX_ASSERT_THROW(writer.callCheckRunning(), "Writer is not running: INIT"); + writer.setStateForTest(Writer::State::kRunning); + ASSERT_NO_THROW(writer.callCheckRunning()); + writer.setStateForTest(Writer::State::kClosed); + VELOX_ASSERT_THROW( + writer.callCheckRunning(), "Writer is not running: CLOSED"); +} + +TEST(WriterTest, stateTransitions) { + MockWriter writer; + ASSERT_EQ(writer.state(), Writer::State::kInit); + + // Valid transition: kInit -> kRunning + writer.setStateForTest(Writer::State::kRunning); + ASSERT_EQ(writer.state(), Writer::State::kRunning); + + // Valid transition: kRunning -> kFinishing + writer.setStateForTest(Writer::State::kFinishing); + ASSERT_EQ(writer.state(), Writer::State::kFinishing); + + // Valid transition: kFinishing -> kFinishing (reentry) + writer.setStateForTest(Writer::State::kFinishing); + ASSERT_EQ(writer.state(), Writer::State::kFinishing); + + // Valid transition: kFinishing -> kClosed + writer.setStateForTest(Writer::State::kClosed); + ASSERT_EQ(writer.state(), Writer::State::kClosed); +} + +TEST(WriterTest, invalidStateTransitions) { + { + MockWriter writer; + // Invalid: kInit -> kClosed + VELOX_ASSERT_THROW( + writer.setStateForTest(Writer::State::kClosed), + "Unexpected state transition from INIT to CLOSED"); + } + { + MockWriter writer; + // Invalid: kInit -> kFinishing + VELOX_ASSERT_THROW( + writer.setStateForTest(Writer::State::kFinishing), + "Unexpected state transition from INIT to FINISHING"); + } + { + MockWriter writer; + writer.setStateForTest(Writer::State::kRunning); + writer.setStateForTest(Writer::State::kClosed); + // Invalid: kClosed -> kRunning + VELOX_ASSERT_THROW( + writer.setStateForTest(Writer::State::kRunning), + "Unexpected state transition from CLOSED to RUNNING"); + } +} + +TEST(WriterTest, stateGetter) { + MockWriter writer; + ASSERT_EQ(writer.state(), Writer::State::kInit); + + writer.setStateForTest(Writer::State::kRunning); + ASSERT_EQ(writer.state(), Writer::State::kRunning); + + writer.setStateForTest(Writer::State::kFinishing); + ASSERT_EQ(writer.state(), Writer::State::kFinishing); + + writer.setStateForTest(Writer::State::kClosed); + ASSERT_EQ(writer.state(), Writer::State::kClosed); +} + +TEST(WriterTest, isRunning) { + MockWriter writer; + ASSERT_FALSE(writer.callIsRunning()); + + writer.setStateForTest(Writer::State::kRunning); + ASSERT_TRUE(writer.callIsRunning()); + + writer.setStateForTest(Writer::State::kFinishing); + ASSERT_FALSE(writer.callIsRunning()); + + writer.setStateForTest(Writer::State::kClosed); + ASSERT_FALSE(writer.callIsRunning()); +} + } // namespace } // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/tests/utils/BatchMaker.cpp b/velox/dwio/common/tests/utils/BatchMaker.cpp index 435e32c39ba..a422c52e00e 100644 --- a/velox/dwio/common/tests/utils/BatchMaker.cpp +++ b/velox/dwio/common/tests/utils/BatchMaker.cpp @@ -141,6 +141,18 @@ VectorPtr BatchMaker::createVector( MemoryPool& pool, std::mt19937& gen, std::function isNullAt) { + if (type->isTime()) { + const bool isMicros = type->equivalent(*TIME_MICRO_UTC()); + // TIME is milliseconds since midnight; TIME_MICRO_UTC is microseconds. + const int64_t maxValue = isMicros ? 86'400'000'000LL : 86'400'000LL; + return createScalar( + size, + gen, + [&gen, maxValue]() { return Random::rand64(0, maxValue, gen); }, + pool, + isNullAt, + type); + } return createScalar( size, gen, diff --git a/velox/dwio/common/tests/utils/CMakeLists.txt b/velox/dwio/common/tests/utils/CMakeLists.txt index bbbd5c2ebe5..4ed4335c238 100644 --- a/velox/dwio/common/tests/utils/CMakeLists.txt +++ b/velox/dwio/common/tests/utils/CMakeLists.txt @@ -21,6 +21,15 @@ add_library( UnitLoaderTestTools.cpp E2EFilterTestBase.cpp ) +velox_add_test_headers( + velox_dwio_common_test_utils + BatchMaker.h + DataFiles.h + DataSetBuilder.h + E2EFilterTestBase.h + FilterGenerator.h + UnitLoaderTestTools.h +) target_link_libraries( velox_dwio_common_test_utils @@ -44,3 +53,5 @@ target_link_libraries( if(CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9) target_link_libraries(velox_dwio_common_test_utils stdc++fs) endif() + +velox_add_library(velox_lib_test_map_builder INTERFACE HEADERS MapBuilder.h) diff --git a/velox/dwio/common/tests/utils/DataSetBuilder.cpp b/velox/dwio/common/tests/utils/DataSetBuilder.cpp index 85c508dd0a3..63c1fdc69e8 100644 --- a/velox/dwio/common/tests/utils/DataSetBuilder.cpp +++ b/velox/dwio/common/tests/utils/DataSetBuilder.cpp @@ -47,8 +47,9 @@ DataSetBuilder& DataSetBuilder::makeDataset( for (size_t i = 0; i < batchCount; ++i) { if (withRecursiveNulls) { - batches_->push_back(std::static_pointer_cast( - BatchMaker::createBatch(rowType, numRows, pool_, nullptr, i))); + batches_->push_back( + std::static_pointer_cast( + BatchMaker::createBatch(rowType, numRows, pool_, nullptr, i))); } else { batches_->push_back( std::static_pointer_cast(BatchMaker::createBatch( @@ -201,13 +202,13 @@ DataSetBuilder& DataSetBuilder::withStringDistributionForField( DataSetBuilder& DataSetBuilder::withUniqueStringsForField( const Subfield& field) { for (RowVectorPtr batch : *batches_) { - auto strings = + auto* strings = getChildBySubfield(batch.get(), field)->as>(); for (auto row = 0; row < strings->size(); ++row) { if (strings->isNullAt(row)) { continue; } - std::string value = strings->valueAt(row); + auto value = std::string(strings->valueAt(row)); value += fmt::format("{}", row); strings->set(row, StringView(value)); } @@ -282,7 +283,7 @@ DataSetBuilder& DataSetBuilder::makeMapStringValues( continue; } if (!keys->isNullAt(i) && i % 3 == 0) { - std::string str = keys->valueAt(i); + auto str = std::string(keys->valueAt(i)); str += "----123456789"; keys->set(i, StringView(str)); } @@ -304,7 +305,7 @@ DataSetBuilder& DataSetBuilder::makeMapStringValues( continue; } if (!values->isNullAt(i) && i % 3 == 0) { - std::string str = values->valueAt(i); + auto str = std::string(values->valueAt(i)); str += "----123456789"; values->set(i, StringView(str)); } diff --git a/velox/dwio/common/tests/utils/E2EFilterTestBase.cpp b/velox/dwio/common/tests/utils/E2EFilterTestBase.cpp index f7656fbd08c..d5d28c9a854 100644 --- a/velox/dwio/common/tests/utils/E2EFilterTestBase.cpp +++ b/velox/dwio/common/tests/utils/E2EFilterTestBase.cpp @@ -16,8 +16,12 @@ #include "velox/dwio/common/tests/utils/E2EFilterTestBase.h" +#include +#include + #include "velox/dwio/common/tests/utils/DataSetBuilder.h" #include "velox/expression/Expr.h" +#include "velox/expression/ExprConstants.h" #include "velox/expression/ExprToSubfieldFilter.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" #include "velox/parse/Expressions.h" @@ -44,10 +48,46 @@ using dwio::common::InMemoryReadFile; using dwio::common::MemorySink; using velox::common::Subfield; +// Helper to generate monotonic unique values for a given scalar type. +// Generates values starting from startValue, incrementing by 1 for each row. +// For ascending order, values go from startValue to startValue + numRows - 1. +// The globalRowOffset is used to compute the correct value when generating +// data across multiple batches. +template +VectorPtr generateStrictSortedVectorImpl( + size_t numRows, + size_t rowOffset, + size_t totalRows, + int64_t startValue, + bool ascending, + memory::MemoryPool* pool) { + auto flatVector = BaseVector::create>( + CppToType::create(), static_cast(numRows), pool); + for (vector_size_t i = 0; i < static_cast(numRows); ++i) { + const size_t nextRowOffset = rowOffset + i; + // For ascending: value increases with nextRowOffset + // For descending: value decreases (totalRows - 1 - nextRowOffset) + const size_t adjustedRowOffset = + ascending ? nextRowOffset : (totalRows - 1 - nextRowOffset); + if constexpr (std::is_same_v) { + // For bool, we can only have 2 unique values, so this type is limited. + flatVector->set(i, adjustedRowOffset % 2 == 1); + } else if constexpr (std::is_same_v) { + flatVector->set( + i, + Timestamp(startValue + static_cast(adjustedRowOffset), 0)); + } else { + flatVector->set(i, static_cast(startValue + adjustedRowOffset)); + } + } + return flatVector; +} + std::vector E2EFilterTestBase::makeDataset( std::function customize, bool forRowGroupSkip, - bool withRecursiveNulls) { + bool withRecursiveNulls, + const std::vector& indexColumns) { if (!dataSetBuilder_) { dataSetBuilder_ = std::make_unique(*leafPool_, 0); } @@ -64,9 +104,308 @@ std::vector E2EFilterTestBase::makeDataset( } std::vector batches = *dataSetBuilder_->build(); + + // Replace index columns with monotonically increasing unique values. + if (!indexColumns.empty()) { + batches = replaceIndexColumnsWithStrictlySortedData(batches, indexColumns); + } + return batches; } +std::vector +E2EFilterTestBase::replaceIndexColumnsWithStrictlySortedData( + const std::vector& batches, + const std::vector& indexColumns) { + if (batches.empty() || indexColumns.empty()) { + return batches; + } + + const auto& rowType = batches[0]->type()->asRow(); + + // Build index column indices. + std::vector indexColIndices; + for (const auto& colName : indexColumns) { + auto idx = rowType.getChildIdxIfExists(colName); + VELOX_CHECK( + idx.has_value(), "Index column '{}' not found in schema", colName); + indexColIndices.push_back(idx.value()); + } + + // Calculate total rows across all batches. + size_t totalRows = 0; + for (const auto& batch : batches) { + totalRows += batch->size(); + } + + std::vector result; + result.reserve(batches.size()); + size_t rowOffset = 0; + + for (const auto& batch : batches) { + const size_t batchSize = batch->size(); + std::vector newChildren; + newChildren.reserve(rowType.size()); + + for (size_t colIdx = 0; colIdx < rowType.size(); ++colIdx) { + const bool isIndexColumn = + std::find(indexColIndices.begin(), indexColIndices.end(), colIdx) != + indexColIndices.end(); + + if (isIndexColumn) { + // Generate monotonically increasing unique values for index columns. + newChildren.push_back(generateStrictSortedVector( + rowType.childAt(colIdx), + batchSize, + rowOffset, + /*startValue=*/0, + totalRows, + /*ascending=*/true, + leafPool_.get())); + } else { + // Keep the original data for non-index columns. + newChildren.push_back(batch->childAt(colIdx)); + } + } + + result.push_back( + std::make_shared( + leafPool_.get(), + batch->type(), + nullptr, + batchSize, + std::move(newChildren))); + rowOffset += batchSize; + } + + return result; +} + +VectorPtr E2EFilterTestBase::generateStrictSortedVector( + const TypePtr& type, + size_t size, + size_t globalRowOffset, + int64_t startValue, + size_t totalRows, + bool ascending, + memory::MemoryPool* pool) { + switch (type->kind()) { + case TypeKind::BOOLEAN: + return generateStrictSortedVectorImpl( + size, globalRowOffset, totalRows, startValue, ascending, pool); + case TypeKind::TINYINT: + return generateStrictSortedVectorImpl( + size, globalRowOffset, totalRows, startValue, ascending, pool); + case TypeKind::SMALLINT: + return generateStrictSortedVectorImpl( + size, globalRowOffset, totalRows, startValue, ascending, pool); + case TypeKind::INTEGER: + return generateStrictSortedVectorImpl( + size, globalRowOffset, totalRows, startValue, ascending, pool); + case TypeKind::BIGINT: + return generateStrictSortedVectorImpl( + size, globalRowOffset, totalRows, startValue, ascending, pool); + case TypeKind::REAL: + return generateStrictSortedVectorImpl( + size, globalRowOffset, totalRows, startValue, ascending, pool); + case TypeKind::DOUBLE: + return generateStrictSortedVectorImpl( + size, globalRowOffset, totalRows, startValue, ascending, pool); + case TypeKind::TIMESTAMP: + return generateStrictSortedVectorImpl( + size, globalRowOffset, totalRows, startValue, ascending, pool); + case TypeKind::VARCHAR: { + // Generate string keys like "key_00000000", "key_00000001", etc. + auto flatVector = BaseVector::create>( + VARCHAR(), static_cast(size), pool); + for (vector_size_t i = 0; i < static_cast(size); ++i) { + const size_t globalIdx = globalRowOffset + i; + const size_t adjustedIdx = + ascending ? globalIdx : (totalRows - 1 - globalIdx); + auto str = fmt::format("key_{:08d}", startValue + adjustedIdx); + flatVector->set(i, StringView(str)); + } + return flatVector; + } + case TypeKind::VARBINARY: + case TypeKind::HUGEINT: + case TypeKind::ARRAY: + case TypeKind::MAP: + case TypeKind::ROW: + case TypeKind::UNKNOWN: + case TypeKind::FUNCTION: + case TypeKind::OPAQUE: + case TypeKind::INVALID: + VELOX_UNREACHABLE( + "Unsupported type for generateStrictSortedVector: {}", + type->toString()); + } + VELOX_UNREACHABLE(); +} + +namespace { + +// Computes which unique key a given row belongs to when generating data with +// duplicates. Each unique key is repeated 1-3 times (randomly determined). +// +// The function simulates generating keys from the beginning (row 0) and counts +// how many unique keys we've passed to reach the given globalRow. +// +// Example with seed=42 (which generates duplicate counts: 2, 1, 2, 3, 1, ...): +// Row 0 -> Key 0 (first of 2 duplicates) +// Row 1 -> Key 0 (second of 2 duplicates) +// Row 2 -> Key 1 (first of 1 duplicate, i.e., no duplicate) +// Row 3 -> Key 2 (first of 2 duplicates) +// Row 4 -> Key 2 (second of 2 duplicates) +// Row 5 -> Key 3 (first of 3 duplicates) +// Row 6 -> Key 3 (second of 3 duplicates) +// Row 7 -> Key 3 (third of 3 duplicates) +// Row 8 -> Key 4 (first of 1 duplicate) +// ... +// +// @param globalRow The row index (0-based) to find the key for +// @param seed Random seed for reproducibility (default 42) +// @return The unique key index that this row maps to +size_t computeUniqueKeyIndex(size_t globalRow, uint32_t seed = 42) { + std::mt19937 rng(seed); + std::uniform_int_distribution dupDist(1, 3); + + size_t currentRow = 0; + size_t uniqueKeyIdx = 0; + while (currentRow <= globalRow) { + int duplicates = dupDist(rng); + if (currentRow + duplicates > globalRow) { + return uniqueKeyIdx; + } + currentRow += duplicates; + ++uniqueKeyIdx; + } + return uniqueKeyIdx; +} + +// Computes the largest unique key index that would be assigned to any row +// in a dataset of totalRows. +// +// This is needed for descending order: to generate descending keys, we need +// to know the maximum key index so we can reverse the sequence +// (maxKey - currentKey gives us the descending value). +// +// Example: For totalRows=9 with the sequence above: +// Row 8 (the last row) maps to Key 4 +// So computeMaxUniqueKeyIndex(9) returns 4 +// +// @param totalRows Total number of rows in the dataset +// @param seed Random seed (must match computeUniqueKeyIndex for consistency) +// @return The maximum unique key index +size_t computeMaxUniqueKeyIndex(size_t totalRows, uint32_t seed = 42) { + if (totalRows == 0) { + return 0; + } + return computeUniqueKeyIndex(totalRows - 1, seed); +} + +// Generates values with duplicates (1-3 copies per unique key). +template +VectorPtr generateSortedVectorImpl( + size_t numRows, + size_t rowOffset, + size_t totalRows, + int64_t startValue, + bool ascending, + memory::MemoryPool* pool) { + auto flatVector = BaseVector::create>( + CppToType::create(), static_cast(numRows), pool); + + // For descending order, we need to know the max unique key index + const size_t maxUniqueKeyIdx = + ascending ? 0 : computeMaxUniqueKeyIndex(totalRows); + + for (vector_size_t i = 0; i < static_cast(numRows); ++i) { + const size_t globalRow = rowOffset + i; + // Get the unique key index for this row position + const size_t uniqueKeyIdx = computeUniqueKeyIndex(globalRow); + // For descending order, reverse the key index relative to max + const size_t adjustedKeyIdx = + ascending ? uniqueKeyIdx : (maxUniqueKeyIdx - uniqueKeyIdx); + + if constexpr (std::is_same_v) { + flatVector->set(i, adjustedKeyIdx % 2 == 1); + } else if constexpr (std::is_same_v) { + flatVector->set( + i, Timestamp(startValue + static_cast(adjustedKeyIdx), 0)); + } else { + flatVector->set(i, static_cast(startValue + adjustedKeyIdx)); + } + } + return flatVector; +} + +} // namespace + +VectorPtr E2EFilterTestBase::generateSortedVector( + const TypePtr& type, + size_t size, + size_t globalRowOffset, + int64_t startValue, + size_t totalRows, + bool ascending, + memory::MemoryPool* pool) { + switch (type->kind()) { + case TypeKind::BOOLEAN: + return generateSortedVectorImpl( + size, globalRowOffset, totalRows, startValue, ascending, pool); + case TypeKind::TINYINT: + return generateSortedVectorImpl( + size, globalRowOffset, totalRows, startValue, ascending, pool); + case TypeKind::SMALLINT: + return generateSortedVectorImpl( + size, globalRowOffset, totalRows, startValue, ascending, pool); + case TypeKind::INTEGER: + return generateSortedVectorImpl( + size, globalRowOffset, totalRows, startValue, ascending, pool); + case TypeKind::BIGINT: + return generateSortedVectorImpl( + size, globalRowOffset, totalRows, startValue, ascending, pool); + case TypeKind::REAL: + return generateSortedVectorImpl( + size, globalRowOffset, totalRows, startValue, ascending, pool); + case TypeKind::DOUBLE: + return generateSortedVectorImpl( + size, globalRowOffset, totalRows, startValue, ascending, pool); + case TypeKind::TIMESTAMP: + return generateSortedVectorImpl( + size, globalRowOffset, totalRows, startValue, ascending, pool); + case TypeKind::VARCHAR: { + // Generate string keys with duplicates + auto flatVector = BaseVector::create>( + VARCHAR(), static_cast(size), pool); + const size_t maxUniqueKeyIdx = + ascending ? 0 : computeMaxUniqueKeyIndex(totalRows); + for (vector_size_t i = 0; i < static_cast(size); ++i) { + const size_t globalRow = globalRowOffset + i; + const size_t uniqueKeyIdx = computeUniqueKeyIndex(globalRow); + const size_t adjustedIdx = + ascending ? uniqueKeyIdx : (maxUniqueKeyIdx - uniqueKeyIdx); + auto str = fmt::format("key_{:08d}", startValue + adjustedIdx); + flatVector->set(i, StringView(str)); + } + return flatVector; + } + case TypeKind::VARBINARY: + case TypeKind::HUGEINT: + case TypeKind::ARRAY: + case TypeKind::MAP: + case TypeKind::ROW: + case TypeKind::UNKNOWN: + case TypeKind::FUNCTION: + case TypeKind::OPAQUE: + case TypeKind::INVALID: + VELOX_UNREACHABLE( + "Unsupported type for generateSortedVector: {}", type->toString()); + } + VELOX_UNREACHABLE(); +} + void E2EFilterTestBase::makeAllNulls(const std::string& fieldName) { dataSetBuilder_->withAllNullsForField(Subfield(fieldName)); } @@ -95,6 +434,8 @@ void E2EFilterTestBase::readWithoutFilter( uint64_t& time) { SCOPED_TRACE("Read without filter"); dwio::common::ReaderOptions readerOpts{leafPool_.get()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); dwio::common::RowReaderOptions rowReaderOpts; auto input = std::make_unique( std::make_shared(sinkData_), readerOpts.memoryPool()); @@ -149,6 +490,8 @@ void E2EFilterTestBase::readWithFilter( bool skipCheck) { SCOPED_TRACE("Read with filter"); dwio::common::ReaderOptions readerOpts{leafPool_.get()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); dwio::common::RowReaderOptions rowReaderOpts; auto input = std::make_unique( std::make_shared(sinkData_), readerOpts.memoryPool()); @@ -167,21 +510,21 @@ void E2EFilterTestBase::readWithFilter( auto resultBatch = BaseVector::create(rowType_, 1, leafPool_.get()); resetReadBatchSizes(); int32_t clearCnt = 0; - auto deletedRowsIter = mutationSpec.deletedRows.begin(); + auto deletedRowsIter = mutationSpec.deletedRows.cbegin(); while (true) { { MicrosecondTimer timer(&time); if (++clearCnt % 17 == 0) { rowReader->resetFilterCaches(); } - auto nextRowNumber = rowReader->nextRowNumber(); + const auto nextRowNumber = rowReader->nextRowNumber(); if (nextRowNumber == RowReader::kAtEnd) { break; } auto readSize = rowReader->nextReadSize(nextReadBatchSize()); std::vector isDeleted(bits::nwords(readSize)); bool haveDelete = false; - for (; deletedRowsIter != mutationSpec.deletedRows.end(); + for (; deletedRowsIter != mutationSpec.deletedRows.cend(); ++deletedRowsIter) { auto i = *deletedRowsIter; if (i < nextRowNumber) { @@ -294,11 +637,13 @@ void E2EFilterTestBase::testReadWithFilterLazy( void E2EFilterTestBase::testFilterSpecs( const std::vector& batches, const std::vector& filterSpecs) { + SCOPED_TRACE(FilterGenerator::specsToString(filterSpecs)); MutationSpec mutations; std::vector hitRows; auto filters = filterGenerator_->makeSubfieldFilters( filterSpecs, batches, &mutations, hitRows); auto spec = filterGenerator_->makeScanSpec(std::move(filters)); + SCOPED_TRACE(spec->toString()); uint64_t timeWithFilter = 0; readWithFilter(spec, mutations, batches, hitRows, timeWithFilter, false); @@ -405,18 +750,19 @@ void E2EFilterTestBase::testScenario( bool wrapInStruct, const std::vector& filterable, int32_t numCombinations, - bool withRecursiveNulls) { + bool withRecursiveNulls, + const std::vector& indexColumns) { rowType_ = DataSetBuilder::makeRowType(columns, wrapInStruct); filterGenerator_ = std::make_unique(rowType_, seed_); - - auto batches = makeDataset(customize, false, withRecursiveNulls); - writeToMemory(rowType_, batches, false); + auto batches = + makeDataset(customize, false, withRecursiveNulls, indexColumns); + writeToMemory(rowType_, batches, false, indexColumns); testNoRowGroupSkip(batches, filterable, numCombinations); testPruningWithFilter(batches, filterable); if (testRowGroupSkip_) { - batches = makeDataset(customize, true, withRecursiveNulls); - writeToMemory(rowType_, batches, true); + batches = makeDataset(customize, true, withRecursiveNulls, indexColumns); + writeToMemory(rowType_, batches, true, indexColumns); testRowGroupSkip(batches, filterable); } } @@ -484,6 +830,12 @@ void E2EFilterTestBase::testRunLengthDictionaryScenario( } } +namespace { +core::ExprPtr parseExpr(const std::string& text) { + return parse::DuckSqlExpressionsParser().parseExpr(text); +} +} // namespace + void E2EFilterTestBase::testMetadataFilterImpl( const std::vector& batches, common::Subfield filterField, @@ -491,15 +843,31 @@ void E2EFilterTestBase::testMetadataFilterImpl( core::ExpressionEvaluator* evaluator, const std::string& remainingFilter, std::function validationFilter) { - SCOPED_TRACE(fmt::format("remainingFilter={}", remainingFilter)); + SCOPED_TRACE(fmt::format("remainingFilter='{}'", remainingFilter)); + auto untypedExpr = parseExpr(remainingFilter); + auto typedExpr = core::Expressions::inferTypes( + untypedExpr, batches[0]->type(), leafPool_.get()); + testMetadataFilterImpl( + batches, + std::move(filterField), + std::move(filter), + evaluator, + std::move(typedExpr), + std::move(validationFilter)); +} + +void E2EFilterTestBase::testMetadataFilterImpl( + const std::vector& batches, + common::Subfield filterField, + std::unique_ptr filter, + core::ExpressionEvaluator* evaluator, + core::TypedExprPtr typedExpr, + std::function validationFilter) { auto spec = std::make_shared(""); if (filter) { spec->getOrCreateChild(std::move(filterField)) ->setFilter(std::move(filter)); } - auto untypedExpr = parse::parseExpr(remainingFilter, {}); - auto typedExpr = core::Expressions::inferTypes( - untypedExpr, batches[0]->type(), leafPool_.get()); auto metadataFilter = std::make_shared(*spec, *typedExpr, evaluator); auto specA = spec->getOrCreateChild(common::Subfield("a")); @@ -512,6 +880,8 @@ void E2EFilterTestBase::testMetadataFilterImpl( specC->setProjectOut(true); specC->setChannel(0); ReaderOptions readerOpts{leafPool_.get()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); RowReaderOptions rowReaderOpts; auto input = std::make_unique( std::make_shared(sinkData_), readerOpts.memoryPool()); @@ -580,12 +950,13 @@ void E2EFilterTestBase::testMetadataFilter() { nullptr, c->size(), std::vector({c})); - batches.push_back(std::make_shared( - leafPool_.get(), - ROW({{"a", a->type()}, {"b", b->type()}}), - nullptr, - a->size(), - std::vector({a, b}))); + batches.push_back( + std::make_shared( + leafPool_.get(), + ROW({{"a", a->type()}, {"b", b->type()}}), + nullptr, + a->size(), + std::vector({a, b}))); } writeToMemory(batches[0]->type(), batches, false); @@ -621,6 +992,56 @@ void E2EFilterTestBase::testMetadataFilter() { [](int64_t a, int64_t) { return !!(a == 2 || a == 3 || a == 5 || a == 7); }); + { + SCOPED_TRACE("remainingFilter='a == 1 or a == 3 or a == 8'"); + auto typedExpr1 = core::Expressions::inferTypes( + parseExpr("a == 1"), batches[0]->type(), leafPool_.get()); + auto typedExpr2 = core::Expressions::inferTypes( + parseExpr("a == 3"), batches[0]->type(), leafPool_.get()); + auto typedExpr3 = core::Expressions::inferTypes( + parseExpr("a == 8"), batches[0]->type(), leafPool_.get()); + + auto typedExpr = std::make_shared( + velox::BOOLEAN(), + std::vector{ + std::move(typedExpr1), + std::move(typedExpr2), + std::move(typedExpr3), + }, + expression::kOr); + testMetadataFilterImpl( + batches, + common::Subfield("a"), + nullptr, + &evaluator, + std::move(typedExpr), + [](int64_t a, int64_t) { return a == 1 || a == 3 || a == 8; }); + } + { + SCOPED_TRACE("remainingFilter='a >= 1 and a <= 100 and a == 8'"); + auto typedExpr1 = core::Expressions::inferTypes( + parseExpr("a >= 1"), batches[0]->type(), leafPool_.get()); + auto typedExpr2 = core::Expressions::inferTypes( + parseExpr("a <= 100"), batches[0]->type(), leafPool_.get()); + auto typedExpr3 = core::Expressions::inferTypes( + parseExpr("b.c != 8"), batches[0]->type(), leafPool_.get()); + + auto typedExpr = std::make_shared( + velox::BOOLEAN(), + std::vector{ + std::move(typedExpr1), + std::move(typedExpr2), + std::move(typedExpr3), + }, + expression::kAnd); + testMetadataFilterImpl( + batches, + common::Subfield("a"), + nullptr, + &evaluator, + std::move(typedExpr), + [](int64_t a, int64_t c) { return a >= 1 && a <= 100 && c != 8; }); + } { SCOPED_TRACE("Values not unique in row group"); @@ -645,7 +1066,7 @@ void E2EFilterTestBase::testMetadataFilter() { writeToMemory(batches[0]->type(), batches, false); auto spec = std::make_shared(""); spec->addAllChildFields(*batches[0]->type()); - auto untypedExpr = parse::parseExpr("a = 1 or b + c = 2", {}); + auto untypedExpr = parseExpr("a = 1 or b + c = 2"); auto typedExpr = core::Expressions::inferTypes( untypedExpr, batches[0]->type(), leafPool_.get()); auto metadataFilter = @@ -724,6 +1145,8 @@ void E2EFilterTestBase::testSubfieldsPruning() { specF->childByName(common::ScanSpec::kArrayElementsFieldName) ->setFilter(common::createBigintValues({0, 2, 4}, false)); ReaderOptions readerOpts{leafPool_.get()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); RowReaderOptions rowReaderOpts; auto input = std::make_unique( std::make_shared(sinkData_), readerOpts.memoryPool()); @@ -814,6 +1237,8 @@ void E2EFilterTestBase::testMutationCornerCases() { auto& rowType = batches[0]->type(); writeToMemory(rowType, batches, false); ReaderOptions readerOpts{leafPool_.get()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); auto input = std::make_unique( std::make_shared(sinkData_), readerOpts.memoryPool()); auto reader = makeReader(readerOpts, std::move(input)); diff --git a/velox/dwio/common/tests/utils/E2EFilterTestBase.h b/velox/dwio/common/tests/utils/E2EFilterTestBase.h index 16d30e36a7b..1f97b09a6d9 100644 --- a/velox/dwio/common/tests/utils/E2EFilterTestBase.h +++ b/velox/dwio/common/tests/utils/E2EFilterTestBase.h @@ -16,6 +16,7 @@ #pragma once +#include "velox/common/io/IoStatistics.h" #include "velox/common/testutil/RandomSeed.h" #include "velox/common/time/Timer.h" #include "velox/dwio/common/BufferedInput.h" @@ -64,7 +65,7 @@ class TestingHook : public ValueHook { } } - void addValue(vector_size_t row, folly::StringPiece value) override { + void addValue(vector_size_t row, std::string_view value) override { if constexpr (std::is_same_v) { result_->set(row, StringView(value)); } else { @@ -96,6 +97,45 @@ class OwnershipChecker { }; class E2EFilterTestBase : public testing::Test { + public: + // Generates a vector with monotonically increasing/decreasing unique values + // for the given scalar type. Supports TINYINT, SMALLINT, INTEGER, BIGINT, + // REAL, DOUBLE, TIMESTAMP, and VARCHAR. + // @param type The scalar type of the vector to generate. + // @param size The number of rows in this batch. + // @param globalRowOffset The starting row offset across all batches. + // @param startValue The starting value for the sequence. + // @param totalRows The total number of rows across all batches. + // @param ascending If true, values increase; if false, values decrease. + // @param pool The memory pool to allocate the vector from. + static VectorPtr generateStrictSortedVector( + const TypePtr& type, + size_t size, + size_t globalRowOffset, + int64_t startValue, + size_t totalRows, + bool ascending, + memory::MemoryPool* pool); + + // Generates a vector with monotonically increasing/decreasing values that + // may contain duplicates (non-strictly sorted). Each unique value is + // repeated 1-3 times randomly. + // @param type The scalar type of the vector to generate. + // @param size The number of rows in this batch. + // @param globalRowOffset The starting row offset across all batches. + // @param startValue The starting value for the sequence. + // @param totalRows The total number of rows across all batches. + // @param ascending If true, values increase; if false, values decrease. + // @param pool The memory pool to allocate the vector from. + static VectorPtr generateSortedVector( + const TypePtr& type, + size_t size, + size_t globalRowOffset, + int64_t startValue, + size_t totalRows, + bool ascending, + memory::MemoryPool* pool); + protected: static constexpr int32_t kRowsInGroup = 10'000; @@ -118,7 +158,15 @@ class E2EFilterTestBase : public testing::Test { std::vector makeDataset( std::function customize, bool forRowGroupSkip, - bool withRecursiveNulls); + bool withRecursiveNulls, + const std::vector& indexColumns = {}); + + // Replaces index columns with monotonically increasing unique values. + // This ensures the data has sorted, unique keys for index testing + // without needing to sort and deduplicate random data. + std::vector replaceIndexColumnsWithStrictlySortedData( + const std::vector& batches, + const std::vector& indexColumns); void makeAllNulls(const std::string& fieldName); @@ -208,7 +256,8 @@ class E2EFilterTestBase : public testing::Test { virtual void writeToMemory( const TypePtr& type, const std::vector& batches, - bool forRowGroupSkip) = 0; + bool forRowGroupSkip, + const std::vector& indexColumns = {}) = 0; virtual std::unique_ptr makeReader( const dwio::common::ReaderOptions& opts, @@ -315,7 +364,8 @@ class E2EFilterTestBase : public testing::Test { bool wrapInStruct, const std::vector& filterable, int32_t numCombinations, - bool withRecursiveNulls = true); + bool withRecursiveNulls = true, + const std::vector& indexColumns = {}); void testRunLengthDictionaryScenario( const std::string& columns, @@ -336,6 +386,14 @@ class E2EFilterTestBase : public testing::Test { const std::string& remainingFilter, std::function validationFilter); + void testMetadataFilterImpl( + const std::vector& batches, + common::Subfield filterField, + std::unique_ptr filter, + core::ExpressionEvaluator* evaluator, + core::TypedExprPtr typedExpr, + std::function validationFilter); + protected: void testMetadataFilter(); @@ -356,13 +414,17 @@ class E2EFilterTestBase : public testing::Test { } const size_t kBatchCount = 4; - // kBatchSize must be greater than 10000 for RowGroup skipping test + // kBatchSize must be greater than 10'000 for RowGroup skipping test const size_t kBatchSize = 25'000; std::unique_ptr dataSetBuilder_; std::unique_ptr filterGenerator_; std::shared_ptr rootPool_; std::shared_ptr leafPool_; + std::shared_ptr dataIoStats_ = + std::make_shared(); + std::shared_ptr metadataIoStats_ = + std::make_shared(); std::shared_ptr rowType_; std::string sinkData_; bool useVInts_ = true; diff --git a/velox/dwio/common/tests/utils/FilterGenerator.cpp b/velox/dwio/common/tests/utils/FilterGenerator.cpp index 52cd41fc2b0..420c718c7d7 100644 --- a/velox/dwio/common/tests/utils/FilterGenerator.cpp +++ b/velox/dwio/common/tests/utils/FilterGenerator.cpp @@ -319,14 +319,7 @@ std::string FilterGenerator::specsToString( out << ", "; } first = false; - out << spec.field; - if (spec.filterKind == FilterKind::kIsNull) { - out << " is null"; - } else if (spec.filterKind == FilterKind::kIsNotNull) { - out << " is not null"; - } else { - out << ":" << spec.selectPct << "," << spec.startPct << " "; - } + out << spec.toString(); } return out.str(); } @@ -663,7 +656,7 @@ void pruneRandomSubfield( break; case TypeKind::VARCHAR: case TypeKind::VARBINARY: - stringKeys.push_back( + stringKeys.emplace_back( keys->asUnchecked>()->valueAt(jj)); break; default: diff --git a/velox/dwio/common/tests/utils/FilterGenerator.h b/velox/dwio/common/tests/utils/FilterGenerator.h index 42368d20ffb..2688441b5a0 100644 --- a/velox/dwio/common/tests/utils/FilterGenerator.h +++ b/velox/dwio/common/tests/utils/FilterGenerator.h @@ -16,6 +16,7 @@ #pragma once +#include #include #include @@ -49,6 +50,17 @@ struct FilterSpec { isForRowGroupSkip(isForRowGroupSkip), allowNulls_(allowNulls) {} + std::string toString() const { + return fmt::format( + "FilterSpec(field={}, startPct={}, selectPct={}, filterKind={}, isForRowGroupSkip={}, allowNulls={})", + field, + startPct, + selectPct, + filterKind, + isForRowGroupSkip, + allowNulls_); + } + std::string field; float startPct = 50; float selectPct = 20; diff --git a/velox/dwio/common/tests/utils/UnitLoaderTestTools.cpp b/velox/dwio/common/tests/utils/UnitLoaderTestTools.cpp index e2ec87ae605..0342d0d530e 100644 --- a/velox/dwio/common/tests/utils/UnitLoaderTestTools.cpp +++ b/velox/dwio/common/tests/utils/UnitLoaderTestTools.cpp @@ -90,8 +90,9 @@ bool ReaderMock::loadUnit() { std::vector> ReaderMock::getUnits() { std::vector> units; for (size_t i = 0; i < rowsPerUnit_.size(); ++i) { - units.emplace_back(std::make_unique( - rowsPerUnit_[i], ioSizes_[i], unitsLoaded_, i)); + units.emplace_back( + std::make_unique( + rowsPerUnit_[i], ioSizes_[i], unitsLoaded_, i)); } return units; } diff --git a/velox/dwio/common/tests/utils/UnitLoaderTestTools.h b/velox/dwio/common/tests/utils/UnitLoaderTestTools.h index 9eae97f575c..6c36c10c865 100644 --- a/velox/dwio/common/tests/utils/UnitLoaderTestTools.h +++ b/velox/dwio/common/tests/utils/UnitLoaderTestTools.h @@ -32,16 +32,20 @@ class LoadUnitMock : public LoadUnit { uint64_t rowCount, uint64_t ioSize, std::vector& unitsLoaded, - size_t unitId) + size_t unitId, + std::chrono::milliseconds loadDelay = std::chrono::milliseconds(100)) : rowCount_{rowCount}, ioSize_{ioSize}, unitsLoaded_{unitsLoaded}, - unitId_{unitId} {} + unitId_{unitId}, + loadDelay_(loadDelay) {} ~LoadUnitMock() override = default; void load() override { VELOX_CHECK(!isLoaded()); + // Simulate loading time + std::this_thread::sleep_for(loadDelay_); unitsLoaded_[unitId_] = true; } @@ -67,6 +71,7 @@ class LoadUnitMock : public LoadUnit { uint64_t ioSize_; std::vector& unitsLoaded_; size_t unitId_; + std::chrono::milliseconds loadDelay_; }; class ReaderMock { @@ -82,7 +87,7 @@ class ReaderMock { void seek(uint64_t rowNumber); std::vector unitsLoaded() const { - return {unitsLoaded_.begin(), unitsLoaded_.end()}; + return {unitsLoaded_.cbegin(), unitsLoaded_.cend()}; } private: diff --git a/velox/dwio/common/wrap/CMakeLists.txt b/velox/dwio/common/wrap/CMakeLists.txt new file mode 100644 index 00000000000..1d519452e0b --- /dev/null +++ b/velox/dwio/common/wrap/CMakeLists.txt @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +velox_install_library_headers() + +# Wrapper headers — verify these are still needed and correctly included. +velox_add_library(velox_dwio_common_wrap INTERFACE HEADERS zero-copy-stream-wrapper.h) diff --git a/velox/dwio/dwrf/CMakeLists.txt b/velox/dwio/dwrf/CMakeLists.txt index 613f69ad2f0..1fe671836e1 100644 --- a/velox/dwio/dwrf/CMakeLists.txt +++ b/velox/dwio/dwrf/CMakeLists.txt @@ -22,3 +22,9 @@ elseif(${VELOX_BUILD_TEST_UTILS}) endif() add_subdirectory(utils) add_subdirectory(writer) + +velox_add_library(velox_dwio_dwrf_register_reader INTERFACE HEADERS RegisterDwrfReader.h) + +velox_add_library(velox_dwio_dwrf_register_writer INTERFACE HEADERS RegisterDwrfWriter.h) + +velox_install_library_headers() diff --git a/velox/dwio/dwrf/common/CMakeLists.txt b/velox/dwio/dwrf/common/CMakeLists.txt index 8a3dc1e393e..311b4ac51ac 100644 --- a/velox/dwio/dwrf/common/CMakeLists.txt +++ b/velox/dwio/dwrf/common/CMakeLists.txt @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +add_subdirectory(wrap) + velox_add_library( velox_dwio_dwrf_common ByteRLE.cpp @@ -27,6 +29,25 @@ velox_add_library( Statistics.cpp wrap/orc-proto-wrapper.cpp wrap/dwrf-proto-wrapper.cpp + HEADERS + ByteRLE.h + Checksum.h + Common.h + Compression.h + Config.h + DecoderUtil.h + Decryption.h + EncoderUtil.h + Encryption.h + EncryptionCommon.h + EncryptionSpecification.h + FileMetadata.h + FloatingPointDecoder.h + IntEncoder.h + NextVisitor.h + RLEv1.h + RLEv2.h + Statistics.h ) velox_link_libraries( diff --git a/velox/dwio/dwrf/common/Checksum.h b/velox/dwio/dwrf/common/Checksum.h index 6ca7d205d99..0f31dbe19d8 100644 --- a/velox/dwio/dwrf/common/Checksum.h +++ b/velox/dwio/dwrf/common/Checksum.h @@ -20,8 +20,7 @@ #include "velox/dwio/dwrf/common/wrap/dwrf-proto-wrapper.h" #include -#define XXH_INLINE_ALL -#include +#include "velox/common/base/XxHashInline.h" namespace facebook::velox::dwrf { diff --git a/velox/dwio/dwrf/common/Common.h b/velox/dwio/dwrf/common/Common.h index 783f8a0329d..ae467654c1f 100644 --- a/velox/dwio/dwrf/common/Common.h +++ b/velox/dwio/dwrf/common/Common.h @@ -29,11 +29,11 @@ namespace facebook::velox::dwrf { // Writer version -constexpr folly::StringPiece WRITER_NAME_KEY{"orc.writer.name"}; -constexpr folly::StringPiece WRITER_VERSION_KEY{"orc.writer.version"}; -constexpr folly::StringPiece WRITER_HOSTNAME_KEY{"orc.writer.host"}; -constexpr folly::StringPiece kDwioWriter{"dwio"}; -constexpr folly::StringPiece kPrestoWriter{"presto"}; +constexpr std::string_view kWriterNameKey{"orc.writer.name"}; +constexpr std::string_view kWriterVersionKey{"orc.writer.version"}; +constexpr std::string_view kWriterHostnameKey{"orc.writer.host"}; +constexpr std::string_view kDwioWriter{"dwio"}; +constexpr std::string_view kPrestoWriter{"presto"}; enum class DwrfFormat : uint8_t { kDwrf = 0, diff --git a/velox/dwio/dwrf/common/Compression.h b/velox/dwio/dwrf/common/Compression.h index 0be6a4e0e41..e09ca21a870 100644 --- a/velox/dwio/dwrf/common/Compression.h +++ b/velox/dwio/dwrf/common/Compression.h @@ -105,6 +105,7 @@ inline CompressionOptions getDwrfOrcDecompressionOptions( * @param input The input stream that is the underlying source * @param bufferSize The maximum size of the buffer * @param pool The memory pool + * @param decompressCounter Optional IoCounter for tracking decompression stats */ inline std::unique_ptr createDecompressor( facebook::velox::common::CompressionKind kind, @@ -112,16 +113,20 @@ inline std::unique_ptr createDecompressor( uint64_t bufferSize, memory::MemoryPool& pool, const std::string& streamDebugInfo, - const dwio::common::encryption::Decrypter* decryptr = nullptr) { + const dwio::common::encryption::Decrypter* decryptr = nullptr, + velox::io::IoCounter* decompressCounter = nullptr) { const CompressionOptions& options = getDwrfOrcDecompressionOptions(kind); - return createDecompressor( + return dwio::common::compression::createDecompressor( kind, std::move(input), bufferSize, pool, options, streamDebugInfo, - decryptr); + decryptr, + false, + 0, + decompressCounter); } } // namespace facebook::velox::dwrf diff --git a/velox/dwio/dwrf/common/Config.cpp b/velox/dwio/dwrf/common/Config.cpp index 8ce1ad5f215..b9c556d8a88 100644 --- a/velox/dwio/dwrf/common/Config.cpp +++ b/velox/dwio/dwrf/common/Config.cpp @@ -130,7 +130,7 @@ Config::Entry> Config::MAP_FLAT_COLS( [](const std::string& /* key */, const std::string& val) { std::vector result; if (!val.empty()) { - std::vector pieces; + std::vector pieces; folly::split(',', val, pieces, true); for (const auto& p : pieces) { const auto& trimmedCol = folly::trimWhitespace(p); @@ -182,7 +182,7 @@ Config::Entry>> Config::Entry Config::MAP_FLAT_MAX_KEYS( "orc.map.flat.max.keys", - 20000); + 30000); Config::Entry Config::MAX_DICTIONARY_SIZE( "hive.exec.orc.max.dictionary.size", diff --git a/velox/dwio/dwrf/common/FileMetadata.cpp b/velox/dwio/dwrf/common/FileMetadata.cpp index ccb9f6faa7f..2b3f2f80d49 100644 --- a/velox/dwio/dwrf/common/FileMetadata.cpp +++ b/velox/dwio/dwrf/common/FileMetadata.cpp @@ -37,6 +37,29 @@ CompressionKind orcCompressionToCompressionKind( } VELOX_FAIL("Unknown compression kind: {}", CompressionKind_Name(compression)); } + +static proto::orc::CompressionKind compressionKindToOrcCompression( + CompressionKind compressionKind) { + switch (compressionKind) { + case CompressionKind::CompressionKind_NONE: + return proto::orc::CompressionKind::NONE; + case CompressionKind::CompressionKind_ZLIB: + return proto::orc::CompressionKind::ZLIB; + case CompressionKind::CompressionKind_SNAPPY: + return proto::orc::CompressionKind::SNAPPY; + case CompressionKind::CompressionKind_LZO: + return proto::orc::CompressionKind::LZO; + case CompressionKind::CompressionKind_ZSTD: + return proto::orc::CompressionKind::ZSTD; + case CompressionKind::CompressionKind_LZ4: + return proto::orc::CompressionKind::LZ4; + case CompressionKind::CompressionKind_GZIP: + default: + VELOX_FAIL( + "Unknown compression kind: {}", + compressionKindToString(compressionKind)); + } +} } // namespace detail TypeKind TypeWrapper::kind() const { @@ -102,9 +125,10 @@ TypeKind TypeWrapper::kind() const { } case proto::orc::Type_Kind_CHAR: case proto::orc::Type_Kind_TIMESTAMP_INSTANT: - VELOX_FAIL(fmt::format( - "{} not supported yet.", - proto::orc::Type_Kind_Name(orcPtr()->kind()))); + VELOX_FAIL( + fmt::format( + "{} not supported yet.", + proto::orc::Type_Kind_Name(orcPtr()->kind()))); default: VELOX_FAIL("Unknown type kind: {}", Type_Kind_Name(orcPtr()->kind())); } @@ -116,4 +140,13 @@ common::CompressionKind PostScript::compression() const { : detail::orcCompressionToCompressionKind(orcPtr()->compression()); } +void PostScriptWriteWrapper::setCompression( + common::CompressionKind compressionKind) { + format_ == DwrfFormat::kDwrf + ? dwrfPtr()->set_compression( + static_cast(compressionKind)) + : orcPtr()->set_compression( + detail::compressionKindToOrcCompression(compressionKind)); +} + } // namespace facebook::velox::dwrf diff --git a/velox/dwio/dwrf/common/FileMetadata.h b/velox/dwio/dwrf/common/FileMetadata.h index 87e8c12719f..3665d4b9120 100644 --- a/velox/dwio/dwrf/common/FileMetadata.h +++ b/velox/dwio/dwrf/common/FileMetadata.h @@ -19,6 +19,7 @@ #include "velox/common/base/Exceptions.h" #include "velox/common/compression/Compression.h" +#include "velox/dwio/common/OutputStream.h" #include "velox/dwio/dwrf/common/Common.h" #include "velox/dwio/dwrf/common/wrap/dwrf-proto-wrapper.h" #include "velox/dwio/dwrf/common/wrap/orc-proto-wrapper.h" @@ -43,6 +44,24 @@ class ProtoWrapperBase { const void* const impl_; }; +class ProtoWriteWrapperBase { + protected: + ProtoWriteWrapperBase(DwrfFormat format, void* impl) + : format_{format}, impl_{impl} {} + + DwrfFormat format_; + void* impl_; + + public: + DwrfFormat format() const { + return format_; + } + + inline void* rawProtoPtr() const { + return impl_; + } +}; + /*** * PostScript that takes the ownership of proto::PostScript / *proto::orc::PostScript and provides access to the attributes @@ -93,6 +112,10 @@ class PostScript { : orcPtr()->footerlength(); } + uint64_t metadataLength() const { + return format_ == DwrfFormat::kDwrf ? 0 : orcPtr()->metadatalength(); + } + bool hasCompression() const { return format_ == DwrfFormat::kDwrf ? dwrfPtr()->has_compression() : orcPtr()->has_compression(); @@ -244,6 +267,46 @@ class StripeInformationWrapper : public ProtoWrapperBase { } }; +class ColumnEncodingKindWrapper : public ProtoWrapperBase { + public: + explicit ColumnEncodingKindWrapper(proto::ColumnEncoding_Kind* stream) + : ProtoWrapperBase(DwrfFormat::kDwrf, stream) {} + + explicit ColumnEncodingKindWrapper(proto::orc::ColumnEncoding_Kind* stream) + : ProtoWrapperBase(DwrfFormat::kOrc, stream) {} +}; + +class ColumnEncodingWrapper : public ProtoWrapperBase { + public: + explicit ColumnEncodingWrapper(const proto::ColumnEncoding* columnEncoding) + : ProtoWrapperBase(DwrfFormat::kDwrf, columnEncoding) {} + explicit ColumnEncodingWrapper( + const proto::orc::ColumnEncoding* columnEncoding) + : ProtoWrapperBase(DwrfFormat::kOrc, columnEncoding) {} + + void Clear() {} + + proto::ColumnEncoding_Kind kind() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->kind(); + } + + uint32_t node() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->node(); + } + + private: + // private helper with no format checking + inline const proto::ColumnEncoding* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + + inline const proto::orc::ColumnEncoding* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + class TypeWrapper : public ProtoWrapperBase { public: explicit TypeWrapper(const proto::Type* t) @@ -940,6 +1003,1027 @@ class StripeFooterWrapper : public ProtoWrapperBase { std::shared_ptr orcStripeFooter_ = nullptr; }; +class StripeInformationWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit StripeInformationWriteWrapper( + proto::StripeInformation* stripeInformation) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, stripeInformation) {} + + explicit StripeInformationWriteWrapper( + proto::orc::StripeInformation* stripeInformation) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, stripeInformation) {} + + uint64_t numberOfRows() { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->numberofrows() + : orcPtr()->numberofrows(); + } + + uint64_t rawDataSize() { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->rawdatasize() : 0; + } + + bool hasChecksum() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->has_checksum() : false; + } + + uint64_t checksum() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->checksum() : 0; + } + + void setNumberOfRows(uint64_t stripeRowCount) { + return format_ == DwrfFormat::kDwrf + ? dwrfPtr()->set_numberofrows(stripeRowCount) + : orcPtr()->set_numberofrows(stripeRowCount); + } + + void setRawDataSize(uint64_t rawDataSize) { + if (format_ == DwrfFormat::kDwrf) { + dwrfPtr()->set_rawdatasize(rawDataSize); + } + } + + void setChecksum(int64_t checksum) { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + dwrfPtr()->set_checksum(checksum); + } + + void setGroupSize(uint64_t groupSize) { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + dwrfPtr()->set_groupsize(groupSize); + } + + uint64_t groupSize() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->groupsize(); + } + + void setOffset(uint64_t offset) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_offset(offset) + : orcPtr()->set_offset(offset); + } + + uint64_t offset() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->offset() + : orcPtr()->offset(); + } + + void setIndexLength(uint64_t indexLength) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_indexlength(indexLength) + : orcPtr()->set_indexlength(indexLength); + } + + uint64_t indexLength() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->indexlength() + : orcPtr()->indexlength(); + } + + void setDataLength(uint64_t dataLength) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_datalength(dataLength) + : orcPtr()->set_datalength(dataLength); + } + + uint64_t dataLength() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->datalength() + : orcPtr()->datalength(); + } + + void setFooterLength(uint64_t footerLength) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_footerlength(footerLength) + : orcPtr()->set_footerlength(footerLength); + } + + uint64_t footerLength() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->footerlength() + : orcPtr()->footerlength(); + } + + std::string* addKeyMetadata() { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->add_keymetadata(); + } + + private: + // private helper with no format checking + inline proto::StripeInformation* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::StripeInformation* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class TypeKindWrapper : public ProtoWriteWrapperBase { + public: + explicit TypeKindWrapper(proto::Type_Kind* footer) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, footer) {} + + explicit TypeKindWrapper(proto::orc::Type_Kind* footer) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, footer) {} +}; + +class TypeWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit TypeWriteWrapper(proto::Type* footer) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, footer) {} + + explicit TypeWriteWrapper(proto::orc::Type* footer) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, footer) {} + + const proto::Type* getDwrfPtr() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return reinterpret_cast(rawProtoPtr()); + } + + const proto::orc::Type* getOrcPtr() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kOrc); + return reinterpret_cast(rawProtoPtr()); + } + + void setKind(TypeKindWrapper typeKindWrapper) { + format_ == DwrfFormat::kDwrf + ? dwrfPtr()->set_kind(*reinterpret_cast( + typeKindWrapper.rawProtoPtr())) + : orcPtr()->set_kind(*reinterpret_cast( + typeKindWrapper.rawProtoPtr())); + } + + void setScale(uint32_t scale) { + VELOX_CHECK_EQ(format_, DwrfFormat::kOrc); + orcPtr()->set_scale(scale); + } + + void setPrecision(uint32_t precision) { + VELOX_CHECK_EQ(format_, DwrfFormat::kOrc); + orcPtr()->set_precision(precision); + } + + void addFieldnames(const std::string& fieldName) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->add_fieldnames(fieldName) + : orcPtr()->add_fieldnames(fieldName); + } + + void addSubtypes(int fieldName) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->add_subtypes(fieldName) + : orcPtr()->add_subtypes(fieldName); + } + + private: + // private helper with no format checking + inline proto::Type* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::Type* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class UserMetadataItemWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit UserMetadataItemWriteWrapper( + proto::UserMetadataItem* userMetadataItem) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, userMetadataItem) {} + + explicit UserMetadataItemWriteWrapper( + proto::orc::UserMetadataItem* userMetadataItem) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, userMetadataItem) {} + + void setName(const std::string& name) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_name(name) + : orcPtr()->set_name(name); + } + + void setValue(const std::string& value) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_value(value) + : orcPtr()->set_value(value); + } + + private: + // private helper with no format checking + inline proto::UserMetadataItem* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::UserMetadataItem* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class BucketStatisticsWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit BucketStatisticsWriteWrapper( + proto::BucketStatistics* bucketStatistics) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, bucketStatistics) {} + + explicit BucketStatisticsWriteWrapper( + proto::orc::BucketStatistics* bucketStatistics) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, bucketStatistics) {} + + int countSize() { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->count_size() + : orcPtr()->count_size(); + } + + void addCount(uint64_t count) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->add_count(count) + : orcPtr()->add_count(count); + } + + private: + // private helper with no format checking + inline proto::BucketStatistics* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::BucketStatistics* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class IntegerStatisticsWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit IntegerStatisticsWriteWrapper( + proto::IntegerStatistics* integerStatistics) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, integerStatistics) {} + + explicit IntegerStatisticsWriteWrapper( + proto::orc::IntegerStatistics* integerStatistics) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, integerStatistics) {} + + void setSum(int64_t sum) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_sum(sum) + : orcPtr()->set_sum(sum); + } + + void setMinimum(int64_t minimum) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_minimum(minimum) + : orcPtr()->set_minimum(minimum); + } + + void setMaximum(int64_t maximum) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_maximum(maximum) + : orcPtr()->set_maximum(maximum); + } + + private: + // private helper with no format checking + inline proto::IntegerStatistics* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::IntegerStatistics* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class DoubleStatisticsWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit DoubleStatisticsWriteWrapper( + proto::DoubleStatistics* doubleStatistics) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, doubleStatistics) {} + + explicit DoubleStatisticsWriteWrapper( + proto::orc::DoubleStatistics* doubleStatistics) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, doubleStatistics) {} + + void setSum(double sum) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_sum(sum) + : orcPtr()->set_sum(sum); + } + + void setMinimum(double minimum) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_minimum(minimum) + : orcPtr()->set_minimum(minimum); + } + + void setMaximum(double maximum) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_maximum(maximum) + : orcPtr()->set_maximum(maximum); + } + + private: + // private helper with no format checking + inline proto::DoubleStatistics* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::DoubleStatistics* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class StringStatisticsWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit StringStatisticsWriteWrapper( + proto::StringStatistics* stringStatistics) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, stringStatistics) {} + + explicit StringStatisticsWriteWrapper( + proto::orc::StringStatistics* stringStatistics) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, stringStatistics) {} + + void setSum(uint64_t sum) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_sum(sum) + : orcPtr()->set_sum(sum); + } + + void setMinimum(std::string minimum) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_minimum(minimum) + : orcPtr()->set_minimum(minimum); + } + + void setMaximum(std::string maximum) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_maximum(maximum) + : orcPtr()->set_maximum(maximum); + } + + private: + // private helper with no format checking + inline proto::StringStatistics* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::StringStatistics* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class BinaryStatisticsWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit BinaryStatisticsWriteWrapper( + proto::BinaryStatistics* binaryStatistics) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, binaryStatistics) {} + + explicit BinaryStatisticsWriteWrapper( + proto::orc::BinaryStatistics* binaryStatistics) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, binaryStatistics) {} + + void setSum(uint64_t sum) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_sum(sum) + : orcPtr()->set_sum(sum); + } + + private: + // private helper with no format checking + inline proto::BinaryStatistics* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::BinaryStatistics* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class ColumnStatisticsWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit ColumnStatisticsWriteWrapper(proto::ColumnStatistics* footer) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, footer) {} + + explicit ColumnStatisticsWriteWrapper(proto::orc::ColumnStatistics* footer) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, footer) {} + + void setSize(uint64_t size) { + if (format_ == DwrfFormat::kDwrf) { + dwrfPtr()->set_size(size); + } + } + + void setHasNull(bool hasNull) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_hasnull(hasNull) + : orcPtr()->set_hasnull(hasNull); + } + + void setNumberOfValues(uint64_t numberOfValues) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_numberofvalues(numberOfValues) + : orcPtr()->set_numberofvalues(numberOfValues); + } + + void setRawSize(uint64_t rawSize) { + if (format_ == DwrfFormat::kDwrf) { + dwrfPtr()->set_rawsize(rawSize); + } + } + + uint64_t getRawSize() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->rawsize(); + } + + uint64_t getSize() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->size(); + } + + bool hasMapStatistics() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->has_mapstatistics(); + } + + proto::MapStatistics* mutableMapStatistics() { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->mutable_mapstatistics(); + } + + BinaryStatisticsWriteWrapper mutableBinaryStatistics() { + return format_ == DwrfFormat::kDwrf + ? BinaryStatisticsWriteWrapper(dwrfPtr()->mutable_binarystatistics()) + : BinaryStatisticsWriteWrapper(orcPtr()->mutable_binarystatistics()); + } + + StringStatisticsWriteWrapper mutableStringStatistics() { + return format_ == DwrfFormat::kDwrf + ? StringStatisticsWriteWrapper(dwrfPtr()->mutable_stringstatistics()) + : StringStatisticsWriteWrapper(orcPtr()->mutable_stringstatistics()); + } + + DoubleStatisticsWriteWrapper mutableDoubleStatistics() { + return format_ == DwrfFormat::kDwrf + ? DoubleStatisticsWriteWrapper(dwrfPtr()->mutable_doublestatistics()) + : DoubleStatisticsWriteWrapper(orcPtr()->mutable_doublestatistics()); + } + + IntegerStatisticsWriteWrapper mutableIntegerStatistics() { + return format_ == DwrfFormat::kDwrf + ? IntegerStatisticsWriteWrapper(dwrfPtr()->mutable_intstatistics()) + : IntegerStatisticsWriteWrapper(orcPtr()->mutable_intstatistics()); + } + + proto::orc::DateStatistics* mutableDateStatistics() { + VELOX_CHECK_EQ(format_, DwrfFormat::kOrc); + return orcPtr()->mutable_datestatistics(); + } + + proto::orc::TimestampStatistics* mutableTimestampStatistics() { + VELOX_CHECK_EQ(format_, DwrfFormat::kOrc); + return orcPtr()->mutable_timestampstatistics(); + } + + proto::orc::DecimalStatistics* mutableDecimalStatistics() { + VELOX_CHECK_EQ(format_, DwrfFormat::kOrc); + return orcPtr()->mutable_decimalstatistics(); + } + + BucketStatisticsWriteWrapper mutableBucketStatistics() { + return format_ == DwrfFormat::kDwrf + ? BucketStatisticsWriteWrapper(dwrfPtr()->mutable_bucketstatistics()) + : BucketStatisticsWriteWrapper(orcPtr()->mutable_bucketstatistics()); + } + + void reset(const proto::ColumnStatistics* dwrfStatistics) { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + VELOX_CHECK_NOT_NULL(dwrfStatistics); + dwrfPtr()->CopyFrom(*dwrfStatistics); + } + + private: + // private helper with no format checking + inline proto::ColumnStatistics* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::ColumnStatistics* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class FooterWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit FooterWriteWrapper(proto::Footer* footer) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, footer) {} + + explicit FooterWriteWrapper(proto::orc::Footer* footer) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, footer) {} + + const proto::Footer* getDwrfPtr() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return reinterpret_cast(rawProtoPtr()); + } + + proto::Footer* getMutableDwrfPtr() { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return reinterpret_cast(rawProtoPtr()); + } + + const proto::orc::Footer* getOrcPtr() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kOrc); + return reinterpret_cast(rawProtoPtr()); + } + + const StripeInformationWriteWrapper addStripes() const { + return format_ == DwrfFormat::kDwrf + ? StripeInformationWriteWrapper(dwrfPtr()->add_stripes()) + : StripeInformationWriteWrapper(orcPtr()->add_stripes()); + } + + void setHeaderLength(uint64_t headerLength) const { + return format_ == DwrfFormat::kDwrf + ? dwrfPtr()->set_headerlength(headerLength) + : orcPtr()->set_headerlength(headerLength); + } + + void setContentLength(uint64_t contentLength) const { + return format_ == DwrfFormat::kDwrf + ? dwrfPtr()->set_contentlength(contentLength) + : orcPtr()->set_contentlength(contentLength); + } + + void setRowIndexStride(uint32_t rowIndexStride) const { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_rowindexstride(rowIndexStride) + : orcPtr()->set_rowindexstride(rowIndexStride); + } + + void setNumberOfRows(uint64_t numberOfRows) const { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_numberofrows(numberOfRows) + : orcPtr()->set_numberofrows(numberOfRows); + } + + void setRawDataSize(uint64_t numberOfRows) const { + if (format_ == DwrfFormat::kDwrf) { + dwrfPtr()->set_rawdatasize(numberOfRows); + } + } + + void setWriter(uint32_t writer) const { + if (format_ == DwrfFormat::kOrc) { + orcPtr()->set_writer(writer); + } + } + + void setCheckSumAlgorithm(proto::ChecksumAlgorithm checksum) const { + if (format_ == DwrfFormat::kDwrf) { + dwrfPtr()->set_checksumalgorithm(checksum); + } + } + + void addStripeCacheOffsets(uint32_t stripeCacheOffsets) const { + if (format_ == DwrfFormat::kDwrf) { + dwrfPtr()->add_stripecacheoffsets(stripeCacheOffsets); + } else { + // + } + } + + TypeWriteWrapper addTypes() const { + return format_ == DwrfFormat::kDwrf + ? TypeWriteWrapper(dwrfPtr()->add_types()) + : TypeWriteWrapper(orcPtr()->add_types()); + } + + UserMetadataItemWriteWrapper addMetadata() const { + return format_ == DwrfFormat::kDwrf + ? UserMetadataItemWriteWrapper(dwrfPtr()->add_metadata()) + : UserMetadataItemWriteWrapper(orcPtr()->add_metadata()); + } + + ColumnStatisticsWriteWrapper addStatistics() const { + return format_ == DwrfFormat::kDwrf + ? ColumnStatisticsWriteWrapper(dwrfPtr()->add_statistics()) + : ColumnStatisticsWriteWrapper(orcPtr()->add_statistics()); + } + + const ::google::protobuf::RepeatedPtrField< + ::facebook::velox::dwrf::proto::ColumnStatistics>& + statistics() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->statistics(); + } + + int typesSize() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->types_size() + : orcPtr()->types_size(); + } + + int statisticsSize() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->statistics_size() + : orcPtr()->statistics_size(); + } + + uint64_t contentLength() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->contentlength() + : orcPtr()->contentlength(); + } + + uint64_t numberOfRows() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->numberofrows() + : orcPtr()->numberofrows(); + } + + // DWRF-specific fields + inline uint64_t rawDataSize() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->rawdatasize(); + } + + inline int stripesSize() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->stripes_size(); + } + + inline proto::Encryption* mutableEncryption() { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->mutable_encryption(); + } + + private: + // private helper with no format checking + inline proto::Footer* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::Footer* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class RowIndexEntryWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit RowIndexEntryWriteWrapper(proto::RowIndexEntry* rowIndexEntry) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, rowIndexEntry) {} + + explicit RowIndexEntryWriteWrapper(proto::orc::RowIndexEntry* rowIndexEntry) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, rowIndexEntry) {} + + ColumnStatisticsWriteWrapper mutableStatistics() { + return format_ == DwrfFormat::kDwrf + ? ColumnStatisticsWriteWrapper(dwrfPtr()->mutable_statistics()) + : ColumnStatisticsWriteWrapper(orcPtr()->mutable_statistics()); + } + + bool hasStatistics() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->has_statistics() + : orcPtr()->has_statistics(); + } + + void mutablePositions(int start, int num) { + return format_ == DwrfFormat::kDwrf + ? dwrfPtr()->mutable_positions()->ExtractSubrange(start, num, nullptr) + : orcPtr()->mutable_positions()->ExtractSubrange(start, num, nullptr); + } + + void addPositions(uint64_t pos) { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->add_positions(pos) + : orcPtr()->add_positions(pos); + } + + uint64_t positionsSize() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->positions_size() + : orcPtr()->positions_size(); + } + + const ::google::protobuf::RepeatedField positions() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->positions() + : orcPtr()->positions(); + } + + void clear() { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->Clear() + : orcPtr()->Clear(); + } + + private: + // private helper with no format checking + inline proto::RowIndexEntry* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::RowIndexEntry* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class RowIndexWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit RowIndexWriteWrapper(proto::RowIndex* rowIndex) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, rowIndex) {} + + explicit RowIndexWriteWrapper(proto::orc::RowIndex* rowIndex) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, rowIndex) {} + + void addEntry(std::unique_ptr& entry) { + if (format_ == DwrfFormat::kDwrf) { + auto e = reinterpret_cast(entry->rawProtoPtr()); + *dwrfPtr()->add_entry() = *e; + } else { + auto e = + reinterpret_cast(entry->rawProtoPtr()); + *orcPtr()->add_entry() = *e; + } + } + + int32_t entrySize() { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->entry_size() + : orcPtr()->entry_size(); + } + + RowIndexEntryWriteWrapper mutableEntry(int32_t index) { + return format_ == DwrfFormat::kDwrf + ? RowIndexEntryWriteWrapper(dwrfPtr()->mutable_entry(index)) + : RowIndexEntryWriteWrapper(orcPtr()->mutable_entry(index)); + } + + void SerializeToZeroCopyStream( + dwio::common::BufferedOutputStream* out) const { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->SerializeToZeroCopyStream(out) + : orcPtr()->SerializeToZeroCopyStream(out); + } + + void clear() { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->Clear() + : orcPtr()->Clear(); + } + + private: + // private helper with no format checking + inline proto::RowIndex* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::RowIndex* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class ColumnEncodingWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit ColumnEncodingWriteWrapper(proto::ColumnEncoding* stream) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, stream) {} + + explicit ColumnEncodingWriteWrapper(proto::orc::ColumnEncoding* stream) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, stream) {} + + void setKind(ColumnEncodingKindWrapper columnEncodingKindWrapper) { + format_ == DwrfFormat::kDwrf + ? dwrfPtr()->set_kind( + *reinterpret_cast( + columnEncodingKindWrapper.rawProtoPtr())) + : orcPtr()->set_kind( + *reinterpret_cast( + columnEncodingKindWrapper.rawProtoPtr())); + } + + void setDictionarySize(uint32_t dictionarySize) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_dictionarysize(dictionarySize) + : orcPtr()->set_dictionarysize(dictionarySize); + } + + void setNode(uint32_t node) { + if (format_ == DwrfFormat::kDwrf) { + dwrfPtr()->set_node(node); + } + } + + void setSequence(uint32_t sequence) { + if (format_ == DwrfFormat::kDwrf) { + dwrfPtr()->set_sequence(sequence); + } + } + + proto::KeyInfo* mutableKey() { + return dwrfPtr()->mutable_key(); + } + + void Clear() { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->Clear() : orcPtr()->Clear(); + } + + void reset(const proto::ColumnEncoding* dwrfEncoding) { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + VELOX_CHECK_NOT_NULL(dwrfEncoding); + dwrfPtr()->CopyFrom(*dwrfEncoding); + } + + private: + // private helper with no format checking + inline proto::ColumnEncoding* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::ColumnEncoding* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class StreamWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit StreamWriteWrapper(proto::Stream* stream) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, stream) {} + + explicit StreamWriteWrapper(proto::orc::Stream* stream) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, stream) {} + + void setOffset(uint64_t offset) { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + dwrfPtr()->set_offset(offset); + } + + void setKind(const StreamKind& kind) { + format_ == DwrfFormat::kDwrf + ? dwrfPtr()->set_kind(static_cast(kind)) + : orcPtr()->set_kind(static_cast(kind)); + } + + void setColumn(uint32_t column) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_column(column) + : orcPtr()->set_column(column); + } + + void setLength(uint64_t length) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_length(length) + : orcPtr()->set_length(length); + } + + void setNode(uint32_t node) { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + dwrfPtr()->set_node(node); + } + + void setSequence(uint32_t sequence) { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + dwrfPtr()->set_sequence(sequence); + } + + void setUseVints(bool useVints) { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + dwrfPtr()->set_usevints(useVints); + } + + private: + // private helper with no format checking + inline proto::Stream* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::Stream* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class StripeEncryptionGroupWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit StripeEncryptionGroupWriteWrapper( + proto::StripeEncryptionGroup* stripeFooter = nullptr) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, stripeFooter) {} + + // See https://orc.apache.org/specification/ORCv1/ + explicit StripeEncryptionGroupWriteWrapper( + proto::orc::StripeEncryptionVariant* stripeFooter) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, stripeFooter) {} + + void encoding( + std::vector& columnEncodingWrappers) const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + for (const proto::ColumnEncoding& encoding : dwrfPtr()->encoding()) { + auto ce = ColumnEncodingWrapper(&encoding); + columnEncodingWrappers.emplace_back(ce); + } + } + + ColumnEncodingWriteWrapper addEncoding() { + return format_ == DwrfFormat::kDwrf + ? ColumnEncodingWriteWrapper(dwrfPtr()->add_encoding()) + : ColumnEncodingWriteWrapper(orcPtr()->add_encoding()); + } + + StreamWriteWrapper addStreams() { + return format_ == DwrfFormat::kDwrf + ? StreamWriteWrapper(dwrfPtr()->add_streams()) + : StreamWriteWrapper(orcPtr()->add_streams()); + } + + void SerializeToZeroCopyStream( + dwio::common::BufferedOutputStream* output) const { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->SerializeToZeroCopyStream(output) + : orcPtr()->SerializeToZeroCopyStream(output); + } + + private: + // private helper with no format checking + inline proto::StripeEncryptionGroup* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::StripeEncryptionVariant* orcPtr() const { + return reinterpret_cast( + rawProtoPtr()); + } +}; + +class StripeFooterWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit StripeFooterWriteWrapper(proto::StripeFooter* stripeFooter) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, stripeFooter) {} + + explicit StripeFooterWriteWrapper(proto::orc::StripeFooter* stripeFooter) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, stripeFooter) {} + + void encoding( + std::vector& columnEncodingWrappers) const { + if (format_ == DwrfFormat::kDwrf) { + for (const proto::ColumnEncoding& encoding : dwrfPtr()->encoding()) { + auto ce = ColumnEncodingWrapper(&encoding); + columnEncodingWrappers.emplace_back(ce); + } + } else { + for (const proto::orc::ColumnEncoding& encoding : orcPtr()->columns()) { + auto ce = ColumnEncodingWrapper(&encoding); + columnEncodingWrappers.emplace_back(ce); + } + } + } + + void setWriterTimezone() const { + if (format_ == DwrfFormat::kOrc) { + // orcPtr()->set_writertimezone("Asia/Shanghai"); + } + } + + StreamWriteWrapper addStreams() { + return format_ == DwrfFormat::kDwrf + ? StreamWriteWrapper(dwrfPtr()->add_streams()) + : StreamWriteWrapper(orcPtr()->add_streams()); + } + + std::string* addEncryptionGroups() { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->add_encryptiongroups(); + } + + ColumnEncodingWriteWrapper addEncoding() { + return format_ == DwrfFormat::kDwrf + ? ColumnEncodingWriteWrapper(dwrfPtr()->add_encoding()) + : ColumnEncodingWriteWrapper(orcPtr()->add_columns()); + } + + void SerializeToZeroCopyStream( + dwio::common::BufferedOutputStream* output) const { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->SerializeToZeroCopyStream(output) + : orcPtr()->SerializeToZeroCopyStream(output); + } + + inline proto::StripeFooter* dwrfPtr() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return reinterpret_cast(rawProtoPtr()); + } + + private: + // private helper with no format checking + inline proto::orc::StripeFooter* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class PostScriptWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit PostScriptWriteWrapper(proto::PostScript* postScript) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, postScript) {} + + explicit PostScriptWriteWrapper(proto::orc::PostScript* postScript) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, postScript) {} + + void setWriterVersion(uint32_t writerVersion) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_writerversion(writerVersion) + : orcPtr()->set_writerversion(6); + } + + void addVersion(uint32_t version) { + if (format_ == DwrfFormat::kOrc) { + orcPtr()->add_version(version); + } + } + + void setFooterLength(uint64_t footerLength) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_footerlength(footerLength) + : orcPtr()->set_footerlength(footerLength); + } + + void setCompression(common::CompressionKind compressionKind); + + void setCompressionBlockSize(uint64_t compressionBlockSize) { + format_ == DwrfFormat::kDwrf + ? dwrfPtr()->set_compressionblocksize(compressionBlockSize) + : orcPtr()->set_compressionblocksize(compressionBlockSize); + } + + void setCacheMode(StripeCacheMode cacheMode) { + if (format_ == DwrfFormat::kDwrf) { + dwrfPtr()->set_cachemode(static_cast(cacheMode)); + } + } + + void setCacheSize(uint32_t cacheSize) { + if (format_ == DwrfFormat::kDwrf) { + dwrfPtr()->set_cachesize(cacheSize); + } + } + + void setMetaDataLength(uint64_t metaDataLength) { + if (format_ == DwrfFormat::kOrc) { + orcPtr()->set_metadatalength(metaDataLength); + } + } + + void SerializeToZeroCopyStream( + dwio::common::BufferedOutputStream* out) const { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->SerializeToZeroCopyStream(out) + : orcPtr()->SerializeToZeroCopyStream(out); + } + + private: + // private helper with no format checking + inline proto::PostScript* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::PostScript* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + } // namespace facebook::velox::dwrf template <> diff --git a/velox/dwio/dwrf/common/IntEncoder.h b/velox/dwio/dwrf/common/IntEncoder.h index 1051967bb39..bba9060a183 100644 --- a/velox/dwio/dwrf/common/IntEncoder.h +++ b/velox/dwio/dwrf/common/IntEncoder.h @@ -341,8 +341,9 @@ template case 57 ... 63: return writeVarint<1>(value, buffer); } - DWIO_RAISE(folly::sformat( - "Unexpected leading zeros {} for value {}", leadingZeros, value)); + DWIO_RAISE( + folly::sformat( + "Unexpected leading zeros {} for value {}", leadingZeros, value)); } template diff --git a/velox/dwio/dwrf/common/RLEv1.h b/velox/dwio/dwrf/common/RLEv1.h index 082b57c10e0..62ab5d444f1 100644 --- a/velox/dwio/dwrf/common/RLEv1.h +++ b/velox/dwio/dwrf/common/RLEv1.h @@ -417,7 +417,7 @@ class RleDecoderV1 : public dwio::common::IntDecoder { rows + rowIndex, std::min(remainingValues_, numRows - rowIndex)); const auto endOfRun = currentRow + remainingValues_; - const auto bound = std::lower_bound(range.begin(), range.end(), endOfRun); + const auto bound = std::lower_bound(range.cbegin(), range.cend(), endOfRun); return std::make_pair(bound - range.begin(), bound[-1] - currentRow + 1); } diff --git a/velox/dwio/dwrf/common/Statistics.h b/velox/dwio/dwrf/common/Statistics.h index 864fb08db37..1fb1e78c2b3 100644 --- a/velox/dwio/dwrf/common/Statistics.h +++ b/velox/dwio/dwrf/common/Statistics.h @@ -19,7 +19,6 @@ #include "velox/dwio/common/Statistics.h" #include "velox/dwio/dwrf/common/Common.h" #include "velox/dwio/dwrf/common/FileMetadata.h" -#include "velox/dwio/dwrf/common/wrap/dwrf-proto-wrapper.h" namespace facebook::velox::dwrf { diff --git a/velox/dwio/dwrf/common/wrap/CMakeLists.txt b/velox/dwio/dwrf/common/wrap/CMakeLists.txt new file mode 100644 index 00000000000..74e338b7507 --- /dev/null +++ b/velox/dwio/dwrf/common/wrap/CMakeLists.txt @@ -0,0 +1,24 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +velox_install_library_headers() + +# Wrapper headers — verify these are still needed and correctly included. +velox_add_library( + velox_dwio_dwrf_common_wrap + INTERFACE + HEADERS + coded-stream-wrapper.h + dwrf-proto-wrapper.h + orc-proto-wrapper.h +) diff --git a/velox/dwio/dwrf/reader/CMakeLists.txt b/velox/dwio/dwrf/reader/CMakeLists.txt index 8b6f04ec962..a5a38a0ab73 100644 --- a/velox/dwio/dwrf/reader/CMakeLists.txt +++ b/velox/dwio/dwrf/reader/CMakeLists.txt @@ -34,6 +34,32 @@ velox_add_library( StripeDictionaryCache.cpp StripeReaderBase.cpp StripeStream.cpp + HEADERS + BinaryStreamReader.h + ColumnReader.h + ConstantColumnReader.h + DwrfData.h + DwrfReader.h + EncodingContext.h + FlatMapColumnReader.h + ReaderBase.h + SelectiveByteRleColumnReader.h + SelectiveDecimalColumnReader.h + SelectiveDwrfReader.h + SelectiveFlatMapColumnReader.h + SelectiveFloatingPointColumnReader.h + SelectiveIntegerDictionaryColumnReader.h + SelectiveIntegerDirectColumnReader.h + SelectiveRepeatedColumnReader.h + SelectiveStringDictionaryColumnReader.h + SelectiveStringDirectColumnReader.h + SelectiveStructColumnReader.h + SelectiveTimestampColumnReader.h + StreamLabels.h + StripeDictionaryCache.h + StripeMetadataCache.h + StripeReaderBase.h + StripeStream.h ) velox_link_libraries( diff --git a/velox/dwio/dwrf/reader/ColumnReader.cpp b/velox/dwio/dwrf/reader/ColumnReader.cpp index 5699694ef7b..0e20372722c 100644 --- a/velox/dwio/dwrf/reader/ColumnReader.cpp +++ b/velox/dwio/dwrf/reader/ColumnReader.cpp @@ -1311,7 +1311,7 @@ void StringDictionaryColumnReader::loadStrideDictionary() { if (strideDictCount_ > 0) { // seek stride dictionary related streams std::vector pos( - positions.begin() + positionOffset_, positions.end()); + positions.cbegin() + positionOffset_, positions.cend()); dwio::common::PositionProvider pp(pos); strideDictStream_->seekToPosition(pp); strideDictLengthDecoder_->seekToRowGroup(pp); @@ -2461,8 +2461,9 @@ std::unique_ptr buildByteRleColumnReader( RleDecoderFactory::get(), std::move(flatMapContext)); default: - DWIO_RAISE(fmt::format( - "Unsupported upcast to typekind: {}", requestedType->toString())); + DWIO_RAISE( + fmt::format( + "Unsupported upcast to typekind: {}", requestedType->toString())); } } @@ -2502,9 +2503,10 @@ std::unique_ptr buildTypedIntegerColumnReader( numBytes, std::move(flatMapContext)); default: - DWIO_RAISE(fmt::format( - "Unsupported requested integral type: {}", - requestedType->toString())); + DWIO_RAISE( + fmt::format( + "Unsupported requested integral type: {}", + requestedType->toString())); } } diff --git a/velox/dwio/dwrf/reader/DwrfData.h b/velox/dwio/dwrf/reader/DwrfData.h index b4237fef8b9..04fa978a96e 100644 --- a/velox/dwio/dwrf/reader/DwrfData.h +++ b/velox/dwio/dwrf/reader/DwrfData.h @@ -100,7 +100,7 @@ class DwrfData : public dwio::common::FormatData { static std::vector toPositionsInner( const proto::RowIndexEntry& entry) { return std::vector( - entry.positions().begin(), entry.positions().end()); + entry.positions().cbegin(), entry.positions().cend()); } memory::MemoryPool& memoryPool_; diff --git a/velox/dwio/dwrf/reader/DwrfReader.cpp b/velox/dwio/dwrf/reader/DwrfReader.cpp index 4f0ca050d58..db2d33b7751 100644 --- a/velox/dwio/dwrf/reader/DwrfReader.cpp +++ b/velox/dwio/dwrf/reader/DwrfReader.cpp @@ -19,6 +19,7 @@ #include #include "velox/dwio/common/OnDemandUnitLoader.h" +#include "velox/dwio/common/ParallelUnitLoader.h" #include "velox/dwio/common/TypeUtils.h" #include "velox/dwio/common/exception/Exception.h" #include "velox/dwio/dwrf/reader/ColumnReader.h" @@ -38,24 +39,25 @@ using dwio::common::UnitLoaderFactory; class DwrfUnit : public LoadUnit { public: DwrfUnit( - const StripeReaderBase& stripeReaderBase, + std::shared_ptr readerBase, const StrideIndexProvider& strideIndexProvider, - dwio::common::ColumnReaderStatistics& columnReaderStatistics, + std::shared_ptr columnReaderStats, uint32_t stripeIndex, std::shared_ptr columnSelector, - const std::shared_ptr& projectedNodes, + std::shared_ptr projectedNodes, RowReaderOptions options, - const dwio::common::ColumnReaderOptions& columnReaderOptions) - : stripeReaderBase_{stripeReaderBase}, + dwio::common::ColumnReaderOptions columnReaderOptions) + : stripeReaderBase_{readerBase}, + memoryPool_(readerBase->memoryPool().shared_from_this()), strideIndexProvider_{strideIndexProvider}, - columnReaderStatistics_{&columnReaderStatistics}, + columnReaderStats_{std::move(columnReaderStats)}, stripeIndex_{stripeIndex}, columnSelector_{std::move(columnSelector)}, - projectedNodes_{projectedNodes}, + projectedNodes_{std::move(projectedNodes)}, options_{std::move(options)}, - columnReaderOptions_{columnReaderOptions}, + columnReaderOptions_{std::move(columnReaderOptions)}, stripeInfo_{ - stripeReaderBase.getReader().footer().stripes(stripeIndex_)} {} + stripeReaderBase_.getReader().footer().stripes(stripeIndex_)} {} ~DwrfUnit() override = default; @@ -85,14 +87,25 @@ class DwrfUnit : public LoadUnit { void loadDecoders(); // Immutables - const StripeReaderBase& stripeReaderBase_; + const StripeReaderBase stripeReaderBase_; + // Not used in DwrfUnit directly, it is to keep memory pool alive for + // readerBase + const std::shared_ptr memoryPool_; + + // SAFETY: This reference is safe despite DwrfUnit potentially outliving + // DwrfRowReader during async operations. The reference is only STORED (not + // dereferenced) during load() path. Actual dereferencing via + // getStrideIndex() only happens during synchronous data reading in + // ColumnReader::next(), where DwrfRowReader is guaranteed to be alive. const StrideIndexProvider& strideIndexProvider_; - dwio::common::ColumnReaderStatistics* const columnReaderStatistics_; + + const std::shared_ptr + columnReaderStats_; const uint32_t stripeIndex_; const std::shared_ptr columnSelector_; const std::shared_ptr projectedNodes_; const RowReaderOptions options_; - const dwio::common::ColumnReaderOptions& columnReaderOptions_; + const dwio::common::ColumnReaderOptions columnReaderOptions_; const StripeInformationWrapper stripeInfo_; // Mutables @@ -152,7 +165,8 @@ void DwrfUnit::ensureDecoders() { stripeInfo_.offset(), stripeInfo_.numberOfRows(), strideIndexProvider_, - stripeIndex_); + stripeIndex_, + columnReaderStats_.get()); auto* scanSpec = options_.scanSpec().get(); const auto& fileType = stripeReaderBase_.getReader().schemaWithId(); @@ -168,7 +182,7 @@ void DwrfUnit::ensureDecoders() { fileType, *stripeStreams_, streamLabels, - *columnReaderStatistics_, + *columnReaderStats_, scanSpec, flatMapContext, /*isRoot=*/true); @@ -227,6 +241,15 @@ void makeProjectedNodes( } } +const velox::common::ScanSpec* getChildScanSpec( + const velox::common::ScanSpec* scanSpec, + const TypeWrapper& nodeType, + int32_t childIdx) { + return scanSpec != nullptr && childIdx < nodeType.fieldNamesSize() + ? scanSpec->childByName(nodeType.fieldNames(childIdx)) + : nullptr; +} + } // namespace DwrfRowReader::DwrfRowReader( @@ -242,7 +265,11 @@ DwrfRowReader::DwrfRowReader( reader->schema()))}, decodingTimeCallback_{options_.decodingTimeCallback()}, strideIndex_{0}, + columnReaderStats_( + std::make_shared()), currentUnit_{nullptr} { + columnReaderStats_->initColumnStatsCollection( + *getReader().schemaWithId(), options_); const auto& fileFooter = getReader().footer(); const uint32_t numberOfStripes = fileFooter.stripesSize(); currentStripe_ = numberOfStripes; @@ -328,22 +355,30 @@ std::unique_ptr DwrfRowReader::getUnitLoader() { std::vector> loadUnits; loadUnits.reserve(stripeCeiling_ - firstStripe_); for (auto stripe = firstStripe_; stripe < stripeCeiling_; ++stripe) { - loadUnits.emplace_back(std::make_unique( - /*stripeReaderBase=*/*this, - /*strideIndexProvider=*/*this, - columnReaderStatistics_, - stripe, - columnSelector_, - projectedNodes_, - options_, - columnReaderOptions_)); + loadUnits.emplace_back( + std::make_unique( + /*readerBase=*/readerBaseShared(), + /*strideIndexProvider=*/*this, + columnReaderStats_, + stripe, + columnSelector_, + projectedNodes_, + options_, + columnReaderOptions_)); } std::shared_ptr unitLoaderFactory = options_.unitLoaderFactory(); if (!unitLoaderFactory) { - unitLoaderFactory = - std::make_shared( - options_.blockedOnIoCallback()); + if (loadUnits.size() > 1 && options_.parallelUnitLoadCount() > 1 && + options_.ioExecutor() != nullptr) { + unitLoaderFactory = + std::make_shared( + options_.ioExecutor(), options_.parallelUnitLoadCount()); + } else { + unitLoaderFactory = + std::make_shared( + options_.blockedOnIoCallback()); + } } return unitLoaderFactory->create(std::move(loadUnits), 0); } @@ -567,8 +602,10 @@ int64_t DwrfRowReader::nextRowNumber() { const auto skipRows = getReader().randomSkip()->nextSkip(); if (skipRows >= numStripeRows) { getReader().randomSkip()->consume(numStripeRows); - const auto numStrides = bits::divRoundUp(numStripeRows, strideSize); - skippedStrides_ += numStrides; + if (strideSize > 0) { + skippedStrides_ += static_cast( + bits::divRoundUp(numStripeRows, strideSize)); + } goto advanceToNextStripe; } } @@ -578,6 +615,9 @@ int64_t DwrfRowReader::nextRowNumber() { checkSkipStrides(strideSize); if (currentRowInStripe_ < rowsInCurrentStripe_) { + if (strideSize > 0 && currentRowInStripe_ % strideSize == 0) { + ++processedStrides_; + } nextRowNumber_ = firstRowOfStripe_[currentStripe_] + currentRowInStripe_; return *nextRowNumber_; } @@ -632,6 +672,8 @@ uint64_t DwrfRowReader::next( } else { previousRow_ = 0; } + // Collect unit loader stats at the end. + unitLoadStats_ = unitLoader_->stats(); return 0; } @@ -668,7 +710,6 @@ void DwrfRowReader::loadCurrentStripe() { const auto loadUnitIdx = currentStripe_ - firstStripe_; currentUnit_ = castDwrfUnit(&unitLoader_->getLoadedUnit(loadUnitIdx)); rowsInCurrentStripe_ = currentUnit_->getNumRows(); - ++processedStrides_; } size_t DwrfRowReader::estimatedReaderMemory() const { @@ -676,15 +717,17 @@ size_t DwrfRowReader::estimatedReaderMemory() const { return 2 * DwrfReader::getMemoryUse(getReader(), -1, *columnSelector_); } -bool DwrfRowReader::shouldReadNode(uint32_t nodeId) const { - if (columnSelector_) { - return columnSelector_->shouldReadNode(nodeId); - } - return projectedNodes_->contains(nodeId); +bool DwrfRowReader::shouldReadNode( + uint32_t nodeId, + const velox::common::ScanSpec* fieldScanSpec) const { + bool nodeIdSelected = (columnSelector_) + ? columnSelector_->shouldReadNode(nodeId) + : projectedNodes_->contains(nodeId); + return nodeIdSelected && + !(fieldScanSpec != nullptr && !fieldScanSpec->readFromFile()); } namespace { - template std::optional getStringOrBinaryColumnSize( const dwio::common::ColumnStatistics& stats) { @@ -703,6 +746,7 @@ std::optional getStringOrBinaryColumnSize( std::optional DwrfRowReader::estimatedRowSizeHelper( const FooterWrapper& fileFooter, const dwio::common::Statistics& stats, + const velox::common::ScanSpec* scanSpec, uint32_t nodeId) const { VELOX_CHECK_LT(nodeId, fileFooter.typesSize(), "Types missing in footer"); @@ -757,11 +801,16 @@ std::optional DwrfRowReader::estimatedRowSizeHelper( ? 0 : 2 * valueCount * sizeof(vector_size_t); for (int32_t i = 0; i < nodeType.subtypesSize(); ++i) { - if (!shouldReadNode(nodeType.subtypes(i))) { + if (!shouldReadNode( + nodeType.subtypes(i), + getChildScanSpec(scanSpec, nodeType, i))) { continue; } - const auto subtypeEstimate = - estimatedRowSizeHelper(fileFooter, stats, nodeType.subtypes(i)); + const auto subtypeEstimate = estimatedRowSizeHelper( + fileFooter, + stats, + getChildScanSpec(scanSpec, nodeType, i), + nodeType.subtypes(i)); if (subtypeEstimate.has_value()) { totalEstimate += subtypeEstimate.value(); } else { @@ -776,26 +825,40 @@ std::optional DwrfRowReader::estimatedRowSizeHelper( } std::optional DwrfRowReader::estimatedRowSize() const { + if (hasRowEstimate_) { + return estimatedRowSize_; + } + const auto& reader = getReader(); const auto& fileFooter = reader.footer(); + hasRowEstimate_ = true; + if (!fileFooter.hasNumberOfRows()) { - return std::nullopt; + estimatedRowSize_ = std::nullopt; + return estimatedRowSize_; } if (fileFooter.numberOfRows() < 1) { - return 0; + estimatedRowSize_ = 0; + return estimatedRowSize_; } // Estimate with projections. constexpr uint32_t ROOT_NODE_ID = 0; const auto stats = reader.statistics(); - const auto projectedSize = - estimatedRowSizeHelper(fileFooter, *stats, ROOT_NODE_ID); + const auto projectedSize = estimatedRowSizeHelper( + fileFooter, + *stats, + reader.readerOptions().scanSpec().get(), + ROOT_NODE_ID); if (projectedSize.has_value()) { - return projectedSize.value() / fileFooter.numberOfRows(); + estimatedRowSize_ = projectedSize.value() / fileFooter.numberOfRows(); + return estimatedRowSize_; } - return std::nullopt; + + estimatedRowSize_ = std::nullopt; + return estimatedRowSize_; } DwrfReader::DwrfReader( @@ -817,8 +880,9 @@ DwrfReader::DwrfReader( void DwrfReader::updateColumnNamesFromTableSchema() { const auto& tableSchema = readerBase_->readerOptions().fileSchema(); const auto& fileSchema = readerBase_->schema(); - readerBase_->setSchema(std::dynamic_pointer_cast( - updateColumnNames(fileSchema, tableSchema))); + readerBase_->setSchema( + std::dynamic_pointer_cast( + updateColumnNames(fileSchema, tableSchema))); } std::unique_ptr DwrfReader::getStripe( @@ -990,8 +1054,8 @@ uint64_t DwrfReader::getMemoryUse( // Do we need even more memory to read the footer or the metadata? const auto footerLength = readerBase.postScript().footerLength(); - if (memoryBytes < footerLength + readerBase.footerEstimatedSize()) { - memoryBytes = footerLength + readerBase.footerEstimatedSize(); + if (memoryBytes < footerLength + readerBase.footerSpeculativeIoSize()) { + memoryBytes = footerLength + readerBase.footerSpeculativeIoSize(); } // Account for firstRowOfStripe. diff --git a/velox/dwio/dwrf/reader/DwrfReader.h b/velox/dwio/dwrf/reader/DwrfReader.h index dcb38dbb5a1..6dbae5615df 100644 --- a/velox/dwio/dwrf/reader/DwrfReader.h +++ b/velox/dwio/dwrf/reader/DwrfReader.h @@ -111,8 +111,8 @@ class DwrfRowReader : public StrideIndexProvider, stats.processedStrides += processedStrides_; stats.footerBufferOverread += getReader().footerBufferOverread(); stats.numStripes += stripeCeiling_ - firstStripe_; - stats.columnReaderStatistics.flattenStringDictionaryValues += - columnReaderStatistics_.flattenStringDictionaryValues; + stats.columnReaderStats.mergeFrom(*columnReaderStats_); + stats.unitLoaderStats.merge(unitLoadStats_); } void resetFilterCaches() override; @@ -148,11 +148,14 @@ class DwrfRowReader : public StrideIndexProvider, } private: - bool shouldReadNode(uint32_t nodeId) const; + bool shouldReadNode( + uint32_t nodeId, + const velox::common::ScanSpec* fieldScanSpec) const; std::optional estimatedRowSizeHelper( const FooterWrapper& fileFooter, const dwio::common::Statistics& stats, + const velox::common::ScanSpec* scanSpec, uint32_t nodeId) const; bool emptyFile() const { @@ -210,17 +213,22 @@ class DwrfRowReader : public StrideIndexProvider, // Number of processed strides. int64_t processedStrides_{0}; + dwio::common::UnitLoaderStats unitLoadStats_; + // Set to true after clearing filter caches, i.e. adding a dynamic filter. // Causes filters to be re-evaluated against stride stats on next stride // instead of next stripe. bool recomputeStridesToSkip_{false}; - dwio::common::ColumnReaderStatistics columnReaderStatistics_; + std::shared_ptr columnReaderStats_; std::optional nextRowNumber_; std::unique_ptr unitLoader_; DwrfUnit* currentUnit_; + + mutable std::optional estimatedRowSize_; + mutable bool hasRowEstimate_{false}; }; class DwrfReader : public dwio::common::Reader { diff --git a/velox/dwio/dwrf/reader/FlatMapColumnReader.cpp b/velox/dwio/dwrf/reader/FlatMapColumnReader.cpp index 22f7220e853..d1a611755bf 100644 --- a/velox/dwio/dwrf/reader/FlatMapColumnReader.cpp +++ b/velox/dwio/dwrf/reader/FlatMapColumnReader.cpp @@ -163,12 +163,13 @@ std::vector>> getKeyNodesFiltered( .inMapDecoder = inMapDecoder.get(), .keySelectionCallback = nullptr}); - keyNodes.push_back(std::make_unique>( - std::move(valueReader), - std::move(inMapDecoder), - key, - sequence, - memoryPool)); + keyNodes.push_back( + std::make_unique>( + std::move(valueReader), + std::move(inMapDecoder), + key, + sequence, + memoryPool)); }); keySelectionStats.selectedKeys = keyNodes.size(); diff --git a/velox/dwio/dwrf/reader/ReaderBase.cpp b/velox/dwio/dwrf/reader/ReaderBase.cpp index c52dcf871a7..2462ed735d9 100644 --- a/velox/dwio/dwrf/reader/ReaderBase.cpp +++ b/velox/dwio/dwrf/reader/ReaderBase.cpp @@ -19,12 +19,14 @@ #include #include "velox/common/process/TraceContext.h" +#include "velox/dwio/common/Arena.h" #include "velox/dwio/common/Mutation.h" #include "velox/dwio/common/exception/Exception.h" #include "velox/functions/lib/string/StringImpl.h" namespace facebook::velox::dwrf { +using dwio::common::ArenaCreate; using dwio::common::ColumnStatistics; using dwio::common::FileFormat; using dwio::common::LogType; @@ -77,12 +79,6 @@ FooterStatisticsImpl::FooterStatisticsImpl( } } -ReaderBase::ReaderBase( - MemoryPool& pool, - std::unique_ptr input, - FileFormat fileFormat) - : ReaderBase(createReaderOptions(pool, fileFormat), std::move(input)) {} - namespace { template @@ -96,7 +92,7 @@ template std::unique_ptr parseFooter( dwio::common::SeekableInputStream* input, google::protobuf::Arena* arena) { - auto* impl = google::protobuf::Arena::CreateMessage(arena); + auto* impl = ArenaCreate(arena); VELOX_CHECK(impl->ParseFromZeroCopyStream(input)); return std::make_unique(impl); } @@ -115,15 +111,14 @@ ReaderBase::ReaderBase( DWIO_ENSURE(fileLength_ > 0, "ORC file is empty"); VELOX_CHECK_GE(fileLength_, 4, "File size too small"); - const auto preloadFile = fileLength_ <= options_.filePreloadThreshold(); - const int64_t footerBufSize = - std::min(fileLength_, options_.footerEstimatedSize()); - const uint64_t readSize = preloadFile ? fileLength_ : footerBufSize; - if (input_->supportSyncLoad()) { - input_->enqueue({fileLength_ - readSize, readSize, "footer"}); - input_->load(preloadFile ? LogType::FILE : LogType::FOOTER); + // Preload small files: one IO for entire file, serving all subsequent reads. + if (fileLength_ <= options_.filePreloadThreshold()) { + input_->preload(); } + const int64_t footerBufSize = + std::min(fileLength_, options_.footerSpeculativeIoSize()); + // TODO: read footer from spectrum auto footerBuffer = AlignedBuffer::allocate(footerBufSize, &options_.memoryPool()); @@ -170,11 +165,6 @@ ReaderBase::ReaderBase( "Corrupted File, invalid compression kind ", postScript_->compression()); - if (input_->supportSyncLoad() && (tailSize > readSize)) { - input_->enqueue({fileLength_ - tailSize, tailSize, "footer"}); - input_->load(LogType::FOOTER); - } - BufferPtr fullFooterBuffer; char* footerStart; if (footerOffset >= footerSize) { diff --git a/velox/dwio/dwrf/reader/ReaderBase.h b/velox/dwio/dwrf/reader/ReaderBase.h index 561d88e5684..8a90bdaf0d8 100644 --- a/velox/dwio/dwrf/reader/ReaderBase.h +++ b/velox/dwio/dwrf/reader/ReaderBase.h @@ -20,10 +20,10 @@ #include "velox/dwio/common/BufferedInput.h" #include "velox/dwio/common/Options.h" #include "velox/dwio/common/SeekableInputStream.h" +#include "velox/dwio/common/Statistics.h" #include "velox/dwio/common/TypeWithId.h" #include "velox/dwio/dwrf/common/Compression.h" #include "velox/dwio/dwrf/common/Decryption.h" -#include "velox/dwio/dwrf/common/FileMetadata.h" #include "velox/dwio/dwrf/common/Statistics.h" #include "velox/dwio/dwrf/reader/StripeMetadataCache.h" #include "velox/dwio/dwrf/utils/ProtoUtils.h" @@ -65,22 +65,22 @@ class ReaderBase { const dwio::common::ReaderOptions& options, std::unique_ptr input); - /// Creates reader base from buffered input. - /// It is kept here for backward compatibility with Meta's internal usage. - ReaderBase( - memory::MemoryPool& pool, - std::unique_ptr input, - dwio::common::FileFormat fileFormat = dwio::common::FileFormat::DWRF); - - /// Creates reader base from metadata. + /// Creates reader base from metadata (for testing). ReaderBase( memory::MemoryPool& pool, std::unique_ptr input, std::unique_ptr ps, const proto::Footer* footer, std::unique_ptr cache, - std::unique_ptr handler = nullptr) - : options_{dwio::common::ReaderOptions(&pool)}, + std::unique_ptr handler, + std::shared_ptr dataIoStats = nullptr, + std::shared_ptr metadataIoStats = nullptr) + : options_{[&] { + dwio::common::ReaderOptions opts(&pool); + opts.setDataIoStats(std::move(dataIoStats)); + opts.setMetadataIoStats(std::move(metadataIoStats)); + return opts; + }()}, input_{std::move(input)}, fileLength_{0}, postScript_{std::move(ps)}, @@ -160,8 +160,8 @@ class ReaderBase { return *handler_; } - uint64_t footerEstimatedSize() const { - return options_.footerEstimatedSize(); + uint64_t footerSpeculativeIoSize() const { + return options_.footerSpeculativeIoSize(); } uint64_t fileLength() const { @@ -199,7 +199,7 @@ class ReaderBase { const std::string& writerName() const { for (int32_t index = 0; index < footer_->metadataSize(); ++index) { auto entry = footer_->metadata(index); - if (entry.name() == WRITER_NAME_KEY) { + if (entry.name() == kWriterNameKey) { return entry.value(); } } @@ -216,14 +216,16 @@ class ReaderBase { std::unique_ptr createDecompressedStream( std::unique_ptr compressed, const std::string& streamDebugInfo, - const dwio::common::encryption::Decrypter* decrypter = nullptr) const { + const dwio::common::encryption::Decrypter* decrypter = nullptr, + velox::io::IoCounter* decompressCounter = nullptr) const { return createDecompressor( compressionKind(), std::move(compressed), compressionBlockSize(), options_.memoryPool(), streamDebugInfo, - decrypter); + decrypter, + decompressCounter); } template @@ -258,14 +260,6 @@ class ReaderBase { uint32_t index = 0, bool fileColumnNamesReadAsLowerCase = false); - static dwio::common::ReaderOptions createReaderOptions( - memory::MemoryPool& pool, - dwio::common::FileFormat fileFormat) { - dwio::common::ReaderOptions options(&pool); - options.setFileFormat(fileFormat); - return options; - } - const dwio::common::ReaderOptions options_; const std::unique_ptr input_; const uint64_t fileLength_; @@ -282,7 +276,7 @@ class ReaderBase { RowTypePtr schema_; // Lazily populated mutable std::shared_ptr schemaWithId_; - uint64_t psLength_; + uint64_t psLength_{}; }; } // namespace facebook::velox::dwrf diff --git a/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.cpp b/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.cpp index ec570ae05b7..09cdd6cd87e 100644 --- a/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.cpp +++ b/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.cpp @@ -20,10 +20,19 @@ namespace facebook::velox::dwrf { template SelectiveDecimalColumnReader::SelectiveDecimalColumnReader( + const TypePtr& requestedType, const std::shared_ptr& fileType, DwrfParams& params, common::ScanSpec& scanSpec) - : SelectiveColumnReader(fileType->type(), fileType, params, scanSpec) { + // Read using requestedType so that values are materialized at the + // table-schema scale rather than the file-footer scale. See the header + // comment for the Hive ORC DECIMAL(38, 18) footer behavior this works + // around. + : SelectiveColumnReader(requestedType, fileType, params, scanSpec) { + VELOX_CHECK( + requestedType_->isDecimal(), + "SelectiveDecimalColumnReader requires a decimal requestedType, got {}", + requestedType_->toString()); EncodingKey encodingKey{fileType_->id(), params.flatMapContext().sequence}; auto& stripe = params.stripeStreams(); if constexpr (std::is_same_v) { @@ -51,7 +60,7 @@ SelectiveDecimalColumnReader::SelectiveDecimalColumnReader( scaleDecoder_ = createRleDecoder( stripe.getStream(secondary, params.streamLabels().label(), true), version_, - *memoryPool_, + *pool_, stripe.getUseVInts(secondary), LONG_BYTE_SIZE); } @@ -75,16 +84,17 @@ void SelectiveDecimalColumnReader::seekToRowGroup(int64_t index) { template template -void SelectiveDecimalColumnReader::readHelper(RowSet rows) { - vector_size_t numRows = rows.back() + 1; +void SelectiveDecimalColumnReader::readHelper( + const common::Filter* filter, + RowSet rows) { ExtractToReader extractValues(this); - common::AlwaysTrue filter; + common::AlwaysTrue alwaysTrue; DirectRleColumnVisitor< int64_t, common::AlwaysTrue, decltype(extractValues), kDense> - visitor(filter, this, rows, extractValues); + visitor(alwaysTrue, this, rows, extractValues); // decode scale stream if (version_ == velox::dwrf::RleVersion_1) { @@ -94,7 +104,7 @@ void SelectiveDecimalColumnReader::readHelper(RowSet rows) { } // copy scales into scaleBuffer_ - ensureCapacity(scaleBuffer_, numValues_, memoryPool_); + ensureCapacity(scaleBuffer_, numValues_, pool_); scaleBuffer_->setSize(numValues_ * sizeof(int64_t)); memcpy( scaleBuffer_->asMutable(), @@ -104,14 +114,135 @@ void SelectiveDecimalColumnReader::readHelper(RowSet rows) { // reset numValues_ before reading values numValues_ = 0; valueSize_ = sizeof(DataT); + vector_size_t numRows = rows.back() + 1; ensureValuesCapacity(numRows); // decode value stream facebook::velox::dwio::common:: ColumnVisitor - valueVisitor(filter, this, rows, extractValues); + valueVisitor(alwaysTrue, this, rows, extractValues); decodeWithVisitor>(valueDecoder_.get(), valueVisitor); readOffset_ += numRows; + + // Fill decimals before applying filter. + fillDecimals(); + + // 'nullsInReadRange_' is the nulls for the entire read range, and if the row + // set is not dense, result nulls should be allocated, which represents the + // nulls for the selected rows before filtering. + const auto rawNulls = nullsInReadRange_ + ? (kDense ? nullsInReadRange_->as() : rawResultNulls_) + : nullptr; + // Process filter. + process(filter, rows, rawNulls); +} + +template +void SelectiveDecimalColumnReader::processNulls( + bool isNull, + const RowSet& rows, + const uint64_t* rawNulls) { + if (!rawNulls) { + return; + } + returnReaderNulls_ = false; + anyNulls_ = !isNull; + allNull_ = isNull; + + auto rawDecimal = values_->asMutable(); + auto rawScale = scaleBuffer_->asMutable(); + + vector_size_t idx = 0; + if (isNull) { + for (vector_size_t i = 0; i < numValues_; i++) { + if (bits::isBitNull(rawNulls, i)) { + bits::setNull(rawResultNulls_, idx); + addOutputRow(rows[i]); + idx++; + } + } + } else { + for (vector_size_t i = 0; i < numValues_; i++) { + if (!bits::isBitNull(rawNulls, i)) { + bits::setNull(rawResultNulls_, idx, false); + rawDecimal[idx] = rawDecimal[i]; + rawScale[idx] = rawScale[i]; + addOutputRow(rows[i]); + idx++; + } + } + } +} + +template +void SelectiveDecimalColumnReader::processFilter( + const common::Filter* filter, + const RowSet& rows, + const uint64_t* rawNulls) { + VELOX_CHECK_NOT_NULL(filter, "Filter must not be null."); + returnReaderNulls_ = false; + anyNulls_ = false; + allNull_ = true; + + vector_size_t idx = 0; + auto rawDecimal = values_->asMutable(); + for (vector_size_t i = 0; i < numValues_; i++) { + if (rawNulls && bits::isBitNull(rawNulls, i)) { + if (filter->testNull()) { + bits::setNull(rawResultNulls_, idx); + addOutputRow(rows[i]); + anyNulls_ = true; + idx++; + } + } else { + bool tested; + if constexpr (std::is_same_v) { + tested = filter->testInt64(rawDecimal[i]); + } else { + tested = filter->testInt128(rawDecimal[i]); + } + + if (tested) { + if (rawNulls) { + bits::setNull(rawResultNulls_, idx, false); + } + rawDecimal[idx] = rawDecimal[i]; + addOutputRow(rows[i]); + allNull_ = false; + idx++; + } + } + } +} + +template +void SelectiveDecimalColumnReader::process( + const common::Filter* filter, + const RowSet& rows, + const uint64_t* rawNulls) { + if (!filter) { + // No filter and "hasDeletion" is false so input rows will be + // reused. + return; + } + + switch (filter->kind()) { + case common::FilterKind::kIsNull: + processNulls(true, rows, rawNulls); + break; + case common::FilterKind::kIsNotNull: { + if (rawNulls) { + processNulls(false, rows, rawNulls); + } else { + for (vector_size_t i = 0; i < numValues_; i++) { + addOutputRow(rows[i]); + } + } + break; + } + default: + processFilter(filter, rows, rawNulls); + } } template @@ -119,14 +250,23 @@ void SelectiveDecimalColumnReader::read( int64_t offset, const RowSet& rows, const uint64_t* incomingNulls) { - VELOX_CHECK(!scanSpec_->filter()); VELOX_CHECK(!scanSpec_->valueHook()); prepareRead(offset, rows, incomingNulls); + if (!scanSpec_->keepValues() && scanSpec_->filter() && + (!resultNulls_ || !resultNulls_->unique() || + resultNulls_->capacity() * 8 < rows.size())) { + // Make sure a dedicated resultNulls_ is allocated with enough capacity as + // RleDecoder always assumes it is available and 'prepareRead' skips + // allocation when the column is not projected. + resultNulls_ = AlignedBuffer::allocate(rows.size(), pool_); + rawResultNulls_ = resultNulls_->asMutable(); + } + rawValues_ = values_->asMutable(); bool isDense = rows.back() == rows.size() - 1; if (isDense) { - readHelper(rows); + readHelper(scanSpec_->filter(), rows); } else { - readHelper(rows); + readHelper(scanSpec_->filter(), rows); } } @@ -134,16 +274,17 @@ template void SelectiveDecimalColumnReader::getValues( const RowSet& rows, VectorPtr* result) { + getIntValues(rows, requestedType_, result); +} + +template +void SelectiveDecimalColumnReader::fillDecimals() { auto nullsPtr = resultNulls() ? resultNulls()->template as() : nullptr; auto scales = scaleBuffer_->as(); auto values = values_->asMutable(); - DecimalUtil::fillDecimals( values, nullsPtr, values, scales, numValues_, scale_); - - rawValues_ = values_->asMutable(); - getIntValues(rows, requestedType_, result); } template class SelectiveDecimalColumnReader; diff --git a/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.h b/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.h index 67a82b051e3..a770fd2d9cd 100644 --- a/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.h +++ b/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.h @@ -28,7 +28,16 @@ using namespace dwio::common; template class SelectiveDecimalColumnReader : public SelectiveColumnReader { public: + // requestedType is the DECIMAL type to materialize values as. It must be a + // decimal type. Hive's ORC writer always records DECIMAL(38, 18) in the + // file footer regardless of the metastore-declared precision/scale; the + // per-row scale at which each value was actually written lives in the + // SECONDARY (a.k.a. NANO_DATA) stream. The reader uses + // requestedType.scale() (the table-schema scale) as the target scale and + // rescales each value from its per-row scale, so the output matches what + // table consumers expect even when the file footer scale differs. SelectiveDecimalColumnReader( + const TypePtr& requestedType, const std::shared_ptr& fileType, DwrfParams& params, common::ScanSpec& scanSpec); @@ -49,7 +58,24 @@ class SelectiveDecimalColumnReader : public SelectiveColumnReader { private: template - void readHelper(RowSet rows); + void readHelper(const common::Filter* filter, RowSet rows); + + // Process IsNull and IsNotNull filters. + void processNulls(bool isNull, const RowSet& rows, const uint64_t* rawNulls); + + // Process filters on decimal values. + void processFilter( + const common::Filter* filter, + const RowSet& rows, + const uint64_t* rawNulls); + + // Dispatch to the respective filter processing based on the filter type. + void process( + const common::Filter* filter, + const RowSet& rows, + const uint64_t* rawNulls); + + void fillDecimals(); std::unique_ptr> valueDecoder_; std::unique_ptr> scaleDecoder_; diff --git a/velox/dwio/dwrf/reader/SelectiveDwrfReader.cpp b/velox/dwio/dwrf/reader/SelectiveDwrfReader.cpp index 05c16970ad9..dd12ff7ff04 100644 --- a/velox/dwio/dwrf/reader/SelectiveDwrfReader.cpp +++ b/velox/dwio/dwrf/reader/SelectiveDwrfReader.cpp @@ -81,7 +81,7 @@ std::unique_ptr SelectiveDwrfReader::build( case TypeKind::BIGINT: if (fileType->type()->isDecimal()) { return std::make_unique>( - fileType, params, scanSpec); + requestedType, fileType, params, scanSpec); } else { return buildIntegerReader( requestedType, fileType, params, LONG_BYTE_SIZE, scanSpec); @@ -99,6 +99,10 @@ std::unique_ptr SelectiveDwrfReader::build( return createSelectiveFlatMapColumnReader( columnReaderOptions, requestedType, fileType, params, scanSpec); } + if (scanSpec.isFlatMapAsStruct()) { + return std::make_unique( + columnReaderOptions, requestedType, fileType, params, scanSpec); + } return std::make_unique( columnReaderOptions, requestedType, fileType, params, scanSpec); case TypeKind::REAL: @@ -146,13 +150,13 @@ std::unique_ptr SelectiveDwrfReader::build( case TypeKind::HUGEINT: if (fileType->type()->isDecimal()) { return std::make_unique>( - fileType, params, scanSpec); + requestedType, fileType, params, scanSpec); } [[fallthrough]]; default: VELOX_FAIL( "buildReader unhandled type: " + - mapTypeKindToName(fileType->type()->kind())); + std::string(TypeKindName::toName(fileType->type()->kind()))); } } diff --git a/velox/dwio/dwrf/reader/SelectiveFlatMapColumnReader.cpp b/velox/dwio/dwrf/reader/SelectiveFlatMapColumnReader.cpp index ec3dd9d90b1..57c931e9b46 100644 --- a/velox/dwio/dwrf/reader/SelectiveFlatMapColumnReader.cpp +++ b/velox/dwio/dwrf/reader/SelectiveFlatMapColumnReader.cpp @@ -23,7 +23,6 @@ #include "velox/vector/FlatMapVector.h" namespace facebook::velox::dwrf { - namespace { template @@ -41,7 +40,7 @@ inline dwio::common::flatmap::KeyValue extractKey( template std::string toString(const T& x) { if constexpr (std::is_same_v) { - return x; + return std::string(x); } else { return std::to_string(x); } @@ -207,15 +206,16 @@ class SelectiveFlatMapAsStructReader : public SelectiveStructColumnReaderBase { fileType, params, scanSpec), - keyNodes_(getKeyNodes( - columnReaderOptions, - requestedType, - fileType, - params, - scanSpec, - dwio::common::flatmap::FlatMapOutput::kStruct)) { + keyNodes_( + getKeyNodes( + columnReaderOptions, + requestedType, + fileType, + params, + scanSpec, + dwio::common::flatmap::FlatMapOutput::kStruct)) { VELOX_CHECK( - !keyNodes_.empty(), + !scanSpec.children().empty(), "For struct encoding, keys to project must be configured"); children_.resize(keyNodes_.size()); for (auto& childSpec : scanSpec.children()) { @@ -289,13 +289,14 @@ class SelectiveFlatMapReader fileType, params, scanSpec), - keyNodes_(getKeyNodes( - columnReaderOptions, - requestedType, - fileType, - params, - scanSpec, - dwio::common::flatmap::FlatMapOutput::kFlatMap)), + keyNodes_( + getKeyNodes( + columnReaderOptions, + requestedType, + fileType, + params, + scanSpec, + dwio::common::flatmap::FlatMapOutput::kFlatMap)), rowsPerRowGroup_(formatData_->rowsPerRowGroup().value()) { // Instantiate and populate distinct keys vector. keysVector_ = BaseVector::create( @@ -305,6 +306,14 @@ class SelectiveFlatMapReader auto rawKeys = keysVector_->values()->asMutable(); children_.resize(keyNodes_.size()); + // Invalidate subscripts from previous stripes. The shared ScanSpec + // accumulates children across stripes via getOrCreateChild(). Keys + // absent in the current stripe must not be accessed via children_, + // so mark them constant so read() skips them. + for (auto& child : scanSpec.children()) { + child->setSubscript(kConstantChildSpecSubscript); + } + for (int i = 0; i < keyNodes_.size(); ++i) { keyNodes_[i].reader->scanSpec()->setSubscript(i); children_[i] = keyNodes_[i].reader.get(); diff --git a/velox/dwio/dwrf/reader/SelectiveIntegerDictionaryColumnReader.cpp b/velox/dwio/dwrf/reader/SelectiveIntegerDictionaryColumnReader.cpp index c2a4edc130a..939f355b7c5 100644 --- a/velox/dwio/dwrf/reader/SelectiveIntegerDictionaryColumnReader.cpp +++ b/velox/dwio/dwrf/reader/SelectiveIntegerDictionaryColumnReader.cpp @@ -45,7 +45,7 @@ SelectiveIntegerDictionaryColumnReader::SelectiveIntegerDictionaryColumnReader( dataReader_ = createRleDecoder( stripe.getStream(si, params.streamLabels().label(), true), rleVersion_, - *memoryPool_, + *pool_, dataVInts, numBytes); @@ -95,7 +95,7 @@ void SelectiveIntegerDictionaryColumnReader::read( ? bits::countNonNulls(nullsInReadRange_->as(), 0, end) : end; dwio::common::ensureCapacity( - scanState_.inDictionary, bits::nwords(numFlags), memoryPool_); + scanState_.inDictionary, bits::nwords(numFlags), pool_); // The in dict buffer may have changed. If no change in // dictionary, the raw state will not be updated elsewhere. scanState_.rawState.inDictionary = scanState_.inDictionary->as(); diff --git a/velox/dwio/dwrf/reader/SelectiveRepeatedColumnReader.cpp b/velox/dwio/dwrf/reader/SelectiveRepeatedColumnReader.cpp index 6cbfd654a4a..12bd9c7b2a0 100644 --- a/velox/dwio/dwrf/reader/SelectiveRepeatedColumnReader.cpp +++ b/velox/dwio/dwrf/reader/SelectiveRepeatedColumnReader.cpp @@ -39,6 +39,30 @@ std::unique_ptr> makeLengthDecoder( lenVints, dwio::common::INT_BYTE_SIZE); } + +// Returns true if the MAP extraction type needs the key child reader. +// When deltaUpdate is set, all child readers are needed regardless of +// ExtractionType because delta updates (e.g., MAP_CONCAT) operate on the +// full map. +bool needsKeyReader( + common::ScanSpec::ExtractionType extractionType, + bool hasDeltaUpdate) { + return hasDeltaUpdate || + extractionType == common::ScanSpec::ExtractionType::kNone || + extractionType == common::ScanSpec::ExtractionType::kKeys; +} + +// Returns true if the MAP extraction type needs the value child reader. +// When deltaUpdate is set, all child readers are needed regardless of +// ExtractionType because delta updates (e.g., MAP_CONCAT) operate on the +// full map. +bool needsElementReader( + common::ScanSpec::ExtractionType extractionType, + bool hasDeltaUpdate) { + return hasDeltaUpdate || + extractionType == common::ScanSpec::ExtractionType::kNone || + extractionType == common::ScanSpec::ExtractionType::kValues; +} } // namespace FlatMapContext flatMapContextFromEncodingKey(const EncodingKey& encodingKey) { @@ -59,8 +83,7 @@ SelectiveListColumnReader::SelectiveListColumnReader( fileType, params, scanSpec), - length_(makeLengthDecoder(*fileType_, params, *memoryPool_)) { - VELOX_CHECK_EQ(fileType_->id(), fileType->id(), "working on the same node"); + length_(makeLengthDecoder(*fileType_, params, *pool_)) { EncodingKey encodingKey{fileType_->id(), params.flatMapContext().sequence}; auto& stripe = params.stripeStreams(); // count the number of selected sub-columns @@ -70,6 +93,15 @@ SelectiveListColumnReader::SelectiveListColumnReader( } scanSpec_->children()[0]->setProjectOut(true); + // For kSize extraction we only need the length stream, so skip creating + // the element reader entirely. This avoids registering its streams and + // reduces IO. When deltaUpdate is set, we still need the element reader + // because delta updates operate on the full array. + if (scanSpec.extractionType() == common::ScanSpec::ExtractionType::kSize && + !scanSpec.deltaUpdate()) { + return; + } + auto childParams = DwrfParams( stripe, params.streamLabels(), @@ -84,6 +116,53 @@ SelectiveListColumnReader::SelectiveListColumnReader( children_ = {child_.get()}; } +namespace { + +void makeMapChildrenReaders( + const dwio::common::TypeWithId& fileType, + const Type& requestedType, + DwrfParams& params, + const dwio::common::ColumnReaderOptions& columnReaderOptions, + const common::ScanSpec& scanSpec, + common::ScanSpec::ExtractionType extractionType, + bool hasDeltaUpdate, + std::unique_ptr& keyReader, + std::unique_ptr& elementReader) { + const EncodingKey encodingKey{ + fileType.id(), params.flatMapContext().sequence}; + auto& stripe = params.stripeStreams(); + // Skip creating child readers that extraction pushdown doesn't need. + // This avoids registering their streams and reduces IO. + if (needsKeyReader(extractionType, hasDeltaUpdate)) { + DwrfParams keyParams( + stripe, + params.streamLabels(), + params.runtimeStatistics(), + flatMapContextFromEncodingKey(encodingKey)); + keyReader = SelectiveDwrfReader::build( + columnReaderOptions, + requestedType.childAt(0), + fileType.childAt(0), + keyParams, + *scanSpec.children()[0]); + } + if (needsElementReader(extractionType, hasDeltaUpdate)) { + DwrfParams elementParams = DwrfParams( + stripe, + params.streamLabels(), + params.runtimeStatistics(), + flatMapContextFromEncodingKey(encodingKey)); + elementReader = SelectiveDwrfReader::build( + columnReaderOptions, + requestedType.childAt(1), + fileType.childAt(1), + elementParams, + *scanSpec.children()[1]); + } +} + +} // namespace + SelectiveMapColumnReader::SelectiveMapColumnReader( const dwio::common::ColumnReaderOptions& columnReaderOptions, const TypePtr& requestedType, @@ -95,43 +174,49 @@ SelectiveMapColumnReader::SelectiveMapColumnReader( fileType, params, scanSpec), - length_(makeLengthDecoder(*fileType_, params, *memoryPool_)) { - VELOX_CHECK_EQ(fileType_->id(), fileType->id(), "working on the same node"); - const EncodingKey encodingKey{ - fileType_->id(), params.flatMapContext().sequence}; - auto& stripe = params.stripeStreams(); - if (scanSpec_->children().empty()) { - scanSpec_->getOrCreateChild(common::ScanSpec::kMapKeysFieldName); - scanSpec_->getOrCreateChild(common::ScanSpec::kMapValuesFieldName); - } - scanSpec_->children()[0]->setProjectOut(true); - scanSpec_->children()[1]->setProjectOut(true); - - auto& keyType = requestedType_->childAt(0); - auto keyParams = DwrfParams( - stripe, - params.streamLabels(), - params.runtimeStatistics(), - flatMapContextFromEncodingKey(encodingKey)); - keyReader_ = SelectiveDwrfReader::build( + length_(makeLengthDecoder(*fileType_, params, *pool_)) { + makeMapChildrenReaders( + *fileType_, + *requestedType_, + params, columnReaderOptions, - keyType, - fileType_->childAt(0), - keyParams, - *scanSpec_->children()[0].get()); + *scanSpec_, + scanSpec.extractionType(), + scanSpec.deltaUpdate() != nullptr, + keyReader_, + elementReader_); + if (keyReader_) { + children_.push_back(keyReader_.get()); + } + if (elementReader_) { + children_.push_back(elementReader_.get()); + } +} - auto& valueType = requestedType_->childAt(1); - auto elementParams = DwrfParams( - stripe, - params.streamLabels(), - params.runtimeStatistics(), - flatMapContextFromEncodingKey(encodingKey)); - elementReader_ = SelectiveDwrfReader::build( +SelectiveMapAsStructColumnReader::SelectiveMapAsStructColumnReader( + const dwio::common::ColumnReaderOptions& columnReaderOptions, + const TypePtr& requestedType, + const std::shared_ptr& fileType, + DwrfParams& params, + common::ScanSpec& scanSpec) + : dwio::common::SelectiveMapAsStructColumnReader( + requestedType, + fileType, + params, + scanSpec), + length_(makeLengthDecoder(*fileType_, params, *pool_)) { + // MapAsStruct never uses extraction pushdown (asserted in base class), + // so always create both readers. + makeMapChildrenReaders( + *fileType_, + *requestedType_, + params, columnReaderOptions, - valueType, - fileType_->childAt(1), - elementParams, - *scanSpec_->children()[1]); + mapScanSpec_, + common::ScanSpec::ExtractionType::kNone, + /*hasDeltaUpdate=*/false, + keyReader_, + elementReader_); children_ = {keyReader_.get(), elementReader_.get()}; } diff --git a/velox/dwio/dwrf/reader/SelectiveRepeatedColumnReader.h b/velox/dwio/dwrf/reader/SelectiveRepeatedColumnReader.h index ad6e8575bdd..f86aae8d993 100644 --- a/velox/dwio/dwrf/reader/SelectiveRepeatedColumnReader.h +++ b/velox/dwio/dwrf/reader/SelectiveRepeatedColumnReader.h @@ -34,10 +34,6 @@ class SelectiveListColumnReader DwrfParams& params, common::ScanSpec& scanSpec); - void resetFilterCaches() override { - child_->resetFilterCaches(); - } - void seekToRowGroup(int64_t index) override { dwio::common::SelectiveListColumnReader::seekToRowGroup(index); auto positionsProvider = formatData_->seekToRowGroup(index); @@ -45,8 +41,10 @@ class SelectiveListColumnReader VELOX_CHECK(!positionsProvider.hasNext()); - child_->seekToRowGroup(index); - child_->setReadOffsetRecursive(0); + if (child_) { + child_->seekToRowGroup(index); + child_->setReadOffsetRecursive(0); + } childTargetReadOffset_ = 0; } @@ -68,24 +66,37 @@ class SelectiveMapColumnReader : public dwio::common::SelectiveMapColumnReader { DwrfParams& params, common::ScanSpec& scanSpec); - void resetFilterCaches() override { - keyReader_->resetFilterCaches(); - elementReader_->resetFilterCaches(); - } - void seekToRowGroup(int64_t index) override { dwio::common::SelectiveMapColumnReader::seekToRowGroup(index); auto positionsProvider = formatData_->seekToRowGroup(index); - length_->seekToRowGroup(positionsProvider); - VELOX_CHECK(!positionsProvider.hasNext()); + } - keyReader_->seekToRowGroup(index); - keyReader_->setReadOffsetRecursive(0); - elementReader_->seekToRowGroup(index); - elementReader_->setReadOffsetRecursive(0); - childTargetReadOffset_ = 0; + void readLengths(int32_t* lengths, int32_t numLengths, const uint64_t* nulls) + override { + length_->next(lengths, numLengths, nulls); + } + + private: + std::unique_ptr> length_; +}; + +class SelectiveMapAsStructColumnReader + : public dwio::common::SelectiveMapAsStructColumnReader { + public: + SelectiveMapAsStructColumnReader( + const dwio::common::ColumnReaderOptions& columnReaderOptions, + const TypePtr& requestedType, + const std::shared_ptr& fileType, + DwrfParams& params, + common::ScanSpec& scanSpec); + + void seekToRowGroup(int64_t index) override { + dwio::common::SelectiveMapAsStructColumnReader::seekToRowGroup(index); + auto positionsProvider = formatData_->seekToRowGroup(index); + length_->seekToRowGroup(positionsProvider); + VELOX_CHECK(!positionsProvider.hasNext()); } void readLengths(int32_t* lengths, int32_t numLengths, const uint64_t* nulls) diff --git a/velox/dwio/dwrf/reader/SelectiveStringDictionaryColumnReader.cpp b/velox/dwio/dwrf/reader/SelectiveStringDictionaryColumnReader.cpp index d88b8c24e40..cd0e6b0b15e 100644 --- a/velox/dwio/dwrf/reader/SelectiveStringDictionaryColumnReader.cpp +++ b/velox/dwio/dwrf/reader/SelectiveStringDictionaryColumnReader.cpp @@ -46,7 +46,7 @@ SelectiveStringDictionaryColumnReader::SelectiveStringDictionaryColumnReader( dictIndex_ = createRleDecoder( stripe.getStream(dataId, params.streamLabels().label(), true), version_, - *memoryPool_, + *pool_, dictVInts, dwio::common::INT_BYTE_SIZE); @@ -59,7 +59,7 @@ SelectiveStringDictionaryColumnReader::SelectiveStringDictionaryColumnReader( lengthDecoder_ = createRleDecoder( stripe.getStream(lenId, params.streamLabels().label(), false), version_, - *memoryPool_, + *pool_, lenVInts, dwio::common::INT_BYTE_SIZE); @@ -96,7 +96,7 @@ SelectiveStringDictionaryColumnReader::SelectiveStringDictionaryColumnReader( strideDictLengthDecoder_ = createRleDecoder( stripe.getStream(strideDictLenId, params.streamLabels().label(), true), version_, - *memoryPool_, + *pool_, strideLenVInt, dwio::common::INT_BYTE_SIZE); } @@ -118,7 +118,7 @@ void SelectiveStringDictionaryColumnReader::loadDictionary( DictionaryValues& values) { // read lengths from length reader dwio::common::ensureCapacity( - values.values, values.numValues, memoryPool_); + values.values, values.numValues, pool_); // The lengths are read in the low addresses of the string views array. auto* lengths = values.values->asMutable(); lengthDecoder.nextLengths(lengths, values.numValues); @@ -127,7 +127,7 @@ void SelectiveStringDictionaryColumnReader::loadDictionary( stringsBytes += lengths[i]; } // read bytes from underlying string - values.strings = AlignedBuffer::allocate(stringsBytes, memoryPool_); + values.strings = AlignedBuffer::allocate(stringsBytes, pool_); data.readFully(values.strings->asMutable(), stringsBytes); // fill the values with StringViews over the strings. 'strings' will // exist even if 'stringsBytes' is 0, which can happen if the only @@ -156,7 +156,7 @@ void SelectiveStringDictionaryColumnReader::loadStrideDictionary() { if (scanState_.dictionary2.numValues > 0) { // seek stride dictionary related streams std::vector pos( - positions.begin() + positionOffset_, positions.end()); + positions.cbegin() + positionOffset_, positions.cend()); PositionProvider pp(pos); strideDictStream_->seekToPosition(pp); strideDictLengthDecoder_->seekToRowGroup(pp); @@ -182,7 +182,7 @@ void SelectiveStringDictionaryColumnReader::makeDictionaryBaseVector() { if (scanState_.dictionary2.numValues) { BufferPtr values = AlignedBuffer::allocate( scanState_.dictionary.numValues + scanState_.dictionary2.numValues, - memoryPool_); + pool_); auto* valuesPtr = values->asMutable(); memcpy( valuesPtr, @@ -194,7 +194,7 @@ void SelectiveStringDictionaryColumnReader::makeDictionaryBaseVector() { scanState_.dictionary2.numValues * sizeof(StringView)); dictionaryValues_ = std::make_shared>( - memoryPool_, + pool_, fileType_->type(), BufferPtr(nullptr), // TODO nulls scanState_.dictionary.numValues + @@ -204,7 +204,7 @@ void SelectiveStringDictionaryColumnReader::makeDictionaryBaseVector() { scanState_.dictionary.strings, scanState_.dictionary2.strings}); } else { dictionaryValues_ = std::make_shared>( - memoryPool_, + pool_, fileType_->type(), BufferPtr(nullptr), // TODO nulls scanState_.dictionary.numValues /*length*/, @@ -230,7 +230,7 @@ void SelectiveStringDictionaryColumnReader::read( ? bits::countNonNulls(nullsInReadRange_->as(), 0, end) : end; dwio::common::ensureCapacity( - scanState_.inDictionary, bits::nwords(numFlags), memoryPool_); + scanState_.inDictionary, bits::nwords(numFlags), pool_); // The in dict buffer may have changed. If no change in // dictionary, the raw state will not be updated elsewhere. scanState_.rawState.inDictionary = scanState_.inDictionary->as(); @@ -253,7 +253,7 @@ void SelectiveStringDictionaryColumnReader::read( void SelectiveStringDictionaryColumnReader::makeFlat(VectorPtr* result) { auto* indices = reinterpret_cast(rawValues_); - auto values = AlignedBuffer::allocate(numValues_, memoryPool_); + auto values = AlignedBuffer::allocate(numValues_, pool_); auto* stringViews = values->asMutable(); std::vector stringBuffers; auto* stripeDict = scanState_.dictionary.values->as(); @@ -278,7 +278,7 @@ void SelectiveStringDictionaryColumnReader::makeFlat(VectorPtr* result) { } } *result = std::make_shared>( - memoryPool_, + pool_, requestedType(), std::move(nulls), numValues_, @@ -308,7 +308,7 @@ void SelectiveStringDictionaryColumnReader::getValues( makeDictionaryBaseVector(); } *result = std::make_shared>( - memoryPool_, resultNulls(), numValues_, dictionaryValues_, values_); + pool_, resultNulls(), numValues_, dictionaryValues_, values_); } void SelectiveStringDictionaryColumnReader::ensureInitialized() { diff --git a/velox/dwio/dwrf/reader/SelectiveStringDirectColumnReader.cpp b/velox/dwio/dwrf/reader/SelectiveStringDirectColumnReader.cpp index 657470e9267..2a39829acf8 100644 --- a/velox/dwio/dwrf/reader/SelectiveStringDirectColumnReader.cpp +++ b/velox/dwio/dwrf/reader/SelectiveStringDirectColumnReader.cpp @@ -39,7 +39,7 @@ SelectiveStringDirectColumnReader::SelectiveStringDirectColumnReader( lengthDecoder_ = createRleDecoder( stripe.getStream(lenId, params.streamLabels().label(), true), rleVersion, - *memoryPool_, + *pool_, lenVInts, dwio::common::INT_BYTE_SIZE); blobStream_ = stripe.getStream( @@ -54,7 +54,7 @@ SelectiveStringDirectColumnReader::SelectiveStringDirectColumnReader( uint64_t SelectiveStringDirectColumnReader::skip(uint64_t numValues) { numValues = SelectiveColumnReader::skip(numValues); - dwio::common::ensureCapacity(lengths_, numValues, memoryPool_); + dwio::common::ensureCapacity(lengths_, numValues, pool_); lengthDecoder_->nextLengths(lengths_->asMutable(), numValues); rawLengths_ = lengths_->as(); for (auto i = 0; i < numValues; ++i) { @@ -318,11 +318,10 @@ inline bool SelectiveStringDirectColumnReader::try8Consecutive( void SelectiveStringDirectColumnReader::extractSparse( const int32_t* rows, int32_t numRows) { - dwio::common::rowLoop( + dwio::common::rowLoop<8>( rows, 0, numRows, - 8, [&](int32_t row) { auto start = rangeSum(rawLengths_, 0, lengthIndex_, rows[row]); lengthIndex_ = rows[row]; @@ -362,20 +361,19 @@ void SelectiveStringDirectColumnReader::skipInDecode( lengthIndex_ += numValues; } -folly::StringPiece SelectiveStringDirectColumnReader::readValue( - int32_t length) { +std::string_view SelectiveStringDirectColumnReader::readValue(int32_t length) { skipBytes(bytesToSkip_, blobStream_.get(), bufferStart_, bufferEnd_); bytesToSkip_ = 0; // bufferStart_ may be null if length is 0 and this is the first string // we're reading. if (bufferEnd_ - bufferStart_ >= length) { bytesToSkip_ = length; - return folly::StringPiece(bufferStart_, length); + return std::string_view(bufferStart_, length); } tempString_.resize(length); readBytes( length, blobStream_.get(), tempString_.data(), bufferStart_, bufferEnd_); - return folly::StringPiece(tempString_); + return std::string_view(tempString_); } template @@ -474,13 +472,12 @@ void SelectiveStringDirectColumnReader::read( int64_t offset, const RowSet& rows, const uint64_t* incomingNulls) { - prepareRead(offset, rows, incomingNulls); + prepareRead(offset, rows, incomingNulls); auto numRows = rows.back() + 1; auto numNulls = nullsInReadRange_ ? BaseVector::countNulls(nullsInReadRange_, 0, numRows) : 0; - dwio::common::ensureCapacity( - lengths_, numRows - numNulls, memoryPool_); + dwio::common::ensureCapacity(lengths_, numRows - numNulls, pool_); lengthDecoder_->nextLengths( lengths_->asMutable(), numRows - numNulls); rawLengths_ = lengths_->as(); diff --git a/velox/dwio/dwrf/reader/SelectiveStringDirectColumnReader.h b/velox/dwio/dwrf/reader/SelectiveStringDirectColumnReader.h index 8da1e77401d..e0a2369ceb5 100644 --- a/velox/dwio/dwrf/reader/SelectiveStringDirectColumnReader.h +++ b/velox/dwio/dwrf/reader/SelectiveStringDirectColumnReader.h @@ -58,7 +58,7 @@ class SelectiveStringDirectColumnReader template void skipInDecode(int32_t numValues, int32_t current, const uint64_t* nulls); - folly::StringPiece readValue(int32_t length); + std::string_view readValue(int32_t length); template void decode(const uint64_t* nulls, Visitor visitor); diff --git a/velox/dwio/dwrf/reader/SelectiveStructColumnReader.cpp b/velox/dwio/dwrf/reader/SelectiveStructColumnReader.cpp index b1f47289505..364270e5391 100644 --- a/velox/dwio/dwrf/reader/SelectiveStructColumnReader.cpp +++ b/velox/dwio/dwrf/reader/SelectiveStructColumnReader.cpp @@ -86,12 +86,13 @@ SelectiveStructColumnReader::SelectiveStructColumnReader( .sequence = encodingKey.sequence(), .inMapDecoder = nullptr, .keySelectionCallback = nullptr}); - addChild(SelectiveDwrfReader::build( - columnReaderOptions, - childRequestedType, - childFileType, - childParams, - *childSpec)); + addChild( + SelectiveDwrfReader::build( + columnReaderOptions, + childRequestedType, + childFileType, + childParams, + *childSpec)); childSpec->setSubscript(children_.size() - 1); } } diff --git a/velox/dwio/dwrf/reader/SelectiveTimestampColumnReader.cpp b/velox/dwio/dwrf/reader/SelectiveTimestampColumnReader.cpp index 10f9be96a57..053c2432e56 100644 --- a/velox/dwio/dwrf/reader/SelectiveTimestampColumnReader.cpp +++ b/velox/dwio/dwrf/reader/SelectiveTimestampColumnReader.cpp @@ -41,7 +41,7 @@ SelectiveTimestampColumnReader::SelectiveTimestampColumnReader( seconds_ = createRleDecoder( stripe.getStream(data, params.streamLabels().label(), true), version_, - *memoryPool_, + *pool_, vints, LONG_BYTE_SIZE); auto nanoData = StripeStreamsUtil::getStreamForKind( @@ -53,7 +53,7 @@ SelectiveTimestampColumnReader::SelectiveTimestampColumnReader( nano_ = createRleDecoder( stripe.getStream(nanoData, params.streamLabels().label(), true), version_, - *memoryPool_, + *pool_, nanoVInts, LONG_BYTE_SIZE); } @@ -86,7 +86,7 @@ void SelectiveTimestampColumnReader::read( resultNulls_->capacity() * 8 < rows.size()) { // Make sure a dedicated resultNulls_ is allocated with enough capacity as // RleDecoder always assumes it is available. - resultNulls_ = AlignedBuffer::allocate(rows.size(), memoryPool_); + resultNulls_ = AlignedBuffer::allocate(rows.size(), pool_); rawResultNulls_ = resultNulls_->asMutable(); } bool isDense = rows.back() == rows.size() - 1; @@ -119,8 +119,7 @@ void SelectiveTimestampColumnReader::readHelper( // Save the seconds into their own buffer before reading nanos into // 'values_' - dwio::common::ensureCapacity( - secondsValues_, numValues_, memoryPool_); + dwio::common::ensureCapacity(secondsValues_, numValues_, pool_); secondsValues_->setSize(numValues_ * sizeof(int64_t)); memcpy( secondsValues_->asMutable(), @@ -141,7 +140,7 @@ void SelectiveTimestampColumnReader::readHelper( const auto rawNulls = nullsInReadRange_ ? (isDense ? nullsInReadRange_->as() : rawResultNulls_) : nullptr; - auto tsValues = AlignedBuffer::allocate(numValues_, memoryPool_); + auto tsValues = AlignedBuffer::allocate(numValues_, pool_); auto rawTs = tsValues->asMutable(); for (vector_size_t i = 0; i < numValues_; i++) { diff --git a/velox/dwio/dwrf/reader/StripeMetadataCache.h b/velox/dwio/dwrf/reader/StripeMetadataCache.h index 4a7a1125610..9cd5507932d 100644 --- a/velox/dwio/dwrf/reader/StripeMetadataCache.h +++ b/velox/dwio/dwrf/reader/StripeMetadataCache.h @@ -95,7 +95,7 @@ class StripeMetadataCache { std::vector offsets; offsets.reserve(footer.stripeCacheOffsetsSize()); const auto& from = footer.stripeCacheOffsets(); - offsets.assign(from.begin(), from.end()); + offsets.assign(from.cbegin(), from.cend()); return offsets; } diff --git a/velox/dwio/dwrf/reader/StripeReaderBase.cpp b/velox/dwio/dwrf/reader/StripeReaderBase.cpp index 44ba6aa81e8..9bc5549d923 100644 --- a/velox/dwio/dwrf/reader/StripeReaderBase.cpp +++ b/velox/dwio/dwrf/reader/StripeReaderBase.cpp @@ -16,8 +16,11 @@ #include "velox/dwio/dwrf/reader/StripeReaderBase.h" +#include "velox/dwio/common/Arena.h" + namespace facebook::velox::dwrf { +using dwio::common::ArenaCreate; using dwio::common::LogType; // preload is not considered or mutated if stripe has already been fetched. e.g. @@ -95,9 +98,7 @@ std::unique_ptr StripeReaderBase::fetchStripe( }; if (fileFooter.format() == DwrfFormat::kDwrf) { - auto* rawFooter = - google::protobuf::Arena::CreateMessage( - arena.get()); + auto* rawFooter = ArenaCreate(arena.get()); ProtoUtils::readProtoInto( reader_->createDecompressedStream( std::move(footerStream), streamDebugInfo), @@ -110,9 +111,7 @@ std::unique_ptr StripeReaderBase::fetchStripe( return createStripeMetadata(std::move(stripeFooter)); } else { - auto* rawFooter = - google::protobuf::Arena::CreateMessage( - arena.get()); + auto* rawFooter = ArenaCreate(arena.get()); ProtoUtils::readProtoInto( reader_->createDecompressedStream( std::move(footerStream), streamDebugInfo), diff --git a/velox/dwio/dwrf/reader/StripeStream.cpp b/velox/dwio/dwrf/reader/StripeStream.cpp index 4c9b3ec829a..eb7293b0955 100644 --- a/velox/dwio/dwrf/reader/StripeStream.cpp +++ b/velox/dwio/dwrf/reader/StripeStream.cpp @@ -343,10 +343,12 @@ std::unique_ptr StripeStreamsImpl::getStream( const auto streamDebugInfo = fmt::format("Stripe {} Stream {}", stripeIndex_, si.toString()); + return readState_->readerBase->createDecompressedStream( std::move(streamInput), streamDebugInfo, - getDecrypter(si.encodingKey().node())); + getDecrypter(si.encodingKey().node()), + getDecompressCounter(si.encodingKey().node())); } uint32_t StripeStreamsImpl::visitStreamsOfNode( diff --git a/velox/dwio/dwrf/reader/StripeStream.h b/velox/dwio/dwrf/reader/StripeStream.h index cccbf3af875..cba5fc0748e 100644 --- a/velox/dwio/dwrf/reader/StripeStream.h +++ b/velox/dwio/dwrf/reader/StripeStream.h @@ -20,6 +20,7 @@ #include "velox/dwio/common/ColumnSelector.h" #include "velox/dwio/common/Options.h" #include "velox/dwio/common/SeekableInputStream.h" +#include "velox/dwio/common/Statistics.h" #include "velox/dwio/dwrf/common/Common.h" #include "velox/dwio/dwrf/reader/StreamLabels.h" #include "velox/dwio/dwrf/reader/StripeDictionaryCache.h" @@ -229,7 +230,8 @@ class StripeStreamsImpl : public StripeStreamsBase { uint64_t stripeStart, int64_t stripeNumberOfRows, const StrideIndexProvider& provider, - uint32_t stripeIndex) + uint32_t stripeIndex, + dwio::common::ColumnReaderStatistics* columnReaderStats = nullptr) : StripeStreamsBase{&readState->readerBase->memoryPool()}, readState_(std::move(readState)), selector_{selector}, @@ -238,7 +240,8 @@ class StripeStreamsImpl : public StripeStreamsBase { stripeStart_{stripeStart}, stripeNumberOfRows_{stripeNumberOfRows}, provider_(provider), - stripeIndex_{stripeIndex} { + stripeIndex_{stripeIndex}, + columnReaderStats_{columnReaderStats} { loadStreams(); } @@ -362,6 +365,14 @@ class StripeStreamsImpl : public StripeStreamsBase { : nullptr; } + io::IoCounter* getDecompressCounter(uint32_t nodeId) const { + if (!columnReaderStats_ || !columnReaderStats_->columnMetricsSet) { + return nullptr; + } + auto* metrics = columnReaderStats_->columnMetricsSet->getOrCreate(nodeId); + return &metrics->decompressCPUTimeNanos; + } + void loadStreams(); const std::shared_ptr readState_; @@ -374,6 +385,7 @@ class StripeStreamsImpl : public StripeStreamsBase { const int64_t stripeNumberOfRows_; const StrideIndexProvider& provider_; const uint32_t stripeIndex_; + dwio::common::ColumnReaderStatistics* const columnReaderStats_{nullptr}; bool readPlanLoaded_{false}; diff --git a/velox/dwio/dwrf/test/CMakeLists.txt b/velox/dwio/dwrf/test/CMakeLists.txt index cbd29a9ac2f..c373de65143 100644 --- a/velox/dwio/dwrf/test/CMakeLists.txt +++ b/velox/dwio/dwrf/test/CMakeLists.txt @@ -70,7 +70,7 @@ target_link_libraries( ${TEST_LINK_LIBS} ) -add_executable(velox_dwio_dwrf_decompression_test TestDecompression.cpp) +add_executable(velox_dwio_dwrf_decompression_test DecompressionTest.cpp) add_test( NAME velox_dwio_dwrf_decompression_test COMMAND velox_dwio_dwrf_decompression_test @@ -87,7 +87,7 @@ target_link_libraries( ${TEST_LINK_LIBS} ) -add_executable(velox_dwio_dwrf_stripe_stream_test TestStripeStream.cpp) +add_executable(velox_dwio_dwrf_stripe_stream_test StripeStreamTest.cpp) add_test(velox_dwio_dwrf_stripe_stream_test velox_dwio_dwrf_stripe_stream_test) target_link_libraries( @@ -428,6 +428,7 @@ target_link_libraries( velox_dwio_dwrf_reader_test velox_dwrf_test_utils velox_vector_test_lib + velox_hive_connector velox_link_libs Folly::folly fmt::fmt @@ -618,7 +619,7 @@ target_link_libraries( velox_dwio_cache_test velox_common_io velox_link_libs - velox_temp_path + velox_test_util Folly::folly fmt::fmt lz4::lz4 @@ -626,3 +627,9 @@ target_link_libraries( ZLIB::ZLIB ${TEST_LINK_LIBS} ) + +velox_add_library(velox_test_column_statistics_base INTERFACE HEADERS ColumnStatisticsBase.h) + +velox_add_library(velox_test-read-file INTERFACE HEADERS TestReadFile.h) + +velox_add_library(velox_orc-test INTERFACE HEADERS OrcTest.h) diff --git a/velox/dwio/dwrf/test/CacheInputTest.cpp b/velox/dwio/dwrf/test/CacheInputTest.cpp index 79ca58b37ab..f06cef3f142 100644 --- a/velox/dwio/dwrf/test/CacheInputTest.cpp +++ b/velox/dwio/dwrf/test/CacheInputTest.cpp @@ -24,15 +24,16 @@ #include "velox/common/io/IoStatistics.h" #include "velox/common/io/Options.h" #include "velox/common/memory/MmapAllocator.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/dwio/common/CachedBufferedInput.h" #include "velox/dwio/dwrf/common/Common.h" #include "velox/dwio/dwrf/test/TestReadFile.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include #include using namespace facebook::velox; +using namespace facebook::velox::common::testutil; using namespace facebook::velox::dwio; using namespace facebook::velox::dwio::common; using namespace facebook::velox::cache; @@ -62,8 +63,7 @@ class CacheTest : public ::testing::Test { void SetUp() override { // executor_ = std::make_unique(10, 10); rng_.seed(1); - ioStats_ = std::make_shared(); - fsStats_ = std::make_shared(); + ioStats_ = std::make_shared(); filesystems::registerLocalFileSystem(); } @@ -99,7 +99,7 @@ class CacheTest : public ::testing::Test { std::unique_ptr ssd; if (ssdBytes > 0) { FLAGS_velox_ssd_odirect = false; - tempDirectory_ = exec::test::TempDirectoryPath::create(); + tempDirectory_ = TempDirectoryPath::create(); const SsdCache::Config config( fmt::format("{}/cache", tempDirectory_->getPath()), ssdBytes, @@ -110,19 +110,21 @@ class CacheTest : public ::testing::Test { checksumEnabled, checksumEnabled); ssd = std::make_unique(config); - ssdCacheHelper_ = std::make_unique(ssd.get()); + ssdCacheHelper_ = + std::make_unique(ssd.get()); groupStats_ = &ssd->groupStats(); } - memory::MmapAllocator::Options options; + memory::MemoryAllocator::Options options; options.capacity = maxBytes; allocator_ = std::make_shared(options); cache_ = AsyncDataCache::create(allocator_.get(), std::move(ssd)); asyncDataCacheHelper_ = - std::make_unique(cache_.get()); + std::make_unique(cache_.get()); cache_->setVerifyHook(checkEntry); for (auto i = 0; i < kMaxStreams; ++i) { - streamIds_.push_back(std::make_unique( - i, i, 0, dwrf::StreamKind_DATA)); + streamIds_.push_back( + std::make_unique( + i, i, 0, dwrf::StreamKind_DATA)); } streamStarts_.resize(kMaxStreams + 1); streamStarts_[0] = 0; @@ -149,13 +151,13 @@ class CacheTest : public ::testing::Test { static void checkEntry(const cache::AsyncDataCacheEntry& entry) { uint64_t seed = entry.key().fileNum.id(); - if (entry.tinyData()) { - checkData(entry.tinyData(), entry.offset(), entry.size(), seed); + if (entry.hasContiguousData()) { + checkData(entry.contiguousData(), entry.offset(), entry.size(), seed); } else { int64_t bytesLeft = entry.size(); auto runOffset = entry.offset(); - for (auto i = 0; i < entry.data().numRuns(); ++i) { - auto run = entry.data().runAt(i); + for (auto i = 0; i < entry.nonContiguousData().numRuns(); ++i) { + auto run = entry.nonContiguousData().runAt(i); checkData( run.data(), runOffset, @@ -202,7 +204,7 @@ class CacheTest : public ::testing::Test { fileIds_.push_back(groupId); // Creates an extremely large read file for test. auto stream = std::make_shared( - fileId.id(), 1UL << 63, std::make_shared()); + fileId.id(), 1UL << 63, std::make_shared()); pathToInput_[fileId.id()] = stream; return stream; } @@ -216,12 +218,14 @@ class CacheTest : public ::testing::Test { const StringIdLease& fileId, const StringIdLease& groupId, int64_t offset, - bool noCacheRetention, - const IoStatisticsPtr& ioStats, - const std::shared_ptr& fsStats) { + bool cacheable, + const IoStatisticsPtr& ioStatistics, + const std::shared_ptr& ioStats) { auto data = std::make_unique(); auto readOptions = io::ReaderOptions(pool_.get()); - readOptions.setNoCacheRetention(noCacheRetention); + readOptions.setDataIoStats(dataIoStats_); + readOptions.setMetadataIoStats(metadataIoStats_); + readOptions.setCacheable(cacheable); data->input = std::make_unique( readFile, MetricsLog::voidLog(), @@ -229,8 +233,8 @@ class CacheTest : public ::testing::Test { cache_.get(), tracker, groupId, + ioStatistics, ioStats, - fsStats, executor_.get(), readOptions); data->file = readFile.get(); @@ -345,9 +349,9 @@ class CacheTest : public ::testing::Test { int32_t readPctModulo, int32_t numStripes, int32_t stripeWindow, - bool noCacheRetention, - const IoStatisticsPtr& ioStats, - const std::shared_ptr& fsStats) { + bool cacheable, + const IoStatisticsPtr& ioStatistics, + const std::shared_ptr& ioStats) { auto tracker = std::make_shared( "testTracker", nullptr, @@ -375,9 +379,9 @@ class CacheTest : public ::testing::Test { fileId, groupId, prefetchStripeIndex * streamStarts_[kMaxStreams - 1], - noCacheRetention, - ioStats, - fsStats)); + cacheable, + ioStatistics, + ioStats)); if (stripes.back()->input->shouldPreload()) { stripes.back()->input->load(LogType::TEST); stripes.back()->prefetched = true; @@ -417,9 +421,9 @@ class CacheTest : public ::testing::Test { readPctModulo, numStripes, stripeWindow, - /*noCacheRetention=*/false, - ioStats_, - fsStats_); + /*cacheable=*/true, + dataIoStats_, + ioStats_); } } @@ -432,18 +436,22 @@ class CacheTest : public ::testing::Test { } } + const std::shared_ptr dataIoStats_{ + std::make_shared()}; + const std::shared_ptr metadataIoStats_{ + std::make_shared()}; + // Serializes 'pathToInput_' and 'fileIds_' in multithread test. std::mutex mutex_; std::vector fileIds_; folly::F14FastMap> pathToInput_; - std::shared_ptr tempDirectory_; + std::shared_ptr tempDirectory_; cache::FileGroupStats* groupStats_ = nullptr; std::shared_ptr allocator_; std::shared_ptr cache_; - std::unique_ptr asyncDataCacheHelper_; - std::unique_ptr ssdCacheHelper_; - std::shared_ptr ioStats_; - std::shared_ptr fsStats_; + std::unique_ptr asyncDataCacheHelper_; + std::unique_ptr ssdCacheHelper_; + std::shared_ptr ioStats_; std::unique_ptr executor_; std::shared_ptr pool_{ memory::memoryManager()->addLeafPool()}; @@ -481,10 +489,15 @@ TEST_F(CacheTest, window) { cache_.get(), tracker, groupId, + dataIoStats_, ioStats_, - fsStats_, executor_.get(), - io::ReaderOptions(pool_.get())); + [&] { + io::ReaderOptions opts(pool_.get()); + opts.setDataIoStats(dataIoStats_); + opts.setMetadataIoStats(metadataIoStats_); + return opts; + }()); auto begin = 4 * kMB; auto end = 17 * kMB; auto stream = input->read(begin, end - begin, LogType::TEST); @@ -511,7 +524,7 @@ TEST_F(CacheTest, window) { auto clone = cacheInput->clone(); clone->SkipInt64(100); clone->setRemainingBytes(kMB); - auto previousRead = ioStats_->rawBytesRead(); + auto previousRead = dataIoStats_->rawBytesRead(); EXPECT_TRUE(clone->Next(&buffer, &size)); // Half MB minus the 100 bytes skipped above should be left in the first load // quantum of 8MB. @@ -520,7 +533,7 @@ TEST_F(CacheTest, window) { EXPECT_EQ(kMB / 2 + 100, size); // There should be no more data in the window. EXPECT_FALSE(clone->Next(&buffer, &size)); - EXPECT_EQ(kMB, ioStats_->rawBytesRead() - previousRead); + EXPECT_EQ(kMB, dataIoStats_->rawBytesRead() - previousRead); } TEST_F(CacheTest, bufferedInput) { @@ -533,9 +546,9 @@ TEST_F(CacheTest, bufferedInput) { 10, 20, 4, - /*noCacheRetention=*/false, - ioStats_, - fsStats_); + /*cacheable=*/true, + dataIoStats_, + ioStats_); readLoop( "testfile", 30, @@ -543,9 +556,9 @@ TEST_F(CacheTest, bufferedInput) { 10, 20, 4, - /*noCacheRetention=*/false, - ioStats_, - fsStats_); + /*cacheable=*/true, + dataIoStats_, + ioStats_); readLoop( "testfile2", 30, @@ -553,9 +566,9 @@ TEST_F(CacheTest, bufferedInput) { 70, 20, 4, - /*noCacheRetention=*/false, - ioStats_, - fsStats_); + /*cacheable=*/true, + dataIoStats_, + ioStats_); } // Calibrates the data read for a densely and sparsely read stripe of test data. @@ -578,15 +591,15 @@ TEST_F(CacheTest, ssd) { 1, 1, 1, - /*noCacheRetention=*/false, - ioStats_, - fsStats_); + /*cacheable=*/true, + dataIoStats_, + ioStats_); // This is a cold read, so expect no hits. - EXPECT_EQ(0, ioStats_->ramHit().sum()); + EXPECT_EQ(0, dataIoStats_->ramHit().sum()); // Expect some extra reading from coalescing. - EXPECT_LT(0, ioStats_->rawOverreadBytes()); - auto fullStripeBytes = ioStats_->rawBytesRead(); - auto bytes = ioStats_->rawBytesRead(); + EXPECT_LT(0, dataIoStats_->rawOverreadBytes()); + auto fullStripeBytes = dataIoStats_->rawBytesRead(); + auto bytes = dataIoStats_->rawBytesRead(); cache_->clear(); // We read 10 stripes with some columns sparsely accessed. readLoop( @@ -596,13 +609,13 @@ TEST_F(CacheTest, ssd) { 10, 10, 1, - /*noCacheRetention=*/false, - ioStats_, - fsStats_); - auto sparseStripeBytes = (ioStats_->rawBytesRead() - bytes) / 10; + /*cacheable=*/true, + dataIoStats_, + ioStats_); + auto sparseStripeBytes = (dataIoStats_->rawBytesRead() - bytes) / 10; EXPECT_LT(sparseStripeBytes, fullStripeBytes / 4); // Expect the dense fraction of columns to have read ahead. - EXPECT_LT(400'000, ioStats_->prefetch().sum()); + EXPECT_LT(400'000, dataIoStats_->prefetch().sum()); constexpr int32_t kStripesPerFile = 10; auto bytesPerFile = fullStripeBytes * kStripesPerFile; @@ -623,13 +636,13 @@ TEST_F(CacheTest, ssd) { kStripesPerFile, 4); // Expect some hits from SSD. - EXPECT_LE(kSsdBytes / 8, ioStats_->ssdRead().sum()); + EXPECT_LE(kSsdBytes / 8, dataIoStats_->ssdRead().sum()); // We expec some prefetch but the quantity is nondeterminstic // because cases where the main thread reads the data ahead of // background reader does not count as prefetch even if prefetch was // issued. Also, the head of each file does not get prefetched // because each file has its own tracker. - EXPECT_LE(kSsdBytes / 8, ioStats_->prefetch().sum()); + EXPECT_LE(kSsdBytes / 8, dataIoStats_->prefetch().sum()); readFiles( "prefix1_", @@ -657,9 +670,9 @@ TEST_F(CacheTest, singleFileThreads) { 10, 20, 4, - /*noCacheRetention=*/false, - ioStats_, - fsStats_); + /*cacheable=*/true, + dataIoStats_, + ioStats_); })); } for (auto i = 0; i < numThreads; ++i) { @@ -675,16 +688,19 @@ TEST_F(CacheTest, ssdThreads) { stats.reserve(kNumThreads); std::vector threads; threads.reserve(kNumThreads); - std::vector> fsStats; - fsStats.reserve(kNumThreads); + std::vector> ioStatsVec; + ioStatsVec.reserve(kNumThreads); // We read 4 files on 8 threads. Threads 0 and 1 read file 0, 2 and 3 read // file 1 etc. Each tread reads its file 4 times. for (int i = 0; i < kNumThreads; ++i) { stats.push_back(std::make_shared()); - fsStats.push_back(std::make_shared()); - threads.push_back(std::thread( - [i, this, threadStats = stats.back(), fsStat = fsStats.back()]() { + ioStatsVec.push_back(std::make_shared()); + threads.push_back( + std::thread([i, + this, + threadStats = stats.back(), + ioStat = ioStatsVec.back()]() { for (auto counter = 0; counter < 4; ++counter) { readLoop( fmt::format("testfile{}", i / 2), @@ -693,9 +709,9 @@ TEST_F(CacheTest, ssdThreads) { 10, 20, 2, - /*noCacheRetention=*/false, + /*cacheable=*/true, threadStats, - fsStat); + ioStat); } })); } @@ -725,13 +741,15 @@ class FileWithReadAhead { const std::string& name, cache::AsyncDataCache* cache, IoStatisticsPtr stats, - std::shared_ptr fsStats, + std::shared_ptr ioStats, memory::MemoryPool& pool, folly::Executor* executor) : options_(&pool) { + options_.setDataIoStats(dataIoStats_); + options_.setMetadataIoStats(metadataIoStats_); fileId_ = std::make_unique(fileIds(), name); - file_ = std::make_shared(fileId_->id(), kFileSize, fsStats); - options_.setNoCacheRetention(true); + file_ = std::make_shared(fileId_->id(), kFileSize, ioStats); + options_.setCacheable(false); bufferedInput_ = std::make_unique( file_, MetricsLog::voidLog(), @@ -740,13 +758,13 @@ class FileWithReadAhead { nullptr, StringIdLease{}, stats, - fsStats, + ioStats, executor, options_); auto sequential = StreamIdentifier::sequentialFile(); stream_ = bufferedInput_->enqueue(Region{0, file_->size()}, &sequential); - VELOX_CHECK(reinterpret_cast(stream_.get()) - ->testingNoCacheRetention()); + VELOX_CHECK(!reinterpret_cast(stream_.get()) + ->testingCacheable()); // Trigger load of next 4MB after reading the first 2MB of the previous 4MB // quantum. reinterpret_cast(stream_.get())->setPrefetchPct(50); @@ -758,6 +776,11 @@ class FileWithReadAhead { } private: + std::shared_ptr dataIoStats_{ + std::make_shared()}; + std::shared_ptr metadataIoStats_{ + std::make_shared()}; + std::unique_ptr fileId_; std::unique_ptr bufferedInput_; std::unique_ptr stream_; @@ -778,8 +801,8 @@ TEST_F(CacheTest, readAhead) { stats.reserve(kNumThreads); std::vector threads; threads.reserve(kNumThreads); - std::vector> fsStats; - fsStats.reserve(kNumThreads); + std::vector> ioStatsVec; + ioStatsVec.reserve(kNumThreads); // We read kFilesPerThread on each thread. The files are read in parallel, // advancing each file in turn. Read-ahead is triggered when a fraction of the @@ -787,59 +810,68 @@ TEST_F(CacheTest, readAhead) { for (int threadIndex = 0; threadIndex < kNumThreads; ++threadIndex) { stats.push_back(std::make_shared()); - fsStats.push_back(std::make_shared()); - threads.push_back(std::thread([threadIndex, - this, - threadStats = stats.back(), - fsStat = fsStats.back()]() { - std::vector> files; - auto firstFileNumber = threadIndex * kFilesPerThread; - for (auto i = 0; i < kFilesPerThread; ++i) { - auto name = fmt::format("prefetch_{}", i + firstFileNumber); - files.push_back(std::make_unique( - name, cache_.get(), threadStats, fsStat, *pool_, executor_.get())); - } - std::vector totalRead(kFilesPerThread); - std::vector bytesLeft(kFilesPerThread); - for (auto counter = 0; counter < 100; ++counter) { - for (auto i = 0; i < kFilesPerThread; ++i) { - if (!files[i]) { - continue; // This set of files is finished. + ioStatsVec.push_back(std::make_shared()); + threads.push_back( + std::thread([threadIndex, + this, + threadStats = stats.back(), + ioStat = ioStatsVec.back()]() { + std::vector> files; + auto firstFileNumber = threadIndex * kFilesPerThread; + for (auto i = 0; i < kFilesPerThread; ++i) { + auto name = fmt::format("prefetch_{}", i + firstFileNumber); + files.push_back( + std::make_unique( + name, + cache_.get(), + threadStats, + ioStat, + *pool_, + executor_.get())); } - // Read from the next file. Different files advance at slightly - // different rates. - auto bytesNeeded = kMinRead + i * 1000; - while (bytesLeft[i] < bytesNeeded) { - const void* buffer; - int32_t size; - if (!files[i]->next(buffer, size)) { - // End of file. Check that a multiple of file size has been read. - EXPECT_EQ(0, totalRead[i] % FileWithReadAhead::kFileSize); - if (totalRead[i] >= 3 * FileWithReadAhead::kFileSize) { - files[i] = nullptr; - break; + std::vector totalRead(kFilesPerThread); + std::vector bytesLeft(kFilesPerThread); + for (auto counter = 0; counter < 100; ++counter) { + for (auto i = 0; i < kFilesPerThread; ++i) { + if (!files[i]) { + continue; // This set of files is finished. + } + // Read from the next file. Different files advance at slightly + // different rates. + auto bytesNeeded = kMinRead + i * 1000; + while (bytesLeft[i] < bytesNeeded) { + const void* buffer; + int32_t size; + if (!files[i]->next(buffer, size)) { + // End of file. Check that a multiple of file size has been + // read. + EXPECT_EQ(0, totalRead[i] % FileWithReadAhead::kFileSize); + if (totalRead[i] >= 3 * FileWithReadAhead::kFileSize) { + files[i] = nullptr; + break; + } + // Open a new file with a different unique name. + auto newName = fmt::format( + "prefetch_{}", + (static_cast(firstFileNumber) + i + i) * + 1000000000 + + totalRead[i]); + files[i] = std::make_unique( + newName, + cache_.get(), + threadStats, + ioStat, + *pool_, + executor_.get()); + continue; + } + totalRead[i] += size; + bytesLeft[i] += size; } - // Open a new file with a different unique name. - auto newName = fmt::format( - "prefetch_{}", - (static_cast(firstFileNumber) + i + i) * 1000000000 + - totalRead[i]); - files[i] = std::make_unique( - newName, - cache_.get(), - threadStats, - fsStat, - *pool_, - executor_.get()); - continue; + bytesLeft[i] -= bytesNeeded; } - totalRead[i] += size; - bytesLeft[i] += size; } - bytesLeft[i] -= bytesNeeded; - } - } - })); + })); } int64_t bytes = 0; int32_t count = 0; @@ -853,29 +885,29 @@ TEST_F(CacheTest, readAhead) { LOG(INFO) << count << " prefetches with total " << bytes << " bytes"; } -TEST_F(CacheTest, noCacheRetention) { +TEST_F(CacheTest, cacheable) { const int64_t cacheSize = 1LL << 30; struct { - bool noCacheRetention; + bool cacheable; bool hasSsdCache; int readPct; std::string debugString() const { return fmt::format( - "noCacheRetention {}, hasSsdCache {}, readPct {}", - noCacheRetention, + "cacheable {}, hasSsdCache {}, readPct {}", + cacheable, hasSsdCache, readPct); } } testSettings[] = { - {true, true, 100}, - {true, false, 100}, - {false, false, 100}, {false, true, 100}, - {true, true, 10}, + {false, false, 100}, {true, false, 100}, + {true, true, 100}, + {false, true, 10}, {false, false, 100}, - {false, true, 100}}; + {true, false, 100}, + {true, true, 100}}; for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); @@ -887,26 +919,26 @@ TEST_F(CacheTest, noCacheRetention) { // We read one stripe with all columns, readLoop( - "noCacheRetention", + "cacheable", 20, testData.readPct, 1, 5, 1, - testData.noCacheRetention, - ioStats_, - fsStats_); + testData.cacheable, + dataIoStats_, + ioStats_); // This is a cold read, so expect no hits. - ASSERT_EQ(ioStats_->ramHit().sum(), 0); + ASSERT_EQ(dataIoStats_->ramHit().sum(), 0); // Only one reference per column so there is no prefetch. - ASSERT_LT(0, ioStats_->prefetch().sum()); + ASSERT_LT(0, dataIoStats_->prefetch().sum()); // Expect some extra reading from coalescing. - ASSERT_LT(0, ioStats_->rawOverreadBytes()); - ASSERT_LT(0, ioStats_->rawBytesRead()); + ASSERT_LT(0, dataIoStats_->rawOverreadBytes()); + ASSERT_LT(0, dataIoStats_->rawBytesRead()); auto* ssdCache = cache_->ssdCache(); if (ssdCache != nullptr) { ssdCache->waitForWriteToFinish(); - if (testData.noCacheRetention) { + if (!testData.cacheable) { ASSERT_EQ(ssdCache->stats().entriesCached, 0); } else { ASSERT_GT(ssdCache->stats().entriesCached, 0); @@ -916,8 +948,9 @@ TEST_F(CacheTest, noCacheRetention) { const auto cacheEntries = asyncDataCacheHelper_->cacheEntries(); for (const auto& cacheEntry : cacheEntries) { const auto cacheEntryHelper = - std::make_unique(cacheEntry); - if (testData.noCacheRetention) { + std::make_unique( + cacheEntry); + if (!testData.cacheable) { ASSERT_EQ(cacheEntryHelper->accessStats().numUses, 0); ASSERT_EQ(cacheEntryHelper->accessStats().lastUse, 0); } else { @@ -937,6 +970,8 @@ TEST_F(CacheTest, loadQuotumTooLarge) { auto readFile = std::make_shared(fileId.id(), 10 << 20, nullptr); auto readOptions = io::ReaderOptions(pool_.get()); + readOptions.setDataIoStats(dataIoStats_); + readOptions.setMetadataIoStats(metadataIoStats_); readOptions.setLoadQuantum(9 << 20 /*9MB*/); VELOX_ASSERT_THROW( std::make_unique( @@ -971,10 +1006,15 @@ TEST_F(CacheTest, ssdReadVerification) { cache_.get(), tracker, groupId, + dataIoStats_, ioStats_, - fsStats_, executor_.get(), - io::ReaderOptions(pool_.get())); + [&] { + io::ReaderOptions opts(pool_.get()); + opts.setDataIoStats(dataIoStats_); + opts.setMetadataIoStats(metadataIoStats_); + return opts; + }()); const auto readData = [&](uint32_t numBytesRead) { const uint64_t kNumBytesPerRead = 4 << 20; @@ -999,9 +1039,15 @@ TEST_F(CacheTest, ssdReadVerification) { ASSERT_EQ(stats.numHit, 0); ASSERT_EQ(stats.ssdStats->entriesRead, 0); ASSERT_EQ(stats.ssdStats->readSsdCorruptions, 0); - ASSERT_GT(ioStats_->read().sum(), 0); - ASSERT_EQ(ioStats_->ramHit().sum(), 0); - ASSERT_EQ(ioStats_->ssdRead().sum(), 0); + ASSERT_GT(dataIoStats_->read().sum(), 0); + ASSERT_EQ(dataIoStats_->ramHit().sum(), 0); + ASSERT_EQ(dataIoStats_->ssdRead().sum(), 0); + // Cold read should have remote storage latency. + ASSERT_GT(dataIoStats_->storageReadLatencyUs().count(), 0); + // This test does not use coalesced loading for cold reads, so no coalesced + // latency is expected. + ASSERT_EQ(dataIoStats_->coalescedSsdLoadLatencyUs().count(), 0); + ASSERT_EQ(dataIoStats_->ssdCacheReadLatencyUs().count(), 0); // Read kSsdBytes of data. readData(kSsdBytes); @@ -1011,9 +1057,10 @@ TEST_F(CacheTest, ssdReadVerification) { ASSERT_GT(stats.numHit, 0); ASSERT_EQ(stats.ssdStats->entriesRead, 0); ASSERT_EQ(stats.ssdStats->readSsdCorruptions, 0); - ASSERT_GT(ioStats_->read().sum(), 0); - ASSERT_GT(ioStats_->ramHit().sum(), 0); - ASSERT_EQ(ioStats_->ssdRead().sum(), 0); + ASSERT_GT(dataIoStats_->read().sum(), 0); + ASSERT_GT(dataIoStats_->ramHit().sum(), 0); + ASSERT_EQ(dataIoStats_->ssdRead().sum(), 0); + ASSERT_EQ(dataIoStats_->ssdCacheReadLatencyUs().count(), 0); // Read kSsdBytes of data. readData(kSsdBytes); @@ -1023,9 +1070,10 @@ TEST_F(CacheTest, ssdReadVerification) { ASSERT_GT(stats.numHit, 0); ASSERT_GT(stats.ssdStats->entriesRead, 0); ASSERT_EQ(stats.ssdStats->readSsdCorruptions, 0); - ASSERT_GT(ioStats_->read().sum(), 0); - ASSERT_GT(ioStats_->ramHit().sum(), 0); - ASSERT_GT(ioStats_->ssdRead().sum(), 0); + ASSERT_GT(dataIoStats_->read().sum(), 0); + ASSERT_GT(dataIoStats_->ramHit().sum(), 0); + ASSERT_GT(dataIoStats_->ssdRead().sum(), 0); + ASSERT_GT(dataIoStats_->ssdCacheReadLatencyUs().count(), 0); // Corrupt SSD cache file. corruptSsdFile(fmt::format("{}/cache0", tempDirectory_->getPath())); @@ -1034,22 +1082,22 @@ TEST_F(CacheTest, ssdReadVerification) { // Record the baseline stats. const auto prevStats = cache_->refreshStats(); - const auto prevRead = ioStats_->read().sum(); - const auto prevRamHit = ioStats_->ramHit().sum(); - const auto prevSsdRead = ioStats_->ssdRead().sum(); + const auto prevRead = dataIoStats_->read().sum(); + const auto prevRamHit = dataIoStats_->ramHit().sum(); + const auto prevSsdRead = dataIoStats_->ssdRead().sum(); // Read from the corrupted cache. readData(kSsdBytes); waitForWrite(); stats = cache_->refreshStats(); // Expect all new reads to be recorded as corruptions. - ASSERT_GT(ioStats_->read().sum(), prevRead); + ASSERT_GT(dataIoStats_->read().sum(), prevRead); ASSERT_GT(stats.ssdStats->readSsdCorruptions, 0); ASSERT_EQ( stats.ssdStats->readSsdCorruptions, stats.ssdStats->entriesRead - prevStats.ssdStats->entriesRead); // Expect no new succeeded cache hits. ASSERT_EQ(stats.numHit, prevStats.numHit); - ASSERT_EQ(ioStats_->ramHit().sum(), prevRamHit); - ASSERT_EQ(ioStats_->ssdRead().sum(), prevSsdRead); + ASSERT_EQ(dataIoStats_->ramHit().sum(), prevRamHit); + ASSERT_EQ(dataIoStats_->ssdRead().sum(), prevSsdRead); } diff --git a/velox/dwio/dwrf/test/ColumnStatisticsBase.h b/velox/dwio/dwrf/test/ColumnStatisticsBase.h index 32b09aacfcf..e1264cd9f27 100644 --- a/velox/dwio/dwrf/test/ColumnStatisticsBase.h +++ b/velox/dwio/dwrf/test/ColumnStatisticsBase.h @@ -18,10 +18,14 @@ #include +#include "velox/dwio/common/Arena.h" #include "velox/dwio/common/Statistics.h" #include "velox/dwio/dwrf/writer/StatisticsBuilder.h" namespace facebook::velox::dwrf { + +using dwio::common::ArenaCreate; + class ColumnStatisticsBase { public: ColumnStatisticsBase() @@ -752,8 +756,7 @@ class ColumnStatisticsBase { if (format == DwrfFormat::kDwrf) { auto columnStatistics = - google::protobuf::Arena::CreateMessage( - arena_.get()); + ArenaCreate(arena_.get()); if (from == State::kFalse) { columnStatistics->set_hasnull(false); } else if (from == State::kTrue) { @@ -762,8 +765,8 @@ class ColumnStatisticsBase { target.merge(*buildColumnStatisticsFromProto( ColumnStatisticsWrapper(columnStatistics), context())); } else { - auto columnStatistics = google::protobuf::Arena::CreateMessage< - proto::orc::ColumnStatistics>(arena_.get()); + auto columnStatistics = + ArenaCreate(arena_.get()); if (from == State::kFalse) { columnStatistics->set_hasnull(false); } else if (from == State::kTrue) { diff --git a/velox/dwio/dwrf/test/ColumnWriterIndexTest.cpp b/velox/dwio/dwrf/test/ColumnWriterIndexTest.cpp index a8abecaed2f..abf23b57b67 100644 --- a/velox/dwio/dwrf/test/ColumnWriterIndexTest.cpp +++ b/velox/dwio/dwrf/test/ColumnWriterIndexTest.cpp @@ -371,9 +371,10 @@ class WriterEncodingIndexTest2 { } } proto::StripeFooter stripeFooter; + auto sfw = StripeFooterWriteWrapper(&stripeFooter); columnWriter->flush( - [&stripeFooter](uint32_t /* unused */) -> proto::ColumnEncoding& { - return *stripeFooter.add_encoding(); + [&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); // Simulate continue writing to next stripe, so internally buffered data @@ -821,9 +822,10 @@ class IntegerColumnWriterDirectEncodingIndexTest : public testing::Test { // *all* streams EXPECT_CALL(*mockIndexBuilderPtr, add(0, -1)).Times(positionCount); proto::StripeFooter stripeFooter; + auto sfw = StripeFooterWriteWrapper(&stripeFooter); columnWriter->flush( - [&stripeFooter](uint32_t /* unused */) -> proto::ColumnEncoding& { - return *stripeFooter.add_encoding(); + [&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); } else { for (size_t i = 0; i != pageCount; ++i) { @@ -847,9 +849,10 @@ class IntegerColumnWriterDirectEncodingIndexTest : public testing::Test { EXPECT_CALL(*mockIndexBuilderPtr, flush()); EXPECT_CALL(*mockIndexBuilderPtr, add(0, -1)).Times(positionCount); proto::StripeFooter stripeFooter; + auto sfw = StripeFooterWriteWrapper(&stripeFooter); columnWriter->flush( - [&stripeFooter](uint32_t /* unused */) -> proto::ColumnEncoding& { - return *stripeFooter.add_encoding(); + [&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); } @@ -972,9 +975,10 @@ class StringColumnWriterDictionaryEncodingIndexTest : public testing::Test { // Recording PRESENT stream starting positions for the new stripe. EXPECT_CALL(*mockIndexBuilderPtr, add(0, -1)).Times(4); proto::StripeFooter stripeFooter; + auto sfw = StripeFooterWriteWrapper(&stripeFooter); columnWriter->flush( - [&stripeFooter](uint32_t /* unused */) -> proto::ColumnEncoding& { - return *stripeFooter.add_encoding(); + [&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); // Simulate continue writing to next stripe, so internally buffered data @@ -1128,9 +1132,10 @@ class StringColumnWriterDirectEncodingIndexTest : public testing::Test { // *all* streams EXPECT_CALL(*mockIndexBuilderPtr, add(0, -1)).Times(positionCount); proto::StripeFooter stripeFooter; + auto sfw = StripeFooterWriteWrapper(&stripeFooter); columnWriter->flush( - [&stripeFooter](uint32_t /* unused */) -> proto::ColumnEncoding& { - return *stripeFooter.add_encoding(); + [&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); } else { for (size_t i = 0; i != pageCount; ++i) { @@ -1154,9 +1159,10 @@ class StringColumnWriterDirectEncodingIndexTest : public testing::Test { EXPECT_CALL(*mockIndexBuilderPtr, flush()); EXPECT_CALL(*mockIndexBuilderPtr, add(0, -1)).Times(positionCount); proto::StripeFooter stripeFooter; + auto sfw = StripeFooterWriteWrapper(&stripeFooter); columnWriter->flush( - [&stripeFooter](uint32_t /* unused */) -> proto::ColumnEncoding& { - return *stripeFooter.add_encoding(); + [&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); } @@ -1208,7 +1214,7 @@ class ListColumnWriterEncodingIndexTest : public testing::Test, public WriterEncodingIndexTest2 { public: ListColumnWriterEncodingIndexTest() - : WriterEncodingIndexTest2(ARRAY(REAL())){}; + : WriterEncodingIndexTest2(ARRAY(REAL())) {}; protected: static void SetUpTestCase() { diff --git a/velox/dwio/dwrf/test/ColumnWriterStatsTests.cpp b/velox/dwio/dwrf/test/ColumnWriterStatsTests.cpp index 8dad33a0d98..03cada99169 100644 --- a/velox/dwio/dwrf/test/ColumnWriterStatsTests.cpp +++ b/velox/dwio/dwrf/test/ColumnWriterStatsTests.cpp @@ -18,6 +18,7 @@ #include #include "velox/common/base/Nulls.h" +#include "velox/common/io/IoStatistics.h" #include "velox/dwio/common/tests/utils/BatchMaker.h" #include "velox/dwio/dwrf/reader/DwrfReader.h" #include "velox/dwio/dwrf/writer/FlushPolicy.h" @@ -194,7 +195,9 @@ class ColumnWriterStatsTest : public ::testing::Test { std::make_shared(std::move(data)); auto input = std::make_unique(readFile, *leafPool_); - dwio::common::ReaderOptions readerOpts{leafPool_.get()}; + dwio::common::ReaderOptions readerOpts(leafPool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); RowReaderOptions rowReaderOpts; auto reader = std::make_unique(readerOpts, std::move(input)); return reader->createRowReader(rowReaderOpts); @@ -228,6 +231,10 @@ class ColumnWriterStatsTest : public ::testing::Test { std::shared_ptr rootPool_; std::shared_ptr leafPool_; + std::shared_ptr dataIoStats_ = + std::make_shared(); + std::shared_ptr metadataIoStats_ = + std::make_shared(); }; template diff --git a/velox/dwio/dwrf/test/ColumnWriterTest.cpp b/velox/dwio/dwrf/test/ColumnWriterTest.cpp index a1820322ddf..804c96f2103 100644 --- a/velox/dwio/dwrf/test/ColumnWriterTest.cpp +++ b/velox/dwio/dwrf/test/ColumnWriterTest.cpp @@ -213,7 +213,7 @@ VectorPtr populateBatch( auto valuesPtr = values->asMutableRange(); const size_t nulloptCount = - std::count(data.begin(), data.end(), std::nullopt); + std::count(data.cbegin(), data.cend(), std::nullopt); if (nulloptCount == 0) { size_t index = 0; for (auto val : data) { @@ -354,12 +354,13 @@ void testDataTypeWriter( for (auto stripeI = 0; stripeI < stripeCount; ++stripeI) { proto::StripeFooter sf; + auto sfw = StripeFooterWriteWrapper(&sf); for (auto strideI = 0; strideI < strideCount; ++strideI) { writer->write(batch, common::Ranges::of(0, size)); writer->createIndexEntry(); } - writer->flush([&sf](uint32_t /* unused */) -> proto::ColumnEncoding& { - return *sf.add_encoding(); + writer->flush([&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); TestStripeStreams streams(context, sf, rowType, pool.get()); @@ -1055,6 +1056,7 @@ void testMapWriter( } proto::StripeFooter sf; + auto sfw = StripeFooterWriteWrapper(&sf); std::vector writtenBatches; // Write map/row @@ -1072,8 +1074,8 @@ void testMapWriter( writtenBatches.push_back(toWrite); } - writer->flush([&sf](uint32_t /* unused */) -> proto::ColumnEncoding& { - return *sf.add_encoding(); + writer->flush([&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); auto validate = [&](bool returnFlatVector = false) { @@ -1199,6 +1201,7 @@ void testMapWriterRow( } proto::StripeFooter sf; + auto sfw = StripeFooterWriteWrapper(&sf); std::vector writtenBatches; // Write map/row @@ -1210,8 +1213,8 @@ void testMapWriterRow( writer->createIndexEntry(); writtenBatches.push_back(toWrite); - writer->flush([&sf](uint32_t /* unused */) -> proto::ColumnEncoding& { - return *sf.add_encoding(); + writer->flush([&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); auto validate = [&](bool returnFlatVector = false) { @@ -1479,8 +1482,9 @@ void testFlatMapWriter( writer->createIndexEntry(); proto::StripeFooter sf; - writer->flush([&sf](uint32_t /* unused */) -> proto::ColumnEncoding& { - return *sf.add_encoding(); + auto sfw = StripeFooterWriteWrapper(&sf); + writer->flush([&sfw](uint32_t /*unused*/) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); // Reading the vector out @@ -1503,6 +1507,93 @@ void testFlatMapWriter( } } +// Regression test for dangling StringView keys in FlatMapColumnWriter. +// +// FlatMapColumnWriter stores StringView keys in an F14NodeMap (valueWriters_). +// StringView is non-owning — it holds a pointer to external string data. +// When the input batch is released between writes, the stored StringViews +// become dangling. On a subsequent write that triggers F14 rehash, +// F14Table::rehashImpl recomputes hashes from the dangling pointers and +// hits assertion: hp.second == srcChunk->tag(srcI). +// +// This test reproduces the crash by: +// 1. Writing a batch with long string keys (>12 chars → external pointer) +// 2. Corrupting the backing buffer to simulate freed memory +// 3. Writing a second batch with enough new keys to trigger F14 rehash +// +// With the bug: crashes with SIGABRT in F14Table::rehashImpl. +// With the fix: passes (keys are properly owned). +TEST_F(ColumnWriterTest, TestFlatMapDanglingStringViewKeyOnRehash) { + const auto rowType = CppToType>>::create(); + const auto writerSchema = TypeWithId::create(rowType); + const auto writerDataTypeWithId = writerSchema->childAt(0); + + const auto config = std::make_shared(); + config->set(Config::FLATTEN_MAP, true); + config->set(Config::MAP_FLAT_COLS, {writerDataTypeWithId->column()}); + config->set(Config::MAP_FLAT_DISABLE_DICT_ENCODING, true); + config->set(Config::MAP_FLAT_MAX_KEYS, 100); + + WriterContext context{config, memory::memoryManager()->addRootPool()}; + context.initBuffer(); + const auto writer = BaseColumnWriter::create(context, *writerDataTypeWithId); + + using b = MapBuilder; + + // Batch 1: 10 rows, each with one key-value pair. + // Keys are >12 chars so StringView uses an external data pointer + // (not inline storage). We control the backing buffer so we can + // corrupt it after writing to reliably simulate use-after-free. + constexpr int kBatch1Size = 10; + constexpr int kKeySlotSize = 32; + auto keyDataBuf = std::make_unique(kBatch1Size * kKeySlotSize); + + { + b::rows rows; + for (int i = 0; i < kBatch1Size; ++i) { + auto key = fmt::format("very_long_string_key_{:04d}", i); + ASSERT_GT(key.size(), 12u); + memcpy(keyDataBuf.get() + i * kKeySlotSize, key.data(), key.size()); + StringView sv( + keyDataBuf.get() + i * kKeySlotSize, + static_cast(key.size())); + rows.emplace_back(b::row{b::pair{sv, static_cast(i)}}); + } + auto batch = b::create(*pool_, rows); + writer->write(batch, common::Ranges::of(0, batch->size())); + } + + // Corrupt batch 1's key data to simulate the input vector's string + // buffer being freed and reused. The writer's F14NodeMap still holds + // StringView keys pointing into this buffer. + memset(keyDataBuf.get(), 0xFF, kBatch1Size * kKeySlotSize); + + // Batch 2: 15 rows with new keys. Combined with batch 1's 10 keys, + // the total of 25 unique keys exceeds the F14NodeMap's initial + // capacity, triggering rehashImpl. During rehash, F14 recomputes + // hashes for batch 1's stored StringView keys — which now point to + // corrupted memory — producing different hashes and hitting the + // assertion: hp.second == srcChunk->tag(srcI). + { + std::vector keyStrings; + keyStrings.reserve(15); + for (int i = 0; i < 15; ++i) { + keyStrings.push_back(fmt::format("different_long_key_b2_{:04d}", i)); + } + b::rows rows; + for (int i = 0; i < 15; ++i) { + rows.emplace_back( + b::row{b::pair{ + StringView(keyStrings[i]), + static_cast(i + kBatch1Size)}}); + } + auto batch = b::create(*pool_, rows); + writer->write(batch, common::Ranges::of(0, batch->size())); + } + + writer->createIndexEntry(); +} + TEST_F(ColumnWriterTest, TestFlatMapKeyNotInAllBatches) { VectorMaker maker(pool_.get()); // Test the case where not all keys appear in all batches. @@ -1846,7 +1937,9 @@ std::unique_ptr getDwrfReader( MemoryPool& leafPool, const std::shared_ptr type, const VectorPtr& batch, - bool useFlatMap) { + bool useFlatMap, + std::shared_ptr dataIoStats, + std::shared_ptr metadataIoStats) { auto config = std::make_shared(); if (useFlatMap) { config->set(Config::FLATTEN_MAP, true); @@ -1872,6 +1965,8 @@ std::unique_ptr getDwrfReader( std::string data(sinkPtr->data(), sinkPtr->size()); dwio::common::ReaderOptions readerOpts{&leafPool}; + readerOpts.setDataIoStats(std::move(dataIoStats)); + readerOpts.setMetadataIoStats(std::move(metadataIoStats)); return std::make_unique( readerOpts, std::make_unique( @@ -1894,8 +1989,12 @@ void testMapWriterStats(const std::shared_ptr type) { auto rootPool = memory::memoryManager()->addRootPool(); auto leafPool = memory::memoryManager()->addLeafPool(); auto batch = BatchMaker::createBatch(type, 10, *leafPool); - auto mapReader = getDwrfReader(*rootPool, *leafPool, type, batch, false); - auto flatMapReader = getDwrfReader(*rootPool, *leafPool, type, batch, true); + auto dataIoStats = std::make_shared(); + auto metadataIoStats = std::make_shared(); + auto mapReader = getDwrfReader( + *rootPool, *leafPool, type, batch, false, dataIoStats, metadataIoStats); + auto flatMapReader = getDwrfReader( + *rootPool, *leafPool, type, batch, true, dataIoStats, metadataIoStats); ASSERT_EQ( mapReader->getFooter().statisticsSize(), flatMapReader->getFooter().statisticsSize()); @@ -2364,6 +2463,7 @@ struct IntegerColumnWriterTypedTestCase { for (size_t i = 0; i != flushCount; ++i) { proto::StripeFooter stripeFooter; + auto sfw = StripeFooterWriteWrapper(&stripeFooter); for (size_t j = 0; j != repetitionCount; ++j) { columnWriter->write(batch, common::Ranges::of(0, batch->size())); postProcess(*columnWriter, i, j); @@ -2371,8 +2471,8 @@ struct IntegerColumnWriterTypedTestCase { } // We only flush once per stripe. columnWriter->flush( - [&stripeFooter](uint32_t /* unused */) -> proto::ColumnEncoding& { - return *stripeFooter.add_encoding(); + [&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); // Read and verify. @@ -3524,7 +3624,7 @@ struct StringColumnWriterTestCase { postProcess{postProcess}, repetitionCount{repetitionCount}, flushCount{flushCount}, - type{CppToType::create()} {} + type{CppToType::create()} {} virtual ~StringColumnWriterTestCase() = default; @@ -3598,6 +3698,7 @@ struct StringColumnWriterTestCase { for (size_t i = 0; i != flushCount; ++i) { proto::StripeFooter stripeFooter; + auto sfw = StripeFooterWriteWrapper(&stripeFooter); // Write Stride for (size_t j = 0; j != repetitionCount; ++j) { // TODO: break the batch into multiple strides. @@ -3608,8 +3709,8 @@ struct StringColumnWriterTestCase { // Flush when all strides are written (once per stripe). columnWriter->flush( - [&stripeFooter](uint32_t /* unused */) -> proto::ColumnEncoding& { - return *stripeFooter.add_encoding(); + [&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); // Read and verify. @@ -4438,8 +4539,9 @@ TEST_F(ColumnWriterTest, IntDictWriterDirectValueOverflow) { writer->write(vector, common::Ranges::of(0, size)); writer->createIndexEntry(); proto::StripeFooter sf; - writer->flush([&sf](auto /* unused */) -> proto::ColumnEncoding& { - return *sf.add_encoding(); + auto sfw = StripeFooterWriteWrapper(&sf); + writer->flush([&sfw](auto /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); auto& enc = sf.encoding(0); ASSERT_EQ(enc.kind(), proto::ColumnEncoding_Kind_DICTIONARY); @@ -4483,8 +4585,9 @@ TEST_F(ColumnWriterTest, ShortDictWriterDictValueOverflow) { writer->write(vector, common::Ranges::of(0, size)); writer->createIndexEntry(); proto::StripeFooter sf; - writer->flush([&sf](auto /* unused */) -> proto::ColumnEncoding& { - return *sf.add_encoding(); + auto sfw = StripeFooterWriteWrapper(&sf); + writer->flush([&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); auto& enc = sf.encoding(0); ASSERT_EQ(enc.kind(), proto::ColumnEncoding_Kind_DICTIONARY); @@ -4524,8 +4627,9 @@ TEST_F(ColumnWriterTest, RemovePresentStream) { writer->write(vector, common::Ranges::of(0, size)); writer->createIndexEntry(); proto::StripeFooter sf; - writer->flush([&sf](auto /* unused */) -> proto::ColumnEncoding& { - return *sf.add_encoding(); + auto sfw = StripeFooterWriteWrapper(&sf); + writer->flush([&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); // get data stream @@ -4562,8 +4666,9 @@ TEST_F(ColumnWriterTest, ColumnIdInStream) { writer->write(vector, common::Ranges::of(0, size)); writer->createIndexEntry(); proto::StripeFooter sf; - writer->flush([&sf](auto /* unused */) -> proto::ColumnEncoding& { - return *sf.add_encoding(); + auto sfw = StripeFooterWriteWrapper(&sf); + writer->flush([&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); // get data stream @@ -4691,8 +4796,9 @@ struct DictColumnWriterTestCase { writer->createIndexEntry(); proto::StripeFooter sf; - writer->flush([&sf](uint32_t /* unused */) -> proto::ColumnEncoding& { - return *sf.add_encoding(); + auto sfw = StripeFooterWriteWrapper(&sf); + writer->flush([&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); // Reading the vector out diff --git a/velox/dwio/dwrf/test/CompressionTest.cpp b/velox/dwio/dwrf/test/CompressionTest.cpp index b2230ff21e6..9343af57e6d 100644 --- a/velox/dwio/dwrf/test/CompressionTest.cpp +++ b/velox/dwio/dwrf/test/CompressionTest.cpp @@ -401,10 +401,11 @@ TEST_P(CompressionTest, getCompressionBufferOOM) { {true, true}, {true, false}, {false, true}, {false, false}}; for (const auto& testData : testSettings) { - SCOPED_TRACE(fmt::format( - "{} compression {}", - testData.debugString(), - compressionKindToString(kind_))); + SCOPED_TRACE( + fmt::format( + "{} compression {}", + testData.debugString(), + compressionKindToString(kind_))); auto config = std::make_shared(); config->set(Config::COMPRESSION, kind_); diff --git a/velox/dwio/dwrf/test/ConfigTests.cpp b/velox/dwio/dwrf/test/ConfigTests.cpp index 6475d4e89bd..7097f1e1e99 100644 --- a/velox/dwio/dwrf/test/ConfigTests.cpp +++ b/velox/dwio/dwrf/test/ConfigTests.cpp @@ -78,6 +78,10 @@ struct ConfigTestParams { std::vector expectedCols{}; // do we expect the spec to be valid }; +inline void PrintTo(const ConfigTestParams& param, std::ostream* os) { + *os << "cols:" << param.inputCols; +} + TEST(ConfigTests, writerOptionsDefaultConfig) { WriterOptions options; const facebook::velox::config::ConfigBase base({}); diff --git a/velox/dwio/dwrf/test/DataBufferHolderTests.cpp b/velox/dwio/dwrf/test/DataBufferHolderTests.cpp index 6b53cbfab22..d4b5f6e1238 100644 --- a/velox/dwio/dwrf/test/DataBufferHolderTests.cpp +++ b/velox/dwio/dwrf/test/DataBufferHolderTests.cpp @@ -36,9 +36,15 @@ TEST_F(DataBufferHolderTest, InputCheck) { VELOX_ASSERT_THROW((DataBufferHolder{*pool_, 1024, 2048}), ""); VELOX_ASSERT_THROW((DataBufferHolder{*pool_, 1024, 1024, 1.1f}), ""); - { DataBufferHolder holder{*pool_, 1024}; } - { DataBufferHolder holder{*pool_, 1024, 512}; } - { DataBufferHolder holder{*pool_, 1024, 512, 3.0f}; } + { + DataBufferHolder holder{*pool_, 1024}; + } + { + DataBufferHolder holder{*pool_, 1024, 512}; + } + { + DataBufferHolder holder{*pool_, 1024, 512, 3.0f}; + } } TEST_F(DataBufferHolderTest, TakeAndGetBuffer) { diff --git a/velox/dwio/dwrf/test/TestDecompression.cpp b/velox/dwio/dwrf/test/DecompressionTest.cpp similarity index 96% rename from velox/dwio/dwrf/test/TestDecompression.cpp rename to velox/dwio/dwrf/test/DecompressionTest.cpp index 3dcf5f68e46..7a4d6b28377 100644 --- a/velox/dwio/dwrf/test/TestDecompression.cpp +++ b/velox/dwio/dwrf/test/DecompressionTest.cpp @@ -336,7 +336,7 @@ TEST_F(DecompressionTest, testLzoSmall) { std::unique_ptr result = createTestDecompressor( CompressionKind_LZO, std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))), + new SeekableArrayInputStream(buffer, std::size(buffer))), 128 * 1024); const void* ptr; int32_t length; @@ -353,7 +353,7 @@ TEST_F(DecompressionTest, testLzoSmall) { TEST_F(DecompressionTest, testLzoLong) { // set up a framed lzo buffer with 100,000 'a' unsigned char buffer[482]; - bzero(buffer, VELOX_ARRAY_SIZE(buffer)); + bzero(buffer, std::size(buffer)); // header buffer[0] = 190; buffer[1] = 3; @@ -378,7 +378,7 @@ TEST_F(DecompressionTest, testLzoLong) { std::unique_ptr result = createTestDecompressor( CompressionKind_LZO, std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))), + new SeekableArrayInputStream(buffer, std::size(buffer))), 128 * 1024); const void* ptr; int32_t length; @@ -413,7 +413,7 @@ TEST_F(DecompressionTest, testLz4Small) { std::unique_ptr result = createTestDecompressor( CompressionKind_LZ4, std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))), + new SeekableArrayInputStream(buffer, std::size(buffer))), 128 * 1024); const void* ptr; int32_t length; @@ -430,7 +430,7 @@ TEST_F(DecompressionTest, testLz4Small) { TEST_F(DecompressionTest, testLz4Long) { // set up a framed lzo buffer with 100,000 'a' unsigned char buffer[406]; - memset(buffer, 255, VELOX_ARRAY_SIZE(buffer)); + memset(buffer, 255, std::size(buffer)); // header buffer[0] = 38; buffer[1] = 3; @@ -448,7 +448,7 @@ TEST_F(DecompressionTest, testLz4Long) { std::unique_ptr result = createTestDecompressor( CompressionKind_LZ4, std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))), + new SeekableArrayInputStream(buffer, std::size(buffer))), 128 * 1024); const void* ptr; int32_t length; @@ -465,7 +465,7 @@ TEST_F(DecompressionTest, testCreateZlib) { std::unique_ptr result = createTestDecompressor( CompressionKind_ZLIB, std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))), + new SeekableArrayInputStream(buffer, std::size(buffer))), 32768); EXPECT_EQ( "PagedInputStream StreamInfo (Test Decompression) input stream (SeekableArrayInputStream 0 of 8) State (0) remaining length (0)", @@ -497,7 +497,7 @@ TEST_F(DecompressionTest, testLiteralBlocks) { std::unique_ptr result = createTestDecompressor( CompressionKind_ZLIB, std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer), 5)), + new SeekableArrayInputStream(buffer, std::size(buffer), 5)), 5); EXPECT_EQ( "PagedInputStream StreamInfo (Test Decompression) input stream (SeekableArrayInputStream 0 of 23) State (0) remaining length (0)", @@ -539,7 +539,7 @@ TEST_F(DecompressionTest, testInflate) { std::unique_ptr result = createTestDecompressor( CompressionKind_ZLIB, std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))), + new SeekableArrayInputStream(buffer, std::size(buffer))), 1000); const void* ptr; int32_t length; @@ -552,6 +552,25 @@ TEST_F(DecompressionTest, testInflate) { } } +TEST_F(DecompressionTest, testSmallBufferInflate) { + const unsigned char buffer[] = { + 0xe, 0x0, 0x0, 0x63, 0x60, 0x64, 0x62, 0xc0, 0x8d, 0x0}; + const std::unique_ptr result = createTestDecompressor( + CompressionKind_ZLIB, + std::make_unique(buffer, std::size(buffer)), + 1 // blockSize 1 to test multiple inflate calls during decompression. + ); + const void* ptr; + int32_t length; + ASSERT_EQ(true, result->Next(&ptr, &length)); + ASSERT_EQ(30, length); + for (int32_t i = 0; i < 10; ++i) { + for (int32_t j = 0; j < 3; ++j) { + EXPECT_EQ(j, static_cast(ptr)[i * 3 + j]); + } + } +} + TEST_F(DecompressionTest, testInflateSequence) { const unsigned char buffer[] = {0xe, 0x0, 0x0, 0x63, 0x60, 0x64, 0x62, 0xc0, 0x8d, 0x0, 0xe, 0x0, 0x0, 0x63, @@ -559,7 +578,7 @@ TEST_F(DecompressionTest, testInflateSequence) { std::unique_ptr result = createTestDecompressor( CompressionKind_ZLIB, std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer), 3)), + new SeekableArrayInputStream(buffer, std::size(buffer), 3)), 1000); const void* ptr; int32_t length; @@ -594,7 +613,7 @@ TEST_F(DecompressionTest, testSkipZlib) { std::unique_ptr result = createTestDecompressor( CompressionKind_ZLIB, std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer), 5)), + new SeekableArrayInputStream(buffer, std::size(buffer), 5)), 5); const void* ptr; int32_t length; @@ -906,7 +925,7 @@ class TestSeek : public ::testing::Test { size_t arr[][2]{{0, 0}, {0, seekPos}, {offset1, seekPos}, {offset1, 0}}; char* input[]{input1, input1, input2, input2}; - for (size_t i = 0; i < VELOX_ARRAY_SIZE(arr); ++i) { + for (size_t i = 0; i < std::size(arr); ++i) { auto pos = arr[i]; std::vector list{pos[0], pos[1]}; PositionProvider pp(list); @@ -1046,8 +1065,8 @@ class TestingSeekableInputStream : public SeekableInputStream { return true; } - google::protobuf::int64 ByteCount() const override { - return position_; + int64_t ByteCount() const override { + return static_cast(position_); } void seekToPosition(PositionProvider& position) override { @@ -1097,7 +1116,7 @@ TEST_F(TestSeek, uncompressedLarge) { entry.getCompressed()[i] = static_cast(i); } written += runSize + kHeaderSize; - data.insert(data.end(), entry.data().begin(), entry.data().end()); + data.insert(data.end(), entry.data().cbegin(), entry.data().cend()); } auto stream = createTestDecompressor( CompressionKind_SNAPPY, diff --git a/velox/dwio/dwrf/test/DecryptionTests.cpp b/velox/dwio/dwrf/test/DecryptionTests.cpp index 6a19a41c484..7cff7a4acb2 100644 --- a/velox/dwio/dwrf/test/DecryptionTests.cpp +++ b/velox/dwio/dwrf/test/DecryptionTests.cpp @@ -32,7 +32,8 @@ TEST(Decryption, NotEncrypted) { HiveTypeParser parser; auto type = parser.parse("struct"); proto::Footer footer; - ProtoUtils::writeType(*type, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*type, footerWrapper); TestDecrypterFactory factory; auto handler = DecryptionHandler::create(footer, &factory); ASSERT_FALSE(handler->isEncrypted()); @@ -42,7 +43,8 @@ TEST(Decryption, NoKeyProvider) { HiveTypeParser parser; auto type = parser.parse("struct"); proto::Footer footer; - ProtoUtils::writeType(*type, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*type, footerWrapper); footer.mutable_encryption(); TestDecrypterFactory factory; ASSERT_THROW( @@ -53,7 +55,8 @@ TEST(Decryption, EmptyGroup) { HiveTypeParser parser; auto type = parser.parse("struct"); proto::Footer footer; - ProtoUtils::writeType(*type, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*type, footerWrapper); auto enc = footer.mutable_encryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); TestDecrypterFactory factory; @@ -65,7 +68,8 @@ TEST(Decryption, EmptyNodes) { HiveTypeParser parser; auto type = parser.parse("struct"); proto::Footer footer; - ProtoUtils::writeType(*type, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*type, footerWrapper); auto enc = footer.mutable_encryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); auto group = enc->add_encryptiongroups(); @@ -79,7 +83,8 @@ TEST(Decryption, StatsMismatch) { HiveTypeParser parser; auto type = parser.parse("struct"); proto::Footer footer; - ProtoUtils::writeType(*type, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*type, footerWrapper); auto enc = footer.mutable_encryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); auto group = enc->add_encryptiongroups(); @@ -96,7 +101,8 @@ TEST(Decryption, KeyExistenceMismatch) { HiveTypeParser parser; auto type = parser.parse("struct"); proto::Footer footer; - ProtoUtils::writeType(*type, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*type, footerWrapper); auto enc = footer.mutable_encryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); for (size_t i = 0; i < 2; ++i) { @@ -116,7 +122,8 @@ TEST(Decryption, ReuseStripeKey) { HiveTypeParser parser; auto type = parser.parse("struct"); proto::Footer footer; - ProtoUtils::writeType(*type, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*type, footerWrapper); auto enc = footer.mutable_encryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); auto group = enc->add_encryptiongroups(); @@ -135,7 +142,8 @@ TEST(Decryption, StripeKeyMismatch) { HiveTypeParser parser; auto type = parser.parse("struct"); proto::Footer footer; - ProtoUtils::writeType(*type, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*type, footerWrapper); auto enc = footer.mutable_encryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); auto group = enc->add_encryptiongroups(); @@ -153,7 +161,8 @@ TEST(Decryption, Basic) { HiveTypeParser parser; auto type = parser.parse("struct"); proto::Footer footer; - ProtoUtils::writeType(*type, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*type, footerWrapper); auto enc = footer.mutable_encryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); for (auto i = 0; i < 5; ++i) { @@ -183,7 +192,8 @@ TEST(Decryption, NestedType) { auto type = parser.parse( "struct>,c:struct,d:array>"); proto::Footer footer; - ProtoUtils::writeType(*type, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*type, footerWrapper); auto enc = footer.mutable_encryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); @@ -222,7 +232,8 @@ TEST(Decryption, RootNode) { HiveTypeParser parser; auto type = parser.parse("struct"); proto::Footer footer; - ProtoUtils::writeType(*type, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*type, footerWrapper); auto enc = footer.mutable_encryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); auto group = enc->add_encryptiongroups(); @@ -238,7 +249,8 @@ TEST(Decryption, GroupOverlap) { HiveTypeParser parser; auto type = parser.parse("struct>"); proto::Footer footer; - ProtoUtils::writeType(*type, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*type, footerWrapper); auto enc = footer.mutable_encryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); diff --git a/velox/dwio/dwrf/test/DirectBufferedInputTest.cpp b/velox/dwio/dwrf/test/DirectBufferedInputTest.cpp index e3ddd5b457d..c7d74c333a6 100644 --- a/velox/dwio/dwrf/test/DirectBufferedInputTest.cpp +++ b/velox/dwio/dwrf/test/DirectBufferedInputTest.cpp @@ -52,12 +52,15 @@ class DirectBufferedInputTest : public testing::Test { void SetUp() override { executor_ = std::make_unique(10, 10); - ioStats_ = std::make_shared(); + ioStatistics_ = std::make_shared(); + metadataIoStats_ = std::make_shared(); fileIoStats_ = std::make_shared(); - fsStats_ = std::make_shared(); + ioStats_ = std::make_shared(); tracker_ = std::make_shared("", nullptr, kLoadQuantum); - file_ = std::make_shared(11, 100 << 20, fsStats_); + file_ = std::make_shared(11, 100 << 20, ioStats_); opts_ = std::make_unique(pool_.get()); + opts_->setDataIoStats(ioStatistics_); + opts_->setMetadataIoStats(metadataIoStats_); opts_->setLoadQuantum(kLoadQuantum); } @@ -66,16 +69,15 @@ class DirectBufferedInputTest : public testing::Test { } std::unique_ptr makeInput( - const std::shared_ptr& - fsStats) { + const std::shared_ptr& ioStats) { return std::make_unique( file_, dwio::common::MetricsLog::voidLog(), StringIdLease{}, tracker_, StringIdLease{}, - ioStats_, - fsStats, + ioStatistics_, + ioStats, executor_.get(), *opts_); } @@ -85,10 +87,9 @@ class DirectBufferedInputTest : public testing::Test { void testLoads( std::vector regions, int32_t numIos, - const std::shared_ptr& - fsStats) { + const std::shared_ptr& ioStats) { auto previous = file_->numIos(); - auto input = makeInput(fsStats); + auto input = makeInput(ioStats); std::vector> streams; for (auto i = 0; i < regions.size(); ++i) { if (regions[i].length > 0) { @@ -135,9 +136,10 @@ class DirectBufferedInputTest : public testing::Test { std::unique_ptr opts_; std::shared_ptr file_; std::shared_ptr tracker_; - std::shared_ptr ioStats_; + std::shared_ptr ioStatistics_; + std::shared_ptr metadataIoStats_; std::shared_ptr fileIoStats_; - std::shared_ptr fsStats_; + std::shared_ptr ioStats_; std::unique_ptr executor_; std::shared_ptr pool_{ memory::memoryManager()->addLeafPool()}; @@ -154,7 +156,7 @@ TEST_F(DirectBufferedInputTest, basic) { {7004000, 2000000}, {20000000, 10000000}}, 4, - fsStats_); + ioStats_); // All but the last coalesce into one , the last is read in 2 parts. The // columns are now dense and coalesce goes up to 128MB if gaps are small @@ -166,55 +168,54 @@ TEST_F(DirectBufferedInputTest, basic) { {7004000, 2000000}, {20000000, 10000000}}, 3, - fsStats_); + ioStats_); // Mark the first 4 ranges as densely accessed. makeDense(4); // The first and first part of second coalesce. - testLoads({{100, 100}, {1000, 10000000}}, 2, fsStats_); + testLoads({{100, 100}, {1000, 10000000}}, 2, ioStats_); // The first is read in two parts, the tail of the first does not coalesce // with the second. - testLoads({{1000, 10000000}, {10001000, 1000}}, 3, fsStats_); + testLoads({{1000, 10000000}, {10001000, 1000}}, 3, ioStats_); // One large standalone read in 2 parts. - testLoads({{1000, 10000000}}, 2, fsStats_); + testLoads({{1000, 10000000}}, 2, ioStats_); // Small standalone read in 1 part. - testLoads({{100, 100}}, 1, fsStats_); + testLoads({{100, 100}}, 1, ioStats_); // Two small far apart - testLoads({{100, 100}, {1000000, 100}}, 2, fsStats_); + testLoads({{100, 100}, {1000000, 100}}, 2, ioStats_); // The two coalesce because the first fits within load quantum + max coalesce // distance. - testLoads({{1000, 8500000}, {8510000, 1000000}}, 1, fsStats_); + testLoads({{1000, 8500000}, {8510000, 1000000}}, 1, ioStats_); // The two coalesce because the first fits within load quantum + max coalesce // distance. The tail of the second does not coalesce. - testLoads({{1000, 8500000}, {8510000, 8400000}}, 2, fsStats_); + testLoads({{1000, 8500000}, {8510000, 8400000}}, 2, ioStats_); // The first reads in 2 parts and does not coalesce to the second, which reads // in one part. - testLoads({{1000, 9000000}, {9010000, 1000000}}, 3, fsStats_); + testLoads({{1000, 9000000}, {9010000, 1000000}}, 3, ioStats_); } TEST_F(DirectBufferedInputTest, noRedownloadCoalescedPrefetch) { - testLoads({{100, 100}, {201, 1, false}, {202, 100}}, 1, fsStats_); - testLoads({{100, 100}, {201, 1, true}, {202, 100}}, 1, fsStats_); + testLoads({{100, 100}, {201, 1, false}, {202, 100}}, 1, ioStats_); + testLoads({{100, 100}, {201, 1, true}, {202, 100}}, 1, ioStats_); } TEST_F(DirectBufferedInputTest, coalesedPrefetchOverlap) { testLoads( - {{100, 100}, {201, 1, false}, {201, 2, false}, {203, 100}}, 2, fsStats_); + {{100, 100}, {201, 1, false}, {201, 2, false}, {203, 100}}, 2, ioStats_); testLoads( - {{100, 100}, {201, 1, true}, {201, 2, true}, {203, 100}}, 2, fsStats_); + {{100, 100}, {201, 1, true}, {201, 2, true}, {203, 100}}, 2, ioStats_); } TEST_F(DirectBufferedInputTest, ioStatsLifeTimeTest) { for (size_t i = 0; i < 10; i++) { - auto stats = - std::make_shared(); + auto stats = std::make_shared(); // Induce a tiny sleep so that we're more likely to destruct the thread as // we're trying to bump ioStats inside the test file std::thread t([s = stats]() mutable { diff --git a/velox/dwio/dwrf/test/E2EFilterTest.cpp b/velox/dwio/dwrf/test/E2EFilterTest.cpp index 43b67e91e55..61d9735af26 100644 --- a/velox/dwio/dwrf/test/E2EFilterTest.cpp +++ b/velox/dwio/dwrf/test/E2EFilterTest.cpp @@ -63,7 +63,8 @@ class E2EFilterTest : public E2EFilterTestBase { void writeToMemory( const TypePtr& type, const std::vector& batches, - bool forRowGroupSkip = false) override { + bool forRowGroupSkip, + const std::vector& /*indexColumns*/ = {}) override { auto options = createWriterOptions(type); int32_t flushCounter = 0; // If we test row group skip, we have all the data in one stripe. For @@ -241,6 +242,62 @@ TEST_F(E2EFilterTest, floatAndDouble) { false); } +TEST_F(E2EFilterTest, DISABLED_shortDecimal) { + // ORC write functionality is not yet supported. Enable this test once it + // becomes available and set the file format to ORC at that time. + // options.format = DwrfFormat::kOrc; + const std::unordered_map types = { + {"shortdecimal_val:decimal(8, 5)", DECIMAL(8, 5)}, + {"shortdecimal_val:decimal(10, 5)", DECIMAL(10, 5)}, + {"shortdecimal_val:decimal(17, 5)", DECIMAL(17, 5)}}; + + for (const auto& pair : types) { + testWithTypes( + pair.first, + [&]() { + makeIntDistribution( + "shortdecimal_val", + 10, // min + 100, // max + 22, // repeats + 19, // rareFrequency + -999, // rareMin + 30000, // rareMax + true); + }, + false, + {"shortdecimal_val"}, + 20); + } +} + +TEST_F(E2EFilterTest, DISABLED_longDecimal) { + // ORC write functionality is not yet supported. Enable this test once it + // becomes available and set the file format to ORC at that time. + // options.format = DwrfFormat::kOrc; + const std::unordered_map types = { + {"longdecimal_val:decimal(30, 10)", DECIMAL(30, 10)}, + {"longdecimal_val:decimal(37, 15)", DECIMAL(37, 15)}}; + for (const auto& pair : types) { + testWithTypes( + pair.first, + [&]() { + makeIntDistribution( + "longdecimal_val", + 10, // min + 100, // max + 22, // repeats + 19, // rareFrequency + -999, // rareMin + 30000, // rareMax + true); + }, + false, + {"longdecimal_val"}, + 20); + } +} + TEST_F(E2EFilterTest, stringDirect) { testutil::TestValue::enable(); bool coverage[2][2]{}; @@ -446,6 +503,32 @@ TEST_F(E2EFilterTest, mutationCornerCases) { testMutationCornerCases(); } +// Verify processedStrides counts actual strides read, not stripes loaded. +// With multi-stride data, processedStrides + skippedStrides must equal the +// total number of strides across all stripes. +TEST_F(E2EFilterTest, processedStridesCount) { + rowType_ = test::DataSetBuilder::makeRowType("long_val:bigint", false); + filterGenerator_ = std::make_unique(rowType_, seed_); + auto customize = [&]() {}; + auto batches = makeDataset(customize, true, false); + writeToMemory(rowType_, batches, true); + std::vector filterable = {"long_val"}; + testRowGroupSkip(batches, filterable); + + const auto totalRows = kBatchCount * kBatchSize; + const auto totalStrides = (totalRows + kRowsInGroup - 1) / kRowsInGroup; + // With 4 batches x 25,000 rows in one stripe and kRowsInGroup=10,000, + // totalStrides is 10. The filter must skip some and process the rest. + // processedStrides > 1 proves we count per stride, not per stripe + // (the old bug gave processedStrides=1 for a single stripe). + EXPECT_EQ(10, totalStrides); + EXPECT_GT(runtimeStats_.skippedStrides, 0); + EXPECT_GT(runtimeStats_.processedStrides, 1); + EXPECT_EQ( + totalStrides, + runtimeStats_.skippedStrides + runtimeStats_.processedStrides); +} + // Define main so that gflags get processed. int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); diff --git a/velox/dwio/dwrf/test/E2EReaderTest.cpp b/velox/dwio/dwrf/test/E2EReaderTest.cpp index 6b696392910..05aaa4e84ef 100644 --- a/velox/dwio/dwrf/test/E2EReaderTest.cpp +++ b/velox/dwio/dwrf/test/E2EReaderTest.cpp @@ -20,6 +20,7 @@ #include "folly/Random.h" #include "folly/String.h" #include "velox/common/file/File.h" +#include "velox/common/io/IoStatistics.h" #include "velox/common/memory/Memory.h" #include "velox/dwio/common/FileSink.h" #include "velox/dwio/common/tests/utils/BatchMaker.h" @@ -73,11 +74,11 @@ class ValueTypes { } auto begin() const { - return values_.begin(); + return values_.cbegin(); } auto end() const { - return values_.end(); + return values_.cend(); } const std::shared_ptr& decodingExecutor() const { @@ -103,6 +104,11 @@ class E2EReaderTest : public testing::TestWithParam { static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); } + + std::shared_ptr dataIoStats_{ + std::make_shared()}; + std::shared_ptr metadataIoStats_{ + std::make_shared()}; }; } // namespace @@ -174,7 +180,9 @@ TEST_P(E2EReaderTest, SharedDictionaryFlatmapReadAsStruct) { writer->close(); writer.reset(); - dwio::common::ReaderOptions readerOpts{pool.get()}; + dwio::common::ReaderOptions readerOpts(pool.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); auto bufferedInput = std::make_unique( std::make_shared(path), *pool); auto reader = DwrfReader::create(std::move(bufferedInput), readerOpts); @@ -244,127 +252,159 @@ TEST_P(E2EReaderTest, SharedDictionaryFlatmapReadAsStruct) { INSTANTIATE_TEST_SUITE_P( SingleTypesSerialMap, E2EReaderTest, - ValuesIn(std::vector{ - ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"tinyint"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"smallint"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"integer"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"bigint"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"string"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"array"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"array"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"array"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"array"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"array"})})); + ValuesIn( + std::vector{ + ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"tinyint"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"smallint"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"integer"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"bigint"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"string"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"array"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"array"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"array"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"array"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"array"})})); INSTANTIATE_TEST_SUITE_P( SingleTypesSerialStruct, E2EReaderTest, - ValuesIn(std::vector{ - ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"tinyint"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"smallint"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"integer"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"bigint"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"string"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"array"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"array"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"array"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"array"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"array"})})); + ValuesIn( + std::vector{ + ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"tinyint"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"smallint"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"integer"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"bigint"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"string"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"array"}), + ValueTypes( + Decoding::SERIAL, + FlatMapAs::STRUCT, + {"array"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"array"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"array"}), + ValueTypes( + Decoding::SERIAL, + FlatMapAs::STRUCT, + {"array"})})); INSTANTIATE_TEST_SUITE_P( AllTypesSerialMap, E2EReaderTest, - ValuesIn(std::vector{ValueTypes( - Decoding::SERIAL, - FlatMapAs::MAP, - {"tinyint", - "smallint", - "integer", - "bigint", - "string", - "array", - "array", - "array", - "array", - "array"})})); + ValuesIn( + std::vector{ValueTypes( + Decoding::SERIAL, + FlatMapAs::MAP, + {"tinyint", + "smallint", + "integer", + "bigint", + "string", + "array", + "array", + "array", + "array", + "array"})})); INSTANTIATE_TEST_SUITE_P( AllTypesSerialStruct, E2EReaderTest, - ValuesIn(std::vector{ValueTypes( - Decoding::SERIAL, - FlatMapAs::STRUCT, - {"tinyint", - "smallint", - "integer", - "bigint", - "string", - "array", - "array", - "array", - "array", - "array"})})); + ValuesIn( + std::vector{ValueTypes( + Decoding::SERIAL, + FlatMapAs::STRUCT, + {"tinyint", + "smallint", + "integer", + "bigint", + "string", + "array", + "array", + "array", + "array", + "array"})})); INSTANTIATE_TEST_SUITE_P( SingleTypesParallelMap, E2EReaderTest, - ValuesIn(std::vector{ - ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"tinyint"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"smallint"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"integer"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"bigint"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"string"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"array"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"array"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"array"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"array"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"array"})})); + ValuesIn( + std::vector{ + ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"tinyint"}), + ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"smallint"}), + ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"integer"}), + ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"bigint"}), + ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"string"}), + ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"array"}), + ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"array"}), + ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"array"}), + ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"array"}), + ValueTypes( + Decoding::PARALLEL, + FlatMapAs::MAP, + {"array"})})); INSTANTIATE_TEST_SUITE_P( SingleTypesParallelStruct, E2EReaderTest, - ValuesIn(std::vector{ - ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"tinyint"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"smallint"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"integer"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"bigint"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"string"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"array"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"array"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"array"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"array"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"array"})})); + ValuesIn( + std::vector{ + ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"tinyint"}), + ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"smallint"}), + ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"integer"}), + ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"bigint"}), + ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"string"}), + ValueTypes( + Decoding::PARALLEL, + FlatMapAs::STRUCT, + {"array"}), + ValueTypes( + Decoding::PARALLEL, + FlatMapAs::STRUCT, + {"array"}), + ValueTypes( + Decoding::PARALLEL, + FlatMapAs::STRUCT, + {"array"}), + ValueTypes( + Decoding::PARALLEL, + FlatMapAs::STRUCT, + {"array"}), + ValueTypes( + Decoding::PARALLEL, + FlatMapAs::STRUCT, + {"array"})})); INSTANTIATE_TEST_SUITE_P( AllTypesParallelMap, E2EReaderTest, - ValuesIn(std::vector{ValueTypes( - Decoding::PARALLEL, - FlatMapAs::MAP, - {"tinyint", - "smallint", - "integer", - "bigint", - "string", - "array", - "array", - "array", - "array", - "array"})})); + ValuesIn( + std::vector{ValueTypes( + Decoding::PARALLEL, + FlatMapAs::MAP, + {"tinyint", + "smallint", + "integer", + "bigint", + "string", + "array", + "array", + "array", + "array", + "array"})})); INSTANTIATE_TEST_SUITE_P( AllTypesParallelStruct, E2EReaderTest, - ValuesIn(std::vector{ValueTypes( - Decoding::PARALLEL, - FlatMapAs::STRUCT, - {"tinyint", - "smallint", - "integer", - "bigint", - "string", - "array", - "array", - "array", - "array", - "array"})})); + ValuesIn( + std::vector{ValueTypes( + Decoding::PARALLEL, + FlatMapAs::STRUCT, + {"tinyint", + "smallint", + "integer", + "bigint", + "string", + "array", + "array", + "array", + "array", + "array"})})); diff --git a/velox/dwio/dwrf/test/E2EWriterTest.cpp b/velox/dwio/dwrf/test/E2EWriterTest.cpp index 321dcbf4802..6ea0ca99849 100644 --- a/velox/dwio/dwrf/test/E2EWriterTest.cpp +++ b/velox/dwio/dwrf/test/E2EWriterTest.cpp @@ -18,6 +18,7 @@ #include #include "velox/common/base/SpillConfig.h" #include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/io/IoStatistics.h" #include "velox/common/memory/tests/SharedArbitratorTestUtil.h" #include "velox/common/testutil/TestValue.h" #include "velox/dwio/common/Options.h" @@ -30,6 +31,7 @@ #include "velox/dwio/dwrf/reader/DwrfReader.h" #include "velox/dwio/dwrf/test/OrcTest.h" #include "velox/dwio/dwrf/test/utils/E2EWriterTestUtil.h" +#include "velox/dwio/dwrf/writer/StatisticsBuilder.h" #include "velox/type/fbhive/HiveTypeParser.h" #include "velox/vector/fuzzer/VectorFuzzer.h" #include "velox/vector/tests/utils/VectorMaker.h" @@ -61,6 +63,11 @@ class E2EWriterTest : public testing::Test { leafPool_ = rootPool_->addLeafChild("leaf"); } + std::shared_ptr dataIoStats_ = + std::make_shared(); + std::shared_ptr metadataIoStats_ = + std::make_shared(); + static std::unique_ptr createReader( const MemorySink& sink, const dwio::common::ReaderOptions& opts) { @@ -103,7 +110,9 @@ class E2EWriterTest : public testing::Test { writer.close(); - dwio::common::ReaderOptions readerOpts{leafPool_.get()}; + dwio::common::ReaderOptions readerOpts(leafPool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); RowReaderOptions rowReaderOpts; auto reader = createReader(*sinkPtr, readerOpts); auto rowReader = reader->createRowReader(rowReaderOpts); @@ -165,7 +174,9 @@ class E2EWriterTest : public testing::Test { writer.close(); - dwio::common::ReaderOptions readerOpts{leafPool_.get()}; + dwio::common::ReaderOptions readerOpts(leafPool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); RowReaderOptions rowReaderOpts; auto reader = createReader(*sinkPtr, readerOpts); auto rowReader = reader->createRowReader(rowReaderOpts); @@ -205,7 +216,7 @@ class E2EWriterTest : public testing::Test { dwrf::EncodingKey seqEk(valueTypeId, sequence); const auto& keyInfo = stripeStreams.getEncoding(seqEk).key(); - auto key = dwrf::constructKey(keyInfo); + auto key = dwrf::MapStatisticsBuilder::constructKey(keyInfo); sequenceToKey.emplace(sequence, key); }); @@ -235,7 +246,8 @@ class E2EWriterTest : public testing::Test { const auto& entry = stats.mapStatistics().stats(i); ASSERT_TRUE(entry.stats().has_size()); EXPECT_EQ( - featureStreamSizes.at(dwrf::constructKey(entry.key())), + featureStreamSizes.at( + dwrf::MapStatisticsBuilder::constructKey(entry.key())), entry.stats().size()); } } @@ -261,7 +273,8 @@ class E2EWriterTest : public testing::Test { 0, 0, writerFlushThresholdSize, - "none"); + "none", + 0); } std::shared_ptr rootPool_; @@ -565,7 +578,9 @@ TEST_F(E2EWriterTest, PresentStreamIsSuppressedOnFlatMap) { config, dwrf::E2EWriterTestUtil::simpleFlushPolicyFactory(true)); - dwio::common::ReaderOptions readerOpts{leafPool_.get()}; + dwio::common::ReaderOptions readerOpts(leafPool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); RowReaderOptions rowReaderOpts; auto reader = createReader(*sinkPtr, readerOpts); auto rowReader = reader->createRowReader(rowReaderOpts); @@ -937,7 +952,9 @@ TEST_F(E2EWriterTest, PartialStride) { writer.write(batch); writer.close(); - dwio::common::ReaderOptions readerOpts{leafPool_.get()}; + dwio::common::ReaderOptions readerOpts(leafPool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); RowReaderOptions rowReaderOpts; auto reader = createReader(*sinkPtr, readerOpts); ASSERT_EQ( @@ -1137,7 +1154,9 @@ class E2EEncryptionTest : public E2EWriterTest { writer_->close(); // read it back for compare - dwio::common::ReaderOptions readerOpts{leafPool_.get()}; + dwio::common::ReaderOptions readerOpts(leafPool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); readerOpts.setDecrypterFactory(decrypterFactory); return createReader(*sink_, readerOpts); } @@ -2140,9 +2159,10 @@ DEBUG_ONLY_TEST_F(E2EWriterTest, memoryReclaimThreshold) { } const std::vector writerFlushThresholdSizes = {0, 1L << 30}; for (uint64_t writerFlushThresholdSize : writerFlushThresholdSizes) { - SCOPED_TRACE(fmt::format( - "writerFlushThresholdSize {}", - succinctBytes(writerFlushThresholdSize))); + SCOPED_TRACE( + fmt::format( + "writerFlushThresholdSize {}", + succinctBytes(writerFlushThresholdSize))); const common::SpillConfig spillConfig = getSpillConfig(10, 20, writerFlushThresholdSize); diff --git a/velox/dwio/dwrf/test/EncodingManagerTests.cpp b/velox/dwio/dwrf/test/EncodingManagerTests.cpp index 681e0908157..e53c276e29c 100644 --- a/velox/dwio/dwrf/test/EncodingManagerTests.cpp +++ b/velox/dwio/dwrf/test/EncodingManagerTests.cpp @@ -34,10 +34,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, -1, - footer.encoding().begin(), - footer.encoding().end()}; - EXPECT_EQ(footer.encoding().begin(), iter.current_); - EXPECT_EQ(footer.encoding().end(), iter.currentEnd_); + footer.encoding().cbegin(), + footer.encoding().cend()}; + EXPECT_EQ(footer.encoding().cbegin(), iter.current_); + EXPECT_EQ(footer.encoding().cend(), iter.currentEnd_); } { // A valid end iterator. @@ -45,10 +45,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, -1, - footer.encoding().end(), - footer.encoding().end()}; - EXPECT_EQ(footer.encoding().end(), iter.current_); - EXPECT_EQ(footer.encoding().end(), iter.currentEnd_); + footer.encoding().cend(), + footer.encoding().cend()}; + EXPECT_EQ(footer.encoding().cend(), iter.current_); + EXPECT_EQ(footer.encoding().cend(), iter.currentEnd_); } footer.add_encoding(); // footer [e] @@ -58,10 +58,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, -1, - footer.encoding().begin(), - footer.encoding().end()}; - EXPECT_EQ(footer.encoding().begin(), iter.current_); - EXPECT_EQ(footer.encoding().end(), iter.currentEnd_); + footer.encoding().cbegin(), + footer.encoding().cend()}; + EXPECT_EQ(footer.encoding().cbegin(), iter.current_); + EXPECT_EQ(footer.encoding().cend(), iter.currentEnd_); } { // A valid end iterator. @@ -69,10 +69,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, -1, - footer.encoding().end(), - footer.encoding().end()}; - EXPECT_EQ(footer.encoding().end(), iter.current_); - EXPECT_EQ(footer.encoding().end(), iter.currentEnd_); + footer.encoding().cend(), + footer.encoding().cend()}; + EXPECT_EQ(footer.encoding().cend(), iter.current_); + EXPECT_EQ(footer.encoding().cend(), iter.currentEnd_); } proto::StripeEncryptionGroup group1; proto::StripeEncryptionGroup group2; @@ -85,10 +85,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, -1, - footer.encoding().begin(), - footer.encoding().end()}; - EXPECT_EQ(footer.encoding().begin(), iter.current_); - EXPECT_EQ(footer.encoding().end(), iter.currentEnd_); + footer.encoding().cbegin(), + footer.encoding().cend()}; + EXPECT_EQ(footer.encoding().cbegin(), iter.current_); + EXPECT_EQ(footer.encoding().cend(), iter.currentEnd_); } { // A valid end iterator. @@ -96,10 +96,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, 1, - encryptionGroups.at(1).encoding().end(), - encryptionGroups.at(1).encoding().end()}; - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.current_); - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.currentEnd_); + encryptionGroups.at(1).encoding().cend(), + encryptionGroups.at(1).encoding().cend()}; + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.current_); + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.currentEnd_); } { // An adjusted end iterator. @@ -107,10 +107,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, -1, - footer.encoding().end(), - footer.encoding().end()}; - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.current_); - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.currentEnd_); + footer.encoding().cend(), + footer.encoding().cend()}; + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.current_); + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.currentEnd_); } encryptionGroups[1].add_encoding(); // footer [e] @@ -120,10 +120,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, -1, - footer.encoding().begin(), - footer.encoding().end()}; - EXPECT_EQ(footer.encoding().begin(), iter.current_); - EXPECT_EQ(footer.encoding().end(), iter.currentEnd_); + footer.encoding().cbegin(), + footer.encoding().cend()}; + EXPECT_EQ(footer.encoding().cbegin(), iter.current_); + EXPECT_EQ(footer.encoding().cend(), iter.currentEnd_); } { // An adjusted iterator. @@ -131,10 +131,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, -1, - footer.encoding().end(), - footer.encoding().end()}; - EXPECT_EQ(encryptionGroups.at(1).encoding().begin(), iter.current_); - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.currentEnd_); + footer.encoding().cend(), + footer.encoding().cend()}; + EXPECT_EQ(encryptionGroups.at(1).encoding().cbegin(), iter.current_); + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.currentEnd_); } { // An adjusted iterator. @@ -142,10 +142,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, 0, - encryptionGroups.at(0).encoding().begin(), - encryptionGroups.at(0).encoding().end()}; - EXPECT_EQ(encryptionGroups.at(1).encoding().begin(), iter.current_); - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.currentEnd_); + encryptionGroups.at(0).encoding().cbegin(), + encryptionGroups.at(0).encoding().cend()}; + EXPECT_EQ(encryptionGroups.at(1).encoding().cbegin(), iter.current_); + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.currentEnd_); } { // A valid end iterator. @@ -153,10 +153,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, 1, - encryptionGroups.at(1).encoding().end(), - encryptionGroups.at(1).encoding().end()}; - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.current_); - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.currentEnd_); + encryptionGroups.at(1).encoding().cend(), + encryptionGroups.at(1).encoding().cend()}; + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.current_); + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.currentEnd_); } footer.Clear(); // footer [] @@ -167,10 +167,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, -1, - footer.encoding().begin(), - footer.encoding().end()}; - EXPECT_EQ(encryptionGroups.at(1).encoding().begin(), iter.current_); - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.currentEnd_); + footer.encoding().cbegin(), + footer.encoding().cend()}; + EXPECT_EQ(encryptionGroups.at(1).encoding().cbegin(), iter.current_); + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.currentEnd_); } { // An adjusted iterator further back. @@ -178,10 +178,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, -1, - footer.encoding().end(), - footer.encoding().end()}; - EXPECT_EQ(encryptionGroups.at(1).encoding().begin(), iter.current_); - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.currentEnd_); + footer.encoding().cend(), + footer.encoding().cend()}; + EXPECT_EQ(encryptionGroups.at(1).encoding().cbegin(), iter.current_); + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.currentEnd_); } encryptionGroups.at(1).Clear(); // footer [] @@ -192,10 +192,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, 0, - encryptionGroups.at(0).encoding().end(), - encryptionGroups.at(0).encoding().end()}; - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.current_); - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.currentEnd_); + encryptionGroups.at(0).encoding().cend(), + encryptionGroups.at(0).encoding().cend()}; + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.current_); + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.currentEnd_); } { // An adjusted end iterator further back. @@ -203,10 +203,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, -1, - footer.encoding().end(), - footer.encoding().end()}; - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.current_); - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.currentEnd_); + footer.encoding().cend(), + footer.encoding().cend()}; + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.current_); + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.currentEnd_); } } @@ -227,15 +227,15 @@ TEST(TestEncodingIter, EncodingIterBeginAndEnd) { footer, encryptionGroups, -1, - footer.encoding().begin(), - footer.encoding().end()}; + footer.encoding().cbegin(), + footer.encoding().cend()}; EXPECT_EQ(begin, EncodingIter::begin(footer, encryptionGroups)); EncodingIter end{ footer, encryptionGroups, 1, - encryptionGroups.at(1).encoding().end(), - encryptionGroups.at(1).encoding().end()}; + encryptionGroups.at(1).encoding().cend(), + encryptionGroups.at(1).encoding().cend()}; EXPECT_EQ(end, EncodingIter::end(footer, encryptionGroups)); } diff --git a/velox/dwio/dwrf/test/FlushPolicyTest.cpp b/velox/dwio/dwrf/test/FlushPolicyTest.cpp index ed10f458aa2..b045326664e 100644 --- a/velox/dwio/dwrf/test/FlushPolicyTest.cpp +++ b/velox/dwio/dwrf/test/FlushPolicyTest.cpp @@ -51,8 +51,9 @@ TEST_F(DefaultFlushPolicyTest, StripeProgressTest) { testCase.stripeSizeThreshold, /*dictionarySizeThreshold=*/0}; EXPECT_EQ( testCase.shouldFlush, - policy.shouldFlush(dwio::common::StripeProgress{ - .stripeSizeEstimate = testCase.stripeSize})); + policy.shouldFlush( + dwio::common::StripeProgress{ + .stripeSizeEstimate = testCase.stripeSize})); } } @@ -115,8 +116,8 @@ TEST_F(DefaultFlushPolicyTest, AdditionalCriteriaTest) { .dictionarySize = 42, .decision = FlushDecision::SKIP}}; for (const auto& testCase : testCases) { - DefaultFlushPolicy policy{ - /*stripeSizeThreshold=*/1000, testCase.dictionarySizeThreshold}; + DefaultFlushPolicy policy{/*stripeSizeThreshold=*/1000, + testCase.dictionarySizeThreshold}; EXPECT_EQ( testCase.decision, policy.shouldFlushDictionary( diff --git a/velox/dwio/dwrf/test/IndexBuilderTests.cpp b/velox/dwio/dwrf/test/IndexBuilderTests.cpp index 66969d00790..ee5914e236a 100644 --- a/velox/dwio/dwrf/test/IndexBuilderTests.cpp +++ b/velox/dwio/dwrf/test/IndexBuilderTests.cpp @@ -25,17 +25,17 @@ namespace facebook::velox::dwrf { class IndexBuilderTest : public testing::Test { protected: - static const proto::RowIndexEntry& getEntry( + static const RowIndexEntryWriteWrapper getEntry( IndexBuilder& builder, size_t index) { - return *builder.getEntry(index); + return builder.getEntry(index); } static std::vector getPositions( IndexBuilder& builder, size_t index) { - auto& positions = builder.getEntry(index)->positions(); - return std::vector{positions.begin(), positions.end()}; + auto& positions = builder.getEntry(index).positions(); + return std::vector{positions.cbegin(), positions.cend()}; } StatisticsBuilderOptions options_{16}; @@ -45,7 +45,7 @@ TEST_F(IndexBuilderTest, Constructor) { IndexBuilder builder{nullptr}; EXPECT_EQ(1, builder.getEntrySize()); // Ensure a clean start. - EXPECT_EQ(0, getEntry(builder, 0).positions_size()); + EXPECT_EQ(0, getEntry(builder, 0).positionsSize()); } TEST_F(IndexBuilderTest, AddEntry) { @@ -65,8 +65,8 @@ TEST_F(IndexBuilderTest, AddEntry) { ASSERT_EQ(51, builder.getEntrySize()); for (size_t i = 0; i != 50; ++i) { // The newly added entries should be empty. - EXPECT_EQ(0, getEntry(builder, i + 1).positions_size()); - EXPECT_TRUE(getEntry(builder, i).has_statistics()); + EXPECT_EQ(0, getEntry(builder, i + 1).positionsSize()); + EXPECT_TRUE(getEntry(builder, i).hasStatistics()); } } @@ -94,9 +94,9 @@ TEST_F(IndexBuilderTest, Backfill) { IndexBuilder builder{nullptr}; StatisticsBuilder sb{options_}; builder.addEntry(sb); - ASSERT_EQ(0, getEntry(builder, 0).positions_size()); + ASSERT_EQ(0, getEntry(builder, 0).positionsSize()); builder.add(0uL); - ASSERT_EQ(0, getEntry(builder, 0).positions_size()); + ASSERT_EQ(0, getEntry(builder, 0).positionsSize()); ASSERT_THAT(getPositions(builder, 1), ElementsAreArray({0uL})); builder.add(42uL, 0); @@ -112,16 +112,16 @@ TEST_F(IndexBuilderTest, Backfill) { builder.addEntry(sb); } for (size_t i = 2; i != 7; ++i) { - ASSERT_EQ(0, getEntry(builder, i).positions_size()); + ASSERT_EQ(0, getEntry(builder, i).positionsSize()); } builder.add(144uL, 4); EXPECT_THAT(getPositions(builder, 0), ElementsAreArray({42uL})); EXPECT_THAT(getPositions(builder, 1), ElementsAreArray({0uL, 0uL, 7uL})); - ASSERT_EQ(0, getEntry(builder, 2).positions_size()); - ASSERT_EQ(0, getEntry(builder, 3).positions_size()); + ASSERT_EQ(0, getEntry(builder, 2).positionsSize()); + ASSERT_EQ(0, getEntry(builder, 3).positionsSize()); EXPECT_THAT(getPositions(builder, 4), ElementsAreArray({144uL})); - ASSERT_EQ(0, getEntry(builder, 5).positions_size()); - ASSERT_EQ(0, getEntry(builder, 6).positions_size()); + ASSERT_EQ(0, getEntry(builder, 5).positionsSize()); + ASSERT_EQ(0, getEntry(builder, 6).positionsSize()); } TEST_F(IndexBuilderTest, RemovePresentStreamPositions) { diff --git a/velox/dwio/dwrf/test/IntEncoderBenchmark.cpp b/velox/dwio/dwrf/test/IntEncoderBenchmark.cpp index 644ddab0c58..e79fa02bf75 100644 --- a/velox/dwio/dwrf/test/IntEncoderBenchmark.cpp +++ b/velox/dwio/dwrf/test/IntEncoderBenchmark.cpp @@ -117,8 +117,9 @@ FOLLY_ALWAYS_INLINE static int32_t findSetBitsNew(uint64_t value) { case 57 ... 63: return 1; } - DWIO_RAISE(folly::sformat( - "Unexpected leading zeros {} for value {}", leadingZeros, value)); + DWIO_RAISE( + folly::sformat( + "Unexpected leading zeros {} for value {}", leadingZeros, value)); } size_t iters = 2000; diff --git a/velox/dwio/dwrf/test/LayoutPlannerTests.cpp b/velox/dwio/dwrf/test/LayoutPlannerTests.cpp index a2de14e8e91..32cf1a5062b 100644 --- a/velox/dwio/dwrf/test/LayoutPlannerTests.cpp +++ b/velox/dwio/dwrf/test/LayoutPlannerTests.cpp @@ -128,12 +128,12 @@ TEST_F(LayoutPlannerTest, Basic) { uint32_t seq, proto::ColumnEncoding_Kind kind, std::optional key = std::nullopt) { - auto& encoding = encodingManager.addEncodingToFooter(node); - encoding.set_node(node); - encoding.set_sequence(seq); - encoding.set_kind(kind); + auto encoding = encodingManager.addEncodingToFooter(node); + encoding.setNode(node); + encoding.setSequence(seq); + encoding.setKind(ColumnEncodingKindWrapper(&kind)); if (key.has_value()) { - encoding.mutable_key()->set_intkey(*key); + encoding.mutableKey()->set_intkey(*key); } }; diff --git a/velox/dwio/dwrf/test/OrcTest.h b/velox/dwio/dwrf/test/OrcTest.h index 3bd87c45f07..6b0b3987ac2 100644 --- a/velox/dwio/dwrf/test/OrcTest.h +++ b/velox/dwio/dwrf/test/OrcTest.h @@ -28,8 +28,6 @@ namespace facebook::velox::dwrf { -#define VELOX_ARRAY_SIZE(array) (sizeof(array) / sizeof(*array)) - using MemoryPool = memory::MemoryPool; inline std::string getExampleFilePath(const std::string& fileName) { diff --git a/velox/dwio/dwrf/test/ReaderBaseTests.cpp b/velox/dwio/dwrf/test/ReaderBaseTests.cpp index 04f0685c252..08eb3897c41 100644 --- a/velox/dwio/dwrf/test/ReaderBaseTests.cpp +++ b/velox/dwio/dwrf/test/ReaderBaseTests.cpp @@ -16,7 +16,10 @@ #include #include + #include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/io/IoStatistics.h" +#include "velox/dwio/common/Arena.h" #include "velox/dwio/common/InputStream.h" #include "velox/dwio/common/encryption/TestProvider.h" #include "velox/dwio/common/exception/Exception.h" @@ -72,14 +75,14 @@ class EncryptedStatsTest : public Test { TestEncrypter encrypter; HiveTypeParser parser; auto type = parser.parse("struct,c:int,d:int>"); - auto footer = - google::protobuf::Arena::CreateMessage(&arena_); + auto footer = ArenaCreate(&arena_); + auto footerWrapper = FooterWriteWrapper(footer); // add empty stats to the file for (size_t i = 0; i < 7; ++i) { - footer->add_statistics()->set_numberofvalues(i); + footerWrapper.addStatistics().setNumberOfValues(i); } - ProtoUtils::writeType(*type, *footer); - auto enc = footer->mutable_encryption(); + ProtoUtils::writeType(*type, footerWrapper); + auto enc = footerWrapper.mutableEncryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); auto group = enc->add_encryptiongroups(); group->add_nodes(1); @@ -107,9 +110,11 @@ class EncryptedStatsTest : public Test { *readerPool_, std::make_unique(readFile, *readerPool_), std::make_unique(std::move(ps)), - footer, + footerWrapper.getDwrfPtr(), nullptr, - std::move(handler)); + std::move(handler), + dataIoStats_, + metadataIoStats_); } void clearKey(uint32_t groupIdx) { @@ -125,6 +130,10 @@ class EncryptedStatsTest : public Test { std::shared_ptr pool_; std::shared_ptr sinkPool_; std::shared_ptr readerPool_; + std::shared_ptr dataIoStats_ = + std::make_shared(); + std::shared_ptr metadataIoStats_ = + std::make_shared(); }; TEST_F(EncryptedStatsTest, statistics) { @@ -213,6 +222,10 @@ std::unique_ptr createCorruptedFileReader( auto readFile = std::make_shared( std::string(sink.data(), sink.size())); facebook::velox::dwio::common::ReaderOptions readerOpts{pool.get()}; + readerOpts.setDataIoStats( + std::make_shared()); + readerOpts.setMetadataIoStats( + std::make_shared()); return std::make_unique( readerOpts, std::make_unique(readFile, *pool)); } diff --git a/velox/dwio/dwrf/test/ReaderTest.cpp b/velox/dwio/dwrf/test/ReaderTest.cpp index 08b3bc87f10..6a34d8575fd 100644 --- a/velox/dwio/dwrf/test/ReaderTest.cpp +++ b/velox/dwio/dwrf/test/ReaderTest.cpp @@ -18,8 +18,13 @@ #include #include "folly/Random.h" #include "folly/executors/CPUThreadPoolExecutor.h" +#include "folly/executors/IOThreadPoolExecutor.h" #include "folly/lang/Assume.h" +#include "folly/synchronization/Baton.h" #include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/testutil/TestValue.h" +#include "velox/connectors/hive/ExtractionUtils.h" +#include "velox/connectors/hive/HiveConnectorUtil.h" #include "velox/dwio/common/ExecutorBarrier.h" #include "velox/dwio/common/FileSink.h" #include "velox/dwio/common/tests/utils/BatchMaker.h" @@ -27,6 +32,7 @@ #include "velox/dwio/dwrf/reader/DwrfReader.h" #include "velox/dwio/dwrf/test/OrcTest.h" #include "velox/dwio/dwrf/test/utils/E2EWriterTestUtil.h" +#include "velox/dwio/dwrf/writer/Writer.h" #include "velox/type/fbhive/HiveTypeParser.h" #include "velox/vector/ComplexVector.h" #include "velox/vector/FlatVector.h" @@ -38,14 +44,16 @@ #include #include +#include "velox/common/io/IoStatistics.h" + +namespace facebook::velox::dwrf { +namespace { + using namespace ::testing; using namespace facebook::velox::dwio::common; using namespace facebook::velox::type::fbhive; -using namespace facebook::velox; -using namespace facebook::velox::dwrf; using namespace facebook::velox::test; -namespace { const std::string& getStructFile() { static const std::string structFile_ = getExampleFilePath("struct.orc"); return structFile_; @@ -74,12 +82,35 @@ const std::shared_ptr& getFlatmapSchema() { return schema_; } +std::vector makeSubfields( + const std::vector& paths) { + std::vector subfields; + subfields.reserve(paths.size()); + for (auto& path : paths) { + subfields.emplace_back(path); + } + return subfields; +} + +folly::F14FastMap> +groupSubfields(const std::vector& subfields) { + folly::F14FastMap> grouped; + for (auto& subfield : subfields) { + auto& name = + static_cast(*subfield.path()[0]) + .name(); + grouped[name].push_back(&subfield); + } + return grouped; +} + class TestReaderP : public testing::TestWithParam, public VectorTestBase { protected: static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + facebook::velox::common::testutil::TestValue::enable(); } folly::Executor* executor() { @@ -96,12 +127,19 @@ class TestReaderP private: std::unique_ptr executor_; + + protected: + std::shared_ptr dataIoStats_{ + std::make_shared()}; + std::shared_ptr metadataIoStats_{ + std::make_shared()}; }; class TestReader : public testing::Test, public VectorTestBase { protected: static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + facebook::velox::common::testutil::TestValue::enable(); } std::vector createBatches( @@ -114,9 +152,12 @@ class TestReader : public testing::Test, public VectorTestBase { } return batches; } -}; -} // namespace + std::shared_ptr dataIoStats_{ + std::make_shared()}; + std::shared_ptr metadataIoStats_{ + std::make_shared()}; +}; TEST_F(TestReader, testWriterVersions) { EXPECT_EQ("original", writerVersionToString(ORIGINAL)); @@ -238,13 +279,17 @@ void verifyFlatMapReading( const int32_t expectedBatchSize[], const int32_t numBatches, bool returnFlatVector, + const std::shared_ptr& dataIoStats, + const std::shared_ptr& metadataIoStats, const std::vector& expectedPrefetchRowSizes = {}, const std::vector& shouldTryPrefetch = {}) { dwio::common::ReaderOptions readerOpts{pool}; + readerOpts.setDataIoStats(dataIoStats); + readerOpts.setMetadataIoStats(metadataIoStats); /* If an extra sanity check is desired you can uncomment the 2 below lines and * re-run */ - // readerOpts.setFooterEstimatedSize(257); + // readerOpts.setFooterSpeculativeIoSize(257); // readerOpts.setFilePreloadThreshold(0); RowReaderOptions rowReaderOpts; @@ -354,6 +399,11 @@ class TestFlatMapReader : public TestWithParam, public VectorTestBase { static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); } + + std::shared_ptr dataIoStats_{ + std::make_shared()}; + std::shared_ptr metadataIoStats_{ + std::make_shared()}; }; TEST_P(TestFlatMapReader, testReadFlatMapEmptyMap) { @@ -361,6 +411,8 @@ TEST_P(TestFlatMapReader, testReadFlatMapEmptyMap) { auto returnFlatVector = GetParam(); dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); RowReaderOptions rowReaderOpts; rowReaderOpts.setReturnFlatVector(returnFlatVector); std::shared_ptr emptyFileType = @@ -391,6 +443,8 @@ TEST_P(TestFlatMapReader, testStringKeyLifeCycle) { VectorPtr batch; dwio::common::ReaderOptions readerOptions{pool()}; + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); { RowReaderOptions rowReaderOptions; @@ -454,7 +508,9 @@ TEST_P(TestFlatMapReader, testReadFlatMapSampleSmallSkips) { seeks.data(), expectedBatchSize.data(), expectedBatchSize.size(), - returnFlatVector); + returnFlatVector, + dataIoStats_, + metadataIoStats_); } TEST_P(TestFlatMapReader, testReadFlatMapSampleSmall) { @@ -469,7 +525,9 @@ TEST_P(TestFlatMapReader, testReadFlatMapSampleSmall) { seeks.data(), expectedBatchSize.data(), expectedBatchSize.size(), - returnFlatVector); + returnFlatVector, + dataIoStats_, + metadataIoStats_); } TEST_P(TestFlatMapReader, testReadFlatMapSampleLarge) { @@ -487,7 +545,9 @@ TEST_P(TestFlatMapReader, testReadFlatMapSampleLarge) { seeks.data(), expectedBatchSize.data(), expectedBatchSize.size(), - returnFlatVector); + returnFlatVector, + dataIoStats_, + metadataIoStats_); } VELOX_INSTANTIATE_TEST_SUITE_P( @@ -502,10 +562,17 @@ class TestFlatMapReaderFlatLayout static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); } + + std::shared_ptr dataIoStats_{ + std::make_shared()}; + std::shared_ptr metadataIoStats_{ + std::make_shared()}; }; TEST_P(TestFlatMapReaderFlatLayout, testCompare) { dwio::common::ReaderOptions readerOptions{pool()}; + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto reader = DwrfReader::create( createFileBufferedInput(getFMSmallFile(), readerOptions.memoryPool()), readerOptions); @@ -539,6 +606,8 @@ TEST_F(TestReader, testReadFlatMapWithKeyFilters) { // batch size is set as 1000 in reading // file has schema: a int, b struct, c float dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); RowReaderOptions rowReaderOpts; // set map key filter for map1 we only need key=1, and map2 only key-1 auto cs = std::make_shared( @@ -592,6 +661,8 @@ TEST_F(TestReader, testReadFlatMapWithKeyRejectList) { // batch size is set as 1000 in reading // file has schema: a int, b struct, c float dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); RowReaderOptions rowReaderOpts; auto cs = std::make_shared( getFlatmapSchema(), std::vector{"map1#[\"!2\",\"!3\"]"}); @@ -647,6 +718,8 @@ TEST_F(TestReader, testStatsCallbackFiredWithFiltering) { }); dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); auto reader = DwrfReader::create( createFileBufferedInput(getFMSmallFile(), readerOpts.memoryPool()), @@ -685,6 +758,8 @@ TEST_F(TestReader, testBlockedIoCallbackFiredBlocking) { rowReaderOpts.setEagerFirstStripeLoad(false); dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); auto reader = DwrfReader::create( createFileBufferedInput(getFMLargeFile(), readerOpts.memoryPool()), @@ -729,6 +804,8 @@ TEST_F(TestReader, DISABLED_testBlockedIoCallbackFiredNonBlocking) { rowReaderOpts.setEagerFirstStripeLoad(false); dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); auto reader = DwrfReader::create( createFileBufferedInput(getFMLargeFile(), readerOpts.memoryPool()), @@ -778,6 +855,8 @@ TEST_F(TestReader, DISABLED_testBlockedIoCallbackFiredWithFirstStripeLoad) { rowReaderOpts.setEagerFirstStripeLoad(true); dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); auto reader = DwrfReader::create( createFileBufferedInput(getFMLargeFile(), readerOpts.memoryPool()), @@ -813,6 +892,8 @@ TEST_F(TestReader, DISABLED_testBlockedIoCallbackFiredWithFirstStripeLoad) { TEST_F(TestReader, testEstimatedSize) { dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); { auto reader = DwrfReader::create( createFileBufferedInput(getFMSmallFile(), readerOpts.memoryPool()), @@ -839,6 +920,64 @@ TEST_F(TestReader, testEstimatedSize) { } } +TEST_F(TestReader, testSubfieldEstimatedSize) { + dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); + std::shared_ptr schema = + std::dynamic_pointer_cast(HiveTypeParser().parse("struct<\ + a:int,\ + b:struct<\ + a:int,\ + b:float,\ + c:string>,\ + c:float>")); + + std::shared_ptr outputType = + std::dynamic_pointer_cast(HiveTypeParser().parse("struct<\ + a:int,\ + b:struct<\ + a:int,\ + b:float,\ + c:string>>")); + // estimation with subfield filtering + auto subfields = makeSubfields({"a", "b.b"}); + folly::F14FastMap> + subfieldsByName = groupSubfields(subfields); + auto scanSpec = velox::connector::hive::makeScanSpec( + outputType, subfieldsByName, {}, {}, schema, {}, {}, {}, true, pool()); + readerOpts.setScanSpec(scanSpec); + + auto reader = DwrfReader::create( + createFileBufferedInput(getStructFile(), readerOpts.memoryPool()), + readerOpts); + RowReaderOptions rowReaderOpts; + rowReaderOpts.setScanSpec(scanSpec); + + auto rowReader = reader->createRowReader(rowReaderOpts); + ASSERT_EQ(rowReader->estimatedRowSize(), 8); + + // estimation with full struct field selection + dwio::common::ReaderOptions readerOpts2{pool()}; + readerOpts2.setDataIoStats(dataIoStats_); + readerOpts2.setMetadataIoStats(metadataIoStats_); + auto subfields2 = makeSubfields({"a", "b"}); + folly::F14FastMap> + subfields2ByName = groupSubfields(subfields2); + auto scanSpec2 = velox::connector::hive::makeScanSpec( + outputType, subfields2ByName, {}, {}, schema, {}, {}, {}, true, pool()); + readerOpts2.setScanSpec(scanSpec2); + + auto reader2 = DwrfReader::create( + createFileBufferedInput(getStructFile(), readerOpts2.memoryPool()), + readerOpts2); + RowReaderOptions rowReaderOpts2; + rowReaderOpts2.setScanSpec(scanSpec2); + + auto rowReader2 = reader2->createRowReader(rowReaderOpts2); + ASSERT_EQ(rowReader2->estimatedRowSize(), 15); +} + TEST_F(TestReader, testStatsCallbackFiredWithoutFiltering) { RowReaderOptions rowReaderOpts; // Don't apply feature projection here @@ -858,6 +997,8 @@ TEST_F(TestReader, testStatsCallbackFiredWithoutFiltering) { }); dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); auto reader = DwrfReader::create( createFileBufferedInput(getFMSmallFile(), readerOpts.memoryPool()), @@ -941,8 +1082,12 @@ void verifyFlatmapStructEncoding( const std::string& filename, const std::vector& keysAsFields, const std::vector& keysToSelect, + const std::shared_ptr& dataIoStats, + const std::shared_ptr& metadataIoStats, size_t batchSize = 1000) { dwio::common::ReaderOptions readerOpts{pool}; + readerOpts.setDataIoStats(dataIoStats); + readerOpts.setMetadataIoStats(metadataIoStats); auto reader = DwrfReader::create( createFileBufferedInput(filename, readerOpts.memoryPool()), readerOpts); @@ -1009,7 +1154,9 @@ TEST_F(TestReader, testFlatmapAsStructSmall) { pool(), getFMSmallFile(), {1, 2, 3, 4, 5, -99999999 /* does not exist */}, - {} /* no key filtering */); + {} /* no key filtering */, + dataIoStats_, + metadataIoStats_); } TEST_F(TestReader, testFlatmapAsStructSmallEmptyInmap) { @@ -1018,6 +1165,8 @@ TEST_F(TestReader, testFlatmapAsStructSmallEmptyInmap) { getFMSmallFile(), {1, 2, 3, 4, 5, -99999999 /* does not exist */}, {} /* no key filtering */, + dataIoStats_, + metadataIoStats_, 2); } @@ -1026,7 +1175,9 @@ TEST_F(TestReader, testFlatmapAsStructLarge) { pool(), getFMSmallFile(), {1, 2, 3, 4, 5, -99999999 /* does not exist */}, - {} /* no key filtering */); + {} /* no key filtering */, + dataIoStats_, + metadataIoStats_); } TEST_F(TestReader, testFlatmapAsStructWithKeyProjection) { @@ -1034,7 +1185,9 @@ TEST_F(TestReader, testFlatmapAsStructWithKeyProjection) { pool(), getFMSmallFile(), {1, 2, 3, 4, 5, -99999999 /* does not exist */}, - {3, 5} /* select only these to read */); + {3, 5} /* select only these to read */, + dataIoStats_, + metadataIoStats_); } TEST_F(TestReader, testFlatmapAsStructRequiringKeyList) { @@ -1049,12 +1202,15 @@ TEST_F(TestReader, testFlatmapAsStructRequiringKeyList) { TEST_F(TestReader, testMismatchSchemaMoreFields) { // file has schema: a int, b struct, c float dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); RowReaderOptions rowReaderOpts; std::shared_ptr requestedType = std::dynamic_pointer_cast(HiveTypeParser().parse( "struct,c:float,d:string>")); - rowReaderOpts.select(std::make_shared( - requestedType, std::vector{1, 2, 3})); + rowReaderOpts.select( + std::make_shared( + requestedType, std::vector{1, 2, 3})); auto reader = DwrfReader::create( createFileBufferedInput(getStructFile(), readerOpts.memoryPool()), readerOpts); @@ -1094,12 +1250,15 @@ TEST_F(TestReader, testMismatchSchemaMoreFields) { TEST_F(TestReader, testMismatchSchemaFewerFields) { // file has schema: a int, b struct, c float dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); RowReaderOptions rowReaderOpts; std::shared_ptr requestedType = std::dynamic_pointer_cast(HiveTypeParser().parse( "struct>")); - rowReaderOpts.select(std::make_shared( - requestedType, std::vector{1})); + rowReaderOpts.select( + std::make_shared( + requestedType, std::vector{1})); auto reader = DwrfReader::create( createFileBufferedInput(getStructFile(), readerOpts.memoryPool()), readerOpts); @@ -1135,13 +1294,16 @@ TEST_F(TestReader, testMismatchSchemaFewerFields) { TEST_F(TestReader, testMismatchSchemaNestedMoreFields) { // file has schema: a int, b struct, c float dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); RowReaderOptions rowReaderOpts; std::shared_ptr requestedType = std::dynamic_pointer_cast(HiveTypeParser().parse( "struct,c:float>")); LOG(INFO) << requestedType->toString(); - rowReaderOpts.select(std::make_shared( - requestedType, std::vector{"b.b", "b.c", "b.d", "c"})); + rowReaderOpts.select( + std::make_shared( + requestedType, std::vector{"b.b", "b.c", "b.d", "c"})); auto reader = DwrfReader::create( createFileBufferedInput(getStructFile(), readerOpts.memoryPool()), readerOpts); @@ -1201,12 +1363,15 @@ TEST_F(TestReader, testMismatchSchemaNestedMoreFields) { TEST_F(TestReader, testMismatchSchemaNestedFewerFields) { // file has schema: a int, b struct, c float dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); RowReaderOptions rowReaderOpts; std::shared_ptr requestedType = std::dynamic_pointer_cast(HiveTypeParser().parse( "struct,c:float>")); - rowReaderOpts.select(std::make_shared( - requestedType, std::vector{"b.b", "c"})); + rowReaderOpts.select( + std::make_shared( + requestedType, std::vector{"b.b", "c"})); auto reader = DwrfReader::create( createFileBufferedInput(getStructFile(), readerOpts.memoryPool()), readerOpts); @@ -1258,12 +1423,15 @@ TEST_F(TestReader, testMismatchSchemaNestedFewerFields) { TEST_F(TestReader, testMismatchSchemaIncompatibleNotSelected) { // file has schema: a int, b struct, c float dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); RowReaderOptions rowReaderOpts; std::shared_ptr requestedType = std::dynamic_pointer_cast(HiveTypeParser().parse( "struct,c:int>")); - rowReaderOpts.select(std::make_shared( - requestedType, std::vector{"b.b"})); + rowReaderOpts.select( + std::make_shared( + requestedType, std::vector{"b.b"})); auto reader = DwrfReader::create( createFileBufferedInput(getStructFile(), readerOpts.memoryPool()), readerOpts); @@ -1339,6 +1507,8 @@ TEST_F(TestReader, testMismatchSchemaIncompatible) { TEST_F(TestReader, fileColumnNamesReadAsLowerCase) { // upper.orc holds one columns (Bool_Val: BOOLEAN, b: BIGINT) dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); readerOpts.setFileColumnNamesReadAsLowerCase(true); auto reader = DwrfReader::create( createFileBufferedInput( @@ -1353,6 +1523,8 @@ TEST_F(TestReader, fileColumnNamesReadAsLowerCaseComplexStruct) { // upper_complex.orc holds type // Cc:struct>>>> dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); readerOpts.setFileColumnNamesReadAsLowerCase(true); auto reader = DwrfReader::create( createFileBufferedInput( @@ -1391,8 +1563,10 @@ TEST_F(TestReader, fileColumnNamesReadAsLowerCaseComplexStruct) { TEST_F(TestReader, TestStripeSizeCallback) { dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); readerOpts.setFilePreloadThreshold(0); - readerOpts.setFooterEstimatedSize(17); + readerOpts.setFooterSpeculativeIoSize(17); RowReaderOptions rowReaderOpts; std::shared_ptr requestedType = std::dynamic_pointer_cast< @@ -1419,8 +1593,10 @@ TEST_F(TestReader, TestStripeSizeCallback) { TEST_F(TestReader, TestStripeSizeCallbackLimitsOneStripe) { dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); readerOpts.setFilePreloadThreshold(0); - readerOpts.setFooterEstimatedSize(17); + readerOpts.setFooterSpeculativeIoSize(17); RowReaderOptions rowReaderOpts; std::shared_ptr requestedType = std::dynamic_pointer_cast< @@ -1448,8 +1624,10 @@ TEST_F(TestReader, TestStripeSizeCallbackLimitsOneStripe) { TEST_F(TestReader, TestStripeSizeCallbackLimitsTwoStripe) { dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); readerOpts.setFilePreloadThreshold(0); - readerOpts.setFooterEstimatedSize(17); + readerOpts.setFooterSpeculativeIoSize(17); RowReaderOptions rowReaderOpts; std::shared_ptr requestedType = std::dynamic_pointer_cast< @@ -1740,6 +1918,8 @@ TEST_F(TestReader, testEmptyFile) { std::make_shared(std::move(data)), *pool()); dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); RowReaderOptions rowReaderOpts; auto rowReader = DwrfReader::create(std::move(input), readerOpts) @@ -1822,7 +2002,9 @@ void testBufferLifeCycle( const std::shared_ptr& config, std::mt19937& rng, size_t batchSize, - bool hasNull) { + bool hasNull, + const std::shared_ptr& dataIoStats, + const std::shared_ptr& metadataIoStats) { std::vector batches; std::function isNullAt = nullptr; if (hasNull) { @@ -1843,6 +2025,8 @@ void testBufferLifeCycle( std::make_shared(std::move(data)), *pool); dwio::common::ReaderOptions readerOpts{pool}; + readerOpts.setDataIoStats(dataIoStats); + readerOpts.setMetadataIoStats(metadataIoStats); RowReaderOptions rowReaderOpts; rowReaderOpts.setReturnFlatVector(true); auto reader = std::make_unique(readerOpts, std::move(input)); @@ -1880,7 +2064,9 @@ void testFlatmapAsMapFieldLifeCycle( const std::shared_ptr& config, std::mt19937& rng, size_t batchSize, - bool hasNull) { + bool hasNull, + const std::shared_ptr& dataIoStats, + const std::shared_ptr& metadataIoStats) { std::vector batches; std::function isNullAt = nullptr; if (hasNull) { @@ -1901,6 +2087,8 @@ void testFlatmapAsMapFieldLifeCycle( std::make_shared(std::move(data)), *pool); dwio::common::ReaderOptions readerOpts{pool}; + readerOpts.setDataIoStats(dataIoStats); + readerOpts.setMetadataIoStats(metadataIoStats); RowReaderOptions rowReaderOpts; rowReaderOpts.setReturnFlatVector(true); auto reader = std::make_unique(readerOpts, std::move(input)); @@ -1997,8 +2185,24 @@ TEST_F(TestReader, testBufferLifeCycle) { std::mt19937 rng{seed}; for (auto i = 0; i < 10; ++i) { - testBufferLifeCycle(pool(), schema, config, rng, batchSize, false); - testBufferLifeCycle(pool(), schema, config, rng, batchSize, true); + testBufferLifeCycle( + pool(), + schema, + config, + rng, + batchSize, + false, + dataIoStats_, + metadataIoStats_); + testBufferLifeCycle( + pool(), + schema, + config, + rng, + batchSize, + true, + dataIoStats_, + metadataIoStats_); } } @@ -2016,8 +2220,24 @@ TEST_F(TestReader, testFlatmapAsMapFieldLifeCycle) { LOG(INFO) << "seed: " << seed; std::mt19937 rng{seed}; - testFlatmapAsMapFieldLifeCycle(pool(), schema, config, rng, batchSize, false); - testFlatmapAsMapFieldLifeCycle(pool(), schema, config, rng, batchSize, true); + testFlatmapAsMapFieldLifeCycle( + pool(), + schema, + config, + rng, + batchSize, + false, + dataIoStats_, + metadataIoStats_); + testFlatmapAsMapFieldLifeCycle( + pool(), + schema, + config, + rng, + batchSize, + true, + dataIoStats_, + metadataIoStats_); } TEST_F(TestReader, testFooterWrapper) { @@ -2081,6 +2301,8 @@ std::pair, std::unique_ptr> createWriterReader( const std::vector& batches, memory::MemoryPool* pool, + const std::shared_ptr& dataIoStats, + const std::shared_ptr& metadataIoStats, const std::shared_ptr& config = std::make_shared(), std::function()> flushPolicy = @@ -2098,6 +2320,8 @@ createWriterReader( auto input = std::make_unique( std::make_shared(std::move(data)), *pool); dwio::common::ReaderOptions readerOpts(pool); + readerOpts.setDataIoStats(dataIoStats); + readerOpts.setMetadataIoStats(metadataIoStats); readerOpts.setFileFormat(FileFormat::DWRF); auto reader = DwrfReader::create(std::move(input), readerOpts); return std::make_pair(std::move(writer), std::move(reader)); @@ -2115,7 +2339,8 @@ TEST_F(TestReader, setRowNumberColumnInfo) { }; auto batches = createBatches(integerValues); auto schema = asRowType(batches[0]->type()); - auto [writer, reader] = createWriterReader(batches, pool()); + auto [writer, reader] = + createWriterReader(batches, pool(), dataIoStats_, metadataIoStats_); auto spec = std::make_shared(""); spec->addAllChildFields(*schema); @@ -2144,7 +2369,8 @@ TEST_F(TestReader, reuseRowNumberColumn) { std::vector> integerValues{{0, 1, 2, 3, 4}}; auto batches = createBatches(integerValues); auto schema = asRowType(batches[0]->type()); - auto [writer, reader] = createWriterReader(batches, pool()); + auto [writer, reader] = + createWriterReader(batches, pool(), dataIoStats_, metadataIoStats_); auto spec = std::make_shared(""); spec->addAllChildFields(*schema); @@ -2210,7 +2436,8 @@ TEST_F(TestReader, explicitRowNumberColumn) { {9, 10, 11, 12, 13, 14, 15}, }; auto batches = createBatches(integerValues); - auto [writer, reader] = createWriterReader(batches, pool()); + auto [writer, reader] = + createWriterReader(batches, pool(), dataIoStats_, metadataIoStats_); auto spec = std::make_shared(""); spec->addField("c0", 0); spec->addField("$row_number", 1) @@ -2247,7 +2474,8 @@ TEST_F(TestReader, failToReuseReaderNulls) { makeRowVector({"c"}, {makeFlatVector(11, folly::identity)}), }); auto schema = asRowType(data->type()); - auto [writer, reader] = createWriterReader({data}, pool()); + auto [writer, reader] = + createWriterReader({data}, pool(), dataIoStats_, metadataIoStats_); auto spec = std::make_shared(""); spec->addAllChildFields(*schema); spec->childByName("c0")->childByName("a")->setFilter( @@ -2269,24 +2497,25 @@ TEST_F(TestReader, failToReuseReaderNulls) { TEST_F(TestReader, readFlatMapsSomeEmpty) { // Test reading a flat map where the key filter means that some maps are // empty. - auto keys = makeFlatVector(std::vector{ - 1, - 2, - 3, - 4, - 5, - 6, // map 1 has more than just the selected keys. - 1, - 2, - 3, // map 2 has only selected keys. - 4, - 5, - 6, // map 3 has no selected keys. - 1, - 2, - 5, - 6 // map 4 has some selected keys. - }); + auto keys = makeFlatVector( + std::vector{ + 1, + 2, + 3, + 4, + 5, + 6, // map 1 has more than just the selected keys. + 1, + 2, + 3, // map 2 has only selected keys. + 4, + 5, + 6, // map 3 has no selected keys. + 1, + 2, + 5, + 6 // map 4 has some selected keys. + }); auto values = makeFlatVector(16, folly::identity); auto maps = makeMapVector(std::vector{0, 6, 9, 12, 16}, keys, values); @@ -2297,7 +2526,8 @@ TEST_F(TestReader, readFlatMapsSomeEmpty) { config->set(dwrf::Config::FLATTEN_MAP, true); config->set(dwrf::Config::MAP_FLAT_COLS, {0}); - auto [writer, reader] = createWriterReader({row}, pool(), config); + auto [writer, reader] = + createWriterReader({row}, pool(), dataIoStats_, metadataIoStats_, config); auto schema = asRowType(row->type()); auto spec = std::make_shared(""); @@ -2363,7 +2593,8 @@ TEST_F(TestReader, readFlatMapsWithNullMaps) { config->set(dwrf::Config::FLATTEN_MAP, true); config->set(dwrf::Config::MAP_FLAT_COLS, {0}); - auto [writer, reader] = createWriterReader({row}, pool(), config); + auto [writer, reader] = + createWriterReader({row}, pool(), dataIoStats_, metadataIoStats_, config); auto schema = asRowType(row->type()); auto spec = std::make_shared(""); @@ -2420,7 +2651,8 @@ TEST_F(TestReader, readFlatMapsAsFlatMaps) { config->set(dwrf::Config::FLATTEN_MAP, true); config->set(dwrf::Config::MAP_FLAT_COLS, {0}); - auto [writer, reader] = createWriterReader({input}, pool(), config); + auto [writer, reader] = createWriterReader( + {input}, pool(), dataIoStats_, metadataIoStats_, config); auto schema = asRowType(input->type()); auto spec = std::make_shared(""); @@ -2440,32 +2672,85 @@ TEST_F(TestReader, readFlatMapsAsFlatMaps) { assertEqualVectors(flatMap, resultMaps); }; - testRoundTrip(makeFlatMapVector({ - {}, - {{1, 1.9}, {2, 2.1}, {0, 3.12}}, - {{127, 0.12}}, - })); - - testRoundTrip(makeFlatMapVector({ - {{"a", "a1"}}, - {{"b", "b1"}}, - {{"c", "c1"}}, - {{"d", "d1"}}, - })); - - testRoundTrip(makeNullableFlatMapVector({ - {{{101, 1}, {102, 2}, {103, 3}}}, - {{{105, 0}, {106, 0}}}, - {std::nullopt}, - {{{101, 11}, {103, 13}, {105, std::nullopt}}}, - {{{101, 1}, {102, 2}, {103, 3}}}, - })); - - testRoundTrip(makeFlatMapVector( - {{{0, 0}, {1, 1}, {2, 2}, {3, 3}}, - {{0, 4}, {1, 5}, {2, 6}, {3, 7}}, - {{0, 8}, {1, 9}, {2, 10}, {3, 11}}, - {{0, 12}, {1, 13}, {2, 14}, {3, 15}}})); + testRoundTrip( + makeFlatMapVector({ + {}, + {{1, 1.9}, {2, 2.1}, {0, 3.12}}, + {{127, 0.12}}, + })); + + testRoundTrip( + makeFlatMapVector({ + {{"a", "a1"}}, + {{"b", "b1"}}, + {{"c", "c1"}}, + {{"d", "d1"}}, + })); + + testRoundTrip( + makeNullableFlatMapVector({ + {{{101, 1}, {102, 2}, {103, 3}}}, + {{{105, 0}, {106, 0}}}, + {std::nullopt}, + {{{101, 11}, {103, 13}, {105, std::nullopt}}}, + {{{101, 1}, {102, 2}, {103, 3}}}, + })); + + testRoundTrip( + makeFlatMapVector( + {{{0, 0}, {1, 1}, {2, 2}, {3, 3}}, + {{0, 4}, {1, 5}, {2, 6}, {3, 7}}, + {{0, 8}, {1, 9}, {2, 10}, {3, 11}}, + {{0, 12}, {1, 13}, {2, 14}, {3, 15}}})); +} + +// Regression test: reading a multi-stripe flatmap file with +// preserveFlatMapsInMemory=true used to crash when stripes had different +// key sets. The ScanSpec accumulated stale children across stripes, causing +// out-of-range access on the reader's children_ vector. +TEST_F(TestReader, readFlatMapMultiStripeDifferentKeys) { + // Stripe 1: keys {0, 1, 2, 3, 4} — 5 keys. + auto stripe1 = makeRowVector({makeMapVector( + {{{0, 1.0f}, {1, 2.0f}, {2, 3.0f}, {3, 4.0f}, {4, 5.0f}}, + {{0, 6.0f}, {1, 7.0f}, {2, 8.0f}, {3, 9.0f}, {4, 10.0f}}})}); + + // Stripe 2: keys {0, 1, 2} — fewer keys than stripe 1. + auto stripe2 = makeRowVector({makeMapVector( + {{{0, 11.0f}, {1, 12.0f}, {2, 13.0f}}, + {{0, 14.0f}, {1, 15.0f}, {2, 16.0f}}, + {{0, 17.0f}, {1, 18.0f}, {2, 19.0f}}})}); + + auto config = std::make_shared(); + config->set(dwrf::Config::FLATTEN_MAP, true); + config->set(dwrf::Config::MAP_FLAT_COLS, {0}); + + // simpleFlushPolicyFactory(true) produces one stripe per batch. + auto [writer, reader] = createWriterReader( + {stripe1, stripe2}, pool(), dataIoStats_, metadataIoStats_, config); + ASSERT_EQ(reader->getNumberOfStripes(), 2); + + auto schema = asRowType(stripe1->type()); + auto spec = std::make_shared(""); + spec->addAllChildFields(*schema); + + RowReaderOptions rowReaderOpts; + rowReaderOpts.setScanSpec(spec); + rowReaderOpts.setPreserveFlatMapsInMemory(true); + + auto rowReader = reader->createRowReader(rowReaderOpts); + VectorPtr batch = BaseVector::create(schema, 0, pool()); + + uint64_t totalRows = 0; + while (rowReader->next(100, batch) > 0) { + auto* rowVec = batch->as(); + // Trigger lazy loading — this is the pattern that used to crash. + for (column_index_t i = 0; i < rowVec->childrenSize(); ++i) { + auto& child = rowVec->childAt(i); + child = BaseVector::loadedVectorShared(child); + } + totalRows += batch->size(); + } + EXPECT_EQ(totalRows, 5); // 2 rows from stripe 1 + 3 rows from stripe 2. } TEST_F(TestReader, readStructWithWholeBatchFiltered) { @@ -2492,7 +2777,8 @@ TEST_F(TestReader, readStructWithWholeBatchFiltered) { std::make_shared(pool(), rowType, nulls, vectorSize, children); auto row = makeRowVector({"c0"}, {c0}); - auto [writer, reader] = createWriterReader({row}, pool()); + auto [writer, reader] = + createWriterReader({row}, pool(), dataIoStats_, metadataIoStats_); auto schema = asRowType(row->type()); auto spec = std::make_shared(""); @@ -2547,6 +2833,8 @@ TEST_F(TestReader, readStringDictionaryAsFlat) { auto [writer, reader] = createWriterReader( {batch}, pool(), + dataIoStats_, + metadataIoStats_, std::make_shared(), // The always true flush policy would disable dictionary encoding at least // for first batch. @@ -2566,9 +2854,10 @@ TEST_F(TestReader, readStringDictionaryAsFlat) { ASSERT_EQ(c0->valueVector()->size(), dictionary.size()); dwio::common::RuntimeStatistics stats; rowReader->updateRuntimeStats(stats); - ASSERT_EQ(stats.columnReaderStatistics.flattenStringDictionaryValues, 0); - spec->childByName("c0")->setFilter(std::make_unique( - std::vector{"aaaaaaaaaaaaaaaaaaaa"}, false)); + ASSERT_EQ(stats.columnReaderStats.flattenStringDictionaryValues, 0); + spec->childByName("c0")->setFilter( + std::make_unique( + std::vector{"aaaaaaaaaaaaaaaaaaaa"}, false)); spec->resetCachedValues(true); rowReader = reader->createRowReader(rowReaderOpts); ASSERT_EQ(rowReader->next(20, actual), 20); @@ -2576,7 +2865,7 @@ TEST_F(TestReader, readStringDictionaryAsFlat) { ASSERT_TRUE(actual->as()->childAt(0)->isFlatEncoding()); stats = {}; rowReader->updateRuntimeStats(stats); - ASSERT_EQ(stats.columnReaderStatistics.flattenStringDictionaryValues, 1); + ASSERT_EQ(stats.columnReaderStats.flattenStringDictionaryValues, 1); } // A primitive subfield is missing in file, and result is not reused. @@ -2587,7 +2876,8 @@ TEST_F(TestReader, missingSubfieldsNoResultReusing) { makeFlatVector(kSize, folly::identity), }), }); - auto [writer, reader] = createWriterReader({batch}, pool()); + auto [writer, reader] = + createWriterReader({batch}, pool(), dataIoStats_, metadataIoStats_); auto schema = ROW({{"c0", ROW({{"c0", BIGINT()}, {"c1", VARCHAR()}})}}); auto spec = std::make_shared(""); spec->addAllChildFields(*schema); @@ -2616,7 +2906,8 @@ TEST_F(TestReader, selectiveStringDirectFastPath) { makeFlatVector(17, [](auto i) { return i != 8; }), makeFlatVector(17, genStr), }); - auto [writer, reader] = createWriterReader({batch}, pool()); + auto [writer, reader] = + createWriterReader({batch}, pool(), dataIoStats_, metadataIoStats_); auto schema = asRowType(batch->type()); auto spec = std::make_shared(""); spec->addAllChildFields(*schema); @@ -2642,7 +2933,8 @@ TEST_F(TestReader, selectiveStringDirect) { makeFlatVector(17, [](auto i) { return i != 15; }), makeFlatVector(17, genStr), }); - auto [writer, reader] = createWriterReader({batch}, pool()); + auto [writer, reader] = + createWriterReader({batch}, pool(), dataIoStats_, metadataIoStats_); auto schema = asRowType(batch->type()); auto spec = std::make_shared(""); spec->addAllChildFields(*schema); @@ -2666,7 +2958,8 @@ TEST_F(TestReader, selectiveFlatMapFastPathAllInlinedStringKeys) { auto config = std::make_shared(); config->set(dwrf::Config::FLATTEN_MAP, true); config->set(dwrf::Config::MAP_FLAT_COLS, {0}); - auto [writer, reader] = createWriterReader({row}, pool(), config); + auto [writer, reader] = + createWriterReader({row}, pool(), dataIoStats_, metadataIoStats_, config); auto schema = asRowType(row->type()); auto spec = std::make_shared(""); spec->addAllChildFields(*schema); @@ -2685,6 +2978,8 @@ TEST_F(TestReader, skipLongString) { std::make_shared(getExampleFilePath("long_string.dwrf")), *pool()); dwio::common::ReaderOptions readerOpts(pool()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); readerOpts.setFileFormat(FileFormat::DWRF); auto reader = DwrfReader::create(std::move(input), readerOpts); auto spec = std::make_shared(""); @@ -2721,3 +3016,820 @@ TEST_F(TestReader, skipLongString) { validate(batch); } } + +TEST_F(TestReader, mapAsStruct) { + auto row = makeRowVector({ + makeMapVector({{{1, 4}, {2, 5}}, {{1, 6}, {3, 7}}}), + }); + auto [writer, reader] = + createWriterReader({row}, pool(), dataIoStats_, metadataIoStats_); + auto outType = ROW({"c0"}, {ROW({"3", "1"}, BIGINT())}); + auto spec = std::make_shared(""); + spec->addAllChildFields(*outType); + spec->childByName("c0")->setFlatMapAsStruct(true); + RowReaderOptions rowReaderOpts; + rowReaderOpts.setScanSpec(spec); + auto rowReader = reader->createRowReader(rowReaderOpts); + VectorPtr batch = BaseVector::create(outType, 0, pool()); + ASSERT_EQ(rowReader->next(10, batch), 2); + auto expected = makeRowVector({ + makeRowVector( + {"3", "1"}, + { + makeNullableFlatVector({std::nullopt, 7}), + makeFlatVector({4, 6}), + }), + }); + assertEqualVectors(expected, batch); +} + +TEST_F(TestReader, mapAsStructFilterAfterRead) { + auto row = makeRowVector({ + makeMapVector({{{1, 4}, {2, 5}}, {}, {{1, 6}, {3, 7}}}), + makeRowVector( + {makeConstant(0, 3)}, [](auto i) { return i == 0; }), + }); + auto [writer, reader] = + createWriterReader({row}, pool(), dataIoStats_, metadataIoStats_); + auto outType = + ROW({"c0", "c1"}, {ROW({"3", "1"}, BIGINT()), ROW({"c0"}, BIGINT())}); + auto spec = std::make_shared(""); + spec->addAllChildFields(*outType); + auto* c0Spec = spec->childByName("c0"); + c0Spec->setFlatMapAsStruct(true); + c0Spec->setFilter(std::make_shared()); + spec->childByName("c1")->setFilter(std::make_shared()); + RowReaderOptions rowReaderOpts; + rowReaderOpts.setScanSpec(spec); + auto rowReader = reader->createRowReader(rowReaderOpts); + VectorPtr batch = BaseVector::create(outType, 0, pool()); + ASSERT_EQ(rowReader->next(10, batch), 3); + auto expected = makeRowVector({ + makeRowVector( + {"3", "1"}, + { + makeNullableFlatVector({std::nullopt, 7}), + makeNullableFlatVector({std::nullopt, 6}), + }), + makeRowVector({makeConstant(0, 2)}), + }); + assertEqualVectors(expected, batch); +} + +TEST_F(TestReader, mapAsStructAllEmpty) { + auto row = makeRowVector({makeMapVector({{}, {}})}); + auto [writer, reader] = + createWriterReader({row}, pool(), dataIoStats_, metadataIoStats_); + auto outType = ROW({"c0"}, {ROW({"1"}, BIGINT())}); + auto spec = std::make_shared(""); + spec->addAllChildFields(*outType); + spec->childByName("c0")->setFlatMapAsStruct(true); + RowReaderOptions rowReaderOpts; + rowReaderOpts.setScanSpec(spec); + auto rowReader = reader->createRowReader(rowReaderOpts); + VectorPtr batch = BaseVector::create(outType, 0, pool()); + ASSERT_EQ(rowReader->next(10, batch), 2); + auto expected = makeRowVector({ + makeRowVector({"1"}, {makeNullConstant(TypeKind::BIGINT, 2)}), + }); + assertEqualVectors(expected, batch); +} + +// Verify DwrfRowReader can be destroyed while ParallelUnitLoader async load() +// are in progress. This regression test ensures that: +// 1. ParallelUnitLoader destructor doesn't wait for async load() operations +// 2. Async load() from DwrfUnit can still function after ParallelUnitLoader +// destruction and DwrfRowReader destruction, which means all dependencies in +// DwrfUnit remain valid (eg ReaderBase) +// +// If a future change adds an unsafe raw pointer to DwrfUnit's dependencies that +// would be freed by ParallelUnitLoader or DwrfRowReader's destruction, this +// test may crash due to use-after-free. +DEBUG_ONLY_TEST_F(TestReader, asyncLoadSurvivesReaderDestruction) { + const int kNumStripes = 2; + const int kRowsPerStripe = 100; + std::vector batches; + batches.reserve(kNumStripes); + for (int stripe = 0; stripe < kNumStripes; ++stripe) { + batches.push_back(makeRowVector({ + makeFlatVector( + kRowsPerStripe, + [stripe](auto row) { return stripe * kRowsPerStripe + row; }), + })); + } + + // Write the DWRF file - force each batch into its own stripe + auto config = std::make_shared(); + + auto sink = + std::make_unique(1 << 20, FileSink::Options{.pool = pool()}); + auto* sinkPtr = sink.get(); + auto writer = E2EWriterTestUtil::writeData( + std::move(sink), + asRowType(batches[0]->type()), + batches, + config, + // Force flush after each batch to create separate stripes + E2EWriterTestUtil::simpleFlushPolicyFactory(true)); + + std::string data(sinkPtr->data(), sinkPtr->size()); + auto input = std::make_unique( + std::make_shared(std::move(data)), *pool()); + + std::atomic asyncLoadsStarted{0}; + std::atomic asyncLoadsCompleted{0}; + folly::Baton<> readerDestroyed; + + SCOPED_TESTVALUE_SET( + "facebook::velox::dwio::common::ParallelUnitLoader::load", + std::function([&](void*) { + // Only block the second stripe (index 1) - let the first stripe load + // normally so rowReader->next() can complete + // fetch_add returns the value before increment: 0 for first, 1 for + // second, etc. + if (asyncLoadsStarted.fetch_add(1) == 1) { + // Block here until reader is destroyed + readerDestroyed.wait(); + } + asyncLoadsCompleted.fetch_add(1); + })); + + auto ioExecutor = std::make_shared(2); + + // Make sure ReaderOptions and DwrfRowReader are freed after {} scope + { + dwio::common::ReaderOptions readerOpts(pool()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); + readerOpts.setFileFormat(FileFormat::DWRF); + auto reader = DwrfReader::create(std::move(input), readerOpts); + + // Enable parallel unit load + RowReaderOptions rowReaderOpts; + rowReaderOpts.setParallelUnitLoadCount(2); + rowReaderOpts.setIOExecutor(ioExecutor.get()); + auto rowReader = reader->createRowReader(rowReaderOpts); + + VectorPtr batch; + rowReader->next(50, batch); // Read first stripe + + auto start = std::chrono::steady_clock::now(); + rowReader.reset(); + auto duration = std::chrono::steady_clock::now() - start; + // Verify destruction was fast (didn't wait for async operations) + EXPECT_LT(duration, std::chrono::seconds(1)) + << "Destruction should not wait for async loads"; + } + + // Now signal that reader is destroyed + readerDestroyed.post(); + + // Wait for async loads to complete + int maxWaitMs = 2000; + int waitedMs = 0; + while (asyncLoadsCompleted.load() < 2 && waitedMs < maxWaitMs) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + waitedMs += 100; + } + + // Verify that both async loads completed successfully after reader + // destruction This proves the fix works: async operations can complete even + // after DwrfRowReader is destroyed because LoadUnit is captured as shared_ptr + EXPECT_EQ(asyncLoadsCompleted.load(), 2) + << "Both async loads should complete even after reader destruction. " + << "If this fails, it means async operations are being cancelled or " + << "crashing after DwrfRowReader destruction, indicating unsafe pointers."; + + // Clean up + ioExecutor->join(); +} + +TEST_F(TestReader, extractionTransformMapKeys) { + // Write a MAP(VARCHAR, BIGINT) column, read with a ScanSpec transform + // that applies MapKeys extraction at the reader level. + auto keys = makeFlatVector({"a", "b", "c", "d"}); + auto values = makeFlatVector({1, 2, 3, 4}); + auto mapVector = makeMapVector({0, 2}, keys, values); + auto data = makeRowVector({"col"}, {mapVector}); + auto schema = asRowType(data->type()); + auto [writer, reader] = + createWriterReader({data}, pool(), dataIoStats_, metadataIoStats_); + + auto spec = std::make_shared(""); + spec->addAllChildFields(*schema); + + spec->childByName("col")->setExtractionType( + common::ScanSpec::ExtractionType::kKeys); + + RowReaderOptions rowReaderOpts; + rowReaderOpts.setScanSpec(spec); + auto rowReader = reader->createRowReader(rowReaderOpts); + + auto result = BaseVector::create(schema, 0, pool()); + ASSERT_EQ(rowReader->next(10, result), 2); + auto* row = result->as(); + auto* resultArray = row->childAt(0)->loadedVector()->as(); + ASSERT_EQ(resultArray->size(), 2); + ASSERT_EQ(resultArray->sizeAt(0), 2); + ASSERT_EQ(resultArray->sizeAt(1), 2); +} + +TEST_F(TestReader, extractionTransformSize) { + // Write a MAP(VARCHAR, BIGINT) column, read with a Size extraction. + auto keys = makeFlatVector({"a", "b", "c"}); + auto values = makeFlatVector({1, 2, 3}); + auto mapVector = makeMapVector({0, 1}, keys, values); + auto data = makeRowVector({"col"}, {mapVector}); + auto schema = asRowType(data->type()); + auto [writer, reader] = + createWriterReader({data}, pool(), dataIoStats_, metadataIoStats_); + + auto spec = std::make_shared(""); + spec->addAllChildFields(*schema); + + spec->childByName("col")->setExtractionType( + common::ScanSpec::ExtractionType::kSize); + + RowReaderOptions rowReaderOpts; + rowReaderOpts.setScanSpec(spec); + auto rowReader = reader->createRowReader(rowReaderOpts); + + auto result = BaseVector::create(schema, 0, pool()); + ASSERT_EQ(rowReader->next(10, result), 2); + auto* row = result->as(); + auto* sizes = row->childAt(0)->loadedVector()->as>(); + ASSERT_EQ(sizes->size(), 2); + ASSERT_EQ(sizes->valueAt(0), 1); + ASSERT_EQ(sizes->valueAt(1), 2); +} + +TEST_F(TestReader, extractionMapKeySizeWithSeek) { + // Repro for the kSize+MAP+no-deltaUpdate skip() bug. When kSize is + // configured on a MAP column without a delta update, neither keyReader_ + // nor elementReader_ is created. Forward seekToRow() after a prior read + // triggers SelectiveMapColumnReaderBase::skip() which currently fails + // because it requires at least one child reader to read the length stream. + constexpr int kNumRows = 100; + std::vector keyStrs(kNumRows * 2); + for (int i = 0; i < kNumRows * 2; ++i) { + keyStrs[i] = "k_" + std::to_string(i); + } + auto keys = makeFlatVector( + kNumRows * 2, [&](auto i) { return StringView(keyStrs[i]); }); + auto values = makeFlatVector(kNumRows * 2, folly::identity); + std::vector offsets(kNumRows); + for (int i = 0; i < kNumRows; ++i) { + offsets[i] = i * 2; + } + auto mapVector = makeMapVector(offsets, keys, values); + auto data = makeRowVector({"col"}, {mapVector}); + auto schema = asRowType(data->type()); + auto [writer, reader] = + createWriterReader({data}, pool(), dataIoStats_, metadataIoStats_); + + auto spec = std::make_shared(""); + spec->addAllChildFields(*schema); + spec->childByName("col")->setExtractionType( + common::ScanSpec::ExtractionType::kSize); + + RowReaderOptions rowReaderOpts; + rowReaderOpts.setScanSpec(spec); + auto rowReader = reader->createRowReader(rowReaderOpts); + + // Read & materialize the first 20 rows. + auto result = BaseVector::create(schema, 0, pool()); + ASSERT_EQ(rowReader->next(20, result), 20); + result->as()->childAt(0)->loadedVector(); + + // seekToRow advances the map reader from offset 20 to 50, calling skip() + // for the 30-row gap. Without the fix, this VELOX_FAILs at + // SelectiveMapColumnReaderBase::skip with "repeated reader with no + // children". + ASSERT_NO_THROW(dynamic_cast(rowReader.get())->seekToRow(50)); +} + +TEST_F(TestReader, extractionTransformMapValuesStructField) { + // MAP(VARCHAR, ROW(x: INT, y: INT)) with chain + // [MapValues, ArrayElements, StructField("x")] -> ARRAY(INT). + // The reader handles this natively via kValues on the map and kField on + // the values struct, so no post-read transform is needed. + auto keys = makeFlatVector({"a", "b", "c"}); + auto structValues = makeRowVector( + {"x", "y"}, + {makeFlatVector({10, 20, 30}), + makeFlatVector({100, 200, 300})}); + auto mapVector = makeMapVector({0, 2}, keys, structValues); + auto data = makeRowVector({"col"}, {mapVector}); + auto schema = asRowType(data->type()); + auto [writer, reader] = + createWriterReader({data}, pool(), dataIoStats_, metadataIoStats_); + + auto spec = std::make_shared(""); + spec->addAllChildFields(*schema); + + using connector::hive::applyExtractionChain; + using connector::hive::configureExtractionScanSpec; + using connector::hive::ExtractionPathElement; + using connector::hive::ExtractionPathElementPtr; + using connector::hive::ExtractionStep; + using connector::hive::NamedExtraction; + + // Configure extraction on the "col" spec with the sub-chain starting + // from the MAP type: [MapValues, ArrayElements, StructField("x")]. + auto mapType = schema->childAt(0); + std::vector colExtractions = { + {"out", + {ExtractionPathElement::simple(ExtractionStep::kMapValues), + ExtractionPathElement::simple(ExtractionStep::kArrayElements), + ExtractionPathElement::structField("x")}, + INTEGER()}}; + auto* colSpec = spec->childByName("col"); + configureExtractionScanSpec(mapType, colExtractions, *colSpec, pool()); + + // Set a full-chain transform for delta update fallback. + auto fullChain = colExtractions[0].chain; + colSpec->setTransform( + [fullChain](const VectorPtr& input, memory::MemoryPool* p) -> VectorPtr { + return applyExtractionChain(input, fullChain, p); + }, + ARRAY(INTEGER())); + + RowReaderOptions rowReaderOpts; + rowReaderOpts.setScanSpec(spec); + auto rowReader = reader->createRowReader(rowReaderOpts); + + auto result = BaseVector::create(schema, 0, pool()); + ASSERT_EQ(rowReader->next(10, result), 2); + auto* row = result->as(); + auto* resultArray = row->childAt(0)->loadedVector()->as(); + ASSERT_EQ(resultArray->size(), 2); + ASSERT_EQ(resultArray->sizeAt(0), 2); + ASSERT_EQ(resultArray->sizeAt(1), 1); + auto* elements = resultArray->elements()->as>(); + ASSERT_EQ(elements->valueAt(0), 10); + ASSERT_EQ(elements->valueAt(1), 20); + ASSERT_EQ(elements->valueAt(2), 30); +} + +TEST_F(TestReader, extractionSizeResultVectorReuse) { + // Verify the FlatVector result is reused across batches for + // Size extraction (no per-batch allocation). + // Write multiple batches to produce multiple reads. + constexpr int kNumRows = 200; + std::vector keyStrs(kNumRows * 2); + for (int i = 0; i < kNumRows * 2; ++i) { + keyStrs[i] = std::to_string(i); + } + auto keys = makeFlatVector( + kNumRows * 2, [&](auto i) { return StringView(keyStrs[i]); }); + auto values = makeFlatVector(kNumRows * 2, folly::identity); + // Each row has 2 map entries. + std::vector offsets(kNumRows); + for (int i = 0; i < kNumRows; ++i) { + offsets[i] = i * 2; + } + auto mapVector = makeMapVector(offsets, keys, values); + auto data = makeRowVector({"col"}, {mapVector}); + auto schema = asRowType(data->type()); + auto [writer, reader] = + createWriterReader({data}, pool(), dataIoStats_, metadataIoStats_); + + auto spec = std::make_shared(""); + spec->addAllChildFields(*schema); + spec->childByName("col")->setExtractionType( + common::ScanSpec::ExtractionType::kSize); + + RowReaderOptions rowReaderOpts; + rowReaderOpts.setScanSpec(spec); + auto rowReader = reader->createRowReader(rowReaderOpts); + + auto result = BaseVector::create(schema, 0, pool()); + + // Read first batch. + ASSERT_GT(rowReader->next(50, result), 0); + auto* row = result->as(); + auto* child = row->childAt(0)->loadedVector(); + ASSERT_TRUE(child->type()->isBigint()); + auto* firstBatchPtr = child; + + // Read second batch — the FlatVector should be the same object. + ASSERT_GT(rowReader->next(50, result), 0); + row = result->as(); + child = row->childAt(0)->loadedVector(); + ASSERT_EQ(child, firstBatchPtr) + << "FlatVector result should be reused across batches for Size extraction"; + + // Verify values are correct (each map has 2 entries). + auto* sizes = child->as>(); + for (int i = 0; i < sizes->size(); ++i) { + ASSERT_EQ(sizes->valueAt(i), 2); + } +} + +TEST_F(TestReader, extractionMapKeysMultipleBatches) { + // Verify MapKeys extraction produces correct results across multiple + // batches. + constexpr int kNumRows = 200; + std::vector keyStrs(kNumRows * 3); + for (int i = 0; i < kNumRows * 3; ++i) { + keyStrs[i] = std::to_string(i); + } + auto keys = makeFlatVector( + kNumRows * 3, [&](auto i) { return StringView(keyStrs[i]); }); + auto values = makeFlatVector(kNumRows * 3, folly::identity); + std::vector offsets(kNumRows); + for (int i = 0; i < kNumRows; ++i) { + offsets[i] = i * 3; + } + auto mapVector = makeMapVector(offsets, keys, values); + auto data = makeRowVector({"col"}, {mapVector}); + auto schema = asRowType(data->type()); + auto [writer, reader] = + createWriterReader({data}, pool(), dataIoStats_, metadataIoStats_); + + auto spec = std::make_shared(""); + spec->addAllChildFields(*schema); + spec->childByName("col")->setExtractionType( + common::ScanSpec::ExtractionType::kKeys); + + RowReaderOptions rowReaderOpts; + rowReaderOpts.setScanSpec(spec); + auto rowReader = reader->createRowReader(rowReaderOpts); + + auto result = BaseVector::create(schema, 0, pool()); + + // Read first batch. + ASSERT_GT(rowReader->next(50, result), 0); + auto* row = result->as(); + auto* child = row->childAt(0)->loadedVector(); + ASSERT_TRUE(child->type()->isArray()); + auto* arr = child->as(); + for (int i = 0; i < arr->size(); ++i) { + ASSERT_EQ(arr->sizeAt(i), 3); + } + + // Read second batch — verify correctness. + ASSERT_GT(rowReader->next(50, result), 0); + row = result->as(); + child = row->childAt(0)->loadedVector(); + arr = child->as(); + for (int i = 0; i < arr->size(); ++i) { + ASSERT_EQ(arr->sizeAt(i), 3); + } +} + +TEST_F(TestReader, extractionMapKeysIoReduction) { + // Verify that MapKeys extraction produces correct results on a large + // dataset. The extraction pushdown skips decoding the value stream + // (seekTo instead of readWithTiming) — this saves CPU but DWRF coalesced + // I/O reads the full stripe from storage regardless. Decode-level savings + // are validated by the extractionTransformMapKeys test which verifies the + // reader's seekTo behavior. + constexpr int kNumRows = 10'000; + std::vector keyStrs(kNumRows * 2); + for (int i = 0; i < kNumRows * 2; ++i) { + keyStrs[i] = "key_" + std::to_string(i); + } + auto keys = makeFlatVector( + kNumRows * 2, [&](auto i) { return StringView(keyStrs[i]); }); + auto values = makeFlatVector(kNumRows * 2, folly::identity); + std::vector offsets(kNumRows); + for (int i = 0; i < kNumRows; ++i) { + offsets[i] = i * 2; + } + auto mapVector = makeMapVector(offsets, keys, values); + auto data = makeRowVector({"col"}, {mapVector}); + auto schema = asRowType(data->type()); + auto [writer, reader] = + createWriterReader({data}, pool(), dataIoStats_, metadataIoStats_); + + auto spec = std::make_shared(""); + spec->addAllChildFields(*schema); + spec->childByName("col")->setExtractionType( + common::ScanSpec::ExtractionType::kKeys); + + RowReaderOptions rowReaderOpts; + rowReaderOpts.setScanSpec(spec); + auto rowReader = reader->createRowReader(rowReaderOpts); + + auto result = BaseVector::create(schema, 0, pool()); + uint64_t totalRows = 0; + while (auto batch = rowReader->next(1'000, result)) { + totalRows += batch; + auto* row = result->as(); + auto* arr = row->childAt(0)->loadedVector()->as(); + ASSERT_EQ(arr->size(), batch); + for (vector_size_t i = 0; i < arr->size(); ++i) { + ASSERT_EQ(arr->sizeAt(i), 2) << "Row " << i << " should have 2 keys"; + } + } + ASSERT_EQ(totalRows, kNumRows); + + // Validate IO reduction: MapKeys extraction should read fewer bytes than + // a full scan because the values stream is not requested. + auto sink = + std::make_unique(1 << 20, FileSink::Options{.pool = pool()}); + auto* sinkPtr = sink.get(); + auto writerObj = E2EWriterTestUtil::writeData( + std::move(sink), + schema, + {data}, + std::make_shared(), + E2EWriterTestUtil::simpleFlushPolicyFactory(true)); + std::string fileData(sinkPtr->data(), sinkPtr->size()); + + auto runWithSpec = + [&](const std::shared_ptr& scanSpec) -> uint64_t { + auto ioStats = std::make_shared(); + // Disable coalescing (maxMergeDistance=0) so each stream is read + // separately and rawBytesRead reflects actual stream-level I/O. + // Disable file preloading so small files don't bypass per-stream IO. + auto input = std::make_unique( + std::make_shared(fileData), + *pool(), + MetricsLog::voidLog(), + ioStats.get(), + /*ioStats=*/nullptr, + /*maxMergeDistance=*/0); + dwio::common::ReaderOptions readerOpts(pool()); + readerOpts.setDataIoStats(ioStats); + readerOpts.setMetadataIoStats(ioStats); + readerOpts.setFileFormat(FileFormat::DWRF); + readerOpts.setFilePreloadThreshold(0); + auto rdr = DwrfReader::create(std::move(input), readerOpts); + RowReaderOptions rro; + rro.setScanSpec(scanSpec); + auto rr = rdr->createRowReader(rro); + auto res = BaseVector::create(schema, 0, pool()); + while (rr->next(kNumRows, res) > 0) { + } + return ioStats->rawBytesRead(); + }; + + auto fullSpec = std::make_shared(""); + fullSpec->addAllChildFields(*schema); + auto fullBytes = runWithSpec(fullSpec); + + auto extSpec = std::make_shared(""); + extSpec->addAllChildFields(*schema); + extSpec->childByName("col")->setExtractionType( + common::ScanSpec::ExtractionType::kKeys); + auto extBytes = runWithSpec(extSpec); + + ASSERT_GT(fullBytes, 0); + ASSERT_LT(extBytes, fullBytes) + << "Extraction: " << extBytes << ", Full: " << fullBytes; +} + +TEST_F(TestReader, extractionNestedChainScanSpec) { + // Test that a nested extraction chain recursively configures ALL levels + // of the ScanSpec, so no post-read transform is needed. + // + // Schema: ROW(a: MAP(VARCHAR, ROW(x: INT, y: ARRAY(BIGINT))), b: INT) + // Chain: [StructField("a"), MapValues, ArrayElements, StructField("y"), Size] + // + // Expected ScanSpec at each level: + // ROOT ROW: "b" pruned (constant null), "a" not pruned + // "a" MAP: ExtractionType::kValues + // values ROW: "x" pruned (constant null), "y" not pruned + // "y" ARRAY: ExtractionType::kSize + + auto innerStructType = ROW({{"x", INTEGER()}, {"y", ARRAY(BIGINT())}}); + auto mapType = MAP(VARCHAR(), innerStructType); + auto schema = ROW({{"a", mapType}, {"b", INTEGER()}}); + + // Build the ScanSpec. + auto spec = std::make_shared(""); + spec->addAllChildFields(*schema); + + // Build extraction. + using connector::hive::configureExtractionScanSpec; + using connector::hive::ExtractionPathElement; + using connector::hive::ExtractionPathElementPtr; + using connector::hive::ExtractionStep; + using connector::hive::NamedExtraction; + + std::vector extractions = { + {"out", + {ExtractionPathElement::structField("a"), + ExtractionPathElement::simple(ExtractionStep::kMapValues), + ExtractionPathElement::simple(ExtractionStep::kArrayElements), + ExtractionPathElement::structField("y"), + ExtractionPathElement::simple(ExtractionStep::kSize)}, + ARRAY(BIGINT())}}; + + configureExtractionScanSpec(schema, extractions, *spec, pool()); + + // Level 1 (ROOT ROW): "b" pruned, "a" not pruned. + auto* bSpec = spec->childByName("b"); + ASSERT_NE(bSpec, nullptr); + ASSERT_TRUE(bSpec->isConstant()); + + auto* aSpec = spec->childByName("a"); + ASSERT_NE(aSpec, nullptr); + ASSERT_FALSE(aSpec->isConstant()); + + // Level 2 ("a" MAP): ExtractionType::kValues. + ASSERT_EQ(aSpec->extractionType(), common::ScanSpec::ExtractionType::kValues); + + // Level 3 (values ROW): "x" pruned, "y" not pruned. + auto* valuesSpec = aSpec->childByName(common::ScanSpec::kMapValuesFieldName); + ASSERT_NE(valuesSpec, nullptr); + + auto* xSpec = valuesSpec->childByName("x"); + ASSERT_NE(xSpec, nullptr); + ASSERT_TRUE(xSpec->isConstant()); + + auto* ySpec = valuesSpec->childByName("y"); + ASSERT_NE(ySpec, nullptr); + ASSERT_FALSE(ySpec->isConstant()); + + // Level 4 ("y" ARRAY): ExtractionType::kSize. + ASSERT_EQ(ySpec->extractionType(), common::ScanSpec::ExtractionType::kSize); + + // The values struct has kField extraction for "y" (the only needed field). + ASSERT_EQ( + valuesSpec->extractionType(), common::ScanSpec::ExtractionType::kField); + + // Verify kField targets the correct field index. + ASSERT_EQ(valuesSpec->extractionFieldIndex(), 1); + + // Write test data and verify the reader produces correct results. + // + // Each row i has map: {"k" -> {x: i*10, y: [0..i]}} + // So y has (i+1) elements. + constexpr int kNumRows = 10; + std::vector allKeys; + std::vector allX; + std::vector> allY; + for (int i = 0; i < kNumRows; ++i) { + allKeys.emplace_back("k"); + allX.push_back(i * 10); + std::vector yValues; + for (int j = 0; j <= i; ++j) { + yValues.push_back(i * 100 + j); + } + allY.push_back(std::move(yValues)); + } + + auto keys = makeFlatVector(allKeys); + auto xFlat = makeFlatVector(allX); + auto yArray = makeArrayVector(allY); + auto rowValues = makeRowVector({"x", "y"}, {xFlat, yArray}); + + // Each row has exactly 1 map entry. + std::vector mapOffsets(kNumRows); + std::iota(mapOffsets.begin(), mapOffsets.end(), 0); + auto map = makeMapVector(mapOffsets, keys, rowValues); + auto bCol = makeFlatVector(kNumRows, folly::identity); + auto batch = makeRowVector({"a", "b"}, {map, bCol}); + + auto [writer, reader] = + createWriterReader({batch}, pool(), dataIoStats_, metadataIoStats_); + + // Configure the ScanSpec for reading. In production, the extraction is + // applied to the column's spec (not the root), so we configure "a"'s + // spec directly with the sub-chain after StructField("a"). + using connector::hive::applyExtractionChain; + auto readSpec = std::make_shared(""); + readSpec->addAllChildFields(*schema); + auto* readASpec = readSpec->childByName("a"); + + // Sub-chain starting from the MAP type: [MapValues, AE, SF("y"), Size]. + std::vector aExtractions = { + {"out", + {ExtractionPathElement::simple(ExtractionStep::kMapValues), + ExtractionPathElement::simple(ExtractionStep::kArrayElements), + ExtractionPathElement::structField("y"), + ExtractionPathElement::simple(ExtractionStep::kSize)}, + ARRAY(BIGINT())}}; + configureExtractionScanSpec(mapType, aExtractions, *readASpec, pool()); + + // Prune "b" as constant null (mimic HiveDataSource behavior). + auto* readBSpec = readSpec->childByName("b"); + readBSpec->setConstantValue( + BaseVector::createNullConstant(INTEGER(), 1, pool())); + + // Set a full-chain transform (used as fallback for delta updates). + auto fullChain = aExtractions[0].chain; + readASpec->setTransform( + [fullChain](const VectorPtr& input, memory::MemoryPool* p) -> VectorPtr { + return applyExtractionChain(input, fullChain, p); + }, + BIGINT()); + + RowReaderOptions rowReaderOpts; + rowReaderOpts.setScanSpec(readSpec); + auto rowReader = reader->createRowReader(rowReaderOpts); + + auto result = BaseVector::create(schema, 0, pool()); + ASSERT_TRUE(rowReader->next(kNumRows, result)); + ASSERT_EQ(result->size(), kNumRows); + + // The reader with kValues extraction on "a" produces an ArrayVector. + // The values struct reader has kField extraction for "y", so it + // produces the "y" field directly (BIGINT from kSize). No remaining + // transform is needed. So elements are BIGINT. + auto* resultRow = result->as(); + ASSERT_NE(resultRow, nullptr); + + // "b" should be null (pruned). + auto* bResult = resultRow->childAt(1).get(); + ASSERT_TRUE(bResult->isConstantEncoding()); + + // "a" should be an ArrayVector (from kValues extraction). + auto* aResult = resultRow->childAt(0)->loadedVector()->as(); + ASSERT_NE(aResult, nullptr); + + // Each element is BIGINT (size of y array). + auto* elements = aResult->elements()->asFlatVector(); + ASSERT_NE(elements, nullptr); + for (int i = 0; i < kNumRows; ++i) { + ASSERT_EQ(aResult->sizeAt(i), 1) << "Row " << i << " should have 1 entry"; + // y array had (i+1) elements, so size should be (i+1). + ASSERT_EQ(elements->valueAt(aResult->offsetAt(i)), i + 1) + << "Incorrect size at row " << i; + } + + // Validate IO reduction: nested extraction should read fewer bytes than + // a full scan because "b" column, map keys, "x" field, and y array + // elements are all skipped. + constexpr int kLargeNumRows = 1'000; + std::vector largeKeys; + std::vector largeX; + std::vector> largeY; + for (int i = 0; i < kLargeNumRows; ++i) { + largeKeys.emplace_back("k"); + largeX.push_back(i * 10); + std::vector yVals(10); + std::iota(yVals.begin(), yVals.end(), i * 100); + largeY.push_back(std::move(yVals)); + } + auto largeKeysVec = makeFlatVector(largeKeys); + auto largeXVec = makeFlatVector(largeX); + auto largeYVec = makeArrayVector(largeY); + auto largeRowValues = makeRowVector({"x", "y"}, {largeXVec, largeYVec}); + std::vector largeMapOffsets(kLargeNumRows); + std::iota(largeMapOffsets.begin(), largeMapOffsets.end(), 0); + auto largeMap = makeMapVector(largeMapOffsets, largeKeysVec, largeRowValues); + auto largeBCol = makeFlatVector(kLargeNumRows, folly::identity); + auto largeBatch = makeRowVector({"a", "b"}, {largeMap, largeBCol}); + + auto largeSink = + std::make_unique(1 << 20, FileSink::Options{.pool = pool()}); + auto* largeSinkPtr = largeSink.get(); + auto largeWriter = E2EWriterTestUtil::writeData( + std::move(largeSink), + schema, + {largeBatch}, + std::make_shared(), + E2EWriterTestUtil::simpleFlushPolicyFactory(true)); + std::string largeFileData(largeSinkPtr->data(), largeSinkPtr->size()); + + auto runLargeWithSpec = + [&](const std::shared_ptr& scanSpec) -> uint64_t { + auto ioStats = std::make_shared(); + // Disable coalescing so each stream is read separately. + // Disable file preloading so small files don't bypass per-stream IO. + auto input = std::make_unique( + std::make_shared(largeFileData), + *pool(), + MetricsLog::voidLog(), + ioStats.get(), + /*ioStats=*/nullptr, + /*maxMergeDistance=*/0); + dwio::common::ReaderOptions readerOpts(pool()); + readerOpts.setDataIoStats(ioStats); + readerOpts.setMetadataIoStats(ioStats); + readerOpts.setFileFormat(FileFormat::DWRF); + readerOpts.setFilePreloadThreshold(0); + auto rdr = DwrfReader::create(std::move(input), readerOpts); + RowReaderOptions rro; + rro.setScanSpec(scanSpec); + auto rr = rdr->createRowReader(rro); + auto res = BaseVector::create(schema, 0, pool()); + while (rr->next(kLargeNumRows, res) > 0) { + } + return ioStats->rawBytesRead(); + }; + + // Full scan: read all columns. + auto fullSpec2 = std::make_shared(""); + fullSpec2->addAllChildFields(*schema); + auto fullBytes = runLargeWithSpec(fullSpec2); + + // Extraction scan with nested pushdown. + auto extSpec2 = std::make_shared(""); + extSpec2->addAllChildFields(*schema); + configureExtractionScanSpec(schema, extractions, *extSpec2, pool()); + extSpec2->childByName("b")->setConstantValue( + BaseVector::createNullConstant(INTEGER(), 1, pool())); + auto extBytes = runLargeWithSpec(extSpec2); + + ASSERT_GT(fullBytes, 0); + ASSERT_LT(extBytes, fullBytes) + << "Nested extraction: " << extBytes << ", Full: " << fullBytes; +} + +} // namespace +} // namespace facebook::velox::dwrf diff --git a/velox/dwio/dwrf/test/StripeReaderBaseTests.cpp b/velox/dwio/dwrf/test/StripeReaderBaseTests.cpp index 4091204ba96..09ab5503f19 100644 --- a/velox/dwio/dwrf/test/StripeReaderBaseTests.cpp +++ b/velox/dwio/dwrf/test/StripeReaderBaseTests.cpp @@ -17,6 +17,7 @@ #include #include +#include "velox/common/io/IoStatistics.h" #include "velox/dwio/common/encryption/TestProvider.h" #include "velox/dwio/dwrf/reader/StripeReaderBase.h" #include "velox/dwio/dwrf/utils/ProtoUtils.h" @@ -42,7 +43,8 @@ class StripeLoadKeysTest : public Test { HiveTypeParser parser; auto type = parser.parse("struct"); footer_ = std::make_unique(); - ProtoUtils::writeType(*type, *footer_); + auto footerWrapper = FooterWriteWrapper(footer_.get()); + ProtoUtils::writeType(*type, footerWrapper); auto enc = footer_->mutable_encryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); auto group = enc->add_encryptiongroups(); @@ -72,7 +74,9 @@ class StripeLoadKeysTest : public Test { nullptr, footer_.get(), nullptr, - std::move(handler)); + std::move(handler), + dataIoStats_, + metadataIoStats_); stripeReader_ = std::make_unique(reader_); } @@ -93,8 +97,8 @@ class StripeLoadKeysTest : public Test { handler_ = std::move(handler); - enc_ = const_cast( - std::addressof(dynamic_cast( + enc_ = const_cast(std::addressof( + dynamic_cast( handler_->getEncryptionProviderByIndex(0)))); } @@ -103,6 +107,10 @@ class StripeLoadKeysTest : public Test { std::unique_ptr stripeReader_; TestEncryption* enc_; std::shared_ptr pool_; + std::shared_ptr dataIoStats_ = + std::make_shared(); + std::shared_ptr metadataIoStats_ = + std::make_shared(); std::unique_ptr stripeFooter_; std::unique_ptr handler_; std::unique_ptr stripeInfo_; diff --git a/velox/dwio/dwrf/test/TestStripeStream.cpp b/velox/dwio/dwrf/test/StripeStreamTest.cpp similarity index 92% rename from velox/dwio/dwrf/test/TestStripeStream.cpp rename to velox/dwio/dwrf/test/StripeStreamTest.cpp index 2e9132aeb5f..72bebf35a0d 100644 --- a/velox/dwio/dwrf/test/TestStripeStream.cpp +++ b/velox/dwio/dwrf/test/StripeStreamTest.cpp @@ -14,9 +14,11 @@ * limitations under the License. */ +#include "velox/dwio/dwrf/reader/StripeStream.h" #include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/io/IoStatistics.h" +#include "velox/dwio/common/Arena.h" #include "velox/dwio/common/encryption/TestProvider.h" -#include "velox/dwio/dwrf/reader/StripeStream.h" #include "velox/dwio/dwrf/test/OrcTest.h" #include "velox/dwio/dwrf/utils/ProtoUtils.h" #include "velox/dwio/dwrf/writer/WriterBase.h" @@ -42,8 +44,7 @@ class RecordingInputStream : public facebook::velox::InMemoryReadFile { uint64_t offset, uint64_t length, void* buf, - facebook::velox::filesystems::File::IoStats* stats = - nullptr) const override { + const facebook::velox::FileIoContext& context = {}) const override { reads_.push_back({offset, length}); return {static_cast(buf), length}; } @@ -138,6 +139,10 @@ class StripeStreamTest : public testing::TestWithParam { MemoryManager::testingSetInstance(MemoryManager::Options{}); } std::shared_ptr pool_{memoryManager()->addLeafPool()}; + std::shared_ptr dataIoStats_ = + std::make_shared(); + std::shared_ptr metadataIoStats_ = + std::make_shared(); }; class StripeStreamFormatTypeTest : public testing::TestWithParam { @@ -146,6 +151,10 @@ class StripeStreamFormatTypeTest : public testing::TestWithParam { MemoryManager::testingSetInstance(MemoryManager::Options{}); } std::shared_ptr pool_{memoryManager()->addLeafPool()}; + std::shared_ptr dataIoStats_ = + std::make_shared(); + std::shared_ptr metadataIoStats_ = + std::make_shared(); DwrfFormat testParamDwrfFormat_ = GetParam(); }; @@ -158,10 +167,11 @@ INSTANTIATE_TEST_SUITE_P( TEST_P(StripeStreamFormatTypeTest, planReads) { google::protobuf::Arena arena; - auto footer = google::protobuf::Arena::CreateMessage(&arena); - footer->set_rowindexstride(100); + auto footer = ArenaCreate(&arena); + auto footerWrapper = FooterWriteWrapper(footer); + footerWrapper.setRowIndexStride(100); auto type = HiveTypeParser().parse("struct"); - ProtoUtils::writeType(*type, *footer); + ProtoUtils::writeType(*type, footerWrapper); auto is = std::make_unique(); auto isPtr = is.get(); auto readerBase = std::make_shared( @@ -176,7 +186,10 @@ TEST_P(StripeStreamFormatTypeTest, planReads) { true), std::make_unique(proto::PostScript{}), footer, - nullptr); + nullptr, + nullptr, + dataIoStats_, + metadataIoStats_); ColumnSelector cs{readerBase->schema(), std::vector{2}, true}; TestDecrypterFactory factory; @@ -247,10 +260,11 @@ TEST_P(StripeStreamFormatTypeTest, planReads) { TEST_F(StripeStreamTest, filterSequences) { google::protobuf::Arena arena; - auto footer = google::protobuf::Arena::CreateMessage(&arena); - footer->set_rowindexstride(100); + auto footer = ArenaCreate(&arena); + auto footerWrapper = FooterWriteWrapper(footer); + footerWrapper.setRowIndexStride(100); auto type = HiveTypeParser().parse("struct>"); - ProtoUtils::writeType(*type, *footer); + ProtoUtils::writeType(*type, footerWrapper); auto is = std::make_unique(); auto isPtr = is.get(); auto readerBase = std::make_shared( @@ -258,7 +272,10 @@ TEST_F(StripeStreamTest, filterSequences) { std::make_unique(std::move(is), *pool_), std::make_unique(proto::PostScript{}), footer, - nullptr); + nullptr, + nullptr, + dataIoStats_, + metadataIoStats_); // mock a filter that we only need one node and one sequence ColumnSelector cs{readerBase->schema(), std::vector{"a#[1]"}}; @@ -311,10 +328,11 @@ TEST_F(StripeStreamTest, filterSequences) { TEST_P(StripeStreamFormatTypeTest, zeroLength) { google::protobuf::Arena arena; - auto footer = google::protobuf::Arena::CreateMessage(&arena); - footer->set_rowindexstride(100); + auto footer = ArenaCreate(&arena); + auto footerWrapper = FooterWriteWrapper(footer); + footerWrapper.setRowIndexStride(100); auto type = HiveTypeParser().parse("struct"); - ProtoUtils::writeType(*type, *footer); + ProtoUtils::writeType(*type, footerWrapper); proto::PostScript ps; ps.set_compressionblocksize(1024); ps.set_compression(proto::CompressionKind::ZSTD); @@ -325,7 +343,10 @@ TEST_P(StripeStreamFormatTypeTest, zeroLength) { std::make_unique(std::move(is), *pool_), std::make_unique(std::move(ps)), footer, - nullptr); + nullptr, + nullptr, + dataIoStats_, + metadataIoStats_); TestDecrypterFactory factory; auto handler = DecryptionHandler::create(FooterWrapper(footer), &factory); @@ -437,12 +458,13 @@ TEST_P(StripeStreamFormatTypeTest, planReadsIndex) { index.SerializeToOstream(&buffer); // build footer - auto footer = google::protobuf::Arena::CreateMessage(&arena); - footer->set_rowindexstride(100); - footer->add_stripecacheoffsets(0); - footer->add_stripecacheoffsets(buffer.tellp()); + auto footer = ArenaCreate(&arena); + auto footerWrapper = FooterWriteWrapper(footer); + footerWrapper.setRowIndexStride(100); + footerWrapper.addStripeCacheOffsets(0); + footerWrapper.addStripeCacheOffsets(buffer.tellp()); auto type = HiveTypeParser().parse("struct"); - ProtoUtils::writeType(*type, *footer); + ProtoUtils::writeType(*type, footerWrapper); // build cache std::string str(buffer.str()); @@ -458,7 +480,10 @@ TEST_P(StripeStreamFormatTypeTest, planReadsIndex) { std::make_unique(std::move(is), *pool_), std::make_unique(std::move(ps)), footer, - std::move(cache)); + std::move(cache), + nullptr, + dataIoStats_, + metadataIoStats_); TestDecrypterFactory factory; auto handler = DecryptionHandler::create(FooterWrapper(footer), &factory); @@ -596,23 +621,24 @@ TEST_F(StripeStreamTest, readEncryptedStreams) { proto::PostScript ps; ps.set_compression(proto::CompressionKind::ZSTD); ps.set_compressionblocksize(256 * 1024); - auto footer = google::protobuf::Arena::CreateMessage(&arena); + auto footer = ArenaCreate(&arena); + auto footerWrapper = FooterWriteWrapper(footer); // a: not encrypted, projected // encryption group 1: b, c. projected b. // group 2: d. projected d. // group 3: e. not projected auto type = HiveTypeParser().parse("struct"); - ProtoUtils::writeType(*type, *footer); + ProtoUtils::writeType(*type, footerWrapper); - auto enc = footer->mutable_encryption(); + auto enc = footerWrapper.mutableEncryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); addEncryptionGroup(*enc, {2, 3}); addEncryptionGroup(*enc, {4}); addEncryptionGroup(*enc, {5}); - auto stripe = footer->add_stripes(); + auto stripe = footerWrapper.addStripes(); for (auto i = 0; i < 3; ++i) { - *stripe->add_keymetadata() = folly::to("key", i); + *stripe.addKeyMetadata() = folly::to("key", i); } TestDecrypterFactory factory; auto handler = DecryptionHandler::create(FooterWrapper(footer), &factory); @@ -643,7 +669,9 @@ TEST_F(StripeStreamTest, readEncryptedStreams) { std::make_unique(std::move(ps)), footer, nullptr, - std::move(handler)); + std::move(handler), + dataIoStats_, + metadataIoStats_); auto stripeMetadata = std::make_unique( &readerBase->bufferedInput(), std::move(stripeFooter), @@ -689,19 +717,20 @@ TEST_F(StripeStreamTest, schemaMismatch) { proto::PostScript ps; ps.set_compression(proto::CompressionKind::ZSTD); ps.set_compressionblocksize(256 * 1024); - auto footer = google::protobuf::Arena::CreateMessage(&arena); + auto footer = ArenaCreate(&arena); + auto footerWrapper = FooterWriteWrapper(footer); // a: not encrypted, has schema change // b: encrypted // c: not encrypted auto type = HiveTypeParser().parse("struct,b:int,c:int>"); - ProtoUtils::writeType(*type, *footer); + ProtoUtils::writeType(*type, footerWrapper); - auto enc = footer->mutable_encryption(); + auto enc = footerWrapper.mutableEncryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); addEncryptionGroup(*enc, {3}); - auto stripe = footer->add_stripes(); - *stripe->add_keymetadata() = "key"; + auto stripe = footerWrapper.addStripes(); + *stripe.addKeyMetadata() = "key"; TestDecrypterFactory factory; auto handler = DecryptionHandler::create(FooterWrapper(footer), &factory); TestEncrypter encrypter; @@ -726,7 +755,9 @@ TEST_F(StripeStreamTest, schemaMismatch) { std::make_unique(std::move(ps)), footer, nullptr, - std::move(handler)); + std::move(handler), + dataIoStats_, + metadataIoStats_); auto stripeMetadata = std::make_unique( &readerBase->bufferedInput(), std::move(stripeFooter), diff --git a/velox/dwio/dwrf/test/TestByteRle.cpp b/velox/dwio/dwrf/test/TestByteRle.cpp index ad19d341a50..94f8ecf9da6 100644 --- a/velox/dwio/dwrf/test/TestByteRle.cpp +++ b/velox/dwio/dwrf/test/TestByteRle.cpp @@ -39,9 +39,9 @@ std::unique_ptr createBooleanDecoder( TEST(ByteRle, simpleTest) { const unsigned char buffer[] = {0x61, 0x00, 0xfd, 0x44, 0x45, 0x46}; - std::unique_ptr rle = - createByteDecoder(std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer)))); + std::unique_ptr rle = createByteDecoder( + std::unique_ptr( + new SeekableArrayInputStream(buffer, std::size(buffer)))); std::vector data(103); rle->next(data.data(), data.size(), nullptr); @@ -54,9 +54,9 @@ TEST(ByteRle, simpleTest) { } TEST(ByteRle, nullTest) { - char buffer[258]; - uint64_t nulls[5]; - char result[266]; + char buffer[258] = {'\0'}; + uint64_t nulls[5] = {'\0'}; + char result[266] = {'\0'}; buffer[0] = -128; buffer[129] = -128; for (int32_t i = 0; i < 128; ++i) { @@ -66,8 +66,8 @@ TEST(ByteRle, nullTest) { for (int32_t i = 0; i < 266; ++i) { bits::setNull(nulls, i, i < 10); } - std::unique_ptr rle = - createByteDecoder(std::unique_ptr( + std::unique_ptr rle = createByteDecoder( + std::unique_ptr( new SeekableArrayInputStream(buffer, sizeof(buffer)))); rle->next(result, sizeof(result), nulls); for (size_t i = 0; i < sizeof(result); ++i) { @@ -93,9 +93,9 @@ TEST(ByteRle, literalCrossBuffer) { 0x09, 0x07, 0x10}; - std::unique_ptr rle = - createByteDecoder(std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer), 6))); + std::unique_ptr rle = createByteDecoder( + std::unique_ptr( + new SeekableArrayInputStream(buffer, std::size(buffer), 6))); std::vector data(20); rle->next(data.data(), data.size(), nullptr); @@ -109,9 +109,9 @@ TEST(ByteRle, literalCrossBuffer) { TEST(ByteRle, skipLiteralBufferUnderflowTest) { const unsigned char buffer[] = {0xf8, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7}; - std::unique_ptr rle = - createByteDecoder(std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer), 4))); + std::unique_ptr rle = createByteDecoder( + std::unique_ptr( + new SeekableArrayInputStream(buffer, std::size(buffer), 4))); std::vector data(8); rle->next(data.data(), 3, nullptr); EXPECT_EQ(0x0, data[0]); @@ -127,9 +127,9 @@ TEST(ByteRle, skipLiteralBufferUnderflowTest) { TEST(ByteRle, simpleRuns) { const unsigned char buffer[] = {0x0d, 0xff, 0x0d, 0xfe, 0x0d, 0xfd}; - std::unique_ptr rle = - createByteDecoder(std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer)))); + std::unique_ptr rle = createByteDecoder( + std::unique_ptr( + new SeekableArrayInputStream(buffer, std::size(buffer)))); std::vector data(16); for (size_t i = 0; i < 3; ++i) { rle->next(data.data(), data.size(), nullptr); @@ -145,9 +145,9 @@ TEST(ByteRle, splitHeader) { 0x00, 0x01, 0xe0, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20}; - std::unique_ptr rle = - createByteDecoder(std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer), 1))); + std::unique_ptr rle = createByteDecoder( + std::unique_ptr( + new SeekableArrayInputStream(buffer, std::size(buffer), 1))); std::vector data(35); rle->next(data.data(), data.size(), nullptr); for (size_t i = 0; i < 3; ++i) { @@ -179,9 +179,9 @@ TEST(ByteRle, splitRuns) { 0x0e, 0x0f, 0x10}; - std::unique_ptr rle = - createByteDecoder(std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer)))); + std::unique_ptr rle = createByteDecoder( + std::unique_ptr( + new SeekableArrayInputStream(buffer, std::size(buffer)))); std::vector data(5); for (size_t i = 0; i < 3; ++i) { rle->next(data.data(), data.size(), nullptr); @@ -227,9 +227,9 @@ TEST(ByteRle, testNulls) { 0x0f, 0x3d, 0xdc}; - std::unique_ptr rle = - createByteDecoder(std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer), 3))); + std::unique_ptr rle = createByteDecoder( + std::unique_ptr( + new SeekableArrayInputStream(buffer, std::size(buffer), 3))); std::vector data(16, -1); std::vector nulls(1); for (size_t i = 0; i < data.size(); ++i) { @@ -276,9 +276,9 @@ TEST(ByteRle, testAllNulls) { 0x0f, 0x3d, 0xdc}; - std::unique_ptr rle = - createByteDecoder(std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer)))); + std::unique_ptr rle = createByteDecoder( + std::unique_ptr( + new SeekableArrayInputStream(buffer, std::size(buffer)))); std::vector data(16, -1); std::vector allNull(1, bits::kNull64); std::vector noNull(1, bits::kNotNull64); @@ -413,7 +413,7 @@ TEST(ByteRle, testSkip) { 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, }; SeekableInputStream* const stream = - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer)); + new SeekableArrayInputStream(buffer, std::size(buffer)); std::unique_ptr rle = createByteDecoder(std::unique_ptr(stream)); std::vector data(1); @@ -570,7 +570,7 @@ TEST(ByteRle, testSeek) { 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, }; std::unique_ptr stream( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))); + new SeekableArrayInputStream(buffer, std::size(buffer))); const uint64_t fileLocs[] = { 0, 0, 0, 0, 0, 2, 2, 2, 2, 4, 4, 4, 4, 6, 6, 6, 6, 8, 8, 8, 8, 10, 10, 10, @@ -907,7 +907,7 @@ TEST(ByteRle, testSeek) { // Seek to end std::vector position; - position.push_back(VELOX_ARRAY_SIZE(buffer)); + position.push_back(std::size(buffer)); position.push_back(0); PositionProvider pp{position}; rle->seekToRowGroup(pp); @@ -916,7 +916,7 @@ TEST(ByteRle, testSeek) { // Seek to end + 1 position.clear(); - position.push_back(VELOX_ARRAY_SIZE(buffer)); + position.push_back(std::size(buffer)); position.push_back(1); PositionProvider pp2{position}; rle->seekToRowGroup(pp2); @@ -926,7 +926,7 @@ TEST(ByteRle, testSeek) { TEST(BooleanRle, simpleTest) { const unsigned char buffer[] = {0x61, 0xf0, 0xfd, 0x55, 0xAA, 0x55}; std::unique_ptr stream( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))); + new SeekableArrayInputStream(buffer, std::size(buffer))); std::unique_ptr rle = createBooleanDecoder(std::move(stream)); std::vector data(50); for (size_t i = 0; i < 16; ++i) { @@ -951,7 +951,7 @@ TEST(BooleanRle, runsTest) { const unsigned char buffer[] = { 0xf7, 0xff, 0x80, 0x3f, 0xe0, 0x0f, 0xf8, 0x03, 0xfe, 0x00}; std::unique_ptr stream( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))); + new SeekableArrayInputStream(buffer, std::size(buffer))); std::unique_ptr rle = createBooleanDecoder(std::move(stream)); std::vector data(72); rle->next(data.data(), data.size(), nullptr); @@ -973,7 +973,7 @@ TEST(BooleanRle, runsTestWithNull) { const unsigned char buffer[] = { 0xf7, 0xff, 0x80, 0x3f, 0xe0, 0x0f, 0xf8, 0x03, 0xfe, 0x00}; std::unique_ptr stream( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))); + new SeekableArrayInputStream(buffer, std::size(buffer))); std::unique_ptr rle = createBooleanDecoder(std::move(stream)); std::vector data(72); std::vector nulls(bits::nwords(data.size()), bits::kNotNull64); @@ -1089,7 +1089,7 @@ TEST(BooleanRle, skipTest) { 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71}; std::unique_ptr stream( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))); + new SeekableArrayInputStream(buffer, std::size(buffer))); std::unique_ptr rle = createBooleanDecoder(std::move(stream)); std::vector data(1); for (size_t i = 0; i < 16384; i += 5) { @@ -1200,7 +1200,7 @@ TEST(BooleanRle, skipTestWithNulls) { 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71}; std::unique_ptr stream( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))); + new SeekableArrayInputStream(buffer, std::size(buffer))); std::unique_ptr rle = createBooleanDecoder(std::move(stream)); raw_vector data; data.resize(3); @@ -1365,7 +1365,7 @@ TEST(BooleanRle, seekTest) { 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71}; std::unique_ptr stream( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))); + new SeekableArrayInputStream(buffer, std::size(buffer))); std::unique_ptr rle = createBooleanDecoder(std::move(stream)); // Read all 16384 values and validate them. @@ -1501,7 +1501,7 @@ TEST(BooleanRle, seekTestWithNulls) { 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71}; - auto* stream = new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer)); + auto* stream = new SeekableArrayInputStream(buffer, std::size(buffer)); auto rle = createBooleanDecoder(std::unique_ptr(stream)); ASSERT_EQ(stream->totalRead(), 0); auto lastTotalReadBytes = stream->totalRead(); @@ -1519,7 +1519,7 @@ TEST(BooleanRle, seekTestWithNulls) { EXPECT_EQ(0, bits::isBitSet(data.data(), i)) << "Output wrong at " << i; } rle->next(data.data(), data.size(), noNull.data()); - ASSERT_EQ(getNumReadBytes(), VELOX_ARRAY_SIZE(buffer)); + ASSERT_EQ(getNumReadBytes(), std::size(buffer)); for (size_t i = 0; i < data.size(); ++i) { EXPECT_EQ(i < 8192 ? i & 1 : (i / 3) & 1, bits::isBitSet(data.data(), i)) << "Output wrong at " << i; @@ -1573,7 +1573,7 @@ TEST(BooleanRle, seekBoolAndByteRLE) { 0xf9, 0xf0, 0xf0, 0xf7, 0x1c, 0x71, 0xc1, 0x80}; std::unique_ptr stream( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))); + new SeekableArrayInputStream(buffer, std::size(buffer))); std::unique_ptr rle = createBooleanDecoder(std::move(stream)); std::vector data(sizeof(num) / sizeof(char)); rle->next(data.data(), data.size(), nullptr); @@ -1596,7 +1596,7 @@ TEST(BooleanRle, seekBoolAndByteRLE) { TEST(BooleanRle, skipToEnd) { const unsigned char buffer[] = {0xfe, 0xff, 0xff}; std::unique_ptr stream( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))); + new SeekableArrayInputStream(buffer, std::size(buffer))); std::unique_ptr rle = createBooleanDecoder(std::move(stream)); char value[1]; rle->next(value, 1, nullptr); diff --git a/velox/dwio/dwrf/test/TestColumnReader.cpp b/velox/dwio/dwrf/test/TestColumnReader.cpp index da7b9f743f6..1ead2ca74b1 100644 --- a/velox/dwio/dwrf/test/TestColumnReader.cpp +++ b/velox/dwio/dwrf/test/TestColumnReader.cpp @@ -27,6 +27,7 @@ #include "velox/vector/ComplexVector.h" #include "velox/vector/DictionaryVector.h" #include "velox/vector/FlatVector.h" +#include "velox/vector/tests/utils/VectorTestBase.h" #include #include @@ -311,6 +312,10 @@ struct ReaderTestParams { } }; +inline void PrintTo(const ReaderTestParams& param, std::ostream* os) { + *os << param.toString(); +} + class TestColumnReader : public testing::TestWithParam, public ColumnReaderTestBase { protected: @@ -356,6 +361,55 @@ class TestColumnReader : public testing::TestWithParam, bool parallelDecoding() const override { return parallelDecoding_; } + + // Helper for SelectiveDecimalColumnReader tests that exercise schema + // mismatch between the file footer (e.g. Hive ORC's DECIMAL(38, 18)) + // and the requested table-schema type. + template + void verifyDecimalRequestedType( + const unsigned char (&dataBuffer)[kDataSize], + const unsigned char (&scaleBuffer)[kScaleSize], + const TypePtr& fileType, + const TypePtr& requestedType, + const std::vector& expectedValues) { + auto fileRowType = ROW("col_0", fileType); + auto requestedRowType = ROW("col_0", requestedType); + proto::ColumnEncoding directEncoding; + directEncoding.set_kind(proto::ColumnEncoding_Kind_DIRECT); + EXPECT_CALL(streams_, getEncodingProxy(_)) + .WillRepeatedly(Return(&directEncoding)); + + EXPECT_CALL( + streams_, getStreamProxy(_, proto::Stream_Kind_ROW_INDEX, false)) + .WillRepeatedly(Return(nullptr)); + EXPECT_CALL(streams_, getStreamProxy(_, proto::Stream_Kind_PRESENT, false)) + .WillRepeatedly(Return(nullptr)); + + EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) + .WillRepeatedly( + Return(new SeekableArrayInputStream(dataBuffer, kDataSize))); + EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_NANO_DATA, true)) + .WillRepeatedly( + Return(new SeekableArrayInputStream(scaleBuffer, kScaleSize))); + + auto scanSpec = std::make_unique("root"); + buildReader(requestedRowType, fileRowType, {}, scanSpec.get()); + VectorPtr batch = newBatch(requestedRowType); + skipAndRead(batch, /*readSize=*/expectedValues.size()); + + auto actual = getOnlyChild>(batch); + ASSERT_EQ(expectedValues.size(), batch->size()); + ASSERT_EQ(0, getNullCount(batch)); + ASSERT_EQ(0, getNullCount(actual)); + + auto* pool = &streams_.getMemoryPool(); + auto expected = BaseVector::create>( + requestedType, expectedValues.size(), pool); + for (vector_size_t i = 0; i < expectedValues.size(); ++i) { + expected->set(i, expectedValues[i]); + } + facebook::velox::test::assertEqualVectors(expected, actual); + } }; struct NonSelectiveReaderTestParams { @@ -370,6 +424,12 @@ struct NonSelectiveReaderTestParams { } }; +inline void PrintTo( + const NonSelectiveReaderTestParams& param, + std::ostream* os) { + *os << param.toString(); +} + // For test cases where SelectiveColumnReader does not have support. class TestNonSelectiveColumnReader : public testing::TestWithParam, @@ -423,6 +483,10 @@ struct SchemaMismatchTestParam { } }; +inline void PrintTo(const SchemaMismatchTestParam& param, std::ostream* os) { + *os << param.toString(); +} + class SchemaMismatchTest : public TestWithParam, public ColumnReaderTestBase { protected: @@ -517,14 +581,14 @@ TEST_P(TestColumnReader, testBooleanWithNulls) { // alternate 4 non-null and 4 null via [0xf0 for x in range(512 / 8)] const unsigned char buffer1[] = {0x3d, 0xf0}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // [0x0f for x in range(256 / 8)] const unsigned char buffer2[] = {0x1d, 0x0f}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -628,13 +692,13 @@ TEST_P(TestColumnReader, testBooleanSkipsWithNulls) { // alternate 4 non-null and 4 null via [0xf0 for x in range(512 / 8)] const unsigned char buffer1[] = {0x3d, 0xf0}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // [0x0f for x in range(128 / 8)] const unsigned char buffer2[] = {0x1d, 0x0f}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -683,8 +747,8 @@ TEST_P(TestColumnReader, testByteWithNulls) { // alternate 4 non-null and 4 null via [0xf0 for x in range(512 / 8)] const unsigned char buffer1[] = {0x3d, 0xf0}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // range(256) char buffer[258]; @@ -697,8 +761,8 @@ TEST_P(TestColumnReader, testByteWithNulls) { buffer[i + 2] = static_cast(i); } EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer, std::size(buffer)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -778,8 +842,8 @@ TEST_P(TestColumnReader, testByteSkipsWithNulls) { // alternate 4 non-null and 4 null via [0xf0 for x in range(512 / 8)] const unsigned char buffer1[] = {0x3d, 0xf0}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // range(256) char buffer[258]; @@ -792,8 +856,8 @@ TEST_P(TestColumnReader, testByteSkipsWithNulls) { buffer[i + 2] = static_cast(i); } EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer, std::size(buffer)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -856,7 +920,7 @@ TEST_P(TestColumnReader, testIntegerRLEv2) { int32_t expects_col0[] = {2110, 2120, 2130, 2140}; int32_t expects_col1[] = {11, 12, 13, 14}; int32_t expects_col2[] = {32, 34, 36, 38}; - int32_t size = VELOX_ARRAY_SIZE(col0); + int32_t size = std::size(col0); // set format streams_.setFormat(DwrfFormat::kOrc); @@ -884,8 +948,8 @@ TEST_P(TestColumnReader, testIntegerRLEv2) { // col_0's DATA stream EXPECT_CALL( streams_, getStreamOrcProxy(1, proto::orc::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer0, VELOX_ARRAY_SIZE(buffer0)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer0, std::size(buffer0)))); // col_1's DATA stream std::array data; std::vector v; @@ -903,8 +967,8 @@ TEST_P(TestColumnReader, testIntegerRLEv2) { // col_2's DATA stream EXPECT_CALL( streams_, getStreamOrcProxy(3, proto::orc::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = @@ -929,7 +993,7 @@ TEST_P(TestColumnReader, testIntegerRLEv2) { auto colBatch = getChild>(batch, 0); auto colBatch2 = getChild>(batch, 1); auto colBatch3 = getChild>(batch, 2); - ASSERT_EQ(VELOX_ARRAY_SIZE(expects_col0), colBatch->size()); + ASSERT_EQ(std::size(expects_col0), colBatch->size()); ASSERT_EQ(colBatch->size(), colBatch2->size()); ASSERT_EQ(colBatch2->size(), colBatch3->size()); for (size_t i = 0; i < batch->size(); ++i) { @@ -972,8 +1036,8 @@ TEST_P(TestColumnReader, testIntegerWithNulls) { .WillRepeatedly(Return(nullptr)); const unsigned char buffer1[] = {0x16, 0xf0}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); char buffer2[1024]; size_t size = writeRange(buffer2, 0, 100); @@ -1091,8 +1155,8 @@ TEST_P(TestColumnReader, testIntDictSkipWithNulls) { .WillRepeatedly(Return(nullptr)); const unsigned char buffer1[] = {0x16, 0xaa}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // even row points to dictionary. char buffer2[1024]; @@ -1108,8 +1172,8 @@ TEST_P(TestColumnReader, testIntDictSkipWithNulls) { const unsigned char buffer3[] = {0x0a, 0xaa}; EXPECT_CALL( streams_, getStreamProxy(1, proto::Stream_Kind_IN_DICTIONARY, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer3, VELOX_ARRAY_SIZE(buffer3)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer3, std::size(buffer3)))); EXPECT_CALL(streams_, genMockDictDataSetter(1, 0)) .WillRepeatedly(Return([&](BufferPtr& buffer, MemoryPool* pool) { @@ -1249,21 +1313,21 @@ TEST_P(StringReaderTests, testDictionaryWithNulls) { .WillRepeatedly(Return(nullptr)); const unsigned char buffer1[] = {0x19, 0xf0}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); const unsigned char buffer2[] = {0x2f, 0x00, 0x00, 0x2f, 0x00, 0x01}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); const unsigned char buffer3[] = {0x4f, 0x52, 0x43, 0x4f, 0x77, 0x65, 0x6e}; EXPECT_CALL( streams_, getStreamProxy(1, proto::Stream_Kind_DICTIONARY_DATA, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer3, VELOX_ARRAY_SIZE(buffer3)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer3, std::size(buffer3)))); const unsigned char buffer4[] = {0x02, 0x01, 0x03}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer4, VELOX_ARRAY_SIZE(buffer4)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer4, std::size(buffer4)))); TestStrideIndexProvider provider(10000); EXPECT_CALL(streams_, getStrideIndexProviderProxy()) @@ -1395,8 +1459,8 @@ TEST_P(StringReaderTests, testStringDictSkipNoNulls) { const unsigned char inDict[] = {0x0a, 0xaa}; EXPECT_CALL( streams_, getStreamProxy(1, proto::Stream_Kind_IN_DICTIONARY, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(inDict, VELOX_ARRAY_SIZE(inDict)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(inDict, std::size(inDict)))); auto indexData = index.SerializePartialAsString(); EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_ROW_INDEX, _)) @@ -1484,8 +1548,8 @@ TEST_P(StringReaderTests, testStringDictSkipWithNulls) { .WillRepeatedly(Return(nullptr)); const unsigned char present[] = {0x16, 0xaa}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(present, VELOX_ARRAY_SIZE(present)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(present, std::size(present)))); char data[1024]; data[0] = 0x9c; @@ -1565,8 +1629,8 @@ TEST_P(StringReaderTests, testStringDictSkipWithNulls) { const unsigned char inDict[] = {0x0a, 0xaa}; EXPECT_CALL( streams_, getStreamProxy(1, proto::Stream_Kind_IN_DICTIONARY, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(inDict, VELOX_ARRAY_SIZE(inDict)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(inDict, std::size(inDict)))); auto indexData = index.SerializePartialAsString(); EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_ROW_INDEX, _)) @@ -1637,18 +1701,18 @@ TEST_P(TestNonSelectiveColumnReader, testSubstructsWithNulls) { const unsigned char buffer1[] = {0x16, 0x0f}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); const unsigned char buffer2[] = {0x0a, 0x55}; EXPECT_CALL(streams_, getStreamProxy(2, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); const unsigned char buffer3[] = {0x04, 0xf0}; EXPECT_CALL(streams_, getStreamProxy(3, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer3, VELOX_ARRAY_SIZE(buffer3)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer3, std::size(buffer3)))); char buffer4[256]; size_t size = writeRange(buffer4, 0, 26); @@ -1723,13 +1787,13 @@ TEST_P(TestColumnReader, testSkipWithNulls) { const unsigned char buffer1[] = { 0x03, 0x00, 0xff, 0x3f, 0x08, 0xff, 0xff, 0xfc, 0x03, 0x00}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); EXPECT_CALL(streams_, getStreamProxy(2, _, _)) .WillRepeatedly(Return(nullptr)); EXPECT_CALL(streams_, getStreamProxy(2, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); TestStrideIndexProvider provider(10000); EXPECT_CALL(streams_, getStrideIndexProviderProxy()) @@ -1741,8 +1805,8 @@ TEST_P(TestColumnReader, testSkipWithNulls) { .WillRepeatedly(Return(new SeekableArrayInputStream(buffer2, size))); const unsigned char buffer3[] = {0x61, 0x01, 0x00}; EXPECT_CALL(streams_, getStreamProxy(2, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer3, VELOX_ARRAY_SIZE(buffer3)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer3, std::size(buffer3)))); // fill the dictionary with '00' to '99' char digits[200]; @@ -1754,12 +1818,12 @@ TEST_P(TestColumnReader, testSkipWithNulls) { } EXPECT_CALL( streams_, getStreamProxy(2, proto::Stream_Kind_DICTIONARY_DATA, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(digits, VELOX_ARRAY_SIZE(digits)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(digits, std::size(digits)))); const unsigned char buffer4[] = {0x61, 0x00, 0x02}; EXPECT_CALL(streams_, getStreamProxy(2, proto::Stream_Kind_LENGTH, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer4, VELOX_ARRAY_SIZE(buffer4)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer4, std::size(buffer4)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -1838,12 +1902,12 @@ TEST_P(StringReaderTests, testBinaryDirect) { } EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) .WillRepeatedly( - Return(new SeekableArrayInputStream(blob, VELOX_ARRAY_SIZE(blob)))); + Return(new SeekableArrayInputStream(blob, std::size(blob)))); const unsigned char buffer[] = {0x61, 0x00, 0x02}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer, std::size(buffer)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -1892,8 +1956,8 @@ TEST_P(StringReaderTests, testBinaryDirectWithNulls) { const unsigned char buffer1[] = {0x1d, 0xf0}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); char blob[256]; for (size_t i = 0; i < 8; ++i) { @@ -1904,12 +1968,12 @@ TEST_P(StringReaderTests, testBinaryDirectWithNulls) { } EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) .WillRepeatedly( - Return(new SeekableArrayInputStream(blob, VELOX_ARRAY_SIZE(blob)))); + Return(new SeekableArrayInputStream(blob, std::size(blob)))); const unsigned char buffer2[] = {0x7d, 0x00, 0x02}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -1969,12 +2033,12 @@ TEST_P(TestColumnReader, testShortBlobError) { char blob[100]; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) .WillRepeatedly( - Return(new SeekableArrayInputStream(blob, VELOX_ARRAY_SIZE(blob)))); + Return(new SeekableArrayInputStream(blob, std::size(blob)))); const unsigned char buffer1[] = {0x61, 0x00, 0x02}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -2020,13 +2084,13 @@ TEST_P(StringReaderTests, testStringDirectShortBuffer) { } } EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(blob, VELOX_ARRAY_SIZE(blob), 3))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(blob, std::size(blob), 3))); const unsigned char buffer1[] = {0x61, 0x00, 0x02}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -2075,8 +2139,8 @@ TEST_P(StringReaderTests, testStringDirectShortBufferWithNulls) { const unsigned char buffer1[] = {0x3d, 0xf0}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); char blob[512]; for (size_t i = 0; i < 16; ++i) { @@ -2086,13 +2150,13 @@ TEST_P(StringReaderTests, testStringDirectShortBufferWithNulls) { } } EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(blob, VELOX_ARRAY_SIZE(blob), 30))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(blob, std::size(blob), 30))); const unsigned char buffer2[] = {0x7d, 0x00, 0x02, 0x7d, 0x00, 0x02}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -2154,19 +2218,19 @@ TEST_P(StringReaderTests, testStringDirectNullAcrossWindow) { const unsigned char isNull[2] = {0xff, 0x7f}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(isNull, VELOX_ARRAY_SIZE(isNull)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(isNull, std::size(isNull)))); const char blob[] = "abcdefg"; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(blob, VELOX_ARRAY_SIZE(blob), 4))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(blob, std::size(blob), 4))); // [1] * 7 const unsigned char lenData[] = {0x04, 0x00, 0x01}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(lenData, VELOX_ARRAY_SIZE(lenData)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(lenData, std::size(lenData)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -2250,8 +2314,8 @@ TEST_P(StringReaderTests, testStringDirectSkip) { 0x01, 0x8a, 0x05, 0x7f, 0x01, 0x8c, 0x06, 0x7f, 0x01, 0x8e, 0x07, 0x7f, 0x01, 0x90, 0x08, 0x1b, 0x01, 0x92, 0x09}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -2298,8 +2362,8 @@ TEST_P(StringReaderTests, testStringDirectSkipWithNulls) { // alternate 4 non-null and 4 null via [0xf0 for x in range(2400 / 8)] const unsigned char buffer1[] = {0x7f, 0xf0, 0x7f, 0xf0, 0x25, 0xf0}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // sum(range(1200)) const size_t BLOB_SIZE = 719400; @@ -2323,8 +2387,8 @@ TEST_P(StringReaderTests, testStringDirectSkipWithNulls) { 0x01, 0x8a, 0x05, 0x7f, 0x01, 0x8c, 0x06, 0x7f, 0x01, 0x8e, 0x07, 0x7f, 0x01, 0x90, 0x08, 0x1b, 0x01, 0x92, 0x09}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -2386,8 +2450,8 @@ TEST_P(TestColumnReader, testList) { 0x00, 0x02}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // range(1200) char buffer2[8192]; @@ -2435,8 +2499,8 @@ TEST_P(TestNonSelectiveColumnReader, testListPropagateNulls) { // set getStream const unsigned char buffer[] = {0xff, 0x00}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer, std::size(buffer)))); EXPECT_CALL(streams_, getStreamProxy(2, proto::Stream_Kind_LENGTH, true)) .WillRepeatedly(Return(new SeekableArrayInputStream(buffer, 0))); @@ -2477,8 +2541,8 @@ TEST_P(TestNonSelectiveColumnReader, testListWithNulls) { // [0xaa for x in range(2048/8)] const unsigned char buffer1[] = {0x7f, 0xaa, 0x7b, 0xaa}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); EXPECT_CALL(streams_, getStreamProxy(2, proto::Stream_Kind_PRESENT, false)) .WillRepeatedly(Return(nullptr)); @@ -2493,8 +2557,8 @@ TEST_P(TestNonSelectiveColumnReader, testListWithNulls) { 0x00, 0x7f, 0x00, 0x00, 0x7f, 0x00, 0x03, 0x6e, 0x00, 0x03, 0xff, 0x13}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // range(2048) char buffer3[8192]; @@ -2636,8 +2700,8 @@ TEST_P(TestNonSelectiveColumnReader, testListSkipWithNulls) { // [0xaa for x in range(2048/8)] const unsigned char buffer1[] = {0x7f, 0xaa, 0x7b, 0xaa}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); EXPECT_CALL(streams_, getStreamProxy(2, proto::Stream_Kind_PRESENT, false)) .WillRepeatedly(Return(nullptr)); @@ -2652,8 +2716,8 @@ TEST_P(TestNonSelectiveColumnReader, testListSkipWithNulls) { 0x00, 0x7f, 0x00, 0x00, 0x7f, 0x00, 0x03, 0x6e, 0x00, 0x03, 0xff, 0x13}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // range(2048) char buffer3[8192]; @@ -2737,8 +2801,8 @@ TEST_P(TestNonSelectiveColumnReader, testListSkipWithNullsNoData) { // [0xaa for x in range(2048/8)] const unsigned char buffer1[] = {0x7f, 0xaa, 0x7b, 0xaa}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); EXPECT_CALL(streams_, getStreamProxy(2, proto::Stream_Kind_PRESENT, false)) .WillRepeatedly(Return(nullptr)); @@ -2753,8 +2817,8 @@ TEST_P(TestNonSelectiveColumnReader, testListSkipWithNullsNoData) { 0x00, 0x7f, 0x00, 0x00, 0x7f, 0x00, 0x03, 0x6e, 0x00, 0x03, 0xff, 0x13}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); EXPECT_CALL(streams_, getStreamProxy(2, proto::Stream_Kind_DATA, true)) .WillRepeatedly(Return(nullptr)); @@ -2815,8 +2879,8 @@ TEST_P(TestNonSelectiveColumnReader, testListWithAllNulls) { // set getStream const unsigned char buffer[] = {0xff, 0x00}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer, std::size(buffer)))); EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) .WillRepeatedly(Return(new SeekableArrayInputStream(buffer, 0))); @@ -2869,8 +2933,8 @@ TEST_P(TestColumnReader, testMap) { 0x00, 0x02}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // range(1200) char buffer2[8192]; @@ -2926,8 +2990,8 @@ TEST_P(TestNonSelectiveColumnReader, testMapWithNulls) { // [0xaa for x in range(2048/8)] const unsigned char buffer1[] = {0x7f, 0xaa, 0x7b, 0xaa}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); EXPECT_CALL(streams_, getStreamProxy(2, proto::Stream_Kind_PRESENT, false)) .WillRepeatedly(Return(nullptr)); @@ -2935,8 +2999,8 @@ TEST_P(TestNonSelectiveColumnReader, testMapWithNulls) { // [0x55 for x in range(2048/8)] const unsigned char buffer2[] = {0x7f, 0x55, 0x7b, 0x55}; EXPECT_CALL(streams_, getStreamProxy(3, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // [1 for x in range(260)] + // [4 for x in range(260)] + @@ -2948,8 +3012,8 @@ TEST_P(TestNonSelectiveColumnReader, testMapWithNulls) { 0x00, 0x7f, 0x00, 0x00, 0x7f, 0x00, 0x03, 0x6e, 0x00, 0x03, 0xff, 0x13}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer3, VELOX_ARRAY_SIZE(buffer3)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer3, std::size(buffer3)))); // range(2048) char buffer4[8192]; @@ -3131,8 +3195,8 @@ TEST_P(TestNonSelectiveColumnReader, testMapSkipWithNulls) { // [0xaa for x in range(2048/8)] const unsigned char buffer1[] = {0x7f, 0xaa, 0x7b, 0xaa}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // [1 for x in range(260)] + // [4 for x in range(260)] + @@ -3144,8 +3208,8 @@ TEST_P(TestNonSelectiveColumnReader, testMapSkipWithNulls) { 0x00, 0x7f, 0x00, 0x00, 0x7f, 0x00, 0x03, 0x6e, 0x00, 0x03, 0xff, 0x13}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // range(2048) char buffer3[8192]; @@ -3254,8 +3318,8 @@ TEST_P(TestNonSelectiveColumnReader, testMapSkipWithNullsNoData) { // [0xaa for x in range(2048/8)] const unsigned char buffer1[] = {0x7f, 0xaa, 0x7b, 0xaa}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // [1 for x in range(260)] + // [4 for x in range(260)] + @@ -3267,8 +3331,8 @@ TEST_P(TestNonSelectiveColumnReader, testMapSkipWithNullsNoData) { 0x00, 0x7f, 0x00, 0x00, 0x7f, 0x00, 0x03, 0x6e, 0x00, 0x03, 0xff, 0x13}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = HiveTypeParser().parse("struct>"); @@ -3322,8 +3386,8 @@ TEST_P(TestNonSelectiveColumnReader, testMapWithAllNulls) { const unsigned char buffer1[] = {0xff, 0x00}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) .WillRepeatedly(Return(new SeekableArrayInputStream(buffer1, 0))); @@ -3365,9 +3429,7 @@ TEST_P(TestColumnReader, testFloatBatchNotAligned) { EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) .WillRepeatedly(Return(new SeekableArrayInputStream( - byteValues, - VELOX_ARRAY_SIZE(byteValues), - VELOX_ARRAY_SIZE(byteValues) / 2))); + byteValues, std::size(byteValues), std::size(byteValues) / 2))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -3403,8 +3465,8 @@ TEST_P(TestColumnReader, testFloatWithNulls) { // 13 non-nulls followed by 19 nulls const unsigned char buffer1[] = {0xfc, 0xff, 0xf8, 0x0, 0x0}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); const float test_vals[] = { 1.0f, @@ -3427,8 +3489,8 @@ TEST_P(TestColumnReader, testFloatWithNulls) { 0x0, 0x80, 0xff, 0xff, 0xff, 0x7f, 0x7f, 0xff, 0xff, 0x7f, 0xff, 0x1, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x80}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -3472,8 +3534,8 @@ TEST_P(TestColumnReader, testFloatSkipWithNulls) { // 2 non-nulls, 2 nulls, 2 non-nulls, 2 nulls const unsigned char buffer1[] = {0xff, 0xcc}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // 1, 2.5, -100.125, 10000 const unsigned char buffer2[] = { @@ -3494,8 +3556,8 @@ TEST_P(TestColumnReader, testFloatSkipWithNulls) { 0x1c, 0x46}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -3563,8 +3625,8 @@ TEST_P(TestColumnReader, testDoubleWithNulls) { // 13 non-nulls followed by 19 nulls const unsigned char buffer1[] = {0xfc, 0xff, 0xf8, 0x0, 0x0}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); const double test_vals[] = { 1.0, @@ -3591,8 +3653,8 @@ TEST_P(TestColumnReader, testDoubleWithNulls) { 0xff, 0xff, 0xef, 0xff, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x80}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -3637,16 +3699,16 @@ TEST_P(TestColumnReader, testDoubleSkipWithNulls) { // 1 non-null, 5 nulls, 2 non-nulls const unsigned char buffer1[] = {0xff, 0x83}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // 1, 2, -2 const unsigned char buffer2[] = { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf0, 0x3f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xc0}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -3713,8 +3775,8 @@ TEST_P(TestColumnReader, testTimestampSkipWithNulls) { // 2 non-nulls, 2 nulls, 2 non-nulls, 2 nulls const unsigned char buffer1[] = {0xff, 0xcc}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); const unsigned char buffer2[] = { 0xfc, @@ -3735,13 +3797,13 @@ TEST_P(TestColumnReader, testTimestampSkipWithNulls) { 0xd4, 0x30}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); const unsigned char buffer3[] = {0x1, 0x8, 0x5e}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_NANO_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer3, VELOX_ARRAY_SIZE(buffer3)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer3, std::size(buffer3)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -3831,15 +3893,15 @@ TEST_P(TestColumnReader, testTimestamp) { 0xba, 0xa0, 0x1a, 0x9d, 0x88, 0xa6, 0x82, 0x1a, 0x9d, 0xba, 0x9c, 0xe4, 0x19, 0x9d, 0xee, 0xe1, 0xcd, 0x18}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); const unsigned char buffer2[] = { 0xf6, 0x00, 0xa8, 0xd1, 0xf9, 0xd6, 0x03, 0x00, 0x9e, 0x01, 0xec, 0x76, 0xf4, 0x76, 0xfc, 0x76, 0x84, 0x77, 0x8c, 0x77, 0xfd, 0x0b}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_NANO_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -3909,13 +3971,13 @@ TEST_P(TestColumnReader, testDecimal64) { } } EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return(new SeekableArrayInputStream( - numBuffer, VELOX_ARRAY_SIZE(numBuffer), 3))); + .WillRepeatedly(Return( + new SeekableArrayInputStream(numBuffer, std::size(numBuffer), 3))); // col_0's Secondary Stream const unsigned char buffer2[] = {0x3e, 0x00, 0x04}; // [0x02] * 65 EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_NANO_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -3948,19 +4010,19 @@ TEST_P(TestColumnReader, testDecimal64WithSkip) { const unsigned char presentBuffer[] = {0xfe, 0xff, 0x80}; // [0xff] EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) .WillRepeatedly(Return(new SeekableArrayInputStream( - presentBuffer, VELOX_ARRAY_SIZE(presentBuffer)))); + presentBuffer, std::size(presentBuffer)))); const unsigned char numBuffer[] = { 0xf8, 0xe8, 0xe2, 0xcf, 0xf4, 0xcb, 0xb6, 0xda, 0x0d, 0x86, 0xc1, 0xcc, 0xcd, 0x9e, 0xd5, 0xc5, 0x11, 0xb4, 0xf6, 0xfc, 0xf3, 0xb9, 0xba, 0x16, 0xca, 0xe7, 0xa3, 0xa6, 0xdf, 0x1c, 0xea, 0xad, 0xc0, 0xe5, 0x24, 0xf8, 0x94, 0x8c, 0x2f, 0x86, 0xa4, 0x3c, 0x94, 0x4d, 0x62}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return(new SeekableArrayInputStream( - numBuffer, VELOX_ARRAY_SIZE(numBuffer)))); + .WillRepeatedly(Return( + new SeekableArrayInputStream(numBuffer, std::size(numBuffer)))); const unsigned char buffer1[] = {0x06, 0x00, 0x14}; // [0x0a] * 9 EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_NANO_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -4004,7 +4066,7 @@ TEST_P(TestColumnReader, testDecimal128WithSkip) { const unsigned char presentBuffer[] = {0xfe, 0xff, 0xf8}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) .WillRepeatedly(Return(new SeekableArrayInputStream( - presentBuffer, VELOX_ARRAY_SIZE(presentBuffer)))); + presentBuffer, std::size(presentBuffer)))); const unsigned char numBuffer[] = { 0xf8, 0xe8, 0xe2, 0xcf, 0xf4, 0xcb, 0xb6, 0xda, 0x0d, 0x86, 0xc1, 0xcc, 0xcd, 0x9e, 0xd5, 0xc5, 0x11, 0xb4, 0xf6, 0xfc, 0xf3, 0xb9, 0xba, 0x16, @@ -4018,12 +4080,12 @@ TEST_P(TestColumnReader, testDecimal128WithSkip) { 0x93, 0xe8, 0xa3, 0xec, 0xd0, 0x96, 0xd4, 0xcc, 0xf6, 0xac, 0x02, }; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return(new SeekableArrayInputStream( - numBuffer, VELOX_ARRAY_SIZE(numBuffer)))); + .WillRepeatedly(Return( + new SeekableArrayInputStream(numBuffer, std::size(numBuffer)))); const unsigned char buffer1[] = {0x0a, 0x00, 0x4a}; // [0x02] * 13 EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_NANO_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -4065,6 +4127,71 @@ TEST_P(TestColumnReader, testDecimal128WithSkip) { DecimalUtil::toString(intBatch->valueAt(4), decimalType)); } +// Verify that when the file type doesn't match the metastore type, +// the metastore type wins and data is rescaled accordingly. + +// Integer column (per-row scale=0) stored in an ORC file whose footer +// declares DECIMAL(38, 18), but the metastore says DECIMAL(20, 0). The +// reader must produce the original integer values without rescaling. +TEST_P(TestColumnReader, longDecimalRequestedTypeScaleZero) { + const unsigned char dataBuffer[] = {0x02, 0x04, 0x06, 0x08}; + const unsigned char scaleBuffer[] = {0x01, 0x00, 0x00}; + verifyDecimalRequestedType( + dataBuffer, scaleBuffer, DECIMAL(38, 18), DECIMAL(20, 0), {1, 2, 3, 4}); +} + +// Per-row scale (5) already matches the requestedType scale (5); no +// rescaling expected. +TEST_P(TestColumnReader, longDecimalRequestedTypeScaleMatchesData) { + const unsigned char dataBuffer[] = {0x02, 0x04, 0x06, 0x08}; + const unsigned char scaleBuffer[] = {0x01, 0x00, 0x0A}; + verifyDecimalRequestedType( + dataBuffer, scaleBuffer, DECIMAL(38, 18), DECIMAL(25, 5), {1, 2, 3, 4}); +} + +// Per-row scale (3) is lower than the requestedType scale (5). The reader +// must upscale by multiplying by 10^(5-3) = 100. +TEST_P(TestColumnReader, longDecimalRequestedTypeUpscale) { + const unsigned char dataBuffer[] = {0x02, 0x04, 0x06, 0x08}; + const unsigned char scaleBuffer[] = {0x01, 0x00, 0x06}; + verifyDecimalRequestedType( + dataBuffer, + scaleBuffer, + DECIMAL(38, 18), + DECIMAL(25, 5), + {100, 200, 300, 400}); +} + +// Short decimal (BIGINT, precision<=18). File declares DECIMAL(12, 5), +// metastore says DECIMAL(10, 2). Reader must downscale by 10^(5-2)=1000. +TEST_P(TestColumnReader, shortDecimalRequestedTypeDownscale) { + const unsigned char dataBuffer[] = { + 0xD0, 0x0F, 0xA0, 0x1F, 0xF0, 0x2E, 0xC0, 0x3E}; + const unsigned char scaleBuffer[] = {0x01, 0x00, 0x0A}; + verifyDecimalRequestedType( + dataBuffer, scaleBuffer, DECIMAL(12, 5), DECIMAL(10, 2), {1, 2, 3, 4}); +} + +TEST_P(TestColumnReader, decimalRequestedTypeNonDecimalRejected) { + proto::ColumnEncoding directEncoding; + directEncoding.set_kind(proto::ColumnEncoding_Kind_DIRECT); + EXPECT_CALL(streams_, getEncodingProxy(_)) + .WillRepeatedly(Return(&directEncoding)); + EXPECT_CALL(streams_, getStreamProxy(_, proto::Stream_Kind_ROW_INDEX, false)) + .WillRepeatedly(Return(nullptr)); + EXPECT_CALL(streams_, getStreamProxy(_, proto::Stream_Kind_PRESENT, false)) + .WillRepeatedly(Return(nullptr)); + + auto fileType = ROW("col_0", DECIMAL(38, 18)); + auto scanSpec = std::make_unique("root"); + VELOX_ASSERT_THROW( + buildReader(ROW("col_0", BIGINT()), fileType, {}, scanSpec.get()), + "Schema mismatch, From Kind: HUGEINT, To Kind: BIGINT"); + VELOX_ASSERT_THROW( + buildReader(ROW("col_0", DOUBLE()), fileType, {}, scanSpec.get()), + "Schema mismatch, From Kind: HUGEINT, To Kind: DOUBLE"); +} + TEST_P(TestColumnReader, testLargeSkip) { // set getEncoding proto::ColumnEncoding directEncoding; @@ -4087,8 +4214,8 @@ TEST_P(TestColumnReader, testLargeSkip) { length[pos + 2] = 0x01; } EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(length, VELOX_ARRAY_SIZE(length)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(length, std::size(length)))); char data[1024 * 1024]; size_t size = writeRange(data, 0, 73200); diff --git a/velox/dwio/dwrf/test/TestDictionaryEncodingUtils.cpp b/velox/dwio/dwrf/test/TestDictionaryEncodingUtils.cpp index dc42b6941b1..0fb09b43fa4 100644 --- a/velox/dwio/dwrf/test/TestDictionaryEncodingUtils.cpp +++ b/velox/dwio/dwrf/test/TestDictionaryEncodingUtils.cpp @@ -22,7 +22,9 @@ using namespace testing; using namespace facebook::velox::memory; -namespace facebook::velox::dwrf { +namespace facebook::velox::dwrf::test { +namespace { + class DictionaryEncodingUtilsTest : public testing::Test { protected: static void SetUpTestCase() { @@ -36,7 +38,7 @@ TEST_F(DictionaryEncodingUtilsTest, StringGetSortedIndexLookupTable) { bool sort, std::function ordering, - const std::vector& addKeySequence, + const std::vector& addKeySequence, const std::vector& lookupTable) : sort{sort}, ordering{ordering}, @@ -45,7 +47,7 @@ TEST_F(DictionaryEncodingUtilsTest, StringGetSortedIndexLookupTable) { bool sort; std::function ordering; - std::vector addKeySequence; + std::vector addKeySequence; std::vector lookupTable; }; @@ -151,7 +153,7 @@ TEST_F(DictionaryEncodingUtilsTest, StringStrideDictOptimization) { bool sort, std::function ordering, - const std::vector& addKeySequence, + const std::vector& addKeySequence, const std::vector& lookupTable, const std::vector& inDict, size_t finalDictSize, @@ -166,7 +168,7 @@ TEST_F(DictionaryEncodingUtilsTest, StringStrideDictOptimization) { bool sort; std::function ordering; - std::vector addKeySequence; + std::vector addKeySequence; std::vector lookupTable; std::vector inDict; size_t finalDictSize; @@ -315,7 +317,7 @@ TEST_F(DictionaryEncodingUtilsTest, StringStrideDictOptimization) { dwio::common::DataBuffer strideDictSizes{ *pool, rowCount / kStrideSize + 1}; - std::vector expected{testCase.finalDictSize}; + std::vector expected{testCase.finalDictSize}; std::vector expectedSize(testCase.finalDictSize); for (size_t i = 0; i < testCase.lookupTable.size(); ++i) { if (testCase.inDict[i]) { @@ -368,4 +370,5 @@ TEST_F(DictionaryEncodingUtilsTest, StringStrideDictOptimization) { } } -} // namespace facebook::velox::dwrf +} // namespace +} // namespace facebook::velox::dwrf::test diff --git a/velox/dwio/dwrf/test/TestDwrfColumnStatistics.cpp b/velox/dwio/dwrf/test/TestDwrfColumnStatistics.cpp index cf13c6df826..a71ddc8a4bc 100644 --- a/velox/dwio/dwrf/test/TestDwrfColumnStatistics.cpp +++ b/velox/dwio/dwrf/test/TestDwrfColumnStatistics.cpp @@ -195,12 +195,12 @@ void checkEntries( for (const auto& entry : entries) { EXPECT_NE( std::find_if( - expectedEntries.begin(), - expectedEntries.end(), + expectedEntries.cbegin(), + expectedEntries.cend(), [&](const ColumnStatistics& expectedStats) { return expectedStats == entry; }), - expectedEntries.end()); + expectedEntries.cend()); } } @@ -490,11 +490,11 @@ TEST(MapStatisticsBuilderTest, mergeKeyStats) { statsBuilder.increaseRawSize(8); mapStatsBuilder.addValues(createKeyInfo(1), statsBuilder); - keyStats = dynamic_cast( + auto& keyStats1 = dynamic_cast( *mapStatsBuilder.getEntryStatistics().at(KeyInfo{1})); - ASSERT_EQ(2, keyStats.getNumberOfValues()); - ASSERT_TRUE(keyStats.getRawSize().has_value()); - ASSERT_EQ(8, keyStats.getRawSize().value()); - EXPECT_TRUE(keyStats.getSize().has_value()); - EXPECT_EQ(42, keyStats.getSize().value()); + ASSERT_EQ(2, keyStats1.getNumberOfValues()); + ASSERT_TRUE(keyStats1.getRawSize().has_value()); + ASSERT_EQ(8, keyStats1.getRawSize().value()); + EXPECT_TRUE(keyStats1.getSize().has_value()); + EXPECT_EQ(42, keyStats1.getSize().value()); } diff --git a/velox/dwio/dwrf/test/TestReadFile.h b/velox/dwio/dwrf/test/TestReadFile.h index 8501b231ed8..aca3b13cf2a 100644 --- a/velox/dwio/dwrf/test/TestReadFile.h +++ b/velox/dwio/dwrf/test/TestReadFile.h @@ -31,7 +31,7 @@ class TestReadFile : public velox::ReadFile { TestReadFile( uint64_t seed, uint64_t length, - std::shared_ptr ioStats) + std::shared_ptr ioStats) : seed_(seed), length_(length), ioStats_(std::move(ioStats)) {} uint64_t size() const override { @@ -42,15 +42,15 @@ class TestReadFile : public velox::ReadFile { uint64_t offset, uint64_t length, void* buffer, - filesystems::File::IoStats* stats = nullptr) const override { + const FileIoContext& context = {}) const override { const uint64_t content = offset + seed_; const uint64_t available = std::min(length_ - offset, length); int fill; for (fill = 0; fill < available; ++fill) { reinterpret_cast(buffer)[fill] = content + fill; } - if (stats) { - stats->addCounter( + if (context.ioStats) { + context.ioStats->addCounter( "read", RuntimeCounter(fill, RuntimeCounter::Unit::kBytes)); } return std::string_view(static_cast(buffer), fill); @@ -59,13 +59,12 @@ class TestReadFile : public velox::ReadFile { uint64_t preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats = nullptr) const override { - auto res = ReadFile::preadv(offset, buffers, stats); - if (stats) { - stats->addCounter( + const FileIoContext& context = {}) const override { + auto res = ReadFile::preadv(offset, buffers, context); + if (context.ioStats) { + context.ioStats->addCounter( "read", - RuntimeCounter( - static_cast(res), RuntimeCounter::Unit::kBytes)); + RuntimeCounter(saturateCast(res), RuntimeCounter::Unit::kBytes)); } ++numIos_; return res; @@ -103,7 +102,7 @@ class TestReadFile : public velox::ReadFile { private: const uint64_t seed_; const uint64_t length_; - std::shared_ptr ioStats_; + std::shared_ptr ioStats_; mutable std::atomic numIos_{0}; }; diff --git a/velox/dwio/dwrf/test/TestRle.cpp b/velox/dwio/dwrf/test/TestRle.cpp index b779dedff8e..d1fe6f2c813 100644 --- a/velox/dwio/dwrf/test/TestRle.cpp +++ b/velox/dwio/dwrf/test/TestRle.cpp @@ -88,7 +88,7 @@ TEST_F(RLEv2Test, basicDelta0) { } const unsigned char bytes[] = {0xc0, 0x13, 0x00, 0x02}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, count), 1); checkResults(values, decodeRLEv2(bytes, l, 3, count), 3); @@ -106,7 +106,7 @@ TEST_F(RLEv2Test, basicDelta1) { const unsigned char bytes[] = { 0xce, 0x04, 0xe7, 0x07, 0xc8, 0x01, 0x32, 0x19, 0x0f}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, values.size()), 1); checkResults(values, decodeRLEv2(bytes, l, 3, values.size()), 3); @@ -127,7 +127,7 @@ TEST_F(RLEv2Test, basicDelta2) { const unsigned char bytes[] = { 0xce, 0x04, 0xe7, 0x07, 0xc7, 0x01, 0x32, 0x19, 0x23}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, values.size()), 1); checkResults(values, decodeRLEv2(bytes, l, 3, values.size()), 3); @@ -148,7 +148,7 @@ TEST_F(RLEv2Test, basicDelta3) { const unsigned char bytes[] = { 0xce, 0x04, 0xe8, 0x07, 0xc7, 0x01, 0x32, 0x19, 0x0f}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, values.size()), 1); checkResults(values, decodeRLEv2(bytes, l, 3, values.size()), 3); @@ -169,7 +169,7 @@ TEST_F(RLEv2Test, basicDelta4) { const unsigned char bytes[] = { 0xce, 0x04, 0xe8, 0x07, 0xc8, 0x01, 0x32, 0x19, 0x23}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, values.size()), 1); checkResults(values, decodeRLEv2(bytes, l, 3, values.size()), 3); @@ -188,7 +188,7 @@ TEST_F(RLEv2Test, delta0Width) { createRleDecoder( std::unique_ptr( new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer))), + buffer, std::size(buffer))), RleVersion_2, *pool, true /* doesn't matter */, @@ -219,7 +219,7 @@ TEST_F(RLEv2Test, basicDelta0WithNulls) { } const unsigned char bytes[] = {0xc0, 0x13, 0x00, 0x02}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); const size_t count = values.size(); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, count, nulls), 1, nulls); @@ -243,7 +243,7 @@ TEST_F(RLEv2Test, shortRepeats) { const unsigned char bytes[] = {0x04, 0x00, 0x04, 0x02, 0x04, 0x04, 0x04, 0x06, 0x04, 0x08, 0x04, 0x0a, 0x04, 0x0c, 0x04, 0x0e, 0x04, 0x10, 0x04, 0x12}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, count), 1); checkResults(values, decodeRLEv2(bytes, l, 3, count), 3); @@ -266,7 +266,7 @@ TEST_F(RLEv2Test, multiByteShortRepeats) { 0x00, 0x00, 0x3c, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x3c, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, count), 1); checkResults(values, decodeRLEv2(bytes, l, 3, count), 3); @@ -280,7 +280,7 @@ TEST_F(RLEv2Test, 0to2Repeat1Direct) { std::unique_ptr> rle = createRleDecoder( std::unique_ptr( new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer))), + buffer, std::size(buffer))), RleVersion_2, *pool, true /* doesn't matter */, @@ -302,7 +302,7 @@ TEST_F(RLEv2Test, bitSize2Direct) { } const unsigned char bytes[] = {0x42, 0x13, 0x22, 0x22, 0x22, 0x22, 0x22}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, count), 1); checkResults(values, decodeRLEv2(bytes, l, 3, count), 3); @@ -320,7 +320,7 @@ TEST_F(RLEv2Test, bitSize4Direct) { const unsigned char bytes[] = { 0x46, 0x13, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, count), 1); @@ -360,7 +360,7 @@ TEST_F(RLEv2Test, multipleRunsDirect) { 0x04, 0x04, 0x04}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, values.size()), 1); @@ -382,7 +382,7 @@ TEST_F(RLEv2Test, largeNegativesDirect) { std::unique_ptr> rle = createRleDecoder( std::unique_ptr( new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer))), + buffer, std::size(buffer))), RleVersion_2, *pool, true /* doesn't matter */, @@ -408,7 +408,7 @@ TEST_F(RLEv2Test, overflowDirect) { 0x7e, 0x03, 0x7d, 0x45, 0x3c, 0x12, 0x41, 0x48, 0xf4, 0xbe, 0x7d, 0x45, 0x3c, 0x12, 0x41, 0x48, 0xf4, 0xae, 0x50, 0xce, 0xad, 0x2a, 0x30, 0x0e, 0xd2, 0x96, 0xfe, 0xd8, 0xd2, 0x38, 0x54, 0x6e, 0x3d, 0x81}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, values.size()), 1); checkResults(values, decodeRLEv2(bytes, l, 3, values.size()), 3); @@ -445,7 +445,7 @@ TEST_F(RLEv2Test, basicPatched0) { 0x5a, 0xfc, 0xe8}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, values.size()), 1); checkResults(values, decodeRLEv2(bytes, l, 3, values.size()), 3); @@ -484,7 +484,7 @@ TEST_F(RLEv2Test, basicPatched1) { 0xe0, 0x78, 0x00, 0x1c, 0x0f, 0x08, 0x06, 0x81, 0xc6, 0x90, 0x80, 0x68, 0x24, 0x1b, 0x0b, 0x26, 0x83, 0x21, 0x30, 0xe0, 0x98, 0x3c, 0x6f, 0x06, 0xb7, 0x03, 0x70}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, values.size()), 1); checkResults(values, decodeRLEv2(bytes, l, 3, values.size()), 3); @@ -544,7 +544,7 @@ TEST_F(RLEv2Test, mixedPatchedAndShortRepeats) { 0x00, 0x0c, 0x02, 0x08, 0x18, 0x00, 0x40, 0x00, 0x01, 0x00, 0x00, 0x08, 0x30, 0x33, 0x80, 0x00, 0x02, 0x0c, 0x10, 0x20, 0x20, 0x47, 0x80, 0x13, 0x4c}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, values.size()), 1); checkResults(values, decodeRLEv2(bytes, l, 3, values.size()), 3); @@ -579,7 +579,7 @@ TEST_F(RLEv2Test, basicDirectSeek) { 0x04, 0x04, 0x04}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); std::unique_ptr> rle = createRleDecoder( std::unique_ptr( @@ -688,7 +688,7 @@ TEST_F(RLEv1Test, simpleTest) { createRleDecoder( std::unique_ptr( new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer))), + buffer, std::size(buffer))), RleVersion_1, *pool, true, @@ -712,7 +712,7 @@ TEST_F(RLEv1Test, signedNullLiteralTest) { std::unique_ptr> rle = createRleDecoder( std::unique_ptr( new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer))), + buffer, std::size(buffer))), RleVersion_1, *pool, true, @@ -733,7 +733,7 @@ TEST_F(RLEv1Test, splitHeader) { createRleDecoder( std::unique_ptr( new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer), 4)), + buffer, std::size(buffer), 4)), RleVersion_1, *pool, true, @@ -751,8 +751,7 @@ TEST_F(RLEv1Test, splitRuns) { const unsigned char buffer[] = { 0x7d, 0x01, 0xff, 0x01, 0xfb, 0x01, 0x02, 0x03, 0x04, 0x05}; dwio::common::SeekableInputStream* const stream = - new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer)); + new dwio::common::SeekableArrayInputStream(buffer, std::size(buffer)); std::unique_ptr> rle = createRleDecoder( std::unique_ptr(stream), @@ -784,8 +783,7 @@ TEST_F(RLEv1Test, testSigned) { auto pool = memory::memoryManager()->addLeafPool(); const unsigned char buffer[] = {0x7f, 0xff, 0x20}; dwio::common::SeekableInputStream* const stream = - new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer)); + new dwio::common::SeekableArrayInputStream(buffer, std::size(buffer)); std::unique_ptr> rle = createRleDecoder( std::unique_ptr(stream), RleVersion_1, @@ -808,8 +806,7 @@ TEST_F(RLEv1Test, testNull) { auto pool = memory::memoryManager()->addLeafPool(); const unsigned char buffer[] = {0x75, 0x02, 0x00}; dwio::common::SeekableInputStream* const stream = - new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer)); + new dwio::common::SeekableArrayInputStream(buffer, std::size(buffer)); std::unique_ptr> rle = createRleDecoder( std::unique_ptr(stream), RleVersion_1, @@ -842,8 +839,7 @@ TEST_F(RLEv1Test, testAllNulls) { 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x3d, 0x00, 0x12}; dwio::common::SeekableInputStream* const stream = - new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer)); + new dwio::common::SeekableArrayInputStream(buffer, std::size(buffer)); std::unique_ptr> rle = createRleDecoder( std::unique_ptr(stream), @@ -1095,8 +1091,7 @@ TEST_F(RLEv1Test, skipTest) { 128, 228, 63, 128, 232, 63, 128, 236, 63, 128, 240, 63, 128, 244, 63, 128, 248, 63, 128, 252, 63}; dwio::common::SeekableInputStream* const stream = - new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer)); + new dwio::common::SeekableArrayInputStream(buffer, std::size(buffer)); std::unique_ptr> rle = createRleDecoder( std::unique_ptr(stream), RleVersion_1, @@ -1872,8 +1867,8 @@ TEST_F(RLEv1Test, seekTest) { 151, 12, 193, 190, 224, 143, 9, 129, 245, 133, 204, 8, 182, 209, 250, 178, 8, 148, 139, 144, 193, 11, 230, 182, 245, 164, 7, 149, 204, 161, 226, 14, 175, 229, 148, 166, 13, 148, 140, 189, 216, 3}; - auto* stream = new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer)); + auto* stream = + new dwio::common::SeekableArrayInputStream(buffer, std::size(buffer)); const long junk[] = { -1192035722, 1672896916, 1491444859, -1244121273, -791680696, 1681943525, -571055948, -1744759283, -998345856, 240559198, @@ -2991,7 +2986,7 @@ TEST_F(RLEv1Test, seekTest) { }; std::vector data(2048); rle->next(data.data(), data.size(), nullptr); - ASSERT_EQ(getNumReadBytes(), VELOX_ARRAY_SIZE(buffer)); + ASSERT_EQ(getNumReadBytes(), std::size(buffer)); for (size_t i = 0; i < data.size(); ++i) { if (i < 1024) { EXPECT_EQ(i / 4, data[i]) << "Wrong output at " << i; @@ -3023,7 +3018,7 @@ TEST_F(RLEv1Test, seekTest) { // Seek to end std::vector position; - position.push_back(VELOX_ARRAY_SIZE(buffer)); + position.push_back(std::size(buffer)); position.push_back(0); dwio::common::PositionProvider pp{position}; rle->seekToRowGroup(pp); @@ -3033,7 +3028,7 @@ TEST_F(RLEv1Test, seekTest) { // Seek to end + 1 position.clear(); - position.push_back(VELOX_ARRAY_SIZE(buffer)); + position.push_back(std::size(buffer)); position.push_back(1); dwio::common::PositionProvider pp2{position}; // Seek is fine (because it's lazy), but read should fail @@ -3049,7 +3044,7 @@ TEST_F(RLEv1Test, testLeadingNulls) { createRleDecoder( std::unique_ptr( new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer))), + buffer, std::size(buffer))), RleVersion_1, *pool, true, diff --git a/velox/dwio/dwrf/test/TestStringDictionaryEncoder.cpp b/velox/dwio/dwrf/test/TestStringDictionaryEncoder.cpp index 2dfd93797ec..f8c33d91ff6 100644 --- a/velox/dwio/dwrf/test/TestStringDictionaryEncoder.cpp +++ b/velox/dwio/dwrf/test/TestStringDictionaryEncoder.cpp @@ -20,8 +20,6 @@ DECLARE_bool(velox_enable_memory_usage_track_in_default_memory_pool); -using namespace facebook::velox::memory; - namespace facebook::velox::dwrf { class TestStringDictionaryEncoder : public ::testing::Test { @@ -35,10 +33,10 @@ class TestStringDictionaryEncoder : public ::testing::Test { TEST_F(TestStringDictionaryEncoder, AddKey) { struct TestCase { explicit TestCase( - const std::vector& addKeySequence, + const std::vector& addKeySequence, const std::vector& encodedSequence) : addKeySequence{addKeySequence}, encodedSequence{encodedSequence} {} - std::vector addKeySequence; + std::vector addKeySequence; std::vector encodedSequence; }; @@ -50,7 +48,7 @@ TEST_F(TestStringDictionaryEncoder, AddKey) { TestCase{{"doe", "sow", "sow", "doe", "sow"}, {0, 1, 1, 0, 1}}}; for (const auto& testCase : testCases) { - auto pool = memoryManager()->addLeafPool(); + auto pool = memory::memoryManager()->addLeafPool(); StringDictionaryEncoder stringDictEncoder{*pool, *pool}; std::vector actualEncodedSequence{}; for (const auto& key : testCase.addKeySequence) { @@ -63,14 +61,14 @@ TEST_F(TestStringDictionaryEncoder, AddKey) { TEST_F(TestStringDictionaryEncoder, GetIndex) { struct TestCase { explicit TestCase( - const std::vector& addKeySequence, - const std::vector& getIndexSequence, + const std::vector& addKeySequence, + const std::vector& getIndexSequence, const std::vector& encodedSequence) : addKeySequence{addKeySequence}, getIndexSequence{getIndexSequence}, encodedSequence{encodedSequence} {} - std::vector addKeySequence; - std::vector getIndexSequence; + std::vector addKeySequence; + std::vector getIndexSequence; std::vector encodedSequence; }; @@ -94,7 +92,7 @@ TEST_F(TestStringDictionaryEncoder, GetIndex) { {0, 3, 4, 2, 1, 3, 2, 4, 2, 0, 1, 0, 3}}}; for (const auto& testCase : testCases) { - auto pool = memoryManager()->addLeafPool(); + auto pool = memory::memoryManager()->addLeafPool(); StringDictionaryEncoder stringDictEncoder{*pool, *pool}; for (const auto& key : testCase.addKeySequence) { stringDictEncoder.addKey(key, 0); @@ -111,14 +109,14 @@ TEST_F(TestStringDictionaryEncoder, GetIndex) { TEST_F(TestStringDictionaryEncoder, GetCount) { struct TestCase { explicit TestCase( - const std::vector& addKeySequence, - const std::vector& getCountSequence, + const std::vector& addKeySequence, + const std::vector& getCountSequence, const std::vector& countSequence) : addKeySequence{addKeySequence}, getCountSequence{getCountSequence}, countSequence{countSequence} {} - std::vector addKeySequence; - std::vector getCountSequence; + std::vector addKeySequence; + std::vector getCountSequence; std::vector countSequence; }; @@ -143,7 +141,7 @@ TEST_F(TestStringDictionaryEncoder, GetCount) { {3, 2, 3, 3, 2}}}; for (const auto& testCase : testCases) { - auto pool = memoryManager()->addLeafPool(); + auto pool = memory::memoryManager()->addLeafPool(); StringDictionaryEncoder stringDictEncoder{*pool, *pool}; for (const auto& key : testCase.addKeySequence) { stringDictEncoder.addKey(key, 0); @@ -161,15 +159,14 @@ TEST_F(TestStringDictionaryEncoder, GetCount) { TEST_F(TestStringDictionaryEncoder, GetStride) { struct TestCase { explicit TestCase( - const std::vector>& - addKeySequence, - const std::vector& getStrideSequence, + const std::vector>& addKeySequence, + const std::vector& getStrideSequence, const std::vector& strideSequence) : addKeySequence{addKeySequence}, getStrideSequence{getStrideSequence}, strideSequence{strideSequence} {} - std::vector> addKeySequence; - std::vector getStrideSequence; + std::vector> addKeySequence; + std::vector getStrideSequence; std::vector strideSequence; }; @@ -197,7 +194,7 @@ TEST_F(TestStringDictionaryEncoder, GetStride) { {1, 1, 6, 3, 4}}}; for (const auto& testCase : testCases) { - auto pool = memoryManager()->addLeafPool(); + auto pool = memory::memoryManager()->addLeafPool(); StringDictionaryEncoder stringDictEncoder{*pool, *pool}; for (const auto& kv : testCase.addKeySequence) { stringDictEncoder.addKey(kv.first, kv.second); @@ -220,7 +217,7 @@ std::string genPaddedIntegerString(size_t integer, size_t length) { } TEST_F(TestStringDictionaryEncoder, Clear) { - auto pool = memoryManager()->addLeafPool(); + auto pool = memory::memoryManager()->addLeafPool(); StringDictionaryEncoder stringDictEncoder{*pool, *pool}; std::string baseString{"jjkkll"}; for (size_t i = 0; i != 2500; ++i) { @@ -242,7 +239,7 @@ TEST_F(TestStringDictionaryEncoder, Clear) { } TEST_F(TestStringDictionaryEncoder, MemBenchmark) { - auto pool = memoryManager()->addLeafPool(); + auto pool = memory::memoryManager()->addLeafPool(); StringDictionaryEncoder stringDictEncoder{*pool, *pool}; std::string baseString{"jjkkll"}; for (size_t i = 0; i != 10000; ++i) { @@ -253,15 +250,15 @@ TEST_F(TestStringDictionaryEncoder, MemBenchmark) { } TEST_F(TestStringDictionaryEncoder, Limit) { - auto pool = memoryManager()->addLeafPool(); + auto pool = memory::memoryManager()->addLeafPool(); StringDictionaryEncoder encoder{*pool, *pool}; - encoder.addKey(folly::StringPiece{"abc"}, 0); + encoder.addKey(std::string_view{"abc"}, 0); dwio::common::DataBuffer buf{*pool}; buf.resize(std::numeric_limits::max()); ASSERT_THROW( encoder.addKey( - folly::StringPiece{ + std::string_view{ buf.data(), std::numeric_limits::max() - 3}, 0), dwio::common::exception::LoggedException); diff --git a/velox/dwio/dwrf/test/WriterContextTest.cpp b/velox/dwio/dwrf/test/WriterContextTest.cpp index 30cbf19d720..cd67deddb3d 100644 --- a/velox/dwio/dwrf/test/WriterContextTest.cpp +++ b/velox/dwio/dwrf/test/WriterContextTest.cpp @@ -219,6 +219,44 @@ TEST_F(WriterContextTest, memory) { ASSERT_EQ(context.availableMemoryReservation(), 786368); } +TEST_F(WriterContextTest, memoryBudgetDefault) { + auto pool = memory::memoryManager()->addRootPool("memoryBudgetDefault"); + WriterContext context{std::make_shared(), pool}; + ASSERT_EQ(context.getMemoryBudget(), pool->maxCapacity()); +} + +TEST_F(WriterContextTest, memoryBudgetLessThanPoolCapacity) { + const int64_t poolCapacity = 1L << 30; + const int64_t budget = 256L << 20; + auto pool = memory::memoryManager()->addRootPool( + "memoryBudgetLessThanPoolCapacity", poolCapacity); + WriterContext context{ + std::make_shared(), + pool, + dwio::common::MetricsLog::voidLog(), + nullptr, + false, + nullptr, + budget}; + ASSERT_EQ(context.getMemoryBudget(), budget); +} + +TEST_F(WriterContextTest, memoryBudgetGreaterThanPoolCapacity) { + const int64_t poolCapacity = 256L << 20; + const int64_t budget = 1L << 30; + auto pool = memory::memoryManager()->addRootPool( + "memoryBudgetGreaterThanPoolCapacity", poolCapacity); + WriterContext context{ + std::make_shared(), + pool, + dwio::common::MetricsLog::voidLog(), + nullptr, + false, + nullptr, + budget}; + ASSERT_EQ(context.getMemoryBudget(), poolCapacity); +} + TEST_F(WriterContextTest, abort) { auto writerRoot = memory::memoryManager()->addRootPool( "abort", 1L << 30, exec::MemoryReclaimer::create()); diff --git a/velox/dwio/dwrf/test/WriterFlushTest.cpp b/velox/dwio/dwrf/test/WriterFlushTest.cpp index b8dfcac394f..f55d3bc4449 100644 --- a/velox/dwio/dwrf/test/WriterFlushTest.cpp +++ b/velox/dwio/dwrf/test/WriterFlushTest.cpp @@ -146,8 +146,9 @@ class MockMemoryPool : public velox::memory::MemoryPool { VELOX_UNSUPPORTED("allocateContiguous unsupported"); } - void freeContiguous(velox::memory::ContiguousAllocation& - /*unused*/) override { + void freeContiguous( + velox::memory::ContiguousAllocation& + /*unused*/) override { VELOX_UNSUPPORTED("freeContiguous unsupported"); } diff --git a/velox/dwio/dwrf/test/WriterTest.cpp b/velox/dwio/dwrf/test/WriterTest.cpp index d60683b53c3..e757d574e11 100644 --- a/velox/dwio/dwrf/test/WriterTest.cpp +++ b/velox/dwio/dwrf/test/WriterTest.cpp @@ -18,6 +18,7 @@ #include #include #include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/io/IoStatistics.h" #include "velox/dwio/dwrf/reader/ReaderBase.h" #include "velox/dwio/dwrf/writer/WriterBase.h" #include "velox/type/fbhive/HiveTypeParser.h" @@ -61,7 +62,9 @@ class WriterTest : public Test { std::string data(sinkPtr_->data(), sinkPtr_->size()); auto readFile = std::make_shared(std::move(data)); auto input = std::make_unique(std::move(readFile), *pool_); - dwio::common::ReaderOptions readerOpts{pool_.get()}; + dwio::common::ReaderOptions readerOpts(pool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); auto reader = std::make_unique(readerOpts, std::move(input)); reader->loadCache(); return reader; @@ -75,7 +78,7 @@ class WriterTest : public Test { return writer_->getFooter(); } - auto& addStripeInfo() { + StripeInformationWriteWrapper addStripeInfo() { return writer_->addStripeInfo(); } @@ -91,6 +94,10 @@ class WriterTest : public Test { std::shared_ptr pool_; MemorySink* sinkPtr_; std::unique_ptr writer_; + std::shared_ptr dataIoStats_{ + std::make_shared()}; + std::shared_ptr metadataIoStats_{ + std::make_shared()}; }; class SupportedCompressionTest @@ -168,7 +175,7 @@ TEST_P(AllWriterCompressionTest, compression) { folly::to(i), folly::to(i + 1)); } for (size_t i = 0; i < 4; ++i) { - getFooter().add_statistics(); + getFooter()->addStatistics(); } if (compressionKind_ == CompressionKind::CompressionKind_SNAPPY || @@ -230,7 +237,7 @@ TEST_P(SupportedCompressionTest, WriteFooter) { folly::to(i), folly::to(i + 1)); } for (size_t i = 0; i < 4; ++i) { - getFooter().add_statistics(); + getFooter()->addStatistics(); } writeFooter(*schema); writer.close(); @@ -262,11 +269,11 @@ TEST_P(SupportedCompressionTest, WriteFooter) { ASSERT_EQ(footer.metadataSize(), 5); for (size_t i = 0; i < 4; ++i) { auto item = footer.metadata(i); - if (item.name() == WRITER_NAME_KEY) { + if (item.name() == kWriterNameKey) { ASSERT_EQ(item.value(), kDwioWriter); - } else if (item.name() == WRITER_VERSION_KEY) { + } else if (item.name() == kWriterVersionKey) { ASSERT_EQ(item.value(), folly::to(reader->writerVersion())); - } else if (item.name() == WRITER_HOSTNAME_KEY) { + } else if (item.name() == kWriterHostnameKey) { ASSERT_EQ(item.value(), process::getHostName()); } else { ASSERT_EQ( @@ -306,9 +313,9 @@ TEST_P(SupportedCompressionTest, AddStripeInfo) { writerSink.addBuffer(*pool_, data.data(), data.size()); writerSink.setMode(WriterSink::Mode::None); - auto& ret = addStripeInfo(); - ASSERT_EQ(ret.numberofrows(), 101); - ASSERT_EQ(ret.rawdatasize(), 202); + auto ret = addStripeInfo(); + ASSERT_EQ(ret.numberOfRows(), 101); + ASSERT_EQ(ret.rawDataSize(), 202); ASSERT_EQ(ret.checksum(), 8963334039576633799); writer.close(); } @@ -326,14 +333,14 @@ TEST_P(SupportedCompressionTest, NoChecksum) { writerSink.addBuffer(*pool_, data.data(), data.size()); writerSink.setMode(WriterSink::Mode::None); - auto& ret = addStripeInfo(); - ASSERT_FALSE(ret.has_checksum()); + auto ret = addStripeInfo(); + ASSERT_FALSE(ret.hasChecksum()); std::string typeStr{"struct"}; HiveTypeParser parser; auto schema = parser.parse(typeStr); for (size_t i = 0; i < 4; ++i) { - getFooter().add_statistics(); + getFooter()->addStatistics(); } writeFooter(*schema); writer.close(); @@ -368,7 +375,7 @@ TEST_P(SupportedCompressionTest, NoCache) { HiveTypeParser parser; auto schema = parser.parse(typeStr); for (size_t i = 0; i < 4; ++i) { - getFooter().add_statistics(); + getFooter()->addStatistics(); } writeFooter(*schema); writer.close(); @@ -456,7 +463,20 @@ class MockFileSink : public dwio::common::FileSink { MOCK_METHOD(uint64_t, size, (), (const override)); MOCK_METHOD(bool, isBuffered, (), (const override)); +// On Centos9 the gtest mock header doesn't initialize the +// buffer_ member in MatcherBase correctly - the default constructor only +// initializes one: /usr/include/gtest/gtest-matchers.h:302:33 resulting in +// error: +// '.testing::Matcher::.testing::internal::MatcherBase::buffer_' is used uninitialized +// [-Werror=uninitialized] +// 302 | : vtable_(other.vtable_), buffer_(other.buffer_) { +// Fix: https://github.com/google/googletest/pull/3797 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wuninitialized" MOCK_METHOD(void, write, (std::vector>&)); +#pragma GCC diagnostic pop }; TEST_F(WriterTest, FlushWriterSinkUponClose) { diff --git a/velox/dwio/dwrf/test/utils/CMakeLists.txt b/velox/dwio/dwrf/test/utils/CMakeLists.txt index 617b9049077..f4cd63e7d71 100644 --- a/velox/dwio/dwrf/test/utils/CMakeLists.txt +++ b/velox/dwio/dwrf/test/utils/CMakeLists.txt @@ -13,6 +13,7 @@ # limitations under the License. add_library(velox_dwrf_test_utils E2EWriterTestUtil.cpp) +velox_add_test_headers(velox_dwrf_test_utils E2EWriterTestUtil.h) target_link_libraries( velox_dwrf_test_utils diff --git a/velox/dwio/dwrf/test/utils/E2EWriterTestUtil.cpp b/velox/dwio/dwrf/test/utils/E2EWriterTestUtil.cpp index e132023f415..b4098e3ffc7 100644 --- a/velox/dwio/dwrf/test/utils/E2EWriterTestUtil.cpp +++ b/velox/dwio/dwrf/test/utils/E2EWriterTestUtil.cpp @@ -17,6 +17,7 @@ #include "velox/dwio/dwrf/test/utils/E2EWriterTestUtil.h" #include +#include "velox/common/io/IoStatistics.h" #include "velox/dwio/common/tests/utils/BatchMaker.h" #include "velox/dwio/dwrf/reader/DwrfReader.h" #include "velox/dwio/dwrf/writer/FlushPolicy.h" @@ -114,7 +115,11 @@ namespace facebook::velox::dwrf { std::string(sinkPtr->data(), sinkPtr->size())); auto input = std::make_unique(readFile, pool); - dwio::common::ReaderOptions readerOpts{&pool}; + auto dataIoStats = std::make_shared(); + auto metadataIoStats = std::make_shared(); + dwio::common::ReaderOptions readerOpts(&pool); + readerOpts.setDataIoStats(dataIoStats); + readerOpts.setMetadataIoStats(metadataIoStats); RowReaderOptions rowReaderOpts; auto reader = std::make_unique(readerOpts, std::move(input)); EXPECT_GE(numStripesUpper, reader->getNumberOfStripes()); diff --git a/velox/dwio/dwrf/utils/CMakeLists.txt b/velox/dwio/dwrf/utils/CMakeLists.txt index 74655a9b0b1..82685590d34 100644 --- a/velox/dwio/dwrf/utils/CMakeLists.txt +++ b/velox/dwio/dwrf/utils/CMakeLists.txt @@ -16,5 +16,12 @@ if(${VELOX_BUILD_TESTING}) add_subdirectory(test) endif() -velox_add_library(velox_dwio_dwrf_utils ProtoUtils.cpp BitIterator.h) +velox_add_library( + velox_dwio_dwrf_utils + ProtoUtils.cpp + BitIterator.h + HEADERS + BufferedWriter.h + ProtoUtils.h +) velox_link_libraries(velox_dwio_dwrf_utils velox_dwio_dwrf_common velox_type velox_memory) diff --git a/velox/dwio/dwrf/utils/ProtoUtils.cpp b/velox/dwio/dwrf/utils/ProtoUtils.cpp index 405d2e79ddf..8cfbf459440 100644 --- a/velox/dwio/dwrf/utils/ProtoUtils.cpp +++ b/velox/dwio/dwrf/utils/ProtoUtils.cpp @@ -51,31 +51,34 @@ CREATE_TYPE_TRAIT(ROW, STRUCT) void ProtoUtils::writeType( const Type& type, - proto::Footer& footer, - proto::Type* parent) { - auto self = footer.add_types(); + FooterWriteWrapper& footer, + TypeWriteWrapper* parent) { + auto self = footer.addTypes(); if (parent) { - parent->add_subtypes(footer.types_size() - 1); + parent->addSubtypes(footer.typesSize() - 1); } + auto kind = VELOX_STATIC_FIELD_DYNAMIC_DISPATCH(SchemaType, kind, type.kind()); - self->set_kind(kind); + auto typeKindWrapper = TypeKindWrapper(&kind); + self.setKind(typeKindWrapper); + switch (type.kind()) { case TypeKind::ROW: { auto& row = type.asRow(); for (size_t i = 0; i < row.size(); ++i) { - self->add_fieldnames(row.nameOf(i)); - writeType(*row.childAt(i), footer, self); + self.addFieldnames(row.nameOf(i)); + writeType(*row.childAt(i), footer, &self); } break; } case TypeKind::ARRAY: - writeType(*type.asArray().elementType(), footer, self); + writeType(*type.asArray().elementType(), footer, &self); break; case TypeKind::MAP: { auto& map = type.asMap(); - writeType(*map.keyType(), footer, self); - writeType(*map.valueType(), footer, self); + writeType(*map.keyType(), footer, &self); + writeType(*map.valueType(), footer, &self); break; } default: diff --git a/velox/dwio/dwrf/utils/ProtoUtils.h b/velox/dwio/dwrf/utils/ProtoUtils.h index bea1d100330..310aaf7eb18 100644 --- a/velox/dwio/dwrf/utils/ProtoUtils.h +++ b/velox/dwio/dwrf/utils/ProtoUtils.h @@ -17,6 +17,7 @@ #pragma once #include "velox/dwio/common/SeekableInputStream.h" +#include "velox/dwio/dwrf/common/FileMetadata.h" #include "velox/dwio/dwrf/common/wrap/dwrf-proto-wrapper.h" #include "velox/type/Type.h" @@ -26,8 +27,8 @@ class ProtoUtils final { public: static void writeType( const Type& type, - proto::Footer& footer, - proto::Type* parent = nullptr); + FooterWriteWrapper&, + TypeWriteWrapper* parent = nullptr); static std::shared_ptr fromFooter( const proto::Footer& footer, diff --git a/velox/dwio/dwrf/utils/test/ProtoUtilsTests.cpp b/velox/dwio/dwrf/utils/test/ProtoUtilsTests.cpp index 92c6d01b00a..f32700513d2 100644 --- a/velox/dwio/dwrf/utils/test/ProtoUtilsTests.cpp +++ b/velox/dwio/dwrf/utils/test/ProtoUtilsTests.cpp @@ -31,7 +31,8 @@ TEST(ProtoUtilsTests, AllTypes) { HiveTypeParser parser; auto schema = parser.parse(type); proto::Footer footer; - ProtoUtils::writeType(*schema, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*schema, footerWrapper); auto out = ProtoUtils::fromFooter(footer); auto str = HiveTypeSerializer::serialize(out); @@ -45,7 +46,8 @@ TEST(ProtoUtilsTests, Projection) { auto schema = parser.parse( "struct>"); proto::Footer footer; - ProtoUtils::writeType(*schema, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*schema, footerWrapper); auto type = ProtoUtils::fromFooter( footer, [](auto id) { return id != 2 && id != 5; }); diff --git a/velox/dwio/dwrf/writer/CMakeLists.txt b/velox/dwio/dwrf/writer/CMakeLists.txt index ec121ef6aae..3450419be89 100644 --- a/velox/dwio/dwrf/writer/CMakeLists.txt +++ b/velox/dwio/dwrf/writer/CMakeLists.txt @@ -25,6 +25,24 @@ velox_add_library( WriterBase.cpp WriterContext.cpp WriterSink.cpp + HEADERS + ColumnWriter.h + DictionaryEncodingUtils.h + EntropyEncodingSelector.h + FlatMapColumnWriter.h + FlushPolicy.h + IndexBuilder.h + IntegerDictionaryEncoder.h + LayoutPlanner.h + PhysicalSizeAggregator.h + RatioTracker.h + StatisticsBuilder.h + StatisticsBuilderUtils.h + StringDictionaryEncoder.h + Writer.h + WriterBase.h + WriterContext.h + WriterSink.h ) velox_link_libraries( diff --git a/velox/dwio/dwrf/writer/ColumnWriter.cpp b/velox/dwio/dwrf/writer/ColumnWriter.cpp index 2a4cf207796..2100d18fd8a 100644 --- a/velox/dwio/dwrf/writer/ColumnWriter.cpp +++ b/velox/dwio/dwrf/writer/ColumnWriter.cpp @@ -68,8 +68,9 @@ class ByteRleColumnWriter : public BaseColumnWriter { uint64_t write(const VectorPtr& slice, const common::Ranges& ranges) override; void flush( - std::function encodingFactory, - std::function encodingOverride) override { + std::function encodingFactory, + std::function encodingOverride) + override { BaseColumnWriter::flush(encodingFactory, encodingOverride); data_->flush(); } @@ -248,8 +249,9 @@ class IntegerColumnWriter : public BaseColumnWriter { } void flush( - std::function encodingFactory, - std::function encodingOverride) override { + std::function encodingFactory, + std::function encodingOverride) + override { tryAbandonDictionaries(false); initStreamWriters(useDictionaryEncoding_); @@ -287,12 +289,14 @@ class IntegerColumnWriter : public BaseColumnWriter { // FIXME: call base class set encoding first to deal with sequence and // whatnot. - void setEncoding(proto::ColumnEncoding& encoding) const override { + void setEncoding(ColumnEncodingWriteWrapper& encoding) const override { BaseColumnWriter::setEncoding(encoding); if (useDictionaryEncoding_) { - encoding.set_kind( - proto::ColumnEncoding_Kind::ColumnEncoding_Kind_DICTIONARY); - encoding.set_dictionarysize(finalDictionarySize_); + auto columnEncodingKind = + proto::ColumnEncoding_Kind::ColumnEncoding_Kind_DICTIONARY; + encoding.setKind(ColumnEncodingKindWrapper(&columnEncodingKind)); + + encoding.setDictionarySize(finalDictionarySize_); } } @@ -679,8 +683,9 @@ class TimestampColumnWriter : public BaseColumnWriter { uint64_t write(const VectorPtr& slice, const common::Ranges& ranges) override; void flush( - std::function encodingFactory, - std::function encodingOverride) override { + std::function encodingFactory, + std::function encodingOverride) + override { BaseColumnWriter::flush(encodingFactory, encodingOverride); seconds_->flush(); nanos_->flush(); @@ -788,8 +793,9 @@ class DecimalColumnWriter : public BaseColumnWriter { } void flush( - std::function encodingFactory, - std::function encodingOverride) override { + std::function encodingFactory, + std::function encodingOverride) + override { BaseColumnWriter::flush(encodingFactory, encodingOverride); unscaledValues_->flush(); scales_->flush(); @@ -954,8 +960,9 @@ class StringColumnWriter : public BaseColumnWriter { } void flush( - std::function encodingFactory, - std::function encodingOverride) override { + std::function encodingFactory, + std::function encodingOverride) + override { tryAbandonDictionaries(false); initStreamWriters(useDictionaryEncoding_); @@ -1000,12 +1007,13 @@ class StringColumnWriter : public BaseColumnWriter { // FIXME: call base class set encoding first to deal with sequence and // whatnot. - void setEncoding(proto::ColumnEncoding& encoding) const override { + void setEncoding(ColumnEncodingWriteWrapper& encoding) const override { BaseColumnWriter::setEncoding(encoding); if (useDictionaryEncoding_) { - encoding.set_kind( - proto::ColumnEncoding_Kind::ColumnEncoding_Kind_DICTIONARY); - encoding.set_dictionarysize(finalDictionarySize_); + auto columnEncodingKind = + proto::ColumnEncoding_Kind::ColumnEncoding_Kind_DICTIONARY; + encoding.setKind(ColumnEncodingKindWrapper(&columnEncodingKind)); + encoding.setDictionarySize(finalDictionarySize_); } } @@ -1230,10 +1238,12 @@ uint64_t StringColumnWriter::writeDict( size_t strideIndex = strideOffsets_.size() - 1; uint64_t rawSize = 0; auto processRow = [&](size_t pos) { - auto sp = decodedVector.valueAt(pos); - rows_.unsafeAppend(dictEncoder_.addKey(sp, strideIndex)); - statsBuilder.addValues(sp); - rawSize += sp.size(); + auto sv = decodedVector.valueAt(pos); + // TODO: Remove explicit std::string_view cast. + rows_.unsafeAppend(dictEncoder_.addKey(std::string_view(sv), strideIndex)); + // TODO: Remove explicit std::string_view cast. + statsBuilder.addValues(std::string_view(sv)); + rawSize += sv.size(); }; uint64_t nullCount = 0; @@ -1274,10 +1284,11 @@ uint64_t StringColumnWriter::writeDirect( uint64_t rawSize = 0; auto processRow = [&](size_t pos) { - auto sp = decodedVector.valueAt(pos); - auto size = sp.size(); - dataDirect_->write(sp.data(), size); - statsBuilder.addValues(sp); + auto sv = decodedVector.valueAt(pos); + auto size = sv.size(); + dataDirect_->write(sv.data(), size); + // TODO: Remove explicit std::string_view cast. + statsBuilder.addValues(std::string_view(sv)); rawSize += size; lengths.unsafeAppend(size); }; @@ -1481,8 +1492,9 @@ class FloatColumnWriter : public BaseColumnWriter { uint64_t write(const VectorPtr& slice, const common::Ranges& ranges) override; void flush( - std::function encodingFactory, - std::function encodingOverride) override { + std::function encodingFactory, + std::function encodingOverride) + override { BaseColumnWriter::flush(encodingFactory, encodingOverride); data_.flush(); } @@ -1612,8 +1624,9 @@ class BinaryColumnWriter : public BaseColumnWriter { uint64_t write(const VectorPtr& slice, const common::Ranges& ranges) override; void flush( - std::function encodingFactory, - std::function encodingOverride) override { + std::function encodingFactory, + std::function encodingOverride) + override { BaseColumnWriter::flush(encodingFactory, encodingOverride); data_.flush(); lengths_->flush(); @@ -1726,8 +1739,9 @@ class StructColumnWriter : public BaseColumnWriter { uint64_t write(const VectorPtr& slice, const common::Ranges& ranges) override; void flush( - std::function encodingFactory, - std::function encodingOverride) override { + std::function encodingFactory, + std::function encodingOverride) + override { BaseColumnWriter::flush(encodingFactory, encodingOverride); for (auto& c : children_) { c->flush(encodingFactory); @@ -1855,8 +1869,9 @@ class ListColumnWriter : public BaseColumnWriter { uint64_t write(const VectorPtr& slice, const common::Ranges& ranges) override; void flush( - std::function encodingFactory, - std::function encodingOverride) override { + std::function encodingFactory, + std::function encodingOverride) + override { BaseColumnWriter::flush(encodingFactory, encodingOverride); lengths_->flush(); children_.at(0)->flush(encodingFactory); @@ -1982,8 +1997,9 @@ class MapColumnWriter : public BaseColumnWriter { uint64_t write(const VectorPtr& slice, const common::Ranges& ranges) override; void flush( - std::function encodingFactory, - std::function encodingOverride) override { + std::function encodingFactory, + std::function encodingOverride) + override { BaseColumnWriter::flush(encodingFactory, encodingOverride); lengths_->flush(); children_.at(0)->flush(encodingFactory); @@ -2119,7 +2135,7 @@ std::unique_ptr BaseColumnWriter::create( "MAP_FLAT_COLS contains column {}, but the root type of this column is {}." " Column root types must be of type MAP", type.column(), - mapTypeKindToName(type.type()->kind())); + TypeKindName::toName(type.type()->kind())); } const auto structColumnKeys = context.getConfig(Config::MAP_FLAT_COLS_STRUCT_KEYS); @@ -2212,7 +2228,7 @@ std::unique_ptr BaseColumnWriter::create( } default: VELOX_FAIL( - "not supported yet: {}", mapTypeKindToName(type.type()->kind())); + "not supported yet: {}", TypeKindName::toName(type.type()->kind())); } } } // namespace facebook::velox::dwrf diff --git a/velox/dwio/dwrf/writer/ColumnWriter.h b/velox/dwio/dwrf/writer/ColumnWriter.h index 98c5691babb..9fde39f331a 100644 --- a/velox/dwio/dwrf/writer/ColumnWriter.h +++ b/velox/dwio/dwrf/writer/ColumnWriter.h @@ -45,12 +45,31 @@ class ColumnWriter { virtual void reset() = 0; virtual void flush( - std::function encodingFactory, - std::function encodingOverride = - [](auto& /* e */) {}) = 0; + std::function encodingFactory, + std::function encodingOverride = + [](auto /* e */) {}) { + VELOX_NYI(); + } + + virtual void flush( + std::function + encodingFactory, + std::function + encodingOverride = [](auto& /* e */) {}) { + VELOX_NYI(); + } virtual uint64_t writeFileStats( - std::function statsFactory) const = 0; + std::function statsFactory) + const { + VELOX_NYI(); + } + + virtual uint64_t writeFileStats( + std::function + statsFactory) const { + VELOX_NYI(); + } virtual bool tryAbandonDictionaries(bool force) = 0; @@ -61,11 +80,13 @@ class ColumnWriter { const uint32_t sequence) : id_{id}, sequence_{sequence}, context_{context} {} - virtual void setEncoding(proto::ColumnEncoding& encoding) const { - encoding.set_kind(proto::ColumnEncoding_Kind::ColumnEncoding_Kind_DIRECT); - encoding.set_dictionarysize(0); - encoding.set_node(id_); - encoding.set_sequence(sequence_); + virtual void setEncoding(ColumnEncodingWriteWrapper& columnEncoding) const { + auto columnEncodingKind = + proto::ColumnEncoding_Kind::ColumnEncoding_Kind_DIRECT; + columnEncoding.setKind(ColumnEncodingKindWrapper(&columnEncodingKind)); + columnEncoding.setDictionarySize(0); + columnEncoding.setNode(id_); + columnEncoding.setSequence(sequence_); } const uint32_t id_; @@ -99,9 +120,9 @@ class BaseColumnWriter : public ColumnWriter { } void flush( - std::function encodingFactory, - std::function encodingOverride = - [](auto& /* e */) {}) override { + std::function encodingFactory, + std::function encodingOverride = + [](auto /* e */) {}) override { if (!isRoot()) { present_->flush(); @@ -113,21 +134,22 @@ class BaseColumnWriter : public ColumnWriter { } } - auto& encoding = encodingFactory(id_); + auto encoding = encodingFactory(id_); setEncoding(encoding); encodingOverride(encoding); indexBuilder_->flush(); } - uint64_t writeFileStats(std::function - statsFactory) const override { - auto& stats = statsFactory(id_); + uint64_t writeFileStats( + std::function statsFactory) + const override { + auto stats = statsFactory(id_); fileStatsBuilder_->toProto(stats); const uint64_t size = context_.getPhysicalSizeAggregator(id_).getResult(); for (auto& child : children_) { child->writeFileStats(statsFactory); } - stats.set_size(size); + stats.setSize(size); return size; } @@ -170,7 +192,7 @@ class BaseColumnWriter : public ColumnWriter { createBooleanRleEncoder(newStream(StreamKind::StreamKind_PRESENT)); } const auto options = - StatisticsBuilderOptions::fromConfig(context.getConfigs()); + StatisticsBuilder::optionsFromConfig(context.getConfigs()); indexStatsBuilder_ = StatisticsBuilder::create(*type.type(), options); fileStatsBuilder_ = StatisticsBuilder::create(*type.type(), options); } diff --git a/velox/dwio/dwrf/writer/FlatMapColumnWriter.cpp b/velox/dwio/dwrf/writer/FlatMapColumnWriter.cpp index 0099d52f65e..20d88cd5157 100644 --- a/velox/dwio/dwrf/writer/FlatMapColumnWriter.cpp +++ b/velox/dwio/dwrf/writer/FlatMapColumnWriter.cpp @@ -21,7 +21,6 @@ #include "velox/vector/FlatMapVector.h" namespace facebook::velox::dwrf { - namespace { template @@ -62,7 +61,10 @@ FlatMapColumnWriter::FlatMapColumnWriter( valueType_{*type.childAt(1)}, maxKeyCount_{context_.getConfig(Config::MAP_FLAT_MAX_KEYS)}, collectMapStats_{context.getConfig(Config::MAP_STATISTICS)} { - auto options = StatisticsBuilderOptions::fromConfig(context.getConfigs()); + if constexpr (std::is_same_v) { + stringKeys_.reserve(maxKeyCount_); + } + auto options = StatisticsBuilder::optionsFromConfig(context.getConfigs()); keyFileStatsBuilder_ = std::unique_ptr::StatisticsBuilder>( dynamic_cast::StatisticsBuilder*>( @@ -84,15 +86,17 @@ FlatMapColumnWriter::FlatMapColumnWriter( template void FlatMapColumnWriter::setEncoding( - proto::ColumnEncoding& encoding) const { + ColumnEncodingWriteWrapper& encoding) const { BaseColumnWriter::setEncoding(encoding); - encoding.set_kind(proto::ColumnEncoding_Kind::ColumnEncoding_Kind_MAP_FLAT); + auto columnEncodingKind = + proto::ColumnEncoding_Kind::ColumnEncoding_Kind_MAP_FLAT; + encoding.setKind(ColumnEncodingKindWrapper(&columnEncodingKind)); } template void FlatMapColumnWriter::flush( - std::function encodingFactory, - std::function encodingOverride) { + std::function encodingFactory, + std::function encodingOverride) { BaseColumnWriter::flush(encodingFactory, encodingOverride); for (auto& pair : valueWriters_) { @@ -136,18 +140,18 @@ void FlatMapColumnWriter::createIndexEntry() { template uint64_t FlatMapColumnWriter::writeFileStats( - std::function statsFactory) const { - auto& stats = statsFactory(id_); + std::function statsFactory) const { + auto stats = statsFactory(id_); fileStatsBuilder_->toProto(stats); uint64_t size = context_.getPhysicalSizeAggregator(id_).getResult(); - auto& keyStats = statsFactory(keyType_.id()); + auto keyStats = statsFactory(keyType_.id()); keyFileStatsBuilder_->toProto(keyStats); auto keySize = context_.getPhysicalSizeAggregator(keyType_.id()).getResult(); - keyStats.set_size(keySize); + keyStats.setSize(keySize); valueFileStatsBuilder_->writeFileStats(statsFactory); - stats.set_size(size); + stats.setSize(size); return size; } @@ -173,6 +177,7 @@ void FlatMapColumnWriter::reset() { BaseColumnWriter::reset(); clearNodes(); valueWriters_.clear(); + stringKeys_.clear(); rowsInStrides_.clear(); rowsInCurrentStride_ = 0; totalRows_ = 0; @@ -200,15 +205,26 @@ ValueWriter& FlatMapColumnWriter::getValueWriter( } if (valueWriters_.size() >= maxKeyCount_) { - DWIO_RAISE(fmt::format( - "Too many map keys requested in (node {}, column {}). Allowed: {}", - id_, - type_.column(), - maxKeyCount_)); + DWIO_RAISE( + fmt::format( + "Too many map keys requested in (node {}, column {}). Allowed: {}", + id_, + type_.column(), + maxKeyCount_)); } auto keyInfo = getKeyInfo(key); + // For non-inline StringView keys (>12 chars), store an owned copy of the + // string data to prevent dangling pointers when input batches are released + // between writes. Inline StringViews are self-contained and safe as-is. + if constexpr (std::is_same_v) { + if (!key.isInline()) { + stringKeys_.emplace_back(key.data(), key.size()); + key = StringView(stringKeys_.back()); + } + } + it = valueWriters_ .emplace( std::piecewise_construct, @@ -252,7 +268,7 @@ uint32_t updateKeyStatistics( StringView value, uint64_t count) { auto size = value.size(); - keyStatsBuilder.addValues(folly::StringPiece{value.data(), size}, count); + keyStatsBuilder.addValues(std::string_view{value.data(), size}, count); return size * count; } diff --git a/velox/dwio/dwrf/writer/FlatMapColumnWriter.h b/velox/dwio/dwrf/writer/FlatMapColumnWriter.h index c6d706b8a4a..3394ab8a9bc 100644 --- a/velox/dwio/dwrf/writer/FlatMapColumnWriter.h +++ b/velox/dwio/dwrf/writer/FlatMapColumnWriter.h @@ -38,7 +38,7 @@ class ValueStatisticsBuilder { static std::unique_ptr create( WriterContext& context, const dwio::common::TypeWithId& root) { - auto options = StatisticsBuilderOptions::fromConfig(context.getConfigs()); + auto options = StatisticsBuilder::optionsFromConfig(context.getConfigs()); return create_(context, root, options); } @@ -53,14 +53,15 @@ class ValueStatisticsBuilder { } uint64_t writeFileStats( - std::function statsFactory) const { - auto& stats = statsFactory(id_); + std::function statsFactory) + const { + auto stats = statsFactory(id_); statisticsBuilder_->toProto(stats); uint64_t size = context_.getPhysicalSizeAggregator(id_).getResult(); for (int32_t i = 0; i < children_.size(); ++i) { children_[i]->writeFileStats(statsFactory); } - stats.set_size(size); + stats.setSize(size); return size; } @@ -209,10 +210,11 @@ class ValueWriter { columnWriter_->createIndexEntry(); } - void flush(std::function encodingFactory) { + void flush( + std::function encodingFactory) { inMap_->flush(); - columnWriter_->flush(encodingFactory, [&](auto& encoding) { - *encoding.mutable_key() = keyInfo_; + columnWriter_->flush(encodingFactory, [&](auto encoding) { + *encoding.mutableKey() = keyInfo_; }); } @@ -290,20 +292,22 @@ class FlatMapColumnWriter : public BaseColumnWriter { uint64_t write(const VectorPtr& slice, const common::Ranges& ranges) override; void flush( - std::function encodingFactory, - std::function encodingOverride) override; + std::function encodingFactory, + std::function encodingOverride) + override; void createIndexEntry() override; void reset() override; - uint64_t writeFileStats(std::function - statsFactory) const override; + uint64_t writeFileStats( + std::function statsFactory) + const override; private: using KeyType = typename TypeTraits::NativeType; - void setEncoding(proto::ColumnEncoding& encoding) const override; + void setEncoding(ColumnEncodingWriteWrapper& encoding) const override; ValueWriter& getValueWriter(KeyType key, uint32_t inMapSize); diff --git a/velox/dwio/dwrf/writer/IndexBuilder.h b/velox/dwio/dwrf/writer/IndexBuilder.h index 57582199fc5..88cbade7dc3 100644 --- a/velox/dwio/dwrf/writer/IndexBuilder.h +++ b/velox/dwio/dwrf/writer/IndexBuilder.h @@ -16,12 +16,14 @@ #pragma once +#include "velox/dwio/common/Arena.h" #include "velox/dwio/common/OutputStream.h" #include "velox/dwio/dwrf/common/wrap/dwrf-proto-wrapper.h" #include "velox/dwio/dwrf/writer/StatisticsBuilder.h" namespace facebook::velox::dwrf { +using dwio::common::ArenaCreate; using dwio::common::BufferedOutputStream; using dwio::common::PositionRecorder; @@ -35,41 +37,50 @@ constexpr int32_t PRESENT_STREAM_INDEX_ENTRIES_PAGED = class IndexBuilder : public PositionRecorder { public: - IndexBuilder(std::unique_ptr out) - : out_{std::move(out)} {} + IndexBuilder( + std::unique_ptr out, + dwio::common::FileFormat fileFormat = dwio::common::FileFormat::DWRF) + : out_{std::move(out)}, + arena_(std::make_unique()) { + auto rowIndex = ArenaCreate(arena_.get()); + auto rowIndexEntry = ArenaCreate(arena_.get()); + + index_ = std::make_unique(rowIndex); + entry_ = std::make_unique(rowIndexEntry); + } virtual ~IndexBuilder() = default; void add(uint64_t pos, int32_t index = -1) override { - getEntry(index)->add_positions(pos); + getEntry(index).addPositions(pos); } virtual void addEntry(const StatisticsBuilder& writer) { - auto* stats = entry_.mutable_statistics(); - writer.toProto(*stats); - *index_.add_entry() = entry_; - entry_.Clear(); + auto stats = entry_->mutableStatistics(); + writer.toProto(stats); + index_->addEntry(entry_); + entry_->clear(); } virtual size_t getEntrySize() const { - const int32_t size = index_.entry_size() + 1; + const int32_t size = index_->entrySize() + 1; VELOX_CHECK_GT(size, 0, "Invalid entry size or missing current entry."); return size; } virtual void flush() { // remove isPresent positions if none is null - index_.SerializeToZeroCopyStream(out_.get()); + index_->SerializeToZeroCopyStream(out_.get()); out_->flush(); - index_.Clear(); - entry_.Clear(); + index_->clear(); + entry_->clear(); } void capturePresentStreamOffset() { if (!presentStreamOffset_.has_value()) { - presentStreamOffset_ = entry_.positions_size(); + presentStreamOffset_ = entry_->positionsSize(); } else { - DWIO_ENSURE_EQ(presentStreamOffset_.value(), entry_.positions_size()); + DWIO_ENSURE_EQ(presentStreamOffset_.value(), entry_->positionsSize()); } } @@ -79,27 +90,28 @@ class IndexBuilder : public PositionRecorder { : PRESENT_STREAM_INDEX_ENTRIES_UNPAGED; // Only need to process entries that have been added to the row index - for (uint32_t i = 0; i < index_.entry_size(); ++i) { - index_.mutable_entry(i)->mutable_positions()->ExtractSubrange( - presentStreamOffset_.value(), streamCount, nullptr); + for (uint32_t i = 0; i < index_->entrySize(); ++i) { + index_->mutableEntry(i).mutablePositions( + presentStreamOffset_.value(), streamCount); } } private: - proto::RowIndexEntry* getEntry(int32_t index) { + RowIndexEntryWriteWrapper getEntry(int32_t index) { if (index < 0) { - return &entry_; - } else if (index < index_.entry_size()) { - return index_.mutable_entry(index); + return *entry_; + } else if (index < index_->entrySize()) { + return index_->mutableEntry(index); } else { - VELOX_CHECK_EQ(index, index_.entry_size()); - return &entry_; + VELOX_CHECK_EQ(index, index_->entrySize()); + return *entry_; } } const std::unique_ptr out_; - proto::RowIndex index_; - proto::RowIndexEntry entry_; + std::unique_ptr index_; + std::unique_ptr entry_; + std::unique_ptr arena_; std::optional presentStreamOffset_; friend class IndexBuilderTest; diff --git a/velox/dwio/dwrf/writer/LayoutPlanner.cpp b/velox/dwio/dwrf/writer/LayoutPlanner.cpp index 4fcc5668480..334097a435f 100644 --- a/velox/dwio/dwrf/writer/LayoutPlanner.cpp +++ b/velox/dwio/dwrf/writer/LayoutPlanner.cpp @@ -16,14 +16,19 @@ #include "velox/dwio/dwrf/writer/LayoutPlanner.h" +#include "velox/dwio/common/Arena.h" + namespace facebook::velox::dwrf { +using dwio::common::ArenaCreate; + StreamList getStreamList(WriterContext& context) { StreamList streams; streams.reserve(context.getStreamCount()); context.iterateUnSuppressedStreams([&](auto& pair) { - streams.push_back(std::make_pair( - std::addressof(pair.first), std::addressof(pair.second))); + streams.push_back( + std::make_pair( + std::addressof(pair.first), std::addressof(pair.second))); }); return streams; } @@ -124,49 +129,54 @@ EncodingIter::pointer EncodingIter::operator->() const { EncodingManager::EncodingManager( const encryption::EncryptionHandler& encryptionHandler) - : encryptionHandler_{encryptionHandler} { + : encryptionHandler_{encryptionHandler}, + arena_{std::make_unique()} { initEncryptionGroups(); + auto dwrfStripeFooter = ArenaCreate(arena_.get()); + footer_ = std::make_unique(dwrfStripeFooter); } -proto::ColumnEncoding& EncodingManager::addEncodingToFooter(uint32_t nodeId) { +ColumnEncodingWriteWrapper EncodingManager::addEncodingToFooter( + uint32_t nodeId) { if (encryptionHandler_.isEncrypted(nodeId)) { auto index = encryptionHandler_.getEncryptionGroupIndex(nodeId); - return *encryptionGroups_.at(index).add_encoding(); + return ColumnEncodingWriteWrapper( + encryptionGroups_.at(index).add_encoding()); } else { - return *footer_.add_encoding(); + return footer_->addEncoding(); } } -proto::Stream* EncodingManager::addStreamToFooter( +StreamWriteWrapper EncodingManager::addStreamToFooter( uint32_t nodeId, uint32_t& currentIndex) { if (encryptionHandler_.isEncrypted(nodeId)) { currentIndex = encryptionHandler_.getEncryptionGroupIndex(nodeId); - return encryptionGroups_.at(currentIndex).add_streams(); + return StreamWriteWrapper(encryptionGroups_.at(currentIndex).add_streams()); } else { currentIndex = std::numeric_limits::max(); - return footer_.add_streams(); + return footer_->addStreams(); } } std::string* EncodingManager::addEncryptionGroupToFooter() { - return footer_.add_encryptiongroups(); + return footer_->addEncryptionGroups(); } proto::StripeEncryptionGroup EncodingManager::getEncryptionGroup(uint32_t i) { return encryptionGroups_.at(i); } -const proto::StripeFooter& EncodingManager::getFooter() const { - return footer_; +const StripeFooterWriteWrapper& EncodingManager::getFooter() const { + return *footer_; } EncodingIter EncodingManager::begin() const { - return EncodingIter::begin(footer_, encryptionGroups_); + return EncodingIter::begin(*footer_->dwrfPtr(), encryptionGroups_); } EncodingIter EncodingManager::end() const { - return EncodingIter::end(footer_, encryptionGroups_); + return EncodingIter::end(*footer_->dwrfPtr(), encryptionGroups_); } void EncodingManager::initEncryptionGroups() { diff --git a/velox/dwio/dwrf/writer/LayoutPlanner.h b/velox/dwio/dwrf/writer/LayoutPlanner.h index 13dc965a652..ac22f9d736f 100644 --- a/velox/dwio/dwrf/writer/LayoutPlanner.h +++ b/velox/dwio/dwrf/writer/LayoutPlanner.h @@ -89,11 +89,11 @@ class EncodingManager : public EncodingContainer { const encryption::EncryptionHandler& encryptionHandler); virtual ~EncodingManager() override = default; - proto::ColumnEncoding& addEncodingToFooter(uint32_t nodeId); - proto::Stream* addStreamToFooter(uint32_t nodeId, uint32_t& currentIndex); + ColumnEncodingWriteWrapper addEncodingToFooter(uint32_t nodeId); + StreamWriteWrapper addStreamToFooter(uint32_t nodeId, uint32_t& currentIndex); std::string* addEncryptionGroupToFooter(); proto::StripeEncryptionGroup getEncryptionGroup(uint32_t i); - const proto::StripeFooter& getFooter() const; + const StripeFooterWriteWrapper& getFooter() const; EncodingIter begin() const override; EncodingIter end() const override; @@ -102,7 +102,8 @@ class EncodingManager : public EncodingContainer { void initEncryptionGroups(); const encryption::EncryptionHandler& encryptionHandler_; - proto::StripeFooter footer_; + std::unique_ptr footer_; + std::unique_ptr arena_; std::vector encryptionGroups_; }; diff --git a/velox/dwio/dwrf/writer/StatisticsBuilder.cpp b/velox/dwio/dwrf/writer/StatisticsBuilder.cpp index 3000230c115..20d6989d9ef 100644 --- a/velox/dwio/dwrf/writer/StatisticsBuilder.cpp +++ b/velox/dwio/dwrf/writer/StatisticsBuilder.cpp @@ -16,113 +16,52 @@ #include "velox/dwio/dwrf/writer/StatisticsBuilder.h" +#include "velox/dwio/common/Arena.h" + namespace facebook::velox::dwrf { +using dwio::common::ArenaCreate; + namespace { -static bool isValidLength(const std::optional& length) { +bool isValidLength(const std::optional& length) { return length.has_value() && length.value() <= std::numeric_limits::max(); } -template -static void mergeCount(std::optional& to, const std::optional& from) { - if (to.has_value()) { - if (from.has_value()) { - to.value() += from.value(); - } else { - to.reset(); - } +// Serializes base ColumnStatistics fields to proto. +void baseToProto( + const dwio::common::ColumnStatistics& builder, + ColumnStatisticsWriteWrapper& stats) { + if (builder.hasNull().has_value()) { + stats.setHasNull(builder.hasNull().value()); } -} - -template -static void mergeMin(std::optional& to, const std::optional& from) { - if (to.has_value()) { - if (!from.has_value()) { - to.reset(); - } else if (from.value() < to.value()) { - to = from; - } + if (builder.getNumberOfValues().has_value()) { + stats.setNumberOfValues(builder.getNumberOfValues().value()); } -} - -template -static void mergeMax(std::optional& to, const std::optional& from) { - if (to.has_value()) { - if (!from.has_value()) { - to.reset(); - } else if (from.value() > to.value()) { - to = from; - } + if (builder.getRawSize().has_value()) { + stats.setRawSize(builder.getRawSize().value()); + } + if (builder.getSize().has_value()) { + stats.setSize(builder.getSize().value()); } } } // namespace -void StatisticsBuilder::merge( - const dwio::common::ColumnStatistics& other, - bool ignoreSize) { - // Merge valueCount_ only if both sides have it. Otherwise, reset. - mergeCount(valueCount_, other.getNumberOfValues()); - - // Merge hasNull_. Follow below rule: - // self / other => result - // true / any => true - // unknown / true => true - // unknown / unknown or false => unknown - // false / unknown => unknown - // false / false => false - // false / true => true - if (!hasNull_.has_value() || !hasNull_.value()) { - auto otherHasNull = other.hasNull(); - if (otherHasNull.has_value()) { - if (otherHasNull.value()) { - // other is true, set to true - hasNull_ = true; - } - // when other is false, no change is needed - } else if (hasNull_.has_value()) { - // self value is false and other is unknown, set to unknown - hasNull_.reset(); - } - } - // Merge rawSize_ the way similar to valueCount_ - mergeCount(rawSize_, other.getRawSize()); - if (!ignoreSize) { - // Merge size - mergeCount(size_, other.getSize()); - } - if (hll_) { - auto* otherBuilder = dynamic_cast(&other); - VELOX_CHECK_NOT_NULL(otherBuilder); - VELOX_CHECK_NOT_NULL(otherBuilder->hll_); - hll_->mergeWith(*otherBuilder->hll_); - } -} - -void StatisticsBuilder::toProto(proto::ColumnStatistics& stats) const { - if (hasNull_.has_value()) { - stats.set_hasnull(hasNull_.value()); - } - if (valueCount_.has_value()) { - stats.set_numberofvalues(valueCount_.value()); - } - if (rawSize_.has_value()) { - stats.set_rawsize(rawSize_.value()); - } - if (size_.has_value()) { - stats.set_size(size_.value()); - } +void StatisticsBuilder::toProto(ColumnStatisticsWriteWrapper& stats) const { + baseToProto(*this, stats); } std::unique_ptr StatisticsBuilder::build() const { - proto::ColumnStatistics stats; + auto columnStatistics = ArenaCreate(arena_.get()); + auto stats = ColumnStatisticsWriteWrapper(columnStatistics); toProto(stats); + StatsContext context{WriterVersion_CURRENT}; - auto result = - buildColumnStatisticsFromProto(ColumnStatisticsWrapper(&stats), context); + auto result = buildColumnStatisticsFromProto( + ColumnStatisticsWrapper(columnStatistics), context); // We do not alter the proto since this is part of the file format // and the file format. The distinct count does not exist in the // file format but is added here for use in on demand sampling. @@ -208,204 +147,106 @@ void StatisticsBuilder::createTree( DWIO_RAISE("Not supported type: ", kind); break; } - return; -}; - -void BooleanStatisticsBuilder::merge( - const dwio::common::ColumnStatistics& other, - bool ignoreSize) { - StatisticsBuilder::merge(other, ignoreSize); - auto stats = - dynamic_cast(&other); - if (!stats) { - // We only care about the case when type specific stats is missing yet - // it has non-null values. - if (!isEmpty(other) && trueCount_.has_value()) { - trueCount_.reset(); - } - return; - } - - // Now the case when both sides have type specific stats - mergeCount(trueCount_, stats->getTrueCount()); } -void BooleanStatisticsBuilder::toProto(proto::ColumnStatistics& stats) const { - StatisticsBuilder::toProto(stats); - // Serialize type specific stats only if there is non-null values - if (!isEmpty(*this) && trueCount_.has_value()) { - auto bStats = stats.mutable_bucketstatistics(); - DWIO_ENSURE_EQ(bStats->count_size(), 0); - bStats->add_count(trueCount_.value()); +void BooleanStatisticsBuilder::toProto( + ColumnStatisticsWriteWrapper& stats) const { + baseToProto(*this, stats); + if (!isAllNull() && trueCount_.has_value()) { + auto bStats = stats.mutableBucketStatistics(); + DWIO_ENSURE_EQ(bStats.countSize(), 0); + bStats.addCount(trueCount_.value()); } } -void IntegerStatisticsBuilder::merge( - const dwio::common::ColumnStatistics& other, - bool ignoreSize) { - StatisticsBuilder::merge(other, ignoreSize); - auto stats = - dynamic_cast(&other); - if (!stats) { - // We only care about the case when type specific stats is missing yet - // it has non-null values. - if (!isEmpty(other)) { - min_.reset(); - max_.reset(); - sum_.reset(); - } - return; - } - - // Now the case when both sides have type specific stats - mergeMin(min_, stats->getMinimum()); - mergeMax(max_, stats->getMaximum()); - mergeWithOverflowCheck(sum_, stats->getSum()); -} - -void IntegerStatisticsBuilder::toProto(proto::ColumnStatistics& stats) const { - StatisticsBuilder::toProto(stats); - // Serialize type specific stats only if there is non-null values - if (!isEmpty(*this) && +void IntegerStatisticsBuilder::toProto( + ColumnStatisticsWriteWrapper& stats) const { + baseToProto(*this, stats); + if (!isAllNull() && (min_.has_value() || max_.has_value() || sum_.has_value())) { - auto iStats = stats.mutable_intstatistics(); + auto iStats = stats.mutableIntegerStatistics(); if (min_.has_value()) { - iStats->set_minimum(min_.value()); + iStats.setMinimum(min_.value()); } if (max_.has_value()) { - iStats->set_maximum(max_.value()); + iStats.setMaximum(max_.value()); } if (sum_.has_value()) { - iStats->set_sum(sum_.value()); - } - } -} - -void DoubleStatisticsBuilder::merge( - const dwio::common::ColumnStatistics& other, - bool ignoreSize) { - StatisticsBuilder::merge(other, ignoreSize); - auto stats = - dynamic_cast(&other); - if (!stats) { - // We only care about the case when type specific stats is missing yet - // it has non-null values. - if (!isEmpty(other)) { - clear(); + iStats.setSum(sum_.value()); } - return; - } - - // Now the case when both sides have type specific stats - mergeMin(min_, stats->getMinimum()); - mergeMax(max_, stats->getMaximum()); - mergeCount(sum_, stats->getSum()); - if (sum_.has_value() && std::isnan(sum_.value())) { - sum_.reset(); } } -void DoubleStatisticsBuilder::toProto(proto::ColumnStatistics& stats) const { - StatisticsBuilder::toProto(stats); - // Serialize type specific stats only if there is non-null values - if (!isEmpty(*this) && +void DoubleStatisticsBuilder::toProto( + ColumnStatisticsWriteWrapper& stats) const { + baseToProto(*this, stats); + if (!isAllNull() && (min_.has_value() || max_.has_value() || sum_.has_value())) { - auto dStats = stats.mutable_doublestatistics(); + auto dStats = stats.mutableDoubleStatistics(); if (min_.has_value()) { - dStats->set_minimum(min_.value()); + dStats.setMinimum(min_.value()); } if (max_.has_value()) { - dStats->set_maximum(max_.value()); + dStats.setMaximum(max_.value()); } if (sum_.has_value()) { - dStats->set_sum(sum_.value()); + dStats.setSum(sum_.value()); } } } -void StringStatisticsBuilder::merge( - const dwio::common::ColumnStatistics& other, - bool ignoreSize) { - // min_/max_ is not initialized with default that can be compared against - // easily. So we need to capture whether self is empty and handle - // differently. - auto isSelfEmpty = isEmpty(*this); - StatisticsBuilder::merge(other, ignoreSize); - auto stats = - dynamic_cast(&other); - if (!stats) { - // We only care about the case when type specific stats is missing yet - // it has non-null values. - if (!isEmpty(other)) { - min_.reset(); - max_.reset(); - length_.reset(); - } - return; - } - - // If the other stats is empty, there is nothing to merge at string stats - // level. - if (isEmpty(other)) { - return; - } - - if (isSelfEmpty) { - min_ = stats->getMinimum(); - max_ = stats->getMaximum(); - } else { - mergeMin(min_, stats->getMinimum()); - mergeMax(max_, stats->getMaximum()); - } - - mergeWithOverflowCheck(length_, stats->getTotalLength()); -} - -void StringStatisticsBuilder::toProto(proto::ColumnStatistics& stats) const { - StatisticsBuilder::toProto(stats); - // If string value is too long, drop it and fall back to basic stats - if (!isEmpty(*this) && +void StringStatisticsBuilder::toProto( + ColumnStatisticsWriteWrapper& stats) const { + baseToProto(*this, stats); + if (!isAllNull() && (shouldKeep(min_) || shouldKeep(max_) || isValidLength(length_))) { - auto dStats = stats.mutable_stringstatistics(); + auto dStats = stats.mutableStringStatistics(); if (isValidLength(length_)) { - dStats->set_sum(length_.value()); + dStats.setSum(length_.value()); } if (shouldKeep(min_)) { - dStats->set_minimum(min_.value()); + dStats.setMinimum(min_.value()); } if (shouldKeep(max_)) { - dStats->set_maximum(max_.value()); + dStats.setMaximum(max_.value()); } } } -void BinaryStatisticsBuilder::merge( - const dwio::common::ColumnStatistics& other, - bool ignoreSize) { - StatisticsBuilder::merge(other, ignoreSize); - auto stats = - dynamic_cast(&other); - if (!stats) { - // We only care about the case when type specific stats is missing yet - // it has non-null values. - if (!isEmpty(other) && length_.has_value()) { - length_.reset(); - } - return; +void BinaryStatisticsBuilder::toProto( + ColumnStatisticsWriteWrapper& stats) const { + baseToProto(*this, stats); + if (!isAllNull() && isValidLength(length_)) { + auto bStats = stats.mutableBinaryStatistics(); + bStats.setSum(length_.value()); } - - mergeWithOverflowCheck(length_, stats->getTotalLength()); } -void BinaryStatisticsBuilder::toProto(proto::ColumnStatistics& stats) const { - StatisticsBuilder::toProto(stats); - // Serialize type specific stats only if there is non-null values - if (!isEmpty(*this) && isValidLength(length_)) { - auto bStats = stats.mutable_binarystatistics(); - bStats->set_sum(length_.value()); +dwio::common::KeyInfo MapStatisticsBuilder::constructKey( + const dwrf::proto::KeyInfo& keyInfo) { + if (keyInfo.has_intkey()) { + return dwio::common::KeyInfo{keyInfo.intkey()}; + } else if (keyInfo.has_byteskey()) { + return dwio::common::KeyInfo{keyInfo.byteskey()}; } + VELOX_UNREACHABLE("Illegal null key info"); +} + +void MapStatisticsBuilder::addValues( + const dwrf::proto::KeyInfo& keyInfo, + const StatisticsBuilder& stats) { + auto& keyStats = getKeyStats(MapStatisticsBuilder::constructKey(keyInfo)); + keyStats.merge(stats, /*ignoreSize=*/true); +} + +void MapStatisticsBuilder::incrementSize( + const dwrf::proto::KeyInfo& keyInfo, + uint64_t size) { + auto& keyStats = getKeyStats(MapStatisticsBuilder::constructKey(keyInfo)); + keyStats.ensureSize(); + keyStats.incrementSize(size); } void MapStatisticsBuilder::merge( @@ -414,9 +255,7 @@ void MapStatisticsBuilder::merge( StatisticsBuilder::merge(other, ignoreSize); auto stats = dynamic_cast(&other); if (!stats) { - // We only care about the case when type specific stats is missing yet - // it has non-null values. - if (!isEmpty(other) && !entryStatistics_.empty()) { + if (!other.isAllNull() && !entryStatistics_.empty()) { entryStatistics_.clear(); } return; @@ -427,21 +266,20 @@ void MapStatisticsBuilder::merge( } } -void MapStatisticsBuilder::toProto(proto::ColumnStatistics& stats) const { +void MapStatisticsBuilder::toProto(ColumnStatisticsWriteWrapper& stats) const { StatisticsBuilder::toProto(stats); - if (!isEmpty(*this) && !entryStatistics_.empty()) { - auto mapStats = stats.mutable_mapstatistics(); + if (!isAllNull() && !entryStatistics_.empty()) { + auto mapStats = stats.mutableMapStatistics(); for (const auto& entry : entryStatistics_) { auto entryStatistics = mapStats->add_stats(); const auto& key = entry.first; - // Sets the corresponding key. Leave null keys null. if (key.intKey.has_value()) { entryStatistics->mutable_key()->set_intkey(key.intKey.value()); } else if (key.bytesKey.has_value()) { entryStatistics->mutable_key()->set_byteskey(key.bytesKey.value()); } - dynamic_cast(*entry.second) - .toProto(*entryStatistics->mutable_stats()); + auto c = ColumnStatisticsWriteWrapper(entryStatistics->mutable_stats()); + dynamic_cast(*entry.second).toProto(c); } } } diff --git a/velox/dwio/dwrf/writer/StatisticsBuilder.h b/velox/dwio/dwrf/writer/StatisticsBuilder.h index 3441a5ddf9a..7d278163455 100644 --- a/velox/dwio/dwrf/writer/StatisticsBuilder.h +++ b/velox/dwio/dwrf/writer/StatisticsBuilder.h @@ -16,425 +16,136 @@ #pragma once -#include -#include +#include "velox/dwio/common/StatisticsBuilder.h" #include "velox/dwio/dwrf/common/Config.h" #include "velox/dwio/dwrf/common/Statistics.h" #include "velox/dwio/dwrf/common/wrap/dwrf-proto-wrapper.h" -#include "velox/type/Type.h" namespace facebook::velox::dwrf { -namespace { -inline bool isEmpty(const dwio::common::ColumnStatistics& stats) { - auto valueCount = stats.getNumberOfValues(); - return valueCount.has_value() && valueCount.value() == 0; -} - -template -static void -addWithOverflowCheck(std::optional& to, T value, uint64_t count) { - if (to.has_value()) { - // check overflow. Value is only valid when not overflow - T result; - auto overflow = __builtin_mul_overflow(value, count, &result); - if (!overflow) { - overflow = __builtin_add_overflow(to.value(), result, &to.value()); - } - if (overflow) { - to.reset(); - } - } -} - -template -static void mergeWithOverflowCheck( - std::optional& to, - const std::optional& from) { - if (to.has_value()) { - if (from.has_value()) { - auto overflow = - __builtin_add_overflow(to.value(), from.value(), &to.value()); - if (overflow) { - to.reset(); - } - } else { - to.reset(); - } - } -} - -inline dwio::common::KeyInfo constructKey(const dwrf::proto::KeyInfo& keyInfo) { - if (keyInfo.has_intkey()) { - return dwio::common::KeyInfo{keyInfo.intkey()}; - } else if (keyInfo.has_byteskey()) { - return dwio::common::KeyInfo{keyInfo.byteskey()}; - } - VELOX_UNREACHABLE("Illegal null key info"); -} -} // namespace - -struct StatisticsBuilderOptions { - explicit StatisticsBuilderOptions( - uint32_t stringLengthLimit, - std::optional initialSize = std::nullopt, - bool countDistincts = false, - HashStringAllocator* allocator = nullptr) - : stringLengthLimit{stringLengthLimit}, - initialSize{initialSize}, - countDistincts(countDistincts), - allocator(allocator) {} - - uint32_t stringLengthLimit; - std::optional initialSize; - bool countDistincts{false}; - HashStringAllocator* allocator; - - StatisticsBuilderOptions withoutNumDistinct() const { - return StatisticsBuilderOptions(stringLengthLimit, initialSize); - } - - static StatisticsBuilderOptions fromConfig(const Config& config) { - return StatisticsBuilderOptions{config.get(Config::STRING_STATS_LIMIT)}; - } -}; +// Re-export common types into dwrf namespace for backward compatibility. +using dwio::stats::StatisticsBuilderOptions; -/* - * Base class for stats builder. Stats builder is used in writer and file merge - * to collect and merge stats. - * It can also be used for gathering stats in ad hoc sampling. In this case it - * may also count distinct values if enabled in 'options'. - */ -class StatisticsBuilder : public virtual dwio::common::ColumnStatistics { +/// DWRF-specific StatisticsBuilder that adds proto serialization and +/// proto-based build() on top of the common StatisticsBuilder. +class StatisticsBuilder : public virtual dwio::stats::StatisticsBuilder { public: - /// Constructs with 'options'. explicit StatisticsBuilder(const StatisticsBuilderOptions& options) - : options_{options} { - init(); - } + : dwio::stats::StatisticsBuilder(options), + arena_(std::make_unique()) {} ~StatisticsBuilder() override = default; - void setHasNull() { - hasNull_ = true; - } + /// Serializes statistics to a proto wrapper. + virtual void toProto(ColumnStatisticsWriteWrapper& stats) const; - void increaseValueCount(uint64_t count = 1) { - if (valueCount_.has_value()) { - valueCount_.value() += count; - } - } - - void increaseRawSize(uint64_t rawSize) { - if (rawSize_.has_value()) { - rawSize_.value() += rawSize; - } - } - - void clearRawSize() { - rawSize_.reset(); - } - - void ensureSize() { - if (!size_.has_value()) { - size_ = 0; - } - } - - void incrementSize(uint64_t size) { - if (LIKELY(size_.has_value())) { - addWithOverflowCheck(size_, size, /*count=*/1); - } - } - - template - void addHash(const T& data) { - if (hll_) { - hll_->insertHash(folly::hasher()(data)); - } - } - - int64_t cardinality() const { - VELOX_CHECK_NOT_NULL(hll_); - return hll_->cardinality(); - } - - /* - * Merge stats of same type. This is used in writer to aggregate file level - * stats. - */ - virtual void merge( - const dwio::common::ColumnStatistics& other, - bool ignoreSize = false); - - /* - * Reset. Used in the place where row index entry level stats in captured. - */ - virtual void reset() { - init(); - } - - /* - * Write stats to proto - */ - virtual void toProto(proto::ColumnStatistics& stats) const; - - std::unique_ptr build() const; + /// Builds a read-only ColumnStatistics by round-tripping through proto. + std::unique_ptr build() const override; + /// Creates a DWRF-specific StatisticsBuilder for the given type. For MAP + /// type, returns a MapStatisticsBuilder. static std::unique_ptr create( const Type& type, const StatisticsBuilderOptions& options); - // for the given type tree, create the a list of stat builders + /// For the given type tree, creates a list of DWRF stat builders. static void createTree( std::vector>& statBuilders, const Type& type, const StatisticsBuilderOptions& options); - private: - void init() { - valueCount_ = 0; - hasNull_ = false; - rawSize_ = 0; - size_ = options_.initialSize; - if (options_.countDistincts) { - hll_ = std::make_shared(options_.allocator); - } + /// Creates StatisticsBuilderOptions from a DWRF Config. + static StatisticsBuilderOptions optionsFromConfig(const Config& config) { + return StatisticsBuilderOptions{config.get(Config::STRING_STATS_LIMIT)}; } - protected: - StatisticsBuilderOptions options_; - std::shared_ptr hll_; + private: + std::unique_ptr arena_; }; class BooleanStatisticsBuilder : public StatisticsBuilder, - public dwio::common::BooleanColumnStatistics { + public dwio::stats::BooleanStatisticsBuilder { public: explicit BooleanStatisticsBuilder(const StatisticsBuilderOptions& options) - : StatisticsBuilder{options.withoutNumDistinct()} { - init(); - } + : dwio::stats::StatisticsBuilder{options.dropNumDistinct()}, + StatisticsBuilder{options.dropNumDistinct()}, + dwio::stats::BooleanStatisticsBuilder{options} {} ~BooleanStatisticsBuilder() override = default; - void addValues(bool value, uint64_t count = 1) { - increaseValueCount(count); - if (trueCount_.has_value() && value) { - trueCount_.value() += count; - } - } - - void merge( - const dwio::common::ColumnStatistics& other, - bool ignoreSize = false) override; - - void reset() override { - StatisticsBuilder::reset(); - init(); + std::unique_ptr build() const override { + return StatisticsBuilder::build(); } - void toProto(proto::ColumnStatistics& stats) const override; - - private: - void init() { - trueCount_ = 0; - } + void toProto(ColumnStatisticsWriteWrapper& stats) const override; }; class IntegerStatisticsBuilder : public StatisticsBuilder, - public dwio::common::IntegerColumnStatistics { + public dwio::stats::IntegerStatisticsBuilder { public: explicit IntegerStatisticsBuilder(const StatisticsBuilderOptions& options) - : StatisticsBuilder{options} { - init(); - } + : dwio::stats::StatisticsBuilder{options}, + StatisticsBuilder{options}, + dwio::stats::IntegerStatisticsBuilder{options} {} ~IntegerStatisticsBuilder() override = default; - void addValues(int64_t value, uint64_t count = 1) { - increaseValueCount(count); - if (min_.has_value() && value < min_.value()) { - min_ = value; - } - if (max_.has_value() && value > max_.value()) { - max_ = value; - } - addWithOverflowCheck(sum_, value, count); - addHash(value); - } - - void merge( - const dwio::common::ColumnStatistics& other, - bool ignoreSize = false) override; - - void reset() override { - StatisticsBuilder::reset(); - init(); + std::unique_ptr build() const override { + return StatisticsBuilder::build(); } - void toProto(proto::ColumnStatistics& stats) const override; - - private: - void init() { - min_ = std::numeric_limits::max(); - max_ = std::numeric_limits::min(); - sum_ = 0; - } + void toProto(ColumnStatisticsWriteWrapper& stats) const override; }; -static_assert( - std::numeric_limits::has_infinity, - "infinity not defined"); - class DoubleStatisticsBuilder : public StatisticsBuilder, - public dwio::common::DoubleColumnStatistics { + public dwio::stats::DoubleStatisticsBuilder { public: explicit DoubleStatisticsBuilder(const StatisticsBuilderOptions& options) - : StatisticsBuilder{options} { - init(); - } + : dwio::stats::StatisticsBuilder{options}, + StatisticsBuilder{options}, + dwio::stats::DoubleStatisticsBuilder{options} {} ~DoubleStatisticsBuilder() override = default; - void addValues(double value, uint64_t count = 1) { - increaseValueCount(count); - // min/max/sum is defined only when none of the values added is NaN - if (std::isnan(value)) { - clear(); - return; - } - - if (min_.has_value() && value < min_.value()) { - min_ = value; - } - if (max_.has_value() && value > max_.value()) { - max_ = value; - } - addHash(value); - // value * count sometimes is not same as adding values (count) times. So - // add in a loop - if (sum_.has_value()) { - for (uint64_t i = 0; i < count; ++i) { - sum_.value() += value; - } - if (std::isnan(sum_.value())) { - sum_.reset(); - } - } + std::unique_ptr build() const override { + return StatisticsBuilder::build(); } - void merge( - const dwio::common::ColumnStatistics& other, - bool ignoreSize = false) override; - - void reset() override { - StatisticsBuilder::reset(); - init(); - } - - void toProto(proto::ColumnStatistics& stats) const override; - - private: - void init() { - min_ = std::numeric_limits::infinity(); - max_ = -std::numeric_limits::infinity(); - sum_ = 0; - } - - void clear() { - min_.reset(); - max_.reset(); - sum_.reset(); - } + void toProto(ColumnStatisticsWriteWrapper& stats) const override; }; class StringStatisticsBuilder : public StatisticsBuilder, - public dwio::common::StringColumnStatistics { + public dwio::stats::StringStatisticsBuilder { public: explicit StringStatisticsBuilder(const StatisticsBuilderOptions& options) - : StatisticsBuilder{options}, lengthLimit_{options.stringLengthLimit} { - init(); - } + : dwio::stats::StatisticsBuilder{options}, + StatisticsBuilder{options}, + dwio::stats::StringStatisticsBuilder{options} {} ~StringStatisticsBuilder() override = default; - void addValues(folly::StringPiece value, uint64_t count = 1) { - // min_/max_ is not initialized with default that can be compared against - // easily. So we need to capture whether self is empty and handle - // differently. - auto isSelfEmpty = isEmpty(*this); - increaseValueCount(count); - if (isSelfEmpty) { - min_ = value; - max_ = value; - } else { - if (min_.has_value() && value < folly::StringPiece{min_.value()}) { - min_ = value; - } - if (max_.has_value() && value > folly::StringPiece{max_.value()}) { - max_ = value; - } - } - addHash(value); - - addWithOverflowCheck(length_, value.size(), count); + std::unique_ptr build() const override { + return StatisticsBuilder::build(); } - void merge( - const dwio::common::ColumnStatistics& other, - bool ignoreSize = false) override; - - void reset() override { - StatisticsBuilder::reset(); - init(); - } - - void toProto(proto::ColumnStatistics& stats) const override; - - private: - uint32_t lengthLimit_; - - void init() { - min_.reset(); - max_.reset(); - length_ = 0; - } - - bool shouldKeep(const std::optional& val) const { - return val.has_value() && val.value().size() <= lengthLimit_; - } + void toProto(ColumnStatisticsWriteWrapper& stats) const override; }; class BinaryStatisticsBuilder : public StatisticsBuilder, - public dwio::common::BinaryColumnStatistics { + public dwio::stats::BinaryStatisticsBuilder { public: explicit BinaryStatisticsBuilder(const StatisticsBuilderOptions& options) - : StatisticsBuilder{options.withoutNumDistinct()} { - init(); - } + : dwio::stats::StatisticsBuilder{options.dropNumDistinct()}, + StatisticsBuilder{options.dropNumDistinct()}, + dwio::stats::BinaryStatisticsBuilder{options} {} ~BinaryStatisticsBuilder() override = default; - void addValues(uint64_t length, uint64_t count = 1) { - increaseValueCount(count); - addWithOverflowCheck(length_, length, count); - } - - void merge( - const dwio::common::ColumnStatistics& other, - bool ignoreSize = false) override; - - void reset() override { - StatisticsBuilder::reset(); - init(); + std::unique_ptr build() const override { + return StatisticsBuilder::build(); } - void toProto(proto::ColumnStatistics& stats) const override; - - private: - void init() { - length_ = 0; - } + void toProto(ColumnStatisticsWriteWrapper& stats) const override; }; class MapStatisticsBuilder : public StatisticsBuilder, @@ -443,7 +154,8 @@ class MapStatisticsBuilder : public StatisticsBuilder, MapStatisticsBuilder( const Type& type, const StatisticsBuilderOptions& options) - : StatisticsBuilder{options}, + : dwio::stats::StatisticsBuilder{options}, + StatisticsBuilder{options}, valueType_{type.as().valueType()} { init(); hll_.reset(); @@ -453,20 +165,9 @@ class MapStatisticsBuilder : public StatisticsBuilder, void addValues( const dwrf::proto::KeyInfo& keyInfo, - const StatisticsBuilder& stats) { - // Since addValues is called once per key info per stride, - // it's ok to just construct the key struct per call. - auto& keyStats = getKeyStats(constructKey(keyInfo)); - keyStats.merge(stats, /*ignoreSize=*/true); - } + const StatisticsBuilder& stats); - void incrementSize(const dwrf::proto::KeyInfo& keyInfo, uint64_t size) { - // Since incrementSize is called once per key info per stripe, - // it's ok to just construct the key struct per call. - auto& keyStats = getKeyStats(constructKey(keyInfo)); - keyStats.ensureSize(); - keyStats.incrementSize(size); - } + void incrementSize(const dwrf::proto::KeyInfo& keyInfo, uint64_t size); void merge( const dwio::common::ColumnStatistics& other, @@ -477,7 +178,11 @@ class MapStatisticsBuilder : public StatisticsBuilder, init(); } - void toProto(proto::ColumnStatistics& stats) const override; + void toProto(ColumnStatisticsWriteWrapper& stats) const override; + + /// Converts a proto KeyInfo to a dwio::common::KeyInfo. + static dwio::common::KeyInfo constructKey( + const dwrf::proto::KeyInfo& keyInfo); private: void init() { diff --git a/velox/dwio/dwrf/writer/StatisticsBuilderUtils.cpp b/velox/dwio/dwrf/writer/StatisticsBuilderUtils.cpp index 3babe373db2..89dbcd1cfef 100644 --- a/velox/dwio/dwrf/writer/StatisticsBuilderUtils.cpp +++ b/velox/dwio/dwrf/writer/StatisticsBuilderUtils.cpp @@ -81,18 +81,18 @@ void StatisticsBuilderUtils::addValues( const VectorPtr& vector, const common::Ranges& ranges) { auto nulls = vector->rawNulls(); - auto data = vector->asFlatVector()->rawValues(); + auto* data = vector->asFlatVector()->rawValues(); if (vector->mayHaveNulls()) { for (auto& pos : ranges) { if (bits::isBitNull(nulls, pos)) { builder.setHasNull(); } else { - builder.addValues(folly::StringPiece{data[pos]}); + builder.addValues(std::string_view{data[pos]}); } } } else { for (auto& pos : ranges) { - builder.addValues(folly::StringPiece{data[pos]}); + builder.addValues(std::string_view{data[pos]}); } } } diff --git a/velox/dwio/dwrf/writer/StringDictionaryEncoder.h b/velox/dwio/dwrf/writer/StringDictionaryEncoder.h index 8ecb635f865..20abfaac858 100644 --- a/velox/dwio/dwrf/writer/StringDictionaryEncoder.h +++ b/velox/dwio/dwrf/writer/StringDictionaryEncoder.h @@ -30,22 +30,22 @@ namespace detail { // Each new string inserted into dictionary is assigned an incrementing id. // A set is maintained with all of the DictStringId created. Using -// Heterogeneous lookup techniques, incoming StringPiece is first looked for +// Heterogeneous lookup techniques, incoming string_view is first looked for // a match in the set. If no match exists a new id is generated and inserted // into the set. What is Heterogeneous lookup ? // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2018/p0919r1.html // Heterogeneous lookup is not available in standard CPP and proposed for CPP20. // Follys:F14* variant supports it, so leveraging folly for now. struct StringLookupKey { - StringLookupKey(folly::StringPiece sp, uint32_t index) - : sp{sp}, + StringLookupKey(std::string_view sv, uint32_t index) + : sv{sv}, index{index}, hash{folly::crc32c( - reinterpret_cast(sp.data()), - sp.size(), + reinterpret_cast(sv.data()), + sv.size(), 0 /* seed */)} {} - const folly::StringPiece sp; + const std::string_view sv; const uint32_t index; const uint32_t hash; }; @@ -117,9 +117,9 @@ class StringDictionaryEncoder { } uint32_t - addKey(folly::StringPiece sp, uint32_t strideIndex, uint32_t count = 1) { + addKey(std::string_view sv, uint32_t strideIndex, uint32_t count = 1) { auto newIndex = size(); - detail::StringLookupKey key{sp, newIndex}; + detail::StringLookupKey key{sv, newIndex}; auto result = keyIndex_.insert(key); if (!result.second) { auto index = result.first->getIndex(); @@ -130,12 +130,12 @@ class StringDictionaryEncoder { auto bytesCount = keyBytes_.size(); if (UNLIKELY( newIndex == std::numeric_limits::max() || - (std::numeric_limits::max() - bytesCount <= sp.size()))) { + (std::numeric_limits::max() - bytesCount <= sv.size()))) { DWIO_RAISE("exceeds dictionary size limit"); } // append keys - keyBytes_.extendAppend(bytesCount, sp.data(), sp.size()); + keyBytes_.extendAppend(bytesCount, sv.data(), sv.size()); keyOffsets_.append(keyBytes_.size()); hash_.append(key.hash); counts_.append(count); @@ -153,11 +153,11 @@ class StringDictionaryEncoder { return firstSeenStrideIndex_[index]; } - folly::StringPiece getKey(uint32_t index) const { + std::string_view getKey(uint32_t index) const { DCHECK(index < keyOffsets_.size() - 1); auto startOffset = keyOffsets_[index]; auto endOffset = keyOffsets_[index + 1]; - return folly::StringPiece{ + return std::string_view{ keyBytes_.data() + startOffset, endOffset - startOffset}; } @@ -178,8 +178,8 @@ class StringDictionaryEncoder { VELOX_FRIEND_TEST(TestStringDictionaryEncoder, Clear); // Intended for testing only. - uint32_t getIndex(folly::StringPiece sp) { - detail::StringLookupKey key{sp, 0}; + uint32_t getIndex(std::string_view sv) { + detail::StringLookupKey key{sv, 0}; auto result = keyIndex_.find(key); if (result != keyIndex_.end()) { return result->getIndex(); @@ -221,7 +221,7 @@ FOLLY_ALWAYS_INLINE bool DictStringIdEquality::operator()( FOLLY_ALWAYS_INLINE bool DictStringIdEquality::operator()( detail::StringLookupKey key, detail::DictStringId lhs) const { - return encoder_.getKey(lhs.getIndex()) == key.sp; + return encoder_.getKey(lhs.getIndex()) == key.sv; } FOLLY_ALWAYS_INLINE uint32_t diff --git a/velox/dwio/dwrf/writer/Writer.cpp b/velox/dwio/dwrf/writer/Writer.cpp index d6011c38f8d..0133522c7ac 100644 --- a/velox/dwio/dwrf/writer/Writer.cpp +++ b/velox/dwio/dwrf/writer/Writer.cpp @@ -58,10 +58,12 @@ uint64_t orcWriterMaxStripeSize( const config::ConfigBase& config, const config::ConfigBase& session) { return config::toCapacity( - session.get( - dwrf::Config::kOrcWriterMaxStripeSizeSession, - config.get( - dwrf::Config::kOrcWriterMaxStripeSize, "64MB")), + session + .getLegacyWithFallback( + dwrf::Config::kOrcWriterMaxStripeSizeSession, + config, + dwrf::Config::kOrcWriterMaxStripeSize) + .value_or("64MB"), config::CapacityUnit::BYTE); } @@ -69,68 +71,66 @@ uint64_t orcWriterMaxDictionaryMemory( const config::ConfigBase& config, const config::ConfigBase& session) { return config::toCapacity( - session.get( - dwrf::Config::kOrcWriterMaxDictionaryMemorySession, - config.get( - dwrf::Config::kOrcWriterMaxDictionaryMemory, "16MB")), + session + .getLegacyWithFallback( + dwrf::Config::kOrcWriterMaxDictionaryMemorySession, + config, + dwrf::Config::kOrcWriterMaxDictionaryMemory) + .value_or("16MB"), config::CapacityUnit::BYTE); } bool isOrcWriterIntegerDictionaryEncodingEnabled( const config::ConfigBase& config, const config::ConfigBase& session) { - return session.get( - dwrf::Config::kOrcWriterIntegerDictionaryEncodingEnabledSession, - config.get( - dwrf::Config::kOrcWriterIntegerDictionaryEncodingEnabled, true)); + return session + .getLegacyWithFallback( + dwrf::Config::kOrcWriterIntegerDictionaryEncodingEnabledSession, + config, + dwrf::Config::kOrcWriterIntegerDictionaryEncodingEnabled) + .value_or(true); } bool isOrcWriterStringDictionaryEncodingEnabled( const config::ConfigBase& config, const config::ConfigBase& session) { - return session.get( - dwrf::Config::kOrcWriterStringDictionaryEncodingEnabledSession, - config.get( - dwrf::Config::kOrcWriterStringDictionaryEncodingEnabled, true)); + return session + .getLegacyWithFallback( + dwrf::Config::kOrcWriterStringDictionaryEncodingEnabledSession, + config, + dwrf::Config::kOrcWriterStringDictionaryEncodingEnabled) + .value_or(true); } bool orcWriterLinearStripeSizeHeuristics( const config::ConfigBase& config, const config::ConfigBase& session) { - return session.get( - dwrf::Config::kOrcWriterLinearStripeSizeHeuristicsSession, - config.get( - dwrf::Config::kOrcWriterLinearStripeSizeHeuristics, true)); + return session + .getLegacyWithFallback( + dwrf::Config::kOrcWriterLinearStripeSizeHeuristicsSession, + config, + dwrf::Config::kOrcWriterLinearStripeSizeHeuristics) + .value_or(true); } uint64_t orcWriterMinCompressionSize( const config::ConfigBase& config, const config::ConfigBase& session) { - return session.get( - dwrf::Config::kOrcWriterMinCompressionSizeSession, - config.get(dwrf::Config::kOrcWriterMinCompressionSize, 1024)); + return session + .getLegacyWithFallback( + dwrf::Config::kOrcWriterMinCompressionSizeSession, + config, + dwrf::Config::kOrcWriterMinCompressionSize) + .value_or(1024); } std::optional orcWriterCompressionLevel( const config::ConfigBase& config, const config::ConfigBase& session) { - auto sessionProp = - session.get(dwrf::Config::kOrcWriterCompressionLevelSession); - - if (sessionProp.has_value()) { - return sessionProp.value(); - } - - auto configProp = - config.get(dwrf::Config::kOrcWriterCompressionLevel); - - if (configProp.has_value()) { - return configProp.value(); - } - - // Presto has a single config controlling this value, but different defaults - // depending on the compression kind. - return std::nullopt; + return session.getLegacyWithFallback( + dwrf::Config::kOrcWriterCompressionLevelSession, + config, + dwrf::Config::kOrcWriterCompressionLevel); } uint8_t orcWriterZLIBCompressionLevel( @@ -175,7 +175,8 @@ Writer::Writer( pool, options.sessionTimezone, options.adjustTimestampToTimezone, - std::move(handler)); + std::move(handler), + options.memoryBudget); auto& context = writerBase_->getContext(); VELOX_CHECK_EQ( context.getTotalMemoryUsage(), @@ -213,10 +214,11 @@ Writer::Writer( : Writer{ std::move(sink), options, - options.memoryPool->addAggregateChild(fmt::format( - "{}.dwrf.{}", - options.memoryPool->name(), - folly::to(folly::Random::rand64())))} {} + options.memoryPool->addAggregateChild( + fmt::format( + "{}.dwrf.{}", + options.memoryPool->name(), + folly::to(folly::Random::rand64())))} {} void Writer::setMemoryReclaimers( const std::shared_ptr& pool) { @@ -521,7 +523,7 @@ void Writer::flushStripe(bool close) { const auto& handler = context.getEncryptionHandler(); EncodingManager encodingManager{handler}; - writer_->flush([&](uint32_t nodeId) -> proto::ColumnEncoding& { + writer_->flush([&](uint32_t nodeId) -> ColumnEncodingWriteWrapper { return encodingManager.addEncodingToFooter(nodeId); }); @@ -546,7 +548,8 @@ void Writer::flushStripe(bool close) { const DataBufferHolder& out) { uint32_t currentIndex = 0; const auto nodeId = stream.encodingKey().node(); - proto::Stream* s = encodingManager.addStreamToFooter(nodeId, currentIndex); + StreamWriteWrapper s = + encodingManager.addStreamToFooter(nodeId, currentIndex); // set offset only when needed, ie. when offset of current stream cannot be // calculated based on offset and length of previous stream. In that case, @@ -554,19 +557,19 @@ void Writer::flushStripe(bool close) { // encryption group or neither are encrypted. So the logic is simplified to // check if group index are the same for current and previous stream if (offset > 0 && lastIndex != currentIndex) { - s->set_offset(offset); + s.setOffset(offset); } lastIndex = currentIndex; // Jolly/Presto readers can't read streams bigger than 2GB. writerBase_->validateStreamSize(stream, out.size()); - s->set_kind(static_cast(stream.kind())); - s->set_node(nodeId); - s->set_column(stream.column()); - s->set_sequence(stream.encodingKey().sequence()); - s->set_length(out.size()); - s->set_usevints(context.getConfig(Config::USE_VINTS)); + s.setKind(stream.kind()); + s.setNode(nodeId); + s.setColumn(stream.column()); + s.setSequence(stream.encodingKey().sequence()); + s.setLength(out.size()); + s.setUseVints(context.getConfig(Config::USE_VINTS)); offset += out.size(); context.recordPhysicalSize(stream, out.size()); @@ -619,19 +622,20 @@ void Writer::flushStripe(bool close) { VELOX_CHECK_EQ(footerOffset, stripeOffset + dataLength + indexLength); sink.setMode(WriterSink::Mode::Footer); - writerBase_->writeProto(encodingManager.getFooter()); + encodingManager.getFooter().setWriterTimezone(); + writerBase_->writeProto(&encodingManager.getFooter()); sink.setMode(WriterSink::Mode::None); - auto& stripe = writerBase_->addStripeInfo(); - stripe.set_offset(stripeOffset); - stripe.set_indexlength(indexLength); - stripe.set_datalength(dataLength); - stripe.set_footerlength(sink.size() - footerOffset); + auto stripe = writerBase_->addStripeInfo(); + stripe.setOffset(stripeOffset); + stripe.setIndexLength(indexLength); + stripe.setDataLength(dataLength); + stripe.setFooterLength(sink.size() - footerOffset); // set encryption key metadata if (handler.isEncrypted() && context.stripeIndex() == 0) { for (uint32_t i = 0; i < handler.getEncryptionGroupCount(); ++i) { - *stripe.add_keymetadata() = + *stripe.addKeyMetadata() = handler.getEncryptionProviderByIndex(i).getKey(); } } @@ -694,10 +698,10 @@ void Writer::flushInternal(bool close) { proto::Encryption* encryption = nullptr; // initialize encryption related metadata only when there is data written - if (handler.isEncrypted() && footer.stripes_size() > 0) { + if (handler.isEncrypted() && footer->stripesSize() > 0) { const auto count = handler.getEncryptionGroupCount(); stats.resize(count); - encryption = footer.mutable_encryption(); + encryption = footer->mutableEncryption(); encryption->set_keyprovider( encryption::toProto(handler.getKeyProviderType())); for (uint32_t i = 0; i < count; ++i) { @@ -708,25 +712,28 @@ void Writer::flushInternal(bool close) { std::optional lastRoot; std::unordered_map statsMap; - writer_->writeFileStats([&](uint32_t nodeId) -> proto::ColumnStatistics& { - auto entry = footer.add_statistics(); - if (!encryption || !handler.isEncrypted(nodeId)) { - return *entry; - } + writer_->writeFileStats( + [&](uint32_t nodeId) -> ColumnStatisticsWriteWrapper { + auto entry = footer->addStatistics(); + if (!encryption || !handler.isEncrypted(nodeId)) { + return entry; + } - auto root = handler.getEncryptionRoot(nodeId); - auto groupIndex = handler.getEncryptionGroupIndex(nodeId); - auto& group = stats.at(groupIndex); - if (!lastRoot || root != lastRoot.value()) { - // this is a new root, add to the footer, and use a new slot - group.emplace_back(); - encryption->mutable_encryptiongroups(groupIndex)->add_nodes(root); - } - lastRoot = root; - auto encryptedStats = group.back().add_statistics(); - statsMap[entry] = encryptedStats; - return *encryptedStats; - }); + auto root = handler.getEncryptionRoot(nodeId); + auto groupIndex = handler.getEncryptionGroupIndex(nodeId); + auto& group = stats.at(groupIndex); + if (!lastRoot || root != lastRoot.value()) { + // this is a new root, add to the footer, and use a new slot + group.emplace_back(); + encryption->mutable_encryptiongroups(groupIndex)->add_nodes(root); + } + lastRoot = root; + auto encryptedStats = group.back().add_statistics(); + auto cs = + reinterpret_cast(entry.rawProtoPtr()); + statsMap[cs] = encryptedStats; + return ColumnStatisticsWriteWrapper(encryptedStats); + }); #define COPY_STAT(from, to, stat) \ if (from->has_##stat()) { \ @@ -770,7 +777,7 @@ void Writer::flushInternal(bool close) { dwio::common::MetricsLog::FileCloseMetrics{ .writerVersion = writerVersionToString( context.getConfig(Config::WRITER_VERSION)), - .footerLength = footer.contentlength(), + .footerLength = footer->contentLength(), .fileSize = sink.size(), .cacheSize = sink.getCacheSize(), .numCacheBlocks = sink.getCacheOffsets().size() - 1, @@ -792,7 +799,7 @@ void Writer::flush() { flushInternal(false); } -void Writer::close() { +std::unique_ptr Writer::close() { checkRunning(); auto exitGuard = folly::makeGuard([this]() { flushPolicy_->onClose(); @@ -800,6 +807,7 @@ void Writer::close() { }); flushInternal(true); writerBase_->close(); + return std::make_unique(); } void Writer::abort() { @@ -853,13 +861,22 @@ uint64_t Writer::MemoryReclaimer::reclaim( LOG(WARNING) << "Can't reclaim from dwrf writer which is under non-reclaimable " "section: " - << pool->name(); + << pool->name() << ", root pool: " << pool->root()->name() + << ", used: " << succinctBytes(pool->usedBytes()) + << ", reservation: " << succinctBytes(pool->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool->root()->reservedBytes()); ++stats.numNonReclaimableAttempts; return 0; } if (!writer_->isRunning()) { LOG(WARNING) << "Can't reclaim from a not running dwrf writer: " - << pool->name() << ", state: " << writer_->state(); + << pool->name() << ", root pool: " << pool->root()->name() + << ", state: " << writer_->state() + << ", used: " << succinctBytes(pool->usedBytes()) + << ", reservation: " << succinctBytes(pool->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool->root()->reservedBytes()); ++stats.numNonReclaimableAttempts; return 0; } diff --git a/velox/dwio/dwrf/writer/Writer.h b/velox/dwio/dwrf/writer/Writer.h index 3864b57790d..847c0ac3045 100644 --- a/velox/dwio/dwrf/writer/Writer.h +++ b/velox/dwio/dwrf/writer/Writer.h @@ -19,6 +19,7 @@ #include #include +#include "velox/dwio/common/FileMetadata.h" #include "velox/dwio/common/Writer.h" #include "velox/dwio/common/WriterFactory.h" #include "velox/dwio/dwrf/common/Encryption.h" @@ -30,6 +31,9 @@ namespace facebook::velox::dwrf { +/// DWRF-specific file metadata wrapper. Currently a placeholder. +class DwrfFileMetadata : public dwio::common::FileMetadata {}; + struct WriterOptions : public dwio::common::WriterOptions { std::shared_ptr config = std::make_shared(); /// Changes the interface to stream list and encoding iter. @@ -60,10 +64,11 @@ class Writer : public dwio::common::Writer { : Writer{ std::move(sink), options, - parentPool.addAggregateChild(fmt::format( - "{}.dwrf_{}", - parentPool.name(), - folly::to(folly::Random::rand64())))} {} + parentPool.addAggregateChild( + fmt::format( + "{}.dwrf_{}", + parentPool.name(), + folly::to(folly::Random::rand64())))} {} Writer( std::unique_ptr sink, @@ -85,7 +90,7 @@ class Writer : public dwio::common::Writer { return true; } - virtual void close() override; + virtual std::unique_ptr close() override; virtual void abort() override; @@ -134,7 +139,7 @@ class Writer : public dwio::common::Writer { return writerBase_->getContext(); } - const proto::Footer& getFooter() const { + const std::unique_ptr& getFooter() const { return writerBase_->getFooter(); } diff --git a/velox/dwio/dwrf/writer/WriterBase.cpp b/velox/dwio/dwrf/writer/WriterBase.cpp index 6fd51477d04..dd182917661 100644 --- a/velox/dwio/dwrf/writer/WriterBase.cpp +++ b/velox/dwio/dwrf/writer/WriterBase.cpp @@ -22,8 +22,8 @@ namespace facebook::velox::dwrf { void WriterBase::writeFooter(const Type& type) { auto pos = writerSink_->size(); - footer_.set_headerlength(ORC_MAGIC_LEN); - footer_.set_contentlength(pos - ORC_MAGIC_LEN); + footer_->setHeaderLength(ORC_MAGIC_LEN); + footer_->setContentLength(pos - ORC_MAGIC_LEN); writerSink_->setMode(WriterSink::Mode::None); // write cache when available @@ -31,45 +31,46 @@ void WriterBase::writeFooter(const Type& type) { if (cacheSize > 0) { writerSink_->writeCache(); for (auto& i : writerSink_->getCacheOffsets()) { - footer_.add_stripecacheoffsets(i); + footer_->addStripeCacheOffsets(i); } pos = writerSink_->size(); } - ProtoUtils::writeType(type, footer_); - DWIO_ENSURE_EQ(footer_.types_size(), footer_.statistics_size()); + ProtoUtils::writeType(type, *footer_); + DWIO_ENSURE_EQ(footer_->typesSize(), footer_->statisticsSize()); auto writerVersion = static_cast(context_->getConfig(Config::WRITER_VERSION)); writeUserMetadata(writerVersion); - footer_.set_numberofrows(context_->fileRowCount()); - footer_.set_rowindexstride(context_->indexStride()); + footer_->setNumberOfRows(context_->fileRowCount()); + footer_->setRowIndexStride(context_->indexStride()); if (context_->fileRawSize() > 0 || context_->fileRowCount() == 0) { // ColumnTransformWriter, when rewriting presto written file does not have // rawSize. - footer_.set_rawdatasize(context_->fileRawSize()); + footer_->setRawDataSize(context_->fileRawSize()); } auto* checksum = writerSink_->getChecksum(); - footer_.set_checksumalgorithm( + footer_->setCheckSumAlgorithm( (checksum != nullptr) ? checksum->getType() : proto::ChecksumAlgorithm::NULL_); - writeProto(footer_); + writeProto(footer_->getDwrfPtr()); const auto footerLength = writerSink_->size() - pos; // write postscript pos = writerSink_->size(); - proto::PostScript ps; - ps.set_writerversion(writerVersion); - ps.set_footerlength(footerLength); - ps.set_compression( - static_cast(context_->compression())); + auto dwrfPostScript = ArenaCreate(arena_.get()); + std::unique_ptr ps = + std::make_unique(dwrfPostScript); + ps->setWriterVersion(writerVersion); + ps->setFooterLength(footerLength); + ps->setCompression(context_->compression()); if (context_->compression() != common::CompressionKind::CompressionKind_NONE) { - ps.set_compressionblocksize(context_->compressionBlockSize()); + ps->setCompressionBlockSize(context_->compressionBlockSize()); } - ps.set_cachemode( - static_cast(writerSink_->getCacheMode())); - ps.set_cachesize(cacheSize); + + ps->setCacheMode(writerSink_->getCacheMode()); + ps->setCacheSize(cacheSize); writeProto(ps, common::CompressionKind::CompressionKind_NONE); auto psLength = writerSink_->size() - pos; DWIO_ENSURE_LE(psLength, 0xff, "PostScript is too large: ", psLength); @@ -80,14 +81,14 @@ void WriterBase::writeFooter(const Type& type) { void WriterBase::writeUserMetadata(uint32_t writerVersion) { // add writer version - userMetadata_[std::string{WRITER_NAME_KEY}] = kDwioWriter; - userMetadata_[std::string{WRITER_VERSION_KEY}] = + userMetadata_[std::string{kWriterNameKey}] = kDwioWriter; + userMetadata_[std::string{kWriterVersionKey}] = folly::to(writerVersion); - userMetadata_[std::string{WRITER_HOSTNAME_KEY}] = process::getHostName(); + userMetadata_[std::string{kWriterHostnameKey}] = process::getHostName(); std::for_each(userMetadata_.begin(), userMetadata_.end(), [&](auto& pair) { - auto item = footer_.add_metadata(); - item->set_name(pair.first); - item->set_value(pair.second); + auto item = footer_->addMetadata(); + item.setName(pair.first); + item.setValue(pair.second); }); } diff --git a/velox/dwio/dwrf/writer/WriterBase.h b/velox/dwio/dwrf/writer/WriterBase.h index e79b9293db9..3c468a370ca 100644 --- a/velox/dwio/dwrf/writer/WriterBase.h +++ b/velox/dwio/dwrf/writer/WriterBase.h @@ -17,15 +17,19 @@ #pragma once #include "velox/common/base/GTestMacros.h" +#include "velox/dwio/common/Arena.h" #include "velox/dwio/dwrf/writer/WriterContext.h" #include "velox/dwio/dwrf/writer/WriterSink.h" namespace facebook::velox::dwrf { +using dwio::common::ArenaCreate; + class WriterBase { public: explicit WriterBase(std::unique_ptr sink) - : sink_{std::move(sink)} { + : sink_{std::move(sink)}, + arena_(std::make_unique()) { VELOX_CHECK_NOT_NULL(sink_); } @@ -76,18 +80,22 @@ class WriterBase { std::shared_ptr pool, const tz::TimeZone* sessionTimezone = nullptr, const bool adjustTimestampToTimezone = false, - std::unique_ptr handler = nullptr) { + std::unique_ptr handler = nullptr, + int64_t memoryBudget = std::numeric_limits::max()) { context_ = std::make_unique( config, std::move(pool), sink_->metricsLog(), sessionTimezone, adjustTimestampToTimezone, - std::move(handler)); + std::move(handler), + memoryBudget); writerSink_ = std::make_unique( *sink_, context_->getMemoryPool(MemoryUsageCategory::OUTPUT_STREAM), context_->getConfigs()); + auto dwrfFooter_ = ArenaCreate(arena_.get()); + footer_ = std::make_unique(dwrfFooter_); } void initBuffers(); @@ -107,7 +115,7 @@ class WriterBase { auto holder = context_->newDataBufferHolder(); auto stream = context_->newStream(kind, *holder); - t.SerializeToZeroCopyStream(stream.get()); + t->SerializeToZeroCopyStream(stream.get()); stream->flush(); writerSink_->addBuffers(*holder); @@ -131,23 +139,23 @@ class WriterBase { } } - proto::StripeInformation& addStripeInfo() { - auto stripe = footer_.add_stripes(); - stripe->set_numberofrows(context_->stripeRowCount()); + StripeInformationWriteWrapper addStripeInfo() { + auto stripe = footer_->addStripes(); + stripe.setNumberOfRows(context_->stripeRowCount()); if (context_->stripeRawSize() > 0 || context_->stripeRowCount() == 0) { // ColumnTransformWriter, when rewriting presto written // file does not have rawSize. - stripe->set_rawdatasize(context_->stripeRawSize()); + stripe.setRawDataSize(context_->stripeRawSize()); } auto* checksum = writerSink_->getChecksum(); if (checksum != nullptr) { - stripe->set_checksum(checksum->getDigest()); + stripe.setChecksum(checksum->getDigest()); } - return *stripe; + return stripe; } - proto::Footer& getFooter() { + std::unique_ptr& getFooter() { return footer_; } @@ -170,8 +178,10 @@ class WriterBase { std::unique_ptr context_; std::unique_ptr sink_; std::unique_ptr writerSink_; - proto::Footer footer_; + std::unique_ptr footer_; + proto::orc::Metadata metadata_; std::unordered_map userMetadata_; + std::unique_ptr arena_; friend class WriterTest; VELOX_FRIEND_TEST(WriterBaseTest, FlushWriterSinkUponClose); diff --git a/velox/dwio/dwrf/writer/WriterContext.cpp b/velox/dwio/dwrf/writer/WriterContext.cpp index ce02aacf292..84056c95a0e 100644 --- a/velox/dwio/dwrf/writer/WriterContext.cpp +++ b/velox/dwio/dwrf/writer/WriterContext.cpp @@ -29,9 +29,11 @@ WriterContext::WriterContext( const dwio::common::MetricsLogPtr& metricLogger, const tz::TimeZone* sessionTimezone, const bool adjustTimestampToTimezone, - std::unique_ptr handler) + std::unique_ptr handler, + int64_t memoryBudget) : config_{config}, pool_{std::move(pool)}, + memoryBudget_{memoryBudget}, dictionaryPool_{ pool_->addLeafChild(fmt::format("{}.dictionary", pool_->name()))}, outputStreamPool_{ @@ -69,7 +71,17 @@ WriterContext::WriterContext( } validateConfigs(); VLOG(2) << fmt::format( - "Compression config: {}", common::compressionKindToString(compression_)); + "DWRF WriterContext initialized: pool='{}', maxCapacity={}MB, " + "memoryBudget={}MB, effectiveBudget={}MB, " + "stripeSizeFlushThreshold={}MB, dictionarySizeFlushThreshold={}MB, " + "compression={}", + pool_->name(), + pool_->maxCapacity() / (1024 * 1024), + memoryBudget_ / (1024 * 1024), + getMemoryBudget() / (1024 * 1024), + stripeSizeFlushThreshold_ / (1024 * 1024), + dictionarySizeFlushThreshold_ / (1024 * 1024), + common::compressionKindToString(compression_)); } WriterContext::~WriterContext() { diff --git a/velox/dwio/dwrf/writer/WriterContext.h b/velox/dwio/dwrf/writer/WriterContext.h index 9ba444d5317..524ad0fd569 100644 --- a/velox/dwio/dwrf/writer/WriterContext.h +++ b/velox/dwio/dwrf/writer/WriterContext.h @@ -16,6 +16,7 @@ #pragma once +#include #include #include "velox/common/base/GTestMacros.h" #include "velox/common/time/CpuWallTimer.h" @@ -30,6 +31,7 @@ #include "velox/vector/DecodedVector.h" namespace facebook::velox::dwrf { + using dwio::common::BufferedOutputStream; using dwio::common::DataBufferHolder; using dwio::common::compression::CompressionBufferPool; @@ -45,21 +47,20 @@ class WriterContext : public CompressionBufferPool { dwio::common::MetricsLog::voidLog(), const tz::TimeZone* sessionTimezone = nullptr, const bool adjustTimestampToTimezone = false, - std::unique_ptr handler = nullptr); + std::unique_ptr handler = nullptr, + int64_t memoryBudget = std::numeric_limits::max()); ~WriterContext() override; bool hasStream(const DwrfStreamIdentifier& stream) const { - return streams_.find(stream) != streams_.end(); + return streams_.find(stream) != streams_.cend(); } const DataBufferHolder& getStream(const DwrfStreamIdentifier& stream) const { return streams_.at(stream); } - void addBuffer( - const DwrfStreamIdentifier& stream, - folly::StringPiece buffer) { + void addBuffer(const DwrfStreamIdentifier& stream, std::string_view buffer) { streams_.at(stream).take(buffer); } @@ -115,7 +116,7 @@ class WriterContext : public CompressionBufferPool { velox::memory::MemoryPool& dictionaryPool, velox::memory::MemoryPool& generalPool) { auto result = dictEncoders_.find(encodingKey); - if (result == dictEncoders_.end()) { + if (result == dictEncoders_.cend()) { auto emplaceResult = dictEncoders_.emplace( encodingKey, std::make_unique>( @@ -191,7 +192,7 @@ class WriterContext : public CompressionBufferPool { int64_t getTotalMemoryUsage() const; int64_t getMemoryBudget() const { - return pool_->maxCapacity(); + return std::min(memoryBudget_, pool_->maxCapacity()); } /// Returns the available memory reservations from all the memory pools. @@ -233,7 +234,7 @@ class WriterContext : public CompressionBufferPool { void removeAllIntDictionaryEncodersOnNode( std::function predicate) { auto iter = dictEncoders_.begin(); - while (iter != dictEncoders_.end()) { + while (iter != dictEncoders_.cend()) { if (predicate(iter->first.node())) { iter = dictEncoders_.erase(iter); } else { @@ -623,6 +624,7 @@ class WriterContext : public CompressionBufferPool { const std::shared_ptr config_; const std::shared_ptr pool_; + const int64_t memoryBudget_; const std::shared_ptr dictionaryPool_; const std::shared_ptr outputStreamPool_; const std::shared_ptr generalPool_; diff --git a/velox/dwio/orc/reader/CMakeLists.txt b/velox/dwio/orc/reader/CMakeLists.txt index 18a977afee3..d17c996e072 100644 --- a/velox/dwio/orc/reader/CMakeLists.txt +++ b/velox/dwio/orc/reader/CMakeLists.txt @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -velox_add_library(velox_dwio_orc_reader OrcReader.cpp) +velox_add_library(velox_dwio_orc_reader OrcReader.cpp HEADERS OrcReader.h) velox_link_libraries(velox_dwio_orc_reader velox_dwio_dwrf_reader) diff --git a/velox/dwio/orc/test/ReaderFilterTest.cpp b/velox/dwio/orc/test/ReaderFilterTest.cpp index 329286f0079..a817a091883 100644 --- a/velox/dwio/orc/test/ReaderFilterTest.cpp +++ b/velox/dwio/orc/test/ReaderFilterTest.cpp @@ -47,6 +47,11 @@ class OrcReaderFilterTestP static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); } + + std::shared_ptr dataIoStats_{ + std::make_shared()}; + std::shared_ptr metadataIoStats_{ + std::make_shared()}; }; INSTANTIATE_TEST_SUITE_P( @@ -218,6 +223,8 @@ TEST_P(OrcReaderFilterTestP, tests) { std::string fileName = "orc_all_type.orc"; dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); // To make DwrfReader reads ORC file, setFileFormat to FileFormat::ORC readerOpts.setFileFormat(dwio::common::FileFormat::ORC); diff --git a/velox/dwio/orc/test/ReaderTest.cpp b/velox/dwio/orc/test/ReaderTest.cpp index db06e778b3e..4345bd4931a 100644 --- a/velox/dwio/orc/test/ReaderTest.cpp +++ b/velox/dwio/orc/test/ReaderTest.cpp @@ -16,6 +16,7 @@ #include +#include "velox/common/io/IoStatistics.h" #include "velox/dwio/dwrf/common/Common.h" #include "velox/dwio/dwrf/reader/DwrfReader.h" #include "velox/dwio/dwrf/test/OrcTest.h" @@ -35,6 +36,11 @@ class OrcReaderTest : public testing::Test, public VectorTestBase { static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); } + + std::shared_ptr dataIoStats_ = + std::make_shared(); + std::shared_ptr metadataIoStats_ = + std::make_shared(); }; inline std::string getExamplesFilePath(const std::string& fileName) { @@ -47,6 +53,8 @@ TEST_F(OrcReaderTest, testOrcReaderSimple) { const std::string simpleTest( getExamplesFilePath("TestStringDictionary.testRowIndex.orc")); dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); // To make DwrfReader reads ORC file, setFileFormat to FileFormat::ORC readerOpts.setFileFormat(dwio::common::FileFormat::ORC); auto reader = DwrfReader::create( @@ -84,6 +92,8 @@ TEST_F(OrcReaderTest, testOrcReaderComplexTypes) { h:struct<\ i:array>>>>>")); dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); readerOpts.setFileFormat(dwio::common::FileFormat::ORC); auto reader = DwrfReader::create( createFileBufferedInput(icebergOrc, readerOpts.memoryPool()), readerOpts); @@ -103,6 +113,8 @@ TEST_F(OrcReaderTest, testOrcReaderComplexTypes) { TEST_F(OrcReaderTest, testOrcReaderVarchar) { const std::string varcharOrc(getExamplesFilePath("orc_index_int_string.orc")); dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); readerOpts.setFileFormat(dwio::common::FileFormat::ORC); auto reader = DwrfReader::create( createFileBufferedInput(varcharOrc, readerOpts.memoryPool()), readerOpts); @@ -134,6 +146,8 @@ TEST_F(OrcReaderTest, testOrcReaderDate) { const std::string dateOrc( getExamplesFilePath("TestOrcFile.testDate1900.orc")); dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); readerOpts.setFileFormat(dwio::common::FileFormat::ORC); auto reader = DwrfReader::create( createFileBufferedInput(dateOrc, readerOpts.memoryPool()), readerOpts); @@ -179,6 +193,8 @@ TEST_F(OrcReaderTest, testOrcReaderDate) { TEST_F(OrcReaderTest, testOrcReadAllType) { const std::string dateOrc(getExamplesFilePath("orc_all_type.orc")); dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); readerOpts.setFileFormat(dwio::common::FileFormat::ORC); auto reader = DwrfReader::create( createFileBufferedInput(dateOrc, readerOpts.memoryPool()), readerOpts); @@ -256,6 +272,8 @@ TEST_F(OrcReaderTest, testOrcRlev2) { spec->addAllChildFields(*schema); dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); readerOpts.setScanSpec(spec); readerOpts.setFileFormat(dwio::common::FileFormat::ORC); @@ -361,6 +379,11 @@ class OrcReaderTestP : public testing::TestWithParam, return test::getDataFilePath( "velox/dwio/orc/test", "examples/expected/" + GetParam().json); } + + std::shared_ptr dataIoStats_ = + std::make_shared(); + std::shared_ptr metadataIoStats_ = + std::make_shared(); }; TEST_P( @@ -368,6 +391,8 @@ TEST_P( DwrfReader_FetchesOrcMetadata_ExpectCorrectFooterAndMetadata) { const std::string dateOrc(getFilename()); dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); readerOpts.setFileFormat(dwio::common::FileFormat::ORC); auto reader = DwrfReader::create( createFileBufferedInput(dateOrc, readerOpts.memoryPool()), readerOpts); @@ -390,8 +415,8 @@ TEST_P( auto rowReader = reader->createRowReader(rowReaderOptions); for (std::map::const_iterator itr = - GetParam().userMeta.begin(); - itr != GetParam().userMeta.end(); + GetParam().userMeta.cbegin(); + itr != GetParam().userMeta.cend(); ++itr) { ASSERT_EQ(true, reader->hasMetadataValue(itr->first)); std::string val = reader->getMetadataValue(itr->first); @@ -409,6 +434,8 @@ TEST_P(OrcReaderTestP, DwrfRowReader_ReadAllColumnTypes_ExpectedRowDataRead) { const std::string dateOrc(getFilename()); dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); readerOpts.setFileFormat(dwio::common::FileFormat::ORC); readerOpts.setScanSpec(scanSpec); diff --git a/velox/dwio/parquet/CMakeLists.txt b/velox/dwio/parquet/CMakeLists.txt index c40f352d0bb..c6f4fb10ceb 100644 --- a/velox/dwio/parquet/CMakeLists.txt +++ b/velox/dwio/parquet/CMakeLists.txt @@ -23,8 +23,15 @@ if(VELOX_ENABLE_PARQUET) endif() endif() -velox_add_library(velox_dwio_parquet_reader RegisterParquetReader.cpp) +velox_add_library( + velox_dwio_parquet_reader + RegisterParquetReader.cpp + HEADERS + RegisterParquetReader.h + RegisterParquetWriter.h +) velox_add_library(velox_dwio_parquet_writer RegisterParquetWriter.cpp) +velox_add_library(velox_dwio_parquet_field_id INTERFACE ParquetFieldId.h) if(VELOX_ENABLE_PARQUET) velox_link_libraries(velox_dwio_parquet_reader velox_dwio_native_parquet_reader xsimd) diff --git a/velox/dwio/parquet/ParquetFieldId.h b/velox/dwio/parquet/ParquetFieldId.h new file mode 100644 index 00000000000..f440ab37730 --- /dev/null +++ b/velox/dwio/parquet/ParquetFieldId.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace facebook::velox::parquet { +/// Parquet field IDs during write operations. Each ID must be unique positive +/// number, do not need to be sequential. +/// Used to explicitly control field ID assignment in the Parquet schema. +struct ParquetFieldId { + int32_t fieldId; + std::vector children; +}; +} // namespace facebook::velox::parquet diff --git a/velox/dwio/parquet/common/BloomFilter.cpp b/velox/dwio/parquet/common/BloomFilter.cpp index 211b842cf9e..d8617a80728 100644 --- a/velox/dwio/parquet/common/BloomFilter.cpp +++ b/velox/dwio/parquet/common/BloomFilter.cpp @@ -73,31 +73,25 @@ void BlockSplitBloomFilter::init(const uint8_t* bitset, uint32_t numBytes) { } static void validateBloomFilterHeader(const thrift::BloomFilterHeader& header) { - std::stringstream error; if (!header.algorithm.__isset.BLOCK) { - error << "Unsupported Bloom filter algorithm: "; - error << header.algorithm; - VELOX_FAIL(error.str()); + VELOX_FAIL("Unsupported Bloom filter algorithm"); } if (!header.hash.__isset.XXHASH) { - error << "Unsupported Bloom filter hash: ", error << header.hash; - VELOX_FAIL(error.str()); + VELOX_FAIL("Unsupported Bloom filter hash"); } if (!header.compression.__isset.UNCOMPRESSED) { - error << "Unsupported Bloom filter compression: ", - error << header.compression; - VELOX_FAIL(error.str()); + VELOX_FAIL("Unsupported Bloom filter compression"); } if (header.numBytes <= 0 || static_cast(header.numBytes) > BloomFilter::kMaximumBloomFilterBytes) { - error << "Bloom filter size is incorrect: " << header.numBytes - << ". Must be in range (" << 0 << ", " - << BloomFilter::kMaximumBloomFilterBytes << "]."; - VELOX_FAIL(error.str()); + VELOX_FAIL( + "Bloom filter size is incorrect: {}. Must be in range (0, {}].", + header.numBytes, + BloomFilter::kMaximumBloomFilterBytes); } } @@ -188,9 +182,7 @@ void BlockSplitBloomFilter::writeTo( memBuffer->resetBuffer(); header.write(protocol.get()); } catch (std::exception& e) { - std::stringstream ss; - ss << "Couldn't serialize thrift: " << e.what() << "\n"; - VELOX_FAIL(ss.str()); + VELOX_FAIL("Couldn't serialize thrift: {}", e.what()); } uint8_t* outBuffer; uint32_t outLength; diff --git a/velox/dwio/parquet/common/CMakeLists.txt b/velox/dwio/parquet/common/CMakeLists.txt index 159067602e9..4f1256edd75 100644 --- a/velox/dwio/parquet/common/CMakeLists.txt +++ b/velox/dwio/parquet/common/CMakeLists.txt @@ -18,6 +18,15 @@ velox_add_library( XxHasher.cpp LevelComparison.cpp LevelConversion.cpp + HEADERS + BitStreamUtilsInternal.h + BloomFilter.h + Hasher.h + LevelComparison.h + LevelConversion.h + LevelConversionUtil.h + RleEncodingInternal.h + XxHasher.h ) velox_link_libraries( diff --git a/velox/dwio/parquet/common/XxHasher.cpp b/velox/dwio/parquet/common/XxHasher.cpp index df189a972b6..93ee12f1b60 100644 --- a/velox/dwio/parquet/common/XxHasher.cpp +++ b/velox/dwio/parquet/common/XxHasher.cpp @@ -18,8 +18,7 @@ #include "XxHasher.h" -#define XXH_INLINE_ALL -#include +#include "velox/common/base/XxHashInline.h" namespace facebook::velox::parquet { diff --git a/velox/dwio/parquet/reader/CMakeLists.txt b/velox/dwio/parquet/reader/CMakeLists.txt index 292285da9b5..d73c8a31c67 100644 --- a/velox/dwio/parquet/reader/CMakeLists.txt +++ b/velox/dwio/parquet/reader/CMakeLists.txt @@ -26,6 +26,30 @@ velox_add_library( StructColumnReader.cpp StringColumnReader.cpp SemanticVersion.cpp + HEADERS + BooleanColumnReader.h + BooleanDecoder.h + DeltaBpDecoder.h + DeltaByteArrayDecoder.h + FloatingPointColumnReader.h + IntegerColumnReader.h + Metadata.h + NestedStructureDecoder.h + PageReader.h + ParquetColumnReader.h + ParquetData.h + ParquetReader.h + ParquetStatsContext.h + ParquetTypeWithId.h + RepeatedColumnReader.h + RleBpDataDecoder.h + RleBpDecoder.h + SemanticVersion.h + StringColumnReader.h + StringDecoder.h + StructColumnReader.h + TimeColumnReader.h + TimestampColumnReader.h ) velox_link_libraries( diff --git a/velox/dwio/parquet/reader/DeltaBpDecoder.h b/velox/dwio/parquet/reader/DeltaBpDecoder.h index 59ec359ce43..43024cfad4a 100644 --- a/velox/dwio/parquet/reader/DeltaBpDecoder.h +++ b/velox/dwio/parquet/reader/DeltaBpDecoder.h @@ -16,8 +16,10 @@ #pragma once +#include #include "velox/common/base/BitUtil.h" #include "velox/common/base/Exceptions.h" +#include "velox/common/base/Nulls.h" namespace facebook::velox::parquet { @@ -127,8 +129,8 @@ class DeltaBpDecoder { VELOX_CHECK_EQ( valuesPerBlock_ % 128, 0, - "the number of values in a block must be multiple of 128, but it's " + - std::to_string(valuesPerBlock_)); + "the number of values in a block must be multiple of 128, but it's {}", + valuesPerBlock_); VELOX_CHECK_GT( miniBlocksPerBlock_, 0, "cannot have zero miniblock per block"); valuesPerMiniBlock_ = valuesPerBlock_ / miniBlocksPerBlock_; @@ -137,8 +139,8 @@ class DeltaBpDecoder { VELOX_CHECK_EQ( valuesPerMiniBlock_ % 32, 0, - "the number of values in a miniblock must be multiple of 32, but it's " + - std::to_string(valuesPerMiniBlock_)); + "the number of values in a miniblock must be multiple of 32, but it's {}", + valuesPerMiniBlock_); totalValuesRemaining_ = totalValueCount_; deltaBitWidths_.resize(miniBlocksPerBlock_); diff --git a/velox/dwio/parquet/reader/DeltaByteArrayDecoder.h b/velox/dwio/parquet/reader/DeltaByteArrayDecoder.h index 74e4d9001c7..b8ad2d18d48 100644 --- a/velox/dwio/parquet/reader/DeltaByteArrayDecoder.h +++ b/velox/dwio/parquet/reader/DeltaByteArrayDecoder.h @@ -17,60 +17,16 @@ #pragma once #include "velox/common/base/BitUtil.h" +#include "velox/common/base/Nulls.h" #include "velox/dwio/parquet/reader/DeltaBpDecoder.h" namespace facebook::velox::parquet { -// DeltaByteArrayDecoder is adapted from Apache Arrow: -// https://github.com/apache/arrow/blob/apache-arrow-15.0.0/cpp/src/parquet/encoding.cc#L2758-L2889 -class DeltaLengthByteArrayDecoder { +class DeltaByteArrayDecoderBase { public: - explicit DeltaLengthByteArrayDecoder(const char* start) { - lengthDecoder_ = std::make_unique(start); - decodeLengths(); - bufferStart_ = lengthDecoder_->bufferStart(); - } - - std::string_view readString() { - const int64_t length = bufferedLength_[lengthIdx_++]; - VELOX_CHECK_GE(length, 0, "negative string delta length"); - bufferStart_ += length; - return std::string_view(bufferStart_ - length, length); - } + virtual ~DeltaByteArrayDecoderBase() = default; - private: - void decodeLengths() { - int64_t numLength = lengthDecoder_->validValuesCount(); - bufferedLength_.resize(numLength); - lengthDecoder_->readValues(bufferedLength_.data(), numLength); - - lengthIdx_ = 0; - numValidValues_ = numLength; - } - - const char* bufferStart_; - std::unique_ptr lengthDecoder_; - int32_t numValidValues_{0}; - uint32_t lengthIdx_{0}; - std::vector bufferedLength_; -}; - -// DeltaByteArrayDecoder is adapted from Apache Arrow: -// https://github.com/apache/arrow/blob/apache-arrow-15.0.0/cpp/src/parquet/encoding.cc#L3301-L3545 -class DeltaByteArrayDecoder { - public: - explicit DeltaByteArrayDecoder(const char* start) { - prefixLenDecoder_ = std::make_unique(start); - int64_t numPrefix = prefixLenDecoder_->validValuesCount(); - bufferedPrefixLength_.resize(numPrefix); - prefixLenDecoder_->readValues( - bufferedPrefixLength_.data(), numPrefix); - prefixLenOffset_ = 0; - numValidValues_ = numPrefix; - - suffixDecoder_ = std::make_unique( - prefixLenDecoder_->bufferStart()); - } + virtual std::string_view readString() = 0; void skip(uint64_t numValues) { skip(numValues, 0, nullptr); @@ -120,8 +76,61 @@ class DeltaByteArrayDecoder { } } } +}; + +// DeltaByteArrayDecoder is adapted from Apache Arrow: +// https://github.com/apache/arrow/blob/apache-arrow-15.0.0/cpp/src/parquet/encoding.cc#L2758-L2889 +class DeltaLengthByteArrayDecoder : public DeltaByteArrayDecoderBase { + public: + explicit DeltaLengthByteArrayDecoder(const char* start) { + lengthDecoder_ = std::make_unique(start); + decodeLengths(); + bufferStart_ = lengthDecoder_->bufferStart(); + } + + std::string_view readString() override { + const int64_t length = bufferedLength_[lengthIdx_++]; + VELOX_CHECK_GE(length, 0, "negative string delta length"); + bufferStart_ += length; + return std::string_view(bufferStart_ - length, length); + } + + private: + void decodeLengths() { + int64_t numLength = lengthDecoder_->validValuesCount(); + bufferedLength_.resize(numLength); + lengthDecoder_->readValues( + bufferedLength_.data(), static_cast(numLength)); + + lengthIdx_ = 0; + numValidValues_ = static_cast(numLength); + } + + const char* bufferStart_; + std::unique_ptr lengthDecoder_; + int32_t numValidValues_{0}; + uint32_t lengthIdx_{0}; + std::vector bufferedLength_; +}; + +// DeltaByteArrayDecoder is adapted from Apache Arrow: +// https://github.com/apache/arrow/blob/apache-arrow-15.0.0/cpp/src/parquet/encoding.cc#L3301-L3545 +class DeltaByteArrayDecoder : public DeltaByteArrayDecoderBase { + public: + explicit DeltaByteArrayDecoder(const char* start) { + prefixLenDecoder_ = std::make_unique(start); + int64_t numPrefix = prefixLenDecoder_->validValuesCount(); + bufferedPrefixLength_.resize(numPrefix); + prefixLenDecoder_->readValues( + bufferedPrefixLength_.data(), static_cast(numPrefix)); + prefixLenOffset_ = 0; + numValidValues_ = static_cast(numPrefix); + + suffixDecoder_ = std::make_unique( + prefixLenDecoder_->bufferStart()); + } - std::string_view readString() { + std::string_view readString() override { auto suffix = suffixDecoder_->readString(); bool isFirstRun = (prefixLenOffset_ == 0); const int64_t prefixLength = bufferedPrefixLength_[prefixLenOffset_++]; diff --git a/velox/dwio/parquet/reader/FloatingPointColumnReader.h b/velox/dwio/parquet/reader/FloatingPointColumnReader.h index cac475c0ee9..1a3fc9c4c58 100644 --- a/velox/dwio/parquet/reader/FloatingPointColumnReader.h +++ b/velox/dwio/parquet/reader/FloatingPointColumnReader.h @@ -36,6 +36,13 @@ class FloatingPointColumnReader ParquetParams& params, common::ScanSpec& scanSpec); + // Parquet floating point reader always supports a bulk path + static constexpr bool kHasBulkPath = true; + + bool hasBulkPath() const override { + return kHasBulkPath; + } + void seekToRowGroup(int64_t index) override { base::seekToRowGroup(index); this->scanState().clear(); @@ -66,7 +73,16 @@ FloatingPointColumnReader::FloatingPointColumnReader( requestedType, std::move(fileType), params, - scanSpec) {} + scanSpec) { + VELOX_DCHECK( + (this->requestedType_->kind() == TypeKind::REAL && + std::is_same_v) || + (this->requestedType_->kind() == TypeKind::DOUBLE && + std::is_same_v), + "TRequested type mismatch: template parameter is {}, but requestedType is {}", + folly::demangle(typeid(TRequested)), + this->requestedType_->toString()); +} template uint64_t FloatingPointColumnReader::skip( diff --git a/velox/dwio/parquet/reader/IntegerColumnReader.h b/velox/dwio/parquet/reader/IntegerColumnReader.h index 8c2aa2b4df1..fe8afcf90ce 100644 --- a/velox/dwio/parquet/reader/IntegerColumnReader.h +++ b/velox/dwio/parquet/reader/IntegerColumnReader.h @@ -17,6 +17,10 @@ #pragma once #include "velox/dwio/common/SelectiveIntegerColumnReader.h" +#include "velox/dwio/parquet/reader/ParquetColumnReader.h" +#include "velox/dwio/parquet/reader/ParquetData.h" +#include "velox/dwio/parquet/reader/ParquetTypeWithId.h" +#include "velox/type/DecimalUtil.h" namespace facebook::velox::parquet { @@ -61,6 +65,9 @@ class IntegerColumnReader : public dwio::common::SelectiveIntegerColumnReader { getUnsignedIntValues(rows, requestedType_, result); } else { getIntValues(rows, requestedType_, result); + if (requestedType_->isDecimal() && !allNull_) { + rescaleDecimalValues(fileType, *result); + } } } @@ -82,6 +89,67 @@ class IntegerColumnReader : public dwio::common::SelectiveIntegerColumnReader { void readWithVisitor(const RowSet& rows, ColumnVisitor visitor) { formatData_->as().readWithVisitor(visitor); } + + private: + // Rescales integer or decimal values to match the requested decimal type. + // For INT->Decimal, fileScale is 0; for Decimal->Decimal, fileScale comes + // from the file's decimal type. + void rescaleDecimalValues( + const ParquetTypeWithId& fileType, + VectorPtr& result) { + int32_t requestedScale = getDecimalPrecisionScale(*requestedType_).second; + int32_t fileScale = fileType.type()->isDecimal() + ? getDecimalPrecisionScale(*fileType.type()).second + : 0; + int32_t scaleAdjust = requestedScale - fileScale; + VELOX_USER_CHECK_GE( + scaleAdjust, + 0, + "Parquet does not support scale narrowing: {}", + scaleAdjust); + VELOX_USER_CHECK_LE( + scaleAdjust, + LongDecimalType::kMaxPrecision, + "Scale adjustment exceeds max decimal precision: {}", + scaleAdjust); + + if (scaleAdjust > 0) { + if (requestedType_->isShortDecimal()) { + // Safe to cast: kPowersOfTen[scaleAdjust] fits in int64_t because + // scaleAdjust <= maxPrecision(18) and 10^18 < 2^63. + applyDecimalScaleMultiplier( + result, + static_cast(DecimalUtil::kPowersOfTen[scaleAdjust])); + } else { + applyDecimalScaleMultiplier( + result, DecimalUtil::kPowersOfTen[scaleAdjust]); + } + } + } + + /// Multiplies all non-null values in result by multiplier. + /// Overflow is impossible because convertType validates precInc >= scaleInc, + /// guaranteeing that originalValue * 10^scaleAdjust fits within the target + /// precision. + template + void applyDecimalScaleMultiplier(const VectorPtr& result, T multiplier) + const { + auto* flat = result->asUnchecked>(); + auto* rawValues = flat->mutableRawValues(); + const auto* rawNulls = flat->rawNulls(); + const auto size = flat->size(); + if (!rawNulls) { + for (vector_size_t i = 0; i < size; ++i) { + rawValues[i] *= multiplier; + } + } else { + for (vector_size_t i = 0; i < size; ++i) { + if (bits::isBitSet(rawNulls, i)) { + rawValues[i] *= multiplier; + } + } + } + } }; } // namespace facebook::velox::parquet diff --git a/velox/dwio/parquet/reader/Metadata.cpp b/velox/dwio/parquet/reader/Metadata.cpp index e801386e523..05f3de17519 100644 --- a/velox/dwio/parquet/reader/Metadata.cpp +++ b/velox/dwio/parquet/reader/Metadata.cpp @@ -26,6 +26,20 @@ inline const T load(const char* ptr) { return ret; } +inline std::optional decodeInt64Stat(const std::string& bytes) { + switch (bytes.size()) { + case sizeof(int64_t): + return load(bytes.data()); + case sizeof(int32_t): + // Parquet stores Time types as int32_t (for milliseconds), but Velox's + // Time type is always int64_t, so we need to load sizeof(int32_t) and + // cast to int64_t. + return static_cast(load(bytes.data())); + default: + return std::nullopt; + } +} + template inline std::optional getMin(const thrift::Statistics& columnChunkStats) { return columnChunkStats.__isset.min_value @@ -44,6 +58,24 @@ inline std::optional getMax(const thrift::Statistics& columnChunkStats) { : std::nullopt); } +template <> +inline std::optional getMin( + const thrift::Statistics& columnChunkStats) { + return columnChunkStats.__isset.min_value + ? decodeInt64Stat(columnChunkStats.min_value) + : (columnChunkStats.__isset.min ? decodeInt64Stat(columnChunkStats.min) + : std::nullopt); +} + +template <> +inline std::optional getMax( + const thrift::Statistics& columnChunkStats) { + return columnChunkStats.__isset.max_value + ? decodeInt64Stat(columnChunkStats.max_value) + : (columnChunkStats.__isset.max ? decodeInt64Stat(columnChunkStats.max) + : std::nullopt); +} + template <> inline std::optional getMin( const thrift::Statistics& columnChunkStats) { @@ -156,22 +188,16 @@ common::CompressionKind thriftCodecToCompressionKind( switch (codec) { case thrift::CompressionCodec::UNCOMPRESSED: return common::CompressionKind::CompressionKind_NONE; - break; case thrift::CompressionCodec::SNAPPY: return common::CompressionKind::CompressionKind_SNAPPY; - break; case thrift::CompressionCodec::GZIP: return common::CompressionKind::CompressionKind_GZIP; - break; case thrift::CompressionCodec::LZO: return common::CompressionKind::CompressionKind_LZO; - break; case thrift::CompressionCodec::LZ4: return common::CompressionKind::CompressionKind_LZ4; - break; case thrift::CompressionCodec::ZSTD: return common::CompressionKind::CompressionKind_ZSTD; - break; case thrift::CompressionCodec::LZ4_RAW: return common::CompressionKind::CompressionKind_LZ4; default: @@ -321,8 +347,9 @@ FileMetaDataPtr::FileMetaDataPtr(const void* metadata) : ptr_(metadata) {} FileMetaDataPtr::~FileMetaDataPtr() = default; RowGroupMetaDataPtr FileMetaDataPtr::rowGroup(int i) const { - return RowGroupMetaDataPtr(reinterpret_cast( - &thriftFileMetaDataPtr(ptr_)->row_groups[i])); + return RowGroupMetaDataPtr( + reinterpret_cast( + &thriftFileMetaDataPtr(ptr_)->row_groups[i])); } int64_t FileMetaDataPtr::numRows() const { @@ -356,7 +383,7 @@ std::string FileMetaDataPtr::keyValueMetadataValue( return thriftFileMetaDataPtr(ptr_)->key_value_metadata[i].value; } } - VELOX_FAIL(fmt::format("Input key {} is not in the key value metadata", key)); + VELOX_FAIL("Input key {} is not in the key value metadata", key); } std::string FileMetaDataPtr::createdBy() const { diff --git a/velox/dwio/parquet/reader/PageReader.cpp b/velox/dwio/parquet/reader/PageReader.cpp index 879a97a51f3..7202adde6fb 100644 --- a/velox/dwio/parquet/reader/PageReader.cpp +++ b/velox/dwio/parquet/reader/PageReader.cpp @@ -16,12 +16,15 @@ #include "velox/dwio/parquet/reader/PageReader.h" +#include +#include + #include "velox/common/testutil/TestValue.h" +#include "velox/common/time/Timer.h" #include "velox/dwio/common/BufferUtil.h" #include "velox/dwio/common/ColumnVisitors.h" #include "velox/dwio/parquet/common/LevelConversion.h" #include "velox/dwio/parquet/thrift/ThriftTransport.h" - #include "velox/vector/FlatVector.h" #include // @manual @@ -87,7 +90,12 @@ PageHeader PageReader::readPageHeader() { if (bufferEnd_ == bufferStart_) { const void* buffer; int32_t size; - inputStream_->Next(&buffer, &size); + uint64_t readUs{0}; + { + MicrosecondTimer timer(&readUs); + inputStream_->Next(&buffer, &size); + } + stats_.pageLoadTimeNs.increment(readUs * 1'000); bufferStart_ = reinterpret_cast(buffer); bufferEnd_ = bufferStart_ + size; } @@ -106,26 +114,31 @@ PageHeader PageReader::readPageHeader() { } const char* PageReader::readBytes(int32_t size, BufferPtr& copy) { - if (bufferEnd_ == bufferStart_) { - const void* buffer = nullptr; - int32_t bufferSize = 0; - if (!inputStream_->Next(&buffer, &bufferSize)) { - VELOX_FAIL("Read past end"); + uint64_t readUs{0}; + { + MicrosecondTimer timer(&readUs); + if (bufferEnd_ == bufferStart_) { + const void* buffer = nullptr; + int32_t bufferSize = 0; + if (!inputStream_->Next(&buffer, &bufferSize)) { + VELOX_FAIL("Read past end"); + } + bufferStart_ = reinterpret_cast(buffer); + bufferEnd_ = bufferStart_ + bufferSize; } - bufferStart_ = reinterpret_cast(buffer); - bufferEnd_ = bufferStart_ + bufferSize; - } - if (bufferEnd_ - bufferStart_ >= size) { - bufferStart_ += size; - return bufferStart_ - size; - } - dwio::common::ensureCapacity(copy, size, &pool_); - dwio::common::readBytes( - size, - inputStream_.get(), - copy->asMutable(), - bufferStart_, - bufferEnd_); + if (bufferEnd_ - bufferStart_ >= size) { + bufferStart_ += size; + return bufferStart_ - size; + } + dwio::common::ensureCapacity(copy, size, &pool_); + dwio::common::readBytes( + size, + inputStream_.get(), + copy->asMutable(), + bufferStart_, + bufferEnd_); + } + stats_.pageLoadTimeNs.increment(readUs * 1'000); return copy->as(); } @@ -133,27 +146,59 @@ const char* PageReader::decompressData( const char* pageData, uint32_t compressedSize, uint32_t uncompressedSize) { - std::unique_ptr inputStream = - std::make_unique( - pageData, compressedSize, 0); - auto streamDebugInfo = - fmt::format("Page Reader: Stream {}", inputStream_->getName()); - std::unique_ptr decompressedStream = - dwio::common::compression::createDecompressor( - codec_, - std::move(inputStream), - uncompressedSize, - pool_, - getParquetDecompressionOptions(codec_), - streamDebugInfo, - nullptr, - true, - compressedSize); - dwio::common::ensureCapacity( decompressedData_, uncompressedSize, &pool_); - decompressedStream->readFully( - decompressedData_->asMutable(), uncompressedSize); + auto* dest = decompressedData_->asMutable(); + + switch (codec_) { + case common::CompressionKind::CompressionKind_SNAPPY: { + size_t actualUncompressedSize; + VELOX_CHECK( + snappy::GetUncompressedLength( + pageData, compressedSize, &actualUncompressedSize), + "Snappy: failed to get uncompressed length from corrupt data"); + VELOX_CHECK_EQ(actualUncompressedSize, uncompressedSize); + VELOX_CHECK( + snappy::RawUncompress(pageData, compressedSize, dest), + "Snappy decompression failed"); + break; + } + case common::CompressionKind::CompressionKind_ZSTD: { + thread_local std::unique_ptr zstdCtx{ + ZSTD_createDCtx(), ZSTD_freeDCtx}; + VELOX_CHECK_NOT_NULL(zstdCtx); + const auto actualUncompressedSize = ZSTD_decompressDCtx( + zstdCtx.get(), dest, uncompressedSize, pageData, compressedSize); + VELOX_CHECK( + !ZSTD_isError(actualUncompressedSize), + "ZSTD decompression failed: {}", + ZSTD_getErrorName(actualUncompressedSize)); + VELOX_CHECK_EQ(actualUncompressedSize, uncompressedSize); + break; + } + default: { + // Fallback to stream-based decompression for other codecs (gzip, lz4, + // lzo). + std::unique_ptr inputStream = + std::make_unique( + pageData, compressedSize, 0); + auto streamDebugInfo = + fmt::format("Page Reader: Stream {}", inputStream_->getName()); + std::unique_ptr decompressedStream = + dwio::common::compression::createDecompressor( + codec_, + std::move(inputStream), + uncompressedSize, + pool_, + getParquetDecompressionOptions(codec_), + streamDebugInfo, + nullptr, + true, + compressedSize); + decompressedStream->readFully(dest, uncompressedSize); + break; + } + } return decompressedData_->as(); } @@ -222,18 +267,42 @@ void PageReader::prepareDataPageV1(const PageHeader& pageHeader, int64_t row) { pageHeader.compressed_page_size, pageHeader.uncompressed_page_size); auto pageEnd = pageData_ + pageHeader.uncompressed_page_size; + auto remainingBytes = pageHeader.uncompressed_page_size; if (maxRepeat_ > 0) { + VELOX_CHECK_GE( + remainingBytes, + sizeof(int32_t), + "Insufficient bytes for repetition level length (corrupt data page?)"); uint32_t repeatLength = readField(pageData_); + remainingBytes -= sizeof(int32_t); + VELOX_CHECK_LE( + repeatLength, + remainingBytes, + "Repetition level length {} exceeds remaining page size {} (corrupt data page?)", + repeatLength, + remainingBytes); repeatDecoder_ = std::make_unique( reinterpret_cast(pageData_), repeatLength, ::arrow::bit_util::NumRequiredBits(maxRepeat_)); pageData_ += repeatLength; + remainingBytes -= repeatLength; } if (maxDefine_ > 0) { + VELOX_CHECK_GE( + remainingBytes, + sizeof(uint32_t), + "Insufficient bytes for definition level length (corrupt data page?)"); auto defineLength = readField(pageData_); + remainingBytes -= sizeof(uint32_t); + VELOX_CHECK_LE( + defineLength, + remainingBytes, + "Definition level length {} exceeds remaining page size {} (corrupt data page?)", + defineLength, + remainingBytes); if (maxDefine_ == 1) { defineDecoder_ = std::make_unique( pageData_, @@ -278,6 +347,13 @@ void PageReader::prepareDataPageV2(const PageHeader& pageHeader, int64_t row) { pageHeader.data_page_header_v2.repetition_levels_byte_length; auto bytes = pageHeader.compressed_page_size; + VELOX_CHECK_LE( + static_cast(repeatLength) + defineLength, + bytes, + "Repetition and definition level lengths ({} + {}) exceed compressed page size {} (corrupt data page?)", + repeatLength, + defineLength, + bytes); pageData_ = readBytes(bytes, pageBuffer_); if (repeatLength) { @@ -368,12 +444,17 @@ void PageReader::prepareDictionary(const PageHeader& pageHeader) { if (pageData_) { memcpy(dictionary_.values->asMutable(), pageData_, numBytes); } else { - dwio::common::readBytes( - numBytes, - inputStream_.get(), - dictionary_.values->asMutable(), - bufferStart_, - bufferEnd_); + uint64_t readUs{0}; + { + MicrosecondTimer timer(&readUs); + dwio::common::readBytes( + numBytes, + inputStream_.get(), + dictionary_.values->asMutable(), + bufferStart_, + bufferEnd_); + } + stats_.pageLoadTimeNs.increment(readUs * 1'000); } if (type_->type()->isShortDecimal() && parquetType == thrift::Type::INT32) { @@ -403,12 +484,17 @@ void PageReader::prepareDictionary(const PageHeader& pageHeader) { if (pageData_) { memcpy(dictionary_.values->asMutable(), pageData_, numBytes); } else { - dwio::common::readBytes( - numBytes, - inputStream_.get(), - dictionary_.values->asMutable(), - bufferStart_, - bufferEnd_); + uint64_t readUs{0}; + { + MicrosecondTimer timer(&readUs); + dwio::common::readBytes( + numBytes, + inputStream_.get(), + dictionary_.values->asMutable(), + bufferStart_, + bufferEnd_); + } + stats_.pageLoadTimeNs.increment(readUs * 1'000); } // Expand the Parquet type length values to Velox type length. // We start from the end to allow in-place expansion. @@ -435,8 +521,13 @@ void PageReader::prepareDictionary(const PageHeader& pageHeader) { if (pageData_) { memcpy(strings, pageData_, numBytes); } else { - dwio::common::readBytes( - numBytes, inputStream_.get(), strings, bufferStart_, bufferEnd_); + uint64_t readUs{0}; + { + MicrosecondTimer timer(&readUs); + dwio::common::readBytes( + numBytes, inputStream_.get(), strings, bufferStart_, bufferEnd_); + } + stats_.pageLoadTimeNs.increment(readUs * 1'000); } auto header = strings; for (auto i = 0; i < dictionary_.numValues; ++i) { @@ -452,18 +543,24 @@ void PageReader::prepareDictionary(const PageHeader& pageHeader) { auto numParquetBytes = dictionary_.numValues * parquetTypeLength; auto veloxTypeLength = type_->type()->cppSizeInBytes(); auto numVeloxBytes = dictionary_.numValues * veloxTypeLength; + VELOX_CHECK_LE(numParquetBytes, numVeloxBytes); dictionary_.values = AlignedBuffer::allocate(numVeloxBytes, &pool_); auto data = dictionary_.values->asMutable(); // Read the data bytes. if (pageData_) { memcpy(data, pageData_, numParquetBytes); } else { - dwio::common::readBytes( - numParquetBytes, - inputStream_.get(), - data, - bufferStart_, - bufferEnd_); + uint64_t readUs{0}; + { + MicrosecondTimer timer(&readUs); + dwio::common::readBytes( + numParquetBytes, + inputStream_.get(), + data, + bufferStart_, + bufferEnd_); + } + stats_.pageLoadTimeNs.increment(readUs * 1'000); } if (type_->type()->isShortDecimal()) { // Parquet decimal values have a fixed typeLength_ and are in big-endian @@ -752,6 +849,13 @@ void PageReader::makeDecoder() { break; } [[fallthrough]]; + case Encoding::DELTA_LENGTH_BYTE_ARRAY: + if (parquetType == thrift::Type::BYTE_ARRAY) { + deltaLengthByteArrDecoder_ = + std::make_unique(pageData_); + break; + } + [[fallthrough]]; default: VELOX_UNSUPPORTED("Encoding not supported yet: {}", encoding_); } @@ -794,6 +898,8 @@ void PageReader::skip(int64_t numRows) { deltaBpDecoder_->skip(toSkip); } else if (deltaByteArrDecoder_) { deltaByteArrDecoder_->skip(toSkip); + } else if (deltaLengthByteArrDecoder_) { + deltaLengthByteArrDecoder_->skip(toSkip); } else if (rleBooleanDecoder_) { rleBooleanDecoder_->skip(toSkip); } else { diff --git a/velox/dwio/parquet/reader/PageReader.h b/velox/dwio/parquet/reader/PageReader.h index c377100428a..f0ad90dd1c0 100644 --- a/velox/dwio/parquet/reader/PageReader.h +++ b/velox/dwio/parquet/reader/PageReader.h @@ -42,6 +42,7 @@ class PageReader { ParquetTypeWithIdPtr fileType, common::CompressionKind codec, int64_t chunkSize, + dwio::common::ColumnReaderStatistics& stats, const tz::TimeZone* sessionTimezone) : pool_(pool), inputStream_(std::move(stream)), @@ -52,6 +53,7 @@ class PageReader { codec_(codec), chunkSize_(chunkSize), nullConcatenation_(pool_), + stats_(stats), sessionTimezone_(sessionTimezone) { type_->makeLevelInfo(leafInfo_); } @@ -62,15 +64,19 @@ class PageReader { memory::MemoryPool& pool, common::CompressionKind codec, int64_t chunkSize, - const tz::TimeZone* sessionTimezone = nullptr) + dwio::common::ColumnReaderStatistics& stats, + const tz::TimeZone* sessionTimezone = nullptr, + int32_t maxRepeat = 0, + int32_t maxDefine = 1) : pool_(pool), inputStream_(std::move(stream)), - maxRepeat_(0), - maxDefine_(1), + maxRepeat_(maxRepeat), + maxDefine_(maxDefine), isTopLevel_(maxRepeat_ == 0 && maxDefine_ <= 1), codec_(codec), chunkSize_(chunkSize), nullConcatenation_(pool_), + stats_(stats), sessionTimezone_(sessionTimezone) {} /// Advances 'numRows' top level rows. @@ -264,7 +270,7 @@ class PageReader { template < typename Visitor, typename std::enable_if< - !std::is_same_v && + !std::is_same_v && !std::is_same_v, int>::type = 0> void @@ -300,7 +306,7 @@ class PageReader { template < typename Visitor, typename std::enable_if< - std::is_same_v, + std::is_same_v, int>::type = 0> void callDecoder(const uint64_t* nulls, bool& nullsFromFastPath, Visitor visitor) { @@ -312,6 +318,9 @@ class PageReader { } else if (encoding_ == thrift::Encoding::DELTA_BYTE_ARRAY) { nullsFromFastPath = false; deltaByteArrDecoder_->readWithVisitor(nulls, visitor); + } else if (encoding_ == thrift::Encoding::DELTA_LENGTH_BYTE_ARRAY) { + nullsFromFastPath = false; + deltaLengthByteArrDecoder_->readWithVisitor(nulls, visitor); } else { nullsFromFastPath = false; stringDecoder_->readWithVisitor(nulls, visitor); @@ -322,6 +331,8 @@ class PageReader { dictionaryIdDecoder_->readWithVisitor(nullptr, dictVisitor); } else if (encoding_ == thrift::Encoding::DELTA_BYTE_ARRAY) { deltaByteArrDecoder_->readWithVisitor(nulls, visitor); + } else if (encoding_ == thrift::Encoding::DELTA_LENGTH_BYTE_ARRAY) { + deltaLengthByteArrDecoder_->readWithVisitor(nulls, visitor); } else { stringDecoder_->readWithVisitor(nulls, visitor); } @@ -359,14 +370,17 @@ class PageReader { // Returns the number of passed rows/values gathered by // 'reader'. Only numRows() is set for a filter-only case, only // numValues() is set for a non-filtered case. - template - static int32_t numRowsInReader( - const dwio::common::SelectiveColumnReader& reader) { + template + static int32_t numValuesRead( + const dwio::common::SelectiveColumnReader& reader, + const int32_t numPageRowsRead) { + if (hasHook) { + return numPageRowsRead; + } if (hasFilter) { return reader.numRows(); - } else { - return reader.numValues(); } + return reader.numValues(); } memory::MemoryPool& pool_; @@ -502,6 +516,8 @@ class PageReader { // Base values of dictionary when reading a string dictionary. VectorPtr dictionaryValues_; + dwio::common::ColumnReaderStatistics& stats_; + const tz::TimeZone* sessionTimezone_{nullptr}; // Decoders. Only one will be set at a time. @@ -511,6 +527,7 @@ class PageReader { std::unique_ptr booleanDecoder_; std::unique_ptr deltaBpDecoder_; std::unique_ptr deltaByteArrDecoder_; + std::unique_ptr deltaLengthByteArrDecoder_; std::unique_ptr rleBooleanDecoder_; // Add decoders for other encodings here. }; @@ -537,6 +554,10 @@ void PageReader::readWithVisitor(Visitor& visitor) { !std::is_same_v; constexpr bool filterOnly = std::is_same_v; + constexpr bool hasHook = Visitor::kHasHook; + static_assert( + !(hasFilter && hasHook), "hasFilter and hasHook cannot both be true"); + bool mayProduceNulls = !filterOnly && visitor.allowNulls(); auto rows = visitor.rows(); auto numRows = visitor.numRows(); @@ -544,11 +565,13 @@ void PageReader::readWithVisitor(Visitor& visitor) { startVisit(folly::Range(rows, numRows)); rowsCopy_ = &visitor.rowsCopy(); folly::Range pageRows; + int32_t numPageRowsRead = 0; const uint64_t* nulls = nullptr; bool isMultiPage = false; while (rowsForPage(reader, hasFilter, mayProduceNulls, pageRows, nulls)) { bool nullsFromFastPath = false; - int32_t numValuesBeforePage = numRowsInReader(reader); + const int32_t numValuesBeforePage = + numValuesRead(reader, numPageRowsRead); visitor.setNumValuesBias(numValuesBeforePage); visitor.setRows(pageRows); callDecoder(nulls, nullsFromFastPath, visitor); @@ -571,20 +594,23 @@ void PageReader::readWithVisitor(Visitor& visitor) { } if (!nulls) { nullConcatenation_.appendOnes( - numRowsInReader(reader) - numValuesBeforePage); + numValuesRead(reader, numPageRowsRead) - + numValuesBeforePage); } else if (reader.returnReaderNulls()) { // Nulls from decoding go directly to result. nullConcatenation_.append( reader.nullsInReadRange()->template as(), 0, - numRowsInReader(reader) - numValuesBeforePage); + numValuesRead(reader, numPageRowsRead) - + numValuesBeforePage); } else { // Add the nulls produced from the decoder to the result. auto firstNullIndex = nullsFromFastPath ? 0 : numValuesBeforePage; nullConcatenation_.append( reader.mutableNulls(0), firstNullIndex, - firstNullIndex + numRowsInReader(reader) - + firstNullIndex + + numValuesRead(reader, numPageRowsRead) - numValuesBeforePage); } } @@ -598,6 +624,7 @@ void PageReader::readWithVisitor(Visitor& visitor) { if (hasFilter && rowNumberBias_) { reader.offsetOutputRows(numValuesBeforePage, rowNumberBias_); } + numPageRowsRead += pageRows.size(); } if (isMultiPage) { reader.setNulls(mayProduceNulls ? nullConcatenation_.buffer() : nullptr); diff --git a/velox/dwio/parquet/reader/ParquetColumnReader.cpp b/velox/dwio/parquet/reader/ParquetColumnReader.cpp index 0b69f446280..b041d9f94be 100644 --- a/velox/dwio/parquet/reader/ParquetColumnReader.cpp +++ b/velox/dwio/parquet/reader/ParquetColumnReader.cpp @@ -27,6 +27,7 @@ #include "velox/dwio/parquet/reader/RepeatedColumnReader.h" #include "velox/dwio/parquet/reader/StringColumnReader.h" #include "velox/dwio/parquet/reader/StructColumnReader.h" +#include "velox/dwio/parquet/reader/TimeColumnReader.h" #include "velox/dwio/parquet/reader/TimestampColumnReader.h" #include "velox/dwio/parquet/thrift/ParquetThriftTypes.h" @@ -39,8 +40,20 @@ std::unique_ptr ParquetColumnReader::build( const std::shared_ptr& fileType, ParquetParams& params, common::ScanSpec& scanSpec) { + VELOX_CHECK_EQ( + static_cast(scanSpec.extractionType()), + static_cast(common::ScanSpec::ExtractionType::kNone), + "Parquet reader does not support extraction pushdown"); auto colName = scanSpec.fieldName(); + if (fileType->type()->isTime()) { + VELOX_CHECK( + fileType->type()->equivalent(*TIME()) || + fileType->type()->equivalent(*TIME_MICRO_UTC())); + return std::make_unique( + requestedType, fileType, params, scanSpec); + } + switch (fileType->type()->kind()) { case TypeKind::INTEGER: case TypeKind::BIGINT: @@ -51,8 +64,13 @@ std::unique_ptr ParquetColumnReader::build( requestedType, fileType, params, scanSpec); case TypeKind::REAL: - return std::make_unique>( - requestedType, fileType, params, scanSpec); + if (requestedType->kind() == TypeKind::REAL) { + return std::make_unique>( + requestedType, fileType, params, scanSpec); + } else { + return std::make_unique>( + requestedType, fileType, params, scanSpec); + } case TypeKind::DOUBLE: return std::make_unique>( requestedType, fileType, params, scanSpec); @@ -65,9 +83,11 @@ std::unique_ptr ParquetColumnReader::build( case TypeKind::VARCHAR: return std::make_unique(fileType, params, scanSpec); - case TypeKind::ARRAY: + case TypeKind::ARRAY: { + VELOX_CHECK(requestedType->isArray(), "Requested type must be array"); return std::make_unique( columnReaderOptions, requestedType, fileType, params, scanSpec); + } case TypeKind::MAP: return std::make_unique( @@ -97,7 +117,7 @@ std::unique_ptr ParquetColumnReader::build( default: VELOX_FAIL( "buildReader unhandled type: " + - mapTypeKindToName(fileType->type()->kind())); + std::string(TypeKindName::toName(fileType->type()->kind()))); } } diff --git a/velox/dwio/parquet/reader/ParquetData.cpp b/velox/dwio/parquet/reader/ParquetData.cpp index 788d04e3962..d91e07949c5 100644 --- a/velox/dwio/parquet/reader/ParquetData.cpp +++ b/velox/dwio/parquet/reader/ParquetData.cpp @@ -25,7 +25,7 @@ std::unique_ptr ParquetParams::toFormatData( const std::shared_ptr& type, const common::ScanSpec& /*scanSpec*/) { return std::make_unique( - type, metaData_, pool(), sessionTimezone_); + type, metaData_, pool(), runtimeStatistics(), sessionTimezone_); } void ParquetData::filterRowGroups( @@ -128,6 +128,7 @@ dwio::common::PositionProvider ParquetData::seekToRowGroup(int64_t index) { type_, metadata.compression(), metadata.totalCompressedSize(), + stats_, sessionTimezone_); return dwio::common::PositionProvider(empty); } @@ -137,7 +138,8 @@ std::pair ParquetData::getRowGroupRegion( auto rowGroup = fileMetaDataPtr_.rowGroup(index); VELOX_CHECK_GT(rowGroup.numColumns(), 0); - auto fileOffset = rowGroup.hasFileOffset() ? rowGroup.fileOffset() + auto fileOffset = (rowGroup.hasFileOffset() && rowGroup.fileOffset() != 0) + ? rowGroup.fileOffset() : rowGroup.columnChunk(0).hasDictionaryPageOffset() ? rowGroup.columnChunk(0).dictionaryPageOffset() : rowGroup.columnChunk(0).dataPageOffset(); diff --git a/velox/dwio/parquet/reader/ParquetData.h b/velox/dwio/parquet/reader/ParquetData.h index 1ea4a1e8c77..9926202491d 100644 --- a/velox/dwio/parquet/reader/ParquetData.h +++ b/velox/dwio/parquet/reader/ParquetData.h @@ -63,6 +63,7 @@ class ParquetData : public dwio::common::FormatData { const std::shared_ptr& type, const FileMetaDataPtr fileMetadataPtr, memory::MemoryPool& pool, + dwio::common::ColumnReaderStatistics& stats, const tz::TimeZone* sessionTimezone) : pool_(pool), type_(std::static_pointer_cast(type)), @@ -70,6 +71,7 @@ class ParquetData : public dwio::common::FormatData { maxDefine_(type_->maxDefine_), maxRepeat_(type_->maxRepeat_), rowsInRowGroup_(-1), + stats_(stats), sessionTimezone_(sessionTimezone) {} /// Prepares to read data for 'index'th row group. @@ -90,8 +92,9 @@ class ParquetData : public dwio::common::FormatData { return reader_.get(); } - // Reads null flags for 'numValues' next top level rows. The first 'numValues' - // bits of 'nulls' are set and the reader is advanced by numValues'. + // Reads null flags for 'numValues' next top level rows. The first + // 'numValues' bits of 'nulls' are set and the reader is advanced by + // numValues'. void readNullsOnly(int32_t numValues, BufferPtr& nulls) { reader_->readNullsOnly(numValues, nulls); } @@ -100,8 +103,9 @@ class ParquetData : public dwio::common::FormatData { return maxDefine_ > 0; } - /// Sets nulls to be returned by readNulls(). Nulls for non-leaf readers come - /// from leaf repdefs which are gathered before descending the reader tree. + /// Sets nulls to be returned by readNulls(). Nulls for non-leaf readers + /// come from leaf repdefs which are gathered before descending the reader + /// tree. void setNulls(BufferPtr& nulls, int32_t numValues) { if (nulls || numValues) { VELOX_CHECK_EQ(presetNullsConsumed_, presetNullsSize_); @@ -120,8 +124,8 @@ class ParquetData : public dwio::common::FormatData { const uint64_t* incomingNulls, BufferPtr& nulls, bool nullsOnly = false) override { - // If the query accesses only nulls, read the nulls from the pages in range. - // If nulls are preread, return those minus any skipped. + // If the query accesses only nulls, read the nulls from the pages in + // range. If nulls are preread, return those minus any skipped. if (presetNulls_) { VELOX_CHECK_LE(numValues, presetNullsSize_ - presetNullsConsumed_); if (!presetNullsConsumed_ && numValues == presetNullsSize_) { @@ -144,8 +148,8 @@ class ParquetData : public dwio::common::FormatData { readNullsOnly(numValues, nulls); return; } - // There are no column-level nulls in Parquet, only page-level ones, so this - // is always non-null. + // There are no column-level nulls in Parquet, only page-level ones, so + // this is always non-null. nulls = nullptr; } @@ -219,6 +223,7 @@ class ParquetData : public dwio::common::FormatData { const uint32_t maxDefine_; const uint32_t maxRepeat_; int64_t rowsInRowGroup_; + dwio::common::ColumnReaderStatistics& stats_; const tz::TimeZone* sessionTimezone_; std::unique_ptr reader_; diff --git a/velox/dwio/parquet/reader/ParquetReader.cpp b/velox/dwio/parquet/reader/ParquetReader.cpp index 955abc91b8a..005be63f15e 100644 --- a/velox/dwio/parquet/reader/ParquetReader.cpp +++ b/velox/dwio/parquet/reader/ParquetReader.cpp @@ -18,7 +18,9 @@ #include //@manual +#include "velox/dwio/common/StatisticsBuilder.h" #include "velox/dwio/parquet/reader/ParquetColumnReader.h" +#include "velox/dwio/parquet/reader/ParquetStatsContext.h" #include "velox/dwio/parquet/reader/StructColumnReader.h" #include "velox/dwio/parquet/thrift/ThriftTransport.h" #include "velox/functions/lib/string/StringImpl.h" @@ -27,16 +29,109 @@ namespace facebook::velox::parquet { namespace { +/// Finds the node with the given ID in the TypeWithId tree. Uses a full +/// traversal because Parquet's TypeWithId nodes all share the same maxId +/// (the global max schema element index), so the maxId-based pruning used +/// by ORC/DWRF does not work here. +const dwio::common::TypeWithId* findNode( + const dwio::common::TypeWithId& root, + uint32_t nodeId) { + if (root.id() == nodeId) { + return &root; + } + for (auto i = 0; i < root.size(); ++i) { + if (auto* result = findNode(*root.childAt(i), nodeId)) { + return result; + } + } + return nullptr; +} + bool isParquetReservedKeyword( std::string name, uint32_t parentSchemaIdx, uint32_t curSchemaIdx) { - return ((parentSchemaIdx == 0 && curSchemaIdx == 0) || name == "key_value" || - name == "key" || name == "value" || name == "list" || - name == "element" || name == "bag" || name == "array_element") + // We skip this for the top-level nodes. + return ((parentSchemaIdx == 0 && curSchemaIdx == 0) || + (parentSchemaIdx != 0 && + (name == "key_value" || name == "key" || name == "value" || + name == "list" || name == "element" || name == "bag" || + name == "array_element"))) ? true : false; } + +// An unannotated array in Parquet is a repeated field that is not explicitly +// marked as a LIST logical type. If current schema element is a repeated field +// and the requested type is an array, we treat the current schema element as an +// unannotated array, and returns true if the element type is compatible with +// the physical type. +bool isCompatible( + const TypePtr& requestedType, + bool isRepeated, + const std::function& isCompatibleFunc) { + return isCompatibleFunc(requestedType) || + (requestedType->isArray() && isRepeated && + isCompatibleFunc(requestedType->asArray().elementType())); +} + +// Checks if a decimal type has enough integer precision to hold all values +// of the given Parquet physical int type. +bool hasEnoughDecimalPrecision(const TypePtr& type, int32_t minIntegerDigits) { + if (!type->isDecimal()) { + return false; + } + auto [precision, scale] = getDecimalPrecisionScale(*type); + return (precision - scale) >= minIntegerDigits; +} + +// Checks if a type is compatible with an INT32 physical type. +// INT_8, INT_16, and INT_32 are all stored as Parquet INT32. +// 'minTypeKind' is the smallest Velox type that matches the file's +// converted type annotation (TINYINT for INT_8, SMALLINT for INT_16, +// INTEGER for INT_32 or unannotated INT32). +// For decimal targets, requires precision - scale >= 10. +// When 'allowNarrowing' is true, any integer type is accepted and the +// value is silently truncated on overflow. When false, only same-size +// or wider types are allowed. +bool isInt32Compatible( + const TypePtr& type, + TypeKind minTypeKind, + bool allowNarrowing) { + static_assert( + TypeKind::TINYINT < TypeKind::SMALLINT && + TypeKind::SMALLINT < TypeKind::INTEGER && + TypeKind::INTEGER < TypeKind::BIGINT, + "TypeKind enum ordering mismatch"); + + if (type->isDecimal()) { + return hasEnoughDecimalPrecision(type, 10); + } + + auto kind = type->kind(); + switch (kind) { + case TypeKind::TINYINT: + case TypeKind::SMALLINT: + case TypeKind::INTEGER: + case TypeKind::BIGINT: + return allowNarrowing || kind >= minTypeKind; + case TypeKind::DOUBLE: + return true; + default: + return false; + } +} + +// Checks whether the given type is compatible with a Parquet INT64 source. +// Accepts BIGINT identity mapping and Decimal targets with sufficient +// precision (precision - scale >= 20, covering the full INT64 range). +bool isInt64Compatible(const TypePtr& type) { + if (type->isDecimal()) { + return hasEnoughDecimalPrecision(type, 20); + } + return type->kind() == TypeKind::BIGINT; +} + } // namespace /// Metadata and options for reading Parquet. @@ -138,7 +233,7 @@ class ReaderBase { bool fileColumnNamesReadAsLowerCase); memory::MemoryPool& pool_; - const uint64_t footerEstimatedSize_; + const uint64_t footerSpeculativeIoSize_; const uint64_t filePreloadThreshold_; // Copy of options. Must be owned by 'this'. const dwio::common::ReaderOptions options_; @@ -159,7 +254,7 @@ ReaderBase::ReaderBase( std::unique_ptr input, const dwio::common::ReaderOptions& options) : pool_{options.memoryPool()}, - footerEstimatedSize_{options.footerEstimatedSize()}, + footerSpeculativeIoSize_{options.footerSpeculativeIoSize()}, filePreloadThreshold_{options.filePreloadThreshold()}, options_{options}, input_{std::move(input)}, @@ -174,8 +269,8 @@ ReaderBase::ReaderBase( void ReaderBase::loadFileMetaData() { bool preloadFile = - fileLength_ <= std::max(filePreloadThreshold_, footerEstimatedSize_); - uint64_t readSize = preloadFile ? fileLength_ : footerEstimatedSize_; + fileLength_ <= std::max(filePreloadThreshold_, footerSpeculativeIoSize_); + uint64_t readSize = preloadFile ? fileLength_ : footerSpeculativeIoSize_; std::unique_ptr stream; if (preloadFile) { @@ -308,8 +403,7 @@ std::unique_ptr ReaderBase::getParquetColumnInfo( name = functions::stringImpl::utf8StrToLowerCopy(name); } - if ((!options_.useColumnNamesForColumnMapping()) && - (options_.fileSchema() != nullptr)) { + if (!options_.useColumnNamesForColumnMapping() && options_.fileSchema()) { if (isParquetReservedKeyword(name, parentSchemaIdx, curSchemaIdx)) { columnNames.push_back(name); } @@ -337,21 +431,37 @@ std::unique_ptr ReaderBase::getParquetColumnInfo( TypePtr childRequestedType = nullptr; bool followChild = true; - if (requestedType && requestedType->isRow()) { - auto requestedRowType = - std::dynamic_pointer_cast(requestedType); - if (options_.useColumnNamesForColumnMapping()) { - auto fileTypeIdx = requestedRowType->getChildIdxIfExists(childName); - if (fileTypeIdx.has_value()) { - childRequestedType = requestedRowType->childAt(*fileTypeIdx); + + { + RowTypePtr requestedRowType = nullptr; + if (requestedType) { + if (requestedType->isRow()) { + requestedRowType = + std::dynamic_pointer_cast(requestedType); + } else if ( + requestedType->isArray() && isRepeated && + requestedType->asArray().elementType()->isRow()) { + // Handle the case of unannotated array of structs (repeated group + // without LIST annotation). + requestedRowType = std::dynamic_pointer_cast( + requestedType->asArray().elementType()); } - } else { - // Handle schema evolution. - if (i < requestedRowType->size()) { - columnNames.push_back(requestedRowType->nameOf(i)); - childRequestedType = requestedRowType->childAt(i); + } + + if (requestedRowType) { + if (options_.useColumnNamesForColumnMapping()) { + auto fileTypeIdx = requestedRowType->getChildIdxIfExists(childName); + if (fileTypeIdx.has_value()) { + childRequestedType = requestedRowType->childAt(*fileTypeIdx); + } } else { - followChild = false; + // Handle schema evolution. + if (i < requestedRowType->size()) { + columnNames.push_back(requestedRowType->nameOf(i)); + childRequestedType = requestedRowType->childAt(i); + } else { + followChild = false; + } } } } @@ -534,20 +644,21 @@ std::unique_ptr ReaderBase::getParquetColumnInfo( // In this legacy case, there is no middle layer between "array" // node and the children nodes. Below creates this dummy middle // layer to mimic the non-legacy case and fill the gap. - rowChildren.emplace_back(std::make_unique( - childrenRowType, - std::move(children), - curSchemaIdx, - maxSchemaElementIdx, - ParquetTypeWithId::kNonLeaf, - "dummy", - std::nullopt, - std::nullopt, - std::nullopt, - maxRepeat, - maxDefine, - isOptional, - isRepeated)); + rowChildren.emplace_back( + std::make_unique( + childrenRowType, + std::move(children), + curSchemaIdx, + maxSchemaElementIdx, + ParquetTypeWithId::kNonLeaf, + "dummy", + std::nullopt, + std::nullopt, + std::nullopt, + maxRepeat, + maxDefine, + isOptional, + isRepeated)); auto res = std::make_unique( TypeFactory::create(childrenRowType), std::move(rowChildren), @@ -598,20 +709,21 @@ std::unique_ptr ReaderBase::getParquetColumnInfo( // In this legacy case, there is no middle layer between "array" // node and the children nodes. Below creates this dummy middle // layer to mimic the non-legacy case and fill the gap. - rowChildren.emplace_back(std::make_unique( - childrenRowType, - std::move(children), - curSchemaIdx, - maxSchemaElementIdx, - ParquetTypeWithId::kNonLeaf, - "dummy", - std::nullopt, - std::nullopt, - std::nullopt, - maxRepeat, - maxDefine, - isOptional, - isRepeated)); + rowChildren.emplace_back( + std::make_unique( + childrenRowType, + std::move(children), + curSchemaIdx, + maxSchemaElementIdx, + ParquetTypeWithId::kNonLeaf, + "dummy", + std::nullopt, + std::nullopt, + std::nullopt, + maxRepeat, + maxDefine, + isOptional, + isRepeated)); return std::make_unique( TypeFactory::create(childrenRowType), std::move(rowChildren), @@ -720,8 +832,13 @@ TypePtr ReaderBase::convertType( schemaElement.__isset.type_length, "FIXED_LEN_BYTE_ARRAY requires length to be set"); - static std::string_view kTypeMappingErrorFmtStr = - "Converted type {} is not allowed for requested type {}"; + static constexpr const char* kTypeMappingErrorFmtStr = + "Converted type {} is not allowed for requested type {} for file column '{}'"; + + const bool isRepeated = schemaElement.__isset.repetition_type && + schemaElement.repetition_type == thrift::FieldRepetitionType::REPEATED; + const bool allowNarrowing = options_.allowInt32Narrowing(); + if (schemaElement.__isset.converted_type) { switch (schemaElement.converted_type) { case thrift::ConvertedType::INT_8: @@ -732,13 +849,18 @@ TypePtr ReaderBase::convertType( "{} converted type can only be set for value of thrift::Type::INT32", schemaElement.converted_type); VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::TINYINT || - requestedType->kind() == TypeKind::SMALLINT || - requestedType->kind() == TypeKind::INTEGER || - requestedType->kind() == TypeKind::BIGINT, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [&](const TypePtr& type) { + return isInt32Compatible( + type, TypeKind::TINYINT, allowNarrowing); + }), kTypeMappingErrorFmtStr, "TINYINT", - requestedType->toString()); + requestedType->toString(), + schemaElement.name); return TINYINT(); case thrift::ConvertedType::INT_16: @@ -749,12 +871,18 @@ TypePtr ReaderBase::convertType( "{} converted type can only be set for value of thrift::Type::INT32", schemaElement.converted_type); VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::SMALLINT || - requestedType->kind() == TypeKind::INTEGER || - requestedType->kind() == TypeKind::BIGINT, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [&](const TypePtr& type) { + return isInt32Compatible( + type, TypeKind::SMALLINT, allowNarrowing); + }), kTypeMappingErrorFmtStr, "SMALLINT", - requestedType->toString()); + requestedType->toString(), + schemaElement.name); return SMALLINT(); case thrift::ConvertedType::INT_32: @@ -765,11 +893,18 @@ TypePtr ReaderBase::convertType( "{} converted type can only be set for value of thrift::Type::INT32", schemaElement.converted_type); VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::INTEGER || - requestedType->kind() == TypeKind::BIGINT, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [&](const TypePtr& type) { + return isInt32Compatible( + type, TypeKind::INTEGER, allowNarrowing); + }), kTypeMappingErrorFmtStr, "INTEGER", - requestedType->toString()); + requestedType->toString(), + schemaElement.name); return INTEGER(); case thrift::ConvertedType::INT_64: @@ -777,13 +912,15 @@ TypePtr ReaderBase::convertType( VELOX_CHECK_EQ( schemaElement.type, thrift::Type::INT64, - "{} converted type can only be set for value of thrift::Type::INT32", + "{} converted type can only be set for value of thrift::Type::INT64", schemaElement.converted_type); VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::BIGINT, + !requestedType || + isCompatible(requestedType, isRepeated, isInt64Compatible), kTypeMappingErrorFmtStr, "BIGINT", - requestedType->toString()); + requestedType->toString(), + schemaElement.name); return BIGINT(); case thrift::ConvertedType::DATE: @@ -792,10 +929,15 @@ TypePtr ReaderBase::convertType( thrift::Type::INT32, "DATE converted type can only be set for value of thrift::Type::INT32"); VELOX_CHECK( - !requestedType || requestedType->isDate(), + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { return type->isDate(); }), kTypeMappingErrorFmtStr, "DATE", - requestedType->toString()); + requestedType->toString(), + schemaElement.name); return DATE(); case thrift::ConvertedType::TIMESTAMP_MICROS: @@ -805,10 +947,17 @@ TypePtr ReaderBase::convertType( thrift::Type::INT64, "TIMESTAMP_MICROS or TIMESTAMP_MILLIS converted type can only be set for value of thrift::Type::INT64"); VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::TIMESTAMP, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::TIMESTAMP; + }), kTypeMappingErrorFmtStr, "TIMESTAMP", - requestedType->toString()); + requestedType->toString(), + schemaElement.name); return TIMESTAMP(); case thrift::ConvertedType::DECIMAL: { @@ -817,39 +966,35 @@ TypePtr ReaderBase::convertType( "DECIMAL requires a length and scale specifier!"); const auto schemaElementPrecision = schemaElement.precision; const auto schemaElementScale = schemaElement.scale; - // A long decimal requested type cannot read a value of a short decimal. - // As a result, the mapping from short to long decimal is currently - // restricted. auto type = DECIMAL(schemaElementPrecision, schemaElementScale); if (requestedType) { VELOX_CHECK( - requestedType->isDecimal(), + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { return type->isDecimal(); }), kTypeMappingErrorFmtStr, "DECIMAL", - requestedType->toString()); - // Reading short decimals with a long decimal requested type is not - // yet possible. To allow for correct interpretation of the values, - // the scale of the file type and requested type must match while - // precision may be larger. - if (requestedType->isShortDecimal()) { - const auto& shortDecimalType = requestedType->asShortDecimal(); - VELOX_CHECK( - type->isShortDecimal() && - shortDecimalType.precision() >= schemaElementPrecision && - shortDecimalType.scale() == schemaElementScale, - kTypeMappingErrorFmtStr, - type->toString(), - requestedType->toString()); - } else { - const auto& longDecimalType = requestedType->asLongDecimal(); - VELOX_CHECK( - type->isLongDecimal() && - longDecimalType.precision() >= schemaElementPrecision && - longDecimalType.scale() == schemaElementScale, - kTypeMappingErrorFmtStr, - type->toString(), - requestedType->toString()); - } + requestedType->toString(), + schemaElement.name); + // Allow decimal widening: precision may be larger and scale may + // increase as long as precisionIncrease >= scaleIncrease. + // Short-to-long decimal crossing is handled by getDecimalValues + // via the upcast path. + VELOX_CHECK( + isCompatible( + requestedType, + isRepeated, + [&](const TypePtr& type) { + auto [precision, scale] = getDecimalPrecisionScale(*type); + auto precisionInc = precision - schemaElementPrecision; + auto scaleInc = scale - schemaElementScale; + return scaleInc >= 0 && precisionInc >= scaleInc; + }), + kTypeMappingErrorFmtStr, + type->toString(), + requestedType->toString(), + schemaElement.name); } return type; } @@ -859,10 +1004,17 @@ TypePtr ReaderBase::convertType( case thrift::Type::BYTE_ARRAY: case thrift::Type::FIXED_LEN_BYTE_ARRAY: VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::VARCHAR, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::VARCHAR; + }), kTypeMappingErrorFmtStr, "VARCHAR", - requestedType->toString()); + requestedType->toString(), + schemaElement.name); return VARCHAR(); default: VELOX_FAIL( @@ -874,17 +1026,58 @@ TypePtr ReaderBase::convertType( thrift::Type::BYTE_ARRAY, "ENUM converted type can only be set for value of thrift::Type::BYTE_ARRAY"); VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::VARCHAR, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::VARCHAR; + }), kTypeMappingErrorFmtStr, "VARCHAR", - requestedType->toString()); + requestedType->toString(), + schemaElement.name); return VARCHAR(); } + case thrift::ConvertedType::TIME_MILLIS: + VELOX_CHECK_EQ( + schemaElement.type, + thrift::Type::INT32, + "TIME_MILLIS converted type can only be set for value of thrift::Type::INT32"); + VELOX_CHECK( + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->equivalent(*TIME()); + }), + kTypeMappingErrorFmtStr, + "TIME", + requestedType->toString(), + schemaElement.name); + return TIME(); + + case thrift::ConvertedType::TIME_MICROS: { + VELOX_CHECK_EQ( + schemaElement.type, + thrift::Type::INT64, + "TIME_MICROS converted type can only be set for value of thrift::Type::INT64"); + const bool isCompatibleRequestedType = !requestedType || + isCompatible(requestedType, isRepeated, [](const TypePtr& type) { + return type->equivalent(*TIME_MICRO_UTC()); + }); + VELOX_CHECK( + isCompatibleRequestedType, + kTypeMappingErrorFmtStr, + "TIME MICRO UTC", + requestedType->toString()); + return TIME_MICRO_UTC(); + } + case thrift::ConvertedType::MAP: case thrift::ConvertedType::MAP_KEY_VALUE: case thrift::ConvertedType::LIST: - case thrift::ConvertedType::TIME_MILLIS: - case thrift::ConvertedType::TIME_MICROS: case thrift::ConvertedType::JSON: case thrift::ConvertedType::BSON: case thrift::ConvertedType::INTERVAL: @@ -897,68 +1090,120 @@ TypePtr ReaderBase::convertType( switch (schemaElement.type) { case thrift::Type::type::BOOLEAN: VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::BOOLEAN, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::BOOLEAN; + }), kTypeMappingErrorFmtStr, "BOOLEAN", - requestedType->toString()); + requestedType->toString(), + schemaElement.name); return BOOLEAN(); case thrift::Type::type::INT32: VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::INTEGER || - requestedType->kind() == TypeKind::BIGINT, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [&](const TypePtr& type) { + return isInt32Compatible( + type, TypeKind::INTEGER, allowNarrowing); + }), kTypeMappingErrorFmtStr, "INTEGER", - requestedType->toString()); + requestedType->toString(), + schemaElement.name); return INTEGER(); case thrift::Type::type::INT64: // For Int64 Timestamp in nano precision if (schemaElement.__isset.logicalType && schemaElement.logicalType.__isset.TIMESTAMP) { VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::TIMESTAMP, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::TIMESTAMP; + }), kTypeMappingErrorFmtStr, "TIMESTAMP", - requestedType->toString()); + requestedType->toString(), + schemaElement.name); return TIMESTAMP(); } VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::BIGINT, + !requestedType || + isCompatible(requestedType, isRepeated, isInt64Compatible), kTypeMappingErrorFmtStr, "BIGINT", - requestedType->toString()); + requestedType->toString(), + schemaElement.name); return BIGINT(); case thrift::Type::type::INT96: VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::TIMESTAMP, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::TIMESTAMP; + }), kTypeMappingErrorFmtStr, "TIMESTAMP", - requestedType->toString()); + requestedType->toString(), + schemaElement.name); return TIMESTAMP(); // INT96 only maps to a timestamp case thrift::Type::type::FLOAT: VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::REAL || - requestedType->kind() == TypeKind::DOUBLE, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::REAL || + type->kind() == TypeKind::DOUBLE; + }), kTypeMappingErrorFmtStr, "REAL", - requestedType->toString()); + requestedType->toString(), + schemaElement.name); return REAL(); case thrift::Type::type::DOUBLE: VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::DOUBLE, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::DOUBLE; + }), kTypeMappingErrorFmtStr, "DOUBLE", - requestedType->toString()); + requestedType->toString(), + schemaElement.name); return DOUBLE(); case thrift::Type::type::BYTE_ARRAY: case thrift::Type::type::FIXED_LEN_BYTE_ARRAY: - if (requestedType && requestedType->isVarchar()) { + if (requestedType && + isCompatible(requestedType, isRepeated, [](const TypePtr& type) { + return type->isVarchar(); + })) { return VARCHAR(); } else { VELOX_CHECK( - !requestedType || requestedType->isVarbinary(), + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { return type->isVarbinary(); }), kTypeMappingErrorFmtStr, "VARBINARY", - requestedType->toString()); + requestedType->toString(), + schemaElement.name); return VARBINARY(); } @@ -1099,7 +1344,8 @@ class ParquetRowReader::Impl { uint64_t rowNumber = 0; for (auto i = 0; i < rowGroups_.size(); i++) { VELOX_CHECK_GT(rowGroups_[i].columns.size(), 0); - auto fileOffset = rowGroups_[i].__isset.file_offset + auto fileOffset = + (rowGroups_[i].__isset.file_offset && rowGroups_[i].file_offset != 0) ? rowGroups_[i].file_offset : rowGroups_[i].columns[0].meta_data.__isset.dictionary_page_offset ? rowGroups_[i].columns[0].meta_data.dictionary_page_offset @@ -1178,14 +1424,21 @@ class ParquetRowReader::Impl { std::optional estimatedRowSize() const { auto index = nextRowGroupIdsIdx_ < 1 ? 0 : rowGroupIds_[nextRowGroupIdsIdx_ - 1]; - return readerBase_->rowGroupUncompressedSize( - index, *readerBase_->schemaWithId()) / + if (index == lastRowGroupWithRowEstimate_) { + return estimatedRowSize_; + } + estimatedRowSize_ = readerBase_->rowGroupUncompressedSize( + index, *readerBase_->schemaWithId()) / rowGroups_[index].num_rows; + lastRowGroupWithRowEstimate_ = index; + return estimatedRowSize_; } void updateRuntimeStats(dwio::common::RuntimeStatistics& stats) const { stats.skippedStrides += skippedStrides_; stats.processedStrides += rowGroupIds_.size(); + stats.columnReaderStats.pageLoadTimeNs.merge( + columnReaderStats_.pageLoadTimeNs); } void resetFilterCaches() { @@ -1237,6 +1490,9 @@ class ParquetRowReader::Impl { ParquetStatsContext parquetStatsContext_; dwio::common::ColumnReaderStatistics columnReaderStats_; + + mutable std::optional estimatedRowSize_; + mutable int32_t lastRowGroupWithRowEstimate_{-1}; }; ParquetRowReader::ParquetRowReader( @@ -1290,6 +1546,43 @@ std::optional ParquetReader::numberOfRows() const { return readerBase_->thriftFileMetaData().num_rows; } +std::unique_ptr ParquetReader::columnStatistics( + uint32_t index) const { + auto node = findNode(*readerBase_->schemaWithId(), index); + if (!node) { + return nullptr; + } + auto& parquetNode = static_cast(*node); + if (!parquetNode.isLeaf()) { + return nullptr; + } + + auto fileMetaData = readerBase_->fileMetaData(); + const auto numRowGroups = fileMetaData.numRowGroups(); + if (numRowGroups == 0) { + return nullptr; + } + + // Merge per-row-group statistics into file-level statistics. + dwio::stats::StatisticsBuilderOptions options{ + /*stringLengthLimit=*/std::numeric_limits::max()}; + auto builder = + dwio::stats::StatisticsBuilder::create(*parquetNode.type(), options); + + for (int i = 0; i < numRowGroups; ++i) { + auto rowGroup = fileMetaData.rowGroup(i); + auto columnChunk = rowGroup.columnChunk(parquetNode.column()); + if (!columnChunk.hasStatistics()) { + return nullptr; + } + auto rowGroupStats = + columnChunk.getColumnStatistics(parquetNode.type(), rowGroup.numRows()); + builder->merge(*rowGroupStats); + } + + return builder->build(); +} + const velox::RowTypePtr& ParquetReader::rowType() const { return readerBase_->schema(); } diff --git a/velox/dwio/parquet/reader/ParquetReader.h b/velox/dwio/parquet/reader/ParquetReader.h index de6d7a9966d..dbaa8414d92 100644 --- a/velox/dwio/parquet/reader/ParquetReader.h +++ b/velox/dwio/parquet/reader/ParquetReader.h @@ -19,7 +19,6 @@ #include "velox/dwio/common/Reader.h" #include "velox/dwio/common/ReaderFactory.h" #include "velox/dwio/parquet/reader/Metadata.h" -#include "velox/dwio/parquet/reader/ParquetStatsContext.h" namespace facebook::velox::dwio::common { @@ -92,9 +91,7 @@ class ParquetReader : public dwio::common::Reader { std::optional numberOfRows() const override; std::unique_ptr columnStatistics( - uint32_t index) const override { - return nullptr; - } + uint32_t index) const override; const velox::RowTypePtr& rowType() const override; diff --git a/velox/dwio/parquet/reader/ParquetTypeWithId.cpp b/velox/dwio/parquet/reader/ParquetTypeWithId.cpp index 1581fa63934..aab2e71fba7 100644 --- a/velox/dwio/parquet/reader/ParquetTypeWithId.cpp +++ b/velox/dwio/parquet/reader/ParquetTypeWithId.cpp @@ -53,23 +53,24 @@ ParquetTypeWithId::moveChildren() const&& { auto precision = parquetChild->precision_; auto scale = parquetChild->scale_; auto typeLength = parquetChild->typeLength_; - children.push_back(std::make_unique( - std::move(type), - std::move(*parquetChild).moveChildren(), - id, - maxId, - column, - std::move(name), - parquetType, - std::move(logicalType), - std::move(convertedType), - maxRepeat, - maxDefine, - isOptional, - isRepeated, - precision, - scale, - typeLength)); + children.push_back( + std::make_unique( + std::move(type), + std::move(*parquetChild).moveChildren(), + id, + maxId, + column, + std::move(name), + parquetType, + std::move(logicalType), + std::move(convertedType), + maxRepeat, + maxDefine, + isOptional, + isRepeated, + precision, + scale, + typeLength)); } return children; } diff --git a/velox/dwio/parquet/reader/RepeatedColumnReader.cpp b/velox/dwio/parquet/reader/RepeatedColumnReader.cpp index 8cd75156747..0d71733bf32 100644 --- a/velox/dwio/parquet/reader/RepeatedColumnReader.cpp +++ b/velox/dwio/parquet/reader/RepeatedColumnReader.cpp @@ -153,8 +153,6 @@ void MapColumnReader::seekToRowGroup(int64_t index) { BufferPtr noBuffer; formatData_->as().setNulls(noBuffer, 0); lengths_.setLengths(nullptr); - keyReader_->seekToRowGroup(index); - elementReader_->seekToRowGroup(index); } void MapColumnReader::skipUnreadLengths() { @@ -173,10 +171,10 @@ void MapColumnReader::setLengthsFromRepDefs(PageReader& pageReader) { auto repDefRange = pageReader.repDefRange(); int32_t numRepDefs = repDefRange.second - repDefRange.first; BufferPtr lengths = std::move(lengths_.lengths()); - dwio::common::ensureCapacity(lengths, numRepDefs, memoryPool_); + dwio::common::ensureCapacity(lengths, numRepDefs, pool_); memset(lengths->asMutable(), 0, lengths->size()); dwio::common::ensureCapacity( - nullsInReadRange_, bits::nwords(numRepDefs), memoryPool_); + nullsInReadRange_, bits::nwords(numRepDefs), pool_); auto numLists = pageReader.getLengthsAndNulls( LevelMode::kList, levelInfo_, @@ -282,10 +280,10 @@ void ListColumnReader::setLengthsFromRepDefs(PageReader& pageReader) { auto repDefRange = pageReader.repDefRange(); int32_t numRepDefs = repDefRange.second - repDefRange.first; BufferPtr lengths = std::move(lengths_.lengths()); - dwio::common::ensureCapacity(lengths, numRepDefs + 1, memoryPool_); + dwio::common::ensureCapacity(lengths, numRepDefs + 1, pool_); memset(lengths->asMutable(), 0, lengths->size()); dwio::common::ensureCapacity( - nullsInReadRange_, bits::nwords(numRepDefs + 1), memoryPool_); + nullsInReadRange_, bits::nwords(numRepDefs + 1), pool_); auto numLists = pageReader.getLengthsAndNulls( LevelMode::kList, levelInfo_, diff --git a/velox/dwio/parquet/reader/RleBpDecoder.h b/velox/dwio/parquet/reader/RleBpDecoder.h index ac07ed76ad0..5856de775e2 100644 --- a/velox/dwio/parquet/reader/RleBpDecoder.h +++ b/velox/dwio/parquet/reader/RleBpDecoder.h @@ -37,7 +37,7 @@ class RleBpDecoder { /// Decode @param numValues number of values and copy the decoded values into /// @param outputBuffer template - void next(T* FOLLY_NONNULL& outputBuffer, uint64_t numValues) { + void next(T * FOLLY_NONNULL & outputBuffer, uint64_t numValues) { while (numValues > 0) { if (numRemainingUnpackedValues_ > 0) { auto numValuesToRead = @@ -103,7 +103,7 @@ class RleBpDecoder { template inline void copyRemainingUnpackedValues( - T* FOLLY_NONNULL& outputBuffer, + T * FOLLY_NONNULL & outputBuffer, int8_t numValues) { VELOX_CHECK_LE(numValues, numRemainingUnpackedValues_); diff --git a/velox/dwio/parquet/reader/SemanticVersion.cpp b/velox/dwio/parquet/reader/SemanticVersion.cpp index fa7851cb7db..a7c1892f7cf 100644 --- a/velox/dwio/parquet/reader/SemanticVersion.cpp +++ b/velox/dwio/parquet/reader/SemanticVersion.cpp @@ -67,7 +67,7 @@ bool SemanticVersion::shouldIgnoreStatistics(thrift::Type::type type) const { if (this->application_ != "parquet-mr") { return false; } - static SemanticVersion threshold(1, 8, 1); + static SemanticVersion threshold(1, 8, 2); return *this < threshold; } diff --git a/velox/dwio/parquet/reader/StringColumnReader.cpp b/velox/dwio/parquet/reader/StringColumnReader.cpp index ac678b7f0a3..75a62914fd4 100644 --- a/velox/dwio/parquet/reader/StringColumnReader.cpp +++ b/velox/dwio/parquet/reader/StringColumnReader.cpp @@ -35,7 +35,7 @@ void StringColumnReader::read( int64_t offset, const RowSet& rows, const uint64_t* incomingNulls) { - prepareRead(offset, rows, incomingNulls); + prepareRead(offset, rows, incomingNulls); dwio::common::StringColumnReadWithVisitorHelper( *this, rows)([&](auto visitor) { formatData_->as().readWithVisitor(visitor); @@ -50,7 +50,7 @@ void StringColumnReader::getValues(const RowSet& rows, VectorPtr* result) { compactScalarValues(rows, false); *result = std::make_shared>( - memoryPool_, resultNulls(), numValues_, dictionaryValues, values_); + pool_, resultNulls(), numValues_, dictionaryValues, values_); return; } rawStringBuffer_ = nullptr; @@ -83,7 +83,7 @@ void StringColumnReader::dedictionarize() { } auto& view = dict->valueAt(indices[i]); numValues_ = i; - addStringValue(folly::StringPiece(view.data(), view.size())); + addStringValue(std::string_view(view.data(), view.size())); } numValues_ = numValues; } diff --git a/velox/dwio/parquet/reader/StringDecoder.h b/velox/dwio/parquet/reader/StringDecoder.h index 2bff5285e3c..03805b5659f 100644 --- a/velox/dwio/parquet/reader/StringDecoder.h +++ b/velox/dwio/parquet/reader/StringDecoder.h @@ -90,15 +90,15 @@ class StringDecoder { return *reinterpret_cast(buffer); } - folly::StringPiece readString() { + std::string_view readString() { auto length = lengthAt(bufferStart_); bufferStart_ += length + sizeof(int32_t); - return folly::StringPiece(bufferStart_ - length, length); + return std::string_view(bufferStart_ - length, length); } - folly::StringPiece readFixedString() { + std::string_view readFixedString() { bufferStart_ += fixedLength_; - return folly::StringPiece(bufferStart_ - fixedLength_, fixedLength_); + return std::string_view(bufferStart_ - fixedLength_, fixedLength_); } const char* bufferStart_; diff --git a/velox/dwio/parquet/reader/StructColumnReader.cpp b/velox/dwio/parquet/reader/StructColumnReader.cpp index 694f334c51a..fbfffbe585f 100644 --- a/velox/dwio/parquet/reader/StructColumnReader.cpp +++ b/velox/dwio/parquet/reader/StructColumnReader.cpp @@ -46,12 +46,13 @@ StructColumnReader::StructColumnReader( auto childFileType = fileType_->childByName(childSpec->fieldName()); auto childRequestedType = requestedType_->asRow().findChild(childSpec->fieldName()); - addChild(ParquetColumnReader::build( - columnReaderOptions, - childRequestedType, - childFileType, - params, - *childSpec)); + addChild( + ParquetColumnReader::build( + columnReaderOptions, + childRequestedType, + childFileType, + params, + *childSpec)); childSpecs[i]->setSubscript(children_.size() - 1); } @@ -185,7 +186,7 @@ void StructColumnReader::setNullsFromRepDefs(PageReader& pageReader) { auto repDefRange = pageReader.repDefRange(); int32_t numRepDefs = repDefRange.second - repDefRange.first; dwio::common::ensureCapacity( - nullsInReadRange_, bits::nwords(numRepDefs), memoryPool_); + nullsInReadRange_, bits::nwords(numRepDefs), pool_); auto numStructs = pageReader.getLengthsAndNulls( levelMode_, levelInfo_, diff --git a/velox/dwio/parquet/reader/TimeColumnReader.h b/velox/dwio/parquet/reader/TimeColumnReader.h new file mode 100644 index 00000000000..111f6eaa213 --- /dev/null +++ b/velox/dwio/parquet/reader/TimeColumnReader.h @@ -0,0 +1,77 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/dwio/parquet/reader/IntegerColumnReader.h" +#include "velox/dwio/parquet/reader/ParquetColumnReader.h" +#include "velox/dwio/parquet/thrift/ParquetThriftTypes.h" + +namespace facebook::velox::parquet { + +/// Column reader for Parquet TIME type. +/// Handles conversion from: +/// - Parquet TIME_MILLIS (INT32, milliseconds) to Velox TIME (BIGINT, +/// milliseconds). +/// - Parquet TIME_MICROS (INT64, microseconds) to Velox TIME MICRO UTC +/// (BIGINT, microseconds). +class TimeColumnReader : public IntegerColumnReader { + public: + TimeColumnReader( + const TypePtr& requestedType, + const std::shared_ptr& fileType, + ParquetParams& params, + common::ScanSpec& scanSpec) + : IntegerColumnReader(requestedType, fileType, params, scanSpec) { + const auto typeWithId = + std::static_pointer_cast(fileType_); + if (auto logicalType = typeWithId->logicalType_) { + VELOX_CHECK(logicalType->__isset.TIME); + const auto unit = logicalType->TIME.unit; + VELOX_CHECK( + unit.__isset.MILLIS || unit.__isset.MICROS, + "TIME precision other than milliseconds or microseconds is not supported"); + isMicros_ = unit.__isset.MICROS; + } else if (auto convertedType = typeWithId->convertedType_) { + VELOX_CHECK( + convertedType == thrift::ConvertedType::type::TIME_MILLIS || + convertedType == thrift::ConvertedType::type::TIME_MICROS, + "TIME converted type other than TIME_MILLIS or TIME_MICROS is not supported"); + isMicros_ = convertedType == thrift::ConvertedType::type::TIME_MICROS; + } else { + VELOX_NYI("Logical type and converted type are not provided for TIME."); + } + } + + void read( + int64_t offset, + const RowSet& rows, + const uint64_t* /*incomingNulls*/) override { + // Velox represents TIME as BIGINT (8 bytes) for both precisions. + // Parquet stores TIME_MILLIS as INT32 (4 bytes) and TIME_MICROS as + // INT64 (8 bytes). Dispatch on the physical width. + const int32_t physicalWidth = isMicros_ ? 8 : 4; + VELOX_WIDTH_DISPATCH(physicalWidth, prepareRead, offset, rows, nullptr); + readCommon(rows); + readOffset_ += rows.back() + 1; + } + + private: + // True for TIME_MICROS (INT64), false for TIME_MILLIS (INT32). + bool isMicros_{false}; +}; + +} // namespace facebook::velox::parquet diff --git a/velox/dwio/parquet/reader/TimestampColumnReader.h b/velox/dwio/parquet/reader/TimestampColumnReader.h index 909577ddd89..c32de51e106 100644 --- a/velox/dwio/parquet/reader/TimestampColumnReader.h +++ b/velox/dwio/parquet/reader/TimestampColumnReader.h @@ -182,13 +182,14 @@ class TimestampColumnReader : public IntegerColumnReader { filters.reserve(multiRange->filters().size()); for (const auto& filter : multiRange->filters()) { if (auto* range = dynamic_cast(filter.get())) { - filters.emplace_back(std::make_unique>( - range->lower(), - range->upper(), - range->nullAllowed(), - filePrecision_)); + filters.emplace_back( + std::make_unique>( + range->lower(), + range->upper(), + range->nullAllowed(), + filePrecision_)); } else { - filters.emplace_back(filter->clone(range->nullAllowed())); + filters.emplace_back(filter->clone(filter->nullAllowed())); } } auto newMultiRange = diff --git a/velox/dwio/parquet/tests/CMakeLists.txt b/velox/dwio/parquet/tests/CMakeLists.txt index dc09a85c4fd..de1431e8866 100644 --- a/velox/dwio/parquet/tests/CMakeLists.txt +++ b/velox/dwio/parquet/tests/CMakeLists.txt @@ -19,7 +19,7 @@ set( velox_exec_test_lib velox_dwio_parquet_reader velox_dwio_parquet_writer - velox_temp_path + velox_test_util GTest::gtest GTest::gtest_main GTest::gmock @@ -47,3 +47,5 @@ target_link_libraries( ) file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/examples DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) + +velox_add_library(velox_dwio_parquet_test_base INTERFACE HEADERS ParquetTestBase.h) diff --git a/velox/dwio/parquet/tests/ParquetTestBase.h b/velox/dwio/parquet/tests/ParquetTestBase.h index c867978a0a5..fd2cef91546 100644 --- a/velox/dwio/parquet/tests/ParquetTestBase.h +++ b/velox/dwio/parquet/tests/ParquetTestBase.h @@ -19,17 +19,20 @@ #include #include #include "velox/common/base/Fs.h" +#include "velox/common/file/File.h" +#include "velox/common/io/IoStatistics.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/dwio/common/FileSink.h" #include "velox/dwio/common/Reader.h" #include "velox/dwio/common/tests/utils/DataFiles.h" #include "velox/dwio/parquet/reader/PageReader.h" #include "velox/dwio/parquet/reader/ParquetReader.h" #include "velox/dwio/parquet/writer/Writer.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/vector/fuzzer/VectorFuzzer.h" #include "velox/vector/tests/utils/VectorTestBase.h" namespace facebook::velox::parquet { +using TempDirectoryPath = common::testutil::TempDirectoryPath; class ParquetTestBase : public testing::Test, public velox::test::VectorTestBase { @@ -42,7 +45,7 @@ class ParquetTestBase : public testing::Test, dwio::common::LocalFileSink::registerFactory(); rootPool_ = memory::memoryManager()->addRootPool("ParquetTests"); leafPool_ = rootPool_->addLeafChild("ParquetTests"); - tempPath_ = exec::test::TempDirectoryPath::create(); + tempPath_ = TempDirectoryPath::create(); } static RowTypePtr sampleSchema() { @@ -191,10 +194,55 @@ class ParquetTestBase : public testing::Test, "velox/dwio/parquet/tests/reader", "../examples/" + fileName); } + dwio::common::MemorySink* write( + const RowVectorPtr& data, + const WriterOptions& writerOptions) { + auto sink = std::make_unique( + 200 * 1024 * 1024, + dwio::common::FileSink::Options{.pool = leafPool_.get()}); + auto* sinkPtr = sink.get(); + auto writer = std::make_unique( + std::move(sink), writerOptions, data->rowType()); + writer->write(data); + writer->close(); + writers_.push_back(std::move(writer)); + return sinkPtr; + } + + dwio::common::MemorySink* write( + const RowVectorPtr& data, + std::unordered_map configFromFile = {}, + std::unordered_map sessionProperties = {}) { + parquet::WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); + auto connectorConfig = config::ConfigBase(std::move(configFromFile)); + auto connectorSessionProperties = + config::ConfigBase(std::move(sessionProperties)); + writerOptions.processConfigs(connectorConfig, connectorSessionProperties); + return write(data, writerOptions); + } + + std::unique_ptr createReaderInMemory( + const dwio::common::MemorySink& sink, + const dwio::common::ReaderOptions& opts) { + std::string data(sink.data(), sink.size()); + return std::make_unique( + std::make_unique( + std::make_shared(std::move(data)), + opts.memoryPool()), + opts); + } + static constexpr uint64_t kRowsInRowGroup = 10'000; static constexpr uint64_t kBytesInRowGroup = 128 * 1'024 * 1'024; std::shared_ptr rootPool_; std::shared_ptr leafPool_; - std::shared_ptr tempPath_; + std::shared_ptr dataIoStats_ = + std::make_shared(); + std::shared_ptr metadataIoStats_ = + std::make_shared(); + std::shared_ptr tempPath_; + // Stores writers created by write() helper to keep sinks alive for reading. + std::vector> writers_; }; } // namespace facebook::velox::parquet diff --git a/velox/dwio/parquet/tests/ParquetTpchTest.cpp b/velox/dwio/parquet/tests/ParquetTpchTest.cpp index 615476c54b6..1dd5b4ef80d 100644 --- a/velox/dwio/parquet/tests/ParquetTpchTest.cpp +++ b/velox/dwio/parquet/tests/ParquetTpchTest.cpp @@ -18,6 +18,8 @@ #include #include "velox/common/file/FileSystems.h" +#include "velox/common/testutil/TempDirectoryPath.h" +#include "velox/connectors/ConnectorRegistry.h" #include "velox/connectors/hive/HiveConnector.h" #include "velox/connectors/tpch/TpchConnector.h" #include "velox/dwio/parquet/RegisterParquetReader.h" @@ -25,7 +27,6 @@ #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/exec/tests/utils/TpchQueryBuilder.h" #include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" @@ -34,6 +35,7 @@ using namespace facebook::velox; using namespace facebook::velox::exec; using namespace facebook::velox::exec::test; +using namespace facebook::velox::common::testutil; class ParquetTpchTest : public testing::Test { protected: @@ -60,22 +62,24 @@ class ParquetTpchTest : public testing::Test { kHiveConnectorId, std::make_shared( std::unordered_map())); - connector::registerConnector(hiveConnector); + connector::ConnectorRegistry::global().insert( + hiveConnector->connectorId(), hiveConnector); connector::tpch::TpchConnectorFactory tpchFactory; auto tpchConnector = tpchFactory.newConnector( kTpchConnectorId, std::make_shared( std::unordered_map())); - connector::registerConnector(tpchConnector); + connector::ConnectorRegistry::global().insert( + tpchConnector->connectorId(), tpchConnector); saveTpchTablesAsParquet(); tpchBuilder_->initialize(tempDirectory_->getPath()); } static void TearDownTestSuite() { - connector::unregisterConnector(kHiveConnectorId); - connector::unregisterConnector(kTpchConnectorId); + connector::ConnectorRegistry::global().erase(kHiveConnectorId); + connector::ConnectorRegistry::global().erase(kTpchConnectorId); parquet::unregisterParquetReaderFactory(); parquet::unregisterParquetWriterFactory(); } @@ -94,8 +98,8 @@ class ParquetTpchTest : public testing::Test { auto plan = PlanBuilder() .tpchTableScan(table, std::move(columnNames), 0.01) .planNode(); - auto split = - exec::Split(std::make_shared( + auto split = exec::Split( + std::make_shared( kTpchConnectorId, /*cacheable=*/true, 1, 0)); auto rows = diff --git a/velox/dwio/parquet/tests/common/LevelConversionTest.cpp b/velox/dwio/parquet/tests/common/LevelConversionTest.cpp index e3ff3a7e6e9..f7274b7c974 100644 --- a/velox/dwio/parquet/tests/common/LevelConversionTest.cpp +++ b/velox/dwio/parquet/tests/common/LevelConversionTest.cpp @@ -160,38 +160,37 @@ MultiLevelTestData TriplyNestedList() { // [[[]], [[], [1, 2]], null, [[3]]], // null, // [] - return MultiLevelTestData{ - /*defLevels=*/std::vector{ - 2, - 7, - 6, - 7, - 5, - 3, // first row - 5, - 5, - 7, - 7, - 2, - 7, // second row - 0, // third row - 1}, - /*repLevels=*/ - std::vector{ - 0, - 1, - 3, - 3, - 2, - 1, // first row - 0, - 1, - 2, - 3, - 1, - 1, // second row - 0, - 0}}; + return MultiLevelTestData{/*defLevels=*/std::vector{ + 2, + 7, + 6, + 7, + 5, + 3, // first row + 5, + 5, + 7, + 7, + 2, + 7, // second row + 0, // third row + 1}, + /*repLevels=*/ + std::vector{ + 0, + 1, + 3, + 3, + 2, + 1, // first row + 0, + 1, + 2, + 3, + 1, + 1, // second row + 0, + 0}}; } template diff --git a/velox/dwio/parquet/tests/examples/nested_array_struct.parquet b/velox/dwio/parquet/tests/examples/nested_array_struct.parquet new file mode 100644 index 00000000000..41a43fa35d3 Binary files /dev/null and b/velox/dwio/parquet/tests/examples/nested_array_struct.parquet differ diff --git a/velox/dwio/parquet/tests/examples/proto_repeated_string.parquet b/velox/dwio/parquet/tests/examples/proto_repeated_string.parquet new file mode 100644 index 00000000000..8a7eea601d0 Binary files /dev/null and b/velox/dwio/parquet/tests/examples/proto_repeated_string.parquet differ diff --git a/velox/dwio/parquet/tests/examples/zero_offset_row_group.parquet b/velox/dwio/parquet/tests/examples/zero_offset_row_group.parquet new file mode 100644 index 00000000000..3b41425516e Binary files /dev/null and b/velox/dwio/parquet/tests/examples/zero_offset_row_group.parquet differ diff --git a/velox/dwio/parquet/tests/reader/CMakeLists.txt b/velox/dwio/parquet/tests/reader/CMakeLists.txt index 8fe4b353e87..55b4d51506f 100644 --- a/velox/dwio/parquet/tests/reader/CMakeLists.txt +++ b/velox/dwio/parquet/tests/reader/CMakeLists.txt @@ -39,6 +39,7 @@ target_link_libraries( ) add_library(velox_dwio_parquet_reader_benchmark_lib ParquetReaderBenchmark.cpp) +velox_add_test_headers(velox_dwio_parquet_reader_benchmark_lib ParquetReaderBenchmark.h) target_link_libraries( velox_dwio_parquet_reader_benchmark_lib velox_dwio_parquet_reader @@ -59,6 +60,7 @@ endif() add_executable( velox_dwio_parquet_reader_test ParquetReaderTest.cpp + ParquetReaderWideningTest.cpp ParquetReaderBenchmarkTest.cpp BloomFilterTest.cpp ) diff --git a/velox/dwio/parquet/tests/reader/E2EFilterTest.cpp b/velox/dwio/parquet/tests/reader/E2EFilterTest.cpp index d718c29b0c1..d323989918b 100644 --- a/velox/dwio/parquet/tests/reader/E2EFilterTest.cpp +++ b/velox/dwio/parquet/tests/reader/E2EFilterTest.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "velox/common/io/IoStatistics.h" #include "velox/dwio/common/tests/utils/E2EFilterTestBase.h" #include "velox/dwio/parquet/reader/ParquetReader.h" #include "velox/dwio/parquet/reader/ParquetTypeWithId.h" @@ -59,7 +60,8 @@ class E2EFilterTest : public E2EFilterTestBase, void writeToMemory( const TypePtr& type, const std::vector& batches, - bool forRowGroupSkip = false) override { + bool forRowGroupSkip = false, + const std::vector& /*indexColumns*/ = {}) override { auto sink = std::make_unique( 200 * 1024 * 1024, FileSink::Options{.pool = leafPool_.get()}); auto* sinkPtr = sink.get(); @@ -90,6 +92,10 @@ class E2EFilterTest : public E2EFilterTestBase, return std::make_unique(std::move(input), opts); } + std::shared_ptr dataIoStats_ = + std::make_shared(); + std::shared_ptr metadataIoStats_ = + std::make_shared(); std::unique_ptr writer_; facebook::velox::parquet::WriterOptions options_; uint64_t rowsInRowGroup_ = 10'000; @@ -99,8 +105,9 @@ class E2EFilterTest : public E2EFilterTestBase, TEST_F(E2EFilterTest, writerMagic) { rowType_ = ROW({"c0"}, {INTEGER()}); std::vector batches; - batches.push_back(std::static_pointer_cast( - test::BatchMaker::createBatch(rowType_, 20000, *leafPool_, nullptr, 0))); + batches.push_back( + std::static_pointer_cast(test::BatchMaker::createBatch( + rowType_, 20000, *leafPool_, nullptr, 0))); writeToMemory(rowType_, batches, false); auto data = sinkData_.data(); auto size = sinkData_.size(); @@ -136,7 +143,7 @@ TEST_F(E2EFilterTest, integerDirect) { TEST_F(E2EFilterTest, integerDeltaBinaryPack) { options_.enableDictionary = false; options_.encoding = - facebook::velox::parquet::arrow::Encoding::DELTA_BINARY_PACKED; + facebook::velox::parquet::arrow::Encoding::kDeltaBinaryPacked; testWithTypes( "short_val:smallint," @@ -525,7 +532,24 @@ TEST_F(E2EFilterTest, stringDictionary) { TEST_F(E2EFilterTest, stringDeltaByteArray) { options_.enableDictionary = false; options_.encoding = - facebook::velox::parquet::arrow::Encoding::DELTA_BYTE_ARRAY; + facebook::velox::parquet::arrow::Encoding::kDeltaByteArray; + + testWithTypes( + "string_val:string," + "string_val_2:string", + [&]() { + makeStringUnique("string_val"); + makeStringUnique("string_val_2"); + }, + true, + {"string_val", "string_val_2"}, + 20); +} + +TEST_F(E2EFilterTest, stringDeltaLengthByteArray) { + options_.enableDictionary = false; + options_.encoding = + facebook::velox::parquet::arrow::Encoding::kDeltaLengthByteArray; testWithTypes( "string_val:string," @@ -657,11 +681,14 @@ TEST_F(E2EFilterTest, largeMetadata) { rowType_ = ROW({"c0"}, {INTEGER()}); std::vector batches; - batches.push_back(std::static_pointer_cast( - test::BatchMaker::createBatch(rowType_, 1000, *leafPool_, nullptr, 0))); + batches.push_back( + std::static_pointer_cast(test::BatchMaker::createBatch( + rowType_, 1000, *leafPool_, nullptr, 0))); writeToMemory(rowType_, batches, false); - dwio::common::ReaderOptions readerOpts{leafPool_.get()}; - readerOpts.setFooterEstimatedSize(1024); + dwio::common::ReaderOptions readerOpts(leafPool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); + readerOpts.setFooterSpeculativeIoSize(1024); readerOpts.setFilePreloadThreshold(1024 * 8); dwio::common::RowReaderOptions rowReaderOpts; auto input = std::make_unique( @@ -689,16 +716,113 @@ TEST_F(E2EFilterTest, date) { 20); } +TEST_F(E2EFilterTest, time) { + struct { + parquet::arrow::Encoding::type encoding; + bool enableDictionary; + bool keepNulls; + } testCases[] = { + {parquet::arrow::Encoding::kPlain, false, true}, + {parquet::arrow::Encoding::kPlain, true, true}, + {parquet::arrow::Encoding::kDeltaBinaryPacked, false, false}, + {parquet::arrow::Encoding::kDeltaBinaryPacked, false, true}, + }; + + for (const auto& testCase : testCases) { + options_.encoding = testCase.encoding; + bool enableDictionary = testCase.enableDictionary; + bool keepNulls = testCase.keepNulls; + SCOPED_TRACE( + fmt::format( + "Encoding: {}, Dictionary: {}, KeepNulls: {}", + static_cast(options_.encoding), + enableDictionary, + keepNulls)); + + options_.enableDictionary = enableDictionary; + options_.dataPageSize = 4 * 1024; + const int valMax = enableDictionary ? 1000 : 86399999; + + testWithTypes( + "time_val:time", + [&]() { + makeIntDistribution( + "time_val", + 0, // min + valMax, // max + 22, // repeats + 19, // rareFrequency + 0, // rareMin + valMax, // rareMax + keepNulls); // keepNulls + }, + false, + {"time_val"}, + 20); + } +} + +TEST_F(E2EFilterTest, timeMicros) { + struct { + parquet::arrow::Encoding::type encoding; + bool enableDictionary; + bool keepNulls; + } testCases[] = { + {parquet::arrow::Encoding::kPlain, false, true}, + {parquet::arrow::Encoding::kPlain, true, true}, + {parquet::arrow::Encoding::kDeltaBinaryPacked, false, false}, + {parquet::arrow::Encoding::kDeltaBinaryPacked, false, true}, + }; + + for (const auto& testCase : testCases) { + options_.encoding = testCase.encoding; + bool enableDictionary = testCase.enableDictionary; + bool keepNulls = testCase.keepNulls; + SCOPED_TRACE( + fmt::format( + "Encoding: {}, Dictionary: {}, KeepNulls: {}", + static_cast(options_.encoding), + enableDictionary, + keepNulls)); + + options_.enableDictionary = enableDictionary; + options_.dataPageSize = 4 * 1024; + // Microseconds since midnight up to 86,399,999,999 (one second short of + // 24 h). Use a smaller cap when forcing a dictionary so values are dense. + const int64_t valMax = enableDictionary ? 1'000 : 86'399'999'999LL; + + testWithTypes( + "time_val:time_micro_utc", + [&]() { + makeIntDistribution( + "time_val", + 0, // min + valMax, // max + 22, // repeats + 19, // rareFrequency + 0, // rareMin + valMax, // rareMax + keepNulls); // keepNulls + }, + false, + {"time_val"}, + 20); + } +} + TEST_F(E2EFilterTest, combineRowGroup) { rowsInRowGroup_ = 5; rowType_ = ROW({"c0"}, {INTEGER()}); std::vector batches; for (int i = 0; i < 5; i++) { - batches.push_back(std::static_pointer_cast( - test::BatchMaker::createBatch(rowType_, 1, *leafPool_, nullptr, 0))); + batches.push_back( + std::static_pointer_cast(test::BatchMaker::createBatch( + rowType_, 1, *leafPool_, nullptr, 0))); } writeToMemory(rowType_, batches, false); - dwio::common::ReaderOptions readerOpts{leafPool_.get()}; + dwio::common::ReaderOptions readerOpts(leafPool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); auto input = std::make_unique( std::make_shared(sinkData_), readerOpts.memoryPool()); auto reader = makeReader(readerOpts, std::move(input)); @@ -707,13 +831,115 @@ TEST_F(E2EFilterTest, combineRowGroup) { EXPECT_EQ(parquetReader.numberOfRows(), 5); } +// Reproduces the real-world scenario from the bug report. Parquet-mr 1.8.1 +// computed binary column min/max using signed byte ordering, which differs from +// the unsigned lexicographic (memcmp) ordering Velox uses. +// +// With signed byte ordering: 三星应用商店 < 360手机助手 < vivo预装 +// With memcmp byte ordering: 360手机助手 < vivo预装 < 三星应用商店 +// +// A row group containing {"三星应用商店", "vivo预装"} has memcmp-based stats +// min="vivo预装", max="三星应用商店". A filter for "360手机助手" falls below +// the memcmp min, so the row group would be incorrectly skipped — even though +// it should match under the signed ordering that parquet-mr 1.8.1 used to write +// the stats. +TEST_F(E2EFilterTest, parquetMRVersionStringStatsRowGroupFiltering) { + const std::string kSanXing = "三星应用商店"; + const std::string kVivo = "vivo预装"; + const std::string k360 = "360手机助手"; + + auto rowType = ROW({"s"}, {VARCHAR()}); + + auto writeAndGetStats = [&](const std::string& createdBy, + RuntimeStatistics& stats) { + options_.memoryPool = E2EFilterTestBase::rootPool_.get(); + options_.createdBy = createdBy; + // Flush after every 5 rows to create separate row groups. + options_.flushPolicyFactory = []() { + return std::make_unique( + /*rowsInRowGroup=*/5, + /*bytesInRowGroup=*/1'024 * 1'024, + []() { return false; }); + }; + + auto sink = std::make_unique( + 200 * 1024 * 1024, FileSink::Options{.pool = leafPool_.get()}); + auto* sinkPtr = sink.get(); + auto writer = + std::make_unique(std::move(sink), options_, rowType); + // Row group 1: contains the value we will filter for ("360手机助手"). + writer->write(makeRowVector( + {"s"}, + {makeFlatVector( + {k360, kSanXing, kVivo, k360, kSanXing})})); + // Row group 2: does not contain "360手机助手". + writer->write(makeRowVector( + {"s"}, + {makeFlatVector( + {kSanXing, kVivo, kSanXing, kVivo, kSanXing})})); + writer->close(); + + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto input = std::make_unique( + std::make_shared( + std::string(sinkPtr->data(), sinkPtr->size())), + readerOptions.memoryPool()); + auto reader = makeReader(readerOptions, std::move(input)); + auto& parquetReader = dynamic_cast(*reader); + EXPECT_EQ(parquetReader.fileMetaData().numRowGroups(), 2); + + auto scanSpec = std::make_shared(""); + scanSpec->addAllChildFields(*rowType); + // Equality filter: s = "360手机助手". + scanSpec->getOrCreateChild(Subfield("s")) + ->setFilter( + std::make_unique( + k360, false, false, k360, false, false, false)); + + RowReaderOptions rowReaderOpts; + rowReaderOpts.select( + std::make_shared(rowType, rowType->names())); + rowReaderOpts.setScanSpec(scanSpec); + auto rowReader = reader->createRowReader(rowReaderOpts); + + VectorPtr result = BaseVector::create(rowType, 1, leafPool_.get()); + uint64_t totalRows{0}; + while (rowReader->next(1'000, result)) { + totalRows += result->size(); + } + EXPECT_EQ(totalRows, 2); + + rowReader->updateRuntimeStats(stats); + }; + + // parquet-mr 1.8.2: stats are trusted. Under memcmp ordering, row group 1 + // has min="360手机助手" max="三星应用商店" which contains "360手机助手", so + // it is read. Row group 2 has min="vivo预装" max="三星应用商店" which does + // not contain "360手机助手" (it falls below memcmp min), so it is skipped. + RuntimeStatistics stats182; + writeAndGetStats("parquet-mr version 1.8.2", stats182); + EXPECT_EQ(stats182.skippedStrides, 1); + EXPECT_EQ(stats182.processedStrides, 1); + + // parquet-mr 1.8.1: stats are untrusted (signed byte ordering bug), so no + // row groups are skipped. Both row groups are scanned. + RuntimeStatistics stats181; + writeAndGetStats("parquet-mr version 1.8.1", stats181); + EXPECT_EQ(stats181.skippedStrides, 0); + EXPECT_EQ(stats181.processedStrides, 2); +} + TEST_F(E2EFilterTest, writeDecimalAsInteger) { auto rowVector = makeRowVector( {makeFlatVector({1, 2}, DECIMAL(8, 2)), makeFlatVector({1, 2}, DECIMAL(10, 2)), makeFlatVector({1, 2}, DECIMAL(19, 2))}); writeToMemory(rowVector->type(), {rowVector}, false); - dwio::common::ReaderOptions readerOpts{leafPool_.get()}; + dwio::common::ReaderOptions readerOpts(leafPool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); auto input = std::make_unique( std::make_shared(sinkData_), readerOpts.memoryPool()); auto reader = makeReader(readerOpts, std::move(input)); @@ -738,7 +964,9 @@ TEST_F(E2EFilterTest, configurableWriteSchema) { } writeToMemory(newType, batches, false); - dwio::common::ReaderOptions readerOpts{leafPool_.get()}; + dwio::common::ReaderOptions readerOpts(leafPool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); auto input = std::make_unique( std::make_shared(sinkData_), readerOpts.memoryPool()); auto reader = makeReader(readerOpts, std::move(input)); @@ -777,7 +1005,7 @@ TEST_F(E2EFilterTest, configurableWriteSchema) { TEST_F(E2EFilterTest, booleanRle) { options_.enableDictionary = false; - options_.encoding = facebook::velox::parquet::arrow::Encoding::RLE; + options_.encoding = facebook::velox::parquet::arrow::Encoding::kRle; options_.useParquetDataPageV2 = true; testWithTypes( diff --git a/velox/dwio/parquet/tests/reader/ParquetPageReaderTest.cpp b/velox/dwio/parquet/tests/reader/ParquetPageReaderTest.cpp index 5145dcfdc8c..bda28a60ee1 100644 --- a/velox/dwio/parquet/tests/reader/ParquetPageReaderTest.cpp +++ b/velox/dwio/parquet/tests/reader/ParquetPageReaderTest.cpp @@ -15,7 +15,12 @@ */ #include "velox/dwio/parquet/reader/PageReader.h" + +#include +#include +#include "velox/common/base/tests/GTestUtils.h" #include "velox/dwio/parquet/tests/ParquetTestBase.h" +#include "velox/dwio/parquet/thrift/ParquetThriftTypes.h" using namespace facebook::velox; using namespace facebook::velox::common; @@ -31,11 +36,13 @@ TEST_F(ParquetPageReaderTest, smallPage) { auto headerSize = file->getLength(); auto inputStream = std::make_unique( std::move(file), 0, headerSize, *leafPool_, LogType::TEST); + dwio::common::ColumnReaderStatistics stats; auto pageReader = std::make_unique( std::move(inputStream), *leafPool_, common::CompressionKind::CompressionKind_GZIP, - headerSize); + headerSize, + stats); auto header = pageReader->readPageHeader(); EXPECT_EQ(header.type, thrift::PageType::type::DATA_PAGE); EXPECT_EQ(header.uncompressed_page_size, 16950); @@ -50,6 +57,7 @@ TEST_F(ParquetPageReaderTest, smallPage) { auto maxValue = header.data_page_header.statistics.max_value; EXPECT_EQ(minValue, expectedMinValue); EXPECT_EQ(maxValue, expectedMaxValue); + EXPECT_GT(stats.pageLoadTimeNs.sum(), 0); } TEST_F(ParquetPageReaderTest, largePage) { @@ -59,11 +67,13 @@ TEST_F(ParquetPageReaderTest, largePage) { auto headerSize = file->getLength(); auto inputStream = std::make_unique( std::move(file), 0, headerSize, *leafPool_, LogType::TEST); + dwio::common::ColumnReaderStatistics stats; auto pageReader = std::make_unique( std::move(inputStream), *leafPool_, common::CompressionKind::CompressionKind_GZIP, - headerSize); + headerSize, + stats); auto header = pageReader->readPageHeader(); EXPECT_EQ(header.type, thrift::PageType::type::DATA_PAGE); @@ -79,6 +89,7 @@ TEST_F(ParquetPageReaderTest, largePage) { auto maxValue = header.data_page_header.statistics.max_value; EXPECT_EQ(minValue, expectedMinValue); EXPECT_EQ(maxValue, expectedMaxValue); + EXPECT_GT(stats.pageLoadTimeNs.sum(), 0); } TEST_F(ParquetPageReaderTest, corruptedPageHeader) { @@ -92,11 +103,13 @@ TEST_F(ParquetPageReaderTest, corruptedPageHeader) { // In the corrupted_page_header, the min_value length is set incorrectly on // purpose. This is to simulate the situation where the Parquet Page Header is // corrupted. And an error is expected to be thrown. + dwio::common::ColumnReaderStatistics stats; auto pageReader = std::make_unique( std::move(inputStream), *leafPool_, common::CompressionKind::CompressionKind_GZIP, - headerSize); + headerSize, + stats); EXPECT_THROW(pageReader->readPageHeader(), VeloxException); } @@ -108,3 +121,417 @@ TEST(CompressionOptionsTest, testCompressionOptions) { options.format.zlib.windowBits, dwio::common::compression::Compressor::PARQUET_ZLIB_WINDOW_BITS); } + +// Test that prepareDictionary rejects FIXED_LEN_BYTE_ARRAY dictionary pages +// where the Parquet type length exceeds the Velox type length. This guards +// against heap buffer overflow from malicious Parquet files that have a patched +// precision (e.g. decimal128 with typeLength=16 but precision lowered to make +// Velox choose int64_t with cppSizeInBytes=8). +TEST_F(ParquetPageReaderTest, fixedLenByteArrayDictOverflow) { + // Simulate the exploit: Parquet FIXED_LEN_BYTE_ARRAY with typeLength=16 + // but Velox sees SHORT_DECIMAL (precision=10) which is int64_t (8 bytes). + // Without the check, prepareDictionary would allocate numValues*8 bytes + // but memcpy numValues*16 bytes, causing a heap overflow. + constexpr int32_t kNumDictValues = 4; + constexpr int32_t kParquetTypeLength = 16; + constexpr int32_t kDictPageSize = kNumDictValues * kParquetTypeLength; + + // Create a DICTIONARY_PAGE header. + thrift::PageHeader dictHeader; + dictHeader.__set_type(thrift::PageType::DICTIONARY_PAGE); + dictHeader.__set_uncompressed_page_size(kDictPageSize); + dictHeader.__set_compressed_page_size(kDictPageSize); + thrift::DictionaryPageHeader dictPageHeader; + dictPageHeader.__set_num_values(kNumDictValues); + dictPageHeader.__set_encoding(thrift::Encoding::PLAIN); + dictHeader.__set_dictionary_page_header(dictPageHeader); + + auto transport = std::make_shared(); + apache::thrift::protocol::TCompactProtocolT< + apache::thrift::transport::TMemoryBuffer> + protocol(transport); + dictHeader.write(&protocol); + std::string dictHeaderBytes = transport->getBufferAsString(); + + // Dictionary page data (content doesn't matter, check fires before read). + std::string dictPageData(kDictPageSize, '\0'); + + // Create a DATA_PAGE header so seekToPage can find a data page after the + // dictionary page. + constexpr int32_t kDataPageSize = 8; + thrift::PageHeader dataHeader; + dataHeader.__set_type(thrift::PageType::DATA_PAGE); + dataHeader.__set_uncompressed_page_size(kDataPageSize); + dataHeader.__set_compressed_page_size(kDataPageSize); + thrift::DataPageHeader dataPageHeader; + dataPageHeader.__set_num_values(1); + dataPageHeader.__set_encoding(thrift::Encoding::RLE_DICTIONARY); + dataPageHeader.__set_definition_level_encoding(thrift::Encoding::RLE); + dataPageHeader.__set_repetition_level_encoding(thrift::Encoding::RLE); + dataHeader.__set_data_page_header(dataPageHeader); + + auto transport2 = + std::make_shared(); + apache::thrift::protocol::TCompactProtocolT< + apache::thrift::transport::TMemoryBuffer> + protocol2(transport2); + dataHeader.write(&protocol2); + std::string dataHeaderBytes = transport2->getBufferAsString(); + + std::string dataPageData(kDataPageSize, '\0'); + + // Combine: dict header + dict data + data header + data data. + std::string fullData = + dictHeaderBytes + dictPageData + dataHeaderBytes + dataPageData; + + auto inputStream = std::make_unique( + fullData.data(), fullData.size()); + + // Construct ParquetTypeWithId: SHORT_DECIMAL (precision=10, cppSizeInBytes=8) + // with parquetType=FIXED_LEN_BYTE_ARRAY and typeLength=16. + auto fileType = std::make_shared( + DECIMAL(10, 2), + std::vector>{}, + /*id=*/0, + /*maxId=*/0, + /*column=*/0, + "test_col", + thrift::Type::FIXED_LEN_BYTE_ARRAY, + std::nullopt, + std::nullopt, + /*maxRepeat=*/0, + /*maxDefine=*/1, + /*isOptional=*/true, + /*isRepeated=*/false, + /*precision=*/10, + /*scale=*/2, + /*typeLength=*/kParquetTypeLength); + + dwio::common::ColumnReaderStatistics stats; + auto pageReader = std::make_unique( + std::move(inputStream), + *leafPool_, + fileType, + common::CompressionKind::CompressionKind_NONE, + fullData.size(), + stats, + nullptr); + + // skip(1) triggers seekToPage() -> prepareDictionary(). + // The VELOX_CHECK_LE should fire because numParquetBytes (4*16=64) > + // numVeloxBytes (4*8=32). + VELOX_ASSERT_THROW(pageReader->skip(1), ""); +} + +namespace { + +// Helper to serialize a PageHeader using Thrift compact protocol. +std::string serializePageHeader(const thrift::PageHeader& header) { + auto transport = std::make_shared(); + apache::thrift::protocol::TCompactProtocolT< + apache::thrift::transport::TMemoryBuffer> + protocol(transport); + header.write(&protocol); + return transport->getBufferAsString(); +} + +// Helper to create a DATA_PAGE header with specified sizes. +thrift::PageHeader createDataPageV1Header( + int32_t uncompressedSize, + int32_t compressedSize, + int32_t numValues) { + thrift::PageHeader header; + header.__set_type(thrift::PageType::DATA_PAGE); + header.__set_uncompressed_page_size(uncompressedSize); + header.__set_compressed_page_size(compressedSize); + + thrift::DataPageHeader dataHeader; + dataHeader.__set_num_values(numValues); + dataHeader.__set_encoding(thrift::Encoding::PLAIN); + dataHeader.__set_definition_level_encoding(thrift::Encoding::RLE); + dataHeader.__set_repetition_level_encoding(thrift::Encoding::RLE); + header.__set_data_page_header(dataHeader); + + return header; +} + +// Helper to create a DATA_PAGE_V2 header with specified sizes. +thrift::PageHeader createDataPageV2Header( + int32_t uncompressedSize, + int32_t compressedSize, + int32_t numValues, + int32_t definitionLevelsByteLength, + int32_t repetitionLevelsByteLength) { + thrift::PageHeader header; + header.__set_type(thrift::PageType::DATA_PAGE_V2); + header.__set_uncompressed_page_size(uncompressedSize); + header.__set_compressed_page_size(compressedSize); + + thrift::DataPageHeaderV2 dataHeader; + dataHeader.__set_num_values(numValues); + dataHeader.__set_num_nulls(0); + dataHeader.__set_num_rows(numValues); + dataHeader.__set_encoding(thrift::Encoding::PLAIN); + dataHeader.__set_definition_levels_byte_length(definitionLevelsByteLength); + dataHeader.__set_repetition_levels_byte_length(repetitionLevelsByteLength); + dataHeader.__set_is_compressed(false); + header.__set_data_page_header_v2(dataHeader); + + return header; +} + +} // namespace + +// Test that prepareDataPageV1 rejects pages with defineLength exceeding page +// size. This guards against heap buffer overflow from corrupt Parquet files. +TEST_F(ParquetPageReaderTest, corruptDefineLengthV1) { + // Create a DATA_PAGE header with small page size. + constexpr int32_t kPageSize = 20; + auto pageHeader = createDataPageV1Header(kPageSize, kPageSize, 100); + std::string headerBytes = serializePageHeader(pageHeader); + + // Create corrupt page data where defineLength (first 4 bytes after + // decompression for maxDefine > 0) is huge. Since compression is NONE, the + // "decompressed" data is the raw page data. + std::string pageData(kPageSize, '\0'); + // Set defineLength to a huge value (0x7FFFFFF0) that exceeds page size. + uint32_t corruptDefineLength = 0x7FFFFFF0; + memcpy(pageData.data(), &corruptDefineLength, sizeof(uint32_t)); + + // Combine header and page data. + std::string fullData = headerBytes + pageData; + + // Create an input stream from the crafted data. + auto inputStream = std::make_unique( + fullData.data(), fullData.size()); + + dwio::common::ColumnReaderStatistics stats; + // Create PageReader with maxRepeat=0, maxDefine=1 (so defineLength is read) + // and no compression (so page data is used directly). + auto pageReader = std::make_unique( + std::move(inputStream), + *leafPool_, + common::CompressionKind::CompressionKind_NONE, + fullData.size(), + stats, + nullptr, + 0, + 1); + + // Calling skip(1) triggers seekToPage() which calls prepareDataPageV1(). + // The bounds check should throw when defineLength exceeds page size. + VELOX_ASSERT_THROW( + pageReader->skip(1), "Definition level length 2147483632 exceeds"); +} + +// Test that prepareDataPageV1 rejects pages with repeatLength exceeding page +// size. This guards against heap buffer overflow from corrupt Parquet files. +TEST_F(ParquetPageReaderTest, corruptRepeatLengthV1) { + // Create a DATA_PAGE header with small page size. + constexpr int32_t kPageSize = 20; + auto pageHeader = createDataPageV1Header(kPageSize, kPageSize, 100); + std::string headerBytes = serializePageHeader(pageHeader); + + // Create corrupt page data where repeatLength (first 4 bytes after + // decompression for maxRepeat > 0) is huge. Since compression is NONE, the + // "decompressed" data is the raw page data. + std::string pageData(kPageSize, '\0'); + // Set repeatLength to a huge value (0x7FFFFFF0) that exceeds page size. + uint32_t corruptRepeatLength = 0x7FFFFFF0; + memcpy(pageData.data(), &corruptRepeatLength, sizeof(uint32_t)); + + // Combine header and page data. + std::string fullData = headerBytes + pageData; + + // Create an input stream from the crafted data. + auto inputStream = std::make_unique( + fullData.data(), fullData.size()); + + dwio::common::ColumnReaderStatistics stats; + // Create PageReader with maxRepeat=1 (so repeatLength is read), maxDefine=0, + // and no compression (so page data is used directly). + auto pageReader = std::make_unique( + std::move(inputStream), + *leafPool_, + common::CompressionKind::CompressionKind_NONE, + fullData.size(), + stats, + nullptr, + 1, + 0); + + // Calling skip(1) triggers seekToPage() which calls prepareDataPageV1(). + // The bounds check should throw when repeatLength exceeds page size. + VELOX_ASSERT_THROW( + pageReader->skip(1), "Repetition level length 2147483632 exceeds"); +} + +// Test that prepareDataPageV2 rejects pages where repetition + definition +// level lengths exceed compressed page size. +TEST_F(ParquetPageReaderTest, corruptLevelLengthsV2) { + // Create a DATA_PAGE_V2 header with small page size but huge level lengths. + constexpr int32_t kPageSize = 20; + // Set level lengths that exceed the page size. + constexpr int32_t kCorruptRepeatLength = 0x7FFFFFF0; + constexpr int32_t kCorruptDefineLength = 100; + auto pageHeader = createDataPageV2Header( + kPageSize, kPageSize, 100, kCorruptDefineLength, kCorruptRepeatLength); + std::string headerBytes = serializePageHeader(pageHeader); + + // Create page data (content doesn't matter since validation should fail + // before reading it). + std::string pageData(kPageSize, '\0'); + + // Combine header and page data. + std::string fullData = headerBytes + pageData; + + // Create an input stream from the crafted data. + auto inputStream = std::make_unique( + fullData.data(), fullData.size()); + + dwio::common::ColumnReaderStatistics stats; + // maxRepeat and maxDefine don't affect V2 validation since the lengths + // are in the header, not the page data. + auto pageReader = std::make_unique( + std::move(inputStream), + *leafPool_, + common::CompressionKind::CompressionKind_NONE, + fullData.size(), + stats, + nullptr, + 1, + 1); + + // Calling skip(1) triggers seekToPage() which calls prepareDataPageV2(). + // The bounds check should throw when level lengths exceed page size. + VELOX_ASSERT_THROW( + pageReader->skip(1), + "Repetition and definition level lengths (2147483632 + 100) exceed"); +} + +// Test that prepareDataPageV1 rejects pages that are too small to contain +// the repetition level length field (4 bytes). +TEST_F(ParquetPageReaderTest, insufficientBytesForRepeatLengthV1) { + // Create a DATA_PAGE header with page size too small for the 4-byte + // repetition level length field. + constexpr int32_t kPageSize = 2; // Less than sizeof(int32_t) + auto pageHeader = createDataPageV1Header(kPageSize, kPageSize, 100); + std::string headerBytes = serializePageHeader(pageHeader); + + // Create minimal page data. + std::string pageData(kPageSize, '\0'); + + // Combine header and page data. + std::string fullData = headerBytes + pageData; + + // Create an input stream from the crafted data. + auto inputStream = std::make_unique( + fullData.data(), fullData.size()); + + dwio::common::ColumnReaderStatistics stats; + // PageReader with maxRepeat > 0 would try to read repeatLength but page + // is too small. + auto pageReader = std::make_unique( + std::move(inputStream), + *leafPool_, + common::CompressionKind::CompressionKind_NONE, + fullData.size(), + stats, + nullptr, + 1, + 0); + + // Calling skip(1) triggers seekToPage() which calls prepareDataPageV1(). + // The bounds check should throw when page is too small to hold repeatLength. + VELOX_ASSERT_THROW( + pageReader->skip(1), "Insufficient bytes for repetition level length"); +} + +// Test that prepareDataPageV1 rejects pages where there are insufficient bytes +// remaining for the definition level length field after reading repetition +// levels. +TEST_F(ParquetPageReaderTest, insufficientBytesForDefineLengthV1) { + // Create a DATA_PAGE header with page size that can hold repetition level + // length (4 bytes) plus a small repeatLength value, but not enough remaining + // for the definition level length field. + constexpr int32_t kPageSize = 6; // 4 bytes for repeatLength field + 2 bytes + auto pageHeader = createDataPageV1Header(kPageSize, kPageSize, 100); + std::string headerBytes = serializePageHeader(pageHeader); + + // Create page data with a small repeatLength that leaves insufficient bytes + // for defineLength. + std::string pageData(kPageSize, '\0'); + // Set repeatLength to 0 (valid), leaving only 2 bytes which is insufficient + // for the 4-byte defineLength field. + uint32_t repeatLength = 0; + memcpy(pageData.data(), &repeatLength, sizeof(uint32_t)); + + // Combine header and page data. + std::string fullData = headerBytes + pageData; + + // Create an input stream from the crafted data. + auto inputStream = std::make_unique( + fullData.data(), fullData.size()); + + dwio::common::ColumnReaderStatistics stats; + // PageReader with both maxRepeat > 0 and maxDefine > 0 would read + // repeatLength (0), advance past it, then try to read defineLength but + // there are insufficient bytes remaining. + auto pageReader = std::make_unique( + std::move(inputStream), + *leafPool_, + common::CompressionKind::CompressionKind_NONE, + fullData.size(), + stats, + nullptr, + 1, + 1); + + // Calling skip(1) triggers seekToPage() which calls prepareDataPageV1(). + // The bounds check should throw when there are insufficient bytes for + // defineLength. + VELOX_ASSERT_THROW( + pageReader->skip(1), "Insufficient bytes for definition level length"); +} + +// Test that prepareDataPageV2 rejects pages where only the repetition level +// length exceeds the page size (without combining with definition length). +TEST_F(ParquetPageReaderTest, corruptRepeatLengthOnlyV2) { + // Create a DATA_PAGE_V2 header where just the repetition level length + // exceeds the page size. + constexpr int32_t kPageSize = 20; + constexpr int32_t kCorruptRepeatLength = 0x7FFFFFF0; + constexpr int32_t kDefineLength = 0; // No definition levels + auto pageHeader = createDataPageV2Header( + kPageSize, kPageSize, 100, kDefineLength, kCorruptRepeatLength); + std::string headerBytes = serializePageHeader(pageHeader); + + // Create page data. + std::string pageData(kPageSize, '\0'); + + // Combine header and page data. + std::string fullData = headerBytes + pageData; + + // Create an input stream from the crafted data. + auto inputStream = std::make_unique( + fullData.data(), fullData.size()); + + dwio::common::ColumnReaderStatistics stats; + // maxRepeat and maxDefine don't affect V2 validation since the lengths + // are in the header, not the page data. + auto pageReader = std::make_unique( + std::move(inputStream), + *leafPool_, + common::CompressionKind::CompressionKind_NONE, + fullData.size(), + stats, + nullptr, + 1, + 0); + + // Calling skip(1) triggers seekToPage() which calls prepareDataPageV2(). + // The bounds check should throw when repeatLength exceeds page size. + VELOX_ASSERT_THROW( + pageReader->skip(1), + "Repetition and definition level lengths (2147483632 + 0) exceed"); +} diff --git a/velox/dwio/parquet/tests/reader/ParquetReaderBenchmark.cpp b/velox/dwio/parquet/tests/reader/ParquetReaderBenchmark.cpp index 976d021d87e..02cbe4a5129 100644 --- a/velox/dwio/parquet/tests/reader/ParquetReaderBenchmark.cpp +++ b/velox/dwio/parquet/tests/reader/ParquetReaderBenchmark.cpp @@ -96,7 +96,9 @@ std::shared_ptr ParquetReaderBenchmark::createScanSpec( std::unique_ptr ParquetReaderBenchmark::createReader( std::shared_ptr scanSpec, const RowTypePtr& rowType) { - dwio::common::ReaderOptions readerOpts{leafPool_.get()}; + dwio::common::ReaderOptions readerOpts(leafPool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); auto input = std::make_unique( std::make_shared(fileFolder_->getPath() + "/" + fileName_), readerOpts.memoryPool()); diff --git a/velox/dwio/parquet/tests/reader/ParquetReaderBenchmark.h b/velox/dwio/parquet/tests/reader/ParquetReaderBenchmark.h index 07f2913a89f..680a27b10ca 100644 --- a/velox/dwio/parquet/tests/reader/ParquetReaderBenchmark.h +++ b/velox/dwio/parquet/tests/reader/ParquetReaderBenchmark.h @@ -15,6 +15,8 @@ */ #pragma once +#include "velox/common/io/IoStatistics.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/dwio/common/FileSink.h" #include "velox/dwio/common/Options.h" #include "velox/dwio/common/Statistics.h" @@ -22,13 +24,14 @@ #include "velox/dwio/parquet/RegisterParquetReader.h" #include "velox/dwio/parquet/reader/ParquetReader.h" #include "velox/dwio/parquet/writer/Writer.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include #include namespace facebook::velox::parquet::test { +using TempDirectoryPath = common::testutil::TempDirectoryPath; + constexpr uint32_t kNumRowsPerBatch = 60000; constexpr uint32_t kNumBatches = 50; constexpr uint32_t kNumRowsPerRowGroup = 10000; @@ -104,13 +107,17 @@ class ParquetReaderBenchmark { private: const std::string fileName_ = "test.parquet"; - const std::shared_ptr - fileFolder_ = facebook::velox::exec::test::TempDirectoryPath::create(); + const std::shared_ptr fileFolder_ = + TempDirectoryPath::create(); const bool disableDictionary_; std::unique_ptr dataSetBuilder_; std::shared_ptr rootPool_; std::shared_ptr leafPool_; + std::shared_ptr dataIoStats_ = + std::make_shared(); + std::shared_ptr metadataIoStats_ = + std::make_shared(); std::unique_ptr writer_; facebook::velox::dwio::common::RuntimeStatistics runtimeStats_; }; diff --git a/velox/dwio/parquet/tests/reader/ParquetReaderTest.cpp b/velox/dwio/parquet/tests/reader/ParquetReaderTest.cpp index d336db9ac4f..00446d11e70 100644 --- a/velox/dwio/parquet/tests/reader/ParquetReaderTest.cpp +++ b/velox/dwio/parquet/tests/reader/ParquetReaderTest.cpp @@ -15,7 +15,11 @@ */ #include "velox/common/base/tests/GTestUtils.h" +#include "velox/dwio/common/Mutation.h" +#include "velox/dwio/parquet/reader/ParquetStatsContext.h" +#include "velox/dwio/parquet/reader/SemanticVersion.h" #include "velox/dwio/parquet/tests/ParquetTestBase.h" +#include "velox/dwio/parquet/thrift/ParquetThriftTypes.h" #include "velox/expression/ExprToSubfieldFilter.h" #include "velox/vector/tests/utils/VectorMaker.h" @@ -31,7 +35,9 @@ class ParquetReaderTest : public ParquetTestBase { const RowTypePtr& rowType) { const std::string sample(getExampleFilePath(fileName)); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOptions); RowReaderOptions rowReaderOpts; @@ -57,7 +63,9 @@ class ParquetReaderTest : public ParquetTestBase { FilterMap filters, const RowVectorPtr& expected) { const auto filePath(getExampleFilePath(fileName)); - dwio::common::ReaderOptions readerOpts{leafPool_.get()}; + dwio::common::ReaderOptions readerOpts(leafPool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); auto reader = createReader(filePath, readerOpts); assertReadWithReaderAndFilters( std::move(reader), fileName, fileSchema, std::move(filters), expected); @@ -72,7 +80,9 @@ TEST_F(ParquetReaderTest, parseSample) { // b: [1.0..20.0] const std::string sample(getExampleFilePath("sample.parquet")); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOptions); EXPECT_EQ(reader->numberOfRows(), 20ULL); @@ -121,7 +131,9 @@ TEST_F(ParquetReaderTest, parseEmptyNestedList) { const std::string sample( getExampleFilePath("parse_empty_nested_list.parquet")); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOptions); EXPECT_EQ(reader->numberOfRows(), 1000ULL); @@ -179,7 +191,9 @@ TEST_F(ParquetReaderTest, parseUnannotatedList) { // } const std::string sample(getExampleFilePath("unannotated_list.parquet")); - dwio::common::ReaderOptions readerOpts{leafPool_.get()}; + dwio::common::ReaderOptions readerOpts(leafPool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOpts); EXPECT_EQ(reader->numberOfRows(), 22ULL); @@ -234,7 +248,9 @@ TEST_F(ParquetReaderTest, parseUnannotatedMap) { const std::string filename("unnotated_map.parquet"); const std::string sample(getExampleFilePath(filename)); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOptions); auto type = reader->typeWithId(); @@ -273,7 +289,9 @@ TEST_F(ParquetReaderTest, parseLegacyListWithMultipleChildren) { const std::string filename("listmultiplechildren.parquet"); const std::string sample(getExampleFilePath(filename)); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOptions); auto type = reader->typeWithId(); @@ -332,7 +350,9 @@ TEST_F(ParquetReaderTest, parseArrayOfRowHiveReservedKeywords) { const std::string sample( getExampleFilePath("array_of_row_hive_reserved_keywords.parquet")); - dwio::common::ReaderOptions readerOpts{leafPool_.get()}; + dwio::common::ReaderOptions readerOpts(leafPool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOpts); EXPECT_EQ(reader->rowType()->toString(), expectedVeloxType); EXPECT_EQ(reader->numberOfRows(), 6ULL); @@ -383,7 +403,9 @@ TEST_F(ParquetReaderTest, parseArrayOfRowHiveReservedKeywords) { TEST_F(ParquetReaderTest, parseSampleRange1) { const std::string sample(getExampleFilePath("sample.parquet")); - dwio::common::ReaderOptions readerOpts{leafPool_.get()}; + dwio::common::ReaderOptions readerOpts(leafPool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOpts); auto rowReaderOpts = getReaderOpts(sampleSchema()); @@ -402,7 +424,9 @@ TEST_F(ParquetReaderTest, parseSampleRange1) { TEST_F(ParquetReaderTest, parseSampleRange2) { const std::string sample(getExampleFilePath("sample.parquet")); - dwio::common::ReaderOptions readerOpts{leafPool_.get()}; + dwio::common::ReaderOptions readerOpts(leafPool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOpts); auto rowReaderOpts = getReaderOpts(sampleSchema()); @@ -421,7 +445,9 @@ TEST_F(ParquetReaderTest, parseSampleRange2) { TEST_F(ParquetReaderTest, parseSampleEmptyRange) { const std::string sample(getExampleFilePath("sample.parquet")); - dwio::common::ReaderOptions readerOpts{leafPool_.get()}; + dwio::common::ReaderOptions readerOpts(leafPool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOpts); auto rowReaderOpts = getReaderOpts(sampleSchema()); @@ -439,7 +465,9 @@ TEST_F(ParquetReaderTest, parseReadAsLowerCase) { // 2 rows. const std::string upper(getExampleFilePath("upper.parquet")); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto outputRowType = ROW({"a", "b"}, {BIGINT(), BIGINT()}); readerOptions.setFileSchema(outputRowType); readerOptions.setFileColumnNamesReadAsLowerCase(true); @@ -475,7 +503,9 @@ TEST_F(ParquetReaderTest, parseRowMapArrayReadAsLowerCase) { // +-----------------------+ const std::string upper(getExampleFilePath("upper_complex.parquet")); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileColumnNamesReadAsLowerCase(true); auto reader = createReader(upper, readerOptions); @@ -518,7 +548,9 @@ TEST_F(ParquetReaderTest, parseEmpty) { // 0 rows. const std::string empty(getExampleFilePath("empty.parquet")); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto reader = createReader(empty, readerOptions); EXPECT_EQ(reader->numberOfRows(), 0ULL); @@ -540,7 +572,9 @@ TEST_F(ParquetReaderTest, parseInt) { // bigint: [1000 .. 1009] const std::string sample(getExampleFilePath("int.parquet")); - dwio::common::ReaderOptions readerOpts{leafPool_.get()}; + dwio::common::ReaderOptions readerOpts(leafPool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOpts); EXPECT_EQ(reader->numberOfRows(), 10ULL); @@ -575,7 +609,9 @@ TEST_F(ParquetReaderTest, parseUnsignedInt1) { // uint64: [18446744073709551615, 2000000000000000000, 3000000000000000000] const std::string sample(getExampleFilePath("uint.parquet")); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOptions); EXPECT_EQ(reader->numberOfRows(), 3ULL); @@ -595,8 +631,9 @@ TEST_F(ParquetReaderTest, parseUnsignedInt1) { {TINYINT(), SMALLINT(), INTEGER(), BIGINT()}); RowReaderOptions rowReaderOpts; - rowReaderOpts.select(std::make_shared( - rowType, rowType->names())); + rowReaderOpts.select( + std::make_shared( + rowType, rowType->names())); rowReaderOpts.setScanSpec(makeScanSpec(rowType)); auto rowReader = reader->createRowReader(rowReaderOpts); @@ -678,7 +715,9 @@ TEST_F(ParquetReaderTest, parseDate) { // date: [1969-12-27 .. 1970-01-20] const std::string sample(getExampleFilePath("date.parquet")); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOptions); EXPECT_EQ(reader->numberOfRows(), 25ULL); @@ -705,7 +744,9 @@ TEST_F(ParquetReaderTest, parseRowMapArray) { // ARRAY(INTEGER)) c1) c) const std::string sample(getExampleFilePath("row_map_array.parquet")); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOptions); EXPECT_EQ(reader->numberOfRows(), 1ULL); @@ -738,7 +779,9 @@ TEST_F(ParquetReaderTest, parseRowMapArray) { TEST_F(ParquetReaderTest, projectNoColumns) { // This is the case for count(*). auto rowType = ROW({}, {}); - dwio::common::ReaderOptions readerOpts{leafPool_.get()}; + dwio::common::ReaderOptions readerOpts(leafPool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); auto reader = createReader(getExampleFilePath("sample.parquet"), readerOpts); RowReaderOptions rowReaderOpts; rowReaderOpts.setScanSpec(makeScanSpec(rowType)); @@ -762,7 +805,9 @@ TEST_F(ParquetReaderTest, parseIntDecimal) { // a: [11.11, 11.11, 22.22, 22.22, 33.33, 33.33] // b: [11.11, 11.11, 22.22, 22.22, 33.33, 33.33] auto rowType = ROW({"a", "b"}, {DECIMAL(7, 2), DECIMAL(14, 2)}); - dwio::common::ReaderOptions readerOpts{leafPool_.get()}; + dwio::common::ReaderOptions readerOpts(leafPool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); const std::string decimal_dict(getExampleFilePath("decimal_dict.parquet")); auto reader = createReader(decimal_dict, readerOpts); @@ -816,7 +861,9 @@ TEST_F(ParquetReaderTest, parseMapKeyValueAsMap) { const std::string sample(getExampleFilePath("map_key_value.parquet")); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOptions); EXPECT_EQ(reader->numberOfRows(), 1ULL); @@ -864,7 +911,9 @@ TEST_F(ParquetReaderTest, parseRowArrayTest) { const std::string sample( getExampleFilePath("proto-struct-with-array.parquet")); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOptions); EXPECT_EQ(reader->numberOfRows(), 1ULL); auto type = reader->typeWithId(); @@ -1149,7 +1198,9 @@ TEST_F(ParquetReaderTest, filterRowGroups) { // decimal_no_ColumnMetadata.parquet has one columns a: DECIMAL(9,1). It // doesn't have ColumnMetaData, and rowGroups_[0].columns[0].file_offset is 0. auto rowType = ROW({"_c0"}, {DECIMAL(9, 1)}); - dwio::common::ReaderOptions readerOpts{leafPool_.get()}; + dwio::common::ReaderOptions readerOpts(leafPool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); const std::string decimal_dict( getExampleFilePath("decimal_no_ColumnMetadata.parquet")); @@ -1161,11 +1212,42 @@ TEST_F(ParquetReaderTest, filterRowGroups) { EXPECT_EQ(reader->numberOfRows(), 10ULL); } +TEST_F(ParquetReaderTest, shouldIgnoreStatsForParquetMRVersions) { + SemanticVersion v181("parquet-mr", 1, 8, 1); + ParquetStatsContext ctx181{std::optional(v181)}; + EXPECT_TRUE(ctx181.shouldIgnoreStatistics(thrift::Type::BYTE_ARRAY)) + << "ParquetStatsContext(parquet-mr 1.8.1) should ignore string stats"; + + SemanticVersion v182("parquet-mr", 1, 8, 2); + ParquetStatsContext ctx182{std::optional(v182)}; + EXPECT_FALSE(ctx182.shouldIgnoreStatistics(thrift::Type::BYTE_ARRAY)) + << "ParquetStatsContext(parquet-mr 1.8.2) should not ignore string stats"; +} + +// This test is to verify filterRowGroups() doesn't fail if offset is 0 +TEST_F(ParquetReaderTest, filterRowGroupsWithZeroOffset) { + auto rowType = ROW({"IDX"}, {INTEGER()}); + dwio::common::ReaderOptions readerOpts(leafPool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); + const std::string zeroOffsetPath( + getExampleFilePath("zero_offset_row_group.parquet")); + + auto reader = createReader(zeroOffsetPath, readerOpts); + RowReaderOptions rowReaderOpts; + rowReaderOpts.setScanSpec(makeScanSpec(rowType)); + auto rowReader = reader->createRowReader(rowReaderOpts); + + EXPECT_EQ(reader->numberOfRows(), 1L); +} + TEST_F(ParquetReaderTest, parseLongTagged) { // This is a case for long with annonation read const std::string sample(getExampleFilePath("tagged_long.parquet")); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOptions); EXPECT_EQ(reader->numberOfRows(), 4ULL); @@ -1183,7 +1265,9 @@ TEST_F(ParquetReaderTest, preloadSmallFile) { auto file = std::make_shared(sample); auto input = std::make_unique(file, *leafPool_); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto reader = std::make_unique(std::move(input), readerOptions); @@ -1196,7 +1280,7 @@ TEST_F(ParquetReaderTest, preloadSmallFile) { const auto fileSize = file->size(); ASSERT_TRUE( fileSize <= dwio::common::ReaderOptions::kDefaultFilePreloadThreshold || - fileSize <= dwio::common::ReaderOptions::kDefaultFooterEstimatedSize); + fileSize <= dwio::common::ReaderOptions::kDefaultFooterSpeculativeIoSize); // Check the whole file already loaded. ASSERT_EQ(file->bytesRead(), fileSize); @@ -1217,7 +1301,9 @@ TEST_F(ParquetReaderTest, prefetchRowGroups) { const std::string sample(getExampleFilePath("multiple_row_groups.parquet")); const int numRowGroups = 4; - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); // Disable preload of file. readerOptions.setFilePreloadThreshold(0); @@ -1270,7 +1356,9 @@ TEST_F(ParquetReaderTest, testEmptyRowGroups) { // empty_row_groups.parquet contains empty row groups const std::string sample(getExampleFilePath("empty_row_groups.parquet")); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOptions); EXPECT_EQ(reader->numberOfRows(), 5ULL); @@ -1296,7 +1384,9 @@ TEST_F(ParquetReaderTest, testEnumType) { // enum_type.parquet contains 1 column (ENUM) with 3 rows. const std::string sample(getExampleFilePath("enum_type.parquet")); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOptions); EXPECT_EQ(reader->numberOfRows(), 3ULL); @@ -1321,7 +1411,9 @@ TEST_F(ParquetReaderTest, readVarbinaryFromFLBA) { const std::string filename("varbinary_flba.parquet"); const std::string sample(getExampleFilePath(filename)); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOptions); auto type = reader->typeWithId(); @@ -1352,7 +1444,9 @@ TEST_F(ParquetReaderTest, readBinaryAsStringFromNation) { const std::string filename("nation.parquet"); const std::string sample(getExampleFilePath(filename)); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto outputRowType = ROW({"nationkey", "name", "regionkey", "comment"}, {BIGINT(), VARCHAR(), BIGINT(), VARCHAR()}); @@ -1383,7 +1477,9 @@ TEST_F(ParquetReaderTest, readComplexType) { const std::string filename("complex_with_varchar_varbinary.parquet"); const std::string sample(getExampleFilePath(filename)); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto outputRowType = ROW({"a", "b", "c", "d"}, {ARRAY(VARCHAR()), @@ -1430,7 +1526,9 @@ TEST_F(ParquetReaderTest, readFixedLenBinaryAsStringFromUuid) { const std::string filename("uuid.parquet"); const std::string sample(getExampleFilePath(filename)); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto outputRowType = ROW({"uuid_field"}, {VARCHAR()}); readerOptions.setFileSchema(outputRowType); @@ -1458,7 +1556,9 @@ TEST_F(ParquetReaderTest, readFixedLenBinaryAsStringFromUuid) { TEST_F(ParquetReaderTest, testV2PageWithZeroMaxDefRep) { const std::string sample(getExampleFilePath("v2_page.parquet")); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOptions); EXPECT_EQ(reader->numberOfRows(), 5ULL); @@ -1482,7 +1582,9 @@ TEST_F(ParquetReaderTest, testV2PageWithZeroMaxDefRep) { TEST_F(ParquetReaderTest, readComplexTypeWithV2Page) { const std::string sample(getExampleFilePath("complex_type_v2_page.parquet")); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOptions); EXPECT_EQ(reader->numberOfRows(), 1ULL); @@ -1522,7 +1624,9 @@ TEST_F(ParquetReaderTest, arrayOfMapOfIntKeyArrayValue) { "ROW>>>"; const std::string sample( getExampleFilePath("array_of_map_of_int_key_array_value.parquet")); - facebook::velox::dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + facebook::velox::dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOptions); EXPECT_EQ(reader->rowType()->toString(), expectedVeloxType); auto type = reader->typeWithId(); @@ -1555,7 +1659,9 @@ TEST_F(ParquetReaderTest, arrayOfMapOfIntKeyStructValue) { "ROW>>>"; const std::string sample( getExampleFilePath("array_of_map_of_int_key_struct_value.parquet")); - facebook::velox::dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + facebook::velox::dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOptions); EXPECT_EQ(reader->rowType()->toString(), expectedVeloxType); auto type = reader->typeWithId(); @@ -1589,7 +1695,9 @@ TEST_F(ParquetReaderTest, struct_of_array_of_array) { "ROW>,intarrayfield:ARRAY>>>"; const std::string sample( getExampleFilePath("struct_of_array_of_array.parquet")); - facebook::velox::dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + facebook::velox::dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOptions); auto type = reader->typeWithId(); EXPECT_EQ(type->size(), 1ULL); @@ -1653,7 +1761,9 @@ TEST_F(ParquetReaderTest, struct_of_array_of_array) { TEST_F(ParquetReaderTest, testLzoDataPage) { const std::string sample(getExampleFilePath("lzo.parquet")); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOptions); EXPECT_EQ(reader->numberOfRows(), 23'547ULL); @@ -1688,7 +1798,9 @@ TEST_F(ParquetReaderTest, testLzoDataPage) { TEST_F(ParquetReaderTest, testEmptyV2DataPage) { const std::string sample(getExampleFilePath("empty_v2datapage.parquet")); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto reader = createReader(sample, readerOptions); EXPECT_EQ(reader->numberOfRows(), 30001ULL); @@ -1720,7 +1832,9 @@ TEST_F(ParquetReaderTest, parquet251) { TEST_F(ParquetReaderTest, fileColumnVarcharToMetadataColumnMismatchTest) { const std::string sample(getExampleFilePath("nation.parquet")); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto runVarcharColTest = [&](const TypePtr& requestedType) { // The type in the file is a BYTE_ARRAY resolving to VARCHAR. @@ -1738,7 +1852,7 @@ TEST_F(ParquetReaderTest, fileColumnVarcharToMetadataColumnMismatchTest) { VELOX_ASSERT_THROW( createReader(sample, readerOptions), fmt::format( - "Converted type VARCHAR is not allowed for requested type {}", + "Converted type VARCHAR is not allowed for requested type {} for file column 'name'", requestedType->toString())); }; @@ -1757,3 +1871,270 @@ TEST_F(ParquetReaderTest, fileColumnVarcharToMetadataColumnMismatchTest) { runVarcharColTest(type); } } + +TEST_F(ParquetReaderTest, readerWithSchema) { + // Create an in-memory writer. + auto sink = std::make_unique( + 1024 * 1024, dwio::common::FileSink::Options{.pool = leafPool_.get()}); + auto sinkPtr = sink.get(); + const auto data = makeRowVector( + {makeFlatVector({1}), + makeArrayVectorFromJson({"[4 ,5]"})}); + parquet::WriterOptions writerOptions; + writerOptions.memoryPool = leafPool_.get(); + + // key, element are Parquet reserved keywords. + // Ensure we handle them properly during the schema inference. + auto schema = ROW({"key", "element"}, {BIGINT(), ARRAY(INTEGER())}); + + auto writer = std::make_unique( + std::move(sink), writerOptions, rootPool_, schema); + writer->write(data); + writer->close(); + + // Create the reader. + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setFileSchema(schema); + std::string dataBuf(sinkPtr->data(), sinkPtr->size()); + auto file = std::make_shared(std::move(dataBuf)); + auto buffer = std::make_unique( + file, readerOptions.memoryPool()); + ParquetReader reader(std::move(buffer), readerOptions); + + EXPECT_EQ(reader.rowType()->toString(), schema->toString()); +} + +TEST_F(ParquetReaderTest, columnStatistics) { + auto data = makeRowVector( + {"a", "b", "c"}, + { + makeFlatVector({1, 2, 3, 4, 5}), + makeFlatVector({1.1, 2.2, 3.3, 4.4, 5.5}), + makeFlatVector({"aaa", "bbb", "ccc", "ddd", "eee"}), + }); + + auto* sink = write(data); + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto reader = createReaderInMemory(*sink, readerOptions); + const auto& schema = reader->typeWithId(); + + // Root ROW type — no stats for non-leaf. + EXPECT_EQ(reader->columnStatistics(schema->id()), nullptr); + + // Out of range. + EXPECT_EQ(reader->columnStatistics(schema->maxId() + 1), nullptr); + + // BIGINT column. + { + auto stats = reader->columnStatistics(schema->childByName("a")->id()); + ASSERT_NE(stats, nullptr); + EXPECT_EQ(stats->getNumberOfValues(), 5); + EXPECT_FALSE(stats->hasNull().value()); + auto* intStats = + dynamic_cast(stats.get()); + ASSERT_NE(intStats, nullptr); + EXPECT_EQ(intStats->getMinimum(), 1); + EXPECT_EQ(intStats->getMaximum(), 5); + } + + // DOUBLE column. + { + auto stats = reader->columnStatistics(schema->childByName("b")->id()); + ASSERT_NE(stats, nullptr); + EXPECT_EQ(stats->getNumberOfValues(), 5); + EXPECT_FALSE(stats->hasNull().value()); + auto* doubleStats = + dynamic_cast(stats.get()); + ASSERT_NE(doubleStats, nullptr); + EXPECT_EQ(doubleStats->getMinimum(), 1.1); + EXPECT_EQ(doubleStats->getMaximum(), 5.5); + } + + // VARCHAR column. + { + auto stats = reader->columnStatistics(schema->childByName("c")->id()); + ASSERT_NE(stats, nullptr); + EXPECT_EQ(stats->getNumberOfValues(), 5); + EXPECT_FALSE(stats->hasNull().value()); + auto* stringStats = + dynamic_cast(stats.get()); + ASSERT_NE(stringStats, nullptr); + EXPECT_EQ(stringStats->getMinimum(), "aaa"); + EXPECT_EQ(stringStats->getMaximum(), "eee"); + } +} + +TEST_F(ParquetReaderTest, columnStatisticsWithNulls) { + auto data = makeRowVector( + {"a"}, + { + makeNullableFlatVector( + {1, std::nullopt, 3, std::nullopt, 5}), + }); + + auto* sink = write(data); + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto reader = createReaderInMemory(*sink, readerOptions); + const auto& schema = reader->typeWithId(); + + auto stats = reader->columnStatistics(schema->childByName("a")->id()); + ASSERT_NE(stats, nullptr); + EXPECT_EQ(stats->getNumberOfValues(), 3); + EXPECT_TRUE(stats->hasNull().value()); + auto* intStats = + dynamic_cast(stats.get()); + ASSERT_NE(intStats, nullptr); + EXPECT_EQ(intStats->getMinimum(), 1); + EXPECT_EQ(intStats->getMaximum(), 5); +} + +TEST_F(ParquetReaderTest, columnStatisticsMultipleRowGroups) { + // Use a small flush size to force multiple row groups. + parquet::WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); + writerOptions.flushPolicyFactory = []() { + return std::make_unique( + /*rowsInRowGroup=*/5, + /*bytesInRowGroup=*/1'024 * 1'024, + []() { return false; }); + }; + + auto data = makeRowVector( + {"a"}, + { + makeFlatVector({10, 20, 30, 40, 50, 1, 2, 3, 4, 5}), + }); + + auto* sink = write(data, writerOptions); + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto reader = createReaderInMemory(*sink, readerOptions); + + // Verify we have multiple row groups. + ASSERT_GT(reader->fileMetaData().numRowGroups(), 1); + + const auto& schema = reader->typeWithId(); + auto stats = reader->columnStatistics(schema->childByName("a")->id()); + ASSERT_NE(stats, nullptr); + EXPECT_EQ(stats->getNumberOfValues(), 10); + EXPECT_FALSE(stats->hasNull().value()); + auto* intStats = + dynamic_cast(stats.get()); + ASSERT_NE(intStats, nullptr); + // Global min/max across all row groups. + EXPECT_EQ(intStats->getMinimum(), 1); + EXPECT_EQ(intStats->getMaximum(), 50); +} + +TEST_F(ParquetReaderTest, readTimeMillis) { + // Write TIME data using the parquet writer. + // The writer exports Velox TIME as Arrow time32 with milliseconds unit, + // which maps to Parquet TIME_MILLIS (INT32). + const auto rowType = ROW({"time_col"}, {TIME()}); + + auto data = makeRowVector( + rowType->names(), + {makeNullableFlatVector( + {1, 1'000, 3'600'000, std::nullopt, 43'200'000, 86'399'999}, + TIME())}); + auto sink = write(data); + + dwio::common::ReaderOptions readerOpts(leafPool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); + auto reader = createReaderInMemory(*sink, readerOpts); + + EXPECT_EQ(reader->numberOfRows(), 6ULL); + + auto type = reader->typeWithId(); + EXPECT_EQ(type->size(), 1ULL); + auto col0 = type->childAt(0); + EXPECT_TRUE(col0->type()->isTime()); + + auto rowReaderOpts = getReaderOpts(rowType); + rowReaderOpts.setScanSpec(makeScanSpec(rowType)); + auto rowReader = reader->createRowReader(rowReaderOpts); + + assertReadWithReaderAndExpected(rowType, *rowReader, data, *leafPool_); +} + +TEST_F(ParquetReaderTest, readTimeMicros) { + // Write TIME MICRO UTC data using the parquet writer. + // The writer exports Velox TIME MICRO UTC as Arrow time64 with microseconds + // unit, which maps to Parquet TIME_MICROS (INT64). + const auto rowType = ROW({"time_col"}, {TIME_MICRO_UTC()}); + + auto data = makeRowVector( + rowType->names(), + {makeNullableFlatVector( + {0, + 1, + 1'000, + 3'600'000'000, + std::nullopt, + 43'200'000'000, + 86'399'999'999}, + TIME_MICRO_UTC())}); + auto sink = write(data); + + dwio::common::ReaderOptions readerOpts{leafPool_.get()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); + auto reader = createReaderInMemory(*sink, readerOpts); + + EXPECT_EQ(reader->numberOfRows(), 7ULL); + + auto type = reader->typeWithId(); + EXPECT_EQ(type->size(), 1ULL); + auto col0 = type->childAt(0); + EXPECT_TRUE(col0->type()->isTime()); + EXPECT_TRUE(col0->type()->equivalent(*TIME_MICRO_UTC())); + + auto rowReaderOpts = getReaderOpts(rowType); + rowReaderOpts.setScanSpec(makeScanSpec(rowType)); + auto rowReader = reader->createRowReader(rowReaderOpts); + + assertReadWithReaderAndExpected(rowType, *rowReader, data, *leafPool_); +} + +TEST_F(ParquetReaderTest, readTimeWithMultipleColumns) { + const auto rowType = + ROW({"id", "time_col", "name"}, {INTEGER(), TIME(), VARCHAR()}); + + auto data = makeRowVector( + rowType->names(), + { + makeFlatVector({1, 2, 3, 4, 5}), + makeFlatVector( + {0, 1'000, 3'600'000, 43'200'000, 86'399'999}, TIME()), + makeFlatVector({"a", "b", "c", "d", "e"}), + }); + + auto sink = write(data); + + dwio::common::ReaderOptions readerOpts(leafPool_.get()); + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); + auto reader = createReaderInMemory(*sink, readerOpts); + + EXPECT_EQ(reader->numberOfRows(), 5ULL); + + auto type = reader->typeWithId(); + EXPECT_EQ(type->size(), 3ULL); + EXPECT_EQ(type->childAt(0)->type()->kind(), TypeKind::INTEGER); + EXPECT_TRUE(type->childAt(1)->type()->isTime()); + EXPECT_EQ(type->childAt(2)->type()->kind(), TypeKind::VARCHAR); + + auto rowReaderOpts = getReaderOpts(rowType); + rowReaderOpts.setScanSpec(makeScanSpec(rowType)); + auto rowReader = reader->createRowReader(rowReaderOpts); + + assertReadWithReaderAndExpected(rowType, *rowReader, data, *leafPool_); +} diff --git a/velox/dwio/parquet/tests/reader/ParquetReaderWideningTest.cpp b/velox/dwio/parquet/tests/reader/ParquetReaderWideningTest.cpp new file mode 100644 index 00000000000..557b24bb5df --- /dev/null +++ b/velox/dwio/parquet/tests/reader/ParquetReaderWideningTest.cpp @@ -0,0 +1,1512 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/dwio/common/Mutation.h" +#include "velox/dwio/parquet/tests/ParquetTestBase.h" +#include "velox/expression/ExprToSubfieldFilter.h" +#include "velox/vector/tests/utils/VectorMaker.h" + +using namespace facebook::velox; +using namespace facebook::velox::common; +using namespace facebook::velox::dwio::common; +using namespace facebook::velox::parquet; + +class ParquetReaderWideningTest : public ParquetTestBase { + public: + std::unique_ptr createWideningRowReader( + const RowVectorPtr& writeData, + const RowTypePtr& readSchema, + bool allowInt32Narrowing = false) { + auto* sink = write(writeData); + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setFileSchema(readSchema); + readerOptions.setAllowInt32Narrowing(allowInt32Narrowing); + auto reader = createReaderInMemory(*sink, readerOptions); + auto rowReaderOpts = getReaderOpts(readSchema); + rowReaderOpts.setScanSpec(makeScanSpec(readSchema)); + return reader->createRowReader(rowReaderOpts); + } + + /// Writes Parquet data with one schema and reads it back with a wider schema, + /// then verifies the result matches the expected output. + void assertWideningReads( + const RowVectorPtr& writeData, + const RowTypePtr& readSchema, + const RowVectorPtr& expected) { + auto rowReader = createWideningRowReader(writeData, readSchema); + assertReadWithReaderAndExpected( + readSchema, *rowReader, expected, *leafPool_); + } + + /// Writes Parquet data and reads it back with a narrower schema + /// (allowInt32Narrowing enabled), then verifies the result. + void assertNarrowingReads( + const RowVectorPtr& writeData, + const RowTypePtr& readSchema, + const RowVectorPtr& expected) { + auto rowReader = createWideningRowReader( + writeData, readSchema, /*allowInt32Narrowing=*/true); + assertReadWithReaderAndExpected( + readSchema, *rowReader, expected, *leafPool_); + } + + /// Writes Parquet data, reads with widening schema + filter, verifies result. + void assertWideningWithFilter( + const RowVectorPtr& writeData, + const RowTypePtr& readSchema, + std::unique_ptr filter, + const RowVectorPtr& expected, + bool allowInt32Narrowing = false) { + auto* sink = write(writeData); + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setFileSchema(readSchema); + readerOptions.setAllowInt32Narrowing(allowInt32Narrowing); + auto reader = createReaderInMemory(*sink, readerOptions); + auto rowReaderOpts = getReaderOpts(readSchema); + auto scanSpec = makeScanSpec(readSchema); + auto* child = scanSpec->getOrCreateChild(common::Subfield("col")); + child->setFilter(std::move(filter)); + rowReaderOpts.setScanSpec(scanSpec); + auto rowReader = reader->createRowReader(rowReaderOpts); + assertReadWithReaderAndExpected( + readSchema, *rowReader, expected, *leafPool_); + } + + /// Verifies that reading in-memory Parquet data with a mismatched schema + /// throws an exception whose message contains both the source type name and + /// "is not allowed for requested type". + void assertWideningThrows( + const RowVectorPtr& writeData, + const RowTypePtr& readSchema, + const std::string& sourceTypeName) { + auto* sink = write(writeData); + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setFileSchema(readSchema); + VELOX_ASSERT_THROW( + createReaderInMemory(*sink, readerOptions), + "Converted type " + sourceTypeName + + " is not allowed for requested type"); + } +}; + +// Comprehensive test matrix covering all combinations: +// - Nulls: No nulls, With nulls +// - Dictionary: Enabled, Disabled +// - Filter: None, IsNull, IsNotNull, Value filter +// - Density: Dense (no deletions), Non-dense (with deletions/mutations) + +enum class FloatToDoubleFilter { + kNone, + kIsNull, + kIsNotNull, + kGreaterThanOrEqual, // Value filter: greater than or equal to a threshold + kMultiRange, // MultiRange filter: a < X OR a > Y +}; + +struct FloatToDoubleSpec { + std::vector> values; + std::vector ids; + bool enableDictionary{true}; + FloatToDoubleFilter filter{FloatToDoubleFilter::kNone}; + std::optional filterValue; // Value for value-based filters + std::optional filterLowerValue; // Lower bound for MultiRange filter + std::optional filterUpperValue; // Upper bound for MultiRange filter + std::vector deletedRows; +}; + +struct FloatToDoubleTestParam { + bool hasNulls; + bool enableDictionary; + FloatToDoubleFilter filter; + bool isDense; + + std::string toString() const { + return fmt::format( + "Nulls_{}_Dict_{}_Filter_{}_Dense_{}", + hasNulls ? "Yes" : "No", + enableDictionary ? "Yes" : "No", + filterName(filter), + isDense ? "Yes" : "No"); + } + + static std::string filterName(FloatToDoubleFilter filter) { + switch (filter) { + case FloatToDoubleFilter::kNone: + return "None"; + case FloatToDoubleFilter::kIsNull: + return "IsNull"; + case FloatToDoubleFilter::kIsNotNull: + return "IsNotNull"; + case FloatToDoubleFilter::kGreaterThanOrEqual: + return "GreaterThanOrEqual"; + case FloatToDoubleFilter::kMultiRange: + return "MultiRange"; + default: + return "Unknown"; + } + } +}; + +class FloatToDoubleEvolutionTest + : public ParquetReaderWideningTest, + public testing::WithParamInterface { + public: + static std::vector getTestParams() { + std::vector params; + for (bool hasNulls : {false, true}) { + for (bool enableDictionary : {false, true}) { + // When hasNulls is false, only test kNone, kGreaterThanOrEqual, and + // kMultiRange filter (kIsNull would match nothing, kIsNotNull is + // equivalent to kNone) + std::vector filters; + if (hasNulls) { + filters = { + FloatToDoubleFilter::kNone, + FloatToDoubleFilter::kIsNull, + FloatToDoubleFilter::kIsNotNull, + FloatToDoubleFilter::kGreaterThanOrEqual, + FloatToDoubleFilter::kMultiRange}; + } else { + filters = { + FloatToDoubleFilter::kNone, + FloatToDoubleFilter::kGreaterThanOrEqual, + FloatToDoubleFilter::kMultiRange}; + } + + for (auto filter : filters) { + for (bool isDense : {true, false}) { + params.push_back({hasNulls, enableDictionary, filter, isDense}); + } + } + } + } + return params; + } + + void runFloatToDoubleScenario(const FloatToDoubleSpec& spec); +}; + +void FloatToDoubleEvolutionTest::runFloatToDoubleScenario( + const FloatToDoubleSpec& spec) { + ASSERT_EQ(spec.values.size(), spec.ids.size()); + const vector_size_t numRows = spec.ids.size(); + + auto floatVector = makeNullableFlatVector(spec.values); + auto idVector = + makeFlatVector(numRows, [&](auto row) { return spec.ids[row]; }); + + RowVectorPtr writeData = makeRowVector({floatVector, idVector}); + RowTypePtr writeSchema = ROW({"float_col", "id"}, {REAL(), BIGINT()}); + + auto sink = std::make_unique( + 1024 * 1024, dwio::common::FileSink::Options{.pool = leafPool_.get()}); + auto sinkPtr = sink.get(); + + parquet::WriterOptions writerOptions; + writerOptions.memoryPool = leafPool_.get(); + writerOptions.enableDictionary = spec.enableDictionary; + + auto writer = std::make_unique( + std::move(sink), writerOptions, rootPool_, writeSchema); + writer->write(writeData); + writer->close(); + + RowTypePtr readSchema = ROW({"float_col", "id"}, {DOUBLE(), BIGINT()}); + + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setFileSchema(readSchema); + + std::string dataBuf(sinkPtr->data(), sinkPtr->size()); + auto file = std::make_shared(std::move(dataBuf)); + auto buffer = std::make_unique( + file, readerOptions.memoryPool()); + auto reader = + std::make_unique(std::move(buffer), readerOptions); + + RowReaderOptions rowReaderOpts; + rowReaderOpts.select( + std::make_shared( + readSchema, readSchema->names())); + auto scanSpec = makeScanSpec(readSchema); + + // Apply IsNull or IsNotNull filter if specified + switch (spec.filter) { + case FloatToDoubleFilter::kNone: + break; + case FloatToDoubleFilter::kIsNull: { + auto* floatChild = + scanSpec->getOrCreateChild(common::Subfield("float_col")); + floatChild->setFilter(exec::isNull()); + break; + } + case FloatToDoubleFilter::kIsNotNull: { + auto* floatChild = + scanSpec->getOrCreateChild(common::Subfield("float_col")); + floatChild->setFilter(exec::isNotNull()); + break; + } + case FloatToDoubleFilter::kGreaterThanOrEqual: { + ASSERT_TRUE(spec.filterValue.has_value()); + auto* floatChild = + scanSpec->getOrCreateChild(common::Subfield("float_col")); + floatChild->setFilter( + exec::greaterThanOrEqualDouble(spec.filterValue.value())); + break; + } + case FloatToDoubleFilter::kMultiRange: { + ASSERT_TRUE(spec.filterLowerValue.has_value()); + ASSERT_TRUE(spec.filterUpperValue.has_value()); + auto* floatChild = + scanSpec->getOrCreateChild(common::Subfield("float_col")); + // Create a MultiRange filter: a < lower OR a > upper + floatChild->setFilter( + exec::orFilter( + exec::lessThanDouble(spec.filterLowerValue.value()), + exec::greaterThanDouble(spec.filterUpperValue.value()))); + break; + } + } + + rowReaderOpts.setScanSpec(scanSpec); + auto rowReader = reader->createRowReader(rowReaderOpts); + + std::vector deletedFlags(numRows, false); + for (auto index : spec.deletedRows) { + ASSERT_LT(index, numRows); + deletedFlags[index] = true; + } + + std::vector expectedIndices; + expectedIndices.reserve(numRows); + for (vector_size_t row = 0; row < numRows; ++row) { + if (deletedFlags[row]) { + continue; + } + + bool passes = false; + switch (spec.filter) { + case FloatToDoubleFilter::kNone: + passes = true; + break; + case FloatToDoubleFilter::kIsNull: + passes = !spec.values[row].has_value(); + break; + case FloatToDoubleFilter::kIsNotNull: + passes = spec.values[row].has_value(); + break; + case FloatToDoubleFilter::kGreaterThanOrEqual: + passes = spec.values[row].has_value() && + static_cast(*spec.values[row]) >= spec.filterValue.value(); + break; + case FloatToDoubleFilter::kMultiRange: + passes = spec.values[row].has_value() && + (static_cast(*spec.values[row]) < + spec.filterLowerValue.value() || + static_cast(*spec.values[row]) > + spec.filterUpperValue.value()); + break; + } + + if (passes) { + expectedIndices.push_back(row); + } + } + + std::vector> expectedDoubles(expectedIndices.size()); + for (size_t i = 0; i < expectedIndices.size(); ++i) { + const auto originalIndex = expectedIndices[i]; + if (!spec.values[originalIndex].has_value()) { + expectedDoubles[i] = std::nullopt; + } else { + expectedDoubles[i] = static_cast(*spec.values[originalIndex]); + } + } + + auto expectedFloat = makeNullableFlatVector(expectedDoubles); + auto expectedId = makeFlatVector( + expectedIndices.size(), + [&](auto row) { return spec.ids[expectedIndices[row]]; }); + RowVectorPtr expected = makeRowVector({expectedFloat, expectedId}); + + if (spec.deletedRows.empty() && spec.filter != FloatToDoubleFilter::kIsNull && + spec.filter != FloatToDoubleFilter::kIsNotNull && + spec.filter != FloatToDoubleFilter::kGreaterThanOrEqual && + spec.filter != FloatToDoubleFilter::kMultiRange) { + assertReadWithReaderAndExpected( + readSchema, *rowReader, expected, *leafPool_); + return; + } + + VectorPtr result = BaseVector::create(readSchema, 0, leafPool_.get()); + vector_size_t scanned = 0; + std::vector deleted(bits::nwords(numRows), 0); + if (spec.deletedRows.empty()) { + scanned = rowReader->next(numRows, result); + } else { + for (auto index : spec.deletedRows) { + bits::setBit(deleted.data(), index); + } + dwio::common::Mutation mutation; + mutation.deletedRows = deleted.data(); + scanned = rowReader->next(numRows, result, &mutation); + } + + EXPECT_GT(scanned, 0); + EXPECT_GE(scanned, expected->size()); + ASSERT_TRUE(result != nullptr); + auto rowVector = result->as(); + ASSERT_TRUE(rowVector != nullptr); + ASSERT_EQ(rowVector->size(), expected->size()); + assertEqualVectorPart(expected, result, 0); +} + +TEST_P(FloatToDoubleEvolutionTest, readFloatToDouble) { + const auto& param = GetParam(); + FloatToDoubleSpec spec; + constexpr vector_size_t kSize = 200; + spec.enableDictionary = param.enableDictionary; + spec.values.resize(kSize); + spec.ids.resize(kSize); + + for (vector_size_t row = 0; row < kSize; ++row) { + if (param.hasNulls && row % 5 == 0) { + spec.values[row] = std::nullopt; + } else { + // Use a value pattern that works for both dictionary and direct encoding + float val = + static_cast(row % 10) * 1.1f + static_cast(row) * 0.01f; + spec.values[row] = val; + } + spec.ids[row] = row; + } + + spec.filter = param.filter; + + // Set filter value for value-based filters + if (param.filter == FloatToDoubleFilter::kGreaterThanOrEqual) { + // Filter values greater than or equal to 5.0 (this should match + // approximately half the rows) + spec.filterValue = 5.0; + } else if (param.filter == FloatToDoubleFilter::kMultiRange) { + // Filter values < 3.0 OR > 7.0 + spec.filterLowerValue = 3.0; + spec.filterUpperValue = 7.0; + } + + if (!param.isDense) { + // Add some deleted rows scattered throughout + spec.deletedRows = {5, 20, 55, 99, 150, 199}; + } + + runFloatToDoubleScenario(spec); +} + +INSTANTIATE_TEST_SUITE_P( + FloatToDoubleEvolution, + FloatToDoubleEvolutionTest, + testing::ValuesIn(FloatToDoubleEvolutionTest::getTestParams()), + [](const testing::TestParamInfo& info) { + return info.param.toString(); + }); + +// Type widening tests: verify reading Parquet columns with a wider target +// type than the physical type stored in the file. + +TEST_F(ParquetReaderWideningTest, intToShortDecimalWidening) { + auto writeData = makeRowVector({makeFlatVector( + {0, 1, -1, 100, -100, 2'147'483'647, -2'147'483'648})}); + auto expected = makeRowVector({makeFlatVector( + {0, 100, -100, 10'000, -10'000, 214'748'364'700LL, -214'748'364'800LL}, + DECIMAL(12, 2))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(12, 2)}), expected); +} + +TEST_F(ParquetReaderWideningTest, smallintToShortDecimalWidening) { + auto writeData = makeRowVector( + {makeFlatVector({0, 1, -1, 100, 32'767, -32'768})}); + auto expected = makeRowVector({makeFlatVector( + {0, 100, -100, 10'000, 3'276'700, -3'276'800}, DECIMAL(12, 2))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(12, 2)}), expected); +} + +TEST_F(ParquetReaderWideningTest, tinyintToShortDecimalWidening) { + auto writeData = + makeRowVector({makeFlatVector({0, 1, -1, 100, 127, -128})}); + auto expected = makeRowVector({makeFlatVector( + {0, 100, -100, 10'000, 12'700, -12'800}, DECIMAL(12, 2))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(12, 2)}), expected); +} + +// Parquet stores TINYINT as INT32, so the minimum precision for decimal +// widening is precision-scale >= 10 (same as INT32). Test exact boundary. +TEST_F(ParquetReaderWideningTest, tinyintToDecimalMinPrecisionWidening) { + auto writeData = + makeRowVector({makeFlatVector({0, 1, -1, 127, -128})}); + auto expected = makeRowVector( + {makeFlatVector({0, 1, -1, 127, -128}, DECIMAL(10, 0))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(10, 0)}), expected); +} + +// Parquet stores SMALLINT as INT32, so decimal widening requires +// precision-scale >= 10 (same as INT32). Test exact boundary. +TEST_F(ParquetReaderWideningTest, smallintToDecimalMinPrecisionWidening) { + auto writeData = + makeRowVector({makeFlatVector({0, 1, -1, 32'767, -32'768})}); + auto expected = makeRowVector( + {makeFlatVector({0, 1, -1, 32'767, -32'768}, DECIMAL(10, 0))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(10, 0)}), expected); +} + +// Byte -> Long Decimal. Parquet stores TINYINT as INT32. +TEST_F(ParquetReaderWideningTest, tinyintToLongDecimalWidening) { + auto writeData = + makeRowVector({makeFlatVector({0, 1, -1, 127, -128})}); + auto expected = makeRowVector( + {makeFlatVector({0, 1, -1, 127, -128}, DECIMAL(20, 0))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(20, 0)}), expected); +} + +// Short -> Long Decimal. Parquet stores SMALLINT as INT32. +TEST_F(ParquetReaderWideningTest, smallintToLongDecimalWidening) { + auto writeData = + makeRowVector({makeFlatVector({0, 1, -1, 32'767, -32'768})}); + auto expected = makeRowVector( + {makeFlatVector({0, 1, -1, 32'767, -32'768}, DECIMAL(20, 0))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(20, 0)}), expected); +} + +TEST_F(ParquetReaderWideningTest, bigintToLongDecimalWidening) { + auto writeData = makeRowVector({makeFlatVector( + {0, 1, -1, 1'000'000'000'000LL, -1'000'000'000'000LL})}); + auto expected = makeRowVector({makeFlatVector( + {0, + 100'000, + -100'000, + static_cast(1'000'000'000'000LL) * 100'000, + static_cast(-1'000'000'000'000LL) * 100'000}, + DECIMAL(25, 5))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(25, 5)}), expected); +} + +TEST_F(ParquetReaderWideningTest, decimalToDecimalWidening) { + auto writeData = makeRowVector( + {makeFlatVector({1111, 2222, 3333, -4444, 0}, DECIMAL(7, 2))}); + // Each value v becomes v * 10^(4-2) = v * 100. + auto expected = makeRowVector({makeFlatVector( + {111'100, 222'200, 333'300, -444'400, 0}, DECIMAL(10, 4))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(10, 4)}), expected); +} + +TEST_F(ParquetReaderWideningTest, decimalToDecimalPrecisionOnlyWidening) { + auto writeData = makeRowVector( + {makeFlatVector({1111, 2222, 3333, -4444, 0}, DECIMAL(7, 2))}); + auto expected = makeRowVector( + {makeFlatVector({1111, 2222, 3333, -4444, 0}, DECIMAL(10, 2))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(10, 2)}), expected); +} + +// Decimal long -> long, precision-only widening (scale stays the same). +TEST_F(ParquetReaderWideningTest, decimalLongToLongPrecisionOnlyWidening) { + auto writeData = makeRowVector( + {makeFlatVector({1111, -2222, 0}, DECIMAL(20, 2))}); + auto expected = makeRowVector( + {makeFlatVector({1111, -2222, 0}, DECIMAL(22, 2))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(22, 2)}), expected); +} + +TEST_F(ParquetReaderWideningTest, decimalToDecimalWideningWithNulls) { + auto writeData = makeRowVector({makeNullableFlatVector( + {std::nullopt, 1111, std::nullopt, -4444, 0, std::nullopt}, + DECIMAL(7, 2))}); + // Each value v becomes v * 10^(4-2) = v * 100. + auto expected = makeRowVector({makeNullableFlatVector( + {std::nullopt, 111'100, std::nullopt, -444'400, 0, std::nullopt}, + DECIMAL(10, 4))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(10, 4)}), expected); +} + +TEST_F(ParquetReaderWideningTest, intToDecimalWideningWithNulls) { + auto writeData = makeRowVector({makeNullableFlatVector( + {std::nullopt, 42, std::nullopt, -7, 0, std::nullopt})}); + auto expected = makeRowVector({makeNullableFlatVector( + {std::nullopt, 4200, std::nullopt, -700, 0, std::nullopt}, + DECIMAL(12, 2))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(12, 2)}), expected); +} + +// All-null column: getDecimalValues must handle ConstantVector (all nulls) +// without crashing when scaleAdjust > 0. +TEST_F(ParquetReaderWideningTest, intToDecimalWideningAllNull) { + auto writeData = makeRowVector({makeNullableFlatVector( + {std::nullopt, std::nullopt, std::nullopt})}); + auto expected = makeRowVector({makeNullableFlatVector( + {std::nullopt, std::nullopt, std::nullopt}, DECIMAL(12, 2))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(12, 2)}), expected); +} + +TEST_F(ParquetReaderWideningTest, decimalToDecimalWideningAllNull) { + auto writeData = makeRowVector({makeNullableFlatVector( + {std::nullopt, std::nullopt, std::nullopt}, DECIMAL(7, 2))}); + auto expected = makeRowVector({makeNullableFlatVector( + {std::nullopt, std::nullopt, std::nullopt}, DECIMAL(10, 4))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(10, 4)}), expected); +} + +// INT32 -> SMALLINT narrowing. Parquet stores both as INT32; reading with +// a narrower type truncates to 16 bits via static_cast. +TEST_F(ParquetReaderWideningTest, intToSmallintNarrowing) { + auto writeData = + makeRowVector({makeFlatVector({0, 1, -1, 100, -100})}); + auto expected = + makeRowVector({makeFlatVector({0, 1, -1, 100, -100})}); + assertNarrowingReads(writeData, ROW({"col"}, {SMALLINT()}), expected); +} + +TEST_F(ParquetReaderWideningTest, smallintToTinyintNarrowing) { + auto writeData = + makeRowVector({makeFlatVector({0, 1, -1, 42, -42})}); + auto expected = makeRowVector({makeFlatVector({0, 1, -1, 42, -42})}); + assertNarrowingReads(writeData, ROW({"col"}, {TINYINT()}), expected); +} + +TEST_F(ParquetReaderWideningTest, shortToIntegerWidening) { + auto writeData = + makeRowVector({makeFlatVector({0, 1, -1, 32'767, -32'768})}); + auto expected = + makeRowVector({makeFlatVector({0, 1, -1, 32'767, -32'768})}); + assertWideningReads(writeData, ROW({"col"}, {INTEGER()}), expected); +} + +TEST_F(ParquetReaderWideningTest, intToBigintWidening) { + auto writeData = makeRowVector( + {makeFlatVector({0, 1, -1, 2'147'483'647, -2'147'483'648})}); + auto expected = makeRowVector( + {makeFlatVector({0, 1, -1, 2'147'483'647, -2'147'483'648})}); + assertWideningReads(writeData, ROW({"col"}, {BIGINT()}), expected); +} + +// INT32 -> SMALLINT overflow: 32768 truncates to -32768. +TEST_F(ParquetReaderWideningTest, intToSmallintOverflow) { + auto writeData = makeRowVector({makeFlatVector({32'768})}); + auto expected = makeRowVector({makeFlatVector({-32'768})}); + assertNarrowingReads(writeData, ROW({"col"}, {SMALLINT()}), expected); +} + +// INT16 -> TINYINT overflow: 128 truncates to -128. +TEST_F(ParquetReaderWideningTest, smallintToTinyintOverflow) { + auto writeData = makeRowVector({makeFlatVector({128})}); + auto expected = makeRowVector({makeFlatVector({-128})}); + assertNarrowingReads(writeData, ROW({"col"}, {TINYINT()}), expected); +} + +// INT32 -> DOUBLE works because sizeof(int32_t)=4 != sizeof(double)=8, +// so getFlatValues takes the upcastScalarValues path. +TEST_F(ParquetReaderWideningTest, intToDoubleWidening) { + auto writeData = makeRowVector({makeFlatVector( + {0, 1, -1, 100, -100, 2'147'483'647, -2'147'483'648})}); + auto expected = makeRowVector({makeFlatVector( + {0.0, 1.0, -1.0, 100.0, -100.0, 2'147'483'647.0, -2'147'483'648.0})}); + assertWideningReads(writeData, ROW({"col"}, {DOUBLE()}), expected); +} + +TEST_F(ParquetReaderWideningTest, intToDoubleWideningWithNulls) { + auto writeData = makeRowVector({makeNullableFlatVector( + {std::nullopt, 42, std::nullopt, -7, 0, std::nullopt})}); + auto expected = makeRowVector({makeNullableFlatVector( + {std::nullopt, 42.0, std::nullopt, -7.0, 0.0, std::nullopt})}); + assertWideningReads(writeData, ROW({"col"}, {DOUBLE()}), expected); +} + +// Byte/Short -> Double widening. Parquet stores both as INT32. +TEST_F(ParquetReaderWideningTest, tinyintToDoubleWidening) { + auto writeData = + makeRowVector({makeFlatVector({0, 1, -1, 127, -128})}); + auto expected = + makeRowVector({makeFlatVector({0.0, 1.0, -1.0, 127.0, -128.0})}); + assertWideningReads(writeData, ROW({"col"}, {DOUBLE()}), expected); +} + +TEST_F(ParquetReaderWideningTest, smallintToDoubleWidening) { + auto writeData = + makeRowVector({makeFlatVector({0, 1, -1, 32'767, -32'768})}); + auto expected = makeRowVector( + {makeFlatVector({0.0, 1.0, -1.0, 32'767.0, -32'768.0})}); + assertWideningReads(writeData, ROW({"col"}, {DOUBLE()}), expected); +} + +// INT -> Decimal with scale=0 (exact boundary: p-s=10 for INT32). +TEST_F(ParquetReaderWideningTest, intToDecimalScale0Widening) { + auto writeData = + makeRowVector({makeFlatVector({0, 42, -42, 2'147'483'647})}); + auto expected = makeRowVector( + {makeFlatVector({0, 42, -42, 2'147'483'647}, DECIMAL(10, 0))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(10, 0)}), expected); +} + +// INT -> Decimal with scale=1 (p-s=10, minimum boundary with nonzero scale). +TEST_F(ParquetReaderWideningTest, intToDecimalScale1Widening) { + auto writeData = makeRowVector({makeFlatVector({0, 5, -5, 100})}); + auto expected = makeRowVector( + {makeFlatVector({0, 50, -50, 1000}, DECIMAL(11, 1))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(11, 1)}), expected); +} + +// Byte -> Decimal with nonzero scale. Values multiplied by 10^scale. +TEST_F(ParquetReaderWideningTest, tinyintToDecimalWithScaleWidening) { + auto writeData = + makeRowVector({makeFlatVector({0, 1, -1, 127, -128})}); + auto expected = makeRowVector( + {makeFlatVector({0, 10, -10, 1'270, -1'280}, DECIMAL(11, 1))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(11, 1)}), expected); +} + +// Short -> Decimal with nonzero scale. Values multiplied by 10^scale. +TEST_F(ParquetReaderWideningTest, smallintToDecimalWithScaleWidening) { + auto writeData = + makeRowVector({makeFlatVector({0, 1, -1, 32'767, -32'768})}); + auto expected = makeRowVector({makeFlatVector( + {0, 10, -10, 327'670, -327'680}, DECIMAL(11, 1))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(11, 1)}), expected); +} + +// INT32 -> Long Decimal (crossing short/long boundary). +TEST_F(ParquetReaderWideningTest, intToLongDecimalWidening) { + auto writeData = makeRowVector( + {makeFlatVector({0, 1, -1, 2'147'483'647, -2'147'483'648})}); + auto expected = makeRowVector({makeFlatVector( + {0, 1, -1, 2'147'483'647, -2'147'483'648}, DECIMAL(20, 0))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(20, 0)}), expected); +} + +// BIGINT -> Decimal(38,0), maximum precision long decimal. +TEST_F(ParquetReaderWideningTest, bigintToMaxPrecisionDecimalWidening) { + auto writeData = makeRowVector( + {makeFlatVector({0, 1, -1, 9'223'372'036'854'775'807LL})}); + auto expected = makeRowVector({makeFlatVector( + {0, 1, -1, static_cast(9'223'372'036'854'775'807LL)}, + DECIMAL(38, 0))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(38, 0)}), expected); +} + +// BIGINT -> Decimal(21,1), INT64 with nonzero scale. +TEST_F(ParquetReaderWideningTest, bigintToDecimalWithScaleWidening) { + auto writeData = + makeRowVector({makeFlatVector({0, 1, -1, 999'999})}); + auto expected = makeRowVector( + {makeFlatVector({0, 10, -10, 9'999'990}, DECIMAL(21, 1))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(21, 1)}), expected); +} + +// Decimal short -> long decimal crossing: file stores int64 (short decimal), +// requested type is int128 (long decimal). getIntValues upcasts int64->int128. +TEST_F(ParquetReaderWideningTest, decimalShortToLongWidening) { + auto writeData = makeRowVector( + {makeFlatVector({1111, -2222, 0, 99999}, DECIMAL(5, 2))}); + auto expected = makeRowVector( + {makeFlatVector({1111, -2222, 0, 99999}, DECIMAL(20, 2))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(20, 2)}), expected); +} + +TEST_F(ParquetReaderWideningTest, decimalShortToLongWithScaleWidening) { + auto writeData = + makeRowVector({makeFlatVector({1111, -2222, 0}, DECIMAL(5, 2))}); + // v * 10^(4-2) = v * 100 + auto expected = makeRowVector( + {makeFlatVector({111'100, -222'200, 0}, DECIMAL(20, 4))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(20, 4)}), expected); +} + +// INT -> Decimal(38,0) max precision. +TEST_F(ParquetReaderWideningTest, smallintToMaxPrecisionDecimalWidening) { + auto writeData = makeRowVector({makeFlatVector({0, 1, -1, 32'767})}); + auto expected = makeRowVector( + {makeFlatVector({0, 1, -1, 32'767}, DECIMAL(38, 0))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(38, 0)}), expected); +} + +TEST_F(ParquetReaderWideningTest, intToMaxPrecisionDecimalWidening) { + auto writeData = + makeRowVector({makeFlatVector({0, 1, -1, 2'147'483'647})}); + auto expected = makeRowVector( + {makeFlatVector({0, 1, -1, 2'147'483'647}, DECIMAL(38, 0))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(38, 0)}), expected); +} + +// Decimal(5,2) -> (10,7) -- short to short with large scale increase. +TEST_F(ParquetReaderWideningTest, decimalWideningLargeScaleIncrease) { + auto writeData = + makeRowVector({makeFlatVector({1111, -2222, 0}, DECIMAL(5, 2))}); + // v * 10^(7-2) = v * 100000 + auto expected = makeRowVector({makeFlatVector( + {111'100'000, -222'200'000, 0}, DECIMAL(10, 7))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(10, 7)}), expected); +} + +// Decimal(5,2) -> (20,17) -- short to long with large scale increase. +TEST_F(ParquetReaderWideningTest, decimalShortToLongLargeScaleWidening) { + auto writeData = + makeRowVector({makeFlatVector({1111, -2222, 0}, DECIMAL(5, 2))}); + // v * 10^(17-2) = v * 10^15 + auto expected = makeRowVector({makeFlatVector( + {static_cast(1111) * 1'000'000'000'000'000LL, + static_cast(-2222) * 1'000'000'000'000'000LL, + 0}, + DECIMAL(20, 17))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(20, 17)}), expected); +} + +// Decimal(10,2) -> (12,4) -- short to short with precision and scale increase. +TEST_F(ParquetReaderWideningTest, decimalShortWideningPrecisionAndScale) { + auto writeData = makeRowVector( + {makeFlatVector({12345, -67890, 0}, DECIMAL(10, 2))}); + // v * 10^(4-2) = v * 100 + auto expected = makeRowVector( + {makeFlatVector({1'234'500, -6'789'000, 0}, DECIMAL(12, 4))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(12, 4)}), expected); +} + +// Decimal(10,2) -> (20,12) -- short to long with large scale increase. +TEST_F( + ParquetReaderWideningTest, + decimalShortToLongLargeScaleWideningHighPrecision) { + auto writeData = makeRowVector( + {makeFlatVector({12345, -67890, 0}, DECIMAL(10, 2))}); + // v * 10^(12-2) = v * 10^10 + auto expected = makeRowVector({makeFlatVector( + {static_cast(12345) * 10'000'000'000LL, + static_cast(-67890) * 10'000'000'000LL, + 0}, + DECIMAL(20, 12))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(20, 12)}), expected); +} + +// Decimal(20,2) -> (22,4) -- long to long with scale increase. +TEST_F(ParquetReaderWideningTest, decimalLongToLongWithScaleWidening) { + auto writeData = makeRowVector( + {makeFlatVector({12345, -67890, 0}, DECIMAL(20, 2))}); + // v * 10^(4-2) = v * 100 + auto expected = makeRowVector( + {makeFlatVector({1'234'500, -6'789'000, 0}, DECIMAL(22, 4))}); + assertWideningReads(writeData, ROW({"col"}, {DECIMAL(22, 4)}), expected); +} + +TEST_F(ParquetReaderWideningTest, typeWideningRejectionIncompatibleTypes) { + // INT32 -> FLOAT is not supported. FLOAT has only ~7 significant digits + // vs INT32's 10, which would cause silent precision loss. + assertWideningThrows( + makeRowVector({makeFlatVector({1, 2, 3})}), + ROW({"col"}, {REAL()}), + "INTEGER"); + + // BIGINT -> DOUBLE is not supported. + assertWideningThrows( + makeRowVector({makeFlatVector({1, 2, 3})}), + ROW({"col"}, {DOUBLE()}), + "BIGINT"); + + // BIGINT -> INTEGER/SMALLINT/TINYINT narrowing is not supported. + assertWideningThrows( + makeRowVector({makeFlatVector({1, 2, 3})}), + ROW({"col"}, {INTEGER()}), + "BIGINT"); + assertWideningThrows( + makeRowVector({makeFlatVector({1, 2, 3})}), + ROW({"col"}, {SMALLINT()}), + "BIGINT"); + assertWideningThrows( + makeRowVector({makeFlatVector({1, 2, 3})}), + ROW({"col"}, {TINYINT()}), + "BIGINT"); + + // DOUBLE -> FLOAT/INTEGER is not supported. + assertWideningThrows( + makeRowVector({makeFlatVector({1.0, 2.0, 3.0})}), + ROW({"col"}, {REAL()}), + "DOUBLE"); + assertWideningThrows( + makeRowVector({makeFlatVector({1.0, 2.0, 3.0})}), + ROW({"col"}, {INTEGER()}), + "DOUBLE"); + + // FLOAT -> INTEGER/BIGINT is not supported. + assertWideningThrows( + makeRowVector({makeFlatVector({1.0f, 2.0f, 3.0f})}), + ROW({"col"}, {INTEGER()}), + "REAL"); + assertWideningThrows( + makeRowVector({makeFlatVector({1.0f, 2.0f})}), + ROW({"col"}, {BIGINT()}), + "REAL"); + + // BIGINT -> FLOAT is not supported (precision loss). + assertWideningThrows( + makeRowVector({makeFlatVector({1, 2})}), + ROW({"col"}, {REAL()}), + "BIGINT"); +} + +TEST_F(ParquetReaderWideningTest, typeWideningRejectionIntDecimalPrecision) { + // INT32 -> DECIMAL(8,0). p-s=8 < 10, insufficient for INT32. + assertWideningThrows( + makeRowVector({makeFlatVector({1, 2, 3})}), + ROW({"col"}, {DECIMAL(8, 0)}), + "INTEGER"); + + // TINYINT -> DECIMAL(9,0). Parquet stores TINYINT as INT32, so the + // minimum precision is p-s >= 10. p-s=9 < 10. + assertWideningThrows( + makeRowVector({makeFlatVector({1, 2, 3})}), + ROW({"col"}, {DECIMAL(9, 0)}), + "TINYINT"); + + // SMALLINT -> DECIMAL(9,0). Same as TINYINT: stored as INT32, p-s=9 < 10. + assertWideningThrows( + makeRowVector({makeFlatVector({1, 2, 3})}), + ROW({"col"}, {DECIMAL(9, 0)}), + "SMALLINT"); + + // BIGINT -> DECIMAL(18,0). p-s=18 < 20, insufficient for INT64. + assertWideningThrows( + makeRowVector({makeFlatVector({1, 2, 3})}), + ROW({"col"}, {DECIMAL(18, 0)}), + "BIGINT"); + + // Exact boundary: INT32 -> DECIMAL(9,0), p-s=9 < 10. + assertWideningThrows( + makeRowVector({makeFlatVector({1, 2, 3})}), + ROW({"col"}, {DECIMAL(9, 0)}), + "INTEGER"); + + // Exact boundary: BIGINT -> DECIMAL(19,0), p-s=19 < 20. + assertWideningThrows( + makeRowVector({makeFlatVector({1, 2, 3})}), + ROW({"col"}, {DECIMAL(19, 0)}), + "BIGINT"); + + // Rejection with nonzero scale: p-s must still meet the threshold. + // TINYINT -> DECIMAL(3,1). p-s=2 < 10. + assertWideningThrows( + makeRowVector({makeFlatVector({1, 2})}), + ROW({"col"}, {DECIMAL(3, 1)}), + "TINYINT"); + + // INT32 -> DECIMAL(10,1). p-s=9 < 10. + assertWideningThrows( + makeRowVector({makeFlatVector({1, 2})}), + ROW({"col"}, {DECIMAL(10, 1)}), + "INTEGER"); + + // BIGINT -> DECIMAL(20,1). p-s=19 < 20. + assertWideningThrows( + makeRowVector({makeFlatVector({1, 2})}), + ROW({"col"}, {DECIMAL(20, 1)}), + "BIGINT"); + + // INT->Decimal with insufficient precision (various small precisions). + auto tinyintData = makeRowVector({makeFlatVector({1})}); + auto smallintData = makeRowVector({makeFlatVector({1})}); + auto intData = makeRowVector({makeFlatVector({1})}); + auto bigintData = makeRowVector({makeFlatVector({1})}); + assertWideningThrows(tinyintData, ROW({"col"}, {DECIMAL(1, 0)}), "TINYINT"); + assertWideningThrows(tinyintData, ROW({"col"}, {DECIMAL(2, 0)}), "TINYINT"); + assertWideningThrows(tinyintData, ROW({"col"}, {DECIMAL(3, 0)}), "TINYINT"); + assertWideningThrows(smallintData, ROW({"col"}, {DECIMAL(3, 0)}), "SMALLINT"); + assertWideningThrows(smallintData, ROW({"col"}, {DECIMAL(4, 0)}), "SMALLINT"); + assertWideningThrows(smallintData, ROW({"col"}, {DECIMAL(5, 0)}), "SMALLINT"); + assertWideningThrows(intData, ROW({"col"}, {DECIMAL(5, 0)}), "INTEGER"); + assertWideningThrows(bigintData, ROW({"col"}, {DECIMAL(10, 0)}), "BIGINT"); + + // INT->Decimal with nonzero scale, insufficient precision. + assertWideningThrows(tinyintData, ROW({"col"}, {DECIMAL(4, 1)}), "TINYINT"); + assertWideningThrows(smallintData, ROW({"col"}, {DECIMAL(6, 1)}), "SMALLINT"); + assertWideningThrows(smallintData, ROW({"col"}, {DECIMAL(5, 1)}), "SMALLINT"); +} + +TEST_F( + ParquetReaderWideningTest, + typeWideningRejectionDecimalPrecisionDecrease) { + // Decimal precision decrease. + assertWideningThrows( + makeRowVector({makeFlatVector({1111}, DECIMAL(10, 2))}), + ROW({"col"}, {DECIMAL(5, 2)}), + "DECIMAL(10, 2)"); + assertWideningThrows( + makeRowVector({makeFlatVector({1111}, DECIMAL(20, 2))}), + ROW({"col"}, {DECIMAL(5, 2)}), + "DECIMAL(20, 2)"); + assertWideningThrows( + makeRowVector({makeFlatVector({1111}, DECIMAL(12, 2))}), + ROW({"col"}, {DECIMAL(10, 2)}), + "DECIMAL(12, 2)"); + assertWideningThrows( + makeRowVector({makeFlatVector({1111}, DECIMAL(20, 2))}), + ROW({"col"}, {DECIMAL(10, 2)}), + "DECIMAL(20, 2)"); + assertWideningThrows( + makeRowVector({makeFlatVector({1111}, DECIMAL(22, 2))}), + ROW({"col"}, {DECIMAL(20, 2)}), + "DECIMAL(22, 2)"); + + // Decimal precision+scale decrease (precInc < 0). + assertWideningThrows( + makeRowVector({makeFlatVector({1111}, DECIMAL(10, 7))}), + ROW({"col"}, {DECIMAL(5, 2)}), + "DECIMAL(10, 7)"); + assertWideningThrows( + makeRowVector({makeFlatVector({1111}, DECIMAL(20, 17))}), + ROW({"col"}, {DECIMAL(5, 2)}), + "DECIMAL(20, 17)"); + assertWideningThrows( + makeRowVector({makeFlatVector({1111}, DECIMAL(12, 4))}), + ROW({"col"}, {DECIMAL(10, 2)}), + "DECIMAL(12, 4)"); + assertWideningThrows( + makeRowVector({makeFlatVector({1111}, DECIMAL(20, 17))}), + ROW({"col"}, {DECIMAL(10, 2)}), + "DECIMAL(20, 17)"); + assertWideningThrows( + makeRowVector({makeFlatVector({1111}, DECIMAL(22, 4))}), + ROW({"col"}, {DECIMAL(20, 2)}), + "DECIMAL(22, 4)"); +} + +TEST_F(ParquetReaderWideningTest, typeWideningRejectionDecimalScaleViolation) { + // DECIMAL(7,2) -> DECIMAL(8,5). scaleInc=3 > precInc=1. + assertWideningThrows( + makeRowVector({makeFlatVector({1111, 2222}, DECIMAL(7, 2))}), + ROW({"col"}, {DECIMAL(8, 5)}), + "DECIMAL(7, 2)"); + + // DECIMAL(7,2) -> DECIMAL(6,2). Precision decrease is not allowed. + assertWideningThrows( + makeRowVector({makeFlatVector({1111, 2222}, DECIMAL(7, 2))}), + ROW({"col"}, {DECIMAL(6, 2)}), + "DECIMAL(7, 2)"); + + // DECIMAL(7,4) -> DECIMAL(8,2). Scale narrowing (scaleInc < 0) is rejected. + assertWideningThrows( + makeRowVector({makeFlatVector({1111, 2222}, DECIMAL(7, 4))}), + ROW({"col"}, {DECIMAL(8, 2)}), + "DECIMAL(7, 4)"); + + // Scale decrease with precision increase (scaleInc < 0). + assertWideningThrows( + makeRowVector({makeFlatVector({1111}, DECIMAL(10, 6))}), + ROW({"col"}, {DECIMAL(12, 4)}), + "DECIMAL(10, 6)"); + assertWideningThrows( + makeRowVector({makeFlatVector({1111}, DECIMAL(20, 7))}), + ROW({"col"}, {DECIMAL(22, 5)}), + "DECIMAL(20, 7)"); + + // Precision decrease with scale increase. + assertWideningThrows( + makeRowVector({makeFlatVector({1111}, DECIMAL(12, 4))}), + ROW({"col"}, {DECIMAL(10, 6)}), + "DECIMAL(12, 4)"); + assertWideningThrows( + makeRowVector({makeFlatVector({1111}, DECIMAL(22, 5))}), + ROW({"col"}, {DECIMAL(20, 7)}), + "DECIMAL(22, 5)"); + + // scaleInc > precInc (not enough precision for the scale increase). + assertWideningThrows( + makeRowVector({makeFlatVector({1111}, DECIMAL(5, 2))}), + ROW({"col"}, {DECIMAL(6, 4)}), + "DECIMAL(5, 2)"); + assertWideningThrows( + makeRowVector({makeFlatVector({1111}, DECIMAL(10, 4))}), + ROW({"col"}, {DECIMAL(12, 7)}), + "DECIMAL(10, 4)"); + assertWideningThrows( + makeRowVector({makeFlatVector({1111}, DECIMAL(20, 5))}), + ROW({"col"}, {DECIMAL(22, 8)}), + "DECIMAL(20, 5)"); +} + +// Verify allowInt32Narrowing flag on ReaderOptions. +TEST_F(ParquetReaderWideningTest, allowInt32Narrowing) { + // Write INT32 data with values that exercise truncation edge cases. + auto data = makeRowVector( + {"c1"}, + {makeFlatVector( + {0, + 127, + 128, + 255, + 256, + 32767, + 32768, + 65535, + -1, + -128, + -129, + std::numeric_limits::min(), + std::numeric_limits::max()})}); + auto* sink = write(data); + + // Default: flag is false. + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + ASSERT_FALSE(readerOptions.allowInt32Narrowing()); + + // INT32->TINYINT narrowing rejected by default. + readerOptions.setFileSchema(ROW({"c1"}, {TINYINT()})); + VELOX_ASSERT_THROW( + createReaderInMemory(*sink, readerOptions), + "is not allowed for requested type"); + + // INT32->SMALLINT narrowing rejected by default. + readerOptions.setFileSchema(ROW({"c1"}, {SMALLINT()})); + VELOX_ASSERT_THROW( + createReaderInMemory(*sink, readerOptions), + "is not allowed for requested type"); + + // Annotated type-matching always works without the flag. + // INT_8 -> TINYINT: write as TINYINT (produces INT_8 annotation), read back. + { + auto readSchema = ROW({"c1"}, {TINYINT()}); + auto tinyData = + makeRowVector({"c1"}, {makeFlatVector({-128, -1, 0, 1, 127})}); + auto* tinySink = write(tinyData); + readerOptions.setFileSchema(readSchema); + auto reader = createReaderInMemory(*tinySink, readerOptions); + auto rowReaderOpts = getReaderOpts(readSchema); + rowReaderOpts.setScanSpec(makeScanSpec(readSchema)); + auto rowReader = reader->createRowReader(rowReaderOpts); + assertReadWithReaderAndExpected( + readSchema, *rowReader, tinyData, *leafPool_); + } + + // INT_16 -> SMALLINT: write as SMALLINT (produces INT_16 annotation), read + // back. + { + auto readSchema = ROW({"c1"}, {SMALLINT()}); + auto smallData = makeRowVector( + {"c1"}, {makeFlatVector({-32768, -1, 0, 1, 32767})}); + auto* smallSink = write(smallData); + readerOptions.setFileSchema(readSchema); + auto reader = createReaderInMemory(*smallSink, readerOptions); + auto rowReaderOpts = getReaderOpts(readSchema); + rowReaderOpts.setScanSpec(makeScanSpec(readSchema)); + auto rowReader = reader->createRowReader(rowReaderOpts); + assertReadWithReaderAndExpected( + readSchema, *rowReader, smallData, *leafPool_); + } + + // With flag enabled, narrowing is allowed with silent truncation. + readerOptions.setAllowInt32Narrowing(true); + + // INT32->TINYINT: values are truncated via static_cast. + { + auto readSchema = ROW({"c1"}, {TINYINT()}); + readerOptions.setFileSchema(readSchema); + auto reader = createReaderInMemory(*sink, readerOptions); + + auto rowReaderOpts = getReaderOpts(readSchema); + rowReaderOpts.setScanSpec(makeScanSpec(readSchema)); + auto rowReader = reader->createRowReader(rowReaderOpts); + + auto expected = makeRowVector( + {"c1"}, + {makeFlatVector( + {static_cast(0), + static_cast(127), + static_cast(128), // -128 + static_cast(255), // -1 + static_cast(256), // 0 + static_cast(32767), // -1 + static_cast(32768), // 0 + static_cast(65535), // -1 + static_cast(-1), // -1 + static_cast(-128), // -128 + static_cast(-129), // 127 + static_cast(std::numeric_limits::min()), + static_cast(std::numeric_limits::max())})}); + assertReadWithReaderAndExpected( + readSchema, *rowReader, expected, *leafPool_); + } + + // INT32->SMALLINT: values are truncated via static_cast. + { + auto readSchema = ROW({"c1"}, {SMALLINT()}); + readerOptions.setFileSchema(readSchema); + auto reader = createReaderInMemory(*sink, readerOptions); + + auto rowReaderOpts = getReaderOpts(readSchema); + rowReaderOpts.setScanSpec(makeScanSpec(readSchema)); + auto rowReader = reader->createRowReader(rowReaderOpts); + + auto expected = makeRowVector( + {"c1"}, + {makeFlatVector( + {static_cast(0), + static_cast(127), + static_cast(128), + static_cast(255), + static_cast(256), + static_cast(32767), + static_cast(32768), // -32768 + static_cast(65535), // -1 + static_cast(-1), // -1 + static_cast(-128), // -128 + static_cast(-129), // -129 + static_cast(std::numeric_limits::min()), + static_cast(std::numeric_limits::max())})}); + assertReadWithReaderAndExpected( + readSchema, *rowReader, expected, *leafPool_); + } +} + +// INT -> Integer widening + filter. +TEST_F(ParquetReaderWideningTest, tinyintToSmallintWideningWithFilter) { + auto writeData = + makeRowVector({"col"}, {makeFlatVector({-10, 0, 10, 50, 127})}); + auto expected = + makeRowVector({"col"}, {makeFlatVector({0, 10, 50, 127})}); + assertWideningWithFilter( + writeData, + ROW({"col"}, {SMALLINT()}), + exec::greaterThanOrEqual(0), + expected); +} + +TEST_F(ParquetReaderWideningTest, tinyintToIntegerWideningWithFilter) { + auto writeData = + makeRowVector({"col"}, {makeFlatVector({-10, 0, 10, 50, 127})}); + auto expected = + makeRowVector({"col"}, {makeFlatVector({0, 10, 50, 127})}); + assertWideningWithFilter( + writeData, + ROW({"col"}, {INTEGER()}), + exec::greaterThanOrEqual(0), + expected); +} + +TEST_F(ParquetReaderWideningTest, shortToIntegerWideningWithFilter) { + auto writeData = makeRowVector( + {"col"}, {makeFlatVector({-100, 0, 50, 100, 32'767})}); + auto expected = + makeRowVector({"col"}, {makeFlatVector({50, 100, 32'767})}); + assertWideningWithFilter( + writeData, + ROW({"col"}, {INTEGER()}), + exec::greaterThanOrEqual(50), + expected); +} + +TEST_F(ParquetReaderWideningTest, intToBigintWideningWithFilter) { + auto writeData = makeRowVector( + {"col"}, {makeFlatVector({-100, 0, 50, 100, 2'000'000})}); + auto expected = + makeRowVector({"col"}, {makeFlatVector({50, 100, 2'000'000})}); + assertWideningWithFilter( + writeData, + ROW({"col"}, {BIGINT()}), + exec::greaterThanOrEqual(50), + expected); +} + +// INT -> DOUBLE widening + filter. +TEST_F(ParquetReaderWideningTest, intToDoubleWideningWithFilter) { + auto writeData = + makeRowVector({"col"}, {makeFlatVector({-100, 0, 50, 100})}); + auto expected = + makeRowVector({"col"}, {makeFlatVector({50.0, 100.0})}); + // BigintRange filter. + assertWideningWithFilter( + writeData, + ROW({"col"}, {DOUBLE()}), + exec::greaterThanOrEqual(50), + expected); +} + +TEST_F(ParquetReaderWideningTest, tinyintToDoubleWideningWithFilter) { + auto writeData = + makeRowVector({"col"}, {makeFlatVector({-10, 0, 10, 50})}); + auto expected = + makeRowVector({"col"}, {makeFlatVector({0.0, 10.0, 50.0})}); + assertWideningWithFilter( + writeData, + ROW({"col"}, {DOUBLE()}), + exec::greaterThanOrEqual(0), + expected); +} + +// DoubleRange filter not yet supported for widened columns. See #16895. +TEST_F( + ParquetReaderWideningTest, + DISABLED_intToDoubleWideningWithDoubleRangeFilter) { + auto writeData = + makeRowVector({"col"}, {makeFlatVector({-100, 0, 50, 100})}); + assertWideningWithFilter( + writeData, + ROW({"col"}, {DOUBLE()}), + exec::greaterThanOrEqualDouble(50.0), + makeRowVector({"col"}, {makeFlatVector({50.0, 100.0})})); +} + +// INT -> Decimal widening + filter. +TEST_F(ParquetReaderWideningTest, intToDecimalWideningWithFilter) { + auto writeData = + makeRowVector({"col"}, {makeFlatVector({-100, 0, 50, 100})}); + // Scale 0. + assertWideningWithFilter( + writeData, + ROW({"col"}, {DECIMAL(10, 0)}), + exec::greaterThanOrEqual(50), + makeRowVector( + {"col"}, {makeFlatVector({50, 100}, DECIMAL(10, 0))})); +} + +// Scale > 0 not yet supported for filter pushdown with widening. See #16895. +TEST_F( + ParquetReaderWideningTest, + DISABLED_intToDecimalWithScaleWideningWithFilter) { + auto writeData = + makeRowVector({"col"}, {makeFlatVector({-100, 0, 50, 100})}); + assertWideningWithFilter( + writeData, + ROW({"col"}, {DECIMAL(12, 2)}), + exec::greaterThanOrEqual(5'000), + makeRowVector( + {"col"}, {makeFlatVector({5'000, 10'000}, DECIMAL(12, 2))})); +} + +// HugeintRange filter not yet supported for widened columns. See #16895. +TEST_F(ParquetReaderWideningTest, DISABLED_bigintToDecimalWideningWithFilter) { + auto writeData = + makeRowVector({"col"}, {makeFlatVector({-100, 0, 50, 100})}); + // BIGINT -> Decimal(25, 5). + assertWideningWithFilter( + writeData, + ROW({"col"}, {DECIMAL(25, 5)}), + exec::greaterThanOrEqualHugeint(int128_t(50) * 100'000), + makeRowVector( + {"col"}, + {makeFlatVector( + {int128_t(50) * 100'000, int128_t(100) * 100'000}, + DECIMAL(25, 5))})); + // BIGINT -> Decimal(38, 0). + assertWideningWithFilter( + writeData, + ROW({"col"}, {DECIMAL(38, 0)}), + exec::greaterThanOrEqualHugeint(int128_t(50)), + makeRowVector( + {"col"}, {makeFlatVector({50, 100}, DECIMAL(38, 0))})); + // BIGINT -> Decimal(21, 1). + assertWideningWithFilter( + writeData, + ROW({"col"}, {DECIMAL(21, 1)}), + exec::greaterThanOrEqualHugeint(int128_t(50) * 10), + makeRowVector( + {"col"}, + {makeFlatVector( + {int128_t(50) * 10, int128_t(100) * 10}, DECIMAL(21, 1))})); +} + +// Decimal -> Decimal (short->short) + filter. +TEST_F(ParquetReaderWideningTest, decimalShortToShortWideningWithFilter) { + auto writeData = makeRowVector( + {"col"}, + {makeFlatVector({-1'000, 0, 5'000, 10'000}, DECIMAL(7, 2))}); + // Scale unchanged: DECIMAL(7,2) -> DECIMAL(10,2). + assertWideningWithFilter( + writeData, + ROW({"col"}, {DECIMAL(10, 2)}), + exec::greaterThanOrEqual(5'000), + makeRowVector( + {"col"}, {makeFlatVector({5'000, 10'000}, DECIMAL(10, 2))})); +} + +// Scale changed: DECIMAL(7,2) -> DECIMAL(10,4). See #16895. +TEST_F( + ParquetReaderWideningTest, + DISABLED_decimalScaleChangeWideningWithFilter) { + auto writeData = makeRowVector( + {"col"}, + {makeFlatVector({-1'000, 0, 5'000, 10'000}, DECIMAL(7, 2))}); + assertWideningWithFilter( + writeData, + ROW({"col"}, {DECIMAL(10, 4)}), + exec::greaterThanOrEqual(500'000), + makeRowVector( + {"col"}, + {makeFlatVector({500'000, 1'000'000}, DECIMAL(10, 4))})); +} + +// Cases have different failure modes: same-scale fails on HugeintRange crash, +// scale-change fails on unscaled value mismatch. See #16895. +TEST_F( + ParquetReaderWideningTest, + DISABLED_decimalShortToLongWideningWithFilter) { + auto writeData = makeRowVector( + {"col"}, + {makeFlatVector({-1'000, 0, 5'000, 10'000}, DECIMAL(5, 2))}); + // Scale unchanged: DECIMAL(5,2) -> DECIMAL(20,2). + assertWideningWithFilter( + writeData, + ROW({"col"}, {DECIMAL(20, 2)}), + exec::greaterThanOrEqualHugeint(int128_t(5'000)), + makeRowVector( + {"col"}, + {makeFlatVector({5'000, 10'000}, DECIMAL(20, 2))})); + // Scale changed: DECIMAL(5,2) -> DECIMAL(20,4). + assertWideningWithFilter( + writeData, + ROW({"col"}, {DECIMAL(20, 4)}), + exec::greaterThanOrEqualHugeint(int128_t(500'000)), + makeRowVector( + {"col"}, + {makeFlatVector({500'000, 1'000'000}, DECIMAL(20, 4))})); + // Large scale increase: DECIMAL(5,2) -> DECIMAL(10,7). + assertWideningWithFilter( + writeData, + ROW({"col"}, {DECIMAL(10, 7)}), + exec::greaterThanOrEqual(500'000'000), + makeRowVector( + {"col"}, + {makeFlatVector( + {500'000'000, 1'000'000'000}, DECIMAL(10, 7))})); + // Precision and scale increase: DECIMAL(5,2) -> DECIMAL(12,4). + assertWideningWithFilter( + writeData, + ROW({"col"}, {DECIMAL(12, 4)}), + exec::greaterThanOrEqual(500'000), + makeRowVector( + {"col"}, + {makeFlatVector({500'000, 1'000'000}, DECIMAL(12, 4))})); +} + +// HugeintRange filter with scale change not yet supported. See #16895. +TEST_F( + ParquetReaderWideningTest, + DISABLED_decimalLongToLongWideningWithFilter) { + auto writeData = makeRowVector( + {"col"}, + {makeFlatVector({-1'000, 0, 5'000, 10'000}, DECIMAL(20, 2))}); + assertWideningWithFilter( + writeData, + ROW({"col"}, {DECIMAL(22, 4)}), + exec::greaterThanOrEqualHugeint(int128_t(500'000)), + makeRowVector( + {"col"}, + {makeFlatVector({500'000, 1'000'000}, DECIMAL(22, 4))})); +} + +// INT32 narrowing + filter. +TEST_F(ParquetReaderWideningTest, intNarrowingFilterBehavior) { + // INT32 -> TINYINT with filter x in [0, 127] (TINYINT range). + // INT32 value 200 fails filter (200 > 127), so it is filtered out. + auto writeData = + makeRowVector({"col"}, {makeFlatVector({-10, 0, 50, 127, 200})}); + auto expected = + makeRowVector({"col"}, {makeFlatVector({0, 50, 127})}); + assertWideningWithFilter( + writeData, + ROW({"col"}, {TINYINT()}), + std::make_unique(0, 127, /*nullAllowed=*/false), + expected, + /*allowInt32Narrowing=*/true); +} + +// Null filter tests. +TEST_F(ParquetReaderWideningTest, intToBigintWideningNullFilter) { + auto writeData = makeRowVector( + {"col"}, + {makeNullableFlatVector({0, std::nullopt, 50, std::nullopt})}); + assertWideningWithFilter( + writeData, + ROW({"col"}, {BIGINT()}), + exec::isNull(), + makeRowVector( + {"col"}, + {makeNullableFlatVector({std::nullopt, std::nullopt})})); + assertWideningWithFilter( + writeData, + ROW({"col"}, {BIGINT()}), + exec::isNotNull(), + makeRowVector({"col"}, {makeFlatVector({0, 50})})); +} + +TEST_F(ParquetReaderWideningTest, intToDoubleWideningNullFilter) { + auto writeData = makeRowVector( + {"col"}, + {makeNullableFlatVector({0, std::nullopt, 50, std::nullopt})}); + assertWideningWithFilter( + writeData, + ROW({"col"}, {DOUBLE()}), + exec::isNull(), + makeRowVector( + {"col"}, + {makeNullableFlatVector({std::nullopt, std::nullopt})})); + assertWideningWithFilter( + writeData, + ROW({"col"}, {DOUBLE()}), + exec::isNotNull(), + makeRowVector({"col"}, {makeFlatVector({0.0, 50.0})})); +} + +TEST_F(ParquetReaderWideningTest, intToDecimalScale0NullFilter) { + auto writeData = makeRowVector( + {"col"}, + {makeNullableFlatVector({0, std::nullopt, 50, std::nullopt})}); + assertWideningWithFilter( + writeData, + ROW({"col"}, {DECIMAL(10, 0)}), + exec::isNull(), + makeRowVector( + {"col"}, + {makeNullableFlatVector( + {std::nullopt, std::nullopt}, DECIMAL(10, 0))})); + assertWideningWithFilter( + writeData, + ROW({"col"}, {DECIMAL(10, 0)}), + exec::isNotNull(), + makeRowVector( + {"col"}, {makeFlatVector({0, 50}, DECIMAL(10, 0))})); +} + +TEST_F(ParquetReaderWideningTest, decimalPrecisionOnlyNullFilter) { + auto writeData = makeRowVector( + {"col"}, + {makeNullableFlatVector( + {1'000, std::nullopt, 5'000, std::nullopt}, DECIMAL(7, 2))}); + assertWideningWithFilter( + writeData, + ROW({"col"}, {DECIMAL(10, 2)}), + exec::isNull(), + makeRowVector( + {"col"}, + {makeNullableFlatVector( + {std::nullopt, std::nullopt}, DECIMAL(10, 2))})); + assertWideningWithFilter( + writeData, + ROW({"col"}, {DECIMAL(10, 2)}), + exec::isNotNull(), + makeRowVector( + {"col"}, {makeFlatVector({1'000, 5'000}, DECIMAL(10, 2))})); +} diff --git a/velox/dwio/parquet/tests/reader/ParquetTableScanTest.cpp b/velox/dwio/parquet/tests/reader/ParquetTableScanTest.cpp index 96c4a3a4051..786a1d2ddfc 100644 --- a/velox/dwio/parquet/tests/reader/ParquetTableScanTest.cpp +++ b/velox/dwio/parquet/tests/reader/ParquetTableScanTest.cpp @@ -17,6 +17,8 @@ #include #include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/io/IoStatistics.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/dwio/common/tests/utils/DataFiles.h" // @manual #include "velox/dwio/parquet/RegisterParquetReader.h" // @manual #include "velox/dwio/parquet/reader/PageReader.h" // @manual @@ -24,7 +26,6 @@ #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" // @manual #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/type/tests/SubfieldFiltersBuilder.h" #include "velox/type/tz/TimeZoneMap.h" @@ -37,6 +38,7 @@ using namespace facebook::velox::connector::hive; using namespace facebook::velox::exec::test; using namespace facebook::velox::parquet; using namespace facebook::velox::test; +using namespace facebook::velox::common::testutil; class ParquetTableScanTest : public HiveConnectorTestBase { protected: @@ -48,26 +50,29 @@ class ParquetTableScanTest : public HiveConnectorTestBase { } void assertSelect( + std::vector> splits, std::vector&& outputColumnNames, const std::string& sql) { auto rowType = getRowType(std::move(outputColumnNames)); auto plan = PlanBuilder().tableScan(rowType).planNode(); - assertQuery(plan, splits_, sql); + assertQuery(plan, splits, sql); } void assertSelectWithDataColumns( + std::vector> splits, std::vector&& outputColumnNames, const RowTypePtr& dataColumns, const std::string& sql) { auto rowType = getRowType(std::move(outputColumnNames)); auto plan = PlanBuilder().tableScan(rowType, {}, "", dataColumns).planNode(); - assertQuery(plan, splits_, sql); + assertQuery(plan, splits, sql); } void assertSelectWithAssignments( + std::vector> splits, std::vector&& outputColumnNames, const connector::ColumnHandleMap& assignments, const std::string& sql) { @@ -75,10 +80,11 @@ class ParquetTableScanTest : public HiveConnectorTestBase { auto plan = PlanBuilder() .tableScan(rowType, {}, "", nullptr, assignments) .planNode(); - assertQuery(plan, splits_, sql); + assertQuery(plan, splits, sql); } void assertSelectWithFilter( + std::vector> splits, std::vector&& outputColumnNames, const std::vector& subfieldFilters, const std::string& remainingFilter, @@ -99,12 +105,13 @@ class ParquetTableScanTest : public HiveConnectorTestBase { .connectorSessionProperty( kHiveConnectorId, HiveConfig::kReadTimestampUnitSession, - std::to_string(static_cast(timestampPrecision_))) - .splits(splits_) + std::to_string(static_cast(readTimestampPrecision_))) + .splits(splits) .assertResults(sql); } void assertSelectWithAgg( + std::vector> splits, std::vector&& outputColumnNames, const std::vector& aggregates, const std::vector& groupingKeys, @@ -116,10 +123,11 @@ class ParquetTableScanTest : public HiveConnectorTestBase { .singleAggregation(groupingKeys, aggregates) .planNode(); - assertQuery(plan, splits_, sql); + assertQuery(plan, splits, sql); } void assertSelectWithFilterAndAgg( + std::vector> splits, std::vector&& outputColumnNames, const std::vector& filters, const std::vector& aggregates, @@ -132,18 +140,19 @@ class ParquetTableScanTest : public HiveConnectorTestBase { .singleAggregation(groupingKeys, aggregates) .planNode(); - assertQuery(plan, splits_, sql); + assertQuery(plan, splits, sql); } void assertSelectWithTimezone( + std::vector> connectorSplits, std::vector&& outputColumnNames, const std::string& sql, const std::string& sessionTimezone) { auto rowType = getRowType(std::move(outputColumnNames)); auto plan = PlanBuilder().tableScan(rowType).planNode(); std::vector splits; - splits.reserve(splits_.size()); - for (const auto& connectorSplit : splits_) { + splits.reserve(connectorSplits.size()); + for (const auto& connectorSplit : connectorSplits) { splits.emplace_back(folly::copy(connectorSplit), -1); } @@ -153,24 +162,16 @@ class ParquetTableScanTest : public HiveConnectorTestBase { .assertResults(sql); } - void loadData( - const std::string& filePath, - RowTypePtr rowType, - RowVectorPtr data, - const std::optional< - std::unordered_map>>& - partitionKeys = std::nullopt, - const std::optional>& - infoColumns = std::nullopt) { - splits_ = {makeSplit(filePath, partitionKeys, infoColumns)}; + void loadData(RowTypePtr rowType, RowVectorPtr data) { rowType_ = rowType; createDuckDbTable({data}); } void loadDataWithRowType(const std::string& filePath, RowVectorPtr data) { - splits_ = {makeSplit(filePath)}; auto pool = facebook::velox::memory::memoryManager()->addLeafPool(); dwio::common::ReaderOptions readerOpts{pool.get()}; + readerOpts.setDataIoStats(dataIoStats_); + readerOpts.setMetadataIoStats(metadataIoStats_); auto reader = std::make_unique( std::make_unique( std::make_shared(filePath), readerOpts.memoryPool()), @@ -213,10 +214,6 @@ class ParquetTableScanTest : public HiveConnectorTestBase { rootPool_->addAggregateChild("ParquetTableScanTest.Writer"); options.memoryPool = childPool.get(); - if (options.parquetWriteTimestampUnit.has_value()) { - timestampPrecision_ = options.parquetWriteTimestampUnit.value(); - } - auto writer = std::make_unique( std::move(sink), options, asRowType(data[0]->type())); @@ -226,77 +223,102 @@ class ParquetTableScanTest : public HiveConnectorTestBase { writer->close(); } - void testTimestampRead(const WriterOptions& options) { - auto stringToTimestamp = [](std::string_view view) { - return util::fromTimestampString( - view.data(), - view.size(), - util::TimestampParseMode::kPrestoCast) - .thenOrThrow(folly::identity, [&](const Status& status) { - VELOX_USER_FAIL("{}", status.message()); - }); - }; - std::vector views = { - "2015-06-01 19:34:56.007", - "2015-06-02 19:34:56.12306", - "2001-02-03 03:34:06.056", - "1998-03-01 08:01:06.996669", - "2022-12-23 03:56:01", - "1980-01-24 00:23:07", - "1999-12-08 13:39:26.123456", - "2023-04-21 09:09:34.5", - "2000-09-12 22:36:29", - "2007-12-12 04:27:56.999", - }; - std::vector values; - values.reserve(views.size()); - for (auto view : views) { - values.emplace_back(stringToTimestamp(view)); - } - + void testTimestampRead( + const WriterOptions& options, + TimestampPrecision readTimestampPrecision) { + VELOX_CHECK(options.parquetWriteTimestampUnit.has_value()); + const auto [values, expectedValues] = timestampValues( + options.parquetWriteTimestampUnit.value(), readTimestampPrecision); auto vector = makeRowVector( {"t"}, { makeFlatVector(values), }); - auto schema = asRowType(vector->type()); auto file = TempFilePath::create(); writeToParquetFile(file->getPath(), {vector}, options); - loadData(file->getPath(), schema, vector); - - assertSelectWithFilter({"t"}, {}, "", "SELECT t from tmp"); + loadData( + asRowType(vector->type()), + makeRowVector( + {"t"}, + { + makeFlatVector(expectedValues), + })); + + readTimestampPrecision_ = readTimestampPrecision; + auto guard = folly::makeGuard( + [&] { readTimestampPrecision_ = kDefaultReadTimestampPrecision; }); + assertSelectWithFilter( + {makeSplit(file->getPath())}, {"t"}, {}, "", "SELECT t from tmp"); assertSelectWithFilter( + {makeSplit(file->getPath())}, {"t"}, {}, "t < TIMESTAMP '2000-09-12 22:36:29'", "SELECT t from tmp where t < TIMESTAMP '2000-09-12 22:36:29'"); assertSelectWithFilter( + {makeSplit(file->getPath())}, {"t"}, {}, "t <= TIMESTAMP '2000-09-12 22:36:29'", "SELECT t from tmp where t <= TIMESTAMP '2000-09-12 22:36:29'"); assertSelectWithFilter( + {makeSplit(file->getPath())}, {"t"}, {}, "t > TIMESTAMP '1980-01-24 00:23:07'", "SELECT t from tmp where t > TIMESTAMP '1980-01-24 00:23:07'"); assertSelectWithFilter( + {makeSplit(file->getPath())}, {"t"}, {}, "t >= TIMESTAMP '1980-01-24 00:23:07'", "SELECT t from tmp where t >= TIMESTAMP '1980-01-24 00:23:07'"); assertSelectWithFilter( + {makeSplit(file->getPath())}, {"t"}, {}, "t == TIMESTAMP '2022-12-23 03:56:01'", "SELECT t from tmp where t == TIMESTAMP '2022-12-23 03:56:01'"); assertSelectWithFilter( + {makeSplit(file->getPath())}, {"t"}, {}, "not(eq(t, TIMESTAMP '2000-09-12 22:36:29'))", "SELECT t from tmp where t != TIMESTAMP '2000-09-12 22:36:29'"); } + void testTimestampUtcRead( + const WriterOptions& options, + TimestampPrecision readTimestampPrecision) { + VELOX_CHECK(options.parquetWriteTimestampUnit.has_value()); + const auto [values, expectedValues] = timestampValues( + options.parquetWriteTimestampUnit.value(), readTimestampPrecision); + auto vector = makeRowVector( + {"t"}, + { + makeFlatVector(values, TIMESTAMP_UTC()), + }); + auto file = TempFilePath::create(); + writeToParquetFile(file->getPath(), {vector}, options); + + loadData( + ROW({"t"}, {TIMESTAMP_UTC()}), + makeRowVector( + {"t"}, + { + // Expect values are used for creating duckdb table, so keep + // using TIMESTAMP here. + makeFlatVector(expectedValues), + })); + + readTimestampPrecision_ = readTimestampPrecision; + auto guard = folly::makeGuard( + [&] { readTimestampPrecision_ = kDefaultReadTimestampPrecision; }); + + assertSelectWithFilter( + {makeSplit(file->getPath())}, {"t"}, {}, "", "SELECT t from tmp"); + } + private: RowTypePtr getRowType(std::vector&& outputColumnNames) const { std::vector types; @@ -307,14 +329,70 @@ class ParquetTableScanTest : public HiveConnectorTestBase { return ROW(std::move(outputColumnNames), std::move(types)); } + std::pair, std::vector> timestampValues( + TimestampPrecision writeTimestampPrecision, + TimestampPrecision readTimestampPrecision) { + auto stringToTimestamp = [](std::string_view view) { + return util::fromTimestampString( + view.data(), + view.size(), + util::TimestampParseMode::kPrestoCast) + .value(); + }; + std::vector views = { + "2015-06-01 19:34:56.007", + "2015-06-02 19:34:56.12306", + "2001-02-03 03:34:06.056", + "1998-03-01 08:01:06.996669", + "2022-12-23 03:56:01", + "1980-01-24 00:23:07", + "1999-12-08 13:39:26.123456", + "2023-04-21 09:09:34.5", + "2000-09-12 22:36:29", + "2007-12-12 04:27:56.999", + }; + + std::vector values; + std::vector expectedValues; + values.reserve(views.size()); + for (auto view : views) { + auto ts = stringToTimestamp(view); + values.emplace_back(ts); + if (readTimestampPrecision == TimestampPrecision::kMilliseconds) { + expectedValues.emplace_back(Timestamp::fromMillis(ts.toMillis())); + continue; + } + if (readTimestampPrecision == TimestampPrecision::kMicroseconds) { + if (writeTimestampPrecision == TimestampPrecision::kMilliseconds) { + expectedValues.emplace_back( + Timestamp::fromMicros(ts.toMillis() * 1'000)); + continue; + } + if (writeTimestampPrecision == TimestampPrecision::kMicroseconds) { + expectedValues.emplace_back(Timestamp::fromMicros(ts.toMicros())); + continue; + } + } + VELOX_NYI( + "Not implemented, read precision: {}, write precision: {}", + static_cast(readTimestampPrecision), + static_cast(writeTimestampPrecision)); + } + return {values, expectedValues}; + } + RowTypePtr rowType_; - std::vector> splits_; - TimestampPrecision timestampPrecision_ = TimestampPrecision::kMicroseconds; + const TimestampPrecision kDefaultReadTimestampPrecision = + TimestampPrecision::kMicroseconds; + TimestampPrecision readTimestampPrecision_ = kDefaultReadTimestampPrecision; + std::shared_ptr dataIoStats_ = + std::make_shared(); + std::shared_ptr metadataIoStats_ = + std::make_shared(); }; TEST_F(ParquetTableScanTest, basic) { loadData( - getExampleFilePath("sample.parquet"), ROW({"a", "b"}, {BIGINT(), DOUBLE()}), makeRowVector( {"a", "b"}, @@ -324,64 +402,111 @@ TEST_F(ParquetTableScanTest, basic) { })); // Plain select. - assertSelect({"a"}, "SELECT a FROM tmp"); - assertSelect({"b"}, "SELECT b FROM tmp"); - assertSelect({"a", "b"}, "SELECT a, b FROM tmp"); - assertSelect({"b", "a"}, "SELECT b, a FROM tmp"); + const auto filePath = getExampleFilePath("sample.parquet"); + assertSelect({makeSplit(filePath)}, {"a"}, "SELECT a FROM tmp"); + assertSelect({makeSplit(filePath)}, {"b"}, "SELECT b FROM tmp"); + assertSelect({makeSplit(filePath)}, {"a", "b"}, "SELECT a, b FROM tmp"); + assertSelect({makeSplit(filePath)}, {"b", "a"}, "SELECT b, a FROM tmp"); // With filters. - assertSelectWithFilter({"a"}, {"a < 3"}, "", "SELECT a FROM tmp WHERE a < 3"); assertSelectWithFilter( - {"a", "b"}, {"a < 3"}, "", "SELECT a, b FROM tmp WHERE a < 3"); + {makeSplit(filePath)}, + {"a"}, + {"a < 3"}, + "", + "SELECT a FROM tmp WHERE a < 3"); assertSelectWithFilter( - {"b", "a"}, {"a < 3"}, "", "SELECT b, a FROM tmp WHERE a < 3"); + {makeSplit(filePath)}, + {"a", "b"}, + {"a < 3"}, + "", + "SELECT a, b FROM tmp WHERE a < 3"); + assertSelectWithFilter( + {makeSplit(filePath)}, + {"b", "a"}, + {"a < 3"}, + "", + "SELECT b, a FROM tmp WHERE a < 3"); assertSelectWithFilter( - {"a", "b"}, {"a < 0"}, "", "SELECT a, b FROM tmp WHERE a < 0"); + {makeSplit(filePath)}, + {"a", "b"}, + {"a < 0"}, + "", + "SELECT a, b FROM tmp WHERE a < 0"); assertSelectWithFilter( - {"b"}, {"b < DOUBLE '2.0'"}, "", "SELECT b FROM tmp WHERE b < 2.0"); + {makeSplit(filePath)}, + {"b"}, + {"b < DOUBLE '2.0'"}, + "", + "SELECT b FROM tmp WHERE b < 2.0"); assertSelectWithFilter( + {makeSplit(filePath)}, {"a", "b"}, {"b >= DOUBLE '2.0'"}, "", "SELECT a, b FROM tmp WHERE b >= 2.0"); assertSelectWithFilter( + {makeSplit(filePath)}, {"b", "a"}, {"b <= DOUBLE '2.0'"}, "", "SELECT b, a FROM tmp WHERE b <= 2.0"); assertSelectWithFilter( + {makeSplit(filePath)}, {"a", "b"}, {"b < DOUBLE '0.0'"}, "", "SELECT a, b FROM tmp WHERE b < 0.0"); // With aggregations. - assertSelectWithAgg({"a"}, {"sum(a)"}, {}, "SELECT sum(a) FROM tmp"); - assertSelectWithAgg({"b"}, {"max(b)"}, {}, "SELECT max(b) FROM tmp"); assertSelectWithAgg( - {"a", "b"}, {"min(a)", "max(b)"}, {}, "SELECT min(a), max(b) FROM tmp"); + {makeSplit(filePath)}, {"a"}, {"sum(a)"}, {}, "SELECT sum(a) FROM tmp"); + assertSelectWithAgg( + {makeSplit(filePath)}, {"b"}, {"max(b)"}, {}, "SELECT max(b) FROM tmp"); + assertSelectWithAgg( + {makeSplit(filePath)}, + {"a", "b"}, + {"min(a)", "max(b)"}, + {}, + "SELECT min(a), max(b) FROM tmp"); assertSelectWithAgg( - {"b", "a"}, {"max(b)"}, {"a"}, "SELECT max(b), a FROM tmp GROUP BY a"); + {makeSplit(filePath)}, + {"b", "a"}, + {"max(b)"}, + {"a"}, + "SELECT max(b), a FROM tmp GROUP BY a"); assertSelectWithAgg( - {"a", "b"}, {"max(a)"}, {"b"}, "SELECT max(a), b FROM tmp GROUP BY b"); + {makeSplit(filePath)}, + {"a", "b"}, + {"max(a)"}, + {"b"}, + "SELECT max(a), b FROM tmp GROUP BY b"); // With filter and aggregation. assertSelectWithFilterAndAgg( - {"a"}, {"a < 3"}, {"sum(a)"}, {}, "SELECT sum(a) FROM tmp WHERE a < 3"); + {makeSplit(filePath)}, + {"a"}, + {"a < 3"}, + {"sum(a)"}, + {}, + "SELECT sum(a) FROM tmp WHERE a < 3"); assertSelectWithFilterAndAgg( + {makeSplit(filePath)}, {"a", "b"}, {"a < 3"}, {"sum(b)"}, {}, "SELECT sum(b) FROM tmp WHERE a < 3"); assertSelectWithFilterAndAgg( + {makeSplit(filePath)}, {"a", "b"}, {"a < 3"}, {"min(a)", "max(b)"}, {}, "SELECT min(a), max(b) FROM tmp WHERE a < 3"); assertSelectWithFilterAndAgg( + {makeSplit(filePath)}, {"b", "a"}, {"a < 3"}, {"max(b)"}, @@ -428,6 +553,39 @@ TEST_F(ParquetTableScanTest, aggregatePushdown) { assertEqualVectors(rows->childAt(1), valuesVector); } +TEST_F(ParquetTableScanTest, aggregatePushdownToSmallPages) { + const std::vector columnNames = {"a", "b", "c"}; + const auto expectedRowVector = makeRowVector( + {makeFlatVector({1, 2, 4}), + makeFlatVector({7, 9, 13})}); + const auto outputType = ROW(columnNames, {SMALLINT(), SMALLINT(), VARCHAR()}); + std::vector data; + for (auto row = 0; row < 10; ++row) { + data.emplace_back(makeRowVector( + columnNames, + { + makeFlatVector({static_cast(row % 5)}), + makeFlatVector({static_cast(row)}), + makeFlatVector({std::to_string(row)}), + })); + } + const auto filePath = TempFilePath::create(); + WriterOptions options; + options.dataPageSize = 1; + writeToParquetFile(filePath->getPath(), data, options); + const auto plan = + PlanBuilder(pool()) + .tableScan( + outputType, + {}, + "c <> '' AND a in (1::smallint, 2::smallint, 4::smallint)") + .singleAggregation({"a"}, {"sum(b) as s"}) + .planNode(); + AssertQueryBuilder(plan) + .split(makeSplit(filePath->getPath())) + .assertResults(expectedRowVector); +} + TEST_F(ParquetTableScanTest, countStar) { // sample.parquet holds two columns (a: BIGINT, b: DOUBLE) and // 20 rows. @@ -452,7 +610,6 @@ TEST_F(ParquetTableScanTest, decimalSubfieldFilter) { std::vector unscaledShortValues(20); std::iota(unscaledShortValues.begin(), unscaledShortValues.end(), 10001); loadData( - getExampleFilePath("decimal.parquet"), ROW({"a"}, {DECIMAL(5, 2)}), makeRowVector( {"a"}, @@ -460,17 +617,39 @@ TEST_F(ParquetTableScanTest, decimalSubfieldFilter) { makeFlatVector(unscaledShortValues, DECIMAL(5, 2)), })); + const auto filePath = getExampleFilePath("decimal.parquet"); assertSelectWithFilter( - {"a"}, {"a < 100.07"}, "", "SELECT a FROM tmp WHERE a < 100.07"); + {makeSplit(filePath)}, + {"a"}, + {"a < 100.07"}, + "", + "SELECT a FROM tmp WHERE a < 100.07"); assertSelectWithFilter( - {"a"}, {"a <= 100.07"}, "", "SELECT a FROM tmp WHERE a <= 100.07"); + {makeSplit(filePath)}, + {"a"}, + {"a <= 100.07"}, + "", + "SELECT a FROM tmp WHERE a <= 100.07"); assertSelectWithFilter( - {"a"}, {"a > 100.07"}, "", "SELECT a FROM tmp WHERE a > 100.07"); + {makeSplit(filePath)}, + {"a"}, + {"a > 100.07"}, + "", + "SELECT a FROM tmp WHERE a > 100.07"); assertSelectWithFilter( - {"a"}, {"a >= 100.07"}, "", "SELECT a FROM tmp WHERE a >= 100.07"); + {makeSplit(filePath)}, + {"a"}, + {"a >= 100.07"}, + "", + "SELECT a FROM tmp WHERE a >= 100.07"); assertSelectWithFilter( - {"a"}, {"a = 100.07"}, "", "SELECT a FROM tmp WHERE a = 100.07"); + {makeSplit(filePath)}, + {"a"}, + {"a = 100.07"}, + "", + "SELECT a FROM tmp WHERE a = 100.07"); assertSelectWithFilter( + {makeSplit(filePath)}, {"a"}, {"a BETWEEN 100.07 AND 100.12"}, "", @@ -478,11 +657,19 @@ TEST_F(ParquetTableScanTest, decimalSubfieldFilter) { VELOX_ASSERT_THROW( assertSelectWithFilter( - {"a"}, {"a < 1000.7"}, "", "SELECT a FROM tmp WHERE a < 1000.7"), + {makeSplit(filePath)}, + {"a"}, + {"a < 1000.7"}, + "", + "SELECT a FROM tmp WHERE a < 1000.7"), "Scalar function signature is not supported: lt(DECIMAL(5, 2), DECIMAL(5, 1))"); VELOX_ASSERT_THROW( assertSelectWithFilter( - {"a"}, {"a = 1000.7"}, "", "SELECT a FROM tmp WHERE a = 1000.7"), + {makeSplit(filePath)}, + {"a"}, + {"a = 1000.7"}, + "", + "SELECT a FROM tmp WHERE a = 1000.7"), "Scalar function signature is not supported: eq(DECIMAL(5, 2), DECIMAL(5, 1))"); } @@ -490,7 +677,6 @@ TEST_F(ParquetTableScanTest, map) { auto vector = makeMapVector({{{"name", "gluten"}}}); loadData( - getExampleFilePath("types.parquet"), ROW({"map"}, {MAP(VARCHAR(), VARCHAR())}), makeRowVector( {"map"}, @@ -498,40 +684,50 @@ TEST_F(ParquetTableScanTest, map) { vector, })); - assertSelectWithFilter({"map"}, {}, "", "SELECT map FROM tmp"); + assertSelectWithFilter( + {makeSplit(getExampleFilePath("types.parquet"))}, + {"map"}, + {}, + "", + "SELECT map FROM tmp"); } TEST_F(ParquetTableScanTest, nullMap) { - auto path = getExampleFilePath("null_map.parquet"); loadData( - path, ROW({"i", "c"}, {VARCHAR(), MAP(VARCHAR(), VARCHAR())}), makeRowVector( {"i", "c"}, {makeConstant("1", 1), makeNullableMapVector({std::nullopt})})); - assertSelectWithFilter({"i", "c"}, {}, "", "SELECT i, c FROM tmp"); + assertSelectWithFilter( + {makeSplit(getExampleFilePath("null_map.parquet"))}, + {"i", "c"}, + {}, + "", + "SELECT i, c FROM tmp"); } TEST_F(ParquetTableScanTest, singleRowStruct) { auto vector = makeArrayVector({{}}); loadData( - getExampleFilePath("single_row_struct.parquet"), ROW({"s"}, {ROW({"a", "b"}, {BIGINT(), BIGINT()})}), makeRowVector( {"s"}, { vector, })); - - assertSelectWithFilter({"s"}, {}, "", "SELECT (0, 1)"); + assertSelectWithFilter( + {makeSplit(getExampleFilePath("single_row_struct.parquet"))}, + {"s"}, + {}, + "", + "SELECT (0, 1)"); } TEST_F(ParquetTableScanTest, array) { auto vector = makeArrayVector({}); loadData( - getExampleFilePath("old_repeated_int.parquet"), ROW({"repeatedInt"}, {ARRAY(INTEGER())}), makeRowVector( {"repeatedInt"}, @@ -539,8 +735,79 @@ TEST_F(ParquetTableScanTest, array) { vector, })); + const auto filePath = getExampleFilePath("old_repeated_int.parquet"); assertSelectWithFilter( - {"repeatedInt"}, {}, "", "SELECT UNNEST(array[array[1,2,3]])"); + {makeSplit(filePath)}, {"repeatedInt"}, {}, "", "SELECT [1,2,3]"); + + // Set the requested type for unannotated array. + auto rowType = ROW({"repeatedInt"}, {ARRAY(INTEGER())}); + auto plan = PlanBuilder(pool_.get()) + .tableScan(rowType, {}, "", rowType, {}) + .planNode(); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .splits({makeSplit(filePath)}) + .assertResults("SELECT [1,2,3]"); + + // Throws when reading repeated values as scalar type. + rowType = ROW({"repeatedInt"}, {INTEGER()}); + plan = PlanBuilder(pool_.get()) + .tableScan(rowType, {}, "", rowType, {}) + .planNode(); + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan, duckDbQueryRunner_) + .splits({makeSplit(filePath)}) + .assertResults(""), + "Requested type must be array"); + + rowType = ROW({"mystring"}, {ARRAY(VARCHAR())}); + plan = PlanBuilder(pool_.get()) + .tableScan(rowType, {}, "", rowType, {}) + .planNode(); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .splits({makeSplit(getExampleFilePath("proto_repeated_string.parquet"))}) + .assertResults( + "SELECT UNNEST(array[array['hello', 'world'], array['good','bye'], array['one', 'two', 'three']])"); + + rowType = + ROW({"primitive", "myComplex"}, + {INTEGER(), + ARRAY( + ROW({"id", "repeatedMessage"}, + {INTEGER(), ARRAY(ROW({"someId"}, {INTEGER()}))}))}); + plan = PlanBuilder(pool_.get()) + .tableScan(rowType, {}, "", rowType, {}) + .planNode(); + + // Construct the expected vector. + auto someIdVector = makeArrayOfRowVector( + ROW({"someId"}, {INTEGER()}), + { + {variant::row({3})}, + {variant::row({6})}, + {variant::row({9})}, + }); + auto rowVector = makeRowVector( + {"id", "repeatedMessage"}, + { + makeFlatVector({1, 4, 7}), + someIdVector, + }); + auto expected = makeRowVector( + {"primitive", "myComplex"}, + { + makeFlatVector({2, 5, 8}), + makeArrayVector({0, 1, 2}, rowVector), + }); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .connectorSessionProperty( + kHiveConnectorId, + connector::hive::HiveConfig::kParquetUseColumnNamesSession, + "true") + .splits({makeSplit(getExampleFilePath("nested_array_struct.parquet"))}) + .assertResults(expected); } // Optional array with required elements. @@ -548,7 +815,6 @@ TEST_F(ParquetTableScanTest, optArrayReqEle) { auto vector = makeArrayVector({}); loadData( - getExampleFilePath("array_0.parquet"), ROW({"_1"}, {ARRAY(VARCHAR())}), makeRowVector( {"_1"}, @@ -557,6 +823,7 @@ TEST_F(ParquetTableScanTest, optArrayReqEle) { })); assertSelectWithFilter( + {makeSplit(getExampleFilePath("array_0.parquet"))}, {"_1"}, {}, "", @@ -568,7 +835,6 @@ TEST_F(ParquetTableScanTest, reqArrayReqEle) { auto vector = makeArrayVector({}); loadData( - getExampleFilePath("array_1.parquet"), ROW({"_1"}, {ARRAY(VARCHAR())}), makeRowVector( {"_1"}, @@ -577,6 +843,7 @@ TEST_F(ParquetTableScanTest, reqArrayReqEle) { })); assertSelectWithFilter( + {makeSplit(getExampleFilePath("array_1.parquet"))}, {"_1"}, {}, "", @@ -588,7 +855,6 @@ TEST_F(ParquetTableScanTest, reqArrayOptEle) { auto vector = makeArrayVector({}); loadData( - getExampleFilePath("array_2.parquet"), ROW({"_1"}, {ARRAY(VARCHAR())}), makeRowVector( {"_1"}, @@ -597,6 +863,7 @@ TEST_F(ParquetTableScanTest, reqArrayOptEle) { })); assertSelectWithFilter( + {makeSplit(getExampleFilePath("array_2.parquet"))}, {"_1"}, {}, "", @@ -615,6 +882,7 @@ TEST_F(ParquetTableScanTest, arrayOfArrayTest) { })); assertSelectWithFilter( + {makeSplit(getExampleFilePath("array_of_array1.parquet"))}, {"_1"}, {}, "", @@ -626,7 +894,6 @@ TEST_F(ParquetTableScanTest, reqArrayLegacy) { auto vector = makeArrayVector({}); loadData( - getExampleFilePath("array_3.parquet"), ROW({"element"}, {ARRAY(VARCHAR())}), makeRowVector( {"element"}, @@ -635,6 +902,7 @@ TEST_F(ParquetTableScanTest, reqArrayLegacy) { })); assertSelectWithFilter( + {makeSplit(getExampleFilePath("array_3.parquet"))}, {"element"}, {}, "", @@ -643,7 +911,6 @@ TEST_F(ParquetTableScanTest, reqArrayLegacy) { TEST_F(ParquetTableScanTest, filterOnNestedArray) { loadData( - getExampleFilePath("struct_of_array.parquet"), ROW({"struct"}, {ROW({"a0", "a1"}, {ARRAY(VARCHAR()), ARRAY(INTEGER())})}), makeRowVector( @@ -653,7 +920,11 @@ TEST_F(ParquetTableScanTest, filterOnNestedArray) { })); assertSelectWithFilter( - {"struct"}, {}, "struct.a0 is null", "SELECT ROW(NULL, NULL)"); + {makeSplit(getExampleFilePath("struct_of_array.parquet"))}, + {"struct"}, + {}, + "struct.a0 is null", + "SELECT ROW(NULL, NULL)"); } TEST_F(ParquetTableScanTest, readAsLowerCase) { @@ -669,16 +940,19 @@ TEST_F(ParquetTableScanTest, readAsLowerCase) { createDuckDbTable(vectors); auto plan = PlanBuilder().tableScan(ROW({"a"}, {BIGINT()})).planNode(); - auto split = makeSplit(filePath->getPath()); AssertQueryBuilder(plan, duckDbQueryRunner_) .connectorSessionProperty( kHiveConnectorId, connector::hive::HiveConfig::kFileColumnNamesReadAsLowerCaseSession, "true") - .split(split) + .split(makeSplit(filePath->getPath())) .assertResults("SELECT A FROM tmp"); + // Wait for all tasks to be deleted to avoid race condition between async IO + // preloading and TempFilePath destruction. + waitForAllTasksToBeDeleted(); + // Test reading table with non-ascii names. auto vectorsNonAsciiNames = {makeRowVector( {"Товары", "国Ⅵ", "\uFF21", "\uFF22"}, @@ -688,8 +962,9 @@ TEST_F(ParquetTableScanTest, readAsLowerCase) { makeFlatVector(20, [](auto row) { return row + 1; }), makeFlatVector(20, [](auto row) { return row + 1; }), })}; - filePath = TempFilePath::create(); - writeToParquetFile(filePath->getPath(), vectorsNonAsciiNames, options); + + auto filePath2 = TempFilePath::create(); + writeToParquetFile(filePath2->getPath(), vectorsNonAsciiNames, options); createDuckDbTable(vectorsNonAsciiNames); plan = PlanBuilder() @@ -697,15 +972,16 @@ TEST_F(ParquetTableScanTest, readAsLowerCase) { ROW({"товары", "国ⅵ", "\uFF41", "\uFF42"}, {BIGINT(), DOUBLE(), REAL(), INTEGER()})) .planNode(); - split = makeSplit(filePath->getPath()); AssertQueryBuilder(plan, duckDbQueryRunner_) .connectorSessionProperty( kHiveConnectorId, connector::hive::HiveConfig::kFileColumnNamesReadAsLowerCaseSession, "true") - .split(split) + .split(makeSplit(filePath2->getPath())) .assertResults("SELECT * FROM tmp"); + + waitForAllTasksToBeDeleted(); } TEST_F(ParquetTableScanTest, rowIndex) { @@ -713,7 +989,6 @@ TEST_F(ParquetTableScanTest, rowIndex) { // case 1: file not have `_tmp_metadata_row_index`, scan generate it for user. auto filePath = getExampleFilePath("sample.parquet"); loadData( - filePath, ROW({"a", "b", "_tmp_metadata_row_index", kPath}, {BIGINT(), DOUBLE(), BIGINT(), VARCHAR()}), makeRowVector( @@ -724,42 +999,62 @@ TEST_F(ParquetTableScanTest, rowIndex) { makeFlatVector(20, [](auto row) { return row; }), makeFlatVector( 20, [filePath](auto row) { return filePath; }), - }), - std::nullopt, - std::unordered_map{{kPath, filePath}}); + })); connector::ColumnHandleMap assignments; assignments["a"] = std::make_shared( "a", - connector::hive::HiveColumnHandle::ColumnType::kRegular, + connector::hive::FileColumnHandle::ColumnType::kRegular, BIGINT(), BIGINT()); assignments["b"] = std::make_shared( "b", - connector::hive::HiveColumnHandle::ColumnType::kRegular, + connector::hive::FileColumnHandle::ColumnType::kRegular, DOUBLE(), DOUBLE()); assignments[kPath] = synthesizedColumn(kPath, VARCHAR()); assignments["_tmp_metadata_row_index"] = std::make_shared( "_tmp_metadata_row_index", - connector::hive::HiveColumnHandle::ColumnType::kRowIndex, + connector::hive::FileColumnHandle::ColumnType::kRowIndex, BIGINT(), BIGINT()); - assertSelect({"a"}, "SELECT a FROM tmp"); + assertSelect( + {makeSplit( + filePath, + std::nullopt, + std::unordered_map{{kPath, filePath}})}, + {"a"}, + "SELECT a FROM tmp"); assertSelectWithAssignments( + {makeSplit( + filePath, + std::nullopt, + std::unordered_map{{kPath, filePath}})}, {"a", "_tmp_metadata_row_index"}, assignments, "SELECT a, _tmp_metadata_row_index FROM tmp"); assertSelectWithAssignments( + {makeSplit( + filePath, + std::nullopt, + std::unordered_map{{kPath, filePath}})}, {"_tmp_metadata_row_index", "a"}, assignments, "SELECT _tmp_metadata_row_index, a FROM tmp"); assertSelectWithAssignments( + {makeSplit( + filePath, + std::nullopt, + std::unordered_map{{kPath, filePath}})}, {"_tmp_metadata_row_index"}, assignments, "SELECT _tmp_metadata_row_index FROM tmp"); assertSelectWithAssignments( + {makeSplit( + filePath, + std::nullopt, + std::unordered_map{{kPath, filePath}})}, {kPath, "_tmp_metadata_row_index"}, assignments, fmt::format("SELECT {}, _tmp_metadata_row_index FROM tmp", kPath)); @@ -767,7 +1062,6 @@ TEST_F(ParquetTableScanTest, rowIndex) { // case 2: file has `_tmp_metadata_row_index` column, then use user data // insteads of generating it. loadData( - getExampleFilePath("sample_with_rowindex.parquet"), ROW({"a", "b", "_tmp_metadata_row_index"}, {BIGINT(), DOUBLE(), BIGINT()}), makeRowVector( @@ -778,8 +1072,10 @@ TEST_F(ParquetTableScanTest, rowIndex) { makeFlatVector(20, [](auto row) { return row + 1; }), })); - assertSelect({"a"}, "SELECT a FROM tmp"); + filePath = getExampleFilePath("sample_with_rowindex.parquet"); + assertSelect({makeSplit(filePath)}, {"a"}, "SELECT a FROM tmp"); assertSelect( + {makeSplit(filePath)}, {"a", "_tmp_metadata_row_index"}, "SELECT a, _tmp_metadata_row_index FROM tmp"); } @@ -799,27 +1095,30 @@ TEST_F(ParquetTableScanTest, rowIndex) { // VALUES (1, 1), (2, null),(3, null); TEST_F(ParquetTableScanTest, filterNullIcebergPartition) { loadData( - getExampleFilePath("icebergNullIcebergPartition.parquet"), ROW({"c0", "c1"}, {BIGINT(), BIGINT()}), makeRowVector( {"c0", "c1"}, { makeFlatVector(std::vector{2, 3}), makeNullableFlatVector({std::nullopt, std::nullopt}), - }), - std::unordered_map>{ - {"c1", std::nullopt}}); + })); std::shared_ptr c0 = makeColumnHandle( - "c0", BIGINT(), BIGINT(), {}, HiveColumnHandle::ColumnType::kRegular); + "c0", BIGINT(), BIGINT(), {}, FileColumnHandle::ColumnType::kRegular); std::shared_ptr c1 = makeColumnHandle( "c1", BIGINT(), BIGINT(), {}, - HiveColumnHandle::ColumnType::kPartitionKey); + FileColumnHandle::ColumnType::kPartitionKey); + const auto filePath = + getExampleFilePath("icebergNullIcebergPartition.parquet"); assertSelectWithFilter( + {makeSplit( + filePath, + std::unordered_map>{ + {"c1", std::nullopt}})}, {"c0", "c1"}, {"c1 IS NOT NULL"}, "", @@ -827,6 +1126,10 @@ TEST_F(ParquetTableScanTest, filterNullIcebergPartition) { connector::ColumnHandleMap{{"c0", c0}, {"c1", c1}}); assertSelectWithFilter( + {makeSplit( + filePath, + std::unordered_map>{ + {"c1", std::nullopt}})}, {"c0", "c1"}, {"c1 IS NULL"}, "", @@ -844,7 +1147,6 @@ TEST_F(ParquetTableScanTest, sessionTimezone) { // Read sample.parquet to verify if the sessionTimezone in the PageReader // meets expectations. loadData( - getExampleFilePath("sample.parquet"), ROW({"a", "b"}, {BIGINT(), DOUBLE()}), makeRowVector( {"a", "b"}, @@ -853,48 +1155,90 @@ TEST_F(ParquetTableScanTest, sessionTimezone) { makeFlatVector(20, [](auto row) { return row + 1; }), })); - assertSelectWithTimezone({"a"}, "SELECT a FROM tmp", "Asia/Shanghai"); + assertSelectWithTimezone( + {makeSplit(getExampleFilePath("sample.parquet"))}, + {"a"}, + "SELECT a FROM tmp", + "Asia/Shanghai"); } -TEST_F(ParquetTableScanTest, timestampInt64Dictionary) { +TEST_F(ParquetTableScanTest, timestampInt64DictionaryMicro) { WriterOptions options; options.writeInt96AsTimestamp = false; options.enableDictionary = true; options.parquetWriteTimestampUnit = TimestampPrecision::kMicroseconds; - testTimestampRead(options); + testTimestampRead(options, TimestampPrecision::kMicroseconds); + testTimestampRead(options, TimestampPrecision::kMilliseconds); +} + +TEST_F(ParquetTableScanTest, timestampInt64DictionaryMilli) { + WriterOptions options; + options.writeInt96AsTimestamp = false; + options.enableDictionary = true; + options.parquetWriteTimestampUnit = TimestampPrecision::kMilliseconds; + testTimestampRead(options, TimestampPrecision::kMicroseconds); + testTimestampRead(options, TimestampPrecision::kMilliseconds); } -TEST_F(ParquetTableScanTest, timestampInt64Plain) { +TEST_F(ParquetTableScanTest, timestampInt64PlainMicro) { WriterOptions options; options.writeInt96AsTimestamp = false; options.enableDictionary = false; options.parquetWriteTimestampUnit = TimestampPrecision::kMicroseconds; - testTimestampRead(options); + testTimestampRead(options, TimestampPrecision::kMicroseconds); + testTimestampRead(options, TimestampPrecision::kMilliseconds); } -TEST_F(ParquetTableScanTest, timestampInt96Dictionary) { +TEST_F(ParquetTableScanTest, timestampInt64PlainMilli) { + WriterOptions options; + options.writeInt96AsTimestamp = false; + options.enableDictionary = false; + options.parquetWriteTimestampUnit = TimestampPrecision::kMilliseconds; + testTimestampRead(options, TimestampPrecision::kMicroseconds); + testTimestampRead(options, TimestampPrecision::kMilliseconds); +} + +TEST_F(ParquetTableScanTest, timestampInt96DictionaryMicro) { WriterOptions options; options.writeInt96AsTimestamp = true; options.enableDictionary = true; options.parquetWriteTimestampUnit = TimestampPrecision::kMicroseconds; - testTimestampRead(options); + testTimestampRead(options, TimestampPrecision::kMicroseconds); + testTimestampRead(options, TimestampPrecision::kMilliseconds); +} + +TEST_F(ParquetTableScanTest, timestampInt96DictionaryMilli) { + WriterOptions options; + options.writeInt96AsTimestamp = true; + options.enableDictionary = true; + options.parquetWriteTimestampUnit = TimestampPrecision::kMilliseconds; + testTimestampRead(options, TimestampPrecision::kMicroseconds); + testTimestampRead(options, TimestampPrecision::kMilliseconds); } -TEST_F(ParquetTableScanTest, timestampInt96Plain) { +TEST_F(ParquetTableScanTest, timestampInt96PlainMicro) { WriterOptions options; options.writeInt96AsTimestamp = true; options.enableDictionary = false; options.parquetWriteTimestampUnit = TimestampPrecision::kMicroseconds; - testTimestampRead(options); + testTimestampRead(options, TimestampPrecision::kMicroseconds); + testTimestampRead(options, TimestampPrecision::kMilliseconds); +} + +TEST_F(ParquetTableScanTest, timestampInt96PlainMilli) { + WriterOptions options; + options.writeInt96AsTimestamp = true; + options.enableDictionary = false; + options.parquetWriteTimestampUnit = TimestampPrecision::kMilliseconds; + testTimestampRead(options, TimestampPrecision::kMicroseconds); + testTimestampRead(options, TimestampPrecision::kMilliseconds); } TEST_F(ParquetTableScanTest, timestampConvertedType) { auto stringToTimestamp = [](std::string_view view) { return util::fromTimestampString( view.data(), view.size(), util::TimestampParseMode::kPrestoCast) - .thenOrThrow(folly::identity, [&](const Status& status) { - VELOX_USER_FAIL("{}", status.message()); - }); + .value(); }; std::vector expected = { "1970-01-01 00:00:00.010", @@ -913,10 +1257,14 @@ TEST_F(ParquetTableScanTest, timestampConvertedType) { makeFlatVector(values), }); const auto schema = asRowType(vector->type()); - const auto path = getExampleFilePath("tmmillis_i64.parquet"); - loadData(path, schema, vector); + loadData(schema, vector); - assertSelectWithFilter({"time"}, {}, "", "SELECT time from tmp"); + assertSelectWithFilter( + {makeSplit(getExampleFilePath("tmmillis_i64.parquet"))}, + {"time"}, + {}, + "", + "SELECT time from tmp"); } TEST_F(ParquetTableScanTest, timestampPrecisionMicrosecond) { @@ -950,6 +1298,42 @@ TEST_F(ParquetTableScanTest, timestampPrecisionMicrosecond) { } } +TEST_F(ParquetTableScanTest, timestampUtcPlainMicro) { + parquet::WriterOptions options; + options.writeInt96AsTimestamp = false; + options.enableDictionary = false; + options.parquetWriteTimestampUnit = TimestampPrecision::kMicroseconds; + testTimestampUtcRead(options, TimestampPrecision::kMicroseconds); + testTimestampUtcRead(options, TimestampPrecision::kMilliseconds); +} + +TEST_F(ParquetTableScanTest, timestampUtcDictionaryMicro) { + parquet::WriterOptions options; + options.writeInt96AsTimestamp = false; + options.enableDictionary = true; + options.parquetWriteTimestampUnit = TimestampPrecision::kMicroseconds; + testTimestampUtcRead(options, TimestampPrecision::kMicroseconds); + testTimestampUtcRead(options, TimestampPrecision::kMilliseconds); +} + +TEST_F(ParquetTableScanTest, timestampUtcPlainMilli) { + parquet::WriterOptions options; + options.writeInt96AsTimestamp = false; + options.enableDictionary = false; + options.parquetWriteTimestampUnit = TimestampPrecision::kMilliseconds; + testTimestampUtcRead(options, TimestampPrecision::kMicroseconds); + testTimestampUtcRead(options, TimestampPrecision::kMilliseconds); +} + +TEST_F(ParquetTableScanTest, timestampUtcDictionaryMilli) { + parquet::WriterOptions options; + options.writeInt96AsTimestamp = false; + options.enableDictionary = true; + options.parquetWriteTimestampUnit = TimestampPrecision::kMilliseconds; + testTimestampUtcRead(options, TimestampPrecision::kMicroseconds); + testTimestampUtcRead(options, TimestampPrecision::kMilliseconds); +} + TEST_F(ParquetTableScanTest, testColumnNotExists) { auto rowType = ROW({"a", "b", "not_exists", "not_exists_array", "not_exists_map"}, @@ -963,7 +1347,6 @@ TEST_F(ParquetTableScanTest, testColumnNotExists) { // optional double b; // } loadData( - getExampleFilePath("sample.parquet"), rowType, makeRowVector( {"a", "b"}, @@ -973,6 +1356,7 @@ TEST_F(ParquetTableScanTest, testColumnNotExists) { })); assertSelectWithDataColumns( + {makeSplit(getExampleFilePath("sample.parquet"))}, {"a", "b", "not_exists", "not_exists_array", "not_exists_map"}, rowType, "SELECT a, b, NULL, NULL, NULL FROM tmp"); @@ -998,8 +1382,8 @@ TEST_F(ParquetTableScanTest, schemaMatchWithComplexTypes) { {"p", "m", "a"}, {primitiveVector, mapVector, arrayVector}); // columns in data file - const std::shared_ptr dataFileFolder = - exec::test::TempDirectoryPath::create(); + const std::shared_ptr dataFileFolder = + TempDirectoryPath::create(); auto filePath = dataFileFolder->getPath() + "/" + "nested_data.parquet"; WriterOptions options; options.writeInt96AsTimestamp = false; @@ -1022,8 +1406,8 @@ TEST_F(ParquetTableScanTest, schemaMatchWithComplexTypes) { .project({"p1", "m1[0].aa1", "m1[1].bb1", "a1[1].aa1", "a1[2].bb1"}) .planNode(); - auto split = makeSplit(filePath); - auto result = AssertQueryBuilder(op).split(split).copyResults(pool()); + auto result = + AssertQueryBuilder(op).split(makeSplit(filePath)).copyResults(pool()); ASSERT_EQ(result->size(), kSize); auto rows = result->as(); @@ -1049,7 +1433,7 @@ TEST_F(ParquetTableScanTest, schemaMatchWithComplexTypes) { kHiveConnectorId, connector::hive::HiveConfig::kParquetUseColumnNamesSession, "true") - .split(split) + .split(makeSplit(filePath)) .copyResults(pool()); rows = result->as(); // check for rest of the selected columns @@ -1074,8 +1458,8 @@ TEST_F(ParquetTableScanTest, schemaMatch) { {makeFlatVector(kSize, [](auto row) { return row; }), makeFlatVector(kSize, [](auto row) { return row * 4; })}); - const std::shared_ptr dataFileFolder = - exec::test::TempDirectoryPath::create(); + const std::shared_ptr dataFileFolder = + TempDirectoryPath::create(); auto filePath = dataFileFolder->getPath() + "/" + "data.parquet"; WriterOptions options; options.writeInt96AsTimestamp = false; @@ -1089,8 +1473,8 @@ TEST_F(ParquetTableScanTest, schemaMatch) { .endTableScan() .planNode(); - auto split = makeSplit(filePath); - auto result = AssertQueryBuilder(op).split(split).copyResults(pool()); + auto result = + AssertQueryBuilder(op).split(makeSplit(filePath)).copyResults(pool()); auto rows = result->as(); assertEqualVectors(rows->childAt(0), dataFileVectors->childAt(0)); @@ -1106,7 +1490,7 @@ TEST_F(ParquetTableScanTest, schemaMatch) { .endTableScan() .planNode(); EXPECT_THROW( - AssertQueryBuilder(op).split(split).copyResults(pool()), + AssertQueryBuilder(op).split(makeSplit(filePath)).copyResults(pool()), VeloxRuntimeError); // Now run query with column mapping using names, now c2 columns will match in @@ -1123,7 +1507,7 @@ TEST_F(ParquetTableScanTest, schemaMatch) { kHiveConnectorId, connector::hive::HiveConfig::kParquetUseColumnNamesSession, "true") - .split(split) + .split(makeSplit(filePath)) .copyResults(pool()); rows = result->as(); @@ -1143,7 +1527,7 @@ TEST_F(ParquetTableScanTest, schemaMatch) { .planNode(); EXPECT_THROW( - AssertQueryBuilder(op).split(split).copyResults(pool()), + AssertQueryBuilder(op).split(makeSplit(filePath)).copyResults(pool()), VeloxRuntimeError); // Schema evolution remove column. @@ -1156,7 +1540,8 @@ TEST_F(ParquetTableScanTest, schemaMatch) { .project({"c1"}) .planNode(); - result = AssertQueryBuilder(op).split(split).copyResults(pool()); + result = + AssertQueryBuilder(op).split(makeSplit(filePath)).copyResults(pool()); rows = result->as(); assertEqualVectors(rows->childAt(0), dataFileVectors->childAt(0)); @@ -1170,7 +1555,8 @@ TEST_F(ParquetTableScanTest, schemaMatch) { .project({"c1", "c2", "c3"}) .planNode(); - result = AssertQueryBuilder(op).split(split).copyResults(pool()); + result = + AssertQueryBuilder(op).split(makeSplit(filePath)).copyResults(pool()); rows = result->as(); assertEqualVectors(rows->childAt(0), dataFileVectors->childAt(0)); assertEqualVectors(rows->childAt(1), dataFileVectors->childAt(1)); @@ -1183,17 +1569,17 @@ TEST_F(ParquetTableScanTest, deltaByteArray) { createDuckDbTable("expected", {expected}); auto vector = makeFlatVector({{}}); - loadData( - getExampleFilePath("delta_byte_array.parquet"), - ROW({"a"}, {VARCHAR()}), - makeRowVector({"a"}, {vector})); - assertSelect({"a"}, "SELECT a from expected"); + loadData(ROW({"a"}, {VARCHAR()}), makeRowVector({"a"}, {vector})); + assertSelect( + {makeSplit(getExampleFilePath("delta_byte_array.parquet"))}, + {"a"}, + "SELECT a from expected"); } TEST_F(ParquetTableScanTest, booleanRle) { WriterOptions options; options.enableDictionary = false; - options.encoding = facebook::velox::parquet::arrow::Encoding::RLE; + options.encoding = facebook::velox::parquet::arrow::Encoding::kRle; options.useParquetDataPageV2 = true; auto allTrue = [](vector_size_t row) -> bool { return true; }; @@ -1218,30 +1604,30 @@ TEST_F(ParquetTableScanTest, booleanRle) { auto schema = asRowType(vector->type()); auto file = TempFilePath::create(); writeToParquetFile(file->getPath(), {vector}, options); - loadData(file->getPath(), schema, vector); + loadData(schema, vector); std::shared_ptr c0 = makeColumnHandle( - "c0", BOOLEAN(), BOOLEAN(), {}, HiveColumnHandle::ColumnType::kRegular); + "c0", BOOLEAN(), BOOLEAN(), {}, FileColumnHandle::ColumnType::kRegular); std::shared_ptr c1 = makeColumnHandle( - "c1", BOOLEAN(), BOOLEAN(), {}, HiveColumnHandle::ColumnType::kRegular); + "c1", BOOLEAN(), BOOLEAN(), {}, FileColumnHandle::ColumnType::kRegular); std::shared_ptr c2 = makeColumnHandle( - "c2", BOOLEAN(), BOOLEAN(), {}, HiveColumnHandle::ColumnType::kRegular); + "c2", BOOLEAN(), BOOLEAN(), {}, FileColumnHandle::ColumnType::kRegular); std::shared_ptr c3 = makeColumnHandle( - "c3", BOOLEAN(), BOOLEAN(), {}, HiveColumnHandle::ColumnType::kRegular); + "c3", BOOLEAN(), BOOLEAN(), {}, FileColumnHandle::ColumnType::kRegular); std::shared_ptr c4 = makeColumnHandle( - "c4", BOOLEAN(), BOOLEAN(), {}, HiveColumnHandle::ColumnType::kRegular); + "c4", BOOLEAN(), BOOLEAN(), {}, FileColumnHandle::ColumnType::kRegular); - assertSelect({"c0"}, "SELECT c0 FROM tmp"); - assertSelect({"c1"}, "SELECT c1 FROM tmp"); - assertSelect({"c2"}, "SELECT c2 FROM tmp"); - assertSelect({"c3"}, "SELECT c3 FROM tmp"); - assertSelect({"c4"}, "SELECT c4 FROM tmp"); + assertSelect({makeSplit(file->getPath())}, {"c0"}, "SELECT c0 FROM tmp"); + assertSelect({makeSplit(file->getPath())}, {"c1"}, "SELECT c1 FROM tmp"); + assertSelect({makeSplit(file->getPath())}, {"c2"}, "SELECT c2 FROM tmp"); + assertSelect({makeSplit(file->getPath())}, {"c3"}, "SELECT c3 FROM tmp"); + assertSelect({makeSplit(file->getPath())}, {"c4"}, "SELECT c4 FROM tmp"); } TEST_F(ParquetTableScanTest, singleBooleanRle) { WriterOptions options; options.enableDictionary = false; - options.encoding = facebook::velox::parquet::arrow::Encoding::RLE; + options.encoding = facebook::velox::parquet::arrow::Encoding::kRle; options.useParquetDataPageV2 = true; auto vector = makeRowVector( @@ -1254,18 +1640,18 @@ TEST_F(ParquetTableScanTest, singleBooleanRle) { auto schema = asRowType(vector->type()); auto file = TempFilePath::create(); writeToParquetFile(file->getPath(), {vector}, options); - loadData(file->getPath(), schema, vector); + loadData(schema, vector); std::shared_ptr c0 = makeColumnHandle( - "c0", BOOLEAN(), BOOLEAN(), {}, HiveColumnHandle::ColumnType::kRegular); + "c0", BOOLEAN(), BOOLEAN(), {}, FileColumnHandle::ColumnType::kRegular); std::shared_ptr c1 = makeColumnHandle( - "c1", BOOLEAN(), BOOLEAN(), {}, HiveColumnHandle::ColumnType::kRegular); + "c1", BOOLEAN(), BOOLEAN(), {}, FileColumnHandle::ColumnType::kRegular); std::shared_ptr c2 = makeColumnHandle( - "c2", BOOLEAN(), BOOLEAN(), {}, HiveColumnHandle::ColumnType::kRegular); + "c2", BOOLEAN(), BOOLEAN(), {}, FileColumnHandle::ColumnType::kRegular); - assertSelect({"c0"}, "SELECT c0 FROM tmp"); - assertSelect({"c1"}, "SELECT c1 FROM tmp"); - assertSelect({"c2"}, "SELECT c2 FROM tmp"); + assertSelect({makeSplit(file->getPath())}, {"c0"}, "SELECT c0 FROM tmp"); + assertSelect({makeSplit(file->getPath())}, {"c1"}, "SELECT c1 FROM tmp"); + assertSelect({makeSplit(file->getPath())}, {"c2"}, "SELECT c2 FROM tmp"); } TEST_F(ParquetTableScanTest, intToBigintRead) { @@ -1276,8 +1662,8 @@ TEST_F(ParquetTableScanTest, intToBigintRead) { RowVectorPtr bigintDataFileVectors = makeRowVector( {"c1"}, {makeFlatVector(kSize, [](auto row) { return row; })}); - const std::shared_ptr dataFileFolder = - exec::test::TempDirectoryPath::create(); + const std::shared_ptr dataFileFolder = + TempDirectoryPath::create(); auto filePath = dataFileFolder->getPath() + "/" + "data.parquet"; WriterOptions options; options.writeInt96AsTimestamp = false; @@ -1298,6 +1684,116 @@ TEST_F(ParquetTableScanTest, intToBigintRead) { assertEqualVectors(bigintDataFileVectors->childAt(0), rows->childAt(0)); } +TEST_F(ParquetTableScanTest, intNarrowingRejectedByDefault) { + // Narrowing conversions are rejected when allowInt32Narrowing is false + // (default). Each case writes a wider integer column and reads as a narrower + // type, expecting an exception from convertType. + auto assertNarrowingThrows = [&](const VectorPtr& sourceVector, + const TypePtr& targetType) { + auto vectors = makeRowVector({"c1"}, {sourceVector}); + auto dataFile = TempFilePath::create(); + writeToParquetFile(dataFile->getPath(), {vectors}, WriterOptions{}); + auto rowType = ROW({"c1"}, {targetType}); + auto op = PlanBuilder() + .startTableScan() + .outputType(rowType) + .dataColumns(rowType) + .endTableScan() + .planNode(); + VELOX_ASSERT_THROW( + AssertQueryBuilder(op) + .split(makeSplit(dataFile->getPath())) + .copyResults(pool()), + "is not allowed for requested type"); + }; + + assertNarrowingThrows(makeFlatVector({1, 2}), TINYINT()); + assertNarrowingThrows(makeFlatVector({1, 2}), SMALLINT()); + assertNarrowingThrows(makeFlatVector({1, 2}), TINYINT()); +} + +TEST_F(ParquetTableScanTest, intReadWithNarrowerType) { + // Reading a wider integer as a narrower one causes unchecked truncation and + // two's complement reinterpretation, resulting in values INT_MAX becoming -1. + // Only INT32 physical type narrowing is supported (INT_16 -> TINYINT, + // INT_32 -> SMALLINT/TINYINT). INT64 -> INT32 is not allowed. + RowVectorPtr intVectors = makeRowVector( + {"c1", "c2", "c3"}, + { + makeFlatVector( + {123, + std::numeric_limits::max(), + std::numeric_limits::min(), + std::numeric_limits::max(), + std::numeric_limits::min()}), + makeFlatVector( + {123, + std::numeric_limits::max(), + std::numeric_limits::min(), + std::numeric_limits::max(), + std::numeric_limits::min()}), + makeFlatVector( + {123, + std::numeric_limits::max(), + std::numeric_limits::min(), + std::numeric_limits::max(), + std::numeric_limits::min()}), + }); + + RowVectorPtr smallerIntVectors = makeRowVector( + {"c1", "c2", "c3"}, + { + makeFlatVector({ + 123, + std::numeric_limits::max(), + std::numeric_limits::min(), + -1, + 0, + }), + makeFlatVector({ + 123, + std::numeric_limits::max(), + std::numeric_limits::min(), + -1, + 0, + }), + makeFlatVector({ + 123, + std::numeric_limits::max(), + std::numeric_limits::min(), + -1, + 0, + }), + }); + + auto dataFile = TempFilePath::create(); + WriterOptions options; + writeToParquetFile(dataFile->getPath(), {intVectors}, options); + + auto rowType = ROW({"c1", "c2", "c3"}, {TINYINT(), SMALLINT(), TINYINT()}); + auto op = PlanBuilder() + .startTableScan() + .outputType(rowType) + .dataColumns(rowType) + .endTableScan() + .planNode(); + + auto split = makeSplit(dataFile->getPath()); + auto result = + AssertQueryBuilder(op) + .connectorSessionProperty( + kHiveConnectorId, + connector::hive::HiveConfig::kAllowInt32NarrowingSession, + "true") + .split(split) + .copyResults(pool()); + auto rows = result->as(); + + assertEqualVectors(smallerIntVectors->childAt(0), rows->childAt(0)); + assertEqualVectors(smallerIntVectors->childAt(1), rows->childAt(1)); + assertEqualVectors(smallerIntVectors->childAt(2), rows->childAt(2)); +} + TEST_F(ParquetTableScanTest, shortAndLongDecimalReadWithLargerPrecision) { // decimal.parquet holds two columns (a: DECIMAL(5, 2), b: DECIMAL(20, 5)) and // 20 rows (10 rows per group). Data is in plain uncompressed format: @@ -1324,8 +1820,8 @@ TEST_F(ParquetTableScanTest, shortAndLongDecimalReadWithLargerPrecision) { {makeFlatVector(unscaledShortValues, DECIMAL(8, 2)), makeFlatVector(longDecimalValues, DECIMAL(22, 5))}); - const std::shared_ptr dataFileFolder = - exec::test::TempDirectoryPath::create(); + const std::shared_ptr dataFileFolder = + TempDirectoryPath::create(); auto filePath = getExampleFilePath("decimal.parquet"); auto rowType = ROW({"c1", "c2"}, {DECIMAL(8, 2), DECIMAL(22, 5)}); @@ -1344,6 +1840,200 @@ TEST_F(ParquetTableScanTest, shortAndLongDecimalReadWithLargerPrecision) { assertEqualVectors(expectedDecimalVectors->childAt(1), rows->childAt(1)); } +TEST_F(ParquetTableScanTest, inFilter) { + auto vectors = {makeRowVector( + {"name"}, + { + makeNullableFlatVector( + {"mary", "martin", "lucy", "alex", std::nullopt, "mary", "dan"}), + })}; + auto filePath = TempFilePath::create(); + WriterOptions options; + writeToParquetFile(filePath->getPath(), vectors, options); + createDuckDbTable(vectors); + + // Test in. + auto plan = PlanBuilder(pool_.get()) + .tableScan( + ROW({"name"}, {VARCHAR()}), + {"name in ('alex', 'leo', 'mary', null, 'victor')"}, + "") + .planNode(); + AssertQueryBuilder(plan, duckDbQueryRunner_) + .split(makeSplit(filePath->getPath())) + .assertResults( + "SELECT name FROM tmp where name in ('alex', 'leo', 'mary', null, 'victor')"); + + // Test not in. + plan = PlanBuilder(pool_.get()) + .tableScan( + ROW({"name"}, {VARCHAR()}), + {"name not in ('alex', 'leo', 'mary', null, 'victor')"}, + "") + .planNode(); + AssertQueryBuilder(plan, duckDbQueryRunner_) + .split(makeSplit(filePath->getPath())) + .assertResults( + "SELECT name FROM tmp where name not in ('alex', 'leo', 'mary', null, 'victor')"); +} + +TEST_F(ParquetTableScanTest, reusedLazyVectors) { + const std::vector columnNames = {"a", "b"}; + std::vector data; + for (auto row = 0; row < 10; ++row) { + data.emplace_back(makeRowVector( + columnNames, + { + makeFlatVector({static_cast(row % 5)}), + makeFlatVector({static_cast(row)}), + })); + } + const auto expectedRowVector = makeRowVector( + {makeFlatVector({0, 1, 2, 3, 4}), + makeFlatVector({5, 7, 9, 11, 13}), + makeFlatVector({5, 7, 9, 11, 13})}); + + const auto filePath = TempFilePath::create(); + WriterOptions options; + writeToParquetFile(filePath->getPath(), data, options); + + const auto plan = PlanBuilder() + .tableScan(ROW(columnNames, {BIGINT(), BIGINT()})) + .project({"a as c1", "b as c2", "b as c3"}) + .singleAggregation({"c1"}, {"sum(c2)", "sum(c3)"}) + .planNode(); + AssertQueryBuilder(plan) + .split(makeSplit(filePath->getPath())) + .assertResults(expectedRowVector); +} + +// Verify that entire Parquet files are pruned based on file-level column +// statistics when the filter eliminates all data in the file. +TEST_F(ParquetTableScanTest, statsBasedFileSkipping) { + WriterOptions options; + std::vector filePaths; + std::vector dataVectors; + const vector_size_t numRows = 100; + filePaths.push_back(TempFilePath::create()->getPath()); + dataVectors.push_back(makeRowVector( + {"c0", "c1", "c2"}, + { + makeFlatVector(numRows, [](auto row) { return row; }), + makeFlatVector( + numRows, [](auto row) { return static_cast(row); }), + makeFlatVector( + numRows, + [](auto row) { + static std::vector values = {"a", "b", "c", "d"}; + return values[row % values.size()]; + }), + })); + // File 1: integers [0, 99], doubles [0.0, 99.0], strings ["a".."d"]. + writeToParquetFile(filePaths.back(), dataVectors, options); + + filePaths.push_back(TempFilePath::create()->getPath()); + dataVectors.push_back(makeRowVector( + {"c0", "c1", "c2"}, + { + makeFlatVector(numRows, [](auto row) { return row + 200; }), + makeFlatVector(numRows, [](auto row) { return row + 200; }), + makeFlatVector( + numRows, + [](auto row) { + static std::vector values = {"p", "q", "r", "s"}; + return values[row % values.size()]; + }), + })); + // File 2: integers [200, 299], doubles [200.0, 299.0], strings ["p".."s"]. + writeToParquetFile(filePaths.back(), {dataVectors.back()}, options); + + createDuckDbTable(dataVectors); + + auto makeSplits = [&]() { + std::vector> splits; + for (const auto& path : filePaths) { + splits.push_back(makeSplit(path)); + } + return splits; + }; + + auto testFileSkipping = [&](const std::string& filter, + int32_t expectedSkipped, + int32_t expectedProcessed) { + SCOPED_TRACE(filter); + auto plan = PlanBuilder(pool_.get()) + .tableScan(dataVectors.back()->rowType(), {filter}) + .planNode(); + auto task = AssertQueryBuilder(plan, duckDbQueryRunner_) + .splits(makeSplits()) + .assertResults("SELECT * FROM tmp WHERE " + filter); + auto stats = + task->taskStats().pipelineStats[0].operatorStats[0].runtimeStats; + EXPECT_EQ(stats["skippedSplits"].sum, expectedSkipped); + EXPECT_EQ(stats["processedSplits"].sum, expectedProcessed); + }; + + // Neither file has values > 1000, both files skipped. + testFileSkipping("c0 > 1000", 2, 0); + // Neither file has values < 0, both files skipped. + testFileSkipping("c0 < 0", 2, 0); + // Low-range file (max=99) is skipped, high-range file is read. + testFileSkipping("c0 >= 200", 1, 1); + // High-range file (min=200) is skipped, low-range file is read. + testFileSkipping("c0 <= 99", 1, 1); + // Double column: both files skipped. + testFileSkipping("c1 > 500.0", 2, 0); + // String column: low-range has ["a".."d"], high-range has ["p".."s"], both + // skipped. + testFileSkipping("c2 = 'z'", 2, 0); + // Matches both files, no files skipped. + testFileSkipping("c0 >= 0", 0, 2); +} + +TEST_F(ParquetTableScanTest, fileFormatRuntimeStats) { + auto rowType = ROW({"a", "b"}, {BIGINT(), DOUBLE()}); + auto vector = makeRowVector( + {"a", "b"}, + { + makeFlatVector(100, [](auto row) { return row; }), + makeFlatVector(100, [](auto row) { return row * 1.5; }), + }); + std::vector vectors = {vector}; + + // Write one Parquet file and one DWRF file. + auto parquetFile = TempFilePath::create(); + WriterOptions parquetOptions; + parquetOptions.memoryPool = rootPool_.get(); + writeToParquetFile(parquetFile->getPath(), vectors, parquetOptions); + + auto dwrfFile = TempFilePath::create(); + writeToFile(dwrfFile->getPath(), vectors); + + // DuckDB reference table with data from both files. + std::vector allVectors = {vector, vector}; + createDuckDbTable(allVectors); + + auto parquetSplit = makeHiveConnectorSplits( + parquetFile->getPath(), 1, dwio::common::FileFormat::PARQUET); + auto dwrfSplit = makeHiveConnectorSplits( + dwrfFile->getPath(), 1, dwio::common::FileFormat::DWRF); + + auto plan = PlanBuilder().tableScan(asRowType(vectors[0]->type())).planNode(); + + auto task = AssertQueryBuilder(plan, duckDbQueryRunner_) + .splits({parquetSplit[0], dwrfSplit[0]}) + .assertResults("SELECT * FROM tmp"); + + auto stats = task->taskStats().pipelineStats[0].operatorStats[0].runtimeStats; + ASSERT_EQ(stats.count("fileFormat.parquet"), 1); + ASSERT_EQ(stats.at("fileFormat.parquet").sum, 1); + ASSERT_EQ(stats.count("fileFormat.dwrf"), 1); + ASSERT_EQ(stats.at("fileFormat.dwrf").sum, 1); + + task.reset(); + waitForAllTasksToBeDeleted(); +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); folly::Init init{&argc, &argv, false}; diff --git a/velox/dwio/parquet/tests/writer/CMakeLists.txt b/velox/dwio/parquet/tests/writer/CMakeLists.txt index a6b9daec340..6948dbfcf70 100644 --- a/velox/dwio/parquet/tests/writer/CMakeLists.txt +++ b/velox/dwio/parquet/tests/writer/CMakeLists.txt @@ -26,7 +26,6 @@ target_link_libraries( velox_dwio_common_test_utils velox_vector_fuzzer velox_caching - Boost::regex velox_link_libs Folly::folly ${TEST_LINK_LIBS} @@ -34,7 +33,7 @@ target_link_libraries( fmt::fmt ) -add_executable(velox_parquet_writer_test ParquetWriterTest.cpp) +add_executable(velox_parquet_writer_test ParquetWriterFieldIdTest.cpp ParquetWriterTest.cpp) add_test( NAME velox_parquet_writer_test @@ -44,12 +43,12 @@ add_test( target_link_libraries( velox_parquet_writer_test + velox_dwio_arrow_parquet_writer_test_lib velox_dwio_parquet_writer velox_dwio_parquet_reader velox_dwio_common_test_utils velox_caching velox_link_libs - Boost::regex Folly::folly ${TEST_LINK_LIBS} GTest::gtest diff --git a/velox/dwio/parquet/tests/writer/ParquetWriterFieldIdTest.cpp b/velox/dwio/parquet/tests/writer/ParquetWriterFieldIdTest.cpp new file mode 100644 index 00000000000..097dd14740c --- /dev/null +++ b/velox/dwio/parquet/tests/writer/ParquetWriterFieldIdTest.cpp @@ -0,0 +1,126 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/dwio/parquet/writer/arrow/tests/TestUtil.h" + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/dwio/parquet/tests/ParquetTestBase.h" +#include "velox/dwio/parquet/writer/arrow/Schema.h" +#include "velox/dwio/parquet/writer/arrow/tests/FileReader.h" + +namespace { + +using namespace facebook::velox; +using namespace facebook::velox::dwio::common; +using namespace facebook::velox::parquet; + +class ParquetWriterFieldIdTest : public ParquetTestBase, + public ::testing::WithParamInterface {}; + +TEST_P(ParquetWriterFieldIdTest, fieldIds) { + auto schema = + ROW({"p", "s", "a", "m"}, + {BIGINT(), + ROW({"x", "y"}, {INTEGER(), VARCHAR()}), + ARRAY(INTEGER()), + MAP(VARCHAR(), INTEGER())}); + constexpr int32_t kRows = 10; + auto data = makeRowVector( + {"p", "s", "a", "m"}, + {makeFlatVector(kRows, [](auto row) { return row; }), + makeRowVector( + {"x", "y"}, + {makeFlatVector(kRows, [](auto row) { return row; }), + makeFlatVector(kRows, [](auto) { return "z"; })}), + makeArrayVectorFromJson(std::vector(kRows, "[3]")), + makeMapVectorFromJson( + std::vector(kRows, R"({"k": 4})"))}); + + parquet::WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); + + if (GetParam()) { + // Provide Parquet field IDs aligned with the Velox schema tree. + // p -> 10. + // s -> 20, children: x -> 21, y -> 22. + // a -> 30, list element -> 31. + // m -> 40, children: key -> 41, value -> 42. + writerOptions.parquetFieldIds = { + ParquetFieldId{10, {}}, + ParquetFieldId{20, {ParquetFieldId{21, {}}, ParquetFieldId{22, {}}}}, + ParquetFieldId{30, {ParquetFieldId{31, {}}}}, + ParquetFieldId{40, {ParquetFieldId{41, {}}, ParquetFieldId{42, {}}}}, + }; + } + + auto* sinkPtr = write(data, writerOptions); + + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto parquetReader = createReaderInMemory(*sinkPtr, readerOptions); + EXPECT_EQ(parquetReader->numberOfRows(), kRows); + auto veloxRowType = parquetReader->rowType(); + EXPECT_EQ(*veloxRowType, *schema); + + std::string_view sinkData(sinkPtr->data(), sinkPtr->size()); + auto arrowBufferReader = std::make_shared<::arrow::io::BufferReader>( + std::make_shared<::arrow::Buffer>( + reinterpret_cast(sinkData.data()), sinkData.size())); + + auto fileReader = parquet::arrow::ParquetFileReader::open(arrowBufferReader); + auto metadata = fileReader->metadata(); + auto* descr = metadata->schema(); + auto* root = descr->groupNode(); + + ASSERT_EQ(root->fieldCount(), 4); + + auto exp = [&](int32_t expectedFieldId) { + return GetParam() ? expectedFieldId : -1; + }; + + // Top-level field IDs. + EXPECT_EQ(root->field(0)->fieldId(), exp(10)); + EXPECT_EQ(root->field(1)->fieldId(), exp(20)); + EXPECT_EQ(root->field(2)->fieldId(), exp(30)); + EXPECT_EQ(root->field(3)->fieldId(), exp(40)); + + using GroupNode = parquet::arrow::schema::GroupNode; + auto* s = static_cast(root->field(1).get()); + EXPECT_EQ(s->field(0)->fieldId(), exp(21)); + EXPECT_EQ(s->field(1)->fieldId(), exp(22)); + + auto* a = static_cast(root->field(2).get()); + // LIST logical group has one repeated child (the array entries); dive once + // more to the element. + auto* listEntries = a->field(0).get(); + auto* listGroup = static_cast(listEntries); + auto* element = listGroup->field(0).get(); + EXPECT_EQ(element->fieldId(), exp(31)); + + auto* m = static_cast(root->field(3).get()); + auto* keyValue = m->field(0).get(); + auto* keyValueGroup = static_cast(keyValue); + EXPECT_EQ(keyValueGroup->field(0)->fieldId(), exp(41)); + EXPECT_EQ(keyValueGroup->field(1)->fieldId(), exp(42)); +} + +VELOX_INSTANTIATE_TEST_SUITE_P( + ParquetWriterFieldIdTest, + ParquetWriterFieldIdTest, + ::testing::Values(false, true)); + +} // namespace diff --git a/velox/dwio/parquet/tests/writer/ParquetWriterTest.cpp b/velox/dwio/parquet/tests/writer/ParquetWriterTest.cpp index d22cbba6389..7d61caf4778 100644 --- a/velox/dwio/parquet/tests/writer/ParquetWriterTest.cpp +++ b/velox/dwio/parquet/tests/writer/ParquetWriterTest.cpp @@ -19,18 +19,20 @@ #include "velox/dwio/parquet/writer/arrow/tests/TestUtil.h" #include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/common/testutil/TestValue.h" +#include "velox/connectors/ConnectorRegistry.h" #include "velox/connectors/hive/HiveConnector.h" // @manual #include "velox/core/QueryCtx.h" #include "velox/dwio/common/tests/utils/BatchMaker.h" #include "velox/dwio/parquet/RegisterParquetWriter.h" // @manual #include "velox/dwio/parquet/reader/PageReader.h" #include "velox/dwio/parquet/tests/ParquetTestBase.h" +#include "velox/dwio/parquet/writer/WriterConfig.h" #include "velox/exec/Cursor.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/exec/tests/utils/QueryAssertions.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" namespace { @@ -39,6 +41,7 @@ using namespace facebook::velox::common; using namespace facebook::velox::dwio::common; using namespace facebook::velox::exec::test; using namespace facebook::velox::parquet; +using namespace facebook::velox::common::testutil; class ParquetWriterTest : public ParquetTestBase { protected: @@ -51,10 +54,15 @@ class ParquetWriterTest : public ParquetTestBase { kHiveConnectorId, std::make_shared( std::unordered_map())); - connector::registerConnector(hiveConnector); + connector::ConnectorRegistry::global().insert( + hiveConnector->connectorId(), hiveConnector); parquet::registerParquetWriterFactory(); } + void TearDown() override { + writers_.clear(); + } + std::unique_ptr createRowReaderWithSchema( const std::unique_ptr reader, const RowTypePtr& rowType) { @@ -63,20 +71,52 @@ class ParquetWriterTest : public ParquetTestBase { rowReaderOpts.setScanSpec(scanSpec); auto rowReader = reader->createRowReader(rowReaderOpts); return rowReader; - }; + } - std::unique_ptr createReaderInMemory( - const dwio::common::MemorySink& sink, - const dwio::common::ReaderOptions& opts) { - std::string data(sink.data(), sink.size()); - return std::make_unique( - std::make_unique( - std::make_shared(std::move(data)), - opts.memoryPool()), - opts); - }; + RowVectorPtr makeSmallintTestData(int64_t rows) { + auto data = makeRowVector({ + makeFlatVector(rows, [](auto row) { return row + 1; }), + }); + return data; + } + + RowVectorPtr makeTimestampTestData(int64_t rows) { + auto data = makeRowVector({makeFlatVector( + rows, [](auto row) { return Timestamp(row, row); })}); + return data; + } + + thrift::PageHeader readPageHeader( + MemorySink* sinkPtr, + int64_t offsetFromDataPage) { + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto reader = createReaderInMemory(*sinkPtr, readerOptions); + + auto colChunkPtr = reader->fileMetaData().rowGroup(0).columnChunk(0); + std::string_view sinkData(sinkPtr->data(), sinkPtr->size()); + + auto readFile = std::make_shared(sinkData); + auto file = std::make_shared(std::move(readFile)); + + auto inputStream = std::make_unique( + std::move(file), + colChunkPtr.dataPageOffset() + offsetFromDataPage, + 150, + *leafPool_, + LogType::TEST); + auto pageReader = std::make_unique( + std::move(inputStream), + *leafPool_, + colChunkPtr.compression(), + colChunkPtr.totalCompressedSize(), + stats); + return pageReader->readPageHeader(); + } inline static const std::string kHiveConnectorId = "test-hive"; + dwio::common::ColumnReaderStatistics stats; }; class ArrowMemoryPool final : public ::arrow::MemoryPool { @@ -147,76 +187,24 @@ std::vector params = { }; TEST_F(ParquetWriterTest, dictionaryEncodingWithDictionaryPageSize) { - const auto schema = ROW({"c0"}, {SMALLINT()}); constexpr int64_t kRows = 10'000; - const auto data = makeRowVector({ - makeFlatVector(kRows, [](auto row) { return row + 1; }), - }); + const auto data = makeSmallintTestData(kRows); // Write Parquet test data, then read and return the DataPage // (thrift::PageType::type) used. const auto testEnableDictionaryAndDictionaryPageSizeToGetPageHeader = [&](std::unordered_map configFromFile, std::unordered_map sessionProperties, - bool isFirstPageOrSecondPage) { - // Create an in-memory writer. - auto sink = std::make_unique( - 200 * 1024 * 1024, - dwio::common::FileSink::Options{.pool = leafPool_.get()}); - auto sinkPtr = sink.get(); - parquet::WriterOptions writerOptions; - writerOptions.memoryPool = leafPool_.get(); - - auto connectorConfig = config::ConfigBase(std::move(configFromFile)); - auto connectorSessionProperties = - config::ConfigBase(std::move(sessionProperties)); - - writerOptions.processConfigs( - connectorConfig, connectorSessionProperties); - auto writer = std::make_unique( - std::move(sink), writerOptions, rootPool_, schema); - writer->write(data); - writer->close(); - - // Read to identify DataPage used. - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; - auto reader = createReaderInMemory(*sinkPtr, readerOptions); - - auto colChunkPtr = reader->fileMetaData().rowGroup(0).columnChunk(0); - std::string_view sinkData(sinkPtr->data(), sinkPtr->size()); - - auto readFile = std::make_shared(sinkData); - auto file = std::make_shared(std::move(readFile)); - - if (isFirstPageOrSecondPage) { - auto inputStream = std::make_unique( - std::move(file), - colChunkPtr.dataPageOffset(), - 150, - *leafPool_, - LogType::TEST); - auto pageReader = std::make_unique( - std::move(inputStream), - *leafPool_, - colChunkPtr.compression(), - colChunkPtr.totalCompressedSize()); - return pageReader->readPageHeader(); + bool isFirstPage) { + auto* sinkPtr = write( + data, std::move(configFromFile), std::move(sessionProperties)); + if (isFirstPage) { + return readPageHeader(sinkPtr, 0); } constexpr int64_t kFirstDataPageCompressedSize = 1291; constexpr int64_t kFirstDataPageHeaderSize = 48; - auto inputStream = std::make_unique( - std::move(file), - colChunkPtr.dataPageOffset() + kFirstDataPageCompressedSize + - kFirstDataPageHeaderSize, - 150, - *leafPool_, - LogType::TEST); - auto pageReader = std::make_unique( - std::move(inputStream), - *leafPool_, - colChunkPtr.compression(), - colChunkPtr.totalCompressedSize()); - return pageReader->readPageHeader(); + return readPageHeader( + sinkPtr, kFirstDataPageCompressedSize + kFirstDataPageHeaderSize); }; // Test default config (i.e., no explicit config) @@ -247,13 +235,16 @@ TEST_F(ParquetWriterTest, dictionaryEncodingWithDictionaryPageSize) { // page size limit, the default is 1MB (same as data page default size) then // there will be only one data page contains all data encoded with dictionary const std::unordered_map normalConfigFromFile = { - {parquet::WriterOptions::kParquetHiveConnectorEnableDictionary, "true"}, - {parquet::WriterOptions::kParquetHiveConnectorDictionaryPageSizeLimit, + {config::ConfigBase::toConfigKey( + parquet::WriterConfig::kParquetSessionEnableDictionary), + "true"}, + {config::ConfigBase::toConfigKey( + parquet::WriterConfig::kParquetSessionDictionaryPageSizeLimit), "1B"}, }; const std::unordered_map normalSessionProperties = { - {parquet::WriterOptions::kParquetSessionEnableDictionary, "true"}, - {parquet::WriterOptions::kParquetSessionDictionaryPageSizeLimit, "1B"}, + {parquet::WriterConfig::kParquetSessionEnableDictionary, "true"}, + {parquet::WriterConfig::kParquetSessionDictionaryPageSizeLimit, "1B"}, }; // Here we are reading the second data page. If we don't set the dictionary @@ -273,12 +264,13 @@ TEST_F(ParquetWriterTest, dictionaryEncodingWithDictionaryPageSize) { const std::string invalidEnableDictionaryValue{"NaB"}; const std::unordered_map incorrectEnableDictionaryConfigFromFile = { - {parquet::WriterOptions::kParquetHiveConnectorEnableDictionary, + {config::ConfigBase::toConfigKey( + parquet::WriterConfig::kParquetSessionEnableDictionary), invalidEnableDictionaryValue}, }; const std::unordered_map incorrectEnableDictionarySessionProperties = { - {parquet::WriterOptions::kParquetSessionEnableDictionary, + {parquet::WriterConfig::kParquetSessionEnableDictionary, invalidEnableDictionaryValue}, }; @@ -297,12 +289,13 @@ TEST_F(ParquetWriterTest, dictionaryEncodingWithDictionaryPageSize) { const std::string invalidDictionaryPageSizeValue{"NaN"}; const std::unordered_map incorrectDictionaryPageSizeConfigFromFile = { - {parquet::WriterOptions::kParquetHiveConnectorDictionaryPageSizeLimit, + {config::ConfigBase::toConfigKey( + parquet::WriterConfig::kParquetSessionDictionaryPageSizeLimit), invalidDictionaryPageSizeValue}, }; const std::unordered_map incorrectDictionaryPageSizeSessionProperties = { - {parquet::WriterOptions::kParquetSessionDictionaryPageSizeLimit, + {parquet::WriterConfig::kParquetSessionDictionaryPageSizeLimit, invalidDictionaryPageSizeValue}, }; @@ -317,70 +310,30 @@ TEST_F(ParquetWriterTest, dictionaryEncodingWithDictionaryPageSize) { } TEST_F(ParquetWriterTest, dictionaryEncodingOff) { - const auto schema = ROW({"c0"}, {SMALLINT()}); constexpr int64_t kRows = 10'000; - const auto data = makeRowVector({ - makeFlatVector(kRows, [](auto row) { return row + 1; }), - }); + const auto data = makeSmallintTestData(kRows); // Write Parquet test data, then read and return the DataPage // (thrift::PageType::type) used. const auto testEnableDictionaryAndDictionaryPageSizeToGetPageHeader = [&](std::unordered_map configFromFile, std::unordered_map sessionProperties) { - // Create an in-memory writer. - auto sink = std::make_unique( - 200 * 1024 * 1024, - dwio::common::FileSink::Options{.pool = leafPool_.get()}); - auto sinkPtr = sink.get(); - parquet::WriterOptions writerOptions; - writerOptions.memoryPool = leafPool_.get(); - - auto connectorConfig = config::ConfigBase(std::move(configFromFile)); - auto connectorSessionProperties = - config::ConfigBase(std::move(sessionProperties)); - - writerOptions.processConfigs( - connectorConfig, connectorSessionProperties); - auto writer = std::make_unique( - std::move(sink), writerOptions, rootPool_, schema); - writer->write(data); - writer->close(); - - // Read to identify DataPage used. - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; - auto reader = createReaderInMemory(*sinkPtr, readerOptions); - - auto colChunkPtr = reader->fileMetaData().rowGroup(0).columnChunk(0); - std::string_view sinkData(sinkPtr->data(), sinkPtr->size()); - - auto readFile = std::make_shared(sinkData); - auto file = std::make_shared(std::move(readFile)); - - auto inputStream = std::make_unique( - std::move(file), - colChunkPtr.dataPageOffset(), - 150, - *leafPool_, - LogType::TEST); - auto pageReader = std::make_unique( - std::move(inputStream), - *leafPool_, - colChunkPtr.compression(), - colChunkPtr.totalCompressedSize()); - return pageReader->readPageHeader(); + auto* sinkPtr = write( + data, std::move(configFromFile), std::move(sessionProperties)); + return readPageHeader(sinkPtr, 0); }; // Test only dictionary off without dictionary page size configured const std::unordered_map withoutPageSizeConfigFromFile = { - {parquet::WriterOptions::kParquetHiveConnectorEnableDictionary, + {config::ConfigBase::toConfigKey( + parquet::WriterConfig::kParquetSessionEnableDictionary), "false"}, }; const std::unordered_map withoutPageSizeSessionProperties = { - {parquet::WriterOptions::kParquetSessionEnableDictionary, "false"}, + {parquet::WriterConfig::kParquetSessionEnableDictionary, "false"}, }; const auto withoutPageSizeHeader = @@ -401,16 +354,17 @@ TEST_F(ParquetWriterTest, dictionaryEncodingOff) { const std::unordered_map withPageSizeConfigFromFile = { - {parquet::WriterOptions::kParquetHiveConnectorEnableDictionary, + {config::ConfigBase::toConfigKey( + parquet::WriterConfig::kParquetSessionEnableDictionary), "false"}, - {parquet::WriterOptions::kParquetHiveConnectorDictionaryPageSizeLimit, + {config::ConfigBase::toConfigKey( + parquet::WriterConfig::kParquetSessionDictionaryPageSizeLimit), "1B"}, }; const std::unordered_map withPageSizeSessionProperties = { - {parquet::WriterOptions::kParquetSessionEnableDictionary, "false"}, - {parquet::WriterOptions::kParquetSessionDictionaryPageSizeLimit, - "1B"}, + {parquet::WriterConfig::kParquetSessionEnableDictionary, "false"}, + {parquet::WriterConfig::kParquetSessionDictionaryPageSizeLimit, "1B"}, }; const auto withPageSizeHeader = @@ -446,27 +400,20 @@ TEST_F(ParquetWriterTest, compression) { makeFlatVector(kRows, [](auto row) { return row - 25; }), }); - // Create an in-memory writer - auto sink = std::make_unique( - 200 * 1024 * 1024, - dwio::common::FileSink::Options{.pool = leafPool_.get()}); - auto sinkPtr = sink.get(); - facebook::velox::parquet::WriterOptions writerOptions; - writerOptions.memoryPool = leafPool_.get(); + parquet::WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); writerOptions.compressionKind = CompressionKind::CompressionKind_SNAPPY; const auto& fieldNames = schema->names(); - for (int i = 0; i < params.size(); i++) { writerOptions.columnCompressionsMap[fieldNames[i]] = params[i]; } - auto writer = std::make_unique( - std::move(sink), writerOptions, rootPool_, schema); - writer->write(data); - writer->close(); + auto* sinkPtr = write(data, writerOptions); - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto reader = createReaderInMemory(*sinkPtr, readerOptions); ASSERT_EQ(reader->numberOfRows(), kRows); @@ -481,61 +428,20 @@ TEST_F(ParquetWriterTest, compression) { auto rowReader = createRowReaderWithSchema(std::move(reader), schema); assertReadWithReaderAndExpected(schema, *rowReader, data, *leafPool_); -}; +} TEST_F(ParquetWriterTest, testPageSizeAndBatchSizeConfiguration) { - const auto schema = ROW({"c0"}, {SMALLINT()}); constexpr int64_t kRows = 10'000; - const auto data = makeRowVector({ - makeFlatVector(kRows, [](auto row) { return row + 1; }), - }); + const auto data = makeSmallintTestData(kRows); // Write Parquet test data, then read and return the DataPage // (thrift::PageType::type) used. const auto testPageSizeAndBatchSizeToGetPageHeader = [&](std::unordered_map configFromFile, std::unordered_map sessionProperties) { - // Create an in-memory writer. - auto sink = std::make_unique( - 200 * 1024 * 1024, - dwio::common::FileSink::Options{.pool = leafPool_.get()}); - auto sinkPtr = sink.get(); - parquet::WriterOptions writerOptions; - writerOptions.memoryPool = leafPool_.get(); - - auto connectorConfig = config::ConfigBase(std::move(configFromFile)); - auto connectorSessionProperties = - config::ConfigBase(std::move(sessionProperties)); - - writerOptions.processConfigs( - connectorConfig, connectorSessionProperties); - auto writer = std::make_unique( - std::move(sink), writerOptions, rootPool_, schema); - writer->write(data); - writer->close(); - - // Read to identify DataPage used. - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; - auto reader = createReaderInMemory(*sinkPtr, readerOptions); - - auto colChunkPtr = reader->fileMetaData().rowGroup(0).columnChunk(0); - std::string_view sinkData(sinkPtr->data(), sinkPtr->size()); - - auto readFile = std::make_shared(sinkData); - auto file = std::make_shared(std::move(readFile)); - - auto inputStream = std::make_unique( - std::move(file), - colChunkPtr.dataPageOffset(), - 150, - *leafPool_, - LogType::TEST); - auto pageReader = std::make_unique( - std::move(inputStream), - *leafPool_, - colChunkPtr.compression(), - colChunkPtr.totalCompressedSize()); - return pageReader->readPageHeader(); + auto* sinkPtr = write( + data, std::move(configFromFile), std::move(sessionProperties)); + return readPageHeader(sinkPtr, 0); }; // Test default config (i.e., no explicit config) @@ -564,12 +470,16 @@ TEST_F(ParquetWriterTest, testPageSizeAndBatchSizeConfiguration) { // of values in each page can be divided by 97, it means the batch size is // applied (default is 1024) const std::unordered_map normalConfigFromFile = { - {parquet::WriterOptions::kParquetHiveConnectorWritePageSize, "2KB"}, - {parquet::WriterOptions::kParquetHiveConnectorWriteBatchSize, "97"}, + {config::ConfigBase::toConfigKey( + parquet::WriterConfig::kParquetSessionWritePageSize), + "2KB"}, + {config::ConfigBase::toConfigKey( + parquet::WriterConfig::kParquetSessionWriteBatchSize), + "97"}, }; const std::unordered_map normalSessionProperties = { - {parquet::WriterOptions::kParquetSessionWritePageSize, "2KB"}, - {parquet::WriterOptions::kParquetSessionWriteBatchSize, "97"}, + {parquet::WriterConfig::kParquetSessionWritePageSize, "2KB"}, + {parquet::WriterConfig::kParquetSessionWriteBatchSize, "97"}, }; const auto normalHeader = testPageSizeAndBatchSizeToGetPageHeader( normalConfigFromFile, normalSessionProperties); @@ -588,12 +498,13 @@ TEST_F(ParquetWriterTest, testPageSizeAndBatchSizeConfiguration) { const std::string invalidPageSizeAndBatchSizeValue{"NaN"}; const std::unordered_map incorrectPageSizeConfigFromFile = { - {parquet::WriterOptions::kParquetHiveConnectorWritePageSize, + {config::ConfigBase::toConfigKey( + parquet::WriterConfig::kParquetSessionWritePageSize), invalidPageSizeAndBatchSizeValue}, }; const std::unordered_map incorrectPageSizeSessionPropertiesFromFile = { - {parquet::WriterOptions::kParquetSessionWritePageSize, + {parquet::WriterConfig::kParquetSessionWritePageSize, invalidPageSizeAndBatchSizeValue}, }; @@ -609,12 +520,13 @@ TEST_F(ParquetWriterTest, testPageSizeAndBatchSizeConfiguration) { const std::unordered_map incorrectBatchSizeConfigFromFile = { - {parquet::WriterOptions::kParquetHiveConnectorWriteBatchSize, + {config::ConfigBase::toConfigKey( + parquet::WriterConfig::kParquetSessionWriteBatchSize), invalidPageSizeAndBatchSizeValue}, }; const std::unordered_map incorrectBatchSizeSessionPropertiesFromFile = { - {parquet::WriterOptions::kParquetSessionWriteBatchSize, + {parquet::WriterConfig::kParquetSessionWriteBatchSize, invalidPageSizeAndBatchSizeValue}, }; @@ -629,7 +541,6 @@ TEST_F(ParquetWriterTest, testPageSizeAndBatchSizeConfiguration) { } TEST_F(ParquetWriterTest, toggleDataPageVersion) { - auto schema = ROW({"c0"}, {INTEGER()}); const int64_t kRows = 1; const auto data = makeRowVector({ makeFlatVector(kRows, [](auto row) { return 987; }), @@ -640,50 +551,9 @@ TEST_F(ParquetWriterTest, toggleDataPageVersion) { const auto testDataPageVersion = [&](std::unordered_map configFromFile, std::unordered_map sessionProperties) { - // Create an in-memory writer. - auto sink = std::make_unique( - 200 * 1024 * 1024, - dwio::common::FileSink::Options{.pool = leafPool_.get()}); - auto sinkPtr = sink.get(); - parquet::WriterOptions writerOptions; - writerOptions.memoryPool = leafPool_.get(); - - // Simulate setting of Hive config & connector session properties, then - // write test data. - auto connectorConfig = config::ConfigBase(std::move(configFromFile)); - auto connectorSessionProperties = - config::ConfigBase(std::move(sessionProperties)); - - writerOptions.processConfigs( - connectorConfig, connectorSessionProperties); - auto writer = std::make_unique( - std::move(sink), writerOptions, rootPool_, schema); - writer->write(data); - writer->close(); - - // Read to identify DataPage used. - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; - auto reader = createReaderInMemory(*sinkPtr, readerOptions); - - auto colChunkPtr = reader->fileMetaData().rowGroup(0).columnChunk(0); - std::string_view sinkData(sinkPtr->data(), sinkPtr->size()); - - auto readFile = std::make_shared(sinkData); - auto file = std::make_shared(std::move(readFile)); - - auto inputStream = std::make_unique( - std::move(file), - colChunkPtr.dataPageOffset(), - 150, - *leafPool_, - LogType::TEST); - auto pageReader = std::make_unique( - std::move(inputStream), - *leafPool_, - colChunkPtr.compression(), - colChunkPtr.totalCompressedSize()); - - return pageReader->readPageHeader().type; + auto* sinkPtr = write( + data, std::move(configFromFile), std::move(sessionProperties)); + return readPageHeader(sinkPtr, 0).type; }; // Test default behavior - DataPage should be V1. @@ -691,7 +561,9 @@ TEST_F(ParquetWriterTest, toggleDataPageVersion) { // Simulate setting DataPage version to V2 via Hive config from file. std::unordered_map configFromFile = { - {parquet::WriterOptions::kParquetHiveConnectorDataPageVersion, "V2"}}; + {config::ConfigBase::toConfigKey( + parquet::WriterConfig::kParquetSessionDataPageVersion), + "V2"}}; ASSERT_EQ( testDataPageVersion(configFromFile, {}), @@ -699,7 +571,9 @@ TEST_F(ParquetWriterTest, toggleDataPageVersion) { // Simulate setting DataPage version to V1 via Hive config from file. configFromFile = { - {parquet::WriterOptions::kParquetHiveConnectorDataPageVersion, "V1"}}; + {config::ConfigBase::toConfigKey( + parquet::WriterConfig::kParquetSessionDataPageVersion), + "V1"}}; ASSERT_EQ( testDataPageVersion(configFromFile, {}), @@ -707,7 +581,7 @@ TEST_F(ParquetWriterTest, toggleDataPageVersion) { // Simulate setting DataPage version to V2 via connector session property. std::unordered_map sessionProperties = { - {parquet::WriterOptions::kParquetSessionDataPageVersion, "V2"}}; + {parquet::WriterConfig::kParquetSessionDataPageVersion, "V2"}}; ASSERT_EQ( testDataPageVersion({}, sessionProperties), @@ -715,7 +589,7 @@ TEST_F(ParquetWriterTest, toggleDataPageVersion) { // Simulate setting DataPage version to V1 via connector session property. sessionProperties = { - {parquet::WriterOptions::kParquetSessionDataPageVersion, "V1"}}; + {parquet::WriterConfig::kParquetSessionDataPageVersion, "V1"}}; ASSERT_EQ( testDataPageVersion({}, sessionProperties), @@ -725,9 +599,11 @@ TEST_F(ParquetWriterTest, toggleDataPageVersion) { // and to V2 via Hive config from file. Session property should take // precedence. sessionProperties = { - {parquet::WriterOptions::kParquetSessionDataPageVersion, "V1"}}; + {parquet::WriterConfig::kParquetSessionDataPageVersion, "V1"}}; configFromFile = { - {parquet::WriterOptions::kParquetHiveConnectorDataPageVersion, "V2"}}; + {config::ConfigBase::toConfigKey( + parquet::WriterConfig::kParquetSessionDataPageVersion), + "V2"}}; ASSERT_EQ( testDataPageVersion({}, sessionProperties), @@ -737,9 +613,11 @@ TEST_F(ParquetWriterTest, toggleDataPageVersion) { // and to V1 via Hive config from file. Session property should take // precedence. sessionProperties = { - {parquet::WriterOptions::kParquetSessionDataPageVersion, "V2"}}; + {parquet::WriterConfig::kParquetSessionDataPageVersion, "V2"}}; configFromFile = { - {parquet::WriterOptions::kParquetHiveConnectorDataPageVersion, "V1"}}; + {config::ConfigBase::toConfigKey( + parquet::WriterConfig::kParquetSessionDataPageVersion), + "V1"}}; ASSERT_EQ( testDataPageVersion({}, sessionProperties), @@ -758,22 +636,14 @@ DEBUG_ONLY_TEST_F(ParquetWriterTest, unitFromWriterOptions) { ASSERT_EQ(tsType->timezone(), "America/Los_Angeles"); }))); - const auto data = makeRowVector({makeFlatVector( - 10'000, [](auto row) { return Timestamp(row, row); })}); + const auto data = makeTimestampTestData(10'000); parquet::WriterOptions writerOptions; - writerOptions.memoryPool = leafPool_.get(); + writerOptions.memoryPool = rootPool_.get(); writerOptions.parquetWriteTimestampUnit = TimestampPrecision::kMicroseconds; writerOptions.parquetWriteTimestampTimeZone = "America/Los_Angeles"; - // Create an in-memory writer. - auto sink = std::make_unique( - 200 * 1024 * 1024, - dwio::common::FileSink::Options{.pool = leafPool_.get()}); - auto writer = std::make_unique( - std::move(sink), writerOptions, rootPool_, ROW({"c0"}, {TIMESTAMP()})); - writer->write(data); - writer->close(); -}; + write(data, writerOptions); +} DEBUG_ONLY_TEST_F(ParquetWriterTest, parquetWriteTimestampTimeZoneWithDefault) { SCOPED_TESTVALUE_SET( @@ -787,42 +657,28 @@ DEBUG_ONLY_TEST_F(ParquetWriterTest, parquetWriteTimestampTimeZoneWithDefault) { ASSERT_EQ(tsType->timezone(), ""); }))); - const auto data = makeRowVector({makeFlatVector( - 10'000, [](auto row) { return Timestamp(row, row); })}); + const auto data = makeTimestampTestData(10'000); parquet::WriterOptions writerOptions; - writerOptions.memoryPool = leafPool_.get(); + writerOptions.memoryPool = rootPool_.get(); writerOptions.parquetWriteTimestampUnit = TimestampPrecision::kMicroseconds; - // Create an in-memory writer. - auto sink = std::make_unique( - 200 * 1024 * 1024, - dwio::common::FileSink::Options{.pool = leafPool_.get()}); - auto writer = std::make_unique( - std::move(sink), writerOptions, rootPool_, ROW({"c0"}, {TIMESTAMP()})); - writer->write(data); - writer->close(); -}; + write(data, writerOptions); +} TEST_F(ParquetWriterTest, parquetWriteWithArrowMemoryPool) { - const auto data = makeRowVector({makeFlatVector( - 10'000, [](auto row) { return Timestamp(row, row); })}); + const auto data = makeTimestampTestData(10'000); parquet::WriterOptions writerOptions; - writerOptions.memoryPool = leafPool_.get(); + writerOptions.memoryPool = rootPool_.get(); writerOptions.arrowMemoryPool = std::make_shared(); - // Create an in-memory writer. - auto sink = std::make_unique( - 200 * 1024 * 1024, - dwio::common::FileSink::Options{.pool = leafPool_.get()}); - auto writer = std::make_unique( - std::move(sink), writerOptions, rootPool_, ROW({"c0"}, {TIMESTAMP()})); - writer->write(data); - writer->close(); -}; + write(data, writerOptions); +} TEST_F(ParquetWriterTest, updateWriterOptionsFromHiveConfig) { std::unordered_map configFromFile = { - {parquet::WriterOptions::kParquetSessionWriteTimestampUnit, "3"}}; + {config::ConfigBase::toConfigKey( + parquet::WriterConfig::kParquetSessionWriteTimestampUnit), + "3"}}; const config::ConfigBase connectorConfig(std::move(configFromFile)); const config::ConfigBase connectorSessionProperties({}); @@ -906,20 +762,6 @@ TEST_F(ParquetWriterTest, dictionaryEncodedVector) { return wrappedVectors; }; - const auto writeToFile = [this](const RowVectorPtr& data) { - parquet::WriterOptions writerOptions; - writerOptions.memoryPool = leafPool_.get(); - - // Create an in-memory writer. - auto sink = std::make_unique( - 200 * 1024 * 1024, - dwio::common::FileSink::Options{.pool = leafPool_.get()}); - auto writer = std::make_unique( - std::move(sink), writerOptions, rootPool_, asRowType(data->type())); - writer->write(data); - writer->close(); - }; - // Dictionary encoded vectors with complex type. const auto size = 10'000; auto wrappedVectors = wrapDictionaryVectors({ @@ -934,7 +776,8 @@ TEST_F(ParquetWriterTest, dictionaryEncodedVector) { *leafPool_), }); - writeToFile(makeRowVector(wrappedVectors)); + auto data = makeRowVector(wrappedVectors); + write(data); // Dictionary encoded constant vector of scalar type. const auto constantVector = makeConstant(static_cast(123'456), size); @@ -944,8 +787,38 @@ TEST_F(ParquetWriterTest, dictionaryEncodedVector) { VELOX_CHECK_NOT_NULL(wrappedVector->valueVector()); EXPECT_FALSE(wrappedVector->wrappedVector()->isFlatEncoding()); - writeToFile(makeRowVector({wrappedVector})); -}; + data = makeRowVector({wrappedVector}); + write(data); +} + +TEST_F(ParquetWriterTest, allNulls) { + auto schema = ROW({"c0"}, {INTEGER()}); + const int64_t kRows = 4096; + // Create a column with all elements being null. + auto nulls = makeNulls(kRows, [](auto /*row*/) { return true; }); + auto flatVector = std::make_shared>( + pool_.get(), + schema->childAt(0), + nulls, + kRows, + /*values=*/nullptr, + std::vector()); + auto data = std::make_shared( + pool_.get(), schema, nullptr, kRows, std::vector{flatVector}); + + auto* sinkPtr = write(data); + + dwio::common::ReaderOptions readerOptions(leafPool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + auto reader = createReaderInMemory(*sinkPtr, readerOptions); + + ASSERT_EQ(reader->numberOfRows(), kRows); + ASSERT_EQ(*reader->rowType(), *schema); + + auto rowReader = createRowReaderWithSchema(std::move(reader), schema); + assertReadWithReaderAndExpected(schema, *rowReader, data, *leafPool_); +} } // namespace diff --git a/velox/dwio/parquet/thrift/CMakeLists.txt b/velox/dwio/parquet/thrift/CMakeLists.txt index ee6fe26e8c8..57aed73c01a 100644 --- a/velox/dwio/parquet/thrift/CMakeLists.txt +++ b/velox/dwio/parquet/thrift/CMakeLists.txt @@ -12,5 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -velox_add_library(velox_dwio_parquet_thrift ParquetThriftTypes.cpp) +velox_add_library( + velox_dwio_parquet_thrift + ParquetThriftTypes.cpp + HEADERS + ParquetThriftTypes.h + ThriftTransport.h +) velox_link_libraries(velox_dwio_parquet_thrift arrow thrift Boost::headers fmt::fmt) diff --git a/velox/dwio/parquet/writer/CMakeLists.txt b/velox/dwio/parquet/writer/CMakeLists.txt index 37040ec54a3..c131f59ea36 100644 --- a/velox/dwio/parquet/writer/CMakeLists.txt +++ b/velox/dwio/parquet/writer/CMakeLists.txt @@ -14,7 +14,7 @@ add_subdirectory(arrow) -velox_add_library(velox_dwio_arrow_parquet_writer Writer.cpp) +velox_add_library(velox_dwio_arrow_parquet_writer Writer.cpp HEADERS Writer.h) velox_link_libraries( velox_dwio_arrow_parquet_writer @@ -22,6 +22,9 @@ velox_link_libraries( velox_dwio_arrow_parquet_writer_util_lib velox_dwio_common velox_arrow_bridge + velox_exec arrow fmt::fmt ) + +velox_add_library(velox_writer_config INTERFACE HEADERS WriterConfig.h) diff --git a/velox/dwio/parquet/writer/Writer.cpp b/velox/dwio/parquet/writer/Writer.cpp index 2444fa247cc..8cf29529238 100644 --- a/velox/dwio/parquet/writer/Writer.cpp +++ b/velox/dwio/parquet/writer/Writer.cpp @@ -15,6 +15,10 @@ */ #include "velox/dwio/parquet/writer/Writer.h" + +#include +#include + #include #include #include @@ -22,6 +26,7 @@ #include "velox/common/config/Config.h" #include "velox/common/testutil/TestValue.h" #include "velox/core/QueryConfig.h" +#include "velox/dwio/parquet/writer/arrow/ArrowSchema.h" #include "velox/dwio/parquet/writer/arrow/Properties.h" #include "velox/dwio/parquet/writer/arrow/Writer.h" #include "velox/exec/MemoryReclaimer.h" @@ -133,13 +138,13 @@ std::shared_ptr getArrowParquetWriterOptions( WriterProperties::Builder* properties = &builder; if (options.enableDictionary.value_or( facebook::velox::parquet::arrow::DEFAULT_IS_DICTIONARY_ENABLED)) { - properties = properties->enable_dictionary(); - properties = properties->dictionary_pagesize_limit( + properties = properties->enableDictionary(); + properties = properties->dictionaryPagesizeLimit( options.dictionaryPageSizeLimit.value_or( facebook::velox::parquet::arrow:: DEFAULT_DICTIONARY_PAGE_SIZE_LIMIT)); } else { - properties = properties->disable_dictionary(); + properties = properties->disableDictionary(); } properties = properties->compression(getArrowParquetCompression( options.compressionKind.value_or(common::CompressionKind_NONE))); @@ -149,30 +154,30 @@ std::shared_ptr getArrowParquetWriterOptions( getArrowParquetCompression(columnCompressionValues.second)); } properties = properties->encoding(options.encoding); - properties = properties->data_pagesize(options.dataPageSize.value_or( + properties = properties->dataPagesize(options.dataPageSize.value_or( facebook::velox::parquet::arrow::kDefaultDataPageSize)); - properties = properties->write_batch_size(options.batchSize.value_or( + properties = properties->writeBatchSize(options.batchSize.value_or( facebook::velox::parquet::arrow::DEFAULT_WRITE_BATCH_SIZE)); - properties = properties->max_row_group_length( + properties = properties->maxRowGroupLength( static_cast(flushPolicy->rowsInRowGroup())); - properties = properties->codec_options(options.codecOptions); - properties = properties->enable_store_decimal_as_integer(); + properties = properties->codecOptions(options.codecOptions); + properties = properties->enableStoreDecimalAsInteger(); if (options.useParquetDataPageV2.value_or(false)) { - properties = - properties->data_page_version(arrow::ParquetDataPageVersion::V2); + properties = properties->dataPageVersion(arrow::ParquetDataPageVersion::V2); } else { - properties = - properties->data_page_version(arrow::ParquetDataPageVersion::V1); + properties = properties->dataPageVersion(arrow::ParquetDataPageVersion::V1); } if (options.createdBy.has_value()) { - properties = properties->created_by(options.createdBy.value()); + properties = properties->createdBy(options.createdBy.value()); } return properties->build(); } -void validateSchemaRecursive(const RowTypePtr& schema) { - // Check the schema's field names is not empty and unique. - VELOX_USER_CHECK_NOT_NULL(schema, "Field schema must not be empty."); +void validateSchemaRecursive( + const RowTypePtr& schema, + const std::vector& parquetFieldIds) { + // Check the schema's field names are not empty and unique. + VELOX_USER_CHECK_NOT_NULL(schema, "Schema must not be empty."); const auto& fieldNames = schema->names(); folly::F14FastSet uniqueNames; @@ -185,137 +190,163 @@ void validateSchemaRecursive(const RowTypePtr& schema) { name); } + if (!parquetFieldIds.empty()) { + VELOX_USER_CHECK_EQ(parquetFieldIds.size(), schema->size()); + } + for (auto i = 0; i < schema->size(); ++i) { - if (auto childSchema = - std::dynamic_pointer_cast(schema->childAt(i))) { - validateSchemaRecursive(childSchema); + const auto& childType = schema->childAt(i); + const auto& childFieldIds = + parquetFieldIds.empty() ? parquetFieldIds : parquetFieldIds[i].children; + + if (childType->isRow()) { + validateSchemaRecursive( + std::dynamic_pointer_cast(childType), childFieldIds); + } else if (childType->isArray()) { + if (!parquetFieldIds.empty()) { + VELOX_USER_CHECK_EQ(parquetFieldIds[i].children.size(), 1); + } + const auto& elementType = childType->asArray().elementType(); + if (elementType->isRow()) { + validateSchemaRecursive( + std::dynamic_pointer_cast(elementType), + childFieldIds.empty() ? childFieldIds : childFieldIds[0].children); + } + } else if (childType->isMap()) { + if (!parquetFieldIds.empty()) { + VELOX_USER_CHECK_EQ(parquetFieldIds[i].children.size(), 2); + } + const auto& mapType = childType->asMap(); + if (mapType.keyType()->isRow()) { + validateSchemaRecursive( + std::dynamic_pointer_cast(mapType.keyType()), + childFieldIds.empty() ? childFieldIds : childFieldIds[0].children); + } + if (mapType.valueType()->isRow()) { + validateSchemaRecursive( + std::dynamic_pointer_cast(mapType.valueType()), + childFieldIds.empty() ? childFieldIds : childFieldIds[1].children); + } } } } -std::shared_ptr<::arrow::Field> updateFieldNameRecursive( +std::shared_ptr<::arrow::Field> updateFieldNameAndIdRecursive( const std::shared_ptr<::arrow::Field>& field, const Type& type, + const ParquetFieldId* fieldId, const std::string& name = "") { + auto newField = name.empty() ? field : field->WithName(name); + + if (fieldId) { + newField = + newField->WithMetadata(arrow::arrow::fieldIdMetadata(fieldId->fieldId)); + } + if (type.isRow()) { auto& rowType = type.asRow(); - auto newField = field->WithName(name); auto structType = std::dynamic_pointer_cast<::arrow::StructType>(newField->type()); auto childrenSize = rowType.size(); + VELOX_CHECK(!fieldId || childrenSize <= fieldId->children.size()); std::vector> newFields; newFields.reserve(childrenSize); - for (auto i = 0; i < childrenSize; i++) { - newFields.push_back(updateFieldNameRecursive( - structType->fields()[i], *rowType.childAt(i), rowType.nameOf(i))); + for (auto i = 0; i < childrenSize; ++i) { + const auto* childSetting = fieldId ? &fieldId->children.at(i) : nullptr; + newFields.push_back(updateFieldNameAndIdRecursive( + structType->fields()[i], + *rowType.childAt(i), + childSetting, + rowType.nameOf(i))); } - return newField->WithType(::arrow::struct_(newFields)); + newField = newField->WithType(::arrow::struct_(newFields)); } else if (type.isArray()) { - auto newField = field->WithName(name); auto listType = std::dynamic_pointer_cast<::arrow::BaseListType>(newField->type()); auto elementType = type.asArray().elementType(); auto elementField = listType->value_field(); - return newField->WithType( - ::arrow::list(updateFieldNameRecursive(elementField, *elementType))); + const auto* childSetting = fieldId ? &fieldId->children.at(0) : nullptr; + auto updatedElementField = + updateFieldNameAndIdRecursive(elementField, *elementType, childSetting); + newField = newField->WithType(::arrow::list(updatedElementField)); } else if (type.isMap()) { auto mapType = type.asMap(); - auto newField = field->WithName(name); auto arrowMapType = std::dynamic_pointer_cast<::arrow::MapType>(newField->type()); - auto newKeyField = - updateFieldNameRecursive(arrowMapType->key_field(), *mapType.keyType()); - auto newValueField = updateFieldNameRecursive( - arrowMapType->item_field(), *mapType.valueType()); - return newField->WithType( - ::arrow::map(newKeyField->type(), newValueField->type())); - } else if (name != "") { - return field->WithName(name); - } else { - return field; + const auto* keySetting = fieldId ? &fieldId->children.at(0) : nullptr; + const auto* valueSetting = fieldId ? &fieldId->children.at(1) : nullptr; + auto newKeyField = updateFieldNameAndIdRecursive( + arrowMapType->key_field(), *mapType.keyType(), keySetting); + auto newValueField = updateFieldNameAndIdRecursive( + arrowMapType->item_field(), *mapType.valueType(), valueSetting); + newField = newField->WithType( + std::make_shared<::arrow::MapType>(newKeyField, newValueField)); } -} -std::optional getTimestampUnit( - const config::ConfigBase& config, - const char* configKey) { - if (const auto unit = config.get(configKey)) { - VELOX_CHECK( - unit == 3 /*milli*/ || unit == 6 /*micro*/ || unit == 9 /*nano*/, - "Invalid timestamp unit: {}", - unit.value()); - return std::optional(static_cast(unit.value())); - } - return std::nullopt; + return newField; } -std::optional getTimestampTimeZone( - const config::ConfigBase& config, - const char* configKey) { - if (const auto timezone = config.get(configKey)) { - return timezone.value(); +std::optional toTimestampPrecision( + std::optional unit) { + if (!unit) { + return std::nullopt; } - return std::nullopt; + VELOX_CHECK( + *unit == 3 /*milli*/ || *unit == 6 /*micro*/ || *unit == 9 /*nano*/, + "Invalid timestamp unit: {}", + *unit); + return static_cast(*unit); } -std::optional isParquetEnableDictionary( - const config::ConfigBase& config, - const char* configKey) { - try { - if (const auto enableDictionary = config.get(configKey)) { - return enableDictionary.value(); - } - } catch (const folly::ConversionError& e) { - VELOX_USER_FAIL( - "Invalid parquet writer enable dictionary option: {}", e.what()); - } - return std::nullopt; +// Converts a string to TimestampPrecision. Accepts numeric values "3" (milli), +// "6" (micro), or "9" (nano). +TimestampPrecision stringToTimestampPrecision(const std::string& value) { + return toTimestampPrecision(std::optional{folly::to(value)}).value(); } -std::optional getParquetDataPageVersion( - const config::ConfigBase& config, - const char* configKey) { - if (const auto version = config.get(configKey)) { - if (version == "V1") { - return false; - } else if (version == "V2") { - return true; - } else { - VELOX_FAIL("Unsupported parquet datapage version {}", version.value()); - } +std::optional isParquetV2(std::optional version) { + if (!version) { + return std::nullopt; + } + if (version == "V1") { + return false; } - return std::nullopt; + if (version == "V2") { + return true; + } + VELOX_FAIL("Unsupported parquet datapage version {}", *version); } -std::optional getParquetPageSize( - const config::ConfigBase& config, - const char* configKey) { - if (const auto pageSize = config.get(configKey)) { - return config::toCapacity(pageSize.value(), config::CapacityUnit::BYTE); +std::optional toParquetPageSize(std::optional pageSize) { + if (!pageSize) { + return std::nullopt; } - return std::nullopt; + return config::toCapacity(*pageSize, config::CapacityUnit::BYTE); } -std::optional getParquetBatchSize( - const config::ConfigBase& config, - const char* configKey) { +std::optional toParquetEnableDictionary( + std::optional enableDictionary) { + if (!enableDictionary) { + return std::nullopt; + } try { - if (const auto batchSize = config.get(configKey)) { - return batchSize.value(); - } - } catch (const folly::ConversionError& e) { - VELOX_USER_FAIL("Invalid parquet writer batch size: {}", e.what()); + return folly::to(*enableDictionary); + } catch (const std::exception& e) { + VELOX_USER_FAIL( + "Invalid parquet writer enable dictionary option: {}", e.what()); } - return std::nullopt; } -std::optional getParquetCreatedBy( - const config::ConfigBase& config, - const char* configKey) { - if (config.get(configKey).has_value()) { - return config.get(configKey).value(); +std::optional toParquetBatchSize( + std::optional batchSize) { + if (!batchSize) { + return std::nullopt; + } + try { + return folly::to(*batchSize); + } catch (const std::exception& e) { + VELOX_USER_FAIL("Invalid parquet writer batch size: {}", e.what()); } - return std::nullopt; } } // namespace @@ -327,13 +358,14 @@ Writer::Writer( RowTypePtr schema) : pool_(std::move(pool)), generalPool_{pool_->addLeafChild(".general")}, - stream_(std::make_shared( - std::move(sink), - *generalPool_, - options.bufferGrowRatio)), + stream_( + std::make_shared( + std::move(sink), + *generalPool_, + options.bufferGrowRatio)), arrowContext_(std::make_shared()), schema_(std::move(schema)) { - validateSchemaRecursive(schema_); + validateSchemaRecursive(schema_, options.parquetFieldIds); if (options.flushPolicyFactory) { castUniquePointer(options.flushPolicyFactory(), flushPolicy_); @@ -351,6 +383,7 @@ Writer::Writer( setMemoryReclaimers(); writeInt96AsTimestamp_ = options.writeInt96AsTimestamp; arrowMemoryPool_ = options.arrowMemoryPool; + parquetFieldIds_ = std::move(options.parquetFieldIds); } Writer::Writer( @@ -360,9 +393,10 @@ Writer::Writer( : Writer{ std::move(sink), options, - options.memoryPool->addAggregateChild(fmt::format( - "writer_node_{}", - folly::to(folly::Random::rand64()))), + options.memoryPool->addAggregateChild( + fmt::format( + "writer_node_{}", + folly::to(folly::Random::rand64()))), std::move(schema)} {} void Writer::flush() { @@ -370,12 +404,12 @@ void Writer::flush() { if (!arrowContext_->writer) { ArrowWriterProperties::Builder builder; if (writeInt96AsTimestamp_) { - builder.enable_deprecated_int96_timestamps(); + builder.enableDeprecatedInt96Timestamps(); } auto arrowProperties = builder.build(); PARQUET_ASSIGN_OR_THROW( arrowContext_->writer, - FileWriter::Open( + FileWriter::open( *arrowContext_->schema.get(), arrowMemoryPool_.get(), stream_, @@ -397,7 +431,7 @@ void Writer::flush() { arrowContext_->schema, std::move(chunks), static_cast(arrowContext_->stagingRows)); - PARQUET_THROW_NOT_OK(arrowContext_->writer->WriteTable( + PARQUET_THROW_NOT_OK(arrowContext_->writer->writeTable( *table, static_cast(flushPolicy_->rowsInRowGroup()))); PARQUET_THROW_NOT_OK(stream_->Flush()); for (auto& chunk : arrowContext_->stagingChunks) { @@ -445,9 +479,15 @@ void Writer::write(const VectorPtr& data) { "facebook::velox::parquet::Writer::write", arrowSchema.get()); std::vector> newFields; auto childSize = schema_->size(); + if (!parquetFieldIds_.empty()) { + VELOX_CHECK(childSize == parquetFieldIds_.size()); + } for (auto i = 0; i < childSize; i++) { - newFields.push_back(updateFieldNameRecursive( - arrowSchema->fields()[i], *schema_->childAt(i), schema_->nameOf(i))); + newFields.push_back(updateFieldNameAndIdRecursive( + arrowSchema->fields()[i], + *schema_->childAt(i), + !parquetFieldIds_.empty() ? &parquetFieldIds_.at(i) : nullptr, + schema_->nameOf(i))); } PARQUET_ASSIGN_OR_THROW( @@ -478,24 +518,30 @@ void Writer::write(const VectorPtr& data) { } bool Writer::isCodecAvailable(common::CompressionKind compression) { - return arrow::util::Codec::IsAvailable( + return arrow::util::Codec::isAvailable( getArrowParquetCompression(compression)); } void Writer::newRowGroup(int32_t numRows) { - PARQUET_THROW_NOT_OK(arrowContext_->writer->NewRowGroup(numRows)); + PARQUET_THROW_NOT_OK(arrowContext_->writer->newRowGroup(numRows)); } -void Writer::close() { +std::unique_ptr Writer::close() { flush(); + std::unique_ptr parquetFileMetadata; if (arrowContext_->writer) { - PARQUET_THROW_NOT_OK(arrowContext_->writer->Close()); + PARQUET_THROW_NOT_OK(arrowContext_->writer->close()); + parquetFileMetadata = std::make_unique( + arrowContext_->writer->metadata()); arrowContext_->writer.reset(); } + PARQUET_THROW_NOT_OK(stream_->Close()); arrowContext_->stagingChunks.clear(); + + return parquetFileMetadata; } void Writer::abort() { @@ -560,11 +606,29 @@ void WriterOptions::processConfigs( VELOX_CHECK_NOT_NULL( parquetWriterOptions, "Expected a Parquet WriterOptions object."); + // Check serdeParameters for timestamp settings first (highest priority). + auto serdeTimestampUnitIt = + serdeParameters.find(WriterConfig::kParquetSerdeTimestampUnit); + if (serdeTimestampUnitIt != serdeParameters.end()) { + parquetWriteTimestampUnit = + stringToTimestampPrecision(serdeTimestampUnitIt->second); + } + + auto serdeTimestampTimezoneIt = + serdeParameters.find(WriterConfig::kParquetSerdeTimestampTimezone); + if (serdeTimestampTimezoneIt != serdeParameters.end()) { + // Empty string means no timezone conversion (nullopt). + if (serdeTimestampTimezoneIt->second.empty()) { + parquetWriteTimestampTimeZone = std::nullopt; + } else { + parquetWriteTimestampTimeZone = serdeTimestampTimezoneIt->second; + } + } + if (!parquetWriteTimestampUnit) { parquetWriteTimestampUnit = - getTimestampUnit(session, kParquetSessionWriteTimestampUnit).has_value() - ? getTimestampUnit(session, kParquetSessionWriteTimestampUnit) - : getTimestampUnit(connectorConfig, kParquetSessionWriteTimestampUnit); + toTimestampPrecision(session.getWithFallback( + WriterConfig::kParquetSessionWriteTimestampUnit, connectorConfig)); } if (!parquetWriteTimestampTimeZone) { parquetWriteTimestampTimeZone = parquetWriterOptions->sessionTimezoneName; @@ -572,50 +636,56 @@ void WriterOptions::processConfigs( if (!enableDictionary) { enableDictionary = - isParquetEnableDictionary(session, kParquetSessionEnableDictionary) - .has_value() - ? isParquetEnableDictionary(session, kParquetSessionEnableDictionary) - : isParquetEnableDictionary( - connectorConfig, kParquetHiveConnectorEnableDictionary); + toParquetEnableDictionary(session.getWithFallback( + WriterConfig::kParquetSessionEnableDictionary, connectorConfig)); } if (!dictionaryPageSizeLimit) { dictionaryPageSizeLimit = - getParquetPageSize(session, kParquetSessionDictionaryPageSizeLimit) - .has_value() - ? getParquetPageSize(session, kParquetSessionDictionaryPageSizeLimit) - : getParquetPageSize( - connectorConfig, kParquetHiveConnectorDictionaryPageSizeLimit); + toParquetPageSize(session.getWithFallback( + WriterConfig::kParquetSessionDictionaryPageSizeLimit, + connectorConfig)); } if (!useParquetDataPageV2) { - useParquetDataPageV2 = - getParquetDataPageVersion(session, kParquetSessionDataPageVersion) - .has_value() - ? getParquetDataPageVersion(session, kParquetSessionDataPageVersion) - : getParquetDataPageVersion( - connectorConfig, kParquetHiveConnectorDataPageVersion); + useParquetDataPageV2 = isParquetV2(session.getWithFallback( + WriterConfig::kParquetSessionDataPageVersion, connectorConfig)); } if (!dataPageSize) { - dataPageSize = - getParquetPageSize(session, kParquetSessionWritePageSize).has_value() - ? getParquetPageSize(session, kParquetSessionWritePageSize) - : getParquetPageSize( - connectorConfig, kParquetHiveConnectorWritePageSize); + dataPageSize = toParquetPageSize(session.getWithFallback( + WriterConfig::kParquetSessionWritePageSize, connectorConfig)); } if (!batchSize) { - batchSize = - getParquetBatchSize(session, kParquetSessionWriteBatchSize).has_value() - ? getParquetBatchSize(session, kParquetSessionWriteBatchSize) - : getParquetBatchSize( - connectorConfig, kParquetHiveConnectorWriteBatchSize); + batchSize = toParquetBatchSize(session.getWithFallback( + WriterConfig::kParquetSessionWriteBatchSize, connectorConfig)); } if (!createdBy) { - createdBy = - getParquetCreatedBy(connectorConfig, kParquetHiveConnectorCreatedBy); + createdBy = session.getWithFallback( + WriterConfig::kParquetHiveConnectorCreatedBy, connectorConfig); + } + + // Parquet only updates ioStats_->rawBytesWritten() when a row group is + // flushed. With the default flush policy (1M rows / 128MB), small + // maxTargetFileBytes_ would never trigger rotation because rawBytesWritten() + // stays at 0 while data is buffered. To honor maxTargetFileBytes_, cap the + // row group byte threshold so we flush earlier and rawBytesWritten() grows + // during writes. + auto maxTargetFileSize = + toParquetPageSize(session.getWithFallback( + WriterConfig::kParquetSessionMaxTargetFileSize, connectorConfig)); + if (maxTargetFileSize.has_value()) { + if (!flushPolicyFactory) { + auto bytesInRowGroup = std::min( + DefaultFlushPolicy::kDefaultBytesInRowGroup, + maxTargetFileSize.value()); + flushPolicyFactory = [bytesInRowGroup]() { + return std::make_unique( + DefaultFlushPolicy::kDefaultRowsInGroup, bytesInRowGroup); + }; + } } } diff --git a/velox/dwio/parquet/writer/Writer.h b/velox/dwio/parquet/writer/Writer.h index 7deca7ee670..3046424a6f9 100644 --- a/velox/dwio/parquet/writer/Writer.h +++ b/velox/dwio/parquet/writer/Writer.h @@ -20,11 +20,15 @@ #include "velox/common/compression/Compression.h" #include "velox/common/config/Config.h" #include "velox/dwio/common/DataBuffer.h" +#include "velox/dwio/common/FileMetadata.h" #include "velox/dwio/common/FileSink.h" #include "velox/dwio/common/FlushPolicy.h" #include "velox/dwio/common/Options.h" #include "velox/dwio/common/Writer.h" #include "velox/dwio/common/WriterFactory.h" +#include "velox/dwio/parquet/ParquetFieldId.h" +#include "velox/dwio/parquet/writer/WriterConfig.h" +#include "velox/dwio/parquet/writer/arrow/Metadata.h" #include "velox/dwio/parquet/writer/arrow/Types.h" #include "velox/dwio/parquet/writer/arrow/util/Compression.h" #include "velox/vector/ComplexVector.h" @@ -38,13 +42,32 @@ class ArrowDataBufferSink; struct ArrowContext; +/// Parquet-specific file metadata wrapper. Provides access to the underlying +/// arrow::FileMetaData. +class ParquetFileMetadata : public dwio::common::FileMetadata { + public: + explicit ParquetFileMetadata(std::shared_ptr metadata) + : metadata_(std::move(metadata)) {} + + std::shared_ptr arrowMetadata() const { + return metadata_; + } + + private: + std::shared_ptr metadata_; +}; + class DefaultFlushPolicy : public dwio::common::FlushPolicy { public: DefaultFlushPolicy() - : rowsInRowGroup_(1'024 * 1'024), bytesInRowGroup_(128 * 1'024 * 1'024) {} + : rowsInRowGroup_(kDefaultRowsInGroup), + bytesInRowGroup_(kDefaultBytesInRowGroup) {} DefaultFlushPolicy(uint64_t rowsInRowGroup, int64_t bytesInRowGroup) : rowsInRowGroup_(rowsInRowGroup), bytesInRowGroup_(bytesInRowGroup) {} + static constexpr uint64_t kDefaultRowsInGroup{1'024 * 1'024}; + static constexpr int64_t kDefaultBytesInRowGroup{128 * 1'024 * 1'024}; + bool shouldFlush( const dwio::common::StripeProgress& stripeProgress) override { return stripeProgress.stripeRowCount >= rowsInRowGroup_ || @@ -94,7 +117,7 @@ struct WriterOptions : public dwio::common::WriterOptions { // folly/FBVector(https://github.com/facebook/folly/blob/main/folly/docs/FBVector.md#memory-handling). double bufferGrowRatio = 1.5; - arrow::Encoding::type encoding = arrow::Encoding::PLAIN; + arrow::Encoding::type encoding = arrow::Encoding::kPlain; std::shared_ptr codecOptions; std::unordered_map @@ -116,36 +139,11 @@ struct WriterOptions : public dwio::common::WriterOptions { std::shared_ptr arrowMemoryPool; - // Parsing session and hive configs. - - // This isn't a typo; session and hive connector config names are different - // ('_' vs '-'). - static constexpr const char* kParquetSessionWriteTimestampUnit = - "hive.parquet.writer.timestamp_unit"; - static constexpr const char* kParquetHiveConnectorWriteTimestampUnit = - "hive.parquet.writer.timestamp-unit"; - static constexpr const char* kParquetSessionEnableDictionary = - "hive.parquet.writer.enable_dictionary"; - static constexpr const char* kParquetHiveConnectorEnableDictionary = - "hive.parquet.writer.enable-dictionary"; - static constexpr const char* kParquetSessionDictionaryPageSizeLimit = - "hive.parquet.writer.dictionary_page_size_limit"; - static constexpr const char* kParquetHiveConnectorDictionaryPageSizeLimit = - "hive.parquet.writer.dictionary-page-size-limit"; - static constexpr const char* kParquetSessionDataPageVersion = - "hive.parquet.writer.datapage_version"; - static constexpr const char* kParquetHiveConnectorDataPageVersion = - "hive.parquet.writer.datapage-version"; - static constexpr const char* kParquetSessionWritePageSize = - "hive.parquet.writer.page_size"; - static constexpr const char* kParquetHiveConnectorWritePageSize = - "hive.parquet.writer.page-size"; - static constexpr const char* kParquetSessionWriteBatchSize = - "hive.parquet.writer.batch_size"; - static constexpr const char* kParquetHiveConnectorWriteBatchSize = - "hive.parquet.writer.batch-size"; - static constexpr const char* kParquetHiveConnectorCreatedBy = - "hive.parquet.writer.created-by"; + /// Optional field IDs to assign to columns in the Parquet schema. + /// If provided, the writer will use these IDs for the schema fields. + /// If not provided, the field_id will be -1. + /// The structure should match the schema hierarchy with nested children. + std::vector parquetFieldIds; // Process hive connector and session configs. void processConfigs( @@ -188,10 +186,11 @@ class Writer : public dwio::common::Writer { return true; } - // Closes 'this', After close, data can no longer be added and the completed + // Closes 'this'. After close, data can no longer be added and the completed // Parquet file is flushed into 'sink' provided at construction. 'sink' stays - // live until destruction of 'this'. - void close() override; + // live until destruction of 'this'. Returns file metadata, or null if no + // metadata is available (e.g. for an empty file). + std::unique_ptr close() override; void abort() override; @@ -214,6 +213,8 @@ class Writer : public dwio::common::Writer { std::shared_ptr arrowContext_; + std::vector parquetFieldIds_; + std::unique_ptr flushPolicy_; const RowTypePtr schema_; diff --git a/velox/dwio/parquet/writer/WriterConfig.h b/velox/dwio/parquet/writer/WriterConfig.h new file mode 100644 index 00000000000..ee8788f74b9 --- /dev/null +++ b/velox/dwio/parquet/writer/WriterConfig.h @@ -0,0 +1,75 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace facebook::velox::parquet { + +/// Config constants for the Parquet writer. +/// +/// IMPORTANT: These constants are kept in a separate header rather than in +/// Writer.h because Gluten's WholeStageResultIterator.cc needs access to these +/// configuration constants but cannot include Writer.h due to Arrow header +/// conflicts. This separation allows external code to reference these constants +/// without pulling in Arrow dependencies. +struct WriterConfig { + // Parsing session and hive configs. + + // This isn't a typo; session and hive connector config names are different + // ('_' vs '-'). + static constexpr const char* kParquetSessionWriteTimestampUnit = + "hive.parquet.writer.timestamp_unit"; + static constexpr const char* kParquetHiveConnectorWriteTimestampUnit = + "hive.parquet.writer.timestamp-unit"; + static constexpr const char* kParquetSessionEnableDictionary = + "hive.parquet.writer.enable_dictionary"; + static constexpr const char* kParquetHiveConnectorEnableDictionary = + "hive.parquet.writer.enable-dictionary"; + static constexpr const char* kParquetSessionDictionaryPageSizeLimit = + "hive.parquet.writer.dictionary_page_size_limit"; + static constexpr const char* kParquetHiveConnectorDictionaryPageSizeLimit = + "hive.parquet.writer.dictionary-page-size-limit"; + static constexpr const char* kParquetSessionDataPageVersion = + "hive.parquet.writer.datapage_version"; + static constexpr const char* kParquetHiveConnectorDataPageVersion = + "hive.parquet.writer.datapage-version"; + static constexpr const char* kParquetSessionWritePageSize = + "hive.parquet.writer.page_size"; + static constexpr const char* kParquetHiveConnectorWritePageSize = + "hive.parquet.writer.page-size"; + static constexpr const char* kParquetSessionWriteBatchSize = + "hive.parquet.writer.batch_size"; + static constexpr const char* kParquetHiveConnectorWriteBatchSize = + "hive.parquet.writer.batch-size"; + static constexpr const char* kParquetHiveConnectorCreatedBy = + "hive.parquet.writer.created-by"; + + // Use the same property name from HiveConfig::kMaxTargetFileSize. + static constexpr const char* kParquetConnectorMaxTargetFileSize = + "max-target-file-size"; + static constexpr const char* kParquetSessionMaxTargetFileSize = + "max_target_file_size"; + // Serde parameter keys for timestamp settings. These can be set via + // serdeParameters map to override the default timestamp behavior. + // The timezone key accepts a timezone string or empty string to disable + // timezone conversion. + static constexpr const char* kParquetSerdeTimestampUnit = + "parquet.writer.timestamp.unit"; + static constexpr const char* kParquetSerdeTimestampTimezone = + "parquet.writer.timestamp.timezone"; +}; + +} // namespace facebook::velox::parquet diff --git a/velox/dwio/parquet/writer/arrow/ArrowSchema.cpp b/velox/dwio/parquet/writer/arrow/ArrowSchema.cpp index de5a2382198..a514c33a5da 100644 --- a/velox/dwio/parquet/writer/arrow/ArrowSchema.cpp +++ b/velox/dwio/parquet/writer/arrow/ArrowSchema.cpp @@ -59,35 +59,35 @@ using schema::PrimitiveNode; using ParquetType = Type; -// ---------------------------------------------------------------------- -// Parquet to Arrow schema conversion +// ----------------------------------------------------------------------. +// Parquet to Arrow schema conversion. namespace { /// Increments levels according to the cardinality of node. -void IncrementLevels(LevelInfo& current_levels, const schema::Node& node) { - if (node.is_repeated()) { - current_levels.IncrementRepeated(); +void incrementLevels(LevelInfo& currentLevels, const schema::Node& node) { + if (node.isRepeated()) { + currentLevels.IncrementRepeated(); return; } - if (node.is_optional()) { - current_levels.IncrementOptional(); + if (node.isOptional()) { + currentLevels.IncrementOptional(); return; } } -/// Like std::string_view::ends_with in C++20 -inline bool EndsWith(std::string_view s, std::string_view suffix) { +/// Like std::string_view::ends_with in C++20. +inline bool endsWith(std::string_view s, std::string_view suffix) { return s.length() >= suffix.length() && (s.empty() || s.substr(s.length() - suffix.length()) == suffix); } namespace detail { template -struct can_to_chars : public std::false_type {}; +struct CanToChars : public std::false_type {}; template -struct can_to_chars< +struct CanToChars< T, std::void_t(), @@ -101,9 +101,9 @@ struct can_to_chars< /// This is useful as some C++ libraries do not implement all specified /// overloads for std::to_chars. template -inline constexpr bool have_to_chars = detail::can_to_chars::value; +inline constexpr bool haveToChars = detail::CanToChars::value; -/// \brief An ergonomic wrapper around std::to_chars, returning a std::string +/// \brief An ergonomic wrapper around std::to_chars, returning a std::string. /// /// For most inputs, the std::string result will not incur any heap allocation /// thanks to small string optimization. @@ -111,10 +111,10 @@ inline constexpr bool have_to_chars = detail::can_to_chars::value; /// Compared to std::to_string, this function gives locale-agnostic results /// and might also be faster. template -std::string ToChars(T value, Args&&... args) { - if constexpr (!have_to_chars) { - // Some C++ standard libraries do not yet implement std::to_chars for all - // types, in which case we have to fallback to std::string. +std::string toChars(T value, Args&&... args) { + if constexpr (!haveToChars) { + // Some C++ standard libraries do not yet implement std::to_chars for all. + // Types, in which case we have to fallback to std::string. return std::to_string(value); } else { // According to various sources, the GNU libstdc++ and Microsoft's C++ STL @@ -134,86 +134,86 @@ std::string ToChars(T value, Args&&... args) { } } -Repetition::type RepetitionFromNullable(bool is_nullable) { - return is_nullable ? Repetition::OPTIONAL : Repetition::REQUIRED; +Repetition::type repetitionFromNullable(bool isNullable) { + return isNullable ? Repetition::kOptional : Repetition::kRequired; } -Status FieldToNode( +Status fieldToNode( const std::string& name, const std::shared_ptr& field, const WriterProperties& properties, - const ArrowWriterProperties& arrow_properties, + const ArrowWriterProperties& arrowProperties, NodePtr* out); -Status ListToNode( +Status listToNode( const std::shared_ptr<::arrow::BaseListType>& type, const std::string& name, bool nullable, - int field_id, + int fieldId, const WriterProperties& properties, - const ArrowWriterProperties& arrow_properties, + const ArrowWriterProperties& arrowProperties, NodePtr* out) { NodePtr element; - std::string value_name = arrow_properties.compliant_nested_types() + std::string valueName = arrowProperties.compliantNestedTypes() ? "element" : type->value_field()->name(); - RETURN_NOT_OK(FieldToNode( - value_name, type->value_field(), properties, arrow_properties, &element)); + RETURN_NOT_OK(fieldToNode( + valueName, type->value_field(), properties, arrowProperties, &element)); - NodePtr list = GroupNode::Make("list", Repetition::REPEATED, {element}); - *out = GroupNode::Make( + NodePtr List = GroupNode::make("list", Repetition::kRepeated, {element}); + *out = GroupNode::make( name, - RepetitionFromNullable(nullable), - {list}, - LogicalType::List(), - field_id); + repetitionFromNullable(nullable), + {List}, + LogicalType::list(), + fieldId); return Status::OK(); } -Status MapToNode( +Status mapToNode( const std::shared_ptr<::arrow::MapType>& type, const std::string& name, bool nullable, - int field_id, + int fieldId, const WriterProperties& properties, - const ArrowWriterProperties& arrow_properties, + const ArrowWriterProperties& arrowProperties, NodePtr* out) { // TODO: Should we offer a non-compliant mode that forwards the type names? - NodePtr key_node; - RETURN_NOT_OK(FieldToNode( - "key", type->key_field(), properties, arrow_properties, &key_node)); + NodePtr keyNode; + RETURN_NOT_OK(fieldToNode( + "key", type->key_field(), properties, arrowProperties, &keyNode)); - NodePtr value_node; - RETURN_NOT_OK(FieldToNode( - "value", type->item_field(), properties, arrow_properties, &value_node)); + NodePtr valueNode; + RETURN_NOT_OK(fieldToNode( + "value", type->item_field(), properties, arrowProperties, &valueNode)); - NodePtr key_value = GroupNode::Make( - "key_value", Repetition::REPEATED, {key_node, value_node}); - *out = GroupNode::Make( + NodePtr keyValue = + GroupNode::make("key_value", Repetition::kRepeated, {keyNode, valueNode}); + *out = GroupNode::make( name, - RepetitionFromNullable(nullable), - {key_value}, - LogicalType::Map(), - field_id); + repetitionFromNullable(nullable), + {keyValue}, + LogicalType::map(), + fieldId); return Status::OK(); } -Status StructToNode( +Status structToNode( const std::shared_ptr<::arrow::StructType>& type, const std::string& name, bool nullable, - int field_id, + int fieldId, const WriterProperties& properties, - const ArrowWriterProperties& arrow_properties, + const ArrowWriterProperties& arrowProperties, NodePtr* out) { std::vector children(type->num_fields()); if (type->num_fields() != 0) { for (int i = 0; i < type->num_fields(); i++) { - RETURN_NOT_OK(FieldToNode( + RETURN_NOT_OK(fieldToNode( type->field(i)->name(), type->field(i), properties, - arrow_properties, + arrowProperties, &children[i])); } } else { @@ -227,71 +227,65 @@ Status StructToNode( "Consider adding a dummy child field."); } - *out = GroupNode::Make( - name, RepetitionFromNullable(nullable), children, nullptr, field_id); + *out = GroupNode::make( + name, repetitionFromNullable(nullable), children, nullptr, fieldId); return Status::OK(); } static std::shared_ptr -TimestampLogicalTypeFromArrowTimestamp( - const ::arrow::TimestampType& timestamp_type, - ::arrow::TimeUnit::type time_unit) { - const bool utc = !(timestamp_type.timezone().empty()); +timestampLogicalTypeFromArrowTimestamp( + const ::arrow::TimestampType& timestampType, + ::arrow::TimeUnit::type timeUnit) { + const bool utc = !(timestampType.timezone().empty()); // ARROW-5878(wesm): for forward compatibility reasons, and because // there's no other way to signal to old readers that values are // timestamps, we force the ConvertedType field to be set to the // corresponding TIMESTAMP_* value. This does cause some ambiguity // as Parquet readers have not been consistent about the // interpretation of TIMESTAMP_* values as being UTC-normalized. - switch (time_unit) { + switch (timeUnit) { case ::arrow::TimeUnit::MILLI: - return LogicalType::Timestamp( - utc, - LogicalType::TimeUnit::MILLIS, - /*is_from_converted_type=*/false, - /*force_set_converted_type=*/true); + return LogicalType::timestamp( + utc, LogicalType::TimeUnit::kMillis, false, true); case ::arrow::TimeUnit::MICRO: - return LogicalType::Timestamp( - utc, - LogicalType::TimeUnit::MICROS, - /*is_from_converted_type=*/false, - /*force_set_converted_type=*/true); + return LogicalType::timestamp( + utc, LogicalType::TimeUnit::kMicros, false, true); case ::arrow::TimeUnit::NANO: - return LogicalType::Timestamp(utc, LogicalType::TimeUnit::NANOS); + return LogicalType::timestamp(utc, LogicalType::TimeUnit::kNanos); case ::arrow::TimeUnit::SECOND: // No equivalent parquet logical type. break; } - return LogicalType::None(); + return LogicalType::none(); } -static Status GetTimestampMetadata( +static Status getTimestampMetadata( const ::arrow::TimestampType& type, const WriterProperties& properties, - const ArrowWriterProperties& arrow_properties, - ParquetType::type* physical_type, - std::shared_ptr* logical_type) { - const bool coerce = arrow_properties.coerce_timestamps_enabled(); - const auto target_unit = - coerce ? arrow_properties.coerce_timestamps_unit() : type.unit(); + const ArrowWriterProperties& arrowProperties, + ParquetType::type* physicalType, + std::shared_ptr* logicalType) { + const bool coerce = arrowProperties.coerceTimestampsEnabled(); + const auto targetUnit = + coerce ? arrowProperties.coerceTimestampsUnit() : type.unit(); const auto version = properties.version(); // The user is explicitly asking for Impala int96 encoding, there is no // logical type. - if (arrow_properties.support_deprecated_int96_timestamps()) { - *physical_type = ParquetType::INT96; + if (arrowProperties.supportDeprecatedInt96Timestamps()) { + *physicalType = ParquetType::kInt96; return Status::OK(); } - *physical_type = ParquetType::INT64; - *logical_type = TimestampLogicalTypeFromArrowTimestamp(type, target_unit); + *physicalType = ParquetType::kInt64; + *logicalType = timestampLogicalTypeFromArrowTimestamp(type, targetUnit); // The user is explicitly asking for timestamp data to be converted to the // specified units (target_unit). if (coerce) { if (version == ParquetVersion::PARQUET_1_0 || version == ParquetVersion::PARQUET_2_4) { - switch (target_unit) { + switch (targetUnit) { case ::arrow::TimeUnit::MILLI: case ::arrow::TimeUnit::MICRO: break; @@ -299,12 +293,12 @@ static Status GetTimestampMetadata( case ::arrow::TimeUnit::SECOND: return Status::NotImplemented( "For Parquet version ", - ParquetVersionToString(version), + parquetVersionToString(version), ", can only coerce Arrow timestamps to " "milliseconds or microseconds"); } } else { - switch (target_unit) { + switch (targetUnit) { case ::arrow::TimeUnit::MILLI: case ::arrow::TimeUnit::MICRO: case ::arrow::TimeUnit::NANO: @@ -312,7 +306,7 @@ static Status GetTimestampMetadata( case ::arrow::TimeUnit::SECOND: return Status::NotImplemented( "For Parquet version ", - ParquetVersionToString(version), + parquetVersionToString(version), ", can only coerce Arrow timestamps to " "milliseconds, microseconds, or nanoseconds"); } @@ -320,24 +314,25 @@ static Status GetTimestampMetadata( return Status::OK(); } - // The user implicitly wants timestamp data to retain its original time units, - // however the ConvertedType field used to indicate logical types for Parquet - // version <= 2.4 fields does not allow for nanosecond time units and so - // nanoseconds must be coerced to microseconds. + // The user implicitly wants timestamp data to retain its original time + // units. However, the ConvertedType field used to indicate logical types for + // Parquet version <= 2.4 fields does not allow for nanosecond time units and + // so nanoseconds must be coerced to microseconds. if ((version == ParquetVersion::PARQUET_1_0 || version == ParquetVersion::PARQUET_2_4) && type.unit() == ::arrow::TimeUnit::NANO) { - *logical_type = - TimestampLogicalTypeFromArrowTimestamp(type, ::arrow::TimeUnit::MICRO); + *logicalType = + timestampLogicalTypeFromArrowTimestamp(type, ::arrow::TimeUnit::MICRO); return Status::OK(); } - // The user implicitly wants timestamp data to retain its original time units, - // however the Arrow seconds time unit can not be represented (annotated) in - // any version of Parquet and so must be coerced to milliseconds. + // The user implicitly wants timestamp data to retain its original time + // units. However, the Arrow seconds time unit can not be represented + // (annotated) in any version of Parquet and so must be coerced to + // milliseconds. if (type.unit() == ::arrow::TimeUnit::SECOND) { - *logical_type = - TimestampLogicalTypeFromArrowTimestamp(type, ::arrow::TimeUnit::MILLI); + *logicalType = + timestampLogicalTypeFromArrowTimestamp(type, ::arrow::TimeUnit::MILLI); return Status::OK(); } @@ -346,15 +341,7 @@ static Status GetTimestampMetadata( static constexpr char FIELD_ID_KEY[] = "PARQUET:field_id"; -std::shared_ptr<::arrow::KeyValueMetadata> FieldIdMetadata(int field_id) { - if (field_id >= 0) { - return ::arrow::key_value_metadata({FIELD_ID_KEY}, {ToChars(field_id)}); - } else { - return nullptr; - } -} - -int FieldIdFromMetadata( +int fieldIdFromMetadata( const std::shared_ptr& metadata) { if (!metadata) { return -1; @@ -363,31 +350,31 @@ int FieldIdFromMetadata( if (key < 0) { return -1; } - std::string field_id_str = metadata->value(key); - int field_id; + std::string fieldIdStr = metadata->value(key); + int fieldId; if (::arrow::internal::ParseValue<::arrow::Int32Type>( - field_id_str.c_str(), field_id_str.length(), &field_id)) { - if (field_id < 0) { - // Thrift should convert any negative value to null but normalize to -1 - // here in case we later check this in logic. + fieldIdStr.c_str(), fieldIdStr.length(), &fieldId)) { + if (fieldId < 0) { + // Thrift should convert any negative value to null but normalize to -1. + // Here in case we later check this in logic. return -1; } - return field_id; + return fieldId; } else { return -1; } } -Status FieldToNode( +Status fieldToNode( const std::string& name, const std::shared_ptr& field, const WriterProperties& properties, - const ArrowWriterProperties& arrow_properties, + const ArrowWriterProperties& arrowProperties, NodePtr* out) { - std::shared_ptr logical_type = LogicalType::None(); + std::shared_ptr logicalType = LogicalType::none(); ParquetType::type type; - Repetition::type repetition = RepetitionFromNullable(field->nullable()); - int field_id = FieldIdFromMetadata(field->metadata()); + Repetition::type repetition = repetitionFromNullable(field->nullable()); + int fieldId = fieldIdFromMetadata(field->metadata()); int length = -1; int precision = -1; @@ -395,186 +382,181 @@ Status FieldToNode( switch (field->type()->id()) { case ArrowTypeId::NA: { - type = ParquetType::INT32; - logical_type = LogicalType::Null(); - if (repetition != Repetition::OPTIONAL) { + type = ParquetType::kInt32; + logicalType = LogicalType::nullType(); + if (repetition != Repetition::kOptional) { return Status::Invalid("NullType Arrow field must be nullable"); } } break; case ArrowTypeId::BOOL: - type = ParquetType::BOOLEAN; + type = ParquetType::kBoolean; break; case ArrowTypeId::UINT8: - type = ParquetType::INT32; - logical_type = LogicalType::Int(8, false); + type = ParquetType::kInt32; + logicalType = LogicalType::intType(8, false); break; case ArrowTypeId::INT8: - type = ParquetType::INT32; - logical_type = LogicalType::Int(8, true); + type = ParquetType::kInt32; + logicalType = LogicalType::intType(8, true); break; case ArrowTypeId::UINT16: - type = ParquetType::INT32; - logical_type = LogicalType::Int(16, false); + type = ParquetType::kInt32; + logicalType = LogicalType::intType(16, false); break; case ArrowTypeId::INT16: - type = ParquetType::INT32; - logical_type = LogicalType::Int(16, true); + type = ParquetType::kInt32; + logicalType = LogicalType::intType(16, true); break; case ArrowTypeId::UINT32: if (properties.version() == ParquetVersion::PARQUET_1_0) { - type = ParquetType::INT64; + type = ParquetType::kInt64; } else { - type = ParquetType::INT32; - logical_type = LogicalType::Int(32, false); + type = ParquetType::kInt32; + logicalType = LogicalType::intType(32, false); } break; case ArrowTypeId::INT32: - type = ParquetType::INT32; + type = ParquetType::kInt32; break; case ArrowTypeId::UINT64: - type = ParquetType::INT64; - logical_type = LogicalType::Int(64, false); + type = ParquetType::kInt64; + logicalType = LogicalType::intType(64, false); break; case ArrowTypeId::INT64: - type = ParquetType::INT64; + type = ParquetType::kInt64; break; case ArrowTypeId::FLOAT: - type = ParquetType::FLOAT; + type = ParquetType::kFloat; break; case ArrowTypeId::DOUBLE: - type = ParquetType::DOUBLE; + type = ParquetType::kDouble; break; case ArrowTypeId::LARGE_STRING: case ArrowTypeId::STRING: - type = ParquetType::BYTE_ARRAY; - logical_type = LogicalType::String(); + type = ParquetType::kByteArray; + logicalType = LogicalType::string(); break; case ArrowTypeId::LARGE_BINARY: case ArrowTypeId::BINARY: - type = ParquetType::BYTE_ARRAY; + type = ParquetType::kByteArray; break; case ArrowTypeId::FIXED_SIZE_BINARY: { - type = ParquetType::FIXED_LEN_BYTE_ARRAY; - const auto& fixed_size_binary_type = + type = ParquetType::kFixedLenByteArray; + const auto& fixedSizeBinaryType = static_cast(*field->type()); - length = fixed_size_binary_type.byte_width(); + length = fixedSizeBinaryType.byte_width(); } break; case ArrowTypeId::DECIMAL128: case ArrowTypeId::DECIMAL256: { - const auto& decimal_type = + const auto& decimalType = static_cast(*field->type()); - precision = decimal_type.precision(); - scale = decimal_type.scale(); - if (properties.store_decimal_as_integer() && 1 <= precision && + precision = decimalType.precision(); + scale = decimalType.scale(); + if (properties.storeDecimalAsInteger() && 1 <= precision && precision <= 18) { - type = precision <= 9 ? ParquetType ::INT32 : ParquetType ::INT64; + type = precision <= 9 ? ParquetType::kInt32 : ParquetType::kInt64; } else { - type = ParquetType::FIXED_LEN_BYTE_ARRAY; + type = ParquetType::kFixedLenByteArray; length = DecimalType::DecimalSize(precision); } PARQUET_CATCH_NOT_OK( - logical_type = LogicalType::Decimal(precision, scale)); + logicalType = LogicalType::decimal(precision, scale)); } break; case ArrowTypeId::DATE32: - type = ParquetType::INT32; - logical_type = LogicalType::Date(); + type = ParquetType::kInt32; + logicalType = LogicalType::date(); break; case ArrowTypeId::DATE64: - type = ParquetType::INT32; - logical_type = LogicalType::Date(); + type = ParquetType::kInt32; + logicalType = LogicalType::date(); break; case ArrowTypeId::TIMESTAMP: - RETURN_NOT_OK(GetTimestampMetadata( + RETURN_NOT_OK(getTimestampMetadata( static_cast<::arrow::TimestampType&>(*field->type()), properties, - arrow_properties, + arrowProperties, &type, - &logical_type)); + &logicalType)); break; case ArrowTypeId::TIME32: - type = ParquetType::INT32; - logical_type = LogicalType::Time( - /*is_adjusted_to_utc=*/true, LogicalType::TimeUnit::MILLIS); + type = ParquetType::kInt32; + logicalType = LogicalType::time(true, LogicalType::TimeUnit::kMillis); break; case ArrowTypeId::TIME64: { - type = ParquetType::INT64; - auto time_type = static_cast<::arrow::Time64Type*>(field->type().get()); - if (time_type->unit() == ::arrow::TimeUnit::NANO) { - logical_type = LogicalType::Time( - /*is_adjusted_to_utc=*/true, LogicalType::TimeUnit::NANOS); + type = ParquetType::kInt64; + auto timeType = static_cast<::arrow::Time64Type*>(field->type().get()); + if (timeType->unit() == ::arrow::TimeUnit::NANO) { + logicalType = LogicalType::time(true, LogicalType::TimeUnit::kNanos); } else { - logical_type = LogicalType::Time( - /*is_adjusted_to_utc=*/true, LogicalType::TimeUnit::MICROS); + logicalType = LogicalType::time(true, LogicalType::TimeUnit::kMicros); } } break; case ArrowTypeId::DURATION: - type = ParquetType::INT64; + type = ParquetType::kInt64; break; case ArrowTypeId::STRUCT: { - auto struct_type = + auto structType = std::static_pointer_cast<::arrow::StructType>(field->type()); - return StructToNode( - struct_type, + return structToNode( + structType, name, field->nullable(), - field_id, + fieldId, properties, - arrow_properties, + arrowProperties, out); } case ArrowTypeId::FIXED_SIZE_LIST: case ArrowTypeId::LARGE_LIST: case ArrowTypeId::LIST: { - auto list_type = + auto listType = std::static_pointer_cast<::arrow::BaseListType>(field->type()); - return ListToNode( - list_type, + return listToNode( + listType, name, field->nullable(), - field_id, + fieldId, properties, - arrow_properties, + arrowProperties, out); } case ArrowTypeId::DICTIONARY: { // Parquet has no Dictionary type, dictionary-encoded is handled on // the encoding, not the schema level. - const ::arrow::DictionaryType& dict_type = + const ::arrow::DictionaryType& dictType = static_cast(*field->type()); - std::shared_ptr<::arrow::Field> unpacked_field = ::arrow::field( - name, dict_type.value_type(), field->nullable(), field->metadata()); - return FieldToNode( - name, unpacked_field, properties, arrow_properties, out); + std::shared_ptr<::arrow::Field> unpackedField = ::arrow::field( + name, dictType.value_type(), field->nullable(), field->metadata()); + return fieldToNode(name, unpackedField, properties, arrowProperties, out); } case ArrowTypeId::EXTENSION: { - auto ext_type = + auto extType = std::static_pointer_cast<::arrow::ExtensionType>(field->type()); - std::shared_ptr<::arrow::Field> storage_field = ::arrow::field( - name, ext_type->storage_type(), field->nullable(), field->metadata()); - return FieldToNode( - name, storage_field, properties, arrow_properties, out); + std::shared_ptr<::arrow::Field> storageField = ::arrow::field( + name, extType->storage_type(), field->nullable(), field->metadata()); + return fieldToNode(name, storageField, properties, arrowProperties, out); } case ArrowTypeId::MAP: { - auto map_type = std::static_pointer_cast<::arrow::MapType>(field->type()); - return MapToNode( - map_type, + auto mapType = std::static_pointer_cast<::arrow::MapType>(field->type()); + return mapToNode( + mapType, name, field->nullable(), - field_id, + fieldId, properties, - arrow_properties, + arrowProperties, out); } default: { - // TODO: DENSE_UNION, SPARE_UNION, JSON_SCALAR, DECIMAL_TEXT, VARCHAR + // TODO: DENSE_UNION, SPARE_UNION, JSON_SCALAR, DECIMAL_TEXT, VARCHAR. return Status::NotImplemented( "Unhandled type for Arrow to Parquet schema conversion: ", field->type()->ToString()); } } - PARQUET_CATCH_NOT_OK(*out = PrimitiveNode::Make(name, repetition, logical_type, type, length, field_id)); + PARQUET_CATCH_NOT_OK(*out = PrimitiveNode::make(name, repetition, logicalType, type, length, fieldId)); return Status::OK(); } @@ -584,165 +566,164 @@ struct SchemaTreeContext { ArrowReaderProperties properties; const SchemaDescriptor* schema; - void LinkParent(const SchemaField* child, const SchemaField* parent) { - manifest->child_to_parent[child] = parent; + void linkParent(const SchemaField* child, const SchemaField* parent) { + manifest->childToParent[child] = parent; } - void RecordLeaf(const SchemaField* leaf) { - manifest->column_index_to_field[leaf->column_index] = leaf; + void recordLeaf(const SchemaField* leaf) { + manifest->columnIndexToField[leaf->columnIndex] = leaf; } }; -bool IsDictionaryReadSupported(const ArrowType& type) { - // Only supported currently for BYTE_ARRAY types +bool isDictionaryReadSupported(const ArrowType& type) { + // Only supported currently for BYTE_ARRAY types. return type.id() == ::arrow::Type::BINARY || type.id() == ::arrow::Type::STRING; } -// ---------------------------------------------------------------------- -// Schema logic +// ----------------------------------------------------------------------. +// Schema logic. -::arrow::Result> GetTypeForNode( - int column_index, - const schema::PrimitiveNode& primitive_node, +::arrow::Result> getTypeForNode( + int columnIndex, + const schema::PrimitiveNode& primitiveNode, SchemaTreeContext* ctx) { ARROW_ASSIGN_OR_RAISE( - std::shared_ptr storage_type, - GetArrowType( - primitive_node, ctx->properties.coerce_int96_timestamp_unit())); - if (ctx->properties.read_dictionary(column_index) && - IsDictionaryReadSupported(*storage_type)) { - return ::arrow::dictionary(::arrow::int32(), storage_type); + std::shared_ptr storageType, + getArrowType(primitiveNode, ctx->properties.coerceInt96TimestampUnit())); + if (ctx->properties.readDictionary(columnIndex) && + isDictionaryReadSupported(*storageType)) { + return ::arrow::dictionary(::arrow::int32(), storageType); } - return storage_type; + return storageType; } -Status NodeToSchemaField( +Status nodeToSchemaField( const Node& node, - LevelInfo current_levels, + LevelInfo currentLevels, SchemaTreeContext* ctx, const SchemaField* parent, SchemaField* out); -Status GroupToSchemaField( - const GroupNode& node, - LevelInfo current_levels, +Status groupToSchemaField( + const GroupNode& groupNode, + LevelInfo currentLevels, SchemaTreeContext* ctx, const SchemaField* parent, SchemaField* out); -Status PopulateLeaf( - int column_index, +Status populateLeaf( + int columnIndex, const std::shared_ptr& field, - LevelInfo current_levels, + LevelInfo currentLevels, SchemaTreeContext* ctx, const SchemaField* parent, SchemaField* out) { out->field = field; - out->column_index = column_index; - out->level_info = current_levels; - ctx->RecordLeaf(out); - ctx->LinkParent(out, parent); + out->columnIndex = columnIndex; + out->levelInfo = currentLevels; + ctx->recordLeaf(out); + ctx->linkParent(out, parent); return Status::OK(); } // Special case mentioned in the format spec: -// If the name is array or ends in _tuple, this should be a list of struct +// If the name is array or ends in _tuple, this should be a list of struct, // even for single child elements. -bool HasStructListName(const GroupNode& node) { - ::std::string_view name{node.name()}; - return name == "array" || EndsWith(name, "_tuple"); +bool hasStructListName(const GroupNode& groupNode) { + ::std::string_view name{groupNode.name()}; + return name == "array" || endsWith(name, "_tuple"); } -Status GroupToStruct( - const GroupNode& node, - LevelInfo current_levels, +Status groupToStruct( + const GroupNode& groupNode, + LevelInfo currentLevels, SchemaTreeContext* ctx, const SchemaField* parent, SchemaField* out) { - std::vector> arrow_fields; - out->children.resize(node.field_count()); + std::vector> arrowFields; + out->children.resize(groupNode.fieldCount()); // All level increments for the node are expected to happen by callers. // This is required because repeated elements need to have their own // SchemaField. - for (int i = 0; i < node.field_count(); i++) { - RETURN_NOT_OK(NodeToSchemaField( - *node.field(i), current_levels, ctx, out, &out->children[i])); - arrow_fields.push_back(out->children[i].field); + for (int i = 0; i < groupNode.fieldCount(); i++) { + RETURN_NOT_OK(nodeToSchemaField( + *groupNode.field(i), currentLevels, ctx, out, &out->children[i])); + arrowFields.push_back(out->children[i].field); } - auto struct_type = ::arrow::struct_(arrow_fields); + auto structType = ::arrow::struct_(arrowFields); out->field = ::arrow::field( - node.name(), - struct_type, - node.is_optional(), - FieldIdMetadata(node.field_id())); - out->level_info = current_levels; + groupNode.name(), + structType, + groupNode.isOptional(), + fieldIdMetadata(groupNode.fieldId())); + out->levelInfo = currentLevels; return Status::OK(); } -Status ListToSchemaField( +Status listToSchemaField( const GroupNode& group, - LevelInfo current_levels, + LevelInfo currentLevels, SchemaTreeContext* ctx, const SchemaField* parent, SchemaField* out); -Status MapToSchemaField( +Status mapToSchemaField( const GroupNode& group, - LevelInfo current_levels, + LevelInfo currentLevels, SchemaTreeContext* ctx, const SchemaField* parent, SchemaField* out) { - if (group.field_count() != 1) { + if (group.fieldCount() != 1) { return Status::Invalid("MAP-annotated groups must have a single child."); } - if (group.is_repeated()) { + if (group.isRepeated()) { return Status::Invalid("MAP-annotated groups must not be repeated."); } - const Node& key_value_node = *group.field(0); + const Node& keyValueNode = *group.field(0); - if (!key_value_node.is_repeated()) { + if (!keyValueNode.isRepeated()) { return Status::Invalid( "Non-repeated key value in a MAP-annotated group are not supported."); } - if (!key_value_node.is_group()) { + if (!keyValueNode.isGroup()) { return Status::Invalid("Key-value node must be a group."); } - const GroupNode& key_value = checked_cast(key_value_node); - if (key_value.field_count() != 1 && key_value.field_count() != 2) { + const GroupNode& keyValue = checked_cast(keyValueNode); + if (keyValue.fieldCount() != 1 && keyValue.fieldCount() != 2) { return Status::Invalid( "Key-value map node must have 1 or 2 child elements. Found: ", - key_value.field_count()); + keyValue.fieldCount()); } - const Node& key_node = *key_value.field(0); - if (!key_node.is_required()) { + const Node& keyNode = *keyValue.field(0); + if (!keyNode.isRequired()) { return Status::Invalid("Map keys must be annotated as required."); } - // Arrow doesn't support 1 column maps (i.e. Sets). The options are to either - // make the values column nullable, or process the map as a list. We choose - // the latter as it is simpler. - if (key_value.field_count() == 1) { - return ListToSchemaField(group, current_levels, ctx, parent, out); + // Arrow doesn't support 1 column maps (i.e. Sets). The options are to + // either make the values column nullable, or process the map as a list. We + // choose the latter as it is simpler. + if (keyValue.fieldCount() == 1) { + return listToSchemaField(group, currentLevels, ctx, parent, out); } - IncrementLevels(current_levels, group); - int16_t repeated_ancestor_def_level = current_levels.IncrementRepeated(); + incrementLevels(currentLevels, group); + int16_t repeatedAncestorDefLevel = currentLevels.IncrementRepeated(); out->children.resize(1); - SchemaField* key_value_field = &out->children[0]; + SchemaField* keyValueField = &out->children[0]; - key_value_field->children.resize(2); - SchemaField* key_field = &key_value_field->children[0]; - SchemaField* value_field = &key_value_field->children[1]; + keyValueField->children.resize(2); + SchemaField* keyField = &keyValueField->children[0]; + SchemaField* valueField = &keyValueField->children[1]; - ctx->LinkParent(out, parent); - ctx->LinkParent(key_value_field, out); - ctx->LinkParent(key_field, key_value_field); - ctx->LinkParent(value_field, key_value_field); + ctx->linkParent(out, parent); + ctx->linkParent(keyValueField, out); + ctx->linkParent(keyField, keyValueField); + ctx->linkParent(valueField, keyValueField); // required/optional group name=whatever { // repeated group name=key_values{ @@ -752,60 +733,60 @@ Status MapToSchemaField( // } // - RETURN_NOT_OK(NodeToSchemaField( - *key_value.field(0), current_levels, ctx, key_value_field, key_field)); - RETURN_NOT_OK(NodeToSchemaField( - *key_value.field(1), current_levels, ctx, key_value_field, value_field)); + RETURN_NOT_OK(nodeToSchemaField( + *keyValue.field(0), currentLevels, ctx, keyValueField, keyField)); + RETURN_NOT_OK(nodeToSchemaField( + *keyValue.field(1), currentLevels, ctx, keyValueField, valueField)); - key_value_field->field = ::arrow::field( + keyValueField->field = ::arrow::field( group.name(), - ::arrow::struct_({key_field->field, value_field->field}), - /*nullable=*/false, - FieldIdMetadata(key_value.field_id())); - key_value_field->level_info = current_levels; + ::arrow::struct_({keyField->field, valueField->field}), + false, + fieldIdMetadata(keyValue.fieldId())); + keyValueField->levelInfo = currentLevels; out->field = ::arrow::field( group.name(), - std::make_shared<::arrow::MapType>(key_value_field->field), - group.is_optional(), - FieldIdMetadata(group.field_id())); - out->level_info = current_levels; - // At this point current levels contains the def level for this list, - // we need to reset to the prior parent. - out->level_info.repeatedAncestorDefLevel = repeated_ancestor_def_level; + std::make_shared<::arrow::MapType>(keyValueField->field), + group.isOptional(), + fieldIdMetadata(group.fieldId())); + out->levelInfo = currentLevels; + // At this point current levels contains the def level for this list. + // We need to reset to the prior parent. + out->levelInfo.repeatedAncestorDefLevel = repeatedAncestorDefLevel; return Status::OK(); } -Status ListToSchemaField( +Status listToSchemaField( const GroupNode& group, - LevelInfo current_levels, + LevelInfo currentLevels, SchemaTreeContext* ctx, const SchemaField* parent, SchemaField* out) { - if (group.field_count() != 1) { + if (group.fieldCount() != 1) { return Status::Invalid("LIST-annotated groups must have a single child."); } - if (group.is_repeated()) { + if (group.isRepeated()) { return Status::Invalid("LIST-annotated groups must not be repeated."); } - IncrementLevels(current_levels, group); + incrementLevels(currentLevels, group); - out->children.resize(group.field_count()); - SchemaField* child_field = &out->children[0]; + out->children.resize(group.fieldCount()); + SchemaField* childField = &out->children[0]; - ctx->LinkParent(out, parent); - ctx->LinkParent(child_field, out); + ctx->linkParent(out, parent); + ctx->linkParent(childField, out); - const Node& list_node = *group.field(0); + const Node& listNode = *group.field(0); - if (!list_node.is_repeated()) { + if (!listNode.isRepeated()) { return Status::Invalid( "Non-repeated nodes in a LIST-annotated group are not supported."); } - int16_t repeated_ancestor_def_level = current_levels.IncrementRepeated(); - if (list_node.is_group()) { - // Resolve 3-level encoding + int16_t repeatedAncestorDefLevel = currentLevels.IncrementRepeated(); + if (listNode.isGroup()) { + // Resolve 3-level encoding. // // required/optional group name=whatever { // repeated group name=list { @@ -813,9 +794,9 @@ Status ListToSchemaField( // } // } // - // yields list ?nullable + // Yields list ?nullable. // - // We distinguish the special case that we have + // We distinguish the special case that we have. // // required/optional group name=whatever { // repeated group name=array or $SOMETHING_tuple { @@ -824,160 +805,152 @@ Status ListToSchemaField( // } // // In this latter case, the inner type of the list should be a struct - // rather than a primitive value + // rather than a primitive value. // - // yields list not null> ?nullable - const auto& list_group = static_cast(list_node); + // Yields list not null> ?nullable. + const auto& listGroup = static_cast(listNode); // Special case mentioned in the format spec: - // If the name is array or ends in _tuple, this should be a list of struct - // even for single child elements. - if (list_group.field_count() == 1 && !HasStructListName(list_group)) { - // List of primitive type - RETURN_NOT_OK(NodeToSchemaField( - *list_group.field(0), current_levels, ctx, out, child_field)); + // If the name is array or ends in _tuple, this should be a list of + // struct, even for single child elements. + if (listGroup.fieldCount() == 1 && !hasStructListName(listGroup)) { + // List of primitive type. + RETURN_NOT_OK(nodeToSchemaField( + *listGroup.field(0), currentLevels, ctx, out, childField)); } else { RETURN_NOT_OK( - GroupToStruct(list_group, current_levels, ctx, out, child_field)); + groupToStruct(listGroup, currentLevels, ctx, out, childField)); } } else { - // Two-level list encoding + // Two-level list encoding. // // required/optional group LIST { // repeated TYPE; // } - const auto& primitive_node = static_cast(list_node); - int column_index = ctx->schema->GetColumnIndex(primitive_node); + const auto& primitiveNode = static_cast(listNode); + int columnIndex = ctx->schema->getColumnIndex(primitiveNode); ARROW_ASSIGN_OR_RAISE( std::shared_ptr type, - GetTypeForNode(column_index, primitive_node, ctx)); - auto item_field = ::arrow::field( - list_node.name(), - type, - /*nullable=*/false, - FieldIdMetadata(list_node.field_id())); - RETURN_NOT_OK(PopulateLeaf( - column_index, item_field, current_levels, ctx, out, child_field)); + getTypeForNode(columnIndex, primitiveNode, ctx)); + auto itemField = ::arrow::field( + listNode.name(), type, false, fieldIdMetadata(listNode.fieldId())); + RETURN_NOT_OK(populateLeaf( + columnIndex, itemField, currentLevels, ctx, out, childField)); } out->field = ::arrow::field( group.name(), - ::arrow::list(child_field->field), - group.is_optional(), - FieldIdMetadata(group.field_id())); - out->level_info = current_levels; - // At this point current levels contains the def level for this list, - // we need to reset to the prior parent. - out->level_info.repeatedAncestorDefLevel = repeated_ancestor_def_level; + ::arrow::list(childField->field), + group.isOptional(), + fieldIdMetadata(group.fieldId())); + out->levelInfo = currentLevels; + // At this point current levels contains the def level for this list. + // We need to reset to the prior parent. + out->levelInfo.repeatedAncestorDefLevel = repeatedAncestorDefLevel; return Status::OK(); } -Status GroupToSchemaField( - const GroupNode& node, - LevelInfo current_levels, +Status groupToSchemaField( + const GroupNode& groupNode, + LevelInfo currentLevels, SchemaTreeContext* ctx, const SchemaField* parent, SchemaField* out) { - if (node.logical_type()->is_list()) { - return ListToSchemaField(node, current_levels, ctx, parent, out); - } else if (node.logical_type()->is_map()) { - return MapToSchemaField(node, current_levels, ctx, parent, out); + if (groupNode.logicalType()->isList()) { + return listToSchemaField(groupNode, currentLevels, ctx, parent, out); + } else if (groupNode.logicalType()->isMap()) { + return mapToSchemaField(groupNode, currentLevels, ctx, parent, out); } std::shared_ptr type; - if (node.is_repeated()) { - // Simple repeated struct + if (groupNode.isRepeated()) { + // Simple repeated struct. // // repeated group $NAME { - // r/o TYPE[0] f0 - // r/o TYPE[1] f1 + // R/o TYPE[0] f0. + // R/o TYPE[1] f1. // } out->children.resize(1); - int16_t repeated_ancestor_def_level = current_levels.IncrementRepeated(); + int16_t repeatedAncestorDefLevel = currentLevels.IncrementRepeated(); RETURN_NOT_OK( - GroupToStruct(node, current_levels, ctx, out, &out->children[0])); + groupToStruct(groupNode, currentLevels, ctx, out, &out->children[0])); out->field = ::arrow::field( - node.name(), + groupNode.name(), ::arrow::list(out->children[0].field), - /*nullable=*/false, - FieldIdMetadata(node.field_id())); - - ctx->LinkParent(&out->children[0], out); - out->level_info = current_levels; - // At this point current_levels contains this list as the def level, we need - // to use the previous ancestor of this list. - out->level_info.repeatedAncestorDefLevel = repeated_ancestor_def_level; + false, + fieldIdMetadata(groupNode.fieldId())); + + ctx->linkParent(&out->children[0], out); + out->levelInfo = currentLevels; + // At this point current_levels contains this list as the def level, we + // need to use the previous ancestor of this list. + out->levelInfo.repeatedAncestorDefLevel = repeatedAncestorDefLevel; return Status::OK(); } else { - IncrementLevels(current_levels, node); - return GroupToStruct(node, current_levels, ctx, parent, out); + incrementLevels(currentLevels, groupNode); + return groupToStruct(groupNode, currentLevels, ctx, parent, out); } } -Status NodeToSchemaField( +Status nodeToSchemaField( const Node& node, - LevelInfo current_levels, + LevelInfo currentLevels, SchemaTreeContext* ctx, const SchemaField* parent, SchemaField* out) { // Workhorse function for converting a Parquet schema node to an Arrow // type. Handles different conventions for nested data. - ctx->LinkParent(out, parent); + ctx->linkParent(out, parent); - // Now, walk the schema and create a ColumnDescriptor for each leaf node - if (node.is_group()) { - // A nested field, but we don't know what kind yet - return GroupToSchemaField( - static_cast(node), current_levels, ctx, parent, out); + // Now, walk the schema and create a ColumnDescriptor for each leaf node. + if (node.isGroup()) { + // A nested field, but we don't know what kind yet. + return groupToSchemaField( + static_cast(node), currentLevels, ctx, parent, out); } else { - // Either a normal flat primitive type, or a list type encoded with 1-level - // list encoding. Note that the 3-level encoding is the form recommended by - // the parquet specification, but technically we can have either + // Either a normal flat primitive type, or a list type encoded with 1-level. + // List encoding. Note that the 3-level encoding is the form recommended by + // the Parquet specification, but technically we can have either. // - // required/optional $TYPE $FIELD_NAME + // Required/optional $TYPE $FIELD_NAME. // - // or + // Or. // - // repeated $TYPE $FIELD_NAME - const auto& primitive_node = static_cast(node); - int column_index = ctx->schema->GetColumnIndex(primitive_node); + // Repeated $TYPE $FIELD_NAME. + const auto& primitiveNode = static_cast(node); + int columnIndex = ctx->schema->getColumnIndex(primitiveNode); ARROW_ASSIGN_OR_RAISE( std::shared_ptr type, - GetTypeForNode(column_index, primitive_node, ctx)); - if (node.is_repeated()) { + getTypeForNode(columnIndex, primitiveNode, ctx)); + if (node.isRepeated()) { // One-level list encoding, e.g. // a: repeated int32; - int16_t repeated_ancestor_def_level = current_levels.IncrementRepeated(); + int16_t repeatedAncestorDefLevel = currentLevels.IncrementRepeated(); out->children.resize(1); - auto child_field = ::arrow::field(node.name(), type, /*nullable=*/false); - RETURN_NOT_OK(PopulateLeaf( - column_index, - child_field, - current_levels, - ctx, - out, - &out->children[0])); + auto childField = ::arrow::field(node.name(), type, false); + RETURN_NOT_OK(populateLeaf( + columnIndex, childField, currentLevels, ctx, out, &out->children[0])); out->field = ::arrow::field( node.name(), - ::arrow::list(child_field), - /*nullable=*/false, - FieldIdMetadata(node.field_id())); - out->level_info = current_levels; + ::arrow::list(childField), + false, + fieldIdMetadata(node.fieldId())); + out->levelInfo = currentLevels; // At this point current_levels has consider this list the ancestor so // restore the actual ancestor. - out->level_info.repeatedAncestorDefLevel = repeated_ancestor_def_level; + out->levelInfo.repeatedAncestorDefLevel = repeatedAncestorDefLevel; return Status::OK(); } else { - IncrementLevels(current_levels, node); - // A normal (required/optional) primitive node - return PopulateLeaf( - column_index, + incrementLevels(currentLevels, node); + // A normal (required/optional) primitive node. + return populateLeaf( + columnIndex, ::arrow::field( node.name(), type, - node.is_optional(), - FieldIdMetadata(node.field_id())), - current_levels, + node.isOptional(), + fieldIdMetadata(node.fieldId())), + currentLevels, ctx, parent, out); @@ -985,93 +958,93 @@ Status NodeToSchemaField( } } -// Get the original Arrow schema, as serialized in the Parquet metadata -Status GetOriginSchema( +// Get the original Arrow schema, as serialized in the Parquet metadata. +Status getOriginSchema( const std::shared_ptr& metadata, - std::shared_ptr* clean_metadata, + std::shared_ptr* cleanMetadata, std::shared_ptr<::arrow::Schema>* out) { if (metadata == nullptr) { *out = nullptr; - *clean_metadata = nullptr; + *cleanMetadata = nullptr; return Status::OK(); } static const std::string kArrowSchemaKey = "ARROW:schema"; - int schema_index = metadata->FindKey(kArrowSchemaKey); - if (schema_index == -1) { + int schemaIndex = metadata->FindKey(kArrowSchemaKey); + if (schemaIndex == -1) { *out = nullptr; - *clean_metadata = metadata; + *cleanMetadata = metadata; return Status::OK(); } // The original Arrow schema was serialized using the store_schema option. // We deserialize it here and use it to inform read options such as // dictionary-encoded fields. - auto decoded = ::arrow::util::base64_decode(metadata->value(schema_index)); - auto schema_buf = std::make_shared(decoded); + auto decoded = ::arrow::util::base64_decode(metadata->value(schemaIndex)); + auto schemaBuf = std::make_shared(decoded); - ::arrow::ipc::DictionaryMemo dict_memo; - ::arrow::io::BufferReader input(schema_buf); + ::arrow::ipc::DictionaryMemo dictMemo; + ::arrow::io::BufferReader input(schemaBuf); - ARROW_ASSIGN_OR_RAISE(*out, ::arrow::ipc::ReadSchema(&input, &dict_memo)); + ARROW_ASSIGN_OR_RAISE(*out, ::arrow::ipc::ReadSchema(&input, &dictMemo)); if (metadata->size() > 1) { - // Copy the metadata without the schema key - auto new_metadata = ::arrow::key_value_metadata({}, {}); - new_metadata->reserve(metadata->size() - 1); + // Copy the metadata without the schema key. + auto newMetadata = ::arrow::key_value_metadata({}, {}); + newMetadata->reserve(metadata->size() - 1); for (int64_t i = 0; i < metadata->size(); ++i) { - if (i == schema_index) + if (i == schemaIndex) continue; - new_metadata->Append(metadata->key(i), metadata->value(i)); + newMetadata->Append(metadata->key(i), metadata->value(i)); } - *clean_metadata = new_metadata; + *cleanMetadata = newMetadata; } else { - // No other keys, let metadata be null - *clean_metadata = nullptr; + // No other keys, let metadata be null. + *cleanMetadata = nullptr; } return Status::OK(); } -// Restore original Arrow field information that was serialized as Parquet +// Restore original Arrow field information that was serialized as Parquet. // metadata but that is not necessarily present in the field reconstituted from // Parquet data (for example, Parquet timestamp types doesn't carry timezone // information). -Result ApplyOriginalMetadata( - const Field& origin_field, +Result applyOriginalMetadata( + const Field& originField, SchemaField* inferred); -std::function(FieldVector)> GetNestedFactory( - const ArrowType& origin_type, - const ArrowType& inferred_type) { - switch (inferred_type.id()) { +std::function(FieldVector)> getNestedFactory( + const ArrowType& originType, + const ArrowType& inferredType) { + switch (inferredType.id()) { case ::arrow::Type::STRUCT: - if (origin_type.id() == ::arrow::Type::STRUCT) { + if (originType.id() == ::arrow::Type::STRUCT) { return [](FieldVector fields) { return ::arrow::struct_(std::move(fields)); }; } break; case ::arrow::Type::LIST: - if (origin_type.id() == ::arrow::Type::LIST) { + if (originType.id() == ::arrow::Type::LIST) { return [](FieldVector fields) { VELOX_DCHECK_EQ(fields.size(), 1); return ::arrow::list(std::move(fields[0])); }; } - if (origin_type.id() == ::arrow::Type::LARGE_LIST) { + if (originType.id() == ::arrow::Type::LARGE_LIST) { return [](FieldVector fields) { VELOX_DCHECK_EQ(fields.size(), 1); return ::arrow::large_list(std::move(fields[0])); }; } - if (origin_type.id() == ::arrow::Type::FIXED_SIZE_LIST) { - const auto list_size = - checked_cast(origin_type) + if (originType.id() == ::arrow::Type::FIXED_SIZE_LIST) { + const auto listSize = + checked_cast(originType) .list_size(); - return [list_size](FieldVector fields) { + return [listSize](FieldVector fields) { VELOX_DCHECK_EQ(fields.size(), 1); - return ::arrow::fixed_size_list(std::move(fields[0]), list_size); + return ::arrow::fixed_size_list(std::move(fields[0]), listSize); }; } break; @@ -1081,140 +1054,140 @@ std::function(FieldVector)> GetNestedFactory( return {}; } -Result ApplyOriginalStorageMetadata( - const Field& origin_field, +Result applyOriginalStorageMetadata( + const Field& originField, SchemaField* inferred) { bool modified = false; - auto& origin_type = origin_field.type(); - auto& inferred_type = inferred->field->type(); + auto& originType = originField.type(); + auto& inferredType = inferred->field->type(); - const int num_children = inferred_type->num_fields(); + const int numChildren = inferredType->num_fields(); - if (num_children > 0 && origin_type->num_fields() == num_children) { - VELOX_DCHECK_EQ(static_cast(inferred->children.size()), num_children); - const auto factory = GetNestedFactory(*origin_type, *inferred_type); + if (numChildren > 0 && originType->num_fields() == numChildren) { + VELOX_DCHECK_EQ(static_cast(inferred->children.size()), numChildren); + const auto factory = getNestedFactory(*originType, *inferredType); if (factory) { // The type may be modified (e.g. LargeList) while the children stay the - // same - modified |= origin_type->id() != inferred_type->id(); + // same. + modified |= originType->id() != inferredType->id(); - // Apply original metadata recursively to children - for (int i = 0; i < inferred_type->num_fields(); ++i) { + // Apply original metadata recursively to children. + for (int i = 0; i < inferredType->num_fields(); ++i) { ARROW_ASSIGN_OR_RAISE( - const bool child_modified, - ApplyOriginalMetadata( - *origin_type->field(i), &inferred->children[i])); - modified |= child_modified; + const bool childModified, + applyOriginalMetadata( + *originType->field(i), &inferred->children[i])); + modified |= childModified; } if (modified) { - // Recreate this field using the modified child fields - ::arrow::FieldVector modified_children(inferred_type->num_fields()); - for (int i = 0; i < inferred_type->num_fields(); ++i) { - modified_children[i] = inferred->children[i].field; + // Recreate this field using the modified child fields. + ::arrow::FieldVector modifiedChildren(inferredType->num_fields()); + for (int i = 0; i < inferredType->num_fields(); ++i) { + modifiedChildren[i] = inferred->children[i].field; } inferred->field = - inferred->field->WithType(factory(std::move(modified_children))); + inferred->field->WithType(factory(std::move(modifiedChildren))); } } } - if (origin_type->id() == ::arrow::Type::TIMESTAMP && - inferred_type->id() == ::arrow::Type::TIMESTAMP) { - // Restore time zone, if any - const auto& ts_type = - checked_cast(*inferred_type); - const auto& ts_origin_type = - checked_cast(*origin_type); + if (originType->id() == ::arrow::Type::TIMESTAMP && + inferredType->id() == ::arrow::Type::TIMESTAMP) { + // Restore time zone, if any. + const auto& tsType = + checked_cast(*inferredType); + const auto& tsOriginType = + checked_cast(*originType); // If the data is tz-aware, then set the original time zone, since Parquet - // has no native storage for timezones - if (ts_type.timezone() == "UTC" && !ts_origin_type.timezone().empty()) { - if (ts_type.unit() == ts_origin_type.unit()) { - inferred->field = inferred->field->WithType(origin_type); + // has no native storage for timezones. + if (tsType.timezone() == "UTC" && !tsOriginType.timezone().empty()) { + if (tsType.unit() == tsOriginType.unit()) { + inferred->field = inferred->field->WithType(originType); } else { - auto ts_type_new = - ::arrow::timestamp(ts_type.unit(), ts_origin_type.timezone()); - inferred->field = inferred->field->WithType(ts_type_new); + auto tsTypeNew = + ::arrow::timestamp(tsType.unit(), tsOriginType.timezone()); + inferred->field = inferred->field->WithType(tsTypeNew); } } modified = true; } - if (origin_type->id() == ::arrow::Type::DURATION && - inferred_type->id() == ::arrow::Type::INT64) { + if (originType->id() == ::arrow::Type::DURATION && + inferredType->id() == ::arrow::Type::INT64) { // Read back int64 arrays as duration. - inferred->field = inferred->field->WithType(origin_type); + inferred->field = inferred->field->WithType(originType); modified = true; } - if (origin_type->id() == ::arrow::Type::DICTIONARY && - inferred_type->id() != ::arrow::Type::DICTIONARY && - IsDictionaryReadSupported(*inferred_type)) { + if (originType->id() == ::arrow::Type::DICTIONARY && + inferredType->id() != ::arrow::Type::DICTIONARY && + isDictionaryReadSupported(*inferredType)) { // Direct dictionary reads are only supported for a couple primitive types, // so no need to recurse on value types. - const auto& dict_origin_type = - checked_cast(*origin_type); - inferred->field = inferred->field->WithType(::arrow::dictionary( - ::arrow::int32(), inferred_type, dict_origin_type.ordered())); + const auto& dictOriginType = + checked_cast(*originType); + inferred->field = inferred->field->WithType( + ::arrow::dictionary( + ::arrow::int32(), inferredType, dictOriginType.ordered())); modified = true; } - if ((origin_type->id() == ::arrow::Type::LARGE_BINARY && - inferred_type->id() == ::arrow::Type::BINARY) || - (origin_type->id() == ::arrow::Type::LARGE_STRING && - inferred_type->id() == ::arrow::Type::STRING)) { + if ((originType->id() == ::arrow::Type::LARGE_BINARY && + inferredType->id() == ::arrow::Type::BINARY) || + (originType->id() == ::arrow::Type::LARGE_STRING && + inferredType->id() == ::arrow::Type::STRING)) { // Read back binary-like arrays with the intended offset width. - inferred->field = inferred->field->WithType(origin_type); + inferred->field = inferred->field->WithType(originType); modified = true; } - if (origin_type->id() == ::arrow::Type::DECIMAL256 && - inferred_type->id() == ::arrow::Type::DECIMAL128) { - inferred->field = inferred->field->WithType(origin_type); + if (originType->id() == ::arrow::Type::DECIMAL256 && + inferredType->id() == ::arrow::Type::DECIMAL128) { + inferred->field = inferred->field->WithType(originType); modified = true; } - // Restore field metadata - std::shared_ptr field_metadata = - origin_field.metadata(); - if (field_metadata != nullptr) { + // Restore field metadata. + std::shared_ptr fieldMetadata = + originField.metadata(); + if (fieldMetadata != nullptr) { if (inferred->field->metadata()) { - // Prefer the metadata keys (like field_id) from the current metadata - field_metadata = field_metadata->Merge(*inferred->field->metadata()); + // Prefer the metadata keys (like field_id) from the current metadata. + fieldMetadata = fieldMetadata->Merge(*inferred->field->metadata()); } - inferred->field = inferred->field->WithMetadata(field_metadata); + inferred->field = inferred->field->WithMetadata(fieldMetadata); modified = true; } return modified; } -Result ApplyOriginalMetadata( - const Field& origin_field, +Result applyOriginalMetadata( + const Field& originField, SchemaField* inferred) { bool modified = false; - auto& origin_type = origin_field.type(); + auto& originType = originField.type(); - if (origin_type->id() == ::arrow::Type::EXTENSION) { - const auto& ex_type = - checked_cast(*origin_type); - auto origin_storage_field = origin_field.WithType(ex_type.storage_type()); + if (originType->id() == ::arrow::Type::EXTENSION) { + const auto& exType = + checked_cast(*originType); + auto originStorageField = originField.WithType(exType.storage_type()); - // Apply metadata recursively to storage type - RETURN_NOT_OK( - ApplyOriginalStorageMetadata(*origin_storage_field, inferred)); + // Apply metadata recursively to storage type. + RETURN_NOT_OK(applyOriginalStorageMetadata(*originStorageField, inferred)); // Restore extension type, if the storage type is the same as inferred - // from the Parquet type - if (ex_type.storage_type()->Equals(*inferred->field->type())) { - inferred->field = inferred->field->WithType(origin_type); + // from the Parquet type. + if (exType.storage_type()->Equals(*inferred->field->type())) { + inferred->field = inferred->field->WithType(originType); } modified = true; } else { ARROW_ASSIGN_OR_RAISE( - modified, ApplyOriginalStorageMetadata(origin_field, inferred)); + modified, applyOriginalStorageMetadata(originField, inferred)); } return modified; @@ -1222,79 +1195,87 @@ Result ApplyOriginalMetadata( } // namespace -Status FieldToNode( +std::shared_ptr<::arrow::KeyValueMetadata> fieldIdMetadata(int fieldId) { + if (fieldId >= 0) { + return ::arrow::key_value_metadata({FIELD_ID_KEY}, {toChars(fieldId)}); + } else { + return nullptr; + } +} + +Status fieldToNode( const std::shared_ptr& field, const WriterProperties& properties, - const ArrowWriterProperties& arrow_properties, + const ArrowWriterProperties& arrowProperties, NodePtr* out) { - return FieldToNode(field->name(), field, properties, arrow_properties, out); + return fieldToNode(field->name(), field, properties, arrowProperties, out); } -Status ToParquetSchema( - const ::arrow::Schema* arrow_schema, +Status toParquetSchema( + const ::arrow::Schema* arrowSchema, const WriterProperties& properties, - const ArrowWriterProperties& arrow_properties, + const ArrowWriterProperties& arrowProperties, std::shared_ptr* out) { - std::vector nodes(arrow_schema->num_fields()); - for (int i = 0; i < arrow_schema->num_fields(); i++) { - RETURN_NOT_OK(FieldToNode( - arrow_schema->field(i), properties, arrow_properties, &nodes[i])); + std::vector nodes(arrowSchema->num_fields()); + for (int i = 0; i < arrowSchema->num_fields(); i++) { + RETURN_NOT_OK(fieldToNode( + arrowSchema->field(i), properties, arrowProperties, &nodes[i])); } - NodePtr schema = GroupNode::Make("schema", Repetition::REQUIRED, nodes); + NodePtr schema = GroupNode::make("schema", Repetition::kRequired, nodes); *out = std::make_shared(); - PARQUET_CATCH_NOT_OK((*out)->Init(schema)); + PARQUET_CATCH_NOT_OK((*out)->init(schema)); return Status::OK(); } -Status ToParquetSchema( - const ::arrow::Schema* arrow_schema, +Status toParquetSchema( + const ::arrow::Schema* arrowSchema, const WriterProperties& properties, std::shared_ptr* out) { - return ToParquetSchema( - arrow_schema, properties, *default_arrow_writer_properties(), out); + return toParquetSchema( + arrowSchema, properties, *defaultArrowWriterProperties(), out); } -Status FromParquetSchema( +Status fromParquetSchema( const SchemaDescriptor* schema, const ArrowReaderProperties& properties, - const std::shared_ptr& key_value_metadata, + const std::shared_ptr& keyValueMetadata, std::shared_ptr<::arrow::Schema>* out) { SchemaManifest manifest; RETURN_NOT_OK( - SchemaManifest::Make(schema, key_value_metadata, properties, &manifest)); - std::vector> fields(manifest.schema_fields.size()); + SchemaManifest::make(schema, keyValueMetadata, properties, &manifest)); + std::vector> fields(manifest.schemaFields.size()); for (int i = 0; i < static_cast(fields.size()); i++) { - const auto& schema_field = manifest.schema_fields[i]; - fields[i] = schema_field.field; + const auto& schemaField = manifest.schemaFields[i]; + fields[i] = schemaField.field; } - if (manifest.origin_schema) { + if (manifest.originSchema) { // ARROW-8980: If the ARROW:schema was in the input metadata, then - // manifest.origin_schema will have it scrubbed out - *out = ::arrow::schema(fields, manifest.origin_schema->metadata()); + // manifest.originSchema will have it scrubbed out. + *out = ::arrow::schema(fields, manifest.originSchema->metadata()); } else { - *out = ::arrow::schema(fields, key_value_metadata); + *out = ::arrow::schema(fields, keyValueMetadata); } return Status::OK(); } -Status FromParquetSchema( - const SchemaDescriptor* parquet_schema, +Status fromParquetSchema( + const SchemaDescriptor* parquetSchema, const ArrowReaderProperties& properties, std::shared_ptr<::arrow::Schema>* out) { - return FromParquetSchema(parquet_schema, properties, nullptr, out); + return fromParquetSchema(parquetSchema, properties, nullptr, out); } -Status FromParquetSchema( - const SchemaDescriptor* parquet_schema, +Status fromParquetSchema( + const SchemaDescriptor* parquetSchema, std::shared_ptr<::arrow::Schema>* out) { ArrowReaderProperties properties; - return FromParquetSchema(parquet_schema, properties, nullptr, out); + return fromParquetSchema(parquetSchema, properties, nullptr, out); } -Status SchemaManifest::Make( +Status SchemaManifest::make( const SchemaDescriptor* schema, const std::shared_ptr& metadata, const ArrowReaderProperties& properties, @@ -1303,38 +1284,34 @@ Status SchemaManifest::Make( ctx.manifest = manifest; ctx.properties = properties; ctx.schema = schema; - const GroupNode& schema_node = *schema->group_node(); + const GroupNode& schemaNode = *schema->groupNode(); manifest->descr = schema; - manifest->schema_fields.resize(schema_node.field_count()); - - // Try to deserialize original Arrow schema - RETURN_NOT_OK(GetOriginSchema( - metadata, &manifest->schema_metadata, &manifest->origin_schema)); - // Ignore original schema if it's not compatible with the Parquet schema - if (manifest->origin_schema != nullptr && - manifest->origin_schema->num_fields() != schema_node.field_count()) { - manifest->origin_schema = nullptr; + manifest->schemaFields.resize(schemaNode.fieldCount()); + + // Try to deserialize original Arrow schema. + RETURN_NOT_OK(getOriginSchema( + metadata, &manifest->schemaMetadata, &manifest->originSchema)); + // Ignore original schema if it's not compatible with the Parquet schema. + if (manifest->originSchema != nullptr && + manifest->originSchema->num_fields() != schemaNode.fieldCount()) { + manifest->originSchema = nullptr; } - for (int i = 0; i < static_cast(schema_node.field_count()); ++i) { - SchemaField* out_field = &manifest->schema_fields[i]; - RETURN_NOT_OK(NodeToSchemaField( - *schema_node.field(i), - LevelInfo(), - &ctx, - /*parent=*/nullptr, - out_field)); + for (int i = 0; i < static_cast(schemaNode.fieldCount()); ++i) { + SchemaField* outField = &manifest->schemaFields[i]; + RETURN_NOT_OK(nodeToSchemaField( + *schemaNode.field(i), LevelInfo(), &ctx, nullptr, outField)); - // TODO(wesm): as follow up to ARROW-3246, we should really pass the origin + // TODO(wesm): As follow up to ARROW-3246, we should really pass the origin // schema (if any) through all functions in the schema reconstruction, but // I'm being lazy and just setting dictionary fields at the top level for - // now - if (manifest->origin_schema == nullptr) { + // now. + if (manifest->originSchema == nullptr) { continue; } - auto& origin_field = manifest->origin_schema->field(i); - RETURN_NOT_OK(ApplyOriginalMetadata(*origin_field, out_field)); + auto& originField = manifest->originSchema->field(i); + RETURN_NOT_OK(applyOriginalMetadata(*originField, outField)); } return Status::OK(); } diff --git a/velox/dwio/parquet/writer/arrow/ArrowSchema.h b/velox/dwio/parquet/writer/arrow/ArrowSchema.h index 8302bc1cdb1..6a49edce39c 100644 --- a/velox/dwio/parquet/writer/arrow/ArrowSchema.h +++ b/velox/dwio/parquet/writer/arrow/ArrowSchema.h @@ -47,153 +47,155 @@ namespace arrow { /// @{ PARQUET_EXPORT -::arrow::Status FieldToNode( +::arrow::Status fieldToNode( const std::shared_ptr<::arrow::Field>& field, const WriterProperties& properties, - const ArrowWriterProperties& arrow_properties, + const ArrowWriterProperties& arrowProperties, schema::NodePtr* out); PARQUET_EXPORT -::arrow::Status ToParquetSchema( - const ::arrow::Schema* arrow_schema, +::arrow::Status toParquetSchema( + const ::arrow::Schema* arrowSchema, const WriterProperties& properties, - const ArrowWriterProperties& arrow_properties, + const ArrowWriterProperties& arrowProperties, std::shared_ptr* out); PARQUET_EXPORT -::arrow::Status ToParquetSchema( - const ::arrow::Schema* arrow_schema, +::arrow::Status toParquetSchema( + const ::arrow::Schema* arrowSchema, const WriterProperties& properties, std::shared_ptr* out); /// @} /// \defgroup parquet-to-arrow-schema-conversion Functions to convert a Parquet -/// schema into an Arrow schema. +/// schema into an arrow schema. /// /// @{ PARQUET_EXPORT -::arrow::Status FromParquetSchema( - const SchemaDescriptor* parquet_schema, +::arrow::Status fromParquetSchema( + const SchemaDescriptor* parquetSchema, const ArrowReaderProperties& properties, - const std::shared_ptr& key_value_metadata, + const std::shared_ptr& keyValueMetadata, std::shared_ptr<::arrow::Schema>* out); PARQUET_EXPORT -::arrow::Status FromParquetSchema( - const SchemaDescriptor* parquet_schema, +::arrow::Status fromParquetSchema( + const SchemaDescriptor* parquetSchema, const ArrowReaderProperties& properties, std::shared_ptr<::arrow::Schema>* out); PARQUET_EXPORT -::arrow::Status FromParquetSchema( - const SchemaDescriptor* parquet_schema, +::arrow::Status fromParquetSchema( + const SchemaDescriptor* parquetSchema, std::shared_ptr<::arrow::Schema>* out); /// @} -/// \brief Bridge between an arrow::Field and parquet column indices. +/// \brief Bridge between an arrow::Field and Parquet column indices. struct PARQUET_EXPORT SchemaField { std::shared_ptr<::arrow::Field> field; std::vector children; - // Only set for leaf nodes - int column_index = -1; + // Only set for leaf nodes. + int columnIndex = -1; - LevelInfo level_info; + LevelInfo levelInfo; - bool is_leaf() const { - return column_index != -1; + bool isLeaf() const { + return columnIndex != -1; } }; /// \brief Bridge between a parquet Schema and an arrow Schema. /// -/// Expose parquet columns as a tree structure. Useful traverse and link -/// between arrow's Schema and parquet's Schema. +/// Expose Parquet columns as a tree structure. Useful to traverse and link +/// between Arrow and Parquet schemas. struct PARQUET_EXPORT SchemaManifest { - static ::arrow::Status Make( + static ::arrow::Status make( const SchemaDescriptor* schema, const std::shared_ptr& metadata, const ArrowReaderProperties& properties, SchemaManifest* manifest); const SchemaDescriptor* descr; - std::shared_ptr<::arrow::Schema> origin_schema; - std::shared_ptr schema_metadata; - std::vector schema_fields; + std::shared_ptr<::arrow::Schema> originSchema; + std::shared_ptr schemaMetadata; + std::vector schemaFields; - std::unordered_map column_index_to_field; - std::unordered_map child_to_parent; + std::unordered_map columnIndexToField; + std::unordered_map childToParent; - ::arrow::Status GetColumnField(int column_index, const SchemaField** out) + ::arrow::Status getColumnField(int columnIndex, const SchemaField** out) const { - auto it = column_index_to_field.find(column_index); - if (it == column_index_to_field.end()) { + auto it = columnIndexToField.find(columnIndex); + if (it == columnIndexToField.end()) { return ::arrow::Status::KeyError( "Column index ", - column_index, + columnIndex, " not found in schema manifest, may be malformed"); } *out = it->second; return ::arrow::Status::OK(); } - const SchemaField* GetParent(const SchemaField* field) const { - // Returns nullptr also if not found - auto it = child_to_parent.find(field); - if (it == child_to_parent.end()) { + const SchemaField* getParent(const SchemaField* field) const { + // Returns nullptr also if not found. + auto it = childToParent.find(field); + if (it == childToParent.end()) { return NULLPTR; } return it->second; } /// Coalesce a list of field indices (relative to the equivalent - /// arrow::Schema) which correspond to the column root (first node below the - /// parquet schema's root group) of each leaf referenced in column_indices. + /// Arrow schema) which correspond to the column root (first node below the + /// Parquet schema's root group) of each leaf referenced in columnIndices. /// - /// For example, for leaves `a.b.c`, `a.b.d.e`, and `i.j.k` - /// (column_indices=[0,1,3]) the roots are `a` and `i` (return=[0,2]). + /// For example, for leaves `a.b.c`, `a.b.d.e`, and `i.j.k`. + /// (Column_indices=[0,1,3]) the roots are `a` and `i` (return=[0,2]). /// - /// root - /// -- a <------ - /// -- -- b | | - /// -- -- -- c | - /// -- -- -- d | - /// -- -- -- -- e - /// -- f - /// -- -- g - /// -- -- -- h - /// -- i <--- - /// -- -- j | - /// -- -- -- k - ::arrow::Result> GetFieldIndices( - const std::vector& column_indices) const { - const schema::GroupNode* group = descr->group_node(); - std::unordered_set already_added; + /// Root. + /// -- A <------. + /// -- -- B | |. + /// -- -- -- C |. + /// -- -- -- D |. + /// -- -- -- -- E. + /// -- F. + /// -- -- G. + /// -- -- -- H. + /// -- I <---. + /// -- -- J |. + /// -- -- -- K. + ::arrow::Result> getFieldIndices( + const std::vector& columnIndices) const { + const schema::GroupNode* group = descr->groupNode(); + std::unordered_set alreadyAdded; std::vector out; - for (int column_idx : column_indices) { - if (column_idx < 0 || column_idx >= descr->num_columns()) { + for (int columnIdx : columnIndices) { + if (columnIdx < 0 || columnIdx >= descr->numColumns()) { return ::arrow::Status::IndexError( - "Column index ", column_idx, " is not valid"); + "Column index ", columnIdx, " is not valid"); } - auto field_node = descr->GetColumnRoot(column_idx); - auto field_idx = group->FieldIndex(*field_node); - if (field_idx == -1) { + auto fieldNode = descr->getColumnRoot(columnIdx); + auto fieldIdx = group->fieldIndex(*fieldNode); + if (fieldIdx == -1) { return ::arrow::Status::IndexError( - "Column index ", column_idx, " is not valid"); + "Column index ", columnIdx, " is not valid"); } - if (already_added.insert(field_idx).second) { - out.push_back(field_idx); + if (alreadyAdded.insert(fieldIdx).second) { + out.push_back(fieldIdx); } } return out; } }; +std::shared_ptr<::arrow::KeyValueMetadata> fieldIdMetadata(int32_t fieldId); + } // namespace arrow } // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/ArrowSchemaInternal.cpp b/velox/dwio/parquet/writer/arrow/ArrowSchemaInternal.cpp index 246df606ce3..8681462bfc7 100644 --- a/velox/dwio/parquet/writer/arrow/ArrowSchemaInternal.cpp +++ b/velox/dwio/parquet/writer/arrow/ArrowSchemaInternal.cpp @@ -30,217 +30,216 @@ using ::arrow::Result; using ::arrow::Status; using ::arrow::internal::checked_cast; -Result> MakeArrowDecimal( - const LogicalType& logical_type) { - const auto& decimal = checked_cast(logical_type); +Result> makeArrowDecimal( + const LogicalType& logicalType) { + const auto& decimal = checked_cast(logicalType); if (decimal.precision() <= ::arrow::Decimal128Type::kMaxPrecision) { return ::arrow::Decimal128Type::Make(decimal.precision(), decimal.scale()); } return ::arrow::Decimal256Type::Make(decimal.precision(), decimal.scale()); } -Result> MakeArrowInt( - const LogicalType& logical_type) { - const auto& integer = checked_cast(logical_type); - switch (integer.bit_width()) { +Result> makeArrowInt( + const LogicalType& logicalType) { + const auto& integer = checked_cast(logicalType); + switch (integer.bitWidth()) { case 8: - return integer.is_signed() ? ::arrow::int8() : ::arrow::uint8(); + return integer.isSigned() ? ::arrow::int8() : ::arrow::uint8(); case 16: - return integer.is_signed() ? ::arrow::int16() : ::arrow::uint16(); + return integer.isSigned() ? ::arrow::int16() : ::arrow::uint16(); case 32: - return integer.is_signed() ? ::arrow::int32() : ::arrow::uint32(); + return integer.isSigned() ? ::arrow::int32() : ::arrow::uint32(); default: return Status::TypeError( - logical_type.ToString(), " can not annotate physical type Int32"); + logicalType.toString(), " can not annotate physical type Int32"); } } -Result> MakeArrowInt64( - const LogicalType& logical_type) { - const auto& integer = checked_cast(logical_type); - switch (integer.bit_width()) { +Result> makeArrowInt64( + const LogicalType& logicalType) { + const auto& integer = checked_cast(logicalType); + switch (integer.bitWidth()) { case 64: - return integer.is_signed() ? ::arrow::int64() : ::arrow::uint64(); + return integer.isSigned() ? ::arrow::int64() : ::arrow::uint64(); default: return Status::TypeError( - logical_type.ToString(), " can not annotate physical type Int64"); + logicalType.toString(), " can not annotate physical type Int64"); } } -Result> MakeArrowTime32( - const LogicalType& logical_type) { - const auto& time = checked_cast(logical_type); - switch (time.time_unit()) { - case LogicalType::TimeUnit::MILLIS: +Result> makeArrowTime32( + const LogicalType& logicalType) { + const auto& time = checked_cast(logicalType); + switch (time.timeUnit()) { + case LogicalType::TimeUnit::kMillis: return ::arrow::time32(::arrow::TimeUnit::MILLI); default: return Status::TypeError( - logical_type.ToString(), " can not annotate physical type Time32"); + logicalType.toString(), " can not annotate physical type Time32"); } } -Result> MakeArrowTime64( - const LogicalType& logical_type) { - const auto& time = checked_cast(logical_type); - switch (time.time_unit()) { - case LogicalType::TimeUnit::MICROS: +Result> makeArrowTime64( + const LogicalType& logicalType) { + const auto& time = checked_cast(logicalType); + switch (time.timeUnit()) { + case LogicalType::TimeUnit::kMicros: return ::arrow::time64(::arrow::TimeUnit::MICRO); - case LogicalType::TimeUnit::NANOS: + case LogicalType::TimeUnit::kNanos: return ::arrow::time64(::arrow::TimeUnit::NANO); default: return Status::TypeError( - logical_type.ToString(), " can not annotate physical type Time64"); + logicalType.toString(), " can not annotate physical type Time64"); } } -Result> MakeArrowTimestamp( - const LogicalType& logical_type) { +Result> makeArrowTimestamp( + const LogicalType& logicalType) { const auto& timestamp = - checked_cast(logical_type); - const bool utc_normalized = timestamp.is_from_converted_type() - ? false - : timestamp.is_adjusted_to_utc(); - static const char* utc_timezone = "UTC"; - switch (timestamp.time_unit()) { - case LogicalType::TimeUnit::MILLIS: + checked_cast(logicalType); + const bool utcNormalized = + timestamp.isFromConvertedType() ? false : timestamp.isAdjustedToUtc(); + static const char* utcTimezone = "UTC"; + switch (timestamp.timeUnit()) { + case LogicalType::TimeUnit::kMillis: return ( - utc_normalized - ? ::arrow::timestamp(::arrow::TimeUnit::MILLI, utc_timezone) + utcNormalized + ? ::arrow::timestamp(::arrow::TimeUnit::MILLI, utcTimezone) : ::arrow::timestamp(::arrow::TimeUnit::MILLI)); - case LogicalType::TimeUnit::MICROS: + case LogicalType::TimeUnit::kMicros: return ( - utc_normalized - ? ::arrow::timestamp(::arrow::TimeUnit::MICRO, utc_timezone) + utcNormalized + ? ::arrow::timestamp(::arrow::TimeUnit::MICRO, utcTimezone) : ::arrow::timestamp(::arrow::TimeUnit::MICRO)); - case LogicalType::TimeUnit::NANOS: + case LogicalType::TimeUnit::kNanos: return ( - utc_normalized - ? ::arrow::timestamp(::arrow::TimeUnit::NANO, utc_timezone) + utcNormalized + ? ::arrow::timestamp(::arrow::TimeUnit::NANO, utcTimezone) : ::arrow::timestamp(::arrow::TimeUnit::NANO)); default: return Status::TypeError( "Unrecognized time unit in timestamp logical_type: ", - logical_type.ToString()); + logicalType.toString()); } } -Result> FromByteArray( - const LogicalType& logical_type) { - switch (logical_type.type()) { - case LogicalType::Type::STRING: +Result> fromByteArray( + const LogicalType& logicalType) { + switch (logicalType.type()) { + case LogicalType::Type::kString: return ::arrow::utf8(); - case LogicalType::Type::DECIMAL: - return MakeArrowDecimal(logical_type); - case LogicalType::Type::NONE: - case LogicalType::Type::ENUM: - case LogicalType::Type::JSON: - case LogicalType::Type::BSON: + case LogicalType::Type::kDecimal: + return makeArrowDecimal(logicalType); + case LogicalType::Type::kNone: + case LogicalType::Type::kEnum: + case LogicalType::Type::kJson: + case LogicalType::Type::kBson: return ::arrow::binary(); default: return Status::NotImplemented( "Unhandled logical logical_type ", - logical_type.ToString(), + logicalType.toString(), " for binary array"); } } -Result> FromFLBA( - const LogicalType& logical_type, - int32_t physical_length) { - switch (logical_type.type()) { - case LogicalType::Type::DECIMAL: - return MakeArrowDecimal(logical_type); - case LogicalType::Type::NONE: - case LogicalType::Type::INTERVAL: - case LogicalType::Type::UUID: - return ::arrow::fixed_size_binary(physical_length); +Result> fromFLBA( + const LogicalType& logicalType, + int32_t physicalLength) { + switch (logicalType.type()) { + case LogicalType::Type::kDecimal: + return makeArrowDecimal(logicalType); + case LogicalType::Type::kNone: + case LogicalType::Type::kInterval: + case LogicalType::Type::kUuid: + return ::arrow::fixed_size_binary(physicalLength); default: return Status::NotImplemented( "Unhandled logical logical_type ", - logical_type.ToString(), + logicalType.toString(), " for fixed-length binary array"); } } -::arrow::Result> FromInt32( - const LogicalType& logical_type) { - switch (logical_type.type()) { - case LogicalType::Type::INT: - return MakeArrowInt(logical_type); - case LogicalType::Type::DATE: +::arrow::Result> fromInt32( + const LogicalType& logicalType) { + switch (logicalType.type()) { + case LogicalType::Type::kInt: + return makeArrowInt(logicalType); + case LogicalType::Type::kDate: return ::arrow::date32(); - case LogicalType::Type::TIME: - return MakeArrowTime32(logical_type); - case LogicalType::Type::DECIMAL: - return MakeArrowDecimal(logical_type); - case LogicalType::Type::NONE: + case LogicalType::Type::kTime: + return makeArrowTime32(logicalType); + case LogicalType::Type::kDecimal: + return makeArrowDecimal(logicalType); + case LogicalType::Type::kNone: return ::arrow::int32(); default: return Status::NotImplemented( - "Unhandled logical type ", logical_type.ToString(), " for INT32"); + "Unhandled logical type ", logicalType.toString(), " for INT32"); } } -Result> FromInt64(const LogicalType& logical_type) { - switch (logical_type.type()) { - case LogicalType::Type::INT: - return MakeArrowInt64(logical_type); - case LogicalType::Type::DECIMAL: - return MakeArrowDecimal(logical_type); - case LogicalType::Type::TIMESTAMP: - return MakeArrowTimestamp(logical_type); - case LogicalType::Type::TIME: - return MakeArrowTime64(logical_type); - case LogicalType::Type::NONE: +Result> fromInt64(const LogicalType& logicalType) { + switch (logicalType.type()) { + case LogicalType::Type::kInt: + return makeArrowInt64(logicalType); + case LogicalType::Type::kDecimal: + return makeArrowDecimal(logicalType); + case LogicalType::Type::kTimestamp: + return makeArrowTimestamp(logicalType); + case LogicalType::Type::kTime: + return makeArrowTime64(logicalType); + case LogicalType::Type::kNone: return ::arrow::int64(); default: return Status::NotImplemented( - "Unhandled logical type ", logical_type.ToString(), " for INT64"); + "Unhandled logical type ", logicalType.toString(), " for INT64"); } } -Result> GetArrowType( - Type::type physical_type, - const LogicalType& logical_type, - int type_length, - const ::arrow::TimeUnit::type int96_arrow_time_unit) { - if (logical_type.is_invalid() || logical_type.is_null()) { +Result> getArrowType( + Type::type physicalType, + const LogicalType& logicalType, + int typeLength, + const ::arrow::TimeUnit::type int96ArrowTimeUnit) { + if (logicalType.isInvalid() || logicalType.isNull()) { return ::arrow::null(); } - switch (physical_type) { - case ParquetType::BOOLEAN: + switch (physicalType) { + case ParquetType::kBoolean: return ::arrow::boolean(); - case ParquetType::INT32: - return FromInt32(logical_type); - case ParquetType::INT64: - return FromInt64(logical_type); - case ParquetType::INT96: - return ::arrow::timestamp(int96_arrow_time_unit); - case ParquetType::FLOAT: + case ParquetType::kInt32: + return fromInt32(logicalType); + case ParquetType::kInt64: + return fromInt64(logicalType); + case ParquetType::kInt96: + return ::arrow::timestamp(int96ArrowTimeUnit); + case ParquetType::kFloat: return ::arrow::float32(); - case ParquetType::DOUBLE: + case ParquetType::kDouble: return ::arrow::float64(); - case ParquetType::BYTE_ARRAY: - return FromByteArray(logical_type); - case ParquetType::FIXED_LEN_BYTE_ARRAY: - return FromFLBA(logical_type, type_length); + case ParquetType::kByteArray: + return fromByteArray(logicalType); + case ParquetType::kFixedLenByteArray: + return fromFLBA(logicalType, typeLength); default: { - // PARQUET-1565: This can occur if the file is corrupt + // PARQUET-1565: This can occur if the file is corrupt. return Status::IOError( - "Invalid physical column type: ", TypeToString(physical_type)); + "Invalid physical column type: ", typeToString(physicalType)); } } } -Result> GetArrowType( +Result> getArrowType( const schema::PrimitiveNode& primitive, - const ::arrow::TimeUnit::type int96_arrow_time_unit) { - return GetArrowType( - primitive.physical_type(), - *primitive.logical_type(), - primitive.type_length(), - int96_arrow_time_unit); + const ::arrow::TimeUnit::type int96ArrowTimeUnit) { + return getArrowType( + primitive.physicalType(), + *primitive.logicalType(), + primitive.typeLength(), + int96ArrowTimeUnit); } } // namespace facebook::velox::parquet::arrow::arrow diff --git a/velox/dwio/parquet/writer/arrow/ArrowSchemaInternal.h b/velox/dwio/parquet/writer/arrow/ArrowSchemaInternal.h index 8b7b443db53..ebe29843a54 100644 --- a/velox/dwio/parquet/writer/arrow/ArrowSchemaInternal.h +++ b/velox/dwio/parquet/writer/arrow/ArrowSchemaInternal.h @@ -29,29 +29,29 @@ namespace facebook::velox::parquet::arrow::arrow { using ::arrow::Result; -Result> FromByteArray( - const LogicalType& logical_type); -Result> FromFLBA( - const LogicalType& logical_type, - int32_t physical_length); -Result> FromInt32( - const LogicalType& logical_type); -Result> FromInt64( - const LogicalType& logical_type); - -Result> GetArrowType( - Type::type physical_type, - const LogicalType& logical_type, - int type_length); - -Result> GetArrowType( - Type::type physical_type, - const LogicalType& logical_type, - int type_length, - ::arrow::TimeUnit::type int96_arrow_time_unit = ::arrow::TimeUnit::NANO); - -Result> GetArrowType( +Result> fromByteArray( + const LogicalType& logicalType); +Result> fromFLBA( + const LogicalType& logicalType, + int32_t physicalLength); +Result> fromInt32( + const LogicalType& logicalType); +Result> fromInt64( + const LogicalType& logicalType); + +Result> getArrowType( + Type::type physicalType, + const LogicalType& logicalType, + int typeLength); + +Result> getArrowType( + Type::type physicalType, + const LogicalType& logicalType, + int typeLength, + ::arrow::TimeUnit::type int96ArrowTimeUnit = ::arrow::TimeUnit::NANO); + +Result> getArrowType( const schema::PrimitiveNode& primitive, - ::arrow::TimeUnit::type int96_arrow_time_unit = ::arrow::TimeUnit::NANO); + ::arrow::TimeUnit::type int96ArrowTimeUnit = ::arrow::TimeUnit::NANO); } // namespace facebook::velox::parquet::arrow::arrow diff --git a/velox/dwio/parquet/writer/arrow/CMakeLists.txt b/velox/dwio/parquet/writer/arrow/CMakeLists.txt index b70c5f25e9b..b00b2f8a662 100644 --- a/velox/dwio/parquet/writer/arrow/CMakeLists.txt +++ b/velox/dwio/parquet/writer/arrow/CMakeLists.txt @@ -37,8 +37,33 @@ velox_add_library( Properties.cpp Schema.cpp Statistics.cpp + StringTruncation.cpp Types.cpp Writer.cpp + HEADERS + ArrowSchema.h + ArrowSchemaInternal.h + ColumnPage.h + ColumnWriter.h + Encoding.h + Encryption.h + EncryptionInternal.h + Exception.h + FileDecryptorInternal.h + FileEncryptorInternal.h + FileWriter.h + Metadata.h + PageIndex.h + PathInternal.h + Platform.h + Properties.h + Schema.h + SchemaInternal.h + Statistics.h + StringTruncation.h + ThriftInternal.h + Types.h + Writer.h ) velox_link_libraries( diff --git a/velox/dwio/parquet/writer/arrow/ColumnPage.h b/velox/dwio/parquet/writer/arrow/ColumnPage.h index f6c20de01c8..df514dcc8f6 100644 --- a/velox/dwio/parquet/writer/arrow/ColumnPage.h +++ b/velox/dwio/parquet/writer/arrow/ColumnPage.h @@ -35,11 +35,11 @@ namespace facebook::velox::parquet::arrow { // TODO: Parallel processing is not yet safe because of memory-ownership // semantics (the PageReader may or may not own the memory referenced by a -// page) +// page). // // TODO(wesm): In the future Parquet implementations may store the crc code // in facebook::velox::parquet::thrift::PageHeader. parquet-mr currently does -// not, so we also skip it here, both on the read and write path +// not, so we also skip it here, both on the read and write path. class Page { public: Page(const std::shared_ptr<::arrow::Buffer>& buffer, PageType::type type) @@ -53,12 +53,12 @@ class Page { return buffer_; } - // @returns: a pointer to the page's data + // @returns: A pointer to the page's data. const uint8_t* data() const { return buffer_->data(); } - // @returns: the total size in bytes of the page's data buffer + // @returns: The total size in bytes of the page's data buffer. int32_t size() const { return static_cast(buffer_->size()); } @@ -68,26 +68,26 @@ class Page { PageType::type type_; }; -/// \brief Base type for DataPageV1 and DataPageV2 including common attributes +/// \brief Base type for DataPageV1 and DataPageV2 including common attributes. class DataPage : public Page { public: - int32_t num_values() const { - return num_values_; + int32_t numValues() const { + return numValues_; } Encoding::type encoding() const { return encoding_; } - int64_t uncompressed_size() const { - return uncompressed_size_; + int64_t uncompressedSize() const { + return uncompressedSize_; } const EncodedStatistics& statistics() const { return statistics_; } /// Return the row ordinal within the row group to the first row in the data - /// page. Currently it is only present from data pages created by ColumnWriter - /// in order to collect page index. - std::optional first_row_index() const { - return first_row_index_; + /// page. Currently it is only present from data pages created by + /// ColumnWriter in order to collect page index. + std::optional firstRowIndex() const { + return firstRowIndex_; } virtual ~DataPage() = default; @@ -96,145 +96,145 @@ class DataPage : public Page { DataPage( PageType::type type, const std::shared_ptr<::arrow::Buffer>& buffer, - int32_t num_values, + int32_t numValues, Encoding::type encoding, - int64_t uncompressed_size, + int64_t uncompressedSize, const EncodedStatistics& statistics = EncodedStatistics(), - std::optional first_row_index = std::nullopt) + std::optional firstRowIndex = std::nullopt) : Page(buffer, type), - num_values_(num_values), + numValues_(numValues), encoding_(encoding), - uncompressed_size_(uncompressed_size), + uncompressedSize_(uncompressedSize), statistics_(statistics), - first_row_index_(std::move(first_row_index)) {} + firstRowIndex_(std::move(firstRowIndex)) {} - int32_t num_values_; + int32_t numValues_; Encoding::type encoding_; - int64_t uncompressed_size_; + int64_t uncompressedSize_; EncodedStatistics statistics_; /// Row ordinal within the row group to the first row in the data page. - std::optional first_row_index_; + std::optional firstRowIndex_; }; class DataPageV1 : public DataPage { public: DataPageV1( const std::shared_ptr<::arrow::Buffer>& buffer, - int32_t num_values, + int32_t numValues, Encoding::type encoding, - Encoding::type definition_level_encoding, - Encoding::type repetition_level_encoding, - int64_t uncompressed_size, + Encoding::type definitionLevelEncoding, + Encoding::type repetitionLevelEncoding, + int64_t uncompressedSize, const EncodedStatistics& statistics = EncodedStatistics(), - std::optional first_row_index = std::nullopt) + std::optional firstRowIndex = std::nullopt) : DataPage( - PageType::DATA_PAGE, + PageType::kDataPage, buffer, - num_values, + numValues, encoding, - uncompressed_size, + uncompressedSize, statistics, - std::move(first_row_index)), - definition_level_encoding_(definition_level_encoding), - repetition_level_encoding_(repetition_level_encoding) {} + std::move(firstRowIndex)), + definitionLevelEncoding_(definitionLevelEncoding), + repetitionLevelEncoding_(repetitionLevelEncoding) {} - Encoding::type repetition_level_encoding() const { - return repetition_level_encoding_; + Encoding::type repetitionLevelEncoding() const { + return repetitionLevelEncoding_; } - Encoding::type definition_level_encoding() const { - return definition_level_encoding_; + Encoding::type definitionLevelEncoding() const { + return definitionLevelEncoding_; } private: - Encoding::type definition_level_encoding_; - Encoding::type repetition_level_encoding_; + Encoding::type definitionLevelEncoding_; + Encoding::type repetitionLevelEncoding_; }; class DataPageV2 : public DataPage { public: DataPageV2( const std::shared_ptr<::arrow::Buffer>& buffer, - int32_t num_values, - int32_t num_nulls, - int32_t num_rows, + int32_t numValues, + int32_t numNulls, + int32_t numRows, Encoding::type encoding, - int32_t definition_levels_byte_length, - int32_t repetition_levels_byte_length, - int64_t uncompressed_size, - bool is_compressed = false, + int32_t definitionLevelsByteLength, + int32_t repetitionLevelsByteLength, + int64_t uncompressedSize, + bool isCompressed = false, const EncodedStatistics& statistics = EncodedStatistics(), - std::optional first_row_index = std::nullopt) + std::optional firstRowIndex = std::nullopt) : DataPage( - PageType::DATA_PAGE_V2, + PageType::kDataPageV2, buffer, - num_values, + numValues, encoding, - uncompressed_size, + uncompressedSize, statistics, - std::move(first_row_index)), - num_nulls_(num_nulls), - num_rows_(num_rows), - definition_levels_byte_length_(definition_levels_byte_length), - repetition_levels_byte_length_(repetition_levels_byte_length), - is_compressed_(is_compressed) {} + std::move(firstRowIndex)), + numNulls_(numNulls), + numRows_(numRows), + definitionLevelsByteLength_(definitionLevelsByteLength), + repetitionLevelsByteLength_(repetitionLevelsByteLength), + isCompressed_(isCompressed) {} - int32_t num_nulls() const { - return num_nulls_; + int32_t numNulls() const { + return numNulls_; } - int32_t num_rows() const { - return num_rows_; + int32_t numRows() const { + return numRows_; } - int32_t definition_levels_byte_length() const { - return definition_levels_byte_length_; + int32_t definitionLevelsByteLength() const { + return definitionLevelsByteLength_; } - int32_t repetition_levels_byte_length() const { - return repetition_levels_byte_length_; + int32_t repetitionLevelsByteLength() const { + return repetitionLevelsByteLength_; } - bool is_compressed() const { - return is_compressed_; + bool isCompressed() const { + return isCompressed_; } private: - int32_t num_nulls_; - int32_t num_rows_; - int32_t definition_levels_byte_length_; - int32_t repetition_levels_byte_length_; - bool is_compressed_; + int32_t numNulls_; + int32_t numRows_; + int32_t definitionLevelsByteLength_; + int32_t repetitionLevelsByteLength_; + bool isCompressed_; }; class DictionaryPage : public Page { public: DictionaryPage( const std::shared_ptr<::arrow::Buffer>& buffer, - int32_t num_values, + int32_t numValues, Encoding::type encoding, - bool is_sorted = false) - : Page(buffer, PageType::DICTIONARY_PAGE), - num_values_(num_values), + bool isSorted = false) + : Page(buffer, PageType::kDictionaryPage), + numValues_(numValues), encoding_(encoding), - is_sorted_(is_sorted) {} + isSorted_(isSorted) {} - int32_t num_values() const { - return num_values_; + int32_t numValues() const { + return numValues_; } Encoding::type encoding() const { return encoding_; } - bool is_sorted() const { - return is_sorted_; + bool isSorted() const { + return isSorted_; } private: - int32_t num_values_; + int32_t numValues_; Encoding::type encoding_; - bool is_sorted_; + bool isSorted_; }; } // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/ColumnWriter.cpp b/velox/dwio/parquet/writer/arrow/ColumnWriter.cpp index d69c8ac6a12..1ba58979697 100644 --- a/velox/dwio/parquet/writer/arrow/ColumnWriter.cpp +++ b/velox/dwio/parquet/writer/arrow/ColumnWriter.cpp @@ -69,7 +69,7 @@ using arrow::internal::checked_cast; using arrow::internal::checked_pointer_cast; namespace arrow { -fmt::underlying_t format_as(Type::type type) { +fmt::underlying_t formatAs(Type::type type) { return fmt::underlying(type); } }; // namespace arrow @@ -79,14 +79,14 @@ using util::CodecOptions; namespace { -// Visitor that exracts the value buffer from a FlatArray at a given offset. +// Visitor that extracts the value buffer from a FlatArray at a given offset. struct ValueBufferSlicer { template - ::arrow::enable_if_base_binary Visit( + ::arrow::enable_if_base_binary visit( const T& array, std::shared_ptr* buffer) { auto data = array.data(); - *buffer = SliceBuffer( + *buffer = ::arrow::SliceBuffer( data->buffers[1], data->offset * sizeof(typename T::offset_type), data->length * sizeof(typename T::offset_type)); @@ -94,11 +94,11 @@ struct ValueBufferSlicer { } template - ::arrow::enable_if_fixed_size_binary Visit( + ::arrow::enable_if_fixed_size_binary visit( const T& array, std::shared_ptr* buffer) { auto data = array.data(); - *buffer = SliceBuffer( + *buffer = ::arrow::SliceBuffer( data->buffers[1], data->offset * array.byte_width(), data->length * array.byte_width()); @@ -110,9 +110,9 @@ struct ValueBufferSlicer { ::arrow::has_c_type::value && !std::is_same::value, Status> - Visit(const T& array, std::shared_ptr* buffer) { + visit(const T& array, std::shared_ptr* buffer) { auto data = array.data(); - *buffer = SliceBuffer( + *buffer = ::arrow::SliceBuffer( data->buffers[1], ::arrow::TypeTraits::bytes_required( data->offset), @@ -121,12 +121,12 @@ struct ValueBufferSlicer { return Status::OK(); } - Status Visit( + Status visit( const ::arrow::BooleanArray& array, std::shared_ptr* buffer) { auto data = array.data(); if (::arrow::bit_util::IsMultipleOf8(data->offset)) { - *buffer = SliceBuffer( + *buffer = ::arrow::SliceBuffer( data->buffers[1], ::arrow::bit_util::BytesForBits(data->offset), ::arrow::bit_util::BytesForBits(data->length)); @@ -139,7 +139,7 @@ struct ValueBufferSlicer { return Status::OK(); } #define NOT_IMPLEMENTED_VISIT(ArrowTypePrefix) \ - Status Visit( \ + Status visit( \ const ::arrow::ArrowTypePrefix##Array& array, \ std::shared_ptr* buffer) { \ return Status::NotImplemented( \ @@ -165,25 +165,25 @@ struct ValueBufferSlicer { MemoryPool* pool_; }; -LevelInfo ComputeLevelInfo(const ColumnDescriptor* descr) { - LevelInfo level_info; - level_info.defLevel = descr->max_definition_level(); - level_info.repLevel = descr->max_repetition_level(); +LevelInfo computeLevelInfo(const ColumnDescriptor* descr) { + LevelInfo levelInfo; + levelInfo.defLevel = descr->maxDefinitionLevel(); + levelInfo.repLevel = descr->maxRepetitionLevel(); - int16_t min_spaced_def_level = descr->max_definition_level(); - const schema::Node* node = descr->schema_node().get(); - while (node != nullptr && !node->is_repeated()) { - if (node->is_optional()) { - min_spaced_def_level--; + int16_t minSpacedDefLevel = descr->maxDefinitionLevel(); + const schema::Node* node = descr->schemaNode().get(); + while (node != nullptr && !node->isRepeated()) { + if (node->isOptional()) { + minSpacedDefLevel--; } node = node->parent(); } - level_info.repeatedAncestorDefLevel = min_spaced_def_level; - return level_info; + levelInfo.repeatedAncestorDefLevel = minSpacedDefLevel; + return levelInfo; } template -inline const T* AddIfNotNull(const T* base, int64_t offset) { +inline const T* addIfNotNull(const T* base, int64_t offset) { if (base != nullptr) { return base + offset; } @@ -195,23 +195,23 @@ inline const T* AddIfNotNull(const T* base, int64_t offset) { LevelEncoder::LevelEncoder() {} LevelEncoder::~LevelEncoder() {} -void LevelEncoder::Init( +void LevelEncoder::init( Encoding::type encoding, - int16_t max_level, - int num_buffered_values, + int16_t maxLevel, + int numBufferedValues, uint8_t* data, - int data_size) { - bit_width_ = ::arrow::bit_util::Log2(max_level + 1); + int dataSize) { + bitWidth_ = ::arrow::bit_util::Log2(maxLevel + 1); encoding_ = encoding; switch (encoding) { - case Encoding::RLE: { - rle_encoder_ = std::make_unique(data, data_size, bit_width_); + case Encoding::kRle: { + rleEncoder_ = std::make_unique(data, dataSize, bitWidth_); break; } - case Encoding::BIT_PACKED: { - int num_bytes = static_cast( - ::arrow::bit_util::BytesForBits(num_buffered_values * bit_width_)); - bit_packed_encoder_ = std::make_unique(data, num_bytes); + case Encoding::kBitPacked: { + int numBytes = static_cast( + ::arrow::bit_util::BytesForBits(numBufferedValues * bitWidth_)); + bitPackedEncoder_ = std::make_unique(data, numBytes); break; } default: @@ -219,60 +219,60 @@ void LevelEncoder::Init( } } -int LevelEncoder::MaxBufferSize( +int LevelEncoder::maxBufferSize( Encoding::type encoding, - int16_t max_level, - int num_buffered_values) { - int bit_width = ::arrow::bit_util::Log2(max_level + 1); - int num_bytes = 0; + int16_t maxLevel, + int numBufferedValues) { + int bitWidth = ::arrow::bit_util::Log2(maxLevel + 1); + int numBytes = 0; switch (encoding) { - case Encoding::RLE: { + case Encoding::kRle: { // TODO: Due to the way we currently check if the buffer is full enough, // we need to have MinBufferSize as head room. - num_bytes = RleEncoder::MaxBufferSize(bit_width, num_buffered_values) + - RleEncoder::MinBufferSize(bit_width); + numBytes = RleEncoder::MaxBufferSize(bitWidth, numBufferedValues) + + RleEncoder::MinBufferSize(bitWidth); break; } - case Encoding::BIT_PACKED: { - num_bytes = static_cast( - ::arrow::bit_util::BytesForBits(num_buffered_values * bit_width)); + case Encoding::kBitPacked: { + numBytes = static_cast( + ::arrow::bit_util::BytesForBits(numBufferedValues * bitWidth)); break; } default: throw ParquetException("Unknown encoding type for levels."); } - return num_bytes; + return numBytes; } -int LevelEncoder::Encode(int batch_size, const int16_t* levels) { - int num_encoded = 0; - if (!rle_encoder_ && !bit_packed_encoder_) { +int LevelEncoder::encode(int batchSize, const int16_t* levels) { + int numEncoded = 0; + if (!rleEncoder_ && !bitPackedEncoder_) { throw ParquetException("Level encoders are not initialized."); } - if (encoding_ == Encoding::RLE) { - for (int i = 0; i < batch_size; ++i) { - if (!rle_encoder_->Put(*(levels + i))) { + if (encoding_ == Encoding::kRle) { + for (int i = 0; i < batchSize; ++i) { + if (!rleEncoder_->Put(*(levels + i))) { break; } - ++num_encoded; + ++numEncoded; } - rle_encoder_->Flush(); - rle_length_ = rle_encoder_->len(); + rleEncoder_->Flush(); + rleLength_ = rleEncoder_->len(); } else { - for (int i = 0; i < batch_size; ++i) { - if (!bit_packed_encoder_->PutValue(*(levels + i), bit_width_)) { + for (int i = 0; i < batchSize; ++i) { + if (!bitPackedEncoder_->PutValue(*(levels + i), bitWidth_)) { break; } - ++num_encoded; + ++numEncoded; } - bit_packed_encoder_->Flush(); + bitPackedEncoder_->Flush(); } - return num_encoded; + return numEncoded; } -// ---------------------------------------------------------------------- -// PageWriter implementation +// ----------------------------------------------------------------------. +// PageWriter implementation. // This subclass delimits pages appearing in a serialized stream, each preceded // by a serialized Thrift facebook::velox::parquet::thrift::PageHeader @@ -283,611 +283,608 @@ class SerializedPageWriter : public PageWriter { std::shared_ptr sink, Compression::type codec, ColumnChunkMetaDataBuilder* metadata, - int16_t row_group_ordinal, - int16_t column_chunk_ordinal, - bool use_page_checksum_verification, + int16_t rowGroupOrdinal, + int16_t columnChunkOrdinal, + bool usePageChecksumVerification, MemoryPool* pool = ::arrow::default_memory_pool(), - std::shared_ptr meta_encryptor = nullptr, - std::shared_ptr data_encryptor = nullptr, - ColumnIndexBuilder* column_index_builder = nullptr, - OffsetIndexBuilder* offset_index_builder = nullptr, - const CodecOptions& codec_options = CodecOptions{}) + std::shared_ptr metaEncryptor = nullptr, + std::shared_ptr dataEncryptor = nullptr, + ColumnIndexBuilder* columnIndexBuilder = nullptr, + OffsetIndexBuilder* offsetIndexBuilder = nullptr, + const CodecOptions& codecOptions = CodecOptions{}) : sink_(std::move(sink)), metadata_(metadata), pool_(pool), - num_values_(0), - dictionary_page_offset_(0), - data_page_offset_(0), - total_uncompressed_size_(0), - total_compressed_size_(0), - page_ordinal_(0), - row_group_ordinal_(row_group_ordinal), - column_ordinal_(column_chunk_ordinal), - page_checksum_verification_(use_page_checksum_verification), - meta_encryptor_(std::move(meta_encryptor)), - data_encryptor_(std::move(data_encryptor)), - encryption_buffer_(AllocateBuffer(pool, 0)), - column_index_builder_(column_index_builder), - offset_index_builder_(offset_index_builder) { - if (data_encryptor_ != nullptr || meta_encryptor_ != nullptr) { - InitEncryption(); - } - compressor_ = GetCodec(codec, codec_options); - thrift_serializer_ = std::make_unique(); - } - - int64_t WriteDictionaryPage(const DictionaryPage& page) override { - int64_t uncompressed_size = page.size(); - std::shared_ptr compressed_data; - if (has_compressor()) { + numValues_(0), + dictionaryPageOffset_(0), + dataPageOffset_(0), + totalUncompressedSize_(0), + totalCompressedSize_(0), + pageOrdinal_(0), + rowGroupOrdinal_(rowGroupOrdinal), + columnOrdinal_(columnChunkOrdinal), + pageChecksumVerification_(usePageChecksumVerification), + metaEncryptor_(std::move(metaEncryptor)), + dataEncryptor_(std::move(dataEncryptor)), + encryptionBuffer_(allocateBuffer(pool, 0)), + columnIndexBuilder_(columnIndexBuilder), + offsetIndexBuilder_(offsetIndexBuilder) { + if (dataEncryptor_ != nullptr || metaEncryptor_ != nullptr) { + initEncryption(); + } + compressor_ = getCodec(codec, codecOptions); + thriftSerializer_ = std::make_unique(); + } + + int64_t writeDictionaryPage(const DictionaryPage& page) override { + int64_t uncompressedSize = page.size(); + std::shared_ptr compressedData; + if (hasCompressor()) { auto buffer = std::static_pointer_cast( - AllocateBuffer(pool_, uncompressed_size)); - Compress(*(page.buffer().get()), buffer.get()); - compressed_data = std::static_pointer_cast(buffer); + allocateBuffer(pool_, uncompressedSize)); + compress(*(page.buffer().get()), buffer.get()); + compressedData = std::static_pointer_cast(buffer); } else { - compressed_data = page.buffer(); + compressedData = page.buffer(); } - facebook::velox::parquet::thrift::DictionaryPageHeader dict_page_header; - dict_page_header.__set_num_values(page.num_values()); - dict_page_header.__set_encoding(ToThrift(page.encoding())); - dict_page_header.__set_is_sorted(page.is_sorted()); + facebook::velox::parquet::thrift::DictionaryPageHeader dictPageHeader; + dictPageHeader.__set_num_values(page.numValues()); + dictPageHeader.__set_encoding(toThrift(page.encoding())); + dictPageHeader.__set_is_sorted(page.isSorted()); - const uint8_t* output_data_buffer = compressed_data->data(); - int32_t output_data_len = static_cast(compressed_data->size()); + const uint8_t* outputDataBuffer = compressedData->data(); + int32_t outputDataLen = static_cast(compressedData->size()); - if (data_encryptor_.get()) { - UpdateEncryption(encryption::kDictionaryPage); - PARQUET_THROW_NOT_OK(encryption_buffer_->Resize( - data_encryptor_->CiphertextSizeDelta() + output_data_len, false)); - output_data_len = data_encryptor_->Encrypt( - compressed_data->data(), - output_data_len, - encryption_buffer_->mutable_data()); - output_data_buffer = encryption_buffer_->data(); + if (dataEncryptor_.get()) { + updateEncryption(encryption::kDictionaryPage); + PARQUET_THROW_NOT_OK(encryptionBuffer_->Resize( + dataEncryptor_->ciphertextSizeDelta() + outputDataLen, false)); + outputDataLen = dataEncryptor_->encrypt( + compressedData->data(), + outputDataLen, + encryptionBuffer_->mutable_data()); + outputDataBuffer = encryptionBuffer_->data(); } - facebook::velox::parquet::thrift::PageHeader page_header; - page_header.__set_type( + facebook::velox::parquet::thrift::PageHeader pageHeader; + pageHeader.__set_type( facebook::velox::parquet::thrift::PageType::DICTIONARY_PAGE); - page_header.__set_uncompressed_page_size( - static_cast(uncompressed_size)); - page_header.__set_compressed_page_size( - static_cast(output_data_len)); - page_header.__set_dictionary_page_header(dict_page_header); - if (page_checksum_verification_) { + pageHeader.__set_uncompressed_page_size( + static_cast(uncompressedSize)); + pageHeader.__set_compressed_page_size(static_cast(outputDataLen)); + pageHeader.__set_dictionary_page_header(dictPageHeader); + if (pageChecksumVerification_) { uint32_t crc32 = - internal::crc32(/* prev */ 0, output_data_buffer, output_data_len); - page_header.__set_crc(static_cast(crc32)); + internal::crc32(/* prev */ 0, outputDataBuffer, outputDataLen); + pageHeader.__set_crc(static_cast(crc32)); } - PARQUET_ASSIGN_OR_THROW(int64_t start_pos, sink_->Tell()); - if (dictionary_page_offset_ == 0) { - dictionary_page_offset_ = start_pos; + PARQUET_ASSIGN_OR_THROW(int64_t startPos, sink_->Tell()); + if (dictionaryPageOffset_ == 0) { + dictionaryPageOffset_ = startPos; } - if (meta_encryptor_) { - UpdateEncryption(encryption::kDictionaryPageHeader); + if (metaEncryptor_) { + updateEncryption(encryption::kDictionaryPageHeader); } - const int64_t header_size = thrift_serializer_->Serialize( - &page_header, sink_.get(), meta_encryptor_); + const int64_t headerSize = + thriftSerializer_->serialize(&pageHeader, sink_.get(), metaEncryptor_); - PARQUET_THROW_NOT_OK(sink_->Write(output_data_buffer, output_data_len)); + PARQUET_THROW_NOT_OK(sink_->Write(outputDataBuffer, outputDataLen)); - total_uncompressed_size_ += uncompressed_size + header_size; - total_compressed_size_ += output_data_len + header_size; - ++dict_encoding_stats_[page.encoding()]; - return uncompressed_size + header_size; + totalUncompressedSize_ += uncompressedSize + headerSize; + totalCompressedSize_ += outputDataLen + headerSize; + ++dictEncodingStats_[page.encoding()]; + return uncompressedSize + headerSize; } - void Close(bool has_dictionary, bool fallback) override { - if (meta_encryptor_ != nullptr) { - UpdateEncryption(encryption::kColumnMetaData); + void close(bool hasDictionary, bool fallback) override { + if (metaEncryptor_ != nullptr) { + updateEncryption(encryption::kColumnMetaData); } // Serialized page writer does not need to adjust page offsets. - FinishPageIndexes(/*final_position=*/0); + finishPageIndexes(0); - // index_page_offset = -1 since they are not supported - metadata_->Finish( - num_values_, - dictionary_page_offset_, + // Index_page_offset = -1 since they are not supported. + metadata_->finish( + numValues_, + dictionaryPageOffset_, -1, - data_page_offset_, - total_compressed_size_, - total_uncompressed_size_, - has_dictionary, + dataPageOffset_, + totalCompressedSize_, + totalUncompressedSize_, + hasDictionary, fallback, - dict_encoding_stats_, - data_encoding_stats_, - meta_encryptor_); - // Write metadata at end of column chunk - metadata_->WriteTo(sink_.get()); + dictEncodingStats_, + dataEncodingStats_, + metaEncryptor_); + // Write metadata at end of column chunk. + metadata_->writeTo(sink_.get()); } /** * Compress a buffer. */ - void Compress(const Buffer& src_buffer, ResizableBuffer* dest_buffer) - override { + void compress(const Buffer& srcBuffer, ResizableBuffer* destBuffer) override { VELOX_DCHECK_NOT_NULL(compressor_); - // Compress the data - int64_t max_compressed_size = - compressor_->MaxCompressedLen(src_buffer.size(), src_buffer.data()); + // Compress the data. + int64_t maxCompressedSize = + compressor_->maxCompressedLen(srcBuffer.size(), srcBuffer.data()); - // Use Arrow::Buffer::shrink_to_fit = false - // underlying buffer only keeps growing. Resize to a smaller size does not + // Use Arrow::Buffer::shrink_to_fit = false. + // Underlying buffer only keeps growing. Resize to a smaller size does not // reallocate. - PARQUET_THROW_NOT_OK(dest_buffer->Resize(max_compressed_size, false)); + PARQUET_THROW_NOT_OK(destBuffer->Resize(maxCompressedSize, false)); PARQUET_ASSIGN_OR_THROW( - int64_t compressed_size, - compressor_->Compress( - src_buffer.size(), - src_buffer.data(), - max_compressed_size, - dest_buffer->mutable_data())); - PARQUET_THROW_NOT_OK(dest_buffer->Resize(compressed_size, false)); - } - - int64_t WriteDataPage(const DataPage& page) override { - const int64_t uncompressed_size = page.uncompressed_size(); - std::shared_ptr compressed_data = page.buffer(); - const uint8_t* output_data_buffer = compressed_data->data(); - int32_t output_data_len = static_cast(compressed_data->size()); - - if (data_encryptor_.get()) { - PARQUET_THROW_NOT_OK(encryption_buffer_->Resize( - data_encryptor_->CiphertextSizeDelta() + output_data_len, false)); - UpdateEncryption(encryption::kDataPage); - output_data_len = data_encryptor_->Encrypt( - compressed_data->data(), - output_data_len, - encryption_buffer_->mutable_data()); - output_data_buffer = encryption_buffer_->data(); - } - - facebook::velox::parquet::thrift::PageHeader page_header; - page_header.__set_uncompressed_page_size( - static_cast(uncompressed_size)); - page_header.__set_compressed_page_size( - static_cast(output_data_len)); - - if (page_checksum_verification_) { + int64_t compressedSize, + compressor_->compress( + srcBuffer.size(), + srcBuffer.data(), + maxCompressedSize, + destBuffer->mutable_data())); + PARQUET_THROW_NOT_OK(destBuffer->Resize(compressedSize, false)); + } + + int64_t writeDataPage(const DataPage& page) override { + const int64_t uncompressedSize = page.uncompressedSize(); + std::shared_ptr compressedData = page.buffer(); + const uint8_t* outputDataBuffer = compressedData->data(); + int32_t outputDataLen = static_cast(compressedData->size()); + + if (dataEncryptor_.get()) { + PARQUET_THROW_NOT_OK(encryptionBuffer_->Resize( + dataEncryptor_->ciphertextSizeDelta() + outputDataLen, false)); + updateEncryption(encryption::kDataPage); + outputDataLen = dataEncryptor_->encrypt( + compressedData->data(), + outputDataLen, + encryptionBuffer_->mutable_data()); + outputDataBuffer = encryptionBuffer_->data(); + } + + facebook::velox::parquet::thrift::PageHeader pageHeader; + pageHeader.__set_uncompressed_page_size( + static_cast(uncompressedSize)); + pageHeader.__set_compressed_page_size(static_cast(outputDataLen)); + + if (pageChecksumVerification_) { uint32_t crc32 = - internal::crc32(/* prev */ 0, output_data_buffer, output_data_len); - page_header.__set_crc(static_cast(crc32)); + internal::crc32(/* prev */ 0, outputDataBuffer, outputDataLen); + pageHeader.__set_crc(static_cast(crc32)); } - if (page.type() == PageType::DATA_PAGE) { - const DataPageV1& v1_page = checked_cast(page); - SetDataPageHeader(page_header, v1_page); - } else if (page.type() == PageType::DATA_PAGE_V2) { - const DataPageV2& v2_page = checked_cast(page); - SetDataPageV2Header(page_header, v2_page); + if (page.type() == PageType::kDataPage) { + const DataPageV1& v1Page = checked_cast(page); + setDataPageHeader(pageHeader, v1Page); + } else if (page.type() == PageType::kDataPageV2) { + const DataPageV2& v2Page = checked_cast(page); + setDataPageV2Header(pageHeader, v2Page); } else { throw ParquetException("Unexpected page type"); } - PARQUET_ASSIGN_OR_THROW(int64_t start_pos, sink_->Tell()); - if (page_ordinal_ == 0) { - data_page_offset_ = start_pos; + PARQUET_ASSIGN_OR_THROW(int64_t startPos, sink_->Tell()); + if (pageOrdinal_ == 0) { + dataPageOffset_ = startPos; } - if (meta_encryptor_) { - UpdateEncryption(encryption::kDataPageHeader); + if (metaEncryptor_) { + updateEncryption(encryption::kDataPageHeader); } - const int64_t header_size = thrift_serializer_->Serialize( - &page_header, sink_.get(), meta_encryptor_); - PARQUET_THROW_NOT_OK(sink_->Write(output_data_buffer, output_data_len)); + const int64_t headerSize = + thriftSerializer_->serialize(&pageHeader, sink_.get(), metaEncryptor_); + PARQUET_THROW_NOT_OK(sink_->Write(outputDataBuffer, outputDataLen)); - /// Collect page index - if (column_index_builder_ != nullptr) { - column_index_builder_->AddPage(page.statistics()); + /// Collect page index. + if (columnIndexBuilder_ != nullptr) { + columnIndexBuilder_->addPage(page.statistics()); } - if (offset_index_builder_ != nullptr) { - const int64_t compressed_size = output_data_len + header_size; - if (compressed_size > std::numeric_limits::max()) { + if (offsetIndexBuilder_ != nullptr) { + const int64_t compressedSize = outputDataLen + headerSize; + if (compressedSize > std::numeric_limits::max()) { throw ParquetException("Compressed page size overflows to INT32_MAX."); } - if (!page.first_row_index().has_value()) { + if (!page.firstRowIndex().has_value()) { throw ParquetException("First row index is not set in data page."); } - /// start_pos is a relative offset in the buffered mode. It should be - /// adjusted via OffsetIndexBuilder::Finish() after BufferedPageWriter + /// startPos is a relative offset in the buffered mode. It should be + /// adjusted via OffsetIndexBuilder::finish() after BufferedPageWriter /// has flushed all data pages. - offset_index_builder_->AddPage( - start_pos, - static_cast(compressed_size), - *page.first_row_index()); + offsetIndexBuilder_->addPage( + startPos, + static_cast(compressedSize), + *page.firstRowIndex()); } - total_uncompressed_size_ += uncompressed_size + header_size; - total_compressed_size_ += output_data_len + header_size; - num_values_ += page.num_values(); - ++data_encoding_stats_[page.encoding()]; - ++page_ordinal_; - return uncompressed_size + header_size; + totalUncompressedSize_ += uncompressedSize + headerSize; + totalCompressedSize_ += outputDataLen + headerSize; + numValues_ += page.numValues(); + ++dataEncodingStats_[page.encoding()]; + ++pageOrdinal_; + return uncompressedSize + headerSize; } - void SetDataPageHeader( - facebook::velox::parquet::thrift::PageHeader& page_header, + void setDataPageHeader( + facebook::velox::parquet::thrift::PageHeader& pageHeader, const DataPageV1& page) { - facebook::velox::parquet::thrift::DataPageHeader data_page_header; - data_page_header.__set_num_values(page.num_values()); - data_page_header.__set_encoding(ToThrift(page.encoding())); - data_page_header.__set_definition_level_encoding( - ToThrift(page.definition_level_encoding())); - data_page_header.__set_repetition_level_encoding( - ToThrift(page.repetition_level_encoding())); + facebook::velox::parquet::thrift::DataPageHeader dataPageHeader; + dataPageHeader.__set_num_values(page.numValues()); + dataPageHeader.__set_encoding(toThrift(page.encoding())); + dataPageHeader.__set_definition_level_encoding( + toThrift(page.definitionLevelEncoding())); + dataPageHeader.__set_repetition_level_encoding( + toThrift(page.repetitionLevelEncoding())); // Write page statistics only when page index is not enabled. - if (column_index_builder_ == nullptr) { - data_page_header.__set_statistics(ToThrift(page.statistics())); + if (columnIndexBuilder_ == nullptr) { + dataPageHeader.__set_statistics(toThrift(page.statistics())); } - page_header.__set_type( + pageHeader.__set_type( facebook::velox::parquet::thrift::PageType::DATA_PAGE); - page_header.__set_data_page_header(data_page_header); + pageHeader.__set_data_page_header(dataPageHeader); } - void SetDataPageV2Header( - facebook::velox::parquet::thrift::PageHeader& page_header, + void setDataPageV2Header( + facebook::velox::parquet::thrift::PageHeader& pageHeader, const DataPageV2& page) { - facebook::velox::parquet::thrift::DataPageHeaderV2 data_page_header; - data_page_header.__set_num_values(page.num_values()); - data_page_header.__set_num_nulls(page.num_nulls()); - data_page_header.__set_num_rows(page.num_rows()); - data_page_header.__set_encoding(ToThrift(page.encoding())); + facebook::velox::parquet::thrift::DataPageHeaderV2 dataPageHeader; + dataPageHeader.__set_num_values(page.numValues()); + dataPageHeader.__set_num_nulls(page.numNulls()); + dataPageHeader.__set_num_rows(page.numRows()); + dataPageHeader.__set_encoding(toThrift(page.encoding())); - data_page_header.__set_definition_levels_byte_length( - page.definition_levels_byte_length()); - data_page_header.__set_repetition_levels_byte_length( - page.repetition_levels_byte_length()); + dataPageHeader.__set_definition_levels_byte_length( + page.definitionLevelsByteLength()); + dataPageHeader.__set_repetition_levels_byte_length( + page.repetitionLevelsByteLength()); - data_page_header.__set_is_compressed(page.is_compressed()); + dataPageHeader.__set_is_compressed(page.isCompressed()); // Write page statistics only when page index is not enabled. - if (column_index_builder_ == nullptr) { - data_page_header.__set_statistics(ToThrift(page.statistics())); + if (columnIndexBuilder_ == nullptr) { + dataPageHeader.__set_statistics(toThrift(page.statistics())); } - page_header.__set_type( + pageHeader.__set_type( facebook::velox::parquet::thrift::PageType::DATA_PAGE_V2); - page_header.__set_data_page_header_v2(data_page_header); + pageHeader.__set_data_page_header_v2(dataPageHeader); } /// \brief Finish page index builders and update the stream offset to adjust /// page offsets. - void FinishPageIndexes(int64_t final_position) { - if (column_index_builder_ != nullptr) { - column_index_builder_->Finish(); + void finishPageIndexes(int64_t finalPosition) { + if (columnIndexBuilder_ != nullptr) { + columnIndexBuilder_->finish(); } - if (offset_index_builder_ != nullptr) { - offset_index_builder_->Finish(final_position); + if (offsetIndexBuilder_ != nullptr) { + offsetIndexBuilder_->finish(finalPosition); } } - bool has_compressor() override { + bool hasCompressor() override { return (compressor_ != nullptr); } - int64_t num_values() { - return num_values_; + int64_t numValues() { + return numValues_; } - int64_t dictionary_page_offset() { - return dictionary_page_offset_; + int64_t dictionaryPageOffset() { + return dictionaryPageOffset_; } - int64_t data_page_offset() { - return data_page_offset_; + int64_t dataPageOffset() { + return dataPageOffset_; } - int64_t total_compressed_size() { - return total_compressed_size_; + int64_t totalCompressedSize() { + return totalCompressedSize_; } - int64_t total_uncompressed_size() { - return total_uncompressed_size_; + int64_t totalUncompressedSize() { + return totalUncompressedSize_; } - int64_t total_compressed_bytes_written() const override { - return total_compressed_size_; + int64_t totalCompressedBytesWritten() const override { + return totalCompressedSize_; } - bool page_checksum_verification() { - return page_checksum_verification_; + bool pageChecksumVerification() { + return pageChecksumVerification_; } private: - // To allow UpdateEncryption on Close + // To allow updateEncryption on close. friend class BufferedPageWriter; - void InitEncryption() { + void initEncryption() { // Prepare the AAD for quick update later. - if (data_encryptor_ != nullptr) { - data_page_aad_ = encryption::CreateModuleAad( - data_encryptor_->file_aad(), + if (dataEncryptor_ != nullptr) { + dataPageAad_ = encryption::createModuleAad( + dataEncryptor_->fileAad(), encryption::kDataPage, - row_group_ordinal_, - column_ordinal_, + rowGroupOrdinal_, + columnOrdinal_, kNonPageOrdinal); } - if (meta_encryptor_ != nullptr) { - data_page_header_aad_ = encryption::CreateModuleAad( - meta_encryptor_->file_aad(), + if (metaEncryptor_ != nullptr) { + dataPageHeaderAad_ = encryption::createModuleAad( + metaEncryptor_->fileAad(), encryption::kDataPageHeader, - row_group_ordinal_, - column_ordinal_, + rowGroupOrdinal_, + columnOrdinal_, kNonPageOrdinal); } } - void UpdateEncryption(int8_t module_type) { - switch (module_type) { + void updateEncryption(int8_t moduleType) { + switch (moduleType) { case encryption::kColumnMetaData: { - meta_encryptor_->UpdateAad(encryption::CreateModuleAad( - meta_encryptor_->file_aad(), - module_type, - row_group_ordinal_, - column_ordinal_, - kNonPageOrdinal)); + metaEncryptor_->updateAad( + encryption::createModuleAad( + metaEncryptor_->fileAad(), + moduleType, + rowGroupOrdinal_, + columnOrdinal_, + kNonPageOrdinal)); break; } case encryption::kDataPage: { - encryption::QuickUpdatePageAad(page_ordinal_, &data_page_aad_); - data_encryptor_->UpdateAad(data_page_aad_); + encryption::quickUpdatePageAad(pageOrdinal_, &dataPageAad_); + dataEncryptor_->updateAad(dataPageAad_); break; } case encryption::kDataPageHeader: { - encryption::QuickUpdatePageAad(page_ordinal_, &data_page_header_aad_); - meta_encryptor_->UpdateAad(data_page_header_aad_); + encryption::quickUpdatePageAad(pageOrdinal_, &dataPageHeaderAad_); + metaEncryptor_->updateAad(dataPageHeaderAad_); break; } case encryption::kDictionaryPageHeader: { - meta_encryptor_->UpdateAad(encryption::CreateModuleAad( - meta_encryptor_->file_aad(), - module_type, - row_group_ordinal_, - column_ordinal_, - kNonPageOrdinal)); + metaEncryptor_->updateAad( + encryption::createModuleAad( + metaEncryptor_->fileAad(), + moduleType, + rowGroupOrdinal_, + columnOrdinal_, + kNonPageOrdinal)); break; } case encryption::kDictionaryPage: { - data_encryptor_->UpdateAad(encryption::CreateModuleAad( - data_encryptor_->file_aad(), - module_type, - row_group_ordinal_, - column_ordinal_, - kNonPageOrdinal)); + dataEncryptor_->updateAad( + encryption::createModuleAad( + dataEncryptor_->fileAad(), + moduleType, + rowGroupOrdinal_, + columnOrdinal_, + kNonPageOrdinal)); break; } default: - throw ParquetException("Unknown module type in UpdateEncryption"); + throw ParquetException("Unknown module type in updateEncryption"); } } std::shared_ptr sink_; ColumnChunkMetaDataBuilder* metadata_; MemoryPool* pool_; - int64_t num_values_; - int64_t dictionary_page_offset_; - int64_t data_page_offset_; - // The uncompressed page size the page writer has already - // written. - int64_t total_uncompressed_size_; - // The compressed page size the page writer has already - // written. - // If the column is UNCOMPRESSED, the size would be - // equal to `total_uncompressed_size_`. - int64_t total_compressed_size_; - int32_t page_ordinal_; - int16_t row_group_ordinal_; - int16_t column_ordinal_; - bool page_checksum_verification_; - - std::unique_ptr thrift_serializer_; + int64_t numValues_; + int64_t dictionaryPageOffset_; + int64_t dataPageOffset_; + // The uncompressed page size the page writer has already written. + int64_t totalUncompressedSize_; + // The compressed page size the page writer has already written. + // If the column is UNCOMPRESSED, the size would be equal to + // totalUncompressedSize_. + int64_t totalCompressedSize_; + int32_t pageOrdinal_; + int16_t rowGroupOrdinal_; + int16_t columnOrdinal_; + bool pageChecksumVerification_; + + std::unique_ptr thriftSerializer_; // Compression codec to use. std::unique_ptr compressor_; - std::string data_page_aad_; - std::string data_page_header_aad_; + std::string dataPageAad_; + std::string dataPageHeaderAad_; - std::shared_ptr meta_encryptor_; - std::shared_ptr data_encryptor_; + std::shared_ptr metaEncryptor_; + std::shared_ptr dataEncryptor_; - std::shared_ptr encryption_buffer_; + std::shared_ptr encryptionBuffer_; - std::map dict_encoding_stats_; - std::map data_encoding_stats_; + std::map dictEncodingStats_; + std::map dataEncodingStats_; - ColumnIndexBuilder* column_index_builder_; - OffsetIndexBuilder* offset_index_builder_; + ColumnIndexBuilder* columnIndexBuilder_; + OffsetIndexBuilder* offsetIndexBuilder_; }; -// This implementation of the PageWriter writes to the final sink on Close . +// This implementation of the PageWriter writes to the final sink on close. class BufferedPageWriter : public PageWriter { public: BufferedPageWriter( std::shared_ptr sink, Compression::type codec, ColumnChunkMetaDataBuilder* metadata, - int16_t row_group_ordinal, - int16_t current_column_ordinal, - bool use_page_checksum_verification, + int16_t rowGroupOrdinal, + int16_t currentColumnOrdinal, + bool usePageChecksumVerification, MemoryPool* pool = ::arrow::default_memory_pool(), - std::shared_ptr meta_encryptor = nullptr, - std::shared_ptr data_encryptor = nullptr, - ColumnIndexBuilder* column_index_builder = nullptr, - OffsetIndexBuilder* offset_index_builder = nullptr, - const CodecOptions& codec_options = CodecOptions{}) - : final_sink_(std::move(sink)), + std::shared_ptr metaEncryptor = nullptr, + std::shared_ptr dataEncryptor = nullptr, + ColumnIndexBuilder* columnIndexBuilder = nullptr, + OffsetIndexBuilder* offsetIndexBuilder = nullptr, + const CodecOptions& codecOptions = CodecOptions{}) + : finalSink_(std::move(sink)), metadata_(metadata), - has_dictionary_pages_(false) { - in_memory_sink_ = CreateOutputStream(pool); + hasDictionaryPages_(false) { + inMemorySink_ = createOutputStream(pool); pager_ = std::make_unique( - in_memory_sink_, + inMemorySink_, codec, metadata, - row_group_ordinal, - current_column_ordinal, - use_page_checksum_verification, + rowGroupOrdinal, + currentColumnOrdinal, + usePageChecksumVerification, pool, - std::move(meta_encryptor), - std::move(data_encryptor), - column_index_builder, - offset_index_builder, - codec_options); + std::move(metaEncryptor), + std::move(dataEncryptor), + columnIndexBuilder, + offsetIndexBuilder, + codecOptions); } - int64_t WriteDictionaryPage(const DictionaryPage& page) override { - has_dictionary_pages_ = true; - return pager_->WriteDictionaryPage(page); + int64_t writeDictionaryPage(const DictionaryPage& page) override { + hasDictionaryPages_ = true; + return pager_->writeDictionaryPage(page); } - void Close(bool has_dictionary, bool fallback) override { - if (pager_->meta_encryptor_ != nullptr) { - pager_->UpdateEncryption(encryption::kColumnMetaData); + void close(bool hasDictionary, bool fallback) override { + if (pager_->metaEncryptor_ != nullptr) { + pager_->updateEncryption(encryption::kColumnMetaData); } - // index_page_offset = -1 since they are not supported - PARQUET_ASSIGN_OR_THROW(int64_t final_position, final_sink_->Tell()); - // dictionary page offset should be 0 iff there are no dictionary pages - auto dictionary_page_offset = has_dictionary_pages_ - ? pager_->dictionary_page_offset() + final_position + // Index_page_offset = -1 since they are not supported. + PARQUET_ASSIGN_OR_THROW(int64_t finalPosition, finalSink_->Tell()); + // Dictionary page offset should be 0 iff there are no dictionary pages. + auto dictionaryPageOffset = hasDictionaryPages_ + ? pager_->dictionaryPageOffset() + finalPosition : 0; - metadata_->Finish( - pager_->num_values(), - dictionary_page_offset, + metadata_->finish( + pager_->numValues(), + dictionaryPageOffset, -1, - pager_->data_page_offset() + final_position, - pager_->total_compressed_size(), - pager_->total_uncompressed_size(), - has_dictionary, + pager_->dataPageOffset() + finalPosition, + pager_->totalCompressedSize(), + pager_->totalUncompressedSize(), + hasDictionary, fallback, - pager_->dict_encoding_stats_, - pager_->data_encoding_stats_, - pager_->meta_encryptor_); + pager_->dictEncodingStats_, + pager_->dataEncodingStats_, + pager_->metaEncryptor_); - // Write metadata at end of column chunk - metadata_->WriteTo(in_memory_sink_.get()); + // Write metadata at end of column chunk. + metadata_->writeTo(inMemorySink_.get()); // Buffered page writer needs to adjust page offsets. - pager_->FinishPageIndexes(final_position); + pager_->finishPageIndexes(finalPosition); - // flush everything to the serialized sink - PARQUET_ASSIGN_OR_THROW(auto buffer, in_memory_sink_->Finish()); - PARQUET_THROW_NOT_OK(final_sink_->Write(buffer)); + // Flush everything to the serialized sink. + PARQUET_ASSIGN_OR_THROW(auto buffer, inMemorySink_->Finish()); + PARQUET_THROW_NOT_OK(finalSink_->Write(buffer)); } - int64_t WriteDataPage(const DataPage& page) override { - return pager_->WriteDataPage(page); + int64_t writeDataPage(const DataPage& page) override { + return pager_->writeDataPage(page); } - void Compress(const Buffer& src_buffer, ResizableBuffer* dest_buffer) - override { - pager_->Compress(src_buffer, dest_buffer); + void compress(const Buffer& srcBuffer, ResizableBuffer* destBuffer) override { + pager_->compress(srcBuffer, destBuffer); } - bool has_compressor() override { - return pager_->has_compressor(); + bool hasCompressor() override { + return pager_->hasCompressor(); } - int64_t total_compressed_bytes_written() const override { - return pager_->total_compressed_bytes_written(); + int64_t totalCompressedBytesWritten() const override { + return pager_->totalCompressedBytesWritten(); } private: - std::shared_ptr final_sink_; + std::shared_ptr finalSink_; ColumnChunkMetaDataBuilder* metadata_; - std::shared_ptr<::arrow::io::BufferOutputStream> in_memory_sink_; + std::shared_ptr<::arrow::io::BufferOutputStream> inMemorySink_; std::unique_ptr pager_; - bool has_dictionary_pages_; + bool hasDictionaryPages_; }; -std::unique_ptr PageWriter::Open( +std::unique_ptr PageWriter::open( std::shared_ptr sink, Compression::type codec, ColumnChunkMetaDataBuilder* metadata, - int16_t row_group_ordinal, - int16_t column_chunk_ordinal, + int16_t rowGroupOrdinal, + int16_t columnChunkOrdinal, MemoryPool* pool, - bool buffered_row_group, - std::shared_ptr meta_encryptor, - std::shared_ptr data_encryptor, - bool page_write_checksum_enabled, - ColumnIndexBuilder* column_index_builder, - OffsetIndexBuilder* offset_index_builder, - const CodecOptions& codec_options) { - if (buffered_row_group) { + bool bufferedRowGroup, + std::shared_ptr metaEncryptor, + std::shared_ptr dataEncryptor, + bool pageWriteChecksumEnabled, + ColumnIndexBuilder* columnIndexBuilder, + OffsetIndexBuilder* offsetIndexBuilder, + const CodecOptions& codecOptions) { + if (bufferedRowGroup) { return std::unique_ptr(new BufferedPageWriter( std::move(sink), codec, metadata, - row_group_ordinal, - column_chunk_ordinal, - page_write_checksum_enabled, + rowGroupOrdinal, + columnChunkOrdinal, + pageWriteChecksumEnabled, pool, - std::move(meta_encryptor), - std::move(data_encryptor), - column_index_builder, - offset_index_builder, - codec_options)); + std::move(metaEncryptor), + std::move(dataEncryptor), + columnIndexBuilder, + offsetIndexBuilder, + codecOptions)); } else { return std::unique_ptr(new SerializedPageWriter( std::move(sink), codec, metadata, - row_group_ordinal, - column_chunk_ordinal, - page_write_checksum_enabled, + rowGroupOrdinal, + columnChunkOrdinal, + pageWriteChecksumEnabled, pool, - std::move(meta_encryptor), - std::move(data_encryptor), - column_index_builder, - offset_index_builder, - codec_options)); + std::move(metaEncryptor), + std::move(dataEncryptor), + columnIndexBuilder, + offsetIndexBuilder, + codecOptions)); } } -std::unique_ptr PageWriter::Open( +std::unique_ptr PageWriter::open( std::shared_ptr sink, Compression::type codec, - int compression_level, + int compressionLevel, ColumnChunkMetaDataBuilder* metadata, - int16_t row_group_ordinal, - int16_t column_chunk_ordinal, + int16_t rowGroupOrdinal, + int16_t columnChunkOrdinal, MemoryPool* pool, - bool buffered_row_group, - std::shared_ptr meta_encryptor, - std::shared_ptr data_encryptor, - bool page_write_checksum_enabled, - ColumnIndexBuilder* column_index_builder, - OffsetIndexBuilder* offset_index_builder) { - return PageWriter::Open( + bool bufferedRowGroup, + std::shared_ptr metaEncryptor, + std::shared_ptr dataEncryptor, + bool pageWriteChecksumEnabled, + ColumnIndexBuilder* columnIndexBuilder, + OffsetIndexBuilder* offsetIndexBuilder) { + return PageWriter::open( sink, codec, metadata, - row_group_ordinal, - column_chunk_ordinal, + rowGroupOrdinal, + columnChunkOrdinal, pool, - buffered_row_group, - meta_encryptor, - data_encryptor, - page_write_checksum_enabled, - column_index_builder, - offset_index_builder, - CodecOptions{compression_level}); + bufferedRowGroup, + metaEncryptor, + dataEncryptor, + pageWriteChecksumEnabled, + columnIndexBuilder, + offsetIndexBuilder, + CodecOptions{compressionLevel}); } -// ---------------------------------------------------------------------- -// ColumnWriter +// ----------------------------------------------------------------------. +// ColumnWriter. -const std::shared_ptr& default_writer_properties() { - static std::shared_ptr default_writer_properties = +const std::shared_ptr& defaultWriterProperties() { + static std::shared_ptr defaultWriterProperties = WriterProperties::Builder().build(); - return default_writer_properties; + return defaultWriterProperties; } class ColumnWriterImpl { @@ -895,118 +892,118 @@ class ColumnWriterImpl { ColumnWriterImpl( ColumnChunkMetaDataBuilder* metadata, std::unique_ptr pager, - const bool use_dictionary, + const bool useDictionary, Encoding::type encoding, const WriterProperties* properties) : metadata_(metadata), descr_(metadata->descr()), - level_info_(ComputeLevelInfo(metadata->descr())), + levelInfo_(computeLevelInfo(metadata->descr())), pager_(std::move(pager)), - has_dictionary_(use_dictionary), + hasDictionary_(useDictionary), encoding_(encoding), properties_(properties), - allocator_(properties->memory_pool()), - num_buffered_values_(0), - num_buffered_encoded_values_(0), - num_buffered_nulls_(0), - num_buffered_rows_(0), - rows_written_(0), - total_bytes_written_(0), - total_compressed_bytes_(0), + allocator_(properties->memoryPool()), + numBufferedValues_(0), + numBufferedEncodedValues_(0), + numBufferedNulls_(0), + numBufferedRows_(0), + rowsWritten_(0), + totalBytesWritten_(0), + totalCompressedBytes_(0), closed_(false), fallback_(false), - definition_levels_sink_(allocator_), - repetition_levels_sink_(allocator_) { - definition_levels_rle_ = std::static_pointer_cast( - AllocateBuffer(allocator_, 0)); - repetition_levels_rle_ = std::static_pointer_cast( - AllocateBuffer(allocator_, 0)); - uncompressed_data_ = std::static_pointer_cast( - AllocateBuffer(allocator_, 0)); + definitionLevelsSink_(allocator_), + repetitionLevelsSink_(allocator_) { + definitionLevelsRle_ = std::static_pointer_cast( + allocateBuffer(allocator_, 0)); + repetitionLevelsRle_ = std::static_pointer_cast( + allocateBuffer(allocator_, 0)); + uncompressedData_ = std::static_pointer_cast( + allocateBuffer(allocator_, 0)); - if (pager_->has_compressor()) { - compressor_temp_buffer_ = std::static_pointer_cast( - AllocateBuffer(allocator_, 0)); + if (pager_->hasCompressor()) { + compressorTempBuffer_ = std::static_pointer_cast( + allocateBuffer(allocator_, 0)); } } virtual ~ColumnWriterImpl() = default; - int64_t Close(); + int64_t close(); protected: - virtual std::shared_ptr GetValuesBuffer() = 0; + virtual std::shared_ptr getValuesBuffer() = 0; - // Serializes Dictionary Page if enabled - virtual void WriteDictionaryPage() = 0; + // Serializes Dictionary Page if enabled. + virtual void writeDictionaryPage() = 0; - // Plain-encoded statistics of the current page - virtual EncodedStatistics GetPageStatistics() = 0; + // Plain-encoded statistics of the current page. + virtual EncodedStatistics getPageStatistics() = 0; - // Plain-encoded statistics of the whole chunk - virtual EncodedStatistics GetChunkStatistics() = 0; + // Plain-encoded statistics of the whole chunk. + virtual EncodedStatistics getChunkStatistics() = 0; - // Merges page statistics into chunk statistics, then resets the values - virtual void ResetPageStatistics() = 0; + // Merges page statistics into chunk statistics, then resets the values. + virtual void resetPageStatistics() = 0; - // Adds Data Pages to an in memory buffer in dictionary encoding mode - // Serializes the Data Pages in other encoding modes - void AddDataPage(); + // Adds Data Pages to an in memory buffer in dictionary encoding mode. + // Serializes the Data Pages in other encoding modes. + void addDataPage(); - void BuildDataPageV1( - int64_t definition_levels_rle_size, - int64_t repetition_levels_rle_size, - int64_t uncompressed_size, + void buildDataPageV1( + int64_t definitionLevelsRleSize, + int64_t repetitionLevelsRleSize, + int64_t uncompressedSize, const std::shared_ptr& values); - void BuildDataPageV2( - int64_t definition_levels_rle_size, - int64_t repetition_levels_rle_size, - int64_t uncompressed_size, + void buildDataPageV2( + int64_t definitionLevelsRleSize, + int64_t repetitionLevelsRleSize, + int64_t uncompressedSize, const std::shared_ptr& values); - // Serializes Data Pages - void WriteDataPage(const DataPage& page) { - total_bytes_written_ += pager_->WriteDataPage(page); + // Serializes Data Pages. + void writeDataPage(const DataPage& page) { + totalBytesWritten_ += pager_->writeDataPage(page); } - // Write multiple definition levels - void WriteDefinitionLevels(int64_t num_levels, const int16_t* levels) { + // Write multiple definition levels. + void writeDefinitionLevels(int64_t numLevels, const int16_t* levels) { VELOX_DCHECK(!closed_); PARQUET_THROW_NOT_OK( - definition_levels_sink_.Append(levels, sizeof(int16_t) * num_levels)); + definitionLevelsSink_.Append(levels, sizeof(int16_t) * numLevels)); } - // Write multiple repetition levels - void WriteRepetitionLevels(int64_t num_levels, const int16_t* levels) { + // Write multiple repetition levels. + void writeRepetitionLevels(int64_t numLevels, const int16_t* levels) { VELOX_DCHECK(!closed_); PARQUET_THROW_NOT_OK( - repetition_levels_sink_.Append(levels, sizeof(int16_t) * num_levels)); + repetitionLevelsSink_.Append(levels, sizeof(int16_t) * numLevels)); } - // RLE encode the src_buffer into dest_buffer and return the encoded size - int64_t RleEncodeLevels( - const void* src_buffer, - ResizableBuffer* dest_buffer, - int16_t max_level, - bool include_length_prefix = true); + // RLE encode the src_buffer into dest_buffer and return the encoded size. + int64_t rleEncodeLevels( + const void* srcBuffer, + ResizableBuffer* destBuffer, + int16_t maxLevel, + bool includeLengthPrefix = true); - // Serialize the buffered Data Pages - void FlushBufferedDataPages(); + // Serialize the buffered Data Pages. + void flushBufferedDataPages(); ColumnChunkMetaDataBuilder* metadata_; const ColumnDescriptor* descr_; - // scratch buffer if validity bits need to be recalculated. - std::shared_ptr bits_buffer_; - const LevelInfo level_info_; + // Scratch buffer if validity bits need to be recalculated. + std::shared_ptr bitsBuffer_; + const LevelInfo levelInfo_; std::unique_ptr pager_; - bool has_dictionary_; + bool hasDictionary_; Encoding::type encoding_; const WriterProperties* properties_; - LevelEncoder level_encoder_; + LevelEncoder levelEncoder_; MemoryPool* allocator_; @@ -1016,1160 +1013,1131 @@ class ColumnWriterImpl { // values. For repeated or optional values, there may be fewer data values // than levels, and this tells you how many encoded levels there are in that // case. - int64_t num_buffered_values_; + int64_t numBufferedValues_; // The total number of stored values in the data page. For repeated or - // optional values, this number may be lower than num_buffered_values_. - int64_t num_buffered_encoded_values_; + // optional values, this number may be lower than numBufferedValues_. + int64_t numBufferedEncodedValues_; // The total number of nulls stored in the data page. - int64_t num_buffered_nulls_; + int64_t numBufferedNulls_; // Total number of rows buffered in the data page. - int64_t num_buffered_rows_; + int64_t numBufferedRows_; - // Total number of rows written with this ColumnWriter - int64_t rows_written_; + // Total number of rows written with this ColumnWriter. + int64_t rowsWritten_; - // Records the total number of uncompressed bytes written by the serializer - int64_t total_bytes_written_; + // Records the total number of uncompressed bytes written by the serializer. + int64_t totalBytesWritten_; - // Records the current number of compressed bytes in a column - // These bytes are unwritten to `pager_` yet - int64_t total_compressed_bytes_; + // Records the current number of compressed bytes in a column. + // These bytes are unwritten to `pager_` yet. + int64_t totalCompressedBytes_; - // Flag to check if the Writer has been closed + // Flag to check if the Writer has been closed. bool closed_; - // Flag to infer if dictionary encoding has fallen back to PLAIN + // Flag to infer if dictionary encoding has fallen back to PLAIN. bool fallback_; - ::arrow::BufferBuilder definition_levels_sink_; - ::arrow::BufferBuilder repetition_levels_sink_; + ::arrow::BufferBuilder definitionLevelsSink_; + ::arrow::BufferBuilder repetitionLevelsSink_; - std::shared_ptr definition_levels_rle_; - std::shared_ptr repetition_levels_rle_; + std::shared_ptr definitionLevelsRle_; + std::shared_ptr repetitionLevelsRle_; - std::shared_ptr uncompressed_data_; - std::shared_ptr compressor_temp_buffer_; + std::shared_ptr uncompressedData_; + std::shared_ptr compressorTempBuffer_; - std::vector> data_pages_; + std::vector> dataPages_; private: - void InitSinks() { - definition_levels_sink_.Rewind(0); - repetition_levels_sink_.Rewind(0); + void initSinks() { + definitionLevelsSink_.Rewind(0); + repetitionLevelsSink_.Rewind(0); } - // Concatenate the encoded levels and values into one buffer - void ConcatenateBuffers( - int64_t definition_levels_rle_size, - int64_t repetition_levels_rle_size, + // Concatenate the encoded levels and values into one buffer. + void concatenateBuffers( + int64_t definitionLevelsRleSize, + int64_t repetitionLevelsRleSize, const std::shared_ptr& values, uint8_t* combined) { - memcpy( - combined, repetition_levels_rle_->data(), repetition_levels_rle_size); - combined += repetition_levels_rle_size; - memcpy( - combined, definition_levels_rle_->data(), definition_levels_rle_size); - combined += definition_levels_rle_size; + memcpy(combined, repetitionLevelsRle_->data(), repetitionLevelsRleSize); + combined += repetitionLevelsRleSize; + memcpy(combined, definitionLevelsRle_->data(), definitionLevelsRleSize); + combined += definitionLevelsRleSize; memcpy(combined, values->data(), values->size()); } }; -// return the size of the encoded buffer -int64_t ColumnWriterImpl::RleEncodeLevels( - const void* src_buffer, - ResizableBuffer* dest_buffer, - int16_t max_level, - bool include_length_prefix) { +// Return the size of the encoded buffer. +int64_t ColumnWriterImpl::rleEncodeLevels( + const void* srcBuffer, + ResizableBuffer* destBuffer, + int16_t maxLevel, + bool includeLengthPrefix) { // V1 DataPage includes the length of the RLE level as a prefix. - int32_t prefix_size = include_length_prefix ? sizeof(int32_t) : 0; - - // TODO: This only works with due to some RLE specifics - int64_t rle_size = - LevelEncoder::MaxBufferSize( - Encoding::RLE, max_level, static_cast(num_buffered_values_)) + - prefix_size; - - // Use Arrow::Buffer::shrink_to_fit = false - // underlying buffer only keeps growing. Resize to a smaller size does not - // reallocate. - PARQUET_THROW_NOT_OK(dest_buffer->Resize(rle_size, false)); - - level_encoder_.Init( - Encoding::RLE, - max_level, - static_cast(num_buffered_values_), - dest_buffer->mutable_data() + prefix_size, - static_cast(dest_buffer->size() - prefix_size)); - VELOX_DEBUG_ONLY int encoded = level_encoder_.Encode( - static_cast(num_buffered_values_), - reinterpret_cast(src_buffer)); - VELOX_DCHECK_EQ(encoded, num_buffered_values_); - - if (include_length_prefix) { - reinterpret_cast(dest_buffer->mutable_data())[0] = - level_encoder_.len(); - } - - return level_encoder_.len() + prefix_size; + int32_t prefixSize = includeLengthPrefix ? sizeof(int32_t) : 0; + + // TODO: This only works due to some RLE specifics. + int64_t rleSize = + LevelEncoder::maxBufferSize( + Encoding::kRle, maxLevel, static_cast(numBufferedValues_)) + + prefixSize; + + // Use Arrow::Buffer::shrink_to_fit = false. + // Underlying buffer only keeps growing. Resize to a smaller size does not. + // Reallocate. + PARQUET_THROW_NOT_OK(destBuffer->Resize(rleSize, false)); + + levelEncoder_.init( + Encoding::kRle, + maxLevel, + static_cast(numBufferedValues_), + destBuffer->mutable_data() + prefixSize, + static_cast(destBuffer->size() - prefixSize)); + VELOX_DEBUG_ONLY int encoded = levelEncoder_.encode( + static_cast(numBufferedValues_), + reinterpret_cast(srcBuffer)); + VELOX_DCHECK_EQ(encoded, numBufferedValues_); + + if (includeLengthPrefix) { + reinterpret_cast(destBuffer->mutable_data())[0] = + levelEncoder_.len(); + } + + return levelEncoder_.len() + prefixSize; } -void ColumnWriterImpl::AddDataPage() { - int64_t definition_levels_rle_size = 0; - int64_t repetition_levels_rle_size = 0; +void ColumnWriterImpl::addDataPage() { + int64_t definitionLevelsRleSize = 0; + int64_t repetitionLevelsRleSize = 0; - std::shared_ptr values = GetValuesBuffer(); - bool is_v1_data_page = - properties_->data_page_version() == ParquetDataPageVersion::V1; + std::shared_ptr values = getValuesBuffer(); + bool isV1DataPage = + properties_->dataPageVersion() == ParquetDataPageVersion::V1; - if (descr_->max_definition_level() > 0) { - definition_levels_rle_size = RleEncodeLevels( - definition_levels_sink_.data(), - definition_levels_rle_.get(), - descr_->max_definition_level(), - /*include_length_prefix=*/is_v1_data_page); + if (descr_->maxDefinitionLevel() > 0) { + definitionLevelsRleSize = rleEncodeLevels( + definitionLevelsSink_.data(), + definitionLevelsRle_.get(), + descr_->maxDefinitionLevel(), + isV1DataPage); } - if (descr_->max_repetition_level() > 0) { - repetition_levels_rle_size = RleEncodeLevels( - repetition_levels_sink_.data(), - repetition_levels_rle_.get(), - descr_->max_repetition_level(), - /*include_length_prefix=*/is_v1_data_page); + if (descr_->maxRepetitionLevel() > 0) { + repetitionLevelsRleSize = rleEncodeLevels( + repetitionLevelsSink_.data(), + repetitionLevelsRle_.get(), + descr_->maxRepetitionLevel(), + isV1DataPage); } - int64_t uncompressed_size = - definition_levels_rle_size + repetition_levels_rle_size + values->size(); + int64_t uncompressedSize = + definitionLevelsRleSize + repetitionLevelsRleSize + values->size(); - if (is_v1_data_page) { - BuildDataPageV1( - definition_levels_rle_size, - repetition_levels_rle_size, - uncompressed_size, + if (isV1DataPage) { + buildDataPageV1( + definitionLevelsRleSize, + repetitionLevelsRleSize, + uncompressedSize, values); } else { - BuildDataPageV2( - definition_levels_rle_size, - repetition_levels_rle_size, - uncompressed_size, + buildDataPageV2( + definitionLevelsRleSize, + repetitionLevelsRleSize, + uncompressedSize, values); } // Re-initialize the sinks for next Page. - InitSinks(); - num_buffered_values_ = 0; - num_buffered_encoded_values_ = 0; - num_buffered_rows_ = 0; - num_buffered_nulls_ = 0; + initSinks(); + numBufferedValues_ = 0; + numBufferedEncodedValues_ = 0; + numBufferedRows_ = 0; + numBufferedNulls_ = 0; } -void ColumnWriterImpl::BuildDataPageV1( - int64_t definition_levels_rle_size, - int64_t repetition_levels_rle_size, - int64_t uncompressed_size, +void ColumnWriterImpl::buildDataPageV1( + int64_t definitionLevelsRleSize, + int64_t repetitionLevelsRleSize, + int64_t uncompressedSize, const std::shared_ptr& values) { - // Use Arrow::Buffer::shrink_to_fit = false - // underlying buffer only keeps growing. Resize to a smaller size does not - // reallocate. - PARQUET_THROW_NOT_OK(uncompressed_data_->Resize(uncompressed_size, false)); - ConcatenateBuffers( - definition_levels_rle_size, - repetition_levels_rle_size, + // Use Arrow::Buffer::shrink_to_fit = false. + // Underlying buffer only keeps growing. Resize to a smaller size does not. + // Reallocate. + PARQUET_THROW_NOT_OK(uncompressedData_->Resize(uncompressedSize, false)); + concatenateBuffers( + definitionLevelsRleSize, + repetitionLevelsRleSize, values, - uncompressed_data_->mutable_data()); - - EncodedStatistics page_stats = GetPageStatistics(); - page_stats.ApplyStatSizeLimits( - properties_->max_statistics_size(descr_->path())); - page_stats.set_is_signed(SortOrder::SIGNED == descr_->sort_order()); - ResetPageStatistics(); - - std::shared_ptr compressed_data; - if (pager_->has_compressor()) { - pager_->Compress( - *(uncompressed_data_.get()), compressor_temp_buffer_.get()); - compressed_data = compressor_temp_buffer_; + uncompressedData_->mutable_data()); + + EncodedStatistics pageStats = getPageStatistics(); + pageStats.applyStatSizeLimits(properties_->maxStatisticsSize(descr_->path())); + pageStats.setIsSigned(SortOrder::kSigned == descr_->sortOrder()); + resetPageStatistics(); + + std::shared_ptr compressedData; + if (pager_->hasCompressor()) { + pager_->compress(*(uncompressedData_.get()), compressorTempBuffer_.get()); + compressedData = compressorTempBuffer_; } else { - compressed_data = uncompressed_data_; + compressedData = uncompressedData_; } - int32_t num_values = static_cast(num_buffered_values_); - int64_t first_row_index = rows_written_ - num_buffered_rows_; + int32_t numValues = static_cast(numBufferedValues_); + int64_t firstRowIndex = rowsWritten_ - numBufferedRows_; - // Write the page to OutputStream eagerly if there is no dictionary or - // if dictionary encoding has fallen back to PLAIN - if (has_dictionary_ && + // Write the page to OutputStream eagerly if there is no dictionary or. + // If dictionary encoding has fallen back to PLAIN. + if (hasDictionary_ && !fallback_) { // Save pages until end of dictionary encoding PARQUET_ASSIGN_OR_THROW( - auto compressed_data_copy, - compressed_data->CopySlice(0, compressed_data->size(), allocator_)); - std::unique_ptr page_ptr = std::make_unique( - compressed_data_copy, - num_values, + auto compressedDataCopy, + compressedData->CopySlice(0, compressedData->size(), allocator_)); + std::unique_ptr pagePtr = std::make_unique( + compressedDataCopy, + numValues, encoding_, - Encoding::RLE, - Encoding::RLE, - uncompressed_size, - page_stats, - first_row_index); - total_compressed_bytes_ += - page_ptr->size() + sizeof(facebook::velox::parquet::thrift::PageHeader); - - data_pages_.push_back(std::move(page_ptr)); + Encoding::kRle, + Encoding::kRle, + uncompressedSize, + pageStats, + firstRowIndex); + totalCompressedBytes_ += + pagePtr->size() + sizeof(facebook::velox::parquet::thrift::PageHeader); + + dataPages_.push_back(std::move(pagePtr)); } else { // Eagerly write pages DataPageV1 page( - compressed_data, - num_values, + compressedData, + numValues, encoding_, - Encoding::RLE, - Encoding::RLE, - uncompressed_size, - page_stats, - first_row_index); - WriteDataPage(page); + Encoding::kRle, + Encoding::kRle, + uncompressedSize, + pageStats, + firstRowIndex); + writeDataPage(page); } } -void ColumnWriterImpl::BuildDataPageV2( - int64_t definition_levels_rle_size, - int64_t repetition_levels_rle_size, - int64_t uncompressed_size, +void ColumnWriterImpl::buildDataPageV2( + int64_t definitionLevelsRleSize, + int64_t repetitionLevelsRleSize, + int64_t uncompressedSize, const std::shared_ptr& values) { // Compress the values if needed. Repetition and definition levels are // uncompressed in V2. - std::shared_ptr compressed_values; - if (pager_->has_compressor()) { - pager_->Compress(*values, compressor_temp_buffer_.get()); - compressed_values = compressor_temp_buffer_; + std::shared_ptr compressedValues; + if (pager_->hasCompressor()) { + pager_->compress(*values, compressorTempBuffer_.get()); + compressedValues = compressorTempBuffer_; } else { - compressed_values = values; + compressedValues = values; } - // Concatenate uncompressed levels and the possibly compressed values - int64_t combined_size = definition_levels_rle_size + - repetition_levels_rle_size + compressed_values->size(); + // Concatenate uncompressed levels and the possibly compressed values. + int64_t combinedSize = definitionLevelsRleSize + repetitionLevelsRleSize + + compressedValues->size(); std::shared_ptr combined = - AllocateBuffer(allocator_, combined_size); + allocateBuffer(allocator_, combinedSize); - ConcatenateBuffers( - definition_levels_rle_size, - repetition_levels_rle_size, - compressed_values, + concatenateBuffers( + definitionLevelsRleSize, + repetitionLevelsRleSize, + compressedValues, combined->mutable_data()); - EncodedStatistics page_stats = GetPageStatistics(); - page_stats.ApplyStatSizeLimits( - properties_->max_statistics_size(descr_->path())); - page_stats.set_is_signed(SortOrder::SIGNED == descr_->sort_order()); - ResetPageStatistics(); - - int32_t num_values = static_cast(num_buffered_values_); - int32_t null_count = static_cast(num_buffered_nulls_); - int32_t num_rows = static_cast(num_buffered_rows_); - int32_t def_levels_byte_length = - static_cast(definition_levels_rle_size); - int32_t rep_levels_byte_length = - static_cast(repetition_levels_rle_size); - int64_t first_row_index = rows_written_ - num_buffered_rows_; - - // page_stats.null_count is not set when page_statistics_ is nullptr. It is + EncodedStatistics pageStats = getPageStatistics(); + pageStats.applyStatSizeLimits(properties_->maxStatisticsSize(descr_->path())); + pageStats.setIsSigned(SortOrder::kSigned == descr_->sortOrder()); + resetPageStatistics(); + + int32_t numValues = static_cast(numBufferedValues_); + int32_t nullCount = static_cast(numBufferedNulls_); + int32_t numRows = static_cast(numBufferedRows_); + int32_t defLevelsByteLength = static_cast(definitionLevelsRleSize); + int32_t repLevelsByteLength = static_cast(repetitionLevelsRleSize); + int64_t firstRowIndex = rowsWritten_ - numBufferedRows_; + + // pageStats.null_count is not set when page_statistics_ is nullptr. It is // only used here for safety check. - VELOX_DCHECK( - !page_stats.has_null_count || page_stats.null_count == null_count); + VELOX_DCHECK(!pageStats.hasNullCount || pageStats.nullCount == nullCount); - // Write the page to OutputStream eagerly if there is no dictionary or - // if dictionary encoding has fallen back to PLAIN - if (has_dictionary_ && + // Write the page to OutputStream eagerly if there is no dictionary or. + // If dictionary encoding has fallen back to PLAIN. + if (hasDictionary_ && !fallback_) { // Save pages until end of dictionary encoding PARQUET_ASSIGN_OR_THROW( - auto data_copy, combined->CopySlice(0, combined->size(), allocator_)); - std::unique_ptr page_ptr = std::make_unique( + auto dataCopy, combined->CopySlice(0, combined->size(), allocator_)); + std::unique_ptr pagePtr = std::make_unique( combined, - num_values, - null_count, - num_rows, + numValues, + nullCount, + numRows, encoding_, - def_levels_byte_length, - rep_levels_byte_length, - uncompressed_size, - pager_->has_compressor(), - page_stats, - first_row_index); - total_compressed_bytes_ += - page_ptr->size() + sizeof(facebook::velox::parquet::thrift::PageHeader); - data_pages_.push_back(std::move(page_ptr)); + defLevelsByteLength, + repLevelsByteLength, + uncompressedSize, + pager_->hasCompressor(), + pageStats, + firstRowIndex); + totalCompressedBytes_ += + pagePtr->size() + sizeof(facebook::velox::parquet::thrift::PageHeader); + dataPages_.push_back(std::move(pagePtr)); } else { DataPageV2 page( combined, - num_values, - null_count, - num_rows, + numValues, + nullCount, + numRows, encoding_, - def_levels_byte_length, - rep_levels_byte_length, - uncompressed_size, - pager_->has_compressor(), - page_stats, - first_row_index); - WriteDataPage(page); + defLevelsByteLength, + repLevelsByteLength, + uncompressedSize, + pager_->hasCompressor(), + pageStats, + firstRowIndex); + writeDataPage(page); } } -int64_t ColumnWriterImpl::Close() { +int64_t ColumnWriterImpl::close() { if (!closed_) { closed_ = true; - if (has_dictionary_ && !fallback_) { - WriteDictionaryPage(); + if (hasDictionary_ && !fallback_) { + writeDictionaryPage(); } - FlushBufferedDataPages(); + flushBufferedDataPages(); - EncodedStatistics chunk_statistics = GetChunkStatistics(); - chunk_statistics.ApplyStatSizeLimits( - properties_->max_statistics_size(descr_->path())); - chunk_statistics.set_is_signed(SortOrder::SIGNED == descr_->sort_order()); + EncodedStatistics chunkStatistics = getChunkStatistics(); + chunkStatistics.applyStatSizeLimits( + properties_->maxStatisticsSize(descr_->path())); + chunkStatistics.setIsSigned(SortOrder::kSigned == descr_->sortOrder()); - // Write stats only if the column has at least one row written - if (rows_written_ > 0 && chunk_statistics.is_set()) { - metadata_->SetStatistics(chunk_statistics); + // Write stats only if the column has at least one row written. + if (rowsWritten_ > 0 && chunkStatistics.isSet()) { + metadata_->setStatistics(chunkStatistics); } - pager_->Close(has_dictionary_, fallback_); + pager_->close(hasDictionary_, fallback_); } - return total_bytes_written_; + return totalBytesWritten_; } -void ColumnWriterImpl::FlushBufferedDataPages() { - // Write all outstanding data to a new page - if (num_buffered_values_ > 0) { - AddDataPage(); +void ColumnWriterImpl::flushBufferedDataPages() { + // Write all outstanding data to a new page. + if (numBufferedValues_ > 0) { + addDataPage(); } - for (const auto& page_ptr : data_pages_) { - WriteDataPage(*page_ptr); + for (const auto& pagePtr : dataPages_) { + writeDataPage(*pagePtr); } - data_pages_.clear(); - total_compressed_bytes_ = 0; + dataPages_.clear(); + totalCompressedBytes_ = 0; } -// ---------------------------------------------------------------------- -// TypedColumnWriter +// ----------------------------------------------------------------------. +// TypedColumnWriter. template -inline void DoInBatches(int64_t total, int64_t batch_size, Action&& action) { - int64_t num_batches = static_cast(total / batch_size); - for (int round = 0; round < num_batches; round++) { - action(round * batch_size, batch_size, /*check_page_size=*/true); +inline void doInBatches(int64_t total, int64_t batchSize, Action&& action) { + int64_t numBatches = static_cast(total / batchSize); + for (int round = 0; round < numBatches; round++) { + action(round * batchSize, batchSize, true); } - // Write the remaining values - if (total % batch_size > 0) { - action( - num_batches * batch_size, total % batch_size, /*check_page_size=*/true); + // Write the remaining values. + if (total % batchSize > 0) { + action(numBatches * batchSize, total % batchSize, true); } } template -inline void DoInBatches( - const int16_t* def_levels, - const int16_t* rep_levels, - int64_t num_levels, - int64_t batch_size, +inline void doInBatches( + const int16_t* defLevels, + const int16_t* repLevels, + int64_t numLevels, + int64_t batchSize, Action&& action, - bool pages_change_on_record_boundaries) { - if (!pages_change_on_record_boundaries || !rep_levels) { - // If rep_levels is null, then we are writing a non-repeated column. + bool pagesChangeOnRecordBoundaries) { + if (!pagesChangeOnRecordBoundaries || !repLevels) { + // If repLevels is null, then we are writing a non-repeated column. // In this case, every record contains only one level. - return DoInBatches(num_levels, batch_size, std::forward(action)); + return doInBatches(numLevels, batchSize, std::forward(action)); } int64_t offset = 0; - while (offset < num_levels) { - int64_t end_offset = std::min(offset + batch_size, num_levels); + while (offset < numLevels) { + int64_t endOffset = std::min(offset + batchSize, numLevels); - // Find next record boundary (i.e. ref_level = 0) - while (end_offset < num_levels && rep_levels[end_offset] != 0) { - end_offset++; + // Find next record boundary (i.e. repLevel = 0). + while (endOffset < numLevels && repLevels[endOffset] != 0) { + endOffset++; } - if (end_offset < num_levels) { - // This is not the last chunk of batch and end_offset is a record + if (endOffset < numLevels) { + // This is not the last chunk of batch and endOffset is a record // boundary. It is a good chance to check the page size. - action(offset, end_offset - offset, /*check_page_size=*/true); + action(offset, endOffset - offset, true); } else { - VELOX_DCHECK_EQ(end_offset, num_levels); - // This is the last chunk of batch, and we do not know whether end_offset + VELOX_DCHECK_EQ(endOffset, numLevels); + // This is the last chunk of batch, and we do not know whether endOffset // is a record boundary. Find the offset to beginning of last record in // this chunk, so we can check page size. - int64_t last_record_begin_offset = num_levels - 1; - while (last_record_begin_offset >= offset && - rep_levels[last_record_begin_offset] != 0) { - last_record_begin_offset--; + int64_t lastRecordBeginOffset = numLevels - 1; + while (lastRecordBeginOffset >= offset && + repLevels[lastRecordBeginOffset] != 0) { + lastRecordBeginOffset--; } - if (offset < last_record_begin_offset) { + if (offset < lastRecordBeginOffset) { // We have found the beginning of last record and can check page size. - action( - offset, - last_record_begin_offset - offset, - /*check_page_size=*/true); - offset = last_record_begin_offset; + action(offset, lastRecordBeginOffset - offset, true); + offset = lastRecordBeginOffset; } // There is no record boundary in this chunk and cannot check page size. - action(offset, end_offset - offset, /*check_page_size=*/false); + action(offset, endOffset - offset, false); } - offset = end_offset; + offset = endOffset; } } -bool DictionaryDirectWriteSupported(const ::arrow::Array& array) { - VELOX_DCHECK_EQ(array.type_id(), ::arrow::Type::DICTIONARY); - const ::arrow::DictionaryType& dict_type = +bool dictionaryDirectWriteSupported(const ::arrow::Array& array) { + VELOX_DCHECK_EQ( + static_cast(array.type_id()), + static_cast(::arrow::Type::DICTIONARY)); + const ::arrow::DictionaryType& dictType = static_cast(*array.type()); - return ::arrow::is_base_binary_like(dict_type.value_type()->id()); + return ::arrow::is_base_binary_like(dictType.value_type()->id()); } -Status ConvertDictionaryToDense( +Status convertDictionaryToDense( const ::arrow::Array& array, MemoryPool* pool, std::shared_ptr<::arrow::Array>* out) { - const ::arrow::DictionaryType& dict_type = + const ::arrow::DictionaryType& dictType = static_cast(*array.type()); ::arrow::compute::ExecContext ctx(pool); ARROW_ASSIGN_OR_RAISE( - Datum cast_output, + Datum castOutput, ::arrow::compute::Cast( array.data(), - dict_type.value_type(), - ::arrow::compute::CastOptions(), + dictType.value_type(), + ::arrow::compute::CastOptions::Safe(), &ctx)); - *out = cast_output.make_array(); + *out = castOutput.make_array(); return Status::OK(); } -static inline bool IsDictionaryEncoding(Encoding::type encoding) { - return encoding == Encoding::PLAIN_DICTIONARY; +static inline bool isDictionaryEncoding(Encoding::type encoding) { + return encoding == Encoding::kPlainDictionary; } template class TypedColumnWriterImpl : public ColumnWriterImpl, public TypedColumnWriter { public: - using T = typename DType::c_type; + using T = typename DType::CType; TypedColumnWriterImpl( ColumnChunkMetaDataBuilder* metadata, std::unique_ptr pager, - const bool use_dictionary, + const bool useDictionary, Encoding::type encoding, const WriterProperties* properties) : ColumnWriterImpl( metadata, std::move(pager), - use_dictionary, + useDictionary, encoding, properties) { - current_encoder_ = MakeEncoder( - DType::type_num, + currentEncoder_ = makeEncoder( + DType::typeNum, encoding, - use_dictionary, + useDictionary, descr_, - properties->memory_pool()); + properties->memoryPool()); // We have to dynamic_cast as some compilers don't want to static_cast // through virtual inheritance. - current_value_encoder_ = - dynamic_cast*>(current_encoder_.get()); + currentValueEncoder_ = + dynamic_cast*>(currentEncoder_.get()); - // Will be null if not using dictionary, but that's ok - current_dict_encoder_ = - dynamic_cast*>(current_encoder_.get()); + // Will be null if not using dictionary, but that's ok. + currentDictEncoder_ = + dynamic_cast*>(currentEncoder_.get()); - if (properties->statistics_enabled(descr_->path()) && - (SortOrder::UNKNOWN != descr_->sort_order())) { - page_statistics_ = MakeStatistics(descr_, allocator_); - chunk_statistics_ = MakeStatistics(descr_, allocator_); + if (properties->statisticsEnabled(descr_->path()) && + (SortOrder::kUnknown != descr_->sortOrder())) { + pageStatistics_ = makeStatistics(descr_, allocator_); + chunkStatistics_ = makeStatistics(descr_, allocator_); } - pages_change_on_record_boundaries_ = - properties->data_page_version() == ParquetDataPageVersion::V2 || - properties->page_index_enabled(descr_->path()); + pagesChangeOnRecordBoundaries_ = + properties->dataPageVersion() == ParquetDataPageVersion::V2 || + properties->pageIndexEnabled(descr_->path()); } - int64_t Close() override { - return ColumnWriterImpl::Close(); + int64_t close() override { + return ColumnWriterImpl::close(); } - int64_t WriteBatch( - int64_t num_values, - const int16_t* def_levels, - const int16_t* rep_levels, + int64_t writeBatch( + int64_t numValues, + const int16_t* defLevels, + const int16_t* repLevels, const T* values) override { - // We check for DataPage limits only after we have inserted the values. If a - // user writes a large number of values, the DataPage size can be much above - // the limit. The purpose of this chunking is to bound this. Even if a user - // writes large number of values, the chunking will ensure the AddDataPage() - // is called at a reasonable pagesize limit - int64_t value_offset = 0; - - auto WriteChunk = [&](int64_t offset, int64_t batch_size, bool check_page) { - int64_t values_to_write = WriteLevels( - batch_size, - AddIfNotNull(def_levels, offset), - AddIfNotNull(rep_levels, offset)); - - // PARQUET-780 - if (values_to_write > 0) { + // We check for DataPage limits only after we have inserted the values. If + // a user writes a large number of values, the DataPage size can be much + // above the limit. The purpose of this chunking is to bound this. Even if + // a user writes large number of values, the chunking will ensure the + // addDataPage() is called at a reasonable pagesize limit. + int64_t valueOffset = 0; + + auto writeChunk = [&](int64_t offset, int64_t batchSize, bool checkPage) { + int64_t valuesToWrite = writeLevels( + batchSize, + addIfNotNull(defLevels, offset), + addIfNotNull(repLevels, offset)); + + // PARQUET-780. + if (valuesToWrite > 0) { VELOX_DCHECK_NOT_NULL(values); } - const int64_t num_nulls = batch_size - values_to_write; - WriteValues( - AddIfNotNull(values, value_offset), values_to_write, num_nulls); - CommitWriteAndCheckPageLimit( - batch_size, values_to_write, num_nulls, check_page); - value_offset += values_to_write; + const int64_t numNulls = batchSize - valuesToWrite; + writeValues(addIfNotNull(values, valueOffset), valuesToWrite, numNulls); + commitWriteAndCheckPageLimit( + batchSize, valuesToWrite, numNulls, checkPage); + valueOffset += valuesToWrite; // Dictionary size checked separately from data page size since we - // circumvent this check when writing ::arrow::DictionaryArray directly - CheckDictionarySizeLimit(); + // circumvent this check when writing ::arrow::DictionaryArray directly. + checkDictionarySizeLimit(); }; - DoInBatches( - def_levels, - rep_levels, - num_values, - properties_->write_batch_size(), - WriteChunk, - pages_change_on_record_boundaries()); - return value_offset; - } - - void WriteBatchSpaced( - int64_t num_values, - const int16_t* def_levels, - const int16_t* rep_levels, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + doInBatches( + defLevels, + repLevels, + numValues, + properties_->writeBatchSize(), + writeChunk, + pagesChangeOnRecordBoundaries()); + return valueOffset; + } + + void writeBatchSpaced( + int64_t numValues, + const int16_t* defLevels, + const int16_t* repLevels, + const uint8_t* validBits, + int64_t validBitsOffset, const T* values) override { - // Like WriteBatch, but for spaced values - int64_t value_offset = 0; - auto WriteChunk = [&](int64_t offset, int64_t batch_size, bool check_page) { - int64_t batch_num_values = 0; - int64_t batch_num_spaced_values = 0; - int64_t null_count; - MaybeCalculateValidityBits( - AddIfNotNull(def_levels, offset), - batch_size, - &batch_num_values, - &batch_num_spaced_values, - &null_count); - - WriteLevelsSpaced( - batch_size, - AddIfNotNull(def_levels, offset), - AddIfNotNull(rep_levels, offset)); - if (bits_buffer_ != nullptr) { - WriteValuesSpaced( - AddIfNotNull(values, value_offset), - batch_num_values, - batch_num_spaced_values, - bits_buffer_->data(), - /*valid_bits_offset=*/0, - /*num_levels=*/batch_size, - null_count); + // Like WriteBatch, but for spaced values. + int64_t valueOffset = 0; + auto writeChunk = [&](int64_t offset, int64_t batchSize, bool checkPage) { + int64_t batchNumValues = 0; + int64_t batchNumSpacedValues = 0; + int64_t nullCount; + maybeCalculateValidityBits( + addIfNotNull(defLevels, offset), + batchSize, + &batchNumValues, + &batchNumSpacedValues, + &nullCount); + + writeLevelsSpaced( + batchSize, + addIfNotNull(defLevels, offset), + addIfNotNull(repLevels, offset)); + if (bitsBuffer_ != nullptr) { + writeValuesSpaced( + addIfNotNull(values, valueOffset), + batchNumValues, + batchNumSpacedValues, + bitsBuffer_->data(), + 0, + batchSize, + nullCount); } else { - WriteValuesSpaced( - AddIfNotNull(values, value_offset), - batch_num_values, - batch_num_spaced_values, - valid_bits, - valid_bits_offset + value_offset, - /*num_levels=*/batch_size, - null_count); + writeValuesSpaced( + addIfNotNull(values, valueOffset), + batchNumValues, + batchNumSpacedValues, + validBits, + validBitsOffset + valueOffset, + batchSize, + nullCount); } - CommitWriteAndCheckPageLimit( - batch_size, batch_num_spaced_values, null_count, check_page); - value_offset += batch_num_spaced_values; + commitWriteAndCheckPageLimit( + batchSize, batchNumSpacedValues, nullCount, checkPage); + valueOffset += batchNumSpacedValues; // Dictionary size checked separately from data page size since we - // circumvent this check when writing ::arrow::DictionaryArray directly - CheckDictionarySizeLimit(); + // circumvent this check when writing ::arrow::DictionaryArray directly. + checkDictionarySizeLimit(); }; - DoInBatches( - def_levels, - rep_levels, - num_values, - properties_->write_batch_size(), - WriteChunk, - pages_change_on_record_boundaries()); - } - - Status WriteArrow( - const int16_t* def_levels, - const int16_t* rep_levels, - int64_t num_levels, - const ::arrow::Array& leaf_array, + doInBatches( + defLevels, + repLevels, + numValues, + properties_->writeBatchSize(), + writeChunk, + pagesChangeOnRecordBoundaries()); + } + + Status writeArrow( + const int16_t* defLevels, + const int16_t* repLevels, + int64_t numLevels, + const ::arrow::Array& leafArray, ArrowWriteContext* ctx, - bool leaf_field_nullable) override { + bool leafFieldNullable) override { BEGIN_PARQUET_CATCH_EXCEPTIONS - // Leaf nulls are canonical when there is only a single null element after a - // list and it is at the leaf. - bool single_nullable_element = - (level_info_.defLevel == level_info_.repeatedAncestorDefLevel + 1) && - leaf_field_nullable; - bool maybe_parent_nulls = - level_info_.HasNullableValues() && !single_nullable_element; - if (maybe_parent_nulls) { + // Leaf nulls are canonical when there is only a single null element after + // a list and it is at the leaf. + bool singleNullableElement = + (levelInfo_.defLevel == levelInfo_.repeatedAncestorDefLevel + 1) && + leafFieldNullable; + bool maybeParentNulls = + levelInfo_.HasNullableValues() && !singleNullableElement; + if (maybeParentNulls) { ARROW_ASSIGN_OR_RAISE( - bits_buffer_, + bitsBuffer_, ::arrow::AllocateResizableBuffer( - ::arrow::bit_util::BytesForBits(properties_->write_batch_size()), - ctx->memory_pool)); - bits_buffer_->ZeroPadding(); + ::arrow::bit_util::BytesForBits(properties_->writeBatchSize()), + ctx->memoryPool)); + bitsBuffer_->ZeroPadding(); } - if (leaf_array.type()->id() == ::arrow::Type::DICTIONARY) { - return WriteArrowDictionary( - def_levels, - rep_levels, - num_levels, - leaf_array, - ctx, - maybe_parent_nulls); + if (leafArray.type()->id() == ::arrow::Type::DICTIONARY) { + return writeArrowDictionary( + defLevels, repLevels, numLevels, leafArray, ctx, maybeParentNulls); } else { - return WriteArrowDense( - def_levels, - rep_levels, - num_levels, - leaf_array, - ctx, - maybe_parent_nulls); + return writeArrowDense( + defLevels, repLevels, numLevels, leafArray, ctx, maybeParentNulls); } END_PARQUET_CATCH_EXCEPTIONS } - int64_t EstimatedBufferedValueBytes() const override { - return current_encoder_->EstimatedDataEncodedSize(); + int64_t estimatedBufferedValueBytes() const override { + return currentEncoder_->estimatedDataEncodedSize(); } protected: - std::shared_ptr GetValuesBuffer() override { - return current_encoder_->FlushValues(); + std::shared_ptr getValuesBuffer() override { + return currentEncoder_->flushValues(); } // Internal function to handle direct writing of ::arrow::DictionaryArray, // since the standard logic concerning dictionary size limits and fallback to - // plain encoding is circumvented - Status WriteArrowDictionary( - const int16_t* def_levels, - const int16_t* rep_levels, - int64_t num_levels, + // plain encoding is circumvented. + Status writeArrowDictionary( + const int16_t* defLevels, + const int16_t* repLevels, + int64_t numLevels, const ::arrow::Array& array, ArrowWriteContext* context, - bool maybe_parent_nulls); + bool maybeParentNulls); - Status WriteArrowDense( - const int16_t* def_levels, - const int16_t* rep_levels, - int64_t num_levels, + Status writeArrowDense( + const int16_t* defLevels, + const int16_t* repLevels, + int64_t numLevels, const ::arrow::Array& array, ArrowWriteContext* context, - bool maybe_parent_nulls); + bool maybeParentNulls); - void WriteDictionaryPage() override { - VELOX_DCHECK(current_dict_encoder_); - std::shared_ptr buffer = AllocateBuffer( - properties_->memory_pool(), current_dict_encoder_->dict_encoded_size()); - current_dict_encoder_->WriteDict(buffer->mutable_data()); + void writeDictionaryPage() override { + VELOX_DCHECK(currentDictEncoder_); + std::shared_ptr buffer = allocateBuffer( + properties_->memoryPool(), currentDictEncoder_->dictEncodedSize()); + currentDictEncoder_->writeDict(buffer->mutable_data()); DictionaryPage page( buffer, - current_dict_encoder_->num_entries(), - properties_->dictionary_page_encoding()); - total_bytes_written_ += pager_->WriteDictionaryPage(page); + currentDictEncoder_->numEntries(), + properties_->dictionaryPageEncoding()); + totalBytesWritten_ += pager_->writeDictionaryPage(page); } - EncodedStatistics GetPageStatistics() override { + EncodedStatistics getPageStatistics() override { EncodedStatistics result; - if (page_statistics_) - result = page_statistics_->Encode(); + if (pageStatistics_) + result = pageStatistics_->encode(); return result; } - EncodedStatistics GetChunkStatistics() override { + EncodedStatistics getChunkStatistics() override { EncodedStatistics result; - if (chunk_statistics_) - result = chunk_statistics_->Encode(); + if (chunkStatistics_) + result = chunkStatistics_->encode(); return result; } - void ResetPageStatistics() override { - if (chunk_statistics_ != nullptr) { - chunk_statistics_->Merge(*page_statistics_); - page_statistics_->Reset(); + void resetPageStatistics() override { + if (chunkStatistics_ != nullptr) { + chunkStatistics_->merge(*pageStatistics_); + pageStatistics_->reset(); } } Type::type type() const override { - return descr_->physical_type(); + return descr_->physicalType(); } const ColumnDescriptor* descr() const override { return descr_; } - int64_t rows_written() const override { - return rows_written_; + int64_t rowsWritten() const override { + return rowsWritten_; } - int64_t total_compressed_bytes() const override { - return total_compressed_bytes_; + int64_t totalCompressedBytes() const override { + return totalCompressedBytes_; } - int64_t total_bytes_written() const override { - return total_bytes_written_; + int64_t totalBytesWritten() const override { + return totalBytesWritten_; } - int64_t total_compressed_bytes_written() const override { - return pager_->total_compressed_bytes_written(); + int64_t totalCompressedBytesWritten() const override { + return pager_->totalCompressedBytesWritten(); } const WriterProperties* properties() override { return properties_; } - bool pages_change_on_record_boundaries() const { - return pages_change_on_record_boundaries_; + bool pagesChangeOnRecordBoundaries() const { + return pagesChangeOnRecordBoundaries_; } private: using ValueEncoderType = typename EncodingTraits::Encoder; using TypedStats = TypedStatistics; - std::unique_ptr current_encoder_; - // Downcasted observers of current_encoder_. + std::unique_ptr currentEncoder_; + // Downcasted observers of currentEncoder_. // The downcast is performed once as opposed to at every use since // dynamic_cast is so expensive, and static_cast is not available due // to virtual inheritance. - ValueEncoderType* current_value_encoder_; - DictEncoder* current_dict_encoder_; - std::shared_ptr page_statistics_; - std::shared_ptr chunk_statistics_; - bool pages_change_on_record_boundaries_; + ValueEncoderType* currentValueEncoder_; + DictEncoder* currentDictEncoder_; + std::shared_ptr pageStatistics_; + std::shared_ptr chunkStatistics_; + bool pagesChangeOnRecordBoundaries_; // If writing a sequence of ::arrow::DictionaryArray to the writer, we keep - // the dictionary passed to DictEncoder::PutDictionary so we can check + // the dictionary passed to DictEncoder::putDictionary so we can check // subsequent array chunks to see either if materialization is required (in - // which case we call back to the dense write path) - std::shared_ptr<::arrow::Array> preserved_dictionary_; - - int64_t WriteLevels( - int64_t num_values, - const int16_t* def_levels, - const int16_t* rep_levels) { - int64_t values_to_write = 0; - // If the field is required and non-repeated, there are no definition levels - if (descr_->max_definition_level() > 0) { - for (int64_t i = 0; i < num_values; ++i) { - if (def_levels[i] == descr_->max_definition_level()) { - ++values_to_write; + // which case we call back to the dense write path). + std::shared_ptr<::arrow::Array> preservedDictionary_; + + int64_t writeLevels( + int64_t numValues, + const int16_t* defLevels, + const int16_t* repLevels) { + int64_t valuesToWrite = 0; + // If the field is required and non-repeated, there are no definition + // levels. + if (descr_->maxDefinitionLevel() > 0) { + for (int64_t i = 0; i < numValues; ++i) { + if (defLevels[i] == descr_->maxDefinitionLevel()) { + ++valuesToWrite; } } - WriteDefinitionLevels(num_values, def_levels); + writeDefinitionLevels(numValues, defLevels); } else { - // Required field, write all values - values_to_write = num_values; - } - - // Not present for non-repeated fields - if (descr_->max_repetition_level() > 0) { - // A row could include more than one value - // Count the occasions where we start a new row - for (int64_t i = 0; i < num_values; ++i) { - if (rep_levels[i] == 0) { - rows_written_++; - num_buffered_rows_++; + // Required field, write all values. + valuesToWrite = numValues; + } + + // Not present for non-repeated fields. + if (descr_->maxRepetitionLevel() > 0) { + // A row could include more than one value. + // Count the occasions where we start a new row. + for (int64_t i = 0; i < numValues; ++i) { + if (repLevels[i] == 0) { + rowsWritten_++; + numBufferedRows_++; } } - WriteRepetitionLevels(num_values, rep_levels); + writeRepetitionLevels(numValues, repLevels); } else { - // Each value is exactly one row - rows_written_ += num_values; - num_buffered_rows_ += num_values; + // Each value is exactly one row. + rowsWritten_ += numValues; + numBufferedRows_ += numValues; } - return values_to_write; + return valuesToWrite; } // This method will always update the three output parameters, - // out_values_to_write, out_spaced_values_to_write and null_count. + // outValuesToWrite, outSpacedValuesToWrite and nullCount. // Additionally it will update the validity bitmap if required (i.e. if at // least one level of nullable structs directly precede the leaf node). - void MaybeCalculateValidityBits( - const int16_t* def_levels, - int64_t batch_size, - int64_t* out_values_to_write, - int64_t* out_spaced_values_to_write, - int64_t* null_count) { - if (bits_buffer_ == nullptr) { - if (level_info_.defLevel == 0) { + void maybeCalculateValidityBits( + const int16_t* defLevels, + int64_t batchSize, + int64_t* outValuesToWrite, + int64_t* outSpacedValuesToWrite, + int64_t* nullCount) { + if (bitsBuffer_ == nullptr) { + if (levelInfo_.defLevel == 0) { // In this case def levels should be null and we only // need to output counts which will always be equal to - // the batch size passed in (max def_level == 0 indicates + // the batch size passed in (max defLevel == 0 indicates // there cannot be repeated or null fields). - VELOX_DCHECK_NULL(def_levels); - *out_values_to_write = batch_size; - *out_spaced_values_to_write = batch_size; - *null_count = 0; + VELOX_DCHECK_NULL(defLevels); + *outValuesToWrite = batchSize; + *outSpacedValuesToWrite = batchSize; + *nullCount = 0; } else { - for (int x = 0; x < batch_size; x++) { - *out_values_to_write += def_levels[x] == level_info_.defLevel ? 1 : 0; - *out_spaced_values_to_write += - def_levels[x] >= level_info_.repeatedAncestorDefLevel ? 1 : 0; + for (int x = 0; x < batchSize; x++) { + *outValuesToWrite += defLevels[x] == levelInfo_.defLevel ? 1 : 0; + *outSpacedValuesToWrite += + defLevels[x] >= levelInfo_.repeatedAncestorDefLevel ? 1 : 0; } - *null_count = batch_size - *out_values_to_write; + *nullCount = batchSize - *outValuesToWrite; } return; } // Shrink to fit possible causes another allocation, and would only be // necessary on the last batch. - int64_t new_bitmap_size = ::arrow::bit_util::BytesForBits(batch_size); - if (new_bitmap_size != bits_buffer_->size()) { - PARQUET_THROW_NOT_OK( - bits_buffer_->Resize(new_bitmap_size, /*shrink_to_fit=*/false)); - bits_buffer_->ZeroPadding(); + int64_t newBitmapSize = ::arrow::bit_util::BytesForBits(batchSize); + if (newBitmapSize != bitsBuffer_->size()) { + PARQUET_THROW_NOT_OK(bitsBuffer_->Resize(newBitmapSize, false)); + bitsBuffer_->ZeroPadding(); } ValidityBitmapInputOutput io; - io.validBits = bits_buffer_->mutable_data(); - io.valuesReadUpperBound = batch_size; - DefLevelsToBitmap(def_levels, batch_size, level_info_, &io); - *out_values_to_write = io.valuesRead - io.nullCount; - *out_spaced_values_to_write = io.valuesRead; - *null_count = io.nullCount; + io.validBits = bitsBuffer_->mutable_data(); + io.valuesReadUpperBound = batchSize; + DefLevelsToBitmap(defLevels, batchSize, levelInfo_, &io); + *outValuesToWrite = io.valuesRead - io.nullCount; + *outSpacedValuesToWrite = io.valuesRead; + *nullCount = io.nullCount; } - Result> MaybeReplaceValidity( + Result> maybeReplaceValidity( std::shared_ptr array, - int64_t new_null_count, - ::arrow::MemoryPool* memory_pool) { - if (bits_buffer_ == nullptr) { + int64_t newNullCount, + ::arrow::MemoryPool* memoryPool) { + if (bitsBuffer_ == nullptr) { return array; } std::vector> buffers = array->data()->buffers; if (buffers.empty()) { return array; } - buffers[0] = bits_buffer_; + buffers[0] = bitsBuffer_; // Should be a leaf array. VELOX_DCHECK_GT(buffers.size(), 1); - ValueBufferSlicer slicer{memory_pool}; + ValueBufferSlicer slicer{memoryPool}; if (array->data()->offset > 0) { - RETURN_NOT_OK(util::VisitArrayInline(*array, &slicer, &buffers[1])); - } - return ::arrow::MakeArray(std::make_shared( - array->type(), array->length(), std::move(buffers), new_null_count)); - } - - void WriteLevelsSpaced( - int64_t num_levels, - const int16_t* def_levels, - const int16_t* rep_levels) { - // If the field is required and non-repeated, there are no definition levels - if (descr_->max_definition_level() > 0) { - WriteDefinitionLevels(num_levels, def_levels); - } - // Not present for non-repeated fields - if (descr_->max_repetition_level() > 0) { - // A row could include more than one value - // Count the occasions where we start a new row - for (int64_t i = 0; i < num_levels; ++i) { - if (rep_levels[i] == 0) { - rows_written_++; - num_buffered_rows_++; + RETURN_NOT_OK(util::visitArrayInline(*array, &slicer, &buffers[1])); + } + return ::arrow::MakeArray( + std::make_shared( + array->type(), array->length(), std::move(buffers), newNullCount)); + } + + void writeLevelsSpaced( + int64_t numLevels, + const int16_t* defLevels, + const int16_t* repLevels) { + // If the field is required and non-repeated, there are no definition + // levels. + if (descr_->maxDefinitionLevel() > 0) { + writeDefinitionLevels(numLevels, defLevels); + } + // Not present for non-repeated fields. + if (descr_->maxRepetitionLevel() > 0) { + // A row could include more than one value. + // Count the occasions where we start a new row. + for (int64_t i = 0; i < numLevels; ++i) { + if (repLevels[i] == 0) { + rowsWritten_++; + numBufferedRows_++; } } - WriteRepetitionLevels(num_levels, rep_levels); + writeRepetitionLevels(numLevels, repLevels); } else { - // Each value is exactly one row - rows_written_ += num_levels; - num_buffered_rows_ += num_levels; + // Each value is exactly one row. + rowsWritten_ += numLevels; + numBufferedRows_ += numLevels; } } - void CommitWriteAndCheckPageLimit( - int64_t num_levels, - int64_t num_values, - int64_t num_nulls, - bool check_page_size) { - num_buffered_values_ += num_levels; - num_buffered_encoded_values_ += num_values; - num_buffered_nulls_ += num_nulls; + void commitWriteAndCheckPageLimit( + int64_t numLevels, + int64_t numValues, + int64_t numNulls, + bool checkPageSize) { + numBufferedValues_ += numLevels; + numBufferedEncodedValues_ += numValues; + numBufferedNulls_ += numNulls; - if (check_page_size && - current_encoder_->EstimatedDataEncodedSize() >= - properties_->data_pagesize()) { - AddDataPage(); + if (checkPageSize && + currentEncoder_->estimatedDataEncodedSize() >= + properties_->dataPagesize()) { + addDataPage(); } } - void FallbackToPlainEncoding() { - if (IsDictionaryEncoding(current_encoder_->encoding())) { - WriteDictionaryPage(); - // Serialize the buffered Dictionary Indices - FlushBufferedDataPages(); + void fallbackToPlainEncoding() { + if (isDictionaryEncoding(currentEncoder_->encoding())) { + writeDictionaryPage(); + // Serialize the buffered dictionary indices. + flushBufferedDataPages(); fallback_ = true; - // Only PLAIN encoding is supported for fallback in V1 - current_encoder_ = MakeEncoder( - DType::type_num, - Encoding::PLAIN, + // Only PLAIN encoding is supported for fallback in V1. + currentEncoder_ = makeEncoder( + DType::typeNum, + Encoding::kPlain, false, descr_, - properties_->memory_pool()); - current_value_encoder_ = - dynamic_cast(current_encoder_.get()); - current_dict_encoder_ = nullptr; // not using dict - encoding_ = Encoding::PLAIN; + properties_->memoryPool()); + currentValueEncoder_ = + dynamic_cast(currentEncoder_.get()); + currentDictEncoder_ = nullptr; // not using dict + encoding_ = Encoding::kPlain; } } - // Checks if the Dictionary Page size limit is reached - // If the limit is reached, the Dictionary and Data Pages are serialized - // The encoding is switched to PLAIN + // Checks if the Dictionary Page size limit is reached. + // If the limit is reached, the Dictionary and Data Pages are serialized. + // The encoding is switched to PLAIN. // // Only one Dictionary Page is written. // Fallback to PLAIN if dictionary page limit is reached. - void CheckDictionarySizeLimit() { - if (!has_dictionary_ || fallback_) { + void checkDictionarySizeLimit() { + if (!hasDictionary_ || fallback_) { // Either not using dictionary encoding, or we have already fallen back - // to PLAIN encoding because the size threshold was reached + // to PLAIN encoding because the size threshold was reached. return; } - if (current_dict_encoder_->dict_encoded_size() >= - properties_->dictionary_pagesize_limit()) { - FallbackToPlainEncoding(); + if (currentDictEncoder_->dictEncodedSize() >= + properties_->dictionaryPagesizeLimit()) { + fallbackToPlainEncoding(); } } - void WriteValues(const T* values, int64_t num_values, int64_t num_nulls) { - current_value_encoder_->Put(values, static_cast(num_values)); - if (page_statistics_ != nullptr) { - page_statistics_->Update(values, num_values, num_nulls); + void writeValues(const T* values, int64_t numValues, int64_t numNulls) { + currentValueEncoder_->put(values, static_cast(numValues)); + if (pageStatistics_ != nullptr) { + pageStatistics_->update(values, numValues, numNulls); } } /// \brief Write values with spaces and update page statistics accordingly. /// - /// \param values input buffer of values to write, including spaces. - /// \param num_values number of non-null values in the values buffer. - /// \param num_spaced_values length of values buffer, including spaces and - /// does not - /// count some nulls from ancestor (e.g. empty lists). - /// \param valid_bits validity bitmap of values buffer, which does not include - /// some - /// nulls from ancestor (e.g. empty lists). - /// \param valid_bits_offset offset to valid_bits bitmap. - /// \param num_levels number of levels to write, including nulls from values - /// buffer - /// and nulls from ancestor (e.g. empty lists). - /// \param num_nulls number of nulls in the values buffer as well as nulls - /// from the - /// ancestor (e.g. empty lists). - void WriteValuesSpaced( + /// @param values input buffer of values to write, including spaces. + /// @param numValues number of non-null values in the values buffer. + /// @param numSpacedValues length of values buffer, including spaces and + /// does not count some nulls from ancestor (e.g. empty lists). + /// @param validBits validity bitmap of values buffer, which does not + /// include some nulls from ancestor (e.g. empty lists). + /// @param validBitsOffset offset to validBits bitmap. + /// @param numLevels number of levels to write, including nulls from values + /// buffer and nulls from ancestor (e.g. empty lists). + /// @param numNulls number of nulls in the values buffer as well as nulls + /// from the ancestor (e.g. empty lists). + void writeValuesSpaced( const T* values, - int64_t num_values, - int64_t num_spaced_values, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - int64_t num_levels, - int64_t num_nulls) { - if (num_values != num_spaced_values) { - current_value_encoder_->PutSpaced( + int64_t numValues, + int64_t numSpacedValues, + const uint8_t* validBits, + int64_t validBitsOffset, + int64_t numLevels, + int64_t numNulls) { + if (numValues != numSpacedValues) { + currentValueEncoder_->putSpaced( values, - static_cast(num_spaced_values), - valid_bits, - valid_bits_offset); + static_cast(numSpacedValues), + validBits, + validBitsOffset); } else { - current_value_encoder_->Put(values, static_cast(num_values)); + currentValueEncoder_->put(values, static_cast(numValues)); } - if (page_statistics_ != nullptr) { - page_statistics_->UpdateSpaced( + if (pageStatistics_ != nullptr) { + pageStatistics_->updateSpaced( values, - valid_bits, - valid_bits_offset, - num_spaced_values, - num_values, - num_nulls); + validBits, + validBitsOffset, + numSpacedValues, + numValues, + numNulls); } } }; template -Status TypedColumnWriterImpl::WriteArrowDictionary( - const int16_t* def_levels, - const int16_t* rep_levels, - int64_t num_levels, +Status TypedColumnWriterImpl::writeArrowDictionary( + const int16_t* defLevels, + const int16_t* repLevels, + int64_t numLevels, const ::arrow::Array& array, ArrowWriteContext* ctx, - bool maybe_parent_nulls) { + bool maybeParentNulls) { // If this is the first time writing a DictionaryArray, then there's // a few possible paths to take: // - // - If dictionary encoding is not enabled, convert to densely - // encoded and call WriteArrow - // - Dictionary encoding enabled + // - If dictionary encoding is not enabled, convert to densely. + // Encoded and call WriteArrow. + // - Dictionary encoding enabled. // - If this is the first time this is called, then we call - // PutDictionary into the encoder and then PutIndices on each + // putDictionary into the encoder and then putIndices on each // chunk. We store the dictionary that was written in - // preserved_dictionary_ so that subsequent calls to this method - // can make sure the dictionary has not changed + // preservedDictionary_ so that subsequent calls to this method + // can make sure the dictionary has not changed. // - On subsequent calls, we have to check whether the dictionary // has changed. If it has, then we trigger the varying // dictionary path and materialize each chunk and then call - // WriteArrow with that - auto WriteDense = [&] { - std::shared_ptr<::arrow::Array> dense_array; - RETURN_NOT_OK(ConvertDictionaryToDense( - array, properties_->memory_pool(), &dense_array)); - return WriteArrowDense( - def_levels, - rep_levels, - num_levels, - *dense_array, - ctx, - maybe_parent_nulls); + // writeArrow with that. + auto writeDense = [&] { + std::shared_ptr<::arrow::Array> denseArray; + RETURN_NOT_OK(convertDictionaryToDense( + array, properties_->memoryPool(), &denseArray)); + return writeArrowDense( + defLevels, repLevels, numLevels, *denseArray, ctx, maybeParentNulls); }; - if (!IsDictionaryEncoding(current_encoder_->encoding()) || - !DictionaryDirectWriteSupported(array)) { + if (!isDictionaryEncoding(currentEncoder_->encoding()) || + !dictionaryDirectWriteSupported(array)) { // No longer dictionary-encoding for whatever reason, maybe we never were - // or we decided to stop. Note that WriteArrow can be invoked multiple + // or we decided to stop. Note that writeArrow can be invoked multiple // times with both dense and dictionary-encoded versions of the same data // without a problem. Any dense data will be hashed to indices until the // dictionary page limit is reached, at which everything (dictionary and - // dense) will fall back to plain encoding - return WriteDense(); + // dense) will fall back to plain encoding. + return writeDense(); } - auto dict_encoder = dynamic_cast*>(current_encoder_.get()); + auto dictEncoder = dynamic_cast*>(currentEncoder_.get()); const auto& data = checked_cast(array); std::shared_ptr<::arrow::Array> dictionary = data.dictionary(); std::shared_ptr<::arrow::Array> indices = data.indices(); - auto update_stats = [&](int64_t num_chunk_levels, - const std::shared_ptr& chunk_indices) { - // TODO(PARQUET-2068) This approach may make two copies. First, a copy of + auto updateStats = [&](int64_t numChunkLevels, + const std::shared_ptr& chunkIndices) { + // TODO(PARQUET-2068) This approach may make two copies. First, a copy of // the indices array to a (hopefully smaller) referenced indices array. - // Second, a copy of the values array to a (probably not smaller) referenced - // values array. + // Second, a copy of the values array to a (probably not smaller) + // referenced values array. // - // Once the MinMax kernel supports all data types we should use that kernel - // instead as it does not make any copies. - ::arrow::compute::ExecContext exec_ctx(ctx->memory_pool); - exec_ctx.set_use_threads(false); + // Once the MinMax kernel supports all data types we should use that kernel. + // Instead as it does not make any copies. + ::arrow::compute::ExecContext execCtx(ctx->memoryPool); + execCtx.set_use_threads(false); - std::shared_ptr<::arrow::Array> referenced_dictionary; + std::shared_ptr<::arrow::Array> referencedDictionary; PARQUET_ASSIGN_OR_THROW( - ::arrow::Datum referenced_indices, - ::arrow::compute::Unique(*chunk_indices, &exec_ctx)); + ::arrow::Datum referencedIndices, + ::arrow::compute::Unique(*chunkIndices, &execCtx)); - // On first run, we might be able to re-use the existing dictionary - if (referenced_indices.length() == dictionary->length()) { - referenced_dictionary = dictionary; + // On first run, we might be able to re-use the existing dictionary. + if (referencedIndices.length() == dictionary->length()) { + referencedDictionary = dictionary; } else { PARQUET_ASSIGN_OR_THROW( - ::arrow::Datum referenced_dictionary_datum, + ::arrow::Datum referencedDictionaryDatum, ::arrow::compute::Take( dictionary, - referenced_indices, - ::arrow::compute::TakeOptions(/*boundscheck=*/false), - &exec_ctx)); - referenced_dictionary = referenced_dictionary_datum.make_array(); + referencedIndices, + ::arrow::compute::TakeOptions::NoBoundsCheck(), + &execCtx)); + referencedDictionary = referencedDictionaryDatum.make_array(); } - int64_t non_null_count = - chunk_indices->length() - chunk_indices->null_count(); - page_statistics_->IncrementNullCount(num_chunk_levels - non_null_count); - page_statistics_->IncrementNumValues(non_null_count); - page_statistics_->Update(*referenced_dictionary, /*update_counts=*/false); - }; - - int64_t value_offset = 0; - auto WriteIndicesChunk = [&](int64_t offset, - int64_t batch_size, - bool check_page) { - int64_t batch_num_values = 0; - int64_t batch_num_spaced_values = 0; - int64_t null_count = ::arrow::kUnknownNullCount; - // Bits is not null for nullable values. At this point in the code we can't - // determine if the leaf array has the same null values as any parents it - // might have had so we need to recompute it from def levels. - MaybeCalculateValidityBits( - AddIfNotNull(def_levels, offset), - batch_size, - &batch_num_values, - &batch_num_spaced_values, - &null_count); - WriteLevelsSpaced( - batch_size, - AddIfNotNull(def_levels, offset), - AddIfNotNull(rep_levels, offset)); - std::shared_ptr writeable_indices = - indices->Slice(value_offset, batch_num_spaced_values); - if (page_statistics_) { - update_stats(/*num_chunk_levels=*/batch_size, writeable_indices); - } - PARQUET_ASSIGN_OR_THROW( - writeable_indices, - MaybeReplaceValidity(writeable_indices, null_count, ctx->memory_pool)); - dict_encoder->PutIndices(*writeable_indices); - CommitWriteAndCheckPageLimit( - batch_size, batch_num_values, null_count, check_page); - value_offset += batch_num_spaced_values; + int64_t nonNullCount = chunkIndices->length() - chunkIndices->null_count(); + pageStatistics_->incrementNullCount(numChunkLevels - nonNullCount); + pageStatistics_->incrementNumValues(nonNullCount); + pageStatistics_->update(*referencedDictionary, false); }; - // Handle seeing dictionary for the first time - if (!preserved_dictionary_) { - // It's a new dictionary. Call PutDictionary and keep track of it - PARQUET_CATCH_NOT_OK(dict_encoder->PutDictionary(*dictionary)); - - // If there were duplicate value in the dictionary, the encoder's memo table - // will be out of sync with the indices in the Arrow array. - // The easiest solution for this uncommon case is to fallback to plain + int64_t valueOffset = 0; + auto writeIndicesChunk = + [&](int64_t offset, int64_t batchSize, bool checkPage) { + int64_t batchNumValues = 0; + int64_t batchNumSpacedValues = 0; + int64_t nullCount = ::arrow::kUnknownNullCount; + // Bits is not null for nullable values. At this point in the code we + // can't. Determine if the leaf array has the same null values as any + // parents it. Might have had so we need to recompute it from def + // levels. + maybeCalculateValidityBits( + addIfNotNull(defLevels, offset), + batchSize, + &batchNumValues, + &batchNumSpacedValues, + &nullCount); + writeLevelsSpaced( + batchSize, + addIfNotNull(defLevels, offset), + addIfNotNull(repLevels, offset)); + std::shared_ptr writeableIndices = + indices->Slice(valueOffset, batchNumSpacedValues); + if (pageStatistics_) { + updateStats(batchSize, writeableIndices); + } + PARQUET_ASSIGN_OR_THROW( + writeableIndices, + maybeReplaceValidity(writeableIndices, nullCount, ctx->memoryPool)); + dictEncoder->putIndices(*writeableIndices); + commitWriteAndCheckPageLimit( + batchSize, batchNumValues, nullCount, checkPage); + valueOffset += batchNumSpacedValues; + }; + + // Handle seeing dictionary for the first time. + if (!preservedDictionary_) { + // It's a new dictionary. Call PutDictionary and keep track of it. + PARQUET_CATCH_NOT_OK(dictEncoder->putDictionary(*dictionary)); + + // If there were duplicate value in the dictionary, the encoder's memo + // table will be out of sync with the indices in the Arrow array. The + // easiest solution for this uncommon case is to fallback to plain // encoding. - if (dict_encoder->num_entries() != dictionary->length()) { - PARQUET_CATCH_NOT_OK(FallbackToPlainEncoding()); - return WriteDense(); + if (dictEncoder->numEntries() != dictionary->length()) { + PARQUET_CATCH_NOT_OK(fallbackToPlainEncoding()); + return writeDense(); } - preserved_dictionary_ = dictionary; - } else if (!dictionary->Equals(*preserved_dictionary_)) { - // Dictionary has changed - PARQUET_CATCH_NOT_OK(FallbackToPlainEncoding()); - return WriteDense(); + preservedDictionary_ = dictionary; + } else if (!dictionary->Equals(*preservedDictionary_)) { + // Dictionary has changed. + PARQUET_CATCH_NOT_OK(fallbackToPlainEncoding()); + return writeDense(); } - PARQUET_CATCH_NOT_OK(DoInBatches( - def_levels, - rep_levels, - num_levels, - properties_->write_batch_size(), - WriteIndicesChunk, - pages_change_on_record_boundaries())); + PARQUET_CATCH_NOT_OK(doInBatches( + defLevels, + repLevels, + numLevels, + properties_->writeBatchSize(), + writeIndicesChunk, + pagesChangeOnRecordBoundaries())); return Status::OK(); } -// ---------------------------------------------------------------------- -// Direct Arrow write path +// ----------------------------------------------------------------------. +// Direct Arrow write path. template struct SerializeFunctor { using ArrowCType = typename ArrowType::c_type; using ArrayType = typename ::arrow::TypeTraits::ArrayType; - using ParquetCType = typename ParquetType::c_type; + using ParquetCType = typename ParquetType::CType; Status - Serialize(const ArrayType& array, ArrowWriteContext*, ParquetCType* out) { + serialize(const ArrayType& array, ArrowWriteContext*, ParquetCType* out) { const ArrowCType* input = array.raw_values(); if (array.null_count() > 0) { for (int i = 0; i < array.length(); i++) { @@ -2183,34 +2151,34 @@ struct SerializeFunctor { }; template -Status WriteArrowSerialize( +Status writeArrowSerialize( const ::arrow::Array& array, - int64_t num_levels, - const int16_t* def_levels, - const int16_t* rep_levels, + int64_t numLevels, + const int16_t* defLevels, + const int16_t* repLevels, ArrowWriteContext* ctx, TypedColumnWriter* writer, - bool maybe_parent_nulls) { - using ParquetCType = typename ParquetType::c_type; + bool maybeParentNulls) { + using ParquetCType = typename ParquetType::CType; using ArrayType = typename ::arrow::TypeTraits::ArrayType; ParquetCType* buffer = nullptr; PARQUET_THROW_NOT_OK( - ctx->GetScratchData(array.length(), &buffer)); + ctx->getScratchData(array.length(), &buffer)); SerializeFunctor functor; RETURN_NOT_OK( - functor.Serialize(checked_cast(array), ctx, buffer)); - bool no_nulls = writer->descr()->schema_node()->is_required() || - (array.null_count() == 0); - if (!maybe_parent_nulls && no_nulls) { + functor.serialize(checked_cast(array), ctx, buffer)); + bool noNulls = + writer->descr()->schemaNode()->isRequired() || (array.null_count() == 0); + if (!maybeParentNulls && noNulls) { PARQUET_CATCH_NOT_OK( - writer->WriteBatch(num_levels, def_levels, rep_levels, buffer)); + writer->writeBatch(numLevels, defLevels, repLevels, buffer)); } else { - PARQUET_CATCH_NOT_OK(writer->WriteBatchSpaced( - num_levels, - def_levels, - rep_levels, + PARQUET_CATCH_NOT_OK(writer->writeBatchSpaced( + numLevels, + defLevels, + repLevels, array.null_bitmap_data(), array.offset(), buffer)); @@ -2219,34 +2187,34 @@ Status WriteArrowSerialize( } template -Status WriteArrowZeroCopy( +Status writeArrowZeroCopy( const ::arrow::Array& array, - int64_t num_levels, - const int16_t* def_levels, - const int16_t* rep_levels, + int64_t numLevels, + const int16_t* defLevels, + const int16_t* repLevels, ArrowWriteContext* ctx, TypedColumnWriter* writer, - bool maybe_parent_nulls) { - using T = typename ParquetType::c_type; + bool maybeParentNulls) { + using T = typename ParquetType::CType; const auto& data = static_cast(array); const T* values = nullptr; - // The values buffer may be null if the array is empty (ARROW-2744) + // The values buffer may be null if the array is empty (ARROW-2744). if (data.values() != nullptr) { values = reinterpret_cast(data.values()->data()) + data.offset(); } else { VELOX_DCHECK_EQ(data.length(), 0); } - bool no_nulls = writer->descr()->schema_node()->is_required() || - (array.null_count() == 0); + bool noNulls = + writer->descr()->schemaNode()->isRequired() || (array.null_count() == 0); - if (!maybe_parent_nulls && no_nulls) { + if (!maybeParentNulls && noNulls) { PARQUET_CATCH_NOT_OK( - writer->WriteBatch(num_levels, def_levels, rep_levels, values)); + writer->writeBatch(numLevels, defLevels, repLevels, values)); } else { - PARQUET_CATCH_NOT_OK(writer->WriteBatchSpaced( - num_levels, - def_levels, - rep_levels, + PARQUET_CATCH_NOT_OK(writer->writeBatchSpaced( + numLevels, + defLevels, + repLevels, data.null_bitmap_data(), data.offset(), values)); @@ -2254,41 +2222,29 @@ Status WriteArrowZeroCopy( return Status::OK(); } -#define WRITE_SERIALIZE_CASE(ArrowEnum, ArrowType, ParquetType) \ - case ::arrow::Type::ArrowEnum: \ - return WriteArrowSerialize( \ - array, \ - num_levels, \ - def_levels, \ - rep_levels, \ - ctx, \ - this, \ - maybe_parent_nulls); - -#define WRITE_ZERO_COPY_CASE(ArrowEnum, ArrowType, ParquetType) \ - case ::arrow::Type::ArrowEnum: \ - return WriteArrowZeroCopy( \ - array, \ - num_levels, \ - def_levels, \ - rep_levels, \ - ctx, \ - this, \ - maybe_parent_nulls); +#define WRITE_SERIALIZE_CASE(Arrowenum, ArrowType, ParquetType) \ + case ::arrow::Type::Arrowenum: \ + return writeArrowSerialize( \ + array, numLevels, defLevels, repLevels, ctx, this, maybeParentNulls); + +#define WRITE_ZERO_COPY_CASE(Arrowenum, ArrowType, ParquetType) \ + case ::arrow::Type::Arrowenum: \ + return writeArrowZeroCopy( \ + array, numLevels, defLevels, repLevels, ctx, this, maybeParentNulls); #define ARROW_UNSUPPORTED() \ std::stringstream ss; \ ss << "Arrow type " << array.type()->ToString() \ - << " cannot be written to Parquet type " << descr_->ToString(); \ + << " cannot be written to Parquet type " << descr_->toString(); \ return Status::Invalid(ss.str()); -// ---------------------------------------------------------------------- -// Write Arrow to BooleanType +// ----------------------------------------------------------------------. +// Write Arrow to BooleanType. template <> struct SerializeFunctor { Status - Serialize(const ::arrow::BooleanArray& data, ArrowWriteContext*, bool* out) { + serialize(const ::arrow::BooleanArray& data, ArrowWriteContext*, bool* out) { for (int i = 0; i < data.length(); i++) { *out++ = data.Value(i); } @@ -2297,26 +2253,26 @@ struct SerializeFunctor { }; template <> -Status TypedColumnWriterImpl::WriteArrowDense( - const int16_t* def_levels, - const int16_t* rep_levels, - int64_t num_levels, +Status TypedColumnWriterImpl::writeArrowDense( + const int16_t* defLevels, + const int16_t* repLevels, + int64_t numLevels, const ::arrow::Array& array, ArrowWriteContext* ctx, - bool maybe_parent_nulls) { + bool maybeParentNulls) { if (array.type_id() != ::arrow::Type::BOOL) { ARROW_UNSUPPORTED(); } - return WriteArrowSerialize( - array, num_levels, def_levels, rep_levels, ctx, this, maybe_parent_nulls); + return writeArrowSerialize( + array, numLevels, defLevels, repLevels, ctx, this, maybeParentNulls); } -// ---------------------------------------------------------------------- -// Write Arrow types to INT32 +// ----------------------------------------------------------------------. +// Write Arrow types to INT32. template <> struct SerializeFunctor { - Status Serialize( + Status serialize( const ::arrow::Date64Array& array, ArrowWriteContext*, int32_t* out) { @@ -2335,20 +2291,20 @@ struct SerializeFunctor< ::arrow::enable_if_t< ::arrow::is_decimal_type::value&& ::arrow::internal:: IsOneOf::value>> { - using value_type = typename ParquetType::c_type; + using ValueType = typename ParquetType::CType; - Status Serialize( + Status serialize( const typename ::arrow::TypeTraits::ArrayType& array, ArrowWriteContext* ctx, - value_type* out) { + ValueType* out) { if (array.null_count() == 0) { for (int64_t i = 0; i < array.length(); i++) { - out[i] = TransferValue(array.Value(i)); + out[i] = transferValue(array.Value(i)); } } else { for (int64_t i = 0; i < array.length(); i++) { out[i] = array.IsValid(i) - ? TransferValue(array.Value(i)) + ? transferValue(array.Value(i)) : 0; } } @@ -2356,20 +2312,20 @@ struct SerializeFunctor< return Status::OK(); } - template - value_type TransferValue(const uint8_t* in) const { + template + ValueType transferValue(const uint8_t* in) const { static_assert( - byte_width == 16 || byte_width == 32, + byteWidth == 16 || byteWidth == 32, "only 16 and 32 byte Decimals supported"); - value_type value = 0; - if constexpr (byte_width == 16) { - ::arrow::Decimal128 decimal_value(in); - PARQUET_THROW_NOT_OK(decimal_value.ToInteger(&value)); + ValueType value = 0; + if constexpr (byteWidth == 16) { + ::arrow::Decimal128 decimalValue(in); + PARQUET_ASSIGN_OR_THROW(value, decimalValue.ToInteger()); } else { - ::arrow::Decimal256 decimal_value(in); + ::arrow::Decimal256 decimalValue(in); // Decimal256 does not provide ToInteger, but we are sure it fits in the // target integer type. - value = static_cast(decimal_value.low_bits()); + value = static_cast(decimalValue.low_bits()); } return value; } @@ -2377,7 +2333,7 @@ struct SerializeFunctor< template <> struct SerializeFunctor { - Status Serialize( + Status serialize( const ::arrow::Time32Array& array, ArrowWriteContext*, int32_t* out) { @@ -2395,17 +2351,17 @@ struct SerializeFunctor { }; template <> -Status TypedColumnWriterImpl::WriteArrowDense( - const int16_t* def_levels, - const int16_t* rep_levels, - int64_t num_levels, +Status TypedColumnWriterImpl::writeArrowDense( + const int16_t* defLevels, + const int16_t* repLevels, + int64_t numLevels, const ::arrow::Array& array, ArrowWriteContext* ctx, - bool maybe_parent_nulls) { + bool maybeParentNulls) { switch (array.type()->id()) { case ::arrow::Type::NA: { PARQUET_CATCH_NOT_OK( - WriteBatch(num_levels, def_levels, rep_levels, nullptr)); + writeBatch(numLevels, defLevels, repLevels, nullptr)); } break; WRITE_SERIALIZE_CASE(INT8, Int8Type, Int32Type) WRITE_SERIALIZE_CASE(UINT8, UInt8Type, Int32Type) @@ -2424,16 +2380,16 @@ Status TypedColumnWriterImpl::WriteArrowDense( return Status::OK(); } -// ---------------------------------------------------------------------- -// Write Arrow to Int64 and Int96 +// ----------------------------------------------------------------------. +// Write Arrow to Int64 and Int96. -#define INT96_CONVERT_LOOP(ConversionFunction) \ +#define INT96_CONVERT_LOOP(conversionFunction) \ for (int64_t i = 0; i < array.length(); i++) \ - ConversionFunction(input[i], &out[i]); + conversionFunction(input[i], &out[i]); template <> struct SerializeFunctor { - Status Serialize( + Status serialize( const ::arrow::TimestampArray& array, ArrowWriteContext*, Int96* out) { @@ -2442,16 +2398,16 @@ struct SerializeFunctor { static_cast(*array.type()); switch (type.unit()) { case ::arrow::TimeUnit::NANO: - INT96_CONVERT_LOOP(internal::NanosecondsToImpalaTimestamp); + INT96_CONVERT_LOOP(internal::nanosecondsToImpalaTimestamp); break; case ::arrow::TimeUnit::MICRO: - INT96_CONVERT_LOOP(internal::MicrosecondsToImpalaTimestamp); + INT96_CONVERT_LOOP(internal::microsecondsToImpalaTimestamp); break; case ::arrow::TimeUnit::MILLI: - INT96_CONVERT_LOOP(internal::MillisecondsToImpalaTimestamp); + INT96_CONVERT_LOOP(internal::millisecondsToImpalaTimestamp); break; case ::arrow::TimeUnit::SECOND: - INT96_CONVERT_LOOP(internal::SecondsToImpalaTimestamp); + INT96_CONVERT_LOOP(internal::secondsToImpalaTimestamp); break; } return Status::OK(); @@ -2463,22 +2419,22 @@ struct SerializeFunctor { #define COERCE_MULTIPLY +1 static std::pair kTimestampCoercionFactors[4][4] = { - // from seconds ... + // From seconds ... {{COERCE_INVALID, 0}, // ... to seconds {COERCE_MULTIPLY, 1000}, // ... to millis {COERCE_MULTIPLY, 1000000}, // ... to micros {COERCE_MULTIPLY, INT64_C(1000000000)}}, // ... to nanos - // from millis ... + // From millis ... {{COERCE_INVALID, 0}, {COERCE_MULTIPLY, 1}, {COERCE_MULTIPLY, 1000}, {COERCE_MULTIPLY, 1000000}}, - // from micros ... + // From micros ... {{COERCE_INVALID, 0}, {COERCE_DIVIDE, 1000}, {COERCE_MULTIPLY, 1}, {COERCE_MULTIPLY, 1000}}, - // from nanos ... + // From nanos ... {{COERCE_INVALID, 0}, {COERCE_DIVIDE, 1000000}, {COERCE_DIVIDE, 1000}, @@ -2486,29 +2442,29 @@ static std::pair kTimestampCoercionFactors[4][4] = { template <> struct SerializeFunctor { - Status Serialize( + Status serialize( const ::arrow::TimestampArray& array, ArrowWriteContext* ctx, int64_t* out) { - const auto& source_type = + const auto& sourceType = static_cast(*array.type()); - auto source_unit = source_type.unit(); + auto sourceUnit = sourceType.unit(); const int64_t* values = array.raw_values(); - ::arrow::TimeUnit::type target_unit = - ctx->properties->coerce_timestamps_unit(); - auto target_type = ::arrow::timestamp(target_unit); - bool truncation_allowed = ctx->properties->truncated_timestamps_allowed(); + ::arrow::TimeUnit::type targetUnit = + ctx->properties->coerceTimestampsUnit(); + auto targetType = ::arrow::timestamp(targetUnit); + bool truncationAllowed = ctx->properties->truncatedTimestampsAllowed(); - auto DivideBy = [&](const int64_t factor) { + auto divideBy = [&](const int64_t factor) { for (int64_t i = 0; i < array.length(); i++) { - if (!truncation_allowed && array.IsValid(i) && + if (!truncationAllowed && array.IsValid(i) && (values[i] % factor != 0)) { return Status::Invalid( "Casting from ", - source_type.ToString(), + sourceType.ToString(), " to ", - target_type->ToString(), + targetType->ToString(), " would lose data: ", values[i]); } @@ -2517,7 +2473,7 @@ struct SerializeFunctor { return Status::OK(); }; - auto MultiplyBy = [&](const int64_t factor) { + auto multiplyBy = [&](const int64_t factor) { for (int64_t i = 0; i < array.length(); i++) { out[i] = values[i] * factor; } @@ -2525,13 +2481,13 @@ struct SerializeFunctor { }; const auto& coercion = - kTimestampCoercionFactors[static_cast(source_unit)] - [static_cast(target_unit)]; + kTimestampCoercionFactors[static_cast(sourceUnit)] + [static_cast(targetUnit)]; - // .first -> coercion operation; .second -> scale factor + // first -> coercion operation; second -> scale factor. VELOX_DCHECK_NE(coercion.first, COERCE_INVALID); - return coercion.first == COERCE_DIVIDE ? DivideBy(coercion.second) - : MultiplyBy(coercion.second); + return coercion.first == COERCE_DIVIDE ? divideBy(coercion.second) + : multiplyBy(coercion.second); } }; @@ -2539,98 +2495,85 @@ struct SerializeFunctor { #undef COERCE_INVALID #undef COERCE_MULTIPLY -Status WriteTimestamps( +Status writeTimestamps( const ::arrow::Array& values, - int64_t num_levels, - const int16_t* def_levels, - const int16_t* rep_levels, + int64_t numLevels, + const int16_t* defLevels, + const int16_t* repLevels, ArrowWriteContext* ctx, TypedColumnWriter* writer, - bool maybe_parent_nulls) { - const auto& source_type = + bool maybeParentNulls) { + const auto& sourceType = static_cast(*values.type()); - auto WriteCoerce = [&](const ArrowWriterProperties* properties) { - ArrowWriteContext temp_ctx = *ctx; - temp_ctx.properties = properties; - return WriteArrowSerialize( + auto writeCoerce = [&](const ArrowWriterProperties* properties) { + ArrowWriteContext tempCtx = *ctx; + tempCtx.properties = properties; + return writeArrowSerialize( values, - num_levels, - def_levels, - rep_levels, - &temp_ctx, + numLevels, + defLevels, + repLevels, + &tempCtx, writer, - maybe_parent_nulls); + maybeParentNulls); }; const ParquetVersion::type version = writer->properties()->version(); - if (ctx->properties->coerce_timestamps_enabled()) { - // User explicitly requested coercion to specific unit - if (source_type.unit() == ctx->properties->coerce_timestamps_unit()) { - // No data conversion necessary - return WriteArrowZeroCopy( + if (ctx->properties->coerceTimestampsEnabled()) { + // User explicitly requested coercion to specific unit. + if (sourceType.unit() == ctx->properties->coerceTimestampsUnit()) { + // No data conversion necessary. + return writeArrowZeroCopy( values, - num_levels, - def_levels, - rep_levels, + numLevels, + defLevels, + repLevels, ctx, writer, - maybe_parent_nulls); + maybeParentNulls); } else { - return WriteCoerce(ctx->properties); + return writeCoerce(ctx->properties); } } else if ( (version == ParquetVersion::PARQUET_1_0 || version == ParquetVersion::PARQUET_2_4) && - source_type.unit() == ::arrow::TimeUnit::NANO) { - // Absent superseding user instructions, when writing Parquet version <= 2.4 - // files, timestamps in nanoseconds are coerced to microseconds + sourceType.unit() == ::arrow::TimeUnit::NANO) { + // Absent superseding user instructions, when writing Parquet version + // Files, timestamps in nanoseconds are coerced to microseconds. std::shared_ptr properties = (ArrowWriterProperties::Builder()) - .coerce_timestamps(::arrow::TimeUnit::MICRO) - ->disallow_truncated_timestamps() + .coerceTimestamps(::arrow::TimeUnit::MICRO) + ->disallowTruncatedTimestamps() ->build(); - return WriteCoerce(properties.get()); - } else if (source_type.unit() == ::arrow::TimeUnit::SECOND) { - // Absent superseding user instructions, timestamps in seconds are coerced - // to milliseconds + return writeCoerce(properties.get()); + } else if (sourceType.unit() == ::arrow::TimeUnit::SECOND) { + // To milliseconds. std::shared_ptr properties = (ArrowWriterProperties::Builder()) - .coerce_timestamps(::arrow::TimeUnit::MILLI) + .coerceTimestamps(::arrow::TimeUnit::MILLI) ->build(); - return WriteCoerce(properties.get()); + return writeCoerce(properties.get()); } else { - // No data conversion necessary - return WriteArrowZeroCopy( - values, - num_levels, - def_levels, - rep_levels, - ctx, - writer, - maybe_parent_nulls); + // No data conversion necessary. + return writeArrowZeroCopy( + values, numLevels, defLevels, repLevels, ctx, writer, maybeParentNulls); } } template <> -Status TypedColumnWriterImpl::WriteArrowDense( - const int16_t* def_levels, - const int16_t* rep_levels, - int64_t num_levels, +Status TypedColumnWriterImpl::writeArrowDense( + const int16_t* defLevels, + const int16_t* repLevels, + int64_t numLevels, const ::arrow::Array& array, ArrowWriteContext* ctx, - bool maybe_parent_nulls) { + bool maybeParentNulls) { switch (array.type()->id()) { case ::arrow::Type::TIMESTAMP: - return WriteTimestamps( - array, - num_levels, - def_levels, - rep_levels, - ctx, - this, - maybe_parent_nulls); + return writeTimestamps( + array, numLevels, defLevels, repLevels, ctx, this, maybeParentNulls); WRITE_ZERO_COPY_CASE(INT64, Int64Type, Int64Type) WRITE_SERIALIZE_CASE(UINT32, UInt32Type, Int64Type) WRITE_SERIALIZE_CASE(UINT64, UInt64Type, Int64Type) @@ -2644,116 +2587,115 @@ Status TypedColumnWriterImpl::WriteArrowDense( } template <> -Status TypedColumnWriterImpl::WriteArrowDense( - const int16_t* def_levels, - const int16_t* rep_levels, - int64_t num_levels, +Status TypedColumnWriterImpl::writeArrowDense( + const int16_t* defLevels, + const int16_t* repLevels, + int64_t numLevels, const ::arrow::Array& array, ArrowWriteContext* ctx, - bool maybe_parent_nulls) { + bool maybeParentNulls) { if (array.type_id() != ::arrow::Type::TIMESTAMP) { ARROW_UNSUPPORTED(); } - return WriteArrowSerialize( - array, num_levels, def_levels, rep_levels, ctx, this, maybe_parent_nulls); + return writeArrowSerialize( + array, numLevels, defLevels, repLevels, ctx, this, maybeParentNulls); } -// ---------------------------------------------------------------------- -// Floating point types +// ----------------------------------------------------------------------. +// Floating point types. template <> -Status TypedColumnWriterImpl::WriteArrowDense( - const int16_t* def_levels, - const int16_t* rep_levels, - int64_t num_levels, +Status TypedColumnWriterImpl::writeArrowDense( + const int16_t* defLevels, + const int16_t* repLevels, + int64_t numLevels, const ::arrow::Array& array, ArrowWriteContext* ctx, - bool maybe_parent_nulls) { + bool maybeParentNulls) { if (array.type_id() != ::arrow::Type::FLOAT) { ARROW_UNSUPPORTED(); } - return WriteArrowZeroCopy( - array, num_levels, def_levels, rep_levels, ctx, this, maybe_parent_nulls); + return writeArrowZeroCopy( + array, numLevels, defLevels, repLevels, ctx, this, maybeParentNulls); } template <> -Status TypedColumnWriterImpl::WriteArrowDense( - const int16_t* def_levels, - const int16_t* rep_levels, - int64_t num_levels, +Status TypedColumnWriterImpl::writeArrowDense( + const int16_t* defLevels, + const int16_t* repLevels, + int64_t numLevels, const ::arrow::Array& array, ArrowWriteContext* ctx, - bool maybe_parent_nulls) { + bool maybeParentNulls) { if (array.type_id() != ::arrow::Type::DOUBLE) { ARROW_UNSUPPORTED(); } - return WriteArrowZeroCopy( - array, num_levels, def_levels, rep_levels, ctx, this, maybe_parent_nulls); + return writeArrowZeroCopy( + array, numLevels, defLevels, repLevels, ctx, this, maybeParentNulls); } -// ---------------------------------------------------------------------- -// Write Arrow to BYTE_ARRAY +// ----------------------------------------------------------------------. +// Write Arrow to BYTE_ARRAY. template <> -Status TypedColumnWriterImpl::WriteArrowDense( - const int16_t* def_levels, - const int16_t* rep_levels, - int64_t num_levels, +Status TypedColumnWriterImpl::writeArrowDense( + const int16_t* defLevels, + const int16_t* repLevels, + int64_t numLevels, const ::arrow::Array& array, ArrowWriteContext* ctx, - bool maybe_parent_nulls) { + bool maybeParentNulls) { if (!::arrow::is_base_binary_like(array.type()->id())) { ARROW_UNSUPPORTED(); } - int64_t value_offset = 0; - auto WriteChunk = [&](int64_t offset, int64_t batch_size, bool check_page) { - int64_t batch_num_values = 0; - int64_t batch_num_spaced_values = 0; - int64_t null_count = 0; - - MaybeCalculateValidityBits( - AddIfNotNull(def_levels, offset), - batch_size, - &batch_num_values, - &batch_num_spaced_values, - &null_count); - WriteLevelsSpaced( - batch_size, - AddIfNotNull(def_levels, offset), - AddIfNotNull(rep_levels, offset)); - std::shared_ptr data_slice = - array.Slice(value_offset, batch_num_spaced_values); + int64_t valueOffset = 0; + auto writeChunk = [&](int64_t offset, int64_t batchSize, bool checkPage) { + int64_t batchNumValues = 0; + int64_t batchNumSpacedValues = 0; + int64_t nullCount = 0; + + maybeCalculateValidityBits( + addIfNotNull(defLevels, offset), + batchSize, + &batchNumValues, + &batchNumSpacedValues, + &nullCount); + writeLevelsSpaced( + batchSize, + addIfNotNull(defLevels, offset), + addIfNotNull(repLevels, offset)); + std::shared_ptr dataSlice = + array.Slice(valueOffset, batchNumSpacedValues); PARQUET_ASSIGN_OR_THROW( - data_slice, - MaybeReplaceValidity(data_slice, null_count, ctx->memory_pool)); + dataSlice, maybeReplaceValidity(dataSlice, nullCount, ctx->memoryPool)); - current_encoder_->Put(*data_slice); + currentEncoder_->put(*dataSlice); // Null values in ancestors count as nulls. - const int64_t non_null = data_slice->length() - data_slice->null_count(); - if (page_statistics_ != nullptr) { - page_statistics_->Update(*data_slice, /*update_counts=*/false); - page_statistics_->IncrementNullCount(batch_size - non_null); - page_statistics_->IncrementNumValues(non_null); - } - CommitWriteAndCheckPageLimit( - batch_size, batch_num_values, batch_size - non_null, check_page); - CheckDictionarySizeLimit(); - value_offset += batch_num_spaced_values; + const int64_t nonNull = dataSlice->length() - dataSlice->null_count(); + if (pageStatistics_ != nullptr) { + pageStatistics_->update(*dataSlice, false); + pageStatistics_->incrementNullCount(batchSize - nonNull); + pageStatistics_->incrementNumValues(nonNull); + } + commitWriteAndCheckPageLimit( + batchSize, batchNumValues, batchSize - nonNull, checkPage); + checkDictionarySizeLimit(); + valueOffset += batchNumSpacedValues; }; - PARQUET_CATCH_NOT_OK(DoInBatches( - def_levels, - rep_levels, - num_levels, - properties_->write_batch_size(), - WriteChunk, - pages_change_on_record_boundaries())); + PARQUET_CATCH_NOT_OK(doInBatches( + defLevels, + repLevels, + numLevels, + properties_->writeBatchSize(), + writeChunk, + pagesChangeOnRecordBoundaries())); return Status::OK(); } -// ---------------------------------------------------------------------- -// Write Arrow to FIXED_LEN_BYTE_ARRAY +// ----------------------------------------------------------------------. +// Write Arrow to FIXED_LEN_BYTE_ARRAY. template struct SerializeFunctor< @@ -2762,13 +2704,13 @@ struct SerializeFunctor< ::arrow::enable_if_t< ::arrow::is_fixed_size_binary_type::value && !::arrow::is_decimal_type::value>> { - Status Serialize( + Status serialize( const ::arrow::FixedSizeBinaryArray& array, ArrowWriteContext*, FLBA* out) { if (array.null_count() == 0) { - // no nulls, just dump the data - // todo(advancedxy): use a writeBatch to avoid this step + // No nulls, just dump the data. + // Todo(advancedxy): use a writeBatch to avoid this step. for (int64_t i = 0; i < array.length(); i++) { out[i] = FixedLenByteArray(array.GetValue(i)); } @@ -2783,8 +2725,8 @@ struct SerializeFunctor< } }; -// ---------------------------------------------------------------------- -// Write Arrow to Decimal128 +// ----------------------------------------------------------------------. +// Write Arrow to Decimal128. // Requires a custom serializer because decimal in parquet are in big-endian // format. Thus, a temporary local buffer is required. @@ -2796,22 +2738,22 @@ struct SerializeFunctor< ::arrow::is_decimal_type::value && !::arrow::internal::IsOneOf:: value>> { - Status Serialize( + Status serialize( const typename ::arrow::TypeTraits::ArrayType& array, ArrowWriteContext* ctx, FLBA* out) { - AllocateScratch(array, ctx); - auto offset = Offset(array); + allocateScratch(array, ctx); + auto decimalOffsetValue = decimalOffset(array); if (array.null_count() == 0) { for (int64_t i = 0; i < array.length(); i++) { - out[i] = FixDecimalEndianess( - array.GetValue(i), offset); + out[i] = fixDecimalEndianess( + array.GetValue(i), decimalOffsetValue); } } else { for (int64_t i = 0; i < array.length(); i++) { - out[i] = array.IsValid(i) ? FixDecimalEndianess( - array.GetValue(i), offset) + out[i] = array.IsValid(i) ? fixDecimalEndianess( + array.GetValue(i), decimalOffsetValue) : FixedLenByteArray(); } } @@ -2819,57 +2761,56 @@ struct SerializeFunctor< return Status::OK(); } - // Parquet's Decimal are stored with FixedLength values where the length is - // proportional to the precision. Arrow's Decimal are always stored with 16/32 - // bytes. Thus the internal FLBA pointer must be adjusted by the offset - // calculated here. - int32_t Offset(const Array& array) { - auto decimal_type = - checked_pointer_cast<::arrow::DecimalType>(array.type()); - return decimal_type->byte_width() - - ::arrow::DecimalType::DecimalSize(decimal_type->precision()); + // Parquet's decimals are stored with FixedLength values where the + // length is proportional to the precision. Arrow's Decimal are always stored + // with 16 or 32 bytes. Thus the internal FLBA pointer must be adjusted by the + // offset calculated here. + int32_t decimalOffset(const Array& array) { + auto decimalType = checked_pointer_cast<::arrow::DecimalType>(array.type()); + return decimalType->byte_width() - + ::arrow::DecimalType::DecimalSize(decimalType->precision()); } - void AllocateScratch( + void allocateScratch( const typename ::arrow::TypeTraits::ArrayType& array, ArrowWriteContext* ctx) { - int64_t non_null_count = array.length() - array.null_count(); - int64_t size = non_null_count * ArrowType::kByteWidth; - scratch_buffer = AllocateBuffer(ctx->memory_pool, size); - scratch = reinterpret_cast(scratch_buffer->mutable_data()); + int64_t nonNullCount = array.length() - array.null_count(); + int64_t size = nonNullCount * ArrowType::kByteWidth; + scratchBuffer = allocateBuffer(ctx->memoryPool, size); + scratch = reinterpret_cast(scratchBuffer->mutable_data()); } - template - FixedLenByteArray FixDecimalEndianess(const uint8_t* in, int64_t offset) { - const auto* u64_in = reinterpret_cast(in); + template + FixedLenByteArray fixDecimalEndianess(const uint8_t* in, int64_t offset) { + const auto* u64In = reinterpret_cast(in); auto out = reinterpret_cast(scratch) + offset; static_assert( - byte_width == 16 || byte_width == 32, + byteWidth == 16 || byteWidth == 32, "only 16 and 32 byte Decimals supported"); - if (byte_width == 32) { - *scratch++ = ::arrow::bit_util::ToBigEndian(u64_in[3]); - *scratch++ = ::arrow::bit_util::ToBigEndian(u64_in[2]); - *scratch++ = ::arrow::bit_util::ToBigEndian(u64_in[1]); - *scratch++ = ::arrow::bit_util::ToBigEndian(u64_in[0]); + if (byteWidth == 32) { + *scratch++ = ::arrow::bit_util::ToBigEndian(u64In[3]); + *scratch++ = ::arrow::bit_util::ToBigEndian(u64In[2]); + *scratch++ = ::arrow::bit_util::ToBigEndian(u64In[1]); + *scratch++ = ::arrow::bit_util::ToBigEndian(u64In[0]); } else { - *scratch++ = ::arrow::bit_util::ToBigEndian(u64_in[1]); - *scratch++ = ::arrow::bit_util::ToBigEndian(u64_in[0]); + *scratch++ = ::arrow::bit_util::ToBigEndian(u64In[1]); + *scratch++ = ::arrow::bit_util::ToBigEndian(u64In[0]); } return FixedLenByteArray(out); } - std::shared_ptr scratch_buffer; + std::shared_ptr scratchBuffer; int64_t* scratch; }; template <> -Status TypedColumnWriterImpl::WriteArrowDense( - const int16_t* def_levels, - const int16_t* rep_levels, - int64_t num_levels, +Status TypedColumnWriterImpl::writeArrowDense( + const int16_t* defLevels, + const int16_t* repLevels, + int64_t numLevels, const ::arrow::Array& array, ArrowWriteContext* ctx, - bool maybe_parent_nulls) { + bool maybeParentNulls) { switch (array.type()->id()) { WRITE_SERIALIZE_CASE(FIXED_SIZE_BINARY, FixedSizeBinaryType, FLBAType) WRITE_SERIALIZE_CASE(DECIMAL128, Decimal128Type, FLBAType) @@ -2880,60 +2821,60 @@ Status TypedColumnWriterImpl::WriteArrowDense( return Status::OK(); } -// ---------------------------------------------------------------------- -// Dynamic column writer constructor +// ----------------------------------------------------------------------. +// Dynamic column writer constructor. -std::shared_ptr ColumnWriter::Make( +std::shared_ptr ColumnWriter::make( ColumnChunkMetaDataBuilder* metadata, std::unique_ptr pager, const WriterProperties* properties) { const ColumnDescriptor* descr = metadata->descr(); - const bool use_dictionary = properties->dictionary_enabled(descr->path()) && - descr->physical_type() != Type::BOOLEAN; + const bool useDictionary = properties->dictionaryEnabled(descr->path()) && + descr->physicalType() != Type::kBoolean; Encoding::type encoding = properties->encoding(descr->path()); - if (encoding == Encoding::UNKNOWN) { + if (encoding == Encoding::kUnknown) { // TODO: Arrow uses RLE by default for boolean columns. Since Velox can't // read RLEs yet, we disable this check. Re-enable once Velox's native // reader supports RLE. - // encoding = (descr->physical_type() == Type::BOOLEAN && + // Encoding = (descr->physical_type() == Type::kBoolean &&. // properties->version() != ParquetVersion::PARQUET_1_0) - // ? Encoding::RLE + // ? Encoding::RLE. // : Encoding::PLAIN; - encoding = Encoding::PLAIN; + encoding = Encoding::kPlain; } - if (use_dictionary) { - encoding = properties->dictionary_index_encoding(); + if (useDictionary) { + encoding = properties->dictionaryIndexEncoding(); } - switch (descr->physical_type()) { - case Type::BOOLEAN: + switch (descr->physicalType()) { + case Type::kBoolean: return std::make_shared>( - metadata, std::move(pager), use_dictionary, encoding, properties); - case Type::INT32: + metadata, std::move(pager), useDictionary, encoding, properties); + case Type::kInt32: return std::make_shared>( - metadata, std::move(pager), use_dictionary, encoding, properties); - case Type::INT64: + metadata, std::move(pager), useDictionary, encoding, properties); + case Type::kInt64: return std::make_shared>( - metadata, std::move(pager), use_dictionary, encoding, properties); - case Type::INT96: + metadata, std::move(pager), useDictionary, encoding, properties); + case Type::kInt96: return std::make_shared>( - metadata, std::move(pager), use_dictionary, encoding, properties); - case Type::FLOAT: + metadata, std::move(pager), useDictionary, encoding, properties); + case Type::kFloat: return std::make_shared>( - metadata, std::move(pager), use_dictionary, encoding, properties); - case Type::DOUBLE: + metadata, std::move(pager), useDictionary, encoding, properties); + case Type::kDouble: return std::make_shared>( - metadata, std::move(pager), use_dictionary, encoding, properties); - case Type::BYTE_ARRAY: + metadata, std::move(pager), useDictionary, encoding, properties); + case Type::kByteArray: return std::make_shared>( - metadata, std::move(pager), use_dictionary, encoding, properties); - case Type::FIXED_LEN_BYTE_ARRAY: + metadata, std::move(pager), useDictionary, encoding, properties); + case Type::kFixedLenByteArray: return std::make_shared>( - metadata, std::move(pager), use_dictionary, encoding, properties); + metadata, std::move(pager), useDictionary, encoding, properties); default: ParquetException::NYI("type reader not implemented"); } - // Unreachable code, but suppress compiler warning + // Unreachable code, but suppress compiler warning. return std::shared_ptr(nullptr); } diff --git a/velox/dwio/parquet/writer/arrow/ColumnWriter.h b/velox/dwio/parquet/writer/arrow/ColumnWriter.h index 90a9b96ab65..52699ca2969 100644 --- a/velox/dwio/parquet/writer/arrow/ColumnWriter.h +++ b/velox/dwio/parquet/writer/arrow/ColumnWriter.h @@ -60,218 +60,218 @@ class PARQUET_EXPORT LevelEncoder { LevelEncoder(); ~LevelEncoder(); - static int MaxBufferSize( + static int maxBufferSize( Encoding::type encoding, - int16_t max_level, - int num_buffered_values); + int16_t maxLevel, + int numBufferedValues); // Initialize the LevelEncoder. - void Init( + void init( Encoding::type encoding, - int16_t max_level, - int num_buffered_values, + int16_t maxLevel, + int numBufferedValues, uint8_t* data, - int data_size); + int dataSize); // Encodes a batch of levels from an array and returns the number of levels - // encoded - int Encode(int batch_size, const int16_t* levels); + // encoded. + int encode(int batchSize, const int16_t* levels); int32_t len() { - if (encoding_ != Encoding::RLE) { + if (encoding_ != Encoding::kRle) { throw ParquetException("Only implemented for RLE encoding"); } - return rle_length_; + return rleLength_; } private: - int bit_width_; - int rle_length_; + int bitWidth_; + int rleLength_; Encoding::type encoding_; - std::unique_ptr rle_encoder_; - std::unique_ptr bit_packed_encoder_; + std::unique_ptr rleEncoder_; + std::unique_ptr bitPackedEncoder_; }; class PARQUET_EXPORT PageWriter { public: virtual ~PageWriter() {} - static std::unique_ptr Open( + static std::unique_ptr open( std::shared_ptr sink, Compression::type codec, ColumnChunkMetaDataBuilder* metadata, - int16_t row_group_ordinal = -1, - int16_t column_chunk_ordinal = -1, + int16_t rowGroupOrdinal = -1, + int16_t columnChunkOrdinal = -1, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool(), - bool buffered_row_group = false, - std::shared_ptr header_encryptor = NULLPTR, - std::shared_ptr data_encryptor = NULLPTR, - bool page_write_checksum_enabled = false, - // column_index_builder MUST outlive the PageWriter - ColumnIndexBuilder* column_index_builder = NULLPTR, - // offset_index_builder MUST outlive the PageWriter - OffsetIndexBuilder* offset_index_builder = NULLPTR, - const util::CodecOptions& codec_options = util::CodecOptions{}); + bool bufferedRowGroup = false, + std::shared_ptr headerEncryptor = NULLPTR, + std::shared_ptr dataEncryptor = NULLPTR, + bool pageWriteChecksumEnabled = false, + // columnIndexBuilder must outlive the PageWriter. + ColumnIndexBuilder* columnIndexBuilder = NULLPTR, + // offsetIndexBuilder must outlive the PageWriter. + OffsetIndexBuilder* offsetIndexBuilder = NULLPTR, + const util::CodecOptions& codecOptions = util::CodecOptions{}); // TODO: remove this and port to new signature. // ARROW_DEPRECATED( - // "Deprecated in 13.0.0. Use CodecOptions-taking overload instead.") - static std::unique_ptr Open( + // "Deprecated in 13.0.0. Use codecOptions-taking overload instead.") + static std::unique_ptr open( std::shared_ptr sink, Compression::type codec, - int compression_level, + int compressionLevel, ColumnChunkMetaDataBuilder* metadata, - int16_t row_group_ordinal = -1, - int16_t column_chunk_ordinal = -1, + int16_t rowGroupOrdinal = -1, + int16_t columnChunkOrdinal = -1, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool(), - bool buffered_row_group = false, - std::shared_ptr header_encryptor = NULLPTR, - std::shared_ptr data_encryptor = NULLPTR, - bool page_write_checksum_enabled = false, - // column_index_builder MUST outlive the PageWriter - ColumnIndexBuilder* column_index_builder = NULLPTR, - // offset_index_builder MUST outlive the PageWriter - OffsetIndexBuilder* offset_index_builder = NULLPTR); - - // The Column Writer decides if dictionary encoding is used if set and + bool bufferedRowGroup = false, + std::shared_ptr headerEncryptor = NULLPTR, + std::shared_ptr dataEncryptor = NULLPTR, + bool pageWriteChecksumEnabled = false, + // columnIndexBuilder must outlive the PageWriter. + ColumnIndexBuilder* columnIndexBuilder = NULLPTR, + // offsetIndexBuilder must outlive the PageWriter. + OffsetIndexBuilder* offsetIndexBuilder = NULLPTR); + + // The column writer decides if dictionary encoding is used, if set, and // if the dictionary encoding has fallen back to default encoding on reaching - // dictionary page limit - virtual void Close(bool has_dictionary, bool fallback) = 0; + // dictionary page limit. + virtual void close(bool hasDictionary, bool fallback) = 0; - // Return the number of uncompressed bytes written (including header size) - virtual int64_t WriteDataPage(const DataPage& page) = 0; + // Return the number of uncompressed bytes written (including header size). + virtual int64_t writeDataPage(const DataPage& page) = 0; - // Return the number of uncompressed bytes written (including header size) - virtual int64_t WriteDictionaryPage(const DictionaryPage& page) = 0; + // Return the number of uncompressed bytes written (including header size). + virtual int64_t writeDictionaryPage(const DictionaryPage& page) = 0; /// \brief The total number of bytes written as serialized data and /// dictionary pages to the sink so far. - virtual int64_t total_compressed_bytes_written() const = 0; + virtual int64_t totalCompressedBytesWritten() const = 0; - virtual bool has_compressor() = 0; + virtual bool hasCompressor() = 0; - virtual void Compress( - const ::arrow::Buffer& src_buffer, - ::arrow::ResizableBuffer* dest_buffer) = 0; + virtual void compress( + const ::arrow::Buffer& srcBuffer, + ::arrow::ResizableBuffer* destBuffer) = 0; }; class PARQUET_EXPORT ColumnWriter { public: virtual ~ColumnWriter() = default; - static std::shared_ptr Make( + static std::shared_ptr make( ColumnChunkMetaDataBuilder*, std::unique_ptr, const WriterProperties* properties); /// \brief Closes the ColumnWriter, commits any buffered values to pages. - /// \return Total size of the column in bytes - virtual int64_t Close() = 0; + /// \return Total size of the column in bytes. + virtual int64_t close() = 0; - /// \brief The physical Parquet type of the column + /// \brief The physical Parquet type of the column. virtual Type::type type() const = 0; - /// \brief The schema for the column + /// \brief The schema for the column. virtual const ColumnDescriptor* descr() const = 0; - /// \brief The number of rows written so far - virtual int64_t rows_written() const = 0; + /// \brief The number of rows written so far. + virtual int64_t rowsWritten() const = 0; /// \brief The total size of the compressed pages + page headers. Values - /// are still buffered and not written to a pager yet + /// are still buffered and not written to a pager yet. /// - /// So in un-buffered mode, it always returns 0 - virtual int64_t total_compressed_bytes() const = 0; + /// So in unbuffered mode, it always returns 0. + virtual int64_t totalCompressedBytes() const = 0; /// \brief The total number of bytes written as serialized data and - /// dictionary pages to the ColumnChunk so far + /// dictionary pages to the ColumnChunk so far. /// These bytes are uncompressed bytes. - virtual int64_t total_bytes_written() const = 0; + virtual int64_t totalBytesWritten() const = 0; /// \brief The total number of bytes written as serialized data and /// dictionary pages to the ColumnChunk so far. /// If the column is uncompressed, the value would be equal to - /// total_bytes_written(). - virtual int64_t total_compressed_bytes_written() const = 0; + /// totalBytesWritten(). + virtual int64_t totalCompressedBytesWritten() const = 0; - /// \brief The file-level writer properties + /// \brief The file-level writer properties. virtual const WriterProperties* properties() = 0; /// \brief Write Apache Arrow columnar data directly to ColumnWriter. Returns /// error status if the array data type is not compatible with the concrete /// writer type. /// - /// leaf_array is always a primitive (possibly dictionary encoded type). - /// Leaf_field_nullable indicates whether the leaf array is considered + /// leafArray is always a primitive (possibly dictionary encoded type). + /// leafFieldNullable indicates whether the leaf array is considered /// nullable according to its schema in a Table or its parent array. - virtual ::arrow::Status WriteArrow( - const int16_t* def_levels, - const int16_t* rep_levels, - int64_t num_levels, - const ::arrow::Array& leaf_array, + virtual ::arrow::Status writeArrow( + const int16_t* defLevels, + const int16_t* repLevels, + int64_t numLevels, + const ::arrow::Array& leafArray, ArrowWriteContext* ctx, - bool leaf_field_nullable) = 0; + bool leafFieldNullable) = 0; }; // API to write values to a single column. This is the main client facing API. template class TypedColumnWriter : public ColumnWriter { public: - using T = typename DType::c_type; + using T = typename DType::CType; // Write a batch of repetition levels, definition levels, and values to the // column. - // `num_values` is the number of logical leaf values. - // `def_levels` (resp. `rep_levels`) can be null if the column's max + // 'numValues' is the number of logical leaf values. + // `defLevels` (resp. `repLevels`) can be null if the column's max // definition level (resp. max repetition level) is 0. If not null, each of - // `def_levels` and `rep_levels` must have at least `num_values`. + // `defLevels` and `repLevels` must have at least `numValues`. // // The number of physical values written (taken from `values`) is returned. - // It can be smaller than `num_values` is there are some undefined values. - virtual int64_t WriteBatch( - int64_t num_values, - const int16_t* def_levels, - const int16_t* rep_levels, + // It can be smaller than `numValues` if there are some undefined values. + virtual int64_t writeBatch( + int64_t numValues, + const int16_t* defLevels, + const int16_t* repLevels, const T* values) = 0; /// Write a batch of repetition levels, definition levels, and values to the /// column. /// - /// In comparison to WriteBatch the length of repetition and definition levels - /// is the same as of the number of values read for max_definition_level == 1. - /// In the case of max_definition_level > 1, the repetition and definition - /// levels are larger than the values but the values include the null entries - /// with definition_level == (max_definition_level - 1). Thus we have to - /// differentiate in the parameters of this function if the input has the - /// length of num_values or the _number of rows in the lowest nesting level_. + /// In comparison to writeBatch() the length of repetition and definition + /// levels is the same as of the number of values read for + /// maxDefinitionLevel == 1. In the case of maxDefinitionLevel > 1, the + /// repetition and definition levels are larger than the values but the values + /// include the null entries with definitionLevel == (maxDefinitionLevel - + /// 1). Thus we have to differentiate in the parameters of this function if + /// the input has the length of numValues or the _number of rows in the lowest + /// nesting level. /// /// In the case that the most inner node in the Parquet is required, the /// _number of rows in the lowest nesting level_ is equal to the number of /// non-null values. If the inner-most schema node is optional, the _number of /// rows in the lowest nesting level_ also includes all values with - /// definition_level == (max_definition_level - 1). + /// definitionLevel == (maxDefinitionLevel - 1). /// - /// @param num_values number of levels to write. - /// @param def_levels The Parquet definition levels, length is num_values - /// @param rep_levels The Parquet repetition levels, length is num_values - /// @param valid_bits Bitmap that indicates if the row is null on the lowest - /// nesting - /// level. The length is number of rows in the lowest nesting level. - /// @param valid_bits_offset The offset in bits of the valid_bits where the + /// @param numValues Number of levels to write. + /// @param defLevels The Parquet definition levels, length is numValues. + /// @param repLevels The Parquet repetition levels, length is numValues. + /// @param validBits Bitmap that indicates if the row is null on the lowest + /// nesting level. The length is number of rows in the lowest nesting level. + /// @param validBitsOffset The offset in bits of the validBits where the /// first relevant bit resides. /// @param values The values in the lowest nested level including /// spacing for nulls on the lowest levels; input has the length /// of the number of rows on the lowest nesting level. - virtual void WriteBatchSpaced( - int64_t num_values, - const int16_t* def_levels, - const int16_t* rep_levels, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + virtual void writeBatchSpaced( + int64_t numValues, + const int16_t* defLevels, + const int16_t* repLevels, + const uint8_t* validBits, + int64_t validBitsOffset, const T* values) = 0; // Estimated size of the values that are not written to a page yet - virtual int64_t EstimatedBufferedValueBytes() const = 0; + virtual int64_t estimatedBufferedValueBytes() const = 0; }; using BoolWriter = TypedColumnWriter; @@ -285,57 +285,57 @@ using FixedLenByteArrayWriter = TypedColumnWriter; namespace internal { -// Timestamp conversion constants +// Timestamp conversion constants. constexpr int64_t kJulianEpochOffsetDays = INT64_C(2440588); template -inline void ArrowTimestampToImpalaTimestamp( - const int64_t time, - Int96* impala_timestamp) { - int64_t julian_days = (time / UnitPerDay) + kJulianEpochOffsetDays; - (*impala_timestamp).value[2] = (uint32_t)julian_days; - - int64_t last_day_units = time % UnitPerDay; - auto last_day_nanos = last_day_units * NanosecondsPerUnit; - // impala_timestamp will be unaligned every other entry so do memcpy instead +inline void arrowTimestampToImpalaTimestamp( + const int64_t Time, + Int96* impalaTimestamp) { + int64_t julianDays = (Time / UnitPerDay) + kJulianEpochOffsetDays; + (*impalaTimestamp).value[2] = (uint32_t)julianDays; + + int64_t lastDayUnits = Time % UnitPerDay; + auto lastDayNanos = lastDayUnits * NanosecondsPerUnit; + // impalaTimestamp will be unaligned every other entry so do memcpy instead // of assign and reinterpret cast to avoid undefined behavior. - std::memcpy(impala_timestamp, &last_day_nanos, sizeof(int64_t)); + std::memcpy(impalaTimestamp, &lastDayNanos, sizeof(int64_t)); } constexpr int64_t kSecondsInNanos = INT64_C(1000000000); -inline void SecondsToImpalaTimestamp( +inline void secondsToImpalaTimestamp( const int64_t seconds, - Int96* impala_timestamp) { - ArrowTimestampToImpalaTimestamp( - seconds, impala_timestamp); + Int96* impalaTimestamp) { + arrowTimestampToImpalaTimestamp( + seconds, impalaTimestamp); } constexpr int64_t kMillisecondsInNanos = kSecondsInNanos / INT64_C(1000); -inline void MillisecondsToImpalaTimestamp( +inline void millisecondsToImpalaTimestamp( const int64_t milliseconds, - Int96* impala_timestamp) { - ArrowTimestampToImpalaTimestamp( - milliseconds, impala_timestamp); + Int96* impalaTimestamp) { + arrowTimestampToImpalaTimestamp( + milliseconds, impalaTimestamp); } constexpr int64_t kMicrosecondsInNanos = kMillisecondsInNanos / INT64_C(1000); -inline void MicrosecondsToImpalaTimestamp( +inline void microsecondsToImpalaTimestamp( const int64_t microseconds, - Int96* impala_timestamp) { - ArrowTimestampToImpalaTimestamp( - microseconds, impala_timestamp); + Int96* impalaTimestamp) { + arrowTimestampToImpalaTimestamp( + microseconds, impalaTimestamp); } constexpr int64_t kNanosecondsInNanos = INT64_C(1); -inline void NanosecondsToImpalaTimestamp( +inline void nanosecondsToImpalaTimestamp( const int64_t nanoseconds, - Int96* impala_timestamp) { - ArrowTimestampToImpalaTimestamp( - nanoseconds, impala_timestamp); + Int96* impalaTimestamp) { + arrowTimestampToImpalaTimestamp( + nanoseconds, impalaTimestamp); } } // namespace internal diff --git a/velox/dwio/parquet/writer/arrow/Encoding.cpp b/velox/dwio/parquet/writer/arrow/Encoding.cpp index 996d6455d06..f907b3e86be 100644 --- a/velox/dwio/parquet/writer/arrow/Encoding.cpp +++ b/velox/dwio/parquet/writer/arrow/Encoding.cpp @@ -40,7 +40,6 @@ #include "arrow/util/bitmap_ops.h" #include "arrow/util/bitmap_writer.h" #include "arrow/util/checked_cast.h" -#include "arrow/util/logging.h" #include "arrow/util/ubsan.h" #include "arrow/visit_data_inline.h" @@ -59,9 +58,9 @@ using ::arrow::MemoryPool; using ::arrow::ResizableBuffer; using arrow::Status; using arrow::VisitNullBitmapInline; -using arrow::internal::AddWithOverflow; +using arrow::internal::addWithOverflow; using arrow::internal::checked_cast; -using arrow::internal::MultiplyWithOverflow; +using arrow::internal::multiplyWithOverflow; using arrow::internal::SubtractWithOverflow; using std::string_view; @@ -80,14 +79,14 @@ inline std::enable_if_t, T> SafeLoadAs( } template -inline std::enable_if_t, T> SafeLoad( +inline std::enable_if_t, T> safeLoad( const T* unaligned) { std::remove_const_t ret; std::memcpy(&ret, unaligned, sizeof(T)); return ret; } -std::shared_ptr AllocateBuffer( +std::shared_ptr allocateBuffer( MemoryPool* pool, int64_t size) { PARQUET_ASSIGN_OR_THROW( @@ -108,93 +107,91 @@ class EncoderImpl : virtual public Encoder { : descr_(descr), encoding_(encoding), pool_(pool), - type_length_(descr ? descr->type_length() : -1) {} + typeLength_(descr ? descr->typeLength() : -1) {} Encoding::type encoding() const override { return encoding_; } - MemoryPool* memory_pool() const override { + MemoryPool* memoryPool() const override { return pool_; } protected: - // For accessing type-specific metadata, like FIXED_LEN_BYTE_ARRAY + // For accessing type-specific metadata, like FIXED_LEN_BYTE_ARRAY. const ColumnDescriptor* descr_; const Encoding::type encoding_; MemoryPool* pool_; - /// Type length from descr - int type_length_; + /// Type length from descr. + int typeLength_; }; -// ---------------------------------------------------------------------- -// Plain encoder implementation +// ----------------------------------------------------------------------. +// Plain encoder implementation. template class PlainEncoder : public EncoderImpl, virtual public TypedEncoder { public: - using T = typename DType::c_type; + using T = typename DType::CType; explicit PlainEncoder(const ColumnDescriptor* descr, MemoryPool* pool) - : EncoderImpl(descr, Encoding::PLAIN, pool), sink_(pool) {} + : EncoderImpl(descr, Encoding::kPlain, pool), sink_(pool) {} - int64_t EstimatedDataEncodedSize() override { + int64_t estimatedDataEncodedSize() override { return sink_.length(); } - std::shared_ptr<::arrow::Buffer> FlushValues() override { + std::shared_ptr<::arrow::Buffer> flushValues() override { std::shared_ptr buffer; PARQUET_THROW_NOT_OK(sink_.Finish(&buffer)); return buffer; } - using TypedEncoder::Put; + using TypedEncoder::put; - void Put(const T* buffer, int num_values) override; + void put(const T* buffer, int numValues) override; - void Put(const ::arrow::Array& values) override; + void put(const ::arrow::Array& values) override; - void PutSpaced( + void putSpaced( const T* src, - int num_values, - const uint8_t* valid_bits, - int64_t valid_bits_offset) override { - if (valid_bits != NULLPTR) { - PARQUET_ASSIGN_OR_THROW( - auto buffer, - ::arrow::AllocateBuffer(num_values * sizeof(T), this->memory_pool())); + int numValues, + const uint8_t* validBits, + int64_t validBitsOffset) override { + if (validBits != NULLPTR) { + auto buffer = allocateBuffer(this->memoryPool(), numValues * sizeof(T)); T* data = reinterpret_cast(buffer->mutable_data()); - int num_valid_values = ::arrow::util::internal::SpacedCompress( - src, num_values, valid_bits, valid_bits_offset, data); - Put(data, num_valid_values); + int numValidValues = ::arrow::util::internal::SpacedCompress( + src, numValues, validBits, validBitsOffset, data); + put(data, numValidValues); } else { - Put(src, num_values); + put(src, numValues); } } - void UnsafePutByteArray(const void* data, uint32_t length) { + void unsafePutByteArray(const void* data, uint32_t length) { VELOX_DCHECK(length == 0 || data != nullptr, "Value ptr cannot be NULL"); sink_.UnsafeAppend(&length, sizeof(uint32_t)); sink_.UnsafeAppend(data, static_cast(length)); } - void Put(const ByteArray& val) { - // Write the result to the output stream + void put(const ByteArray& val) { + // Write the result to the output stream. const int64_t increment = static_cast(val.len + sizeof(uint32_t)); if (ARROW_PREDICT_FALSE(sink_.length() + increment > sink_.capacity())) { PARQUET_THROW_NOT_OK(sink_.Reserve(increment)); } - UnsafePutByteArray(val.ptr, val.len); + unsafePutByteArray(val.ptr, val.len); } protected: template - void PutBinaryArray(const ArrayType& array) { - const int64_t total_bytes = + void putBinaryArray(const ArrayType& array) { + const int64_t totalBytes = array.value_offset(array.length()) - array.value_offset(0); PARQUET_THROW_NOT_OK( - sink_.Reserve(total_bytes + array.length() * sizeof(uint32_t))); + sink_.Reserve(totalBytes + array.length() * sizeof(uint32_t))); PARQUET_THROW_NOT_OK( ::arrow::VisitArraySpanInline( @@ -204,7 +201,7 @@ class PlainEncoder : public EncoderImpl, virtual public TypedEncoder { return Status::Invalid( "Parquet cannot store strings with size 2GB or more"); } - UnsafePutByteArray( + unsafePutByteArray( view.data(), static_cast(view.size())); return Status::OK(); }, @@ -215,125 +212,125 @@ class PlainEncoder : public EncoderImpl, virtual public TypedEncoder { }; template -void PlainEncoder::Put(const T* buffer, int num_values) { - if (num_values > 0) { - PARQUET_THROW_NOT_OK(sink_.Append(buffer, num_values * sizeof(T))); +void PlainEncoder::put(const T* buffer, int numValues) { + if (numValues > 0) { + PARQUET_THROW_NOT_OK(sink_.Append(buffer, numValues * sizeof(T))); } } template <> -inline void PlainEncoder::Put( +inline void PlainEncoder::put( const ByteArray* src, - int num_values) { - for (int i = 0; i < num_values; ++i) { - Put(src[i]); + int numValues) { + for (int i = 0; i < numValues; ++i) { + put(src[i]); } } template -void DirectPutImpl(const ::arrow::Array& values, ::arrow::BufferBuilder* sink) { +void directPutImpl(const ::arrow::Array& values, ::arrow::BufferBuilder* sink) { if (values.type_id() != ArrayType::TypeClass::type_id) { - std::string type_name = ArrayType::TypeClass::type_name(); + std::string typeName = ArrayType::TypeClass::type_name(); throw ParquetException( - "direct put to " + type_name + " from " + values.type()->ToString() + + "direct put to " + typeName + " from " + values.type()->ToString() + " not supported"); } - using value_type = typename ArrayType::value_type; - constexpr auto value_size = sizeof(value_type); - auto raw_values = checked_cast(values).raw_values(); + using ValueType = typename ArrayType::value_type; + constexpr auto valueSize = sizeof(ValueType); + auto rawValues = checked_cast(values).raw_values(); if (values.null_count() == 0) { - // no nulls, just dump the data - PARQUET_THROW_NOT_OK( - sink->Append(raw_values, values.length() * value_size)); + // No nulls, just dump the data. + PARQUET_THROW_NOT_OK(sink->Append(rawValues, values.length() * valueSize)); } else { PARQUET_THROW_NOT_OK( - sink->Reserve((values.length() - values.null_count()) * value_size)); + sink->Reserve((values.length() - values.null_count()) * valueSize)); for (int64_t i = 0; i < values.length(); i++) { if (values.IsValid(i)) { - sink->UnsafeAppend(&raw_values[i], value_size); + sink->UnsafeAppend(&rawValues[i], valueSize); } } } } template <> -void PlainEncoder::Put(const ::arrow::Array& values) { - DirectPutImpl<::arrow::Int32Array>(values, &sink_); +void PlainEncoder::put(const ::arrow::Array& values) { + directPutImpl<::arrow::Int32Array>(values, &sink_); } template <> -void PlainEncoder::Put(const ::arrow::Array& values) { - DirectPutImpl<::arrow::Int64Array>(values, &sink_); +void PlainEncoder::put(const ::arrow::Array& values) { + directPutImpl<::arrow::Int64Array>(values, &sink_); } template <> -void PlainEncoder::Put(const ::arrow::Array& values) { +void PlainEncoder::put(const ::arrow::Array& values) { ParquetException::NYI("direct put to Int96"); } template <> -void PlainEncoder::Put(const ::arrow::Array& values) { - DirectPutImpl<::arrow::FloatArray>(values, &sink_); +void PlainEncoder::put(const ::arrow::Array& values) { + directPutImpl<::arrow::FloatArray>(values, &sink_); } template <> -void PlainEncoder::Put(const ::arrow::Array& values) { - DirectPutImpl<::arrow::DoubleArray>(values, &sink_); +void PlainEncoder::put(const ::arrow::Array& values) { + directPutImpl<::arrow::DoubleArray>(values, &sink_); } template -void PlainEncoder::Put(const ::arrow::Array& values) { +void PlainEncoder::put(const ::arrow::Array& values) { ParquetException::NYI("direct put of " + values.type()->ToString()); } -void AssertBaseBinary(const ::arrow::Array& values) { +void assertBaseBinary(const ::arrow::Array& values) { if (!::arrow::is_base_binary_like(values.type_id())) { throw ParquetException("Only BaseBinaryArray and subclasses supported"); } } template <> -inline void PlainEncoder::Put(const ::arrow::Array& values) { - AssertBaseBinary(values); +inline void PlainEncoder::put(const ::arrow::Array& values) { + assertBaseBinary(values); if (::arrow::is_binary_like(values.type_id())) { - PutBinaryArray(checked_cast(values)); + putBinaryArray(checked_cast(values)); } else { VELOX_DCHECK(::arrow::is_large_binary_like(values.type_id())); - PutBinaryArray(checked_cast(values)); + putBinaryArray(checked_cast(values)); } } -void AssertFixedSizeBinary(const ::arrow::Array& values, int type_length) { +void assertFixedSizeBinary(const ::arrow::Array& values, int typeLength) { if (values.type_id() != ::arrow::Type::FIXED_SIZE_BINARY && - values.type_id() != ::arrow::Type::DECIMAL) { + values.type_id() != ::arrow::Type::DECIMAL128 && + values.type_id() != ::arrow::Type::DECIMAL256) { throw ParquetException( "Only FixedSizeBinaryArray and subclasses supported"); } if (checked_cast(*values.type()) - .byte_width() != type_length) { + .byte_width() != typeLength) { throw ParquetException( "Size mismatch: " + values.type()->ToString() + " should have been " + - std::to_string(type_length) + " wide"); + std::to_string(typeLength) + " wide"); } } template <> -inline void PlainEncoder::Put(const ::arrow::Array& values) { - AssertFixedSizeBinary(values, descr_->type_length()); +inline void PlainEncoder::put(const ::arrow::Array& values) { + assertFixedSizeBinary(values, descr_->typeLength()); const auto& data = checked_cast(values); if (data.null_count() == 0) { - // no nulls, just dump the data + // No nulls, just dump the data. PARQUET_THROW_NOT_OK( sink_.Append(data.raw_values(), data.length() * data.byte_width())); } else { - const int64_t total_bytes = data.length() * data.byte_width() - + const int64_t totalBytes = data.length() * data.byte_width() - data.null_count() * data.byte_width(); - PARQUET_THROW_NOT_OK(sink_.Reserve(total_bytes)); + PARQUET_THROW_NOT_OK(sink_.Reserve(totalBytes)); for (int64_t i = 0; i < data.length(); i++) { if (data.IsValid(i)) { sink_.UnsafeAppend(data.Value(i), data.byte_width()); @@ -343,16 +340,16 @@ inline void PlainEncoder::Put(const ::arrow::Array& values) { } template <> -inline void PlainEncoder::Put( +inline void PlainEncoder::put( const FixedLenByteArray* src, - int num_values) { - if (descr_->type_length() == 0) { + int numValues) { + if (descr_->typeLength() == 0) { return; } - for (int i = 0; i < num_values; ++i) { - // Write the result to the output stream + for (int i = 0; i < numValues; ++i) { + // Write the result to the output stream. VELOX_DCHECK(src[i].ptr != nullptr, "Value ptr cannot be NULL"); - PARQUET_THROW_NOT_OK(sink_.Append(src[i].ptr, descr_->type_length())); + PARQUET_THROW_NOT_OK(sink_.Append(src[i].ptr, descr_->typeLength())); } } @@ -361,34 +358,32 @@ class PlainEncoder : public EncoderImpl, virtual public BooleanEncoder { public: explicit PlainEncoder(const ColumnDescriptor* descr, MemoryPool* pool) - : EncoderImpl(descr, Encoding::PLAIN, pool), sink_(pool) {} + : EncoderImpl(descr, Encoding::kPlain, pool), sink_(pool) {} - int64_t EstimatedDataEncodedSize() override; - std::shared_ptr<::arrow::Buffer> FlushValues() override; + int64_t estimatedDataEncodedSize() override; + std::shared_ptr<::arrow::Buffer> flushValues() override; - void Put(const bool* src, int num_values) override; + void put(const bool* src, int numValues) override; - void Put(const std::vector& src, int num_values) override; + void put(const std::vector& src, int numValues) override; - void PutSpaced( + void putSpaced( const bool* src, - int num_values, - const uint8_t* valid_bits, - int64_t valid_bits_offset) override { - if (valid_bits != NULLPTR) { - PARQUET_ASSIGN_OR_THROW( - auto buffer, - ::arrow::AllocateBuffer(num_values * sizeof(T), this->memory_pool())); + int numValues, + const uint8_t* validBits, + int64_t validBitsOffset) override { + if (validBits != NULLPTR) { + auto buffer = allocateBuffer(this->memoryPool(), numValues * sizeof(T)); T* data = reinterpret_cast(buffer->mutable_data()); - int num_valid_values = ::arrow::util::internal::SpacedCompress( - src, num_values, valid_bits, valid_bits_offset, data); - Put(data, num_valid_values); + int numValidValues = ::arrow::util::internal::SpacedCompress( + src, numValues, validBits, validBitsOffset, data); + put(data, numValidValues); } else { - Put(src, num_values); + put(src, numValues); } } - void Put(const ::arrow::Array& values) override { + void put(const ::arrow::Array& values) override { if (values.type_id() != ::arrow::Type::BOOL) { throw ParquetException( "direct put to boolean from " + values.type()->ToString() + @@ -397,7 +392,7 @@ class PlainEncoder : public EncoderImpl, const auto& data = checked_cast(values); if (data.null_count() == 0) { - // no nulls, just dump the data + // No nulls, just dump the data. PARQUET_THROW_NOT_OK(sink_.Reserve(data.length())); sink_.UnsafeAppend( data.data()->GetValues(1, 0), data.offset(), data.length()); @@ -415,46 +410,46 @@ class PlainEncoder : public EncoderImpl, ::arrow::TypedBufferBuilder sink_; template - void PutImpl(const SequenceType& src, int num_values); + void putImpl(const SequenceType& src, int numValues); }; template -void PlainEncoder::PutImpl( +void PlainEncoder::putImpl( const SequenceType& src, - int num_values) { - PARQUET_THROW_NOT_OK(sink_.Reserve(num_values)); - for (int i = 0; i < num_values; ++i) { + int numValues) { + PARQUET_THROW_NOT_OK(sink_.Reserve(numValues)); + for (int i = 0; i < numValues; ++i) { sink_.UnsafeAppend(src[i]); } } -int64_t PlainEncoder::EstimatedDataEncodedSize() { +int64_t PlainEncoder::estimatedDataEncodedSize() { return ::arrow::bit_util::BytesForBits(sink_.length()); } -std::shared_ptr<::arrow::Buffer> PlainEncoder::FlushValues() { +std::shared_ptr<::arrow::Buffer> PlainEncoder::flushValues() { std::shared_ptr buffer; PARQUET_THROW_NOT_OK(sink_.Finish(&buffer)); return buffer; } -void PlainEncoder::Put(const bool* src, int num_values) { - PutImpl(src, num_values); +void PlainEncoder::put(const bool* src, int numValues) { + putImpl(src, numValues); } -void PlainEncoder::Put( +void PlainEncoder::put( const std::vector& src, - int num_values) { - PutImpl(src, num_values); + int numValues) { + putImpl(src, numValues); } -// ---------------------------------------------------------------------- -// DictEncoder implementations +// ----------------------------------------------------------------------. +// DictEncoder implementations. template struct DictEncoderTraits { - using c_type = typename DType::c_type; - using MemoTableType = arrow::internal::ScalarMemoTable; + using CType = typename DType::CType; + using MemoTableType = arrow::internal::ScalarMemoTable; }; template <> @@ -469,16 +464,16 @@ struct DictEncoderTraits { arrow::internal::BinaryMemoTable<::arrow::BinaryBuilder>; }; -// Initially 1024 elements +// Initially 1024 elements. static constexpr int32_t kInitialHashTableSize = 1 << 10; -int RlePreserveBufferSize(int num_values, int bit_width) { +int rlePreserveBufferSize(int numValues, int bitWidth) { // Note: because of the way RleEncoder::CheckBufferFull() - // is called, we have to reserve an extra "RleEncoder::MinBufferSize" + // is called, we have to Reserve an extra "RleEncoder::MinBufferSize" // bytes. These extra bytes won't be used but not reserving them // would cause the encoder to fail. - return RleEncoder::MaxBufferSize(bit_width, num_values) + - RleEncoder::MinBufferSize(bit_width); + return RleEncoder::MaxBufferSize(bitWidth, numValues) + + RleEncoder::MinBufferSize(bitWidth); } /// See the dictionary encoding section of @@ -493,103 +488,103 @@ class DictEncoderImpl : public EncoderImpl, virtual public DictEncoder { using MemoTableType = typename DictEncoderTraits::MemoTableType; public: - typedef typename DType::c_type T; + typedef typename DType::CType T; /// In data page, the bit width used to encode the entry /// ids stored as 1 byte (max bit width = 32). constexpr static int32_t kDataPageBitWidthBytes = 1; explicit DictEncoderImpl(const ColumnDescriptor* desc, MemoryPool* pool) - : EncoderImpl(desc, Encoding::PLAIN_DICTIONARY, pool), - buffered_indices_(::arrow::stl::allocator(pool)), - dict_encoded_size_(0), - memo_table_(pool, kInitialHashTableSize) {} + : EncoderImpl(desc, Encoding::kPlainDictionary, pool), + bufferedIndices_(::arrow::stl::allocator(pool)), + dictEncodedSize_(0), + memoTable_(pool, kInitialHashTableSize) {} ~DictEncoderImpl() = default; - int dict_encoded_size() const override { - return dict_encoded_size_; + int dictEncodedSize() const override { + return dictEncodedSize_; } - int WriteIndices(uint8_t* buffer, int buffer_len) override { - // Write bit width in first byte - *buffer = static_cast(bit_width()); + int writeIndices(uint8_t* buffer, int bufferLen) override { + // Write bit width in first byte. + *buffer = static_cast(bitWidth()); ++buffer; - --buffer_len; + --bufferLen; - RleEncoder encoder(buffer, buffer_len, bit_width()); + RleEncoder encoder(buffer, bufferLen, bitWidth()); - for (int32_t index : buffered_indices_) { + for (int32_t index : bufferedIndices_) { if (ARROW_PREDICT_FALSE(!encoder.Put(index))) return -1; } encoder.Flush(); - ClearIndices(); + clearIndices(); return kDataPageBitWidthBytes + encoder.len(); } - void set_type_length(int type_length) { - this->type_length_ = type_length; + void setTypeLength(int typeLength) { + this->typeLength_ = typeLength; } /// Returns a conservative estimate of the number of bytes needed to encode /// the buffered indices. Used to size the buffer passed to WriteIndices(). - int64_t EstimatedDataEncodedSize() override { + int64_t estimatedDataEncodedSize() override { return kDataPageBitWidthBytes + - RlePreserveBufferSize( - static_cast(buffered_indices_.size()), bit_width()); + rlePreserveBufferSize( + static_cast(bufferedIndices_.size()), bitWidth()); } /// The minimum bit width required to encode the currently buffered indices. - int bit_width() const override { - if (ARROW_PREDICT_FALSE(num_entries() == 0)) + int bitWidth() const override { + if (ARROW_PREDICT_FALSE(numEntries() == 0)) return 0; - if (ARROW_PREDICT_FALSE(num_entries() == 1)) + if (ARROW_PREDICT_FALSE(numEntries() == 1)) return 1; - return ::arrow::bit_util::Log2(num_entries()); + return ::arrow::bit_util::Log2(numEntries()); } /// Encode value. Note that this does not actually write any data, just /// buffers the value's index to be written later. - inline void Put(const T& value); + inline void put(const T& value); - // Not implemented for other data types - inline void PutByteArray(const void* ptr, int32_t length); + // Not implemented for other data types. + inline void putByteArray(const void* ptr, int32_t length); - void Put(const T* src, int num_values) override { - for (int32_t i = 0; i < num_values; i++) { - Put(SafeLoad(src + i)); + void put(const T* src, int numValues) override { + for (int32_t i = 0; i < numValues; i++) { + put(safeLoad(src + i)); } } - void PutSpaced( + void putSpaced( const T* src, - int num_values, - const uint8_t* valid_bits, - int64_t valid_bits_offset) override { + int numValues, + const uint8_t* validBits, + int64_t validBitsOffset) override { ::arrow::internal::VisitSetBitRunsVoid( - valid_bits, - valid_bits_offset, - num_values, + validBits, + validBitsOffset, + numValues, [&](int64_t position, int64_t length) { for (int64_t i = 0; i < length; i++) { - Put(SafeLoad(src + i + position)); + put(safeLoad(src + i + position)); } }); } - using TypedEncoder::Put; + using TypedEncoder::put; - void Put(const ::arrow::Array& values) override; - void PutDictionary(const ::arrow::Array& values) override; + void put(const ::arrow::Array& values) override; + void putDictionary(const ::arrow::Array& values) override; template - void PutIndicesTyped(const ::arrow::Array& data) { + void putIndicesTyped(const ::arrow::Array& data) { auto values = data.data()->GetValues(1); - size_t buffer_position = buffered_indices_.size(); - buffered_indices_.resize( - buffer_position + + size_t bufferPosition = bufferedIndices_.size(); + bufferedIndices_.resize( + bufferPosition + static_cast(data.length() - data.null_count())); ::arrow::internal::VisitSetBitRunsVoid( data.null_bitmap_data(), @@ -597,60 +592,60 @@ class DictEncoderImpl : public EncoderImpl, virtual public DictEncoder { data.length(), [&](int64_t position, int64_t length) { for (int64_t i = 0; i < length; ++i) { - buffered_indices_[buffer_position++] = + bufferedIndices_[bufferPosition++] = static_cast(values[i + position]); } }); } - void PutIndices(const ::arrow::Array& data) override { + void putIndices(const ::arrow::Array& data) override { switch (data.type()->id()) { case ::arrow::Type::UINT8: case ::arrow::Type::INT8: - return PutIndicesTyped<::arrow::UInt8Type>(data); + return putIndicesTyped<::arrow::UInt8Type>(data); case ::arrow::Type::UINT16: case ::arrow::Type::INT16: - return PutIndicesTyped<::arrow::UInt16Type>(data); + return putIndicesTyped<::arrow::UInt16Type>(data); case ::arrow::Type::UINT32: case ::arrow::Type::INT32: - return PutIndicesTyped<::arrow::UInt32Type>(data); + return putIndicesTyped<::arrow::UInt32Type>(data); case ::arrow::Type::UINT64: case ::arrow::Type::INT64: - return PutIndicesTyped<::arrow::UInt64Type>(data); + return putIndicesTyped<::arrow::UInt64Type>(data); default: throw ParquetException("Passed non-integer array to PutIndices"); } } - std::shared_ptr<::arrow::Buffer> FlushValues() override { + std::shared_ptr<::arrow::Buffer> flushValues() override { std::shared_ptr buffer = - AllocateBuffer(this->pool_, EstimatedDataEncodedSize()); - int result_size = WriteIndices( - buffer->mutable_data(), static_cast(EstimatedDataEncodedSize())); - PARQUET_THROW_NOT_OK(buffer->Resize(result_size, false)); + allocateBuffer(this->pool_, estimatedDataEncodedSize()); + int resultSize = writeIndices( + buffer->mutable_data(), static_cast(estimatedDataEncodedSize())); + PARQUET_THROW_NOT_OK(buffer->Resize(resultSize, false)); return std::move(buffer); } /// Writes out the encoded dictionary to buffer. buffer must be preallocated /// to dict_encoded_size() bytes. - void WriteDict(uint8_t* buffer) const override; + void writeDict(uint8_t* buffer) const override; /// The number of entries in the dictionary. - int num_entries() const override { - return memo_table_.size(); + int numEntries() const override { + return memoTable_.size(); } private: /// Clears all the indices (but leaves the dictionary). - void ClearIndices() { - buffered_indices_.clear(); + void clearIndices() { + bufferedIndices_.clear(); } /// Indices that have not yet be written out by WriteIndices(). - ArrowPoolVector buffered_indices_; + ArrowPoolVector bufferedIndices_; template - void PutBinaryArray(const ArrayType& array) { + void putBinaryArray(const ArrayType& array) { PARQUET_THROW_NOT_OK( ::arrow::VisitArraySpanInline( *array.data(), @@ -659,14 +654,14 @@ class DictEncoderImpl : public EncoderImpl, virtual public DictEncoder { return Status::Invalid( "Parquet cannot store strings with size 2GB or more"); } - PutByteArray(view.data(), static_cast(view.size())); + putByteArray(view.data(), static_cast(view.size())); return Status::OK(); }, []() { return Status::OK(); })); } template - void PutBinaryDictionaryArray(const ArrayType& array) { + void putBinaryDictionaryArray(const ArrayType& array) { VELOX_DCHECK_EQ(array.null_count(), 0); for (int64_t i = 0; i < array.length(); i++) { auto v = array.GetView(i); @@ -674,31 +669,31 @@ class DictEncoderImpl : public EncoderImpl, virtual public DictEncoder { throw ParquetException( "Parquet cannot store strings with size 2GB or more"); } - dict_encoded_size_ += static_cast(v.size() + sizeof(uint32_t)); - int32_t unused_memo_index; - PARQUET_THROW_NOT_OK(memo_table_.GetOrInsert( - v.data(), static_cast(v.size()), &unused_memo_index)); + dictEncodedSize_ += static_cast(v.size() + sizeof(uint32_t)); + int32_t unusedMemoIndex; + PARQUET_THROW_NOT_OK(memoTable_.getOrInsert( + v.data(), static_cast(v.size()), &unusedMemoIndex)); } } /// The number of bytes needed to encode the dictionary. - int dict_encoded_size_; + int dictEncodedSize_; - MemoTableType memo_table_; + MemoTableType memoTable_; }; template -void DictEncoderImpl::WriteDict(uint8_t* buffer) const { - // For primitive types, only a memcpy +void DictEncoderImpl::writeDict(uint8_t* buffer) const { + // For primitive types, only a memcpy. VELOX_DCHECK_EQ( - static_cast(dict_encoded_size_), sizeof(T) * memo_table_.size()); - memo_table_.CopyValues(0 /* start_pos */, reinterpret_cast(buffer)); + static_cast(dictEncodedSize_), sizeof(T) * memoTable_.size()); + memoTable_.copyValues(0 /* start_pos */, reinterpret_cast(buffer)); } -// ByteArray and FLBA already have the dictionary encoded in their data heaps +// ByteArray and FLBA already have the dictionary encoded in their data heaps. template <> -void DictEncoderImpl::WriteDict(uint8_t* buffer) const { - memo_table_.VisitValues(0, [&buffer](std::string_view v) { +void DictEncoderImpl::writeDict(uint8_t* buffer) const { + memoTable_.visitValues(0, [&buffer](std::string_view v) { uint32_t len = static_cast(v.length()); memcpy(buffer, &len, sizeof(len)); buffer += sizeof(len); @@ -708,231 +703,231 @@ void DictEncoderImpl::WriteDict(uint8_t* buffer) const { } template <> -void DictEncoderImpl::WriteDict(uint8_t* buffer) const { - memo_table_.VisitValues(0, [&](std::string_view v) { - VELOX_DCHECK_EQ(v.length(), static_cast(type_length_)); - memcpy(buffer, v.data(), type_length_); - buffer += type_length_; +void DictEncoderImpl::writeDict(uint8_t* buffer) const { + memoTable_.visitValues(0, [&](std::string_view v) { + VELOX_DCHECK_EQ(v.length(), static_cast(typeLength_)); + memcpy(buffer, v.data(), typeLength_); + buffer += typeLength_; }); } template -inline void DictEncoderImpl::Put(const T& v) { - // Put() implementation for primitive types - auto on_found = [](int32_t memo_index) {}; - auto on_not_found = [this](int32_t memo_index) { - dict_encoded_size_ += static_cast(sizeof(T)); +inline void DictEncoderImpl::put(const T& v) { + // Put() implementation for primitive types. + auto onFound = [](int32_t memoIndex) {}; + auto onNotFound = [this](int32_t memoIndex) { + dictEncodedSize_ += static_cast(sizeof(T)); }; - int32_t memo_index; + int32_t memoIndex; PARQUET_THROW_NOT_OK( - memo_table_.GetOrInsert(v, on_found, on_not_found, &memo_index)); - buffered_indices_.push_back(memo_index); + memoTable_.getOrInsert(v, onFound, onNotFound, &memoIndex)); + bufferedIndices_.push_back(memoIndex); } template -inline void DictEncoderImpl::PutByteArray( +inline void DictEncoderImpl::putByteArray( const void* ptr, int32_t length) { VELOX_DCHECK(false); } template <> -inline void DictEncoderImpl::PutByteArray( +inline void DictEncoderImpl::putByteArray( const void* ptr, int32_t length) { static const uint8_t empty[] = {0}; - auto on_found = [](int32_t memo_index) {}; - auto on_not_found = [&](int32_t memo_index) { - dict_encoded_size_ += static_cast(length + sizeof(uint32_t)); + auto onFound = [](int32_t memoIndex) {}; + auto onNotFound = [&](int32_t memoIndex) { + dictEncodedSize_ += static_cast(length + sizeof(uint32_t)); }; VELOX_DCHECK(ptr != nullptr || length == 0); ptr = (ptr != nullptr) ? ptr : empty; - int32_t memo_index; - PARQUET_THROW_NOT_OK(memo_table_.GetOrInsert( - ptr, length, on_found, on_not_found, &memo_index)); - buffered_indices_.push_back(memo_index); + int32_t memoIndex; + PARQUET_THROW_NOT_OK( + memoTable_.getOrInsert(ptr, length, onFound, onNotFound, &memoIndex)); + bufferedIndices_.push_back(memoIndex); } template <> -inline void DictEncoderImpl::Put(const ByteArray& val) { - return PutByteArray(val.ptr, static_cast(val.len)); +inline void DictEncoderImpl::put(const ByteArray& val) { + return putByteArray(val.ptr, static_cast(val.len)); } template <> -inline void DictEncoderImpl::Put(const FixedLenByteArray& v) { +inline void DictEncoderImpl::put(const FixedLenByteArray& v) { static const uint8_t empty[] = {0}; - auto on_found = [](int32_t memo_index) {}; - auto on_not_found = [this](int32_t memo_index) { - dict_encoded_size_ += type_length_; + auto onFound = [](int32_t memoIndex) {}; + auto onNotFound = [this](int32_t memoIndex) { + dictEncodedSize_ += typeLength_; }; - VELOX_DCHECK(v.ptr != nullptr || type_length_ == 0); + VELOX_DCHECK(v.ptr != nullptr || typeLength_ == 0); const void* ptr = (v.ptr != nullptr) ? v.ptr : empty; - int32_t memo_index; - PARQUET_THROW_NOT_OK(memo_table_.GetOrInsert( - ptr, type_length_, on_found, on_not_found, &memo_index)); - buffered_indices_.push_back(memo_index); + int32_t memoIndex; + PARQUET_THROW_NOT_OK(memoTable_.getOrInsert( + ptr, typeLength_, onFound, onNotFound, &memoIndex)); + bufferedIndices_.push_back(memoIndex); } template <> -void DictEncoderImpl::Put(const ::arrow::Array& values) { +void DictEncoderImpl::put(const ::arrow::Array& values) { ParquetException::NYI("Direct put to Int96"); } template <> -void DictEncoderImpl::PutDictionary(const ::arrow::Array& values) { +void DictEncoderImpl::putDictionary(const ::arrow::Array& values) { ParquetException::NYI("Direct put to Int96"); } template -void DictEncoderImpl::Put(const ::arrow::Array& values) { +void DictEncoderImpl::put(const ::arrow::Array& values) { using ArrayType = - typename ::arrow::CTypeTraits::ArrayType; + typename ::arrow::CTypeTraits::ArrayType; const auto& data = checked_cast(values); if (data.null_count() == 0) { - // no nulls, just dump the data + // No nulls, just dump the data. for (int64_t i = 0; i < data.length(); i++) { - Put(data.Value(i)); + put(data.Value(i)); } } else { for (int64_t i = 0; i < data.length(); i++) { if (data.IsValid(i)) { - Put(data.Value(i)); + put(data.Value(i)); } } } } template <> -void DictEncoderImpl::Put(const ::arrow::Array& values) { - AssertFixedSizeBinary(values, type_length_); +void DictEncoderImpl::put(const ::arrow::Array& values) { + assertFixedSizeBinary(values, typeLength_); const auto& data = checked_cast(values); if (data.null_count() == 0) { - // no nulls, just dump the data + // No nulls, just dump the data. for (int64_t i = 0; i < data.length(); i++) { - Put(FixedLenByteArray(data.Value(i))); + put(FixedLenByteArray(data.Value(i))); } } else { - std::vector empty(type_length_, 0); + std::vector empty(typeLength_, 0); for (int64_t i = 0; i < data.length(); i++) { if (data.IsValid(i)) { - Put(FixedLenByteArray(data.Value(i))); + put(FixedLenByteArray(data.Value(i))); } } } } template <> -void DictEncoderImpl::Put(const ::arrow::Array& values) { - AssertBaseBinary(values); +void DictEncoderImpl::put(const ::arrow::Array& values) { + assertBaseBinary(values); if (::arrow::is_binary_like(values.type_id())) { - PutBinaryArray(checked_cast(values)); + putBinaryArray(checked_cast(values)); } else { VELOX_DCHECK(::arrow::is_large_binary_like(values.type_id())); - PutBinaryArray(checked_cast(values)); + putBinaryArray(checked_cast(values)); } } template -void AssertCanPutDictionary( +void assertCanPutDictionary( DictEncoderImpl* encoder, const ::arrow::Array& dict) { if (dict.null_count() > 0) { throw ParquetException("Inserted dictionary cannot cannot contain nulls"); } - if (encoder->num_entries() > 0) { + if (encoder->numEntries() > 0) { throw ParquetException( "Can only call PutDictionary on an empty DictEncoder"); } } template -void DictEncoderImpl::PutDictionary(const ::arrow::Array& values) { - AssertCanPutDictionary(this, values); +void DictEncoderImpl::putDictionary(const ::arrow::Array& values) { + assertCanPutDictionary(this, values); using ArrayType = - typename ::arrow::CTypeTraits::ArrayType; + typename ::arrow::CTypeTraits::ArrayType; const auto& data = checked_cast(values); - dict_encoded_size_ += - static_cast(sizeof(typename DType::c_type) * data.length()); + dictEncodedSize_ += + static_cast(sizeof(typename DType::CType) * data.length()); for (int64_t i = 0; i < data.length(); i++) { - int32_t unused_memo_index; + int32_t unusedMemoIndex; PARQUET_THROW_NOT_OK( - memo_table_.GetOrInsert(data.Value(i), &unused_memo_index)); + memoTable_.getOrInsert(data.Value(i), &unusedMemoIndex)); } } template <> -void DictEncoderImpl::PutDictionary(const ::arrow::Array& values) { - AssertFixedSizeBinary(values, type_length_); - AssertCanPutDictionary(this, values); +void DictEncoderImpl::putDictionary(const ::arrow::Array& values) { + assertFixedSizeBinary(values, typeLength_); + assertCanPutDictionary(this, values); const auto& data = checked_cast(values); - dict_encoded_size_ += static_cast(type_length_ * data.length()); + dictEncodedSize_ += static_cast(typeLength_ * data.length()); for (int64_t i = 0; i < data.length(); i++) { - int32_t unused_memo_index; - PARQUET_THROW_NOT_OK(memo_table_.GetOrInsert( - data.Value(i), type_length_, &unused_memo_index)); + int32_t unusedMemoIndex; + PARQUET_THROW_NOT_OK( + memoTable_.getOrInsert(data.Value(i), typeLength_, &unusedMemoIndex)); } } template <> -void DictEncoderImpl::PutDictionary( +void DictEncoderImpl::putDictionary( const ::arrow::Array& values) { - AssertBaseBinary(values); - AssertCanPutDictionary(this, values); + assertBaseBinary(values); + assertCanPutDictionary(this, values); if (::arrow::is_binary_like(values.type_id())) { - PutBinaryDictionaryArray(checked_cast(values)); + putBinaryDictionaryArray(checked_cast(values)); } else { VELOX_DCHECK(::arrow::is_large_binary_like(values.type_id())); - PutBinaryDictionaryArray( + putBinaryDictionaryArray( checked_cast(values)); } } -// ---------------------------------------------------------------------- -// ByteStreamSplitEncoder implementations +// ----------------------------------------------------------------------. +// ByteStreamSplitEncoder implementations. template class ByteStreamSplitEncoder : public EncoderImpl, virtual public TypedEncoder { public: - using T = typename DType::c_type; - using TypedEncoder::Put; + using T = typename DType::CType; + using TypedEncoder::put; explicit ByteStreamSplitEncoder( const ColumnDescriptor* descr, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()); - int64_t EstimatedDataEncodedSize() override; - std::shared_ptr<::arrow::Buffer> FlushValues() override; + int64_t estimatedDataEncodedSize() override; + std::shared_ptr<::arrow::Buffer> flushValues() override; - void Put(const T* buffer, int num_values) override; - void Put(const ::arrow::Array& values) override; - void PutSpaced( + void put(const T* buffer, int numValues) override; + void put(const ::arrow::Array& values) override; + void putSpaced( const T* src, - int num_values, - const uint8_t* valid_bits, - int64_t valid_bits_offset) override; + int numValues, + const uint8_t* validBits, + int64_t validBitsOffset) override; protected: template - void PutImpl(const ::arrow::Array& values) { + void putImpl(const ::arrow::Array& values) { if (values.type_id() != ArrowType::type_id) { throw ParquetException( std::string() + "direct put to " + ArrowType::type_name() + " from " + values.type()->ToString() + " not supported"); } const auto& data = *values.data(); - PutSpaced( + putSpaced( data.GetValues(1), static_cast(data.length), data.GetValues(0, 0), @@ -940,82 +935,79 @@ class ByteStreamSplitEncoder : public EncoderImpl, } ::arrow::BufferBuilder sink_; - int64_t num_values_in_buffer_; + int64_t numValuesInBuffer_; }; template ByteStreamSplitEncoder::ByteStreamSplitEncoder( const ColumnDescriptor* descr, ::arrow::MemoryPool* pool) - : EncoderImpl(descr, Encoding::BYTE_STREAM_SPLIT, pool), + : EncoderImpl(descr, Encoding::kByteStreamSplit, pool), sink_{pool}, - num_values_in_buffer_{0} {} + numValuesInBuffer_{0} {} template -int64_t ByteStreamSplitEncoder::EstimatedDataEncodedSize() { +int64_t ByteStreamSplitEncoder::estimatedDataEncodedSize() { return sink_.length(); } template -std::shared_ptr<::arrow::Buffer> ByteStreamSplitEncoder::FlushValues() { - std::shared_ptr output_buffer = - AllocateBuffer(this->memory_pool(), EstimatedDataEncodedSize()); - uint8_t* output_buffer_raw = output_buffer->mutable_data(); - const uint8_t* raw_values = sink_.data(); - ByteStreamSplitEncode( - raw_values, num_values_in_buffer_, output_buffer_raw); +std::shared_ptr<::arrow::Buffer> ByteStreamSplitEncoder::flushValues() { + std::shared_ptr outputBuffer = + allocateBuffer(this->memoryPool(), estimatedDataEncodedSize()); + uint8_t* outputBufferRaw = outputBuffer->mutable_data(); + const uint8_t* rawValues = sink_.data(); + byteStreamSplitEncode(rawValues, numValuesInBuffer_, outputBufferRaw); sink_.Reset(); - num_values_in_buffer_ = 0; - return std::move(output_buffer); + numValuesInBuffer_ = 0; + return std::move(outputBuffer); } template -void ByteStreamSplitEncoder::Put(const T* buffer, int num_values) { - if (num_values > 0) { - PARQUET_THROW_NOT_OK(sink_.Append(buffer, num_values * sizeof(T))); - num_values_in_buffer_ += num_values; +void ByteStreamSplitEncoder::put(const T* buffer, int numValues) { + if (numValues > 0) { + PARQUET_THROW_NOT_OK(sink_.Append(buffer, numValues * sizeof(T))); + numValuesInBuffer_ += numValues; } } template <> -void ByteStreamSplitEncoder::Put(const ::arrow::Array& values) { - PutImpl<::arrow::FloatType>(values); +void ByteStreamSplitEncoder::put(const ::arrow::Array& values) { + putImpl<::arrow::FloatType>(values); } template <> -void ByteStreamSplitEncoder::Put(const ::arrow::Array& values) { - PutImpl<::arrow::DoubleType>(values); +void ByteStreamSplitEncoder::put(const ::arrow::Array& values) { + putImpl<::arrow::DoubleType>(values); } template -void ByteStreamSplitEncoder::PutSpaced( +void ByteStreamSplitEncoder::putSpaced( const T* src, - int num_values, - const uint8_t* valid_bits, - int64_t valid_bits_offset) { - if (valid_bits != NULLPTR) { - PARQUET_ASSIGN_OR_THROW( - auto buffer, - ::arrow::AllocateBuffer(num_values * sizeof(T), this->memory_pool())); + int numValues, + const uint8_t* validBits, + int64_t validBitsOffset) { + if (validBits != NULLPTR) { + auto buffer = allocateBuffer(this->memoryPool(), numValues * sizeof(T)); T* data = reinterpret_cast(buffer->mutable_data()); - int num_valid_values = ::arrow::util::internal::SpacedCompress( - src, num_values, valid_bits, valid_bits_offset, data); - Put(data, num_valid_values); + int numValidValues = ::arrow::util::internal::SpacedCompress( + src, numValues, validBits, validBitsOffset, data); + put(data, numValidValues); } else { - Put(src, num_values); + put(src, numValues); } } class DecoderImpl : virtual public Decoder { public: - void SetData(int num_values, const uint8_t* data, int len) override { - num_values_ = num_values; + void setData(int numValues, const uint8_t* data, int len) override { + numValues_ = numValues; data_ = data; len_ = len; } - int values_left() const override { - return num_values_; + int valuesLeft() const override { + return numValues_; } Encoding::type encoding() const override { return encoding_; @@ -1025,343 +1017,340 @@ class DecoderImpl : virtual public Decoder { explicit DecoderImpl(const ColumnDescriptor* descr, Encoding::type encoding) : descr_(descr), encoding_(encoding), - num_values_(0), + numValues_(0), data_(NULLPTR), len_(0) {} - // For accessing type-specific metadata, like FIXED_LEN_BYTE_ARRAY + // For accessing type-specific metadata, like FIXED_LEN_BYTE_ARRAY. const ColumnDescriptor* descr_; const Encoding::type encoding_; - int num_values_; + int numValues_; const uint8_t* data_; int len_; - int type_length_; + int typeLength_; }; template class PlainDecoder : public DecoderImpl, virtual public TypedDecoder { public: - using T = typename DType::c_type; + using T = typename DType::CType; explicit PlainDecoder(const ColumnDescriptor* descr); - int Decode(T* buffer, int max_values) override; - - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::Accumulator* builder) override; - - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) override; + int decode(T* buffer, int maxValues) override; + + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::Accumulator* Builder) override; + + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) override; }; template <> -inline int PlainDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::Accumulator* builder) { +inline int PlainDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::Accumulator* Builder) { ParquetException::NYI("DecodeArrow not supported for Int96"); } template <> -inline int PlainDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) { +inline int PlainDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) { ParquetException::NYI("DecodeArrow not supported for Int96"); } template <> -inline int PlainDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) { +inline int PlainDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) { ParquetException::NYI("dictionaries of BooleanType"); } template -int PlainDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::Accumulator* builder) { - using value_type = typename DType::c_type; +int PlainDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::Accumulator* Builder) { + using ValueType = typename DType::CType; - constexpr int value_size = static_cast(sizeof(value_type)); - int values_decoded = num_values - null_count; - if (ARROW_PREDICT_FALSE(len_ < value_size * values_decoded)) { - ParquetException::EofException(); + constexpr int valueSize = static_cast(sizeof(ValueType)); + int valuesDecoded = numValues - nullCount; + if (ARROW_PREDICT_FALSE(len_ < valueSize * valuesDecoded)) { + ParquetException::eofException(); } - PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); + PARQUET_THROW_NOT_OK(Builder->Reserve(numValues)); VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { - builder->UnsafeAppend(SafeLoadAs(data_)); - data_ += sizeof(value_type); + Builder->UnsafeAppend(SafeLoadAs(data_)); + data_ += sizeof(ValueType); }, - [&]() { builder->UnsafeAppendNull(); }); + [&]() { Builder->UnsafeAppendNull(); }); - num_values_ -= values_decoded; - len_ -= sizeof(value_type) * values_decoded; - return values_decoded; + numValues_ -= valuesDecoded; + len_ -= sizeof(ValueType) * valuesDecoded; + return valuesDecoded; } template -int PlainDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) { - using value_type = typename DType::c_type; +int PlainDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) { + using ValueType = typename DType::CType; - constexpr int value_size = static_cast(sizeof(value_type)); - int values_decoded = num_values - null_count; - if (ARROW_PREDICT_FALSE(len_ < value_size * values_decoded)) { - ParquetException::EofException(); + constexpr int valueSize = static_cast(sizeof(ValueType)); + int valuesDecoded = numValues - nullCount; + if (ARROW_PREDICT_FALSE(len_ < valueSize * valuesDecoded)) { + ParquetException::eofException(); } - PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); + PARQUET_THROW_NOT_OK(Builder->Reserve(numValues)); VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { - PARQUET_THROW_NOT_OK(builder->Append(SafeLoadAs(data_))); - data_ += sizeof(value_type); + PARQUET_THROW_NOT_OK(Builder->Append(SafeLoadAs(data_))); + data_ += sizeof(ValueType); }, - [&]() { PARQUET_THROW_NOT_OK(builder->AppendNull()); }); + [&]() { PARQUET_THROW_NOT_OK(Builder->AppendNull()); }); - num_values_ -= values_decoded; - len_ -= sizeof(value_type) * values_decoded; - return values_decoded; + numValues_ -= valuesDecoded; + len_ -= sizeof(ValueType) * valuesDecoded; + return valuesDecoded; } -// Decode routine templated on C++ type rather than type enum +// Decode routine templated on C++ type rather than type enum. template -inline int DecodePlain( +inline int decodePlain( const uint8_t* data, - int64_t data_size, - int num_values, - int type_length, + int64_t dataSize, + int numValues, + int typeLength, T* out) { - int64_t bytes_to_decode = num_values * static_cast(sizeof(T)); - if (bytes_to_decode > data_size || bytes_to_decode > INT_MAX) { - ParquetException::EofException(); + int64_t bytesToDecode = numValues * static_cast(sizeof(T)); + if (bytesToDecode > dataSize || bytesToDecode > INT_MAX) { + ParquetException::eofException(); } - // If bytes_to_decode == 0, data could be null - if (bytes_to_decode > 0) { - memcpy(out, data, bytes_to_decode); + // If bytes_to_decode == 0, data could be null. + if (bytesToDecode > 0) { + memcpy(out, data, bytesToDecode); } - return static_cast(bytes_to_decode); + return static_cast(bytesToDecode); } template PlainDecoder::PlainDecoder(const ColumnDescriptor* descr) - : DecoderImpl(descr, Encoding::PLAIN) { - if (descr_ && descr_->physical_type() == Type::FIXED_LEN_BYTE_ARRAY) { - type_length_ = descr_->type_length(); + : DecoderImpl(descr, Encoding::kPlain) { + if (descr_ && descr_->physicalType() == Type::kFixedLenByteArray) { + typeLength_ = descr_->typeLength(); } else { - type_length_ = -1; + typeLength_ = -1; } } -// Template specialization for BYTE_ARRAY. The written values do not own their -// own data. +// Template specialization for BYTE_ARRAY. The written values do not own their. +// Own data. static inline int64_t -ReadByteArray(const uint8_t* data, int64_t data_size, ByteArray* out) { - if (ARROW_PREDICT_FALSE(data_size < 4)) { - ParquetException::EofException(); +readByteArray(const uint8_t* data, int64_t dataSize, ByteArray* out) { + if (ARROW_PREDICT_FALSE(dataSize < 4)) { + ParquetException::eofException(); } const int32_t len = SafeLoadAs(data); if (len < 0) { throw ParquetException("Invalid BYTE_ARRAY value"); } - const int64_t consumed_length = static_cast(len) + 4; - if (ARROW_PREDICT_FALSE(data_size < consumed_length)) { - ParquetException::EofException(); + const int64_t consumedLength = static_cast(len) + 4; + if (ARROW_PREDICT_FALSE(dataSize < consumedLength)) { + ParquetException::eofException(); } *out = ByteArray{static_cast(len), data + 4}; - return consumed_length; + return consumedLength; } template <> -inline int DecodePlain( +inline int decodePlain( const uint8_t* data, - int64_t data_size, - int num_values, - int type_length, + int64_t dataSize, + int numValues, + int typeLength, ByteArray* out) { - int bytes_decoded = 0; - for (int i = 0; i < num_values; ++i) { - const auto increment = ReadByteArray(data, data_size, out + i); - if (ARROW_PREDICT_FALSE(increment > INT_MAX - bytes_decoded)) { + int bytesDecoded = 0; + for (int i = 0; i < numValues; ++i) { + const auto increment = readByteArray(data, dataSize, out + i); + if (ARROW_PREDICT_FALSE(increment > INT_MAX - bytesDecoded)) { throw ParquetException("BYTE_ARRAY chunk too large"); } data += increment; - data_size -= increment; - bytes_decoded += static_cast(increment); + dataSize -= increment; + bytesDecoded += static_cast(increment); } - return bytes_decoded; + return bytesDecoded; } -// Template specialization for FIXED_LEN_BYTE_ARRAY. The written values do not -// own their own data. +// Template specialization for FIXED_LEN_BYTE_ARRAY. The written values do not. +// Own their own data. template <> -inline int DecodePlain( +inline int decodePlain( const uint8_t* data, - int64_t data_size, - int num_values, - int type_length, + int64_t dataSize, + int numValues, + int typeLength, FixedLenByteArray* out) { - int64_t bytes_to_decode = static_cast(type_length) * num_values; - if (bytes_to_decode > data_size || bytes_to_decode > INT_MAX) { - ParquetException::EofException(); + int64_t bytesToDecode = static_cast(typeLength) * numValues; + if (bytesToDecode > dataSize || bytesToDecode > INT_MAX) { + ParquetException::eofException(); } - for (int i = 0; i < num_values; ++i) { + for (int i = 0; i < numValues; ++i) { out[i].ptr = data; - data += type_length; - data_size -= type_length; + data += typeLength; + dataSize -= typeLength; } - return static_cast(bytes_to_decode); + return static_cast(bytesToDecode); } template -int PlainDecoder::Decode(T* buffer, int max_values) { - max_values = std::min(max_values, num_values_); - int bytes_consumed = - DecodePlain(data_, len_, max_values, type_length_, buffer); - data_ += bytes_consumed; - len_ -= bytes_consumed; - num_values_ -= max_values; - return max_values; +int PlainDecoder::decode(T* buffer, int maxValues) { + maxValues = std::min(maxValues, numValues_); + int bytesConsumed = + decodePlain(data_, len_, maxValues, typeLength_, buffer); + data_ += bytesConsumed; + len_ -= bytesConsumed; + numValues_ -= maxValues; + return maxValues; } class PlainBooleanDecoder : public DecoderImpl, virtual public BooleanDecoder { public: explicit PlainBooleanDecoder(const ColumnDescriptor* descr); - void SetData(int num_values, const uint8_t* data, int len) override; - - // Two flavors of bool decoding - int Decode(uint8_t* buffer, int max_values) override; - int Decode(bool* buffer, int max_values) override; - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + void setData(int numValues, const uint8_t* data, int len) override; + + // Two flavors of bool decoding. + int decode(uint8_t* buffer, int maxValues) override; + int decode(bool* buffer, int maxValues) override; + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::Accumulator* out) override; - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::DictAccumulator* out) override; private: - std::unique_ptr bit_reader_; + std::unique_ptr bitReader_; }; PlainBooleanDecoder::PlainBooleanDecoder(const ColumnDescriptor* descr) - : DecoderImpl(descr, Encoding::PLAIN) {} + : DecoderImpl(descr, Encoding::kPlain) {} -void PlainBooleanDecoder::SetData( - int num_values, - const uint8_t* data, - int len) { - num_values_ = num_values; - bit_reader_ = std::make_unique(data, len); +void PlainBooleanDecoder::setData(int numValues, const uint8_t* data, int len) { + numValues_ = numValues; + bitReader_ = std::make_unique(data, len); } -int PlainBooleanDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::Accumulator* builder) { - int values_decoded = num_values - null_count; - if (ARROW_PREDICT_FALSE(num_values_ < values_decoded)) { - ParquetException::EofException(); +int PlainBooleanDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::Accumulator* Builder) { + int valuesDecoded = numValues - nullCount; + if (ARROW_PREDICT_FALSE(numValues_ < valuesDecoded)) { + ParquetException::eofException(); } - PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); + PARQUET_THROW_NOT_OK(Builder->Reserve(numValues)); VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { bool value; - ARROW_IGNORE_EXPR(bit_reader_->GetValue(1, &value)); - builder->UnsafeAppend(value); + ((void)(bitReader_->GetValue(1, &value))); + Builder->UnsafeAppend(value); }, - [&]() { builder->UnsafeAppendNull(); }); + [&]() { Builder->UnsafeAppendNull(); }); - num_values_ -= values_decoded; - return values_decoded; + numValues_ -= valuesDecoded; + return valuesDecoded; } -inline int PlainBooleanDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) { +inline int PlainBooleanDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) { ParquetException::NYI("dictionaries of BooleanType"); } -int PlainBooleanDecoder::Decode(uint8_t* buffer, int max_values) { - max_values = std::min(max_values, num_values_); +int PlainBooleanDecoder::decode(uint8_t* buffer, int maxValues) { + maxValues = std::min(maxValues, numValues_); bool val; - ::arrow::internal::BitmapWriter bit_writer(buffer, 0, max_values); - for (int i = 0; i < max_values; ++i) { - if (!bit_reader_->GetValue(1, &val)) { - ParquetException::EofException(); + ::arrow::internal::BitmapWriter bitWriter(buffer, 0, maxValues); + for (int i = 0; i < maxValues; ++i) { + if (!bitReader_->GetValue(1, &val)) { + ParquetException::eofException(); } if (val) { - bit_writer.Set(); + bitWriter.Set(); } - bit_writer.Next(); + bitWriter.Next(); } - bit_writer.Finish(); - num_values_ -= max_values; - return max_values; + bitWriter.Finish(); + numValues_ -= maxValues; + return maxValues; } -int PlainBooleanDecoder::Decode(bool* buffer, int max_values) { - max_values = std::min(max_values, num_values_); - if (bit_reader_->GetBatch(1, buffer, max_values) != max_values) { - ParquetException::EofException(); +int PlainBooleanDecoder::decode(bool* buffer, int maxValues) { + maxValues = std::min(maxValues, numValues_); + if (bitReader_->GetBatch(1, buffer, maxValues) != maxValues) { + ParquetException::eofException(); } - num_values_ -= max_values; - return max_values; + numValues_ -= maxValues; + return maxValues; } -// A helper class to abstract away differences between +// A helper class to abstract away differences between. // EncodingTraits::Accumulator for ByteArrayType and FLBAType. template struct ArrowBinaryHelper; @@ -1372,75 +1361,77 @@ struct ArrowBinaryHelper { ArrowBinaryHelper(Accumulator* acc, int64_t length) : acc_(acc), - entries_remaining_(length), - chunk_space_remaining_( - ::arrow::kBinaryMemoryLimit - acc_->builder->value_data_length()) {} - - Status Prepare(std::optional estimated_data_length = {}) { - RETURN_NOT_OK(acc_->builder->Reserve(entries_remaining_)); - if (estimated_data_length.has_value()) { - RETURN_NOT_OK(acc_->builder->ReserveData(std::min( - *estimated_data_length, ::arrow::kBinaryMemoryLimit))); + entriesRemaining_(length), + chunkSpaceRemaining_( + ::arrow::kBinaryMemoryLimit - acc_->Builder->value_data_length()) {} + + Status prepare(std::optional estimatedDataLength = {}) { + RETURN_NOT_OK(acc_->Builder->Reserve(entriesRemaining_)); + if (estimatedDataLength.has_value()) { + RETURN_NOT_OK(acc_->Builder->ReserveData( + std::min( + *estimatedDataLength, ::arrow::kBinaryMemoryLimit))); } return Status::OK(); } - Status PrepareNextInput( - int64_t next_value_length, - std::optional estimated_remaining_data_length = {}) { - if (ARROW_PREDICT_FALSE(!CanFit(next_value_length))) { - // This element would exceed the capacity of a chunk - RETURN_NOT_OK(PushChunk()); - RETURN_NOT_OK(acc_->builder->Reserve(entries_remaining_)); - if (estimated_remaining_data_length.has_value()) { - RETURN_NOT_OK(acc_->builder->ReserveData(std::min( - *estimated_remaining_data_length, chunk_space_remaining_))); + Status prepareNextInput( + int64_t nextValueLength, + std::optional estimatedRemainingDataLength = {}) { + if (ARROW_PREDICT_FALSE(!canFit(nextValueLength))) { + // This element would exceed the capacity of a chunk. + RETURN_NOT_OK(pushChunk()); + RETURN_NOT_OK(acc_->Builder->Reserve(entriesRemaining_)); + if (estimatedRemainingDataLength.has_value()) { + RETURN_NOT_OK(acc_->Builder->ReserveData( + std::min( + *estimatedRemainingDataLength, chunkSpaceRemaining_))); } } return Status::OK(); } void UnsafeAppend(const uint8_t* data, int32_t length) { - VELOX_DCHECK(CanFit(length)); - VELOX_DCHECK_GT(entries_remaining_, 0); - chunk_space_remaining_ -= length; - --entries_remaining_; - acc_->builder->UnsafeAppend(data, length); + VELOX_DCHECK(canFit(length)); + VELOX_DCHECK_GT(entriesRemaining_, 0); + chunkSpaceRemaining_ -= length; + --entriesRemaining_; + acc_->Builder->UnsafeAppend(data, length); } Status Append(const uint8_t* data, int32_t length) { - VELOX_DCHECK(CanFit(length)); - VELOX_DCHECK_GT(entries_remaining_, 0); - chunk_space_remaining_ -= length; - --entries_remaining_; - return acc_->builder->Append(data, length); + VELOX_DCHECK(canFit(length)); + VELOX_DCHECK_GT(entriesRemaining_, 0); + chunkSpaceRemaining_ -= length; + --entriesRemaining_; + return acc_->Builder->Append(data, length); } void UnsafeAppendNull() { - --entries_remaining_; - acc_->builder->UnsafeAppendNull(); + --entriesRemaining_; + acc_->Builder->UnsafeAppendNull(); } Status AppendNull() { - --entries_remaining_; - return acc_->builder->AppendNull(); + --entriesRemaining_; + return acc_->Builder->AppendNull(); } private: - Status PushChunk() { - ARROW_ASSIGN_OR_RAISE(auto chunk, acc_->builder->Finish()); + Status pushChunk() { + ARROW_ASSIGN_OR_RAISE(auto chunk, acc_->Builder->Finish()); acc_->chunks.push_back(std::move(chunk)); - chunk_space_remaining_ = ::arrow::kBinaryMemoryLimit; + chunkSpaceRemaining_ = ::arrow::kBinaryMemoryLimit; return Status::OK(); } - bool CanFit(int64_t length) const { - return length <= chunk_space_remaining_; + bool canFit(int64_t length) const { + return length <= chunkSpaceRemaining_; } Accumulator* acc_; - int64_t entries_remaining_; - int64_t chunk_space_remaining_; + int64_t entriesRemaining_; + int64_t chunkSpaceRemaining_; }; template <> @@ -1448,204 +1439,199 @@ struct ArrowBinaryHelper { using Accumulator = typename EncodingTraits::Accumulator; ArrowBinaryHelper(Accumulator* acc, int64_t length) - : acc_(acc), entries_remaining_(length) {} + : acc_(acc), entriesRemaining_(length) {} - Status Prepare(std::optional estimated_data_length = {}) { - return acc_->Reserve(entries_remaining_); + Status prepare(std::optional estimatedDataLength = {}) { + return acc_->Reserve(entriesRemaining_); } - Status PrepareNextInput( - int64_t next_value_length, - std::optional estimated_remaining_data_length = {}) { + Status prepareNextInput( + int64_t nextValueLength, + std::optional estimatedRemainingDataLength = {}) { return Status::OK(); } void UnsafeAppend(const uint8_t* data, int32_t length) { - VELOX_DCHECK_GT(entries_remaining_, 0); - --entries_remaining_; + VELOX_DCHECK_GT(entriesRemaining_, 0); + --entriesRemaining_; acc_->UnsafeAppend(data); } Status Append(const uint8_t* data, int32_t length) { - VELOX_DCHECK_GT(entries_remaining_, 0); - --entries_remaining_; + VELOX_DCHECK_GT(entriesRemaining_, 0); + --entriesRemaining_; return acc_->Append(data); } void UnsafeAppendNull() { - --entries_remaining_; + --entriesRemaining_; acc_->UnsafeAppendNull(); } Status AppendNull() { - --entries_remaining_; + --entriesRemaining_; return acc_->AppendNull(); } private: Accumulator* acc_; - int64_t entries_remaining_; + int64_t entriesRemaining_; }; template <> -inline int PlainDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::Accumulator* builder) { +inline int PlainDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::Accumulator* Builder) { ParquetException::NYI(); } template <> -inline int PlainDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) { +inline int PlainDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) { ParquetException::NYI(); } template <> -inline int PlainDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::Accumulator* builder) { - int values_decoded = num_values - null_count; - if (ARROW_PREDICT_FALSE(len_ < descr_->type_length() * values_decoded)) { - ParquetException::EofException(); +inline int PlainDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::Accumulator* Builder) { + int valuesDecoded = numValues - nullCount; + if (ARROW_PREDICT_FALSE(len_ < descr_->typeLength() * valuesDecoded)) { + ParquetException::eofException(); } - PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); + PARQUET_THROW_NOT_OK(Builder->Reserve(numValues)); VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { - builder->UnsafeAppend(data_); - data_ += descr_->type_length(); + Builder->UnsafeAppend(data_); + data_ += descr_->typeLength(); }, - [&]() { builder->UnsafeAppendNull(); }); + [&]() { Builder->UnsafeAppendNull(); }); - num_values_ -= values_decoded; - len_ -= descr_->type_length() * values_decoded; - return values_decoded; + numValues_ -= valuesDecoded; + len_ -= descr_->typeLength() * valuesDecoded; + return valuesDecoded; } template <> -inline int PlainDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) { - int values_decoded = num_values - null_count; - if (ARROW_PREDICT_FALSE(len_ < descr_->type_length() * values_decoded)) { - ParquetException::EofException(); +inline int PlainDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) { + int valuesDecoded = numValues - nullCount; + if (ARROW_PREDICT_FALSE(len_ < descr_->typeLength() * valuesDecoded)) { + ParquetException::eofException(); } - PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); + PARQUET_THROW_NOT_OK(Builder->Reserve(numValues)); VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { - PARQUET_THROW_NOT_OK(builder->Append(data_)); - data_ += descr_->type_length(); + PARQUET_THROW_NOT_OK(Builder->Append(data_)); + data_ += descr_->typeLength(); }, - [&]() { PARQUET_THROW_NOT_OK(builder->AppendNull()); }); + [&]() { PARQUET_THROW_NOT_OK(Builder->AppendNull()); }); - num_values_ -= values_decoded; - len_ -= descr_->type_length() * values_decoded; - return values_decoded; + numValues_ -= valuesDecoded; + len_ -= descr_->typeLength() * valuesDecoded; + return valuesDecoded; } class PlainByteArrayDecoder : public PlainDecoder, virtual public ByteArrayDecoder { public: using Base = PlainDecoder; - using Base::DecodeSpaced; + using Base::decodeSpaced; using Base::PlainDecoder; - // ---------------------------------------------------------------------- - // Dictionary read paths + // ----------------------------------------------------------------------. + // Dictionary read paths. - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - ::arrow::BinaryDictionary32Builder* builder) override { + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + ::arrow::BinaryDictionary32Builder* Builder) override { int result = 0; - PARQUET_THROW_NOT_OK(DecodeArrow( - num_values, - null_count, - valid_bits, - valid_bits_offset, - builder, - &result)); + PARQUET_THROW_NOT_OK(decodeArrow( + numValues, nullCount, validBits, validBitsOffset, Builder, &result)); return result; } - // ---------------------------------------------------------------------- - // Optimized dense binary read paths + // ----------------------------------------------------------------------. + // Optimized dense binary read paths. - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::Accumulator* out) override { int result = 0; - PARQUET_THROW_NOT_OK(DecodeArrowDense( - num_values, null_count, valid_bits, valid_bits_offset, out, &result)); + PARQUET_THROW_NOT_OK(decodeArrowDense( + numValues, nullCount, validBits, validBitsOffset, out, &result)); return result; } private: - Status DecodeArrowDense( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + Status decodeArrowDense( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::Accumulator* out, - int* out_values_decoded) { - ArrowBinaryHelper helper(out, num_values); - int values_decoded = 0; + int* outValuesDecoded) { + ArrowBinaryHelper helper(out, numValues); + int valuesDecoded = 0; - RETURN_NOT_OK(helper.Prepare(len_)); + RETURN_NOT_OK(helper.prepare(len_)); int i = 0; RETURN_NOT_OK(VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { if (ARROW_PREDICT_FALSE(len_ < 4)) { - ParquetException::EofException(); + ParquetException::eofException(); } - auto value_len = SafeLoadAs(data_); - if (ARROW_PREDICT_FALSE(value_len < 0 || value_len > INT32_MAX - 4)) { + auto valueLen = SafeLoadAs(data_); + if (ARROW_PREDICT_FALSE(valueLen < 0 || valueLen > INT32_MAX - 4)) { return Status::Invalid( - "Invalid or corrupted value_len '", value_len, "'"); + "Invalid or corrupted value_len '", valueLen, "'"); } - auto increment = value_len + 4; + auto increment = valueLen + 4; if (ARROW_PREDICT_FALSE(len_ < increment)) { - ParquetException::EofException(); + ParquetException::eofException(); } - RETURN_NOT_OK(helper.PrepareNextInput(value_len, len_)); - helper.UnsafeAppend(data_ + 4, value_len); + RETURN_NOT_OK(helper.prepareNextInput(valueLen, len_)); + helper.UnsafeAppend(data_ + 4, valueLen); data_ += increment; len_ -= increment; - ++values_decoded; + ++valuesDecoded; ++i; return Status::OK(); }, @@ -1655,50 +1641,50 @@ class PlainByteArrayDecoder : public PlainDecoder, return Status::OK(); })); - num_values_ -= values_decoded; - *out_values_decoded = values_decoded; + numValues_ -= valuesDecoded; + *outValuesDecoded = valuesDecoded; return Status::OK(); } template - Status DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - BuilderType* builder, - int* out_values_decoded) { - RETURN_NOT_OK(builder->Reserve(num_values)); - int values_decoded = 0; + Status decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + BuilderType* Builder, + int* outValuesDecoded) { + RETURN_NOT_OK(Builder->Reserve(numValues)); + int valuesDecoded = 0; RETURN_NOT_OK(VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { if (ARROW_PREDICT_FALSE(len_ < 4)) { - ParquetException::EofException(); + ParquetException::eofException(); } - auto value_len = SafeLoadAs(data_); - if (ARROW_PREDICT_FALSE(value_len < 0 || value_len > INT32_MAX - 4)) { + auto valueLen = SafeLoadAs(data_); + if (ARROW_PREDICT_FALSE(valueLen < 0 || valueLen > INT32_MAX - 4)) { return Status::Invalid( - "Invalid or corrupted value_len '", value_len, "'"); + "Invalid or corrupted value_len '", valueLen, "'"); } - auto increment = value_len + 4; + auto increment = valueLen + 4; if (ARROW_PREDICT_FALSE(len_ < increment)) { - ParquetException::EofException(); + ParquetException::eofException(); } - RETURN_NOT_OK(builder->Append(data_ + 4, value_len)); + RETURN_NOT_OK(Builder->Append(data_ + 4, valueLen)); data_ += increment; len_ -= increment; - ++values_decoded; + ++valuesDecoded; return Status::OK(); }, - [&]() { return builder->AppendNull(); })); + [&]() { return Builder->AppendNull(); })); - num_values_ -= values_decoded; - *out_values_decoded = values_decoded; + numValues_ -= valuesDecoded; + *outValuesDecoded = valuesDecoded; return Status::OK(); } }; @@ -1710,486 +1696,474 @@ class PlainFLBADecoder : public PlainDecoder, using Base::PlainDecoder; }; -// ---------------------------------------------------------------------- -// Dictionary encoding and decoding +// ----------------------------------------------------------------------. +// Dictionary encoding and decoding. template class DictDecoderImpl : public DecoderImpl, virtual public DictDecoder { public: - typedef typename Type::c_type T; + typedef typename Type::CType T; - // Initializes the dictionary with values from 'dictionary'. The data in - // dictionary is not guaranteed to persist in memory after this call so the - // dictionary decoder needs to copy the data out if necessary. + // Initializes the dictionary with values from 'dictionary'. The data in. + // Dictionary is not guaranteed to persist in memory after this call so the. + // Dictionary decoder needs to copy the data out if necessary. explicit DictDecoderImpl( const ColumnDescriptor* descr, MemoryPool* pool = ::arrow::default_memory_pool()) - : DecoderImpl(descr, Encoding::RLE_DICTIONARY), - dictionary_(AllocateBuffer(pool, 0)), - dictionary_length_(0), - byte_array_data_(AllocateBuffer(pool, 0)), - byte_array_offsets_(AllocateBuffer(pool, 0)), - indices_scratch_space_(AllocateBuffer(pool, 0)) {} - - // Perform type-specific initialization - void SetDict(TypedDecoder* dictionary) override; - - void SetData(int num_values, const uint8_t* data, int len) override { - num_values_ = num_values; + : DecoderImpl(descr, Encoding::kRleDictionary), + dictionary_(allocateBuffer(pool, 0)), + dictionaryLength_(0), + byteArrayData_(allocateBuffer(pool, 0)), + byteArrayOffsets_(allocateBuffer(pool, 0)), + indicesScratchSpace_(allocateBuffer(pool, 0)) {} + + // Perform type-specific initialization. + void setDict(TypedDecoder* dictionary) override; + + void setData(int numValues, const uint8_t* data, int len) override { + numValues_ = numValues; if (len == 0) { - // Initialize dummy decoder to avoid crashes later on - idx_decoder_ = RleDecoder(data, len, /*bitWidth=*/1); + // Initialize dummy decoder to avoid crashes later on. + idxDecoder_ = RleDecoder(data, len, 1); return; } - uint8_t bit_width = *data; - if (ARROW_PREDICT_FALSE(bit_width > 32)) { + uint8_t bitWidth = *data; + if (ARROW_PREDICT_FALSE(bitWidth > 32)) { throw ParquetException( - "Invalid or corrupted bit_width " + std::to_string(bit_width) + + "Invalid or corrupted bit_width " + std::to_string(bitWidth) + ". Maximum allowed is 32."); } - idx_decoder_ = RleDecoder(++data, --len, bit_width); + idxDecoder_ = RleDecoder(++data, --len, bitWidth); } - int Decode(T* buffer, int num_values) override { - num_values = std::min(num_values, num_values_); - int decoded_values = idx_decoder_.GetBatchWithDict( + int decode(T* buffer, int numValues) override { + numValues = std::min(numValues, numValues_); + int decodedValues = idxDecoder_.GetBatchWithDict( reinterpret_cast(dictionary_->data()), - dictionary_length_, + dictionaryLength_, buffer, - num_values); - if (decoded_values != num_values) { - ParquetException::EofException(); + numValues); + if (decodedValues != numValues) { + ParquetException::eofException(); } - num_values_ -= num_values; - return num_values; + numValues_ -= numValues; + return numValues; } - int DecodeSpaced( + int decodeSpaced( T* buffer, - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset) override { - num_values = std::min(num_values, num_values_); - if (num_values != - idx_decoder_.GetBatchWithDictSpaced( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset) override { + numValues = std::min(numValues, numValues_); + if (numValues != + idxDecoder_.GetBatchWithDictSpaced( reinterpret_cast(dictionary_->data()), - dictionary_length_, + dictionaryLength_, buffer, - num_values, - null_count, - valid_bits, - valid_bits_offset)) { - ParquetException::EofException(); + numValues, + nullCount, + validBits, + validBitsOffset)) { + ParquetException::eofException(); } - num_values_ -= num_values; - return num_values; + numValues_ -= numValues; + return numValues; } - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::Accumulator* out) override; - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::DictAccumulator* out) override; - void InsertDictionary(::arrow::ArrayBuilder* builder) override; - - int DecodeIndicesSpaced( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - ::arrow::ArrayBuilder* builder) override { - if (num_values > 0) { - // TODO(wesm): Refactor to batch reads for improved memory use. It is not - // trivial because the null_count is relative to the entire bitmap - PARQUET_THROW_NOT_OK(indices_scratch_space_->TypedResize( - num_values, /*shrink_to_fit=*/false)); + void insertDictionary(::arrow::ArrayBuilder* Builder) override; + + int decodeIndicesSpaced( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + ::arrow::ArrayBuilder* Builder) override { + if (numValues > 0) { + // TODO(wesm): Refactor to batch reads for improved memory use. It is not. + // Trivial because the null_count is relative to the entire bitmap. + PARQUET_THROW_NOT_OK( + indicesScratchSpace_->TypedResize(numValues, false)); } - auto indices_buffer = - reinterpret_cast(indices_scratch_space_->mutable_data()); + auto indicesBuffer = + reinterpret_cast(indicesScratchSpace_->mutable_data()); - if (num_values != - idx_decoder_.GetBatchSpaced( - num_values, - null_count, - valid_bits, - valid_bits_offset, - indices_buffer)) { - ParquetException::EofException(); + if (numValues != + idxDecoder_.GetBatchSpaced( + numValues, nullCount, validBits, validBitsOffset, indicesBuffer)) { + ParquetException::eofException(); } - // XXX(wesm): Cannot append "valid bits" directly to the builder - std::vector valid_bytes(num_values, 0); + // XXX(wesm): Cannot Append "valid bits" directly to the builder. + std::vector validBytes(numValues, 0); int64_t i = 0; VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, - [&]() { valid_bytes[i++] = 1; }, + validBits, + validBitsOffset, + numValues, + nullCount, + [&]() { validBytes[i++] = 1; }, [&]() { ++i; }); - auto binary_builder = - checked_cast<::arrow::BinaryDictionary32Builder*>(builder); - PARQUET_THROW_NOT_OK(binary_builder->AppendIndices( - indices_buffer, num_values, valid_bytes.data())); - num_values_ -= num_values - null_count; - return num_values - null_count; - } - - int DecodeIndices(int num_values, ::arrow::ArrayBuilder* builder) override { - num_values = std::min(num_values, num_values_); - if (num_values > 0) { - // TODO(wesm): Refactor to batch reads for improved memory use. This is - // relatively simple here because we don't have to do any bookkeeping of - // nulls - PARQUET_THROW_NOT_OK(indices_scratch_space_->TypedResize( - num_values, /*shrink_to_fit=*/false)); + auto binaryBuilder = + checked_cast<::arrow::BinaryDictionary32Builder*>(Builder); + PARQUET_THROW_NOT_OK(binaryBuilder->AppendIndices( + indicesBuffer, numValues, validBytes.data())); + numValues_ -= numValues - nullCount; + return numValues - nullCount; + } + + int decodeIndices(int numValues, ::arrow::ArrayBuilder* Builder) override { + numValues = std::min(numValues, numValues_); + if (numValues > 0) { + // TODO(wesm): Refactor to batch reads for improved memory use. This is. + // Relatively simple here because we don't have to do any bookkeeping of. + // Nulls. + PARQUET_THROW_NOT_OK( + indicesScratchSpace_->TypedResize(numValues, false)); } - auto indices_buffer = - reinterpret_cast(indices_scratch_space_->mutable_data()); - if (num_values != idx_decoder_.GetBatch(indices_buffer, num_values)) { - ParquetException::EofException(); + auto indicesBuffer = + reinterpret_cast(indicesScratchSpace_->mutable_data()); + if (numValues != idxDecoder_.GetBatch(indicesBuffer, numValues)) { + ParquetException::eofException(); } - auto binary_builder = - checked_cast<::arrow::BinaryDictionary32Builder*>(builder); + auto binaryBuilder = + checked_cast<::arrow::BinaryDictionary32Builder*>(Builder); PARQUET_THROW_NOT_OK( - binary_builder->AppendIndices(indices_buffer, num_values)); - num_values_ -= num_values; - return num_values; + binaryBuilder->AppendIndices(indicesBuffer, numValues)); + numValues_ -= numValues; + return numValues; } - int DecodeIndices(int num_values, int32_t* indices) override { - if (num_values != idx_decoder_.GetBatch(indices, num_values)) { - ParquetException::EofException(); + int decodeIndices(int numValues, int32_t* indices) override { + if (numValues != idxDecoder_.GetBatch(indices, numValues)) { + ParquetException::eofException(); } - num_values_ -= num_values; - return num_values; + numValues_ -= numValues; + return numValues; } - void GetDictionary(const T** dictionary, int32_t* dictionary_length) - override { - *dictionary_length = dictionary_length_; + void getDictionary(const T** dictionary, int32_t* dictionaryLength) override { + *dictionaryLength = dictionaryLength_; *dictionary = reinterpret_cast(dictionary_->mutable_data()); } protected: - Status IndexInBounds(int32_t index) { - if (ARROW_PREDICT_TRUE(0 <= index && index < dictionary_length_)) { + Status indexInBounds(int32_t index) { + if (ARROW_PREDICT_TRUE(0 <= index && index < dictionaryLength_)) { return Status::OK(); } return Status::Invalid("Index not in dictionary bounds"); } - inline void DecodeDict(TypedDecoder* dictionary) { - dictionary_length_ = static_cast(dictionary->values_left()); - PARQUET_THROW_NOT_OK(dictionary_->Resize( - dictionary_length_ * sizeof(T), - /*shrink_to_fit=*/false)); - dictionary->Decode( - reinterpret_cast(dictionary_->mutable_data()), dictionary_length_); + inline void decodeDict(TypedDecoder* dictionary) { + dictionaryLength_ = static_cast(dictionary->valuesLeft()); + PARQUET_THROW_NOT_OK( + dictionary_->Resize(dictionaryLength_ * sizeof(T), false)); + dictionary->decode( + reinterpret_cast(dictionary_->mutable_data()), dictionaryLength_); } // Only one is set. std::shared_ptr dictionary_; - int32_t dictionary_length_; + int32_t dictionaryLength_; - // Data that contains the byte array data (byte_array_dictionary_ just has the - // pointers). - std::shared_ptr byte_array_data_; + // Data that contains the byte array data (byte_array_dictionary_ just has + // the. Pointers). + std::shared_ptr byteArrayData_; - // Arrow-style byte offsets for each dictionary value. We maintain two - // representations of the dictionary, one as ByteArray* for non-Arrow - // consumers and this one for Arrow consumers. Since dictionaries are - // generally pretty small to begin with this doesn't mean too much extra - // memory use in most cases - std::shared_ptr byte_array_offsets_; + // Arrow-style byte offsets for each dictionary value. We maintain two. + // Representations of the dictionary, one as ByteArray* for non-Arrow. + // Consumers and this one for Arrow consumers. Since dictionaries are. + // Generally pretty small to begin with this doesn't mean too much extra. + // Memory use in most cases. + std::shared_ptr byteArrayOffsets_; - // Reusable buffer for decoding dictionary indices to be appended to a - // BinaryDictionary32Builder - std::shared_ptr indices_scratch_space_; + // Reusable buffer for decoding dictionary indices to be appended to a. + // BinaryDictionary32Builder. + std::shared_ptr indicesScratchSpace_; - RleDecoder idx_decoder_; + RleDecoder idxDecoder_; }; template -void DictDecoderImpl::SetDict(TypedDecoder* dictionary) { - DecodeDict(dictionary); +void DictDecoderImpl::setDict(TypedDecoder* dictionary) { + decodeDict(dictionary); } template <> -void DictDecoderImpl::SetDict( +void DictDecoderImpl::setDict( TypedDecoder* dictionary) { ParquetException::NYI( "Dictionary encoding is not implemented for boolean values"); } template <> -void DictDecoderImpl::SetDict( +void DictDecoderImpl::setDict( TypedDecoder* dictionary) { - DecodeDict(dictionary); + decodeDict(dictionary); - auto dict_values = reinterpret_cast(dictionary_->mutable_data()); + auto dictValues = reinterpret_cast(dictionary_->mutable_data()); - int total_size = 0; - for (int i = 0; i < dictionary_length_; ++i) { - total_size += dict_values[i].len; + int totalSize = 0; + for (int i = 0; i < dictionaryLength_; ++i) { + totalSize += dictValues[i].len; } - PARQUET_THROW_NOT_OK(byte_array_data_->Resize( - total_size, - /*shrink_to_fit=*/false)); - PARQUET_THROW_NOT_OK(byte_array_offsets_->Resize( - (dictionary_length_ + 1) * sizeof(int32_t), - /*shrink_to_fit=*/false)); + PARQUET_THROW_NOT_OK(byteArrayData_->Resize(totalSize, false)); + PARQUET_THROW_NOT_OK(byteArrayOffsets_->Resize( + (dictionaryLength_ + 1) * sizeof(int32_t), false)); int32_t offset = 0; - uint8_t* bytes_data = byte_array_data_->mutable_data(); - int32_t* bytes_offsets = - reinterpret_cast(byte_array_offsets_->mutable_data()); - for (int i = 0; i < dictionary_length_; ++i) { - memcpy(bytes_data + offset, dict_values[i].ptr, dict_values[i].len); - bytes_offsets[i] = offset; - dict_values[i].ptr = bytes_data + offset; - offset += dict_values[i].len; - } - bytes_offsets[dictionary_length_] = offset; + uint8_t* bytesData = byteArrayData_->mutable_data(); + int32_t* bytesOffsets = + reinterpret_cast(byteArrayOffsets_->mutable_data()); + for (int i = 0; i < dictionaryLength_; ++i) { + memcpy(bytesData + offset, dictValues[i].ptr, dictValues[i].len); + bytesOffsets[i] = offset; + dictValues[i].ptr = bytesData + offset; + offset += dictValues[i].len; + } + bytesOffsets[dictionaryLength_] = offset; } template <> -inline void DictDecoderImpl::SetDict( +inline void DictDecoderImpl::setDict( TypedDecoder* dictionary) { - DecodeDict(dictionary); + decodeDict(dictionary); - auto dict_values = reinterpret_cast(dictionary_->mutable_data()); + auto dictValues = reinterpret_cast(dictionary_->mutable_data()); - int fixed_len = descr_->type_length(); - int total_size = dictionary_length_ * fixed_len; + int fixedLen = descr_->typeLength(); + int totalSize = dictionaryLength_ * fixedLen; - PARQUET_THROW_NOT_OK(byte_array_data_->Resize( - total_size, - /*shrink_to_fit=*/false)); - uint8_t* bytes_data = byte_array_data_->mutable_data(); - for (int32_t i = 0, offset = 0; i < dictionary_length_; - ++i, offset += fixed_len) { - memcpy(bytes_data + offset, dict_values[i].ptr, fixed_len); - dict_values[i].ptr = bytes_data + offset; + PARQUET_THROW_NOT_OK(byteArrayData_->Resize(totalSize, false)); + uint8_t* bytesData = byteArrayData_->mutable_data(); + for (int32_t i = 0, offset = 0; i < dictionaryLength_; + ++i, offset += fixedLen) { + memcpy(bytesData + offset, dictValues[i].ptr, fixedLen); + dictValues[i].ptr = bytesData + offset; } } template <> -inline int DictDecoderImpl::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::Accumulator* builder) { +inline int DictDecoderImpl::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::Accumulator* Builder) { ParquetException::NYI("DecodeArrow to Int96Type"); } template <> -inline int DictDecoderImpl::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) { +inline int DictDecoderImpl::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) { ParquetException::NYI("DecodeArrow to Int96Type"); } template <> -inline int DictDecoderImpl::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::Accumulator* builder) { +inline int DictDecoderImpl::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::Accumulator* Builder) { ParquetException::NYI("DecodeArrow implemented elsewhere"); } template <> -inline int DictDecoderImpl::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) { +inline int DictDecoderImpl::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) { ParquetException::NYI("DecodeArrow implemented elsewhere"); } template -int DictDecoderImpl::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) { - PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); +int DictDecoderImpl::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) { + PARQUET_THROW_NOT_OK(Builder->Reserve(numValues)); - auto dict_values = - reinterpret_cast(dictionary_->data()); + auto dictValues = + reinterpret_cast(dictionary_->data()); VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { int32_t index; - if (ARROW_PREDICT_FALSE(!idx_decoder_.Get(&index))) { + if (ARROW_PREDICT_FALSE(!idxDecoder_.Get(&index))) { throw ParquetException(""); } - PARQUET_THROW_NOT_OK(IndexInBounds(index)); - PARQUET_THROW_NOT_OK(builder->Append(dict_values[index])); + PARQUET_THROW_NOT_OK(indexInBounds(index)); + PARQUET_THROW_NOT_OK(Builder->Append(dictValues[index])); }, - [&]() { PARQUET_THROW_NOT_OK(builder->AppendNull()); }); + [&]() { PARQUET_THROW_NOT_OK(Builder->AppendNull()); }); - return num_values - null_count; + return numValues - nullCount; } template <> -int DictDecoderImpl::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) { +int DictDecoderImpl::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) { ParquetException::NYI("No dictionary encoding for BooleanType"); } template <> -inline int DictDecoderImpl::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::Accumulator* builder) { - if (builder->byte_width() != descr_->type_length()) { +inline int DictDecoderImpl::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::Accumulator* Builder) { + if (Builder->byte_width() != descr_->typeLength()) { throw ParquetException( "Byte width mismatch: builder was " + - std::to_string(builder->byte_width()) + " but decoder was " + - std::to_string(descr_->type_length())); + std::to_string(Builder->byte_width()) + " but decoder was " + + std::to_string(descr_->typeLength())); } - PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); + PARQUET_THROW_NOT_OK(Builder->Reserve(numValues)); - auto dict_values = reinterpret_cast(dictionary_->data()); + auto dictValues = reinterpret_cast(dictionary_->data()); VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { int32_t index; - if (ARROW_PREDICT_FALSE(!idx_decoder_.Get(&index))) { + if (ARROW_PREDICT_FALSE(!idxDecoder_.Get(&index))) { throw ParquetException(""); } - PARQUET_THROW_NOT_OK(IndexInBounds(index)); - builder->UnsafeAppend(dict_values[index].ptr); + PARQUET_THROW_NOT_OK(indexInBounds(index)); + Builder->UnsafeAppend(dictValues[index].ptr); }, - [&]() { builder->UnsafeAppendNull(); }); + [&]() { Builder->UnsafeAppendNull(); }); - return num_values - null_count; + return numValues - nullCount; } template <> -int DictDecoderImpl::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) { - auto value_type = - checked_cast(*builder->type()) +int DictDecoderImpl::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) { + auto valueType = + checked_cast(*Builder->type()) .value_type(); - auto byte_width = - checked_cast(*value_type) - .byte_width(); - if (byte_width != descr_->type_length()) { + auto byteWidth = checked_cast(*valueType) + .byte_width(); + if (byteWidth != descr_->typeLength()) { throw ParquetException( - "Byte width mismatch: builder was " + std::to_string(byte_width) + - " but decoder was " + std::to_string(descr_->type_length())); + "Byte width mismatch: builder was " + std::to_string(byteWidth) + + " but decoder was " + std::to_string(descr_->typeLength())); } - PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); + PARQUET_THROW_NOT_OK(Builder->Reserve(numValues)); - auto dict_values = reinterpret_cast(dictionary_->data()); + auto dictValues = reinterpret_cast(dictionary_->data()); VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { int32_t index; - if (ARROW_PREDICT_FALSE(!idx_decoder_.Get(&index))) { + if (ARROW_PREDICT_FALSE(!idxDecoder_.Get(&index))) { throw ParquetException(""); } - PARQUET_THROW_NOT_OK(IndexInBounds(index)); - PARQUET_THROW_NOT_OK(builder->Append(dict_values[index].ptr)); + PARQUET_THROW_NOT_OK(indexInBounds(index)); + PARQUET_THROW_NOT_OK(Builder->Append(dictValues[index].ptr)); }, - [&]() { PARQUET_THROW_NOT_OK(builder->AppendNull()); }); + [&]() { PARQUET_THROW_NOT_OK(Builder->AppendNull()); }); - return num_values - null_count; + return numValues - nullCount; } template -int DictDecoderImpl::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::Accumulator* builder) { - PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); +int DictDecoderImpl::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::Accumulator* Builder) { + PARQUET_THROW_NOT_OK(Builder->Reserve(numValues)); - using value_type = typename Type::c_type; - auto dict_values = reinterpret_cast(dictionary_->data()); + using ValueType = typename Type::CType; + auto dictValues = reinterpret_cast(dictionary_->data()); VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { int32_t index; - if (ARROW_PREDICT_FALSE(!idx_decoder_.Get(&index))) { + if (ARROW_PREDICT_FALSE(!idxDecoder_.Get(&index))) { throw ParquetException(""); } - PARQUET_THROW_NOT_OK(IndexInBounds(index)); - builder->UnsafeAppend(dict_values[index]); + PARQUET_THROW_NOT_OK(indexInBounds(index)); + Builder->UnsafeAppend(dictValues[index]); }, - [&]() { builder->UnsafeAppendNull(); }); + [&]() { Builder->UnsafeAppendNull(); }); - return num_values - null_count; + return numValues - nullCount; } template -void DictDecoderImpl::InsertDictionary(::arrow::ArrayBuilder* builder) { +void DictDecoderImpl::insertDictionary(::arrow::ArrayBuilder* Builder) { ParquetException::NYI( "InsertDictionary only implemented for BYTE_ARRAY types"); } template <> -void DictDecoderImpl::InsertDictionary( - ::arrow::ArrayBuilder* builder) { - auto binary_builder = - checked_cast<::arrow::BinaryDictionary32Builder*>(builder); +void DictDecoderImpl::insertDictionary( + ::arrow::ArrayBuilder* Builder) { + auto binaryBuilder = + checked_cast<::arrow::BinaryDictionary32Builder*>(Builder); - // Make a BinaryArray referencing the internal dictionary data + // Make a BinaryArray referencing the internal dictionary data. auto arr = std::make_shared<::arrow::BinaryArray>( - dictionary_length_, byte_array_offsets_, byte_array_data_); - PARQUET_THROW_NOT_OK(binary_builder->InsertMemoValues(*arr)); + dictionaryLength_, byteArrayOffsets_, byteArrayData_); + PARQUET_THROW_NOT_OK(binaryBuilder->InsertMemoValues(*arr)); } class DictByteArrayDecoderImpl : public DictDecoderImpl, @@ -2198,497 +2172,489 @@ class DictByteArrayDecoderImpl : public DictDecoderImpl, using BASE = DictDecoderImpl; using BASE::DictDecoderImpl; - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - ::arrow::BinaryDictionary32Builder* builder) override { + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + ::arrow::BinaryDictionary32Builder* Builder) override { int result = 0; - if (null_count == 0) { - PARQUET_THROW_NOT_OK(DecodeArrowNonNull(num_values, builder, &result)); + if (nullCount == 0) { + PARQUET_THROW_NOT_OK(decodeArrowNonNull(numValues, Builder, &result)); } else { - PARQUET_THROW_NOT_OK(DecodeArrow( - num_values, - null_count, - valid_bits, - valid_bits_offset, - builder, - &result)); + PARQUET_THROW_NOT_OK(decodeArrow( + numValues, nullCount, validBits, validBitsOffset, Builder, &result)); } return result; } - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::Accumulator* out) override { int result = 0; - if (null_count == 0) { - PARQUET_THROW_NOT_OK(DecodeArrowDenseNonNull(num_values, out, &result)); + if (nullCount == 0) { + PARQUET_THROW_NOT_OK(decodeArrowDenseNonNull(numValues, out, &result)); } else { - PARQUET_THROW_NOT_OK(DecodeArrowDense( - num_values, null_count, valid_bits, valid_bits_offset, out, &result)); + PARQUET_THROW_NOT_OK(decodeArrowDense( + numValues, nullCount, validBits, validBitsOffset, out, &result)); } return result; } private: - Status DecodeArrowDense( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + Status decodeArrowDense( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::Accumulator* out, - int* out_num_values) { + int* outNumValues) { constexpr int32_t kBufferSize = 1024; int32_t indices[kBufferSize]; - ArrowBinaryHelper helper(out, num_values); - RETURN_NOT_OK(helper.Prepare()); - - auto dict_values = reinterpret_cast(dictionary_->data()); - int values_decoded = 0; - int num_indices = 0; - int pos_indices = 0; - - auto visit_valid = [&](int64_t position) -> Status { - if (num_indices == pos_indices) { - // Refill indices buffer - const auto batch_size = std::min( - kBufferSize, num_values - null_count - values_decoded); - num_indices = idx_decoder_.GetBatch(indices, batch_size); - if (ARROW_PREDICT_FALSE(num_indices < 1)) { - return Status::Invalid("Invalid number of indices: ", num_indices); + ArrowBinaryHelper helper(out, numValues); + RETURN_NOT_OK(helper.prepare()); + + auto dictValues = reinterpret_cast(dictionary_->data()); + int valuesDecoded = 0; + int numIndices = 0; + int posIndices = 0; + + auto visitValid = [&](int64_t position) -> Status { + if (numIndices == posIndices) { + // Refill indices buffer. + const auto batchSize = std::min( + kBufferSize, numValues - nullCount - valuesDecoded); + numIndices = idxDecoder_.GetBatch(indices, batchSize); + if (ARROW_PREDICT_FALSE(numIndices < 1)) { + return Status::Invalid("Invalid number of indices: ", numIndices); } - pos_indices = 0; + posIndices = 0; } - const auto index = indices[pos_indices++]; - RETURN_NOT_OK(IndexInBounds(index)); - const auto& val = dict_values[index]; - RETURN_NOT_OK(helper.PrepareNextInput(val.len)); + const auto index = indices[posIndices++]; + RETURN_NOT_OK(indexInBounds(index)); + const auto& val = dictValues[index]; + RETURN_NOT_OK(helper.prepareNextInput(val.len)); RETURN_NOT_OK(helper.Append(val.ptr, static_cast(val.len))); - ++values_decoded; + ++valuesDecoded; return Status::OK(); }; - auto visit_null = [&]() -> Status { + auto visitNull = [&]() -> Status { RETURN_NOT_OK(helper.AppendNull()); return Status::OK(); }; - ::arrow::internal::BitBlockCounter bit_blocks( - valid_bits, valid_bits_offset, num_values); + ::arrow::internal::BitBlockCounter bitBlocks( + validBits, validBitsOffset, numValues); int64_t position = 0; - while (position < num_values) { - const auto block = bit_blocks.NextWord(); + while (position < numValues) { + const auto block = bitBlocks.NextWord(); if (block.AllSet()) { for (int64_t i = 0; i < block.length; ++i, ++position) { - ARROW_RETURN_NOT_OK(visit_valid(position)); + ARROW_RETURN_NOT_OK(visitValid(position)); } } else if (block.NoneSet()) { for (int64_t i = 0; i < block.length; ++i, ++position) { - ARROW_RETURN_NOT_OK(visit_null()); + ARROW_RETURN_NOT_OK(visitNull()); } } else { for (int64_t i = 0; i < block.length; ++i, ++position) { if (::arrow::bit_util::GetBit( - valid_bits, valid_bits_offset + position)) { - ARROW_RETURN_NOT_OK(visit_valid(position)); + validBits, validBitsOffset + position)) { + ARROW_RETURN_NOT_OK(visitValid(position)); } else { - ARROW_RETURN_NOT_OK(visit_null()); + ARROW_RETURN_NOT_OK(visitNull()); } } } } - *out_num_values = values_decoded; + *outNumValues = valuesDecoded; return Status::OK(); } - Status DecodeArrowDenseNonNull( - int num_values, + Status decodeArrowDenseNonNull( + int numValues, typename EncodingTraits::Accumulator* out, - int* out_num_values) { + int* outNumValues) { constexpr int32_t kBufferSize = 2048; int32_t indices[kBufferSize]; - int values_decoded = 0; + int valuesDecoded = 0; - ArrowBinaryHelper helper(out, num_values); - RETURN_NOT_OK(helper.Prepare(len_)); + ArrowBinaryHelper helper(out, numValues); + RETURN_NOT_OK(helper.prepare(len_)); - auto dict_values = reinterpret_cast(dictionary_->data()); + auto dictValues = reinterpret_cast(dictionary_->data()); - while (values_decoded < num_values) { - const int32_t batch_size = - std::min(kBufferSize, num_values - values_decoded); - const int num_indices = idx_decoder_.GetBatch(indices, batch_size); - if (num_indices == 0) - ParquetException::EofException(); - for (int i = 0; i < num_indices; ++i) { + while (valuesDecoded < numValues) { + const int32_t batchSize = + std::min(kBufferSize, numValues - valuesDecoded); + const int numIndices = idxDecoder_.GetBatch(indices, batchSize); + if (numIndices == 0) + ParquetException::eofException(); + for (int i = 0; i < numIndices; ++i) { auto idx = indices[i]; - RETURN_NOT_OK(IndexInBounds(idx)); - const auto& val = dict_values[idx]; - RETURN_NOT_OK(helper.PrepareNextInput(val.len)); + RETURN_NOT_OK(indexInBounds(idx)); + const auto& val = dictValues[idx]; + RETURN_NOT_OK(helper.prepareNextInput(val.len)); RETURN_NOT_OK(helper.Append(val.ptr, static_cast(val.len))); } - values_decoded += num_indices; + valuesDecoded += numIndices; } - *out_num_values = values_decoded; + *outNumValues = valuesDecoded; return Status::OK(); } template - Status DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - BuilderType* builder, - int* out_num_values) { + Status decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + BuilderType* Builder, + int* outNumValues) { constexpr int32_t kBufferSize = 1024; int32_t indices[kBufferSize]; - RETURN_NOT_OK(builder->Reserve(num_values)); - ::arrow::internal::BitmapReader bit_reader( - valid_bits, valid_bits_offset, num_values); + RETURN_NOT_OK(Builder->Reserve(numValues)); + ::arrow::internal::BitmapReader bitReader( + validBits, validBitsOffset, numValues); - auto dict_values = reinterpret_cast(dictionary_->data()); + auto dictValues = reinterpret_cast(dictionary_->data()); - int values_decoded = 0; - int num_appended = 0; - while (num_appended < num_values) { - bool is_valid = bit_reader.IsSet(); - bit_reader.Next(); + int valuesDecoded = 0; + int numAppended = 0; + while (numAppended < numValues) { + bool isValid = bitReader.IsSet(); + bitReader.Next(); - if (is_valid) { - int32_t batch_size = std::min( - kBufferSize, num_values - num_appended - null_count); - int num_indices = idx_decoder_.GetBatch(indices, batch_size); + if (isValid) { + int32_t batchSize = + std::min(kBufferSize, numValues - numAppended - nullCount); + int numIndices = idxDecoder_.GetBatch(indices, batchSize); int i = 0; while (true) { - // Consume all indices - if (is_valid) { + // Consume all indices. + if (isValid) { auto idx = indices[i]; - RETURN_NOT_OK(IndexInBounds(idx)); - const auto& val = dict_values[idx]; - RETURN_NOT_OK(builder->Append(val.ptr, val.len)); + RETURN_NOT_OK(indexInBounds(idx)); + const auto& val = dictValues[idx]; + RETURN_NOT_OK(Builder->Append(val.ptr, val.len)); ++i; - ++values_decoded; + ++valuesDecoded; } else { - RETURN_NOT_OK(builder->AppendNull()); - --null_count; + RETURN_NOT_OK(Builder->AppendNull()); + --nullCount; } - ++num_appended; - if (i == num_indices) { - // Do not advance the bit_reader if we have fulfilled the decode - // request + ++numAppended; + if (i == numIndices) { + // Do not advance the bit_reader if we have fulfilled the decode. + // Request. break; } - is_valid = bit_reader.IsSet(); - bit_reader.Next(); + isValid = bitReader.IsSet(); + bitReader.Next(); } } else { - RETURN_NOT_OK(builder->AppendNull()); - --null_count; - ++num_appended; + RETURN_NOT_OK(Builder->AppendNull()); + --nullCount; + ++numAppended; } } - *out_num_values = values_decoded; + *outNumValues = valuesDecoded; return Status::OK(); } template - Status DecodeArrowNonNull( - int num_values, - BuilderType* builder, - int* out_num_values) { + Status + decodeArrowNonNull(int numValues, BuilderType* Builder, int* outNumValues) { constexpr int32_t kBufferSize = 2048; int32_t indices[kBufferSize]; - RETURN_NOT_OK(builder->Reserve(num_values)); + RETURN_NOT_OK(Builder->Reserve(numValues)); - auto dict_values = reinterpret_cast(dictionary_->data()); + auto dictValues = reinterpret_cast(dictionary_->data()); - int values_decoded = 0; - while (values_decoded < num_values) { - int32_t batch_size = - std::min(kBufferSize, num_values - values_decoded); - int num_indices = idx_decoder_.GetBatch(indices, batch_size); - if (num_indices == 0) - ParquetException::EofException(); - for (int i = 0; i < num_indices; ++i) { + int valuesDecoded = 0; + while (valuesDecoded < numValues) { + int32_t batchSize = + std::min(kBufferSize, numValues - valuesDecoded); + int numIndices = idxDecoder_.GetBatch(indices, batchSize); + if (numIndices == 0) + ParquetException::eofException(); + for (int i = 0; i < numIndices; ++i) { auto idx = indices[i]; - RETURN_NOT_OK(IndexInBounds(idx)); - const auto& val = dict_values[idx]; - RETURN_NOT_OK(builder->Append(val.ptr, val.len)); + RETURN_NOT_OK(indexInBounds(idx)); + const auto& val = dictValues[idx]; + RETURN_NOT_OK(Builder->Append(val.ptr, val.len)); } - values_decoded += num_indices; + valuesDecoded += numIndices; } - *out_num_values = values_decoded; + *outNumValues = valuesDecoded; return Status::OK(); } }; -// ---------------------------------------------------------------------- -// DeltaBitPackEncoder +// ----------------------------------------------------------------------. +// DeltaBitPackEncoder. -/// DeltaBitPackEncoder is an encoder for the DeltaBinary Packing format -/// as per the parquet spec. See: +/// DeltaBitPackEncoder is an encoder for the DeltaBinary Packing format. +/// As per the parquet spec. See: /// https://github.com/apache/parquet-format/blob/master/Encodings.md#delta-encoding-delta_binary_packed--5 /// -/// Consists of a header followed by blocks of delta encoded values binary -/// packed. +/// Consists of a header followed by blocks of delta encoded values binary. +/// Packed. /// -/// Format -/// [header] [block 1] [block 2] ... [block N] +/// Format. +/// [Header] [block 1] [block 2] ... [block N]. /// -/// Header -/// [block size] [number of mini blocks per block] [total value count] [first -/// value] +/// Header. +/// [Block size] [number of mini blocks per block] [total value count] +/// [first. Value]. /// -/// Block -/// [min delta] [list of bitwidths of the mini blocks] [miniblocks] +/// Block. +/// [Min delta] [list of bitwidths of the mini blocks] [miniblocks]. /// -/// Sets aside bytes at the start of the internal buffer where the header will -/// be written, and only writes the header when FlushValues is called before -/// returning it. +/// Sets aside bytes at the start of the internal buffer where the header will. +/// Be written, and only writes the header when FlushValues is called before. +/// Returning it. /// /// To encode a block, we will: /// -/// 1. Compute the differences between consecutive elements. For the first -/// element in the block, use the last element in the previous block or, in the -/// case of the first block, use the first value of the whole sequence, stored -/// in the header. +/// 1. Compute the differences between consecutive elements. For the first. +/// Element in the block, use the last element in the previous block or, in the. +/// Case of the first block, use the first value of the whole sequence, stored. +/// In the header. /// /// 2. Compute the frame of reference (the minimum of the deltas in the block). -/// Subtract this min delta from all deltas in the block. This guarantees that -/// all values are non-negative. +/// Subtract this min delta from all deltas in the block. This guarantees that. +/// All values are non-negative. /// -/// 3. Encode the frame of reference (min delta) as a zigzag ULEB128 int -/// followed by the bit widths of the mini blocks and the delta values (minus -/// the min delta) bit packed per mini block. +/// 3. Encode the frame of reference (min delta) as a zigzag ULEB128 int. +/// Followed by the bit widths of the mini blocks and the delta values (minus. +/// The min delta) bit packed per mini block. /// /// Supports only INT32 and INT64. template class DeltaBitPackEncoder : public EncoderImpl, virtual public TypedEncoder { - // Maximum possible header size + // Maximum possible header size. static constexpr uint32_t kMaxPageHeaderWriterSize = 32; static constexpr uint32_t kValuesPerBlock = - std::is_same_v ? 128 : 256; + std::is_same_v ? 128 : 256; static constexpr uint32_t kMiniBlocksPerBlock = 4; public: - using T = typename DType::c_type; + using T = typename DType::CType; using UT = std::make_unsigned_t; - using TypedEncoder::Put; + using TypedEncoder::put; explicit DeltaBitPackEncoder( const ColumnDescriptor* descr, MemoryPool* pool, - const uint32_t values_per_block = kValuesPerBlock, - const uint32_t mini_blocks_per_block = kMiniBlocksPerBlock) - : EncoderImpl(descr, Encoding::DELTA_BINARY_PACKED, pool), - values_per_block_(values_per_block), - mini_blocks_per_block_(mini_blocks_per_block), - values_per_mini_block_(values_per_block / mini_blocks_per_block), - deltas_(values_per_block, ::arrow::stl::allocator(pool)), - bits_buffer_(AllocateBuffer( + const uint32_t valuesPerBlock = kValuesPerBlock, + const uint32_t miniBlocksPerBlock = kMiniBlocksPerBlock) + : EncoderImpl(descr, Encoding::kDeltaBinaryPacked, pool), + valuesPerBlock_(valuesPerBlock), + miniBlocksPerBlock_(miniBlocksPerBlock), + valuesPerMiniBlock_(valuesPerBlock / miniBlocksPerBlock), + deltas_(valuesPerBlock, ::arrow::stl::allocator(pool)), + bitsBuffer_(allocateBuffer( pool, - (kMiniBlocksPerBlock + values_per_block) * sizeof(T))), + (kMiniBlocksPerBlock + valuesPerBlock) * sizeof(T))), sink_(pool), - bit_writer_( - bits_buffer_->mutable_data(), - static_cast(bits_buffer_->size())) { - if (values_per_block_ % 128 != 0) { + bitWriter_( + bitsBuffer_->mutable_data(), + static_cast(bitsBuffer_->size())) { + if (valuesPerBlock_ % 128 != 0) { throw ParquetException( "the number of values in a block must be multiple of 128, but it's " + - std::to_string(values_per_block_)); + std::to_string(valuesPerBlock_)); } - if (values_per_mini_block_ % 32 != 0) { + if (valuesPerMiniBlock_ % 32 != 0) { throw ParquetException( "the number of values in a miniblock must be multiple of 32, but it's " + - std::to_string(values_per_mini_block_)); + std::to_string(valuesPerMiniBlock_)); } - if (values_per_block % mini_blocks_per_block != 0) { + if (valuesPerBlock % miniBlocksPerBlock != 0) { throw ParquetException( "the number of values per block % number of miniblocks per block must be 0, " "but it's " + - std::to_string(values_per_block % mini_blocks_per_block)); + std::to_string(valuesPerBlock % miniBlocksPerBlock)); } - // Reserve enough space at the beginning of the buffer for largest possible - // header. + // Reserve enough space at the beginning of the buffer for largest possible. + // Header. PARQUET_THROW_NOT_OK(sink_.Advance(kMaxPageHeaderWriterSize)); } - std::shared_ptr<::arrow::Buffer> FlushValues() override; + std::shared_ptr<::arrow::Buffer> flushValues() override; - int64_t EstimatedDataEncodedSize() override { + int64_t estimatedDataEncodedSize() override { return sink_.length(); } - void Put(const ::arrow::Array& values) override; + void put(const ::arrow::Array& values) override; - void Put(const T* buffer, int num_values) override; + void put(const T* buffer, int numValues) override; - void PutSpaced( + void putSpaced( const T* src, - int num_values, - const uint8_t* valid_bits, - int64_t valid_bits_offset) override; + int numValues, + const uint8_t* validBits, + int64_t validBitsOffset) override; - void FlushBlock(); + void flushBlock(); private: - const uint32_t values_per_block_; - const uint32_t mini_blocks_per_block_; - const uint32_t values_per_mini_block_; - uint32_t values_current_block_{0}; - uint32_t total_value_count_{0}; - UT first_value_{0}; - UT current_value_{0}; + const uint32_t valuesPerBlock_; + const uint32_t miniBlocksPerBlock_; + const uint32_t valuesPerMiniBlock_; + uint32_t valuesCurrentBlock_{0}; + uint32_t totalValueCount_{0}; + UT firstValue_{0}; + UT currentValue_{0}; ArrowPoolVector deltas_; - std::shared_ptr bits_buffer_; + std::shared_ptr bitsBuffer_; ::arrow::BufferBuilder sink_; - BitWriter bit_writer_; + BitWriter bitWriter_; }; template -void DeltaBitPackEncoder::Put(const T* src, int num_values) { - if (num_values == 0) { +void DeltaBitPackEncoder::put(const T* src, int numValues) { + if (numValues == 0) { return; } int idx = 0; - if (total_value_count_ == 0) { - current_value_ = src[0]; - first_value_ = current_value_; + if (totalValueCount_ == 0) { + currentValue_ = src[0]; + firstValue_ = currentValue_; idx = 1; } - total_value_count_ += num_values; + totalValueCount_ += numValues; - while (idx < num_values) { + while (idx < numValues) { UT value = static_cast(src[idx]); - // Calculate deltas. The possible overflow is handled by use of unsigned - // integers making subtraction operations well-defined and correct even in - // case of overflow. Encoded integers will wrap back around on decoding. See - // http://en.wikipedia.org/wiki/Modular_arithmetic#Integers_modulo_n - deltas_[values_current_block_] = value - current_value_; - current_value_ = value; + // Calculate deltas. The possible overflow is handled by use of unsigned. + // Integers making subtraction operations well-defined and correct even in. + // Case of overflow. Encoded integers will wrap back around on decoding. + // See. http://en.wikipedia.org/wiki/Modular_arithmetic#Integers_modulo_n + deltas_[valuesCurrentBlock_] = value - currentValue_; + currentValue_ = value; idx++; - values_current_block_++; - if (values_current_block_ == values_per_block_) { - FlushBlock(); + valuesCurrentBlock_++; + if (valuesCurrentBlock_ == valuesPerBlock_) { + flushBlock(); } } } template -void DeltaBitPackEncoder::FlushBlock() { - if (values_current_block_ == 0) { +void DeltaBitPackEncoder::flushBlock() { + if (valuesCurrentBlock_ == 0) { return; } - const UT min_delta = *std::min_element( - deltas_.begin(), deltas_.begin() + values_current_block_); - bit_writer_.PutZigZagVlqInt(static_cast(min_delta)); + const UT minDelta = + *std::min_element(deltas_.begin(), deltas_.begin() + valuesCurrentBlock_); + bitWriter_.PutZigZagVlqInt(static_cast(minDelta)); - // Call to GetNextBytePtr reserves mini_blocks_per_block_ bytes of space to - // write bit widths of miniblocks as they become known during the encoding. - uint8_t* bit_width_data = bit_writer_.GetNextBytePtr(mini_blocks_per_block_); - VELOX_DCHECK(bit_width_data != nullptr); + // Call to GetNextBytePtr reserves mini_blocks_per_block_ bytes of space to. + // Write bit widths of miniblocks as they become known during the encoding. + uint8_t* bitWidthData = bitWriter_.GetNextBytePtr(miniBlocksPerBlock_); + VELOX_DCHECK(bitWidthData != nullptr); - const uint32_t num_miniblocks = static_cast(std::ceil( - static_cast(values_current_block_) / - static_cast(values_per_mini_block_))); - for (uint32_t i = 0; i < num_miniblocks; i++) { - const uint32_t values_current_mini_block = - std::min(values_per_mini_block_, values_current_block_); + const uint32_t numMiniblocks = static_cast(std::ceil( + static_cast(valuesCurrentBlock_) / + static_cast(valuesPerMiniBlock_))); + for (uint32_t i = 0; i < numMiniblocks; i++) { + const uint32_t valuesCurrentMiniBlock = + std::min(valuesPerMiniBlock_, valuesCurrentBlock_); - const uint32_t start = i * values_per_mini_block_; - const UT max_delta = *std::max_element( + const uint32_t start = i * valuesPerMiniBlock_; + const UT maxDelta = *std::max_element( deltas_.begin() + start, - deltas_.begin() + start + values_current_mini_block); + deltas_.begin() + start + valuesCurrentMiniBlock); - // The minimum number of bits required to write any of values in deltas_ - // vector. See overflow comment above. - const auto bit_width = bit_width_data[i] = - ::arrow::bit_util::NumRequiredBits(max_delta - min_delta); + // The minimum number of bits required to write any of values in deltas_. + // Vector. See overflow comment above. + const auto bitWidth = bitWidthData[i] = + ::arrow::bit_util::NumRequiredBits(maxDelta - minDelta); - for (uint32_t j = start; j < start + values_current_mini_block; j++) { + for (uint32_t j = start; j < start + valuesCurrentMiniBlock; j++) { // See overflow comment above. - const UT value = deltas_[j] - min_delta; - bit_writer_.PutValue(value, bit_width); + const UT value = deltas_[j] - minDelta; + bitWriter_.PutValue(value, bitWidth); } - // If there are not enough values to fill the last mini block, we pad the - // mini block with zeroes so that its length is the number of values in a - // full mini block multiplied by the bit width. - for (uint32_t j = values_current_mini_block; j < values_per_mini_block_; - j++) { - bit_writer_.PutValue(0, bit_width); + // If there are not enough values to fill the last mini block, we pad the. + // Mini block with zeroes so that its length is the number of values in a. + // Full mini block multiplied by the bit width. + for (uint32_t j = valuesCurrentMiniBlock; j < valuesPerMiniBlock_; j++) { + bitWriter_.PutValue(0, bitWidth); } - values_current_block_ -= values_current_mini_block; + valuesCurrentBlock_ -= valuesCurrentMiniBlock; } - // If, in the last block, less than - // miniblocks are needed to store the values, the bytes storing the bit widths - // of the unneeded miniblocks are still present, their value should be zero, - // but readers must accept arbitrary values as well. - for (uint32_t i = num_miniblocks; i < mini_blocks_per_block_; i++) { - bit_width_data[i] = 0; + // If, in the last block, less than . + // Miniblocks are needed to store the values, the bytes storing the bit + // widths. Of the unneeded miniblocks are still present, their value should be + // zero,. But readers must accept arbitrary values as well. + for (uint32_t i = numMiniblocks; i < miniBlocksPerBlock_; i++) { + bitWidthData[i] = 0; } - VELOX_DCHECK_EQ(values_current_block_, 0); + VELOX_DCHECK_EQ(valuesCurrentBlock_, 0); - bit_writer_.Flush(); + bitWriter_.Flush(); PARQUET_THROW_NOT_OK( - sink_.Append(bit_writer_.buffer(), bit_writer_.bytesWritten())); - bit_writer_.Clear(); + sink_.Append(bitWriter_.buffer(), bitWriter_.bytesWritten())); + bitWriter_.Clear(); } template -std::shared_ptr<::arrow::Buffer> DeltaBitPackEncoder::FlushValues() { - if (values_current_block_ > 0) { - FlushBlock(); - } - PARQUET_ASSIGN_OR_THROW(auto buffer, sink_.Finish(/*shrink_to_fit=*/true)); - - uint8_t header_buffer_[kMaxPageHeaderWriterSize] = {}; - BitWriter header_writer(header_buffer_, sizeof(header_buffer_)); - if (!header_writer.PutVlqInt(values_per_block_) || - !header_writer.PutVlqInt(mini_blocks_per_block_) || - !header_writer.PutVlqInt(total_value_count_) || - !header_writer.PutZigZagVlqInt(static_cast(first_value_))) { +std::shared_ptr<::arrow::Buffer> DeltaBitPackEncoder::flushValues() { + if (valuesCurrentBlock_ > 0) { + flushBlock(); + } + PARQUET_ASSIGN_OR_THROW(auto buffer, sink_.Finish(true)); + + uint8_t headerBuffer_[kMaxPageHeaderWriterSize] = {}; + BitWriter headerWriter(headerBuffer_, sizeof(headerBuffer_)); + if (!headerWriter.PutVlqInt(valuesPerBlock_) || + !headerWriter.PutVlqInt(miniBlocksPerBlock_) || + !headerWriter.PutVlqInt(totalValueCount_) || + !headerWriter.PutZigZagVlqInt(static_cast(firstValue_))) { throw ParquetException("header writing error"); } - header_writer.Flush(); + headerWriter.Flush(); - // We reserved enough space at the beginning of the buffer for largest - // possible header and data was written immediately after. We now write the - // header data immediately before the end of reserved space. - const size_t offset_bytes = - kMaxPageHeaderWriterSize - header_writer.bytesWritten(); + // We reserved enough space at the beginning of the buffer for largest. + // Possible header and data was written immediately after. We now write the. + // Header data immediately before the end of reserved space. + const size_t offsetBytes = + kMaxPageHeaderWriterSize - headerWriter.bytesWritten(); std::memcpy( - buffer->mutable_data() + offset_bytes, - header_buffer_, - header_writer.bytesWritten()); - - // Reset counter of cached values - total_value_count_ = 0; - // Reserve enough space at the beginning of the buffer for largest possible - // header. + buffer->mutable_data() + offsetBytes, + headerBuffer_, + headerWriter.bytesWritten()); + + // Reset counter of cached values. + totalValueCount_ = 0; + // Reserve enough space at the beginning of the buffer for largest possible. + // Header. PARQUET_THROW_NOT_OK(sink_.Advance(kMaxPageHeaderWriterSize)); // Excess bytes at the beginning are sliced off and ignored. - return SliceBuffer(buffer, offset_bytes); + return ::arrow::SliceBuffer(buffer, offsetBytes); } template <> -void DeltaBitPackEncoder::Put(const ::arrow::Array& values) { +void DeltaBitPackEncoder::put(const ::arrow::Array& values) { const ::arrow::ArrayData& data = *values.data(); if (values.type_id() != ::arrow::Type::INT32) { throw ParquetException( @@ -2700,9 +2666,9 @@ void DeltaBitPackEncoder::Put(const ::arrow::Array& values) { } if (values.null_count() == 0) { - Put(data.GetValues(1), static_cast(data.length)); + put(data.GetValues(1), static_cast(data.length)); } else { - PutSpaced( + putSpaced( data.GetValues(1), static_cast(data.length), data.GetValues(0, 0), @@ -2711,7 +2677,7 @@ void DeltaBitPackEncoder::Put(const ::arrow::Array& values) { } template <> -void DeltaBitPackEncoder::Put(const ::arrow::Array& values) { +void DeltaBitPackEncoder::put(const ::arrow::Array& values) { const ::arrow::ArrayData& data = *values.data(); if (values.type_id() != ::arrow::Type::INT64) { throw ParquetException( @@ -2722,9 +2688,9 @@ void DeltaBitPackEncoder::Put(const ::arrow::Array& values) { "Array cannot be longer than ", std::numeric_limits::max()); } if (values.null_count() == 0) { - Put(data.GetValues(1), static_cast(data.length)); + put(data.GetValues(1), static_cast(data.length)); } else { - PutSpaced( + putSpaced( data.GetValues(1), static_cast(data.length), data.GetValues(0, 0), @@ -2733,280 +2699,277 @@ void DeltaBitPackEncoder::Put(const ::arrow::Array& values) { } template -void DeltaBitPackEncoder::PutSpaced( +void DeltaBitPackEncoder::putSpaced( const T* src, - int num_values, - const uint8_t* valid_bits, - int64_t valid_bits_offset) { - if (valid_bits != NULLPTR) { - PARQUET_ASSIGN_OR_THROW( - auto buffer, - ::arrow::AllocateBuffer(num_values * sizeof(T), this->memory_pool())); + int numValues, + const uint8_t* validBits, + int64_t validBitsOffset) { + if (validBits != NULLPTR) { + auto buffer = allocateBuffer(this->memoryPool(), numValues * sizeof(T)); T* data = reinterpret_cast(buffer->mutable_data()); - int num_valid_values = ::arrow::util::internal::SpacedCompress( - src, num_values, valid_bits, valid_bits_offset, data); - Put(data, num_valid_values); + int numValidValues = ::arrow::util::internal::SpacedCompress( + src, numValues, validBits, validBitsOffset, data); + put(data, numValidValues); } else { - Put(src, num_values); + put(src, numValues); } } -// ---------------------------------------------------------------------- -// DeltaBitPackDecoder +// ----------------------------------------------------------------------. +// DeltaBitPackDecoder. template class DeltaBitPackDecoder : public DecoderImpl, virtual public TypedDecoder { public: - typedef typename DType::c_type T; + typedef typename DType::CType T; using UT = std::make_unsigned_t; explicit DeltaBitPackDecoder( const ColumnDescriptor* descr, MemoryPool* pool = ::arrow::default_memory_pool()) - : DecoderImpl(descr, Encoding::DELTA_BINARY_PACKED), pool_(pool) { - if (DType::type_num != Type::INT32 && DType::type_num != Type::INT64) { + : DecoderImpl(descr, Encoding::kDeltaBinaryPacked), pool_(pool) { + if (DType::typeNum != Type::kInt32 && DType::typeNum != Type::kInt64) { throw ParquetException( "Delta bit pack encoding should only be for integer data."); } } - void SetData(int num_values, const uint8_t* data, int len) override { - // num_values is equal to page's num_values, including null values in this - // page - this->num_values_ = num_values; + void setData(int numValues, const uint8_t* data, int len) override { + // Num_values is equal to page's num_values, including null values in this. + // Page. + this->numValues_ = numValues; decoder_ = std::make_shared(data, len); - InitHeader(); + initHeader(); } - // Set BitReader which is already initialized by DeltaLengthByteArrayDecoder - // or DeltaByteArrayDecoder - void SetDecoder(int num_values, std::shared_ptr decoder) { - this->num_values_ = num_values; + // Set BitReader which is already initialized by DeltaLengthByteArrayDecoder. + // Or DeltaByteArrayDecoder. + void setDecoder(int numValues, std::shared_ptr decoder) { + this->numValues_ = numValues; decoder_ = std::move(decoder); - InitHeader(); + initHeader(); } - int ValidValuesCount() { - // total_values_remaining_ in header ignores of null values - return static_cast(total_values_remaining_); + int validValuesCount() { + // Total_values_remaining_ in header ignores of null values. + return static_cast(totalValuesRemaining_); } - int Decode(T* buffer, int max_values) override { - return GetInternal(buffer, max_values); + int decode(T* buffer, int maxValues) override { + return getInternal(buffer, maxValues); } - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::Accumulator* out) override { - if (null_count != 0) { + if (nullCount != 0) { // TODO(ARROW-34660): implement DecodeArrow with null slots. ParquetException::NYI("Delta bit pack DecodeArrow with null slots"); } - std::vector values(num_values); - int decoded_count = GetInternal(values.data(), num_values); - PARQUET_THROW_NOT_OK(out->AppendValues(values.data(), decoded_count)); - return decoded_count; + std::vector values(numValues); + int decodedCount = getInternal(values.data(), numValues); + PARQUET_THROW_NOT_OK(out->AppendValues(values.data(), decodedCount)); + return decodedCount; } - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::DictAccumulator* out) override { - if (null_count != 0) { + if (nullCount != 0) { // TODO(ARROW-34660): implement DecodeArrow with null slots. ParquetException::NYI("Delta bit pack DecodeArrow with null slots"); } - std::vector values(num_values); - int decoded_count = GetInternal(values.data(), num_values); - PARQUET_THROW_NOT_OK(out->Reserve(decoded_count)); - for (int i = 0; i < decoded_count; ++i) { + std::vector values(numValues); + int decodedCount = getInternal(values.data(), numValues); + PARQUET_THROW_NOT_OK(out->Reserve(decodedCount)); + for (int i = 0; i < decodedCount; ++i) { PARQUET_THROW_NOT_OK(out->Append(values[i])); } - return decoded_count; + return decodedCount; } private: static constexpr int kMaxDeltaBitWidth = static_cast(sizeof(T) * 8); - void InitHeader() { - if (!decoder_->GetVlqInt(&values_per_block_) || - !decoder_->GetVlqInt(&mini_blocks_per_block_) || - !decoder_->GetVlqInt(&total_value_count_) || - !decoder_->GetZigZagVlqInt(&last_value_)) { - ParquetException::EofException("InitHeader EOF"); + void initHeader() { + if (!decoder_->GetVlqInt(&valuesPerBlock_) || + !decoder_->GetVlqInt(&miniBlocksPerBlock_) || + !decoder_->GetVlqInt(&totalValueCount_) || + !decoder_->GetZigZagVlqInt(&lastValue_)) { + ParquetException::eofException("InitHeader EOF"); } - if (values_per_block_ == 0) { + if (valuesPerBlock_ == 0) { throw ParquetException("cannot have zero value per block"); } - if (values_per_block_ % 128 != 0) { + if (valuesPerBlock_ % 128 != 0) { throw ParquetException( "the number of values in a block must be multiple of 128, but it's " + - std::to_string(values_per_block_)); + std::to_string(valuesPerBlock_)); } - if (mini_blocks_per_block_ == 0) { + if (miniBlocksPerBlock_ == 0) { throw ParquetException("cannot have zero miniblock per block"); } - values_per_mini_block_ = values_per_block_ / mini_blocks_per_block_; - if (values_per_mini_block_ == 0) { + valuesPerMiniBlock_ = valuesPerBlock_ / miniBlocksPerBlock_; + if (valuesPerMiniBlock_ == 0) { throw ParquetException("cannot have zero value per miniblock"); } - if (values_per_mini_block_ % 32 != 0) { + if (valuesPerMiniBlock_ % 32 != 0) { throw ParquetException( "the number of values in a miniblock must be multiple of 32, but it's " + - std::to_string(values_per_mini_block_)); + std::to_string(valuesPerMiniBlock_)); } - total_values_remaining_ = total_value_count_; - if (delta_bit_widths_ == nullptr) { - delta_bit_widths_ = AllocateBuffer(pool_, mini_blocks_per_block_); + totalValuesRemaining_ = totalValueCount_; + if (deltaBitWidths_ == nullptr) { + deltaBitWidths_ = allocateBuffer(pool_, miniBlocksPerBlock_); } else { - PARQUET_THROW_NOT_OK(delta_bit_widths_->Resize( - mini_blocks_per_block_, /*shrink_to_fit*/ false)); + PARQUET_THROW_NOT_OK(deltaBitWidths_->Resize( + miniBlocksPerBlock_, /*shrink_to_fit*/ false)); } - first_block_initialized_ = false; - values_remaining_current_mini_block_ = 0; + firstBlockInitialized_ = false; + valuesRemainingCurrentMiniBlock_ = 0; } - void InitBlock() { - VELOX_DCHECK_GT(total_values_remaining_, 0, "InitBlock called at EOF"); + void initBlock() { + VELOX_DCHECK_GT(totalValuesRemaining_, 0, "InitBlock called at EOF"); - if (!decoder_->GetZigZagVlqInt(&min_delta_)) - ParquetException::EofException("InitBlock EOF"); + if (!decoder_->GetZigZagVlqInt(&minDelta_)) + ParquetException::eofException("InitBlock EOF"); - // read the bitwidth of each miniblock - uint8_t* bit_width_data = delta_bit_widths_->mutable_data(); - for (uint32_t i = 0; i < mini_blocks_per_block_; ++i) { - if (!decoder_->GetAligned(1, bit_width_data + i)) { - ParquetException::EofException("Decode bit-width EOF"); + // Read the bitwidth of each miniblock. + uint8_t* bitWidthData = deltaBitWidths_->mutable_data(); + for (uint32_t i = 0; i < miniBlocksPerBlock_; ++i) { + if (!decoder_->GetAligned(1, bitWidthData + i)) { + ParquetException::eofException("Decode bit-width EOF"); } - // Note that non-conformant bitwidth entries are allowed by the Parquet - // spec for extraneous miniblocks in the last block (GH-14923), so we - // check the bitwidths when actually using them (see InitMiniBlock()). + // Note that non-conformant bitwidth entries are allowed by the Parquet. + // Spec for extraneous miniblocks in the last block (GH-14923), so we. + // Check the bitwidths when actually using them (see InitMiniBlock()). } - mini_block_idx_ = 0; - first_block_initialized_ = true; - InitMiniBlock(bit_width_data[0]); + miniBlockIdx_ = 0; + firstBlockInitialized_ = true; + initMiniBlock(bitWidthData[0]); } - void InitMiniBlock(int bit_width) { - if (ARROW_PREDICT_FALSE(bit_width > kMaxDeltaBitWidth)) { + void initMiniBlock(int bitWidth) { + if (ARROW_PREDICT_FALSE(bitWidth > kMaxDeltaBitWidth)) { throw ParquetException("delta bit width larger than integer bit width"); } - delta_bit_width_ = bit_width; - values_remaining_current_mini_block_ = values_per_mini_block_; + deltaBitWidth_ = bitWidth; + valuesRemainingCurrentMiniBlock_ = valuesPerMiniBlock_; } - int GetInternal(T* buffer, int max_values) { - max_values = static_cast( - std::min(max_values, total_values_remaining_)); - if (max_values == 0) { + int getInternal(T* buffer, int maxValues) { + maxValues = + static_cast(std::min(maxValues, totalValuesRemaining_)); + if (maxValues == 0) { return 0; } int i = 0; - if (ARROW_PREDICT_FALSE(!first_block_initialized_)) { - // This is the first time we decode this data page, first output the - // last value and initialize the first block. - buffer[i++] = last_value_; - if (ARROW_PREDICT_FALSE(i == max_values)) { + if (ARROW_PREDICT_FALSE(!firstBlockInitialized_)) { + // This is the first time we decode this data page, first output the. + // Last value and initialize the first block. + buffer[i++] = lastValue_; + if (ARROW_PREDICT_FALSE(i == maxValues)) { // When i reaches max_values here we have two different possibilities: - // 1. total_value_count_ == 1, which means that the page may have only - // one value (encoded in the header), and we should not initialize - // any block, nor should we skip any padding bits below. - // 2. total_value_count_ != 1, which means we should initialize the - // incoming block for subsequent reads. - if (total_value_count_ != 1) { - InitBlock(); + // 1. Total_value_count_ == 1, which means that the page may have only. + // One value (encoded in the header), and we should not initialize. + // Any block, nor should we skip any padding bits below. + // 2. Total_value_count_ != 1, which means we should initialize the. + // Incoming block for subsequent reads. + if (totalValueCount_ != 1) { + initBlock(); } - total_values_remaining_ -= max_values; - this->num_values_ -= max_values; - return max_values; + totalValuesRemaining_ -= maxValues; + this->numValues_ -= maxValues; + return maxValues; } - InitBlock(); + initBlock(); } - VELOX_DCHECK(first_block_initialized_); - while (i < max_values) { - // Ensure we have an initialized mini-block - if (ARROW_PREDICT_FALSE(values_remaining_current_mini_block_ == 0)) { - ++mini_block_idx_; - if (mini_block_idx_ < mini_blocks_per_block_) { - InitMiniBlock(delta_bit_widths_->data()[mini_block_idx_]); + VELOX_DCHECK(firstBlockInitialized_); + while (i < maxValues) { + // Ensure we have an initialized mini-block. + if (ARROW_PREDICT_FALSE(valuesRemainingCurrentMiniBlock_ == 0)) { + ++miniBlockIdx_; + if (miniBlockIdx_ < miniBlocksPerBlock_) { + initMiniBlock(deltaBitWidths_->data()[miniBlockIdx_]); } else { - InitBlock(); + initBlock(); } } - int values_decode = std::min( - values_remaining_current_mini_block_, - static_cast(max_values - i)); - if (decoder_->GetBatch(delta_bit_width_, buffer + i, values_decode) != - values_decode) { - ParquetException::EofException(); + int valuesDecode = std::min( + valuesRemainingCurrentMiniBlock_, + static_cast(maxValues - i)); + if (decoder_->GetBatch(deltaBitWidth_, buffer + i, valuesDecode) != + valuesDecode) { + ParquetException::eofException(); } - for (int j = 0; j < values_decode; ++j) { - // Addition between min_delta, packed int and last_value should be - // treated as unsigned addition. Overflow is as expected. - buffer[i + j] = static_cast(min_delta_) + - static_cast(buffer[i + j]) + static_cast(last_value_); - last_value_ = buffer[i + j]; + for (int j = 0; j < valuesDecode; ++j) { + // Addition between min_delta, packed int and last_value should be. + // Treated as unsigned addition. Overflow is as expected. + buffer[i + j] = static_cast(minDelta_) + + static_cast(buffer[i + j]) + static_cast(lastValue_); + lastValue_ = buffer[i + j]; } - values_remaining_current_mini_block_ -= values_decode; - i += values_decode; + valuesRemainingCurrentMiniBlock_ -= valuesDecode; + i += valuesDecode; } - total_values_remaining_ -= max_values; - this->num_values_ -= max_values; - - if (ARROW_PREDICT_FALSE(total_values_remaining_ == 0)) { - uint32_t padding_bits = - values_remaining_current_mini_block_ * delta_bit_width_; - // skip the padding bits - if (!decoder_->Advance(padding_bits)) { - ParquetException::EofException(); + totalValuesRemaining_ -= maxValues; + this->numValues_ -= maxValues; + + if (ARROW_PREDICT_FALSE(totalValuesRemaining_ == 0)) { + uint32_t paddingBits = valuesRemainingCurrentMiniBlock_ * deltaBitWidth_; + // Skip the padding bits. + if (!decoder_->Advance(paddingBits)) { + ParquetException::eofException(); } - values_remaining_current_mini_block_ = 0; + valuesRemainingCurrentMiniBlock_ = 0; } - return max_values; + return maxValues; } MemoryPool* pool_; std::shared_ptr decoder_; - uint32_t values_per_block_; - uint32_t mini_blocks_per_block_; - uint32_t values_per_mini_block_; - uint32_t total_value_count_; - - uint32_t total_values_remaining_; - // Remaining values in current mini block. If the current block is the last - // mini block, values_remaining_current_mini_block_ may greater than - // total_values_remaining_. - uint32_t values_remaining_current_mini_block_; - - // If the page doesn't contain any block, `first_block_initialized_` will - // always be false. Otherwise, it will be true when first block initialized. - bool first_block_initialized_; - T min_delta_; - uint32_t mini_block_idx_; - std::shared_ptr delta_bit_widths_; - int delta_bit_width_; - - T last_value_; + uint32_t valuesPerBlock_; + uint32_t miniBlocksPerBlock_; + uint32_t valuesPerMiniBlock_; + uint32_t totalValueCount_; + + uint32_t totalValuesRemaining_; + // Remaining values in current mini block. If the current block is the last. + // Mini block, values_remaining_current_mini_block_ may greater than. + // Total_values_remaining_. + uint32_t valuesRemainingCurrentMiniBlock_; + + // If the page doesn't contain any block, `first_block_initialized_` will. + // Always be false. Otherwise, it will be true when first block initialized. + bool firstBlockInitialized_; + T minDelta_; + uint32_t miniBlockIdx_; + std::shared_ptr deltaBitWidths_; + int deltaBitWidth_; + + T lastValue_; }; -// ---------------------------------------------------------------------- -// DELTA_LENGTH_BYTE_ARRAY +// ----------------------------------------------------------------------. +// DELTA_LENGTH_BYTE_ARRAY. -// ---------------------------------------------------------------------- -// DeltaLengthByteArrayEncoder +// ----------------------------------------------------------------------. +// DeltaLengthByteArrayEncoder. template class DeltaLengthByteArrayEncoder : public EncoderImpl, @@ -3017,33 +2980,33 @@ class DeltaLengthByteArrayEncoder : public EncoderImpl, MemoryPool* pool) : EncoderImpl( descr, - Encoding::DELTA_LENGTH_BYTE_ARRAY, + Encoding::kDeltaLengthByteArray, pool = ::arrow::default_memory_pool()), sink_(pool), - length_encoder_(nullptr, pool), - encoded_size_{0} {} + lengthEncoder_(nullptr, pool), + encodedSize_{0} {} - std::shared_ptr<::arrow::Buffer> FlushValues() override; + std::shared_ptr<::arrow::Buffer> flushValues() override; - int64_t EstimatedDataEncodedSize() override { - return encoded_size_ + length_encoder_.EstimatedDataEncodedSize(); + int64_t estimatedDataEncodedSize() override { + return encodedSize_ + lengthEncoder_.estimatedDataEncodedSize(); } - using TypedEncoder::Put; + using TypedEncoder::put; - void Put(const ::arrow::Array& values) override; + void put(const ::arrow::Array& values) override; - void Put(const T* buffer, int num_values) override; + void put(const T* buffer, int numValues) override; - void PutSpaced( + void putSpaced( const T* src, - int num_values, - const uint8_t* valid_bits, - int64_t valid_bits_offset) override; + int numValues, + const uint8_t* validBits, + int64_t validBitsOffset) override; protected: template - void PutBinaryArray(const ArrayType& array) { + void putBinaryArray(const ArrayType& array) { PARQUET_THROW_NOT_OK( ::arrow::VisitArraySpanInline( *array.data(), @@ -3052,7 +3015,7 @@ class DeltaLengthByteArrayEncoder : public EncoderImpl, return Status::Invalid( "Parquet cannot store strings with size 2GB or more"); } - length_encoder_.Put({static_cast(view.length())}, 1); + lengthEncoder_.put({static_cast(view.length())}, 1); PARQUET_THROW_NOT_OK(sink_.Append(view.data(), view.length())); return Status::OK(); }, @@ -3060,91 +3023,89 @@ class DeltaLengthByteArrayEncoder : public EncoderImpl, } ::arrow::BufferBuilder sink_; - DeltaBitPackEncoder length_encoder_; - uint32_t encoded_size_; + DeltaBitPackEncoder lengthEncoder_; + uint32_t encodedSize_; }; template -void DeltaLengthByteArrayEncoder::Put(const ::arrow::Array& values) { - AssertBaseBinary(values); +void DeltaLengthByteArrayEncoder::put(const ::arrow::Array& values) { + assertBaseBinary(values); if (::arrow::is_binary_like(values.type_id())) { - PutBinaryArray(checked_cast(values)); + putBinaryArray(checked_cast(values)); } else { - PutBinaryArray(checked_cast(values)); + putBinaryArray(checked_cast(values)); } } template -void DeltaLengthByteArrayEncoder::Put(const T* src, int num_values) { - if (num_values == 0) { +void DeltaLengthByteArrayEncoder::put(const T* src, int numValues) { + if (numValues == 0) { return; } constexpr int kBatchSize = 256; std::array lengths; - uint32_t total_increment_size = 0; - for (int idx = 0; idx < num_values; idx += kBatchSize) { - const int batch_size = std::min(kBatchSize, num_values - idx); - for (int j = 0; j < batch_size; ++j) { + uint32_t totalIncrementSize = 0; + for (int idx = 0; idx < numValues; idx += kBatchSize) { + const int batchSize = std::min(kBatchSize, numValues - idx); + for (int j = 0; j < batchSize; ++j) { const int32_t len = src[idx + j].len; - if (AddWithOverflow(total_increment_size, len, &total_increment_size)) { + if (addWithOverflow(totalIncrementSize, len, &totalIncrementSize)) { throw ParquetException("excess expansion in DELTA_LENGTH_BYTE_ARRAY"); } lengths[j] = len; } - length_encoder_.Put(lengths.data(), batch_size); + lengthEncoder_.put(lengths.data(), batchSize); } - if (AddWithOverflow(encoded_size_, total_increment_size, &encoded_size_)) { + if (addWithOverflow(encodedSize_, totalIncrementSize, &encodedSize_)) { throw ParquetException("excess expansion in DELTA_LENGTH_BYTE_ARRAY"); } - PARQUET_THROW_NOT_OK(sink_.Reserve(total_increment_size)); - for (int idx = 0; idx < num_values; idx++) { + PARQUET_THROW_NOT_OK(sink_.Reserve(totalIncrementSize)); + for (int idx = 0; idx < numValues; idx++) { sink_.UnsafeAppend(src[idx].ptr, src[idx].len); } } template -void DeltaLengthByteArrayEncoder::PutSpaced( +void DeltaLengthByteArrayEncoder::putSpaced( const T* src, - int num_values, - const uint8_t* valid_bits, - int64_t valid_bits_offset) { - if (valid_bits != NULLPTR) { - PARQUET_ASSIGN_OR_THROW( - auto buffer, - ::arrow::AllocateBuffer(num_values * sizeof(T), this->memory_pool())); + int numValues, + const uint8_t* validBits, + int64_t validBitsOffset) { + if (validBits != NULLPTR) { + auto buffer = allocateBuffer(this->memoryPool(), numValues * sizeof(T)); T* data = reinterpret_cast(buffer->mutable_data()); - int num_valid_values = ::arrow::util::internal::SpacedCompress( - src, num_values, valid_bits, valid_bits_offset, data); - Put(data, num_valid_values); + int numValidValues = ::arrow::util::internal::SpacedCompress( + src, numValues, validBits, validBitsOffset, data); + put(data, numValidValues); } else { - Put(src, num_values); + put(src, numValues); } } template std::shared_ptr<::arrow::Buffer> -DeltaLengthByteArrayEncoder::FlushValues() { - std::shared_ptr encoded_lengths = length_encoder_.FlushValues(); +DeltaLengthByteArrayEncoder::flushValues() { + std::shared_ptr encodedLengths = lengthEncoder_.flushValues(); std::shared_ptr data; PARQUET_THROW_NOT_OK(sink_.Finish(&data)); sink_.Reset(); - PARQUET_THROW_NOT_OK(sink_.Resize(encoded_lengths->size() + data->size())); + PARQUET_THROW_NOT_OK(sink_.Resize(encodedLengths->size() + data->size())); PARQUET_THROW_NOT_OK( - sink_.Append(encoded_lengths->data(), encoded_lengths->size())); + sink_.Append(encodedLengths->data(), encodedLengths->size())); PARQUET_THROW_NOT_OK(sink_.Append(data->data(), data->size())); std::shared_ptr buffer; PARQUET_THROW_NOT_OK(sink_.Finish(&buffer, true)); - encoded_size_ = 0; + encodedSize_ = 0; return buffer; } -// ---------------------------------------------------------------------- -// DeltaLengthByteArrayDecoder +// ----------------------------------------------------------------------. +// DeltaLengthByteArrayDecoder. class DeltaLengthByteArrayDecoder : public DecoderImpl, virtual public TypedDecoder { @@ -3152,154 +3113,151 @@ class DeltaLengthByteArrayDecoder : public DecoderImpl, explicit DeltaLengthByteArrayDecoder( const ColumnDescriptor* descr, MemoryPool* pool = ::arrow::default_memory_pool()) - : DecoderImpl(descr, Encoding::DELTA_LENGTH_BYTE_ARRAY), - len_decoder_(nullptr, pool), - buffered_length_(AllocateBuffer(pool, 0)) {} + : DecoderImpl(descr, Encoding::kDeltaLengthByteArray), + lenDecoder_(nullptr, pool), + bufferedLength_(allocateBuffer(pool, 0)) {} - void SetData(int num_values, const uint8_t* data, int len) override { - DecoderImpl::SetData(num_values, data, len); + void setData(int numValues, const uint8_t* data, int len) override { + DecoderImpl::setData(numValues, data, len); decoder_ = std::make_shared(data, len); - DecodeLengths(); + decodeLengths(); } - int Decode(ByteArray* buffer, int max_values) override { - // Decode up to `max_values` strings into an internal buffer - // and reference them into `buffer`. - max_values = std::min(max_values, num_valid_values_); - VELOX_DCHECK_GE(max_values, 0); - if (max_values == 0) { + int decode(ByteArray* buffer, int maxValues) override { + // Decode up to `max_values` strings into an internal buffer. + // And reference them into `buffer`. + maxValues = std::min(maxValues, numValidValues_); + VELOX_DCHECK_GE(maxValues, 0); + if (maxValues == 0) { return 0; } - int32_t data_size = 0; - const int32_t* length_ptr = - reinterpret_cast(buffered_length_->data()) + - length_idx_; - int bytes_offset = len_ - decoder_->bytesLeft(); - for (int i = 0; i < max_values; ++i) { - int32_t len = length_ptr[i]; + int32_t dataSize = 0; + const int32_t* lengthPtr = + reinterpret_cast(bufferedLength_->data()) + lengthIdx_; + int bytesOffset = len_ - decoder_->bytesLeft(); + for (int i = 0; i < maxValues; ++i) { + int32_t len = lengthPtr[i]; if (ARROW_PREDICT_FALSE(len < 0)) { throw ParquetException("negative string delta length"); } buffer[i].len = len; - if (AddWithOverflow(data_size, len, &data_size)) { + if (addWithOverflow(dataSize, len, &dataSize)) { throw ParquetException("excess expansion in DELTA_(LENGTH_)BYTE_ARRAY"); } } - length_idx_ += max_values; + lengthIdx_ += maxValues; if (ARROW_PREDICT_FALSE( - !decoder_->Advance(8 * static_cast(data_size)))) { - ParquetException::EofException(); + !decoder_->Advance(8 * static_cast(dataSize)))) { + ParquetException::eofException(); } - const uint8_t* data_ptr = data_ + bytes_offset; - for (int i = 0; i < max_values; ++i) { - buffer[i].ptr = data_ptr; - data_ptr += buffer[i].len; + const uint8_t* dataPtr = data_ + bytesOffset; + for (int i = 0; i < maxValues; ++i) { + buffer[i].ptr = dataPtr; + dataPtr += buffer[i].len; } - this->num_values_ -= max_values; - num_valid_values_ -= max_values; - return max_values; + this->numValues_ -= maxValues; + numValidValues_ -= maxValues; + return maxValues; } - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::Accumulator* out) override { int result = 0; - PARQUET_THROW_NOT_OK(DecodeArrowDense( - num_values, null_count, valid_bits, valid_bits_offset, out, &result)); + PARQUET_THROW_NOT_OK(decodeArrowDense( + numValues, nullCount, validBits, validBitsOffset, out, &result)); return result; } - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::DictAccumulator* out) override { ParquetException::NYI( "DecodeArrow of DictAccumulator for DeltaLengthByteArrayDecoder"); } private: - // Decode all the encoded lengths. The decoder_ will be at the start of the - // encoded data after that. - void DecodeLengths() { - len_decoder_.SetDecoder(num_values_, decoder_); - - // get the number of encoded lengths - int num_length = len_decoder_.ValidValuesCount(); - PARQUET_THROW_NOT_OK( - buffered_length_->Resize(num_length * sizeof(int32_t))); - - // call len_decoder_.Decode to decode all the lengths. - // all the lengths are buffered in buffered_length_. - VELOX_DEBUG_ONLY int ret = len_decoder_.Decode( - reinterpret_cast(buffered_length_->mutable_data()), - num_length); - VELOX_DCHECK_EQ(ret, num_length); - length_idx_ = 0; - num_valid_values_ = num_length; - } - - Status DecodeArrowDense( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + // Decode all the encoded lengths. The decoder_ will be at the start of the. + // Encoded data after that. + void decodeLengths() { + lenDecoder_.setDecoder(numValues_, decoder_); + + // Get the number of encoded lengths. + int numLength = lenDecoder_.validValuesCount(); + PARQUET_THROW_NOT_OK(bufferedLength_->Resize(numLength * sizeof(int32_t))); + + // Call len_decoder_.Decode to decode all the lengths. + // All the lengths are buffered in buffered_length_. + VELOX_DEBUG_ONLY int ret = lenDecoder_.decode( + reinterpret_cast(bufferedLength_->mutable_data()), numLength); + VELOX_DCHECK_EQ(ret, numLength); + lengthIdx_ = 0; + numValidValues_ = numLength; + } + + Status decodeArrowDense( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::Accumulator* out, - int* out_num_values) { - ArrowBinaryHelper helper(out, num_values); - RETURN_NOT_OK(helper.Prepare()); + int* outNumValues) { + ArrowBinaryHelper helper(out, numValues); + RETURN_NOT_OK(helper.prepare()); - std::vector values(num_values - null_count); - const int num_valid_values = Decode(values.data(), num_values - null_count); - if (ARROW_PREDICT_FALSE(num_values - null_count != num_valid_values)) { + std::vector values(numValues - nullCount); + const int numValidValues = decode(values.data(), numValues - nullCount); + if (ARROW_PREDICT_FALSE(numValues - nullCount != numValidValues)) { throw ParquetException( "Expected to decode ", - num_values - null_count, + numValues - nullCount, " values, but decoded ", - num_valid_values, + numValidValues, " values."); } - auto values_ptr = values.data(); - int value_idx = 0; + auto valuesPtr = values.data(); + int valueIdx = 0; RETURN_NOT_OK(VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { - const auto& val = values_ptr[value_idx]; - RETURN_NOT_OK(helper.PrepareNextInput(val.len)); + const auto& val = valuesPtr[valueIdx]; + RETURN_NOT_OK(helper.prepareNextInput(val.len)); RETURN_NOT_OK(helper.Append(val.ptr, static_cast(val.len))); - ++value_idx; + ++valueIdx; return Status::OK(); }, [&]() { RETURN_NOT_OK(helper.AppendNull()); - --null_count; + --nullCount; return Status::OK(); })); - VELOX_DCHECK_EQ(null_count, 0); - *out_num_values = num_valid_values; + VELOX_DCHECK_EQ(nullCount, 0); + *outNumValues = numValidValues; return Status::OK(); } std::shared_ptr decoder_; - DeltaBitPackDecoder len_decoder_; - int num_valid_values_{0}; - uint32_t length_idx_{0}; - std::shared_ptr buffered_length_; + DeltaBitPackDecoder lenDecoder_; + int numValidValues_{0}; + uint32_t lengthIdx_{0}; + std::shared_ptr bufferedLength_; }; -// ---------------------------------------------------------------------- -// RLE_BOOLEAN_ENCODER +// ----------------------------------------------------------------------. +// RLE_BOOLEAN_ENCODER. class RleBooleanEncoder final : public EncoderImpl, virtual public BooleanEncoder { @@ -3307,116 +3265,115 @@ class RleBooleanEncoder final : public EncoderImpl, explicit RleBooleanEncoder( const ColumnDescriptor* descr, ::arrow::MemoryPool* pool) - : EncoderImpl(descr, Encoding::RLE, pool), - buffered_append_values_(::arrow::stl::allocator(pool)) {} + : EncoderImpl(descr, Encoding::kRle, pool), + bufferedAppendValues_(::arrow::stl::allocator(pool)) {} - int64_t EstimatedDataEncodedSize() override { - return kRleLengthInBytes + MaxRleBufferSize(); + int64_t estimatedDataEncodedSize() override { + return kRleLengthInBytes + maxRleBufferSize(); } - std::shared_ptr FlushValues() override; + std::shared_ptr flushValues() override; - void Put(const T* buffer, int num_values) override; - void Put(const ::arrow::Array& values) override { + void put(const T* buffer, int numValues) override; + void put(const ::arrow::Array& values) override { if (values.type_id() != ::arrow::Type::BOOL) { throw ParquetException( "RleBooleanEncoder expects BooleanArray, got ", values.type()->ToString()); } - const auto& boolean_array = + const auto& booleanArray = checked_cast(values); if (values.null_count() == 0) { - for (int i = 0; i < boolean_array.length(); ++i) { - // null_count == 0, so just call Value directly is ok. - buffered_append_values_.push_back(boolean_array.Value(i)); + for (int i = 0; i < booleanArray.length(); ++i) { + // Null_count == 0, so just call Value directly is ok. + bufferedAppendValues_.push_back(booleanArray.Value(i)); } } else { - PARQUET_THROW_NOT_OK(::arrow::VisitArraySpanInline<::arrow::BooleanType>( - *boolean_array.data(), - [&](bool value) { - buffered_append_values_.push_back(value); - return Status::OK(); - }, - []() { return Status::OK(); })); + PARQUET_THROW_NOT_OK( + ::arrow::VisitArraySpanInline<::arrow::BooleanType>( + *booleanArray.data(), + [&](bool value) { + bufferedAppendValues_.push_back(value); + return Status::OK(); + }, + []() { return Status::OK(); })); } } - void PutSpaced( + void putSpaced( const T* src, - int num_values, - const uint8_t* valid_bits, - int64_t valid_bits_offset) override { - if (valid_bits != NULLPTR) { - PARQUET_ASSIGN_OR_THROW( - auto buffer, - ::arrow::AllocateBuffer(num_values * sizeof(T), this->memory_pool())); + int numValues, + const uint8_t* validBits, + int64_t validBitsOffset) override { + if (validBits != NULLPTR) { + auto buffer = allocateBuffer(this->memoryPool(), numValues * sizeof(T)); T* data = reinterpret_cast(buffer->mutable_data()); - int num_valid_values = ::arrow::util::internal::SpacedCompress( - src, num_values, valid_bits, valid_bits_offset, data); - Put(data, num_valid_values); + int numValidValues = ::arrow::util::internal::SpacedCompress( + src, numValues, validBits, validBitsOffset, data); + put(data, numValidValues); } else { - Put(src, num_values); + put(src, numValues); } } - void Put(const std::vector& src, int num_values) override; + void put(const std::vector& src, int numValues) override; protected: template - void PutImpl(const SequenceType& src, int num_values); + void putImpl(const SequenceType& src, int numValues); - int MaxRleBufferSize() const noexcept { - return RlePreserveBufferSize( - static_cast(buffered_append_values_.size()), kBitWidth); + int maxRleBufferSize() const noexcept { + return rlePreserveBufferSize( + static_cast(bufferedAppendValues_.size()), kBitWidth); } constexpr static int32_t kBitWidth = 1; - /// 4 bytes in little-endian, which indicates the length. + /// 4 Bytes in little-endian, which indicates the length. constexpr static int32_t kRleLengthInBytes = 4; - // std::vector in C++ is tricky, because it's a bitmap. - // Here RleBooleanEncoder will only append values into it, and - // dump values into Buffer, so using it here is ok. - ArrowPoolVector buffered_append_values_; + // Std::vector in C++ is tricky, because it's a bitmap. + // Here RleBooleanEncoder will only Append values into it, and. + // Dump values into Buffer, so using it here is ok. + ArrowPoolVector bufferedAppendValues_; }; -void RleBooleanEncoder::Put(const bool* src, int num_values) { - PutImpl(src, num_values); +void RleBooleanEncoder::put(const bool* src, int numValues) { + putImpl(src, numValues); } -void RleBooleanEncoder::Put(const std::vector& src, int num_values) { - PutImpl(src, num_values); +void RleBooleanEncoder::put(const std::vector& src, int numValues) { + putImpl(src, numValues); } template -void RleBooleanEncoder::PutImpl(const SequenceType& src, int num_values) { - for (int i = 0; i < num_values; ++i) { - buffered_append_values_.push_back(src[i]); +void RleBooleanEncoder::putImpl(const SequenceType& src, int numValues) { + for (int i = 0; i < numValues; ++i) { + bufferedAppendValues_.push_back(src[i]); } } -std::shared_ptr RleBooleanEncoder::FlushValues() { - int rle_buffer_size_max = MaxRleBufferSize(); +std::shared_ptr RleBooleanEncoder::flushValues() { + int rleBufferSizeMax = maxRleBufferSize(); std::shared_ptr buffer = - AllocateBuffer(this->pool_, rle_buffer_size_max + kRleLengthInBytes); + allocateBuffer(this->pool_, rleBufferSizeMax + kRleLengthInBytes); RleEncoder encoder( buffer->mutable_data() + kRleLengthInBytes, - rle_buffer_size_max, + rleBufferSizeMax, /*bit_width*/ kBitWidth); - for (bool value : buffered_append_values_) { + for (bool value : bufferedAppendValues_) { encoder.Put(value ? 1 : 0); } encoder.Flush(); ::arrow::util::SafeStore( buffer->mutable_data(), ::arrow::bit_util::ToLittleEndian(encoder.len())); PARQUET_THROW_NOT_OK(buffer->Resize(kRleLengthInBytes + encoder.len())); - buffered_append_values_.clear(); + bufferedAppendValues_.clear(); return buffer; } -// ---------------------------------------------------------------------- -// RLE_BOOLEAN_DECODER +// ----------------------------------------------------------------------. +// RLE_BOOLEAN_DECODER. // TODO - Commented out as arrow/util/endian.h needs to be updated first. /* @@ -3433,7 +3390,7 @@ class RleBooleanDecoder : public DecoderImpl, virtual public BooleanDecoder { throw ParquetException("Received invalid length : " + std::to_string(len) + " (corrupt data page?)"); } - // Load the first 4 bytes in little-endian, which indicates the length + // Load the first 4 bytes in little-endian, which indicates the length. num_bytes = ::arrow::bit_util::FromLittleEndian(SafeLoadAs(data)); if (num_bytes < 0 || num_bytes > static_cast(len - 4)) { throw ParquetException("Received invalid number of bytes : " + @@ -3501,19 +3458,19 @@ override { if (null_count != 0) { }; */ -// ---------------------------------------------------------------------- -// DELTA_BYTE_ARRAY +// ----------------------------------------------------------------------. +// DELTA_BYTE_ARRAY. -/// Delta Byte Array encoding also known as incremental encoding or front -/// compression: for each element in a sequence of strings, store the prefix -/// length of the previous entry plus the suffix. +/// Delta Byte Array encoding also known as incremental encoding or front. +/// Compression: for each element in a sequence of strings, store the prefix. +/// Length of the previous entry plus the suffix. /// -/// This is stored as a sequence of delta-encoded prefix lengths -/// (DELTA_BINARY_PACKED), followed by the suffixes encoded as delta length byte -/// arrays (DELTA_LENGTH_BYTE_ARRAY). +/// This is stored as a sequence of delta-encoded prefix lengths. +/// (DELTA_BINARY_PACKED), followed by the suffixes encoded as delta length +/// byte. Arrays (DELTA_LENGTH_BYTE_ARRAY). -// ---------------------------------------------------------------------- -// DeltaByteArrayEncoder +// ----------------------------------------------------------------------. +// DeltaByteArrayEncoder. template class DeltaByteArrayEncoder : public EncoderImpl, @@ -3521,106 +3478,105 @@ class DeltaByteArrayEncoder : public EncoderImpl, static constexpr std::string_view kEmpty = ""; public: - using T = typename DType::c_type; + using T = typename DType::CType; explicit DeltaByteArrayEncoder( const ColumnDescriptor* descr, MemoryPool* pool = ::arrow::default_memory_pool()) - : EncoderImpl(descr, Encoding::DELTA_BYTE_ARRAY, pool), + : EncoderImpl(descr, Encoding::kDeltaByteArray, pool), sink_(pool), - prefix_length_encoder_(/*descr=*/nullptr, pool), - suffix_encoder_(descr, pool), - last_value_(""), + prefixLengthEncoder_(nullptr, pool), + suffixEncoder_(descr, pool), + lastValue_(""), empty_( static_cast(kEmpty.size()), reinterpret_cast(kEmpty.data())) {} - std::shared_ptr FlushValues() override; + std::shared_ptr flushValues() override; - int64_t EstimatedDataEncodedSize() override { - return prefix_length_encoder_.EstimatedDataEncodedSize() + - suffix_encoder_.EstimatedDataEncodedSize(); + int64_t estimatedDataEncodedSize() override { + return prefixLengthEncoder_.estimatedDataEncodedSize() + + suffixEncoder_.estimatedDataEncodedSize(); } - using TypedEncoder::Put; + using TypedEncoder::put; - void Put(const ::arrow::Array& values) override; + void put(const ::arrow::Array& values) override; - void Put(const T* buffer, int num_values) override; + void put(const T* buffer, int numValues) override; - void PutSpaced( + void putSpaced( const T* src, - int num_values, - const uint8_t* valid_bits, - int64_t valid_bits_offset) override { - if (valid_bits != nullptr) { + int numValues, + const uint8_t* validBits, + int64_t validBitsOffset) override { + if (validBits != nullptr) { if (buffer_ == nullptr) { PARQUET_ASSIGN_OR_THROW( buffer_, ::arrow::AllocateResizableBuffer( - num_values * sizeof(T), this->memory_pool())); + numValues * sizeof(T), this->memoryPool())); } else { - PARQUET_THROW_NOT_OK(buffer_->Resize(num_values * sizeof(T), false)); + PARQUET_THROW_NOT_OK(buffer_->Resize(numValues * sizeof(T), false)); } T* data = reinterpret_cast(buffer_->mutable_data()); - int num_valid_values = ::arrow::util::internal::SpacedCompress( - src, num_values, valid_bits, valid_bits_offset, data); - Put(data, num_valid_values); + int numValidValues = ::arrow::util::internal::SpacedCompress( + src, numValues, validBits, validBitsOffset, data); + put(data, numValidValues); } else { - Put(src, num_values); + put(src, numValues); } } protected: template - void PutInternal(const T* src, int num_values, const VisitorType visitor) { - if (num_values == 0) { + void putInternal(const T* src, int numValues, const VisitorType visitor) { + if (numValues == 0) { return; } - std::string_view last_value_view = last_value_; + std::string_view lastValueView = lastValue_; constexpr int kBatchSize = 256; - std::array prefix_lengths; + std::array prefixLengths; std::array suffixes; - for (int i = 0; i < num_values; i += kBatchSize) { - const int batch_size = std::min(kBatchSize, num_values - i); + for (int i = 0; i < numValues; i += kBatchSize) { + const int batchSize = std::min(kBatchSize, numValues - i); - for (int j = 0; j < batch_size; ++j) { + for (int j = 0; j < batchSize; ++j) { const int idx = i + j; const auto view = visitor[idx]; const auto len = static_cast(view.length()); - uint32_t common_prefix_length = 0; - const uint32_t maximum_common_prefix_length = - std::min(len, static_cast(last_value_view.length())); - while (common_prefix_length < maximum_common_prefix_length) { - if (last_value_view[common_prefix_length] != - view[common_prefix_length]) { + uint32_t commonPrefixLength = 0; + const uint32_t maximumCommonPrefixLength = + std::min(len, static_cast(lastValueView.length())); + while (commonPrefixLength < maximumCommonPrefixLength) { + if (lastValueView[commonPrefixLength] != view[commonPrefixLength]) { break; } - common_prefix_length++; + commonPrefixLength++; } - last_value_view = view; - prefix_lengths[j] = common_prefix_length; - const uint32_t suffix_length = len - common_prefix_length; - const uint8_t* suffix_ptr = src[idx].ptr + common_prefix_length; + lastValueView = view; + prefixLengths[j] = commonPrefixLength; + const uint32_t suffixLength = len - commonPrefixLength; + const uint8_t* suffixPtr = src[idx].ptr + commonPrefixLength; // Convert to ByteArray, so it can be passed to the suffix_encoder_. - const ByteArray suffix(suffix_length, suffix_ptr); + const ByteArray suffix(suffixLength, suffixPtr); suffixes[j] = suffix; } - suffix_encoder_.Put(suffixes.data(), batch_size); - prefix_length_encoder_.Put(prefix_lengths.data(), batch_size); + suffixEncoder_.put(suffixes.data(), batchSize); + prefixLengthEncoder_.put(prefixLengths.data(), batchSize); } - last_value_ = last_value_view; + lastValue_ = lastValueView; } template - void PutBinaryArray(const ArrayType& array) { - auto previous_len = static_cast(last_value_.length()); - std::string_view last_value_view = last_value_; + void putBinaryArray(const ArrayType& array) { + auto previousLen = static_cast(lastValue_.length()); + std::string_view lastValueView = lastValue_; PARQUET_THROW_NOT_OK( ::arrow::VisitArraySpanInline( @@ -3632,44 +3588,44 @@ class DeltaByteArrayEncoder : public EncoderImpl, } const ByteArray src{std::string_view(view.data(), view.size())}; - uint32_t common_prefix_length = 0; + uint32_t commonPrefixLength = 0; const uint32_t len = src.len; - const uint32_t maximum_common_prefix_length = - std::min(previous_len, len); - while (common_prefix_length < maximum_common_prefix_length) { - if (last_value_view[common_prefix_length] != - view[common_prefix_length]) { + const uint32_t maximumCommonPrefixLength = + std::min(previousLen, len); + while (commonPrefixLength < maximumCommonPrefixLength) { + if (lastValueView[commonPrefixLength] != + view[commonPrefixLength]) { break; } - common_prefix_length++; + commonPrefixLength++; } - previous_len = len; - prefix_length_encoder_.Put( - {static_cast(common_prefix_length)}, 1); - - last_value_view = std::string_view(view.data(), view.size()); - const auto suffix_length = - static_cast(len - common_prefix_length); - if (suffix_length == 0) { - suffix_encoder_.Put(&empty_, 1); + previousLen = len; + prefixLengthEncoder_.put( + {static_cast(commonPrefixLength)}, 1); + + lastValueView = std::string_view(view.data(), view.size()); + const auto suffixLength = + static_cast(len - commonPrefixLength); + if (suffixLength == 0) { + suffixEncoder_.put(&empty_, 1); return Status::OK(); } - const uint8_t* suffix_ptr = src.ptr + common_prefix_length; - // Convert to ByteArray, so it can be passed to the - // suffix_encoder_. - const ByteArray suffix(suffix_length, suffix_ptr); - suffix_encoder_.Put(&suffix, 1); + const uint8_t* suffixPtr = src.ptr + commonPrefixLength; + // Convert to ByteArray, so it can be passed to the. + // Suffix_encoder_. + const ByteArray suffix(suffixLength, suffixPtr); + suffixEncoder_.put(&suffix, 1); return Status::OK(); }, []() { return Status::OK(); })); - last_value_ = last_value_view; + lastValue_ = lastValueView; } ::arrow::BufferBuilder sink_; - DeltaBitPackEncoder prefix_length_encoder_; - DeltaLengthByteArrayEncoder suffix_encoder_; - std::string last_value_; + DeltaBitPackEncoder prefixLengthEncoder_; + DeltaLengthByteArrayEncoder suffixEncoder_; + std::string lastValue_; const ByteArray empty_; std::unique_ptr buffer_; }; @@ -3692,232 +3648,232 @@ struct ByteArrayVisitor { struct FLBAVisitor { const FLBA* src; - const uint32_t type_length; + const uint32_t typeLength; std::string_view operator[](int i) const { return std::string_view{ - reinterpret_cast(src[i].ptr), type_length}; + reinterpret_cast(src[i].ptr), typeLength}; } uint32_t len(int i) const { - return type_length; + return typeLength; } }; template <> -void DeltaByteArrayEncoder::Put( +void DeltaByteArrayEncoder::put( const ByteArray* src, - int num_values) { + int numValues) { auto visitor = ByteArrayVisitor{src}; - PutInternal(src, num_values, visitor); + putInternal(src, numValues, visitor); } template <> -void DeltaByteArrayEncoder::Put(const FLBA* src, int num_values) { - auto visitor = FLBAVisitor{src, static_cast(descr_->type_length())}; - PutInternal(src, num_values, visitor); +void DeltaByteArrayEncoder::put(const FLBA* src, int numValues) { + auto visitor = FLBAVisitor{src, static_cast(descr_->typeLength())}; + putInternal(src, numValues, visitor); } template -void DeltaByteArrayEncoder::Put(const ::arrow::Array& values) { +void DeltaByteArrayEncoder::put(const ::arrow::Array& values) { if (::arrow::is_binary_like(values.type_id())) { - PutBinaryArray(checked_cast(values)); + putBinaryArray(checked_cast(values)); } else if (::arrow::is_large_binary_like(values.type_id())) { - PutBinaryArray(checked_cast(values)); + putBinaryArray(checked_cast(values)); } else if (::arrow::is_fixed_size_binary(values.type_id())) { - PutBinaryArray(checked_cast(values)); + putBinaryArray(checked_cast(values)); } else { throw ParquetException("Only BaseBinaryArray and subclasses supported"); } } template -std::shared_ptr DeltaByteArrayEncoder::FlushValues() { - PARQUET_THROW_NOT_OK(sink_.Resize(EstimatedDataEncodedSize(), false)); +std::shared_ptr DeltaByteArrayEncoder::flushValues() { + PARQUET_THROW_NOT_OK(sink_.Resize(estimatedDataEncodedSize(), false)); - std::shared_ptr prefix_lengths = prefix_length_encoder_.FlushValues(); + std::shared_ptr prefixLengths = prefixLengthEncoder_.flushValues(); PARQUET_THROW_NOT_OK( - sink_.Append(prefix_lengths->data(), prefix_lengths->size())); + sink_.Append(prefixLengths->data(), prefixLengths->size())); - std::shared_ptr suffixes = suffix_encoder_.FlushValues(); + std::shared_ptr suffixes = suffixEncoder_.flushValues(); PARQUET_THROW_NOT_OK(sink_.Append(suffixes->data(), suffixes->size())); std::shared_ptr buffer; PARQUET_THROW_NOT_OK(sink_.Finish(&buffer, true)); - last_value_.clear(); + lastValue_.clear(); return buffer; } -// ---------------------------------------------------------------------- -// DeltaByteArrayDecoder +// ----------------------------------------------------------------------. +// DeltaByteArrayDecoder. template class DeltaByteArrayDecoderImpl : public DecoderImpl, virtual public TypedDecoder { - using T = typename DType::c_type; + using T = typename DType::CType; public: explicit DeltaByteArrayDecoderImpl( const ColumnDescriptor* descr, MemoryPool* pool = ::arrow::default_memory_pool()) - : DecoderImpl(descr, Encoding::DELTA_BYTE_ARRAY), + : DecoderImpl(descr, Encoding::kDeltaByteArray), pool_(pool), - prefix_len_decoder_(nullptr, pool), - suffix_decoder_(nullptr, pool), - last_value_in_previous_page_(""), - buffered_prefix_length_(AllocateBuffer(pool, 0)), - buffered_data_(AllocateBuffer(pool, 0)) {} - - void SetData(int num_values, const uint8_t* data, int len) override { - num_values_ = num_values; + prefixLenDecoder_(nullptr, pool), + suffixDecoder_(nullptr, pool), + lastValueInPreviousPage_(""), + bufferedPrefixLength_(allocateBuffer(pool, 0)), + bufferedData_(allocateBuffer(pool, 0)) {} + + void setData(int numValues, const uint8_t* data, int len) override { + numValues_ = numValues; decoder_ = std::make_shared(data, len); - prefix_len_decoder_.SetDecoder(num_values, decoder_); + prefixLenDecoder_.setDecoder(numValues, decoder_); - // get the number of encoded prefix lengths - int num_prefix = prefix_len_decoder_.ValidValuesCount(); - // call prefix_len_decoder_.Decode to decode all the prefix lengths. - // all the prefix lengths are buffered in buffered_prefix_length_. + // Get the number of encoded prefix lengths. + int numPrefix = prefixLenDecoder_.validValuesCount(); + // Call prefix_len_decoder_.Decode to decode all the prefix lengths. + // All the prefix lengths are buffered in buffered_prefix_length_. PARQUET_THROW_NOT_OK( - buffered_prefix_length_->Resize(num_prefix * sizeof(int32_t))); - VELOX_DEBUG_ONLY int ret = prefix_len_decoder_.Decode( - reinterpret_cast(buffered_prefix_length_->mutable_data()), - num_prefix); - VELOX_DCHECK_EQ(ret, num_prefix); - prefix_len_offset_ = 0; - num_valid_values_ = num_prefix; - - int bytes_left = decoder_->bytesLeft(); + bufferedPrefixLength_->Resize(numPrefix * sizeof(int32_t))); + VELOX_DEBUG_ONLY int ret = prefixLenDecoder_.decode( + reinterpret_cast(bufferedPrefixLength_->mutable_data()), + numPrefix); + VELOX_DCHECK_EQ(ret, numPrefix); + prefixLenOffset_ = 0; + numValidValues_ = numPrefix; + + int bytesLeft = decoder_->bytesLeft(); // If len < bytes_left, prefix_len_decoder.Decode will throw exception. - VELOX_DCHECK_GE(len, bytes_left); - int suffix_begins = len - bytes_left; - // at this time, the decoder_ will be at the start of the encoded suffix - // data. - suffix_decoder_.SetData(num_values, data + suffix_begins, bytes_left); - - // TODO: read corrupted files written with bug(PARQUET-246). last_value_ - // should be set to last_value_in_previous_page_ when decoding a new + VELOX_DCHECK_GE(len, bytesLeft); + int suffixBegins = len - bytesLeft; + // At this time, the decoder_ will be at the start of the encoded suffix. + // Data. + suffixDecoder_.setData(numValues, data + suffixBegins, bytesLeft); + + // TODO: read corrupted files written with bug(PARQUET-246). last_value_. + // Should be set to last_value_in_previous_page_ when decoding a new. // page(except the first page) - last_value_ = ""; + lastValue_ = ""; } - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::Accumulator* out) override { int result = 0; - PARQUET_THROW_NOT_OK(DecodeArrowDense( - num_values, null_count, valid_bits, valid_bits_offset, out, &result)); + PARQUET_THROW_NOT_OK(decodeArrowDense( + numValues, nullCount, validBits, validBitsOffset, out, &result)); return result; } - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) override { + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) override { ParquetException::NYI( "DecodeArrow of DictAccumulator for DeltaByteArrayDecoder"); } protected: - int GetInternal(ByteArray* buffer, int max_values) { - // Decode up to `max_values` strings into an internal buffer - // and reference them into `buffer`. - max_values = std::min(max_values, num_valid_values_); - if (max_values == 0) { - return max_values; + int getInternal(ByteArray* buffer, int maxValues) { + // Decode up to `max_values` strings into an internal buffer. + // And reference them into `buffer`. + maxValues = std::min(maxValues, numValidValues_); + if (maxValues == 0) { + return maxValues; } - int suffix_read = suffix_decoder_.Decode(buffer, max_values); - if (ARROW_PREDICT_FALSE(suffix_read != max_values)) { - ParquetException::EofException( - "Read " + std::to_string(suffix_read) + ", expecting " + - std::to_string(max_values) + " from suffix decoder"); + int suffixRead = suffixDecoder_.decode(buffer, maxValues); + if (ARROW_PREDICT_FALSE(suffixRead != maxValues)) { + ParquetException::eofException( + "Read " + std::to_string(suffixRead) + ", expecting " + + std::to_string(maxValues) + " from suffix decoder"); } - int64_t data_size = 0; - const int32_t* prefix_len_ptr = - reinterpret_cast(buffered_prefix_length_->data()) + - prefix_len_offset_; - for (int i = 0; i < max_values; ++i) { - if (ARROW_PREDICT_FALSE(prefix_len_ptr[i] < 0)) { + int64_t dataSize = 0; + const int32_t* prefixLenPtr = + reinterpret_cast(bufferedPrefixLength_->data()) + + prefixLenOffset_; + for (int i = 0; i < maxValues; ++i) { + if (ARROW_PREDICT_FALSE(prefixLenPtr[i] < 0)) { throw ParquetException("negative prefix length in DELTA_BYTE_ARRAY"); } if (ARROW_PREDICT_FALSE( - AddWithOverflow(data_size, prefix_len_ptr[i], &data_size) || - AddWithOverflow(data_size, buffer[i].len, &data_size))) { + addWithOverflow(dataSize, prefixLenPtr[i], &dataSize) || + addWithOverflow(dataSize, buffer[i].len, &dataSize))) { throw ParquetException("excess expansion in DELTA_BYTE_ARRAY"); } } - PARQUET_THROW_NOT_OK(buffered_data_->Resize(data_size)); + PARQUET_THROW_NOT_OK(bufferedData_->Resize(dataSize)); - string_view prefix{last_value_}; - uint8_t* data_ptr = buffered_data_->mutable_data(); - for (int i = 0; i < max_values; ++i) { + string_view prefix{lastValue_}; + uint8_t* dataPtr = bufferedData_->mutable_data(); + for (int i = 0; i < maxValues; ++i) { if (ARROW_PREDICT_FALSE( - static_cast(prefix_len_ptr[i]) > prefix.length())) { + static_cast(prefixLenPtr[i]) > prefix.length())) { throw ParquetException("prefix length too large in DELTA_BYTE_ARRAY"); } - memcpy(data_ptr, prefix.data(), prefix_len_ptr[i]); - // buffer[i] currently points to the string suffix - memcpy(data_ptr + prefix_len_ptr[i], buffer[i].ptr, buffer[i].len); - buffer[i].ptr = data_ptr; - buffer[i].len += prefix_len_ptr[i]; - data_ptr += buffer[i].len; + memcpy(dataPtr, prefix.data(), prefixLenPtr[i]); + // Buffer[i] currently points to the string suffix. + memcpy(dataPtr + prefixLenPtr[i], buffer[i].ptr, buffer[i].len); + buffer[i].ptr = dataPtr; + buffer[i].len += prefixLenPtr[i]; + dataPtr += buffer[i].len; prefix = std::string_view{buffer[i]}; } - prefix_len_offset_ += max_values; - this->num_values_ -= max_values; - num_valid_values_ -= max_values; - last_value_ = std::string{prefix}; + prefixLenOffset_ += maxValues; + this->numValues_ -= maxValues; + numValidValues_ -= maxValues; + lastValue_ = std::string{prefix}; - if (num_valid_values_ == 0) { - last_value_in_previous_page_ = last_value_; + if (numValidValues_ == 0) { + lastValueInPreviousPage_ = lastValue_; } - return max_values; + return maxValues; } - Status DecodeArrowDense( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + Status decodeArrowDense( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::Accumulator* out, - int* out_num_values) { - ArrowBinaryHelper helper(out, num_values); - RETURN_NOT_OK(helper.Prepare()); + int* outNumValues) { + ArrowBinaryHelper helper(out, numValues); + RETURN_NOT_OK(helper.prepare()); - std::vector values(num_values); - const int num_valid_values = - GetInternal(values.data(), num_values - null_count); - VELOX_DCHECK_EQ(num_values - null_count, num_valid_values); + std::vector values(numValues); + const int numValidValues = + getInternal(values.data(), numValues - nullCount); + VELOX_DCHECK_EQ(numValues - nullCount, numValidValues); - auto values_ptr = reinterpret_cast(values.data()); - int value_idx = 0; + auto valuesPtr = reinterpret_cast(values.data()); + int valueIdx = 0; RETURN_NOT_OK(VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { - const auto& val = values_ptr[value_idx]; - RETURN_NOT_OK(helper.PrepareNextInput(val.len)); + const auto& val = valuesPtr[valueIdx]; + RETURN_NOT_OK(helper.prepareNextInput(val.len)); RETURN_NOT_OK(helper.Append(val.ptr, static_cast(val.len))); - ++value_idx; + ++valueIdx; return Status::OK(); }, [&]() { RETURN_NOT_OK(helper.AppendNull()); - --null_count; + --nullCount; return Status::OK(); })); - VELOX_DCHECK_EQ(null_count, 0); - *out_num_values = num_valid_values; + VELOX_DCHECK_EQ(nullCount, 0); + *outNumValues = numValidValues; return Status::OK(); } @@ -3925,15 +3881,15 @@ class DeltaByteArrayDecoderImpl : public DecoderImpl, private: std::shared_ptr decoder_; - DeltaBitPackDecoder prefix_len_decoder_; - DeltaLengthByteArrayDecoder suffix_decoder_; - std::string last_value_; - // string buffer for last value in previous page - std::string last_value_in_previous_page_; - int num_valid_values_{0}; - uint32_t prefix_len_offset_{0}; - std::shared_ptr buffered_prefix_length_; - std::shared_ptr buffered_data_; + DeltaBitPackDecoder prefixLenDecoder_; + DeltaLengthByteArrayDecoder suffixDecoder_; + std::string lastValue_; + // String buffer for last value in previous page. + std::string lastValueInPreviousPage_; + int numValidValues_{0}; + uint32_t prefixLenOffset_{0}; + std::shared_ptr bufferedPrefixLength_; + std::shared_ptr bufferedData_; }; class DeltaByteArrayDecoder : public DeltaByteArrayDecoderImpl { @@ -3941,8 +3897,8 @@ class DeltaByteArrayDecoder : public DeltaByteArrayDecoderImpl { using Base = DeltaByteArrayDecoderImpl; using Base::DeltaByteArrayDecoderImpl; - int Decode(ByteArray* buffer, int max_values) override { - return GetInternal(buffer, max_values); + int decode(ByteArray* buffer, int maxValues) override { + return getInternal(buffer, maxValues); } }; @@ -3953,62 +3909,62 @@ class DeltaByteArrayFLBADecoder : public DeltaByteArrayDecoderImpl, using Base::DeltaByteArrayDecoderImpl; using Base::pool_; - int Decode(FixedLenByteArray* buffer, int max_values) override { + int decode(FixedLenByteArray* buffer, int maxValues) override { // GetInternal currently only support ByteArray. - std::vector decode_byte_array(max_values); - const int decoded_values_size = - GetInternal(decode_byte_array.data(), max_values); - const uint32_t type_length = descr_->type_length(); + std::vector decodeByteArray(maxValues); + const int decodedValuesSize = + getInternal(decodeByteArray.data(), maxValues); + const uint32_t typeLength = descr_->typeLength(); - for (int i = 0; i < decoded_values_size; i++) { - if (ARROW_PREDICT_FALSE(decode_byte_array[i].len != type_length)) { + for (int i = 0; i < decodedValuesSize; i++) { + if (ARROW_PREDICT_FALSE(decodeByteArray[i].len != typeLength)) { throw ParquetException("Fixed length byte array length mismatch"); } - buffer[i].ptr = decode_byte_array[i].ptr; + buffer[i].ptr = decodeByteArray[i].ptr; } - return decoded_values_size; + return decodedValuesSize; } }; -// ---------------------------------------------------------------------- -// BYTE_STREAM_SPLIT +// ----------------------------------------------------------------------. +// BYTE_STREAM_SPLIT. template class ByteStreamSplitDecoder : public DecoderImpl, virtual public TypedDecoder { public: - using T = typename DType::c_type; + using T = typename DType::CType; explicit ByteStreamSplitDecoder(const ColumnDescriptor* descr); - int Decode(T* buffer, int max_values) override; + int decode(T* buffer, int maxValues) override; - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::Accumulator* builder) override; + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::Accumulator* Builder) override; - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) override; + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) override; - void SetData(int num_values, const uint8_t* data, int len) override; + void setData(int numValues, const uint8_t* data, int len) override; - T* EnsureDecodeBuffer(int64_t min_values) { - const int64_t size = sizeof(T) * min_values; - if (!decode_buffer_ || decode_buffer_->size() < size) { - PARQUET_ASSIGN_OR_THROW(decode_buffer_, ::arrow::AllocateBuffer(size)); + T* ensureDecodeBuffer(int64_t minValues) { + const int64_t size = sizeof(T) * minValues; + if (!decodeBuffer_ || decodeBuffer_->size() < size) { + decodeBuffer_ = allocateBuffer(::arrow::default_memory_pool(), size); } - return reinterpret_cast(decode_buffer_->mutable_data()); + return reinterpret_cast(decodeBuffer_->mutable_data()); } private: - int num_values_in_buffer_{0}; - std::shared_ptr decode_buffer_; + int numValuesInBuffer_{0}; + std::shared_ptr decodeBuffer_; static constexpr size_t kNumStreams = sizeof(T); }; @@ -4016,14 +3972,14 @@ class ByteStreamSplitDecoder : public DecoderImpl, template ByteStreamSplitDecoder::ByteStreamSplitDecoder( const ColumnDescriptor* descr) - : DecoderImpl(descr, Encoding::BYTE_STREAM_SPLIT) {} + : DecoderImpl(descr, Encoding::kByteStreamSplit) {} template -void ByteStreamSplitDecoder::SetData( - int num_values, +void ByteStreamSplitDecoder::setData( + int numValues, const uint8_t* data, int len) { - if (num_values * static_cast(sizeof(T)) < len) { + if (numValues * static_cast(sizeof(T)) < len) { throw ParquetException( "Data size too large for number of values (padding in byte stream split data " "page?)"); @@ -4031,194 +3987,193 @@ void ByteStreamSplitDecoder::SetData( if (len % sizeof(T) != 0) { throw ParquetException( "ByteStreamSplit data size " + std::to_string(len) + - " not aligned with type " + TypeToString(DType::type_num)); + " not aligned with type " + typeToString(DType::typeNum)); } - num_values = len / sizeof(T); - DecoderImpl::SetData(num_values, data, len); - num_values_in_buffer_ = num_values_; + numValues = len / sizeof(T); + DecoderImpl::setData(numValues, data, len); + numValuesInBuffer_ = numValues_; } template -int ByteStreamSplitDecoder::Decode(T* buffer, int max_values) { - const int values_to_decode = std::min(num_values_, max_values); - const int num_decoded_previously = num_values_in_buffer_ - num_values_; - const uint8_t* data = data_ + num_decoded_previously; - - ByteStreamSplitDecode( - data, values_to_decode, num_values_in_buffer_, buffer); - num_values_ -= values_to_decode; - len_ -= sizeof(T) * values_to_decode; - return values_to_decode; +int ByteStreamSplitDecoder::decode(T* buffer, int maxValues) { + const int valuesToDecode = std::min(numValues_, maxValues); + const int numDecodedPreviously = numValuesInBuffer_ - numValues_; + const uint8_t* data = data_ + numDecodedPreviously; + + byteStreamSplitDecode(data, valuesToDecode, numValuesInBuffer_, buffer); + numValues_ -= valuesToDecode; + len_ -= sizeof(T) * valuesToDecode; + return valuesToDecode; } template -int ByteStreamSplitDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::Accumulator* builder) { - constexpr int value_size = static_cast(kNumStreams); - int values_decoded = num_values - null_count; - if (ARROW_PREDICT_FALSE(len_ < value_size * values_decoded)) { - ParquetException::EofException(); - } - - PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); - - const int num_decoded_previously = num_values_in_buffer_ - num_values_; - const uint8_t* data = data_ + num_decoded_previously; +int ByteStreamSplitDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::Accumulator* Builder) { + constexpr int valueSize = static_cast(kNumStreams); + int valuesDecoded = numValues - nullCount; + if (ARROW_PREDICT_FALSE(len_ < valueSize * valuesDecoded)) { + ParquetException::eofException(); + } + + PARQUET_THROW_NOT_OK(Builder->Reserve(numValues)); + + const int numDecodedPreviously = numValuesInBuffer_ - numValues_; + const uint8_t* data = data_ + numDecodedPreviously; int offset = 0; #if defined(ARROW_HAVE_SIMD_SPLIT) - // Use fast decoding into intermediate buffer. This will also decode - // some null values, but it's fast enough that we don't care. - T* decode_out = EnsureDecodeBuffer(values_decoded); - ::arrow::util::internal::ByteStreamSplitDecode( - data, values_decoded, num_values_in_buffer_, decode_out); - - // XXX If null_count is 0, we could even append in bulk or decode directly - // into builder + // Use fast decoding into intermediate buffer. This will also decode. + // Some null values, but it's fast enough that we don't care. + T* decodeOut = ensureDecodeBuffer(valuesDecoded); + ::arrow::util::internal::byte_stream_split_decode( + data, valuesDecoded, numValuesInBuffer_, decodeOut); + + // XXX If null_count is 0, we could even Append in bulk or decode directly. + // Into builder. VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { - builder->UnsafeAppend(decode_out[offset]); + Builder->UnsafeAppend(decodeOut[offset]); ++offset; }, - [&]() { builder->UnsafeAppendNull(); }); + [&]() { Builder->UnsafeAppendNull(); }); #else VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { - uint8_t gathered_byte_data[kNumStreams]; + uint8_t gatheredByteData[kNumStreams]; for (size_t b = 0; b < kNumStreams; ++b) { - const size_t byte_index = b * num_values_in_buffer_ + offset; - gathered_byte_data[b] = data[byte_index]; + const size_t byteIndex = b * numValuesInBuffer_ + offset; + gatheredByteData[b] = data[byteIndex]; } - builder->UnsafeAppend(SafeLoadAs(&gathered_byte_data[0])); + Builder->UnsafeAppend(SafeLoadAs(&gatheredByteData[0])); ++offset; }, - [&]() { builder->UnsafeAppendNull(); }); + [&]() { Builder->UnsafeAppendNull(); }); #endif - num_values_ -= values_decoded; - len_ -= sizeof(T) * values_decoded; - return values_decoded; + numValues_ -= valuesDecoded; + len_ -= sizeof(T) * valuesDecoded; + return valuesDecoded; } template -int ByteStreamSplitDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) { +int ByteStreamSplitDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) { ParquetException::NYI("DecodeArrow for ByteStreamSplitDecoder"); } } // namespace -// ---------------------------------------------------------------------- -// Encoder and decoder factory functions +// ----------------------------------------------------------------------. +// Encoder and decoder factory functions. -std::unique_ptr MakeEncoder( - Type::type type_num, +std::unique_ptr makeEncoder( + Type::type typeNum, Encoding::type encoding, - bool use_dictionary, + bool useDictionary, const ColumnDescriptor* descr, MemoryPool* pool) { - if (use_dictionary) { - switch (type_num) { - case Type::INT32: + if (useDictionary) { + switch (typeNum) { + case Type::kInt32: return std::make_unique>(descr, pool); - case Type::INT64: + case Type::kInt64: return std::make_unique>(descr, pool); - case Type::INT96: + case Type::kInt96: return std::make_unique>(descr, pool); - case Type::FLOAT: + case Type::kFloat: return std::make_unique>(descr, pool); - case Type::DOUBLE: + case Type::kDouble: return std::make_unique>(descr, pool); - case Type::BYTE_ARRAY: + case Type::kByteArray: return std::make_unique>(descr, pool); - case Type::FIXED_LEN_BYTE_ARRAY: + case Type::kFixedLenByteArray: return std::make_unique>(descr, pool); default: VELOX_DCHECK(false, "Encoder not implemented"); break; } - } else if (encoding == Encoding::PLAIN) { - switch (type_num) { - case Type::BOOLEAN: + } else if (encoding == Encoding::kPlain) { + switch (typeNum) { + case Type::kBoolean: return std::make_unique>(descr, pool); - case Type::INT32: + case Type::kInt32: return std::make_unique>(descr, pool); - case Type::INT64: + case Type::kInt64: return std::make_unique>(descr, pool); - case Type::INT96: + case Type::kInt96: return std::make_unique>(descr, pool); - case Type::FLOAT: + case Type::kFloat: return std::make_unique>(descr, pool); - case Type::DOUBLE: + case Type::kDouble: return std::make_unique>(descr, pool); - case Type::BYTE_ARRAY: + case Type::kByteArray: return std::make_unique>(descr, pool); - case Type::FIXED_LEN_BYTE_ARRAY: + case Type::kFixedLenByteArray: return std::make_unique>(descr, pool); default: VELOX_DCHECK(false, "Encoder not implemented"); break; } - } else if (encoding == Encoding::BYTE_STREAM_SPLIT) { - switch (type_num) { - case Type::FLOAT: + } else if (encoding == Encoding::kByteStreamSplit) { + switch (typeNum) { + case Type::kFloat: return std::make_unique>(descr, pool); - case Type::DOUBLE: + case Type::kDouble: return std::make_unique>( descr, pool); default: throw ParquetException( "BYTE_STREAM_SPLIT only supports FLOAT and DOUBLE"); } - } else if (encoding == Encoding::DELTA_BINARY_PACKED) { - switch (type_num) { - case Type::INT32: + } else if (encoding == Encoding::kDeltaBinaryPacked) { + switch (typeNum) { + case Type::kInt32: return std::make_unique>(descr, pool); - case Type::INT64: + case Type::kInt64: return std::make_unique>(descr, pool); default: throw ParquetException( "DELTA_BINARY_PACKED encoder only supports INT32 and INT64"); } - } else if (encoding == Encoding::DELTA_LENGTH_BYTE_ARRAY) { - switch (type_num) { - case Type::BYTE_ARRAY: + } else if (encoding == Encoding::kDeltaLengthByteArray) { + switch (typeNum) { + case Type::kByteArray: return std::make_unique>( descr, pool); default: throw ParquetException( "DELTA_LENGTH_BYTE_ARRAY only supports BYTE_ARRAY"); } - } else if (encoding == Encoding::RLE) { - switch (type_num) { - case Type::BOOLEAN: + } else if (encoding == Encoding::kRle) { + switch (typeNum) { + case Type::kBoolean: return std::make_unique(descr, pool); default: throw ParquetException("RLE only supports BOOLEAN"); } - } else if (encoding == Encoding::DELTA_BYTE_ARRAY) { - switch (type_num) { - case Type::BYTE_ARRAY: + } else if (encoding == Encoding::kDeltaByteArray) { + switch (typeNum) { + case Type::kByteArray: return std::make_unique>( descr, pool); - case Type::FIXED_LEN_BYTE_ARRAY: + case Type::kFixedLenByteArray: return std::make_unique>(descr, pool); default: throw ParquetException( @@ -4231,69 +4186,69 @@ std::unique_ptr MakeEncoder( return nullptr; } -std::unique_ptr MakeDecoder( - Type::type type_num, +std::unique_ptr makeDecoder( + Type::type typeNum, Encoding::type encoding, const ColumnDescriptor* descr, ::arrow::MemoryPool* pool) { - if (encoding == Encoding::PLAIN) { - switch (type_num) { - case Type::BOOLEAN: + if (encoding == Encoding::kPlain) { + switch (typeNum) { + case Type::kBoolean: return std::make_unique(descr); - case Type::INT32: + case Type::kInt32: return std::make_unique>(descr); - case Type::INT64: + case Type::kInt64: return std::make_unique>(descr); - case Type::INT96: + case Type::kInt96: return std::make_unique>(descr); - case Type::FLOAT: + case Type::kFloat: return std::make_unique>(descr); - case Type::DOUBLE: + case Type::kDouble: return std::make_unique>(descr); - case Type::BYTE_ARRAY: + case Type::kByteArray: return std::make_unique(descr); - case Type::FIXED_LEN_BYTE_ARRAY: + case Type::kFixedLenByteArray: return std::make_unique(descr); default: break; } - } else if (encoding == Encoding::BYTE_STREAM_SPLIT) { - switch (type_num) { - case Type::FLOAT: + } else if (encoding == Encoding::kByteStreamSplit) { + switch (typeNum) { + case Type::kFloat: return std::make_unique>(descr); - case Type::DOUBLE: + case Type::kDouble: return std::make_unique>(descr); default: throw ParquetException( "BYTE_STREAM_SPLIT only supports FLOAT and DOUBLE"); } - } else if (encoding == Encoding::DELTA_BINARY_PACKED) { - switch (type_num) { - case Type::INT32: + } else if (encoding == Encoding::kDeltaBinaryPacked) { + switch (typeNum) { + case Type::kInt32: return std::make_unique>(descr, pool); - case Type::INT64: + case Type::kInt64: return std::make_unique>(descr, pool); default: throw ParquetException( "DELTA_BINARY_PACKED decoder only supports INT32 and INT64"); } - } else if (encoding == Encoding::DELTA_BYTE_ARRAY) { - switch (type_num) { - case Type::BYTE_ARRAY: + } else if (encoding == Encoding::kDeltaByteArray) { + switch (typeNum) { + case Type::kByteArray: return std::make_unique(descr, pool); - case Type::FIXED_LEN_BYTE_ARRAY: + case Type::kFixedLenByteArray: return std::make_unique(descr, pool); default: throw ParquetException( "DELTA_BYTE_ARRAY only supports BYTE_ARRAY and FIXED_LEN_BYTE_ARRAY"); } - } else if (encoding == Encoding::DELTA_LENGTH_BYTE_ARRAY) { - if (type_num == Type::BYTE_ARRAY) { + } else if (encoding == Encoding::kDeltaLengthByteArray) { + if (typeNum == Type::kByteArray) { return std::make_unique(descr, pool); } throw ParquetException("DELTA_LENGTH_BYTE_ARRAY only supports BYTE_ARRAY"); - } else if (encoding == Encoding::RLE) { - if (type_num == Type::BOOLEAN) { + } else if (encoding == Encoding::kRle) { + if (typeNum == Type::kBoolean) { throw ParquetException("RleBooleanDecoder has been disabled."); // return std::make_unique(descr); } @@ -4306,27 +4261,27 @@ std::unique_ptr MakeDecoder( } namespace detail { -std::unique_ptr MakeDictDecoder( - Type::type type_num, +std::unique_ptr makeDictDecoder( + Type::type typeNum, const ColumnDescriptor* descr, MemoryPool* pool) { - switch (type_num) { - case Type::BOOLEAN: + switch (typeNum) { + case Type::kBoolean: ParquetException::NYI( "Dictionary encoding not implemented for boolean type"); - case Type::INT32: + case Type::kInt32: return std::make_unique>(descr, pool); - case Type::INT64: + case Type::kInt64: return std::make_unique>(descr, pool); - case Type::INT96: + case Type::kInt96: return std::make_unique>(descr, pool); - case Type::FLOAT: + case Type::kFloat: return std::make_unique>(descr, pool); - case Type::DOUBLE: + case Type::kDouble: return std::make_unique>(descr, pool); - case Type::BYTE_ARRAY: + case Type::kByteArray: return std::make_unique(descr, pool); - case Type::FIXED_LEN_BYTE_ARRAY: + case Type::kFixedLenByteArray: return std::make_unique>(descr, pool); default: break; diff --git a/velox/dwio/parquet/writer/arrow/Encoding.h b/velox/dwio/parquet/writer/arrow/Encoding.h index b57ee2b68e3..0100136ecf5 100644 --- a/velox/dwio/parquet/writer/arrow/Encoding.h +++ b/velox/dwio/parquet/writer/arrow/Encoding.h @@ -144,9 +144,9 @@ struct EncodingTraits { using ArrowType = ::arrow::BinaryType; /// \brief Internal helper class for decoding BYTE_ARRAY data where we can - /// overflow the capacity of a single arrow::BinaryArray + /// overflow the capacity of a single arrow::BinaryArray. struct Accumulator { - std::unique_ptr<::arrow::BinaryBuilder> builder; + std::unique_ptr<::arrow::BinaryBuilder> Builder; std::vector> chunks; }; using DictAccumulator = ::arrow::Dictionary32Builder<::arrow::BinaryType>; @@ -165,18 +165,18 @@ struct EncodingTraits { class ColumnDescriptor; -// Untyped base for all encoders +// Untyped base for all encoders. class Encoder { public: virtual ~Encoder() = default; - virtual int64_t EstimatedDataEncodedSize() = 0; - virtual std::shared_ptr<::arrow::Buffer> FlushValues() = 0; + virtual int64_t estimatedDataEncodedSize() = 0; + virtual std::shared_ptr<::arrow::Buffer> flushValues() = 0; virtual Encoding::type encoding() const = 0; - virtual void Put(const ::arrow::Array& values) = 0; + virtual void put(const ::arrow::Array& values) = 0; - virtual ::arrow::MemoryPool* memory_pool() const = 0; + virtual ::arrow::MemoryPool* memoryPool() const = 0; }; // Base class for value encoders. Since encoders may or not have state (e.g., @@ -186,74 +186,74 @@ class Encoder { template class TypedEncoder : virtual public Encoder { public: - typedef typename DType::c_type T; + typedef typename DType::CType T; - using Encoder::Put; + using Encoder::put; - virtual void Put(const T* src, int num_values) = 0; + virtual void put(const T* src, int numValues) = 0; - virtual void Put(const std::vector& src, int num_values = -1); + virtual void put(const std::vector& src, int numValues = -1); - virtual void PutSpaced( + virtual void putSpaced( const T* src, - int num_values, - const uint8_t* valid_bits, - int64_t valid_bits_offset) = 0; + int numValues, + const uint8_t* validBits, + int64_t validBitsOffset) = 0; }; template -void TypedEncoder::Put(const std::vector& src, int num_values) { - if (num_values == -1) { - num_values = static_cast(src.size()); +void TypedEncoder::put(const std::vector& src, int numValues) { + if (numValues == -1) { + numValues = static_cast(src.size()); } - Put(src.data(), num_values); + put(src.data(), numValues); } template <> -inline void TypedEncoder::Put( +inline void TypedEncoder::put( const std::vector& src, - int num_values) { + int numValues) { // NOTE(wesm): This stub is here only to satisfy the compiler; it is - // overridden later with the actual implementation + // overridden later with the actual implementation. } -// Base class for dictionary encoders +// Base class for dictionary encoders. template class DictEncoder : virtual public TypedEncoder { public: /// Writes out any buffered indices to buffer preceded by the bit width of /// this data. Returns the number of bytes written. If the supplied buffer is - /// not big enough, returns -1. buffer must be preallocated with buffer_len + /// not big enough, returns -1. Buffer must be preallocated with buffer_len /// bytes. Use EstimatedDataEncodedSize() to size buffer. - virtual int WriteIndices(uint8_t* buffer, int buffer_len) = 0; + virtual int writeIndices(uint8_t* buffer, int bufferLen) = 0; - virtual int dict_encoded_size() const = 0; + virtual int dictEncodedSize() const = 0; - virtual int bit_width() const = 0; + virtual int bitWidth() const = 0; - /// Writes out the encoded dictionary to buffer. buffer must be preallocated + /// Writes out the encoded dictionary to buffer. Buffer must be preallocated /// to dict_encoded_size() bytes. - virtual void WriteDict(uint8_t* buffer) const = 0; + virtual void writeDict(uint8_t* buffer) const = 0; - virtual int num_entries() const = 0; + virtual int numEntries() const = 0; /// \brief EXPERIMENTAL: Append dictionary indices into the encoder. It is /// assumed (without any boundschecking) that the indices reference - /// pre-existing dictionary values - /// \param[in] indices the dictionary index values. Only Int32Array currently - /// supported - virtual void PutIndices(const ::arrow::Array& indices) = 0; + /// pre-existing dictionary values. + /// \param[in] indices The dictionary index values. Only Int32Array currently + /// supported. + virtual void putIndices(const ::arrow::Array& indices) = 0; /// \brief EXPERIMENTAL: Append dictionary into encoder, inserting indices /// separately. Currently throws exception if the current dictionary memo is - /// non-empty - /// \param[in] values the dictionary values. Only valid for certain - /// Parquet/Arrow type combinations, like BYTE_ARRAY/BinaryArray - virtual void PutDictionary(const ::arrow::Array& values) = 0; + /// non-empty. + /// \param[in] values The dictionary values. Only valid for certain + /// Parquet/Arrow type combinations, like BYTE_ARRAY/BinaryArray. + virtual void putDictionary(const ::arrow::Array& values) = 0; }; -// ---------------------------------------------------------------------- -// Value decoding +// ----------------------------------------------------------------------. +// Value decoding. class Decoder { public: @@ -261,145 +261,145 @@ class Decoder { // Sets the data for a new page. This will be called multiple times on the // same decoder and should reset all internal state. - virtual void SetData(int num_values, const uint8_t* data, int len) = 0; + virtual void setData(int numValues, const uint8_t* data, int len) = 0; - // Returns the number of values left (for the last call to SetData()). This is - // the number of values left in this page. - virtual int values_left() const = 0; + // Returns the number of values left (for the last call to SetData()). This + // is the number of values left in this page. + virtual int valuesLeft() const = 0; virtual Encoding::type encoding() const = 0; }; template class TypedDecoder : virtual public Decoder { public: - using T = typename DType::c_type; + using T = typename DType::CType; - /// \brief Decode values into a buffer + /// \brief Decode values into a buffer. /// /// Subclasses may override the more specialized Decode methods below. /// - /// \param[in] buffer destination for decoded values - /// \param[in] max_values maximum number of values to decode + /// \param[in] buffer Destination for decoded values. + /// \param[in] max_values Maximum number of values to decode. /// \return The number of values decoded. Should be identical to max_values /// except at the end of the current data page. - virtual int Decode(T* buffer, int max_values) = 0; + virtual int decode(T* buffer, int maxValues) = 0; /// \brief Decode the values in this data page but leave spaces for null /// entries. /// - /// \param[in] buffer destination for decoded values - /// \param[in] num_values size of the def_levels and buffer arrays including - /// the number of null slots \param[in] null_count number of null slots - /// \param[in] valid_bits bitmap data indicating position of valid slots - /// \param[in] valid_bits_offset offset into valid_bits + /// \param[in] buffer Destination for decoded values. + /// \param[in] num_values Size of the def_levels and buffer arrays including + /// the number of null slots \param[in] null_count Number of null slots. + /// \param[in] valid_bits Bitmap data indicating position of valid slots. + /// \param[in] valid_bits_offset Offset into valid_bits. /// \return The number of values decoded, including nulls. - virtual int DecodeSpaced( + virtual int decodeSpaced( T* buffer, - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset) { - if (null_count > 0) { - int values_to_read = num_values - null_count; - int values_read = Decode(buffer, values_to_read); - if (values_read != values_to_read) { + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset) { + if (nullCount > 0) { + int valuesToRead = numValues - nullCount; + int valuesRead = decode(buffer, valuesToRead); + if (valuesRead != valuesToRead) { throw ParquetException( "Number of values / definition_levels read did not match"); } return ::arrow::util::internal::SpacedExpand( - buffer, num_values, null_count, valid_bits, valid_bits_offset); + buffer, numValues, nullCount, validBits, validBitsOffset); } else { - return Decode(buffer, num_values); + return decode(buffer, numValues); } } - /// \brief Decode into an ArrayBuilder or other accumulator + /// \brief Decode into an ArrayBuilder or other accumulator. /// - /// This function assumes the definition levels were already decoded - /// as a validity bitmap in the given `valid_bits`. `null_count` + /// This function assumes the definition levels were already decoded. + /// As a validity bitmap in the given `valid_bits`. `null_count` /// is the number of 0s in `valid_bits`. /// As a space optimization, it is allowed for `valid_bits` to be null /// if `null_count` is zero. /// - /// \return number of values decoded - virtual int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + /// \return Number of values decoded. + virtual int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::Accumulator* out) = 0; - /// \brief Decode into an ArrayBuilder or other accumulator ignoring nulls + /// \brief Decode into an ArrayBuilder or other accumulator ignoring nulls. /// - /// \return number of values decoded - int DecodeArrowNonNull( - int num_values, + /// \return Number of values decoded. + int decodeArrowNonNull( + int numValues, typename EncodingTraits::Accumulator* out) { - return DecodeArrow(num_values, 0, /*valid_bits=*/NULLPTR, 0, out); + return decodeArrow(numValues, 0, NULLPTR, 0, out); } - /// \brief Decode into a DictionaryBuilder + /// \brief Decode into a DictionaryBuilder. /// - /// This function assumes the definition levels were already decoded - /// as a validity bitmap in the given `valid_bits`. `null_count` + /// This function assumes the definition levels were already decoded. + /// As a validity bitmap in the given `valid_bits`. `null_count` /// is the number of 0s in `valid_bits`. /// As a space optimization, it is allowed for `valid_bits` to be null /// if `null_count` is zero. /// - /// \return number of values decoded - virtual int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) = 0; - - /// \brief Decode into a DictionaryBuilder ignoring nulls + /// \return Number of values decoded. + virtual int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) = 0; + + /// \brief Decode into a DictionaryBuilder ignoring nulls. /// - /// \return number of values decoded - int DecodeArrowNonNull( - int num_values, - typename EncodingTraits::DictAccumulator* builder) { - return DecodeArrow(num_values, 0, /*valid_bits=*/NULLPTR, 0, builder); + /// \return Number of values decoded. + int decodeArrowNonNull( + int numValues, + typename EncodingTraits::DictAccumulator* Builder) { + return decodeArrow(numValues, 0, NULLPTR, 0, Builder); } }; template class DictDecoder : virtual public TypedDecoder { public: - using T = typename DType::c_type; + using T = typename DType::CType; - virtual void SetDict(TypedDecoder* dictionary) = 0; + virtual void setDict(TypedDecoder* dictionary) = 0; /// \brief Insert dictionary values into the Arrow dictionary builder's memo, - /// but do not append any indices - virtual void InsertDictionary(::arrow::ArrayBuilder* builder) = 0; + /// but do not append any indices. + virtual void insertDictionary(::arrow::ArrayBuilder* Builder) = 0; /// \brief Decode only dictionary indices and append to dictionary /// builder. The builder must have had the dictionary from this decoder /// inserted already. /// /// \warning Remember to reset the builder each time the dict decoder is - /// initialized with a new dictionary page - virtual int DecodeIndicesSpaced( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - ::arrow::ArrayBuilder* builder) = 0; - - /// \brief Decode only dictionary indices (no nulls) + /// initialized with a new dictionary page. + virtual int decodeIndicesSpaced( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + ::arrow::ArrayBuilder* Builder) = 0; + + /// \brief Decode only dictionary indices (no nulls). /// /// \warning Remember to reset the builder each time the dict decoder is - /// initialized with a new dictionary page - virtual int DecodeIndices(int num_values, ::arrow::ArrayBuilder* builder) = 0; + /// initialized with a new dictionary page. + virtual int decodeIndices(int numValues, ::arrow::ArrayBuilder* Builder) = 0; /// \brief Decode only dictionary indices (no nulls). Same as above - /// DecodeIndices but target is an array instead of a builder. + /// decodeIndices but target is an array instead of a builder. /// - /// \note API EXPERIMENTAL - virtual int DecodeIndices(int num_values, int32_t* indices) = 0; + /// \note API EXPERIMENTAL. + virtual int decodeIndices(int numValues, int32_t* indices) = 0; /// \brief Get dictionary. The reader will call this API when it encounters a /// new dictionary. @@ -408,62 +408,62 @@ class DictDecoder : virtual public TypedDecoder { /// owned by the decoder and is destroyed when the decoder is destroyed. /// @param[out] dictionary_length The dictionary length. /// - /// \note API EXPERIMENTAL - virtual void GetDictionary( + /// \note API EXPERIMENTAL. + virtual void getDictionary( const T** dictionary, - int32_t* dictionary_length) = 0; + int32_t* dictionaryLength) = 0; }; -// ---------------------------------------------------------------------- -// TypedEncoder specializations, traits, and factory functions +// ----------------------------------------------------------------------. +// TypedEncoder specializations, traits, and factory functions. class BooleanDecoder : virtual public TypedDecoder { public: - using TypedDecoder::Decode; + using TypedDecoder::decode; - /// \brief Decode and bit-pack values into a buffer + /// \brief Decode and bit-pack values into a buffer. /// - /// \param[in] buffer destination for decoded values + /// \param[in] buffer Destination for decoded values. /// This buffer will contain bit-packed values. - /// \param[in] max_values max values to decode. + /// \param[in] max_values Max values to decode. /// \return The number of values decoded. Should be identical to max_values /// except at the end of the current data page. - virtual int Decode(uint8_t* buffer, int max_values) = 0; + virtual int decode(uint8_t* buffer, int maxValues) = 0; }; class FLBADecoder : virtual public TypedDecoder { public: - using TypedDecoder::DecodeSpaced; + using TypedDecoder::decodeSpaced; // TODO(wesm): As possible follow-up to PARQUET-1508, we should examine if // there is value in adding specialized read methods for // FIXED_LEN_BYTE_ARRAY. If only Decimal data can occur with this data type - // then perhaps not + // then perhaps not. }; PARQUET_EXPORT -std::unique_ptr MakeEncoder( - Type::type type_num, +std::unique_ptr makeEncoder( + Type::type typeNum, Encoding::type encoding, - bool use_dictionary = false, + bool useDictionary = false, const ColumnDescriptor* descr = NULLPTR, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()); template -std::unique_ptr::Encoder> MakeTypedEncoder( +std::unique_ptr::Encoder> makeTypedEncoder( Encoding::type encoding, - bool use_dictionary = false, + bool useDictionary = false, const ColumnDescriptor* descr = NULLPTR, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()) { using OutType = typename EncodingTraits::Encoder; std::unique_ptr base = - MakeEncoder(DType::type_num, encoding, use_dictionary, descr, pool); + makeEncoder(DType::typeNum, encoding, useDictionary, descr, pool); return std::unique_ptr(dynamic_cast(base.release())); } PARQUET_EXPORT -std::unique_ptr MakeDecoder( - Type::type type_num, +std::unique_ptr makeDecoder( + Type::type typeNum, Encoding::type encoding, const ColumnDescriptor* descr = NULLPTR, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()); @@ -471,30 +471,30 @@ std::unique_ptr MakeDecoder( namespace detail { PARQUET_EXPORT -std::unique_ptr MakeDictDecoder( - Type::type type_num, +std::unique_ptr makeDictDecoder( + Type::type typeNum, const ColumnDescriptor* descr, ::arrow::MemoryPool* pool); } // namespace detail template -std::unique_ptr> MakeDictDecoder( +std::unique_ptr> makeDictDecoder( const ColumnDescriptor* descr = NULLPTR, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()) { using OutType = DictDecoder; - auto decoder = detail::MakeDictDecoder(DType::type_num, descr, pool); + auto decoder = detail::makeDictDecoder(DType::typeNum, descr, pool); return std::unique_ptr(dynamic_cast(decoder.release())); } template -std::unique_ptr::Decoder> MakeTypedDecoder( +std::unique_ptr::Decoder> makeTypedDecoder( Encoding::type encoding, const ColumnDescriptor* descr = NULLPTR, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()) { using OutType = typename EncodingTraits::Decoder; std::unique_ptr base = - MakeDecoder(DType::type_num, encoding, descr, pool); + makeDecoder(DType::typeNum, encoding, descr, pool); return std::unique_ptr(dynamic_cast(base.release())); } diff --git a/velox/dwio/parquet/writer/arrow/Encryption.cpp b/velox/dwio/parquet/writer/arrow/Encryption.cpp index 2aaa06f680e..984ec118416 100644 --- a/velox/dwio/parquet/writer/arrow/Encryption.cpp +++ b/velox/dwio/parquet/writer/arrow/Encryption.cpp @@ -30,159 +30,159 @@ namespace facebook::velox::parquet::arrow { -// integer key retriever -void IntegerKeyIdRetriever::PutKey(uint32_t key_id, const std::string& key) { - key_map_.insert({key_id, key}); +// Integer key retriever. +void IntegerKeyIdRetriever::putKey(uint32_t keyId, const std::string& key) { + keyMap_.insert({keyId, key}); } -std::string IntegerKeyIdRetriever::GetKey(const std::string& key_metadata) { - uint32_t key_id; - memcpy(reinterpret_cast(&key_id), key_metadata.c_str(), 4); +std::string IntegerKeyIdRetriever::getKey(const std::string& keyMetadata) { + uint32_t keyId; + memcpy(reinterpret_cast(&keyId), keyMetadata.c_str(), 4); - return key_map_.at(key_id); + return keyMap_.at(keyId); } -// string key retriever -void StringKeyIdRetriever::PutKey( - const std::string& key_id, +// String key retriever. +void StringKeyIdRetriever::putKey( + const std::string& keyId, const std::string& key) { - key_map_.insert({key_id, key}); + keyMap_.insert({keyId, key}); } -std::string StringKeyIdRetriever::GetKey(const std::string& key_id) { - return key_map_.at(key_id); +std::string StringKeyIdRetriever::getKey(const std::string& keyId) { + return keyMap_.at(keyId); } ColumnEncryptionProperties::Builder* ColumnEncryptionProperties::Builder::key( - std::string column_key) { - if (column_key.empty()) + std::string columnKey) { + if (columnKey.empty()) return this; VELOX_DCHECK(key_.empty()); - key_ = column_key; + key_ = columnKey; return this; } ColumnEncryptionProperties::Builder* -ColumnEncryptionProperties::Builder::key_metadata( - const std::string& key_metadata) { - VELOX_DCHECK(!key_metadata.empty()); - VELOX_DCHECK(key_metadata_.empty()); - key_metadata_ = key_metadata; +ColumnEncryptionProperties::Builder::keyMetadata( + const std::string& keyMetadata) { + VELOX_DCHECK(!keyMetadata.empty()); + VELOX_DCHECK(keyMetadata_.empty()); + keyMetadata_ = keyMetadata; return this; } -ColumnEncryptionProperties::Builder* -ColumnEncryptionProperties::Builder::key_id(const std::string& key_id) { - // key_id is expected to be in UTF8 encoding +ColumnEncryptionProperties::Builder* ColumnEncryptionProperties::Builder::keyId( + const std::string& keyId) { + // Key_id is expected to be in UTF8 encoding. ::arrow::util::InitializeUTF8(); - const uint8_t* data = reinterpret_cast(key_id.c_str()); - if (!::arrow::util::ValidateUTF8(data, key_id.size())) { + const uint8_t* data = reinterpret_cast(keyId.c_str()); + if (!::arrow::util::ValidateUTF8(data, keyId.size())) { throw ParquetException("key id should be in UTF8 encoding"); } - VELOX_DCHECK(!key_id.empty()); - this->key_metadata(key_id); + VELOX_DCHECK(!keyId.empty()); + this->keyMetadata(keyId); return this; } FileDecryptionProperties::Builder* -FileDecryptionProperties::Builder::column_keys( - const ColumnPathToDecryptionPropertiesMap& column_decryption_properties) { - if (column_decryption_properties.size() == 0) +FileDecryptionProperties::Builder::columnKeys( + const ColumnPathToDecryptionPropertiesMap& ColumnDecryptionProperties) { + if (ColumnDecryptionProperties.size() == 0) return this; - if (column_decryption_properties_.size() != 0) + if (columnDecryptionProperties_.size() != 0) throw ParquetException("Column properties already set"); - for (const auto& element : column_decryption_properties) { - if (element.second->is_utilized()) { + for (const auto& element : ColumnDecryptionProperties) { + if (element.second->isUtilized()) { throw ParquetException("Column properties utilized in another file"); } - element.second->set_utilized(); + element.second->setUtilized(); } - column_decryption_properties_ = column_decryption_properties; + columnDecryptionProperties_ = ColumnDecryptionProperties; return this; } -void FileDecryptionProperties::WipeOutDecryptionKeys() { - footer_key_.clear(); +void FileDecryptionProperties::wipeOutDecryptionKeys() { + footerKey_.clear(); - for (const auto& element : column_decryption_properties_) { - element.second->WipeOutDecryptionKey(); + for (const auto& element : columnDecryptionProperties_) { + element.second->wipeOutDecryptionKey(); } } -bool FileDecryptionProperties::is_utilized() { - if (footer_key_.empty() && column_decryption_properties_.size() == 0 && - aad_prefix_.empty()) +bool FileDecryptionProperties::isUtilized() { + if (footerKey_.empty() && columnDecryptionProperties_.size() == 0 && + aadPrefix_.empty()) return false; return utilized_; } -std::shared_ptr FileDecryptionProperties::DeepClone( - std::string new_aad_prefix) { - std::string footer_key_copy = footer_key_; - ColumnPathToDecryptionPropertiesMap column_decryption_properties_map_copy; +std::shared_ptr FileDecryptionProperties::deepClone( + std::string newAadPrefix) { + std::string footerKeyCopy = footerKey_; + ColumnPathToDecryptionPropertiesMap columnDecryptionPropertiesMapCopy; - for (const auto& element : column_decryption_properties_) { - column_decryption_properties_map_copy.insert( - {element.second->column_path(), element.second->DeepClone()}); + for (const auto& element : columnDecryptionProperties_) { + columnDecryptionPropertiesMapCopy.insert( + {element.second->columnPath(), element.second->deepClone()}); } - if (new_aad_prefix.empty()) - new_aad_prefix = aad_prefix_; + if (newAadPrefix.empty()) + newAadPrefix = aadPrefix_; return std::shared_ptr(new FileDecryptionProperties( - footer_key_copy, - key_retriever_, - check_plaintext_footer_integrity_, - new_aad_prefix, - aad_prefix_verifier_, - column_decryption_properties_map_copy, - plaintext_files_allowed_)); -} - -FileDecryptionProperties::Builder* -FileDecryptionProperties::Builder::footer_key(const std::string footer_key) { - if (footer_key.empty()) { + footerKeyCopy, + keyRetriever_, + checkPlaintextFooterIntegrity_, + newAadPrefix, + aadPrefixVerifier_, + columnDecryptionPropertiesMapCopy, + plaintextFilesAllowed_)); +} + +FileDecryptionProperties::Builder* FileDecryptionProperties::Builder::footerKey( + const std::string footerKey) { + if (footerKey.empty()) { return this; } - VELOX_DCHECK(footer_key_.empty()); - footer_key_ = footer_key; + VELOX_DCHECK(footerKey_.empty()); + footerKey_ = footerKey; return this; } FileDecryptionProperties::Builder* -FileDecryptionProperties::Builder::key_retriever( - const std::shared_ptr& key_retriever) { - if (key_retriever == nullptr) +FileDecryptionProperties::Builder::keyRetriever( + const std::shared_ptr& keyRetriever) { + if (keyRetriever == nullptr) return this; - VELOX_DCHECK_NULL(key_retriever_); - key_retriever_ = key_retriever; + VELOX_DCHECK_NULL(keyRetriever_); + keyRetriever_ = keyRetriever; return this; } -FileDecryptionProperties::Builder* -FileDecryptionProperties::Builder::aad_prefix(const std::string& aad_prefix) { - if (aad_prefix.empty()) { +FileDecryptionProperties::Builder* FileDecryptionProperties::Builder::aadPrefix( + const std::string& aadPrefix) { + if (aadPrefix.empty()) { return this; } - VELOX_DCHECK(aad_prefix_.empty()); - aad_prefix_ = aad_prefix; + VELOX_DCHECK(aadPrefix_.empty()); + aadPrefix_ = aadPrefix; return this; } FileDecryptionProperties::Builder* -FileDecryptionProperties::Builder::aad_prefix_verifier( - std::shared_ptr aad_prefix_verifier) { - if (aad_prefix_verifier == nullptr) +FileDecryptionProperties::Builder::aadPrefixVerifier( + std::shared_ptr aadPrefixVerifier) { + if (aadPrefixVerifier == nullptr) return this; - VELOX_DCHECK_NULL(aad_prefix_verifier_); - aad_prefix_verifier_ = std::move(aad_prefix_verifier); + VELOX_DCHECK_NULL(aadPrefixVerifier_); + aadPrefixVerifier_ = std::move(aadPrefixVerifier); return this; } @@ -199,112 +199,112 @@ ColumnDecryptionProperties::Builder* ColumnDecryptionProperties::Builder::key( std::shared_ptr ColumnDecryptionProperties::Builder::build() { return std::shared_ptr( - new ColumnDecryptionProperties(column_path_, key_)); + new ColumnDecryptionProperties(columnPath_, key_)); } -void ColumnDecryptionProperties::WipeOutDecryptionKey() { +void ColumnDecryptionProperties::wipeOutDecryptionKey() { key_.clear(); } std::shared_ptr -ColumnDecryptionProperties::DeepClone() { - std::string key_copy = key_; +ColumnDecryptionProperties::deepClone() { + std::string keyCopy = key_; return std::shared_ptr( - new ColumnDecryptionProperties(column_path_, key_copy)); + new ColumnDecryptionProperties(columnPath_, keyCopy)); } FileEncryptionProperties::Builder* -FileEncryptionProperties::Builder::footer_key_metadata( - const std::string& footer_key_metadata) { - if (footer_key_metadata.empty()) +FileEncryptionProperties::Builder::footerKeyMetadata( + const std::string& footerKeyMetadata) { + if (footerKeyMetadata.empty()) return this; - VELOX_DCHECK(footer_key_metadata_.empty()); - footer_key_metadata_ = footer_key_metadata; + VELOX_DCHECK(footerKeyMetadata_.empty()); + footerKeyMetadata_ = footerKeyMetadata; return this; } FileEncryptionProperties::Builder* -FileEncryptionProperties::Builder::encrypted_columns( - const ColumnPathToEncryptionPropertiesMap& encrypted_columns) { - if (encrypted_columns.size() == 0) +FileEncryptionProperties::Builder::encryptedColumns( + const ColumnPathToEncryptionPropertiesMap& encryptedColumns) { + if (encryptedColumns.size() == 0) return this; - if (encrypted_columns_.size() != 0) + if (encryptedColumns_.size() != 0) throw ParquetException("Column properties already set"); - for (const auto& element : encrypted_columns) { - if (element.second->is_utilized()) { + for (const auto& element : encryptedColumns) { + if (element.second->isUtilized()) { throw ParquetException("Column properties utilized in another file"); } - element.second->set_utilized(); + element.second->setUtilized(); } - encrypted_columns_ = encrypted_columns; + encryptedColumns_ = encryptedColumns; return this; } -void FileEncryptionProperties::WipeOutEncryptionKeys() { - footer_key_.clear(); - for (const auto& element : encrypted_columns_) { - element.second->WipeOutEncryptionKey(); +void FileEncryptionProperties::wipeOutEncryptionKeys() { + footerKey_.clear(); + for (const auto& element : encryptedColumns_) { + element.second->wipeOutEncryptionKey(); } } -std::shared_ptr FileEncryptionProperties::DeepClone( - std::string new_aad_prefix) { - std::string footer_key_copy = footer_key_; - ColumnPathToEncryptionPropertiesMap encrypted_columns_map_copy; +std::shared_ptr FileEncryptionProperties::deepClone( + std::string newAadPrefix) { + std::string footerKeyCopy = footerKey_; + ColumnPathToEncryptionPropertiesMap encryptedColumnsMapCopy; - for (const auto& element : encrypted_columns_) { - encrypted_columns_map_copy.insert( - {element.second->column_path(), element.second->DeepClone()}); + for (const auto& element : encryptedColumns_) { + encryptedColumnsMapCopy.insert( + {element.second->columnPath(), element.second->deepClone()}); } - if (new_aad_prefix.empty()) - new_aad_prefix = aad_prefix_; + if (newAadPrefix.empty()) + newAadPrefix = aadPrefix_; return std::shared_ptr(new FileEncryptionProperties( algorithm_.algorithm, - footer_key_copy, - footer_key_metadata_, - encrypted_footer_, - new_aad_prefix, - store_aad_prefix_in_file_, - encrypted_columns_map_copy)); + footerKeyCopy, + footerKeyMetadata_, + encryptedFooter_, + newAadPrefix, + storeAadPrefixInFile_, + encryptedColumnsMapCopy)); } -FileEncryptionProperties::Builder* -FileEncryptionProperties::Builder::aad_prefix(const std::string& aad_prefix) { - if (aad_prefix.empty()) +FileEncryptionProperties::Builder* FileEncryptionProperties::Builder::aadPrefix( + const std::string& aadPrefix) { + if (aadPrefix.empty()) return this; - VELOX_DCHECK(aad_prefix_.empty()); - aad_prefix_ = aad_prefix; - store_aad_prefix_in_file_ = true; + VELOX_DCHECK(aadPrefix_.empty()); + aadPrefix_ = aadPrefix; + storeAadPrefixInFile_ = true; return this; } FileEncryptionProperties::Builder* -FileEncryptionProperties::Builder::disable_aad_prefix_storage() { - VELOX_DCHECK(!aad_prefix_.empty()); +FileEncryptionProperties::Builder::disableAadPrefixStorage() { + VELOX_DCHECK(!aadPrefix_.empty()); - store_aad_prefix_in_file_ = false; + storeAadPrefixInFile_ = false; return this; } ColumnEncryptionProperties::ColumnEncryptionProperties( bool encrypted, - const std::string& column_path, + const std::string& ColumnPath, const std::string& key, - const std::string& key_metadata) - : column_path_(column_path) { - // column encryption properties object (with a column key) can be used for - // writing only one file. Upon completion of file writing, the encryption keys - // in the properties will be wiped out (set to 0 in memory). + const std::string& keyMetadata) + : columnPath_(ColumnPath) { + // Column encryption properties object (with a column key) can be used for. + // Writing only one file. Upon completion of file writing, the encryption + // keys. In the properties will be wiped out (set to 0 in memory). utilized_ = false; - VELOX_DCHECK(!column_path.empty()); + VELOX_DCHECK(!ColumnPath.empty()); if (!encrypted) { - VELOX_DCHECK(key.empty() && key_metadata.empty()); + VELOX_DCHECK(key.empty() && keyMetadata.empty()); } if (!key.empty()) { @@ -312,22 +312,22 @@ ColumnEncryptionProperties::ColumnEncryptionProperties( key.length() == 16 || key.length() == 24 || key.length() == 32); } - encrypted_with_footer_key_ = (encrypted && key.empty()); - if (encrypted_with_footer_key_) { - VELOX_DCHECK(key_metadata.empty()); + encryptedWithFooterKey_ = (encrypted && key.empty()); + if (encryptedWithFooterKey_) { + VELOX_DCHECK(keyMetadata.empty()); } encrypted_ = encrypted; - key_metadata_ = key_metadata; + keyMetadata_ = keyMetadata; key_ = key; } ColumnDecryptionProperties::ColumnDecryptionProperties( - const std::string& column_path, + const std::string& ColumnPath, const std::string& key) - : column_path_(column_path) { + : columnPath_(ColumnPath) { utilized_ = false; - VELOX_DCHECK(!column_path.empty()); + VELOX_DCHECK(!ColumnPath.empty()); if (!key.empty()) { VELOX_DCHECK( @@ -337,74 +337,74 @@ ColumnDecryptionProperties::ColumnDecryptionProperties( key_ = key; } -std::string FileDecryptionProperties::column_key( - const std::string& column_path) const { - if (column_decryption_properties_.find(column_path) != - column_decryption_properties_.end()) { - auto column_prop = column_decryption_properties_.at(column_path); - if (column_prop != nullptr) { - return column_prop->key(); +std::string FileDecryptionProperties::columnKey( + const std::string& ColumnPath) const { + if (columnDecryptionProperties_.find(ColumnPath) != + columnDecryptionProperties_.end()) { + auto columnProp = columnDecryptionProperties_.at(ColumnPath); + if (columnProp != nullptr) { + return columnProp->key(); } } - return empty_string_; + return emptyString_; } FileDecryptionProperties::FileDecryptionProperties( - const std::string& footer_key, - std::shared_ptr key_retriever, - bool check_plaintext_footer_integrity, - const std::string& aad_prefix, - std::shared_ptr aad_prefix_verifier, - const ColumnPathToDecryptionPropertiesMap& column_decryption_properties, - bool plaintext_files_allowed) { + const std::string& footerKey, + std::shared_ptr keyRetriever, + bool checkPlaintextFooterIntegrity, + const std::string& aadPrefix, + std::shared_ptr aadPrefixVerifier, + const ColumnPathToDecryptionPropertiesMap& ColumnDecryptionProperties, + bool plaintextFilesAllowed) { VELOX_DCHECK( - !footer_key.empty() || nullptr != key_retriever || - 0 != column_decryption_properties.size()); + !footerKey.empty() || nullptr != keyRetriever || + 0 != ColumnDecryptionProperties.size()); - if (!footer_key.empty()) { + if (!footerKey.empty()) { VELOX_DCHECK( - footer_key.length() == 16 || footer_key.length() == 24 || - footer_key.length() == 32); + footerKey.length() == 16 || footerKey.length() == 24 || + footerKey.length() == 32); } - if (footer_key.empty() && check_plaintext_footer_integrity) { - VELOX_DCHECK_NOT_NULL(key_retriever); + if (footerKey.empty() && checkPlaintextFooterIntegrity) { + VELOX_DCHECK_NOT_NULL(keyRetriever); } - aad_prefix_verifier_ = std::move(aad_prefix_verifier); - footer_key_ = footer_key; - check_plaintext_footer_integrity_ = check_plaintext_footer_integrity; - key_retriever_ = std::move(key_retriever); - aad_prefix_ = aad_prefix; - column_decryption_properties_ = column_decryption_properties; - plaintext_files_allowed_ = plaintext_files_allowed; + aadPrefixVerifier_ = std::move(aadPrefixVerifier); + footerKey_ = footerKey; + checkPlaintextFooterIntegrity_ = checkPlaintextFooterIntegrity; + keyRetriever_ = std::move(keyRetriever); + aadPrefix_ = aadPrefix; + columnDecryptionProperties_ = ColumnDecryptionProperties; + plaintextFilesAllowed_ = plaintextFilesAllowed; utilized_ = false; } FileEncryptionProperties::Builder* -FileEncryptionProperties::Builder::footer_key_id(const std::string& key_id) { - // key_id is expected to be in UTF8 encoding +FileEncryptionProperties::Builder::footerKeyId(const std::string& keyId) { + // Key_id is expected to be in UTF8 encoding. ::arrow::util::InitializeUTF8(); - const uint8_t* data = reinterpret_cast(key_id.c_str()); - if (!::arrow::util::ValidateUTF8(data, key_id.size())) { + const uint8_t* data = reinterpret_cast(keyId.c_str()); + if (!::arrow::util::ValidateUTF8(data, keyId.size())) { throw ParquetException("footer key id should be in UTF8 encoding"); } - if (key_id.empty()) { + if (keyId.empty()) { return this; } - return footer_key_metadata(key_id); + return footerKeyMetadata(keyId); } std::shared_ptr -FileEncryptionProperties::column_encryption_properties( - const std::string& column_path) { - if (encrypted_columns_.size() == 0) { +FileEncryptionProperties::columnEncryptionProperties( + const std::string& columnPath) { + if (encryptedColumns_.empty()) { auto builder = - std::make_shared(column_path); + std::make_shared(columnPath); return builder->build(); } - if (encrypted_columns_.find(column_path) != encrypted_columns_.end()) { - return encrypted_columns_[column_path]; + if (encryptedColumns_.find(columnPath) != encryptedColumns_.end()) { + return encryptedColumns_[columnPath]; } return nullptr; @@ -412,47 +412,47 @@ FileEncryptionProperties::column_encryption_properties( FileEncryptionProperties::FileEncryptionProperties( ParquetCipher::type cipher, - const std::string& footer_key, - const std::string& footer_key_metadata, - bool encrypted_footer, - const std::string& aad_prefix, - bool store_aad_prefix_in_file, - const ColumnPathToEncryptionPropertiesMap& encrypted_columns) - : footer_key_(footer_key), - footer_key_metadata_(footer_key_metadata), - encrypted_footer_(encrypted_footer), - aad_prefix_(aad_prefix), - store_aad_prefix_in_file_(store_aad_prefix_in_file), - encrypted_columns_(encrypted_columns) { - // file encryption properties object can be used for writing only one file. - // Upon completion of file writing, the encryption keys in the properties will - // be wiped out (set to 0 in memory). + const std::string& footerKey, + const std::string& footerKeyMetadata, + bool encryptedFooter, + const std::string& aadPrefix, + bool storeAadPrefixInFile, + const ColumnPathToEncryptionPropertiesMap& encryptedColumns) + : footerKey_(footerKey), + footerKeyMetadata_(footerKeyMetadata), + encryptedFooter_(encryptedFooter), + aadPrefix_(aadPrefix), + storeAadPrefixInFile_(storeAadPrefixInFile), + encryptedColumns_(encryptedColumns) { + // File encryption properties object can be used for writing only one file. + // Upon completion of file writing, the encryption keys in the properties + // will. Be wiped out (set to 0 in memory). utilized_ = false; - VELOX_DCHECK(!footer_key.empty()); - // footer_key must be either 16, 24 or 32 bytes. + VELOX_DCHECK(!footerKey.empty()); + // Footer_key must be either 16, 24 or 32 bytes. VELOX_DCHECK( - footer_key.length() == 16 || footer_key.length() == 24 || - footer_key.length() == 32); + footerKey.length() == 16 || footerKey.length() == 24 || + footerKey.length() == 32); - uint8_t aad_file_unique[kAadFileUniqueLength]; - encryption::RandBytes(aad_file_unique, kAadFileUniqueLength); - std::string aad_file_unique_str( - reinterpret_cast(aad_file_unique), kAadFileUniqueLength); + uint8_t aadFileUnique[kAadFileUniqueLength]; + encryption::randBytes(aadFileUnique, kAadFileUniqueLength); + std::string aadFileUniqueStr( + reinterpret_cast(aadFileUnique), kAadFileUniqueLength); - bool supply_aad_prefix = false; - if (aad_prefix.empty()) { - file_aad_ = aad_file_unique_str; + bool supplyAadPrefix = false; + if (aadPrefix.empty()) { + fileAad_ = aadFileUniqueStr; } else { - file_aad_ = aad_prefix + aad_file_unique_str; - if (!store_aad_prefix_in_file) - supply_aad_prefix = true; + fileAad_ = aadPrefix + aadFileUniqueStr; + if (!storeAadPrefixInFile) + supplyAadPrefix = true; } algorithm_.algorithm = cipher; - algorithm_.aad.aad_file_unique = aad_file_unique_str; - algorithm_.aad.supply_aad_prefix = supply_aad_prefix; - if (!aad_prefix.empty() && store_aad_prefix_in_file) { - algorithm_.aad.aad_prefix = aad_prefix; + algorithm_.aad.aadFileUnique = aadFileUniqueStr; + algorithm_.aad.supplyAadPrefix = supplyAadPrefix; + if (!aadPrefix.empty() && storeAadPrefixInFile) { + algorithm_.aad.aadPrefix = aadPrefix; } } diff --git a/velox/dwio/parquet/writer/arrow/Encryption.h b/velox/dwio/parquet/writer/arrow/Encryption.h index df310589a43..1979f189fea 100644 --- a/velox/dwio/parquet/writer/arrow/Encryption.h +++ b/velox/dwio/parquet/writer/arrow/Encryption.h @@ -30,7 +30,7 @@ namespace facebook::velox::parquet::arrow { static constexpr ParquetCipher::type kDefaultEncryptionAlgorithm = - ParquetCipher::AES_GCM_V1; + ParquetCipher::kAesGcmV1; static constexpr int32_t kMaximalAadMetadataLength = 256; static constexpr bool kDefaultEncryptedFooter = true; static constexpr bool kDefaultCheckSignature = true; @@ -47,28 +47,28 @@ using ColumnPathToEncryptionPropertiesMap = class PARQUET_EXPORT DecryptionKeyRetriever { public: - virtual std::string GetKey(const std::string& key_metadata) = 0; + virtual std::string getKey(const std::string& keyMetadata) = 0; virtual ~DecryptionKeyRetriever() {} }; -/// Simple integer key retriever +/// Simple integer key retriever. class PARQUET_EXPORT IntegerKeyIdRetriever : public DecryptionKeyRetriever { public: - void PutKey(uint32_t key_id, const std::string& key); - std::string GetKey(const std::string& key_metadata) override; + void putKey(uint32_t keyId, const std::string& key); + std::string getKey(const std::string& keyMetadata) override; private: - std::map key_map_; + std::map keyMap_; }; -// Simple string key retriever +// Simple string key retriever. class PARQUET_EXPORT StringKeyIdRetriever : public DecryptionKeyRetriever { public: - void PutKey(const std::string& key_id, const std::string& key); - std::string GetKey(const std::string& key_metadata) override; + void putKey(const std::string& keyId, const std::string& key); + std::string getKey(const std::string& keyMetadata) override; private: - std::map key_map_; + std::map keyMap_; }; class PARQUET_EXPORT HiddenColumnException : public ParquetException { @@ -100,66 +100,66 @@ class PARQUET_EXPORT ColumnEncryptionProperties { /// Convenience builder for encrypted columns. explicit Builder(const std::shared_ptr& path) - : Builder(path->ToDotString(), true) {} + : Builder(path->toDotString(), true) {} /// Set a column-specific key. /// If key is not set on an encrypted column, the column will /// be encrypted with the footer key. - /// keyBytes Key length must be either 16, 24 or 32 bytes. + /// KeyBytes Key length must be either 16, 24 or 32 bytes. /// The key is cloned, and will be wiped out (array values set to 0) upon - /// completion of file writing. Caller is responsible for wiping out the + /// completion of file writing. Caller is responsible for wiping out the. /// input key array. - Builder* key(std::string column_key); + Builder* key(std::string columnKey); /// Set a key retrieval metadata. - /// use either key_metadata() or key_id(), not both - Builder* key_metadata(const std::string& key_metadata); + /// Use either key_metadata() or key_id(), not both. + Builder* keyMetadata(const std::string& keyMetadata); /// A convenience function to set key metadata using a string id. /// Set a key retrieval metadata (converted from String). - /// use either key_metadata() or key_id(), not both + /// Use either key_metadata() or key_id(), not both. /// key_id will be converted to metadata (UTF-8 array). - Builder* key_id(const std::string& key_id); + Builder* keyId(const std::string& keyId); std::shared_ptr build() { return std::shared_ptr( new ColumnEncryptionProperties( - encrypted_, column_path_, key_, key_metadata_)); + encrypted_, columnPath_, key_, keyMetadata_)); } private: - const std::string column_path_; + const std::string columnPath_; bool encrypted_; std::string key_; - std::string key_metadata_; + std::string keyMetadata_; Builder(const std::string path, bool encrypted) - : column_path_(path), encrypted_(encrypted) {} + : columnPath_(path), encrypted_(encrypted) {} }; - std::string column_path() const { - return column_path_; + std::string columnPath() const { + return columnPath_; } - bool is_encrypted() const { + bool isEncrypted() const { return encrypted_; } - bool is_encrypted_with_footer_key() const { - return encrypted_with_footer_key_; + bool isEncryptedWithFooterKey() const { + return encryptedWithFooterKey_; } std::string key() const { return key_; } - std::string key_metadata() const { - return key_metadata_; + std::string keyMetadata() const { + return keyMetadata_; } /// Upon completion of file writing, the encryption key /// will be wiped out. - void WipeOutEncryptionKey() { + void wipeOutEncryptionKey() { key_.clear(); } - bool is_utilized() { + bool isUtilized() { if (key_.empty()) return false; // can re-use column properties without encryption keys return utilized_; @@ -169,15 +169,15 @@ class PARQUET_EXPORT ColumnEncryptionProperties { /// Mark ColumnEncryptionProperties as utilized once it is used in /// FileEncryptionProperties as the encryption key will be wiped out upon /// completion of file writing. - void set_utilized() { + void setUtilized() { utilized_ = true; } - std::shared_ptr DeepClone() { - std::string key_copy = key_; + std::shared_ptr deepClone() { + std::string keyCopy = key_; return std::shared_ptr( new ColumnEncryptionProperties( - encrypted_, column_path_, key_copy, key_metadata_)); + encrypted_, columnPath_, keyCopy, keyMetadata_)); } ColumnEncryptionProperties() = default; @@ -185,38 +185,38 @@ class PARQUET_EXPORT ColumnEncryptionProperties { ColumnEncryptionProperties(ColumnEncryptionProperties&& other) = default; private: - const std::string column_path_; + const std::string columnPath_; bool encrypted_; - bool encrypted_with_footer_key_; + bool encryptedWithFooterKey_; std::string key_; - std::string key_metadata_; + std::string keyMetadata_; bool utilized_; explicit ColumnEncryptionProperties( bool encrypted, - const std::string& column_path, + const std::string& columnPath, const std::string& key, - const std::string& key_metadata); + const std::string& keyMetadata); }; class PARQUET_EXPORT ColumnDecryptionProperties { public: class PARQUET_EXPORT Builder { public: - explicit Builder(const std::string& name) : column_path_(name) {} + explicit Builder(const std::string& name) : columnPath_(name) {} explicit Builder(const std::shared_ptr& path) - : Builder(path->ToDotString()) {} + : Builder(path->toDotString()) {} /// Set an explicit column key. If applied on a file that contains /// key metadata for this column the metadata will be ignored, - /// the column will be decrypted with this key. - /// key length must be either 16, 24 or 32 bytes. + /// and the column will be decrypted with this key. + /// Key length must be either 16, 24 or 32 bytes. Builder* key(const std::string& key); std::shared_ptr build(); private: - const std::string column_path_; + const std::string columnPath_; std::string key_; }; @@ -224,13 +224,13 @@ class PARQUET_EXPORT ColumnDecryptionProperties { ColumnDecryptionProperties(const ColumnDecryptionProperties& other) = default; ColumnDecryptionProperties(ColumnDecryptionProperties&& other) = default; - std::string column_path() const { - return column_path_; + std::string columnPath() const { + return columnPath_; } std::string key() const { return key_; } - bool is_utilized() { + bool isUtilized() { return utilized_; } @@ -238,26 +238,26 @@ class PARQUET_EXPORT ColumnDecryptionProperties { /// Mark ColumnDecryptionProperties as utilized once it is used in /// FileDecryptionProperties as the encryption key will be wiped out upon /// completion of file reading. - void set_utilized() { + void setUtilized() { utilized_ = true; } /// Upon completion of file reading, the encryption key /// will be wiped out. - void WipeOutDecryptionKey(); + void wipeOutDecryptionKey(); - std::shared_ptr DeepClone(); + std::shared_ptr deepClone(); private: - const std::string column_path_; + const std::string columnPath_; std::string key_; bool utilized_; - /// This class is only required for setting explicit column decryption keys - - /// to override key retriever (or to provide keys when key metadata and/or + /// This class is only required for setting explicit column decryption keys -. + /// To override key retriever (or to provide keys when key metadata and/or. /// key retriever are not available) explicit ColumnDecryptionProperties( - const std::string& column_path, + const std::string& columnPath, const std::string& key); }; @@ -268,7 +268,7 @@ class PARQUET_EXPORT AADPrefixVerifier { /// Throws exception if an AAD prefix is wrong. /// In a data set, AAD Prefixes should be collected, /// and then checked for missing files. - virtual void Verify(const std::string& aad_prefix) = 0; + virtual void verify(const std::string& aadPrefix) = 0; virtual ~AADPrefixVerifier() {} }; @@ -277,8 +277,8 @@ class PARQUET_EXPORT FileDecryptionProperties { class PARQUET_EXPORT Builder { public: Builder() { - check_plaintext_footer_integrity_ = kDefaultCheckSignature; - plaintext_files_allowed_ = kDefaultAllowPlaintextFiles; + checkPlaintextFooterIntegrity_ = kDefaultCheckSignature; + plaintextFilesAllowed_ = kDefaultAllowPlaintextFiles; } /// Set an explicit footer key. If applied on a file that contains @@ -287,297 +287,297 @@ class PARQUET_EXPORT FileDecryptionProperties { /// If explicit key is not set, footer key will be fetched from /// key retriever. /// With explicit keys or AAD prefix, new encryption properties object must - /// be created for each encrypted file. Explicit encryption keys (footer and - /// column) are cloned. Upon completion of file reading, the cloned - /// encryption keys in the properties will be wiped out (array values set to - /// 0). Caller is responsible for wiping out the input key array. param + /// be created for each encrypted file. Explicit encryption keys (footer + /// and column) are cloned. Upon completion of file reading, the cloned + /// encryption keys in the properties will be wiped out (array values set + /// to 0). Caller is responsible for wiping out the input key array. param /// footerKey Key length must be either 16, 24 or 32 bytes. - Builder* footer_key(const std::string footer_key); + Builder* footerKey(const std::string footerKey); /// Set explicit column keys (decryption properties). /// Its also possible to set a key retriever on this property object. - /// Upon file decryption, availability of explicit keys is checked before + /// Upon file decryption, availability of explicit keys is checked before. /// invocation of the retriever callback. /// If an explicit key is available for a footer or a column, /// its key metadata will be ignored. - Builder* column_keys(const ColumnPathToDecryptionPropertiesMap& - column_decryption_properties); + Builder* columnKeys( + const ColumnPathToDecryptionPropertiesMap& columnDecryptionProperties); - /// Set a key retriever callback. Its also possible to + /// Set a key retrieval callback. It is also possible to /// set explicit footer or column keys on this file property object. - /// Upon file decryption, availability of explicit keys is checked before + /// Upon file decryption, availability of explicit keys is checked before. /// invocation of the retriever callback. /// If an explicit key is available for a footer or a column, /// its key metadata will be ignored. - Builder* key_retriever( - const std::shared_ptr& key_retriever); + Builder* keyRetriever( + const std::shared_ptr& keyRetriever); /// Skip integrity verification of plaintext footers. - /// If not called, integrity of plaintext footers will be checked in + /// If not called, integrity of plaintext footers will be checked in. /// runtime, and an exception will be thrown in the following situations: - /// - footer signing key is not available - /// (not passed, or not found by key retriever) - /// - footer content and signature don't match - Builder* disable_footer_signature_verification() { - check_plaintext_footer_integrity_ = false; + /// - Footer signing key is not available. + /// (not passed, or not found by key retriever). + /// - Footer content and signature don't match. + Builder* disableFooterSignatureVerification() { + checkPlaintextFooterIntegrity_ = false; return this; } /// Explicitly supply the file AAD prefix. - /// A must when a prefix is used for file encryption, but not stored in - /// file. If AAD prefix is stored in file, it will be compared to the - /// explicitly supplied value and an exception will be thrown if they + /// This is mandatory when a prefix is used for file encryption, but not + /// stored in file. If AAD prefix is stored in file, it will be compared to + /// the explicitly supplied value and an exception will be thrown if they /// differ. - Builder* aad_prefix(const std::string& aad_prefix); + Builder* aadPrefix(const std::string& aadPrefix); /// Set callback for verification of AAD Prefixes stored in file. - Builder* aad_prefix_verifier( - std::shared_ptr aad_prefix_verifier); + Builder* aadPrefixVerifier( + std::shared_ptr aadPrefixVerifier); /// By default, reading plaintext (unencrypted) files is not - /// allowed when using a decryptor + /// allowed when using a decryptor. /// - in order to detect files that were not encrypted by mistake. /// However, the default behavior can be overridden by calling this method. - /// The caller should use then a different method to ensure encryption + /// The caller should then use a different method to ensure encryption /// of files with sensitive data. - Builder* plaintext_files_allowed() { - plaintext_files_allowed_ = true; + Builder* plaintextFilesAllowed() { + plaintextFilesAllowed_ = true; return this; } std::shared_ptr build() { return std::shared_ptr( new FileDecryptionProperties( - footer_key_, - key_retriever_, - check_plaintext_footer_integrity_, - aad_prefix_, - aad_prefix_verifier_, - column_decryption_properties_, - plaintext_files_allowed_)); + footerKey_, + keyRetriever_, + checkPlaintextFooterIntegrity_, + aadPrefix_, + aadPrefixVerifier_, + columnDecryptionProperties_, + plaintextFilesAllowed_)); } private: - std::string footer_key_; - std::string aad_prefix_; - std::shared_ptr aad_prefix_verifier_; - ColumnPathToDecryptionPropertiesMap column_decryption_properties_; - - std::shared_ptr key_retriever_; - bool check_plaintext_footer_integrity_; - bool plaintext_files_allowed_; + std::string footerKey_; + std::string aadPrefix_; + std::shared_ptr aadPrefixVerifier_; + ColumnPathToDecryptionPropertiesMap columnDecryptionProperties_; + + std::shared_ptr keyRetriever_; + bool checkPlaintextFooterIntegrity_; + bool plaintextFilesAllowed_; }; - std::string column_key(const std::string& column_path) const; + std::string columnKey(const std::string& columnPath) const; - std::string footer_key() const { - return footer_key_; + std::string footerKey() const { + return footerKey_; } - std::string aad_prefix() const { - return aad_prefix_; + std::string aadPrefix() const { + return aadPrefix_; } - const std::shared_ptr& key_retriever() const { - return key_retriever_; + const std::shared_ptr& keyRetriever() const { + return keyRetriever_; } - bool check_plaintext_footer_integrity() const { - return check_plaintext_footer_integrity_; + bool checkPlaintextFooterIntegrity() const { + return checkPlaintextFooterIntegrity_; } - bool plaintext_files_allowed() const { - return plaintext_files_allowed_; + bool plaintextFilesAllowed() const { + return plaintextFilesAllowed_; } - const std::shared_ptr& aad_prefix_verifier() const { - return aad_prefix_verifier_; + const std::shared_ptr& aadPrefixVerifier() const { + return aadPrefixVerifier_; } /// Upon completion of file reading, the encryption keys in the properties /// will be wiped out (array values set to 0). - void WipeOutDecryptionKeys(); + void wipeOutDecryptionKeys(); - bool is_utilized(); + bool isUtilized(); /// FileDecryptionProperties object can be used for reading one file only. /// Mark FileDecryptionProperties as utilized once it is used to read a file /// as the encryption keys will be wiped out upon completion of file reading. - void set_utilized() { + void setUtilized() { utilized_ = true; } - /// FileDecryptionProperties object can be used for reading one file only. + /// FileDecryptionProperties object can be used for reading one file only /// (unless this object keeps the keyRetrieval callback only, and no explicit /// keys or aadPrefix). /// At the end, keys are wiped out in the memory. - /// This method allows to clone identical properties for another file, + /// This method allows cloning identical properties for another file, /// with an option to update the aadPrefix (if newAadPrefix is null, - /// aadPrefix will be cloned too) - std::shared_ptr DeepClone( - std::string new_aad_prefix = ""); + /// aadPrefix will be cloned too). + std::shared_ptr deepClone( + std::string newAadPrefix = ""); private: - std::string footer_key_; - std::string aad_prefix_; - std::shared_ptr aad_prefix_verifier_; + std::string footerKey_; + std::string aadPrefix_; + std::shared_ptr aadPrefixVerifier_; - const std::string empty_string_ = ""; - ColumnPathToDecryptionPropertiesMap column_decryption_properties_; + const std::string emptyString_ = ""; + ColumnPathToDecryptionPropertiesMap columnDecryptionProperties_; - std::shared_ptr key_retriever_; - bool check_plaintext_footer_integrity_; - bool plaintext_files_allowed_; + std::shared_ptr keyRetriever_; + bool checkPlaintextFooterIntegrity_; + bool plaintextFilesAllowed_; bool utilized_; FileDecryptionProperties( - const std::string& footer_key, - std::shared_ptr key_retriever, - bool check_plaintext_footer_integrity, - const std::string& aad_prefix, - std::shared_ptr aad_prefix_verifier, - const ColumnPathToDecryptionPropertiesMap& column_decryption_properties, - bool plaintext_files_allowed); + const std::string& footerKey, + std::shared_ptr keyRetriever, + bool checkPlaintextFooterIntegrity, + const std::string& aadPrefix, + std::shared_ptr aadPrefixVerifier, + const ColumnPathToDecryptionPropertiesMap& columnDecryptionProperties, + bool plaintextFilesAllowed); }; class PARQUET_EXPORT FileEncryptionProperties { public: class PARQUET_EXPORT Builder { public: - explicit Builder(const std::string& footer_key) - : parquet_cipher_(kDefaultEncryptionAlgorithm), - encrypted_footer_(kDefaultEncryptedFooter) { - footer_key_ = footer_key; - store_aad_prefix_in_file_ = false; + explicit Builder(const std::string& footerKey) + : parquetCipher_(kDefaultEncryptionAlgorithm), + encryptedFooter_(kDefaultEncryptedFooter) { + footerKey_ = footerKey; + storeAadPrefixInFile_ = false; } /// Create files with plaintext footer. - /// If not called, the files will be created with encrypted footer + /// If not called, the files will be created with encrypted footer. /// (default). - Builder* set_plaintext_footer() { - encrypted_footer_ = false; + Builder* setPlaintextFooter() { + encryptedFooter_ = false; return this; } /// Set encryption algorithm. /// If not called, files will be encrypted with AES_GCM_V1 (default). - Builder* algorithm(ParquetCipher::type parquet_cipher) { - parquet_cipher_ = parquet_cipher; + Builder* algorithm(ParquetCipher::type parquetCipher) { + parquetCipher_ = parquetCipher; return this; } /// Set a key retrieval metadata (converted from String). - /// use either footer_key_metadata or footer_key_id, not both. - Builder* footer_key_id(const std::string& key_id); + /// Use either footer_key_metadata or footer_key_id, not both. + Builder* footerKeyId(const std::string& keyId); /// Set a key retrieval metadata. - /// use either footer_key_metadata or footer_key_id, not both. - Builder* footer_key_metadata(const std::string& footer_key_metadata); + /// Use either footer_key_metadata or footer_key_id, not both. + Builder* footerKeyMetadata(const std::string& footerKeyMetadata); /// Set the file AAD Prefix. - Builder* aad_prefix(const std::string& aad_prefix); + Builder* aadPrefix(const std::string& aadPrefix); /// Skip storing AAD Prefix in file. /// If not called, and if AAD Prefix is set, it will be stored. - Builder* disable_aad_prefix_storage(); + Builder* disableAadPrefixStorage(); - /// Set the list of encrypted columns and their properties (keys etc). + /// Set the list of encrypted columns and their properties (keys, etc.). /// If not called, all columns will be encrypted with the footer key. /// If called, the file columns not in the list will be left unencrypted. - Builder* encrypted_columns( - const ColumnPathToEncryptionPropertiesMap& encrypted_columns); + Builder* encryptedColumns( + const ColumnPathToEncryptionPropertiesMap& encryptedColumns); std::shared_ptr build() { return std::shared_ptr( new FileEncryptionProperties( - parquet_cipher_, - footer_key_, - footer_key_metadata_, - encrypted_footer_, - aad_prefix_, - store_aad_prefix_in_file_, - encrypted_columns_)); + parquetCipher_, + footerKey_, + footerKeyMetadata_, + encryptedFooter_, + aadPrefix_, + storeAadPrefixInFile_, + encryptedColumns_)); } private: - ParquetCipher::type parquet_cipher_; - bool encrypted_footer_; - std::string footer_key_; - std::string footer_key_metadata_; - - std::string aad_prefix_; - bool store_aad_prefix_in_file_; - ColumnPathToEncryptionPropertiesMap encrypted_columns_; + ParquetCipher::type parquetCipher_; + bool encryptedFooter_; + std::string footerKey_; + std::string footerKeyMetadata_; + + std::string aadPrefix_; + bool storeAadPrefixInFile_; + ColumnPathToEncryptionPropertiesMap encryptedColumns_; }; - bool encrypted_footer() const { - return encrypted_footer_; + bool encryptedFooter() const { + return encryptedFooter_; } EncryptionAlgorithm algorithm() const { return algorithm_; } - std::string footer_key() const { - return footer_key_; + std::string footerKey() const { + return footerKey_; } - std::string footer_key_metadata() const { - return footer_key_metadata_; + std::string footerKeyMetadata() const { + return footerKeyMetadata_; } - std::string file_aad() const { - return file_aad_; + std::string fileAad() const { + return fileAad_; } - std::shared_ptr column_encryption_properties( - const std::string& column_path); + std::shared_ptr columnEncryptionProperties( + const std::string& columnPath); - bool is_utilized() const { + bool isUtilized() const { return utilized_; } /// FileEncryptionProperties object can be used for writing one file only. /// Mark FileEncryptionProperties as utilized once it is used to write a file /// as the encryption keys will be wiped out upon completion of file writing. - void set_utilized() { + void setUtilized() { utilized_ = true; } /// Upon completion of file writing, the encryption keys /// will be wiped out (array values set to 0). - void WipeOutEncryptionKeys(); + void wipeOutEncryptionKeys(); - /// FileEncryptionProperties object can be used for writing one file only. + /// FileEncryptionProperties object can be used for writing one file only /// (at the end, keys are wiped out in the memory). - /// This method allows to clone identical properties for another file, + /// This method allows cloning identical properties for another file, /// with an option to update the aadPrefix (if newAadPrefix is null, - /// aadPrefix will be cloned too) - std::shared_ptr DeepClone( - std::string new_aad_prefix = ""); + /// aadPrefix will be cloned too). + std::shared_ptr deepClone( + std::string newAadPrefix = ""); - ColumnPathToEncryptionPropertiesMap encrypted_columns() const { - return encrypted_columns_; + ColumnPathToEncryptionPropertiesMap encryptedColumns() const { + return encryptedColumns_; } private: EncryptionAlgorithm algorithm_; - std::string footer_key_; - std::string footer_key_metadata_; - bool encrypted_footer_; - std::string file_aad_; - std::string aad_prefix_; + std::string footerKey_; + std::string footerKeyMetadata_; + bool encryptedFooter_; + std::string fileAad_; + std::string aadPrefix_; bool utilized_; - bool store_aad_prefix_in_file_; - ColumnPathToEncryptionPropertiesMap encrypted_columns_; + bool storeAadPrefixInFile_; + ColumnPathToEncryptionPropertiesMap encryptedColumns_; FileEncryptionProperties( ParquetCipher::type cipher, - const std::string& footer_key, - const std::string& footer_key_metadata, - bool encrypted_footer, - const std::string& aad_prefix, - bool store_aad_prefix_in_file, - const ColumnPathToEncryptionPropertiesMap& encrypted_columns); + const std::string& footerKey, + const std::string& footerKeyMetadata, + bool encryptedFooter, + const std::string& aadPrefix, + bool storeAadPrefixInFile, + const ColumnPathToEncryptionPropertiesMap& encryptedColumns); }; } // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/EncryptionInternal.cpp b/velox/dwio/parquet/writer/arrow/EncryptionInternal.cpp index 7e988988069..81b4ed26e39 100644 --- a/velox/dwio/parquet/writer/arrow/EncryptionInternal.cpp +++ b/velox/dwio/parquet/writer/arrow/EncryptionInternal.cpp @@ -54,10 +54,10 @@ constexpr int kBufferSizeLength = 4; class AesEncryptor::AesEncryptorImpl { public: explicit AesEncryptorImpl( - ParquetCipher::type alg_id, - int key_len, + ParquetCipher::type algId, + int keyLen, bool metadata, - bool write_length); + bool writeLength); ~AesEncryptorImpl() { if (nullptr != ctx_) { @@ -66,179 +66,172 @@ class AesEncryptor::AesEncryptorImpl { } } - int Encrypt( + int encrypt( const uint8_t* plaintext, - int plaintext_len, + int plaintextLen, const uint8_t* key, - int key_len, + int keyLen, const uint8_t* aad, - int aad_len, + int aadLen, uint8_t* ciphertext); - int SignedFooterEncrypt( + int signedFooterEncrypt( const uint8_t* footer, - int footer_len, + int footerLen, const uint8_t* key, - int key_len, + int keyLen, const uint8_t* aad, - int aad_len, + int aadLen, const uint8_t* nonce, - uint8_t* encrypted_footer); - void WipeOut() { + uint8_t* encryptedFooter); + void wipeOut() { if (nullptr != ctx_) { EVP_CIPHER_CTX_free(ctx_); ctx_ = nullptr; } } - int ciphertext_size_delta() { - return ciphertext_size_delta_; + int ciphertextSizeDelta() { + return ciphertextSizeDelta_; } private: EVP_CIPHER_CTX* ctx_; - int aes_mode_; - int key_length_; - int ciphertext_size_delta_; - int length_buffer_length_; + int aesMode_; + int keyLength_; + int ciphertextSizeDelta_; + int lengthBufferLength_; - int GcmEncrypt( + int gcmEncrypt( const uint8_t* plaintext, - int plaintext_len, + int plaintextLen, const uint8_t* key, - int key_len, + int keyLen, const uint8_t* nonce, const uint8_t* aad, - int aad_len, + int aadLen, uint8_t* ciphertext); - int CtrEncrypt( + int ctrEncrypt( const uint8_t* plaintext, - int plaintext_len, + int plaintextLen, const uint8_t* key, - int key_len, + int keyLen, const uint8_t* nonce, uint8_t* ciphertext); }; AesEncryptor::AesEncryptorImpl::AesEncryptorImpl( - ParquetCipher::type alg_id, - int key_len, + ParquetCipher::type algId, + int keyLen, bool metadata, - bool write_length) { + bool writeLength) { ctx_ = nullptr; - length_buffer_length_ = write_length ? kBufferSizeLength : 0; - ciphertext_size_delta_ = length_buffer_length_ + kNonceLength; - if (metadata || (ParquetCipher::AES_GCM_V1 == alg_id)) { - aes_mode_ = kGcmMode; - ciphertext_size_delta_ += kGcmTagLength; + lengthBufferLength_ = writeLength ? kBufferSizeLength : 0; + ciphertextSizeDelta_ = lengthBufferLength_ + kNonceLength; + if (metadata || (ParquetCipher::kAesGcmV1 == algId)) { + aesMode_ = kGcmMode; + ciphertextSizeDelta_ += kGcmTagLength; } else { - aes_mode_ = kCtrMode; + aesMode_ = kCtrMode; } - if (16 != key_len && 24 != key_len && 32 != key_len) { + if (16 != keyLen && 24 != keyLen && 32 != keyLen) { std::stringstream ss; - ss << "Wrong key length: " << key_len; + ss << "Wrong key length: " << keyLen; throw ParquetException(ss.str()); } - key_length_ = key_len; + keyLength_ = keyLen; ctx_ = EVP_CIPHER_CTX_new(); if (nullptr == ctx_) { throw ParquetException("Couldn't init cipher context"); } - if (kGcmMode == aes_mode_) { - // Init AES-GCM with specified key length - if (16 == key_len) { + if (kGcmMode == aesMode_) { + // Init AES-GCM with specified key length. + if (16 == keyLen) { ENCRYPT_INIT(ctx_, EVP_aes_128_gcm()); - } else if (24 == key_len) { + } else if (24 == keyLen) { ENCRYPT_INIT(ctx_, EVP_aes_192_gcm()); - } else if (32 == key_len) { + } else if (32 == keyLen) { ENCRYPT_INIT(ctx_, EVP_aes_256_gcm()); } } else { - // Init AES-CTR with specified key length - if (16 == key_len) { + // Init AES-CTR with specified key length. + if (16 == keyLen) { ENCRYPT_INIT(ctx_, EVP_aes_128_ctr()); - } else if (24 == key_len) { + } else if (24 == keyLen) { ENCRYPT_INIT(ctx_, EVP_aes_192_ctr()); - } else if (32 == key_len) { + } else if (32 == keyLen) { ENCRYPT_INIT(ctx_, EVP_aes_256_ctr()); } } } -int AesEncryptor::AesEncryptorImpl::SignedFooterEncrypt( +int AesEncryptor::AesEncryptorImpl::signedFooterEncrypt( const uint8_t* footer, - int footer_len, + int footerLen, const uint8_t* key, - int key_len, + int keyLen, const uint8_t* aad, - int aad_len, + int aadLen, const uint8_t* nonce, - uint8_t* encrypted_footer) { - if (key_length_ != key_len) { + uint8_t* encryptedFooter) { + if (keyLength_ != keyLen) { std::stringstream ss; - ss << "Wrong key length " << key_len << ". Should be " << key_length_; + ss << "Wrong key length " << keyLen << ". Should be " << keyLength_; throw ParquetException(ss.str()); } - if (kGcmMode != aes_mode_) { + if (kGcmMode != aesMode_) { throw ParquetException("Must use AES GCM (metadata) encryptor"); } - return GcmEncrypt( - footer, footer_len, key, key_len, nonce, aad, aad_len, encrypted_footer); + return gcmEncrypt( + footer, footerLen, key, keyLen, nonce, aad, aadLen, encryptedFooter); } -int AesEncryptor::AesEncryptorImpl::Encrypt( +int AesEncryptor::AesEncryptorImpl::encrypt( const uint8_t* plaintext, - int plaintext_len, + int plaintextLen, const uint8_t* key, - int key_len, + int keyLen, const uint8_t* aad, - int aad_len, + int aadLen, uint8_t* ciphertext) { - if (key_length_ != key_len) { + if (keyLength_ != keyLen) { std::stringstream ss; - ss << "Wrong key length " << key_len << ". Should be " << key_length_; + ss << "Wrong key length " << keyLen << ". Should be " << keyLength_; throw ParquetException(ss.str()); } uint8_t nonce[kNonceLength]; memset(nonce, 0, kNonceLength); - // Random nonce - RAND_bytes(nonce, sizeof(nonce)); - - if (kGcmMode == aes_mode_) { - return GcmEncrypt( - plaintext, - plaintext_len, - key, - key_len, - nonce, - aad, - aad_len, - ciphertext); - } - - return CtrEncrypt(plaintext, plaintext_len, key, key_len, nonce, ciphertext); + // Random nonce. + randBytes(nonce, sizeof(nonce)); + + if (kGcmMode == aesMode_) { + return gcmEncrypt( + plaintext, plaintextLen, key, keyLen, nonce, aad, aadLen, ciphertext); + } + + return ctrEncrypt(plaintext, plaintextLen, key, keyLen, nonce, ciphertext); } -int AesEncryptor::AesEncryptorImpl::GcmEncrypt( +int AesEncryptor::AesEncryptorImpl::gcmEncrypt( const uint8_t* plaintext, - int plaintext_len, + int plaintextLen, const uint8_t* key, - int key_len, + int keyLen, const uint8_t* nonce, const uint8_t* aad, - int aad_len, + int aadLen, uint8_t* ciphertext) { int len = 0; - int ciphertext_len; + int ciphertextLen; uint8_t tag[kGcmTagLength]; memset(tag, 0, kGcmTagLength); @@ -248,170 +241,166 @@ int AesEncryptor::AesEncryptorImpl::GcmEncrypt( throw ParquetException("Couldn't set key and nonce"); } - // Setting additional authenticated data + // Setting additional authenticated data. if ((nullptr != aad) && - (1 != EVP_EncryptUpdate(ctx_, nullptr, &len, aad, aad_len))) { + (1 != EVP_EncryptUpdate(ctx_, nullptr, &len, aad, aadLen))) { throw ParquetException("Couldn't set AAD"); } - // Encryption + // Encryption. if (1 != EVP_EncryptUpdate( ctx_, - ciphertext + length_buffer_length_ + kNonceLength, + ciphertext + lengthBufferLength_ + kNonceLength, &len, plaintext, - plaintext_len)) { + plaintextLen)) { throw ParquetException("Failed encryption update"); } - ciphertext_len = len; + ciphertextLen = len; - // Finalization + // Finalization. if (1 != EVP_EncryptFinal_ex( - ctx_, - ciphertext + length_buffer_length_ + kNonceLength + len, - &len)) { + ctx_, ciphertext + lengthBufferLength_ + kNonceLength + len, &len)) { throw ParquetException("Failed encryption finalization"); } - ciphertext_len += len; + ciphertextLen += len; - // Getting the tag + // Getting the tag. if (1 != EVP_CIPHER_CTX_ctrl(ctx_, EVP_CTRL_GCM_GET_TAG, kGcmTagLength, tag)) { throw ParquetException("Couldn't get AES-GCM tag"); } - // Copying the buffer size, nonce and tag to ciphertext - uint32_t buffer_size = kNonceLength + ciphertext_len + kGcmTagLength; - if (length_buffer_length_ > 0) { - ciphertext[3] = static_cast(0xff & (buffer_size >> 24)); - ciphertext[2] = static_cast(0xff & (buffer_size >> 16)); - ciphertext[1] = static_cast(0xff & (buffer_size >> 8)); - ciphertext[0] = static_cast(0xff & (buffer_size)); + // Copying the buffer size, nonce and tag to ciphertext. + uint32_t bufferSize = kNonceLength + ciphertextLen + kGcmTagLength; + if (lengthBufferLength_ > 0) { + ciphertext[3] = static_cast(0xff & (bufferSize >> 24)); + ciphertext[2] = static_cast(0xff & (bufferSize >> 16)); + ciphertext[1] = static_cast(0xff & (bufferSize >> 8)); + ciphertext[0] = static_cast(0xff & (bufferSize)); } - std::copy(nonce, nonce + kNonceLength, ciphertext + length_buffer_length_); + std::copy(nonce, nonce + kNonceLength, ciphertext + lengthBufferLength_); std::copy( tag, tag + kGcmTagLength, - ciphertext + length_buffer_length_ + kNonceLength + ciphertext_len); + ciphertext + lengthBufferLength_ + kNonceLength + ciphertextLen); - return length_buffer_length_ + buffer_size; + return lengthBufferLength_ + bufferSize; } -int AesEncryptor::AesEncryptorImpl::CtrEncrypt( +int AesEncryptor::AesEncryptorImpl::ctrEncrypt( const uint8_t* plaintext, - int plaintext_len, + int plaintextLen, const uint8_t* key, - int key_len, + int keyLen, const uint8_t* nonce, uint8_t* ciphertext) { int len = 0; - int ciphertext_len; + int ciphertextLen; - // Parquet CTR IVs are comprised of a 12-byte nonce and a 4-byte initial - // counter field. - // The first 31 bits of the initial counter field are set to 0, the last bit - // is set to 1. + // Parquet CTR IVs are comprised of a 12-byte nonce and a 4-byte initial. + // Counter field. + // The first 31 bits of the initial counter field are set to 0, the last bit. + // Is set to 1. uint8_t iv[kCtrIvLength]; memset(iv, 0, kCtrIvLength); std::copy(nonce, nonce + kNonceLength, iv); iv[kCtrIvLength - 1] = 1; - // Setting key and IV + // Setting key and IV. if (1 != EVP_EncryptInit_ex(ctx_, nullptr, nullptr, key, iv)) { throw ParquetException("Couldn't set key and IV"); } - // Encryption + // Encryption. if (1 != EVP_EncryptUpdate( ctx_, - ciphertext + length_buffer_length_ + kNonceLength, + ciphertext + lengthBufferLength_ + kNonceLength, &len, plaintext, - plaintext_len)) { + plaintextLen)) { throw ParquetException("Failed encryption update"); } - ciphertext_len = len; + ciphertextLen = len; - // Finalization + // Finalization. if (1 != EVP_EncryptFinal_ex( - ctx_, - ciphertext + length_buffer_length_ + kNonceLength + len, - &len)) { + ctx_, ciphertext + lengthBufferLength_ + kNonceLength + len, &len)) { throw ParquetException("Failed encryption finalization"); } - ciphertext_len += len; + ciphertextLen += len; - // Copying the buffer size and nonce to ciphertext - uint32_t buffer_size = kNonceLength + ciphertext_len; - if (length_buffer_length_ > 0) { - ciphertext[3] = static_cast(0xff & (buffer_size >> 24)); - ciphertext[2] = static_cast(0xff & (buffer_size >> 16)); - ciphertext[1] = static_cast(0xff & (buffer_size >> 8)); - ciphertext[0] = static_cast(0xff & (buffer_size)); + // Copying the buffer size and nonce to ciphertext. + uint32_t bufferSize = kNonceLength + ciphertextLen; + if (lengthBufferLength_ > 0) { + ciphertext[3] = static_cast(0xff & (bufferSize >> 24)); + ciphertext[2] = static_cast(0xff & (bufferSize >> 16)); + ciphertext[1] = static_cast(0xff & (bufferSize >> 8)); + ciphertext[0] = static_cast(0xff & (bufferSize)); } - std::copy(nonce, nonce + kNonceLength, ciphertext + length_buffer_length_); + std::copy(nonce, nonce + kNonceLength, ciphertext + lengthBufferLength_); - return length_buffer_length_ + buffer_size; + return lengthBufferLength_ + bufferSize; } AesEncryptor::~AesEncryptor() {} -int AesEncryptor::SignedFooterEncrypt( +int AesEncryptor::signedFooterEncrypt( const uint8_t* footer, - int footer_len, + int footerLen, const uint8_t* key, - int key_len, + int keyLen, const uint8_t* aad, - int aad_len, + int aadLen, const uint8_t* nonce, - uint8_t* encrypted_footer) { - return impl_->SignedFooterEncrypt( - footer, footer_len, key, key_len, aad, aad_len, nonce, encrypted_footer); + uint8_t* encryptedFooter) { + return impl_->signedFooterEncrypt( + footer, footerLen, key, keyLen, aad, aadLen, nonce, encryptedFooter); } -void AesEncryptor::WipeOut() { - impl_->WipeOut(); +void AesEncryptor::wipeOut() { + impl_->wipeOut(); } -int AesEncryptor::CiphertextSizeDelta() { - return impl_->ciphertext_size_delta(); +int AesEncryptor::ciphertextSizeDelta() { + return impl_->ciphertextSizeDelta(); } -int AesEncryptor::Encrypt( +int AesEncryptor::encrypt( const uint8_t* plaintext, - int plaintext_len, + int plaintextLen, const uint8_t* key, - int key_len, + int keyLen, const uint8_t* aad, - int aad_len, + int aadLen, uint8_t* ciphertext) { - return impl_->Encrypt( - plaintext, plaintext_len, key, key_len, aad, aad_len, ciphertext); + return impl_->encrypt( + plaintext, plaintextLen, key, keyLen, aad, aadLen, ciphertext); } AesEncryptor::AesEncryptor( - ParquetCipher::type alg_id, - int key_len, + ParquetCipher::type algId, + int keyLen, bool metadata, - bool write_length) + bool writeLength) : impl_{std::unique_ptr( - new AesEncryptorImpl(alg_id, key_len, metadata, write_length))} {} + new AesEncryptorImpl(algId, keyLen, metadata, writeLength))} {} class AesDecryptor::AesDecryptorImpl { public: explicit AesDecryptorImpl( - ParquetCipher::type alg_id, - int key_len, + ParquetCipher::type algId, + int keyLen, bool metadata, - bool contains_length); + bool containsLength); ~AesDecryptorImpl() { if (nullptr != ctx_) { @@ -420,345 +409,344 @@ class AesDecryptor::AesDecryptorImpl { } } - int Decrypt( + int decrypt( const uint8_t* ciphertext, - int ciphertext_len, + int ciphertextLen, const uint8_t* key, - int key_len, + int keyLen, const uint8_t* aad, - int aad_len, + int aadLen, uint8_t* plaintext); - void WipeOut() { + void wipeOut() { if (nullptr != ctx_) { EVP_CIPHER_CTX_free(ctx_); ctx_ = nullptr; } } - int ciphertext_size_delta() { - return ciphertext_size_delta_; + int ciphertextSizeDelta() { + return ciphertextSizeDelta_; } private: EVP_CIPHER_CTX* ctx_; - int aes_mode_; - int key_length_; - int ciphertext_size_delta_; - int length_buffer_length_; - int GcmDecrypt( + int aesMode_; + int keyLength_; + int ciphertextSizeDelta_; + int lengthBufferLength_; + int gcmDecrypt( const uint8_t* ciphertext, - int ciphertext_len, + int ciphertextLen, const uint8_t* key, - int key_len, + int keyLen, const uint8_t* aad, - int aad_len, + int aadLen, uint8_t* plaintext); - int CtrDecrypt( + int ctrDecrypt( const uint8_t* ciphertext, - int ciphertext_len, + int ciphertextLen, const uint8_t* key, - int key_len, + int keyLen, uint8_t* plaintext); }; -int AesDecryptor::Decrypt( +int AesDecryptor::decrypt( const uint8_t* plaintext, - int plaintext_len, + int plaintextLen, const uint8_t* key, - int key_len, + int keyLen, const uint8_t* aad, - int aad_len, + int aadLen, uint8_t* ciphertext) { - return impl_->Decrypt( - plaintext, plaintext_len, key, key_len, aad, aad_len, ciphertext); + return impl_->decrypt( + plaintext, plaintextLen, key, keyLen, aad, aadLen, ciphertext); } -void AesDecryptor::WipeOut() { - impl_->WipeOut(); +void AesDecryptor::wipeOut() { + impl_->wipeOut(); } AesDecryptor::~AesDecryptor() {} AesDecryptor::AesDecryptorImpl::AesDecryptorImpl( - ParquetCipher::type alg_id, - int key_len, + ParquetCipher::type algId, + int keyLen, bool metadata, - bool contains_length) { + bool containsLength) { ctx_ = nullptr; - length_buffer_length_ = contains_length ? kBufferSizeLength : 0; - ciphertext_size_delta_ = length_buffer_length_ + kNonceLength; - if (metadata || (ParquetCipher::AES_GCM_V1 == alg_id)) { - aes_mode_ = kGcmMode; - ciphertext_size_delta_ += kGcmTagLength; + lengthBufferLength_ = containsLength ? kBufferSizeLength : 0; + ciphertextSizeDelta_ = lengthBufferLength_ + kNonceLength; + if (metadata || (ParquetCipher::kAesGcmV1 == algId)) { + aesMode_ = kGcmMode; + ciphertextSizeDelta_ += kGcmTagLength; } else { - aes_mode_ = kCtrMode; + aesMode_ = kCtrMode; } - if (16 != key_len && 24 != key_len && 32 != key_len) { + if (16 != keyLen && 24 != keyLen && 32 != keyLen) { std::stringstream ss; - ss << "Wrong key length: " << key_len; + ss << "Wrong key length: " << keyLen; throw ParquetException(ss.str()); } - key_length_ = key_len; + keyLength_ = keyLen; ctx_ = EVP_CIPHER_CTX_new(); if (nullptr == ctx_) { throw ParquetException("Couldn't init cipher context"); } - if (kGcmMode == aes_mode_) { - // Init AES-GCM with specified key length - if (16 == key_len) { + if (kGcmMode == aesMode_) { + // Init AES-GCM with specified key length. + if (16 == keyLen) { DECRYPT_INIT(ctx_, EVP_aes_128_gcm()); - } else if (24 == key_len) { + } else if (24 == keyLen) { DECRYPT_INIT(ctx_, EVP_aes_192_gcm()); - } else if (32 == key_len) { + } else if (32 == keyLen) { DECRYPT_INIT(ctx_, EVP_aes_256_gcm()); } } else { - // Init AES-CTR with specified key length - if (16 == key_len) { + // Init AES-CTR with specified key length. + if (16 == keyLen) { DECRYPT_INIT(ctx_, EVP_aes_128_ctr()); - } else if (24 == key_len) { + } else if (24 == keyLen) { DECRYPT_INIT(ctx_, EVP_aes_192_ctr()); - } else if (32 == key_len) { + } else if (32 == keyLen) { DECRYPT_INIT(ctx_, EVP_aes_256_ctr()); } } } -AesEncryptor* AesEncryptor::Make( - ParquetCipher::type alg_id, - int key_len, +AesEncryptor* AesEncryptor::make( + ParquetCipher::type algId, + int keyLen, bool metadata, - std::vector* all_encryptors) { - return Make(alg_id, key_len, metadata, true /*write_length*/, all_encryptors); + std::vector* allEncryptors) { + return make(algId, keyLen, metadata, true /*write_length*/, allEncryptors); } -AesEncryptor* AesEncryptor::Make( - ParquetCipher::type alg_id, - int key_len, +AesEncryptor* AesEncryptor::make( + ParquetCipher::type algId, + int keyLen, bool metadata, - bool write_length, - std::vector* all_encryptors) { - if (ParquetCipher::AES_GCM_V1 != alg_id && - ParquetCipher::AES_GCM_CTR_V1 != alg_id) { + bool writeLength, + std::vector* allEncryptors) { + if (ParquetCipher::kAesGcmV1 != algId && + ParquetCipher::kAesGcmCtrV1 != algId) { std::stringstream ss; - ss << "Crypto algorithm " << alg_id << " is not supported"; + ss << "Crypto algorithm " << algId << " is not supported"; throw ParquetException(ss.str()); } - AesEncryptor* encryptor = - new AesEncryptor(alg_id, key_len, metadata, write_length); - if (all_encryptors != nullptr) - all_encryptors->push_back(encryptor); - return encryptor; + AesEncryptor* Encryptor = + new AesEncryptor(algId, keyLen, metadata, writeLength); + if (allEncryptors != nullptr) + allEncryptors->push_back(Encryptor); + return Encryptor; } AesDecryptor::AesDecryptor( - ParquetCipher::type alg_id, - int key_len, + ParquetCipher::type algId, + int keyLen, bool metadata, - bool contains_length) + bool containsLength) : impl_{std::unique_ptr( - new AesDecryptorImpl(alg_id, key_len, metadata, contains_length))} {} + new AesDecryptorImpl(algId, keyLen, metadata, containsLength))} {} -std::shared_ptr AesDecryptor::Make( - ParquetCipher::type alg_id, - int key_len, +std::shared_ptr AesDecryptor::make( + ParquetCipher::type algId, + int keyLen, bool metadata, - std::vector>* all_decryptors) { - if (ParquetCipher::AES_GCM_V1 != alg_id && - ParquetCipher::AES_GCM_CTR_V1 != alg_id) { + std::vector>* allDecryptors) { + if (ParquetCipher::kAesGcmV1 != algId && + ParquetCipher::kAesGcmCtrV1 != algId) { std::stringstream ss; - ss << "Crypto algorithm " << alg_id << " is not supported"; + ss << "Crypto algorithm " << algId << " is not supported"; throw ParquetException(ss.str()); } - auto decryptor = std::make_shared(alg_id, key_len, metadata); - if (all_decryptors != nullptr) { - all_decryptors->push_back(decryptor); + auto Decryptor = std::make_shared(algId, keyLen, metadata); + if (allDecryptors != nullptr) { + allDecryptors->push_back(Decryptor); } - return decryptor; + return Decryptor; } -int AesDecryptor::CiphertextSizeDelta() { - return impl_->ciphertext_size_delta(); +int AesDecryptor::ciphertextSizeDelta() { + return impl_->ciphertextSizeDelta(); } -int AesDecryptor::AesDecryptorImpl::GcmDecrypt( +int AesDecryptor::AesDecryptorImpl::gcmDecrypt( const uint8_t* ciphertext, - int ciphertext_len, + int ciphertextLen, const uint8_t* key, - int key_len, + int keyLen, const uint8_t* aad, - int aad_len, + int aadLen, uint8_t* plaintext) { int len = 0; - int plaintext_len; + int plaintextLen; uint8_t tag[kGcmTagLength]; memset(tag, 0, kGcmTagLength); uint8_t nonce[kNonceLength]; memset(nonce, 0, kNonceLength); - if (length_buffer_length_ > 0) { - // Extract ciphertext length - uint32_t written_ciphertext_len = ((ciphertext[3] & 0xff) << 24) | + if (lengthBufferLength_ > 0) { + // Extract ciphertext length. + uint32_t writtenCiphertextLen = ((ciphertext[3] & 0xff) << 24) | ((ciphertext[2] & 0xff) << 16) | ((ciphertext[1] & 0xff) << 8) | ((ciphertext[0] & 0xff)); - if (ciphertext_len > 0 && - ciphertext_len != (written_ciphertext_len + length_buffer_length_)) { + if (ciphertextLen > 0 && + ciphertextLen != (writtenCiphertextLen + lengthBufferLength_)) { throw ParquetException("Wrong ciphertext length"); } - ciphertext_len = written_ciphertext_len + length_buffer_length_; + ciphertextLen = writtenCiphertextLen + lengthBufferLength_; } else { - if (ciphertext_len == 0) { + if (ciphertextLen == 0) { throw ParquetException("Zero ciphertext length"); } } - // Extracting IV and tag + // Extracting IV and tag. std::copy( - ciphertext + length_buffer_length_, - ciphertext + length_buffer_length_ + kNonceLength, + ciphertext + lengthBufferLength_, + ciphertext + lengthBufferLength_ + kNonceLength, nonce); std::copy( - ciphertext + ciphertext_len - kGcmTagLength, - ciphertext + ciphertext_len, + ciphertext + ciphertextLen - kGcmTagLength, + ciphertext + ciphertextLen, tag); - // Setting key and IV + // Setting key and IV. if (1 != EVP_DecryptInit_ex(ctx_, nullptr, nullptr, key, nonce)) { throw ParquetException("Couldn't set key and IV"); } - // Setting additional authenticated data + // Setting additional authenticated data. if ((nullptr != aad) && - (1 != EVP_DecryptUpdate(ctx_, nullptr, &len, aad, aad_len))) { + (1 != EVP_DecryptUpdate(ctx_, nullptr, &len, aad, aadLen))) { throw ParquetException("Couldn't set AAD"); } - // Decryption + // Decryption. if (!EVP_DecryptUpdate( ctx_, plaintext, &len, - ciphertext + length_buffer_length_ + kNonceLength, - ciphertext_len - length_buffer_length_ - kNonceLength - - kGcmTagLength)) { + ciphertext + lengthBufferLength_ + kNonceLength, + ciphertextLen - lengthBufferLength_ - kNonceLength - kGcmTagLength)) { throw ParquetException("Failed decryption update"); } - plaintext_len = len; + plaintextLen = len; // Checking the tag (authentication) if (!EVP_CIPHER_CTX_ctrl(ctx_, EVP_CTRL_GCM_SET_TAG, kGcmTagLength, tag)) { throw ParquetException("Failed authentication"); } - // Finalization + // Finalization. if (1 != EVP_DecryptFinal_ex(ctx_, plaintext + len, &len)) { throw ParquetException("Failed decryption finalization"); } - plaintext_len += len; - return plaintext_len; + plaintextLen += len; + return plaintextLen; } -int AesDecryptor::AesDecryptorImpl::CtrDecrypt( +int AesDecryptor::AesDecryptorImpl::ctrDecrypt( const uint8_t* ciphertext, - int ciphertext_len, + int ciphertextLen, const uint8_t* key, - int key_len, + int keyLen, uint8_t* plaintext) { int len = 0; - int plaintext_len; + int plaintextLen; uint8_t iv[kCtrIvLength]; memset(iv, 0, kCtrIvLength); - if (length_buffer_length_ > 0) { - // Extract ciphertext length - uint32_t written_ciphertext_len = ((ciphertext[3] & 0xff) << 24) | + if (lengthBufferLength_ > 0) { + // Extract ciphertext length. + uint32_t writtenCiphertextLen = ((ciphertext[3] & 0xff) << 24) | ((ciphertext[2] & 0xff) << 16) | ((ciphertext[1] & 0xff) << 8) | ((ciphertext[0] & 0xff)); - if (ciphertext_len > 0 && - ciphertext_len != (written_ciphertext_len + length_buffer_length_)) { + if (ciphertextLen > 0 && + ciphertextLen != (writtenCiphertextLen + lengthBufferLength_)) { throw ParquetException("Wrong ciphertext length"); } - ciphertext_len = written_ciphertext_len; + ciphertextLen = writtenCiphertextLen; } else { - if (ciphertext_len == 0) { + if (ciphertextLen == 0) { throw ParquetException("Zero ciphertext length"); } } - // Extracting nonce + // Extracting nonce. std::copy( - ciphertext + length_buffer_length_, - ciphertext + length_buffer_length_ + kNonceLength, + ciphertext + lengthBufferLength_, + ciphertext + lengthBufferLength_ + kNonceLength, iv); - // Parquet CTR IVs are comprised of a 12-byte nonce and a 4-byte initial - // counter field. - // The first 31 bits of the initial counter field are set to 0, the last bit - // is set to 1. + // Parquet CTR IVs are comprised of a 12-byte nonce and a 4-byte initial. + // Counter field. + // The first 31 bits of the initial counter field are set to 0, the last bit. + // Is set to 1. iv[kCtrIvLength - 1] = 1; - // Setting key and IV + // Setting key and IV. if (1 != EVP_DecryptInit_ex(ctx_, nullptr, nullptr, key, iv)) { throw ParquetException("Couldn't set key and IV"); } - // Decryption + // Decryption. if (!EVP_DecryptUpdate( ctx_, plaintext, &len, - ciphertext + length_buffer_length_ + kNonceLength, - ciphertext_len - kNonceLength)) { + ciphertext + lengthBufferLength_ + kNonceLength, + ciphertextLen - kNonceLength)) { throw ParquetException("Failed decryption update"); } - plaintext_len = len; + plaintextLen = len; - // Finalization + // Finalization. if (1 != EVP_DecryptFinal_ex(ctx_, plaintext + len, &len)) { throw ParquetException("Failed decryption finalization"); } - plaintext_len += len; - return plaintext_len; + plaintextLen += len; + return plaintextLen; } -int AesDecryptor::AesDecryptorImpl::Decrypt( +int AesDecryptor::AesDecryptorImpl::decrypt( const uint8_t* ciphertext, - int ciphertext_len, + int ciphertextLen, const uint8_t* key, - int key_len, + int keyLen, const uint8_t* aad, - int aad_len, + int aadLen, uint8_t* plaintext) { - if (key_length_ != key_len) { + if (keyLength_ != keyLen) { std::stringstream ss; - ss << "Wrong key length " << key_len << ". Should be " << key_length_; + ss << "Wrong key length " << keyLen << ". Should be " << keyLength_; throw ParquetException(ss.str()); } - if (kGcmMode == aes_mode_) { - return GcmDecrypt( - ciphertext, ciphertext_len, key, key_len, aad, aad_len, plaintext); + if (kGcmMode == aesMode_) { + return gcmDecrypt( + ciphertext, ciphertextLen, key, keyLen, aad, aadLen, plaintext); } - return CtrDecrypt(ciphertext, ciphertext_len, key, key_len, plaintext); + return ctrDecrypt(ciphertext, ciphertextLen, key, keyLen, plaintext); } -static std::string ShortToBytesLe(int16_t input) { +static std::string shortToBytesLe(int16_t input) { int8_t output[2]; memset(output, 0, 2); uint16_t in = static_cast(input); @@ -768,66 +756,68 @@ static std::string ShortToBytesLe(int16_t input) { return std::string(reinterpret_cast(output), 2); } -static void CheckPageOrdinal(int32_t page_ordinal) { - if (ARROW_PREDICT_FALSE(page_ordinal > std::numeric_limits::max())) { +static void checkPageOrdinal(int32_t pageOrdinal) { + if (ARROW_PREDICT_FALSE(pageOrdinal > std::numeric_limits::max())) { throw ParquetException( "Encrypted Parquet files can't have more than " + std::to_string(std::numeric_limits::max()) + - " pages per chunk: got " + std::to_string(page_ordinal)); + " pages per chunk: got " + std::to_string(pageOrdinal)); } } -std::string CreateModuleAad( - const std::string& file_aad, - int8_t module_type, - int16_t row_group_ordinal, - int16_t column_ordinal, - int32_t page_ordinal) { - CheckPageOrdinal(page_ordinal); - const int16_t page_ordinal_short = static_cast(page_ordinal); - int8_t type_ordinal_bytes[1]; - type_ordinal_bytes[0] = module_type; - std::string type_ordinal_bytes_str( - reinterpret_cast(type_ordinal_bytes), 1); - if (kFooter == module_type) { - std::string result = file_aad + type_ordinal_bytes_str; +std::string createModuleAad( + const std::string& fileAad, + int8_t moduleType, + int16_t rowGroupOrdinal, + int16_t columnOrdinal, + int32_t pageOrdinal) { + checkPageOrdinal(pageOrdinal); + const int16_t pageOrdinalShort = static_cast(pageOrdinal); + int8_t typeOrdinalBytes[1]; + typeOrdinalBytes[0] = moduleType; + std::string typeOrdinalBytesStr( + reinterpret_cast(typeOrdinalBytes), 1); + if (kFooter == moduleType) { + std::string result = fileAad + typeOrdinalBytesStr; return result; } - std::string row_group_ordinal_bytes = ShortToBytesLe(row_group_ordinal); - std::string column_ordinal_bytes = ShortToBytesLe(column_ordinal); - if (kDataPage != module_type && kDataPageHeader != module_type) { + std::string rowGroupOrdinalBytes = shortToBytesLe(rowGroupOrdinal); + std::string columnOrdinalBytes = shortToBytesLe(columnOrdinal); + if (kDataPage != moduleType && kDataPageHeader != moduleType) { std::ostringstream out; - out << file_aad << type_ordinal_bytes_str << row_group_ordinal_bytes - << column_ordinal_bytes; + out << fileAad << typeOrdinalBytesStr << rowGroupOrdinalBytes + << columnOrdinalBytes; return out.str(); } - std::string page_ordinal_bytes = ShortToBytesLe(page_ordinal_short); + std::string pageOrdinalBytes = shortToBytesLe(pageOrdinalShort); std::ostringstream out; - out << file_aad << type_ordinal_bytes_str << row_group_ordinal_bytes - << column_ordinal_bytes << page_ordinal_bytes; + out << fileAad << typeOrdinalBytesStr << rowGroupOrdinalBytes + << columnOrdinalBytes << pageOrdinalBytes; return out.str(); } -std::string CreateFooterAad(const std::string& aad_prefix_bytes) { - return CreateModuleAad( - aad_prefix_bytes, +std::string createFooterAad(const std::string& aadPrefixBytes) { + return createModuleAad( + aadPrefixBytes, kFooter, static_cast(-1), static_cast(-1), static_cast(-1)); } -// Update last two bytes with new page ordinal (instead of creating new page AAD -// from scratch) -void QuickUpdatePageAad(int32_t new_page_ordinal, std::string* AAD) { - CheckPageOrdinal(new_page_ordinal); - const std::string page_ordinal_bytes = - ShortToBytesLe(static_cast(new_page_ordinal)); - std::memcpy(AAD->data() + AAD->length() - 2, page_ordinal_bytes.data(), 2); +// Update last two bytes with new page ordinal (instead of creating new page +// AAD. from scratch) +void quickUpdatePageAad(int32_t newPageOrdinal, std::string* AAD) { + checkPageOrdinal(newPageOrdinal); + const std::string pageOrdinalBytes = + shortToBytesLe(static_cast(newPageOrdinal)); + std::memcpy(AAD->data() + AAD->length() - 2, pageOrdinalBytes.data(), 2); } -void RandBytes(unsigned char* buf, int num) { - RAND_bytes(buf, num); +void randBytes(unsigned char* buf, int num) { + if (RAND_bytes(buf, num) != 1) { + throw ParquetException("Failed to generate random bytes"); + } } } // namespace facebook::velox::parquet::arrow::encryption diff --git a/velox/dwio/parquet/writer/arrow/EncryptionInternal.h b/velox/dwio/parquet/writer/arrow/EncryptionInternal.h index 1d554e5345c..bcb5068479c 100644 --- a/velox/dwio/parquet/writer/arrow/EncryptionInternal.h +++ b/velox/dwio/parquet/writer/arrow/EncryptionInternal.h @@ -32,7 +32,7 @@ namespace facebook::velox::parquet::arrow::encryption { constexpr int kGcmTagLength = 16; constexpr int kNonceLength = 12; -// Module types +// Module types. constexpr int8_t kFooter = 0; constexpr int8_t kColumnMetaData = 1; constexpr int8_t kDataPage = 2; @@ -46,58 +46,58 @@ constexpr int8_t kOffsetIndex = 7; class AesEncryptor { public: /// Can serve one key length only. Possible values: 16, 24, 32 bytes. - /// If write_length is true, prepend ciphertext length to the ciphertext + /// If write_length is true, prepend ciphertext length to the ciphertext. explicit AesEncryptor( - ParquetCipher::type alg_id, - int key_len, + ParquetCipher::type algId, + int keyLen, bool metadata, - bool write_length = true); + bool writeLength = true); - static AesEncryptor* Make( - ParquetCipher::type alg_id, - int key_len, + static AesEncryptor* make( + ParquetCipher::type algId, + int keyLen, bool metadata, - std::vector* all_encryptors); + std::vector* allEncryptors); - static AesEncryptor* Make( - ParquetCipher::type alg_id, - int key_len, + static AesEncryptor* make( + ParquetCipher::type algId, + int keyLen, bool metadata, - bool write_length, - std::vector* all_encryptors); + bool writeLength, + std::vector* allEncryptors); ~AesEncryptor(); /// Size difference between plaintext and ciphertext, for this cipher. - int CiphertextSizeDelta(); + int ciphertextSizeDelta(); - /// Encrypts plaintext with the key and aad. Key length is passed only for - /// validation. If different from value in constructor, exception will be - /// thrown. - int Encrypt( + /// Encrypts plaintext with the key and aad. Key length is passed only for. + /// Validation. If different from value in constructor, exception will be. + /// Thrown. + int encrypt( const uint8_t* plaintext, - int plaintext_len, + int plaintextLen, const uint8_t* key, - int key_len, + int keyLen, const uint8_t* aad, - int aad_len, + int aadLen, uint8_t* ciphertext); /// Encrypts plaintext footer, in order to compute footer signature (tag). - int SignedFooterEncrypt( + int signedFooterEncrypt( const uint8_t* footer, - int footer_len, + int footerLen, const uint8_t* key, - int key_len, + int keyLen, const uint8_t* aad, - int aad_len, + int aadLen, const uint8_t* nonce, - uint8_t* encrypted_footer); + uint8_t* encryptedFooter); - void WipeOut(); + void wipeOut(); private: - // PIMPL Idiom + // PIMPL Idiom. class AesEncryptorImpl; std::unique_ptr impl_; }; @@ -106,65 +106,65 @@ class AesEncryptor { class AesDecryptor { public: /// Can serve one key length only. Possible values: 16, 24, 32 bytes. - /// If contains_length is true, expect ciphertext length prepended to the - /// ciphertext + /// If contains_length is true, expect ciphertext length prepended to the. + /// Ciphertext. explicit AesDecryptor( - ParquetCipher::type alg_id, - int key_len, + ParquetCipher::type algId, + int keyLen, bool metadata, - bool contains_length = true); + bool containsLength = true); - /// \brief Factory function to create an AesDecryptor + /// \brief Factory function to create an AesDecryptor. /// - /// \param alg_id the encryption algorithm to use + /// \param alg_id the encryption algorithm to use. /// \param key_len key length. Possible values: 16, 24, 32 bytes. - /// \param metadata if true then this is a metadata decryptor - /// \param all_decryptors A weak reference to all decryptors that need to be - /// wiped out when decryption is finished \return shared pointer to a new - /// AesDecryptor - static std::shared_ptr Make( - ParquetCipher::type alg_id, - int key_len, + /// \param metadata if true then this is a metadata decryptor. + /// \param all_decryptors A weak reference to all decryptors that need to be. + /// Wiped out when decryption is finished \return shared pointer to a new. + /// AesDecryptor. + static std::shared_ptr make( + ParquetCipher::type algId, + int keyLen, bool metadata, - std::vector>* all_decryptors); + std::vector>* allDecryptors); ~AesDecryptor(); - void WipeOut(); + void wipeOut(); /// Size difference between plaintext and ciphertext, for this cipher. - int CiphertextSizeDelta(); + int ciphertextSizeDelta(); - /// Decrypts ciphertext with the key and aad. Key length is passed only for - /// validation. If different from value in constructor, exception will be - /// thrown. - int Decrypt( + /// Decrypts ciphertext with the key and aad. Key length is passed only for. + /// Validation. If different from value in constructor, exception will be. + /// Thrown. + int decrypt( const uint8_t* ciphertext, - int ciphertext_len, + int ciphertextLen, const uint8_t* key, - int key_len, + int keyLen, const uint8_t* aad, - int aad_len, + int aadLen, uint8_t* plaintext); private: - // PIMPL Idiom + // PIMPL Idiom. class AesDecryptorImpl; std::unique_ptr impl_; }; -std::string CreateModuleAad( - const std::string& file_aad, - int8_t module_type, - int16_t row_group_ordinal, - int16_t column_ordinal, - int32_t page_ordinal); +std::string createModuleAad( + const std::string& fileAad, + int8_t moduleType, + int16_t rowGroupOrdinal, + int16_t columnOrdinal, + int32_t pageOrdinal); -std::string CreateFooterAad(const std::string& aad_prefix_bytes); +std::string createFooterAad(const std::string& aadPrefixBytes); -// Update last two bytes of page (or page header) module AAD -void QuickUpdatePageAad(int32_t new_page_ordinal, std::string* AAD); +// Update last two bytes of page (or page header) module AAD. +void quickUpdatePageAad(int32_t newPageOrdinal, std::string* AAD); -// Wraps OpenSSL RAND_bytes function -void RandBytes(unsigned char* buf, int num); +// Wraps OpenSSL RAND_bytes function. +void randBytes(unsigned char* buf, int num); } // namespace facebook::velox::parquet::arrow::encryption diff --git a/velox/dwio/parquet/writer/arrow/Exception.h b/velox/dwio/parquet/writer/arrow/Exception.h index 927df340746..caeed90f08b 100644 --- a/velox/dwio/parquet/writer/arrow/Exception.h +++ b/velox/dwio/parquet/writer/arrow/Exception.h @@ -27,12 +27,12 @@ #include "arrow/util/string_builder.h" #include "velox/dwio/parquet/writer/arrow/Platform.h" -// PARQUET-1085 +// PARQUET-1085. #if !defined(ARROW_UNUSED) #define ARROW_UNUSED(x) UNUSED(x) #endif -// Parquet exception to Arrow Status +// Parquet exception to Arrow Status. #define BEGIN_PARQUET_CATCH_EXCEPTIONS try { #define END_PARQUET_CATCH_EXCEPTIONS \ @@ -44,51 +44,50 @@ return ::arrow::Status::IOError(e.what()); \ } -// clang-format off +// clang-format off. -#define PARQUET_CATCH_NOT_OK(s) \ - BEGIN_PARQUET_CATCH_EXCEPTIONS \ - (s); \ +#define PARQUET_CATCH_NOT_OK(s) \ + BEGIN_PARQUET_CATCH_EXCEPTIONS(s); \ END_PARQUET_CATCH_EXCEPTIONS -// clang-format on +// clang-format on. #define PARQUET_CATCH_AND_RETURN(s) \ BEGIN_PARQUET_CATCH_EXCEPTIONS \ return (s); \ END_PARQUET_CATCH_EXCEPTIONS -// Arrow Status to Parquet exception +// Arrow Status to Parquet exception. -#define PARQUET_IGNORE_NOT_OK(s) \ - do { \ - ::arrow::Status _s = ::arrow::internal::GenericToStatus(s); \ - ARROW_UNUSED(_s); \ +#define PARQUET_IGNORE_NOT_OK(s) \ + do { \ + ::arrow::Status S = ::arrow::internal::GenericToStatus(s); \ + ARROW_UNUSED(S); \ } while (0) #define PARQUET_THROW_NOT_OK(s) \ do { \ - ::arrow::Status _s = ::arrow::internal::GenericToStatus(s); \ - if (!_s.ok()) { \ + ::arrow::Status S = ::arrow::internal::GenericToStatus(s); \ + if (!S.ok()) { \ throw ::facebook::velox::parquet::arrow::ParquetStatusException( \ - std::move(_s)); \ + std::move(S)); \ } \ } while (0) -#define PARQUET_ASSIGN_OR_THROW_IMPL(status_name, lhs, rexpr) \ - auto status_name = (rexpr); \ - PARQUET_THROW_NOT_OK(status_name.status()); \ - lhs = std::move(status_name).ValueOrDie(); +#define PARQUET_ASSIGN_OR_THROW_IMPL(statusName, lhs, rexpr) \ + auto statusName = (rexpr); \ + PARQUET_THROW_NOT_OK(statusName.status()); \ + lhs = std::move(statusName).ValueOrDie(); #define PARQUET_ASSIGN_OR_THROW(lhs, rexpr) \ PARQUET_ASSIGN_OR_THROW_IMPL( \ - ARROW_ASSIGN_OR_RAISE_NAME(_error_or_value, __COUNTER__), lhs, rexpr); + ARROW_ASSIGN_OR_RAISE_NAME(ErrorOrValue, __COUNTER__), lhs, rexpr); namespace facebook::velox::parquet::arrow { class ParquetException : public std::exception { public: - PARQUET_NORETURN static void EofException(const std::string& msg = "") { + PARQUET_NORETURN static void eofException(const std::string& msg = "") { static std::string prefix = "Unexpected end of stream"; if (msg.empty()) { throw ParquetException(prefix); @@ -154,13 +153,14 @@ class ParquetInvalidOrCorruptedFileException : public ParquetStatusException { int>::type = 0, typename... Args> explicit ParquetInvalidOrCorruptedFileException(Arg arg, Args&&... args) - : ParquetStatusException(::arrow::Status::Invalid( - std::forward(arg), - std::forward(args)...)) {} + : ParquetStatusException( + ::arrow::Status::Invalid( + std::forward(arg), + std::forward(args)...)) {} }; template -void ThrowNotOk(StatusReturnBlock&& b) { +void throwNotOk(StatusReturnBlock&& b) { PARQUET_THROW_NOT_OK(b()); } diff --git a/velox/dwio/parquet/writer/arrow/FileDecryptorInternal.cpp b/velox/dwio/parquet/writer/arrow/FileDecryptorInternal.cpp index 6161a1e4474..ad578784229 100644 --- a/velox/dwio/parquet/writer/arrow/FileDecryptorInternal.cpp +++ b/velox/dwio/parquet/writer/arrow/FileDecryptorInternal.cpp @@ -22,30 +22,30 @@ namespace facebook::velox::parquet::arrow { -// Decryptor +// Decryptor. Decryptor::Decryptor( - std::shared_ptr aes_decryptor, + std::shared_ptr aesDecryptor, const std::string& key, - const std::string& file_aad, + const std::string& fileAad, const std::string& aad, ::arrow::MemoryPool* pool) - : aes_decryptor_(aes_decryptor), + : aesDecryptor_(aesDecryptor), key_(key), - file_aad_(file_aad), + fileAad_(fileAad), aad_(aad), pool_(pool) {} -int Decryptor::CiphertextSizeDelta() { - return aes_decryptor_->CiphertextSizeDelta(); +int Decryptor::ciphertextSizeDelta() { + return aesDecryptor_->ciphertextSizeDelta(); } -int Decryptor::Decrypt( +int Decryptor::decrypt( const uint8_t* ciphertext, - int ciphertext_len, + int ciphertextLen, uint8_t* plaintext) { - return aes_decryptor_->Decrypt( + return aesDecryptor_->decrypt( ciphertext, - ciphertext_len, + ciphertextLen, str2bytes(key_), static_cast(key_.size()), str2bytes(aad_), @@ -53,190 +53,190 @@ int Decryptor::Decrypt( plaintext); } -// InternalFileDecryptor +// InternalFileDecryptor. InternalFileDecryptor::InternalFileDecryptor( FileDecryptionProperties* properties, - const std::string& file_aad, + const std::string& fileAad, ParquetCipher::type algorithm, - const std::string& footer_key_metadata, + const std::string& footerKeyMetadata, ::arrow::MemoryPool* pool) : properties_(properties), - file_aad_(file_aad), + fileAad_(fileAad), algorithm_(algorithm), - footer_key_metadata_(footer_key_metadata), + footerKeyMetadata_(footerKeyMetadata), pool_(pool) { - if (properties_->is_utilized()) { + if (properties_->isUtilized()) { throw ParquetException( "Re-using decryption properties with explicit keys for another file"); } - properties_->set_utilized(); + properties_->setUtilized(); } -void InternalFileDecryptor::WipeOutDecryptionKeys() { - properties_->WipeOutDecryptionKeys(); - for (auto const& i : all_decryptors_) { - if (auto aes_decryptor = i.lock()) { - aes_decryptor->WipeOut(); +void InternalFileDecryptor::wipeOutDecryptionKeys() { + properties_->wipeOutDecryptionKeys(); + for (auto const& i : allDecryptors_) { + if (auto aesDecryptor = i.lock()) { + aesDecryptor->wipeOut(); } } } -std::string InternalFileDecryptor::GetFooterKey() { - std::string footer_key = properties_->footer_key(); - // ignore footer key metadata if footer key is explicitly set via API - if (footer_key.empty()) { - if (footer_key_metadata_.empty()) +std::string InternalFileDecryptor::getFooterKey() { + std::string footerKey = properties_->footerKey(); + // Ignore footer key metadata if footer key is explicitly set via API. + if (footerKey.empty()) { + if (footerKeyMetadata_.empty()) throw ParquetException("No footer key or key metadata"); - if (properties_->key_retriever() == nullptr) + if (properties_->keyRetriever() == nullptr) throw ParquetException("No footer key or key retriever"); try { - footer_key = properties_->key_retriever()->GetKey(footer_key_metadata_); + footerKey = properties_->keyRetriever()->getKey(footerKeyMetadata_); } catch (KeyAccessDeniedException& e) { std::stringstream ss; ss << "Footer key: access denied " << e.what() << "\n"; throw ParquetException(ss.str()); } } - if (footer_key.empty()) { + if (footerKey.empty()) { throw ParquetException( "Footer key unavailable. Could not verify " "plaintext footer metadata"); } - return footer_key; + return footerKey; } -std::shared_ptr InternalFileDecryptor::GetFooterDecryptor() { - std::string aad = encryption::CreateFooterAad(file_aad_); - return GetFooterDecryptor(aad, true); +std::shared_ptr InternalFileDecryptor::getFooterDecryptor() { + std::string aad = encryption::createFooterAad(fileAad_); + return getFooterDecryptor(aad, true); } std::shared_ptr -InternalFileDecryptor::GetFooterDecryptorForColumnMeta(const std::string& aad) { - return GetFooterDecryptor(aad, true); +InternalFileDecryptor::getFooterDecryptorForColumnMeta(const std::string& aad) { + return getFooterDecryptor(aad, true); } std::shared_ptr -InternalFileDecryptor::GetFooterDecryptorForColumnData(const std::string& aad) { - return GetFooterDecryptor(aad, false); +InternalFileDecryptor::getFooterDecryptorForColumnData(const std::string& aad) { + return getFooterDecryptor(aad, false); } -std::shared_ptr InternalFileDecryptor::GetFooterDecryptor( +std::shared_ptr InternalFileDecryptor::getFooterDecryptor( const std::string& aad, bool metadata) { if (metadata) { - if (footer_metadata_decryptor_ != nullptr) - return footer_metadata_decryptor_; + if (footerMetadataDecryptor_ != nullptr) + return footerMetadataDecryptor_; } else { - if (footer_data_decryptor_ != nullptr) - return footer_data_decryptor_; + if (footerDataDecryptor_ != nullptr) + return footerDataDecryptor_; } - std::string footer_key = properties_->footer_key(); - if (footer_key.empty()) { - if (footer_key_metadata_.empty()) + std::string footerKey = properties_->footerKey(); + if (footerKey.empty()) { + if (footerKeyMetadata_.empty()) throw ParquetException("No footer key or key metadata"); - if (properties_->key_retriever() == nullptr) + if (properties_->keyRetriever() == nullptr) throw ParquetException("No footer key or key retriever"); try { - footer_key = properties_->key_retriever()->GetKey(footer_key_metadata_); + footerKey = properties_->keyRetriever()->getKey(footerKeyMetadata_); } catch (KeyAccessDeniedException& e) { std::stringstream ss; ss << "Footer key: access denied " << e.what() << "\n"; throw ParquetException(ss.str()); } } - if (footer_key.empty()) { + if (footerKey.empty()) { throw ParquetException( "Invalid footer encryption key. " "Could not parse footer metadata"); } - // Create both data and metadata decryptors to avoid redundant retrieval of - // key from the key_retriever. - int key_len = static_cast(footer_key.size()); - auto aes_metadata_decryptor = encryption::AesDecryptor::Make( - algorithm_, key_len, /*metadata=*/true, &all_decryptors_); - auto aes_data_decryptor = encryption::AesDecryptor::Make( - algorithm_, key_len, /*metadata=*/false, &all_decryptors_); + // Create both data and metadata decryptors to avoid redundant retrieval of. + // Key from the key_retriever. + int keyLen = static_cast(footerKey.size()); + auto aesMetadataDecryptor = + encryption::AesDecryptor::make(algorithm_, keyLen, true, &allDecryptors_); + auto aesDataDecryptor = encryption::AesDecryptor::make( + algorithm_, keyLen, false, &allDecryptors_); - footer_metadata_decryptor_ = std::make_shared( - aes_metadata_decryptor, footer_key, file_aad_, aad, pool_); - footer_data_decryptor_ = std::make_shared( - aes_data_decryptor, footer_key, file_aad_, aad, pool_); + footerMetadataDecryptor_ = std::make_shared( + aesMetadataDecryptor, footerKey, fileAad_, aad, pool_); + footerDataDecryptor_ = std::make_shared( + aesDataDecryptor, footerKey, fileAad_, aad, pool_); if (metadata) - return footer_metadata_decryptor_; - return footer_data_decryptor_; + return footerMetadataDecryptor_; + return footerDataDecryptor_; } -std::shared_ptr InternalFileDecryptor::GetColumnMetaDecryptor( - const std::string& column_path, - const std::string& column_key_metadata, +std::shared_ptr InternalFileDecryptor::getColumnMetaDecryptor( + const std::string& ColumnPath, + const std::string& columnKeyMetadata, const std::string& aad) { - return GetColumnDecryptor(column_path, column_key_metadata, aad, true); + return getColumnDecryptor(ColumnPath, columnKeyMetadata, aad, true); } -std::shared_ptr InternalFileDecryptor::GetColumnDataDecryptor( - const std::string& column_path, - const std::string& column_key_metadata, +std::shared_ptr InternalFileDecryptor::getColumnDataDecryptor( + const std::string& ColumnPath, + const std::string& columnKeyMetadata, const std::string& aad) { - return GetColumnDecryptor(column_path, column_key_metadata, aad, false); + return getColumnDecryptor(ColumnPath, columnKeyMetadata, aad, false); } -std::shared_ptr InternalFileDecryptor::GetColumnDecryptor( - const std::string& column_path, - const std::string& column_key_metadata, +std::shared_ptr InternalFileDecryptor::getColumnDecryptor( + const std::string& ColumnPath, + const std::string& columnKeyMetadata, const std::string& aad, bool metadata) { - std::string column_key; - // first look if we already got the decryptor from before + std::string columnKey; + // First look if we already got the decryptor from before. if (metadata) { - if (column_metadata_map_.find(column_path) != column_metadata_map_.end()) { - auto res(column_metadata_map_.at(column_path)); - res->UpdateAad(aad); + if (columnMetadataMap_.find(ColumnPath) != columnMetadataMap_.end()) { + auto res(columnMetadataMap_.at(ColumnPath)); + res->updateAad(aad); return res; } } else { - if (column_data_map_.find(column_path) != column_data_map_.end()) { - auto res(column_data_map_.at(column_path)); - res->UpdateAad(aad); + if (columnDataMap_.find(ColumnPath) != columnDataMap_.end()) { + auto res(columnDataMap_.at(ColumnPath)); + res->updateAad(aad); return res; } } - column_key = properties_->column_key(column_path); + columnKey = properties_->columnKey(ColumnPath); // No explicit column key given via API. Retrieve via key metadata. - if (column_key.empty() && !column_key_metadata.empty() && - properties_->key_retriever() != nullptr) { + if (columnKey.empty() && !columnKeyMetadata.empty() && + properties_->keyRetriever() != nullptr) { try { - column_key = properties_->key_retriever()->GetKey(column_key_metadata); + columnKey = properties_->keyRetriever()->getKey(columnKeyMetadata); } catch (KeyAccessDeniedException& e) { std::stringstream ss; - ss << "HiddenColumnException, path=" + column_path + " " << e.what() + ss << "HiddenColumnException, path=" + ColumnPath + " " << e.what() << "\n"; throw HiddenColumnException(ss.str()); } } - if (column_key.empty()) { - throw HiddenColumnException("HiddenColumnException, path=" + column_path); + if (columnKey.empty()) { + throw HiddenColumnException("HiddenColumnException, path=" + ColumnPath); } - // Create both data and metadata decryptors to avoid redundant retrieval of - // key using the key_retriever. - int key_len = static_cast(column_key.size()); - auto aes_metadata_decryptor = encryption::AesDecryptor::Make( - algorithm_, key_len, /*metadata=*/true, &all_decryptors_); - auto aes_data_decryptor = encryption::AesDecryptor::Make( - algorithm_, key_len, /*metadata=*/false, &all_decryptors_); + // Create both data and metadata decryptors to avoid redundant retrieval of. + // Key using the key_retriever. + int keyLen = static_cast(columnKey.size()); + auto aesMetadataDecryptor = + encryption::AesDecryptor::make(algorithm_, keyLen, true, &allDecryptors_); + auto aesDataDecryptor = encryption::AesDecryptor::make( + algorithm_, keyLen, false, &allDecryptors_); - column_metadata_map_[column_path] = std::make_shared( - aes_metadata_decryptor, column_key, file_aad_, aad, pool_); - column_data_map_[column_path] = std::make_shared( - aes_data_decryptor, column_key, file_aad_, aad, pool_); + columnMetadataMap_[ColumnPath] = std::make_shared( + aesMetadataDecryptor, columnKey, fileAad_, aad, pool_); + columnDataMap_[ColumnPath] = std::make_shared( + aesDataDecryptor, columnKey, fileAad_, aad, pool_); if (metadata) - return column_metadata_map_[column_path]; - return column_data_map_[column_path]; + return columnMetadataMap_[ColumnPath]; + return columnDataMap_[ColumnPath]; } } // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/FileDecryptorInternal.h b/velox/dwio/parquet/writer/arrow/FileDecryptorInternal.h index b2e2de5e7cd..e34e5ba5b17 100644 --- a/velox/dwio/parquet/writer/arrow/FileDecryptorInternal.h +++ b/velox/dwio/parquet/writer/arrow/FileDecryptorInternal.h @@ -37,32 +37,29 @@ class FileDecryptionProperties; class PARQUET_EXPORT Decryptor { public: Decryptor( - std::shared_ptr decryptor, + std::shared_ptr Decryptor, const std::string& key, - const std::string& file_aad, + const std::string& fileAad, const std::string& aad, ::arrow::MemoryPool* pool); - const std::string& file_aad() const { - return file_aad_; + const std::string& fileAad() const { + return fileAad_; } - void UpdateAad(const std::string& aad) { + void updateAad(const std::string& aad) { aad_ = aad; } ::arrow::MemoryPool* pool() { return pool_; } - int CiphertextSizeDelta(); - int Decrypt( - const uint8_t* ciphertext, - int ciphertext_len, - uint8_t* plaintext); + int ciphertextSizeDelta(); + int decrypt(const uint8_t* ciphertext, int ciphertextLen, uint8_t* plaintext); private: - std::shared_ptr aes_decryptor_; + std::shared_ptr aesDecryptor_; std::string key_; - std::string file_aad_; + std::string fileAad_; std::string aad_; ::arrow::MemoryPool* pool_; }; @@ -71,72 +68,72 @@ class InternalFileDecryptor { public: explicit InternalFileDecryptor( FileDecryptionProperties* properties, - const std::string& file_aad, + const std::string& fileAad, ParquetCipher::type algorithm, - const std::string& footer_key_metadata, + const std::string& footerKeyMetadata, ::arrow::MemoryPool* pool); - std::string& file_aad() { - return file_aad_; + std::string& fileAad() { + return fileAad_; } - std::string GetFooterKey(); + std::string getFooterKey(); ParquetCipher::type algorithm() { return algorithm_; } - std::string& footer_key_metadata() { - return footer_key_metadata_; + std::string& footerKeyMetadata() { + return footerKeyMetadata_; } FileDecryptionProperties* properties() { return properties_; } - void WipeOutDecryptionKeys(); + void wipeOutDecryptionKeys(); ::arrow::MemoryPool* pool() { return pool_; } - std::shared_ptr GetFooterDecryptor(); - std::shared_ptr GetFooterDecryptorForColumnMeta( + std::shared_ptr getFooterDecryptor(); + std::shared_ptr getFooterDecryptorForColumnMeta( const std::string& aad = ""); - std::shared_ptr GetFooterDecryptorForColumnData( + std::shared_ptr getFooterDecryptorForColumnData( const std::string& aad = ""); - std::shared_ptr GetColumnMetaDecryptor( - const std::string& column_path, - const std::string& column_key_metadata, + std::shared_ptr getColumnMetaDecryptor( + const std::string& ColumnPath, + const std::string& columnKeyMetadata, const std::string& aad = ""); - std::shared_ptr GetColumnDataDecryptor( - const std::string& column_path, - const std::string& column_key_metadata, + std::shared_ptr getColumnDataDecryptor( + const std::string& ColumnPath, + const std::string& columnKeyMetadata, const std::string& aad = ""); private: FileDecryptionProperties* properties_; - // Concatenation of aad_prefix (if exists) and aad_file_unique - std::string file_aad_; - std::map> column_data_map_; - std::map> column_metadata_map_; + // Concatenation of aad_prefix (if exists) and aad_file_unique. + std::string fileAad_; + std::map> columnDataMap_; + std::map> columnMetadataMap_; - std::shared_ptr footer_metadata_decryptor_; - std::shared_ptr footer_data_decryptor_; + std::shared_ptr footerMetadataDecryptor_; + std::shared_ptr footerDataDecryptor_; ParquetCipher::type algorithm_; - std::string footer_key_metadata_; - // A weak reference to all decryptors that need to be wiped out when - // decryption is finished - std::vector> all_decryptors_; + std::string footerKeyMetadata_; + // A weak reference to all decryptors that need to be wiped out when. + // Decryption is finished. + std::vector> allDecryptors_; ::arrow::MemoryPool* pool_; - std::shared_ptr GetFooterDecryptor( + std::shared_ptr getFooterDecryptor( const std::string& aad, bool metadata); - std::shared_ptr GetColumnDecryptor( - const std::string& column_path, - const std::string& column_key_metadata, + std::shared_ptr getColumnDecryptor( + const std::string& ColumnPath, + const std::string& columnKeyMetadata, const std::string& aad, bool metadata = false); }; diff --git a/velox/dwio/parquet/writer/arrow/FileEncryptorInternal.cpp b/velox/dwio/parquet/writer/arrow/FileEncryptorInternal.cpp index 6192156c7b2..74c4c2fbd7f 100644 --- a/velox/dwio/parquet/writer/arrow/FileEncryptorInternal.cpp +++ b/velox/dwio/parquet/writer/arrow/FileEncryptorInternal.cpp @@ -22,30 +22,30 @@ namespace facebook::velox::parquet::arrow { -// Encryptor +// Encryptor. Encryptor::Encryptor( - encryption::AesEncryptor* aes_encryptor, + encryption::AesEncryptor* aesEncryptor, const std::string& key, - const std::string& file_aad, + const std::string& fileAad, const std::string& aad, ::arrow::MemoryPool* pool) - : aes_encryptor_(aes_encryptor), + : aesEncryptor_(aesEncryptor), key_(key), - file_aad_(file_aad), + fileAad_(fileAad), aad_(aad), pool_(pool) {} -int Encryptor::CiphertextSizeDelta() { - return aes_encryptor_->CiphertextSizeDelta(); +int Encryptor::ciphertextSizeDelta() { + return aesEncryptor_->ciphertextSizeDelta(); } -int Encryptor::Encrypt( +int Encryptor::encrypt( const uint8_t* plaintext, - int plaintext_len, + int plaintextLen, uint8_t* ciphertext) { - return aes_encryptor_->Encrypt( + return aesEncryptor_->encrypt( plaintext, - plaintext_len, + plaintextLen, str2bytes(key_), static_cast(key_.size()), str2bytes(aad_), @@ -53,141 +53,138 @@ int Encryptor::Encrypt( ciphertext); } -// InternalFileEncryptor +// InternalFileEncryptor. InternalFileEncryptor::InternalFileEncryptor( FileEncryptionProperties* properties, ::arrow::MemoryPool* pool) : properties_(properties), pool_(pool) { - if (properties_->is_utilized()) { + if (properties_->isUtilized()) { throw ParquetException("Re-using encryption properties for another file"); } - properties_->set_utilized(); + properties_->setUtilized(); } -void InternalFileEncryptor::WipeOutEncryptionKeys() { - properties_->WipeOutEncryptionKeys(); +void InternalFileEncryptor::wipeOutEncryptionKeys() { + properties_->wipeOutEncryptionKeys(); - for (auto const& i : all_encryptors_) { - i->WipeOut(); + for (auto const& i : allEncryptors_) { + i->wipeOut(); } } -std::shared_ptr InternalFileEncryptor::GetFooterEncryptor() { - if (footer_encryptor_ != nullptr) { - return footer_encryptor_; +std::shared_ptr InternalFileEncryptor::getFooterEncryptor() { + if (footerEncryptor_ != nullptr) { + return footerEncryptor_; } ParquetCipher::type algorithm = properties_->algorithm().algorithm; - std::string footer_aad = encryption::CreateFooterAad(properties_->file_aad()); - std::string footer_key = properties_->footer_key(); - auto aes_encryptor = GetMetaAesEncryptor(algorithm, footer_key.size()); - footer_encryptor_ = std::make_shared( - aes_encryptor, footer_key, properties_->file_aad(), footer_aad, pool_); - return footer_encryptor_; + std::string footerAad = encryption::createFooterAad(properties_->fileAad()); + std::string footerKey = properties_->footerKey(); + auto aesEncryptor = getMetaAesEncryptor(algorithm, footerKey.size()); + footerEncryptor_ = std::make_shared( + aesEncryptor, footerKey, properties_->fileAad(), footerAad, pool_); + return footerEncryptor_; } -std::shared_ptr InternalFileEncryptor::GetFooterSigningEncryptor() { - if (footer_signing_encryptor_ != nullptr) { - return footer_signing_encryptor_; +std::shared_ptr InternalFileEncryptor::getFooterSigningEncryptor() { + if (footerSigningEncryptor_ != nullptr) { + return footerSigningEncryptor_; } ParquetCipher::type algorithm = properties_->algorithm().algorithm; - std::string footer_aad = encryption::CreateFooterAad(properties_->file_aad()); - std::string footer_signing_key = properties_->footer_key(); - auto aes_encryptor = - GetMetaAesEncryptor(algorithm, footer_signing_key.size()); - footer_signing_encryptor_ = std::make_shared( - aes_encryptor, - footer_signing_key, - properties_->file_aad(), - footer_aad, - pool_); - return footer_signing_encryptor_; + std::string footerAad = encryption::createFooterAad(properties_->fileAad()); + std::string footerSigningKey = properties_->footerKey(); + auto aesEncryptor = getMetaAesEncryptor(algorithm, footerSigningKey.size()); + footerSigningEncryptor_ = std::make_shared( + aesEncryptor, footerSigningKey, properties_->fileAad(), footerAad, pool_); + return footerSigningEncryptor_; } -std::shared_ptr InternalFileEncryptor::GetColumnMetaEncryptor( - const std::string& column_path) { - return GetColumnEncryptor(column_path, true); +std::shared_ptr InternalFileEncryptor::getColumnMetaEncryptor( + const std::string& columnPath) { + return getColumnEncryptor(columnPath, true); } -std::shared_ptr InternalFileEncryptor::GetColumnDataEncryptor( - const std::string& column_path) { - return GetColumnEncryptor(column_path, false); +std::shared_ptr InternalFileEncryptor::getColumnDataEncryptor( + const std::string& columnPath) { + return getColumnEncryptor(columnPath, false); } std::shared_ptr -InternalFileEncryptor::InternalFileEncryptor::GetColumnEncryptor( - const std::string& column_path, +InternalFileEncryptor::InternalFileEncryptor::getColumnEncryptor( + const std::string& columnPath, bool metadata) { - // first look if we already got the encryptor from before + // First look if we already got the encryptor from before. if (metadata) { - if (column_metadata_map_.find(column_path) != column_metadata_map_.end()) { - return column_metadata_map_.at(column_path); + if (columnMetadataMap_.find(columnPath) != columnMetadataMap_.end()) { + return columnMetadataMap_.at(columnPath); } } else { - if (column_data_map_.find(column_path) != column_data_map_.end()) { - return column_data_map_.at(column_path); + if (columnDataMap_.find(columnPath) != columnDataMap_.end()) { + return columnDataMap_.at(columnPath); } } - auto column_prop = properties_->column_encryption_properties(column_path); - if (column_prop == nullptr) { + auto columnProp = properties_->columnEncryptionProperties(columnPath); + if (columnProp == nullptr) { return nullptr; } std::string key; - if (column_prop->is_encrypted_with_footer_key()) { - key = properties_->footer_key(); + if (columnProp->isEncryptedWithFooterKey()) { + key = properties_->footerKey(); } else { - key = column_prop->key(); + key = columnProp->key(); } ParquetCipher::type algorithm = properties_->algorithm().algorithm; - auto aes_encryptor = metadata ? GetMetaAesEncryptor(algorithm, key.size()) - : GetDataAesEncryptor(algorithm, key.size()); + auto aesEncryptor = metadata ? getMetaAesEncryptor(algorithm, key.size()) + : getDataAesEncryptor(algorithm, key.size()); - std::string file_aad = properties_->file_aad(); + std::string fileAad = properties_->fileAad(); std::shared_ptr encryptor = - std::make_shared(aes_encryptor, key, file_aad, "", pool_); + std::make_shared(aesEncryptor, key, fileAad, "", pool_); if (metadata) - column_metadata_map_[column_path] = encryptor; + columnMetadataMap_[columnPath] = encryptor; else - column_data_map_[column_path] = encryptor; + columnDataMap_[columnPath] = encryptor; return encryptor; } -int InternalFileEncryptor::MapKeyLenToEncryptorArrayIndex(int key_len) { - if (key_len == 16) +int InternalFileEncryptor::mapKeyLenToEncryptorArrayIndex(int keyLen) { + if (keyLen == 16) return 0; - else if (key_len == 24) + else if (keyLen == 24) return 1; - else if (key_len == 32) + else if (keyLen == 32) return 2; throw ParquetException("encryption key must be 16, 24 or 32 bytes in length"); } -encryption::AesEncryptor* InternalFileEncryptor::GetMetaAesEncryptor( +encryption::AesEncryptor* InternalFileEncryptor::getMetaAesEncryptor( ParquetCipher::type algorithm, - size_t key_size) { - int key_len = static_cast(key_size); - int index = MapKeyLenToEncryptorArrayIndex(key_len); - if (meta_encryptor_[index] == nullptr) { - meta_encryptor_[index].reset(encryption::AesEncryptor::Make( - algorithm, key_len, true, &all_encryptors_)); + size_t keySize) { + int keyLen = static_cast(keySize); + int index = mapKeyLenToEncryptorArrayIndex(keyLen); + if (metaEncryptor_[index] == nullptr) { + metaEncryptor_[index].reset( + encryption::AesEncryptor::make( + algorithm, keyLen, true, &allEncryptors_)); } - return meta_encryptor_[index].get(); + return metaEncryptor_[index].get(); } -encryption::AesEncryptor* InternalFileEncryptor::GetDataAesEncryptor( +encryption::AesEncryptor* InternalFileEncryptor::getDataAesEncryptor( ParquetCipher::type algorithm, - size_t key_size) { - int key_len = static_cast(key_size); - int index = MapKeyLenToEncryptorArrayIndex(key_len); - if (data_encryptor_[index] == nullptr) { - data_encryptor_[index].reset(encryption::AesEncryptor::Make( - algorithm, key_len, false, &all_encryptors_)); + size_t keySize) { + int keyLen = static_cast(keySize); + int index = mapKeyLenToEncryptorArrayIndex(keyLen); + if (dataEncryptor_[index] == nullptr) { + dataEncryptor_[index].reset( + encryption::AesEncryptor::make( + algorithm, keyLen, false, &allEncryptors_)); } - return data_encryptor_[index].get(); + return dataEncryptor_[index].get(); } } // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/FileEncryptorInternal.h b/velox/dwio/parquet/writer/arrow/FileEncryptorInternal.h index 39f6be4b0e1..eac13b5d2c4 100644 --- a/velox/dwio/parquet/writer/arrow/FileEncryptorInternal.h +++ b/velox/dwio/parquet/writer/arrow/FileEncryptorInternal.h @@ -38,44 +38,44 @@ class ColumnEncryptionProperties; class PARQUET_EXPORT Encryptor { public: Encryptor( - encryption::AesEncryptor* aes_encryptor, + encryption::AesEncryptor* aesEncryptor, const std::string& key, - const std::string& file_aad, + const std::string& fileAad, const std::string& aad, ::arrow::MemoryPool* pool); - const std::string& file_aad() { - return file_aad_; + const std::string& fileAad() { + return fileAad_; } - void UpdateAad(const std::string& aad) { + void updateAad(const std::string& aad) { aad_ = aad; } ::arrow::MemoryPool* pool() { return pool_; } - int CiphertextSizeDelta(); - int Encrypt(const uint8_t* plaintext, int plaintext_len, uint8_t* ciphertext); + int ciphertextSizeDelta(); + int encrypt(const uint8_t* plaintext, int plaintextLen, uint8_t* ciphertext); - bool EncryptColumnMetaData( - bool encrypted_footer, + bool encryptColumnMetaData( + bool encryptedFooter, const std::shared_ptr& - column_encryption_properties) { - // if column is not encrypted then do not encrypt the column metadata - if (!column_encryption_properties || - !column_encryption_properties->is_encrypted()) + ColumnEncryptionProperties) { + // If column is not encrypted then do not encrypt the column metadata. + if (!ColumnEncryptionProperties || + !ColumnEncryptionProperties->isEncrypted()) return false; - // if plaintext footer then encrypt the column metadata - if (!encrypted_footer) + // If plaintext footer then encrypt the column metadata. + if (!encryptedFooter) return true; - // if column is not encrypted with footer key then encrypt the column - // metadata - return !column_encryption_properties->is_encrypted_with_footer_key(); + // If column is not encrypted with footer key then encrypt the column. + // Metadata. + return !ColumnEncryptionProperties->isEncryptedWithFooterKey(); } private: - encryption::AesEncryptor* aes_encryptor_; + encryption::AesEncryptor* aesEncryptor_; std::string key_; - std::string file_aad_; + std::string fileAad_; std::string aad_; ::arrow::MemoryPool* pool_; }; @@ -86,44 +86,44 @@ class InternalFileEncryptor { FileEncryptionProperties* properties, ::arrow::MemoryPool* pool); - std::shared_ptr GetFooterEncryptor(); - std::shared_ptr GetFooterSigningEncryptor(); - std::shared_ptr GetColumnMetaEncryptor( - const std::string& column_path); - std::shared_ptr GetColumnDataEncryptor( - const std::string& column_path); - void WipeOutEncryptionKeys(); + std::shared_ptr getFooterEncryptor(); + std::shared_ptr getFooterSigningEncryptor(); + std::shared_ptr getColumnMetaEncryptor( + const std::string& ColumnPath); + std::shared_ptr getColumnDataEncryptor( + const std::string& ColumnPath); + void wipeOutEncryptionKeys(); private: FileEncryptionProperties* properties_; - std::map> column_data_map_; - std::map> column_metadata_map_; + std::map> columnDataMap_; + std::map> columnMetadataMap_; - std::shared_ptr footer_signing_encryptor_; - std::shared_ptr footer_encryptor_; + std::shared_ptr footerSigningEncryptor_; + std::shared_ptr footerEncryptor_; - std::vector all_encryptors_; + std::vector allEncryptors_; - // Key must be 16, 24 or 32 bytes in length. Thus there could be up to three - // types of meta_encryptors and data_encryptors. - std::unique_ptr meta_encryptor_[3]; - std::unique_ptr data_encryptor_[3]; + // Key must be 16, 24 or 32 bytes in length. Thus there could be up to three. + // Types of meta_encryptors and data_encryptors. + std::unique_ptr metaEncryptor_[3]; + std::unique_ptr dataEncryptor_[3]; ::arrow::MemoryPool* pool_; - std::shared_ptr GetColumnEncryptor( - const std::string& column_path, + std::shared_ptr getColumnEncryptor( + const std::string& ColumnPath, bool metadata); - encryption::AesEncryptor* GetMetaAesEncryptor( + encryption::AesEncryptor* getMetaAesEncryptor( ParquetCipher::type algorithm, - size_t key_len); - encryption::AesEncryptor* GetDataAesEncryptor( + size_t keyLen); + encryption::AesEncryptor* getDataAesEncryptor( ParquetCipher::type algorithm, - size_t key_len); + size_t keyLen); - int MapKeyLenToEncryptorArrayIndex(int key_len); + int mapKeyLenToEncryptorArrayIndex(int keyLen); }; } // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/FileWriter.cpp b/velox/dwio/parquet/writer/arrow/FileWriter.cpp index ea89c4838eb..3e2e49e6e8f 100644 --- a/velox/dwio/parquet/writer/arrow/FileWriter.cpp +++ b/velox/dwio/parquet/writer/arrow/FileWriter.cpp @@ -41,256 +41,254 @@ namespace facebook::velox::parquet::arrow { using schema::GroupNode; -// ---------------------------------------------------------------------- -// RowGroupWriter public API +// ----------------------------------------------------------------------. +// RowGroupWriter public API. RowGroupWriter::RowGroupWriter(std::unique_ptr contents) : contents_(std::move(contents)) {} -void RowGroupWriter::Close() { +void RowGroupWriter::close() { if (contents_) { - contents_->Close(); + contents_->close(); } } -ColumnWriter* RowGroupWriter::NextColumn() { - return contents_->NextColumn(); +ColumnWriter* RowGroupWriter::nextColumn() { + return contents_->nextColumn(); } ColumnWriter* RowGroupWriter::column(int i) { return contents_->column(i); } -int64_t RowGroupWriter::total_compressed_bytes() const { - return contents_->total_compressed_bytes(); +int64_t RowGroupWriter::totalCompressedBytes() const { + return contents_->totalCompressedBytes(); } -int64_t RowGroupWriter::total_bytes_written() const { - return contents_->total_bytes_written(); +int64_t RowGroupWriter::totalBytesWritten() const { + return contents_->totalBytesWritten(); } -int64_t RowGroupWriter::total_compressed_bytes_written() const { - return contents_->total_compressed_bytes_written(); +int64_t RowGroupWriter::totalCompressedBytesWritten() const { + return contents_->totalCompressedBytesWritten(); } bool RowGroupWriter::buffered() const { return contents_->buffered(); } -int RowGroupWriter::current_column() { - return contents_->current_column(); +int RowGroupWriter::currentColumn() { + return contents_->currentColumn(); } -int RowGroupWriter::num_columns() const { - return contents_->num_columns(); +int RowGroupWriter::numColumns() const { + return contents_->numColumns(); } -int64_t RowGroupWriter::num_rows() const { - return contents_->num_rows(); +int64_t RowGroupWriter::numRows() const { + return contents_->numRows(); } -inline void ThrowRowsMisMatchError(int col, int64_t prev, int64_t curr) { +inline void throwRowsMisMatchError(int col, int64_t prev, int64_t curr) { std::stringstream ss; ss << "Column " << col << " had " << curr << " while previous column had " << prev; throw ParquetException(ss.str()); } -// ---------------------------------------------------------------------- -// RowGroupSerializer +// ----------------------------------------------------------------------. +// RowGroupSerializer. -// RowGroupWriter::Contents implementation for the Parquet file specification +// RowGroupWriter::Contents implementation for the Parquet file specification. class RowGroupSerializer : public RowGroupWriter::Contents { public: RowGroupSerializer( std::shared_ptr sink, RowGroupMetaDataBuilder* metadata, - int16_t row_group_ordinal, + int16_t rowGroupOrdinal, const WriterProperties* properties, - bool buffered_row_group = false, - InternalFileEncryptor* file_encryptor = nullptr, - PageIndexBuilder* page_index_builder = nullptr) + bool bufferedRowGroup = false, + InternalFileEncryptor* fileEncryptor = nullptr, + PageIndexBuilder* pageIndexBuilder = nullptr) : sink_(std::move(sink)), metadata_(metadata), properties_(properties), - total_bytes_written_(0), - total_compressed_bytes_written_(0), + totalBytesWritten_(0), + totalCompressedBytesWritten_(0), closed_(false), - row_group_ordinal_(row_group_ordinal), - next_column_index_(0), - num_rows_(0), - buffered_row_group_(buffered_row_group), - file_encryptor_(file_encryptor), - page_index_builder_(page_index_builder) { - if (buffered_row_group) { - InitColumns(); + rowGroupOrdinal_(rowGroupOrdinal), + nextColumnIndex_(0), + numRows_(0), + bufferedRowGroup_(bufferedRowGroup), + fileEncryptor_(fileEncryptor), + pageIndexBuilder_(pageIndexBuilder) { + if (bufferedRowGroup) { + initColumns(); } else { - column_writers_.push_back(nullptr); + columnWriters_.push_back(nullptr); } } - int num_columns() const override { - return metadata_->num_columns(); + int numColumns() const override { + return metadata_->numColumns(); } - int64_t num_rows() const override { - CheckRowsWritten(); - // CheckRowsWritten ensures num_rows_ is set correctly - return num_rows_; + int64_t numRows() const override { + checkRowsWritten(); + // checkRowsWritten() ensures numRows_ is set correctly. + return numRows_; } - ColumnWriter* NextColumn() override { - if (buffered_row_group_) { + ColumnWriter* nextColumn() override { + if (bufferedRowGroup_) { throw ParquetException( - "NextColumn() is not supported when a RowGroup is written by size"); + "nextColumn() is not supported when a RowGroup is written by size"); } - if (column_writers_[0]) { - CheckRowsWritten(); + if (columnWriters_[0]) { + checkRowsWritten(); } - // Throws an error if more columns are being written - auto col_meta = metadata_->NextColumnChunk(); + // Throws an error if more columns are being written. + auto colMeta = metadata_->nextColumnChunk(); - if (column_writers_[0]) { - total_bytes_written_ += column_writers_[0]->Close(); - total_compressed_bytes_written_ += - column_writers_[0]->total_compressed_bytes_written(); + if (columnWriters_[0]) { + totalBytesWritten_ += columnWriters_[0]->close(); + totalCompressedBytesWritten_ += + columnWriters_[0]->totalCompressedBytesWritten(); } - const int32_t column_ordinal = next_column_index_++; - const auto& path = col_meta->descr()->path(); - auto meta_encryptor = file_encryptor_ - ? file_encryptor_->GetColumnMetaEncryptor(path->ToDotString()) + const int32_t columnOrdinal = nextColumnIndex_++; + const auto& path = colMeta->descr()->path(); + auto metaEncryptor = fileEncryptor_ + ? fileEncryptor_->getColumnMetaEncryptor(path->toDotString()) : nullptr; - auto data_encryptor = file_encryptor_ - ? file_encryptor_->GetColumnDataEncryptor(path->ToDotString()) + auto dataEncryptor = fileEncryptor_ + ? fileEncryptor_->getColumnDataEncryptor(path->toDotString()) : nullptr; - auto ci_builder = page_index_builder_ && - properties_->page_index_enabled(path) && - properties_->statistics_enabled(path) - ? page_index_builder_->GetColumnIndexBuilder(column_ordinal) + auto ciBuilder = pageIndexBuilder_ && properties_->pageIndexEnabled(path) && + properties_->statisticsEnabled(path) + ? pageIndexBuilder_->getColumnIndexBuilder(columnOrdinal) : nullptr; - auto oi_builder = - page_index_builder_ && properties_->page_index_enabled(path) - ? page_index_builder_->GetOffsetIndexBuilder(column_ordinal) + auto oiBuilder = pageIndexBuilder_ && properties_->pageIndexEnabled(path) + ? pageIndexBuilder_->getOffsetIndexBuilder(columnOrdinal) : nullptr; - auto codec_options = properties_->codec_options(path) - ? properties_->codec_options(path).get() + auto codecOptions = properties_->codecOptions(path) + ? properties_->codecOptions(path).get() : nullptr; std::unique_ptr pager; - if (!codec_options) { - pager = PageWriter::Open( + if (!codecOptions) { + pager = PageWriter::open( sink_, properties_->compression(path), - col_meta, - row_group_ordinal_, - static_cast(column_ordinal), - properties_->memory_pool(), + colMeta, + rowGroupOrdinal_, + static_cast(columnOrdinal), + properties_->memoryPool(), false, - meta_encryptor, - data_encryptor, - properties_->page_checksum_enabled(), - ci_builder, - oi_builder, + metaEncryptor, + dataEncryptor, + properties_->pageChecksumEnabled(), + ciBuilder, + oiBuilder, CodecOptions()); } else { - pager = PageWriter::Open( + pager = PageWriter::open( sink_, properties_->compression(path), - col_meta, - row_group_ordinal_, - static_cast(column_ordinal), - properties_->memory_pool(), + colMeta, + rowGroupOrdinal_, + static_cast(columnOrdinal), + properties_->memoryPool(), false, - meta_encryptor, - data_encryptor, - properties_->page_checksum_enabled(), - ci_builder, - oi_builder, - *codec_options); + metaEncryptor, + dataEncryptor, + properties_->pageChecksumEnabled(), + ciBuilder, + oiBuilder, + *codecOptions); } - column_writers_[0] = - ColumnWriter::Make(col_meta, std::move(pager), properties_); - return column_writers_[0].get(); + columnWriters_[0] = + ColumnWriter::make(colMeta, std::move(pager), properties_); + return columnWriters_[0].get(); } ColumnWriter* column(int i) override { - if (!buffered_row_group_) { + if (!bufferedRowGroup_) { throw ParquetException( "column() is only supported when a BufferedRowGroup is being written"); } - if (i >= 0 && i < static_cast(column_writers_.size())) { - return column_writers_[i].get(); + if (i >= 0 && i < static_cast(columnWriters_.size())) { + return columnWriters_[i].get(); } return nullptr; } - int current_column() const override { - return metadata_->current_column(); + int currentColumn() const override { + return metadata_->currentColumn(); } - int64_t total_compressed_bytes() const override { - int64_t total_compressed_bytes = 0; - for (size_t i = 0; i < column_writers_.size(); i++) { - if (column_writers_[i]) { - total_compressed_bytes += column_writers_[i]->total_compressed_bytes(); + int64_t totalCompressedBytes() const override { + int64_t totalCompressedBytes = 0; + for (size_t i = 0; i < columnWriters_.size(); i++) { + if (columnWriters_[i]) { + totalCompressedBytes += columnWriters_[i]->totalCompressedBytes(); } } - return total_compressed_bytes; + return totalCompressedBytes; } - int64_t total_bytes_written() const override { + int64_t totalBytesWritten() const override { if (closed_) { - return total_bytes_written_; + return totalBytesWritten_; } - int64_t total_bytes_written = 0; - for (size_t i = 0; i < column_writers_.size(); i++) { - if (column_writers_[i]) { - total_bytes_written += column_writers_[i]->total_bytes_written(); + int64_t totalBytesWritten = 0; + for (size_t i = 0; i < columnWriters_.size(); i++) { + if (columnWriters_[i]) { + totalBytesWritten += columnWriters_[i]->totalBytesWritten(); } } - return total_bytes_written; + return totalBytesWritten; } - int64_t total_compressed_bytes_written() const override { + int64_t totalCompressedBytesWritten() const override { if (closed_) { - return total_compressed_bytes_written_; + return totalCompressedBytesWritten_; } - int64_t total_compressed_bytes_written = 0; - for (size_t i = 0; i < column_writers_.size(); i++) { - if (column_writers_[i]) { - total_compressed_bytes_written += - column_writers_[i]->total_compressed_bytes_written(); + int64_t totalCompressedBytesWritten = 0; + for (size_t i = 0; i < columnWriters_.size(); i++) { + if (columnWriters_[i]) { + totalCompressedBytesWritten += + columnWriters_[i]->totalCompressedBytesWritten(); } } - return total_compressed_bytes_written; + return totalCompressedBytesWritten; } bool buffered() const override { - return buffered_row_group_; + return bufferedRowGroup_; } - void Close() override { + void close() override { if (!closed_) { closed_ = true; - CheckRowsWritten(); - - // Avoid invalid state if ColumnWriter::Close() throws internally. - auto column_writers = std::move(column_writers_); - for (size_t i = 0; i < column_writers.size(); i++) { - if (column_writers[i]) { - total_bytes_written_ += column_writers[i]->Close(); - total_compressed_bytes_written_ += - column_writers[i]->total_compressed_bytes_written(); + checkRowsWritten(); + + // Avoid invalid state if ColumnWriter::close() throws internally. + auto columnWriters = std::move(columnWriters_); + for (size_t i = 0; i < columnWriters.size(); i++) { + if (columnWriters[i]) { + totalBytesWritten_ += columnWriters[i]->close(); + totalCompressedBytesWritten_ += + columnWriters[i]->totalCompressedBytesWritten(); } } - // Ensures all columns have been written - metadata_->set_num_rows(num_rows_); - metadata_->Finish(total_bytes_written_, row_group_ordinal_); + // Ensures all columns have been written. + metadata_->setNumRows(numRows_); + metadata_->finish(totalBytesWritten_, rowGroupOrdinal_); } } @@ -298,212 +296,204 @@ class RowGroupSerializer : public RowGroupWriter::Contents { std::shared_ptr sink_; mutable RowGroupMetaDataBuilder* metadata_; const WriterProperties* properties_; - int64_t total_bytes_written_; - int64_t total_compressed_bytes_written_; + int64_t totalBytesWritten_; + int64_t totalCompressedBytesWritten_; bool closed_; - int16_t row_group_ordinal_; - int next_column_index_; - mutable int64_t num_rows_; - bool buffered_row_group_; - InternalFileEncryptor* file_encryptor_; - PageIndexBuilder* page_index_builder_; - - void CheckRowsWritten() const { - // verify when only one column is written at a time - if (!buffered_row_group_ && column_writers_.size() > 0 && - column_writers_[0]) { - int64_t current_col_rows = column_writers_[0]->rows_written(); - if (num_rows_ == 0) { - num_rows_ = current_col_rows; - } else if (num_rows_ != current_col_rows) { - ThrowRowsMisMatchError(next_column_index_, current_col_rows, num_rows_); + int16_t rowGroupOrdinal_; + int nextColumnIndex_; + mutable int64_t numRows_; + bool bufferedRowGroup_; + InternalFileEncryptor* fileEncryptor_; + PageIndexBuilder* pageIndexBuilder_; + + void checkRowsWritten() const { + // Verify when only one column is written at a time. + if (!bufferedRowGroup_ && columnWriters_.size() > 0 && columnWriters_[0]) { + int64_t currentColRows = columnWriters_[0]->rowsWritten(); + if (numRows_ == 0) { + numRows_ = currentColRows; + } else if (numRows_ != currentColRows) { + throwRowsMisMatchError(nextColumnIndex_, currentColRows, numRows_); } - } else if ( - buffered_row_group_ && - column_writers_.size() > 0) { // when - // buffered_row_group - // = true - VELOX_DCHECK_NOT_NULL(column_writers_[0]); - int64_t current_col_rows = column_writers_[0]->rows_written(); - for (int i = 1; i < static_cast(column_writers_.size()); i++) { - VELOX_DCHECK_NOT_NULL(column_writers_[i]); - int64_t current_col_rows_i = column_writers_[i]->rows_written(); - if (current_col_rows != current_col_rows_i) { - ThrowRowsMisMatchError(i, current_col_rows_i, current_col_rows); + } else if (bufferedRowGroup_ && columnWriters_.size() > 0) { + // When bufferedRowGroup = true. + VELOX_DCHECK_NOT_NULL(columnWriters_[0]); + int64_t currentColRows = columnWriters_[0]->rowsWritten(); + for (int i = 1; i < static_cast(columnWriters_.size()); i++) { + VELOX_DCHECK_NOT_NULL(columnWriters_[i]); + int64_t currentColRowsI = columnWriters_[i]->rowsWritten(); + if (currentColRows != currentColRowsI) { + throwRowsMisMatchError(i, currentColRowsI, currentColRows); } } - num_rows_ = current_col_rows; + numRows_ = currentColRows; } } - void InitColumns() { - for (int i = 0; i < num_columns(); i++) { - auto col_meta = metadata_->NextColumnChunk(); - const auto& path = col_meta->descr()->path(); - const int32_t column_ordinal = next_column_index_++; - auto meta_encryptor = file_encryptor_ - ? file_encryptor_->GetColumnMetaEncryptor(path->ToDotString()) - : nullptr; - auto data_encryptor = file_encryptor_ - ? file_encryptor_->GetColumnDataEncryptor(path->ToDotString()) + void initColumns() { + for (int i = 0; i < numColumns(); i++) { + auto colMeta = metadata_->nextColumnChunk(); + const auto& path = colMeta->descr()->path(); + const int32_t columnOrdinal = nextColumnIndex_++; + auto metaEncryptor = fileEncryptor_ + ? fileEncryptor_->getColumnMetaEncryptor(path->toDotString()) : nullptr; - auto ci_builder = - page_index_builder_ && properties_->page_index_enabled(path) - ? page_index_builder_->GetColumnIndexBuilder(column_ordinal) + auto dataEncryptor = fileEncryptor_ + ? fileEncryptor_->getColumnDataEncryptor(path->toDotString()) : nullptr; - auto oi_builder = - page_index_builder_ && properties_->page_index_enabled(path) - ? page_index_builder_->GetOffsetIndexBuilder(column_ordinal) + auto ciBuilder = pageIndexBuilder_ && properties_->pageIndexEnabled(path) + ? pageIndexBuilder_->getColumnIndexBuilder(columnOrdinal) : nullptr; - auto codec_options = properties_->codec_options(path) - ? (properties_->codec_options(path)).get() + auto oiBuilder = pageIndexBuilder_ && properties_->pageIndexEnabled(path) + ? pageIndexBuilder_->getOffsetIndexBuilder(columnOrdinal) : nullptr; + auto codecOptions = properties_->codecOptions(path); std::unique_ptr pager; - if (!codec_options) { - pager = PageWriter::Open( + if (!codecOptions) { + pager = PageWriter::open( sink_, properties_->compression(path), - col_meta, - row_group_ordinal_, - static_cast(column_ordinal), - properties_->memory_pool(), - buffered_row_group_, - meta_encryptor, - data_encryptor, - properties_->page_checksum_enabled(), - ci_builder, - oi_builder, + colMeta, + rowGroupOrdinal_, + static_cast(columnOrdinal), + properties_->memoryPool(), + bufferedRowGroup_, + metaEncryptor, + dataEncryptor, + properties_->pageChecksumEnabled(), + ciBuilder, + oiBuilder, CodecOptions()); } else { - pager = PageWriter::Open( + pager = PageWriter::open( sink_, properties_->compression(path), - col_meta, - row_group_ordinal_, - static_cast(column_ordinal), - properties_->memory_pool(), - buffered_row_group_, - meta_encryptor, - data_encryptor, - properties_->page_checksum_enabled(), - ci_builder, - oi_builder, - *codec_options); + colMeta, + rowGroupOrdinal_, + static_cast(columnOrdinal), + properties_->memoryPool(), + bufferedRowGroup_, + metaEncryptor, + dataEncryptor, + properties_->pageChecksumEnabled(), + ciBuilder, + oiBuilder, + *codecOptions); } - column_writers_.push_back( - ColumnWriter::Make(col_meta, std::move(pager), properties_)); + columnWriters_.push_back( + ColumnWriter::make(colMeta, std::move(pager), properties_)); } } - std::vector> column_writers_; + std::vector> columnWriters_; }; // ---------------------------------------------------------------------- -// FileSerializer +// FileSerializer. // An implementation of ParquetFileWriter::Contents that deals with the Parquet -// file structure, Thrift serialization, and other internal matters +// file structure, Thrift serialization, and other internal matters. class FileSerializer : public ParquetFileWriter::Contents { public: - static std::unique_ptr Open( + static std::unique_ptr open( std::shared_ptr sink, std::shared_ptr schema, std::shared_ptr properties, - std::shared_ptr key_value_metadata) { + std::shared_ptr keyValueMetadata) { std::unique_ptr result(new FileSerializer( std::move(sink), std::move(schema), std::move(properties), - std::move(key_value_metadata))); + std::move(keyValueMetadata))); return result; } - void Close() override { - if (is_open_) { - // If any functions here raise an exception, we set is_open_ to be false - // so that this does not get called again (possibly causing segfault) - is_open_ = false; - if (row_group_writer_) { - num_rows_ += row_group_writer_->num_rows(); - row_group_writer_->Close(); + void close() override { + if (isOpen_) { + // If any functions here raise an exception, we set isOpen_ to be false + // so that this does not get called again (possibly causing segfault). + isOpen_ = false; + if (rowGroupWriter_) { + numRows_ += rowGroupWriter_->numRows(); + rowGroupWriter_->close(); } - row_group_writer_.reset(); + rowGroupWriter_.reset(); - WritePageIndex(); + writePageIndex(); - // Write magic bytes and metadata - auto file_encryption_properties = - properties_->file_encryption_properties(); + // Write magic bytes and metadata. + auto fileEncryptionProperties = properties_->fileEncryptionProperties(); - if (file_encryption_properties == nullptr) { // Non encrypted file. - file_metadata_ = metadata_->Finish(key_value_metadata_); - WriteFileMetaData(*file_metadata_, sink_.get()); - } else { // Encrypted file - CloseEncryptedFile(file_encryption_properties); + if (fileEncryptionProperties == nullptr) { // Non encrypted file. + fileMetadata_ = metadata_->finish(keyValueMetadata_); + writeFileMetaData(*fileMetadata_, sink_.get()); + } else { // Encrypted file. + closeEncryptedFile(fileEncryptionProperties); } } } - int num_columns() const override { - return schema_.num_columns(); + int numColumns() const override { + return schema_.numColumns(); } - int num_row_groups() const override { - return num_row_groups_; + int numRowGroups() const override { + return numRowGroups_; } - int64_t num_rows() const override { - return num_rows_; + int64_t numRows() const override { + return numRows_; } const std::shared_ptr& properties() const override { return properties_; } - RowGroupWriter* AppendRowGroup(bool buffered_row_group) { - if (row_group_writer_) { - row_group_writer_->Close(); + RowGroupWriter* appendRowGroup(bool bufferedRowGroup) { + if (rowGroupWriter_) { + rowGroupWriter_->close(); } - num_row_groups_++; - auto rg_metadata = metadata_->AppendRowGroup(); - if (page_index_builder_) { - page_index_builder_->AppendRowGroup(); + numRowGroups_++; + auto rgMetadata = metadata_->appendRowGroup(); + if (pageIndexBuilder_) { + pageIndexBuilder_->appendRowGroup(); } std::unique_ptr contents(new RowGroupSerializer( sink_, - rg_metadata, - static_cast(num_row_groups_ - 1), + rgMetadata, + static_cast(numRowGroups_ - 1), properties_.get(), - buffered_row_group, - file_encryptor_.get(), - page_index_builder_.get())); - row_group_writer_ = std::make_unique(std::move(contents)); - return row_group_writer_.get(); + bufferedRowGroup, + fileEncryptor_.get(), + pageIndexBuilder_.get())); + rowGroupWriter_ = std::make_unique(std::move(contents)); + return rowGroupWriter_.get(); } - RowGroupWriter* AppendRowGroup() override { - return AppendRowGroup(false); + RowGroupWriter* appendRowGroup() override { + return appendRowGroup(false); } - RowGroupWriter* AppendBufferedRowGroup() override { - return AppendRowGroup(true); + RowGroupWriter* appendBufferedRowGroup() override { + return appendRowGroup(true); } - void AddKeyValueMetadata(const std::shared_ptr& - key_value_metadata) override { - if (key_value_metadata_ == nullptr) { - key_value_metadata_ = key_value_metadata; - } else if (key_value_metadata != nullptr) { - key_value_metadata_ = key_value_metadata_->Merge(*key_value_metadata); + void addKeyValueMetadata( + const std::shared_ptr& keyValueMetadata) + override { + if (keyValueMetadata_ == nullptr) { + keyValueMetadata_ = keyValueMetadata; + } else if (keyValueMetadata != nullptr) { + keyValueMetadata_ = keyValueMetadata_->Merge(*keyValueMetadata); } } ~FileSerializer() override { try { - FileSerializer::Close(); + FileSerializer::close(); } catch (...) { } } @@ -513,105 +503,103 @@ class FileSerializer : public ParquetFileWriter::Contents { std::shared_ptr sink, std::shared_ptr schema, std::shared_ptr properties, - std::shared_ptr key_value_metadata) + std::shared_ptr keyValueMetadata) : ParquetFileWriter::Contents( std::move(schema), - std::move(key_value_metadata)), + std::move(keyValueMetadata)), sink_(std::move(sink)), - is_open_(true), + isOpen_(true), properties_(std::move(properties)), - num_row_groups_(0), - num_rows_(0), - metadata_(FileMetaDataBuilder::Make(&schema_, properties_)) { + numRowGroups_(0), + numRows_(0), + metadata_(FileMetaDataBuilder::make(&schema_, properties_)) { PARQUET_ASSIGN_OR_THROW(int64_t position, sink_->Tell()); if (position == 0) { - StartFile(); + startFile(); } else { throw ParquetException("Appending to file not implemented."); } } - void CloseEncryptedFile( - FileEncryptionProperties* file_encryption_properties) { - // Encrypted file with encrypted footer - if (file_encryption_properties->encrypted_footer()) { - // encrypted footer - file_metadata_ = metadata_->Finish(key_value_metadata_); + void closeEncryptedFile(FileEncryptionProperties* fileEncryptionProperties) { + // Encrypted file with encrypted footer. + if (fileEncryptionProperties->encryptedFooter()) { + // Encrypted footer. + fileMetadata_ = metadata_->finish(keyValueMetadata_); PARQUET_ASSIGN_OR_THROW(int64_t position, sink_->Tell()); - uint64_t metadata_start = static_cast(position); - auto crypto_metadata = metadata_->GetCryptoMetaData(); - WriteFileCryptoMetaData(*crypto_metadata, sink_.get()); + uint64_t metadataStart = static_cast(position); + auto cryptoMetadata = metadata_->getCryptoMetaData(); + writeFileCryptoMetaData(*cryptoMetadata, sink_.get()); - auto footer_encryptor = file_encryptor_->GetFooterEncryptor(); - WriteEncryptedFileMetadata( - *file_metadata_, sink_.get(), footer_encryptor, true); + auto footerEncryptor = fileEncryptor_->getFooterEncryptor(); + writeEncryptedFileMetadata( + *fileMetadata_, sink_.get(), footerEncryptor, true); PARQUET_ASSIGN_OR_THROW(position, sink_->Tell()); - uint32_t footer_and_crypto_len = - static_cast(position - metadata_start); + uint32_t footerAndCryptoLen = + static_cast(position - metadataStart); PARQUET_THROW_NOT_OK( - sink_->Write(reinterpret_cast(&footer_and_crypto_len), 4)); + sink_->Write(reinterpret_cast(&footerAndCryptoLen), 4)); PARQUET_THROW_NOT_OK(sink_->Write(kParquetEMagic, 4)); } else { // Encrypted file with plaintext footer - file_metadata_ = metadata_->Finish(key_value_metadata_); - auto footer_signing_encryptor = - file_encryptor_->GetFooterSigningEncryptor(); - WriteEncryptedFileMetadata( - *file_metadata_, sink_.get(), footer_signing_encryptor, false); + fileMetadata_ = metadata_->finish(keyValueMetadata_); + auto footerSigningEncryptor = fileEncryptor_->getFooterSigningEncryptor(); + writeEncryptedFileMetadata( + *fileMetadata_, sink_.get(), footerSigningEncryptor, false); } - if (file_encryptor_) { - file_encryptor_->WipeOutEncryptionKeys(); + if (fileEncryptor_) { + fileEncryptor_->wipeOutEncryptionKeys(); } } - void WritePageIndex() { - if (page_index_builder_ != nullptr) { - if (properties_->file_encryption_properties()) { + void writePageIndex() { + if (pageIndexBuilder_ != nullptr) { + if (properties_->fileEncryptionProperties()) { throw ParquetException("Encryption is not supported with page index"); } // Serialize page index after all row groups have been written and report // location to the file metadata. - PageIndexLocation page_index_location; - page_index_builder_->Finish(); - page_index_builder_->WriteTo(sink_.get(), &page_index_location); - metadata_->SetPageIndexLocation(page_index_location); + PageIndexLocation pageIndexLocation; + pageIndexBuilder_->finish(); + pageIndexBuilder_->writeTo(sink_.get(), &pageIndexLocation); + metadata_->setPageIndexLocation(pageIndexLocation); } } std::shared_ptr sink_; - bool is_open_; + bool isOpen_; const std::shared_ptr properties_; - int num_row_groups_; - int64_t num_rows_; + int numRowGroups_; + int64_t numRows_; std::unique_ptr metadata_; - // Only one of the row group writers is active at a time - std::unique_ptr row_group_writer_; - std::unique_ptr page_index_builder_; - std::unique_ptr file_encryptor_; - - void StartFile() { - auto file_encryption_properties = properties_->file_encryption_properties(); - if (file_encryption_properties == nullptr) { - // Unencrypted parquet files always start with PAR1 + // Only one of the row group writers is active at a time. + std::unique_ptr rowGroupWriter_; + std::unique_ptr pageIndexBuilder_; + std::unique_ptr fileEncryptor_; + + void startFile() { + auto fileEncryptionProperties = properties_->fileEncryptionProperties(); + if (fileEncryptionProperties == nullptr) { + // Unencrypted parquet files always start with PAR1. PARQUET_THROW_NOT_OK(sink_->Write(kParquetMagic, 4)); } else { // Check that all columns in columnEncryptionProperties exist in the // schema. - auto encrypted_columns = file_encryption_properties->encrypted_columns(); - // if columnEncryptionProperties is empty, every column in file schema + auto encryptedColumns = fileEncryptionProperties->encryptedColumns(); + // If columnEncryptionProperties is empty, every column in file schema // will be encrypted with footer key. - if (encrypted_columns.size() != 0) { - std::vector column_path_vec; + if (encryptedColumns.size() != 0) { + std::vector columnPathVec; // First, save all column paths in schema. - for (int i = 0; i < num_columns(); i++) { - column_path_vec.push_back(schema_.Column(i)->path()->ToDotString()); + for (int i = 0; i < numColumns(); i++) { + columnPathVec.push_back(schema_.column(i)->path()->toDotString()); } // Check if column exists in schema. - for (const auto& elem : encrypted_columns) { - auto it = std::find( - column_path_vec.begin(), column_path_vec.end(), elem.first); - if (it == column_path_vec.end()) { + for (const auto& elem : encryptedColumns) { + auto it = + std::find(columnPathVec.begin(), columnPathVec.end(), elem.first); + if (it == columnPathVec.end()) { std::stringstream ss; ss << "Encrypted column " + elem.first + " not in file schema"; throw ParquetException(ss.str()); @@ -619,9 +607,9 @@ class FileSerializer : public ParquetFileWriter::Contents { } } - file_encryptor_ = std::make_unique( - file_encryption_properties, properties_->memory_pool()); - if (file_encryption_properties->encrypted_footer()) { + fileEncryptor_ = std::make_unique( + fileEncryptionProperties, properties_->memoryPool()); + if (fileEncryptionProperties->encryptedFooter()) { PARQUET_THROW_NOT_OK(sink_->Write(kParquetEMagic, 4)); } else { // Encrypted file with plaintext footer mode. @@ -629,88 +617,88 @@ class FileSerializer : public ParquetFileWriter::Contents { } } - if (properties_->page_index_enabled()) { - page_index_builder_ = PageIndexBuilder::Make(&schema_); + if (properties_->pageIndexEnabled()) { + pageIndexBuilder_ = PageIndexBuilder::make(&schema_); } } }; // ---------------------------------------------------------------------- -// ParquetFileWriter public API +// ParquetFileWriter public API. ParquetFileWriter::ParquetFileWriter() {} ParquetFileWriter::~ParquetFileWriter() { try { - Close(); + close(); } catch (...) { } } -std::unique_ptr ParquetFileWriter::Open( +std::unique_ptr ParquetFileWriter::open( std::shared_ptr<::arrow::io::OutputStream> sink, std::shared_ptr schema, std::shared_ptr properties, - std::shared_ptr key_value_metadata) { - auto contents = FileSerializer::Open( + std::shared_ptr keyValueMetadata) { + auto contents = FileSerializer::open( std::move(sink), std::move(schema), std::move(properties), - std::move(key_value_metadata)); + std::move(keyValueMetadata)); std::unique_ptr result(new ParquetFileWriter()); - result->Open(std::move(contents)); + result->open(std::move(contents)); return result; } -void WriteFileMetaData( - const FileMetaData& file_metadata, +void writeFileMetaData( + const FileMetaData& fileMetadata, ArrowOutputStream* sink) { - // Write MetaData + // Write metadata. PARQUET_ASSIGN_OR_THROW(int64_t position, sink->Tell()); - uint32_t metadata_len = static_cast(position); + uint32_t metadataLen = static_cast(position); - file_metadata.WriteTo(sink); + fileMetadata.writeTo(sink); PARQUET_ASSIGN_OR_THROW(position, sink->Tell()); - metadata_len = static_cast(position) - metadata_len; + metadataLen = static_cast(position) - metadataLen; - // Write Footer + // Write Footer. PARQUET_THROW_NOT_OK( - sink->Write(reinterpret_cast(&metadata_len), 4)); + sink->Write(reinterpret_cast(&metadataLen), 4)); PARQUET_THROW_NOT_OK(sink->Write(kParquetMagic, 4)); } -void WriteMetaDataFile( - const FileMetaData& file_metadata, +void writeMetaDataFile( + const FileMetaData& fileMetadata, ArrowOutputStream* sink) { PARQUET_THROW_NOT_OK(sink->Write(kParquetMagic, 4)); - return WriteFileMetaData(file_metadata, sink); + return writeFileMetaData(fileMetadata, sink); } -void WriteEncryptedFileMetadata( - const FileMetaData& file_metadata, +void writeEncryptedFileMetadata( + const FileMetaData& fileMetadata, ArrowOutputStream* sink, const std::shared_ptr& encryptor, - bool encrypt_footer) { - if (encrypt_footer) { // Encrypted file with encrypted footer - // encrypt and write to sink - file_metadata.WriteTo(sink, encryptor); + bool encryptFooter) { + if (encryptFooter) { // Encrypted file with encrypted footer. + // Encrypt and write to sink. + fileMetadata.writeTo(sink, encryptor); } else { // Encrypted file with plaintext footer mode. PARQUET_ASSIGN_OR_THROW(int64_t position, sink->Tell()); - uint32_t metadata_len = static_cast(position); - file_metadata.WriteTo(sink, encryptor); + uint32_t metadataLen = static_cast(position); + fileMetadata.writeTo(sink, encryptor); PARQUET_ASSIGN_OR_THROW(position, sink->Tell()); - metadata_len = static_cast(position) - metadata_len; + metadataLen = static_cast(position) - metadataLen; PARQUET_THROW_NOT_OK( - sink->Write(reinterpret_cast(&metadata_len), 4)); + sink->Write(reinterpret_cast(&metadataLen), 4)); PARQUET_THROW_NOT_OK(sink->Write(kParquetMagic, 4)); } } -void WriteFileCryptoMetaData( - const FileCryptoMetaData& crypto_metadata, +void writeFileCryptoMetaData( + const FileCryptoMetaData& cryptoMetadata, ArrowOutputStream* sink) { - crypto_metadata.WriteTo(sink); + cryptoMetadata.writeTo(sink); } const SchemaDescriptor* ParquetFileWriter::schema() const { @@ -718,59 +706,59 @@ const SchemaDescriptor* ParquetFileWriter::schema() const { } const ColumnDescriptor* ParquetFileWriter::descr(int i) const { - return contents_->schema()->Column(i); + return contents_->schema()->column(i); } -int ParquetFileWriter::num_columns() const { - return contents_->num_columns(); +int ParquetFileWriter::numColumns() const { + return contents_->numColumns(); } -int64_t ParquetFileWriter::num_rows() const { - return contents_->num_rows(); +int64_t ParquetFileWriter::numRows() const { + return contents_->numRows(); } -int ParquetFileWriter::num_row_groups() const { - return contents_->num_row_groups(); +int ParquetFileWriter::numRowGroups() const { + return contents_->numRowGroups(); } const std::shared_ptr& -ParquetFileWriter::key_value_metadata() const { - return contents_->key_value_metadata(); +ParquetFileWriter::keyValueMetadata() const { + return contents_->keyValueMetadata(); } const std::shared_ptr ParquetFileWriter::metadata() const { - return file_metadata_; + return fileMetadata_; } -void ParquetFileWriter::Open( +void ParquetFileWriter::open( std::unique_ptr contents) { contents_ = std::move(contents); } -void ParquetFileWriter::Close() { +void ParquetFileWriter::close() { if (contents_) { - contents_->Close(); - file_metadata_ = contents_->metadata(); + contents_->close(); + fileMetadata_ = contents_->metadata(); contents_.reset(); } } -RowGroupWriter* ParquetFileWriter::AppendRowGroup() { - return contents_->AppendRowGroup(); +RowGroupWriter* ParquetFileWriter::appendRowGroup() { + return contents_->appendRowGroup(); } -RowGroupWriter* ParquetFileWriter::AppendBufferedRowGroup() { - return contents_->AppendBufferedRowGroup(); +RowGroupWriter* ParquetFileWriter::appendBufferedRowGroup() { + return contents_->appendBufferedRowGroup(); } -RowGroupWriter* ParquetFileWriter::AppendRowGroup(int64_t num_rows) { - return AppendRowGroup(); +RowGroupWriter* ParquetFileWriter::appendRowGroup(int64_t numRows) { + return appendRowGroup(); } -void ParquetFileWriter::AddKeyValueMetadata( - const std::shared_ptr& key_value_metadata) { +void ParquetFileWriter::addKeyValueMetadata( + const std::shared_ptr& keyValueMetadata) { if (contents_) { - contents_->AddKeyValueMetadata(key_value_metadata); + contents_->addKeyValueMetadata(keyValueMetadata); } else { throw ParquetException("Cannot add key-value metadata to closed file"); } diff --git a/velox/dwio/parquet/writer/arrow/FileWriter.h b/velox/dwio/parquet/writer/arrow/FileWriter.h index 27a2eafda03..5e466929769 100644 --- a/velox/dwio/parquet/writer/arrow/FileWriter.h +++ b/velox/dwio/parquet/writer/arrow/FileWriter.h @@ -31,34 +31,34 @@ namespace facebook::velox::parquet::arrow { class ColumnWriter; -// FIXME: copied from reader-internal.cc +// FIXME: copied from reader-internal.cc. static constexpr uint8_t kParquetMagic[4] = {'P', 'A', 'R', '1'}; static constexpr uint8_t kParquetEMagic[4] = {'P', 'A', 'R', 'E'}; class PARQUET_EXPORT RowGroupWriter { public: // Forward declare a virtual class 'Contents' to aid dependency injection and - // more easily create test fixtures An implementation of the Contents class is - // defined in the .cc file + // more easily create test fixtures. An implementation of the Contents class + // is defined in the .cpp file. struct Contents { virtual ~Contents() = default; - virtual int num_columns() const = 0; - virtual int64_t num_rows() const = 0; + virtual int numColumns() const = 0; + virtual int64_t numRows() const = 0; - // to be used only with ParquetFileWriter::AppendRowGroup - virtual ColumnWriter* NextColumn() = 0; - // to be used only with ParquetFileWriter::AppendBufferedRowGroup + // To be used only with ParquetFileWriter::AppendRowGroup. + virtual ColumnWriter* nextColumn() = 0; + // To be used only with ParquetFileWriter::AppendBufferedRowGroup. virtual ColumnWriter* column(int i) = 0; - virtual int current_column() const = 0; - virtual void Close() = 0; + virtual int currentColumn() const = 0; + virtual void close() = 0; - /// \brief total uncompressed bytes written by the page writer - virtual int64_t total_bytes_written() const = 0; - /// \brief total bytes still compressed but not written by the page writer - virtual int64_t total_compressed_bytes() const = 0; - /// \brief total compressed bytes written by the page writer - virtual int64_t total_compressed_bytes_written() const = 0; + /// \brief Total uncompressed bytes written by the page writer + virtual int64_t totalBytesWritten() const = 0; + /// \brief Total bytes still compressed but not written by the page writer + virtual int64_t totalCompressedBytes() const = 0; + /// \brief Total compressed bytes written by the page writer + virtual int64_t totalCompressedBytesWritten() const = 0; virtual bool buffered() const = 0; }; @@ -67,208 +67,207 @@ class PARQUET_EXPORT RowGroupWriter { /// Construct a ColumnWriter for the indicated row group-relative column. /// - /// To be used only with ParquetFileWriter::AppendRowGroup + /// To be used only with ParquetFileWriter::appendRowGroup(). /// Ownership is solely within the RowGroupWriter. The ColumnWriter is only - /// valid until the next call to NextColumn or Close. As the contents are + /// valid until the next call to nextColumn() or close(). As the contents are /// directly written to the sink, once a new column is started, the contents /// of the previous one cannot be modified anymore. - ColumnWriter* NextColumn(); - /// Index of currently written column. Equal to -1 if NextColumn() + ColumnWriter* nextColumn(); + /// Index of currently written column. Equal to -1 if nextColumn() /// has not been called yet. - int current_column(); - void Close(); + int currentColumn(); + void close(); - int num_columns() const; + int numColumns() const; /// Construct a ColumnWriter for the indicated row group column. /// - /// To be used only with ParquetFileWriter::AppendBufferedRowGroup + /// To be used only with ParquetFileWriter::appendBufferedRowGroup(). /// Ownership is solely within the RowGroupWriter. The ColumnWriter is - /// valid until Close. The contents are buffered in memory and written to sink - /// on Close + /// valid until close(). The contents are buffered in memory and written to + /// sink on close(). ColumnWriter* column(int i); /** * Number of rows that shall be written as part of this RowGroup. */ - int64_t num_rows() const; + int64_t numRows() const; - /// \brief total uncompressed bytes written by the page writer - int64_t total_bytes_written() const; - /// \brief total bytes still compressed but not written by the page writer. + /// \brief Total uncompressed bytes written by the page writer. + int64_t totalBytesWritten() const; + /// \brief Total bytes still compressed but not written by the page writer. /// It will always return 0 from the SerializedPageWriter. - int64_t total_compressed_bytes() const; - /// \brief total compressed bytes written by the page writer - int64_t total_compressed_bytes_written() const; - + int64_t totalCompressedBytes() const; + /// \brief Total compressed bytes written by the page writer + int64_t totalCompressedBytesWritten() const; /// Returns whether the current RowGroupWriter is in the buffered mode and is - /// created by calling ParquetFileWriter::AppendBufferedRowGroup. + /// created by calling ParquetFileWriter::appendBufferedRowGroup(). bool buffered() const; private: - // Holds a pointer to an instance of Contents implementation + // Holds a pointer to an instance of Contents implementation. std::unique_ptr contents_; }; PARQUET_EXPORT -void WriteFileMetaData( - const FileMetaData& file_metadata, +void writeFileMetaData( + const FileMetaData& fileMetadata, ::arrow::io::OutputStream* sink); PARQUET_EXPORT -void WriteMetaDataFile( - const FileMetaData& file_metadata, +void writeMetaDataFile( + const FileMetaData& fileMetadata, ::arrow::io::OutputStream* sink); PARQUET_EXPORT -void WriteEncryptedFileMetadata( - const FileMetaData& file_metadata, +void writeEncryptedFileMetadata( + const FileMetaData& fileMetadata, ArrowOutputStream* sink, const std::shared_ptr& encryptor, - bool encrypt_footer); + bool encryptFooter); PARQUET_EXPORT -void WriteEncryptedFileMetadata( - const FileMetaData& file_metadata, +void writeEncryptedFileMetadata( + const FileMetaData& fileMetadata, ::arrow::io::OutputStream* sink, const std::shared_ptr& encryptor = NULLPTR, - bool encrypt_footer = false); + bool encryptFooter = false); PARQUET_EXPORT -void WriteFileCryptoMetaData( - const FileCryptoMetaData& crypto_metadata, +void writeFileCryptoMetaData( + const FileCryptoMetaData& cryptoMetadata, ::arrow::io::OutputStream* sink); class PARQUET_EXPORT ParquetFileWriter { public: // Forward declare a virtual class 'Contents' to aid dependency injection and - // more easily create test fixtures An implementation of the Contents class is - // defined in the .cc file + // more easily create test fixtures. An implementation of the Contents class + // is defined in the .cpp file. struct Contents { Contents( std::shared_ptr schema, - std::shared_ptr key_value_metadata) - : schema_(), key_value_metadata_(std::move(key_value_metadata)) { - schema_.Init(std::move(schema)); + std::shared_ptr keyValueMetadata) + : schema_(), keyValueMetadata_(std::move(keyValueMetadata)) { + schema_.init(std::move(schema)); } virtual ~Contents() {} - // Perform any cleanup associated with the file contents - virtual void Close() = 0; + // Perform any cleanup associated with the file contents. + virtual void close() = 0; - /// \note Deprecated since 1.3.0 - RowGroupWriter* AppendRowGroup(int64_t num_rows); + /// \note Deprecated since 1.3.0. + RowGroupWriter* appendRowGroup(int64_t numRows); - virtual RowGroupWriter* AppendRowGroup() = 0; - virtual RowGroupWriter* AppendBufferedRowGroup() = 0; + virtual RowGroupWriter* appendRowGroup() = 0; + virtual RowGroupWriter* appendBufferedRowGroup() = 0; - virtual int64_t num_rows() const = 0; - virtual int num_columns() const = 0; - virtual int num_row_groups() const = 0; + virtual int64_t numRows() const = 0; + virtual int numColumns() const = 0; + virtual int numRowGroups() const = 0; virtual const std::shared_ptr& properties() const = 0; - const std::shared_ptr& key_value_metadata() const { - return key_value_metadata_; + const std::shared_ptr& keyValueMetadata() const { + return keyValueMetadata_; } - virtual void AddKeyValueMetadata( - const std::shared_ptr& key_value_metadata) = 0; + virtual void addKeyValueMetadata( + const std::shared_ptr& keyValueMetadata) = 0; - // Return const-pointer to make it clear that this object is not to be - // copied + // Return const pointer to make it clear that this object is not to be + // copied. const SchemaDescriptor* schema() const { return &schema_; } SchemaDescriptor schema_; - /// This should be the only place this is stored. Everything else is a const - /// reference - std::shared_ptr key_value_metadata_; + /// This should be the only place this is stored. Everything else is a + /// const reference. + std::shared_ptr keyValueMetadata_; const std::shared_ptr& metadata() const { - return file_metadata_; + return fileMetadata_; } - std::shared_ptr file_metadata_; + std::shared_ptr fileMetadata_; }; ParquetFileWriter(); ~ParquetFileWriter(); - static std::unique_ptr Open( + static std::unique_ptr open( std::shared_ptr<::arrow::io::OutputStream> sink, std::shared_ptr schema, - std::shared_ptr properties = - default_writer_properties(), - std::shared_ptr key_value_metadata = NULLPTR); + std::shared_ptr properties = defaultWriterProperties(), + std::shared_ptr keyValueMetadata = NULLPTR); - void Open(std::unique_ptr contents); - void Close(); + void open(std::unique_ptr contents); + void close(); // Construct a RowGroupWriter for the indicated number of rows. // // Ownership is solely within the ParquetFileWriter. The RowGroupWriter is - // only valid until the next call to AppendRowGroup or AppendBufferedRowGroup - // or Close. - // @param num_rows The number of rows that are stored in the new RowGroup + // only valid until the next call to appendRowGroup() or + // appendBufferedRowGroup() or close(). + // @param numRows The number of rows that are stored in the new RowGroup. // - // \deprecated Since 1.3.0 - RowGroupWriter* AppendRowGroup(int64_t num_rows); + // \deprecated Since 1.3.0. + RowGroupWriter* appendRowGroup(int64_t numRows); /// Construct a RowGroupWriter with an arbitrary number of rows. /// /// Ownership is solely within the ParquetFileWriter. The RowGroupWriter is - /// only valid until the next call to AppendRowGroup or AppendBufferedRowGroup - /// or Close. - RowGroupWriter* AppendRowGroup(); + /// only valid until the next call to appendRowGroup() or + /// appendBufferedRowGroup() or close(). + RowGroupWriter* appendRowGroup(); /// Construct a RowGroupWriter that buffers all the values until the RowGroup - /// is ready. Use this if you want to write a RowGroup based on a certain size + /// is ready. Use this if you want to write a RowGroup based on a certain + /// size. /// /// Ownership is solely within the ParquetFileWriter. The RowGroupWriter is - /// only valid until the next call to AppendRowGroup or AppendBufferedRowGroup - /// or Close. - RowGroupWriter* AppendBufferedRowGroup(); + /// only valid until the next call to appendRowGroup() or + /// appendBufferedRowGroup() or close(). + RowGroupWriter* appendBufferedRowGroup(); /// \brief Add key-value metadata to the file. - /// \param[in] key_value_metadata the metadata to add. + /// \param[in] keyValueMetadata The metadata to add. /// \note This will overwrite any existing metadata with the same key. - /// \throw ParquetException if Close() has been called. - void AddKeyValueMetadata( - const std::shared_ptr& key_value_metadata); + /// \throw ParquetException if close() has been called. + void addKeyValueMetadata( + const std::shared_ptr& keyValueMetadata); /// Number of columns. /// - /// This number is fixed during the lifetime of the writer as it is determined - /// via the schema. - int num_columns() const; + /// This number is fixed during the lifetime of the writer as it is + /// determined via the schema. + int numColumns() const; /// Number of rows in the yet started RowGroups. /// /// Changes on the addition of a new RowGroup. - int64_t num_rows() const; + int64_t numRows() const; /// Number of started RowGroups. - int num_row_groups() const; + int numRowGroups() const; /// Configuration passed to the writer, e.g. the used Parquet format version. const std::shared_ptr& properties() const; - /// Returns the file schema descriptor + /// Returns the file schema descriptor. const SchemaDescriptor* schema() const; - /// Returns a column descriptor in schema + /// Returns a column descriptor in schema. const ColumnDescriptor* descr(int i) const; - /// Returns the file custom metadata - const std::shared_ptr& key_value_metadata() const; + /// Returns the file custom metadata. + const std::shared_ptr& keyValueMetadata() const; - /// Returns the file metadata, only available after calling Close(). + /// Returns the file metadata, only available after calling close(). const std::shared_ptr metadata() const; private: - // Holds a pointer to an instance of Contents implementation + // Holds a pointer to an instance of Contents implementation. std::unique_ptr contents_; - std::shared_ptr file_metadata_; + std::shared_ptr fileMetadata_; }; } // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/Metadata.cpp b/velox/dwio/parquet/writer/arrow/Metadata.cpp index ef978a01a74..b62d9ecea85 100644 --- a/velox/dwio/parquet/writer/arrow/Metadata.cpp +++ b/velox/dwio/parquet/writer/arrow/Metadata.cpp @@ -62,15 +62,15 @@ const ApplicationVersion& ApplicationVersion::PARQUET_MR_FIXED_STATS_VERSION() { const ApplicationVersion& ApplicationVersion::PARQUET_CPP_10353_FIXED_VERSION() { - // parquet-cpp versions released prior to Arrow 3.0 would write DataPageV2 - // pages with is_compressed==0 but still write compressed data. (See: - // ARROW-10353). Parquet 1.5.1 had this problem, and after that we switched to - // the application name "parquet-cpp-arrow", so this version is fake. + // Parquet-cpp versions released prior to Arrow 3.0 would write DataPageV2. + // Pages with is_compressed==0 but still write compressed data. (See: + // ARROW-10353). Parquet 1.5.1 had this problem, and after that we switched + // to. The application name "parquet-cpp-arrow", so this version is fake. static ApplicationVersion version("parquet-cpp", 2, 0, 0); return version; } -std::string ParquetVersionToString(ParquetVersion::type ver) { +std::string parquetVersionToString(ParquetVersion::type ver) { switch (ver) { case ParquetVersion::PARQUET_1_0: return "1.0"; @@ -84,294 +84,297 @@ std::string ParquetVersionToString(ParquetVersion::type ver) { return "2.6"; } - // This should be unreachable + // This should be unreachable. return "UNKNOWN"; } template -static std::shared_ptr MakeTypedColumnStats( +static std::shared_ptr makeTypedColumnStats( const facebook::velox::parquet::thrift::ColumnMetaData& metadata, const ColumnDescriptor* descr) { - // If ColumnOrder is defined, return max_value and min_value - if (descr->column_order().get_order() == ColumnOrder::TYPE_DEFINED_ORDER) { - return MakeStatistics( + // If ColumnOrder is defined, return max_value and min_value. + const auto& stats = metadata.statistics; + if (descr->columnOrder().order() == ColumnOrder::kTypeDefinedOrder) { + return makeStatistics( descr, - metadata.statistics.min_value, - metadata.statistics.max_value, - metadata.num_values - metadata.statistics.null_count, - metadata.statistics.null_count, - metadata.statistics.distinct_count, - metadata.statistics.__isset.max_value || - metadata.statistics.__isset.min_value, - metadata.statistics.__isset.null_count, - metadata.statistics.__isset.distinct_count); - } - // Default behavior - return MakeStatistics( + stats.min_value, + stats.max_value, + metadata.num_values - stats.null_count, + stats.null_count, + stats.distinct_count, + stats.__isset.max_value || stats.__isset.min_value, + stats.__isset.null_count, + stats.__isset.distinct_count, + false, + 0); + } + // Default behavior. + return makeStatistics( descr, - metadata.statistics.min, - metadata.statistics.max, - metadata.num_values - metadata.statistics.null_count, - metadata.statistics.null_count, - metadata.statistics.distinct_count, - metadata.statistics.__isset.max || metadata.statistics.__isset.min, - metadata.statistics.__isset.null_count, - metadata.statistics.__isset.distinct_count); -} - -std::shared_ptr MakeColumnStats( + stats.min, + stats.max, + metadata.num_values - stats.null_count, + stats.null_count, + stats.distinct_count, + stats.__isset.max || stats.__isset.min, + stats.__isset.null_count, + stats.__isset.distinct_count, + false, + 0); +} + +std::shared_ptr makeColumnStats( const facebook::velox::parquet::thrift::ColumnMetaData& meta_data, const ColumnDescriptor* descr) { switch (static_cast(meta_data.type)) { - case Type::BOOLEAN: - return MakeTypedColumnStats(meta_data, descr); - case Type::INT32: - return MakeTypedColumnStats(meta_data, descr); - case Type::INT64: - return MakeTypedColumnStats(meta_data, descr); - case Type::INT96: - return MakeTypedColumnStats(meta_data, descr); - case Type::DOUBLE: - return MakeTypedColumnStats(meta_data, descr); - case Type::FLOAT: - return MakeTypedColumnStats(meta_data, descr); - case Type::BYTE_ARRAY: - return MakeTypedColumnStats(meta_data, descr); - case Type::FIXED_LEN_BYTE_ARRAY: - return MakeTypedColumnStats(meta_data, descr); - case Type::UNDEFINED: + case Type::kBoolean: + return makeTypedColumnStats(meta_data, descr); + case Type::kInt32: + return makeTypedColumnStats(meta_data, descr); + case Type::kInt64: + return makeTypedColumnStats(meta_data, descr); + case Type::kInt96: + return makeTypedColumnStats(meta_data, descr); + case Type::kDouble: + return makeTypedColumnStats(meta_data, descr); + case Type::kFloat: + return makeTypedColumnStats(meta_data, descr); + case Type::kByteArray: + return makeTypedColumnStats(meta_data, descr); + case Type::kFixedLenByteArray: + return makeTypedColumnStats(meta_data, descr); + case Type::kUndefined: break; } throw ParquetException( - "Can't decode page statistics for selected column type"); + "Can't decode page statistics for selected column type."); } -// MetaData Accessor +// MetaData Accessor. -// ColumnCryptoMetaData +// ColumnCryptoMetaData. class ColumnCryptoMetaData::ColumnCryptoMetaDataImpl { public: explicit ColumnCryptoMetaDataImpl( const facebook::velox::parquet::thrift::ColumnCryptoMetaData* - crypto_metadata) - : crypto_metadata_(crypto_metadata) {} + cryptoMetadata) + : cryptoMetadata_(cryptoMetadata) {} - bool encrypted_with_footer_key() const { - return crypto_metadata_->__isset.ENCRYPTION_WITH_FOOTER_KEY; + bool encryptedWithFooterKey() const { + return cryptoMetadata_->__isset.ENCRYPTION_WITH_FOOTER_KEY; } - bool encrypted_with_column_key() const { - return crypto_metadata_->__isset.ENCRYPTION_WITH_COLUMN_KEY; + bool encryptedWithColumnKey() const { + return cryptoMetadata_->__isset.ENCRYPTION_WITH_COLUMN_KEY; } - std::shared_ptr path_in_schema() const { + std::shared_ptr pathInSchema() const { return std::make_shared( - crypto_metadata_->ENCRYPTION_WITH_COLUMN_KEY.path_in_schema); + cryptoMetadata_->ENCRYPTION_WITH_COLUMN_KEY.path_in_schema); } - const std::string& key_metadata() const { - return crypto_metadata_->ENCRYPTION_WITH_COLUMN_KEY.key_metadata; + const std::string& keyMetadata() const { + return cryptoMetadata_->ENCRYPTION_WITH_COLUMN_KEY.key_metadata; } private: - const facebook::velox::parquet::thrift::ColumnCryptoMetaData* - crypto_metadata_; + const facebook::velox::parquet::thrift::ColumnCryptoMetaData* cryptoMetadata_; }; -std::unique_ptr ColumnCryptoMetaData::Make( +std::unique_ptr ColumnCryptoMetaData::make( const uint8_t* metadata) { return std::unique_ptr( new ColumnCryptoMetaData(metadata)); } ColumnCryptoMetaData::ColumnCryptoMetaData(const uint8_t* metadata) - : impl_(std::make_unique( - reinterpret_cast< - const facebook::velox::parquet::thrift::ColumnCryptoMetaData*>( - metadata))) {} + : impl_( + std::make_unique( + reinterpret_cast(metadata))) {} ColumnCryptoMetaData::~ColumnCryptoMetaData() = default; -std::shared_ptr ColumnCryptoMetaData::path_in_schema() - const { - return impl_->path_in_schema(); +std::shared_ptr ColumnCryptoMetaData::pathInSchema() const { + return impl_->pathInSchema(); } -bool ColumnCryptoMetaData::encrypted_with_footer_key() const { - return impl_->encrypted_with_footer_key(); +bool ColumnCryptoMetaData::encryptedWithFooterKey() const { + return impl_->encryptedWithFooterKey(); } -const std::string& ColumnCryptoMetaData::key_metadata() const { - return impl_->key_metadata(); +const std::string& ColumnCryptoMetaData::keyMetadata() const { + return impl_->keyMetadata(); } -// ColumnChunk metadata +// ColumnChunk metadata. class ColumnChunkMetaData::ColumnChunkMetaDataImpl { public: explicit ColumnChunkMetaDataImpl( const facebook::velox::parquet::thrift::ColumnChunk* column, const ColumnDescriptor* descr, - int16_t row_group_ordinal, - int16_t column_ordinal, + int16_t rowGroupOrdinal, + int16_t columnOrdinal, const ReaderProperties& properties, - const ApplicationVersion* writer_version, - std::shared_ptr file_decryptor) + const ApplicationVersion* writerVersion, + std::shared_ptr fileDecryptor) : column_(column), descr_(descr), properties_(properties), - writer_version_(writer_version) { - column_metadata_ = &column->meta_data; + writerVersion_(writerVersion) { + columnMetadata_ = &column->meta_data; if (column->__isset.crypto_metadata) { // column metadata is encrypted facebook::velox::parquet::thrift::ColumnCryptoMetaData ccmd = column->crypto_metadata; if (ccmd.__isset.ENCRYPTION_WITH_COLUMN_KEY) { - if (file_decryptor != nullptr && - file_decryptor->properties() != nullptr) { - // should decrypt metadata + if (fileDecryptor != nullptr && + fileDecryptor->properties() != nullptr) { + // Should decrypt metadata. std::shared_ptr path = std::make_shared( ccmd.ENCRYPTION_WITH_COLUMN_KEY.path_in_schema); std::string key_metadata = ccmd.ENCRYPTION_WITH_COLUMN_KEY.key_metadata; - std::string aad_column_metadata = encryption::CreateModuleAad( - file_decryptor->file_aad(), + std::string aadColumnMetadata = encryption::createModuleAad( + fileDecryptor->fileAad(), encryption::kColumnMetaData, - row_group_ordinal, - column_ordinal, + rowGroupOrdinal, + columnOrdinal, static_cast(-1)); - auto decryptor = file_decryptor->GetColumnMetaDecryptor( - path->ToDotString(), key_metadata, aad_column_metadata); + auto Decryptor = fileDecryptor->getColumnMetaDecryptor( + path->toDotString(), key_metadata, aadColumnMetadata); auto len = static_cast(column->encrypted_column_metadata.size()); ThriftDeserializer deserializer(properties_); - deserializer.DeserializeMessage( + deserializer.deserializeMessage( reinterpret_cast( column->encrypted_column_metadata.c_str()), &len, - &decrypted_metadata_, - decryptor); - column_metadata_ = &decrypted_metadata_; + &decryptedMetadata_, + Decryptor); + columnMetadata_ = &decryptedMetadata_; } else { throw ParquetException( "Cannot decrypt ColumnMetadata." - " FileDecryption is not setup correctly"); + " FileDecryption is not setup correctly."); } } } - for (const auto& encoding : column_metadata_->encodings) { - encodings_.push_back(LoadEnumSafe(&encoding)); + for (const auto& encoding : columnMetadata_->encodings) { + encodings_.push_back(loadenumSafe(&encoding)); } - for (const auto& encoding_stats : column_metadata_->encoding_stats) { - encoding_stats_.push_back( - {LoadEnumSafe(&encoding_stats.page_type), - LoadEnumSafe(&encoding_stats.encoding), - encoding_stats.count}); + for (const auto& encodingStats : columnMetadata_->encoding_stats) { + encodingStats_.push_back( + {loadenumSafe(&encodingStats.page_type), + loadenumSafe(&encodingStats.encoding), + encodingStats.count}); } - possible_stats_ = nullptr; + possibleStats_ = nullptr; } - bool Equals(const ColumnChunkMetaDataImpl& other) const { - return *column_metadata_ == *other.column_metadata_; + bool equals(const ColumnChunkMetaDataImpl& other) const { + return *columnMetadata_ == *other.columnMetadata_; } - // column chunk - inline int64_t file_offset() const { + // Column chunk. + inline int64_t fileOffset() const { return column_->file_offset; } - inline const std::string& file_path() const { + inline const std::string& filePath() const { return column_->file_path; } inline Type::type type() const { - return LoadEnumSafe(&column_metadata_->type); + return loadenumSafe(&columnMetadata_->type); } - inline int64_t num_values() const { - return column_metadata_->num_values; + inline int64_t numValues() const { + return columnMetadata_->num_values; } - std::shared_ptr path_in_schema() { + std::shared_ptr pathInSchema() { return std::make_shared( - column_metadata_->path_in_schema); - } - - // Check if statistics are set and are valid - // 1) Must be set in the metadata - // 2) Statistics must not be corrupted - inline bool is_stats_set() const { - VELOX_DCHECK_NOT_NULL(writer_version_); - // If the column statistics don't exist or column sort order is unknown - // we cannot use the column stats - if (!column_metadata_->__isset.statistics || - descr_->sort_order() == SortOrder::UNKNOWN) { + columnMetadata_->path_in_schema); + } + + // Check if statistics are set and are valid. + // 1) Must be set in the metadata. + // 2) Statistics must not be corrupted. + inline bool isStatsSet() const { + VELOX_DCHECK_NOT_NULL(writerVersion_); + // If the column statistics don't exist or column sort order is unknown, + // we cannot use the column stats. + if (!columnMetadata_->__isset.statistics || + descr_->sortOrder() == SortOrder::kUnknown) { return false; } - if (possible_stats_ == nullptr) { - possible_stats_ = MakeColumnStats(*column_metadata_, descr_); + if (possibleStats_ == nullptr) { + possibleStats_ = makeColumnStats(*columnMetadata_, descr_); } - EncodedStatistics encodedStatistics = possible_stats_->Encode(); - return writer_version_->HasCorrectStatistics( - type(), encodedStatistics, descr_->sort_order()); + EncodedStatistics encodedStats = possibleStats_->encode(); + return writerVersion_->hasCorrectStatistics( + type(), encodedStats, descr_->sortOrder()); } - inline std::shared_ptr statistics() const { - return is_stats_set() ? possible_stats_ : nullptr; + inline std::shared_ptr<::facebook::velox::parquet::arrow::Statistics> + statistics() const { + return isStatsSet() ? possibleStats_ : nullptr; } inline Compression::type compression() const { - return LoadEnumSafe(&column_metadata_->codec); + return loadenumSafe(&columnMetadata_->codec); } const std::vector& encodings() const { return encodings_; } - const std::vector& encoding_stats() const { - return encoding_stats_; + const std::vector& encodingStats() const { + return encodingStats_; } - inline std::optional bloom_filter_offset() const { - if (column_metadata_->__isset.bloom_filter_offset) { - return column_metadata_->bloom_filter_offset; + inline std::optional bloomFilterOffset() const { + if (columnMetadata_->__isset.bloom_filter_offset) { + return columnMetadata_->bloom_filter_offset; } return std::nullopt; } - inline bool has_dictionary_page() const { - return column_metadata_->__isset.dictionary_page_offset; + inline bool hasDictionaryPage() const { + return columnMetadata_->__isset.dictionary_page_offset; } - inline int64_t dictionary_page_offset() const { - return column_metadata_->dictionary_page_offset; + inline int64_t dictionaryPageOffset() const { + return columnMetadata_->dictionary_page_offset; } - inline int64_t data_page_offset() const { - return column_metadata_->data_page_offset; + inline int64_t dataPageOffset() const { + return columnMetadata_->data_page_offset; } - inline bool has_index_page() const { - return column_metadata_->__isset.index_page_offset; + inline bool hasIndexPage() const { + return columnMetadata_->__isset.index_page_offset; } - inline int64_t index_page_offset() const { - return column_metadata_->index_page_offset; + inline int64_t indexPageOffset() const { + return columnMetadata_->index_page_offset; } - inline int64_t total_compressed_size() const { - return column_metadata_->total_compressed_size; + inline int64_t totalCompressedSize() const { + return columnMetadata_->total_compressed_size; } - inline int64_t total_uncompressed_size() const { - return column_metadata_->total_uncompressed_size; + inline int64_t totalUncompressedSize() const { + return columnMetadata_->total_uncompressed_size; } - inline std::unique_ptr crypto_metadata() const { + inline std::unique_ptr cryptoMetadata() const { if (column_->__isset.crypto_metadata) { - return ColumnCryptoMetaData::Make( + return ColumnCryptoMetaData::make( reinterpret_cast(&column_->crypto_metadata)); } else { return nullptr; } } - std::optional GetColumnIndexLocation() const { + std::optional getColumnIndexLocation() const { if (column_->__isset.column_index_offset && column_->__isset.column_index_length) { return IndexLocation{ @@ -380,7 +383,7 @@ class ColumnChunkMetaData::ColumnChunkMetaDataImpl { return std::nullopt; } - std::optional GetOffsetIndexLocation() const { + std::optional getOffsetIndexLocation() const { if (column_->__isset.offset_index_offset && column_->__isset.offset_index_length) { return IndexLocation{ @@ -389,521 +392,529 @@ class ColumnChunkMetaData::ColumnChunkMetaDataImpl { return std::nullopt; } + inline int32_t fieldId() const { + return descr_->schemaNode()->fieldId(); + } + private: - mutable std::shared_ptr possible_stats_; + mutable std::shared_ptr<::facebook::velox::parquet::arrow::Statistics> + possibleStats_; std::vector encodings_; - std::vector encoding_stats_; + std::vector encodingStats_; const facebook::velox::parquet::thrift::ColumnChunk* column_; - const facebook::velox::parquet::thrift::ColumnMetaData* column_metadata_; - facebook::velox::parquet::thrift::ColumnMetaData decrypted_metadata_; + const facebook::velox::parquet::thrift::ColumnMetaData* columnMetadata_; + facebook::velox::parquet::thrift::ColumnMetaData decryptedMetadata_; const ColumnDescriptor* descr_; const ReaderProperties properties_; - const ApplicationVersion* writer_version_; + const ApplicationVersion* writerVersion_; }; -std::unique_ptr ColumnChunkMetaData::Make( +std::unique_ptr ColumnChunkMetaData::make( const void* metadata, const ColumnDescriptor* descr, const ReaderProperties& properties, - const ApplicationVersion* writer_version, - int16_t row_group_ordinal, - int16_t column_ordinal, - std::shared_ptr file_decryptor) { + const ApplicationVersion* writerVersion, + int16_t rowGroupOrdinal, + int16_t columnOrdinal, + std::shared_ptr fileDecryptor) { return std::unique_ptr(new ColumnChunkMetaData( metadata, descr, - row_group_ordinal, - column_ordinal, + rowGroupOrdinal, + columnOrdinal, properties, - writer_version, - std::move(file_decryptor))); + writerVersion, + std::move(fileDecryptor))); } -std::unique_ptr ColumnChunkMetaData::Make( +std::unique_ptr ColumnChunkMetaData::make( const void* metadata, const ColumnDescriptor* descr, - const ApplicationVersion* writer_version, - int16_t row_group_ordinal, - int16_t column_ordinal, - std::shared_ptr file_decryptor) { + const ApplicationVersion* writerVersion, + int16_t rowGroupOrdinal, + int16_t columnOrdinal, + std::shared_ptr fileDecryptor) { return std::unique_ptr(new ColumnChunkMetaData( metadata, descr, - row_group_ordinal, - column_ordinal, - default_reader_properties(), - writer_version, - std::move(file_decryptor))); + rowGroupOrdinal, + columnOrdinal, + defaultReaderProperties(), + writerVersion, + std::move(fileDecryptor))); } ColumnChunkMetaData::ColumnChunkMetaData( const void* metadata, const ColumnDescriptor* descr, - int16_t row_group_ordinal, - int16_t column_ordinal, + int16_t rowGroupOrdinal, + int16_t columnOrdinal, const ReaderProperties& properties, - const ApplicationVersion* writer_version, - std::shared_ptr file_decryptor) + const ApplicationVersion* writerVersion, + std::shared_ptr fileDecryptor) : impl_{new ColumnChunkMetaDataImpl( reinterpret_cast< const facebook::velox::parquet::thrift::ColumnChunk*>(metadata), descr, - row_group_ordinal, - column_ordinal, + rowGroupOrdinal, + columnOrdinal, properties, - writer_version, - std::move(file_decryptor))} {} + writerVersion, + std::move(fileDecryptor))} {} ColumnChunkMetaData::~ColumnChunkMetaData() = default; -// column chunk -int64_t ColumnChunkMetaData::file_offset() const { - return impl_->file_offset(); +// Column chunk. +int64_t ColumnChunkMetaData::fileOffset() const { + return impl_->fileOffset(); } -const std::string& ColumnChunkMetaData::file_path() const { - return impl_->file_path(); +const std::string& ColumnChunkMetaData::filePath() const { + return impl_->filePath(); } Type::type ColumnChunkMetaData::type() const { return impl_->type(); } -int64_t ColumnChunkMetaData::num_values() const { - return impl_->num_values(); +int64_t ColumnChunkMetaData::numValues() const { + return impl_->numValues(); } -std::shared_ptr ColumnChunkMetaData::path_in_schema() - const { - return impl_->path_in_schema(); +std::shared_ptr ColumnChunkMetaData::pathInSchema() const { + return impl_->pathInSchema(); } std::shared_ptr ColumnChunkMetaData::statistics() const { return impl_->statistics(); } -bool ColumnChunkMetaData::is_stats_set() const { - return impl_->is_stats_set(); +bool ColumnChunkMetaData::isStatsSet() const { + return impl_->isStatsSet(); } -std::optional ColumnChunkMetaData::bloom_filter_offset() const { - return impl_->bloom_filter_offset(); +std::optional ColumnChunkMetaData::bloomFilterOffset() const { + return impl_->bloomFilterOffset(); } -bool ColumnChunkMetaData::has_dictionary_page() const { - return impl_->has_dictionary_page(); +bool ColumnChunkMetaData::hasDictionaryPage() const { + return impl_->hasDictionaryPage(); } -int64_t ColumnChunkMetaData::dictionary_page_offset() const { - return impl_->dictionary_page_offset(); +int64_t ColumnChunkMetaData::dictionaryPageOffset() const { + return impl_->dictionaryPageOffset(); } -int64_t ColumnChunkMetaData::data_page_offset() const { - return impl_->data_page_offset(); +int64_t ColumnChunkMetaData::dataPageOffset() const { + return impl_->dataPageOffset(); } -bool ColumnChunkMetaData::has_index_page() const { - return impl_->has_index_page(); +bool ColumnChunkMetaData::hasIndexPage() const { + return impl_->hasIndexPage(); } -int64_t ColumnChunkMetaData::index_page_offset() const { - return impl_->index_page_offset(); +int64_t ColumnChunkMetaData::indexPageOffset() const { + return impl_->indexPageOffset(); } Compression::type ColumnChunkMetaData::compression() const { return impl_->compression(); } -bool ColumnChunkMetaData::can_decompress() const { - return util::Codec::IsAvailable(compression()); +bool ColumnChunkMetaData::canDecompress() const { + return util::Codec::isAvailable(compression()); } const std::vector& ColumnChunkMetaData::encodings() const { return impl_->encodings(); } -const std::vector& ColumnChunkMetaData::encoding_stats() +const std::vector& ColumnChunkMetaData::encodingStats() const { - return impl_->encoding_stats(); + return impl_->encodingStats(); } -int64_t ColumnChunkMetaData::total_uncompressed_size() const { - return impl_->total_uncompressed_size(); +int64_t ColumnChunkMetaData::totalUncompressedSize() const { + return impl_->totalUncompressedSize(); } -int64_t ColumnChunkMetaData::total_compressed_size() const { - return impl_->total_compressed_size(); +int64_t ColumnChunkMetaData::totalCompressedSize() const { + return impl_->totalCompressedSize(); } -std::unique_ptr ColumnChunkMetaData::crypto_metadata() +int32_t ColumnChunkMetaData::fieldId() const { + return impl_->fieldId(); +} + +std::unique_ptr ColumnChunkMetaData::cryptoMetadata() const { - return impl_->crypto_metadata(); + return impl_->cryptoMetadata(); } -std::optional ColumnChunkMetaData::GetColumnIndexLocation() +std::optional ColumnChunkMetaData::getColumnIndexLocation() const { - return impl_->GetColumnIndexLocation(); + return impl_->getColumnIndexLocation(); } -std::optional ColumnChunkMetaData::GetOffsetIndexLocation() +std::optional ColumnChunkMetaData::getOffsetIndexLocation() const { - return impl_->GetOffsetIndexLocation(); + return impl_->getOffsetIndexLocation(); } -bool ColumnChunkMetaData::Equals(const ColumnChunkMetaData& other) const { - return impl_->Equals(*other.impl_); +bool ColumnChunkMetaData::equals(const ColumnChunkMetaData& other) const { + return impl_->equals(*other.impl_); } -// row-group metadata +// Row-group metadata. class RowGroupMetaData::RowGroupMetaDataImpl { public: explicit RowGroupMetaDataImpl( - const facebook::velox::parquet::thrift::RowGroup* row_group, + const facebook::velox::parquet::thrift::RowGroup* rowGroup, const SchemaDescriptor* schema, const ReaderProperties& properties, - const ApplicationVersion* writer_version, - std::shared_ptr file_decryptor) - : row_group_(row_group), + const ApplicationVersion* writerVersion, + std::shared_ptr fileDecryptor) + : rowGroup_(rowGroup), schema_(schema), properties_(properties), - writer_version_(writer_version), - file_decryptor_(std::move(file_decryptor)) { + writerVersion_(writerVersion), + fileDecryptor_(std::move(fileDecryptor)) { if (ARROW_PREDICT_FALSE( - row_group_->columns.size() > + rowGroup_->columns.size() > static_cast(std::numeric_limits::max()))) { throw ParquetException( - "Row group had too many columns: ", row_group_->columns.size()); + "Row group had too many columns: ", rowGroup_->columns.size()); } } - bool Equals(const RowGroupMetaDataImpl& other) const { - return *row_group_ == *other.row_group_; + bool equals(const RowGroupMetaDataImpl& other) const { + return *rowGroup_ == *other.rowGroup_; } - inline int num_columns() const { - return static_cast(row_group_->columns.size()); + inline int numColumns() const { + return static_cast(rowGroup_->columns.size()); } - inline int64_t num_rows() const { - return row_group_->num_rows; + inline int64_t numRows() const { + return rowGroup_->num_rows; } - inline int64_t total_byte_size() const { - return row_group_->total_byte_size; + inline int64_t totalByteSize() const { + return rowGroup_->total_byte_size; } - inline int64_t total_compressed_size() const { - return row_group_->total_compressed_size; + inline int64_t totalCompressedSize() const { + return rowGroup_->total_compressed_size; } - inline int64_t file_offset() const { - return row_group_->file_offset; + inline int64_t fileOffset() const { + return rowGroup_->file_offset; } inline const SchemaDescriptor* schema() const { return schema_; } - std::unique_ptr ColumnChunk(int i) { - if (i >= 0 && i < num_columns()) { - return ColumnChunkMetaData::Make( - &row_group_->columns[i], - schema_->Column(i), + std::unique_ptr columnChunk(int i) { + if (i >= 0 && i < numColumns()) { + return ColumnChunkMetaData::make( + &rowGroup_->columns[i], + schema_->column(i), properties_, - writer_version_, - row_group_->ordinal, + writerVersion_, + rowGroup_->ordinal, i, - file_decryptor_); + fileDecryptor_); } throw ParquetException( "The file only has ", - num_columns(), + numColumns(), " columns, requested metadata for column: ", i); } - std::vector sorting_columns() const { - std::vector sorting_columns; - if (!row_group_->__isset.sorting_columns) { - return sorting_columns; + std::vector sortingColumns() const { + std::vector sortingColumns; + if (!rowGroup_->__isset.sorting_columns) { + return sortingColumns; } - sorting_columns.resize(row_group_->sorting_columns.size()); - for (size_t i = 0; i < sorting_columns.size(); ++i) { - sorting_columns[i] = FromThrift(row_group_->sorting_columns[i]); + sortingColumns.resize(rowGroup_->sorting_columns.size()); + for (size_t i = 0; i < sortingColumns.size(); ++i) { + sortingColumns[i] = fromThrift(rowGroup_->sorting_columns[i]); } - return sorting_columns; + return sortingColumns; } private: - const facebook::velox::parquet::thrift::RowGroup* row_group_; + const facebook::velox::parquet::thrift::RowGroup* rowGroup_; const SchemaDescriptor* schema_; const ReaderProperties properties_; - const ApplicationVersion* writer_version_; - std::shared_ptr file_decryptor_; + const ApplicationVersion* writerVersion_; + std::shared_ptr fileDecryptor_; }; -std::unique_ptr RowGroupMetaData::Make( +std::unique_ptr RowGroupMetaData::make( const void* metadata, const SchemaDescriptor* schema, - const ApplicationVersion* writer_version, - std::shared_ptr file_decryptor) { + const ApplicationVersion* writerVersion, + std::shared_ptr fileDecryptor) { return std::unique_ptr(new RowGroupMetaData( metadata, schema, - default_reader_properties(), - writer_version, - std::move(file_decryptor))); + defaultReaderProperties(), + writerVersion, + std::move(fileDecryptor))); } -std::unique_ptr RowGroupMetaData::Make( +std::unique_ptr RowGroupMetaData::make( const void* metadata, const SchemaDescriptor* schema, const ReaderProperties& properties, - const ApplicationVersion* writer_version, - std::shared_ptr file_decryptor) { + const ApplicationVersion* writerVersion, + std::shared_ptr fileDecryptor) { return std::unique_ptr(new RowGroupMetaData( - metadata, schema, properties, writer_version, std::move(file_decryptor))); + metadata, schema, properties, writerVersion, std::move(fileDecryptor))); } RowGroupMetaData::RowGroupMetaData( const void* metadata, const SchemaDescriptor* schema, const ReaderProperties& properties, - const ApplicationVersion* writer_version, - std::shared_ptr file_decryptor) + const ApplicationVersion* writerVersion, + std::shared_ptr fileDecryptor) : impl_{new RowGroupMetaDataImpl( reinterpret_cast( metadata), schema, properties, - writer_version, - std::move(file_decryptor))} {} + writerVersion, + std::move(fileDecryptor))} {} RowGroupMetaData::~RowGroupMetaData() = default; -bool RowGroupMetaData::Equals(const RowGroupMetaData& other) const { - return impl_->Equals(*other.impl_); +bool RowGroupMetaData::equals(const RowGroupMetaData& other) const { + return impl_->equals(*other.impl_); } -int RowGroupMetaData::num_columns() const { - return impl_->num_columns(); +int RowGroupMetaData::numColumns() const { + return impl_->numColumns(); } -int64_t RowGroupMetaData::num_rows() const { - return impl_->num_rows(); +int64_t RowGroupMetaData::numRows() const { + return impl_->numRows(); } -int64_t RowGroupMetaData::total_byte_size() const { - return impl_->total_byte_size(); +int64_t RowGroupMetaData::totalByteSize() const { + return impl_->totalByteSize(); } -int64_t RowGroupMetaData::total_compressed_size() const { - return impl_->total_compressed_size(); +int64_t RowGroupMetaData::totalCompressedSize() const { + return impl_->totalCompressedSize(); } -int64_t RowGroupMetaData::file_offset() const { - return impl_->file_offset(); +int64_t RowGroupMetaData::fileOffset() const { + return impl_->fileOffset(); } const SchemaDescriptor* RowGroupMetaData::schema() const { return impl_->schema(); } -std::unique_ptr RowGroupMetaData::ColumnChunk( +std::unique_ptr RowGroupMetaData::columnChunk( int i) const { - return impl_->ColumnChunk(i); + return impl_->columnChunk(i); } -bool RowGroupMetaData::can_decompress() const { - int n_columns = num_columns(); - for (int i = 0; i < n_columns; i++) { - if (!ColumnChunk(i)->can_decompress()) { +bool RowGroupMetaData::canDecompress() const { + int nColumns = numColumns(); + for (int i = 0; i < nColumns; i++) { + if (!columnChunk(i)->canDecompress()) { return false; } } return true; } -std::vector RowGroupMetaData::sorting_columns() const { - return impl_->sorting_columns(); +std::vector RowGroupMetaData::sortingColumns() const { + return impl_->sortingColumns(); } -// file metadata +// File metadata. class FileMetaData::FileMetaDataImpl { public: FileMetaDataImpl() = default; explicit FileMetaDataImpl( const void* metadata, - uint32_t* metadata_len, + uint32_t* metadataLen, ReaderProperties properties, - std::shared_ptr file_decryptor = nullptr) + std::shared_ptr fileDecryptor = nullptr) : properties_(std::move(properties)), - file_decryptor_(std::move(file_decryptor)) { + fileDecryptor_(std::move(fileDecryptor)) { metadata_ = std::make_unique(); - auto footer_decryptor = file_decryptor_ != nullptr - ? file_decryptor_->GetFooterDecryptor() + auto footerDecryptor = fileDecryptor_ != nullptr + ? fileDecryptor_->getFooterDecryptor() : nullptr; ThriftDeserializer deserializer(properties_); - deserializer.DeserializeMessage( + deserializer.deserializeMessage( reinterpret_cast(metadata), - metadata_len, + metadataLen, metadata_.get(), - footer_decryptor); - metadata_len_ = *metadata_len; + footerDecryptor); + metadataLen_ = *metadataLen; if (metadata_->__isset.created_by) { - writer_version_ = ApplicationVersion(metadata_->created_by); + writerVersion_ = ApplicationVersion(metadata_->created_by); } else { - writer_version_ = ApplicationVersion("unknown 0.0.0"); + writerVersion_ = ApplicationVersion("unknown 0.0.0"); } - InitSchema(); - InitColumnOrders(); - InitKeyValueMetadata(); + initSchema(); + initColumnOrders(); + initKeyValueMetadata(); } - bool VerifySignature(const void* signature) { - // verify decryption properties are set - if (file_decryptor_ == nullptr) { + bool verifySignature(const void* signature) { + // Verify decryption properties are set. + if (fileDecryptor_ == nullptr) { throw ParquetException( - "Decryption not set properly. cannot verify signature"); + "Decryption not set properly. Cannot verify signature."); } - // serialize the footer - uint8_t* serialized_data; - uint32_t serialized_len = metadata_len_; + // Serialize the footer. + uint8_t* serializedData; + uint32_t serializedLen = metadataLen_; ThriftSerializer serializer; - serializer.SerializeToBuffer( - metadata_.get(), &serialized_len, &serialized_data); + serializer.serializeToBuffer( + metadata_.get(), &serializedLen, &serializedData); - // encrypt with nonce + // Encrypt with nonce. auto nonce = const_cast(reinterpret_cast(signature)); auto tag = const_cast(reinterpret_cast(signature)) + encryption::kNonceLength; - std::string key = file_decryptor_->GetFooterKey(); - std::string aad = encryption::CreateFooterAad(file_decryptor_->file_aad()); + std::string key = fileDecryptor_->getFooterKey(); + std::string aad = encryption::createFooterAad(fileDecryptor_->fileAad()); - auto aes_encryptor = encryption::AesEncryptor::Make( - file_decryptor_->algorithm(), + auto aesEncryptor = encryption::AesEncryptor::make( + fileDecryptor_->algorithm(), static_cast(key.size()), true, false /*write_length*/, nullptr); - std::shared_ptr encrypted_buffer = - std::static_pointer_cast(AllocateBuffer( - file_decryptor_->pool(), - aes_encryptor->CiphertextSizeDelta() + serialized_len)); - uint32_t encrypted_len = aes_encryptor->SignedFooterEncrypt( - serialized_data, - serialized_len, + std::shared_ptr encryptedBuffer = + std::static_pointer_cast(allocateBuffer( + fileDecryptor_->pool(), + aesEncryptor->ciphertextSizeDelta() + serializedLen)); + uint32_t encryptedLen = aesEncryptor->signedFooterEncrypt( + serializedData, + serializedLen, str2bytes(key), static_cast(key.size()), str2bytes(aad), static_cast(aad.size()), nonce, - encrypted_buffer->mutable_data()); - // Delete AES encryptor object. It was created only to verify the footer - // signature. - aes_encryptor->WipeOut(); - delete aes_encryptor; + encryptedBuffer->mutable_data()); + // Delete AES encryptor object. It was created only to verify the footer. + // Signature. + aesEncryptor->wipeOut(); + delete aesEncryptor; return 0 == - memcmp(encrypted_buffer->data() + encrypted_len - + memcmp(encryptedBuffer->data() + encryptedLen - encryption::kGcmTagLength, tag, encryption::kGcmTagLength); } inline uint32_t size() const { - return metadata_len_; + return metadataLen_; } - inline int num_columns() const { - return schema_.num_columns(); + inline int numColumns() const { + return schema_.numColumns(); } - inline int64_t num_rows() const { + inline int64_t numRows() const { return metadata_->num_rows; } - inline int num_row_groups() const { + inline int numRowGroups() const { return static_cast(metadata_->row_groups.size()); } inline int32_t version() const { return metadata_->version; } - inline const std::string& created_by() const { + inline const std::string& createdBy() const { return metadata_->created_by; } - inline int num_schema_elements() const { + inline int numSchemaElements() const { return static_cast(metadata_->schema.size()); } - inline bool is_encryption_algorithm_set() const { + inline bool isEncryptionAlgorithmSet() const { return metadata_->__isset.encryption_algorithm; } - inline EncryptionAlgorithm encryption_algorithm() { - return FromThrift(metadata_->encryption_algorithm); + inline EncryptionAlgorithm encryptionAlgorithm() { + return fromThrift(metadata_->encryption_algorithm); } - inline const std::string& footer_signing_key_metadata() { + inline const std::string& footerSigningKeyMetadata() { return metadata_->footer_signing_key_metadata; } - const ApplicationVersion& writer_version() const { - return writer_version_; + const ApplicationVersion& writerVersion() const { + return writerVersion_; } - void WriteTo( + void writeTo( ::arrow::io::OutputStream* dst, const std::shared_ptr& encryptor) const { ThriftSerializer serializer; - // Only in encrypted files with plaintext footers the - // encryption_algorithm is set in footer - if (is_encryption_algorithm_set()) { - uint8_t* serialized_data; - uint32_t serialized_len; - serializer.SerializeToBuffer( - metadata_.get(), &serialized_len, &serialized_data); - - // encrypt the footer key - std::vector encrypted_data( - encryptor->CiphertextSizeDelta() + serialized_len); - unsigned encrypted_len = encryptor->Encrypt( - serialized_data, serialized_len, encrypted_data.data()); - - // write unencrypted footer - PARQUET_THROW_NOT_OK(dst->Write(serialized_data, serialized_len)); + // Only in encrypted files with plaintext footers the. + // Encryption_algorithm is set in footer. + if (isEncryptionAlgorithmSet()) { + uint8_t* serializedData; + uint32_t serializedLen; + serializer.serializeToBuffer( + metadata_.get(), &serializedLen, &serializedData); + + // Encrypt the footer key. + std::vector encryptedData( + encryptor->ciphertextSizeDelta() + serializedLen); + unsigned encryptedLen = encryptor->encrypt( + serializedData, serializedLen, encryptedData.data()); + + // Write unencrypted footer. + PARQUET_THROW_NOT_OK(dst->Write(serializedData, serializedLen)); // Write signature (nonce and tag) PARQUET_THROW_NOT_OK( - dst->Write(encrypted_data.data() + 4, encryption::kNonceLength)); + dst->Write(encryptedData.data() + 4, encryption::kNonceLength)); PARQUET_THROW_NOT_OK(dst->Write( - encrypted_data.data() + encrypted_len - encryption::kGcmTagLength, + encryptedData.data() + encryptedLen - encryption::kGcmTagLength, encryption::kGcmTagLength)); } else { // either plaintext file (when encryptor is null) - // or encrypted file with encrypted footer - serializer.Serialize(metadata_.get(), dst, encryptor); + // Or encrypted file with encrypted footer. + serializer.serialize(metadata_.get(), dst, encryptor); } } - std::unique_ptr RowGroup(int i) { - if (!(i >= 0 && i < num_row_groups())) { + std::unique_ptr rowGroup(int i) { + if (!(i >= 0 && i < numRowGroups())) { std::stringstream ss; - ss << "The file only has " << num_row_groups() + ss << "The file only has " << numRowGroups() << " row groups, requested metadata for row group: " << i; throw ParquetException(ss.str()); } - return RowGroupMetaData::Make( + return RowGroupMetaData::make( &metadata_->row_groups[i], &schema_, properties_, - &writer_version_, - file_decryptor_); + &writerVersion_, + fileDecryptor_); } - bool Equals(const FileMetaDataImpl& other) const { + bool equals(const FileMetaDataImpl& other) const { return *metadata_ == *other.metadata_; } @@ -911,61 +922,60 @@ class FileMetaData::FileMetaDataImpl { return &schema_; } - const std::shared_ptr& key_value_metadata() const { - return key_value_metadata_; + const std::shared_ptr& keyValueMetadata() const { + return keyValueMetadata_; } - void set_file_path(const std::string& path) { - for (facebook::velox::parquet::thrift::RowGroup& row_group : + void setFilePath(const std::string& path) { + for (facebook::velox::parquet::thrift::RowGroup& rowGroup : metadata_->row_groups) { for (facebook::velox::parquet::thrift::ColumnChunk& chunk : - row_group.columns) { + rowGroup.columns) { chunk.__set_file_path(path); } } } - facebook::velox::parquet::thrift::RowGroup& row_group(int i) { - if (!(i >= 0 && i < num_row_groups())) { + facebook::velox::parquet::thrift::RowGroup& thriftRowGroup(int i) { + if (!(i >= 0 && i < numRowGroups())) { std::stringstream ss; - ss << "The file only has " << num_row_groups() + ss << "The file only has " << numRowGroups() << " row groups, requested metadata for row group: " << i; throw ParquetException(ss.str()); } return metadata_->row_groups[i]; } - void AppendRowGroups(const std::unique_ptr& other) { - std::ostringstream diff_output; - if (!schema()->Equals(*other->schema(), &diff_output)) { - auto msg = - "AppendRowGroups requires equal schemas.\n" + diff_output.str(); + void appendRowGroups(const std::unique_ptr& other) { + std::ostringstream diffOutput; + if (!schema()->equals(*other->schema(), &diffOutput)) { + auto msg = "AppendRowGroups requires equal schemas.\n" + diffOutput.str(); throw ParquetException(msg); } - // ARROW-13654: `other` may point to self, be careful not to enter an - // infinite loop - const int n = other->num_row_groups(); - // ARROW-16613: do not use reserve() as that may suppress overallocation - // and incur O(n²) behavior on repeated calls to AppendRowGroups(). - // (see https://en.cppreference.com/w/cpp/container/vector/reserve - // about inappropriate uses of reserve()). + // ARROW-13654: `other` may point to self, be careful not to enter an. + // Infinite loop. + const int n = other->numRowGroups(); + // ARROW-16613: do not use reserve() as that may suppress overallocation. + // And incur O(n²) behavior on repeated calls to AppendRowGroups(). + // (See https://en.cppreference.com/w/cpp/container/vector/reserve. + // About inappropriate uses of reserve()). const auto start = metadata_->row_groups.size(); metadata_->row_groups.resize(start + n); for (int i = 0; i < n; i++) { - metadata_->row_groups[start + i] = other->row_group(i); + metadata_->row_groups[start + i] = other->thriftRowGroup(i); metadata_->num_rows += metadata_->row_groups[start + i].num_rows; } } - std::shared_ptr Subset(const std::vector& row_groups) { - for (int i : row_groups) { - if (i < num_row_groups()) + std::shared_ptr subset(const std::vector& rowGroups) { + for (int i : rowGroups) { + if (i < numRowGroups()) continue; throw ParquetException( "The file only has ", - num_row_groups(), + numRowGroups(), " row groups, but requested a subset including row group: ", i); } @@ -979,11 +989,11 @@ class FileMetaData::FileMetaDataImpl { metadata->version = metadata_->version; metadata->schema = metadata_->schema; - metadata->row_groups.resize(row_groups.size()); + metadata->row_groups.resize(rowGroups.size()); int i = 0; - for (int selected_index : row_groups) { - metadata->num_rows += row_group(selected_index).num_rows; - metadata->row_groups[i++] = row_group(selected_index); + for (int selectedIndex : rowGroups) { + metadata->num_rows += thriftRowGroup(selectedIndex).num_rows; + metadata->row_groups[i++] = thriftRowGroup(selectedIndex); } metadata->key_value_metadata = metadata_->key_value_metadata; @@ -995,56 +1005,75 @@ class FileMetaData::FileMetaDataImpl { metadata->__isset = metadata_->__isset; out->impl_->schema_ = schema_; - out->impl_->writer_version_ = writer_version_; - out->impl_->key_value_metadata_ = key_value_metadata_; - out->impl_->file_decryptor_ = file_decryptor_; + out->impl_->writerVersion_ = writerVersion_; + out->impl_->keyValueMetadata_ = keyValueMetadata_; + out->impl_->fileDecryptor_ = fileDecryptor_; return out; } - void set_file_decryptor( - std::shared_ptr file_decryptor) { - file_decryptor_ = file_decryptor; + void setFileDecryptor(std::shared_ptr fileDecryptor) { + fileDecryptor_ = fileDecryptor; + } + + // Set NaN counts from the builder (called during Finish) + // This stores total NaN counts per field ID across all row groups. + void setNaNCounts( + std::unordered_map> nan_counts) { + fieldNanCounts_ = std::move(nan_counts); + } + + // Get total NaN count for a specific field ID across all row groups. + std::pair getNaNCount(int32_t fieldId) const { + auto it = fieldNanCounts_.find(fieldId); + if (it != fieldNanCounts_.end()) { + return it->second; + } + return {0, false}; } private: friend FileMetaDataBuilder; - uint32_t metadata_len_ = 0; + uint32_t metadataLen_ = 0; std::unique_ptr metadata_; SchemaDescriptor schema_; - ApplicationVersion writer_version_; - std::shared_ptr key_value_metadata_; + ApplicationVersion writerVersion_; + std::shared_ptr keyValueMetadata_; const ReaderProperties properties_; - std::shared_ptr file_decryptor_; + std::shared_ptr fileDecryptor_; + // Total NaN counts per field ID across all row groups: field_id -> + // (nan_count, has_nan_count). + std::unordered_map> fieldNanCounts_; - void InitSchema() { + void initSchema() { if (metadata_->schema.empty()) { throw ParquetException("Empty file schema (no root)"); } - schema_.Init(schema::Unflatten( - &metadata_->schema[0], static_cast(metadata_->schema.size()))); + schema_.init( + schema::unflatten( + &metadata_->schema[0], static_cast(metadata_->schema.size()))); } - void InitColumnOrders() { - // update ColumnOrder - std::vector column_orders; + void initColumnOrders() { + // Update ColumnOrder. + std::vector columnOrders; if (metadata_->__isset.column_orders) { - column_orders.reserve(metadata_->column_orders.size()); - for (auto column_order : metadata_->column_orders) { - if (column_order.__isset.TYPE_ORDER) { - column_orders.push_back(ColumnOrder::type_defined_); + columnOrders.reserve(metadata_->column_orders.size()); + for (auto columnOrder : metadata_->column_orders) { + if (columnOrder.__isset.TYPE_ORDER) { + columnOrders.push_back(ColumnOrder::typeDefined_); } else { - column_orders.push_back(ColumnOrder::undefined_); + columnOrders.push_back(ColumnOrder::undefined_); } } } else { - column_orders.resize(schema_.num_columns(), ColumnOrder::undefined_); + columnOrders.resize(schema_.numColumns(), ColumnOrder::undefined_); } - schema_.updateColumnOrders(column_orders); + schema_.updateColumnOrders(columnOrders); } - void InitKeyValueMetadata() { + void initKeyValueMetadata() { std::shared_ptr metadata = nullptr; if (metadata_->__isset.key_value_metadata) { metadata = std::make_shared(); @@ -1052,96 +1081,96 @@ class FileMetaData::FileMetaDataImpl { metadata->Append(it.key, it.value); } } - key_value_metadata_ = std::move(metadata); + keyValueMetadata_ = std::move(metadata); } }; -std::shared_ptr FileMetaData::Make( +std::shared_ptr FileMetaData::make( const void* metadata, - uint32_t* metadata_len, + uint32_t* metadataLen, const ReaderProperties& properties, - std::shared_ptr file_decryptor) { - // This FileMetaData ctor is private, not compatible with std::make_shared + std::shared_ptr fileDecryptor) { + // This FileMetaData ctor is private, not compatible with std::make_shared. return std::shared_ptr(new FileMetaData( - metadata, metadata_len, properties, std::move(file_decryptor))); + metadata, metadataLen, properties, std::move(fileDecryptor))); } -std::shared_ptr FileMetaData::Make( +std::shared_ptr FileMetaData::make( const void* metadata, - uint32_t* metadata_len, - std::shared_ptr file_decryptor) { + uint32_t* metadataLen, + std::shared_ptr fileDecryptor) { return std::shared_ptr(new FileMetaData( - metadata, metadata_len, default_reader_properties(), file_decryptor)); + metadata, metadataLen, defaultReaderProperties(), fileDecryptor)); } FileMetaData::FileMetaData( const void* metadata, - uint32_t* metadata_len, + uint32_t* metadataLen, const ReaderProperties& properties, - std::shared_ptr file_decryptor) + std::shared_ptr fileDecryptor) : impl_(new FileMetaDataImpl( metadata, - metadata_len, + metadataLen, properties, - file_decryptor)) {} + fileDecryptor)) {} FileMetaData::FileMetaData() : impl_(new FileMetaDataImpl()) {} FileMetaData::~FileMetaData() = default; -bool FileMetaData::Equals(const FileMetaData& other) const { - return impl_->Equals(*other.impl_); +bool FileMetaData::equals(const FileMetaData& other) const { + return impl_->equals(*other.impl_); } -std::unique_ptr FileMetaData::RowGroup(int i) const { - return impl_->RowGroup(i); +std::unique_ptr FileMetaData::rowGroup(int i) const { + return impl_->rowGroup(i); } -bool FileMetaData::VerifySignature(const void* signature) { - return impl_->VerifySignature(signature); +bool FileMetaData::verifySignature(const void* signature) { + return impl_->verifySignature(signature); } uint32_t FileMetaData::size() const { return impl_->size(); } -int FileMetaData::num_columns() const { - return impl_->num_columns(); +int FileMetaData::numColumns() const { + return impl_->numColumns(); } -int64_t FileMetaData::num_rows() const { - return impl_->num_rows(); +int64_t FileMetaData::numRows() const { + return impl_->numRows(); } -int FileMetaData::num_row_groups() const { - return impl_->num_row_groups(); +int FileMetaData::numRowGroups() const { + return impl_->numRowGroups(); } -bool FileMetaData::can_decompress() const { - int n_row_groups = num_row_groups(); - for (int i = 0; i < n_row_groups; i++) { - if (!RowGroup(i)->can_decompress()) { +bool FileMetaData::canDecompress() const { + int nRowGroups = numRowGroups(); + for (int i = 0; i < nRowGroups; i++) { + if (!rowGroup(i)->canDecompress()) { return false; } } return true; } -bool FileMetaData::is_encryption_algorithm_set() const { - return impl_->is_encryption_algorithm_set(); +bool FileMetaData::isEncryptionAlgorithmSet() const { + return impl_->isEncryptionAlgorithmSet(); } -EncryptionAlgorithm FileMetaData::encryption_algorithm() const { - return impl_->encryption_algorithm(); +EncryptionAlgorithm FileMetaData::encryptionAlgorithm() const { + return impl_->encryptionAlgorithm(); } -const std::string& FileMetaData::footer_signing_key_metadata() const { - return impl_->footer_signing_key_metadata(); +const std::string& FileMetaData::footerSigningKeyMetadata() const { + return impl_->footerSigningKeyMetadata(); } -void FileMetaData::set_file_decryptor( - std::shared_ptr file_decryptor) { - impl_->set_file_decryptor(file_decryptor); +void FileMetaData::setFileDecryptor( + std::shared_ptr fileDecryptor) { + impl_->setFileDecryptor(fileDecryptor); } ParquetVersion::type FileMetaData::version() const { @@ -1151,50 +1180,54 @@ ParquetVersion::type FileMetaData::version() const { case 2: return ParquetVersion::PARQUET_2_LATEST; default: - // Improperly set version, assuming Parquet 1.0 + // Improperly set version, assuming Parquet 1.0. break; } return ParquetVersion::PARQUET_1_0; } -const ApplicationVersion& FileMetaData::writer_version() const { - return impl_->writer_version(); +const ApplicationVersion& FileMetaData::writerVersion() const { + return impl_->writerVersion(); } -const std::string& FileMetaData::created_by() const { - return impl_->created_by(); +const std::string& FileMetaData::createdBy() const { + return impl_->createdBy(); } -int FileMetaData::num_schema_elements() const { - return impl_->num_schema_elements(); +int FileMetaData::numSchemaElements() const { + return impl_->numSchemaElements(); } const SchemaDescriptor* FileMetaData::schema() const { return impl_->schema(); } -const std::shared_ptr& -FileMetaData::key_value_metadata() const { - return impl_->key_value_metadata(); +const std::shared_ptr& FileMetaData::keyValueMetadata() + const { + return impl_->keyValueMetadata(); } -void FileMetaData::set_file_path(const std::string& path) { - impl_->set_file_path(path); +void FileMetaData::setFilePath(const std::string& path) { + impl_->setFilePath(path); } -void FileMetaData::AppendRowGroups(const FileMetaData& other) { - impl_->AppendRowGroups(other.impl_); +void FileMetaData::appendRowGroups(const FileMetaData& other) { + impl_->appendRowGroups(other.impl_); } -std::shared_ptr FileMetaData::Subset( - const std::vector& row_groups) const { - return impl_->Subset(row_groups); +std::shared_ptr FileMetaData::subset( + const std::vector& rowGroups) const { + return impl_->subset(rowGroups); } -void FileMetaData::WriteTo( +std::pair FileMetaData::getNaNCount(int32_t fieldId) const { + return impl_->getNaNCount(fieldId); +} + +void FileMetaData::writeTo( ::arrow::io::OutputStream* dst, const std::shared_ptr& encryptor) const { - return impl_->WriteTo(dst, encryptor); + return impl_->writeTo(dst, encryptor); } class FileCryptoMetaData::FileCryptoMetaDataImpl { @@ -1203,55 +1236,55 @@ class FileCryptoMetaData::FileCryptoMetaDataImpl { explicit FileCryptoMetaDataImpl( const uint8_t* metadata, - uint32_t* metadata_len, + uint32_t* metadataLen, const ReaderProperties& properties) { ThriftDeserializer deserializer(properties); - deserializer.DeserializeMessage(metadata, metadata_len, &metadata_); - metadata_len_ = *metadata_len; + deserializer.deserializeMessage(metadata, metadataLen, &metadata_); + metadataLen_ = *metadataLen; } - EncryptionAlgorithm encryption_algorithm() const { - return FromThrift(metadata_.encryption_algorithm); + EncryptionAlgorithm encryptionAlgorithm() const { + return fromThrift(metadata_.encryption_algorithm); } - const std::string& key_metadata() const { + const std::string& keyMetadata() const { return metadata_.key_metadata; } - void WriteTo(::arrow::io::OutputStream* dst) const { + void writeTo(::arrow::io::OutputStream* dst) const { ThriftSerializer serializer; - serializer.Serialize(&metadata_, dst); + serializer.serialize(&metadata_, dst); } private: friend FileMetaDataBuilder; facebook::velox::parquet::thrift::FileCryptoMetaData metadata_; - uint32_t metadata_len_; + uint32_t metadataLen_; }; -EncryptionAlgorithm FileCryptoMetaData::encryption_algorithm() const { - return impl_->encryption_algorithm(); +EncryptionAlgorithm FileCryptoMetaData::encryptionAlgorithm() const { + return impl_->encryptionAlgorithm(); } -const std::string& FileCryptoMetaData::key_metadata() const { - return impl_->key_metadata(); +const std::string& FileCryptoMetaData::keyMetadata() const { + return impl_->keyMetadata(); } -std::shared_ptr FileCryptoMetaData::Make( - const uint8_t* serialized_metadata, - uint32_t* metadata_len, +std::shared_ptr FileCryptoMetaData::make( + const uint8_t* serializedMetadata, + uint32_t* metadataLen, const ReaderProperties& properties) { return std::shared_ptr( - new FileCryptoMetaData(serialized_metadata, metadata_len, properties)); + new FileCryptoMetaData(serializedMetadata, metadataLen, properties)); } FileCryptoMetaData::FileCryptoMetaData( - const uint8_t* serialized_metadata, - uint32_t* metadata_len, + const uint8_t* serializedMetadata, + uint32_t* metadataLen, const ReaderProperties& properties) : impl_(new FileCryptoMetaDataImpl( - serialized_metadata, - metadata_len, + serializedMetadata, + metadataLen, properties)) {} FileCryptoMetaData::FileCryptoMetaData() @@ -1259,18 +1292,18 @@ FileCryptoMetaData::FileCryptoMetaData() FileCryptoMetaData::~FileCryptoMetaData() = default; -void FileCryptoMetaData::WriteTo(::arrow::io::OutputStream* dst) const { - impl_->WriteTo(dst); +void FileCryptoMetaData::writeTo(::arrow::io::OutputStream* dst) const { + impl_->writeTo(dst); } -std::string FileMetaData::SerializeToString() const { - // We need to pass in an initial size. Since it will automatically - // increase the buffer size to hold the metadata, we just leave it 0. +std::string FileMetaData::serializeToString() const { + // We need to pass in an initial size. Since it will automatically. + // Increase the buffer size to hold the metadata, we just leave it 0. PARQUET_ASSIGN_OR_THROW( auto serializer, ::arrow::io::BufferOutputStream::Create(0)); - WriteTo(serializer.get()); - PARQUET_ASSIGN_OR_THROW(auto metadata_buffer, serializer->Finish()); - return metadata_buffer->ToString(); + writeTo(serializer.get()); + PARQUET_ASSIGN_OR_THROW(auto metadataBuffer, serializer->Finish()); + return metadataBuffer->ToString(); } ApplicationVersion::ApplicationVersion( @@ -1282,609 +1315,618 @@ ApplicationVersion::ApplicationVersion( version{major, minor, patch, "", "", ""} {} namespace { -// Parse the application version format and set parsed values to +// Parse the application version format and set parsed values to. // ApplicationVersion. // -// The application version format must be compatible parquet-mr's -// one. See also: -// * https://github.com/apache/parquet-mr/blob/master/parquet-common/src/main/java/org/apache/parquet/VersionParser.java -// * https://github.com/apache/parquet-mr/blob/master/parquet-common/src/main/java/org/apache/parquet/SemanticVersion.java +// The application version format must be compatible parquet-mr's. +// One. See also: +// * Https://github.com/apache/parquet-mr/blob/master/parquet-common/src/main/java/org/apache/parquet/VersionParser.java. +// * Https://github.com/apache/parquet-mr/blob/master/parquet-common/src/main/java/org/apache/parquet/SemanticVersion.java. // // The application version format: -// "${APPLICATION_NAME}" -// "${APPLICATION_NAME} version ${VERSION}" -// "${APPLICATION_NAME} version ${VERSION} (build ${BUILD_NAME})" +// "${APPLICATION_NAME}". +// "${APPLICATION_NAME} version ${VERSION}". +// "${APPLICATION_NAME} version ${VERSION} (build ${BUILD_NAME})". // // Eg: -// parquet-cpp -// parquet-cpp version 1.5.0ab-xyz5.5.0+cd +// Parquet-cpp. +// Parquet-cpp version 1.5.0ab-xyz5.5.0+cd. // parquet-cpp version 1.5.0ab-xyz5.5.0+cd (build abcd) // // The VERSION format: -// "${MAJOR}" -// "${MAJOR}.${MINOR}" -// "${MAJOR}.${MINOR}.${PATCH}" -// "${MAJOR}.${MINOR}.${PATCH}${UNKNOWN}" -// "${MAJOR}.${MINOR}.${PATCH}${UNKNOWN}-${PRE_RELEASE}" -// "${MAJOR}.${MINOR}.${PATCH}${UNKNOWN}-${PRE_RELEASE}+${BUILD_INFO}" -// "${MAJOR}.${MINOR}.${PATCH}${UNKNOWN}+${BUILD_INFO}" -// "${MAJOR}.${MINOR}.${PATCH}-${PRE_RELEASE}" -// "${MAJOR}.${MINOR}.${PATCH}-${PRE_RELEASE}+${BUILD_INFO}" -// "${MAJOR}.${MINOR}.${PATCH}+${BUILD_INFO}" +// "${MAJOR}". +// "${MAJOR}.${MINOR}". +// "${MAJOR}.${MINOR}.${PATCH}". +// "${MAJOR}.${MINOR}.${PATCH}${UNKNOWN}". +// "${MAJOR}.${MINOR}.${PATCH}${UNKNOWN}-${PRE_RELEASE}". +// "${MAJOR}.${MINOR}.${PATCH}${UNKNOWN}-${PRE_RELEASE}+${BUILD_INFO}". +// "${MAJOR}.${MINOR}.${PATCH}${UNKNOWN}+${BUILD_INFO}". +// "${MAJOR}.${MINOR}.${PATCH}-${PRE_RELEASE}". +// "${MAJOR}.${MINOR}.${PATCH}-${PRE_RELEASE}+${BUILD_INFO}". +// "${MAJOR}.${MINOR}.${PATCH}+${BUILD_INFO}". // // Eg: -// 1 -// 1.5 -// 1.5.0 -// 1.5.0ab -// 1.5.0ab-cdh5.5.0 -// 1.5.0ab-cdh5.5.0+cd -// 1.5.0ab+cd -// 1.5.0-cdh5.5.0 -// 1.5.0-cdh5.5.0+cd -// 1.5.0+cd +// 1. +// 1.5. +// 1.5.0. +// 1.5.0Ab. +// 1.5.0Ab-cdh5.5.0. +// 1.5.0Ab-cdh5.5.0+cd. +// 1.5.0Ab+cd. +// 1.5.0-Cdh5.5.0. +// 1.5.0-Cdh5.5.0+cd. +// 1.5.0+Cd. class ApplicationVersionParser { public: ApplicationVersionParser( - const std::string& created_by, - ApplicationVersion& application_version) - : created_by_(created_by), - application_version_(application_version), + const std::string& createdBy, + ApplicationVersion& ApplicationVersion) + : createdBy_(createdBy), + ApplicationVersion_(ApplicationVersion), spaces_(" \t\v\r\n\f"), digits_("0123456789") {} - void Parse() { - application_version_.application_ = "unknown"; - application_version_.version = {0, 0, 0, "", "", ""}; + void parse() { + ApplicationVersion_.application_ = "unknown"; + ApplicationVersion_.version = {0, 0, 0, "", "", ""}; - if (!ParseApplicationName()) { + if (!parseApplicationName()) { return; } - if (!ParseVersion()) { + if (!parseVersion()) { return; } - if (!ParseBuildName()) { + if (!parseBuildName()) { return; } } private: - bool IsSpace(const std::string& string, const size_t& offset) { - auto target = ::std::string_view(string).substr(offset, 1); + bool isSpace(const std::string& text, const size_t& offset) { + auto target = ::std::string_view(text).substr(offset, 1); return target.find_first_of(spaces_) != ::std::string_view::npos; } - void RemovePrecedingSpaces( - const std::string& string, + void removePrecedingSpaces( + const std::string& text, size_t& start, const size_t& end) { - while (start < end && IsSpace(string, start)) { + while (start < end && isSpace(text, start)) { ++start; } } - void RemoveTrailingSpaces( - const std::string& string, + void removeTrailingSpaces( + const std::string& text, const size_t& start, size_t& end) { - while (start < (end - 1) && (end - 1) < string.size() && - IsSpace(string, end - 1)) { + while (start < (end - 1) && (end - 1) < text.size() && + isSpace(text, end - 1)) { --end; } } - bool ParseApplicationName() { - std::string version_mark(" version "); - auto version_mark_position = created_by_.find(version_mark); - size_t application_name_end; + bool parseApplicationName() { + std::string versionMark(" version "); + auto versionMarkPosition = createdBy_.find(versionMark); + size_t applicationNameEnd; // No VERSION and BUILD_NAME. - if (version_mark_position == std::string::npos) { - version_start_ = std::string::npos; - application_name_end = created_by_.size(); + if (versionMarkPosition == std::string::npos) { + versionStart_ = std::string::npos; + applicationNameEnd = createdBy_.size(); } else { - version_start_ = version_mark_position + version_mark.size(); - application_name_end = version_mark_position; + versionStart_ = versionMarkPosition + versionMark.size(); + applicationNameEnd = versionMarkPosition; } - size_t application_name_start = 0; - RemovePrecedingSpaces( - created_by_, application_name_start, application_name_end); - RemoveTrailingSpaces( - created_by_, application_name_start, application_name_end); - application_version_.application_ = created_by_.substr( - application_name_start, application_name_end - application_name_start); + size_t applicationNameStart = 0; + removePrecedingSpaces(createdBy_, applicationNameStart, applicationNameEnd); + removeTrailingSpaces(createdBy_, applicationNameStart, applicationNameEnd); + ApplicationVersion_.application_ = createdBy_.substr( + applicationNameStart, applicationNameEnd - applicationNameStart); return true; } - bool ParseVersion() { + bool parseVersion() { // No VERSION. - if (version_start_ == std::string::npos) { + if (versionStart_ == std::string::npos) { return false; } - RemovePrecedingSpaces(created_by_, version_start_, created_by_.size()); - version_end_ = created_by_.find(" (", version_start_); + removePrecedingSpaces(createdBy_, versionStart_, createdBy_.size()); + versionEnd_ = createdBy_.find(" (", versionStart_); // No BUILD_NAME. - if (version_end_ == std::string::npos) { - version_end_ = created_by_.size(); + if (versionEnd_ == std::string::npos) { + versionEnd_ = createdBy_.size(); } - RemoveTrailingSpaces(created_by_, version_start_, version_end_); + removeTrailingSpaces(createdBy_, versionStart_, versionEnd_); // No VERSION. - if (version_start_ == version_end_) { + if (versionStart_ == versionEnd_) { return false; } - version_string_ = - created_by_.substr(version_start_, version_end_ - version_start_); + versionString_ = + createdBy_.substr(versionStart_, versionEnd_ - versionStart_); - if (!ParseVersionMajor()) { + if (!parseVersionMajor()) { return false; } - if (!ParseVersionMinor()) { + if (!parseVersionMinor()) { return false; } - if (!ParseVersionPatch()) { + if (!parseVersionPatch()) { return false; } - if (!ParseVersionUnknown()) { + if (!parseVersionUnknown()) { return false; } - if (!ParseVersionPreRelease()) { + if (!parseVersionPreRelease()) { return false; } - if (!ParseVersionBuildInfo()) { + if (!parseVersionBuildInfo()) { return false; } return true; } - bool ParseVersionMajor() { - size_t version_major_start = 0; - auto version_major_end = version_string_.find_first_not_of(digits_); + bool parseVersionMajor() { + size_t versionMajorStart = 0; + auto versionMajorEnd = versionString_.find_first_not_of(digits_); // MAJOR only. - if (version_major_end == std::string::npos) { - version_major_end = version_string_.size(); - version_parsing_position_ = version_major_end; + if (versionMajorEnd == std::string::npos) { + versionMajorEnd = versionString_.size(); + versionParsingPosition_ = versionMajorEnd; } else { // No ".". - if (version_string_[version_major_end] != '.') { + if (versionString_[versionMajorEnd] != '.') { return false; } // No MAJOR. - if (version_major_end == version_major_start) { + if (versionMajorEnd == versionMajorStart) { return false; } - version_parsing_position_ = version_major_end + 1; // +1 is for '.'. + versionParsingPosition_ = versionMajorEnd + 1; // +1 is for '.'. } - auto version_major_string = version_string_.substr( - version_major_start, version_major_end - version_major_start); - application_version_.version.major = atoi(version_major_string.c_str()); + auto versionMajorString = versionString_.substr( + versionMajorStart, versionMajorEnd - versionMajorStart); + ApplicationVersion_.version.major = atoi(versionMajorString.c_str()); return true; } - bool ParseVersionMinor() { - auto version_minor_start = version_parsing_position_; - auto version_minor_end = - version_string_.find_first_not_of(digits_, version_minor_start); + bool parseVersionMinor() { + auto versionMinorStart = versionParsingPosition_; + auto versionMinorEnd = + versionString_.find_first_not_of(digits_, versionMinorStart); // MAJOR.MINOR only. - if (version_minor_end == std::string::npos) { - version_minor_end = version_string_.size(); - version_parsing_position_ = version_minor_end; + if (versionMinorEnd == std::string::npos) { + versionMinorEnd = versionString_.size(); + versionParsingPosition_ = versionMinorEnd; } else { // No ".". - if (version_string_[version_minor_end] != '.') { + if (versionString_[versionMinorEnd] != '.') { return false; } // No MINOR. - if (version_minor_end == version_minor_start) { + if (versionMinorEnd == versionMinorStart) { return false; } - version_parsing_position_ = version_minor_end + 1; // +1 is for '.'. + versionParsingPosition_ = versionMinorEnd + 1; // +1 is for '.'. } - auto version_minor_string = version_string_.substr( - version_minor_start, version_minor_end - version_minor_start); - application_version_.version.minor = atoi(version_minor_string.c_str()); + auto versionMinorString = versionString_.substr( + versionMinorStart, versionMinorEnd - versionMinorStart); + ApplicationVersion_.version.minor = atoi(versionMinorString.c_str()); return true; } - bool ParseVersionPatch() { - auto version_patch_start = version_parsing_position_; - auto version_patch_end = - version_string_.find_first_not_of(digits_, version_patch_start); + bool parseVersionPatch() { + auto versionPatchStart = versionParsingPosition_; + auto versionPatchEnd = + versionString_.find_first_not_of(digits_, versionPatchStart); // No UNKNOWN, PRE_RELEASE and BUILD_INFO. - if (version_patch_end == std::string::npos) { - version_patch_end = version_string_.size(); + if (versionPatchEnd == std::string::npos) { + versionPatchEnd = versionString_.size(); } // No PATCH. - if (version_patch_end == version_patch_start) { + if (versionPatchEnd == versionPatchStart) { return false; } - auto version_patch_string = version_string_.substr( - version_patch_start, version_patch_end - version_patch_start); - application_version_.version.patch = atoi(version_patch_string.c_str()); - version_parsing_position_ = version_patch_end; + auto versionPatchString = versionString_.substr( + versionPatchStart, versionPatchEnd - versionPatchStart); + ApplicationVersion_.version.patch = atoi(versionPatchString.c_str()); + versionParsingPosition_ = versionPatchEnd; return true; } - bool ParseVersionUnknown() { + bool parseVersionUnknown() { // No UNKNOWN. - if (version_parsing_position_ == version_string_.size()) { + if (versionParsingPosition_ == versionString_.size()) { return true; } - auto version_unknown_start = version_parsing_position_; - auto version_unknown_end = - version_string_.find_first_of("-+", version_unknown_start); - // No PRE_RELEASE and BUILD_INFO - if (version_unknown_end == std::string::npos) { - version_unknown_end = version_string_.size(); + auto versionUnknownStart = versionParsingPosition_; + auto versionUnknownEnd = + versionString_.find_first_of("-+", versionUnknownStart); + // No PRE_RELEASE and BUILD_INFO. + if (versionUnknownEnd == std::string::npos) { + versionUnknownEnd = versionString_.size(); } - application_version_.version.unknown = version_string_.substr( - version_unknown_start, version_unknown_end - version_unknown_start); - version_parsing_position_ = version_unknown_end; + ApplicationVersion_.version.unknown = versionString_.substr( + versionUnknownStart, versionUnknownEnd - versionUnknownStart); + versionParsingPosition_ = versionUnknownEnd; return true; } - bool ParseVersionPreRelease() { + bool parseVersionPreRelease() { // No PRE_RELEASE. - if (version_parsing_position_ == version_string_.size() || - version_string_[version_parsing_position_] != '-') { + if (versionParsingPosition_ == versionString_.size() || + versionString_[versionParsingPosition_] != '-') { return true; } - auto version_pre_release_start = - version_parsing_position_ + 1; // +1 is for '-'. - auto version_pre_release_end = - version_string_.find_first_of("+", version_pre_release_start); - // No BUILD_INFO - if (version_pre_release_end == std::string::npos) { - version_pre_release_end = version_string_.size(); + auto versionPreReleaseStart = versionParsingPosition_ + 1; // +1 is for '-'. + auto versionPreReleaseEnd = + versionString_.find_first_of("+", versionPreReleaseStart); + // No BUILD_INFO. + if (versionPreReleaseEnd == std::string::npos) { + versionPreReleaseEnd = versionString_.size(); } - application_version_.version.pre_release = version_string_.substr( - version_pre_release_start, - version_pre_release_end - version_pre_release_start); - version_parsing_position_ = version_pre_release_end; + ApplicationVersion_.version.preRelease = versionString_.substr( + versionPreReleaseStart, versionPreReleaseEnd - versionPreReleaseStart); + versionParsingPosition_ = versionPreReleaseEnd; return true; } - bool ParseVersionBuildInfo() { + bool parseVersionBuildInfo() { // No BUILD_INFO. - if (version_parsing_position_ == version_string_.size() || - version_string_[version_parsing_position_] != '+') { + if (versionParsingPosition_ == versionString_.size() || + versionString_[versionParsingPosition_] != '+') { return true; } - auto version_build_info_start = - version_parsing_position_ + 1; // +1 is for '+'. - application_version_.version.build_info = - version_string_.substr(version_build_info_start); + auto versionBuildInfoStart = versionParsingPosition_ + 1; // +1 is for '+'. + ApplicationVersion_.version.buildInfo = + versionString_.substr(versionBuildInfoStart); return true; } - bool ParseBuildName() { - std::string build_mark(" (build "); - auto build_mark_position = created_by_.find(build_mark, version_end_); + bool parseBuildName() { + std::string buildMark(" (build "); + auto buildMarkPosition = createdBy_.find(buildMark, versionEnd_); // No BUILD_NAME. - if (build_mark_position == std::string::npos) { + if (buildMarkPosition == std::string::npos) { return false; } - auto build_name_start = build_mark_position + build_mark.size(); - RemovePrecedingSpaces(created_by_, build_name_start, created_by_.size()); - auto build_name_end = created_by_.find_first_of(")", build_name_start); + auto buildNameStart = buildMarkPosition + buildMark.size(); + removePrecedingSpaces(createdBy_, buildNameStart, createdBy_.size()); + auto buildNameEnd = createdBy_.find_first_of(")", buildNameStart); // No end ")". - if (build_name_end == std::string::npos) { + if (buildNameEnd == std::string::npos) { return false; } - RemoveTrailingSpaces(created_by_, build_name_start, build_name_end); - application_version_.build_ = - created_by_.substr(build_name_start, build_name_end - build_name_start); + removeTrailingSpaces(createdBy_, buildNameStart, buildNameEnd); + ApplicationVersion_.build_ = + createdBy_.substr(buildNameStart, buildNameEnd - buildNameStart); return true; } - const std::string& created_by_; - ApplicationVersion& application_version_; + const std::string& createdBy_; + ApplicationVersion& ApplicationVersion_; // For parsing. std::string spaces_; std::string digits_; - size_t version_parsing_position_; - size_t version_start_; - size_t version_end_; - std::string version_string_; + size_t versionParsingPosition_; + size_t versionStart_; + size_t versionEnd_; + std::string versionString_; }; } // namespace -ApplicationVersion::ApplicationVersion(const std::string& created_by) { - ApplicationVersionParser parser(created_by, *this); - parser.Parse(); +ApplicationVersion::ApplicationVersion(const std::string& createdBy) { + ApplicationVersionParser parser(createdBy, *this); + parser.parse(); } -bool ApplicationVersion::VersionLt( - const ApplicationVersion& other_version) const { - if (application_ != other_version.application_) +bool ApplicationVersion::versionLt( + const ApplicationVersion& otherVersion) const { + if (application_ != otherVersion.application_) return false; - if (version.major < other_version.version.major) + if (version.major < otherVersion.version.major) return true; - if (version.major > other_version.version.major) + if (version.major > otherVersion.version.major) return false; - VELOX_DCHECK_EQ(version.major, other_version.version.major); - if (version.minor < other_version.version.minor) + VELOX_DCHECK_EQ(version.major, otherVersion.version.major); + if (version.minor < otherVersion.version.minor) return true; - if (version.minor > other_version.version.minor) + if (version.minor > otherVersion.version.minor) return false; - VELOX_DCHECK_EQ(version.minor, other_version.version.minor); - return version.patch < other_version.version.patch; + VELOX_DCHECK_EQ(version.minor, otherVersion.version.minor); + return version.patch < otherVersion.version.patch; } -bool ApplicationVersion::VersionEq( - const ApplicationVersion& other_version) const { - return application_ == other_version.application_ && - version.major == other_version.version.major && - version.minor == other_version.version.minor && - version.patch == other_version.version.patch; +bool ApplicationVersion::versionEq( + const ApplicationVersion& otherVersion) const { + return application_ == otherVersion.application_ && + version.major == otherVersion.version.major && + version.minor == otherVersion.version.minor && + version.patch == otherVersion.version.patch; } // Reference: -// parquet-mr/parquet-column/src/main/java/org/apache/parquet/CorruptStatistics.java -// PARQUET-686 has more discussion on statistics -bool ApplicationVersion::HasCorrectStatistics( - Type::type col_type, +// Parquet-mr/parquet-column/src/main/java/org/apache/parquet/CorruptStatistics.java. +// PARQUET-686 has more discussion on statistics. +bool ApplicationVersion::hasCorrectStatistics( + Type::type colType, EncodedStatistics& statistics, - SortOrder::type sort_order) const { - // parquet-cpp version 1.3.0 and parquet-mr 1.10.0 onwards stats are computed - // correctly for all types + SortOrder::type sortOrder) const { + // Parquet-cpp version 1.3.0 and parquet-mr 1.10.0 onwards stats are computed + // correctly for all types. if ((application_ == "parquet-cpp" && - VersionLt(PARQUET_CPP_FIXED_STATS_VERSION())) || + versionLt(PARQUET_CPP_FIXED_STATS_VERSION())) || (application_ == "parquet-mr" && - VersionLt(PARQUET_MR_FIXED_STATS_VERSION()))) { - // Only SIGNED are valid unless max and min are the same + versionLt(PARQUET_MR_FIXED_STATS_VERSION()))) { + // Only SIGNED are valid unless max and min are the same. // (in which case the sort order does not matter) - bool max_equals_min = statistics.has_min && statistics.has_max + bool maxEqualsMin = statistics.hasMin && statistics.hasMax ? statistics.min() == statistics.max() : false; - if (SortOrder::SIGNED != sort_order && !max_equals_min) { + if (SortOrder::kSigned != sortOrder && !maxEqualsMin) { return false; } - // Statistics of other types are OK - if (col_type != Type::FIXED_LEN_BYTE_ARRAY && - col_type != Type::BYTE_ARRAY) { + // Statistics of other types are OK. + if (colType != Type::kFixedLenByteArray && colType != Type::kByteArray) { return true; } } - // created_by is not populated, which could have been caused by - // parquet-mr during the same time as PARQUET-251, see PARQUET-297 + // Created_by is not populated, which could have been caused by + // Parquet-mr during the same time as PARQUET-251, see PARQUET-297. if (application_ == "unknown") { return true; } - // Unknown sort order has incorrect stats - if (SortOrder::UNKNOWN == sort_order) { + // Unknown sort order has incorrect stats. + if (SortOrder::kUnknown == sortOrder) { return false; } - // PARQUET-251 - if (VersionLt(PARQUET_251_FIXED_VERSION())) { + // PARQUET-251. + if (versionLt(PARQUET_251_FIXED_VERSION())) { return false; } return true; } -// MetaData Builders -// row-group metadata +// MetaData Builders. +// Row-group metadata. class ColumnChunkMetaDataBuilder::ColumnChunkMetaDataBuilderImpl { public: explicit ColumnChunkMetaDataBuilderImpl( std::shared_ptr props, const ColumnDescriptor* column) - : owned_column_chunk_(new facebook::velox::parquet::thrift::ColumnChunk), + : ownedColumnChunk_(new facebook::velox::parquet::thrift::ColumnChunk), properties_(std::move(props)), column_(column) { - Init(owned_column_chunk_.get()); + init(ownedColumnChunk_.get()); } explicit ColumnChunkMetaDataBuilderImpl( std::shared_ptr props, const ColumnDescriptor* column, - facebook::velox::parquet::thrift::ColumnChunk* column_chunk) + facebook::velox::parquet::thrift::ColumnChunk* columnChunk) : properties_(std::move(props)), column_(column) { - Init(column_chunk); + init(columnChunk); } - const void* contents() const { - return column_chunk_; + const void* Contents() const { + return columnChunk_; } - // column chunk - void set_file_path(const std::string& val) { - column_chunk_->__set_file_path(val); + // Column chunk. + void setFilePath(const std::string& val) { + columnChunk_->__set_file_path(val); + } + + // Column metadata. + void setStatistics(const EncodedStatistics& val) { + columnChunk_->meta_data.__set_statistics(toThrift(val)); + // Store NaN count separately since it's not written to the parquet file. + if (val.hasNanCount) { + nanCount_ = val.nanCount; + hasNanCount_ = true; + } } - // column metadata - void SetStatistics(const EncodedStatistics& val) { - column_chunk_->meta_data.__set_statistics(ToThrift(val)); + int64_t nanCount() const { + return nanCount_; } - void Finish( + bool hasNanCount() const { + return hasNanCount_; + } + + void finish( int64_t num_values, int64_t dictionary_page_offset, int64_t index_page_offset, int64_t data_page_offset, - int64_t compressed_size, - int64_t uncompressed_size, - bool has_dictionary, - bool dictionary_fallback, - const std::map& dict_encoding_stats, - const std::map& data_encoding_stats, + int64_t compressedSize, + int64_t uncompressedSize, + bool hasDictionary, + bool dictionaryFallback, + const std::map& dictEncodingStats, + const std::map& dataEncodingStats, const std::shared_ptr& encryptor) { if (dictionary_page_offset > 0) { - column_chunk_->meta_data.__set_dictionary_page_offset( + columnChunk_->meta_data.__set_dictionary_page_offset( dictionary_page_offset); - column_chunk_->__set_file_offset( - dictionary_page_offset + compressed_size); + columnChunk_->__set_file_offset(dictionary_page_offset + compressedSize); } else { - column_chunk_->__set_file_offset(data_page_offset + compressed_size); + columnChunk_->__set_file_offset(data_page_offset + compressedSize); } - column_chunk_->__isset.meta_data = true; - column_chunk_->meta_data.__set_num_values(num_values); + columnChunk_->__isset.meta_data = true; + columnChunk_->meta_data.__set_num_values(num_values); if (index_page_offset >= 0) { - column_chunk_->meta_data.__set_index_page_offset(index_page_offset); + columnChunk_->meta_data.__set_index_page_offset(index_page_offset); } - column_chunk_->meta_data.__set_data_page_offset(data_page_offset); - column_chunk_->meta_data.__set_total_uncompressed_size(uncompressed_size); - column_chunk_->meta_data.__set_total_compressed_size(compressed_size); + columnChunk_->meta_data.__set_data_page_offset(data_page_offset); + columnChunk_->meta_data.__set_total_uncompressed_size(uncompressedSize); + columnChunk_->meta_data.__set_total_compressed_size(compressedSize); std::vector - thrift_encodings; + thriftEncodings; std::vector - thrift_encoding_stats; - auto add_encoding = - [&thrift_encodings]( + thriftEncodingStats; + auto addEncoding = + [&thriftEncodings]( facebook::velox::parquet::thrift::Encoding::type value) { auto it = std::find( - thrift_encodings.begin(), thrift_encodings.end(), value); - if (it == thrift_encodings.end()) { - thrift_encodings.push_back(value); + thriftEncodings.cbegin(), thriftEncodings.cend(), value); + if (it == thriftEncodings.cend()) { + thriftEncodings.push_back(value); } }; - // Add dictionary page encoding stats - if (has_dictionary) { - for (const auto& entry : dict_encoding_stats) { - facebook::velox::parquet::thrift::PageEncodingStats dict_enc_stat; - dict_enc_stat.__set_page_type( + // Add dictionary page encoding stats. + if (hasDictionary) { + for (const auto& entry : dictEncodingStats) { + facebook::velox::parquet::thrift::PageEncodingStats dictEncStat; + dictEncStat.__set_page_type( facebook::velox::parquet::thrift::PageType::DICTIONARY_PAGE); - // Dictionary Encoding would be PLAIN_DICTIONARY in v1 and + // Dictionary encoding would be PLAIN_DICTIONARY in v1 and // PLAIN in v2. - facebook::velox::parquet::thrift::Encoding::type dict_encoding = - ToThrift(entry.first); - dict_enc_stat.__set_encoding(dict_encoding); - dict_enc_stat.__set_count(entry.second); - thrift_encoding_stats.push_back(dict_enc_stat); - add_encoding(dict_encoding); + facebook::velox::parquet::thrift::Encoding::type dictEncoding = + toThrift(entry.first); + dictEncStat.__set_encoding(dictEncoding); + dictEncStat.__set_count(entry.second); + thriftEncodingStats.push_back(dictEncStat); + addEncoding(dictEncoding); } } // Always add encoding for RL/DL. - // BIT_PACKED is supported in `LevelEncoder`, but would only be used - // in benchmark and testing. + // BIT_PACKED is supported in `LevelEncoder`, but would only be used. + // In benchmark and testing. // And for now, we always add RLE even if there are no levels at all, // while parquet-mr is more fine-grained. - add_encoding(facebook::velox::parquet::thrift::Encoding::RLE); - // Add data page encoding stats - for (const auto& entry : data_encoding_stats) { - facebook::velox::parquet::thrift::PageEncodingStats data_enc_stat; - data_enc_stat.__set_page_type( + addEncoding(facebook::velox::parquet::thrift::Encoding::RLE); + // Add data page encoding stats. + for (const auto& entry : dataEncodingStats) { + facebook::velox::parquet::thrift::PageEncodingStats dataEncStat; + dataEncStat.__set_page_type( facebook::velox::parquet::thrift::PageType::DATA_PAGE); - facebook::velox::parquet::thrift::Encoding::type data_encoding = - ToThrift(entry.first); - data_enc_stat.__set_encoding(data_encoding); - data_enc_stat.__set_count(entry.second); - thrift_encoding_stats.push_back(data_enc_stat); - add_encoding(data_encoding); - } - column_chunk_->meta_data.__set_encodings(thrift_encodings); - column_chunk_->meta_data.__set_encoding_stats(thrift_encoding_stats); - - const auto& encrypt_md = properties_->column_encryption_properties( - column_->path()->ToDotString()); - // column is encrypted - if (encrypt_md != nullptr && encrypt_md->is_encrypted()) { - column_chunk_->__isset.crypto_metadata = true; + facebook::velox::parquet::thrift::Encoding::type dataEncoding = + toThrift(entry.first); + dataEncStat.__set_encoding(dataEncoding); + dataEncStat.__set_count(entry.second); + thriftEncodingStats.push_back(dataEncStat); + addEncoding(dataEncoding); + } + columnChunk_->meta_data.__set_encodings(thriftEncodings); + columnChunk_->meta_data.__set_encoding_stats(thriftEncodingStats); + + const auto& encryptMd = + properties_->columnEncryptionProperties(column_->path()->toDotString()); + // Column is encrypted. + if (encryptMd != nullptr && encryptMd->isEncrypted()) { + columnChunk_->__isset.crypto_metadata = true; facebook::velox::parquet::thrift::ColumnCryptoMetaData ccmd; - if (encrypt_md->is_encrypted_with_footer_key()) { - // encrypted with footer key + if (encryptMd->isEncryptedWithFooterKey()) { + // Encrypted with footer key. ccmd.__isset.ENCRYPTION_WITH_FOOTER_KEY = true; ccmd.__set_ENCRYPTION_WITH_FOOTER_KEY( facebook::velox::parquet::thrift::EncryptionWithFooterKey()); } else { // encrypted with column key facebook::velox::parquet::thrift::EncryptionWithColumnKey eck; - eck.__set_key_metadata(encrypt_md->key_metadata()); - eck.__set_path_in_schema(column_->path()->ToDotVector()); + eck.__set_key_metadata(encryptMd->keyMetadata()); + eck.__set_path_in_schema(column_->path()->toDotVector()); ccmd.__isset.ENCRYPTION_WITH_COLUMN_KEY = true; ccmd.__set_ENCRYPTION_WITH_COLUMN_KEY(eck); } - column_chunk_->__set_crypto_metadata(ccmd); + columnChunk_->__set_crypto_metadata(ccmd); - bool encrypted_footer = - properties_->file_encryption_properties()->encrypted_footer(); - bool encrypt_metadata = - !encrypted_footer || !encrypt_md->is_encrypted_with_footer_key(); - if (encrypt_metadata) { + bool encryptedFooter = + properties_->fileEncryptionProperties()->encryptedFooter(); + bool encryptMetadata = + !encryptedFooter || !encryptMd->isEncryptedWithFooterKey(); + if (encryptMetadata) { ThriftSerializer serializer; - // Serialize and encrypt ColumnMetadata separately + // Serialize and encrypt ColumnMetadata separately. // Thrift-serialize the ColumnMetaData structure, // encrypt it with the column key, and write to - // encrypted_column_metadata - uint8_t* serialized_data; - uint32_t serialized_len; + // encrypted_column_metadata. + uint8_t* serializedData; + uint32_t serializedLen; - serializer.SerializeToBuffer( - &column_chunk_->meta_data, &serialized_len, &serialized_data); + serializer.serializeToBuffer( + &columnChunk_->meta_data, &serializedLen, &serializedData); - std::vector encrypted_data( - encryptor->CiphertextSizeDelta() + serialized_len); - unsigned encrypted_len = encryptor->Encrypt( - serialized_data, serialized_len, encrypted_data.data()); + std::vector encryptedData( + encryptor->ciphertextSizeDelta() + serializedLen); + unsigned encryptedLen = encryptor->encrypt( + serializedData, serializedLen, encryptedData.data()); const char* temp = const_cast( - reinterpret_cast(encrypted_data.data())); - std::string encrypted_column_metadata(temp, encrypted_len); - column_chunk_->__set_encrypted_column_metadata( + reinterpret_cast(encryptedData.data())); + std::string encrypted_column_metadata(temp, encryptedLen); + columnChunk_->__set_encrypted_column_metadata( encrypted_column_metadata); - if (encrypted_footer) { - column_chunk_->__isset.meta_data = false; + if (encryptedFooter) { + columnChunk_->__isset.meta_data = false; } else { - // Keep redacted metadata version for old readers - column_chunk_->__isset.meta_data = true; - column_chunk_->meta_data.__isset.statistics = false; - column_chunk_->meta_data.__isset.encoding_stats = false; + // Keep redacted metadata version for old readers. + columnChunk_->__isset.meta_data = true; + columnChunk_->meta_data.__isset.statistics = false; + columnChunk_->meta_data.__isset.encoding_stats = false; } } } } - void WriteTo(::arrow::io::OutputStream* sink) { + void writeTo(::arrow::io::OutputStream* sink) { ThriftSerializer serializer; - serializer.Serialize(column_chunk_, sink); + serializer.serialize(columnChunk_, sink); } const ColumnDescriptor* descr() const { return column_; } - int64_t total_compressed_size() const { - return column_chunk_->meta_data.total_compressed_size; + int64_t totalCompressedSize() const { + return columnChunk_->meta_data.total_compressed_size; } private: - void Init(facebook::velox::parquet::thrift::ColumnChunk* column_chunk) { - column_chunk_ = column_chunk; + void init(facebook::velox::parquet::thrift::ColumnChunk* columnChunk) { + columnChunk_ = columnChunk; - column_chunk_->meta_data.__set_type(ToThrift(column_->physical_type())); - column_chunk_->meta_data.__set_path_in_schema( - column_->path()->ToDotVector()); - column_chunk_->meta_data.__set_codec( - ToThrift(properties_->compression(column_->path()))); + columnChunk_->meta_data.__set_type(toThrift(column_->physicalType())); + columnChunk_->meta_data.__set_path_in_schema( + column_->path()->toDotVector()); + columnChunk_->meta_data.__set_codec( + toThrift(properties_->compression(column_->path()))); } - facebook::velox::parquet::thrift::ColumnChunk* column_chunk_; + facebook::velox::parquet::thrift::ColumnChunk* columnChunk_; std::unique_ptr - owned_column_chunk_; + ownedColumnChunk_; const std::shared_ptr properties_; const ColumnDescriptor* column_; + // NaN count is stored separately since it's not written to the parquet file. + int64_t nanCount_ = 0; + bool hasNanCount_ = false; }; -std::unique_ptr ColumnChunkMetaDataBuilder::Make( +std::unique_ptr ColumnChunkMetaDataBuilder::make( std::shared_ptr props, const ColumnDescriptor* column, - void* contents) { + void* Contents) { return std::unique_ptr( - new ColumnChunkMetaDataBuilder(std::move(props), column, contents)); + new ColumnChunkMetaDataBuilder(std::move(props), column, Contents)); } -std::unique_ptr ColumnChunkMetaDataBuilder::Make( +std::unique_ptr ColumnChunkMetaDataBuilder::make( std::shared_ptr props, const ColumnDescriptor* column) { return std::unique_ptr( @@ -1900,65 +1942,73 @@ ColumnChunkMetaDataBuilder::ColumnChunkMetaDataBuilder( ColumnChunkMetaDataBuilder::ColumnChunkMetaDataBuilder( std::shared_ptr props, const ColumnDescriptor* column, - void* contents) + void* Contents) : impl_{std::unique_ptr( new ColumnChunkMetaDataBuilderImpl( std::move(props), column, reinterpret_cast( - contents)))} {} + Contents)))} {} ColumnChunkMetaDataBuilder::~ColumnChunkMetaDataBuilder() = default; -const void* ColumnChunkMetaDataBuilder::contents() const { - return impl_->contents(); +const void* ColumnChunkMetaDataBuilder::Contents() const { + return impl_->Contents(); } -void ColumnChunkMetaDataBuilder::set_file_path(const std::string& path) { - impl_->set_file_path(path); +void ColumnChunkMetaDataBuilder::setFilePath(const std::string& path) { + impl_->setFilePath(path); } -void ColumnChunkMetaDataBuilder::Finish( +void ColumnChunkMetaDataBuilder::finish( int64_t num_values, int64_t dictionary_page_offset, int64_t index_page_offset, int64_t data_page_offset, - int64_t compressed_size, - int64_t uncompressed_size, - bool has_dictionary, - bool dictionary_fallback, - const std::map& dict_encoding_stats, - const std::map& data_encoding_stats, + int64_t compressedSize, + int64_t uncompressedSize, + bool hasDictionary, + bool dictionaryFallback, + const std::map& dictEncodingStats, + const std::map& dataEncodingStats, const std::shared_ptr& encryptor) { - impl_->Finish( + impl_->finish( num_values, dictionary_page_offset, index_page_offset, data_page_offset, - compressed_size, - uncompressed_size, - has_dictionary, - dictionary_fallback, - dict_encoding_stats, - data_encoding_stats, + compressedSize, + uncompressedSize, + hasDictionary, + dictionaryFallback, + dictEncodingStats, + dataEncodingStats, encryptor); } -void ColumnChunkMetaDataBuilder::WriteTo(::arrow::io::OutputStream* sink) { - impl_->WriteTo(sink); +void ColumnChunkMetaDataBuilder::writeTo(::arrow::io::OutputStream* sink) { + impl_->writeTo(sink); } const ColumnDescriptor* ColumnChunkMetaDataBuilder::descr() const { return impl_->descr(); } -void ColumnChunkMetaDataBuilder::SetStatistics( +void ColumnChunkMetaDataBuilder::setStatistics( const EncodedStatistics& result) { - impl_->SetStatistics(result); + impl_->setStatistics(result); +} + +int64_t ColumnChunkMetaDataBuilder::totalCompressedSize() const { + return impl_->totalCompressedSize(); +} + +int64_t ColumnChunkMetaDataBuilder::nanCount() const { + return impl_->nanCount(); } -int64_t ColumnChunkMetaDataBuilder::total_compressed_size() const { - return impl_->total_compressed_size(); +bool ColumnChunkMetaDataBuilder::hasNanCount() const { + return impl_->hasNanCount(); } class RowGroupMetaDataBuilder::RowGroupMetaDataBuilderImpl { @@ -1966,313 +2016,334 @@ class RowGroupMetaDataBuilder::RowGroupMetaDataBuilderImpl { explicit RowGroupMetaDataBuilderImpl( std::shared_ptr props, const SchemaDescriptor* schema, - void* contents) - : properties_(std::move(props)), schema_(schema), next_column_(0) { - row_group_ = - reinterpret_cast(contents); - InitializeColumns(schema->num_columns()); + void* Contents) + : properties_(std::move(props)), schema_(schema), nextColumn_(0) { + rowGroup_ = + reinterpret_cast(Contents); + initializeColumns(schema->numColumns()); } - ColumnChunkMetaDataBuilder* NextColumnChunk() { - if (!(next_column_ < num_columns())) { + ColumnChunkMetaDataBuilder* nextColumnChunk() { + if (!(nextColumn_ < numColumns())) { std::stringstream ss; - ss << "The schema only has " << num_columns() - << " columns, requested metadata for column: " << next_column_; + ss << "The schema only has " << numColumns() + << " columns, requested metadata for column: " << nextColumn_; throw ParquetException(ss.str()); } - auto column = schema_->Column(next_column_); - auto column_builder = ColumnChunkMetaDataBuilder::Make( - properties_, column, &row_group_->columns[next_column_++]); - auto column_builder_ptr = column_builder.get(); - column_builders_.push_back(std::move(column_builder)); - return column_builder_ptr; + auto column = schema_->column(nextColumn_); + auto columnBuilder = ColumnChunkMetaDataBuilder::make( + properties_, column, &rowGroup_->columns[nextColumn_++]); + auto columnBuilderPtr = columnBuilder.get(); + columnBuilders_.push_back(std::move(columnBuilder)); + return columnBuilderPtr; } - int current_column() { - return next_column_ - 1; + int currentColumn() { + return nextColumn_ - 1; } - void Finish(int64_t total_bytes_written, int16_t row_group_ordinal) { - if (!(next_column_ == schema_->num_columns())) { + void finish(int64_t totalBytesWritten, int16_t rowGroupOrdinal) { + if (!(nextColumn_ == schema_->numColumns())) { std::stringstream ss; - ss << "Only " << next_column_ - 1 << " out of " << schema_->num_columns() + ss << "Only " << nextColumn_ - 1 << " out of " << schema_->numColumns() << " columns are initialized"; throw ParquetException(ss.str()); } - int64_t file_offset = 0; + int64_t fileOffset = 0; int64_t total_compressed_size = 0; - for (int i = 0; i < schema_->num_columns(); i++) { - if (!(row_group_->columns[i].file_offset >= 0)) { + for (int i = 0; i < schema_->numColumns(); i++) { + if (!(rowGroup_->columns[i].file_offset >= 0)) { std::stringstream ss; ss << "Column " << i << " is not complete."; throw ParquetException(ss.str()); } if (i == 0) { - const facebook::velox::parquet::thrift::ColumnMetaData& first_col = - row_group_->columns[0].meta_data; + const facebook::velox::parquet::thrift::ColumnMetaData& firstCol = + rowGroup_->columns[0].meta_data; // As per spec, file_offset for the row group points to the first // dictionary or data page of the column. - if (first_col.__isset.dictionary_page_offset && - first_col.dictionary_page_offset > 0) { - file_offset = first_col.dictionary_page_offset; + if (firstCol.__isset.dictionary_page_offset && + firstCol.dictionary_page_offset > 0) { + fileOffset = firstCol.dictionary_page_offset; } else { - file_offset = first_col.data_page_offset; + fileOffset = firstCol.data_page_offset; } } - // sometimes column metadata is encrypted and not available to read, - // so we must get total_compressed_size from column builder - total_compressed_size += column_builders_[i]->total_compressed_size(); + // Sometimes column metadata is encrypted and not available to read, + // so we must get total_compressed_size from column builder. + total_compressed_size += columnBuilders_[i]->totalCompressedSize(); } - const auto& sorting_columns = properties_->sorting_columns(); - if (!sorting_columns.empty()) { + const auto& sortingColumns = properties_->sortingColumns(); + if (!sortingColumns.empty()) { std::vector - thrift_sorting_columns(sorting_columns.size()); - for (size_t i = 0; i < sorting_columns.size(); ++i) { - thrift_sorting_columns[i] = ToThrift(sorting_columns[i]); + thriftSortingColumns(sortingColumns.size()); + for (size_t i = 0; i < sortingColumns.size(); ++i) { + thriftSortingColumns[i] = toThrift(sortingColumns[i]); } - row_group_->__set_sorting_columns(std::move(thrift_sorting_columns)); + rowGroup_->__set_sorting_columns(std::move(thriftSortingColumns)); } - row_group_->__set_file_offset(file_offset); - row_group_->__set_total_compressed_size(total_compressed_size); - row_group_->__set_total_byte_size(total_bytes_written); - row_group_->__set_ordinal(row_group_ordinal); + rowGroup_->__set_file_offset(fileOffset); + rowGroup_->__set_total_compressed_size(total_compressed_size); + rowGroup_->__set_total_byte_size(totalBytesWritten); + rowGroup_->__set_ordinal(rowGroupOrdinal); } - void set_num_rows(int64_t num_rows) { - row_group_->num_rows = num_rows; + void setNumRows(int64_t numRows) { + rowGroup_->num_rows = numRows; } - int num_columns() { - return static_cast(row_group_->columns.size()); + int numColumns() { + return static_cast(rowGroup_->columns.size()); } - int64_t num_rows() { - return row_group_->num_rows; + int64_t numRows() { + return rowGroup_->num_rows; + } + + // Returns a map of field_id -> (nan_count, has_nan_count). + std::unordered_map> nanCounts() const { + std::unordered_map> result; + for (const auto& builder : columnBuilders_) { + int32_t field_id = builder->descr()->schemaNode()->fieldId(); + result[field_id] = {builder->nanCount(), builder->hasNanCount()}; + } + return result; } private: - void InitializeColumns(int ncols) { - row_group_->columns.resize(ncols); + void initializeColumns(int ncols) { + rowGroup_->columns.resize(ncols); } - facebook::velox::parquet::thrift::RowGroup* row_group_; + facebook::velox::parquet::thrift::RowGroup* rowGroup_; const std::shared_ptr properties_; const SchemaDescriptor* schema_; - std::vector> column_builders_; - int next_column_; + std::vector> columnBuilders_; + int nextColumn_; }; -std::unique_ptr RowGroupMetaDataBuilder::Make( +std::unique_ptr RowGroupMetaDataBuilder::make( std::shared_ptr props, const SchemaDescriptor* schema_, - void* contents) { + void* Contents) { return std::unique_ptr( - new RowGroupMetaDataBuilder(std::move(props), schema_, contents)); + new RowGroupMetaDataBuilder(std::move(props), schema_, Contents)); } RowGroupMetaDataBuilder::RowGroupMetaDataBuilder( std::shared_ptr props, const SchemaDescriptor* schema_, - void* contents) + void* Contents) : impl_{new RowGroupMetaDataBuilderImpl( std::move(props), schema_, - contents)} {} + Contents)} {} RowGroupMetaDataBuilder::~RowGroupMetaDataBuilder() = default; -ColumnChunkMetaDataBuilder* RowGroupMetaDataBuilder::NextColumnChunk() { - return impl_->NextColumnChunk(); +ColumnChunkMetaDataBuilder* RowGroupMetaDataBuilder::nextColumnChunk() { + return impl_->nextColumnChunk(); +} + +int RowGroupMetaDataBuilder::currentColumn() const { + return impl_->currentColumn(); } -int RowGroupMetaDataBuilder::current_column() const { - return impl_->current_column(); +int RowGroupMetaDataBuilder::numColumns() { + return impl_->numColumns(); } -int RowGroupMetaDataBuilder::num_columns() { - return impl_->num_columns(); +int64_t RowGroupMetaDataBuilder::numRows() { + return impl_->numRows(); } -int64_t RowGroupMetaDataBuilder::num_rows() { - return impl_->num_rows(); +void RowGroupMetaDataBuilder::setNumRows(int64_t numRows) { + impl_->setNumRows(numRows); } -void RowGroupMetaDataBuilder::set_num_rows(int64_t num_rows) { - impl_->set_num_rows(num_rows); +void RowGroupMetaDataBuilder::finish( + int64_t totalBytesWritten, + int16_t rowGroupOrdinal) { + impl_->finish(totalBytesWritten, rowGroupOrdinal); } -void RowGroupMetaDataBuilder::Finish( - int64_t total_bytes_written, - int16_t row_group_ordinal) { - impl_->Finish(total_bytes_written, row_group_ordinal); +std::unordered_map> +RowGroupMetaDataBuilder::nanCounts() const { + return impl_->nanCounts(); } -// file metadata +// File metadata. class FileMetaDataBuilder::FileMetaDataBuilderImpl { public: explicit FileMetaDataBuilderImpl( const SchemaDescriptor* schema, std::shared_ptr props, - std::shared_ptr key_value_metadata) + std::shared_ptr keyValueMetadata) : metadata_(new facebook::velox::parquet::thrift::FileMetaData()), properties_(std::move(props)), schema_(schema), - key_value_metadata_(std::move(key_value_metadata)) { - if (properties_->file_encryption_properties() != nullptr && - properties_->file_encryption_properties()->encrypted_footer()) { + keyValueMetadata_(std::move(keyValueMetadata)) { + if (properties_->fileEncryptionProperties() != nullptr && + properties_->fileEncryptionProperties()->encryptedFooter()) { crypto_metadata_.reset( new facebook::velox::parquet::thrift::FileCryptoMetaData()); } } - RowGroupMetaDataBuilder* AppendRowGroup() { - row_groups_.emplace_back(); - current_row_group_builder_ = RowGroupMetaDataBuilder::Make( - properties_, schema_, &row_groups_.back()); - return current_row_group_builder_.get(); - } - - void SetPageIndexLocation(const PageIndexLocation& location) { - auto set_index_location = [this]( - size_t row_group_ordinal, - const PageIndexLocation::FileIndexLocation& - file_index_location, - bool column_index) { - auto& row_group_metadata = this->row_groups_.at(row_group_ordinal); - auto iter = file_index_location.find(row_group_ordinal); - if (iter != file_index_location.cend()) { - const auto& row_group_index_location = iter->second; - for (size_t i = 0; i < row_group_index_location.size(); ++i) { - if (i >= row_group_metadata.columns.size()) { + RowGroupMetaDataBuilder* appendRowGroup() { + // Accumulate NaN counts from the previous row group before creating a new + // one. + accumulateNaNCountsFromCurrentRowGroup(); + rowGroups_.emplace_back(); + currentRowGroupBuilder_ = + RowGroupMetaDataBuilder::make(properties_, schema_, &rowGroups_.back()); + return currentRowGroupBuilder_.get(); + } + + void setPageIndexLocation(const PageIndexLocation& location) { + auto setIndexLocation = [this]( + size_t rowGroupOrdinal, + const PageIndexLocation::FileIndexLocation& + fileIndexLocation, + bool columnIndex) { + auto& rowGroupMetadata = this->rowGroups_.at(rowGroupOrdinal); + auto iter = fileIndexLocation.find(rowGroupOrdinal); + if (iter != fileIndexLocation.cend()) { + const auto& rowGroupIndexLocation = iter->second; + for (size_t i = 0; i < rowGroupIndexLocation.size(); ++i) { + if (i >= rowGroupMetadata.columns.size()) { throw ParquetException( "Cannot find metadata for column ordinal ", i); } - auto& column_metadata = row_group_metadata.columns.at(i); - const auto& index_location = row_group_index_location.at(i); - if (index_location.has_value()) { - if (column_index) { - column_metadata.__set_column_index_offset(index_location->offset); - column_metadata.__set_column_index_length(index_location->length); + auto& columnMetadata = rowGroupMetadata.columns.at(i); + const auto& indexLocation = rowGroupIndexLocation.at(i); + if (indexLocation.has_value()) { + if (columnIndex) { + columnMetadata.__set_column_index_offset(indexLocation->offset); + columnMetadata.__set_column_index_length(indexLocation->length); } else { - column_metadata.__set_offset_index_offset(index_location->offset); - column_metadata.__set_offset_index_length(index_location->length); + columnMetadata.__set_offset_index_offset(indexLocation->offset); + columnMetadata.__set_offset_index_length(indexLocation->length); } } } } }; - for (size_t i = 0; i < row_groups_.size(); ++i) { - set_index_location(i, location.column_index_location, true); - set_index_location(i, location.offset_index_location, false); + for (size_t i = 0; i < rowGroups_.size(); ++i) { + setIndexLocation(i, location.columnIndexLocation, true); + setIndexLocation(i, location.offsetIndexLocation, false); } } - std::unique_ptr Finish( - const std::shared_ptr& key_value_metadata) { - int64_t total_rows = 0; - for (auto row_group : row_groups_) { - total_rows += row_group.num_rows; + std::unique_ptr finish( + const std::shared_ptr& keyValueMetadata) { + // Accumulate NaN counts from the last row group. + accumulateNaNCountsFromCurrentRowGroup(); + + int64_t totalRows = 0; + for (auto rowGroup : rowGroups_) { + totalRows += rowGroup.num_rows; } - metadata_->__set_num_rows(total_rows); - metadata_->__set_row_groups(row_groups_); + metadata_->__set_num_rows(totalRows); + metadata_->__set_row_groups(rowGroups_); - if (key_value_metadata_ || key_value_metadata) { - if (!key_value_metadata_) { - key_value_metadata_ = key_value_metadata; - } else if (key_value_metadata) { - key_value_metadata_ = key_value_metadata_->Merge(*key_value_metadata); + if (keyValueMetadata_ || keyValueMetadata) { + if (!keyValueMetadata_) { + keyValueMetadata_ = keyValueMetadata; + } else if (keyValueMetadata) { + keyValueMetadata_ = keyValueMetadata_->Merge(*keyValueMetadata); } metadata_->key_value_metadata.clear(); - metadata_->key_value_metadata.reserve(key_value_metadata_->size()); - for (int64_t i = 0; i < key_value_metadata_->size(); ++i) { - facebook::velox::parquet::thrift::KeyValue kv_pair; - kv_pair.__set_key(key_value_metadata_->key(i)); - kv_pair.__set_value(key_value_metadata_->value(i)); - metadata_->key_value_metadata.push_back(kv_pair); + metadata_->key_value_metadata.reserve(keyValueMetadata_->size()); + for (int64_t i = 0; i < keyValueMetadata_->size(); ++i) { + facebook::velox::parquet::thrift::KeyValue kvPair; + kvPair.__set_key(keyValueMetadata_->key(i)); + kvPair.__set_value(keyValueMetadata_->value(i)); + metadata_->key_value_metadata.push_back(kvPair); } metadata_->__isset.key_value_metadata = true; } - int32_t file_version = 0; + int32_t fileVersion = 0; switch (properties_->version()) { case ParquetVersion::PARQUET_1_0: - file_version = 1; + fileVersion = 1; break; default: - file_version = 2; + fileVersion = 2; break; } - metadata_->__set_version(file_version); - metadata_->__set_created_by(properties_->created_by()); - - // Users cannot set the `ColumnOrder` since we do not have user defined sort - // order in the spec yet. We always default to `TYPE_DEFINED_ORDER`. We can - // expose it in the API once we have user defined sort orders in the Parquet - // format. TypeDefinedOrder implies choose SortOrder based on - // ConvertedType/PhysicalType - facebook::velox::parquet::thrift::TypeDefinedOrder type_defined_order; - facebook::velox::parquet::thrift::ColumnOrder column_order; - column_order.__set_TYPE_ORDER(type_defined_order); - column_order.__isset.TYPE_ORDER = true; - metadata_->column_orders.resize(schema_->num_columns(), column_order); + metadata_->__set_version(fileVersion); + metadata_->__set_created_by(properties_->createdBy()); + + // Users cannot set the `ColumnOrder` since we do not have user-defined + // sort order in the spec yet. We always default to `TYPE_DEFINED_ORDER`. + // We can expose it in the API once we have user-defined sort orders in the + // Parquet format. TypeDefinedOrder implies choose SortOrder based on + // convertedType/physicalType. + facebook::velox::parquet::thrift::TypeDefinedOrder typeDefinedOrder; + facebook::velox::parquet::thrift::ColumnOrder columnOrder; + columnOrder.__set_TYPE_ORDER(typeDefinedOrder); + columnOrder.__isset.TYPE_ORDER = true; + metadata_->column_orders.resize(schema_->numColumns(), columnOrder); metadata_->__isset.column_orders = true; - // if plaintext footer, set footer signing algorithm - auto file_encryption_properties = properties_->file_encryption_properties(); - if (file_encryption_properties && - !file_encryption_properties->encrypted_footer()) { - EncryptionAlgorithm signing_algorithm; - EncryptionAlgorithm algo = file_encryption_properties->algorithm(); - signing_algorithm.aad.aad_file_unique = algo.aad.aad_file_unique; - signing_algorithm.aad.supply_aad_prefix = algo.aad.supply_aad_prefix; - if (!algo.aad.supply_aad_prefix) { - signing_algorithm.aad.aad_prefix = algo.aad.aad_prefix; + // If plaintext footer, set footer signing algorithm. + auto fileEncryptionProperties = properties_->fileEncryptionProperties(); + if (fileEncryptionProperties && + !fileEncryptionProperties->encryptedFooter()) { + EncryptionAlgorithm signingAlgorithm; + EncryptionAlgorithm algo = fileEncryptionProperties->algorithm(); + signingAlgorithm.aad.aadFileUnique = algo.aad.aadFileUnique; + signingAlgorithm.aad.supplyAadPrefix = algo.aad.supplyAadPrefix; + if (!algo.aad.supplyAadPrefix) { + signingAlgorithm.aad.aadPrefix = algo.aad.aadPrefix; } - signing_algorithm.algorithm = ParquetCipher::AES_GCM_V1; - - metadata_->__set_encryption_algorithm(ToThrift(signing_algorithm)); - const std::string& footer_signing_key_metadata = - file_encryption_properties->footer_key_metadata(); - if (footer_signing_key_metadata.size() > 0) { - metadata_->__set_footer_signing_key_metadata( - footer_signing_key_metadata); + signingAlgorithm.algorithm = ParquetCipher::kAesGcmV1; + + metadata_->__set_encryption_algorithm(toThrift(signingAlgorithm)); + const std::string& footerSigningKeyMetadata = + fileEncryptionProperties->footerKeyMetadata(); + if (footerSigningKeyMetadata.size() > 0) { + metadata_->__set_footer_signing_key_metadata(footerSigningKeyMetadata); } } - ToParquet( - static_cast(schema_->schema_root().get()), + toParquet( + static_cast(schema_->schemaRoot().get()), &metadata_->schema); - auto file_meta_data = std::unique_ptr(new FileMetaData()); - file_meta_data->impl_->metadata_ = std::move(metadata_); - file_meta_data->impl_->InitSchema(); - file_meta_data->impl_->InitKeyValueMetadata(); - return file_meta_data; + auto fileMetaData = std::unique_ptr(new FileMetaData()); + fileMetaData->impl_->metadata_ = std::move(metadata_); + fileMetaData->impl_->initSchema(); + fileMetaData->impl_->initKeyValueMetadata(); + // Pass total NaN counts per field ID to FileMetaData. + fileMetaData->impl_->setNaNCounts(std::move(fieldNanCounts_)); + return fileMetaData; } - std::unique_ptr BuildFileCryptoMetaData() { + std::unique_ptr buildFileCryptoMetaData() { if (crypto_metadata_ == nullptr) { return nullptr; } - auto file_encryption_properties = properties_->file_encryption_properties(); + auto fileEncryptionProperties = properties_->fileEncryptionProperties(); crypto_metadata_->__set_encryption_algorithm( - ToThrift(file_encryption_properties->algorithm())); - std::string key_metadata = - file_encryption_properties->footer_key_metadata(); + toThrift(fileEncryptionProperties->algorithm())); + std::string keyMetadata = fileEncryptionProperties->footerKeyMetadata(); - if (!key_metadata.empty()) { - crypto_metadata_->__set_key_metadata(key_metadata); + if (!keyMetadata.empty()) { + crypto_metadata_->__set_key_metadata(keyMetadata); } - std::unique_ptr file_crypto_metadata( + std::unique_ptr fileCryptoMetadata( new FileCryptoMetaData()); - file_crypto_metadata->impl_->metadata_ = std::move(*crypto_metadata_); - return file_crypto_metadata; + fileCryptoMetadata->impl_->metadata_ = std::move(*crypto_metadata_); + return fileCryptoMetadata; } protected: @@ -2281,23 +2352,42 @@ class FileMetaDataBuilder::FileMetaDataBuilderImpl { crypto_metadata_; private: + // Helper to accumulate NaN counts from the current row group builder. + void accumulateNaNCountsFromCurrentRowGroup() { + if (!currentRowGroupBuilder_) { + return; + } + auto rgNaNCounts = currentRowGroupBuilder_->nanCounts(); + // Accumulate NaN counts from this row group (keyed by field ID). + for (const auto& [fieldId, countPair] : rgNaNCounts) { + const auto& [count, has_count] = countPair; + if (has_count) { + fieldNanCounts_[fieldId].first += count; + fieldNanCounts_[fieldId].second = true; + } + } + } + const std::shared_ptr properties_; - std::vector row_groups_; + std::vector rowGroups_; - std::unique_ptr current_row_group_builder_; + std::unique_ptr currentRowGroupBuilder_; const SchemaDescriptor* schema_; - std::shared_ptr key_value_metadata_; + std::shared_ptr keyValueMetadata_; + // Total NaN counts per field ID across all row groups: field_id -> + // (nan_count, has_nan_count). + std::unordered_map> fieldNanCounts_; }; -std::unique_ptr FileMetaDataBuilder::Make( +std::unique_ptr FileMetaDataBuilder::make( const SchemaDescriptor* schema, std::shared_ptr props, - std::shared_ptr key_value_metadata) { + std::shared_ptr keyValueMetadata) { return std::unique_ptr(new FileMetaDataBuilder( - schema, std::move(props), std::move(key_value_metadata))); + schema, std::move(props), std::move(keyValueMetadata))); } -std::unique_ptr FileMetaDataBuilder::Make( +std::unique_ptr FileMetaDataBuilder::make( const SchemaDescriptor* schema, std::shared_ptr props) { return std::unique_ptr( @@ -2307,31 +2397,31 @@ std::unique_ptr FileMetaDataBuilder::Make( FileMetaDataBuilder::FileMetaDataBuilder( const SchemaDescriptor* schema, std::shared_ptr props, - std::shared_ptr key_value_metadata) + std::shared_ptr keyValueMetadata) : impl_{ std::unique_ptr(new FileMetaDataBuilderImpl( schema, std::move(props), - std::move(key_value_metadata)))} {} + std::move(keyValueMetadata)))} {} FileMetaDataBuilder::~FileMetaDataBuilder() = default; -RowGroupMetaDataBuilder* FileMetaDataBuilder::AppendRowGroup() { - return impl_->AppendRowGroup(); +RowGroupMetaDataBuilder* FileMetaDataBuilder::appendRowGroup() { + return impl_->appendRowGroup(); } -void FileMetaDataBuilder::SetPageIndexLocation( +void FileMetaDataBuilder::setPageIndexLocation( const PageIndexLocation& location) { - impl_->SetPageIndexLocation(location); + impl_->setPageIndexLocation(location); } -std::unique_ptr FileMetaDataBuilder::Finish( - const std::shared_ptr& key_value_metadata) { - return impl_->Finish(key_value_metadata); +std::unique_ptr FileMetaDataBuilder::finish( + const std::shared_ptr& keyValueMetadata) { + return impl_->finish(keyValueMetadata); } -std::unique_ptr FileMetaDataBuilder::GetCryptoMetaData() { - return impl_->BuildFileCryptoMetaData(); +std::unique_ptr FileMetaDataBuilder::getCryptoMetaData() { + return impl_->buildFileCryptoMetaData(); } } // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/Metadata.h b/velox/dwio/parquet/writer/arrow/Metadata.h index c69ee5a03d4..c6104a2a794 100644 --- a/velox/dwio/parquet/writer/arrow/Metadata.h +++ b/velox/dwio/parquet/writer/arrow/Metadata.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -54,16 +55,16 @@ using KeyValueMetadata = ::arrow::KeyValueMetadata; class PARQUET_EXPORT ApplicationVersion { public: - // Known Versions with Issues + // Known versions with issues. static const ApplicationVersion& PARQUET_251_FIXED_VERSION(); static const ApplicationVersion& PARQUET_816_FIXED_VERSION(); static const ApplicationVersion& PARQUET_CPP_FIXED_STATS_VERSION(); static const ApplicationVersion& PARQUET_MR_FIXED_STATS_VERSION(); static const ApplicationVersion& PARQUET_CPP_10353_FIXED_VERSION(); - // Application that wrote the file. e.g. "IMPALA" + // Application that wrote the file, e.g., "IMPALA". std::string application_; - // Build name + // Build name. std::string build_; // Version of the application that wrote the file, expressed as @@ -76,37 +77,37 @@ class PARQUET_EXPORT ApplicationVersion { int minor; int patch; std::string unknown; - std::string pre_release; - std::string build_info; + std::string preRelease; + std::string buildInfo; } version; ApplicationVersion() = default; - explicit ApplicationVersion(const std::string& created_by); + explicit ApplicationVersion(const std::string& createdBy); ApplicationVersion(std::string application, int major, int minor, int patch); - // Returns true if version is strictly less than other_version - bool VersionLt(const ApplicationVersion& other_version) const; + // Returns true if version is strictly less than otherVersion. + bool versionLt(const ApplicationVersion& otherVersion) const; - // Returns true if version is strictly equal with other_version - bool VersionEq(const ApplicationVersion& other_version) const; + // Returns true if version is strictly equal with otherVersion. + bool versionEq(const ApplicationVersion& otherVersion) const; - // Checks if the Version has the correct statistics for a given column - bool HasCorrectStatistics( + // Checks if the Version has the correct statistics for a given column. + bool hasCorrectStatistics( Type::type primitive, EncodedStatistics& statistics, - SortOrder::type sort_order = SortOrder::SIGNED) const; + SortOrder::type sortOrder = SortOrder::kSigned) const; }; class PARQUET_EXPORT ColumnCryptoMetaData { public: - static std::unique_ptr Make(const uint8_t* metadata); + static std::unique_ptr make(const uint8_t* metadata); ~ColumnCryptoMetaData(); - bool Equals(const ColumnCryptoMetaData& other) const; + bool equals(const ColumnCryptoMetaData& other) const; - std::shared_ptr path_in_schema() const; - bool encrypted_with_footer_key() const; - const std::string& key_metadata() const; + std::shared_ptr pathInSchema() const; + bool encryptedWithFooterKey() const; + const std::string& keyMetadata() const; private: explicit ColumnCryptoMetaData(const uint8_t* metadata); @@ -115,18 +116,18 @@ class PARQUET_EXPORT ColumnCryptoMetaData { std::unique_ptr impl_; }; -/// \brief Public struct for Thrift PageEncodingStats in ColumnChunkMetaData +/// \brief Public struct for Thrift PageEncodingStats in ColumnChunkMetaData. struct PageEncodingStats { - PageType::type page_type; + PageType::type pageType; Encoding::type encoding; int32_t count; }; /// \brief Public struct for location to page index in ColumnChunkMetaData. struct IndexLocation { - /// File offset of the given index, in bytes + /// File offset of the given index, in bytes. int64_t offset; - /// Length of the given index, in bytes + /// Length of the given index, in bytes. int32_t length; }; @@ -134,73 +135,74 @@ struct IndexLocation { /// facebook::velox::parquet::thrift::ColumnChunkMetaData. class PARQUET_EXPORT ColumnChunkMetaData { public: - // API convenience to get a MetaData accessor + // API convenience to get a MetaData accessor. ARROW_DEPRECATED("Use the ReaderProperties-taking overload") - static std::unique_ptr Make( + static std::unique_ptr make( const void* metadata, const ColumnDescriptor* descr, - const ApplicationVersion* writer_version, - int16_t row_group_ordinal = -1, - int16_t column_ordinal = -1, - std::shared_ptr file_decryptor = NULLPTR); + const ApplicationVersion* writerVersion, + int16_t rowGroupOrdinal = -1, + int16_t columnOrdinal = -1, + std::shared_ptr fileDecryptor = NULLPTR); - static std::unique_ptr Make( + static std::unique_ptr make( const void* metadata, const ColumnDescriptor* descr, - const ReaderProperties& properties = default_reader_properties(), - const ApplicationVersion* writer_version = NULLPTR, - int16_t row_group_ordinal = -1, - int16_t column_ordinal = -1, - std::shared_ptr file_decryptor = NULLPTR); + const ReaderProperties& properties = defaultReaderProperties(), + const ApplicationVersion* writerVersion = NULLPTR, + int16_t rowGroupOrdinal = -1, + int16_t columnOrdinal = -1, + std::shared_ptr fileDecryptor = NULLPTR); ~ColumnChunkMetaData(); - bool Equals(const ColumnChunkMetaData& other) const; + bool equals(const ColumnChunkMetaData& other) const; - // column chunk - int64_t file_offset() const; + // Column chunk. + int64_t fileOffset() const; - // parameter is only used when a dataset is spread across multiple files - const std::string& file_path() const; + // Parameter is only used when a dataset is spread across multiple files. + const std::string& filePath() const; - // column metadata - bool is_metadata_set() const; + // Column metadata. + bool isMetadataSet() const; Type::type type() const; - int64_t num_values() const; - std::shared_ptr path_in_schema() const; - bool is_stats_set() const; + int64_t numValues() const; + std::shared_ptr pathInSchema() const; + bool isStatsSet() const; std::shared_ptr statistics() const; Compression::type compression() const; // Indicate if the ColumnChunk compression is supported by the current - // compiled parquet library. - bool can_decompress() const; + // compiled Parquet library. + bool canDecompress() const; const std::vector& encodings() const; - const std::vector& encoding_stats() const; - std::optional bloom_filter_offset() const; - bool has_dictionary_page() const; - int64_t dictionary_page_offset() const; - int64_t data_page_offset() const; - bool has_index_page() const; - int64_t index_page_offset() const; - int64_t total_compressed_size() const; - int64_t total_uncompressed_size() const; - std::unique_ptr crypto_metadata() const; - std::optional GetColumnIndexLocation() const; - std::optional GetOffsetIndexLocation() const; + const std::vector& encodingStats() const; + std::optional bloomFilterOffset() const; + bool hasDictionaryPage() const; + int64_t dictionaryPageOffset() const; + int64_t dataPageOffset() const; + bool hasIndexPage() const; + int64_t indexPageOffset() const; + int64_t totalCompressedSize() const; + int64_t totalUncompressedSize() const; + int32_t fieldId() const; + std::unique_ptr cryptoMetadata() const; + std::optional getColumnIndexLocation() const; + std::optional getOffsetIndexLocation() const; private: explicit ColumnChunkMetaData( const void* metadata, const ColumnDescriptor* descr, - int16_t row_group_ordinal, - int16_t column_ordinal, + int16_t rowGroupOrdinal, + int16_t columnOrdinal, const ReaderProperties& properties, - const ApplicationVersion* writer_version = NULLPTR, - std::shared_ptr file_decryptor = NULLPTR); - // PIMPL Idiom + const ApplicationVersion* writerVersion = NULLPTR, + std::shared_ptr fileDecryptor = NULLPTR); + // PIMPL Idiom. class ColumnChunkMetaDataImpl; std::unique_ptr impl_; }; @@ -210,73 +212,73 @@ class PARQUET_EXPORT ColumnChunkMetaData { class PARQUET_EXPORT RowGroupMetaData { public: ARROW_DEPRECATED("Use the ReaderProperties-taking overload") - static std::unique_ptr Make( + static std::unique_ptr make( const void* metadata, const SchemaDescriptor* schema, - const ApplicationVersion* writer_version, - std::shared_ptr file_decryptor = NULLPTR); + const ApplicationVersion* writerVersion, + std::shared_ptr fileDecryptor = NULLPTR); /// \brief Create a RowGroupMetaData from a serialized thrift message. - static std::unique_ptr Make( + static std::unique_ptr make( const void* metadata, const SchemaDescriptor* schema, - const ReaderProperties& properties = default_reader_properties(), - const ApplicationVersion* writer_version = NULLPTR, - std::shared_ptr file_decryptor = NULLPTR); + const ReaderProperties& properties = defaultReaderProperties(), + const ApplicationVersion* writerVersion = NULLPTR, + std::shared_ptr fileDecryptor = NULLPTR); ~RowGroupMetaData(); - bool Equals(const RowGroupMetaData& other) const; + bool equals(const RowGroupMetaData& other) const; /// \brief The number of columns in this row group. The order must match the /// parent's column ordering. - int num_columns() const; + int numColumns() const; /// \brief Return the ColumnChunkMetaData of the corresponding column ordinal. /// - /// WARNING, the returned object references memory location in it's parent + /// WARNING: The returned object references memory location in its parent /// (RowGroupMetaData) object. Hence, the parent must outlive the returned /// object. /// - /// \param[in] index of the ColumnChunkMetaData to retrieve. + /// \param[in] index Index of the ColumnChunkMetaData to retrieve. /// /// \throws ParquetException if the index is out of bound. - std::unique_ptr ColumnChunk(int index) const; + std::unique_ptr columnChunk(int index) const; /// \brief Number of rows in this row group. - int64_t num_rows() const; + int64_t numRows() const; /// \brief Total byte size of all the uncompressed column data in this row /// group. - int64_t total_byte_size() const; + int64_t totalByteSize() const; /// \brief Total byte size of all the compressed (and potentially encrypted) /// column data in this row group. /// /// This information is optional and may be 0 if omitted. - int64_t total_compressed_size() const; + int64_t totalCompressedSize() const; /// \brief Byte offset from beginning of file to first page (data or - /// dictionary) in this row group + /// dictionary) in this row group. /// /// The file_offset field that this method exposes is optional. This method /// will return 0 if that field is not set to a meaningful value. - int64_t file_offset() const; - // Return const-pointer to make it clear that this object is not to be copied + int64_t fileOffset() const; + // Return const pointer to make it clear that this object is not to be copied. const SchemaDescriptor* schema() const; // Indicate if all of the RowGroup's ColumnChunks can be decompressed. - bool can_decompress() const; + bool canDecompress() const; // Sorting columns of the row group if any. - std::vector sorting_columns() const; + std::vector sortingColumns() const; private: explicit RowGroupMetaData( const void* metadata, const SchemaDescriptor* schema, const ReaderProperties& properties, - const ApplicationVersion* writer_version = NULLPTR, - std::shared_ptr file_decryptor = NULLPTR); - // PIMPL Idiom + const ApplicationVersion* writerVersion = NULLPTR, + std::shared_ptr fileDecryptor = NULLPTR); + // PIMPL Idiom. class RowGroupMetaDataImpl; std::unique_ptr impl_; }; @@ -288,74 +290,74 @@ class FileMetaDataBuilder; class PARQUET_EXPORT FileMetaData { public: ARROW_DEPRECATED("Use the ReaderProperties-taking overload") - static std::shared_ptr Make( - const void* serialized_metadata, - uint32_t* inout_metadata_len, - std::shared_ptr file_decryptor); + static std::shared_ptr make( + const void* serializedMetadata, + uint32_t* inoutMetadataLen, + std::shared_ptr fileDecryptor); /// \brief Create a FileMetaData from a serialized thrift message. - static std::shared_ptr Make( - const void* serialized_metadata, - uint32_t* inout_metadata_len, - const ReaderProperties& properties = default_reader_properties(), - std::shared_ptr file_decryptor = NULLPTR); + static std::shared_ptr make( + const void* serializedMetadata, + uint32_t* inoutMetadataLen, + const ReaderProperties& properties = defaultReaderProperties(), + std::shared_ptr fileDecryptor = NULLPTR); ~FileMetaData(); - bool Equals(const FileMetaData& other) const; + bool equals(const FileMetaData& other) const; - /// \brief The number of parquet "leaf" columns. + /// \brief The number of Parquet "leaf" columns. /// /// Parquet thrift definition requires that nested schema elements are /// flattened. This method returns the number of columns in the flattened /// version. - /// For instance, if the schema looks like this : - /// 0 foo.bar - /// foo.bar.baz 0 - /// foo.bar.baz2 1 - /// foo.qux 2 - /// 1 foo2 3 - /// 2 foo3 4 + /// For instance, if the schema looks like this: + /// 0 Foo.bar + /// Foo.bar.baz 0 + /// Foo.bar.baz2 1 + /// Foo.qux 2 + /// 1 Foo2 3 + /// 2 Foo3 4 /// This method will return 5, because there are 5 "leaf" fields (so 5 - /// flattened fields) - int num_columns() const; + /// flattened fields). + int numColumns() const; /// \brief The number of flattened schema elements. /// /// Parquet thrift definition requires that nested schema elements are /// flattened. This method returns the total number of elements in the /// flattened list. - int num_schema_elements() const; + int numSchemaElements() const; /// \brief The total number of rows. - int64_t num_rows() const; + int64_t numRows() const; /// \brief The number of row groups in the file. - int num_row_groups() const; + int numRowGroups() const; /// \brief Return the RowGroupMetaData of the corresponding row group ordinal. /// - /// WARNING, the returned object references memory location in it's parent + /// WARNING: The returned object references memory location in its parent /// (FileMetaData) object. Hence, the parent must outlive the returned object. /// - /// \param[in] index of the RowGroup to retrieve. + /// \param[in] index Index of the RowGroup to retrieve. /// /// \throws ParquetException if the index is out of bound. - std::unique_ptr RowGroup(int index) const; + std::unique_ptr rowGroup(int index) const; - /// \brief Return the "version" of the file + /// \brief Return the "version" of the file. /// - /// WARNING: The value returned by this method is unreliable as 1) the Parquet - /// file metadata stores the version as a single integer and 2) some producers - /// are known to always write a hardcoded value. Therefore, you cannot use - /// this value to know which features are used in the file. + /// WARNING: The value returned by this method is unreliable as 1) the + /// Parquet file metadata stores the version as a single integer and 2) some + /// producers are known to always write a hardcoded value. Therefore, you + /// cannot use this value to know which features are used in the file. ParquetVersion::type version() const; /// \brief Return the application's user-agent string of the writer. - const std::string& created_by() const; + const std::string& createdBy() const; /// \brief Return the application's version of the writer. - const ApplicationVersion& writer_version() const; + const ApplicationVersion& writerVersion() const; /// \brief Size of the original thrift encoded metadata footer. uint32_t size() const; @@ -364,71 +366,75 @@ class PARQUET_EXPORT FileMetaData { /// decompressed. /// /// This will return false if any of the RowGroup's page is compressed with a - /// compression format which is not compiled in the current parquet library. - bool can_decompress() const; + /// compression format which is not compiled in the current Parquet library. + bool canDecompress() const; - bool is_encryption_algorithm_set() const; - EncryptionAlgorithm encryption_algorithm() const; - const std::string& footer_signing_key_metadata() const; + bool isEncryptionAlgorithmSet() const; + EncryptionAlgorithm encryptionAlgorithm() const; + const std::string& footerSigningKeyMetadata() const; /// \brief Verify signature of FileMetaData when file is encrypted but footer /// is not encrypted (plaintext footer). - bool VerifySignature(const void* signature); + bool verifySignature(const void* signature); - void WriteTo( + void writeTo( ::arrow::io::OutputStream* dst, const std::shared_ptr& encryptor = NULLPTR) const; /// \brief Return Thrift-serialized representation of the metadata as a - /// string - std::string SerializeToString() const; + /// string. + std::string serializeToString() const; - // Return const-pointer to make it clear that this object is not to be copied + // Return const pointer to make it clear that this object is not to be copied. const SchemaDescriptor* schema() const; - const std::shared_ptr& key_value_metadata() const; + const std::shared_ptr& keyValueMetadata() const; /// \brief Set a path to all ColumnChunk for all RowGroups. /// - /// Commonly used by systems (Dask, Spark) who generates an metadata-only - /// parquet file. The path is usually relative to said index file. + /// Commonly used by systems (Dask, Spark) that generate a metadata-only + /// Parquet file. The path is usually relative to said index file. /// - /// \param[in] path to set. - void set_file_path(const std::string& path); + /// \param[in] path Path to set. + void setFilePath(const std::string& path); /// \brief Merge row groups from another metadata file into this one. /// /// The schema of the input FileMetaData must be equal to the /// schema of this object. /// - /// This is used by systems who creates an aggregate metadata-only file by + /// This is used by systems that create an aggregate metadata-only file by /// concatenating the row groups of multiple files. This newly created /// metadata file acts as an index of all available row groups. /// - /// \param[in] other FileMetaData to merge the row groups from. + /// \param[in] other Other FileMetaData to merge the row groups from. /// /// \throws ParquetException if schemas are not equal. - void AppendRowGroups(const FileMetaData& other); + void appendRowGroups(const FileMetaData& other); - /// \brief Return a FileMetaData containing a subset of the row groups in this - /// FileMetaData. - std::shared_ptr Subset( - const std::vector& row_groups) const; + /// \brief Return a FileMetaData containing a subset of the row groups in + /// this FileMetaData. + std::shared_ptr subset(const std::vector& rowGroups) const; + + /// \brief Get total NaN count for a specific field ID across all row groups. + /// Returns a pair of (nanCount, hasNanCount). + /// NaN counts are collected during writing but not written to the Parquet + /// file. + std::pair getNaNCount(int32_t fieldId) const; private: friend FileMetaDataBuilder; friend class SerializedFile; explicit FileMetaData( - const void* serialized_metadata, - uint32_t* metadata_len, + const void* serializedMetadata, + uint32_t* metadataLen, const ReaderProperties& properties, - std::shared_ptr file_decryptor = NULLPTR); + std::shared_ptr fileDecryptor = NULLPTR); - void set_file_decryptor( - std::shared_ptr file_decryptor); + void setFileDecryptor(std::shared_ptr fileDecryptor); - // PIMPL Idiom + // PIMPL Idiom. FileMetaData(); class FileMetaDataImpl; std::unique_ptr impl_; @@ -436,75 +442,81 @@ class PARQUET_EXPORT FileMetaData { class PARQUET_EXPORT FileCryptoMetaData { public: - // API convenience to get a MetaData accessor - static std::shared_ptr Make( - const uint8_t* serialized_metadata, - uint32_t* metadata_len, - const ReaderProperties& properties = default_reader_properties()); + // API convenience to get a MetaData accessor. + static std::shared_ptr make( + const uint8_t* serializedMetadata, + uint32_t* metadataLen, + const ReaderProperties& properties = defaultReaderProperties()); ~FileCryptoMetaData(); - EncryptionAlgorithm encryption_algorithm() const; - const std::string& key_metadata() const; + EncryptionAlgorithm encryptionAlgorithm() const; + const std::string& keyMetadata() const; - void WriteTo(::arrow::io::OutputStream* dst) const; + void writeTo(::arrow::io::OutputStream* dst) const; private: friend FileMetaDataBuilder; FileCryptoMetaData( - const uint8_t* serialized_metadata, - uint32_t* metadata_len, + const uint8_t* serializedMetadata, + uint32_t* metadataLen, const ReaderProperties& properties); - // PIMPL Idiom + // PIMPL Idiom. FileCryptoMetaData(); class FileCryptoMetaDataImpl; std::unique_ptr impl_; }; -// Builder API +// Builder API. class PARQUET_EXPORT ColumnChunkMetaDataBuilder { public: - // API convenience to get a MetaData reader - static std::unique_ptr Make( + // API convenience to get a MetaData reader. + static std::unique_ptr make( std::shared_ptr props, const ColumnDescriptor* column); - static std::unique_ptr Make( + static std::unique_ptr make( std::shared_ptr props, const ColumnDescriptor* column, - void* contents); + void* Contents); ~ColumnChunkMetaDataBuilder(); - // column chunk - // Used when a dataset is spread across multiple files - void set_file_path(const std::string& path); - // column metadata - void SetStatistics(const EncodedStatistics& stats); - // get the column descriptor + // Column chunk. + // Used when a dataset is spread across multiple files. + void setFilePath(const std::string& path); + // Column metadata. + void setStatistics(const EncodedStatistics& stats); + // Get the column descriptor. const ColumnDescriptor* descr() const; - int64_t total_compressed_size() const; - // commit the metadata - - void Finish( - int64_t num_values, - int64_t dictionary_page_offset, - int64_t index_page_offset, - int64_t data_page_offset, - int64_t compressed_size, - int64_t uncompressed_size, - bool has_dictionary, - bool dictionary_fallback, - const std::map& dict_encoding_stats_, - const std::map& data_encoding_stats_, + int64_t totalCompressedSize() const; + + // NaN count accessors - NaN counts are collected during writing but not + // written to the parquet file. + int64_t nanCount() const; + + bool hasNanCount() const; + + // Commit the metadata. + void finish( + int64_t numValues, + int64_t dictionaryPageOffset, + int64_t indexPageOffset, + int64_t dataPageOffset, + int64_t compressedSize, + int64_t uncompressedSize, + bool hasDictionary, + bool dictionaryFallback, + const std::map& dictEncodingStats, + const std::map& dataEncodingStats, const std::shared_ptr& encryptor = NULLPTR); - // The metadata contents, suitable for passing to ColumnChunkMetaData::Make - const void* contents() const; + // The metadata contents, suitable for passing to ColumnChunkMetaData::Make. + const void* Contents() const; - // For writing metadata at end of column chunk - void WriteTo(::arrow::io::OutputStream* sink); + // For writing metadata at end of column chunk. + void writeTo(::arrow::io::OutputStream* sink); private: explicit ColumnChunkMetaDataBuilder( @@ -513,97 +525,101 @@ class PARQUET_EXPORT ColumnChunkMetaDataBuilder { explicit ColumnChunkMetaDataBuilder( std::shared_ptr props, const ColumnDescriptor* column, - void* contents); - // PIMPL Idiom + void* Contents); + // PIMPL Idiom. class ColumnChunkMetaDataBuilderImpl; std::unique_ptr impl_; }; class PARQUET_EXPORT RowGroupMetaDataBuilder { public: - // API convenience to get a MetaData reader - static std::unique_ptr Make( + // API convenience to get a MetaData reader. + static std::unique_ptr make( std::shared_ptr props, const SchemaDescriptor* schema_, - void* contents); + void* Contents); ~RowGroupMetaDataBuilder(); - ColumnChunkMetaDataBuilder* NextColumnChunk(); - int num_columns(); - int64_t num_rows(); - int current_column() const; + ColumnChunkMetaDataBuilder* nextColumnChunk(); + int numColumns(); + int64_t numRows(); + int currentColumn() const; + + void setNumRows(int64_t numRows); - void set_num_rows(int64_t num_rows); + // Get NaN counts for all columns in current row group. + // Returns a map of field_id -> (nan_count, has_nan_count). + std::unordered_map> nanCounts() const; - // commit the metadata - void Finish(int64_t total_bytes_written, int16_t row_group_ordinal = -1); + // Commit the metadata. + void finish(int64_t totalBytesWritten, int16_t rowGroupOrdinal = -1); private: explicit RowGroupMetaDataBuilder( std::shared_ptr props, const SchemaDescriptor* schema_, - void* contents); - // PIMPL Idiom + void* Contents); + // PIMPL Idiom. class RowGroupMetaDataBuilderImpl; std::unique_ptr impl_; }; -/// \brief Public struct for location to all page indexes in a parquet file. +/// \brief Public struct for location to all page indexes in a Parquet file. struct PageIndexLocation { /// Alias type of page index location of a row group. The index location /// is located by column ordinal. If the column does not have the page index, /// its value is set to std::nullopt. using RowGroupIndexLocation = std::vector>; - /// Alias type of page index location of a parquet file. The index location + /// Alias type of page index location of a Parquet file. The index location /// is located by the row group ordinal. using FileIndexLocation = std::map; - /// Row group column index locations which uses row group ordinal as the key. - FileIndexLocation column_index_location; - /// Row group offset index locations which uses row group ordinal as the key. - FileIndexLocation offset_index_location; + /// Row group column index locations which use row group ordinal as the key. + FileIndexLocation columnIndexLocation; + /// Row group offset index locations which use row group ordinal as the key. + FileIndexLocation offsetIndexLocation; }; class PARQUET_EXPORT FileMetaDataBuilder { public: ARROW_DEPRECATED( "Deprecated in 12.0.0. Use overload without KeyValueMetadata instead.") - static std::unique_ptr Make( + static std::unique_ptr make( const SchemaDescriptor* schema, std::shared_ptr props, - std::shared_ptr key_value_metadata); + std::shared_ptr keyValueMetadata); - // API convenience to get a MetaData builder - static std::unique_ptr Make( + // API convenience to get a MetaData builder. + static std::unique_ptr make( const SchemaDescriptor* schema, std::shared_ptr props); ~FileMetaDataBuilder(); - // The prior RowGroupMetaDataBuilder (if any) is destroyed - RowGroupMetaDataBuilder* AppendRowGroup(); + // The prior RowGroupMetaDataBuilder (if any) is destroyed. + RowGroupMetaDataBuilder* appendRowGroup(); - // Update location to all page indexes in the parquet file - void SetPageIndexLocation(const PageIndexLocation& location); + // Update location to all page indexes in the Parquet file. + void setPageIndexLocation(const PageIndexLocation& location); - // Complete the Thrift structure - std::unique_ptr Finish( - const std::shared_ptr& key_value_metadata = + // Complete the Thrift structure. + std::unique_ptr finish( + const std::shared_ptr& keyValueMetadata = NULLPTR); - // crypto metadata - std::unique_ptr GetCryptoMetaData(); + // Crypto metadata. + std::unique_ptr getCryptoMetaData(); private: explicit FileMetaDataBuilder( const SchemaDescriptor* schema, std::shared_ptr props, - std::shared_ptr key_value_metadata = NULLPTR); - // PIMPL Idiom + std::shared_ptr keyValueMetadata = NULLPTR); + // PIMPL Idiom. class FileMetaDataBuilderImpl; std::unique_ptr impl_; }; -PARQUET_EXPORT std::string ParquetVersionToString(ParquetVersion::type ver); +PARQUET_EXPORT std::string parquetVersionToString(ParquetVersion::type ver); } // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/PageIndex.cpp b/velox/dwio/parquet/writer/arrow/PageIndex.cpp index 465d6e1afec..e07538e83ba 100644 --- a/velox/dwio/parquet/writer/arrow/PageIndex.cpp +++ b/velox/dwio/parquet/writer/arrow/PageIndex.cpp @@ -36,55 +36,54 @@ namespace facebook::velox::parquet::arrow { namespace { template -void Decode( +void decode( std::unique_ptr::Decoder>& decoder, const std::string& input, - std::vector* output, - size_t output_index) { - if (ARROW_PREDICT_FALSE(output_index >= output->size())) { + std::vector* output, + size_t outputIndex) { + if (ARROW_PREDICT_FALSE(outputIndex >= output->size())) { throw ParquetException("Index out of bound"); } - decoder->SetData( - /*num_values=*/1, + decoder->setData( + 1, reinterpret_cast(input.c_str()), static_cast(input.size())); - const auto num_values = - decoder->Decode(&output->at(output_index), /*max_values=*/1); - if (ARROW_PREDICT_FALSE(num_values != 1)) { + const auto numValues = decoder->decode(&output->at(outputIndex), 1); + if (ARROW_PREDICT_FALSE(numValues != 1)) { throw ParquetException("Could not decode statistics value"); } } template <> -void Decode( +void decode( std::unique_ptr& decoder, const std::string& input, std::vector* output, - size_t output_index) { - if (ARROW_PREDICT_FALSE(output_index >= output->size())) { + size_t outputIndex) { + if (ARROW_PREDICT_FALSE(outputIndex >= output->size())) { throw ParquetException("Index out of bound"); } bool value; - decoder->SetData( - /*num_values=*/1, + decoder->setData( + 1, reinterpret_cast(input.c_str()), static_cast(input.size())); - const auto num_values = decoder->Decode(&value, /*max_values=*/1); - if (ARROW_PREDICT_FALSE(num_values != 1)) { + const auto numValues = decoder->decode(&value, 1); + if (ARROW_PREDICT_FALSE(numValues != 1)) { throw ParquetException("Could not decode statistics value"); } - output->at(output_index) = value; + output->at(outputIndex) = value; } template <> -void Decode( +void decode( std::unique_ptr&, const std::string& input, std::vector* output, - size_t output_index) { - if (ARROW_PREDICT_FALSE(output_index >= output->size())) { + size_t outputIndex) { + if (ARROW_PREDICT_FALSE(outputIndex >= output->size())) { throw ParquetException("Index out of bound"); } @@ -94,273 +93,265 @@ void Decode( throw ParquetException("Invalid encoded byte array length"); } - output->at(output_index) = { - /*len=*/static_cast(input.size()), - /*ptr=*/reinterpret_cast(input.data())}; + output->at(outputIndex) = { + static_cast(input.size()), + reinterpret_cast(input.data())}; } template class TypedColumnIndexImpl : public TypedColumnIndex { public: - using T = typename DType::c_type; + using T = typename DType::CType; TypedColumnIndexImpl( const ColumnDescriptor& descr, - facebook::velox::parquet::thrift::ColumnIndex column_index) - : column_index_(std::move(column_index)) { + facebook::velox::parquet::thrift::ColumnIndex columnIndex) + : columnIndex_(std::move(columnIndex)) { // Make sure the number of pages is valid and it does not overflow to // int32_t. - const size_t num_pages = column_index_.null_pages.size(); - if (num_pages >= static_cast(std::numeric_limits::max()) || - column_index_.min_values.size() != num_pages || - column_index_.max_values.size() != num_pages || - (column_index_.__isset.null_counts && - column_index_.null_counts.size() != num_pages)) { + const size_t numPages = columnIndex_.null_pages.size(); + if (numPages >= static_cast(std::numeric_limits::max()) || + columnIndex_.min_values.size() != numPages || + columnIndex_.max_values.size() != numPages || + (columnIndex_.__isset.null_counts && + columnIndex_.null_counts.size() != numPages)) { throw ParquetException("Invalid column index"); } - const size_t num_non_null_pages = static_cast(std::accumulate( - column_index_.null_pages.cbegin(), - column_index_.null_pages.cend(), + const size_t numNonNullPages = static_cast(std::accumulate( + columnIndex_.null_pages.cbegin(), + columnIndex_.null_pages.cend(), 0, - [](int32_t num_non_null_pages, bool null_page) { - return num_non_null_pages + (null_page ? 0 : 1); + [](int32_t numNonNullPages, bool nullPage) { + return numNonNullPages + (nullPage ? 0 : 1); })); - VELOX_DCHECK_LE(num_non_null_pages, num_pages); + VELOX_DCHECK_LE(numNonNullPages, numPages); // Allocate slots for decoded values. - min_values_.resize(num_pages); - max_values_.resize(num_pages); - non_null_page_indices_.reserve(num_non_null_pages); + minValues_.resize(numPages); + maxValues_.resize(numPages); + nonNullPageIndices_.reserve(numNonNullPages); // Decode min and max values according to the physical type. - // Note that null page are skipped. - auto plain_decoder = MakeTypedDecoder(Encoding::PLAIN, &descr); - for (size_t i = 0; i < num_pages; ++i) { - if (!column_index_.null_pages[i]) { - // The check on `num_pages` has guaranteed the cast below is safe. - non_null_page_indices_.emplace_back(static_cast(i)); - Decode( - plain_decoder, column_index_.min_values[i], &min_values_, i); - Decode( - plain_decoder, column_index_.max_values[i], &max_values_, i); + // Note that null pages are skipped. + auto plainDecoder = makeTypedDecoder(Encoding::kPlain, &descr); + for (size_t i = 0; i < numPages; ++i) { + if (!columnIndex_.null_pages[i]) { + // The check on `numPages` has guaranteed the cast below is safe. + nonNullPageIndices_.emplace_back(static_cast(i)); + decode(plainDecoder, columnIndex_.min_values[i], &minValues_, i); + decode(plainDecoder, columnIndex_.max_values[i], &maxValues_, i); } } - VELOX_DCHECK_EQ(num_non_null_pages, non_null_page_indices_.size()); + VELOX_DCHECK_EQ(numNonNullPages, nonNullPageIndices_.size()); } - const std::vector& null_pages() const override { - return column_index_.null_pages; + const std::vector& nullPages() const override { + return columnIndex_.null_pages; } - const std::vector& encoded_min_values() const override { - return column_index_.min_values; + const std::vector& encodedMinValues() const override { + return columnIndex_.min_values; } - const std::vector& encoded_max_values() const override { - return column_index_.max_values; + const std::vector& encodedMaxValues() const override { + return columnIndex_.max_values; } - BoundaryOrder::type boundary_order() const override { - return LoadEnumSafe(&column_index_.boundary_order); + BoundaryOrder::type boundaryOrder() const override { + return loadenumSafe(&columnIndex_.boundary_order); } - bool has_null_counts() const override { - return column_index_.__isset.null_counts; + bool hasNullCounts() const override { + return columnIndex_.__isset.null_counts; } - const std::vector& null_counts() const override { - return column_index_.null_counts; + const std::vector& nullCounts() const override { + return columnIndex_.null_counts; } - const std::vector& non_null_page_indices() const override { - return non_null_page_indices_; + const std::vector& nonNullPageIndices() const override { + return nonNullPageIndices_; } - const std::vector& min_values() const override { - return min_values_; + const std::vector& minValues() const override { + return minValues_; } - const std::vector& max_values() const override { - return max_values_; + const std::vector& maxValues() const override { + return maxValues_; } private: - /// Wrapped thrift column index. - const facebook::velox::parquet::thrift::ColumnIndex column_index_; + /// Wrapped Thrift column index. + const facebook::velox::parquet::thrift::ColumnIndex columnIndex_; /// Decoded typed min/max values. Undefined for null pages. - std::vector min_values_; - std::vector max_values_; + std::vector minValues_; + std::vector maxValues_; /// A list of page indices for non-null pages. - std::vector non_null_page_indices_; + std::vector nonNullPageIndices_; }; class OffsetIndexImpl : public OffsetIndex { public: explicit OffsetIndexImpl( - const facebook::velox::parquet::thrift::OffsetIndex& offset_index) { - page_locations_.reserve(offset_index.page_locations.size()); - for (const auto& page_location : offset_index.page_locations) { - page_locations_.emplace_back(PageLocation{ - page_location.offset, - page_location.compressed_page_size, - page_location.first_row_index}); + const facebook::velox::parquet::thrift::OffsetIndex& offsetIndex) { + pageLocations_.reserve(offsetIndex.page_locations.size()); + for (const auto& pageLocation : offsetIndex.page_locations) { + pageLocations_.emplace_back( + PageLocation{ + pageLocation.offset, + pageLocation.compressed_page_size, + pageLocation.first_row_index}); } } - const std::vector& page_locations() const override { - return page_locations_; + const std::vector& pageLocations() const override { + return pageLocations_; } private: - std::vector page_locations_; + std::vector pageLocations_; }; class RowGroupPageIndexReaderImpl : public RowGroupPageIndexReader { public: RowGroupPageIndexReaderImpl( ::arrow::io::RandomAccessFile* input, - std::shared_ptr row_group_metadata, + std::shared_ptr rowGroupMetadata, const ReaderProperties& properties, - int32_t row_group_ordinal, - const RowGroupIndexReadRange& index_read_range, - std::shared_ptr file_decryptor) + int32_t rowGroupOrdinal, + const RowGroupIndexReadRange& indexReadRange, + std::shared_ptr fileDecryptor) : input_(input), - row_group_metadata_(std::move(row_group_metadata)), + rowGroupMetadata_(std::move(rowGroupMetadata)), properties_(properties), - row_group_ordinal_(row_group_ordinal), - index_read_range_(index_read_range), - file_decryptor_(std::move(file_decryptor)) {} + rowGroupOrdinal_(rowGroupOrdinal), + indexReadRange_(indexReadRange), + fileDecryptor_(std::move(fileDecryptor)) {} /// Read column index of a column chunk. - std::shared_ptr GetColumnIndex(int32_t i) override { - if (i < 0 || i >= row_group_metadata_->num_columns()) { + std::shared_ptr getColumnIndex(int32_t i) override { + if (i < 0 || i >= rowGroupMetadata_->numColumns()) { throw ParquetException("Invalid column index at column ordinal ", i); } - auto col_chunk = row_group_metadata_->ColumnChunk(i); - std::unique_ptr crypto_metadata = - col_chunk->crypto_metadata(); - if (crypto_metadata != nullptr) { + auto colChunk = rowGroupMetadata_->columnChunk(i); + std::unique_ptr cryptoMetadata = + colChunk->cryptoMetadata(); + if (cryptoMetadata != nullptr) { ParquetException::NYI("Cannot read encrypted column index yet"); } - auto column_index_location = col_chunk->GetColumnIndexLocation(); - if (!column_index_location.has_value()) { + auto columnIndexLocation = colChunk->getColumnIndexLocation(); + if (!columnIndexLocation.has_value()) { return nullptr; } - CheckReadRangeOrThrow( - *column_index_location, - index_read_range_.column_index, - row_group_ordinal_); + checkReadRangeOrThrow( + *columnIndexLocation, indexReadRange_.columnIndex, rowGroupOrdinal_); - if (column_index_buffer_ == nullptr) { + if (columnIndexBuffer_ == nullptr) { PARQUET_ASSIGN_OR_THROW( - column_index_buffer_, + columnIndexBuffer_, input_->ReadAt( - index_read_range_.column_index->offset, - index_read_range_.column_index->length)); + indexReadRange_.columnIndex->offset, + indexReadRange_.columnIndex->length)); } - int64_t buffer_offset = - column_index_location->offset - index_read_range_.column_index->offset; - // ColumnIndex::Make() requires the type of serialized thrift message to be - // uint32_t - uint32_t length = static_cast(column_index_location->length); - auto descr = row_group_metadata_->schema()->Column(i); - return ColumnIndex::Make( - *descr, - column_index_buffer_->data() + buffer_offset, - length, - properties_); + int64_t bufferoffset = + columnIndexLocation->offset - indexReadRange_.columnIndex->offset; + // ColumnIndex::make() requires the type of serialized Thrift message to be + // uint32_t. + uint32_t length = static_cast(columnIndexLocation->length); + auto descr = rowGroupMetadata_->schema()->column(i); + return ColumnIndex::make( + *descr, columnIndexBuffer_->data() + bufferoffset, length, properties_); } /// Read offset index of a column chunk. - std::shared_ptr GetOffsetIndex(int32_t i) override { - if (i < 0 || i >= row_group_metadata_->num_columns()) { + std::shared_ptr getOffsetIndex(int32_t i) override { + if (i < 0 || i >= rowGroupMetadata_->numColumns()) { throw ParquetException("Invalid offset index at column ordinal ", i); } - auto col_chunk = row_group_metadata_->ColumnChunk(i); - std::unique_ptr crypto_metadata = - col_chunk->crypto_metadata(); - if (crypto_metadata != nullptr) { + auto colChunk = rowGroupMetadata_->columnChunk(i); + std::unique_ptr cryptoMetadata = + colChunk->cryptoMetadata(); + if (cryptoMetadata != nullptr) { ParquetException::NYI("Cannot read encrypted offset index yet"); } - auto offset_index_location = col_chunk->GetOffsetIndexLocation(); - if (!offset_index_location.has_value()) { + auto offsetIndexLocation = colChunk->getOffsetIndexLocation(); + if (!offsetIndexLocation.has_value()) { return nullptr; } - CheckReadRangeOrThrow( - *offset_index_location, - index_read_range_.offset_index, - row_group_ordinal_); + checkReadRangeOrThrow( + *offsetIndexLocation, indexReadRange_.offsetIndex, rowGroupOrdinal_); - if (offset_index_buffer_ == nullptr) { + if (offsetIndexBuffer_ == nullptr) { PARQUET_ASSIGN_OR_THROW( - offset_index_buffer_, + offsetIndexBuffer_, input_->ReadAt( - index_read_range_.offset_index->offset, - index_read_range_.offset_index->length)); + indexReadRange_.offsetIndex->offset, + indexReadRange_.offsetIndex->length)); } - int64_t buffer_offset = - offset_index_location->offset - index_read_range_.offset_index->offset; - // OffsetIndex::Make() requires the type of serialized thrift message to be - // uint32_t - uint32_t length = static_cast(offset_index_location->length); - return OffsetIndex::Make( - offset_index_buffer_->data() + buffer_offset, length, properties_); + int64_t bufferoffset = + offsetIndexLocation->offset - indexReadRange_.offsetIndex->offset; + // OffsetIndex::make() requires the type of serialized Thrift message to be + // uint32_t. + uint32_t length = static_cast(offsetIndexLocation->length); + return OffsetIndex::make( + offsetIndexBuffer_->data() + bufferoffset, length, properties_); } private: - static void CheckReadRangeOrThrow( - const IndexLocation& index_location, - const std::optional<::arrow::io::ReadRange>& index_read_range, - int32_t row_group_ordinal) { - if (!index_read_range.has_value()) { + static void checkReadRangeOrThrow( + const IndexLocation& indexLocation, + const std::optional<::arrow::io::ReadRange>& indexReadRange, + int32_t rowGroupOrdinal) { + if (!indexReadRange.has_value()) { throw ParquetException( "Missing page index read range of row group ", - row_group_ordinal, + rowGroupOrdinal, ", it may not exist or has not been requested"); } /// The coalesced read range is invalid. - if (index_read_range->offset < 0 || index_read_range->length <= 0) { + if (indexReadRange->offset < 0 || indexReadRange->length <= 0) { throw ParquetException( "Invalid page index read range: offset ", - index_read_range->offset, + indexReadRange->offset, " length ", - index_read_range->length); + indexReadRange->length); } /// The location to page index itself is corrupted. - if (index_location.offset < 0 || index_location.length <= 0) { + if (indexLocation.offset < 0 || indexLocation.length <= 0) { throw ParquetException( "Invalid page index location: offset ", - index_location.offset, + indexLocation.offset, " length ", - index_location.length); + indexLocation.length); } /// Page index location must be within the range of the read range. - if (index_location.offset < index_read_range->offset || - index_location.offset + index_location.length > - index_read_range->offset + index_read_range->length) { + if (indexLocation.offset < indexReadRange->offset || + indexLocation.offset + indexLocation.length > + indexReadRange->offset + indexReadRange->length) { throw ParquetException( "Page index location [offset:", - index_location.offset, + indexLocation.offset, ",length:", - index_location.length, + indexLocation.length, "] is out of range from previous WillNeed request [offset:", - index_read_range->offset, + indexReadRange->offset, ",length:", - index_read_range->length, + indexReadRange->length, "], row group: ", - row_group_ordinal); + rowGroupOrdinal); } } @@ -369,107 +360,107 @@ class RowGroupPageIndexReaderImpl : public RowGroupPageIndexReader { ::arrow::io::RandomAccessFile* input_; /// The row group metadata to get column chunk metadata. - std::shared_ptr row_group_metadata_; + std::shared_ptr rowGroupMetadata_; /// Reader properties used to deserialize thrift object. const ReaderProperties& properties_; /// The ordinal of the row group in the file. - int32_t row_group_ordinal_; + int32_t rowGroupOrdinal_; /// File offsets and sizes of the page Index of all column chunks in the row /// group. - RowGroupIndexReadRange index_read_range_; + RowGroupIndexReadRange indexReadRange_; /// File-level decryptor. - std::shared_ptr file_decryptor_; + std::shared_ptr fileDecryptor_; /// Buffer to hold the raw bytes of the page index. /// Will be set lazily when the corresponding page index is accessed for the - /// 1st time. - std::shared_ptr<::arrow::Buffer> column_index_buffer_; - std::shared_ptr<::arrow::Buffer> offset_index_buffer_; + /// first time. + std::shared_ptr<::arrow::Buffer> columnIndexBuffer_; + std::shared_ptr<::arrow::Buffer> offsetIndexBuffer_; }; class PageIndexReaderImpl : public PageIndexReader { public: PageIndexReaderImpl( ::arrow::io::RandomAccessFile* input, - std::shared_ptr file_metadata, + std::shared_ptr fileMetadata, const ReaderProperties& properties, - std::shared_ptr file_decryptor) + std::shared_ptr fileDecryptor) : input_(input), - file_metadata_(std::move(file_metadata)), + fileMetadata_(std::move(fileMetadata)), properties_(properties), - file_decryptor_(std::move(file_decryptor)) {} + fileDecryptor_(std::move(fileDecryptor)) {} - std::shared_ptr RowGroup(int i) override { - if (i < 0 || i >= file_metadata_->num_row_groups()) { + std::shared_ptr rowGroup(int i) override { + if (i < 0 || i >= fileMetadata_->numRowGroups()) { throw ParquetException("Invalid row group ordinal: ", i); } - auto row_group_metadata = file_metadata_->RowGroup(i); + auto rowGroupMetadata = fileMetadata_->rowGroup(i); // Find the read range of the page index of the row group if provided by - // WillNeed() - RowGroupIndexReadRange index_read_range; - auto iter = index_read_ranges_.find(i); - if (iter != index_read_ranges_.cend()) { + // WillNeed(). + RowGroupIndexReadRange indexReadRange; + auto iter = indexReadRanges_.find(i); + if (iter != indexReadRanges_.cend()) { /// This row group has been requested by WillNeed(). Only column index /// and/or offset index of requested columns can be read. - index_read_range = iter->second; + indexReadRange = iter->second; } else { - /// If the row group has not been requested by WillNeed(), by default both - /// column index and offset index of all column chunks for the row group - /// can be read. - index_read_range = PageIndexReader::DeterminePageIndexRangesInRowGroup( - *row_group_metadata, {}); + /// If the row group has not been requested by WillNeed(), by default + /// both column index and offset index of all column chunks for the row + /// group can be read. + indexReadRange = PageIndexReader::determinePageIndexRangesInRowGroup( + *rowGroupMetadata, {}); } - if (index_read_range.column_index.has_value() || - index_read_range.offset_index.has_value()) { + if (indexReadRange.columnIndex.has_value() || + indexReadRange.offsetIndex.has_value()) { return std::make_shared( input_, - std::move(row_group_metadata), + std::move(rowGroupMetadata), properties_, i, - index_read_range, - file_decryptor_); + indexReadRange, + fileDecryptor_); } /// The row group does not has page index or has not been requested by - /// WillNeed(). Simply returns nullptr. + /// willNeed(). Simply returns nullptr. return nullptr; } - void WillNeed( - const std::vector& row_group_indices, - const std::vector& column_indices, + void willNeed( + const std::vector& rowGroupIndices, + const std::vector& columnIndices, const PageIndexSelection& selection) override { - std::vector<::arrow::io::ReadRange> read_ranges; - for (int32_t row_group_ordinal : row_group_indices) { - auto read_range = PageIndexReader::DeterminePageIndexRangesInRowGroup( - *file_metadata_->RowGroup(row_group_ordinal), column_indices); - if (selection.column_index && read_range.column_index.has_value()) { - read_ranges.push_back(*read_range.column_index); + std::vector<::arrow::io::ReadRange> readRanges; + for (int32_t rowGroupOrdinal : rowGroupIndices) { + auto readRange = PageIndexReader::determinePageIndexRangesInRowGroup( + *fileMetadata_->rowGroup(rowGroupOrdinal), columnIndices); + if (selection.columnIndex && readRange.columnIndex.has_value()) { + readRanges.push_back(*readRange.columnIndex); } else { // Mark the column index as not requested. - read_range.column_index = std::nullopt; + readRange.columnIndex = std::nullopt; } - if (selection.offset_index && read_range.offset_index.has_value()) { - read_ranges.push_back(*read_range.offset_index); + if (selection.offsetIndex && readRange.offsetIndex.has_value()) { + readRanges.push_back(*readRange.offsetIndex); } else { // Mark the offset index as not requested. - read_range.offset_index = std::nullopt; + readRange.offsetIndex = std::nullopt; } - index_read_ranges_.emplace(row_group_ordinal, std::move(read_range)); + indexReadRanges_.emplace(rowGroupOrdinal, std::move(readRange)); } - PARQUET_THROW_NOT_OK(input_->WillNeed(read_ranges)); + PARQUET_THROW_NOT_OK(input_->WillNeed(readRanges)); } - void WillNotNeed(const std::vector& row_group_indices) override { - for (int32_t row_group_ordinal : row_group_indices) { - index_read_ranges_.erase(row_group_ordinal); + void willNotNeed(const std::vector& rowGroupIndices) override { + for (int32_t rowGroupOrdinal : rowGroupIndices) { + indexReadRanges_.erase(rowGroupOrdinal); } } @@ -478,17 +469,17 @@ class PageIndexReaderImpl : public PageIndexReader { ::arrow::io::RandomAccessFile* input_; /// The file metadata to get row group metadata. - std::shared_ptr file_metadata_; + std::shared_ptr fileMetadata_; /// Reader properties used to deserialize thrift object. const ReaderProperties& properties_; /// File-level decrypter. - std::shared_ptr file_decryptor_; + std::shared_ptr fileDecryptor_; - /// Coalesced read ranges of page index of row groups that have been suggested - /// by WillNeed(). Key is the row group ordinal. - std::unordered_map index_read_ranges_; + /// Coalesced read ranges of page index of row groups that have been + /// suggested by WillNeed(). Key is the row group ordinal. + std::unordered_map indexReadRanges_; }; /// \brief Internal state of page index builder. @@ -506,38 +497,38 @@ enum class BuilderState { template class ColumnIndexBuilderImpl final : public ColumnIndexBuilder { public: - using T = typename DType::c_type; + using T = typename DType::CType; explicit ColumnIndexBuilderImpl(const ColumnDescriptor* descr) : descr_(descr) { - /// Initialize the null_counts vector as set. Invalid null_counts vector - /// from any page will invalidate the null_counts vector of the column + /// Initialize the nullCounts vector as set. Invalid nullCounts vector + /// from any page will invalidate the nullCounts vector of the column /// index. - column_index_.__isset.null_counts = true; - column_index_.boundary_order = + columnIndex_.__isset.null_counts = true; + columnIndex_.boundary_order = facebook::velox::parquet::thrift::BoundaryOrder::UNORDERED; } - void AddPage(const EncodedStatistics& stats) override { + void addPage(const EncodedStatistics& stats) override { if (state_ == BuilderState::kFinished) { throw ParquetException("Cannot add page to finished ColumnIndexBuilder."); } else if (state_ == BuilderState::kDiscarded) { - /// The offset index is discarded. Do nothing. + /// The column index is discarded. Do nothing. return; } state_ = BuilderState::kStarted; - if (stats.all_null_value) { - column_index_.null_pages.emplace_back(true); - column_index_.min_values.emplace_back(""); - column_index_.max_values.emplace_back(""); - } else if (stats.has_min && stats.has_max) { - const size_t page_ordinal = column_index_.null_pages.size(); - non_null_page_indices_.emplace_back(page_ordinal); - column_index_.min_values.emplace_back(stats.min()); - column_index_.max_values.emplace_back(stats.max()); - column_index_.null_pages.emplace_back(false); + if (stats.allNullValue) { + columnIndex_.null_pages.emplace_back(true); + columnIndex_.min_values.emplace_back(""); + columnIndex_.max_values.emplace_back(""); + } else if (stats.hasMin && stats.hasMax) { + const size_t pageOrdinal = columnIndex_.null_pages.size(); + nonNullPageIndices_.emplace_back(pageOrdinal); + columnIndex_.min_values.emplace_back(stats.min()); + columnIndex_.max_values.emplace_back(stats.max()); + columnIndex_.null_pages.emplace_back(false); } else { /// This is a non-null page but it lacks of meaningful min/max values. /// Discard the column index. @@ -545,15 +536,15 @@ class ColumnIndexBuilderImpl final : public ColumnIndexBuilder { return; } - if (column_index_.__isset.null_counts && stats.has_null_count) { - column_index_.null_counts.emplace_back(stats.null_count); + if (columnIndex_.__isset.null_counts && stats.hasNullCount) { + columnIndex_.null_counts.emplace_back(stats.nullCount); } else { - column_index_.__isset.null_counts = false; - column_index_.null_counts.clear(); + columnIndex_.__isset.null_counts = false; + columnIndex_.null_counts.clear(); } } - void Finish() override { + void finish() override { switch (state_) { case BuilderState::kCreated: { /// No page is added. Discard the column index. @@ -572,93 +563,93 @@ class ColumnIndexBuilderImpl final : public ColumnIndexBuilder { state_ = BuilderState::kFinished; /// Clear null_counts vector because at least one page does not provide it. - if (!column_index_.__isset.null_counts) { - column_index_.null_counts.clear(); + if (!columnIndex_.__isset.null_counts) { + columnIndex_.null_counts.clear(); } /// Decode min/max values according to the data type. - const size_t non_null_page_count = non_null_page_indices_.size(); - std::vector min_values, max_values; - min_values.resize(non_null_page_count); - max_values.resize(non_null_page_count); - auto decoder = MakeTypedDecoder(Encoding::PLAIN, descr_); - for (size_t i = 0; i < non_null_page_count; ++i) { - auto page_ordinal = non_null_page_indices_.at(i); - Decode( - decoder, column_index_.min_values.at(page_ordinal), &min_values, i); - Decode( - decoder, column_index_.max_values.at(page_ordinal), &max_values, i); + const size_t nonNullPageCount = nonNullPageIndices_.size(); + std::vector minValues, maxValues; + minValues.resize(nonNullPageCount); + maxValues.resize(nonNullPageCount); + auto decoder = makeTypedDecoder(Encoding::kPlain, descr_); + for (size_t i = 0; i < nonNullPageCount; ++i) { + auto pageOrdinal = nonNullPageIndices_.at(i); + decode( + decoder, columnIndex_.min_values.at(pageOrdinal), &minValues, i); + decode( + decoder, columnIndex_.max_values.at(pageOrdinal), &maxValues, i); } /// Decide the boundary order from decoded min/max values. - auto boundary_order = DetermineBoundaryOrder(min_values, max_values); - column_index_.__set_boundary_order(ToThrift(boundary_order)); + auto boundaryOrder = determineBoundaryOrder(minValues, maxValues); + columnIndex_.__set_boundary_order(toThrift(boundaryOrder)); } - void WriteTo(::arrow::io::OutputStream* sink) const override { + void writeTo(::arrow::io::OutputStream* sink) const override { if (state_ == BuilderState::kFinished) { - ThriftSerializer{}.Serialize(&column_index_, sink); + ThriftSerializer{}.serialize(&columnIndex_, sink); } } - std::unique_ptr Build() const override { + std::unique_ptr build() const override { if (state_ == BuilderState::kFinished) { return std::make_unique>( - *descr_, column_index_); + *descr_, columnIndex_); } return nullptr; } private: - BoundaryOrder::type DetermineBoundaryOrder( - const std::vector& min_values, - const std::vector& max_values) const { - VELOX_DCHECK_EQ(min_values.size(), max_values.size()); - if (min_values.empty()) { - return BoundaryOrder::Unordered; + BoundaryOrder::type determineBoundaryOrder( + const std::vector& minValues, + const std::vector& maxValues) const { + VELOX_DCHECK_EQ(minValues.size(), maxValues.size()); + if (minValues.empty()) { + return BoundaryOrder::kUnordered; } std::shared_ptr> comparator; try { - comparator = MakeComparator(descr_); + comparator = makeComparator(descr_); } catch (const ParquetException&) { /// Simply return unordered for unsupported comparator. - return BoundaryOrder::Unordered; + return BoundaryOrder::kUnordered; } - /// Check if both min_values and max_values are in ascending order. - bool is_ascending = true; - for (size_t i = 1; i < min_values.size(); ++i) { - if (comparator->Compare(min_values[i], min_values[i - 1]) || - comparator->Compare(max_values[i], max_values[i - 1])) { - is_ascending = false; + /// Check if both minValues and maxValues are in ascending order. + bool isAscending = true; + for (size_t i = 1; i < minValues.size(); ++i) { + if (comparator->compare(minValues[i], minValues[i - 1]) || + comparator->compare(maxValues[i], maxValues[i - 1])) { + isAscending = false; break; } } - if (is_ascending) { - return BoundaryOrder::Ascending; + if (isAscending) { + return BoundaryOrder::kAscending; } - /// Check if both min_values and max_values are in descending order. - bool is_descending = true; - for (size_t i = 1; i < min_values.size(); ++i) { - if (comparator->Compare(min_values[i - 1], min_values[i]) || - comparator->Compare(max_values[i - 1], max_values[i])) { - is_descending = false; + /// Check if both minValues and maxValues are in descending order. + bool isDescending = true; + for (size_t i = 1; i < minValues.size(); ++i) { + if (comparator->compare(minValues[i - 1], minValues[i]) || + comparator->compare(maxValues[i - 1], maxValues[i])) { + isDescending = false; break; } } - if (is_descending) { - return BoundaryOrder::Descending; + if (isDescending) { + return BoundaryOrder::kDescending; } /// Neither ascending nor descending is detected. - return BoundaryOrder::Unordered; + return BoundaryOrder::kUnordered; } const ColumnDescriptor* descr_; - facebook::velox::parquet::thrift::ColumnIndex column_index_; - std::vector non_null_page_indices_; + facebook::velox::parquet::thrift::ColumnIndex columnIndex_; + std::vector nonNullPageIndices_; BuilderState state_ = BuilderState::kCreated; }; @@ -666,10 +657,10 @@ class OffsetIndexBuilderImpl final : public OffsetIndexBuilder { public: OffsetIndexBuilderImpl() = default; - void AddPage( + void addPage( int64_t offset, - int32_t compressed_page_size, - int64_t first_row_index) override { + int32_t compressedPageSize, + int64_t firstRowIndex) override { if (state_ == BuilderState::kFinished) { throw ParquetException("Cannot add page to finished OffsetIndexBuilder."); } else if (state_ == BuilderState::kDiscarded) { @@ -679,14 +670,14 @@ class OffsetIndexBuilderImpl final : public OffsetIndexBuilder { state_ = BuilderState::kStarted; - facebook::velox::parquet::thrift::PageLocation page_location; - page_location.__set_offset(offset); - page_location.__set_compressed_page_size(compressed_page_size); - page_location.__set_first_row_index(first_row_index); - offset_index_.page_locations.emplace_back(std::move(page_location)); + facebook::velox::parquet::thrift::PageLocation pageLocation; + pageLocation.__set_offset(offset); + pageLocation.__set_compressed_page_size(compressedPageSize); + pageLocation.__set_first_row_index(firstRowIndex); + offsetIndex_.page_locations.emplace_back(std::move(pageLocation)); } - void Finish(int64_t final_position) override { + void finish(int64_t finalPosition) override { switch (state_) { case BuilderState::kCreated: { /// No pages are added. Simply discard the offset index. @@ -694,10 +685,10 @@ class OffsetIndexBuilderImpl final : public OffsetIndexBuilder { break; } case BuilderState::kStarted: { - /// Adjust page offsets according the final position. - if (final_position > 0) { - for (auto& page_location : offset_index_.page_locations) { - page_location.__set_offset(page_location.offset + final_position); + /// Adjust page offsets according to the final position. + if (finalPosition > 0) { + for (auto& pageLocation : offsetIndex_.page_locations) { + pageLocation.__set_offset(pageLocation.offset + finalPosition); } } state_ = BuilderState::kFinished; @@ -709,21 +700,21 @@ class OffsetIndexBuilderImpl final : public OffsetIndexBuilder { } } - void WriteTo(::arrow::io::OutputStream* sink) const override { + void writeTo(::arrow::io::OutputStream* sink) const override { if (state_ == BuilderState::kFinished) { - ThriftSerializer{}.Serialize(&offset_index_, sink); + ThriftSerializer{}.serialize(&offsetIndex_, sink); } } - std::unique_ptr Build() const override { + std::unique_ptr build() const override { if (state_ == BuilderState::kFinished) { - return std::make_unique(offset_index_); + return std::make_unique(offsetIndex_); } return nullptr; } private: - facebook::velox::parquet::thrift::OffsetIndex offset_index_; + facebook::velox::parquet::thrift::OffsetIndex offsetIndex_; BuilderState state_ = BuilderState::kCreated; }; @@ -732,118 +723,112 @@ class PageIndexBuilderImpl final : public PageIndexBuilder { explicit PageIndexBuilderImpl(const SchemaDescriptor* schema) : schema_(schema) {} - void AppendRowGroup() override { + void appendRowGroup() override { if (finished_) { throw ParquetException( - "Cannot call AppendRowGroup() to finished PageIndexBuilder."); + "Cannot call appendRowGroup() to finished PageIndexBuilder."); } // Append new builders of next row group. - const auto num_columns = static_cast(schema_->num_columns()); - column_index_builders_.emplace_back(); - offset_index_builders_.emplace_back(); - column_index_builders_.back().resize(num_columns); - offset_index_builders_.back().resize(num_columns); - - VELOX_DCHECK_EQ( - column_index_builders_.size(), offset_index_builders_.size()); - VELOX_DCHECK_EQ(column_index_builders_.back().size(), num_columns); - VELOX_DCHECK_EQ(offset_index_builders_.back().size(), num_columns); + const auto numColumns = static_cast(schema_->numColumns()); + columnIndexBuilders_.emplace_back(); + offsetIndexBuilders_.emplace_back(); + columnIndexBuilders_.back().resize(numColumns); + offsetIndexBuilders_.back().resize(numColumns); + + VELOX_DCHECK_EQ(columnIndexBuilders_.size(), offsetIndexBuilders_.size()); + VELOX_DCHECK_EQ(columnIndexBuilders_.back().size(), numColumns); + VELOX_DCHECK_EQ(offsetIndexBuilders_.back().size(), numColumns); } - ColumnIndexBuilder* GetColumnIndexBuilder(int32_t i) override { - CheckState(i); + ColumnIndexBuilder* getColumnIndexBuilder(int32_t i) override { + checkState(i); std::unique_ptr& builder = - column_index_builders_.back()[i]; + columnIndexBuilders_.back()[i]; if (builder == nullptr) { - builder = ColumnIndexBuilder::Make(schema_->Column(i)); + builder = ColumnIndexBuilder::make(schema_->column(i)); } return builder.get(); } - OffsetIndexBuilder* GetOffsetIndexBuilder(int32_t i) override { - CheckState(i); + OffsetIndexBuilder* getOffsetIndexBuilder(int32_t i) override { + checkState(i); std::unique_ptr& builder = - offset_index_builders_.back()[i]; + offsetIndexBuilders_.back()[i]; if (builder == nullptr) { - builder = OffsetIndexBuilder::Make(); + builder = OffsetIndexBuilder::make(); } return builder.get(); } - void Finish() override { + void finish() override { finished_ = true; } - void WriteTo(::arrow::io::OutputStream* sink, PageIndexLocation* location) + void writeTo(::arrow::io::OutputStream* sink, PageIndexLocation* location) const override { if (!finished_) { throw ParquetException( - "Cannot call WriteTo() to unfinished PageIndexBuilder."); + "Cannot call writeTo() to unfinished PageIndexBuilder."); } - location->column_index_location.clear(); - location->offset_index_location.clear(); + location->columnIndexLocation.clear(); + location->offsetIndexLocation.clear(); /// Serialize column index ordered by row group ordinal and then column /// ordinal. - SerializeIndex( - column_index_builders_, sink, &location->column_index_location); + serializeIndex(columnIndexBuilders_, sink, &location->columnIndexLocation); /// Serialize offset index ordered by row group ordinal and then column /// ordinal. - SerializeIndex( - offset_index_builders_, sink, &location->offset_index_location); + serializeIndex(offsetIndexBuilders_, sink, &location->offsetIndexLocation); } private: - /// Make sure column ordinal is not out of bound and the builder is in good + /// Make sure column ordinal is not out of bounds and the builder is in good /// state. - void CheckState(int32_t column_ordinal) const { + void checkState(int32_t columnOrdinal) const { if (finished_) { throw ParquetException("PageIndexBuilder is already finished."); } - if (column_ordinal < 0 || column_ordinal >= schema_->num_columns()) { - throw ParquetException("Invalid column ordinal: ", column_ordinal); + if (columnOrdinal < 0 || columnOrdinal >= schema_->numColumns()) { + throw ParquetException("Invalid column ordinal: ", columnOrdinal); } - if (offset_index_builders_.empty() || column_index_builders_.empty()) { + if (offsetIndexBuilders_.empty() || columnIndexBuilders_.empty()) { throw ParquetException("No row group appended to PageIndexBuilder."); } } template - void SerializeIndex( + void serializeIndex( const std::vector>>& - page_index_builders, + pageIndexBuilders, ::arrow::io::OutputStream* sink, std::map>>* location) const { - const auto num_columns = static_cast(schema_->num_columns()); + const auto numColumns = static_cast(schema_->numColumns()); /// Serialize the same kind of page index row group by row group. - for (size_t row_group = 0; row_group < page_index_builders.size(); - ++row_group) { - const auto& row_group_page_index_builders = - page_index_builders[row_group]; - VELOX_DCHECK_EQ(row_group_page_index_builders.size(), num_columns); + for (size_t rowGroup = 0; rowGroup < pageIndexBuilders.size(); ++rowGroup) { + const auto& rowGroupPageIndexBuilders = pageIndexBuilders[rowGroup]; + VELOX_DCHECK_EQ(rowGroupPageIndexBuilders.size(), numColumns); - bool has_valid_index = false; + bool hasValidIndex = false; std::vector> locations( - num_columns, std::nullopt); - - /// In the same row group, serialize the same kind of page index column by - /// column. - for (size_t column = 0; column < num_columns; ++column) { - const auto& column_page_index_builder = - row_group_page_index_builders[column]; - if (column_page_index_builder != nullptr) { + numColumns, std::nullopt); + + /// In the same row group, serialize the same kind of page index column + /// by column. + for (size_t column = 0; column < numColumns; ++column) { + const auto& columnPageIndexBuilder = rowGroupPageIndexBuilders[column]; + if (columnPageIndexBuilder != nullptr) { /// Try serializing the page index. - PARQUET_ASSIGN_OR_THROW(int64_t pos_before_write, sink->Tell()); - column_page_index_builder->WriteTo(sink); - PARQUET_ASSIGN_OR_THROW(int64_t pos_after_write, sink->Tell()); - int64_t len = pos_after_write - pos_before_write; + PARQUET_ASSIGN_OR_THROW(int64_t posBeforeWrite, sink->Tell()); + columnPageIndexBuilder->writeTo(sink); + PARQUET_ASSIGN_OR_THROW(int64_t posAfterWrite, sink->Tell()); + int64_t len = posAfterWrite - posBeforeWrite; - /// The page index is not serialized and skip reporting its location + /// The page index is not serialized and skip reporting its location. if (len == 0) { continue; } @@ -851,180 +836,180 @@ class PageIndexBuilderImpl final : public PageIndexBuilder { if (len > std::numeric_limits::max()) { throw ParquetException("Page index size overflows to INT32_MAX"); } - locations[column] = {pos_before_write, static_cast(len)}; - has_valid_index = true; + locations[column] = {posBeforeWrite, static_cast(len)}; + hasValidIndex = true; } } - if (has_valid_index) { - location->emplace(row_group, std::move(locations)); + if (hasValidIndex) { + location->emplace(rowGroup, std::move(locations)); } } } const SchemaDescriptor* schema_; std::vector>> - column_index_builders_; + columnIndexBuilders_; std::vector>> - offset_index_builders_; + offsetIndexBuilders_; bool finished_ = false; }; } // namespace -RowGroupIndexReadRange PageIndexReader::DeterminePageIndexRangesInRowGroup( - const RowGroupMetaData& row_group_metadata, +RowGroupIndexReadRange PageIndexReader::determinePageIndexRangesInRowGroup( + const RowGroupMetaData& rowGroupMetadata, const std::vector& columns) { - int64_t ci_start = std::numeric_limits::max(); - int64_t oi_start = std::numeric_limits::max(); - int64_t ci_end = -1; - int64_t oi_end = -1; - - auto merge_range = [](const std::optional& index_location, - int64_t* start, - int64_t* end) { - if (index_location.has_value()) { - int64_t index_end = 0; - if (index_location->offset < 0 || index_location->length <= 0 || - ::arrow::internal::AddWithOverflow( - index_location->offset, index_location->length, &index_end)) { + int64_t ciStart = std::numeric_limits::max(); + int64_t oiStart = std::numeric_limits::max(); + int64_t ciEnd = -1; + int64_t oiEnd = -1; + + auto mergeRange = [](const std::optional& indexLocation, + int64_t* start, + int64_t* end) { + if (indexLocation.has_value()) { + int64_t indexEnd = 0; + if (indexLocation->offset < 0 || indexLocation->length <= 0 || + ::arrow::internal::addWithOverflow( + indexLocation->offset, indexLocation->length, &indexEnd)) { throw ParquetException( "Invalid page index location: offset ", - index_location->offset, + indexLocation->offset, " length ", - index_location->length); + indexLocation->length); } - *start = std::min(*start, index_location->offset); - *end = std::max(*end, index_end); + *start = std::min(*start, indexLocation->offset); + *end = std::max(*end, indexEnd); } }; if (columns.empty()) { - for (int32_t i = 0; i < row_group_metadata.num_columns(); ++i) { - auto col_chunk = row_group_metadata.ColumnChunk(i); - merge_range(col_chunk->GetColumnIndexLocation(), &ci_start, &ci_end); - merge_range(col_chunk->GetOffsetIndexLocation(), &oi_start, &oi_end); + for (int32_t i = 0; i < rowGroupMetadata.numColumns(); ++i) { + auto colChunk = rowGroupMetadata.columnChunk(i); + mergeRange(colChunk->getColumnIndexLocation(), &ciStart, &ciEnd); + mergeRange(colChunk->getOffsetIndexLocation(), &oiStart, &oiEnd); } } else { for (int32_t i : columns) { - if (i < 0 || i >= row_group_metadata.num_columns()) { + if (i < 0 || i >= rowGroupMetadata.numColumns()) { throw ParquetException("Invalid column ordinal ", i); } - auto col_chunk = row_group_metadata.ColumnChunk(i); - merge_range(col_chunk->GetColumnIndexLocation(), &ci_start, &ci_end); - merge_range(col_chunk->GetOffsetIndexLocation(), &oi_start, &oi_end); + auto colChunk = rowGroupMetadata.columnChunk(i); + mergeRange(colChunk->getColumnIndexLocation(), &ciStart, &ciEnd); + mergeRange(colChunk->getOffsetIndexLocation(), &oiStart, &oiEnd); } } - RowGroupIndexReadRange read_range; - if (ci_end != -1) { - read_range.column_index = {ci_start, ci_end - ci_start}; + RowGroupIndexReadRange readRange; + if (ciEnd != -1) { + readRange.columnIndex = {ciStart, ciEnd - ciStart}; } - if (oi_end != -1) { - read_range.offset_index = {oi_start, oi_end - oi_start}; + if (oiEnd != -1) { + readRange.offsetIndex = {oiStart, oiEnd - oiStart}; } - return read_range; + return readRange; } // ---------------------------------------------------------------------- -// Public factory functions +// Public factory functions. -std::unique_ptr ColumnIndex::Make( +std::unique_ptr ColumnIndex::make( const ColumnDescriptor& descr, - const void* serialized_index, - uint32_t index_len, + const void* serializedIndex, + uint32_t indexLen, const ReaderProperties& properties) { - facebook::velox::parquet::thrift::ColumnIndex column_index; + facebook::velox::parquet::thrift::ColumnIndex columnIndex; ThriftDeserializer deserializer(properties); - deserializer.DeserializeMessage( - reinterpret_cast(serialized_index), - &index_len, - &column_index); - switch (descr.physical_type()) { - case Type::BOOLEAN: + deserializer.deserializeMessage( + reinterpret_cast(serializedIndex), + &indexLen, + &columnIndex); + switch (descr.physicalType()) { + case Type::kBoolean: return std::make_unique>( - descr, std::move(column_index)); - case Type::INT32: + descr, std::move(columnIndex)); + case Type::kInt32: return std::make_unique>( - descr, std::move(column_index)); - case Type::INT64: + descr, std::move(columnIndex)); + case Type::kInt64: return std::make_unique>( - descr, std::move(column_index)); - case Type::INT96: + descr, std::move(columnIndex)); + case Type::kInt96: return std::make_unique>( - descr, std::move(column_index)); - case Type::FLOAT: + descr, std::move(columnIndex)); + case Type::kFloat: return std::make_unique>( - descr, std::move(column_index)); - case Type::DOUBLE: + descr, std::move(columnIndex)); + case Type::kDouble: return std::make_unique>( - descr, std::move(column_index)); - case Type::BYTE_ARRAY: + descr, std::move(columnIndex)); + case Type::kByteArray: return std::make_unique>( - descr, std::move(column_index)); - case Type::FIXED_LEN_BYTE_ARRAY: + descr, std::move(columnIndex)); + case Type::kFixedLenByteArray: return std::make_unique>( - descr, std::move(column_index)); - case Type::UNDEFINED: + descr, std::move(columnIndex)); + case Type::kUndefined: return nullptr; } ::arrow::Unreachable("Cannot make ColumnIndex of an unknown type"); return nullptr; } -std::unique_ptr OffsetIndex::Make( - const void* serialized_index, - uint32_t index_len, +std::unique_ptr OffsetIndex::make( + const void* serializedIndex, + uint32_t indexLen, const ReaderProperties& properties) { - facebook::velox::parquet::thrift::OffsetIndex offset_index; + facebook::velox::parquet::thrift::OffsetIndex offsetIndex; ThriftDeserializer deserializer(properties); - deserializer.DeserializeMessage( - reinterpret_cast(serialized_index), - &index_len, - &offset_index); - return std::make_unique(offset_index); + deserializer.deserializeMessage( + reinterpret_cast(serializedIndex), + &indexLen, + &offsetIndex); + return std::make_unique(offsetIndex); } -std::shared_ptr PageIndexReader::Make( +std::shared_ptr PageIndexReader::make( ::arrow::io::RandomAccessFile* input, - std::shared_ptr file_metadata, + std::shared_ptr fileMetadata, const ReaderProperties& properties, - std::shared_ptr file_decryptor) { + std::shared_ptr fileDecryptor) { return std::make_shared( - input, std::move(file_metadata), properties, std::move(file_decryptor)); + input, std::move(fileMetadata), properties, std::move(fileDecryptor)); } -std::unique_ptr ColumnIndexBuilder::Make( +std::unique_ptr ColumnIndexBuilder::make( const ColumnDescriptor* descr) { - switch (descr->physical_type()) { - case Type::BOOLEAN: + switch (descr->physicalType()) { + case Type::kBoolean: return std::make_unique>(descr); - case Type::INT32: + case Type::kInt32: return std::make_unique>(descr); - case Type::INT64: + case Type::kInt64: return std::make_unique>(descr); - case Type::INT96: + case Type::kInt96: return std::make_unique>(descr); - case Type::FLOAT: + case Type::kFloat: return std::make_unique>(descr); - case Type::DOUBLE: + case Type::kDouble: return std::make_unique>(descr); - case Type::BYTE_ARRAY: + case Type::kByteArray: return std::make_unique>(descr); - case Type::FIXED_LEN_BYTE_ARRAY: + case Type::kFixedLenByteArray: return std::make_unique>(descr); - case Type::UNDEFINED: + case Type::kUndefined: return nullptr; } ::arrow::Unreachable("Cannot make ColumnIndexBuilder of an unknown type"); return nullptr; } -std::unique_ptr OffsetIndexBuilder::Make() { +std::unique_ptr OffsetIndexBuilder::make() { return std::make_unique(); } -std::unique_ptr PageIndexBuilder::Make( +std::unique_ptr PageIndexBuilder::make( const SchemaDescriptor* schema) { return std::make_unique(schema); } @@ -1032,8 +1017,8 @@ std::unique_ptr PageIndexBuilder::Make( std::ostream& operator<<( std::ostream& out, const PageIndexSelection& selection) { - out << "PageIndexSelection{column_index = " << selection.column_index - << ", offset_index = " << selection.offset_index << "}"; + out << "PageIndexSelection{column_index = " << selection.columnIndex + << ", offset_index = " << selection.offsetIndex << "}"; return out; } diff --git a/velox/dwio/parquet/writer/arrow/PageIndex.h b/velox/dwio/parquet/writer/arrow/PageIndex.h index a246b34da9a..b0833075e51 100644 --- a/velox/dwio/parquet/writer/arrow/PageIndex.h +++ b/velox/dwio/parquet/writer/arrow/PageIndex.h @@ -41,72 +41,71 @@ class SchemaDescriptor; class PARQUET_EXPORT ColumnIndex { public: /// \brief Create a ColumnIndex from a serialized thrift message. - static std::unique_ptr Make( + static std::unique_ptr make( const ColumnDescriptor& descr, - const void* serialized_index, - uint32_t index_len, + const void* serializedIndex, + uint32_t indexLen, const ReaderProperties& properties); virtual ~ColumnIndex() = default; - /// \brief A bitmap with a bit set for each data page that has only null - /// values. + /// Values. /// /// The length of this vector is equal to the number of data pages in the /// column. - virtual const std::vector& null_pages() const = 0; + virtual const std::vector& nullPages() const = 0; /// \brief A vector of encoded lower bounds for each data page in this column. /// - /// `null_pages` should be inspected first, as only pages with non-null values - /// may have their lower bounds populated. - virtual const std::vector& encoded_min_values() const = 0; + /// `nullPages` should be inspected first, as only pages with non-null + /// values may have their lower bounds populated. + virtual const std::vector& encodedMinValues() const = 0; /// \brief A vector of encoded upper bounds for each data page in this column. /// - /// `null_pages` should be inspected first, as only pages with non-null values - /// may have their upper bounds populated. - virtual const std::vector& encoded_max_values() const = 0; + /// `nullPages` should be inspected first, as only pages with non-null + /// values may have their upper bounds populated. + virtual const std::vector& encodedMaxValues() const = 0; /// \brief The ordering of lower and upper bounds. /// /// The boundary order applies across all lower bounds, and all upper bounds, /// respectively. However, the order between lower bounds and upper bounds /// cannot be derived from this. - virtual BoundaryOrder::type boundary_order() const = 0; + virtual BoundaryOrder::type boundaryOrder() const = 0; /// \brief Whether per-page null count information is available. - virtual bool has_null_counts() const = 0; + virtual bool hasNullCounts() const = 0; /// \brief An optional vector with the number of null values in each data /// page. /// - /// `has_null_counts` should be called first to determine if this information + /// `hasNullCounts` should be called first to determine if this information /// is available. - virtual const std::vector& null_counts() const = 0; + virtual const std::vector& nullCounts() const = 0; /// \brief A vector of page indices for non-null pages. - virtual const std::vector& non_null_page_indices() const = 0; + virtual const std::vector& nonNullPageIndices() const = 0; }; /// \brief Typed implementation of ColumnIndex. template class PARQUET_EXPORT TypedColumnIndex : public ColumnIndex { public: - using T = typename DType::c_type; + using T = typename DType::CType; /// \brief A vector of lower bounds for each data page in this column. /// - /// This is like `encoded_min_values`, but with the values decoded according - /// to the column's physical type. `min_values` and `max_values` can be used - /// together with `boundary_order` in order to prune some data pages when + /// This is like `encodedMinValues`, but with the values decoded according + /// to the column's physical type. `minValues` and `maxValues` can be used + /// together with `boundaryOrder` in order to prune some data pages when /// searching for specific values. - virtual const std::vector& min_values() const = 0; + virtual const std::vector& minValues() const = 0; /// \brief A vector of upper bounds for each data page in this column. /// - /// Just like `min_values`, but for upper bounds instead of lower bounds. - virtual const std::vector& max_values() const = 0; + /// Just like `minValues`, but for upper bounds instead of lower bounds. + virtual const std::vector& maxValues() const = 0; }; using BoolColumnIndex = TypedColumnIndex; @@ -123,9 +122,9 @@ struct PARQUET_EXPORT PageLocation { /// File offset of the data page. int64_t offset; /// Total compressed size of the data page and header. - int32_t compressed_page_size; + int32_t compressedPageSize; /// Row id of the first row in the page within the row group. - int64_t first_row_index; + int64_t firstRowIndex; }; /// \brief OffsetIndex is a proxy around @@ -133,15 +132,15 @@ struct PARQUET_EXPORT PageLocation { class PARQUET_EXPORT OffsetIndex { public: /// \brief Create a OffsetIndex from a serialized thrift message. - static std::unique_ptr Make( - const void* serialized_index, - uint32_t index_len, + static std::unique_ptr make( + const void* serializedIndex, + uint32_t indexLen, const ReaderProperties& properties); virtual ~OffsetIndex() = default; /// \brief A vector of locations for each data page in this column. - virtual const std::vector& page_locations() const = 0; + virtual const std::vector& pageLocations() const = 0; }; /// \brief Interface for reading the page index for a Parquet row group. @@ -151,24 +150,24 @@ class PARQUET_EXPORT RowGroupPageIndexReader { /// \brief Read column index of a column chunk. /// - /// \param[in] i column ordinal of the column chunk. - /// \returns column index of the column or nullptr if it does not exist. + /// \param[in] i Column ordinal of the column chunk. + /// \returns Column index of the column or nullptr if it does not exist. /// \throws ParquetException if the index is out of bound. - virtual std::shared_ptr GetColumnIndex(int32_t i) = 0; + virtual std::shared_ptr getColumnIndex(int32_t i) = 0; /// \brief Read offset index of a column chunk. /// - /// \param[in] i column ordinal of the column chunk. - /// \returns offset index of the column or nullptr if it does not exist. + /// \param[in] i Column ordinal of the column chunk. + /// \returns Offset index of the column or nullptr if it does not exist. /// \throws ParquetException if the index is out of bound. - virtual std::shared_ptr GetOffsetIndex(int32_t i) = 0; + virtual std::shared_ptr getOffsetIndex(int32_t i) = 0; }; struct PageIndexSelection { /// Specifies whether to read the column index. - bool column_index = false; + bool columnIndex = false; /// Specifies whether to read the offset index. - bool offset_index = false; + bool offsetIndex = false; }; PARQUET_EXPORT @@ -178,11 +177,11 @@ struct RowGroupIndexReadRange { /// Base start and total size of column index of all column chunks in a row /// group. If none of the column chunks have column index, it is set to /// std::nullopt. - std::optional<::arrow::io::ReadRange> column_index = std::nullopt; + std::optional<::arrow::io::ReadRange> columnIndex = std::nullopt; /// Base start and total size of offset index of all column chunks in a row /// group. If none of the column chunks have offset index, it is set to /// std::nullopt. - std::optional<::arrow::io::ReadRange> offset_index = std::nullopt; + std::optional<::arrow::io::ReadRange> offsetIndex = std::nullopt; }; /// \brief Interface for reading the page index for a Parquet file. @@ -191,26 +190,25 @@ class PARQUET_EXPORT PageIndexReader { virtual ~PageIndexReader() = default; /// \brief Create a PageIndexReader instance. - /// \returns a PageIndexReader instance. - /// WARNING: The returned PageIndexReader references to all the input + /// \returns A PageIndexReader instance. + /// WARNING: The returned PageIndexReader references all the input /// parameters, so it must not outlive all of the input parameters. Usually /// these input parameters come from the same ParquetFileReader object, so it /// must not outlive the reader that creates this PageIndexReader. - static std::shared_ptr Make( + static std::shared_ptr make( ::arrow::io::RandomAccessFile* input, - std::shared_ptr file_metadata, + std::shared_ptr fileMetadata, const ReaderProperties& properties, - std::shared_ptr file_decryptor = NULLPTR); + std::shared_ptr fileDecryptor = NULLPTR); /// \brief Get the page index reader of a specific row group. - /// \param[in] i row group ordinal to get page index reader. - /// \returns RowGroupPageIndexReader of the specified row group. A nullptr may - /// or may - /// not be returned if the page index for the row group is - /// unavailable. It is the caller's responsibility to check the - /// return value of follow-up calls to the RowGroupPageIndexReader. + /// \param[in] i Row group ordinal to get page index reader. + /// \returns RowGroupPageIndexReader of the specified row group. A nullptr + /// may or may not be returned if the page index for the row group is + /// unavailable. It is the caller's responsibility to check the + /// return value of follow-up calls to the RowGroupPageIndexReader. /// \throws ParquetException if the index is out of bound. - virtual std::shared_ptr RowGroup(int i) = 0; + virtual std::shared_ptr rowGroup(int i) = 0; /// \brief Advise the reader which part of page index will be read later. /// @@ -218,66 +216,61 @@ class PARQUET_EXPORT PageIndexReader { /// may be read later to get better performance. /// /// The contract of this function is as below: - /// 1) If WillNeed() has not been called for a specific row group and the page - /// index - /// exists, follow-up calls to get column index or offset index of all - /// columns in this row group SHOULD NOT FAIL, but the performance may not - /// be optimal. - /// 2) If WillNeed() has been called for a specific row group, follow-up calls - /// to get - /// page index are limited to columns and index type requested by - /// WillNeed(). So it MAY FAIL if columns that are not requested by - /// WillNeed() are requested. - /// 3) Later calls to WillNeed() MAY OVERRIDE previous calls of same row - /// groups. For example, 1) If WillNeed() is not called for row group 0, then - /// follow-up calls to read - /// column index and/or offset index of all columns of row group 0 should - /// not fail if its page index exists. - /// 2) If WillNeed() is called for columns 0 and 1 for row group 0, then - /// follow-up - /// call to read page index of column 2 for row group 0 MAY FAIL even if - /// its page index exists. - /// 3) If WillNeed() is called for row group 0 with offset index only, then - /// follow-up call to read column index of row group 0 MAY FAIL even if - /// the column index of this column exists. - /// 4) If WillNeed() is called for columns 0 and 1 for row group 0, then later - /// call to WillNeed() for columns 1 and 2 for row group 0. The later one - /// overrides previous call and only columns 1 and 2 of row group 0 are - /// allowed to access. + /// 1) If willNeed() has not been called for a specific row group and the + /// page index exists, follow-up calls to get column index or offset index of + /// all columns in this row group SHOULD NOT FAIL, but the performance may not + /// be optimal. + /// 2) If willNeed() has been called for a specific row group, follow-up + /// calls to get page index are limited to columns and index type requested by + /// willNeed(). So it MAY FAIL if columns that are not requested by + /// willNeed() are requested. + /// 3) Later calls to willNeed() MAY OVERRIDE previous calls of same row + /// groups. For example, 1) if willNeed() is not called for row group 0, then + /// follow-up calls to read column index and/or offset index of all columns of + /// row group 0 should not fail if its page index exists. + /// 2) If willNeed() is called for columns 0 and 1 for row group 0, then + /// follow-up call to read page index of column 2 for row group 0 MAY FAIL + /// even if its page index exists. 3) If willNeed() is called for row group 0 + /// with offset index only, then follow-up call to read column index of row + /// group 0 MAY FAIL even if the column index of this column exists. 4) If + /// willNeed() is called for columns 0 and 1 for row group 0, then later call + /// to willNeed() for columns 1 and 2 for row group 0. The later one overrides + /// previous call and only columns 1 and 2 of row group 0 are allowed to + /// access. /// - /// \param[in] row_group_indices list of row group ordinal to read page index - /// later. \param[in] column_indices list of column ordinal to read page index - /// later. If it is - /// empty, it means all columns in the row group will be read. - /// \param[in] selection which kind of page index is required later. - virtual void WillNeed( - const std::vector& row_group_indices, - const std::vector& column_indices, + /// \param[in] rowGroupIndices List of row group ordinal to read page + /// index later. \param[in] columnIndices List of column ordinal to + /// read page index later. If it is empty, it means all columns in the + /// row group will be read. + /// \param[in] selection Which kind of page index is required later. + virtual void willNeed( + const std::vector& rowGroupIndices, + const std::vector& columnIndices, const PageIndexSelection& selection) = 0; - /// \brief Advise the reader page index of these row groups will not be read + /// \brief Advise the reader page index of these row groups will not be read. /// any more. /// /// The PageIndexReader implementation has the opportunity to cancel any /// prefetch or release resource that are related to these row groups. /// - /// \param[in] row_group_indices list of row group ordinal that whose page - /// index will not be accessed any more. - virtual void WillNotNeed(const std::vector& row_group_indices) = 0; + /// \param[in] rowGroupIndices List of row group ordinal whose page index + /// will not be accessed any more. + virtual void willNotNeed(const std::vector& rowGroupIndices) = 0; /// \brief Determine the column index and offset index ranges for the given /// row group. /// - /// \param[in] row_group_metadata row group metadata to get column chunk - /// metadata. \param[in] columns list of column ordinals to get page index. If - /// the list is empty, - /// it means all columns in the row group. + /// \param[in] rowGroupMetadata Row group metadata to get column chunk + /// metadata. + /// \param[in] columns List of column ordinals to get page index. + /// If the list is empty, it means all columns in the row group. /// \returns RowGroupIndexReadRange of the specified row group. Throws /// ParquetException /// if the selected column ordinal is out of bound or metadata of /// page index is corrupted. - static RowGroupIndexReadRange DeterminePageIndexRangesInRowGroup( - const RowGroupMetaData& row_group_metadata, + static RowGroupIndexReadRange determinePageIndexRangesInRowGroup( + const RowGroupMetaData& rowGroupMetadata, const std::vector& columns); }; @@ -286,7 +279,7 @@ class PARQUET_EXPORT PageIndexReader { class PARQUET_EXPORT ColumnIndexBuilder { public: /// \brief API convenience to create a ColumnIndexBuilder. - static std::unique_ptr Make( + static std::unique_ptr make( const ColumnDescriptor* descr); virtual ~ColumnIndexBuilder() = default; @@ -297,50 +290,50 @@ class PARQUET_EXPORT ColumnIndexBuilder { /// not update statistics any more. /// /// \param stats Page statistics in the encoded form. - virtual void AddPage(const EncodedStatistics& stats) = 0; + virtual void addPage(const EncodedStatistics& stats) = 0; /// \brief Complete the column index. /// - /// Once called, AddPage() can no longer be called. - /// WriteTo() and Build() can only called after Finish() has been called. - virtual void Finish() = 0; + /// Once called, addPage() can no longer be called. + /// writeTo() and build() can only called after finish() has been called. + virtual void finish() = 0; /// \brief Serialize the column index thrift message. /// /// If the ColumnIndexBuilder has seen any corrupted statistics, it will /// not write any data to the sink. /// - /// \param[out] sink output stream to write the serialized message. - virtual void WriteTo(::arrow::io::OutputStream* sink) const = 0; + /// \param[out] sink Output stream to write the serialized message. + virtual void writeTo(::arrow::io::OutputStream* sink) const = 0; /// \brief Create a ColumnIndex directly. /// /// \return If the ColumnIndexBuilder has seen any corrupted statistics, it /// simply returns nullptr. Otherwise the column index is built and returned. - virtual std::unique_ptr Build() const = 0; + virtual std::unique_ptr build() const = 0; }; /// \brief Interface for collecting offset index of data pages in a column /// chunk. class PARQUET_EXPORT OffsetIndexBuilder { public: - /// \brief API convenience to create a OffsetIndexBuilder. - static std::unique_ptr Make(); + /// \brief API convenience to create an OffsetIndexBuilder. + static std::unique_ptr make(); virtual ~OffsetIndexBuilder() = default; /// \brief Add page location of a data page. - virtual void AddPage( + virtual void addPage( int64_t offset, - int32_t compressed_page_size, - int64_t first_row_index) = 0; + int32_t compressedPageSize, + int64_t firstRowIndex) = 0; /// \brief Add page location of a data page. - void AddPage(const PageLocation& page_location) { - AddPage( - page_location.offset, - page_location.compressed_page_size, - page_location.first_row_index); + void addPage(const PageLocation& pageLocation) { + addPage( + pageLocation.offset, + pageLocation.compressedPageSize, + pageLocation.firstRowIndex); } /// \brief Complete the offset index. @@ -349,46 +342,46 @@ class PARQUET_EXPORT OffsetIndexBuilder { /// sink and the OffsetIndexBuilder has only collected the relative offset /// which requires adjustment once they are flushed to the file. /// - /// \param final_position Final stream offset to add for page offset + /// \param finalPosition Final stream offset to add for page offset /// adjustment. - virtual void Finish(int64_t final_position) = 0; + virtual void finish(int64_t finalPosition) = 0; /// \brief Serialize the offset index thrift message. /// - /// \param[out] sink output stream to write the serialized message. - virtual void WriteTo(::arrow::io::OutputStream* sink) const = 0; + /// \param[out] sink Output stream to write the serialized message. + virtual void writeTo(::arrow::io::OutputStream* sink) const = 0; /// \brief Create an OffsetIndex directly. - virtual std::unique_ptr Build() const = 0; + virtual std::unique_ptr build() const = 0; }; -/// \brief Interface for collecting page index of a parquet file. +/// \brief Interface for collecting page index of a Parquet file. class PARQUET_EXPORT PageIndexBuilder { public: /// \brief API convenience to create a PageIndexBuilder. - static std::unique_ptr Make(const SchemaDescriptor* schema); + static std::unique_ptr make(const SchemaDescriptor* schema); virtual ~PageIndexBuilder() = default; /// \brief Start a new row group. - virtual void AppendRowGroup() = 0; + virtual void appendRowGroup() = 0; /// \brief Get the ColumnIndexBuilder from column ordinal. /// /// \param i Column ordinal. - /// \return ColumnIndexBuilder for the column and its memory ownership belongs - /// to the PageIndexBuilder. - virtual ColumnIndexBuilder* GetColumnIndexBuilder(int32_t i) = 0; + /// \return ColumnIndexBuilder for the column and its memory ownership + /// belongs to the PageIndexBuilder. + virtual ColumnIndexBuilder* getColumnIndexBuilder(int32_t i) = 0; /// \brief Get the OffsetIndexBuilder from column ordinal. /// /// \param i Column ordinal. - /// \return OffsetIndexBuilder for the column and its memory ownership belongs - /// to the PageIndexBuilder. - virtual OffsetIndexBuilder* GetOffsetIndexBuilder(int32_t i) = 0; + /// \return OffsetIndexBuilder for the column and its memory ownership + /// belongs to the PageIndexBuilder. + virtual OffsetIndexBuilder* getOffsetIndexBuilder(int32_t i) = 0; /// \brief Complete the page index builder and no more write is allowed. - virtual void Finish() = 0; + virtual void finish() = 0; /// \brief Serialize the page index thrift message. /// @@ -397,7 +390,7 @@ class PARQUET_EXPORT PageIndexBuilder { /// /// \param[out] sink The output stream to write the page index. /// \param[out] location The location of all page index to the start of sink. - virtual void WriteTo( + virtual void writeTo( ::arrow::io::OutputStream* sink, PageIndexLocation* location) const = 0; }; diff --git a/velox/dwio/parquet/writer/arrow/PathInternal.cpp b/velox/dwio/parquet/writer/arrow/PathInternal.cpp index 885a8254edd..07541223c60 100644 --- a/velox/dwio/parquet/writer/arrow/PathInternal.cpp +++ b/velox/dwio/parquet/writer/arrow/PathInternal.cpp @@ -18,77 +18,77 @@ // Overview. // -// The strategy used for this code for repetition/definition -// is to dissect the top level array into a list of paths -// from the top level array to the final primitive (possibly -// dictionary encoded array). It then evaluates each one of -// those paths to produce results for the callback iteratively. +// The strategy used for this code for repetition/definition. +// Is to dissect the top level array into a list of paths. +// From the top level array to the final primitive (possibly. +// Dictionary encoded array). It then evaluates each one of. +// Those paths to produce results for the callback iteratively. // -// This approach was taken to reduce the aggregate memory required if we were -// to build all def/rep levels in parallel as apart of a tree traversal. It -// also allows for straightforward parallelization at the path level if that is -// desired in the future. +// This approach was taken to reduce the aggregate memory required if we were. +// To build all def/rep levels in parallel as apart of a tree traversal. It. +// Also allows for straightforward parallelization at the path level if that is. +// Desired in the future. // -// The main downside to this approach is it duplicates effort for nodes -// that share common ancestors. This can be mitigated to some degree -// by adding in optimizations that detect leaf arrays that share -// the same common list ancestor and reuse the repetition levels -// from the first leaf encountered (only definition levels greater -// the list ancestor need to be re-evaluated. This is left for future -// work. +// The main downside to this approach is it duplicates effort for nodes. +// That share common ancestors. This can be mitigated to some degree. +// By adding in optimizations that detect leaf arrays that share. +// The same common list ancestor and reuse the repetition levels. +// From the first leaf encountered (only definition levels greater. +// The list ancestor need to be re-evaluated. This is left for future. +// Work. // // Algorithm. // // As mentioned above this code dissects arrays into constituent parts: -// nullability data, and list offset data. It tries to optimize for -// some special cases, where it is known ahead of time that a step -// can be skipped (e.g. a nullable array happens to have all of its -// values) or batch filled (a nullable array has all null values). -// One further optimization that is not implemented but could be done -// in the future is special handling for nested list arrays that -// have some intermediate data which indicates the final array contains only -// nulls. +// Nullability data, and list offset data. It tries to optimize for. +// Some special cases, where it is known ahead of time that a step. +// Can be skipped (e.g. a nullable array happens to have all of its. +// Values) or batch filled (a nullable array has all null values). +// One further optimization that is not implemented but could be done. +// In the future is special handling for nested list arrays that. +// Have some intermediate data which indicates the final array contains only. +// Nulls. // -// In general, the algorithm attempts to batch work at each node as much -// as possible. For nullability nodes this means finding runs of null -// values and batch filling those interspersed with finding runs of non-null -// values to process in batch at the next column. +// In general, the algorithm attempts to batch work at each node as much. +// As possible. For nullability nodes this means finding runs of null. +// Values and batch filling those interspersed with finding runs of non-null. +// Values to process in batch at the next column. // -// Similarly, list runs of empty lists are all processed in one batch -// followed by either: -// - A single list entry for non-terminal lists (i.e. the upper part of a +// Similarly, list runs of empty lists are all processed in one batch. +// Followed by either: +// - A single list entry for non-terminal lists (i.e. the upper part of a. // nested list) -// - Runs of non-empty lists for the terminal list (i.e. the lowest part of a -// nested list). +// - Runs of non-empty lists for the terminal list (i.e. the lowest part of +// a. Nested list). // // This makes use of the following observations. -// 1. Null values at any node on the path are terminal (repetition and -// definition -// level can be set directly when a Null value is encountered). +// 1. Null values at any node on the path are terminal (repetition and. +// Definition. +// Level can be set directly when a Null value is encountered). // 2. Empty lists share this eager termination property with Null values. -// 3. In order to keep repetition/definition level populated the algorithm is -// lazy -// in assigning repetition levels. The algorithm tracks whether it is -// currently in the middle of a list by comparing the lengths of -// repetition/definition levels. If it is currently in the middle of a list -// the the number of repetition levels populated will be greater than -// definition levels (the start of a List requires adding the first -// element). If there are equal numbers of definition and repetition levels -// populated this indicates a list is waiting to be started and the next -// list encountered will have its repetition level signify the beginning of -// the list. +// 3. In order to keep repetition/definition level populated the algorithm is. +// Lazy. +// In assigning repetition levels. The algorithm tracks whether it is. +// Currently in the middle of a list by comparing the lengths of. +// Repetition/definition levels. If it is currently in the middle of a list. +// The the number of repetition levels populated will be greater than. +// Definition levels (the start of a List requires adding the first. +// Element). If there are equal numbers of definition and repetition levels. +// Populated this indicates a list is waiting to be started and the next. +// List encountered will have its repetition level signify the beginning of. +// The list. // // Other implementation notes. // -// This code hasn't been benchmarked (or assembly analyzed) but did the -// following as optimizations (yes premature optimization is the root of all -// evil). -// - This code does not use recursion, instead it constructs its own stack -// and manages -// updating elements accordingly. +// This code hasn't been benchmarked (or assembly analyzed) but did the. +// Following as optimizations (yes premature optimization is the root of +// all. Evil). +// - This code does not use recursion, instead it constructs its own stack. +// And manages. +// Updating elements accordingly. // - It tries to avoid using Status for common return states. -// - Avoids virtual dispatch in favor of if/else statements on a set of well -// known classes. +// - Avoids virtual dispatch in favor of if/else statements on a set of +// well. Known classes. #include "velox/dwio/parquet/writer/arrow/PathInternal.h" @@ -127,9 +127,8 @@ using ::arrow::TypedBufferBuilder; constexpr static int16_t kLevelNotSet = -1; /// \brief Simple result of a iterating over a column to determine values. -enum IterationResult { - /// Processing is done at this node. Move back up the path - /// to continue processing. +enum IterationResult { /// Processing is done at this node. Move back up the + /// path. To continue processing. kDone = -1, /// Move down towards the leaf for processing. kNext = 1, @@ -137,339 +136,335 @@ enum IterationResult { kError = 2 }; -#define RETURN_IF_ERROR(iteration_result) \ - do { \ - if (ARROW_PREDICT_FALSE(iteration_result == kError)) { \ - return iteration_result; \ - } \ +#define RETURN_IF_ERROR(iterationResult) \ + do { \ + if (ARROW_PREDICT_FALSE(iterationResult == kError)) { \ + return iterationResult; \ + } \ } while (false) -int64_t LazyNullCount(const Array& array) { +int64_t lazyNullCount(const Array& array) { return array.data()->null_count.load(); } -bool LazyNoNulls(const Array& array) { - int64_t null_count = LazyNullCount(array); - return null_count == 0 || - // kUnkownNullCount comparison is needed to account - // for null arrays. - (null_count == ::arrow::kUnknownNullCount && +bool lazyNoNulls(const Array& array) { + int64_t nullCount = lazyNullCount(array); + return nullCount == 0 || + // KUnkownNullCount comparison is needed to account. + // For null arrays. + (nullCount == ::arrow::kUnknownNullCount && array.null_bitmap_data() == nullptr); } struct PathWriteContext { PathWriteContext( ::arrow::MemoryPool* pool, - std::shared_ptr<::arrow::ResizableBuffer> def_levels_buffer) - : rep_levels(pool), def_levels(std::move(def_levels_buffer), pool) {} - IterationResult ReserveDefLevels(int64_t elements) { - last_status = def_levels.Reserve(elements); - if (ARROW_PREDICT_TRUE(last_status.ok())) { + std::shared_ptr<::arrow::ResizableBuffer> defLevelsBuffer) + : repLevels(pool), defLevels(std::move(defLevelsBuffer), pool) {} + IterationResult reserveDefLevels(int64_t elements) { + lastStatus = defLevels.Reserve(elements); + if (ARROW_PREDICT_TRUE(lastStatus.ok())) { return kDone; } return kError; } - IterationResult AppendDefLevel(int16_t def_level) { - last_status = def_levels.Append(def_level); - if (ARROW_PREDICT_TRUE(last_status.ok())) { + IterationResult appendDefLevel(int16_t defLevel) { + lastStatus = defLevels.Append(defLevel); + if (ARROW_PREDICT_TRUE(lastStatus.ok())) { return kDone; } return kError; } - IterationResult AppendDefLevels(int64_t count, int16_t def_level) { - last_status = def_levels.Append(count, def_level); - if (ARROW_PREDICT_TRUE(last_status.ok())) { + IterationResult appendDefLevels(int64_t count, int16_t defLevel) { + lastStatus = defLevels.Append(count, defLevel); + if (ARROW_PREDICT_TRUE(lastStatus.ok())) { return kDone; } return kError; } - void UnsafeAppendDefLevel(int16_t def_level) { - def_levels.UnsafeAppend(def_level); + void unsafeAppendDefLevel(int16_t defLevel) { + defLevels.UnsafeAppend(defLevel); } - IterationResult AppendRepLevel(int16_t rep_level) { - last_status = rep_levels.Append(rep_level); + IterationResult appendRepLevel(int16_t repLevel) { + lastStatus = repLevels.Append(repLevel); - if (ARROW_PREDICT_TRUE(last_status.ok())) { + if (ARROW_PREDICT_TRUE(lastStatus.ok())) { return kDone; } return kError; } - IterationResult AppendRepLevels(int64_t count, int16_t rep_level) { - last_status = rep_levels.Append(count, rep_level); - if (ARROW_PREDICT_TRUE(last_status.ok())) { + IterationResult appendRepLevels(int64_t count, int16_t repLevel) { + lastStatus = repLevels.Append(count, repLevel); + if (ARROW_PREDICT_TRUE(lastStatus.ok())) { return kDone; } return kError; } - bool EqualRepDefLevelsLengths() const { - return rep_levels.length() == def_levels.length(); + bool equalRepDefLevelsLengths() const { + return repLevels.length() == defLevels.length(); } - // Incorporates |range| into visited elements. If the |range| is contiguous - // with the last range, extend the last range, otherwise add |range| - // separately to the list. - void RecordPostListVisit(const ElementRange& range) { - if (!visited_elements.empty() && - range.start == visited_elements.back().end) { - visited_elements.back().end = range.end; + // Incorporates |range| into visited elements. If the |range| is contiguous. + // With the last range, extend the last range, otherwise add |range|. + // Separately to the list. + void recordPostListVisit(const ElementRange& range) { + if (!visitedElements.empty() && range.start == visitedElements.back().end) { + visitedElements.back().end = range.end; return; } - visited_elements.push_back(range); + visitedElements.push_back(range); } - Status last_status; - TypedBufferBuilder rep_levels; - TypedBufferBuilder def_levels; - std::vector visited_elements; + Status lastStatus; + TypedBufferBuilder repLevels; + TypedBufferBuilder defLevels; + std::vector visitedElements; }; IterationResult -FillRepLevels(int64_t count, int16_t rep_level, PathWriteContext* context) { - if (rep_level == kLevelNotSet) { +fillRepLevels(int64_t count, int16_t repLevel, PathWriteContext* context) { + if (repLevel == kLevelNotSet) { return kDone; } - int64_t fill_count = count; - // This condition occurs (rep and dep levels equals), in one of - // in a few cases: + int64_t fillCount = count; + // This condition occurs (rep and dep levels equals), in one of. + // In a few cases: // 1. Before any list is encountered. - // 2. After rep-level has been filled in due to null/empty - // values above it. + // 2. After rep-level has been filled in due to null/empty. + // Values above it. // 3. After finishing a list. - if (!context->EqualRepDefLevelsLengths()) { - fill_count--; + if (!context->equalRepDefLevelsLengths()) { + fillCount--; } - return context->AppendRepLevels(fill_count, rep_level); + return context->appendRepLevels(fillCount, repLevel); } -// A node for handling an array that is discovered to have all -// null elements. It is referred to as a TerminalNode because -// traversal of nodes will not continue it when generating -// rep/def levels. However, there could be many nested children -// elements beyond it in the Array that is being processed. +// A node for handling an array that is discovered to have all. +// Null elements. It is referred to as a TerminalNode because. +// Traversal of nodes will not continue it when generating. +// Rep/def levels. However, there could be many nested children. +// Elements beyond it in the Array that is being processed. class AllNullsTerminalNode { public: explicit AllNullsTerminalNode( - int16_t def_level, - int16_t rep_level = kLevelNotSet) - : def_level_(def_level), rep_level_(rep_level) {} - void SetRepLevelIfNull(int16_t rep_level) { - rep_level_ = rep_level; + int16_t defLevel, + int16_t repLevel = kLevelNotSet) + : defLevel_(defLevel), repLevel_(repLevel) {} + void setRepLevelIfNull(int16_t repLevel) { + repLevel_ = repLevel; } - IterationResult Run(const ElementRange& range, PathWriteContext* context) { - int64_t size = range.Size(); - RETURN_IF_ERROR(FillRepLevels(size, rep_level_, context)); - return context->AppendDefLevels(size, def_level_); + IterationResult run(const ElementRange& range, PathWriteContext* context) { + int64_t size = range.size(); + RETURN_IF_ERROR(fillRepLevels(size, repLevel_, context)); + return context->appendDefLevels(size, defLevel_); } private: - int16_t def_level_; - int16_t rep_level_; + int16_t defLevel_; + int16_t repLevel_; }; -// Handles the case where all remaining arrays until the leaf have no nulls -// (and are not interrupted by lists). Unlike AllNullsTerminalNode this is -// always the last node in a path. We don't need an analogue to the -// AllNullsTerminalNode because if all values are present at an intermediate -// array no node is added for it (the def-level for the next nullable node is -// incremented). +// Handles the case where all remaining arrays until the leaf have no nulls. +// (And are not interrupted by lists). Unlike AllNullsTerminalNode this is. +// Always the last node in a path. We don't need an analogue to the. +// AllNullsTerminalNode because if all values are present at an intermediate. +// Array no node is added for it (the def-level for the next nullable node is. +// Incremented). struct AllPresentTerminalNode { - IterationResult Run(const ElementRange& range, PathWriteContext* context) { - return context->AppendDefLevels(range.end - range.start, def_level); - // No need to worry about rep levels, because this state should - // only be applicable for after all list/repeated values - // have been evaluated in the path. + IterationResult run(const ElementRange& range, PathWriteContext* context) { + return context->appendDefLevels(range.end - range.start, defLevel); + // No need to worry about rep levels, because this state should. + // Only be applicable for after all list/repeated values. + // Have been evaluated in the path. } - int16_t def_level; + int16_t defLevel; }; -/// Node for handling the case when the leaf-array is nullable -/// and contains null elements. +/// Node for handling the case when the leaf-array is nullable. +/// And contains null elements. struct NullableTerminalNode { NullableTerminalNode() = default; NullableTerminalNode( const uint8_t* bitmap, - int64_t element_offset, - int16_t def_level_if_present) + int64_t elementOffset, + int16_t defLevelIfPresent) : bitmap_(bitmap), - element_offset_(element_offset), - def_level_if_present_(def_level_if_present), - def_level_if_null_(def_level_if_present - 1) {} + elementOffset_(elementOffset), + defLevelIfPresent_(defLevelIfPresent), + defLevelIfNull_(defLevelIfPresent - 1) {} - IterationResult Run(const ElementRange& range, PathWriteContext* context) { - int64_t elements = range.Size(); - RETURN_IF_ERROR(context->ReserveDefLevels(elements)); + IterationResult run(const ElementRange& range, PathWriteContext* context) { + int64_t elements = range.size(); + RETURN_IF_ERROR(context->reserveDefLevels(elements)); VELOX_DCHECK_GT(elements, 0); - auto bit_visitor = [&](bool is_set) { - context->UnsafeAppendDefLevel( - is_set ? def_level_if_present_ : def_level_if_null_); + auto bitVisitor = [&](bool isSet) { + context->unsafeAppendDefLevel( + isSet ? defLevelIfPresent_ : defLevelIfNull_); }; if (elements > 16) { // 16 guarantees at least one unrolled loop. ::arrow::internal::VisitBitsUnrolled( - bitmap_, range.start + element_offset_, elements, bit_visitor); + bitmap_, range.start + elementOffset_, elements, bitVisitor); } else { ::arrow::internal::VisitBits( - bitmap_, range.start + element_offset_, elements, bit_visitor); + bitmap_, range.start + elementOffset_, elements, bitVisitor); } return kDone; } const uint8_t* bitmap_; - int64_t element_offset_; - int16_t def_level_if_present_; - int16_t def_level_if_null_; + int64_t elementOffset_; + int16_t defLevelIfPresent_; + int16_t defLevelIfNull_; }; -// List nodes handle populating rep_level for Arrow Lists and def-level for -// empty lists. Nullability (both list and children) is handled by other Nodes. -// By construction all list nodes will be intermediate nodes (they will always -// be followed by at least one other node). +// List nodes handle populating rep_level for Arrow Lists and def-level for. +// Empty lists. Nullability (both list and children) is handled by other Nodes. +// By construction all list nodes will be intermediate nodes (they will always. +// Be followed by at least one other node). // // Type parameters: -// |RangeSelector| - A strategy for determine the the range of the child node -// to process. -// this varies depending on the type of list (int32_t* offsets, int64_t* -// offsets of fixed. +// |RangeSelector| - A strategy for determine the the range of the child +// node. To process. +// This varies depending on the type of list (int32_t* offsets, int64_t*. +// Offsets of fixed. template class ListPathNode { public: - ListPathNode( - RangeSelector selector, - int16_t rep_lev, - int16_t def_level_if_empty) + ListPathNode(RangeSelector selector, int16_t repLev, int16_t defLevelIfEmpty) : selector_(std::move(selector)), - prev_rep_level_(rep_lev - 1), - rep_level_(rep_lev), - def_level_if_empty_(def_level_if_empty) {} + prevRepLevel_(repLev - 1), + repLevel_(repLev), + defLevelIfEmpty_(defLevelIfEmpty) {} - int16_t rep_level() const { - return rep_level_; + int16_t repLevel() const { + return repLevel_; } - IterationResult Run( + IterationResult run( ElementRange* range, - ElementRange* child_range, + ElementRange* childRange, PathWriteContext* context) { - if (range->Empty()) { + if (range->empty()) { return kDone; } // Find the first non-empty list (skipping a run of empties). - int64_t empty_elements = 0; + int64_t emptyElements = 0; do { // Retrieve the range of elements that this list contains. - *child_range = selector_.GetRange(range->start); - if (!child_range->Empty()) { + *childRange = selector_.getRange(range->start); + if (!childRange->empty()) { break; } - ++empty_elements; + ++emptyElements; ++range->start; - } while (!range->Empty()); + } while (!range->empty()); // Post condition: // * range is either empty (we are done processing at this node) - // or start corresponds a non-empty list. - // * If range is non-empty child_range contains - // the bounds of non-empty list. + // Or start corresponds a non-empty list. + // * If range is non-empty child_range contains. + // The bounds of non-empty list. // Handle any skipped over empty lists. - if (empty_elements > 0) { - RETURN_IF_ERROR(FillRepLevels(empty_elements, prev_rep_level_, context)); + if (emptyElements > 0) { + RETURN_IF_ERROR(fillRepLevels(emptyElements, prevRepLevel_, context)); RETURN_IF_ERROR( - context->AppendDefLevels(empty_elements, def_level_if_empty_)); + context->appendDefLevels(emptyElements, defLevelIfEmpty_)); } - // Start of a new list. Note that for nested lists adding the element - // here effectively suppresses this code until we either encounter null - // elements or empty lists between here and the innermost list (since - // we make the rep levels repetition and definition levels unequal). - // Similarly when we are backtracking up the stack the repetition and - // definition levels are again equal so if we encounter an intermediate list - // with more elements this will detect it as a new list. - if (context->EqualRepDefLevelsLengths() && !range->Empty()) { - RETURN_IF_ERROR(context->AppendRepLevel(prev_rep_level_)); + // Start of a new list. Note that for nested lists adding the element. + // Here effectively suppresses this code until we either encounter null. + // Elements or empty lists between here and the innermost list (since. + // We make the rep levels repetition and definition levels unequal). + // Similarly when we are backtracking up the stack the repetition and. + // Definition levels are again equal so if we encounter an intermediate + // list. With more elements this will detect it as a new list. + if (context->equalRepDefLevelsLengths() && !range->empty()) { + RETURN_IF_ERROR(context->appendRepLevel(prevRepLevel_)); } - if (range->Empty()) { + if (range->empty()) { return kDone; } ++range->start; - if (is_last_) { - // If this is the last repeated node, we can extend try - // to extend the child range as wide as possible before - // continuing to the next node. - return FillForLast(range, child_range, context); + if (isLast_) { + // If this is the last repeated node, we can extend try. + // To extend the child range as wide as possible before. + // Continuing to the next node. + return fillForLast(range, childRange, context); } return kNext; } - void SetLast() { - is_last_ = true; + void setLast() { + isLast_ = true; } private: - IterationResult FillForLast( + IterationResult fillForLast( ElementRange* range, - ElementRange* child_range, + ElementRange* childRange, PathWriteContext* context) { // First fill int the remainder of the list. - RETURN_IF_ERROR(FillRepLevels(child_range->Size(), rep_level_, context)); + RETURN_IF_ERROR(fillRepLevels(childRange->size(), repLevel_, context)); // Once we've reached this point the following preconditions should hold: // 1. There are no more repeated path nodes to deal with. - // 2. All elements in |range| represent contiguous elements in the - // child array (Null values would have shortened the range to ensure - // all remaining list elements are present (though they may be empty - // lists)). - // 3. No element of range spans a parent list (intermediate - // list nodes only handle one list entry at a time). + // 2. All elements in |range| represent contiguous elements in the. + // Child array (Null values would have shortened the range to ensure. + // All remaining list elements are present (though they may be empty. + // Lists)). + // 3. No element of range spans a parent list (intermediate. + // List nodes only handle one list entry at a time). // - // Given these preconditions it should be safe to fill runs on non-empty - // lists here and expand the range in the child node accordingly. - - while (!range->Empty()) { - ElementRange size_check = selector_.GetRange(range->start); - if (size_check.Empty()) { - // The empty range will need to be handled after we pass down the - // accumulated range because it affects def_level placement and we need - // to get the children def_levels entered first. + // Given these preconditions it should be safe to fill runs on non-empty. + // Lists here and expand the range in the child node accordingly. + + while (!range->empty()) { + ElementRange sizeCheck = selector_.getRange(range->start); + if (sizeCheck.empty()) { + // The empty range will need to be handled after we pass down the. + // Accumulated range because it affects def_level placement and we need. + // To get the children def_levels entered first. break; } - // This is the start of a new list. We can be sure it only applies - // to the previous list (and doesn't jump to the start of any list - // further up in nesting due to the constraints mentioned at the start - // of the function). - RETURN_IF_ERROR(context->AppendRepLevel(prev_rep_level_)); + // This is the start of a new list. We can be sure it only applies. + // To the previous list (and doesn't jump to the start of any list. + // Further up in nesting due to the constraints mentioned at the start. + // Of the function). + RETURN_IF_ERROR(context->appendRepLevel(prevRepLevel_)); RETURN_IF_ERROR( - context->AppendRepLevels(size_check.Size() - 1, rep_level_)); - VELOX_DCHECK_EQ(size_check.start, child_range->end); - child_range->end = size_check.end; + context->appendRepLevels(sizeCheck.size() - 1, repLevel_)); + VELOX_DCHECK_EQ(sizeCheck.start, childRange->end); + childRange->end = sizeCheck.end; ++range->start; } - // Do book-keeping to track the elements of the arrays that are actually - // visited beyond this point. This is necessary to identify "gaps" in - // values that should not be processed (written out to parquet). - context->RecordPostListVisit(*child_range); + // Do book-keeping to track the elements of the arrays that are actually. + // Visited beyond this point. This is necessary to identify "gaps" in. + // Values that should not be processed (written out to parquet). + context->recordPostListVisit(*childRange); return kNext; } RangeSelector selector_; - int16_t prev_rep_level_; - int16_t rep_level_; - int16_t def_level_if_empty_; - bool is_last_ = false; + int16_t prevRepLevel_; + int16_t repLevel_; + int16_t defLevelIfEmpty_; + bool isLast_ = false; }; template struct VarRangeSelector { - ElementRange GetRange(int64_t index) const { + ElementRange getRange(int64_t index) const { return ElementRange{offsets[index], offsets[index + 1]}; } @@ -478,76 +473,76 @@ struct VarRangeSelector { }; struct FixedSizedRangeSelector { - ElementRange GetRange(int64_t index) const { - int64_t start = index * list_size; - return ElementRange{start, start + list_size}; + ElementRange getRange(int64_t index) const { + int64_t start = index * listSize; + return ElementRange{start, start + listSize}; } - int list_size; + int listSize; }; // An intermediate node that handles null values. class NullableNode { public: NullableNode( - const uint8_t* null_bitmap, - int64_t entry_offset, - int16_t def_level_if_null, - int16_t rep_level_if_null = kLevelNotSet) - : null_bitmap_(null_bitmap), - entry_offset_(entry_offset), - valid_bits_reader_(MakeReader(ElementRange{0, 0})), - def_level_if_null_(def_level_if_null), - rep_level_if_null_(rep_level_if_null), - new_range_(true) {} - - void SetRepLevelIfNull(int16_t rep_level) { - rep_level_if_null_ = rep_level; - } - - ::arrow::internal::BitRunReader MakeReader(const ElementRange& range) { + const uint8_t* nullBitmap, + int64_t entryOffset, + int16_t defLevelIfNull, + int16_t repLevelIfNull = kLevelNotSet) + : nullBitmap_(nullBitmap), + entryOffset_(entryOffset), + validBitsReader_(makeReader(ElementRange{0, 0})), + defLevelIfNull_(defLevelIfNull), + repLevelIfNull_(repLevelIfNull), + newRange_(true) {} + + void setRepLevelIfNull(int16_t repLevel) { + repLevelIfNull_ = repLevel; + } + + ::arrow::internal::BitRunReader makeReader(const ElementRange& range) { return ::arrow::internal::BitRunReader( - null_bitmap_, entry_offset_ + range.start, range.Size()); + nullBitmap_, entryOffset_ + range.start, range.size()); } - IterationResult Run( + IterationResult run( ElementRange* range, - ElementRange* child_range, + ElementRange* childRange, PathWriteContext* context) { - if (new_range_) { + if (newRange_) { // Reset the reader each time we are starting fresh on a range. - // We can't rely on continuity because nulls above can - // cause discontinuities. - valid_bits_reader_ = MakeReader(*range); + // We can't rely on continuity because nulls above can. + // Cause discontinuities. + validBitsReader_ = makeReader(*range); } - child_range->start = range->start; - ::arrow::internal::BitRun run = valid_bits_reader_.NextRun(); + childRange->start = range->start; + ::arrow::internal::BitRun run = validBitsReader_.NextRun(); if (!run.set) { range->start += run.length; - RETURN_IF_ERROR(FillRepLevels(run.length, rep_level_if_null_, context)); - RETURN_IF_ERROR(context->AppendDefLevels(run.length, def_level_if_null_)); - run = valid_bits_reader_.NextRun(); + RETURN_IF_ERROR(fillRepLevels(run.length, repLevelIfNull_, context)); + RETURN_IF_ERROR(context->appendDefLevels(run.length, defLevelIfNull_)); + run = validBitsReader_.NextRun(); } - if (range->Empty()) { - new_range_ = true; + if (range->empty()) { + newRange_ = true; return kDone; } - child_range->end = child_range->start = range->start; - child_range->end += run.length; + childRange->end = childRange->start = range->start; + childRange->end += run.length; - VELOX_DCHECK(!child_range->Empty()); - range->start += child_range->Size(); - new_range_ = false; + VELOX_DCHECK(!childRange->empty()); + range->start += childRange->size(); + newRange_ = false; return kNext; } - const uint8_t* null_bitmap_; - int64_t entry_offset_; - ::arrow::internal::BitRunReader valid_bits_reader_; - int16_t def_level_if_null_; - int16_t rep_level_if_null_; + const uint8_t* nullBitmap_; + int64_t entryOffset_; + ::arrow::internal::BitRunReader validBitsReader_; + int16_t defLevelIfNull_; + int16_t repLevelIfNull_; // Whether the next invocation will be a new range. - bool new_range_ = true; + bool newRange_ = true; }; using ListNode = ListPathNode>; @@ -569,181 +564,180 @@ struct PathInfo { AllNullsTerminalNode>; std::vector path; - std::shared_ptr primitive_array; - int16_t max_def_level = 0; - int16_t max_rep_level = 0; - bool has_dictionary = false; - bool leaf_is_nullable = false; + std::shared_ptr primitiveArray; + int16_t maxDefLevel = 0; + int16_t maxRepLevel = 0; + bool hasDictionary = false; + bool leafIsNullable = false; }; /// Contains logic for writing a single leaf node to parquet. /// This tracks the path from root to leaf. /// -/// |writer| will be called after all of the definition/repetition -/// values have been calculated for root_range with the calculated -/// values. It is intended to abstract the complexity of writing -/// the levels and values to parquet. -Status WritePath( - ElementRange root_range, - PathInfo* path_info, - ArrowWriteContext* arrow_context, +/// |Writer| will be called after all of the definition/repetition. +/// Values have been calculated for root_range with the calculated. +/// Values. It is intended to abstract the complexity of writing. +/// The levels and values to parquet. +Status writePath( + ElementRange rootRange, + PathInfo* pathInfo, + ArrowWriteContext* arrowContext, MultipathLevelBuilder::CallbackFunction writer) { - std::vector stack(path_info->path.size()); - MultipathLevelBuilderResult builder_result; - builder_result.leaf_array = path_info->primitive_array; - builder_result.leaf_is_nullable = path_info->leaf_is_nullable; - - if (path_info->max_def_level == 0) { - // This case only occurs when there are no nullable or repeated - // columns in the path from the root to leaf. - int64_t leaf_length = builder_result.leaf_array->length(); - builder_result.def_rep_level_count = leaf_length; - builder_result.post_list_visited_elements.push_back({0, leaf_length}); - return writer(builder_result); - } - stack[0] = root_range; - RETURN_NOT_OK(arrow_context->def_levels_buffer->Resize( - /*new_size=*/0, /*shrink_to_fit*/ false)); + std::vector stack(pathInfo->path.size()); + MultipathLevelBuilderResult builderResult; + builderResult.leafArray = pathInfo->primitiveArray; + builderResult.leafIsNullable = pathInfo->leafIsNullable; + + if (pathInfo->maxDefLevel == 0) { + // This case only occurs when there are no nullable or repeated. + // Columns in the path from the root to leaf. + int64_t leafLength = builderResult.leafArray->length(); + builderResult.defRepLevelCount = leafLength; + builderResult.postListVisitedElements.push_back({0, leafLength}); + return writer(builderResult); + } + stack[0] = rootRange; + RETURN_NOT_OK( + arrowContext->defLevelsBuffer->Resize(0, /*shrink_to_fit*/ false)); PathWriteContext context( - arrow_context->memory_pool, arrow_context->def_levels_buffer); - // We should need at least this many entries so reserve the space ahead of - // time. - RETURN_NOT_OK(context.def_levels.Reserve(root_range.Size())); - if (path_info->max_rep_level > 0) { - RETURN_NOT_OK(context.rep_levels.Reserve(root_range.Size())); - } - - auto stack_base = &stack[0]; - auto stack_position = stack_base; - // This is the main loop for calculated rep/def levels. The nodes - // in the path implement a chain-of-responsibility like pattern - // where each node can add some number of repetition/definition - // levels to PathWriteContext and also delegate to the next node + arrowContext->memoryPool, arrowContext->defLevelsBuffer); + // We should need at least this many entries so reserve the space ahead of. + // Time. + RETURN_NOT_OK(context.defLevels.Reserve(rootRange.size())); + if (pathInfo->maxRepLevel > 0) { + RETURN_NOT_OK(context.repLevels.Reserve(rootRange.size())); + } + + auto stackBase = &stack[0]; + auto stackPosition = stackBase; + // This is the main loop for calculated rep/def levels. The nodes. + // In the path implement a chain-of-responsibility like pattern. + // Where each node can add some number of repetition/definition. + // Levels to PathWriteContext and also delegate to the next node. // in the path to add values. The values are added through each Run(...) - // call and the choice to delegate to the next node (or return to the - // previous node) is communicated by the return value of Run(...). - // The loop terminates after the first node indicates all values in - // |root_range| are processed. - while (stack_position >= stack_base) { - PathInfo::Node& node = path_info->path[stack_position - stack_base]; + // Call and the choice to delegate to the next node (or return to the. + // Previous node) is communicated by the return value of Run(...). + // The loop terminates after the first node indicates all values in. + // |Root_range| are processed. + while (stackPosition >= stackBase) { + PathInfo::Node& Node = pathInfo->path[stackPosition - stackBase]; struct { - IterationResult operator()(NullableNode& node) { - return node.Run(stack_position, stack_position + 1, context); + IterationResult operator()(NullableNode& Node) { + return Node.run(stackPosition, stackPosition + 1, context); } - IterationResult operator()(ListNode& node) { - return node.Run(stack_position, stack_position + 1, context); + IterationResult operator()(ListNode& Node) { + return Node.run(stackPosition, stackPosition + 1, context); } - IterationResult operator()(NullableTerminalNode& node) { - return node.Run(*stack_position, context); + IterationResult operator()(NullableTerminalNode& Node) { + return Node.run(*stackPosition, context); } - IterationResult operator()(FixedSizeListNode& node) { - return node.Run(stack_position, stack_position + 1, context); + IterationResult operator()(FixedSizeListNode& Node) { + return Node.run(stackPosition, stackPosition + 1, context); } - IterationResult operator()(AllPresentTerminalNode& node) { - return node.Run(*stack_position, context); + IterationResult operator()(AllPresentTerminalNode& Node) { + return Node.run(*stackPosition, context); } - IterationResult operator()(AllNullsTerminalNode& node) { - return node.Run(*stack_position, context); + IterationResult operator()(AllNullsTerminalNode& Node) { + return Node.run(*stackPosition, context); } - IterationResult operator()(LargeListNode& node) { - return node.Run(stack_position, stack_position + 1, context); + IterationResult operator()(LargeListNode& Node) { + return Node.run(stackPosition, stackPosition + 1, context); } - ElementRange* stack_position; + ElementRange* stackPosition; PathWriteContext* context; - } visitor = {stack_position, &context}; + } visitor = {stackPosition, &context}; - IterationResult result = std::visit(visitor, node); + IterationResult result = std::visit(visitor, Node); if (ARROW_PREDICT_FALSE(result == kError)) { - VELOX_DCHECK(!context.last_status.ok()); - return context.last_status; + VELOX_DCHECK(!context.lastStatus.ok()); + return context.lastStatus; } - stack_position += static_cast(result); - } - RETURN_NOT_OK(context.last_status); - builder_result.def_rep_level_count = context.def_levels.length(); - - if (context.rep_levels.length() > 0) { - // This case only occurs when there was a repeated element that needs to be - // processed. - builder_result.rep_levels = context.rep_levels.data(); - std::swap( - builder_result.post_list_visited_elements, context.visited_elements); - // If it is possible when processing lists that all lists where empty. In - // this case no elements would have been added to - // post_list_visited_elements. By added an empty element we avoid special - // casing in downstream consumers. - if (builder_result.post_list_visited_elements.empty()) { - builder_result.post_list_visited_elements.push_back({0, 0}); + stackPosition += static_cast(result); + } + RETURN_NOT_OK(context.lastStatus); + builderResult.defRepLevelCount = context.defLevels.length(); + + if (context.repLevels.length() > 0) { + // This case only occurs when there was a repeated element that needs to be. + // Processed. + builderResult.repLevels = context.repLevels.data(); + std::swap(builderResult.postListVisitedElements, context.visitedElements); + // If it is possible when processing lists that all lists where empty. In. + // This case no elements would have been added to. + // Post_list_visited_elements. By added an empty element we avoid special. + // Casing in downstream consumers. + if (builderResult.postListVisitedElements.empty()) { + builderResult.postListVisitedElements.push_back({0, 0}); } } else { - builder_result.post_list_visited_elements.push_back( - {0, builder_result.leaf_array->length()}); - builder_result.rep_levels = nullptr; + builderResult.postListVisitedElements.push_back( + {0, builderResult.leafArray->length()}); + builderResult.repLevels = nullptr; } - builder_result.def_levels = context.def_levels.data(); - return writer(builder_result); + builderResult.defLevels = context.defLevels.data(); + return writer(builderResult); } struct FixupVisitor { - int max_rep_level = -1; - int16_t rep_level_if_null = kLevelNotSet; + int maxRepLevel = -1; + int16_t repLevelIfNull = kLevelNotSet; template - void HandleListNode(T& arg) { - if (arg.rep_level() == max_rep_level) { - arg.SetLast(); - // after the last list node we don't need to fill - // rep levels on null. - rep_level_if_null = kLevelNotSet; + void handleListNode(T& arg) { + if (arg.repLevel() == maxRepLevel) { + arg.setLast(); + // After the last list node we don't need to fill. + // Rep levels on null. + repLevelIfNull = kLevelNotSet; } else { - rep_level_if_null = arg.rep_level(); + repLevelIfNull = arg.repLevel(); } } - void operator()(ListNode& node) { - HandleListNode(node); + void operator()(ListNode& Node) { + handleListNode(Node); } - void operator()(LargeListNode& node) { - HandleListNode(node); + void operator()(LargeListNode& Node) { + handleListNode(Node); } - void operator()(FixedSizeListNode& node) { - HandleListNode(node); + void operator()(FixedSizeListNode& Node) { + handleListNode(Node); } // For non-list intermediate nodes. template - void HandleIntermediateNode(T& arg) { - if (rep_level_if_null != kLevelNotSet) { - arg.SetRepLevelIfNull(rep_level_if_null); + void handleIntermediateNode(T& arg) { + if (repLevelIfNull != kLevelNotSet) { + arg.setRepLevelIfNull(repLevelIfNull); } } void operator()(NullableNode& arg) { - HandleIntermediateNode(arg); + handleIntermediateNode(arg); } void operator()(AllNullsTerminalNode& arg) { - // Even though no processing happens past this point we - // still need to adjust it if a list occurred after an - // all null array. - HandleIntermediateNode(arg); + // Even though no processing happens past this point we. + // Still need to adjust it if a list occurred after an. + // All null array. + handleIntermediateNode(arg); } void operator()(NullableTerminalNode&) {} void operator()(AllPresentTerminalNode&) {} }; -PathInfo Fixup(PathInfo info) { - // We only need to fixup the path if there were repeated - // elements on it. - if (info.max_rep_level == 0) { +PathInfo fixup(PathInfo info) { + // We only need to fixup the path if there were repeated. + // Elements on it. + if (info.maxRepLevel == 0) { return info; } FixupVisitor visitor; - visitor.max_rep_level = info.max_rep_level; - if (visitor.max_rep_level > 0) { - visitor.rep_level_if_null = 0; + visitor.maxRepLevel = info.maxRepLevel; + if (visitor.maxRepLevel > 0) { + visitor.repLevelIfNull = 0; } for (size_t x = 0; x < info.path.size(); x++) { std::visit(visitor, info.path[x]); @@ -753,34 +747,33 @@ PathInfo Fixup(PathInfo info) { class PathBuilder { public: - explicit PathBuilder(bool start_nullable) - : nullable_in_parent_(start_nullable) {} + explicit PathBuilder(bool startNullable) : nullableInParent_(startNullable) {} template - void AddTerminalInfo(const T& array) { - info_.leaf_is_nullable = nullable_in_parent_; - if (nullable_in_parent_) { - info_.max_def_level++; + void addTerminalInfo(const T& array) { + info_.leafIsNullable = nullableInParent_; + if (nullableInParent_) { + info_.maxDefLevel++; } - // We don't use null_count() because if the null_count isn't known - // and the array does in fact contain nulls, we will end up - // traversing the null bitmap twice (once here and once when calculating - // rep/def levels). - if (LazyNoNulls(array)) { - info_.path.emplace_back(AllPresentTerminalNode{info_.max_def_level}); - } else if (LazyNullCount(array) == array.length()) { - info_.path.emplace_back(AllNullsTerminalNode(info_.max_def_level - 1)); + // We don't use null_count() because if the null_count isn't known. + // And the array does in fact contain nulls, we will end up. + // Traversing the null bitmap twice (once here and once when calculating. + // Rep/def levels). + if (lazyNoNulls(array)) { + info_.path.emplace_back(AllPresentTerminalNode{info_.maxDefLevel}); + } else if (lazyNullCount(array) == array.length()) { + info_.path.emplace_back(AllNullsTerminalNode(info_.maxDefLevel - 1)); } else { info_.path.emplace_back(NullableTerminalNode( - array.null_bitmap_data(), array.offset(), info_.max_def_level)); + array.null_bitmap_data(), array.offset(), info_.maxDefLevel)); } - info_.primitive_array = std::make_shared(array.data()); - paths_.push_back(Fixup(info_)); + info_.primitiveArray = std::make_shared(array.data()); + paths_.push_back(fixup(info_)); } template ::arrow::enable_if_t::value, Status> Visit(const T& array) { - AddTerminalInfo(array); + addTerminalInfo(array); return Status::OK(); } @@ -790,23 +783,23 @@ class PathBuilder { std::is_same<::arrow::LargeListArray, T>::value, Status> Visit(const T& array) { - MaybeAddNullable(array); + maybeAddNullable(array); // Increment necessary due to empty lists. - info_.max_def_level++; - info_.max_rep_level++; - // raw_value_offsets() accounts for any slice offset. - ListPathNode> node( + info_.maxDefLevel++; + info_.maxRepLevel++; + // Raw_value_offsets() accounts for any slice offset. + ListPathNode> Node( VarRangeSelector{array.raw_value_offsets()}, - info_.max_rep_level, - info_.max_def_level - 1); - info_.path.emplace_back(std::move(node)); - nullable_in_parent_ = array.list_type()->value_field()->nullable(); + info_.maxRepLevel, + info_.maxDefLevel - 1); + info_.path.emplace_back(std::move(Node)); + nullableInParent_ = array.list_type()->value_field()->nullable(); return VisitInline(*array.values()); } Status Visit(const ::arrow::DictionaryArray& array) { - // Only currently handle DictionaryArray where the dictionary is a - // primitive type + // Only currently handle DictionaryArray where the dictionary is a. + // Primitive type. if (array.dict_type()->value_type()->num_fields() > 0) { return Status::NotImplemented( "Writing DictionaryArray with nested dictionary " @@ -817,35 +810,35 @@ class PathBuilder { "Writing DictionaryArray with null encoded in dictionary " "type not yet supported"); } - AddTerminalInfo(array); + addTerminalInfo(array); return Status::OK(); } - void MaybeAddNullable(const Array& array) { - if (!nullable_in_parent_) { + void maybeAddNullable(const Array& array) { + if (!nullableInParent_) { return; } - info_.max_def_level++; - // We don't use null_count() because if the null_count isn't known - // and the array does in fact contain nulls, we will end up - // traversing the null bitmap twice (once here and once when calculating - // rep/def levels). Because this isn't terminal this might not be - // the right decision for structs that share the same nullable - // parents. - if (LazyNoNulls(array)) { - // Don't add anything because there won't be any point checking - // null values for the array. There will always be at least - // one more array to handle nullability. + info_.maxDefLevel++; + // We don't use null_count() because if the null_count isn't known. + // And the array does in fact contain nulls, we will end up. + // Traversing the null bitmap twice (once here and once when calculating. + // Rep/def levels). Because this isn't terminal this might not be. + // The right decision for structs that share the same nullable. + // Parents. + if (lazyNoNulls(array)) { + // Don't add anything because there won't be any point checking. + // Null values for the array. There will always be at least. + // One more array to handle nullability. return; } - if (LazyNullCount(array) == array.length()) { - info_.path.emplace_back(AllNullsTerminalNode(info_.max_def_level - 1)); + if (lazyNullCount(array) == array.length()) { + info_.path.emplace_back(AllNullsTerminalNode(info_.maxDefLevel - 1)); return; } info_.path.emplace_back(NullableNode( array.null_bitmap_data(), array.offset(), - /* def_level_if_null = */ info_.max_def_level - 1)); + /* def_level_if_null = */ info_.maxDefLevel - 1)); } Status VisitInline(const Array& array); @@ -855,29 +848,29 @@ class PathBuilder { } Status Visit(const ::arrow::StructArray& array) { - MaybeAddNullable(array); - PathInfo info_backup = info_; + maybeAddNullable(array); + PathInfo infoBackup = info_; for (int x = 0; x < array.num_fields(); x++) { - nullable_in_parent_ = array.type()->field(x)->nullable(); + nullableInParent_ = array.type()->field(x)->nullable(); RETURN_NOT_OK(VisitInline(*array.field(x))); - info_ = info_backup; + info_ = infoBackup; } return Status::OK(); } Status Visit(const ::arrow::FixedSizeListArray& array) { - MaybeAddNullable(array); - int32_t list_size = array.list_type()->list_size(); - // Technically we could encode fixed size lists with two level encodings - // but since we always use 3 level encoding we increment def levels as - // well. - info_.max_def_level++; - info_.max_rep_level++; + maybeAddNullable(array); + int32_t listSize = array.list_type()->list_size(); + // Technically we could encode fixed size lists with two level encodings. + // But since we always use 3 level encoding we increment def levels as. + // Well. + info_.maxDefLevel++; + info_.maxRepLevel++; info_.path.emplace_back(FixedSizeListNode( - FixedSizedRangeSelector{list_size}, - info_.max_rep_level, - info_.max_def_level)); - nullable_in_parent_ = array.list_type()->value_field()->nullable(); + FixedSizedRangeSelector{listSize}, + info_.maxRepLevel, + info_.maxDefLevel)); + nullableInParent_ = array.list_type()->value_field()->nullable(); if (array.offset() > 0) { return VisitInline(*array.values()->Slice(array.value_offset(0))); } @@ -888,10 +881,10 @@ class PathBuilder { return VisitInline(*array.storage()); } -#define NOT_IMPLEMENTED_VISIT(ArrowTypePrefix) \ - Status Visit(const ::arrow::ArrowTypePrefix##Array& array) { \ - return Status::NotImplemented("Level generation for " #ArrowTypePrefix \ - " not supported yet"); \ +#define NOT_IMPLEMENTED_VISIT(ArrowTypePrefix) \ + Status Visit(const ::arrow::ArrowTypePrefix##Array& array) { \ + return Status::NotImplemented( \ + "Level generation for " #ArrowTypePrefix " not supported yet"); \ } // Types not yet supported in Parquet. @@ -908,7 +901,7 @@ class PathBuilder { private: PathInfo info_; std::vector paths_; - bool nullable_in_parent_; + bool nullableInParent_; }; Status PathBuilder::VisitInline(const Array& array) { @@ -922,65 +915,65 @@ class MultipathLevelBuilderImpl : public MultipathLevelBuilder { public: MultipathLevelBuilderImpl( std::shared_ptr<::arrow::ArrayData> data, - std::unique_ptr path_builder) - : root_range_{0, data->length}, + std::unique_ptr pathBuilder) + : rootRange_{0, data->length}, data_(std::move(data)), - path_builder_(std::move(path_builder)) {} + pathBuilder_(std::move(pathBuilder)) {} - int GetLeafCount() const override { - return static_cast(path_builder_->paths().size()); + int getLeafCount() const override { + return static_cast(pathBuilder_->paths().size()); } - ::arrow::Status Write( - int leaf_index, + ::arrow::Status write( + int leafIndex, ArrowWriteContext* context, - CallbackFunction write_leaf_callback) override { - if (ARROW_PREDICT_FALSE(leaf_index < 0 || leaf_index >= GetLeafCount())) { + CallbackFunction writeLeafCallback) override { + if (ARROW_PREDICT_FALSE(leafIndex < 0 || leafIndex >= getLeafCount())) { return Status::Invalid( "Column index out of bounds (got ", - leaf_index, + leafIndex, ", should be " "between 0 and ", - GetLeafCount(), + getLeafCount(), ")"); } - return WritePath( - root_range_, - &path_builder_->paths()[leaf_index], + return writePath( + rootRange_, + &pathBuilder_->paths()[leafIndex], context, - std::move(write_leaf_callback)); + std::move(writeLeafCallback)); } private: - ElementRange root_range_; + ElementRange rootRange_; // Reference holder to ensure the data stays valid. std::shared_ptr<::arrow::ArrayData> data_; - std::unique_ptr path_builder_; + std::unique_ptr pathBuilder_; }; -// static +// Static. ::arrow::Result> -MultipathLevelBuilder::Make( +MultipathLevelBuilder::make( const ::arrow::Array& array, - bool array_field_nullable) { - auto constructor = std::make_unique(array_field_nullable); - RETURN_NOT_OK(VisitArrayInline(array, constructor.get())); + bool arrayFieldNullable) { + auto constructor = std::make_unique(arrayFieldNullable); + RETURN_NOT_OK(::arrow::VisitArrayInline(array, constructor.get())); return std::make_unique( array.data(), std::move(constructor)); } -// static -Status MultipathLevelBuilder::Write( +// Static. +Status MultipathLevelBuilder::write( const Array& array, - bool array_field_nullable, + bool arrayFieldNullable, ArrowWriteContext* context, MultipathLevelBuilder::CallbackFunction callback) { ARROW_ASSIGN_OR_RAISE( - std::unique_ptr builder, - MultipathLevelBuilder::Make(array, array_field_nullable)); - for (int leaf_idx = 0; leaf_idx < builder->GetLeafCount(); leaf_idx++) { - RETURN_NOT_OK(builder->Write(leaf_idx, context, callback)); + std::unique_ptr Builder, + MultipathLevelBuilder::make(array, arrayFieldNullable)); + for (int leafIdx = 0; leafIdx < Builder->getLeafCount(); leafIdx++) { + RETURN_NOT_OK(Builder->write(leafIdx, context, callback)); } return Status::OK(); } diff --git a/velox/dwio/parquet/writer/arrow/PathInternal.h b/velox/dwio/parquet/writer/arrow/PathInternal.h index ccf69cea5eb..55e467ddf9a 100644 --- a/velox/dwio/parquet/writer/arrow/PathInternal.h +++ b/velox/dwio/parquet/writer/arrow/PathInternal.h @@ -40,11 +40,11 @@ struct ArrowWriteContext; namespace arrow { -// This files contain internal implementation details and should not be -// considered part of the public API. +// This files contain internal implementation details and should not be. +// Considered part of the public API. -// The MultipathLevelBuilder is intended to fully support all Arrow nested types -// that map to parquet types (i.e. Everything but Unions). +// The MultipathLevelBuilder is intended to fully support all Arrow nested +// types. That map to parquet types (i.e. Everything but Unions). // /// \brief Half open range of elements in an array. @@ -54,111 +54,111 @@ struct ElementRange { /// Upper bound of range (exclusive) int64_t end; - bool Empty() const { + bool empty() const { return start == end; } - int64_t Size() const { + int64_t size() const { return end - start; } }; -/// \brief Result for a single leaf array when running the builder on the -/// its root. +/// \brief Result for a single leaf array when running the builder on the. +/// Its root. struct MultipathLevelBuilderResult { - /// \brief The Array containing only the values to write (after all nesting - /// has been processed. + /// \brief The Array containing only the values to write (after all nesting. + /// Has been processed. /// - /// No additional processing is done on this array (it is copied as is when - /// visited via a DFS). - std::shared_ptr<::arrow::Array> leaf_array; + /// No additional processing is done on this array (it is copied as is when. + /// Visited via a DFS). + std::shared_ptr<::arrow::Array> leafArray; /// \brief Might be null. - const int16_t* def_levels = nullptr; + const int16_t* defLevels = nullptr; /// \brief Might be null. - const int16_t* rep_levels = nullptr; + const int16_t* repLevels = nullptr; /// \brief Number of items (int16_t) contained in def/rep_levels when present. - int64_t def_rep_level_count = 0; + int64_t defRepLevelCount = 0; - /// \brief Contains element ranges of the required visiting on the - /// descendants of the final list ancestor for any leaf node. + /// \brief Contains element ranges of the required visiting on the. + /// Descendants of the final list ancestor for any leaf node. /// - /// The algorithm will attempt to consolidate visited ranges into - /// the smallest number possible. + /// The algorithm will attempt to consolidate visited ranges into. + /// The smallest number possible. /// - /// This data is necessary to pass along because after producing - /// def-rep levels for each leaf array it is impossible to determine - /// which values have to be sent to parquet when a null list value - /// in a nullable ListArray is non-empty. + /// This data is necessary to pass along because after producing. + /// Def-rep levels for each leaf array it is impossible to determine. + /// Which values have to be sent to parquet when a null list value. + /// In a nullable ListArray is non-empty. /// - /// This allows for the parquet writing to determine which values ultimately - /// needs to be written. - std::vector post_list_visited_elements; + /// This allows for the parquet writing to determine which values ultimately. + /// Needs to be written. + std::vector postListVisitedElements; /// Whether the leaf array is nullable. - bool leaf_is_nullable; + bool leafIsNullable; }; -/// \brief Logic for being able to write out nesting (rep/def level) data that -/// is needed for writing to parquet. +/// \brief Logic for being able to write out nesting (rep/def level) data that. +/// Is needed for writing to parquet. class PARQUET_EXPORT MultipathLevelBuilder { public: - /// \brief A callback function that will receive results from the call to - /// Write(...) below. The MultipathLevelBuilderResult passed in will - /// only remain valid for the function call (i.e. storing it and relying - /// for its data to be consistent afterwards will result in undefined - /// behavior. + /// \brief A callback function that will receive results from the call to. + /// Write(...) below. The MultipathLevelBuilderResult passed in will. + /// Only remain valid for the function call (i.e. storing it and relying. + /// For its data to be consistent afterwards will result in undefined. + /// Behavior. using CallbackFunction = std::function<::arrow::Status(const MultipathLevelBuilderResult&)>; /// \brief Determine rep/def level information for the array. /// - /// The callback will be invoked for each leaf Array that is a - /// descendant of array. Each leaf array is processed in a depth - /// first traversal-order. + /// The callback will be invoked for each leaf Array that is a. + /// Descendant of array. Each leaf array is processed in a depth. + /// First traversal-order. /// /// \param[in] array The array to process. - /// \param[in] array_field_nullable Whether the algorithm should consider - /// the the array column as nullable (as determined by its type's parent - /// field). + /// \param[in] array_field_nullable Whether the algorithm should consider. + /// The the array column as nullable (as determined by its type's parent. + /// Field). /// \param[in, out] context for use when allocating memory, etc. /// \param[out] write_leaf_callback Callback to receive results. /// There will be one call to the write_leaf_callback for each leaf node. - static ::arrow::Status Write( + static ::arrow::Status write( const ::arrow::Array& array, - bool array_field_nullable, + bool arrayFieldNullable, ArrowWriteContext* context, - CallbackFunction write_leaf_callback); + CallbackFunction writeLeafCallback); /// \brief Construct a new instance of the builder. /// /// \param[in] array The array to process. - /// \param[in] array_field_nullable Whether the algorithm should consider - /// the the array column as nullable (as determined by its type's parent - /// field). - static ::arrow::Result> Make( + /// \param[in] array_field_nullable Whether the algorithm should consider. + /// The the array column as nullable (as determined by its type's parent. + /// Field). + static ::arrow::Result> make( const ::arrow::Array& array, - bool array_field_nullable); + bool arrayFieldNullable); virtual ~MultipathLevelBuilder() = default; - /// \brief Returns the number of leaf columns that need to be written - /// to Parquet. - virtual int GetLeafCount() const = 0; + /// \brief Returns the number of leaf columns that need to be written. + /// To Parquet. + virtual int getLeafCount() const = 0; - /// \brief Calls write_leaf_callback with the MultipathLevelBuilderResult - /// corresponding to |leaf_index|. + /// \brief Calls write_leaf_callback with the MultipathLevelBuilderResult. + /// Corresponding to |leaf_index|. /// - /// \param[in] leaf_index The index of the leaf column to write. Must be in - /// the range [0, GetLeafCount()]. \param[in, out] context for use when - /// allocating memory, etc. \param[out] write_leaf_callback Callback to - /// receive the result. - virtual ::arrow::Status Write( - int leaf_index, + /// \param[in] leaf_index The index of the leaf column to write. Must be in. + /// The range [0, GetLeafCount()]. \param[in, out] context for use when. + /// Allocating memory, etc. \param[out] write_leaf_callback Callback to. + /// Receive the result. + virtual ::arrow::Status write( + int leafIndex, ArrowWriteContext* context, - CallbackFunction write_leaf_callback) = 0; + CallbackFunction writeLeafCallback) = 0; }; } // namespace arrow diff --git a/velox/dwio/parquet/writer/arrow/Platform.cpp b/velox/dwio/parquet/writer/arrow/Platform.cpp index 66f27575d5d..c3d3e758a6e 100644 --- a/velox/dwio/parquet/writer/arrow/Platform.cpp +++ b/velox/dwio/parquet/writer/arrow/Platform.cpp @@ -27,7 +27,7 @@ namespace facebook::velox::parquet::arrow { -std::shared_ptr<::arrow::io::BufferOutputStream> CreateOutputStream( +std::shared_ptr<::arrow::io::BufferOutputStream> createOutputStream( MemoryPool* pool) { PARQUET_ASSIGN_OR_THROW( auto stream, @@ -35,7 +35,7 @@ std::shared_ptr<::arrow::io::BufferOutputStream> CreateOutputStream( return stream; } -std::shared_ptr AllocateBuffer( +std::shared_ptr allocateBuffer( MemoryPool* pool, int64_t size) { PARQUET_ASSIGN_OR_THROW( diff --git a/velox/dwio/parquet/writer/arrow/Platform.h b/velox/dwio/parquet/writer/arrow/Platform.h index 467d2364bf7..f3ac272ca80 100644 --- a/velox/dwio/parquet/writer/arrow/Platform.h +++ b/velox/dwio/parquet/writer/arrow/Platform.h @@ -31,14 +31,14 @@ #if defined(_MSC_VER) #pragma warning(push) -// Disable warning for STL types usage in DLL interface +// Disable warning for STL types usage in DLL interface. // https://web.archive.org/web/20130317015847/http://connect.microsoft.com/VisualStudio/feedback/details/696593/vc-10-vs-2010-basic-string-exports #pragma warning(disable : 4275 4251) -// Disable diamond inheritance warnings +// Disable diamond inheritance warnings. #pragma warning(disable : 4250) -// Disable macro redefinition warnings +// Disable macro redefinition warnings. #pragma warning(disable : 4005) -// Disable extern before exported template warnings +// Disable extern before exported template warnings. #pragma warning(disable : 4910) #else #pragma GCC diagnostic ignored "-Wattributes" @@ -79,7 +79,7 @@ #define PARQUET_DEPRECATED ARROW_DEPRECATED // If ARROW_VALGRIND set when compiling unit tests, also define -// PARQUET_VALGRIND +// PARQUET_VALGRIND. #ifdef ARROW_VALGRIND #define PARQUET_VALGRIND #endif @@ -100,11 +100,11 @@ constexpr int64_t kDefaultOutputStreamSize = 1024; constexpr int16_t kNonPageOrdinal = static_cast(-1); PARQUET_EXPORT -std::shared_ptr<::arrow::io::BufferOutputStream> CreateOutputStream( +std::shared_ptr<::arrow::io::BufferOutputStream> createOutputStream( ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()); PARQUET_EXPORT -std::shared_ptr AllocateBuffer( +std::shared_ptr allocateBuffer( ::arrow::MemoryPool* pool = ::arrow::default_memory_pool(), int64_t size = 0); diff --git a/velox/dwio/parquet/writer/arrow/Properties.cpp b/velox/dwio/parquet/writer/arrow/Properties.cpp index c7588e79973..dc8e49bf63d 100644 --- a/velox/dwio/parquet/writer/arrow/Properties.cpp +++ b/velox/dwio/parquet/writer/arrow/Properties.cpp @@ -27,32 +27,32 @@ namespace facebook::velox::parquet::arrow { -ReaderProperties default_reader_properties() { - static ReaderProperties default_reader_properties; - return default_reader_properties; +ReaderProperties defaultReaderProperties() { + static ReaderProperties defaultReaderProperties; + return defaultReaderProperties; } -std::shared_ptr ReaderProperties::GetStream( +std::shared_ptr ReaderProperties::getStream( std::shared_ptr source, int64_t start, - int64_t num_bytes) { - if (buffered_stream_enabled_) { - // ARROW-6180 / PARQUET-1636 Create isolated reader that references segment - // of source + int64_t numBytes) { + if (bufferedStreamEnabled_) { + // ARROW-6180 / PARQUET-1636 Create isolated reader that references segment. + // Of source. PARQUET_ASSIGN_OR_THROW( - std::shared_ptr<::arrow::io::InputStream> safe_stream, - ::arrow::io::RandomAccessFile::GetStream(source, start, num_bytes)); + std::shared_ptr<::arrow::io::InputStream> safeStream, + ::arrow::io::RandomAccessFile::GetStream(source, start, numBytes)); PARQUET_ASSIGN_OR_THROW( auto stream, ::arrow::io::BufferedInputStream::Create( - buffer_size_, pool_, safe_stream, num_bytes)); + bufferSize_, pool_, safeStream, numBytes)); return std::move(stream); } else { - PARQUET_ASSIGN_OR_THROW(auto data, source->ReadAt(start, num_bytes)); + PARQUET_ASSIGN_OR_THROW(auto data, source->ReadAt(start, numBytes)); - if (data->size() != num_bytes) { + if (data->size() != numBytes) { std::stringstream ss; - ss << "Tried reading " << num_bytes << " bytes starting at position " + ss << "Tried reading " << numBytes << " bytes starting at position " << start << " from file but only got " << data->size(); throw ParquetException(ss.str()); } @@ -65,15 +65,15 @@ ::arrow::internal::Executor* ArrowWriterProperties::executor() const { : ::arrow::internal::GetCpuThreadPool(); } -ArrowReaderProperties default_arrow_reader_properties() { - static ArrowReaderProperties default_reader_props; - return default_reader_props; +ArrowReaderProperties defaultArrowReaderProperties() { + static ArrowReaderProperties defaultReaderProps; + return defaultReaderProps; } -std::shared_ptr default_arrow_writer_properties() { - static std::shared_ptr default_writer_properties = +std::shared_ptr defaultArrowWriterProperties() { + static std::shared_ptr defaultWriterProperties = ArrowWriterProperties::Builder().build(); - return default_writer_properties; + return defaultWriterProperties; } } // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/Properties.h b/velox/dwio/parquet/writer/arrow/Properties.h index 815e2c7d6f5..60cd17ff426 100644 --- a/velox/dwio/parquet/writer/arrow/Properties.h +++ b/velox/dwio/parquet/writer/arrow/Properties.h @@ -34,7 +34,7 @@ #include "velox/dwio/parquet/writer/arrow/Types.h" #include "velox/dwio/parquet/writer/arrow/util/Compression.h" -// Define the parquet created by version. +// Define the Parquet created by version. #define CREATED_BY_VERSION "parquet-cpp-velox" // Velox has no versioning yet. Set default 0.0.0. #define VELOX_VERSION "0.0.0" @@ -43,7 +43,7 @@ namespace facebook::velox::parquet::arrow { using facebook::velox::parquet::arrow::util::CodecOptions; -/// \brief Feature selection when writing Parquet files +/// \brief Feature selection when writing Parquet files. /// /// `ParquetVersion::type` governs which data types are allowed and how they /// are represented. For example, uint32_t data will be written differently @@ -55,21 +55,20 @@ using facebook::velox::parquet::arrow::util::CodecOptions; /// ArrowWriterProperties. struct ParquetVersion { enum type : int { - /// Enable only pre-2.2 Parquet format features when writing + /// Enable only pre-2.2 Parquet format features when writing. /// /// This setting is useful for maximum compatibility with legacy readers. - /// Note that logical types may still be emitted, as long they have a - /// corresponding converted type. + /// Corresponds to a converted type. PARQUET_1_0, - /// DEPRECATED: Enable Parquet format 2.6 features + /// DEPRECATED: Enable Parquet format 2.6 features. /// /// This misleadingly named enum value is roughly similar to PARQUET_2_6. PARQUET_2_0 ARROW_DEPRECATED_ENUM_VALUE( "use PARQUET_2_4 or PARQUET_2_6 " "for fine-grained feature selection"), - /// Enable Parquet format 2.4 and earlier features when writing + /// Enable Parquet format 2.4 and earlier features when writing. /// /// This enables UINT32 as well as logical types which don't have /// a corresponding converted type. @@ -77,23 +76,21 @@ struct ParquetVersion { /// Note: Parquet format 2.4.0 was released in October 2017. PARQUET_2_4, - /// Enable Parquet format 2.6 and earlier features when writing + /// Enable Parquet format 2.6 and earlier features when writing. /// - /// This enables the NANOS time unit in addition to the PARQUET_2_4 - /// features. + /// units in addition to the PARQUET_2_4 features. /// /// Note: Parquet format 2.6.0 was released in September 2018. PARQUET_2_6, - /// Enable latest Parquet format 2.x features + /// Enable latest Parquet format 2.x features. /// - /// This value is equal to the greatest 2.x version supported by - /// this library. + /// The version supported by this library. PARQUET_2_LATEST = PARQUET_2_6 }; }; -/// Controls serialization format of data pages. parquet-format v2.0.0 +/// Controls serialization format of data pages. Parquet-format v2.0.0 /// introduced a new data page metadata type DataPageV2 and serialized page /// structure (for example, encoded levels are no longer compressed). Prior to /// the completion of PARQUET-457 in 2020, this library did not implement @@ -117,96 +114,96 @@ class PARQUET_EXPORT ReaderProperties { explicit ReaderProperties(MemoryPool* pool = ::arrow::default_memory_pool()) : pool_(pool) {} - MemoryPool* memory_pool() const { + MemoryPool* memoryPool() const { return pool_; } - std::shared_ptr GetStream( + std::shared_ptr getStream( std::shared_ptr source, int64_t start, - int64_t num_bytes); + int64_t numBytes); /// Buffered stream reading allows the user to control the memory usage of - /// parquet readers. This ensure that all `RandomAccessFile::ReadAt` calls are - /// wrapped in a buffered reader that uses a fix sized buffer (of size - /// `buffer_size()`) instead of the full size of the ReadAt. + /// Parquet readers. This ensure that all `RandomAccessFile::ReadAt` calls + /// are wrapped in a buffered reader that uses a fix sized buffer (of size + /// `bufferSize()`) instead of the full size of the ReadAt. /// - /// The primary reason for this control knobs is for resource control and not + /// The primary reason for this control knob is for resource control and not /// performance. - bool is_buffered_stream_enabled() const { - return buffered_stream_enabled_; + bool isBufferedStreamEnabled() const { + return bufferedStreamEnabled_; } /// Enable buffered stream reading. - void enable_buffered_stream() { - buffered_stream_enabled_ = true; + void enableBufferedStream() { + bufferedStreamEnabled_ = true; } /// Disable buffered stream reading. - void disable_buffered_stream() { - buffered_stream_enabled_ = false; + void disableBufferedStream() { + bufferedStreamEnabled_ = false; } /// Return the size of the buffered stream buffer. - int64_t buffer_size() const { - return buffer_size_; + int64_t bufferSize() const { + return bufferSize_; } /// Set the size of the buffered stream buffer in bytes. - void set_buffer_size(int64_t size) { - buffer_size_ = size; + void setBufferSize(int64_t size) { + bufferSize_ = size; } /// \brief Return the size limit on thrift strings. /// /// This limit helps prevent space and time bombs in files, but may need to /// be increased in order to read files with especially large headers. - int32_t thrift_string_size_limit() const { - return thrift_string_size_limit_; + int32_t thriftStringSizeLimit() const { + return thriftStringSizeLimit_; } /// Set the size limit on thrift strings. - void set_thrift_string_size_limit(int32_t size) { - thrift_string_size_limit_ = size; + void setThriftStringSizeLimit(int32_t size) { + thriftStringSizeLimit_ = size; } /// \brief Return the size limit on thrift containers. /// /// This limit helps prevent space and time bombs in files, but may need to /// be increased in order to read files with especially large headers. - int32_t thrift_container_size_limit() const { - return thrift_container_size_limit_; + int32_t thriftContainerSizeLimit() const { + return thriftContainerSizeLimit_; } /// Set the size limit on thrift containers. - void set_thrift_container_size_limit(int32_t size) { - thrift_container_size_limit_ = size; + void setThriftContainerSizeLimit(int32_t size) { + thriftContainerSizeLimit_ = size; } /// Set the decryption properties. - void file_decryption_properties( + void setFileDecryptionProperties( std::shared_ptr decryption) { - file_decryption_properties_ = std::move(decryption); + fileDecryptionProperties_ = std::move(decryption); } /// Return the decryption properties. - const std::shared_ptr& file_decryption_properties() + const std::shared_ptr& fileDecryptionProperties() const { - return file_decryption_properties_; + return fileDecryptionProperties_; } - bool page_checksum_verification() const { - return page_checksum_verification_; + bool pageChecksumVerification() const { + return pageChecksumVerification_; } - void set_page_checksum_verification(bool check_crc) { - page_checksum_verification_ = check_crc; + void setPageChecksumVerification(bool checkCrc) { + pageChecksumVerification_ = checkCrc; } private: MemoryPool* pool_; - int64_t buffer_size_ = kDefaultBufferSize; - int32_t thrift_string_size_limit_ = kDefaultThriftStringSizeLimit; - int32_t thrift_container_size_limit_ = kDefaultThriftContainerSizeLimit; - bool buffered_stream_enabled_ = false; - bool page_checksum_verification_ = false; - std::shared_ptr file_decryption_properties_; + int64_t bufferSize_ = kDefaultBufferSize; + int32_t thriftStringSizeLimit_ = kDefaultThriftStringSizeLimit; + int32_t thriftContainerSizeLimit_ = kDefaultThriftContainerSizeLimit; + bool bufferedStreamEnabled_ = false; + bool pageChecksumVerification_ = false; + std::shared_ptr fileDecryptionProperties_; }; -ReaderProperties PARQUET_EXPORT default_reader_properties(); +ReaderProperties PARQUET_EXPORT defaultReaderProperties(); static constexpr int64_t kDefaultDataPageSize = 1024 * 1024; static constexpr bool DEFAULT_IS_DICTIONARY_ENABLED = true; @@ -216,7 +213,7 @@ static constexpr int64_t DEFAULT_WRITE_BATCH_SIZE = 1024; static constexpr int64_t DEFAULT_MAX_ROW_GROUP_LENGTH = 1024 * 1024; static constexpr bool DEFAULT_ARE_STATISTICS_ENABLED = true; static constexpr int64_t DEFAULT_MAX_STATISTICS_SIZE = 4096; -static constexpr Encoding::type DEFAULT_ENCODING = Encoding::UNKNOWN; +static constexpr Encoding::type DEFAULT_ENCODING = Encoding::kUnknown; static const char DEFAULT_CREATED_BY[] = CREATED_BY_VERSION; static constexpr Compression::type DEFAULT_COMPRESSION_TYPE = Compression::UNCOMPRESSED; @@ -227,50 +224,50 @@ class PARQUET_EXPORT ColumnProperties { ColumnProperties( Encoding::type encoding = DEFAULT_ENCODING, Compression::type codec = DEFAULT_COMPRESSION_TYPE, - bool dictionary_enabled = DEFAULT_IS_DICTIONARY_ENABLED, - bool statistics_enabled = DEFAULT_ARE_STATISTICS_ENABLED, - size_t max_stats_size = DEFAULT_MAX_STATISTICS_SIZE, - bool page_index_enabled = DEFAULT_IS_PAGE_INDEX_ENABLED) + bool dictionaryEnabled = DEFAULT_IS_DICTIONARY_ENABLED, + bool statisticsEnabled = DEFAULT_ARE_STATISTICS_ENABLED, + size_t maxStatsSize = DEFAULT_MAX_STATISTICS_SIZE, + bool pageIndexEnabled = DEFAULT_IS_PAGE_INDEX_ENABLED) : encoding_(encoding), codec_(codec), - dictionary_enabled_(dictionary_enabled), - statistics_enabled_(statistics_enabled), - max_stats_size_(max_stats_size), - page_index_enabled_(page_index_enabled) {} + dictionaryEnabled_(dictionaryEnabled), + statisticsEnabled_(statisticsEnabled), + maxStatsSize_(maxStatsSize), + pageIndexEnabled_(pageIndexEnabled) {} - void set_encoding(Encoding::type encoding) { + void setEncoding(Encoding::type encoding) { encoding_ = encoding; } - void set_compression(Compression::type codec) { + void setCompression(Compression::type codec) { codec_ = codec; } - void set_dictionary_enabled(bool dictionary_enabled) { - dictionary_enabled_ = dictionary_enabled; + void setDictionaryEnabled(bool dictionaryEnabled) { + dictionaryEnabled_ = dictionaryEnabled; } - void set_statistics_enabled(bool statistics_enabled) { - statistics_enabled_ = statistics_enabled; + void setStatisticsEnabled(bool statisticsEnabled) { + statisticsEnabled_ = statisticsEnabled; } - void set_max_statistics_size(size_t max_stats_size) { - max_stats_size_ = max_stats_size; + void setMaxStatisticsSize(size_t maxStatsSize) { + maxStatsSize_ = maxStatsSize; } - void set_compression_level(int compression_level) { - if (!codec_options_) { - codec_options_ = std::make_shared(); + void setCompressionLevel(int compressionLevel) { + if (!codecOptions_) { + codecOptions_ = std::make_shared(); } - codec_options_->compression_level = compression_level; + codecOptions_->compressionLevel = compressionLevel; } - void set_codec_options(const std::shared_ptr& codec_options) { - codec_options_ = codec_options; + void setCodecOptions(const std::shared_ptr& codecOptions) { + codecOptions_ = codecOptions; } - void set_page_index_enabled(bool page_index_enabled) { - page_index_enabled_ = page_index_enabled; + void setPageIndexEnabled(bool pageIndexEnabled) { + pageIndexEnabled_ = pageIndexEnabled; } Encoding::type encoding() const { @@ -281,38 +278,38 @@ class PARQUET_EXPORT ColumnProperties { return codec_; } - bool dictionary_enabled() const { - return dictionary_enabled_; + bool dictionaryEnabled() const { + return dictionaryEnabled_; } - bool statistics_enabled() const { - return statistics_enabled_; + bool statisticsEnabled() const { + return statisticsEnabled_; } - size_t max_statistics_size() const { - return max_stats_size_; + size_t maxStatisticsSize() const { + return maxStatsSize_; } - int compression_level() const { - return codec_options_->compression_level; + int compressionLevel() const { + return codecOptions_->compressionLevel; } - const std::shared_ptr& codec_options() const { - return codec_options_; + const std::shared_ptr& codecOptions() const { + return codecOptions_; } - bool page_index_enabled() const { - return page_index_enabled_; + bool pageIndexEnabled() const { + return pageIndexEnabled_; } private: Encoding::type encoding_; Compression::type codec_; - bool dictionary_enabled_; - bool statistics_enabled_; - size_t max_stats_size_; - std::shared_ptr codec_options_; - bool page_index_enabled_; + bool dictionaryEnabled_; + bool statisticsEnabled_; + size_t maxStatsSize_; + std::shared_ptr codecOptions_; + bool pageIndexEnabled_; }; class PARQUET_EXPORT WriterProperties { @@ -321,95 +318,94 @@ class PARQUET_EXPORT WriterProperties { public: Builder() : pool_(::arrow::default_memory_pool()), - dictionary_pagesize_limit_(DEFAULT_DICTIONARY_PAGE_SIZE_LIMIT), - write_batch_size_(DEFAULT_WRITE_BATCH_SIZE), - max_row_group_length_(DEFAULT_MAX_ROW_GROUP_LENGTH), + dictionaryPagesizeLimit_(DEFAULT_DICTIONARY_PAGE_SIZE_LIMIT), + writeBatchSize_(DEFAULT_WRITE_BATCH_SIZE), + maxRowGroupLength_(DEFAULT_MAX_ROW_GROUP_LENGTH), pagesize_(kDefaultDataPageSize), version_(ParquetVersion::PARQUET_2_6), - data_page_version_(ParquetDataPageVersion::V1), - created_by_( + dataPageVersion_(ParquetDataPageVersion::V1), + createdBy_( DEFAULT_CREATED_BY + std::string(" version ") + VELOX_VERSION), - store_decimal_as_integer_(false), - page_checksum_enabled_(false) {} + storeDecimalAsInteger_(false), + pageChecksumEnabled_(false) {} virtual ~Builder() {} /// Specify the memory pool for the writer. Default default_memory_pool. - Builder* memory_pool(MemoryPool* pool) { + Builder* memoryPool(MemoryPool* pool) { pool_ = pool; return this; } /// Enable dictionary encoding in general for all columns. Default enabled. - Builder* enable_dictionary() { - default_column_properties_.set_dictionary_enabled(true); + Builder* enableDictionary() { + defaultColumnProperties_.setDictionaryEnabled(true); return this; } /// Disable dictionary encoding in general for all columns. Default enabled. - Builder* disable_dictionary() { - default_column_properties_.set_dictionary_enabled(false); + Builder* disableDictionary() { + defaultColumnProperties_.setDictionaryEnabled(false); return this; } /// Enable dictionary encoding for column specified by `path`. Default /// enabled. - Builder* enable_dictionary(const std::string& path) { - dictionary_enabled_[path] = true; + Builder* enableDictionary(const std::string& path) { + dictionaryEnabled_[path] = true; return this; } /// Enable dictionary encoding for column specified by `path`. Default /// enabled. - Builder* enable_dictionary( - const std::shared_ptr& path) { - return this->enable_dictionary(path->ToDotString()); + Builder* enableDictionary(const std::shared_ptr& path) { + return this->enableDictionary(path->toDotString()); } /// Disable dictionary encoding for column specified by `path`. Default /// enabled. - Builder* disable_dictionary(const std::string& path) { - dictionary_enabled_[path] = false; + Builder* disableDictionary(const std::string& path) { + dictionaryEnabled_[path] = false; return this; } /// Disable dictionary encoding for column specified by `path`. Default /// enabled. - Builder* disable_dictionary( + Builder* disableDictionary( const std::shared_ptr& path) { - return this->disable_dictionary(path->ToDotString()); + return this->disableDictionary(path->toDotString()); } /// Specify the dictionary page size limit per row group. Default 1MB. - Builder* dictionary_pagesize_limit(int64_t dictionary_psize_limit) { - dictionary_pagesize_limit_ = dictionary_psize_limit; + Builder* dictionaryPagesizeLimit(int64_t dictionaryPsizeLimit) { + dictionaryPagesizeLimit_ = dictionaryPsizeLimit; return this; } /// Specify the write batch size while writing batches of Arrow values into /// Parquet. Default 1024. - Builder* write_batch_size(int64_t write_batch_size) { - write_batch_size_ = write_batch_size; + Builder* writeBatchSize(int64_t writeBatchSize) { + writeBatchSize_ = writeBatchSize; return this; } /// Specify the max number of rows to put in a single row group. /// Default 1Mi rows. - Builder* max_row_group_length(int64_t max_row_group_length) { - max_row_group_length_ = max_row_group_length; + Builder* maxRowGroupLength(int64_t maxRowGroupLength) { + maxRowGroupLength_ = maxRowGroupLength; return this; } /// Specify the data page size. /// Default 1MB. - Builder* data_pagesize(int64_t pg_size) { - pagesize_ = pg_size; + Builder* dataPagesize(int64_t pgSize) { + pagesize_ = pgSize; return this; } /// Specify the data page version. /// Default V1. - Builder* data_page_version(ParquetDataPageVersion data_page_version) { - data_page_version_ = data_page_version; + Builder* dataPageVersion(ParquetDataPageVersion dataPageVersion) { + dataPageVersion_ = dataPageVersion; return this; } @@ -420,75 +416,75 @@ class PARQUET_EXPORT WriterProperties { return this; } - Builder* created_by(const std::string& created_by) { - created_by_ = created_by; + Builder* createdBy(const std::string& createdBy) { + createdBy_ = createdBy; return this; } - Builder* enable_page_checksum() { - page_checksum_enabled_ = true; + Builder* enablePageChecksum() { + pageChecksumEnabled_ = true; return this; } - Builder* disable_page_checksum() { - page_checksum_enabled_ = false; + Builder* disablePageChecksum() { + pageChecksumEnabled_ = false; return this; } - /// \brief Define the encoding that is used when we don't utilise dictionary - /// encoding. + /// \brief Define the encoding that is used when we don't utilise + /// dictionary encoding. // - /// This either apply if dictionary encoding is disabled or if we fallback - /// as the dictionary grew too large. - Builder* encoding(Encoding::type encoding_type) { - if (encoding_type == Encoding::PLAIN_DICTIONARY || - encoding_type == Encoding::RLE_DICTIONARY) { + /// This either applies if dictionary encoding is disabled or if we + /// fallback because the dictionary grew too large. + Builder* encoding(Encoding::type encodingType) { + if (encodingType == Encoding::kPlainDictionary || + encodingType == Encoding::kRleDictionary) { throw ParquetException( "Can't use dictionary encoding as fallback encoding"); } - default_column_properties_.set_encoding(encoding_type); + defaultColumnProperties_.setEncoding(encodingType); return this; } - /// \brief Define the encoding that is used when we don't utilise dictionary - /// encoding. + /// \brief Define the encoding that is used when we don't utilise + /// dictionary encoding. // - /// This either apply if dictionary encoding is disabled or if we fallback - /// as the dictionary grew too large. - Builder* encoding(const std::string& path, Encoding::type encoding_type) { - if (encoding_type == Encoding::PLAIN_DICTIONARY || - encoding_type == Encoding::RLE_DICTIONARY) { + /// This either applies if dictionary encoding is disabled or if we + /// fallback because the dictionary grew too large. + Builder* encoding(const std::string& path, Encoding::type encodingType) { + if (encodingType == Encoding::kPlainDictionary || + encodingType == Encoding::kRleDictionary) { throw ParquetException( "Can't use dictionary encoding as fallback encoding"); } - encodings_[path] = encoding_type; + encodings_[path] = encodingType; return this; } - /// \brief Define the encoding that is used when we don't utilise dictionary - /// encoding. + /// \brief Define the encoding that is used when we don't utilise + /// dictionary encoding. // - /// This either apply if dictionary encoding is disabled or if we fallback - /// as the dictionary grew too large. + /// This either applies if dictionary encoding is disabled or if we + /// fallback because the dictionary grew too large. Builder* encoding( const std::shared_ptr& path, - Encoding::type encoding_type) { - return this->encoding(path->ToDotString(), encoding_type); + Encoding::type encodingType) { + return this->encoding(path->toDotString(), encodingType); } /// Specify compression codec in general for all columns. /// Default UNCOMPRESSED. Builder* compression(Compression::type codec) { - default_column_properties_.set_compression(codec); + defaultColumnProperties_.setCompression(codec); return this; } /// Specify max statistics size to store min max value. /// Default 4KB. - Builder* max_statistics_size(size_t max_stats_sz) { - default_column_properties_.set_max_statistics_size(max_stats_sz); + Builder* maxStatisticsSize(size_t maxStatsSz) { + defaultColumnProperties_.setMaxStatisticsSize(maxStatsSz); return this; } @@ -504,126 +500,125 @@ class PARQUET_EXPORT WriterProperties { Builder* compression( const std::shared_ptr& path, Compression::type codec) { - return this->compression(path->ToDotString(), codec); + return this->compression(path->toDotString(), codec); } - /// \brief Specify the default compression level for the compressor in - /// every column. In case a column does not have an explicitly specified - /// compression level, the default one would be used. + /// \brief Specify the default compression level for the compressor + /// in every column. In case a column does not have an explicitly + /// specified compression level, the default one would be used. /// - /// The provided compression level is compressor specific. The user would - /// have to familiarize oneself with the available levels for the selected - /// compressor. If the compressor does not allow for selecting different - /// compression levels, calling this function would not have any effect. - /// Parquet and Arrow do not validate the passed compression level. If no - /// level is selected by the user or if the special - /// std::numeric_limits::min() value is passed, then Arrow selects the - /// compression level. + /// The provided compression level is compressor specific. The user + /// would have to familiarize oneself with the available levels for + /// the selected compressor. If the compressor does not allow for + /// selecting different compression levels, calling this function + /// would not have any effect. Parquet and Arrow do not validate the + /// passed compression level. If no level is selected by the user or + /// if the special std::numeric_limits::min() value is passed, + /// then Arrow selects the compression level. /// - /// If other compressor-specific options need to be set in addition to the - /// compression level, use the codec_options method. - Builder* compression_level(int compression_level) { - default_column_properties_.set_compression_level(compression_level); + /// If other compressor-specific options need to be set in addition + /// to the compression level, use the codec_options method. + Builder* compressionLevel(int compressionLevel) { + defaultColumnProperties_.setCompressionLevel(compressionLevel); return this; } - /// \brief Specify a compression level for the compressor for the column - /// described by path. + /// \brief Specify a compression level for the compressor for the + /// column described by path. /// - /// The provided compression level is compressor specific. The user would - /// have to familiarize oneself with the available levels for the selected - /// compressor. If the compressor does not allow for selecting different - /// compression levels, calling this function would not have any effect. - /// Parquet and Arrow do not validate the passed compression level. If no - /// level is selected by the user or if the special - /// std::numeric_limits::min() value is passed, then Arrow selects the - /// compression level. - Builder* compression_level(const std::string& path, int compression_level) { - if (!codec_options_[path]) { - codec_options_[path] = std::make_shared(); + /// The provided compression level is compressor specific. The user + /// would have to familiarize oneself with the available levels for + /// the selected compressor. If the compressor does not allow for + /// selecting different compression levels, calling this function + /// would not have any effect. Parquet and Arrow do not validate the + /// passed compression level. If no level is selected by the user or + /// if the special std::numeric_limits::min() value is passed, + /// then Arrow selects the compression level. + Builder* compressionLevel(const std::string& path, int compressionLevel) { + if (!codecOptions_[path]) { + codecOptions_[path] = std::make_shared(); } - codec_options_[path]->compression_level = compression_level; + codecOptions_[path]->compressionLevel = compressionLevel; return this; } - /// \brief Specify a compression level for the compressor for the column - /// described by path. + /// \brief Specify a compression level for the compressor for the + /// column described by path. /// - /// The provided compression level is compressor specific. The user would - /// have to familiarize oneself with the available levels for the selected - /// compressor. If the compressor does not allow for selecting different - /// compression levels, calling this function would not have any effect. - /// Parquet and Arrow do not validate the passed compression level. If no - /// level is selected by the user or if the special - /// std::numeric_limits::min() value is passed, then Arrow selects the - /// compression level. - Builder* compression_level( + /// The provided compression level is compressor specific. The user + /// would have to familiarize oneself with the available levels for + /// the selected compressor. If the compressor does not allow for + /// selecting different compression levels, calling this function + /// would not have any effect. Parquet and Arrow do not validate the + /// passed compression level. If no level is selected by the user or + /// if the special std::numeric_limits::min() value is passed, + /// then Arrow selects the compression level. + Builder* compressionLevel( const std::shared_ptr& path, - int compression_level) { - return this->compression_level(path->ToDotString(), compression_level); + int compressionLevel) { + return this->compressionLevel(path->toDotString(), compressionLevel); } /// \brief Specify the default codec options for the compressor in /// every column. /// - /// The codec options allow configuring the compression level as well - /// as other codec-specific options. - Builder* codec_options(const std::shared_ptr& codec_options) { - default_column_properties_.set_codec_options(codec_options); + /// The codec options allow configuring the compression level as + /// well as other codec-specific options. + Builder* codecOptions(const std::shared_ptr& codecOptions) { + defaultColumnProperties_.setCodecOptions(codecOptions); return this; } - /// \brief Specify the codec options for the compressor for the column - /// described by path. - Builder* codec_options( + /// \brief Specify the codec options for the compressor for the + /// column described by path. + Builder* codecOptions( const std::string& path, - const std::shared_ptr& codec_options) { - codec_options_[path] = codec_options; + const std::shared_ptr& codecOptions) { + codecOptions_[path] = codecOptions; return this; } - /// \brief Specify the codec options for the compressor for the column - /// described by path. - Builder* codec_options( + /// \brief Specify the codec options for the compressor for the + /// column described by path. + Builder* codecOptions( const std::shared_ptr& path, - const std::shared_ptr& codec_options) { - return this->codec_options(path->ToDotString(), codec_options); + const std::shared_ptr& codecOptions) { + return this->codecOptions(path->toDotString(), codecOptions); } /// Define the file encryption properties. /// Default NULL. Builder* encryption( - std::shared_ptr file_encryption_properties) { - file_encryption_properties_ = std::move(file_encryption_properties); + std::shared_ptr fileEncryptionProperties) { + fileEncryptionProperties_ = std::move(fileEncryptionProperties); return this; } /// Enable statistics in general. /// Default enabled. - Builder* enable_statistics() { - default_column_properties_.set_statistics_enabled(true); + Builder* enableStatistics() { + defaultColumnProperties_.setStatisticsEnabled(true); return this; } /// Disable statistics in general. /// Default enabled. - Builder* disable_statistics() { - default_column_properties_.set_statistics_enabled(false); + Builder* disableStatistics() { + defaultColumnProperties_.setStatisticsEnabled(false); return this; } /// Enable statistics for the column specified by `path`. /// Default enabled. - Builder* enable_statistics(const std::string& path) { - statistics_enabled_[path] = true; + Builder* enableStatistics(const std::string& path) { + statisticsEnabled_[path] = true; return this; } /// Enable statistics for the column specified by `path`. /// Default enabled. - Builder* enable_statistics( - const std::shared_ptr& path) { - return this->enable_statistics(path->ToDotString()); + Builder* enableStatistics(const std::shared_ptr& path) { + return this->enableStatistics(path->toDotString()); } /// Define the sorting columns. @@ -632,307 +627,312 @@ class PARQUET_EXPORT WriterProperties { /// If sorting columns are set, user should ensure that records /// are sorted by sorting columns. Otherwise, the storing data /// will be inconsistent with sorting_columns metadata. - Builder* set_sorting_columns(std::vector sorting_columns) { - sorting_columns_ = std::move(sorting_columns); + Builder* setSortingColumns(std::vector sortingColumns) { + sortingColumns_ = std::move(sortingColumns); return this; } /// Disable statistics for the column specified by `path`. /// Default enabled. - Builder* disable_statistics(const std::string& path) { - statistics_enabled_[path] = false; + Builder* disableStatistics(const std::string& path) { + statisticsEnabled_[path] = false; return this; } /// Disable statistics for the column specified by `path`. /// Default enabled. - Builder* disable_statistics( + Builder* disableStatistics( const std::shared_ptr& path) { - return this->disable_statistics(path->ToDotString()); + return this->disableStatistics(path->toDotString()); } - /// Allow decimals with 1 <= precision <= 18 to be stored as integers. + /// Allow decimals with 1 <= precision <= 18 to be stored as + /// integers. /// - /// In Parquet, DECIMAL can be stored in any of the following physical - /// types: - /// - int32: for 1 <= precision <= 9. - /// - int64: for 10 <= precision <= 18. - /// - fixed_len_byte_array: precision is limited by the array size. - /// Length n can store <= floor(log_10(2^(8*n - 1) - 1)) base-10 digits. - /// - binary: precision is unlimited. The minimum number of bytes to store + /// In Parquet, DECIMAL can be stored in any of the following + /// physical types: + /// - Int32: For 1 <= precision <= 9. + /// - Int64: For 10 <= precision <= 18. + /// - Fixed_len_byte_array: Precision is limited by the array size. + /// Length n can store <= floor(log_10(2^(8*n - 1) - 1)) base-10 + /// digits. + /// - Binary: Precision is unlimited. The minimum number of bytes to + /// store /// the unscaled value is used. /// - /// By default, this is DISABLED and all decimal types annotate - /// fixed_len_byte_array. + /// By default, this is DISABLED and all decimal types annotate. + /// Fixed_len_byte_array. /// - /// When enabled, the C++ writer will use following physical types to store - /// decimals: - /// - int32: for 1 <= precision <= 9. - /// - int64: for 10 <= precision <= 18. - /// - fixed_len_byte_array: for precision > 18. + /// When enabled, the C++ writer will use following physical types + /// to store decimals: + /// - Int32: For 1 <= precision <= 9. + /// - Int64: For 10 <= precision <= 18. + /// - Fixed_len_byte_array: For precision > 18. /// - /// As a consequence, decimal columns stored in integer types are more - /// compact. - Builder* enable_store_decimal_as_integer() { - store_decimal_as_integer_ = true; + /// As a consequence, decimal columns stored in integer types are + /// more compact. + Builder* enableStoreDecimalAsInteger() { + storeDecimalAsInteger_ = true; return this; } - /// Disable decimal logical type with 1 <= precision <= 18 to be stored as - /// integer physical type. + /// Disable decimal logical type with 1 <= precision <= 18 to be + /// stored as integer physical type. /// /// Default disabled. - Builder* disable_store_decimal_as_integer() { - store_decimal_as_integer_ = false; + Builder* disableStoreDecimalAsInteger() { + storeDecimalAsInteger_ = false; return this; } - /// Enable writing page index in general for all columns. Default disabled. + /// Enable writing page index in general for all columns. Default + /// disabled. /// - /// Writing statistics to the page index disables the old method of writing - /// statistics to each data page header. - /// The page index makes filtering more efficient than the page header, as - /// it gathers all the statistics for a Parquet file in a single place, + /// Writing statistics to the page index disables the old method of + /// writing statistics to each data page header. The page index + /// makes filtering more efficient than the page header, as it + /// gathers all the statistics for a Parquet file in a single place, /// avoiding scattered I/O. /// /// Please check the link below for more details: /// https://github.com/apache/parquet-format/blob/master/PageIndex.md - Builder* enable_write_page_index() { - default_column_properties_.set_page_index_enabled(true); + Builder* enableWritePageIndex() { + defaultColumnProperties_.setPageIndexEnabled(true); return this; } - /// Disable writing page index in general for all columns. Default disabled. - Builder* disable_write_page_index() { - default_column_properties_.set_page_index_enabled(false); + /// Disable writing page index in general for all columns. Default + /// disabled. + Builder* disableWritePageIndex() { + defaultColumnProperties_.setPageIndexEnabled(false); return this; } - /// Enable writing page index for column specified by `path`. Default - /// disabled. - Builder* enable_write_page_index(const std::string& path) { - page_index_enabled_[path] = true; + /// Enable writing page index for column specified by `path`. + /// Default disabled. + Builder* enableWritePageIndex(const std::string& path) { + pageIndexEnabled_[path] = true; return this; } - /// Enable writing page index for column specified by `path`. Default - /// disabled. - Builder* enable_write_page_index( + /// Enable writing page index for column specified by `path`. + /// Default disabled. + Builder* enableWritePageIndex( const std::shared_ptr& path) { - return this->enable_write_page_index(path->ToDotString()); + return this->enableWritePageIndex(path->toDotString()); } - /// Disable writing page index for column specified by `path`. Default - /// disabled. - Builder* disable_write_page_index(const std::string& path) { - page_index_enabled_[path] = false; + /// Disable writing page index for column specified by `path`. + /// Default disabled. + Builder* disableWritePageIndex(const std::string& path) { + pageIndexEnabled_[path] = false; return this; } - /// Disable writing page index for column specified by `path`. Default - /// disabled. - Builder* disable_write_page_index( + /// Disable writing page index for column specified by `path`. + /// Default disabled. + Builder* disableWritePageIndex( const std::shared_ptr& path) { - return this->disable_write_page_index(path->ToDotString()); + return this->disableWritePageIndex(path->toDotString()); } /// \brief Build the WriterProperties with the builder parameters. /// \return The WriterProperties defined by the builder. std::shared_ptr build() { - std::unordered_map column_properties; + std::unordered_map columnProperties; auto get = [&](const std::string& key) -> ColumnProperties& { - auto it = column_properties.find(key); - if (it == column_properties.end()) - return column_properties[key] = default_column_properties_; + auto it = columnProperties.find(key); + if (it == columnProperties.end()) + return columnProperties[key] = defaultColumnProperties_; else return it->second; }; for (const auto& item : encodings_) - get(item.first).set_encoding(item.second); + get(item.first).setEncoding(item.second); for (const auto& item : codecs_) - get(item.first).set_compression(item.second); - for (const auto& item : codec_options_) - get(item.first).set_codec_options(item.second); - for (const auto& item : dictionary_enabled_) - get(item.first).set_dictionary_enabled(item.second); - for (const auto& item : statistics_enabled_) - get(item.first).set_statistics_enabled(item.second); - for (const auto& item : page_index_enabled_) - get(item.first).set_page_index_enabled(item.second); + get(item.first).setCompression(item.second); + for (const auto& item : codecOptions_) + get(item.first).setCodecOptions(item.second); + for (const auto& item : dictionaryEnabled_) + get(item.first).setDictionaryEnabled(item.second); + for (const auto& item : statisticsEnabled_) + get(item.first).setStatisticsEnabled(item.second); + for (const auto& item : pageIndexEnabled_) + get(item.first).setPageIndexEnabled(item.second); return std::shared_ptr(new WriterProperties( pool_, - dictionary_pagesize_limit_, - write_batch_size_, - max_row_group_length_, + dictionaryPagesizeLimit_, + writeBatchSize_, + maxRowGroupLength_, pagesize_, version_, - created_by_, - page_checksum_enabled_, - std::move(file_encryption_properties_), - default_column_properties_, - column_properties, - data_page_version_, - store_decimal_as_integer_, - std::move(sorting_columns_))); + createdBy_, + pageChecksumEnabled_, + std::move(fileEncryptionProperties_), + defaultColumnProperties_, + columnProperties, + dataPageVersion_, + storeDecimalAsInteger_, + std::move(sortingColumns_))); } private: MemoryPool* pool_; - int64_t dictionary_pagesize_limit_; - int64_t write_batch_size_; - int64_t max_row_group_length_; + int64_t dictionaryPagesizeLimit_; + int64_t writeBatchSize_; + int64_t maxRowGroupLength_; int64_t pagesize_; ParquetVersion::type version_; - ParquetDataPageVersion data_page_version_; - std::string created_by_; - bool store_decimal_as_integer_; - bool page_checksum_enabled_; + ParquetDataPageVersion dataPageVersion_; + std::string createdBy_; + bool storeDecimalAsInteger_; + bool pageChecksumEnabled_; - std::shared_ptr file_encryption_properties_; + std::shared_ptr fileEncryptionProperties_; // If empty, there is no sorting columns. - std::vector sorting_columns_; + std::vector sortingColumns_; - // Settings used for each column unless overridden in any of the maps below - ColumnProperties default_column_properties_; + // Settings used for each column unless overridden in any of the + // maps below. + ColumnProperties defaultColumnProperties_; std::unordered_map encodings_; std::unordered_map codecs_; std::unordered_map> - codec_options_; - std::unordered_map dictionary_enabled_; - std::unordered_map statistics_enabled_; - std::unordered_map page_index_enabled_; + codecOptions_; + std::unordered_map dictionaryEnabled_; + std::unordered_map statisticsEnabled_; + std::unordered_map pageIndexEnabled_; }; - inline MemoryPool* memory_pool() const { + inline MemoryPool* memoryPool() const { return pool_; } - inline int64_t dictionary_pagesize_limit() const { - return dictionary_pagesize_limit_; + inline int64_t dictionaryPagesizeLimit() const { + return dictionaryPagesizeLimit_; } - inline int64_t write_batch_size() const { - return write_batch_size_; + inline int64_t writeBatchSize() const { + return writeBatchSize_; } - inline int64_t max_row_group_length() const { - return max_row_group_length_; + inline int64_t maxRowGroupLength() const { + return maxRowGroupLength_; } - inline int64_t data_pagesize() const { + inline int64_t dataPagesize() const { return pagesize_; } - inline ParquetDataPageVersion data_page_version() const { - return parquet_data_page_version_; + inline ParquetDataPageVersion dataPageVersion() const { + return parquetDataPageVersion_; } inline ParquetVersion::type version() const { - return parquet_version_; + return parquetVersion_; } - inline std::string created_by() const { - return parquet_created_by_; + inline std::string createdBy() const { + return parquetCreatedBy_; } - inline bool store_decimal_as_integer() const { - return store_decimal_as_integer_; + inline bool storeDecimalAsInteger() const { + return storeDecimalAsInteger_; } - inline bool page_checksum_enabled() const { - return page_checksum_enabled_; + inline bool pageChecksumEnabled() const { + return pageChecksumEnabled_; } - inline Encoding::type dictionary_index_encoding() const { - if (parquet_version_ == ParquetVersion::PARQUET_1_0) { - return Encoding::PLAIN_DICTIONARY; + inline Encoding::type dictionaryIndexEncoding() const { + if (parquetVersion_ == ParquetVersion::PARQUET_1_0) { + return Encoding::kPlainDictionary; } else { - return Encoding::RLE_DICTIONARY; + return Encoding::kRleDictionary; } } - inline Encoding::type dictionary_page_encoding() const { - if (parquet_version_ == ParquetVersion::PARQUET_1_0) { - return Encoding::PLAIN_DICTIONARY; + inline Encoding::type dictionaryPageEncoding() const { + if (parquetVersion_ == ParquetVersion::PARQUET_1_0) { + return Encoding::kPlainDictionary; } else { - return Encoding::PLAIN; + return Encoding::kPlain; } } - const ColumnProperties& column_properties( + const ColumnProperties& columnProperties( const std::shared_ptr& path) const { - auto it = column_properties_.find(path->ToDotString()); - if (it != column_properties_.end()) + auto it = columnProperties_.find(path->toDotString()); + if (it != columnProperties_.end()) return it->second; - return default_column_properties_; + return defaultColumnProperties_; } Encoding::type encoding( const std::shared_ptr& path) const { - return column_properties(path).encoding(); + return columnProperties(path).encoding(); } Compression::type compression( const std::shared_ptr& path) const { - return column_properties(path).compression(); + return columnProperties(path).compression(); } - int compression_level(const std::shared_ptr& path) const { - return column_properties(path).compression_level(); + int compressionLevel(const std::shared_ptr& path) const { + return columnProperties(path).compressionLevel(); } - const std::shared_ptr codec_options( + const std::shared_ptr codecOptions( const std::shared_ptr& path) const { - return column_properties(path).codec_options(); + return columnProperties(path).codecOptions(); } - bool dictionary_enabled( + bool dictionaryEnabled( const std::shared_ptr& path) const { - return column_properties(path).dictionary_enabled(); + return columnProperties(path).dictionaryEnabled(); } - const std::vector& sorting_columns() const { - return sorting_columns_; + const std::vector& sortingColumns() const { + return sortingColumns_; } - bool statistics_enabled( + bool statisticsEnabled( const std::shared_ptr& path) const { - return column_properties(path).statistics_enabled(); + return columnProperties(path).statisticsEnabled(); } - size_t max_statistics_size( + size_t maxStatisticsSize( const std::shared_ptr& path) const { - return column_properties(path).max_statistics_size(); + return columnProperties(path).maxStatisticsSize(); } - bool page_index_enabled( - const std::shared_ptr& path) const { - return column_properties(path).page_index_enabled(); + bool pageIndexEnabled(const std::shared_ptr& path) const { + return columnProperties(path).pageIndexEnabled(); } - bool page_index_enabled() const { - if (default_column_properties_.page_index_enabled()) { + bool pageIndexEnabled() const { + if (defaultColumnProperties_.pageIndexEnabled()) { return true; } - for (const auto& item : column_properties_) { - if (item.second.page_index_enabled()) { + for (const auto& item : columnProperties_) { + if (item.second.pageIndexEnabled()) { return true; } } return false; } - inline FileEncryptionProperties* file_encryption_properties() const { - return file_encryption_properties_.get(); + inline FileEncryptionProperties* fileEncryptionProperties() const { + return fileEncryptionProperties_.get(); } - std::shared_ptr column_encryption_properties( + std::shared_ptr columnEncryptionProperties( const std::string& path) const { - if (file_encryption_properties_) { - return file_encryption_properties_->column_encryption_properties(path); + if (fileEncryptionProperties_) { + return fileEncryptionProperties_->columnEncryptionProperties(path); } else { return NULLPTR; } @@ -941,103 +941,103 @@ class PARQUET_EXPORT WriterProperties { private: explicit WriterProperties( MemoryPool* pool, - int64_t dictionary_pagesize_limit, - int64_t write_batch_size, - int64_t max_row_group_length, + int64_t dictionaryPagesizeLimit, + int64_t writeBatchSize, + int64_t maxRowGroupLength, int64_t pagesize, ParquetVersion::type version, - const std::string& created_by, - bool page_write_checksum_enabled, - std::shared_ptr file_encryption_properties, - const ColumnProperties& default_column_properties, - const std::unordered_map& - column_properties, - ParquetDataPageVersion data_page_version, - bool store_short_decimal_as_integer, - std::vector sorting_columns) + const std::string& createdBy, + bool pageWriteChecksumEnabled, + std::shared_ptr fileEncryptionProperties, + const ColumnProperties& defaultColumnProperties, + const std::unordered_map& columnProperties, + ParquetDataPageVersion dataPageVersion, + bool storeShortDecimalAsInteger, + std::vector sortingColumns) : pool_(pool), - dictionary_pagesize_limit_(dictionary_pagesize_limit), - write_batch_size_(write_batch_size), - max_row_group_length_(max_row_group_length), + dictionaryPagesizeLimit_(dictionaryPagesizeLimit), + writeBatchSize_(writeBatchSize), + maxRowGroupLength_(maxRowGroupLength), pagesize_(pagesize), - parquet_data_page_version_(data_page_version), - parquet_version_(version), - parquet_created_by_(created_by), - store_decimal_as_integer_(store_short_decimal_as_integer), - page_checksum_enabled_(page_write_checksum_enabled), - file_encryption_properties_(file_encryption_properties), - sorting_columns_(std::move(sorting_columns)), - default_column_properties_(default_column_properties), - column_properties_(column_properties) {} + parquetDataPageVersion_(dataPageVersion), + parquetVersion_(version), + parquetCreatedBy_(createdBy), + storeDecimalAsInteger_(storeShortDecimalAsInteger), + pageChecksumEnabled_(pageWriteChecksumEnabled), + fileEncryptionProperties_(std::move(fileEncryptionProperties)), + sortingColumns_(std::move(sortingColumns)), + defaultColumnProperties_(defaultColumnProperties), + columnProperties_(columnProperties) {} MemoryPool* pool_; - int64_t dictionary_pagesize_limit_; - int64_t write_batch_size_; - int64_t max_row_group_length_; + int64_t dictionaryPagesizeLimit_; + int64_t writeBatchSize_; + int64_t maxRowGroupLength_; int64_t pagesize_; - ParquetDataPageVersion parquet_data_page_version_; - ParquetVersion::type parquet_version_; - std::string parquet_created_by_; - bool store_decimal_as_integer_; - bool page_checksum_enabled_; + ParquetDataPageVersion parquetDataPageVersion_; + ParquetVersion::type parquetVersion_; + std::string parquetCreatedBy_; + bool storeDecimalAsInteger_; + bool pageChecksumEnabled_; - std::shared_ptr file_encryption_properties_; + std::shared_ptr fileEncryptionProperties_; - std::vector sorting_columns_; + std::vector sortingColumns_; - ColumnProperties default_column_properties_; - std::unordered_map column_properties_; + ColumnProperties defaultColumnProperties_; + std::unordered_map columnProperties_; }; PARQUET_EXPORT const std::shared_ptr& -default_writer_properties(); +defaultWriterProperties(); -// ---------------------------------------------------------------------- -// Properties specific to Apache Arrow columnar read and write +// ----------------------------------------------------------------------. +// Properties specific to Apache Arrow columnar read and write. static constexpr bool kArrowDefaultUseThreads = false; -// Default number of rows to read when using ::arrow::RecordBatchReader +// Default number of rows to read when using ::arrow::RecordBatchReader. static constexpr int64_t kArrowDefaultBatchSize = 64 * 1024; /// EXPERIMENTAL: Properties for configuring FileReader behavior. class PARQUET_EXPORT ArrowReaderProperties { public: - explicit ArrowReaderProperties(bool use_threads = kArrowDefaultUseThreads) - : use_threads_(use_threads), - read_dict_indices_(), - batch_size_(kArrowDefaultBatchSize), - pre_buffer_(false), - cache_options_(::arrow::io::CacheOptions::Defaults()), - coerce_int96_timestamp_unit_(::arrow::TimeUnit::NANO) {} - - /// \brief Set whether to use the IO thread pool to parse columns in parallel. + explicit ArrowReaderProperties(bool useThreads = kArrowDefaultUseThreads) + : useThreads_(useThreads), + readDictIndices_(), + batchSize_(kArrowDefaultBatchSize), + preBuffer_(false), + cacheOptions_(::arrow::io::CacheOptions::Defaults()), + coerceInt96TimestampUnit_(::arrow::TimeUnit::NANO) {} + + /// \brief Set whether to use the IO thread pool to parse columns in + /// parallel. /// /// Default is false. - void set_use_threads(bool use_threads) { - use_threads_ = use_threads; + void setUseThreads(bool useThreads) { + useThreads_ = useThreads; } /// Return whether will use multiple threads. - bool use_threads() const { - return use_threads_; + bool useThreads() const { + return useThreads_; } - /// \brief Set whether to read a particular column as dictionary encoded. + /// \brief Set whether to read a particular column as dictionary + /// encoded. /// - /// If the file metadata contains a serialized Arrow schema, then ... - //// - /// This is only supported for columns with a Parquet physical type of + /// If the file metadata contains a serialized Arrow schema, then this + /// is only supported for columns with a Parquet physical type of /// BYTE_ARRAY, such as string or binary types. - void set_read_dictionary(int column_index, bool read_dict) { - if (read_dict) { - read_dict_indices_.insert(column_index); + void setReadDictionary(int columnIndex, bool readDict) { + if (readDict) { + readDictIndices_.insert(columnIndex); } else { - read_dict_indices_.erase(column_index); + readDictIndices_.erase(columnIndex); } } /// Return whether the column at the index will be read as dictionary. - bool read_dictionary(int column_index) const { - if (read_dict_indices_.find(column_index) != read_dict_indices_.end()) { + bool readDictionary(int columnIndex) const { + if (readDictIndices_.find(columnIndex) != readDictIndices_.end()) { return true; } else { return false; @@ -1048,71 +1048,71 @@ class PARQUET_EXPORT ArrowReaderProperties { /// /// Will only be fewer rows when there are no more rows in the file. /// Note that some APIs such as ReadTable may ignore this setting. - void set_batch_size(int64_t batch_size) { - batch_size_ = batch_size; + void setBatchSize(int64_t batchSize) { + batchSize_ = batchSize; } /// Return the batch size in rows. /// /// Note that some APIs such as ReadTable may ignore this setting. - int64_t batch_size() const { - return batch_size_; + int64_t batchSize() const { + return batchSize_; } /// Enable read coalescing (default false). /// - /// When enabled, the Arrow reader will pre-buffer necessary regions - /// of the file in-memory. This is intended to improve performance on - /// high-latency filesystems (e.g. Amazon S3). - void set_pre_buffer(bool pre_buffer) { - pre_buffer_ = pre_buffer; + /// When enabled, the Arrow reader will pre-buffer necessary regions. + /// Of the file in-memory. This is intended to improve performance on. + /// High-latency filesystems (e.g. Amazon S3). + void setPreBuffer(bool preBuffer) { + preBuffer_ = preBuffer; } /// Return whether read coalescing is enabled. - bool pre_buffer() const { - return pre_buffer_; + bool preBuffer() const { + return preBuffer_; } - /// Set options for read coalescing. This can be used to tune the - /// implementation for characteristics of different filesystems. - void set_cache_options(::arrow::io::CacheOptions options) { - cache_options_ = options; + /// Set options for read coalescing. This can be used to tune the. + /// Implementation for characteristics of different filesystems. + void setCacheOptions(::arrow::io::CacheOptions options) { + cacheOptions_ = options; } /// Return the options for read coalescing. - const ::arrow::io::CacheOptions& cache_options() const { - return cache_options_; + const ::arrow::io::CacheOptions& cacheOptions() const { + return cacheOptions_; } /// Set execution context for read coalescing. - void set_io_context(const ::arrow::io::IOContext& ctx) { - io_context_ = ctx; + void setIoContext(const ::arrow::io::IOContext& ctx) { + ioContext_ = ctx; } /// Return the execution context used for read coalescing. - const ::arrow::io::IOContext& io_context() const { - return io_context_; + const ::arrow::io::IOContext& ioContext() const { + return ioContext_; } - /// Set timestamp unit to use for deprecated INT96-encoded timestamps - /// (default is NANO). - void set_coerce_int96_timestamp_unit(::arrow::TimeUnit::type unit) { - coerce_int96_timestamp_unit_ = unit; + /// Set timestamp unit to use for deprecated INT96-encoded timestamps. + /// (Default is NANO). + void setCoerceInt96TimestampUnit(::arrow::TimeUnit::type unit) { + coerceInt96TimestampUnit_ = unit; } - ::arrow::TimeUnit::type coerce_int96_timestamp_unit() const { - return coerce_int96_timestamp_unit_; + ::arrow::TimeUnit::type coerceInt96TimestampUnit() const { + return coerceInt96TimestampUnit_; } private: - bool use_threads_; - std::unordered_set read_dict_indices_; - int64_t batch_size_; - bool pre_buffer_; - ::arrow::io::IOContext io_context_; - ::arrow::io::CacheOptions cache_options_; - ::arrow::TimeUnit::type coerce_int96_timestamp_unit_; + bool useThreads_; + std::unordered_set readDictIndices_; + int64_t batchSize_; + bool preBuffer_; + ::arrow::io::IOContext ioContext_; + ::arrow::io::CacheOptions cacheOptions_; + ::arrow::TimeUnit::type coerceInt96TimestampUnit_; }; -/// EXPERIMENTAL: Constructs the default ArrowReaderProperties +/// EXPERIMENTAL: Constructs the default ArrowReaderProperties. PARQUET_EXPORT -ArrowReaderProperties default_arrow_reader_properties(); +ArrowReaderProperties defaultArrowReaderProperties(); class PARQUET_EXPORT ArrowWriterProperties { public: @@ -1123,106 +1123,110 @@ class PARQUET_EXPORT ArrowWriterProperties { class Builder { public: Builder() - : write_timestamps_as_int96_(false), - coerce_timestamps_enabled_(false), - coerce_timestamps_unit_(::arrow::TimeUnit::SECOND), - truncated_timestamps_allowed_(false), - store_schema_(false), - compliant_nested_types_(true), - engine_version_(V2), - use_threads_(kArrowDefaultUseThreads), + : writeTimestampsAsInt96_(false), + coerceTimestampsEnabled_(false), + coerceTimestampsUnit_(::arrow::TimeUnit::SECOND), + truncatedTimestampsAllowed_(false), + storeSchema_(false), + compliantNestedTypes_(true), + engineVersion_(V2), + useThreads_(kArrowDefaultUseThreads), executor_(NULLPTR) {} virtual ~Builder() = default; - /// \brief Disable writing legacy int96 timestamps (default disabled). - Builder* disable_deprecated_int96_timestamps() { - write_timestamps_as_int96_ = false; + /// \brief Disable writing legacy int96 timestamps (default + /// disabled). + Builder* disableDeprecatedInt96Timestamps() { + writeTimestampsAsInt96_ = false; return this; } - /// \brief Enable writing legacy int96 timestamps (default disabled). + /// \brief Enable writing legacy int96 timestamps (default + /// disabled). /// - /// May be turned on to write timestamps compatible with older Parquet - /// writers. This takes precedent over coerce_timestamps. - Builder* enable_deprecated_int96_timestamps() { - write_timestamps_as_int96_ = true; + /// May be turned on to write timestamps compatible with older + /// Parquet writers. This takes precedent over coerceTimestamps. + Builder* enableDeprecatedInt96Timestamps() { + writeTimestampsAsInt96_ = true; return this; } /// \brief Coerce all timestamps to the specified time unit. /// \param unit time unit to truncate to. - /// For Parquet versions 1.0 and 2.4, nanoseconds are casted to - /// microseconds. - Builder* coerce_timestamps(::arrow::TimeUnit::type unit) { - coerce_timestamps_enabled_ = true; - coerce_timestamps_unit_ = unit; + /// For Parquet versions 1.0 and 2.4, nanoseconds are casted to. + /// Microseconds. + Builder* coerceTimestamps(::arrow::TimeUnit::type unit) { + coerceTimestampsEnabled_ = true; + coerceTimestampsUnit_ = unit; return this; } /// \brief Allow loss of data when truncating timestamps. /// /// This is disallowed by default and an error will be returned. - Builder* allow_truncated_timestamps() { - truncated_timestamps_allowed_ = true; + Builder* allowTruncatedTimestamps() { + truncatedTimestampsAllowed_ = true; return this; } - /// \brief Disallow loss of data when truncating timestamps (default). - Builder* disallow_truncated_timestamps() { - truncated_timestamps_allowed_ = false; + /// \brief Disallow loss of data when truncating timestamps + /// (default). + Builder* disallowTruncatedTimestamps() { + truncatedTimestampsAllowed_ = false; return this; } - /// \brief EXPERIMENTAL: Write binary serialized Arrow schema to the file, - /// to enable certain read options (like "read_dictionary") to be set - /// automatically - Builder* store_schema() { - store_schema_ = true; + /// \brief EXPERIMENTAL: Write binary serialized Arrow schema to the + /// file, to enable certain read options (like "read_dictionary") to + /// be set automatically. + Builder* storeSchema() { + storeSchema_ = true; return this; } - /// \brief When enabled, will not preserve Arrow field names for list types. + /// \brief When enabled, will not preserve Arrow field names for + /// list types. /// - /// Instead of using the field names Arrow uses for the values array of - /// list types (default "item"), will use "element", as is specified in - /// the Parquet spec. + /// Instead of using the field names Arrow uses for the values array + /// of. List types (default "item"), will use "element", as is + /// specified in. The Parquet spec. /// /// This is enabled by default. - Builder* enable_compliant_nested_types() { - compliant_nested_types_ = true; + Builder* enableCompliantNestedTypes() { + compliantNestedTypes_ = true; return this; } /// Preserve Arrow list field name. - Builder* disable_compliant_nested_types() { - compliant_nested_types_ = false; + Builder* disableCompliantNestedTypes() { + compliantNestedTypes_ = false; return this; } /// Set the version of the Parquet writer engine. - Builder* set_engine_version(EngineVersion version) { - engine_version_ = version; + Builder* setEngineVersion(EngineVersion version) { + engineVersion_ = version; return this; } - /// \brief Set whether to use multiple threads to write columns - /// in parallel in the buffered row group mode. + /// \brief Set whether to use multiple threads to write columns. + /// In parallel in the buffered row group mode. /// - /// WARNING: If writing multiple files in parallel in the same - /// executor, deadlock may occur if use_threads is true. Please - /// disable it in this case. + /// WARNING: If writing multiple files in parallel in the same. + /// Executor, deadlock may occur if use_threads is true. Please. + /// Disable it in this case. /// /// Default is false. - Builder* set_use_threads(bool use_threads) { - use_threads_ = use_threads; + Builder* setUseThreads(bool useThreads) { + useThreads_ = useThreads; return this; } - /// \brief Set the executor to write columns in parallel in the - /// buffered row group mode. + /// \brief Set the executor to write columns in parallel in the. + /// Buffered row group mode. /// /// Default is nullptr and the default cpu executor will be used. - Builder* set_executor(::arrow::internal::Executor* executor) { + Builder* setExecutor(::arrow::internal::Executor* executor) { executor_ = executor; return this; } @@ -1230,72 +1234,74 @@ class PARQUET_EXPORT ArrowWriterProperties { /// Create the final properties. std::shared_ptr build() { return std::shared_ptr(new ArrowWriterProperties( - write_timestamps_as_int96_, - coerce_timestamps_enabled_, - coerce_timestamps_unit_, - truncated_timestamps_allowed_, - store_schema_, - compliant_nested_types_, - engine_version_, - use_threads_, + writeTimestampsAsInt96_, + coerceTimestampsEnabled_, + coerceTimestampsUnit_, + truncatedTimestampsAllowed_, + storeSchema_, + compliantNestedTypes_, + engineVersion_, + useThreads_, executor_)); } private: - bool write_timestamps_as_int96_; + bool writeTimestampsAsInt96_; - bool coerce_timestamps_enabled_; - ::arrow::TimeUnit::type coerce_timestamps_unit_; - bool truncated_timestamps_allowed_; + bool coerceTimestampsEnabled_; + ::arrow::TimeUnit::type coerceTimestampsUnit_; + bool truncatedTimestampsAllowed_; - bool store_schema_; - bool compliant_nested_types_; - EngineVersion engine_version_; + bool storeSchema_; + bool compliantNestedTypes_; + EngineVersion engineVersion_; - bool use_threads_; + bool useThreads_; ::arrow::internal::Executor* executor_; }; - bool support_deprecated_int96_timestamps() const { - return write_timestamps_as_int96_; + bool supportDeprecatedInt96Timestamps() const { + return writeTimestampsAsInt96_; } - bool coerce_timestamps_enabled() const { - return coerce_timestamps_enabled_; + bool coerceTimestampsEnabled() const { + return coerceTimestampsEnabled_; } - ::arrow::TimeUnit::type coerce_timestamps_unit() const { - return coerce_timestamps_unit_; + ::arrow::TimeUnit::type coerceTimestampsUnit() const { + return coerceTimestampsUnit_; } - bool truncated_timestamps_allowed() const { - return truncated_timestamps_allowed_; + bool truncatedTimestampsAllowed() const { + return truncatedTimestampsAllowed_; } - bool store_schema() const { - return store_schema_; + bool storeSchema() const { + return storeSchema_; } - /// \brief Enable nested type naming according to the parquet specification. + /// \brief Enable nested type naming according to the parquet + /// specification. /// - /// Older versions of arrow wrote out field names for nested lists based on - /// the name of the field. According to the parquet specification they should - /// always be "element". - bool compliant_nested_types() const { - return compliant_nested_types_; + /// Older versions of arrow wrote out field names for nested lists + /// based on the name of the field. According to the Parquet + /// specification they should always be "element". + bool compliantNestedTypes() const { + return compliantNestedTypes_; } - /// \brief The underlying engine version to use when writing Arrow data. + /// \brief The underlying engine version to use when writing Arrow + /// data. /// - /// V2 is currently the latest V1 is considered deprecated but left in - /// place in case there are bugs detected in V2. - EngineVersion engine_version() const { - return engine_version_; + /// V2 is currently the latest V1 is considered deprecated but left + /// in. Place in case there are bugs detected in V2. + EngineVersion engineVersion() const { + return engineVersion_; } - /// \brief Returns whether the writer will use multiple threads - /// to write columns in parallel in the buffered row group mode. - bool use_threads() const { - return use_threads_; + /// \brief Returns whether the writer will use multiple threads. + /// To write columns in parallel in the buffered row group mode. + bool useThreads() const { + return useThreads_; } /// \brief Returns the executor used to write columns in parallel. @@ -1303,65 +1309,64 @@ class PARQUET_EXPORT ArrowWriterProperties { private: explicit ArrowWriterProperties( - bool write_nanos_as_int96, - bool coerce_timestamps_enabled, - ::arrow::TimeUnit::type coerce_timestamps_unit, - bool truncated_timestamps_allowed, - bool store_schema, - bool compliant_nested_types, - EngineVersion engine_version, - bool use_threads, + bool writeNanosAsInt96, + bool coerceTimestampsEnabled, + ::arrow::TimeUnit::type coerceTimestampsUnit, + bool truncatedTimestampsAllowed, + bool storeSchema, + bool compliantNestedTypes, + EngineVersion engineVersion, + bool useThreads, ::arrow::internal::Executor* executor) - : write_timestamps_as_int96_(write_nanos_as_int96), - coerce_timestamps_enabled_(coerce_timestamps_enabled), - coerce_timestamps_unit_(coerce_timestamps_unit), - truncated_timestamps_allowed_(truncated_timestamps_allowed), - store_schema_(store_schema), - compliant_nested_types_(compliant_nested_types), - engine_version_(engine_version), - use_threads_(use_threads), + : writeTimestampsAsInt96_(writeNanosAsInt96), + coerceTimestampsEnabled_(coerceTimestampsEnabled), + coerceTimestampsUnit_(coerceTimestampsUnit), + truncatedTimestampsAllowed_(truncatedTimestampsAllowed), + storeSchema_(storeSchema), + compliantNestedTypes_(compliantNestedTypes), + engineVersion_(engineVersion), + useThreads_(useThreads), executor_(executor) {} - const bool write_timestamps_as_int96_; - const bool coerce_timestamps_enabled_; - const ::arrow::TimeUnit::type coerce_timestamps_unit_; - const bool truncated_timestamps_allowed_; - const bool store_schema_; - const bool compliant_nested_types_; - const EngineVersion engine_version_; - const bool use_threads_; + const bool writeTimestampsAsInt96_; + const bool coerceTimestampsEnabled_; + const ::arrow::TimeUnit::type coerceTimestampsUnit_; + const bool truncatedTimestampsAllowed_; + const bool storeSchema_; + const bool compliantNestedTypes_; + const EngineVersion engineVersion_; + const bool useThreads_; ::arrow::internal::Executor* executor_; }; -/// \brief State object used for writing Arrow data directly to a Parquet -/// column chunk. API possibly not stable +/// \brief State object used for writing Arrow data directly to a +/// Parquet. Column chunk. API possibly not stable. struct ArrowWriteContext { - ArrowWriteContext(MemoryPool* memory_pool, ArrowWriterProperties* properties) - : memory_pool(memory_pool), + ArrowWriteContext(MemoryPool* memoryPool, ArrowWriterProperties* properties) + : memoryPool(memoryPool), properties(properties), - data_buffer(AllocateBuffer(memory_pool)), - def_levels_buffer(AllocateBuffer(memory_pool)) {} + dataBuffer(allocateBuffer(memoryPool)), + defLevelsBuffer(allocateBuffer(memoryPool)) {} template - ::arrow::Status GetScratchData(const int64_t num_values, T** out) { - ARROW_RETURN_NOT_OK( - this->data_buffer->Resize(num_values * sizeof(T), false)); - *out = reinterpret_cast(this->data_buffer->mutable_data()); + ::arrow::Status getScratchData(const int64_t numValues, T** out) { + ARROW_RETURN_NOT_OK(this->dataBuffer->Resize(numValues * sizeof(T), false)); + *out = reinterpret_cast(this->dataBuffer->mutable_data()); return ::arrow::Status::OK(); } - MemoryPool* memory_pool; + MemoryPool* memoryPool; const ArrowWriterProperties* properties; - // Buffer used for storing the data of an array converted to the physical type - // as expected by parquet-cpp. - std::shared_ptr data_buffer; + // Buffer used for storing the data of an array converted to the + // physical type. As expected by parquet-cpp. + std::shared_ptr dataBuffer; - // We use the shared ownership of this buffer - std::shared_ptr def_levels_buffer; + // We use the shared ownership of this buffer. + std::shared_ptr defLevelsBuffer; }; PARQUET_EXPORT -std::shared_ptr default_arrow_writer_properties(); +std::shared_ptr defaultArrowWriterProperties(); } // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/Schema.cpp b/velox/dwio/parquet/writer/arrow/Schema.cpp index 1280f8a05c6..3a431bb5c70 100644 --- a/velox/dwio/parquet/writer/arrow/Schema.cpp +++ b/velox/dwio/parquet/writer/arrow/Schema.cpp @@ -34,19 +34,18 @@ using facebook::velox::parquet::thrift::SchemaElement; namespace facebook::velox::parquet::arrow { namespace { -void ThrowInvalidLogicalType(const LogicalType& logical_type) { +void throwInvalidLogicalType(const LogicalType& logicalType) { std::stringstream ss; - ss << "Invalid logical type: " << logical_type.ToString(); + ss << "Invalid logical type: " << logicalType.toString(); throw ParquetException(ss.str()); } -void CheckColumnBounds(int column_index, size_t max_columns) { +void checkColumnBounds(int columnIndex, size_t maxColumns) { if (ARROW_PREDICT_FALSE( - column_index < 0 || - static_cast(column_index) >= max_columns)) { + columnIndex < 0 || static_cast(columnIndex) >= maxColumns)) { std::stringstream ss; - ss << "Invalid Column Index: " << column_index - << " Num columns: " << max_columns; + ss << "Invalid Column Index: " << columnIndex + << " Num columns: " << maxColumns; throw ParquetException(ss.str()); } } @@ -55,10 +54,10 @@ void CheckColumnBounds(int column_index, size_t max_columns) { namespace schema { -// ---------------------------------------------------------------------- -// ColumnPath +// ----------------------------------------------------------------------. +// ColumnPath. -std::shared_ptr ColumnPath::FromDotString( +std::shared_ptr ColumnPath::fromDotString( const std::string& dotstring) { std::stringstream ss(dotstring); std::string item; @@ -69,33 +68,33 @@ std::shared_ptr ColumnPath::FromDotString( return std::make_shared(std::move(path)); } -std::shared_ptr ColumnPath::FromNode(const Node& node) { - // Build the path in reverse order as we traverse the nodes to the top +std::shared_ptr ColumnPath::fromNode(const Node& node) { + // Build the path in reverse order as we traverse the nodes to the top. std::vector rpath_; const Node* cursor = &node; - // The schema node is not part of the ColumnPath + // The schema node is not part of the ColumnPath. while (cursor->parent()) { rpath_.push_back(cursor->name()); cursor = cursor->parent(); } - // Build ColumnPath in correct order + // Build ColumnPath in correct order. std::vector path(rpath_.crbegin(), rpath_.crend()); return std::make_shared(std::move(path)); } std::shared_ptr ColumnPath::extend( - const std::string& node_name) const { + const std::string& nodeName) const { std::vector path; path.reserve(path_.size() + 1); path.resize(path_.size() + 1); std::copy(path_.cbegin(), path_.cend(), path.begin()); - path[path_.size()] = node_name; + path[path_.size()] = nodeName; return std::make_shared(std::move(path)); } -std::string ColumnPath::ToDotString() const { +std::string ColumnPath::toDotString() const { std::stringstream ss; for (auto it = path_.cbegin(); it != path_.cend(); ++it) { if (it != path_.cbegin()) { @@ -106,70 +105,69 @@ std::string ColumnPath::ToDotString() const { return ss.str(); } -const std::vector& ColumnPath::ToDotVector() const { +const std::vector& ColumnPath::toDotVector() const { return path_; } -// ---------------------------------------------------------------------- -// Base node +// ----------------------------------------------------------------------. +// Base node. const std::shared_ptr Node::path() const { // TODO(itaiin): Cache the result, or more precisely, cache ->ToDotString() - // since it is being used to access the leaf nodes - return ColumnPath::FromNode(*this); + // Since it is being used to access the leaf nodes. + return ColumnPath::fromNode(*this); } -bool Node::EqualsInternal(const Node* other) const { +bool Node::equalsInternal(const Node* other) const { return type_ == other->type_ && name_ == other->name_ && repetition_ == other->repetition_ && - converted_type_ == other->converted_type_ && - field_id_ == other->field_id() && - logical_type_->Equals(*(other->logical_type())); + convertedType_ == other->convertedType_ && fieldId_ == other->fieldId() && + logicalType_->equals(*(other->logicalType())); } -void Node::SetParent(const Node* parent) { +void Node::setParent(const Node* parent) { parent_ = parent; } -// ---------------------------------------------------------------------- -// Primitive node +// ----------------------------------------------------------------------. +// Primitive node. PrimitiveNode::PrimitiveNode( const std::string& name, Repetition::type repetition, Type::type type, - ConvertedType::type converted_type, + ConvertedType::type convertedType, int length, int precision, int scale, int id) - : Node(Node::PRIMITIVE, name, repetition, converted_type, id), - physical_type_(type), - type_length_(length) { + : Node(Node::kPrimitive, name, repetition, convertedType, id), + physicalType_(type), + typeLength_(length) { std::stringstream ss; - // PARQUET-842: In an earlier revision, decimal_metadata_.isset was being - // set to true, but Impala will raise an incompatible metadata in such cases - memset(&decimal_metadata_, 0, sizeof(decimal_metadata_)); + // PARQUET-842: In an earlier revision, decimalMetadata_.isset was being + // set to true, but Impala will raise an incompatible metadata in such cases. + memset(&decimalMetadata_, 0, sizeof(decimalMetadata_)); - // Check if the physical and logical types match - // Mapping referred from Apache parquet-mr as on 2016-02-22 - switch (converted_type) { - case ConvertedType::NONE: - // Logical type not set + // Check if the physical and logical types match. + // Mapping referred from Apache parquet-mr as on 2016-02-22. + switch (convertedType) { + case ConvertedType::kNone: + // Logical type not set. break; - case ConvertedType::UTF8: - case ConvertedType::JSON: - case ConvertedType::BSON: - if (type != Type::BYTE_ARRAY) { - ss << ConvertedTypeToString(converted_type); + case ConvertedType::kUtf8: + case ConvertedType::kJson: + case ConvertedType::kBson: + if (type != Type::kByteArray) { + ss << convertedTypeToString(convertedType); ss << " can only annotate BYTE_ARRAY fields"; throw ParquetException(ss.str()); } break; - case ConvertedType::DECIMAL: - if ((type != Type::INT32) && (type != Type::INT64) && - (type != Type::BYTE_ARRAY) && (type != Type::FIXED_LEN_BYTE_ARRAY)) { + case ConvertedType::kDecimal: + if ((type != Type::kInt32) && (type != Type::kInt64) && + (type != Type::kByteArray) && (type != Type::kFixedLenByteArray)) { ss << "DECIMAL can only annotate INT32, INT64, BYTE_ARRAY, and FIXED"; throw ParquetException(ss.str()); } @@ -188,174 +186,173 @@ PrimitiveNode::PrimitiveNode( ss << " cannot be greater than precision " << precision; throw ParquetException(ss.str()); } - decimal_metadata_.isset = true; - decimal_metadata_.precision = precision; - decimal_metadata_.scale = scale; + decimalMetadata_.isset = true; + decimalMetadata_.precision = precision; + decimalMetadata_.scale = scale; break; - case ConvertedType::DATE: - case ConvertedType::TIME_MILLIS: - case ConvertedType::UINT_8: - case ConvertedType::UINT_16: - case ConvertedType::UINT_32: - case ConvertedType::INT_8: - case ConvertedType::INT_16: - case ConvertedType::INT_32: - if (type != Type::INT32) { - ss << ConvertedTypeToString(converted_type); + case ConvertedType::kDate: + case ConvertedType::kTimeMillis: + case ConvertedType::kUint8: + case ConvertedType::kUint16: + case ConvertedType::kUint32: + case ConvertedType::kInt8: + case ConvertedType::kInt16: + case ConvertedType::kInt32: + if (type != Type::kInt32) { + ss << convertedTypeToString(convertedType); ss << " can only annotate INT32"; throw ParquetException(ss.str()); } break; - case ConvertedType::TIME_MICROS: - case ConvertedType::TIMESTAMP_MILLIS: - case ConvertedType::TIMESTAMP_MICROS: - case ConvertedType::UINT_64: - case ConvertedType::INT_64: - if (type != Type::INT64) { - ss << ConvertedTypeToString(converted_type); + case ConvertedType::kTimeMicros: + case ConvertedType::kTimestampMillis: + case ConvertedType::kTimestampMicros: + case ConvertedType::kUint64: + case ConvertedType::kInt64: + if (type != Type::kInt64) { + ss << convertedTypeToString(convertedType); ss << " can only annotate INT64"; throw ParquetException(ss.str()); } break; - case ConvertedType::INTERVAL: - if ((type != Type::FIXED_LEN_BYTE_ARRAY) || (length != 12)) { + case ConvertedType::kInterval: + if ((type != Type::kFixedLenByteArray) || (length != 12)) { ss << "INTERVAL can only annotate FIXED_LEN_BYTE_ARRAY(12)"; throw ParquetException(ss.str()); } break; - case ConvertedType::ENUM: - if (type != Type::BYTE_ARRAY) { + case ConvertedType::kEnum: + if (type != Type::kByteArray) { ss << "ENUM can only annotate BYTE_ARRAY fields"; throw ParquetException(ss.str()); } break; - case ConvertedType::NA: - // NA can annotate any type + case ConvertedType::kNa: + // NA can annotate any type. break; default: - ss << ConvertedTypeToString(converted_type); + ss << convertedTypeToString(convertedType); ss << " cannot be applied to a primitive type"; throw ParquetException(ss.str()); } - // For forward compatibility, create an equivalent logical type - logical_type_ = - LogicalType::FromConvertedType(converted_type_, decimal_metadata_); - if (!(logical_type_ && !logical_type_->is_nested() && - logical_type_->is_compatible(converted_type_, decimal_metadata_))) { - ThrowInvalidLogicalType(*logical_type_); + // For forward compatibility, create an equivalent logical type. + logicalType_ = + LogicalType::fromConvertedType(convertedType_, decimalMetadata_); + if (!(logicalType_ && !logicalType_->isNested() && + logicalType_->isCompatible(convertedType_, decimalMetadata_))) { + throwInvalidLogicalType(*logicalType_); } - if (type == Type::FIXED_LEN_BYTE_ARRAY) { + if (type == Type::kFixedLenByteArray) { if (length <= 0) { ss << "Invalid FIXED_LEN_BYTE_ARRAY length: " << length; throw ParquetException(ss.str()); } - type_length_ = length; + typeLength_ = length; } } PrimitiveNode::PrimitiveNode( const std::string& name, Repetition::type repetition, - std::shared_ptr logical_type, - Type::type physical_type, - int physical_length, + std::shared_ptr logicalType, + Type::type physicalType, + int physicalLength, int id) - : Node(Node::PRIMITIVE, name, repetition, std::move(logical_type), id), - physical_type_(physical_type), - type_length_(physical_length) { + : Node(Node::kPrimitive, name, repetition, std::move(logicalType), id), + physicalType_(physicalType), + typeLength_(physicalLength) { std::stringstream error; - if (logical_type_) { - // Check for logical type <=> node type consistency - if (!logical_type_->is_nested()) { - // Check for logical type <=> physical type consistency - if (logical_type_->is_applicable(physical_type, physical_length)) { + if (logicalType_) { + // Check for logical type <=> node type consistency. + if (!logicalType_->isNested()) { + // Check for logical type <=> physical type consistency. + if (logicalType_->isApplicable(physicalType, physicalLength)) { // For backward compatibility, assign equivalent legacy - // converted type (if possible) - converted_type_ = logical_type_->ToConvertedType(&decimal_metadata_); + // converted type (if possible). + convertedType_ = logicalType_->toConvertedType(&decimalMetadata_); } else { - error << logical_type_->ToString(); + error << logicalType_->toString(); error << " can not be applied to primitive type "; - error << TypeToString(physical_type); + error << typeToString(physicalType); throw ParquetException(error.str()); } } else { error << "Nested logical type "; - error << logical_type_->ToString(); + error << logicalType_->toString(); error << " can not be applied to non-group node"; throw ParquetException(error.str()); } } else { - logical_type_ = NoLogicalType::Make(); - converted_type_ = logical_type_->ToConvertedType(&decimal_metadata_); + logicalType_ = NoLogicalType::make(); + convertedType_ = logicalType_->toConvertedType(&decimalMetadata_); } - if (!(logical_type_ && !logical_type_->is_nested() && - logical_type_->is_compatible(converted_type_, decimal_metadata_))) { - ThrowInvalidLogicalType(*logical_type_); + if (!(logicalType_ && !logicalType_->isNested() && + logicalType_->isCompatible(convertedType_, decimalMetadata_))) { + throwInvalidLogicalType(*logicalType_); } - if (physical_type == Type::FIXED_LEN_BYTE_ARRAY) { - if (physical_length <= 0) { - error << "Invalid FIXED_LEN_BYTE_ARRAY length: " << physical_length; + if (physicalType == Type::kFixedLenByteArray) { + if (physicalLength <= 0) { + error << "Invalid FIXED_LEN_BYTE_ARRAY length: " << physicalLength; throw ParquetException(error.str()); } } } -bool PrimitiveNode::EqualsInternal(const PrimitiveNode* other) const { - bool is_equal = true; - if (physical_type_ != other->physical_type_) { +bool PrimitiveNode::equalsInternal(const PrimitiveNode* other) const { + bool isEqual = true; + if (physicalType_ != other->physicalType_) { return false; } - if (converted_type_ == ConvertedType::DECIMAL) { - is_equal &= - (decimal_metadata_.precision == other->decimal_metadata_.precision) && - (decimal_metadata_.scale == other->decimal_metadata_.scale); + if (convertedType_ == ConvertedType::kDecimal) { + isEqual &= + (decimalMetadata_.precision == other->decimalMetadata_.precision) && + (decimalMetadata_.scale == other->decimalMetadata_.scale); } - if (physical_type_ == Type::FIXED_LEN_BYTE_ARRAY) { - is_equal &= (type_length_ == other->type_length_); + if (physicalType_ == Type::kFixedLenByteArray) { + isEqual &= (typeLength_ == other->typeLength_); } - return is_equal; + return isEqual; } -bool PrimitiveNode::Equals(const Node* other) const { - if (!Node::EqualsInternal(other)) { +bool PrimitiveNode::equals(const Node* other) const { + if (!Node::equalsInternal(other)) { return false; } - return EqualsInternal(static_cast(other)); + return equalsInternal(static_cast(other)); } -void PrimitiveNode::Visit(Node::Visitor* visitor) { - visitor->Visit(this); +void PrimitiveNode::visit(Node::Visitor* visitor) { + visitor->visit(this); } -void PrimitiveNode::VisitConst(Node::ConstVisitor* visitor) const { - visitor->Visit(this); +void PrimitiveNode::visitConst(Node::ConstVisitor* visitor) const { + visitor->visit(this); } -// ---------------------------------------------------------------------- -// Group node +// ----------------------------------------------------------------------. +// Group node. GroupNode::GroupNode( const std::string& name, Repetition::type repetition, const NodeVector& fields, - ConvertedType::type converted_type, + ConvertedType::type convertedType, int id) - : Node(Node::GROUP, name, repetition, converted_type, id), fields_(fields) { - // For forward compatibility, create an equivalent logical type - logical_type_ = LogicalType::FromConvertedType(converted_type_); - if (!(logical_type_ && - (logical_type_->is_nested() || logical_type_->is_none()) && - logical_type_->is_compatible(converted_type_))) { - ThrowInvalidLogicalType(*logical_type_); + : Node(Node::kGroup, name, repetition, convertedType, id), fields_(fields) { + // For forward compatibility, create an equivalent logical type. + logicalType_ = LogicalType::fromConvertedType(convertedType_); + if (!(logicalType_ && (logicalType_->isNested() || logicalType_->isNone()) && + logicalType_->isCompatible(convertedType_))) { + throwInvalidLogicalType(*logicalType_); } - field_name_to_idx_.clear(); - auto field_idx = 0; + fieldNameToIdx_.clear(); + auto fieldIdx = 0; for (NodePtr& field : fields_) { - field->SetParent(this); - field_name_to_idx_.emplace(field->name(), field_idx++); + field->setParent(this); + fieldNameToIdx_.emplace(field->name(), fieldIdx++); } } @@ -363,74 +360,73 @@ GroupNode::GroupNode( const std::string& name, Repetition::type repetition, const NodeVector& fields, - std::shared_ptr logical_type, + std::shared_ptr logicalType, int id) - : Node(Node::GROUP, name, repetition, std::move(logical_type), id), + : Node(Node::kGroup, name, repetition, std::move(logicalType), id), fields_(fields) { - if (logical_type_) { - // Check for logical type <=> node type consistency - if (logical_type_->is_nested()) { - // For backward compatibility, assign equivalent legacy converted type (if - // possible) - converted_type_ = logical_type_->ToConvertedType(nullptr); + if (logicalType_) { + // Check for logical type <=> node type consistency. + if (logicalType_->isNested()) { + // For backward compatibility, assign equivalent legacy converted type + // (if possible). + convertedType_ = logicalType_->toConvertedType(nullptr); } else { std::stringstream error; error << "Logical type "; - error << logical_type_->ToString(); + error << logicalType_->toString(); error << " can not be applied to group node"; throw ParquetException(error.str()); } } else { - logical_type_ = NoLogicalType::Make(); - converted_type_ = logical_type_->ToConvertedType(nullptr); + logicalType_ = NoLogicalType::make(); + convertedType_ = logicalType_->toConvertedType(nullptr); } - if (!(logical_type_ && - (logical_type_->is_nested() || logical_type_->is_none()) && - logical_type_->is_compatible(converted_type_))) { - ThrowInvalidLogicalType(*logical_type_); + if (!(logicalType_ && (logicalType_->isNested() || logicalType_->isNone()) && + logicalType_->isCompatible(convertedType_))) { + throwInvalidLogicalType(*logicalType_); } - field_name_to_idx_.clear(); - auto field_idx = 0; + fieldNameToIdx_.clear(); + auto fieldIdx = 0; for (NodePtr& field : fields_) { - field->SetParent(this); - field_name_to_idx_.emplace(field->name(), field_idx++); + field->setParent(this); + fieldNameToIdx_.emplace(field->name(), fieldIdx++); } } -bool GroupNode::EqualsInternal(const GroupNode* other) const { +bool GroupNode::equalsInternal(const GroupNode* other) const { if (this == other) { return true; } - if (this->field_count() != other->field_count()) { + if (this->fieldCount() != other->fieldCount()) { return false; } - for (int i = 0; i < this->field_count(); ++i) { - if (!this->field(i)->Equals(other->field(i).get())) { + for (int i = 0; i < this->fieldCount(); ++i) { + if (!this->field(i)->equals(other->field(i).get())) { return false; } } return true; } -bool GroupNode::Equals(const Node* other) const { - if (!Node::EqualsInternal(other)) { +bool GroupNode::equals(const Node* other) const { + if (!Node::equalsInternal(other)) { return false; } - return EqualsInternal(static_cast(other)); + return equalsInternal(static_cast(other)); } -int GroupNode::FieldIndex(const std::string& name) const { - auto search = field_name_to_idx_.find(name); - if (search == field_name_to_idx_.end()) { - // Not found +int GroupNode::fieldIndex(const std::string& name) const { + auto search = fieldNameToIdx_.find(name); + if (search == fieldNameToIdx_.end()) { + // Not found. return -1; } return search->second; } -int GroupNode::FieldIndex(const Node& node) const { - auto search = field_name_to_idx_.equal_range(node.name()); +int GroupNode::fieldIndex(const Node& node) const { + auto search = fieldNameToIdx_.equal_range(node.name()); for (auto it = search.first; it != search.second; ++it) { const int idx = it->second; if (&node == field(idx).get()) { @@ -440,223 +436,222 @@ int GroupNode::FieldIndex(const Node& node) const { return -1; } -void GroupNode::Visit(Node::Visitor* visitor) { - visitor->Visit(this); +void GroupNode::visit(Node::Visitor* visitor) { + visitor->visit(this); } -void GroupNode::VisitConst(Node::ConstVisitor* visitor) const { - visitor->Visit(this); +void GroupNode::visitConst(Node::ConstVisitor* visitor) const { + visitor->visit(this); } -// ---------------------------------------------------------------------- -// Node construction from Parquet metadata +// ----------------------------------------------------------------------. +// Node construction from Parquet metadata. -std::unique_ptr GroupNode::FromParquet( - const void* opaque_element, +std::unique_ptr GroupNode::fromParquet( + const void* opaqueElement, NodeVector fields) { const facebook::velox::parquet::thrift::SchemaElement* element = static_cast( - opaque_element); + opaqueElement); - int field_id = -1; + int fieldId = -1; if (element->__isset.field_id) { - field_id = element->field_id; + fieldId = element->field_id; } - std::unique_ptr group_node; + std::unique_ptr groupNode; if (element->__isset.logicalType) { - // updated writer with logical type present - group_node = std::unique_ptr(new GroupNode( + // Updated writer with logical type present. + groupNode = std::unique_ptr(new GroupNode( element->name, - LoadEnumSafe(&element->repetition_type), + loadenumSafe(&element->repetition_type), fields, - LogicalType::FromThrift(element->logicalType), - field_id)); + LogicalType::fromThrift(element->logicalType), + fieldId)); } else { - group_node = std::unique_ptr(new GroupNode( + groupNode = std::unique_ptr(new GroupNode( element->name, - LoadEnumSafe(&element->repetition_type), + loadenumSafe(&element->repetition_type), fields, (element->__isset.converted_type - ? LoadEnumSafe(&element->converted_type) - : ConvertedType::NONE), - field_id)); + ? loadenumSafe(&element->converted_type) + : ConvertedType::kNone), + fieldId)); } - return std::unique_ptr(group_node.release()); + return std::unique_ptr(groupNode.release()); } -std::unique_ptr PrimitiveNode::FromParquet(const void* opaque_element) { +std::unique_ptr PrimitiveNode::fromParquet(const void* opaqueElement) { const facebook::velox::parquet::thrift::SchemaElement* element = static_cast( - opaque_element); + opaqueElement); - int field_id = -1; + int fieldId = -1; if (element->__isset.field_id) { - field_id = element->field_id; + fieldId = element->field_id; } - std::unique_ptr primitive_node; + std::unique_ptr primitiveNode; if (element->__isset.logicalType) { - // updated writer with logical type present - primitive_node = std::unique_ptr(new PrimitiveNode( + // Updated writer with logical type present. + primitiveNode = std::unique_ptr(new PrimitiveNode( element->name, - LoadEnumSafe(&element->repetition_type), - LogicalType::FromThrift(element->logicalType), - LoadEnumSafe(&element->type), + loadenumSafe(&element->repetition_type), + LogicalType::fromThrift(element->logicalType), + loadenumSafe(&element->type), element->type_length, - field_id)); + fieldId)); } else if (element->__isset.converted_type) { - // legacy writer with converted type present - primitive_node = std::unique_ptr(new PrimitiveNode( + // Legacy writer with converted type present. + primitiveNode = std::unique_ptr(new PrimitiveNode( element->name, - LoadEnumSafe(&element->repetition_type), - LoadEnumSafe(&element->type), - LoadEnumSafe(&element->converted_type), + loadenumSafe(&element->repetition_type), + loadenumSafe(&element->type), + loadenumSafe(&element->converted_type), element->type_length, element->precision, element->scale, - field_id)); + fieldId)); } else { - // logical type not present - primitive_node = std::unique_ptr(new PrimitiveNode( + // Logical type not present. + primitiveNode = std::unique_ptr(new PrimitiveNode( element->name, - LoadEnumSafe(&element->repetition_type), - NoLogicalType::Make(), - LoadEnumSafe(&element->type), + loadenumSafe(&element->repetition_type), + NoLogicalType::make(), + loadenumSafe(&element->type), element->type_length, - field_id)); + fieldId)); } - // Return as unique_ptr to the base type - return std::unique_ptr(primitive_node.release()); + // Return as unique_ptr to the base type. + return std::unique_ptr(primitiveNode.release()); } -bool GroupNode::HasRepeatedFields() const { - for (int i = 0; i < this->field_count(); ++i) { +bool GroupNode::hasRepeatedFields() const { + for (int i = 0; i < this->fieldCount(); ++i) { auto field = this->field(i); - if (field->repetition() == Repetition::REPEATED) { + if (field->repetition() == Repetition::kRepeated) { return true; } - if (field->is_group()) { + if (field->isGroup()) { const auto& group = static_cast(*field); - return group.HasRepeatedFields(); + return group.hasRepeatedFields(); } } return false; } -void GroupNode::ToParquet(void* opaque_element) const { +void GroupNode::toParquet(void* opaqueElement) const { facebook::velox::parquet::thrift::SchemaElement* element = static_cast( - opaque_element); + opaqueElement); element->__set_name(name_); - element->__set_num_children(field_count()); - element->__set_repetition_type(ToThrift(repetition_)); - if (converted_type_ != ConvertedType::NONE) { - element->__set_converted_type(ToThrift(converted_type_)); + element->__set_num_children(fieldCount()); + element->__set_repetition_type(toThrift(repetition_)); + if (convertedType_ != ConvertedType::kNone) { + element->__set_converted_type(toThrift(convertedType_)); } - if (field_id_ >= 0) { - element->__set_field_id(field_id_); + if (fieldId_ >= 0) { + element->__set_field_id(fieldId_); } - if (logical_type_ && logical_type_->is_serialized()) { - element->__set_logicalType(logical_type_->ToThrift()); + if (logicalType_ && logicalType_->isSerialized()) { + element->__set_logicalType(logicalType_->toThrift()); } return; } -void PrimitiveNode::ToParquet(void* opaque_element) const { +void PrimitiveNode::toParquet(void* opaqueElement) const { facebook::velox::parquet::thrift::SchemaElement* element = static_cast( - opaque_element); + opaqueElement); element->__set_name(name_); - element->__set_repetition_type(ToThrift(repetition_)); - if (converted_type_ != ConvertedType::NONE) { - if (converted_type_ != ConvertedType::NA) { - element->__set_converted_type(ToThrift(converted_type_)); + element->__set_repetition_type(toThrift(repetition_)); + if (convertedType_ != ConvertedType::kNone) { + if (convertedType_ != ConvertedType::kNa) { + element->__set_converted_type(toThrift(convertedType_)); } else { - // ConvertedType::NA is an unreleased, obsolete synonym for - // LogicalType::Null. Never emit it (see PARQUET-1990 for discussion). - if (!logical_type_ || !logical_type_->is_null()) { + // ConvertedType::kNa is an unreleased, obsolete synonym for. + // LogicalType::nullType. Never emit it (see PARQUET-1990 for discussion). + if (!logicalType_ || !logicalType_->isNull()) { throw ParquetException( - "ConvertedType::NA is obsolete, please use LogicalType::Null instead"); + "ConvertedType::kNa is obsolete, please use LogicalType::nullType instead"); } } } - if (field_id_ >= 0) { - element->__set_field_id(field_id_); + if (fieldId_ >= 0) { + element->__set_field_id(fieldId_); } - if (logical_type_ && logical_type_->is_serialized() && - // TODO(tpboudreau): remove the following conjunct to enable serialization - // of IntervalTypes after parquet.thrift recognizes them - !logical_type_->is_interval()) { - element->__set_logicalType(logical_type_->ToThrift()); + if (logicalType_ && logicalType_->isSerialized() && + // TODO(tpboudreau): remove the following conjunct to enable + // serialization. Of IntervalTypes after parquet.thrift recognizes them. + !logicalType_->isInterval()) { + element->__set_logicalType(logicalType_->toThrift()); } - element->__set_type(ToThrift(physical_type_)); - if (physical_type_ == Type::FIXED_LEN_BYTE_ARRAY) { - element->__set_type_length(type_length_); + element->__set_type(toThrift(physicalType_)); + if (physicalType_ == Type::kFixedLenByteArray) { + element->__set_type_length(typeLength_); } - if (decimal_metadata_.isset) { - element->__set_precision(decimal_metadata_.precision); - element->__set_scale(decimal_metadata_.scale); + if (decimalMetadata_.isset) { + element->__set_precision(decimalMetadata_.precision); + element->__set_scale(decimalMetadata_.scale); } return; } -// ---------------------------------------------------------------------- -// Schema converters +// ----------------------------------------------------------------------. +// Schema converters. -std::unique_ptr Unflatten( +std::unique_ptr unflatten( const facebook::velox::parquet::thrift::SchemaElement* elements, int length) { if (elements[0].num_children == 0) { if (length == 1) { - // Degenerate case of Parquet file with no columns - return GroupNode::FromParquet(elements, {}); + // Degenerate case of Parquet file with no columns. + return GroupNode::fromParquet(elements, {}); } else { throw ParquetException( "Parquet schema had multiple nodes but root had no children"); } } - // We don't check that the root node is repeated since this is not - // consistently set by implementations + // We don't check that the root node is repeated since this is not. + // Consistently set by implementations. int pos = 0; - std::function()> NextNode = [&]() { + std::function()> nextNode = [&]() { if (pos == length) { throw ParquetException("Malformed schema: not enough elements"); } const SchemaElement& element = elements[pos++]; - const void* opaque_element = static_cast(&element); + const void* opaqueElement = static_cast(&element); if (element.num_children == 0 && element.__isset.type) { - // Leaf (primitive) node: always has a type - return PrimitiveNode::FromParquet(opaque_element); + // Leaf (primitive) node: always has a type. + return PrimitiveNode::fromParquet(opaqueElement); } else { // Group node (may have 0 children, but cannot have a type) NodeVector fields; for (int i = 0; i < element.num_children; ++i) { - std::unique_ptr field = NextNode(); - fields.push_back(NodePtr(field.release())); + fields.emplace_back(nextNode()); } - return GroupNode::FromParquet(opaque_element, std::move(fields)); + return GroupNode::fromParquet(opaqueElement, std::move(fields)); } }; - return NextNode(); + return nextNode(); } -std::shared_ptr FromParquet( +std::shared_ptr fromParquet( const std::vector& schema) { if (schema.empty()) { throw ParquetException("Empty file schema (no root)"); } std::unique_ptr root = - Unflatten(&schema[0], static_cast(schema.size())); + unflatten(&schema[0], static_cast(schema.size())); std::shared_ptr descr = std::make_shared(); - descr->Init( + descr->init( std::shared_ptr(static_cast(root.release()))); return descr; } @@ -667,15 +662,15 @@ class SchemaVisitor : public Node::ConstVisitor { std::vector* elements) : elements_(elements) {} - void Visit(const Node* node) override { + void visit(const Node* node) override { facebook::velox::parquet::thrift::SchemaElement element; - node->ToParquet(&element); + node->toParquet(&element); elements_->push_back(element); - if (node->is_group()) { - const GroupNode* group_node = static_cast(node); - for (int i = 0; i < group_node->field_count(); ++i) { - group_node->field(i)->VisitConst(this); + if (node->isGroup()) { + const GroupNode* groupNode = static_cast(node); + for (int i = 0; i < groupNode->fieldCount(); ++i) { + groupNode->field(i)->visitConst(this); } } } @@ -684,25 +679,25 @@ class SchemaVisitor : public Node::ConstVisitor { std::vector* elements_; }; -void ToParquet( +void toParquet( const GroupNode* schema, std::vector* out) { SchemaVisitor visitor(out); - schema->VisitConst(&visitor); + schema->visitConst(&visitor); } -// ---------------------------------------------------------------------- -// Schema printing +// ----------------------------------------------------------------------. +// Schema printing. -static void PrintRepLevel(Repetition::type repetition, std::ostream& stream) { +static void printRepLevel(Repetition::type repetition, std::ostream& stream) { switch (repetition) { - case Repetition::REQUIRED: + case Repetition::kRequired: stream << "required"; break; - case Repetition::OPTIONAL: + case Repetition::kOptional: stream << "optional"; break; - case Repetition::REPEATED: + case Repetition::kRepeated: stream << "repeated"; break; default: @@ -710,113 +705,113 @@ static void PrintRepLevel(Repetition::type repetition, std::ostream& stream) { } } -static void PrintType(const PrimitiveNode* node, std::ostream& stream) { - switch (node->physical_type()) { - case Type::BOOLEAN: +static void printType(const PrimitiveNode* Node, std::ostream& stream) { + switch (Node->physicalType()) { + case Type::kBoolean: stream << "boolean"; break; - case Type::INT32: + case Type::kInt32: stream << "int32"; break; - case Type::INT64: + case Type::kInt64: stream << "int64"; break; - case Type::INT96: + case Type::kInt96: stream << "int96"; break; - case Type::FLOAT: + case Type::kFloat: stream << "float"; break; - case Type::DOUBLE: + case Type::kDouble: stream << "double"; break; - case Type::BYTE_ARRAY: + case Type::kByteArray: stream << "binary"; break; - case Type::FIXED_LEN_BYTE_ARRAY: - stream << "fixed_len_byte_array(" << node->type_length() << ")"; + case Type::kFixedLenByteArray: + stream << "fixed_len_byte_array(" << Node->typeLength() << ")"; break; default: break; } } -static void PrintConvertedType( - const PrimitiveNode* node, +static void printConvertedType( + const PrimitiveNode* Node, std::ostream& stream) { - auto lt = node->converted_type(); - auto la = node->logical_type(); - if (la && la->is_valid() && !la->is_none()) { - stream << " (" << la->ToString() << ")"; - } else if (lt == ConvertedType::DECIMAL) { - stream << " (" << ConvertedTypeToString(lt) << "(" - << node->decimal_metadata().precision << "," - << node->decimal_metadata().scale << "))"; - } else if (lt != ConvertedType::NONE) { - stream << " (" << ConvertedTypeToString(lt) << ")"; + auto lt = Node->convertedType(); + auto la = Node->logicalType(); + if (la && la->isValid() && !la->isNone()) { + stream << " (" << la->toString() << ")"; + } else if (lt == ConvertedType::kDecimal) { + stream << " (" << convertedTypeToString(lt) << "(" + << Node->decimalMetadata().precision << "," + << Node->decimalMetadata().scale << "))"; + } else if (lt != ConvertedType::kNone) { + stream << " (" << convertedTypeToString(lt) << ")"; } } struct SchemaPrinter : public Node::ConstVisitor { - explicit SchemaPrinter(std::ostream& stream, int indent_width) - : stream_(stream), indent_(0), indent_width_(2) {} + explicit SchemaPrinter(std::ostream& stream, int indentWidth) + : stream_(stream), indent_(0), indentWidth_(2) {} - void Indent() { + void indent() { if (indent_ > 0) { std::string spaces(indent_, ' '); stream_ << spaces; } } - void Visit(const Node* node) { - Indent(); - if (node->is_group()) { - Visit(static_cast(node)); + void visit(const Node* Node) { + indent(); + if (Node->isGroup()) { + visit(static_cast(Node)); } else { - // Primitive - Visit(static_cast(node)); + // Primitive. + visit(static_cast(Node)); } } - void Visit(const PrimitiveNode* node) { - PrintRepLevel(node->repetition(), stream_); + void visit(const PrimitiveNode* Node) { + printRepLevel(Node->repetition(), stream_); stream_ << " "; - PrintType(node, stream_); - stream_ << " field_id=" << node->field_id() << " " << node->name(); - PrintConvertedType(node, stream_); + printType(Node, stream_); + stream_ << " field_id=" << Node->fieldId() << " " << Node->name(); + printConvertedType(Node, stream_); stream_ << ";" << std::endl; } - void Visit(const GroupNode* node) { - PrintRepLevel(node->repetition(), stream_); - stream_ << " group " << "field_id=" << node->field_id() << " " - << node->name(); - auto lt = node->converted_type(); - auto la = node->logical_type(); - if (la && la->is_valid() && !la->is_none()) { - stream_ << " (" << la->ToString() << ")"; - } else if (lt != ConvertedType::NONE) { - stream_ << " (" << ConvertedTypeToString(lt) << ")"; + void visit(const GroupNode* Node) { + printRepLevel(Node->repetition(), stream_); + stream_ << " group " << "field_id=" << Node->fieldId() << " " + << Node->name(); + auto lt = Node->convertedType(); + auto la = Node->logicalType(); + if (la && la->isValid() && !la->isNone()) { + stream_ << " (" << la->toString() << ")"; + } else if (lt != ConvertedType::kNone) { + stream_ << " (" << convertedTypeToString(lt) << ")"; } stream_ << " {" << std::endl; - indent_ += indent_width_; - for (int i = 0; i < node->field_count(); ++i) { - node->field(i)->VisitConst(this); + indent_ += indentWidth_; + for (int i = 0; i < Node->fieldCount(); ++i) { + Node->field(i)->visitConst(this); } - indent_ -= indent_width_; - Indent(); + indent_ -= indentWidth_; + indent(); stream_ << "}" << std::endl; } std::ostream& stream_; int indent_; - int indent_width_; + int indentWidth_; }; -void PrintSchema(const Node* schema, std::ostream& stream, int indent_width) { - SchemaPrinter printer(stream, indent_width); - printer.Visit(schema); +void printSchema(const Node* schema, std::ostream& stream, int indentWidth) { + SchemaPrinter printer(stream, indentWidth); + printer.visit(schema); } } // namespace schema @@ -827,74 +822,74 @@ using schema::Node; using schema::NodePtr; using schema::PrimitiveNode; -void SchemaDescriptor::Init(std::unique_ptr schema) { - Init(NodePtr(schema.release())); +void SchemaDescriptor::init(std::unique_ptr schema) { + init(NodePtr(schema.release())); } class SchemaUpdater : public Node::Visitor { public: - explicit SchemaUpdater(const std::vector& column_orders) - : column_orders_(column_orders), leaf_count_(0) {} - - void Visit(Node* node) override { - if (node->is_group()) { - GroupNode* group_node = static_cast(node); - for (int i = 0; i < group_node->field_count(); ++i) { - group_node->field(i)->Visit(this); + explicit SchemaUpdater(const std::vector& columnOrders) + : columnOrders_(columnOrders), leafCount_(0) {} + + void visit(Node* node) override { + if (node->isGroup()) { + GroupNode* groupNode = static_cast(node); + for (int i = 0; i < groupNode->fieldCount(); ++i) { + groupNode->field(i)->visit(this); } } else { // leaf node - PrimitiveNode* leaf_node = static_cast(node); - leaf_node->SetColumnOrder(column_orders_[leaf_count_++]); + PrimitiveNode* leafNode = static_cast(node); + leafNode->setColumnOrder(columnOrders_[leafCount_++]); } } private: - const std::vector& column_orders_; - int leaf_count_; + const std::vector& columnOrders_; + int leafCount_; }; void SchemaDescriptor::updateColumnOrders( - const std::vector& column_orders) { - if (static_cast(column_orders.size()) != num_columns()) { + const std::vector& columnOrders) { + if (static_cast(columnOrders.size()) != numColumns()) { throw ParquetException("Malformed schema: not enough ColumnOrder values"); } - SchemaUpdater visitor(column_orders); - const_cast(group_node_)->Visit(&visitor); + SchemaUpdater visitor(columnOrders); + const_cast(groupNode_)->visit(&visitor); } -void SchemaDescriptor::Init(NodePtr schema) { +void SchemaDescriptor::init(NodePtr schema) { schema_ = std::move(schema); - if (!schema_->is_group()) { + if (!schema_->isGroup()) { throw ParquetException("Must initialize with a schema group"); } - group_node_ = static_cast(schema_.get()); + groupNode_ = static_cast(schema_.get()); leaves_.clear(); - for (int i = 0; i < group_node_->field_count(); ++i) { - BuildTree(group_node_->field(i), 0, 0, group_node_->field(i)); + for (int i = 0; i < groupNode_->fieldCount(); ++i) { + buildTree(groupNode_->field(i), 0, 0, groupNode_->field(i)); } } -bool SchemaDescriptor::Equals( +bool SchemaDescriptor::equals( const SchemaDescriptor& other, - std::ostream* diff_output) const { - if (this->num_columns() != other.num_columns()) { - if (diff_output != nullptr) { - *diff_output << "This schema has " << this->num_columns() - << " columns, other has " << other.num_columns(); + std::ostream* diffOutput) const { + if (this->numColumns() != other.numColumns()) { + if (diffOutput != nullptr) { + *diffOutput << "This schema has " << this->numColumns() + << " columns, other has " << other.numColumns(); } return false; } - for (int i = 0; i < this->num_columns(); ++i) { - if (!this->Column(i)->Equals(*other.Column(i))) { - if (diff_output != nullptr) { - *diff_output << "The two columns with index " << i << " differ." - << std::endl - << this->Column(i)->ToString() << std::endl - << other.Column(i)->ToString() << std::endl; + for (int i = 0; i < this->numColumns(); ++i) { + if (!this->column(i)->equals(*other.column(i))) { + if (diffOutput != nullptr) { + *diffOutput << "The two columns with index " << i << " differ." + << std::endl + << this->column(i)->toString() << std::endl + << other.column(i)->toString() << std::endl; } return false; } @@ -903,149 +898,147 @@ bool SchemaDescriptor::Equals( return true; } -void SchemaDescriptor::BuildTree( - const NodePtr& node, - int16_t max_def_level, - int16_t max_rep_level, +void SchemaDescriptor::buildTree( + const NodePtr& Node, + int16_t maxDefLevel, + int16_t maxRepLevel, const NodePtr& base) { - if (node->is_optional()) { - ++max_def_level; - } else if (node->is_repeated()) { + if (Node->isOptional()) { + ++maxDefLevel; + } else if (Node->isRepeated()) { // Repeated fields add a definition level. This is used to distinguish // between an empty list and a list with an item in it. - ++max_rep_level; - ++max_def_level; + ++maxRepLevel; + ++maxDefLevel; } - // Now, walk the schema and create a ColumnDescriptor for each leaf node - if (node->is_group()) { - const GroupNode* group = static_cast(node.get()); - for (int i = 0; i < group->field_count(); ++i) { - BuildTree(group->field(i), max_def_level, max_rep_level, base); + // Now, walk the schema and create a ColumnDescriptor for each leaf node. + if (Node->isGroup()) { + const GroupNode* group = static_cast(Node.get()); + for (int i = 0; i < group->fieldCount(); ++i) { + buildTree(group->field(i), maxDefLevel, maxRepLevel, base); } } else { - node_to_leaf_index_[static_cast(node.get())] = + nodeToLeafIndex_[static_cast(Node.get())] = static_cast(leaves_.size()); - // Primitive node, append to leaves - leaves_.push_back( - ColumnDescriptor(node, max_def_level, max_rep_level, this)); - leaf_to_base_.emplace(static_cast(leaves_.size()) - 1, base); - leaf_to_idx_.emplace( - node->path()->ToDotString(), static_cast(leaves_.size()) - 1); + // Primitive node, append to leaves. + leaves_.emplace_back(Node, maxDefLevel, maxRepLevel, this); + leafToBase_.emplace(static_cast(leaves_.size()) - 1, base); + leafToIdx_.emplace( + Node->path()->toDotString(), static_cast(leaves_.size()) - 1); } } -int SchemaDescriptor::GetColumnIndex(const PrimitiveNode& node) const { - auto it = node_to_leaf_index_.find(&node); - if (it == node_to_leaf_index_.end()) { +int SchemaDescriptor::getColumnIndex(const PrimitiveNode& Node) const { + auto it = nodeToLeafIndex_.find(&Node); + if (it == nodeToLeafIndex_.end()) { return -1; } return it->second; } ColumnDescriptor::ColumnDescriptor( - schema::NodePtr node, - int16_t max_definition_level, - int16_t max_repetition_level, - const SchemaDescriptor* schema_descr) - : node_(std::move(node)), - max_definition_level_(max_definition_level), - max_repetition_level_(max_repetition_level) { - if (!node_->is_primitive()) { + schema::NodePtr Node, + int16_t maxDefinitionLevel, + int16_t maxRepetitionLevel, + const SchemaDescriptor* schemaDescr) + : node_(std::move(Node)), + maxDefinitionLevel_(maxDefinitionLevel), + maxRepetitionLevel_(maxRepetitionLevel) { + if (!node_->isPrimitive()) { throw ParquetException("Must be a primitive type"); } - primitive_node_ = static_cast(node_.get()); + primitiveNode_ = static_cast(node_.get()); } -bool ColumnDescriptor::Equals(const ColumnDescriptor& other) const { - return primitive_node_->Equals(other.primitive_node_) && - max_repetition_level() == other.max_repetition_level() && - max_definition_level() == other.max_definition_level(); +bool ColumnDescriptor::equals(const ColumnDescriptor& other) const { + return primitiveNode_->equals(other.primitiveNode_) && + maxRepetitionLevel() == other.maxRepetitionLevel() && + maxDefinitionLevel() == other.maxDefinitionLevel(); } -const ColumnDescriptor* SchemaDescriptor::Column(int i) const { - CheckColumnBounds(i, leaves_.size()); +const ColumnDescriptor* SchemaDescriptor::column(int i) const { + checkColumnBounds(i, leaves_.size()); return &leaves_[i]; } -int SchemaDescriptor::ColumnIndex(const std::string& node_path) const { - auto search = leaf_to_idx_.find(node_path); - if (search == leaf_to_idx_.end()) { - // Not found +int SchemaDescriptor::columnIndex(const std::string& nodePath) const { + auto search = leafToIdx_.find(nodePath); + if (search == leafToIdx_.end()) { + // Not found. return -1; } return search->second; } -int SchemaDescriptor::ColumnIndex(const Node& node) const { - auto search = leaf_to_idx_.equal_range(node.path()->ToDotString()); +int SchemaDescriptor::columnIndex(const Node& node) const { + auto search = leafToIdx_.equal_range(node.path()->toDotString()); for (auto it = search.first; it != search.second; ++it) { const int idx = it->second; - if (&node == Column(idx)->schema_node().get()) { + if (&node == column(idx)->schemaNode().get()) { return idx; } } return -1; } -const schema::Node* SchemaDescriptor::GetColumnRoot(int i) const { - CheckColumnBounds(i, leaves_.size()); - return leaf_to_base_.find(i)->second.get(); +const schema::Node* SchemaDescriptor::getColumnRoot(int i) const { + checkColumnBounds(i, leaves_.size()); + return leafToBase_.find(i)->second.get(); } -bool SchemaDescriptor::HasRepeatedFields() const { - return group_node_->HasRepeatedFields(); +bool SchemaDescriptor::hasRepeatedFields() const { + return groupNode_->hasRepeatedFields(); } -std::string SchemaDescriptor::ToString() const { +std::string SchemaDescriptor::toString() const { std::ostringstream ss; - PrintSchema(schema_.get(), ss); + printSchema(schema_.get(), ss); return ss.str(); } -std::string ColumnDescriptor::ToString() const { +std::string ColumnDescriptor::toString() const { std::ostringstream ss; ss << "column descriptor = {" << std::endl << " name: " << name() << "," << std::endl - << " path: " << path()->ToDotString() << "," << std::endl - << " physical_type: " << TypeToString(physical_type()) << "," << std::endl - << " converted_type: " << ConvertedTypeToString(converted_type()) << "," + << " path: " << path()->toDotString() << "," << std::endl + << " physical_type: " << typeToString(physicalType()) << "," << std::endl + << " converted_type: " << convertedTypeToString(convertedType()) << "," << std::endl - << " logical_type: " << logical_type()->ToString() << "," << std::endl - << " max_definition_level: " << max_definition_level() << "," << std::endl - << " max_repetition_level: " << max_repetition_level() << "," - << std::endl; - - if (physical_type() == - ::facebook::velox::parquet::arrow::Type::FIXED_LEN_BYTE_ARRAY) { - ss << " length: " << type_length() << "," << std::endl; + << " logical_type: " << logicalType()->toString() << "," << std::endl + << " max_definition_level: " << maxDefinitionLevel() << "," << std::endl + << " max_repetition_level: " << maxRepetitionLevel() << "," << std::endl; + + if (physicalType() == + ::facebook::velox::parquet::arrow::Type::kFixedLenByteArray) { + ss << " length: " << typeLength() << "," << std::endl; } - if (converted_type() == - ::facebook::velox::parquet::arrow::ConvertedType::DECIMAL) { - ss << " precision: " << type_precision() << "," << std::endl - << " scale: " << type_scale() << "," << std::endl; + if (convertedType() == + ::facebook::velox::parquet::arrow::ConvertedType::kDecimal) { + ss << " precision: " << typePrecision() << "," << std::endl + << " scale: " << typeScale() << "," << std::endl; } ss << "}"; return ss.str(); } -int ColumnDescriptor::type_scale() const { - return primitive_node_->decimal_metadata().scale; +int ColumnDescriptor::typeScale() const { + return primitiveNode_->decimalMetadata().scale; } -int ColumnDescriptor::type_precision() const { - return primitive_node_->decimal_metadata().precision; +int ColumnDescriptor::typePrecision() const { + return primitiveNode_->decimalMetadata().precision; } -int ColumnDescriptor::type_length() const { - return primitive_node_->type_length(); +int ColumnDescriptor::typeLength() const { + return primitiveNode_->typeLength(); } const std::shared_ptr ColumnDescriptor::path() const { - return primitive_node_->path(); + return primitiveNode_->path(); } } // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/Schema.h b/velox/dwio/parquet/writer/arrow/Schema.h index e71493d751d..d9f77840024 100644 --- a/velox/dwio/parquet/writer/arrow/Schema.h +++ b/velox/dwio/parquet/writer/arrow/Schema.h @@ -16,8 +16,8 @@ // Adapted from Apache Arrow. -// This module contains the logical parquet-cpp types (independent of Thrift -// structures), schema nodes, and related type tools +// This module contains the logical Parquet-cpp types (independent of Thrift +// structures), schema nodes, and related type tools. #pragma once @@ -42,23 +42,23 @@ class Node; // List encodings: using the terminology from Impala to define different styles // of representing logical lists (a.k.a. ARRAY types) in Parquet schemas. Since -// the converted type named in the Parquet metadata is ConvertedType::LIST we +// the converted type named in the Parquet metadata is ConvertedType::kList we // use that terminology here. It also helps distinguish from the *_ARRAY // primitive types. // -// One-level encoding: Only allows required lists with required cells -// repeated value_type name +// One-level encoding: Only allows required lists with required cells. +// Repeated value_type name. // -// Two-level encoding: Enables optional lists with only required cells -// group list -// repeated value_type item +// Two-level encoding: Enables optional lists with only required cells. +// group list. +// Repeated value_type item. // -// Three-level encoding: Enables optional lists with optional cells -// group bag -// repeated group list -// value_type item +// Three-level encoding: Enables optional lists with optional cells. +// group bag. +// Repeated group list. +// value_type item. // -// 2- and 1-level encoding are respectively equivalent to 3-level encoding with +// 2- And 1-level encoding are respectively equivalent to 3-level encoding with // the non-repeated nodes set to required. // // The "official" encoding recommended in the Parquet spec is the 3-level, and @@ -67,13 +67,13 @@ class Node; // "in the wild" we need to be able to interpret the associated definition // levels in the context of the actual encoding used in the file. // -// NB: Some Parquet writers may not set ConvertedType::LIST on the repeated +// NB: Some Parquet writers may not set ConvertedType::kList on the repeated // SchemaElement, which could make things challenging if we are trying to infer // that a sequence of nodes semantically represents an array according to one // of these encodings (versus a struct containing an array). We should refuse // the temptation to guess, as they say. struct ListEncoding { - enum type { ONE_LEVEL, TWO_LEVEL, THREE_LEVEL }; + enum type { kOneLevel, kTwoLevel, kThreeLevel }; }; class PARQUET_EXPORT ColumnPath { @@ -83,53 +83,53 @@ class PARQUET_EXPORT ColumnPath { explicit ColumnPath(std::vector&& path) : path_(std::move(path)) {} - static std::shared_ptr FromDotString( + static std::shared_ptr fromDotString( const std::string& dotstring); - static std::shared_ptr FromNode(const Node& node); + static std::shared_ptr fromNode(const Node& Node); - std::shared_ptr extend(const std::string& node_name) const; - std::string ToDotString() const; - const std::vector& ToDotVector() const; + std::shared_ptr extend(const std::string& nodeName) const; + std::string toDotString() const; + const std::vector& toDotVector() const; protected: std::vector path_; }; // Base class for logical schema types. A type has a name, repetition level, -// and optionally a logical type (ConvertedType in Parquet metadata parlance) +// and optionally a logical type (ConvertedType in Parquet metadata parlance).. class PARQUET_EXPORT Node { public: - enum type { PRIMITIVE, GROUP }; + enum type { kPrimitive, kGroup }; virtual ~Node() {} - bool is_primitive() const { - return type_ == Node::PRIMITIVE; + bool isPrimitive() const { + return type_ == Node::kPrimitive; } - bool is_group() const { - return type_ == Node::GROUP; + bool isGroup() const { + return type_ == Node::kGroup; } - bool is_optional() const { - return repetition_ == Repetition::OPTIONAL; + bool isOptional() const { + return repetition_ == Repetition::kOptional; } - bool is_repeated() const { - return repetition_ == Repetition::REPEATED; + bool isRepeated() const { + return repetition_ == Repetition::kRepeated; } - bool is_required() const { - return repetition_ == Repetition::REQUIRED; + bool isRequired() const { + return repetition_ == Repetition::kRequired; } - virtual bool Equals(const Node* other) const = 0; + virtual bool equals(const Node* other) const = 0; const std::string& name() const { return name_; } - Node::type node_type() const { + Node::type nodeType() const { return type_; } @@ -137,19 +137,19 @@ class PARQUET_EXPORT Node { return repetition_; } - ConvertedType::type converted_type() const { - return converted_type_; + ConvertedType::type convertedType() const { + return convertedType_; } - const std::shared_ptr& logical_type() const { - return logical_type_; + const std::shared_ptr& logicalType() const { + return logicalType_; } - /// \brief The field_id value for the serialized SchemaElement. If the - /// field_id is less than 0 (e.g. -1), it will not be set when serialized to + /// \brief The fieldId value for the serialized SchemaElement. If the + /// fieldId is less than 0 (e.g. -1), it will not be set when serialized to /// Thrift. - int field_id() const { - return field_id_; + int fieldId() const { + return fieldId_; } const Node* parent() const { @@ -158,24 +158,24 @@ class PARQUET_EXPORT Node { const std::shared_ptr path() const; - virtual void ToParquet(void* element) const = 0; + virtual void toParquet(void* element) const = 0; - // Node::Visitor abstract class for walking schemas with the visitor pattern + // Node::Visitor abstract class for walking schemas with the visitor pattern. class Visitor { public: virtual ~Visitor() {} - virtual void Visit(Node* node) = 0; + virtual void visit(Node* Node) = 0; }; class ConstVisitor { public: virtual ~ConstVisitor() {} - virtual void Visit(const Node* node) = 0; + virtual void visit(const Node* Node) = 0; }; - virtual void Visit(Visitor* visitor) = 0; - virtual void VisitConst(ConstVisitor* visitor) const = 0; + virtual void visit(Visitor* visitor) = 0; + virtual void visitConst(ConstVisitor* visitor) const = 0; protected: friend class GroupNode; @@ -184,188 +184,188 @@ class PARQUET_EXPORT Node { Node::type type, const std::string& name, Repetition::type repetition, - ConvertedType::type converted_type = ConvertedType::NONE, - int field_id = -1) + ConvertedType::type convertedType = ConvertedType::kNone, + int fieldId = -1) : type_(type), name_(name), repetition_(repetition), - converted_type_(converted_type), - field_id_(field_id), + convertedType_(convertedType), + fieldId_(fieldId), parent_(NULLPTR) {} Node( Node::type type, const std::string& name, Repetition::type repetition, - std::shared_ptr logical_type, - int field_id = -1) + std::shared_ptr logicalType, + int fieldId = -1) : type_(type), name_(name), repetition_(repetition), - logical_type_(std::move(logical_type)), - field_id_(field_id), + logicalType_(std::move(logicalType)), + fieldId_(fieldId), parent_(NULLPTR) {} Node::type type_; std::string name_; Repetition::type repetition_; - ConvertedType::type converted_type_; - std::shared_ptr logical_type_; - int field_id_; + ConvertedType::type convertedType_; + std::shared_ptr logicalType_; + int fieldId_; // Nodes should not be shared, they have a single parent. const Node* parent_; - bool EqualsInternal(const Node* other) const; - void SetParent(const Node* p_parent); + bool equalsInternal(const Node* other) const; + void setParent(const Node* pParent); private: PARQUET_DISALLOW_COPY_AND_ASSIGN(Node); }; -// Save our breath all over the place with these typedefs +// Save our breath all over the place with these typedefs. using NodePtr = std::shared_ptr; using NodeVector = std::vector; // A type that is one of the primitive Parquet storage types. In addition to // the other type metadata (name, repetition level, logical type), also has the // physical storage type and their type-specific metadata (byte width, decimal -// parameters) +// parameters). class PARQUET_EXPORT PrimitiveNode : public Node { public: - static std::unique_ptr FromParquet(const void* opaque_element); + static std::unique_ptr fromParquet(const void* opaqueElement); - // A field_id -1 (or any negative value) will be serialized as null in Thrift - static inline NodePtr Make( + // A field_id -1 (or any negative value) will be serialized as null in Thrift. + static inline NodePtr make( const std::string& name, Repetition::type repetition, Type::type type, - ConvertedType::type converted_type = ConvertedType::NONE, + ConvertedType::type convertedType = ConvertedType::kNone, int length = -1, int precision = -1, int scale = -1, - int field_id = -1) { + int fieldId = -1) { return NodePtr(new PrimitiveNode( name, repetition, type, - converted_type, + convertedType, length, precision, scale, - field_id)); + fieldId)); } - // If no logical type, pass LogicalType::None() or nullptr - // A field_id -1 (or any negative value) will be serialized as null in Thrift - static inline NodePtr Make( + // If no logical type, pass LogicalType::None() or nullptr. + // A field_id -1 (or any negative value) will be serialized as null in Thrift. + static inline NodePtr make( const std::string& name, Repetition::type repetition, - std::shared_ptr logical_type, - Type::type primitive_type, - int primitive_length = -1, - int field_id = -1) { + std::shared_ptr logicalType, + Type::type primitiveType, + int primitiveLength = -1, + int fieldId = -1) { return NodePtr(new PrimitiveNode( name, repetition, - std::move(logical_type), - primitive_type, - primitive_length, - field_id)); + std::move(logicalType), + primitiveType, + primitiveLength, + fieldId)); } - bool Equals(const Node* other) const override; + bool equals(const Node* other) const override; - Type::type physical_type() const { - return physical_type_; + Type::type physicalType() const { + return physicalType_; } - ColumnOrder column_order() const { - return column_order_; + ColumnOrder columnOrder() const { + return columnOrder_; } - void SetColumnOrder(ColumnOrder column_order) { - column_order_ = column_order; + void setColumnOrder(ColumnOrder columnOrder) { + columnOrder_ = columnOrder; } - int32_t type_length() const { - return type_length_; + int32_t typeLength() const { + return typeLength_; } - const DecimalMetadata& decimal_metadata() const { - return decimal_metadata_; + const DecimalMetadata& decimalMetadata() const { + return decimalMetadata_; } - void ToParquet(void* element) const override; - void Visit(Visitor* visitor) override; - void VisitConst(ConstVisitor* visitor) const override; + void toParquet(void* element) const override; + void visit(Visitor* visitor) override; + void visitConst(ConstVisitor* visitor) const override; private: PrimitiveNode( const std::string& name, Repetition::type repetition, Type::type type, - ConvertedType::type converted_type = ConvertedType::NONE, + ConvertedType::type convertedType = ConvertedType::kNone, int length = -1, int precision = -1, int scale = -1, - int field_id = -1); + int fieldId = -1); PrimitiveNode( const std::string& name, Repetition::type repetition, - std::shared_ptr logical_type, - Type::type primitive_type, - int primitive_length = -1, - int field_id = -1); + std::shared_ptr logicalType, + Type::type primitiveType, + int primitiveLength = -1, + int fieldId = -1); - Type::type physical_type_; - int32_t type_length_; - DecimalMetadata decimal_metadata_; - ColumnOrder column_order_; + Type::type physicalType_; + int32_t typeLength_; + DecimalMetadata decimalMetadata_; + ColumnOrder columnOrder_; - // For FIXED_LEN_BYTE_ARRAY - void SetTypeLength(int32_t length) { - type_length_ = length; + // For FIXED_LEN_BYTE_ARRAY. + void setTypeLength(int32_t length) { + typeLength_ = length; } - bool EqualsInternal(const PrimitiveNode* other) const; + bool equalsInternal(const PrimitiveNode* other) const; FRIEND_TEST(TestPrimitiveNode, Attrs); - FRIEND_TEST(TestPrimitiveNode, Equals); + FRIEND_TEST(TestPrimitiveNode, equals); FRIEND_TEST(TestPrimitiveNode, PhysicalLogicalMapping); - FRIEND_TEST(TestPrimitiveNode, FromParquet); + FRIEND_TEST(TestPrimitiveNode, fromParquet); }; class PARQUET_EXPORT GroupNode : public Node { public: - static std::unique_ptr FromParquet( - const void* opaque_element, + static std::unique_ptr fromParquet( + const void* opaqueElement, NodeVector fields = {}); - // A field_id -1 (or any negative value) will be serialized as null in Thrift - static inline NodePtr Make( + // A field_id -1 (or any negative value) will be serialized as null in Thrift. + static inline NodePtr make( const std::string& name, Repetition::type repetition, const NodeVector& fields, - ConvertedType::type converted_type = ConvertedType::NONE, - int field_id = -1) { + ConvertedType::type convertedType = ConvertedType::kNone, + int fieldId = -1) { return NodePtr( - new GroupNode(name, repetition, fields, converted_type, field_id)); + new GroupNode(name, repetition, fields, convertedType, fieldId)); } - // If no logical type, pass nullptr - // A field_id -1 (or any negative value) will be serialized as null in Thrift - static inline NodePtr Make( + // If no logical type, pass nullptr. + // A field_id -1 (or any negative value) will be serialized as null in Thrift. + static inline NodePtr make( const std::string& name, Repetition::type repetition, const NodeVector& fields, - std::shared_ptr logical_type, - int field_id = -1) { + std::shared_ptr logicalType, + int fieldId = -1) { return NodePtr( - new GroupNode(name, repetition, fields, logical_type, field_id)); + new GroupNode(name, repetition, fields, logicalType, fieldId)); } - bool Equals(const Node* other) const override; + bool equals(const Node* other) const override; const NodePtr& field(int i) const { return fields_[i]; @@ -373,80 +373,80 @@ class PARQUET_EXPORT GroupNode : public Node { // Get the index of a field by its name, or negative value if not found. // If several fields share the same name, it is unspecified which one // is returned. - int FieldIndex(const std::string& name) const; + int fieldIndex(const std::string& name) const; // Get the index of a field by its node, or negative value if not found. - int FieldIndex(const Node& node) const; + int fieldIndex(const Node& node) const; - int field_count() const { + int fieldCount() const { return static_cast(fields_.size()); } - void ToParquet(void* element) const override; - void Visit(Visitor* visitor) override; - void VisitConst(ConstVisitor* visitor) const override; + void toParquet(void* element) const override; + void visit(Visitor* visitor) override; + void visitConst(ConstVisitor* visitor) const override; /// \brief Return true if this node or any child node has REPEATED repetition - /// type - bool HasRepeatedFields() const; + /// type. + bool hasRepeatedFields() const; private: GroupNode( const std::string& name, Repetition::type repetition, const NodeVector& fields, - ConvertedType::type converted_type = ConvertedType::NONE, - int field_id = -1); + ConvertedType::type convertedType = ConvertedType::kNone, + int fieldId = -1); GroupNode( const std::string& name, Repetition::type repetition, const NodeVector& fields, - std::shared_ptr logical_type, - int field_id = -1); + std::shared_ptr logicalType, + int fieldId = -1); NodeVector fields_; - bool EqualsInternal(const GroupNode* other) const; + bool equalsInternal(const GroupNode* other) const; - // Mapping between field name to the field index - std::unordered_multimap field_name_to_idx_; + // Mapping between field name to the field index. + std::unordered_multimap fieldNameToIdx_; FRIEND_TEST(TestGroupNode, Attrs); - FRIEND_TEST(TestGroupNode, Equals); - FRIEND_TEST(TestGroupNode, FieldIndex); + FRIEND_TEST(TestGroupNode, equals); + FRIEND_TEST(TestGroupNode, fieldIndex); FRIEND_TEST(TestGroupNode, FieldIndexDuplicateName); }; -// ---------------------------------------------------------------------- -// Convenience primitive type factory functions - -#define PRIMITIVE_FACTORY(FuncName, TYPE) \ - static inline NodePtr FuncName( \ - const std::string& name, \ - Repetition::type repetition = Repetition::OPTIONAL, \ - int field_id = -1) { \ - return PrimitiveNode::Make( \ - name, \ - repetition, \ - Type::TYPE, \ - ConvertedType::NONE, \ - /*length=*/-1, \ - /*precision=*/-1, \ - /*scale=*/-1, \ - field_id); \ - } - -PRIMITIVE_FACTORY(Boolean, BOOLEAN) -PRIMITIVE_FACTORY(Int32, INT32) -PRIMITIVE_FACTORY(Int64, INT64) -PRIMITIVE_FACTORY(Int96, INT96) -PRIMITIVE_FACTORY(Float, FLOAT) -PRIMITIVE_FACTORY(Double, DOUBLE) -PRIMITIVE_FACTORY(ByteArray, BYTE_ARRAY) - -void PARQUET_EXPORT PrintSchema( +// ----------------------------------------------------------------------. +// Convenience primitive type factory functions. + +#define PRIMITIVE_FACTORY(funcName, TYPE) \ + static inline NodePtr funcName( \ + const std::string& name, \ + Repetition::type repetition = Repetition::kOptional, \ + int fieldId = -1) { \ + return PrimitiveNode::make( \ + name, \ + repetition, \ + Type::TYPE, \ + ConvertedType::kNone, \ + -1, \ + -1, \ + -1, \ + fieldId); \ + } + +PRIMITIVE_FACTORY(boolean, kBoolean) +PRIMITIVE_FACTORY(int32, kInt32) +PRIMITIVE_FACTORY(int64, kInt64) +PRIMITIVE_FACTORY(int96, kInt96) +PRIMITIVE_FACTORY(floatType, kFloat) +PRIMITIVE_FACTORY(doubleType, kDouble) +PRIMITIVE_FACTORY(byteArray, kByteArray) + +void PARQUET_EXPORT printSchema( const schema::Node* schema, std::ostream& stream, - int indent_width = 2); + int indentWidth = 2); } // namespace schema @@ -458,166 +458,165 @@ void PARQUET_EXPORT PrintSchema( class PARQUET_EXPORT ColumnDescriptor { public: ColumnDescriptor( - schema::NodePtr node, - int16_t max_definition_level, - int16_t max_repetition_level, - const SchemaDescriptor* schema_descr = NULLPTR); + schema::NodePtr Node, + int16_t maxDefinitionLevel, + int16_t maxRepetitionLevel, + const SchemaDescriptor* schemaDescr = NULLPTR); - bool Equals(const ColumnDescriptor& other) const; + bool equals(const ColumnDescriptor& other) const; - int16_t max_definition_level() const { - return max_definition_level_; + int16_t maxDefinitionLevel() const { + return maxDefinitionLevel_; } - int16_t max_repetition_level() const { - return max_repetition_level_; + int16_t maxRepetitionLevel() const { + return maxRepetitionLevel_; } - Type::type physical_type() const { - return primitive_node_->physical_type(); + Type::type physicalType() const { + return primitiveNode_->physicalType(); } - ConvertedType::type converted_type() const { - return primitive_node_->converted_type(); + ConvertedType::type convertedType() const { + return primitiveNode_->convertedType(); } - const std::shared_ptr& logical_type() const { - return primitive_node_->logical_type(); + const std::shared_ptr& logicalType() const { + return primitiveNode_->logicalType(); } - ColumnOrder column_order() const { - return primitive_node_->column_order(); + ColumnOrder columnOrder() const { + return primitiveNode_->columnOrder(); } - SortOrder::type sort_order() const { - auto la = logical_type(); - auto pt = physical_type(); - return la ? GetSortOrder(la, pt) : GetSortOrder(converted_type(), pt); + SortOrder::type sortOrder() const { + auto la = logicalType(); + auto pt = physicalType(); + return la ? getSortOrder(la, pt) : getSortOrder(convertedType(), pt); } const std::string& name() const { - return primitive_node_->name(); + return primitiveNode_->name(); } const std::shared_ptr path() const; - const schema::NodePtr& schema_node() const { + const schema::NodePtr& schemaNode() const { return node_; } - std::string ToString() const; + std::string toString() const; - int type_length() const; + int typeLength() const; - int type_precision() const; + int typePrecision() const; - int type_scale() const; + int typeScale() const; private: schema::NodePtr node_; - const schema::PrimitiveNode* primitive_node_; + const schema::PrimitiveNode* primitiveNode_; - int16_t max_definition_level_; - int16_t max_repetition_level_; + int16_t maxDefinitionLevel_; + int16_t maxRepetitionLevel_; }; // Container for the converted Parquet schema with a computed information from -// the schema analysis needed for file reading +// the schema analysis needed for file reading. // -// * Column index to Node -// * Max repetition / definition levels for each primitive node +// * Column index to Node. +// * Max repetition / definition levels for each primitive node. // // The ColumnDescriptor objects produced by this class can be used to assist in // the reconstruction of fully materialized data structures from the -// repetition-definition level encoding of nested data +// repetition-definition level encoding of nested data. // -// TODO(wesm): this object can be recomputed from a Schema +// TODO(wesm): This object can be recomputed from a Schema. class PARQUET_EXPORT SchemaDescriptor { public: SchemaDescriptor() {} ~SchemaDescriptor() {} - // Analyze the schema - void Init(std::unique_ptr schema); - void Init(schema::NodePtr schema); + // Analyze the schema. + void init(std::unique_ptr schema); + void init(schema::NodePtr schema); - const ColumnDescriptor* Column(int i) const; + const ColumnDescriptor* column(int i) const; // Get the index of a column by its dotstring path, or negative value if not // found. If several columns share the same dotstring path, it is unspecified // which one is returned. - int ColumnIndex(const std::string& node_path) const; + int columnIndex(const std::string& nodePath) const; // Get the index of a column by its node, or negative value if not found. - int ColumnIndex(const schema::Node& node) const; + int columnIndex(const schema::Node& node) const; - bool Equals( - const SchemaDescriptor& other, - std::ostream* diff_output = NULLPTR) const; + bool equals(const SchemaDescriptor& other, std::ostream* diffOutput = NULLPTR) + const; - // The number of physical columns appearing in the file - int num_columns() const { + // The number of physical columns appearing in the file. + int numColumns() const { return static_cast(leaves_.size()); } - const schema::NodePtr& schema_root() const { + const schema::NodePtr& schemaRoot() const { return schema_; } - const schema::GroupNode* group_node() const { - return group_node_; + const schema::GroupNode* groupNode() const { + return groupNode_; } - // Returns the root (child of the schema root) node of the leaf(column) node - const schema::Node* GetColumnRoot(int i) const; + // Returns the root (child of the schema root) node of the leaf (column) node. + const schema::Node* getColumnRoot(int i) const; const std::string& name() const { - return group_node_->name(); + return groupNode_->name(); } - std::string ToString() const; + std::string toString() const; - void updateColumnOrders(const std::vector& column_orders); + void updateColumnOrders(const std::vector& columnOrders); /// \brief Return column index corresponding to a particular - /// PrimitiveNode. Returns -1 if not found - int GetColumnIndex(const schema::PrimitiveNode& node) const; + /// PrimitiveNode. Returns -1 if not found. + int getColumnIndex(const schema::PrimitiveNode& node) const; - /// \brief Return true if any field or their children have REPEATED repetition - /// type - bool HasRepeatedFields() const; + /// \brief Return true if any field or their children have REPEATED + /// repetition type. + bool hasRepeatedFields() const; private: friend class ColumnDescriptor; - // Root Node + // Root Node. schema::NodePtr schema_; - // Root Node - const schema::GroupNode* group_node_; + // Root Node. + const schema::GroupNode* groupNode_; - void BuildTree( - const schema::NodePtr& node, - int16_t max_def_level, - int16_t max_rep_level, + void buildTree( + const schema::NodePtr& Node, + int16_t maxDefLevel, + int16_t maxRepLevel, const schema::NodePtr& base); - // Result of leaf node / tree analysis + // Result of leaf node / tree analysis. std::vector leaves_; - std::unordered_map node_to_leaf_index_; + std::unordered_map nodeToLeafIndex_; // Mapping between leaf nodes and root group of leaf (first node - // below the schema's root group) + // below the schema's root group). // - // For example, the leaf `a.b.c.d` would have a link back to `a` + // For example, the leaf `a.b.c.d` would have a link back to `a`. // - // -- a <------ - // -- -- b | - // -- -- -- c | - // -- -- -- -- d - std::unordered_map leaf_to_base_; - - // Mapping between ColumnPath DotString to the leaf index - std::unordered_multimap leaf_to_idx_; + // -- A <------. + // -- -- B |. + // -- -- -- C |. + // -- -- -- -- D. + std::unordered_map leafToBase_; + + // Mapping between ColumnPath DotString to the leaf index. + std::unordered_multimap leafToIdx_; }; } // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/SchemaInternal.h b/velox/dwio/parquet/writer/arrow/SchemaInternal.h index 7b8e9e5c073..85eae7b7e87 100644 --- a/velox/dwio/parquet/writer/arrow/SchemaInternal.h +++ b/velox/dwio/parquet/writer/arrow/SchemaInternal.h @@ -16,7 +16,7 @@ // Adapted from Apache Arrow. -// Non-public Thrift schema serialization utilities +// Non-public Thrift schema serialization utilities. #pragma once @@ -35,23 +35,23 @@ class SchemaElement; namespace schema { -// ---------------------------------------------------------------------- -// Conversion from Parquet Thrift metadata +// ----------------------------------------------------------------------. +// Conversion from Parquet Thrift metadata. PARQUET_EXPORT -std::shared_ptr FromParquet( +std::shared_ptr fromParquet( const std::vector& schema); PARQUET_EXPORT -std::unique_ptr Unflatten( +std::unique_ptr unflatten( const facebook::velox::parquet::thrift::SchemaElement* elements, int length); -// ---------------------------------------------------------------------- -// Conversion to Parquet Thrift metadata +// ----------------------------------------------------------------------. +// Conversion to Parquet Thrift metadata. PARQUET_EXPORT -void ToParquet( +void toParquet( const GroupNode* schema, std::vector* out); diff --git a/velox/dwio/parquet/writer/arrow/Statistics.cpp b/velox/dwio/parquet/writer/arrow/Statistics.cpp index 757ac36b62e..3d88fc41d33 100644 --- a/velox/dwio/parquet/writer/arrow/Statistics.cpp +++ b/velox/dwio/parquet/writer/arrow/Statistics.cpp @@ -39,6 +39,10 @@ #include "velox/dwio/parquet/writer/arrow/Exception.h" #include "velox/dwio/parquet/writer/arrow/Platform.h" #include "velox/dwio/parquet/writer/arrow/Schema.h" +#include "velox/dwio/parquet/writer/arrow/StringTruncation.h" + +#include "velox/type/DecimalUtil.h" +#include "velox/type/HugeInt.h" using arrow::default_memory_pool; using arrow::MemoryPool; @@ -53,21 +57,21 @@ inline std::enable_if_t< std::is_trivially_copyable_v && std::is_trivially_copyable_v && sizeof(T) == sizeof(U), U> -SafeCopy(T value) { +safeCopy(T value) { std::remove_const_t ret; std::memcpy(&ret, &value, sizeof(T)); return ret; } template -inline std::enable_if_t, T> SafeLoad( +inline std::enable_if_t, T> safeLoad( const T* unaligned) { std::remove_const_t ret; std::memcpy(&ret, unaligned, sizeof(T)); return ret; } -std::shared_ptr AllocateBuffer( +std::shared_ptr allocateBuffer( MemoryPool* pool, int64_t size) { PARQUET_ASSIGN_OR_THROW( @@ -76,94 +80,99 @@ std::shared_ptr AllocateBuffer( } // ---------------------------------------------------------------------- -// Comparator implementations +// Comparator implementations. -constexpr int value_length(int value_length, const ByteArray& value) { +constexpr int valueLength(int valueLength, const ByteArray& value) { return value.len; } -constexpr int value_length(int type_length, const FLBA& value) { - return type_length; +constexpr int valueLength(int typeLength, const FLBA& value) { + return typeLength; } -template +template struct CompareHelper { - using T = typename DType::c_type; + using T = typename DType::CType; static_assert( !std::is_unsigned::value || std::is_same::value, "T is an unsigned numeric"); - constexpr static T DefaultMin() { + constexpr static T defaultMin() { + if constexpr (std::is_floating_point_v) { + return std::numeric_limits::infinity(); + } return std::numeric_limits::max(); } - constexpr static T DefaultMax() { - return std::numeric_limits::lowest(); + constexpr static T defaultMax() { + if constexpr (std::is_floating_point_v) { + return -std::numeric_limits::infinity(); + } + return std::numeric_limits::min(); } // MSVC17 fix, isnan is not overloaded for IntegralType as per C++11 // standard requirements. template - static ::arrow::enable_if_t::value, T> Coalesce( + static ::arrow::enable_if_t::value, T> coalesce( T val, T fallback) { return std::isnan(val) ? fallback : val; } template - static ::arrow::enable_if_t::value, T> Coalesce( + static ::arrow::enable_if_t::value, T> coalesce( T val, T fallback) { return val; } - static inline bool Compare(int type_length, const T& a, const T& b) { + static inline bool compare(int typeLength, const T& a, const T& b) { return a < b; } - static T Min(int type_length, T a, T b) { + static T min(int typeLength, T a, T b) { return a < b ? a : b; } - static T Max(int type_length, T a, T b) { + static T max(int typeLength, T a, T b) { return a < b ? b : a; } }; template struct UnsignedCompareHelperBase { - using T = typename DType::c_type; + using T = typename DType::CType; using UCType = typename std::make_unsigned::type; static_assert(!std::is_same::value, "T is unsigned"); static_assert(sizeof(T) == sizeof(UCType), "T and UCType not the same size"); - // NOTE: according to the C++ spec, unsigned-to-signed conversion is + // NOTE: According to the C++ spec, unsigned-to-signed conversion is // implementation-defined if the original value does not fit in the signed - // type (i.e., two's complement cannot be assumed even on mainstream machines, - // because the compiler may decide otherwise). Hence the use of `SafeCopy` - // below for deterministic bit-casting. - // (see "Integer conversions" in - // https://en.cppreference.com/w/cpp/language/implicit_conversion) + // type (i.e., two's complement cannot be assumed even on mainstream + // machines, because the compiler may decide otherwise). Hence the use of + // `safeCopy` below for deterministic bit-casting. (See "Integer conversions" + // in https://en.cppreference.com/w/cpp/language/implicit_conversion). - static const T DefaultMin() { - return SafeCopy(std::numeric_limits::max()); + static const T defaultMin() { + return safeCopy(std::numeric_limits::max()); } - static const T DefaultMax() { + static const T defaultMax() { return 0; } - static T Coalesce(T val, T fallback) { + static T coalesce(T val, T fallback) { return val; } - static bool Compare(int type_length, T a, T b) { - return SafeCopy(a) < SafeCopy(b); + static bool compare(int typeLength, T a, T b) { + return safeCopy(a) < safeCopy(b); } - static T Min(int type_length, T a, T b) { - return Compare(type_length, a, b) ? a : b; + static T min(int typeLength, T a, T b) { + return compare(typeLength, a, b) ? a : b; } - static T Max(int type_length, T a, T b) { - return Compare(type_length, a, b) ? b : a; + static T max(int typeLength, T a, T b) { + return compare(typeLength, a, b) ? b : a; } }; @@ -175,208 +184,207 @@ template <> struct CompareHelper : public UnsignedCompareHelperBase {}; -template -struct CompareHelper { - using T = typename Int96Type::c_type; - using msb_type = - typename std::conditional::type; +template +struct CompareHelper { + using T = typename Int96Type::CType; + using MsbType = typename std::conditional::type; - static T DefaultMin() { - uint32_t kMsbMax = SafeCopy(std::numeric_limits::max()); + static T defaultMin() { + uint32_t kMsbMax = safeCopy(std::numeric_limits::max()); uint32_t kMax = std::numeric_limits::max(); return {kMax, kMax, kMsbMax}; } - static T DefaultMax() { - uint32_t kMsbMin = SafeCopy(std::numeric_limits::min()); + static T defaultMax() { + uint32_t kMsbMin = safeCopy(std::numeric_limits::min()); uint32_t kMin = std::numeric_limits::min(); return {kMin, kMin, kMsbMin}; } - static T Coalesce(T val, T fallback) { + static T coalesce(T val, T fallback) { return val; } - static inline bool Compare(int type_length, const T& a, const T& b) { + static inline bool compare(int typeLength, const T& a, const T& b) { if (a.value[2] != b.value[2]) { - // Only the MSB bit is by Signed comparison. For little-endian, this is + // Only the MSB bit is by signed comparison. For little-endian, this is // the last bit of Int96 type. - return SafeCopy(a.value[2]) < SafeCopy(b.value[2]); + return safeCopy(a.value[2]) < safeCopy(b.value[2]); } else if (a.value[1] != b.value[1]) { return (a.value[1] < b.value[1]); } return (a.value[0] < b.value[0]); } - static T Min(int type_length, const T& a, const T& b) { - return Compare(0, a, b) ? a : b; + static T min(int typeLength, const T& a, const T& b) { + return compare(0, a, b) ? a : b; } - static T Max(int type_length, const T& a, const T& b) { - return Compare(0, a, b) ? b : a; + static T max(int typeLength, const T& a, const T& b) { + return compare(0, a, b) ? b : a; } }; -template +template struct BinaryLikeComparer {}; template -struct BinaryLikeComparer { - static bool Compare(int type_length, const T& a, const T& b) { - int a_length = value_length(type_length, a); - int b_length = value_length(type_length, b); +struct BinaryLikeComparer { + static bool compare(int typeLength, const T& a, const T& b) { + int aLength = valueLength(typeLength, a); + int bLength = valueLength(typeLength, b); // Unsigned comparison is used for non-numeric types so straight - // lexicographic comparison makes sense. (a.ptr is always unsigned).... + // lexicographic comparison makes sense (a.ptr is always unsigned). return std::lexicographical_compare( - a.ptr, a.ptr + a_length, b.ptr, b.ptr + b_length); + a.ptr, a.ptr + aLength, b.ptr, b.ptr + bLength); } }; template -struct BinaryLikeComparer { - static bool Compare(int type_length, const T& a, const T& b) { +struct BinaryLikeComparer { + static bool compare(int typeLength, const T& a, const T& b) { // Is signed is used for integers encoded as big-endian twos - // complement integers. (e.g. decimals). - int a_length = value_length(type_length, a); - int b_length = value_length(type_length, b); + // complement integers (e.g., decimals). + int aLength = valueLength(typeLength, a); + int bLength = valueLength(typeLength, b); // At least of the lengths is zero. - if (a_length == 0 || b_length == 0) { - return a_length == 0 && b_length > 0; + if (aLength == 0 || bLength == 0) { + return aLength == 0 && bLength > 0; } - int8_t first_a = *a.ptr; - int8_t first_b = *b.ptr; + int8_t firstA = *a.ptr; + int8_t firstB = *b.ptr; // We can short circuit for different signed numbers or // for equal length bytes arrays that have different first bytes. // The equality requirement is necessary for sign extension cases. - // 0xFF10 should be equal to 0x10 (due to big endian sign extension). - if ((0x80 & first_a) != (0x80 & first_b) || - (a_length == b_length && first_a != first_b)) { - return first_a < first_b; + // 0xff10 should be equal to 0x10 (due to big-endian sign extension). + if ((0x80 & firstA) != (0x80 & firstB) || + (aLength == bLength && firstA != firstB)) { + return firstA < firstB; } // When the lengths are unequal and the numbers are of the same // sign we need to do comparison by sign extending the shorter // value first, and once we get to equal sized arrays, lexicographical // unsigned comparison of everything but the first byte is sufficient. - const uint8_t* a_start = a.ptr; - const uint8_t* b_start = b.ptr; - if (a_length != b_length) { - const uint8_t* lead_start = nullptr; - const uint8_t* lead_end = nullptr; - if (a_length > b_length) { - int lead_length = a_length - b_length; - lead_start = a.ptr; - lead_end = a.ptr + lead_length; - a_start += lead_length; + const uint8_t* aStart = a.ptr; + const uint8_t* bStart = b.ptr; + if (aLength != bLength) { + const uint8_t* leadStart = nullptr; + const uint8_t* leadEnd = nullptr; + if (aLength > bLength) { + int leadLength = aLength - bLength; + leadStart = a.ptr; + leadEnd = a.ptr + leadLength; + aStart += leadLength; } else { - VELOX_DCHECK_LT(a_length, b_length); - int lead_length = b_length - a_length; - lead_start = b.ptr; - lead_end = b.ptr + lead_length; - b_start += lead_length; + VELOX_DCHECK_LT(aLength, bLength); + int leadLength = bLength - aLength; + leadStart = b.ptr; + leadEnd = b.ptr + leadLength; + bStart += leadLength; } // Compare extra bytes to the sign extension of the first // byte of the other number. - uint8_t extension = first_a < 0 ? 0xFF : 0; - bool not_equal = - std::any_of(lead_start, lead_end, [extension](uint8_t a) { - return extension != a; - }); - if (not_equal) { + uint8_t extension = firstA < 0 ? 0xFF : 0; + bool notEqual = std::any_of(leadStart, leadEnd, [extension](uint8_t a) { + return extension != a; + }); + if (notEqual) { // Since sign extension are extrema values for unsigned bytes: // // Four cases exist: - // negative values: - // b is the longer value. - // b must be the lesser value: return false - // else: - // a must be the lesser value: return true + // Negative values: + // B is the longer value. + // B must be the lesser value: return false. + // Else: + // A must be the lesser value: return true. // - // positive values: - // b is the longer value. - // values in b must be greater than a: return true - // else: - // values in a must be greater than b: return false - bool negative_values = first_a < 0; - bool b_longer = a_length < b_length; - return negative_values != b_longer; + // Positive values: + // B is the longer value. + // Values in b must be greater than a: return true. + // Else: + // Values in a must be greater than b: return false. + bool negativeValues = firstA < 0; + bool bLonger = aLength < bLength; + return negativeValues != bLonger; } } else { - a_start++; - b_start++; + aStart++; + bStart++; } return std::lexicographical_compare( - a_start, a.ptr + a_length, b_start, b.ptr + b_length); + aStart, a.ptr + aLength, bStart, b.ptr + bLength); } }; -template +template struct BinaryLikeCompareHelperBase { - using T = typename DType::c_type; + using T = typename DType::CType; - static T DefaultMin() { + static T defaultMin() { return {}; } - static T DefaultMax() { + static T defaultMax() { return {}; } - static T Coalesce(T val, T fallback) { + static T coalesce(T val, T fallback) { return val; } - static inline bool Compare(int type_length, const T& a, const T& b) { - return BinaryLikeComparer::Compare(type_length, a, b); + static inline bool compare(int typeLength, const T& a, const T& b) { + return BinaryLikeComparer::compare(typeLength, a, b); } - static T Min(int type_length, const T& a, const T& b) { + static T min(int typeLength, const T& a, const T& b) { if (a.ptr == nullptr) return b; if (b.ptr == nullptr) return a; - return Compare(type_length, a, b) ? a : b; + return compare(typeLength, a, b) ? a : b; } - static T Max(int type_length, const T& a, const T& b) { + static T max(int typeLength, const T& a, const T& b) { if (a.ptr == nullptr) return b; if (b.ptr == nullptr) return a; - return Compare(type_length, a, b) ? b : a; + return compare(typeLength, a, b) ? b : a; } }; -template -struct CompareHelper - : public BinaryLikeCompareHelperBase {}; +template +struct CompareHelper + : public BinaryLikeCompareHelperBase {}; -template -struct CompareHelper - : public BinaryLikeCompareHelperBase {}; +template +struct CompareHelper + : public BinaryLikeCompareHelperBase {}; using ::std::optional; template ::arrow::enable_if_t::value, optional>> -CleanStatistic(std::pair min_max) { - return min_max; +cleanStatistic(std::pair minMax) { + return minMax; } // In case of floating point types, the following rules are applied (as per // upstream parquet-mr): // - If any of min/max is NaN, return nothing. -// - If min is 0.0f, replace with -0.0f -// - If max is -0.0f, replace with 0.0f +// - If min is infinity and max is -infinity, return nothing. +// - If min is 0.0f, replace with -0.0f. +// - If max is -0.0f, replace with 0.0f. template ::arrow:: enable_if_t::value, optional>> - CleanStatistic(std::pair min_max) { - T min = min_max.first; - T max = min_max.second; + cleanStatistic(std::pair minMax) { + T min = minMax.first; + T max = minMax.second; // Ignore if one of the value is nan. if (std::isnan(min) || std::isnan(max)) { return ::std::nullopt; } - if (min == std::numeric_limits::max() && - max == std::numeric_limits::lowest()) { + if (min == std::numeric_limits::infinity() && + max == -std::numeric_limits::infinity()) { return ::std::nullopt; } @@ -393,141 +401,139 @@ ::arrow:: return {{min, max}}; } -optional> CleanStatistic(std::pair min_max) { - if (min_max.first.ptr == nullptr || min_max.second.ptr == nullptr) { +optional> cleanStatistic(std::pair minMax) { + if (minMax.first.ptr == nullptr || minMax.second.ptr == nullptr) { return ::std::nullopt; } - return min_max; + return minMax; } -optional> CleanStatistic( - std::pair min_max) { - if (min_max.first.ptr == nullptr || min_max.second.ptr == nullptr) { +optional> cleanStatistic( + std::pair minMax) { + if (minMax.first.ptr == nullptr || minMax.second.ptr == nullptr) { return ::std::nullopt; } - return min_max; + return minMax; } -template +template class TypedComparatorImpl : virtual public TypedComparator { public: - using T = typename DType::c_type; - using Helper = CompareHelper; + using T = typename DType::CType; + using Helper = CompareHelper; - explicit TypedComparatorImpl(int type_length = -1) - : type_length_(type_length) {} + explicit TypedComparatorImpl(int typeLength = -1) : typeLength_(typeLength) {} - bool CompareInline(const T& a, const T& b) const { - return Helper::Compare(type_length_, a, b); + bool compareInline(const T& a, const T& b) const { + return Helper::compare(typeLength_, a, b); } - bool Compare(const T& a, const T& b) override { - return CompareInline(a, b); + bool compare(const T& a, const T& b) override { + return compareInline(a, b); } - std::pair GetMinMax(const T* values, int64_t length) override { + std::pair getMinMax(const T* values, int64_t length) override { VELOX_DCHECK_GT(length, 0); - T min = Helper::DefaultMin(); - T max = Helper::DefaultMax(); + T min = Helper::defaultMin(); + T max = Helper::defaultMax(); for (int64_t i = 0; i < length; i++) { - const auto val = SafeLoad(values + i); - min = Helper::Min( - type_length_, min, Helper::Coalesce(val, Helper::DefaultMin())); - max = Helper::Max( - type_length_, max, Helper::Coalesce(val, Helper::DefaultMax())); + const auto val = safeLoad(values + i); + min = Helper::min( + typeLength_, min, Helper::coalesce(val, Helper::defaultMin())); + max = Helper::max( + typeLength_, max, Helper::coalesce(val, Helper::defaultMax())); } return {min, max}; } - std::pair GetMinMaxSpaced( + std::pair getMinMaxSpaced( const T* values, int64_t length, - const uint8_t* valid_bits, - int64_t valid_bits_offset) override { + const uint8_t* validBits, + int64_t validBitsOffset) override { VELOX_DCHECK_GT(length, 0); - T min = Helper::DefaultMin(); - T max = Helper::DefaultMax(); + T min = Helper::defaultMin(); + T max = Helper::defaultMax(); ::arrow::internal::VisitSetBitRunsVoid( - valid_bits, - valid_bits_offset, + validBits, + validBitsOffset, length, [&](int64_t position, int64_t length) { for (int64_t i = 0; i < length; i++) { - const auto val = SafeLoad(values + i + position); - min = Helper::Min( - type_length_, min, Helper::Coalesce(val, Helper::DefaultMin())); - max = Helper::Max( - type_length_, max, Helper::Coalesce(val, Helper::DefaultMax())); + const auto val = safeLoad(values + i + position); + min = Helper::min( + typeLength_, min, Helper::coalesce(val, Helper::defaultMin())); + max = Helper::max( + typeLength_, max, Helper::coalesce(val, Helper::defaultMax())); } }); return {min, max}; } - std::pair GetMinMax(const ::arrow::Array& values) override; + std::pair getMinMax(const ::arrow::Array& values) override; private: - int type_length_; + int typeLength_; }; -// ARROW-11675: A hand-written version of GetMinMax(), to work around +// ARROW-11675: A hand-written version of getMinMax(), to work around // what looks like a MSVC code generation bug. -// This does not seem to be required for GetMinMaxSpaced(). +// This does not seem to be required for getMinMaxSpaced(). template <> -std::pair -TypedComparatorImpl::GetMinMax( +std::pair TypedComparatorImpl::getMinMax( const int32_t* values, int64_t length) { VELOX_DCHECK_GT(length, 0); - const uint32_t* unsigned_values = reinterpret_cast(values); + const uint32_t* unsignedValues = reinterpret_cast(values); uint32_t min = std::numeric_limits::max(); uint32_t max = std::numeric_limits::lowest(); for (int64_t i = 0; i < length; i++) { - const auto val = unsigned_values[i]; + const auto val = unsignedValues[i]; min = std::min(min, val); max = std::max(max, val); } - return {SafeCopy(min), SafeCopy(max)}; + return {safeCopy(min), safeCopy(max)}; } -template -std::pair -TypedComparatorImpl::GetMinMax(const ::arrow::Array& values) { +template +std::pair +TypedComparatorImpl::getMinMax(const ::arrow::Array& values) { ParquetException::NYI(values.type()->ToString()); } -template -std::pair GetMinMaxBinaryHelper( - const TypedComparatorImpl& comparator, +template +std::pair getMinMaxBinaryHelper( + const TypedComparatorImpl& Comparator, const ::arrow::Array& values) { - using Helper = CompareHelper; + using Helper = CompareHelper; - ByteArray min = Helper::DefaultMin(); - ByteArray max = Helper::DefaultMax(); - constexpr int type_length = -1; + ByteArray min = Helper::defaultMin(); + ByteArray max = Helper::defaultMax(); + constexpr int typeLength = -1; - const auto valid_func = [&](std::string_view val) { + const auto validFunc = [&](std::string_view val) { ByteArray ba{std::string_view(val.data(), val.size())}; - min = Helper::Min(type_length, ba, min); - max = Helper::Max(type_length, ba, max); + min = Helper::min(typeLength, ba, min); + max = Helper::max(typeLength, ba, max); }; - const auto null_func = [&]() {}; + const auto nullFunc = [&]() {}; if (::arrow::is_binary_like(values.type_id())) { ::arrow::VisitArraySpanInline<::arrow::BinaryType>( - *values.data(), std::move(valid_func), std::move(null_func)); + *values.data(), std::move(validFunc), std::move(nullFunc)); } else { VELOX_DCHECK(::arrow::is_large_binary_like(values.type_id())); ::arrow::VisitArraySpanInline<::arrow::LargeBinaryType>( - *values.data(), std::move(valid_func), std::move(null_func)); + *values.data(), std::move(validFunc), std::move(nullFunc)); } return {min, max}; @@ -535,181 +541,215 @@ std::pair GetMinMaxBinaryHelper( template <> std::pair -TypedComparatorImpl::GetMinMax( +TypedComparatorImpl::getMinMax( const ::arrow::Array& values) { - return GetMinMaxBinaryHelper(*this, values); + return getMinMaxBinaryHelper(*this, values); } template <> std::pair -TypedComparatorImpl::GetMinMax( +TypedComparatorImpl::getMinMax( const ::arrow::Array& values) { - return GetMinMaxBinaryHelper(*this, values); + return getMinMaxBinaryHelper(*this, values); +} + +template +std::string encodeDecimalToBigEndian(T value) { + uint8_t buffer[sizeof(T)]; + if constexpr (std::is_same_v) { + *reinterpret_cast(buffer) = ::arrow::bit_util::ToBigEndian(value); + } else if constexpr (std::is_same_v) { + *reinterpret_cast(buffer) = DecimalUtil::bigEndian(value); + } + return std::string(reinterpret_cast(buffer), sizeof(T)); } template class TypedStatisticsImpl : public TypedStatistics { public: - using T = typename DType::c_type; + using T = typename DType::CType; - // Create an empty stats. + // Create an empty statistics. TypedStatisticsImpl(const ColumnDescriptor* descr, MemoryPool* pool) : descr_(descr), pool_(pool), - min_buffer_(AllocateBuffer(pool_, 0)), - max_buffer_(AllocateBuffer(pool_, 0)) { - auto comp = Comparator::Make(descr); + minBuffer_(allocateBuffer(pool_, 0)), + maxBuffer_(allocateBuffer(pool_, 0)) { + auto comp = Comparator::make(descr); comparator_ = std::static_pointer_cast>(comp); - TypedStatisticsImpl::Reset(); + TypedStatisticsImpl::reset(); } // Create stats from provided values. TypedStatisticsImpl( const T& min, const T& max, - int64_t num_values, - int64_t null_count, - int64_t distinct_count) + int64_t numValues, + int64_t nullCount, + int64_t distinctCount) : pool_(default_memory_pool()), - min_buffer_(AllocateBuffer(pool_, 0)), - max_buffer_(AllocateBuffer(pool_, 0)) { - TypedStatisticsImpl::IncrementNumValues(num_values); - TypedStatisticsImpl::IncrementNullCount(null_count); - SetDistinctCount(distinct_count); + minBuffer_(allocateBuffer(pool_, 0)), + maxBuffer_(allocateBuffer(pool_, 0)) { + TypedStatisticsImpl::incrementNumValues(numValues); + TypedStatisticsImpl::incrementNullCount(nullCount); + setDistinctCount(distinctCount); - Copy(min, &min_, min_buffer_.get()); - Copy(max, &max_, max_buffer_.get()); - has_min_max_ = true; + copy(min, &min_, minBuffer_.get()); + copy(max, &max_, maxBuffer_.get()); + hasMinMax_ = true; } // Create stats from a thrift Statistics object. TypedStatisticsImpl( const ColumnDescriptor* descr, - const std::string& encoded_min, - const std::string& encoded_max, - int64_t num_values, - int64_t null_count, - int64_t distinct_count, - bool has_min_max, - bool has_null_count, - bool has_distinct_count, + const std::string& encodedMin, + const std::string& encodedMax, + int64_t numValues, + int64_t nullCount, + int64_t distinctCount, + bool hasMinMax, + bool hasNullCount, + bool hasDistinctCount, + bool hasNaNCount, + int64_t nanCount, MemoryPool* pool) : TypedStatisticsImpl(descr, pool) { - TypedStatisticsImpl::IncrementNumValues(num_values); - if (has_null_count) { - TypedStatisticsImpl::IncrementNullCount(null_count); + TypedStatisticsImpl::incrementNumValues(numValues); + if (hasNullCount) { + TypedStatisticsImpl::incrementNullCount(nullCount); } else { - has_null_count_ = false; + hasNullCount_ = false; } - if (has_distinct_count) { - SetDistinctCount(distinct_count); + if (hasDistinctCount) { + setDistinctCount(distinctCount); } else { - has_distinct_count_ = false; + hasDistinctCount_ = false; } - if (!encoded_min.empty()) { - PlainDecode(encoded_min, &min_); + if (hasNaNCount) { + incrementNaNValues(nanCount); + } else { + hasNanCount_ = false; + } + + if (!encodedMin.empty()) { + plainDecode(encodedMin, &min_); } - if (!encoded_max.empty()) { - PlainDecode(encoded_max, &max_); + if (!encodedMax.empty()) { + plainDecode(encodedMax, &max_); } - has_min_max_ = has_min_max; + hasMinMax_ = hasMinMax; } - bool HasDistinctCount() const override { - return has_distinct_count_; + bool hasDistinctCount() const override { + return hasDistinctCount_; }; - bool HasMinMax() const override { - return has_min_max_; + bool hasMinMax() const override { + return hasMinMax_; } - bool HasNullCount() const override { - return has_null_count_; + bool hasNullCount() const override { + return hasNullCount_; }; - void IncrementNullCount(int64_t n) override { - statistics_.null_count += n; - has_null_count_ = true; + bool hasNaNCount() const override { + return hasNanCount_; } - void IncrementNumValues(int64_t n) override { - num_values_ += n; + void incrementNullCount(int64_t n) override { + statistics_.nullCount += n; + hasNullCount_ = true; } - bool Equals(const Statistics& raw_other) const override { - if (physical_type() != raw_other.physical_type()) + void incrementNumValues(int64_t n) override { + numValues_ += n; + } + + void incrementNaNValues(int64_t n) override { + if (n > 0) { + nanCount_ += n; + hasNanCount_ = true; + } + } + + bool equals(const Statistics& rawOther) const override { + if (physicalType() != rawOther.physicalType()) return false; - const auto& other = checked_cast(raw_other); + const auto& other = checked_cast(rawOther); - if (has_min_max_ != other.has_min_max_) + if (hasMinMax_ != other.hasMinMax_) return false; - if (has_min_max_) { - if (!MinMaxEqual(other)) + if (hasMinMax_) { + if (!minMaxEqual(other)) return false; } - return null_count() == other.null_count() && - distinct_count() == other.distinct_count() && - num_values() == other.num_values(); + return nullCount() == other.nullCount() && + distinctCount() == other.distinctCount() && + numValues() == other.numValues(); } - bool MinMaxEqual(const TypedStatisticsImpl& other) const; + bool minMaxEqual(const TypedStatisticsImpl& other) const; - void Reset() override { - ResetCounts(); - ResetHasFlags(); + void reset() override { + resetCounts(); + resetHasFlags(); } - void SetMinMax(const T& arg_min, const T& arg_max) override { - SetMinMaxPair({arg_min, arg_max}); + void setMinMax(const T& argMin, const T& argMax) override { + setMinMaxPair({argMin, argMax}); } - void Merge(const TypedStatistics& other) override { - this->num_values_ += other.num_values(); - // null_count is always valid when merging page statistics into + void merge(const TypedStatistics& other) override { + this->numValues_ += other.numValues(); + // nullCount is always valid when merging page statistics into // column chunk statistics. - if (other.HasNullCount()) { - this->statistics_.null_count += other.null_count(); + if (other.hasNullCount()) { + this->statistics_.nullCount += other.nullCount(); } else { - this->has_null_count_ = false; + this->hasNullCount_ = false; + } + if (other.hasNaNCount()) { + this->nanCount_ += other.nanCount(); + this->hasNanCount_ = true; } - if (has_distinct_count_ && other.HasDistinctCount() && - (distinct_count() == 0 || other.distinct_count() == 0)) { + if (hasDistinctCount_ && other.hasDistinctCount() && + (distinctCount() == 0 || other.distinctCount() == 0)) { // We can merge distinct counts if either side is zero. - statistics_.distinct_count = - std::max(statistics_.distinct_count, other.distinct_count()); + statistics_.distinctCount = + std::max(statistics_.distinctCount, other.distinctCount()); } else { - // Otherwise clear has_distinct_count_ as distinct count cannot be merged. - this->has_distinct_count_ = false; + // Otherwise clear hasDistinctCount_ as distinct count cannot be merged. + this->hasDistinctCount_ = false; } // Do not clear min/max here if the other side does not provide // min/max which may happen when other is an empty stats or all // its values are null and/or NaN. - if (other.HasMinMax()) { - SetMinMax(other.min(), other.max()); + if (other.hasMinMax()) { + setMinMax(other.min(), other.max()); } } - void Update(const T* values, int64_t num_values, int64_t null_count) override; - void UpdateSpaced( + void update(const T* values, int64_t numValues, int64_t nullCount) override; + void updateSpaced( const T* values, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - int64_t num_spaced_values, - int64_t num_values, - int64_t null_count) override; - - void Update(const ::arrow::Array& values, bool update_counts) override { - if (update_counts) { - IncrementNullCount(values.null_count()); - IncrementNumValues(values.length() - values.null_count()); + const uint8_t* validBits, + int64_t validBitsOffset, + int64_t numSpacedValues, + int64_t numValues, + int64_t nullCount) override; + + void update(const ::arrow::Array& values, bool updateCounts) override { + if (updateCounts) { + incrementNullCount(values.null_count()); + incrementNumValues(values.length() - values.null_count()); } if (values.null_count() == values.length()) { return; } - SetMinMaxPair(comparator_->GetMinMax(values)); + setMinMaxPair(comparator_->getMinMax(values)); } const T& min() const override { @@ -720,158 +760,296 @@ class TypedStatisticsImpl : public TypedStatistics { return max_; } - Type::type physical_type() const override { - return descr_->physical_type(); + Type::type physicalType() const override { + return descr_->physicalType(); } const ColumnDescriptor* descr() const override { return descr_; } - std::string EncodeMin() const override { + std::string encodeMin() const override { std::string s; - if (HasMinMax()) - this->PlainEncode(min_, &s); + if (hasMinMax()) + this->plainEncode(min_, &s); return s; } - std::string EncodeMax() const override { + std::string encodeMax() const override { std::string s; - if (HasMinMax()) - this->PlainEncode(max_, &s); + if (hasMinMax()) + this->plainEncode(max_, &s); return s; } - EncodedStatistics Encode() override { + std::string icebergLowerBoundInclusive(int32_t truncateTo) const override { + if constexpr (std::is_same_v) { + if (descr_->logicalType()->isDecimal()) { + return encodeDecimalToBigEndian(min_); + } + } + if constexpr (std::is_same_v) { + return encodeDecimalToBigEndian(min_); + } + if constexpr (std::is_same_v) { + // STRING columns truncate by UTF-8 code points; BINARY/VARBINARY + // columns truncate by raw bytes (mirrors the upper-bound dispatch). + const std::string_view minView(min_); + const auto truncatedMin = descr_->logicalType()->isString() + ? truncateUtf8(minView, truncateTo) + : minView.substr( + 0, + std::min( + minView.size(), + static_cast(std::max(truncateTo, 0)))); + std::string s; + this->plainEncode( + ByteArray( + truncatedMin.size(), + reinterpret_cast(truncatedMin.data())), + &s); + return s; + } + return encodeMin(); + } + + std::optional icebergUpperBoundExclusive( + int32_t truncateTo) const override { + if constexpr (std::is_same_v) { + if (descr_->logicalType()->isDecimal()) { + return encodeDecimalToBigEndian(max_); + } + } + if constexpr (std::is_same_v) { + return encodeDecimalToBigEndian(max_); + } + if constexpr (std::is_same_v) { + // For ByteArray, we need to determine if this is UTF-8 text (STRING) + // or raw binary data (BINARY/VARBINARY). The Parquet logical type tells + // us this. + const bool isUtf8String = descr_->logicalType()->isString(); + + std::optional truncatedMax; + + if (isUtf8String) { + // Use UTF-8 string logic for STRING type + truncatedMax = roundUpUtf8(std::string_view(max_), truncateTo); + } else { + // Use binary byte logic for BINARY type (VARBINARY) + // Implementation follows Apache Iceberg's + // BinaryUtil.truncateBinaryMax() + truncatedMax = roundUpBinary(std::string_view(max_), truncateTo); + } + + if (!truncatedMax.has_value()) { + return std::nullopt; + } + std::string s; + this->plainEncode( + ByteArray( + truncatedMax->size(), + reinterpret_cast(truncatedMax->data())), + &s); + return s; + } + return encodeMax(); + } + + EncodedStatistics encode() override { EncodedStatistics s; - if (HasMinMax()) { - s.set_min(this->EncodeMin()); - s.set_max(this->EncodeMax()); + if (hasMinMax()) { + s.setMin(this->encodeMin()); + s.setMax(this->encodeMax()); } - if (HasNullCount()) { - s.set_null_count(this->null_count()); - // num_values_ is reliable and it means number of non-null values. - s.all_null_value = num_values_ == 0; + if (hasNullCount()) { + s.setNullCount(this->nullCount()); + // numValues_ is reliable and it means the number of non-null values. + s.allNullValue = numValues_ == 0; } - if (HasDistinctCount()) { - s.set_distinct_count(this->distinct_count()); + if (hasDistinctCount()) { + s.setDistinctCount(this->distinctCount()); + } + if (hasNanCount_) { + s.set_nan_count(nanCount_); } return s; } - int64_t null_count() const override { - return statistics_.null_count; + int64_t nullCount() const override { + return statistics_.nullCount; + } + int64_t distinctCount() const override { + return statistics_.distinctCount; + } + int64_t numValues() const override { + return numValues_; + } + + int64_t nanCount() const override { + return nanCount_; } - int64_t distinct_count() const override { - return statistics_.distinct_count; + + bool maxGreaterThan(const Statistics& other) const override { + const auto* typedOther = + dynamic_cast*>(&other); + return comparator_->compare(max_, typedOther->max_) ? false : true; } - int64_t num_values() const override { - return num_values_; + + bool minLessThan(const Statistics& other) const override { + const auto* typedOther = + dynamic_cast*>(&other); + return comparator_->compare(min_, typedOther->min_) ? true : false; } private: const ColumnDescriptor* descr_; - bool has_min_max_ = false; - bool has_null_count_ = false; - bool has_distinct_count_ = false; + bool hasMinMax_ = false; + bool hasNullCount_ = false; + bool hasDistinctCount_ = false; + bool hasNanCount_ = false; T min_; T max_; ::arrow::MemoryPool* pool_; // Number of non-null values. - // Please note that num_values_ is reliable when has_null_count_ is set. - // When has_null_count_ is not set, e.g. a page statistics created from - // a statistics thrift message which doesn't have the optional null_count, - // `num_values_` may include null values. - int64_t num_values_ = 0; + // Please note that numValues_ is reliable when hasNullCount_ is set. + // When hasNullCount_ is not set, e.g., a page statistics created from + // a statistics thrift message which doesn't have the optional nullCount, + // `numValues_` may include null values. + int64_t numValues_ = 0; + // NaN count is tracked separately since it's not written to the parquet file. + int64_t nanCount_ = 0; EncodedStatistics statistics_; std::shared_ptr> comparator_; - std::shared_ptr min_buffer_, max_buffer_; + std::shared_ptr minBuffer_, maxBuffer_; - void PlainEncode(const T& src, std::string* dst) const; - void PlainDecode(const std::string& src, T* dst) const; + void plainEncode(const T& src, std::string* dst) const; + void plainDecode(const std::string& src, T* dst) const; - void Copy(const T& src, T* dst, ResizableBuffer*) { + void copy(const T& src, T* dst, ResizableBuffer*) { *dst = src; } - void SetDistinctCount(int64_t n) { - // distinct count can only be "set", and cannot be incremented. - statistics_.distinct_count = n; - has_distinct_count_ = true; + void setDistinctCount(int64_t n) { + // Distinct count can only be "set", and cannot be incremented. + statistics_.distinctCount = n; + hasDistinctCount_ = true; } - void ResetCounts() { - this->statistics_.null_count = 0; - this->statistics_.distinct_count = 0; - this->num_values_ = 0; + void resetCounts() { + this->statistics_.nullCount = 0; + this->statistics_.distinctCount = 0; + this->numValues_ = 0; + this->nanCount_ = 0; } - void ResetHasFlags() { - // has_min_max_ will only be set when it meets any valid value. - this->has_min_max_ = false; - // has_distinct_count_ will only be set once SetDistinctCount() - // is called because distinct count calculation is not cheap and + void resetHasFlags() { + // hasMinMax_ will only be set when it meets any valid value. + this->hasMinMax_ = false; + // hasDistinctCount_ will only be set once setDistinctCount() + // is called because distinct count calculation is not cheap and. // disabled by default. - this->has_distinct_count_ = false; + this->hasDistinctCount_ = false; // Null count calculation is cheap and enabled by default. - this->has_null_count_ = true; + this->hasNullCount_ = true; + this->hasNanCount_ = false; } - void SetMinMaxPair(std::pair min_max) { - // CleanStatistic can return a nullopt in case of erroneous values, e.g. NaN - auto maybe_min_max = CleanStatistic(min_max); - if (!maybe_min_max) + void setMinMaxPair(std::pair minMax) { + // CleanStatistic can return a nullopt in case of erroneous values, e.g. + // NaN. + auto maybeMinMax = cleanStatistic(minMax); + if (!maybeMinMax) return; - auto min = maybe_min_max.value().first; - auto max = maybe_min_max.value().second; + auto min = maybeMinMax.value().first; + auto max = maybeMinMax.value().second; - if (!has_min_max_) { - has_min_max_ = true; - Copy(min, &min_, min_buffer_.get()); - Copy(max, &max_, max_buffer_.get()); + if (!hasMinMax_) { + hasMinMax_ = true; + copy(min, &min_, minBuffer_.get()); + copy(max, &max_, maxBuffer_.get()); } else { - Copy( - comparator_->Compare(min_, min) ? min_ : min, + copy( + comparator_->compare(min_, min) ? min_ : min, &min_, - min_buffer_.get()); - Copy( - comparator_->Compare(max_, max) ? max : max_, + minBuffer_.get()); + copy( + comparator_->compare(max_, max) ? max : max_, &max_, - max_buffer_.get()); + maxBuffer_.get()); + } + } + + int64_t countNaN(const T* values, int64_t length) { + if constexpr (!std::is_floating_point_v) { + return 0; + } else { + int64_t count = 0; + for (auto i = 0; i < length; i++) { + const auto val = safeLoad(values + i); + if (std::isnan(val)) { + count++; + } + } + return count; + } + } + + int64_t countNaNSpaced( + const T* values, + int64_t length, + const uint8_t* validBits, + int64_t validBitsOffset) { + if constexpr (!std::is_floating_point_v) { + return 0; + } else { + int64_t count = 0; + ::arrow::internal::VisitSetBitRunsVoid( + validBits, + validBitsOffset, + length, + [&](int64_t position, int64_t runLength) { + for (auto i = 0; i < runLength; i++) { + const auto val = safeLoad(values + i + position); + if (std::isnan(val)) { + count++; + } + } + }); + return count; } } }; template <> -inline bool TypedStatisticsImpl::MinMaxEqual( +inline bool TypedStatisticsImpl::minMaxEqual( const TypedStatisticsImpl& other) const { - uint32_t len = descr_->type_length(); + uint32_t len = descr_->typeLength(); return std::memcmp(min_.ptr, other.min_.ptr, len) == 0 && std::memcmp(max_.ptr, other.max_.ptr, len) == 0; } template -bool TypedStatisticsImpl::MinMaxEqual( +bool TypedStatisticsImpl::minMaxEqual( const TypedStatisticsImpl& other) const { return min_ == other.min_ && max_ == other.max_; } template <> -inline void TypedStatisticsImpl::Copy( +inline void TypedStatisticsImpl::copy( const FLBA& src, FLBA* dst, ResizableBuffer* buffer) { if (dst->ptr == src.ptr) return; - uint32_t len = descr_->type_length(); + uint32_t len = descr_->typeLength(); PARQUET_THROW_NOT_OK(buffer->Resize(len, false)); std::memcpy(buffer->mutable_data(), src.ptr, len); *dst = FLBA(buffer->data()); } template <> -inline void TypedStatisticsImpl::Copy( +inline void TypedStatisticsImpl::copy( const ByteArray& src, ByteArray* dst, ResizableBuffer* buffer) { @@ -883,71 +1061,75 @@ inline void TypedStatisticsImpl::Copy( } template -void TypedStatisticsImpl::Update( +void TypedStatisticsImpl::update( const T* values, - int64_t num_values, - int64_t null_count) { - VELOX_DCHECK_GE(num_values, 0); - VELOX_DCHECK_GE(null_count, 0); + int64_t numValues, + int64_t nullCount) { + VELOX_DCHECK_GE(numValues, 0); + VELOX_DCHECK_GE(nullCount, 0); - IncrementNullCount(null_count); - IncrementNumValues(num_values); + incrementNullCount(nullCount); + incrementNumValues(numValues); - if (num_values == 0) + if (numValues == 0) return; - SetMinMaxPair(comparator_->GetMinMax(values, num_values)); + setMinMaxPair(comparator_->getMinMax(values, numValues)); + incrementNaNValues(countNaN(values, numValues)); } template -void TypedStatisticsImpl::UpdateSpaced( +void TypedStatisticsImpl::updateSpaced( const T* values, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - int64_t num_spaced_values, - int64_t num_values, - int64_t null_count) { - VELOX_DCHECK_GE(num_values, 0); - VELOX_DCHECK_GE(null_count, 0); - - IncrementNullCount(null_count); - IncrementNumValues(num_values); - - if (num_values == 0) + const uint8_t* validBits, + int64_t validBitsOffset, + int64_t numSpacedValues, + int64_t numValues, + int64_t nullCount) { + VELOX_DCHECK_GE(numValues, 0); + VELOX_DCHECK_GE(nullCount, 0); + + incrementNullCount(nullCount); + incrementNumValues(numValues); + + if (numValues == 0) return; - SetMinMaxPair(comparator_->GetMinMaxSpaced( - values, num_spaced_values, valid_bits, valid_bits_offset)); + setMinMaxPair(comparator_->getMinMaxSpaced( + values, numSpacedValues, validBits, validBitsOffset)); + incrementNaNValues( + countNaNSpaced(values, numSpacedValues, validBits, validBitsOffset)); } template -void TypedStatisticsImpl::PlainEncode(const T& src, std::string* dst) +void TypedStatisticsImpl::plainEncode(const T& src, std::string* dst) const { - auto encoder = MakeTypedEncoder(Encoding::PLAIN, false, descr_, pool_); - encoder->Put(&src, 1); - auto buffer = encoder->FlushValues(); + auto encoder = + makeTypedEncoder(Encoding::kPlain, false, descr_, pool_); + encoder->put(&src, 1); + auto buffer = encoder->flushValues(); auto ptr = reinterpret_cast(buffer->data()); dst->assign(ptr, buffer->size()); } template -void TypedStatisticsImpl::PlainDecode(const std::string& src, T* dst) +void TypedStatisticsImpl::plainDecode(const std::string& src, T* dst) const { - auto decoder = MakeTypedDecoder(Encoding::PLAIN, descr_); - decoder->SetData( + auto decoder = makeTypedDecoder(Encoding::kPlain, descr_); + decoder->setData( 1, reinterpret_cast(src.c_str()), static_cast(src.size())); - decoder->Decode(dst, 1); + decoder->decode(dst, 1); } template <> -void TypedStatisticsImpl::PlainEncode( +void TypedStatisticsImpl::plainEncode( const T& src, std::string* dst) const { dst->assign(reinterpret_cast(src.ptr), src.len); } template <> -void TypedStatisticsImpl::PlainDecode( +void TypedStatisticsImpl::plainDecode( const std::string& src, T* dst) const { dst->len = static_cast(src.size()); @@ -957,47 +1139,47 @@ void TypedStatisticsImpl::PlainDecode( } // namespace // ---------------------------------------------------------------------- -// Public factory functions - -std::shared_ptr Comparator::Make( - Type::type physical_type, - SortOrder::type sort_order, - int type_length) { - if (SortOrder::SIGNED == sort_order) { - switch (physical_type) { - case Type::BOOLEAN: +// Public factory functions. + +std::shared_ptr Comparator::make( + Type::type physicalType, + SortOrder::type sortOrder, + int typeLength) { + if (SortOrder::kSigned == sortOrder) { + switch (physicalType) { + case Type::kBoolean: return std::make_shared>(); - case Type::INT32: + case Type::kInt32: return std::make_shared>(); - case Type::INT64: + case Type::kInt64: return std::make_shared>(); - case Type::INT96: + case Type::kInt96: return std::make_shared>(); - case Type::FLOAT: + case Type::kFloat: return std::make_shared>(); - case Type::DOUBLE: + case Type::kDouble: return std::make_shared>(); - case Type::BYTE_ARRAY: + case Type::kByteArray: return std::make_shared>(); - case Type::FIXED_LEN_BYTE_ARRAY: + case Type::kFixedLenByteArray: return std::make_shared>( - type_length); + typeLength); default: ParquetException::NYI("Signed Compare not implemented"); } - } else if (SortOrder::UNSIGNED == sort_order) { - switch (physical_type) { - case Type::INT32: + } else if (SortOrder::kUnsigned == sortOrder) { + switch (physicalType) { + case Type::kInt32: return std::make_shared>(); - case Type::INT64: + case Type::kInt64: return std::make_shared>(); - case Type::INT96: + case Type::kInt96: return std::make_shared>(); - case Type::BYTE_ARRAY: + case Type::kByteArray: return std::make_shared>(); - case Type::FIXED_LEN_BYTE_ARRAY: + case Type::kFixedLenByteArray: return std::make_shared>( - type_length); + typeLength); default: ParquetException::NYI("Unsigned Compare not implemented"); } @@ -1007,58 +1189,57 @@ std::shared_ptr Comparator::Make( return nullptr; } -std::shared_ptr Comparator::Make(const ColumnDescriptor* descr) { - return Make( - descr->physical_type(), descr->sort_order(), descr->type_length()); +std::shared_ptr Comparator::make(const ColumnDescriptor* descr) { + return make(descr->physicalType(), descr->sortOrder(), descr->typeLength()); } -std::shared_ptr Statistics::Make( +std::shared_ptr Statistics::make( const ColumnDescriptor* descr, ::arrow::MemoryPool* pool) { - switch (descr->physical_type()) { - case Type::BOOLEAN: + switch (descr->physicalType()) { + case Type::kBoolean: return std::make_shared>(descr, pool); - case Type::INT32: + case Type::kInt32: return std::make_shared>(descr, pool); - case Type::INT64: + case Type::kInt64: return std::make_shared>(descr, pool); - case Type::FLOAT: + case Type::kFloat: return std::make_shared>(descr, pool); - case Type::DOUBLE: + case Type::kDouble: return std::make_shared>(descr, pool); - case Type::BYTE_ARRAY: + case Type::kByteArray: return std::make_shared>(descr, pool); - case Type::FIXED_LEN_BYTE_ARRAY: + case Type::kFixedLenByteArray: return std::make_shared>(descr, pool); default: ParquetException::NYI("Statistics not implemented"); } } -std::shared_ptr Statistics::Make( - Type::type physical_type, +std::shared_ptr Statistics::make( + Type::type physicalType, const void* min, const void* max, - int64_t num_values, - int64_t null_count, - int64_t distinct_count) { -#define MAKE_STATS(CAP_TYPE, KLASS) \ - case Type::CAP_TYPE: \ - return std::make_shared>( \ - *reinterpret_cast(min), \ - *reinterpret_cast(max), \ - num_values, \ - null_count, \ - distinct_count) - - switch (physical_type) { - MAKE_STATS(BOOLEAN, BooleanType); - MAKE_STATS(INT32, Int32Type); - MAKE_STATS(INT64, Int64Type); - MAKE_STATS(FLOAT, FloatType); - MAKE_STATS(DOUBLE, DoubleType); - MAKE_STATS(BYTE_ARRAY, ByteArrayType); - MAKE_STATS(FIXED_LEN_BYTE_ARRAY, FLBAType); + int64_t numValues, + int64_t nullCount, + int64_t distinctCount) { +#define MAKE_STATS(CAP_TYPE, KLASS) \ + case Type::CAP_TYPE: \ + return std::make_shared>( \ + *reinterpret_cast(min), \ + *reinterpret_cast(max), \ + numValues, \ + nullCount, \ + distinctCount) + + switch (physicalType) { + MAKE_STATS(kBoolean, BooleanType); + MAKE_STATS(kInt32, Int32Type); + MAKE_STATS(kInt64, Int64Type); + MAKE_STATS(kFloat, FloatType); + MAKE_STATS(kDouble, DoubleType); + MAKE_STATS(kByteArray, ByteArrayType); + MAKE_STATS(kFixedLenByteArray, FLBAType); default: break; } @@ -1067,58 +1248,64 @@ std::shared_ptr Statistics::Make( return nullptr; } -std::shared_ptr Statistics::Make( +std::shared_ptr Statistics::make( const ColumnDescriptor* descr, - const EncodedStatistics* encoded_stats, - int64_t num_values, + const EncodedStatistics* encodedStats, + int64_t numValues, ::arrow::MemoryPool* pool) { - VELOX_DCHECK(encoded_stats != nullptr); - return Make( + VELOX_DCHECK(encodedStats != nullptr); + return make( descr, - encoded_stats->min(), - encoded_stats->max(), - num_values, - encoded_stats->null_count, - encoded_stats->distinct_count, - encoded_stats->has_min && encoded_stats->has_max, - encoded_stats->has_null_count, - encoded_stats->has_distinct_count, + encodedStats->min(), + encodedStats->max(), + numValues, + encodedStats->nullCount, + encodedStats->distinctCount, + encodedStats->hasMin && encodedStats->hasMax, + encodedStats->hasNullCount, + encodedStats->hasDistinctCount, + encodedStats->hasNanCount, + encodedStats->nanCount, pool); } -std::shared_ptr Statistics::Make( +std::shared_ptr Statistics::make( const ColumnDescriptor* descr, - const std::string& encoded_min, - const std::string& encoded_max, - int64_t num_values, - int64_t null_count, - int64_t distinct_count, - bool has_min_max, - bool has_null_count, - bool has_distinct_count, + const std::string& encodedMin, + const std::string& encodedMax, + int64_t numValues, + int64_t nullCount, + int64_t distinctCount, + bool hasMinMax, + bool hasNullCount, + bool hasDistinctCount, + bool hasNaNCount, + int64_t nanCount, ::arrow::MemoryPool* pool) { #define MAKE_STATS(CAP_TYPE, KLASS) \ case Type::CAP_TYPE: \ return std::make_shared>( \ descr, \ - encoded_min, \ - encoded_max, \ - num_values, \ - null_count, \ - distinct_count, \ - has_min_max, \ - has_null_count, \ - has_distinct_count, \ + encodedMin, \ + encodedMax, \ + numValues, \ + nullCount, \ + distinctCount, \ + hasMinMax, \ + hasNullCount, \ + hasDistinctCount, \ + hasNaNCount, \ + nanCount, \ pool) - switch (descr->physical_type()) { - MAKE_STATS(BOOLEAN, BooleanType); - MAKE_STATS(INT32, Int32Type); - MAKE_STATS(INT64, Int64Type); - MAKE_STATS(FLOAT, FloatType); - MAKE_STATS(DOUBLE, DoubleType); - MAKE_STATS(BYTE_ARRAY, ByteArrayType); - MAKE_STATS(FIXED_LEN_BYTE_ARRAY, FLBAType); + switch (descr->physicalType()) { + MAKE_STATS(kBoolean, BooleanType); + MAKE_STATS(kInt32, Int32Type); + MAKE_STATS(kInt64, Int64Type); + MAKE_STATS(kFloat, FloatType); + MAKE_STATS(kDouble, DoubleType); + MAKE_STATS(kByteArray, ByteArrayType); + MAKE_STATS(kFixedLenByteArray, FLBAType); default: break; } diff --git a/velox/dwio/parquet/writer/arrow/Statistics.h b/velox/dwio/parquet/writer/arrow/Statistics.h index 6abf66b0b20..64295487359 100644 --- a/velox/dwio/parquet/writer/arrow/Statistics.h +++ b/velox/dwio/parquet/writer/arrow/Statistics.h @@ -40,30 +40,28 @@ namespace facebook::velox::parquet::arrow { class ColumnDescriptor; // ---------------------------------------------------------------------- -// Value comparator interfaces +// Value Comparator interfaces. -/// \brief Base class for value comparators. Generally used with -/// TypedComparator +/// \brief Base class for value Comparators. Generally used with +/// TypedComparator. class PARQUET_EXPORT Comparator { public: virtual ~Comparator() {} - /// \brief Create a comparator explicitly from physical type and - /// sort order - /// \param[in] physical_type the physical type for the typed - /// comparator - /// \param[in] sort_order either SortOrder::SIGNED or - /// SortOrder::UNSIGNED - /// \param[in] type_length for FIXED_LEN_BYTE_ARRAY only - static std::shared_ptr Make( - Type::type physical_type, - SortOrder::type sort_order, - int type_length = -1); - - /// \brief Create typed comparator inferring default sort order from - /// ColumnDescriptor - /// \param[in] descr the Parquet column schema - static std::shared_ptr Make(const ColumnDescriptor* descr); + /// \brief Create a Comparator explicitly from physical type and + /// sort order. + /// \param[in] physicalType Physical type for the typed + /// Comparator. + /// \param[in] sortOrder Either SortOrder::kSigned or + /// SortOrder::kUnsigned. + /// \param[in] typeLength For FIXED_LEN_BYTE_ARRAY only. + static std::shared_ptr + make(Type::type physicalType, SortOrder::type sortOrder, int typeLength = -1); + + /// \brief Create typed Comparator inferring default sort order from + /// ColumnDescriptor. + /// \param[in] descr the Parquet column schema. + static std::shared_ptr make(const ColumnDescriptor* descr); }; /// \brief Interface for comparison of physical types according to the @@ -71,54 +69,54 @@ class PARQUET_EXPORT Comparator { template class TypedComparator : public Comparator { public: - using T = typename DType::c_type; + using T = typename DType::CType; /// \brief Scalar comparison of two elements, return true if first - /// is strictly less than the second - virtual bool Compare(const T& a, const T& b) = 0; + /// is strictly less than the second. + virtual bool compare(const T& a, const T& b) = 0; /// \brief Compute maximum and minimum elements in a batch of - /// elements without any nulls - virtual std::pair GetMinMax(const T* values, int64_t length) = 0; + /// elements without any nulls. + virtual std::pair getMinMax(const T* values, int64_t length) = 0; /// \brief Compute minimum and maximum elements from an Arrow array. Only /// valid for certain Parquet Type / Arrow Type combinations, like BYTE_ARRAY - /// / arrow::BinaryArray - virtual std::pair GetMinMax(const ::arrow::Array& values) = 0; + /// / Arrow::BinaryArray. + virtual std::pair getMinMax(const ::arrow::Array& values) = 0; /// \brief Compute maximum and minimum elements in a batch of /// elements with accompanying bitmap indicating which elements are - /// included (bit set) and excluded (bit not set) + /// included (bit set) and excluded (bit not set). /// - /// \param[in] values the sequence of values - /// \param[in] length the length of the sequence - /// \param[in] valid_bits a bitmap indicating which elements are - /// included (1) or excluded (0) - /// \param[in] valid_bits_offset the bit offset into the bitmap of - /// the first element in the sequence - virtual std::pair GetMinMaxSpaced( + /// \param[in] values The sequence of values. + /// \param[in] length The length of the sequence. + /// \param[in] validBits A bitmap indicating which elements are + /// included (1) or excluded (0). + /// \param[in] validBitsOffset The bit offset into the bitmap of + /// the first element in the sequence. + virtual std::pair getMinMaxSpaced( const T* values, int64_t length, - const uint8_t* valid_bits, - int64_t valid_bits_offset) = 0; + const uint8_t* validBits, + int64_t validBitsOffset) = 0; }; -/// \brief Typed version of Comparator::Make +/// \brief Typed version of Comparator::Make. template -std::shared_ptr> MakeComparator( - Type::type physical_type, - SortOrder::type sort_order, - int type_length = -1) { +std::shared_ptr> makeComparator( + Type::type physicalType, + SortOrder::type sortOrder, + int typeLength = -1) { return std::static_pointer_cast>( - Comparator::Make(physical_type, sort_order, type_length)); + Comparator::make(physicalType, sortOrder, typeLength)); } -/// \brief Typed version of Comparator::Make +/// \brief Typed version of Comparator::Make. template -std::shared_ptr> MakeComparator( +std::shared_ptr> makeComparator( const ColumnDescriptor* descr) { return std::static_pointer_cast>( - Comparator::Make(descr)); + Comparator::make(descr)); } // ---------------------------------------------------------------------- @@ -127,7 +125,7 @@ std::shared_ptr> MakeComparator( /// and read from Parquet serialized metadata. class PARQUET_EXPORT EncodedStatistics { std::string max_, min_; - bool is_signed_ = false; + bool isSigned_ = false; public: EncodedStatistics() = default; @@ -139,236 +137,300 @@ class PARQUET_EXPORT EncodedStatistics { return min_; } - int64_t null_count = 0; - int64_t distinct_count = 0; + int64_t nullCount = 0; + int64_t distinctCount = 0; + int64_t nanCount = 0; - bool has_min = false; - bool has_max = false; - bool has_null_count = false; - bool has_distinct_count = false; + bool hasMin = false; + bool hasMax = false; + bool hasNullCount = false; + bool hasDistinctCount = false; + bool hasNanCount = false; // When all values in the statistics are null, it is set to true. // Otherwise, at least one value is not null, or we are not sure at all. // Page index requires this information to decide whether a data page // is a null page or not. - bool all_null_value = false; + bool allNullValue = false; - // From parquet-mr + // From parquet-mr. // Don't write stats larger than the max size rather than truncating. The // rationale is that some engines may use the minimum value in the page as // the true minimum for aggregations and there is no way to mark that a // value has been truncated and is a lower bound and not in the page. - void ApplyStatSizeLimits(size_t length) { + void applyStatSizeLimits(size_t length) { if (max_.length() > length) { - has_max = false; + hasMax = false; max_.clear(); } if (min_.length() > length) { - has_min = false; + hasMin = false; min_.clear(); } } - bool is_set() const { - return has_min || has_max || has_null_count || has_distinct_count; + bool isSet() const { + return hasMin || hasMax || hasNullCount || hasDistinctCount; } - bool is_signed() const { - return is_signed_; + bool isSigned() const { + return isSigned_; } - void set_is_signed(bool is_signed) { - is_signed_ = is_signed; + void setIsSigned(bool isSigned) { + isSigned_ = isSigned; } - EncodedStatistics& set_max(std::string value) { + EncodedStatistics& setMax(std::string value) { max_ = std::move(value); - has_max = true; + hasMax = true; return *this; } - EncodedStatistics& set_min(std::string value) { + EncodedStatistics& setMin(std::string value) { min_ = std::move(value); - has_min = true; + hasMin = true; return *this; } - EncodedStatistics& set_null_count(int64_t value) { - null_count = value; - has_null_count = true; + EncodedStatistics& setNullCount(int64_t value) { + nullCount = value; + hasNullCount = true; return *this; } - EncodedStatistics& set_distinct_count(int64_t value) { - distinct_count = value; - has_distinct_count = true; + EncodedStatistics& setDistinctCount(int64_t value) { + distinctCount = value; + hasDistinctCount = true; + return *this; + } + + EncodedStatistics& set_nan_count(int64_t value) { + nanCount = value; + hasNanCount = true; return *this; } }; -/// \brief Base type for computing column statistics while writing a file +/// \brief Base type for computing column statistics while writing a file. class PARQUET_EXPORT Statistics { public: virtual ~Statistics() {} /// \brief Create a new statistics instance given a column schema - /// definition - /// \param[in] descr the column schema - /// \param[in] pool a memory pool to use for any memory allocations, optional - static std::shared_ptr Make( + /// definition. + /// \param[in] descr The column schema. + /// \param[in] pool A memory pool to use for any memory allocations, optional. + static std::shared_ptr make( const ColumnDescriptor* descr, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()); /// \brief Create a new statistics instance given a column schema - /// definition and pre-existing state - /// \param[in] descr the column schema - /// \param[in] encoded_min the encoded minimum value - /// \param[in] encoded_max the encoded maximum value - /// \param[in] num_values total number of values - /// \param[in] null_count number of null values - /// \param[in] distinct_count number of distinct values - /// \param[in] has_min_max whether the min/max statistics are set - /// \param[in] has_null_count whether the null_count statistics are set - /// \param[in] has_distinct_count whether the distinct_count statistics are - /// set \param[in] pool a memory pool to use for any memory allocations, - /// optional - static std::shared_ptr Make( + /// definition and pre-existing state. + /// \param[in] descr The column schema. + /// \param[in] encodedMin The encoded minimum value. + /// \param[in] encodedMax The encoded maximum value. + /// \param[in] numValues Total number of values. + /// \param[in] nullCount Number of null values. + /// \param[in] distinctCount Number of distinct values. + /// \param[in] hasMinMax Whether the min/max statistics are set. + /// \param[in] hasNullCount Whether the nullCount statistics are set. + /// \param[in] hasDistinctCount Whether the distinctCount statistics are set. + /// \param[in] pool A memory pool to use for any memory allocations, + /// optional. + static std::shared_ptr make( const ColumnDescriptor* descr, - const std::string& encoded_min, - const std::string& encoded_max, - int64_t num_values, - int64_t null_count, - int64_t distinct_count, - bool has_min_max, - bool has_null_count, - bool has_distinct_count, + const std::string& encodedMin, + const std::string& encodedMax, + int64_t numValues, + int64_t nullCount, + int64_t distinctCount, + bool hasMinMax, + bool hasNullCount, + bool hasDistinctCount, + bool hasNaNCount, + int64_t nanCount, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()); // Helper function to convert EncodedStatistics to Statistics. - // EncodedStatistics does not contain number of non-null values, and it can be - // passed using the num_values parameter. - static std::shared_ptr Make( + // EncodedStatistics does not contain number of non-null values, and it can + // be passed using the numValues parameter. + static std::shared_ptr make( const ColumnDescriptor* descr, - const EncodedStatistics* encoded_statistics, - int64_t num_values = -1, + const EncodedStatistics* encodedStats, + int64_t numValues = -1, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()); - /// \brief Return true if the count of null values is set - virtual bool HasNullCount() const = 0; + /// \brief Return true if the count of null values is set. + virtual bool hasNullCount() const = 0; + + /// \brief The number of null values, may not be set. + virtual int64_t nullCount() const = 0; - /// \brief The number of null values, may not be set - virtual int64_t null_count() const = 0; + /// \brief Return true if the count of distinct values is set. + virtual bool hasDistinctCount() const = 0; - /// \brief Return true if the count of distinct values is set - virtual bool HasDistinctCount() const = 0; + /// \brief The number of distinct values, may not be set. + virtual int64_t distinctCount() const = 0; - /// \brief The number of distinct values, may not be set - virtual int64_t distinct_count() const = 0; + /// \brief The number of non-null values in the column. + virtual int64_t numValues() const = 0; - /// \brief The number of non-null values in the column - virtual int64_t num_values() const = 0; + /// \brief Return true if the count of nan values is set. + virtual bool hasNaNCount() const = 0; + + /// \brief The number of NaN values, may not be set. + virtual int64_t nanCount() const = 0; /// \brief Return true if the min and max statistics are set. Obtain - /// with TypedStatistics::min and max - virtual bool HasMinMax() const = 0; + /// with TypedStatistics::min and max. + virtual bool hasMinMax() const = 0; + + /// \brief Reset state of object to initial (no data observed) state. + virtual void reset() = 0; + + /// \brief Plain-encoded minimum value. + virtual std::string encodeMin() const = 0; - /// \brief Reset state of object to initial (no data observed) state - virtual void Reset() = 0; + /// \brief Plain-encoded maximum value. + virtual std::string encodeMax() const = 0; - /// \brief Plain-encoded minimum value - virtual std::string EncodeMin() const = 0; + /// \brief Encoded lower bound value compatible with Iceberg. + /// + /// Returns an encoded value guaranteed to be <= the actual minimum value. + /// For string types, truncates to at most \p truncateTo Unicode code + /// points. For decimal types, encodes the value in big-endian format as + /// required by Iceberg's single-value serialization specification. For + /// all other data types, uses the same plain encoding as Parquet. + /// (Returns the exact encoded minimum value). + /// + /// @param truncateTo Maximum number of Unicode code points for string + /// types. + virtual std::string icebergLowerBoundInclusive(int32_t truncateTo) const = 0; - /// \brief Plain-encoded maximum value - virtual std::string EncodeMax() const = 0; + /// \brief Encoded upper bound value compatible with Iceberg. + /// + /// Returns an encoded value guaranteed to be >= the actual maximum value. + /// For string types: + /// - If the maximum value has <= \p truncateTo Unicode code points, + /// returns the exact encoded maximum value (inclusive upper bound). + /// - If the maximum value has > \p truncateTo Unicode code points, + /// truncates to \p truncateTo code points and increments the last code point + /// to produce an exclusive upper bound that is greater than the maximum + /// value. + /// - Returns std::nullopt if no valid upper bound can be computed (e.g., + /// all code points in the truncated portion are at the maximum Unicode + /// value U+10FFFF). This allows distinguishing between "upper bound is + /// empty string" and "no valid upper bound exists". + /// For decimal types, encodes the value in big-endian format as required + /// by Iceberg's single-value serialization specification. For all other + /// data types, uses the same plain encoding as Parquet. (Returns the + /// exact encoded maximum value). + /// + /// @param truncateTo Maximum number of Unicode code points for string + /// types. + /// @return Encoded upper bound value, or std::nullopt if no valid upper + /// bound can be computed. + virtual std::optional icebergUpperBoundExclusive( + int32_t truncateTo) const = 0; - /// \brief The finalized encoded form of the statistics for transport - virtual EncodedStatistics Encode() = 0; + /// \brief The finalized encoded form of the statistics for transport. + virtual EncodedStatistics encode() = 0; - /// \brief The physical type of the column schema - virtual Type::type physical_type() const = 0; + /// \brief The physical type of the column schema. + virtual Type::type physicalType() const = 0; - /// \brief The full type descriptor from the column schema + /// \brief The full type descriptor from the column schema. virtual const ColumnDescriptor* descr() const = 0; - /// \brief Check two Statistics for equality - virtual bool Equals(const Statistics& other) const = 0; + /// \brief Check two Statistics for equality. + virtual bool equals(const Statistics& other) const = 0; + + /// \brief Return true if this object's max is greater than the other's + /// max. + /// \param[in] other The Statistics object to compare against. + virtual bool maxGreaterThan(const Statistics& other) const = 0; + + /// \brief Return true if this object's min is less than the other's min. + /// \param[in] other The Statistics object to compare against. + virtual bool minLessThan(const Statistics& other) const = 0; protected: - static std::shared_ptr Make( - Type::type physical_type, + static std::shared_ptr make( + Type::type physicalType, const void* min, const void* max, - int64_t num_values, - int64_t null_count, - int64_t distinct_count); + int64_t numValues, + int64_t nullCount, + int64_t distinctCount); }; -/// \brief A typed implementation of Statistics +/// \brief A typed implementation of Statistics. template class TypedStatistics : public Statistics { public: - using T = typename DType::c_type; + using T = typename DType::CType; - /// \brief The current minimum value + /// \brief The current minimum value. virtual const T& min() const = 0; - /// \brief The current maximum value + /// \brief The current maximum value. virtual const T& max() const = 0; - /// \brief Update state with state of another Statistics object - virtual void Merge(const TypedStatistics& other) = 0; + /// \brief Update state with state of another Statistics object. + virtual void merge(const TypedStatistics& other) = 0; - /// \brief Batch statistics update + /// \brief Batch statistics update. virtual void - Update(const T* values, int64_t num_values, int64_t null_count) = 0; - - /// \brief Batch statistics update with supplied validity bitmap - /// \param[in] values pointer to column values - /// \param[in] valid_bits Pointer to bitmap representing if values are - /// non-null. \param[in] valid_bits_offset Offset offset into valid_bits where - /// the slice of - /// data begins. - /// \param[in] num_spaced_values The length of values in values/valid_bits to - /// inspect - /// when calculating statistics. This can be - /// smaller than num_values+null_count as - /// null_count can include nulls from parents - /// while num_spaced_values does not. - /// \param[in] num_values Number of values that are not null. - /// \param[in] null_count Number of values that are null. - virtual void UpdateSpaced( + update(const T* values, int64_t numValues, int64_t nullCount) = 0; + + /// \brief Batch statistics update with supplied validity bitmap. + /// \param[in] values Pointer to column values. + /// \param[in] validBits Pointer to bitmap representing if values are + /// non-null. \param[in] validBitsOffset Offset offset into validBits + /// where the slice of data begins. + /// \param[in] numSpacedValues The length of values in values/validBits to + /// inspect when calculating statistics. This can be smaller than + /// numValues + nullCount as nullCount can include nulls from parents + /// while numSpacedValues does not. + /// \param[in] numValues Number of values that are not null. + /// \param[in] nullCount Number of values that are null. + virtual void updateSpaced( const T* values, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - int64_t num_spaced_values, - int64_t num_values, - int64_t null_count) = 0; + const uint8_t* validBits, + int64_t validBitsOffset, + int64_t numSpacedValues, + int64_t numValues, + int64_t nullCount) = 0; /// \brief EXPERIMENTAL: Update statistics with an Arrow array without /// conversion to a primitive Parquet C type. Only implemented for certain /// Parquet type / Arrow type combinations like BYTE_ARRAY / - /// arrow::BinaryArray + /// Arrow::BinaryArray. /// - /// If update_counts is true then the null_count and num_values will be - /// updated based on the null_count of values. Set to false if these are + /// If updateCounts is true then the nullCount and numValues will be + /// updated based on the nullCount of values. Set to false if these are /// updated elsewhere (e.g. when updating a dictionary where the counts are - /// taken from the indices and not the values) - virtual void Update( + /// taken from the indices and not the values). + virtual void update( const ::arrow::Array& values, - bool update_counts = true) = 0; + bool updateCounts = true) = 0; + + /// \brief Set min and max values to particular values. + virtual void setMinMax(const T& min, const T& max) = 0; - /// \brief Set min and max values to particular values - virtual void SetMinMax(const T& min, const T& max) = 0; + /// \brief Increments the null count directly. + /// Use Update to extract the null count from data. Use this if you + /// determine the null count through some other means (e.g. dictionary arrays + /// where the null count is determined from the indices). + virtual void incrementNullCount(int64_t n) = 0; - /// \brief Increments the null count directly - /// Use Update to extract the null count from data. Use this if you determine - /// the null count through some other means (e.g. dictionary arrays where the - /// null count is determined from the indices) - virtual void IncrementNullCount(int64_t n) = 0; + /// \brief Increments the number of values directly. + /// The same note on IncrementNullCount applies here. + virtual void incrementNumValues(int64_t n) = 0; - /// \brief Increments the number of values directly - /// The same note on IncrementNullCount applies here - virtual void IncrementNumValues(int64_t n) = 0; + /// \brief Increments the NaN count directly. + virtual void incrementNaNValues(int64_t n) = 0; }; using BoolStatistics = TypedStatistics; @@ -379,55 +441,59 @@ using DoubleStatistics = TypedStatistics; using ByteArrayStatistics = TypedStatistics; using FLBAStatistics = TypedStatistics; -/// \brief Typed version of Statistics::Make +/// \brief Typed version of Statistics::Make. template -std::shared_ptr> MakeStatistics( +std::shared_ptr> makeStatistics( const ColumnDescriptor* descr, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()) { return std::static_pointer_cast>( - Statistics::Make(descr, pool)); + Statistics::make(descr, pool)); } -/// \brief Create Statistics initialized to a particular state -/// \param[in] min the minimum value -/// \param[in] max the minimum value -/// \param[in] num_values number of values -/// \param[in] null_count number of null values -/// \param[in] distinct_count number of distinct values +/// \brief Create Statistics initialized to a particular state. +/// \param[in] min The minimum value. +/// \param[in] max The maximum value. +/// \param[in] numValues number of values. +/// \param[in] nullCount Number of null values. +/// \param[in] distinctCount Number of distinct values. template -std::shared_ptr> MakeStatistics( - const typename DType::c_type& min, - const typename DType::c_type& max, - int64_t num_values, - int64_t null_count, - int64_t distinct_count) { - return std::static_pointer_cast>(Statistics::Make( - DType::type_num, &min, &max, num_values, null_count, distinct_count)); +std::shared_ptr> makeStatistics( + const typename DType::CType& min, + const typename DType::CType& max, + int64_t numValues, + int64_t nullCount, + int64_t distinctCount) { + return std::static_pointer_cast>(Statistics::make( + DType::typeNum, &min, &max, numValues, nullCount, distinctCount)); } -/// \brief Typed version of Statistics::Make +/// \brief Typed version of Statistics::Make. template -std::shared_ptr> MakeStatistics( +std::shared_ptr> makeStatistics( const ColumnDescriptor* descr, - const std::string& encoded_min, - const std::string& encoded_max, - int64_t num_values, - int64_t null_count, - int64_t distinct_count, - bool has_min_max, - bool has_null_count, - bool has_distinct_count, + const std::string& encodedMin, + const std::string& encodedMax, + int64_t numValues, + int64_t nullCount, + int64_t distinctCount, + bool hasMinMax, + bool hasNullCount, + bool hasDistinctCount, + bool hasNaNCount, + int64_t nanCount, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()) { - return std::static_pointer_cast>(Statistics::Make( + return std::static_pointer_cast>(Statistics::make( descr, - encoded_min, - encoded_max, - num_values, - null_count, - distinct_count, - has_min_max, - has_null_count, - has_distinct_count, + encodedMin, + encodedMax, + numValues, + nullCount, + distinctCount, + hasMinMax, + hasNullCount, + hasDistinctCount, + hasNaNCount, + nanCount, pool)); } diff --git a/velox/dwio/parquet/writer/arrow/StringTruncation.cpp b/velox/dwio/parquet/writer/arrow/StringTruncation.cpp new file mode 100644 index 00000000000..a7b4bad6bcb --- /dev/null +++ b/velox/dwio/parquet/writer/arrow/StringTruncation.cpp @@ -0,0 +1,180 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/dwio/parquet/writer/arrow/StringTruncation.h" + +#include +#include +#include +#include +#include +#include + +#include "velox/functions/lib/string/StringCore.h" +#include "velox/functions/lib/string/StringImpl.h" + +namespace facebook::velox::parquet::arrow { + +// Import necessary functions from stringImpl namespace +using facebook::velox::functions::stringCore::isAscii; +using facebook::velox::functions::stringImpl::cappedByteLength; + +namespace { + +// Increments a Unicode code point to the next valid Unicode scalar value. +// Returns 0 if overflow (input is max code point). +FOLLY_ALWAYS_INLINE int32_t incrementCodePoint(int32_t codePoint) { + static constexpr int32_t kMaxCodePoint = 0x10FFFF; + static constexpr int32_t kMinSurrogate = 0xD800; + static constexpr int32_t kMaxSurrogate = 0xDFFF; + if (codePoint == (kMinSurrogate - 1)) { + // Skip the surrogate range. + return kMaxSurrogate + 1; + } else if (codePoint == kMaxCodePoint) { + return 0; + } + return codePoint + 1; +} + +// ASCII fast-path for roundUp. +FOLLY_ALWAYS_INLINE std::optional roundUpAscii( + std::string_view input, + int32_t numCodePoints) { + const size_t truncatedLength = + std::min(input.size(), static_cast(numCodePoints)); + + if (truncatedLength == input.size()) { + return std::string(input); + } + + if (truncatedLength == 0) { + return std::nullopt; + } + + for (int32_t i = truncatedLength - 1; i >= 0; --i) { + const auto byte = static_cast(input[i]); + if (byte < 0x7F) { + std::string result(input.data(), i); + result.push_back(static_cast(byte + 1)); + return result; + } + } + + // All bytes are 0x7F (DEL character), no valid upper bound. + return std::nullopt; +} + +// Unicode path for roundUp. +FOLLY_ALWAYS_INLINE std::optional roundUpUnicode( + std::string_view input, + int32_t numCodePoints) { + const auto truncatedLength = cappedByteLength(input, numCodePoints); + + if (truncatedLength == input.size()) { + return std::string(input); + } + + if (truncatedLength == 0) { + return std::nullopt; + } + + const char* data = input.data(); + const char* truncatedEnd = data + truncatedLength; + + // Collect the byte offset of each code point. + std::vector codePointOffsets; + codePointOffsets.reserve(numCodePoints); + const char* current = data; + while (current < truncatedEnd) { + codePointOffsets.push_back(current - data); + int32_t charLength; + utf8proc_codepoint(current, truncatedEnd, charLength); + current += charLength; + } + + // Try incrementing from the last code point backwards. + for (int32_t i = codePointOffsets.size() - 1; i >= 0; --i) { + const char* pos = data + codePointOffsets[i]; + int32_t charLength; + const auto codePoint = utf8proc_codepoint(pos, truncatedEnd, charLength); + const auto nextCodePoint = incrementCodePoint(codePoint); + if (nextCodePoint != 0) { + std::string result(data, codePointOffsets[i]); + char buffer[4]; + const auto bytesWritten = utf8proc_encode_char( + nextCodePoint, reinterpret_cast(buffer)); + result.append(buffer, bytesWritten); + return result; + } + } + + // No valid upper bound can be found. + return std::nullopt; +} + +} // namespace + +std::string_view truncateUtf8(std::string_view input, int32_t numCodePoints) { + if (isAscii(input.data(), input.size())) { + return std::string_view( + input.data(), std::min(input.size(), (size_t)numCodePoints)); + } + const auto truncatedLength = cappedByteLength(input, numCodePoints); + return std::string_view(input.data(), truncatedLength); +} + +std::optional roundUpUtf8( + std::string_view input, + int32_t numCodePoints) { + if (isAscii(input.data(), input.size())) { + return roundUpAscii(input, numCodePoints); + } + return roundUpUnicode(input, numCodePoints); +} + +std::optional roundUpBinary( + std::string_view input, + int32_t truncateLength) { + if (truncateLength <= 0) { + return std::nullopt; + } + + const size_t length = static_cast(truncateLength); + if (input.size() <= length) { + return std::string(input); + } + + // Create a mutable copy of the truncated input. + std::string result(input.data(), length); + + // Try incrementing bytes from the end. + for (size_t i = length; i-- > 0;) { + unsigned char byte = static_cast(result[i]); + + if (byte != 0xFF) { // Can increment without overflow. + result[i] = static_cast(byte + 1); + // Truncate to i + 1 bytes (remove trailing bytes after increment point). + result.resize(i + 1); + return result; + } + // If byte == 0xFF, it will overflow, continue to previous byte. + } + + // All bytes were 0xFF and overflowed - no valid upper bound. + return std::nullopt; +} + +} // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/StringTruncation.h b/velox/dwio/parquet/writer/arrow/StringTruncation.h new file mode 100644 index 00000000000..2a0fcb35ec4 --- /dev/null +++ b/velox/dwio/parquet/writer/arrow/StringTruncation.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace facebook::velox::parquet::arrow { + +/// Truncates a UTF-8 encoded string to at most 'numCodePoints' Unicode code +/// points. Returns a string_view pointing to the truncated portion of the +/// input string. This is used for computing lower bound statistics, +/// as the truncated string is guaranteed to be less than or equal to the +/// original string in lexicographic order. +/// +/// @param input The UTF-8 encoded input string. +/// @param numCodePoints Maximum number of Unicode code points to retain. +/// @return A string_view of the truncated string. +std::string_view truncateUtf8(std::string_view input, int32_t numCodePoints); + +/// Rounds up a UTF-8 encoded string to produce an exclusive upper bound. +/// The result is guaranteed to be greater than any string that shares the +/// same prefix up to 'numCodePoints' code points. This is used for computing +/// upper bound statistics. +/// +/// The function behaves as follows: +/// - If the string has fewer than or equal to 'numCodePoints' code points, +/// returns the original string unchanged. +/// - Otherwise, truncates to 'numCodePoints' code points and increments +/// code points from the last to the first, returning immediately on the +/// first successful increment. +/// - If no code point can be incremented (e.g., all are at max value +/// U+10FFFF), returns std::nullopt. +/// +/// @param input The UTF-8 encoded input string. +/// @param numCodePoints Maximum number of Unicode code points to retain. +/// @return A new string containing the rounded-up result, or std::nullopt if +/// no valid upper bound can be computed. +std::optional roundUpUtf8( + std::string_view input, + int32_t numCodePoints); + +/// Computes an upper bound for binary data by truncating to a specified length +/// and incrementing the last byte that is not 0xFF. +/// +/// This function is used for computing upper bounds on binary statistics +/// (e.g., for Parquet file metadata). It follows the algorithm described in +/// Apache Iceberg's BinaryUtil.truncateBinaryMax(). +/// +/// The algorithm: +/// 1. If the input is shorter than or equal to truncateLength, return it as-is. +/// 2. Otherwise, truncate to truncateLength bytes. +/// 3. Starting from the last byte, find the first byte that is not 0xFF. +/// 4. Increment that byte and truncate everything after it. +/// 5. If all bytes are 0xFF, return std::nullopt (no valid upper bound). +/// +/// @param input The binary data as a string_view. +/// @param truncateLength Maximum number of bytes to retain before incrementing. +/// @return An optional string containing the upper bound, or std::nullopt if +/// no valid upper bound exists (e.g., all bytes are 0xFF). +std::optional roundUpBinary( + std::string_view input, + int32_t truncateLength); + +} // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/ThriftInternal.h b/velox/dwio/parquet/writer/arrow/ThriftInternal.h index a6383474d23..2ce0e151741 100644 --- a/velox/dwio/parquet/writer/arrow/ThriftInternal.h +++ b/velox/dwio/parquet/writer/arrow/ThriftInternal.h @@ -46,38 +46,38 @@ namespace facebook::velox::parquet::arrow { -// ---------------------------------------------------------------------- -// Convert Thrift enums to Parquet enums +// ----------------------------------------------------------------------. +// Convert Thrift enums to Parquet enums. // Unsafe enum converters (input is not checked for validity) -static inline Type::type FromThriftUnsafe( +static inline Type::type fromThriftUnsafe( facebook::velox::parquet::thrift::Type::type type) { return static_cast(type); } -static inline ConvertedType::type FromThriftUnsafe( +static inline ConvertedType::type fromThriftUnsafe( facebook::velox::parquet::thrift::ConvertedType::type type) { - // item 0 is NONE + // Item 0 is NONE. return static_cast(static_cast(type) + 1); } -static inline Repetition::type FromThriftUnsafe( +static inline Repetition::type fromThriftUnsafe( facebook::velox::parquet::thrift::FieldRepetitionType::type type) { return static_cast(type); } -static inline Encoding::type FromThriftUnsafe( +static inline Encoding::type fromThriftUnsafe( facebook::velox::parquet::thrift::Encoding::type type) { return static_cast(type); } -static inline PageType::type FromThriftUnsafe( +static inline PageType::type fromThriftUnsafe( facebook::velox::parquet::thrift::PageType::type type) { return static_cast(type); } -static inline Compression::type FromThriftUnsafe( +static inline Compression::type fromThriftUnsafe( facebook::velox::parquet::thrift::CompressionCodec::type type) { switch (type) { case facebook::velox::parquet::thrift::CompressionCodec::UNCOMPRESSED: @@ -102,7 +102,7 @@ static inline Compression::type FromThriftUnsafe( } } -static inline BoundaryOrder::type FromThriftUnsafe( +static inline BoundaryOrder::type fromThriftUnsafe( facebook::velox::parquet::thrift::BoundaryOrder::type type) { return static_cast(type); } @@ -110,131 +110,131 @@ static inline BoundaryOrder::type FromThriftUnsafe( namespace internal { template -struct ThriftEnumTypeTraits {}; +struct ThriftenumTypeTraits {}; template <> -struct ThriftEnumTypeTraits<::facebook::velox::parquet::thrift::Type::type> { - using ParquetEnum = Type; +struct ThriftenumTypeTraits<::facebook::velox::parquet::thrift::Type::type> { + using Parquetenum = Type; }; template <> -struct ThriftEnumTypeTraits< +struct ThriftenumTypeTraits< ::facebook::velox::parquet::thrift::ConvertedType::type> { - using ParquetEnum = ConvertedType; + using Parquetenum = ConvertedType; }; template <> -struct ThriftEnumTypeTraits< +struct ThriftenumTypeTraits< ::facebook::velox::parquet::thrift::FieldRepetitionType::type> { - using ParquetEnum = Repetition; + using Parquetenum = Repetition; }; template <> -struct ThriftEnumTypeTraits< +struct ThriftenumTypeTraits< ::facebook::velox::parquet::thrift::Encoding::type> { - using ParquetEnum = Encoding; + using Parquetenum = Encoding; }; template <> -struct ThriftEnumTypeTraits< +struct ThriftenumTypeTraits< ::facebook::velox::parquet::thrift::PageType::type> { - using ParquetEnum = PageType; + using Parquetenum = PageType; }; template <> -struct ThriftEnumTypeTraits< +struct ThriftenumTypeTraits< ::facebook::velox::parquet::thrift::BoundaryOrder::type> { - using ParquetEnum = BoundaryOrder; + using Parquetenum = BoundaryOrder; }; -// If the parquet file is corrupted it is possible the enum value decoded -// will not be in the range of defined values, which is undefined behaviour. +// If the parquet file is corrupted it is possible the enum value decoded. +// Will not be in the range of defined values, which is undefined behaviour. // This facility prevents this by loading the value as the underlying type // and checking to make sure it is in range. template < - typename EnumType, - typename EnumTypeRaw = typename std::underlying_type::type> -inline static EnumTypeRaw LoadEnumRaw(const EnumType* in) { - EnumTypeRaw raw_value; + typename enumType, + typename enumTypeRaw = typename std::underlying_type::type> +inline static enumTypeRaw loadenumRaw(const enumType* in) { + enumTypeRaw rawValue; // Use memcpy(), as a regular cast would be undefined behaviour on invalid - // values - memcpy(&raw_value, in, sizeof(EnumType)); - return raw_value; + // values. + memcpy(&rawValue, in, sizeof(enumType)); + return rawValue; } template struct SafeLoader { - using ApiTypeEnum = typename ApiType::type; - using ApiTypeRawEnum = typename std::underlying_type::type; + using ApiTypeenum = typename ApiType::type; + using ApiTypeRawenum = typename std::underlying_type::type; template - inline static ApiTypeRawEnum LoadRaw(const ThriftType* in) { + inline static ApiTypeRawenum loadRaw(const ThriftType* in) { static_assert( - sizeof(ApiTypeEnum) == sizeof(ThriftType), + sizeof(ApiTypeenum) == sizeof(ThriftType), "parquet type should always be the same size as thrift type"); - return static_cast(LoadEnumRaw(in)); + return static_cast(loadenumRaw(in)); } template - inline static ApiTypeEnum LoadChecked( + inline static ApiTypeenum loadChecked( const typename std::enable_if::type* in) { - auto raw_value = LoadRaw(in); + auto rawValue = loadRaw(in); if (ARROW_PREDICT_FALSE( - raw_value >= static_cast(ApiType::UNDEFINED))) { - return ApiType::UNDEFINED; + rawValue >= static_cast(ApiType::kUndefined))) { + return ApiType::kUndefined; } - return FromThriftUnsafe(static_cast(raw_value)); + return fromThriftUnsafe(static_cast(rawValue)); } template - inline static ApiTypeEnum LoadChecked( + inline static ApiTypeenum loadChecked( const typename std::enable_if::type* in) { - auto raw_value = LoadRaw(in); + auto rawValue = loadRaw(in); if (ARROW_PREDICT_FALSE( - raw_value >= static_cast(ApiType::UNDEFINED) || - raw_value < 0)) { - return ApiType::UNDEFINED; + rawValue >= static_cast(ApiType::kUndefined) || + rawValue < 0)) { + return ApiType::kUndefined; } - return FromThriftUnsafe(static_cast(raw_value)); + return fromThriftUnsafe(static_cast(rawValue)); } template - inline static ApiTypeEnum Load(const ThriftType* in) { - return LoadChecked::value>(in); + inline static ApiTypeenum load(const ThriftType* in) { + return loadChecked::value>(in); } }; } // namespace internal -// Safe enum loader: will check for invalid enum value before converting +// Safe enum loader: will check for invalid enum value before converting. template < typename ThriftType, - typename ParquetEnum = - typename internal::ThriftEnumTypeTraits::ParquetEnum> -inline typename ParquetEnum::type LoadEnumSafe(const ThriftType* in) { - return internal::SafeLoader::Load(in); + typename Parquetenum = + typename internal::ThriftenumTypeTraits::Parquetenum> +inline typename Parquetenum::type loadenumSafe(const ThriftType* in) { + return internal::SafeLoader::load(in); } -inline typename Compression::type LoadEnumSafe( +inline typename Compression::type loadenumSafe( const facebook::velox::parquet::thrift::CompressionCodec::type* in) { - const auto raw_value = internal::LoadEnumRaw(in); + const auto rawValue = internal::loadenumRaw(in); // Check bounds manually, as Compression::type doesn't have the same values // as facebook::velox::parquet::thrift::CompressionCodec. - const auto min_value = static_cast( + const auto minValue = static_cast( facebook::velox::parquet::thrift::CompressionCodec::UNCOMPRESSED); - const auto max_value = static_cast( + const auto maxValue = static_cast( facebook::velox::parquet::thrift::CompressionCodec::LZ4_RAW); - if (raw_value < min_value || raw_value > max_value) { + if (rawValue < minValue || rawValue > maxValue) { return Compression::UNCOMPRESSED; } - return FromThriftUnsafe(*in); + return fromThriftUnsafe(*in); } -// Safe non-enum converters +// Safe non-enum converters. -static inline AadMetadata FromThrift( +static inline AadMetadata fromThrift( facebook::velox::parquet::thrift::AesGcmV1 aesGcmV1) { return AadMetadata{ aesGcmV1.aad_prefix, @@ -242,7 +242,7 @@ static inline AadMetadata FromThrift( aesGcmV1.supply_aad_prefix}; } -static inline AadMetadata FromThrift( +static inline AadMetadata fromThrift( facebook::velox::parquet::thrift::AesGcmCtrV1 aesGcmCtrV1) { return AadMetadata{ aesGcmCtrV1.aad_prefix, @@ -250,67 +250,68 @@ static inline AadMetadata FromThrift( aesGcmCtrV1.supply_aad_prefix}; } -static inline EncryptionAlgorithm FromThrift( +static inline EncryptionAlgorithm fromThrift( facebook::velox::parquet::thrift::EncryptionAlgorithm encryption) { - EncryptionAlgorithm encryption_algorithm; + EncryptionAlgorithm encryptionAlgorithm; if (encryption.__isset.AES_GCM_V1) { - encryption_algorithm.algorithm = ParquetCipher::AES_GCM_V1; - encryption_algorithm.aad = FromThrift(encryption.AES_GCM_V1); + encryptionAlgorithm.algorithm = ParquetCipher::kAesGcmV1; + encryptionAlgorithm.aad = fromThrift(encryption.AES_GCM_V1); } else if (encryption.__isset.AES_GCM_CTR_V1) { - encryption_algorithm.algorithm = ParquetCipher::AES_GCM_CTR_V1; - encryption_algorithm.aad = FromThrift(encryption.AES_GCM_CTR_V1); + encryptionAlgorithm.algorithm = ParquetCipher::kAesGcmCtrV1; + encryptionAlgorithm.aad = fromThrift(encryption.AES_GCM_CTR_V1); } else { throw ParquetException("Unsupported algorithm"); } - return encryption_algorithm; + return encryptionAlgorithm; } -static inline SortingColumn FromThrift( - facebook::velox::parquet::thrift::SortingColumn thrift_sorting_column) { - SortingColumn sorting_column; - sorting_column.column_idx = thrift_sorting_column.column_idx; - sorting_column.nulls_first = thrift_sorting_column.nulls_first; - sorting_column.descending = thrift_sorting_column.descending; - return sorting_column; +static inline SortingColumn fromThrift( + facebook::velox::parquet::thrift::SortingColumn thriftSortingColumn) { + SortingColumn sortingColumn; + sortingColumn.columnIdx = thriftSortingColumn.column_idx; + sortingColumn.nullsFirst = thriftSortingColumn.nulls_first; + sortingColumn.descending = thriftSortingColumn.descending; + return sortingColumn; } -// ---------------------------------------------------------------------- -// Convert Thrift enums from Parquet enums +// ----------------------------------------------------------------------. +// Convert Thrift enums from Parquet enums. -static inline facebook::velox::parquet::thrift::Type::type ToThrift( +static inline facebook::velox::parquet::thrift::Type::type toThrift( Type::type type) { return static_cast(type); } -static fmt::underlying_t format_as( +static fmt::underlying_t formatAs( ConvertedType::type type) { return fmt::underlying(type); } -static inline facebook::velox::parquet::thrift::ConvertedType::type ToThrift( +static inline facebook::velox::parquet::thrift::ConvertedType::type toThrift( ConvertedType::type type) { - // item 0 is NONE - VELOX_DCHECK_NE(type, ConvertedType::NONE); + // Item 0 is NONE. + const int typeValue = static_cast(type); + VELOX_DCHECK_NE(typeValue, static_cast(ConvertedType::kNone)); // it is forbidden to emit "NA" (PARQUET-1990) - VELOX_DCHECK_NE(type, ConvertedType::NA); - VELOX_DCHECK_NE(type, ConvertedType::UNDEFINED); + VELOX_DCHECK_NE(typeValue, static_cast(ConvertedType::kNa)); + VELOX_DCHECK_NE(typeValue, static_cast(ConvertedType::kUndefined)); return static_cast( - static_cast(type) - 1); + typeValue - 1); } static inline facebook::velox::parquet::thrift::FieldRepetitionType::type -ToThrift(Repetition::type type) { +toThrift(Repetition::type type) { return static_cast< facebook::velox::parquet::thrift::FieldRepetitionType::type>(type); } -static inline facebook::velox::parquet::thrift::Encoding::type ToThrift( +static inline facebook::velox::parquet::thrift::Encoding::type toThrift( Encoding::type type) { return static_cast(type); } -static inline facebook::velox::parquet::thrift::CompressionCodec::type ToThrift( +static inline facebook::velox::parquet::thrift::CompressionCodec::type toThrift( Compression::type type) { switch (type) { case Compression::UNCOMPRESSED: @@ -326,7 +327,7 @@ static inline facebook::velox::parquet::thrift::CompressionCodec::type ToThrift( case Compression::LZ4: return facebook::velox::parquet::thrift::CompressionCodec::LZ4_RAW; case Compression::LZ4_HADOOP: - // Deprecated "LZ4" Parquet compression has Hadoop-specific framing + // Deprecated "LZ4" Parquet compression has Hadoop-specific framing. return facebook::velox::parquet::thrift::CompressionCodec::LZ4; case Compression::ZSTD: return facebook::velox::parquet::thrift::CompressionCodec::ZSTD; @@ -336,12 +337,12 @@ static inline facebook::velox::parquet::thrift::CompressionCodec::type ToThrift( } } -static inline facebook::velox::parquet::thrift::BoundaryOrder::type ToThrift( +static inline facebook::velox::parquet::thrift::BoundaryOrder::type toThrift( BoundaryOrder::type type) { switch (type) { - case BoundaryOrder::Unordered: - case BoundaryOrder::Ascending: - case BoundaryOrder::Descending: + case BoundaryOrder::kUnordered: + case BoundaryOrder::kAscending: + case BoundaryOrder::kDescending: return static_cast( type); default: @@ -350,82 +351,82 @@ static inline facebook::velox::parquet::thrift::BoundaryOrder::type ToThrift( } } -static inline facebook::velox::parquet::thrift::SortingColumn ToThrift( - SortingColumn sorting_column) { - facebook::velox::parquet::thrift::SortingColumn thrift_sorting_column; - thrift_sorting_column.column_idx = sorting_column.column_idx; - thrift_sorting_column.descending = sorting_column.descending; - thrift_sorting_column.nulls_first = sorting_column.nulls_first; - return thrift_sorting_column; +static inline facebook::velox::parquet::thrift::SortingColumn toThrift( + SortingColumn sortingColumn) { + facebook::velox::parquet::thrift::SortingColumn thriftSortingColumn; + thriftSortingColumn.column_idx = sortingColumn.columnIdx; + thriftSortingColumn.descending = sortingColumn.descending; + thriftSortingColumn.nulls_first = sortingColumn.nullsFirst; + return thriftSortingColumn; } -static inline facebook::velox::parquet::thrift::Statistics ToThrift( +static inline facebook::velox::parquet::thrift::Statistics toThrift( const EncodedStatistics& stats) { - facebook::velox::parquet::thrift::Statistics statistics; - if (stats.has_min) { - statistics.__set_min_value(stats.min()); + facebook::velox::parquet::thrift::Statistics Statistics; + if (stats.hasMin) { + Statistics.__set_min_value(stats.min()); // If the order is SIGNED, then the old min value must be set too. - // This for backward compatibility - if (stats.is_signed()) { - statistics.__set_min(stats.min()); + // This for backward compatibility. + if (stats.isSigned()) { + Statistics.__set_min(stats.min()); } } - if (stats.has_max) { - statistics.__set_max_value(stats.max()); + if (stats.hasMax) { + Statistics.__set_max_value(stats.max()); // If the order is SIGNED, then the old max value must be set too. - // This for backward compatibility - if (stats.is_signed()) { - statistics.__set_max(stats.max()); + // This for backward compatibility. + if (stats.isSigned()) { + Statistics.__set_max(stats.max()); } } - if (stats.has_null_count) { - statistics.__set_null_count(stats.null_count); + if (stats.hasNullCount) { + Statistics.__set_null_count(stats.nullCount); } - if (stats.has_distinct_count) { - statistics.__set_distinct_count(stats.distinct_count); + if (stats.hasDistinctCount) { + Statistics.__set_distinct_count(stats.distinctCount); } - return statistics; + return Statistics; } -static inline facebook::velox::parquet::thrift::AesGcmV1 ToAesGcmV1Thrift( +static inline facebook::velox::parquet::thrift::AesGcmV1 toAesGcmV1Thrift( AadMetadata aad) { facebook::velox::parquet::thrift::AesGcmV1 aesGcmV1; - // aad_file_unique is always set - aesGcmV1.__set_aad_file_unique(aad.aad_file_unique); - aesGcmV1.__set_supply_aad_prefix(aad.supply_aad_prefix); - if (!aad.aad_prefix.empty()) { - aesGcmV1.__set_aad_prefix(aad.aad_prefix); + // Aad_file_unique is always set. + aesGcmV1.__set_aad_file_unique(aad.aadFileUnique); + aesGcmV1.__set_supply_aad_prefix(aad.supplyAadPrefix); + if (!aad.aadPrefix.empty()) { + aesGcmV1.__set_aad_prefix(aad.aadPrefix); } return aesGcmV1; } -static inline facebook::velox::parquet::thrift::AesGcmCtrV1 ToAesGcmCtrV1Thrift( +static inline facebook::velox::parquet::thrift::AesGcmCtrV1 toAesGcmCtrV1Thrift( AadMetadata aad) { facebook::velox::parquet::thrift::AesGcmCtrV1 aesGcmCtrV1; - // aad_file_unique is always set - aesGcmCtrV1.__set_aad_file_unique(aad.aad_file_unique); - aesGcmCtrV1.__set_supply_aad_prefix(aad.supply_aad_prefix); - if (!aad.aad_prefix.empty()) { - aesGcmCtrV1.__set_aad_prefix(aad.aad_prefix); + // Aad_file_unique is always set. + aesGcmCtrV1.__set_aad_file_unique(aad.aadFileUnique); + aesGcmCtrV1.__set_supply_aad_prefix(aad.supplyAadPrefix); + if (!aad.aadPrefix.empty()) { + aesGcmCtrV1.__set_aad_prefix(aad.aadPrefix); } return aesGcmCtrV1; } -static inline facebook::velox::parquet::thrift::EncryptionAlgorithm ToThrift( +static inline facebook::velox::parquet::thrift::EncryptionAlgorithm toThrift( EncryptionAlgorithm encryption) { - facebook::velox::parquet::thrift::EncryptionAlgorithm encryption_algorithm; - if (encryption.algorithm == ParquetCipher::AES_GCM_V1) { - encryption_algorithm.__set_AES_GCM_V1(ToAesGcmV1Thrift(encryption.aad)); + facebook::velox::parquet::thrift::EncryptionAlgorithm encryptionAlgorithm; + if (encryption.algorithm == ParquetCipher::kAesGcmV1) { + encryptionAlgorithm.__set_AES_GCM_V1(toAesGcmV1Thrift(encryption.aad)); } else { - encryption_algorithm.__set_AES_GCM_CTR_V1( - ToAesGcmCtrV1Thrift(encryption.aad)); + encryptionAlgorithm.__set_AES_GCM_CTR_V1( + toAesGcmCtrV1Thrift(encryption.aad)); } - return encryption_algorithm; + return encryptionAlgorithm; } -// ---------------------------------------------------------------------- -// Thrift struct serialization / deserialization utilities +// ----------------------------------------------------------------------. +// Thrift struct serialization / deserialization utilities. using ThriftBuffer = apache::thrift::transport::TMemoryBuffer; @@ -433,43 +434,43 @@ class ThriftDeserializer { public: explicit ThriftDeserializer(const ReaderProperties& properties) : ThriftDeserializer( - properties.thrift_string_size_limit(), - properties.thrift_container_size_limit()) {} + properties.thriftStringSizeLimit(), + properties.thriftContainerSizeLimit()) {} - ThriftDeserializer(int32_t string_size_limit, int32_t container_size_limit) - : string_size_limit_(string_size_limit), - container_size_limit_(container_size_limit) {} + ThriftDeserializer(int32_t stringSizeLimit, int32_t containerSizeLimit) + : stringSizeLimit_(stringSizeLimit), + containerSizeLimit_(containerSizeLimit) {} // Deserialize a thrift message from buf/len. buf/len must at least contain // all the bytes needed to store the thrift message. On return, len will be // set to the actual length of the header. template - void DeserializeMessage( + void deserializeMessage( const uint8_t* buf, uint32_t* len, - T* deserialized_msg, - const std::shared_ptr& decryptor = NULLPTR) { - if (decryptor == NULLPTR) { - // thrift message is not encrypted - DeserializeUnencryptedMessage(buf, len, deserialized_msg); + T* deserializedMsg, + const std::shared_ptr& Decryptor = NULLPTR) { + if (Decryptor == NULLPTR) { + // Thrift message is not encrypted. + deserializeUnencryptedMessage(buf, len, deserializedMsg); } else { - // thrift message is encrypted + // Thrift message is encrypted. uint32_t clen; clen = *len; - // decrypt - auto decrypted_buffer = - std::static_pointer_cast(AllocateBuffer( - decryptor->pool(), - static_cast(clen - decryptor->CiphertextSizeDelta()))); - const uint8_t* cipher_buf = buf; - uint32_t decrypted_buffer_len = - decryptor->Decrypt(cipher_buf, 0, decrypted_buffer->mutable_data()); - if (decrypted_buffer_len <= 0) { + // Decrypt. + auto decryptedBuffer = + std::static_pointer_cast(allocateBuffer( + Decryptor->pool(), + static_cast(clen - Decryptor->ciphertextSizeDelta()))); + const uint8_t* cipherBuf = buf; + uint32_t decryptedBufferLen = + Decryptor->decrypt(cipherBuf, 0, decryptedBuffer->mutable_data()); + if (decryptedBufferLen <= 0) { throw ParquetException("Couldn't decrypt buffer\n"); } - *len = decrypted_buffer_len + decryptor->CiphertextSizeDelta(); - DeserializeUnencryptedMessage( - decrypted_buffer->data(), &decrypted_buffer_len, deserialized_msg); + *len = decryptedBufferLen + Decryptor->ciphertextSizeDelta(); + deserializeUnencryptedMessage( + decryptedBuffer->data(), &decryptedBufferLen, deserializedMsg); } } @@ -477,7 +478,7 @@ class ThriftDeserializer { // On Thrift 0.14.0+, we want to use TConfiguration to raise the max message // size limit (ARROW-13655). If we wanted to protect against huge messages, // we could do it ourselves since we know the message size up front. - std::shared_ptr CreateReadOnlyMemoryBuffer( + std::shared_ptr createReadOnlyMemoryBuffer( uint8_t* buf, uint32_t len) { #if PARQUET_THRIFT_VERSION_MAJOR > 0 || PARQUET_THRIFT_VERSION_MINOR >= 14 @@ -491,84 +492,84 @@ class ThriftDeserializer { } template - void DeserializeUnencryptedMessage( + void deserializeUnencryptedMessage( const uint8_t* buf, uint32_t* len, - T* deserialized_msg) { + T* deserializedMsg) { // Deserialize msg bytes into c++ thrift msg using memory transport. - auto tmem_transport = - CreateReadOnlyMemoryBuffer(const_cast(buf), *len); + auto tmemTransport = + createReadOnlyMemoryBuffer(const_cast(buf), *len); apache::thrift::protocol::TCompactProtocolFactoryT - tproto_factory; - // Protect against CPU and memory bombs - tproto_factory.setStringSizeLimit(string_size_limit_); - tproto_factory.setContainerSizeLimit(container_size_limit_); - auto tproto = tproto_factory.getProtocol(tmem_transport); + tprotoFactory; + // Protect against CPU and memory bombs. + tprotoFactory.setStringSizeLimit(stringSizeLimit_); + tprotoFactory.setContainerSizeLimit(containerSizeLimit_); + auto tproto = tprotoFactory.getProtocol(tmemTransport); try { - deserialized_msg->read(tproto.get()); + deserializedMsg->read(tproto.get()); } catch (std::exception& e) { std::stringstream ss; ss << "Couldn't deserialize thrift: " << e.what() << "\n"; throw ParquetException(ss.str()); } - uint32_t bytes_left = tmem_transport->available_read(); - *len = *len - bytes_left; + uint32_t bytesLeft = tmemTransport->available_read(); + *len = *len - bytesLeft; } - const int32_t string_size_limit_; - const int32_t container_size_limit_; + const int32_t stringSizeLimit_; + const int32_t containerSizeLimit_; }; /// Utility class to serialize thrift objects to a binary format. This object /// should be reused if possible to reuse the underlying memory. -/// Note: thrift will encode NULLs into the serialized buffer so it is not valid -/// to treat it as a string. +/// Note: thrift will encode NULLs into the serialized buffer so it is not +/// valid. To treat it as a string. class ThriftSerializer { public: - explicit ThriftSerializer(int initial_buffer_size = 1024) - : mem_buffer_(new ThriftBuffer(initial_buffer_size)) { + explicit ThriftSerializer(int initialBufferSize = 1024) + : memBuffer_(std::make_shared(initialBufferSize)) { apache::thrift::protocol::TCompactProtocolFactoryT factory; - protocol_ = factory.getProtocol(mem_buffer_); + protocol_ = factory.getProtocol(memBuffer_); } /// Serialize obj into a memory buffer. The result is returned in buffer/len. /// The memory returned is owned by this object and will be invalid when /// another object is serialized. template - void SerializeToBuffer(const T* obj, uint32_t* len, uint8_t** buffer) { - SerializeObject(obj); - mem_buffer_->getBuffer(buffer, len); + void serializeToBuffer(const T* obj, uint32_t* len, uint8_t** buffer) { + serializeObject(obj); + memBuffer_->getBuffer(buffer, len); } template - void SerializeToString(const T* obj, std::string* result) { - SerializeObject(obj); - *result = mem_buffer_->getBufferAsString(); + void serializeToString(const T* obj, std::string* result) { + serializeObject(obj); + *result = memBuffer_->getBufferAsString(); } template - int64_t Serialize( + int64_t serialize( const T* obj, ArrowOutputStream* out, - const std::shared_ptr& encryptor = NULLPTR) { - uint8_t* out_buffer; - uint32_t out_length; - SerializeToBuffer(obj, &out_length, &out_buffer); - - // obj is not encrypted - if (encryptor == NULLPTR) { - PARQUET_THROW_NOT_OK(out->Write(out_buffer, out_length)); - return static_cast(out_length); + const std::shared_ptr& Encryptor = NULLPTR) { + uint8_t* outBuffer; + uint32_t outLength; + serializeToBuffer(obj, &outLength, &outBuffer); + + // Obj is not encrypted. + if (Encryptor == NULLPTR) { + PARQUET_THROW_NOT_OK(out->Write(outBuffer, outLength)); + return static_cast(outLength); } else { // obj is encrypted - return SerializeEncryptedObj(out, out_buffer, out_length, encryptor); + return serializeEncryptedObj(out, outBuffer, outLength, Encryptor); } } private: template - void SerializeObject(const T* obj) { + void serializeObject(const T* obj) { try { - mem_buffer_->resetBuffer(); + memBuffer_->resetBuffer(); obj->write(protocol_.get()); } catch (std::exception& e) { std::stringstream ss; @@ -577,24 +578,24 @@ class ThriftSerializer { } } - int64_t SerializeEncryptedObj( + int64_t serializeEncryptedObj( ArrowOutputStream* out, - uint8_t* out_buffer, - uint32_t out_length, - const std::shared_ptr& encryptor) { - auto cipher_buffer = - std::static_pointer_cast(AllocateBuffer( - encryptor->pool(), + uint8_t* outBuffer, + uint32_t outLength, + const std::shared_ptr& Encryptor) { + auto cipherBuffer = + std::static_pointer_cast(allocateBuffer( + Encryptor->pool(), static_cast( - encryptor->CiphertextSizeDelta() + out_length))); - int cipher_buffer_len = encryptor->Encrypt( - out_buffer, out_length, cipher_buffer->mutable_data()); + Encryptor->ciphertextSizeDelta() + outLength))); + int cipherBufferLen = + Encryptor->encrypt(outBuffer, outLength, cipherBuffer->mutable_data()); - PARQUET_THROW_NOT_OK(out->Write(cipher_buffer->data(), cipher_buffer_len)); - return static_cast(cipher_buffer_len); + PARQUET_THROW_NOT_OK(out->Write(cipherBuffer->data(), cipherBufferLen)); + return static_cast(cipherBufferLen); } - std::shared_ptr mem_buffer_; + std::shared_ptr memBuffer_; std::shared_ptr protocol_; }; diff --git a/velox/dwio/parquet/writer/arrow/Types.cpp b/velox/dwio/parquet/writer/arrow/Types.cpp index 929f664f309..4dc5f465883 100644 --- a/velox/dwio/parquet/writer/arrow/Types.cpp +++ b/velox/dwio/parquet/writer/arrow/Types.cpp @@ -32,12 +32,12 @@ using arrow::internal::checked_cast; namespace facebook::velox::parquet::arrow { -fmt::underlying_t format_as( - LogicalType::TimeUnit::unit unit) { +fmt::underlying_t formatAs( + LogicalType::TimeUnit::Unit unit) { return fmt::underlying(unit); } -bool IsCodecSupported(Compression::type codec) { +bool isCodecSupported(Compression::type codec) { switch (codec) { case Compression::UNCOMPRESSED: case Compression::SNAPPY: @@ -52,13 +52,13 @@ bool IsCodecSupported(Compression::type codec) { } } -std::unique_ptr GetCodec(Compression::type codec) { - return GetCodec(codec, util::CodecOptions()); +std::unique_ptr getCodec(Compression::type codec) { + return getCodec(codec, util::CodecOptions()); } -std::unique_ptr GetCodec( +std::unique_ptr getCodec( Compression::type codec, - const util::CodecOptions& codec_options) { + const util::CodecOptions& codecOptions) { std::unique_ptr result; if (codec == Compression::LZO) { throw ParquetException( @@ -66,479 +66,473 @@ std::unique_ptr GetCodec( "general, it is currently not supported by the C++ implementation."); } - if (!IsCodecSupported(codec)) { + if (!isCodecSupported(codec)) { std::stringstream ss; - ss << "Codec type " << util::Codec::GetCodecAsString(codec) + ss << "Codec type " << util::Codec::getCodecAsString(codec) << " not supported in Parquet format"; throw ParquetException(ss.str()); } - PARQUET_ASSIGN_OR_THROW(result, util::Codec::Create(codec, codec_options)); + PARQUET_ASSIGN_OR_THROW(result, util::Codec::create(codec, codecOptions)); return result; } -// use compression level to create Codec -std::unique_ptr GetCodec( +// Use compression level to create Codec. +std::unique_ptr getCodec( Compression::type codec, - int compression_level) { - return GetCodec(codec, util::CodecOptions{compression_level}); + int compressionLevel) { + return getCodec(codec, util::CodecOptions{compressionLevel}); } -bool PageCanUseChecksum(PageType::type pageType) { +bool pageCanUseChecksum(PageType::type pageType) { switch (pageType) { - case PageType::type::DATA_PAGE: - case PageType::type::DATA_PAGE_V2: - case PageType::type::DICTIONARY_PAGE: + case PageType::type::kDataPage: + case PageType::type::kDataPageV2: + case PageType::type::kDictionaryPage: return true; default: return false; } } -std::string FormatStatValue(Type::type parquet_type, ::std::string_view val) { +std::string formatStatValue(Type::type parquetType, ::std::string_view val) { std::stringstream result; const char* bytes = val.data(); - switch (parquet_type) { - case Type::BOOLEAN: + switch (parquetType) { + case Type::kBoolean: result << reinterpret_cast(bytes)[0]; break; - case Type::INT32: + case Type::kInt32: result << reinterpret_cast(bytes)[0]; break; - case Type::INT64: + case Type::kInt64: result << reinterpret_cast(bytes)[0]; break; - case Type::DOUBLE: + case Type::kDouble: result << reinterpret_cast(bytes)[0]; break; - case Type::FLOAT: + case Type::kFloat: result << reinterpret_cast(bytes)[0]; break; - case Type::INT96: { - auto const i32_val = reinterpret_cast(bytes); - result << i32_val[0] << " " << i32_val[1] << " " << i32_val[2]; + case Type::kInt96: { + auto const i32Val = reinterpret_cast(bytes); + result << i32Val[0] << " " << i32Val[1] << " " << i32Val[2]; break; } - case Type::BYTE_ARRAY: { + case Type::kByteArray: { return std::string(val); } - case Type::FIXED_LEN_BYTE_ARRAY: { + case Type::kFixedLenByteArray: { return std::string(val); } - case Type::UNDEFINED: + case Type::kUndefined: default: break; } return result.str(); } -std::string EncodingToString(Encoding::type t) { +std::string encodingToString(Encoding::type t) { switch (t) { - case Encoding::PLAIN: + case Encoding::kPlain: return "PLAIN"; - case Encoding::PLAIN_DICTIONARY: + case Encoding::kPlainDictionary: return "PLAIN_DICTIONARY"; - case Encoding::RLE: + case Encoding::kRle: return "RLE"; - case Encoding::BIT_PACKED: + case Encoding::kBitPacked: return "BIT_PACKED"; - case Encoding::DELTA_BINARY_PACKED: + case Encoding::kDeltaBinaryPacked: return "DELTA_BINARY_PACKED"; - case Encoding::DELTA_LENGTH_BYTE_ARRAY: + case Encoding::kDeltaLengthByteArray: return "DELTA_LENGTH_BYTE_ARRAY"; - case Encoding::DELTA_BYTE_ARRAY: + case Encoding::kDeltaByteArray: return "DELTA_BYTE_ARRAY"; - case Encoding::RLE_DICTIONARY: + case Encoding::kRleDictionary: return "RLE_DICTIONARY"; - case Encoding::BYTE_STREAM_SPLIT: + case Encoding::kByteStreamSplit: return "BYTE_STREAM_SPLIT"; default: return "UNKNOWN"; } } -std::string TypeToString(Type::type t) { +std::string typeToString(Type::type t) { switch (t) { - case Type::BOOLEAN: + case Type::kBoolean: return "BOOLEAN"; - case Type::INT32: + case Type::kInt32: return "INT32"; - case Type::INT64: + case Type::kInt64: return "INT64"; - case Type::INT96: + case Type::kInt96: return "INT96"; - case Type::FLOAT: + case Type::kFloat: return "FLOAT"; - case Type::DOUBLE: + case Type::kDouble: return "DOUBLE"; - case Type::BYTE_ARRAY: + case Type::kByteArray: return "BYTE_ARRAY"; - case Type::FIXED_LEN_BYTE_ARRAY: + case Type::kFixedLenByteArray: return "FIXED_LEN_BYTE_ARRAY"; - case Type::UNDEFINED: + case Type::kUndefined: default: return "UNKNOWN"; } } -std::string ConvertedTypeToString(ConvertedType::type t) { +std::string convertedTypeToString(ConvertedType::type t) { switch (t) { - case ConvertedType::NONE: + case ConvertedType::kNone: return "NONE"; - case ConvertedType::UTF8: + case ConvertedType::kUtf8: return "UTF8"; - case ConvertedType::MAP: + case ConvertedType::kMap: return "MAP"; - case ConvertedType::MAP_KEY_VALUE: + case ConvertedType::kMapKeyValue: return "MAP_KEY_VALUE"; - case ConvertedType::LIST: + case ConvertedType::kList: return "LIST"; - case ConvertedType::ENUM: + case ConvertedType::kEnum: return "ENUM"; - case ConvertedType::DECIMAL: + case ConvertedType::kDecimal: return "DECIMAL"; - case ConvertedType::DATE: + case ConvertedType::kDate: return "DATE"; - case ConvertedType::TIME_MILLIS: + case ConvertedType::kTimeMillis: return "TIME_MILLIS"; - case ConvertedType::TIME_MICROS: + case ConvertedType::kTimeMicros: return "TIME_MICROS"; - case ConvertedType::TIMESTAMP_MILLIS: + case ConvertedType::kTimestampMillis: return "TIMESTAMP_MILLIS"; - case ConvertedType::TIMESTAMP_MICROS: + case ConvertedType::kTimestampMicros: return "TIMESTAMP_MICROS"; - case ConvertedType::UINT_8: + case ConvertedType::kUint8: return "UINT_8"; - case ConvertedType::UINT_16: + case ConvertedType::kUint16: return "UINT_16"; - case ConvertedType::UINT_32: + case ConvertedType::kUint32: return "UINT_32"; - case ConvertedType::UINT_64: + case ConvertedType::kUint64: return "UINT_64"; - case ConvertedType::INT_8: + case ConvertedType::kInt8: return "INT_8"; - case ConvertedType::INT_16: + case ConvertedType::kInt16: return "INT_16"; - case ConvertedType::INT_32: + case ConvertedType::kInt32: return "INT_32"; - case ConvertedType::INT_64: + case ConvertedType::kInt64: return "INT_64"; - case ConvertedType::JSON: + case ConvertedType::kJson: return "JSON"; - case ConvertedType::BSON: + case ConvertedType::kBson: return "BSON"; - case ConvertedType::INTERVAL: + case ConvertedType::kInterval: return "INTERVAL"; - case ConvertedType::UNDEFINED: + case ConvertedType::kUndefined: default: return "UNKNOWN"; } } -int GetTypeByteSize(Type::type parquet_type) { - switch (parquet_type) { - case Type::BOOLEAN: - return type_traits::value_byte_size; - case Type::INT32: - return type_traits::value_byte_size; - case Type::INT64: - return type_traits::value_byte_size; - case Type::INT96: - return type_traits::value_byte_size; - case Type::DOUBLE: - return type_traits::value_byte_size; - case Type::FLOAT: - return type_traits::value_byte_size; - case Type::BYTE_ARRAY: - return type_traits::value_byte_size; - case Type::FIXED_LEN_BYTE_ARRAY: - return type_traits::value_byte_size; - case Type::UNDEFINED: +int getTypeByteSize(Type::type parquetType) { + switch (parquetType) { + case Type::kBoolean: + return TypeTraits::valueByteSize; + case Type::kInt32: + return TypeTraits::valueByteSize; + case Type::kInt64: + return TypeTraits::valueByteSize; + case Type::kInt96: + return TypeTraits::valueByteSize; + case Type::kDouble: + return TypeTraits::valueByteSize; + case Type::kFloat: + return TypeTraits::valueByteSize; + case Type::kByteArray: + return TypeTraits::valueByteSize; + case Type::kFixedLenByteArray: + return TypeTraits::valueByteSize; + case Type::kUndefined: default: return 0; } return 0; } -// Return the Sort Order of the Parquet Physical Types -SortOrder::type DefaultSortOrder(Type::type primitive) { +// Return the Sort Order of the Parquet Physical Types. +SortOrder::type defaultSortOrder(Type::type primitive) { switch (primitive) { - case Type::BOOLEAN: - case Type::INT32: - case Type::INT64: - case Type::FLOAT: - case Type::DOUBLE: - return SortOrder::SIGNED; - case Type::BYTE_ARRAY: - case Type::FIXED_LEN_BYTE_ARRAY: - return SortOrder::UNSIGNED; - case Type::INT96: - case Type::UNDEFINED: - return SortOrder::UNKNOWN; + case Type::kBoolean: + case Type::kInt32: + case Type::kInt64: + case Type::kFloat: + case Type::kDouble: + return SortOrder::kSigned; + case Type::kByteArray: + case Type::kFixedLenByteArray: + return SortOrder::kUnsigned; + case Type::kInt96: + case Type::kUndefined: + return SortOrder::kUnknown; } - return SortOrder::UNKNOWN; + return SortOrder::kUnknown; } -// Return the SortOrder of the Parquet Types using Logical or Physical Types -SortOrder::type GetSortOrder( +// Return the SortOrder of the Parquet Types using Logical or Physical Types. +SortOrder::type getSortOrder( ConvertedType::type converted, Type::type primitive) { - if (converted == ConvertedType::NONE) - return DefaultSortOrder(primitive); + if (converted == ConvertedType::kNone) + return defaultSortOrder(primitive); switch (converted) { - case ConvertedType::INT_8: - case ConvertedType::INT_16: - case ConvertedType::INT_32: - case ConvertedType::INT_64: - case ConvertedType::DATE: - case ConvertedType::TIME_MICROS: - case ConvertedType::TIME_MILLIS: - case ConvertedType::TIMESTAMP_MICROS: - case ConvertedType::TIMESTAMP_MILLIS: - return SortOrder::SIGNED; - case ConvertedType::UINT_8: - case ConvertedType::UINT_16: - case ConvertedType::UINT_32: - case ConvertedType::UINT_64: - case ConvertedType::ENUM: - case ConvertedType::UTF8: - case ConvertedType::BSON: - case ConvertedType::JSON: - return SortOrder::UNSIGNED; - case ConvertedType::DECIMAL: - case ConvertedType::LIST: - case ConvertedType::MAP: - case ConvertedType::MAP_KEY_VALUE: - case ConvertedType::INTERVAL: - case ConvertedType::NONE: // required instead of default - case ConvertedType::NA: // required instead of default - case ConvertedType::UNDEFINED: - return SortOrder::UNKNOWN; + case ConvertedType::kInt8: + case ConvertedType::kInt16: + case ConvertedType::kInt32: + case ConvertedType::kInt64: + case ConvertedType::kDate: + case ConvertedType::kTimeMicros: + case ConvertedType::kTimeMillis: + case ConvertedType::kTimestampMicros: + case ConvertedType::kTimestampMillis: + return SortOrder::kSigned; + case ConvertedType::kUint8: + case ConvertedType::kUint16: + case ConvertedType::kUint32: + case ConvertedType::kUint64: + case ConvertedType::kEnum: + case ConvertedType::kUtf8: + case ConvertedType::kBson: + case ConvertedType::kJson: + return SortOrder::kUnsigned; + case ConvertedType::kDecimal: + case ConvertedType::kList: + case ConvertedType::kMap: + case ConvertedType::kMapKeyValue: + case ConvertedType::kInterval: + case ConvertedType::kNone: // required instead of default + case ConvertedType::kNa: // required instead of default + case ConvertedType::kUndefined: + return SortOrder::kUnknown; } - return SortOrder::UNKNOWN; + return SortOrder::kUnknown; } -SortOrder::type GetSortOrder( - const std::shared_ptr& logical_type, +SortOrder::type getSortOrder( + const std::shared_ptr& logicalType, Type::type primitive) { - SortOrder::type o = SortOrder::UNKNOWN; - if (logical_type && logical_type->is_valid()) { + SortOrder::type o = SortOrder::kUnknown; + if (logicalType && logicalType->isValid()) { o = - (logical_type->is_none() ? DefaultSortOrder(primitive) - : logical_type->sort_order()); + (logicalType->isNone() ? defaultSortOrder(primitive) + : logicalType->sortOrder()); } return o; } -ColumnOrder ColumnOrder::undefined_ = ColumnOrder(ColumnOrder::UNDEFINED); -ColumnOrder ColumnOrder::type_defined_ = - ColumnOrder(ColumnOrder::TYPE_DEFINED_ORDER); - -// Static methods for LogicalType class - -std::shared_ptr LogicalType::FromConvertedType( - const ConvertedType::type converted_type, - const schema::DecimalMetadata converted_decimal_metadata) { - switch (converted_type) { - case ConvertedType::UTF8: - return StringLogicalType::Make(); - case ConvertedType::MAP_KEY_VALUE: - case ConvertedType::MAP: - return MapLogicalType::Make(); - case ConvertedType::LIST: - return ListLogicalType::Make(); - case ConvertedType::ENUM: - return EnumLogicalType::Make(); - case ConvertedType::DECIMAL: - return DecimalLogicalType::Make( - converted_decimal_metadata.precision, - converted_decimal_metadata.scale); - case ConvertedType::DATE: - return DateLogicalType::Make(); - case ConvertedType::TIME_MILLIS: - return TimeLogicalType::Make(true, LogicalType::TimeUnit::MILLIS); - case ConvertedType::TIME_MICROS: - return TimeLogicalType::Make(true, LogicalType::TimeUnit::MICROS); - case ConvertedType::TIMESTAMP_MILLIS: - return TimestampLogicalType::Make( - true, - LogicalType::TimeUnit::MILLIS, - /*is_from_converted_type=*/true, - /*force_set_converted_type=*/false); - case ConvertedType::TIMESTAMP_MICROS: - return TimestampLogicalType::Make( - true, - LogicalType::TimeUnit::MICROS, - /*is_from_converted_type=*/true, - /*force_set_converted_type=*/false); - case ConvertedType::INTERVAL: - return IntervalLogicalType::Make(); - case ConvertedType::INT_8: - return IntLogicalType::Make(8, true); - case ConvertedType::INT_16: - return IntLogicalType::Make(16, true); - case ConvertedType::INT_32: - return IntLogicalType::Make(32, true); - case ConvertedType::INT_64: - return IntLogicalType::Make(64, true); - case ConvertedType::UINT_8: - return IntLogicalType::Make(8, false); - case ConvertedType::UINT_16: - return IntLogicalType::Make(16, false); - case ConvertedType::UINT_32: - return IntLogicalType::Make(32, false); - case ConvertedType::UINT_64: - return IntLogicalType::Make(64, false); - case ConvertedType::JSON: - return JSONLogicalType::Make(); - case ConvertedType::BSON: - return BSONLogicalType::Make(); - case ConvertedType::NA: - return NullLogicalType::Make(); - case ConvertedType::NONE: - return NoLogicalType::Make(); - case ConvertedType::UNDEFINED: - return UndefinedLogicalType::Make(); +ColumnOrder ColumnOrder::undefined_ = ColumnOrder(ColumnOrder::kUndefined); +ColumnOrder ColumnOrder::typeDefined_ = + ColumnOrder(ColumnOrder::kTypeDefinedOrder); + +// Static methods for LogicalType class. + +std::shared_ptr LogicalType::fromConvertedType( + const ConvertedType::type convertedType, + const schema::DecimalMetadata convertedDecimalMetadata) { + switch (convertedType) { + case ConvertedType::kUtf8: + return StringLogicalType::make(); + case ConvertedType::kMapKeyValue: + case ConvertedType::kMap: + return MapLogicalType::make(); + case ConvertedType::kList: + return ListLogicalType::make(); + case ConvertedType::kEnum: + return EnumLogicalType::make(); + case ConvertedType::kDecimal: + return DecimalLogicalType::make( + convertedDecimalMetadata.precision, convertedDecimalMetadata.scale); + case ConvertedType::kDate: + return DateLogicalType::make(); + case ConvertedType::kTimeMillis: + return TimeLogicalType::make(true, LogicalType::TimeUnit::kMillis); + case ConvertedType::kTimeMicros: + return TimeLogicalType::make(true, LogicalType::TimeUnit::kMicros); + case ConvertedType::kTimestampMillis: + return TimestampLogicalType::make( + true, LogicalType::TimeUnit::kMillis, true, false); + case ConvertedType::kTimestampMicros: + return TimestampLogicalType::make( + true, LogicalType::TimeUnit::kMicros, true, false); + case ConvertedType::kInterval: + return IntervalLogicalType::make(); + case ConvertedType::kInt8: + return IntLogicalType::make(8, true); + case ConvertedType::kInt16: + return IntLogicalType::make(16, true); + case ConvertedType::kInt32: + return IntLogicalType::make(32, true); + case ConvertedType::kInt64: + return IntLogicalType::make(64, true); + case ConvertedType::kUint8: + return IntLogicalType::make(8, false); + case ConvertedType::kUint16: + return IntLogicalType::make(16, false); + case ConvertedType::kUint32: + return IntLogicalType::make(32, false); + case ConvertedType::kUint64: + return IntLogicalType::make(64, false); + case ConvertedType::kJson: + return JsonLogicalType::make(); + case ConvertedType::kBson: + return BsonLogicalType::make(); + case ConvertedType::kNa: + return NullLogicalType::make(); + case ConvertedType::kNone: + return NoLogicalType::make(); + case ConvertedType::kUndefined: + return UndefinedLogicalType::make(); } - return UndefinedLogicalType::Make(); + return UndefinedLogicalType::make(); } -std::shared_ptr LogicalType::FromThrift( +std::shared_ptr LogicalType::fromThrift( const facebook::velox::parquet::thrift::LogicalType& type) { if (type.__isset.STRING) { - return StringLogicalType::Make(); + return StringLogicalType::make(); } else if (type.__isset.MAP) { - return MapLogicalType::Make(); + return MapLogicalType::make(); } else if (type.__isset.LIST) { - return ListLogicalType::Make(); + return ListLogicalType::make(); } else if (type.__isset.ENUM) { - return EnumLogicalType::Make(); + return EnumLogicalType::make(); } else if (type.__isset.DECIMAL) { - return DecimalLogicalType::Make(type.DECIMAL.precision, type.DECIMAL.scale); + return DecimalLogicalType::make(type.DECIMAL.precision, type.DECIMAL.scale); } else if (type.__isset.DATE) { - return DateLogicalType::Make(); + return DateLogicalType::make(); } else if (type.__isset.TIME) { - LogicalType::TimeUnit::unit unit; + LogicalType::TimeUnit::Unit unit; if (type.TIME.unit.__isset.MILLIS) { - unit = LogicalType::TimeUnit::MILLIS; + unit = LogicalType::TimeUnit::kMillis; } else if (type.TIME.unit.__isset.MICROS) { - unit = LogicalType::TimeUnit::MICROS; + unit = LogicalType::TimeUnit::kMicros; } else if (type.TIME.unit.__isset.NANOS) { - unit = LogicalType::TimeUnit::NANOS; + unit = LogicalType::TimeUnit::kNanos; } else { - unit = LogicalType::TimeUnit::UNKNOWN; + unit = LogicalType::TimeUnit::kUnknown; } - return TimeLogicalType::Make(type.TIME.isAdjustedToUTC, unit); + return TimeLogicalType::make(type.TIME.isAdjustedToUTC, unit); } else if (type.__isset.TIMESTAMP) { - LogicalType::TimeUnit::unit unit; + LogicalType::TimeUnit::Unit unit; if (type.TIMESTAMP.unit.__isset.MILLIS) { - unit = LogicalType::TimeUnit::MILLIS; + unit = LogicalType::TimeUnit::kMillis; } else if (type.TIMESTAMP.unit.__isset.MICROS) { - unit = LogicalType::TimeUnit::MICROS; + unit = LogicalType::TimeUnit::kMicros; } else if (type.TIMESTAMP.unit.__isset.NANOS) { - unit = LogicalType::TimeUnit::NANOS; + unit = LogicalType::TimeUnit::kNanos; } else { - unit = LogicalType::TimeUnit::UNKNOWN; + unit = LogicalType::TimeUnit::kUnknown; } - return TimestampLogicalType::Make(type.TIMESTAMP.isAdjustedToUTC, unit); + return TimestampLogicalType::make(type.TIMESTAMP.isAdjustedToUTC, unit); // TODO(tpboudreau): activate the commented code after parquet.thrift - // recognizes IntervalType as a LogicalType - //} else if (type.__isset.INTERVAL) { - // return IntervalLogicalType::Make(); + // recognizes IntervalType as a LogicalType. + // } else if (type.__isset.INTERVAL) { + // return IntervalLogicalType::make(); } else if (type.__isset.INTEGER) { - return IntLogicalType::Make( + return IntLogicalType::make( static_cast(type.INTEGER.bitWidth), type.INTEGER.isSigned); } else if (type.__isset.UNKNOWN) { - return NullLogicalType::Make(); + return NullLogicalType::make(); } else if (type.__isset.JSON) { - return JSONLogicalType::Make(); + return JsonLogicalType::make(); } else if (type.__isset.BSON) { - return BSONLogicalType::Make(); + return BsonLogicalType::make(); } else if (type.__isset.UUID) { - return UUIDLogicalType::Make(); + return UuidLogicalType::make(); } else { throw ParquetException( "Metadata contains Thrift LogicalType that is not recognized"); } } -std::shared_ptr LogicalType::String() { - return StringLogicalType::Make(); +std::shared_ptr LogicalType::string() { + return StringLogicalType::make(); } -std::shared_ptr LogicalType::Map() { - return MapLogicalType::Make(); +std::shared_ptr LogicalType::map() { + return MapLogicalType::make(); } -std::shared_ptr LogicalType::List() { - return ListLogicalType::Make(); +std::shared_ptr LogicalType::list() { + return ListLogicalType::make(); } -std::shared_ptr LogicalType::Enum() { - return EnumLogicalType::Make(); +std::shared_ptr LogicalType::enumType() { + return EnumLogicalType::make(); } -std::shared_ptr LogicalType::Decimal( +std::shared_ptr LogicalType::decimal( int32_t precision, int32_t scale) { - return DecimalLogicalType::Make(precision, scale); + return DecimalLogicalType::make(precision, scale); } -std::shared_ptr LogicalType::Date() { - return DateLogicalType::Make(); +std::shared_ptr LogicalType::date() { + return DateLogicalType::make(); } -std::shared_ptr LogicalType::Time( - bool is_adjusted_to_utc, - LogicalType::TimeUnit::unit time_unit) { - VELOX_DCHECK_NE(time_unit, LogicalType::TimeUnit::UNKNOWN); - return TimeLogicalType::Make(is_adjusted_to_utc, time_unit); +std::shared_ptr LogicalType::time( + bool isAdjustedToUtc, + LogicalType::TimeUnit::Unit timeUnit) { + VELOX_DCHECK_NE( + static_cast(timeUnit), + static_cast(LogicalType::TimeUnit::kUnknown)); + return TimeLogicalType::make(isAdjustedToUtc, timeUnit); } -std::shared_ptr LogicalType::Timestamp( - bool is_adjusted_to_utc, - LogicalType::TimeUnit::unit time_unit, - bool is_from_converted_type, - bool force_set_converted_type) { - VELOX_DCHECK_NE(time_unit, LogicalType::TimeUnit::UNKNOWN); - return TimestampLogicalType::Make( - is_adjusted_to_utc, - time_unit, - is_from_converted_type, - force_set_converted_type); +std::shared_ptr LogicalType::timestamp( + bool isAdjustedToUtc, + LogicalType::TimeUnit::Unit timeUnit, + bool isFromConvertedType, + bool forceSetConvertedType) { + VELOX_DCHECK_NE( + static_cast(timeUnit), + static_cast(LogicalType::TimeUnit::kUnknown)); + return TimestampLogicalType::make( + isAdjustedToUtc, timeUnit, isFromConvertedType, forceSetConvertedType); } -std::shared_ptr LogicalType::Interval() { - return IntervalLogicalType::Make(); +std::shared_ptr LogicalType::interval() { + return IntervalLogicalType::make(); } -std::shared_ptr LogicalType::Int( - int bit_width, - bool is_signed) { +std::shared_ptr LogicalType::intType( + int bitWidth, + bool isSigned) { VELOX_DCHECK( - bit_width == 64 || bit_width == 32 || bit_width == 16 || bit_width == 8); - return IntLogicalType::Make(bit_width, is_signed); + bitWidth == 64 || bitWidth == 32 || bitWidth == 16 || bitWidth == 8); + return IntLogicalType::make(bitWidth, isSigned); } -std::shared_ptr LogicalType::Null() { - return NullLogicalType::Make(); +std::shared_ptr LogicalType::nullType() { + return NullLogicalType::make(); } -std::shared_ptr LogicalType::JSON() { - return JSONLogicalType::Make(); +std::shared_ptr LogicalType::json() { + return JsonLogicalType::make(); } -std::shared_ptr LogicalType::BSON() { - return BSONLogicalType::Make(); +std::shared_ptr LogicalType::bson() { + return BsonLogicalType::make(); } -std::shared_ptr LogicalType::UUID() { - return UUIDLogicalType::Make(); +std::shared_ptr LogicalType::uuid() { + return UuidLogicalType::make(); } -std::shared_ptr LogicalType::None() { - return NoLogicalType::Make(); +std::shared_ptr LogicalType::none() { + return NoLogicalType::make(); } /* @@ -554,44 +548,44 @@ std::shared_ptr LogicalType::None() { * overridden. */ -// LogicalTypeImpl base class +// LogicalTypeImpl base class. class LogicalType::Impl { public: - virtual bool is_applicable( - parquet::Type::type primitive_type, - int32_t primitive_length = -1) const = 0; + virtual bool isApplicable( + parquet::Type::type primitiveType, + int32_t primitiveLength = -1) const = 0; - virtual bool is_compatible( - ConvertedType::type converted_type, - schema::DecimalMetadata converted_decimal_metadata = {false, -1, -1}) + virtual bool isCompatible( + ConvertedType::type convertedType, + schema::DecimalMetadata convertedDecimalMetadata = {false, -1, -1}) const = 0; - virtual ConvertedType::type ToConvertedType( - schema::DecimalMetadata* out_decimal_metadata) const = 0; + virtual ConvertedType::type toConvertedType( + schema::DecimalMetadata* outDecimalMetadata) const = 0; - virtual std::string ToString() const = 0; + virtual std::string toString() const = 0; - virtual bool is_serialized() const { + virtual bool isSerialized() const { return !( - type_ == LogicalType::Type::NONE || - type_ == LogicalType::Type::UNDEFINED); + type_ == LogicalType::Type::kNone || + type_ == LogicalType::Type::kUndefined); } - virtual std::string ToJSON() const { + virtual std::string toJson() const { std::stringstream json; - json << R"({"Type": ")" << ToString() << R"("})"; + json << R"({"Type": ")" << toString() << R"("})"; return json.str(); } - virtual facebook::velox::parquet::thrift::LogicalType ToThrift() const { - // logical types inheriting this method should never be serialized + virtual facebook::velox::parquet::thrift::LogicalType toThrift() const { + // Logical types inheriting this method should never be serialized. std::stringstream ss; - ss << "Logical type " << ToString() << " should not be serialized"; + ss << "Logical type " << toString() << " should not be serialized"; throw ParquetException(ss.str()); } - virtual bool Equals(const LogicalType& other) const { + virtual bool equals(const LogicalType& other) const { return other.type() == type_; } @@ -599,7 +593,7 @@ class LogicalType::Impl { return type_; } - SortOrder::type sort_order() const { + SortOrder::type sortOrder() const { return order_; } @@ -628,9 +622,9 @@ class LogicalType::Impl { class Interval; class Int; class Null; - class JSON; - class BSON; - class UUID; + class Json; + class Bson; + class Uuid; class No; class Undefined; @@ -639,202 +633,203 @@ class LogicalType::Impl { Impl() = default; private: - LogicalType::Type::type type_ = LogicalType::Type::UNDEFINED; - SortOrder::type order_ = SortOrder::UNKNOWN; + LogicalType::Type::type type_ = LogicalType::Type::kUndefined; + SortOrder::type order_ = SortOrder::kUnknown; }; -// Special methods for public LogicalType class +// Special methods for public LogicalType class. LogicalType::LogicalType() = default; LogicalType::~LogicalType() noexcept = default; -// Delegating methods for public LogicalType class +// Delegating methods for public LogicalType class. -bool LogicalType::is_applicable( - parquet::Type::type primitive_type, - int32_t primitive_length) const { - return impl_->is_applicable(primitive_type, primitive_length); +bool LogicalType::isApplicable( + parquet::Type::type primitiveType, + int32_t primitiveLength) const { + return impl_->isApplicable(primitiveType, primitiveLength); } -bool LogicalType::is_compatible( - ConvertedType::type converted_type, - schema::DecimalMetadata converted_decimal_metadata) const { - return impl_->is_compatible(converted_type, converted_decimal_metadata); +bool LogicalType::isCompatible( + ConvertedType::type convertedType, + schema::DecimalMetadata convertedDecimalMetadata) const { + return impl_->isCompatible(convertedType, convertedDecimalMetadata); } -ConvertedType::type LogicalType::ToConvertedType( - schema::DecimalMetadata* out_decimal_metadata) const { - return impl_->ToConvertedType(out_decimal_metadata); +ConvertedType::type LogicalType::toConvertedType( + schema::DecimalMetadata* outDecimalMetadata) const { + return impl_->toConvertedType(outDecimalMetadata); } -std::string LogicalType::ToString() const { - return impl_->ToString(); +std::string LogicalType::toString() const { + return impl_->toString(); } -std::string LogicalType::ToJSON() const { - return impl_->ToJSON(); +std::string LogicalType::toJson() const { + return impl_->toJson(); } -facebook::velox::parquet::thrift::LogicalType LogicalType::ToThrift() const { - return impl_->ToThrift(); +facebook::velox::parquet::thrift::LogicalType LogicalType::toThrift() const { + return impl_->toThrift(); } -bool LogicalType::Equals(const LogicalType& other) const { - return impl_->Equals(other); +bool LogicalType::equals(const LogicalType& other) const { + return impl_->equals(other); } LogicalType::Type::type LogicalType::type() const { return impl_->type(); } -SortOrder::type LogicalType::sort_order() const { - return impl_->sort_order(); +SortOrder::type LogicalType::sortOrder() const { + return impl_->sortOrder(); } -// Type checks for public LogicalType class +// Type checks for public LogicalType class. -bool LogicalType::is_string() const { - return impl_->type() == LogicalType::Type::STRING; +bool LogicalType::isString() const { + return impl_->type() == LogicalType::Type::kString; } -bool LogicalType::is_map() const { - return impl_->type() == LogicalType::Type::MAP; +bool LogicalType::isMap() const { + return impl_->type() == LogicalType::Type::kMap; } -bool LogicalType::is_list() const { - return impl_->type() == LogicalType::Type::LIST; +bool LogicalType::isList() const { + return impl_->type() == LogicalType::Type::kList; } -bool LogicalType::is_enum() const { - return impl_->type() == LogicalType::Type::ENUM; +bool LogicalType::isEnum() const { + return impl_->type() == LogicalType::Type::kEnum; } -bool LogicalType::is_decimal() const { - return impl_->type() == LogicalType::Type::DECIMAL; +bool LogicalType::isDecimal() const { + return impl_->type() == LogicalType::Type::kDecimal; } -bool LogicalType::is_date() const { - return impl_->type() == LogicalType::Type::DATE; +bool LogicalType::isDate() const { + return impl_->type() == LogicalType::Type::kDate; } -bool LogicalType::is_time() const { - return impl_->type() == LogicalType::Type::TIME; +bool LogicalType::isTime() const { + return impl_->type() == LogicalType::Type::kTime; } -bool LogicalType::is_timestamp() const { - return impl_->type() == LogicalType::Type::TIMESTAMP; +bool LogicalType::isTimestamp() const { + return impl_->type() == LogicalType::Type::kTimestamp; } -bool LogicalType::is_interval() const { - return impl_->type() == LogicalType::Type::INTERVAL; +bool LogicalType::isInterval() const { + return impl_->type() == LogicalType::Type::kInterval; } -bool LogicalType::is_int() const { - return impl_->type() == LogicalType::Type::INT; +bool LogicalType::isInt() const { + return impl_->type() == LogicalType::Type::kInt; } -bool LogicalType::is_null() const { - return impl_->type() == LogicalType::Type::NIL; +bool LogicalType::isNull() const { + return impl_->type() == LogicalType::Type::kNil; } -bool LogicalType::is_JSON() const { - return impl_->type() == LogicalType::Type::JSON; +bool LogicalType::isJson() const { + return impl_->type() == LogicalType::Type::kJson; } -bool LogicalType::is_BSON() const { - return impl_->type() == LogicalType::Type::BSON; +bool LogicalType::isBson() const { + return impl_->type() == LogicalType::Type::kBson; } -bool LogicalType::is_UUID() const { - return impl_->type() == LogicalType::Type::UUID; +bool LogicalType::isUuid() const { + return impl_->type() == LogicalType::Type::kUuid; } -bool LogicalType::is_none() const { - return impl_->type() == LogicalType::Type::NONE; +bool LogicalType::isNone() const { + return impl_->type() == LogicalType::Type::kNone; } -bool LogicalType::is_valid() const { - return impl_->type() != LogicalType::Type::UNDEFINED; +bool LogicalType::isValid() const { + return impl_->type() != LogicalType::Type::kUndefined; } -bool LogicalType::is_invalid() const { - return !is_valid(); +bool LogicalType::isInvalid() const { + return !isValid(); } -bool LogicalType::is_nested() const { - return (impl_->type() == LogicalType::Type::LIST) || - (impl_->type() == LogicalType::Type::MAP); +bool LogicalType::isNested() const { + return (impl_->type() == LogicalType::Type::kList) || + (impl_->type() == LogicalType::Type::kMap); } -bool LogicalType::is_nonnested() const { - return !is_nested(); +bool LogicalType::isNonnested() const { + return !isNested(); } -bool LogicalType::is_serialized() const { - return impl_->is_serialized(); +bool LogicalType::isSerialized() const { + return impl_->isSerialized(); } -// LogicalTypeImpl intermediate "compatibility" classes +// LogicalTypeImpl intermediate "compatibility" classes. class LogicalType::Impl::Compatible : public virtual LogicalType::Impl { protected: Compatible() = default; }; -#define set_decimal_metadata(m___, i___, p___, s___) \ - { \ - if (m___) { \ - (m___)->isset = (i___); \ - (m___)->scale = (s___); \ - (m___)->precision = (p___); \ - } \ +#define setDecimalMetadata(m___, i___, p___, s___) \ + { \ + if (m___) { \ + (m___)->isset = (i___); \ + (m___)->scale = (s___); \ + (m___)->precision = (p___); \ + } \ } -#define reset_decimal_metadata(m___) \ - { set_decimal_metadata(m___, false, -1, -1); } +#define resetDecimalMetadata(m___) \ + { \ + setDecimalMetadata(m___, false, -1, -1); \ + } -// For logical types that always translate to the same converted type +// For logical types that always translate to the same converted type. class LogicalType::Impl::SimpleCompatible : public virtual LogicalType::Impl::Compatible { public: - bool is_compatible( - ConvertedType::type converted_type, - schema::DecimalMetadata converted_decimal_metadata) const override { - return (converted_type == converted_type_) && - !converted_decimal_metadata.isset; + bool isCompatible( + ConvertedType::type convertedType, + schema::DecimalMetadata convertedDecimalMetadata) const override { + return (convertedType == convertedType_) && !convertedDecimalMetadata.isset; } - ConvertedType::type ToConvertedType( - schema::DecimalMetadata* out_decimal_metadata) const override { - reset_decimal_metadata(out_decimal_metadata); - return converted_type_; + ConvertedType::type toConvertedType( + schema::DecimalMetadata* outDecimalMetadata) const override { + resetDecimalMetadata(outDecimalMetadata); + return convertedType_; } protected: - explicit SimpleCompatible(ConvertedType::type c) : converted_type_(c) {} + explicit SimpleCompatible(ConvertedType::type c) : convertedType_(c) {} private: - ConvertedType::type converted_type_ = ConvertedType::NA; + ConvertedType::type convertedType_ = ConvertedType::kNa; }; -// For logical types that have no corresponding converted type +// For logical types that have no corresponding converted type. class LogicalType::Impl::Incompatible : public virtual LogicalType::Impl { public: - bool is_compatible( - ConvertedType::type converted_type, - schema::DecimalMetadata converted_decimal_metadata) const override { - return (converted_type == ConvertedType::NONE || - converted_type == ConvertedType::NA) && - !converted_decimal_metadata.isset; + bool isCompatible( + ConvertedType::type convertedType, + schema::DecimalMetadata convertedDecimalMetadata) const override { + return (convertedType == ConvertedType::kNone || + convertedType == ConvertedType::kNa) && + !convertedDecimalMetadata.isset; } - ConvertedType::type ToConvertedType( - schema::DecimalMetadata* out_decimal_metadata) const override { - reset_decimal_metadata(out_decimal_metadata); - return ConvertedType::NONE; + ConvertedType::type toConvertedType( + schema::DecimalMetadata* outDecimalMetadata) const override { + resetDecimalMetadata(outDecimalMetadata); + return ConvertedType::kNone; } protected: Incompatible() = default; }; -// LogicalTypeImpl intermediate "applicability" classes +// LogicalTypeImpl intermediate "applicability" classes. class LogicalType::Impl::Applicable : public virtual LogicalType::Impl { protected: Applicable() = default; }; -// For logical types that can apply only to a single -// physical type +// For logical types that can apply only to a single. +// Physical type. class LogicalType::Impl::SimpleApplicable : public virtual LogicalType::Impl::Applicable { public: - bool is_applicable( - parquet::Type::type primitive_type, - int32_t primitive_length = -1) const override { - return primitive_type == type_; + bool isApplicable( + parquet::Type::type primitiveType, + int32_t primitiveLength = -1) const override { + return primitiveType == type_; } protected: @@ -844,15 +839,15 @@ class LogicalType::Impl::SimpleApplicable parquet::Type::type type_; }; -// For logical types that can apply only to a particular -// physical type and physical length combination +// For logical types that can apply only to a particular. +// Physical type and physical length combination. class LogicalType::Impl::TypeLengthApplicable : public virtual LogicalType::Impl::Applicable { public: - bool is_applicable( - parquet::Type::type primitive_type, - int32_t primitive_length = -1) const override { - return primitive_type == type_ && primitive_length == length_; + bool isApplicable( + parquet::Type::type primitiveType, + int32_t primitiveLength = -1) const override { + return primitiveType == type_ && primitiveLength == length_; } protected: @@ -864,13 +859,13 @@ class LogicalType::Impl::TypeLengthApplicable int32_t length_; }; -// For logical types that can apply to any physical type +// For logical types that can apply to any physical type. class LogicalType::Impl::UniversalApplicable : public virtual LogicalType::Impl::Applicable { public: - bool is_applicable( - parquet::Type::type primitive_type, - int32_t primitive_length = -1) const override { + bool isApplicable( + parquet::Type::type primitiveType, + int32_t primitiveLength = -1) const override { return true; } @@ -878,13 +873,13 @@ class LogicalType::Impl::UniversalApplicable UniversalApplicable() = default; }; -// For logical types that can never apply to any primitive -// physical type +// For logical types that can never apply to any primitive. +// Physical type. class LogicalType::Impl::Inapplicable : public virtual LogicalType::Impl { public: - bool is_applicable( - parquet::Type::type primitive_type, - int32_t primitive_length = -1) const override { + bool isApplicable( + parquet::Type::type primitiveType, + int32_t primitiveLength = -1) const override { return false; } @@ -892,15 +887,15 @@ class LogicalType::Impl::Inapplicable : public virtual LogicalType::Impl { Inapplicable() = default; }; -// LogicalType implementation final classes +// LogicalType implementation final classes. #define OVERRIDE_TOSTRING(n___) \ - std::string ToString() const override { \ + std::string toString() const override { \ return #n___; \ } #define OVERRIDE_TOTHRIFT(t___, s___) \ - facebook::velox::parquet::thrift::LogicalType ToThrift() const override { \ + facebook::velox::parquet::thrift::LogicalType toThrift() const override { \ facebook::velox::parquet::thrift::LogicalType type; \ facebook::velox::parquet::thrift::t___ subtype; \ type.__set_##s___(subtype); \ @@ -918,20 +913,20 @@ class LogicalType::Impl::String final private: String() - : LogicalType::Impl(LogicalType::Type::STRING, SortOrder::UNSIGNED), - LogicalType::Impl::SimpleCompatible(ConvertedType::UTF8), - LogicalType::Impl::SimpleApplicable(parquet::Type::BYTE_ARRAY) {} + : LogicalType::Impl(LogicalType::Type::kString, SortOrder::kUnsigned), + LogicalType::Impl::SimpleCompatible(ConvertedType::kUtf8), + LogicalType::Impl::SimpleApplicable(parquet::Type::kByteArray) {} }; -// Each public logical type class's Make() creation method instantiates a -// corresponding LogicalType::Impl::* object and installs that implementation in -// the logical type it returns. +// Each public logical type class's Make() creation method instantiates a. +// Corresponding LogicalType::Impl::* object and installs that implementation +// in. The logical type it returns. #define GENERATE_MAKE(a___) \ - std::shared_ptr a___##LogicalType::Make() { \ - auto* logical_type = new a___##LogicalType(); \ - logical_type->impl_.reset(new LogicalType::Impl::a___()); \ - return std::shared_ptr(logical_type); \ + std::shared_ptr a___##LogicalType::make() { \ + auto* logicalType = new a___##LogicalType(); \ + logicalType->impl_.reset(new LogicalType::Impl::a___()); \ + return std::shared_ptr(logicalType); \ } GENERATE_MAKE(String) @@ -941,12 +936,12 @@ class LogicalType::Impl::Map final : public LogicalType::Impl::SimpleCompatible, public: friend class MapLogicalType; - bool is_compatible( - ConvertedType::type converted_type, - schema::DecimalMetadata converted_decimal_metadata) const override { - return (converted_type == ConvertedType::MAP || - converted_type == ConvertedType::MAP_KEY_VALUE) && - !converted_decimal_metadata.isset; + bool isCompatible( + ConvertedType::type convertedType, + schema::DecimalMetadata convertedDecimalMetadata) const override { + return (convertedType == ConvertedType::kMap || + convertedType == ConvertedType::kMapKeyValue) && + !convertedDecimalMetadata.isset; } OVERRIDE_TOSTRING(Map) @@ -954,8 +949,8 @@ class LogicalType::Impl::Map final : public LogicalType::Impl::SimpleCompatible, private: Map() - : LogicalType::Impl(LogicalType::Type::MAP, SortOrder::UNKNOWN), - LogicalType::Impl::SimpleCompatible(ConvertedType::MAP) {} + : LogicalType::Impl(LogicalType::Type::kMap, SortOrder::kUnknown), + LogicalType::Impl::SimpleCompatible(ConvertedType::kMap) {} }; GENERATE_MAKE(Map) @@ -971,8 +966,8 @@ class LogicalType::Impl::List final private: List() - : LogicalType::Impl(LogicalType::Type::LIST, SortOrder::UNKNOWN), - LogicalType::Impl::SimpleCompatible(ConvertedType::LIST) {} + : LogicalType::Impl(LogicalType::Type::kList, SortOrder::kUnknown), + LogicalType::Impl::SimpleCompatible(ConvertedType::kList) {} }; GENERATE_MAKE(List) @@ -988,34 +983,34 @@ class LogicalType::Impl::Enum final private: Enum() - : LogicalType::Impl(LogicalType::Type::ENUM, SortOrder::UNSIGNED), - LogicalType::Impl::SimpleCompatible(ConvertedType::ENUM), - LogicalType::Impl::SimpleApplicable(parquet::Type::BYTE_ARRAY) {} + : LogicalType::Impl(LogicalType::Type::kEnum, SortOrder::kUnsigned), + LogicalType::Impl::SimpleCompatible(ConvertedType::kEnum), + LogicalType::Impl::SimpleApplicable(parquet::Type::kByteArray) {} }; GENERATE_MAKE(Enum) // The parameterized logical types (currently Decimal, Time, Timestamp, and Int) -// generally can't reuse the simple method implementations available in the base -// and intermediate classes and must (re)implement them all +// Generally can't reuse the simple method implementations available in the +// base. And intermediate classes and must (re)implement them all. class LogicalType::Impl::Decimal final : public LogicalType::Impl::Compatible, public LogicalType::Impl::Applicable { public: friend class DecimalLogicalType; - bool is_applicable( - parquet::Type::type primitive_type, - int32_t primitive_length = -1) const override; - bool is_compatible( - ConvertedType::type converted_type, - schema::DecimalMetadata converted_decimal_metadata) const override; - ConvertedType::type ToConvertedType( - schema::DecimalMetadata* out_decimal_metadata) const override; - std::string ToString() const override; - std::string ToJSON() const override; - facebook::velox::parquet::thrift::LogicalType ToThrift() const override; - bool Equals(const LogicalType& other) const override; + bool isApplicable( + parquet::Type::type primitiveType, + int32_t primitiveLength = -1) const override; + bool isCompatible( + ConvertedType::type convertedType, + schema::DecimalMetadata convertedDecimalMetadata) const override; + ConvertedType::type toConvertedType( + schema::DecimalMetadata* outDecimalMetadata) const override; + std::string toString() const override; + std::string toJson() const override; + facebook::velox::parquet::thrift::LogicalType toThrift() const override; + bool equals(const LogicalType& other) const override; int32_t precision() const { return precision_; @@ -1026,38 +1021,38 @@ class LogicalType::Impl::Decimal final : public LogicalType::Impl::Compatible, private: Decimal(int32_t p, int32_t s) - : LogicalType::Impl(LogicalType::Type::DECIMAL, SortOrder::SIGNED), + : LogicalType::Impl(LogicalType::Type::kDecimal, SortOrder::kSigned), precision_(p), scale_(s) {} int32_t precision_ = -1; int32_t scale_ = -1; }; -bool LogicalType::Impl::Decimal::is_applicable( - parquet::Type::type primitive_type, - int32_t primitive_length) const { +bool LogicalType::Impl::Decimal::isApplicable( + parquet::Type::type primitiveType, + int32_t primitiveLength) const { bool ok = false; - switch (primitive_type) { - case parquet::Type::INT32: { + switch (primitiveType) { + case parquet::Type::kInt32: { ok = (1 <= precision_) && (precision_ <= 9); } break; - case parquet::Type::INT64: { + case parquet::Type::kInt64: { ok = (1 <= precision_) && (precision_ <= 18); if (precision_ < 10) { - // FIXME(tpb): warn that INT32 could be used + // FIXME(tpb): warn that INT32 could be used. } } break; - case parquet::Type::FIXED_LEN_BYTE_ARRAY: { - // If the primitive length is larger than this we will overflow int32 when - // calculating precision. - if (primitive_length <= 0 || primitive_length > 891723282) { + case parquet::Type::kFixedLenByteArray: { + // If the primitive length is larger than this we will overflow int32 + // when. Calculating precision. + if (primitiveLength <= 0 || primitiveLength > 891723282) { ok = false; break; } ok = precision_ <= static_cast(std::floor( - std::log10(2) * ((8.0 * primitive_length) - 1.0))); + std::log10(2) * ((8.0 * primitiveLength) - 1.0))); } break; - case parquet::Type::BYTE_ARRAY: { + case parquet::Type::kByteArray: { ok = true; } break; default: { @@ -1066,28 +1061,28 @@ bool LogicalType::Impl::Decimal::is_applicable( return ok; } -bool LogicalType::Impl::Decimal::is_compatible( - ConvertedType::type converted_type, - schema::DecimalMetadata converted_decimal_metadata) const { - return converted_type == ConvertedType::DECIMAL && - (converted_decimal_metadata.isset && - converted_decimal_metadata.scale == scale_ && - converted_decimal_metadata.precision == precision_); +bool LogicalType::Impl::Decimal::isCompatible( + ConvertedType::type convertedType, + schema::DecimalMetadata convertedDecimalMetadata) const { + return convertedType == ConvertedType::kDecimal && + (convertedDecimalMetadata.isset && + convertedDecimalMetadata.scale == scale_ && + convertedDecimalMetadata.precision == precision_); } -ConvertedType::type LogicalType::Impl::Decimal::ToConvertedType( - schema::DecimalMetadata* out_decimal_metadata) const { - set_decimal_metadata(out_decimal_metadata, true, precision_, scale_); - return ConvertedType::DECIMAL; +ConvertedType::type LogicalType::Impl::Decimal::toConvertedType( + schema::DecimalMetadata* outDecimalMetadata) const { + setDecimalMetadata(outDecimalMetadata, true, precision_, scale_); + return ConvertedType::kDecimal; } -std::string LogicalType::Impl::Decimal::ToString() const { +std::string LogicalType::Impl::Decimal::toString() const { std::stringstream type; type << "Decimal(precision=" << precision_ << ", scale=" << scale_ << ")"; return type.str(); } -std::string LogicalType::Impl::Decimal::ToJSON() const { +std::string LogicalType::Impl::Decimal::toJson() const { std::stringstream json; json << R"({"Type": "Decimal", "precision": )" << precision_ << R"(, "scale": )" << scale_ << "}"; @@ -1095,27 +1090,27 @@ std::string LogicalType::Impl::Decimal::ToJSON() const { } facebook::velox::parquet::thrift::LogicalType -LogicalType::Impl::Decimal::ToThrift() const { +LogicalType::Impl::Decimal::toThrift() const { facebook::velox::parquet::thrift::LogicalType type; - facebook::velox::parquet::thrift::DecimalType decimal_type; - decimal_type.__set_precision(precision_); - decimal_type.__set_scale(scale_); - type.__set_DECIMAL(decimal_type); + facebook::velox::parquet::thrift::DecimalType decimalType; + decimalType.__set_precision(precision_); + decimalType.__set_scale(scale_); + type.__set_DECIMAL(decimalType); return type; } -bool LogicalType::Impl::Decimal::Equals(const LogicalType& other) const { +bool LogicalType::Impl::Decimal::equals(const LogicalType& other) const { bool eq = false; - if (other.is_decimal()) { - const auto& other_decimal = checked_cast(other); + if (other.isDecimal()) { + const auto& otherDecimal = checked_cast(other); eq = - (precision_ == other_decimal.precision() && - scale_ == other_decimal.scale()); + (precision_ == otherDecimal.precision() && + scale_ == otherDecimal.scale()); } return eq; } -std::shared_ptr DecimalLogicalType::Make( +std::shared_ptr DecimalLogicalType::make( int32_t precision, int32_t scale) { if (precision < 1) { @@ -1127,9 +1122,9 @@ std::shared_ptr DecimalLogicalType::Make( "Scale must be a non-negative integer that does not exceed precision for " "Decimal logical type"); } - auto* logical_type = new DecimalLogicalType(); - logical_type->impl_.reset(new LogicalType::Impl::Decimal(precision, scale)); - return std::shared_ptr(logical_type); + auto* logicalType = new DecimalLogicalType(); + logicalType->impl_.reset(new LogicalType::Impl::Decimal(precision, scale)); + return std::shared_ptr(logicalType); } int32_t DecimalLogicalType::precision() const { @@ -1151,164 +1146,165 @@ class LogicalType::Impl::Date final private: Date() - : LogicalType::Impl(LogicalType::Type::DATE, SortOrder::SIGNED), - LogicalType::Impl::SimpleCompatible(ConvertedType::DATE), - LogicalType::Impl::SimpleApplicable(parquet::Type::INT32) {} + : LogicalType::Impl(LogicalType::Type::kDate, SortOrder::kSigned), + LogicalType::Impl::SimpleCompatible(ConvertedType::kDate), + LogicalType::Impl::SimpleApplicable(parquet::Type::kInt32) {} }; GENERATE_MAKE(Date) -#define time_unit_string(u___) \ - ((u___) == LogicalType::TimeUnit::MILLIS \ - ? "milliseconds" \ - : ((u___) == LogicalType::TimeUnit::MICROS \ - ? "microseconds" \ - : ((u___) == LogicalType::TimeUnit::NANOS ? "nanoseconds" \ - : "unknown"))) +#define timeUnitString(u___) \ + ((u___) == LogicalType::TimeUnit::kMillis \ + ? "milliseconds" \ + : ((u___) == LogicalType::TimeUnit::kMicros \ + ? "microseconds" \ + : ((u___) == LogicalType::TimeUnit::kNanos ? "nanoseconds" \ + : "unknown"))) class LogicalType::Impl::Time final : public LogicalType::Impl::Compatible, public LogicalType::Impl::Applicable { public: friend class TimeLogicalType; - bool is_applicable( - parquet::Type::type primitive_type, - int32_t primitive_length = -1) const override; - bool is_compatible( - ConvertedType::type converted_type, - schema::DecimalMetadata converted_decimal_metadata) const override; - ConvertedType::type ToConvertedType( - schema::DecimalMetadata* out_decimal_metadata) const override; - std::string ToString() const override; - std::string ToJSON() const override; - facebook::velox::parquet::thrift::LogicalType ToThrift() const override; - bool Equals(const LogicalType& other) const override; - - bool is_adjusted_to_utc() const { + bool isApplicable( + parquet::Type::type primitiveType, + int32_t primitiveLength = -1) const override; + bool isCompatible( + ConvertedType::type convertedType, + schema::DecimalMetadata convertedDecimalMetadata) const override; + ConvertedType::type toConvertedType( + schema::DecimalMetadata* outDecimalMetadata) const override; + std::string toString() const override; + std::string toJson() const override; + facebook::velox::parquet::thrift::LogicalType toThrift() const override; + bool equals(const LogicalType& other) const override; + + bool isAdjustedToUtc() const { return adjusted_; } - LogicalType::TimeUnit::unit time_unit() const { + LogicalType::TimeUnit::Unit timeUnit() const { return unit_; } private: - Time(bool a, LogicalType::TimeUnit::unit u) - : LogicalType::Impl(LogicalType::Type::TIME, SortOrder::SIGNED), + Time(bool a, LogicalType::TimeUnit::Unit u) + : LogicalType::Impl(LogicalType::Type::kTime, SortOrder::kSigned), adjusted_(a), unit_(u) {} bool adjusted_ = false; - LogicalType::TimeUnit::unit unit_; + LogicalType::TimeUnit::Unit unit_; }; -bool LogicalType::Impl::Time::is_applicable( - parquet::Type::type primitive_type, - int32_t primitive_length) const { - return (primitive_type == parquet::Type::INT32 && - unit_ == LogicalType::TimeUnit::MILLIS) || - (primitive_type == parquet::Type::INT64 && - (unit_ == LogicalType::TimeUnit::MICROS || - unit_ == LogicalType::TimeUnit::NANOS)); +bool LogicalType::Impl::Time::isApplicable( + parquet::Type::type primitiveType, + int32_t primitiveLength) const { + return (primitiveType == parquet::Type::kInt32 && + unit_ == LogicalType::TimeUnit::kMillis) || + (primitiveType == parquet::Type::kInt64 && + (unit_ == LogicalType::TimeUnit::kMicros || + unit_ == LogicalType::TimeUnit::kNanos)); } -bool LogicalType::Impl::Time::is_compatible( - ConvertedType::type converted_type, - schema::DecimalMetadata converted_decimal_metadata) const { - if (converted_decimal_metadata.isset) { +bool LogicalType::Impl::Time::isCompatible( + ConvertedType::type convertedType, + schema::DecimalMetadata convertedDecimalMetadata) const { + if (convertedDecimalMetadata.isset) { return false; - } else if (adjusted_ && unit_ == LogicalType::TimeUnit::MILLIS) { - return converted_type == ConvertedType::TIME_MILLIS; - } else if (adjusted_ && unit_ == LogicalType::TimeUnit::MICROS) { - return converted_type == ConvertedType::TIME_MICROS; + } else if (adjusted_ && unit_ == LogicalType::TimeUnit::kMillis) { + return convertedType == ConvertedType::kTimeMillis; + } else if (adjusted_ && unit_ == LogicalType::TimeUnit::kMicros) { + return convertedType == ConvertedType::kTimeMicros; } else { - return (converted_type == ConvertedType::NONE) || - (converted_type == ConvertedType::NA); + return (convertedType == ConvertedType::kNone) || + (convertedType == ConvertedType::kNa); } } -ConvertedType::type LogicalType::Impl::Time::ToConvertedType( - schema::DecimalMetadata* out_decimal_metadata) const { - reset_decimal_metadata(out_decimal_metadata); +ConvertedType::type LogicalType::Impl::Time::toConvertedType( + schema::DecimalMetadata* outDecimalMetadata) const { + resetDecimalMetadata(outDecimalMetadata); if (adjusted_) { - if (unit_ == LogicalType::TimeUnit::MILLIS) { - return ConvertedType::TIME_MILLIS; - } else if (unit_ == LogicalType::TimeUnit::MICROS) { - return ConvertedType::TIME_MICROS; + if (unit_ == LogicalType::TimeUnit::kMillis) { + return ConvertedType::kTimeMillis; + } else if (unit_ == LogicalType::TimeUnit::kMicros) { + return ConvertedType::kTimeMicros; } } - return ConvertedType::NONE; + return ConvertedType::kNone; } -std::string LogicalType::Impl::Time::ToString() const { +std::string LogicalType::Impl::Time::toString() const { std::stringstream type; type << "Time(isAdjustedToUTC=" << std::boolalpha << adjusted_ - << ", timeUnit=" << time_unit_string(unit_) << ")"; + << ", timeUnit=" << timeUnitString(unit_) << ")"; return type.str(); } -std::string LogicalType::Impl::Time::ToJSON() const { +std::string LogicalType::Impl::Time::toJson() const { std::stringstream json; json << R"({"Type": "Time", "isAdjustedToUTC": )" << std::boolalpha - << adjusted_ << R"(, "timeUnit": ")" << time_unit_string(unit_) - << R"("})"; + << adjusted_ << R"(, "timeUnit": ")" << timeUnitString(unit_) << R"("})"; return json.str(); } facebook::velox::parquet::thrift::LogicalType -LogicalType::Impl::Time::ToThrift() const { +LogicalType::Impl::Time::toThrift() const { facebook::velox::parquet::thrift::LogicalType type; - facebook::velox::parquet::thrift::TimeType time_type; - facebook::velox::parquet::thrift::TimeUnit time_unit; - VELOX_DCHECK_NE(unit_, LogicalType::TimeUnit::UNKNOWN); - if (unit_ == LogicalType::TimeUnit::MILLIS) { + facebook::velox::parquet::thrift::TimeType timeType; + facebook::velox::parquet::thrift::TimeUnit timeUnit; + VELOX_DCHECK_NE( + static_cast(unit_), + static_cast(LogicalType::TimeUnit::kUnknown)); + if (unit_ == LogicalType::TimeUnit::kMillis) { facebook::velox::parquet::thrift::MilliSeconds millis; - time_unit.__set_MILLIS(millis); - } else if (unit_ == LogicalType::TimeUnit::MICROS) { + timeUnit.__set_MILLIS(millis); + } else if (unit_ == LogicalType::TimeUnit::kMicros) { facebook::velox::parquet::thrift::MicroSeconds micros; - time_unit.__set_MICROS(micros); - } else if (unit_ == LogicalType::TimeUnit::NANOS) { + timeUnit.__set_MICROS(micros); + } else if (unit_ == LogicalType::TimeUnit::kNanos) { facebook::velox::parquet::thrift::NanoSeconds nanos; - time_unit.__set_NANOS(nanos); + timeUnit.__set_NANOS(nanos); } - time_type.__set_isAdjustedToUTC(adjusted_); - time_type.__set_unit(time_unit); - type.__set_TIME(time_type); + timeType.__set_isAdjustedToUTC(adjusted_); + timeType.__set_unit(timeUnit); + type.__set_TIME(timeType); return type; } -bool LogicalType::Impl::Time::Equals(const LogicalType& other) const { +bool LogicalType::Impl::Time::equals(const LogicalType& other) const { bool eq = false; - if (other.is_time()) { - const auto& other_time = checked_cast(other); + if (other.isTime()) { + const auto& otherTime = checked_cast(other); eq = - (adjusted_ == other_time.is_adjusted_to_utc() && - unit_ == other_time.time_unit()); + (adjusted_ == otherTime.isAdjustedToUtc() && + unit_ == otherTime.timeUnit()); } return eq; } -std::shared_ptr TimeLogicalType::Make( - bool is_adjusted_to_utc, - LogicalType::TimeUnit::unit time_unit) { - if (time_unit == LogicalType::TimeUnit::MILLIS || - time_unit == LogicalType::TimeUnit::MICROS || - time_unit == LogicalType::TimeUnit::NANOS) { - auto* logical_type = new TimeLogicalType(); - logical_type->impl_.reset( - new LogicalType::Impl::Time(is_adjusted_to_utc, time_unit)); - return std::shared_ptr(logical_type); +std::shared_ptr TimeLogicalType::make( + bool isAdjustedToUtc, + LogicalType::TimeUnit::Unit timeUnit) { + if (timeUnit == LogicalType::TimeUnit::kMillis || + timeUnit == LogicalType::TimeUnit::kMicros || + timeUnit == LogicalType::TimeUnit::kNanos) { + auto* logicalType = new TimeLogicalType(); + logicalType->impl_.reset( + new LogicalType::Impl::Time(isAdjustedToUtc, timeUnit)); + return std::shared_ptr(logicalType); } else { throw ParquetException( "TimeUnit must be one of MILLIS, MICROS, or NANOS for Time logical type"); } } -bool TimeLogicalType::is_adjusted_to_utc() const { +bool TimeLogicalType::isAdjustedToUtc() const { return (dynamic_cast(*impl_)) - .is_adjusted_to_utc(); + .isAdjustedToUtc(); } -LogicalType::TimeUnit::unit TimeLogicalType::time_unit() const { - return (dynamic_cast(*impl_)).time_unit(); +LogicalType::TimeUnit::Unit TimeLogicalType::timeUnit() const { + return (dynamic_cast(*impl_)).timeUnit(); } class LogicalType::Impl::Timestamp final @@ -1317,183 +1313,180 @@ class LogicalType::Impl::Timestamp final public: friend class TimestampLogicalType; - bool is_serialized() const override; - bool is_compatible( - ConvertedType::type converted_type, - schema::DecimalMetadata converted_decimal_metadata) const override; - ConvertedType::type ToConvertedType( - schema::DecimalMetadata* out_decimal_metadata) const override; - std::string ToString() const override; - std::string ToJSON() const override; - facebook::velox::parquet::thrift::LogicalType ToThrift() const override; - bool Equals(const LogicalType& other) const override; - - bool is_adjusted_to_utc() const { + bool isSerialized() const override; + bool isCompatible( + ConvertedType::type convertedType, + schema::DecimalMetadata convertedDecimalMetadata) const override; + ConvertedType::type toConvertedType( + schema::DecimalMetadata* outDecimalMetadata) const override; + std::string toString() const override; + std::string toJson() const override; + facebook::velox::parquet::thrift::LogicalType toThrift() const override; + bool equals(const LogicalType& other) const override; + + bool isAdjustedToUtc() const { return adjusted_; } - LogicalType::TimeUnit::unit time_unit() const { + LogicalType::TimeUnit::Unit timeUnit() const { return unit_; } - bool is_from_converted_type() const { - return is_from_converted_type_; + bool isFromConvertedType() const { + return isFromConvertedType_; } - bool force_set_converted_type() const { - return force_set_converted_type_; + bool forceSetConvertedType() const { + return forceSetConvertedType_; } private: Timestamp( bool adjusted, - LogicalType::TimeUnit::unit unit, - bool is_from_converted_type, - bool force_set_converted_type) - : LogicalType::Impl(LogicalType::Type::TIMESTAMP, SortOrder::SIGNED), - LogicalType::Impl::SimpleApplicable(parquet::Type::INT64), + LogicalType::TimeUnit::Unit Unit, + bool isFromConvertedType, + bool forceSetConvertedType) + : LogicalType::Impl(LogicalType::Type::kTimestamp, SortOrder::kSigned), + LogicalType::Impl::SimpleApplicable(parquet::Type::kInt64), adjusted_(adjusted), - unit_(unit), - is_from_converted_type_(is_from_converted_type), - force_set_converted_type_(force_set_converted_type) {} + unit_(Unit), + isFromConvertedType_(isFromConvertedType), + forceSetConvertedType_(forceSetConvertedType) {} bool adjusted_ = false; - LogicalType::TimeUnit::unit unit_; - bool is_from_converted_type_ = false; - bool force_set_converted_type_ = false; + LogicalType::TimeUnit::Unit unit_; + bool isFromConvertedType_ = false; + bool forceSetConvertedType_ = false; }; -bool LogicalType::Impl::Timestamp::is_serialized() const { - return !is_from_converted_type_; +bool LogicalType::Impl::Timestamp::isSerialized() const { + return !isFromConvertedType_; } -bool LogicalType::Impl::Timestamp::is_compatible( - ConvertedType::type converted_type, - schema::DecimalMetadata converted_decimal_metadata) const { - if (converted_decimal_metadata.isset) { +bool LogicalType::Impl::Timestamp::isCompatible( + ConvertedType::type convertedType, + schema::DecimalMetadata convertedDecimalMetadata) const { + if (convertedDecimalMetadata.isset) { return false; - } else if (unit_ == LogicalType::TimeUnit::MILLIS) { - if (adjusted_ || force_set_converted_type_) { - return converted_type == ConvertedType::TIMESTAMP_MILLIS; + } else if (unit_ == LogicalType::TimeUnit::kMillis) { + if (adjusted_ || forceSetConvertedType_) { + return convertedType == ConvertedType::kTimestampMillis; } else { - return (converted_type == ConvertedType::NONE) || - (converted_type == ConvertedType::NA); + return (convertedType == ConvertedType::kNone) || + (convertedType == ConvertedType::kNa); } - } else if (unit_ == LogicalType::TimeUnit::MICROS) { - if (adjusted_ || force_set_converted_type_) { - return converted_type == ConvertedType::TIMESTAMP_MICROS; + } else if (unit_ == LogicalType::TimeUnit::kMicros) { + if (adjusted_ || forceSetConvertedType_) { + return convertedType == ConvertedType::kTimestampMicros; } else { - return (converted_type == ConvertedType::NONE) || - (converted_type == ConvertedType::NA); + return (convertedType == ConvertedType::kNone) || + (convertedType == ConvertedType::kNa); } } else { - return (converted_type == ConvertedType::NONE) || - (converted_type == ConvertedType::NA); + return (convertedType == ConvertedType::kNone) || + (convertedType == ConvertedType::kNa); } } -ConvertedType::type LogicalType::Impl::Timestamp::ToConvertedType( - schema::DecimalMetadata* out_decimal_metadata) const { - reset_decimal_metadata(out_decimal_metadata); - if (adjusted_ || force_set_converted_type_) { - if (unit_ == LogicalType::TimeUnit::MILLIS) { - return ConvertedType::TIMESTAMP_MILLIS; - } else if (unit_ == LogicalType::TimeUnit::MICROS) { - return ConvertedType::TIMESTAMP_MICROS; +ConvertedType::type LogicalType::Impl::Timestamp::toConvertedType( + schema::DecimalMetadata* outDecimalMetadata) const { + resetDecimalMetadata(outDecimalMetadata); + if (adjusted_ || forceSetConvertedType_) { + if (unit_ == LogicalType::TimeUnit::kMillis) { + return ConvertedType::kTimestampMillis; + } else if (unit_ == LogicalType::TimeUnit::kMicros) { + return ConvertedType::kTimestampMicros; } } - return ConvertedType::NONE; + return ConvertedType::kNone; } -std::string LogicalType::Impl::Timestamp::ToString() const { +std::string LogicalType::Impl::Timestamp::toString() const { std::stringstream type; type << "Timestamp(isAdjustedToUTC=" << std::boolalpha << adjusted_ - << ", timeUnit=" << time_unit_string(unit_) - << ", is_from_converted_type=" << is_from_converted_type_ - << ", force_set_converted_type=" << force_set_converted_type_ << ")"; + << ", timeUnit=" << timeUnitString(unit_) + << ", is_from_converted_type=" << isFromConvertedType_ + << ", force_set_converted_type=" << forceSetConvertedType_ << ")"; return type.str(); } -std::string LogicalType::Impl::Timestamp::ToJSON() const { +std::string LogicalType::Impl::Timestamp::toJson() const { std::stringstream json; json << R"({"Type": "Timestamp", "isAdjustedToUTC": )" << std::boolalpha - << adjusted_ << R"(, "timeUnit": ")" << time_unit_string(unit_) << R"(")" - << R"(, "is_from_converted_type": )" << is_from_converted_type_ - << R"(, "force_set_converted_type": )" << force_set_converted_type_ - << R"(})"; + << adjusted_ << R"(, "timeUnit": ")" << timeUnitString(unit_) << R"(")" + << R"(, "isFromConvertedType": )" << isFromConvertedType_ + << R"(, "forceSetConvertedType": )" << forceSetConvertedType_ << R"(})"; return json.str(); } facebook::velox::parquet::thrift::LogicalType -LogicalType::Impl::Timestamp::ToThrift() const { +LogicalType::Impl::Timestamp::toThrift() const { facebook::velox::parquet::thrift::LogicalType type; - facebook::velox::parquet::thrift::TimestampType timestamp_type; - facebook::velox::parquet::thrift::TimeUnit time_unit; - VELOX_DCHECK_NE(unit_, LogicalType::TimeUnit::UNKNOWN); - if (unit_ == LogicalType::TimeUnit::MILLIS) { + facebook::velox::parquet::thrift::TimestampType timestampType; + facebook::velox::parquet::thrift::TimeUnit timeUnit; + VELOX_DCHECK_NE( + static_cast(unit_), + static_cast(LogicalType::TimeUnit::kUnknown)); + if (unit_ == LogicalType::TimeUnit::kMillis) { facebook::velox::parquet::thrift::MilliSeconds millis; - time_unit.__set_MILLIS(millis); - } else if (unit_ == LogicalType::TimeUnit::MICROS) { + timeUnit.__set_MILLIS(millis); + } else if (unit_ == LogicalType::TimeUnit::kMicros) { facebook::velox::parquet::thrift::MicroSeconds micros; - time_unit.__set_MICROS(micros); - } else if (unit_ == LogicalType::TimeUnit::NANOS) { + timeUnit.__set_MICROS(micros); + } else if (unit_ == LogicalType::TimeUnit::kNanos) { facebook::velox::parquet::thrift::NanoSeconds nanos; - time_unit.__set_NANOS(nanos); + timeUnit.__set_NANOS(nanos); } - timestamp_type.__set_isAdjustedToUTC(adjusted_); - timestamp_type.__set_unit(time_unit); - type.__set_TIMESTAMP(timestamp_type); + timestampType.__set_isAdjustedToUTC(adjusted_); + timestampType.__set_unit(timeUnit); + type.__set_TIMESTAMP(timestampType); return type; } -bool LogicalType::Impl::Timestamp::Equals(const LogicalType& other) const { +bool LogicalType::Impl::Timestamp::equals(const LogicalType& other) const { bool eq = false; - if (other.is_timestamp()) { - const auto& other_timestamp = + if (other.isTimestamp()) { + const auto& otherTimestamp = checked_cast(other); eq = - (adjusted_ == other_timestamp.is_adjusted_to_utc() && - unit_ == other_timestamp.time_unit()); + (adjusted_ == otherTimestamp.isAdjustedToUtc() && + unit_ == otherTimestamp.timeUnit()); } return eq; } -std::shared_ptr TimestampLogicalType::Make( - bool is_adjusted_to_utc, - LogicalType::TimeUnit::unit time_unit, - bool is_from_converted_type, - bool force_set_converted_type) { - if (time_unit == LogicalType::TimeUnit::MILLIS || - time_unit == LogicalType::TimeUnit::MICROS || - time_unit == LogicalType::TimeUnit::NANOS) { - auto* logical_type = new TimestampLogicalType(); - logical_type->impl_.reset(new LogicalType::Impl::Timestamp( - is_adjusted_to_utc, - time_unit, - is_from_converted_type, - force_set_converted_type)); - return std::shared_ptr(logical_type); +std::shared_ptr TimestampLogicalType::make( + bool isAdjustedToUtc, + LogicalType::TimeUnit::Unit timeUnit, + bool isFromConvertedType, + bool forceSetConvertedType) { + if (timeUnit == LogicalType::TimeUnit::kMillis || + timeUnit == LogicalType::TimeUnit::kMicros || + timeUnit == LogicalType::TimeUnit::kNanos) { + auto* logicalType = new TimestampLogicalType(); + logicalType->impl_.reset(new LogicalType::Impl::Timestamp( + isAdjustedToUtc, timeUnit, isFromConvertedType, forceSetConvertedType)); + return std::shared_ptr(logicalType); } else { throw ParquetException( "TimeUnit must be one of MILLIS, MICROS, or NANOS for Timestamp logical type"); } } -bool TimestampLogicalType::is_adjusted_to_utc() const { +bool TimestampLogicalType::isAdjustedToUtc() const { return (dynamic_cast(*impl_)) - .is_adjusted_to_utc(); + .isAdjustedToUtc(); } -LogicalType::TimeUnit::unit TimestampLogicalType::time_unit() const { - return (dynamic_cast(*impl_)) - .time_unit(); +LogicalType::TimeUnit::Unit TimestampLogicalType::timeUnit() const { + return (dynamic_cast(*impl_)).timeUnit(); } -bool TimestampLogicalType::is_from_converted_type() const { +bool TimestampLogicalType::isFromConvertedType() const { return (dynamic_cast(*impl_)) - .is_from_converted_type(); + .isFromConvertedType(); } -bool TimestampLogicalType::force_set_converted_type() const { +bool TimestampLogicalType::forceSetConvertedType() const { return (dynamic_cast(*impl_)) - .force_set_converted_type(); + .forceSetConvertedType(); } class LogicalType::Impl::Interval final @@ -1503,16 +1496,16 @@ class LogicalType::Impl::Interval final friend class IntervalLogicalType; OVERRIDE_TOSTRING(Interval) - // TODO(tpboudreau): uncomment the following line to enable serialization - // after parquet.thrift recognizes IntervalType as a ConvertedType + // TODO(tpboudreau): uncomment the following line to enable serialization. + // After parquet.thrift recognizes IntervalType as a ConvertedType. // OVERRIDE_TOTHRIFT(IntervalType, INTERVAL) private: Interval() - : LogicalType::Impl(LogicalType::Type::INTERVAL, SortOrder::UNKNOWN), - LogicalType::Impl::SimpleCompatible(ConvertedType::INTERVAL), + : LogicalType::Impl(LogicalType::Type::kInterval, SortOrder::kUnknown), + LogicalType::Impl::SimpleCompatible(ConvertedType::kInterval), LogicalType::Impl::TypeLengthApplicable( - parquet::Type::FIXED_LEN_BYTE_ARRAY, + parquet::Type::kFixedLenByteArray, 12) {} }; @@ -1523,152 +1516,152 @@ class LogicalType::Impl::Int final : public LogicalType::Impl::Compatible, public: friend class IntLogicalType; - bool is_applicable( - parquet::Type::type primitive_type, - int32_t primitive_length = -1) const override; - bool is_compatible( - ConvertedType::type converted_type, - schema::DecimalMetadata converted_decimal_metadata) const override; - ConvertedType::type ToConvertedType( - schema::DecimalMetadata* out_decimal_metadata) const override; - std::string ToString() const override; - std::string ToJSON() const override; - facebook::velox::parquet::thrift::LogicalType ToThrift() const override; - bool Equals(const LogicalType& other) const override; - - int bit_width() const { + bool isApplicable( + parquet::Type::type primitiveType, + int32_t primitiveLength = -1) const override; + bool isCompatible( + ConvertedType::type convertedType, + schema::DecimalMetadata convertedDecimalMetadata) const override; + ConvertedType::type toConvertedType( + schema::DecimalMetadata* outDecimalMetadata) const override; + std::string toString() const override; + std::string toJson() const override; + facebook::velox::parquet::thrift::LogicalType toThrift() const override; + bool equals(const LogicalType& other) const override; + + int bitWidth() const { return width_; } - bool is_signed() const { + bool isSigned() const { return signed_; } private: - Int(int w, bool s) + Int(int width, bool isSigned) : LogicalType::Impl( - LogicalType::Type::INT, - (s ? SortOrder::SIGNED : SortOrder::UNSIGNED)), - width_(w), - signed_(s) {} + LogicalType::Type::kInt, + (isSigned ? SortOrder::kSigned : SortOrder::kUnsigned)), + width_(width), + signed_(isSigned) {} int width_ = 0; bool signed_ = false; }; -bool LogicalType::Impl::Int::is_applicable( - parquet::Type::type primitive_type, - int32_t primitive_length) const { - return (primitive_type == parquet::Type::INT32 && width_ <= 32) || - (primitive_type == parquet::Type::INT64 && width_ == 64); +bool LogicalType::Impl::Int::isApplicable( + parquet::Type::type primitiveType, + int32_t primitiveLength) const { + return (primitiveType == parquet::Type::kInt32 && width_ <= 32) || + (primitiveType == parquet::Type::kInt64 && width_ == 64); } -bool LogicalType::Impl::Int::is_compatible( - ConvertedType::type converted_type, - schema::DecimalMetadata converted_decimal_metadata) const { - if (converted_decimal_metadata.isset) { +bool LogicalType::Impl::Int::isCompatible( + ConvertedType::type convertedType, + schema::DecimalMetadata convertedDecimalMetadata) const { + if (convertedDecimalMetadata.isset) { return false; } else if (signed_ && width_ == 8) { - return converted_type == ConvertedType::INT_8; + return convertedType == ConvertedType::kInt8; } else if (signed_ && width_ == 16) { - return converted_type == ConvertedType::INT_16; + return convertedType == ConvertedType::kInt16; } else if (signed_ && width_ == 32) { - return converted_type == ConvertedType::INT_32; + return convertedType == ConvertedType::kInt32; } else if (signed_ && width_ == 64) { - return converted_type == ConvertedType::INT_64; + return convertedType == ConvertedType::kInt64; } else if (!signed_ && width_ == 8) { - return converted_type == ConvertedType::UINT_8; + return convertedType == ConvertedType::kUint8; } else if (!signed_ && width_ == 16) { - return converted_type == ConvertedType::UINT_16; + return convertedType == ConvertedType::kUint16; } else if (!signed_ && width_ == 32) { - return converted_type == ConvertedType::UINT_32; + return convertedType == ConvertedType::kUint32; } else if (!signed_ && width_ == 64) { - return converted_type == ConvertedType::UINT_64; + return convertedType == ConvertedType::kUint64; } else { return false; } } -ConvertedType::type LogicalType::Impl::Int::ToConvertedType( - schema::DecimalMetadata* out_decimal_metadata) const { - reset_decimal_metadata(out_decimal_metadata); +ConvertedType::type LogicalType::Impl::Int::toConvertedType( + schema::DecimalMetadata* outDecimalMetadata) const { + resetDecimalMetadata(outDecimalMetadata); if (signed_) { switch (width_) { case 8: - return ConvertedType::INT_8; + return ConvertedType::kInt8; case 16: - return ConvertedType::INT_16; + return ConvertedType::kInt16; case 32: - return ConvertedType::INT_32; + return ConvertedType::kInt32; case 64: - return ConvertedType::INT_64; + return ConvertedType::kInt64; } } else { // unsigned switch (width_) { case 8: - return ConvertedType::UINT_8; + return ConvertedType::kUint8; case 16: - return ConvertedType::UINT_16; + return ConvertedType::kUint16; case 32: - return ConvertedType::UINT_32; + return ConvertedType::kUint32; case 64: - return ConvertedType::UINT_64; + return ConvertedType::kUint64; } } - return ConvertedType::NONE; + return ConvertedType::kNone; } -std::string LogicalType::Impl::Int::ToString() const { +std::string LogicalType::Impl::Int::toString() const { std::stringstream type; type << "Int(bitWidth=" << width_ << ", isSigned=" << std::boolalpha << signed_ << ")"; return type.str(); } -std::string LogicalType::Impl::Int::ToJSON() const { +std::string LogicalType::Impl::Int::toJson() const { std::stringstream json; - json << R"({"Type": "Int", "bitWidth": )" << width_ << R"(, "isSigned": )" + json << R"({"Type": "int", "bitWidth": )" << width_ << R"(, "isSigned": )" << std::boolalpha << signed_ << "}"; return json.str(); } -facebook::velox::parquet::thrift::LogicalType LogicalType::Impl::Int::ToThrift() +facebook::velox::parquet::thrift::LogicalType LogicalType::Impl::Int::toThrift() const { facebook::velox::parquet::thrift::LogicalType type; - facebook::velox::parquet::thrift::IntType int_type; + facebook::velox::parquet::thrift::IntType intType; VELOX_DCHECK(width_ == 64 || width_ == 32 || width_ == 16 || width_ == 8); - int_type.__set_bitWidth(static_cast(width_)); - int_type.__set_isSigned(signed_); - type.__set_INTEGER(int_type); + intType.__set_bitWidth(static_cast(width_)); + intType.__set_isSigned(signed_); + type.__set_INTEGER(intType); return type; } -bool LogicalType::Impl::Int::Equals(const LogicalType& other) const { +bool LogicalType::Impl::Int::equals(const LogicalType& other) const { bool eq = false; - if (other.is_int()) { - const auto& other_int = checked_cast(other); - eq = (width_ == other_int.bit_width() && signed_ == other_int.is_signed()); + if (other.isInt()) { + const auto& otherInt = checked_cast(other); + eq = (width_ == otherInt.bitWidth() && signed_ == otherInt.isSigned()); } return eq; } -std::shared_ptr IntLogicalType::Make( - int bit_width, - bool is_signed) { - if (bit_width == 8 || bit_width == 16 || bit_width == 32 || bit_width == 64) { - auto* logical_type = new IntLogicalType(); - logical_type->impl_.reset(new LogicalType::Impl::Int(bit_width, is_signed)); - return std::shared_ptr(logical_type); +std::shared_ptr IntLogicalType::make( + int bitWidth, + bool isSigned) { + if (bitWidth == 8 || bitWidth == 16 || bitWidth == 32 || bitWidth == 64) { + auto* logicalType = new IntLogicalType(); + logicalType->impl_.reset(new LogicalType::Impl::Int(bitWidth, isSigned)); + return std::shared_ptr(logicalType); } else { throw ParquetException( "Bit width must be exactly 8, 16, 32, or 64 for Int logical type"); } } -int IntLogicalType::bit_width() const { - return (dynamic_cast(*impl_)).bit_width(); +int IntLogicalType::bitWidth() const { + return (dynamic_cast(*impl_)).bitWidth(); } -bool IntLogicalType::is_signed() const { - return (dynamic_cast(*impl_)).is_signed(); +bool IntLogicalType::isSigned() const { + return (dynamic_cast(*impl_)).isSigned(); } class LogicalType::Impl::Null final @@ -1681,65 +1674,65 @@ class LogicalType::Impl::Null final OVERRIDE_TOTHRIFT(NullType, UNKNOWN) private: - Null() : LogicalType::Impl(LogicalType::Type::NIL, SortOrder::UNKNOWN) {} + Null() : LogicalType::Impl(LogicalType::Type::kNil, SortOrder::kUnknown) {} }; GENERATE_MAKE(Null) -class LogicalType::Impl::JSON final +class LogicalType::Impl::Json final : public LogicalType::Impl::SimpleCompatible, public LogicalType::Impl::SimpleApplicable { public: - friend class JSONLogicalType; + friend class JsonLogicalType; OVERRIDE_TOSTRING(JSON) OVERRIDE_TOTHRIFT(JsonType, JSON) private: - JSON() - : LogicalType::Impl(LogicalType::Type::JSON, SortOrder::UNSIGNED), - LogicalType::Impl::SimpleCompatible(ConvertedType::JSON), - LogicalType::Impl::SimpleApplicable(parquet::Type::BYTE_ARRAY) {} + Json() + : LogicalType::Impl(LogicalType::Type::kJson, SortOrder::kUnsigned), + LogicalType::Impl::SimpleCompatible(ConvertedType::kJson), + LogicalType::Impl::SimpleApplicable(parquet::Type::kByteArray) {} }; -GENERATE_MAKE(JSON) +GENERATE_MAKE(Json) -class LogicalType::Impl::BSON final +class LogicalType::Impl::Bson final : public LogicalType::Impl::SimpleCompatible, public LogicalType::Impl::SimpleApplicable { public: - friend class BSONLogicalType; + friend class BsonLogicalType; OVERRIDE_TOSTRING(BSON) OVERRIDE_TOTHRIFT(BsonType, BSON) private: - BSON() - : LogicalType::Impl(LogicalType::Type::BSON, SortOrder::UNSIGNED), - LogicalType::Impl::SimpleCompatible(ConvertedType::BSON), - LogicalType::Impl::SimpleApplicable(parquet::Type::BYTE_ARRAY) {} + Bson() + : LogicalType::Impl(LogicalType::Type::kBson, SortOrder::kUnsigned), + LogicalType::Impl::SimpleCompatible(ConvertedType::kBson), + LogicalType::Impl::SimpleApplicable(parquet::Type::kByteArray) {} }; -GENERATE_MAKE(BSON) +GENERATE_MAKE(Bson) -class LogicalType::Impl::UUID final +class LogicalType::Impl::Uuid final : public LogicalType::Impl::Incompatible, public LogicalType::Impl::TypeLengthApplicable { public: - friend class UUIDLogicalType; + friend class UuidLogicalType; OVERRIDE_TOSTRING(UUID) OVERRIDE_TOTHRIFT(UUIDType, UUID) private: - UUID() - : LogicalType::Impl(LogicalType::Type::UUID, SortOrder::UNSIGNED), + Uuid() + : LogicalType::Impl(LogicalType::Type::kUuid, SortOrder::kUnsigned), LogicalType::Impl::TypeLengthApplicable( - parquet::Type::FIXED_LEN_BYTE_ARRAY, + parquet::Type::kFixedLenByteArray, 16) {} }; -GENERATE_MAKE(UUID) +GENERATE_MAKE(Uuid) class LogicalType::Impl::No final : public LogicalType::Impl::SimpleCompatible, @@ -1751,8 +1744,8 @@ class LogicalType::Impl::No final private: No() - : LogicalType::Impl(LogicalType::Type::NONE, SortOrder::UNKNOWN), - LogicalType::Impl::SimpleCompatible(ConvertedType::NONE) {} + : LogicalType::Impl(LogicalType::Type::kNone, SortOrder::kUnknown), + LogicalType::Impl::SimpleCompatible(ConvertedType::kNone) {} }; GENERATE_MAKE(No) @@ -1767,8 +1760,8 @@ class LogicalType::Impl::Undefined final private: Undefined() - : LogicalType::Impl(LogicalType::Type::UNDEFINED, SortOrder::UNKNOWN), - LogicalType::Impl::SimpleCompatible(ConvertedType::UNDEFINED) {} + : LogicalType::Impl(LogicalType::Type::kUndefined, SortOrder::kUnknown), + LogicalType::Impl::SimpleCompatible(ConvertedType::kUndefined) {} }; GENERATE_MAKE(Undefined) diff --git a/velox/dwio/parquet/writer/arrow/Types.h b/velox/dwio/parquet/writer/arrow/Types.h index 24727b1c05c..0bf8274a542 100644 --- a/velox/dwio/parquet/writer/arrow/Types.h +++ b/velox/dwio/parquet/writer/arrow/Types.h @@ -40,8 +40,8 @@ class Codec; namespace facebook::velox::parquet::arrow { -// ---------------------------------------------------------------------- -// Metadata enums to match Thrift metadata +// ----------------------------------------------------------------------. +// Metadata enums to match Thrift metadata. // // The reason we maintain our own enums is to avoid transitive dependency on // the compiled Thrift headers (and thus thrift/Thrift.h) for users of the @@ -51,21 +51,21 @@ namespace facebook::velox::parquet::arrow { // // We can also add special values like NONE to distinguish between metadata // values being set and not set. As an example consider ConvertedType and -// CompressionCodec +// CompressionCodec. -// Mirrors parquet::Type +// Mirrors parquet::Type. struct Type { enum type { - BOOLEAN = 0, - INT32 = 1, - INT64 = 2, - INT96 = 3, - FLOAT = 4, - DOUBLE = 5, - BYTE_ARRAY = 6, - FIXED_LEN_BYTE_ARRAY = 7, + kBoolean = 0, + kInt32 = 1, + kInt64 = 2, + kInt96 = 3, + kFloat = 4, + kDouble = 5, + kByteArray = 6, + kFixedLenByteArray = 7, // Should always be last element. - UNDEFINED = 8 + kUndefined = 8 }; }; @@ -74,69 +74,70 @@ namespace parquet { using Type = facebook::velox::parquet::arrow::Type; } -// Mirrors parquet::ConvertedType +// Mirrors parquet::ConvertedType. struct ConvertedType { enum type { - NONE, // Not a real converted type, but means no converted type is specified - UTF8, - MAP, - MAP_KEY_VALUE, - LIST, - ENUM, - DECIMAL, - DATE, - TIME_MILLIS, - TIME_MICROS, - TIMESTAMP_MILLIS, - TIMESTAMP_MICROS, - UINT_8, - UINT_16, - UINT_32, - UINT_64, - INT_8, - INT_16, - INT_32, - INT_64, - JSON, - BSON, - INTERVAL, + kNone, // Not a real converted type, but means no converted type is + // specified + kUtf8, + kMap, + kMapKeyValue, + kList, + kEnum, + kDecimal, + kDate, + kTimeMillis, + kTimeMicros, + kTimestampMillis, + kTimestampMicros, + kUint8, + kUint16, + kUint32, + kUint64, + kInt8, + kInt16, + kInt32, + kInt64, + kJson, + kBson, + kInterval, // DEPRECATED INVALID ConvertedType for all-null data. - // Only useful for reading legacy files written out by interim Parquet C++ - // releases. For writing, always emit LogicalType::Null instead. See - // PARQUET-1990. - NA = 25, - UNDEFINED = 26 // Not a real converted type; should always be last element + // Only useful for reading legacy files written out by interim Parquet + // C++ releases. For writing, always emit LogicalType::nullType instead. + // See PARQUET-1990. + kNa = 25, + kUndefined = 26 // Not a real converted type; should always be last element }; }; -// forward declaration +// Forward declaration. namespace format { class LogicalType; } -// Mirrors parquet::FieldRepetitionType +// Mirrors parquet::FieldRepetitionType. struct Repetition { enum type { - REQUIRED = 0, - OPTIONAL = 1, - REPEATED = 2, - /*Always last*/ UNDEFINED = 3 + kRequired = 0, + kOptional = 1, + kRepeated = 2, + /*Always last*/ kUndefined = 3 }; }; // Reference: -// parquet-mr/parquet-hadoop/src/main/java/org/apache/parquet/ -// format/converter/ParquetMetadataConverter.java -// Sort order for page and column statistics. Types are associated with sort -// orders (e.g., UTF8 columns should use UNSIGNED) and column stats are -// aggregated using a sort order. As of parquet-format version 2.3.1, the -// order used to aggregate stats is always SIGNED and is not stored in the +// Parquet-mr/parquet-hadoop/src/main/java/org/apache/parquet/. +// Format/converter/ParquetMetadataConverter.java. +// Sort order for page and column statistics. Types are associated with sort. +// Orders (e.g., UTF8 columns should use UNSIGNED) and column stats are. +// Aggregated using a sort order. As of parquet-format version 2.3.1, the. +// Order used to aggregate stats is always SIGNED and is not stored in the. // Parquet file. These stats are discarded for types that need unsigned. // See PARQUET-686. struct SortOrder { - enum type { SIGNED, UNSIGNED, UNKNOWN }; + enum type { kSigned, kUnsigned, kUnknown }; }; namespace schema { @@ -154,149 +155,147 @@ class PARQUET_EXPORT LogicalType { public: struct Type { enum type { - UNDEFINED = 0, // Not a real logical type - STRING = 1, - MAP, - LIST, - ENUM, - DECIMAL, - DATE, - TIME, - TIMESTAMP, - INTERVAL, - INT, - NIL, // Thrift NullType: annotates data that is always null - JSON, - BSON, - UUID, - NONE // Not a real logical type; should always be last element + kUndefined = 0, // Not a real logical type + kString = 1, + kMap, + kList, + kEnum, + kDecimal, + kDate, + kTime, + kTimestamp, + kInterval, + kInt, + kNil, // Thrift NullType: annotates data that is always null + kJson, + kBson, + kUuid, + kNone // Not a real logical type; should always be last element }; }; struct TimeUnit { - enum unit { UNKNOWN = 0, MILLIS = 1, MICROS, NANOS }; + enum Unit { kUnknown = 0, kMillis = 1, kMicros, kNanos }; }; /// \brief If possible, return a logical type equivalent to the given legacy /// converted type (and decimal metadata if applicable). - static std::shared_ptr FromConvertedType( - const ConvertedType::type converted_type, - const schema::DecimalMetadata converted_decimal_metadata = { - false, - -1, - -1}); + static std::shared_ptr fromConvertedType( + const ConvertedType::type convertedType, + const schema::DecimalMetadata convertedDecimalMetadata = {false, -1, -1}); /// \brief Return the logical type represented by the Thrift intermediary /// object. - static std::shared_ptr FromThrift( - const facebook::velox::parquet::thrift::LogicalType& thrift_logical_type); + static std::shared_ptr fromThrift( + const facebook::velox::parquet::thrift::LogicalType& thriftLogicalType); /// \brief Return the explicitly requested logical type. - static std::shared_ptr String(); - static std::shared_ptr Map(); - static std::shared_ptr List(); - static std::shared_ptr Enum(); - static std::shared_ptr Decimal( + static std::shared_ptr string(); + static std::shared_ptr map(); + static std::shared_ptr list(); + static std::shared_ptr enumType(); + static std::shared_ptr decimal( int32_t precision, int32_t scale = 0); - static std::shared_ptr Date(); - static std::shared_ptr Time( - bool is_adjusted_to_utc, - LogicalType::TimeUnit::unit time_unit); - - /// \brief Create a Timestamp logical type - /// \param[in] is_adjusted_to_utc set true if the data is UTC-normalized - /// \param[in] time_unit the resolution of the timestamp - /// \param[in] is_from_converted_type if true, the timestamp was generated - /// by translating a legacy converted type of TIMESTAMP_MILLIS or + static std::shared_ptr date(); + static std::shared_ptr time( + bool isAdjustedToUtc, + LogicalType::TimeUnit::Unit timeUnit); + + /// \brief Create a Timestamp logical type. + /// \param[in] is_adjusted_to_utc Set true if the data is UTC-normalized. + /// \param[in] time_unit The resolution of the timestamp. + /// \param[in] is_from_converted_type If true, the timestamp was generated. + /// By translating a legacy converted type of TIMESTAMP_MILLIS or /// TIMESTAMP_MICROS. Default is false. - /// \param[in] force_set_converted_type if true, always set the - /// legacy ConvertedType TIMESTAMP_MICROS and TIMESTAMP_MILLIS - /// metadata. Default is false - static std::shared_ptr Timestamp( - bool is_adjusted_to_utc, - LogicalType::TimeUnit::unit time_unit, - bool is_from_converted_type = false, - bool force_set_converted_type = false); - - static std::shared_ptr Interval(); - static std::shared_ptr Int(int bit_width, bool is_signed); - - /// \brief Create a logical type for data that's always null + /// \param[in] force_set_converted_type If true, always set the + /// legacy ConvertedType TIMESTAMP_MICROS and TIMESTAMP_MILLIS. + /// metadata. Default is false. + static std::shared_ptr timestamp( + bool isAdjustedToUtc, + LogicalType::TimeUnit::Unit timeUnit, + bool isFromConvertedType = false, + bool forceSetConvertedType = false); + + static std::shared_ptr interval(); + static std::shared_ptr intType( + int bitWidth, + bool isSigned); + + /// \brief Create a logical type for data that's always null. /// /// Any physical type can be annotated with this logical type. - static std::shared_ptr Null(); + static std::shared_ptr nullType(); - static std::shared_ptr JSON(); - static std::shared_ptr BSON(); - static std::shared_ptr UUID(); + static std::shared_ptr json(); + static std::shared_ptr bson(); + static std::shared_ptr uuid(); - /// \brief Create a placeholder for when no logical type is specified - static std::shared_ptr None(); + /// \brief Create a placeholder for when no logical type is specified. + static std::shared_ptr none(); - /// \brief Return true if this logical type is consistent with the given - /// underlying physical type. - bool is_applicable( - parquet::Type::type primitive_type, - int32_t primitive_length = -1) const; + /// \brief Return true if this logical type is consistent with the given. + /// Underlying physical type. + bool isApplicable( + parquet::Type::type primitiveType, + int32_t primitiveLength = -1) const; /// \brief Return true if this logical type is equivalent to the given legacy /// converted type (and decimal metadata if applicable). - bool is_compatible( - ConvertedType::type converted_type, - schema::DecimalMetadata converted_decimal_metadata = {false, -1, -1}) - const; + bool isCompatible( + ConvertedType::type convertedType, + schema::DecimalMetadata convertedDecimalMetadata = {false, -1, -1}) const; - /// \brief If possible, return the legacy converted type (and decimal metadata - /// if applicable) equivalent to this logical type. - ConvertedType::type ToConvertedType( - schema::DecimalMetadata* out_decimal_metadata) const; + /// \brief If possible, return the legacy converted type (and decimal + /// metadata if applicable) equivalent to this logical type. + ConvertedType::type toConvertedType( + schema::DecimalMetadata* outDecimalMetadata) const; /// \brief Return a printable representation of this logical type. - std::string ToString() const; + std::string toString() const; /// \brief Return a JSON representation of this logical type. - std::string ToJSON() const; + std::string toJson() const; /// \brief Return a serializable Thrift object for this logical type. - facebook::velox::parquet::thrift::LogicalType ToThrift() const; + facebook::velox::parquet::thrift::LogicalType toThrift() const; - /// \brief Return true if the given logical type is equivalent to this logical - /// type. - bool Equals(const LogicalType& other) const; + /// \brief Return true if the given logical type is equivalent to this + /// logical type. + bool equals(const LogicalType& other) const; /// \brief Return the enumerated type of this logical type. LogicalType::Type::type type() const; /// \brief Return the appropriate sort order for this logical type. - SortOrder::type sort_order() const; + SortOrder::type sortOrder() const; // Type checks ... - bool is_string() const; - bool is_map() const; - bool is_list() const; - bool is_enum() const; - bool is_decimal() const; - bool is_date() const; - bool is_time() const; - bool is_timestamp() const; - bool is_interval() const; - bool is_int() const; - bool is_null() const; - bool is_JSON() const; - bool is_BSON() const; - bool is_UUID() const; - bool is_none() const; + bool isString() const; + bool isMap() const; + bool isList() const; + bool isEnum() const; + bool isDecimal() const; + bool isDate() const; + bool isTime() const; + bool isTimestamp() const; + bool isInterval() const; + bool isInt() const; + bool isNull() const; + bool isJson() const; + bool isBson() const; + bool isUuid() const; + bool isNone() const; /// \brief Return true if this logical type is of a known type. - bool is_valid() const; - bool is_invalid() const; - /// \brief Return true if this logical type is suitable for a schema + bool isValid() const; + bool isInvalid() const; + /// \brief Return true if this logical type is suitable for a schema. /// GroupNode. - bool is_nested() const; - bool is_nonnested() const; - /// \brief Return true if this logical type is included in the Thrift output - /// for its node. - bool is_serialized() const; + bool isNested() const; + bool isNonnested() const; + /// \brief Return true if this logical type is included in the Thrift output. + /// For its node. + bool isSerialized() const; LogicalType(const LogicalType&) = delete; LogicalType& operator=(const LogicalType&) = delete; @@ -312,7 +311,7 @@ class PARQUET_EXPORT LogicalType { /// \brief Allowed for physical type BYTE_ARRAY, must be encoded as UTF-8. class PARQUET_EXPORT StringLogicalType : public LogicalType { public: - static std::shared_ptr Make(); + static std::shared_ptr make(); private: StringLogicalType() = default; @@ -321,7 +320,7 @@ class PARQUET_EXPORT StringLogicalType : public LogicalType { /// \brief Allowed for group nodes only. class PARQUET_EXPORT MapLogicalType : public LogicalType { public: - static std::shared_ptr Make(); + static std::shared_ptr make(); private: MapLogicalType() = default; @@ -330,7 +329,7 @@ class PARQUET_EXPORT MapLogicalType : public LogicalType { /// \brief Allowed for group nodes only. class PARQUET_EXPORT ListLogicalType : public LogicalType { public: - static std::shared_ptr Make(); + static std::shared_ptr make(); private: ListLogicalType() = default; @@ -339,17 +338,17 @@ class PARQUET_EXPORT ListLogicalType : public LogicalType { /// \brief Allowed for physical type BYTE_ARRAY, must be encoded as UTF-8. class PARQUET_EXPORT EnumLogicalType : public LogicalType { public: - static std::shared_ptr Make(); + static std::shared_ptr make(); private: EnumLogicalType() = default; }; -/// \brief Allowed for physical type INT32, INT64, FIXED_LEN_BYTE_ARRAY, or +/// \brief Allowed for physical type INT32, INT64, FIXED_LEN_BYTE_ARRAY, or. /// BYTE_ARRAY, depending on the precision. class PARQUET_EXPORT DecimalLogicalType : public LogicalType { public: - static std::shared_ptr Make( + static std::shared_ptr make( int32_t precision, int32_t scale = 0); int32_t precision() const; @@ -362,21 +361,21 @@ class PARQUET_EXPORT DecimalLogicalType : public LogicalType { /// \brief Allowed for physical type INT32. class PARQUET_EXPORT DateLogicalType : public LogicalType { public: - static std::shared_ptr Make(); + static std::shared_ptr make(); private: DateLogicalType() = default; }; -/// \brief Allowed for physical type INT32 (for MILLIS) or INT64 (for MICROS and -/// NANOS). +/// \brief Allowed for physical type INT32 (for MILLIS) or INT64 (for MICROS +/// and NANOS). class PARQUET_EXPORT TimeLogicalType : public LogicalType { public: - static std::shared_ptr Make( - bool is_adjusted_to_utc, - LogicalType::TimeUnit::unit time_unit); - bool is_adjusted_to_utc() const; - LogicalType::TimeUnit::unit time_unit() const; + static std::shared_ptr make( + bool isAdjustedToUtc, + LogicalType::TimeUnit::Unit timeUnit); + bool isAdjustedToUtc() const; + LogicalType::TimeUnit::Unit timeUnit() const; private: TimeLogicalType() = default; @@ -385,41 +384,41 @@ class PARQUET_EXPORT TimeLogicalType : public LogicalType { /// \brief Allowed for physical type INT64. class PARQUET_EXPORT TimestampLogicalType : public LogicalType { public: - static std::shared_ptr Make( - bool is_adjusted_to_utc, - LogicalType::TimeUnit::unit time_unit, - bool is_from_converted_type = false, - bool force_set_converted_type = false); - bool is_adjusted_to_utc() const; - LogicalType::TimeUnit::unit time_unit() const; + static std::shared_ptr make( + bool isAdjustedToUtc, + LogicalType::TimeUnit::Unit timeUnit, + bool isFromConvertedType = false, + bool forceSetConvertedType = false); + bool isAdjustedToUtc() const; + LogicalType::TimeUnit::Unit timeUnit() const; - /// \brief If true, will not set LogicalType in Thrift metadata - bool is_from_converted_type() const; + /// \brief If true, will not set LogicalType in Thrift metadata. + bool isFromConvertedType() const; - /// \brief If true, will set ConvertedType for micros and millis - /// resolution in legacy ConvertedType Thrift metadata - bool force_set_converted_type() const; + /// \brief If true, will set ConvertedType for micros and millis. + /// Resolution in legacy ConvertedType Thrift metadata. + bool forceSetConvertedType() const; private: TimestampLogicalType() = default; }; -/// \brief Allowed for physical type FIXED_LEN_BYTE_ARRAY with length 12 +/// \brief Allowed for physical type FIXED_LEN_BYTE_ARRAY with length 12. class PARQUET_EXPORT IntervalLogicalType : public LogicalType { public: - static std::shared_ptr Make(); + static std::shared_ptr make(); private: IntervalLogicalType() = default; }; -/// \brief Allowed for physical type INT32 (for bit widths 8, 16, and 32) and +/// \brief Allowed for physical type INT32 (for bit widths 8, 16, and 32) and. /// INT64 (for bit width 64). class PARQUET_EXPORT IntLogicalType : public LogicalType { public: - static std::shared_ptr Make(int bit_width, bool is_signed); - int bit_width() const; - bool is_signed() const; + static std::shared_ptr make(int bitWidth, bool isSigned); + int bitWidth() const; + bool isSigned() const; private: IntLogicalType() = default; @@ -428,110 +427,110 @@ class PARQUET_EXPORT IntLogicalType : public LogicalType { /// \brief Allowed for any physical type. class PARQUET_EXPORT NullLogicalType : public LogicalType { public: - static std::shared_ptr Make(); + static std::shared_ptr make(); private: NullLogicalType() = default; }; /// \brief Allowed for physical type BYTE_ARRAY. -class PARQUET_EXPORT JSONLogicalType : public LogicalType { +class PARQUET_EXPORT JsonLogicalType : public LogicalType { public: - static std::shared_ptr Make(); + static std::shared_ptr make(); private: - JSONLogicalType() = default; + JsonLogicalType() = default; }; /// \brief Allowed for physical type BYTE_ARRAY. -class PARQUET_EXPORT BSONLogicalType : public LogicalType { +class PARQUET_EXPORT BsonLogicalType : public LogicalType { public: - static std::shared_ptr Make(); + static std::shared_ptr make(); private: - BSONLogicalType() = default; + BsonLogicalType() = default; }; -/// \brief Allowed for physical type FIXED_LEN_BYTE_ARRAY with length 16, -/// must encode raw UUID bytes. -class PARQUET_EXPORT UUIDLogicalType : public LogicalType { +/// \brief Allowed for physical type FIXED_LEN_BYTE_ARRAY with length 16,. +/// Must encode raw UUID bytes. +class PARQUET_EXPORT UuidLogicalType : public LogicalType { public: - static std::shared_ptr Make(); + static std::shared_ptr make(); private: - UUIDLogicalType() = default; + UuidLogicalType() = default; }; /// \brief Allowed for any physical type. class PARQUET_EXPORT NoLogicalType : public LogicalType { public: - static std::shared_ptr Make(); + static std::shared_ptr make(); private: NoLogicalType() = default; }; -// Internal API, for unrecognized logical types +// Internal API, for unrecognized logical types. class PARQUET_EXPORT UndefinedLogicalType : public LogicalType { public: - static std::shared_ptr Make(); + static std::shared_ptr make(); private: UndefinedLogicalType() = default; }; -// Data encodings. Mirrors parquet::Encoding +// Data encodings. Mirrors parquet::Encoding. struct Encoding { enum type { - PLAIN = 0, - PLAIN_DICTIONARY = 2, - RLE = 3, - BIT_PACKED = 4, - DELTA_BINARY_PACKED = 5, - DELTA_LENGTH_BYTE_ARRAY = 6, - DELTA_BYTE_ARRAY = 7, - RLE_DICTIONARY = 8, - BYTE_STREAM_SPLIT = 9, - // Should always be last element (except UNKNOWN) - UNDEFINED = 10, - UNKNOWN = 999 + kPlain = 0, + kPlainDictionary = 2, + kRle = 3, + kBitPacked = 4, + kDeltaBinaryPacked = 5, + kDeltaLengthByteArray = 6, + kDeltaByteArray = 7, + kRleDictionary = 8, + kByteStreamSplit = 9, + // Should always be last element (except UNKNOWN). + kUndefined = 10, + kUnknown = 999 }; }; -// Exposed data encodings. It is the encoding of the data read from the file, -// rather than the encoding of the data in the file. E.g., the data encoded as -// RLE_DICTIONARY in the file can be read as dictionary indices by RLE -// decoding, in which case the data read from the file is DICTIONARY encoded. +// Exposed data encodings. It is the encoding of the data read from the file,. +// Rather than the encoding of the data in the file. E.g., the data encoded as. +// RLE_DICTIONARY in the file can be read as dictionary indices by RLE. +// Decoding, in which case the data read from the file is DICTIONARY encoded. enum class ExposedEncoding { - NO_ENCODING = 0, // data is not encoded, i.e. already decoded during reading - DICTIONARY = 1 + kNoEncoding = 0, // data is not encoded, i.e. already decoded during reading + kDictionary = 1 }; -/// \brief Return true if Parquet supports indicated compression type +/// \brief Return true if Parquet supports indicated compression type. PARQUET_EXPORT -bool IsCodecSupported(Compression::type codec); +bool isCodecSupported(Compression::type codec); PARQUET_EXPORT -std::unique_ptr GetCodec(Compression::type codec); +std::unique_ptr getCodec(Compression::type codec); PARQUET_EXPORT -std::unique_ptr GetCodec( +std::unique_ptr getCodec( Compression::type codec, - const util::CodecOptions& codec_options); + const util::CodecOptions& codecOptions); PARQUET_EXPORT -std::unique_ptr GetCodec( +std::unique_ptr getCodec( Compression::type codec, - int compression_level); + int compressionLevel); struct ParquetCipher { - enum type { AES_GCM_V1 = 0, AES_GCM_CTR_V1 = 1 }; + enum type { kAesGcmV1 = 0, kAesGcmCtrV1 = 1 }; }; struct AadMetadata { - std::string aad_prefix; - std::string aad_file_unique; - bool supply_aad_prefix; + std::string aadPrefix; + std::string aadFileUnique; + bool supplyAadPrefix; }; struct EncryptionAlgorithm { @@ -539,71 +538,69 @@ struct EncryptionAlgorithm { AadMetadata aad; }; -// parquet::PageType +// Parquet::PageType. struct PageType { enum type { - DATA_PAGE, - INDEX_PAGE, - DICTIONARY_PAGE, - DATA_PAGE_V2, - // Should always be last element - UNDEFINED + kDataPage, + kIndexPage, + kDictionaryPage, + kDataPageV2, + // Should always be last element. + kUndefined }; }; -bool PageCanUseChecksum(PageType::type pageType); +bool pageCanUseChecksum(PageType::type pageType); class ColumnOrder { public: - enum type { UNDEFINED, TYPE_DEFINED_ORDER }; - explicit ColumnOrder(ColumnOrder::type column_order) - : column_order_(column_order) {} - // Default to Type Defined Order - ColumnOrder() : column_order_(type::TYPE_DEFINED_ORDER) {} - ColumnOrder::type get_order() { - return column_order_; + enum type { kUndefined, kTypeDefinedOrder }; + explicit ColumnOrder(ColumnOrder::type order) : columnOrder_(order) {} + // Default to Type Defined Order. + ColumnOrder() : columnOrder_(type::kTypeDefinedOrder) {} + ColumnOrder::type order() const { + return columnOrder_; } static ColumnOrder undefined_; - static ColumnOrder type_defined_; + static ColumnOrder typeDefined_; private: - ColumnOrder::type column_order_; + ColumnOrder::type columnOrder_; }; -/// \brief BoundaryOrder is a proxy around -/// facebook::velox::parquet::thrift::BoundaryOrder. +/// \brief BoundaryOrder is a proxy around. +/// Facebook::velox::parquet::thrift::BoundaryOrder. struct BoundaryOrder { enum type { - Unordered = 0, - Ascending = 1, - Descending = 2, - // Should always be last element - UNDEFINED = 3 + kUnordered = 0, + kAscending = 1, + kDescending = 2, + // Should always be last element. + kUndefined = 3 }; }; -/// \brief SortingColumn is a proxy around -/// facebook::velox::parquet::thrift::SortingColumn. +/// \brief SortingColumn is a proxy around. +/// Facebook::velox::parquet::thrift::SortingColumn. struct PARQUET_EXPORT SortingColumn { // The column index (in this row group) - int32_t column_idx; + int32_t columnIdx; // If true, indicates this column is sorted in descending order. bool descending; - // If true, nulls will come before non-null values, otherwise, nulls go at the - // end. - bool nulls_first; + // If true, nulls will come before non-null values, otherwise, nulls go at + // the. End. + bool nullsFirst; }; inline bool operator==(const SortingColumn& left, const SortingColumn& right) { - return left.nulls_first == right.nulls_first && - left.descending == right.descending && - left.column_idx == right.column_idx; + return left.nullsFirst == right.nullsFirst && + left.descending == right.descending && left.columnIdx == right.columnIdx; } -// ---------------------------------------------------------------------- +// ----------------------------------------------------------------------. struct ByteArray { ByteArray() : len(0), ptr(NULLPTR) {} @@ -637,11 +634,11 @@ using FLBA = FixedLenByteArray; // Julian day at unix epoch. // -// The Julian Day Number (JDN) is the integer assigned to a whole solar day in -// the Julian day count starting from noon Universal time, with Julian day -// number 0 assigned to the day starting at noon on Monday, January 1, 4713 BC, -// proleptic Julian calendar (November 24, 4714 BC, in the proleptic Gregorian -// calendar), +// The Julian Day Number (JDN) is the integer assigned to a whole solar day in. +// The Julian day count starting from noon Universal time, with Julian day. +// Number 0 assigned to the day starting at noon on Monday, January 1, 4713 BC,. +// Proleptic Julian calendar (November 24, 4714 BC, in the proleptic Gregorian. +// Calendar),. constexpr int64_t kJulianToUnixEpochDays = INT64_C(2440588); constexpr int64_t kSecondsPerDay = INT64_C(60 * 60 * 24); constexpr int64_t kMillisecondsPerDay = kSecondsPerDay * INT64_C(1000); @@ -657,24 +654,24 @@ inline bool operator==(const Int96& left, const Int96& right) { return std::equal(left.value, left.value + 3, right.value); } -static inline std::string ByteArrayToString(const ByteArray& a) { +static inline std::string byteArrayToString(const ByteArray& a) { return std::string(reinterpret_cast(a.ptr), a.len); } -static inline void Int96SetNanoSeconds(Int96& i96, int64_t nanoseconds) { +static inline void int96SetNanoSeconds(Int96& i96, int64_t nanoseconds) { std::memcpy(&i96.value, &nanoseconds, sizeof(nanoseconds)); } struct DecodedInt96 { - uint64_t days_since_epoch; + uint64_t daysSinceEpoch; uint64_t nanoseconds; }; -static inline DecodedInt96 DecodeInt96Timestamp(const Int96& i96) { - // We do the computations in the unsigned domain to avoid unsigned behaviour - // on overflow. +static inline DecodedInt96 decodeInt96Timestamp(const Int96& i96) { + // We do the computations in the unsigned domain to avoid unsigned behaviour. + // On overflow. DecodedInt96 result; - result.days_since_epoch = + result.daysSinceEpoch = i96.value[2] - static_cast(kJulianToUnixEpochDays); result.nanoseconds = 0; @@ -682,40 +679,40 @@ static inline DecodedInt96 DecodeInt96Timestamp(const Int96& i96) { return result; } -static inline int64_t Int96GetNanoSeconds(const Int96& i96) { - const auto decoded = DecodeInt96Timestamp(i96); +static inline int64_t int96GetNanoSeconds(const Int96& i96) { + const auto decoded = decodeInt96Timestamp(i96); return static_cast( - decoded.days_since_epoch * kNanosecondsPerDay + decoded.nanoseconds); + decoded.daysSinceEpoch * kNanosecondsPerDay + decoded.nanoseconds); } -static inline int64_t Int96GetMicroSeconds(const Int96& i96) { - const auto decoded = DecodeInt96Timestamp(i96); +static inline int64_t int96GetMicroSeconds(const Int96& i96) { + const auto decoded = decodeInt96Timestamp(i96); uint64_t microseconds = decoded.nanoseconds / static_cast(1000); return static_cast( - decoded.days_since_epoch * kMicrosecondsPerDay + microseconds); + decoded.daysSinceEpoch * kMicrosecondsPerDay + microseconds); } -static inline int64_t Int96GetMilliSeconds(const Int96& i96) { - const auto decoded = DecodeInt96Timestamp(i96); +static inline int64_t int96GetMilliSeconds(const Int96& i96) { + const auto decoded = decodeInt96Timestamp(i96); uint64_t milliseconds = decoded.nanoseconds / static_cast(1000000); return static_cast( - decoded.days_since_epoch * kMillisecondsPerDay + milliseconds); + decoded.daysSinceEpoch * kMillisecondsPerDay + milliseconds); } -static inline int64_t Int96GetSeconds(const Int96& i96) { - const auto decoded = DecodeInt96Timestamp(i96); +static inline int64_t int96GetSeconds(const Int96& i96) { + const auto decoded = decodeInt96Timestamp(i96); uint64_t seconds = decoded.nanoseconds / static_cast(1000000000); return static_cast( - decoded.days_since_epoch * kSecondsPerDay + seconds); + decoded.daysSinceEpoch * kSecondsPerDay + seconds); } -static inline std::string Int96ToString(const Int96& a) { +static inline std::string int96ToString(const Int96& a) { std::ostringstream result; std::copy(a.value, a.value + 3, std::ostream_iterator(result, " ")); return result.str(); } -static inline std::string FixedLenByteArrayToString( +static inline std::string fixedLenByteArrayToString( const FixedLenByteArray& a, int len) { std::ostringstream result; @@ -724,115 +721,115 @@ static inline std::string FixedLenByteArrayToString( } template -struct type_traits {}; +struct TypeTraits {}; template <> -struct type_traits { - using value_type = bool; +struct TypeTraits { + using ValueType = bool; - static constexpr int value_byte_size = 1; - static constexpr const char* printf_code = "d"; + static constexpr int valueByteSize = 1; + static constexpr const char* printfCode = "d"; }; template <> -struct type_traits { - using value_type = int32_t; +struct TypeTraits { + using ValueType = int32_t; - static constexpr int value_byte_size = 4; - static constexpr const char* printf_code = "d"; + static constexpr int valueByteSize = 4; + static constexpr const char* printfCode = "d"; }; template <> -struct type_traits { - using value_type = int64_t; +struct TypeTraits { + using ValueType = int64_t; - static constexpr int value_byte_size = 8; - static constexpr const char* printf_code = + static constexpr int valueByteSize = 8; + static constexpr const char* printfCode = (sizeof(long) == 64) ? "ld" : "lld"; // NOLINT: runtime/int }; template <> -struct type_traits { - using value_type = Int96; +struct TypeTraits { + using ValueType = Int96; - static constexpr int value_byte_size = 12; - static constexpr const char* printf_code = "s"; + static constexpr int valueByteSize = 12; + static constexpr const char* printfCode = "s"; }; template <> -struct type_traits { - using value_type = float; +struct TypeTraits { + using ValueType = float; - static constexpr int value_byte_size = 4; - static constexpr const char* printf_code = "f"; + static constexpr int valueByteSize = 4; + static constexpr const char* printfCode = "f"; }; template <> -struct type_traits { - using value_type = double; +struct TypeTraits { + using ValueType = double; - static constexpr int value_byte_size = 8; - static constexpr const char* printf_code = "lf"; + static constexpr int valueByteSize = 8; + static constexpr const char* printfCode = "lf"; }; template <> -struct type_traits { - using value_type = ByteArray; +struct TypeTraits { + using ValueType = ByteArray; - static constexpr int value_byte_size = sizeof(ByteArray); - static constexpr const char* printf_code = "s"; + static constexpr int valueByteSize = sizeof(ByteArray); + static constexpr const char* printfCode = "s"; }; template <> -struct type_traits { - using value_type = FixedLenByteArray; +struct TypeTraits { + using ValueType = FixedLenByteArray; - static constexpr int value_byte_size = sizeof(FixedLenByteArray); - static constexpr const char* printf_code = "s"; + static constexpr int valueByteSize = sizeof(FixedLenByteArray); + static constexpr const char* printfCode = "s"; }; template struct PhysicalType { - using c_type = typename type_traits::value_type; - static constexpr Type::type type_num = TYPE; + using CType = typename TypeTraits::ValueType; + static constexpr Type::type typeNum = TYPE; }; -using BooleanType = PhysicalType; -using Int32Type = PhysicalType; -using Int64Type = PhysicalType; -using Int96Type = PhysicalType; -using FloatType = PhysicalType; -using DoubleType = PhysicalType; -using ByteArrayType = PhysicalType; -using FLBAType = PhysicalType; +using BooleanType = PhysicalType; +using Int32Type = PhysicalType; +using Int64Type = PhysicalType; +using Int96Type = PhysicalType; +using FloatType = PhysicalType; +using DoubleType = PhysicalType; +using ByteArrayType = PhysicalType; +using FLBAType = PhysicalType; template -inline std::string format_fwf(int width) { +inline std::string formatFwf(int width) { std::stringstream ss; - ss << "%-" << width << type_traits::printf_code; + ss << "%-" << width << TypeTraits::printfCode; return ss.str(); } -PARQUET_EXPORT std::string EncodingToString(Encoding::type t); +PARQUET_EXPORT std::string encodingToString(Encoding::type t); -PARQUET_EXPORT std::string ConvertedTypeToString(ConvertedType::type t); +PARQUET_EXPORT std::string convertedTypeToString(ConvertedType::type t); -PARQUET_EXPORT std::string TypeToString(Type::type t); +PARQUET_EXPORT std::string typeToString(Type::type t); -PARQUET_EXPORT std::string FormatStatValue( - Type::type parquet_type, +PARQUET_EXPORT std::string formatStatValue( + Type::type parquetType, ::std::string_view val); -PARQUET_EXPORT int GetTypeByteSize(Type::type t); +PARQUET_EXPORT int getTypeByteSize(Type::type t); -PARQUET_EXPORT SortOrder::type DefaultSortOrder(Type::type primitive); +PARQUET_EXPORT SortOrder::type defaultSortOrder(Type::type primitive); -PARQUET_EXPORT SortOrder::type GetSortOrder( +PARQUET_EXPORT SortOrder::type getSortOrder( ConvertedType::type converted, Type::type primitive); -PARQUET_EXPORT SortOrder::type GetSortOrder( - const std::shared_ptr& logical_type, +PARQUET_EXPORT SortOrder::type getSortOrder( + const std::shared_ptr& logicalType, Type::type primitive); } // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/Writer.cpp b/velox/dwio/parquet/writer/arrow/Writer.cpp index e6572bb764d..1d078d8bb7b 100644 --- a/velox/dwio/parquet/writer/arrow/Writer.cpp +++ b/velox/dwio/parquet/writer/arrow/Writer.cpp @@ -75,233 +75,232 @@ using schema::GroupNode; namespace { -int CalculateLeafCount(const DataType* type) { +int calculateLeafCount(const DataType* type) { if (type->id() == ::arrow::Type::EXTENSION) { type = checked_cast(*type).storage_type().get(); } - // Note num_fields() can be 0 for an empty struct type + // Note numFields() can be 0 for an empty struct type. if (!::arrow::is_nested(type->id())) { // Primitive type. return 1; } - int num_leaves = 0; + int numLeaves = 0; for (const auto& field : type->fields()) { - num_leaves += CalculateLeafCount(field->type().get()); + numLeaves += calculateLeafCount(field->type().get()); } - return num_leaves; + return numLeaves; } // Determines if the |schema_field|'s root ancestor is nullable. -bool HasNullableRoot( - const SchemaManifest& schema_manifest, - const SchemaField* schema_field) { - VELOX_DCHECK_NOT_NULL(schema_field); - const SchemaField* current_field = schema_field; - bool nullable = schema_field->field->nullable(); - while (current_field != nullptr) { - nullable = current_field->field->nullable(); - current_field = schema_manifest.GetParent(current_field); +bool hasNullableRoot( + const SchemaManifest& schemaManifest, + const SchemaField* schemaField) { + VELOX_DCHECK_NOT_NULL(schemaField); + const SchemaField* currentField = schemaField; + bool nullable = schemaField->field->nullable(); + while (currentField != nullptr) { + nullable = currentField->field->nullable(); + currentField = schemaManifest.getParent(currentField); } return nullable; } -// Manages writing nested parquet columns with support for all nested types -// supported by parquet. +// Manages writing nested parquet columns with support for all nested types. +// Supported by Parquet. class ArrowColumnWriterV2 { public: - // Constructs a new object (use Make() method below to construct from - // A ChunkedArray). - // level_builders should contain one MultipathLevelBuilder per chunk of the - // Arrow-column to write. + // Constructs a new object (use make() method below to construct from + // a ChunkedArray). + // LevelBuilders should contain one MultipathLevelBuilder per chunk of the + // arrow-column to write. ArrowColumnWriterV2( - std::vector> level_builders, - int start_leaf_column_index, - int leaf_count, - RowGroupWriter* row_group_writer) - : level_builders_(std::move(level_builders)), - start_leaf_column_index_(start_leaf_column_index), - leaf_count_(leaf_count), - row_group_writer_(row_group_writer) {} - - // Writes out all leaf parquet columns to the RowGroupWriter that this + std::vector> levelBuilders, + int startLeafColumnIndex, + int leafCount, + RowGroupWriter* rowGroupWriter) + : levelBuilders_(std::move(levelBuilders)), + startLeafColumnIndex_(startLeafColumnIndex), + leafCount_(leafCount), + rowGroupWriter_(rowGroupWriter) {} + + // Writes out all leaf Parquet columns to the rowGroupWriter that this // object was constructed with. Each leaf column is written fully before // the next column is written (i.e. no buffering is assumed). // // Columns are written in DFS order. - Status Write(ArrowWriteContext* ctx) { - for (int leaf_idx = 0; leaf_idx < leaf_count_; leaf_idx++) { - ColumnWriter* column_writer; - if (row_group_writer_->buffered()) { - const int column_index = start_leaf_column_index_ + leaf_idx; + Status write(ArrowWriteContext* ctx) { + for (int leafIdx = 0; leafIdx < leafCount_; leafIdx++) { + ColumnWriter* columnWriter; + if (rowGroupWriter_->buffered()) { + const int columnIndex = startLeafColumnIndex_ + leafIdx; PARQUET_CATCH_NOT_OK( - column_writer = row_group_writer_->column(column_index)); + columnWriter = rowGroupWriter_->column(columnIndex)); } else { - PARQUET_CATCH_NOT_OK(column_writer = row_group_writer_->NextColumn()); + PARQUET_CATCH_NOT_OK(columnWriter = rowGroupWriter_->nextColumn()); } - for (auto& level_builder : level_builders_) { - RETURN_NOT_OK(level_builder->Write( - leaf_idx, ctx, [&](const MultipathLevelBuilderResult& result) { - size_t visited_component_size = - result.post_list_visited_elements.size(); - VELOX_DCHECK_GT(visited_component_size, 0); - if (visited_component_size != 1) { + for (auto& levelBuilder : levelBuilders_) { + RETURN_NOT_OK(levelBuilder->write( + leafIdx, ctx, [&](const MultipathLevelBuilderResult& result) { + size_t visitedComponentSize = + result.postListVisitedElements.size(); + VELOX_DCHECK_GT(visitedComponentSize, 0); + if (visitedComponentSize != 1) { return Status::NotImplemented( "Lists with non-zero length null components are not supported"); } - const ElementRange& range = result.post_list_visited_elements[0]; - std::shared_ptr values_array = - result.leaf_array->Slice(range.start, range.Size()); - - return column_writer->WriteArrow( - result.def_levels, - result.rep_levels, - result.def_rep_level_count, - *values_array, + const ElementRange& range = result.postListVisitedElements[0]; + std::shared_ptr valuesArray = + result.leafArray->Slice(range.start, range.size()); + + return columnWriter->writeArrow( + result.defLevels, + result.repLevels, + result.defRepLevelCount, + *valuesArray, ctx, - result.leaf_is_nullable); + result.leafIsNullable); })); } - if (!row_group_writer_->buffered()) { - PARQUET_CATCH_NOT_OK(column_writer->Close()); + if (!rowGroupWriter_->buffered()) { + PARQUET_CATCH_NOT_OK(columnWriter->close()); } } return Status::OK(); } - // Make a new object by converting each chunk in |data| to a // MultipathLevelBuilder. // // It is necessary to create a new builder per array because the // MultipathlevelBuilder extracts the data necessary for writing each leaf - // column at construction time. (it optimizes based on null count) and with + // column at construction time (it optimizes based on null count) and with // slicing via |offset| ephemeral chunks are created which need to be tracked - // across each leaf column-write. This decision could potentially be revisited - // if we wanted to use "buffered" RowGroupWriters (we could construct each - // builder on demand in that case). - static ::arrow::Result> Make( + // across each leaf column-write. This decision could potentially be + // revisited if we wanted to use "buffered" RowGroupWriters (we could + // construct each builder on demand in that case). + static ::arrow::Result> make( const ChunkedArray& data, int64_t offset, const int64_t size, - const SchemaManifest& schema_manifest, - RowGroupWriter* row_group_writer, - int start_leaf_column_index = -1) { - int64_t absolute_position = 0; - int chunk_index = 0; - int64_t chunk_offset = 0; + const SchemaManifest& schemaManifest, + RowGroupWriter* rowGroupWriter, + int startLeafColumnIndex = -1) { + int64_t absolutePosition = 0; + int chunkIndex = 0; + int64_t chunkOffset = 0; if (data.length() == 0) { return std::make_unique( std::vector>{}, - start_leaf_column_index, - CalculateLeafCount(data.type().get()), - row_group_writer); + startLeafColumnIndex, + calculateLeafCount(data.type().get()), + rowGroupWriter); } - while (chunk_index < data.num_chunks() && absolute_position < offset) { - const int64_t chunk_length = data.chunk(chunk_index)->length(); - if (absolute_position + chunk_length > offset) { + while (chunkIndex < data.num_chunks() && absolutePosition < offset) { + const int64_t chunkLength = data.chunk(chunkIndex)->length(); + if (absolutePosition + chunkLength > offset) { // Relative offset into the chunk to reach the desired start offset for - // writing - chunk_offset = offset - absolute_position; + // writing. + chunkOffset = offset - absolutePosition; break; } else { - ++chunk_index; - absolute_position += chunk_length; + ++chunkIndex; + absolutePosition += chunkLength; } } - if (absolute_position >= data.length()) { + if (absolutePosition >= data.length()) { return Status::Invalid( "Cannot write data at offset past end of chunked array"); } - int64_t values_written = 0; + int64_t valuesWritten = 0; std::vector> builders; - const int leaf_count = CalculateLeafCount(data.type().get()); - bool is_nullable = false; + const int leafCount = calculateLeafCount(data.type().get()); + bool isNullable = false; - int column_index = 0; - if (row_group_writer->buffered()) { - column_index = start_leaf_column_index; + int columnIndex = 0; + if (rowGroupWriter->buffered()) { + columnIndex = startLeafColumnIndex; } else { - // The row_group_writer hasn't been advanced yet so add 1 to the current + // The rowGroupWriter hasn't been advanced yet so add 1 to the current // which is the one this instance will start writing for. - column_index = row_group_writer->current_column() + 1; + columnIndex = rowGroupWriter->currentColumn() + 1; } - for (int leaf_offset = 0; leaf_offset < leaf_count; ++leaf_offset) { - const SchemaField* schema_field = nullptr; - RETURN_NOT_OK(schema_manifest.GetColumnField( - column_index + leaf_offset, &schema_field)); - bool nullable_root = HasNullableRoot(schema_manifest, schema_field); - if (leaf_offset == 0) { - is_nullable = nullable_root; + for (int leafOffset = 0; leafOffset < leafCount; ++leafOffset) { + const SchemaField* schemaField = nullptr; + RETURN_NOT_OK(schemaManifest.getColumnField( + columnIndex + leafOffset, &schemaField)); + bool nullableRoot = hasNullableRoot(schemaManifest, schemaField); + if (leafOffset == 0) { + isNullable = nullableRoot; } // Don't validate common ancestry for all leafs if not in debug. #ifndef NDEBUG break; #else - if (is_nullable != nullable_root) { + if (isNullable != nullableRoot) { return Status::UnknownError( "Unexpected mismatched nullability between column index", - column_index + leaf_offset, + columnIndex + leafOffset, " and ", - column_index); + columnIndex); } #endif } - while (values_written < size) { - const Array& chunk = *data.chunk(chunk_index); - const int64_t available_values = chunk.length() - chunk_offset; - const int64_t chunk_write_size = - std::min(size - values_written, available_values); + while (valuesWritten < size) { + const Array& chunk = *data.chunk(chunkIndex); + const int64_t availableValues = chunk.length() - chunkOffset; + const int64_t chunkWriteSize = + std::min(size - valuesWritten, availableValues); // The chunk offset here will be 0 except for possibly the first chunk - // because of the advancing logic above - std::shared_ptr array_to_write = - chunk.Slice(chunk_offset, chunk_write_size); + // because of the advancing logic above. + std::shared_ptr arrayToWrite = + chunk.Slice(chunkOffset, chunkWriteSize); - if (array_to_write->length() > 0) { + if (arrayToWrite->length() > 0) { ARROW_ASSIGN_OR_RAISE( std::unique_ptr builder, - MultipathLevelBuilder::Make(*array_to_write, is_nullable)); - if (leaf_count != builder->GetLeafCount()) { + MultipathLevelBuilder::make(*arrayToWrite, isNullable)); + if (leafCount != builder->getLeafCount()) { return Status::UnknownError( "data type leaf_count != builder_leaf_count", - leaf_count, + leafCount, " ", - builder->GetLeafCount()); + builder->getLeafCount()); } builders.emplace_back(std::move(builder)); } - if (chunk_write_size == available_values) { - chunk_offset = 0; - ++chunk_index; + if (chunkWriteSize == availableValues) { + chunkOffset = 0; + ++chunkIndex; } - values_written += chunk_write_size; + valuesWritten += chunkWriteSize; } return std::make_unique( - std::move(builders), column_index, leaf_count, row_group_writer); + std::move(builders), columnIndex, leafCount, rowGroupWriter); } - int leaf_count() const { - return leaf_count_; + int leafCount() const { + return leafCount_; } private: // One builder per column-chunk. - std::vector> level_builders_; - int start_leaf_column_index_; - int leaf_count_; - RowGroupWriter* row_group_writer_; + std::vector> levelBuilders_; + int startLeafColumnIndex_; + int leafCount_; + RowGroupWriter* rowGroupWriter_; }; } // namespace // ---------------------------------------------------------------------- -// FileWriter implementation +// FileWriter implementation. class FileWriterImpl : public FileWriter { public: @@ -309,94 +308,93 @@ class FileWriterImpl : public FileWriter { std::shared_ptr<::arrow::Schema> schema, MemoryPool* pool, std::unique_ptr writer, - std::shared_ptr arrow_properties) + std::shared_ptr arrowProperties) : schema_(std::move(schema)), writer_(std::move(writer)), - row_group_writer_(nullptr), - column_write_context_(pool, arrow_properties.get()), - arrow_properties_(std::move(arrow_properties)), + rowGroupWriter_(nullptr), + columnWriteContext_(pool, arrowProperties.get()), + arrowProperties_(std::move(arrowProperties)), closed_(false) { - if (arrow_properties_->use_threads()) { - parallel_column_write_contexts_.reserve(schema_->num_fields()); + if (arrowProperties_->useThreads()) { + parallelColumnWriteContexts_.reserve(schema_->num_fields()); for (int i = 0; i < schema_->num_fields(); ++i) { // Explicitly create each ArrowWriteContext object to avoid - // unintentional call of the copy constructor. Otherwise, the buffers in - // the type of sharad_ptr will be shared among all contexts. - parallel_column_write_contexts_.emplace_back( - pool, arrow_properties_.get()); + // unintentional call of the copy constructor. Otherwise, the buffers + // in the type of shared_ptr will be shared among all contexts. + parallelColumnWriteContexts_.emplace_back(pool, arrowProperties_.get()); } } } - Status Init() { - return SchemaManifest::Make( + Status init() { + return SchemaManifest::make( writer_->schema(), - /*schema_metadata=*/nullptr, - default_arrow_reader_properties(), - &schema_manifest_); + nullptr, + defaultArrowReaderProperties(), + &schemaManifest_); } - Status NewRowGroup(int64_t chunk_size) override { - if (row_group_writer_ != nullptr) { - PARQUET_CATCH_NOT_OK(row_group_writer_->Close()); + Status newRowGroup(int64_t chunkSize) override { + if (rowGroupWriter_ != nullptr) { + PARQUET_CATCH_NOT_OK(rowGroupWriter_->close()); } - PARQUET_CATCH_NOT_OK(row_group_writer_ = writer_->AppendRowGroup()); + PARQUET_CATCH_NOT_OK(rowGroupWriter_ = writer_->appendRowGroup()); return Status::OK(); } - Status Close() override { + Status close() override { if (!closed_) { - // Make idempotent + // Make idempotent. closed_ = true; - if (row_group_writer_ != nullptr) { - PARQUET_CATCH_NOT_OK(row_group_writer_->Close()); + if (rowGroupWriter_ != nullptr) { + PARQUET_CATCH_NOT_OK(rowGroupWriter_->close()); } - PARQUET_CATCH_NOT_OK(writer_->Close()); + PARQUET_CATCH_NOT_OK(writer_->close()); } return Status::OK(); } - Status WriteColumnChunk(const Array& data) override { + Status writeColumnChunk(const Array& data) override { // A bit awkward here since cannot instantiate ChunkedArray from const - // Array& + // Array&. auto chunk = ::arrow::MakeArray(data.data()); - auto chunked_array = std::make_shared<::arrow::ChunkedArray>(chunk); - return WriteColumnChunk(chunked_array, 0, data.length()); + auto chunkedArray = std::make_shared<::arrow::ChunkedArray>(chunk); + return writeColumnChunk(chunkedArray, 0, data.length()); } - Status WriteColumnChunk( + Status writeColumnChunk( const std::shared_ptr& data, int64_t offset, int64_t size) override { - if (arrow_properties_->engine_version() == ArrowWriterProperties::V2 || - arrow_properties_->engine_version() == ArrowWriterProperties::V1) { - if (row_group_writer_->buffered()) { + if (arrowProperties_->engineVersion() == ArrowWriterProperties::V2 || + arrowProperties_->engineVersion() == ArrowWriterProperties::V1) { + if (rowGroupWriter_->buffered()) { return Status::Invalid( "Cannot write column chunk into the buffered row group."); } ARROW_ASSIGN_OR_RAISE( std::unique_ptr writer, - ArrowColumnWriterV2::Make( - *data, offset, size, schema_manifest_, row_group_writer_)); - return writer->Write(&column_write_context_); + ArrowColumnWriterV2::make( + *data, offset, size, schemaManifest_, rowGroupWriter_)); + return writer->write(&columnWriteContext_); } return Status::NotImplemented("Unknown engine version."); } - Status WriteColumnChunk( + Status writeColumnChunk( const std::shared_ptr<::arrow::ChunkedArray>& data) override { - return WriteColumnChunk(data, 0, data->length()); + return writeColumnChunk(data, 0, data->length()); } std::shared_ptr<::arrow::Schema> schema() const override { return schema_; } - Status WriteTable(const Table& table, int64_t chunk_size) override { + Status writeTable(const Table& table, int64_t chunkSize) override { RETURN_NOT_OK(table.Validate()); - if (chunk_size <= 0 && table.num_rows() > 0) { + if (chunkSize <= 0 && table.num_rows() > 0) { return Status::Invalid("chunk size per row_group must be greater than 0"); } else if (!table.schema()->Equals(*schema_, false)) { return Status::Invalid( @@ -405,87 +403,86 @@ class FileWriterImpl : public FileWriter { "' this:'", schema_->ToString(), "'"); - } else if (chunk_size > this->properties().max_row_group_length()) { - chunk_size = this->properties().max_row_group_length(); + } else if (chunkSize > this->properties().maxRowGroupLength()) { + chunkSize = this->properties().maxRowGroupLength(); } - auto WriteRowGroup = [&](int64_t offset, int64_t size) { - RETURN_NOT_OK(NewRowGroup(size)); + auto writeRowGroup = [&](int64_t offset, int64_t size) { + RETURN_NOT_OK(newRowGroup(size)); for (int i = 0; i < table.num_columns(); i++) { - RETURN_NOT_OK(WriteColumnChunk(table.column(i), offset, size)); + RETURN_NOT_OK(writeColumnChunk(table.column(i), offset, size)); } return Status::OK(); }; if (table.num_rows() == 0) { - // Append a row group with 0 rows - RETURN_NOT_OK_ELSE(WriteRowGroup(0, 0), PARQUET_IGNORE_NOT_OK(Close())); + // Append a row group with 0 rows. + RETURN_NOT_OK_ELSE(writeRowGroup(0, 0), PARQUET_IGNORE_NOT_OK(close())); return Status::OK(); } - for (int chunk = 0; chunk * chunk_size < table.num_rows(); chunk++) { - int64_t offset = chunk * chunk_size; + for (int chunk = 0; chunk * chunkSize < table.num_rows(); chunk++) { + int64_t offset = chunk * chunkSize; RETURN_NOT_OK_ELSE( - WriteRowGroup( - offset, std::min(chunk_size, table.num_rows() - offset)), - PARQUET_IGNORE_NOT_OK(Close())); + writeRowGroup(offset, std::min(chunkSize, table.num_rows() - offset)), + PARQUET_IGNORE_NOT_OK(close())); } return Status::OK(); } - Status NewBufferedRowGroup() override { - if (row_group_writer_ != nullptr) { - PARQUET_CATCH_NOT_OK(row_group_writer_->Close()); + Status newBufferedRowGroup() override { + if (rowGroupWriter_ != nullptr) { + PARQUET_CATCH_NOT_OK(rowGroupWriter_->close()); } - PARQUET_CATCH_NOT_OK(row_group_writer_ = writer_->AppendBufferedRowGroup()); + PARQUET_CATCH_NOT_OK(rowGroupWriter_ = writer_->appendBufferedRowGroup()); return Status::OK(); } - Status WriteRecordBatch(const RecordBatch& batch) override { + Status writeRecordBatch(const RecordBatch& batch) override { if (batch.num_rows() == 0) { return Status::OK(); } // Max number of rows allowed in a row group. - const int64_t max_row_group_length = - this->properties().max_row_group_length(); + const int64_t maxRowGroupLength = this->properties().maxRowGroupLength(); - if (row_group_writer_ == nullptr || !row_group_writer_->buffered() || - row_group_writer_->num_rows() >= max_row_group_length) { - RETURN_NOT_OK(NewBufferedRowGroup()); + if (rowGroupWriter_ == nullptr || !rowGroupWriter_->buffered() || + rowGroupWriter_->numRows() >= maxRowGroupLength) { + RETURN_NOT_OK(newBufferedRowGroup()); } - auto WriteBatch = [&](int64_t offset, int64_t size) { + auto writeBatch = [&](int64_t offset, int64_t size) { std::vector> writers; - int column_index_start = 0; + int columnIndexStart = 0; for (int i = 0; i < batch.num_columns(); i++) { - ChunkedArray chunked_array{batch.column(i)}; + ChunkedArray chunkedArray{batch.column(i)}; ARROW_ASSIGN_OR_RAISE( std::unique_ptr writer, - ArrowColumnWriterV2::Make( - chunked_array, + ArrowColumnWriterV2::make( + chunkedArray, offset, size, - schema_manifest_, - row_group_writer_, - column_index_start)); - column_index_start += writer->leaf_count(); - if (arrow_properties_->use_threads()) { + schemaManifest_, + rowGroupWriter_, + columnIndexStart)); + columnIndexStart += writer->leafCount(); + if (arrowProperties_->useThreads()) { writers.emplace_back(std::move(writer)); } else { - RETURN_NOT_OK(writer->Write(&column_write_context_)); + RETURN_NOT_OK(writer->write(&columnWriteContext_)); } } - if (arrow_properties_->use_threads()) { - VELOX_DCHECK_EQ(parallel_column_write_contexts_.size(), writers.size()); - RETURN_NOT_OK(::arrow::internal::ParallelFor( - static_cast(writers.size()), - [&](int i) { - return writers[i]->Write(¶llel_column_write_contexts_[i]); - }, - arrow_properties_->executor())); + if (arrowProperties_->useThreads()) { + VELOX_DCHECK_EQ(parallelColumnWriteContexts_.size(), writers.size()); + RETURN_NOT_OK( + ::arrow::internal::ParallelFor( + static_cast(writers.size()), + [&](int i) { + return writers[i]->write(¶llelColumnWriteContexts_[i]); + }, + arrowProperties_->executor())); } return Status::OK(); @@ -493,15 +490,15 @@ class FileWriterImpl : public FileWriter { int64_t offset = 0; while (offset < batch.num_rows()) { - const int64_t batch_size = std::min( - max_row_group_length - row_group_writer_->num_rows(), + const int64_t batchSize = std::min( + maxRowGroupLength - rowGroupWriter_->numRows(), batch.num_rows() - offset); - RETURN_NOT_OK(WriteBatch(offset, batch_size)); - offset += batch_size; + RETURN_NOT_OK(writeBatch(offset, batchSize)); + offset += batchSize; // Flush current row group if it is full. - if (row_group_writer_->num_rows() >= max_row_group_length) { - RETURN_NOT_OK(NewBufferedRowGroup()); + if (rowGroupWriter_->numRows() >= maxRowGroupLength) { + RETURN_NOT_OK(newBufferedRowGroup()); } } @@ -512,8 +509,8 @@ class FileWriterImpl : public FileWriter { return *writer_->properties(); } - ::arrow::MemoryPool* memory_pool() const override { - return column_write_context_.memory_pool; + ::arrow::MemoryPool* memoryPool() const override { + return columnWriteContext_.memoryPool; } const std::shared_ptr metadata() const override { @@ -525,36 +522,36 @@ class FileWriterImpl : public FileWriter { std::shared_ptr<::arrow::Schema> schema_; - SchemaManifest schema_manifest_; + SchemaManifest schemaManifest_; std::unique_ptr writer_; - RowGroupWriter* row_group_writer_; - ArrowWriteContext column_write_context_; - std::shared_ptr arrow_properties_; + RowGroupWriter* rowGroupWriter_; + ArrowWriteContext columnWriteContext_; + std::shared_ptr arrowProperties_; bool closed_; - /// If arrow_properties_.use_threads() is true, the vector size is equal to + /// If arrowProperties_->useThreads() is true, the vector size is equal to /// schema_->num_fields() to make it thread-safe. Otherwise, the vector is - /// empty and column_write_context_ above is shared by all columns. - std::vector parallel_column_write_contexts_; + /// empty and columnWriteContext_ above is shared by all columns. + std::vector parallelColumnWriteContexts_; }; FileWriter::~FileWriter() {} -Status FileWriter::Make( +Status FileWriter::make( ::arrow::MemoryPool* pool, std::unique_ptr writer, std::shared_ptr<::arrow::Schema> schema, - std::shared_ptr arrow_properties, + std::shared_ptr arrowProperties, std::unique_ptr* out) { std::unique_ptr impl(new FileWriterImpl( - std::move(schema), pool, std::move(writer), std::move(arrow_properties))); - RETURN_NOT_OK(impl->Init()); + std::move(schema), pool, std::move(writer), std::move(arrowProperties))); + RETURN_NOT_OK(impl->init()); *out = std::move(impl); return Status::OK(); } -Status FileWriter::Open( +Status FileWriter::open( const ::arrow::Schema& schema, ::arrow::MemoryPool* pool, std::shared_ptr<::arrow::io::OutputStream> sink, @@ -562,21 +559,21 @@ Status FileWriter::Open( std::unique_ptr* writer) { ARROW_ASSIGN_OR_RAISE( *writer, - Open( + open( std::move(schema), pool, std::move(sink), std::move(properties), - default_arrow_writer_properties())); + defaultArrowWriterProperties())); return Status::OK(); } -Status GetSchemaMetadata( +Status getSchemaMetadata( const ::arrow::Schema& schema, ::arrow::MemoryPool* pool, const ArrowWriterProperties& properties, std::shared_ptr* out) { - if (!properties.store_schema()) { + if (!properties.storeSchema()) { *out = nullptr; return Status::OK(); } @@ -593,102 +590,102 @@ Status GetSchemaMetadata( std::shared_ptr serialized, ::arrow::ipc::SerializeSchema(schema, pool)); - // The serialized schema is not UTF-8, which is required for Thrift - std::string schema_as_string = serialized->ToString(); - std::string schema_base64 = ::arrow::util::base64_encode(schema_as_string); - result->Append(kArrowSchemaKey, schema_base64); + // The serialized schema is not UTF-8, which is required for Thrift. + std::string schemaAsString = serialized->ToString(); + std::string schemaBase64 = ::arrow::util::base64_encode(schemaAsString); + result->Append(kArrowSchemaKey, schemaBase64); *out = result; return Status::OK(); } -Status FileWriter::Open( +Status FileWriter::open( const ::arrow::Schema& schema, ::arrow::MemoryPool* pool, std::shared_ptr<::arrow::io::OutputStream> sink, std::shared_ptr properties, - std::shared_ptr arrow_properties, + std::shared_ptr arrowProperties, std::unique_ptr* writer) { ARROW_ASSIGN_OR_RAISE( *writer, - Open( + open( std::move(schema), pool, std::move(sink), std::move(properties), - arrow_properties)); + arrowProperties)); return Status::OK(); } -Result> FileWriter::Open( +Result> FileWriter::open( const ::arrow::Schema& schema, ::arrow::MemoryPool* pool, std::shared_ptr<::arrow::io::OutputStream> sink, std::shared_ptr properties, - std::shared_ptr arrow_properties) { - std::shared_ptr parquet_schema; - RETURN_NOT_OK(ToParquetSchema( - &schema, *properties, *arrow_properties, &parquet_schema)); + std::shared_ptr arrowProperties) { + std::shared_ptr parquetSchema; + RETURN_NOT_OK( + toParquetSchema(&schema, *properties, *arrowProperties, &parquetSchema)); - auto schema_node = - std::static_pointer_cast(parquet_schema->schema_root()); + auto schemaNode = + std::static_pointer_cast(parquetSchema->schemaRoot()); std::shared_ptr metadata; - RETURN_NOT_OK(GetSchemaMetadata(schema, pool, *arrow_properties, &metadata)); + RETURN_NOT_OK(getSchemaMetadata(schema, pool, *arrowProperties, &metadata)); - std::unique_ptr base_writer; + std::unique_ptr baseWriter; PARQUET_CATCH_NOT_OK( - base_writer = ParquetFileWriter::Open( + baseWriter = ParquetFileWriter::open( std::move(sink), - schema_node, + schemaNode, std::move(properties), std::move(metadata))); std::unique_ptr writer; - auto schema_ptr = std::make_shared<::arrow::Schema>(schema); - RETURN_NOT_OK(Make( + auto schemaPtr = std::make_shared<::arrow::Schema>(schema); + RETURN_NOT_OK(make( pool, - std::move(base_writer), - std::move(schema_ptr), - std::move(arrow_properties), + std::move(baseWriter), + std::move(schemaPtr), + std::move(arrowProperties), &writer)); return writer; } -Status WriteFileMetaData( - const FileMetaData& file_metadata, +Status writeFileMetaData( + const FileMetaData& fileMetadata, ::arrow::io::OutputStream* sink) { - PARQUET_CATCH_NOT_OK(::facebook::velox::parquet::arrow::WriteFileMetaData( - file_metadata, sink)); + PARQUET_CATCH_NOT_OK( + ::facebook::velox::parquet::arrow::writeFileMetaData(fileMetadata, sink)); return Status::OK(); } -Status WriteMetaDataFile( - const FileMetaData& file_metadata, +Status writeMetaDataFile( + const FileMetaData& fileMetadata, ::arrow::io::OutputStream* sink) { - PARQUET_CATCH_NOT_OK(::facebook::velox::parquet::arrow::WriteMetaDataFile( - file_metadata, sink)); + PARQUET_CATCH_NOT_OK( + ::facebook::velox::parquet::arrow::writeMetaDataFile(fileMetadata, sink)); return Status::OK(); } -Status WriteTable( +Status writeTable( const ::arrow::Table& table, ::arrow::MemoryPool* pool, std::shared_ptr<::arrow::io::OutputStream> sink, - int64_t chunk_size, + int64_t chunkSize, std::shared_ptr properties, - std::shared_ptr arrow_properties) { + std::shared_ptr arrowProperties) { std::unique_ptr writer; ARROW_ASSIGN_OR_RAISE( writer, - FileWriter::Open( + FileWriter::open( *table.schema(), pool, std::move(sink), std::move(properties), - std::move(arrow_properties))); - RETURN_NOT_OK(writer->WriteTable(table, chunk_size)); - return writer->Close(); + std::move(arrowProperties))); + RETURN_NOT_OK(writer->writeTable(table, chunkSize)); + return writer->close(); } } // namespace facebook::velox::parquet::arrow::arrow diff --git a/velox/dwio/parquet/writer/arrow/Writer.h b/velox/dwio/parquet/writer/arrow/Writer.h index 441b4096acb..840d11cbc4e 100644 --- a/velox/dwio/parquet/writer/arrow/Writer.h +++ b/velox/dwio/parquet/writer/arrow/Writer.h @@ -41,47 +41,47 @@ class ParquetFileWriter; namespace arrow { -/// \brief Iterative FileWriter class +/// \brief Iterative FileWriter class. /// /// For basic usage, can write a Table at a time, creating one or more row /// groups per write call. /// -/// For advanced usage, can write column-by-column: Start a new RowGroup or -/// Chunk with NewRowGroup, then write column-by-column the whole column chunk. +/// For advanced usage, can write column-by-column: Start a new row group or +/// chunk with newRowGroup(), then write column-by-column the whole column +/// chunk. /// /// If PARQUET:field_id is present as a metadata key on a field, and the /// corresponding value is a nonnegative integer, then it will be used as the -/// field_id in the parquet file. +/// field_id in the Parquet file. class PARQUET_EXPORT FileWriter { public: - static ::arrow::Status Make( + static ::arrow::Status make( MemoryPool* pool, std::unique_ptr writer, std::shared_ptr<::arrow::Schema> schema, - std::shared_ptr arrow_properties, + std::shared_ptr arrowProperties, std::unique_ptr* out); /// \brief Try to create an Arrow to Parquet file writer. /// - /// \param schema schema of data that will be passed. - /// \param pool memory pool to use. - /// \param sink output stream to write Parquet data. - /// \param properties general Parquet writer properties. + /// \param schema Schema of data that will be passed. + /// \param pool Memory pool to use. + /// \param sink Output stream to write Parquet data. + /// \param properties General Parquet writer properties. /// \param arrow_properties Arrow-specific writer properties. /// - /// \since 11.0.0 - static ::arrow::Result> Open( + /// \since 11.0.0. + static ::arrow::Result> open( const ::arrow::Schema& schema, MemoryPool* pool, std::shared_ptr<::arrow::io::OutputStream> sink, - std::shared_ptr properties = - default_writer_properties(), - std::shared_ptr arrow_properties = - default_arrow_writer_properties()); + std::shared_ptr properties = defaultWriterProperties(), + std::shared_ptr arrowProperties = + defaultArrowWriterProperties()); ARROW_DEPRECATED( "Deprecated in 11.0.0. Use Result-returning variants instead.") - static ::arrow::Status Open( + static ::arrow::Status open( const ::arrow::Schema& schema, MemoryPool* pool, std::shared_ptr<::arrow::io::OutputStream> sink, @@ -89,12 +89,12 @@ class PARQUET_EXPORT FileWriter { std::unique_ptr* writer); ARROW_DEPRECATED( "Deprecated in 11.0.0. Use Result-returning variants instead.") - static ::arrow::Status Open( + static ::arrow::Status open( const ::arrow::Schema& schema, MemoryPool* pool, std::shared_ptr<::arrow::io::OutputStream> sink, std::shared_ptr properties, - std::shared_ptr arrow_properties, + std::shared_ptr arrowProperties, std::unique_ptr* writer); /// Return the Arrow schema to be written to. @@ -103,74 +103,74 @@ class PARQUET_EXPORT FileWriter { /// \brief Write a Table to Parquet. /// /// \param table Arrow table to write. - /// \param chunk_size maximum number of rows to write per row group. - virtual ::arrow::Status WriteTable( + /// \param chunk_size Maximum number of rows to write per row group. + virtual ::arrow::Status writeTable( const ::arrow::Table& table, - int64_t chunk_size = DEFAULT_MAX_ROW_GROUP_LENGTH) = 0; + int64_t chunkSize = DEFAULT_MAX_ROW_GROUP_LENGTH) = 0; /// \brief Start a new row group. /// /// Returns an error if not all columns have been written. /// - /// \param chunk_size the number of rows in the next row group. - virtual ::arrow::Status NewRowGroup(int64_t chunk_size) = 0; + /// \param chunk_size The number of rows in the next row group. + virtual ::arrow::Status newRowGroup(int64_t chunkSize) = 0; /// \brief Write ColumnChunk in row group using an array. - virtual ::arrow::Status WriteColumnChunk(const ::arrow::Array& data) = 0; + virtual ::arrow::Status writeColumnChunk(const ::arrow::Array& data) = 0; - /// \brief Write ColumnChunk in row group using slice of a ChunkedArray - virtual ::arrow::Status WriteColumnChunk( + /// \brief Write ColumnChunk in row group using slice of a ChunkedArray. + virtual ::arrow::Status writeColumnChunk( const std::shared_ptr<::arrow::ChunkedArray>& data, int64_t offset, int64_t size) = 0; - /// \brief Write ColumnChunk in a row group using a ChunkedArray - virtual ::arrow::Status WriteColumnChunk( + /// \brief Write ColumnChunk in a row group using a ChunkedArray. + virtual ::arrow::Status writeColumnChunk( const std::shared_ptr<::arrow::ChunkedArray>& data) = 0; /// \brief Start a new buffered row group. /// /// Returns an error if not all columns have been written. - virtual ::arrow::Status NewBufferedRowGroup() = 0; + virtual ::arrow::Status newBufferedRowGroup() = 0; /// \brief Write a RecordBatch into the buffered row group. /// - /// Multiple RecordBatches can be written into the same row group - /// through this method. + /// Multiple RecordBatches can be written into the same row group through this + /// method. /// - /// WriterProperties.max_row_group_length() is respected and a new + /// WriterProperties.maxRowGroupLength() is respected and a new /// row group will be created if the current row group exceeds the /// limit. /// - /// Batches get flushed to the output stream once NewBufferedRowGroup() - /// or Close() is called. + /// Batches get flushed to the output stream once newBufferedRowGroup() + /// or close() is called. /// /// WARNING: If you are writing multiple files in parallel in the same - /// executor, deadlock may occur if ArrowWriterProperties::use_threads - /// is set to true to write columns in parallel. Please disable use_threads + /// executor, deadlock may occur if ArrowWriterProperties::useThreads + /// is set to true to write columns in parallel. Please disable useThreads /// option in this case. - virtual ::arrow::Status WriteRecordBatch( + virtual ::arrow::Status writeRecordBatch( const ::arrow::RecordBatch& batch) = 0; /// \brief Write the footer and close the file. - virtual ::arrow::Status Close() = 0; + virtual ::arrow::Status close() = 0; virtual ~FileWriter(); - virtual MemoryPool* memory_pool() const = 0; - /// \brief Return the file metadata, only available after calling Close(). + virtual MemoryPool* memoryPool() const = 0; + /// \brief Return the file metadata, only available after calling close(). virtual const std::shared_ptr metadata() const = 0; }; -/// \brief Write Parquet file metadata only to indicated Arrow OutputStream +/// \brief Write Parquet file metadata only to indicated Arrow OutputStream. PARQUET_EXPORT -::arrow::Status WriteFileMetaData( - const FileMetaData& file_metadata, +::arrow::Status writeFileMetaData( + const FileMetaData& fileMetadata, ::arrow::io::OutputStream* sink); -/// \brief Write metadata-only Parquet file to indicated Arrow OutputStream +/// \brief Write metadata-only Parquet file to indicated Arrow OutputStream. PARQUET_EXPORT -::arrow::Status WriteMetaDataFile( - const FileMetaData& file_metadata, +::arrow::Status writeMetaDataFile( + const FileMetaData& fileMetadata, ::arrow::io::OutputStream* sink); /// \brief Write a Table to Parquet. @@ -178,20 +178,20 @@ ::arrow::Status WriteMetaDataFile( /// This writes one table in a single shot. To write a Parquet file with /// multiple tables iteratively, see parquet::arrow::FileWriter. /// -/// \param table Table to write. -/// \param pool memory pool to use. -/// \param sink output stream to write Parquet data. -/// \param chunk_size maximum number of rows to write per row group. -/// \param properties general Parquet writer properties. +/// \param table Arrow table to write. +/// \param pool Memory pool to use. +/// \param sink Output stream to write Parquet data. +/// \param chunk_size Maximum number of rows to write per row group. +/// \param properties General Parquet writer properties. /// \param arrow_properties Arrow-specific writer properties. -::arrow::Status PARQUET_EXPORT WriteTable( +::arrow::Status PARQUET_EXPORT writeTable( const ::arrow::Table& table, MemoryPool* pool, std::shared_ptr<::arrow::io::OutputStream> sink, - int64_t chunk_size = DEFAULT_MAX_ROW_GROUP_LENGTH, - std::shared_ptr properties = default_writer_properties(), - std::shared_ptr arrow_properties = - default_arrow_writer_properties()); + int64_t chunkSize = DEFAULT_MAX_ROW_GROUP_LENGTH, + std::shared_ptr properties = defaultWriterProperties(), + std::shared_ptr arrowProperties = + defaultArrowWriterProperties()); } // namespace arrow } // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/tests/BloomFilter.cpp b/velox/dwio/parquet/writer/arrow/tests/BloomFilter.cpp index 866aa8ada61..b60ef1f10ae 100644 --- a/velox/dwio/parquet/writer/arrow/tests/BloomFilter.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/BloomFilter.cpp @@ -33,50 +33,49 @@ namespace facebook::velox::parquet::arrow { BlockSplitBloomFilter::BlockSplitBloomFilter(::arrow::MemoryPool* pool) : pool_(pool), - hash_strategy_(HashStrategy::XXHASH), + hashStrategy_(HashStrategy::XXHASH), algorithm_(Algorithm::BLOCK), - compression_strategy_(CompressionStrategy::UNCOMPRESSED) {} + compressionStrategy_(CompressionStrategy::UNCOMPRESSED) {} -void BlockSplitBloomFilter::Init(uint32_t num_bytes) { - if (num_bytes < kMinimumBloomFilterBytes) { - num_bytes = kMinimumBloomFilterBytes; +void BlockSplitBloomFilter::init(uint32_t numBytes) { + if (numBytes < kMinimumBloomFilterBytes) { + numBytes = kMinimumBloomFilterBytes; } // Get next power of 2 if it is not power of 2. - if ((num_bytes & (num_bytes - 1)) != 0) { - num_bytes = static_cast(::arrow::bit_util::NextPower2(num_bytes)); + if ((numBytes & (numBytes - 1)) != 0) { + numBytes = static_cast(::arrow::bit_util::NextPower2(numBytes)); } - if (num_bytes > kMaximumBloomFilterBytes) { - num_bytes = kMaximumBloomFilterBytes; + if (numBytes > kMaximumBloomFilterBytes) { + numBytes = kMaximumBloomFilterBytes; } - num_bytes_ = num_bytes; - PARQUET_ASSIGN_OR_THROW(data_, ::arrow::AllocateBuffer(num_bytes_, pool_)); - memset(data_->mutable_data(), 0, num_bytes_); + numBytes_ = numBytes; + PARQUET_ASSIGN_OR_THROW(data_, ::arrow::AllocateBuffer(numBytes_, pool_)); + memset(data_->mutable_data(), 0, numBytes_); this->hasher_ = std::make_unique(); } -void BlockSplitBloomFilter::Init(const uint8_t* bitset, uint32_t num_bytes) { +void BlockSplitBloomFilter::init(const uint8_t* bitset, uint32_t numBytes) { VELOX_DCHECK_NOT_NULL(bitset); - if (num_bytes < kMinimumBloomFilterBytes || - num_bytes > kMaximumBloomFilterBytes || - (num_bytes & (num_bytes - 1)) != 0) { + if (numBytes < kMinimumBloomFilterBytes || + numBytes > kMaximumBloomFilterBytes || (numBytes & (numBytes - 1)) != 0) { throw ParquetException("Given length of bitset is illegal"); } - num_bytes_ = num_bytes; - PARQUET_ASSIGN_OR_THROW(data_, ::arrow::AllocateBuffer(num_bytes_, pool_)); - memcpy(data_->mutable_data(), bitset, num_bytes_); + numBytes_ = numBytes; + PARQUET_ASSIGN_OR_THROW(data_, ::arrow::AllocateBuffer(numBytes_, pool_)); + memcpy(data_->mutable_data(), bitset, numBytes_); this->hasher_ = std::make_unique(); } static constexpr uint32_t kBloomFilterHeaderSizeGuess = 256; -static ::arrow::Status ValidateBloomFilterHeader( +static ::arrow::Status validateBloomFilterHeader( const facebook::velox::parquet::thrift::BloomFilterHeader& header) { if (!header.algorithm.__isset.BLOCK) { return ::arrow::Status::Invalid( @@ -106,71 +105,69 @@ static ::arrow::Status ValidateBloomFilterHeader( return ::arrow::Status::OK(); } -BlockSplitBloomFilter BlockSplitBloomFilter::Deserialize( +BlockSplitBloomFilter BlockSplitBloomFilter::deserialize( const ReaderProperties& properties, ArrowInputStream* input) { - // NOTE: we don't know the bloom filter header size upfront, and we can't rely - // on InputStream::Peek() which isn't always implemented. Therefore, we must - // first Read() with an upper bound estimate of the header size, then once we - // know the bloom filter data size, we can Read() the exact number of - // remaining data bytes. + // NOTE: we don't know the bloom filter header size upfront, and we can't + // rely. On InputStream::Peek() which isn't always implemented. Therefore, we + // must. First Read() with an upper bound estimate of the header size, then + // once we. Know the bloom filter data size, we can Read() the exact number + // of. Remaining data bytes. ThriftDeserializer deserializer(properties); facebook::velox::parquet::thrift::BloomFilterHeader header; - // Read and deserialize bloom filter header + // Read and deserialize bloom filter header. PARQUET_ASSIGN_OR_THROW( - auto header_buf, input->Read(kBloomFilterHeaderSizeGuess)); - // This gets used, then set by DeserializeThriftMsg - uint32_t header_size = static_cast(header_buf->size()); + auto headerBuf, input->Read(kBloomFilterHeaderSizeGuess)); + // This gets used, then set by DeserializeThriftMsg. + uint32_t headerSize = static_cast(headerBuf->size()); try { - deserializer.DeserializeMessage( - reinterpret_cast(header_buf->data()), - &header_size, + deserializer.deserializeMessage( + reinterpret_cast(headerBuf->data()), + &headerSize, &header); - VELOX_DCHECK_LE(header_size, header_buf->size()); + VELOX_DCHECK_LE(headerSize, headerBuf->size()); } catch (std::exception& e) { std::stringstream ss; ss << "Deserializing bloom filter header failed.\n" << e.what(); throw ParquetException(ss.str()); } - PARQUET_THROW_NOT_OK(ValidateBloomFilterHeader(header)); + PARQUET_THROW_NOT_OK(validateBloomFilterHeader(header)); - const int32_t bloom_filter_size = header.numBytes; - if (bloom_filter_size + header_size <= header_buf->size()) { - // The bloom filter data is entirely contained in the buffer we just read - // => just return it. - BlockSplitBloomFilter bloom_filter(properties.memory_pool()); - bloom_filter.Init(header_buf->data() + header_size, bloom_filter_size); - return bloom_filter; + const int32_t bloomFilterSize = header.numBytes; + if (bloomFilterSize + headerSize <= headerBuf->size()) { + // The bloom filter data is entirely contained in the buffer we just read. + // => Just return it. + BlockSplitBloomFilter bloomFilter(properties.memoryPool()); + bloomFilter.init(headerBuf->data() + headerSize, bloomFilterSize); + return bloomFilter; } - // We have read a part of the bloom filter already, copy it to the target - // buffer and read the remaining part from the InputStream. - auto buffer = AllocateBuffer(properties.memory_pool(), bloom_filter_size); + // We have read a part of the bloom filter already, copy it to the target. + // Buffer and read the remaining part from the InputStream. + auto buffer = allocateBuffer(properties.memoryPool(), bloomFilterSize); - const auto bloom_filter_bytes_in_header = header_buf->size() - header_size; - if (bloom_filter_bytes_in_header > 0) { + const auto bloomFilterBytesInHeader = headerBuf->size() - headerSize; + if (bloomFilterBytesInHeader > 0) { std::memcpy( buffer->mutable_data(), - header_buf->data() + header_size, - bloom_filter_bytes_in_header); + headerBuf->data() + headerSize, + bloomFilterBytesInHeader); } - const auto required_read_size = - bloom_filter_size - bloom_filter_bytes_in_header; + const auto requiredReadSize = bloomFilterSize - bloomFilterBytesInHeader; PARQUET_ASSIGN_OR_THROW( - auto read_size, + auto readSize, input->Read( - required_read_size, - buffer->mutable_data() + bloom_filter_bytes_in_header)); - if (ARROW_PREDICT_FALSE(read_size < required_read_size)) { + requiredReadSize, buffer->mutable_data() + bloomFilterBytesInHeader)); + if (ARROW_PREDICT_FALSE(readSize < requiredReadSize)) { throw ParquetException("Bloom Filter read failed: not enough data"); } - BlockSplitBloomFilter bloom_filter(properties.memory_pool()); - bloom_filter.Init(buffer->data(), bloom_filter_size); - return bloom_filter; + BlockSplitBloomFilter bloomFilter(properties.memoryPool()); + bloomFilter.init(buffer->data(), bloomFilterSize); + return bloomFilter; } -void BlockSplitBloomFilter::WriteTo(ArrowOutputStream* sink) const { +void BlockSplitBloomFilter::writeTo(ArrowOutputStream* sink) const { VELOX_DCHECK_NOT_NULL(sink); facebook::velox::parquet::thrift::BloomFilterHeader header; @@ -180,29 +177,29 @@ void BlockSplitBloomFilter::WriteTo(ArrowOutputStream* sink) const { } header.algorithm.__set_BLOCK( facebook::velox::parquet::thrift::SplitBlockAlgorithm()); - if (ARROW_PREDICT_FALSE(hash_strategy_ != HashStrategy::XXHASH)) { + if (ARROW_PREDICT_FALSE(hashStrategy_ != HashStrategy::XXHASH)) { throw ParquetException( "BloomFilter does not support Hash other than XXHASH"); } header.hash.__set_XXHASH(facebook::velox::parquet::thrift::XxHash()); if (ARROW_PREDICT_FALSE( - compression_strategy_ != CompressionStrategy::UNCOMPRESSED)) { + compressionStrategy_ != CompressionStrategy::UNCOMPRESSED)) { throw ParquetException( "BloomFilter does not support Compression other than UNCOMPRESSED"); } header.compression.__set_UNCOMPRESSED( facebook::velox::parquet::thrift::Uncompressed()); - header.__set_numBytes(num_bytes_); + header.__set_numBytes(numBytes_); ThriftSerializer serializer; - serializer.Serialize(&header, sink); + serializer.serialize(&header, sink); - PARQUET_THROW_NOT_OK(sink->Write(data_->data(), num_bytes_)); + PARQUET_THROW_NOT_OK(sink->Write(data_->data(), numBytes_)); } -bool BlockSplitBloomFilter::FindHash(uint64_t hash) const { - const uint32_t bucket_index = static_cast( - ((hash >> 32) * (num_bytes_ / kBytesPerFilterBlock)) >> 32); +bool BlockSplitBloomFilter::findHash(uint64_t hash) const { + const uint32_t bucketIndex = static_cast( + ((hash >> 32) * (numBytes_ / kBytesPerFilterBlock)) >> 32); const uint32_t key = static_cast(hash); const uint32_t* bitset32 = reinterpret_cast(data_->data()); @@ -210,35 +207,35 @@ bool BlockSplitBloomFilter::FindHash(uint64_t hash) const { // Calculate mask for key in the given bitset. const uint32_t mask = UINT32_C(0x1) << ((key * SALT[i]) >> 27); if (ARROW_PREDICT_FALSE( - 0 == (bitset32[kBitsSetPerBlock * bucket_index + i] & mask))) { + 0 == (bitset32[kBitsSetPerBlock * bucketIndex + i] & mask))) { return false; } } return true; } -void BlockSplitBloomFilter::InsertHashImpl(uint64_t hash) { - const uint32_t bucket_index = static_cast( - ((hash >> 32) * (num_bytes_ / kBytesPerFilterBlock)) >> 32); +void BlockSplitBloomFilter::insertHashImpl(uint64_t hash) { + const uint32_t bucketIndex = static_cast( + ((hash >> 32) * (numBytes_ / kBytesPerFilterBlock)) >> 32); const uint32_t key = static_cast(hash); uint32_t* bitset32 = reinterpret_cast(data_->mutable_data()); for (int i = 0; i < kBitsSetPerBlock; i++) { // Calculate mask for key in the given bitset. const uint32_t mask = UINT32_C(0x1) << ((key * SALT[i]) >> 27); - bitset32[bucket_index * kBitsSetPerBlock + i] |= mask; + bitset32[bucketIndex * kBitsSetPerBlock + i] |= mask; } } -void BlockSplitBloomFilter::InsertHash(uint64_t hash) { - InsertHashImpl(hash); +void BlockSplitBloomFilter::insertHash(uint64_t hash) { + insertHashImpl(hash); } -void BlockSplitBloomFilter::InsertHashes( +void BlockSplitBloomFilter::insertHashes( const uint64_t* hashes, - int num_values) { - for (int i = 0; i < num_values; ++i) { - InsertHashImpl(hashes[i]); + int numValues) { + for (int i = 0; i < numValues; ++i) { + insertHashImpl(hashes[i]); } } diff --git a/velox/dwio/parquet/writer/arrow/tests/BloomFilter.h b/velox/dwio/parquet/writer/arrow/tests/BloomFilter.h index 4637921fefd..2f36cc7ae01 100644 --- a/velox/dwio/parquet/writer/arrow/tests/BloomFilter.h +++ b/velox/dwio/parquet/writer/arrow/tests/BloomFilter.h @@ -32,153 +32,154 @@ namespace facebook::velox::parquet::arrow { -// A Bloom filter is a compact structure to indicate whether an item is not in a -// set or probably in a set. The Bloom filter usually consists of a bit set that -// represents a set of elements, a hash strategy and a Bloom filter algorithm. +// A Bloom filter is a compact structure to indicate whether an item is not in +// a. Set or probably in a set. The Bloom filter usually consists of a bit set +// that. Represents a set of elements, a hash strategy and a Bloom filter +// algorithm. class PARQUET_EXPORT BloomFilter { public: - // Maximum Bloom filter size, it sets to HDFS default block size 128MB + // Maximum Bloom filter size, it sets to HDFS default block size 128MB. // This value will be reconsidered when implementing Bloom filter producer. static constexpr uint32_t kMaximumBloomFilterBytes = 128 * 1024 * 1024; /// Determine whether an element exist in set or not. /// /// @param hash the element to contain. - /// @return false if value is definitely not in set, and true means PROBABLY - /// in set. - virtual bool FindHash(uint64_t hash) const = 0; + /// @return false if value is definitely not in set, and true means PROBABLY. + /// In set. + virtual bool findHash(uint64_t hash) const = 0; /// Insert element to set represented by Bloom filter bitset. /// @param hash the hash of value to insert into Bloom filter. - virtual void InsertHash(uint64_t hash) = 0; + virtual void insertHash(uint64_t hash) = 0; /// Insert elements to set represented by Bloom filter bitset. /// @param hashes the hash values to insert into Bloom filter. /// @param num_values the number of hash values to insert. - virtual void InsertHashes(const uint64_t* hashes, int num_values) = 0; + virtual void insertHashes(const uint64_t* hashes, int numValues) = 0; - /// Write this Bloom filter to an output stream. A Bloom filter structure - /// should include bitset length, hash strategy, algorithm, and bitset. + /// Write this Bloom filter to an output stream. A Bloom filter structure. + /// Should include bitset length, hash strategy, algorithm, and bitset. /// - /// @param sink the output stream to write - virtual void WriteTo(ArrowOutputStream* sink) const = 0; + /// @param sink the output stream to write. + virtual void writeTo(ArrowOutputStream* sink) const = 0; - /// Get the number of bytes of bitset - virtual uint32_t GetBitsetSize() const = 0; + /// Get the number of bytes of bitset. + virtual uint32_t getBitsetSize() const = 0; /// Compute hash for 32 bits value by using its plain encoding result. /// /// @param value the value to hash. /// @return hash result. - virtual uint64_t Hash(int32_t value) const = 0; + virtual uint64_t hash(int32_t value) const = 0; /// Compute hash for 64 bits value by using its plain encoding result. /// /// @param value the value to hash. /// @return hash result. - virtual uint64_t Hash(int64_t value) const = 0; + virtual uint64_t hash(int64_t value) const = 0; /// Compute hash for float value by using its plain encoding result. /// /// @param value the value to hash. /// @return hash result. - virtual uint64_t Hash(float value) const = 0; + virtual uint64_t hash(float value) const = 0; /// Compute hash for double value by using its plain encoding result. /// /// @param value the value to hash. /// @return hash result. - virtual uint64_t Hash(double value) const = 0; + virtual uint64_t hash(double value) const = 0; /// Compute hash for Int96 value by using its plain encoding result. /// /// @param value the value to hash. /// @return hash result. - virtual uint64_t Hash(const Int96* value) const = 0; + virtual uint64_t hash(const Int96* value) const = 0; /// Compute hash for ByteArray value by using its plain encoding result. /// /// @param value the value to hash. /// @return hash result. - virtual uint64_t Hash(const ByteArray* value) const = 0; + virtual uint64_t hash(const ByteArray* value) const = 0; - /// Compute hash for fixed byte array value by using its plain encoding - /// result. + /// Compute hash for fixed byte array value by using its plain encoding. + /// Result. /// /// @param value the value address. /// @param len the value length. /// @return hash result. - virtual uint64_t Hash(const FLBA* value, uint32_t len) const = 0; + virtual uint64_t hash(const FLBA* value, uint32_t len) const = 0; - /// Batch compute hashes for 32 bits values by using its plain encoding - /// result. + /// Batch compute hashes for 32 bits values by using its plain encoding. + /// Result. /// /// @param values values a pointer to the values to hash. /// @param num_values the number of values to hash. - /// @param hashes a pointer to the output hash values, its length should be - /// equal to num_values. - virtual void Hashes(const int32_t* values, int num_values, uint64_t* hashes) + /// @param hashes a pointer to the output hash values, its length should be. + /// Equal to num_values. + virtual void hashes(const int32_t* values, int numValues, uint64_t* hashes) const = 0; - /// Batch compute hashes for 64 bits values by using its plain encoding - /// result. + /// Batch compute hashes for 64 bits values by using its plain encoding. + /// Result. /// /// @param values values a pointer to the values to hash. /// @param num_values the number of values to hash. - /// @param hashes a pointer to the output hash values, its length should be - /// equal to num_values. - virtual void Hashes(const int64_t* values, int num_values, uint64_t* hashes) + /// @param hashes a pointer to the output hash values, its length should be. + /// Equal to num_values. + virtual void hashes(const int64_t* values, int numValues, uint64_t* hashes) const = 0; /// Batch compute hashes for float values by using its plain encoding result. /// /// @param values values a pointer to the values to hash. /// @param num_values the number of values to hash. - /// @param hashes a pointer to the output hash values, its length should be - /// equal to num_values. - virtual void Hashes(const float* values, int num_values, uint64_t* hashes) + /// @param hashes a pointer to the output hash values, its length should be. + /// Equal to num_values. + virtual void hashes(const float* values, int numValues, uint64_t* hashes) const = 0; /// Batch compute hashes for double values by using its plain encoding result. /// /// @param values values a pointer to the values to hash. /// @param num_values the number of values to hash. - /// @param hashes a pointer to the output hash values, its length should be - /// equal to num_values. - virtual void Hashes(const double* values, int num_values, uint64_t* hashes) + /// @param hashes a pointer to the output hash values, its length should be. + /// Equal to num_values. + virtual void hashes(const double* values, int numValues, uint64_t* hashes) const = 0; /// Batch compute hashes for Int96 values by using its plain encoding result. /// /// @param values values a pointer to the values to hash. /// @param num_values the number of values to hash. - /// @param hashes a pointer to the output hash values, its length should be - /// equal to num_values. - virtual void Hashes(const Int96* values, int num_values, uint64_t* hashes) + /// @param hashes a pointer to the output hash values, its length should be. + /// Equal to num_values. + virtual void hashes(const Int96* values, int numValues, uint64_t* hashes) const = 0; - /// Batch compute hashes for ByteArray values by using its plain encoding - /// result. + /// Batch compute hashes for ByteArray values by using its plain encoding. + /// Result. /// /// @param values values a pointer to the values to hash. /// @param num_values the number of values to hash. - /// @param hashes a pointer to the output hash values, its length should be - /// equal to num_values. - virtual void Hashes(const ByteArray* values, int num_values, uint64_t* hashes) + /// @param hashes a pointer to the output hash values, its length should be. + /// Equal to num_values. + virtual void hashes(const ByteArray* values, int numValues, uint64_t* hashes) const = 0; - /// Batch compute hashes for fixed byte array values by using its plain - /// encoding result. + /// Batch compute hashes for fixed byte array values by using its plain. + /// Encoding result. /// /// @param values values a pointer to the values to hash. /// @param type_len the value length. /// @param num_values the number of values to hash. - /// @param hashes a pointer to the output hash values, its length should be - /// equal to num_values. - virtual void Hashes( + /// @param hashes a pointer to the output hash values, its length should be. + /// Equal to num_values. + virtual void hashes( const FLBA* values, - uint32_t type_len, - int num_values, + uint32_t typeLen, + int numValues, uint64_t* hashes) const = 0; virtual ~BloomFilter() = default; @@ -193,13 +194,13 @@ class PARQUET_EXPORT BloomFilter { enum class CompressionStrategy : uint32_t { UNCOMPRESSED = 0 }; }; -/// The BlockSplitBloomFilter is implemented using block-based Bloom filters -/// from Putze et al.'s "Cache-,Hash- and Space-Efficient Bloom filters". The -/// basic idea is to hash the item to a tiny Bloom filter which size fit a -/// single cache line or smaller. +/// The BlockSplitBloomFilter is implemented using block-based Bloom filters. +/// From Putze et al.'s "Cache-,Hash- and Space-Efficient Bloom filters". The. +/// Basic idea is to hash the item to a tiny Bloom filter which size fit a. +/// Single cache line or smaller. /// -/// This implementation sets 8 bits in each tiny Bloom filter. Each tiny Bloom -/// filter is 32 bytes to take advantage of 32-byte SIMD instructions. +/// This implementation sets 8 bits in each tiny Bloom filter. Each tiny Bloom. +/// Filter is 32 bytes to take advantage of 32-byte SIMD instructions. class PARQUET_EXPORT BlockSplitBloomFilter : public BloomFilter { public: /// The constructor of BlockSplitBloomFilter. It uses XXH64 as hash function. @@ -208,170 +209,171 @@ class PARQUET_EXPORT BlockSplitBloomFilter : public BloomFilter { explicit BlockSplitBloomFilter( ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()); - /// Initialize the BlockSplitBloomFilter. The range of num_bytes should be - /// within [kMinimumBloomFilterBytes, kMaximumBloomFilterBytes], it will be - /// rounded up/down to lower/upper bound if num_bytes is out of range and also - /// will be rounded up to a power of 2. + /// Initialize the BlockSplitBloomFilter. The range of num_bytes should be. + /// Within [kMinimumBloomFilterBytes, kMaximumBloomFilterBytes], it will be. + /// Rounded up/down to lower/upper bound if num_bytes is out of range and + /// also. Will be rounded up to a power of 2. /// /// @param num_bytes The number of bytes to store Bloom filter bitset. - void Init(uint32_t num_bytes); + void init(uint32_t numBytes); - /// Initialize the BlockSplitBloomFilter. It copies the bitset as underlying - /// bitset because the given bitset may not satisfy the 32-byte alignment - /// requirement which may lead to segfault when performing SIMD instructions. - /// It is the caller's responsibility to free the bitset passed in. This is - /// used when reconstructing a Bloom filter from a parquet file. + /// Initialize the BlockSplitBloomFilter. It copies the bitset as underlying. + /// Bitset because the given bitset may not satisfy the 32-byte alignment. + /// Requirement which may lead to segfault when performing SIMD instructions. + /// It is the caller's responsibility to free the bitset passed in. This is. + /// Used when reconstructing a Bloom filter from a parquet file. /// /// @param bitset The given bitset to initialize the Bloom filter. /// @param num_bytes The number of bytes of given bitset. - void Init(const uint8_t* bitset, uint32_t num_bytes); + void init(const uint8_t* bitset, uint32_t numBytes); /// Minimum Bloom filter size, it sets to 32 bytes to fit a tiny Bloom filter. static constexpr uint32_t kMinimumBloomFilterBytes = 32; - /// Calculate optimal size according to the number of distinct values and - /// false positive probability. + /// Calculate optimal size according to the number of distinct values and. + /// False positive probability. /// /// @param ndv The number of distinct values. /// @param fpp The false positive probability. - /// @return it always return a value between kMinimumBloomFilterBytes and - /// kMaximumBloomFilterBytes, and the return value is always a power of 2 - static uint32_t OptimalNumOfBytes(uint32_t ndv, double fpp) { - uint32_t optimal_num_of_bits = OptimalNumOfBits(ndv, fpp); - VELOX_DCHECK(::arrow::bit_util::IsMultipleOf8(optimal_num_of_bits)); - return optimal_num_of_bits >> 3; + /// @return it always return a value between kMinimumBloomFilterBytes and. + /// KMaximumBloomFilterBytes, and the return value is always a power of 2. + static uint32_t optimalNumOfBytes(uint32_t ndv, double fpp) { + uint32_t optimalNumBits = optimalNumOfBits(ndv, fpp); + VELOX_DCHECK(::arrow::bit_util::IsMultipleOf8(optimalNumBits)); + return optimalNumBits >> 3; } - /// Calculate optimal size according to the number of distinct values and - /// false positive probability. + /// Calculate optimal size according to the number of distinct values and. + /// False positive probability. /// /// @param ndv The number of distinct values. /// @param fpp The false positive probability. - /// @return it always return a value between kMinimumBloomFilterBytes * 8 and - /// kMaximumBloomFilterBytes * 8, and the return value is always a power of 16 - static uint32_t OptimalNumOfBits(uint32_t ndv, double fpp) { + /// @return it always return a value between kMinimumBloomFilterBytes * 8 and. + /// KMaximumBloomFilterBytes * 8, and the return value is always a power + /// of 16. + static uint32_t optimalNumOfBits(uint32_t ndv, double fpp) { VELOX_DCHECK(fpp > 0.0 && fpp < 1.0); const double m = -8.0 * ndv / log(1 - pow(fpp, 1.0 / 8)); - uint32_t num_bits; + uint32_t numBits; // Handle overflow. if (m < 0 || m > kMaximumBloomFilterBytes << 3) { - num_bits = static_cast(kMaximumBloomFilterBytes << 3); + numBits = static_cast(kMaximumBloomFilterBytes << 3); } else { - num_bits = static_cast(m); + numBits = static_cast(m); } - // Round up to lower bound - if (num_bits < kMinimumBloomFilterBytes << 3) { - num_bits = kMinimumBloomFilterBytes << 3; + // Round up to lower bound. + if (numBits < kMinimumBloomFilterBytes << 3) { + numBits = kMinimumBloomFilterBytes << 3; } // Get next power of 2 if bits is not power of 2. - if ((num_bits & (num_bits - 1)) != 0) { - num_bits = static_cast(::arrow::bit_util::NextPower2(num_bits)); + if ((numBits & (numBits - 1)) != 0) { + numBits = static_cast(::arrow::bit_util::NextPower2(numBits)); } - // Round down to upper bound - if (num_bits > kMaximumBloomFilterBytes << 3) { - num_bits = kMaximumBloomFilterBytes << 3; + // Round down to upper bound. + if (numBits > kMaximumBloomFilterBytes << 3) { + numBits = kMaximumBloomFilterBytes << 3; } - return num_bits; + return numBits; } - bool FindHash(uint64_t hash) const override; - void InsertHash(uint64_t hash) override; - void InsertHashes(const uint64_t* hashes, int num_values) override; - void WriteTo(ArrowOutputStream* sink) const override; - uint32_t GetBitsetSize() const override { - return num_bytes_; + bool findHash(uint64_t hash) const override; + void insertHash(uint64_t hash) override; + void insertHashes(const uint64_t* hashes, int numValues) override; + void writeTo(ArrowOutputStream* sink) const override; + uint32_t getBitsetSize() const override { + return numBytes_; } - uint64_t Hash(int32_t value) const override { - return hasher_->Hash(value); + uint64_t hash(int32_t value) const override { + return hasher_->hash(value); } - uint64_t Hash(int64_t value) const override { - return hasher_->Hash(value); + uint64_t hash(int64_t value) const override { + return hasher_->hash(value); } - uint64_t Hash(float value) const override { - return hasher_->Hash(value); + uint64_t hash(float value) const override { + return hasher_->hash(value); } - uint64_t Hash(double value) const override { - return hasher_->Hash(value); + uint64_t hash(double value) const override { + return hasher_->hash(value); } - uint64_t Hash(const Int96* value) const override { - return hasher_->Hash(value); + uint64_t hash(const Int96* value) const override { + return hasher_->hash(value); } - uint64_t Hash(const ByteArray* value) const override { - return hasher_->Hash(value); + uint64_t hash(const ByteArray* value) const override { + return hasher_->hash(value); } - uint64_t Hash(const FLBA* value, uint32_t len) const override { - return hasher_->Hash(value, len); + uint64_t hash(const FLBA* value, uint32_t len) const override { + return hasher_->hash(value, len); } - void Hashes(const int32_t* values, int num_values, uint64_t* hashes) + void hashes(const int32_t* values, int numValues, uint64_t* hashes) const override { - hasher_->Hashes(values, num_values, hashes); + hasher_->hashes(values, numValues, hashes); } - void Hashes(const int64_t* values, int num_values, uint64_t* hashes) + void hashes(const int64_t* values, int numValues, uint64_t* hashes) const override { - hasher_->Hashes(values, num_values, hashes); + hasher_->hashes(values, numValues, hashes); } - void Hashes(const float* values, int num_values, uint64_t* hashes) + void hashes(const float* values, int numValues, uint64_t* hashes) const override { - hasher_->Hashes(values, num_values, hashes); + hasher_->hashes(values, numValues, hashes); } - void Hashes(const double* values, int num_values, uint64_t* hashes) + void hashes(const double* values, int numValues, uint64_t* hashes) const override { - hasher_->Hashes(values, num_values, hashes); + hasher_->hashes(values, numValues, hashes); } - void Hashes(const Int96* values, int num_values, uint64_t* hashes) + void hashes(const Int96* values, int numValues, uint64_t* hashes) const override { - hasher_->Hashes(values, num_values, hashes); + hasher_->hashes(values, numValues, hashes); } - void Hashes(const ByteArray* values, int num_values, uint64_t* hashes) + void hashes(const ByteArray* values, int numValues, uint64_t* hashes) const override { - hasher_->Hashes(values, num_values, hashes); + hasher_->hashes(values, numValues, hashes); } - void Hashes( + void hashes( const FLBA* values, - uint32_t type_len, - int num_values, + uint32_t typeLen, + int numValues, uint64_t* hashes) const override { - hasher_->Hashes(values, type_len, num_values, hashes); + hasher_->hashes(values, typeLen, numValues, hashes); } - uint64_t Hash(const int32_t* value) const { - return hasher_->Hash(*value); + uint64_t hash(const int32_t* value) const { + return hasher_->hash(*value); } - uint64_t Hash(const int64_t* value) const { - return hasher_->Hash(*value); + uint64_t hash(const int64_t* value) const { + return hasher_->hash(*value); } - uint64_t Hash(const float* value) const { - return hasher_->Hash(*value); + uint64_t hash(const float* value) const { + return hasher_->hash(*value); } - uint64_t Hash(const double* value) const { - return hasher_->Hash(*value); + uint64_t hash(const double* value) const { + return hasher_->hash(*value); } - /// Deserialize the Bloom filter from an input stream. It is used when - /// reconstructing a Bloom filter from a parquet filter. + /// Deserialize the Bloom filter from an input stream. It is used when. + /// Reconstructing a Bloom filter from a parquet filter. /// /// @param properties The parquet reader properties. - /// @param input_stream The input stream from which to construct the Bloom - /// filter. + /// @param input_stream The input stream from which to construct the Bloom. + /// Filter. /// @return The BlockSplitBloomFilter. - static BlockSplitBloomFilter Deserialize( + static BlockSplitBloomFilter deserialize( const ReaderProperties& properties, - ArrowInputStream* input_stream); + ArrowInputStream* inputStream); private: - inline void InsertHashImpl(uint64_t hash); + inline void insertHashImpl(uint64_t hash); // Bytes in a tiny Bloom filter block. static constexpr int kBytesPerFilterBlock = 32; - // The number of bits to be set in each tiny Bloom filter + // The number of bits to be set in each tiny Bloom filter. static constexpr int kBitsSetPerBlock = 8; // A mask structure used to set bits in each tiny Bloom filter. @@ -379,8 +381,8 @@ class PARQUET_EXPORT BlockSplitBloomFilter : public BloomFilter { uint32_t item[kBitsSetPerBlock]; }; - // The block-based algorithm needs eight odd SALT values to calculate eight - // indexes of bit to set, one bit in each 32-bit word. + // The block-based algorithm needs eight odd SALT values to calculate eight. + // Indexes of bit to set, one bit in each 32-bit word. static constexpr uint32_t SALT[kBitsSetPerBlock] = { 0x47b6137bU, 0x44974d91U, @@ -391,23 +393,23 @@ class PARQUET_EXPORT BlockSplitBloomFilter : public BloomFilter { 0x9efc4947U, 0x5c6bfb31U}; - // Memory pool to allocate aligned buffer for bitset + // Memory pool to allocate aligned buffer for bitset. ::arrow::MemoryPool* pool_; // The underlying buffer of bitset. std::shared_ptr data_; // The number of bytes of Bloom filter bitset. - uint32_t num_bytes_; + uint32_t numBytes_; // Hash strategy used in this Bloom filter. - HashStrategy hash_strategy_; + HashStrategy hashStrategy_; // Algorithm used in this Bloom filter. Algorithm algorithm_; // Compression used in this Bloom filter. - CompressionStrategy compression_strategy_; + CompressionStrategy compressionStrategy_; // The hash pointer points to actual hash class used. std::unique_ptr hasher_; diff --git a/velox/dwio/parquet/writer/arrow/tests/BloomFilterReader.cpp b/velox/dwio/parquet/writer/arrow/tests/BloomFilterReader.cpp index 152ebf32d77..88c15ed6cbd 100644 --- a/velox/dwio/parquet/writer/arrow/tests/BloomFilterReader.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/BloomFilterReader.cpp @@ -27,76 +27,76 @@ class RowGroupBloomFilterReaderImpl final : public RowGroupBloomFilterReader { public: RowGroupBloomFilterReaderImpl( std::shared_ptr<::arrow::io::RandomAccessFile> input, - std::shared_ptr row_group_metadata, + std::shared_ptr rowGroupMetadata, const ReaderProperties& properties) : input_(std::move(input)), - row_group_metadata_(std::move(row_group_metadata)), + rowGroupMetadata_(std::move(rowGroupMetadata)), properties_(properties) {} - std::unique_ptr GetColumnBloomFilter(int i) override; + std::unique_ptr getColumnBloomFilter(int i) override; private: /// The input stream that can perform random access read. std::shared_ptr<::arrow::io::RandomAccessFile> input_; /// The row group metadata to get column chunk metadata. - std::shared_ptr row_group_metadata_; + std::shared_ptr rowGroupMetadata_; /// Reader properties used to deserialize thrift object. const ReaderProperties& properties_; }; std::unique_ptr -RowGroupBloomFilterReaderImpl::GetColumnBloomFilter(int i) { - if (i < 0 || i >= row_group_metadata_->num_columns()) { +RowGroupBloomFilterReaderImpl::getColumnBloomFilter(int i) { + if (i < 0 || i >= rowGroupMetadata_->numColumns()) { throw ParquetException("Invalid column index at column ordinal ", i); } - auto col_chunk = row_group_metadata_->ColumnChunk(i); - std::unique_ptr crypto_metadata = - col_chunk->crypto_metadata(); - if (crypto_metadata != nullptr) { + auto colChunk = rowGroupMetadata_->columnChunk(i); + std::unique_ptr cryptoMetadata = + colChunk->cryptoMetadata(); + if (cryptoMetadata != nullptr) { ParquetException::NYI("Cannot read encrypted bloom filter yet"); } - auto bloom_filter_offset = col_chunk->bloom_filter_offset(); - if (!bloom_filter_offset.has_value()) { + auto bloomFilterOffset = colChunk->bloomFilterOffset(); + if (!bloomFilterOffset.has_value()) { return nullptr; } - PARQUET_ASSIGN_OR_THROW(auto file_size, input_->GetSize()); - if (file_size <= *bloom_filter_offset) { + PARQUET_ASSIGN_OR_THROW(auto fileSize, input_->GetSize()); + if (fileSize <= *bloomFilterOffset) { throw ParquetException("file size less or equal than bloom offset"); } auto stream = ::arrow::io::RandomAccessFile::GetStream( - input_, *bloom_filter_offset, file_size - *bloom_filter_offset); - auto bloom_filter = - BlockSplitBloomFilter::Deserialize(properties_, stream->get()); - return std::make_unique(std::move(bloom_filter)); + input_, *bloomFilterOffset, fileSize - *bloomFilterOffset); + auto bloomFilter = + BlockSplitBloomFilter::deserialize(properties_, stream->get()); + return std::make_unique(std::move(bloomFilter)); } class BloomFilterReaderImpl final : public BloomFilterReader { public: BloomFilterReaderImpl( std::shared_ptr<::arrow::io::RandomAccessFile> input, - std::shared_ptr file_metadata, + std::shared_ptr fileMetadata, const ReaderProperties& properties, - std::shared_ptr file_decryptor) + std::shared_ptr fileDecryptor) : input_(std::move(input)), - file_metadata_(std::move(file_metadata)), + fileMetadata_(std::move(fileMetadata)), properties_(properties) { - if (file_decryptor != nullptr) { + if (fileDecryptor != nullptr) { ParquetException::NYI("BloomFilter decryption is not yet supported"); } } - std::shared_ptr RowGroup(int i) { - if (i < 0 || i >= file_metadata_->num_row_groups()) { + std::shared_ptr rowGroup(int i) { + if (i < 0 || i >= fileMetadata_->numRowGroups()) { throw ParquetException("Invalid row group ordinal: ", i); } - auto row_group_metadata = file_metadata_->RowGroup(i); + auto rowGroupMetadata = fileMetadata_->rowGroup(i); return std::make_shared( - input_, std::move(row_group_metadata), properties_); + input_, std::move(rowGroupMetadata), properties_); } private: @@ -104,19 +104,19 @@ class BloomFilterReaderImpl final : public BloomFilterReader { std::shared_ptr<::arrow::io::RandomAccessFile> input_; /// The file metadata to get row group metadata. - std::shared_ptr file_metadata_; + std::shared_ptr fileMetadata_; /// Reader properties used to deserialize thrift object. const ReaderProperties& properties_; }; -std::unique_ptr BloomFilterReader::Make( +std::unique_ptr BloomFilterReader::make( std::shared_ptr<::arrow::io::RandomAccessFile> input, - std::shared_ptr file_metadata, + std::shared_ptr fileMetadata, const ReaderProperties& properties, - std::shared_ptr file_decryptor) { + std::shared_ptr fileDecryptor) { return std::make_unique( - std::move(input), file_metadata, properties, std::move(file_decryptor)); + std::move(input), fileMetadata, properties, std::move(fileDecryptor)); } } // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/tests/BloomFilterReader.h b/velox/dwio/parquet/writer/arrow/tests/BloomFilterReader.h index ac32d66a688..c6089cc9f04 100644 --- a/velox/dwio/parquet/writer/arrow/tests/BloomFilterReader.h +++ b/velox/dwio/parquet/writer/arrow/tests/BloomFilterReader.h @@ -35,9 +35,9 @@ class PARQUET_EXPORT RowGroupBloomFilterReader { /// /// \param[in] i column ordinal of the column chunk. /// \returns bloom filter of the column or nullptr if it does not exist. - /// \throws ParquetException if the index is out of bound, or read bloom - /// filter failed. - virtual std::unique_ptr GetColumnBloomFilter(int i) = 0; + /// \throws ParquetException if the index is out of bound, or read bloom. + /// Filter failed. + virtual std::unique_ptr getColumnBloomFilter(int i) = 0; }; /// \brief Interface for reading the bloom filter for a Parquet file. @@ -47,25 +47,25 @@ class PARQUET_EXPORT BloomFilterReader { /// \brief Create a BloomFilterReader instance. /// \returns a BloomFilterReader instance. - /// WARNING: The returned BloomFilterReader references to all the input - /// parameters, so it must not outlive all of the input parameters. Usually - /// these input parameters come from the same ParquetFileReader object, so it - /// must not outlive the reader that creates this BloomFilterReader. - static std::unique_ptr Make( + /// WARNING: The returned BloomFilterReader references to all the input. + /// Parameters, so it must not outlive all of the input parameters. Usually. + /// These input parameters come from the same ParquetFileReader object, so it. + /// Must not outlive the reader that creates this BloomFilterReader. + static std::unique_ptr make( std::shared_ptr<::arrow::io::RandomAccessFile> input, - std::shared_ptr file_metadata, + std::shared_ptr fileMetadata, const ReaderProperties& properties, - std::shared_ptr file_decryptor = NULLPTR); + std::shared_ptr fileDecryptor = NULLPTR); /// \brief Get the bloom filter reader of a specific row group. /// \param[in] i row group ordinal to get bloom filter reader. - /// \returns RowGroupBloomFilterReader of the specified row group. A nullptr - /// may or may - /// not be returned if the bloom filter for the row group is - /// unavailable. It is the caller's responsibility to check the - /// return value of follow-up calls to the RowGroupBloomFilterReader. + /// \returns RowGroupBloomFilterReader of the specified row group. A nullptr. + /// May or may. + /// Not be returned if the bloom filter for the row group is. + /// Unavailable. It is the caller's responsibility to check the. + /// Return value of follow-up calls to the RowGroupBloomFilterReader. /// \throws ParquetException if the index is out of bound. - virtual std::shared_ptr RowGroup(int i) = 0; + virtual std::shared_ptr rowGroup(int i) = 0; }; } // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/tests/BloomFilterTest.cpp b/velox/dwio/parquet/writer/arrow/tests/BloomFilterTest.cpp index 644fda5a678..158023f24c5 100644 --- a/velox/dwio/parquet/writer/arrow/tests/BloomFilterTest.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/BloomFilterTest.cpp @@ -43,21 +43,21 @@ namespace facebook::velox::parquet::arrow { namespace test { TEST(ConstructorTest, TestBloomFilter) { - BlockSplitBloomFilter bloom_filter; - EXPECT_NO_THROW(bloom_filter.Init(1000)); + BlockSplitBloomFilter BloomFilter; + EXPECT_NO_THROW(BloomFilter.init(1000)); - // It throws because the length cannot be zero + // It throws because the length cannot be zero. std::unique_ptr bitset1(new uint8_t[1024]()); - EXPECT_THROW(bloom_filter.Init(bitset1.get(), 0), ParquetException); + EXPECT_THROW(BloomFilter.init(bitset1.get(), 0), ParquetException); - // It throws because the number of bytes of Bloom filter bitset must be a - // power of 2. + // It throws because the number of bytes of Bloom filter bitset must be a. + // Power of 2. std::unique_ptr bitset2(new uint8_t[1024]()); - EXPECT_THROW(bloom_filter.Init(bitset2.get(), 1023), ParquetException); + EXPECT_THROW(BloomFilter.init(bitset2.get(), 1023), ParquetException); } -// The BasicTest is used to test basic operations including InsertHash, FindHash -// and serializing and de-serializing. +// The BasicTest is used to test basic operations including InsertHash, +// FindHash. And serializing and de-serializing. TEST(BasicTest, TestBloomFilter) { const std::vector kBloomFilterSizes = { 32, 64, 128, 256, 512, 1024, 2048}; @@ -68,72 +68,72 @@ TEST(BasicTest, TestBloomFilter) { const std::vector kNegativeIntLookups = { 0, 11, 12, 13, -2, -3, 43, 1 << 27, 1 << 28}; - for (const auto bloom_filter_bytes : kBloomFilterSizes) { - BlockSplitBloomFilter bloom_filter; - bloom_filter.Init(bloom_filter_bytes); + for (const auto bloomFilterBytes : kBloomFilterSizes) { + BlockSplitBloomFilter BloomFilter; + BloomFilter.init(bloomFilterBytes); - // Empty bloom filter deterministically returns false + // Empty bloom filter deterministically returns false. for (const auto v : kIntInserts) { - EXPECT_FALSE(bloom_filter.FindHash(bloom_filter.Hash(v))); + EXPECT_FALSE(BloomFilter.findHash(BloomFilter.hash(v))); } for (const auto v : kFloatInserts) { - EXPECT_FALSE(bloom_filter.FindHash(bloom_filter.Hash(v))); + EXPECT_FALSE(BloomFilter.findHash(BloomFilter.hash(v))); } - // Insert all values + // Insert all values. for (const auto v : kIntInserts) { - bloom_filter.InsertHash(bloom_filter.Hash(v)); + BloomFilter.insertHash(BloomFilter.hash(v)); } for (const auto v : kFloatInserts) { - bloom_filter.InsertHash(bloom_filter.Hash(v)); + BloomFilter.insertHash(BloomFilter.hash(v)); } - // They should always lookup successfully + // They should always lookup successfully. for (const auto v : kIntInserts) { - EXPECT_TRUE(bloom_filter.FindHash(bloom_filter.Hash(v))); + EXPECT_TRUE(BloomFilter.findHash(BloomFilter.hash(v))); } for (const auto v : kFloatInserts) { - EXPECT_TRUE(bloom_filter.FindHash(bloom_filter.Hash(v))); + EXPECT_TRUE(BloomFilter.findHash(BloomFilter.hash(v))); } - // Values not inserted in the filter should only rarely lookup successfully - int false_positives = 0; + // Values not inserted in the filter should only rarely lookup successfully. + int falsePositives = 0; for (const auto v : kNegativeIntLookups) { - false_positives += bloom_filter.FindHash(bloom_filter.Hash(v)); + falsePositives += BloomFilter.findHash(BloomFilter.hash(v)); } // (this is a crude check, see FPPTest below for a more rigorous formula) - EXPECT_LE(false_positives, 2); + EXPECT_LE(falsePositives, 2); - // Serialize Bloom filter to memory output stream - auto sink = CreateOutputStream(); - bloom_filter.WriteTo(sink.get()); + // Serialize Bloom filter to memory output stream. + auto sink = createOutputStream(); + BloomFilter.writeTo(sink.get()); - // Deserialize Bloom filter from memory + // Deserialize Bloom filter from memory. ASSERT_OK_AND_ASSIGN(auto buffer, sink->Finish()); ::arrow::io::BufferReader source(buffer); - ReaderProperties reader_properties; - BlockSplitBloomFilter de_bloom = - BlockSplitBloomFilter::Deserialize(reader_properties, &source); + ReaderProperties ReaderProperties; + BlockSplitBloomFilter deBloom = + BlockSplitBloomFilter::deserialize(ReaderProperties, &source); - // Lookup previously inserted values + // Lookup previously inserted values. for (const auto v : kIntInserts) { - EXPECT_TRUE(de_bloom.FindHash(de_bloom.Hash(v))); + EXPECT_TRUE(deBloom.findHash(deBloom.hash(v))); } for (const auto v : kFloatInserts) { - EXPECT_TRUE(de_bloom.FindHash(de_bloom.Hash(v))); + EXPECT_TRUE(deBloom.findHash(deBloom.hash(v))); } - false_positives = 0; + falsePositives = 0; for (const auto v : kNegativeIntLookups) { - false_positives += de_bloom.FindHash(de_bloom.Hash(v)); + falsePositives += deBloom.findHash(deBloom.hash(v)); } - EXPECT_LE(false_positives, 2); + EXPECT_LE(falsePositives, 2); } } // Helper function to generate random string. -std::string GetRandomString(uint32_t length) { - // Character set used to generate random string +std::string getRandomString(uint32_t length) { + // Character set used to generate random string. const std::string charset = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; @@ -152,51 +152,50 @@ TEST(FPPTest, TestBloomFilter) { // It counts the number of times FindHash returns true. int exist = 0; - // Total count of elements that will be used + // Total count of elements that will be used. #ifdef PARQUET_VALGRIND - const int total_count = 5000; + const int totalCount = 5000; #else - const int total_count = 100000; + const int totalCount = 100000; #endif - // Bloom filter fpp parameter + // Bloom filter fpp parameter. const double fpp = 0.01; std::vector members; - BlockSplitBloomFilter bloom_filter; - bloom_filter.Init(BlockSplitBloomFilter::OptimalNumOfBytes(total_count, fpp)); - - // Insert elements into the Bloom filter - for (int i = 0; i < total_count; i++) { - // Insert random string which length is 8 - std::string tmp = GetRandomString(8); - const ByteArray byte_array( - 8, reinterpret_cast(tmp.c_str())); + BlockSplitBloomFilter BloomFilter; + BloomFilter.init(BlockSplitBloomFilter::optimalNumOfBytes(totalCount, fpp)); + + // Insert elements into the Bloom filter. + for (int i = 0; i < totalCount; i++) { + // Insert random string which length is 8. + std::string tmp = getRandomString(8); + const ByteArray ByteArray(8, reinterpret_cast(tmp.c_str())); members.push_back(tmp); - bloom_filter.InsertHash(bloom_filter.Hash(&byte_array)); + BloomFilter.insertHash(BloomFilter.hash(&ByteArray)); } - for (int i = 0; i < total_count; i++) { - const ByteArray byte_array1( + for (int i = 0; i < totalCount; i++) { + const ByteArray byteArray1( 8, reinterpret_cast(members[i].c_str())); - ASSERT_TRUE(bloom_filter.FindHash(bloom_filter.Hash(&byte_array1))); - std::string tmp = GetRandomString(7); - const ByteArray byte_array2( + ASSERT_TRUE(BloomFilter.findHash(BloomFilter.hash(&byteArray1))); + std::string tmp = getRandomString(7); + const ByteArray byteArray2( 7, reinterpret_cast(tmp.c_str())); - if (bloom_filter.FindHash(bloom_filter.Hash(&byte_array2))) { + if (BloomFilter.findHash(BloomFilter.hash(&byteArray2))) { exist++; } } // The exist should be probably less than 1000 according default FPP 0.01. - EXPECT_LT(exist, total_count * fpp); + EXPECT_LT(exist, totalCount * fpp); } -// The CompatibilityTest is used to test cross compatibility with parquet-mr, it -// reads the Bloom filter binary generated by the Bloom filter class in the -// parquet-mr project and tests whether the values inserted before could be -// filtered or not. +// The CompatibilityTest is used to test cross compatibility with parquet-mr, +// it. Reads the Bloom filter binary generated by the Bloom filter class in the. +// Parquet-mr project and tests whether the values inserted before could be. +// Filtered or not. // TODO: disabled as it requires Arrow parquet data dir. // The Bloom filter binary is generated by three steps in from Parquet-mr. @@ -231,11 +230,11 @@ uint8_t*>(test_string[i].c_str())); EXPECT_TRUE(bloom_filter1.FindHash(bloom_filter1.Hash(&tmp))); } - // The following is used to check whether the new created Bloom filter in + // The following is used to check whether the new created Bloom filter in. parquet-cpp is - // byte-for-byte identical to file at bloom_data_path which is created from + // Byte-for-byte identical to file at bloom_data_path which is created from. parquet-mr - // with same inserted hashes. + // With same inserted hashes. BlockSplitBloomFilter bloom_filter2; bloom_filter2.Init(bloom_filter1.GetBitsetSize()); for (int i = 0; i < 4; i++) { @@ -245,7 +244,7 @@ uint8_t*>(test_string[i].c_str())); bloom_filter2.InsertHash(bloom_filter2.Hash(&byte_array)); } - // Serialize Bloom filter to memory output stream + // Serialize Bloom filter to memory output stream. auto sink = CreateOutputStream(); bloom_filter2.WriteTo(sink.get()); PARQUET_ASSIGN_OR_THROW(auto buffer1, sink->Finish()); @@ -258,18 +257,18 @@ uint8_t*>(test_string[i].c_str())); } */ -// OptimalValueTest is used to test whether OptimalNumOfBits returns expected -// numbers according to formula: +// OptimalValueTest is used to test whether OptimalNumOfBits returns expected. +// Numbers according to formula: // num_of_bits = -8.0 * ndv / log(1 - pow(fpp, 1.0 / 8.0)) -// where ndv is the number of distinct values and fpp is the false positive -// probability. Also it is used to test whether OptimalNumOfBits returns value -// between [MINIMUM_BLOOM_FILTER_SIZE, MAXIMUM_BLOOM_FILTER_SIZE]. +// Where ndv is the number of distinct values and fpp is the false positive. +// Probability. Also it is used to test whether OptimalNumOfBits returns value. +// Between [MINIMUM_BLOOM_FILTER_SIZE, MAXIMUM_BLOOM_FILTER_SIZE]. TEST(OptimalValueTest, TestBloomFilter) { auto testOptimalNumEstimation = [](uint32_t ndv, double fpp, - uint32_t num_bits) { - EXPECT_EQ(BlockSplitBloomFilter::OptimalNumOfBits(ndv, fpp), num_bits); - EXPECT_EQ(BlockSplitBloomFilter::OptimalNumOfBytes(ndv, fpp), num_bits / 8); + uint32_t numBits) { + EXPECT_EQ(BlockSplitBloomFilter::optimalNumOfBits(ndv, fpp), numBits); + EXPECT_EQ(BlockSplitBloomFilter::optimalNumOfBytes(ndv, fpp), numBits / 8); }; testOptimalNumEstimation(256, 0.01, UINT32_C(4096)); @@ -292,7 +291,7 @@ TEST(OptimalValueTest, TestBloomFilter) { testOptimalNumEstimation(700, 0.05, UINT32_C(8192)); testOptimalNumEstimation(1500, 0.05, UINT32_C(16384)); - // Boundary check + // Boundary check. testOptimalNumEstimation( 4, 0.01, BlockSplitBloomFilter::kMinimumBloomFilterBytes * 8); testOptimalNumEstimation( @@ -308,8 +307,8 @@ TEST(OptimalValueTest, TestBloomFilter) { BlockSplitBloomFilter::kMaximumBloomFilterBytes * 8); } -// The test below is plainly copied from parquet-mr and serves as a basic sanity -// check of our XXH64 wrapper. +// The test below is plainly copied from parquet-mr and serves as a basic +// sanity. Check of our XXH64 wrapper. const int64_t HASHES_OF_LOOPING_BYTES_WITH_SEED_0[32] = { -1205034819632174695L, -1642502924627794072L, 5216751715308240086L, -1889335612763511331L, -13835840860730338L, -2521325055659080948L, @@ -351,33 +350,32 @@ TEST(XxHashTest, TestBloomFilter) { uint8_t bytes[kNumValues] = {}; for (int i = 0; i < kNumValues; i++) { - ByteArray byte_array(i, bytes); + ByteArray ByteArray(i, bytes); bytes[i] = i; - auto hasher_seed_0 = std::make_unique(); + auto hasherSeed0 = std::make_unique(); EXPECT_EQ( - HASHES_OF_LOOPING_BYTES_WITH_SEED_0[i], - hasher_seed_0->Hash(&byte_array)) + HASHES_OF_LOOPING_BYTES_WITH_SEED_0[i], hasherSeed0->hash(&ByteArray)) << "Hash with seed 0 Error: " << i; } } -// Same as TestBloomFilter but using Batch interface +// Same as TestBloomFilter but using Batch interface. TEST(XxHashTest, TestBloomFilterHashes) { constexpr int kNumValues = 32; uint8_t bytes[kNumValues] = {}; - std::vector byte_array_vector; + std::vector byteArrayVector; for (int i = 0; i < kNumValues; i++) { bytes[i] = i; - byte_array_vector.emplace_back(i, bytes); + byteArrayVector.emplace_back(i, bytes); } - auto hasher_seed_0 = std::make_unique(); + auto hasherSeed0 = std::make_unique(); std::vector hashes; hashes.resize(kNumValues); - hasher_seed_0->Hashes( - byte_array_vector.data(), - static_cast(byte_array_vector.size()), + hasherSeed0->hashes( + byteArrayVector.data(), + static_cast(byteArrayVector.size()), hashes.data()); for (int i = 0; i < kNumValues; i++) { EXPECT_EQ(HASHES_OF_LOOPING_BYTES_WITH_SEED_0[i], hashes[i]) @@ -391,17 +389,17 @@ class TestBatchBloomFilter : public testing::Test { constexpr static int kTestDataSize = 64; // GenerateTestData with size 64. - std::vector GenerateTestData(); + std::vector generateTestData(); - // The Lifetime owner for Test data + // The Lifetime owner for Test data. std::vector members; }; template -std::vector -TestBatchBloomFilter::GenerateTestData() { - std::vector values(kTestDataSize); - GenerateData(kTestDataSize, values.data(), &members); +std::vector +TestBatchBloomFilter::generateTestData() { + std::vector values(kTestDataSize); + generateData(kTestDataSize, values.data(), &members); return values; } @@ -418,58 +416,58 @@ using BloomFilterTestTypes = ::testing::Types< TYPED_TEST_SUITE(TestBatchBloomFilter, BloomFilterTestTypes); TYPED_TEST(TestBatchBloomFilter, Basic) { - using Type = typename TypeParam::c_type; - std::vector test_data = TestFixture::GenerateTestData(); - BlockSplitBloomFilter batch_insert_filter; + using Type = typename TypeParam::CType; + std::vector testData = TestFixture::generateTestData(); + BlockSplitBloomFilter batchInsertFilter; BlockSplitBloomFilter filter; - // Bloom filter fpp parameter + // Bloom filter fpp parameter. const double fpp = 0.05; - filter.Init(BlockSplitBloomFilter::OptimalNumOfBytes( - TestFixture::kTestDataSize, fpp)); - batch_insert_filter.Init(BlockSplitBloomFilter::OptimalNumOfBytes( - TestFixture::kTestDataSize, fpp)); + filter.init( + BlockSplitBloomFilter::optimalNumOfBytes( + TestFixture::kTestDataSize, fpp)); + batchInsertFilter.init( + BlockSplitBloomFilter::optimalNumOfBytes( + TestFixture::kTestDataSize, fpp)); std::vector hashes; - for (const Type& value : test_data) { + for (const Type& value : testData) { uint64_t hash = 0; if constexpr (std::is_same_v) { - hash = filter.Hash(&value, kGenerateDataFLBALength); + hash = filter.hash(&value, kGenerateDataFLBALength); } else { - hash = filter.Hash(&value); + hash = filter.hash(&value); } hashes.push_back(hash); } - std::vector batch_hashes(test_data.size()); + std::vector batchHashes(testData.size()); if constexpr (std::is_same_v) { - batch_insert_filter.Hashes( - test_data.data(), + batchInsertFilter.hashes( + testData.data(), kGenerateDataFLBALength, - static_cast(test_data.size()), - batch_hashes.data()); + static_cast(testData.size()), + batchHashes.data()); } else { - batch_insert_filter.Hashes( - test_data.data(), - static_cast(test_data.size()), - batch_hashes.data()); + batchInsertFilter.hashes( + testData.data(), static_cast(testData.size()), batchHashes.data()); } - EXPECT_EQ(hashes, batch_hashes); + EXPECT_EQ(hashes, batchHashes); std::shared_ptr buffer; - std::shared_ptr batch_insert_buffer; + std::shared_ptr batchInsertBuffer; { - auto sink = CreateOutputStream(); - filter.WriteTo(sink.get()); + auto sink = createOutputStream(); + filter.writeTo(sink.get()); ASSERT_OK_AND_ASSIGN(buffer, sink->Finish()); } { - auto sink = CreateOutputStream(); - batch_insert_filter.WriteTo(sink.get()); - ASSERT_OK_AND_ASSIGN(batch_insert_buffer, sink->Finish()); + auto sink = createOutputStream(); + batchInsertFilter.writeTo(sink.get()); + ASSERT_OK_AND_ASSIGN(batchInsertBuffer, sink->Finish()); } - AssertBufferEqual(*buffer, *batch_insert_buffer); + ::arrow::AssertBufferEqual(*buffer, *batchInsertBuffer); } } // namespace test diff --git a/velox/dwio/parquet/writer/arrow/tests/CMakeLists.txt b/velox/dwio/parquet/writer/arrow/tests/CMakeLists.txt index 5906a75516d..19e3d0e3b92 100644 --- a/velox/dwio/parquet/writer/arrow/tests/CMakeLists.txt +++ b/velox/dwio/parquet/writer/arrow/tests/CMakeLists.txt @@ -25,6 +25,7 @@ add_executable( PropertiesTest.cpp SchemaTest.cpp StatisticsTest.cpp + StringTruncationTest.cpp TypesTest.cpp ) @@ -39,7 +40,7 @@ target_link_libraries( arrow arrow_testing velox_dwio_native_parquet_reader - velox_temp_path + velox_test_util ) add_library( @@ -52,6 +53,17 @@ add_library( TestUtil.cpp XxHasher.cpp ) +velox_add_test_headers( + velox_dwio_arrow_parquet_writer_test_lib + BloomFilter.h + BloomFilterReader.h + ColumnReader.h + ColumnScanner.h + FileReader.h + Hasher.h + TestUtil.h + XxHasher.h +) target_link_libraries( velox_dwio_arrow_parquet_writer_test_lib diff --git a/velox/dwio/parquet/writer/arrow/tests/ColumnReader.cpp b/velox/dwio/parquet/writer/arrow/tests/ColumnReader.cpp index 30dd2def721..916f93b76e4 100644 --- a/velox/dwio/parquet/writer/arrow/tests/ColumnReader.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/ColumnReader.cpp @@ -62,100 +62,99 @@ namespace bit_util = arrow::bit_util; namespace facebook::velox::parquet::arrow { -fmt::underlying_t format_as(Type::type type) { +fmt::underlying_t formatAs(Type::type type) { return fmt::underlying(type); } namespace { -// The minimum number of repetition/definition levels to decode at a time, for -// better vectorized performance when doing many smaller record reads +// The minimum number of repetition/definition levels to decode at a time, for. +// Better vectorized performance when doing many smaller record reads. constexpr int64_t kMinLevelBatchSize = 1024; // Batch size for reading and throwing away values during skip. // Both RecordReader and the ColumnReader use this for skipping. constexpr int64_t kSkipScratchBatchSize = 1024; -inline bool HasSpacedValues(const ColumnDescriptor* descr) { - if (descr->max_repetition_level() > 0) { - // repeated+flat case - return !descr->schema_node()->is_required(); +inline bool hasSpacedValues(const ColumnDescriptor* descr) { + if (descr->maxRepetitionLevel() > 0) { + // Repeated+flat case. + return !descr->schemaNode()->isRequired(); } else { - // non-repeated+nested case - // Find if a node forces nulls in the lowest level along the hierarchy - const schema::Node* node = descr->schema_node().get(); - while (node) { - if (node->is_optional()) { + // Non-repeated+nested case. + // Find if a node forces nulls in the lowest level along the hierarchy. + const schema::Node* Node = descr->schemaNode().get(); + while (Node) { + if (Node->isOptional()) { return true; } - node = node->parent(); + Node = Node->parent(); } return false; } } // Throws exception if number_decoded does not match expected. -inline void CheckNumberDecoded(int64_t number_decoded, int64_t expected) { - if (ARROW_PREDICT_FALSE(number_decoded != expected)) { - ParquetException::EofException( - "Decoded values " + std::to_string(number_decoded) + +inline void checkNumberDecoded(int64_t numberDecoded, int64_t expected) { + if (ARROW_PREDICT_FALSE(numberDecoded != expected)) { + ParquetException::eofException( + "Decoded values " + std::to_string(numberDecoded) + " does not match expected " + std::to_string(expected)); } } } // namespace -LevelDecoder::LevelDecoder() : num_values_remaining_(0) {} +LevelDecoder::LevelDecoder() : numValuesRemaining_(0) {} LevelDecoder::~LevelDecoder() {} -int LevelDecoder::SetData( +int LevelDecoder::setData( Encoding::type encoding, - int16_t max_level, - int num_buffered_values, + int16_t maxLevel, + int numBufferedValues, const uint8_t* data, - int32_t data_size) { - max_level_ = max_level; - int32_t num_bytes = 0; + int32_t dataSize) { + maxLevel_ = maxLevel; + int32_t numBytes = 0; encoding_ = encoding; - num_values_remaining_ = num_buffered_values; - bit_width_ = ::arrow::bit_util::Log2(max_level + 1); + numValuesRemaining_ = numBufferedValues; + bitWidth_ = ::arrow::bit_util::Log2(maxLevel + 1); switch (encoding) { - case Encoding::RLE: { - if (data_size < 4) { + case Encoding::kRle: { + if (dataSize < 4) { throw ParquetException("Received invalid levels (corrupt data page?)"); } - num_bytes = ::arrow::util::SafeLoadAs(data); - if (num_bytes < 0 || num_bytes > data_size - 4) { + numBytes = ::arrow::util::SafeLoadAs(data); + if (numBytes < 0 || numBytes > dataSize - 4) { throw ParquetException( "Received invalid number of bytes (corrupt data page?)"); } - const uint8_t* decoder_data = data + 4; - if (!rle_decoder_) { - rle_decoder_ = - std::make_unique(decoder_data, num_bytes, bit_width_); + const uint8_t* decoderData = data + 4; + if (!rleDecoder_) { + rleDecoder_ = + std::make_unique(decoderData, numBytes, bitWidth_); } else { - rle_decoder_->Reset(decoder_data, num_bytes, bit_width_); + rleDecoder_->Reset(decoderData, numBytes, bitWidth_); } - return 4 + num_bytes; + return 4 + numBytes; } - case Encoding::BIT_PACKED: { - int num_bits = 0; - if (MultiplyWithOverflow(num_buffered_values, bit_width_, &num_bits)) { + case Encoding::kBitPacked: { + int32_t numBits = 0; + if (MultiplyWithOverflow(numBufferedValues, bitWidth_, &numBits)) { throw ParquetException( "Number of buffered values too large (corrupt data page?)"); } - num_bytes = - static_cast(::arrow::bit_util::BytesForBits(num_bits)); - if (num_bytes < 0 || num_bytes > data_size - 4) { + numBytes = static_cast(::arrow::bit_util::BytesForBits(numBits)); + if (numBytes < 0 || numBytes > dataSize - 4) { throw ParquetException( "Received invalid number of bytes (corrupt data page?)"); } - if (!bit_packed_decoder_) { - bit_packed_decoder_ = std::make_unique(data, num_bytes); + if (!bitPackedDecoder_) { + bitPackedDecoder_ = std::make_unique(data, numBytes); } else { - bit_packed_decoder_->Reset(data, num_bytes); + bitPackedDecoder_->Reset(data, numBytes); } - return num_bytes; + return numBytes; } default: throw ParquetException("Unknown encoding type for levels."); @@ -163,265 +162,263 @@ int LevelDecoder::SetData( return -1; } -void LevelDecoder::SetDataV2( - int32_t num_bytes, - int16_t max_level, - int num_buffered_values, +void LevelDecoder::setDataV2( + int32_t numBytes, + int16_t maxLevel, + int numBufferedValues, const uint8_t* data) { - max_level_ = max_level; - // Repetition and definition levels always uses RLE encoding - // in the DataPageV2 format. - if (num_bytes < 0) { + maxLevel_ = maxLevel; + // Repetition and definition levels always uses RLE encoding. + // In the DataPageV2 format. + if (numBytes < 0) { throw ParquetException("Invalid page header (corrupt data page?)"); } - encoding_ = Encoding::RLE; - num_values_remaining_ = num_buffered_values; - bit_width_ = ::arrow::bit_util::Log2(max_level + 1); + encoding_ = Encoding::kRle; + numValuesRemaining_ = numBufferedValues; + bitWidth_ = ::arrow::bit_util::Log2(maxLevel + 1); - if (!rle_decoder_) { - rle_decoder_ = std::make_unique(data, num_bytes, bit_width_); + if (!rleDecoder_) { + rleDecoder_ = std::make_unique(data, numBytes, bitWidth_); } else { - rle_decoder_->Reset(data, num_bytes, bit_width_); + rleDecoder_->Reset(data, numBytes, bitWidth_); } } -int LevelDecoder::Decode(int batch_size, int16_t* levels) { - int num_decoded = 0; +int LevelDecoder::decode(int batchSize, int16_t* levels) { + int numDecoded = 0; - int num_values = std::min(num_values_remaining_, batch_size); - if (encoding_ == Encoding::RLE) { - num_decoded = rle_decoder_->GetBatch(levels, num_values); + int numValues = std::min(numValuesRemaining_, batchSize); + if (encoding_ == Encoding::kRle) { + numDecoded = rleDecoder_->GetBatch(levels, numValues); } else { - num_decoded = bit_packed_decoder_->GetBatch(bit_width_, levels, num_values); + numDecoded = bitPackedDecoder_->GetBatch(bitWidth_, levels, numValues); } - if (num_decoded > 0) { - MinMax min_max = FindMinMax(levels, num_decoded); - if (ARROW_PREDICT_FALSE(min_max.min < 0 || min_max.max > max_level_)) { + if (numDecoded > 0) { + MinMax minMax = FindMinMax(levels, numDecoded); + if (ARROW_PREDICT_FALSE(minMax.min < 0 || minMax.max > maxLevel_)) { std::stringstream ss; - ss << "Malformed levels. min: " << min_max.min << " max: " << min_max.max - << " out of range. Max Level: " << max_level_; + ss << "Malformed levels. min: " << minMax.min << " max: " << minMax.max + << " out of range. Max Level: " << maxLevel_; throw ParquetException(ss.str()); } } - num_values_remaining_ -= num_decoded; - return num_decoded; + numValuesRemaining_ -= numDecoded; + return numDecoded; } namespace { -// Extracts encoded statistics from V1 and V2 data page headers +// Extracts encoded statistics from V1 and V2 data page headers. template -EncodedStatistics ExtractStatsFromHeader(const H& header) { - EncodedStatistics page_statistics; +EncodedStatistics extractStatsFromHeader(const H& header) { + EncodedStatistics pageStatistics; if (!header.__isset.statistics) { - return page_statistics; + return pageStatistics; } const facebook::velox::parquet::thrift::Statistics& stats = header.statistics; - // Use the new V2 min-max statistics over the former one if it is filled + // Use the new V2 min-max statistics over the former one if it is filled. if (stats.__isset.max_value || stats.__isset.min_value) { // TODO: check if the column_order is TYPE_DEFINED_ORDER. if (stats.__isset.max_value) { - page_statistics.set_max(stats.max_value); + pageStatistics.setMax(stats.max_value); } if (stats.__isset.min_value) { - page_statistics.set_min(stats.min_value); + pageStatistics.setMin(stats.min_value); } } else if (stats.__isset.max || stats.__isset.min) { // TODO: check created_by to see if it is corrupted for some types. // TODO: check if the sort_order is SIGNED. if (stats.__isset.max) { - page_statistics.set_max(stats.max); + pageStatistics.setMax(stats.max); } if (stats.__isset.min) { - page_statistics.set_min(stats.min); + pageStatistics.setMin(stats.min); } } if (stats.__isset.null_count) { - page_statistics.set_null_count(stats.null_count); + pageStatistics.setNullCount(stats.null_count); } if (stats.__isset.distinct_count) { - page_statistics.set_distinct_count(stats.distinct_count); + pageStatistics.setDistinctCount(stats.distinct_count); } - return page_statistics; + return pageStatistics; } -void CheckNumValuesInHeader(int num_values) { - if (num_values < 0) { +void checkNumValuesInHeader(int numValues) { + if (numValues < 0) { throw ParquetException("Invalid page header (negative number of values)"); } } -// ---------------------------------------------------------------------- -// SerializedPageReader deserializes Thrift metadata and pages that have been -// assembled in a serialized stream for storing in a Parquet files +// ----------------------------------------------------------------------. +// SerializedPageReader deserializes Thrift metadata and pages that have been. +// Assembled in a serialized stream for storing in a Parquet files. -// This subclass delimits pages appearing in a serialized stream, each preceded -// by a serialized Thrift facebook::velox::parquet::thrift::PageHeader -// indicating the type of each page and the page metadata. +// This subclass delimits pages appearing in a serialized stream, each preceded. +// By a serialized Thrift facebook::velox::parquet::thrift::PageHeader. +// Indicating the type of each page and the page metadata. class SerializedPageReader : public PageReader { public: SerializedPageReader( std::shared_ptr stream, - int64_t total_num_values, - Compression::type codec, + int64_t totalNumValues, + Compression::type Codec, const ReaderProperties& properties, - const CryptoContext* crypto_ctx, - bool always_compressed) + const CryptoContext* cryptoCtx, + bool alwaysCompressed) : properties_(properties), stream_(std::move(stream)), - decompression_buffer_(AllocateBuffer(properties_.memory_pool(), 0)), - page_ordinal_(0), - seen_num_values_(0), - total_num_values_(total_num_values), - decryption_buffer_(AllocateBuffer(properties_.memory_pool(), 0)) { - if (crypto_ctx != nullptr) { - crypto_ctx_ = *crypto_ctx; - InitDecryption(); + decompressionBuffer_(allocateBuffer(properties_.memoryPool(), 0)), + pageOrdinal_(0), + seenNumValues_(0), + totalNumValues_(totalNumValues), + decryptionBuffer_(allocateBuffer(properties_.memoryPool(), 0)) { + if (cryptoCtx != nullptr) { + cryptoCtx_ = *cryptoCtx; + initDecryption(); } - max_page_header_size_ = kDefaultMaxPageHeaderSize; - decompressor_ = GetCodec(codec); - always_compressed_ = always_compressed; + maxPageHeaderSize_ = kDefaultMaxPageHeaderSize; + decompressor_ = getCodec(Codec); + alwaysCompressed_ = alwaysCompressed; } - // Implement the PageReader interface + // Implement the PageReader interface. // - // The returned Page contains references that aren't guaranteed to live - // beyond the next call to NextPage(). SerializedPageReader reuses the - // decryption and decompression buffers internally, so if NextPage() is - // called then the content of previous page might be invalidated. - std::shared_ptr NextPage() override; + // The returned Page contains references that aren't guaranteed to live. + // Beyond the next call to NextPage(). SerializedPageReader reuses the. + // Decryption and decompression buffers internally, so if NextPage() is. + // Called then the content of previous page might be invalidated. + std::shared_ptr nextPage() override; - void set_max_page_header_size(uint32_t size) override { - max_page_header_size_ = size; + void setMaxPageHeaderSize(uint32_t size) override { + maxPageHeaderSize_ = size; } private: - void UpdateDecryption( - const std::shared_ptr& decryptor, - int8_t module_type, - std::string* page_aad); + void updateDecryption( + const std::shared_ptr& Decryptor, + int8_t moduleType, + std::string* pageAad); - void InitDecryption(); + void initDecryption(); - std::shared_ptr DecompressIfNeeded( - std::shared_ptr page_buffer, - int compressed_len, - int uncompressed_len, - int levels_byte_len = 0); + std::shared_ptr decompressIfNeeded( + std::shared_ptr pageBuffer, + int compressedLen, + int uncompressedLen, + int levelsByteLen = 0); - // Returns true for non-data pages, and if we should skip based on - // data_page_filter_. Performs basic checks on values in the page header. + // Returns true for non-data pages, and if we should skip based on. + // Data_page_filter_. Performs basic checks on values in the page header. // Fills in data_page_statistics. - bool ShouldSkipPage(EncodedStatistics* data_page_statistics); + bool shouldSkipPage(EncodedStatistics* dataPageStatistics); const ReaderProperties properties_; std::shared_ptr stream_; - facebook::velox::parquet::thrift::PageHeader current_page_header_; - std::shared_ptr current_page_; + facebook::velox::parquet::thrift::PageHeader currentPageHeader_; + std::shared_ptr currentPage_; // Compression codec to use. std::unique_ptr decompressor_; - std::shared_ptr decompression_buffer_; + std::shared_ptr decompressionBuffer_; - bool always_compressed_; + bool alwaysCompressed_; - // The fields below are used for calculation of AAD (additional authenticated - // data) suffix which is part of the Parquet Modular Encryption. The AAD - // suffix for a parquet module is built internally by concatenating different - // parts some of which include the row group ordinal, column ordinal and page - // ordinal. Please refer to the encryption specification for more details: + // The fields below are used for calculation of AAD (additional authenticated. + // Data) suffix which is part of the Parquet Modular Encryption. The AAD. + // Suffix for a parquet module is built internally by concatenating different. + // Parts some of which include the row group ordinal, column ordinal and page. + // Ordinal. Please refer to the encryption specification for more details: // https://github.com/apache/parquet-format/blob/encryption/Encryption.md#44-additional-authenticated-data - // The ordinal fields in the context below are used for AAD suffix - // calculation. - CryptoContext crypto_ctx_; - int32_t page_ordinal_; // page ordinal does not count the dictionary page + // The ordinal fields in the context below are used for AAD suffix. + // Calculation. + CryptoContext cryptoCtx_; + int32_t pageOrdinal_; // page ordinal does not count the dictionary page - // Maximum allowed page size - uint32_t max_page_header_size_; + // Maximum allowed page size. + uint32_t maxPageHeaderSize_; - // Number of values read in data pages so far - int64_t seen_num_values_; + // Number of values read in data pages so far. + int64_t seenNumValues_; - // Number of values in all the data pages - int64_t total_num_values_; + // Number of values in all the data pages. + int64_t totalNumValues_; - // data_page_aad_ and data_page_header_aad_ contain the AAD for data page and - // data page header in a single column respectively. While calculating AAD for - // different pages in a single column the pages AAD is updated by only the - // page ordinal. - std::string data_page_aad_; - std::string data_page_header_aad_; - // Encryption - std::shared_ptr decryption_buffer_; + // Data_page_aad_ and data_page_header_aad_ contain the AAD for data page and. + // Data page header in a single column respectively. While calculating AAD + // for. Different pages in a single column the pages AAD is updated by only + // the. Page ordinal. + std::string dataPageAad_; + std::string dataPageHeaderAad_; + // Encryption. + std::shared_ptr decryptionBuffer_; }; -void SerializedPageReader::InitDecryption() { +void SerializedPageReader::initDecryption() { // Prepare the AAD for quick update later. - if (crypto_ctx_.data_decryptor != nullptr) { - VELOX_DCHECK(!crypto_ctx_.data_decryptor->file_aad().empty()); - data_page_aad_ = encryption::CreateModuleAad( - crypto_ctx_.data_decryptor->file_aad(), + if (cryptoCtx_.dataDecryptor != nullptr) { + VELOX_DCHECK(!cryptoCtx_.dataDecryptor->fileAad().empty()); + dataPageAad_ = encryption::createModuleAad( + cryptoCtx_.dataDecryptor->fileAad(), encryption::kDataPage, - crypto_ctx_.row_group_ordinal, - crypto_ctx_.column_ordinal, + cryptoCtx_.rowGroupOrdinal, + cryptoCtx_.columnOrdinal, kNonPageOrdinal); } - if (crypto_ctx_.meta_decryptor != nullptr) { - VELOX_DCHECK(!crypto_ctx_.meta_decryptor->file_aad().empty()); - data_page_header_aad_ = encryption::CreateModuleAad( - crypto_ctx_.meta_decryptor->file_aad(), + if (cryptoCtx_.metaDecryptor != nullptr) { + VELOX_DCHECK(!cryptoCtx_.metaDecryptor->fileAad().empty()); + dataPageHeaderAad_ = encryption::createModuleAad( + cryptoCtx_.metaDecryptor->fileAad(), encryption::kDataPageHeader, - crypto_ctx_.row_group_ordinal, - crypto_ctx_.column_ordinal, + cryptoCtx_.rowGroupOrdinal, + cryptoCtx_.columnOrdinal, kNonPageOrdinal); } } -void SerializedPageReader::UpdateDecryption( - const std::shared_ptr& decryptor, - int8_t module_type, - std::string* page_aad) { - VELOX_DCHECK_NOT_NULL(decryptor); - if (crypto_ctx_.start_decrypt_with_dictionary_page) { - std::string aad = encryption::CreateModuleAad( - decryptor->file_aad(), - module_type, - crypto_ctx_.row_group_ordinal, - crypto_ctx_.column_ordinal, +void SerializedPageReader::updateDecryption( + const std::shared_ptr& Decryptor, + int8_t moduleType, + std::string* pageAad) { + VELOX_DCHECK_NOT_NULL(Decryptor); + if (cryptoCtx_.startDecryptWithDictionaryPage) { + std::string aad = encryption::createModuleAad( + Decryptor->fileAad(), + moduleType, + cryptoCtx_.rowGroupOrdinal, + cryptoCtx_.columnOrdinal, kNonPageOrdinal); - decryptor->UpdateAad(aad); + Decryptor->updateAad(aad); } else { - encryption::QuickUpdatePageAad(page_ordinal_, page_aad); - decryptor->UpdateAad(*page_aad); + encryption::quickUpdatePageAad(pageOrdinal_, pageAad); + Decryptor->updateAad(*pageAad); } } -bool SerializedPageReader::ShouldSkipPage( - EncodedStatistics* data_page_statistics) { - const PageType::type page_type = LoadEnumSafe(¤t_page_header_.type); - if (page_type == PageType::DATA_PAGE) { +bool SerializedPageReader::shouldSkipPage( + EncodedStatistics* dataPageStatistics) { + const PageType::type pageType = loadenumSafe(¤tPageHeader_.type); + if (pageType == PageType::kDataPage) { const facebook::velox::parquet::thrift::DataPageHeader& header = - current_page_header_.data_page_header; - CheckNumValuesInHeader(header.num_values); - *data_page_statistics = ExtractStatsFromHeader(header); - seen_num_values_ += header.num_values; - if (data_page_filter_) { - const EncodedStatistics* filter_statistics = - data_page_statistics->is_set() ? data_page_statistics : nullptr; - DataPageStats data_page_stats( - filter_statistics, - header.num_values, - /*num_rows=*/std::nullopt); - if (data_page_filter_(data_page_stats)) { + currentPageHeader_.data_page_header; + checkNumValuesInHeader(header.num_values); + *dataPageStatistics = extractStatsFromHeader(header); + seenNumValues_ += header.num_values; + if (dataPageFilter_) { + const EncodedStatistics* filterStatistics = + dataPageStatistics->isSet() ? dataPageStatistics : nullptr; + DataPageStats dataPageStats( + filterStatistics, header.num_values, std::nullopt); + if (dataPageFilter_(dataPageStats)) { return true; } } - } else if (page_type == PageType::DATA_PAGE_V2) { + } else if (pageType == PageType::kDataPageV2) { const facebook::velox::parquet::thrift::DataPageHeaderV2& header = - current_page_header_.data_page_header_v2; - CheckNumValuesInHeader(header.num_values); + currentPageHeader_.data_page_header_v2; + checkNumValuesInHeader(header.num_values); if (header.num_rows < 0) { throw ParquetException("Invalid page header (negative number of rows)"); } @@ -430,204 +427,200 @@ bool SerializedPageReader::ShouldSkipPage( throw ParquetException( "Invalid page header (negative levels byte length)"); } - *data_page_statistics = ExtractStatsFromHeader(header); - seen_num_values_ += header.num_values; - if (data_page_filter_) { - const EncodedStatistics* filter_statistics = - data_page_statistics->is_set() ? data_page_statistics : nullptr; - DataPageStats data_page_stats( - filter_statistics, header.num_values, header.num_rows); - if (data_page_filter_(data_page_stats)) { + *dataPageStatistics = extractStatsFromHeader(header); + seenNumValues_ += header.num_values; + if (dataPageFilter_) { + const EncodedStatistics* filterStatistics = + dataPageStatistics->isSet() ? dataPageStatistics : nullptr; + DataPageStats dataPageStats( + filterStatistics, header.num_values, header.num_rows); + if (dataPageFilter_(dataPageStats)) { return true; } } - } else if (page_type == PageType::DICTIONARY_PAGE) { - const facebook::velox::parquet::thrift::DictionaryPageHeader& dict_header = - current_page_header_.dictionary_page_header; - CheckNumValuesInHeader(dict_header.num_values); + } else if (pageType == PageType::kDictionaryPage) { + const facebook::velox::parquet::thrift::DictionaryPageHeader& dictHeader = + currentPageHeader_.dictionary_page_header; + checkNumValuesInHeader(dictHeader.num_values); } else { - // We don't know what this page type is. We're allowed to skip non-data - // pages. + // We don't know what this page type is. We're allowed to skip non-data. + // Pages. return true; } return false; } -std::shared_ptr SerializedPageReader::NextPage() { +std::shared_ptr SerializedPageReader::nextPage() { ThriftDeserializer deserializer(properties_); - // Loop here because there may be unhandled page types that we skip until - // finding a page that we do know what to do with - while (seen_num_values_ < total_num_values_) { - uint32_t header_size = 0; - uint32_t allowed_page_size = kDefaultPageHeaderSize; + // Loop here because there may be unhandled page types that we skip until. + // Finding a page that we do know what to do with. + while (seenNumValues_ < totalNumValues_) { + uint32_t headerSize = 0; + uint32_t allowedPageSize = kDefaultPageHeaderSize; - // Page headers can be very large because of page statistics - // We try to deserialize a larger buffer progressively - // until a maximum allowed header limit + // Page headers can be very large because of page statistics. + // We try to deserialize a larger buffer progressively. + // Until a maximum allowed header limit. while (true) { - PARQUET_ASSIGN_OR_THROW(auto view, stream_->Peek(allowed_page_size)); + PARQUET_ASSIGN_OR_THROW(auto view, stream_->Peek(allowedPageSize)); if (view.size() == 0) { return std::shared_ptr(nullptr); } - // This gets used, then set by DeserializeThriftMsg - header_size = static_cast(view.size()); + // This gets used, then set by DeserializeThriftMsg. + headerSize = static_cast(view.size()); try { - if (crypto_ctx_.meta_decryptor != nullptr) { - UpdateDecryption( - crypto_ctx_.meta_decryptor, + if (cryptoCtx_.metaDecryptor != nullptr) { + updateDecryption( + cryptoCtx_.metaDecryptor, encryption::kDictionaryPageHeader, - &data_page_header_aad_); + &dataPageHeaderAad_); } // Reset current page header to avoid unclearing the __isset flag. - current_page_header_ = facebook::velox::parquet::thrift::PageHeader(); - deserializer.DeserializeMessage( + currentPageHeader_ = facebook::velox::parquet::thrift::PageHeader(); + deserializer.deserializeMessage( reinterpret_cast(view.data()), - &header_size, - ¤t_page_header_, - crypto_ctx_.meta_decryptor); + &headerSize, + ¤tPageHeader_, + cryptoCtx_.metaDecryptor); break; } catch (std::exception& e) { - // Failed to deserialize. Double the allowed page header size and try - // again + // Failed to deserialize. Double the allowed page header size and try. + // Again. std::stringstream ss; ss << e.what(); - allowed_page_size *= 2; - if (allowed_page_size > max_page_header_size_) { + allowedPageSize *= 2; + if (allowedPageSize > maxPageHeaderSize_) { ss << "Deserializing page header failed.\n"; throw ParquetException(ss.str()); } } } - // Advance the stream offset - PARQUET_THROW_NOT_OK(stream_->Advance(header_size)); + // Advance the stream offset. + PARQUET_THROW_NOT_OK(stream_->Advance(headerSize)); - int compressed_len = current_page_header_.compressed_page_size; - int uncompressed_len = current_page_header_.uncompressed_page_size; - if (compressed_len < 0 || uncompressed_len < 0) { + int compressedLen = currentPageHeader_.compressed_page_size; + int uncompressedLen = currentPageHeader_.uncompressed_page_size; + if (compressedLen < 0 || uncompressedLen < 0) { throw ParquetException("Invalid page header"); } - EncodedStatistics data_page_statistics; - if (ShouldSkipPage(&data_page_statistics)) { - PARQUET_THROW_NOT_OK(stream_->Advance(compressed_len)); + EncodedStatistics dataPageStatistics; + if (shouldSkipPage(&dataPageStatistics)) { + PARQUET_THROW_NOT_OK(stream_->Advance(compressedLen)); continue; } - if (crypto_ctx_.data_decryptor != nullptr) { - UpdateDecryption( - crypto_ctx_.data_decryptor, - encryption::kDictionaryPage, - &data_page_aad_); + if (cryptoCtx_.dataDecryptor != nullptr) { + updateDecryption( + cryptoCtx_.dataDecryptor, encryption::kDictionaryPage, &dataPageAad_); } // Read the compressed data page. - PARQUET_ASSIGN_OR_THROW(auto page_buffer, stream_->Read(compressed_len)); - if (page_buffer->size() != compressed_len) { + PARQUET_ASSIGN_OR_THROW(auto pageBuffer, stream_->Read(compressedLen)); + if (pageBuffer->size() != compressedLen) { std::stringstream ss; - ss << "Page was smaller (" << page_buffer->size() << ") than expected (" - << compressed_len << ")"; - ParquetException::EofException(ss.str()); + ss << "Page was smaller (" << pageBuffer->size() << ") than expected (" + << compressedLen << ")"; + ParquetException::eofException(ss.str()); } - const PageType::type page_type = LoadEnumSafe(¤t_page_header_.type); + const PageType::type pageType = loadenumSafe(¤tPageHeader_.type); - if (properties_.page_checksum_verification() && - current_page_header_.__isset.crc && PageCanUseChecksum(page_type)) { - // verify crc + if (properties_.pageChecksumVerification() && + currentPageHeader_.__isset.crc && pageCanUseChecksum(pageType)) { + // Verify crc. uint32_t checksum = ::arrow::internal::crc32( - /* prev */ 0, page_buffer->data(), compressed_len); - if (static_cast(checksum) != current_page_header_.crc) { + /* prev */ 0, pageBuffer->data(), compressedLen); + if (static_cast(checksum) != currentPageHeader_.crc) { throw ParquetException( "could not verify page integrity, CRC checksum verification failed for " "page_ordinal " + - std::to_string(page_ordinal_)); + std::to_string(pageOrdinal_)); } } - // Decrypt it if we need to - if (crypto_ctx_.data_decryptor != nullptr) { - PARQUET_THROW_NOT_OK(decryption_buffer_->Resize( - compressed_len - crypto_ctx_.data_decryptor->CiphertextSizeDelta(), - /*shrink_to_fit=*/false)); - compressed_len = crypto_ctx_.data_decryptor->Decrypt( - page_buffer->data(), - compressed_len, - decryption_buffer_->mutable_data()); + // Decrypt it if we need to. + if (cryptoCtx_.dataDecryptor != nullptr) { + PARQUET_THROW_NOT_OK(decryptionBuffer_->Resize( + compressedLen - cryptoCtx_.dataDecryptor->ciphertextSizeDelta(), + false)); + compressedLen = cryptoCtx_.dataDecryptor->decrypt( + pageBuffer->data(), compressedLen, decryptionBuffer_->mutable_data()); - page_buffer = decryption_buffer_; + pageBuffer = decryptionBuffer_; } - if (page_type == PageType::DICTIONARY_PAGE) { - crypto_ctx_.start_decrypt_with_dictionary_page = false; - const facebook::velox::parquet::thrift::DictionaryPageHeader& - dict_header = current_page_header_.dictionary_page_header; - bool is_sorted = - dict_header.__isset.is_sorted ? dict_header.is_sorted : false; + if (pageType == PageType::kDictionaryPage) { + cryptoCtx_.startDecryptWithDictionaryPage = false; + const facebook::velox::parquet::thrift::DictionaryPageHeader& dictHeader = + currentPageHeader_.dictionary_page_header; + bool isSorted = + dictHeader.__isset.is_sorted ? dictHeader.is_sorted : false; - page_buffer = DecompressIfNeeded( - std::move(page_buffer), compressed_len, uncompressed_len); + pageBuffer = decompressIfNeeded( + std::move(pageBuffer), compressedLen, uncompressedLen); return std::make_shared( - page_buffer, - dict_header.num_values, - LoadEnumSafe(&dict_header.encoding), - is_sorted); - } else if (page_type == PageType::DATA_PAGE) { - ++page_ordinal_; + pageBuffer, + dictHeader.num_values, + loadenumSafe(&dictHeader.encoding), + isSorted); + } else if (pageType == PageType::kDataPage) { + ++pageOrdinal_; const facebook::velox::parquet::thrift::DataPageHeader& header = - current_page_header_.data_page_header; - page_buffer = DecompressIfNeeded( - std::move(page_buffer), compressed_len, uncompressed_len); + currentPageHeader_.data_page_header; + pageBuffer = decompressIfNeeded( + std::move(pageBuffer), compressedLen, uncompressedLen); return std::make_shared( - page_buffer, + pageBuffer, header.num_values, - LoadEnumSafe(&header.encoding), - LoadEnumSafe(&header.definition_level_encoding), - LoadEnumSafe(&header.repetition_level_encoding), - uncompressed_len, - data_page_statistics); - } else if (page_type == PageType::DATA_PAGE_V2) { - ++page_ordinal_; + loadenumSafe(&header.encoding), + loadenumSafe(&header.definition_level_encoding), + loadenumSafe(&header.repetition_level_encoding), + uncompressedLen, + dataPageStatistics); + } else if (pageType == PageType::kDataPageV2) { + ++pageOrdinal_; const facebook::velox::parquet::thrift::DataPageHeaderV2& header = - current_page_header_.data_page_header_v2; + currentPageHeader_.data_page_header_v2; // Arrow prior to 3.0.0 set is_compressed to false but still compressed. - bool is_compressed = + bool isCompressed = (header.__isset.is_compressed ? header.is_compressed : false) || - always_compressed_; + alwaysCompressed_; - // Uncompress if needed - int levels_byte_len; + // Uncompress if needed. + int levelsByteLen; if (AddWithOverflow( header.definition_levels_byte_length, header.repetition_levels_byte_length, - &levels_byte_len)) { + &levelsByteLen)) { throw ParquetException("Levels size too large (corrupt file?)"); } - // DecompressIfNeeded doesn't take `is_compressed` into account as - // it's page type-agnostic. - if (is_compressed) { - page_buffer = DecompressIfNeeded( - std::move(page_buffer), - compressed_len, - uncompressed_len, - levels_byte_len); + // DecompressIfNeeded doesn't take `is_compressed` into account as. + // It's page type-agnostic. + if (isCompressed) { + pageBuffer = decompressIfNeeded( + std::move(pageBuffer), + compressedLen, + uncompressedLen, + levelsByteLen); } return std::make_shared( - page_buffer, + pageBuffer, header.num_values, header.num_nulls, header.num_rows, - LoadEnumSafe(&header.encoding), + loadenumSafe(&header.encoding), header.definition_levels_byte_length, header.repetition_levels_byte_length, - uncompressed_len, - is_compressed, - data_page_statistics); + uncompressedLen, + isCompressed, + dataPageStatistics); } else { throw ParquetException( "Internal error, we have already skipped non-data pages in ShouldSkipPage()"); @@ -636,204 +629,200 @@ std::shared_ptr SerializedPageReader::NextPage() { return std::shared_ptr(nullptr); } -std::shared_ptr SerializedPageReader::DecompressIfNeeded( - std::shared_ptr page_buffer, - int compressed_len, - int uncompressed_len, - int levels_byte_len) { +std::shared_ptr SerializedPageReader::decompressIfNeeded( + std::shared_ptr pageBuffer, + int compressedLen, + int uncompressedLen, + int levelsByteLen) { if (decompressor_ == nullptr) { - return page_buffer; + return pageBuffer; } - if (compressed_len < levels_byte_len || uncompressed_len < levels_byte_len) { + if (compressedLen < levelsByteLen || uncompressedLen < levelsByteLen) { throw ParquetException("Invalid page header"); } // Grow the uncompressed buffer if we need to. - PARQUET_THROW_NOT_OK( - decompression_buffer_->Resize(uncompressed_len, /*shrink_to_fit=*/false)); + PARQUET_THROW_NOT_OK(decompressionBuffer_->Resize(uncompressedLen, false)); - if (levels_byte_len > 0) { - // First copy the levels as-is - uint8_t* decompressed = decompression_buffer_->mutable_data(); - memcpy(decompressed, page_buffer->data(), levels_byte_len); + if (levelsByteLen > 0) { + // First copy the levels as-is. + uint8_t* decompressed = decompressionBuffer_->mutable_data(); + memcpy(decompressed, pageBuffer->data(), levelsByteLen); } - // Decompress the values - PARQUET_THROW_NOT_OK(decompressor_->Decompress( - compressed_len - levels_byte_len, - page_buffer->data() + levels_byte_len, - uncompressed_len - levels_byte_len, - decompression_buffer_->mutable_data() + levels_byte_len)); + // Decompress the values. + PARQUET_THROW_NOT_OK(decompressor_->decompress( + compressedLen - levelsByteLen, + pageBuffer->data() + levelsByteLen, + uncompressedLen - levelsByteLen, + decompressionBuffer_->mutable_data() + levelsByteLen)); - return decompression_buffer_; + return decompressionBuffer_; } } // namespace -std::unique_ptr PageReader::Open( +std::unique_ptr PageReader::open( std::shared_ptr stream, - int64_t total_num_values, - Compression::type codec, + int64_t totalNumValues, + Compression::type Codec, const ReaderProperties& properties, - bool always_compressed, + bool alwaysCompressed, const CryptoContext* ctx) { return std::unique_ptr(new SerializedPageReader( std::move(stream), - total_num_values, - codec, + totalNumValues, + Codec, properties, ctx, - always_compressed)); + alwaysCompressed)); } -std::unique_ptr PageReader::Open( +std::unique_ptr PageReader::open( std::shared_ptr stream, - int64_t total_num_values, - Compression::type codec, - bool always_compressed, + int64_t totalNumValues, + Compression::type Codec, + bool alwaysCompressed, ::arrow::MemoryPool* pool, const CryptoContext* ctx) { return std::unique_ptr(new SerializedPageReader( std::move(stream), - total_num_values, - codec, + totalNumValues, + Codec, ReaderProperties(pool), ctx, - always_compressed)); + alwaysCompressed)); } namespace { -// ---------------------------------------------------------------------- -// Impl base class for TypedColumnReader and RecordReader +// ----------------------------------------------------------------------. +// Impl base class for TypedColumnReader and RecordReader. -// PLAIN_DICTIONARY is deprecated but used to be used as a dictionary index -// encoding. -static bool IsDictionaryIndexEncoding(const Encoding::type& e) { - return e == Encoding::RLE_DICTIONARY || e == Encoding::PLAIN_DICTIONARY; +// PLAIN_DICTIONARY is deprecated but used to be used as a dictionary index. +// Encoding. +static bool isDictionaryIndexEncoding(const Encoding::type& e) { + return e == Encoding::kRleDictionary || e == Encoding::kPlainDictionary; } template class ColumnReaderImplBase { public: - using T = typename DType::c_type; + using T = typename DType::CType; ColumnReaderImplBase(const ColumnDescriptor* descr, ::arrow::MemoryPool* pool) : descr_(descr), - max_def_level_(descr->max_definition_level()), - max_rep_level_(descr->max_repetition_level()), - num_buffered_values_(0), - num_decoded_values_(0), + maxDefLevel_(descr->maxDefinitionLevel()), + maxRepLevel_(descr->maxRepetitionLevel()), + numBufferedValues_(0), + numDecodedValues_(0), pool_(pool), - current_decoder_(nullptr), - current_encoding_(Encoding::UNKNOWN) {} + currentDecoder_(nullptr), + currentEncoding_(Encoding::kUnknown) {} virtual ~ColumnReaderImplBase() = default; protected: - // Read up to batch_size values from the current data page into the - // pre-allocated memory T* + // Read up to batch_size values from the current data page into the. + // Pre-allocated memory T*. // - // @returns: the number of values read into the out buffer - int64_t ReadValues(int64_t batch_size, T* out) { - int64_t num_decoded = - current_decoder_->Decode(out, static_cast(batch_size)); - return num_decoded; + // @returns: the number of values read into the out buffer. + int64_t readValues(int64_t batchSize, T* out) { + int64_t numDecoded = + currentDecoder_->decode(out, static_cast(batchSize)); + return numDecoded; } - // Read up to batch_size values from the current data page into the - // pre-allocated memory T*, leaving spaces for null entries according - // to the def_levels. + // Read up to batch_size values from the current data page into the. + // Pre-allocated memory T*, leaving spaces for null entries according. + // To the def_levels. // - // @returns: the number of values read into the out buffer - int64_t ReadValuesSpaced( - int64_t batch_size, + // @returns: the number of values read into the out buffer. + int64_t readValuesSpaced( + int64_t batchSize, T* out, - int64_t null_count, - uint8_t* valid_bits, - int64_t valid_bits_offset) { - return current_decoder_->DecodeSpaced( + int64_t nullCount, + uint8_t* validBits, + int64_t validBitsOffset) { + return currentDecoder_->decodeSpaced( out, - static_cast(batch_size), - static_cast(null_count), - valid_bits, - valid_bits_offset); + static_cast(batchSize), + static_cast(nullCount), + validBits, + validBitsOffset); } - // Read multiple definition levels into preallocated memory + // Read multiple definition levels into preallocated memory. // - // Returns the number of decoded definition levels - int64_t ReadDefinitionLevels(int64_t batch_size, int16_t* levels) { - if (max_def_level_ == 0) { + // Returns the number of decoded definition levels. + int64_t readDefinitionLevels(int64_t batchSize, int16_t* levels) { + if (maxDefLevel_ == 0) { return 0; } - return definition_level_decoder_.Decode( - static_cast(batch_size), levels); + return definitionLevelDecoder_.decode(static_cast(batchSize), levels); } - bool HasNextInternal() { - // Either there is no data page available yet, or the data page has been - // exhausted - if (num_buffered_values_ == 0 || - num_decoded_values_ == num_buffered_values_) { - if (!ReadNewPage() || num_buffered_values_ == 0) { + bool hasNextInternal() { + // Either there is no data page available yet, or the data page has been. + // Exhausted. + if (numBufferedValues_ == 0 || numDecodedValues_ == numBufferedValues_) { + if (!readNewPage() || numBufferedValues_ == 0) { return false; } } return true; } - // Read multiple repetition levels into preallocated memory - // Returns the number of decoded repetition levels - int64_t ReadRepetitionLevels(int64_t batch_size, int16_t* levels) { - if (max_rep_level_ == 0) { + // Read multiple repetition levels into preallocated memory. + // Returns the number of decoded repetition levels. + int64_t readRepetitionLevels(int64_t batchSize, int16_t* levels) { + if (maxRepLevel_ == 0) { return 0; } - return repetition_level_decoder_.Decode( - static_cast(batch_size), levels); + return repetitionLevelDecoder_.decode(static_cast(batchSize), levels); } - // Advance to the next data page - bool ReadNewPage() { + // Advance to the next data page. + bool readNewPage() { // Loop until we find the next data page. while (true) { - current_page_ = pager_->NextPage(); - if (!current_page_) { - // EOS + currentPage_ = pager_->nextPage(); + if (!currentPage_) { + // EOS. return false; } - if (current_page_->type() == PageType::DICTIONARY_PAGE) { - ConfigureDictionary( - static_cast(current_page_.get())); + if (currentPage_->type() == PageType::kDictionaryPage) { + configureDictionary( + static_cast(currentPage_.get())); continue; - } else if (current_page_->type() == PageType::DATA_PAGE) { - const auto page = std::static_pointer_cast(current_page_); - const int64_t levels_byte_size = InitializeLevelDecoders( + } else if (currentPage_->type() == PageType::kDataPage) { + const auto page = std::static_pointer_cast(currentPage_); + const int64_t levelsByteSize = initializeLevelDecoders( *page, - page->repetition_level_encoding(), - page->definition_level_encoding()); - InitializeDataDecoder(*page, levels_byte_size); + page->repetitionLevelEncoding(), + page->definitionLevelEncoding()); + initializeDataDecoder(*page, levelsByteSize); return true; - } else if (current_page_->type() == PageType::DATA_PAGE_V2) { - const auto page = std::static_pointer_cast(current_page_); - int64_t levels_byte_size = InitializeLevelDecodersV2(*page); - InitializeDataDecoder(*page, levels_byte_size); + } else if (currentPage_->type() == PageType::kDataPageV2) { + const auto page = std::static_pointer_cast(currentPage_); + int64_t levelsByteSize = initializeLevelDecodersV2(*page); + initializeDataDecoder(*page, levelsByteSize); return true; } else { - // We don't know what this page type is. We're allowed to skip non-data - // pages. + // We don't know what this page type is. We're allowed to skip non-data. + // Pages. continue; } } return true; } - void ConfigureDictionary(const DictionaryPage* page) { + void configureDictionary(const DictionaryPage* page) { int encoding = static_cast(page->encoding()); - if (page->encoding() == Encoding::PLAIN_DICTIONARY || - page->encoding() == Encoding::PLAIN) { - encoding = static_cast(Encoding::RLE_DICTIONARY); + if (page->encoding() == Encoding::kPlainDictionary || + page->encoding() == Encoding::kPlain) { + encoding = static_cast(Encoding::kRleDictionary); } auto it = decoders_.find(encoding); @@ -841,20 +830,20 @@ class ColumnReaderImplBase { throw ParquetException("Column cannot have more than one dictionary."); } - if (page->encoding() == Encoding::PLAIN_DICTIONARY || - page->encoding() == Encoding::PLAIN) { - auto dictionary = MakeTypedDecoder(Encoding::PLAIN, descr_); - dictionary->SetData(page->num_values(), page->data(), page->size()); + if (page->encoding() == Encoding::kPlainDictionary || + page->encoding() == Encoding::kPlain) { + auto dictionary = makeTypedDecoder(Encoding::kPlain, descr_); + dictionary->setData(page->numValues(), page->data(), page->size()); - // The dictionary is fully decoded during DictionaryDecoder::Init, so the - // DictionaryPage buffer is no longer required after this step + // The dictionary is fully decoded during DictionaryDecoder::Init, so the. + // DictionaryPage buffer is no longer required after this step. // - // TODO(wesm): investigate whether this all-or-nothing decoding of the - // dictionary makes sense and whether performance can be improved + // TODO(wesm): investigate whether this all-or-nothing decoding of the. + // Dictionary makes sense and whether performance can be improved. std::unique_ptr> decoder = - MakeDictDecoder(descr_, pool_); - decoder->SetDict(dictionary.get()); + makeDictDecoder(descr_, pool_); + decoder->setDict(dictionary.get()); decoders_[encoding] = std::unique_ptr( dynamic_cast(decoder.release())); } else { @@ -862,164 +851,164 @@ class ColumnReaderImplBase { "only plain dictionary encoding has been implemented"); } - new_dictionary_ = true; - current_decoder_ = decoders_[encoding].get(); - VELOX_DCHECK(current_decoder_); + newDictionary_ = true; + currentDecoder_ = decoders_[encoding].get(); + VELOX_DCHECK(currentDecoder_); } // Initialize repetition and definition level decoders on the next data page. - // If the data page includes repetition and definition levels, we - // initialize the level decoders and return the number of encoded level bytes. + // If the data page includes repetition and definition levels, we. + // Initialize the level decoders and return the number of encoded level bytes. // The return value helps determine the number of bytes in the encoded data. - int64_t InitializeLevelDecoders( + int64_t initializeLevelDecoders( const DataPage& page, - Encoding::type repetition_level_encoding, - Encoding::type definition_level_encoding) { + Encoding::type repetitionLevelEncoding, + Encoding::type definitionLevelEncoding) { // Read a data page. - num_buffered_values_ = page.num_values(); + numBufferedValues_ = page.numValues(); - // Have not decoded any values from the data page yet - num_decoded_values_ = 0; + // Have not decoded any values from the data page yet. + numDecodedValues_ = 0; const uint8_t* buffer = page.data(); - int32_t levels_byte_size = 0; - int32_t max_size = page.size(); + int32_t levelsByteSize = 0; + int32_t maxSize = page.size(); // Data page Layout: Repetition Levels - Definition Levels - encoded values. // Levels are encoded as rle or bit-packed. - // Init repetition levels - if (max_rep_level_ > 0) { - int32_t rep_levels_bytes = repetition_level_decoder_.SetData( - repetition_level_encoding, - max_rep_level_, - static_cast(num_buffered_values_), + // Init repetition levels. + if (maxRepLevel_ > 0) { + int32_t repLevelsBytes = repetitionLevelDecoder_.setData( + repetitionLevelEncoding, + maxRepLevel_, + static_cast(numBufferedValues_), buffer, - max_size); - buffer += rep_levels_bytes; - levels_byte_size += rep_levels_bytes; - max_size -= rep_levels_bytes; - } - // TODO figure a way to set max_def_level_ to 0 - // if the initial value is invalid - - // Init definition levels - if (max_def_level_ > 0) { - int32_t def_levels_bytes = definition_level_decoder_.SetData( - definition_level_encoding, - max_def_level_, - static_cast(num_buffered_values_), + maxSize); + buffer += repLevelsBytes; + levelsByteSize += repLevelsBytes; + maxSize -= repLevelsBytes; + } + // TODO figure a way to set max_def_level_ to 0. + // If the initial value is invalid. + + // Init definition levels. + if (maxDefLevel_ > 0) { + int32_t defLevelsBytes = definitionLevelDecoder_.setData( + definitionLevelEncoding, + maxDefLevel_, + static_cast(numBufferedValues_), buffer, - max_size); - levels_byte_size += def_levels_bytes; - max_size -= def_levels_bytes; + maxSize); + levelsByteSize += defLevelsBytes; + maxSize -= defLevelsBytes; } - return levels_byte_size; + return levelsByteSize; } - int64_t InitializeLevelDecodersV2(const DataPageV2& page) { + int64_t initializeLevelDecodersV2(const DataPageV2& page) { // Read a data page. - num_buffered_values_ = page.num_values(); + numBufferedValues_ = page.numValues(); - // Have not decoded any values from the data page yet - num_decoded_values_ = 0; + // Have not decoded any values from the data page yet. + numDecodedValues_ = 0; const uint8_t* buffer = page.data(); - const int64_t total_levels_length = - static_cast(page.repetition_levels_byte_length()) + - page.definition_levels_byte_length(); + const int64_t totalLevelsLength = + static_cast(page.repetitionLevelsByteLength()) + + page.definitionLevelsByteLength(); - if (total_levels_length > page.size()) { + if (totalLevelsLength > page.size()) { throw ParquetException( "Data page too small for levels (corrupt header?)"); } - if (max_rep_level_ > 0) { - repetition_level_decoder_.SetDataV2( - page.repetition_levels_byte_length(), - max_rep_level_, - static_cast(num_buffered_values_), + if (maxRepLevel_ > 0) { + repetitionLevelDecoder_.setDataV2( + page.repetitionLevelsByteLength(), + maxRepLevel_, + static_cast(numBufferedValues_), buffer); } - // ARROW-17453: Even if max_rep_level_ is 0, there may still be - // repetition level bytes written and/or reported in the header by + // ARROW-17453: Even if max_rep_level_ is 0, there may still be. + // Repetition level bytes written and/or reported in the header by. // some writers (e.g. Athena) - buffer += page.repetition_levels_byte_length(); + buffer += page.repetitionLevelsByteLength(); - if (max_def_level_ > 0) { - definition_level_decoder_.SetDataV2( - page.definition_levels_byte_length(), - max_def_level_, - static_cast(num_buffered_values_), + if (maxDefLevel_ > 0) { + definitionLevelDecoder_.setDataV2( + page.definitionLevelsByteLength(), + maxDefLevel_, + static_cast(numBufferedValues_), buffer); } - return total_levels_length; + return totalLevelsLength; } - // Get a decoder object for this page or create a new decoder if this is the - // first page with this encoding. - void InitializeDataDecoder(const DataPage& page, int64_t levels_byte_size) { - const uint8_t* buffer = page.data() + levels_byte_size; - const int64_t data_size = page.size() - levels_byte_size; + // Get a decoder object for this page or create a new decoder if this is the. + // First page with this encoding. + void initializeDataDecoder(const DataPage& page, int64_t levelsByteSize) { + const uint8_t* buffer = page.data() + levelsByteSize; + const int64_t dataSize = page.size() - levelsByteSize; - if (data_size < 0) { + if (dataSize < 0) { throw ParquetException("Page smaller than size of encoded levels"); } Encoding::type encoding = page.encoding(); - if (IsDictionaryIndexEncoding(encoding)) { - encoding = Encoding::RLE_DICTIONARY; + if (isDictionaryIndexEncoding(encoding)) { + encoding = Encoding::kRleDictionary; } auto it = decoders_.find(static_cast(encoding)); if (it != decoders_.end()) { VELOX_DCHECK_NOT_NULL(it->second.get()); - current_decoder_ = it->second.get(); + currentDecoder_ = it->second.get(); } else { switch (encoding) { - case Encoding::PLAIN: { - auto decoder = MakeTypedDecoder(Encoding::PLAIN, descr_); - current_decoder_ = decoder.get(); + case Encoding::kPlain: { + auto decoder = makeTypedDecoder(Encoding::kPlain, descr_); + currentDecoder_ = decoder.get(); decoders_[static_cast(encoding)] = std::move(decoder); break; } - case Encoding::BYTE_STREAM_SPLIT: { + case Encoding::kByteStreamSplit: { auto decoder = - MakeTypedDecoder(Encoding::BYTE_STREAM_SPLIT, descr_); - current_decoder_ = decoder.get(); + makeTypedDecoder(Encoding::kByteStreamSplit, descr_); + currentDecoder_ = decoder.get(); decoders_[static_cast(encoding)] = std::move(decoder); break; } - case Encoding::RLE: { - auto decoder = MakeTypedDecoder(Encoding::RLE, descr_); - current_decoder_ = decoder.get(); + case Encoding::kRle: { + auto decoder = makeTypedDecoder(Encoding::kRle, descr_); + currentDecoder_ = decoder.get(); decoders_[static_cast(encoding)] = std::move(decoder); break; } - case Encoding::RLE_DICTIONARY: + case Encoding::kRleDictionary: throw ParquetException("Dictionary page must be before data page."); - case Encoding::DELTA_BINARY_PACKED: { + case Encoding::kDeltaBinaryPacked: { auto decoder = - MakeTypedDecoder(Encoding::DELTA_BINARY_PACKED, descr_); - current_decoder_ = decoder.get(); + makeTypedDecoder(Encoding::kDeltaBinaryPacked, descr_); + currentDecoder_ = decoder.get(); decoders_[static_cast(encoding)] = std::move(decoder); break; } - case Encoding::DELTA_BYTE_ARRAY: { + case Encoding::kDeltaByteArray: { auto decoder = - MakeTypedDecoder(Encoding::DELTA_BYTE_ARRAY, descr_); - current_decoder_ = decoder.get(); + makeTypedDecoder(Encoding::kDeltaByteArray, descr_); + currentDecoder_ = decoder.get(); decoders_[static_cast(encoding)] = std::move(decoder); break; } - case Encoding::DELTA_LENGTH_BYTE_ARRAY: { - auto decoder = MakeTypedDecoder( - Encoding::DELTA_LENGTH_BYTE_ARRAY, descr_); - current_decoder_ = decoder.get(); + case Encoding::kDeltaLengthByteArray: { + auto decoder = + makeTypedDecoder(Encoding::kDeltaLengthByteArray, descr_); + currentDecoder_ = decoder.get(); decoders_[static_cast(encoding)] = std::move(decoder); break; } @@ -1028,73 +1017,73 @@ class ColumnReaderImplBase { throw ParquetException("Unknown encoding type."); } } - current_encoding_ = encoding; - current_decoder_->SetData( - static_cast(num_buffered_values_), + currentEncoding_ = encoding; + currentDecoder_->setData( + static_cast(numBufferedValues_), buffer, - static_cast(data_size)); + static_cast(dataSize)); } - int64_t available_values_current_page() const { - return num_buffered_values_ - num_decoded_values_; + int64_t availableValuesCurrentPage() const { + return numBufferedValues_ - numDecodedValues_; } const ColumnDescriptor* descr_; - const int16_t max_def_level_; - const int16_t max_rep_level_; + const int16_t maxDefLevel_; + const int16_t maxRepLevel_; std::unique_ptr pager_; - std::shared_ptr current_page_; + std::shared_ptr currentPage_; - // Not set if full schema for this field has no optional or repeated elements - LevelDecoder definition_level_decoder_; + // Not set if full schema for this field has no optional or repeated elements. + LevelDecoder definitionLevelDecoder_; // Not set for flat schemas. - LevelDecoder repetition_level_decoder_; + LevelDecoder repetitionLevelDecoder_; - // The total number of values stored in the data page. This is the maximum of - // the number of encoded definition levels or encoded values. For - // non-repeated, required columns, this is equal to the number of encoded - // values. For repeated or optional values, there may be fewer data values - // than levels, and this tells you how many encoded levels there are in that - // case. - int64_t num_buffered_values_; + // The total number of values stored in the data page. This is the maximum of. + // The number of encoded definition levels or encoded values. For. + // Non-repeated, required columns, this is equal to the number of encoded. + // Values. For repeated or optional values, there may be fewer data values. + // Than levels, and this tells you how many encoded levels there are in that. + // Case. + int64_t numBufferedValues_; - // The number of values from the current data page that have been decoded - // into memory - int64_t num_decoded_values_; + // The number of values from the current data page that have been decoded. + // Into memory. + int64_t numDecodedValues_; ::arrow::MemoryPool* pool_; using DecoderType = TypedDecoder; - DecoderType* current_decoder_; - Encoding::type current_encoding_; + DecoderType* currentDecoder_; + Encoding::type currentEncoding_; - /// Flag to signal when a new dictionary has been set, for the benefit of - /// DictionaryRecordReader - bool new_dictionary_; + /// Flag to signal when a new dictionary has been set, for the benefit of. + /// DictionaryRecordReader. + bool newDictionary_; - // The exposed encoding - ExposedEncoding exposed_encoding_ = ExposedEncoding::NO_ENCODING; + // The exposed encoding. + ExposedEncoding exposedEncoding_ = ExposedEncoding::kNoEncoding; - // Map of encoding type to the respective decoder object. For example, a - // column chunk's data pages may include both dictionary-encoded and - // plain-encoded data. + // Map of encoding type to the respective decoder object. For example, a. + // Column chunk's data pages may include both dictionary-encoded and. + // Plain-encoded data. std::unordered_map> decoders_; - void ConsumeBufferedValues(int64_t num_values) { - num_decoded_values_ += num_values; + void consumeBufferedValues(int64_t numValues) { + numDecodedValues_ += numValues; } }; -// ---------------------------------------------------------------------- -// TypedColumnReader implementations +// ----------------------------------------------------------------------. +// TypedColumnReader implementations. template class TypedColumnReaderImpl : public TypedColumnReader, public ColumnReaderImplBase { public: - using T = typename DType::c_type; + using T = typename DType::CType; TypedColumnReaderImpl( const ColumnDescriptor* descr, @@ -1104,114 +1093,114 @@ class TypedColumnReaderImpl : public TypedColumnReader, this->pager_ = std::move(pager); } - bool HasNext() override { - return this->HasNextInternal(); + bool hasNext() override { + return this->hasNextInternal(); } - int64_t ReadBatch( - int64_t batch_size, - int16_t* def_levels, - int16_t* rep_levels, + int64_t readBatch( + int64_t batchSize, + int16_t* defLevels, + int16_t* repLevels, T* values, - int64_t* values_read) override; + int64_t* valuesRead) override; - int64_t ReadBatchSpaced( - int64_t batch_size, - int16_t* def_levels, - int16_t* rep_levels, + int64_t readBatchSpaced( + int64_t batchSize, + int16_t* defLevels, + int16_t* repLevels, T* values, - uint8_t* valid_bits, - int64_t valid_bits_offset, - int64_t* levels_read, - int64_t* values_read, - int64_t* null_count) override; + uint8_t* validBits, + int64_t validBitsOffset, + int64_t* levelsRead, + int64_t* valuesRead, + int64_t* nullCount) override; - int64_t Skip(int64_t num_values_to_skip) override; + int64_t skip(int64_t numValuesToSkip) override; Type::type type() const override { - return this->descr_->physical_type(); + return this->descr_->physicalType(); } const ColumnDescriptor* descr() const override { return this->descr_; } - ExposedEncoding GetExposedEncoding() override { - return this->exposed_encoding_; + ExposedEncoding getExposedEncoding() override { + return this->exposedEncoding_; }; - int64_t ReadBatchWithDictionary( - int64_t batch_size, - int16_t* def_levels, - int16_t* rep_levels, + int64_t readBatchWithDictionary( + int64_t batchSize, + int16_t* defLevels, + int16_t* repLevels, int32_t* indices, - int64_t* indices_read, + int64_t* indicesRead, const T** dict, - int32_t* dict_len) override; + int32_t* dictLen) override; protected: - void SetExposedEncoding(ExposedEncoding encoding) override { - this->exposed_encoding_ = encoding; + void setExposedEncoding(ExposedEncoding encoding) override { + this->exposedEncoding_ = encoding; } - // Allocate enough scratch space to accommodate skipping 16-bit levels or any - // value type. - void InitScratchForSkip(); + // Allocate enough scratch space to accommodate skipping 16-bit levels or any. + // Value type. + void initScratchForSkip(); - // Scratch space for reading and throwing away rep/def levels and values when - // skipping. - std::shared_ptr scratch_for_skip_; + // Scratch space for reading and throwing away rep/def levels and values when. + // Skipping. + std::shared_ptr scratchForSkip_; private: - // Read dictionary indices. Similar to ReadValues but decode data to - // dictionary indices. This function is called only by + // Read dictionary indices. Similar to ReadValues but decode data to. + // Dictionary indices. This function is called only by. // ReadBatchWithDictionary(). - int64_t ReadDictionaryIndices(int64_t indices_to_read, int32_t* indices) { - auto decoder = dynamic_cast*>(this->current_decoder_); - return decoder->DecodeIndices(static_cast(indices_to_read), indices); + int64_t readDictionaryIndices(int64_t indicesToRead, int32_t* indices) { + auto decoder = dynamic_cast*>(this->currentDecoder_); + return decoder->decodeIndices(static_cast(indicesToRead), indices); } - // Get dictionary. The dictionary should have been set by SetDict(). The - // dictionary is owned by the internal decoder and is destroyed when the - // reader is destroyed. This function is called only by + // Get dictionary. The dictionary should have been set by SetDict(). The. + // Dictionary is owned by the internal decoder and is destroyed when the. + // Reader is destroyed. This function is called only by. // ReadBatchWithDictionary() after dictionary is configured. - void GetDictionary(const T** dictionary, int32_t* dictionary_length) { - auto decoder = dynamic_cast*>(this->current_decoder_); - decoder->GetDictionary(dictionary, dictionary_length); - } - - // Read definition and repetition levels. Also return the number of definition - // levels and number of values to read. This function is called before reading - // values. - void ReadLevels( - int64_t batch_size, - int16_t* def_levels, - int16_t* rep_levels, - int64_t* num_def_levels, - int64_t* values_to_read) { - batch_size = std::min( - batch_size, this->num_buffered_values_ - this->num_decoded_values_); - - // If the field is required and non-repeated, there are no definition levels - if (this->max_def_level_ > 0 && def_levels != nullptr) { - *num_def_levels = this->ReadDefinitionLevels(batch_size, def_levels); - // TODO(wesm): this tallying of values-to-decode can be performed with - // better cache-efficiency if fused with the level decoding. - for (int64_t i = 0; i < *num_def_levels; ++i) { - if (def_levels[i] == this->max_def_level_) { - ++(*values_to_read); + void getDictionary(const T** dictionary, int32_t* dictionaryLength) { + auto decoder = dynamic_cast*>(this->currentDecoder_); + decoder->getDictionary(dictionary, dictionaryLength); + } + + // Read definition and repetition levels. Also return the number of + // definition. Levels and number of values to read. This function is called + // before reading. Values. + void readLevels( + int64_t batchSize, + int16_t* defLevels, + int16_t* repLevels, + int64_t* numDefLevels, + int64_t* valuesToRead) { + batchSize = + std::min(batchSize, this->numBufferedValues_ - this->numDecodedValues_); + + // If the field is required and non-repeated, there are no definition + // levels. + if (this->maxDefLevel_ > 0 && defLevels != nullptr) { + *numDefLevels = this->readDefinitionLevels(batchSize, defLevels); + // TODO(wesm): this tallying of values-to-decode can be performed with. + // Better cache-efficiency if fused with the level decoding. + for (int64_t i = 0; i < *numDefLevels; ++i) { + if (defLevels[i] == this->maxDefLevel_) { + ++(*valuesToRead); } } } else { - // Required field, read all values - *values_to_read = batch_size; + // Required field, read all values. + *valuesToRead = batchSize; } - // Not present for non-repeated fields - if (this->max_rep_level_ > 0 && rep_levels != nullptr) { - int64_t num_rep_levels = - this->ReadRepetitionLevels(batch_size, rep_levels); - if (def_levels != nullptr && *num_def_levels != num_rep_levels) { + // Not present for non-repeated fields. + if (this->maxRepLevel_ > 0 && repLevels != nullptr) { + int64_t numRepLevels = this->readRepetitionLevels(batchSize, repLevels); + if (defLevels != nullptr && *numDefLevels != numRepLevels) { throw ParquetException( "Number of decoded rep / def levels did not match"); } @@ -1220,277 +1209,267 @@ class TypedColumnReaderImpl : public TypedColumnReader, }; template -int64_t TypedColumnReaderImpl::ReadBatchWithDictionary( - int64_t batch_size, - int16_t* def_levels, - int16_t* rep_levels, +int64_t TypedColumnReaderImpl::readBatchWithDictionary( + int64_t batchSize, + int16_t* defLevels, + int16_t* repLevels, int32_t* indices, - int64_t* indices_read, + int64_t* indicesRead, const T** dict, - int32_t* dict_len) { - bool has_dict_output = dict != nullptr && dict_len != nullptr; + int32_t* dictLen) { + bool hasDictOutput = dict != nullptr && dictLen != nullptr; // Similar logic as ReadValues to get pages. - if (!HasNext()) { - *indices_read = 0; - if (has_dict_output) { + if (!hasNext()) { + *indicesRead = 0; + if (hasDictOutput) { *dict = nullptr; - *dict_len = 0; + *dictLen = 0; } return 0; } // Verify the current data page is dictionary encoded. - if (this->current_encoding_ != Encoding::RLE_DICTIONARY) { + if (this->currentEncoding_ != Encoding::kRleDictionary) { std::stringstream ss; ss << "Data page is not dictionary encoded. Encoding: " - << EncodingToString(this->current_encoding_); + << encodingToString(this->currentEncoding_); throw ParquetException(ss.str()); } // Get dictionary pointer and length. - if (has_dict_output) { - GetDictionary(dict, dict_len); + if (hasDictOutput) { + getDictionary(dict, dictLen); } // Similar logic as ReadValues to get def levels and rep levels. - int64_t num_def_levels = 0; - int64_t indices_to_read = 0; - ReadLevels( - batch_size, def_levels, rep_levels, &num_def_levels, &indices_to_read); + int64_t numDefLevels = 0; + int64_t indicesToRead = 0; + readLevels(batchSize, defLevels, repLevels, &numDefLevels, &indicesToRead); // Read dictionary indices. - *indices_read = ReadDictionaryIndices(indices_to_read, indices); - int64_t total_indices = std::max(num_def_levels, *indices_read); + *indicesRead = readDictionaryIndices(indicesToRead, indices); + int64_t totalIndices = std::max(numDefLevels, *indicesRead); // Some callers use a batch size of 0 just to get the dictionary. - int64_t expected_values = std::min( - batch_size, this->num_buffered_values_ - this->num_decoded_values_); - if (total_indices == 0 && expected_values > 0) { + int64_t expectedValues = + std::min(batchSize, this->numBufferedValues_ - this->numDecodedValues_); + if (totalIndices == 0 && expectedValues > 0) { std::stringstream ss; - ss << "Read 0 values, expected " << expected_values; - ParquetException::EofException(ss.str()); + ss << "Read 0 values, expected " << expectedValues; + ParquetException::eofException(ss.str()); } - this->ConsumeBufferedValues(total_indices); + this->consumeBufferedValues(totalIndices); - return total_indices; + return totalIndices; } template -int64_t TypedColumnReaderImpl::ReadBatch( - int64_t batch_size, - int16_t* def_levels, - int16_t* rep_levels, +int64_t TypedColumnReaderImpl::readBatch( + int64_t batchSize, + int16_t* defLevels, + int16_t* repLevels, T* values, - int64_t* values_read) { - // HasNext invokes ReadNewPage - if (!HasNext()) { - *values_read = 0; + int64_t* valuesRead) { + // HasNext invokes ReadNewPage. + if (!hasNext()) { + *valuesRead = 0; return 0; } - // TODO(wesm): keep reading data pages until batch_size is reached, or the - // row group is finished - int64_t num_def_levels = 0; - int64_t values_to_read = 0; - ReadLevels( - batch_size, def_levels, rep_levels, &num_def_levels, &values_to_read); + // TODO(wesm): keep reading data pages until batch_size is reached, or the. + // Row group is finished. + int64_t numDefLevels = 0; + int64_t valuesToRead = 0; + readLevels(batchSize, defLevels, repLevels, &numDefLevels, &valuesToRead); - *values_read = this->ReadValues(values_to_read, values); - int64_t total_values = std::max(num_def_levels, *values_read); - int64_t expected_values = std::min( - batch_size, this->num_buffered_values_ - this->num_decoded_values_); - if (total_values == 0 && expected_values > 0) { + *valuesRead = this->readValues(valuesToRead, values); + int64_t totalValues = std::max(numDefLevels, *valuesRead); + int64_t expectedValues = + std::min(batchSize, this->numBufferedValues_ - this->numDecodedValues_); + if (totalValues == 0 && expectedValues > 0) { std::stringstream ss; - ss << "Read 0 values, expected " << expected_values; - ParquetException::EofException(ss.str()); + ss << "Read 0 values, expected " << expectedValues; + ParquetException::eofException(ss.str()); } - this->ConsumeBufferedValues(total_values); + this->consumeBufferedValues(totalValues); - return total_values; + return totalValues; } template -int64_t TypedColumnReaderImpl::ReadBatchSpaced( - int64_t batch_size, - int16_t* def_levels, - int16_t* rep_levels, +int64_t TypedColumnReaderImpl::readBatchSpaced( + int64_t batchSize, + int16_t* defLevels, + int16_t* repLevels, T* values, - uint8_t* valid_bits, - int64_t valid_bits_offset, - int64_t* levels_read, - int64_t* values_read, - int64_t* null_count_out) { - // HasNext invokes ReadNewPage - if (!HasNext()) { - *levels_read = 0; - *values_read = 0; - *null_count_out = 0; + uint8_t* validBits, + int64_t validBitsOffset, + int64_t* levelsRead, + int64_t* valuesRead, + int64_t* nullCountOut) { + // HasNext invokes ReadNewPage. + if (!hasNext()) { + *levelsRead = 0; + *valuesRead = 0; + *nullCountOut = 0; return 0; } - int64_t total_values; - // TODO(wesm): keep reading data pages until batch_size is reached, or the - // row group is finished - batch_size = std::min( - batch_size, this->num_buffered_values_ - this->num_decoded_values_); + int64_t totalValues; + // TODO(wesm): keep reading data pages until batch_size is reached, or the. + // Row group is finished. + batchSize = + std::min(batchSize, this->numBufferedValues_ - this->numDecodedValues_); - // If the field is required and non-repeated, there are no definition levels - if (this->max_def_level_ > 0) { - int64_t num_def_levels = this->ReadDefinitionLevels(batch_size, def_levels); + // If the field is required and non-repeated, there are no definition levels. + if (this->maxDefLevel_ > 0) { + int64_t numDefLevels = this->readDefinitionLevels(batchSize, defLevels); - // Not present for non-repeated fields - if (this->max_rep_level_ > 0) { - int64_t num_rep_levels = - this->ReadRepetitionLevels(batch_size, rep_levels); - if (num_def_levels != num_rep_levels) { + // Not present for non-repeated fields. + if (this->maxRepLevel_ > 0) { + int64_t numRepLevels = this->readRepetitionLevels(batchSize, repLevels); + if (numDefLevels != numRepLevels) { throw ParquetException( "Number of decoded rep / def levels did not match"); } } - const bool has_spaced_values = HasSpacedValues(this->descr_); - int64_t null_count = 0; - if (!has_spaced_values) { - int values_to_read = 0; - for (int64_t i = 0; i < num_def_levels; ++i) { - if (def_levels[i] == this->max_def_level_) { - ++values_to_read; + const bool hasSpacedValuesFlag = hasSpacedValues(this->descr_); + int64_t nullCount = 0; + if (!hasSpacedValuesFlag) { + int valuesToRead = 0; + for (int64_t i = 0; i < numDefLevels; ++i) { + if (defLevels[i] == this->maxDefLevel_) { + ++valuesToRead; } } - total_values = this->ReadValues(values_to_read, values); + totalValues = this->readValues(valuesToRead, values); ::arrow::bit_util::SetBitsTo( - valid_bits, - valid_bits_offset, - /*length=*/total_values, - /*bits_are_set=*/true); - *values_read = total_values; + validBits, validBitsOffset, totalValues, true); + *valuesRead = totalValues; } else { LevelInfo info; - info.repeatedAncestorDefLevel = this->max_def_level_ - 1; - info.defLevel = this->max_def_level_; - info.repLevel = this->max_rep_level_; - ValidityBitmapInputOutput validity_io; - validity_io.valuesReadUpperBound = num_def_levels; - validity_io.validBits = valid_bits; - validity_io.validBitsOffset = valid_bits_offset; - validity_io.nullCount = null_count; - validity_io.valuesRead = *values_read; - - DefLevelsToBitmap(def_levels, num_def_levels, info, &validity_io); - null_count = validity_io.nullCount; - *values_read = validity_io.valuesRead; - - total_values = this->ReadValuesSpaced( - *values_read, + info.repeatedAncestorDefLevel = this->maxDefLevel_ - 1; + info.defLevel = this->maxDefLevel_; + info.repLevel = this->maxRepLevel_; + ValidityBitmapInputOutput validityIo; + validityIo.valuesReadUpperBound = numDefLevels; + validityIo.validBits = validBits; + validityIo.validBitsOffset = validBitsOffset; + validityIo.nullCount = nullCount; + validityIo.valuesRead = *valuesRead; + + DefLevelsToBitmap(defLevels, numDefLevels, info, &validityIo); + nullCount = validityIo.nullCount; + *valuesRead = validityIo.valuesRead; + + totalValues = this->readValuesSpaced( + *valuesRead, values, - static_cast(null_count), - valid_bits, - valid_bits_offset); + static_cast(nullCount), + validBits, + validBitsOffset); } - *levels_read = num_def_levels; - *null_count_out = null_count; + *levelsRead = numDefLevels; + *nullCountOut = nullCount; } else { - // Required field, read all values - total_values = this->ReadValues(batch_size, values); - ::arrow::bit_util::SetBitsTo( - valid_bits, - valid_bits_offset, - /*length=*/total_values, - /*bits_are_set=*/true); - *null_count_out = 0; - *values_read = total_values; - *levels_read = total_values; - } - - this->ConsumeBufferedValues(*levels_read); - return total_values; + // Required field, read all values. + totalValues = this->readValues(batchSize, values); + ::arrow::bit_util::SetBitsTo(validBits, validBitsOffset, totalValues, true); + *nullCountOut = 0; + *valuesRead = totalValues; + *levelsRead = totalValues; + } + + this->consumeBufferedValues(*levelsRead); + return totalValues; } template -void TypedColumnReaderImpl::InitScratchForSkip() { - if (this->scratch_for_skip_ == nullptr) { - int value_size = type_traits::value_byte_size; - this->scratch_for_skip_ = AllocateBuffer( +void TypedColumnReaderImpl::initScratchForSkip() { + if (this->scratchForSkip_ == nullptr) { + int valueSize = TypeTraits::valueByteSize; + this->scratchForSkip_ = allocateBuffer( this->pool_, - kSkipScratchBatchSize * std::max(sizeof(int16_t), value_size)); + kSkipScratchBatchSize * std::max(sizeof(int16_t), valueSize)); } } template -int64_t TypedColumnReaderImpl::Skip(int64_t num_values_to_skip) { - int64_t values_to_skip = num_values_to_skip; +int64_t TypedColumnReaderImpl::skip(int64_t numValuesToSkip) { + int64_t valuesToSkip = numValuesToSkip; // Optimization: Do not call HasNext() when values_to_skip == 0. - while (values_to_skip > 0 && HasNext()) { - // If the number of values to skip is more than the number of undecoded - // values, skip the Page. - const int64_t available_values = this->available_values_current_page(); - if (values_to_skip >= available_values) { - values_to_skip -= available_values; - this->ConsumeBufferedValues(available_values); + while (valuesToSkip > 0 && hasNext()) { + // If the number of values to skip is more than the number of undecoded. + // Values, skip the Page. + const int64_t availableValues = this->availableValuesCurrentPage(); + if (valuesToSkip >= availableValues) { + valuesToSkip -= availableValues; + this->consumeBufferedValues(availableValues); } else { - // We need to read this Page - // Jump to the right offset in the Page - int64_t values_read = 0; - InitScratchForSkip(); - VELOX_DCHECK_NOT_NULL(this->scratch_for_skip_); + // We need to read this Page. + // Jump to the right offset in the Page. + int64_t valuesRead = 0; + initScratchForSkip(); + VELOX_DCHECK_NOT_NULL(this->scratchForSkip_); do { - int64_t batch_size = std::min(kSkipScratchBatchSize, values_to_skip); - values_read = ReadBatch( - static_cast(batch_size), - reinterpret_cast(this->scratch_for_skip_->mutable_data()), - reinterpret_cast(this->scratch_for_skip_->mutable_data()), - reinterpret_cast(this->scratch_for_skip_->mutable_data()), - &values_read); - values_to_skip -= values_read; - } while (values_read > 0 && values_to_skip > 0); - } - } - return num_values_to_skip - values_to_skip; + int64_t batchSize = std::min(kSkipScratchBatchSize, valuesToSkip); + valuesRead = readBatch( + static_cast(batchSize), + reinterpret_cast(this->scratchForSkip_->mutable_data()), + reinterpret_cast(this->scratchForSkip_->mutable_data()), + reinterpret_cast(this->scratchForSkip_->mutable_data()), + &valuesRead); + valuesToSkip -= valuesRead; + } while (valuesRead > 0 && valuesToSkip > 0); + } + } + return numValuesToSkip - valuesToSkip; } } // namespace -// ---------------------------------------------------------------------- -// Dynamic column reader constructor +// ----------------------------------------------------------------------. +// Dynamic column reader constructor. -std::shared_ptr ColumnReader::Make( +std::shared_ptr ColumnReader::make( const ColumnDescriptor* descr, std::unique_ptr pager, MemoryPool* pool) { - switch (descr->physical_type()) { - case Type::BOOLEAN: + switch (descr->physicalType()) { + case Type::kBoolean: return std::make_shared>( descr, std::move(pager), pool); - case Type::INT32: + case Type::kInt32: return std::make_shared>( descr, std::move(pager), pool); - case Type::INT64: + case Type::kInt64: return std::make_shared>( descr, std::move(pager), pool); - case Type::INT96: + case Type::kInt96: return std::make_shared>( descr, std::move(pager), pool); - case Type::FLOAT: + case Type::kFloat: return std::make_shared>( descr, std::move(pager), pool); - case Type::DOUBLE: + case Type::kDouble: return std::make_shared>( descr, std::move(pager), pool); - case Type::BYTE_ARRAY: + case Type::kByteArray: return std::make_shared>( descr, std::move(pager), pool); - case Type::FIXED_LEN_BYTE_ARRAY: + case Type::kFixedLenByteArray: return std::make_shared>( descr, std::move(pager), pool); default: ParquetException::NYI("type reader not implemented"); } - // Unreachable code, but suppress compiler warning + // Unreachable code, but suppress compiler warning. return std::shared_ptr(nullptr); } -// ---------------------------------------------------------------------- -// RecordReader +// ----------------------------------------------------------------------. +// RecordReader. namespace internal { @@ -1500,526 +1479,515 @@ template class TypedRecordReader : public TypedColumnReaderImpl, virtual public RecordReader { public: - using T = typename DType::c_type; + using T = typename DType::CType; using BASE = TypedColumnReaderImpl; TypedRecordReader( const ColumnDescriptor* descr, - LevelInfo leaf_info, + LevelInfo leafInfo, MemoryPool* pool, - bool read_dense_for_nullable) + bool readDenseForNullable) // Pager must be set using SetPageReader. : BASE(descr, /* pager = */ nullptr, pool) { - leaf_info_ = leaf_info; - nullable_values_ = leaf_info_.HasNullableValues(); - at_record_start_ = true; - values_written_ = 0; - null_count_ = 0; - values_capacity_ = 0; - levels_written_ = 0; - levels_position_ = 0; - levels_capacity_ = 0; - read_dense_for_nullable_ = read_dense_for_nullable; - uses_values_ = !(descr->physical_type() == Type::BYTE_ARRAY); - - if (uses_values_) { - values_ = AllocateBuffer(pool); - } - valid_bits_ = AllocateBuffer(pool); - def_levels_ = AllocateBuffer(pool); - rep_levels_ = AllocateBuffer(pool); - TypedRecordReader::Reset(); - } - - // Compute the values capacity in bytes for the given number of elements - int64_t bytes_for_values(int64_t nitems) const { - int64_t type_size = GetTypeByteSize(this->descr_->physical_type()); - int64_t bytes_for_values = -1; - if (MultiplyWithOverflow(nitems, type_size, &bytes_for_values)) { + leafInfo_ = leafInfo; + nullableValues_ = leafInfo_.HasNullableValues(); + atRecordStart_ = true; + valuesWritten_ = 0; + nullCount_ = 0; + valuesCapacity_ = 0; + levelsWritten_ = 0; + levelsPosition_ = 0; + levelsCapacity_ = 0; + readDenseForNullable_ = readDenseForNullable; + usesValues_ = !(descr->physicalType() == Type::kByteArray); + + if (usesValues_) { + values_ = allocateBuffer(pool); + } + validBits_ = allocateBuffer(pool); + defLevels_ = allocateBuffer(pool); + repLevels_ = allocateBuffer(pool); + TypedRecordReader::reset(); + } + + // Compute the values capacity in bytes for the given number of elements. + int64_t bytesForValues(int64_t nitems) const { + int64_t typeSize = getTypeByteSize(this->descr_->physicalType()); + int64_t bytesForValues = -1; + if (MultiplyWithOverflow(nitems, typeSize, &bytesForValues)) { throw ParquetException("Total size of items too large"); } - return bytes_for_values; + return bytesForValues; } - int64_t ReadRecords(int64_t num_records) override { - if (num_records == 0) + int64_t readRecords(int64_t numRecords) override { + if (numRecords == 0) return 0; - // Delimit records, then read values at the end - int64_t records_read = 0; + // Delimit records, then read values at the end. + int64_t recordsRead = 0; - if (has_values_to_process()) { - records_read += ReadRecordData(num_records); + if (hasValuesToProcess()) { + recordsRead += readRecordData(numRecords); } - int64_t level_batch_size = - std::max(kMinLevelBatchSize, num_records); + int64_t levelBatchSize = std::max(kMinLevelBatchSize, numRecords); - // If we are in the middle of a record, we continue until reaching the - // desired number of records or the end of the current record if we've found - // enough records - while (!at_record_start_ || records_read < num_records) { + // If we are in the middle of a record, we continue until reaching the. + // Desired number of records or the end of the current record if we've + // found. Enough records. + while (!atRecordStart_ || recordsRead < numRecords) { // Is there more data to read in this row group? - if (!this->HasNextInternal()) { - if (!at_record_start_) { - // We ended the row group while inside a record that we haven't seen - // the end of yet. So increment the record count for the last record - // in the row group - ++records_read; - at_record_start_ = true; + if (!this->hasNextInternal()) { + if (!atRecordStart_) { + // We ended the row group while inside a record that we haven't seen. + // The end of yet. So increment the record count for the last record. + // In the row group. + ++recordsRead; + atRecordStart_ = true; } break; } - /// We perform multiple batch reads until we either exhaust the row group - /// or observe the desired number of records - int64_t batch_size = - std::min(level_batch_size, this->available_values_current_page()); + /// We perform multiple batch reads until we either exhaust the row group. + /// Or observe the desired number of records. + int64_t batchSize = + std::min(levelBatchSize, this->availableValuesCurrentPage()); - // No more data in column - if (batch_size == 0) { + // No more data in column. + if (batchSize == 0) { break; } - if (this->max_def_level_ > 0) { - ReserveLevels(batch_size); + if (this->maxDefLevel_ > 0) { + reserveLevels(batchSize); - int16_t* def_levels = this->def_levels() + levels_written_; - int16_t* rep_levels = this->rep_levels() + levels_written_; + int16_t* defLevels = this->defLevels() + levelsWritten_; + int16_t* repLevels = this->repLevels() + levelsWritten_; - // Not present for non-repeated fields - int64_t levels_read = 0; - if (this->max_rep_level_ > 0) { - levels_read = this->ReadDefinitionLevels(batch_size, def_levels); - if (this->ReadRepetitionLevels(batch_size, rep_levels) != - levels_read) { + // Not present for non-repeated fields. + int64_t levelsRead = 0; + if (this->maxRepLevel_ > 0) { + levelsRead = this->readDefinitionLevels(batchSize, defLevels); + if (this->readRepetitionLevels(batchSize, repLevels) != levelsRead) { throw ParquetException( "Number of decoded rep / def levels did not match"); } - } else if (this->max_def_level_ > 0) { - levels_read = this->ReadDefinitionLevels(batch_size, def_levels); + } else if (this->maxDefLevel_ > 0) { + levelsRead = this->readDefinitionLevels(batchSize, defLevels); } - // Exhausted column chunk - if (levels_read == 0) { + // Exhausted column chunk. + if (levelsRead == 0) { break; } - levels_written_ += levels_read; - records_read += ReadRecordData(num_records - records_read); + levelsWritten_ += levelsRead; + recordsRead += readRecordData(numRecords - recordsRead); } else { - // No repetition or definition levels - batch_size = std::min(num_records - records_read, batch_size); - records_read += ReadRecordData(batch_size); + // No repetition or definition levels. + batchSize = std::min(numRecords - recordsRead, batchSize); + recordsRead += readRecordData(batchSize); } } - return records_read; + return recordsRead; } // Throw away levels from start_levels_position to levels_position_. - // Will update levels_position_, levels_written_, and levels_capacity_ - // accordingly and move the levels to left to fill in the gap. + // Will update levels_position_, levels_written_, and levels_capacity_. + // Accordingly and move the levels to left to fill in the gap. // It will resize the buffer without releasing the memory allocation. - void ThrowAwayLevels(int64_t start_levels_position) { - VELOX_DCHECK_LE(levels_position_, levels_written_); - VELOX_DCHECK_LE(start_levels_position, levels_position_); - VELOX_DCHECK_GT(this->max_def_level_, 0); - VELOX_DCHECK_NOT_NULL(def_levels_); + void throwAwayLevels(int64_t startLevelsPosition) { + VELOX_DCHECK_LE(levelsPosition_, levelsWritten_); + VELOX_DCHECK_LE(startLevelsPosition, levelsPosition_); + VELOX_DCHECK_GT(this->maxDefLevel_, 0); + VELOX_DCHECK_NOT_NULL(defLevels_); - int64_t gap = levels_position_ - start_levels_position; + int64_t gap = levelsPosition_ - startLevelsPosition; if (gap == 0) return; - int64_t levels_remaining = levels_written_ - gap; + int64_t levelsRemaining = levelsWritten_ - gap; - auto left_shift = [&](::arrow::ResizableBuffer* buffer) { + auto leftShift = [&](::arrow::ResizableBuffer* buffer) { int16_t* data = reinterpret_cast(buffer->mutable_data()); std::copy( - data + levels_position_, - data + levels_written_, - data + start_levels_position); - PARQUET_THROW_NOT_OK(buffer->Resize( - levels_remaining * sizeof(int16_t), - /*shrink_to_fit=*/false)); + data + levelsPosition_, + data + levelsWritten_, + data + startLevelsPosition); + PARQUET_THROW_NOT_OK( + buffer->Resize(levelsRemaining * sizeof(int16_t), false)); }; - left_shift(def_levels_.get()); + leftShift(defLevels_.get()); - if (this->max_rep_level_ > 0) { - VELOX_DCHECK_NOT_NULL(rep_levels_); - left_shift(rep_levels_.get()); + if (this->maxRepLevel_ > 0) { + VELOX_DCHECK_NOT_NULL(repLevels_); + leftShift(repLevels_.get()); } - levels_written_ -= gap; - levels_position_ -= gap; - levels_capacity_ -= gap; + levelsWritten_ -= gap; + levelsPosition_ -= gap; + levelsCapacity_ -= gap; } - // Skip records that we have in our buffer. This function is only for - // non-repeated fields. - int64_t SkipRecordsInBufferNonRepeated(int64_t num_records) { - VELOX_DCHECK_EQ(this->max_rep_level_, 0); - if (!this->has_values_to_process() || num_records == 0) + // Skip records that we have in our buffer. This function is only for. + // Non-repeated fields. + int64_t skipRecordsInBufferNonRepeated(int64_t numRecords) { + VELOX_DCHECK_EQ(this->maxRepLevel_, 0); + if (!this->hasValuesToProcess() || numRecords == 0) return 0; - int64_t remaining_records = levels_written_ - levels_position_; - int64_t skipped_records = std::min(num_records, remaining_records); - int64_t start_levels_position = levels_position_; + int64_t remainingRecords = levelsWritten_ - levelsPosition_; + int64_t skippedRecords = std::min(numRecords, remainingRecords); + int64_t startLevelsPosition = levelsPosition_; // Since there is no repetition, number of levels equals number of records. - levels_position_ += skipped_records; + levelsPosition_ += skippedRecords; - // We skipped the levels by incrementing 'levels_position_'. For values - // we do not have a buffer, so we need to read them and throw them away. + // We skipped the levels by incrementing 'levels_position_'. For values. + // We do not have a buffer, so we need to read them and throw them away. // First we need to figure out how many present/not-null values there are. - std::shared_ptr<::arrow::ResizableBuffer> valid_bits; - valid_bits = AllocateBuffer(this->pool_); - PARQUET_THROW_NOT_OK(valid_bits->Resize( - ::arrow::bit_util::BytesForBits(skipped_records), - /*shrink_to_fit=*/true)); - ValidityBitmapInputOutput validity_io; - validity_io.valuesReadUpperBound = skipped_records; - validity_io.validBits = valid_bits->mutable_data(); - validity_io.validBitsOffset = 0; + std::shared_ptr<::arrow::ResizableBuffer> validBits; + validBits = allocateBuffer(this->pool_); + PARQUET_THROW_NOT_OK(validBits->Resize( + ::arrow::bit_util::BytesForBits(skippedRecords), true)); + ValidityBitmapInputOutput validityIo; + validityIo.valuesReadUpperBound = skippedRecords; + validityIo.validBits = validBits->mutable_data(); + validityIo.validBitsOffset = 0; DefLevelsToBitmap( - def_levels() + start_levels_position, - skipped_records, - this->leaf_info_, - &validity_io); - int64_t values_to_read = validity_io.valuesRead - validity_io.nullCount; - - // Now that we have figured out number of values to read, we do not need - // these levels anymore. We will remove these values from the buffer. - // This requires shifting the levels in the buffer to left. So this will - // update levels_position_ and levels_written_. - ThrowAwayLevels(start_levels_position); - // For values, we do not have them in buffer, so we will read them and - // throw them away. - ReadAndThrowAwayValues(values_to_read); + defLevels() + startLevelsPosition, + skippedRecords, + this->leafInfo_, + &validityIo); + int64_t valuesToRead = validityIo.valuesRead - validityIo.nullCount; + + // Now that we have figured out number of values to read, we do not need. + // These levels anymore. We will remove these values from the buffer. + // This requires shifting the levels in the buffer to left. So this will. + // Update levels_position_ and levels_written_. + throwAwayLevels(startLevelsPosition); + // For values, we do not have them in buffer, so we will read them and. + // Throw them away. + readAndThrowAwayValues(valuesToRead); // Mark the levels as read in the underlying column reader. - this->ConsumeBufferedValues(skipped_records); + this->consumeBufferedValues(skippedRecords); - return skipped_records; + return skippedRecords; } - // Attempts to skip num_records from the buffer. Will throw away levels - // and corresponding values for the records it skipped and consumes them from - // the underlying decoder. Will advance levels_position_ and update - // at_record_start_. + // Attempts to skip num_records from the buffer. Will throw away levels. + // And corresponding values for the records it skipped and consumes them from. + // The underlying decoder. Will advance levels_position_ and update. + // At_record_start_. // Returns how many records were skipped. - int64_t DelimitAndSkipRecordsInBuffer(int64_t num_records) { - if (num_records == 0) + int64_t delimitAndSkipRecordsInBuffer(int64_t numRecords) { + if (numRecords == 0) return 0; - // Look at the buffered levels, delimit them based on - // (rep_level == 0), report back how many records are in there, and - // fill in how many not-null values (def_level == max_def_level_). + // Look at the buffered levels, delimit them based on. + // (Rep_level == 0), report back how many records are in there, and. + // Fill in how many not-null values (def_level == max_def_level_). // DelimitRecords updates levels_position_. - int64_t start_levels_position = levels_position_; - int64_t values_seen = 0; - int64_t skipped_records = DelimitRecords(num_records, &values_seen); - ReadAndThrowAwayValues(values_seen); + int64_t startLevelsPosition = levelsPosition_; + int64_t valuesSeen = 0; + int64_t skippedRecords = delimitRecords(numRecords, &valuesSeen); + readAndThrowAwayValues(valuesSeen); // Mark those levels and values as consumed in the underlying page. - // This must be done before we throw away levels since it updates - // levels_position_ and levels_written_. - this->ConsumeBufferedValues(levels_position_ - start_levels_position); + // This must be done before we throw away levels since it updates. + // Levels_position_ and levels_written_. + this->consumeBufferedValues(levelsPosition_ - startLevelsPosition); // Updated levels_position_ and levels_written_. - ThrowAwayLevels(start_levels_position); - return skipped_records; + throwAwayLevels(startLevelsPosition); + return skippedRecords; } - // Skip records for repeated fields. For repeated fields, we are technically - // reading and throwing away the levels and values since we do not know the - // record boundaries in advance. Keep filling the buffer and skipping until we - // reach the desired number of records or we run out of values in the column - // chunk. Returns number of skipped records. - int64_t SkipRecordsRepeated(int64_t num_records) { - VELOX_DCHECK_GT(this->max_rep_level_, 0); - int64_t skipped_records = 0; + // Skip records for repeated fields. For repeated fields, we are technically. + // Reading and throwing away the levels and values since we do not know the. + // Record boundaries in advance. Keep filling the buffer and skipping until + // we. Reach the desired number of records or we run out of values in the + // column. Chunk. Returns number of skipped records. + int64_t skipRecordsRepeated(int64_t numRecords) { + VELOX_DCHECK_GT(this->maxRepLevel_, 0); + int64_t skippedRecords = 0; // First consume what is in the buffer. - if (levels_position_ < levels_written_) { + if (levelsPosition_ < levelsWritten_) { // This updates at_record_start_. - skipped_records = DelimitAndSkipRecordsInBuffer(num_records); + skippedRecords = delimitAndSkipRecordsInBuffer(numRecords); } - int64_t level_batch_size = - std::max(kMinLevelBatchSize, num_records - skipped_records); + int64_t levelBatchSize = + std::max(kMinLevelBatchSize, numRecords - skippedRecords); - // If 'at_record_start_' is false, but (skipped_records == num_records), it - // means that for the last record that was counted, we have not seen all - // of its values yet. - while (!at_record_start_ || skipped_records < num_records) { + // If 'at_record_start_' is false, but (skipped_records == num_records), it. + // Means that for the last record that was counted, we have not seen all. + // Of its values yet. + while (!atRecordStart_ || skippedRecords < numRecords) { // Is there more data to read in this row group? // HasNextInternal() will advance to the next page if necessary. - if (!this->HasNextInternal()) { - if (!at_record_start_) { - // We ended the row group while inside a record that we haven't seen - // the end of yet. So increment the record count for the last record - // in the row group - ++skipped_records; - at_record_start_ = true; + if (!this->hasNextInternal()) { + if (!atRecordStart_) { + // We ended the row group while inside a record that we haven't seen. + // The end of yet. So increment the record count for the last record. + // In the row group. + ++skippedRecords; + atRecordStart_ = true; } break; } // Read some more levels. - int64_t batch_size = - std::min(level_batch_size, this->available_values_current_page()); + int64_t batchSize = + std::min(levelBatchSize, this->availableValuesCurrentPage()); // No more data in column. This must be an empty page. - // If we had exhausted the last page, HasNextInternal() must have advanced - // to the next page. So there must be available values to process. - if (batch_size == 0) { + // If we had exhausted the last page, HasNextInternal() must have + // advanced. To the next page. So there must be available values to + // process. + if (batchSize == 0) { break; } - // For skipping we will read the levels and append them to the end - // of the def_levels and rep_levels just like for read. - ReserveLevels(batch_size); + // For skipping we will read the levels and append them to the end. + // Of the def_levels and rep_levels just like for read. + reserveLevels(batchSize); - int16_t* def_levels = this->def_levels() + levels_written_; - int16_t* rep_levels = this->rep_levels() + levels_written_; + int16_t* defLevels = this->defLevels() + levelsWritten_; + int16_t* repLevels = this->repLevels() + levelsWritten_; - int64_t levels_read = 0; - levels_read = this->ReadDefinitionLevels(batch_size, def_levels); - if (this->ReadRepetitionLevels(batch_size, rep_levels) != levels_read) { + int64_t levelsRead = 0; + levelsRead = this->readDefinitionLevels(batchSize, defLevels); + if (this->readRepetitionLevels(batchSize, repLevels) != levelsRead) { throw ParquetException( "Number of decoded rep / def levels did not match"); } - levels_written_ += levels_read; - int64_t remaining_records = num_records - skipped_records; + levelsWritten_ += levelsRead; + int64_t remainingRecords = numRecords - skippedRecords; // This updates at_record_start_. - skipped_records += DelimitAndSkipRecordsInBuffer(remaining_records); + skippedRecords += delimitAndSkipRecordsInBuffer(remainingRecords); } - return skipped_records; + return skippedRecords; } // Read 'num_values' values and throw them away. // Throws an error if it could not read 'num_values'. - void ReadAndThrowAwayValues(int64_t num_values) { - int64_t values_left = num_values; - int64_t values_read = 0; - - // Allocate enough scratch space to accommodate 16-bit levels or any - // value type - this->InitScratchForSkip(); - VELOX_DCHECK_NOT_NULL(this->scratch_for_skip_); + void readAndThrowAwayValues(int64_t numValues) { + int64_t valuesLeft = numValues; + int64_t valuesRead = 0; + + // Allocate enough scratch space to accommodate 16-bit levels or any. + // Value type. + this->initScratchForSkip(); + VELOX_DCHECK_NOT_NULL(this->scratchForSkip_); do { - int64_t batch_size = - std::min(kSkipScratchBatchSize, values_left); - values_read = this->ReadValues( - batch_size, - reinterpret_cast(this->scratch_for_skip_->mutable_data())); - values_left -= values_read; - } while (values_read > 0 && values_left > 0); - if (values_left > 0) { + int64_t batchSize = std::min(kSkipScratchBatchSize, valuesLeft); + valuesRead = this->readValues( + batchSize, + reinterpret_cast(this->scratchForSkip_->mutable_data())); + valuesLeft -= valuesRead; + } while (valuesRead > 0 && valuesLeft > 0); + if (valuesLeft > 0) { std::stringstream ss; - ss << "Could not read and throw away " << num_values << " values"; + ss << "Could not read and throw away " << numValues << " values"; throw ParquetException(ss.str()); } } - int64_t SkipRecords(int64_t num_records) override { - if (num_records == 0) + int64_t skipRecords(int64_t numRecords) override { + if (numRecords == 0) return 0; - // Top level required field. Number of records equals to number of levels, - // and there is not read-ahead for levels. - if (this->max_rep_level_ == 0 && this->max_def_level_ == 0) { - return this->Skip(num_records); + // Top level required field. Number of records equals to number of levels,. + // And there is not read-ahead for levels. + if (this->maxRepLevel_ == 0 && this->maxDefLevel_ == 0) { + return this->skip(numRecords); } - int64_t skipped_records = 0; - if (this->max_rep_level_ == 0) { + int64_t skippedRecords = 0; + if (this->maxRepLevel_ == 0) { // Non-repeated optional field. // First consume whatever is in the buffer. - skipped_records = SkipRecordsInBufferNonRepeated(num_records); + skippedRecords = skipRecordsInBufferNonRepeated(numRecords); - VELOX_DCHECK_LE(skipped_records, num_records); + VELOX_DCHECK_LE(skippedRecords, numRecords); - // For records that we have not buffered, we will use the column - // reader's Skip to do the remaining Skip. Since the field is not - // repeated number of levels to skip is the same as number of records - // to skip. - skipped_records += this->Skip(num_records - skipped_records); + // For records that we have not buffered, we will use the column. + // Reader's Skip to do the remaining Skip. Since the field is not. + // Repeated number of levels to skip is the same as number of records. + // To skip. + skippedRecords += this->skip(numRecords - skippedRecords); } else { - skipped_records += this->SkipRecordsRepeated(num_records); + skippedRecords += this->skipRecordsRepeated(numRecords); } - return skipped_records; + return skippedRecords; } - // We may outwardly have the appearance of having exhausted a column chunk - // when in fact we are in the middle of processing the last batch - bool has_values_to_process() const { - return levels_position_ < levels_written_; + // We may outwardly have the appearance of having exhausted a column chunk. + // When in fact we are in the middle of processing the last batch. + bool hasValuesToProcess() const { + return levelsPosition_ < levelsWritten_; } - std::shared_ptr ReleaseValues() override { - if (uses_values_) { + std::shared_ptr releaseValues() override { + if (usesValues_) { auto result = values_; - PARQUET_THROW_NOT_OK(result->Resize( - bytes_for_values(values_written_), /*shrink_to_fit=*/true)); - values_ = AllocateBuffer(this->pool_); - values_capacity_ = 0; + PARQUET_THROW_NOT_OK( + result->Resize(bytesForValues(valuesWritten_), true)); + values_ = allocateBuffer(this->pool_); + valuesCapacity_ = 0; return result; } else { return nullptr; } } - std::shared_ptr ReleaseIsValid() override { - if (nullable_values()) { - auto result = valid_bits_; + std::shared_ptr releaseIsValid() override { + if (nullableValues()) { + auto result = validBits_; PARQUET_THROW_NOT_OK(result->Resize( - ::arrow::bit_util::BytesForBits(values_written_), - /*shrink_to_fit=*/true)); - valid_bits_ = AllocateBuffer(this->pool_); + ::arrow::bit_util::BytesForBits(valuesWritten_), true)); + validBits_ = allocateBuffer(this->pool_); return result; } else { return nullptr; } } - // Process written repetition/definition levels to reach the end of - // records. Only used for repeated fields. - // Process no more levels than necessary to delimit the indicated - // number of logical records. Updates internal state of RecordReader + // Process written repetition/definition levels to reach the end of. + // Records. Only used for repeated fields. + // Process no more levels than necessary to delimit the indicated. + // Number of logical records. Updates internal state of RecordReader. // - // \return Number of records delimited - int64_t DelimitRecords(int64_t num_records, int64_t* values_seen) { - int64_t values_to_read = 0; - int64_t records_read = 0; - - const int16_t* def_levels = this->def_levels() + levels_position_; - const int16_t* rep_levels = this->rep_levels() + levels_position_; - - VELOX_DCHECK_GT(this->max_rep_level_, 0); - - // Count logical records and number of values to read - while (levels_position_ < levels_written_) { - const int16_t rep_level = *rep_levels++; - if (rep_level == 0) { - // If at_record_start_ is true, we are seeing the start of a record - // for the second time, such as after repeated calls to - // DelimitRecords. In this case we must continue until we find - // another record start or exhausting the ColumnChunk - if (!at_record_start_) { + // \return Number of records delimited. + int64_t delimitRecords(int64_t numRecords, int64_t* valuesSeen) { + int64_t valuesToRead = 0; + int64_t recordsRead = 0; + + const int16_t* defLevels = this->defLevels() + levelsPosition_; + const int16_t* repLevels = this->repLevels() + levelsPosition_; + + VELOX_DCHECK_GT(this->maxRepLevel_, 0); + + // Count logical records and number of values to read. + while (levelsPosition_ < levelsWritten_) { + const int16_t repLevel = *repLevels++; + if (repLevel == 0) { + // If at_record_start_ is true, we are seeing the start of a record. + // For the second time, such as after repeated calls to. + // DelimitRecords. In this case we must continue until we find. + // Another record start or exhausting the ColumnChunk. + if (!atRecordStart_) { // We've reached the end of a record; increment the record count. - ++records_read; - if (records_read == num_records) { - // We've found the number of records we were looking for. Set - // at_record_start_ to true and break - at_record_start_ = true; + ++recordsRead; + if (recordsRead == numRecords) { + // We've found the number of records we were looking for. Set. + // At_record_start_ to true and break. + atRecordStart_ = true; break; } } } - // We have decided to consume the level at this position; therefore we - // must advance until we find another record boundary - at_record_start_ = false; + // We have decided to consume the level at this position; therefore we. + // Must advance until we find another record boundary. + atRecordStart_ = false; - const int16_t def_level = *def_levels++; - if (def_level == this->max_def_level_) { - ++values_to_read; + const int16_t defLevel = *defLevels++; + if (defLevel == this->maxDefLevel_) { + ++valuesToRead; } - ++levels_position_; + ++levelsPosition_; } - *values_seen = values_to_read; - return records_read; + *valuesSeen = valuesToRead; + return recordsRead; } - void Reserve(int64_t capacity) override { - ReserveLevels(capacity); - ReserveValues(capacity); + void reserve(int64_t capacity) override { + reserveLevels(capacity); + reserveValues(capacity); } - int64_t UpdateCapacity(int64_t capacity, int64_t size, int64_t extra_size) { - if (extra_size < 0) { + int64_t updateCapacity(int64_t capacity, int64_t size, int64_t extraSize) { + if (extraSize < 0) { throw ParquetException("Negative size (corrupt file?)"); } - int64_t target_size = -1; - if (AddWithOverflow(size, extra_size, &target_size)) { + int64_t targetSize = -1; + if (AddWithOverflow(size, extraSize, &targetSize)) { throw ParquetException("Allocation size too large (corrupt file?)"); } - if (target_size >= (1LL << 62)) { + if (targetSize >= (1LL << 62)) { throw ParquetException("Allocation size too large (corrupt file?)"); } - if (capacity >= target_size) { + if (capacity >= targetSize) { return capacity; } - return ::arrow::bit_util::NextPower2(target_size); + return ::arrow::bit_util::NextPower2(targetSize); } - void ReserveLevels(int64_t extra_levels) { - if (this->max_def_level_ > 0) { - const int64_t new_levels_capacity = - UpdateCapacity(levels_capacity_, levels_written_, extra_levels); - if (new_levels_capacity > levels_capacity_) { + void reserveLevels(int64_t extraLevels) { + if (this->maxDefLevel_ > 0) { + const int64_t newLevelsCapacity = + updateCapacity(levelsCapacity_, levelsWritten_, extraLevels); + if (newLevelsCapacity > levelsCapacity_) { constexpr auto kItemSize = static_cast(sizeof(int16_t)); - int64_t capacity_in_bytes = -1; + int64_t capacityInBytes = -1; if (MultiplyWithOverflow( - new_levels_capacity, kItemSize, &capacity_in_bytes)) { + newLevelsCapacity, kItemSize, &capacityInBytes)) { throw ParquetException("Allocation size too large (corrupt file?)"); } - PARQUET_THROW_NOT_OK( - def_levels_->Resize(capacity_in_bytes, /*shrink_to_fit=*/false)); - if (this->max_rep_level_ > 0) { - PARQUET_THROW_NOT_OK( - rep_levels_->Resize(capacity_in_bytes, /*shrink_to_fit=*/false)); + PARQUET_THROW_NOT_OK(defLevels_->Resize(capacityInBytes, false)); + if (this->maxRepLevel_ > 0) { + PARQUET_THROW_NOT_OK(repLevels_->Resize(capacityInBytes, false)); } - levels_capacity_ = new_levels_capacity; + levelsCapacity_ = newLevelsCapacity; } } } - void ReserveValues(int64_t extra_values) { - const int64_t new_values_capacity = - UpdateCapacity(values_capacity_, values_written_, extra_values); - if (new_values_capacity > values_capacity_) { - // XXX(wesm): A hack to avoid memory allocation when reading directly - // into builder classes - if (uses_values_) { - PARQUET_THROW_NOT_OK(values_->Resize( - bytes_for_values(new_values_capacity), - /*shrink_to_fit=*/false)); - } - values_capacity_ = new_values_capacity; - } - if (nullable_values() && !read_dense_for_nullable_) { - int64_t valid_bytes_new = - ::arrow::bit_util::BytesForBits(values_capacity_); - if (valid_bits_->size() < valid_bytes_new) { - int64_t valid_bytes_old = - ::arrow::bit_util::BytesForBits(values_written_); + void reserveValues(int64_t extraValues) { + const int64_t newValuesCapacity = + updateCapacity(valuesCapacity_, valuesWritten_, extraValues); + if (newValuesCapacity > valuesCapacity_) { + // XXX(wesm): A hack to avoid memory allocation when reading directly. + // Into builder classes. + if (usesValues_) { PARQUET_THROW_NOT_OK( - valid_bits_->Resize(valid_bytes_new, /*shrink_to_fit=*/false)); + values_->Resize(bytesForValues(newValuesCapacity), false)); + } + valuesCapacity_ = newValuesCapacity; + } + if (nullableValues() && !readDenseForNullable_) { + int64_t validBytesNew = ::arrow::bit_util::BytesForBits(valuesCapacity_); + if (validBits_->size() < validBytesNew) { + int64_t validBytesOld = ::arrow::bit_util::BytesForBits(valuesWritten_); + PARQUET_THROW_NOT_OK(validBits_->Resize(validBytesNew, false)); - // Avoid valgrind warnings + // Avoid valgrind warnings. memset( - valid_bits_->mutable_data() + valid_bytes_old, + validBits_->mutable_data() + validBytesOld, 0, - valid_bytes_new - valid_bytes_old); + validBytesNew - validBytesOld); } } } - void Reset() override { - ResetValues(); + void reset() override { + resetValues(); - if (levels_written_ > 0) { + if (levelsWritten_ > 0) { // Throw away levels from 0 to levels_position_. - ThrowAwayLevels(0); + throwAwayLevels(0); } - // Call Finish on the binary builders to reset them + // Call Finish on the binary builders to reset them. } - void SetPageReader(std::unique_ptr reader) override { - at_record_start_ = true; + void setPageReader(std::unique_ptr reader) override { + atRecordStart_ = true; this->pager_ = std::move(reader); - ResetDecoders(); + resetDecoders(); } - bool HasMoreData() const override { + bool hasMoreData() const override { return this->pager_ != nullptr; } @@ -2027,239 +1995,237 @@ class TypedRecordReader : public TypedColumnReaderImpl, return this->descr_; } - // Dictionary decoders must be reset when advancing row groups - void ResetDecoders() { + // Dictionary decoders must be reset when advancing row groups. + void resetDecoders() { this->decoders_.clear(); } - virtual void ReadValuesSpaced(int64_t values_with_nulls, int64_t null_count) { - uint8_t* valid_bits = valid_bits_->mutable_data(); - const int64_t valid_bits_offset = values_written_; - - int64_t num_decoded = this->current_decoder_->DecodeSpaced( - ValuesHead(), - static_cast(values_with_nulls), - static_cast(null_count), - valid_bits, - valid_bits_offset); - CheckNumberDecoded(num_decoded, values_with_nulls); - } - - virtual void ReadValuesDense(int64_t values_to_read) { - int64_t num_decoded = this->current_decoder_->Decode( - ValuesHead(), static_cast(values_to_read)); - CheckNumberDecoded(num_decoded, values_to_read); - } - - // Reads repeated records and returns number of records read. Fills in - // values_to_read and null_count. - int64_t ReadRepeatedRecords( - int64_t num_records, - int64_t* values_to_read, - int64_t* null_count) { - const int64_t start_levels_position = levels_position_; - // Note that repeated records may be required or nullable. If they have - // an optional parent in the path, they will be nullable, otherwise, - // they are required. We use leaf_info_->HasNullableValues() that looks - // at repeated_ancestor_def_level to determine if it is required or - // nullable. Even if they are required, we may have to read ahead and - // delimit the records to get the right number of values and they will - // have associated levels. - int64_t records_read = DelimitRecords(num_records, values_to_read); - if (!nullable_values() || read_dense_for_nullable_) { - ReadValuesDense(*values_to_read); - // null_count is always 0 for required. - VELOX_DCHECK_EQ(*null_count, 0); + virtual void readValuesSpaced(int64_t valuesWithNulls, int64_t nullCount) { + uint8_t* validBits = validBits_->mutable_data(); + const int64_t validBitsOffset = valuesWritten_; + + int64_t numDecoded = this->currentDecoder_->decodeSpaced( + valuesHead(), + static_cast(valuesWithNulls), + static_cast(nullCount), + validBits, + validBitsOffset); + checkNumberDecoded(numDecoded, valuesWithNulls); + } + + virtual void readValuesDense(int64_t valuesToRead) { + int64_t numDecoded = this->currentDecoder_->decode( + valuesHead(), static_cast(valuesToRead)); + checkNumberDecoded(numDecoded, valuesToRead); + } + + // Reads repeated records and returns number of records read. Fills in. + // Values_to_read and null_count. + int64_t readRepeatedRecords( + int64_t numRecords, + int64_t* valuesToRead, + int64_t* nullCount) { + const int64_t startLevelsPosition = levelsPosition_; + // Note that repeated records may be required or nullable. If they have. + // An optional parent in the path, they will be nullable, otherwise,. + // They are required. We use leaf_info_->HasNullableValues() that looks. + // At repeated_ancestor_def_level to determine if it is required or. + // Nullable. Even if they are required, we may have to read ahead and. + // Delimit the records to get the right number of values and they will. + // Have associated levels. + int64_t recordsRead = delimitRecords(numRecords, valuesToRead); + if (!nullableValues() || readDenseForNullable_) { + readValuesDense(*valuesToRead); + // Null_count is always 0 for required. + VELOX_DCHECK_EQ(*nullCount, 0); } else { - ReadSpacedForOptionalOrRepeated( - start_levels_position, values_to_read, null_count); - } - return records_read; - } - - // Reads optional records and returns number of records read. Fills in - // values_to_read and null_count. - int64_t ReadOptionalRecords( - int64_t num_records, - int64_t* values_to_read, - int64_t* null_count) { - const int64_t start_levels_position = levels_position_; - // No repetition levels, skip delimiting logic. Each level represents a - // null or not null entry - int64_t records_read = - std::min(levels_written_ - levels_position_, num_records); + readSpacedForOptionalOrRepeated( + startLevelsPosition, valuesToRead, nullCount); + } + return recordsRead; + } + + // Reads optional records and returns number of records read. Fills in. + // Values_to_read and null_count. + int64_t readOptionalRecords( + int64_t numRecords, + int64_t* valuesToRead, + int64_t* nullCount) { + const int64_t startLevelsPosition = levelsPosition_; + // No repetition levels, skip delimiting logic. Each level represents a. + // Null or not null entry. + int64_t recordsRead = + std::min(levelsWritten_ - levelsPosition_, numRecords); // This is advanced by DelimitRecords for the repeated field case above. - levels_position_ += records_read; + levelsPosition_ += recordsRead; // Optional fields are always nullable. - if (read_dense_for_nullable_) { - ReadDenseForOptional(start_levels_position, values_to_read); - // We don't need to update null_count when reading dense. It should be - // already set to 0. - VELOX_DCHECK_EQ(*null_count, 0); + if (readDenseForNullable_) { + readDenseForOptional(startLevelsPosition, valuesToRead); + // We don't need to update null_count when reading dense. It should be. + // Already set to 0. + VELOX_DCHECK_EQ(*nullCount, 0); } else { - ReadSpacedForOptionalOrRepeated( - start_levels_position, values_to_read, null_count); + readSpacedForOptionalOrRepeated( + startLevelsPosition, valuesToRead, nullCount); } - return records_read; + return recordsRead; } - // Reads required records and returns number of records read. Fills in - // values_to_read. - int64_t ReadRequiredRecords(int64_t num_records, int64_t* values_to_read) { - *values_to_read = num_records; - ReadValuesDense(*values_to_read); - return num_records; + // Reads required records and returns number of records read. Fills in. + // Values_to_read. + int64_t readRequiredRecords(int64_t numRecords, int64_t* valuesToRead) { + *valuesToRead = numRecords; + readValuesDense(*valuesToRead); + return numRecords; } - // Reads dense for optional records. First it figures out how many values to - // read. - void ReadDenseForOptional( - int64_t start_levels_position, - int64_t* values_to_read) { - // levels_position_ must already be incremented based on number of records - // read. - VELOX_DCHECK_GE(levels_position_, start_levels_position); + // Reads dense for optional records. First it figures out how many values to. + // Read. + void readDenseForOptional( + int64_t startLevelsPosition, + int64_t* valuesToRead) { + // Levels_position_ must already be incremented based on number of records. + // Read. + VELOX_DCHECK_GE(levelsPosition_, startLevelsPosition); // When reading dense we need to figure out number of values to read. - const int16_t* def_levels = this->def_levels(); - for (int64_t i = start_levels_position; i < levels_position_; ++i) { - if (def_levels[i] == this->max_def_level_) { - ++(*values_to_read); + const int16_t* defLevels = this->defLevels(); + for (int64_t i = startLevelsPosition; i < levelsPosition_; ++i) { + if (defLevels[i] == this->maxDefLevel_) { + ++(*valuesToRead); } } - ReadValuesDense(*values_to_read); + readValuesDense(*valuesToRead); } // Reads spaced for optional or repeated fields. - void ReadSpacedForOptionalOrRepeated( - int64_t start_levels_position, - int64_t* values_to_read, - int64_t* null_count) { - // levels_position_ must already be incremented based on number of records - // read. - VELOX_DCHECK_GE(levels_position_, start_levels_position); - ValidityBitmapInputOutput validity_io; - validity_io.valuesReadUpperBound = levels_position_ - start_levels_position; - validity_io.validBits = valid_bits_->mutable_data(); - validity_io.validBitsOffset = values_written_; + void readSpacedForOptionalOrRepeated( + int64_t startLevelsPosition, + int64_t* valuesToRead, + int64_t* nullCount) { + // Levels_position_ must already be incremented based on number of records. + // Read. + VELOX_DCHECK_GE(levelsPosition_, startLevelsPosition); + ValidityBitmapInputOutput validityIo; + validityIo.valuesReadUpperBound = levelsPosition_ - startLevelsPosition; + validityIo.validBits = validBits_->mutable_data(); + validityIo.validBitsOffset = valuesWritten_; DefLevelsToBitmap( - def_levels() + start_levels_position, - levels_position_ - start_levels_position, - leaf_info_, - &validity_io); - *values_to_read = validity_io.valuesRead - validity_io.nullCount; - *null_count = validity_io.nullCount; - VELOX_DCHECK_GE(*values_to_read, 0); - VELOX_DCHECK_GE(*null_count, 0); - ReadValuesSpaced(validity_io.valuesRead, *null_count); + defLevels() + startLevelsPosition, + levelsPosition_ - startLevelsPosition, + leafInfo_, + &validityIo); + *valuesToRead = validityIo.valuesRead - validityIo.nullCount; + *nullCount = validityIo.nullCount; + VELOX_DCHECK_GE(*valuesToRead, 0); + VELOX_DCHECK_GE(*nullCount, 0); + readValuesSpaced(validityIo.valuesRead, *nullCount); } // Return number of logical records read. // Updates levels_position_, values_written_, and null_count_. - int64_t ReadRecordData(int64_t num_records) { - // Conservative upper bound - const int64_t possible_num_values = - std::max(num_records, levels_written_ - levels_position_); - ReserveValues(possible_num_values); - - const int64_t start_levels_position = levels_position_; - - // To be updated by the function calls below for each of the repetition - // types. - int64_t records_read = 0; - int64_t values_to_read = 0; - int64_t null_count = 0; - if (this->max_rep_level_ > 0) { + int64_t readRecordData(int64_t numRecords) { + // Conservative upper bound. + const int64_t possibleNumValues = + std::max(numRecords, levelsWritten_ - levelsPosition_); + reserveValues(possibleNumValues); + + const int64_t startLevelsPosition = levelsPosition_; + + // To be updated by the function calls below for each of the repetition. + // Types. + int64_t recordsRead = 0; + int64_t valuesToRead = 0; + int64_t nullCount = 0; + if (this->maxRepLevel_ > 0) { // Repeated fields may be nullable or not. // This call updates levels_position_. - records_read = - ReadRepeatedRecords(num_records, &values_to_read, &null_count); - } else if (this->max_def_level_ > 0) { + recordsRead = readRepeatedRecords(numRecords, &valuesToRead, &nullCount); + } else if (this->maxDefLevel_ > 0) { // Non-repeated optional values are always nullable. // This call updates levels_position_. - VELOX_DCHECK(nullable_values()); - records_read = - ReadOptionalRecords(num_records, &values_to_read, &null_count); + VELOX_DCHECK(nullableValues()); + recordsRead = readOptionalRecords(numRecords, &valuesToRead, &nullCount); } else { - VELOX_DCHECK(!nullable_values()); - records_read = ReadRequiredRecords(num_records, &values_to_read); + VELOX_DCHECK(!nullableValues()); + recordsRead = readRequiredRecords(numRecords, &valuesToRead); // We don't need to update null_count, since it is 0. } - VELOX_DCHECK_GE(records_read, 0); - VELOX_DCHECK_GE(values_to_read, 0); - VELOX_DCHECK_GE(null_count, 0); + VELOX_DCHECK_GE(recordsRead, 0); + VELOX_DCHECK_GE(valuesToRead, 0); + VELOX_DCHECK_GE(nullCount, 0); - if (read_dense_for_nullable_) { - values_written_ += values_to_read; - VELOX_DCHECK_EQ(null_count, 0); + if (readDenseForNullable_) { + valuesWritten_ += valuesToRead; + VELOX_DCHECK_EQ(nullCount, 0); } else { - values_written_ += values_to_read + null_count; - null_count_ += null_count; + valuesWritten_ += valuesToRead + nullCount; + nullCount_ += nullCount; } - // Total values, including null spaces, if any - if (this->max_def_level_ > 0) { - // Optional, repeated, or some mix thereof - this->ConsumeBufferedValues(levels_position_ - start_levels_position); + // Total values, including null spaces, if any. + if (this->maxDefLevel_ > 0) { + // Optional, repeated, or some mix thereof. + this->consumeBufferedValues(levelsPosition_ - startLevelsPosition); } else { - // Flat, non-repeated - this->ConsumeBufferedValues(values_to_read); + // Flat, non-repeated. + this->consumeBufferedValues(valuesToRead); } - return records_read; + return recordsRead; } - void DebugPrintState() override { - const int16_t* def_levels = this->def_levels(); - const int16_t* rep_levels = this->rep_levels(); - const int64_t total_levels_read = levels_position_; + void debugPrintState() override { + const int16_t* defLevels = this->defLevels(); + const int16_t* repLevels = this->repLevels(); + const int64_t totalLevelsRead = levelsPosition_; const T* vals = reinterpret_cast(this->values()); - if (leaf_info_.defLevel > 0) { + if (leafInfo_.defLevel > 0) { std::cout << "def levels: "; - for (int64_t i = 0; i < total_levels_read; ++i) { - std::cout << def_levels[i] << " "; + for (int64_t i = 0; i < totalLevelsRead; ++i) { + std::cout << defLevels[i] << " "; } std::cout << std::endl; } - if (leaf_info_.repLevel > 0) { + if (leafInfo_.repLevel > 0) { std::cout << "rep levels: "; - for (int64_t i = 0; i < total_levels_read; ++i) { - std::cout << rep_levels[i] << " "; + for (int64_t i = 0; i < totalLevelsRead; ++i) { + std::cout << repLevels[i] << " "; } std::cout << std::endl; } std::cout << "values: "; - for (int64_t i = 0; i < this->values_written(); ++i) { + for (int64_t i = 0; i < this->valuesWritten(); ++i) { std::cout << vals[i] << " "; } std::cout << std::endl; } - void ResetValues() { - if (values_written_ > 0) { - // Resize to 0, but do not shrink to fit - if (uses_values_) { - PARQUET_THROW_NOT_OK(values_->Resize(0, /*shrink_to_fit=*/false)); + void resetValues() { + if (valuesWritten_ > 0) { + // Resize to 0, but do not shrink to fit. + if (usesValues_) { + PARQUET_THROW_NOT_OK(values_->Resize(0, false)); } - PARQUET_THROW_NOT_OK(valid_bits_->Resize(0, /*shrink_to_fit=*/false)); - values_written_ = 0; - values_capacity_ = 0; - null_count_ = 0; + PARQUET_THROW_NOT_OK(validBits_->Resize(0, false)); + valuesWritten_ = 0; + valuesCapacity_ = 0; + nullCount_ = 0; } } protected: template - T* ValuesHead() { - return reinterpret_cast(values_->mutable_data()) + values_written_; + T* valuesHead() { + return reinterpret_cast(values_->mutable_data()) + valuesWritten_; } - LevelInfo leaf_info_; + LevelInfo leafInfo_; }; class FLBARecordReader : public TypedRecordReader, @@ -2267,62 +2233,64 @@ class FLBARecordReader : public TypedRecordReader, public: FLBARecordReader( const ColumnDescriptor* descr, - LevelInfo leaf_info, + LevelInfo leafInfo, ::arrow::MemoryPool* pool, - bool read_dense_for_nullable) + bool readDenseForNullable) : TypedRecordReader( descr, - leaf_info, + leafInfo, pool, - read_dense_for_nullable), + readDenseForNullable), builder_(nullptr) { - VELOX_DCHECK_EQ(descr_->physical_type(), Type::FIXED_LEN_BYTE_ARRAY); - int byte_width = descr_->type_length(); + VELOX_DCHECK_EQ( + static_cast(descr_->physicalType()), + static_cast(Type::kFixedLenByteArray)); + int byteWidth = descr_->typeLength(); std::shared_ptr<::arrow::DataType> type = - ::arrow::fixed_size_binary(byte_width); + ::arrow::fixed_size_binary(byteWidth); builder_ = std::make_unique<::arrow::FixedSizeBinaryBuilder>(type, this->pool_); } - ::arrow::ArrayVector GetBuilderChunks() override { + ::arrow::ArrayVector getBuilderChunks() override { std::shared_ptr<::arrow::Array> chunk; PARQUET_THROW_NOT_OK(builder_->Finish(&chunk)); return ::arrow::ArrayVector({chunk}); } - void ReadValuesDense(int64_t values_to_read) override { - auto values = ValuesHead(); - int64_t num_decoded = this->current_decoder_->Decode( - values, static_cast(values_to_read)); - CheckNumberDecoded(num_decoded, values_to_read); + void readValuesDense(int64_t valuesToRead) override { + auto values = valuesHead(); + int64_t numDecoded = + this->currentDecoder_->decode(values, static_cast(valuesToRead)); + checkNumberDecoded(numDecoded, valuesToRead); - for (int64_t i = 0; i < num_decoded; i++) { + for (int64_t i = 0; i < numDecoded; i++) { PARQUET_THROW_NOT_OK(builder_->Append(values[i].ptr)); } - ResetValues(); + resetValues(); } - void ReadValuesSpaced(int64_t values_to_read, int64_t null_count) override { - uint8_t* valid_bits = valid_bits_->mutable_data(); - const int64_t valid_bits_offset = values_written_; - auto values = ValuesHead(); + void readValuesSpaced(int64_t valuesToRead, int64_t nullCount) override { + uint8_t* validBits = validBits_->mutable_data(); + const int64_t validBitsOffset = valuesWritten_; + auto values = valuesHead(); - int64_t num_decoded = this->current_decoder_->DecodeSpaced( + int64_t numDecoded = this->currentDecoder_->decodeSpaced( values, - static_cast(values_to_read), - static_cast(null_count), - valid_bits, - valid_bits_offset); - VELOX_DCHECK_EQ(num_decoded, values_to_read); - - for (int64_t i = 0; i < num_decoded; i++) { - if (::arrow::bit_util::GetBit(valid_bits, valid_bits_offset + i)) { + static_cast(valuesToRead), + static_cast(nullCount), + validBits, + validBitsOffset); + VELOX_DCHECK_EQ(numDecoded, valuesToRead); + + for (int64_t i = 0; i < numDecoded; i++) { + if (::arrow::bit_util::GetBit(validBits, validBitsOffset + i)) { PARQUET_THROW_NOT_OK(builder_->Append(values[i].ptr)); } else { PARQUET_THROW_NOT_OK(builder_->AppendNull()); } } - ResetValues(); + resetValues(); } private: @@ -2334,49 +2302,51 @@ class ByteArrayChunkedRecordReader : public TypedRecordReader, public: ByteArrayChunkedRecordReader( const ColumnDescriptor* descr, - LevelInfo leaf_info, + LevelInfo leafInfo, ::arrow::MemoryPool* pool, - bool read_dense_for_nullable) + bool readDenseForNullable) : TypedRecordReader( descr, - leaf_info, + leafInfo, pool, - read_dense_for_nullable) { - VELOX_DCHECK_EQ(descr_->physical_type(), Type::BYTE_ARRAY); - accumulator_.builder = std::make_unique<::arrow::BinaryBuilder>(pool); + readDenseForNullable) { + VELOX_DCHECK_EQ( + static_cast(descr_->physicalType()), + static_cast(Type::kByteArray)); + accumulator_.Builder = std::make_unique<::arrow::BinaryBuilder>(pool); } - ::arrow::ArrayVector GetBuilderChunks() override { + ::arrow::ArrayVector getBuilderChunks() override { ::arrow::ArrayVector result = accumulator_.chunks; - if (result.size() == 0 || accumulator_.builder->length() > 0) { - std::shared_ptr<::arrow::Array> last_chunk; - PARQUET_THROW_NOT_OK(accumulator_.builder->Finish(&last_chunk)); - result.push_back(std::move(last_chunk)); + if (result.size() == 0 || accumulator_.Builder->length() > 0) { + std::shared_ptr<::arrow::Array> lastChunk; + PARQUET_THROW_NOT_OK(accumulator_.Builder->Finish(&lastChunk)); + result.push_back(std::move(lastChunk)); } accumulator_.chunks = {}; return result; } - void ReadValuesDense(int64_t values_to_read) override { - int64_t num_decoded = this->current_decoder_->DecodeArrowNonNull( - static_cast(values_to_read), &accumulator_); - CheckNumberDecoded(num_decoded, values_to_read); - ResetValues(); + void readValuesDense(int64_t valuesToRead) override { + int64_t numDecoded = this->currentDecoder_->decodeArrowNonNull( + static_cast(valuesToRead), &accumulator_); + checkNumberDecoded(numDecoded, valuesToRead); + resetValues(); } - void ReadValuesSpaced(int64_t values_to_read, int64_t null_count) override { - int64_t num_decoded = this->current_decoder_->DecodeArrow( - static_cast(values_to_read), - static_cast(null_count), - valid_bits_->mutable_data(), - values_written_, + void readValuesSpaced(int64_t valuesToRead, int64_t nullCount) override { + int64_t numDecoded = this->currentDecoder_->decodeArrow( + static_cast(valuesToRead), + static_cast(nullCount), + validBits_->mutable_data(), + valuesWritten_, &accumulator_); - CheckNumberDecoded(num_decoded, values_to_read - null_count); - ResetValues(); + checkNumberDecoded(numDecoded, valuesToRead - nullCount); + resetValues(); } private: - // Helper data structure for accumulating builder chunks + // Helper data structure for accumulating builder chunks. typename EncodingTraits::Accumulator accumulator_; }; @@ -2385,166 +2355,166 @@ class ByteArrayDictionaryRecordReader : public TypedRecordReader, public: ByteArrayDictionaryRecordReader( const ColumnDescriptor* descr, - LevelInfo leaf_info, + LevelInfo leafInfo, ::arrow::MemoryPool* pool, - bool read_dense_for_nullable) + bool readDenseForNullable) : TypedRecordReader( descr, - leaf_info, + leafInfo, pool, - read_dense_for_nullable), + readDenseForNullable), builder_(pool) { - this->read_dictionary_ = true; + this->readDictionary_ = true; } - std::shared_ptr<::arrow::ChunkedArray> GetResult() override { - FlushBuilder(); + std::shared_ptr<::arrow::ChunkedArray> getResult() override { + flushBuilder(); std::vector> result; - std::swap(result, result_chunks_); + std::swap(result, resultChunks_); return std::make_shared<::arrow::ChunkedArray>( std::move(result), builder_.type()); } - void FlushBuilder() { + void flushBuilder() { if (builder_.length() > 0) { std::shared_ptr<::arrow::Array> chunk; PARQUET_THROW_NOT_OK(builder_.Finish(&chunk)); - result_chunks_.emplace_back(std::move(chunk)); + resultChunks_.emplace_back(std::move(chunk)); - // Also clears the dictionary memo table + // Also clears the dictionary memo table. builder_.Reset(); } } - void MaybeWriteNewDictionary() { - if (this->new_dictionary_) { - /// If there is a new dictionary, we may need to flush the builder, then - /// insert the new dictionary values - FlushBuilder(); + void maybeWriteNewDictionary() { + if (this->newDictionary_) { + /// If there is a new dictionary, we may need to flush the builder, then. + /// Insert the new dictionary values. + flushBuilder(); builder_.ResetFull(); - auto decoder = dynamic_cast(this->current_decoder_); - decoder->InsertDictionary(&builder_); - this->new_dictionary_ = false; + auto decoder = dynamic_cast(this->currentDecoder_); + decoder->insertDictionary(&builder_); + this->newDictionary_ = false; } } - void ReadValuesDense(int64_t values_to_read) override { - int64_t num_decoded = 0; - if (current_encoding_ == Encoding::RLE_DICTIONARY) { - MaybeWriteNewDictionary(); - auto decoder = dynamic_cast(this->current_decoder_); - num_decoded = - decoder->DecodeIndices(static_cast(values_to_read), &builder_); + void readValuesDense(int64_t valuesToRead) override { + int64_t numDecoded = 0; + if (currentEncoding_ == Encoding::kRleDictionary) { + maybeWriteNewDictionary(); + auto decoder = dynamic_cast(this->currentDecoder_); + numDecoded = + decoder->decodeIndices(static_cast(valuesToRead), &builder_); } else { - num_decoded = this->current_decoder_->DecodeArrowNonNull( - static_cast(values_to_read), &builder_); + numDecoded = this->currentDecoder_->decodeArrowNonNull( + static_cast(valuesToRead), &builder_); - /// Flush values since they have been copied into the builder - ResetValues(); + /// Flush values since they have been copied into the builder. + resetValues(); } - CheckNumberDecoded(num_decoded, values_to_read); + checkNumberDecoded(numDecoded, valuesToRead); } - void ReadValuesSpaced(int64_t values_to_read, int64_t null_count) override { - VELOX_DEBUG_ONLY int64_t num_decoded = 0; - if (current_encoding_ == Encoding::RLE_DICTIONARY) { - MaybeWriteNewDictionary(); - auto decoder = dynamic_cast(this->current_decoder_); - num_decoded = decoder->DecodeIndicesSpaced( - static_cast(values_to_read), - static_cast(null_count), - valid_bits_->mutable_data(), - values_written_, + void readValuesSpaced(int64_t valuesToRead, int64_t nullCount) override { + VELOX_DEBUG_ONLY int64_t numDecoded = 0; + if (currentEncoding_ == Encoding::kRleDictionary) { + maybeWriteNewDictionary(); + auto decoder = dynamic_cast(this->currentDecoder_); + numDecoded = decoder->decodeIndicesSpaced( + static_cast(valuesToRead), + static_cast(nullCount), + validBits_->mutable_data(), + valuesWritten_, &builder_); } else { - num_decoded = this->current_decoder_->DecodeArrow( - static_cast(values_to_read), - static_cast(null_count), - valid_bits_->mutable_data(), - values_written_, + numDecoded = this->currentDecoder_->decodeArrow( + static_cast(valuesToRead), + static_cast(nullCount), + validBits_->mutable_data(), + valuesWritten_, &builder_); - /// Flush values since they have been copied into the builder - ResetValues(); + /// Flush values since they have been copied into the builder. + resetValues(); } - VELOX_DCHECK_EQ(num_decoded, values_to_read - null_count); + VELOX_DCHECK_EQ(numDecoded, valuesToRead - nullCount); } private: using BinaryDictDecoder = DictDecoder; ::arrow::BinaryDictionary32Builder builder_; - std::vector> result_chunks_; + std::vector> resultChunks_; }; -// TODO(wesm): Implement these to some satisfaction +// TODO(wesm): Implement these to some satisfaction. template <> -void TypedRecordReader::DebugPrintState() {} +void TypedRecordReader::debugPrintState() {} template <> -void TypedRecordReader::DebugPrintState() {} +void TypedRecordReader::debugPrintState() {} template <> -void TypedRecordReader::DebugPrintState() {} +void TypedRecordReader::debugPrintState() {} -std::shared_ptr MakeByteArrayRecordReader( +std::shared_ptr makeByteArrayRecordReader( const ColumnDescriptor* descr, - LevelInfo leaf_info, + LevelInfo leafInfo, ::arrow::MemoryPool* pool, - bool read_dictionary, - bool read_dense_for_nullable) { - if (read_dictionary) { + bool readDictionary, + bool readDenseForNullable) { + if (readDictionary) { return std::make_shared( - descr, leaf_info, pool, read_dense_for_nullable); + descr, leafInfo, pool, readDenseForNullable); } else { return std::make_shared( - descr, leaf_info, pool, read_dense_for_nullable); + descr, leafInfo, pool, readDenseForNullable); } } } // namespace -std::shared_ptr RecordReader::Make( +std::shared_ptr RecordReader::make( const ColumnDescriptor* descr, - LevelInfo leaf_info, + LevelInfo leafInfo, MemoryPool* pool, - bool read_dictionary, - bool read_dense_for_nullable) { - switch (descr->physical_type()) { - case Type::BOOLEAN: + bool readDictionary, + bool readDenseForNullable) { + switch (descr->physicalType()) { + case Type::kBoolean: return std::make_shared>( - descr, leaf_info, pool, read_dense_for_nullable); - case Type::INT32: + descr, leafInfo, pool, readDenseForNullable); + case Type::kInt32: return std::make_shared>( - descr, leaf_info, pool, read_dense_for_nullable); - case Type::INT64: + descr, leafInfo, pool, readDenseForNullable); + case Type::kInt64: return std::make_shared>( - descr, leaf_info, pool, read_dense_for_nullable); - case Type::INT96: + descr, leafInfo, pool, readDenseForNullable); + case Type::kInt96: return std::make_shared>( - descr, leaf_info, pool, read_dense_for_nullable); - case Type::FLOAT: + descr, leafInfo, pool, readDenseForNullable); + case Type::kFloat: return std::make_shared>( - descr, leaf_info, pool, read_dense_for_nullable); - case Type::DOUBLE: + descr, leafInfo, pool, readDenseForNullable); + case Type::kDouble: return std::make_shared>( - descr, leaf_info, pool, read_dense_for_nullable); - case Type::BYTE_ARRAY: { - return MakeByteArrayRecordReader( - descr, leaf_info, pool, read_dictionary, read_dense_for_nullable); + descr, leafInfo, pool, readDenseForNullable); + case Type::kByteArray: { + return makeByteArrayRecordReader( + descr, leafInfo, pool, readDictionary, readDenseForNullable); } - case Type::FIXED_LEN_BYTE_ARRAY: + case Type::kFixedLenByteArray: return std::make_shared( - descr, leaf_info, pool, read_dense_for_nullable); + descr, leafInfo, pool, readDenseForNullable); default: { - // PARQUET-1481: This can occur if the file is corrupt + // PARQUET-1481: This can occur if the file is corrupt. std::stringstream ss; ss << "Invalid physical column type: " - << static_cast(descr->physical_type()); + << static_cast(descr->physicalType()); throw ParquetException(ss.str()); } } - // Unreachable code, but suppress compiler warning + // Unreachable code, but suppress compiler warning. return nullptr; } diff --git a/velox/dwio/parquet/writer/arrow/tests/ColumnReader.h b/velox/dwio/parquet/writer/arrow/tests/ColumnReader.h index 6936aa618e5..07faa918e27 100644 --- a/velox/dwio/parquet/writer/arrow/tests/ColumnReader.h +++ b/velox/dwio/parquet/writer/arrow/tests/ColumnReader.h @@ -51,32 +51,32 @@ namespace facebook::velox::parquet::arrow { class Decryptor; class Page; -// 16 MB is the default maximum page header size +// 16 MB is the default maximum page header size. static constexpr uint32_t kDefaultMaxPageHeaderSize = 16 * 1024 * 1024; -// 16 KB is the default expected page header size +// 16 KB is the default expected page header size. static constexpr uint32_t kDefaultPageHeaderSize = 16 * 1024; -// \brief DataPageStats stores encoded statistics and number of values/rows for -// a page. +// \brief DataPageStats stores encoded statistics and number of values/rows for. +// A page. struct PARQUET_EXPORT DataPageStats { DataPageStats( - const EncodedStatistics* encoded_statistics, - int32_t num_values, - std::optional num_rows) - : encoded_statistics(encoded_statistics), - num_values(num_values), - num_rows(num_rows) {} + const EncodedStatistics* encodedStatistics, + int32_t numValues, + std::optional numRows) + : encodedStatistics(encodedStatistics), + numValues(numValues), + numRows(numRows) {} // Encoded statistics extracted from the page header. // Nullptr if there are no statistics in the page header. - const EncodedStatistics* encoded_statistics; + const EncodedStatistics* encodedStatistics; // Number of values stored in the page. Filled for both V1 and V2 data pages. - // For repeated fields, this can be greater than number of rows. For - // non-repeated fields, this will be the same as the number of rows. - int32_t num_values; + // For repeated fields, this can be greater than number of rows. For. + // Non-repeated fields, this will be the same as the number of rows. + int32_t numValues; // Number of rows stored in the page. std::nullopt if not available. - std::optional num_rows; + std::optional numRows; }; class PARQUET_EXPORT LevelDecoder { @@ -84,456 +84,456 @@ class PARQUET_EXPORT LevelDecoder { LevelDecoder(); ~LevelDecoder(); - // Initialize the LevelDecoder state with new data - // and return the number of bytes consumed - int SetData( + // Initialize the LevelDecoder state with new data. + // And return the number of bytes consumed. + int setData( Encoding::type encoding, - int16_t max_level, - int num_buffered_values, + int16_t maxLevel, + int numBufferedValues, const uint8_t* data, - int32_t data_size); + int32_t dataSize); - void SetDataV2( - int32_t num_bytes, - int16_t max_level, - int num_buffered_values, + void setDataV2( + int32_t numBytes, + int16_t maxLevel, + int numBufferedValues, const uint8_t* data); - // Decodes a batch of levels into an array and returns the number of levels - // decoded - int Decode(int batch_size, int16_t* levels); + // Decodes a batch of levels into an array and returns the number of levels. + // Decoded. + int decode(int batchSize, int16_t* levels); private: - int bit_width_; - int num_values_remaining_; + int bitWidth_; + int numValuesRemaining_; Encoding::type encoding_; - std::unique_ptr rle_decoder_; - std::unique_ptr bit_packed_decoder_; - int16_t max_level_; + std::unique_ptr rleDecoder_; + std::unique_ptr bitPackedDecoder_; + int16_t maxLevel_; }; struct CryptoContext { CryptoContext( - bool start_with_dictionary_page, - int16_t rg_ordinal, - int16_t col_ordinal, + bool startWithDictionaryPage, + int16_t rgOrdinal, + int16_t colOrdinal, std::shared_ptr meta, std::shared_ptr data) - : start_decrypt_with_dictionary_page(start_with_dictionary_page), - row_group_ordinal(rg_ordinal), - column_ordinal(col_ordinal), - meta_decryptor(std::move(meta)), - data_decryptor(std::move(data)) {} + : startDecryptWithDictionaryPage(startWithDictionaryPage), + rowGroupOrdinal(rgOrdinal), + columnOrdinal(colOrdinal), + metaDecryptor(std::move(meta)), + dataDecryptor(std::move(data)) {} CryptoContext() {} - bool start_decrypt_with_dictionary_page = false; - int16_t row_group_ordinal = -1; - int16_t column_ordinal = -1; - std::shared_ptr meta_decryptor; - std::shared_ptr data_decryptor; + bool startDecryptWithDictionaryPage = false; + int16_t rowGroupOrdinal = -1; + int16_t columnOrdinal = -1; + std::shared_ptr metaDecryptor; + std::shared_ptr dataDecryptor; }; -// Abstract page iterator interface. This way, we can feed column pages to the -// ColumnReader through whatever mechanism we choose +// Abstract page iterator interface. This way, we can feed column pages to the. +// ColumnReader through whatever mechanism we choose. class PARQUET_EXPORT PageReader { using DataPageFilter = std::function; public: virtual ~PageReader() = default; - static std::unique_ptr Open( + static std::unique_ptr open( std::shared_ptr stream, - int64_t total_num_values, - Compression::type codec, - bool always_compressed = false, + int64_t totalNumValues, + Compression::type Codec, + bool alwaysCompressed = false, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool(), const CryptoContext* ctx = NULLPTR); - static std::unique_ptr Open( + static std::unique_ptr open( std::shared_ptr stream, - int64_t total_num_values, - Compression::type codec, + int64_t totalNumValues, + Compression::type Codec, const ReaderProperties& properties, - bool always_compressed = false, + bool alwaysCompressed = false, const CryptoContext* ctx = NULLPTR); - // If data_page_filter is present (not null), NextPage() will call the - // callback function exactly once per page in the order the pages appear in - // the column. If the callback function returns true the page will be - // skipped. The callback will be called only if the page type is DATA_PAGE or + // If data_page_filter is present (not null), NextPage() will call the. + // Callback function exactly once per page in the order the pages appear in. + // The column. If the callback function returns true the page will be. + // Skipped. The callback will be called only if the page type is DATA_PAGE or. // DATA_PAGE_V2. Dictionary pages will not be skipped. - // Caller is responsible for checking that statistics are correct using + // Caller is responsible for checking that statistics are correct using. // ApplicationVersion::HasCorrectStatistics(). - // \note API EXPERIMENTAL - void set_data_page_filter(DataPageFilter data_page_filter) { - data_page_filter_ = std::move(data_page_filter); + // \note API EXPERIMENTAL. + void setDataPageFilter(DataPageFilter dataPageFilter) { + dataPageFilter_ = std::move(dataPageFilter); } - // @returns: shared_ptr(nullptr) on EOS, std::shared_ptr - // containing new Page otherwise + // @returns: shared_ptr(nullptr) on EOS, std::shared_ptr. + // Containing new Page otherwise. // - // The returned Page may contain references that aren't guaranteed to live - // beyond the next call to NextPage(). - virtual std::shared_ptr NextPage() = 0; + // The returned Page may contain references that aren't guaranteed to live. + // Beyond the next call to NextPage(). + virtual std::shared_ptr nextPage() = 0; - virtual void set_max_page_header_size(uint32_t size) = 0; + virtual void setMaxPageHeaderSize(uint32_t size) = 0; protected: // Callback that decides if we should skip a page or not. - DataPageFilter data_page_filter_; + DataPageFilter dataPageFilter_; }; class PARQUET_EXPORT ColumnReader { public: virtual ~ColumnReader() = default; - static std::shared_ptr Make( + static std::shared_ptr make( const ColumnDescriptor* descr, std::unique_ptr pager, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()); // Returns true if there are still values in this column. - virtual bool HasNext() = 0; + virtual bool hasNext() = 0; virtual Type::type type() const = 0; virtual const ColumnDescriptor* descr() const = 0; - // Get the encoding that can be exposed by this reader. If it returns - // dictionary encoding, then ReadBatchWithDictionary can be used to read data. + // Get the encoding that can be exposed by this reader. If it returns. + // Dictionary encoding, then ReadBatchWithDictionary can be used to read data. // - // \note API EXPERIMENTAL - virtual ExposedEncoding GetExposedEncoding() = 0; + // \note API EXPERIMENTAL. + virtual ExposedEncoding getExposedEncoding() = 0; protected: friend class RowGroupReader; // Set the encoding that can be exposed by this reader. // - // \note API EXPERIMENTAL - virtual void SetExposedEncoding(ExposedEncoding encoding) = 0; + // \note API EXPERIMENTAL. + virtual void setExposedEncoding(ExposedEncoding encoding) = 0; }; // API to read values from a single column. This is a main client facing API. template class TypedColumnReader : public ColumnReader { public: - typedef typename DType::c_type T; + typedef typename DType::CType T; - // Read a batch of repetition levels, definition levels, and values from the - // column. + // Read a batch of repetition levels, definition levels, and values from the. + // Column. // - // Since null values are not stored in the values, the number of values read - // may be less than the number of repetition and definition levels. With - // nested data this is almost certainly true. + // Since null values are not stored in the values, the number of values read. + // May be less than the number of repetition and definition levels. With. + // Nested data this is almost certainly true. // // Set def_levels or rep_levels to nullptr if you want to skip reading them. - // This is only safe if you know through some other source that there are no - // undefined values. + // This is only safe if you know through some other source that there are no. + // Undefined values. // - // To fully exhaust a row group, you must read batches until the number of - // values read reaches the number of stored values according to the metadata. + // To fully exhaust a row group, you must read batches until the number of. + // Values read reaches the number of stored values according to the metadata. // - // This API is the same for both V1 and V2 of the DataPage + // This API is the same for both V1 and V2 of the DataPage. // - // @returns: actual number of levels read (see values_read for number of + // @returns: actual number of levels read (see values_read for number of. // values read) - virtual int64_t ReadBatch( - int64_t batch_size, - int16_t* def_levels, - int16_t* rep_levels, + virtual int64_t readBatch( + int64_t batchSize, + int16_t* defLevels, + int16_t* repLevels, T* values, - int64_t* values_read) = 0; + int64_t* valuesRead) = 0; - /// Read a batch of repetition levels, definition levels, and values from the - /// column and leave spaces for null entries on the lowest level in the values - /// buffer. + /// Read a batch of repetition levels, definition levels, and values from the. + /// Column and leave spaces for null entries on the lowest level in the + /// values. Buffer. /// - /// In comparison to ReadBatch the length of repetition and definition levels - /// is the same as of the number of values read for max_definition_level == 1. - /// In the case of max_definition_level > 1, the repetition and definition - /// levels are larger than the values but the values include the null entries - /// with definition_level == (max_definition_level - 1). + /// In comparison to ReadBatch the length of repetition and definition levels. + /// Is the same as of the number of values read for max_definition_level == 1. + /// In the case of max_definition_level > 1, the repetition and definition. + /// Levels are larger than the values but the values include the null entries. + /// With definition_level == (max_definition_level - 1). /// - /// To fully exhaust a row group, you must read batches until the number of - /// values read reaches the number of stored values according to the metadata. + /// To fully exhaust a row group, you must read batches until the number of. + /// Values read reaches the number of stored values according to the metadata. /// - /// @param batch_size the number of levels to read - /// @param[out] def_levels The Parquet definition levels, output has - /// the length levels_read. - /// @param[out] rep_levels The Parquet repetition levels, output has - /// the length levels_read. - /// @param[out] values The values in the lowest nested level including - /// spacing for nulls on the lowest levels; output has the length - /// values_read. - /// @param[out] valid_bits Memory allocated for a bitmap that indicates if - /// the row is null or on the maximum definition level. For performance - /// reasons the underlying buffer should be able to store 1 bit more than - /// required. If this requires an additional byte, this byte is only read - /// but never written to. - /// @param valid_bits_offset The offset in bits of the valid_bits where the - /// first relevant bit resides. - /// @param[out] levels_read The number of repetition/definition levels that - /// were read. - /// @param[out] values_read The number of values read, this includes all - /// non-null entries as well as all null-entries on the lowest level + /// @param batch_size the number of levels to read. + /// @param[out] def_levels The Parquet definition levels, output has. + /// The length levels_read. + /// @param[out] rep_levels The Parquet repetition levels, output has. + /// The length levels_read. + /// @param[out] values The values in the lowest nested level including. + /// Spacing for nulls on the lowest levels; output has the length. + /// Values_read. + /// @param[out] valid_bits Memory allocated for a bitmap that indicates if. + /// The row is null or on the maximum definition level. For performance. + /// Reasons the underlying buffer should be able to store 1 bit more than. + /// Required. If this requires an additional byte, this byte is only read. + /// But never written to. + /// @param valid_bits_offset The offset in bits of the valid_bits where the. + /// First relevant bit resides. + /// @param[out] levels_read The number of repetition/definition levels that. + /// Were read. + /// @param[out] values_read The number of values read, this includes all. + /// Non-null entries as well as all null-entries on the lowest level. /// (i.e. definition_level == max_definition_level - 1) /// @param[out] null_count The number of nulls on the lowest levels. /// (i.e. (values_read - null_count) is total number of non-null entries) /// - /// \deprecated Since 4.0.0 + /// \deprecated Since 4.0.0. ARROW_DEPRECATED( "Doesn't handle nesting correctly and unused outside of unit tests.") - virtual int64_t ReadBatchSpaced( - int64_t batch_size, - int16_t* def_levels, - int16_t* rep_levels, + virtual int64_t readBatchSpaced( + int64_t batchSize, + int16_t* defLevels, + int16_t* repLevels, T* values, - uint8_t* valid_bits, - int64_t valid_bits_offset, - int64_t* levels_read, - int64_t* values_read, - int64_t* null_count) = 0; - - // Skip reading values. This method will work for both repeated and - // non-repeated fields. Note that this method is skipping values and not - // records. This distinction is important for repeated fields, meaning that - // we are not skipping over the values to the next record. For example, - // consider the following two consecutive records containing one repeated - // field: - // {[1, 2, 3]}, {[4, 5]}. If we Skip(2), our next read value will be 3, which - // is inside the first record. + uint8_t* validBits, + int64_t validBitsOffset, + int64_t* levelsRead, + int64_t* valuesRead, + int64_t* nullCount) = 0; + + // Skip reading values. This method will work for both repeated and. + // Non-repeated fields. Note that this method is skipping values and not. + // Records. This distinction is important for repeated fields, meaning that. + // We are not skipping over the values to the next record. For example,. + // Consider the following two consecutive records containing one repeated. + // Field: + // {[1, 2, 3]}, {[4, 5]}. If we Skip(2), our next read value will be 3, which. + // Is inside the first record. // Returns the number of values skipped. - virtual int64_t Skip(int64_t num_values_to_skip) = 0; + virtual int64_t skip(int64_t numValuesToSkip) = 0; - // Read a batch of repetition levels, definition levels, and indices from the - // column. And read the dictionary if a dictionary page is encountered during - // reading pages. This API is similar to ReadBatch(), with ability to read - // dictionary and indices. It is only valid to call this method when the - // reader can expose dictionary encoding. (i.e., the reader's + // Read a batch of repetition levels, definition levels, and indices from the. + // Column. And read the dictionary if a dictionary page is encountered during. + // Reading pages. This API is similar to ReadBatch(), with ability to read. + // Dictionary and indices. It is only valid to call this method when the. + // Reader can expose dictionary encoding. (i.e., the reader's. // GetExposedEncoding() returns DICTIONARY). // - // The dictionary is read along with the data page. When there's no data page, - // the dictionary won't be returned. + // The dictionary is read along with the data page. When there's no data + // page,. The dictionary won't be returned. // - // @param batch_size The batch size to read + // @param batch_size The batch size to read. // @param[out] def_levels The Parquet definition levels. // @param[out] rep_levels The Parquet repetition levels. // @param[out] indices The dictionary indices. // @param[out] indices_read The number of indices read. - // @param[out] dict The pointer to dictionary values. It will return nullptr - // if there's no data page. Each column chunk only has one dictionary page. - // The dictionary is owned by the reader, so the caller is responsible for - // copying the dictionary values before the reader gets destroyed. - // @param[out] dict_len The dictionary length. It will return 0 if there's no - // data page. - // @returns: actual number of levels read (see indices_read for number of - // indices read + // @param[out] dict The pointer to dictionary values. It will return nullptr. + // If there's no data page. Each column chunk only has one dictionary page. + // The dictionary is owned by the reader, so the caller is responsible for. + // Copying the dictionary values before the reader gets destroyed. + // @param[out] dict_len The dictionary length. It will return 0 if there's no. + // Data page. + // @returns: actual number of levels read (see indices_read for number of. + // Indices read. // - // \note API EXPERIMENTAL - virtual int64_t ReadBatchWithDictionary( - int64_t batch_size, - int16_t* def_levels, - int16_t* rep_levels, + // \note API EXPERIMENTAL. + virtual int64_t readBatchWithDictionary( + int64_t batchSize, + int16_t* defLevels, + int16_t* repLevels, int32_t* indices, - int64_t* indices_read, + int64_t* indicesRead, const T** dict, - int32_t* dict_len) = 0; + int32_t* dictLen) = 0; }; namespace internal { -/// \brief Stateful column reader that delimits semantic records for both flat -/// and nested columns +/// \brief Stateful column reader that delimits semantic records for both flat. +/// And nested columns. /// -/// \note API EXPERIMENTAL -/// \since 1.3.0 +/// \note API EXPERIMENTAL. +/// \since 1.3.0. class PARQUET_EXPORT RecordReader { public: /// \brief Creates a record reader. - /// @param descr Column descriptor - /// @param leaf_info Level info, used to determine if a column is nullable or - /// not - /// @param pool Memory pool to use for buffering values and rep/def levels - /// @param read_dictionary True if reading directly as Arrow - /// dictionary-encoded - /// @param read_dense_for_nullable True if reading dense and not leaving space - /// for null values - static std::shared_ptr Make( + /// @param descr Column descriptor. + /// @param leaf_info Level info, used to determine if a column is nullable or. + /// Not. + /// @param pool Memory pool to use for buffering values and rep/def levels. + /// @param read_dictionary True if reading directly as Arrow. + /// Dictionary-encoded. + /// @param read_dense_for_nullable True if reading dense and not leaving + /// space. For null values. + static std::shared_ptr make( const ColumnDescriptor* descr, - LevelInfo leaf_info, + LevelInfo leafInfo, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool(), - bool read_dictionary = false, - bool read_dense_for_nullable = false); + bool readDictionary = false, + bool readDenseForNullable = false); virtual ~RecordReader() = default; - /// \brief Attempt to read indicated number of records from column chunk - /// Note that for repeated fields, a record may have more than one value - /// and all of them are read. If read_dense_for_nullable() it will - /// not leave any space for null values. Otherwise, it will read spaced. - /// \return number of records read - virtual int64_t ReadRecords(int64_t num_records) = 0; + /// \brief Attempt to read indicated number of records from column chunk. + /// Note that for repeated fields, a record may have more than one value. + /// And all of them are read. If read_dense_for_nullable() it will. + /// Not leave any space for null values. Otherwise, it will read spaced. + /// \return number of records read. + virtual int64_t readRecords(int64_t numRecords) = 0; /// \brief Attempt to skip indicated number of records from column chunk. - /// Note that for repeated fields, a record may have more than one value - /// and all of them are skipped. - /// \return number of records skipped - virtual int64_t SkipRecords(int64_t num_records) = 0; + /// Note that for repeated fields, a record may have more than one value. + /// And all of them are skipped. + /// \return number of records skipped. + virtual int64_t skipRecords(int64_t numRecords) = 0; - /// \brief Pre-allocate space for data. Results in better flat read - /// performance - virtual void Reserve(int64_t num_values) = 0; + /// \brief Pre-allocate space for data. Results in better flat read. + /// Performance. + virtual void reserve(int64_t numValues) = 0; - /// \brief Clear consumed values and repetition/definition levels as the - /// result of calling ReadRecords + /// \brief Clear consumed values and repetition/definition levels as the. + /// Result of calling ReadRecords. /// For FLBA and ByteArray types, call GetBuilderChunks() to reset them. - virtual void Reset() = 0; + virtual void reset() = 0; - /// \brief Transfer filled values buffer to caller. A new one will be - /// allocated in subsequent ReadRecords calls - virtual std::shared_ptr ReleaseValues() = 0; + /// \brief Transfer filled values buffer to caller. A new one will be. + /// Allocated in subsequent ReadRecords calls. + virtual std::shared_ptr releaseValues() = 0; - /// \brief Transfer filled validity bitmap buffer to caller. A new one will - /// be allocated in subsequent ReadRecords calls - virtual std::shared_ptr ReleaseIsValid() = 0; + /// \brief Transfer filled validity bitmap buffer to caller. A new one will. + /// Be allocated in subsequent ReadRecords calls. + virtual std::shared_ptr releaseIsValid() = 0; - /// \brief Return true if the record reader has more internal data yet to - /// process - virtual bool HasMoreData() const = 0; + /// \brief Return true if the record reader has more internal data yet to. + /// Process. + virtual bool hasMoreData() const = 0; - /// \brief Advance record reader to the next row group. Must be set before - /// any records could be read/skipped. - /// \param[in] reader obtained from RowGroupReader::GetColumnPageReader - virtual void SetPageReader(std::unique_ptr reader) = 0; + /// \brief Advance record reader to the next row group. Must be set before. + /// Any records could be read/skipped. + /// \param[in] reader obtained from RowGroupReader::GetColumnPageReader. + virtual void setPageReader(std::unique_ptr reader) = 0; /// \brief Returns the underlying column reader's descriptor. virtual const ColumnDescriptor* descr() const = 0; - virtual void DebugPrintState() = 0; + virtual void debugPrintState() = 0; - /// \brief Decoded definition levels - int16_t* def_levels() const { - return reinterpret_cast(def_levels_->mutable_data()); + /// \brief Decoded definition levels. + int16_t* defLevels() const { + return reinterpret_cast(defLevels_->mutable_data()); } - /// \brief Decoded repetition levels - int16_t* rep_levels() const { - return reinterpret_cast(rep_levels_->mutable_data()); + /// \brief Decoded repetition levels. + int16_t* repLevels() const { + return reinterpret_cast(repLevels_->mutable_data()); } - /// \brief Decoded values, including nulls, if any - /// FLBA and ByteArray types do not use this array and read into their own - /// builders. + /// \brief Decoded values, including nulls, if any. + /// FLBA and ByteArray types do not use this array and read into their own. + /// Builders. uint8_t* values() const { return values_->mutable_data(); } /// \brief Number of values written, including space left for nulls if any. - /// If this Reader was constructed with read_dense_for_nullable(), there is no - /// space for nulls and null_count() will be 0. There is no - /// read-ahead/buffering for values. For FLBA and ByteArray types this value - /// reflects the values written with the last ReadRecords call since those - /// readers will reset the values after each call. - int64_t values_written() const { - return values_written_; + /// If this Reader was constructed with read_dense_for_nullable(), there is + /// no. Space for nulls and null_count() will be 0. There is no. + /// Read-ahead/buffering for values. For FLBA and ByteArray types this value. + /// Reflects the values written with the last ReadRecords call since those. + /// Readers will reset the values after each call. + int64_t valuesWritten() const { + return valuesWritten_; } - /// \brief Number of definition / repetition levels (from those that have - /// been decoded) that have been consumed inside the reader. - int64_t levels_position() const { - return levels_position_; + /// \brief Number of definition / repetition levels (from those that have. + /// Been decoded) that have been consumed inside the reader. + int64_t levelsPosition() const { + return levelsPosition_; } - /// \brief Number of definition / repetition levels that have been written - /// internally in the reader. This may be larger than values_written() because - /// for repeated fields we need to look at the levels in advance to figure out - /// the record boundaries. - int64_t levels_written() const { - return levels_written_; + /// \brief Number of definition / repetition levels that have been written. + /// Internally in the reader. This may be larger than values_written() + /// because. For repeated fields we need to look at the levels in advance to + /// figure out. The record boundaries. + int64_t levelsWritten() const { + return levelsWritten_; } - /// \brief Number of nulls in the leaf that we have read so far into the - /// values vector. This is only valid when !read_dense_for_nullable(). When - /// read_dense_for_nullable() it will always be 0. - int64_t null_count() const { - return null_count_; + /// \brief Number of nulls in the leaf that we have read so far into the. + /// Values vector. This is only valid when !read_dense_for_nullable(). When. + /// Read_dense_for_nullable() it will always be 0. + int64_t nullCount() const { + return nullCount_; } - /// \brief True if the leaf values are nullable - bool nullable_values() const { - return nullable_values_; + /// \brief True if the leaf values are nullable. + bool nullableValues() const { + return nullableValues_; } - /// \brief True if reading directly as Arrow dictionary-encoded - bool read_dictionary() const { - return read_dictionary_; + /// \brief True if reading directly as Arrow dictionary-encoded. + bool readDictionary() const { + return readDictionary_; } /// \brief True if reading dense for nullable columns. - bool read_dense_for_nullable() const { - return read_dense_for_nullable_; + bool readDenseForNullable() const { + return readDenseForNullable_; } protected: - /// \brief Indicates if we can have nullable values. Note that repeated fields - /// may or may not be nullable. - bool nullable_values_; + /// \brief Indicates if we can have nullable values. Note that repeated + /// fields. May or may not be nullable. + bool nullableValues_; - bool at_record_start_; - int64_t records_read_; + bool atRecordStart_; + int64_t recordsRead_; - /// \brief Stores values. These values are populated based on each ReadRecords - /// call. No extra values are buffered for the next call. SkipRecords will not - /// add any value to this buffer. + /// \brief Stores values. These values are populated based on each + /// ReadRecords. Call. No extra values are buffered for the next call. + /// SkipRecords will not. Add any value to this buffer. std::shared_ptr<::arrow::ResizableBuffer> values_; - /// \brief False for BYTE_ARRAY, in which case we don't allocate the values - /// buffer and we directly read into builder classes. - bool uses_values_; + /// \brief False for BYTE_ARRAY, in which case we don't allocate the values. + /// Buffer and we directly read into builder classes. + bool usesValues_; /// \brief Values that we have read into 'values_' + 'null_count_'. - int64_t values_written_; - int64_t values_capacity_; - int64_t null_count_; - - /// \brief Each bit corresponds to one element in 'values_' and specifies if - /// it is null or not null. Not set if read_dense_for_nullable_ is true. - std::shared_ptr<::arrow::ResizableBuffer> valid_bits_; - - /// \brief Buffer for definition levels. May contain more levels than - /// is actually read. This is because we read levels ahead to - /// figure out record boundaries for repeated fields. - /// For flat required fields, 'def_levels_' and 'rep_levels_' are not - /// populated. For non-repeated fields 'rep_levels_' is not populated. - /// 'def_levels_' and 'rep_levels_' must be of the same size if present. - std::shared_ptr<::arrow::ResizableBuffer> def_levels_; - /// \brief Buffer for repetition levels. Only populated for repeated - /// fields. - std::shared_ptr<::arrow::ResizableBuffer> rep_levels_; - - /// \brief Number of definition / repetition levels that have been written - /// internally in the reader. This may be larger than values_written() since - /// for repeated fields we need to look at the levels in advance to figure out - /// the record boundaries. - int64_t levels_written_; + int64_t valuesWritten_; + int64_t valuesCapacity_; + int64_t nullCount_; + + /// \brief Each bit corresponds to one element in 'values_' and specifies if. + /// It is null or not null. Not set if read_dense_for_nullable_ is true. + std::shared_ptr<::arrow::ResizableBuffer> validBits_; + + /// \brief Buffer for definition levels. May contain more levels than. + /// Is actually read. This is because we read levels ahead to. + /// Figure out record boundaries for repeated fields. + /// For flat required fields, 'def_levels_' and 'rep_levels_' are not. + /// Populated. For non-repeated fields 'rep_levels_' is not populated. + /// 'Def_levels_' and 'rep_levels_' must be of the same size if present. + std::shared_ptr<::arrow::ResizableBuffer> defLevels_; + /// \brief Buffer for repetition levels. Only populated for repeated. + /// Fields. + std::shared_ptr<::arrow::ResizableBuffer> repLevels_; + + /// \brief Number of definition / repetition levels that have been written. + /// Internally in the reader. This may be larger than values_written() since. + /// For repeated fields we need to look at the levels in advance to figure + /// out. The record boundaries. + int64_t levelsWritten_; /// \brief Position of the next level that should be consumed. - int64_t levels_position_; - int64_t levels_capacity_; + int64_t levelsPosition_; + int64_t levelsCapacity_; - bool read_dictionary_ = false; - // If true, we will not leave any space for the null values in the values_ - // vector. - bool read_dense_for_nullable_ = false; + bool readDictionary_ = false; + // If true, we will not leave any space for the null values in the values_. + // Vector. + bool readDenseForNullable_ = false; }; class BinaryRecordReader : virtual public RecordReader { public: - virtual std::vector> GetBuilderChunks() = 0; + virtual std::vector> getBuilderChunks() = 0; }; -/// \brief Read records directly to dictionary-encoded Arrow form (int32 -/// indices). Only valid for BYTE_ARRAY columns +/// \brief Read records directly to dictionary-encoded Arrow form (int32. +/// Indices). Only valid for BYTE_ARRAY columns. class DictionaryRecordReader : virtual public RecordReader { public: - virtual std::shared_ptr<::arrow::ChunkedArray> GetResult() = 0; + virtual std::shared_ptr<::arrow::ChunkedArray> getResult() = 0; }; } // namespace internal diff --git a/velox/dwio/parquet/writer/arrow/tests/ColumnReaderTest.cpp b/velox/dwio/parquet/writer/arrow/tests/ColumnReaderTest.cpp index 0deb1540244..f4ff80e472e 100644 --- a/velox/dwio/parquet/writer/arrow/tests/ColumnReaderTest.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/ColumnReaderTest.cpp @@ -50,32 +50,32 @@ using testing::ElementsAre; namespace test { template -static inline bool vector_equal_with_def_levels( +static inline bool vectorEqualWithDefLevels( const std::vector& left, - const std::vector& def_levels, - int16_t max_def_levels, - int16_t max_rep_levels, + const std::vector& defLevels, + int16_t maxDefLevels, + int16_t maxRepLevels, const std::vector& right) { - size_t i_left = 0; - size_t i_right = 0; - for (size_t i = 0; i < def_levels.size(); i++) { - if (def_levels[i] == max_def_levels) { - // Compare - if (left[i_left] != right[i_right]) { - std::cerr << "index " << i << " left was " << left[i_left] + size_t iLeft = 0; + size_t iRight = 0; + for (size_t i = 0; i < defLevels.size(); i++) { + if (defLevels[i] == maxDefLevels) { + // Compare. + if (left[iLeft] != right[iRight]) { + std::cerr << "index " << i << " left was " << left[iLeft] << " right was " << right[i] << std::endl; return false; } - i_left++; - i_right++; - } else if (def_levels[i] == (max_def_levels - 1)) { - // Null entry on the lowest nested level - i_right++; - } else if (def_levels[i] < (max_def_levels - 1)) { - // Null entry on a higher nesting level, only supported for non-repeating - // data - if (max_rep_levels == 0) { - i_right++; + iLeft++; + iRight++; + } else if (defLevels[i] == (maxDefLevels - 1)) { + // Null entry on the lowest nested level. + iRight++; + } else if (defLevels[i] < (maxDefLevels - 1)) { + // Null entry on a higher nesting level, only supported for non-repeating. + // Data. + if (maxRepLevels == 0) { + iRight++; } } } @@ -85,684 +85,670 @@ static inline bool vector_equal_with_def_levels( class TestPrimitiveReader : public ::testing::Test { public: - void InitReader(const ColumnDescriptor* d) { + void initReader(const ColumnDescriptor* d) { auto pager = std::make_unique(pages_); - reader_ = ColumnReader::Make(d, std::move(pager)); + reader_ = ColumnReader::make(d, std::move(pager)); } - void CheckResults() { - std::vector vresult(num_values_, -1); - std::vector dresult(num_levels_, -1); - std::vector rresult(num_levels_, -1); - int64_t values_read = 0; - int total_values_read = 0; - int batch_actual = 0; + void checkResults() { + std::vector vresult(numValues_, -1); + std::vector dresult(numLevels_, -1); + std::vector rresult(numLevels_, -1); + int64_t valuesRead = 0; + int totalValuesRead = 0; + int batchActual = 0; Int32Reader* reader = static_cast(reader_.get()); - int32_t batch_size = 8; + int32_t batchSize = 8; int batch = 0; - // This will cover both the cases + // This will cover both the cases. // 1) batch_size < page_size (multiple ReadBatch from a single page) // 2) batch_size > page_size (BatchRead limits to a single page) do { - batch = static_cast(reader->ReadBatch( - batch_size, - &dresult[0] + batch_actual, - &rresult[0] + batch_actual, - &vresult[0] + total_values_read, - &values_read)); - total_values_read += static_cast(values_read); - batch_actual += batch; - batch_size = std::min(1 << 24, std::max(batch_size * 2, 4096)); + batch = static_cast(reader->readBatch( + batchSize, + &dresult[0] + batchActual, + &rresult[0] + batchActual, + &vresult[0] + totalValuesRead, + &valuesRead)); + totalValuesRead += static_cast(valuesRead); + batchActual += batch; + batchSize = std::min(1 << 24, std::max(batchSize * 2, 4096)); } while (batch > 0); - ASSERT_EQ(num_levels_, batch_actual); - ASSERT_EQ(num_values_, total_values_read); - ASSERT_TRUE(vector_equal(values_, vresult)); - if (max_def_level_ > 0) { - ASSERT_TRUE(vector_equal(def_levels_, dresult)); + ASSERT_EQ(numLevels_, batchActual); + ASSERT_EQ(numValues_, totalValuesRead); + ASSERT_TRUE(vectorEqual(values_, vresult)); + if (maxDefLevel_ > 0) { + ASSERT_TRUE(vectorEqual(defLevels_, dresult)); } - if (max_rep_level_ > 0) { - ASSERT_TRUE(vector_equal(rep_levels_, rresult)); + if (maxRepLevel_ > 0) { + ASSERT_TRUE(vectorEqual(repLevels_, rresult)); } - // catch improper writes at EOS - batch_actual = static_cast( - reader->ReadBatch(5, nullptr, nullptr, nullptr, &values_read)); - ASSERT_EQ(0, batch_actual); - ASSERT_EQ(0, values_read); + // Catch improper writes at EOS. + batchActual = static_cast( + reader->readBatch(5, nullptr, nullptr, nullptr, &valuesRead)); + ASSERT_EQ(0, batchActual); + ASSERT_EQ(0, valuesRead); } - void CheckResultsSpaced() { - std::vector vresult(num_levels_, -1); - std::vector dresult(num_levels_, -1); - std::vector rresult(num_levels_, -1); - std::vector valid_bits(num_levels_, 255); - int total_values_read = 0; - int batch_actual = 0; - int levels_actual = 0; - int64_t null_count = -1; - int64_t levels_read = 0; - int64_t values_read; + void checkResultsSpaced() { + std::vector vresult(numLevels_, -1); + std::vector dresult(numLevels_, -1); + std::vector rresult(numLevels_, -1); + std::vector validBits(numLevels_, 255); + int totalValuesRead = 0; + int batchActual = 0; + int levelsActual = 0; + int64_t nullCount = -1; + int64_t levelsRead = 0; + int64_t valuesRead; Int32Reader* reader = static_cast(reader_.get()); - int32_t batch_size = 8; + int32_t batchSize = 8; int batch = 0; - // This will cover both the cases + // This will cover both the cases. // 1) batch_size < page_size (multiple ReadBatch from a single page) // 2) batch_size > page_size (BatchRead limits to a single page) do { ARROW_SUPPRESS_DEPRECATION_WARNING - batch = static_cast(reader->ReadBatchSpaced( - batch_size, - dresult.data() + levels_actual, - rresult.data() + levels_actual, - vresult.data() + batch_actual, - valid_bits.data() + batch_actual, + batch = static_cast(reader->readBatchSpaced( + batchSize, + dresult.data() + levelsActual, + rresult.data() + levelsActual, + vresult.data() + batchActual, + validBits.data() + batchActual, 0, - &levels_read, - &values_read, - &null_count)); + &levelsRead, + &valuesRead, + &nullCount)); ARROW_UNSUPPRESS_DEPRECATION_WARNING - total_values_read += batch - static_cast(null_count); - batch_actual += batch; - levels_actual += static_cast(levels_read); - batch_size = std::min(1 << 24, std::max(batch_size * 2, 4096)); - } while ((batch > 0) || (levels_read > 0)); - - ASSERT_EQ(num_levels_, levels_actual); - ASSERT_EQ(num_values_, total_values_read); - if (max_def_level_ > 0) { - ASSERT_TRUE(vector_equal(def_levels_, dresult)); - ASSERT_TRUE(vector_equal_with_def_levels( - values_, dresult, max_def_level_, max_rep_level_, vresult)); + totalValuesRead += batch - static_cast(nullCount); + batchActual += batch; + levelsActual += static_cast(levelsRead); + batchSize = std::min(1 << 24, std::max(batchSize * 2, 4096)); + } while ((batch > 0) || (levelsRead > 0)); + + ASSERT_EQ(numLevels_, levelsActual); + ASSERT_EQ(numValues_, totalValuesRead); + if (maxDefLevel_ > 0) { + ASSERT_TRUE(vectorEqual(defLevels_, dresult)); + ASSERT_TRUE(vectorEqualWithDefLevels( + values_, dresult, maxDefLevel_, maxRepLevel_, vresult)); } else { - ASSERT_TRUE(vector_equal(values_, vresult)); + ASSERT_TRUE(vectorEqual(values_, vresult)); } - if (max_rep_level_ > 0) { - ASSERT_TRUE(vector_equal(rep_levels_, rresult)); + if (maxRepLevel_ > 0) { + ASSERT_TRUE(vectorEqual(repLevels_, rresult)); } - // catch improper writes at EOS + // Catch improper writes at EOS. ARROW_SUPPRESS_DEPRECATION_WARNING - batch_actual = static_cast(reader->ReadBatchSpaced( + batchActual = static_cast(reader->readBatchSpaced( 5, nullptr, nullptr, nullptr, - valid_bits.data(), + validBits.data(), 0, - &levels_read, - &values_read, - &null_count)); + &levelsRead, + &valuesRead, + &nullCount)); ARROW_UNSUPPRESS_DEPRECATION_WARNING - ASSERT_EQ(0, batch_actual); - ASSERT_EQ(0, null_count); + ASSERT_EQ(0, batchActual); + ASSERT_EQ(0, nullCount); } - void Clear() { + void clear() { values_.clear(); - def_levels_.clear(); - rep_levels_.clear(); + defLevels_.clear(); + repLevels_.clear(); pages_.clear(); reader_.reset(); } void - ExecutePlain(int num_pages, int levels_per_page, const ColumnDescriptor* d) { - num_values_ = MakePages( + executePlain(int numPages, int levelsPerPage, const ColumnDescriptor* d) { + numValues_ = makePages( d, - num_pages, - levels_per_page, - def_levels_, - rep_levels_, + numPages, + levelsPerPage, + defLevels_, + repLevels_, values_, - data_buffer_, + dataBuffer_, pages_, - Encoding::PLAIN); - num_levels_ = num_pages * levels_per_page; - InitReader(d); - CheckResults(); - Clear(); + Encoding::kPlain); + numLevels_ = numPages * levelsPerPage; + initReader(d); + checkResults(); + clear(); - num_values_ = MakePages( + numValues_ = makePages( d, - num_pages, - levels_per_page, - def_levels_, - rep_levels_, + numPages, + levelsPerPage, + defLevels_, + repLevels_, values_, - data_buffer_, + dataBuffer_, pages_, - Encoding::PLAIN); - num_levels_ = num_pages * levels_per_page; - InitReader(d); - CheckResultsSpaced(); - Clear(); + Encoding::kPlain); + numLevels_ = numPages * levelsPerPage; + initReader(d); + checkResultsSpaced(); + clear(); } - void - ExecuteDict(int num_pages, int levels_per_page, const ColumnDescriptor* d) { - num_values_ = MakePages( + void executeDict(int numPages, int levelsPerPage, const ColumnDescriptor* d) { + numValues_ = makePages( d, - num_pages, - levels_per_page, - def_levels_, - rep_levels_, + numPages, + levelsPerPage, + defLevels_, + repLevels_, values_, - data_buffer_, + dataBuffer_, pages_, - Encoding::RLE_DICTIONARY); - num_levels_ = num_pages * levels_per_page; - InitReader(d); - CheckResults(); - Clear(); + Encoding::kRleDictionary); + numLevels_ = numPages * levelsPerPage; + initReader(d); + checkResults(); + clear(); - num_values_ = MakePages( + numValues_ = makePages( d, - num_pages, - levels_per_page, - def_levels_, - rep_levels_, + numPages, + levelsPerPage, + defLevels_, + repLevels_, values_, - data_buffer_, + dataBuffer_, pages_, - Encoding::RLE_DICTIONARY); - num_levels_ = num_pages * levels_per_page; - InitReader(d); - CheckResultsSpaced(); - Clear(); + Encoding::kRleDictionary); + numLevels_ = numPages * levelsPerPage; + initReader(d); + checkResultsSpaced(); + clear(); } protected: - int num_levels_; - int num_values_; - int16_t max_def_level_; - int16_t max_rep_level_; + int numLevels_; + int numValues_; + int16_t maxDefLevel_; + int16_t maxRepLevel_; std::vector> pages_; std::shared_ptr reader_; std::vector values_; - std::vector def_levels_; - std::vector rep_levels_; - std::vector data_buffer_; // For BA and FLBA + std::vector defLevels_; + std::vector repLevels_; + std::vector dataBuffer_; // For BA and FLBA }; TEST_F(TestPrimitiveReader, TestInt32FlatRequired) { - int levels_per_page = 100; - int num_pages = 50; - max_def_level_ = 0; - max_rep_level_ = 0; - NodePtr type = schema::Int32("a", Repetition::REQUIRED); - const ColumnDescriptor descr(type, max_def_level_, max_rep_level_); - ASSERT_NO_FATAL_FAILURE(ExecutePlain(num_pages, levels_per_page, &descr)); - ASSERT_NO_FATAL_FAILURE(ExecuteDict(num_pages, levels_per_page, &descr)); + int levelsPerPage = 100; + int numPages = 50; + maxDefLevel_ = 0; + maxRepLevel_ = 0; + NodePtr type = schema::int32("a", Repetition::kRequired); + const ColumnDescriptor descr(type, maxDefLevel_, maxRepLevel_); + ASSERT_NO_FATAL_FAILURE(executePlain(numPages, levelsPerPage, &descr)); + ASSERT_NO_FATAL_FAILURE(executeDict(numPages, levelsPerPage, &descr)); } TEST_F(TestPrimitiveReader, TestInt32FlatOptional) { - int levels_per_page = 100; - int num_pages = 50; - max_def_level_ = 4; - max_rep_level_ = 0; - NodePtr type = schema::Int32("b", Repetition::OPTIONAL); - const ColumnDescriptor descr(type, max_def_level_, max_rep_level_); - ASSERT_NO_FATAL_FAILURE(ExecutePlain(num_pages, levels_per_page, &descr)); - ASSERT_NO_FATAL_FAILURE(ExecuteDict(num_pages, levels_per_page, &descr)); + int levelsPerPage = 100; + int numPages = 50; + maxDefLevel_ = 4; + maxRepLevel_ = 0; + NodePtr type = schema::int32("b", Repetition::kOptional); + const ColumnDescriptor descr(type, maxDefLevel_, maxRepLevel_); + ASSERT_NO_FATAL_FAILURE(executePlain(numPages, levelsPerPage, &descr)); + ASSERT_NO_FATAL_FAILURE(executeDict(numPages, levelsPerPage, &descr)); } TEST_F(TestPrimitiveReader, TestInt32FlatRepeated) { - int levels_per_page = 100; - int num_pages = 50; - max_def_level_ = 4; - max_rep_level_ = 2; - NodePtr type = schema::Int32("c", Repetition::REPEATED); - const ColumnDescriptor descr(type, max_def_level_, max_rep_level_); - ASSERT_NO_FATAL_FAILURE(ExecutePlain(num_pages, levels_per_page, &descr)); - ASSERT_NO_FATAL_FAILURE(ExecuteDict(num_pages, levels_per_page, &descr)); + int levelsPerPage = 100; + int numPages = 50; + maxDefLevel_ = 4; + maxRepLevel_ = 2; + NodePtr type = schema::int32("c", Repetition::kRepeated); + const ColumnDescriptor descr(type, maxDefLevel_, maxRepLevel_); + ASSERT_NO_FATAL_FAILURE(executePlain(numPages, levelsPerPage, &descr)); + ASSERT_NO_FATAL_FAILURE(executeDict(numPages, levelsPerPage, &descr)); } // Tests skipping around page boundaries. TEST_F(TestPrimitiveReader, TestSkipAroundPageBoundries) { - int levels_per_page = 100; - int num_pages = 7; - max_def_level_ = 0; - max_rep_level_ = 0; - NodePtr type = schema::Int32("b", Repetition::REQUIRED); - const ColumnDescriptor descr(type, max_def_level_, max_rep_level_); - MakePages( + int levelsPerPage = 100; + int numPages = 7; + maxDefLevel_ = 0; + maxRepLevel_ = 0; + NodePtr type = schema::int32("b", Repetition::kRequired); + const ColumnDescriptor descr(type, maxDefLevel_, maxRepLevel_); + makePages( &descr, - num_pages, - levels_per_page, - def_levels_, - rep_levels_, + numPages, + levelsPerPage, + defLevels_, + repLevels_, values_, - data_buffer_, + dataBuffer_, pages_, - Encoding::PLAIN); - InitReader(&descr); - std::vector vresult(levels_per_page / 2, -1); - std::vector dresult(levels_per_page / 2, -1); - std::vector rresult(levels_per_page / 2, -1); + Encoding::kPlain); + initReader(&descr); + std::vector vresult(levelsPerPage / 2, -1); + std::vector dresult(levelsPerPage / 2, -1); + std::vector rresult(levelsPerPage / 2, -1); Int32Reader* reader = static_cast(reader_.get()); - int64_t values_read = 0; + int64_t valuesRead = 0; // 1) skip_size > page_size (multiple pages skipped) - // Skip first 2 pages - int64_t levels_skipped = reader->Skip(2 * levels_per_page); - ASSERT_EQ(2 * levels_per_page, levels_skipped); - // Read half a page - reader->ReadBatch( - levels_per_page / 2, + // Skip first 2 pages. + int64_t levelsSkipped = reader->skip(2 * levelsPerPage); + ASSERT_EQ(2 * levelsPerPage, levelsSkipped); + // Read half a page. + reader->readBatch( + levelsPerPage / 2, dresult.data(), rresult.data(), vresult.data(), - &values_read); - std::vector sub_values( - values_.begin() + 2 * levels_per_page, - values_.begin() + static_cast(2.5 * levels_per_page)); - ASSERT_TRUE(vector_equal(sub_values, vresult)); + &valuesRead); + std::vector subValues( + values_.begin() + 2 * levelsPerPage, + values_.begin() + static_cast(2.5 * levelsPerPage)); + ASSERT_TRUE(vectorEqual(subValues, vresult)); // 2) skip_size == page_size (skip across two pages from page 2.5 to 3.5) - levels_skipped = reader->Skip(levels_per_page); - ASSERT_EQ(levels_per_page, levels_skipped); + levelsSkipped = reader->skip(levelsPerPage); + ASSERT_EQ(levelsPerPage, levelsSkipped); // Read half a page (page 3.5 to 4) - reader->ReadBatch( - levels_per_page / 2, + reader->readBatch( + levelsPerPage / 2, dresult.data(), rresult.data(), vresult.data(), - &values_read); - sub_values.clear(); - sub_values.insert( - sub_values.end(), - values_.begin() + static_cast(3.5 * levels_per_page), - values_.begin() + 4 * levels_per_page); - ASSERT_TRUE(vector_equal(sub_values, vresult)); + &valuesRead); + subValues.clear(); + subValues.insert( + subValues.end(), + values_.begin() + static_cast(3.5 * levelsPerPage), + values_.begin() + 4 * levelsPerPage); + ASSERT_TRUE(vectorEqual(subValues, vresult)); // 3) skip_size == page_size (skip page 4 from start of the page to the end) - levels_skipped = reader->Skip(levels_per_page); - ASSERT_EQ(levels_per_page, levels_skipped); + levelsSkipped = reader->skip(levelsPerPage); + ASSERT_EQ(levelsPerPage, levelsSkipped); // Read half a page (page 5 to 5.5) - reader->ReadBatch( - levels_per_page / 2, + reader->readBatch( + levelsPerPage / 2, dresult.data(), rresult.data(), vresult.data(), - &values_read); - sub_values.clear(); - sub_values.insert( - sub_values.end(), - values_.begin() + static_cast(5.0 * levels_per_page), - values_.begin() + static_cast(5.5 * levels_per_page)); - ASSERT_TRUE(vector_equal(sub_values, vresult)); + &valuesRead); + subValues.clear(); + subValues.insert( + subValues.end(), + values_.begin() + static_cast(5.0 * levelsPerPage), + values_.begin() + static_cast(5.5 * levelsPerPage)); + ASSERT_TRUE(vectorEqual(subValues, vresult)); // 4) skip_size < page_size (skip limited to a single page) // Skip half a page (page 5.5 to 6) - levels_skipped = reader->Skip(levels_per_page / 2); - ASSERT_EQ(0.5 * levels_per_page, levels_skipped); + levelsSkipped = reader->skip(levelsPerPage / 2); + ASSERT_EQ(0.5 * levelsPerPage, levelsSkipped); // Read half a page (6 to 6.5) - reader->ReadBatch( - levels_per_page / 2, + reader->readBatch( + levelsPerPage / 2, dresult.data(), rresult.data(), vresult.data(), - &values_read); - sub_values.clear(); - sub_values.insert( - sub_values.end(), - values_.begin() + static_cast(6.0 * levels_per_page), - values_.begin() + static_cast(6.5 * levels_per_page)); - ASSERT_TRUE(vector_equal(sub_values, vresult)); - - // 5) skip_size = 0 - levels_skipped = reader->Skip(0); - ASSERT_EQ(0, levels_skipped); + &valuesRead); + subValues.clear(); + subValues.insert( + subValues.end(), + values_.begin() + static_cast(6.0 * levelsPerPage), + values_.begin() + static_cast(6.5 * levelsPerPage)); + ASSERT_TRUE(vectorEqual(subValues, vresult)); + + // 5) Skip_size = 0. + levelsSkipped = reader->skip(0); + ASSERT_EQ(0, levelsSkipped); // 6) Skip past the end page. - levels_skipped = reader->Skip(levels_per_page / 2 + 10); - ASSERT_EQ(levels_per_page / 2, levels_skipped); + levelsSkipped = reader->skip(levelsPerPage / 2 + 10); + ASSERT_EQ(levelsPerPage / 2, levelsSkipped); values_.clear(); - def_levels_.clear(); - rep_levels_.clear(); + defLevels_.clear(); + repLevels_.clear(); pages_.clear(); reader_.reset(); } -// Skip with repeated field. This test makes it clear that we are skipping -// values and not records. +// Skip with repeated field. This test makes it clear that we are skipping. +// Values and not records. TEST_F(TestPrimitiveReader, TestSkipRepeatedField) { // Example schema: message M { repeated int32 b = 1 } - max_def_level_ = 1; - max_rep_level_ = 1; - NodePtr type = schema::Int32("b", Repetition::REPEATED); - const ColumnDescriptor descr(type, max_def_level_, max_rep_level_); + maxDefLevel_ = 1; + maxRepLevel_ = 1; + NodePtr type = schema::int32("b", Repetition::kRepeated); + const ColumnDescriptor descr(type, maxDefLevel_, maxRepLevel_); // Example rows: {}, {[10, 10]}, {[20, 20, 20]} std::vector values = {10, 10, 20, 20, 20}; - std::vector def_levels = {0, 1, 1, 1, 1, 1}; - std::vector rep_levels = {0, 0, 1, 0, 1, 1}; - num_values_ = static_cast(def_levels.size()); - std::shared_ptr page = MakeDataPage( + std::vector defLevels = {0, 1, 1, 1, 1, 1}; + std::vector repLevels = {0, 0, 1, 0, 1, 1}; + numValues_ = static_cast(defLevels.size()); + std::shared_ptr page = makeDataPage( &descr, values, - num_values_, - Encoding::PLAIN, - /*indices=*/{}, - /*indices_size=*/0, - def_levels, - max_def_level_, - rep_levels, - max_rep_level_); + numValues_, + Encoding::kPlain, + {}, + 0, + defLevels, + maxDefLevel_, + repLevels, + maxRepLevel_); pages_.push_back(std::move(page)); - InitReader(&descr); + initReader(&descr); Int32Reader* reader = static_cast(reader_.get()); // Vecotrs to hold read values, definition levels, and repetition levels. - std::vector read_vals(4, -1); - std::vector read_defs(4, -1); - std::vector read_reps(4, -1); + std::vector readVals(4, -1); + std::vector readDefs(4, -1); + std::vector readReps(4, -1); // Skip two levels. - int64_t levels_skipped = reader->Skip(2); - ASSERT_EQ(2, levels_skipped); - - int64_t num_read_values = 0; - // Read the next set of values - reader->ReadBatch( - 10, - read_defs.data(), - read_reps.data(), - read_vals.data(), - &num_read_values); - ASSERT_EQ(num_read_values, 4); + int64_t levelsSkipped = reader->skip(2); + ASSERT_EQ(2, levelsSkipped); + + int64_t numReadValues = 0; + // Read the next set of values. + reader->readBatch( + 10, readDefs.data(), readReps.data(), readVals.data(), &numReadValues); + ASSERT_EQ(numReadValues, 4); // Note that we end up in the record with {[10, 10]} - ASSERT_TRUE(vector_equal({10, 20, 20, 20}, read_vals)); - ASSERT_TRUE(vector_equal({1, 1, 1, 1}, read_defs)); - ASSERT_TRUE(vector_equal({1, 0, 1, 1}, read_reps)); - - // No values remain in data page - levels_skipped = reader->Skip(2); - ASSERT_EQ(0, levels_skipped); - reader->ReadBatch( - 10, - read_defs.data(), - read_reps.data(), - read_vals.data(), - &num_read_values); - ASSERT_EQ(num_read_values, 0); + ASSERT_TRUE(vectorEqual({10, 20, 20, 20}, readVals)); + ASSERT_TRUE(vectorEqual({1, 1, 1, 1}, readDefs)); + ASSERT_TRUE(vectorEqual({1, 0, 1, 1}, readReps)); + + // No values remain in data page. + levelsSkipped = reader->skip(2); + ASSERT_EQ(0, levelsSkipped); + reader->readBatch( + 10, readDefs.data(), readReps.data(), readVals.data(), &numReadValues); + ASSERT_EQ(numReadValues, 0); } // Page claims to have two values but only 1 is present. TEST_F(TestPrimitiveReader, TestReadValuesMissing) { - max_def_level_ = 1; - max_rep_level_ = 0; - constexpr int batch_size = 1; + maxDefLevel_ = 1; + maxRepLevel_ = 0; + constexpr int batchSize = 1; std::vector values(1, false); - std::vector input_def_levels(1, 1); - NodePtr type = schema::Boolean("a", Repetition::OPTIONAL); - const ColumnDescriptor descr(type, max_def_level_, max_rep_level_); + std::vector inputDefLevels(1, 1); + NodePtr type = schema::boolean("a", Repetition::kOptional); + const ColumnDescriptor descr(type, maxDefLevel_, maxRepLevel_); - // The data page falls back to plain encoding - std::shared_ptr dummy = AllocateBuffer(); - std::shared_ptr data_page = MakeDataPage( + // The data page falls back to plain encoding. + std::shared_ptr dummy = allocateBuffer(); + std::shared_ptr dataPage = makeDataPage( &descr, values, - /*num_vals=*/2, - Encoding::PLAIN, - /*indices=*/{}, - /*indices_size=*/0, - /*def_levels=*/input_def_levels, - max_def_level_, - /*rep_levels=*/{}, - /*max_rep_level=*/0); - pages_.push_back(data_page); - InitReader(&descr); + 2, + Encoding::kPlain, + {}, + 0, + inputDefLevels, + maxDefLevel_, + {}, + 0); + pages_.push_back(dataPage); + initReader(&descr); auto reader = static_cast(reader_.get()); - ASSERT_TRUE(reader->HasNext()); - std::vector def_levels(batch_size, 0); - std::vector rep_levels(batch_size, 0); - bool values_out[batch_size]; - int64_t values_read; + ASSERT_TRUE(reader->hasNext()); + std::vector defLevels(batchSize, 0); + std::vector repLevels(batchSize, 0); + bool valuesOut[batchSize]; + int64_t valuesRead; EXPECT_EQ( 1, - reader->ReadBatch( - batch_size, - def_levels.data(), - rep_levels.data(), - values_out, - &values_read)); + reader->readBatch( + batchSize, + defLevels.data(), + repLevels.data(), + valuesOut, + &valuesRead)); ASSERT_THROW( - reader->ReadBatch( - batch_size, - def_levels.data(), - rep_levels.data(), - values_out, - &values_read), + reader->readBatch( + batchSize, + defLevels.data(), + repLevels.data(), + valuesOut, + &valuesRead), ParquetException); } -// Repetition level byte length reported in Page but Max Repetition level -// is zero for the column. +// Repetition level byte length reported in Page but Max Repetition level. +// Is zero for the column. TEST_F(TestPrimitiveReader, TestRepetitionLvlBytesWithMaxRepetitionZero) { - constexpr int batch_size = 4; - max_def_level_ = 1; - max_rep_level_ = 0; - NodePtr type = schema::Int32("a", Repetition::OPTIONAL); - const ColumnDescriptor descr(type, max_def_level_, max_rep_level_); - // Bytes here came from the example parquet file in ARROW-17453's int32 - // column which was delta bit-packed. The key part is the first three - // bytes: the page header reports 1 byte for repetition levels even - // though the max rep level is 0. If that byte isn't skipped then - // we get def levels of [1, 1, 0, 0] instead of the correct [1, 1, 1, 0]. - const std::vector page_data{0x3, 0x3, 0x7, 0x80, 0x1, 0x4, 0x3, - 0x18, 0x1, 0x2, 0x0, 0x0, 0x0, 0xc, - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}; - - std::shared_ptr data_page = std::make_shared( - Buffer::Wrap(page_data.data(), page_data.size()), + constexpr int batchSize = 4; + maxDefLevel_ = 1; + maxRepLevel_ = 0; + NodePtr type = schema::int32("a", Repetition::kOptional); + const ColumnDescriptor descr(type, maxDefLevel_, maxRepLevel_); + // Bytes here came from the example parquet file in ARROW-17453's int32. + // Column which was delta bit-packed. The key part is the first three. + // Bytes: the page header reports 1 byte for repetition levels even. + // Though the max rep level is 0. If that byte isn't skipped then. + // We get def levels of [1, 1, 0, 0] instead of the correct [1, 1, 1, 0]. + const std::vector pageData{0x3, 0x3, 0x7, 0x80, 0x1, 0x4, 0x3, + 0x18, 0x1, 0x2, 0x0, 0x0, 0x0, 0xc, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}; + + std::shared_ptr dataPage = std::make_shared( + Buffer::Wrap(pageData.data(), pageData.size()), 4, 1, 4, - Encoding::DELTA_BINARY_PACKED, + Encoding::kDeltaBinaryPacked, 2, 1, 21); - pages_.push_back(data_page); - InitReader(&descr); + pages_.push_back(dataPage); + initReader(&descr); auto reader = static_cast(reader_.get()); - int16_t def_levels_out[batch_size]; - int32_t values[batch_size]; - int64_t values_read; - ASSERT_TRUE(reader->HasNext()); + int16_t defLevelsOut[batchSize]; + int32_t values[batchSize]; + int64_t valuesRead; + ASSERT_TRUE(reader->hasNext()); EXPECT_EQ( 4, - reader->ReadBatch( - batch_size, - def_levels_out, - /*replevels=*/nullptr, - values, - &values_read)); - EXPECT_EQ(3, values_read); + reader->readBatch(batchSize, defLevelsOut, nullptr, values, &valuesRead)); + EXPECT_EQ(3, valuesRead); } // Page claims to have two values but only 1 is present. TEST_F(TestPrimitiveReader, TestReadValuesMissingWithDictionary) { - constexpr int batch_size = 1; - max_def_level_ = 1; - max_rep_level_ = 0; - NodePtr type = schema::Int32("a", Repetition::OPTIONAL); - const ColumnDescriptor descr(type, max_def_level_, max_rep_level_); - std::shared_ptr dummy = AllocateBuffer(); - - std::shared_ptr dict_page = - std::make_shared(dummy, 0, Encoding::PLAIN); - std::vector input_def_levels(1, 0); - std::shared_ptr data_page = MakeDataPage( + constexpr int batchSize = 1; + maxDefLevel_ = 1; + maxRepLevel_ = 0; + NodePtr type = schema::int32("a", Repetition::kOptional); + const ColumnDescriptor descr(type, maxDefLevel_, maxRepLevel_); + std::shared_ptr dummy = allocateBuffer(); + + std::shared_ptr dictPage = + std::make_shared(dummy, 0, Encoding::kPlain); + std::vector inputDefLevels(1, 0); + std::shared_ptr dataPage = makeDataPage( &descr, {}, - /*num_vals=*/2, - Encoding::RLE_DICTIONARY, - /*indices=*/{}, - /*indices_size=*/0, - /*def_levels=*/input_def_levels, - max_def_level_, - /*rep_levels=*/{}, - /*max_rep_level=*/0); - pages_.push_back(dict_page); - pages_.push_back(data_page); - InitReader(&descr); + 2, + Encoding::kRleDictionary, + {}, + 0, + inputDefLevels, + maxDefLevel_, + {}, + 0); + pages_.push_back(dictPage); + pages_.push_back(dataPage); + initReader(&descr); auto reader = static_cast(reader_.get()); const ByteArray* dict = nullptr; - int32_t dict_len = 0; - int64_t indices_read = 0; - int32_t indices[batch_size]; - int16_t def_levels_out[batch_size]; - ASSERT_TRUE(reader->HasNext()); + int32_t dictLen = 0; + int64_t indicesRead = 0; + int32_t indices[batchSize]; + int16_t defLevelsOut[batchSize]; + ASSERT_TRUE(reader->hasNext()); EXPECT_EQ( 1, - reader->ReadBatchWithDictionary( - batch_size, - def_levels_out, - /*rep_levels=*/nullptr, + reader->readBatchWithDictionary( + batchSize, + defLevelsOut, + nullptr, indices, - &indices_read, + &indicesRead, &dict, - &dict_len)); + &dictLen)); ASSERT_THROW( - reader->ReadBatchWithDictionary( - batch_size, - def_levels_out, - /*rep_levels=*/nullptr, + reader->readBatchWithDictionary( + batchSize, + defLevelsOut, + nullptr, indices, - &indices_read, + &indicesRead, &dict, - &dict_len), + &dictLen), ParquetException); } TEST_F(TestPrimitiveReader, TestDictionaryEncodedPages) { - max_def_level_ = 0; - max_rep_level_ = 0; - NodePtr type = schema::Int32("a", Repetition::REQUIRED); - const ColumnDescriptor descr(type, max_def_level_, max_rep_level_); - std::shared_ptr dummy = AllocateBuffer(); - - std::shared_ptr dict_page = - std::make_shared(dummy, 0, Encoding::PLAIN); - std::shared_ptr data_page = MakeDataPage( - &descr, {}, 0, Encoding::RLE_DICTIONARY, {}, 0, {}, 0, {}, 0); - pages_.push_back(dict_page); - pages_.push_back(data_page); - InitReader(&descr); - // Tests Dict : PLAIN, Data : RLE_DICTIONARY - ASSERT_NO_THROW(reader_->HasNext()); + maxDefLevel_ = 0; + maxRepLevel_ = 0; + NodePtr type = schema::int32("a", Repetition::kRequired); + const ColumnDescriptor descr(type, maxDefLevel_, maxRepLevel_); + std::shared_ptr dummy = allocateBuffer(); + + std::shared_ptr dictPage = + std::make_shared(dummy, 0, Encoding::kPlain); + std::shared_ptr dataPage = makeDataPage( + &descr, {}, 0, Encoding::kRleDictionary, {}, 0, {}, 0, {}, 0); + pages_.push_back(dictPage); + pages_.push_back(dataPage); + initReader(&descr); + // Tests Dict : PLAIN, Data : RLE_DICTIONARY. + ASSERT_NO_THROW(reader_->hasNext()); pages_.clear(); - dict_page = - std::make_shared(dummy, 0, Encoding::PLAIN_DICTIONARY); - data_page = MakeDataPage( - &descr, {}, 0, Encoding::PLAIN_DICTIONARY, {}, 0, {}, 0, {}, 0); - pages_.push_back(dict_page); - pages_.push_back(data_page); - InitReader(&descr); - // Tests Dict : PLAIN_DICTIONARY, Data : PLAIN_DICTIONARY - ASSERT_NO_THROW(reader_->HasNext()); + dictPage = + std::make_shared(dummy, 0, Encoding::kPlainDictionary); + dataPage = makeDataPage( + &descr, {}, 0, Encoding::kPlainDictionary, {}, 0, {}, 0, {}, 0); + pages_.push_back(dictPage); + pages_.push_back(dataPage); + initReader(&descr); + // Tests Dict : PLAIN_DICTIONARY, Data : PLAIN_DICTIONARY. + ASSERT_NO_THROW(reader_->hasNext()); pages_.clear(); - data_page = MakeDataPage( - &descr, {}, 0, Encoding::RLE_DICTIONARY, {}, 0, {}, 0, {}, 0); - pages_.push_back(data_page); - InitReader(&descr); - // Tests dictionary page must occur before data page - ASSERT_THROW(reader_->HasNext(), ParquetException); + dataPage = makeDataPage( + &descr, {}, 0, Encoding::kRleDictionary, {}, 0, {}, 0, {}, 0); + pages_.push_back(dataPage); + initReader(&descr); + // Tests dictionary page must occur before data page. + ASSERT_THROW(reader_->hasNext(), ParquetException); pages_.clear(); - dict_page = - std::make_shared(dummy, 0, Encoding::DELTA_BYTE_ARRAY); - pages_.push_back(dict_page); - InitReader(&descr); - // Tests only RLE_DICTIONARY is supported - ASSERT_THROW(reader_->HasNext(), ParquetException); + dictPage = + std::make_shared(dummy, 0, Encoding::kDeltaByteArray); + pages_.push_back(dictPage); + initReader(&descr); + // Tests only RLE_DICTIONARY is supported. + ASSERT_THROW(reader_->hasNext(), ParquetException); pages_.clear(); - std::shared_ptr dict_page1 = - std::make_shared(dummy, 0, Encoding::PLAIN_DICTIONARY); - std::shared_ptr dict_page2 = - std::make_shared(dummy, 0, Encoding::PLAIN); - pages_.push_back(dict_page1); - pages_.push_back(dict_page2); - InitReader(&descr); - // Column cannot have more than one dictionary - ASSERT_THROW(reader_->HasNext(), ParquetException); + std::shared_ptr dictPage1 = + std::make_shared(dummy, 0, Encoding::kPlainDictionary); + std::shared_ptr dictPage2 = + std::make_shared(dummy, 0, Encoding::kPlain); + pages_.push_back(dictPage1); + pages_.push_back(dictPage2); + initReader(&descr); + // Column cannot have more than one dictionary. + ASSERT_THROW(reader_->hasNext(), ParquetException); pages_.clear(); - data_page = MakeDataPage( - &descr, {}, 0, Encoding::DELTA_BYTE_ARRAY, {}, 0, {}, 0, {}, 0); - pages_.push_back(data_page); - InitReader(&descr); - // unsupported encoding - ASSERT_THROW(reader_->HasNext(), ParquetException); + dataPage = makeDataPage( + &descr, {}, 0, Encoding::kDeltaByteArray, {}, 0, {}, 0, {}, 0); + pages_.push_back(dataPage); + initReader(&descr); + // Unsupported encoding. + ASSERT_THROW(reader_->hasNext(), ParquetException); pages_.clear(); } TEST_F(TestPrimitiveReader, TestDictionaryEncodedPagesWithExposeEncoding) { - max_def_level_ = 0; - max_rep_level_ = 0; - int levels_per_page = 100; - int num_pages = 5; - std::vector def_levels; - std::vector rep_levels; + maxDefLevel_ = 0; + maxRepLevel_ = 0; + int levelsPerPage = 100; + int numPages = 5; + std::vector defLevels; + std::vector repLevels; std::vector values; std::vector buffer; - NodePtr type = schema::ByteArray("a", Repetition::REQUIRED); - const ColumnDescriptor descr(type, max_def_level_, max_rep_level_); + NodePtr type = schema::byteArray("a", Repetition::kRequired); + const ColumnDescriptor descr(type, maxDefLevel_, maxRepLevel_); - // Fully dictionary encoded - MakePages( + // Fully dictionary encoded. + makePages( &descr, - num_pages, - levels_per_page, - def_levels, - rep_levels, + numPages, + levelsPerPage, + defLevels, + repLevels, values, buffer, pages_, - Encoding::RLE_DICTIONARY); - InitReader(&descr); + Encoding::kRleDictionary); + initReader(&descr); auto reader = static_cast(reader_.get()); const ByteArray* dict = nullptr; - int32_t dict_len = 0; - int64_t total_indices = 0; - int64_t indices_read = 0; - int64_t value_size = values.size(); - auto indices = std::make_unique(value_size); - while (total_indices < value_size && reader->HasNext()) { - const ByteArray* tmp_dict = nullptr; - int32_t tmp_dict_len = 0; - EXPECT_NO_THROW(reader->ReadBatchWithDictionary( - value_size, - /*def_levels=*/nullptr, - /*rep_levels=*/nullptr, - indices.get() + total_indices, - &indices_read, - &tmp_dict, - &tmp_dict_len)); - if (tmp_dict != nullptr) { - // Dictionary is read along with data - EXPECT_GT(indices_read, 0); - dict = tmp_dict; - dict_len = tmp_dict_len; + int32_t dictLen = 0; + int64_t totalIndices = 0; + int64_t indicesRead = 0; + int64_t valueSize = values.size(); + auto indices = std::make_unique(valueSize); + while (totalIndices < valueSize && reader->hasNext()) { + const ByteArray* tmpDict = nullptr; + int32_t tmpDictLen = 0; + EXPECT_NO_THROW(reader->readBatchWithDictionary( + valueSize, + nullptr, + nullptr, + indices.get() + totalIndices, + &indicesRead, + &tmpDict, + &tmpDictLen)); + if (tmpDict != nullptr) { + // Dictionary is read along with data. + EXPECT_GT(indicesRead, 0); + dict = tmpDict; + dictLen = tmpDictLen; } else { - // Dictionary is not read when there's no data - EXPECT_EQ(indices_read, 0); + // Dictionary is not read when there's no data. + EXPECT_EQ(indicesRead, 0); } - total_indices += indices_read; + totalIndices += indicesRead; } - EXPECT_EQ(total_indices, value_size); - for (int64_t i = 0; i < total_indices; ++i) { - EXPECT_LT(indices[i], dict_len); + EXPECT_EQ(totalIndices, valueSize); + for (int64_t i = 0; i < totalIndices; ++i) { + EXPECT_LT(indices[i], dictLen); EXPECT_EQ(dict[indices[i]].len, values[i].len); EXPECT_EQ(memcmp(dict[indices[i]].ptr, values[i].ptr, values[i].len), 0); } @@ -770,68 +756,68 @@ TEST_F(TestPrimitiveReader, TestDictionaryEncodedPagesWithExposeEncoding) { } TEST_F(TestPrimitiveReader, TestNonDictionaryEncodedPagesWithExposeEncoding) { - max_def_level_ = 0; - max_rep_level_ = 0; - int64_t value_size = 100; - std::vector values(value_size, 0); - NodePtr type = schema::Int32("a", Repetition::REQUIRED); - const ColumnDescriptor descr(type, max_def_level_, max_rep_level_); - - // The data page falls back to plain encoding - std::shared_ptr dummy = AllocateBuffer(); - std::shared_ptr dict_page = - std::make_shared(dummy, 0, Encoding::PLAIN); - std::shared_ptr data_page = MakeDataPage( + maxDefLevel_ = 0; + maxRepLevel_ = 0; + int64_t valueSize = 100; + std::vector values(valueSize, 0); + NodePtr type = schema::int32("a", Repetition::kRequired); + const ColumnDescriptor descr(type, maxDefLevel_, maxRepLevel_); + + // The data page falls back to plain encoding. + std::shared_ptr dummy = allocateBuffer(); + std::shared_ptr dictPage = + std::make_shared(dummy, 0, Encoding::kPlain); + std::shared_ptr dataPage = makeDataPage( &descr, values, - static_cast(value_size), - Encoding::PLAIN, - /*indices=*/{}, - /*indices_size=*/0, - /*def_levels=*/{}, - /*max_def_level=*/0, - /*rep_levels=*/{}, - /*max_rep_level=*/0); - pages_.push_back(dict_page); - pages_.push_back(data_page); - InitReader(&descr); + static_cast(valueSize), + Encoding::kPlain, + {}, + 0, + {}, + 0, + {}, + 0); + pages_.push_back(dictPage); + pages_.push_back(dataPage); + initReader(&descr); auto reader = static_cast(reader_.get()); const ByteArray* dict = nullptr; - int32_t dict_len = 0; - int64_t indices_read = 0; - auto indices = std::make_unique(value_size); - // Dictionary cannot be exposed when it's not fully dictionary encoded + int32_t dictLen = 0; + int64_t indicesRead = 0; + auto indices = std::make_unique(valueSize); + // Dictionary cannot be exposed when it's not fully dictionary encoded. EXPECT_THROW( - reader->ReadBatchWithDictionary( - value_size, - /*def_levels=*/nullptr, - /*rep_levels=*/nullptr, + reader->readBatchWithDictionary( + valueSize, + nullptr, + nullptr, indices.get(), - &indices_read, + &indicesRead, &dict, - &dict_len), + &dictLen), ParquetException); pages_.clear(); } namespace { -LevelInfo ComputeLevelInfo(const ColumnDescriptor* descr) { - LevelInfo level_info; - level_info.defLevel = descr->max_definition_level(); - level_info.repLevel = descr->max_repetition_level(); +LevelInfo computeLevelInfo(const ColumnDescriptor* descr) { + LevelInfo levelInfo; + levelInfo.defLevel = descr->maxDefinitionLevel(); + levelInfo.repLevel = descr->maxRepetitionLevel(); - int16_t min_spaced_def_level = descr->max_definition_level(); - const schema::Node* node = descr->schema_node().get(); - while (node != nullptr && !node->is_repeated()) { - if (node->is_optional()) { - min_spaced_def_level--; + int16_t minSpacedDefLevel = descr->maxDefinitionLevel(); + const schema::Node* Node = descr->schemaNode().get(); + while (Node != nullptr && !Node->isRepeated()) { + if (Node->isOptional()) { + minSpacedDefLevel--; } - node = node->parent(); + Node = Node->parent(); } - level_info.repeatedAncestorDefLevel = min_spaced_def_level; - return level_info; + levelInfo.repeatedAncestorDefLevel = minSpacedDefLevel; + return levelInfo; } } // namespace @@ -842,1277 +828,995 @@ class RecordReaderPrimitiveTypeTest public: const int32_t kNullValue = -1; - void Init(NodePtr column) { - NodePtr root = GroupNode::Make("root", Repetition::REQUIRED, {column}); - schema_descriptor_.Init(root); - descr_ = schema_descriptor_.Column(0); - record_reader_ = internal::RecordReader::Make( + void init(NodePtr column) { + NodePtr root = GroupNode::make("root", Repetition::kRequired, {column}); + schemaDescriptor_.init(root); + descr_ = schemaDescriptor_.column(0); + recordReader_ = internal::RecordReader::make( descr_, - ComputeLevelInfo(descr_), + computeLevelInfo(descr_), ::arrow::default_memory_pool(), - /*read_dictionary=*/false, + false, GetParam()); } - void CheckReadValues( - std::vector expected_values, - std::vector expected_defs, - std::vector expected_reps) { - const auto read_values = - reinterpret_cast(record_reader_->values()); - std::vector read_vals( - read_values, read_values + record_reader_->values_written()); - ASSERT_EQ(read_vals.size(), expected_values.size()); - for (size_t i = 0; i < expected_values.size(); ++i) { - if (expected_values[i] != kNullValue) { - ASSERT_EQ(expected_values[i], read_values[i]); + void checkReadValues( + std::vector expectedValues, + std::vector expectedDefs, + std::vector expectedReps) { + const auto readValues = + reinterpret_cast(recordReader_->values()); + std::vector readVals( + readValues, readValues + recordReader_->valuesWritten()); + ASSERT_EQ(readVals.size(), expectedValues.size()); + for (size_t i = 0; i < expectedValues.size(); ++i) { + if (expectedValues[i] != kNullValue) { + ASSERT_EQ(expectedValues[i], readValues[i]); } } - if (!descr_->schema_node()->is_required()) { - std::vector read_defs( - record_reader_->def_levels(), - record_reader_->def_levels() + record_reader_->levels_position()); - ASSERT_TRUE(vector_equal(expected_defs, read_defs)); + if (!descr_->schemaNode()->isRequired()) { + std::vector readDefs( + recordReader_->defLevels(), + recordReader_->defLevels() + recordReader_->levelsPosition()); + ASSERT_TRUE(vectorEqual(expectedDefs, readDefs)); } - if (descr_->schema_node()->is_repeated()) { - std::vector read_reps( - record_reader_->rep_levels(), - record_reader_->rep_levels() + record_reader_->levels_position()); - ASSERT_TRUE(vector_equal(expected_reps, read_reps)); + if (descr_->schemaNode()->isRepeated()) { + std::vector readReps( + recordReader_->repLevels(), + recordReader_->repLevels() + recordReader_->levelsPosition()); + ASSERT_TRUE(vectorEqual(expectedReps, readReps)); } } - void CheckState( - int64_t values_written, - int64_t null_count, - int64_t levels_written, - int64_t levels_position) { - ASSERT_EQ(record_reader_->values_written(), values_written); - ASSERT_EQ(record_reader_->null_count(), null_count); - ASSERT_EQ(record_reader_->levels_written(), levels_written); - ASSERT_EQ(record_reader_->levels_position(), levels_position); + void checkState( + int64_t valuesWritten, + int64_t nullCount, + int64_t levelsWritten, + int64_t levelsPosition) { + ASSERT_EQ(recordReader_->valuesWritten(), valuesWritten); + ASSERT_EQ(recordReader_->nullCount(), nullCount); + ASSERT_EQ(recordReader_->levelsWritten(), levelsWritten); + ASSERT_EQ(recordReader_->levelsPosition(), levelsPosition); } protected: - SchemaDescriptor schema_descriptor_; - std::shared_ptr record_reader_; + SchemaDescriptor schemaDescriptor_; + std::shared_ptr recordReader_; const ColumnDescriptor* descr_; }; -// Tests reading a required field. The expected results are the same for -// reading dense and spaced. +// Tests reading a required field. The expected results are the same for. +// Reading dense and spaced. TEST_P(RecordReaderPrimitiveTypeTest, ReadRequired) { - Init(schema::Int32("b", Repetition::REQUIRED)); + init(schema::int32("b", Repetition::kRequired)); // Records look like: {10, 20, 20, 30, 30, 30} std::vector> pages; std::vector values = {10, 20, 20, 30, 30, 30}; - std::vector def_levels = {}; - std::vector rep_levels = {}; + std::vector defLevels = {}; + std::vector repLevels = {}; - std::shared_ptr page = MakeDataPage( + std::shared_ptr page = makeDataPage( descr_, values, - /*num_vals=*/static_cast(def_levels.size()), - Encoding::PLAIN, - /*indices=*/{}, - /*indices_size=*/0, - def_levels, - descr_->max_definition_level(), - rep_levels, - descr_->max_repetition_level()); + static_cast(defLevels.size()), + Encoding::kPlain, + {}, + 0, + defLevels, + descr_->maxDefinitionLevel(), + repLevels, + descr_->maxRepetitionLevel()); pages.push_back(std::move(page)); auto pager = std::make_unique(pages); - record_reader_->SetPageReader(std::move(pager)); - - // Read [10] - int64_t records_read = record_reader_->ReadRecords(/*num_records=*/1); - ASSERT_EQ(records_read, 1); - CheckState( - /*values_written=*/1, - /*null_count=*/0, - /*levels_written=*/0, - /*levels_position=*/0); - CheckReadValues( - /*expected_values=*/{10}, - /*expected_defs=*/{}, - /*expected_reps=*/{}); - record_reader_->Reset(); - CheckState( - /*values_written=*/0, - /*null_count=*/0, - /*levels_written=*/0, - /*levels_position=*/0); - - // Read 20, 20, 30, 30, 30 - records_read = record_reader_->ReadRecords(/*num_records=*/10); - ASSERT_EQ(records_read, 5); - CheckState( - /*values_written=*/5, - /*null_count=*/0, - /*levels_written=*/0, - /*levels_position=*/0); - CheckReadValues( - /*expected_values=*/{20, 20, 30, 30, 30}, - /*expected_defs=*/{}, - /*expected_reps=*/{}); - record_reader_->Reset(); - CheckState( - /*values_written=*/0, - /*null_count=*/0, - /*levels_written=*/0, - /*levels_position=*/0); + recordReader_->setPageReader(std::move(pager)); + + // Read [10]. + int64_t recordsRead = recordReader_->readRecords(1); + ASSERT_EQ(recordsRead, 1); + checkState(1, 0, 0, 0); + checkReadValues({10}, {}, {}); + recordReader_->reset(); + checkState(0, 0, 0, 0); + + // Read 20, 20, 30, 30, 30. + recordsRead = recordReader_->readRecords(10); + ASSERT_EQ(recordsRead, 5); + checkState(5, 0, 0, 0); + checkReadValues({20, 20, 30, 30, 30}, {}, {}); + recordReader_->reset(); + checkState(0, 0, 0, 0); } // Tests reading an optional field. -// Use a max definition field > 1 to test both cases where parent is present or -// parent is missing. +// Use a max definition field > 1 to test both cases where parent is present or. +// Parent is missing. TEST_P(RecordReaderPrimitiveTypeTest, ReadOptional) { - NodePtr column = GroupNode::Make( + NodePtr column = GroupNode::make( "a", - Repetition::OPTIONAL, - {PrimitiveNode::Make( - "element", Repetition::OPTIONAL, ParquetType::INT32)}); - Init(column); + Repetition::kOptional, + {PrimitiveNode::make( + "element", Repetition::kOptional, ParquetType::kInt32)}); + init(column); // Records look like: {10, null, 20, 20, null, 30, 30, 30, null} std::vector> pages; std::vector values = {10, 20, 20, 30, 30, 30}; - std::vector def_levels = {2, 0, 2, 2, 1, 2, 2, 2, 0}; + std::vector defLevels = {2, 0, 2, 2, 1, 2, 2, 2, 0}; - std::shared_ptr page = MakeDataPage( + std::shared_ptr page = makeDataPage( descr_, values, - /*num_vals=*/static_cast(def_levels.size()), - Encoding::PLAIN, - /*indices=*/{}, - /*indices_size=*/0, - def_levels, - descr_->max_definition_level(), - /*rep_levels=*/{}, - descr_->max_repetition_level()); + static_cast(defLevels.size()), + Encoding::kPlain, + {}, + 0, + defLevels, + descr_->maxDefinitionLevel(), + {}, + descr_->maxRepetitionLevel()); pages.push_back(std::move(page)); auto pager = std::make_unique(pages); - record_reader_->SetPageReader(std::move(pager)); - - // Read 10, null - int64_t records_read = record_reader_->ReadRecords(/*num_records=*/2); - ASSERT_EQ(records_read, 2); - if (GetParam() == /*read_dense_for_nullable=*/true) { - CheckState( - /*values_written=*/1, - /*null_count=*/0, - /*levels_written=*/9, - /*levels_position=*/2); - CheckReadValues( - /*expected_values=*/{10}, - /*expected_defs=*/{2, 0}, - /*expected_reps=*/{}); + recordReader_->setPageReader(std::move(pager)); + + // Read 10, null. + int64_t recordsRead = recordReader_->readRecords(2); + ASSERT_EQ(recordsRead, 2); + if (GetParam() == true) { + checkState(1, 0, 9, 2); + checkReadValues({10}, {2, 0}, {}); } else { - CheckState( - /*values_written=*/2, - /*null_count=*/1, - /*levels_written=*/9, - /*levels_position=*/2); - CheckReadValues( - /*expected_values=*/{10, kNullValue}, - /*expected_defs=*/{2, 0}, - /*expected_reps=*/{}); + checkState(2, 1, 9, 2); + checkReadValues({10, kNullValue}, {2, 0}, {}); } - record_reader_->Reset(); - CheckState( - /*values_written=*/0, - /*null_count=*/0, - /*levels_written=*/7, - /*levels_position=*/0); - - // Read 20, 20, null (parent present), 30, 30, 30 - records_read = record_reader_->ReadRecords(/*num_records=*/6); - ASSERT_EQ(records_read, 6); - if (GetParam() == /*read_dense_for_nullable=*/true) { - CheckState( - /*values_written=*/5, - /*null_count=*/0, - /*levels_written=*/7, - /*levels_position=*/6); - CheckReadValues( - /*expected_values=*/{20, 20, 30, 30, 30}, - /*expected_defs=*/{2, 2, 1, 2, 2, 2}, - /*expected_reps=*/{}); + recordReader_->reset(); + checkState(0, 0, 7, 0); + + // Read 20, 20, null (parent present), 30, 30, 30. + recordsRead = recordReader_->readRecords(6); + ASSERT_EQ(recordsRead, 6); + if (GetParam() == true) { + checkState(5, 0, 7, 6); + checkReadValues({20, 20, 30, 30, 30}, {2, 2, 1, 2, 2, 2}, {}); } else { - CheckState( - /*values_written=*/6, - /*null_count=*/1, - /*levels_written=*/7, - /*levels_position=*/6); - CheckReadValues( - /*expected_values=*/{20, 20, kNullValue, 30, 30, 30}, - /*expected_defs=*/{2, 2, 1, 2, 2, 2}, - /*expected_reps=*/{}); + checkState(6, 1, 7, 6); + checkReadValues({20, 20, kNullValue, 30, 30, 30}, {2, 2, 1, 2, 2, 2}, {}); } - record_reader_->Reset(); - CheckState( - /*values_written=*/0, - /*null_count=*/0, - /*levels_written=*/1, - /*levels_position=*/0); + recordReader_->reset(); + checkState(0, 0, 1, 0); // Read the last null value and read past the end. - records_read = record_reader_->ReadRecords(/*num_records=*/3); - ASSERT_EQ(records_read, 1); - if (GetParam() == /*read_dense_for_nullable=*/true) { - CheckState( - /*values_written=*/0, - /*null_count=*/0, - /*levels_written=*/1, - /*levels_position=*/1); - CheckReadValues( - /*expected_values=*/{}, - /*expected_defs=*/{0}, - /*expected_reps=*/{}); + recordsRead = recordReader_->readRecords(3); + ASSERT_EQ(recordsRead, 1); + if (GetParam() == true) { + checkState(0, 0, 1, 1); + checkReadValues({}, {0}, {}); } else { - CheckState( - /*values_written=*/1, - /*null_count=*/1, - /*levels_written=*/1, - /*levels_position=*/1); - CheckReadValues( - /*expected_values=*/{kNullValue}, - /*expected_defs=*/{0}, - /*expected_reps=*/{}); + checkState(1, 1, 1, 1); + checkReadValues({kNullValue}, {0}, {}); } - record_reader_->Reset(); - CheckState( - /*values_written=*/0, - /*null_count=*/0, - /*levels_written=*/0, - /*levels_position=*/0); + recordReader_->reset(); + checkState(0, 0, 0, 0); } -// Tests reading a required repeated field. The results are the same for reading -// dense or spaced. +// Tests reading a required repeated field. The results are the same for +// reading. Dense or spaced. TEST_P(RecordReaderPrimitiveTypeTest, ReadRequiredRepeated) { - NodePtr column = GroupNode::Make( + NodePtr column = GroupNode::make( "p", - Repetition::REQUIRED, - {GroupNode::Make( + Repetition::kRequired, + {GroupNode::make( "list", - Repetition::REPEATED, - {PrimitiveNode::Make( - "element", Repetition::REQUIRED, ParquetType::INT32)})}); - Init(column); + Repetition::kRepeated, + {PrimitiveNode::make( + "element", Repetition::kRequired, ParquetType::kInt32)})}); + init(column); // Records look like: {[10], [20, 20], [30, 30, 30]} std::vector> pages; std::vector values = {10, 20, 20, 30, 30, 30}; - std::vector def_levels = {1, 1, 1, 1, 1, 1}; - std::vector rep_levels = {0, 0, 1, 0, 1, 1}; + std::vector defLevels = {1, 1, 1, 1, 1, 1}; + std::vector repLevels = {0, 0, 1, 0, 1, 1}; - std::shared_ptr page = MakeDataPage( + std::shared_ptr page = makeDataPage( descr_, values, - /*num_vals=*/static_cast(def_levels.size()), - Encoding::PLAIN, - /*indices=*/{}, - /*indices_size=*/0, - def_levels, - descr_->max_definition_level(), - rep_levels, - descr_->max_repetition_level()); + static_cast(defLevels.size()), + Encoding::kPlain, + {}, + 0, + defLevels, + descr_->maxDefinitionLevel(), + repLevels, + descr_->maxRepetitionLevel()); pages.push_back(std::move(page)); auto pager = std::make_unique(pages); - record_reader_->SetPageReader(std::move(pager)); - - // Read [10] - int64_t records_read = record_reader_->ReadRecords(/*num_records=*/1); - ASSERT_EQ(records_read, 1); - CheckState( - /*values_written=*/1, - /*null_count=*/0, - /*levels_written=*/6, - /*levels_position=*/1); - CheckReadValues( - /*expected_values=*/{10}, - /*expected_defs=*/{1}, - /*expected_reps=*/{0}); - record_reader_->Reset(); - CheckState( - /*values_written=*/0, - /*null_count=*/0, - /*levels_written=*/5, - /*levels_position=*/0); - - // Read [20, 20], [30, 30, 30] - records_read = record_reader_->ReadRecords(/*num_records=*/3); - ASSERT_EQ(records_read, 2); - CheckState( - /*values_written=*/5, - /*null_count=*/0, - /*levels_written=*/5, - /*levels_position=*/5); - CheckReadValues( - /*expected_values=*/{20, 20, 30, 30, 30}, - /*expected_defs=*/{1, 1, 1, 1, 1}, - /*expected_reps=*/{0, 1, 0, 1, 1}); - record_reader_->Reset(); - CheckState( - /*values_written=*/0, - /*null_count=*/0, - /*levels_written=*/0, - /*levels_position=*/0); + recordReader_->setPageReader(std::move(pager)); + + // Read [10]. + int64_t recordsRead = recordReader_->readRecords(1); + ASSERT_EQ(recordsRead, 1); + checkState(1, 0, 6, 1); + checkReadValues({10}, {1}, {0}); + recordReader_->reset(); + checkState(0, 0, 5, 0); + + // Read [20, 20], [30, 30, 30]. + recordsRead = recordReader_->readRecords(3); + ASSERT_EQ(recordsRead, 2); + checkState(5, 0, 5, 5); + checkReadValues({20, 20, 30, 30, 30}, {1, 1, 1, 1, 1}, {0, 1, 0, 1, 1}); + recordReader_->reset(); + checkState(0, 0, 0, 0); } -// Tests reading a nullable repeated field. Tests reading null values at -// differnet levels and reading an empty list. +// Tests reading a nullable repeated field. Tests reading null values at. +// Differnet levels and reading an empty list. TEST_P(RecordReaderPrimitiveTypeTest, ReadNullableRepeated) { - NodePtr column = GroupNode::Make( + NodePtr column = GroupNode::make( "p", - Repetition::OPTIONAL, - {GroupNode::Make( + Repetition::kOptional, + {GroupNode::make( "list", - Repetition::REPEATED, - {PrimitiveNode::Make( - "element", Repetition::OPTIONAL, ParquetType::INT32)})}); - Init(column); + Repetition::kRepeated, + {PrimitiveNode::make( + "element", Repetition::kOptional, ParquetType::kInt32)})}); + init(column); // Records look like: {[10], null, [20, 20], [], [30, 30, null, 30]} - // Some explanation regarding the behavior. When reading spaced, for an empty - // list or for a top-level null, we do not leave a space and we do not count - // it towards null_count. For a leaf-level null, we leave a space for it and - // we count it towards null_count. When reading dense, null_count is always 0, - // and we do not leave any space for values. + // Some explanation regarding the behavior. When reading spaced, for an empty. + // List or for a top-level null, we do not leave a space and we do not count. + // It towards null_count. For a leaf-level null, we leave a space for it and. + // We count it towards null_count. When reading dense, null_count is always + // 0,. And we do not leave any space for values. std::vector> pages; std::vector values = {10, 20, 20, 30, 30, 30}; - std::vector def_levels = {3, 0, 3, 3, 1, 3, 3, 2, 3}; - std::vector rep_levels = {0, 0, 0, 1, 0, 0, 1, 1, 1}; + std::vector defLevels = {3, 0, 3, 3, 1, 3, 3, 2, 3}; + std::vector repLevels = {0, 0, 0, 1, 0, 0, 1, 1, 1}; - std::shared_ptr page = MakeDataPage( + std::shared_ptr page = makeDataPage( descr_, values, - /*num_vals=*/static_cast(def_levels.size()), - Encoding::PLAIN, - /*indices=*/{}, - /*indices_size=*/0, - def_levels, - descr_->max_definition_level(), - rep_levels, - descr_->max_repetition_level()); + static_cast(defLevels.size()), + Encoding::kPlain, + {}, + 0, + defLevels, + descr_->maxDefinitionLevel(), + repLevels, + descr_->maxRepetitionLevel()); pages.push_back(std::move(page)); auto pager = std::make_unique(pages); - record_reader_->SetPageReader(std::move(pager)); + recordReader_->setPageReader(std::move(pager)); // Test reading 0 records. - int64_t records_read = record_reader_->ReadRecords(/*num_records=*/0); - ASSERT_EQ(records_read, 0); + int64_t recordsRead = recordReader_->readRecords(0); + ASSERT_EQ(recordsRead, 0); // Test the descr() accessor. - ASSERT_EQ(record_reader_->descr()->max_definition_level(), 3); + ASSERT_EQ(recordReader_->descr()->maxDefinitionLevel(), 3); - // Read [10], null + // Read [10], null. // We do not read this null for both reading dense and spaced. - records_read = record_reader_->ReadRecords(/*num_records=*/2); - ASSERT_EQ(records_read, 2); - if (GetParam() == /*read_dense_for_nullable=*/true) { - CheckState( - /*values_written=*/1, - /*null_count=*/0, - /*levels_written=*/9, - /*levels_position=*/2); - CheckReadValues( - /*expected_values=*/{10}, - /*expected_defs=*/{3, 0}, - /*expected_reps=*/{0, 0}); + recordsRead = recordReader_->readRecords(2); + ASSERT_EQ(recordsRead, 2); + if (GetParam() == true) { + checkState(1, 0, 9, 2); + checkReadValues({10}, {3, 0}, {0, 0}); } else { - CheckState( - /*values_written=*/1, - /*null_count=*/0, - /*levels_written=*/9, - /*levels_position=*/2); - CheckReadValues( - /*expected_values=*/{10}, - /*expected_defs=*/{3, 0}, - /*expected_reps=*/{0, 0}); + checkState(1, 0, 9, 2); + checkReadValues({10}, {3, 0}, {0, 0}); } - record_reader_->Reset(); - CheckState( - /*values_written=*/0, - /*null_count=*/0, - /*levels_written=*/7, - /*levels_position=*/0); - - // Read [20, 20], [] - // We do not read any value for this, it will be counted towards null count - // when reading spaced. - records_read = record_reader_->ReadRecords(/*num_records=*/2); - ASSERT_EQ(records_read, 2); - if (GetParam() == /*read_dense_for_nullable=*/true) { - CheckState( - /*values_written=*/2, - /*null_count=*/0, - /*levels_written=*/7, - /*levels_position=*/3); - CheckReadValues( - /*expected_values=*/{20, 20}, - /*expected_defs=*/{3, 3, 1}, - /*expected_reps=*/{0, 1, 0}); + recordReader_->reset(); + checkState(0, 0, 7, 0); + + // Read [20, 20], []. + // We do not read any value for this, it will be counted towards null count. + // When reading spaced. + recordsRead = recordReader_->readRecords(2); + ASSERT_EQ(recordsRead, 2); + if (GetParam() == true) { + checkState(2, 0, 7, 3); + checkReadValues({20, 20}, {3, 3, 1}, {0, 1, 0}); } else { - CheckState( - /*values_written=*/2, - /*null_count=*/0, - /*levels_written=*/7, - /*levels_position=*/3); - CheckReadValues( - /*expected_values=*/{20, 20}, - /*expected_defs=*/{3, 3, 1}, - /*expected_reps=*/{0, 1, 0}); + checkState(2, 0, 7, 3); + checkReadValues({20, 20}, {3, 3, 1}, {0, 1, 0}); } - record_reader_->Reset(); - CheckState( - /*values_written=*/0, - /*null_count=*/0, - /*levels_written=*/4, - /*levels_position=*/0); + recordReader_->reset(); + checkState(0, 0, 4, 0); // Test reading 0 records. - records_read = record_reader_->ReadRecords(/*num_records=*/0); - ASSERT_EQ(records_read, 0); + recordsRead = recordReader_->readRecords(0); + ASSERT_EQ(recordsRead, 0); // Read the last record. - records_read = record_reader_->ReadRecords(/*num_records=*/1); - ASSERT_EQ(records_read, 1); - if (GetParam() == /*read_dense_for_nullable=*/true) { - CheckState( - /*values_written=*/3, - /*null_count=*/0, - /*levels_written=*/4, - /*levels_position=*/4); - CheckReadValues( - /*expected_values=*/{30, 30, 30}, - /*expected_defs=*/{3, 3, 2, 3}, - /*expected_reps=*/{0, 1, 1, 1}); + recordsRead = recordReader_->readRecords(1); + ASSERT_EQ(recordsRead, 1); + if (GetParam() == true) { + checkState(3, 0, 4, 4); + checkReadValues({30, 30, 30}, {3, 3, 2, 3}, {0, 1, 1, 1}); } else { - CheckState( - /*values_written=*/4, - /*null_count=*/1, - /*levels_written=*/4, - /*levels_position=*/4); - CheckReadValues( - /*expected_values=*/{30, 30, kNullValue, 30}, - /*expected_defs=*/{3, 3, 2, 3}, - /*expected_reps=*/{0, 1, 1, 1}); + checkState(4, 1, 4, 4); + checkReadValues({30, 30, kNullValue, 30}, {3, 3, 2, 3}, {0, 1, 1, 1}); } - record_reader_->Reset(); - CheckState( - /*values_written=*/0, - /*null_count=*/0, - /*levels_written=*/0, - /*levels_position=*/0); + recordReader_->reset(); + checkState(0, 0, 0, 0); } // Test that we can skip required top level field. TEST_P(RecordReaderPrimitiveTypeTest, SkipRequiredTopLevel) { - Init(schema::Int32("b", Repetition::REQUIRED)); + init(schema::int32("b", Repetition::kRequired)); std::vector> pages; std::vector values = {10, 20, 20, 30, 30, 30}; - std::shared_ptr page = MakeDataPage( + std::shared_ptr page = makeDataPage( descr_, values, - /*num_vals=*/static_cast(values.size()), - Encoding::PLAIN, - /*indices=*/{}, - /*indices_size=*/0, - /*def_levels=*/{}, - descr_->max_definition_level(), - /*rep_levels=*/{}, - descr_->max_repetition_level()); + static_cast(values.size()), + Encoding::kPlain, + {}, + 0, + {}, + descr_->maxDefinitionLevel(), + {}, + descr_->maxRepetitionLevel()); pages.push_back(std::move(page)); auto pager = std::make_unique(pages); - record_reader_->SetPageReader(std::move(pager)); - - int64_t records_skipped = record_reader_->SkipRecords(/*num_records=*/3); - ASSERT_EQ(records_skipped, 3); - CheckState( - /*values_written=*/0, - /*null_count=*/0, - /*levels_written=*/0, - /*levels_position=*/0); - - int64_t records_read = record_reader_->ReadRecords(/*num_records=*/2); - ASSERT_EQ(records_read, 2); - CheckState( - /*values_written=*/2, - /*null_count=*/0, - /*levels_written=*/0, - /*levels_position=*/0); - CheckReadValues( - /*expected_values=*/{30, 30}, - /*expected_defs=*/{}, - /*expected_reps=*/{}); - record_reader_->Reset(); - CheckState( - /*values_written=*/0, - /*null_count=*/0, - /*levels_written=*/0, - /*levels_position=*/0); + recordReader_->setPageReader(std::move(pager)); + + int64_t recordsSkipped = recordReader_->skipRecords(3); + ASSERT_EQ(recordsSkipped, 3); + checkState(0, 0, 0, 0); + + int64_t recordsRead = recordReader_->readRecords(2); + ASSERT_EQ(recordsRead, 2); + checkState(2, 0, 0, 0); + checkReadValues({30, 30}, {}, {}); + recordReader_->reset(); + checkState(0, 0, 0, 0); } // Skip an optional field. Intentionally included some null values. TEST_P(RecordReaderPrimitiveTypeTest, SkipOptional) { - Init(schema::Int32("b", Repetition::OPTIONAL)); + init(schema::int32("b", Repetition::kOptional)); // Records look like {null, 10, 20, 30, null, 40, 50, 60} std::vector> pages; std::vector values = {10, 20, 30, 40, 50, 60}; - std::vector def_levels = {0, 1, 1, 0, 1, 1, 1, 1}; + std::vector defLevels = {0, 1, 1, 0, 1, 1, 1, 1}; - std::shared_ptr page = MakeDataPage( + std::shared_ptr page = makeDataPage( descr_, values, - /*num_vals=*/static_cast(values.size()), - Encoding::PLAIN, - /*indices=*/{}, - /*indices_size=*/0, - def_levels, - descr_->max_definition_level(), - /*rep_levels=*/{}, - descr_->max_repetition_level()); + static_cast(values.size()), + Encoding::kPlain, + {}, + 0, + defLevels, + descr_->maxDefinitionLevel(), + {}, + descr_->maxRepetitionLevel()); pages.push_back(std::move(page)); auto pager = std::make_unique(pages); - record_reader_->SetPageReader(std::move(pager)); + recordReader_->setPageReader(std::move(pager)); { // Skip {null, 10} // This also tests when we start with a Skip. - int64_t records_skipped = record_reader_->SkipRecords(/*num_records=*/2); - ASSERT_EQ(records_skipped, 2); - CheckState( - /*values_written=*/0, - /*null_count=*/0, - /*levels_written=*/0, - /*levels_position=*/0); + int64_t recordsSkipped = recordReader_->skipRecords(2); + ASSERT_EQ(recordsSkipped, 2); + checkState(0, 0, 0, 0); } { // Read 3 records: {20, null, 30} - int64_t records_read = record_reader_->ReadRecords(/*num_records=*/3); - - ASSERT_EQ(records_read, 3); - if (GetParam() == /*read_dense_for_nullable=*/true) { - // We had skipped 2 of the levels above. So there is only 6 left in total - // to read, and we read 3 of them here. - CheckState( - /*values_written=*/2, - /*null_count=*/0, - /*levels_written=*/6, - /*levels_position=*/3); - CheckReadValues( - /*expected_values=*/{20, 30}, - /*expected_defs=*/{1, 0, 1}, - /*expected_reps=*/{}); + int64_t recordsRead = recordReader_->readRecords(3); + + ASSERT_EQ(recordsRead, 3); + if (GetParam() == true) { + // We had skipped 2 of the levels above. So there is only 6 left in total. + // To read, and we read 3 of them here. + checkState(2, 0, 6, 3); + checkReadValues({20, 30}, {1, 0, 1}, {}); } else { - CheckState( - /*values_written=*/3, - /*null_count=*/1, - /*levels_written=*/6, - /*levels_position=*/3); - CheckReadValues( - /*expected_values=*/{20, kNullValue, 30}, - /*expected_defs=*/{1, 0, 1}, - /*expected_reps=*/{}); + checkState(3, 1, 6, 3); + checkReadValues({20, kNullValue, 30}, {1, 0, 1}, {}); } - record_reader_->Reset(); + recordReader_->reset(); } { // Skip {40, 50}. - int64_t records_skipped = record_reader_->SkipRecords(/*num_records=*/2); - ASSERT_EQ(records_skipped, 2); - CheckState( - /*values_written=*/0, - /*null_count=*/0, - /*levels_written=*/1, - /*levels_position=*/0); - CheckReadValues( - /*expected_values=*/{}, - /*expected_defs=*/{}, - /*expected_reps=*/{}); + int64_t recordsSkipped = recordReader_->skipRecords(2); + ASSERT_EQ(recordsSkipped, 2); + checkState(0, 0, 1, 0); + checkReadValues({}, {}, {}); // Try reset after a Skip. - record_reader_->Reset(); - CheckState( - /*values_written=*/0, - /*null_count=*/0, - /*levels_written=*/1, - /*levels_position=*/0); + recordReader_->reset(); + checkState(0, 0, 1, 0); } { // Read to the end of the column. Read {60} - // This test checks that ReadAndThrowAwayValues works, since if it - // does not we would read the wrong values. - int64_t records_read = record_reader_->ReadRecords(/*num_records=*/1); - - ASSERT_EQ(records_read, 1); - CheckState( - /*values_written=*/1, - /*null_count=*/0, - /*levels_written=*/1, - /*levels_position=*/1); - CheckReadValues( - /*expected_values=*/{60}, - /*expected_defs=*/{1}, - /*expected_reps=*/{}); + // This test checks that ReadAndThrowAwayValues works, since if it. + // Does not we would read the wrong values. + int64_t recordsRead = recordReader_->readRecords(1); + + ASSERT_EQ(recordsRead, 1); + checkState(1, 0, 1, 1); + checkReadValues({60}, {1}, {}); } // We have exhausted all the records. - ASSERT_EQ(record_reader_->ReadRecords(/*num_records=*/3), 0); - ASSERT_EQ(record_reader_->SkipRecords(/*num_records=*/3), 0); + ASSERT_EQ(recordReader_->readRecords(3), 0); + ASSERT_EQ(recordReader_->skipRecords(3), 0); } // Test skipping for repeated fields. TEST_P(RecordReaderPrimitiveTypeTest, SkipRepeated) { - Init(schema::Int32("b", Repetition::REPEATED)); + init(schema::int32("b", Repetition::kRepeated)); // Records look like {null, [20, 20, 20], null, [30, 30], [40]} std::vector> pages; std::vector values = {20, 20, 20, 30, 30, 40}; - std::vector def_levels = {0, 1, 1, 1, 0, 1, 1, 1}; - std::vector rep_levels = {0, 0, 1, 1, 0, 0, 1, 0}; + std::vector defLevels = {0, 1, 1, 1, 0, 1, 1, 1}; + std::vector repLevels = {0, 0, 1, 1, 0, 0, 1, 0}; - std::shared_ptr page = MakeDataPage( + std::shared_ptr page = makeDataPage( descr_, values, - /*num_vals=*/static_cast(values.size()), - Encoding::PLAIN, - /*indices=*/{}, - /*indices_size=*/0, - def_levels, - descr_->max_definition_level(), - rep_levels, - descr_->max_repetition_level()); + static_cast(values.size()), + Encoding::kPlain, + {}, + 0, + defLevels, + descr_->maxDefinitionLevel(), + repLevels, + descr_->maxRepetitionLevel()); pages.push_back(std::move(page)); auto pager = std::make_unique(pages); - record_reader_->SetPageReader(std::move(pager)); + recordReader_->setPageReader(std::move(pager)); { // Skip 0 records. - int64_t records_skipped = record_reader_->SkipRecords(/*num_records=*/0); - ASSERT_EQ(records_skipped, 0); + int64_t recordsSkipped = recordReader_->skipRecords(0); + ASSERT_EQ(recordsSkipped, 0); } { // This should skip the first null record. - int64_t records_skipped = record_reader_->SkipRecords(/*num_records=*/1); - ASSERT_EQ(records_skipped, 1); - ASSERT_EQ(record_reader_->values_written(), 0); - ASSERT_EQ(record_reader_->null_count(), 0); - // For repeated fields, we need to read the levels to find the record - // boundaries and skip. So some levels are read, however, the skipped + int64_t recordsSkipped = recordReader_->skipRecords(1); + ASSERT_EQ(recordsSkipped, 1); + ASSERT_EQ(recordReader_->valuesWritten(), 0); + ASSERT_EQ(recordReader_->nullCount(), 0); + // For repeated fields, we need to read the levels to find the record. + // Boundaries and skip. So some levels are read, however, the skipped. // level should not be there after the skip. That's why levels_position() - // is 0. - CheckState( - /*values_written=*/0, - /*null_count=*/0, - /*levels_written=*/7, - /*levels_position=*/0); - CheckReadValues( - /*expected_values=*/{}, - /*expected_defs=*/{}, - /*expected_reps=*/{}); + // Is 0. + checkState(0, 0, 7, 0); + checkReadValues({}, {}, {}); } { - // Read [20, 20, 20] - int64_t records_read = record_reader_->ReadRecords(/*num_records=*/1); - ASSERT_EQ(records_read, 1); - CheckState( - /*values_written=*/3, - /*null_count=*/0, - /*levels_written=*/7, - /*levels_position=*/3); - CheckReadValues( - /*expected_values=*/{20, 20, 20}, - /*expected_defs=*/{1, 1, 1}, - /*expected_reps=*/{0, 1, 1}); + // Read [20, 20, 20]. + int64_t recordsRead = recordReader_->readRecords(1); + ASSERT_EQ(recordsRead, 1); + checkState(3, 0, 7, 3); + checkReadValues({20, 20, 20}, {1, 1, 1}, {0, 1, 1}); } { // Skip 0 records. - int64_t records_skipped = record_reader_->SkipRecords(/*num_records=*/0); - ASSERT_EQ(records_skipped, 0); + int64_t recordsSkipped = recordReader_->skipRecords(0); + ASSERT_EQ(recordsSkipped, 0); } { - // Skip the null record and also skip [30, 30] - int64_t records_skipped = record_reader_->SkipRecords(/*num_records=*/2); - ASSERT_EQ(records_skipped, 2); + // Skip the null record and also skip [30, 30]. + int64_t recordsSkipped = recordReader_->skipRecords(2); + ASSERT_EQ(recordsSkipped, 2); // We remove the skipped levels from the buffer. - CheckState( - /*values_written=*/3, - /*null_count=*/0, - /*levels_written=*/4, - /*levels_position=*/3); - CheckReadValues( - /*expected_values=*/{20, 20, 20}, - /*expected_defs=*/{1, 1, 1}, - /*expected_reps=*/{0, 1, 1}); + checkState(3, 0, 4, 3); + checkReadValues({20, 20, 20}, {1, 1, 1}, {0, 1, 1}); } { - // Read [40] - int64_t records_read = record_reader_->ReadRecords(/*num_records=*/1); - ASSERT_EQ(records_read, 1); - CheckState( - /*values_written=*/4, - /*null_count=*/0, - /*levels_written=*/4, - /*levels_position=*/4); - CheckReadValues( - /*expected_values=*/{20, 20, 20, 40}, - /*expected_defs=*/{1, 1, 1, 1}, - /*expected_reps=*/{0, 1, 1, 0}); + // Read [40]. + int64_t recordsRead = recordReader_->readRecords(1); + ASSERT_EQ(recordsRead, 1); + checkState(4, 0, 4, 4); + checkReadValues({20, 20, 20, 40}, {1, 1, 1, 1}, {0, 1, 1, 0}); } } -// Tests that for repeated fields, we first consume what is in the buffer -// before reading more levels. +// Tests that for repeated fields, we first consume what is in the buffer. +// Before reading more levels. TEST_P(RecordReaderPrimitiveTypeTest, SkipRepeatedConsumeBufferFirst) { - Init(schema::Int32("b", Repetition::REPEATED)); + init(schema::int32("b", Repetition::kRepeated)); std::vector> pages; std::vector values(2048, 10); - std::vector def_levels(2048, 1); - std::vector rep_levels(2048, 0); + std::vector defLevels(2048, 1); + std::vector repLevels(2048, 0); - std::shared_ptr page = MakeDataPage( + std::shared_ptr page = makeDataPage( descr_, values, - /*num_vals=*/static_cast(values.size()), - Encoding::PLAIN, - /*indices=*/{}, - /*indices_size=*/0, - def_levels, - descr_->max_definition_level(), - rep_levels, - descr_->max_repetition_level()); + static_cast(values.size()), + Encoding::kPlain, + {}, + 0, + defLevels, + descr_->maxDefinitionLevel(), + repLevels, + descr_->maxRepetitionLevel()); pages.push_back(std::move(page)); auto pager = std::make_unique(pages); - record_reader_->SetPageReader(std::move(pager)); + recordReader_->setPageReader(std::move(pager)); { - // Read 1000 records. We will read 1024 levels because that is the minimum - // number of levels to read. - int64_t records_read = record_reader_->ReadRecords(/*num_records=*/1000); - ASSERT_EQ(records_read, 1000); - CheckState( - /*values_written=*/1000, - /*null_count=*/0, - /*levels_written=*/1024, - /*levels_position=*/1000); - std::vector expected_values(1000, 10); - std::vector expected_def_levels(1000, 1); - std::vector expected_rep_levels(1000, 0); - CheckReadValues(expected_values, expected_def_levels, expected_rep_levels); + // Read 1000 records. We will read 1024 levels because that is the minimum. + // Number of levels to read. + int64_t recordsRead = recordReader_->readRecords(1000); + ASSERT_EQ(recordsRead, 1000); + checkState(1000, 0, 1024, 1000); + std::vector expectedValues(1000, 10); + std::vector expectedDefLevels(1000, 1); + std::vector expectedRepLevels(1000, 0); + checkReadValues(expectedValues, expectedDefLevels, expectedRepLevels); // Reset removes the already consumed values and levels. - record_reader_->Reset(); + recordReader_->reset(); } { // Skip 12 records. Since we already have 24 in the buffer, we should not be - // reading any more levels into the buffer, we will just consume 12 of it. - int64_t records_skipped = record_reader_->SkipRecords(/*num_records=*/12); - ASSERT_EQ(records_skipped, 12); - CheckState( - /*values_written=*/0, - /*null_count=*/0, - /*levels_written=*/12, - /*levels_position=*/0); + // Reading any more levels into the buffer, we will just consume 12 of it. + int64_t recordsSkipped = recordReader_->skipRecords(12); + ASSERT_EQ(recordsSkipped, 12); + checkState(0, 0, 12, 0); // Everthing is empty because we reset the reader before this skip. - CheckReadValues( - /*expected_values=*/{}, - /*expected_def_levels=*/{}, - /*expected_rep_levels=*/{}); + checkReadValues({}, {}, {}); } } // Test reading when one record spans multiple pages for a repeated field. TEST_P(RecordReaderPrimitiveTypeTest, ReadPartialRecord) { - Init(schema::Int32("b", Repetition::REPEATED)); + init(schema::int32("b", Repetition::kRepeated)); std::vector> pages; // Page 1: {[10], [20, 20, 20 ... } continues to next page. { - std::shared_ptr page = MakeDataPage( + std::shared_ptr page = makeDataPage( descr_, - /*values=*/{10, 20, 20, 20}, - /*num_vals=*/4, - Encoding::PLAIN, - /*indices=*/{}, - /*indices_size=*/0, - /*def_levels=*/{1, 1, 1, 1}, - descr_->max_definition_level(), - /*rep_levels=*/{0, 0, 1, 1}, - descr_->max_repetition_level()); + {10, 20, 20, 20}, + 4, + Encoding::kPlain, + {}, + 0, + {1, 1, 1, 1}, + descr_->maxDefinitionLevel(), + {0, 0, 1, 1}, + descr_->maxRepetitionLevel()); pages.push_back(std::move(page)); } // Page 2: {... 20, 20, ...} continues from previous page and to next page. { - std::shared_ptr page = MakeDataPage( + std::shared_ptr page = makeDataPage( descr_, - /*values=*/{20, 20}, - /*num_vals=*/2, - Encoding::PLAIN, - /*indices=*/{}, - /*indices_size=*/0, - /*def_levels=*/{1, 1}, - descr_->max_definition_level(), - /*rep_levels=*/{1, 1}, - descr_->max_repetition_level()); + {20, 20}, + 2, + Encoding::kPlain, + {}, + 0, + {1, 1}, + descr_->maxDefinitionLevel(), + {1, 1}, + descr_->maxRepetitionLevel()); pages.push_back(std::move(page)); } // Page 3: { ... 20], [30]} continues from previous page. { - std::shared_ptr page = MakeDataPage( + std::shared_ptr page = makeDataPage( descr_, - /*values=*/{20, 30}, - /*num_vals=*/2, - Encoding::PLAIN, - /*indices=*/{}, - /*indices_size=*/0, - /*def_levels=*/{1, 1}, - descr_->max_definition_level(), - /*rep_levels=*/{1, 0}, - descr_->max_repetition_level()); + {20, 30}, + 2, + Encoding::kPlain, + {}, + 0, + {1, 1}, + descr_->maxDefinitionLevel(), + {1, 0}, + descr_->maxRepetitionLevel()); pages.push_back(std::move(page)); } auto pager = std::make_unique(pages); - record_reader_->SetPageReader(std::move(pager)); + recordReader_->setPageReader(std::move(pager)); { - // Read [10] - int64_t records_read = record_reader_->ReadRecords(/*num_records=*/1); - ASSERT_EQ(records_read, 1); - CheckState( - /*values_written=*/1, - /*null_count=*/0, - /*levels_written=*/4, - /*levels_position=*/1); - CheckReadValues( - /*expected_values=*/{10}, - /*expected_defs=*/{1}, - /*expected_reps=*/{0}); + // Read [10]. + int64_t recordsRead = recordReader_->readRecords(1); + ASSERT_EQ(recordsRead, 1); + checkState(1, 0, 4, 1); + checkReadValues({10}, {1}, {0}); } { // Read [20, 20, 20, 20, 20, 20] that spans multiple pages. - int64_t records_read = record_reader_->ReadRecords(/*num_records=*/1); - ASSERT_EQ(records_read, 1); - CheckState( - /*values_written=*/7, - /*null_count=*/0, - /*levels_written=*/8, - /*levels_position=*/7); - CheckReadValues( - /*expected_values=*/{10, 20, 20, 20, 20, 20, 20}, - /*expected_defs=*/{1, 1, 1, 1, 1, 1, 1}, - /*expected_reps=*/{0, 0, 1, 1, 1, 1, 1}); + int64_t recordsRead = recordReader_->readRecords(1); + ASSERT_EQ(recordsRead, 1); + checkState(7, 0, 8, 7); + checkReadValues( + {10, 20, 20, 20, 20, 20, 20}, + {1, 1, 1, 1, 1, 1, 1}, + {0, 0, 1, 1, 1, 1, 1}); } { - // Read [30] - int64_t records_read = record_reader_->ReadRecords(/*num_records=*/1); - ASSERT_EQ(records_read, 1); - CheckState( - /*values_written=*/8, - /*null_count=*/0, - /*levels_written=*/8, - /*levels_position=*/8); - CheckReadValues( - /*expected_values=*/{10, 20, 20, 20, 20, 20, 20, 30}, - /*expected_defs=*/{1, 1, 1, 1, 1, 1, 1, 1}, - /*expected_reps=*/{0, 0, 1, 1, 1, 1, 1, 0}); + // Read [30]. + int64_t recordsRead = recordReader_->readRecords(1); + ASSERT_EQ(recordsRead, 1); + checkState(8, 0, 8, 8); + checkReadValues( + {10, 20, 20, 20, 20, 20, 20, 30}, + {1, 1, 1, 1, 1, 1, 1, 1}, + {0, 0, 1, 1, 1, 1, 1, 0}); } } -// Test skipping for repeated fields for the case when one record spans multiple -// pages. +// Test skipping for repeated fields for the case when one record spans +// multiple. Pages. TEST_P(RecordReaderPrimitiveTypeTest, SkipPartialRecord) { - Init(schema::Int32("b", Repetition::REPEATED)); + init(schema::int32("b", Repetition::kRepeated)); std::vector> pages; // Page 1: {[10], [20, 20, 20 ... } continues to next page. { - std::shared_ptr page = MakeDataPage( + std::shared_ptr page = makeDataPage( descr_, - /*values=*/{10, 20, 20, 20}, - /*num_vals=*/4, - Encoding::PLAIN, - /*indices=*/{}, - /*indices_size=*/0, - /*def_levels=*/{1, 1, 1, 1}, - descr_->max_definition_level(), - /*rep_levels=*/{0, 0, 1, 1}, - descr_->max_repetition_level()); + {10, 20, 20, 20}, + 4, + Encoding::kPlain, + {}, + 0, + {1, 1, 1, 1}, + descr_->maxDefinitionLevel(), + {0, 0, 1, 1}, + descr_->maxRepetitionLevel()); pages.push_back(std::move(page)); } // Page 2: {... 20, 20, ...} continues from previous page and to next page. { - std::shared_ptr page = MakeDataPage( + std::shared_ptr page = makeDataPage( descr_, - /*values=*/{20, 20}, - /*num_vals=*/2, - Encoding::PLAIN, - /*indices=*/{}, - /*indices_size=*/0, - /*def_levels=*/{1, 1}, - descr_->max_definition_level(), - /*rep_levels=*/{1, 1}, - descr_->max_repetition_level()); + {20, 20}, + 2, + Encoding::kPlain, + {}, + 0, + {1, 1}, + descr_->maxDefinitionLevel(), + {1, 1}, + descr_->maxRepetitionLevel()); pages.push_back(std::move(page)); } // Page 3: { ... 20, [30]} continues from previous page. { - std::shared_ptr page = MakeDataPage( + std::shared_ptr page = makeDataPage( descr_, - /*values=*/{20, 30}, - /*num_vals=*/2, - Encoding::PLAIN, - /*indices=*/{}, - /*indices_size=*/0, - /*def_levels=*/{1, 1}, - descr_->max_definition_level(), - /*rep_levels=*/{1, 0}, - descr_->max_repetition_level()); + {20, 30}, + 2, + Encoding::kPlain, + {}, + 0, + {1, 1}, + descr_->maxDefinitionLevel(), + {1, 0}, + descr_->maxRepetitionLevel()); pages.push_back(std::move(page)); } auto pager = std::make_unique(pages); - record_reader_->SetPageReader(std::move(pager)); + recordReader_->setPageReader(std::move(pager)); { - // Read [10] - int64_t records_read = record_reader_->ReadRecords(/*num_records=*/1); - ASSERT_EQ(records_read, 1); + // Read [10]. + int64_t recordsRead = recordReader_->readRecords(1); + ASSERT_EQ(recordsRead, 1); // There are 4 levels in the first page. - CheckState( - /*values_written=*/1, - /*null_count=*/0, - /*levels_written=*/4, - /*levels_position=*/1); - CheckReadValues( - /*expected_values=*/{10}, - /*expected_defs=*/{1}, - /*expected_reps=*/{0}); + checkState(1, 0, 4, 1); + checkReadValues({10}, {1}, {0}); } { // Skip the record that goes across pages. - int64_t records_skipped = record_reader_->SkipRecords(/*num_records=*/1); - ASSERT_EQ(records_skipped, 1); - CheckState( - /*values_written=*/1, - /*null_count=*/0, - /*levels_written=*/2, - /*levels_position=*/1); - CheckReadValues( - /*expected_values=*/{10}, - /*expected_defs=*/{1}, - /*expected_reps=*/{0}); + int64_t recordsSkipped = recordReader_->skipRecords(1); + ASSERT_EQ(recordsSkipped, 1); + checkState(1, 0, 2, 1); + checkReadValues({10}, {1}, {0}); } { - // Read [30] - int64_t records_read = record_reader_->ReadRecords(/*num_records=*/1); - - ASSERT_EQ(records_read, 1); - CheckState( - /*values_written=*/2, - /*null_count=*/0, - /*levels_written=*/2, - /*levels_position=*/2); - CheckReadValues( - /*expected_values=*/{10, 30}, - /*expected_defs=*/{1, 1}, - /*expected_reps=*/{0, 0}); + // Read [30]. + int64_t recordsRead = recordReader_->readRecords(1); + + ASSERT_EQ(recordsRead, 1); + checkState(2, 0, 2, 2); + checkReadValues({10, 30}, {1, 1}, {0, 0}); } } INSTANTIATE_TEST_SUITE_P( RecordReaderPrimitveTypeTests, RecordReaderPrimitiveTypeTest, - ::testing::Values(/*read_dense_for_nullable=*/true, false), + ::testing::Values(true, false), testing::PrintToStringParamName()); // Parameterized test for FLBA record reader. class FLBARecordReaderTest : public ::testing::TestWithParam { public: - bool read_dense_for_nullable() { + bool readDenseForNullable() { return GetParam(); } - void - MakeRecordReader(int levels_per_page, int num_pages, int FLBA_type_length) { - levels_per_page_ = levels_per_page; - FLBA_type_length_ = FLBA_type_length; - LevelInfo level_info; - level_info.defLevel = 1; - level_info.repLevel = 0; - NodePtr type = schema::PrimitiveNode::Make( + void makeRecordReader(int levelsPerPage, int numPages, int flbaTypeLength) { + levelsPerPage_ = levelsPerPage; + flbaTypeLength_ = flbaTypeLength; + LevelInfo levelInfo; + levelInfo.defLevel = 1; + levelInfo.repLevel = 0; + NodePtr type = schema::PrimitiveNode::make( "b", - Repetition::OPTIONAL, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::NONE, - FLBA_type_length_); + Repetition::kOptional, + Type::kFixedLenByteArray, + ConvertedType::kNone, + flbaTypeLength_); descr_ = std::make_unique( - type, level_info.defLevel, level_info.repLevel); - MakePages( + type, levelInfo.defLevel, levelInfo.repLevel); + makePages( descr_.get(), - num_pages, - levels_per_page, - def_levels_, - rep_levels_, + numPages, + levelsPerPage, + defLevels_, + repLevels_, values_, buffer_, pages_, - Encoding::PLAIN); + Encoding::kPlain); auto pager = std::make_unique(pages_); - record_reader_ = internal::RecordReader::Make( + recordReader_ = internal::RecordReader::make( descr_.get(), - level_info, + levelInfo, ::arrow::default_memory_pool(), - /*read_dictionary=*/false, - read_dense_for_nullable()); - record_reader_->SetPageReader(std::move(pager)); + false, + readDenseForNullable()); + recordReader_->setPageReader(std::move(pager)); } // Returns expected values in row range. // We need this since some values are null. - std::vector expected_values(int start, int end) { + std::vector expectedValues(int start, int end) { std::vector result; // Find out where in the values_ vector we start from. - size_t values_index = 0; + size_t valuesIndex = 0; for (int i = 0; i < start; ++i) { - if (def_levels_[i] != 0) { - ++values_index; + if (defLevels_[i] != 0) { + ++valuesIndex; } } for (int i = start; i < end; ++i) { - if (def_levels_[i] == 0) { - if (!read_dense_for_nullable()) { + if (defLevels_[i] == 0) { + if (!readDenseForNullable()) { result.emplace_back(); } continue; } result.emplace_back( - reinterpret_cast(values_[values_index].ptr), - FLBA_type_length_); - ++values_index; + reinterpret_cast(values_[valuesIndex].ptr), + flbaTypeLength_); + ++valuesIndex; } return result; } - void CheckReadValues(int start, int end) { - auto binary_reader = - dynamic_cast(record_reader_.get()); - ASSERT_NE(binary_reader, nullptr); + void checkReadValues(int start, int end) { + auto binaryReader = dynamic_cast(recordReader_.get()); + ASSERT_NE(binaryReader, nullptr); // Chunks are reset after this call. - ::arrow::ArrayVector array_vector = binary_reader->GetBuilderChunks(); - ASSERT_EQ(array_vector.size(), 1); - auto binary_array = - dynamic_cast<::arrow::FixedSizeBinaryArray*>(array_vector[0].get()); - - ASSERT_NE(binary_array, nullptr); - ASSERT_EQ(binary_array->length(), record_reader_->values_written()); - if (read_dense_for_nullable()) { - ASSERT_EQ(binary_array->null_count(), 0); - ASSERT_EQ(record_reader_->null_count(), 0); + ::arrow::ArrayVector arrayVector = binaryReader->getBuilderChunks(); + ASSERT_EQ(arrayVector.size(), 1); + auto binaryArray = + dynamic_cast<::arrow::FixedSizeBinaryArray*>(arrayVector[0].get()); + + ASSERT_NE(binaryArray, nullptr); + ASSERT_EQ(binaryArray->length(), recordReader_->valuesWritten()); + if (readDenseForNullable()) { + ASSERT_EQ(binaryArray->null_count(), 0); + ASSERT_EQ(recordReader_->nullCount(), 0); } else { - ASSERT_EQ(binary_array->null_count(), record_reader_->null_count()); + ASSERT_EQ(binaryArray->null_count(), recordReader_->nullCount()); } - std::vector expected = expected_values(start, end); + std::vector expected = expectedValues(start, end); for (size_t i = 0; i < expected.size(); ++i) { - if (def_levels_[i + start] == 0) { - ASSERT_EQ(!read_dense_for_nullable(), binary_array->IsNull(i)); + if (defLevels_[i + start] == 0) { + ASSERT_EQ(!readDenseForNullable(), binaryArray->IsNull(i)); } else { - ASSERT_EQ(expected[i].compare(binary_array->GetView(i)), 0); - ASSERT_FALSE(binary_array->IsNull(i)); + ASSERT_EQ(expected[i].compare(binaryArray->GetView(i)), 0); + ASSERT_FALSE(binaryArray->IsNull(i)); } } } protected: - std::shared_ptr record_reader_; + std::shared_ptr recordReader_; private: - int levels_per_page_; - int FLBA_type_length_; + int levelsPerPage_; + int flbaTypeLength_; std::vector> pages_; - std::vector def_levels_; - std::vector rep_levels_; + std::vector defLevels_; + std::vector repLevels_; std::vector values_; std::vector buffer_; std::unique_ptr descr_; }; -// Similar to above, except for Byte arrays. FLBA and Byte arrays are -// sufficiently different to warrant a separate class for readability. +// Similar to above, except for Byte arrays. FLBA and Byte arrays are. +// Sufficiently different to warrant a separate class for readability. class ByteArrayRecordReaderTest : public ::testing::TestWithParam { public: - bool read_dense_for_nullable() { + bool readDenseForNullable() { return GetParam(); } - void MakeRecordReader(int levels_per_page, int num_pages) { - levels_per_page_ = levels_per_page; - LevelInfo level_info; - level_info.defLevel = 1; - level_info.repLevel = 0; - NodePtr type = schema::ByteArray("b", Repetition::OPTIONAL); + void makeRecordReader(int levelsPerPage, int numPages) { + levelsPerPage_ = levelsPerPage; + LevelInfo levelInfo; + levelInfo.defLevel = 1; + levelInfo.repLevel = 0; + NodePtr type = schema::byteArray("b", Repetition::kOptional); descr_ = std::make_unique( - type, level_info.defLevel, level_info.repLevel); - MakePages( + type, levelInfo.defLevel, levelInfo.repLevel); + makePages( descr_.get(), - num_pages, - levels_per_page, - def_levels_, - rep_levels_, + numPages, + levelsPerPage, + defLevels_, + repLevels_, values_, buffer_, pages_, - Encoding::PLAIN); + Encoding::kPlain); auto pager = std::make_unique(pages_); - record_reader_ = internal::RecordReader::Make( + recordReader_ = internal::RecordReader::make( descr_.get(), - level_info, + levelInfo, ::arrow::default_memory_pool(), - /*read_dictionary=*/false, - read_dense_for_nullable()); - record_reader_->SetPageReader(std::move(pager)); + false, + readDenseForNullable()); + recordReader_->setPageReader(std::move(pager)); } // Returns expected values in row range. // We need this since some values are null. - std::vector expected_values(int start, int end) { + std::vector expectedValues(int start, int end) { std::vector result; // Find out where in the values_ vector we start from. - size_t values_index = 0; + size_t valuesIndex = 0; for (int i = 0; i < start; ++i) { - if (def_levels_[i] != 0) { - ++values_index; + if (defLevels_[i] != 0) { + ++valuesIndex; } } for (int i = start; i < end; ++i) { - if (def_levels_[i] == 0) { - if (!read_dense_for_nullable()) { + if (defLevels_[i] == 0) { + if (!readDenseForNullable()) { result.emplace_back(); } continue; } result.emplace_back( - reinterpret_cast(values_[values_index].ptr), - values_[values_index].len); - ++values_index; + reinterpret_cast(values_[valuesIndex].ptr), + values_[valuesIndex].len); + ++valuesIndex; } return result; } - void CheckReadValues(int start, int end) { - auto binary_reader = - dynamic_cast(record_reader_.get()); - ASSERT_NE(binary_reader, nullptr); + void checkReadValues(int start, int end) { + auto binaryReader = dynamic_cast(recordReader_.get()); + ASSERT_NE(binaryReader, nullptr); // Chunks are reset after this call. - ::arrow::ArrayVector array_vector = binary_reader->GetBuilderChunks(); - ASSERT_EQ(array_vector.size(), 1); - ::arrow::BinaryArray* binary_array = - dynamic_cast<::arrow::BinaryArray*>(array_vector[0].get()); - - ASSERT_NE(binary_array, nullptr); - ASSERT_EQ(binary_array->length(), record_reader_->values_written()); - if (read_dense_for_nullable()) { - ASSERT_EQ(binary_array->null_count(), 0); - ASSERT_EQ(record_reader_->null_count(), 0); + ::arrow::ArrayVector arrayVector = binaryReader->getBuilderChunks(); + ASSERT_EQ(arrayVector.size(), 1); + ::arrow::BinaryArray* binaryArray = + dynamic_cast<::arrow::BinaryArray*>(arrayVector[0].get()); + + ASSERT_NE(binaryArray, nullptr); + ASSERT_EQ(binaryArray->length(), recordReader_->valuesWritten()); + if (readDenseForNullable()) { + ASSERT_EQ(binaryArray->null_count(), 0); + ASSERT_EQ(recordReader_->nullCount(), 0); } else { - ASSERT_EQ(binary_array->null_count(), record_reader_->null_count()); + ASSERT_EQ(binaryArray->null_count(), recordReader_->nullCount()); } - std::vector expected = expected_values(start, end); + std::vector expected = expectedValues(start, end); for (size_t i = 0; i < expected.size(); ++i) { - if (def_levels_[i + start] == 0) { - ASSERT_EQ(!read_dense_for_nullable(), binary_array->IsNull(i)); + if (defLevels_[i + start] == 0) { + ASSERT_EQ(!readDenseForNullable(), binaryArray->IsNull(i)); } else { - ASSERT_EQ(expected[i].compare(binary_array->GetView(i)), 0); - ASSERT_FALSE(binary_array->IsNull(i)); + ASSERT_EQ(expected[i].compare(binaryArray->GetView(i)), 0); + ASSERT_FALSE(binaryArray->IsNull(i)); } } } protected: - std::shared_ptr record_reader_; + std::shared_ptr recordReader_; private: - int levels_per_page_; + int levelsPerPage_; std::vector> pages_; - std::vector def_levels_; - std::vector rep_levels_; + std::vector defLevels_; + std::vector repLevels_; std::vector values_; std::vector buffer_; std::unique_ptr descr_; }; // Tests reading and skipping a ByteArray field. -// The binary readers only differ in DeocdeDense and DecodeSpaced functions, so -// testing optional is sufficient in excercising those code paths. +// The binary readers only differ in DeocdeDense and DecodeSpaced functions, so. +// Testing optional is sufficient in excercising those code paths. TEST_P(ByteArrayRecordReaderTest, ReadAndSkipOptional) { - MakeRecordReader(/*levels_per_page=*/90, /*num_pages=*/1); + makeRecordReader(90, 1); // Read one-third of the page. - ASSERT_EQ(record_reader_->ReadRecords(/*num_records=*/30), 30); - CheckReadValues(0, 30); - record_reader_->Reset(); + ASSERT_EQ(recordReader_->readRecords(30), 30); + checkReadValues(0, 30); + recordReader_->reset(); // Skip 30 records. - ASSERT_EQ(record_reader_->SkipRecords(/*num_records=*/30), 30); + ASSERT_EQ(recordReader_->skipRecords(30), 30); - // Read 60 more records. Only 30 will be read, since we read 30 and skipped - // 30, so only 30 is left. - ASSERT_EQ(record_reader_->ReadRecords(/*num_records=*/60), 30); - CheckReadValues(60, 90); - record_reader_->Reset(); + // Read 60 more records. Only 30 will be read, since we read 30 and skipped. + // 30, So only 30 is left. + ASSERT_EQ(recordReader_->readRecords(60), 30); + checkReadValues(60, 90); + recordReader_->reset(); } // Tests reading and skipping an optional FLBA field. -// The binary readers only differ in DeocdeDense and DecodeSpaced functions, so -// testing optional is sufficient in excercising those code paths. +// The binary readers only differ in DeocdeDense and DecodeSpaced functions, so. +// Testing optional is sufficient in excercising those code paths. TEST_P(FLBARecordReaderTest, ReadAndSkipOptional) { - MakeRecordReader( - /*levels_per_page=*/90, /*num_pages=*/1, /*FLBA_type_length=*/4); + makeRecordReader(90, 1, 4); // Read one-third of the page. - ASSERT_EQ(record_reader_->ReadRecords(/*num_records=*/30), 30); - CheckReadValues(0, 30); - record_reader_->Reset(); + ASSERT_EQ(recordReader_->readRecords(30), 30); + checkReadValues(0, 30); + recordReader_->reset(); // Skip 30 records. - ASSERT_EQ(record_reader_->SkipRecords(/*num_records=*/30), 30); + ASSERT_EQ(recordReader_->skipRecords(30), 30); - // Read 60 more records. Only 30 will be read, since we read 30 and skipped - // 30, so only 30 is left. - ASSERT_EQ(record_reader_->ReadRecords(/*num_records=*/60), 30); - CheckReadValues(60, 90); - record_reader_->Reset(); + // Read 60 more records. Only 30 will be read, since we read 30 and skipped. + // 30, So only 30 is left. + ASSERT_EQ(recordReader_->readRecords(60), 30); + checkReadValues(60, 90); + recordReader_->reset(); } INSTANTIATE_TEST_SUITE_P( @@ -2130,142 +1834,142 @@ class RecordReaderStressTest : public ::testing::TestWithParam {}; TEST_P(RecordReaderStressTest, StressTest) { - LevelInfo level_info; + LevelInfo levelInfo; // Define these boolean variables for improving readability below. bool repeated = false, required = false; - if (GetParam() == Repetition::REQUIRED) { - level_info.defLevel = 0; - level_info.repLevel = 0; + if (GetParam() == Repetition::kRequired) { + levelInfo.defLevel = 0; + levelInfo.repLevel = 0; required = true; - } else if (GetParam() == Repetition::OPTIONAL) { - level_info.defLevel = 1; - level_info.repLevel = 0; + } else if (GetParam() == Repetition::kOptional) { + levelInfo.defLevel = 1; + levelInfo.repLevel = 0; } else { - level_info.defLevel = 1; - level_info.repLevel = 1; + levelInfo.defLevel = 1; + levelInfo.repLevel = 1; repeated = true; } - NodePtr type = schema::Int32("b", GetParam()); - const ColumnDescriptor descr(type, level_info.defLevel, level_info.repLevel); + NodePtr type = schema::int32("b", GetParam()); + const ColumnDescriptor descr(type, levelInfo.defLevel, levelInfo.repLevel); auto seed1 = static_cast(time(0)); std::default_random_engine gen(seed1); // Generate random number of pages with random number of values per page. std::uniform_int_distribution d(0, 2000); - const int num_pages = d(gen); - const int levels_per_page = d(gen); + const int numPages = d(gen); + const int levelsPerPage = d(gen); std::vector values; - std::vector def_levels; - std::vector rep_levels; - std::vector data_buffer; + std::vector defLevels; + std::vector repLevels; + std::vector dataBuffer; std::vector> pages; auto seed2 = static_cast(time(0)); - // Uses time(0) as seed so it would run a different test every time it is - // run. - MakePages( + // Uses time(0) as seed so it would run a different test every time it is. + // Run. + makePages( &descr, - num_pages, - levels_per_page, - def_levels, - rep_levels, + numPages, + levelsPerPage, + defLevels, + repLevels, values, - data_buffer, + dataBuffer, pages, - Encoding::PLAIN, + Encoding::kPlain, seed2); std::unique_ptr pager; pager.reset(new test::MockPageReader(pages)); // Set up the RecordReader. - std::shared_ptr record_reader = - internal::RecordReader::Make(&descr, level_info); - record_reader->SetPageReader(std::move(pager)); + std::shared_ptr recordReader = + internal::RecordReader::make(&descr, levelInfo); + recordReader->setPageReader(std::move(pager)); // Figure out how many total records. - int total_records = 0; + int totalRecords = 0; if (repeated) { - for (int16_t rep : rep_levels) { + for (int16_t rep : repLevels) { if (rep == 0) { - ++total_records; + ++totalRecords; } } } else { - total_records = static_cast(def_levels.size()); + totalRecords = static_cast(defLevels.size()); } // Generate a sequence of reads and skips. - int records_left = total_records; + int recordsLeft = totalRecords; // The first element of the pair is 1 if SkipRecords and 0 if ReadRecords. - // The second element indicates the number of records for reading or - // skipping. + // The second element indicates the number of records for reading or. + // Skipping. std::vector> sequence; - while (records_left > 0) { - std::uniform_int_distribution d(0, records_left); + while (recordsLeft > 0) { + std::uniform_int_distribution d(0, recordsLeft); // Generate a number to decide if this is a skip or read. - bool is_skip = d(gen) < records_left / 2; - int num_records = d(gen); + bool isSkip = d(gen) < recordsLeft / 2; + int numRecords = d(gen); - sequence.emplace_back(is_skip, num_records); - records_left -= num_records; + sequence.emplace_back(isSkip, numRecords); + recordsLeft -= numRecords; } - // The levels_index and values_index are over the original vectors that have - // all the rep/def values for all the records. In the following loop, we will - // read/skip a numebr of records and Reset the reader after each iteration. + // The levels_index and values_index are over the original vectors that have. + // All the rep/def values for all the records. In the following loop, we will. + // Read/skip a numebr of records and Reset the reader after each iteration. // This is on-par with how the record reader is used. - size_t levels_index = 0; - size_t values_index = 0; - for (const auto& [is_skip, num_records] : sequence) { + size_t levelsIndex = 0; + size_t valuesIndex = 0; + for (const auto& [isSkip, numRecords] : sequence) { // Reset the reader before the next round of read/skip. - record_reader->Reset(); + recordReader->reset(); // Prepare the expected result and do the SkipRecords and ReadRecords. - std::vector expected_values; - std::vector expected_def_levels; - std::vector expected_rep_levels; - bool inside_repeated_field = false; - - int read_records = 0; - while (read_records < num_records || inside_repeated_field) { - if (!repeated || (repeated && rep_levels[levels_index] == 0)) { - ++read_records; + std::vector expectedValues; + std::vector expectedDefLevels; + std::vector expectedRepLevels; + bool insideRepeatedField = false; + + int readRecords = 0; + while (readRecords < numRecords || insideRepeatedField) { + if (!repeated || (repeated && repLevels[levelsIndex] == 0)) { + ++readRecords; } - bool has_value = required || - (!required && def_levels[levels_index] == level_info.defLevel); + bool hasValue = required || + (!required && defLevels[levelsIndex] == levelInfo.defLevel); - // If we are not skipping, we need to update the expected values and - // rep/defs. If we are skipping, we just keep going. - if (!is_skip) { + // If we are not skipping, we need to update the expected values and. + // Rep/defs. If we are skipping, we just keep going. + if (!isSkip) { if (!required) { - expected_def_levels.push_back(def_levels[levels_index]); - if (!has_value) { - expected_values.push_back(-1); + expectedDefLevels.push_back(defLevels[levelsIndex]); + if (!hasValue) { + expectedValues.push_back(-1); } } if (repeated) { - expected_rep_levels.push_back(rep_levels[levels_index]); + expectedRepLevels.push_back(repLevels[levelsIndex]); } - if (has_value) { - expected_values.push_back(values[values_index]); + if (hasValue) { + expectedValues.push_back(values[valuesIndex]); } } - if (has_value) { - ++values_index; + if (hasValue) { + ++valuesIndex; } - // If we are in the middle of a repeated field, we should keep going - // until we consume it all. - if (repeated && levels_index + 1 < rep_levels.size() && - rep_levels[levels_index + 1] == 1) { - inside_repeated_field = true; + // If we are in the middle of a repeated field, we should keep going. + // Until we consume it all. + if (repeated && levelsIndex + 1 < repLevels.size() && + repLevels[levelsIndex + 1] == 1) { + insideRepeatedField = true; } else { - inside_repeated_field = false; + insideRepeatedField = false; } - ++levels_index; + ++levelsIndex; } // Print out the seeds with each failing ASSERT to easily reproduce the bug. @@ -2273,50 +1977,50 @@ TEST_P(RecordReaderStressTest, StressTest) { "seeds: " + std::to_string(seed1) + " " + std::to_string(seed2); // Perform the actual read/skip. - if (is_skip) { - int64_t skipped_records = record_reader->SkipRecords(num_records); - ASSERT_EQ(skipped_records, num_records) << seeds; + if (isSkip) { + int64_t skippedRecords = recordReader->skipRecords(numRecords); + ASSERT_EQ(skippedRecords, numRecords) << seeds; } else { - int64_t read_records = record_reader->ReadRecords(num_records); - ASSERT_EQ(read_records, num_records) << seeds; + int64_t readRecords = recordReader->readRecords(numRecords); + ASSERT_EQ(readRecords, numRecords) << seeds; } - const auto read_values = - reinterpret_cast(record_reader->values()); + const auto readValues = + reinterpret_cast(recordReader->values()); if (required) { - ASSERT_EQ(record_reader->null_count(), 0) << seeds; + ASSERT_EQ(recordReader->nullCount(), 0) << seeds; } - std::vector read_vals( - read_values, read_values + record_reader->values_written()); - for (size_t i = 0; i < expected_values.size(); ++i) { - if (expected_values[i] != -1) { - ASSERT_EQ(read_vals[i], expected_values[i]) << seeds; + std::vector readVals( + readValues, readValues + recordReader->valuesWritten()); + for (size_t i = 0; i < expectedValues.size(); ++i) { + if (expectedValues[i] != -1) { + ASSERT_EQ(readVals[i], expectedValues[i]) << seeds; } } if (!required) { - std::vector read_def_levels( - record_reader->def_levels(), - record_reader->def_levels() + record_reader->levels_position()); - ASSERT_TRUE(vector_equal(read_def_levels, expected_def_levels)) << seeds; + std::vector readDefLevels( + recordReader->defLevels(), + recordReader->defLevels() + recordReader->levelsPosition()); + ASSERT_TRUE(vectorEqual(readDefLevels, expectedDefLevels)) << seeds; } if (repeated) { - std::vector read_rep_levels( - record_reader->rep_levels(), - record_reader->rep_levels() + record_reader->levels_position()); - ASSERT_TRUE(vector_equal(read_rep_levels, expected_rep_levels)) << seeds; + std::vector readRepLevels( + recordReader->repLevels(), + recordReader->repLevels() + recordReader->levelsPosition()); + ASSERT_TRUE(vectorEqual(readRepLevels, expectedRepLevels)) << seeds; } } } INSTANTIATE_TEST_SUITE_P( - Repetition_type, + repetitionType, RecordReaderStressTest, ::testing::Values( - Repetition::REQUIRED, - Repetition::OPTIONAL, - Repetition::REPEATED)); + Repetition::kRequired, + Repetition::kOptional, + Repetition::kRepeated)); } // namespace test } // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/tests/ColumnScanner.cpp b/velox/dwio/parquet/writer/arrow/tests/ColumnScanner.cpp index 030f9c1c446..7404949a2b2 100644 --- a/velox/dwio/parquet/writer/arrow/tests/ColumnScanner.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/ColumnScanner.cpp @@ -27,78 +27,78 @@ using arrow::MemoryPool; namespace facebook::velox::parquet::arrow { -std::shared_ptr Scanner::Make( - std::shared_ptr col_reader, - int64_t batch_size, +std::shared_ptr Scanner::make( + std::shared_ptr colReader, + int64_t batchSize, MemoryPool* pool) { - switch (col_reader->type()) { - case Type::BOOLEAN: + switch (colReader->type()) { + case Type::kBoolean: return std::make_shared( - std::move(col_reader), batch_size, pool); - case Type::INT32: + std::move(colReader), batchSize, pool); + case Type::kInt32: return std::make_shared( - std::move(col_reader), batch_size, pool); - case Type::INT64: + std::move(colReader), batchSize, pool); + case Type::kInt64: return std::make_shared( - std::move(col_reader), batch_size, pool); - case Type::INT96: + std::move(colReader), batchSize, pool); + case Type::kInt96: return std::make_shared( - std::move(col_reader), batch_size, pool); - case Type::FLOAT: + std::move(colReader), batchSize, pool); + case Type::kFloat: return std::make_shared( - std::move(col_reader), batch_size, pool); - case Type::DOUBLE: + std::move(colReader), batchSize, pool); + case Type::kDouble: return std::make_shared( - std::move(col_reader), batch_size, pool); - case Type::BYTE_ARRAY: + std::move(colReader), batchSize, pool); + case Type::kByteArray: return std::make_shared( - std::move(col_reader), batch_size, pool); - case Type::FIXED_LEN_BYTE_ARRAY: + std::move(colReader), batchSize, pool); + case Type::kFixedLenByteArray: return std::make_shared( - std::move(col_reader), batch_size, pool); + std::move(colReader), batchSize, pool); default: ParquetException::NYI("type reader not implemented"); } - // Unreachable code, but suppress compiler warning + // Unreachable code, but suppress compiler warning. return std::shared_ptr(nullptr); } -int64_t ScanAllValues( - int32_t batch_size, - int16_t* def_levels, - int16_t* rep_levels, +int64_t scanAllValues( + int32_t batchSize, + int16_t* defLevels, + int16_t* repLevels, uint8_t* values, - int64_t* values_buffered, + int64_t* valuesBuffered, ColumnReader* reader) { switch (reader->type()) { - case parquet::Type::BOOLEAN: - return ScanAll( - batch_size, def_levels, rep_levels, values, values_buffered, reader); - case parquet::Type::INT32: - return ScanAll( - batch_size, def_levels, rep_levels, values, values_buffered, reader); - case parquet::Type::INT64: - return ScanAll( - batch_size, def_levels, rep_levels, values, values_buffered, reader); - case parquet::Type::INT96: - return ScanAll( - batch_size, def_levels, rep_levels, values, values_buffered, reader); - case parquet::Type::FLOAT: - return ScanAll( - batch_size, def_levels, rep_levels, values, values_buffered, reader); - case parquet::Type::DOUBLE: - return ScanAll( - batch_size, def_levels, rep_levels, values, values_buffered, reader); - case parquet::Type::BYTE_ARRAY: - return ScanAll( - batch_size, def_levels, rep_levels, values, values_buffered, reader); - case parquet::Type::FIXED_LEN_BYTE_ARRAY: - return ScanAll( - batch_size, def_levels, rep_levels, values, values_buffered, reader); + case parquet::Type::kBoolean: + return scanAll( + batchSize, defLevels, repLevels, values, valuesBuffered, reader); + case parquet::Type::kInt32: + return scanAll( + batchSize, defLevels, repLevels, values, valuesBuffered, reader); + case parquet::Type::kInt64: + return scanAll( + batchSize, defLevels, repLevels, values, valuesBuffered, reader); + case parquet::Type::kInt96: + return scanAll( + batchSize, defLevels, repLevels, values, valuesBuffered, reader); + case parquet::Type::kFloat: + return scanAll( + batchSize, defLevels, repLevels, values, valuesBuffered, reader); + case parquet::Type::kDouble: + return scanAll( + batchSize, defLevels, repLevels, values, valuesBuffered, reader); + case parquet::Type::kByteArray: + return scanAll( + batchSize, defLevels, repLevels, values, valuesBuffered, reader); + case parquet::Type::kFixedLenByteArray: + return scanAll( + batchSize, defLevels, repLevels, values, valuesBuffered, reader); default: ParquetException::NYI("type reader not implemented"); } - // Unreachable code, but suppress compiler warning + // Unreachable code, but suppress compiler warning. return 0; } diff --git a/velox/dwio/parquet/writer/arrow/tests/ColumnScanner.h b/velox/dwio/parquet/writer/arrow/tests/ColumnScanner.h index 5729a4c05ce..907def5f954 100644 --- a/velox/dwio/parquet/writer/arrow/tests/ColumnScanner.h +++ b/velox/dwio/parquet/writer/arrow/tests/ColumnScanner.h @@ -41,227 +41,227 @@ class PARQUET_EXPORT Scanner { public: explicit Scanner( std::shared_ptr reader, - int64_t batch_size = DEFAULT_SCANNER_BATCH_SIZE, + int64_t batchSize = DEFAULT_SCANNER_BATCH_SIZE, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()) - : batch_size_(batch_size), - level_offset_(0), - levels_buffered_(0), - value_buffer_(AllocateBuffer(pool)), - value_offset_(0), - values_buffered_(0), + : batchSize_(batchSize), + levelOffset_(0), + levelsBuffered_(0), + valueBuffer_(allocateBuffer(pool)), + valueOffset_(0), + valuesBuffered_(0), reader_(std::move(reader)) { - def_levels_.resize(descr()->max_definition_level() > 0 ? batch_size_ : 0); - rep_levels_.resize(descr()->max_repetition_level() > 0 ? batch_size_ : 0); + defLevels_.resize(descr()->maxDefinitionLevel() > 0 ? batchSize_ : 0); + repLevels_.resize(descr()->maxRepetitionLevel() > 0 ? batchSize_ : 0); } virtual ~Scanner() {} - static std::shared_ptr Make( - std::shared_ptr col_reader, - int64_t batch_size = DEFAULT_SCANNER_BATCH_SIZE, + static std::shared_ptr make( + std::shared_ptr colReader, + int64_t batchSize = DEFAULT_SCANNER_BATCH_SIZE, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()); virtual void - PrintNext(std::ostream& out, int width, bool with_levels = false) = 0; + printNext(std::ostream& out, int width, bool withLevels = false) = 0; - bool HasNext() { - return level_offset_ < levels_buffered_ || reader_->HasNext(); + bool hasNext() { + return levelOffset_ < levelsBuffered_ || reader_->hasNext(); } const ColumnDescriptor* descr() const { return reader_->descr(); } - int64_t batch_size() const { - return batch_size_; + int64_t batchSize() const { + return batchSize_; } - void SetBatchSize(int64_t batch_size) { - batch_size_ = batch_size; + void setBatchSize(int64_t batchSize) { + batchSize_ = batchSize; } protected: - int64_t batch_size_; + int64_t batchSize_; - std::vector def_levels_; - std::vector rep_levels_; - int level_offset_; - int levels_buffered_; + std::vector defLevels_; + std::vector repLevels_; + int levelOffset_; + int levelsBuffered_; - std::shared_ptr value_buffer_; - int value_offset_; - int64_t values_buffered_; + std::shared_ptr valueBuffer_; + int valueOffset_; + int64_t valuesBuffered_; std::shared_ptr reader_; }; template class PARQUET_TEMPLATE_CLASS_EXPORT TypedScanner : public Scanner { public: - typedef typename DType::c_type T; + typedef typename DType::CType T; explicit TypedScanner( std::shared_ptr reader, - int64_t batch_size = DEFAULT_SCANNER_BATCH_SIZE, + int64_t batchSize = DEFAULT_SCANNER_BATCH_SIZE, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()) - : Scanner(std::move(reader), batch_size, pool) { - typed_reader_ = static_cast*>(reader_.get()); - int value_byte_size = type_traits::value_byte_size; - PARQUET_THROW_NOT_OK(value_buffer_->Resize(batch_size_ * value_byte_size)); - values_ = reinterpret_cast(value_buffer_->mutable_data()); + : Scanner(std::move(reader), batchSize, pool) { + typedReader_ = static_cast*>(reader_.get()); + int valueByteSize = TypeTraits::valueByteSize; + PARQUET_THROW_NOT_OK(valueBuffer_->Resize(batchSize_ * valueByteSize)); + values_ = reinterpret_cast(valueBuffer_->mutable_data()); } virtual ~TypedScanner() {} - bool NextLevels(int16_t* def_level, int16_t* rep_level) { - if (level_offset_ == levels_buffered_) { - levels_buffered_ = static_cast(typed_reader_->ReadBatch( - static_cast(batch_size_), - def_levels_.data(), - rep_levels_.data(), + bool nextLevels(int16_t* defLevel, int16_t* repLevel) { + if (levelOffset_ == levelsBuffered_) { + levelsBuffered_ = static_cast(typedReader_->readBatch( + static_cast(batchSize_), + defLevels_.data(), + repLevels_.data(), values_, - &values_buffered_)); + &valuesBuffered_)); - value_offset_ = 0; - level_offset_ = 0; - if (!levels_buffered_) { + valueOffset_ = 0; + levelOffset_ = 0; + if (!levelsBuffered_) { return false; } } - *def_level = - descr()->max_definition_level() > 0 ? def_levels_[level_offset_] : 0; - *rep_level = - descr()->max_repetition_level() > 0 ? rep_levels_[level_offset_] : 0; - level_offset_++; + *defLevel = + descr()->maxDefinitionLevel() > 0 ? defLevels_[levelOffset_] : 0; + *repLevel = + descr()->maxRepetitionLevel() > 0 ? repLevels_[levelOffset_] : 0; + levelOffset_++; return true; } - bool Next(T* val, int16_t* def_level, int16_t* rep_level, bool* is_null) { - if (level_offset_ == levels_buffered_) { - if (!HasNext()) { - // Out of data pages + bool next(T* val, int16_t* defLevel, int16_t* repLevel, bool* isNull) { + if (levelOffset_ == levelsBuffered_) { + if (!hasNext()) { + // Out of data pages. return false; } } - NextLevels(def_level, rep_level); - *is_null = *def_level < descr()->max_definition_level(); + nextLevels(defLevel, repLevel); + *isNull = *defLevel < descr()->maxDefinitionLevel(); - if (*is_null) { + if (*isNull) { return true; } - if (value_offset_ == values_buffered_) { + if (valueOffset_ == valuesBuffered_) { throw ParquetException("Value was non-null, but has not been buffered"); } - *val = values_[value_offset_++]; + *val = values_[valueOffset_++]; return true; } - // Returns true if there is a next value - bool NextValue(T* val, bool* is_null) { - if (level_offset_ == levels_buffered_) { - if (!HasNext()) { - // Out of data pages + // Returns true if there is a next value. + bool nextValue(T* val, bool* isNull) { + if (levelOffset_ == levelsBuffered_) { + if (!hasNext()) { + // Out of data pages. return false; } } - // Out of values - int16_t def_level = -1; - int16_t rep_level = -1; - NextLevels(&def_level, &rep_level); - *is_null = def_level < descr()->max_definition_level(); + // Out of values. + int16_t defLevel = -1; + int16_t repLevel = -1; + nextLevels(&defLevel, &repLevel); + *isNull = defLevel < descr()->maxDefinitionLevel(); - if (*is_null) { + if (*isNull) { return true; } - if (value_offset_ == values_buffered_) { + if (valueOffset_ == valuesBuffered_) { throw ParquetException("Value was non-null, but has not been buffered"); } - *val = values_[value_offset_++]; + *val = values_[valueOffset_++]; return true; } virtual void - PrintNext(std::ostream& out, int width, bool with_levels = false) { + printNext(std::ostream& out, int width, bool withLevels = false) { T val{}; - int16_t def_level = -1; - int16_t rep_level = -1; - bool is_null = false; + int16_t defLevel = -1; + int16_t repLevel = -1; + bool isNull = false; char buffer[80]; - if (!Next(&val, &def_level, &rep_level, &is_null)) { + if (!next(&val, &defLevel, &repLevel, &isNull)) { throw ParquetException("No more values buffered"); } - if (with_levels) { - out << " D:" << def_level << " R:" << rep_level << " "; - if (!is_null) { + if (withLevels) { + out << " D:" << defLevel << " R:" << repLevel << " "; + if (!isNull) { out << "V:"; } } - if (is_null) { - std::string null_fmt = format_fwf(width); - snprintf(buffer, sizeof(buffer), null_fmt.c_str(), "NULL"); + if (isNull) { + std::string nullFmt = formatFwf(width); + snprintf(buffer, sizeof(buffer), nullFmt.c_str(), "NULL"); } else { - FormatValue(&val, buffer, sizeof(buffer), width); + formatValue(&val, buffer, sizeof(buffer), width); } out << buffer; } private: - // The ownership of this object is expressed through the reader_ variable in - // the base - TypedColumnReader* typed_reader_; + // The ownership of this object is expressed through the reader_ variable in. + // The base. + TypedColumnReader* typedReader_; - inline void FormatValue(void* val, char* buffer, int bufsize, int width); + inline void formatValue(void* val, char* buffer, int bufsize, int width); T* values_; }; template -inline void TypedScanner::FormatValue( +inline void TypedScanner::formatValue( void* val, char* buffer, int bufsize, int width) { - std::string fmt = format_fwf(width); + std::string fmt = formatFwf(width); snprintf(buffer, bufsize, fmt.c_str(), *reinterpret_cast(val)); } template <> -inline void TypedScanner::FormatValue( +inline void TypedScanner::formatValue( void* val, char* buffer, int bufsize, int width) { - std::string fmt = format_fwf(width); - std::string result = Int96ToString(*reinterpret_cast(val)); + std::string fmt = formatFwf(width); + std::string result = int96ToString(*reinterpret_cast(val)); snprintf(buffer, bufsize, fmt.c_str(), result.c_str()); } template <> -inline void TypedScanner::FormatValue( +inline void TypedScanner::formatValue( void* val, char* buffer, int bufsize, int width) { - std::string fmt = format_fwf(width); - std::string result = ByteArrayToString(*reinterpret_cast(val)); + std::string fmt = formatFwf(width); + std::string result = byteArrayToString(*reinterpret_cast(val)); snprintf(buffer, bufsize, fmt.c_str(), result.c_str()); } template <> -inline void TypedScanner::FormatValue( +inline void TypedScanner::formatValue( void* val, char* buffer, int bufsize, int width) { - std::string fmt = format_fwf(width); - std::string result = FixedLenByteArrayToString( - *reinterpret_cast(val), descr()->type_length()); + std::string fmt = formatFwf(width); + std::string result = fixedLenByteArrayToString( + *reinterpret_cast(val), descr()->typeLength()); snprintf(buffer, bufsize, fmt.c_str(), result.c_str()); } @@ -275,26 +275,26 @@ typedef TypedScanner ByteArrayScanner; typedef TypedScanner FixedLenByteArrayScanner; template -int64_t ScanAll( - int32_t batch_size, - int16_t* def_levels, - int16_t* rep_levels, +int64_t scanAll( + int32_t batchSize, + int16_t* defLevels, + int16_t* repLevels, uint8_t* values, - int64_t* values_buffered, + int64_t* valuesBuffered, ColumnReader* reader) { typedef typename RType::T Type; - auto typed_reader = static_cast(reader); + auto typedReader = static_cast(reader); auto vals = reinterpret_cast(&values[0]); - return typed_reader->ReadBatch( - batch_size, def_levels, rep_levels, vals, values_buffered); + return typedReader->readBatch( + batchSize, defLevels, repLevels, vals, valuesBuffered); } -int64_t PARQUET_EXPORT ScanAllValues( - int32_t batch_size, - int16_t* def_levels, - int16_t* rep_levels, +int64_t PARQUET_EXPORT scanAllValues( + int32_t batchSize, + int16_t* defLevels, + int16_t* repLevels, uint8_t* values, - int64_t* values_buffered, + int64_t* valuesBuffered, ColumnReader* reader); } // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/tests/ColumnWriterTest.cpp b/velox/dwio/parquet/writer/arrow/tests/ColumnWriterTest.cpp index 4d9b10f3568..d9716a76580 100644 --- a/velox/dwio/parquet/writer/arrow/tests/ColumnWriterTest.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/ColumnWriterTest.cpp @@ -56,15 +56,15 @@ const int SMALL_SIZE = 100; const int LARGE_SIZE = 10000; // Very large size to test dictionary fallback. const int VERY_LARGE_SIZE = 40000; -// Reduced dictionary page size to use for testing dictionary fallback with -// valgrind +// Reduced dictionary page size to use for testing dictionary fallback with. +// Valgrind. const int64_t DICTIONARY_PAGE_SIZE = 1024; #else // Larger size to test some corner cases, only used in some specific cases. const int LARGE_SIZE = 100000; // Very large size to test dictionary fallback. const int VERY_LARGE_SIZE = 400000; -// Dictionary page size to use for testing dictionary fallback +// Dictionary page size to use for testing dictionary fallback. const int64_t DICTIONARY_PAGE_SIZE = 1024 * 1024; #endif @@ -72,62 +72,62 @@ template class TestPrimitiveWriter : public PrimitiveTypedTest { public: void SetUp() { - this->SetupValuesOut(SMALL_SIZE); - writer_properties_ = default_writer_properties(); - definition_levels_out_.resize(SMALL_SIZE); - repetition_levels_out_.resize(SMALL_SIZE); + this->setupValuesOut(SMALL_SIZE); + writerProperties_ = defaultWriterProperties(); + definitionLevelsOut_.resize(SMALL_SIZE); + repetitionLevelsOut_.resize(SMALL_SIZE); - this->SetUpSchema(Repetition::REQUIRED); + this->setUpSchema(Repetition::kRequired); - descr_ = this->schema_.Column(0); + descr_ = this->schema_.column(0); } - Type::type type_num() { - return TestType::type_num; + Type::type typeNum() { + return TestType::typeNum; } - void BuildReader( - int64_t num_rows, + void buildReader( + int64_t numRows, Compression::type compression = Compression::UNCOMPRESSED, - bool page_checksum_verify = false) { + bool pageChecksumVerify = false) { ASSERT_OK_AND_ASSIGN(auto buffer, sink_->Finish()); auto source = std::make_shared<::arrow::io::BufferReader>(buffer); ReaderProperties readerProperties; - readerProperties.set_page_checksum_verification(page_checksum_verify); - std::unique_ptr page_reader = PageReader::Open( - std::move(source), num_rows, compression, readerProperties); + readerProperties.setPageChecksumVerification(pageChecksumVerify); + std::unique_ptr pageReader = PageReader::open( + std::move(source), numRows, compression, readerProperties); reader_ = std::static_pointer_cast>( - ColumnReader::Make(this->descr_, std::move(page_reader))); + ColumnReader::make(this->descr_, std::move(pageReader))); } - std::shared_ptr> BuildWriter( - int64_t output_size = SMALL_SIZE, - const ColumnProperties& column_properties = ColumnProperties(), + std::shared_ptr> buildWriter( + int64_t outputSize = SMALL_SIZE, + const ColumnProperties& columnProps = ColumnProperties(), const ParquetVersion::type version = ParquetVersion::PARQUET_1_0, - bool enable_checksum = false) { - sink_ = CreateOutputStream(); - WriterProperties::Builder wp_builder; - wp_builder.version(version); - if (column_properties.encoding() == Encoding::PLAIN_DICTIONARY || - column_properties.encoding() == Encoding::RLE_DICTIONARY) { - wp_builder.enable_dictionary(); - wp_builder.dictionary_pagesize_limit(DICTIONARY_PAGE_SIZE); + bool enableChecksum = false) { + sink_ = createOutputStream(); + WriterProperties::Builder wpBuilder; + wpBuilder.version(version); + if (columnProps.encoding() == Encoding::kPlainDictionary || + columnProps.encoding() == Encoding::kRleDictionary) { + wpBuilder.enableDictionary(); + wpBuilder.dictionaryPagesizeLimit(DICTIONARY_PAGE_SIZE); } else { - wp_builder.disable_dictionary(); - wp_builder.encoding(column_properties.encoding()); + wpBuilder.disableDictionary(); + wpBuilder.encoding(columnProps.encoding()); } - if (enable_checksum) { - wp_builder.enable_page_checksum(); + if (enableChecksum) { + wpBuilder.enablePageChecksum(); } - wp_builder.max_statistics_size(column_properties.max_statistics_size()); - writer_properties_ = wp_builder.build(); + wpBuilder.maxStatisticsSize(columnProps.maxStatisticsSize()); + writerProperties_ = wpBuilder.build(); metadata_ = - ColumnChunkMetaDataBuilder::Make(writer_properties_, this->descr_); - std::unique_ptr pager = PageWriter::Open( + ColumnChunkMetaDataBuilder::make(writerProperties_, this->descr_); + std::unique_ptr pager = PageWriter::open( sink_, - column_properties.compression(), - Codec::UseDefaultCompressionLevel(), + columnProps.compression(), + Codec::useDefaultCompressionLevel(), metadata_.get(), /* row_group_ordinal */ -1, /* column_chunk_ordinal*/ -1, @@ -135,374 +135,366 @@ class TestPrimitiveWriter : public PrimitiveTypedTest { /* buffered_row_group */ false, /* header_encryptor */ NULLPTR, /* data_encryptor */ NULLPTR, - enable_checksum); - std::shared_ptr writer = ColumnWriter::Make( - metadata_.get(), std::move(pager), writer_properties_.get()); + enableChecksum); + std::shared_ptr writer = ColumnWriter::make( + metadata_.get(), std::move(pager), writerProperties_.get()); return std::static_pointer_cast>(writer); } - void ReadColumn( + void readColumn( Compression::type compression = Compression::UNCOMPRESSED, - bool page_checksum_verify = false) { - BuildReader( - static_cast(this->values_out_.size()), + bool pageChecksumVerify = false) { + buildReader( + static_cast(this->valuesOut_.size()), compression, - page_checksum_verify); - reader_->ReadBatch( - static_cast(this->values_out_.size()), - definition_levels_out_.data(), - repetition_levels_out_.data(), - this->values_out_ptr_, - &values_read_); - this->SyncValuesOut(); + pageChecksumVerify); + reader_->readBatch( + static_cast(this->valuesOut_.size()), + definitionLevelsOut_.data(), + repetitionLevelsOut_.data(), + this->valuesOutPtr_, + &valuesRead_); + this->syncValuesOut(); } - void ReadColumnFully( + void readColumnFully( Compression::type compression = Compression::UNCOMPRESSED, - bool page_checksum_verify = false); + bool pageChecksumVerify = false); - void TestRequiredWithEncoding(Encoding::type encoding) { - return TestRequiredWithSettings( + void testRequiredWithEncoding(Encoding::type encoding) { + return testRequiredWithSettings( encoding, Compression::UNCOMPRESSED, false, false); } - void TestRequiredWithSettings( + void testRequiredWithSettings( Encoding::type encoding, Compression::type compression, - bool enable_dictionary, - bool enable_statistics, - int64_t num_rows = SMALL_SIZE, - int compression_level = Codec::UseDefaultCompressionLevel(), - bool enable_checksum = false) { - this->GenerateData(num_rows); - - this->WriteRequiredWithSettings( + bool enableDictionary, + bool enableStatistics, + int64_t numRows = SMALL_SIZE, + int compressionLevel = Codec::useDefaultCompressionLevel(), + bool enableChecksum = false) { + this->generateData(numRows); + + this->writeRequiredWithSettings( encoding, compression, - enable_dictionary, - enable_statistics, - compression_level, - num_rows, - enable_checksum); + enableDictionary, + enableStatistics, + compressionLevel, + numRows, + enableChecksum); ASSERT_NO_FATAL_FAILURE( - this->ReadAndCompare(compression, num_rows, enable_checksum)); + this->readAndCompare(compression, numRows, enableChecksum)); - this->WriteRequiredWithSettingsSpaced( + this->writeRequiredWithSettingsSpaced( encoding, compression, - enable_dictionary, - enable_statistics, - num_rows, - compression_level, - enable_checksum); + enableDictionary, + enableStatistics, + numRows, + compressionLevel, + enableChecksum); ASSERT_NO_FATAL_FAILURE( - this->ReadAndCompare(compression, num_rows, enable_checksum)); + this->readAndCompare(compression, numRows, enableChecksum)); } - void TestDictionaryFallbackEncoding(ParquetVersion::type version) { - this->GenerateData(VERY_LARGE_SIZE); - ColumnProperties column_properties; - column_properties.set_dictionary_enabled(true); + void testDictionaryFallbackEncoding(ParquetVersion::type version) { + this->generateData(VERY_LARGE_SIZE); + ColumnProperties columnProperties; + columnProperties.setDictionaryEnabled(true); if (version == ParquetVersion::PARQUET_1_0) { - column_properties.set_encoding(Encoding::PLAIN_DICTIONARY); + columnProperties.setEncoding(Encoding::kPlainDictionary); } else { - column_properties.set_encoding(Encoding::RLE_DICTIONARY); + columnProperties.setEncoding(Encoding::kRleDictionary); } - auto writer = - this->BuildWriter(VERY_LARGE_SIZE, column_properties, version); + auto writer = this->buildWriter(VERY_LARGE_SIZE, columnProperties, version); - writer->WriteBatch( - this->values_.size(), nullptr, nullptr, this->values_ptr_); - writer->Close(); + writer->writeBatch( + this->values_.size(), nullptr, nullptr, this->valuesPtr_); + writer->close(); - // Read all rows so we are sure that also the non-dictionary pages are read - // correctly - this->SetupValuesOut(VERY_LARGE_SIZE); - this->ReadColumnFully(); - ASSERT_EQ(VERY_LARGE_SIZE, this->values_read_); + // Read all rows so we are sure that also the non-dictionary pages are read. + // Correctly. + this->setupValuesOut(VERY_LARGE_SIZE); + this->readColumnFully(); + ASSERT_EQ(VERY_LARGE_SIZE, this->valuesRead_); this->values_.resize(VERY_LARGE_SIZE); - ASSERT_EQ(this->values_, this->values_out_); - std::vector encodings_vector = this->metadata_encodings(); + ASSERT_EQ(this->values_, this->valuesOut_); + std::vector encodingsVector = this->metadataEncodings(); std::set encodings( - encodings_vector.begin(), encodings_vector.end()); + encodingsVector.cbegin(), encodingsVector.cend()); - if (this->type_num() == Type::BOOLEAN) { - // Dictionary encoding is not allowed for boolean type - // There are 2 encodings (PLAIN, RLE) in a non dictionary encoding case - std::set expected({Encoding::PLAIN, Encoding::RLE}); + if (this->typeNum() == Type::kBoolean) { + // Dictionary encoding is not allowed for boolean type. + // There are 2 encodings (PLAIN, RLE) in a non dictionary encoding case. + std::set expected({Encoding::kPlain, Encoding::kRle}); ASSERT_EQ(encodings, expected); } else if (version == ParquetVersion::PARQUET_1_0) { - // There are 3 encodings (PLAIN_DICTIONARY, PLAIN, RLE) in a fallback case - // for version 1.0 + // There are 3 encodings (PLAIN_DICTIONARY, PLAIN, RLE) in a fallback + // case. For version 1.0. std::set expected( - {Encoding::PLAIN_DICTIONARY, Encoding::PLAIN, Encoding::RLE}); + {Encoding::kPlainDictionary, Encoding::kPlain, Encoding::kRle}); ASSERT_EQ(encodings, expected); } else { - // There are 3 encodings (RLE_DICTIONARY, PLAIN, RLE) in a fallback case - // for version 2.0 + // There are 3 encodings (RLE_DICTIONARY, PLAIN, RLE) in a fallback case. + // For version 2.0. std::set expected( - {Encoding::RLE_DICTIONARY, Encoding::PLAIN, Encoding::RLE}); + {Encoding::kRleDictionary, Encoding::kPlain, Encoding::kRle}); ASSERT_EQ(encodings, expected); } - std::vector encoding_stats = - this->metadata_encoding_stats(); - if (this->type_num() == Type::BOOLEAN) { - ASSERT_EQ(encoding_stats[0].encoding, Encoding::PLAIN); - ASSERT_EQ(encoding_stats[0].page_type, PageType::DATA_PAGE); + std::vector encodingStats = + this->metadataEncodingStats(); + if (this->typeNum() == Type::kBoolean) { + ASSERT_EQ(encodingStats[0].encoding, Encoding::kPlain); + ASSERT_EQ(encodingStats[0].pageType, PageType::kDataPage); } else if (version == ParquetVersion::PARQUET_1_0) { std::vector expected( - {Encoding::PLAIN_DICTIONARY, - Encoding::PLAIN, - Encoding::PLAIN_DICTIONARY}); - ASSERT_EQ(encoding_stats[0].encoding, expected[0]); - ASSERT_EQ(encoding_stats[0].page_type, PageType::DICTIONARY_PAGE); - for (size_t i = 1; i < encoding_stats.size(); i++) { - ASSERT_EQ(encoding_stats[i].encoding, expected[i]); - ASSERT_EQ(encoding_stats[i].page_type, PageType::DATA_PAGE); + {Encoding::kPlainDictionary, + Encoding::kPlain, + Encoding::kPlainDictionary}); + ASSERT_EQ(encodingStats[0].encoding, expected[0]); + ASSERT_EQ(encodingStats[0].pageType, PageType::kDictionaryPage); + for (size_t i = 1; i < encodingStats.size(); i++) { + ASSERT_EQ(encodingStats[i].encoding, expected[i]); + ASSERT_EQ(encodingStats[i].pageType, PageType::kDataPage); } } else { std::vector expected( - {Encoding::PLAIN, Encoding::PLAIN, Encoding::RLE_DICTIONARY}); - ASSERT_EQ(encoding_stats[0].encoding, expected[0]); - ASSERT_EQ(encoding_stats[0].page_type, PageType::DICTIONARY_PAGE); - for (size_t i = 1; i < encoding_stats.size(); i++) { - ASSERT_EQ(encoding_stats[i].encoding, expected[i]); - ASSERT_EQ(encoding_stats[i].page_type, PageType::DATA_PAGE); + {Encoding::kPlain, Encoding::kPlain, Encoding::kRleDictionary}); + ASSERT_EQ(encodingStats[0].encoding, expected[0]); + ASSERT_EQ(encodingStats[0].pageType, PageType::kDictionaryPage); + for (size_t i = 1; i < encodingStats.size(); i++) { + ASSERT_EQ(encodingStats[i].encoding, expected[i]); + ASSERT_EQ(encodingStats[i].pageType, PageType::kDataPage); } } } - void WriteRequiredWithSettings( + void writeRequiredWithSettings( Encoding::type encoding, Compression::type compression, - bool enable_dictionary, - bool enable_statistics, - int compression_level, - int64_t num_rows, - bool enable_checksum) { - ColumnProperties column_properties( - encoding, compression, enable_dictionary, enable_statistics); - column_properties.set_compression_level(compression_level); - std::shared_ptr> writer = this->BuildWriter( - num_rows, - column_properties, - ParquetVersion::PARQUET_1_0, - enable_checksum); - writer->WriteBatch( - this->values_.size(), nullptr, nullptr, this->values_ptr_); - // The behaviour should be independent from the number of Close() calls - writer->Close(); - writer->Close(); + bool enableDictionary, + bool enableStatistics, + int compressionLevel, + int64_t numRows, + bool enableChecksum) { + ColumnProperties columnProperties( + encoding, compression, enableDictionary, enableStatistics); + columnProperties.setCompressionLevel(compressionLevel); + std::shared_ptr> writer = this->buildWriter( + numRows, columnProperties, ParquetVersion::PARQUET_1_0, enableChecksum); + writer->writeBatch( + this->values_.size(), nullptr, nullptr, this->valuesPtr_); + // The behaviour should be independent from the number of Close() calls. + writer->close(); + writer->close(); } - void WriteRequiredWithSettingsSpaced( + void writeRequiredWithSettingsSpaced( Encoding::type encoding, Compression::type compression, - bool enable_dictionary, - bool enable_statistics, - int64_t num_rows, - int compression_level, - bool enable_checksum) { - std::vector valid_bits( + bool enableDictionary, + bool enableStatistics, + int64_t numRows, + int compressionLevel, + bool enableChecksum) { + std::vector validBits( ::arrow::bit_util::BytesForBits( static_cast(this->values_.size())) + 1, 255); - ColumnProperties column_properties( - encoding, compression, enable_dictionary, enable_statistics); - column_properties.set_compression_level(compression_level); - std::shared_ptr> writer = this->BuildWriter( - num_rows, - column_properties, - ParquetVersion::PARQUET_1_0, - enable_checksum); - writer->WriteBatchSpaced( + ColumnProperties columnProperties( + encoding, compression, enableDictionary, enableStatistics); + columnProperties.setCompressionLevel(compressionLevel); + std::shared_ptr> writer = this->buildWriter( + numRows, columnProperties, ParquetVersion::PARQUET_1_0, enableChecksum); + writer->writeBatchSpaced( this->values_.size(), nullptr, nullptr, - valid_bits.data(), + validBits.data(), 0, - this->values_ptr_); - // The behaviour should be independent from the number of Close() calls - writer->Close(); - writer->Close(); + this->valuesPtr_); + // The behaviour should be independent from the number of Close() calls. + writer->close(); + writer->close(); } - void ReadAndCompare( + void readAndCompare( Compression::type compression, - int64_t num_rows, - bool page_checksum_verify) { - this->SetupValuesOut(num_rows); - this->ReadColumnFully(compression, page_checksum_verify); - auto comparator = MakeComparator(this->descr_); + int64_t numRows, + bool pageChecksumVerify) { + this->setupValuesOut(numRows); + this->readColumnFully(compression, pageChecksumVerify); + auto Comparator = makeComparator(this->descr_); for (size_t i = 0; i < this->values_.size(); i++) { - if (comparator->Compare(this->values_[i], this->values_out_[i]) || - comparator->Compare(this->values_out_[i], this->values_[i])) { + if (Comparator->compare(this->values_[i], this->valuesOut_[i]) || + Comparator->compare(this->valuesOut_[i], this->values_[i])) { ARROW_SCOPED_TRACE("i = ", i); } - ASSERT_FALSE(comparator->Compare(this->values_[i], this->values_out_[i])); - ASSERT_FALSE(comparator->Compare(this->values_out_[i], this->values_[i])); + ASSERT_FALSE(Comparator->compare(this->values_[i], this->valuesOut_[i])); + ASSERT_FALSE(Comparator->compare(this->valuesOut_[i], this->values_[i])); } - ASSERT_EQ(this->values_, this->values_out_); + ASSERT_EQ(this->values_, this->valuesOut_); } - int64_t metadata_num_values() { + int64_t metadataNumValues() { // Metadata accessor must be created lazily. - // This is because the ColumnChunkMetaData semantics dictate the metadata - // object is complete (no changes to the metadata buffer can be made after + // This is because the ColumnChunkMetaData semantics dictate the metadata. + // Object is complete (no changes to the metadata buffer can be made after. // instantiation) - auto metadata_accessor = - ColumnChunkMetaData::Make(metadata_->contents(), this->descr_); - return metadata_accessor->num_values(); + auto metadataAccessor = + ColumnChunkMetaData::make(metadata_->Contents(), this->descr_); + return metadataAccessor->numValues(); } - bool metadata_is_stats_set() { + bool metadataIsStatsSet() { // Metadata accessor must be created lazily. - // This is because the ColumnChunkMetaData semantics dictate the metadata - // object is complete (no changes to the metadata buffer can be made after + // This is because the ColumnChunkMetaData semantics dictate the metadata. + // Object is complete (no changes to the metadata buffer can be made after. // instantiation) - ApplicationVersion app_version(this->writer_properties_->created_by()); - auto metadata_accessor = ColumnChunkMetaData::Make( - metadata_->contents(), + ApplicationVersion appVersion(this->writerProperties_->createdBy()); + auto metadataAccessor = ColumnChunkMetaData::make( + metadata_->Contents(), this->descr_, - default_reader_properties(), - &app_version); - return metadata_accessor->is_stats_set(); + defaultReaderProperties(), + &appVersion); + return metadataAccessor->isStatsSet(); } - std::pair metadata_stats_has_min_max() { + std::pair metadataStatsHasMinMax() { // Metadata accessor must be created lazily. - // This is because the ColumnChunkMetaData semantics dictate the metadata - // object is complete (no changes to the metadata buffer can be made after + // This is because the ColumnChunkMetaData semantics dictate the metadata. + // Object is complete (no changes to the metadata buffer can be made after. // instantiation) - ApplicationVersion app_version(this->writer_properties_->created_by()); - auto metadata_accessor = ColumnChunkMetaData::Make( - metadata_->contents(), + ApplicationVersion appVersion(this->writerProperties_->createdBy()); + auto metadataAccessor = ColumnChunkMetaData::make( + metadata_->Contents(), this->descr_, - default_reader_properties(), - &app_version); - auto encoded_stats = metadata_accessor->statistics()->Encode(); - return {encoded_stats.has_min, encoded_stats.has_max}; + defaultReaderProperties(), + &appVersion); + auto encodedStats = metadataAccessor->statistics()->encode(); + return {encodedStats.hasMin, encodedStats.hasMax}; } - std::vector metadata_encodings() { + std::vector metadataEncodings() { // Metadata accessor must be created lazily. - // This is because the ColumnChunkMetaData semantics dictate the metadata - // object is complete (no changes to the metadata buffer can be made after + // This is because the ColumnChunkMetaData semantics dictate the metadata. + // Object is complete (no changes to the metadata buffer can be made after. // instantiation) - auto metadata_accessor = - ColumnChunkMetaData::Make(metadata_->contents(), this->descr_); - return metadata_accessor->encodings(); + auto metadataAccessor = + ColumnChunkMetaData::make(metadata_->Contents(), this->descr_); + return metadataAccessor->encodings(); } - std::vector metadata_encoding_stats() { + std::vector metadataEncodingStats() { // Metadata accessor must be created lazily. - // This is because the ColumnChunkMetaData semantics dictate the metadata - // object is complete (no changes to the metadata buffer can be made after + // This is because the ColumnChunkMetaData semantics dictate the metadata. + // Object is complete (no changes to the metadata buffer can be made after. // instantiation) - auto metadata_accessor = - ColumnChunkMetaData::Make(metadata_->contents(), this->descr_); - return metadata_accessor->encoding_stats(); + auto metadataAccessor = + ColumnChunkMetaData::make(metadata_->Contents(), this->descr_); + return metadataAccessor->encodingStats(); } protected: - int64_t values_read_; - // Keep the reader alive as for ByteArray the lifetime of the ByteArray - // content is bound to the reader. + int64_t valuesRead_; + // Keep the reader alive as for ByteArray the lifetime of the ByteArray. + // Content is bound to the reader. std::shared_ptr> reader_; - std::vector definition_levels_out_; - std::vector repetition_levels_out_; + std::vector definitionLevelsOut_; + std::vector repetitionLevelsOut_; const ColumnDescriptor* descr_; private: std::unique_ptr metadata_; std::shared_ptr<::arrow::io::BufferOutputStream> sink_; - std::shared_ptr writer_properties_; - std::vector> data_buffer_; + std::shared_ptr writerProperties_; + std::vector> dataBuffer_; }; template -void TestPrimitiveWriter::ReadColumnFully( +void TestPrimitiveWriter::readColumnFully( Compression::type compression, - bool page_checksum_verify) { - int64_t total_values = static_cast(this->values_out_.size()); - BuildReader(total_values, compression, page_checksum_verify); - values_read_ = 0; - while (values_read_ < total_values) { - int64_t values_read_recently = 0; - reader_->ReadBatch( - static_cast(this->values_out_.size()) - - static_cast(values_read_), - definition_levels_out_.data() + values_read_, - repetition_levels_out_.data() + values_read_, - this->values_out_ptr_ + values_read_, - &values_read_recently); - values_read_ += values_read_recently; + bool pageChecksumVerify) { + int64_t totalValues = static_cast(this->valuesOut_.size()); + buildReader(totalValues, compression, pageChecksumVerify); + valuesRead_ = 0; + while (valuesRead_ < totalValues) { + int64_t valuesReadRecently = 0; + reader_->readBatch( + static_cast(this->valuesOut_.size()) - + static_cast(valuesRead_), + definitionLevelsOut_.data() + valuesRead_, + repetitionLevelsOut_.data() + valuesRead_, + this->valuesOutPtr_ + valuesRead_, + &valuesReadRecently); + valuesRead_ += valuesReadRecently; } - this->SyncValuesOut(); + this->syncValuesOut(); } template <> -void TestPrimitiveWriter::ReadAndCompare( +void TestPrimitiveWriter::readAndCompare( Compression::type compression, - int64_t num_rows, - bool page_checksum_verify) { - this->SetupValuesOut(num_rows); - this->ReadColumnFully(compression, page_checksum_verify); + int64_t numRows, + bool pageChecksumVerify) { + this->setupValuesOut(numRows); + this->readColumnFully(compression, pageChecksumVerify); - auto comparator = MakeComparator(Type::INT96, SortOrder::SIGNED); + auto Comparator = makeComparator(Type::kInt96, SortOrder::kSigned); for (size_t i = 0; i < this->values_.size(); i++) { - if (comparator->Compare(this->values_[i], this->values_out_[i]) || - comparator->Compare(this->values_out_[i], this->values_[i])) { + if (Comparator->compare(this->values_[i], this->valuesOut_[i]) || + Comparator->compare(this->valuesOut_[i], this->values_[i])) { ARROW_SCOPED_TRACE("i = ", i); } - ASSERT_FALSE(comparator->Compare(this->values_[i], this->values_out_[i])); - ASSERT_FALSE(comparator->Compare(this->values_out_[i], this->values_[i])); + ASSERT_FALSE(Comparator->compare(this->values_[i], this->valuesOut_[i])); + ASSERT_FALSE(Comparator->compare(this->valuesOut_[i], this->values_[i])); } - ASSERT_EQ(this->values_, this->values_out_); + ASSERT_EQ(this->values_, this->valuesOut_); } template <> -void TestPrimitiveWriter::ReadColumnFully( +void TestPrimitiveWriter::readColumnFully( Compression::type compression, - bool page_checksum_verify) { - int64_t total_values = static_cast(this->values_out_.size()); - BuildReader(total_values, compression, page_checksum_verify); - this->data_buffer_.clear(); - - values_read_ = 0; - while (values_read_ < total_values) { - int64_t values_read_recently = 0; - reader_->ReadBatch( - static_cast(this->values_out_.size()) - - static_cast(values_read_), - definition_levels_out_.data() + values_read_, - repetition_levels_out_.data() + values_read_, - this->values_out_ptr_ + values_read_, - &values_read_recently); - - // Copy contents of the pointers - std::vector data( - values_read_recently * this->descr_->type_length()); - uint8_t* data_ptr = data.data(); - for (int64_t i = 0; i < values_read_recently; i++) { + bool pageChecksumVerify) { + int64_t totalValues = static_cast(this->valuesOut_.size()); + buildReader(totalValues, compression, pageChecksumVerify); + this->dataBuffer_.clear(); + + valuesRead_ = 0; + while (valuesRead_ < totalValues) { + int64_t valuesReadRecently = 0; + reader_->readBatch( + static_cast(this->valuesOut_.size()) - + static_cast(valuesRead_), + definitionLevelsOut_.data() + valuesRead_, + repetitionLevelsOut_.data() + valuesRead_, + this->valuesOutPtr_ + valuesRead_, + &valuesReadRecently); + + // Copy contents of the pointers. + std::vector data(valuesReadRecently * this->descr_->typeLength()); + uint8_t* dataPtr = data.data(); + for (int64_t i = 0; i < valuesReadRecently; i++) { memcpy( - data_ptr + this->descr_->type_length() * i, - this->values_out_[i + values_read_].ptr, - this->descr_->type_length()); - this->values_out_[i + values_read_].ptr = - data_ptr + this->descr_->type_length() * i; + dataPtr + this->descr_->typeLength() * i, + this->valuesOut_[i + valuesRead_].ptr, + this->descr_->typeLength()); + this->valuesOut_[i + valuesRead_].ptr = + dataPtr + this->descr_->typeLength() * i; } - data_buffer_.emplace_back(std::move(data)); + dataBuffer_.emplace_back(std::move(data)); - values_read_ += values_read_recently; + valuesRead_ += valuesReadRecently; } - this->SyncValuesOut(); + this->syncValuesOut(); } typedef ::testing::Types< @@ -524,11 +516,11 @@ using TestByteArrayValuesWriter = TestPrimitiveWriter; using TestFixedLengthByteArrayValuesWriter = TestPrimitiveWriter; TYPED_TEST(TestPrimitiveWriter, RequiredPlain) { - this->TestRequiredWithEncoding(Encoding::PLAIN); + this->testRequiredWithEncoding(Encoding::kPlain); } TYPED_TEST(TestPrimitiveWriter, RequiredDictionary) { - this->TestRequiredWithEncoding(Encoding::PLAIN_DICTIONARY); + this->testRequiredWithEncoding(Encoding::kPlainDictionary); } /* @@ -542,15 +534,15 @@ TYPED_TEST(TestPrimitiveWriter, RequiredBitPacked) { */ TEST_F(TestValuesWriterInt32Type, RequiredDeltaBinaryPacked) { - this->TestRequiredWithEncoding(Encoding::DELTA_BINARY_PACKED); + this->testRequiredWithEncoding(Encoding::kDeltaBinaryPacked); } TEST_F(TestValuesWriterInt64Type, RequiredDeltaBinaryPacked) { - this->TestRequiredWithEncoding(Encoding::DELTA_BINARY_PACKED); + this->testRequiredWithEncoding(Encoding::kDeltaBinaryPacked); } TEST_F(TestByteArrayValuesWriter, RequiredDeltaLengthByteArray) { - this->TestRequiredWithEncoding(Encoding::DELTA_LENGTH_BYTE_ARRAY); + this->testRequiredWithEncoding(Encoding::kDeltaLengthByteArray); } /* @@ -564,432 +556,423 @@ TEST_F(TestFixedLengthByteArrayValuesWriter, RequiredDeltaByteArray) { */ TYPED_TEST(TestPrimitiveWriter, RequiredRLEDictionary) { - this->TestRequiredWithEncoding(Encoding::RLE_DICTIONARY); + this->testRequiredWithEncoding(Encoding::kRleDictionary); } TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithStats) { - this->TestRequiredWithSettings( - Encoding::PLAIN, Compression::UNCOMPRESSED, false, true, LARGE_SIZE); + this->testRequiredWithSettings( + Encoding::kPlain, Compression::UNCOMPRESSED, false, true, LARGE_SIZE); } TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithSnappyCompression) { - this->TestRequiredWithSettings( - Encoding::PLAIN, Compression::SNAPPY, false, false, LARGE_SIZE); + this->testRequiredWithSettings( + Encoding::kPlain, Compression::SNAPPY, false, false, LARGE_SIZE); } TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithStatsAndSnappyCompression) { - this->TestRequiredWithSettings( - Encoding::PLAIN, Compression::SNAPPY, false, true, LARGE_SIZE); + this->testRequiredWithSettings( + Encoding::kPlain, Compression::SNAPPY, false, true, LARGE_SIZE); } #ifdef ARROW_WITH_BROTLI TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithBrotliCompression) { - this->TestRequiredWithSettings( - Encoding::PLAIN, Compression::BROTLI, false, false, LARGE_SIZE); + this->testRequiredWithSettings( + Encoding::kPlain, Compression::BROTLI, false, false, LARGE_SIZE); } TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithBrotliCompressionAndLevel) { - this->TestRequiredWithSettings( - Encoding::PLAIN, Compression::BROTLI, false, false, LARGE_SIZE, 10); + this->testRequiredWithSettings( + Encoding::kPlain, Compression::BROTLI, false, false, LARGE_SIZE, 10); } TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithStatsAndBrotliCompression) { - this->TestRequiredWithSettings( - Encoding::PLAIN, Compression::BROTLI, false, true, LARGE_SIZE); + this->testRequiredWithSettings( + Encoding::kPlain, Compression::BROTLI, false, true, LARGE_SIZE); } #endif TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithGzipCompression) { - this->TestRequiredWithSettings( - Encoding::PLAIN, Compression::GZIP, false, false, LARGE_SIZE); + this->testRequiredWithSettings( + Encoding::kPlain, Compression::GZIP, false, false, LARGE_SIZE); } TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithGzipCompressionAndLevel) { - this->TestRequiredWithSettings( - Encoding::PLAIN, Compression::GZIP, false, false, LARGE_SIZE, 10); + this->testRequiredWithSettings( + Encoding::kPlain, Compression::GZIP, false, false, LARGE_SIZE, 10); } TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithStatsAndGzipCompression) { - this->TestRequiredWithSettings( - Encoding::PLAIN, Compression::GZIP, false, true, LARGE_SIZE); + this->testRequiredWithSettings( + Encoding::kPlain, Compression::GZIP, false, true, LARGE_SIZE); } TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithLz4Compression) { - this->TestRequiredWithSettings( - Encoding::PLAIN, Compression::LZ4, false, false, LARGE_SIZE); + this->testRequiredWithSettings( + Encoding::kPlain, Compression::LZ4, false, false, LARGE_SIZE); } TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithStatsAndLz4Compression) { - this->TestRequiredWithSettings( - Encoding::PLAIN, Compression::LZ4, false, true, LARGE_SIZE); + this->testRequiredWithSettings( + Encoding::kPlain, Compression::LZ4, false, true, LARGE_SIZE); } TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithZstdCompression) { - this->TestRequiredWithSettings( - Encoding::PLAIN, Compression::ZSTD, false, false, LARGE_SIZE); + this->testRequiredWithSettings( + Encoding::kPlain, Compression::ZSTD, false, false, LARGE_SIZE); } TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithZstdCompressionAndLevel) { - this->TestRequiredWithSettings( - Encoding::PLAIN, Compression::ZSTD, false, false, LARGE_SIZE, 6); + this->testRequiredWithSettings( + Encoding::kPlain, Compression::ZSTD, false, false, LARGE_SIZE, 6); } TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithStatsAndZstdCompression) { - this->TestRequiredWithSettings( - Encoding::PLAIN, Compression::ZSTD, false, true, LARGE_SIZE); + this->testRequiredWithSettings( + Encoding::kPlain, Compression::ZSTD, false, true, LARGE_SIZE); } TYPED_TEST(TestPrimitiveWriter, Optional) { - // Optional and non-repeated, with definition levels - // but no repetition levels - this->SetUpSchema(Repetition::OPTIONAL); + // Optional and non-repeated, with definition levels. + // But no repetition levels. + this->setUpSchema(Repetition::kOptional); - this->GenerateData(SMALL_SIZE); - std::vector definition_levels(SMALL_SIZE, 1); - definition_levels[1] = 0; + this->generateData(SMALL_SIZE); + std::vector definitionLevels(SMALL_SIZE, 1); + definitionLevels[1] = 0; - auto writer = this->BuildWriter(); - writer->WriteBatch( - this->values_.size(), - definition_levels.data(), - nullptr, - this->values_ptr_); - writer->Close(); + auto writer = this->buildWriter(); + writer->writeBatch( + this->values_.size(), definitionLevels.data(), nullptr, this->valuesPtr_); + writer->close(); - // PARQUET-703 - ASSERT_EQ(100, this->metadata_num_values()); + // PARQUET-703. + ASSERT_EQ(100, this->metadataNumValues()); - this->ReadColumn(); - ASSERT_EQ(99, this->values_read_); - this->values_out_.resize(99); + this->readColumn(); + ASSERT_EQ(99, this->valuesRead_); + this->valuesOut_.resize(99); this->values_.resize(99); - ASSERT_EQ(this->values_, this->values_out_); + ASSERT_EQ(this->values_, this->valuesOut_); } TYPED_TEST(TestPrimitiveWriter, OptionalSpaced) { - // Optional and non-repeated, with definition levels - // but no repetition levels - this->SetUpSchema(Repetition::OPTIONAL); + // Optional and non-repeated, with definition levels. + // But no repetition levels. + this->setUpSchema(Repetition::kOptional); - this->GenerateData(SMALL_SIZE); - std::vector definition_levels(SMALL_SIZE, 1); - std::vector valid_bits( + this->generateData(SMALL_SIZE); + std::vector definitionLevels(SMALL_SIZE, 1); + std::vector validBits( ::arrow::bit_util::BytesForBits(SMALL_SIZE), 255); - definition_levels[SMALL_SIZE - 1] = 0; - ::arrow::bit_util::ClearBit(valid_bits.data(), SMALL_SIZE - 1); - definition_levels[1] = 0; - ::arrow::bit_util::ClearBit(valid_bits.data(), 1); + definitionLevels[SMALL_SIZE - 1] = 0; + ::arrow::bit_util::ClearBit(validBits.data(), SMALL_SIZE - 1); + definitionLevels[1] = 0; + ::arrow::bit_util::ClearBit(validBits.data(), 1); - auto writer = this->BuildWriter(); - writer->WriteBatchSpaced( + auto writer = this->buildWriter(); + writer->writeBatchSpaced( this->values_.size(), - definition_levels.data(), + definitionLevels.data(), nullptr, - valid_bits.data(), + validBits.data(), 0, - this->values_ptr_); - writer->Close(); + this->valuesPtr_); + writer->close(); - // PARQUET-703 - ASSERT_EQ(100, this->metadata_num_values()); + // PARQUET-703. + ASSERT_EQ(100, this->metadataNumValues()); - this->ReadColumn(); - ASSERT_EQ(98, this->values_read_); - this->values_out_.resize(98); + this->readColumn(); + ASSERT_EQ(98, this->valuesRead_); + this->valuesOut_.resize(98); this->values_.resize(99); this->values_.erase(this->values_.begin() + 1); - ASSERT_EQ(this->values_, this->values_out_); + ASSERT_EQ(this->values_, this->valuesOut_); } TYPED_TEST(TestPrimitiveWriter, Repeated) { - // Optional and repeated, so definition and repetition levels - this->SetUpSchema(Repetition::REPEATED); + // Optional and repeated, so definition and repetition levels. + this->setUpSchema(Repetition::kRepeated); - this->GenerateData(SMALL_SIZE); - std::vector definition_levels(SMALL_SIZE, 1); - definition_levels[1] = 0; - std::vector repetition_levels(SMALL_SIZE, 0); + this->generateData(SMALL_SIZE); + std::vector definitionLevels(SMALL_SIZE, 1); + definitionLevels[1] = 0; + std::vector repetitionLevels(SMALL_SIZE, 0); - auto writer = this->BuildWriter(); - writer->WriteBatch( + auto writer = this->buildWriter(); + writer->writeBatch( this->values_.size(), - definition_levels.data(), - repetition_levels.data(), - this->values_ptr_); - writer->Close(); - - this->ReadColumn(); - ASSERT_EQ(SMALL_SIZE - 1, this->values_read_); - this->values_out_.resize(SMALL_SIZE - 1); + definitionLevels.data(), + repetitionLevels.data(), + this->valuesPtr_); + writer->close(); + + this->readColumn(); + ASSERT_EQ(SMALL_SIZE - 1, this->valuesRead_); + this->valuesOut_.resize(SMALL_SIZE - 1); this->values_.resize(SMALL_SIZE - 1); - ASSERT_EQ(this->values_, this->values_out_); + ASSERT_EQ(this->values_, this->valuesOut_); } TYPED_TEST(TestPrimitiveWriter, RequiredLargeChunk) { - this->GenerateData(LARGE_SIZE); + this->generateData(LARGE_SIZE); - // Test case 1: required and non-repeated, so no definition or repetition - // levels - auto writer = this->BuildWriter(LARGE_SIZE); - writer->WriteBatch(this->values_.size(), nullptr, nullptr, this->values_ptr_); - writer->Close(); + // Test case 1: required and non-repeated, so no definition or repetition. + // Levels. + auto writer = this->buildWriter(LARGE_SIZE); + writer->writeBatch(this->values_.size(), nullptr, nullptr, this->valuesPtr_); + writer->close(); - // Just read the first SMALL_SIZE rows to ensure we could read it back in - this->ReadColumn(); - ASSERT_EQ(SMALL_SIZE, this->values_read_); + // Just read the first SMALL_SIZE rows to ensure we could read it back in. + this->readColumn(); + ASSERT_EQ(SMALL_SIZE, this->valuesRead_); this->values_.resize(SMALL_SIZE); - ASSERT_EQ(this->values_, this->values_out_); + ASSERT_EQ(this->values_, this->valuesOut_); } -// Test cases for dictionary fallback encoding -TYPED_TEST(TestPrimitiveWriter, DictionaryFallbackVersion1_0) { - this->TestDictionaryFallbackEncoding(ParquetVersion::PARQUET_1_0); +// Test cases for dictionary fallback encoding. +TYPED_TEST(TestPrimitiveWriter, dictionaryfallbackversion10) { + this->testDictionaryFallbackEncoding(ParquetVersion::PARQUET_1_0); } -TYPED_TEST(TestPrimitiveWriter, DictionaryFallbackVersion2_0) { - this->TestDictionaryFallbackEncoding(ParquetVersion::PARQUET_2_4); - this->TestDictionaryFallbackEncoding(ParquetVersion::PARQUET_2_6); +TYPED_TEST(TestPrimitiveWriter, dictionaryfallbackversion20) { + this->testDictionaryFallbackEncoding(ParquetVersion::PARQUET_2_4); + this->testDictionaryFallbackEncoding(ParquetVersion::PARQUET_2_6); } TEST(TestWriter, NullValuesBuffer) { - std::shared_ptr<::arrow::io::BufferOutputStream> sink = CreateOutputStream(); - - const auto item_node = schema::PrimitiveNode::Make( - "item", Repetition::REQUIRED, LogicalType::Int(32, true), Type::INT32); - const auto list_node = - schema::GroupNode::Make("list", Repetition::REPEATED, {item_node}); - const auto column_node = schema::GroupNode::Make( + std::shared_ptr<::arrow::io::BufferOutputStream> sink = createOutputStream(); + + const auto itemNode = schema::PrimitiveNode::make( + "item", + Repetition::kRequired, + LogicalType::intType(32, true), + Type::kInt32); + const auto listNode = + schema::GroupNode::make("list", Repetition::kRepeated, {itemNode}); + const auto columnNode = schema::GroupNode::make( "array_of_ints_column", - Repetition::OPTIONAL, - {list_node}, - LogicalType::List()); - const auto schema_node = - schema::GroupNode::Make("schema", Repetition::REQUIRED, {column_node}); - - auto file_writer = ParquetFileWriter::Open( - sink, std::dynamic_pointer_cast(schema_node)); - auto group_writer = file_writer->AppendRowGroup(); - auto column_writer = group_writer->NextColumn(); - auto typed_writer = dynamic_cast(column_writer); - - const int64_t num_values = 1; - const int16_t def_levels[] = {0}; - const int16_t rep_levels[] = {0}; - const uint8_t valid_bits[] = {0}; - const int64_t valid_bits_offset = 0; + Repetition::kOptional, + {listNode}, + LogicalType::list()); + const auto schemaNode = + schema::GroupNode::make("schema", Repetition::kRequired, {columnNode}); + + auto fileWriter = ParquetFileWriter::open( + sink, std::dynamic_pointer_cast(schemaNode)); + auto groupWriter = fileWriter->appendRowGroup(); + auto columnWriter = groupWriter->nextColumn(); + auto typedWriter = dynamic_cast(columnWriter); + + const int64_t numValues = 1; + const int16_t defLevels[] = {0}; + const int16_t repLevels[] = {0}; + const uint8_t validBits[] = {0}; + const int64_t validBitsOffset = 0; const int32_t* values = nullptr; - typed_writer->WriteBatchSpaced( - num_values, - def_levels, - rep_levels, - valid_bits, - valid_bits_offset, - values); + typedWriter->writeBatchSpaced( + numValues, defLevels, repLevels, validBits, validBitsOffset, values); } TYPED_TEST(TestPrimitiveWriter, RequiredPlainChecksum) { - this->TestRequiredWithSettings( - Encoding::PLAIN, + this->testRequiredWithSettings( + Encoding::kPlain, Compression::UNCOMPRESSED, /* enable_dictionary */ false, false, SMALL_SIZE, - Codec::UseDefaultCompressionLevel(), + Codec::useDefaultCompressionLevel(), /* enable_checksum */ true); } TYPED_TEST(TestPrimitiveWriter, RequiredDictChecksum) { - this->TestRequiredWithSettings( - Encoding::PLAIN, + this->testRequiredWithSettings( + Encoding::kPlain, Compression::UNCOMPRESSED, /* enable_dictionary */ true, false, SMALL_SIZE, - Codec::UseDefaultCompressionLevel(), + Codec::useDefaultCompressionLevel(), /* enable_checksum */ true); } -// PARQUET-719 -// Test case for NULL values +// PARQUET-719. +// Test case for NULL values. TEST_F(TestValuesWriterInt32Type, OptionalNullValueChunk) { - this->SetUpSchema(Repetition::OPTIONAL); + this->setUpSchema(Repetition::kOptional); - this->GenerateData(LARGE_SIZE); + this->generateData(LARGE_SIZE); - std::vector definition_levels(LARGE_SIZE, 0); - std::vector repetition_levels(LARGE_SIZE, 0); + std::vector definitionLevels(LARGE_SIZE, 0); + std::vector repetitionLevels(LARGE_SIZE, 0); - auto writer = this->BuildWriter(LARGE_SIZE); - // All values being written are NULL - writer->WriteBatch( + auto writer = this->buildWriter(LARGE_SIZE); + // All values being written are NULL. + writer->writeBatch( this->values_.size(), - definition_levels.data(), - repetition_levels.data(), + definitionLevels.data(), + repetitionLevels.data(), nullptr); - writer->Close(); + writer->close(); - // Just read the first SMALL_SIZE rows to ensure we could read it back in - this->ReadColumn(); - ASSERT_EQ(0, this->values_read_); + // Just read the first SMALL_SIZE rows to ensure we could read it back in. + this->readColumn(); + ASSERT_EQ(0, this->valuesRead_); } -// PARQUET-764 -// Correct bitpacking for boolean write at non-byte boundaries +// PARQUET-764. +// Correct bitpacking for boolean write at non-byte boundaries. using TestBooleanValuesWriter = TestPrimitiveWriter; TEST_F(TestBooleanValuesWriter, AlternateBooleanValues) { - this->SetUpSchema(Repetition::REQUIRED); - auto writer = this->BuildWriter(); + this->setUpSchema(Repetition::kRequired); + auto writer = this->buildWriter(); for (int i = 0; i < SMALL_SIZE; i++) { bool value = (i % 2 == 0) ? true : false; - writer->WriteBatch(1, nullptr, nullptr, &value); + writer->writeBatch(1, nullptr, nullptr, &value); } - writer->Close(); - this->ReadColumn(); + writer->close(); + this->readColumn(); for (int i = 0; i < SMALL_SIZE; i++) { - ASSERT_EQ((i % 2 == 0) ? true : false, this->values_out_[i]) << i; + ASSERT_EQ((i % 2 == 0) ? true : false, this->valuesOut_[i]) << i; } } -// PARQUET-979 -// Prevent writing large MIN, MAX stats +// PARQUET-979. +// Prevent writing large MIN, MAX stats. TEST_F(TestByteArrayValuesWriter, OmitStats) { - int min_len = 1024 * 4; - int max_len = 1024 * 8; - this->SetUpSchema(Repetition::REQUIRED); - auto writer = this->BuildWriter(); + int minLen = 1024 * 4; + int maxLen = 1024 * 8; + this->setUpSchema(Repetition::kRequired); + auto writer = this->buildWriter(); values_.resize(SMALL_SIZE); - InitWideByteArrayValues( - SMALL_SIZE, this->values_, this->buffer_, min_len, max_len); - writer->WriteBatch(SMALL_SIZE, nullptr, nullptr, this->values_.data()); - writer->Close(); - - auto has_min_max = this->metadata_stats_has_min_max(); - ASSERT_FALSE(has_min_max.first); - ASSERT_FALSE(has_min_max.second); + initWideByteArrayValues( + SMALL_SIZE, this->values_, this->buffer_, minLen, maxLen); + writer->writeBatch(SMALL_SIZE, nullptr, nullptr, this->values_.data()); + writer->close(); + + auto hasMinMax = this->metadataStatsHasMinMax(); + ASSERT_FALSE(hasMinMax.first); + ASSERT_FALSE(hasMinMax.second); } -// PARQUET-1405 -// Prevent writing large stats in the DataPageHeader +// PARQUET-1405. +// Prevent writing large stats in the DataPageHeader. TEST_F(TestByteArrayValuesWriter, OmitDataPageStats) { - int min_len = static_cast(std::pow(10, 7)); - int max_len = static_cast(std::pow(10, 7)); - this->SetUpSchema(Repetition::REQUIRED); - ColumnProperties column_properties; - column_properties.set_statistics_enabled(false); - auto writer = this->BuildWriter(SMALL_SIZE, column_properties); + int minLen = static_cast(std::pow(10, 7)); + int maxLen = static_cast(std::pow(10, 7)); + this->setUpSchema(Repetition::kRequired); + ColumnProperties columnProperties; + columnProperties.setStatisticsEnabled(false); + auto writer = this->buildWriter(SMALL_SIZE, columnProperties); values_.resize(1); - InitWideByteArrayValues(1, this->values_, this->buffer_, min_len, max_len); - writer->WriteBatch(1, nullptr, nullptr, this->values_.data()); - writer->Close(); + initWideByteArrayValues(1, this->values_, this->buffer_, minLen, maxLen); + writer->writeBatch(1, nullptr, nullptr, this->values_.data()); + writer->close(); - ASSERT_NO_THROW(this->ReadColumn()); + ASSERT_NO_THROW(this->readColumn()); } TEST_F(TestByteArrayValuesWriter, LimitStats) { - int min_len = 1024 * 4; - int max_len = 1024 * 8; - this->SetUpSchema(Repetition::REQUIRED); - ColumnProperties column_properties; - column_properties.set_max_statistics_size(static_cast(max_len)); - auto writer = this->BuildWriter(SMALL_SIZE, column_properties); + int minLen = 1024 * 4; + int maxLen = 1024 * 8; + this->setUpSchema(Repetition::kRequired); + ColumnProperties columnProperties; + columnProperties.setMaxStatisticsSize(static_cast(maxLen)); + auto writer = this->buildWriter(SMALL_SIZE, columnProperties); values_.resize(SMALL_SIZE); - InitWideByteArrayValues( - SMALL_SIZE, this->values_, this->buffer_, min_len, max_len); - writer->WriteBatch(SMALL_SIZE, nullptr, nullptr, this->values_.data()); - writer->Close(); + initWideByteArrayValues( + SMALL_SIZE, this->values_, this->buffer_, minLen, maxLen); + writer->writeBatch(SMALL_SIZE, nullptr, nullptr, this->values_.data()); + writer->close(); - ASSERT_TRUE(this->metadata_is_stats_set()); + ASSERT_TRUE(this->metadataIsStatsSet()); } TEST_F(TestByteArrayValuesWriter, CheckDefaultStats) { - this->SetUpSchema(Repetition::REQUIRED); - auto writer = this->BuildWriter(); - this->GenerateData(SMALL_SIZE); + this->setUpSchema(Repetition::kRequired); + auto writer = this->buildWriter(); + this->generateData(SMALL_SIZE); - writer->WriteBatch(SMALL_SIZE, nullptr, nullptr, this->values_ptr_); - writer->Close(); + writer->writeBatch(SMALL_SIZE, nullptr, nullptr, this->valuesPtr_); + writer->close(); - ASSERT_TRUE(this->metadata_is_stats_set()); + ASSERT_TRUE(this->metadataIsStatsSet()); } TEST(TestColumnWriter, RepeatedListsUpdateSpacedBug) { - // In ARROW-3930 we discovered a bug when writing from Arrow when we had data - // that looks like this: + // In ARROW-3930 we discovered a bug when writing from Arrow when we had data. + // That looks like this: // - // [null, [0, 1, null, 2, 3, 4, null]] + // [Null, [0, 1, null, 2, 3, 4, null]]. - // Create schema - NodePtr item = schema::Int32("item"); // optional item - NodePtr list( - GroupNode::Make("b", Repetition::REPEATED, {item}, ConvertedType::LIST)); + // Create schema. + NodePtr item = schema::int32("item"); // optional item + NodePtr List( + GroupNode::make( + "b", Repetition::kRepeated, {item}, ConvertedType::kList)); NodePtr bag( - GroupNode::Make("bag", Repetition::OPTIONAL, {list})); // optional list + GroupNode::make("bag", Repetition::kOptional, {List})); // optional list std::vector fields = {bag}; - NodePtr root = GroupNode::Make("schema", Repetition::REPEATED, fields); + NodePtr root = GroupNode::make("schema", Repetition::kRepeated, fields); SchemaDescriptor schema; - schema.Init(root); + schema.init(root); - auto sink = CreateOutputStream(); + auto sink = createOutputStream(); auto props = WriterProperties::Builder().build(); - auto metadata = ColumnChunkMetaDataBuilder::Make(props, schema.Column(0)); - std::unique_ptr pager = PageWriter::Open( + auto metadata = ColumnChunkMetaDataBuilder::make(props, schema.column(0)); + std::unique_ptr pager = PageWriter::open( sink, Compression::UNCOMPRESSED, - Codec::UseDefaultCompressionLevel(), + Codec::useDefaultCompressionLevel(), metadata.get()); std::shared_ptr writer = - ColumnWriter::Make(metadata.get(), std::move(pager), props.get()); - auto typed_writer = + ColumnWriter::make(metadata.get(), std::move(pager), props.get()); + auto typedWriter = std::static_pointer_cast>(writer); - std::vector def_levels = {1, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3}; - std::vector rep_levels = {0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + std::vector defLevels = {1, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3}; + std::vector repLevels = {0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; std::vector values = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; - // Write the values into uninitialized memory - ASSERT_OK_AND_ASSIGN(auto values_buffer, ::arrow::AllocateBuffer(64)); - memcpy(values_buffer->mutable_data(), values.data(), 13 * sizeof(int32_t)); - auto values_data = reinterpret_cast(values_buffer->data()); + // Write the values into uninitialized memory. + ASSERT_OK_AND_ASSIGN(auto valuesBuffer, ::arrow::AllocateBuffer(64)); + memcpy(valuesBuffer->mutable_data(), values.data(), 13 * sizeof(int32_t)); + auto valuesData = reinterpret_cast(valuesBuffer->data()); - std::shared_ptr valid_bits; + std::shared_ptr validBits; ASSERT_OK_AND_ASSIGN( - valid_bits, + validBits, ::arrow::internal::BytesToBits({1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1})); - // valgrind will warn about out of bounds access into def_levels_data - typed_writer->WriteBatchSpaced( - 14, - def_levels.data(), - rep_levels.data(), - valid_bits->data(), - 0, - values_data); - writer->Close(); + // Valgrind will warn about out of bounds access into def_levels_data. + typedWriter->writeBatchSpaced( + 14, defLevels.data(), repLevels.data(), validBits->data(), 0, valuesData); + writer->close(); } -void GenerateLevels( - int min_repeat_factor, - int max_repeat_factor, - int max_level, - std::vector& input_levels) { - // for each repetition count up to max_repeat_factor - for (int repeat = min_repeat_factor; repeat <= max_repeat_factor; repeat++) { - // repeat count increases by a factor of 2 for every iteration - int repeat_count = (1 << repeat); - // generate levels for repetition count up to the maximum level +void generateLevels( + int minRepeatFactor, + int maxRepeatFactor, + int maxLevel, + std::vector& inputLevels) { + // For each repetition count up to max_repeat_factor. + for (int repeat = minRepeatFactor; repeat <= maxRepeatFactor; repeat++) { + // Repeat count increases by a factor of 2 for every iteration. + int repeatCount = (1 << repeat); + // Generate levels for repetition count up to the maximum level. int16_t value = 0; int bwidth = 0; - while (value <= max_level) { - for (int i = 0; i < repeat_count; i++) { - input_levels.push_back(value); + while (value <= maxLevel) { + for (int i = 0; i < repeatCount; i++) { + inputLevels.push_back(value); } value = static_cast((2 << bwidth) - 1); bwidth++; @@ -997,186 +980,185 @@ void GenerateLevels( } } -void EncodeLevels( +void encodeLevels( Encoding::type encoding, - int16_t max_level, - int num_levels, - const int16_t* input_levels, + int16_t maxLevel, + int numLevels, + const int16_t* inputLevels, std::vector& bytes) { LevelEncoder encoder; - int levels_count = 0; - bytes.resize(2 * num_levels); - ASSERT_EQ(2 * num_levels, static_cast(bytes.size())); - // encode levels - if (encoding == Encoding::RLE) { - // leave space to write the rle length value - encoder.Init( + int levelsCount = 0; + bytes.resize(2 * numLevels); + ASSERT_EQ(2 * numLevels, static_cast(bytes.size())); + // Encode levels. + if (encoding == Encoding::kRle) { + // Leave space to write the rle length value. + encoder.init( encoding, - max_level, - num_levels, + maxLevel, + numLevels, bytes.data() + sizeof(int32_t), static_cast(bytes.size())); - levels_count = encoder.Encode(num_levels, input_levels); + levelsCount = encoder.encode(numLevels, inputLevels); (reinterpret_cast(bytes.data()))[0] = encoder.len(); } else { - encoder.Init( + encoder.init( encoding, - max_level, - num_levels, + maxLevel, + numLevels, bytes.data(), static_cast(bytes.size())); - levels_count = encoder.Encode(num_levels, input_levels); + levelsCount = encoder.encode(numLevels, inputLevels); } - ASSERT_EQ(num_levels, levels_count); + ASSERT_EQ(numLevels, levelsCount); } -void VerifyDecodingLevels( +void verifyDecodingLevels( Encoding::type encoding, - int16_t max_level, - std::vector& input_levels, + int16_t maxLevel, + std::vector& inputLevels, std::vector& bytes) { LevelDecoder decoder; - int levels_count = 0; - std::vector output_levels; - int num_levels = static_cast(input_levels.size()); + int levelsCount = 0; + std::vector outputLevels; + int numLevels = static_cast(inputLevels.size()); - output_levels.resize(num_levels); - ASSERT_EQ(num_levels, static_cast(output_levels.size())); + outputLevels.resize(numLevels); + ASSERT_EQ(numLevels, static_cast(outputLevels.size())); - // Decode levels and test with multiple decode calls - decoder.SetData( + // Decode levels and test with multiple decode calls. + decoder.setData( encoding, - max_level, - num_levels, + maxLevel, + numLevels, bytes.data(), static_cast(bytes.size())); - int decode_count = 4; - int num_inner_levels = num_levels / decode_count; - // Try multiple decoding on a single SetData call - for (int ct = 0; ct < decode_count; ct++) { - int offset = ct * num_inner_levels; - levels_count = decoder.Decode(num_inner_levels, output_levels.data()); - ASSERT_EQ(num_inner_levels, levels_count); - for (int i = 0; i < num_inner_levels; i++) { - EXPECT_EQ(input_levels[i + offset], output_levels[i]); + int decodeCount = 4; + int numInnerLevels = numLevels / decodeCount; + // Try multiple decoding on a single SetData call. + for (int ct = 0; ct < decodeCount; ct++) { + int Offset = ct * numInnerLevels; + levelsCount = decoder.decode(numInnerLevels, outputLevels.data()); + ASSERT_EQ(numInnerLevels, levelsCount); + for (int i = 0; i < numInnerLevels; i++) { + EXPECT_EQ(inputLevels[i + Offset], outputLevels[i]); } } - // check the remaining levels - int num_levels_completed = decode_count * (num_levels / decode_count); - int num_remaining_levels = num_levels - num_levels_completed; - if (num_remaining_levels > 0) { - levels_count = decoder.Decode(num_remaining_levels, output_levels.data()); - ASSERT_EQ(num_remaining_levels, levels_count); - for (int i = 0; i < num_remaining_levels; i++) { - EXPECT_EQ(input_levels[i + num_levels_completed], output_levels[i]); + // Check the remaining levels. + int numLevelsCompleted = decodeCount * (numLevels / decodeCount); + int numRemainingLevels = numLevels - numLevelsCompleted; + if (numRemainingLevels > 0) { + levelsCount = decoder.decode(numRemainingLevels, outputLevels.data()); + ASSERT_EQ(numRemainingLevels, levelsCount); + for (int i = 0; i < numRemainingLevels; i++) { + EXPECT_EQ(inputLevels[i + numLevelsCompleted], outputLevels[i]); } } - // Test zero Decode values - ASSERT_EQ(0, decoder.Decode(1, output_levels.data())); + // Test zero Decode values. + ASSERT_EQ(0, decoder.decode(1, outputLevels.data())); } -void VerifyDecodingMultipleSetData( +void verifyDecodingMultipleSetData( Encoding::type encoding, - int16_t max_level, - std::vector& input_levels, + int16_t maxLevel, + std::vector& inputLevels, std::vector>& bytes) { LevelDecoder decoder; - int levels_count = 0; - std::vector output_levels; - - // Decode levels and test with multiple SetData calls - int setdata_count = static_cast(bytes.size()); - int num_levels = static_cast(input_levels.size()) / setdata_count; - output_levels.resize(num_levels); - // Try multiple SetData - for (int ct = 0; ct < setdata_count; ct++) { - int offset = ct * num_levels; - ASSERT_EQ(num_levels, static_cast(output_levels.size())); - decoder.SetData( + int levelsCount = 0; + std::vector outputLevels; + + // Decode levels and test with multiple SetData calls. + int setdataCount = static_cast(bytes.size()); + int numLevels = static_cast(inputLevels.size()) / setdataCount; + outputLevels.resize(numLevels); + // Try multiple SetData. + for (int ct = 0; ct < setdataCount; ct++) { + int Offset = ct * numLevels; + ASSERT_EQ(numLevels, static_cast(outputLevels.size())); + decoder.setData( encoding, - max_level, - num_levels, + maxLevel, + numLevels, bytes[ct].data(), static_cast(bytes[ct].size())); - levels_count = decoder.Decode(num_levels, output_levels.data()); - ASSERT_EQ(num_levels, levels_count); - for (int i = 0; i < num_levels; i++) { - EXPECT_EQ(input_levels[i + offset], output_levels[i]); + levelsCount = decoder.decode(numLevels, outputLevels.data()); + ASSERT_EQ(numLevels, levelsCount); + for (int i = 0; i < numLevels; i++) { + EXPECT_EQ(inputLevels[i + Offset], outputLevels[i]); } } } -// Test levels with maximum bit-width from 1 to 8 -// increase the repetition count for each iteration by a factor of 2 +// Test levels with maximum bit-width from 1 to 8. +// Increase the repetition count for each iteration by a factor of 2. TEST(TestLevels, TestLevelsDecodeMultipleBitWidth) { - int min_repeat_factor = 0; - int max_repeat_factor = 7; // 128 - int max_bit_width = 8; - std::vector input_levels; + int minRepeatFactor = 0; + int maxRepeatFactor = 7; // 128 + int maxBitWidth = 8; + std::vector inputLevels; std::vector bytes; - Encoding::type encodings[2] = {Encoding::RLE, Encoding::BIT_PACKED}; + Encoding::type encodings[2] = {Encoding::kRle, Encoding::kBitPacked}; - // for each encoding + // For each encoding. for (int encode = 0; encode < 2; encode++) { Encoding::type encoding = encodings[encode]; - // BIT_PACKED requires a sequence of at least 8 - if (encoding == Encoding::BIT_PACKED) - min_repeat_factor = 3; - // for each maximum bit-width - for (int bit_width = 1; bit_width <= max_bit_width; bit_width++) { - // find the maximum level for the current bit_width - int16_t max_level = static_cast((1 << bit_width) - 1); - // Generate levels - GenerateLevels( - min_repeat_factor, max_repeat_factor, max_level, input_levels); - ASSERT_NO_FATAL_FAILURE(EncodeLevels( + // BIT_PACKED requires a sequence of at least 8. + if (encoding == Encoding::kBitPacked) + minRepeatFactor = 3; + // For each maximum bit-width. + for (int bitWidth = 1; bitWidth <= maxBitWidth; bitWidth++) { + // Find the maximum level for the current bit_width. + int16_t maxLevel = static_cast((1 << bitWidth) - 1); + // Generate levels. + generateLevels(minRepeatFactor, maxRepeatFactor, maxLevel, inputLevels); + ASSERT_NO_FATAL_FAILURE(encodeLevels( encoding, - max_level, - static_cast(input_levels.size()), - input_levels.data(), + maxLevel, + static_cast(inputLevels.size()), + inputLevels.data(), bytes)); ASSERT_NO_FATAL_FAILURE( - VerifyDecodingLevels(encoding, max_level, input_levels, bytes)); - input_levels.clear(); + verifyDecodingLevels(encoding, maxLevel, inputLevels, bytes)); + inputLevels.clear(); } } } -// Test multiple decoder SetData calls +// Test multiple decoder SetData calls. TEST(TestLevels, TestLevelsDecodeMultipleSetData) { - int min_repeat_factor = 3; - int max_repeat_factor = 7; // 128 - int bit_width = 8; - int16_t max_level = static_cast((1 << bit_width) - 1); - std::vector input_levels; + int minRepeatFactor = 3; + int maxRepeatFactor = 7; // 128 + int bitWidth = 8; + int16_t maxLevel = static_cast((1 << bitWidth) - 1); + std::vector inputLevels; std::vector> bytes; - Encoding::type encodings[2] = {Encoding::RLE, Encoding::BIT_PACKED}; - GenerateLevels(min_repeat_factor, max_repeat_factor, max_level, input_levels); - int num_levels = static_cast(input_levels.size()); - int setdata_factor = 8; - int split_level_size = num_levels / setdata_factor; - bytes.resize(setdata_factor); - - // for each encoding + Encoding::type encodings[2] = {Encoding::kRle, Encoding::kBitPacked}; + generateLevels(minRepeatFactor, maxRepeatFactor, maxLevel, inputLevels); + int numLevels = static_cast(inputLevels.size()); + int setdataFactor = 8; + int splitLevelSize = numLevels / setdataFactor; + bytes.resize(setdataFactor); + + // For each encoding. for (int encode = 0; encode < 2; encode++) { Encoding::type encoding = encodings[encode]; - for (int rf = 0; rf < setdata_factor; rf++) { - int offset = rf * split_level_size; - ASSERT_NO_FATAL_FAILURE(EncodeLevels( + for (int rf = 0; rf < setdataFactor; rf++) { + int Offset = rf * splitLevelSize; + ASSERT_NO_FATAL_FAILURE(encodeLevels( encoding, - max_level, - split_level_size, - reinterpret_cast(input_levels.data()) + offset, + maxLevel, + splitLevelSize, + reinterpret_cast(inputLevels.data()) + Offset, bytes[rf])); } - ASSERT_NO_FATAL_FAILURE(VerifyDecodingMultipleSetData( - encoding, max_level, input_levels, bytes)); + ASSERT_NO_FATAL_FAILURE( + verifyDecodingMultipleSetData(encoding, maxLevel, inputLevels, bytes)); } } TEST(TestLevelEncoder, MinimumBufferSize) { - // PARQUET-676, PARQUET-698 + // PARQUET-676, PARQUET-698. const int kNumToEncode = 1024; std::vector levels; @@ -1189,23 +1171,23 @@ TEST(TestLevelEncoder, MinimumBufferSize) { } std::vector output( - LevelEncoder::MaxBufferSize(Encoding::RLE, 1, kNumToEncode)); + LevelEncoder::maxBufferSize(Encoding::kRle, 1, kNumToEncode)); LevelEncoder encoder; - encoder.Init( - Encoding::RLE, + encoder.init( + Encoding::kRle, 1, kNumToEncode, output.data(), static_cast(output.size())); - int encode_count = encoder.Encode(kNumToEncode, levels.data()); + int encodeCount = encoder.encode(kNumToEncode, levels.data()); - ASSERT_EQ(kNumToEncode, encode_count); + ASSERT_EQ(kNumToEncode, encodeCount); } TEST(TestLevelEncoder, MinimumBufferSize2) { - // PARQUET-708 - // Test the worst case for bit_width=2 consisting of + // PARQUET-708. + // Test the worst case for bit_width=2 consisting of. // LiteralRun(size=8) // RepeatedRun(size=8) // LiteralRun(size=8) @@ -1214,8 +1196,8 @@ TEST(TestLevelEncoder, MinimumBufferSize2) { std::vector levels; for (int i = 0; i < kNumToEncode; ++i) { - // This forces a literal run of 00000001 - // followed by eight 1s + // This forces a literal run of 00000001. + // Followed by eight 1s. if ((i % 16) < 7) { levels.push_back(0); } else { @@ -1223,332 +1205,326 @@ TEST(TestLevelEncoder, MinimumBufferSize2) { } } - for (int16_t bit_width = 1; bit_width <= 8; bit_width++) { + for (int16_t bitWidth = 1; bitWidth <= 8; bitWidth++) { std::vector output( - LevelEncoder::MaxBufferSize(Encoding::RLE, bit_width, kNumToEncode)); + LevelEncoder::maxBufferSize(Encoding::kRle, bitWidth, kNumToEncode)); LevelEncoder encoder; - encoder.Init( - Encoding::RLE, - bit_width, + encoder.init( + Encoding::kRle, + bitWidth, kNumToEncode, output.data(), static_cast(output.size())); - int encode_count = encoder.Encode(kNumToEncode, levels.data()); + int encodeCount = encoder.encode(kNumToEncode, levels.data()); - ASSERT_EQ(kNumToEncode, encode_count); + ASSERT_EQ(kNumToEncode, encodeCount); } } TEST(TestColumnWriter, WriteDataPageV2Header) { - auto sink = CreateOutputStream(); - auto schema = std::static_pointer_cast(GroupNode::Make( + auto sink = createOutputStream(); + auto schema = std::static_pointer_cast(GroupNode::make( "schema", - Repetition::REQUIRED, + Repetition::kRequired, { - schema::Int32("required", Repetition::REQUIRED), - schema::Int32("optional", Repetition::OPTIONAL), - schema::Int32("repeated", Repetition::REPEATED), + schema::int32("required", Repetition::kRequired), + schema::int32("optional", Repetition::kOptional), + schema::int32("repeated", Repetition::kRepeated), })); auto properties = WriterProperties::Builder() - .disable_dictionary() - ->data_page_version(ParquetDataPageVersion::V2) + .disableDictionary() + ->dataPageVersion(ParquetDataPageVersion::V2) ->build(); - auto file_writer = ParquetFileWriter::Open(sink, schema, properties); - auto rg_writer = file_writer->AppendRowGroup(); + auto fileWriter = ParquetFileWriter::open(sink, schema, properties); + auto rgWriter = fileWriter->appendRowGroup(); - constexpr int32_t num_rows = 100; + constexpr int32_t numRows = 100; - auto required_writer = static_cast(rg_writer->NextColumn()); - for (int32_t i = 0; i < num_rows; i++) { - required_writer->WriteBatch(1, nullptr, nullptr, &i); + auto requiredWriter = static_cast(rgWriter->nextColumn()); + for (int32_t i = 0; i < numRows; i++) { + requiredWriter->writeBatch(1, nullptr, nullptr, &i); } // Write a null value at every other row. - auto optional_writer = static_cast(rg_writer->NextColumn()); - for (int32_t i = 0; i < num_rows; i++) { - int16_t definition_level = i % 2 == 0 ? 1 : 0; - optional_writer->WriteBatch(1, &definition_level, nullptr, &i); + auto optionalWriter = static_cast(rgWriter->nextColumn()); + for (int32_t i = 0; i < numRows; i++) { + int16_t definitionLevel = i % 2 == 0 ? 1 : 0; + optionalWriter->writeBatch(1, &definitionLevel, nullptr, &i); } // Each row has repeated twice. - auto repeated_writer = static_cast(rg_writer->NextColumn()); - for (int i = 0; i < 2 * num_rows; i++) { + auto repeatedWriter = static_cast(rgWriter->nextColumn()); + for (int i = 0; i < 2 * numRows; i++) { int32_t value = i * 1000; - int16_t definition_level = 1; - int16_t repetition_level = i % 2 == 0 ? 1 : 0; - repeated_writer->WriteBatch( - 1, &definition_level, &repetition_level, &value); + int16_t definitionLevel = 1; + int16_t repetitionLevel = i % 2 == 0 ? 1 : 0; + repeatedWriter->writeBatch(1, &definitionLevel, &repetitionLevel, &value); } - ASSERT_NO_THROW(file_writer->Close()); + ASSERT_NO_THROW(fileWriter->close()); ASSERT_OK_AND_ASSIGN(auto buffer, sink->Finish()); - auto file_reader = ParquetFileReader::Open( + auto fileReader = ParquetFileReader::open( std::make_shared<::arrow::io::BufferReader>(buffer), - default_reader_properties()); - auto metadata = file_reader->metadata(); - ASSERT_EQ(1, metadata->num_row_groups()); - auto row_group_reader = file_reader->RowGroup(0); + defaultReaderProperties()); + auto metadata = fileReader->metadata(); + ASSERT_EQ(1, metadata->numRowGroups()); + auto rowGroupReader = fileReader->rowGroup(0); // Verify required column. { - auto page_reader = row_group_reader->GetColumnPageReader(0); - auto page = page_reader->NextPage(); + auto pageReader = rowGroupReader->getColumnPageReader(0); + auto page = pageReader->nextPage(); ASSERT_NE(page, nullptr); - auto data_page = std::static_pointer_cast(page); - EXPECT_EQ(num_rows, data_page->num_rows()); - EXPECT_EQ(num_rows, data_page->num_values()); - EXPECT_EQ(0, data_page->num_nulls()); - EXPECT_EQ(page_reader->NextPage(), nullptr); + auto dataPage = std::static_pointer_cast(page); + EXPECT_EQ(numRows, dataPage->numRows()); + EXPECT_EQ(numRows, dataPage->numValues()); + EXPECT_EQ(0, dataPage->numNulls()); + EXPECT_EQ(pageReader->nextPage(), nullptr); } // Verify optional column. { - auto page_reader = row_group_reader->GetColumnPageReader(1); - auto page = page_reader->NextPage(); + auto pageReader = rowGroupReader->getColumnPageReader(1); + auto page = pageReader->nextPage(); ASSERT_NE(page, nullptr); - auto data_page = std::static_pointer_cast(page); - EXPECT_EQ(num_rows, data_page->num_rows()); - EXPECT_EQ(num_rows, data_page->num_values()); - EXPECT_EQ(num_rows / 2, data_page->num_nulls()); - EXPECT_EQ(page_reader->NextPage(), nullptr); + auto dataPage = std::static_pointer_cast(page); + EXPECT_EQ(numRows, dataPage->numRows()); + EXPECT_EQ(numRows, dataPage->numValues()); + EXPECT_EQ(numRows / 2, dataPage->numNulls()); + EXPECT_EQ(pageReader->nextPage(), nullptr); } // Verify repeated column. { - auto page_reader = row_group_reader->GetColumnPageReader(2); - auto page = page_reader->NextPage(); + auto pageReader = rowGroupReader->getColumnPageReader(2); + auto page = pageReader->nextPage(); ASSERT_NE(page, nullptr); - auto data_page = std::static_pointer_cast(page); - EXPECT_EQ(num_rows, data_page->num_rows()); - EXPECT_EQ(num_rows * 2, data_page->num_values()); - EXPECT_EQ(0, data_page->num_nulls()); - EXPECT_EQ(page_reader->NextPage(), nullptr); + auto dataPage = std::static_pointer_cast(page); + EXPECT_EQ(numRows, dataPage->numRows()); + EXPECT_EQ(numRows * 2, dataPage->numValues()); + EXPECT_EQ(0, dataPage->numNulls()); + EXPECT_EQ(pageReader->nextPage(), nullptr); } } -// The test below checks that data page v2 changes on record boundaries for +// The test below checks that data page v2 changes on record boundaries for. // all repetition types (i.e. required, optional, and repeated) TEST(TestColumnWriter, WriteDataPagesChangeOnRecordBoundaries) { - auto sink = CreateOutputStream(); - auto schema = std::static_pointer_cast(GroupNode::Make( + auto sink = createOutputStream(); + auto schema = std::static_pointer_cast(GroupNode::make( "schema", - Repetition::REQUIRED, - {schema::Int32("required", Repetition::REQUIRED), - schema::Int32("optional", Repetition::OPTIONAL), - schema::Int32("repeated", Repetition::REPEATED)})); + Repetition::kRequired, + {schema::int32("required", Repetition::kRequired), + schema::int32("optional", Repetition::kOptional), + schema::int32("repeated", Repetition::kRepeated)})); // Write at most 11 levels per batch. - constexpr int64_t batch_size = 11; + constexpr int64_t batchSize = 11; auto properties = WriterProperties::Builder() - .disable_dictionary() - ->data_page_version(ParquetDataPageVersion::V2) - ->write_batch_size(batch_size) - ->data_pagesize(1) /* every page size check creates a new page */ + .disableDictionary() + ->dataPageVersion(ParquetDataPageVersion::V2) + ->writeBatchSize(batchSize) + ->dataPagesize(1) /* every page size check creates a new page */ ->build(); - auto file_writer = ParquetFileWriter::Open(sink, schema, properties); - auto rg_writer = file_writer->AppendRowGroup(); - - constexpr int32_t num_levels = 100; - const std::vector values(num_levels, 1024); - std::array def_levels; - std::array rep_levels; - for (int32_t i = 0; i < num_levels; i++) { - def_levels[i] = i % 2 == 0 ? 1 : 0; - rep_levels[i] = i % 2 == 0 ? 0 : 1; + auto fileWriter = ParquetFileWriter::open(sink, schema, properties); + auto rgWriter = fileWriter->appendRowGroup(); + + constexpr int32_t numLevels = 100; + const std::vector values(numLevels, 1024); + std::array defLevels; + std::array repLevels; + for (int32_t i = 0; i < numLevels; i++) { + defLevels[i] = i % 2 == 0 ? 1 : 0; + repLevels[i] = i % 2 == 0 ? 0 : 1; } - auto required_writer = static_cast(rg_writer->NextColumn()); - required_writer->WriteBatch(num_levels, nullptr, nullptr, values.data()); + auto requiredWriter = static_cast(rgWriter->nextColumn()); + requiredWriter->writeBatch(numLevels, nullptr, nullptr, values.data()); // Write a null value at every other row. - auto optional_writer = static_cast(rg_writer->NextColumn()); - optional_writer->WriteBatch( - num_levels, def_levels.data(), nullptr, values.data()); + auto optionalWriter = static_cast(rgWriter->nextColumn()); + optionalWriter->writeBatch( + numLevels, defLevels.data(), nullptr, values.data()); // Each row has repeated twice. - auto repeated_writer = static_cast(rg_writer->NextColumn()); - repeated_writer->WriteBatch( - num_levels, def_levels.data(), rep_levels.data(), values.data()); - repeated_writer->WriteBatch( - num_levels, def_levels.data(), rep_levels.data(), values.data()); + auto repeatedWriter = static_cast(rgWriter->nextColumn()); + repeatedWriter->writeBatch( + numLevels, defLevels.data(), repLevels.data(), values.data()); + repeatedWriter->writeBatch( + numLevels, defLevels.data(), repLevels.data(), values.data()); - ASSERT_NO_THROW(file_writer->Close()); + ASSERT_NO_THROW(fileWriter->close()); ASSERT_OK_AND_ASSIGN(auto buffer, sink->Finish()); - auto file_reader = ParquetFileReader::Open( + auto fileReader = ParquetFileReader::open( std::make_shared<::arrow::io::BufferReader>(buffer), - default_reader_properties()); - auto metadata = file_reader->metadata(); - ASSERT_EQ(1, metadata->num_row_groups()); - auto row_group_reader = file_reader->RowGroup(0); + defaultReaderProperties()); + auto metadata = fileReader->metadata(); + ASSERT_EQ(1, metadata->numRowGroups()); + auto rowGroupReader = fileReader->rowGroup(0); // Check if pages are changed on record boundaries. - constexpr int num_columns = 3; - const std::array expected_num_pages = {10, 10, 19}; - for (int i = 0; i < num_columns; ++i) { - auto page_reader = row_group_reader->GetColumnPageReader(i); - int64_t num_rows = 0; - int64_t num_pages = 0; + constexpr int numColumns = 3; + const std::array expectedNumPages = {10, 10, 19}; + for (int i = 0; i < numColumns; ++i) { + auto pageReader = rowGroupReader->getColumnPageReader(i); + int64_t numRows = 0; + int64_t numPages = 0; std::shared_ptr page; - while ((page = page_reader->NextPage()) != nullptr) { - auto data_page = std::static_pointer_cast(page); + while ((page = pageReader->nextPage()) != nullptr) { + auto dataPage = std::static_pointer_cast(page); if (i < 2) { - EXPECT_EQ(data_page->num_values(), data_page->num_rows()); + EXPECT_EQ(dataPage->numValues(), dataPage->numRows()); } else { - // Make sure repeated column has 2 values per row and not span multiple - // pages. - EXPECT_EQ(data_page->num_values(), 2 * data_page->num_rows()); + // Make sure repeated column has 2 values per row and not span multiple. + // Pages. + EXPECT_EQ(dataPage->numValues(), 2 * dataPage->numRows()); } - num_rows += data_page->num_rows(); - num_pages++; + numRows += dataPage->numRows(); + numPages++; } - EXPECT_EQ(num_levels, num_rows); - EXPECT_EQ(expected_num_pages[i], num_pages); + EXPECT_EQ(numLevels, numRows); + EXPECT_EQ(expectedNumPages[i], numPages); } } -// The test below checks that data page v2 changes on record boundaries for -// repeated columns with small batches. +// The test below checks that data page v2 changes on record boundaries for. +// Repeated columns with small batches. TEST(TestColumnWriter, WriteDataPagesChangeOnRecordBoundariesWithSmallBatches) { - auto sink = CreateOutputStream(); - auto schema = std::static_pointer_cast(GroupNode::Make( + auto sink = createOutputStream(); + auto schema = std::static_pointer_cast(GroupNode::make( "schema", - Repetition::REQUIRED, - {schema::Int32("tiny_repeat", Repetition::REPEATED), - schema::Int32("small_repeat", Repetition::REPEATED), - schema::Int32("medium_repeat", Repetition::REPEATED), - schema::Int32("large_repeat", Repetition::REPEATED)})); - - // The batch_size is large enough so each WriteBatch call checks page size at - // most once. - constexpr int64_t batch_size = std::numeric_limits::max(); + Repetition::kRequired, + {schema::int32("tiny_repeat", Repetition::kRepeated), + schema::int32("small_repeat", Repetition::kRepeated), + schema::int32("medium_repeat", Repetition::kRepeated), + schema::int32("large_repeat", Repetition::kRepeated)})); + + // The batch_size is large enough so each WriteBatch call checks page size at. + // Most once. + constexpr int64_t batchSize = std::numeric_limits::max(); auto properties = WriterProperties::Builder() - .disable_dictionary() - ->data_page_version(ParquetDataPageVersion::V2) - ->write_batch_size(batch_size) - ->data_pagesize(1) /* every page size check creates a new page */ + .disableDictionary() + ->dataPageVersion(ParquetDataPageVersion::V2) + ->writeBatchSize(batchSize) + ->dataPagesize(1) /* every page size check creates a new page */ ->build(); - auto file_writer = ParquetFileWriter::Open(sink, schema, properties); - auto rg_writer = file_writer->AppendRowGroup(); + auto fileWriter = ParquetFileWriter::open(sink, schema, properties); + auto rgWriter = fileWriter->appendRowGroup(); - constexpr int32_t num_cols = 4; - constexpr int64_t num_rows = 400; - constexpr int64_t num_levels = 100; - constexpr std::array num_levels_per_row_by_col = { + constexpr int32_t numCols = 4; + constexpr int64_t numRows = 400; + constexpr int64_t numLevels = 100; + constexpr std::array numLevelsPerRowByCol = { 1, 50, 99, 150}; // All values are not null and fixed to 1024 for simplicity. - const std::vector values(num_levels, 1024); - const std::vector def_levels(num_levels, 1); - std::vector rep_levels(num_levels, 0); - - for (int32_t i = 0; i < num_cols; ++i) { - auto writer = static_cast(rg_writer->NextColumn()); - const auto num_levels_per_row = num_levels_per_row_by_col[i]; - int64_t num_rows_written = 0; - int64_t num_levels_written_curr_row = 0; - while (num_rows_written < num_rows) { - int32_t num_levels_to_write = 0; - while (num_levels_to_write < num_levels) { - if (num_levels_written_curr_row == 0) { + const std::vector values(numLevels, 1024); + const std::vector defLevels(numLevels, 1); + std::vector repLevels(numLevels, 0); + + for (int32_t i = 0; i < numCols; ++i) { + auto writer = static_cast(rgWriter->nextColumn()); + const auto numLevelsPerRow = numLevelsPerRowByCol[i]; + int64_t numRowsWritten = 0; + int64_t numLevelsWrittenCurrRow = 0; + while (numRowsWritten < numRows) { + int32_t numLevelsToWrite = 0; + while (numLevelsToWrite < numLevels) { + if (numLevelsWrittenCurrRow == 0) { // A new record. - rep_levels[num_levels_to_write++] = 0; + repLevels[numLevelsToWrite++] = 0; } else { - rep_levels[num_levels_to_write++] = 1; + repLevels[numLevelsToWrite++] = 1; } - if (++num_levels_written_curr_row == num_levels_per_row) { + if (++numLevelsWrittenCurrRow == numLevelsPerRow) { // Current row has enough levels. - num_levels_written_curr_row = 0; - if (++num_rows_written == num_rows) { + numLevelsWrittenCurrRow = 0; + if (++numRowsWritten == numRows) { // Enough rows have been written. break; } } } - writer->WriteBatch( - num_levels_to_write, - def_levels.data(), - rep_levels.data(), - values.data()); + writer->writeBatch( + numLevelsToWrite, defLevels.data(), repLevels.data(), values.data()); } } - ASSERT_NO_THROW(file_writer->Close()); + ASSERT_NO_THROW(fileWriter->close()); ASSERT_OK_AND_ASSIGN(auto buffer, sink->Finish()); - auto file_reader = ParquetFileReader::Open( + auto fileReader = ParquetFileReader::open( std::make_shared<::arrow::io::BufferReader>(buffer), - default_reader_properties()); - auto metadata = file_reader->metadata(); - ASSERT_EQ(1, metadata->num_row_groups()); - auto row_group_reader = file_reader->RowGroup(0); + defaultReaderProperties()); + auto metadata = fileReader->metadata(); + ASSERT_EQ(1, metadata->numRowGroups()); + auto rowGroupReader = fileReader->rowGroup(0); // Check if pages are changed on record boundaries. - const std::array expect_num_pages_by_col = { - 5, 201, 397, 201}; - const std::array expect_num_rows_1st_page_by_col = { - 99, 1, 1, 1}; - const std::array expect_num_vals_1st_page_by_col = { + const std::array expectNumPagesByCol = {5, 201, 397, 201}; + const std::array expectNumRows1stPageByCol = {99, 1, 1, 1}; + const std::array expectNumVals1stPageByCol = { 99, 50, 99, 150}; - for (int32_t i = 0; i < num_cols; ++i) { - auto page_reader = row_group_reader->GetColumnPageReader(i); - int64_t num_rows_read = 0; - int64_t num_pages_read = 0; - int64_t num_values_read = 0; + for (int32_t i = 0; i < numCols; ++i) { + auto pageReader = rowGroupReader->getColumnPageReader(i); + int64_t numRowsRead = 0; + int64_t numPagesRead = 0; + int64_t numValuesRead = 0; std::shared_ptr page; - while ((page = page_reader->NextPage()) != nullptr) { - auto data_page = std::static_pointer_cast(page); - num_values_read += data_page->num_values(); - num_rows_read += data_page->num_rows(); - if (num_pages_read++ == 0) { - EXPECT_EQ(expect_num_rows_1st_page_by_col[i], data_page->num_rows()); - EXPECT_EQ(expect_num_vals_1st_page_by_col[i], data_page->num_values()); + while ((page = pageReader->nextPage()) != nullptr) { + auto dataPage = std::static_pointer_cast(page); + numValuesRead += dataPage->numValues(); + numRowsRead += dataPage->numRows(); + if (numPagesRead++ == 0) { + EXPECT_EQ(expectNumRows1stPageByCol[i], dataPage->numRows()); + EXPECT_EQ(expectNumVals1stPageByCol[i], dataPage->numValues()); } } - EXPECT_EQ(num_rows, num_rows_read); - EXPECT_EQ(expect_num_pages_by_col[i], num_pages_read); - EXPECT_EQ(num_levels_per_row_by_col[i] * num_rows, num_values_read); + EXPECT_EQ(numRows, numRowsRead); + EXPECT_EQ(expectNumPagesByCol[i], numPagesRead); + EXPECT_EQ(numLevelsPerRowByCol[i] * numRows, numValuesRead); } } class ColumnWriterTestSizeEstimated : public ::testing::Test { public: void SetUp() { - sink_ = CreateOutputStream(); - node_ = std::static_pointer_cast(GroupNode::Make( + sink_ = createOutputStream(); + node_ = std::static_pointer_cast(GroupNode::make( "schema", - Repetition::REQUIRED, + Repetition::kRequired, { - schema::Int32("required", Repetition::REQUIRED), + schema::int32("required", Repetition::kRequired), })); std::vector fields; - schema_descriptor_ = std::make_unique(); - schema_descriptor_->Init(node_); + schemaDescriptor_ = std::make_unique(); + schemaDescriptor_->init(node_); } - std::shared_ptr BuildWriter( + std::shared_ptr buildWriter( Compression::type compression, bool buffered, - bool enable_dictionary = false) { - auto builder = WriterProperties::Builder(); - builder.disable_dictionary() + bool enableDictionary = false) { + auto Builder = WriterProperties::Builder(); + Builder.disableDictionary() ->compression(compression) - ->data_pagesize(100 * sizeof(int)); - if (enable_dictionary) { - builder.enable_dictionary(); + ->dataPagesize(100 * sizeof(int)); + if (enableDictionary) { + Builder.enableDictionary(); } else { - builder.disable_dictionary(); + Builder.disableDictionary(); } - writer_properties_ = builder.build(); - metadata_ = ColumnChunkMetaDataBuilder::Make( - writer_properties_, schema_descriptor_->Column(0)); + writerProperties_ = Builder.build(); + metadata_ = ColumnChunkMetaDataBuilder::make( + writerProperties_, schemaDescriptor_->column(0)); - std::unique_ptr pager = PageWriter::Open( + std::unique_ptr pager = PageWriter::open( sink_, compression, - Codec::UseDefaultCompressionLevel(), + Codec::useDefaultCompressionLevel(), metadata_.get(), /* row_group_ordinal */ -1, /* column_chunk_ordinal*/ -1, @@ -1557,165 +1533,165 @@ class ColumnWriterTestSizeEstimated : public ::testing::Test { /* header_encryptor */ NULLPTR, /* data_encryptor */ NULLPTR, /* enable_checksum */ false); - return std::static_pointer_cast(ColumnWriter::Make( - metadata_.get(), std::move(pager), writer_properties_.get())); + return std::static_pointer_cast(ColumnWriter::make( + metadata_.get(), std::move(pager), writerProperties_.get())); } std::shared_ptr<::arrow::io::BufferOutputStream> sink_; std::shared_ptr node_; - std::unique_ptr schema_descriptor_; + std::unique_ptr schemaDescriptor_; - std::shared_ptr writer_properties_; + std::shared_ptr writerProperties_; std::unique_ptr metadata_; }; TEST_F(ColumnWriterTestSizeEstimated, NonBuffered) { - auto required_writer = - this->BuildWriter(Compression::UNCOMPRESSED, /* buffered*/ false); - // Write half page, page will not be flushed after loop + auto requiredWriter = + this->buildWriter(Compression::UNCOMPRESSED, /* buffered*/ false); + // Write half page, page will not be flushed after loop. for (int32_t i = 0; i < 50; i++) { - required_writer->WriteBatch(1, nullptr, nullptr, &i); + requiredWriter->writeBatch(1, nullptr, nullptr, &i); } - // Page not flushed, check size - EXPECT_EQ(0, required_writer->total_bytes_written()); - EXPECT_EQ(0, required_writer->total_compressed_bytes()); // unbuffered - EXPECT_EQ(0, required_writer->total_compressed_bytes_written()); - // Write half page, page be flushed after loop + // Page not flushed, check size. + EXPECT_EQ(0, requiredWriter->totalBytesWritten()); + EXPECT_EQ(0, requiredWriter->totalCompressedBytes()); // unbuffered + EXPECT_EQ(0, requiredWriter->totalCompressedBytesWritten()); + // Write half page, page be flushed after loop. for (int32_t i = 0; i < 50; i++) { - required_writer->WriteBatch(1, nullptr, nullptr, &i); + requiredWriter->writeBatch(1, nullptr, nullptr, &i); } - // Page flushed, check size - EXPECT_LT(400, required_writer->total_bytes_written()); - EXPECT_EQ(0, required_writer->total_compressed_bytes()); - EXPECT_LT(400, required_writer->total_compressed_bytes_written()); - - // Test after closed - int64_t written_size = required_writer->Close(); - EXPECT_EQ(0, required_writer->total_compressed_bytes()); - EXPECT_EQ(written_size, required_writer->total_bytes_written()); - // uncompressed writer should be equal - EXPECT_EQ(written_size, required_writer->total_compressed_bytes_written()); + // Page flushed, check size. + EXPECT_LT(400, requiredWriter->totalBytesWritten()); + EXPECT_EQ(0, requiredWriter->totalCompressedBytes()); + EXPECT_LT(400, requiredWriter->totalCompressedBytesWritten()); + + // Test after closed. + int64_t writtenSize = requiredWriter->close(); + EXPECT_EQ(0, requiredWriter->totalCompressedBytes()); + EXPECT_EQ(writtenSize, requiredWriter->totalBytesWritten()); + // Uncompressed writer should be equal. + EXPECT_EQ(writtenSize, requiredWriter->totalCompressedBytesWritten()); } TEST_F(ColumnWriterTestSizeEstimated, Buffered) { - auto required_writer = - this->BuildWriter(Compression::UNCOMPRESSED, /* buffered*/ true); - // Write half page, page will not be flushed after loop + auto requiredWriter = + this->buildWriter(Compression::UNCOMPRESSED, /* buffered*/ true); + // Write half page, page will not be flushed after loop. for (int32_t i = 0; i < 50; i++) { - required_writer->WriteBatch(1, nullptr, nullptr, &i); + requiredWriter->writeBatch(1, nullptr, nullptr, &i); } - // Page not flushed, check size - EXPECT_EQ(0, required_writer->total_bytes_written()); - EXPECT_EQ(0, required_writer->total_compressed_bytes()); // buffered - EXPECT_EQ(0, required_writer->total_compressed_bytes_written()); - // Write half page, page be flushed after loop + // Page not flushed, check size. + EXPECT_EQ(0, requiredWriter->totalBytesWritten()); + EXPECT_EQ(0, requiredWriter->totalCompressedBytes()); // buffered + EXPECT_EQ(0, requiredWriter->totalCompressedBytesWritten()); + // Write half page, page be flushed after loop. for (int32_t i = 0; i < 50; i++) { - required_writer->WriteBatch(1, nullptr, nullptr, &i); + requiredWriter->writeBatch(1, nullptr, nullptr, &i); } - // Page flushed, check size - EXPECT_LT(400, required_writer->total_bytes_written()); - EXPECT_EQ(0, required_writer->total_compressed_bytes()); - EXPECT_LT(400, required_writer->total_compressed_bytes_written()); - - // Test after closed - int64_t written_size = required_writer->Close(); - EXPECT_EQ(0, required_writer->total_compressed_bytes()); - EXPECT_EQ(written_size, required_writer->total_bytes_written()); - // uncompressed writer should be equal - EXPECT_EQ(written_size, required_writer->total_compressed_bytes_written()); + // Page flushed, check size. + EXPECT_LT(400, requiredWriter->totalBytesWritten()); + EXPECT_EQ(0, requiredWriter->totalCompressedBytes()); + EXPECT_LT(400, requiredWriter->totalCompressedBytesWritten()); + + // Test after closed. + int64_t writtenSize = requiredWriter->close(); + EXPECT_EQ(0, requiredWriter->totalCompressedBytes()); + EXPECT_EQ(writtenSize, requiredWriter->totalBytesWritten()); + // Uncompressed writer should be equal. + EXPECT_EQ(writtenSize, requiredWriter->totalCompressedBytesWritten()); } TEST_F(ColumnWriterTestSizeEstimated, NonBufferedDictionary) { - auto required_writer = - this->BuildWriter(Compression::UNCOMPRESSED, /* buffered*/ false, true); - // for dict, keep all values equal - int32_t dict_value = 1; + auto requiredWriter = + this->buildWriter(Compression::UNCOMPRESSED, /* buffered*/ false, true); + // For dict, keep all values equal. + int32_t dictValue = 1; for (int32_t i = 0; i < 50; i++) { - required_writer->WriteBatch(1, nullptr, nullptr, &dict_value); + requiredWriter->writeBatch(1, nullptr, nullptr, &dictValue); } - // Page not flushed, check size - EXPECT_EQ(0, required_writer->total_bytes_written()); - EXPECT_EQ(0, required_writer->total_compressed_bytes()); - EXPECT_EQ(0, required_writer->total_compressed_bytes_written()); - // write a huge batch to trigger page flush + // Page not flushed, check size. + EXPECT_EQ(0, requiredWriter->totalBytesWritten()); + EXPECT_EQ(0, requiredWriter->totalCompressedBytes()); + EXPECT_EQ(0, requiredWriter->totalCompressedBytesWritten()); + // Write a huge batch to trigger page flush. for (int32_t i = 0; i < 50000; i++) { - required_writer->WriteBatch(1, nullptr, nullptr, &dict_value); + requiredWriter->writeBatch(1, nullptr, nullptr, &dictValue); } - // Page flushed, check size - EXPECT_EQ(0, required_writer->total_bytes_written()); - EXPECT_LT(400, required_writer->total_compressed_bytes()); - EXPECT_EQ(0, required_writer->total_compressed_bytes_written()); - - required_writer->Close(); - - // Test after closed - int64_t written_size = required_writer->Close(); - EXPECT_EQ(0, required_writer->total_compressed_bytes()); - EXPECT_EQ(written_size, required_writer->total_bytes_written()); - // uncompressed writer should be equal - EXPECT_EQ(written_size, required_writer->total_compressed_bytes_written()); + // Page flushed, check size. + EXPECT_EQ(0, requiredWriter->totalBytesWritten()); + EXPECT_LT(400, requiredWriter->totalCompressedBytes()); + EXPECT_EQ(0, requiredWriter->totalCompressedBytesWritten()); + + requiredWriter->close(); + + // Test after closed. + int64_t writtenSize = requiredWriter->close(); + EXPECT_EQ(0, requiredWriter->totalCompressedBytes()); + EXPECT_EQ(writtenSize, requiredWriter->totalBytesWritten()); + // Uncompressed writer should be equal. + EXPECT_EQ(writtenSize, requiredWriter->totalCompressedBytesWritten()); } TEST_F(ColumnWriterTestSizeEstimated, BufferedCompression) { - auto required_writer = this->BuildWriter(Compression::SNAPPY, true); + auto requiredWriter = this->buildWriter(Compression::SNAPPY, true); - // Write half page + // Write half page. for (int32_t i = 0; i < 50; i++) { - required_writer->WriteBatch(1, nullptr, nullptr, &i); + requiredWriter->writeBatch(1, nullptr, nullptr, &i); } - // Page not flushed, check size - EXPECT_EQ(0, required_writer->total_bytes_written()); - EXPECT_EQ(0, required_writer->total_compressed_bytes()); // buffered - EXPECT_EQ(0, required_writer->total_compressed_bytes_written()); + // Page not flushed, check size. + EXPECT_EQ(0, requiredWriter->totalBytesWritten()); + EXPECT_EQ(0, requiredWriter->totalCompressedBytes()); // buffered + EXPECT_EQ(0, requiredWriter->totalCompressedBytesWritten()); for (int32_t i = 0; i < 50; i++) { - required_writer->WriteBatch(1, nullptr, nullptr, &i); + requiredWriter->writeBatch(1, nullptr, nullptr, &i); } - // Page flushed, check size - EXPECT_LT(400, required_writer->total_bytes_written()); - EXPECT_EQ(0, required_writer->total_compressed_bytes()); + // Page flushed, check size. + EXPECT_LT(400, requiredWriter->totalBytesWritten()); + EXPECT_EQ(0, requiredWriter->totalCompressedBytes()); EXPECT_LT( - required_writer->total_compressed_bytes_written(), - required_writer->total_bytes_written()); - - // Test after closed - int64_t written_size = required_writer->Close(); - EXPECT_EQ(0, required_writer->total_compressed_bytes()); - EXPECT_EQ(written_size, required_writer->total_bytes_written()); - EXPECT_GT(written_size, required_writer->total_compressed_bytes_written()); + requiredWriter->totalCompressedBytesWritten(), + requiredWriter->totalBytesWritten()); + + // Test after closed. + int64_t writtenSize = requiredWriter->close(); + EXPECT_EQ(0, requiredWriter->totalCompressedBytes()); + EXPECT_EQ(writtenSize, requiredWriter->totalBytesWritten()); + EXPECT_GT(writtenSize, requiredWriter->totalCompressedBytesWritten()); } TEST(TestColumnWriter, WriteDataPageV2HeaderNullCount) { - auto sink = CreateOutputStream(); - auto list_type = GroupNode::Make( + auto sink = createOutputStream(); + auto listType = GroupNode::make( "list", - Repetition::REPEATED, - {schema::Int32("elem", Repetition::OPTIONAL)}); - auto schema = std::static_pointer_cast(GroupNode::Make( + Repetition::kRepeated, + {schema::int32("elem", Repetition::kOptional)}); + auto schema = std::static_pointer_cast(GroupNode::make( "schema", - Repetition::REQUIRED, + Repetition::kRequired, { - schema::Int32("non_null", Repetition::OPTIONAL), - schema::Int32("half_null", Repetition::OPTIONAL), - schema::Int32("all_null", Repetition::OPTIONAL), - GroupNode::Make("half_null_list", Repetition::OPTIONAL, {list_type}), - GroupNode::Make("half_empty_list", Repetition::OPTIONAL, {list_type}), - GroupNode::Make( - "half_list_of_null", Repetition::OPTIONAL, {list_type}), - GroupNode::Make("all_single_list", Repetition::OPTIONAL, {list_type}), + schema::int32("non_null", Repetition::kOptional), + schema::int32("half_null", Repetition::kOptional), + schema::int32("all_null", Repetition::kOptional), + GroupNode::make("half_null_list", Repetition::kOptional, {listType}), + GroupNode::make("half_empty_list", Repetition::kOptional, {listType}), + GroupNode::make( + "half_list_of_null", Repetition::kOptional, {listType}), + GroupNode::make("all_single_list", Repetition::kOptional, {listType}), })); auto properties = WriterProperties::Builder() /* Use V2 data page to read null_count from header */ - .data_page_version(ParquetDataPageVersion::V2) + .dataPageVersion(ParquetDataPageVersion::V2) /* Disable stats to test null_count is properly set */ - ->disable_statistics() - ->disable_dictionary() + ->disableStatistics() + ->disableDictionary() ->build(); - auto file_writer = ParquetFileWriter::Open(sink, schema, properties); - auto rg_writer = file_writer->AppendRowGroup(); + auto fileWriter = ParquetFileWriter::open(sink, schema, properties); + auto rgWriter = fileWriter->appendRowGroup(); - constexpr int32_t num_rows = 10; - constexpr int32_t num_cols = 7; - const std::vector> def_levels_by_col = { + constexpr int32_t numRows = 10; + constexpr int32_t numCols = 7; + const std::vector> defLevelsByCol = { {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, {1, 0, 1, 0, 1, 0, 1, 0, 1, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, @@ -1724,43 +1700,43 @@ TEST(TestColumnWriter, WriteDataPageV2HeaderNullCount) { {2, 3, 2, 3, 2, 3, 2, 3, 2, 3}, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, }; - const std::vector ref_levels(num_rows, 0); - const std::vector values(num_rows, 123); - const std::vector expect_null_count_by_col = {0, 5, 10, 5, 5, 5, 0}; - - for (int32_t i = 0; i < num_cols; ++i) { - auto writer = static_cast(rg_writer->NextColumn()); - writer->WriteBatch( - num_rows, - def_levels_by_col[i].data(), - i >= 3 ? ref_levels.data() : nullptr, + const std::vector refLevels(numRows, 0); + const std::vector values(numRows, 123); + const std::vector expectNullCountByCol = {0, 5, 10, 5, 5, 5, 0}; + + for (int32_t i = 0; i < numCols; ++i) { + auto writer = static_cast(rgWriter->nextColumn()); + writer->writeBatch( + numRows, + defLevelsByCol[i].data(), + i >= 3 ? refLevels.data() : nullptr, values.data()); } - ASSERT_NO_THROW(file_writer->Close()); + ASSERT_NO_THROW(fileWriter->close()); ASSERT_OK_AND_ASSIGN(auto buffer, sink->Finish()); - auto file_reader = ParquetFileReader::Open( + auto fileReader = ParquetFileReader::open( std::make_shared<::arrow::io::BufferReader>(buffer), - default_reader_properties()); - auto metadata = file_reader->metadata(); - ASSERT_EQ(1, metadata->num_row_groups()); - auto row_group_reader = file_reader->RowGroup(0); + defaultReaderProperties()); + auto metadata = fileReader->metadata(); + ASSERT_EQ(1, metadata->numRowGroups()); + auto rowGroupReader = fileReader->rowGroup(0); std::shared_ptr page; - for (int32_t i = 0; i < num_cols; ++i) { - auto page_reader = row_group_reader->GetColumnPageReader(i); - int64_t num_nulls_read = 0; - int64_t num_rows_read = 0; - int64_t num_values_read = 0; - while ((page = page_reader->NextPage()) != nullptr) { - auto data_page = std::static_pointer_cast(page); - num_nulls_read += data_page->num_nulls(); - num_rows_read += data_page->num_rows(); - num_values_read += data_page->num_values(); + for (int32_t i = 0; i < numCols; ++i) { + auto pageReader = rowGroupReader->getColumnPageReader(i); + int64_t numNullsRead = 0; + int64_t numRowsRead = 0; + int64_t numValuesRead = 0; + while ((page = pageReader->nextPage()) != nullptr) { + auto dataPage = std::static_pointer_cast(page); + numNullsRead += dataPage->numNulls(); + numRowsRead += dataPage->numRows(); + numValuesRead += dataPage->numValues(); } - EXPECT_EQ(expect_null_count_by_col[i], num_nulls_read); - EXPECT_EQ(num_rows, num_rows_read); - EXPECT_EQ(num_rows, num_values_read); + EXPECT_EQ(expectNullCountByCol[i], numNullsRead); + EXPECT_EQ(numRows, numRowsRead); + EXPECT_EQ(numRows, numValuesRead); } } diff --git a/velox/dwio/parquet/writer/arrow/tests/EncodingTest.cpp b/velox/dwio/parquet/writer/arrow/tests/EncodingTest.cpp index 92570dc139a..ab97f63702f 100644 --- a/velox/dwio/parquet/writer/arrow/tests/EncodingTest.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/EncodingTest.cpp @@ -68,8 +68,8 @@ namespace facebook::velox::parquet::arrow { namespace { constexpr int64_t kInMemoryDefaultCapacity = 1024; -// The Parquet spec isn't very clear whether ByteArray lengths are signed or -// unsigned, but the Java implementation uses signed ints. +// The Parquet spec isn't very clear whether ByteArray lengths are signed or. +// Unsigned, but the Java implementation uses signed ints. constexpr size_t kMaxByteArraySize = std::numeric_limits::max(); class EncoderImpl : virtual public Encoder { @@ -81,93 +81,91 @@ class EncoderImpl : virtual public Encoder { : descr_(descr), encoding_(encoding), pool_(pool), - type_length_(descr ? descr->type_length() : -1) {} + typeLength_(descr ? descr->typeLength() : -1) {} Encoding::type encoding() const override { return encoding_; } - MemoryPool* memory_pool() const override { + MemoryPool* memoryPool() const override { return pool_; } protected: - // For accessing type-specific metadata, like FIXED_LEN_BYTE_ARRAY + // For accessing type-specific metadata, like FIXED_LEN_BYTE_ARRAY. const ColumnDescriptor* descr_; const Encoding::type encoding_; MemoryPool* pool_; - /// Type length from descr - int type_length_; + /// Type length from descr. + int typeLength_; }; -// ---------------------------------------------------------------------- -// Plain encoder implementation +// ----------------------------------------------------------------------. +// Plain encoder implementation. template class PlainEncoder : public EncoderImpl, virtual public TypedEncoder { public: - using T = typename DType::c_type; + using T = typename DType::CType; explicit PlainEncoder(const ColumnDescriptor* descr, MemoryPool* pool) - : EncoderImpl(descr, Encoding::PLAIN, pool), sink_(pool) {} + : EncoderImpl(descr, Encoding::kPlain, pool), sink_(pool) {} - int64_t EstimatedDataEncodedSize() override { + int64_t estimatedDataEncodedSize() override { return sink_.length(); } - std::shared_ptr FlushValues() override { + std::shared_ptr flushValues() override { std::shared_ptr buffer; PARQUET_THROW_NOT_OK(sink_.Finish(&buffer)); return buffer; } - using TypedEncoder::Put; + using TypedEncoder::put; - void Put(const T* buffer, int num_values) override; + void put(const T* buffer, int numValues) override; - void Put(const ::arrow::Array& values) override; + void put(const ::arrow::Array& values) override; - void PutSpaced( + void putSpaced( const T* src, - int num_values, - const uint8_t* valid_bits, - int64_t valid_bits_offset) override { - if (valid_bits != NULLPTR) { - PARQUET_ASSIGN_OR_THROW( - auto buffer, - ::arrow::AllocateBuffer(num_values * sizeof(T), this->memory_pool())); + int numValues, + const uint8_t* validBits, + int64_t validBitsOffset) override { + if (validBits != NULLPTR) { + auto buffer = allocateBuffer(this->memoryPool(), numValues * sizeof(T)); T* data = reinterpret_cast(buffer->mutable_data()); - int num_valid_values = ::arrow::util::internal::SpacedCompress( - src, num_values, valid_bits, valid_bits_offset, data); - Put(data, num_valid_values); + int numValidValues = ::arrow::util::internal::SpacedCompress( + src, numValues, validBits, validBitsOffset, data); + put(data, numValidValues); } else { - Put(src, num_values); + put(src, numValues); } } - void UnsafePutByteArray(const void* data, uint32_t length) { + void unsafePutByteArray(const void* data, uint32_t length) { VELOX_DCHECK(length == 0 || data != nullptr, "Value ptr cannot be NULL"); sink_.UnsafeAppend(&length, sizeof(uint32_t)); sink_.UnsafeAppend(data, static_cast(length)); } - void Put(const ByteArray& val) { - // Write the result to the output stream + void put(const ByteArray& val) { + // Write the result to the output stream. const int64_t increment = static_cast(val.len + sizeof(uint32_t)); if (ARROW_PREDICT_FALSE(sink_.length() + increment > sink_.capacity())) { PARQUET_THROW_NOT_OK(sink_.Reserve(increment)); } - UnsafePutByteArray(val.ptr, val.len); + unsafePutByteArray(val.ptr, val.len); } protected: template - void PutBinaryArray(const ArrayType& array) { - const int64_t total_bytes = + void putBinaryArray(const ArrayType& array) { + const int64_t totalBytes = array.value_offset(array.length()) - array.value_offset(0); PARQUET_THROW_NOT_OK( - sink_.Reserve(total_bytes + array.length() * sizeof(uint32_t))); + sink_.Reserve(totalBytes + array.length() * sizeof(uint32_t))); PARQUET_THROW_NOT_OK( ::arrow::VisitArraySpanInline( @@ -177,7 +175,7 @@ class PlainEncoder : public EncoderImpl, virtual public TypedEncoder { return Status::Invalid( "Parquet cannot store strings with size 2GB or more"); } - UnsafePutByteArray( + unsafePutByteArray( view.data(), static_cast(view.size())); return Status::OK(); }, @@ -188,125 +186,124 @@ class PlainEncoder : public EncoderImpl, virtual public TypedEncoder { }; template -void PlainEncoder::Put(const T* buffer, int num_values) { - if (num_values > 0) { - PARQUET_THROW_NOT_OK(sink_.Append(buffer, num_values * sizeof(T))); +void PlainEncoder::put(const T* buffer, int numValues) { + if (numValues > 0) { + PARQUET_THROW_NOT_OK(sink_.Append(buffer, numValues * sizeof(T))); } } template <> -inline void PlainEncoder::Put( +inline void PlainEncoder::put( const ByteArray* src, - int num_values) { - for (int i = 0; i < num_values; ++i) { - Put(src[i]); + int numValues) { + for (int i = 0; i < numValues; ++i) { + put(src[i]); } } template -void DirectPutImpl(const ::arrow::Array& values, ::arrow::BufferBuilder* sink) { +void directPutImpl(const ::arrow::Array& values, ::arrow::BufferBuilder* sink) { if (values.type_id() != ArrayType::TypeClass::type_id) { - std::string type_name = ArrayType::TypeClass::type_name(); + std::string typeName = ArrayType::TypeClass::type_name(); throw ParquetException( - "direct put to " + type_name + " from " + values.type()->ToString() + + "direct put to " + typeName + " from " + values.type()->ToString() + " not supported"); } - using value_type = typename ArrayType::value_type; - constexpr auto value_size = sizeof(value_type); + using ValueType = typename ArrayType::value_type; + constexpr auto valueSize = sizeof(ValueType); auto raw_values = checked_cast(values).raw_values(); if (values.null_count() == 0) { - // no nulls, just dump the data - PARQUET_THROW_NOT_OK( - sink->Append(raw_values, values.length() * value_size)); + // No nulls, just dump the data. + PARQUET_THROW_NOT_OK(sink->Append(raw_values, values.length() * valueSize)); } else { PARQUET_THROW_NOT_OK( - sink->Reserve((values.length() - values.null_count()) * value_size)); + sink->Reserve((values.length() - values.null_count()) * valueSize)); for (int64_t i = 0; i < values.length(); i++) { if (values.IsValid(i)) { - sink->UnsafeAppend(&raw_values[i], value_size); + sink->UnsafeAppend(&raw_values[i], valueSize); } } } } template <> -void PlainEncoder::Put(const ::arrow::Array& values) { - DirectPutImpl<::arrow::Int32Array>(values, &sink_); +void PlainEncoder::put(const ::arrow::Array& values) { + directPutImpl<::arrow::Int32Array>(values, &sink_); } template <> -void PlainEncoder::Put(const ::arrow::Array& values) { - DirectPutImpl<::arrow::Int64Array>(values, &sink_); +void PlainEncoder::put(const ::arrow::Array& values) { + directPutImpl<::arrow::Int64Array>(values, &sink_); } template <> -void PlainEncoder::Put(const ::arrow::Array& values) { +void PlainEncoder::put(const ::arrow::Array& values) { ParquetException::NYI("direct put to Int96"); } template <> -void PlainEncoder::Put(const ::arrow::Array& values) { - DirectPutImpl<::arrow::FloatArray>(values, &sink_); +void PlainEncoder::put(const ::arrow::Array& values) { + directPutImpl<::arrow::FloatArray>(values, &sink_); } template <> -void PlainEncoder::Put(const ::arrow::Array& values) { - DirectPutImpl<::arrow::DoubleArray>(values, &sink_); +void PlainEncoder::put(const ::arrow::Array& values) { + directPutImpl<::arrow::DoubleArray>(values, &sink_); } template -void PlainEncoder::Put(const ::arrow::Array& values) { +void PlainEncoder::put(const ::arrow::Array& values) { ParquetException::NYI("direct put of " + values.type()->ToString()); } -void AssertBaseBinary(const ::arrow::Array& values) { +void assertBaseBinary(const ::arrow::Array& values) { if (!::arrow::is_base_binary_like(values.type_id())) { throw ParquetException("Only BaseBinaryArray and subclasses supported"); } } template <> -inline void PlainEncoder::Put(const ::arrow::Array& values) { - AssertBaseBinary(values); +inline void PlainEncoder::put(const ::arrow::Array& values) { + assertBaseBinary(values); if (::arrow::is_binary_like(values.type_id())) { - PutBinaryArray(checked_cast(values)); + putBinaryArray(checked_cast(values)); } else { VELOX_DCHECK(::arrow::is_large_binary_like(values.type_id())); - PutBinaryArray(checked_cast(values)); + putBinaryArray(checked_cast(values)); } } -void AssertFixedSizeBinary(const ::arrow::Array& values, int type_length) { +void assertFixedSizeBinary(const ::arrow::Array& values, int typeLength) { if (values.type_id() != ::arrow::Type::FIXED_SIZE_BINARY && - values.type_id() != ::arrow::Type::DECIMAL) { + values.type_id() != ::arrow::Type::DECIMAL128) { throw ParquetException( "Only FixedSizeBinaryArray and subclasses supported"); } if (checked_cast(*values.type()) - .byte_width() != type_length) { + .byte_width() != typeLength) { throw ParquetException( "Size mismatch: " + values.type()->ToString() + " should have been " + - std::to_string(type_length) + " wide"); + std::to_string(typeLength) + " wide"); } } template <> -inline void PlainEncoder::Put(const ::arrow::Array& values) { - AssertFixedSizeBinary(values, descr_->type_length()); +inline void PlainEncoder::put(const ::arrow::Array& values) { + assertFixedSizeBinary(values, descr_->typeLength()); const auto& data = checked_cast(values); if (data.null_count() == 0) { - // no nulls, just dump the data + // No nulls, just dump the data. PARQUET_THROW_NOT_OK( sink_.Append(data.raw_values(), data.length() * data.byte_width())); } else { - const int64_t total_bytes = data.length() * data.byte_width() - + const int64_t totalBytes = data.length() * data.byte_width() - data.null_count() * data.byte_width(); - PARQUET_THROW_NOT_OK(sink_.Reserve(total_bytes)); + PARQUET_THROW_NOT_OK(sink_.Reserve(totalBytes)); for (int64_t i = 0; i < data.length(); i++) { if (data.IsValid(i)) { sink_.UnsafeAppend(data.Value(i), data.byte_width()); @@ -316,16 +313,16 @@ inline void PlainEncoder::Put(const ::arrow::Array& values) { } template <> -inline void PlainEncoder::Put( +inline void PlainEncoder::put( const FixedLenByteArray* src, - int num_values) { - if (descr_->type_length() == 0) { + int numValues) { + if (descr_->typeLength() == 0) { return; } - for (int i = 0; i < num_values; ++i) { - // Write the result to the output stream + for (int i = 0; i < numValues; ++i) { + // Write the result to the output stream. VELOX_DCHECK_NOT_NULL(src[i].ptr, "Value ptr cannot be NULL"); - PARQUET_THROW_NOT_OK(sink_.Append(src[i].ptr, descr_->type_length())); + PARQUET_THROW_NOT_OK(sink_.Append(src[i].ptr, descr_->typeLength())); } } @@ -334,40 +331,38 @@ class PlainEncoder : public EncoderImpl, virtual public BooleanEncoder { public: explicit PlainEncoder(const ColumnDescriptor* descr, MemoryPool* pool) - : EncoderImpl(descr, Encoding::PLAIN, pool), - bits_available_(kInMemoryDefaultCapacity * 8), - bits_buffer_(AllocateBuffer(pool, kInMemoryDefaultCapacity)), + : EncoderImpl(descr, Encoding::kPlain, pool), + bitsAvailable_(kInMemoryDefaultCapacity * 8), + bitsBuffer_(allocateBuffer(pool, kInMemoryDefaultCapacity)), sink_(pool), - bit_writer_( - bits_buffer_->mutable_data(), - static_cast(bits_buffer_->size())) {} + bitWriter_( + bitsBuffer_->mutable_data(), + static_cast(bitsBuffer_->size())) {} - int64_t EstimatedDataEncodedSize() override; - std::shared_ptr FlushValues() override; + int64_t estimatedDataEncodedSize() override; + std::shared_ptr flushValues() override; - void Put(const bool* src, int num_values) override; + void put(const bool* src, int numValues) override; - void Put(const std::vector& src, int num_values) override; + void put(const std::vector& src, int numValues) override; - void PutSpaced( + void putSpaced( const bool* src, - int num_values, - const uint8_t* valid_bits, - int64_t valid_bits_offset) override { - if (valid_bits != NULLPTR) { - PARQUET_ASSIGN_OR_THROW( - auto buffer, - ::arrow::AllocateBuffer(num_values * sizeof(T), this->memory_pool())); + int numValues, + const uint8_t* validBits, + int64_t validBitsOffset) override { + if (validBits != NULLPTR) { + auto buffer = allocateBuffer(this->memoryPool(), numValues * sizeof(T)); T* data = reinterpret_cast(buffer->mutable_data()); - int num_valid_values = ::arrow::util::internal::SpacedCompress( - src, num_values, valid_bits, valid_bits_offset, data); - Put(data, num_valid_values); + int numValidValues = ::arrow::util::internal::SpacedCompress( + src, numValues, validBits, validBitsOffset, data); + put(data, numValidValues); } else { - Put(src, num_values); + put(src, numValues); } } - void Put(const ::arrow::Array& values) override { + void put(const ::arrow::Array& values) override { if (values.type_id() != ::arrow::Type::BOOL) { throw ParquetException( "direct put to boolean from " + values.type()->ToString() + @@ -378,19 +373,19 @@ class PlainEncoder : public EncoderImpl, if (data.null_count() == 0) { PARQUET_THROW_NOT_OK( sink_.Reserve(::arrow::bit_util::BytesForBits(data.length()))); - // no nulls, just dump the data + // No nulls, just dump the data. ::arrow::internal::CopyBitmap( - data.data()->GetValues(1), + data.data()->GetValues(1, 0), data.offset(), data.length(), sink_.mutable_data(), sink_.length()); } else { - auto n_valid = + auto nValid = ::arrow::bit_util::BytesForBits(data.length() - data.null_count()); - PARQUET_THROW_NOT_OK(sink_.Reserve(n_valid)); + PARQUET_THROW_NOT_OK(sink_.Reserve(nValid)); ::arrow::internal::FirstTimeBitmapWriter writer( - sink_.mutable_data(), sink_.length(), n_valid); + sink_.mutable_data(), sink_.length(), nValid); for (int64_t i = 0; i < data.length(); i++) { if (data.IsValid(i)) { @@ -408,69 +403,69 @@ class PlainEncoder : public EncoderImpl, } private: - int bits_available_; - std::shared_ptr bits_buffer_; + int bitsAvailable_; + std::shared_ptr bitsBuffer_; ::arrow::BufferBuilder sink_; - BitWriter bit_writer_; + BitWriter bitWriter_; template - void PutImpl(const SequenceType& src, int num_values); + void putImpl(const SequenceType& src, int numValues); }; template -void PlainEncoder::PutImpl( +void PlainEncoder::putImpl( const SequenceType& src, - int num_values) { - int bit_offset = 0; - if (bits_available_ > 0) { - int bits_to_write = std::min(bits_available_, num_values); - for (int i = 0; i < bits_to_write; i++) { - bit_writer_.PutValue(src[i], 1); - } - bits_available_ -= bits_to_write; - bit_offset = bits_to_write; - - if (bits_available_ == 0) { - bit_writer_.Flush(); + int numValues) { + int bitOffset = 0; + if (bitsAvailable_ > 0) { + int bitsToWrite = std::min(bitsAvailable_, numValues); + for (int i = 0; i < bitsToWrite; i++) { + bitWriter_.PutValue(src[i], 1); + } + bitsAvailable_ -= bitsToWrite; + bitOffset = bitsToWrite; + + if (bitsAvailable_ == 0) { + bitWriter_.Flush(); PARQUET_THROW_NOT_OK( - sink_.Append(bit_writer_.buffer(), bit_writer_.bytesWritten())); - bit_writer_.Clear(); + sink_.Append(bitWriter_.buffer(), bitWriter_.bytesWritten())); + bitWriter_.Clear(); } } - int bits_remaining = num_values - bit_offset; - while (bit_offset < num_values) { - bits_available_ = static_cast(bits_buffer_->size()) * 8; + int bitsRemaining = numValues - bitOffset; + while (bitOffset < numValues) { + bitsAvailable_ = static_cast(bitsBuffer_->size()) * 8; - int bits_to_write = std::min(bits_available_, bits_remaining); - for (int i = bit_offset; i < bit_offset + bits_to_write; i++) { - bit_writer_.PutValue(src[i], 1); + int bitsToWrite = std::min(bitsAvailable_, bitsRemaining); + for (int i = bitOffset; i < bitOffset + bitsToWrite; i++) { + bitWriter_.PutValue(src[i], 1); } - bit_offset += bits_to_write; - bits_available_ -= bits_to_write; - bits_remaining -= bits_to_write; + bitOffset += bitsToWrite; + bitsAvailable_ -= bitsToWrite; + bitsRemaining -= bitsToWrite; - if (bits_available_ == 0) { - bit_writer_.Flush(); + if (bitsAvailable_ == 0) { + bitWriter_.Flush(); PARQUET_THROW_NOT_OK( - sink_.Append(bit_writer_.buffer(), bit_writer_.bytesWritten())); - bit_writer_.Clear(); + sink_.Append(bitWriter_.buffer(), bitWriter_.bytesWritten())); + bitWriter_.Clear(); } } } -int64_t PlainEncoder::EstimatedDataEncodedSize() { +int64_t PlainEncoder::estimatedDataEncodedSize() { int64_t position = sink_.length(); - return position + bit_writer_.bytesWritten(); + return position + bitWriter_.bytesWritten(); } -std::shared_ptr PlainEncoder::FlushValues() { - if (bits_available_ > 0) { - bit_writer_.Flush(); +std::shared_ptr PlainEncoder::flushValues() { + if (bitsAvailable_ > 0) { + bitWriter_.Flush(); PARQUET_THROW_NOT_OK( - sink_.Append(bit_writer_.buffer(), bit_writer_.bytesWritten())); - bit_writer_.Clear(); - bits_available_ = static_cast(bits_buffer_->size()) * 8; + sink_.Append(bitWriter_.buffer(), bitWriter_.bytesWritten())); + bitWriter_.Clear(); + bitsAvailable_ = static_cast(bitsBuffer_->size()) * 8; } std::shared_ptr buffer; @@ -478,23 +473,23 @@ std::shared_ptr PlainEncoder::FlushValues() { return buffer; } -void PlainEncoder::Put(const bool* src, int num_values) { - PutImpl(src, num_values); +void PlainEncoder::put(const bool* src, int numValues) { + putImpl(src, numValues); } -void PlainEncoder::Put( +void PlainEncoder::put( const std::vector& src, - int num_values) { - PutImpl(src, num_values); + int numValues) { + putImpl(src, numValues); } -// ---------------------------------------------------------------------- -// DictEncoder implementations +// ----------------------------------------------------------------------. +// DictEncoder implementations. template struct DictEncoderTraits { - using c_type = typename DType::c_type; - using MemoTableType = ::arrow::internal::ScalarMemoTable; + using CType = typename DType::CType; + using MemoTableType = ::arrow::internal::ScalarMemoTable; }; template <> @@ -509,129 +504,129 @@ struct DictEncoderTraits { ::arrow::internal::BinaryMemoTable<::arrow::BinaryBuilder>; }; -// Initially 1024 elements +// Initially 1024 elements. static constexpr int32_t kInitialHashTableSize = 1 << 10; -int RlePreserveBufferSize(int num_values, int bit_width) { +int rlePreserveBufferSize(int numValues, int bitWidth) { // Note: because of the way RleEncoder::CheckBufferFull() - // is called, we have to reserve an extra "RleEncoder::MinBufferSize" - // bytes. These extra bytes won't be used but not reserving them - // would cause the encoder to fail. - return RleEncoder::MaxBufferSize(bit_width, num_values) + - RleEncoder::MinBufferSize(bit_width); + // Is called, we have to reserve an extra "RleEncoder::MinBufferSize". + // Bytes. These extra bytes won't be used but not reserving them. + // Would cause the encoder to fail. + return RleEncoder::MaxBufferSize(bitWidth, numValues) + + RleEncoder::MinBufferSize(bitWidth); } -/// See the dictionary encoding section of +/// See the dictionary encoding section of. /// https://github.com/Parquet/parquet-format. The encoding supports -/// streaming encoding. Values are encoded as they are added while the -/// dictionary is being constructed. At any time, the buffered values -/// can be written out with the current dictionary size. More values -/// can then be added to the encoder, including new dictionary -/// entries. +/// Streaming encoding. Values are encoded as they are added while the. +/// Dictionary is being constructed. At any time, the buffered values. +/// Can be written out with the current dictionary size. More values. +/// Can then be added to the encoder, including new dictionary. +/// Entries. template class DictEncoderImpl : public EncoderImpl, virtual public DictEncoder { using MemoTableType = typename DictEncoderTraits::MemoTableType; public: - typedef typename DType::c_type T; + typedef typename DType::CType T; - /// In data page, the bit width used to encode the entry - /// ids stored as 1 byte (max bit width = 32). + /// In data page, the bit width used to encode the entry. + /// Ids stored as 1 byte (max bit width = 32). constexpr static int32_t kDataPageBitWidthBytes = 1; explicit DictEncoderImpl(const ColumnDescriptor* desc, MemoryPool* pool) - : EncoderImpl(desc, Encoding::PLAIN_DICTIONARY, pool), - buffered_indices_(::arrow::stl::allocator(pool)), - dict_encoded_size_(0), - memo_table_(pool, kInitialHashTableSize) {} + : EncoderImpl(desc, Encoding::kPlainDictionary, pool), + bufferedIndices_(::arrow::stl::allocator(pool)), + dictEncodedSize_(0), + memoTable_(pool, kInitialHashTableSize) {} ~DictEncoderImpl() override { - VELOX_DCHECK(buffered_indices_.empty()); + VELOX_DCHECK(bufferedIndices_.empty()); } - int dict_encoded_size() const override { - return dict_encoded_size_; + int dictEncodedSize() const override { + return dictEncodedSize_; } - int WriteIndices(uint8_t* buffer, int buffer_len) override { - // Write bit width in first byte - *buffer = static_cast(bit_width()); + int writeIndices(uint8_t* buffer, int bufferLen) override { + // Write bit width in first byte. + *buffer = static_cast(bitWidth()); ++buffer; - --buffer_len; + --bufferLen; - RleEncoder encoder(buffer, buffer_len, bit_width()); + RleEncoder encoder(buffer, bufferLen, bitWidth()); - for (int32_t index : buffered_indices_) { + for (int32_t index : bufferedIndices_) { if (ARROW_PREDICT_FALSE(!encoder.Put(index))) return -1; } encoder.Flush(); - ClearIndices(); + clearIndices(); return kDataPageBitWidthBytes + encoder.len(); } - void set_type_length(int type_length) { - this->type_length_ = type_length; + void setTypeLength(int typeLength) { + this->typeLength_ = typeLength; } - /// Returns a conservative estimate of the number of bytes needed to encode - /// the buffered indices. Used to size the buffer passed to WriteIndices(). - int64_t EstimatedDataEncodedSize() override { + /// Returns a conservative estimate of the number of bytes needed to encode. + /// The buffered indices. Used to size the buffer passed to WriteIndices(). + int64_t estimatedDataEncodedSize() override { return kDataPageBitWidthBytes + - RlePreserveBufferSize( - static_cast(buffered_indices_.size()), bit_width()); + rlePreserveBufferSize( + static_cast(bufferedIndices_.size()), bitWidth()); } /// The minimum bit width required to encode the currently buffered indices. - int bit_width() const override { - if (ARROW_PREDICT_FALSE(num_entries() == 0)) + int bitWidth() const override { + if (ARROW_PREDICT_FALSE(numEntries() == 0)) return 0; - if (ARROW_PREDICT_FALSE(num_entries() == 1)) + if (ARROW_PREDICT_FALSE(numEntries() == 1)) return 1; - return ::arrow::bit_util::Log2(num_entries()); + return ::arrow::bit_util::Log2(numEntries()); } - /// Encode value. Note that this does not actually write any data, just - /// buffers the value's index to be written later. - inline void Put(const T& value); + /// Encode value. Note that this does not actually write any data, just. + /// Buffers the value's index to be written later. + inline void put(const T& value); - // Not implemented for other data types - inline void PutByteArray(const void* ptr, int32_t length); + // Not implemented for other data types. + inline void putByteArray(const void* ptr, int32_t length); - void Put(const T* src, int num_values) override { - for (int32_t i = 0; i < num_values; i++) { - Put(SafeLoad(src + i)); + void put(const T* src, int numValues) override { + for (int32_t i = 0; i < numValues; i++) { + put(SafeLoad(src + i)); } } - void PutSpaced( + void putSpaced( const T* src, - int num_values, - const uint8_t* valid_bits, - int64_t valid_bits_offset) override { + int numValues, + const uint8_t* validBits, + int64_t validBitsOffset) override { ::arrow::internal::VisitSetBitRunsVoid( - valid_bits, - valid_bits_offset, - num_values, + validBits, + validBitsOffset, + numValues, [&](int64_t position, int64_t length) { for (int64_t i = 0; i < length; i++) { - Put(SafeLoad(src + i + position)); + put(SafeLoad(src + i + position)); } }); } - using TypedEncoder::Put; + using TypedEncoder::put; - void Put(const ::arrow::Array& values) override; - void PutDictionary(const ::arrow::Array& values) override; + void put(const ::arrow::Array& values) override; + void putDictionary(const ::arrow::Array& values) override; template - void PutIndicesTyped(const ::arrow::Array& data) { + void putIndicesTyped(const ::arrow::Array& data) { auto values = data.data()->GetValues(1); - size_t buffer_position = buffered_indices_.size(); - buffered_indices_.resize( - buffer_position + + size_t bufferPosition = bufferedIndices_.size(); + bufferedIndices_.resize( + bufferPosition + static_cast(data.length() - data.null_count())); ::arrow::internal::VisitSetBitRunsVoid( data.null_bitmap_data(), @@ -639,60 +634,60 @@ class DictEncoderImpl : public EncoderImpl, virtual public DictEncoder { data.length(), [&](int64_t position, int64_t length) { for (int64_t i = 0; i < length; ++i) { - buffered_indices_[buffer_position++] = + bufferedIndices_[bufferPosition++] = static_cast(values[i + position]); } }); } - void PutIndices(const ::arrow::Array& data) override { + void putIndices(const ::arrow::Array& data) override { switch (data.type()->id()) { case ::arrow::Type::UINT8: case ::arrow::Type::INT8: - return PutIndicesTyped<::arrow::UInt8Type>(data); + return putIndicesTyped<::arrow::UInt8Type>(data); case ::arrow::Type::UINT16: case ::arrow::Type::INT16: - return PutIndicesTyped<::arrow::UInt16Type>(data); + return putIndicesTyped<::arrow::UInt16Type>(data); case ::arrow::Type::UINT32: case ::arrow::Type::INT32: - return PutIndicesTyped<::arrow::UInt32Type>(data); + return putIndicesTyped<::arrow::UInt32Type>(data); case ::arrow::Type::UINT64: case ::arrow::Type::INT64: - return PutIndicesTyped<::arrow::UInt64Type>(data); + return putIndicesTyped<::arrow::UInt64Type>(data); default: throw ParquetException("Passed non-integer array to PutIndices"); } } - std::shared_ptr FlushValues() override { + std::shared_ptr flushValues() override { std::shared_ptr buffer = - AllocateBuffer(this->pool_, EstimatedDataEncodedSize()); - int result_size = WriteIndices( - buffer->mutable_data(), static_cast(EstimatedDataEncodedSize())); - PARQUET_THROW_NOT_OK(buffer->Resize(result_size, false)); + allocateBuffer(this->pool_, estimatedDataEncodedSize()); + int resultSize = writeIndices( + buffer->mutable_data(), static_cast(estimatedDataEncodedSize())); + PARQUET_THROW_NOT_OK(buffer->Resize(resultSize, false)); return std::move(buffer); } - /// Writes out the encoded dictionary to buffer. buffer must be preallocated - /// to dict_encoded_size() bytes. - void WriteDict(uint8_t* buffer) const override; + /// Writes out the encoded dictionary to buffer. buffer must be preallocated. + /// To dict_encoded_size() bytes. + void writeDict(uint8_t* buffer) const override; /// The number of entries in the dictionary. - int num_entries() const override { - return memo_table_.size(); + int numEntries() const override { + return memoTable_.size(); } private: /// Clears all the indices (but leaves the dictionary). - void ClearIndices() { - buffered_indices_.clear(); + void clearIndices() { + bufferedIndices_.clear(); } /// Indices that have not yet be written out by WriteIndices(). - ArrowPoolVector buffered_indices_; + ArrowPoolVector bufferedIndices_; template - void PutBinaryArray(const ArrayType& array) { + void putBinaryArray(const ArrayType& array) { PARQUET_THROW_NOT_OK( ::arrow::VisitArraySpanInline( *array.data(), @@ -701,14 +696,14 @@ class DictEncoderImpl : public EncoderImpl, virtual public DictEncoder { return Status::Invalid( "Parquet cannot store strings with size 2GB or more"); } - PutByteArray(view.data(), static_cast(view.size())); + putByteArray(view.data(), static_cast(view.size())); return Status::OK(); }, []() { return Status::OK(); })); } template - void PutBinaryDictionaryArray(const ArrayType& array) { + void putBinaryDictionaryArray(const ArrayType& array) { VELOX_DCHECK_EQ(array.null_count(), 0); for (int64_t i = 0; i < array.length(); i++) { auto v = array.GetView(i); @@ -716,31 +711,31 @@ class DictEncoderImpl : public EncoderImpl, virtual public DictEncoder { throw ParquetException( "Parquet cannot store strings with size 2GB or more"); } - dict_encoded_size_ += static_cast(v.size() + sizeof(uint32_t)); - int32_t unused_memo_index; - PARQUET_THROW_NOT_OK(memo_table_.GetOrInsert( - v.data(), static_cast(v.size()), &unused_memo_index)); + dictEncodedSize_ += static_cast(v.size() + sizeof(uint32_t)); + int32_t unusedMemoIndex; + PARQUET_THROW_NOT_OK(memoTable_.GetOrInsert( + v.data(), static_cast(v.size()), &unusedMemoIndex)); } } /// The number of bytes needed to encode the dictionary. - int dict_encoded_size_; + int dictEncodedSize_; - MemoTableType memo_table_; + MemoTableType memoTable_; }; template -void DictEncoderImpl::WriteDict(uint8_t* buffer) const { - // For primitive types, only a memcpy +void DictEncoderImpl::writeDict(uint8_t* buffer) const { + // For primitive types, only a memcpy. VELOX_DCHECK_EQ( - static_cast(dict_encoded_size_), sizeof(T) * memo_table_.size()); - memo_table_.CopyValues(0 /* start_pos */, reinterpret_cast(buffer)); + static_cast(dictEncodedSize_), sizeof(T) * memoTable_.size()); + memoTable_.CopyValues(0 /* start_pos */, reinterpret_cast(buffer)); } -// ByteArray and FLBA already have the dictionary encoded in their data heaps +// ByteArray and FLBA already have the dictionary encoded in their data heaps. template <> -void DictEncoderImpl::WriteDict(uint8_t* buffer) const { - memo_table_.VisitValues(0, [&buffer](::std::string_view v) { +void DictEncoderImpl::writeDict(uint8_t* buffer) const { + memoTable_.VisitValues(0, [&buffer](::std::string_view v) { uint32_t len = static_cast(v.length()); memcpy(buffer, &len, sizeof(len)); buffer += sizeof(len); @@ -750,231 +745,231 @@ void DictEncoderImpl::WriteDict(uint8_t* buffer) const { } template <> -void DictEncoderImpl::WriteDict(uint8_t* buffer) const { - memo_table_.VisitValues(0, [&](::std::string_view v) { - VELOX_DCHECK_EQ(v.length(), static_cast(type_length_)); - memcpy(buffer, v.data(), type_length_); - buffer += type_length_; +void DictEncoderImpl::writeDict(uint8_t* buffer) const { + memoTable_.VisitValues(0, [&](::std::string_view v) { + VELOX_DCHECK_EQ(v.length(), static_cast(typeLength_)); + memcpy(buffer, v.data(), typeLength_); + buffer += typeLength_; }); } template -inline void DictEncoderImpl::Put(const T& v) { - // Put() implementation for primitive types - auto on_found = [](int32_t memo_index) {}; - auto on_not_found = [this](int32_t memo_index) { - dict_encoded_size_ += static_cast(sizeof(T)); +inline void DictEncoderImpl::put(const T& v) { + // Put() implementation for primitive types. + auto onFound = [](int32_t memoIndex) {}; + auto onNotFound = [this](int32_t memoIndex) { + dictEncodedSize_ += static_cast(sizeof(T)); }; - int32_t memo_index; + int32_t memoIndex; PARQUET_THROW_NOT_OK( - memo_table_.GetOrInsert(v, on_found, on_not_found, &memo_index)); - buffered_indices_.push_back(memo_index); + memoTable_.GetOrInsert(v, onFound, onNotFound, &memoIndex)); + bufferedIndices_.push_back(memoIndex); } template -inline void DictEncoderImpl::PutByteArray( +inline void DictEncoderImpl::putByteArray( const void* ptr, int32_t length) { VELOX_DCHECK(false); } template <> -inline void DictEncoderImpl::PutByteArray( +inline void DictEncoderImpl::putByteArray( const void* ptr, int32_t length) { static const uint8_t empty[] = {0}; - auto on_found = [](int32_t memo_index) {}; - auto on_not_found = [&](int32_t memo_index) { - dict_encoded_size_ += static_cast(length + sizeof(uint32_t)); + auto onFound = [](int32_t memoIndex) {}; + auto onNotFound = [&](int32_t memoIndex) { + dictEncodedSize_ += static_cast(length + sizeof(uint32_t)); }; VELOX_DCHECK(ptr != nullptr || length == 0); ptr = (ptr != nullptr) ? ptr : empty; - int32_t memo_index; - PARQUET_THROW_NOT_OK(memo_table_.GetOrInsert( - ptr, length, on_found, on_not_found, &memo_index)); - buffered_indices_.push_back(memo_index); + int32_t memoIndex; + PARQUET_THROW_NOT_OK( + memoTable_.GetOrInsert(ptr, length, onFound, onNotFound, &memoIndex)); + bufferedIndices_.push_back(memoIndex); } template <> -inline void DictEncoderImpl::Put(const ByteArray& val) { - return PutByteArray(val.ptr, static_cast(val.len)); +inline void DictEncoderImpl::put(const ByteArray& val) { + return putByteArray(val.ptr, static_cast(val.len)); } template <> -inline void DictEncoderImpl::Put(const FixedLenByteArray& v) { +inline void DictEncoderImpl::put(const FixedLenByteArray& v) { static const uint8_t empty[] = {0}; - auto on_found = [](int32_t memo_index) {}; - auto on_not_found = [this](int32_t memo_index) { - dict_encoded_size_ += type_length_; + auto onFound = [](int32_t memoIndex) {}; + auto onNotFound = [this](int32_t memoIndex) { + dictEncodedSize_ += typeLength_; }; - VELOX_DCHECK(v.ptr != nullptr || type_length_ == 0); + VELOX_DCHECK(v.ptr != nullptr || typeLength_ == 0); const void* ptr = (v.ptr != nullptr) ? v.ptr : empty; - int32_t memo_index; - PARQUET_THROW_NOT_OK(memo_table_.GetOrInsert( - ptr, type_length_, on_found, on_not_found, &memo_index)); - buffered_indices_.push_back(memo_index); + int32_t memoIndex; + PARQUET_THROW_NOT_OK(memoTable_.GetOrInsert( + ptr, typeLength_, onFound, onNotFound, &memoIndex)); + bufferedIndices_.push_back(memoIndex); } template <> -void DictEncoderImpl::Put(const ::arrow::Array& values) { +void DictEncoderImpl::put(const ::arrow::Array& values) { ParquetException::NYI("Direct put to Int96"); } template <> -void DictEncoderImpl::PutDictionary(const ::arrow::Array& values) { +void DictEncoderImpl::putDictionary(const ::arrow::Array& values) { ParquetException::NYI("Direct put to Int96"); } template -void DictEncoderImpl::Put(const ::arrow::Array& values) { +void DictEncoderImpl::put(const ::arrow::Array& values) { using ArrayType = - typename ::arrow::CTypeTraits::ArrayType; + typename ::arrow::CTypeTraits::ArrayType; const auto& data = checked_cast(values); if (data.null_count() == 0) { - // no nulls, just dump the data + // No nulls, just dump the data. for (int64_t i = 0; i < data.length(); i++) { - Put(data.Value(i)); + put(data.Value(i)); } } else { for (int64_t i = 0; i < data.length(); i++) { if (data.IsValid(i)) { - Put(data.Value(i)); + put(data.Value(i)); } } } } template <> -void DictEncoderImpl::Put(const ::arrow::Array& values) { - AssertFixedSizeBinary(values, type_length_); +void DictEncoderImpl::put(const ::arrow::Array& values) { + assertFixedSizeBinary(values, typeLength_); const auto& data = checked_cast(values); if (data.null_count() == 0) { - // no nulls, just dump the data + // No nulls, just dump the data. for (int64_t i = 0; i < data.length(); i++) { - Put(FixedLenByteArray(data.Value(i))); + put(FixedLenByteArray(data.Value(i))); } } else { - std::vector empty(type_length_, 0); + std::vector empty(typeLength_, 0); for (int64_t i = 0; i < data.length(); i++) { if (data.IsValid(i)) { - Put(FixedLenByteArray(data.Value(i))); + put(FixedLenByteArray(data.Value(i))); } } } } template <> -void DictEncoderImpl::Put(const ::arrow::Array& values) { - AssertBaseBinary(values); +void DictEncoderImpl::put(const ::arrow::Array& values) { + assertBaseBinary(values); if (::arrow::is_binary_like(values.type_id())) { - PutBinaryArray(checked_cast(values)); + putBinaryArray(checked_cast(values)); } else { VELOX_DCHECK(::arrow::is_large_binary_like(values.type_id())); - PutBinaryArray(checked_cast(values)); + putBinaryArray(checked_cast(values)); } } template -void AssertCanPutDictionary( +void assertCanPutDictionary( DictEncoderImpl* encoder, const ::arrow::Array& dict) { if (dict.null_count() > 0) { throw ParquetException("Inserted dictionary cannot cannot contain nulls"); } - if (encoder->num_entries() > 0) { + if (encoder->numEntries() > 0) { throw ParquetException( "Can only call PutDictionary on an empty DictEncoder"); } } template -void DictEncoderImpl::PutDictionary(const ::arrow::Array& values) { - AssertCanPutDictionary(this, values); +void DictEncoderImpl::putDictionary(const ::arrow::Array& values) { + assertCanPutDictionary(this, values); using ArrayType = - typename ::arrow::CTypeTraits::ArrayType; + typename ::arrow::CTypeTraits::ArrayType; const auto& data = checked_cast(values); - dict_encoded_size_ += - static_cast(sizeof(typename DType::c_type) * data.length()); + dictEncodedSize_ += + static_cast(sizeof(typename DType::CType) * data.length()); for (int64_t i = 0; i < data.length(); i++) { - int32_t unused_memo_index; + int32_t unusedMemoIndex; PARQUET_THROW_NOT_OK( - memo_table_.GetOrInsert(data.Value(i), &unused_memo_index)); + memoTable_.GetOrInsert(data.Value(i), &unusedMemoIndex)); } } template <> -void DictEncoderImpl::PutDictionary(const ::arrow::Array& values) { - AssertFixedSizeBinary(values, type_length_); - AssertCanPutDictionary(this, values); +void DictEncoderImpl::putDictionary(const ::arrow::Array& values) { + assertFixedSizeBinary(values, typeLength_); + assertCanPutDictionary(this, values); const auto& data = checked_cast(values); - dict_encoded_size_ += static_cast(type_length_ * data.length()); + dictEncodedSize_ += static_cast(typeLength_ * data.length()); for (int64_t i = 0; i < data.length(); i++) { - int32_t unused_memo_index; - PARQUET_THROW_NOT_OK(memo_table_.GetOrInsert( - data.Value(i), type_length_, &unused_memo_index)); + int32_t unusedMemoIndex; + PARQUET_THROW_NOT_OK( + memoTable_.GetOrInsert(data.Value(i), typeLength_, &unusedMemoIndex)); } } template <> -void DictEncoderImpl::PutDictionary( +void DictEncoderImpl::putDictionary( const ::arrow::Array& values) { - AssertBaseBinary(values); - AssertCanPutDictionary(this, values); + assertBaseBinary(values); + assertCanPutDictionary(this, values); if (::arrow::is_binary_like(values.type_id())) { - PutBinaryDictionaryArray(checked_cast(values)); + putBinaryDictionaryArray(checked_cast(values)); } else { VELOX_DCHECK(::arrow::is_large_binary_like(values.type_id())); - PutBinaryDictionaryArray( + putBinaryDictionaryArray( checked_cast(values)); } } -// ---------------------------------------------------------------------- -// ByteStreamSplitEncoder implementations +// ----------------------------------------------------------------------. +// ByteStreamSplitEncoder implementations. template class ByteStreamSplitEncoder : public EncoderImpl, virtual public TypedEncoder { public: - using T = typename DType::c_type; - using TypedEncoder::Put; + using T = typename DType::CType; + using TypedEncoder::put; explicit ByteStreamSplitEncoder( const ColumnDescriptor* descr, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()); - int64_t EstimatedDataEncodedSize() override; - std::shared_ptr FlushValues() override; + int64_t estimatedDataEncodedSize() override; + std::shared_ptr flushValues() override; - void Put(const T* buffer, int num_values) override; - void Put(const ::arrow::Array& values) override; - void PutSpaced( + void put(const T* buffer, int numValues) override; + void put(const ::arrow::Array& values) override; + void putSpaced( const T* src, - int num_values, - const uint8_t* valid_bits, - int64_t valid_bits_offset) override; + int numValues, + const uint8_t* validBits, + int64_t validBitsOffset) override; protected: template - void PutImpl(const ::arrow::Array& values) { + void putImpl(const ::arrow::Array& values) { if (values.type_id() != ArrowType::type_id) { throw ParquetException( std::string() + "direct put to " + ArrowType::type_name() + " from " + values.type()->ToString() + " not supported"); } const auto& data = *values.data(); - PutSpaced( + putSpaced( data.GetValues(1), static_cast(data.length), data.GetValues(0, 0), @@ -982,82 +977,79 @@ class ByteStreamSplitEncoder : public EncoderImpl, } ::arrow::BufferBuilder sink_; - int64_t num_values_in_buffer_; + int64_t numValuesInBuffer_; }; template ByteStreamSplitEncoder::ByteStreamSplitEncoder( const ColumnDescriptor* descr, ::arrow::MemoryPool* pool) - : EncoderImpl(descr, Encoding::BYTE_STREAM_SPLIT, pool), + : EncoderImpl(descr, Encoding::kByteStreamSplit, pool), sink_{pool}, - num_values_in_buffer_{0} {} + numValuesInBuffer_{0} {} template -int64_t ByteStreamSplitEncoder::EstimatedDataEncodedSize() { +int64_t ByteStreamSplitEncoder::estimatedDataEncodedSize() { return sink_.length(); } template -std::shared_ptr ByteStreamSplitEncoder::FlushValues() { - std::shared_ptr output_buffer = - AllocateBuffer(this->memory_pool(), EstimatedDataEncodedSize()); - uint8_t* output_buffer_raw = output_buffer->mutable_data(); +std::shared_ptr ByteStreamSplitEncoder::flushValues() { + std::shared_ptr outputBuffer = + allocateBuffer(this->memoryPool(), estimatedDataEncodedSize()); + uint8_t* outputBufferRaw = outputBuffer->mutable_data(); const uint8_t* raw_values = sink_.data(); - ByteStreamSplitEncode( - raw_values, num_values_in_buffer_, output_buffer_raw); + byteStreamSplitEncode(raw_values, numValuesInBuffer_, outputBufferRaw); sink_.Reset(); - num_values_in_buffer_ = 0; - return std::move(output_buffer); + numValuesInBuffer_ = 0; + return std::move(outputBuffer); } template -void ByteStreamSplitEncoder::Put(const T* buffer, int num_values) { - if (num_values > 0) { - PARQUET_THROW_NOT_OK(sink_.Append(buffer, num_values * sizeof(T))); - num_values_in_buffer_ += num_values; +void ByteStreamSplitEncoder::put(const T* buffer, int numValues) { + if (numValues > 0) { + PARQUET_THROW_NOT_OK(sink_.Append(buffer, numValues * sizeof(T))); + numValuesInBuffer_ += numValues; } } template <> -void ByteStreamSplitEncoder::Put(const ::arrow::Array& values) { - PutImpl<::arrow::FloatType>(values); +void ByteStreamSplitEncoder::put(const ::arrow::Array& values) { + putImpl<::arrow::FloatType>(values); } template <> -void ByteStreamSplitEncoder::Put(const ::arrow::Array& values) { - PutImpl<::arrow::DoubleType>(values); +void ByteStreamSplitEncoder::put(const ::arrow::Array& values) { + putImpl<::arrow::DoubleType>(values); } template -void ByteStreamSplitEncoder::PutSpaced( +void ByteStreamSplitEncoder::putSpaced( const T* src, - int num_values, - const uint8_t* valid_bits, - int64_t valid_bits_offset) { - if (valid_bits != NULLPTR) { - PARQUET_ASSIGN_OR_THROW( - auto buffer, - ::arrow::AllocateBuffer(num_values * sizeof(T), this->memory_pool())); + int numValues, + const uint8_t* validBits, + int64_t validBitsOffset) { + if (validBits != NULLPTR) { + auto buffer = allocateBuffer(this->memoryPool(), numValues * sizeof(T)); T* data = reinterpret_cast(buffer->mutable_data()); - int num_valid_values = ::arrow::util::internal::SpacedCompress( - src, num_values, valid_bits, valid_bits_offset, data); - Put(data, num_valid_values); + int numValidValues = ::arrow::util::internal::SpacedCompress( + src, numValues, validBits, validBitsOffset, data); + put(data, numValidValues); } else { - Put(src, num_values); + put(src, numValues); } } class DecoderImpl : virtual public Decoder { public: - void SetData(int num_values, const uint8_t* data, int len) override { - num_values_ = num_values; + void setData(int numValues, const uint8_t* data, int len) override { + numValues_ = numValues; data_ = data; len_ = len; } - int values_left() const override { - return num_values_; + int valuesLeft() const override { + return numValues_; } Encoding::type encoding() const override { return encoding_; @@ -1067,553 +1059,545 @@ class DecoderImpl : virtual public Decoder { explicit DecoderImpl(const ColumnDescriptor* descr, Encoding::type encoding) : descr_(descr), encoding_(encoding), - num_values_(0), + numValues_(0), data_(NULLPTR), len_(0) {} - // For accessing type-specific metadata, like FIXED_LEN_BYTE_ARRAY + // For accessing type-specific metadata, like FIXED_LEN_BYTE_ARRAY. const ColumnDescriptor* descr_; const Encoding::type encoding_; - int num_values_; + int numValues_; const uint8_t* data_; int len_; - int type_length_; + int typeLength_; }; template class PlainDecoder : public DecoderImpl, virtual public TypedDecoder { public: - using T = typename DType::c_type; + using T = typename DType::CType; explicit PlainDecoder(const ColumnDescriptor* descr); - int Decode(T* buffer, int max_values) override; - - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::Accumulator* builder) override; - - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) override; + int decode(T* buffer, int maxValues) override; + + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::Accumulator* Builder) override; + + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) override; }; template <> -inline int PlainDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::Accumulator* builder) { +inline int PlainDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::Accumulator* Builder) { ParquetException::NYI("DecodeArrow not supported for Int96"); } template <> -inline int PlainDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) { +inline int PlainDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) { ParquetException::NYI("DecodeArrow not supported for Int96"); } template <> -inline int PlainDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) { +inline int PlainDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) { ParquetException::NYI("dictionaries of BooleanType"); } template -int PlainDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::Accumulator* builder) { - using value_type = typename DType::c_type; +int PlainDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::Accumulator* Builder) { + using ValueType = typename DType::CType; - constexpr int value_size = static_cast(sizeof(value_type)); - int values_decoded = num_values - null_count; - if (ARROW_PREDICT_FALSE(len_ < value_size * values_decoded)) { - ParquetException::EofException(); + constexpr int valueSize = static_cast(sizeof(ValueType)); + int valuesDecoded = numValues - nullCount; + if (ARROW_PREDICT_FALSE(len_ < valueSize * valuesDecoded)) { + ParquetException::eofException(); } - PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); + PARQUET_THROW_NOT_OK(Builder->Reserve(numValues)); VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { - builder->UnsafeAppend(SafeLoadAs(data_)); - data_ += sizeof(value_type); + Builder->UnsafeAppend(SafeLoadAs(data_)); + data_ += sizeof(ValueType); }, - [&]() { builder->UnsafeAppendNull(); }); + [&]() { Builder->UnsafeAppendNull(); }); - num_values_ -= values_decoded; - len_ -= sizeof(value_type) * values_decoded; - return values_decoded; + numValues_ -= valuesDecoded; + len_ -= sizeof(ValueType) * valuesDecoded; + return valuesDecoded; } template -int PlainDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) { - using value_type = typename DType::c_type; +int PlainDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) { + using ValueType = typename DType::CType; - constexpr int value_size = static_cast(sizeof(value_type)); - int values_decoded = num_values - null_count; - if (ARROW_PREDICT_FALSE(len_ < value_size * values_decoded)) { - ParquetException::EofException(); + constexpr int valueSize = static_cast(sizeof(ValueType)); + int valuesDecoded = numValues - nullCount; + if (ARROW_PREDICT_FALSE(len_ < valueSize * valuesDecoded)) { + ParquetException::eofException(); } - PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); + PARQUET_THROW_NOT_OK(Builder->Reserve(numValues)); VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { - PARQUET_THROW_NOT_OK(builder->Append(SafeLoadAs(data_))); - data_ += sizeof(value_type); + PARQUET_THROW_NOT_OK(Builder->Append(SafeLoadAs(data_))); + data_ += sizeof(ValueType); }, - [&]() { PARQUET_THROW_NOT_OK(builder->AppendNull()); }); + [&]() { PARQUET_THROW_NOT_OK(Builder->AppendNull()); }); - num_values_ -= values_decoded; - len_ -= sizeof(value_type) * values_decoded; - return values_decoded; + numValues_ -= valuesDecoded; + len_ -= sizeof(ValueType) * valuesDecoded; + return valuesDecoded; } -// Decode routine templated on C++ type rather than type enum +// Decode routine templated on C++ type rather than type enum. template -inline int DecodePlain( +inline int decodePlain( const uint8_t* data, - int64_t data_size, - int num_values, - int type_length, + int64_t dataSize, + int numValues, + int typeLength, T* out) { - int64_t bytes_to_decode = num_values * static_cast(sizeof(T)); - if (bytes_to_decode > data_size || bytes_to_decode > INT_MAX) { - ParquetException::EofException(); + int64_t bytesToDecode = numValues * static_cast(sizeof(T)); + if (bytesToDecode > dataSize || bytesToDecode > INT_MAX) { + ParquetException::eofException(); } - // If bytes_to_decode == 0, data could be null - if (bytes_to_decode > 0) { - memcpy(out, data, bytes_to_decode); + // If bytes_to_decode == 0, data could be null. + if (bytesToDecode > 0) { + memcpy(out, data, bytesToDecode); } - return static_cast(bytes_to_decode); + return static_cast(bytesToDecode); } template PlainDecoder::PlainDecoder(const ColumnDescriptor* descr) - : DecoderImpl(descr, Encoding::PLAIN) { - if (descr_ && descr_->physical_type() == Type::FIXED_LEN_BYTE_ARRAY) { - type_length_ = descr_->type_length(); + : DecoderImpl(descr, Encoding::kPlain) { + if (descr_ && descr_->physicalType() == Type::kFixedLenByteArray) { + typeLength_ = descr_->typeLength(); } else { - type_length_ = -1; + typeLength_ = -1; } } -// Template specialization for BYTE_ARRAY. The written values do not own their -// own data. +// Template specialization for BYTE_ARRAY. The written values do not own their. +// Own data. static inline int64_t -ReadByteArray(const uint8_t* data, int64_t data_size, ByteArray* out) { - if (ARROW_PREDICT_FALSE(data_size < 4)) { - ParquetException::EofException(); +readByteArray(const uint8_t* data, int64_t dataSize, ByteArray* out) { + if (ARROW_PREDICT_FALSE(dataSize < 4)) { + ParquetException::eofException(); } const int32_t len = SafeLoadAs(data); if (len < 0) { throw ParquetException("Invalid BYTE_ARRAY value"); } - const int64_t consumed_length = static_cast(len) + 4; - if (ARROW_PREDICT_FALSE(data_size < consumed_length)) { - ParquetException::EofException(); + const int64_t consumedLength = static_cast(len) + 4; + if (ARROW_PREDICT_FALSE(dataSize < consumedLength)) { + ParquetException::eofException(); } *out = ByteArray{static_cast(len), data + 4}; - return consumed_length; + return consumedLength; } template <> -inline int DecodePlain( +inline int decodePlain( const uint8_t* data, - int64_t data_size, - int num_values, - int type_length, + int64_t dataSize, + int numValues, + int typeLength, ByteArray* out) { - int bytes_decoded = 0; - for (int i = 0; i < num_values; ++i) { - const auto increment = ReadByteArray(data, data_size, out + i); - if (ARROW_PREDICT_FALSE(increment > INT_MAX - bytes_decoded)) { + int bytesDecoded = 0; + for (int i = 0; i < numValues; ++i) { + const auto increment = readByteArray(data, dataSize, out + i); + if (ARROW_PREDICT_FALSE(increment > INT_MAX - bytesDecoded)) { throw ParquetException("BYTE_ARRAY chunk too large"); } data += increment; - data_size -= increment; - bytes_decoded += static_cast(increment); + dataSize -= increment; + bytesDecoded += static_cast(increment); } - return bytes_decoded; + return bytesDecoded; } -// Template specialization for FIXED_LEN_BYTE_ARRAY. The written values do not -// own their own data. +// Template specialization for FIXED_LEN_BYTE_ARRAY. The written values do not. +// Own their own data. template <> -inline int DecodePlain( +inline int decodePlain( const uint8_t* data, - int64_t data_size, - int num_values, - int type_length, + int64_t dataSize, + int numValues, + int typeLength, FixedLenByteArray* out) { - int64_t bytes_to_decode = static_cast(type_length) * num_values; - if (bytes_to_decode > data_size || bytes_to_decode > INT_MAX) { - ParquetException::EofException(); + int64_t bytesToDecode = static_cast(typeLength) * numValues; + if (bytesToDecode > dataSize || bytesToDecode > INT_MAX) { + ParquetException::eofException(); } - for (int i = 0; i < num_values; ++i) { + for (int i = 0; i < numValues; ++i) { out[i].ptr = data; - data += type_length; - data_size -= type_length; + data += typeLength; + dataSize -= typeLength; } - return static_cast(bytes_to_decode); + return static_cast(bytesToDecode); } template -int PlainDecoder::Decode(T* buffer, int max_values) { - max_values = std::min(max_values, num_values_); - int bytes_consumed = - DecodePlain(data_, len_, max_values, type_length_, buffer); - data_ += bytes_consumed; - len_ -= bytes_consumed; - num_values_ -= max_values; - return max_values; +int PlainDecoder::decode(T* buffer, int maxValues) { + maxValues = std::min(maxValues, numValues_); + int bytesConsumed = + decodePlain(data_, len_, maxValues, typeLength_, buffer); + data_ += bytesConsumed; + len_ -= bytesConsumed; + numValues_ -= maxValues; + return maxValues; } class PlainBooleanDecoder : public DecoderImpl, virtual public BooleanDecoder { public: explicit PlainBooleanDecoder(const ColumnDescriptor* descr); - void SetData(int num_values, const uint8_t* data, int len) override; - - // Two flavors of bool decoding - int Decode(uint8_t* buffer, int max_values) override; - int Decode(bool* buffer, int max_values) override; - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + void setData(int numValues, const uint8_t* data, int len) override; + + // Two flavors of bool decoding. + int decode(uint8_t* buffer, int maxValues) override; + int decode(bool* buffer, int maxValues) override; + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::Accumulator* out) override; - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::DictAccumulator* out) override; private: - std::unique_ptr bit_reader_; + std::unique_ptr bitReader_; }; PlainBooleanDecoder::PlainBooleanDecoder(const ColumnDescriptor* descr) - : DecoderImpl(descr, Encoding::PLAIN) {} + : DecoderImpl(descr, Encoding::kPlain) {} -void PlainBooleanDecoder::SetData( - int num_values, - const uint8_t* data, - int len) { - num_values_ = num_values; - bit_reader_ = std::make_unique(data, len); +void PlainBooleanDecoder::setData(int numValues, const uint8_t* data, int len) { + numValues_ = numValues; + bitReader_ = std::make_unique(data, len); } -int PlainBooleanDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::Accumulator* builder) { - int values_decoded = num_values - null_count; - if (ARROW_PREDICT_FALSE(num_values_ < values_decoded)) { - ParquetException::EofException(); +int PlainBooleanDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::Accumulator* Builder) { + int valuesDecoded = numValues - nullCount; + if (ARROW_PREDICT_FALSE(numValues_ < valuesDecoded)) { + ParquetException::eofException(); } - PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); + PARQUET_THROW_NOT_OK(Builder->Reserve(numValues)); VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { bool value; - ARROW_IGNORE_EXPR(bit_reader_->GetValue(1, &value)); - builder->UnsafeAppend(value); + ARROW_IGNORE_EXPR(bitReader_->GetValue(1, &value)); + Builder->UnsafeAppend(value); }, - [&]() { builder->UnsafeAppendNull(); }); + [&]() { Builder->UnsafeAppendNull(); }); - num_values_ -= values_decoded; - return values_decoded; + numValues_ -= valuesDecoded; + return valuesDecoded; } -inline int PlainBooleanDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) { +inline int PlainBooleanDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) { ParquetException::NYI("dictionaries of BooleanType"); } -int PlainBooleanDecoder::Decode(uint8_t* buffer, int max_values) { - max_values = std::min(max_values, num_values_); +int PlainBooleanDecoder::decode(uint8_t* buffer, int maxValues) { + maxValues = std::min(maxValues, numValues_); bool val; - ::arrow::internal::BitmapWriter bit_writer(buffer, 0, max_values); - for (int i = 0; i < max_values; ++i) { - if (!bit_reader_->GetValue(1, &val)) { - ParquetException::EofException(); + ::arrow::internal::BitmapWriter bitWriter(buffer, 0, maxValues); + for (int i = 0; i < maxValues; ++i) { + if (!bitReader_->GetValue(1, &val)) { + ParquetException::eofException(); } if (val) { - bit_writer.Set(); + bitWriter.Set(); } - bit_writer.Next(); + bitWriter.Next(); } - bit_writer.Finish(); - num_values_ -= max_values; - return max_values; + bitWriter.Finish(); + numValues_ -= maxValues; + return maxValues; } -int PlainBooleanDecoder::Decode(bool* buffer, int max_values) { - max_values = std::min(max_values, num_values_); - if (bit_reader_->GetBatch(1, buffer, max_values) != max_values) { - ParquetException::EofException(); +int PlainBooleanDecoder::decode(bool* buffer, int maxValues) { + maxValues = std::min(maxValues, numValues_); + if (bitReader_->GetBatch(1, buffer, maxValues) != maxValues) { + ParquetException::eofException(); } - num_values_ -= max_values; - return max_values; + numValues_ -= maxValues; + return maxValues; } struct ArrowBinaryHelper { explicit ArrowBinaryHelper( typename EncodingTraits::Accumulator* out) { this->out = out; - this->builder = out->builder.get(); - this->chunk_space_remaining = - ::arrow::kBinaryMemoryLimit - this->builder->value_data_length(); + this->Builder = out->Builder.get(); + this->chunkSpaceRemaining = + ::arrow::kBinaryMemoryLimit - this->Builder->value_data_length(); } - Status PushChunk() { + Status pushChunk() { std::shared_ptr<::arrow::Array> result; - RETURN_NOT_OK(builder->Finish(&result)); + RETURN_NOT_OK(Builder->Finish(&result)); out->chunks.push_back(result); - chunk_space_remaining = ::arrow::kBinaryMemoryLimit; + chunkSpaceRemaining = ::arrow::kBinaryMemoryLimit; return Status::OK(); } - bool CanFit(int64_t length) const { - return length <= chunk_space_remaining; + bool canFit(int64_t length) const { + return length <= chunkSpaceRemaining; } void UnsafeAppend(const uint8_t* data, int32_t length) { - chunk_space_remaining -= length; - builder->UnsafeAppend(data, length); + chunkSpaceRemaining -= length; + Builder->UnsafeAppend(data, length); } void UnsafeAppendNull() { - builder->UnsafeAppendNull(); + Builder->UnsafeAppendNull(); } - Status Append(const uint8_t* data, int32_t length) { - chunk_space_remaining -= length; - return builder->Append(data, length); + Status append(const uint8_t* data, int32_t length) { + chunkSpaceRemaining -= length; + return Builder->Append(data, length); } - Status AppendNull() { - return builder->AppendNull(); + Status appendNull() { + return Builder->AppendNull(); } typename EncodingTraits::Accumulator* out; - ::arrow::BinaryBuilder* builder; - int64_t chunk_space_remaining; + ::arrow::BinaryBuilder* Builder; + int64_t chunkSpaceRemaining; }; template <> -inline int PlainDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::Accumulator* builder) { +inline int PlainDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::Accumulator* Builder) { ParquetException::NYI(); } template <> -inline int PlainDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) { +inline int PlainDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) { ParquetException::NYI(); } template <> -inline int PlainDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::Accumulator* builder) { - int values_decoded = num_values - null_count; - if (ARROW_PREDICT_FALSE(len_ < descr_->type_length() * values_decoded)) { - ParquetException::EofException(); +inline int PlainDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::Accumulator* Builder) { + int valuesDecoded = numValues - nullCount; + if (ARROW_PREDICT_FALSE(len_ < descr_->typeLength() * valuesDecoded)) { + ParquetException::eofException(); } - PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); + PARQUET_THROW_NOT_OK(Builder->Reserve(numValues)); VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { - builder->UnsafeAppend(data_); - data_ += descr_->type_length(); + Builder->UnsafeAppend(data_); + data_ += descr_->typeLength(); }, - [&]() { builder->UnsafeAppendNull(); }); + [&]() { Builder->UnsafeAppendNull(); }); - num_values_ -= values_decoded; - len_ -= descr_->type_length() * values_decoded; - return values_decoded; + numValues_ -= valuesDecoded; + len_ -= descr_->typeLength() * valuesDecoded; + return valuesDecoded; } template <> -inline int PlainDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) { - int values_decoded = num_values - null_count; - if (ARROW_PREDICT_FALSE(len_ < descr_->type_length() * values_decoded)) { - ParquetException::EofException(); +inline int PlainDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) { + int valuesDecoded = numValues - nullCount; + if (ARROW_PREDICT_FALSE(len_ < descr_->typeLength() * valuesDecoded)) { + ParquetException::eofException(); } - PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); + PARQUET_THROW_NOT_OK(Builder->Reserve(numValues)); VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { - PARQUET_THROW_NOT_OK(builder->Append(data_)); - data_ += descr_->type_length(); + PARQUET_THROW_NOT_OK(Builder->Append(data_)); + data_ += descr_->typeLength(); }, - [&]() { PARQUET_THROW_NOT_OK(builder->AppendNull()); }); + [&]() { PARQUET_THROW_NOT_OK(Builder->AppendNull()); }); - num_values_ -= values_decoded; - len_ -= descr_->type_length() * values_decoded; - return values_decoded; + numValues_ -= valuesDecoded; + len_ -= descr_->typeLength() * valuesDecoded; + return valuesDecoded; } class PlainByteArrayDecoder : public PlainDecoder, virtual public ByteArrayDecoder { public: using Base = PlainDecoder; - using Base::DecodeSpaced; + using Base::decodeSpaced; using Base::PlainDecoder; - // ---------------------------------------------------------------------- - // Dictionary read paths + // ----------------------------------------------------------------------. + // Dictionary read paths. - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - ::arrow::BinaryDictionary32Builder* builder) override { + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + ::arrow::BinaryDictionary32Builder* Builder) override { int result = 0; - PARQUET_THROW_NOT_OK(DecodeArrow( - num_values, - null_count, - valid_bits, - valid_bits_offset, - builder, - &result)); + PARQUET_THROW_NOT_OK(decodeArrow( + numValues, nullCount, validBits, validBitsOffset, Builder, &result)); return result; } - // ---------------------------------------------------------------------- - // Optimized dense binary read paths + // ----------------------------------------------------------------------. + // Optimized dense binary read paths. - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::Accumulator* out) override { int result = 0; - PARQUET_THROW_NOT_OK(DecodeArrowDense( - num_values, null_count, valid_bits, valid_bits_offset, out, &result)); + PARQUET_THROW_NOT_OK(decodeArrowDense( + numValues, nullCount, validBits, validBitsOffset, out, &result)); return result; } private: - Status DecodeArrowDense( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + Status decodeArrowDense( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::Accumulator* out, - int* out_values_decoded) { + int* outValuesDecoded) { ArrowBinaryHelper helper(out); - int values_decoded = 0; + int valuesDecoded = 0; - RETURN_NOT_OK(helper.builder->Reserve(num_values)); - RETURN_NOT_OK(helper.builder->ReserveData( - std::min(len_, helper.chunk_space_remaining))); + RETURN_NOT_OK(helper.Builder->Reserve(numValues)); + RETURN_NOT_OK(helper.Builder->ReserveData( + std::min(len_, helper.chunkSpaceRemaining))); int i = 0; RETURN_NOT_OK(VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { if (ARROW_PREDICT_FALSE(len_ < 4)) { - ParquetException::EofException(); + ParquetException::eofException(); } - auto value_len = SafeLoadAs(data_); - if (ARROW_PREDICT_FALSE(value_len < 0 || value_len > INT32_MAX - 4)) { + auto valueLen = SafeLoadAs(data_); + if (ARROW_PREDICT_FALSE(valueLen < 0 || valueLen > INT32_MAX - 4)) { return Status::Invalid( - "Invalid or corrupted value_len '", value_len, "'"); + "Invalid or corrupted value_len '", valueLen, "'"); } - auto increment = value_len + 4; + auto increment = valueLen + 4; if (ARROW_PREDICT_FALSE(len_ < increment)) { - ParquetException::EofException(); + ParquetException::eofException(); } - if (ARROW_PREDICT_FALSE(!helper.CanFit(value_len))) { - // This element would exceed the capacity of a chunk - RETURN_NOT_OK(helper.PushChunk()); - RETURN_NOT_OK(helper.builder->Reserve(num_values - i)); - RETURN_NOT_OK(helper.builder->ReserveData( - std::min(len_, helper.chunk_space_remaining))); + if (ARROW_PREDICT_FALSE(!helper.canFit(valueLen))) { + // This element would exceed the capacity of a chunk. + RETURN_NOT_OK(helper.pushChunk()); + RETURN_NOT_OK(helper.Builder->Reserve(numValues - i)); + RETURN_NOT_OK(helper.Builder->ReserveData( + std::min(len_, helper.chunkSpaceRemaining))); } - helper.UnsafeAppend(data_ + 4, value_len); + helper.UnsafeAppend(data_ + 4, valueLen); data_ += increment; len_ -= increment; - ++values_decoded; + ++valuesDecoded; ++i; return Status::OK(); }, @@ -1623,50 +1607,50 @@ class PlainByteArrayDecoder : public PlainDecoder, return Status::OK(); })); - num_values_ -= values_decoded; - *out_values_decoded = values_decoded; + numValues_ -= valuesDecoded; + *outValuesDecoded = valuesDecoded; return Status::OK(); } template - Status DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - BuilderType* builder, - int* out_values_decoded) { - RETURN_NOT_OK(builder->Reserve(num_values)); - int values_decoded = 0; + Status decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + BuilderType* Builder, + int* outValuesDecoded) { + RETURN_NOT_OK(Builder->Reserve(numValues)); + int valuesDecoded = 0; RETURN_NOT_OK(VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { if (ARROW_PREDICT_FALSE(len_ < 4)) { - ParquetException::EofException(); + ParquetException::eofException(); } - auto value_len = SafeLoadAs(data_); - if (ARROW_PREDICT_FALSE(value_len < 0 || value_len > INT32_MAX - 4)) { + auto valueLen = SafeLoadAs(data_); + if (ARROW_PREDICT_FALSE(valueLen < 0 || valueLen > INT32_MAX - 4)) { return Status::Invalid( - "Invalid or corrupted value_len '", value_len, "'"); + "Invalid or corrupted value_len '", valueLen, "'"); } - auto increment = value_len + 4; + auto increment = valueLen + 4; if (ARROW_PREDICT_FALSE(len_ < increment)) { - ParquetException::EofException(); + ParquetException::eofException(); } - RETURN_NOT_OK(builder->Append(data_ + 4, value_len)); + RETURN_NOT_OK(Builder->Append(data_ + 4, valueLen)); data_ += increment; len_ -= increment; - ++values_decoded; + ++valuesDecoded; return Status::OK(); }, - [&]() { return builder->AppendNull(); })); + [&]() { return Builder->AppendNull(); })); - num_values_ -= values_decoded; - *out_values_decoded = values_decoded; + numValues_ -= valuesDecoded; + *outValuesDecoded = valuesDecoded; return Status::OK(); } }; @@ -1678,486 +1662,475 @@ class PlainFLBADecoder : public PlainDecoder, using Base::PlainDecoder; }; -// ---------------------------------------------------------------------- -// Dictionary encoding and decoding +// ----------------------------------------------------------------------. +// Dictionary encoding and decoding. template class DictDecoderImpl : public DecoderImpl, virtual public DictDecoder { public: - typedef typename Type::c_type T; + typedef typename Type::CType T; - // Initializes the dictionary with values from 'dictionary'. The data in - // dictionary is not guaranteed to persist in memory after this call so the - // dictionary decoder needs to copy the data out if necessary. + // Initializes the dictionary with values from 'dictionary'. The data in. + // Dictionary is not guaranteed to persist in memory after this call so the. + // Dictionary decoder needs to copy the data out if necessary. explicit DictDecoderImpl( const ColumnDescriptor* descr, MemoryPool* pool = ::arrow::default_memory_pool()) - : DecoderImpl(descr, Encoding::RLE_DICTIONARY), - dictionary_(AllocateBuffer(pool, 0)), - dictionary_length_(0), - byte_array_data_(AllocateBuffer(pool, 0)), - byte_array_offsets_(AllocateBuffer(pool, 0)), - indices_scratch_space_(AllocateBuffer(pool, 0)) {} - - // Perform type-specific initiatialization - void SetDict(TypedDecoder* dictionary) override; - - void SetData(int num_values, const uint8_t* data, int len) override { - num_values_ = num_values; + : DecoderImpl(descr, Encoding::kRleDictionary), + dictionary_(allocateBuffer(pool, 0)), + dictionaryLength_(0), + byteArrayData_(allocateBuffer(pool, 0)), + byteArrayOffsets_(allocateBuffer(pool, 0)), + indicesScratchSpace_(allocateBuffer(pool, 0)) {} + + // Perform type-specific initiatialization. + void setDict(TypedDecoder* dictionary) override; + + void setData(int numValues, const uint8_t* data, int len) override { + numValues_ = numValues; if (len == 0) { - // Initialize dummy decoder to avoid crashes later on - idx_decoder_ = RleDecoder(data, len, /*bitWidth=*/1); + // Initialize dummy decoder to avoid crashes later on. + idxDecoder_ = RleDecoder(data, len, 1); return; } - uint8_t bit_width = *data; - if (ARROW_PREDICT_FALSE(bit_width > 32)) { + uint8_t bitWidth = *data; + if (ARROW_PREDICT_FALSE(bitWidth > 32)) { throw ParquetException( - "Invalid or corrupted bit_width " + std::to_string(bit_width) + + "Invalid or corrupted bit_width " + std::to_string(bitWidth) + ". Maximum allowed is 32."); } - idx_decoder_ = RleDecoder(++data, --len, bit_width); + idxDecoder_ = RleDecoder(++data, --len, bitWidth); } - int Decode(T* buffer, int num_values) override { - num_values = std::min(num_values, num_values_); - int decoded_values = idx_decoder_.GetBatchWithDict( + int decode(T* buffer, int numValues) override { + numValues = std::min(numValues, numValues_); + int decodedValues = idxDecoder_.GetBatchWithDict( reinterpret_cast(dictionary_->data()), - dictionary_length_, + dictionaryLength_, buffer, - num_values); - if (decoded_values != num_values) { - ParquetException::EofException(); + numValues); + if (decodedValues != numValues) { + ParquetException::eofException(); } - num_values_ -= num_values; - return num_values; + numValues_ -= numValues; + return numValues; } - int DecodeSpaced( + int decodeSpaced( T* buffer, - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset) override { - num_values = std::min(num_values, num_values_); - if (num_values != - idx_decoder_.GetBatchWithDictSpaced( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset) override { + numValues = std::min(numValues, numValues_); + if (numValues != + idxDecoder_.GetBatchWithDictSpaced( reinterpret_cast(dictionary_->data()), - dictionary_length_, + dictionaryLength_, buffer, - num_values, - null_count, - valid_bits, - valid_bits_offset)) { - ParquetException::EofException(); + numValues, + nullCount, + validBits, + validBitsOffset)) { + ParquetException::eofException(); } - num_values_ -= num_values; - return num_values; + numValues_ -= numValues; + return numValues; } - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::Accumulator* out) override; - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::DictAccumulator* out) override; - void InsertDictionary(::arrow::ArrayBuilder* builder) override; - - int DecodeIndicesSpaced( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - ::arrow::ArrayBuilder* builder) override { - if (num_values > 0) { - // TODO(wesm): Refactor to batch reads for improved memory use. It is not - // trivial because the null_count is relative to the entire bitmap - PARQUET_THROW_NOT_OK(indices_scratch_space_->TypedResize( - num_values, /*shrink_to_fit=*/false)); + void insertDictionary(::arrow::ArrayBuilder* Builder) override; + + int decodeIndicesSpaced( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + ::arrow::ArrayBuilder* Builder) override { + if (numValues > 0) { + // TODO(wesm): Refactor to batch reads for improved memory use. It is not. + // Trivial because the null_count is relative to the entire bitmap. + PARQUET_THROW_NOT_OK( + indicesScratchSpace_->TypedResize(numValues, false)); } - auto indices_buffer = - reinterpret_cast(indices_scratch_space_->mutable_data()); + auto indicesBuffer = + reinterpret_cast(indicesScratchSpace_->mutable_data()); - if (num_values != - idx_decoder_.GetBatchSpaced( - num_values, - null_count, - valid_bits, - valid_bits_offset, - indices_buffer)) { - ParquetException::EofException(); + if (numValues != + idxDecoder_.GetBatchSpaced( + numValues, nullCount, validBits, validBitsOffset, indicesBuffer)) { + ParquetException::eofException(); } - // XXX(wesm): Cannot append "valid bits" directly to the builder - std::vector valid_bytes(num_values, 0); + // XXX(wesm): Cannot append "valid bits" directly to the builder. + std::vector validBytes(numValues, 0); int64_t i = 0; VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, - [&]() { valid_bytes[i++] = 1; }, + validBits, + validBitsOffset, + numValues, + nullCount, + [&]() { validBytes[i++] = 1; }, [&]() { ++i; }); - auto binary_builder = - checked_cast<::arrow::BinaryDictionary32Builder*>(builder); - PARQUET_THROW_NOT_OK(binary_builder->AppendIndices( - indices_buffer, num_values, valid_bytes.data())); - num_values_ -= num_values - null_count; - return num_values - null_count; - } - - int DecodeIndices(int num_values, ::arrow::ArrayBuilder* builder) override { - num_values = std::min(num_values, num_values_); - if (num_values > 0) { - // TODO(wesm): Refactor to batch reads for improved memory use. This is - // relatively simple here because we don't have to do any bookkeeping of - // nulls - PARQUET_THROW_NOT_OK(indices_scratch_space_->TypedResize( - num_values, /*shrink_to_fit=*/false)); - } - auto indices_buffer = - reinterpret_cast(indices_scratch_space_->mutable_data()); - if (num_values != idx_decoder_.GetBatch(indices_buffer, num_values)) { - ParquetException::EofException(); - } - auto binary_builder = - checked_cast<::arrow::BinaryDictionary32Builder*>(builder); + auto binaryBuilder = + checked_cast<::arrow::BinaryDictionary32Builder*>(Builder); + PARQUET_THROW_NOT_OK(binaryBuilder->AppendIndices( + indicesBuffer, numValues, validBytes.data())); + numValues_ -= numValues - nullCount; + return numValues - nullCount; + } + + int decodeIndices(int numValues, ::arrow::ArrayBuilder* Builder) override { + numValues = std::min(numValues, numValues_); + if (numValues > 0) { + // TODO(wesm): Refactor to batch reads for improved memory use. This is. + // Relatively simple here because we don't have to do any bookkeeping of. + // Nulls. + PARQUET_THROW_NOT_OK( + indicesScratchSpace_->TypedResize(numValues, false)); + } + auto indicesBuffer = + reinterpret_cast(indicesScratchSpace_->mutable_data()); + if (numValues != idxDecoder_.GetBatch(indicesBuffer, numValues)) { + ParquetException::eofException(); + } + auto binaryBuilder = + checked_cast<::arrow::BinaryDictionary32Builder*>(Builder); PARQUET_THROW_NOT_OK( - binary_builder->AppendIndices(indices_buffer, num_values)); - num_values_ -= num_values; - return num_values; + binaryBuilder->AppendIndices(indicesBuffer, numValues)); + numValues_ -= numValues; + return numValues; } - int DecodeIndices(int num_values, int32_t* indices) override { - if (num_values != idx_decoder_.GetBatch(indices, num_values)) { - ParquetException::EofException(); + int decodeIndices(int numValues, int32_t* indices) override { + if (numValues != idxDecoder_.GetBatch(indices, numValues)) { + ParquetException::eofException(); } - num_values_ -= num_values; - return num_values; + numValues_ -= numValues; + return numValues; } - void GetDictionary(const T** dictionary, int32_t* dictionary_length) - override { - *dictionary_length = dictionary_length_; + void getDictionary(const T** dictionary, int32_t* dictionaryLength) override { + *dictionaryLength = dictionaryLength_; *dictionary = reinterpret_cast(dictionary_->mutable_data()); } protected: - Status IndexInBounds(int32_t index) { - if (ARROW_PREDICT_TRUE(0 <= index && index < dictionary_length_)) { + Status indexInBounds(int32_t index) { + if (ARROW_PREDICT_TRUE(0 <= index && index < dictionaryLength_)) { return Status::OK(); } return Status::Invalid("Index not in dictionary bounds"); } - inline void DecodeDict(TypedDecoder* dictionary) { - dictionary_length_ = static_cast(dictionary->values_left()); - PARQUET_THROW_NOT_OK(dictionary_->Resize( - dictionary_length_ * sizeof(T), - /*shrink_to_fit=*/false)); - dictionary->Decode( - reinterpret_cast(dictionary_->mutable_data()), dictionary_length_); + inline void decodeDict(TypedDecoder* dictionary) { + dictionaryLength_ = static_cast(dictionary->valuesLeft()); + PARQUET_THROW_NOT_OK( + dictionary_->Resize(dictionaryLength_ * sizeof(T), false)); + dictionary->decode( + reinterpret_cast(dictionary_->mutable_data()), dictionaryLength_); } // Only one is set. std::shared_ptr dictionary_; - int32_t dictionary_length_; + int32_t dictionaryLength_; - // Data that contains the byte array data (byte_array_dictionary_ just has the - // pointers). - std::shared_ptr byte_array_data_; + // Data that contains the byte array data (byte_array_dictionary_ just has + // the. Pointers). + std::shared_ptr byteArrayData_; - // Arrow-style byte offsets for each dictionary value. We maintain two - // representations of the dictionary, one as ByteArray* for non-Arrow - // consumers and this one for Arrow consumers. Since dictionaries are - // generally pretty small to begin with this doesn't mean too much extra - // memory use in most cases - std::shared_ptr byte_array_offsets_; + // Arrow-style byte offsets for each dictionary value. We maintain two. + // Representations of the dictionary, one as ByteArray* for non-Arrow. + // Consumers and this one for Arrow consumers. Since dictionaries are. + // Generally pretty small to begin with this doesn't mean too much extra. + // Memory use in most cases. + std::shared_ptr byteArrayOffsets_; - // Reusable buffer for decoding dictionary indices to be appended to a - // BinaryDictionary32Builder - std::shared_ptr indices_scratch_space_; + // Reusable buffer for decoding dictionary indices to be appended to a. + // BinaryDictionary32Builder. + std::shared_ptr indicesScratchSpace_; - RleDecoder idx_decoder_; + RleDecoder idxDecoder_; }; template -void DictDecoderImpl::SetDict(TypedDecoder* dictionary) { - DecodeDict(dictionary); +void DictDecoderImpl::setDict(TypedDecoder* dictionary) { + decodeDict(dictionary); } template <> -void DictDecoderImpl::SetDict( +void DictDecoderImpl::setDict( TypedDecoder* dictionary) { ParquetException::NYI( "Dictionary encoding is not implemented for boolean values"); } template <> -void DictDecoderImpl::SetDict( +void DictDecoderImpl::setDict( TypedDecoder* dictionary) { - DecodeDict(dictionary); + decodeDict(dictionary); - auto dict_values = reinterpret_cast(dictionary_->mutable_data()); + auto dictValues = reinterpret_cast(dictionary_->mutable_data()); - int total_size = 0; - for (int i = 0; i < dictionary_length_; ++i) { - total_size += dict_values[i].len; + int totalSize = 0; + for (int i = 0; i < dictionaryLength_; ++i) { + totalSize += dictValues[i].len; } - PARQUET_THROW_NOT_OK(byte_array_data_->Resize( - total_size, - /*shrink_to_fit=*/false)); - PARQUET_THROW_NOT_OK(byte_array_offsets_->Resize( - (dictionary_length_ + 1) * sizeof(int32_t), - /*shrink_to_fit=*/false)); + PARQUET_THROW_NOT_OK(byteArrayData_->Resize(totalSize, false)); + PARQUET_THROW_NOT_OK(byteArrayOffsets_->Resize( + (dictionaryLength_ + 1) * sizeof(int32_t), false)); int32_t offset = 0; - uint8_t* bytes_data = byte_array_data_->mutable_data(); - int32_t* bytes_offsets = - reinterpret_cast(byte_array_offsets_->mutable_data()); - for (int i = 0; i < dictionary_length_; ++i) { - memcpy(bytes_data + offset, dict_values[i].ptr, dict_values[i].len); - bytes_offsets[i] = offset; - dict_values[i].ptr = bytes_data + offset; - offset += dict_values[i].len; + uint8_t* bytesData = byteArrayData_->mutable_data(); + int32_t* bytesOffsets = + reinterpret_cast(byteArrayOffsets_->mutable_data()); + for (int i = 0; i < dictionaryLength_; ++i) { + memcpy(bytesData + offset, dictValues[i].ptr, dictValues[i].len); + bytesOffsets[i] = offset; + dictValues[i].ptr = bytesData + offset; + offset += dictValues[i].len; } - bytes_offsets[dictionary_length_] = offset; + bytesOffsets[dictionaryLength_] = offset; } template <> -inline void DictDecoderImpl::SetDict( +inline void DictDecoderImpl::setDict( TypedDecoder* dictionary) { - DecodeDict(dictionary); + decodeDict(dictionary); - auto dict_values = reinterpret_cast(dictionary_->mutable_data()); + auto dictValues = reinterpret_cast(dictionary_->mutable_data()); - int fixed_len = descr_->type_length(); - int total_size = dictionary_length_ * fixed_len; + int fixedLen = descr_->typeLength(); + int totalSize = dictionaryLength_ * fixedLen; - PARQUET_THROW_NOT_OK(byte_array_data_->Resize( - total_size, - /*shrink_to_fit=*/false)); - uint8_t* bytes_data = byte_array_data_->mutable_data(); - for (int32_t i = 0, offset = 0; i < dictionary_length_; - ++i, offset += fixed_len) { - memcpy(bytes_data + offset, dict_values[i].ptr, fixed_len); - dict_values[i].ptr = bytes_data + offset; + PARQUET_THROW_NOT_OK(byteArrayData_->Resize(totalSize, false)); + uint8_t* bytesData = byteArrayData_->mutable_data(); + for (int32_t i = 0, offset = 0; i < dictionaryLength_; + ++i, offset += fixedLen) { + memcpy(bytesData + offset, dictValues[i].ptr, fixedLen); + dictValues[i].ptr = bytesData + offset; } } template <> -inline int DictDecoderImpl::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::Accumulator* builder) { +inline int DictDecoderImpl::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::Accumulator* Builder) { ParquetException::NYI("DecodeArrow to Int96Type"); } template <> -inline int DictDecoderImpl::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) { +inline int DictDecoderImpl::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) { ParquetException::NYI("DecodeArrow to Int96Type"); } template <> -inline int DictDecoderImpl::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::Accumulator* builder) { +inline int DictDecoderImpl::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::Accumulator* Builder) { ParquetException::NYI("DecodeArrow implemented elsewhere"); } template <> -inline int DictDecoderImpl::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) { +inline int DictDecoderImpl::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) { ParquetException::NYI("DecodeArrow implemented elsewhere"); } template -int DictDecoderImpl::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) { - PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); +int DictDecoderImpl::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) { + PARQUET_THROW_NOT_OK(Builder->Reserve(numValues)); - auto dict_values = - reinterpret_cast(dictionary_->data()); + auto dictValues = + reinterpret_cast(dictionary_->data()); VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { int32_t index; - if (ARROW_PREDICT_FALSE(!idx_decoder_.Get(&index))) { + if (ARROW_PREDICT_FALSE(!idxDecoder_.Get(&index))) { throw ParquetException(""); } - PARQUET_THROW_NOT_OK(IndexInBounds(index)); - PARQUET_THROW_NOT_OK(builder->Append(dict_values[index])); + PARQUET_THROW_NOT_OK(indexInBounds(index)); + PARQUET_THROW_NOT_OK(Builder->Append(dictValues[index])); }, - [&]() { PARQUET_THROW_NOT_OK(builder->AppendNull()); }); + [&]() { PARQUET_THROW_NOT_OK(Builder->AppendNull()); }); - return num_values - null_count; + return numValues - nullCount; } template <> -int DictDecoderImpl::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) { +int DictDecoderImpl::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) { ParquetException::NYI("No dictionary encoding for BooleanType"); } template <> -inline int DictDecoderImpl::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::Accumulator* builder) { - if (builder->byte_width() != descr_->type_length()) { +inline int DictDecoderImpl::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::Accumulator* Builder) { + if (Builder->byte_width() != descr_->typeLength()) { throw ParquetException( "Byte width mismatch: builder was " + - std::to_string(builder->byte_width()) + " but decoder was " + - std::to_string(descr_->type_length())); + std::to_string(Builder->byte_width()) + " but decoder was " + + std::to_string(descr_->typeLength())); } - PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); + PARQUET_THROW_NOT_OK(Builder->Reserve(numValues)); - auto dict_values = reinterpret_cast(dictionary_->data()); + auto dictValues = reinterpret_cast(dictionary_->data()); VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { int32_t index; - if (ARROW_PREDICT_FALSE(!idx_decoder_.Get(&index))) { + if (ARROW_PREDICT_FALSE(!idxDecoder_.Get(&index))) { throw ParquetException(""); } - PARQUET_THROW_NOT_OK(IndexInBounds(index)); - builder->UnsafeAppend(dict_values[index].ptr); + PARQUET_THROW_NOT_OK(indexInBounds(index)); + Builder->UnsafeAppend(dictValues[index].ptr); }, - [&]() { builder->UnsafeAppendNull(); }); + [&]() { Builder->UnsafeAppendNull(); }); - return num_values - null_count; + return numValues - nullCount; } template <> -int DictDecoderImpl::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) { - auto value_type = - checked_cast(*builder->type()) +int DictDecoderImpl::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) { + auto ValueType = + checked_cast(*Builder->type()) .value_type(); auto byte_width = - checked_cast(*value_type) + checked_cast(*ValueType) .byte_width(); - if (byte_width != descr_->type_length()) { + if (byte_width != descr_->typeLength()) { throw ParquetException( "Byte width mismatch: builder was " + std::to_string(byte_width) + - " but decoder was " + std::to_string(descr_->type_length())); + " but decoder was " + std::to_string(descr_->typeLength())); } - PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); + PARQUET_THROW_NOT_OK(Builder->Reserve(numValues)); - auto dict_values = reinterpret_cast(dictionary_->data()); + auto dictValues = reinterpret_cast(dictionary_->data()); VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { int32_t index; - if (ARROW_PREDICT_FALSE(!idx_decoder_.Get(&index))) { + if (ARROW_PREDICT_FALSE(!idxDecoder_.Get(&index))) { throw ParquetException(""); } - PARQUET_THROW_NOT_OK(IndexInBounds(index)); - PARQUET_THROW_NOT_OK(builder->Append(dict_values[index].ptr)); + PARQUET_THROW_NOT_OK(indexInBounds(index)); + PARQUET_THROW_NOT_OK(Builder->Append(dictValues[index].ptr)); }, - [&]() { PARQUET_THROW_NOT_OK(builder->AppendNull()); }); + [&]() { PARQUET_THROW_NOT_OK(Builder->AppendNull()); }); - return num_values - null_count; + return numValues - nullCount; } template -int DictDecoderImpl::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::Accumulator* builder) { - PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); +int DictDecoderImpl::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::Accumulator* Builder) { + PARQUET_THROW_NOT_OK(Builder->Reserve(numValues)); - using value_type = typename Type::c_type; - auto dict_values = reinterpret_cast(dictionary_->data()); + using ValueType = typename Type::CType; + auto dictValues = reinterpret_cast(dictionary_->data()); VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { int32_t index; - if (ARROW_PREDICT_FALSE(!idx_decoder_.Get(&index))) { + if (ARROW_PREDICT_FALSE(!idxDecoder_.Get(&index))) { throw ParquetException(""); } - PARQUET_THROW_NOT_OK(IndexInBounds(index)); - builder->UnsafeAppend(dict_values[index]); + PARQUET_THROW_NOT_OK(indexInBounds(index)); + Builder->UnsafeAppend(dictValues[index]); }, - [&]() { builder->UnsafeAppendNull(); }); + [&]() { Builder->UnsafeAppendNull(); }); - return num_values - null_count; + return numValues - nullCount; } template -void DictDecoderImpl::InsertDictionary(::arrow::ArrayBuilder* builder) { +void DictDecoderImpl::insertDictionary(::arrow::ArrayBuilder* Builder) { ParquetException::NYI( "InsertDictionary only implemented for BYTE_ARRAY types"); } template <> -void DictDecoderImpl::InsertDictionary( - ::arrow::ArrayBuilder* builder) { - auto binary_builder = - checked_cast<::arrow::BinaryDictionary32Builder*>(builder); +void DictDecoderImpl::insertDictionary( + ::arrow::ArrayBuilder* Builder) { + auto binaryBuilder = + checked_cast<::arrow::BinaryDictionary32Builder*>(Builder); - // Make a BinaryArray referencing the internal dictionary data + // Make a BinaryArray referencing the internal dictionary data. auto arr = std::make_shared<::arrow::BinaryArray>( - dictionary_length_, byte_array_offsets_, byte_array_data_); - PARQUET_THROW_NOT_OK(binary_builder->InsertMemoValues(*arr)); + dictionaryLength_, byteArrayOffsets_, byteArrayData_); + PARQUET_THROW_NOT_OK(binaryBuilder->InsertMemoValues(*arr)); } class DictByteArrayDecoderImpl : public DictDecoderImpl, @@ -2166,498 +2139,490 @@ class DictByteArrayDecoderImpl : public DictDecoderImpl, using BASE = DictDecoderImpl; using BASE::DictDecoderImpl; - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - ::arrow::BinaryDictionary32Builder* builder) override { + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + ::arrow::BinaryDictionary32Builder* Builder) override { int result = 0; - if (null_count == 0) { - PARQUET_THROW_NOT_OK(DecodeArrowNonNull(num_values, builder, &result)); + if (nullCount == 0) { + PARQUET_THROW_NOT_OK(decodeArrowNonNull(numValues, Builder, &result)); } else { - PARQUET_THROW_NOT_OK(DecodeArrow( - num_values, - null_count, - valid_bits, - valid_bits_offset, - builder, - &result)); + PARQUET_THROW_NOT_OK(decodeArrow( + numValues, nullCount, validBits, validBitsOffset, Builder, &result)); } return result; } - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::Accumulator* out) override { int result = 0; - if (null_count == 0) { - PARQUET_THROW_NOT_OK(DecodeArrowDenseNonNull(num_values, out, &result)); + if (nullCount == 0) { + PARQUET_THROW_NOT_OK(decodeArrowDenseNonNull(numValues, out, &result)); } else { - PARQUET_THROW_NOT_OK(DecodeArrowDense( - num_values, null_count, valid_bits, valid_bits_offset, out, &result)); + PARQUET_THROW_NOT_OK(decodeArrowDense( + numValues, nullCount, validBits, validBitsOffset, out, &result)); } return result; } private: - Status DecodeArrowDense( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + Status decodeArrowDense( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::Accumulator* out, - int* out_num_values) { + int* outNumValues) { constexpr int32_t kBufferSize = 1024; int32_t indices[kBufferSize]; ArrowBinaryHelper helper(out); - auto dict_values = reinterpret_cast(dictionary_->data()); - int values_decoded = 0; - int num_indices = 0; - int pos_indices = 0; - - auto visit_valid = [&](int64_t position) -> Status { - if (num_indices == pos_indices) { - // Refill indices buffer - const auto batch_size = std::min( - kBufferSize, num_values - null_count - values_decoded); - num_indices = idx_decoder_.GetBatch(indices, batch_size); - if (ARROW_PREDICT_FALSE(num_indices < 1)) { - return Status::Invalid("Invalid number of indices: ", num_indices); + auto dictValues = reinterpret_cast(dictionary_->data()); + int valuesDecoded = 0; + int numIndices = 0; + int posIndices = 0; + + auto visitValid = [&](int64_t position) -> Status { + if (numIndices == posIndices) { + // Refill indices buffer. + const auto batchSize = std::min( + kBufferSize, numValues - nullCount - valuesDecoded); + numIndices = idxDecoder_.GetBatch(indices, batchSize); + if (ARROW_PREDICT_FALSE(numIndices < 1)) { + return Status::Invalid("Invalid number of indices: ", numIndices); } - pos_indices = 0; + posIndices = 0; } - const auto index = indices[pos_indices++]; - RETURN_NOT_OK(IndexInBounds(index)); - const auto& val = dict_values[index]; - if (ARROW_PREDICT_FALSE(!helper.CanFit(val.len))) { - RETURN_NOT_OK(helper.PushChunk()); + const auto index = indices[posIndices++]; + RETURN_NOT_OK(indexInBounds(index)); + const auto& val = dictValues[index]; + if (ARROW_PREDICT_FALSE(!helper.canFit(val.len))) { + RETURN_NOT_OK(helper.pushChunk()); } - RETURN_NOT_OK(helper.Append(val.ptr, static_cast(val.len))); - ++values_decoded; + RETURN_NOT_OK(helper.append(val.ptr, static_cast(val.len))); + ++valuesDecoded; return Status::OK(); }; - auto visit_null = [&]() -> Status { - RETURN_NOT_OK(helper.AppendNull()); + auto visitNull = [&]() -> Status { + RETURN_NOT_OK(helper.appendNull()); return Status::OK(); }; - ::arrow::internal::BitBlockCounter bit_blocks( - valid_bits, valid_bits_offset, num_values); + ::arrow::internal::BitBlockCounter bitBlocks( + validBits, validBitsOffset, numValues); int64_t position = 0; - while (position < num_values) { - const auto block = bit_blocks.NextWord(); + while (position < numValues) { + const auto block = bitBlocks.NextWord(); if (block.AllSet()) { for (int64_t i = 0; i < block.length; ++i, ++position) { - ARROW_RETURN_NOT_OK(visit_valid(position)); + ARROW_RETURN_NOT_OK(visitValid(position)); } } else if (block.NoneSet()) { for (int64_t i = 0; i < block.length; ++i, ++position) { - ARROW_RETURN_NOT_OK(visit_null()); + ARROW_RETURN_NOT_OK(visitNull()); } } else { for (int64_t i = 0; i < block.length; ++i, ++position) { if (::arrow::bit_util::GetBit( - valid_bits, valid_bits_offset + position)) { - ARROW_RETURN_NOT_OK(visit_valid(position)); + validBits, validBitsOffset + position)) { + ARROW_RETURN_NOT_OK(visitValid(position)); } else { - ARROW_RETURN_NOT_OK(visit_null()); + ARROW_RETURN_NOT_OK(visitNull()); } } } } - *out_num_values = values_decoded; + *outNumValues = valuesDecoded; return Status::OK(); } - Status DecodeArrowDenseNonNull( - int num_values, + Status decodeArrowDenseNonNull( + int numValues, typename EncodingTraits::Accumulator* out, - int* out_num_values) { + int* outNumValues) { constexpr int32_t kBufferSize = 2048; int32_t indices[kBufferSize]; - int values_decoded = 0; + int valuesDecoded = 0; ArrowBinaryHelper helper(out); - auto dict_values = reinterpret_cast(dictionary_->data()); - - while (values_decoded < num_values) { - int32_t batch_size = - std::min(kBufferSize, num_values - values_decoded); - int num_indices = idx_decoder_.GetBatch(indices, batch_size); - if (num_indices == 0) - ParquetException::EofException(); - for (int i = 0; i < num_indices; ++i) { + auto dictValues = reinterpret_cast(dictionary_->data()); + + while (valuesDecoded < numValues) { + int32_t batchSize = + std::min(kBufferSize, numValues - valuesDecoded); + int numIndices = idxDecoder_.GetBatch(indices, batchSize); + if (numIndices == 0) + ParquetException::eofException(); + for (int i = 0; i < numIndices; ++i) { auto idx = indices[i]; - RETURN_NOT_OK(IndexInBounds(idx)); - const auto& val = dict_values[idx]; - if (ARROW_PREDICT_FALSE(!helper.CanFit(val.len))) { - RETURN_NOT_OK(helper.PushChunk()); + RETURN_NOT_OK(indexInBounds(idx)); + const auto& val = dictValues[idx]; + if (ARROW_PREDICT_FALSE(!helper.canFit(val.len))) { + RETURN_NOT_OK(helper.pushChunk()); } - RETURN_NOT_OK(helper.Append(val.ptr, static_cast(val.len))); + RETURN_NOT_OK(helper.append(val.ptr, static_cast(val.len))); } - values_decoded += num_indices; + valuesDecoded += numIndices; } - *out_num_values = values_decoded; + *outNumValues = valuesDecoded; return Status::OK(); } template - Status DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - BuilderType* builder, - int* out_num_values) { + Status decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + BuilderType* Builder, + int* outNumValues) { constexpr int32_t kBufferSize = 1024; int32_t indices[kBufferSize]; - RETURN_NOT_OK(builder->Reserve(num_values)); - ::arrow::internal::BitmapReader bit_reader( - valid_bits, valid_bits_offset, num_values); + RETURN_NOT_OK(Builder->Reserve(numValues)); + ::arrow::internal::BitmapReader bitReader( + validBits, validBitsOffset, numValues); - auto dict_values = reinterpret_cast(dictionary_->data()); + auto dictValues = reinterpret_cast(dictionary_->data()); - int values_decoded = 0; - int num_appended = 0; - while (num_appended < num_values) { - bool is_valid = bit_reader.IsSet(); - bit_reader.Next(); + int valuesDecoded = 0; + int numAppended = 0; + while (numAppended < numValues) { + bool IsValid = bitReader.IsSet(); + bitReader.Next(); - if (is_valid) { - int32_t batch_size = std::min( - kBufferSize, num_values - num_appended - null_count); - int num_indices = idx_decoder_.GetBatch(indices, batch_size); + if (IsValid) { + int32_t batchSize = + std::min(kBufferSize, numValues - numAppended - nullCount); + int numIndices = idxDecoder_.GetBatch(indices, batchSize); int i = 0; while (true) { - // Consume all indices - if (is_valid) { + // Consume all indices. + if (IsValid) { auto idx = indices[i]; - RETURN_NOT_OK(IndexInBounds(idx)); - const auto& val = dict_values[idx]; - RETURN_NOT_OK(builder->Append(val.ptr, val.len)); + RETURN_NOT_OK(indexInBounds(idx)); + const auto& val = dictValues[idx]; + RETURN_NOT_OK(Builder->Append(val.ptr, val.len)); ++i; - ++values_decoded; + ++valuesDecoded; } else { - RETURN_NOT_OK(builder->AppendNull()); - --null_count; + RETURN_NOT_OK(Builder->AppendNull()); + --nullCount; } - ++num_appended; - if (i == num_indices) { - // Do not advance the bit_reader if we have fulfilled the decode - // request + ++numAppended; + if (i == numIndices) { + // Do not advance the bit_reader if we have fulfilled the decode. + // Request. break; } - is_valid = bit_reader.IsSet(); - bit_reader.Next(); + IsValid = bitReader.IsSet(); + bitReader.Next(); } } else { - RETURN_NOT_OK(builder->AppendNull()); - --null_count; - ++num_appended; + RETURN_NOT_OK(Builder->AppendNull()); + --nullCount; + ++numAppended; } } - *out_num_values = values_decoded; + *outNumValues = valuesDecoded; return Status::OK(); } template - Status DecodeArrowNonNull( - int num_values, - BuilderType* builder, - int* out_num_values) { + Status + decodeArrowNonNull(int numValues, BuilderType* Builder, int* outNumValues) { constexpr int32_t kBufferSize = 2048; int32_t indices[kBufferSize]; - RETURN_NOT_OK(builder->Reserve(num_values)); + RETURN_NOT_OK(Builder->Reserve(numValues)); - auto dict_values = reinterpret_cast(dictionary_->data()); + auto dictValues = reinterpret_cast(dictionary_->data()); - int values_decoded = 0; - while (values_decoded < num_values) { - int32_t batch_size = - std::min(kBufferSize, num_values - values_decoded); - int num_indices = idx_decoder_.GetBatch(indices, batch_size); - if (num_indices == 0) - ParquetException::EofException(); - for (int i = 0; i < num_indices; ++i) { + int valuesDecoded = 0; + while (valuesDecoded < numValues) { + int32_t batchSize = + std::min(kBufferSize, numValues - valuesDecoded); + int numIndices = idxDecoder_.GetBatch(indices, batchSize); + if (numIndices == 0) + ParquetException::eofException(); + for (int i = 0; i < numIndices; ++i) { auto idx = indices[i]; - RETURN_NOT_OK(IndexInBounds(idx)); - const auto& val = dict_values[idx]; - RETURN_NOT_OK(builder->Append(val.ptr, val.len)); + RETURN_NOT_OK(indexInBounds(idx)); + const auto& val = dictValues[idx]; + RETURN_NOT_OK(Builder->Append(val.ptr, val.len)); } - values_decoded += num_indices; + valuesDecoded += numIndices; } - *out_num_values = values_decoded; + *outNumValues = valuesDecoded; return Status::OK(); } }; -// ---------------------------------------------------------------------- -// DeltaBitPackEncoder +// ----------------------------------------------------------------------. +// DeltaBitPackEncoder. -/// DeltaBitPackEncoder is an encoder for the DeltaBinary Packing format -/// as per the parquet spec. See: +/// DeltaBitPackEncoder is an encoder for the DeltaBinary Packing format. +/// As per the parquet spec. See: /// https://github.com/apache/parquet-format/blob/master/Encodings.md#delta-encoding-delta_binary_packed--5 /// -/// Consists of a header followed by blocks of delta encoded values binary -/// packed. +/// Consists of a header followed by blocks of delta encoded values binary. +/// Packed. /// -/// Format -/// [header] [block 1] [block 2] ... [block N] +/// Format. +/// [Header] [block 1] [block 2] ... [block N]. /// -/// Header -/// [block size] [number of mini blocks per block] [total value count] [first -/// value] +/// Header. +/// [Block size] [number of mini blocks per block] [total value count] +/// [first. Value]. /// -/// Block -/// [min delta] [list of bitwidths of the mini blocks] [miniblocks] +/// Block. +/// [Min delta] [list of bitwidths of the mini blocks] [miniblocks]. /// -/// Sets aside bytes at the start of the internal buffer where the header will -/// be written, and only writes the header when FlushValues is called before -/// returning it. +/// Sets aside bytes at the start of the internal buffer where the header will. +/// Be written, and only writes the header when FlushValues is called before. +/// Returning it. /// /// To encode a block, we will: /// -/// 1. Compute the differences between consecutive elements. For the first -/// element in the block, use the last element in the previous block or, in the -/// case of the first block, use the first value of the whole sequence, stored -/// in the header. +/// 1. Compute the differences between consecutive elements. For the first. +/// Element in the block, use the last element in the previous block or, in the. +/// Case of the first block, use the first value of the whole sequence, stored. +/// In the header. /// /// 2. Compute the frame of reference (the minimum of the deltas in the block). -/// Subtract this min delta from all deltas in the block. This guarantees that -/// all values are non-negative. +/// Subtract this min delta from all deltas in the block. This guarantees that. +/// All values are non-negative. /// -/// 3. Encode the frame of reference (min delta) as a zigzag ULEB128 int -/// followed by the bit widths of the mini blocks and the delta values (minus -/// the min delta) bit packed per mini block. +/// 3. Encode the frame of reference (min delta) as a zigzag ULEB128 int. +/// Followed by the bit widths of the mini blocks and the delta values (minus. +/// The min delta) bit packed per mini block. /// /// Supports only INT32 and INT64. template class DeltaBitPackEncoder : public EncoderImpl, virtual public TypedEncoder { - // Maximum possible header size + // Maximum possible header size. static constexpr uint32_t kMaxPageHeaderWriterSize = 32; static constexpr uint32_t kValuesPerBlock = - std::is_same_v ? 128 : 256; + std::is_same_v ? 128 : 256; static constexpr uint32_t kMiniBlocksPerBlock = 4; public: - using T = typename DType::c_type; + using T = typename DType::CType; using UT = std::make_unsigned_t; - using TypedEncoder::Put; + using TypedEncoder::put; explicit DeltaBitPackEncoder( const ColumnDescriptor* descr, MemoryPool* pool, - const uint32_t values_per_block = kValuesPerBlock, - const uint32_t mini_blocks_per_block = kMiniBlocksPerBlock) - : EncoderImpl(descr, Encoding::DELTA_BINARY_PACKED, pool), - values_per_block_(values_per_block), - mini_blocks_per_block_(mini_blocks_per_block), - values_per_mini_block_(values_per_block / mini_blocks_per_block), - deltas_(values_per_block, ::arrow::stl::allocator(pool)), - bits_buffer_(AllocateBuffer( + const uint32_t valuesPerBlock = kValuesPerBlock, + const uint32_t miniBlocksPerBlock = kMiniBlocksPerBlock) + : EncoderImpl(descr, Encoding::kDeltaBinaryPacked, pool), + valuesPerBlock_(valuesPerBlock), + miniBlocksPerBlock_(miniBlocksPerBlock), + valuesPerMiniBlock_(valuesPerBlock / miniBlocksPerBlock), + deltas_(valuesPerBlock, ::arrow::stl::allocator(pool)), + bitsBuffer_(allocateBuffer( pool, - (kMiniBlocksPerBlock + values_per_block) * sizeof(T))), + (kMiniBlocksPerBlock + valuesPerBlock) * sizeof(T))), sink_(pool), - bit_writer_( - bits_buffer_->mutable_data(), - static_cast(bits_buffer_->size())) { - if (values_per_block_ % 128 != 0) { + bitWriter_( + bitsBuffer_->mutable_data(), + static_cast(bitsBuffer_->size())) { + if (valuesPerBlock_ % 128 != 0) { throw ParquetException( "the number of values in a block must be multiple of 128, but it's " + - std::to_string(values_per_block_)); + std::to_string(valuesPerBlock_)); } - if (values_per_mini_block_ % 32 != 0) { + if (valuesPerMiniBlock_ % 32 != 0) { throw ParquetException( "the number of values in a miniblock must be multiple of 32, but it's " + - std::to_string(values_per_mini_block_)); + std::to_string(valuesPerMiniBlock_)); } - if (values_per_block % mini_blocks_per_block != 0) { + if (valuesPerBlock % miniBlocksPerBlock != 0) { throw ParquetException( "the number of values per block % number of miniblocks per block must be 0, " "but it's " + - std::to_string(values_per_block % mini_blocks_per_block)); + std::to_string(valuesPerBlock % miniBlocksPerBlock)); } - // Reserve enough space at the beginning of the buffer for largest possible - // header. + // Reserve enough space at the beginning of the buffer for largest possible. + // Header. PARQUET_THROW_NOT_OK(sink_.Advance(kMaxPageHeaderWriterSize)); } - std::shared_ptr FlushValues() override; + std::shared_ptr flushValues() override; - int64_t EstimatedDataEncodedSize() override { + int64_t estimatedDataEncodedSize() override { return sink_.length(); } - void Put(const ::arrow::Array& values) override; + void put(const ::arrow::Array& values) override; - void Put(const T* buffer, int num_values) override; + void put(const T* buffer, int numValues) override; - void PutSpaced( + void putSpaced( const T* src, - int num_values, - const uint8_t* valid_bits, - int64_t valid_bits_offset) override; + int numValues, + const uint8_t* validBits, + int64_t validBitsOffset) override; - void FlushBlock(); + void flushBlock(); private: - const uint32_t values_per_block_; - const uint32_t mini_blocks_per_block_; - const uint32_t values_per_mini_block_; - uint32_t values_current_block_{0}; - uint32_t total_value_count_{0}; - UT first_value_{0}; - UT current_value_{0}; + const uint32_t valuesPerBlock_; + const uint32_t miniBlocksPerBlock_; + const uint32_t valuesPerMiniBlock_; + uint32_t valuesCurrentBlock_{0}; + uint32_t totalValueCount_{0}; + UT firstValue_{0}; + UT currentValue_{0}; ArrowPoolVector deltas_; - std::shared_ptr bits_buffer_; + std::shared_ptr bitsBuffer_; ::arrow::BufferBuilder sink_; - BitWriter bit_writer_; + BitWriter bitWriter_; }; template -void DeltaBitPackEncoder::Put(const T* src, int num_values) { - if (num_values == 0) { +void DeltaBitPackEncoder::put(const T* src, int numValues) { + if (numValues == 0) { return; } int idx = 0; - if (total_value_count_ == 0) { - current_value_ = src[0]; - first_value_ = current_value_; + if (totalValueCount_ == 0) { + currentValue_ = src[0]; + firstValue_ = currentValue_; idx = 1; } - total_value_count_ += num_values; + totalValueCount_ += numValues; - while (idx < num_values) { + while (idx < numValues) { UT value = static_cast(src[idx]); - // Calculate deltas. The possible overflow is handled by use of unsigned - // integers making subtraction operations well-defined and correct even in - // case of overflow. Encoded integers will wrap back around on decoding. See - // http://en.wikipedia.org/wiki/Modular_arithmetic#Integers_modulo_n - deltas_[values_current_block_] = value - current_value_; - current_value_ = value; + // Calculate deltas. The possible overflow is handled by use of unsigned. + // Integers making subtraction operations well-defined and correct even in. + // Case of overflow. Encoded integers will wrap back around on decoding. + // See. http://en.wikipedia.org/wiki/Modular_arithmetic#Integers_modulo_n + deltas_[valuesCurrentBlock_] = value - currentValue_; + currentValue_ = value; idx++; - values_current_block_++; - if (values_current_block_ == values_per_block_) { - FlushBlock(); + valuesCurrentBlock_++; + if (valuesCurrentBlock_ == valuesPerBlock_) { + flushBlock(); } } } template -void DeltaBitPackEncoder::FlushBlock() { - if (values_current_block_ == 0) { +void DeltaBitPackEncoder::flushBlock() { + if (valuesCurrentBlock_ == 0) { return; } - const UT min_delta = *std::min_element( - deltas_.begin(), deltas_.begin() + values_current_block_); - bit_writer_.PutZigZagVlqInt(static_cast(min_delta)); + const UT minDelta = + *std::min_element(deltas_.begin(), deltas_.begin() + valuesCurrentBlock_); + bitWriter_.PutZigZagVlqInt(static_cast(minDelta)); - // Call to GetNextBytePtr reserves mini_blocks_per_block_ bytes of space to - // write bit widths of miniblocks as they become known during the encoding. - uint8_t* bit_width_data = bit_writer_.GetNextBytePtr(mini_blocks_per_block_); - VELOX_DCHECK(bit_width_data != nullptr); + // Call to GetNextBytePtr reserves mini_blocks_per_block_ bytes of space to. + // Write bit widths of miniblocks as they become known during the encoding. + uint8_t* bitWidthData = bitWriter_.GetNextBytePtr(miniBlocksPerBlock_); + VELOX_DCHECK(bitWidthData != nullptr); - const uint32_t num_miniblocks = static_cast(std::ceil( - static_cast(values_current_block_) / - static_cast(values_per_mini_block_))); - for (uint32_t i = 0; i < num_miniblocks; i++) { - const uint32_t values_current_mini_block = - std::min(values_per_mini_block_, values_current_block_); + const uint32_t numMiniblocks = static_cast(std::ceil( + static_cast(valuesCurrentBlock_) / + static_cast(valuesPerMiniBlock_))); + for (uint32_t i = 0; i < numMiniblocks; i++) { + const uint32_t valuesCurrentMiniBlock = + std::min(valuesPerMiniBlock_, valuesCurrentBlock_); - const uint32_t start = i * values_per_mini_block_; - const UT max_delta = *std::max_element( + const uint32_t start = i * valuesPerMiniBlock_; + const UT maxDelta = *std::max_element( deltas_.begin() + start, - deltas_.begin() + start + values_current_mini_block); + deltas_.begin() + start + valuesCurrentMiniBlock); - // The minimum number of bits required to write any of values in deltas_ - // vector. See overflow comment above. - const auto bit_width = bit_width_data[i] = - ::arrow::bit_util::NumRequiredBits(max_delta - min_delta); + // The minimum number of bits required to write any of values in deltas_. + // Vector. See overflow comment above. + const auto bitWidth = bitWidthData[i] = + ::arrow::bit_util::NumRequiredBits(maxDelta - minDelta); - for (uint32_t j = start; j < start + values_current_mini_block; j++) { + for (uint32_t j = start; j < start + valuesCurrentMiniBlock; j++) { // See overflow comment above. - const UT value = deltas_[j] - min_delta; - bit_writer_.PutValue(value, bit_width); + const UT value = deltas_[j] - minDelta; + bitWriter_.PutValue(value, bitWidth); } - // If there are not enough values to fill the last mini block, we pad the - // mini block with zeroes so that its length is the number of values in a - // full mini block multiplied by the bit width. - for (uint32_t j = values_current_mini_block; j < values_per_mini_block_; - j++) { - bit_writer_.PutValue(0, bit_width); + // If there are not enough values to fill the last mini block, we pad the. + // Mini block with zeroes so that its length is the number of values in a. + // Full mini block multiplied by the bit width. + for (uint32_t j = valuesCurrentMiniBlock; j < valuesPerMiniBlock_; j++) { + bitWriter_.PutValue(0, bitWidth); } - values_current_block_ -= values_current_mini_block; + valuesCurrentBlock_ -= valuesCurrentMiniBlock; } - // If, in the last block, less than - // miniblocks are needed to store the values, the bytes storing the bit widths - // of the unneeded miniblocks are still present, their value should be zero, - // but readers must accept arbitrary values as well. - for (uint32_t i = num_miniblocks; i < mini_blocks_per_block_; i++) { - bit_width_data[i] = 0; + // If, in the last block, less than . + // Miniblocks are needed to store the values, the bytes storing the bit + // widths. Of the unneeded miniblocks are still present, their value should be + // zero,. But readers must accept arbitrary values as well. + for (uint32_t i = numMiniblocks; i < miniBlocksPerBlock_; i++) { + bitWidthData[i] = 0; } - VELOX_DCHECK_EQ(values_current_block_, 0); + VELOX_DCHECK_EQ(valuesCurrentBlock_, 0); - bit_writer_.Flush(); + bitWriter_.Flush(); PARQUET_THROW_NOT_OK( - sink_.Append(bit_writer_.buffer(), bit_writer_.bytesWritten())); - bit_writer_.Clear(); + sink_.Append(bitWriter_.buffer(), bitWriter_.bytesWritten())); + bitWriter_.Clear(); } template -std::shared_ptr DeltaBitPackEncoder::FlushValues() { - if (values_current_block_ > 0) { - FlushBlock(); - } - PARQUET_ASSIGN_OR_THROW(auto buffer, sink_.Finish(/*shrink_to_fit=*/true)); - - uint8_t header_buffer_[kMaxPageHeaderWriterSize] = {}; - BitWriter header_writer(header_buffer_, sizeof(header_buffer_)); - if (!header_writer.PutVlqInt(values_per_block_) || - !header_writer.PutVlqInt(mini_blocks_per_block_) || - !header_writer.PutVlqInt(total_value_count_) || - !header_writer.PutZigZagVlqInt(static_cast(first_value_))) { +std::shared_ptr DeltaBitPackEncoder::flushValues() { + if (valuesCurrentBlock_ > 0) { + flushBlock(); + } + PARQUET_ASSIGN_OR_THROW(auto buffer, sink_.Finish(true)); + + uint8_t headerBuffer_[kMaxPageHeaderWriterSize] = {}; + BitWriter headerWriter(headerBuffer_, sizeof(headerBuffer_)); + if (!headerWriter.PutVlqInt(valuesPerBlock_) || + !headerWriter.PutVlqInt(miniBlocksPerBlock_) || + !headerWriter.PutVlqInt(totalValueCount_) || + !headerWriter.PutZigZagVlqInt(static_cast(firstValue_))) { throw ParquetException("header writing error"); } - header_writer.Flush(); + headerWriter.Flush(); - // We reserved enough space at the beginning of the buffer for largest - // possible header and data was written immediately after. We now write the - // header data immediately before the end of reserved space. - const size_t offset_bytes = - kMaxPageHeaderWriterSize - header_writer.bytesWritten(); + // We reserved enough space at the beginning of the buffer for largest. + // Possible header and data was written immediately after. We now write the. + // Header data immediately before the end of reserved space. + const size_t offsetBytes = + kMaxPageHeaderWriterSize - headerWriter.bytesWritten(); std::memcpy( - buffer->mutable_data() + offset_bytes, - header_buffer_, - header_writer.bytesWritten()); - - // Reset counter of cached values - total_value_count_ = 0; - // Reserve enough space at the beginning of the buffer for largest possible - // header. + buffer->mutable_data() + offsetBytes, + headerBuffer_, + headerWriter.bytesWritten()); + + // Reset counter of cached values. + totalValueCount_ = 0; + // Reserve enough space at the beginning of the buffer for largest possible. + // Header. PARQUET_THROW_NOT_OK(sink_.Advance(kMaxPageHeaderWriterSize)); // Excess bytes at the beginning are sliced off and ignored. - return SliceBuffer(buffer, offset_bytes); + return ::arrow::SliceBuffer(buffer, offsetBytes); } template <> -void DeltaBitPackEncoder::Put(const ::arrow::Array& values) { +void DeltaBitPackEncoder::put(const ::arrow::Array& values) { const ::arrow::ArrayData& data = *values.data(); if (values.type_id() != ::arrow::Type::INT32) { throw ParquetException( @@ -2669,9 +2634,9 @@ void DeltaBitPackEncoder::Put(const ::arrow::Array& values) { } if (values.null_count() == 0) { - Put(data.GetValues(1), static_cast(data.length)); + put(data.GetValues(1), static_cast(data.length)); } else { - PutSpaced( + putSpaced( data.GetValues(1), static_cast(data.length), data.GetValues(0, 0), @@ -2680,7 +2645,7 @@ void DeltaBitPackEncoder::Put(const ::arrow::Array& values) { } template <> -void DeltaBitPackEncoder::Put(const ::arrow::Array& values) { +void DeltaBitPackEncoder::put(const ::arrow::Array& values) { const ::arrow::ArrayData& data = *values.data(); if (values.type_id() != ::arrow::Type::INT64) { throw ParquetException( @@ -2691,9 +2656,9 @@ void DeltaBitPackEncoder::Put(const ::arrow::Array& values) { "Array cannot be longer than ", std::numeric_limits::max()); } if (values.null_count() == 0) { - Put(data.GetValues(1), static_cast(data.length)); + put(data.GetValues(1), static_cast(data.length)); } else { - PutSpaced( + putSpaced( data.GetValues(1), static_cast(data.length), data.GetValues(0, 0), @@ -2702,275 +2667,272 @@ void DeltaBitPackEncoder::Put(const ::arrow::Array& values) { } template -void DeltaBitPackEncoder::PutSpaced( +void DeltaBitPackEncoder::putSpaced( const T* src, - int num_values, - const uint8_t* valid_bits, - int64_t valid_bits_offset) { - if (valid_bits != NULLPTR) { - PARQUET_ASSIGN_OR_THROW( - auto buffer, - ::arrow::AllocateBuffer(num_values * sizeof(T), this->memory_pool())); + int numValues, + const uint8_t* validBits, + int64_t validBitsOffset) { + if (validBits != NULLPTR) { + auto buffer = allocateBuffer(this->memoryPool(), numValues * sizeof(T)); T* data = reinterpret_cast(buffer->mutable_data()); - int num_valid_values = ::arrow::util::internal::SpacedCompress( - src, num_values, valid_bits, valid_bits_offset, data); - Put(data, num_valid_values); + int numValidValues = ::arrow::util::internal::SpacedCompress( + src, numValues, validBits, validBitsOffset, data); + put(data, numValidValues); } else { - Put(src, num_values); + put(src, numValues); } } -// ---------------------------------------------------------------------- -// DeltaBitPackDecoder +// ----------------------------------------------------------------------. +// DeltaBitPackDecoder. template class DeltaBitPackDecoder : public DecoderImpl, virtual public TypedDecoder { public: - typedef typename DType::c_type T; + typedef typename DType::CType T; using UT = std::make_unsigned_t; explicit DeltaBitPackDecoder( const ColumnDescriptor* descr, MemoryPool* pool = ::arrow::default_memory_pool()) - : DecoderImpl(descr, Encoding::DELTA_BINARY_PACKED), pool_(pool) { - if (DType::type_num != Type::INT32 && DType::type_num != Type::INT64) { + : DecoderImpl(descr, Encoding::kDeltaBinaryPacked), pool_(pool) { + if (DType::typeNum != Type::kInt32 && DType::typeNum != Type::kInt64) { throw ParquetException( "Delta bit pack encoding should only be for integer data."); } } - void SetData(int num_values, const uint8_t* data, int len) override { - // num_values is equal to page's num_values, including null values in this - // page - this->num_values_ = num_values; + void setData(int numValues, const uint8_t* data, int len) override { + // Num_values is equal to page's num_values, including null values in this. + // Page. + this->numValues_ = numValues; decoder_ = std::make_shared(data, len); - InitHeader(); + initHeader(); } - // Set BitReader which is already initialized by DeltaLengthByteArrayDecoder - // or DeltaByteArrayDecoder - void SetDecoder(int num_values, std::shared_ptr decoder) { - this->num_values_ = num_values; + // Set BitReader which is already initialized by DeltaLengthByteArrayDecoder. + // Or DeltaByteArrayDecoder. + void setDecoder(int numValues, std::shared_ptr decoder) { + this->numValues_ = numValues; decoder_ = std::move(decoder); - InitHeader(); + initHeader(); } - int ValidValuesCount() { - // total_values_remaining_ in header ignores of null values - return static_cast(total_values_remaining_); + int validValuesCount() { + // Total_values_remaining_ in header ignores of null values. + return static_cast(totalValuesRemaining_); } - int Decode(T* buffer, int max_values) override { - return GetInternal(buffer, max_values); + int decode(T* buffer, int maxValues) override { + return getInternal(buffer, maxValues); } - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::Accumulator* out) override { - if (null_count != 0) { + if (nullCount != 0) { // TODO(ARROW-34660): implement DecodeArrow with null slots. ParquetException::NYI("Delta bit pack DecodeArrow with null slots"); } - std::vector values(num_values); - int decoded_count = GetInternal(values.data(), num_values); - PARQUET_THROW_NOT_OK(out->AppendValues(values.data(), decoded_count)); - return decoded_count; + std::vector values(numValues); + int decodedCount = getInternal(values.data(), numValues); + PARQUET_THROW_NOT_OK(out->AppendValues(values.data(), decodedCount)); + return decodedCount; } - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::DictAccumulator* out) override { - if (null_count != 0) { + if (nullCount != 0) { // TODO(ARROW-34660): implement DecodeArrow with null slots. ParquetException::NYI("Delta bit pack DecodeArrow with null slots"); } - std::vector values(num_values); - int decoded_count = GetInternal(values.data(), num_values); - PARQUET_THROW_NOT_OK(out->Reserve(decoded_count)); - for (int i = 0; i < decoded_count; ++i) { + std::vector values(numValues); + int decodedCount = getInternal(values.data(), numValues); + PARQUET_THROW_NOT_OK(out->Reserve(decodedCount)); + for (int i = 0; i < decodedCount; ++i) { PARQUET_THROW_NOT_OK(out->Append(values[i])); } - return decoded_count; + return decodedCount; } private: static constexpr int kMaxDeltaBitWidth = static_cast(sizeof(T) * 8); - void InitHeader() { - if (!decoder_->GetVlqInt(&values_per_block_) || - !decoder_->GetVlqInt(&mini_blocks_per_block_) || - !decoder_->GetVlqInt(&total_value_count_) || - !decoder_->GetZigZagVlqInt(&last_value_)) { - ParquetException::EofException("InitHeader EOF"); + void initHeader() { + if (!decoder_->GetVlqInt(&valuesPerBlock_) || + !decoder_->GetVlqInt(&miniBlocksPerBlock_) || + !decoder_->GetVlqInt(&totalValueCount_) || + !decoder_->GetZigZagVlqInt(&lastValue_)) { + ParquetException::eofException("InitHeader EOF"); } - if (values_per_block_ == 0) { + if (valuesPerBlock_ == 0) { throw ParquetException("cannot have zero value per block"); } - if (values_per_block_ % 128 != 0) { + if (valuesPerBlock_ % 128 != 0) { throw ParquetException( "the number of values in a block must be multiple of 128, but it's " + - std::to_string(values_per_block_)); + std::to_string(valuesPerBlock_)); } - if (mini_blocks_per_block_ == 0) { + if (miniBlocksPerBlock_ == 0) { throw ParquetException("cannot have zero miniblock per block"); } - values_per_mini_block_ = values_per_block_ / mini_blocks_per_block_; - if (values_per_mini_block_ == 0) { + valuesPerMiniBlock_ = valuesPerBlock_ / miniBlocksPerBlock_; + if (valuesPerMiniBlock_ == 0) { throw ParquetException("cannot have zero value per miniblock"); } - if (values_per_mini_block_ % 32 != 0) { + if (valuesPerMiniBlock_ % 32 != 0) { throw ParquetException( "the number of values in a miniblock must be multiple of 32, but it's " + - std::to_string(values_per_mini_block_)); + std::to_string(valuesPerMiniBlock_)); } - total_values_remaining_ = total_value_count_; - if (delta_bit_widths_ == nullptr) { - delta_bit_widths_ = AllocateBuffer(pool_, mini_blocks_per_block_); + totalValuesRemaining_ = totalValueCount_; + if (deltaBitWidths_ == nullptr) { + deltaBitWidths_ = allocateBuffer(pool_, miniBlocksPerBlock_); } else { - PARQUET_THROW_NOT_OK(delta_bit_widths_->Resize( - mini_blocks_per_block_, /*shrink_to_fit*/ false)); + PARQUET_THROW_NOT_OK(deltaBitWidths_->Resize( + miniBlocksPerBlock_, /*shrink_to_fit*/ false)); } - first_block_initialized_ = false; - values_remaining_current_mini_block_ = 0; + firstBlockInitialized_ = false; + valuesRemainingCurrentMiniBlock_ = 0; } - void InitBlock() { - VELOX_DCHECK_GT(total_values_remaining_, 0, "InitBlock called at EOF"); + void initBlock() { + VELOX_DCHECK_GT(totalValuesRemaining_, 0, "InitBlock called at EOF"); - if (!decoder_->GetZigZagVlqInt(&min_delta_)) - ParquetException::EofException("InitBlock EOF"); + if (!decoder_->GetZigZagVlqInt(&minDelta_)) + ParquetException::eofException("InitBlock EOF"); - // read the bitwidth of each miniblock - uint8_t* bit_width_data = delta_bit_widths_->mutable_data(); - for (uint32_t i = 0; i < mini_blocks_per_block_; ++i) { - if (!decoder_->GetAligned(1, bit_width_data + i)) { - ParquetException::EofException("Decode bit-width EOF"); + // Read the bitwidth of each miniblock. + uint8_t* bitWidthData = deltaBitWidths_->mutable_data(); + for (uint32_t i = 0; i < miniBlocksPerBlock_; ++i) { + if (!decoder_->GetAligned(1, bitWidthData + i)) { + ParquetException::eofException("Decode bit-width EOF"); } - // Note that non-conformant bitwidth entries are allowed by the Parquet - // spec for extraneous miniblocks in the last block (GH-14923), so we - // check the bitwidths when actually using them (see InitMiniBlock()). + // Note that non-conformant bitwidth entries are allowed by the Parquet. + // Spec for extraneous miniblocks in the last block (GH-14923), so we. + // Check the bitwidths when actually using them (see InitMiniBlock()). } - mini_block_idx_ = 0; - first_block_initialized_ = true; - InitMiniBlock(bit_width_data[0]); + miniBlockIdx_ = 0; + firstBlockInitialized_ = true; + initMiniBlock(bitWidthData[0]); } - void InitMiniBlock(int bit_width) { - if (ARROW_PREDICT_FALSE(bit_width > kMaxDeltaBitWidth)) { + void initMiniBlock(int bitWidth) { + if (ARROW_PREDICT_FALSE(bitWidth > kMaxDeltaBitWidth)) { throw ParquetException("delta bit width larger than integer bit width"); } - delta_bit_width_ = bit_width; - values_remaining_current_mini_block_ = values_per_mini_block_; + deltaBitWidth_ = bitWidth; + valuesRemainingCurrentMiniBlock_ = valuesPerMiniBlock_; } - int GetInternal(T* buffer, int max_values) { - max_values = static_cast( - std::min(max_values, total_values_remaining_)); - if (max_values == 0) { + int getInternal(T* buffer, int maxValues) { + maxValues = + static_cast(std::min(maxValues, totalValuesRemaining_)); + if (maxValues == 0) { return 0; } int i = 0; - while (i < max_values) { - if (ARROW_PREDICT_FALSE(values_remaining_current_mini_block_ == 0)) { - if (ARROW_PREDICT_FALSE(!first_block_initialized_)) { - buffer[i++] = last_value_; + while (i < maxValues) { + if (ARROW_PREDICT_FALSE(valuesRemainingCurrentMiniBlock_ == 0)) { + if (ARROW_PREDICT_FALSE(!firstBlockInitialized_)) { + buffer[i++] = lastValue_; VELOX_DCHECK_EQ(i, 1); // we're at the beginning of the page - if (ARROW_PREDICT_FALSE(i == max_values)) { - // When block is uninitialized and i reaches max_values we have two - // different possibilities: - // 1. total_value_count_ == 1, which means that the page may have - // only one value (encoded in the header), and we should not - // initialize any block. - // 2. total_value_count_ != 1, which means we should initialize the - // incoming block for subsequent reads. - if (total_value_count_ != 1) { - InitBlock(); + if (ARROW_PREDICT_FALSE(i == maxValues)) { + // When block is uninitialized and i reaches max_values we have two. + // Different possibilities: + // 1. Total_value_count_ == 1, which means that the page may have. + // Only one value (encoded in the header), and we should not. + // Initialize any block. + // 2. Total_value_count_ != 1, which means we should initialize the. + // Incoming block for subsequent reads. + if (totalValueCount_ != 1) { + initBlock(); } break; } - InitBlock(); + initBlock(); } else { - ++mini_block_idx_; - if (mini_block_idx_ < mini_blocks_per_block_) { - InitMiniBlock(delta_bit_widths_->data()[mini_block_idx_]); + ++miniBlockIdx_; + if (miniBlockIdx_ < miniBlocksPerBlock_) { + initMiniBlock(deltaBitWidths_->data()[miniBlockIdx_]); } else { - InitBlock(); + initBlock(); } } } - int values_decode = std::min( - values_remaining_current_mini_block_, - static_cast(max_values - i)); - if (decoder_->GetBatch(delta_bit_width_, buffer + i, values_decode) != - values_decode) { - ParquetException::EofException(); + int valuesDecode = std::min( + valuesRemainingCurrentMiniBlock_, + static_cast(maxValues - i)); + if (decoder_->GetBatch(deltaBitWidth_, buffer + i, valuesDecode) != + valuesDecode) { + ParquetException::eofException(); } - for (int j = 0; j < values_decode; ++j) { - // Addition between min_delta, packed int and last_value should be - // treated as unsigned addition. Overflow is as expected. - buffer[i + j] = static_cast(min_delta_) + - static_cast(buffer[i + j]) + static_cast(last_value_); - last_value_ = buffer[i + j]; + for (int j = 0; j < valuesDecode; ++j) { + // Addition between min_delta, packed int and last_value should be. + // Treated as unsigned addition. Overflow is as expected. + buffer[i + j] = static_cast(minDelta_) + + static_cast(buffer[i + j]) + static_cast(lastValue_); + lastValue_ = buffer[i + j]; } - values_remaining_current_mini_block_ -= values_decode; - i += values_decode; - } - total_values_remaining_ -= max_values; - this->num_values_ -= max_values; - - if (ARROW_PREDICT_FALSE(total_values_remaining_ == 0)) { - uint32_t padding_bits = - values_remaining_current_mini_block_ * delta_bit_width_; - // skip the padding bits - if (!decoder_->Advance(padding_bits)) { - ParquetException::EofException(); + valuesRemainingCurrentMiniBlock_ -= valuesDecode; + i += valuesDecode; + } + totalValuesRemaining_ -= maxValues; + this->numValues_ -= maxValues; + + if (ARROW_PREDICT_FALSE(totalValuesRemaining_ == 0)) { + uint32_t paddingBits = valuesRemainingCurrentMiniBlock_ * deltaBitWidth_; + // Skip the padding bits. + if (!decoder_->Advance(paddingBits)) { + ParquetException::eofException(); } - values_remaining_current_mini_block_ = 0; + valuesRemainingCurrentMiniBlock_ = 0; } - return max_values; + return maxValues; } MemoryPool* pool_; std::shared_ptr decoder_; - uint32_t values_per_block_; - uint32_t mini_blocks_per_block_; - uint32_t values_per_mini_block_; - uint32_t total_value_count_; - - uint32_t total_values_remaining_; - // Remaining values in current mini block. If the current block is the last - // mini block, values_remaining_current_mini_block_ may greater than - // total_values_remaining_. - uint32_t values_remaining_current_mini_block_; - - // If the page doesn't contain any block, `first_block_initialized_` will - // always be false. Otherwise, it will be true when first block initialized. - bool first_block_initialized_; - T min_delta_; - uint32_t mini_block_idx_; - std::shared_ptr delta_bit_widths_; - int delta_bit_width_; - - T last_value_; + uint32_t valuesPerBlock_; + uint32_t miniBlocksPerBlock_; + uint32_t valuesPerMiniBlock_; + uint32_t totalValueCount_; + + uint32_t totalValuesRemaining_; + // Remaining values in current mini block. If the current block is the last. + // Mini block, values_remaining_current_mini_block_ may greater than. + // Total_values_remaining_. + uint32_t valuesRemainingCurrentMiniBlock_; + + // If the page doesn't contain any block, `first_block_initialized_` will. + // Always be false. Otherwise, it will be true when first block initialized. + bool firstBlockInitialized_; + T minDelta_; + uint32_t miniBlockIdx_; + std::shared_ptr deltaBitWidths_; + int deltaBitWidth_; + + T lastValue_; }; -// ---------------------------------------------------------------------- -// DELTA_LENGTH_BYTE_ARRAY +// ----------------------------------------------------------------------. +// DELTA_LENGTH_BYTE_ARRAY. -// ---------------------------------------------------------------------- -// DeltaLengthByteArrayEncoder +// ----------------------------------------------------------------------. +// DeltaLengthByteArrayEncoder. template class DeltaLengthByteArrayEncoder : public EncoderImpl, @@ -2981,33 +2943,33 @@ class DeltaLengthByteArrayEncoder : public EncoderImpl, MemoryPool* pool) : EncoderImpl( descr, - Encoding::DELTA_LENGTH_BYTE_ARRAY, + Encoding::kDeltaLengthByteArray, pool = ::arrow::default_memory_pool()), sink_(pool), - length_encoder_(nullptr, pool), - encoded_size_{0} {} + lengthEncoder_(nullptr, pool), + encodedSize_{0} {} - std::shared_ptr FlushValues() override; + std::shared_ptr flushValues() override; - int64_t EstimatedDataEncodedSize() override { - return encoded_size_ + length_encoder_.EstimatedDataEncodedSize(); + int64_t estimatedDataEncodedSize() override { + return encodedSize_ + lengthEncoder_.estimatedDataEncodedSize(); } - using TypedEncoder::Put; + using TypedEncoder::put; - void Put(const ::arrow::Array& values) override; + void put(const ::arrow::Array& values) override; - void Put(const T* buffer, int num_values) override; + void put(const T* buffer, int numValues) override; - void PutSpaced( + void putSpaced( const T* src, - int num_values, - const uint8_t* valid_bits, - int64_t valid_bits_offset) override; + int numValues, + const uint8_t* validBits, + int64_t validBitsOffset) override; protected: template - void PutBinaryArray(const ArrayType& array) { + void putBinaryArray(const ArrayType& array) { PARQUET_THROW_NOT_OK( ::arrow::VisitArraySpanInline( *array.data(), @@ -3016,7 +2978,7 @@ class DeltaLengthByteArrayEncoder : public EncoderImpl, return Status::Invalid( "Parquet cannot store strings with size 2GB or more"); } - length_encoder_.Put({static_cast(view.length())}, 1); + lengthEncoder_.put({static_cast(view.length())}, 1); PARQUET_THROW_NOT_OK(sink_.Append(view.data(), view.length())); return Status::OK(); }, @@ -3024,90 +2986,88 @@ class DeltaLengthByteArrayEncoder : public EncoderImpl, } ::arrow::BufferBuilder sink_; - DeltaBitPackEncoder length_encoder_; - uint32_t encoded_size_; + DeltaBitPackEncoder lengthEncoder_; + uint32_t encodedSize_; }; template -void DeltaLengthByteArrayEncoder::Put(const ::arrow::Array& values) { - AssertBaseBinary(values); +void DeltaLengthByteArrayEncoder::put(const ::arrow::Array& values) { + assertBaseBinary(values); if (::arrow::is_binary_like(values.type_id())) { - PutBinaryArray(checked_cast(values)); + putBinaryArray(checked_cast(values)); } else { - PutBinaryArray(checked_cast(values)); + putBinaryArray(checked_cast(values)); } } template -void DeltaLengthByteArrayEncoder::Put(const T* src, int num_values) { - if (num_values == 0) { +void DeltaLengthByteArrayEncoder::put(const T* src, int numValues) { + if (numValues == 0) { return; } constexpr int kBatchSize = 256; std::array lengths; - uint32_t total_increment_size = 0; - for (int idx = 0; idx < num_values; idx += kBatchSize) { - const int batch_size = std::min(kBatchSize, num_values - idx); - for (int j = 0; j < batch_size; ++j) { + uint32_t totalIncrementSize = 0; + for (int idx = 0; idx < numValues; idx += kBatchSize) { + const int batchSize = std::min(kBatchSize, numValues - idx); + for (int j = 0; j < batchSize; ++j) { const int32_t len = src[idx + j].len; - if (AddWithOverflow(total_increment_size, len, &total_increment_size)) { + if (AddWithOverflow(totalIncrementSize, len, &totalIncrementSize)) { throw ParquetException("excess expansion in DELTA_LENGTH_BYTE_ARRAY"); } lengths[j] = len; } - length_encoder_.Put(lengths.data(), batch_size); + lengthEncoder_.put(lengths.data(), batchSize); } - if (AddWithOverflow(encoded_size_, total_increment_size, &encoded_size_)) { + if (AddWithOverflow(encodedSize_, totalIncrementSize, &encodedSize_)) { throw ParquetException("excess expansion in DELTA_LENGTH_BYTE_ARRAY"); } - PARQUET_THROW_NOT_OK(sink_.Reserve(total_increment_size)); - for (int idx = 0; idx < num_values; idx++) { + PARQUET_THROW_NOT_OK(sink_.Reserve(totalIncrementSize)); + for (int idx = 0; idx < numValues; idx++) { sink_.UnsafeAppend(src[idx].ptr, src[idx].len); } } template -void DeltaLengthByteArrayEncoder::PutSpaced( +void DeltaLengthByteArrayEncoder::putSpaced( const T* src, - int num_values, - const uint8_t* valid_bits, - int64_t valid_bits_offset) { - if (valid_bits != NULLPTR) { - PARQUET_ASSIGN_OR_THROW( - auto buffer, - ::arrow::AllocateBuffer(num_values * sizeof(T), this->memory_pool())); + int numValues, + const uint8_t* validBits, + int64_t validBitsOffset) { + if (validBits != NULLPTR) { + auto buffer = allocateBuffer(this->memoryPool(), numValues * sizeof(T)); T* data = reinterpret_cast(buffer->mutable_data()); - int num_valid_values = ::arrow::util::internal::SpacedCompress( - src, num_values, valid_bits, valid_bits_offset, data); - Put(data, num_valid_values); + int numValidValues = ::arrow::util::internal::SpacedCompress( + src, numValues, validBits, validBitsOffset, data); + put(data, numValidValues); } else { - Put(src, num_values); + put(src, numValues); } } template -std::shared_ptr DeltaLengthByteArrayEncoder::FlushValues() { - std::shared_ptr encoded_lengths = length_encoder_.FlushValues(); +std::shared_ptr DeltaLengthByteArrayEncoder::flushValues() { + std::shared_ptr encodedLengths = lengthEncoder_.flushValues(); std::shared_ptr data; PARQUET_THROW_NOT_OK(sink_.Finish(&data)); sink_.Reset(); - PARQUET_THROW_NOT_OK(sink_.Resize(encoded_lengths->size() + data->size())); + PARQUET_THROW_NOT_OK(sink_.Resize(encodedLengths->size() + data->size())); PARQUET_THROW_NOT_OK( - sink_.Append(encoded_lengths->data(), encoded_lengths->size())); + sink_.Append(encodedLengths->data(), encodedLengths->size())); PARQUET_THROW_NOT_OK(sink_.Append(data->data(), data->size())); std::shared_ptr buffer; PARQUET_THROW_NOT_OK(sink_.Finish(&buffer, true)); - encoded_size_ = 0; + encodedSize_ = 0; return buffer; } -// ---------------------------------------------------------------------- -// DeltaLengthByteArrayDecoder +// ----------------------------------------------------------------------. +// DeltaLengthByteArrayDecoder. class DeltaLengthByteArrayDecoder : public DecoderImpl, virtual public TypedDecoder { @@ -3115,155 +3075,152 @@ class DeltaLengthByteArrayDecoder : public DecoderImpl, explicit DeltaLengthByteArrayDecoder( const ColumnDescriptor* descr, MemoryPool* pool = ::arrow::default_memory_pool()) - : DecoderImpl(descr, Encoding::DELTA_LENGTH_BYTE_ARRAY), - len_decoder_(nullptr, pool), - buffered_length_(AllocateBuffer(pool, 0)) {} + : DecoderImpl(descr, Encoding::kDeltaLengthByteArray), + lenDecoder_(nullptr, pool), + bufferedLength_(allocateBuffer(pool, 0)) {} - void SetData(int num_values, const uint8_t* data, int len) override { - DecoderImpl::SetData(num_values, data, len); + void setData(int numValues, const uint8_t* data, int len) override { + DecoderImpl::setData(numValues, data, len); decoder_ = std::make_shared(data, len); - DecodeLengths(); + decodeLengths(); } - int Decode(ByteArray* buffer, int max_values) override { - // Decode up to `max_values` strings into an internal buffer - // and reference them into `buffer`. - max_values = std::min(max_values, num_valid_values_); - VELOX_DCHECK_GE(max_values, 0); - if (max_values == 0) { + int decode(ByteArray* buffer, int maxValues) override { + // Decode up to `max_values` strings into an internal buffer. + // And reference them into `buffer`. + maxValues = std::min(maxValues, numValidValues_); + VELOX_DCHECK_GE(maxValues, 0); + if (maxValues == 0) { return 0; } - int32_t data_size = 0; - const int32_t* length_ptr = - reinterpret_cast(buffered_length_->data()) + - length_idx_; - int bytes_offset = len_ - decoder_->bytesLeft(); - for (int i = 0; i < max_values; ++i) { - int32_t len = length_ptr[i]; + int32_t dataSize = 0; + const int32_t* lengthPtr = + reinterpret_cast(bufferedLength_->data()) + lengthIdx_; + int bytesOffset = len_ - decoder_->bytesLeft(); + for (int i = 0; i < maxValues; ++i) { + int32_t len = lengthPtr[i]; if (ARROW_PREDICT_FALSE(len < 0)) { throw ParquetException("negative string delta length"); } buffer[i].len = len; - if (AddWithOverflow(data_size, len, &data_size)) { + if (AddWithOverflow(dataSize, len, &dataSize)) { throw ParquetException("excess expansion in DELTA_(LENGTH_)BYTE_ARRAY"); } } - length_idx_ += max_values; + lengthIdx_ += maxValues; if (ARROW_PREDICT_FALSE( - !decoder_->Advance(8 * static_cast(data_size)))) { - ParquetException::EofException(); + !decoder_->Advance(8 * static_cast(dataSize)))) { + ParquetException::eofException(); } - const uint8_t* data_ptr = data_ + bytes_offset; - for (int i = 0; i < max_values; ++i) { - buffer[i].ptr = data_ptr; - data_ptr += buffer[i].len; + const uint8_t* dataPtr = data_ + bytesOffset; + for (int i = 0; i < maxValues; ++i) { + buffer[i].ptr = dataPtr; + dataPtr += buffer[i].len; } - this->num_values_ -= max_values; - num_valid_values_ -= max_values; - return max_values; + this->numValues_ -= maxValues; + numValidValues_ -= maxValues; + return maxValues; } - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::Accumulator* out) override { int result = 0; - PARQUET_THROW_NOT_OK(DecodeArrowDense( - num_values, null_count, valid_bits, valid_bits_offset, out, &result)); + PARQUET_THROW_NOT_OK(decodeArrowDense( + numValues, nullCount, validBits, validBitsOffset, out, &result)); return result; } - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::DictAccumulator* out) override { ParquetException::NYI( "DecodeArrow of DictAccumulator for DeltaLengthByteArrayDecoder"); } private: - // Decode all the encoded lengths. The decoder_ will be at the start of the - // encoded data after that. - void DecodeLengths() { - len_decoder_.SetDecoder(num_values_, decoder_); - - // get the number of encoded lengths - int num_length = len_decoder_.ValidValuesCount(); - PARQUET_THROW_NOT_OK( - buffered_length_->Resize(num_length * sizeof(int32_t))); - - // call len_decoder_.Decode to decode all the lengths. - // all the lengths are buffered in buffered_length_. - VELOX_DEBUG_ONLY int ret = len_decoder_.Decode( - reinterpret_cast(buffered_length_->mutable_data()), - num_length); - VELOX_DCHECK_EQ(ret, num_length); - length_idx_ = 0; - num_valid_values_ = num_length; - } - - Status DecodeArrowDense( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + // Decode all the encoded lengths. The decoder_ will be at the start of the. + // Encoded data after that. + void decodeLengths() { + lenDecoder_.setDecoder(numValues_, decoder_); + + // Get the number of encoded lengths. + int numLength = lenDecoder_.validValuesCount(); + PARQUET_THROW_NOT_OK(bufferedLength_->Resize(numLength * sizeof(int32_t))); + + // Call len_decoder_.Decode to decode all the lengths. + // All the lengths are buffered in buffered_length_. + VELOX_DEBUG_ONLY int ret = lenDecoder_.decode( + reinterpret_cast(bufferedLength_->mutable_data()), numLength); + VELOX_DCHECK_EQ(ret, numLength); + lengthIdx_ = 0; + numValidValues_ = numLength; + } + + Status decodeArrowDense( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::Accumulator* out, - int* out_num_values) { + int* outNumValues) { ArrowBinaryHelper helper(out); - std::vector values(num_values - null_count); - const int num_valid_values = Decode(values.data(), num_values - null_count); - if (ARROW_PREDICT_FALSE(num_values - null_count != num_valid_values)) { + std::vector values(numValues - nullCount); + const int numValidValues = decode(values.data(), numValues - nullCount); + if (ARROW_PREDICT_FALSE(numValues - nullCount != numValidValues)) { throw ParquetException( "Expected to decode ", - num_values - null_count, + numValues - nullCount, " values, but decoded ", - num_valid_values, + numValidValues, " values."); } - auto values_ptr = values.data(); - int value_idx = 0; + auto valuesPtr = values.data(); + int valueIdx = 0; RETURN_NOT_OK(VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { - const auto& val = values_ptr[value_idx]; - if (ARROW_PREDICT_FALSE(!helper.CanFit(val.len))) { - RETURN_NOT_OK(helper.PushChunk()); + const auto& val = valuesPtr[valueIdx]; + if (ARROW_PREDICT_FALSE(!helper.canFit(val.len))) { + RETURN_NOT_OK(helper.pushChunk()); } - RETURN_NOT_OK(helper.Append(val.ptr, static_cast(val.len))); - ++value_idx; + RETURN_NOT_OK(helper.append(val.ptr, static_cast(val.len))); + ++valueIdx; return Status::OK(); }, [&]() { - RETURN_NOT_OK(helper.AppendNull()); - --null_count; + RETURN_NOT_OK(helper.appendNull()); + --nullCount; return Status::OK(); })); - VELOX_DCHECK_EQ(null_count, 0); - *out_num_values = num_valid_values; + VELOX_DCHECK_EQ(nullCount, 0); + *outNumValues = numValidValues; return Status::OK(); } std::shared_ptr decoder_; - DeltaBitPackDecoder len_decoder_; - int num_valid_values_; - uint32_t length_idx_; - std::shared_ptr buffered_length_; + DeltaBitPackDecoder lenDecoder_; + int numValidValues_; + uint32_t lengthIdx_; + std::shared_ptr bufferedLength_; }; -// ---------------------------------------------------------------------- -// RLE_BOOLEAN_ENCODER +// ----------------------------------------------------------------------. +// RLE_BOOLEAN_ENCODER. class RleBooleanEncoder final : public EncoderImpl, virtual public BooleanEncoder { @@ -3271,199 +3228,195 @@ class RleBooleanEncoder final : public EncoderImpl, explicit RleBooleanEncoder( const ColumnDescriptor* descr, ::arrow::MemoryPool* pool) - : EncoderImpl(descr, Encoding::RLE, pool), - buffered_append_values_(::arrow::stl::allocator(pool)) {} + : EncoderImpl(descr, Encoding::kRle, pool), + bufferedAppendValues_(::arrow::stl::allocator(pool)) {} - int64_t EstimatedDataEncodedSize() override { - return kRleLengthInBytes + MaxRleBufferSize(); + int64_t estimatedDataEncodedSize() override { + return kRleLengthInBytes + maxRleBufferSize(); } - std::shared_ptr FlushValues() override; + std::shared_ptr flushValues() override; - void Put(const T* buffer, int num_values) override; - void Put(const ::arrow::Array& values) override { + void put(const T* buffer, int numValues) override; + void put(const ::arrow::Array& values) override { if (values.type_id() != ::arrow::Type::BOOL) { throw ParquetException( "RleBooleanEncoder expects BooleanArray, got ", values.type()->ToString()); } - const auto& boolean_array = + const auto& booleanArray = checked_cast(values); if (values.null_count() == 0) { - for (int i = 0; i < boolean_array.length(); ++i) { - // null_count == 0, so just call Value directly is ok. - buffered_append_values_.push_back(boolean_array.Value(i)); + for (int i = 0; i < booleanArray.length(); ++i) { + // Null_count == 0, so just call Value directly is ok. + bufferedAppendValues_.push_back(booleanArray.Value(i)); } } else { - PARQUET_THROW_NOT_OK(::arrow::VisitArraySpanInline<::arrow::BooleanType>( - *boolean_array.data(), - [&](bool value) { - buffered_append_values_.push_back(value); - return Status::OK(); - }, - []() { return Status::OK(); })); + PARQUET_THROW_NOT_OK( + ::arrow::VisitArraySpanInline<::arrow::BooleanType>( + *booleanArray.data(), + [&](bool value) { + bufferedAppendValues_.push_back(value); + return Status::OK(); + }, + []() { return Status::OK(); })); } } - void PutSpaced( + void putSpaced( const T* src, - int num_values, - const uint8_t* valid_bits, - int64_t valid_bits_offset) override { - if (valid_bits != NULLPTR) { - PARQUET_ASSIGN_OR_THROW( - auto buffer, - ::arrow::AllocateBuffer(num_values * sizeof(T), this->memory_pool())); + int numValues, + const uint8_t* validBits, + int64_t validBitsOffset) override { + if (validBits != NULLPTR) { + auto buffer = allocateBuffer(this->memoryPool(), numValues * sizeof(T)); T* data = reinterpret_cast(buffer->mutable_data()); - int num_valid_values = ::arrow::util::internal::SpacedCompress( - src, num_values, valid_bits, valid_bits_offset, data); - Put(data, num_valid_values); + int numValidValues = ::arrow::util::internal::SpacedCompress( + src, numValues, validBits, validBitsOffset, data); + put(data, numValidValues); } else { - Put(src, num_values); + put(src, numValues); } } - void Put(const std::vector& src, int num_values) override; + void put(const std::vector& src, int numValues) override; protected: template - void PutImpl(const SequenceType& src, int num_values); + void putImpl(const SequenceType& src, int numValues); - int MaxRleBufferSize() const noexcept { - return RlePreserveBufferSize( - static_cast(buffered_append_values_.size()), kBitWidth); + int maxRleBufferSize() const noexcept { + return rlePreserveBufferSize( + static_cast(bufferedAppendValues_.size()), kBitWidth); } constexpr static int32_t kBitWidth = 1; - /// 4 bytes in little-endian, which indicates the length. + /// 4 Bytes in little-endian, which indicates the length. constexpr static int32_t kRleLengthInBytes = 4; - // std::vector in C++ is tricky, because it's a bitmap. - // Here RleBooleanEncoder will only append values into it, and - // dump values into Buffer, so using it here is ok. - ArrowPoolVector buffered_append_values_; + // Std::vector in C++ is tricky, because it's a bitmap. + // Here RleBooleanEncoder will only append values into it, and. + // Dump values into Buffer, so using it here is ok. + ArrowPoolVector bufferedAppendValues_; }; -void RleBooleanEncoder::Put(const bool* src, int num_values) { - PutImpl(src, num_values); +void RleBooleanEncoder::put(const bool* src, int numValues) { + putImpl(src, numValues); } -void RleBooleanEncoder::Put(const std::vector& src, int num_values) { - PutImpl(src, num_values); +void RleBooleanEncoder::put(const std::vector& src, int numValues) { + putImpl(src, numValues); } template -void RleBooleanEncoder::PutImpl(const SequenceType& src, int num_values) { - for (int i = 0; i < num_values; ++i) { - buffered_append_values_.push_back(src[i]); +void RleBooleanEncoder::putImpl(const SequenceType& src, int numValues) { + for (int i = 0; i < numValues; ++i) { + bufferedAppendValues_.push_back(src[i]); } } -std::shared_ptr RleBooleanEncoder::FlushValues() { - int rle_buffer_size_max = MaxRleBufferSize(); +std::shared_ptr RleBooleanEncoder::flushValues() { + int rleBufferSizeMax = maxRleBufferSize(); std::shared_ptr buffer = - AllocateBuffer(this->pool_, rle_buffer_size_max + kRleLengthInBytes); + allocateBuffer(this->pool_, rleBufferSizeMax + kRleLengthInBytes); RleEncoder encoder( buffer->mutable_data() + kRleLengthInBytes, - rle_buffer_size_max, + rleBufferSizeMax, /*bit_width*/ kBitWidth); - for (bool value : buffered_append_values_) { + for (bool value : bufferedAppendValues_) { encoder.Put(value ? 1 : 0); } encoder.Flush(); ::arrow::util::SafeStore( buffer->mutable_data(), ::arrow::bit_util::ToLittleEndian(encoder.len())); PARQUET_THROW_NOT_OK(buffer->Resize(kRleLengthInBytes + encoder.len())); - buffered_append_values_.clear(); + bufferedAppendValues_.clear(); return buffer; } -// ---------------------------------------------------------------------- -// RLE_BOOLEAN_DECODER +// ----------------------------------------------------------------------. +// RLE_BOOLEAN_DECODER. class RleBooleanDecoder : public DecoderImpl, virtual public BooleanDecoder { public: explicit RleBooleanDecoder(const ColumnDescriptor* descr) - : DecoderImpl(descr, Encoding::RLE) {} + : DecoderImpl(descr, Encoding::kRle) {} - void SetData(int num_values, const uint8_t* data, int len) override { - num_values_ = num_values; - uint32_t num_bytes = 0; + void setData(int numValues, const uint8_t* data, int len) override { + numValues_ = numValues; + uint32_t numBytes = 0; if (len < 4) { throw ParquetException( "Received invalid length : " + std::to_string(len) + " (corrupt data page?)"); } - // Load the first 4 bytes in little-endian, which indicates the length - num_bytes = ::arrow::bit_util::FromLittleEndian(SafeLoadAs(data)); - if (num_bytes > static_cast(len - 4)) { + // Load the first 4 bytes in little-endian, which indicates the length. + numBytes = ::arrow::bit_util::FromLittleEndian(SafeLoadAs(data)); + if (numBytes > static_cast(len - 4)) { throw ParquetException( - "Received invalid number of bytes : " + std::to_string(num_bytes) + + "Received invalid number of bytes : " + std::to_string(numBytes) + " (corrupt data page?)"); } - auto decoder_data = data + 4; + auto decoderData = data + 4; if (decoder_ == nullptr) { - decoder_ = std::make_shared( - decoder_data, - num_bytes, - /*bit_width=*/1); + decoder_ = std::make_shared(decoderData, numBytes, 1); } else { - decoder_->Reset(decoder_data, num_bytes, /*bitWidth=*/1); + decoder_->Reset(decoderData, numBytes, 1); } } - int Decode(bool* buffer, int max_values) override { - max_values = std::min(max_values, num_values_); + int decode(bool* buffer, int maxValues) override { + maxValues = std::min(maxValues, numValues_); - if (decoder_->GetBatch(buffer, max_values) != max_values) { - ParquetException::EofException(); + if (decoder_->GetBatch(buffer, maxValues) != maxValues) { + ParquetException::eofException(); } - num_values_ -= max_values; - return max_values; + numValues_ -= maxValues; + return maxValues; } - int Decode(uint8_t* buffer, int max_values) override { + int decode(uint8_t* buffer, int maxValues) override { ParquetException::NYI("Decode(uint8_t*, int) for RleBooleanDecoder"); } - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::Accumulator* out) override { - if (null_count != 0) { + if (nullCount != 0) { // TODO(ARROW-34660): implement DecodeArrow with null slots. ParquetException::NYI("RleBoolean DecodeArrow with null slots"); } constexpr int kBatchSize = 1024; std::array values; - int sum_decode_count = 0; + int sumDecodeCount = 0; do { - int current_batch = std::min(kBatchSize, num_values); - int decoded_count = decoder_->GetBatch(values.data(), current_batch); - if (decoded_count == 0) { + int currentBatch = std::min(kBatchSize, numValues); + int decodedCount = decoder_->GetBatch(values.data(), currentBatch); + if (decodedCount == 0) { break; } - sum_decode_count += decoded_count; - PARQUET_THROW_NOT_OK(out->Reserve(sum_decode_count)); - for (int i = 0; i < decoded_count; ++i) { + sumDecodeCount += decodedCount; + PARQUET_THROW_NOT_OK(out->Reserve(sumDecodeCount)); + for (int i = 0; i < decodedCount; ++i) { PARQUET_THROW_NOT_OK(out->Append(values[i])); } - num_values -= decoded_count; - } while (num_values > 0); - return sum_decode_count; + numValues -= decodedCount; + } while (numValues > 0); + return sumDecodeCount; } - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) override { + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) override { ParquetException::NYI("DecodeArrow for RleBooleanDecoder"); } @@ -3471,8 +3424,8 @@ class RleBooleanDecoder : public DecoderImpl, virtual public BooleanDecoder { std::shared_ptr decoder_; }; -// ---------------------------------------------------------------------- -// DELTA_BYTE_ARRAY +// ----------------------------------------------------------------------. +// DELTA_BYTE_ARRAY. class DeltaByteArrayDecoder : public DecoderImpl, virtual public TypedDecoder { @@ -3480,224 +3433,224 @@ class DeltaByteArrayDecoder : public DecoderImpl, explicit DeltaByteArrayDecoder( const ColumnDescriptor* descr, MemoryPool* pool = ::arrow::default_memory_pool()) - : DecoderImpl(descr, Encoding::DELTA_BYTE_ARRAY), - prefix_len_decoder_(nullptr, pool), - suffix_decoder_(nullptr, pool), - last_value_in_previous_page_(""), - buffered_prefix_length_(AllocateBuffer(pool, 0)), - buffered_data_(AllocateBuffer(pool, 0)) {} - - void SetData(int num_values, const uint8_t* data, int len) override { - num_values_ = num_values; + : DecoderImpl(descr, Encoding::kDeltaByteArray), + prefixLenDecoder_(nullptr, pool), + suffixDecoder_(nullptr, pool), + lastValueInPreviousPage_(""), + bufferedPrefixLength_(allocateBuffer(pool, 0)), + bufferedData_(allocateBuffer(pool, 0)) {} + + void setData(int numValues, const uint8_t* data, int len) override { + numValues_ = numValues; decoder_ = std::make_shared(data, len); - prefix_len_decoder_.SetDecoder(num_values, decoder_); + prefixLenDecoder_.setDecoder(numValues, decoder_); - // get the number of encoded prefix lengths - int num_prefix = prefix_len_decoder_.ValidValuesCount(); - // call prefix_len_decoder_.Decode to decode all the prefix lengths. - // all the prefix lengths are buffered in buffered_prefix_length_. + // Get the number of encoded prefix lengths. + int numPrefix = prefixLenDecoder_.validValuesCount(); + // Call prefix_len_decoder_.Decode to decode all the prefix lengths. + // All the prefix lengths are buffered in buffered_prefix_length_. PARQUET_THROW_NOT_OK( - buffered_prefix_length_->Resize(num_prefix * sizeof(int32_t))); - VELOX_DEBUG_ONLY int ret = prefix_len_decoder_.Decode( - reinterpret_cast(buffered_prefix_length_->mutable_data()), - num_prefix); - VELOX_DCHECK_EQ(ret, num_prefix); - prefix_len_offset_ = 0; - num_valid_values_ = num_prefix; - - int bytes_left = decoder_->bytesLeft(); + bufferedPrefixLength_->Resize(numPrefix * sizeof(int32_t))); + VELOX_DEBUG_ONLY int ret = prefixLenDecoder_.decode( + reinterpret_cast(bufferedPrefixLength_->mutable_data()), + numPrefix); + VELOX_DCHECK_EQ(ret, numPrefix); + prefixLenOffset_ = 0; + numValidValues_ = numPrefix; + + int bytesLeft = decoder_->bytesLeft(); // If len < bytes_left, prefix_len_decoder.Decode will throw exception. - VELOX_DCHECK_GE(len, bytes_left); - int suffix_begins = len - bytes_left; - // at this time, the decoder_ will be at the start of the encoded suffix - // data. - suffix_decoder_.SetData(num_values, data + suffix_begins, bytes_left); - - // TODO: read corrupted files written with bug(PARQUET-246). last_value_ - // should be set to last_value_in_previous_page_ when decoding a new + VELOX_DCHECK_GE(len, bytesLeft); + int suffixBegins = len - bytesLeft; + // At this time, the decoder_ will be at the start of the encoded suffix. + // Data. + suffixDecoder_.setData(numValues, data + suffixBegins, bytesLeft); + + // TODO: read corrupted files written with bug(PARQUET-246). last_value_. + // Should be set to last_value_in_previous_page_ when decoding a new. // page(except the first page) - last_value_ = ""; + lastValue_ = ""; } - int Decode(ByteArray* buffer, int max_values) override { - return GetInternal(buffer, max_values); + int decode(ByteArray* buffer, int maxValues) override { + return getInternal(buffer, maxValues); } - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::Accumulator* out) override { int result = 0; - PARQUET_THROW_NOT_OK(DecodeArrowDense( - num_values, null_count, valid_bits, valid_bits_offset, out, &result)); + PARQUET_THROW_NOT_OK(decodeArrowDense( + numValues, nullCount, validBits, validBitsOffset, out, &result)); return result; } - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) override { ParquetException::NYI( "DecodeArrow of DictAccumulator for DeltaByteArrayDecoder"); } private: - int GetInternal(ByteArray* buffer, int max_values) { - // Decode up to `max_values` strings into an internal buffer - // and reference them into `buffer`. - max_values = std::min(max_values, num_valid_values_); - if (max_values == 0) { - return max_values; - } - - int suffix_read = suffix_decoder_.Decode(buffer, max_values); - if (ARROW_PREDICT_FALSE(suffix_read != max_values)) { - ParquetException::EofException( - "Read " + std::to_string(suffix_read) + ", expecting " + - std::to_string(max_values) + " from suffix decoder"); - } - - int64_t data_size = 0; - const int32_t* prefix_len_ptr = - reinterpret_cast(buffered_prefix_length_->data()) + - prefix_len_offset_; - for (int i = 0; i < max_values; ++i) { - if (ARROW_PREDICT_FALSE(prefix_len_ptr[i] < 0)) { + int getInternal(ByteArray* buffer, int maxValues) { + // Decode up to `max_values` strings into an internal buffer. + // And reference them into `buffer`. + maxValues = std::min(maxValues, numValidValues_); + if (maxValues == 0) { + return maxValues; + } + + int suffixRead = suffixDecoder_.decode(buffer, maxValues); + if (ARROW_PREDICT_FALSE(suffixRead != maxValues)) { + ParquetException::eofException( + "Read " + std::to_string(suffixRead) + ", expecting " + + std::to_string(maxValues) + " from suffix decoder"); + } + + int64_t dataSize = 0; + const int32_t* prefixLenPtr = + reinterpret_cast(bufferedPrefixLength_->data()) + + prefixLenOffset_; + for (int i = 0; i < maxValues; ++i) { + if (ARROW_PREDICT_FALSE(prefixLenPtr[i] < 0)) { throw ParquetException("negative prefix length in DELTA_BYTE_ARRAY"); } if (ARROW_PREDICT_FALSE( - AddWithOverflow(data_size, prefix_len_ptr[i], &data_size) || - AddWithOverflow(data_size, buffer[i].len, &data_size))) { + AddWithOverflow(dataSize, prefixLenPtr[i], &dataSize) || + AddWithOverflow(dataSize, buffer[i].len, &dataSize))) { throw ParquetException("excess expansion in DELTA_BYTE_ARRAY"); } } - PARQUET_THROW_NOT_OK(buffered_data_->Resize(data_size)); + PARQUET_THROW_NOT_OK(bufferedData_->Resize(dataSize)); - string_view prefix{last_value_}; - uint8_t* data_ptr = buffered_data_->mutable_data(); - for (int i = 0; i < max_values; ++i) { + string_view prefix{lastValue_}; + uint8_t* dataPtr = bufferedData_->mutable_data(); + for (int i = 0; i < maxValues; ++i) { if (ARROW_PREDICT_FALSE( - static_cast(prefix_len_ptr[i]) > prefix.length())) { + static_cast(prefixLenPtr[i]) > prefix.length())) { throw ParquetException("prefix length too large in DELTA_BYTE_ARRAY"); } - memcpy(data_ptr, prefix.data(), prefix_len_ptr[i]); - // buffer[i] currently points to the string suffix - memcpy(data_ptr + prefix_len_ptr[i], buffer[i].ptr, buffer[i].len); - buffer[i].ptr = data_ptr; - buffer[i].len += prefix_len_ptr[i]; - data_ptr += buffer[i].len; + memcpy(dataPtr, prefix.data(), prefixLenPtr[i]); + // Buffer[i] currently points to the string suffix. + memcpy(dataPtr + prefixLenPtr[i], buffer[i].ptr, buffer[i].len); + buffer[i].ptr = dataPtr; + buffer[i].len += prefixLenPtr[i]; + dataPtr += buffer[i].len; prefix = string_view{ reinterpret_cast(buffer[i].ptr), buffer[i].len}; } - prefix_len_offset_ += max_values; - this->num_values_ -= max_values; - num_valid_values_ -= max_values; - last_value_ = std::string{prefix}; + prefixLenOffset_ += maxValues; + this->numValues_ -= maxValues; + numValidValues_ -= maxValues; + lastValue_ = std::string{prefix}; - if (num_valid_values_ == 0) { - last_value_in_previous_page_ = last_value_; + if (numValidValues_ == 0) { + lastValueInPreviousPage_ = lastValue_; } - return max_values; + return maxValues; } - Status DecodeArrowDense( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, + Status decodeArrowDense( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, typename EncodingTraits::Accumulator* out, - int* out_num_values) { + int* outNumValues) { ArrowBinaryHelper helper(out); - std::vector values(num_values); - const int num_valid_values = - GetInternal(values.data(), num_values - null_count); - VELOX_DCHECK_EQ(num_values - null_count, num_valid_values); + std::vector values(numValues); + const int numValidValues = + getInternal(values.data(), numValues - nullCount); + VELOX_DCHECK_EQ(numValues - nullCount, numValidValues); - auto values_ptr = reinterpret_cast(values.data()); - int value_idx = 0; + auto valuesPtr = reinterpret_cast(values.data()); + int valueIdx = 0; RETURN_NOT_OK(VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { - const auto& val = values_ptr[value_idx]; - if (ARROW_PREDICT_FALSE(!helper.CanFit(val.len))) { - RETURN_NOT_OK(helper.PushChunk()); + const auto& val = valuesPtr[valueIdx]; + if (ARROW_PREDICT_FALSE(!helper.canFit(val.len))) { + RETURN_NOT_OK(helper.pushChunk()); } - RETURN_NOT_OK(helper.Append(val.ptr, static_cast(val.len))); - ++value_idx; + RETURN_NOT_OK(helper.append(val.ptr, static_cast(val.len))); + ++valueIdx; return Status::OK(); }, [&]() { - RETURN_NOT_OK(helper.AppendNull()); - --null_count; + RETURN_NOT_OK(helper.appendNull()); + --nullCount; return Status::OK(); })); - VELOX_DCHECK_EQ(null_count, 0); - *out_num_values = num_valid_values; + VELOX_DCHECK_EQ(nullCount, 0); + *outNumValues = numValidValues; return Status::OK(); } std::shared_ptr decoder_; - DeltaBitPackDecoder prefix_len_decoder_; - DeltaLengthByteArrayDecoder suffix_decoder_; - std::string last_value_; - // string buffer for last value in previous page - std::string last_value_in_previous_page_; - int num_valid_values_; - uint32_t prefix_len_offset_; - std::shared_ptr buffered_prefix_length_; - std::shared_ptr buffered_data_; + DeltaBitPackDecoder prefixLenDecoder_; + DeltaLengthByteArrayDecoder suffixDecoder_; + std::string lastValue_; + // String buffer for last value in previous page. + std::string lastValueInPreviousPage_; + int numValidValues_; + uint32_t prefixLenOffset_; + std::shared_ptr bufferedPrefixLength_; + std::shared_ptr bufferedData_; }; -// ---------------------------------------------------------------------- -// BYTE_STREAM_SPLIT +// ----------------------------------------------------------------------. +// BYTE_STREAM_SPLIT. template class ByteStreamSplitDecoder : public DecoderImpl, virtual public TypedDecoder { public: - using T = typename DType::c_type; + using T = typename DType::CType; explicit ByteStreamSplitDecoder(const ColumnDescriptor* descr); - int Decode(T* buffer, int max_values) override; + int decode(T* buffer, int maxValues) override; - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::Accumulator* builder) override; + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::Accumulator* Builder) override; - int DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) override; + int decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) override; - void SetData(int num_values, const uint8_t* data, int len) override; + void setData(int numValues, const uint8_t* data, int len) override; - T* EnsureDecodeBuffer(int64_t min_values) { - const int64_t size = sizeof(T) * min_values; - if (!decode_buffer_ || decode_buffer_->size() < size) { - PARQUET_ASSIGN_OR_THROW(decode_buffer_, ::arrow::AllocateBuffer(size)); + T* ensureDecodeBuffer(int64_t minValues) { + const int64_t size = sizeof(T) * minValues; + if (!decodeBuffer_ || decodeBuffer_->size() < size) { + decodeBuffer_ = allocateBuffer(this->memoryPool(), size); } - return reinterpret_cast(decode_buffer_->mutable_data()); + return reinterpret_cast(decodeBuffer_->mutable_data()); } private: - int num_values_in_buffer_{0}; - std::shared_ptr decode_buffer_; + int numValuesInBuffer_{0}; + std::shared_ptr decodeBuffer_; static constexpr size_t kNumStreams = sizeof(T); }; @@ -3705,14 +3658,14 @@ class ByteStreamSplitDecoder : public DecoderImpl, template ByteStreamSplitDecoder::ByteStreamSplitDecoder( const ColumnDescriptor* descr) - : DecoderImpl(descr, Encoding::BYTE_STREAM_SPLIT) {} + : DecoderImpl(descr, Encoding::kByteStreamSplit) {} template -void ByteStreamSplitDecoder::SetData( - int num_values, +void ByteStreamSplitDecoder::setData( + int numValues, const uint8_t* data, int len) { - if (num_values * static_cast(sizeof(T)) < len) { + if (numValues * static_cast(sizeof(T)) < len) { throw ParquetException( "Data size too large for number of values (padding in byte stream split data " "page?)"); @@ -3720,184 +3673,183 @@ void ByteStreamSplitDecoder::SetData( if (len % sizeof(T) != 0) { throw ParquetException( "ByteStreamSplit data size " + std::to_string(len) + - " not aligned with type " + TypeToString(DType::type_num)); + " not aligned with type " + typeToString(DType::typeNum)); } - num_values = len / sizeof(T); - DecoderImpl::SetData(num_values, data, len); - num_values_in_buffer_ = num_values_; + numValues = len / sizeof(T); + DecoderImpl::setData(numValues, data, len); + numValuesInBuffer_ = numValues_; } template -int ByteStreamSplitDecoder::Decode(T* buffer, int max_values) { - const int values_to_decode = std::min(num_values_, max_values); - const int num_decoded_previously = num_values_in_buffer_ - num_values_; - const uint8_t* data = data_ + num_decoded_previously; +int ByteStreamSplitDecoder::decode(T* buffer, int maxValues) { + const int valuesToDecode = std::min(numValues_, maxValues); + const int numDecodedPreviously = numValuesInBuffer_ - numValues_; + const uint8_t* data = data_ + numDecodedPreviously; - ByteStreamSplitDecode( - data, values_to_decode, num_values_in_buffer_, buffer); - num_values_ -= values_to_decode; - len_ -= sizeof(T) * values_to_decode; - return values_to_decode; + byteStreamSplitDecode(data, valuesToDecode, numValuesInBuffer_, buffer); + numValues_ -= valuesToDecode; + len_ -= sizeof(T) * valuesToDecode; + return valuesToDecode; } template -int ByteStreamSplitDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::Accumulator* builder) { - constexpr int value_size = static_cast(kNumStreams); - int values_decoded = num_values - null_count; - if (ARROW_PREDICT_FALSE(len_ < value_size * values_decoded)) { - ParquetException::EofException(); - } - - PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); - - const int num_decoded_previously = num_values_in_buffer_ - num_values_; - const uint8_t* data = data_ + num_decoded_previously; +int ByteStreamSplitDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::Accumulator* Builder) { + constexpr int valueSize = static_cast(kNumStreams); + int valuesDecoded = numValues - nullCount; + if (ARROW_PREDICT_FALSE(len_ < valueSize * valuesDecoded)) { + ParquetException::eofException(); + } + + PARQUET_THROW_NOT_OK(Builder->Reserve(numValues)); + + const int numDecodedPreviously = numValuesInBuffer_ - numValues_; + const uint8_t* data = data_ + numDecodedPreviously; int offset = 0; #if defined(ARROW_HAVE_SIMD_SPLIT) - // Use fast decoding into intermediate buffer. This will also decode - // some null values, but it's fast enough that we don't care. - T* decode_out = EnsureDecodeBuffer(values_decoded); - ::arrow::util::internal::ByteStreamSplitDecode( - data, values_decoded, num_values_in_buffer_, decode_out); - - // XXX If null_count is 0, we could even append in bulk or decode directly - // into builder + // Use fast decoding into intermediate buffer. This will also decode. + // Some null values, but it's fast enough that we don't care. + T* decodeOut = ensureDecodeBuffer(valuesDecoded); + ::arrow::util::internal::byte_stream_split_decode( + data, valuesDecoded, numValuesInBuffer_, decodeOut); + + // XXX If null_count is 0, we could even append in bulk or decode directly. + // Into builder. VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { - builder->UnsafeAppend(decode_out[offset]); + Builder->UnsafeAppend(decodeOut[offset]); ++offset; }, - [&]() { builder->UnsafeAppendNull(); }); + [&]() { Builder->UnsafeAppendNull(); }); #else VisitNullBitmapInline( - valid_bits, - valid_bits_offset, - num_values, - null_count, + validBits, + validBitsOffset, + numValues, + nullCount, [&]() { - uint8_t gathered_byte_data[kNumStreams]; + uint8_t gatheredByteData[kNumStreams]; for (size_t b = 0; b < kNumStreams; ++b) { - const size_t byte_index = b * num_values_in_buffer_ + offset; - gathered_byte_data[b] = data[byte_index]; + const size_t byteIndex = b * numValuesInBuffer_ + offset; + gatheredByteData[b] = data[byteIndex]; } - builder->UnsafeAppend(SafeLoadAs(&gathered_byte_data[0])); + Builder->UnsafeAppend(SafeLoadAs(&gatheredByteData[0])); ++offset; }, - [&]() { builder->UnsafeAppendNull(); }); + [&]() { Builder->UnsafeAppendNull(); }); #endif - num_values_ -= values_decoded; - len_ -= sizeof(T) * values_decoded; - return values_decoded; + numValues_ -= valuesDecoded; + len_ -= sizeof(T) * valuesDecoded; + return valuesDecoded; } template -int ByteStreamSplitDecoder::DecodeArrow( - int num_values, - int null_count, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - typename EncodingTraits::DictAccumulator* builder) { +int ByteStreamSplitDecoder::decodeArrow( + int numValues, + int nullCount, + const uint8_t* validBits, + int64_t validBitsOffset, + typename EncodingTraits::DictAccumulator* Builder) { ParquetException::NYI("DecodeArrow for ByteStreamSplitDecoder"); } } // namespace -// ---------------------------------------------------------------------- -// Encoder and decoder factory functions +// ----------------------------------------------------------------------. +// Encoder and decoder factory functions. -std::unique_ptr MakeEncoder( - Type::type type_num, +std::unique_ptr makeEncoder( + Type::type typeNum, Encoding::type encoding, - bool use_dictionary, + bool useDictionary, const ColumnDescriptor* descr, MemoryPool* pool) { - if (use_dictionary) { - switch (type_num) { - case Type::INT32: + if (useDictionary) { + switch (typeNum) { + case Type::kInt32: return std::make_unique>(descr, pool); - case Type::INT64: + case Type::kInt64: return std::make_unique>(descr, pool); - case Type::INT96: + case Type::kInt96: return std::make_unique>(descr, pool); - case Type::FLOAT: + case Type::kFloat: return std::make_unique>(descr, pool); - case Type::DOUBLE: + case Type::kDouble: return std::make_unique>(descr, pool); - case Type::BYTE_ARRAY: + case Type::kByteArray: return std::make_unique>(descr, pool); - case Type::FIXED_LEN_BYTE_ARRAY: + case Type::kFixedLenByteArray: return std::make_unique>(descr, pool); default: VELOX_DCHECK(false, "Encoder not implemented"); break; } - } else if (encoding == Encoding::PLAIN) { - switch (type_num) { - case Type::BOOLEAN: + } else if (encoding == Encoding::kPlain) { + switch (typeNum) { + case Type::kBoolean: return std::make_unique>(descr, pool); - case Type::INT32: + case Type::kInt32: return std::make_unique>(descr, pool); - case Type::INT64: + case Type::kInt64: return std::make_unique>(descr, pool); - case Type::INT96: + case Type::kInt96: return std::make_unique>(descr, pool); - case Type::FLOAT: + case Type::kFloat: return std::make_unique>(descr, pool); - case Type::DOUBLE: + case Type::kDouble: return std::make_unique>(descr, pool); - case Type::BYTE_ARRAY: + case Type::kByteArray: return std::make_unique>(descr, pool); - case Type::FIXED_LEN_BYTE_ARRAY: + case Type::kFixedLenByteArray: return std::make_unique>(descr, pool); default: VELOX_DCHECK(false, "Encoder not implemented"); break; } - } else if (encoding == Encoding::BYTE_STREAM_SPLIT) { - switch (type_num) { - case Type::FLOAT: + } else if (encoding == Encoding::kByteStreamSplit) { + switch (typeNum) { + case Type::kFloat: return std::make_unique>(descr, pool); - case Type::DOUBLE: + case Type::kDouble: return std::make_unique>( descr, pool); default: throw ParquetException( "BYTE_STREAM_SPLIT only supports FLOAT and DOUBLE"); } - } else if (encoding == Encoding::DELTA_BINARY_PACKED) { - switch (type_num) { - case Type::INT32: + } else if (encoding == Encoding::kDeltaBinaryPacked) { + switch (typeNum) { + case Type::kInt32: return std::make_unique>(descr, pool); - case Type::INT64: + case Type::kInt64: return std::make_unique>(descr, pool); default: throw ParquetException( "DELTA_BINARY_PACKED encoder only supports INT32 and INT64"); } - } else if (encoding == Encoding::DELTA_LENGTH_BYTE_ARRAY) { - switch (type_num) { - case Type::BYTE_ARRAY: + } else if (encoding == Encoding::kDeltaLengthByteArray) { + switch (typeNum) { + case Type::kByteArray: return std::make_unique>( descr, pool); default: throw ParquetException( "DELTA_LENGTH_BYTE_ARRAY only supports BYTE_ARRAY"); } - } else if (encoding == Encoding::RLE) { - switch (type_num) { - case Type::BOOLEAN: + } else if (encoding == Encoding::kRle) { + switch (typeNum) { + case Type::kBoolean: return std::make_unique(descr, pool); default: throw ParquetException("RLE only supports BOOLEAN"); @@ -3909,64 +3861,64 @@ std::unique_ptr MakeEncoder( return nullptr; } -std::unique_ptr MakeDecoder( - Type::type type_num, +std::unique_ptr makeDecoder( + Type::type typeNum, Encoding::type encoding, const ColumnDescriptor* descr, ::arrow::MemoryPool* pool) { - if (encoding == Encoding::PLAIN) { - switch (type_num) { - case Type::BOOLEAN: + if (encoding == Encoding::kPlain) { + switch (typeNum) { + case Type::kBoolean: return std::make_unique(descr); - case Type::INT32: + case Type::kInt32: return std::make_unique>(descr); - case Type::INT64: + case Type::kInt64: return std::make_unique>(descr); - case Type::INT96: + case Type::kInt96: return std::make_unique>(descr); - case Type::FLOAT: + case Type::kFloat: return std::make_unique>(descr); - case Type::DOUBLE: + case Type::kDouble: return std::make_unique>(descr); - case Type::BYTE_ARRAY: + case Type::kByteArray: return std::make_unique(descr); - case Type::FIXED_LEN_BYTE_ARRAY: + case Type::kFixedLenByteArray: return std::make_unique(descr); default: break; } - } else if (encoding == Encoding::BYTE_STREAM_SPLIT) { - switch (type_num) { - case Type::FLOAT: + } else if (encoding == Encoding::kByteStreamSplit) { + switch (typeNum) { + case Type::kFloat: return std::make_unique>(descr); - case Type::DOUBLE: + case Type::kDouble: return std::make_unique>(descr); default: throw ParquetException( "BYTE_STREAM_SPLIT only supports FLOAT and DOUBLE"); } - } else if (encoding == Encoding::DELTA_BINARY_PACKED) { - switch (type_num) { - case Type::INT32: + } else if (encoding == Encoding::kDeltaBinaryPacked) { + switch (typeNum) { + case Type::kInt32: return std::make_unique>(descr, pool); - case Type::INT64: + case Type::kInt64: return std::make_unique>(descr, pool); default: throw ParquetException( "DELTA_BINARY_PACKED decoder only supports INT32 and INT64"); } - } else if (encoding == Encoding::DELTA_BYTE_ARRAY) { - if (type_num == Type::BYTE_ARRAY) { + } else if (encoding == Encoding::kDeltaByteArray) { + if (typeNum == Type::kByteArray) { return std::make_unique(descr, pool); } throw ParquetException("DELTA_BYTE_ARRAY only supports BYTE_ARRAY"); - } else if (encoding == Encoding::DELTA_LENGTH_BYTE_ARRAY) { - if (type_num == Type::BYTE_ARRAY) { + } else if (encoding == Encoding::kDeltaLengthByteArray) { + if (typeNum == Type::kByteArray) { return std::make_unique(descr, pool); } throw ParquetException("DELTA_LENGTH_BYTE_ARRAY only supports BYTE_ARRAY"); - } else if (encoding == Encoding::RLE) { - if (type_num == Type::BOOLEAN) { + } else if (encoding == Encoding::kRle) { + if (typeNum == Type::kBoolean) { return std::make_unique(descr); } throw ParquetException("RLE encoding only supports BOOLEAN"); @@ -3978,27 +3930,27 @@ std::unique_ptr MakeDecoder( } namespace detail { -std::unique_ptr MakeDictDecoder( - Type::type type_num, +std::unique_ptr makeDictDecoder( + Type::type typeNum, const ColumnDescriptor* descr, MemoryPool* pool) { - switch (type_num) { - case Type::BOOLEAN: + switch (typeNum) { + case Type::kBoolean: ParquetException::NYI( "Dictionary encoding not implemented for boolean type"); - case Type::INT32: + case Type::kInt32: return std::make_unique>(descr, pool); - case Type::INT64: + case Type::kInt64: return std::make_unique>(descr, pool); - case Type::INT96: + case Type::kInt96: return std::make_unique>(descr, pool); - case Type::FLOAT: + case Type::kFloat: return std::make_unique>(descr, pool); - case Type::DOUBLE: + case Type::kDouble: return std::make_unique>(descr, pool); - case Type::BYTE_ARRAY: + case Type::kByteArray: return std::make_unique(descr, pool); - case Type::FIXED_LEN_BYTE_ARRAY: + case Type::kFixedLenByteArray: return std::make_unique>(descr, pool); default: break; diff --git a/velox/dwio/parquet/writer/arrow/tests/FileDeserializeTest.cpp b/velox/dwio/parquet/writer/arrow/tests/FileDeserializeTest.cpp index 9c144739158..eb6279f55ce 100644 --- a/velox/dwio/parquet/writer/arrow/tests/FileDeserializeTest.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/FileDeserializeTest.cpp @@ -19,20 +19,22 @@ #include #include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/io/IoStatistics.h" +#include "velox/common/testutil/TempFilePath.h" #include "velox/dwio/parquet/reader/ParquetReader.h" #include "velox/dwio/parquet/writer/arrow/Exception.h" #include "velox/dwio/parquet/writer/arrow/FileWriter.h" #include "velox/dwio/parquet/writer/arrow/ThriftInternal.h" #include "velox/dwio/parquet/writer/arrow/tests/TestUtil.h" -#include "velox/exec/tests/utils/TempFilePath.h" #include "arrow/testing/gtest_util.h" #include "arrow/util/crc32.h" namespace facebook::velox::parquet::arrow { +using namespace facebook::velox::common::testutil; namespace { void writeToFile( - std::shared_ptr filePath, + std::shared_ptr filePath, std::shared_ptr buffer) { auto localWriteFile = std::make_unique(filePath->getPath(), false, false); @@ -45,20 +47,20 @@ void writeToFile( using ::arrow::io::BufferReader; -// Adds page statistics occupying a certain amount of bytes (for testing very +// Adds page statistics occupying a certain amount of bytes (for testing very. // large page headers) template static inline void -AddDummyStats(int stat_size, H& header, bool fill_all_stats = false) { - std::vector stat_bytes(stat_size); - // Some non-zero value - std::fill(stat_bytes.begin(), stat_bytes.end(), 1); +addDummyStats(int statSize, H& header, bool fillAllStats = false) { + std::vector statBytes(statSize); + // Some non-zero value. + std::fill(statBytes.begin(), statBytes.end(), 1); header.statistics.__set_max( - std::string(reinterpret_cast(stat_bytes.data()), stat_size)); + std::string(reinterpret_cast(statBytes.data()), statSize)); - if (fill_all_stats) { - header.statistics.__set_min(std::string( - reinterpret_cast(stat_bytes.data()), stat_size)); + if (fillAllStats) { + header.statistics.__set_min( + std::string(reinterpret_cast(statBytes.data()), statSize)); header.statistics.__set_null_count(42); header.statistics.__set_distinct_count(1); } @@ -67,7 +69,7 @@ AddDummyStats(int stat_size, H& header, bool fill_all_stats = false) { } template -static inline void CheckStatistics( +static inline void checkStatistics( const H& expected, const EncodedStatistics& actual) { if (expected.statistics.__isset.max) { @@ -77,434 +79,417 @@ static inline void CheckStatistics( ASSERT_EQ(expected.statistics.min, actual.min()); } if (expected.statistics.__isset.null_count) { - ASSERT_EQ(expected.statistics.null_count, actual.null_count); + ASSERT_EQ(expected.statistics.null_count, actual.nullCount); } if (expected.statistics.__isset.distinct_count) { - ASSERT_EQ(expected.statistics.distinct_count, actual.distinct_count); + ASSERT_EQ(expected.statistics.distinct_count, actual.distinctCount); } } -static std::vector GetSupportedCodecTypes() { - std::vector codec_types; +static std::vector getSupportedCodecTypes() { + std::vector codecTypes; - codec_types.push_back(Compression::SNAPPY); + codecTypes.push_back(Compression::SNAPPY); #ifdef ARROW_WITH_BROTLI - codec_types.push_back(Compression::BROTLI); + codecTypes.push_back(Compression::BROTLI); #endif - codec_types.push_back(Compression::GZIP); + codecTypes.push_back(Compression::GZIP); - codec_types.push_back(Compression::LZ4); - codec_types.push_back(Compression::LZ4_HADOOP); + codecTypes.push_back(Compression::LZ4); + codecTypes.push_back(Compression::LZ4_HADOOP); - codec_types.push_back(Compression::ZSTD); - return codec_types; + codecTypes.push_back(Compression::ZSTD); + return codecTypes; } class TestPageSerde : public ::testing::Test { public: void SetUp() { - data_page_header_.encoding = + dataPageHeader_.encoding = facebook::velox::parquet::thrift::Encoding::PLAIN; - data_page_header_.definition_level_encoding = + dataPageHeader_.definition_level_encoding = facebook::velox::parquet::thrift::Encoding::RLE; - data_page_header_.repetition_level_encoding = + dataPageHeader_.repetition_level_encoding = facebook::velox::parquet::thrift::Encoding::RLE; - ResetStream(); + resetStream(); } - void InitSerializedPageReader( - int64_t num_rows, - Compression::type codec = Compression::UNCOMPRESSED, + void initSerializedPageReader( + int64_t numRows, + Compression::type Codec = Compression::UNCOMPRESSED, const ReaderProperties& properties = ReaderProperties()) { - EndStream(); + endStream(); - auto stream = std::make_shared<::arrow::io::BufferReader>(out_buffer_); - page_reader_ = PageReader::Open(stream, num_rows, codec, properties); + auto stream = std::make_shared<::arrow::io::BufferReader>(outBuffer_); + pageReader_ = PageReader::open(stream, numRows, Codec, properties); } - void WriteDataPageHeader( - int max_serialized_len = 1024, - int32_t uncompressed_size = 0, - int32_t compressed_size = 0, + void writeDataPageHeader( + int maxSerializedLen = 1024, + int32_t uncompressedSize = 0, + int32_t compressedSize = 0, std::optional checksum = std::nullopt) { - // Simplifying writing serialized data page headers which may or may not - // have meaningful data associated with them - - // Serialize the Page header - page_header_.__set_data_page_header(data_page_header_); - page_header_.uncompressed_page_size = uncompressed_size; - page_header_.compressed_page_size = compressed_size; - page_header_.type = facebook::velox::parquet::thrift::PageType::DATA_PAGE; + // Simplifying writing serialized data page headers which may or may not. + // Have meaningful data associated with them. + + // Serialize the Page header. + pageHeader_.__set_data_page_header(dataPageHeader_); + pageHeader_.uncompressed_page_size = uncompressedSize; + pageHeader_.compressed_page_size = compressedSize; + pageHeader_.type = facebook::velox::parquet::thrift::PageType::DATA_PAGE; if (checksum.has_value()) { - page_header_.__set_crc(checksum.value()); + pageHeader_.__set_crc(checksum.value()); } ThriftSerializer serializer; - ASSERT_NO_THROW(serializer.Serialize(&page_header_, out_stream_.get())); + ASSERT_NO_THROW(serializer.serialize(&pageHeader_, outStream_.get())); } - void WriteDataPageHeaderV2( - int max_serialized_len = 1024, - int32_t uncompressed_size = 0, - int32_t compressed_size = 0, + void writeDataPageHeaderV2( + int maxSerializedLen = 1024, + int32_t uncompressedSize = 0, + int32_t compressedSize = 0, std::optional checksum = std::nullopt) { - // Simplifying writing serialized data page V2 headers which may or may not - // have meaningful data associated with them - - // Serialize the Page header - page_header_.__set_data_page_header_v2(data_page_header_v2_); - page_header_.uncompressed_page_size = uncompressed_size; - page_header_.compressed_page_size = compressed_size; - page_header_.type = - facebook::velox::parquet::thrift::PageType::DATA_PAGE_V2; + // Simplifying writing serialized data page V2 headers which may or may not. + // Have meaningful data associated with them. + + // Serialize the Page header. + pageHeader_.__set_data_page_header_v2(dataPageHeaderV2_); + pageHeader_.uncompressed_page_size = uncompressedSize; + pageHeader_.compressed_page_size = compressedSize; + pageHeader_.type = facebook::velox::parquet::thrift::PageType::DATA_PAGE_V2; if (checksum.has_value()) { - page_header_.__set_crc(checksum.value()); + pageHeader_.__set_crc(checksum.value()); } ThriftSerializer serializer; - ASSERT_NO_THROW(serializer.Serialize(&page_header_, out_stream_.get())); + ASSERT_NO_THROW(serializer.serialize(&pageHeader_, outStream_.get())); } - void WriteDictionaryPageHeader( - int32_t uncompressed_size = 0, - int32_t compressed_size = 0, + void writeDictionaryPageHeader( + int32_t uncompressedSize = 0, + int32_t compressedSize = 0, std::optional checksum = std::nullopt) { - page_header_.__set_dictionary_page_header(dictionary_page_header_); - page_header_.uncompressed_page_size = uncompressed_size; - page_header_.compressed_page_size = compressed_size; - page_header_.type = + pageHeader_.__set_dictionary_page_header(dictionaryPageHeader_); + pageHeader_.uncompressed_page_size = uncompressedSize; + pageHeader_.compressed_page_size = compressedSize; + pageHeader_.type = facebook::velox::parquet::thrift::PageType::DICTIONARY_PAGE; if (checksum.has_value()) { - page_header_.__set_crc(checksum.value()); + pageHeader_.__set_crc(checksum.value()); } ThriftSerializer serializer; - ASSERT_NO_THROW(serializer.Serialize(&page_header_, out_stream_.get())); + ASSERT_NO_THROW(serializer.serialize(&pageHeader_, outStream_.get())); } - void WriteIndexPageHeader( - int32_t uncompressed_size = 0, - int32_t compressed_size = 0) { - page_header_.__set_index_page_header(index_page_header_); - page_header_.uncompressed_page_size = uncompressed_size; - page_header_.compressed_page_size = compressed_size; - page_header_.type = facebook::velox::parquet::thrift::PageType::INDEX_PAGE; + void writeIndexPageHeader( + int32_t uncompressedSize = 0, + int32_t compressedSize = 0) { + pageHeader_.__set_index_page_header(indexPageHeader_); + pageHeader_.uncompressed_page_size = uncompressedSize; + pageHeader_.compressed_page_size = compressedSize; + pageHeader_.type = facebook::velox::parquet::thrift::PageType::INDEX_PAGE; ThriftSerializer serializer; - ASSERT_NO_THROW(serializer.Serialize(&page_header_, out_stream_.get())); + ASSERT_NO_THROW(serializer.serialize(&pageHeader_, outStream_.get())); } - void ResetStream() { - out_stream_ = CreateOutputStream(); + void resetStream() { + outStream_ = createOutputStream(); } - void EndStream() { - PARQUET_ASSIGN_OR_THROW(out_buffer_, out_stream_->Finish()); + void endStream() { + PARQUET_ASSIGN_OR_THROW(outBuffer_, outStream_->Finish()); } - void TestPageSerdeCrc( - bool write_checksum, - bool write_page_corrupt, - bool verification_checksum, - bool has_dictionary = false, - bool write_data_page_v2 = false); + void testPageSerdeCrc( + bool writeChecksum, + bool writePageCorrupt, + bool verificationChecksum, + bool hasDictionary = false, + bool writeDataPageV2 = false); - void TestPageCompressionRoundTrip(const std::vector& page_sizes); + void testPageCompressionRoundTrip(const std::vector& pageSizes); protected: - std::shared_ptr<::arrow::io::BufferOutputStream> out_stream_; - std::shared_ptr out_buffer_; - - std::unique_ptr page_reader_; - facebook::velox::parquet::thrift::PageHeader page_header_; - facebook::velox::parquet::thrift::DataPageHeader data_page_header_; - facebook::velox::parquet::thrift::DataPageHeaderV2 data_page_header_v2_; - facebook::velox::parquet::thrift::IndexPageHeader index_page_header_; - facebook::velox::parquet::thrift::DictionaryPageHeader - dictionary_page_header_; + std::shared_ptr<::arrow::io::BufferOutputStream> outStream_; + std::shared_ptr outBuffer_; + + std::unique_ptr pageReader_; + facebook::velox::parquet::thrift::PageHeader pageHeader_; + facebook::velox::parquet::thrift::DataPageHeader dataPageHeader_; + facebook::velox::parquet::thrift::DataPageHeaderV2 dataPageHeaderV2_; + facebook::velox::parquet::thrift::IndexPageHeader indexPageHeader_; + facebook::velox::parquet::thrift::DictionaryPageHeader dictionaryPageHeader_; }; -void TestPageSerde::TestPageSerdeCrc( - bool write_checksum, - bool write_page_corrupt, - bool verification_checksum, - bool has_dictionary, - bool write_data_page_v2) { - auto codec_types = GetSupportedCodecTypes(); - codec_types.push_back(Compression::UNCOMPRESSED); - const int32_t num_rows = 32; // dummy value - if (write_data_page_v2) { - data_page_header_v2_.num_values = num_rows; +void TestPageSerde::testPageSerdeCrc( + bool writeChecksum, + bool writePageCorrupt, + bool verificationChecksum, + bool hasDictionary, + bool writeDataPageV2) { + auto codecTypes = getSupportedCodecTypes(); + codecTypes.push_back(Compression::UNCOMPRESSED); + const int32_t numRows = 32; // dummy value + if (writeDataPageV2) { + dataPageHeaderV2_.num_values = numRows; } else { - data_page_header_.num_values = num_rows; + dataPageHeader_.num_values = numRows; } - dictionary_page_header_.num_values = num_rows; + dictionaryPageHeader_.num_values = numRows; - const int num_pages = 10; + const int numPages = 10; - std::vector> faux_data; - faux_data.resize(num_pages); - for (int i = 0; i < num_pages; ++i) { - // The pages keep getting larger - int page_size = (i + 1) * 64; - test::random_bytes(page_size, 0, &faux_data[i]); + std::vector> fauxData; + fauxData.resize(numPages); + for (int i = 0; i < numPages; ++i) { + // The pages keep getting larger. + int pageSize = (i + 1) * 64; + test::randomBytes(pageSize, 0, &fauxData[i]); } - for (auto codec_type : codec_types) { - auto codec = GetCodec(codec_type); + for (auto codecType : codecTypes) { + auto Codec = getCodec(codecType); std::vector buffer; - for (int i = 0; i < num_pages; ++i) { - const uint8_t* data = faux_data[i].data(); - int data_size = static_cast(faux_data[i].size()); - int64_t actual_size; - if (codec == nullptr) { - buffer = faux_data[i]; - actual_size = data_size; + for (int i = 0; i < numPages; ++i) { + const uint8_t* data = fauxData[i].data(); + int dataSize = static_cast(fauxData[i].size()); + int64_t actualSize; + if (Codec == nullptr) { + buffer = fauxData[i]; + actualSize = dataSize; } else { - int64_t max_compressed_size = codec->MaxCompressedLen(data_size, data); - buffer.resize(max_compressed_size); + int64_t maxCompressedSize = Codec->maxCompressedLen(dataSize, data); + buffer.resize(maxCompressedSize); ASSERT_OK_AND_ASSIGN( - actual_size, - codec->Compress(data_size, data, max_compressed_size, &buffer[0])); + actualSize, + Codec->compress(dataSize, data, maxCompressedSize, &buffer[0])); } - std::optional checksum_opt; - if (write_checksum) { + std::optional checksumOpt; + if (writeChecksum) { uint32_t checksum = - ::arrow::internal::crc32(/* prev */ 0, buffer.data(), actual_size); - if (write_page_corrupt) { + ::arrow::internal::crc32(/* prev */ 0, buffer.data(), actualSize); + if (writePageCorrupt) { checksum += 1; // write a bad checksum } - checksum_opt = checksum; + checksumOpt = checksum; } - if (has_dictionary && i == 0) { - ASSERT_NO_FATAL_FAILURE(WriteDictionaryPageHeader( - data_size, static_cast(actual_size), checksum_opt)); + if (hasDictionary && i == 0) { + ASSERT_NO_FATAL_FAILURE(writeDictionaryPageHeader( + dataSize, static_cast(actualSize), checksumOpt)); } else { - if (write_data_page_v2) { - ASSERT_NO_FATAL_FAILURE(WriteDataPageHeaderV2( - 1024, - data_size, - static_cast(actual_size), - checksum_opt)); + if (writeDataPageV2) { + ASSERT_NO_FATAL_FAILURE(writeDataPageHeaderV2( + 1024, dataSize, static_cast(actualSize), checksumOpt)); } else { - ASSERT_NO_FATAL_FAILURE(WriteDataPageHeader( - 1024, - data_size, - static_cast(actual_size), - checksum_opt)); + ASSERT_NO_FATAL_FAILURE(writeDataPageHeader( + 1024, dataSize, static_cast(actualSize), checksumOpt)); } } - ASSERT_OK(out_stream_->Write(buffer.data(), actual_size)); + ASSERT_OK(outStream_->Write(buffer.data(), actualSize)); } - ReaderProperties readerProperties; - readerProperties.set_page_checksum_verification(verification_checksum); - InitSerializedPageReader( - num_rows * num_pages, codec_type, readerProperties); + ReaderProperties ReaderProperties; + ReaderProperties.setPageChecksumVerification(verificationChecksum); + initSerializedPageReader(numRows * numPages, codecType, ReaderProperties); - for (int i = 0; i < num_pages; ++i) { - if (write_checksum && write_page_corrupt && verification_checksum) { + for (int i = 0; i < numPages; ++i) { + if (writeChecksum && writePageCorrupt && verificationChecksum) { EXPECT_THROW_THAT( - [&]() { page_reader_->NextPage(); }, + [&]() { pageReader_->nextPage(); }, ParquetException, ::testing::Property( &ParquetException::what, ::testing::HasSubstr("CRC checksum verification failed"))); } else { - const auto page = page_reader_->NextPage(); - const int data_size = static_cast(faux_data[i].size()); - if (has_dictionary && i == 0) { - ASSERT_EQ(PageType::DICTIONARY_PAGE, page->type()); - const auto dict_page = static_cast(page.get()); - ASSERT_EQ(data_size, dict_page->size()); - ASSERT_EQ( - 0, memcmp(faux_data[i].data(), dict_page->data(), data_size)); - } else if (write_data_page_v2) { - ASSERT_EQ(PageType::DATA_PAGE_V2, page->type()); - const auto data_page = static_cast(page.get()); - ASSERT_EQ(data_size, data_page->size()); - ASSERT_EQ( - 0, memcmp(faux_data[i].data(), data_page->data(), data_size)); + const auto page = pageReader_->nextPage(); + const int dataSize = static_cast(fauxData[i].size()); + if (hasDictionary && i == 0) { + ASSERT_EQ(PageType::kDictionaryPage, page->type()); + const auto dictPage = static_cast(page.get()); + ASSERT_EQ(dataSize, dictPage->size()); + ASSERT_EQ(0, memcmp(fauxData[i].data(), dictPage->data(), dataSize)); + } else if (writeDataPageV2) { + ASSERT_EQ(PageType::kDataPageV2, page->type()); + const auto dataPage = static_cast(page.get()); + ASSERT_EQ(dataSize, dataPage->size()); + ASSERT_EQ(0, memcmp(fauxData[i].data(), dataPage->data(), dataSize)); } else { - ASSERT_EQ(PageType::DATA_PAGE, page->type()); - const auto data_page = static_cast(page.get()); - ASSERT_EQ(data_size, data_page->size()); - ASSERT_EQ( - 0, memcmp(faux_data[i].data(), data_page->data(), data_size)); + ASSERT_EQ(PageType::kDataPage, page->type()); + const auto dataPage = static_cast(page.get()); + ASSERT_EQ(dataSize, dataPage->size()); + ASSERT_EQ(0, memcmp(fauxData[i].data(), dataPage->data(), dataSize)); } } } - ResetStream(); + resetStream(); } } -void CheckDataPageHeader( +void checkDataPageHeader( const facebook::velox::parquet::thrift::DataPageHeader& expected, const Page* page) { - ASSERT_EQ(PageType::DATA_PAGE, page->type()); + ASSERT_EQ(PageType::kDataPage, page->type()); - const DataPageV1* data_page = static_cast(page); - ASSERT_EQ(expected.num_values, data_page->num_values()); - ASSERT_EQ(expected.encoding, data_page->encoding()); + const DataPageV1* dataPage = static_cast(page); + ASSERT_EQ(expected.num_values, dataPage->numValues()); + ASSERT_EQ(expected.encoding, dataPage->encoding()); ASSERT_EQ( - expected.definition_level_encoding, - data_page->definition_level_encoding()); + expected.definition_level_encoding, dataPage->definitionLevelEncoding()); ASSERT_EQ( - expected.repetition_level_encoding, - data_page->repetition_level_encoding()); - CheckStatistics(expected, data_page->statistics()); + expected.repetition_level_encoding, dataPage->repetitionLevelEncoding()); + checkStatistics(expected, dataPage->statistics()); } // Overload for DataPageV2 tests. -void CheckDataPageHeader( +void checkDataPageHeader( const facebook::velox::parquet::thrift::DataPageHeaderV2& expected, const Page* page) { - ASSERT_EQ(PageType::DATA_PAGE_V2, page->type()); + ASSERT_EQ(PageType::kDataPageV2, page->type()); - const DataPageV2* data_page = static_cast(page); - ASSERT_EQ(expected.num_values, data_page->num_values()); - ASSERT_EQ(expected.num_nulls, data_page->num_nulls()); - ASSERT_EQ(expected.num_rows, data_page->num_rows()); - ASSERT_EQ(expected.encoding, data_page->encoding()); + const DataPageV2* dataPage = static_cast(page); + ASSERT_EQ(expected.num_values, dataPage->numValues()); + ASSERT_EQ(expected.num_nulls, dataPage->numNulls()); + ASSERT_EQ(expected.num_rows, dataPage->numRows()); + ASSERT_EQ(expected.encoding, dataPage->encoding()); ASSERT_EQ( expected.definition_levels_byte_length, - data_page->definition_levels_byte_length()); + dataPage->definitionLevelsByteLength()); ASSERT_EQ( expected.repetition_levels_byte_length, - data_page->repetition_levels_byte_length()); - ASSERT_EQ(expected.is_compressed, data_page->is_compressed()); - CheckStatistics(expected, data_page->statistics()); + dataPage->repetitionLevelsByteLength()); + ASSERT_EQ(expected.is_compressed, dataPage->isCompressed()); + checkStatistics(expected, dataPage->statistics()); } TEST_F(TestPageSerde, DataPageV1) { - int stats_size = 512; - const int32_t num_rows = 4444; - AddDummyStats(stats_size, data_page_header_, /* fill_all_stats = */ true); - data_page_header_.num_values = num_rows; - - ASSERT_NO_FATAL_FAILURE(WriteDataPageHeader()); - InitSerializedPageReader(num_rows); - std::shared_ptr current_page = page_reader_->NextPage(); + int statsSize = 512; + const int32_t numRows = 4444; + addDummyStats(statsSize, dataPageHeader_, /*fill_all_stats=*/true); + dataPageHeader_.num_values = numRows; + + ASSERT_NO_FATAL_FAILURE(writeDataPageHeader()); + initSerializedPageReader(numRows); + std::shared_ptr currentPage = pageReader_->nextPage(); ASSERT_NO_FATAL_FAILURE( - CheckDataPageHeader(data_page_header_, current_page.get())); + checkDataPageHeader(dataPageHeader_, currentPage.get())); } -// Templated test class to test page filtering for both -// facebook::velox::parquet::thrift::DataPageHeader and -// facebook::velox::parquet::thrift::DataPageHeaderV2. +// Templated test class to test page filtering for both. +// Facebook::velox::parquet::thrift::DataPageHeader and. +// Facebook::velox::parquet::thrift::DataPageHeaderV2. template class PageFilterTest : public TestPageSerde { public: const int kNumPages = 10; - void WriteStream(); - void WritePageWithoutStats(); - void CheckNumRows(std::optional num_rows, const T& header); + void writeStream(); + void writePageWithoutStats(); + void checkNumRows(std::optional numRows, const T& header); protected: - std::vector data_page_headers_; - int total_rows_ = 0; + std::vector dataPageHeaders_; + int totalRows_ = 0; }; template <> void PageFilterTest< - facebook::velox::parquet::thrift::DataPageHeader>::WriteStream() { + facebook::velox::parquet::thrift::DataPageHeader>::writeStream() { for (int i = 0; i < kNumPages; ++i) { // Vary the number of rows to produce different headers. - int32_t num_rows = i + 100; - total_rows_ += num_rows; - int data_size = i + 1024; - this->data_page_header_.__set_num_values(num_rows); - this->data_page_header_.statistics.__set_min_value("A" + std::to_string(i)); - this->data_page_header_.statistics.__set_max_value("Z" + std::to_string(i)); - this->data_page_header_.statistics.__set_null_count(0); - this->data_page_header_.statistics.__set_distinct_count(num_rows); - this->data_page_header_.__isset.statistics = true; - ASSERT_NO_FATAL_FAILURE(this->WriteDataPageHeader( - /*max_serialized_len=*/1024, data_size, data_size)); - data_page_headers_.push_back(this->data_page_header_); + int32_t numRows = i + 100; + totalRows_ += numRows; + int dataSize = i + 1024; + this->dataPageHeader_.__set_num_values(numRows); + this->dataPageHeader_.statistics.__set_min_value("A" + std::to_string(i)); + this->dataPageHeader_.statistics.__set_max_value("Z" + std::to_string(i)); + this->dataPageHeader_.statistics.__set_null_count(0); + this->dataPageHeader_.statistics.__set_distinct_count(numRows); + this->dataPageHeader_.__isset.statistics = true; + ASSERT_NO_FATAL_FAILURE( + this->writeDataPageHeader(1024, dataSize, dataSize)); + dataPageHeaders_.push_back(this->dataPageHeader_); // Also write data, to make sure we skip the data correctly. - std::vector faux_data(data_size); - ASSERT_OK(this->out_stream_->Write(faux_data.data(), data_size)); + std::vector fauxData(dataSize); + ASSERT_OK(this->outStream_->Write(fauxData.data(), dataSize)); } - this->EndStream(); + this->endStream(); } template <> void PageFilterTest< - facebook::velox::parquet::thrift::DataPageHeaderV2>::WriteStream() { + facebook::velox::parquet::thrift::DataPageHeaderV2>::writeStream() { for (int i = 0; i < kNumPages; ++i) { // Vary the number of rows to produce different headers. - int32_t num_rows = i + 100; - total_rows_ += num_rows; - int data_size = i + 1024; - this->data_page_header_v2_.__set_num_values(num_rows); - this->data_page_header_v2_.__set_num_rows(num_rows); - this->data_page_header_v2_.statistics.__set_min_value( - "A" + std::to_string(i)); - this->data_page_header_v2_.statistics.__set_max_value( - "Z" + std::to_string(i)); - this->data_page_header_v2_.statistics.__set_null_count(0); - this->data_page_header_v2_.statistics.__set_distinct_count(num_rows); - this->data_page_header_v2_.__isset.statistics = true; - ASSERT_NO_FATAL_FAILURE(this->WriteDataPageHeaderV2( - /*max_serialized_len=*/1024, data_size, data_size)); - data_page_headers_.push_back(this->data_page_header_v2_); + int32_t numRows = i + 100; + totalRows_ += numRows; + int dataSize = i + 1024; + this->dataPageHeaderV2_.__set_num_values(numRows); + this->dataPageHeaderV2_.__set_num_rows(numRows); + this->dataPageHeaderV2_.statistics.__set_min_value("A" + std::to_string(i)); + this->dataPageHeaderV2_.statistics.__set_max_value("Z" + std::to_string(i)); + this->dataPageHeaderV2_.statistics.__set_null_count(0); + this->dataPageHeaderV2_.statistics.__set_distinct_count(numRows); + this->dataPageHeaderV2_.__isset.statistics = true; + ASSERT_NO_FATAL_FAILURE( + this->writeDataPageHeaderV2(1024, dataSize, dataSize)); + dataPageHeaders_.push_back(this->dataPageHeaderV2_); // Also write data, to make sure we skip the data correctly. - std::vector faux_data(data_size); - ASSERT_OK(this->out_stream_->Write(faux_data.data(), data_size)); + std::vector fauxData(dataSize); + ASSERT_OK(this->outStream_->Write(fauxData.data(), dataSize)); } - this->EndStream(); + this->endStream(); } template <> void PageFilterTest< - facebook::velox::parquet::thrift::DataPageHeader>::WritePageWithoutStats() { - int32_t num_rows = 100; - total_rows_ += num_rows; - int data_size = 1024; - this->data_page_header_.__set_num_values(num_rows); - ASSERT_NO_FATAL_FAILURE(this->WriteDataPageHeader( - /*max_serialized_len=*/1024, data_size, data_size)); - data_page_headers_.push_back(this->data_page_header_); - std::vector faux_data(data_size); - ASSERT_OK(this->out_stream_->Write(faux_data.data(), data_size)); - this->EndStream(); + facebook::velox::parquet::thrift::DataPageHeader>::writePageWithoutStats() { + int32_t numRows = 100; + totalRows_ += numRows; + int dataSize = 1024; + this->dataPageHeader_.__set_num_values(numRows); + ASSERT_NO_FATAL_FAILURE(this->writeDataPageHeader(1024, dataSize, dataSize)); + dataPageHeaders_.push_back(this->dataPageHeader_); + std::vector fauxData(dataSize); + ASSERT_OK(this->outStream_->Write(fauxData.data(), dataSize)); + this->endStream(); } template <> void PageFilterTest:: - WritePageWithoutStats() { - int32_t num_rows = 100; - total_rows_ += num_rows; - int data_size = 1024; - this->data_page_header_v2_.__set_num_values(num_rows); - this->data_page_header_v2_.__set_num_rows(num_rows); - ASSERT_NO_FATAL_FAILURE(this->WriteDataPageHeaderV2( - /*max_serialized_len=*/1024, data_size, data_size)); - data_page_headers_.push_back(this->data_page_header_v2_); - std::vector faux_data(data_size); - ASSERT_OK(this->out_stream_->Write(faux_data.data(), data_size)); - this->EndStream(); + writePageWithoutStats() { + int32_t numRows = 100; + totalRows_ += numRows; + int dataSize = 1024; + this->dataPageHeaderV2_.__set_num_values(numRows); + this->dataPageHeaderV2_.__set_num_rows(numRows); + ASSERT_NO_FATAL_FAILURE( + this->writeDataPageHeaderV2(1024, dataSize, dataSize)); + dataPageHeaders_.push_back(this->dataPageHeaderV2_); + std::vector fauxData(dataSize); + ASSERT_OK(this->outStream_->Write(fauxData.data(), dataSize)); + this->endStream(); } template <> void PageFilterTest:: - CheckNumRows( - std::optional num_rows, + checkNumRows( + std::optional numRows, const facebook::velox::parquet::thrift::DataPageHeader& header) { - ASSERT_EQ(num_rows, std::nullopt); + ASSERT_EQ(numRows, std::nullopt); } template <> void PageFilterTest:: - CheckNumRows( - std::optional num_rows, + checkNumRows( + std::optional numRows, const facebook::velox::parquet::thrift::DataPageHeaderV2& header) { - ASSERT_EQ(*num_rows, header.num_rows); + ASSERT_EQ(*numRows, header.num_rows); } using DataPageHeaderTypes = ::testing::Types< @@ -512,415 +497,400 @@ using DataPageHeaderTypes = ::testing::Types< facebook::velox::parquet::thrift::DataPageHeaderV2>; TYPED_TEST_SUITE(PageFilterTest, DataPageHeaderTypes); -// Test that the returned encoded_statistics is nullptr when there are no -// statistics in the page header. +// Test that the returned encoded_statistics is nullptr when there are no. +// Statistics in the page header. TYPED_TEST(PageFilterTest, TestPageWithoutStatistics) { - this->WritePageWithoutStats(); + this->writePageWithoutStats(); - auto stream = std::make_shared<::arrow::io::BufferReader>(this->out_buffer_); - this->page_reader_ = - PageReader::Open(stream, this->total_rows_, Compression::UNCOMPRESSED); + auto stream = std::make_shared<::arrow::io::BufferReader>(this->outBuffer_); + this->pageReader_ = + PageReader::open(stream, this->totalRows_, Compression::UNCOMPRESSED); - int num_pages = 0; - bool is_stats_null = false; - auto read_all_pages = [&](const DataPageStats& stats) -> bool { - is_stats_null = stats.encoded_statistics == nullptr; - ++num_pages; + int numPages = 0; + bool isStatsNull = false; + auto readAllPages = [&](const DataPageStats& stats) -> bool { + isStatsNull = stats.encodedStatistics == nullptr; + ++numPages; return false; }; - this->page_reader_->set_data_page_filter(read_all_pages); - std::shared_ptr current_page = this->page_reader_->NextPage(); - ASSERT_EQ(num_pages, 1); - ASSERT_EQ(is_stats_null, true); - ASSERT_EQ(this->page_reader_->NextPage(), nullptr); + this->pageReader_->setDataPageFilter(readAllPages); + std::shared_ptr currentPage = this->pageReader_->nextPage(); + ASSERT_EQ(numPages, 1); + ASSERT_EQ(isStatsNull, true); + ASSERT_EQ(this->pageReader_->nextPage(), nullptr); } -// Creates a number of pages and skips some of them with the page filter -// callback. +// Creates a number of pages and skips some of them with the page filter. +// Callback. TYPED_TEST(PageFilterTest, TestPageFilterCallback) { - this->WriteStream(); + this->writeStream(); { // Read all pages. - // Also check that the encoded statistics passed to the callback function - // are right. - auto stream = - std::make_shared<::arrow::io::BufferReader>(this->out_buffer_); - this->page_reader_ = - PageReader::Open(stream, this->total_rows_, Compression::UNCOMPRESSED); - - std::vector read_stats; - std::vector read_num_values; - std::vector> read_num_rows; - auto read_all_pages = [&](const DataPageStats& stats) -> bool { - VELOX_DCHECK_NOT_NULL(stats.encoded_statistics); - read_stats.push_back(*stats.encoded_statistics); - read_num_values.push_back(stats.num_values); - read_num_rows.push_back(stats.num_rows); + // Also check that the encoded statistics passed to the callback function. + // Are right. + auto stream = std::make_shared<::arrow::io::BufferReader>(this->outBuffer_); + this->pageReader_ = + PageReader::open(stream, this->totalRows_, Compression::UNCOMPRESSED); + + std::vector readStats; + std::vector readNumValues; + std::vector> readNumRows; + auto readAllPages = [&](const DataPageStats& stats) -> bool { + VELOX_DCHECK_NOT_NULL(stats.encodedStatistics); + readStats.push_back(*stats.encodedStatistics); + readNumValues.push_back(stats.numValues); + readNumRows.push_back(stats.numRows); return false; }; - this->page_reader_->set_data_page_filter(read_all_pages); + this->pageReader_->setDataPageFilter(readAllPages); for (int i = 0; i < this->kNumPages; ++i) { - std::shared_ptr current_page = this->page_reader_->NextPage(); - ASSERT_NE(current_page, nullptr); + std::shared_ptr currentPage = this->pageReader_->nextPage(); + ASSERT_NE(currentPage, nullptr); ASSERT_NO_FATAL_FAILURE( - CheckDataPageHeader(this->data_page_headers_[i], current_page.get())); - auto data_page = static_cast(current_page.get()); - const EncodedStatistics encoded_statistics = data_page->statistics(); - ASSERT_EQ(read_stats[i].max(), encoded_statistics.max()); - ASSERT_EQ(read_stats[i].min(), encoded_statistics.min()); - ASSERT_EQ(read_stats[i].null_count, encoded_statistics.null_count); - ASSERT_EQ( - read_stats[i].distinct_count, encoded_statistics.distinct_count); - ASSERT_EQ(read_num_values[i], this->data_page_headers_[i].num_values); - this->CheckNumRows(read_num_rows[i], this->data_page_headers_[i]); + checkDataPageHeader(this->dataPageHeaders_[i], currentPage.get())); + auto dataPage = static_cast(currentPage.get()); + const EncodedStatistics EncodedStatistics = dataPage->statistics(); + ASSERT_EQ(readStats[i].max(), EncodedStatistics.max()); + ASSERT_EQ(readStats[i].min(), EncodedStatistics.min()); + ASSERT_EQ(readStats[i].nullCount, EncodedStatistics.nullCount); + ASSERT_EQ(readStats[i].distinctCount, EncodedStatistics.distinctCount); + ASSERT_EQ(readNumValues[i], this->dataPageHeaders_[i].num_values); + this->checkNumRows(readNumRows[i], this->dataPageHeaders_[i]); } - ASSERT_EQ(this->page_reader_->NextPage(), nullptr); + ASSERT_EQ(this->pageReader_->nextPage(), nullptr); } { // Skip all pages. - auto stream = - std::make_shared<::arrow::io::BufferReader>(this->out_buffer_); - this->page_reader_ = - PageReader::Open(stream, this->total_rows_, Compression::UNCOMPRESSED); + auto stream = std::make_shared<::arrow::io::BufferReader>(this->outBuffer_); + this->pageReader_ = + PageReader::open(stream, this->totalRows_, Compression::UNCOMPRESSED); - auto skip_all_pages = [](const DataPageStats& stats) -> bool { - return true; - }; + auto skipAllPages = [](const DataPageStats& stats) -> bool { return true; }; - this->page_reader_->set_data_page_filter(skip_all_pages); - std::shared_ptr current_page = this->page_reader_->NextPage(); - ASSERT_EQ(this->page_reader_->NextPage(), nullptr); + this->pageReader_->setDataPageFilter(skipAllPages); + std::shared_ptr currentPage = this->pageReader_->nextPage(); + ASSERT_EQ(this->pageReader_->nextPage(), nullptr); } { // Skip every other page. - auto stream = - std::make_shared<::arrow::io::BufferReader>(this->out_buffer_); - this->page_reader_ = - PageReader::Open(stream, this->total_rows_, Compression::UNCOMPRESSED); + auto stream = std::make_shared<::arrow::io::BufferReader>(this->outBuffer_); + this->pageReader_ = + PageReader::open(stream, this->totalRows_, Compression::UNCOMPRESSED); // Skip pages with even number of values. - auto skip_even_pages = [](const DataPageStats& stats) -> bool { - if (stats.num_values % 2 == 0) + auto skipEvenPages = [](const DataPageStats& stats) -> bool { + if (stats.numValues % 2 == 0) return true; return false; }; - this->page_reader_->set_data_page_filter(skip_even_pages); + this->pageReader_->setDataPageFilter(skipEvenPages); for (int i = 0; i < this->kNumPages; ++i) { // Only pages with odd number of values are read. if (i % 2 != 0) { - std::shared_ptr current_page = this->page_reader_->NextPage(); - ASSERT_NE(current_page, nullptr); - ASSERT_NO_FATAL_FAILURE(CheckDataPageHeader( - this->data_page_headers_[i], current_page.get())); + std::shared_ptr currentPage = this->pageReader_->nextPage(); + ASSERT_NE(currentPage, nullptr); + ASSERT_NO_FATAL_FAILURE( + checkDataPageHeader(this->dataPageHeaders_[i], currentPage.get())); } } // We should have exhausted reading the pages by reading the odd pages only. - ASSERT_EQ(this->page_reader_->NextPage(), nullptr); + ASSERT_EQ(this->pageReader_->nextPage(), nullptr); } } -// Set the page filter more than once. The new filter should be effective -// on the next NextPage() call. +// Set the page filter more than once. The new filter should be effective. +// On the next NextPage() call. TYPED_TEST(PageFilterTest, TestChangingPageFilter) { - this->WriteStream(); + this->writeStream(); - auto stream = std::make_shared<::arrow::io::BufferReader>(this->out_buffer_); - this->page_reader_ = - PageReader::Open(stream, this->total_rows_, Compression::UNCOMPRESSED); + auto stream = std::make_shared<::arrow::io::BufferReader>(this->outBuffer_); + this->pageReader_ = + PageReader::open(stream, this->totalRows_, Compression::UNCOMPRESSED); // This callback will always return false. - auto read_all_pages = [](const DataPageStats& stats) -> bool { - return false; - }; - this->page_reader_->set_data_page_filter(read_all_pages); - std::shared_ptr current_page = this->page_reader_->NextPage(); - ASSERT_NE(current_page, nullptr); + auto readAllPages = [](const DataPageStats& stats) -> bool { return false; }; + this->pageReader_->setDataPageFilter(readAllPages); + std::shared_ptr currentPage = this->pageReader_->nextPage(); + ASSERT_NE(currentPage, nullptr); ASSERT_NO_FATAL_FAILURE( - CheckDataPageHeader(this->data_page_headers_[0], current_page.get())); + checkDataPageHeader(this->dataPageHeaders_[0], currentPage.get())); // This callback will skip all pages. - auto skip_all_pages = [](const DataPageStats& stats) -> bool { return true; }; - this->page_reader_->set_data_page_filter(skip_all_pages); - ASSERT_EQ(this->page_reader_->NextPage(), nullptr); + auto skipAllPages = [](const DataPageStats& stats) -> bool { return true; }; + this->pageReader_->setDataPageFilter(skipAllPages); + ASSERT_EQ(this->pageReader_->nextPage(), nullptr); } // Test that we do not skip dictionary pages. TEST_F(TestPageSerde, DoesNotFilterDictionaryPages) { - int data_size = 1024; - std::vector faux_data(data_size); + int dataSize = 1024; + std::vector fauxData(dataSize); - ASSERT_NO_FATAL_FAILURE( - WriteDataPageHeader(/*max_serialized_len=*/1024, data_size, data_size)); - ASSERT_OK(out_stream_->Write(faux_data.data(), data_size)); + ASSERT_NO_FATAL_FAILURE(writeDataPageHeader(1024, dataSize, dataSize)); + ASSERT_OK(outStream_->Write(fauxData.data(), dataSize)); - ASSERT_NO_FATAL_FAILURE(WriteDictionaryPageHeader(data_size, data_size)); - ASSERT_OK(out_stream_->Write(faux_data.data(), data_size)); + ASSERT_NO_FATAL_FAILURE(writeDictionaryPageHeader(dataSize, dataSize)); + ASSERT_OK(outStream_->Write(fauxData.data(), dataSize)); - ASSERT_NO_FATAL_FAILURE( - WriteDataPageHeader(/*max_serialized_len=*/1024, data_size, data_size)); - ASSERT_OK(out_stream_->Write(faux_data.data(), data_size)); - EndStream(); + ASSERT_NO_FATAL_FAILURE(writeDataPageHeader(1024, dataSize, dataSize)); + ASSERT_OK(outStream_->Write(fauxData.data(), dataSize)); + endStream(); // Try to read it back while asking for all data pages to be skipped. - auto stream = std::make_shared<::arrow::io::BufferReader>(out_buffer_); - page_reader_ = - PageReader::Open(stream, /*num_rows=*/100, Compression::UNCOMPRESSED); + auto stream = std::make_shared<::arrow::io::BufferReader>(outBuffer_); + pageReader_ = PageReader::open(stream, 100, Compression::UNCOMPRESSED); - auto skip_all_pages = [](const DataPageStats& stats) -> bool { return true; }; + auto skipAllPages = [](const DataPageStats& stats) -> bool { return true; }; - page_reader_->set_data_page_filter(skip_all_pages); + pageReader_->setDataPageFilter(skipAllPages); // The first data page is skipped, so we are now at the dictionary page. - std::shared_ptr current_page = page_reader_->NextPage(); - ASSERT_NE(current_page, nullptr); - ASSERT_EQ(current_page->type(), PageType::DICTIONARY_PAGE); + std::shared_ptr currentPage = pageReader_->nextPage(); + ASSERT_NE(currentPage, nullptr); + ASSERT_EQ(currentPage->type(), PageType::kDictionaryPage); // The data page after dictionary page is skipped. - ASSERT_EQ(page_reader_->NextPage(), nullptr); + ASSERT_EQ(pageReader_->nextPage(), nullptr); } // Tests that we successfully skip non-data pages. TEST_F(TestPageSerde, SkipsNonDataPages) { - int data_size = 1024; - std::vector faux_data(data_size); - ASSERT_NO_FATAL_FAILURE(WriteIndexPageHeader(data_size, data_size)); - ASSERT_OK(out_stream_->Write(faux_data.data(), data_size)); + int dataSize = 1024; + std::vector fauxData(dataSize); + ASSERT_NO_FATAL_FAILURE(writeIndexPageHeader(dataSize, dataSize)); + ASSERT_OK(outStream_->Write(fauxData.data(), dataSize)); - ASSERT_NO_FATAL_FAILURE( - WriteDataPageHeader(/*max_serialized_len=*/1024, data_size, data_size)); - ASSERT_OK(out_stream_->Write(faux_data.data(), data_size)); + ASSERT_NO_FATAL_FAILURE(writeDataPageHeader(1024, dataSize, dataSize)); + ASSERT_OK(outStream_->Write(fauxData.data(), dataSize)); - ASSERT_NO_FATAL_FAILURE(WriteIndexPageHeader(data_size, data_size)); - ASSERT_OK(out_stream_->Write(faux_data.data(), data_size)); - ASSERT_NO_FATAL_FAILURE(WriteIndexPageHeader(data_size, data_size)); - ASSERT_OK(out_stream_->Write(faux_data.data(), data_size)); + ASSERT_NO_FATAL_FAILURE(writeIndexPageHeader(dataSize, dataSize)); + ASSERT_OK(outStream_->Write(fauxData.data(), dataSize)); + ASSERT_NO_FATAL_FAILURE(writeIndexPageHeader(dataSize, dataSize)); + ASSERT_OK(outStream_->Write(fauxData.data(), dataSize)); - ASSERT_NO_FATAL_FAILURE( - WriteDataPageHeader(/*max_serialized_len=*/1024, data_size, data_size)); - ASSERT_OK(out_stream_->Write(faux_data.data(), data_size)); - ASSERT_NO_FATAL_FAILURE(WriteIndexPageHeader(data_size, data_size)); - ASSERT_OK(out_stream_->Write(faux_data.data(), data_size)); - EndStream(); + ASSERT_NO_FATAL_FAILURE(writeDataPageHeader(1024, dataSize, dataSize)); + ASSERT_OK(outStream_->Write(fauxData.data(), dataSize)); + ASSERT_NO_FATAL_FAILURE(writeIndexPageHeader(dataSize, dataSize)); + ASSERT_OK(outStream_->Write(fauxData.data(), dataSize)); + endStream(); - auto stream = std::make_shared<::arrow::io::BufferReader>(out_buffer_); - page_reader_ = - PageReader::Open(stream, /*num_rows=*/100, Compression::UNCOMPRESSED); + auto stream = std::make_shared<::arrow::io::BufferReader>(outBuffer_); + pageReader_ = PageReader::open(stream, 100, Compression::UNCOMPRESSED); // Only the two data pages are returned. - std::shared_ptr current_page = page_reader_->NextPage(); - ASSERT_EQ(current_page->type(), PageType::DATA_PAGE); - current_page = page_reader_->NextPage(); - ASSERT_EQ(current_page->type(), PageType::DATA_PAGE); - ASSERT_EQ(page_reader_->NextPage(), nullptr); + std::shared_ptr currentPage = pageReader_->nextPage(); + ASSERT_EQ(currentPage->type(), PageType::kDataPage); + currentPage = pageReader_->nextPage(); + ASSERT_EQ(currentPage->type(), PageType::kDataPage); + ASSERT_EQ(pageReader_->nextPage(), nullptr); } TEST_F(TestPageSerde, DataPageV2) { - int stats_size = 512; - const int32_t num_rows = 4444; - AddDummyStats(stats_size, data_page_header_v2_, /* fill_all_stats = */ true); - data_page_header_v2_.num_values = num_rows; - - ASSERT_NO_FATAL_FAILURE(WriteDataPageHeaderV2()); - InitSerializedPageReader(num_rows); - std::shared_ptr current_page = page_reader_->NextPage(); + int statsSize = 512; + const int32_t numRows = 4444; + addDummyStats(statsSize, dataPageHeaderV2_, /*fill_all_stats=*/true); + dataPageHeaderV2_.num_values = numRows; + + ASSERT_NO_FATAL_FAILURE(writeDataPageHeaderV2()); + initSerializedPageReader(numRows); + std::shared_ptr currentPage = pageReader_->nextPage(); ASSERT_NO_FATAL_FAILURE( - CheckDataPageHeader(data_page_header_v2_, current_page.get())); + checkDataPageHeader(dataPageHeaderV2_, currentPage.get())); } TEST_F(TestPageSerde, TestLargePageHeaders) { - int stats_size = 256 * 1024; // 256 KB - AddDummyStats(stats_size, data_page_header_); + int statsSize = 256 * 1024; // 256 KB + addDummyStats(statsSize, dataPageHeader_); - // Any number to verify metadata roundtrip - const int32_t num_rows = 4141; - data_page_header_.num_values = num_rows; + // Any number to verify metadata roundtrip. + const int32_t numRows = 4141; + dataPageHeader_.num_values = numRows; - int max_header_size = 512 * 1024; // 512 KB - ASSERT_NO_FATAL_FAILURE(WriteDataPageHeader(max_header_size)); + int maxHeaderSize = 512 * 1024; // 512 KB + ASSERT_NO_FATAL_FAILURE(writeDataPageHeader(maxHeaderSize)); - ASSERT_OK_AND_ASSIGN(int64_t position, out_stream_->Tell()); - ASSERT_GE(max_header_size, position); + ASSERT_OK_AND_ASSIGN(int64_t position, outStream_->Tell()); + ASSERT_GE(maxHeaderSize, position); - // check header size is between 256 KB to 16 MB - ASSERT_LE(stats_size, position); + // Check header size is between 256 KB to 16 MB. + ASSERT_LE(statsSize, position); ASSERT_GE(kDefaultMaxPageHeaderSize, position); - InitSerializedPageReader(num_rows); - std::shared_ptr current_page = page_reader_->NextPage(); + initSerializedPageReader(numRows); + std::shared_ptr currentPage = pageReader_->nextPage(); ASSERT_NO_FATAL_FAILURE( - CheckDataPageHeader(data_page_header_, current_page.get())); + checkDataPageHeader(dataPageHeader_, currentPage.get())); } TEST_F(TestPageSerde, TestFailLargePageHeaders) { - const int32_t num_rows = 1337; // dummy value + const int32_t numRows = 1337; // dummy value - int stats_size = 256 * 1024; // 256 KB - AddDummyStats(stats_size, data_page_header_); + int statsSize = 256 * 1024; // 256 KB + addDummyStats(statsSize, dataPageHeader_); - // Serialize the Page header - int max_header_size = 512 * 1024; // 512 KB - ASSERT_NO_FATAL_FAILURE(WriteDataPageHeader(max_header_size)); - ASSERT_OK_AND_ASSIGN(int64_t position, out_stream_->Tell()); - ASSERT_GE(max_header_size, position); + // Serialize the Page header. + int maxHeaderSize = 512 * 1024; // 512 KB + ASSERT_NO_FATAL_FAILURE(writeDataPageHeader(maxHeaderSize)); + ASSERT_OK_AND_ASSIGN(int64_t position, outStream_->Tell()); + ASSERT_GE(maxHeaderSize, position); - int smaller_max_size = 128 * 1024; - ASSERT_LE(smaller_max_size, position); - InitSerializedPageReader(num_rows); + int smallerMaxSize = 128 * 1024; + ASSERT_LE(smallerMaxSize, position); + initSerializedPageReader(numRows); - // Set the max page header size to 128 KB, which is less than the current - // header size - page_reader_->set_max_page_header_size(smaller_max_size); - ASSERT_THROW(page_reader_->NextPage(), ParquetException); + // Set the max page header size to 128 KB, which is less than the current. + // Header size. + pageReader_->setMaxPageHeaderSize(smallerMaxSize); + ASSERT_THROW(pageReader_->nextPage(), ParquetException); } -void TestPageSerde::TestPageCompressionRoundTrip( - const std::vector& page_sizes) { - auto codec_types = GetSupportedCodecTypes(); +void TestPageSerde::testPageCompressionRoundTrip( + const std::vector& pageSizes) { + auto codecTypes = getSupportedCodecTypes(); - const int32_t num_rows = 32; // dummy value - data_page_header_.num_values = num_rows; + const int32_t numRows = 32; // dummy value + dataPageHeader_.num_values = numRows; - std::vector> faux_data; - int num_pages = static_cast(page_sizes.size()); - faux_data.resize(num_pages); - for (int i = 0; i < num_pages; ++i) { - test::random_bytes(page_sizes[i], 0, &faux_data[i]); + std::vector> fauxData; + int numPages = static_cast(pageSizes.size()); + fauxData.resize(numPages); + for (int i = 0; i < numPages; ++i) { + test::randomBytes(pageSizes[i], 0, &fauxData[i]); } - for (auto codec_type : codec_types) { - auto codec = GetCodec(codec_type); + for (auto codecType : codecTypes) { + auto Codec = getCodec(codecType); std::vector buffer; - for (int i = 0; i < num_pages; ++i) { - const uint8_t* data = faux_data[i].data(); - int data_size = static_cast(faux_data[i].size()); + for (int i = 0; i < numPages; ++i) { + const uint8_t* data = fauxData[i].data(); + int dataSize = static_cast(fauxData[i].size()); - int64_t max_compressed_size = codec->MaxCompressedLen(data_size, data); - buffer.resize(max_compressed_size); + int64_t maxCompressedSize = Codec->maxCompressedLen(dataSize, data); + buffer.resize(maxCompressedSize); - int64_t actual_size; + int64_t actualSize; ASSERT_OK_AND_ASSIGN( - actual_size, - codec->Compress(data_size, data, max_compressed_size, &buffer[0])); + actualSize, + Codec->compress(dataSize, data, maxCompressedSize, &buffer[0])); - ASSERT_NO_FATAL_FAILURE(WriteDataPageHeader( - 1024, data_size, static_cast(actual_size))); - ASSERT_OK(out_stream_->Write(buffer.data(), actual_size)); + ASSERT_NO_FATAL_FAILURE(writeDataPageHeader( + 1024, dataSize, static_cast(actualSize))); + ASSERT_OK(outStream_->Write(buffer.data(), actualSize)); } - InitSerializedPageReader(num_rows * num_pages, codec_type); + initSerializedPageReader(numRows * numPages, codecType); std::shared_ptr page; - const DataPageV1* data_page; - for (int i = 0; i < num_pages; ++i) { - int data_size = static_cast(faux_data[i].size()); - page = page_reader_->NextPage(); - data_page = static_cast(page.get()); - ASSERT_EQ(data_size, data_page->size()); - ASSERT_EQ(0, memcmp(faux_data[i].data(), data_page->data(), data_size)); + const DataPageV1* dataPage; + for (int i = 0; i < numPages; ++i) { + int dataSize = static_cast(fauxData[i].size()); + page = pageReader_->nextPage(); + dataPage = static_cast(page.get()); + ASSERT_EQ(dataSize, dataPage->size()); + ASSERT_EQ(0, memcmp(fauxData[i].data(), dataPage->data(), dataSize)); } - ResetStream(); + resetStream(); } } TEST_F(TestPageSerde, Compression) { - std::vector page_sizes; - page_sizes.reserve(10); + std::vector pageSizes; + pageSizes.reserve(10); for (int i = 0; i < 10; ++i) { - // The pages keep getting larger - page_sizes.push_back((i + 1) * 64); + // The pages keep getting larger. + pageSizes.push_back((i + 1) * 64); } - this->TestPageCompressionRoundTrip(page_sizes); + this->testPageCompressionRoundTrip(pageSizes); } TEST_F(TestPageSerde, PageSizeResetWhenRead) { - // GH-35423: Parquet SerializedPageReader need to - // reset the size after getting a smaller page. - std::vector page_sizes; - page_sizes.reserve(10); + // GH-35423: Parquet SerializedPageReader need to. + // Reset the size after getting a smaller page. + std::vector pageSizes; + pageSizes.reserve(10); for (int i = 0; i < 10; ++i) { - // The pages keep getting smaller - page_sizes.push_back((10 - i) * 64); + // The pages keep getting smaller. + pageSizes.push_back((10 - i) * 64); } - this->TestPageCompressionRoundTrip(page_sizes); + this->testPageCompressionRoundTrip(pageSizes); } TEST_F(TestPageSerde, LZONotSupported) { - // Must await PARQUET-530 - int data_size = 1024; - std::vector faux_data(data_size); - ASSERT_NO_FATAL_FAILURE(WriteDataPageHeader(1024, data_size, data_size)); - ASSERT_OK(out_stream_->Write(faux_data.data(), data_size)); + // Must await PARQUET-530. + int dataSize = 1024; + std::vector fauxData(dataSize); + ASSERT_NO_FATAL_FAILURE(writeDataPageHeader(1024, dataSize, dataSize)); + ASSERT_OK(outStream_->Write(fauxData.data(), dataSize)); ASSERT_THROW( - InitSerializedPageReader(data_size, Compression::LZO), ParquetException); + initSerializedPageReader(dataSize, Compression::LZO), ParquetException); } TEST_F(TestPageSerde, NoCrc) { - int stats_size = 512; - const int32_t num_rows = 4444; - AddDummyStats(stats_size, data_page_header_, /*fill_all_stats=*/true); - data_page_header_.num_values = num_rows; - - ASSERT_NO_FATAL_FAILURE(WriteDataPageHeader()); - ReaderProperties readerProperties; - readerProperties.set_page_checksum_verification(true); - InitSerializedPageReader( - num_rows, Compression::UNCOMPRESSED, readerProperties); - std::shared_ptr current_page = page_reader_->NextPage(); + int statsSize = 512; + const int32_t numRows = 4444; + addDummyStats(statsSize, dataPageHeader_, true); + dataPageHeader_.num_values = numRows; + + ASSERT_NO_FATAL_FAILURE(writeDataPageHeader()); + ReaderProperties ReaderProperties; + ReaderProperties.setPageChecksumVerification(true); + initSerializedPageReader( + numRows, Compression::UNCOMPRESSED, ReaderProperties); + std::shared_ptr currentPage = pageReader_->nextPage(); ASSERT_NO_FATAL_FAILURE( - CheckDataPageHeader(data_page_header_, current_page.get())); + checkDataPageHeader(dataPageHeader_, currentPage.get())); } TEST_F(TestPageSerde, NoCrcDict) { - const int32_t num_rows = 4444; - dictionary_page_header_.num_values = num_rows; + const int32_t numRows = 4444; + dictionaryPageHeader_.num_values = numRows; - ASSERT_NO_FATAL_FAILURE(WriteDictionaryPageHeader()); - ReaderProperties readerProperties; - readerProperties.set_page_checksum_verification(true); - InitSerializedPageReader( - num_rows, Compression::UNCOMPRESSED, readerProperties); - std::shared_ptr current_page = page_reader_->NextPage(); + ASSERT_NO_FATAL_FAILURE(writeDictionaryPageHeader()); + ReaderProperties ReaderProperties; + ReaderProperties.setPageChecksumVerification(true); + initSerializedPageReader( + numRows, Compression::UNCOMPRESSED, ReaderProperties); + std::shared_ptr currentPage = pageReader_->nextPage(); - ASSERT_EQ(PageType::DICTIONARY_PAGE, current_page->type()); + ASSERT_EQ(PageType::kDictionaryPage, currentPage->type()); - const auto* dict_page = - static_cast(current_page.get()); - EXPECT_EQ(num_rows, dict_page->num_values()); + const auto* dictPage = static_cast(currentPage.get()); + EXPECT_EQ(numRows, dictPage->numValues()); } TEST_F(TestPageSerde, CrcCheckSuccessful) { - this->TestPageSerdeCrc( + this->testPageSerdeCrc( /* write_checksum */ true, /* write_page_corrupt */ false, /* verification_checksum */ true); } TEST_F(TestPageSerde, CrcCheckFail) { - this->TestPageSerdeCrc( + this->testPageSerdeCrc( /* write_checksum */ true, /* write_page_corrupt */ true, /* verification_checksum */ true); } TEST_F(TestPageSerde, CrcCorruptNotChecked) { - this->TestPageSerdeCrc( + this->testPageSerdeCrc( /* write_checksum */ true, /* write_page_corrupt */ true, /* verification_checksum */ false); } TEST_F(TestPageSerde, CrcCheckNonExistent) { - this->TestPageSerdeCrc( + this->testPageSerdeCrc( /* write_checksum */ false, /* write_page_corrupt */ false, /* verification_checksum */ true); } TEST_F(TestPageSerde, DictCrcCheckSuccessful) { - this->TestPageSerdeCrc( + this->testPageSerdeCrc( /* write_checksum */ true, /* write_page_corrupt */ false, /* verification_checksum */ true, @@ -928,7 +898,7 @@ TEST_F(TestPageSerde, DictCrcCheckSuccessful) { } TEST_F(TestPageSerde, DictCrcCheckFail) { - this->TestPageSerdeCrc( + this->testPageSerdeCrc( /* write_checksum */ true, /* write_page_corrupt */ true, /* verification_checksum */ true, @@ -936,7 +906,7 @@ TEST_F(TestPageSerde, DictCrcCheckFail) { } TEST_F(TestPageSerde, DictCrcCorruptNotChecked) { - this->TestPageSerdeCrc( + this->testPageSerdeCrc( /* write_checksum */ true, /* write_page_corrupt */ true, /* verification_checksum */ false, @@ -944,7 +914,7 @@ TEST_F(TestPageSerde, DictCrcCorruptNotChecked) { } TEST_F(TestPageSerde, DictCrcCheckNonExistent) { - this->TestPageSerdeCrc( + this->testPageSerdeCrc( /* write_checksum */ false, /* write_page_corrupt */ false, /* verification_checksum */ true, @@ -952,7 +922,7 @@ TEST_F(TestPageSerde, DictCrcCheckNonExistent) { } TEST_F(TestPageSerde, DataPageV2CrcCheckSuccessful) { - this->TestPageSerdeCrc( + this->testPageSerdeCrc( /* write_checksum */ true, /* write_page_corrupt */ false, /* verification_checksum */ true, @@ -961,7 +931,7 @@ TEST_F(TestPageSerde, DataPageV2CrcCheckSuccessful) { } TEST_F(TestPageSerde, DataPageV2CrcCheckFail) { - this->TestPageSerdeCrc( + this->testPageSerdeCrc( /* write_checksum */ true, /* write_page_corrupt */ true, /* verification_checksum */ true, @@ -970,7 +940,7 @@ TEST_F(TestPageSerde, DataPageV2CrcCheckFail) { } TEST_F(TestPageSerde, DataPageV2CrcCorruptNotChecked) { - this->TestPageSerdeCrc( + this->testPageSerdeCrc( /* write_checksum */ true, /* write_page_corrupt */ true, /* verification_checksum */ false, @@ -979,7 +949,7 @@ TEST_F(TestPageSerde, DataPageV2CrcCorruptNotChecked) { } TEST_F(TestPageSerde, DataPageV2CrcCheckNonExistent) { - this->TestPageSerdeCrc( + this->testPageSerdeCrc( /* write_checksum */ false, /* write_page_corrupt */ false, /* verification_checksum */ true, @@ -987,22 +957,26 @@ TEST_F(TestPageSerde, DataPageV2CrcCheckNonExistent) { /* write_data_page_v2 */ true); } -// ---------------------------------------------------------------------- -// File structure tests +// ----------------------------------------------------------------------. +// File structure tests. class TestParquetFileReader : public ::testing::Test { public: - void AssertInvalidFileThrows(const std::shared_ptr& buffer) { + void assertInvalidFileThrows(const std::shared_ptr& buffer) { auto reader = std::make_shared(buffer); - // Write the buffer to a temp file path - auto filePath = exec::test::TempFilePath::create(); + // Write the buffer to a temp file path. + auto filePath = common::testutil::TempFilePath::create(); writeToFile(filePath, buffer); memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); std::shared_ptr rootPool = memory::memoryManager()->addRootPool("MetadataTest"); std::shared_ptr leafPool = rootPool->addLeafChild("MetadataTest"); - dwio::common::ReaderOptions readerOptions{leafPool.get()}; + auto dataIoStats = std::make_shared(); + auto metadataIoStats = std::make_shared(); + dwio::common::ReaderOptions readerOptions(leafPool.get()); + readerOptions.setDataIoStats(dataIoStats); + readerOptions.setMetadataIoStats(metadataIoStats); auto input = std::make_unique( std::make_shared(filePath->getPath()), readerOptions.memoryPool()); @@ -1014,26 +988,26 @@ class TestParquetFileReader : public ::testing::Test { }; TEST_F(TestParquetFileReader, InvalidHeader) { - const char* bad_header = "PAR2"; + const char* badHeader = "PAR2"; - auto buffer = Buffer::Wrap(bad_header, strlen(bad_header)); - ASSERT_NO_FATAL_FAILURE(AssertInvalidFileThrows(buffer)); + auto buffer = Buffer::Wrap(badHeader, strlen(badHeader)); + ASSERT_NO_FATAL_FAILURE(assertInvalidFileThrows(buffer)); } TEST_F(TestParquetFileReader, InvalidFooter) { - // File is smaller than FOOTER_SIZE - const char* bad_file = "PAR1PAR"; - auto buffer = Buffer::Wrap(bad_file, strlen(bad_file)); - ASSERT_NO_FATAL_FAILURE(AssertInvalidFileThrows(buffer)); - - // Magic number incorrect - const char* bad_file2 = "PAR1PAR2"; - buffer = Buffer::Wrap(bad_file2, strlen(bad_file2)); - ASSERT_NO_FATAL_FAILURE(AssertInvalidFileThrows(buffer)); + // File is smaller than FOOTER_SIZE. + const char* badFile = "PAR1PAR"; + auto buffer = Buffer::Wrap(badFile, strlen(badFile)); + ASSERT_NO_FATAL_FAILURE(assertInvalidFileThrows(buffer)); + + // Magic number incorrect. + const char* badFile2 = "PAR1PAR2"; + buffer = Buffer::Wrap(badFile2, strlen(badFile2)); + ASSERT_NO_FATAL_FAILURE(assertInvalidFileThrows(buffer)); } TEST_F(TestParquetFileReader, IncompleteMetadata) { - auto stream = CreateOutputStream(); + auto stream = createOutputStream(); const char* magic = "PAR1"; @@ -1041,24 +1015,28 @@ TEST_F(TestParquetFileReader, IncompleteMetadata) { stream->Write(reinterpret_cast(magic), strlen(magic))); std::vector bytes(10); ASSERT_OK(stream->Write(bytes.data(), bytes.size())); - uint32_t metadata_len = 24; + uint32_t metadataLen = 24; ASSERT_OK(stream->Write( - reinterpret_cast(&metadata_len), sizeof(uint32_t))); + reinterpret_cast(&metadataLen), sizeof(uint32_t))); ASSERT_OK( stream->Write(reinterpret_cast(magic), strlen(magic))); ASSERT_OK_AND_ASSIGN(auto buffer, stream->Finish()); auto reader = std::make_shared(buffer); - // Write the buffer to a temp file path - auto filePath = exec::test::TempFilePath::create(); + // Write the buffer to a temp file path. + auto filePath = TempFilePath::create(); writeToFile(filePath, buffer); memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); std::shared_ptr rootPool = memory::memoryManager()->addRootPool("MetadataTest"); std::shared_ptr leafPool = rootPool->addLeafChild("MetadataTest"); - dwio::common::ReaderOptions readerOptions{leafPool.get()}; + auto dataIoStats = std::make_shared(); + auto metadataIoStats = std::make_shared(); + dwio::common::ReaderOptions readerOptions(leafPool.get()); + readerOptions.setDataIoStats(dataIoStats); + readerOptions.setMetadataIoStats(metadataIoStats); auto input = std::make_unique( std::make_shared(filePath->getPath()), readerOptions.memoryPool()); diff --git a/velox/dwio/parquet/writer/arrow/tests/FileReader.cpp b/velox/dwio/parquet/writer/arrow/tests/FileReader.cpp index b1a8f3148bf..b80f25f1c96 100644 --- a/velox/dwio/parquet/writer/arrow/tests/FileReader.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/FileReader.cpp @@ -55,64 +55,63 @@ using arrow::internal::AddWithOverflow; namespace facebook::velox::parquet::arrow { -// PARQUET-978: Minimize footer reads by reading 64 KB from the end of the file +// PARQUET-978: Minimize footer reads by reading 64 KB from the end of the file. static constexpr int64_t kDefaultFooterReadSize = 64 * 1024; static constexpr uint32_t kFooterSize = 8; -// For PARQUET-816 +// For PARQUET-816. static constexpr int64_t kMaxDictHeaderSize = 100; -// ---------------------------------------------------------------------- -// RowGroupReader public API +// ----------------------------------------------------------------------. +// RowGroupReader public API. RowGroupReader::RowGroupReader(std::unique_ptr contents) : contents_(std::move(contents)) {} -std::shared_ptr RowGroupReader::Column(int i) { - if (i >= metadata()->num_columns()) { +std::shared_ptr RowGroupReader::column(int i) { + if (i >= metadata()->numColumns()) { std::stringstream ss; ss << "Trying to read column index " << i - << " but row group metadata has only " << metadata()->num_columns() + << " but row group metadata has only " << metadata()->numColumns() << " columns"; throw ParquetException(ss.str()); } - const ColumnDescriptor* descr = metadata()->schema()->Column(i); + const ColumnDescriptor* descr = metadata()->schema()->column(i); - std::unique_ptr page_reader = contents_->GetColumnPageReader(i); - return ColumnReader::Make( + std::unique_ptr PageReader = contents_->getColumnPageReader(i); + return ColumnReader::make( descr, - std::move(page_reader), - const_cast(contents_->properties())->memory_pool()); + std::move(PageReader), + const_cast(contents_->properties())->memoryPool()); } -std::shared_ptr RowGroupReader::ColumnWithExposeEncoding( +std::shared_ptr RowGroupReader::columnWithExposeEncoding( int i, - ExposedEncoding encoding_to_expose) { - std::shared_ptr reader = Column(i); + ExposedEncoding encodingToExpose) { + std::shared_ptr reader = column(i); - if (encoding_to_expose == ExposedEncoding::DICTIONARY) { + if (encodingToExpose == ExposedEncoding::kDictionary) { // Check the encoding_stats to see if all data pages are dictionary encoded. - std::unique_ptr col = metadata()->ColumnChunk(i); - const std::vector& encoding_stats = - col->encoding_stats(); - if (encoding_stats.empty()) { - // Some parquet files may have empty encoding_stats. In this case we are - // not sure whether all data pages are dictionary encoded. So we do not - // enable exposing dictionary. + std::unique_ptr col = metadata()->columnChunk(i); + const std::vector& encodingStats = col->encodingStats(); + if (encodingStats.empty()) { + // Some parquet files may have empty encoding_stats. In this case we are. + // Not sure whether all data pages are dictionary encoded. So we do not. + // Enable exposing dictionary. return reader; } // The 1st page should be the dictionary page. - if (encoding_stats[0].page_type != PageType::DICTIONARY_PAGE || - (encoding_stats[0].encoding != Encoding::PLAIN && - encoding_stats[0].encoding != Encoding::PLAIN_DICTIONARY)) { + if (encodingStats[0].pageType != PageType::kDictionaryPage || + (encodingStats[0].encoding != Encoding::kPlain && + encodingStats[0].encoding != Encoding::kPlainDictionary)) { return reader; } // The following pages should be dictionary encoded data pages. - for (size_t idx = 1; idx < encoding_stats.size(); ++idx) { - if ((encoding_stats[idx].encoding != Encoding::RLE_DICTIONARY && - encoding_stats[idx].encoding != Encoding::PLAIN_DICTIONARY) || - (encoding_stats[idx].page_type != PageType::DATA_PAGE && - encoding_stats[idx].page_type != PageType::DATA_PAGE_V2)) { + for (size_t idx = 1; idx < encodingStats.size(); ++idx) { + if ((encodingStats[idx].encoding != Encoding::kRleDictionary && + encodingStats[idx].encoding != Encoding::kPlainDictionary) || + (encodingStats[idx].pageType != PageType::kDataPage && + encodingStats[idx].pageType != PageType::kDataPageV2)) { return reader; } } @@ -122,138 +121,135 @@ std::shared_ptr RowGroupReader::ColumnWithExposeEncoding( } // Set exposed encoding. - reader->SetExposedEncoding(encoding_to_expose); + reader->setExposedEncoding(encodingToExpose); return reader; } -std::unique_ptr RowGroupReader::GetColumnPageReader(int i) { - if (i >= metadata()->num_columns()) { +std::unique_ptr RowGroupReader::getColumnPageReader(int i) { + if (i >= metadata()->numColumns()) { std::stringstream ss; ss << "Trying to read column index " << i - << " but row group metadata has only " << metadata()->num_columns() + << " but row group metadata has only " << metadata()->numColumns() << " columns"; throw ParquetException(ss.str()); } - return contents_->GetColumnPageReader(i); + return contents_->getColumnPageReader(i); } -// Returns the rowgroup metadata +// Returns the rowgroup metadata. const RowGroupMetaData* RowGroupReader::metadata() const { return contents_->metadata(); } -/// Compute the section of the file that should be read for the given -/// row group and column chunk. -::arrow::io::ReadRange ComputeColumnChunkRange( - FileMetaData* file_metadata, - int64_t source_size, - int row_group_index, - int column_index) { - auto row_group_metadata = file_metadata->RowGroup(row_group_index); - auto column_metadata = row_group_metadata->ColumnChunk(column_index); - - int64_t col_start = column_metadata->data_page_offset(); - if (column_metadata->has_dictionary_page() && - column_metadata->dictionary_page_offset() > 0 && - col_start > column_metadata->dictionary_page_offset()) { - col_start = column_metadata->dictionary_page_offset(); +/// Compute the section of the file that should be read for the given. +/// Row group and column chunk. +::arrow::io::ReadRange computeColumnChunkRange( + FileMetaData* fileMetadata, + int64_t sourceSize, + int rowGroupIndex, + int columnIndex) { + auto rowGroupMetadata = fileMetadata->rowGroup(rowGroupIndex); + auto columnMetadata = rowGroupMetadata->columnChunk(columnIndex); + + int64_t colStart = columnMetadata->dataPageOffset(); + if (columnMetadata->hasDictionaryPage() && + columnMetadata->dictionaryPageOffset() > 0 && + colStart > columnMetadata->dictionaryPageOffset()) { + colStart = columnMetadata->dictionaryPageOffset(); } - int64_t col_length = column_metadata->total_compressed_size(); - int64_t col_end; - if (col_start < 0 || col_length < 0) { + int64_t colLength = columnMetadata->totalCompressedSize(); + int64_t colEnd; + if (colStart < 0 || colLength < 0) { throw ParquetException("Invalid column metadata (corrupt file?)"); } - if (AddWithOverflow(col_start, col_length, &col_end) || - col_end > source_size) { + if (AddWithOverflow(colStart, colLength, &colEnd) || colEnd > sourceSize) { throw ParquetException("Invalid column metadata (corrupt file?)"); } - // PARQUET-816 workaround for old files created by older parquet-mr - const ApplicationVersion& version = file_metadata->writer_version(); - if (version.VersionLt(ApplicationVersion::PARQUET_816_FIXED_VERSION())) { - // The Parquet MR writer had a bug in 1.2.8 and below where it didn't - // include the dictionary page header size in total_compressed_size and - // total_uncompressed_size (see IMPALA-694). We add padding to compensate. - int64_t bytes_remaining = source_size - col_end; - int64_t padding = std::min(kMaxDictHeaderSize, bytes_remaining); - col_length += padding; + // PARQUET-816 workaround for old files created by older parquet-mr. + const ApplicationVersion& version = fileMetadata->writerVersion(); + if (version.versionLt(ApplicationVersion::PARQUET_816_FIXED_VERSION())) { + // The Parquet MR writer had a bug in 1.2.8 and below where it didn't. + // Include the dictionary page header size in total_compressed_size and. + // Total_uncompressed_size (see IMPALA-694). We add padding to compensate. + int64_t bytesRemaining = sourceSize - colEnd; + int64_t padding = std::min(kMaxDictHeaderSize, bytesRemaining); + colLength += padding; } - return {col_start, col_length}; + return {colStart, colLength}; } -// RowGroupReader::Contents implementation for the Parquet file specification +// RowGroupReader::Contents implementation for the Parquet file specification. class SerializedRowGroup : public RowGroupReader::Contents { public: SerializedRowGroup( std::shared_ptr source, - std::shared_ptr<::arrow::io::internal::ReadRangeCache> cached_source, - int64_t source_size, - FileMetaData* file_metadata, - int row_group_number, + std::shared_ptr<::arrow::io::internal::ReadRangeCache> cachedSource, + int64_t sourceSize, + FileMetaData* fileMetadata, + int rowGroupNumber, const ReaderProperties& props, - std::shared_ptr prebuffered_column_chunks_bitmap, - std::shared_ptr file_decryptor = nullptr) + std::shared_ptr prebufferedColumnChunksBitmap, + std::shared_ptr fileDecryptor = nullptr) : source_(std::move(source)), - cached_source_(std::move(cached_source)), - source_size_(source_size), - file_metadata_(file_metadata), + cachedSource_(std::move(cachedSource)), + sourceSize_(sourceSize), + fileMetadata_(fileMetadata), properties_(props), - row_group_ordinal_(row_group_number), - prebuffered_column_chunks_bitmap_( - std::move(prebuffered_column_chunks_bitmap)), - file_decryptor_(file_decryptor) { - row_group_metadata_ = file_metadata->RowGroup(row_group_number); + rowGroupOrdinal_(rowGroupNumber), + prebufferedColumnChunksBitmap_( + std::move(prebufferedColumnChunksBitmap)), + fileDecryptor_(fileDecryptor) { + rowGroupMetadata_ = fileMetadata->rowGroup(rowGroupNumber); } const RowGroupMetaData* metadata() const override { - return row_group_metadata_.get(); + return rowGroupMetadata_.get(); } const ReaderProperties* properties() const override { return &properties_; } - std::unique_ptr GetColumnPageReader(int i) override { - // Read column chunk from the file - auto col = row_group_metadata_->ColumnChunk(i); + std::unique_ptr getColumnPageReader(int i) override { + // Read column chunk from the file. + auto col = rowGroupMetadata_->columnChunk(i); - ::arrow::io::ReadRange col_range = ComputeColumnChunkRange( - file_metadata_, source_size_, row_group_ordinal_, i); + ::arrow::io::ReadRange colRange = computeColumnChunkRange( + fileMetadata_, sourceSize_, rowGroupOrdinal_, i); std::shared_ptr stream; - if (cached_source_ && prebuffered_column_chunks_bitmap_ != nullptr && - ::arrow::bit_util::GetBit( - prebuffered_column_chunks_bitmap_->data(), i)) { - // PARQUET-1698: if read coalescing is enabled, read from pre-buffered - // segments. - PARQUET_ASSIGN_OR_THROW(auto buffer, cached_source_->Read(col_range)); + if (cachedSource_ && prebufferedColumnChunksBitmap_ != nullptr && + ::arrow::bit_util::GetBit(prebufferedColumnChunksBitmap_->data(), i)) { + // PARQUET-1698: if read coalescing is enabled, read from pre-buffered. + // Segments. + PARQUET_ASSIGN_OR_THROW(auto buffer, cachedSource_->Read(colRange)); stream = std::make_shared<::arrow::io::BufferReader>(buffer); } else { - stream = - properties_.GetStream(source_, col_range.offset, col_range.length); + stream = properties_.getStream(source_, colRange.offset, colRange.length); } - std::unique_ptr crypto_metadata = - col->crypto_metadata(); + std::unique_ptr cryptoMetadata = + col->cryptoMetadata(); - // Prior to Arrow 3.0.0, is_compressed was always set to false in column - // headers, even if compression was used. See ARROW-17100. - bool always_compressed = file_metadata_->writer_version().VersionLt( + // Prior to Arrow 3.0.0, is_compressed was always set to false in column. + // Headers, even if compression was used. See ARROW-17100. + bool alwaysCompressed = fileMetadata_->writerVersion().versionLt( ApplicationVersion::PARQUET_CPP_10353_FIXED_VERSION()); // Column is encrypted only if crypto_metadata exists. - if (!crypto_metadata) { - return PageReader::Open( + if (!cryptoMetadata) { + return PageReader::open( stream, - col->num_values(), + col->numValues(), col->compression(), properties_, - always_compressed); + alwaysCompressed); } - if (file_decryptor_ == nullptr) { + if (fileDecryptor_ == nullptr) { throw ParquetException( "RowGroup is noted as encrypted but no file decryptor"); } @@ -264,420 +260,415 @@ class SerializedRowGroup : public RowGroupReader::Contents { "Encrypted files cannot contain more than 32767 row groups"); } - // The column is encrypted - std::shared_ptr meta_decryptor; - std::shared_ptr data_decryptor; - // The column is encrypted with footer key - if (crypto_metadata->encrypted_with_footer_key()) { - meta_decryptor = file_decryptor_->GetFooterDecryptorForColumnMeta(); - data_decryptor = file_decryptor_->GetFooterDecryptorForColumnData(); + // The column is encrypted. + std::shared_ptr metaDecryptor; + std::shared_ptr dataDecryptor; + // The column is encrypted with footer key. + if (cryptoMetadata->encryptedWithFooterKey()) { + metaDecryptor = fileDecryptor_->getFooterDecryptorForColumnMeta(); + dataDecryptor = fileDecryptor_->getFooterDecryptorForColumnData(); CryptoContext ctx( - col->has_dictionary_page(), - row_group_ordinal_, + col->hasDictionaryPage(), + rowGroupOrdinal_, static_cast(i), - meta_decryptor, - data_decryptor); - return PageReader::Open( + metaDecryptor, + dataDecryptor); + return PageReader::open( stream, - col->num_values(), + col->numValues(), col->compression(), properties_, - always_compressed, + alwaysCompressed, &ctx); } - // The column is encrypted with its own key - std::string column_key_metadata = crypto_metadata->key_metadata(); - const std::string column_path = - crypto_metadata->path_in_schema()->ToDotString(); + // The column is encrypted with its own key. + std::string columnKeyMetadata = cryptoMetadata->keyMetadata(); + const std::string ColumnPath = + cryptoMetadata->pathInSchema()->toDotString(); - meta_decryptor = file_decryptor_->GetColumnMetaDecryptor( - column_path, column_key_metadata); - data_decryptor = file_decryptor_->GetColumnDataDecryptor( - column_path, column_key_metadata); + metaDecryptor = + fileDecryptor_->getColumnMetaDecryptor(ColumnPath, columnKeyMetadata); + dataDecryptor = + fileDecryptor_->getColumnDataDecryptor(ColumnPath, columnKeyMetadata); CryptoContext ctx( - col->has_dictionary_page(), - row_group_ordinal_, + col->hasDictionaryPage(), + rowGroupOrdinal_, static_cast(i), - meta_decryptor, - data_decryptor); - return PageReader::Open( + metaDecryptor, + dataDecryptor); + return PageReader::open( stream, - col->num_values(), + col->numValues(), col->compression(), properties_, - always_compressed, + alwaysCompressed, &ctx); } private: std::shared_ptr source_; // Will be nullptr if PreBuffer() is not called. - std::shared_ptr<::arrow::io::internal::ReadRangeCache> cached_source_; - int64_t source_size_; - FileMetaData* file_metadata_; - std::unique_ptr row_group_metadata_; + std::shared_ptr<::arrow::io::internal::ReadRangeCache> cachedSource_; + int64_t sourceSize_; + FileMetaData* fileMetadata_; + std::unique_ptr rowGroupMetadata_; ReaderProperties properties_; - int row_group_ordinal_; - const std::shared_ptr prebuffered_column_chunks_bitmap_; - std::shared_ptr file_decryptor_; + int rowGroupOrdinal_; + const std::shared_ptr prebufferedColumnChunksBitmap_; + std::shared_ptr fileDecryptor_; }; -// ---------------------------------------------------------------------- -// SerializedFile: An implementation of ParquetFileReader::Contents that deals -// with the Parquet file structure, Thrift deserialization, and other internal -// matters +// ----------------------------------------------------------------------. +// SerializedFile: An implementation of ParquetFileReader::Contents that deals. +// With the Parquet file structure, Thrift deserialization, and other internal. +// Matters. -// This class takes ownership of the provided data source +// This class takes ownership of the provided data source. class SerializedFile : public ParquetFileReader::Contents { public: SerializedFile( std::shared_ptr source, - const ReaderProperties& props = default_reader_properties()) + const ReaderProperties& props = defaultReaderProperties()) : source_(std::move(source)), properties_(props) { - PARQUET_ASSIGN_OR_THROW(source_size_, source_->GetSize()); + PARQUET_ASSIGN_OR_THROW(sourceSize_, source_->GetSize()); } ~SerializedFile() override { try { - Close(); + close(); } catch (...) { } } - void Close() override { - if (file_decryptor_) - file_decryptor_->WipeOutDecryptionKeys(); + void close() override { + if (fileDecryptor_) + fileDecryptor_->wipeOutDecryptionKeys(); } - std::shared_ptr GetRowGroup(int i) override { - std::shared_ptr prebuffered_column_chunks_bitmap; + std::shared_ptr getRowGroup(int i) override { + std::shared_ptr prebufferedColumnChunksBitmap; // Avoid updating the bitmap as this function can be called concurrently. // The bitmap can only be updated within Prebuffer(). - auto prebuffered_column_chunks_iter = prebuffered_column_chunks_.find(i); - if (prebuffered_column_chunks_iter != prebuffered_column_chunks_.end()) { - prebuffered_column_chunks_bitmap = prebuffered_column_chunks_iter->second; + auto prebufferedColumnChunksIter = prebufferedColumnChunks_.find(i); + if (prebufferedColumnChunksIter != prebufferedColumnChunks_.end()) { + prebufferedColumnChunksBitmap = prebufferedColumnChunksIter->second; } std::unique_ptr contents = std::make_unique( source_, - cached_source_, - source_size_, - file_metadata_.get(), + cachedSource_, + sourceSize_, + fileMetadata_.get(), i, properties_, - std::move(prebuffered_column_chunks_bitmap), - file_decryptor_); + std::move(prebufferedColumnChunksBitmap), + fileDecryptor_); return std::make_shared(std::move(contents)); } std::shared_ptr metadata() const override { - return file_metadata_; + return fileMetadata_; } - std::shared_ptr GetPageIndexReader() override { - if (!file_metadata_) { + std::shared_ptr getPageIndexReader() override { + if (!fileMetadata_) { // Usually this won't happen if user calls one of the static Open() - // functions to create a ParquetFileReader instance. But if user calls the - // constructor directly and calls GetPageIndexReader() before Open() then - // this could happen. + // Functions to create a ParquetFileReader instance. But if user calls + // the. Constructor directly and calls GetPageIndexReader() before Open() + // then. This could happen. throw ParquetException( "Cannot call GetPageIndexReader() due to missing file metadata. Did you " "forget to call ParquetFileReader::Open() first?"); } - if (!page_index_reader_) { - page_index_reader_ = PageIndexReader::Make( - source_.get(), file_metadata_, properties_, file_decryptor_); + if (!pageIndexReader_) { + pageIndexReader_ = PageIndexReader::make( + source_.get(), fileMetadata_, properties_, fileDecryptor_); } - return page_index_reader_; + return pageIndexReader_; } - BloomFilterReader& GetBloomFilterReader() override { - if (!file_metadata_) { + BloomFilterReader& getBloomFilterReader() override { + if (!fileMetadata_) { // Usually this won't happen if user calls one of the static Open() - // functions to create a ParquetFileReader instance. But if user calls the - // constructor directly and calls GetBloomFilterReader() before Open() - // then this could happen. + // Functions to create a ParquetFileReader instance. But if user calls + // the. constructor directly and calls GetBloomFilterReader() before + // Open() Then this could happen. throw ParquetException( "Cannot call GetBloomFilterReader() due to missing file metadata. Did you " "forget to call ParquetFileReader::Open() first?"); } - if (!bloom_filter_reader_) { - bloom_filter_reader_ = BloomFilterReader::Make( - source_, file_metadata_, properties_, file_decryptor_); - if (bloom_filter_reader_ == nullptr) { + if (!bloomFilterReader_) { + bloomFilterReader_ = BloomFilterReader::make( + source_, fileMetadata_, properties_, fileDecryptor_); + if (bloomFilterReader_ == nullptr) { throw ParquetException("Cannot create BloomFilterReader"); } } - return *bloom_filter_reader_; + return *bloomFilterReader_; } - void set_metadata(std::shared_ptr metadata) { - file_metadata_ = std::move(metadata); + void setMetadata(std::shared_ptr metadata) { + fileMetadata_ = std::move(metadata); } - void PreBuffer( - const std::vector& row_groups, - const std::vector& column_indices, + void preBuffer( + const std::vector& rowGroups, + const std::vector& columnIndices, const ::arrow::io::IOContext& ctx, const ::arrow::io::CacheOptions& options) { - cached_source_ = std::make_shared<::arrow::io::internal::ReadRangeCache>( + cachedSource_ = std::make_shared<::arrow::io::internal::ReadRangeCache>( source_, ctx, options); std::vector<::arrow::io::ReadRange> ranges; - prebuffered_column_chunks_.clear(); - for (int row : row_groups) { - std::shared_ptr& col_bitmap = prebuffered_column_chunks_[row]; - int num_cols = file_metadata_->num_columns(); + prebufferedColumnChunks_.clear(); + for (int row : rowGroups) { + std::shared_ptr& colBitmap = prebufferedColumnChunks_[row]; + int numCols = fileMetadata_->numColumns(); PARQUET_THROW_NOT_OK( - AllocateEmptyBitmap(num_cols, properties_.memory_pool()) - .Value(&col_bitmap)); - for (int col : column_indices) { - ::arrow::bit_util::SetBit(col_bitmap->mutable_data(), col); - ranges.push_back(ComputeColumnChunkRange( - file_metadata_.get(), source_size_, row, col)); + ::arrow::AllocateEmptyBitmap(numCols, properties_.memoryPool()) + .Value(&colBitmap)); + for (int col : columnIndices) { + ::arrow::bit_util::SetBit(colBitmap->mutable_data(), col); + ranges.push_back(computeColumnChunkRange( + fileMetadata_.get(), sourceSize_, row, col)); } } - PARQUET_THROW_NOT_OK(cached_source_->Cache(ranges)); + PARQUET_THROW_NOT_OK(cachedSource_->Cache(ranges)); } - ::arrow::Future<> WhenBuffered( - const std::vector& row_groups, - const std::vector& column_indices) const { - if (!cached_source_) { + ::arrow::Future<> whenBuffered( + const std::vector& rowGroups, + const std::vector& columnIndices) const { + if (!cachedSource_) { return ::arrow::Status::Invalid( "Must call PreBuffer before WhenBuffered"); } std::vector<::arrow::io::ReadRange> ranges; - for (int row : row_groups) { - for (int col : column_indices) { - ranges.push_back(ComputeColumnChunkRange( - file_metadata_.get(), source_size_, row, col)); + for (int row : rowGroups) { + for (int col : columnIndices) { + ranges.push_back(computeColumnChunkRange( + fileMetadata_.get(), sourceSize_, row, col)); } } - return cached_source_->WaitFor(ranges); + return cachedSource_->WaitFor(ranges); } - // Metadata/footer parsing. Divided up to separate sync/async paths, and to - // use exceptions for error handling (with the async path converting to + // Metadata/footer parsing. Divided up to separate sync/async paths, and to. + // Use exceptions for error handling (with the async path converting to. // Future/Status). - void ParseMetaData() { - int64_t footer_read_size = GetFooterReadSize(); + void parseMetaData() { + int64_t footerReadSize = getFooterReadSize(); PARQUET_ASSIGN_OR_THROW( - auto footer_buffer, - source_->ReadAt(source_size_ - footer_read_size, footer_read_size)); - uint32_t metadata_len = ParseFooterLength(footer_buffer, footer_read_size); - int64_t metadata_start = source_size_ - kFooterSize - metadata_len; - - std::shared_ptr<::arrow::Buffer> metadata_buffer; - if (footer_read_size >= (metadata_len + kFooterSize)) { - metadata_buffer = SliceBuffer( - footer_buffer, - footer_read_size - metadata_len - kFooterSize, - metadata_len); + auto footerBuffer, + source_->ReadAt(sourceSize_ - footerReadSize, footerReadSize)); + uint32_t metadataLen = parseFooterLength(footerBuffer, footerReadSize); + int64_t metadataStart = sourceSize_ - kFooterSize - metadataLen; + + std::shared_ptr<::arrow::Buffer> metadataBuffer; + if (footerReadSize >= (metadataLen + kFooterSize)) { + metadataBuffer = ::arrow::SliceBuffer( + footerBuffer, + footerReadSize - metadataLen - kFooterSize, + metadataLen); } else { PARQUET_ASSIGN_OR_THROW( - metadata_buffer, source_->ReadAt(metadata_start, metadata_len)); + metadataBuffer, source_->ReadAt(metadataStart, metadataLen)); } - // Parse the footer depending on encryption type - const bool is_encrypted_footer = - memcmp( - footer_buffer->data() + footer_read_size - 4, kParquetEMagic, 4) == + // Parse the footer depending on encryption type. + const bool isEncryptedFooter = + memcmp(footerBuffer->data() + footerReadSize - 4, kParquetEMagic, 4) == 0; - if (is_encrypted_footer) { + if (isEncryptedFooter) { // Encrypted file with Encrypted footer. - const std::pair read_size = - ParseMetaDataOfEncryptedFileWithEncryptedFooter( - metadata_buffer, metadata_len); - // Read the actual footer - metadata_start = read_size.first; - metadata_len = read_size.second; + const std::pair readSize = + parseMetaDataOfEncryptedFileWithEncryptedFooter( + metadataBuffer, metadataLen); + // Read the actual footer. + metadataStart = readSize.first; + metadataLen = readSize.second; PARQUET_ASSIGN_OR_THROW( - metadata_buffer, source_->ReadAt(metadata_start, metadata_len)); - // Fall through + metadataBuffer, source_->ReadAt(metadataStart, metadataLen)); + // Fall through. } - const uint32_t read_metadata_len = - ParseUnencryptedFileMetadata(metadata_buffer, metadata_len); - auto file_decryption_properties = - properties_.file_decryption_properties().get(); - if (is_encrypted_footer) { + const uint32_t readMetadataLen = + parseUnencryptedFileMetadata(metadataBuffer, metadataLen); + auto fileDecryptionProperties = + properties_.fileDecryptionProperties().get(); + if (isEncryptedFooter) { // Nothing else to do here. return; - } else if (!file_metadata_ - ->is_encryption_algorithm_set()) { // Non encrypted file. - if (file_decryption_properties != nullptr) { - if (!file_decryption_properties->plaintext_files_allowed()) { + } else if (!fileMetadata_ + ->isEncryptionAlgorithmSet()) { // Non encrypted file. + if (fileDecryptionProperties != nullptr) { + if (!fileDecryptionProperties->plaintextFilesAllowed()) { throw ParquetException( "Applying decryption properties on plaintext file"); } } } else { // Encrypted file with plaintext footer mode. - ParseMetaDataOfEncryptedFileWithPlaintextFooter( - file_decryption_properties, - metadata_buffer, - metadata_len, - read_metadata_len); + parseMetaDataOfEncryptedFileWithPlaintextFooter( + fileDecryptionProperties, + metadataBuffer, + metadataLen, + readMetadataLen); } } // Validate the source size and get the initial read size. - int64_t GetFooterReadSize() { - if (source_size_ == 0) { + int64_t getFooterReadSize() { + if (sourceSize_ == 0) { throw ParquetInvalidOrCorruptedFileException( "Parquet file size is 0 bytes"); - } else if (source_size_ < kFooterSize) { + } else if (sourceSize_ < kFooterSize) { throw ParquetInvalidOrCorruptedFileException( "Parquet file size is ", - source_size_, + sourceSize_, " bytes, smaller than the minimum file footer (", kFooterSize, " bytes)"); } - return std::min(source_size_, kDefaultFooterReadSize); + return std::min(sourceSize_, kDefaultFooterReadSize); } // Validate the magic bytes and get the length of the full footer. - uint32_t ParseFooterLength( - const std::shared_ptr<::arrow::Buffer>& footer_buffer, - const int64_t footer_read_size) { - // Check if all bytes are read. Check if last 4 bytes read have the magic - // bits - if (footer_buffer->size() != footer_read_size || - (memcmp( - footer_buffer->data() + footer_read_size - 4, kParquetMagic, 4) != + uint32_t parseFooterLength( + const std::shared_ptr<::arrow::Buffer>& footerBuffer, + const int64_t footerReadSize) { + // Check if all bytes are read. Check if last 4 bytes read have the magic. + // Bits. + if (footerBuffer->size() != footerReadSize || + (memcmp(footerBuffer->data() + footerReadSize - 4, kParquetMagic, 4) != 0 && - memcmp( - footer_buffer->data() + footer_read_size - 4, kParquetEMagic, 4) != + memcmp(footerBuffer->data() + footerReadSize - 4, kParquetEMagic, 4) != 0)) { throw ParquetInvalidOrCorruptedFileException( "Parquet magic bytes not found in footer. Either the file is corrupted or this " "is not a parquet file."); } // Both encrypted/unencrypted footers have the same footer length check. - uint32_t metadata_len = ::arrow::util::SafeLoadAs( - reinterpret_cast(footer_buffer->data()) + - footer_read_size - kFooterSize); - if (metadata_len > source_size_ - kFooterSize) { + uint32_t metadataLen = ::arrow::util::SafeLoadAs( + reinterpret_cast(footerBuffer->data()) + + footerReadSize - kFooterSize); + if (metadataLen > sourceSize_ - kFooterSize) { throw ParquetInvalidOrCorruptedFileException( "Parquet file size is ", - source_size_, + sourceSize_, " bytes, smaller than the size reported by footer's (", - metadata_len, + metadataLen, "bytes)"); } - return metadata_len; + return metadataLen; } // Does not throw. - ::arrow::Future<> ParseMetaDataAsync() { - int64_t footer_read_size; + ::arrow::Future<> parseMetaDataAsync() { + int64_t footerReadSize; BEGIN_PARQUET_CATCH_EXCEPTIONS - footer_read_size = GetFooterReadSize(); + footerReadSize = getFooterReadSize(); END_PARQUET_CATCH_EXCEPTIONS - // Assumes this is kept alive externally - return source_->ReadAsync(source_size_ - footer_read_size, footer_read_size) + // Assumes this is kept alive externally. + return source_->ReadAsync(sourceSize_ - footerReadSize, footerReadSize) .Then( - [this, footer_read_size]( - const std::shared_ptr<::arrow::Buffer>& footer_buffer) + [this, footerReadSize]( + const std::shared_ptr<::arrow::Buffer>& footerBuffer) -> ::arrow::Future<> { - uint32_t metadata_len; + uint32_t metadataLen; BEGIN_PARQUET_CATCH_EXCEPTIONS - metadata_len = ParseFooterLength(footer_buffer, footer_read_size); + metadataLen = parseFooterLength(footerBuffer, footerReadSize); END_PARQUET_CATCH_EXCEPTIONS - int64_t metadata_start = - source_size_ - kFooterSize - metadata_len; - - std::shared_ptr<::arrow::Buffer> metadata_buffer; - if (footer_read_size >= (metadata_len + kFooterSize)) { - metadata_buffer = SliceBuffer( - footer_buffer, - footer_read_size - metadata_len - kFooterSize, - metadata_len); - return ParseMaybeEncryptedMetaDataAsync( - footer_buffer, - std::move(metadata_buffer), - footer_read_size, - metadata_len); + int64_t metadataStart = sourceSize_ - kFooterSize - metadataLen; + + std::shared_ptr<::arrow::Buffer> metadataBuffer; + if (footerReadSize >= (metadataLen + kFooterSize)) { + metadataBuffer = ::arrow::SliceBuffer( + footerBuffer, + footerReadSize - metadataLen - kFooterSize, + metadataLen); + return parseMaybeEncryptedMetaDataAsync( + footerBuffer, + std::move(metadataBuffer), + footerReadSize, + metadataLen); } - return source_->ReadAsync(metadata_start, metadata_len) - .Then([this, footer_buffer, footer_read_size, metadata_len]( + return source_->ReadAsync(metadataStart, metadataLen) + .Then([this, footerBuffer, footerReadSize, metadataLen]( const std::shared_ptr<::arrow::Buffer>& - metadata_buffer) { - return ParseMaybeEncryptedMetaDataAsync( - footer_buffer, - metadata_buffer, - footer_read_size, - metadata_len); + metadataBuffer) { + return parseMaybeEncryptedMetaDataAsync( + footerBuffer, + metadataBuffer, + footerReadSize, + metadataLen); }); }); } - // Continuation - ::arrow::Future<> ParseMaybeEncryptedMetaDataAsync( - std::shared_ptr<::arrow::Buffer> footer_buffer, - std::shared_ptr<::arrow::Buffer> metadata_buffer, - int64_t footer_read_size, - uint32_t metadata_len) { - // Parse the footer depending on encryption type - const bool is_encrypted_footer = - memcmp( - footer_buffer->data() + footer_read_size - 4, kParquetEMagic, 4) == + // Continuation. + ::arrow::Future<> parseMaybeEncryptedMetaDataAsync( + std::shared_ptr<::arrow::Buffer> footerBuffer, + std::shared_ptr<::arrow::Buffer> metadataBuffer, + int64_t footerReadSize, + uint32_t metadataLen) { + // Parse the footer depending on encryption type. + const bool isEncryptedFooter = + memcmp(footerBuffer->data() + footerReadSize - 4, kParquetEMagic, 4) == 0; - if (is_encrypted_footer) { + if (isEncryptedFooter) { // Encrypted file with Encrypted footer. - std::pair read_size; + std::pair readSize; BEGIN_PARQUET_CATCH_EXCEPTIONS - read_size = ParseMetaDataOfEncryptedFileWithEncryptedFooter( - metadata_buffer, metadata_len); + readSize = parseMetaDataOfEncryptedFileWithEncryptedFooter( + metadataBuffer, metadataLen); END_PARQUET_CATCH_EXCEPTIONS - // Read the actual footer - int64_t metadata_start = read_size.first; - metadata_len = read_size.second; - return source_->ReadAsync(metadata_start, metadata_len) - .Then([this, metadata_len, is_encrypted_footer]( - const std::shared_ptr<::arrow::Buffer>& metadata_buffer) { - // Continue and read the file footer - return ParseMetaDataFinal( - metadata_buffer, metadata_len, is_encrypted_footer); + // Read the actual footer. + int64_t metadataStart = readSize.first; + metadataLen = readSize.second; + return source_->ReadAsync(metadataStart, metadataLen) + .Then([this, metadataLen, isEncryptedFooter]( + const std::shared_ptr<::arrow::Buffer>& metadataBuffer) { + // Continue and read the file footer. + return parseMetaDataFinal( + metadataBuffer, metadataLen, isEncryptedFooter); }); } - return ParseMetaDataFinal( - std::move(metadata_buffer), metadata_len, is_encrypted_footer); + return parseMetaDataFinal( + std::move(metadataBuffer), metadataLen, isEncryptedFooter); } - // Continuation - ::arrow::Status ParseMetaDataFinal( - std::shared_ptr<::arrow::Buffer> metadata_buffer, - uint32_t metadata_len, - const bool is_encrypted_footer) { + // Continuation. + ::arrow::Status parseMetaDataFinal( + std::shared_ptr<::arrow::Buffer> metadataBuffer, + uint32_t metadataLen, + const bool isEncryptedFooter) { BEGIN_PARQUET_CATCH_EXCEPTIONS - const uint32_t read_metadata_len = - ParseUnencryptedFileMetadata(metadata_buffer, metadata_len); - auto file_decryption_properties = - properties_.file_decryption_properties().get(); - if (is_encrypted_footer) { + const uint32_t readMetadataLen = + parseUnencryptedFileMetadata(metadataBuffer, metadataLen); + auto fileDecryptionProperties = + properties_.fileDecryptionProperties().get(); + if (isEncryptedFooter) { // Nothing else to do here. return ::arrow::Status::OK(); - } else if (!file_metadata_ - ->is_encryption_algorithm_set()) { // Non encrypted file. - if (file_decryption_properties != nullptr) { - if (!file_decryption_properties->plaintext_files_allowed()) { + } else if (!fileMetadata_ + ->isEncryptionAlgorithmSet()) { // Non encrypted file. + if (fileDecryptionProperties != nullptr) { + if (!fileDecryptionProperties->plaintextFilesAllowed()) { throw ParquetException( "Applying decryption properties on plaintext file"); } } } else { // Encrypted file with plaintext footer mode. - ParseMetaDataOfEncryptedFileWithPlaintextFooter( - file_decryption_properties, - metadata_buffer, - metadata_len, - read_metadata_len); + parseMetaDataOfEncryptedFileWithPlaintextFooter( + fileDecryptionProperties, + metadataBuffer, + metadataLen, + readMetadataLen); } END_PARQUET_CATCH_EXCEPTIONS return ::arrow::Status::OK(); @@ -685,130 +676,126 @@ class SerializedFile : public ParquetFileReader::Contents { private: std::shared_ptr source_; - std::shared_ptr<::arrow::io::internal::ReadRangeCache> cached_source_; - int64_t source_size_; - std::shared_ptr file_metadata_; + std::shared_ptr<::arrow::io::internal::ReadRangeCache> cachedSource_; + int64_t sourceSize_; + std::shared_ptr fileMetadata_; ReaderProperties properties_; - std::shared_ptr page_index_reader_; - std::unique_ptr bloom_filter_reader_; - // Maps row group ordinal and prebuffer status of its column chunks in the - // form of a bitmap buffer. - std::unordered_map> prebuffered_column_chunks_; - std::shared_ptr file_decryptor_; - - // \return The true length of the metadata in bytes - uint32_t ParseUnencryptedFileMetadata( - const std::shared_ptr& footer_buffer, - const uint32_t metadata_len); - - std::string HandleAadPrefix( - FileDecryptionProperties* file_decryption_properties, + std::shared_ptr pageIndexReader_; + std::unique_ptr bloomFilterReader_; + // Maps row group ordinal and prebuffer status of its column chunks in the. + // Form of a bitmap buffer. + std::unordered_map> prebufferedColumnChunks_; + std::shared_ptr fileDecryptor_; + + // \return The true length of the metadata in bytes. + uint32_t parseUnencryptedFileMetadata( + const std::shared_ptr& footerBuffer, + const uint32_t metadataLen); + + std::string handleAadPrefix( + FileDecryptionProperties* fileDecryptionProperties, EncryptionAlgorithm& algo); - void ParseMetaDataOfEncryptedFileWithPlaintextFooter( - FileDecryptionProperties* file_decryption_properties, - const std::shared_ptr& metadata_buffer, - uint32_t metadata_len, - uint32_t read_metadata_len); + void parseMetaDataOfEncryptedFileWithPlaintextFooter( + FileDecryptionProperties* fileDecryptionProperties, + const std::shared_ptr& metadataBuffer, + uint32_t metadataLen, + uint32_t readMetadataLen); - // \return The position and size of the actual footer - std::pair ParseMetaDataOfEncryptedFileWithEncryptedFooter( - const std::shared_ptr& crypto_metadata_buffer, - uint32_t footer_len); + // \return The position and size of the actual footer. + std::pair parseMetaDataOfEncryptedFileWithEncryptedFooter( + const std::shared_ptr& cryptoMetadataBuffer, + uint32_t footerLen); }; -uint32_t SerializedFile::ParseUnencryptedFileMetadata( - const std::shared_ptr& metadata_buffer, - const uint32_t metadata_len) { - if (metadata_buffer->size() != metadata_len) { +uint32_t SerializedFile::parseUnencryptedFileMetadata( + const std::shared_ptr& metadataBuffer, + const uint32_t metadataLen) { + if (metadataBuffer->size() != metadataLen) { throw ParquetException( "Failed reading metadata buffer (requested " + - std::to_string(metadata_len) + " bytes but got " + - std::to_string(metadata_buffer->size()) + " bytes)"); + std::to_string(metadataLen) + " bytes but got " + + std::to_string(metadataBuffer->size()) + " bytes)"); } - uint32_t read_metadata_len = metadata_len; - // The encrypted read path falls through to here, so pass in the decryptor - file_metadata_ = FileMetaData::Make( - metadata_buffer->data(), - &read_metadata_len, - properties_, - file_decryptor_); - return read_metadata_len; + uint32_t readMetadataLen = metadataLen; + // The encrypted read path falls through to here, so pass in the decryptor. + fileMetadata_ = FileMetaData::make( + metadataBuffer->data(), &readMetadataLen, properties_, fileDecryptor_); + return readMetadataLen; } std::pair -SerializedFile::ParseMetaDataOfEncryptedFileWithEncryptedFooter( - const std::shared_ptr<::arrow::Buffer>& crypto_metadata_buffer, - // both metadata & crypto metadata length - const uint32_t footer_len) { - // encryption with encrypted footer - // Check if the footer_buffer contains the entire metadata - if (crypto_metadata_buffer->size() != footer_len) { +SerializedFile::parseMetaDataOfEncryptedFileWithEncryptedFooter( + const std::shared_ptr<::arrow::Buffer>& cryptoMetadataBuffer, + // Both metadata & crypto metadata length. + const uint32_t footerLen) { + // Encryption with encrypted footer. + // Check if the footer_buffer contains the entire metadata. + if (cryptoMetadataBuffer->size() != footerLen) { throw ParquetException( "Failed reading encrypted metadata buffer (requested " + - std::to_string(footer_len) + " bytes but got " + - std::to_string(crypto_metadata_buffer->size()) + " bytes)"); + std::to_string(footerLen) + " bytes but got " + + std::to_string(cryptoMetadataBuffer->size()) + " bytes)"); } - auto file_decryption_properties = - properties_.file_decryption_properties().get(); - if (file_decryption_properties == nullptr) { + auto fileDecryptionProperties = properties_.fileDecryptionProperties().get(); + if (fileDecryptionProperties == nullptr) { throw ParquetException( "Could not read encrypted metadata, no decryption found in reader's properties"); } - uint32_t crypto_metadata_len = footer_len; - std::shared_ptr file_crypto_metadata = - FileCryptoMetaData::Make( - crypto_metadata_buffer->data(), &crypto_metadata_len); - // Handle AAD prefix - EncryptionAlgorithm algo = file_crypto_metadata->encryption_algorithm(); - std::string file_aad = HandleAadPrefix(file_decryption_properties, algo); - file_decryptor_ = std::make_shared( - file_decryption_properties, - file_aad, + uint32_t cryptoMetadataLen = footerLen; + std::shared_ptr fileCryptoMetadata = + FileCryptoMetaData::make( + cryptoMetadataBuffer->data(), &cryptoMetadataLen); + // Handle AAD prefix. + EncryptionAlgorithm algo = fileCryptoMetadata->encryptionAlgorithm(); + std::string fileAad = handleAadPrefix(fileDecryptionProperties, algo); + fileDecryptor_ = std::make_shared( + fileDecryptionProperties, + fileAad, algo.algorithm, - file_crypto_metadata->key_metadata(), - properties_.memory_pool()); + fileCryptoMetadata->keyMetadata(), + properties_.memoryPool()); - int64_t metadata_offset = - source_size_ - kFooterSize - footer_len + crypto_metadata_len; - uint32_t metadata_len = footer_len - crypto_metadata_len; - return std::make_pair(metadata_offset, metadata_len); + int64_t metadataOffset = + sourceSize_ - kFooterSize - footerLen + cryptoMetadataLen; + uint32_t metadataLen = footerLen - cryptoMetadataLen; + return std::make_pair(metadataOffset, metadataLen); } -void SerializedFile::ParseMetaDataOfEncryptedFileWithPlaintextFooter( - FileDecryptionProperties* file_decryption_properties, - const std::shared_ptr& metadata_buffer, - uint32_t metadata_len, - uint32_t read_metadata_len) { - // Providing decryption properties in plaintext footer mode is not mandatory, - // for example when reading by legacy reader. - if (file_decryption_properties != nullptr) { - EncryptionAlgorithm algo = file_metadata_->encryption_algorithm(); - // Handle AAD prefix - std::string file_aad = HandleAadPrefix(file_decryption_properties, algo); - file_decryptor_ = std::make_shared( - file_decryption_properties, - file_aad, +void SerializedFile::parseMetaDataOfEncryptedFileWithPlaintextFooter( + FileDecryptionProperties* fileDecryptionProperties, + const std::shared_ptr& metadataBuffer, + uint32_t metadataLen, + uint32_t readMetadataLen) { + // Providing decryption properties in plaintext footer mode is not mandatory,. + // For example when reading by legacy reader. + if (fileDecryptionProperties != nullptr) { + EncryptionAlgorithm algo = fileMetadata_->encryptionAlgorithm(); + // Handle AAD prefix. + std::string fileAad = handleAadPrefix(fileDecryptionProperties, algo); + fileDecryptor_ = std::make_shared( + fileDecryptionProperties, + fileAad, algo.algorithm, - file_metadata_->footer_signing_key_metadata(), - properties_.memory_pool()); - // set the InternalFileDecryptor in the metadata as well, as it's used - // for signature verification and for ColumnChunkMetaData creation. - file_metadata_->set_file_decryptor(file_decryptor_); - - if (file_decryption_properties->check_plaintext_footer_integrity()) { - if (metadata_len - read_metadata_len != + fileMetadata_->footerSigningKeyMetadata(), + properties_.memoryPool()); + // Set the InternalFileDecryptor in the metadata as well, as it's used. + // For signature verification and for ColumnChunkMetaData creation. + fileMetadata_->setFileDecryptor(fileDecryptor_); + + if (fileDecryptionProperties->checkPlaintextFooterIntegrity()) { + if (metadataLen - readMetadataLen != (encryption::kGcmTagLength + encryption::kNonceLength)) { throw ParquetInvalidOrCorruptedFileException( "Failed reading metadata for encryption signature (requested ", encryption::kGcmTagLength + encryption::kNonceLength, " bytes but have ", - metadata_len - read_metadata_len, + metadataLen - readMetadataLen, " bytes)"); } - if (!file_metadata_->VerifySignature( - metadata_buffer->data() + read_metadata_len)) { + if (!fileMetadata_->verifySignature( + metadataBuffer->data() + readMetadataLen)) { throw ParquetInvalidOrCorruptedFileException( "Parquet crypto signature verification failed"); } @@ -816,88 +803,87 @@ void SerializedFile::ParseMetaDataOfEncryptedFileWithPlaintextFooter( } } -std::string SerializedFile::HandleAadPrefix( - FileDecryptionProperties* file_decryption_properties, +std::string SerializedFile::handleAadPrefix( + FileDecryptionProperties* fileDecryptionProperties, EncryptionAlgorithm& algo) { - std::string aad_prefix_in_properties = - file_decryption_properties->aad_prefix(); - std::string aad_prefix = aad_prefix_in_properties; - bool file_has_aad_prefix = algo.aad.aad_prefix.empty() ? false : true; - std::string aad_prefix_in_file = algo.aad.aad_prefix; + std::string aadPrefixInProperties = fileDecryptionProperties->aadPrefix(); + std::string aadPrefix = aadPrefixInProperties; + bool fileHasAadPrefix = algo.aad.aadPrefix.empty() ? false : true; + std::string aadPrefixInFile = algo.aad.aadPrefix; - if (algo.aad.supply_aad_prefix && aad_prefix_in_properties.empty()) { + if (algo.aad.supplyAadPrefix && aadPrefixInProperties.empty()) { throw ParquetException( "AAD prefix used for file encryption, " "but not stored in file and not supplied " "in decryption properties"); } - if (file_has_aad_prefix) { - if (!aad_prefix_in_properties.empty()) { - if (aad_prefix_in_properties.compare(aad_prefix_in_file) != 0) { + if (fileHasAadPrefix) { + if (!aadPrefixInProperties.empty()) { + if (aadPrefixInProperties.compare(aadPrefixInFile) != 0) { throw ParquetException( "AAD Prefix in file and in properties " "is not the same"); } } - aad_prefix = aad_prefix_in_file; - std::shared_ptr aad_prefix_verifier = - file_decryption_properties->aad_prefix_verifier(); - if (aad_prefix_verifier != nullptr) - aad_prefix_verifier->Verify(aad_prefix); + aadPrefix = aadPrefixInFile; + std::shared_ptr aadPrefixVerifier = + fileDecryptionProperties->aadPrefixVerifier(); + if (aadPrefixVerifier != nullptr) + aadPrefixVerifier->verify(aadPrefix); } else { - if (!algo.aad.supply_aad_prefix && !aad_prefix_in_properties.empty()) { + if (!algo.aad.supplyAadPrefix && !aadPrefixInProperties.empty()) { throw ParquetException( "AAD Prefix set in decryption properties, but was not used " "for file encryption"); } - std::shared_ptr aad_prefix_verifier = - file_decryption_properties->aad_prefix_verifier(); - if (aad_prefix_verifier != nullptr) { + std::shared_ptr aadPrefixVerifier = + fileDecryptionProperties->aadPrefixVerifier(); + if (aadPrefixVerifier != nullptr) { throw ParquetException( "AAD Prefix Verifier is set, but AAD Prefix not found in file"); } } - return aad_prefix + algo.aad.aad_file_unique; + return aadPrefix + algo.aad.aadFileUnique; } -// ---------------------------------------------------------------------- -// ParquetFileReader public API +// ----------------------------------------------------------------------. +// ParquetFileReader public API. ParquetFileReader::ParquetFileReader() {} ParquetFileReader::~ParquetFileReader() { try { - Close(); + close(); } catch (...) { } } -// Open the file. If no metadata is passed, it is parsed from the footer of -// the file -std::unique_ptr ParquetFileReader::Contents::Open( +// Open the file. If no metadata is passed, it is parsed from the footer of. +// The file. +std::unique_ptr ParquetFileReader::Contents::open( std::shared_ptr source, const ReaderProperties& props, std::shared_ptr metadata) { std::unique_ptr result( new SerializedFile(std::move(source), props)); - // Access private methods here, but otherwise unavailable + // Access private methods here, but otherwise unavailable. SerializedFile* file = static_cast(result.get()); if (metadata == nullptr) { - // Validates magic bytes, parses metadata, and initializes the - // SchemaDescriptor - file->ParseMetaData(); + // Validates magic bytes, parses metadata, and initializes the. + // SchemaDescriptor. + file->parseMetaData(); } else { - file->set_metadata(std::move(metadata)); + file->setMetadata(std::move(metadata)); } return result; } ::arrow::Future> -ParquetFileReader::Contents::OpenAsync( +ParquetFileReader::Contents::openAsync( std::shared_ptr source, const ReaderProperties& props, std::shared_ptr metadata) { @@ -906,7 +892,7 @@ ParquetFileReader::Contents::OpenAsync( new SerializedFile(std::move(source), props)); SerializedFile* file = static_cast(result.get()); if (metadata == nullptr) { - // TODO(ARROW-12259): workaround since we have Future<(move-only type)> + // TODO(ARROW-12259): workaround since we have Future<(move-only type)>. struct { ::arrow::Result> operator()() { @@ -916,80 +902,80 @@ ParquetFileReader::Contents::OpenAsync( std::unique_ptr result; } Continuation; Continuation.result = std::move(result); - return file->ParseMetaDataAsync().Then(std::move(Continuation)); + return file->parseMetaDataAsync().Then(std::move(Continuation)); } else { - file->set_metadata(std::move(metadata)); + file->setMetadata(std::move(metadata)); return ::arrow::Future>:: MakeFinished(std::move(result)); } END_PARQUET_CATCH_EXCEPTIONS } -std::unique_ptr ParquetFileReader::Open( +std::unique_ptr ParquetFileReader::open( std::shared_ptr<::arrow::io::RandomAccessFile> source, const ReaderProperties& props, std::shared_ptr metadata) { auto contents = - SerializedFile::Open(std::move(source), props, std::move(metadata)); + SerializedFile::open(std::move(source), props, std::move(metadata)); std::unique_ptr result = std::make_unique(); - result->Open(std::move(contents)); + result->open(std::move(contents)); return result; } -std::unique_ptr ParquetFileReader::OpenFile( +std::unique_ptr ParquetFileReader::openFile( const std::string& path, - bool memory_map, + bool memoryMap, const ReaderProperties& props, std::shared_ptr metadata) { std::shared_ptr<::arrow::io::RandomAccessFile> source; - if (memory_map) { + if (memoryMap) { PARQUET_ASSIGN_OR_THROW( source, ::arrow::io::MemoryMappedFile::Open(path, ::arrow::io::FileMode::READ)); } else { PARQUET_ASSIGN_OR_THROW( - source, ::arrow::io::ReadableFile::Open(path, props.memory_pool())); + source, ::arrow::io::ReadableFile::Open(path, props.memoryPool())); } - return Open(std::move(source), props, std::move(metadata)); + return open(std::move(source), props, std::move(metadata)); } ::arrow::Future> -ParquetFileReader::OpenAsync( +ParquetFileReader::openAsync( std::shared_ptr<::arrow::io::RandomAccessFile> source, const ReaderProperties& props, std::shared_ptr metadata) { BEGIN_PARQUET_CATCH_EXCEPTIONS auto fut = - SerializedFile::OpenAsync(std::move(source), props, std::move(metadata)); - // TODO(ARROW-12259): workaround since we have Future<(move-only type)> + SerializedFile::openAsync(std::move(source), props, std::move(metadata)); + // TODO(ARROW-12259): workaround since we have Future<(move-only type)>. auto completed = ::arrow::Future>::Make(); fut.AddCallback( [fut, completed]( const ::arrow::Result>& - contents) mutable { - if (!contents.ok()) { - completed.MarkFinished(contents.status()); + Contents) mutable { + if (!Contents.ok()) { + completed.MarkFinished(Contents.status()); return; } std::unique_ptr result = std::make_unique(); - result->Open(fut.MoveResult().MoveValueUnsafe()); + result->open(fut.MoveResult().MoveValueUnsafe()); completed.MarkFinished(std::move(result)); }); return completed; END_PARQUET_CATCH_EXCEPTIONS } -void ParquetFileReader::Open( +void ParquetFileReader::open( std::unique_ptr contents) { contents_ = std::move(contents); } -void ParquetFileReader::Close() { +void ParquetFileReader::close() { if (contents_) { - contents_->Close(); + contents_->close(); } } @@ -997,120 +983,120 @@ std::shared_ptr ParquetFileReader::metadata() const { return contents_->metadata(); } -std::shared_ptr ParquetFileReader::GetPageIndexReader() { - return contents_->GetPageIndexReader(); +std::shared_ptr ParquetFileReader::getPageIndexReader() { + return contents_->getPageIndexReader(); } -BloomFilterReader& ParquetFileReader::GetBloomFilterReader() { - return contents_->GetBloomFilterReader(); +BloomFilterReader& ParquetFileReader::getBloomFilterReader() { + return contents_->getBloomFilterReader(); } -std::shared_ptr ParquetFileReader::RowGroup(int i) { - if (i >= metadata()->num_row_groups()) { +std::shared_ptr ParquetFileReader::rowGroup(int i) { + if (i >= metadata()->numRowGroups()) { std::stringstream ss; ss << "Trying to read row group " << i << " but file only has " - << metadata()->num_row_groups() << " row groups"; + << metadata()->numRowGroups() << " row groups"; throw ParquetException(ss.str()); } - return contents_->GetRowGroup(i); + return contents_->getRowGroup(i); } -void ParquetFileReader::PreBuffer( - const std::vector& row_groups, - const std::vector& column_indices, +void ParquetFileReader::preBuffer( + const std::vector& rowGroups, + const std::vector& columnIndices, const ::arrow::io::IOContext& ctx, const ::arrow::io::CacheOptions& options) { - // Access private methods here + // Access private methods here. SerializedFile* file = ::arrow::internal::checked_cast(contents_.get()); - file->PreBuffer(row_groups, column_indices, ctx, options); + file->preBuffer(rowGroups, columnIndices, ctx, options); } -::arrow::Future<> ParquetFileReader::WhenBuffered( - const std::vector& row_groups, - const std::vector& column_indices) const { - // Access private methods here +::arrow::Future<> ParquetFileReader::whenBuffered( + const std::vector& rowGroups, + const std::vector& columnIndices) const { + // Access private methods here. SerializedFile* file = ::arrow::internal::checked_cast(contents_.get()); - return file->WhenBuffered(row_groups, column_indices); + return file->whenBuffered(rowGroups, columnIndices); } -// ---------------------------------------------------------------------- -// File metadata helpers +// ----------------------------------------------------------------------. +// File metadata helpers. -std::shared_ptr ReadMetaData( +std::shared_ptr readMetaData( const std::shared_ptr<::arrow::io::RandomAccessFile>& source) { - return ParquetFileReader::Open(source)->metadata(); + return ParquetFileReader::open(source)->metadata(); } -// ---------------------------------------------------------------------- -// File scanner for performance testing +// ----------------------------------------------------------------------. +// File Scanner for performance testing. -int64_t ScanFileContents( +int64_t scanFileContents( std::vector columns, - const int32_t column_batch_size, + const int32_t columnBatchSize, ParquetFileReader* reader) { - std::vector rep_levels(column_batch_size); - std::vector def_levels(column_batch_size); + std::vector repLevels(columnBatchSize); + std::vector defLevels(columnBatchSize); - int num_columns = static_cast(columns.size()); + int numColumns = static_cast(columns.size()); - // columns are not specified explicitly. Add all columns + // Columns are not specified explicitly. Add all columns. if (columns.size() == 0) { - num_columns = reader->metadata()->num_columns(); - columns.resize(num_columns); - for (int i = 0; i < num_columns; i++) { + numColumns = reader->metadata()->numColumns(); + columns.resize(numColumns); + for (int i = 0; i < numColumns; i++) { columns[i] = i; } } - if (num_columns == 0) { - // If we still have no columns(none in file), return early. The remainder of - // function expects there to be at least one column. + if (numColumns == 0) { + // If we still have no columns(none in file), return early. The remainder + // of. Function expects there to be at least one column. return 0; } - std::vector total_rows(num_columns, 0); + std::vector totalRows(numColumns, 0); - for (int r = 0; r < reader->metadata()->num_row_groups(); ++r) { - auto group_reader = reader->RowGroup(r); + for (int r = 0; r < reader->metadata()->numRowGroups(); ++r) { + auto groupReader = reader->rowGroup(r); int col = 0; for (auto i : columns) { - std::shared_ptr col_reader = group_reader->Column(i); - size_t value_byte_size = - GetTypeByteSize(col_reader->descr()->physical_type()); - std::vector values(column_batch_size * value_byte_size); - - int64_t values_read = 0; - while (col_reader->HasNext()) { - int64_t levels_read = ScanAllValues( - column_batch_size, - def_levels.data(), - rep_levels.data(), + std::shared_ptr colReader = groupReader->column(i); + size_t valueByteSize = + getTypeByteSize(colReader->descr()->physicalType()); + std::vector values(columnBatchSize * valueByteSize); + + int64_t valuesRead = 0; + while (colReader->hasNext()) { + int64_t levelsRead = scanAllValues( + columnBatchSize, + defLevels.data(), + repLevels.data(), values.data(), - &values_read, - col_reader.get()); - if (col_reader->descr()->max_repetition_level() > 0) { - for (int64_t i = 0; i < levels_read; i++) { - if (rep_levels[i] == 0) { - total_rows[col]++; + &valuesRead, + colReader.get()); + if (colReader->descr()->maxRepetitionLevel() > 0) { + for (int64_t i = 0; i < levelsRead; i++) { + if (repLevels[i] == 0) { + totalRows[col]++; } } } else { - total_rows[col] += levels_read; + totalRows[col] += levelsRead; } } col++; } } - for (int i = 1; i < num_columns; ++i) { - if (total_rows[0] != total_rows[i]) { + for (int i = 1; i < numColumns; ++i) { + if (totalRows[0] != totalRows[i]) { throw ParquetException( "Parquet error: Total rows among columns do not match"); } } - return total_rows[0]; + return totalRows[0]; } } // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/tests/FileReader.h b/velox/dwio/parquet/writer/arrow/tests/FileReader.h index 046655e28b5..cf6e058e50a 100644 --- a/velox/dwio/parquet/writer/arrow/tests/FileReader.h +++ b/velox/dwio/parquet/writer/arrow/tests/FileReader.h @@ -40,178 +40,178 @@ class RowGroupMetaData; class PARQUET_EXPORT RowGroupReader { public: - // Forward declare a virtual class 'Contents' to aid dependency injection and - // more easily create test fixtures An implementation of the Contents class is - // defined in the .cc file + // Forward declare a virtual class 'Contents' to aid dependency injection and. + // More easily create test fixtures An implementation of the Contents class + // is. Defined in the .cc file. struct Contents { virtual ~Contents() {} - virtual std::unique_ptr GetColumnPageReader(int i) = 0; + virtual std::unique_ptr getColumnPageReader(int i) = 0; virtual const RowGroupMetaData* metadata() const = 0; virtual const ReaderProperties* properties() const = 0; }; explicit RowGroupReader(std::unique_ptr contents); - // Returns the rowgroup metadata + // Returns the rowgroup metadata. const RowGroupMetaData* metadata() const; - // Construct a ColumnReader for the indicated row group-relative - // column. Ownership is shared with the RowGroupReader. - std::shared_ptr Column(int i); + // Construct a ColumnReader for the indicated row group-relative. + // Column. Ownership is shared with the RowGroupReader. + std::shared_ptr column(int i); // Construct a ColumnReader, trying to enable exposed encoding. // - // For dictionary encoding, currently we only support column chunks that are - // fully dictionary encoded, i.e., all data pages in the column chunk are - // dictionary encoded. If a column chunk uses dictionary encoding but then - // falls back to plain encoding, the encoding will not be exposed. + // For dictionary encoding, currently we only support column chunks that are. + // Fully dictionary encoded, i.e., all data pages in the column chunk are. + // Dictionary encoded. If a column chunk uses dictionary encoding but then. + // Falls back to plain encoding, the encoding will not be exposed. // - // The returned column reader provides an API GetExposedEncoding() for the - // users to check the exposed encoding and determine how to read the batches. + // The returned column reader provides an API GetExposedEncoding() for the. + // Users to check the exposed encoding and determine how to read the batches. // - // \note API EXPERIMENTAL - std::shared_ptr ColumnWithExposeEncoding( + // \note API EXPERIMENTAL. + std::shared_ptr columnWithExposeEncoding( int i, - ExposedEncoding encoding_to_expose); + ExposedEncoding encodingToExpose); - std::unique_ptr GetColumnPageReader(int i); + std::unique_ptr getColumnPageReader(int i); private: - // Holds a pointer to an instance of Contents implementation + // Holds a pointer to an instance of Contents implementation. std::unique_ptr contents_; }; class PARQUET_EXPORT ParquetFileReader { public: - // Declare a virtual class 'Contents' to aid dependency injection and more - // easily create test fixtures - // An implementation of the Contents class is defined in the .cc file + // Declare a virtual class 'Contents' to aid dependency injection and more. + // Easily create test fixtures. + // An implementation of the Contents class is defined in the .cc file. struct PARQUET_EXPORT Contents { - static std::unique_ptr Open( + static std::unique_ptr open( std::shared_ptr<::arrow::io::RandomAccessFile> source, - const ReaderProperties& props = default_reader_properties(), + const ReaderProperties& props = defaultReaderProperties(), std::shared_ptr metadata = NULLPTR); - static ::arrow::Future> OpenAsync( + static ::arrow::Future> openAsync( std::shared_ptr<::arrow::io::RandomAccessFile> source, - const ReaderProperties& props = default_reader_properties(), + const ReaderProperties& props = defaultReaderProperties(), std::shared_ptr metadata = NULLPTR); virtual ~Contents() = default; - // Perform any cleanup associated with the file contents - virtual void Close() = 0; - virtual std::shared_ptr GetRowGroup(int i) = 0; + // Perform any cleanup associated with the file contents. + virtual void close() = 0; + virtual std::shared_ptr getRowGroup(int i) = 0; virtual std::shared_ptr metadata() const = 0; - virtual std::shared_ptr GetPageIndexReader() = 0; - virtual BloomFilterReader& GetBloomFilterReader() = 0; + virtual std::shared_ptr getPageIndexReader() = 0; + virtual BloomFilterReader& getBloomFilterReader() = 0; }; ParquetFileReader(); ~ParquetFileReader(); - // Create a file reader instance from an Arrow file object. Thread-safety is - // the responsibility of the file implementation - static std::unique_ptr Open( + // Create a file reader instance from an Arrow file object. Thread-safety is. + // The responsibility of the file implementation. + static std::unique_ptr open( std::shared_ptr<::arrow::io::RandomAccessFile> source, - const ReaderProperties& props = default_reader_properties(), + const ReaderProperties& props = defaultReaderProperties(), std::shared_ptr metadata = NULLPTR); - // API Convenience to open a serialized Parquet file on disk, using Arrow IO - // interfaces. - static std::unique_ptr OpenFile( + // API Convenience to open a serialized Parquet file on disk, using Arrow IO. + // Interfaces. + static std::unique_ptr openFile( const std::string& path, - bool memory_map = false, - const ReaderProperties& props = default_reader_properties(), + bool memoryMap = false, + const ReaderProperties& props = defaultReaderProperties(), std::shared_ptr metadata = NULLPTR); // Asynchronously open a file reader from an Arrow file object. // Does not throw - all errors are reported through the Future. - static ::arrow::Future> OpenAsync( + static ::arrow::Future> openAsync( std::shared_ptr<::arrow::io::RandomAccessFile> source, - const ReaderProperties& props = default_reader_properties(), + const ReaderProperties& props = defaultReaderProperties(), std::shared_ptr metadata = NULLPTR); - void Open(std::unique_ptr contents); - void Close(); + void open(std::unique_ptr contents); + void close(); - // The RowGroupReader is owned by the FileReader - std::shared_ptr RowGroup(int i); + // The RowGroupReader is owned by the FileReader. + std::shared_ptr rowGroup(int i); - // Returns the file metadata. Only one instance is ever created + // Returns the file metadata. Only one instance is ever created. std::shared_ptr metadata() const; /// Returns the PageIndexReader. Only one instance is ever created. /// /// If the file does not have the page index, nullptr may be returned. - /// Because it pays to check existence of page index in the file, it - /// is possible to return a non null value even if page index does - /// not exist. It is the caller's responsibility to check the return - /// value and follow-up calls to PageIndexReader. + /// Because it pays to check existence of page index in the file, it. + /// Is possible to return a non null value even if page index does. + /// Not exist. It is the caller's responsibility to check the return. + /// Value and follow-up calls to PageIndexReader. /// - /// WARNING: The returned PageIndexReader must not outlive the + /// WARNING: The returned PageIndexReader must not outlive the. /// ParquetFileReader. Initialize GetPageIndexReader() is not thread-safety. - std::shared_ptr GetPageIndexReader(); + std::shared_ptr getPageIndexReader(); /// Returns the BloomFilterReader. Only one instance is ever created. /// - /// WARNING: The returned BloomFilterReader must not outlive the + /// WARNING: The returned BloomFilterReader must not outlive the. /// ParquetFileReader. Initialize GetBloomFilterReader() is not thread-safety. - BloomFilterReader& GetBloomFilterReader(); + BloomFilterReader& getBloomFilterReader(); /// Pre-buffer the specified column indices in all row groups. /// - /// Readers can optionally call this to cache the necessary slices - /// of the file in-memory before deserialization. Arrow readers can - /// automatically do this via an option. This is intended to - /// increase performance when reading from high-latency filesystems - /// (e.g. Amazon S3). + /// Readers can optionally call this to cache the necessary slices. + /// Of the file in-memory before deserialization. Arrow readers can. + /// Automatically do this via an option. This is intended to. + /// Increase performance when reading from high-latency filesystems. + /// (E.g. Amazon S3). /// - /// After calling this, creating readers for row groups/column - /// indices that were not buffered may fail. Creating multiple - /// readers for the a subset of the buffered regions is - /// acceptable. This may be called again to buffer a different set - /// of row groups/columns. + /// After calling this, creating readers for row groups/column. + /// Indices that were not buffered may fail. Creating multiple. + /// Readers for the a subset of the buffered regions is. + /// Acceptable. This may be called again to buffer a different set. + /// Of row groups/columns. /// - /// If memory usage is a concern, note that data will remain - /// buffered in memory until either \a PreBuffer() is called again, - /// or the reader itself is destructed. Reading - and buffering - - /// only one row group at a time may be useful. + /// If memory usage is a concern, note that data will remain. + /// Buffered in memory until either \a PreBuffer() is called again,. + /// Or the reader itself is destructed. Reading - and buffering -. + /// Only one row group at a time may be useful. /// /// This method may throw. - void PreBuffer( - const std::vector& row_groups, - const std::vector& column_indices, + void preBuffer( + const std::vector& rowGroups, + const std::vector& columnIndices, const ::arrow::io::IOContext& ctx, const ::arrow::io::CacheOptions& options); /// Wait for the specified row groups and column indices to be pre-buffered. /// - /// After the returned Future completes, reading the specified row - /// groups/columns will not block. + /// After the returned Future completes, reading the specified row. + /// Groups/columns will not block. /// /// PreBuffer must be called first. This method does not throw. - ::arrow::Future<> WhenBuffered( - const std::vector& row_groups, - const std::vector& column_indices) const; + ::arrow::Future<> whenBuffered( + const std::vector& rowGroups, + const std::vector& columnIndices) const; private: - // Holds a pointer to an instance of Contents implementation + // Holds a pointer to an instance of Contents implementation. std::unique_ptr contents_; }; -// Read only Parquet file metadata +// Read only Parquet file metadata. std::shared_ptr PARQUET_EXPORT -ReadMetaData(const std::shared_ptr<::arrow::io::RandomAccessFile>& source); +readMetaData(const std::shared_ptr<::arrow::io::RandomAccessFile>& source); -/// \brief Scan all values in file. Useful for performance testing -/// \param[in] columns the column numbers to scan. If empty scans all -/// \param[in] column_batch_size number of values to read at a time when -/// scanning column \param[in] reader a ParquetFileReader instance \return -/// number of semantic rows in file +/// \brief Scan all values in file. Useful for performance testing. +/// \param[in] columns the column numbers to scan. If empty scans all. +/// \param[in] column_batch_size number of values to read at a time when. +/// Scanning column \param[in] reader a ParquetFileReader instance \return. +/// Number of semantic rows in file. PARQUET_EXPORT -int64_t ScanFileContents( +int64_t scanFileContents( std::vector columns, - const int32_t column_batch_size, + const int32_t columnBatchSize, ParquetFileReader* reader); } // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/tests/FileSerializeTest.cpp b/velox/dwio/parquet/writer/arrow/tests/FileSerializeTest.cpp index 1a939b49a9a..98527fe63ed 100644 --- a/velox/dwio/parquet/writer/arrow/tests/FileSerializeTest.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/FileSerializeTest.cpp @@ -42,291 +42,288 @@ template class TestSerialize : public PrimitiveTypedTest { public: void SetUp() { - num_columns_ = 4; - num_rowgroups_ = 4; - rows_per_rowgroup_ = 50; - rows_per_batch_ = 10; - this->SetUpSchema(Repetition::OPTIONAL, num_columns_); + numColumns_ = 4; + numRowgroups_ = 4; + rowsPerRowgroup_ = 50; + rowsPerBatch_ = 10; + this->setUpSchema(Repetition::kOptional, numColumns_); } protected: - int num_columns_; - int num_rowgroups_; - int rows_per_rowgroup_; - int rows_per_batch_; + int numColumns_; + int numRowgroups_; + int rowsPerRowgroup_; + int rowsPerBatch_; - void FileSerializeTest(Compression::type codec_type) { - FileSerializeTest(codec_type, codec_type); + void fileSerializeTest(Compression::type codecType) { + fileSerializeTest(codecType, codecType); } - void FileSerializeTest( - Compression::type codec_type, - Compression::type expected_codec_type) { - auto sink = CreateOutputStream(); + void fileSerializeTest( + Compression::type codecType, + Compression::type expectedCodecType) { + auto sink = createOutputStream(); auto gnode = std::static_pointer_cast(this->node_); - WriterProperties::Builder prop_builder; + WriterProperties::Builder propBuilder; - for (int i = 0; i < num_columns_; ++i) { - prop_builder.compression(this->schema_.Column(i)->name(), codec_type); + for (int i = 0; i < numColumns_; ++i) { + propBuilder.compression(this->schema_.column(i)->name(), codecType); } - std::shared_ptr writer_properties = prop_builder.build(); - - auto file_writer = ParquetFileWriter::Open(sink, gnode, writer_properties); - this->GenerateData(rows_per_rowgroup_); - for (int rg = 0; rg < num_rowgroups_ / 2; ++rg) { - RowGroupWriter* row_group_writer; - row_group_writer = file_writer->AppendRowGroup(); - for (int col = 0; col < num_columns_; ++col) { - auto column_writer = static_cast*>( - row_group_writer->NextColumn()); - column_writer->WriteBatch( - rows_per_rowgroup_, - this->def_levels_.data(), + std::shared_ptr WriterProperties = propBuilder.build(); + + auto fileWriter = ParquetFileWriter::open(sink, gnode, WriterProperties); + this->generateData(rowsPerRowgroup_); + for (int rg = 0; rg < numRowgroups_ / 2; ++rg) { + RowGroupWriter* rowGroupWriter; + rowGroupWriter = fileWriter->appendRowGroup(); + for (int col = 0; col < numColumns_; ++col) { + auto columnWriter = static_cast*>( + rowGroupWriter->nextColumn()); + columnWriter->writeBatch( + rowsPerRowgroup_, + this->defLevels_.data(), nullptr, - this->values_ptr_); - column_writer->Close(); - // Ensure column() API which is specific to BufferedRowGroup cannot be - // called - ASSERT_THROW(row_group_writer->column(col), ParquetException); + this->valuesPtr_); + columnWriter->close(); + // Ensure column() API which is specific to BufferedRowGroup cannot be. + // Called. + ASSERT_THROW(rowGroupWriter->column(col), ParquetException); } - EXPECT_EQ(0, row_group_writer->total_compressed_bytes()); - EXPECT_NE(0, row_group_writer->total_bytes_written()); - EXPECT_NE(0, row_group_writer->total_compressed_bytes_written()); - row_group_writer->Close(); - EXPECT_EQ(0, row_group_writer->total_compressed_bytes()); - EXPECT_NE(0, row_group_writer->total_bytes_written()); - EXPECT_NE(0, row_group_writer->total_compressed_bytes_written()); + EXPECT_EQ(0, rowGroupWriter->totalCompressedBytes()); + EXPECT_NE(0, rowGroupWriter->totalBytesWritten()); + EXPECT_NE(0, rowGroupWriter->totalCompressedBytesWritten()); + rowGroupWriter->close(); + EXPECT_EQ(0, rowGroupWriter->totalCompressedBytes()); + EXPECT_NE(0, rowGroupWriter->totalBytesWritten()); + EXPECT_NE(0, rowGroupWriter->totalCompressedBytesWritten()); } - // Write half BufferedRowGroups - for (int rg = 0; rg < num_rowgroups_ / 2; ++rg) { - RowGroupWriter* row_group_writer; - row_group_writer = file_writer->AppendBufferedRowGroup(); - for (int batch = 0; batch < (rows_per_rowgroup_ / rows_per_batch_); - ++batch) { - for (int col = 0; col < num_columns_; ++col) { - auto column_writer = static_cast*>( - row_group_writer->column(col)); - column_writer->WriteBatch( - rows_per_batch_, - this->def_levels_.data() + (batch * rows_per_batch_), + // Write half BufferedRowGroups. + for (int rg = 0; rg < numRowgroups_ / 2; ++rg) { + RowGroupWriter* rowGroupWriter; + rowGroupWriter = fileWriter->appendBufferedRowGroup(); + for (int batch = 0; batch < (rowsPerRowgroup_ / rowsPerBatch_); ++batch) { + for (int col = 0; col < numColumns_; ++col) { + auto columnWriter = static_cast*>( + rowGroupWriter->column(col)); + columnWriter->writeBatch( + rowsPerBatch_, + this->defLevels_.data() + (batch * rowsPerBatch_), nullptr, - this->values_ptr_ + (batch * rows_per_batch_)); - // Ensure NextColumn() API which is specific to RowGroup cannot be - // called - ASSERT_THROW(row_group_writer->NextColumn(), ParquetException); + this->valuesPtr_ + (batch * rowsPerBatch_)); + // Ensure NextColumn() API which is specific to RowGroup cannot be. + // Called. + ASSERT_THROW(rowGroupWriter->nextColumn(), ParquetException); } } - // total_compressed_bytes() may equal to 0 if no dictionary enabled and no - // buffered values. - EXPECT_EQ(0, row_group_writer->total_bytes_written()); - EXPECT_EQ(0, row_group_writer->total_compressed_bytes_written()); - for (int col = 0; col < num_columns_; ++col) { - auto column_writer = static_cast*>( - row_group_writer->column(col)); - column_writer->Close(); + // Total_compressed_bytes() may equal to 0 if no dictionary enabled and + // no. Buffered values. + EXPECT_EQ(0, rowGroupWriter->totalBytesWritten()); + EXPECT_EQ(0, rowGroupWriter->totalCompressedBytesWritten()); + for (int col = 0; col < numColumns_; ++col) { + auto columnWriter = static_cast*>( + rowGroupWriter->column(col)); + columnWriter->close(); } - row_group_writer->Close(); - EXPECT_EQ(0, row_group_writer->total_compressed_bytes()); - EXPECT_NE(0, row_group_writer->total_bytes_written()); - EXPECT_NE(0, row_group_writer->total_compressed_bytes_written()); + rowGroupWriter->close(); + EXPECT_EQ(0, rowGroupWriter->totalCompressedBytes()); + EXPECT_NE(0, rowGroupWriter->totalBytesWritten()); + EXPECT_NE(0, rowGroupWriter->totalCompressedBytesWritten()); } - file_writer->Close(); + fileWriter->close(); PARQUET_ASSIGN_OR_THROW(auto buffer, sink->Finish()); - int num_rows_ = num_rowgroups_ * rows_per_rowgroup_; + int numRows_ = numRowgroups_ * rowsPerRowgroup_; auto source = std::make_shared<::arrow::io::BufferReader>(buffer); - auto file_reader = ParquetFileReader::Open(source); - ASSERT_EQ(num_columns_, file_reader->metadata()->num_columns()); - ASSERT_EQ(num_rowgroups_, file_reader->metadata()->num_row_groups()); - ASSERT_EQ(num_rows_, file_reader->metadata()->num_rows()); - - for (int rg = 0; rg < num_rowgroups_; ++rg) { - auto rg_reader = file_reader->RowGroup(rg); - auto rg_metadata = rg_reader->metadata(); - ASSERT_EQ(num_columns_, rg_metadata->num_columns()); - ASSERT_EQ(rows_per_rowgroup_, rg_metadata->num_rows()); + auto fileReader = ParquetFileReader::open(source); + ASSERT_EQ(numColumns_, fileReader->metadata()->numColumns()); + ASSERT_EQ(numRowgroups_, fileReader->metadata()->numRowGroups()); + ASSERT_EQ(numRows_, fileReader->metadata()->numRows()); + + for (int rg = 0; rg < numRowgroups_; ++rg) { + auto rgReader = fileReader->rowGroup(rg); + auto rgMetadata = rgReader->metadata(); + ASSERT_EQ(numColumns_, rgMetadata->numColumns()); + ASSERT_EQ(rowsPerRowgroup_, rgMetadata->numRows()); // Check that the specified compression was actually used. - ASSERT_EQ( - expected_codec_type, rg_metadata->ColumnChunk(0)->compression()); - - const int64_t total_byte_size = rg_metadata->total_byte_size(); - const int64_t total_compressed_size = - rg_metadata->total_compressed_size(); - if (expected_codec_type == Compression::UNCOMPRESSED) { - ASSERT_EQ(total_byte_size, total_compressed_size); + ASSERT_EQ(expectedCodecType, rgMetadata->columnChunk(0)->compression()); + + const int64_t totalByteSize = rgMetadata->totalByteSize(); + const int64_t totalCompressedSize = rgMetadata->totalCompressedSize(); + if (expectedCodecType == Compression::UNCOMPRESSED) { + ASSERT_EQ(totalByteSize, totalCompressedSize); } else { - ASSERT_NE(total_byte_size, total_compressed_size); + ASSERT_NE(totalByteSize, totalCompressedSize); } - int64_t total_column_byte_size = 0; - int64_t total_column_compressed_size = 0; - - for (int i = 0; i < num_columns_; ++i) { - int64_t values_read; - ASSERT_FALSE(rg_metadata->ColumnChunk(i)->has_index_page()); - total_column_byte_size += - rg_metadata->ColumnChunk(i)->total_uncompressed_size(); - total_column_compressed_size += - rg_metadata->ColumnChunk(i)->total_compressed_size(); - - std::vector def_levels_out(rows_per_rowgroup_); - std::vector rep_levels_out(rows_per_rowgroup_); - auto col_reader = std::static_pointer_cast>( - rg_reader->Column(i)); - this->SetupValuesOut(rows_per_rowgroup_); - col_reader->ReadBatch( - rows_per_rowgroup_, - def_levels_out.data(), - rep_levels_out.data(), - this->values_out_ptr_, - &values_read); - this->SyncValuesOut(); - ASSERT_EQ(rows_per_rowgroup_, values_read); - ASSERT_EQ(this->values_, this->values_out_); - ASSERT_EQ(this->def_levels_, def_levels_out); + int64_t totalColumnByteSize = 0; + int64_t totalColumnCompressedSize = 0; + + for (int i = 0; i < numColumns_; ++i) { + int64_t valuesRead; + ASSERT_FALSE(rgMetadata->columnChunk(i)->hasIndexPage()); + totalColumnByteSize += + rgMetadata->columnChunk(i)->totalUncompressedSize(); + totalColumnCompressedSize += + rgMetadata->columnChunk(i)->totalCompressedSize(); + + std::vector defLevelsOut(rowsPerRowgroup_); + std::vector repLevelsOut(rowsPerRowgroup_); + auto colReader = std::static_pointer_cast>( + rgReader->column(i)); + this->setupValuesOut(rowsPerRowgroup_); + colReader->readBatch( + rowsPerRowgroup_, + defLevelsOut.data(), + repLevelsOut.data(), + this->valuesOutPtr_, + &valuesRead); + this->syncValuesOut(); + ASSERT_EQ(rowsPerRowgroup_, valuesRead); + ASSERT_EQ(this->values_, this->valuesOut_); + ASSERT_EQ(this->defLevels_, defLevelsOut); } - ASSERT_EQ(total_byte_size, total_column_byte_size); - ASSERT_EQ(total_compressed_size, total_column_compressed_size); + ASSERT_EQ(totalByteSize, totalColumnByteSize); + ASSERT_EQ(totalCompressedSize, totalColumnCompressedSize); } } - void UnequalNumRows( - int64_t max_rows, - const std::vector rows_per_column) { - auto sink = CreateOutputStream(); + void unequalNumRows( + int64_t maxRows, + const std::vector rowsPerColumn) { + auto sink = createOutputStream(); auto gnode = std::static_pointer_cast(this->node_); std::shared_ptr props = WriterProperties::Builder().build(); - auto file_writer = ParquetFileWriter::Open(sink, gnode, props); + auto fileWriter = ParquetFileWriter::open(sink, gnode, props); - RowGroupWriter* row_group_writer; - row_group_writer = file_writer->AppendRowGroup(); + RowGroupWriter* rowGroupWriter; + rowGroupWriter = fileWriter->appendRowGroup(); - this->GenerateData(max_rows); - for (int col = 0; col < num_columns_; ++col) { - auto column_writer = static_cast*>( - row_group_writer->NextColumn()); - column_writer->WriteBatch( - rows_per_column[col], - this->def_levels_.data(), + this->generateData(maxRows); + for (int col = 0; col < numColumns_; ++col) { + auto columnWriter = static_cast*>( + rowGroupWriter->nextColumn()); + columnWriter->writeBatch( + rowsPerColumn[col], + this->defLevels_.data(), nullptr, - this->values_ptr_); - column_writer->Close(); + this->valuesPtr_); + columnWriter->close(); } - row_group_writer->Close(); - file_writer->Close(); + rowGroupWriter->close(); + fileWriter->close(); } - void UnequalNumRowsBuffered( - int64_t max_rows, - const std::vector rows_per_column) { - auto sink = CreateOutputStream(); + void unequalNumRowsBuffered( + int64_t maxRows, + const std::vector rowsPerColumn) { + auto sink = createOutputStream(); auto gnode = std::static_pointer_cast(this->node_); std::shared_ptr props = WriterProperties::Builder().build(); - auto file_writer = ParquetFileWriter::Open(sink, gnode, props); + auto fileWriter = ParquetFileWriter::open(sink, gnode, props); - RowGroupWriter* row_group_writer; - row_group_writer = file_writer->AppendBufferedRowGroup(); + RowGroupWriter* rowGroupWriter; + rowGroupWriter = fileWriter->appendBufferedRowGroup(); - this->GenerateData(max_rows); - for (int col = 0; col < num_columns_; ++col) { - auto column_writer = static_cast*>( - row_group_writer->column(col)); - column_writer->WriteBatch( - rows_per_column[col], - this->def_levels_.data(), + this->generateData(maxRows); + for (int col = 0; col < numColumns_; ++col) { + auto columnWriter = static_cast*>( + rowGroupWriter->column(col)); + columnWriter->writeBatch( + rowsPerColumn[col], + this->defLevels_.data(), nullptr, - this->values_ptr_); - column_writer->Close(); + this->valuesPtr_); + columnWriter->close(); } - row_group_writer->Close(); - file_writer->Close(); + rowGroupWriter->close(); + fileWriter->close(); } - void RepeatedUnequalRows() { - // Optional and repeated, so definition and repetition levels - this->SetUpSchema(Repetition::REPEATED); + void repeatedUnequalRows() { + // Optional and repeated, so definition and repetition levels. + this->setUpSchema(Repetition::kRepeated); const int kNumRows = 100; - this->GenerateData(kNumRows); + this->generateData(kNumRows); - auto sink = CreateOutputStream(); + auto sink = createOutputStream(); auto gnode = std::static_pointer_cast(this->node_); std::shared_ptr props = WriterProperties::Builder().build(); - auto file_writer = ParquetFileWriter::Open(sink, gnode, props); + auto fileWriter = ParquetFileWriter::open(sink, gnode, props); - RowGroupWriter* row_group_writer; - row_group_writer = file_writer->AppendRowGroup(); + RowGroupWriter* rowGroupWriter; + rowGroupWriter = fileWriter->appendRowGroup(); - this->GenerateData(kNumRows); + this->generateData(kNumRows); - std::vector definition_levels(kNumRows, 1); - std::vector repetition_levels(kNumRows, 0); + std::vector definitionLevels(kNumRows, 1); + std::vector repetitionLevels(kNumRows, 0); { - auto column_writer = static_cast*>( - row_group_writer->NextColumn()); - column_writer->WriteBatch( + auto columnWriter = static_cast*>( + rowGroupWriter->nextColumn()); + columnWriter->writeBatch( kNumRows, - definition_levels.data(), - repetition_levels.data(), - this->values_ptr_); - column_writer->Close(); + definitionLevels.data(), + repetitionLevels.data(), + this->valuesPtr_); + columnWriter->close(); } - definition_levels[1] = 0; - repetition_levels[3] = 1; + definitionLevels[1] = 0; + repetitionLevels[3] = 1; { - auto column_writer = static_cast*>( - row_group_writer->NextColumn()); - column_writer->WriteBatch( + auto columnWriter = static_cast*>( + rowGroupWriter->nextColumn()); + columnWriter->writeBatch( kNumRows, - definition_levels.data(), - repetition_levels.data(), - this->values_ptr_); - column_writer->Close(); + definitionLevels.data(), + repetitionLevels.data(), + this->valuesPtr_); + columnWriter->close(); } } - void ZeroRowsRowGroup() { - auto sink = CreateOutputStream(); + void zeroRowsRowGroup() { + auto sink = createOutputStream(); auto gnode = std::static_pointer_cast(this->node_); std::shared_ptr props = WriterProperties::Builder().build(); - auto file_writer = ParquetFileWriter::Open(sink, gnode, props); + auto fileWriter = ParquetFileWriter::open(sink, gnode, props); - RowGroupWriter* row_group_writer; + RowGroupWriter* rowGroupWriter; - row_group_writer = file_writer->AppendRowGroup(); - for (int col = 0; col < num_columns_; ++col) { - auto column_writer = static_cast*>( - row_group_writer->NextColumn()); - column_writer->Close(); + rowGroupWriter = fileWriter->appendRowGroup(); + for (int col = 0; col < numColumns_; ++col) { + auto columnWriter = static_cast*>( + rowGroupWriter->nextColumn()); + columnWriter->close(); } - row_group_writer->Close(); + rowGroupWriter->close(); - row_group_writer = file_writer->AppendBufferedRowGroup(); - for (int col = 0; col < num_columns_; ++col) { - auto column_writer = static_cast*>( - row_group_writer->column(col)); - column_writer->Close(); + rowGroupWriter = fileWriter->appendBufferedRowGroup(); + for (int col = 0; col < numColumns_; ++col) { + auto columnWriter = static_cast*>( + rowGroupWriter->column(col)); + columnWriter->close(); } - row_group_writer->Close(); + rowGroupWriter->close(); - file_writer->Close(); + fileWriter->close(); } }; @@ -344,177 +341,176 @@ typedef ::testing::Types< TYPED_TEST_SUITE(TestSerialize, TestTypes); TYPED_TEST(TestSerialize, SmallFileUncompressed) { - ASSERT_NO_FATAL_FAILURE(this->FileSerializeTest(Compression::UNCOMPRESSED)); + ASSERT_NO_FATAL_FAILURE(this->fileSerializeTest(Compression::UNCOMPRESSED)); } TYPED_TEST(TestSerialize, TooFewRows) { - std::vector num_rows = {100, 100, 100, 99}; - ASSERT_THROW(this->UnequalNumRows(100, num_rows), ParquetException); - ASSERT_THROW(this->UnequalNumRowsBuffered(100, num_rows), ParquetException); + std::vector numRows = {100, 100, 100, 99}; + ASSERT_THROW(this->unequalNumRows(100, numRows), ParquetException); + ASSERT_THROW(this->unequalNumRowsBuffered(100, numRows), ParquetException); } TYPED_TEST(TestSerialize, TooManyRows) { - std::vector num_rows = {100, 100, 100, 101}; - ASSERT_THROW(this->UnequalNumRows(101, num_rows), ParquetException); - ASSERT_THROW(this->UnequalNumRowsBuffered(101, num_rows), ParquetException); + std::vector numRows = {100, 100, 100, 101}; + ASSERT_THROW(this->unequalNumRows(101, numRows), ParquetException); + ASSERT_THROW(this->unequalNumRowsBuffered(101, numRows), ParquetException); } TYPED_TEST(TestSerialize, ZeroRows) { - ASSERT_NO_THROW(this->ZeroRowsRowGroup()); + ASSERT_NO_THROW(this->zeroRowsRowGroup()); } TYPED_TEST(TestSerialize, RepeatedTooFewRows) { - ASSERT_THROW(this->RepeatedUnequalRows(), ParquetException); + ASSERT_THROW(this->repeatedUnequalRows(), ParquetException); } TYPED_TEST(TestSerialize, SmallFileSnappy) { - ASSERT_NO_FATAL_FAILURE(this->FileSerializeTest(Compression::SNAPPY)); + ASSERT_NO_FATAL_FAILURE(this->fileSerializeTest(Compression::SNAPPY)); } #ifdef ARROW_WITH_BROTLI TYPED_TEST(TestSerialize, SmallFileBrotli) { - ASSERT_NO_FATAL_FAILURE(this->FileSerializeTest(Compression::BROTLI)); + ASSERT_NO_FATAL_FAILURE(this->fileSerializeTest(Compression::BROTLI)); } #endif TYPED_TEST(TestSerialize, SmallFileGzip) { - ASSERT_NO_FATAL_FAILURE(this->FileSerializeTest(Compression::GZIP)); + ASSERT_NO_FATAL_FAILURE(this->fileSerializeTest(Compression::GZIP)); } TYPED_TEST(TestSerialize, SmallFileLz4) { - ASSERT_NO_FATAL_FAILURE(this->FileSerializeTest(Compression::LZ4)); + ASSERT_NO_FATAL_FAILURE(this->fileSerializeTest(Compression::LZ4)); } TYPED_TEST(TestSerialize, SmallFileLz4Hadoop) { - ASSERT_NO_FATAL_FAILURE(this->FileSerializeTest(Compression::LZ4_HADOOP)); + ASSERT_NO_FATAL_FAILURE(this->fileSerializeTest(Compression::LZ4_HADOOP)); } TYPED_TEST(TestSerialize, SmallFileZstd) { - ASSERT_NO_FATAL_FAILURE(this->FileSerializeTest(Compression::ZSTD)); + ASSERT_NO_FATAL_FAILURE(this->fileSerializeTest(Compression::ZSTD)); } TEST(TestBufferedRowGroupWriter, DisabledDictionary) { // PARQUET-1706: - // Wrong dictionary_page_offset when writing only data pages via - // BufferedPageWriter - auto sink = CreateOutputStream(); - auto writer_props = WriterProperties::Builder().disable_dictionary()->build(); + // Wrong dictionary_page_offset when writing only data pages via. + // BufferedPageWriter. + auto sink = createOutputStream(); + auto writerProps = WriterProperties::Builder().disableDictionary()->build(); schema::NodeVector fields; fields.push_back( - PrimitiveNode::Make("col", Repetition::REQUIRED, Type::INT32)); + PrimitiveNode::make("col", Repetition::kRequired, Type::kInt32)); auto schema = std::static_pointer_cast( - GroupNode::Make("schema", Repetition::REQUIRED, fields)); - auto file_writer = ParquetFileWriter::Open(sink, schema, writer_props); - auto rg_writer = file_writer->AppendBufferedRowGroup(); - auto col_writer = static_cast(rg_writer->column(0)); + GroupNode::make("schema", Repetition::kRequired, fields)); + auto fileWriter = ParquetFileWriter::open(sink, schema, writerProps); + auto rgWriter = fileWriter->appendBufferedRowGroup(); + auto colWriter = static_cast(rgWriter->column(0)); int value = 0; - col_writer->WriteBatch(1, nullptr, nullptr, &value); - rg_writer->Close(); - file_writer->Close(); + colWriter->writeBatch(1, nullptr, nullptr, &value); + rgWriter->close(); + fileWriter->close(); PARQUET_ASSIGN_OR_THROW(auto buffer, sink->Finish()); auto source = std::make_shared<::arrow::io::BufferReader>(buffer); - auto file_reader = ParquetFileReader::Open(source); - ASSERT_EQ(1, file_reader->metadata()->num_row_groups()); - auto rg_reader = file_reader->RowGroup(0); - ASSERT_EQ(1, rg_reader->metadata()->num_columns()); - ASSERT_EQ(1, rg_reader->metadata()->num_rows()); - ASSERT_FALSE(rg_reader->metadata()->ColumnChunk(0)->has_dictionary_page()); + auto fileReader = ParquetFileReader::open(source); + ASSERT_EQ(1, fileReader->metadata()->numRowGroups()); + auto rgReader = fileReader->rowGroup(0); + ASSERT_EQ(1, rgReader->metadata()->numColumns()); + ASSERT_EQ(1, rgReader->metadata()->numRows()); + ASSERT_FALSE(rgReader->metadata()->columnChunk(0)->hasDictionaryPage()); } TEST(TestBufferedRowGroupWriter, MultiPageDisabledDictionary) { constexpr int kValueCount = 10000; constexpr int kPageSize = 16384; - auto sink = CreateOutputStream(); - auto writer_props = WriterProperties::Builder() - .disable_dictionary() - ->data_pagesize(kPageSize) - ->build(); + auto sink = createOutputStream(); + auto writerProps = WriterProperties::Builder() + .disableDictionary() + ->dataPagesize(kPageSize) + ->build(); schema::NodeVector fields; fields.push_back( - PrimitiveNode::Make("col", Repetition::REQUIRED, Type::INT32)); + PrimitiveNode::make("col", Repetition::kRequired, Type::kInt32)); auto schema = std::static_pointer_cast( - GroupNode::Make("schema", Repetition::REQUIRED, fields)); - auto file_writer = ParquetFileWriter::Open(sink, schema, writer_props); - auto rg_writer = file_writer->AppendBufferedRowGroup(); - auto col_writer = static_cast(rg_writer->column(0)); - std::vector values_in; + GroupNode::make("schema", Repetition::kRequired, fields)); + auto fileWriter = ParquetFileWriter::open(sink, schema, writerProps); + auto rgWriter = fileWriter->appendBufferedRowGroup(); + auto colWriter = static_cast(rgWriter->column(0)); + std::vector valuesIn; for (int i = 0; i < kValueCount; ++i) { - values_in.push_back((i % 100) + 1); + valuesIn.push_back((i % 100) + 1); } - col_writer->WriteBatch(kValueCount, nullptr, nullptr, values_in.data()); - rg_writer->Close(); - file_writer->Close(); + colWriter->writeBatch(kValueCount, nullptr, nullptr, valuesIn.data()); + rgWriter->close(); + fileWriter->close(); PARQUET_ASSIGN_OR_THROW(auto buffer, sink->Finish()); auto source = std::make_shared<::arrow::io::BufferReader>(buffer); - auto file_reader = ParquetFileReader::Open(source); - auto file_metadata = file_reader->metadata(); - ASSERT_EQ(1, file_reader->metadata()->num_row_groups()); - std::vector values_out(kValueCount); - for (int r = 0; r < file_metadata->num_row_groups(); ++r) { - auto rg_reader = file_reader->RowGroup(r); - ASSERT_EQ(1, rg_reader->metadata()->num_columns()); - ASSERT_EQ(kValueCount, rg_reader->metadata()->num_rows()); - int64_t total_values_read = 0; - std::shared_ptr col_reader; - ASSERT_NO_THROW(col_reader = rg_reader->Column(0)); - Int32Reader* int32_reader = static_cast(col_reader.get()); + auto fileReader = ParquetFileReader::open(source); + auto fileMetadata = fileReader->metadata(); + ASSERT_EQ(1, fileReader->metadata()->numRowGroups()); + std::vector valuesOut(kValueCount); + for (int r = 0; r < fileMetadata->numRowGroups(); ++r) { + auto rgReader = fileReader->rowGroup(r); + ASSERT_EQ(1, rgReader->metadata()->numColumns()); + ASSERT_EQ(kValueCount, rgReader->metadata()->numRows()); + int64_t totalValuesRead = 0; + std::shared_ptr colReader; + ASSERT_NO_THROW(colReader = rgReader->column(0)); + Int32Reader* int32Reader = static_cast(colReader.get()); int64_t vn = kValueCount; - int32_t* vx = values_out.data(); - while (int32_reader->HasNext()) { - int64_t values_read; - int32_reader->ReadBatch(vn, nullptr, nullptr, vx, &values_read); - vn -= values_read; - vx += values_read; - total_values_read += values_read; + int32_t* vx = valuesOut.data(); + while (int32Reader->hasNext()) { + int64_t valuesRead; + int32Reader->readBatch(vn, nullptr, nullptr, vx, &valuesRead); + vn -= valuesRead; + vx += valuesRead; + totalValuesRead += valuesRead; } - ASSERT_EQ(kValueCount, total_values_read); - ASSERT_EQ(values_in, values_out); + ASSERT_EQ(kValueCount, totalValuesRead); + ASSERT_EQ(valuesIn, valuesOut); } } TEST(ParquetRoundtrip, AllNulls) { - auto primitive_node = - PrimitiveNode::Make("nulls", Repetition::OPTIONAL, nullptr, Type::INT32); - schema::NodeVector columns({primitive_node}); + auto primitiveNode = PrimitiveNode::make( + "nulls", Repetition::kOptional, nullptr, Type::kInt32); + schema::NodeVector columns({primitiveNode}); - auto root_node = - GroupNode::Make("root", Repetition::REQUIRED, columns, nullptr); + auto rootNode = + GroupNode::make("root", Repetition::kRequired, columns, nullptr); - auto sink = CreateOutputStream(); + auto sink = createOutputStream(); - auto file_writer = ParquetFileWriter::Open( - sink, std::static_pointer_cast(root_node)); - auto row_group_writer = file_writer->AppendRowGroup(); - auto column_writer = - static_cast(row_group_writer->NextColumn()); + auto fileWriter = ParquetFileWriter::open( + sink, std::static_pointer_cast(rootNode)); + auto rowGroupWriter = fileWriter->appendRowGroup(); + auto columnWriter = static_cast(rowGroupWriter->nextColumn()); int32_t values[3]; - int16_t def_levels[] = {0, 0, 0}; + int16_t defLevels[] = {0, 0, 0}; - column_writer->WriteBatch(3, def_levels, nullptr, values); + columnWriter->writeBatch(3, defLevels, nullptr, values); - column_writer->Close(); - row_group_writer->Close(); - file_writer->Close(); + columnWriter->close(); + rowGroupWriter->close(); + fileWriter->close(); - ReaderProperties props = default_reader_properties(); - props.enable_buffered_stream(); + ReaderProperties props = defaultReaderProperties(); + props.enableBufferedStream(); PARQUET_ASSIGN_OR_THROW(auto buffer, sink->Finish()); auto source = std::make_shared<::arrow::io::BufferReader>(buffer); - auto file_reader = ParquetFileReader::Open(source, props); - auto row_group_reader = file_reader->RowGroup(0); - auto column_reader = - std::static_pointer_cast(row_group_reader->Column(0)); - - int64_t values_read; - def_levels[0] = -1; - def_levels[1] = -1; - def_levels[2] = -1; - column_reader->ReadBatch(3, def_levels, nullptr, values, &values_read); - EXPECT_THAT(def_levels, ElementsAre(0, 0, 0)); + auto fileReader = ParquetFileReader::open(source, props); + auto RowGroupReader = fileReader->rowGroup(0); + auto ColumnReader = + std::static_pointer_cast(RowGroupReader->column(0)); + + int64_t valuesRead; + defLevels[0] = -1; + defLevels[1] = -1; + defLevels[2] = -1; + ColumnReader->readBatch(3, defLevels, nullptr, values, &valuesRead); + EXPECT_THAT(defLevels, ElementsAre(0, 0, 0)); } } // namespace test diff --git a/velox/dwio/parquet/writer/arrow/tests/Hasher.h b/velox/dwio/parquet/writer/arrow/tests/Hasher.h index 3ffd42ed240..81f12b1f8f5 100644 --- a/velox/dwio/parquet/writer/arrow/tests/Hasher.h +++ b/velox/dwio/parquet/writer/arrow/tests/Hasher.h @@ -23,121 +23,121 @@ namespace facebook::velox::parquet::arrow { -// Abstract class for hash +// Abstract class for hash. class Hasher { public: /// Compute hash for 32 bits value by using its plain encoding result. /// /// @param value the value to hash. /// @return hash result. - virtual uint64_t Hash(int32_t value) const = 0; + virtual uint64_t hash(int32_t value) const = 0; /// Compute hash for 64 bits value by using its plain encoding result. /// /// @param value the value to hash. /// @return hash result. - virtual uint64_t Hash(int64_t value) const = 0; + virtual uint64_t hash(int64_t value) const = 0; /// Compute hash for float value by using its plain encoding result. /// /// @param value the value to hash. /// @return hash result. - virtual uint64_t Hash(float value) const = 0; + virtual uint64_t hash(float value) const = 0; /// Compute hash for double value by using its plain encoding result. /// /// @param value the value to hash. /// @return hash result. - virtual uint64_t Hash(double value) const = 0; + virtual uint64_t hash(double value) const = 0; /// Compute hash for Int96 value by using its plain encoding result. /// /// @param value the value to hash. /// @return hash result. - virtual uint64_t Hash(const Int96* value) const = 0; + virtual uint64_t hash(const Int96* value) const = 0; /// Compute hash for ByteArray value by using its plain encoding result. /// /// @param value the value to hash. /// @return hash result. - virtual uint64_t Hash(const ByteArray* value) const = 0; + virtual uint64_t hash(const ByteArray* value) const = 0; - /// Compute hash for fixed byte array value by using its plain encoding - /// result. + /// Compute hash for fixed byte array value by using its plain encoding. + /// Result. /// /// @param value the value address. /// @param len the value length. - virtual uint64_t Hash(const FLBA* value, uint32_t len) const = 0; + virtual uint64_t hash(const FLBA* value, uint32_t len) const = 0; - /// Batch compute hashes for 32 bits values by using its plain encoding - /// result. + /// Batch compute hashes for 32 bits values by using its plain encoding. + /// Result. /// /// @param values a pointer to the values to hash. /// @param num_values the number of values to hash. - /// @param hashes a pointer to the output hash values, its length should be - /// equal to num_values. - virtual void Hashes(const int32_t* values, int num_values, uint64_t* hashes) + /// @param hashes a pointer to the output hash values, its length should be. + /// Equal to num_values. + virtual void hashes(const int32_t* values, int numValues, uint64_t* hashes) const = 0; - /// Batch compute hashes for 64 bits values by using its plain encoding - /// result. + /// Batch compute hashes for 64 bits values by using its plain encoding. + /// Result. /// /// @param values a pointer to the values to hash. /// @param num_values the number of values to hash. - /// @param hashes a pointer to the output hash values, its length should be - /// equal to num_values. - virtual void Hashes(const int64_t* values, int num_values, uint64_t* hashes) + /// @param hashes a pointer to the output hash values, its length should be. + /// Equal to num_values. + virtual void hashes(const int64_t* values, int numValues, uint64_t* hashes) const = 0; /// Batch compute hashes for float values by using its plain encoding result. /// /// @param values a pointer to the values to hash. /// @param num_values the number of values to hash. - /// @param hashes a pointer to the output hash values, its length should be - /// equal to num_values. - virtual void Hashes(const float* values, int num_values, uint64_t* hashes) + /// @param hashes a pointer to the output hash values, its length should be. + /// Equal to num_values. + virtual void hashes(const float* values, int numValues, uint64_t* hashes) const = 0; /// Batch compute hashes for double values by using its plain encoding result. /// /// @param values a pointer to the values to hash. /// @param num_values the number of values to hash. - /// @param hashes a pointer to the output hash values, its length should be - /// equal to num_values. - virtual void Hashes(const double* values, int num_values, uint64_t* hashes) + /// @param hashes a pointer to the output hash values, its length should be. + /// Equal to num_values. + virtual void hashes(const double* values, int numValues, uint64_t* hashes) const = 0; /// Batch compute hashes for Int96 values by using its plain encoding result. /// /// @param values a pointer to the values to hash. /// @param num_values the number of values to hash. - /// @param hashes a pointer to the output hash values, its length should be - /// equal to num_values. - virtual void Hashes(const Int96* values, int num_values, uint64_t* hashes) + /// @param hashes a pointer to the output hash values, its length should be. + /// Equal to num_values. + virtual void hashes(const Int96* values, int numValues, uint64_t* hashes) const = 0; - /// Batch compute hashes for ByteArray values by using its plain encoding - /// result. + /// Batch compute hashes for ByteArray values by using its plain encoding. + /// Result. /// /// @param values a pointer to the values to hash. /// @param num_values the number of values to hash. - /// @param hashes a pointer to the output hash values, its length should be - /// equal to num_values. - virtual void Hashes(const ByteArray* values, int num_values, uint64_t* hashes) + /// @param hashes a pointer to the output hash values, its length should be. + /// Equal to num_values. + virtual void hashes(const ByteArray* values, int numValues, uint64_t* hashes) const = 0; - /// Batch compute hashes for fixed byte array values by using its plain - /// encoding result. + /// Batch compute hashes for fixed byte array values by using its plain. + /// Encoding result. /// /// @param values the value address. /// @param type_len the value length. /// @param num_values the number of values to hash. - /// @param hashes a pointer to the output hash values, its length should be - /// equal to num_values. - virtual void Hashes( + /// @param hashes a pointer to the output hash values, its length should be. + /// Equal to num_values. + virtual void hashes( const FLBA* values, - uint32_t type_len, - int num_values, + uint32_t typeLen, + int numValues, uint64_t* hashes) const = 0; virtual ~Hasher() = default; diff --git a/velox/dwio/parquet/writer/arrow/tests/MetadataTest.cpp b/velox/dwio/parquet/writer/arrow/tests/MetadataTest.cpp index a0d0d2e5cb5..2efb7aa02f8 100644 --- a/velox/dwio/parquet/writer/arrow/tests/MetadataTest.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/MetadataTest.cpp @@ -21,16 +21,20 @@ #include #include "arrow/util/key_value_metadata.h" +#include "velox/common/io/IoStatistics.h" +#include "velox/common/testutil/TempFilePath.h" #include "velox/dwio/parquet/reader/ParquetReader.h" #include "velox/dwio/parquet/writer/arrow/FileWriter.h" #include "velox/dwio/parquet/writer/arrow/tests/TestUtil.h" -#include "velox/exec/tests/utils/TempFilePath.h" namespace facebook::velox::parquet::arrow { namespace metadata { + +using namespace facebook::velox::common::testutil; + namespace { void writeToFile( - std::shared_ptr filePath, + std::shared_ptr filePath, std::shared_ptr buffer) { auto localWriteFile = std::make_unique(filePath->getPath(), false, false); @@ -41,29 +45,29 @@ void writeToFile( } } // namespace -// Helper function for generating table metadata -std::unique_ptr GenerateTableMetaData( +// Helper function for generating table metadata. +std::unique_ptr generateTableMetaData( const SchemaDescriptor& schema, const std::shared_ptr& props, const int64_t& nrows, - EncodedStatistics stats_int, - EncodedStatistics stats_float) { - auto f_builder = FileMetaDataBuilder::Make(&schema, props); - auto rg1_builder = f_builder->AppendRowGroup(); - // Write the metadata - // rowgroup1 metadata - auto col1_builder = rg1_builder->NextColumnChunk(); - auto col2_builder = rg1_builder->NextColumnChunk(); - // column metadata - std::map dict_encoding_stats( - {{Encoding::RLE_DICTIONARY, 1}}); - std::map data_encoding_stats( - {{Encoding::PLAIN, 1}, {Encoding::RLE, 1}}); - stats_int.set_is_signed(true); - col1_builder->SetStatistics(stats_int); - stats_float.set_is_signed(true); - col2_builder->SetStatistics(stats_float); - col1_builder->Finish( + EncodedStatistics statsInt, + EncodedStatistics statsFloat) { + auto fBuilder = FileMetaDataBuilder::make(&schema, props); + auto rg1Builder = fBuilder->appendRowGroup(); + // Write the metadata. + // Rowgroup1 metadata. + auto col1Builder = rg1Builder->nextColumnChunk(); + auto col2Builder = rg1Builder->nextColumnChunk(); + // Column metadata. + std::map dictEncodingStats( + {{Encoding::kRleDictionary, 1}}); + std::map dataEncodingStats( + {{Encoding::kPlain, 1}, {Encoding::kRle, 1}}); + statsInt.setIsSigned(true); + col1Builder->setStatistics(statsInt); + statsFloat.setIsSigned(true); + col2Builder->setStatistics(statsFloat); + col1Builder->finish( nrows / 2, 4, 0, @@ -72,9 +76,9 @@ std::unique_ptr GenerateTableMetaData( 600, true, false, - dict_encoding_stats, - data_encoding_stats); - col2_builder->Finish( + dictEncodingStats, + dataEncodingStats); + col2Builder->finish( nrows / 2, 24, 0, @@ -83,31 +87,31 @@ std::unique_ptr GenerateTableMetaData( 600, true, false, - dict_encoding_stats, - data_encoding_stats); - - rg1_builder->set_num_rows(nrows / 2); - rg1_builder->Finish(1024); - - // rowgroup2 metadata - auto rg2_builder = f_builder->AppendRowGroup(); - col1_builder = rg2_builder->NextColumnChunk(); - col2_builder = rg2_builder->NextColumnChunk(); - // column metadata - col1_builder->SetStatistics(stats_int); - col2_builder->SetStatistics(stats_float); - col1_builder->Finish( + dictEncodingStats, + dataEncodingStats); + + rg1Builder->setNumRows(nrows / 2); + rg1Builder->finish(1024); + + // Rowgroup2 metadata. + auto rg2Builder = fBuilder->appendRowGroup(); + col1Builder = rg2Builder->nextColumnChunk(); + col2Builder = rg2Builder->nextColumnChunk(); + // Column metadata. + col1Builder->setStatistics(statsInt); + col2Builder->setStatistics(statsFloat); + col1Builder->finish( nrows / 2, - /*dictionary_page_offset=*/0, + 0, 0, 10, 512, 600, - /*has_dictionary=*/false, false, - dict_encoding_stats, - data_encoding_stats); - col2_builder->Finish( + false, + dictEncodingStats, + dataEncodingStats); + col2Builder->finish( nrows / 2, 16, 0, @@ -116,17 +120,17 @@ std::unique_ptr GenerateTableMetaData( 600, true, false, - dict_encoding_stats, - data_encoding_stats); + dictEncodingStats, + dataEncodingStats); - rg2_builder->set_num_rows(nrows / 2); - rg2_builder->Finish(1024); + rg2Builder->setNumRows(nrows / 2); + rg2Builder->finish(1024); - // Return the metadata accessor - return f_builder->Finish(); + // Return the metadata accessor. + return fBuilder->finish(); } -void AssertEncodings( +void assertEncodings( const ColumnChunkMetaData& data, const std::set& expected) { std::set encodings( @@ -139,210 +143,210 @@ TEST(Metadata, TestBuildAccess) { schema::NodePtr root; SchemaDescriptor schema; - WriterProperties::Builder prop_builder; + WriterProperties::Builder propBuilder; std::shared_ptr props = - prop_builder.version(ParquetVersion::PARQUET_2_6)->build(); + propBuilder.version(ParquetVersion::PARQUET_2_6)->build(); - fields.push_back(schema::Int32("int_col", Repetition::REQUIRED)); - fields.push_back(schema::Float("float_col", Repetition::REQUIRED)); - root = schema::GroupNode::Make("schema", Repetition::REPEATED, fields); - schema.Init(root); + fields.push_back(schema::int32("int_col", Repetition::kRequired)); + fields.push_back(schema::floatType("float_col", Repetition::kRequired)); + root = schema::GroupNode::make("schema", Repetition::kRepeated, fields); + schema.init(root); int64_t nrows = 1000; - int32_t int_min = 100, int_max = 200; - EncodedStatistics stats_int; - stats_int.set_null_count(0) - .set_distinct_count(nrows) - .set_min(std::string(reinterpret_cast(&int_min), 4)) - .set_max(std::string(reinterpret_cast(&int_max), 4)); - EncodedStatistics stats_float; - float float_min = 100.100f, float_max = 200.200f; - stats_float.set_null_count(0) - .set_distinct_count(nrows) - .set_min(std::string(reinterpret_cast(&float_min), 4)) - .set_max(std::string(reinterpret_cast(&float_max), 4)); - - // Generate the metadata - auto f_accessor = - GenerateTableMetaData(schema, props, nrows, stats_int, stats_float); - - std::string f_accessor_serialized_metadata = f_accessor->SerializeToString(); - uint32_t expected_len = - static_cast(f_accessor_serialized_metadata.length()); - - // decoded_len is an in-out parameter - uint32_t decoded_len = expected_len; - auto f_accessor_copy = - FileMetaData::Make(f_accessor_serialized_metadata.data(), &decoded_len); - - // Check that all of the serialized data is consumed - ASSERT_EQ(expected_len, decoded_len); + int32_t intMin = 100, intMax = 200; + EncodedStatistics statsInt; + statsInt.setNullCount(0) + .setDistinctCount(nrows) + .setMin(std::string(reinterpret_cast(&intMin), 4)) + .setMax(std::string(reinterpret_cast(&intMax), 4)); + EncodedStatistics statsFloat; + float floatMin = 100.100f, floatMax = 200.200f; + statsFloat.setNullCount(0) + .setDistinctCount(nrows) + .setMin(std::string(reinterpret_cast(&floatMin), 4)) + .setMax(std::string(reinterpret_cast(&floatMax), 4)); + + // Generate the metadata. + auto fAccessor = + generateTableMetaData(schema, props, nrows, statsInt, statsFloat); + + std::string fAccessorSerializedMetadata = fAccessor->serializeToString(); + uint32_t expectedLen = + static_cast(fAccessorSerializedMetadata.length()); + + // Decoded_len is an in-out parameter. + uint32_t decodedLen = expectedLen; + auto fAccessorCopy = + FileMetaData::make(fAccessorSerializedMetadata.data(), &decodedLen); + + // Check that all of the serialized data is consumed. + ASSERT_EQ(expectedLen, decodedLen); // Run this block twice, one for f_accessor, one for f_accessor_copy. // To make sure SerializedMetadata was deserialized correctly. - std::vector f_accessors = { - f_accessor.get(), f_accessor_copy.get()}; - for (int loop_index = 0; loop_index < 2; loop_index++) { - // file metadata - ASSERT_EQ(nrows, f_accessors[loop_index]->num_rows()); - ASSERT_LE(0, static_cast(f_accessors[loop_index]->size())); - ASSERT_EQ(2, f_accessors[loop_index]->num_row_groups()); - ASSERT_EQ(ParquetVersion::PARQUET_2_6, f_accessors[loop_index]->version()); + std::vector fAccessors = { + fAccessor.get(), fAccessorCopy.get()}; + for (int loopIndex = 0; loopIndex < 2; loopIndex++) { + // File metadata. + ASSERT_EQ(nrows, fAccessors[loopIndex]->numRows()); + ASSERT_LE(0, static_cast(fAccessors[loopIndex]->size())); + ASSERT_EQ(2, fAccessors[loopIndex]->numRowGroups()); + ASSERT_EQ(ParquetVersion::PARQUET_2_6, fAccessors[loopIndex]->version()); ASSERT_TRUE( - f_accessors[loop_index]->created_by().find(DEFAULT_CREATED_BY) != + fAccessors[loopIndex]->createdBy().find(DEFAULT_CREATED_BY) != std::string::npos); - ASSERT_EQ(3, f_accessors[loop_index]->num_schema_elements()); - - // row group1 metadata - auto rg1_accessor = f_accessors[loop_index]->RowGroup(0); - ASSERT_EQ(2, rg1_accessor->num_columns()); - ASSERT_EQ(nrows / 2, rg1_accessor->num_rows()); - ASSERT_EQ(1024, rg1_accessor->total_byte_size()); - ASSERT_EQ(1024, rg1_accessor->total_compressed_size()); + ASSERT_EQ(3, fAccessors[loopIndex]->numSchemaElements()); + + // Row group1 metadata. + auto rg1Accessor = fAccessors[loopIndex]->rowGroup(0); + ASSERT_EQ(2, rg1Accessor->numColumns()); + ASSERT_EQ(nrows / 2, rg1Accessor->numRows()); + ASSERT_EQ(1024, rg1Accessor->totalByteSize()); + ASSERT_EQ(1024, rg1Accessor->totalCompressedSize()); EXPECT_EQ( - rg1_accessor->file_offset(), - rg1_accessor->ColumnChunk(0)->dictionary_page_offset()); - - auto rg1_column1 = rg1_accessor->ColumnChunk(0); - auto rg1_column2 = rg1_accessor->ColumnChunk(1); - ASSERT_EQ(true, rg1_column1->is_stats_set()); - ASSERT_EQ(true, rg1_column2->is_stats_set()); - ASSERT_EQ(stats_float.min(), rg1_column2->statistics()->EncodeMin()); - ASSERT_EQ(stats_float.max(), rg1_column2->statistics()->EncodeMax()); - ASSERT_EQ(stats_int.min(), rg1_column1->statistics()->EncodeMin()); - ASSERT_EQ(stats_int.max(), rg1_column1->statistics()->EncodeMax()); - ASSERT_EQ(0, rg1_column1->statistics()->null_count()); - ASSERT_EQ(0, rg1_column2->statistics()->null_count()); - ASSERT_EQ(nrows, rg1_column1->statistics()->distinct_count()); - ASSERT_EQ(nrows, rg1_column2->statistics()->distinct_count()); - ASSERT_EQ(DEFAULT_COMPRESSION_TYPE, rg1_column1->compression()); - ASSERT_EQ(DEFAULT_COMPRESSION_TYPE, rg1_column2->compression()); - ASSERT_EQ(nrows / 2, rg1_column1->num_values()); - ASSERT_EQ(nrows / 2, rg1_column2->num_values()); + rg1Accessor->fileOffset(), + rg1Accessor->columnChunk(0)->dictionaryPageOffset()); + + auto rg1Column1 = rg1Accessor->columnChunk(0); + auto rg1Column2 = rg1Accessor->columnChunk(1); + ASSERT_EQ(true, rg1Column1->isStatsSet()); + ASSERT_EQ(true, rg1Column2->isStatsSet()); + ASSERT_EQ(statsFloat.min(), rg1Column2->statistics()->encodeMin()); + ASSERT_EQ(statsFloat.max(), rg1Column2->statistics()->encodeMax()); + ASSERT_EQ(statsInt.min(), rg1Column1->statistics()->encodeMin()); + ASSERT_EQ(statsInt.max(), rg1Column1->statistics()->encodeMax()); + ASSERT_EQ(0, rg1Column1->statistics()->nullCount()); + ASSERT_EQ(0, rg1Column2->statistics()->nullCount()); + ASSERT_EQ(nrows, rg1Column1->statistics()->distinctCount()); + ASSERT_EQ(nrows, rg1Column2->statistics()->distinctCount()); + ASSERT_EQ(DEFAULT_COMPRESSION_TYPE, rg1Column1->compression()); + ASSERT_EQ(DEFAULT_COMPRESSION_TYPE, rg1Column2->compression()); + ASSERT_EQ(nrows / 2, rg1Column1->numValues()); + ASSERT_EQ(nrows / 2, rg1Column2->numValues()); { std::set encodings{ - Encoding::RLE, Encoding::RLE_DICTIONARY, Encoding::PLAIN}; - AssertEncodings(*rg1_column1, encodings); + Encoding::kRle, Encoding::kRleDictionary, Encoding::kPlain}; + assertEncodings(*rg1Column1, encodings); } { std::set encodings{ - Encoding::RLE, Encoding::RLE_DICTIONARY, Encoding::PLAIN}; - AssertEncodings(*rg1_column2, encodings); + Encoding::kRle, Encoding::kRleDictionary, Encoding::kPlain}; + assertEncodings(*rg1Column2, encodings); } - ASSERT_EQ(512, rg1_column1->total_compressed_size()); - ASSERT_EQ(512, rg1_column2->total_compressed_size()); - ASSERT_EQ(600, rg1_column1->total_uncompressed_size()); - ASSERT_EQ(600, rg1_column2->total_uncompressed_size()); - ASSERT_EQ(4, rg1_column1->dictionary_page_offset()); - ASSERT_EQ(24, rg1_column2->dictionary_page_offset()); - ASSERT_EQ(10, rg1_column1->data_page_offset()); - ASSERT_EQ(30, rg1_column2->data_page_offset()); - ASSERT_EQ(3, rg1_column1->encoding_stats().size()); - ASSERT_EQ(3, rg1_column2->encoding_stats().size()); - - auto rg2_accessor = f_accessors[loop_index]->RowGroup(1); - ASSERT_EQ(2, rg2_accessor->num_columns()); - ASSERT_EQ(nrows / 2, rg2_accessor->num_rows()); - ASSERT_EQ(1024, rg2_accessor->total_byte_size()); - ASSERT_EQ(1024, rg2_accessor->total_compressed_size()); + ASSERT_EQ(512, rg1Column1->totalCompressedSize()); + ASSERT_EQ(512, rg1Column2->totalCompressedSize()); + ASSERT_EQ(600, rg1Column1->totalUncompressedSize()); + ASSERT_EQ(600, rg1Column2->totalUncompressedSize()); + ASSERT_EQ(4, rg1Column1->dictionaryPageOffset()); + ASSERT_EQ(24, rg1Column2->dictionaryPageOffset()); + ASSERT_EQ(10, rg1Column1->dataPageOffset()); + ASSERT_EQ(30, rg1Column2->dataPageOffset()); + ASSERT_EQ(3, rg1Column1->encodingStats().size()); + ASSERT_EQ(3, rg1Column2->encodingStats().size()); + + auto rg2Accessor = fAccessors[loopIndex]->rowGroup(1); + ASSERT_EQ(2, rg2Accessor->numColumns()); + ASSERT_EQ(nrows / 2, rg2Accessor->numRows()); + ASSERT_EQ(1024, rg2Accessor->totalByteSize()); + ASSERT_EQ(1024, rg2Accessor->totalCompressedSize()); EXPECT_EQ( - rg2_accessor->file_offset(), - rg2_accessor->ColumnChunk(0)->data_page_offset()); - - auto rg2_column1 = rg2_accessor->ColumnChunk(0); - auto rg2_column2 = rg2_accessor->ColumnChunk(1); - ASSERT_EQ(true, rg2_column1->is_stats_set()); - ASSERT_EQ(true, rg2_column2->is_stats_set()); - ASSERT_EQ(stats_float.min(), rg2_column2->statistics()->EncodeMin()); - ASSERT_EQ(stats_float.max(), rg2_column2->statistics()->EncodeMax()); - ASSERT_EQ(stats_int.min(), rg1_column1->statistics()->EncodeMin()); - ASSERT_EQ(stats_int.max(), rg1_column1->statistics()->EncodeMax()); - ASSERT_EQ(0, rg2_column1->statistics()->null_count()); - ASSERT_EQ(0, rg2_column2->statistics()->null_count()); - ASSERT_EQ(nrows, rg2_column1->statistics()->distinct_count()); - ASSERT_EQ(nrows, rg2_column2->statistics()->distinct_count()); - ASSERT_EQ(nrows / 2, rg2_column1->num_values()); - ASSERT_EQ(nrows / 2, rg2_column2->num_values()); - ASSERT_EQ(DEFAULT_COMPRESSION_TYPE, rg2_column1->compression()); - ASSERT_EQ(DEFAULT_COMPRESSION_TYPE, rg2_column2->compression()); + rg2Accessor->fileOffset(), + rg2Accessor->columnChunk(0)->dataPageOffset()); + + auto rg2Column1 = rg2Accessor->columnChunk(0); + auto rg2Column2 = rg2Accessor->columnChunk(1); + ASSERT_EQ(true, rg2Column1->isStatsSet()); + ASSERT_EQ(true, rg2Column2->isStatsSet()); + ASSERT_EQ(statsFloat.min(), rg2Column2->statistics()->encodeMin()); + ASSERT_EQ(statsFloat.max(), rg2Column2->statistics()->encodeMax()); + ASSERT_EQ(statsInt.min(), rg1Column1->statistics()->encodeMin()); + ASSERT_EQ(statsInt.max(), rg1Column1->statistics()->encodeMax()); + ASSERT_EQ(0, rg2Column1->statistics()->nullCount()); + ASSERT_EQ(0, rg2Column2->statistics()->nullCount()); + ASSERT_EQ(nrows, rg2Column1->statistics()->distinctCount()); + ASSERT_EQ(nrows, rg2Column2->statistics()->distinctCount()); + ASSERT_EQ(nrows / 2, rg2Column1->numValues()); + ASSERT_EQ(nrows / 2, rg2Column2->numValues()); + ASSERT_EQ(DEFAULT_COMPRESSION_TYPE, rg2Column1->compression()); + ASSERT_EQ(DEFAULT_COMPRESSION_TYPE, rg2Column2->compression()); { - std::set encodings{Encoding::RLE, Encoding::PLAIN}; - AssertEncodings(*rg2_column1, encodings); + std::set encodings{Encoding::kRle, Encoding::kPlain}; + assertEncodings(*rg2Column1, encodings); } { std::set encodings{ - Encoding::RLE, Encoding::RLE_DICTIONARY, Encoding::PLAIN}; - AssertEncodings(*rg2_column2, encodings); + Encoding::kRle, Encoding::kRleDictionary, Encoding::kPlain}; + assertEncodings(*rg2Column2, encodings); } - ASSERT_EQ(512, rg2_column1->total_compressed_size()); - ASSERT_EQ(512, rg2_column2->total_compressed_size()); - ASSERT_EQ(600, rg2_column1->total_uncompressed_size()); - ASSERT_EQ(600, rg2_column2->total_uncompressed_size()); - EXPECT_FALSE(rg2_column1->has_dictionary_page()); - ASSERT_EQ(0, rg2_column1->dictionary_page_offset()); - ASSERT_EQ(16, rg2_column2->dictionary_page_offset()); - ASSERT_EQ(10, rg2_column1->data_page_offset()); - ASSERT_EQ(26, rg2_column2->data_page_offset()); - ASSERT_EQ(2, rg2_column1->encoding_stats().size()); - ASSERT_EQ(3, rg2_column2->encoding_stats().size()); - - // Test FileMetaData::set_file_path - ASSERT_TRUE(rg2_column1->file_path().empty()); - f_accessors[loop_index]->set_file_path("/foo/bar/bar.parquet"); - ASSERT_EQ("/foo/bar/bar.parquet", rg2_column1->file_path()); + ASSERT_EQ(512, rg2Column1->totalCompressedSize()); + ASSERT_EQ(512, rg2Column2->totalCompressedSize()); + ASSERT_EQ(600, rg2Column1->totalUncompressedSize()); + ASSERT_EQ(600, rg2Column2->totalUncompressedSize()); + EXPECT_FALSE(rg2Column1->hasDictionaryPage()); + ASSERT_EQ(0, rg2Column1->dictionaryPageOffset()); + ASSERT_EQ(16, rg2Column2->dictionaryPageOffset()); + ASSERT_EQ(10, rg2Column1->dataPageOffset()); + ASSERT_EQ(26, rg2Column2->dataPageOffset()); + ASSERT_EQ(2, rg2Column1->encodingStats().size()); + ASSERT_EQ(3, rg2Column2->encodingStats().size()); + + // Test FileMetaData::set_file_path. + ASSERT_TRUE(rg2Column1->filePath().empty()); + fAccessors[loopIndex]->setFilePath("/foo/bar/bar.parquet"); + ASSERT_EQ("/foo/bar/bar.parquet", rg2Column1->filePath()); } - // Test AppendRowGroups - auto f_accessor_2 = - GenerateTableMetaData(schema, props, nrows, stats_int, stats_float); - f_accessor->AppendRowGroups(*f_accessor_2); - ASSERT_EQ(4, f_accessor->num_row_groups()); - ASSERT_EQ(nrows * 2, f_accessor->num_rows()); - ASSERT_LE(0, static_cast(f_accessor->size())); - ASSERT_EQ(ParquetVersion::PARQUET_2_6, f_accessor->version()); + // Test AppendRowGroups. + auto fAccessor2 = + generateTableMetaData(schema, props, nrows, statsInt, statsFloat); + fAccessor->appendRowGroups(*fAccessor2); + ASSERT_EQ(4, fAccessor->numRowGroups()); + ASSERT_EQ(nrows * 2, fAccessor->numRows()); + ASSERT_LE(0, static_cast(fAccessor->size())); + ASSERT_EQ(ParquetVersion::PARQUET_2_6, fAccessor->version()); ASSERT_TRUE( - f_accessor->created_by().find(DEFAULT_CREATED_BY) != std::string::npos); - ASSERT_EQ(3, f_accessor->num_schema_elements()); + fAccessor->createdBy().find(DEFAULT_CREATED_BY) != std::string::npos); + ASSERT_EQ(3, fAccessor->numSchemaElements()); // Test AppendRowGroups from self (ARROW-13654) - f_accessor->AppendRowGroups(*f_accessor); - ASSERT_EQ(8, f_accessor->num_row_groups()); - ASSERT_EQ(nrows * 4, f_accessor->num_rows()); - ASSERT_EQ(3, f_accessor->num_schema_elements()); - - // Test Subset - auto f_accessor_1 = f_accessor->Subset({2, 3}); - ASSERT_TRUE(f_accessor_1->Equals(*f_accessor_2)); - - f_accessor_1 = f_accessor_2->Subset({0}); - f_accessor_1->AppendRowGroups(*f_accessor->Subset({0})); - ASSERT_TRUE(f_accessor_1->Equals(*f_accessor->Subset({2, 0}))); + fAccessor->appendRowGroups(*fAccessor); + ASSERT_EQ(8, fAccessor->numRowGroups()); + ASSERT_EQ(nrows * 4, fAccessor->numRows()); + ASSERT_EQ(3, fAccessor->numSchemaElements()); + + // Test Subset. + auto fAccessor1 = fAccessor->subset({2, 3}); + ASSERT_TRUE(fAccessor1->equals(*fAccessor2)); + + fAccessor1 = fAccessor2->subset({0}); + fAccessor1->appendRowGroups(*fAccessor->subset({0})); + ASSERT_TRUE(fAccessor1->equals(*fAccessor->subset({2, 0}))); } TEST(Metadata, TestV1Version) { - // PARQUET-839 + // PARQUET-839. schema::NodeVector fields; schema::NodePtr root; SchemaDescriptor schema; - WriterProperties::Builder prop_builder; + WriterProperties::Builder propBuilder; std::shared_ptr props = - prop_builder.version(ParquetVersion::PARQUET_1_0)->build(); + propBuilder.version(ParquetVersion::PARQUET_1_0)->build(); - fields.push_back(schema::Int32("int_col", Repetition::REQUIRED)); - fields.push_back(schema::Float("float_col", Repetition::REQUIRED)); - root = schema::GroupNode::Make("schema", Repetition::REPEATED, fields); - schema.Init(root); + fields.push_back(schema::int32("int_col", Repetition::kRequired)); + fields.push_back(schema::floatType("float_col", Repetition::kRequired)); + root = schema::GroupNode::make("schema", Repetition::kRepeated, fields); + schema.init(root); - auto f_builder = FileMetaDataBuilder::Make(&schema, props); + auto fBuilder = FileMetaDataBuilder::make(&schema, props); - // Read the metadata - auto f_accessor = f_builder->Finish(); + // Read the metadata. + auto fAccessor = fBuilder->finish(); - // file metadata - ASSERT_EQ(ParquetVersion::PARQUET_1_0, f_accessor->version()); + // File metadata. + ASSERT_EQ(ParquetVersion::PARQUET_1_0, fAccessor->version()); } TEST(Metadata, TestKeyValueMetadata) { @@ -350,69 +354,72 @@ TEST(Metadata, TestKeyValueMetadata) { schema::NodePtr root; SchemaDescriptor schema; - WriterProperties::Builder prop_builder; + WriterProperties::Builder propBuilder; std::shared_ptr props = - prop_builder.version(ParquetVersion::PARQUET_1_0)->build(); + propBuilder.version(ParquetVersion::PARQUET_1_0)->build(); - fields.push_back(schema::Int32("int_col", Repetition::REQUIRED)); - fields.push_back(schema::Float("float_col", Repetition::REQUIRED)); - root = schema::GroupNode::Make("schema", Repetition::REPEATED, fields); - schema.Init(root); + fields.push_back(schema::int32("int_col", Repetition::kRequired)); + fields.push_back(schema::floatType("float_col", Repetition::kRequired)); + root = schema::GroupNode::make("schema", Repetition::kRepeated, fields); + schema.init(root); auto kvmeta = std::make_shared(); kvmeta->Append("test_key", "test_value"); - auto f_builder = FileMetaDataBuilder::Make(&schema, props); + auto fBuilder = FileMetaDataBuilder::make(&schema, props); - // Read the metadata - auto f_accessor = f_builder->Finish(kvmeta); + // Read the metadata. + auto fAccessor = fBuilder->finish(kvmeta); - // Key value metadata - ASSERT_TRUE(f_accessor->key_value_metadata()); - EXPECT_TRUE(f_accessor->key_value_metadata()->Equals(*kvmeta)); + // Key value metadata. + ASSERT_TRUE(fAccessor->keyValueMetadata()); + EXPECT_TRUE(fAccessor->keyValueMetadata()->Equals(*kvmeta)); } TEST(Metadata, TestAddKeyValueMetadata) { schema::NodeVector fields; - fields.push_back(schema::Int32("int_col", Repetition::REQUIRED)); + fields.push_back(schema::int32("int_col", Repetition::kRequired)); auto schema = std::static_pointer_cast( - schema::GroupNode::Make("schema", Repetition::REQUIRED, fields)); + schema::GroupNode::make("schema", Repetition::kRequired, fields)); - auto kv_meta = std::make_shared(); - kv_meta->Append("test_key_1", "test_value_1"); - kv_meta->Append("test_key_2", "test_value_2_"); + auto kvMeta = std::make_shared(); + kvMeta->Append("test_key_1", "test_value_1"); + kvMeta->Append("test_key_2", "test_value_2_"); - auto sink = CreateOutputStream(); - auto writer_props = WriterProperties::Builder().disable_dictionary()->build(); - auto file_writer = - ParquetFileWriter::Open(sink, schema, writer_props, kv_meta); + auto sink = createOutputStream(); + auto writerProps = WriterProperties::Builder().disableDictionary()->build(); + auto fileWriter = ParquetFileWriter::open(sink, schema, writerProps, kvMeta); // Key value metadata that will be added to the file. - auto kv_meta_added = std::make_shared(); - kv_meta_added->Append("test_key_2", "test_value_2"); - kv_meta_added->Append("test_key_3", "test_value_3"); + auto kvMetaAdded = std::make_shared(); + kvMetaAdded->Append("test_key_2", "test_value_2"); + kvMetaAdded->Append("test_key_3", "test_value_3"); - file_writer->AddKeyValueMetadata(kv_meta_added); - file_writer->Close(); + fileWriter->addKeyValueMetadata(kvMetaAdded); + fileWriter->close(); // Throw if appending key value metadata to closed file. - auto kv_meta_ignored = std::make_shared(); - kv_meta_ignored->Append("test_key_4", "test_value_4"); + auto kvMetaIgnored = std::make_shared(); + kvMetaIgnored->Append("test_key_4", "test_value_4"); EXPECT_THROW( - file_writer->AddKeyValueMetadata(kv_meta_ignored), ParquetException); + fileWriter->addKeyValueMetadata(kvMetaIgnored), ParquetException); PARQUET_ASSIGN_OR_THROW(auto buffer, sink->Finish()); - // Write the buffer to a temp file path - auto filePath = exec::test::TempFilePath::create(); + // Write the buffer to a temp file path. + auto filePath = TempFilePath::create(); writeToFile(filePath, buffer); memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); std::shared_ptr rootPool = memory::memoryManager()->addRootPool("MetadataTest"); std::shared_ptr leafPool = rootPool->addLeafChild("MetadataTest"); - dwio::common::ReaderOptions readerOptions{leafPool.get()}; + auto dataIoStats = std::make_shared(); + auto metadataIoStats = std::make_shared(); + dwio::common::ReaderOptions readerOptions(leafPool.get()); + readerOptions.setDataIoStats(dataIoStats); + readerOptions.setMetadataIoStats(metadataIoStats); auto input = std::make_unique( std::make_shared(filePath->getPath()), readerOptions.memoryPool()); @@ -470,7 +477,7 @@ TEST(Metadata, TestReadPageIndex) { auto col_chunk_metadata = row_group_metadata->ColumnChunk(i); auto ci_location = col_chunk_metadata->GetColumnIndexLocation(); if (i == 10) { - // column_id 10 does not have column index + // Column_id 10 does not have column index. ASSERT_FALSE(ci_location.has_value()); } else { ASSERT_TRUE(ci_location.has_value()); @@ -490,48 +497,52 @@ TEST(Metadata, TestReadPageIndex) { TEST(Metadata, TestSortingColumns) { schema::NodeVector fields; - fields.push_back(schema::Int32("sort_col", Repetition::REQUIRED)); - fields.push_back(schema::Int32("int_col", Repetition::REQUIRED)); + fields.push_back(schema::int32("sort_col", Repetition::kRequired)); + fields.push_back(schema::int32("int_col", Repetition::kRequired)); auto schema = std::static_pointer_cast( - schema::GroupNode::Make("schema", Repetition::REQUIRED, fields)); + schema::GroupNode::make("schema", Repetition::kRequired, fields)); std::vector sortingColumns; { SortingColumn sortingColumn; - sortingColumn.column_idx = 0; + sortingColumn.columnIdx = 0; sortingColumn.descending = false; - sortingColumn.nulls_first = false; + sortingColumn.nullsFirst = false; sortingColumns.push_back(sortingColumn); } auto createdBy = CREATED_BY_VERSION + std::string(" version 1.0"); - auto sink = CreateOutputStream(); + auto sink = createOutputStream(); auto writerProps = WriterProperties::Builder() - .disable_dictionary() - ->set_sorting_columns(sortingColumns) - ->created_by(createdBy) + .disableDictionary() + ->setSortingColumns(sortingColumns) + ->createdBy(createdBy) ->build(); - EXPECT_EQ(sortingColumns, writerProps->sorting_columns()); + EXPECT_EQ(sortingColumns, writerProps->sortingColumns()); - auto fileWriter = ParquetFileWriter::Open(sink, schema, writerProps); + auto fileWriter = ParquetFileWriter::open(sink, schema, writerProps); - auto rowGroupWriter = fileWriter->AppendBufferedRowGroup(); - rowGroupWriter->Close(); - fileWriter->Close(); + auto rowGroupWriter = fileWriter->appendBufferedRowGroup(); + rowGroupWriter->close(); + fileWriter->close(); PARQUET_ASSIGN_OR_THROW(auto buffer, sink->Finish()); - // Write the buffer to a temp file path - auto filePath = exec::test::TempFilePath::create(); + // Write the buffer to a temp file path. + auto filePath = TempFilePath::create(); writeToFile(filePath, buffer); memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); std::shared_ptr rootPool = memory::memoryManager()->addRootPool("MetadataTest"); std::shared_ptr leafPool = rootPool->addLeafChild("MetadataTest"); - dwio::common::ReaderOptions readerOptions{leafPool.get()}; + auto dataIoStats = std::make_shared(); + auto metadataIoStats = std::make_shared(); + dwio::common::ReaderOptions readerOptions(leafPool.get()); + readerOptions.setDataIoStats(dataIoStats); + readerOptions.setMetadataIoStats(metadataIoStats); auto input = std::make_unique( std::make_shared(filePath->getPath()), readerOptions.memoryPool()); @@ -539,9 +550,9 @@ TEST(Metadata, TestSortingColumns) { std::make_unique(std::move(input), readerOptions); ASSERT_EQ(1, reader->fileMetaData().numRowGroups()); auto rowGroup = reader->fileMetaData().rowGroup(0); - EXPECT_EQ(sortingColumns[0].column_idx, rowGroup.sortingColumnIdx(0)); + EXPECT_EQ(sortingColumns[0].columnIdx, rowGroup.sortingColumnIdx(0)); EXPECT_EQ(sortingColumns[0].descending, rowGroup.sortingColumnDescending(0)); - EXPECT_EQ(sortingColumns[0].nulls_first, rowGroup.sortingColumnNullsFirst(0)); + EXPECT_EQ(sortingColumns[0].nullsFirst, rowGroup.sortingColumnNullsFirst(0)); ASSERT_EQ(createdBy, reader->fileMetaData().createdBy()); } @@ -570,53 +581,53 @@ TEST(ApplicationVersion, Basics) { ASSERT_EQ(5, version4.version.minor); ASSERT_EQ(0, version4.version.patch); ASSERT_EQ("ab", version4.version.unknown); - ASSERT_EQ("cdh5.5.0", version4.version.pre_release); - ASSERT_EQ("cd", version4.version.build_info); + ASSERT_EQ("cdh5.5.0", version4.version.preRelease); + ASSERT_EQ("cd", version4.version.buildInfo); ASSERT_EQ("parquet-mr", version5.application_); ASSERT_EQ(0, version5.version.major); ASSERT_EQ(0, version5.version.minor); ASSERT_EQ(0, version5.version.patch); - ASSERT_EQ(true, version.VersionLt(version1)); + ASSERT_EQ(true, version.versionLt(version1)); EncodedStatistics stats; ASSERT_FALSE( - version1.HasCorrectStatistics(Type::INT96, stats, SortOrder::UNKNOWN)); + version1.hasCorrectStatistics(Type::kInt96, stats, SortOrder::kUnknown)); ASSERT_TRUE( - version.HasCorrectStatistics(Type::INT32, stats, SortOrder::SIGNED)); - ASSERT_FALSE( - version.HasCorrectStatistics(Type::BYTE_ARRAY, stats, SortOrder::SIGNED)); - ASSERT_TRUE(version1.HasCorrectStatistics( - Type::BYTE_ARRAY, stats, SortOrder::SIGNED)); - ASSERT_FALSE(version1.HasCorrectStatistics( - Type::BYTE_ARRAY, stats, SortOrder::UNSIGNED)); - ASSERT_TRUE(version3.HasCorrectStatistics( - Type::FIXED_LEN_BYTE_ARRAY, stats, SortOrder::SIGNED)); - - // Check that the old stats are correct if min and max are the same - // regardless of sort order - EncodedStatistics stats_str; - stats_str.set_min("a").set_max("b"); - ASSERT_FALSE(version1.HasCorrectStatistics( - Type::BYTE_ARRAY, stats_str, SortOrder::UNSIGNED)); - stats_str.set_max("a"); - ASSERT_TRUE(version1.HasCorrectStatistics( - Type::BYTE_ARRAY, stats_str, SortOrder::UNSIGNED)); - - // Check that the same holds true for ints - int32_t int_min = 100, int_max = 200; - EncodedStatistics stats_int; - stats_int.set_min(std::string(reinterpret_cast(&int_min), 4)) - .set_max(std::string(reinterpret_cast(&int_max), 4)); - ASSERT_FALSE(version1.HasCorrectStatistics( - Type::BYTE_ARRAY, stats_int, SortOrder::UNSIGNED)); - stats_int.set_max(std::string(reinterpret_cast(&int_min), 4)); - ASSERT_TRUE(version1.HasCorrectStatistics( - Type::BYTE_ARRAY, stats_int, SortOrder::UNSIGNED)); + version.hasCorrectStatistics(Type::kInt32, stats, SortOrder::kSigned)); + ASSERT_FALSE(version.hasCorrectStatistics( + Type::kByteArray, stats, SortOrder::kSigned)); + ASSERT_TRUE(version1.hasCorrectStatistics( + Type::kByteArray, stats, SortOrder::kSigned)); + ASSERT_FALSE(version1.hasCorrectStatistics( + Type::kByteArray, stats, SortOrder::kUnsigned)); + ASSERT_TRUE(version3.hasCorrectStatistics( + Type::kFixedLenByteArray, stats, SortOrder::kSigned)); + + // Check that the old stats are correct if min and max are the same. + // Regardless of sort order. + EncodedStatistics statsStr; + statsStr.setMin("a").setMax("b"); + ASSERT_FALSE(version1.hasCorrectStatistics( + Type::kByteArray, statsStr, SortOrder::kUnsigned)); + statsStr.setMax("a"); + ASSERT_TRUE(version1.hasCorrectStatistics( + Type::kByteArray, statsStr, SortOrder::kUnsigned)); + + // Check that the same holds true for ints. + int32_t intMin = 100, intMax = 200; + EncodedStatistics statsInt; + statsInt.setMin(std::string(reinterpret_cast(&intMin), 4)) + .setMax(std::string(reinterpret_cast(&intMax), 4)); + ASSERT_FALSE(version1.hasCorrectStatistics( + Type::kByteArray, statsInt, SortOrder::kUnsigned)); + statsInt.setMax(std::string(reinterpret_cast(&intMin), 4)); + ASSERT_TRUE(version1.hasCorrectStatistics( + Type::kByteArray, statsInt, SortOrder::kUnsigned)); } -TEST(ApplicationVersion, Empty) { +TEST(ApplicationVersion, empty) { ApplicationVersion version(""); ASSERT_EQ("", version.application_); @@ -625,8 +636,8 @@ TEST(ApplicationVersion, Empty) { ASSERT_EQ(0, version.version.minor); ASSERT_EQ(0, version.version.patch); ASSERT_EQ("", version.version.unknown); - ASSERT_EQ("", version.version.pre_release); - ASSERT_EQ("", version.version.build_info); + ASSERT_EQ("", version.version.preRelease); + ASSERT_EQ("", version.version.buildInfo); } TEST(ApplicationVersion, NoVersion) { @@ -638,8 +649,8 @@ TEST(ApplicationVersion, NoVersion) { ASSERT_EQ(0, version.version.minor); ASSERT_EQ(0, version.version.patch); ASSERT_EQ("", version.version.unknown); - ASSERT_EQ("", version.version.pre_release); - ASSERT_EQ("", version.version.build_info); + ASSERT_EQ("", version.version.preRelease); + ASSERT_EQ("", version.version.buildInfo); } TEST(ApplicationVersion, VersionEmpty) { @@ -651,8 +662,8 @@ TEST(ApplicationVersion, VersionEmpty) { ASSERT_EQ(0, version.version.minor); ASSERT_EQ(0, version.version.patch); ASSERT_EQ("", version.version.unknown); - ASSERT_EQ("", version.version.pre_release); - ASSERT_EQ("", version.version.build_info); + ASSERT_EQ("", version.version.preRelease); + ASSERT_EQ("", version.version.buildInfo); } TEST(ApplicationVersion, VersionNoMajor) { @@ -664,8 +675,8 @@ TEST(ApplicationVersion, VersionNoMajor) { ASSERT_EQ(0, version.version.minor); ASSERT_EQ(0, version.version.patch); ASSERT_EQ("", version.version.unknown); - ASSERT_EQ("", version.version.pre_release); - ASSERT_EQ("", version.version.build_info); + ASSERT_EQ("", version.version.preRelease); + ASSERT_EQ("", version.version.buildInfo); } TEST(ApplicationVersion, VersionInvalidMajor) { @@ -677,8 +688,8 @@ TEST(ApplicationVersion, VersionInvalidMajor) { ASSERT_EQ(0, version.version.minor); ASSERT_EQ(0, version.version.patch); ASSERT_EQ("", version.version.unknown); - ASSERT_EQ("", version.version.pre_release); - ASSERT_EQ("", version.version.build_info); + ASSERT_EQ("", version.version.preRelease); + ASSERT_EQ("", version.version.buildInfo); } TEST(ApplicationVersion, VersionMajorOnly) { @@ -690,8 +701,8 @@ TEST(ApplicationVersion, VersionMajorOnly) { ASSERT_EQ(0, version.version.minor); ASSERT_EQ(0, version.version.patch); ASSERT_EQ("", version.version.unknown); - ASSERT_EQ("", version.version.pre_release); - ASSERT_EQ("", version.version.build_info); + ASSERT_EQ("", version.version.preRelease); + ASSERT_EQ("", version.version.buildInfo); } TEST(ApplicationVersion, VersionNoMinor) { @@ -703,8 +714,8 @@ TEST(ApplicationVersion, VersionNoMinor) { ASSERT_EQ(0, version.version.minor); ASSERT_EQ(0, version.version.patch); ASSERT_EQ("", version.version.unknown); - ASSERT_EQ("", version.version.pre_release); - ASSERT_EQ("", version.version.build_info); + ASSERT_EQ("", version.version.preRelease); + ASSERT_EQ("", version.version.buildInfo); } TEST(ApplicationVersion, VersionMajorMinorOnly) { @@ -716,8 +727,8 @@ TEST(ApplicationVersion, VersionMajorMinorOnly) { ASSERT_EQ(7, version.version.minor); ASSERT_EQ(0, version.version.patch); ASSERT_EQ("", version.version.unknown); - ASSERT_EQ("", version.version.pre_release); - ASSERT_EQ("", version.version.build_info); + ASSERT_EQ("", version.version.preRelease); + ASSERT_EQ("", version.version.buildInfo); } TEST(ApplicationVersion, VersionInvalidMinor) { @@ -729,8 +740,8 @@ TEST(ApplicationVersion, VersionInvalidMinor) { ASSERT_EQ(0, version.version.minor); ASSERT_EQ(0, version.version.patch); ASSERT_EQ("", version.version.unknown); - ASSERT_EQ("", version.version.pre_release); - ASSERT_EQ("", version.version.build_info); + ASSERT_EQ("", version.version.preRelease); + ASSERT_EQ("", version.version.buildInfo); } TEST(ApplicationVersion, VersionNoPatch) { @@ -742,8 +753,8 @@ TEST(ApplicationVersion, VersionNoPatch) { ASSERT_EQ(7, version.version.minor); ASSERT_EQ(0, version.version.patch); ASSERT_EQ("", version.version.unknown); - ASSERT_EQ("", version.version.pre_release); - ASSERT_EQ("", version.version.build_info); + ASSERT_EQ("", version.version.preRelease); + ASSERT_EQ("", version.version.buildInfo); } TEST(ApplicationVersion, VersionInvalidPatch) { @@ -755,8 +766,8 @@ TEST(ApplicationVersion, VersionInvalidPatch) { ASSERT_EQ(7, version.version.minor); ASSERT_EQ(0, version.version.patch); ASSERT_EQ("", version.version.unknown); - ASSERT_EQ("", version.version.pre_release); - ASSERT_EQ("", version.version.build_info); + ASSERT_EQ("", version.version.preRelease); + ASSERT_EQ("", version.version.buildInfo); } TEST(ApplicationVersion, VersionNoUnknown) { @@ -768,8 +779,8 @@ TEST(ApplicationVersion, VersionNoUnknown) { ASSERT_EQ(7, version.version.minor); ASSERT_EQ(9, version.version.patch); ASSERT_EQ("", version.version.unknown); - ASSERT_EQ("cdh5.5.0", version.version.pre_release); - ASSERT_EQ("cd", version.version.build_info); + ASSERT_EQ("cdh5.5.0", version.version.preRelease); + ASSERT_EQ("cd", version.version.buildInfo); } TEST(ApplicationVersion, VersionNoPreRelease) { @@ -781,8 +792,8 @@ TEST(ApplicationVersion, VersionNoPreRelease) { ASSERT_EQ(7, version.version.minor); ASSERT_EQ(9, version.version.patch); ASSERT_EQ("ab", version.version.unknown); - ASSERT_EQ("", version.version.pre_release); - ASSERT_EQ("cd", version.version.build_info); + ASSERT_EQ("", version.version.preRelease); + ASSERT_EQ("cd", version.version.buildInfo); } TEST(ApplicationVersion, VersionNoUnknownNoPreRelease) { @@ -794,8 +805,8 @@ TEST(ApplicationVersion, VersionNoUnknownNoPreRelease) { ASSERT_EQ(7, version.version.minor); ASSERT_EQ(9, version.version.patch); ASSERT_EQ("", version.version.unknown); - ASSERT_EQ("", version.version.pre_release); - ASSERT_EQ("cd", version.version.build_info); + ASSERT_EQ("", version.version.preRelease); + ASSERT_EQ("cd", version.version.buildInfo); } TEST(ApplicationVersion, VersionNoUnknownBuildInfoPreRelease) { @@ -807,8 +818,8 @@ TEST(ApplicationVersion, VersionNoUnknownBuildInfoPreRelease) { ASSERT_EQ(7, version.version.minor); ASSERT_EQ(9, version.version.patch); ASSERT_EQ("", version.version.unknown); - ASSERT_EQ("", version.version.pre_release); - ASSERT_EQ("cd-cdh5.5.0", version.version.build_info); + ASSERT_EQ("", version.version.preRelease); + ASSERT_EQ("cd-cdh5.5.0", version.version.buildInfo); } TEST(ApplicationVersion, FullWithSpaces) { @@ -821,8 +832,8 @@ TEST(ApplicationVersion, FullWithSpaces) { ASSERT_EQ(5, version.version.minor); ASSERT_EQ(3, version.version.patch); ASSERT_EQ("ab", version.version.unknown); - ASSERT_EQ("cdh5.5.0", version.version.pre_release); - ASSERT_EQ("cd", version.version.build_info); + ASSERT_EQ("cdh5.5.0", version.version.preRelease); + ASSERT_EQ("cd", version.version.buildInfo); } } // namespace metadata diff --git a/velox/dwio/parquet/writer/arrow/tests/PageIndexTest.cpp b/velox/dwio/parquet/writer/arrow/tests/PageIndexTest.cpp index fd175e2a821..1d4797bb4dd 100644 --- a/velox/dwio/parquet/writer/arrow/tests/PageIndexTest.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/PageIndexTest.cpp @@ -26,117 +26,115 @@ namespace facebook::velox::parquet::arrow { struct PageIndexRanges { - int64_t column_index_offset; - int64_t column_index_length; - int64_t offset_index_offset; - int64_t offset_index_length; + int64_t columnIndexOffset; + int64_t columnIndexLength; + int64_t offsetIndexOffset; + int64_t offsetIndexLength; }; using RowGroupRanges = std::vector; -/// Creates an FileMetaData object w/ single row group based on data in -/// 'row_group_ranges'. It sets the offsets and sizes of the column index and -/// offset index members of the row group. It doesn't set the member if the -/// input value is -1. -std::shared_ptr ConstructFakeMetaData( - const RowGroupRanges& row_group_ranges) { - facebook::velox::parquet::thrift::RowGroup row_group; - for (auto& page_index_ranges : row_group_ranges) { - facebook::velox::parquet::thrift::ColumnChunk col_chunk; - if (page_index_ranges.column_index_offset != -1) { - col_chunk.__set_column_index_offset( - page_index_ranges.column_index_offset); +/// Creates an FileMetaData object w/ single row group based on data in. +/// 'Row_group_ranges'. It sets the offsets and sizes of the column index and. +/// Offset index members of the row group. It doesn't set the member if the. +/// Input value is -1. +std::shared_ptr constructFakeMetaData( + const RowGroupRanges& rowGroupRanges) { + facebook::velox::parquet::thrift::RowGroup rowGroup; + for (auto& pageIndexRanges : rowGroupRanges) { + facebook::velox::parquet::thrift::ColumnChunk colChunk; + if (pageIndexRanges.columnIndexOffset != -1) { + colChunk.__set_column_index_offset(pageIndexRanges.columnIndexOffset); } - if (page_index_ranges.column_index_length != -1) { - col_chunk.__set_column_index_length( - static_cast(page_index_ranges.column_index_length)); + if (pageIndexRanges.columnIndexLength != -1) { + colChunk.__set_column_index_length( + static_cast(pageIndexRanges.columnIndexLength)); } - if (page_index_ranges.offset_index_offset != -1) { - col_chunk.__set_offset_index_offset( - page_index_ranges.offset_index_offset); + if (pageIndexRanges.offsetIndexOffset != -1) { + colChunk.__set_offset_index_offset(pageIndexRanges.offsetIndexOffset); } - if (page_index_ranges.offset_index_length != -1) { - col_chunk.__set_offset_index_length( - static_cast(page_index_ranges.offset_index_length)); + if (pageIndexRanges.offsetIndexLength != -1) { + colChunk.__set_offset_index_length( + static_cast(pageIndexRanges.offsetIndexLength)); } - row_group.columns.push_back(col_chunk); + rowGroup.columns.push_back(colChunk); } facebook::velox::parquet::thrift::FileMetaData metadata; - metadata.row_groups.push_back(row_group); + metadata.row_groups.push_back(rowGroup); metadata.schema.emplace_back(); schema::NodeVector fields; - for (size_t i = 0; i < row_group_ranges.size(); ++i) { - fields.push_back(schema::Int64(std::to_string(i))); + for (size_t i = 0; i < rowGroupRanges.size(); ++i) { + fields.push_back(schema::int64(std::to_string(i))); metadata.schema.emplace_back(); - fields.back()->ToParquet(&metadata.schema.back()); + fields.back()->toParquet(&metadata.schema.back()); } - schema::GroupNode::Make("schema", Repetition::REPEATED, fields) - ->ToParquet(&metadata.schema.front()); + schema::GroupNode::make("schema", Repetition::kRepeated, fields) + ->toParquet(&metadata.schema.front()); - auto sink = CreateOutputStream(); - ThriftSerializer{}.Serialize(&metadata, sink.get()); + auto sink = createOutputStream(); + ThriftSerializer{}.serialize(&metadata, sink.get()); auto buffer = sink->Finish().MoveValueUnsafe(); uint32_t len = static_cast(buffer->size()); - return FileMetaData::Make(buffer->data(), &len); + return FileMetaData::make(buffer->data(), &len); } -/// Validates that 'DeterminePageIndexRangesInRowGroup()' selects the expected -/// file offsets and sizes or returns false when the row group doesn't have a -/// page index. -void ValidatePageIndexRange( - const RowGroupRanges& row_group_ranges, - const std::vector& column_indices, - bool expected_has_column_index, - bool expected_has_offset_index, - int expected_ci_start, - int expected_ci_size, - int expected_oi_start, - int expected_oi_size) { - auto file_metadata = ConstructFakeMetaData(row_group_ranges); - auto read_range = PageIndexReader::DeterminePageIndexRangesInRowGroup( - *file_metadata->RowGroup(0), column_indices); - ASSERT_EQ(expected_has_column_index, read_range.column_index.has_value()); - ASSERT_EQ(expected_has_offset_index, read_range.offset_index.has_value()); - if (expected_has_column_index) { - EXPECT_EQ(expected_ci_start, read_range.column_index->offset); - EXPECT_EQ(expected_ci_size, read_range.column_index->length); +/// Validates that 'DeterminePageIndexRangesInRowGroup()' selects the expected. +/// File offsets and sizes or returns false when the row group doesn't have a. +/// Page index. +void validatePageIndexRange( + const RowGroupRanges& rowGroupRanges, + const std::vector& columnIndices, + bool expectedHasColumnIndex, + bool expectedHasOffsetIndex, + int expectedCiStart, + int expectedCiSize, + int expectedOiStart, + int expectedOiSize) { + auto fileMetadata = constructFakeMetaData(rowGroupRanges); + auto readRange = PageIndexReader::determinePageIndexRangesInRowGroup( + *fileMetadata->rowGroup(0), columnIndices); + ASSERT_EQ(expectedHasColumnIndex, readRange.columnIndex.has_value()); + ASSERT_EQ(expectedHasOffsetIndex, readRange.offsetIndex.has_value()); + if (expectedHasColumnIndex) { + EXPECT_EQ(expectedCiStart, readRange.columnIndex->offset); + EXPECT_EQ(expectedCiSize, readRange.columnIndex->length); } - if (expected_has_offset_index) { - EXPECT_EQ(expected_oi_start, read_range.offset_index->offset); - EXPECT_EQ(expected_oi_size, read_range.offset_index->length); + if (expectedHasOffsetIndex) { + EXPECT_EQ(expectedOiStart, readRange.offsetIndex->offset); + EXPECT_EQ(expectedOiSize, readRange.offsetIndex->length); } } -/// This test constructs a couple of artificial row groups with page index -/// offsets in them. Then it validates if -/// PageIndexReader::DeterminePageIndexRangesInRowGroup() properly computes the -/// file range that contains the whole page index. -TEST(PageIndex, DeterminePageIndexRangesInRowGroup) { - // No Column chunks - ValidatePageIndexRange({}, {}, false, false, -1, -1, -1, -1); +/// This test constructs a couple of artificial row groups with page index. +/// Offsets in them. Then it validates if. +/// PageIndexReader::DeterminePageIndexRangesInRowGroup() properly computes the. +/// File range that contains the whole page index. +TEST(PageIndex, determinePageIndexRangesInRowGroup) { + // No Column chunks. + validatePageIndexRange({}, {}, false, false, -1, -1, -1, -1); // No page index at all. - ValidatePageIndexRange({{-1, -1, -1, -1}}, {}, false, false, -1, -1, -1, -1); + validatePageIndexRange({{-1, -1, -1, -1}}, {}, false, false, -1, -1, -1, -1); // Page index for single column chunk. - ValidatePageIndexRange({{10, 5, 15, 5}}, {}, true, true, 10, 5, 15, 5); + validatePageIndexRange({{10, 5, 15, 5}}, {}, true, true, 10, 5, 15, 5); // Page index for two column chunks. - ValidatePageIndexRange( + validatePageIndexRange( {{10, 5, 30, 25}, {15, 15, 50, 20}}, {}, true, true, 10, 20, 30, 40); // Page index for second column chunk. - ValidatePageIndexRange( + validatePageIndexRange( {{-1, -1, -1, -1}, {20, 10, 30, 25}}, {}, true, true, 20, 10, 30, 25); // Page index for first column chunk. - ValidatePageIndexRange( + validatePageIndexRange( {{10, 5, 15, 5}, {-1, -1, -1, -1}}, {}, true, true, 10, 5, 15, 5); // Missing offset index for first column chunk. Gap in column index. - ValidatePageIndexRange( + validatePageIndexRange( {{10, 5, -1, -1}, {20, 10, 30, 25}}, {}, true, true, 10, 20, 30, 25); // Missing offset index for second column chunk. - ValidatePageIndexRange( + validatePageIndexRange( {{10, 5, 25, 5}, {20, 10, -1, -1}}, {}, true, true, 10, 20, 25, 5); // Four column chunks. - ValidatePageIndexRange( + validatePageIndexRange( {{100, 10, 220, 30}, {110, 25, 250, 10}, {140, 30, 260, 40}, @@ -150,23 +148,23 @@ TEST(PageIndex, DeterminePageIndexRangesInRowGroup) { 180); } -/// This test constructs a couple of artificial row groups with page index -/// offsets in them. Then it validates if -/// PageIndexReader::DeterminePageIndexRangesInRowGroup() properly computes the -/// file range that contains the page index of selected columns. +/// This test constructs a couple of artificial row groups with page index. +/// Offsets in them. Then it validates if. +/// PageIndexReader::DeterminePageIndexRangesInRowGroup() properly computes the. +/// File range that contains the page index of selected columns. TEST(PageIndex, DeterminePageIndexRangesInRowGroupWithPartialColumnsSelected) { // No page index at all. - ValidatePageIndexRange({{-1, -1, -1, -1}}, {0}, false, false, -1, -1, -1, -1); + validatePageIndexRange({{-1, -1, -1, -1}}, {0}, false, false, -1, -1, -1, -1); // Page index for single column chunk. - ValidatePageIndexRange({{10, 5, 15, 5}}, {0}, true, true, 10, 5, 15, 5); + validatePageIndexRange({{10, 5, 15, 5}}, {0}, true, true, 10, 5, 15, 5); // Page index for the 1st column chunk. - ValidatePageIndexRange( + validatePageIndexRange( {{10, 5, 30, 25}, {15, 15, 50, 20}}, {0}, true, true, 10, 5, 30, 25); // Page index for the 2nd column chunk. - ValidatePageIndexRange( + validatePageIndexRange( {{10, 5, 30, 25}, {15, 15, 50, 20}}, {1}, true, true, 15, 15, 50, 20); // Only 2nd column is selected among four column chunks. - ValidatePageIndexRange( + validatePageIndexRange( {{100, 10, 220, 30}, {110, 25, 250, 10}, {140, 30, 260, 40}, @@ -179,7 +177,7 @@ TEST(PageIndex, DeterminePageIndexRangesInRowGroupWithPartialColumnsSelected) { 250, 10); // Only 2nd and 3rd columns are selected among four column chunks. - ValidatePageIndexRange( + validatePageIndexRange( {{100, 10, 220, 30}, {110, 25, 250, 10}, {140, 30, 260, 40}, @@ -192,7 +190,7 @@ TEST(PageIndex, DeterminePageIndexRangesInRowGroupWithPartialColumnsSelected) { 250, 50); // Only 2nd and 4th columns are selected among four column chunks. - ValidatePageIndexRange( + validatePageIndexRange( {{100, 10, 220, 30}, {110, 25, 250, 10}, {140, 30, 260, 40}, @@ -205,7 +203,7 @@ TEST(PageIndex, DeterminePageIndexRangesInRowGroupWithPartialColumnsSelected) { 250, 150); // Only 1st, 2nd and 4th columns are selected among four column chunks. - ValidatePageIndexRange( + validatePageIndexRange( {{100, 10, 220, 30}, {110, 25, 250, 10}, {140, 30, 260, 40}, @@ -217,9 +215,9 @@ TEST(PageIndex, DeterminePageIndexRangesInRowGroupWithPartialColumnsSelected) { 110, 220, 180); - // 3rd column is selected but not present in the row group. + // 3Rd column is selected but not present in the row group. EXPECT_THROW( - ValidatePageIndexRange( + validatePageIndexRange( {{10, 5, 30, 25}, {15, 15, 50, 20}}, {2}, false, @@ -231,101 +229,101 @@ TEST(PageIndex, DeterminePageIndexRangesInRowGroupWithPartialColumnsSelected) { ParquetException); } -/// This test constructs a couple of artificial row groups with page index -/// offsets in them. Then it validates if -/// PageIndexReader::DeterminePageIndexRangesInRowGroup() properly detects if -/// column index or offset index is missing. +/// This test constructs a couple of artificial row groups with page index. +/// Offsets in them. Then it validates if. +/// PageIndexReader::DeterminePageIndexRangesInRowGroup() properly detects if. +/// Column index or offset index is missing. TEST(PageIndex, DeterminePageIndexRangesInRowGroupWithMissingPageIndex) { // No column index at all. - ValidatePageIndexRange({{-1, -1, 15, 5}}, {}, false, true, -1, -1, 15, 5); + validatePageIndexRange({{-1, -1, 15, 5}}, {}, false, true, -1, -1, 15, 5); // No offset index at all. - ValidatePageIndexRange({{10, 5, -1, -1}}, {}, true, false, 10, 5, -1, -1); + validatePageIndexRange({{10, 5, -1, -1}}, {}, true, false, 10, 5, -1, -1); // No column index at all among two column chunks. - ValidatePageIndexRange( + validatePageIndexRange( {{-1, -1, 30, 25}, {-1, -1, 50, 20}}, {}, false, true, -1, -1, 30, 40); // No offset index at all among two column chunks. - ValidatePageIndexRange( + validatePageIndexRange( {{10, 5, -1, -1}, {15, 15, -1, -1}}, {}, true, false, 10, 20, -1, -1); } TEST(PageIndex, WriteOffsetIndex) { /// Create offset index via the OffsetIndexBuilder interface. - auto builder = OffsetIndexBuilder::Make(); - const size_t num_pages = 5; + auto Builder = OffsetIndexBuilder::make(); + const size_t numPages = 5; const std::vector offsets = {100, 200, 300, 400, 500}; - const std::vector page_sizes = {1024, 2048, 3072, 4096, 8192}; - const std::vector first_row_indices = { - 0, 10000, 20000, 30000, 40000}; - for (size_t i = 0; i < num_pages; ++i) { - builder->AddPage(offsets[i], page_sizes[i], first_row_indices[i]); + const std::vector pageSizes = {1024, 2048, 3072, 4096, 8192}; + const std::vector firstRowIndices = {0, 10000, 20000, 30000, 40000}; + for (size_t i = 0; i < numPages; ++i) { + Builder->addPage(offsets[i], pageSizes[i], firstRowIndices[i]); } - const int64_t final_position = 4096; - builder->Finish(final_position); - - std::vector> offset_indexes; - /// 1st element is the offset index just built. - offset_indexes.emplace_back(builder->Build()); - /// 2nd element is the offset index restored by serialize-then-deserialize - /// round trip. - auto sink = CreateOutputStream(); - builder->WriteTo(sink.get()); + const int64_t finalPosition = 4096; + Builder->finish(finalPosition); + + std::vector> offsetIndexes; + /// 1St element is the offset index just built. + offsetIndexes.emplace_back(Builder->build()); + /// 2Nd element is the offset index restored by serialize-then-deserialize. + /// Round trip. + auto sink = createOutputStream(); + Builder->writeTo(sink.get()); PARQUET_ASSIGN_OR_THROW(auto buffer, sink->Finish()); - offset_indexes.emplace_back(OffsetIndex::Make( - buffer->data(), - static_cast(buffer->size()), - default_reader_properties())); + offsetIndexes.emplace_back( + OffsetIndex::make( + buffer->data(), + static_cast(buffer->size()), + defaultReaderProperties())); /// Verify the data of the offset index. - for (const auto& offset_index : offset_indexes) { - ASSERT_EQ(num_pages, offset_index->page_locations().size()); - for (size_t i = 0; i < num_pages; ++i) { - const auto& page_location = offset_index->page_locations().at(i); - ASSERT_EQ(offsets[i] + final_position, page_location.offset); - ASSERT_EQ(page_sizes[i], page_location.compressed_page_size); - ASSERT_EQ(first_row_indices[i], page_location.first_row_index); + for (const auto& offsetIndex : offsetIndexes) { + ASSERT_EQ(numPages, offsetIndex->pageLocations().size()); + for (size_t i = 0; i < numPages; ++i) { + const auto& pageLocation = offsetIndex->pageLocations().at(i); + ASSERT_EQ(offsets[i] + finalPosition, pageLocation.offset); + ASSERT_EQ(pageSizes[i], pageLocation.compressedPageSize); + ASSERT_EQ(firstRowIndices[i], pageLocation.firstRowIndex); } } } -void TestWriteTypedColumnIndex( - schema::NodePtr node, - const std::vector& page_stats, - BoundaryOrder::type boundary_order, - bool has_null_counts) { - auto descr = - std::make_unique(node, /*max_definition_level=*/1, 0); - - auto builder = ColumnIndexBuilder::Make(descr.get()); - for (const auto& stats : page_stats) { - builder->AddPage(stats); +void testWriteTypedColumnIndex( + schema::NodePtr Node, + const std::vector& pageStats, + BoundaryOrder::type boundaryOrder, + bool hasNullCounts) { + auto descr = std::make_unique(Node, 1, 0); + + auto Builder = ColumnIndexBuilder::make(descr.get()); + for (const auto& stats : pageStats) { + Builder->addPage(stats); } - ASSERT_NO_THROW(builder->Finish()); - - std::vector> column_indexes; - /// 1st element is the column index just built. - column_indexes.emplace_back(builder->Build()); - /// 2nd element is the column index restored by serialize-then-deserialize - /// round trip. - auto sink = CreateOutputStream(); - builder->WriteTo(sink.get()); + ASSERT_NO_THROW(Builder->finish()); + + std::vector> columnIndexes; + /// 1St element is the column index just built. + columnIndexes.emplace_back(Builder->build()); + /// 2Nd element is the column index restored by serialize-then-deserialize. + /// Round trip. + auto sink = createOutputStream(); + Builder->writeTo(sink.get()); PARQUET_ASSIGN_OR_THROW(auto buffer, sink->Finish()); - column_indexes.emplace_back(ColumnIndex::Make( - *descr, - buffer->data(), - static_cast(buffer->size()), - default_reader_properties())); + columnIndexes.emplace_back( + ColumnIndex::make( + *descr, + buffer->data(), + static_cast(buffer->size()), + defaultReaderProperties())); /// Verify the data of the column index. - for (const auto& column_index : column_indexes) { - ASSERT_EQ(boundary_order, column_index->boundary_order()); - ASSERT_EQ(has_null_counts, column_index->has_null_counts()); - const size_t num_pages = column_index->null_pages().size(); - for (size_t i = 0; i < num_pages; ++i) { - ASSERT_EQ(page_stats[i].all_null_value, column_index->null_pages()[i]); - ASSERT_EQ(page_stats[i].min(), column_index->encoded_min_values()[i]); - ASSERT_EQ(page_stats[i].max(), column_index->encoded_max_values()[i]); - if (has_null_counts) { - ASSERT_EQ(page_stats[i].null_count, column_index->null_counts()[i]); + for (const auto& columnIndex : columnIndexes) { + ASSERT_EQ(boundaryOrder, columnIndex->boundaryOrder()); + ASSERT_EQ(hasNullCounts, columnIndex->hasNullCounts()); + const size_t numPages = columnIndex->nullPages().size(); + for (size_t i = 0; i < numPages; ++i) { + ASSERT_EQ(pageStats[i].allNullValue, columnIndex->nullPages()[i]); + ASSERT_EQ(pageStats[i].min(), columnIndex->encodedMinValues()[i]); + ASSERT_EQ(pageStats[i].max(), columnIndex->encodedMaxValues()[i]); + if (hasNullCounts) { + ASSERT_EQ(pageStats[i].nullCount, columnIndex->nullCounts()[i]); } } } @@ -337,16 +335,13 @@ TEST(PageIndex, WriteInt32ColumnIndex) { }; // Integer values in the ascending order. - std::vector page_stats(3); - page_stats.at(0).set_null_count(1).set_min(encode(1)).set_max(encode(2)); - page_stats.at(1).set_null_count(2).set_min(encode(2)).set_max(encode(3)); - page_stats.at(2).set_null_count(3).set_min(encode(3)).set_max(encode(4)); - - TestWriteTypedColumnIndex( - schema::Int32("c1"), - page_stats, - BoundaryOrder::Ascending, - /*has_null_counts=*/true); + std::vector pageStats(3); + pageStats.at(0).setNullCount(1).setMin(encode(1)).setMax(encode(2)); + pageStats.at(1).setNullCount(2).setMin(encode(2)).setMax(encode(3)); + pageStats.at(2).setNullCount(3).setMin(encode(3)).setMax(encode(4)); + + testWriteTypedColumnIndex( + schema::int32("c1"), pageStats, BoundaryOrder::kAscending, true); } TEST(PageIndex, WriteInt64ColumnIndex) { @@ -355,16 +350,13 @@ TEST(PageIndex, WriteInt64ColumnIndex) { }; // Integer values in the descending order. - std::vector page_stats(3); - page_stats.at(0).set_null_count(4).set_min(encode(-1)).set_max(encode(-2)); - page_stats.at(1).set_null_count(0).set_min(encode(-2)).set_max(encode(-3)); - page_stats.at(2).set_null_count(4).set_min(encode(-3)).set_max(encode(-4)); - - TestWriteTypedColumnIndex( - schema::Int64("c1"), - page_stats, - BoundaryOrder::Descending, - /*has_null_counts=*/true); + std::vector pageStats(3); + pageStats.at(0).setNullCount(4).setMin(encode(-1)).setMax(encode(-2)); + pageStats.at(1).setNullCount(0).setMin(encode(-2)).setMax(encode(-3)); + pageStats.at(2).setNullCount(4).setMin(encode(-3)).setMax(encode(-4)); + + testWriteTypedColumnIndex( + schema::int64("c1"), pageStats, BoundaryOrder::kDescending, true); } TEST(PageIndex, WriteFloatColumnIndex) { @@ -373,25 +365,13 @@ TEST(PageIndex, WriteFloatColumnIndex) { }; // Float values with no specific order. - std::vector page_stats(3); - page_stats.at(0) - .set_null_count(0) - .set_min(encode(2.2F)) - .set_max(encode(4.4F)); - page_stats.at(1) - .set_null_count(0) - .set_min(encode(1.1F)) - .set_max(encode(5.5F)); - page_stats.at(2) - .set_null_count(0) - .set_min(encode(3.3F)) - .set_max(encode(6.6F)); - - TestWriteTypedColumnIndex( - schema::Float("c1"), - page_stats, - BoundaryOrder::Unordered, - /*has_null_counts=*/true); + std::vector pageStats(3); + pageStats.at(0).setNullCount(0).setMin(encode(2.2F)).setMax(encode(4.4F)); + pageStats.at(1).setNullCount(0).setMin(encode(1.1F)).setMax(encode(5.5F)); + pageStats.at(2).setNullCount(0).setMin(encode(3.3F)).setMax(encode(6.6F)); + + testWriteTypedColumnIndex( + schema::floatType("c1"), pageStats, BoundaryOrder::kUnordered, true); } TEST(PageIndex, WriteDoubleColumnIndex) { @@ -400,66 +380,54 @@ TEST(PageIndex, WriteDoubleColumnIndex) { }; // Double values with no specific order and without null count. - std::vector page_stats(3); - page_stats.at(0).set_min(encode(1.2)).set_max(encode(4.4)); - page_stats.at(1).set_min(encode(2.2)).set_max(encode(5.5)); - page_stats.at(2).set_min(encode(3.3)).set_max(encode(-6.6)); - - TestWriteTypedColumnIndex( - schema::Double("c1"), - page_stats, - BoundaryOrder::Unordered, - /*has_null_counts=*/false); + std::vector pageStats(3); + pageStats.at(0).setMin(encode(1.2)).setMax(encode(4.4)); + pageStats.at(1).setMin(encode(2.2)).setMax(encode(5.5)); + pageStats.at(2).setMin(encode(3.3)).setMax(encode(-6.6)); + + testWriteTypedColumnIndex( + schema::doubleType("c1"), pageStats, BoundaryOrder::kUnordered, false); } TEST(PageIndex, WriteByteArrayColumnIndex) { // Byte array values with identical min/max. - std::vector page_stats(3); - page_stats.at(0).set_min("bar").set_max("foo"); - page_stats.at(1).set_min("bar").set_max("foo"); - page_stats.at(2).set_min("bar").set_max("foo"); - - TestWriteTypedColumnIndex( - schema::ByteArray("c1"), - page_stats, - BoundaryOrder::Ascending, - /*has_null_counts=*/false); + std::vector pageStats(3); + pageStats.at(0).setMin("bar").setMax("foo"); + pageStats.at(1).setMin("bar").setMax("foo"); + pageStats.at(2).setMin("bar").setMax("foo"); + + testWriteTypedColumnIndex( + schema::byteArray("c1"), pageStats, BoundaryOrder::kAscending, false); } TEST(PageIndex, WriteFLBAColumnIndex) { - // FLBA values in the ascending order with some null pages - std::vector page_stats(5); - page_stats.at(0).set_min("abc").set_max("ABC"); - page_stats.at(1).all_null_value = true; - page_stats.at(2).set_min("foo").set_max("FOO"); - page_stats.at(3).all_null_value = true; - page_stats.at(4).set_min("xyz").set_max("XYZ"); - - auto node = schema::PrimitiveNode::Make( + // FLBA values in the ascending order with some null pages. + std::vector pageStats(5); + pageStats.at(0).setMin("abc").setMax("ABC"); + pageStats.at(1).allNullValue = true; + pageStats.at(2).setMin("foo").setMax("FOO"); + pageStats.at(3).allNullValue = true; + pageStats.at(4).setMin("xyz").setMax("XYZ"); + + auto Node = schema::PrimitiveNode::make( "c1", - Repetition::OPTIONAL, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::NONE, - /*length=*/3); - TestWriteTypedColumnIndex( - std::move(node), - page_stats, - BoundaryOrder::Ascending, - /*has_null_counts=*/false); + Repetition::kOptional, + Type::kFixedLenByteArray, + ConvertedType::kNone, + 3); + testWriteTypedColumnIndex( + std::move(Node), pageStats, BoundaryOrder::kAscending, false); } TEST(PageIndex, WriteColumnIndexWithAllNullPages) { // All values are null. - std::vector page_stats(3); - page_stats.at(0).set_null_count(100).all_null_value = true; - page_stats.at(1).set_null_count(100).all_null_value = true; - page_stats.at(2).set_null_count(100).all_null_value = true; - - TestWriteTypedColumnIndex( - schema::Int32("c1"), - page_stats, - BoundaryOrder::Unordered, - /*has_null_counts=*/true); + std::vector pageStats(3); + pageStats.at(0).setNullCount(100).allNullValue = true; + pageStats.at(1).setNullCount(100).allNullValue = true; + pageStats.at(2).setNullCount(100).allNullValue = true; + + testWriteTypedColumnIndex( + schema::int32("c1"), pageStats, BoundaryOrder::kUnordered, true); } TEST(PageIndex, WriteColumnIndexWithInvalidNullCounts) { @@ -467,17 +435,14 @@ TEST(PageIndex, WriteColumnIndexWithInvalidNullCounts) { return std::string(reinterpret_cast(&value), sizeof(int32_t)); }; - // Some pages do not provide null_count - std::vector page_stats(3); - page_stats.at(0).set_min(encode(1)).set_max(encode(2)).set_null_count(0); - page_stats.at(1).set_min(encode(1)).set_max(encode(3)); - page_stats.at(2).set_min(encode(2)).set_max(encode(3)).set_null_count(0); - - TestWriteTypedColumnIndex( - schema::Int32("c1"), - page_stats, - BoundaryOrder::Ascending, - /*has_null_counts=*/false); + // Some pages do not provide null_count. + std::vector pageStats(3); + pageStats.at(0).setMin(encode(1)).setMax(encode(2)).setNullCount(0); + pageStats.at(1).setMin(encode(1)).setMax(encode(3)); + pageStats.at(2).setMin(encode(2)).setMax(encode(3)).setNullCount(0); + + testWriteTypedColumnIndex( + schema::int32("c1"), pageStats, BoundaryOrder::kAscending, false); } TEST(PageIndex, WriteColumnIndexWithCorruptedStats) { @@ -485,153 +450,150 @@ TEST(PageIndex, WriteColumnIndexWithCorruptedStats) { return std::string(reinterpret_cast(&value), sizeof(int32_t)); }; - // 2nd page does not set anything - std::vector page_stats(3); - page_stats.at(0).set_min(encode(1)).set_max(encode(2)); - page_stats.at(2).set_min(encode(3)).set_max(encode(4)); + // 2Nd page does not set anything. + std::vector pageStats(3); + pageStats.at(0).setMin(encode(1)).setMax(encode(2)); + pageStats.at(2).setMin(encode(3)).setMax(encode(4)); - ColumnDescriptor descr(schema::Int32("c1"), /*max_definition_level=*/1, 0); - auto builder = ColumnIndexBuilder::Make(&descr); - for (const auto& stats : page_stats) { - builder->AddPage(stats); + ColumnDescriptor descr(schema::int32("c1"), 1, 0); + auto Builder = ColumnIndexBuilder::make(&descr); + for (const auto& stats : pageStats) { + Builder->addPage(stats); } - ASSERT_NO_THROW(builder->Finish()); - ASSERT_EQ(nullptr, builder->Build()); + ASSERT_NO_THROW(Builder->finish()); + ASSERT_EQ(nullptr, Builder->build()); - auto sink = CreateOutputStream(); - builder->WriteTo(sink.get()); + auto sink = createOutputStream(); + Builder->writeTo(sink.get()); PARQUET_ASSIGN_OR_THROW(auto buffer, sink->Finish()); EXPECT_EQ(0, buffer->size()); } TEST(PageIndex, TestPageIndexBuilderWithZeroRowGroup) { - schema::NodeVector fields = {schema::Int32("c1"), schema::ByteArray("c2")}; + schema::NodeVector fields = {schema::int32("c1"), schema::byteArray("c2")}; schema::NodePtr root = - schema::GroupNode::Make("schema", Repetition::REPEATED, fields); + schema::GroupNode::make("schema", Repetition::kRepeated, fields); SchemaDescriptor schema; - schema.Init(root); + schema.init(root); - auto builder = PageIndexBuilder::Make(&schema); + auto Builder = PageIndexBuilder::make(&schema); // AppendRowGroup() is not called and expect throw. - ASSERT_THROW(builder->GetColumnIndexBuilder(0), ParquetException); - ASSERT_THROW(builder->GetOffsetIndexBuilder(0), ParquetException); + ASSERT_THROW(Builder->getColumnIndexBuilder(0), ParquetException); + ASSERT_THROW(Builder->getOffsetIndexBuilder(0), ParquetException); // Finish the builder without calling AppendRowGroup(). - ASSERT_NO_THROW(builder->Finish()); + ASSERT_NO_THROW(Builder->finish()); // Verify WriteTo does not write anything. - auto sink = CreateOutputStream(); + auto sink = createOutputStream(); PageIndexLocation location; - builder->WriteTo(sink.get(), &location); + Builder->writeTo(sink.get(), &location); PARQUET_ASSIGN_OR_THROW(auto buffer, sink->Finish()); ASSERT_EQ(0, buffer->size()); - ASSERT_TRUE(location.column_index_location.empty()); - ASSERT_TRUE(location.offset_index_location.empty()); + ASSERT_TRUE(location.columnIndexLocation.empty()); + ASSERT_TRUE(location.offsetIndexLocation.empty()); } class PageIndexBuilderTest : public ::testing::Test { public: - void WritePageIndexes( - int num_row_groups, - int num_columns, - const std::vector>& page_stats, - const std::vector>& page_locations, - int final_position) { - auto builder = PageIndexBuilder::Make(&schema_); - for (int row_group = 0; row_group < num_row_groups; ++row_group) { - ASSERT_NO_THROW(builder->AppendRowGroup()); - - for (int column = 0; column < num_columns; ++column) { - if (static_cast(column) < page_stats[row_group].size()) { - auto column_index_builder = builder->GetColumnIndexBuilder(column); + void writePageIndexes( + int numRowGroups, + int numColumns, + const std::vector>& pageStats, + const std::vector>& pageLocations, + int finalPosition) { + auto Builder = PageIndexBuilder::make(&schema_); + for (int rowGroup = 0; rowGroup < numRowGroups; ++rowGroup) { + ASSERT_NO_THROW(Builder->appendRowGroup()); + + for (int column = 0; column < numColumns; ++column) { + if (static_cast(column) < pageStats[rowGroup].size()) { + auto ColumnIndexBuilder = Builder->getColumnIndexBuilder(column); ASSERT_NO_THROW( - column_index_builder->AddPage(page_stats[row_group][column])); - ASSERT_NO_THROW(column_index_builder->Finish()); + ColumnIndexBuilder->addPage(pageStats[rowGroup][column])); + ASSERT_NO_THROW(ColumnIndexBuilder->finish()); } - if (static_cast(column) < page_locations[row_group].size()) { - auto offset_index_builder = builder->GetOffsetIndexBuilder(column); + if (static_cast(column) < pageLocations[rowGroup].size()) { + auto OffsetIndexBuilder = Builder->getOffsetIndexBuilder(column); ASSERT_NO_THROW( - offset_index_builder->AddPage(page_locations[row_group][column])); - ASSERT_NO_THROW(offset_index_builder->Finish(final_position)); + OffsetIndexBuilder->addPage(pageLocations[rowGroup][column])); + ASSERT_NO_THROW(OffsetIndexBuilder->finish(finalPosition)); } } } - ASSERT_NO_THROW(builder->Finish()); + ASSERT_NO_THROW(Builder->finish()); - auto sink = CreateOutputStream(); - builder->WriteTo(sink.get(), &page_index_location_); + auto sink = createOutputStream(); + Builder->writeTo(sink.get(), &pageIndexLocation_); PARQUET_ASSIGN_OR_THROW(buffer_, sink->Finish()); ASSERT_EQ( - static_cast(num_row_groups), - page_index_location_.column_index_location.size()); + static_cast(numRowGroups), + pageIndexLocation_.columnIndexLocation.size()); ASSERT_EQ( - static_cast(num_row_groups), - page_index_location_.offset_index_location.size()); - for (int row_group = 0; row_group < num_row_groups; ++row_group) { + static_cast(numRowGroups), + pageIndexLocation_.offsetIndexLocation.size()); + for (int rowGroup = 0; rowGroup < numRowGroups; ++rowGroup) { ASSERT_EQ( - static_cast(num_columns), - page_index_location_.column_index_location[row_group].size()); + static_cast(numColumns), + pageIndexLocation_.columnIndexLocation[rowGroup].size()); ASSERT_EQ( - static_cast(num_columns), - page_index_location_.offset_index_location[row_group].size()); + static_cast(numColumns), + pageIndexLocation_.offsetIndexLocation[rowGroup].size()); } } void - CheckColumnIndex(int row_group, int column, const EncodedStatistics& stats) { - auto column_index = ReadColumnIndex(row_group, column); - ASSERT_NE(nullptr, column_index); - ASSERT_EQ(size_t{1}, column_index->null_pages().size()); - ASSERT_EQ(stats.all_null_value, column_index->null_pages()[0]); - ASSERT_EQ(stats.min(), column_index->encoded_min_values()[0]); - ASSERT_EQ(stats.max(), column_index->encoded_max_values()[0]); - ASSERT_EQ(stats.has_null_count, column_index->has_null_counts()); - if (stats.has_null_count) { - ASSERT_EQ(stats.null_count, column_index->null_counts()[0]); + checkColumnIndex(int rowGroup, int column, const EncodedStatistics& stats) { + auto columnIndex = readColumnIndex(rowGroup, column); + ASSERT_NE(nullptr, columnIndex); + ASSERT_EQ(size_t{1}, columnIndex->nullPages().size()); + ASSERT_EQ(stats.allNullValue, columnIndex->nullPages()[0]); + ASSERT_EQ(stats.min(), columnIndex->encodedMinValues()[0]); + ASSERT_EQ(stats.max(), columnIndex->encodedMaxValues()[0]); + ASSERT_EQ(stats.hasNullCount, columnIndex->hasNullCounts()); + if (stats.hasNullCount) { + ASSERT_EQ(stats.nullCount, columnIndex->nullCounts()[0]); } } - void CheckOffsetIndex( - int row_group, + void checkOffsetIndex( + int rowGroup, int column, - const PageLocation& expected_location, - int64_t final_location) { - auto offset_index = ReadOffsetIndex(row_group, column); - ASSERT_NE(nullptr, offset_index); - ASSERT_EQ(size_t{1}, offset_index->page_locations().size()); - const auto& location = offset_index->page_locations()[0]; - ASSERT_EQ(expected_location.offset + final_location, location.offset); - ASSERT_EQ( - expected_location.compressed_page_size, location.compressed_page_size); - ASSERT_EQ(expected_location.first_row_index, location.first_row_index); + const PageLocation& expectedLocation, + int64_t finalLocation) { + auto offsetIndex = readOffsetIndex(rowGroup, column); + ASSERT_NE(nullptr, offsetIndex); + ASSERT_EQ(size_t{1}, offsetIndex->pageLocations().size()); + const auto& location = offsetIndex->pageLocations()[0]; + ASSERT_EQ(expectedLocation.offset + finalLocation, location.offset); + ASSERT_EQ(expectedLocation.compressedPageSize, location.compressedPageSize); + ASSERT_EQ(expectedLocation.firstRowIndex, location.firstRowIndex); } protected: - std::unique_ptr ReadColumnIndex(int row_group, int column) { - auto location = - page_index_location_.column_index_location[row_group][column]; + std::unique_ptr readColumnIndex(int rowGroup, int column) { + auto location = pageIndexLocation_.columnIndexLocation[rowGroup][column]; if (!location.has_value()) { return nullptr; } - auto properties = default_reader_properties(); - return ColumnIndex::Make( - *schema_.Column(column), + auto properties = defaultReaderProperties(); + return ColumnIndex::make( + *schema_.column(column), buffer_->data() + location->offset, static_cast(location->length), properties); } - std::unique_ptr ReadOffsetIndex(int row_group, int column) { - auto location = - page_index_location_.offset_index_location[row_group][column]; + std::unique_ptr readOffsetIndex(int rowGroup, int column) { + auto location = pageIndexLocation_.offsetIndexLocation[rowGroup][column]; if (!location.has_value()) { return nullptr; } - auto properties = default_reader_properties(); - return OffsetIndex::Make( + auto properties = defaultReaderProperties(); + return OffsetIndex::make( buffer_->data() + location->offset, static_cast(location->length), properties); @@ -639,118 +601,96 @@ class PageIndexBuilderTest : public ::testing::Test { SchemaDescriptor schema_; std::shared_ptr buffer_; - PageIndexLocation page_index_location_; + PageIndexLocation pageIndexLocation_; }; TEST_F(PageIndexBuilderTest, SingleRowGroup) { - schema::NodePtr root = schema::GroupNode::Make( + schema::NodePtr root = schema::GroupNode::make( "schema", - Repetition::REPEATED, - {schema::ByteArray("c1"), - schema::ByteArray("c2"), - schema::ByteArray("c3")}); - schema_.Init(root); + Repetition::kRepeated, + {schema::byteArray("c1"), + schema::byteArray("c2"), + schema::byteArray("c3")}); + schema_.init(root); // Prepare page stats and page locations for single row group. - // Note that the 3rd column does not have any stats and its page index is - // disabled. - const int num_row_groups = 1; - const int num_columns = 3; - const std::vector> page_stats = { + // Note that the 3rd column does not have any stats and its page index is. + // Disabled. + const int numRowGroups = 1; + const int numColumns = 3; + const std::vector> pageStats = { /*row_group_id=0*/ - {/*column_id=0*/ EncodedStatistics() - .set_null_count(0) - .set_min("a") - .set_max("b"), + {/*column_id=0*/ EncodedStatistics().setNullCount(0).setMin("a").setMax( + "b"), /*column_id=1*/ - EncodedStatistics().set_null_count(0).set_min("A").set_max("B")}}; - const std::vector> page_locations = { + EncodedStatistics().setNullCount(0).setMin("A").setMax("B")}}; + const std::vector> pageLocations = { /*row_group_id=0*/ - {/*column_id=0*/ { - /*offset=*/128, - /*compressed_page_size=*/512, - /*first_row_index=*/0}, + {/*column_id=0*/ {128, 512, 0}, /*column_id=1*/ - {/*offset=*/1024, - /*compressed_page_size=*/512, - /*first_row_index=*/0}}}; - const int64_t final_position = 200; + {1024, 512, 0}}}; + const int64_t finalPosition = 200; - WritePageIndexes( - num_row_groups, num_columns, page_stats, page_locations, final_position); + writePageIndexes( + numRowGroups, numColumns, pageStats, pageLocations, finalPosition); // Verify that first two columns have good page indexes. for (int column = 0; column < 2; ++column) { - CheckColumnIndex(/*row_group=*/0, column, page_stats[0][column]); - CheckOffsetIndex( - /*row_group=*/0, column, page_locations[0][column], final_position); + checkColumnIndex(0, column, pageStats[0][column]); + checkOffsetIndex(0, column, pageLocations[0][column], finalPosition); } // Verify the 3rd column does not have page indexes. - ASSERT_EQ(nullptr, ReadColumnIndex(/*row_group=*/0, /*column=*/2)); - ASSERT_EQ(nullptr, ReadOffsetIndex(/*row_group=*/0, /*column=*/2)); + ASSERT_EQ(nullptr, readColumnIndex(0, 2)); + ASSERT_EQ(nullptr, readOffsetIndex(0, 2)); } TEST_F(PageIndexBuilderTest, TwoRowGroups) { - schema::NodePtr root = schema::GroupNode::Make( + schema::NodePtr root = schema::GroupNode::make( "schema", - Repetition::REPEATED, - {schema::ByteArray("c1"), schema::ByteArray("c2")}); - schema_.Init(root); + Repetition::kRepeated, + {schema::byteArray("c1"), schema::byteArray("c2")}); + schema_.init(root); // Prepare page stats and page locations for two row groups. // Note that the 2nd column in the 2nd row group has corrupted stats. - const int num_row_groups = 2; - const int num_columns = 2; - const std::vector> page_stats = { + const int numRowGroups = 2; + const int numColumns = 2; + const std::vector> pageStats = { /*row_group_id=0*/ - {/*column_id=0*/ EncodedStatistics().set_min("a").set_max("b"), + {/*column_id=0*/ EncodedStatistics().setMin("a").setMax("b"), /*column_id=1*/ - EncodedStatistics().set_null_count(0).set_min("A").set_max("B")}, + EncodedStatistics().setNullCount(0).setMin("A").setMax("B")}, /*row_group_id=1*/ {/*column_id=0*/ EncodedStatistics() /* corrupted stats */, /*column_id=1*/ - EncodedStatistics().set_null_count(0).set_min("bar").set_max("foo")}}; - const std::vector> page_locations = { + EncodedStatistics().setNullCount(0).setMin("bar").setMax("foo")}}; + const std::vector> pageLocations = { /*row_group_id=0*/ - {/*column_id=0*/ { - /*offset=*/128, - /*compressed_page_size=*/512, - /*first_row_index=*/0}, + {/*column_id=0*/ {128, 512, 0}, /*column_id=1*/ - {/*offset=*/1024, - /*compressed_page_size=*/512, - /*first_row_index=*/0}}, + {1024, 512, 0}}, /*row_group_id=0*/ - {/*column_id=0*/ { - /*offset=*/128, - /*compressed_page_size=*/512, - /*first_row_index=*/0}, + {/*column_id=0*/ {128, 512, 0}, /*column_id=1*/ - {/*offset=*/1024, - /*compressed_page_size=*/512, - /*first_row_index=*/0}}}; - const int64_t final_position = 200; + {1024, 512, 0}}}; + const int64_t finalPosition = 200; - WritePageIndexes( - num_row_groups, num_columns, page_stats, page_locations, final_position); + writePageIndexes( + numRowGroups, numColumns, pageStats, pageLocations, finalPosition); - // Verify that all columns have good column indexes except the 2nd column in - // the 2nd row group. - CheckColumnIndex(/*row_group=*/0, /*column=*/0, page_stats[0][0]); - CheckColumnIndex(/*row_group=*/0, /*column=*/1, page_stats[0][1]); - CheckColumnIndex(/*row_group=*/1, /*column=*/1, page_stats[1][1]); - ASSERT_EQ(nullptr, ReadColumnIndex(/*row_group=*/1, /*column=*/0)); + // Verify that all columns have good column indexes except the 2nd column in. + // The 2nd row group. + checkColumnIndex(0, 0, pageStats[0][0]); + checkColumnIndex(0, 1, pageStats[0][1]); + checkColumnIndex(1, 1, pageStats[1][1]); + ASSERT_EQ(nullptr, readColumnIndex(1, 0)); // Verify that two columns have good offset indexes. - CheckOffsetIndex( - /*row_group=*/0, /*column=*/0, page_locations[0][0], final_position); - CheckOffsetIndex( - /*row_group=*/0, /*column=*/1, page_locations[0][1], final_position); - CheckOffsetIndex( - /*row_group=*/1, /*column=*/0, page_locations[1][0], final_position); - CheckOffsetIndex( - /*row_group=*/1, /*column=*/1, page_locations[1][1], final_position); + checkOffsetIndex(0, 0, pageLocations[0][0], finalPosition); + checkOffsetIndex(0, 1, pageLocations[0][1], finalPosition); + checkOffsetIndex(1, 0, pageLocations[1][0], finalPosition); + checkOffsetIndex(1, 1, pageLocations[1][1], finalPosition); } } // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/tests/PropertiesTest.cpp b/velox/dwio/parquet/writer/arrow/tests/PropertiesTest.cpp index 0b3f29737e6..27f9589b793 100644 --- a/velox/dwio/parquet/writer/arrow/tests/PropertiesTest.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/PropertiesTest.cpp @@ -34,63 +34,63 @@ namespace test { TEST(TestReaderProperties, Basics) { ReaderProperties props; - ASSERT_EQ(props.buffer_size(), kDefaultBufferSize); - ASSERT_FALSE(props.is_buffered_stream_enabled()); - ASSERT_FALSE(props.page_checksum_verification()); + ASSERT_EQ(props.bufferSize(), kDefaultBufferSize); + ASSERT_FALSE(props.isBufferedStreamEnabled()); + ASSERT_FALSE(props.pageChecksumVerification()); } TEST(TestWriterProperties, Basics) { std::shared_ptr props = WriterProperties::Builder().build(); - ASSERT_EQ(kDefaultDataPageSize, props->data_pagesize()); + ASSERT_EQ(kDefaultDataPageSize, props->dataPagesize()); ASSERT_EQ( - DEFAULT_DICTIONARY_PAGE_SIZE_LIMIT, props->dictionary_pagesize_limit()); + DEFAULT_DICTIONARY_PAGE_SIZE_LIMIT, props->dictionaryPagesizeLimit()); ASSERT_EQ(ParquetVersion::PARQUET_2_6, props->version()); - ASSERT_EQ(ParquetDataPageVersion::V1, props->data_page_version()); - ASSERT_FALSE(props->page_checksum_enabled()); + ASSERT_EQ(ParquetDataPageVersion::V1, props->dataPageVersion()); + ASSERT_FALSE(props->pageChecksumEnabled()); } TEST(TestWriterProperties, AdvancedHandling) { - WriterProperties::Builder builder; - builder.compression("gzip", Compression::GZIP); - builder.compression("zstd", Compression::ZSTD); - builder.compression(Compression::SNAPPY); - builder.encoding(Encoding::DELTA_BINARY_PACKED); - builder.encoding("delta-length", Encoding::DELTA_LENGTH_BYTE_ARRAY); - builder.data_page_version(ParquetDataPageVersion::V2); - std::shared_ptr props = builder.build(); + WriterProperties::Builder Builder; + Builder.compression("gzip", Compression::GZIP); + Builder.compression("zstd", Compression::ZSTD); + Builder.compression(Compression::SNAPPY); + Builder.encoding(Encoding::kDeltaBinaryPacked); + Builder.encoding("delta-length", Encoding::kDeltaLengthByteArray); + Builder.dataPageVersion(ParquetDataPageVersion::V2); + std::shared_ptr props = Builder.build(); ASSERT_EQ( - Compression::GZIP, props->compression(ColumnPath::FromDotString("gzip"))); + Compression::GZIP, props->compression(ColumnPath::fromDotString("gzip"))); ASSERT_EQ( - Compression::ZSTD, props->compression(ColumnPath::FromDotString("zstd"))); + Compression::ZSTD, props->compression(ColumnPath::fromDotString("zstd"))); ASSERT_EQ( Compression::SNAPPY, - props->compression(ColumnPath::FromDotString("delta-length"))); + props->compression(ColumnPath::fromDotString("delta-length"))); ASSERT_EQ( - Encoding::DELTA_BINARY_PACKED, - props->encoding(ColumnPath::FromDotString("gzip"))); + Encoding::kDeltaBinaryPacked, + props->encoding(ColumnPath::fromDotString("gzip"))); ASSERT_EQ( - Encoding::DELTA_LENGTH_BYTE_ARRAY, - props->encoding(ColumnPath::FromDotString("delta-length"))); - ASSERT_EQ(ParquetDataPageVersion::V2, props->data_page_version()); + Encoding::kDeltaLengthByteArray, + props->encoding(ColumnPath::fromDotString("delta-length"))); + ASSERT_EQ(ParquetDataPageVersion::V2, props->dataPageVersion()); } TEST(TestReaderProperties, GetStreamInsufficientData) { - // ARROW-6058 + // ARROW-6058. std::string data = "shorter than expected"; auto buf = std::make_shared(data); auto reader = std::make_shared<::arrow::io::BufferReader>(buf); ReaderProperties props; try { - ARROW_UNUSED(props.GetStream(reader, 12, 15)); + ARROW_UNUSED(props.getStream(reader, 12, 15)); FAIL() << "No exception raised"; } catch (const ParquetException& e) { - std::string ex_what = + std::string exWhat = ("Tried reading 15 bytes starting at position 12" " from file but only got 9"); - ASSERT_EQ(ex_what, e.what()); + ASSERT_EQ(exWhat, e.what()); } } diff --git a/velox/dwio/parquet/writer/arrow/tests/SchemaTest.cpp b/velox/dwio/parquet/writer/arrow/tests/SchemaTest.cpp index f312ec7e7b4..fb22fc3d681 100644 --- a/velox/dwio/parquet/writer/arrow/tests/SchemaTest.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/SchemaTest.cpp @@ -42,507 +42,513 @@ using facebook::velox::parquet::thrift::SchemaElement; namespace schema { -static inline SchemaElement NewPrimitive( +static inline SchemaElement newPrimitive( const std::string& name, FieldRepetitionType::type repetition, Type::type type, - int field_id = -1) { + int fieldId = -1) { SchemaElement result; result.__set_name(name); result.__set_repetition_type(repetition); result.__set_type( static_cast(type)); - if (field_id >= 0) { - result.__set_field_id(field_id); + if (fieldId >= 0) { + result.__set_field_id(fieldId); } return result; } -static inline SchemaElement NewGroup( +static inline SchemaElement newGroup( const std::string& name, FieldRepetitionType::type repetition, - int num_children, - int field_id = -1) { + int numChildren, + int fieldId = -1) { SchemaElement result; result.__set_name(name); result.__set_repetition_type(repetition); - result.__set_num_children(num_children); + result.__set_num_children(numChildren); - if (field_id >= 0) { - result.__set_field_id(field_id); + if (fieldId >= 0) { + result.__set_field_id(fieldId); } return result; } template -static void CheckNodeRoundtrip(const Node& node) { +static void checkNodeRoundtrip(const Node& node) { facebook::velox::parquet::thrift::SchemaElement serialized; - node.ToParquet(&serialized); - std::unique_ptr recovered = NodeType::FromParquet(&serialized); - ASSERT_TRUE(node.Equals(recovered.get())) + node.toParquet(&serialized); + std::unique_ptr recovered = NodeType::fromParquet(&serialized); + ASSERT_TRUE(node.equals(recovered.get())) << "Recovered node not equivalent to original node constructed " - << "with logical type " << node.logical_type()->ToString() << " got " - << recovered->logical_type()->ToString(); + << "with logical type " << node.logicalType()->toString() << " got " + << recovered->logicalType()->toString(); } -static void ConfirmPrimitiveNodeRoundtrip( - const std::shared_ptr& logical_type, - Type::type physical_type, - int physical_length, - int field_id = -1) { - auto node = PrimitiveNode::Make( +static void confirmPrimitiveNodeRoundtrip( + const std::shared_ptr& logicalType, + Type::type physicalType, + int physicalLength, + int fieldId = -1) { + auto Node = PrimitiveNode::make( "something", - Repetition::REQUIRED, - logical_type, - physical_type, - physical_length, - field_id); - CheckNodeRoundtrip(*node); + Repetition::kRequired, + logicalType, + physicalType, + physicalLength, + fieldId); + checkNodeRoundtrip(*Node); } -static void ConfirmGroupNodeRoundtrip( +static void confirmGroupNodeRoundtrip( std::string name, - const std::shared_ptr& logical_type, - int field_id = -1) { - auto node = - GroupNode::Make(name, Repetition::REQUIRED, {}, logical_type, field_id); - CheckNodeRoundtrip(*node); + const std::shared_ptr& logicalType, + int fieldId = -1) { + auto Node = + GroupNode::make(name, Repetition::kRequired, {}, logicalType, fieldId); + checkNodeRoundtrip(*Node); } -// ---------------------------------------------------------------------- -// ColumnPath +// ----------------------------------------------------------------------. +// ColumnPath. TEST(TestColumnPath, TestAttrs) { ColumnPath path(std::vector({"toplevel", "leaf"})); - ASSERT_EQ(path.ToDotString(), "toplevel.leaf"); + ASSERT_EQ(path.toDotString(), "toplevel.leaf"); - std::shared_ptr path_ptr = - ColumnPath::FromDotString("toplevel.leaf"); - ASSERT_EQ(path_ptr->ToDotString(), "toplevel.leaf"); + std::shared_ptr pathPtr = + ColumnPath::fromDotString("toplevel.leaf"); + ASSERT_EQ(pathPtr->toDotString(), "toplevel.leaf"); - std::shared_ptr extended = path_ptr->extend("anotherlevel"); - ASSERT_EQ(extended->ToDotString(), "toplevel.leaf.anotherlevel"); + std::shared_ptr extended = pathPtr->extend("anotherlevel"); + ASSERT_EQ(extended->toDotString(), "toplevel.leaf.anotherlevel"); } -// ---------------------------------------------------------------------- -// Primitive node +// ----------------------------------------------------------------------. +// Primitive node. class TestPrimitiveNode : public ::testing::Test { public: void SetUp() { name_ = "name"; - field_id_ = 5; + fieldId_ = 5; } - void Convert(const facebook::velox::parquet::thrift::SchemaElement* element) { - node_ = PrimitiveNode::FromParquet(element); - ASSERT_TRUE(node_->is_primitive()); - prim_node_ = static_cast(node_.get()); + void convert(const facebook::velox::parquet::thrift::SchemaElement* element) { + node_ = PrimitiveNode::fromParquet(element); + ASSERT_TRUE(node_->isPrimitive()); + primNode_ = static_cast(node_.get()); } protected: std::string name_; - const PrimitiveNode* prim_node_; + const PrimitiveNode* primNode_; - int field_id_; + int fieldId_; std::unique_ptr node_; }; TEST_F(TestPrimitiveNode, Attrs) { - PrimitiveNode node1("foo", Repetition::REPEATED, Type::INT32); + PrimitiveNode node1("foo", Repetition::kRepeated, Type::kInt32); PrimitiveNode node2( - "bar", Repetition::OPTIONAL, Type::BYTE_ARRAY, ConvertedType::UTF8); + "bar", Repetition::kOptional, Type::kByteArray, ConvertedType::kUtf8); ASSERT_EQ("foo", node1.name()); - ASSERT_TRUE(node1.is_primitive()); - ASSERT_FALSE(node1.is_group()); + ASSERT_TRUE(node1.isPrimitive()); + ASSERT_FALSE(node1.isGroup()); - ASSERT_EQ(Repetition::REPEATED, node1.repetition()); - ASSERT_EQ(Repetition::OPTIONAL, node2.repetition()); + ASSERT_EQ(Repetition::kRepeated, node1.repetition()); + ASSERT_EQ(Repetition::kOptional, node2.repetition()); - ASSERT_EQ(Node::PRIMITIVE, node1.node_type()); + ASSERT_EQ(Node::kPrimitive, node1.nodeType()); - ASSERT_EQ(Type::INT32, node1.physical_type()); - ASSERT_EQ(Type::BYTE_ARRAY, node2.physical_type()); + ASSERT_EQ(Type::kInt32, node1.physicalType()); + ASSERT_EQ(Type::kByteArray, node2.physicalType()); - // logical types - ASSERT_EQ(ConvertedType::NONE, node1.converted_type()); - ASSERT_EQ(ConvertedType::UTF8, node2.converted_type()); + // Logical types. + ASSERT_EQ(ConvertedType::kNone, node1.convertedType()); + ASSERT_EQ(ConvertedType::kUtf8, node2.convertedType()); - // repetition - PrimitiveNode node3("foo", Repetition::REPEATED, Type::INT32); - PrimitiveNode node4("foo", Repetition::REQUIRED, Type::INT32); - PrimitiveNode node5("foo", Repetition::OPTIONAL, Type::INT32); + // Repetition. + PrimitiveNode node3("foo", Repetition::kRepeated, Type::kInt32); + PrimitiveNode node4("foo", Repetition::kRequired, Type::kInt32); + PrimitiveNode node5("foo", Repetition::kOptional, Type::kInt32); - ASSERT_TRUE(node3.is_repeated()); - ASSERT_FALSE(node3.is_optional()); + ASSERT_TRUE(node3.isRepeated()); + ASSERT_FALSE(node3.isOptional()); - ASSERT_TRUE(node4.is_required()); + ASSERT_TRUE(node4.isRequired()); - ASSERT_TRUE(node5.is_optional()); - ASSERT_FALSE(node5.is_required()); + ASSERT_TRUE(node5.isOptional()); + ASSERT_FALSE(node5.isRequired()); } -TEST_F(TestPrimitiveNode, FromParquet) { - SchemaElement elt = NewPrimitive( - name_, FieldRepetitionType::OPTIONAL, Type::INT32, field_id_); - ASSERT_NO_FATAL_FAILURE(Convert(&elt)); - ASSERT_EQ(name_, prim_node_->name()); - ASSERT_EQ(field_id_, prim_node_->field_id()); - ASSERT_EQ(Repetition::OPTIONAL, prim_node_->repetition()); - ASSERT_EQ(Type::INT32, prim_node_->physical_type()); - ASSERT_EQ(ConvertedType::NONE, prim_node_->converted_type()); - - // Test a logical type - elt = NewPrimitive( - name_, FieldRepetitionType::REQUIRED, Type::BYTE_ARRAY, field_id_); +TEST_F(TestPrimitiveNode, fromParquet) { + SchemaElement elt = newPrimitive( + name_, FieldRepetitionType::OPTIONAL, Type::kInt32, fieldId_); + ASSERT_NO_FATAL_FAILURE(convert(&elt)); + ASSERT_EQ(name_, primNode_->name()); + ASSERT_EQ(fieldId_, primNode_->fieldId()); + ASSERT_EQ(Repetition::kOptional, primNode_->repetition()); + ASSERT_EQ(Type::kInt32, primNode_->physicalType()); + ASSERT_EQ(ConvertedType::kNone, primNode_->convertedType()); + + // Test a logical type. + elt = newPrimitive( + name_, FieldRepetitionType::REQUIRED, Type::kByteArray, fieldId_); elt.__set_converted_type( facebook::velox::parquet::thrift::ConvertedType::UTF8); - ASSERT_NO_FATAL_FAILURE(Convert(&elt)); - ASSERT_EQ(Repetition::REQUIRED, prim_node_->repetition()); - ASSERT_EQ(Type::BYTE_ARRAY, prim_node_->physical_type()); - ASSERT_EQ(ConvertedType::UTF8, prim_node_->converted_type()); - - // FIXED_LEN_BYTE_ARRAY - elt = NewPrimitive( - name_, - FieldRepetitionType::OPTIONAL, - Type::FIXED_LEN_BYTE_ARRAY, - field_id_); + ASSERT_NO_FATAL_FAILURE(convert(&elt)); + ASSERT_EQ(Repetition::kRequired, primNode_->repetition()); + ASSERT_EQ(Type::kByteArray, primNode_->physicalType()); + ASSERT_EQ(ConvertedType::kUtf8, primNode_->convertedType()); + + // FIXED_LEN_BYTE_ARRAY. + elt = newPrimitive( + name_, FieldRepetitionType::OPTIONAL, Type::kFixedLenByteArray, fieldId_); elt.__set_type_length(16); - ASSERT_NO_FATAL_FAILURE(Convert(&elt)); - ASSERT_EQ(name_, prim_node_->name()); - ASSERT_EQ(field_id_, prim_node_->field_id()); - ASSERT_EQ(Repetition::OPTIONAL, prim_node_->repetition()); - ASSERT_EQ(Type::FIXED_LEN_BYTE_ARRAY, prim_node_->physical_type()); - ASSERT_EQ(16, prim_node_->type_length()); - - // facebook::velox::parquet::thrift::ConvertedType::Decimal - elt = NewPrimitive( - name_, - FieldRepetitionType::OPTIONAL, - Type::FIXED_LEN_BYTE_ARRAY, - field_id_); + ASSERT_NO_FATAL_FAILURE(convert(&elt)); + ASSERT_EQ(name_, primNode_->name()); + ASSERT_EQ(fieldId_, primNode_->fieldId()); + ASSERT_EQ(Repetition::kOptional, primNode_->repetition()); + ASSERT_EQ(Type::kFixedLenByteArray, primNode_->physicalType()); + ASSERT_EQ(16, primNode_->typeLength()); + + // Facebook::velox::parquet::thrift::ConvertedType::Decimal. + elt = newPrimitive( + name_, FieldRepetitionType::OPTIONAL, Type::kFixedLenByteArray, fieldId_); elt.__set_converted_type( facebook::velox::parquet::thrift::ConvertedType::DECIMAL); elt.__set_type_length(6); elt.__set_scale(2); elt.__set_precision(12); - ASSERT_NO_FATAL_FAILURE(Convert(&elt)); - ASSERT_EQ(Type::FIXED_LEN_BYTE_ARRAY, prim_node_->physical_type()); - ASSERT_EQ(ConvertedType::DECIMAL, prim_node_->converted_type()); - ASSERT_EQ(6, prim_node_->type_length()); - ASSERT_EQ(2, prim_node_->decimal_metadata().scale); - ASSERT_EQ(12, prim_node_->decimal_metadata().precision); + ASSERT_NO_FATAL_FAILURE(convert(&elt)); + ASSERT_EQ(Type::kFixedLenByteArray, primNode_->physicalType()); + ASSERT_EQ(ConvertedType::kDecimal, primNode_->convertedType()); + ASSERT_EQ(6, primNode_->typeLength()); + ASSERT_EQ(2, primNode_->decimalMetadata().scale); + ASSERT_EQ(12, primNode_->decimalMetadata().precision); } -TEST_F(TestPrimitiveNode, Equals) { - PrimitiveNode node1("foo", Repetition::REQUIRED, Type::INT32); - PrimitiveNode node2("foo", Repetition::REQUIRED, Type::INT64); - PrimitiveNode node3("bar", Repetition::REQUIRED, Type::INT32); - PrimitiveNode node4("foo", Repetition::OPTIONAL, Type::INT32); - PrimitiveNode node5("foo", Repetition::REQUIRED, Type::INT32); +TEST_F(TestPrimitiveNode, equals) { + PrimitiveNode node1("foo", Repetition::kRequired, Type::kInt32); + PrimitiveNode node2("foo", Repetition::kRequired, Type::kInt64); + PrimitiveNode node3("bar", Repetition::kRequired, Type::kInt32); + PrimitiveNode node4("foo", Repetition::kOptional, Type::kInt32); + PrimitiveNode node5("foo", Repetition::kRequired, Type::kInt32); - ASSERT_TRUE(node1.Equals(&node1)); - ASSERT_FALSE(node1.Equals(&node2)); - ASSERT_FALSE(node1.Equals(&node3)); - ASSERT_FALSE(node1.Equals(&node4)); - ASSERT_TRUE(node1.Equals(&node5)); + ASSERT_TRUE(node1.equals(&node1)); + ASSERT_FALSE(node1.equals(&node2)); + ASSERT_FALSE(node1.equals(&node3)); + ASSERT_FALSE(node1.equals(&node4)); + ASSERT_TRUE(node1.equals(&node5)); PrimitiveNode flba1( "foo", - Repetition::REQUIRED, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::DECIMAL, + Repetition::kRequired, + Type::kFixedLenByteArray, + ConvertedType::kDecimal, 12, 4, 2); PrimitiveNode flba2( "foo", - Repetition::REQUIRED, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::DECIMAL, + Repetition::kRequired, + Type::kFixedLenByteArray, + ConvertedType::kDecimal, 1, 4, 2); - flba2.SetTypeLength(12); + flba2.setTypeLength(12); PrimitiveNode flba3( "foo", - Repetition::REQUIRED, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::DECIMAL, + Repetition::kRequired, + Type::kFixedLenByteArray, + ConvertedType::kDecimal, 1, 4, 2); - flba3.SetTypeLength(16); + flba3.setTypeLength(16); PrimitiveNode flba4( "foo", - Repetition::REQUIRED, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::DECIMAL, + Repetition::kRequired, + Type::kFixedLenByteArray, + ConvertedType::kDecimal, 12, 4, 0); PrimitiveNode flba5( "foo", - Repetition::REQUIRED, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::NONE, + Repetition::kRequired, + Type::kFixedLenByteArray, + ConvertedType::kNone, 12, 4, 0); - ASSERT_TRUE(flba1.Equals(&flba2)); - ASSERT_FALSE(flba1.Equals(&flba3)); - ASSERT_FALSE(flba1.Equals(&flba4)); - ASSERT_FALSE(flba1.Equals(&flba5)); + ASSERT_TRUE(flba1.equals(&flba2)); + ASSERT_FALSE(flba1.equals(&flba3)); + ASSERT_FALSE(flba1.equals(&flba4)); + ASSERT_FALSE(flba1.equals(&flba5)); } TEST_F(TestPrimitiveNode, PhysicalLogicalMapping) { - ASSERT_NO_THROW(PrimitiveNode::Make( - "foo", Repetition::REQUIRED, Type::INT32, ConvertedType::INT_32)); - ASSERT_NO_THROW(PrimitiveNode::Make( - "foo", Repetition::REQUIRED, Type::BYTE_ARRAY, ConvertedType::JSON)); + ASSERT_NO_THROW( + PrimitiveNode::make( + "foo", Repetition::kRequired, Type::kInt32, ConvertedType::kInt32)); + ASSERT_NO_THROW( + PrimitiveNode::make( + "foo", + Repetition::kRequired, + Type::kByteArray, + ConvertedType::kJson)); ASSERT_THROW( - PrimitiveNode::Make( - "foo", Repetition::REQUIRED, Type::INT32, ConvertedType::JSON), + PrimitiveNode::make( + "foo", Repetition::kRequired, Type::kInt32, ConvertedType::kJson), ParquetException); - ASSERT_NO_THROW(PrimitiveNode::Make( - "foo", - Repetition::REQUIRED, - Type::INT64, - ConvertedType::TIMESTAMP_MILLIS)); + ASSERT_NO_THROW( + PrimitiveNode::make( + "foo", + Repetition::kRequired, + Type::kInt64, + ConvertedType::kTimestampMillis)); ASSERT_THROW( - PrimitiveNode::Make( - "foo", Repetition::REQUIRED, Type::INT32, ConvertedType::INT_64), + PrimitiveNode::make( + "foo", Repetition::kRequired, Type::kInt32, ConvertedType::kInt64), ParquetException); ASSERT_THROW( - PrimitiveNode::Make( - "foo", Repetition::REQUIRED, Type::BYTE_ARRAY, ConvertedType::INT_8), + PrimitiveNode::make( + "foo", Repetition::kRequired, Type::kByteArray, ConvertedType::kInt8), ParquetException); ASSERT_THROW( - PrimitiveNode::Make( + PrimitiveNode::make( "foo", - Repetition::REQUIRED, - Type::BYTE_ARRAY, - ConvertedType::INTERVAL), + Repetition::kRequired, + Type::kByteArray, + ConvertedType::kInterval), ParquetException); ASSERT_THROW( - PrimitiveNode::Make( + PrimitiveNode::make( "foo", - Repetition::REQUIRED, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::ENUM), + Repetition::kRequired, + Type::kFixedLenByteArray, + ConvertedType::kEnum), ParquetException); - ASSERT_NO_THROW(PrimitiveNode::Make( - "foo", Repetition::REQUIRED, Type::BYTE_ARRAY, ConvertedType::ENUM)); + ASSERT_NO_THROW( + PrimitiveNode::make( + "foo", + Repetition::kRequired, + Type::kByteArray, + ConvertedType::kEnum)); ASSERT_THROW( - PrimitiveNode::Make( + PrimitiveNode::make( "foo", - Repetition::REQUIRED, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::DECIMAL, + Repetition::kRequired, + Type::kFixedLenByteArray, + ConvertedType::kDecimal, 0, 2, 4), ParquetException); ASSERT_THROW( - PrimitiveNode::Make( + PrimitiveNode::make( "foo", - Repetition::REQUIRED, - Type::FLOAT, - ConvertedType::DECIMAL, + Repetition::kRequired, + Type::kFloat, + ConvertedType::kDecimal, 0, 2, 4), ParquetException); ASSERT_THROW( - PrimitiveNode::Make( + PrimitiveNode::make( "foo", - Repetition::REQUIRED, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::DECIMAL, + Repetition::kRequired, + Type::kFixedLenByteArray, + ConvertedType::kDecimal, 0, 4, 0), ParquetException); ASSERT_THROW( - PrimitiveNode::Make( + PrimitiveNode::make( "foo", - Repetition::REQUIRED, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::DECIMAL, + Repetition::kRequired, + Type::kFixedLenByteArray, + ConvertedType::kDecimal, 10, 0, 4), ParquetException); ASSERT_THROW( - PrimitiveNode::Make( + PrimitiveNode::make( "foo", - Repetition::REQUIRED, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::DECIMAL, + Repetition::kRequired, + Type::kFixedLenByteArray, + ConvertedType::kDecimal, 10, 4, -1), ParquetException); ASSERT_THROW( - PrimitiveNode::Make( + PrimitiveNode::make( "foo", - Repetition::REQUIRED, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::DECIMAL, + Repetition::kRequired, + Type::kFixedLenByteArray, + ConvertedType::kDecimal, 10, 2, 4), ParquetException); - ASSERT_NO_THROW(PrimitiveNode::Make( - "foo", - Repetition::REQUIRED, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::DECIMAL, - 10, - 6, - 4)); - ASSERT_NO_THROW(PrimitiveNode::Make( - "foo", - Repetition::REQUIRED, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::INTERVAL, - 12)); + ASSERT_NO_THROW( + PrimitiveNode::make( + "foo", + Repetition::kRequired, + Type::kFixedLenByteArray, + ConvertedType::kDecimal, + 10, + 6, + 4)); + ASSERT_NO_THROW( + PrimitiveNode::make( + "foo", + Repetition::kRequired, + Type::kFixedLenByteArray, + ConvertedType::kInterval, + 12)); ASSERT_THROW( - PrimitiveNode::Make( + PrimitiveNode::make( "foo", - Repetition::REQUIRED, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::INTERVAL, + Repetition::kRequired, + Type::kFixedLenByteArray, + ConvertedType::kInterval, 10), ParquetException); } -// ---------------------------------------------------------------------- -// Group node +// ----------------------------------------------------------------------. +// Group node. class TestGroupNode : public ::testing::Test { public: - NodeVector Fields1() { + NodeVector fields1() { NodeVector fields; - fields.push_back(Int32("one", Repetition::REQUIRED)); - fields.push_back(Int64("two")); - fields.push_back(Double("three")); + fields.push_back(int32("one", Repetition::kRequired)); + fields.push_back(int64("two")); + fields.push_back(doubleType("three")); return fields; } - NodeVector Fields2() { - // Fields with a duplicate name + NodeVector fields2() { + // Fields with a duplicate name. NodeVector fields; - fields.push_back(Int32("duplicate", Repetition::REQUIRED)); - fields.push_back(Int64("unique")); - fields.push_back(Double("duplicate")); + fields.push_back(int32("duplicate", Repetition::kRequired)); + fields.push_back(int64("unique")); + fields.push_back(doubleType("duplicate")); return fields; } }; TEST_F(TestGroupNode, Attrs) { - NodeVector fields = Fields1(); + NodeVector fields = fields1(); - GroupNode node1("foo", Repetition::REPEATED, fields); - GroupNode node2("bar", Repetition::OPTIONAL, fields, ConvertedType::LIST); + GroupNode node1("foo", Repetition::kRepeated, fields); + GroupNode node2("bar", Repetition::kOptional, fields, ConvertedType::kList); ASSERT_EQ("foo", node1.name()); - ASSERT_TRUE(node1.is_group()); - ASSERT_FALSE(node1.is_primitive()); + ASSERT_TRUE(node1.isGroup()); + ASSERT_FALSE(node1.isPrimitive()); - ASSERT_EQ(fields.size(), node1.field_count()); + ASSERT_EQ(fields.size(), node1.fieldCount()); - ASSERT_TRUE(node1.is_repeated()); - ASSERT_TRUE(node2.is_optional()); + ASSERT_TRUE(node1.isRepeated()); + ASSERT_TRUE(node2.isOptional()); - ASSERT_EQ(Repetition::REPEATED, node1.repetition()); - ASSERT_EQ(Repetition::OPTIONAL, node2.repetition()); + ASSERT_EQ(Repetition::kRepeated, node1.repetition()); + ASSERT_EQ(Repetition::kOptional, node2.repetition()); - ASSERT_EQ(Node::GROUP, node1.node_type()); + ASSERT_EQ(Node::kGroup, node1.nodeType()); - // logical types - ASSERT_EQ(ConvertedType::NONE, node1.converted_type()); - ASSERT_EQ(ConvertedType::LIST, node2.converted_type()); + // Logical types. + ASSERT_EQ(ConvertedType::kNone, node1.convertedType()); + ASSERT_EQ(ConvertedType::kList, node2.convertedType()); } -TEST_F(TestGroupNode, Equals) { - NodeVector f1 = Fields1(); - NodeVector f2 = Fields1(); +TEST_F(TestGroupNode, equals) { + NodeVector f1 = fields1(); + NodeVector f2 = fields1(); - GroupNode group1("group", Repetition::REPEATED, f1); - GroupNode group2("group", Repetition::REPEATED, f2); - GroupNode group3("group2", Repetition::REPEATED, f2); + GroupNode group1("group", Repetition::kRepeated, f1); + GroupNode group2("group", Repetition::kRepeated, f2); + GroupNode group3("group2", Repetition::kRepeated, f2); - // This is copied in the GroupNode ctor, so this is okay - f2.push_back(Float("four", Repetition::OPTIONAL)); - GroupNode group4("group", Repetition::REPEATED, f2); - GroupNode group5("group", Repetition::REPEATED, Fields1()); + // This is copied in the GroupNode ctor, so this is okay. + f2.push_back(floatType("four", Repetition::kOptional)); + GroupNode group4("group", Repetition::kRepeated, f2); + GroupNode group5("group", Repetition::kRepeated, fields1()); - ASSERT_TRUE(group1.Equals(&group1)); - ASSERT_TRUE(group1.Equals(&group2)); - ASSERT_FALSE(group1.Equals(&group3)); + ASSERT_TRUE(group1.equals(&group1)); + ASSERT_TRUE(group1.equals(&group2)); + ASSERT_FALSE(group1.equals(&group3)); - ASSERT_FALSE(group1.Equals(&group4)); - ASSERT_FALSE(group5.Equals(&group4)); + ASSERT_FALSE(group1.equals(&group4)); + ASSERT_FALSE(group5.equals(&group4)); } -TEST_F(TestGroupNode, FieldIndex) { - NodeVector fields = Fields1(); - GroupNode group("group", Repetition::REQUIRED, fields); +TEST_F(TestGroupNode, fieldIndex) { + NodeVector fields = fields1(); + GroupNode group("group", Repetition::kRequired, fields); for (size_t i = 0; i < fields.size(); i++) { auto field = group.field(static_cast(i)); - ASSERT_EQ(i, group.FieldIndex(*field)); + ASSERT_EQ(i, group.fieldIndex(*field)); } - // Test a non field node - auto non_field_alien = Int32("alien", Repetition::REQUIRED); // other name - auto non_field_familiar = Int32("one", Repetition::REPEATED); // other node - ASSERT_LT(group.FieldIndex(*non_field_alien), 0); - ASSERT_LT(group.FieldIndex(*non_field_familiar), 0); + // Test a non field node. + auto nonFieldAlien = int32("alien", Repetition::kRequired); // other name + auto nonFieldFamiliar = int32("one", Repetition::kRepeated); // other node + ASSERT_LT(group.fieldIndex(*nonFieldAlien), 0); + ASSERT_LT(group.fieldIndex(*nonFieldFamiliar), 0); } TEST_F(TestGroupNode, FieldIndexDuplicateName) { - NodeVector fields = Fields2(); - GroupNode group("group", Repetition::REQUIRED, fields); + NodeVector fields = fields2(); + GroupNode group("group", Repetition::kRequired, fields); for (size_t i = 0; i < fields.size(); i++) { auto field = group.field(static_cast(i)); - ASSERT_EQ(i, group.FieldIndex(*field)); + ASSERT_EQ(i, group.fieldIndex(*field)); } } -// ---------------------------------------------------------------------- -// Test convert group +// ----------------------------------------------------------------------. +// Test convert group. class TestSchemaConverter : public ::testing::Test { public: - void setUp() { + void SetUp() { name_ = "parquet_schema"; } - void Convert( + void convert( const facebook::velox::parquet::thrift::SchemaElement* elements, int length) { - node_ = Unflatten(elements, length); - ASSERT_TRUE(node_->is_group()); + node_ = unflatten(elements, length); + ASSERT_TRUE(node_->isGroup()); group_ = static_cast(node_.get()); } @@ -552,16 +558,16 @@ class TestSchemaConverter : public ::testing::Test { std::unique_ptr node_; }; -bool check_for_parent_consistency(const GroupNode* node) { - // Each node should have the group as parent - for (int i = 0; i < node->field_count(); i++) { - const NodePtr& field = node->field(i); - if (field->parent() != node) { +bool checkForParentConsistency(const GroupNode* Node) { + // Each node should have the group as parent. + for (int i = 0; i < Node->fieldCount(); i++) { + const NodePtr& field = Node->field(i); + if (field->parent() != Node) { return false; } - if (field->is_group()) { + if (field->isGroup()) { const GroupNode* group = static_cast(field.get()); - if (!check_for_parent_consistency(group)) { + if (!checkForParentConsistency(group)) { return false; } } @@ -572,106 +578,99 @@ bool check_for_parent_consistency(const GroupNode* node) { TEST_F(TestSchemaConverter, NestedExample) { SchemaElement elt; std::vector elements; - elements.push_back(NewGroup( - name_, - FieldRepetitionType::REPEATED, - /*num_children=*/2, - /*field_id=*/0)); + elements.push_back(newGroup(name_, FieldRepetitionType::REPEATED, 2, 0)); - // A primitive one + // A primitive one. elements.push_back( - NewPrimitive("a", FieldRepetitionType::REQUIRED, Type::INT32, 1)); + newPrimitive("a", FieldRepetitionType::REQUIRED, Type::kInt32, 1)); - // A group - elements.push_back(NewGroup("bag", FieldRepetitionType::OPTIONAL, 1, 2)); + // A group. + elements.push_back(newGroup("bag", FieldRepetitionType::OPTIONAL, 1, 2)); - // 3-level list encoding, by hand - elt = NewGroup("b", FieldRepetitionType::REPEATED, 1, 3); + // 3-Level list encoding, by hand. + elt = newGroup("b", FieldRepetitionType::REPEATED, 1, 3); elt.__set_converted_type( facebook::velox::parquet::thrift::ConvertedType::LIST); elements.push_back(elt); elements.push_back( - NewPrimitive("item", FieldRepetitionType::OPTIONAL, Type::INT64, 4)); + newPrimitive("item", FieldRepetitionType::OPTIONAL, Type::kInt64, 4)); ASSERT_NO_FATAL_FAILURE( - Convert(&elements[0], static_cast(elements.size()))); + convert(&elements[0], static_cast(elements.size()))); - // Construct the expected schema + // Construct the expected schema. NodeVector fields; - fields.push_back(Int32("a", Repetition::REQUIRED, 1)); - - // 3-level list encoding - NodePtr item = Int64("item", Repetition::OPTIONAL, 4); - NodePtr list(GroupNode::Make( - "b", Repetition::REPEATED, {item}, ConvertedType::LIST, 3)); - NodePtr bag(GroupNode::Make( - "bag", Repetition::OPTIONAL, {list}, /*logical_type=*/nullptr, 2)); + fields.push_back(int32("a", Repetition::kRequired, 1)); + + // 3-Level list encoding. + NodePtr item = int64("item", Repetition::kOptional, 4); + NodePtr List( + GroupNode::make( + "b", Repetition::kRepeated, {item}, ConvertedType::kList, 3)); + NodePtr bag( + GroupNode::make("bag", Repetition::kOptional, {List}, nullptr, 2)); fields.push_back(bag); - NodePtr schema = GroupNode::Make( - name_, - Repetition::REPEATED, - fields, - /*logical_type=*/nullptr, - 0); + NodePtr schema = + GroupNode::make(name_, Repetition::kRepeated, fields, nullptr, 0); - ASSERT_TRUE(schema->Equals(group_)); + ASSERT_TRUE(schema->equals(group_)); - // Check that the parent relationship in each node is consistent + // Check that the parent relationship in each node is consistent. ASSERT_EQ(group_->parent(), nullptr); - ASSERT_TRUE(check_for_parent_consistency(group_)); + ASSERT_TRUE(checkForParentConsistency(group_)); } TEST_F(TestSchemaConverter, ZeroColumns) { - // ARROW-3843 + // ARROW-3843. SchemaElement elements[1]; - elements[0] = NewGroup("schema", FieldRepetitionType::REPEATED, 0, 0); - ASSERT_NO_THROW(Convert(elements, 1)); + elements[0] = newGroup("schema", FieldRepetitionType::REPEATED, 0, 0); + ASSERT_NO_THROW(convert(elements, 1)); } TEST_F(TestSchemaConverter, InvalidRoot) { - // According to the Parquet specification, the first element in the + // According to the Parquet specification, the first element in the. // list is a group whose children (and their descendants) - // contain all of the rest of the flattened schema elements. If the first - // element is not a group, it is a malformed Parquet file. + // Contain all of the rest of the flattened schema elements. If the first. + // Element is not a group, it is a malformed Parquet file. SchemaElement elements[2]; - elements[0] = NewPrimitive( - "not-a-group", FieldRepetitionType::REQUIRED, Type::INT32, 0); - ASSERT_THROW(Convert(elements, 2), ParquetException); - - // While the Parquet spec indicates that the root group should have REPEATED - // repetition type, some implementations may return REQUIRED or OPTIONAL - // groups as the first element. These tests check that this is okay as a - // practicality matter. - elements[0] = NewGroup("not-repeated", FieldRepetitionType::REQUIRED, 1, 0); + elements[0] = newPrimitive( + "not-a-group", FieldRepetitionType::REQUIRED, Type::kInt32, 0); + ASSERT_THROW(convert(elements, 2), ParquetException); + + // While the Parquet spec indicates that the root group should have REPEATED. + // Repetition type, some implementations may return REQUIRED or OPTIONAL. + // Groups as the first element. These tests check that this is okay as a. + // Practicality matter. + elements[0] = newGroup("not-repeated", FieldRepetitionType::REQUIRED, 1, 0); elements[1] = - NewPrimitive("a", FieldRepetitionType::REQUIRED, Type::INT32, 1); - ASSERT_NO_FATAL_FAILURE(Convert(elements, 2)); + newPrimitive("a", FieldRepetitionType::REQUIRED, Type::kInt32, 1); + ASSERT_NO_FATAL_FAILURE(convert(elements, 2)); - elements[0] = NewGroup("not-repeated", FieldRepetitionType::OPTIONAL, 1, 0); - ASSERT_NO_FATAL_FAILURE(Convert(elements, 2)); + elements[0] = newGroup("not-repeated", FieldRepetitionType::OPTIONAL, 1, 0); + ASSERT_NO_FATAL_FAILURE(convert(elements, 2)); } TEST_F(TestSchemaConverter, NotEnoughChildren) { - // Throw a ParquetException, but don't core dump or anything + // Throw a ParquetException, but don't core dump or anything. SchemaElement elt; std::vector elements; - elements.push_back(NewGroup(name_, FieldRepetitionType::REPEATED, 2, 0)); - ASSERT_THROW(Convert(&elements[0], 1), ParquetException); + elements.push_back(newGroup(name_, FieldRepetitionType::REPEATED, 2, 0)); + ASSERT_THROW(convert(&elements[0], 1), ParquetException); } -// ---------------------------------------------------------------------- -// Schema tree flatten / unflatten +// ----------------------------------------------------------------------. +// Schema tree flatten / unflatten. class TestSchemaFlatten : public ::testing::Test { public: - void setUp() { + void SetUp() { name_ = "parquet_schema"; } - void Flatten(const GroupNode* schema) { - ToParquet(schema, &elements_); + void flatten(const GroupNode* schema) { + toParquet(schema, &elements_); } protected: @@ -680,42 +679,42 @@ class TestSchemaFlatten : public ::testing::Test { }; TEST_F(TestSchemaFlatten, DecimalMetadata) { - // Checks that DecimalMetadata is only set for DecimalTypes - NodePtr node = PrimitiveNode::Make( + // Checks that DecimalMetadata is only set for DecimalTypes. + NodePtr Node = PrimitiveNode::make( "decimal", - Repetition::REQUIRED, - Type::INT64, - ConvertedType::DECIMAL, + Repetition::kRequired, + Type::kInt64, + ConvertedType::kDecimal, -1, 8, 4); - NodePtr group = GroupNode::Make( - "group", Repetition::REPEATED, {node}, ConvertedType::LIST); - Flatten(reinterpret_cast(group.get())); + NodePtr group = GroupNode::make( + "group", Repetition::kRepeated, {Node}, ConvertedType::kList); + flatten(reinterpret_cast(group.get())); ASSERT_EQ("decimal", elements_[1].name); ASSERT_TRUE(elements_[1].__isset.precision); ASSERT_TRUE(elements_[1].__isset.scale); elements_.clear(); - // ... including those created with new logical types - node = PrimitiveNode::Make( + // ... Including those created with new logical types. + Node = PrimitiveNode::make( "decimal", - Repetition::REQUIRED, - DecimalLogicalType::Make(10, 5), - Type::INT64, + Repetition::kRequired, + DecimalLogicalType::make(10, 5), + Type::kInt64, -1); - group = GroupNode::Make( - "group", Repetition::REPEATED, {node}, ListLogicalType::Make()); - Flatten(reinterpret_cast(group.get())); + group = GroupNode::make( + "group", Repetition::kRepeated, {Node}, ListLogicalType::make()); + flatten(reinterpret_cast(group.get())); ASSERT_EQ("decimal", elements_[1].name); ASSERT_TRUE(elements_[1].__isset.precision); ASSERT_TRUE(elements_[1].__isset.scale); elements_.clear(); - // Not for integers with no logical type - group = GroupNode::Make( - "group", Repetition::REPEATED, {Int64("int64")}, ConvertedType::LIST); - Flatten(reinterpret_cast(group.get())); + // Not for integers with no logical type. + group = GroupNode::make( + "group", Repetition::kRepeated, {int64("int64")}, ConvertedType::kList); + flatten(reinterpret_cast(group.get())); ASSERT_EQ("int64", elements_[1].name); ASSERT_FALSE(elements_[0].__isset.precision); ASSERT_FALSE(elements_[0].__isset.scale); @@ -724,17 +723,17 @@ TEST_F(TestSchemaFlatten, DecimalMetadata) { TEST_F(TestSchemaFlatten, NestedExample) { SchemaElement elt; std::vector elements; - elements.push_back(NewGroup(name_, FieldRepetitionType::REPEATED, 2, 0)); + elements.push_back(newGroup(name_, FieldRepetitionType::REPEATED, 2, 0)); - // A primitive one + // A primitive one. elements.push_back( - NewPrimitive("a", FieldRepetitionType::REQUIRED, Type::INT32, 1)); + newPrimitive("a", FieldRepetitionType::REQUIRED, Type::kInt32, 1)); - // A group - elements.push_back(NewGroup("bag", FieldRepetitionType::OPTIONAL, 1, 2)); + // A group. + elements.push_back(newGroup("bag", FieldRepetitionType::OPTIONAL, 1, 2)); - // 3-level list encoding, by hand - elt = NewGroup("b", FieldRepetitionType::REPEATED, 1, 3); + // 3-Level list encoding, by hand. + elt = newGroup("b", FieldRepetitionType::REPEATED, 1, 3); elt.__set_converted_type( facebook::velox::parquet::thrift::ConvertedType::LIST); facebook::velox::parquet::thrift::ListType ls; @@ -743,32 +742,25 @@ TEST_F(TestSchemaFlatten, NestedExample) { elt.__set_logicalType(lt); elements.push_back(elt); elements.push_back( - NewPrimitive("item", FieldRepetitionType::OPTIONAL, Type::INT64, 4)); + newPrimitive("item", FieldRepetitionType::OPTIONAL, Type::kInt64, 4)); - // Construct the schema + // Construct the schema. NodeVector fields; - fields.push_back(Int32("a", Repetition::REQUIRED, 1)); - - // 3-level list encoding - NodePtr item = Int64("item", Repetition::OPTIONAL, 4); - NodePtr list(GroupNode::Make( - "b", Repetition::REPEATED, {item}, ConvertedType::LIST, 3)); - NodePtr bag(GroupNode::Make( - "bag", - Repetition::OPTIONAL, - {list}, - /*logical_type=*/nullptr, - 2)); + fields.push_back(int32("a", Repetition::kRequired, 1)); + + // 3-Level list encoding. + NodePtr item = int64("item", Repetition::kOptional, 4); + NodePtr List( + GroupNode::make( + "b", Repetition::kRepeated, {item}, ConvertedType::kList, 3)); + NodePtr bag( + GroupNode::make("bag", Repetition::kOptional, {List}, nullptr, 2)); fields.push_back(bag); - NodePtr schema = GroupNode::Make( - name_, - Repetition::REPEATED, - fields, - /*logical_type=*/nullptr, - 0); + NodePtr schema = + GroupNode::make(name_, Repetition::kRepeated, fields, nullptr, 0); - Flatten(static_cast(schema.get())); + flatten(static_cast(schema.get())); ASSERT_EQ(elements_.size(), elements.size()); for (size_t i = 0; i < elements_.size(); i++) { ASSERT_EQ(elements_[i], elements[i]); @@ -776,18 +768,18 @@ TEST_F(TestSchemaFlatten, NestedExample) { } TEST(TestColumnDescriptor, TestAttrs) { - NodePtr node = PrimitiveNode::Make( - "name", Repetition::OPTIONAL, Type::BYTE_ARRAY, ConvertedType::UTF8); - ColumnDescriptor descr(node, 4, 1); + NodePtr Node = PrimitiveNode::make( + "name", Repetition::kOptional, Type::kByteArray, ConvertedType::kUtf8); + ColumnDescriptor descr(Node, 4, 1); ASSERT_EQ("name", descr.name()); - ASSERT_EQ(4, descr.max_definition_level()); - ASSERT_EQ(1, descr.max_repetition_level()); + ASSERT_EQ(4, descr.maxDefinitionLevel()); + ASSERT_EQ(1, descr.maxRepetitionLevel()); - ASSERT_EQ(Type::BYTE_ARRAY, descr.physical_type()); + ASSERT_EQ(Type::kByteArray, descr.physicalType()); - ASSERT_EQ(-1, descr.type_length()); - const char* expected_descr = R"(column descriptor = { + ASSERT_EQ(-1, descr.typeLength()); + const char* expectedDescr = R"(column descriptor = { name: name, path: , physical_type: BYTE_ARRAY, @@ -796,23 +788,23 @@ TEST(TestColumnDescriptor, TestAttrs) { max_definition_level: 4, max_repetition_level: 1, })"; - ASSERT_EQ(expected_descr, descr.ToString()); + ASSERT_EQ(expectedDescr, descr.toString()); - // Test FIXED_LEN_BYTE_ARRAY - node = PrimitiveNode::Make( + // Test FIXED_LEN_BYTE_ARRAY. + Node = PrimitiveNode::make( "name", - Repetition::OPTIONAL, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::DECIMAL, + Repetition::kOptional, + Type::kFixedLenByteArray, + ConvertedType::kDecimal, 12, 10, 4); - ColumnDescriptor descr2(node, 4, 1); + ColumnDescriptor descr2(Node, 4, 1); - ASSERT_EQ(Type::FIXED_LEN_BYTE_ARRAY, descr2.physical_type()); - ASSERT_EQ(12, descr2.type_length()); + ASSERT_EQ(Type::kFixedLenByteArray, descr2.physicalType()); + ASSERT_EQ(12, descr2.typeLength()); - expected_descr = R"(column descriptor = { + expectedDescr = R"(column descriptor = { name: name, path: , physical_type: FIXED_LEN_BYTE_ARRAY, @@ -824,248 +816,259 @@ TEST(TestColumnDescriptor, TestAttrs) { precision: 10, scale: 4, })"; - ASSERT_EQ(expected_descr, descr2.ToString()); + ASSERT_EQ(expectedDescr, descr2.toString()); } class TestSchemaDescriptor : public ::testing::Test { public: - void setUp() {} + void SetUp() {} protected: SchemaDescriptor descr_; }; TEST_F(TestSchemaDescriptor, InitNonGroup) { - NodePtr node = - PrimitiveNode::Make("field", Repetition::OPTIONAL, Type::INT32); + NodePtr Node = + PrimitiveNode::make("field", Repetition::kOptional, Type::kInt32); - ASSERT_THROW(descr_.Init(node), ParquetException); + ASSERT_THROW(descr_.init(Node), ParquetException); } -TEST_F(TestSchemaDescriptor, Equals) { +TEST_F(TestSchemaDescriptor, equals) { NodePtr schema; - NodePtr inta = Int32("a", Repetition::REQUIRED); - NodePtr intb = Int64("b", Repetition::OPTIONAL); - NodePtr intb2 = Int64("b2", Repetition::OPTIONAL); - NodePtr intc = ByteArray("c", Repetition::REPEATED); + NodePtr inta = int32("a", Repetition::kRequired); + NodePtr intb = int64("b", Repetition::kOptional); + NodePtr intb2 = int64("b2", Repetition::kOptional); + NodePtr intc = byteArray("c", Repetition::kRepeated); - NodePtr item1 = Int64("item1", Repetition::REQUIRED); - NodePtr item2 = Boolean("item2", Repetition::OPTIONAL); - NodePtr item3 = Int32("item3", Repetition::REPEATED); - NodePtr list(GroupNode::Make( - "records", - Repetition::REPEATED, - {item1, item2, item3}, - ConvertedType::LIST)); + NodePtr item1 = int64("item1", Repetition::kRequired); + NodePtr item2 = boolean("item2", Repetition::kOptional); + NodePtr item3 = int32("item3", Repetition::kRepeated); + NodePtr List( + GroupNode::make( + "records", + Repetition::kRepeated, + {item1, item2, item3}, + ConvertedType::kList)); - NodePtr bag(GroupNode::Make("bag", Repetition::OPTIONAL, {list})); - NodePtr bag2(GroupNode::Make("bag", Repetition::REQUIRED, {list})); + NodePtr bag(GroupNode::make("bag", Repetition::kOptional, {List})); + NodePtr bag2(GroupNode::make("bag", Repetition::kRequired, {List})); SchemaDescriptor descr1; - descr1.Init( - GroupNode::Make("schema", Repetition::REPEATED, {inta, intb, intc, bag})); + descr1.init( + GroupNode::make( + "schema", Repetition::kRepeated, {inta, intb, intc, bag})); - ASSERT_TRUE(descr1.Equals(descr1)); + ASSERT_TRUE(descr1.equals(descr1)); SchemaDescriptor descr2; - descr2.Init(GroupNode::Make( - "schema", Repetition::REPEATED, {inta, intb, intc, bag2})); - ASSERT_FALSE(descr1.Equals(descr2)); + descr2.init( + GroupNode::make( + "schema", Repetition::kRepeated, {inta, intb, intc, bag2})); + ASSERT_FALSE(descr1.equals(descr2)); SchemaDescriptor descr3; - descr3.Init(GroupNode::Make( - "schema", Repetition::REPEATED, {inta, intb2, intc, bag})); - ASSERT_FALSE(descr1.Equals(descr3)); + descr3.init( + GroupNode::make( + "schema", Repetition::kRepeated, {inta, intb2, intc, bag})); + ASSERT_FALSE(descr1.equals(descr3)); - // Robust to name of parent node + // Robust to name of parent node. SchemaDescriptor descr4; - descr4.Init( - GroupNode::Make("SCHEMA", Repetition::REPEATED, {inta, intb, intc, bag})); - ASSERT_TRUE(descr1.Equals(descr4)); + descr4.init( + GroupNode::make( + "SCHEMA", Repetition::kRepeated, {inta, intb, intc, bag})); + ASSERT_TRUE(descr1.equals(descr4)); SchemaDescriptor descr5; - descr5.Init(GroupNode::Make( - "schema", Repetition::REPEATED, {inta, intb, intc, bag, intb2})); - ASSERT_FALSE(descr1.Equals(descr5)); + descr5.init( + GroupNode::make( + "schema", Repetition::kRepeated, {inta, intb, intc, bag, intb2})); + ASSERT_FALSE(descr1.equals(descr5)); - // Different max repetition / definition levels + // Different max repetition / definition levels. ColumnDescriptor col1(inta, 5, 1); ColumnDescriptor col2(inta, 6, 1); ColumnDescriptor col3(inta, 5, 2); - ASSERT_TRUE(col1.Equals(col1)); - ASSERT_FALSE(col1.Equals(col2)); - ASSERT_FALSE(col1.Equals(col3)); + ASSERT_TRUE(col1.equals(col1)); + ASSERT_FALSE(col1.equals(col2)); + ASSERT_FALSE(col1.equals(col3)); } -TEST_F(TestSchemaDescriptor, BuildTree) { +TEST_F(TestSchemaDescriptor, buildTree) { NodeVector fields; NodePtr schema; - NodePtr inta = Int32("a", Repetition::REQUIRED); + NodePtr inta = int32("a", Repetition::kRequired); fields.push_back(inta); - fields.push_back(Int64("b", Repetition::OPTIONAL)); - fields.push_back(ByteArray("c", Repetition::REPEATED)); - - // 3-level list encoding - NodePtr item1 = Int64("item1", Repetition::REQUIRED); - NodePtr item2 = Boolean("item2", Repetition::OPTIONAL); - NodePtr item3 = Int32("item3", Repetition::REPEATED); - NodePtr list(GroupNode::Make( - "records", - Repetition::REPEATED, - {item1, item2, item3}, - ConvertedType::LIST)); - NodePtr bag(GroupNode::Make("bag", Repetition::OPTIONAL, {list})); + fields.push_back(int64("b", Repetition::kOptional)); + fields.push_back(byteArray("c", Repetition::kRepeated)); + + // 3-Level list encoding. + NodePtr item1 = int64("item1", Repetition::kRequired); + NodePtr item2 = boolean("item2", Repetition::kOptional); + NodePtr item3 = int32("item3", Repetition::kRepeated); + NodePtr List( + GroupNode::make( + "records", + Repetition::kRepeated, + {item1, item2, item3}, + ConvertedType::kList)); + NodePtr bag(GroupNode::make("bag", Repetition::kOptional, {List})); fields.push_back(bag); - schema = GroupNode::Make("schema", Repetition::REPEATED, fields); + schema = GroupNode::make("schema", Repetition::kRepeated, fields); - descr_.Init(schema); + descr_.init(schema); int nleaves = 6; - // 6 leaves - ASSERT_EQ(nleaves, descr_.num_columns()); - - // mdef mrep - // required int32 a 0 0 - // optional int64 b 1 0 - // repeated byte_array c 1 1 - // optional group bag 1 0 - // repeated group records 2 1 - // required int64 item1 2 1 - // optional boolean item2 3 1 - // repeated int32 item3 3 2 - int16_t ex_max_def_levels[6] = {0, 1, 1, 2, 3, 3}; - int16_t ex_max_rep_levels[6] = {0, 0, 1, 1, 1, 2}; + // 6 Leaves. + ASSERT_EQ(nleaves, descr_.numColumns()); + + // Mdef mrep. + // Required int32 a 0 0. + // Optional int64 b 1 0. + // Repeated byte_array c 1 1. + // Optional group bag 1 0. + // Repeated group records 2 1. + // Required int64 item1 2 1. + // Optional boolean item2 3 1. + // Repeated int32 item3 3 2. + int16_t exMaxDefLevels[6] = {0, 1, 1, 2, 3, 3}; + int16_t exMaxRepLevels[6] = {0, 0, 1, 1, 1, 2}; for (int i = 0; i < nleaves; ++i) { - const ColumnDescriptor* col = descr_.Column(i); - EXPECT_EQ(ex_max_def_levels[i], col->max_definition_level()) << i; - EXPECT_EQ(ex_max_rep_levels[i], col->max_repetition_level()) << i; + const ColumnDescriptor* col = descr_.column(i); + EXPECT_EQ(exMaxDefLevels[i], col->maxDefinitionLevel()) << i; + EXPECT_EQ(exMaxRepLevels[i], col->maxRepetitionLevel()) << i; } - ASSERT_EQ(descr_.Column(0)->path()->ToDotString(), "a"); - ASSERT_EQ(descr_.Column(1)->path()->ToDotString(), "b"); - ASSERT_EQ(descr_.Column(2)->path()->ToDotString(), "c"); - ASSERT_EQ(descr_.Column(3)->path()->ToDotString(), "bag.records.item1"); - ASSERT_EQ(descr_.Column(4)->path()->ToDotString(), "bag.records.item2"); - ASSERT_EQ(descr_.Column(5)->path()->ToDotString(), "bag.records.item3"); + ASSERT_EQ(descr_.column(0)->path()->toDotString(), "a"); + ASSERT_EQ(descr_.column(1)->path()->toDotString(), "b"); + ASSERT_EQ(descr_.column(2)->path()->toDotString(), "c"); + ASSERT_EQ(descr_.column(3)->path()->toDotString(), "bag.records.item1"); + ASSERT_EQ(descr_.column(4)->path()->toDotString(), "bag.records.item2"); + ASSERT_EQ(descr_.column(5)->path()->toDotString(), "bag.records.item3"); for (int i = 0; i < nleaves; ++i) { - auto col = descr_.Column(i); - ASSERT_EQ(i, descr_.ColumnIndex(*col->schema_node())); + auto col = descr_.column(i); + ASSERT_EQ(i, descr_.columnIndex(*col->schemaNode())); } - // Test non-column nodes find - NodePtr non_column_alien = Int32("alien", Repetition::REQUIRED); // other path - NodePtr non_column_familiar = Int32("a", Repetition::REPEATED); // other node - ASSERT_LT(descr_.ColumnIndex(*non_column_alien), 0); - ASSERT_LT(descr_.ColumnIndex(*non_column_familiar), 0); + // Test non-column nodes find. + NodePtr nonColumnAlien = int32("alien", Repetition::kRequired); // other path + NodePtr nonColumnFamiliar = int32("a", Repetition::kRepeated); // other node + ASSERT_LT(descr_.columnIndex(*nonColumnAlien), 0); + ASSERT_LT(descr_.columnIndex(*nonColumnFamiliar), 0); - ASSERT_EQ(inta.get(), descr_.GetColumnRoot(0)); - ASSERT_EQ(bag.get(), descr_.GetColumnRoot(3)); - ASSERT_EQ(bag.get(), descr_.GetColumnRoot(4)); - ASSERT_EQ(bag.get(), descr_.GetColumnRoot(5)); + ASSERT_EQ(inta.get(), descr_.getColumnRoot(0)); + ASSERT_EQ(bag.get(), descr_.getColumnRoot(3)); + ASSERT_EQ(bag.get(), descr_.getColumnRoot(4)); + ASSERT_EQ(bag.get(), descr_.getColumnRoot(5)); - ASSERT_EQ(schema.get(), descr_.group_node()); + ASSERT_EQ(schema.get(), descr_.groupNode()); - // Init clears the leaves - descr_.Init(schema); - ASSERT_EQ(nleaves, descr_.num_columns()); + // Init clears the leaves. + descr_.init(schema); + ASSERT_EQ(nleaves, descr_.numColumns()); } -TEST_F(TestSchemaDescriptor, HasRepeatedFields) { +TEST_F(TestSchemaDescriptor, hasRepeatedFields) { NodeVector fields; NodePtr schema; - NodePtr inta = Int32("a", Repetition::REQUIRED); + NodePtr inta = int32("a", Repetition::kRequired); fields.push_back(inta); - fields.push_back(Int64("b", Repetition::OPTIONAL)); - fields.push_back(ByteArray("c", Repetition::REPEATED)); - - schema = GroupNode::Make("schema", Repetition::REPEATED, fields); - descr_.Init(schema); - ASSERT_EQ(true, descr_.HasRepeatedFields()); - - // 3-level list encoding - NodePtr item1 = Int64("item1", Repetition::REQUIRED); - NodePtr item2 = Boolean("item2", Repetition::OPTIONAL); - NodePtr item3 = Int32("item3", Repetition::REPEATED); - NodePtr list(GroupNode::Make( - "records", - Repetition::REPEATED, - {item1, item2, item3}, - ConvertedType::LIST)); - NodePtr bag(GroupNode::Make("bag", Repetition::OPTIONAL, {list})); + fields.push_back(int64("b", Repetition::kOptional)); + fields.push_back(byteArray("c", Repetition::kRepeated)); + + schema = GroupNode::make("schema", Repetition::kRepeated, fields); + descr_.init(schema); + ASSERT_EQ(true, descr_.hasRepeatedFields()); + + // 3-Level list encoding. + NodePtr item1 = int64("item1", Repetition::kRequired); + NodePtr item2 = boolean("item2", Repetition::kOptional); + NodePtr item3 = int32("item3", Repetition::kRepeated); + NodePtr List( + GroupNode::make( + "records", + Repetition::kRepeated, + {item1, item2, item3}, + ConvertedType::kList)); + NodePtr bag(GroupNode::make("bag", Repetition::kOptional, {List})); fields.push_back(bag); - schema = GroupNode::Make("schema", Repetition::REPEATED, fields); - descr_.Init(schema); - ASSERT_EQ(true, descr_.HasRepeatedFields()); - - // 3-level list encoding - NodePtr item_key = Int64("key", Repetition::REQUIRED); - NodePtr item_value = Boolean("value", Repetition::OPTIONAL); - NodePtr map(GroupNode::Make( - "map", Repetition::REPEATED, {item_key, item_value}, ConvertedType::MAP)); - NodePtr my_map(GroupNode::Make("my_map", Repetition::OPTIONAL, {map})); - fields.push_back(my_map); - - schema = GroupNode::Make("schema", Repetition::REPEATED, fields); - descr_.Init(schema); - ASSERT_EQ(true, descr_.HasRepeatedFields()); - ASSERT_EQ(true, descr_.HasRepeatedFields()); + schema = GroupNode::make("schema", Repetition::kRepeated, fields); + descr_.init(schema); + ASSERT_EQ(true, descr_.hasRepeatedFields()); + + // 3-Level list encoding. + NodePtr itemKey = int64("key", Repetition::kRequired); + NodePtr itemValue = boolean("value", Repetition::kOptional); + NodePtr Map( + GroupNode::make( + "map", + Repetition::kRepeated, + {itemKey, itemValue}, + ConvertedType::kMap)); + NodePtr myMap(GroupNode::make("my_map", Repetition::kOptional, {Map})); + fields.push_back(myMap); + + schema = GroupNode::make("schema", Repetition::kRepeated, fields); + descr_.init(schema); + ASSERT_EQ(true, descr_.hasRepeatedFields()); + ASSERT_EQ(true, descr_.hasRepeatedFields()); } -static std::string Print(const NodePtr& node) { +static std::string print(const NodePtr& Node) { std::stringstream ss; - PrintSchema(node.get(), ss); + printSchema(Node.get(), ss); return ss.str(); } TEST(TestSchemaPrinter, Examples) { - // Test schema 1 + // Test schema 1. NodeVector fields; - fields.push_back(Int32("a", Repetition::REQUIRED, 1)); - - // 3-level list encoding - NodePtr item1 = Int64("item1", Repetition::OPTIONAL, 4); - NodePtr item2 = Boolean("item2", Repetition::REQUIRED, 5); - NodePtr list(GroupNode::Make( - "b", Repetition::REPEATED, {item1, item2}, ConvertedType::LIST, 3)); - NodePtr bag(GroupNode::Make( - "bag", Repetition::OPTIONAL, {list}, /*logical_type=*/nullptr, 2)); + fields.push_back(int32("a", Repetition::kRequired, 1)); + + // 3-Level list encoding. + NodePtr item1 = int64("item1", Repetition::kOptional, 4); + NodePtr item2 = boolean("item2", Repetition::kRequired, 5); + NodePtr List( + GroupNode::make( + "b", Repetition::kRepeated, {item1, item2}, ConvertedType::kList, 3)); + NodePtr bag( + GroupNode::make("bag", Repetition::kOptional, {List}, nullptr, 2)); fields.push_back(bag); - fields.push_back(PrimitiveNode::Make( - "c", - Repetition::REQUIRED, - Type::INT32, - ConvertedType::DECIMAL, - -1, - 3, - 2, - 6)); - - fields.push_back(PrimitiveNode::Make( - "d", - Repetition::REQUIRED, - DecimalLogicalType::Make(10, 5), - Type::INT64, - /*length=*/-1, - 7)); - - NodePtr schema = GroupNode::Make( - "schema", - Repetition::REPEATED, - fields, - /*logical_type=*/nullptr, - 0); + fields.push_back( + PrimitiveNode::make( + "c", + Repetition::kRequired, + Type::kInt32, + ConvertedType::kDecimal, + -1, + 3, + 2, + 6)); + + fields.push_back( + PrimitiveNode::make( + "d", + Repetition::kRequired, + DecimalLogicalType::make(10, 5), + Type::kInt64, + -1, + 7)); - std::string result = Print(schema); + NodePtr schema = + GroupNode::make("schema", Repetition::kRepeated, fields, nullptr, 0); + + std::string result = print(schema); std::string expected = R"(repeated group field_id=0 schema { required int32 field_id=1 a; @@ -1082,349 +1085,330 @@ TEST(TestSchemaPrinter, Examples) { ASSERT_EQ(expected, result); } -static void ConfirmFactoryEquivalence( - ConvertedType::type converted_type, - const std::shared_ptr& from_make, +static void confirmFactoryEquivalence( + ConvertedType::type convertedType, + const std::shared_ptr& fromMake, std::function&)> - check_is_type) { - std::shared_ptr from_converted_type = - LogicalType::FromConvertedType(converted_type); - ASSERT_EQ(from_converted_type->type(), from_make->type()) - << from_make->ToString() + checkIsType) { + std::shared_ptr fromConvertedType = + LogicalType::fromConvertedType(convertedType); + ASSERT_EQ(fromConvertedType->type(), fromMake->type()) + << fromMake->toString() << " logical types unexpectedly do not match on type"; - ASSERT_TRUE(from_converted_type->Equals(*from_make)) - << from_make->ToString() << " logical types unexpectedly not equivalent"; - ASSERT_TRUE(check_is_type(from_converted_type)) - << from_converted_type->ToString() + ASSERT_TRUE(fromConvertedType->equals(*fromMake)) + << fromMake->toString() << " logical types unexpectedly not equivalent"; + ASSERT_TRUE(checkIsType(fromConvertedType)) + << fromConvertedType->toString() << " logical type (from converted type) does not have expected type property"; - ASSERT_TRUE(check_is_type(from_make)) - << from_make->ToString() + ASSERT_TRUE(checkIsType(fromMake)) + << fromMake->toString() << " logical type (from Make()) does not have expected type property"; return; } TEST(TestLogicalTypeConstruction, FactoryEquivalence) { - // For each legacy converted type, ensure that the equivalent logical type + // For each legacy converted type, ensure that the equivalent logical type. // object can be obtained from either the base class's FromConvertedType() - // factory method or the logical type type class's Make() method (accessed via - // convenience methods on the base class) and that these logical type objects - // are equivalent + // Factory method or the logical type type class's Make() method (accessed + // via. Convenience methods on the base class) and that these logical type + // objects. Are equivalent. struct ConfirmFactoryEquivalenceArguments { - ConvertedType::type converted_type; - std::shared_ptr logical_type; - std::function&)> - check_is_type; + ConvertedType::type convertedType; + std::shared_ptr logicalType; + std::function&)> checkIsType; }; - auto check_is_string = - [](const std::shared_ptr& logical_type) { - return logical_type->is_string(); - }; - auto check_is_map = - [](const std::shared_ptr& logical_type) { - return logical_type->is_map(); - }; - auto check_is_list = - [](const std::shared_ptr& logical_type) { - return logical_type->is_list(); - }; - auto check_is_enum = - [](const std::shared_ptr& logical_type) { - return logical_type->is_enum(); - }; - auto check_is_date = - [](const std::shared_ptr& logical_type) { - return logical_type->is_date(); - }; - auto check_is_time = - [](const std::shared_ptr& logical_type) { - return logical_type->is_time(); - }; - auto check_is_timestamp = - [](const std::shared_ptr& logical_type) { - return logical_type->is_timestamp(); - }; - auto check_is_int = - [](const std::shared_ptr& logical_type) { - return logical_type->is_int(); - }; - auto check_is_JSON = - [](const std::shared_ptr& logical_type) { - return logical_type->is_JSON(); - }; - auto check_is_BSON = - [](const std::shared_ptr& logical_type) { - return logical_type->is_BSON(); + auto checkIsString = + [](const std::shared_ptr& logicalType) { + return logicalType->isString(); }; - auto check_is_interval = - [](const std::shared_ptr& logical_type) { - return logical_type->is_interval(); + auto checkIsMap = [](const std::shared_ptr& logicalType) { + return logicalType->isMap(); + }; + auto checkIsList = [](const std::shared_ptr& logicalType) { + return logicalType->isList(); + }; + auto checkIsenum = [](const std::shared_ptr& logicalType) { + return logicalType->isEnum(); + }; + auto checkIsDate = [](const std::shared_ptr& logicalType) { + return logicalType->isDate(); + }; + auto checkIsTime = [](const std::shared_ptr& logicalType) { + return logicalType->isTime(); + }; + auto checkIsTimestamp = + [](const std::shared_ptr& logicalType) { + return logicalType->isTimestamp(); }; - auto check_is_none = - [](const std::shared_ptr& logical_type) { - return logical_type->is_none(); + auto checkIsInt = [](const std::shared_ptr& logicalType) { + return logicalType->isInt(); + }; + auto checkIsJson = [](const std::shared_ptr& logicalType) { + return logicalType->isJson(); + }; + auto checkIsBson = [](const std::shared_ptr& logicalType) { + return logicalType->isBson(); + }; + auto checkIsInterval = + [](const std::shared_ptr& logicalType) { + return logicalType->isInterval(); }; + auto checkIsNone = [](const std::shared_ptr& logicalType) { + return logicalType->isNone(); + }; std::vector cases = { - {ConvertedType::UTF8, LogicalType::String(), check_is_string}, - {ConvertedType::MAP, LogicalType::Map(), check_is_map}, - {ConvertedType::MAP_KEY_VALUE, LogicalType::Map(), check_is_map}, - {ConvertedType::LIST, LogicalType::List(), check_is_list}, - {ConvertedType::ENUM, LogicalType::Enum(), check_is_enum}, - {ConvertedType::DATE, LogicalType::Date(), check_is_date}, - {ConvertedType::TIME_MILLIS, - LogicalType::Time(true, LogicalType::TimeUnit::MILLIS), - check_is_time}, - {ConvertedType::TIME_MICROS, - LogicalType::Time(true, LogicalType::TimeUnit::MICROS), - check_is_time}, - {ConvertedType::TIMESTAMP_MILLIS, - LogicalType::Timestamp(true, LogicalType::TimeUnit::MILLIS), - check_is_timestamp}, - {ConvertedType::TIMESTAMP_MICROS, - LogicalType::Timestamp(true, LogicalType::TimeUnit::MICROS), - check_is_timestamp}, - {ConvertedType::UINT_8, LogicalType::Int(8, false), check_is_int}, - {ConvertedType::UINT_16, LogicalType::Int(16, false), check_is_int}, - {ConvertedType::UINT_32, LogicalType::Int(32, false), check_is_int}, - {ConvertedType::UINT_64, LogicalType::Int(64, false), check_is_int}, - {ConvertedType::INT_8, LogicalType::Int(8, true), check_is_int}, - {ConvertedType::INT_16, LogicalType::Int(16, true), check_is_int}, - {ConvertedType::INT_32, LogicalType::Int(32, true), check_is_int}, - {ConvertedType::INT_64, LogicalType::Int(64, true), check_is_int}, - {ConvertedType::JSON, LogicalType::JSON(), check_is_JSON}, - {ConvertedType::BSON, LogicalType::BSON(), check_is_BSON}, - {ConvertedType::INTERVAL, LogicalType::Interval(), check_is_interval}, - {ConvertedType::NONE, LogicalType::None(), check_is_none}}; + {ConvertedType::kUtf8, LogicalType::string(), checkIsString}, + {ConvertedType::kMap, LogicalType::map(), checkIsMap}, + {ConvertedType::kMapKeyValue, LogicalType::map(), checkIsMap}, + {ConvertedType::kList, LogicalType::list(), checkIsList}, + {ConvertedType::kEnum, LogicalType::enumType(), checkIsenum}, + {ConvertedType::kDate, LogicalType::date(), checkIsDate}, + {ConvertedType::kTimeMillis, + LogicalType::time(true, LogicalType::TimeUnit::kMillis), + checkIsTime}, + {ConvertedType::kTimeMicros, + LogicalType::time(true, LogicalType::TimeUnit::kMicros), + checkIsTime}, + {ConvertedType::kTimestampMillis, + LogicalType::timestamp(true, LogicalType::TimeUnit::kMillis), + checkIsTimestamp}, + {ConvertedType::kTimestampMicros, + LogicalType::timestamp(true, LogicalType::TimeUnit::kMicros), + checkIsTimestamp}, + {ConvertedType::kUint8, LogicalType::intType(8, false), checkIsInt}, + {ConvertedType::kUint16, LogicalType::intType(16, false), checkIsInt}, + {ConvertedType::kUint32, LogicalType::intType(32, false), checkIsInt}, + {ConvertedType::kUint64, LogicalType::intType(64, false), checkIsInt}, + {ConvertedType::kInt8, LogicalType::intType(8, true), checkIsInt}, + {ConvertedType::kInt16, LogicalType::intType(16, true), checkIsInt}, + {ConvertedType::kInt32, LogicalType::intType(32, true), checkIsInt}, + {ConvertedType::kInt64, LogicalType::intType(64, true), checkIsInt}, + {ConvertedType::kJson, LogicalType::json(), checkIsJson}, + {ConvertedType::kBson, LogicalType::bson(), checkIsBson}, + {ConvertedType::kInterval, LogicalType::interval(), checkIsInterval}, + {ConvertedType::kNone, LogicalType::none(), checkIsNone}}; for (const ConfirmFactoryEquivalenceArguments& c : cases) { - ConfirmFactoryEquivalence( - c.converted_type, c.logical_type, c.check_is_type); + confirmFactoryEquivalence(c.convertedType, c.logicalType, c.checkIsType); } - // ConvertedType::DECIMAL, LogicalType::Decimal, is_decimal - schema::DecimalMetadata converted_decimal_metadata; - converted_decimal_metadata.isset = true; - converted_decimal_metadata.precision = 10; - converted_decimal_metadata.scale = 4; - std::shared_ptr from_converted_type = - LogicalType::FromConvertedType( - ConvertedType::DECIMAL, converted_decimal_metadata); - std::shared_ptr from_make = LogicalType::Decimal(10, 4); - ASSERT_EQ(from_converted_type->type(), from_make->type()); - ASSERT_TRUE(from_converted_type->Equals(*from_make)); - ASSERT_TRUE(from_converted_type->is_decimal()); - ASSERT_TRUE(from_make->is_decimal()); - ASSERT_TRUE(LogicalType::Decimal(16)->Equals(*LogicalType::Decimal(16, 0))); + // ConvertedType::kDecimal, LogicalType::decimal, is_decimal. + schema::DecimalMetadata convertedDecimalMetadata; + convertedDecimalMetadata.isset = true; + convertedDecimalMetadata.precision = 10; + convertedDecimalMetadata.scale = 4; + std::shared_ptr fromConvertedType = + LogicalType::fromConvertedType( + ConvertedType::kDecimal, convertedDecimalMetadata); + std::shared_ptr fromMake = LogicalType::decimal(10, 4); + ASSERT_EQ(fromConvertedType->type(), fromMake->type()); + ASSERT_TRUE(fromConvertedType->equals(*fromMake)); + ASSERT_TRUE(fromConvertedType->isDecimal()); + ASSERT_TRUE(fromMake->isDecimal()); + ASSERT_TRUE(LogicalType::decimal(16)->equals(*LogicalType::decimal(16, 0))); } -static void ConfirmConvertedTypeCompatibility( +static void confirmConvertedTypeCompatibility( const std::shared_ptr& original, - ConvertedType::type expected_converted_type) { - ASSERT_TRUE(original->is_valid()) - << original->ToString() << " logical type unexpectedly is not valid"; - schema::DecimalMetadata converted_decimal_metadata; - ConvertedType::type converted_type = - original->ToConvertedType(&converted_decimal_metadata); - ASSERT_EQ(converted_type, expected_converted_type) - << original->ToString() + ConvertedType::type expectedConvertedType) { + ASSERT_TRUE(original->isValid()) + << original->toString() << " logical type unexpectedly is not valid"; + schema::DecimalMetadata convertedDecimalMetadata; + ConvertedType::type convertedType = + original->toConvertedType(&convertedDecimalMetadata); + ASSERT_EQ(convertedType, expectedConvertedType) + << original->toString() << " logical type unexpectedly returns incorrect converted type"; - ASSERT_FALSE(converted_decimal_metadata.isset) - << original->ToString() + ASSERT_FALSE(convertedDecimalMetadata.isset) + << original->toString() << " logical type unexpectedly returns converted decimal metadata that is set"; - ASSERT_TRUE( - original->is_compatible(converted_type, converted_decimal_metadata)) - << original->ToString() + ASSERT_TRUE(original->isCompatible(convertedType, convertedDecimalMetadata)) + << original->toString() << " logical type unexpectedly is incompatible with converted type and decimal " "metadata it returned"; - ASSERT_FALSE(original->is_compatible(converted_type, {true, 1, 1})) - << original->ToString() + ASSERT_FALSE(original->isCompatible(convertedType, {true, 1, 1})) + << original->toString() << " logical type unexpectedly is compatible with converted decimal metadata that " "is " "set"; - ASSERT_TRUE(original->is_compatible(converted_type)) - << original->ToString() + ASSERT_TRUE(original->isCompatible(convertedType)) + << original->toString() << " logical type unexpectedly is incompatible with converted type it returned"; std::shared_ptr reconstructed = - LogicalType::FromConvertedType( - converted_type, converted_decimal_metadata); - ASSERT_TRUE(reconstructed->is_valid()) - << "Reconstructed " << reconstructed->ToString() + LogicalType::fromConvertedType(convertedType, convertedDecimalMetadata); + ASSERT_TRUE(reconstructed->isValid()) + << "Reconstructed " << reconstructed->toString() << " logical type unexpectedly is not valid"; - ASSERT_TRUE(reconstructed->Equals(*original)) - << "Reconstructed logical type (" << reconstructed->ToString() + ASSERT_TRUE(reconstructed->equals(*original)) + << "Reconstructed logical type (" << reconstructed->toString() << ") unexpectedly not equivalent to original logical type (" - << original->ToString() << ")"; + << original->toString() << ")"; return; } TEST(TestLogicalTypeConstruction, ConvertedTypeCompatibility) { - // For each legacy converted type, ensure that the equivalent logical type - // emits correct, compatible converted type information and that the emitted - // information can be used to reconstruct another equivalent logical type. + // For each legacy converted type, ensure that the equivalent logical type. + // Emits correct, compatible converted type information and that the emitted. + // Information can be used to reconstruct another equivalent logical type. struct ExpectedConvertedType { - std::shared_ptr logical_type; - ConvertedType::type converted_type; + std::shared_ptr logicalType; + ConvertedType::type convertedType; }; std::vector cases = { - {LogicalType::String(), ConvertedType::UTF8}, - {LogicalType::Map(), ConvertedType::MAP}, - {LogicalType::List(), ConvertedType::LIST}, - {LogicalType::Enum(), ConvertedType::ENUM}, - {LogicalType::Date(), ConvertedType::DATE}, - {LogicalType::Time(true, LogicalType::TimeUnit::MILLIS), - ConvertedType::TIME_MILLIS}, - {LogicalType::Time(true, LogicalType::TimeUnit::MICROS), - ConvertedType::TIME_MICROS}, - {LogicalType::Timestamp(true, LogicalType::TimeUnit::MILLIS), - ConvertedType::TIMESTAMP_MILLIS}, - {LogicalType::Timestamp(true, LogicalType::TimeUnit::MICROS), - ConvertedType::TIMESTAMP_MICROS}, - {LogicalType::Int(8, false), ConvertedType::UINT_8}, - {LogicalType::Int(16, false), ConvertedType::UINT_16}, - {LogicalType::Int(32, false), ConvertedType::UINT_32}, - {LogicalType::Int(64, false), ConvertedType::UINT_64}, - {LogicalType::Int(8, true), ConvertedType::INT_8}, - {LogicalType::Int(16, true), ConvertedType::INT_16}, - {LogicalType::Int(32, true), ConvertedType::INT_32}, - {LogicalType::Int(64, true), ConvertedType::INT_64}, - {LogicalType::JSON(), ConvertedType::JSON}, - {LogicalType::BSON(), ConvertedType::BSON}, - {LogicalType::Interval(), ConvertedType::INTERVAL}, - {LogicalType::None(), ConvertedType::NONE}}; + {LogicalType::string(), ConvertedType::kUtf8}, + {LogicalType::map(), ConvertedType::kMap}, + {LogicalType::list(), ConvertedType::kList}, + {LogicalType::enumType(), ConvertedType::kEnum}, + {LogicalType::date(), ConvertedType::kDate}, + {LogicalType::time(true, LogicalType::TimeUnit::kMillis), + ConvertedType::kTimeMillis}, + {LogicalType::time(true, LogicalType::TimeUnit::kMicros), + ConvertedType::kTimeMicros}, + {LogicalType::timestamp(true, LogicalType::TimeUnit::kMillis), + ConvertedType::kTimestampMillis}, + {LogicalType::timestamp(true, LogicalType::TimeUnit::kMicros), + ConvertedType::kTimestampMicros}, + {LogicalType::intType(8, false), ConvertedType::kUint8}, + {LogicalType::intType(16, false), ConvertedType::kUint16}, + {LogicalType::intType(32, false), ConvertedType::kUint32}, + {LogicalType::intType(64, false), ConvertedType::kUint64}, + {LogicalType::intType(8, true), ConvertedType::kInt8}, + {LogicalType::intType(16, true), ConvertedType::kInt16}, + {LogicalType::intType(32, true), ConvertedType::kInt32}, + {LogicalType::intType(64, true), ConvertedType::kInt64}, + {LogicalType::json(), ConvertedType::kJson}, + {LogicalType::bson(), ConvertedType::kBson}, + {LogicalType::interval(), ConvertedType::kInterval}, + {LogicalType::none(), ConvertedType::kNone}}; for (const ExpectedConvertedType& c : cases) { - ConfirmConvertedTypeCompatibility(c.logical_type, c.converted_type); + confirmConvertedTypeCompatibility(c.logicalType, c.convertedType); } // Special cases ... std::shared_ptr original; - ConvertedType::type converted_type; - schema::DecimalMetadata converted_decimal_metadata; + ConvertedType::type convertedType; + schema::DecimalMetadata convertedDecimalMetadata; std::shared_ptr reconstructed; - // DECIMAL + // DECIMAL. std::memset( - &converted_decimal_metadata, 0x00, sizeof(converted_decimal_metadata)); - original = LogicalType::Decimal(6, 2); - ASSERT_TRUE(original->is_valid()); - converted_type = original->ToConvertedType(&converted_decimal_metadata); - ASSERT_EQ(converted_type, ConvertedType::DECIMAL); - ASSERT_TRUE(converted_decimal_metadata.isset); - ASSERT_EQ(converted_decimal_metadata.precision, 6); - ASSERT_EQ(converted_decimal_metadata.scale, 2); - ASSERT_TRUE( - original->is_compatible(converted_type, converted_decimal_metadata)); - reconstructed = LogicalType::FromConvertedType( - converted_type, converted_decimal_metadata); - ASSERT_TRUE(reconstructed->is_valid()); - ASSERT_TRUE(reconstructed->Equals(*original)); - - // Undefined - original = UndefinedLogicalType::Make(); - ASSERT_TRUE(original->is_invalid()); - ASSERT_FALSE(original->is_valid()); - converted_type = original->ToConvertedType(&converted_decimal_metadata); - ASSERT_EQ(converted_type, ConvertedType::UNDEFINED); - ASSERT_FALSE(converted_decimal_metadata.isset); - ASSERT_TRUE( - original->is_compatible(converted_type, converted_decimal_metadata)); - ASSERT_TRUE(original->is_compatible(converted_type)); - reconstructed = LogicalType::FromConvertedType( - converted_type, converted_decimal_metadata); - ASSERT_TRUE(reconstructed->is_invalid()); - ASSERT_TRUE(reconstructed->Equals(*original)); + &convertedDecimalMetadata, 0x00, sizeof(convertedDecimalMetadata)); + original = LogicalType::decimal(6, 2); + ASSERT_TRUE(original->isValid()); + convertedType = original->toConvertedType(&convertedDecimalMetadata); + ASSERT_EQ(convertedType, ConvertedType::kDecimal); + ASSERT_TRUE(convertedDecimalMetadata.isset); + ASSERT_EQ(convertedDecimalMetadata.precision, 6); + ASSERT_EQ(convertedDecimalMetadata.scale, 2); + ASSERT_TRUE(original->isCompatible(convertedType, convertedDecimalMetadata)); + reconstructed = + LogicalType::fromConvertedType(convertedType, convertedDecimalMetadata); + ASSERT_TRUE(reconstructed->isValid()); + ASSERT_TRUE(reconstructed->equals(*original)); + + // Undefined. + original = UndefinedLogicalType::make(); + ASSERT_TRUE(original->isInvalid()); + ASSERT_FALSE(original->isValid()); + convertedType = original->toConvertedType(&convertedDecimalMetadata); + ASSERT_EQ(convertedType, ConvertedType::kUndefined); + ASSERT_FALSE(convertedDecimalMetadata.isset); + ASSERT_TRUE(original->isCompatible(convertedType, convertedDecimalMetadata)); + ASSERT_TRUE(original->isCompatible(convertedType)); + reconstructed = + LogicalType::fromConvertedType(convertedType, convertedDecimalMetadata); + ASSERT_TRUE(reconstructed->isInvalid()); + ASSERT_TRUE(reconstructed->equals(*original)); } -static void ConfirmNewTypeIncompatibility( - const std::shared_ptr& logical_type, +static void confirmNewTypeIncompatibility( + const std::shared_ptr& logicalType, std::function&)> - check_is_type) { - ASSERT_TRUE(logical_type->is_valid()) - << logical_type->ToString() << " logical type unexpectedly is not valid"; - ASSERT_TRUE(check_is_type(logical_type)) - << logical_type->ToString() + checkIsType) { + ASSERT_TRUE(logicalType->isValid()) + << logicalType->toString() << " logical type unexpectedly is not valid"; + ASSERT_TRUE(checkIsType(logicalType)) + << logicalType->toString() << " logical type is not expected logical type"; - schema::DecimalMetadata converted_decimal_metadata; - ConvertedType::type converted_type = - logical_type->ToConvertedType(&converted_decimal_metadata); - ASSERT_EQ(converted_type, ConvertedType::NONE) - << logical_type->ToString() + schema::DecimalMetadata convertedDecimalMetadata; + ConvertedType::type convertedType = + logicalType->toConvertedType(&convertedDecimalMetadata); + ASSERT_EQ(convertedType, ConvertedType::kNone) + << logicalType->toString() << " logical type converted type unexpectedly is not NONE"; - ASSERT_FALSE(converted_decimal_metadata.isset) - << logical_type->ToString() + ASSERT_FALSE(convertedDecimalMetadata.isset) + << logicalType->toString() << " logical type converted decimal metadata unexpectedly is set"; return; } TEST(TestLogicalTypeConstruction, NewTypeIncompatibility) { - // For each new logical type, ensure that the type - // correctly reports that it has no legacy equivalent + // For each new logical type, ensure that the type. + // Correctly reports that it has no legacy equivalent. struct ConfirmNewTypeIncompatibilityArguments { - std::shared_ptr logical_type; - std::function&)> - check_is_type; + std::shared_ptr logicalType; + std::function&)> checkIsType; }; - auto check_is_UUID = - [](const std::shared_ptr& logical_type) { - return logical_type->is_UUID(); - }; - auto check_is_null = - [](const std::shared_ptr& logical_type) { - return logical_type->is_null(); - }; - auto check_is_time = - [](const std::shared_ptr& logical_type) { - return logical_type->is_time(); - }; - auto check_is_timestamp = - [](const std::shared_ptr& logical_type) { - return logical_type->is_timestamp(); + auto checkIsUuid = [](const std::shared_ptr& logicalType) { + return logicalType->isUuid(); + }; + auto checkIsNull = [](const std::shared_ptr& logicalType) { + return logicalType->isNull(); + }; + auto checkIsTime = [](const std::shared_ptr& logicalType) { + return logicalType->isTime(); + }; + auto checkIsTimestamp = + [](const std::shared_ptr& logicalType) { + return logicalType->isTimestamp(); }; std::vector cases = { - {LogicalType::UUID(), check_is_UUID}, - {LogicalType::Null(), check_is_null}, - {LogicalType::Time(false, LogicalType::TimeUnit::MILLIS), check_is_time}, - {LogicalType::Time(false, LogicalType::TimeUnit::MICROS), check_is_time}, - {LogicalType::Time(false, LogicalType::TimeUnit::NANOS), check_is_time}, - {LogicalType::Time(true, LogicalType::TimeUnit::NANOS), check_is_time}, - {LogicalType::Timestamp(false, LogicalType::TimeUnit::NANOS), - check_is_timestamp}, - {LogicalType::Timestamp(true, LogicalType::TimeUnit::NANOS), - check_is_timestamp}, + {LogicalType::uuid(), checkIsUuid}, + {LogicalType::nullType(), checkIsNull}, + {LogicalType::time(false, LogicalType::TimeUnit::kMillis), checkIsTime}, + {LogicalType::time(false, LogicalType::TimeUnit::kMicros), checkIsTime}, + {LogicalType::time(false, LogicalType::TimeUnit::kNanos), checkIsTime}, + {LogicalType::time(true, LogicalType::TimeUnit::kNanos), checkIsTime}, + {LogicalType::timestamp(false, LogicalType::TimeUnit::kNanos), + checkIsTimestamp}, + {LogicalType::timestamp(true, LogicalType::TimeUnit::kNanos), + checkIsTimestamp}, }; for (const ConfirmNewTypeIncompatibilityArguments& c : cases) { - ConfirmNewTypeIncompatibility(c.logical_type, c.check_is_type); + confirmNewTypeIncompatibility(c.logicalType, c.checkIsType); } } TEST(TestLogicalTypeConstruction, FactoryExceptions) { - // Ensure that logical type construction catches invalid arguments + // Ensure that logical type construction catches invalid arguments. std::vector> cases = { []() { - TimeLogicalType::Make(true, LogicalType::TimeUnit::UNKNOWN); + TimeLogicalType::make(true, LogicalType::TimeUnit::kUnknown); }, // Invalid TimeUnit []() { - TimestampLogicalType::Make(true, LogicalType::TimeUnit::UNKNOWN); + TimestampLogicalType::make(true, LogicalType::TimeUnit::kUnknown); }, // Invalid TimeUnit - []() { IntLogicalType::Make(-1, false); }, // Invalid bit width - []() { IntLogicalType::Make(0, false); }, // Invalid bit width - []() { IntLogicalType::Make(1, false); }, // Invalid bit width - []() { IntLogicalType::Make(65, false); }, // Invalid bit width - []() { DecimalLogicalType::Make(-1); }, // Invalid precision - []() { DecimalLogicalType::Make(0); }, // Invalid precision - []() { DecimalLogicalType::Make(0, 0); }, // Invalid precision - []() { DecimalLogicalType::Make(10, -1); }, // Invalid scale - []() { DecimalLogicalType::Make(10, 11); } // Invalid scale + []() { IntLogicalType::make(-1, false); }, // Invalid bit width + []() { IntLogicalType::make(0, false); }, // Invalid bit width + []() { IntLogicalType::make(1, false); }, // Invalid bit width + []() { IntLogicalType::make(65, false); }, // Invalid bit width + []() { DecimalLogicalType::make(-1); }, // Invalid precision + []() { DecimalLogicalType::make(0); }, // Invalid precision + []() { DecimalLogicalType::make(0, 0); }, // Invalid precision + []() { DecimalLogicalType::make(10, -1); }, // Invalid scale + []() { DecimalLogicalType::make(10, 11); } // Invalid scale }; for (auto f : cases) { @@ -1432,247 +1416,248 @@ TEST(TestLogicalTypeConstruction, FactoryExceptions) { } } -static void ConfirmLogicalTypeProperties( - const std::shared_ptr& logical_type, +static void confirmLogicalTypeProperties( + const std::shared_ptr& logicalType, bool nested, bool serialized, bool valid) { - ASSERT_TRUE(logical_type->is_nested() == nested) - << logical_type->ToString() + ASSERT_TRUE(logicalType->isNested() == nested) + << logicalType->toString() << " logical type has incorrect nested() property"; - ASSERT_TRUE(logical_type->is_serialized() == serialized) - << logical_type->ToString() + ASSERT_TRUE(logicalType->isSerialized() == serialized) + << logicalType->toString() << " logical type has incorrect serialized() property"; - ASSERT_TRUE(logical_type->is_valid() == valid) - << logical_type->ToString() + ASSERT_TRUE(logicalType->isValid() == valid) + << logicalType->toString() << " logical type has incorrect valid() property"; - ASSERT_TRUE(logical_type->is_nonnested() != nested) - << logical_type->ToString() + ASSERT_TRUE(logicalType->isNonnested() != nested) + << logicalType->toString() << " logical type has incorrect nonnested() property"; - ASSERT_TRUE(logical_type->is_invalid() != valid) - << logical_type->ToString() + ASSERT_TRUE(logicalType->isInvalid() != valid) + << logicalType->toString() << " logical type has incorrect invalid() property"; return; } TEST(TestLogicalTypeOperation, LogicalTypeProperties) { - // For each logical type, ensure that the correct general properties are - // reported + // For each logical type, ensure that the correct general properties are. + // Reported. struct ExpectedProperties { - std::shared_ptr logical_type; + std::shared_ptr logicalType; bool nested; bool serialized; bool valid; }; std::vector cases = { - {StringLogicalType::Make(), false, true, true}, - {MapLogicalType::Make(), true, true, true}, - {ListLogicalType::Make(), true, true, true}, - {EnumLogicalType::Make(), false, true, true}, - {DecimalLogicalType::Make(16, 6), false, true, true}, - {DateLogicalType::Make(), false, true, true}, - {TimeLogicalType::Make(true, LogicalType::TimeUnit::MICROS), + {StringLogicalType::make(), false, true, true}, + {MapLogicalType::make(), true, true, true}, + {ListLogicalType::make(), true, true, true}, + {EnumLogicalType::make(), false, true, true}, + {DecimalLogicalType::make(16, 6), false, true, true}, + {DateLogicalType::make(), false, true, true}, + {TimeLogicalType::make(true, LogicalType::TimeUnit::kMicros), false, true, true}, - {TimestampLogicalType::Make(true, LogicalType::TimeUnit::MICROS), + {TimestampLogicalType::make(true, LogicalType::TimeUnit::kMicros), false, true, true}, - {IntervalLogicalType::Make(), false, true, true}, - {IntLogicalType::Make(8, false), false, true, true}, - {IntLogicalType::Make(64, true), false, true, true}, - {NullLogicalType::Make(), false, true, true}, - {JSONLogicalType::Make(), false, true, true}, - {BSONLogicalType::Make(), false, true, true}, - {UUIDLogicalType::Make(), false, true, true}, - {NoLogicalType::Make(), false, false, true}, + {IntervalLogicalType::make(), false, true, true}, + {IntLogicalType::make(8, false), false, true, true}, + {IntLogicalType::make(64, true), false, true, true}, + {NullLogicalType::make(), false, true, true}, + {JsonLogicalType::make(), false, true, true}, + {BsonLogicalType::make(), false, true, true}, + {UuidLogicalType::make(), false, true, true}, + {NoLogicalType::make(), false, false, true}, }; for (const ExpectedProperties& c : cases) { - ConfirmLogicalTypeProperties( - c.logical_type, c.nested, c.serialized, c.valid); + confirmLogicalTypeProperties( + c.logicalType, c.nested, c.serialized, c.valid); } } static constexpr int PHYSICAL_TYPE_COUNT = 8; -static Type::type physical_type[PHYSICAL_TYPE_COUNT] = { - Type::BOOLEAN, - Type::INT32, - Type::INT64, - Type::INT96, - Type::FLOAT, - Type::DOUBLE, - Type::BYTE_ARRAY, - Type::FIXED_LEN_BYTE_ARRAY}; - -static void ConfirmSinglePrimitiveTypeApplicability( - const std::shared_ptr& logical_type, - Type::type applicable_type) { +static Type::type physicalType[PHYSICAL_TYPE_COUNT] = { + Type::kBoolean, + Type::kInt32, + Type::kInt64, + Type::kInt96, + Type::kFloat, + Type::kDouble, + Type::kByteArray, + Type::kFixedLenByteArray}; + +static void confirmSinglePrimitiveTypeApplicability( + const std::shared_ptr& logicalType, + Type::type applicableType) { for (int i = 0; i < PHYSICAL_TYPE_COUNT; ++i) { - if (physical_type[i] == applicable_type) { - ASSERT_TRUE(logical_type->is_applicable(physical_type[i])) - << logical_type->ToString() + if (physicalType[i] == applicableType) { + ASSERT_TRUE(logicalType->isApplicable(physicalType[i])) + << logicalType->toString() << " logical type unexpectedly inapplicable to physical type " - << TypeToString(physical_type[i]); + << typeToString(physicalType[i]); } else { - ASSERT_FALSE(logical_type->is_applicable(physical_type[i])) - << logical_type->ToString() + ASSERT_FALSE(logicalType->isApplicable(physicalType[i])) + << logicalType->toString() << " logical type unexpectedly applicable to physical type " - << TypeToString(physical_type[i]); + << typeToString(physicalType[i]); } } return; } -static void ConfirmAnyPrimitiveTypeApplicability( - const std::shared_ptr& logical_type) { +static void confirmAnyPrimitiveTypeApplicability( + const std::shared_ptr& logicalType) { for (int i = 0; i < PHYSICAL_TYPE_COUNT; ++i) { - ASSERT_TRUE(logical_type->is_applicable(physical_type[i])) - << logical_type->ToString() + ASSERT_TRUE(logicalType->isApplicable(physicalType[i])) + << logicalType->toString() << " logical type unexpectedly inapplicable to physical type " - << TypeToString(physical_type[i]); + << typeToString(physicalType[i]); } return; } -static void ConfirmNoPrimitiveTypeApplicability( - const std::shared_ptr& logical_type) { +static void confirmNoPrimitiveTypeApplicability( + const std::shared_ptr& logicalType) { for (int i = 0; i < PHYSICAL_TYPE_COUNT; ++i) { - ASSERT_FALSE(logical_type->is_applicable(physical_type[i])) - << logical_type->ToString() + ASSERT_FALSE(logicalType->isApplicable(physicalType[i])) + << logicalType->toString() << " logical type unexpectedly applicable to physical type " - << TypeToString(physical_type[i]); + << typeToString(physicalType[i]); } return; } TEST(TestLogicalTypeOperation, LogicalTypeApplicability) { - // Check that each logical type correctly reports which - // underlying primitive type(s) it can be applied to + // Check that each logical type correctly reports which. + // Underlying primitive type(s) it can be applied to. struct ExpectedApplicability { - std::shared_ptr logical_type; - Type::type applicable_type; + std::shared_ptr logicalType; + Type::type applicableType; }; - std::vector single_type_cases = { - {LogicalType::String(), Type::BYTE_ARRAY}, - {LogicalType::Enum(), Type::BYTE_ARRAY}, - {LogicalType::Date(), Type::INT32}, - {LogicalType::Time(true, LogicalType::TimeUnit::MILLIS), Type::INT32}, - {LogicalType::Time(true, LogicalType::TimeUnit::MICROS), Type::INT64}, - {LogicalType::Time(true, LogicalType::TimeUnit::NANOS), Type::INT64}, - {LogicalType::Timestamp(true, LogicalType::TimeUnit::MILLIS), - Type::INT64}, - {LogicalType::Timestamp(true, LogicalType::TimeUnit::MICROS), - Type::INT64}, - {LogicalType::Timestamp(true, LogicalType::TimeUnit::NANOS), Type::INT64}, - {LogicalType::Int(8, false), Type::INT32}, - {LogicalType::Int(16, false), Type::INT32}, - {LogicalType::Int(32, false), Type::INT32}, - {LogicalType::Int(64, false), Type::INT64}, - {LogicalType::Int(8, true), Type::INT32}, - {LogicalType::Int(16, true), Type::INT32}, - {LogicalType::Int(32, true), Type::INT32}, - {LogicalType::Int(64, true), Type::INT64}, - {LogicalType::JSON(), Type::BYTE_ARRAY}, - {LogicalType::BSON(), Type::BYTE_ARRAY}}; - - for (const ExpectedApplicability& c : single_type_cases) { - ConfirmSinglePrimitiveTypeApplicability(c.logical_type, c.applicable_type); + std::vector singleTypeCases = { + {LogicalType::string(), Type::kByteArray}, + {LogicalType::enumType(), Type::kByteArray}, + {LogicalType::date(), Type::kInt32}, + {LogicalType::time(true, LogicalType::TimeUnit::kMillis), Type::kInt32}, + {LogicalType::time(true, LogicalType::TimeUnit::kMicros), Type::kInt64}, + {LogicalType::time(true, LogicalType::TimeUnit::kNanos), Type::kInt64}, + {LogicalType::timestamp(true, LogicalType::TimeUnit::kMillis), + Type::kInt64}, + {LogicalType::timestamp(true, LogicalType::TimeUnit::kMicros), + Type::kInt64}, + {LogicalType::timestamp(true, LogicalType::TimeUnit::kNanos), + Type::kInt64}, + {LogicalType::intType(8, false), Type::kInt32}, + {LogicalType::intType(16, false), Type::kInt32}, + {LogicalType::intType(32, false), Type::kInt32}, + {LogicalType::intType(64, false), Type::kInt64}, + {LogicalType::intType(8, true), Type::kInt32}, + {LogicalType::intType(16, true), Type::kInt32}, + {LogicalType::intType(32, true), Type::kInt32}, + {LogicalType::intType(64, true), Type::kInt64}, + {LogicalType::json(), Type::kByteArray}, + {LogicalType::bson(), Type::kByteArray}}; + + for (const ExpectedApplicability& c : singleTypeCases) { + confirmSinglePrimitiveTypeApplicability(c.logicalType, c.applicableType); } - std::vector> no_type_cases = { - LogicalType::Map(), LogicalType::List()}; + std::vector> noTypeCases = { + LogicalType::map(), LogicalType::list()}; - for (auto c : no_type_cases) { - ConfirmNoPrimitiveTypeApplicability(c); + for (auto c : noTypeCases) { + confirmNoPrimitiveTypeApplicability(c); } - std::vector> any_type_cases = { - LogicalType::Null(), LogicalType::None(), UndefinedLogicalType::Make()}; + std::vector> anyTypeCases = { + LogicalType::nullType(), + LogicalType::none(), + UndefinedLogicalType::make()}; - for (auto c : any_type_cases) { - ConfirmAnyPrimitiveTypeApplicability(c); + for (auto c : anyTypeCases) { + confirmAnyPrimitiveTypeApplicability(c); } // Fixed binary, exact length cases ... struct InapplicableType { - Type::type physical_type; - int physical_length; + Type::type physicalType; + int physicalLength; }; - std::vector inapplicable_types = { - {Type::FIXED_LEN_BYTE_ARRAY, 8}, - {Type::FIXED_LEN_BYTE_ARRAY, 20}, - {Type::BOOLEAN, -1}, - {Type::INT32, -1}, - {Type::INT64, -1}, - {Type::INT96, -1}, - {Type::FLOAT, -1}, - {Type::DOUBLE, -1}, - {Type::BYTE_ARRAY, -1}}; - - std::shared_ptr logical_type; - - logical_type = LogicalType::Interval(); - ASSERT_TRUE(logical_type->is_applicable(Type::FIXED_LEN_BYTE_ARRAY, 12)); - for (const InapplicableType& t : inapplicable_types) { - ASSERT_FALSE( - logical_type->is_applicable(t.physical_type, t.physical_length)); + std::vector inapplicableTypes = { + {Type::kFixedLenByteArray, 8}, + {Type::kFixedLenByteArray, 20}, + {Type::kBoolean, -1}, + {Type::kInt32, -1}, + {Type::kInt64, -1}, + {Type::kInt96, -1}, + {Type::kFloat, -1}, + {Type::kDouble, -1}, + {Type::kByteArray, -1}}; + + std::shared_ptr logicalType; + + logicalType = LogicalType::interval(); + ASSERT_TRUE(logicalType->isApplicable(Type::kFixedLenByteArray, 12)); + for (const InapplicableType& t : inapplicableTypes) { + ASSERT_FALSE(logicalType->isApplicable(t.physicalType, t.physicalLength)); } - logical_type = LogicalType::UUID(); - ASSERT_TRUE(logical_type->is_applicable(Type::FIXED_LEN_BYTE_ARRAY, 16)); - for (const InapplicableType& t : inapplicable_types) { - ASSERT_FALSE( - logical_type->is_applicable(t.physical_type, t.physical_length)); + logicalType = LogicalType::uuid(); + ASSERT_TRUE(logicalType->isApplicable(Type::kFixedLenByteArray, 16)); + for (const InapplicableType& t : inapplicableTypes) { + ASSERT_FALSE(logicalType->isApplicable(t.physicalType, t.physicalLength)); } } TEST(TestLogicalTypeOperation, DecimalLogicalTypeApplicability) { - // Check that the decimal logical type correctly reports which - // underlying primitive type(s) it can be applied to + // Check that the decimal logical type correctly reports which. + // Underlying primitive type(s) it can be applied to. - std::shared_ptr logical_type; + std::shared_ptr logicalType; for (int32_t precision = 1; precision <= 9; ++precision) { - logical_type = DecimalLogicalType::Make(precision, 0); - ASSERT_TRUE(logical_type->is_applicable(Type::INT32)) - << logical_type->ToString() + logicalType = DecimalLogicalType::make(precision, 0); + ASSERT_TRUE(logicalType->isApplicable(Type::kInt32)) + << logicalType->toString() << " unexpectedly inapplicable to physical type INT32"; } - logical_type = DecimalLogicalType::Make(10, 0); - ASSERT_FALSE(logical_type->is_applicable(Type::INT32)) - << logical_type->ToString() + logicalType = DecimalLogicalType::make(10, 0); + ASSERT_FALSE(logicalType->isApplicable(Type::kInt32)) + << logicalType->toString() << " unexpectedly applicable to physical type INT32"; for (int32_t precision = 1; precision <= 18; ++precision) { - logical_type = DecimalLogicalType::Make(precision, 0); - ASSERT_TRUE(logical_type->is_applicable(Type::INT64)) - << logical_type->ToString() + logicalType = DecimalLogicalType::make(precision, 0); + ASSERT_TRUE(logicalType->isApplicable(Type::kInt64)) + << logicalType->toString() << " unexpectedly inapplicable to physical type INT64"; } - logical_type = DecimalLogicalType::Make(19, 0); - ASSERT_FALSE(logical_type->is_applicable(Type::INT64)) - << logical_type->ToString() + logicalType = DecimalLogicalType::make(19, 0); + ASSERT_FALSE(logicalType->isApplicable(Type::kInt64)) + << logicalType->toString() << " unexpectedly applicable to physical type INT64"; for (int32_t precision = 1; precision <= 36; ++precision) { - logical_type = DecimalLogicalType::Make(precision, 0); - ASSERT_TRUE(logical_type->is_applicable(Type::BYTE_ARRAY)) - << logical_type->ToString() + logicalType = DecimalLogicalType::make(precision, 0); + ASSERT_TRUE(logicalType->isApplicable(Type::kByteArray)) + << logicalType->toString() << " unexpectedly inapplicable to physical type BYTE_ARRAY"; } struct PrecisionLimits { - int32_t physical_length; - int32_t precision_limit; + int32_t physicalLength; + int32_t precisionLimit; }; std::vector cases = { @@ -1688,607 +1673,645 @@ TEST(TestLogicalTypeOperation, DecimalLogicalTypeApplicability) { for (const PrecisionLimits& c : cases) { int32_t precision; - for (precision = 1; precision <= c.precision_limit; ++precision) { - logical_type = DecimalLogicalType::Make(precision, 0); - ASSERT_TRUE(logical_type->is_applicable( - Type::FIXED_LEN_BYTE_ARRAY, c.physical_length)) - << logical_type->ToString() + for (precision = 1; precision <= c.precisionLimit; ++precision) { + logicalType = DecimalLogicalType::make(precision, 0); + ASSERT_TRUE( + logicalType->isApplicable(Type::kFixedLenByteArray, c.physicalLength)) + << logicalType->toString() << " unexpectedly inapplicable to physical type FIXED_LEN_BYTE_ARRAY with " "length " - << c.physical_length; + << c.physicalLength; } - logical_type = DecimalLogicalType::Make(precision, 0); - ASSERT_FALSE(logical_type->is_applicable( - Type::FIXED_LEN_BYTE_ARRAY, c.physical_length)) - << logical_type->ToString() + logicalType = DecimalLogicalType::make(precision, 0); + ASSERT_FALSE( + logicalType->isApplicable(Type::kFixedLenByteArray, c.physicalLength)) + << logicalType->toString() << " unexpectedly applicable to physical type FIXED_LEN_BYTE_ARRAY with length " - << c.physical_length; + << c.physicalLength; } - ASSERT_FALSE((DecimalLogicalType::Make(16, 6))->is_applicable(Type::BOOLEAN)); - ASSERT_FALSE((DecimalLogicalType::Make(16, 6))->is_applicable(Type::FLOAT)); - ASSERT_FALSE((DecimalLogicalType::Make(16, 6))->is_applicable(Type::DOUBLE)); + ASSERT_FALSE((DecimalLogicalType::make(16, 6))->isApplicable(Type::kBoolean)); + ASSERT_FALSE((DecimalLogicalType::make(16, 6))->isApplicable(Type::kFloat)); + ASSERT_FALSE((DecimalLogicalType::make(16, 6))->isApplicable(Type::kDouble)); } TEST(TestLogicalTypeOperation, LogicalTypeRepresentation) { - // Ensure that each logical type prints a correct string and - // JSON representation + // Ensure that each logical type prints a correct string and. + // JSON representation. struct ExpectedRepresentation { - std::shared_ptr logical_type; - const char* string_representation; - const char* JSON_representation; + std::shared_ptr logicalType; + const char* stringRepresentation; + const char* jsonRepresentation; }; std::vector cases = { - {UndefinedLogicalType::Make(), "Undefined", R"({"Type": "Undefined"})"}, - {LogicalType::String(), "String", R"({"Type": "String"})"}, - {LogicalType::Map(), "Map", R"({"Type": "Map"})"}, - {LogicalType::List(), "List", R"({"Type": "List"})"}, - {LogicalType::Enum(), "Enum", R"({"Type": "Enum"})"}, - {LogicalType::Decimal(10, 4), + {UndefinedLogicalType::make(), "Undefined", R"({"Type": "Undefined"})"}, + {LogicalType::string(), "String", R"({"Type": "String"})"}, + {LogicalType::map(), "Map", R"({"Type": "Map"})"}, + {LogicalType::list(), "List", R"({"Type": "List"})"}, + {LogicalType::enumType(), "Enum", R"({"Type": "Enum"})"}, + {LogicalType::decimal(10, 4), "Decimal(precision=10, scale=4)", R"({"Type": "Decimal", "precision": 10, "scale": 4})"}, - {LogicalType::Decimal(10), + {LogicalType::decimal(10), "Decimal(precision=10, scale=0)", R"({"Type": "Decimal", "precision": 10, "scale": 0})"}, - {LogicalType::Date(), "Date", R"({"Type": "Date"})"}, - {LogicalType::Time(true, LogicalType::TimeUnit::MILLIS), + {LogicalType::date(), "Date", R"({"Type": "Date"})"}, + {LogicalType::time(true, LogicalType::TimeUnit::kMillis), "Time(isAdjustedToUTC=true, timeUnit=milliseconds)", R"({"Type": "Time", "isAdjustedToUTC": true, "timeUnit": "milliseconds"})"}, - {LogicalType::Time(true, LogicalType::TimeUnit::MICROS), + {LogicalType::time(true, LogicalType::TimeUnit::kMicros), "Time(isAdjustedToUTC=true, timeUnit=microseconds)", R"({"Type": "Time", "isAdjustedToUTC": true, "timeUnit": "microseconds"})"}, - {LogicalType::Time(true, LogicalType::TimeUnit::NANOS), + {LogicalType::time(true, LogicalType::TimeUnit::kNanos), "Time(isAdjustedToUTC=true, timeUnit=nanoseconds)", R"({"Type": "Time", "isAdjustedToUTC": true, "timeUnit": "nanoseconds"})"}, - {LogicalType::Time(false, LogicalType::TimeUnit::MILLIS), + {LogicalType::time(false, LogicalType::TimeUnit::kMillis), "Time(isAdjustedToUTC=false, timeUnit=milliseconds)", R"({"Type": "Time", "isAdjustedToUTC": false, "timeUnit": "milliseconds"})"}, - {LogicalType::Time(false, LogicalType::TimeUnit::MICROS), + {LogicalType::time(false, LogicalType::TimeUnit::kMicros), "Time(isAdjustedToUTC=false, timeUnit=microseconds)", R"({"Type": "Time", "isAdjustedToUTC": false, "timeUnit": "microseconds"})"}, - {LogicalType::Time(false, LogicalType::TimeUnit::NANOS), + {LogicalType::time(false, LogicalType::TimeUnit::kNanos), "Time(isAdjustedToUTC=false, timeUnit=nanoseconds)", R"({"Type": "Time", "isAdjustedToUTC": false, "timeUnit": "nanoseconds"})"}, - {LogicalType::Timestamp(true, LogicalType::TimeUnit::MILLIS), + {LogicalType::timestamp(true, LogicalType::TimeUnit::kMillis), "Timestamp(isAdjustedToUTC=true, timeUnit=milliseconds, " "is_from_converted_type=false, force_set_converted_type=false)", R"({"Type": "Timestamp", "isAdjustedToUTC": true, "timeUnit": "milliseconds", )" - R"("is_from_converted_type": false, "force_set_converted_type": false})"}, - {LogicalType::Timestamp(true, LogicalType::TimeUnit::MICROS), + R"("isFromConvertedType": false, "forceSetConvertedType": false})"}, + {LogicalType::timestamp(true, LogicalType::TimeUnit::kMicros), "Timestamp(isAdjustedToUTC=true, timeUnit=microseconds, " "is_from_converted_type=false, force_set_converted_type=false)", R"({"Type": "Timestamp", "isAdjustedToUTC": true, "timeUnit": "microseconds", )" - R"("is_from_converted_type": false, "force_set_converted_type": false})"}, - {LogicalType::Timestamp(true, LogicalType::TimeUnit::NANOS), + R"("isFromConvertedType": false, "forceSetConvertedType": false})"}, + {LogicalType::timestamp(true, LogicalType::TimeUnit::kNanos), "Timestamp(isAdjustedToUTC=true, timeUnit=nanoseconds, " "is_from_converted_type=false, force_set_converted_type=false)", R"({"Type": "Timestamp", "isAdjustedToUTC": true, "timeUnit": "nanoseconds", )" - R"("is_from_converted_type": false, "force_set_converted_type": false})"}, - {LogicalType::Timestamp(false, LogicalType::TimeUnit::MILLIS, true, true), + R"("isFromConvertedType": false, "forceSetConvertedType": false})"}, + {LogicalType::timestamp( + false, LogicalType::TimeUnit::kMillis, true, true), "Timestamp(isAdjustedToUTC=false, timeUnit=milliseconds, " "is_from_converted_type=true, force_set_converted_type=true)", R"({"Type": "Timestamp", "isAdjustedToUTC": false, "timeUnit": "milliseconds", )" - R"("is_from_converted_type": true, "force_set_converted_type": true})"}, - {LogicalType::Timestamp(false, LogicalType::TimeUnit::MICROS), + R"("isFromConvertedType": true, "forceSetConvertedType": true})"}, + {LogicalType::timestamp(false, LogicalType::TimeUnit::kMicros), "Timestamp(isAdjustedToUTC=false, timeUnit=microseconds, " "is_from_converted_type=false, force_set_converted_type=false)", R"({"Type": "Timestamp", "isAdjustedToUTC": false, "timeUnit": "microseconds", )" - R"("is_from_converted_type": false, "force_set_converted_type": false})"}, - {LogicalType::Timestamp(false, LogicalType::TimeUnit::NANOS), + R"("isFromConvertedType": false, "forceSetConvertedType": false})"}, + {LogicalType::timestamp(false, LogicalType::TimeUnit::kNanos), "Timestamp(isAdjustedToUTC=false, timeUnit=nanoseconds, " "is_from_converted_type=false, force_set_converted_type=false)", R"({"Type": "Timestamp", "isAdjustedToUTC": false, "timeUnit": "nanoseconds", )" - R"("is_from_converted_type": false, "force_set_converted_type": false})"}, - {LogicalType::Interval(), "Interval", R"({"Type": "Interval"})"}, - {LogicalType::Int(8, false), + R"("isFromConvertedType": false, "forceSetConvertedType": false})"}, + {LogicalType::interval(), "Interval", R"({"Type": "Interval"})"}, + {LogicalType::intType(8, false), "Int(bitWidth=8, isSigned=false)", - R"({"Type": "Int", "bitWidth": 8, "isSigned": false})"}, - {LogicalType::Int(16, false), + R"({"Type": "int", "bitWidth": 8, "isSigned": false})"}, + {LogicalType::intType(16, false), "Int(bitWidth=16, isSigned=false)", - R"({"Type": "Int", "bitWidth": 16, "isSigned": false})"}, - {LogicalType::Int(32, false), + R"({"Type": "int", "bitWidth": 16, "isSigned": false})"}, + {LogicalType::intType(32, false), "Int(bitWidth=32, isSigned=false)", - R"({"Type": "Int", "bitWidth": 32, "isSigned": false})"}, - {LogicalType::Int(64, false), + R"({"Type": "int", "bitWidth": 32, "isSigned": false})"}, + {LogicalType::intType(64, false), "Int(bitWidth=64, isSigned=false)", - R"({"Type": "Int", "bitWidth": 64, "isSigned": false})"}, - {LogicalType::Int(8, true), + R"({"Type": "int", "bitWidth": 64, "isSigned": false})"}, + {LogicalType::intType(8, true), "Int(bitWidth=8, isSigned=true)", - R"({"Type": "Int", "bitWidth": 8, "isSigned": true})"}, - {LogicalType::Int(16, true), + R"({"Type": "int", "bitWidth": 8, "isSigned": true})"}, + {LogicalType::intType(16, true), "Int(bitWidth=16, isSigned=true)", - R"({"Type": "Int", "bitWidth": 16, "isSigned": true})"}, - {LogicalType::Int(32, true), + R"({"Type": "int", "bitWidth": 16, "isSigned": true})"}, + {LogicalType::intType(32, true), "Int(bitWidth=32, isSigned=true)", - R"({"Type": "Int", "bitWidth": 32, "isSigned": true})"}, - {LogicalType::Int(64, true), + R"({"Type": "int", "bitWidth": 32, "isSigned": true})"}, + {LogicalType::intType(64, true), "Int(bitWidth=64, isSigned=true)", - R"({"Type": "Int", "bitWidth": 64, "isSigned": true})"}, - {LogicalType::Null(), "Null", R"({"Type": "Null"})"}, - {LogicalType::JSON(), "JSON", R"({"Type": "JSON"})"}, - {LogicalType::BSON(), "BSON", R"({"Type": "BSON"})"}, - {LogicalType::UUID(), "UUID", R"({"Type": "UUID"})"}, - {LogicalType::None(), "None", R"({"Type": "None"})"}, + R"({"Type": "int", "bitWidth": 64, "isSigned": true})"}, + {LogicalType::nullType(), "Null", R"({"Type": "Null"})"}, + {LogicalType::json(), "JSON", R"({"Type": "JSON"})"}, + {LogicalType::bson(), "BSON", R"({"Type": "BSON"})"}, + {LogicalType::uuid(), "UUID", R"({"Type": "UUID"})"}, + {LogicalType::none(), "None", R"({"Type": "None"})"}, }; for (const ExpectedRepresentation& c : cases) { - ASSERT_STREQ(c.logical_type->ToString().c_str(), c.string_representation); - ASSERT_STREQ(c.logical_type->ToJSON().c_str(), c.JSON_representation); + ASSERT_STREQ(c.logicalType->toString().c_str(), c.stringRepresentation); + ASSERT_STREQ(c.logicalType->toJson().c_str(), c.jsonRepresentation); } } TEST(TestLogicalTypeOperation, LogicalTypeSortOrder) { - // Ensure that each logical type reports the correct sort order + // Ensure that each logical type reports the correct sort order. struct ExpectedSortOrder { - std::shared_ptr logical_type; - SortOrder::type sort_order; + std::shared_ptr logicalType; + SortOrder::type sortOrder; }; std::vector cases = { - {LogicalType::String(), SortOrder::UNSIGNED}, - {LogicalType::Map(), SortOrder::UNKNOWN}, - {LogicalType::List(), SortOrder::UNKNOWN}, - {LogicalType::Enum(), SortOrder::UNSIGNED}, - {LogicalType::Decimal(8, 2), SortOrder::SIGNED}, - {LogicalType::Date(), SortOrder::SIGNED}, - {LogicalType::Time(true, LogicalType::TimeUnit::MILLIS), - SortOrder::SIGNED}, - {LogicalType::Time(true, LogicalType::TimeUnit::MICROS), - SortOrder::SIGNED}, - {LogicalType::Time(true, LogicalType::TimeUnit::NANOS), - SortOrder::SIGNED}, - {LogicalType::Time(false, LogicalType::TimeUnit::MILLIS), - SortOrder::SIGNED}, - {LogicalType::Time(false, LogicalType::TimeUnit::MICROS), - SortOrder::SIGNED}, - {LogicalType::Time(false, LogicalType::TimeUnit::NANOS), - SortOrder::SIGNED}, - {LogicalType::Timestamp(true, LogicalType::TimeUnit::MILLIS), - SortOrder::SIGNED}, - {LogicalType::Timestamp(true, LogicalType::TimeUnit::MICROS), - SortOrder::SIGNED}, - {LogicalType::Timestamp(true, LogicalType::TimeUnit::NANOS), - SortOrder::SIGNED}, - {LogicalType::Timestamp(false, LogicalType::TimeUnit::MILLIS), - SortOrder::SIGNED}, - {LogicalType::Timestamp(false, LogicalType::TimeUnit::MICROS), - SortOrder::SIGNED}, - {LogicalType::Timestamp(false, LogicalType::TimeUnit::NANOS), - SortOrder::SIGNED}, - {LogicalType::Interval(), SortOrder::UNKNOWN}, - {LogicalType::Int(8, false), SortOrder::UNSIGNED}, - {LogicalType::Int(16, false), SortOrder::UNSIGNED}, - {LogicalType::Int(32, false), SortOrder::UNSIGNED}, - {LogicalType::Int(64, false), SortOrder::UNSIGNED}, - {LogicalType::Int(8, true), SortOrder::SIGNED}, - {LogicalType::Int(16, true), SortOrder::SIGNED}, - {LogicalType::Int(32, true), SortOrder::SIGNED}, - {LogicalType::Int(64, true), SortOrder::SIGNED}, - {LogicalType::Null(), SortOrder::UNKNOWN}, - {LogicalType::JSON(), SortOrder::UNSIGNED}, - {LogicalType::BSON(), SortOrder::UNSIGNED}, - {LogicalType::UUID(), SortOrder::UNSIGNED}, - {LogicalType::None(), SortOrder::UNKNOWN}}; + {LogicalType::string(), SortOrder::kUnsigned}, + {LogicalType::map(), SortOrder::kUnknown}, + {LogicalType::list(), SortOrder::kUnknown}, + {LogicalType::enumType(), SortOrder::kUnsigned}, + {LogicalType::decimal(8, 2), SortOrder::kSigned}, + {LogicalType::date(), SortOrder::kSigned}, + {LogicalType::time(true, LogicalType::TimeUnit::kMillis), + SortOrder::kSigned}, + {LogicalType::time(true, LogicalType::TimeUnit::kMicros), + SortOrder::kSigned}, + {LogicalType::time(true, LogicalType::TimeUnit::kNanos), + SortOrder::kSigned}, + {LogicalType::time(false, LogicalType::TimeUnit::kMillis), + SortOrder::kSigned}, + {LogicalType::time(false, LogicalType::TimeUnit::kMicros), + SortOrder::kSigned}, + {LogicalType::time(false, LogicalType::TimeUnit::kNanos), + SortOrder::kSigned}, + {LogicalType::timestamp(true, LogicalType::TimeUnit::kMillis), + SortOrder::kSigned}, + {LogicalType::timestamp(true, LogicalType::TimeUnit::kMicros), + SortOrder::kSigned}, + {LogicalType::timestamp(true, LogicalType::TimeUnit::kNanos), + SortOrder::kSigned}, + {LogicalType::timestamp(false, LogicalType::TimeUnit::kMillis), + SortOrder::kSigned}, + {LogicalType::timestamp(false, LogicalType::TimeUnit::kMicros), + SortOrder::kSigned}, + {LogicalType::timestamp(false, LogicalType::TimeUnit::kNanos), + SortOrder::kSigned}, + {LogicalType::interval(), SortOrder::kUnknown}, + {LogicalType::intType(8, false), SortOrder::kUnsigned}, + {LogicalType::intType(16, false), SortOrder::kUnsigned}, + {LogicalType::intType(32, false), SortOrder::kUnsigned}, + {LogicalType::intType(64, false), SortOrder::kUnsigned}, + {LogicalType::intType(8, true), SortOrder::kSigned}, + {LogicalType::intType(16, true), SortOrder::kSigned}, + {LogicalType::intType(32, true), SortOrder::kSigned}, + {LogicalType::intType(64, true), SortOrder::kSigned}, + {LogicalType::nullType(), SortOrder::kUnknown}, + {LogicalType::json(), SortOrder::kUnsigned}, + {LogicalType::bson(), SortOrder::kUnsigned}, + {LogicalType::uuid(), SortOrder::kUnsigned}, + {LogicalType::none(), SortOrder::kUnknown}}; for (const ExpectedSortOrder& c : cases) { - ASSERT_EQ(c.logical_type->sort_order(), c.sort_order) - << c.logical_type->ToString() + ASSERT_EQ(c.logicalType->sortOrder(), c.sortOrder) + << c.logicalType->toString() << " logical type has incorrect sort order"; } } -static void ConfirmPrimitiveNodeFactoryEquivalence( - const std::shared_ptr& logical_type, - ConvertedType::type converted_type, - Type::type physical_type, - int physical_length, +static void confirmPrimitiveNodeFactoryEquivalence( + const std::shared_ptr& logicalType, + ConvertedType::type convertedType, + Type::type physicalType, + int physicalLength, int precision, int scale) { std::string name = "something"; - Repetition::type repetition = Repetition::REQUIRED; - NodePtr from_converted_type = PrimitiveNode::Make( + Repetition::type repetition = Repetition::kRequired; + NodePtr fromConvertedType = PrimitiveNode::make( name, repetition, - physical_type, - converted_type, - physical_length, + physicalType, + convertedType, + physicalLength, precision, scale); - NodePtr from_logical_type = PrimitiveNode::Make( - name, repetition, logical_type, physical_type, physical_length); - ASSERT_TRUE(from_converted_type->Equals(from_logical_type.get())) + NodePtr fromLogicalType = PrimitiveNode::make( + name, repetition, logicalType, physicalType, physicalLength); + ASSERT_TRUE(fromConvertedType->equals(fromLogicalType.get())) << "Primitive node constructed with converted type " - << ConvertedTypeToString(converted_type) + << convertedTypeToString(convertedType) << " unexpectedly not equivalent to primitive node constructed with logical " "type " - << logical_type->ToString(); + << logicalType->toString(); return; } -static void ConfirmGroupNodeFactoryEquivalence( +static void confirmGroupNodeFactoryEquivalence( std::string name, - const std::shared_ptr& logical_type, - ConvertedType::type converted_type) { - Repetition::type repetition = Repetition::OPTIONAL; - NodePtr from_converted_type = - GroupNode::Make(name, repetition, {}, converted_type); - NodePtr from_logical_type = - GroupNode::Make(name, repetition, {}, logical_type); - ASSERT_TRUE(from_converted_type->Equals(from_logical_type.get())) + const std::shared_ptr& logicalType, + ConvertedType::type convertedType) { + Repetition::type repetition = Repetition::kOptional; + NodePtr fromConvertedType = + GroupNode::make(name, repetition, {}, convertedType); + NodePtr fromLogicalType = GroupNode::make(name, repetition, {}, logicalType); + ASSERT_TRUE(fromConvertedType->equals(fromLogicalType.get())) << "Group node constructed with converted type " - << ConvertedTypeToString(converted_type) + << convertedTypeToString(convertedType) << " unexpectedly not equivalent to group node constructed with logical type " - << logical_type->ToString(); + << logicalType->toString(); return; } TEST(TestSchemaNodeCreation, FactoryEquivalence) { - // Ensure that the Node factory methods produce equivalent results regardless - // of whether they are given a converted type or a logical type. + // Ensure that the Node factory methods produce equivalent results regardless. + // Of whether they are given a converted type or a logical type. // Primitive nodes ... struct PrimitiveNodeFactoryArguments { - std::shared_ptr logical_type; - ConvertedType::type converted_type; - Type::type physical_type; - int physical_length; + std::shared_ptr logicalType; + ConvertedType::type convertedType; + Type::type physicalType; + int physicalLength; int precision; int scale; }; std::vector cases = { - {LogicalType::String(), - ConvertedType::UTF8, - Type::BYTE_ARRAY, + {LogicalType::string(), + ConvertedType::kUtf8, + Type::kByteArray, -1, -1, -1}, - {LogicalType::Enum(), ConvertedType::ENUM, Type::BYTE_ARRAY, -1, -1, -1}, - {LogicalType::Decimal(16, 6), - ConvertedType::DECIMAL, - Type::INT64, + {LogicalType::enumType(), + ConvertedType::kEnum, + Type::kByteArray, + -1, + -1, + -1}, + {LogicalType::decimal(16, 6), + ConvertedType::kDecimal, + Type::kInt64, -1, 16, 6}, - {LogicalType::Date(), ConvertedType::DATE, Type::INT32, -1, -1, -1}, - {LogicalType::Time(true, LogicalType::TimeUnit::MILLIS), - ConvertedType::TIME_MILLIS, - Type::INT32, + {LogicalType::date(), ConvertedType::kDate, Type::kInt32, -1, -1, -1}, + {LogicalType::time(true, LogicalType::TimeUnit::kMillis), + ConvertedType::kTimeMillis, + Type::kInt32, -1, -1, -1}, - {LogicalType::Time(true, LogicalType::TimeUnit::MICROS), - ConvertedType::TIME_MICROS, - Type::INT64, + {LogicalType::time(true, LogicalType::TimeUnit::kMicros), + ConvertedType::kTimeMicros, + Type::kInt64, -1, -1, -1}, - {LogicalType::Timestamp(true, LogicalType::TimeUnit::MILLIS), - ConvertedType::TIMESTAMP_MILLIS, - Type::INT64, + {LogicalType::timestamp(true, LogicalType::TimeUnit::kMillis), + ConvertedType::kTimestampMillis, + Type::kInt64, -1, -1, -1}, - {LogicalType::Timestamp(true, LogicalType::TimeUnit::MICROS), - ConvertedType::TIMESTAMP_MICROS, - Type::INT64, + {LogicalType::timestamp(true, LogicalType::TimeUnit::kMicros), + ConvertedType::kTimestampMicros, + Type::kInt64, -1, -1, -1}, - {LogicalType::Interval(), - ConvertedType::INTERVAL, - Type::FIXED_LEN_BYTE_ARRAY, + {LogicalType::interval(), + ConvertedType::kInterval, + Type::kFixedLenByteArray, 12, -1, -1}, - {LogicalType::Int(8, false), - ConvertedType::UINT_8, - Type::INT32, + {LogicalType::intType(8, false), + ConvertedType::kUint8, + Type::kInt32, -1, -1, -1}, - {LogicalType::Int(8, true), - ConvertedType::INT_8, - Type::INT32, + {LogicalType::intType(8, true), + ConvertedType::kInt8, + Type::kInt32, -1, -1, -1}, - {LogicalType::Int(16, false), - ConvertedType::UINT_16, - Type::INT32, + {LogicalType::intType(16, false), + ConvertedType::kUint16, + Type::kInt32, -1, -1, -1}, - {LogicalType::Int(16, true), - ConvertedType::INT_16, - Type::INT32, + {LogicalType::intType(16, true), + ConvertedType::kInt16, + Type::kInt32, -1, -1, -1}, - {LogicalType::Int(32, false), - ConvertedType::UINT_32, - Type::INT32, + {LogicalType::intType(32, false), + ConvertedType::kUint32, + Type::kInt32, -1, -1, -1}, - {LogicalType::Int(32, true), - ConvertedType::INT_32, - Type::INT32, + {LogicalType::intType(32, true), + ConvertedType::kInt32, + Type::kInt32, -1, -1, -1}, - {LogicalType::Int(64, false), - ConvertedType::UINT_64, - Type::INT64, + {LogicalType::intType(64, false), + ConvertedType::kUint64, + Type::kInt64, -1, -1, -1}, - {LogicalType::Int(64, true), - ConvertedType::INT_64, - Type::INT64, + {LogicalType::intType(64, true), + ConvertedType::kInt64, + Type::kInt64, -1, -1, -1}, - {LogicalType::JSON(), ConvertedType::JSON, Type::BYTE_ARRAY, -1, -1, -1}, - {LogicalType::BSON(), ConvertedType::BSON, Type::BYTE_ARRAY, -1, -1, -1}, - {LogicalType::None(), ConvertedType::NONE, Type::INT64, -1, -1, -1}}; + {LogicalType::json(), ConvertedType::kJson, Type::kByteArray, -1, -1, -1}, + {LogicalType::bson(), ConvertedType::kBson, Type::kByteArray, -1, -1, -1}, + {LogicalType::none(), ConvertedType::kNone, Type::kInt64, -1, -1, -1}}; for (const PrimitiveNodeFactoryArguments& c : cases) { - ConfirmPrimitiveNodeFactoryEquivalence( - c.logical_type, - c.converted_type, - c.physical_type, - c.physical_length, + confirmPrimitiveNodeFactoryEquivalence( + c.logicalType, + c.convertedType, + c.physicalType, + c.physicalLength, c.precision, c.scale); } // Group nodes ... - ConfirmGroupNodeFactoryEquivalence( - "map", LogicalType::Map(), ConvertedType::MAP); - ConfirmGroupNodeFactoryEquivalence( - "list", LogicalType::List(), ConvertedType::LIST); + confirmGroupNodeFactoryEquivalence( + "map", LogicalType::map(), ConvertedType::kMap); + confirmGroupNodeFactoryEquivalence( + "list", LogicalType::list(), ConvertedType::kList); } TEST(TestSchemaNodeCreation, FactoryExceptions) { - // Ensure that the Node factory method that accepts a logical type refuses to - // create an object if compatibility conditions are not met + // Ensure that the Node factory method that accepts a logical type refuses to. + // Create an object if compatibility conditions are not met. // Nested logical type on non-group node ... - ASSERT_ANY_THROW(PrimitiveNode::Make( - "map", Repetition::REQUIRED, MapLogicalType::Make(), Type::INT64)); + ASSERT_ANY_THROW( + PrimitiveNode::make( + "map", Repetition::kRequired, MapLogicalType::make(), Type::kInt64)); // Incompatible primitive type ... - ASSERT_ANY_THROW(PrimitiveNode::Make( - "string", - Repetition::REQUIRED, - StringLogicalType::Make(), - Type::BOOLEAN)); + ASSERT_ANY_THROW( + PrimitiveNode::make( + "string", + Repetition::kRequired, + StringLogicalType::make(), + Type::kBoolean)); // Incompatible primitive length ... - ASSERT_ANY_THROW(PrimitiveNode::Make( - "interval", - Repetition::REQUIRED, - IntervalLogicalType::Make(), - Type::FIXED_LEN_BYTE_ARRAY, - 11)); + ASSERT_ANY_THROW( + PrimitiveNode::make( + "interval", + Repetition::kRequired, + IntervalLogicalType::make(), + Type::kFixedLenByteArray, + 11)); // Scale is greater than precision. - ASSERT_ANY_THROW(PrimitiveNode::Make( - "decimal", - Repetition::REQUIRED, - DecimalLogicalType::Make(10, 11), - Type::INT64)); - ASSERT_ANY_THROW(PrimitiveNode::Make( - "decimal", - Repetition::REQUIRED, - DecimalLogicalType::Make(17, 18), - Type::INT64)); + ASSERT_ANY_THROW( + PrimitiveNode::make( + "decimal", + Repetition::kRequired, + DecimalLogicalType::make(10, 11), + Type::kInt64)); + ASSERT_ANY_THROW( + PrimitiveNode::make( + "decimal", + Repetition::kRequired, + DecimalLogicalType::make(17, 18), + Type::kInt64)); // Primitive too small for given precision ... - ASSERT_ANY_THROW(PrimitiveNode::Make( - "decimal", - Repetition::REQUIRED, - DecimalLogicalType::Make(16, 6), - Type::INT32)); - ASSERT_ANY_THROW(PrimitiveNode::Make( - "decimal", - Repetition::REQUIRED, - DecimalLogicalType::Make(10, 9), - Type::INT32)); - ASSERT_ANY_THROW(PrimitiveNode::Make( - "decimal", - Repetition::REQUIRED, - DecimalLogicalType::Make(19, 17), - Type::INT64)); - ASSERT_ANY_THROW(PrimitiveNode::Make( - "decimal", - Repetition::REQUIRED, - DecimalLogicalType::Make(308, 6), - Type::FIXED_LEN_BYTE_ARRAY, - 128)); - // Length is too long - ASSERT_ANY_THROW(PrimitiveNode::Make( - "decimal", - Repetition::REQUIRED, - DecimalLogicalType::Make(10, 6), - Type::FIXED_LEN_BYTE_ARRAY, - 891723283)); + ASSERT_ANY_THROW( + PrimitiveNode::make( + "decimal", + Repetition::kRequired, + DecimalLogicalType::make(16, 6), + Type::kInt32)); + ASSERT_ANY_THROW( + PrimitiveNode::make( + "decimal", + Repetition::kRequired, + DecimalLogicalType::make(10, 9), + Type::kInt32)); + ASSERT_ANY_THROW( + PrimitiveNode::make( + "decimal", + Repetition::kRequired, + DecimalLogicalType::make(19, 17), + Type::kInt64)); + ASSERT_ANY_THROW( + PrimitiveNode::make( + "decimal", + Repetition::kRequired, + DecimalLogicalType::make(308, 6), + Type::kFixedLenByteArray, + 128)); + // Length is too long. + ASSERT_ANY_THROW( + PrimitiveNode::make( + "decimal", + Repetition::kRequired, + DecimalLogicalType::make(10, 6), + Type::kFixedLenByteArray, + 891723283)); // Incompatible primitive length ... - ASSERT_ANY_THROW(PrimitiveNode::Make( - "uuid", - Repetition::REQUIRED, - UUIDLogicalType::Make(), - Type::FIXED_LEN_BYTE_ARRAY, - 64)); + ASSERT_ANY_THROW( + PrimitiveNode::make( + "uuid", + Repetition::kRequired, + UuidLogicalType::make(), + Type::kFixedLenByteArray, + 64)); // Non-positive length argument for fixed length binary ... - ASSERT_ANY_THROW(PrimitiveNode::Make( - "negative_length", - Repetition::REQUIRED, - NoLogicalType::Make(), - Type::FIXED_LEN_BYTE_ARRAY, - -16)); + ASSERT_ANY_THROW( + PrimitiveNode::make( + "negative_length", + Repetition::kRequired, + NoLogicalType::make(), + Type::kFixedLenByteArray, + -16)); // Non-positive length argument for fixed length binary ... - ASSERT_ANY_THROW(PrimitiveNode::Make( - "zero_length", - Repetition::REQUIRED, - NoLogicalType::Make(), - Type::FIXED_LEN_BYTE_ARRAY, - 0)); + ASSERT_ANY_THROW( + PrimitiveNode::make( + "zero_length", + Repetition::kRequired, + NoLogicalType::make(), + Type::kFixedLenByteArray, + 0)); // Non-nested logical type on group node ... - ASSERT_ANY_THROW(GroupNode::Make( - "list", Repetition::REPEATED, {}, JSONLogicalType::Make())); + ASSERT_ANY_THROW( + GroupNode::make( + "list", Repetition::kRepeated, {}, JsonLogicalType::make())); - // nullptr logical type arguments convert to NoLogicalType/ConvertedType::NONE + // Nullptr logical type arguments convert to + // NoLogicalType/ConvertedType::kNone. std::shared_ptr empty; - NodePtr node; + NodePtr Node; ASSERT_NO_THROW( - node = PrimitiveNode::Make( - "value", Repetition::REQUIRED, empty, Type::DOUBLE)); - ASSERT_TRUE(node->logical_type()->is_none()); - ASSERT_EQ(node->converted_type(), ConvertedType::NONE); + Node = PrimitiveNode::make( + "value", Repetition::kRequired, empty, Type::kDouble)); + ASSERT_TRUE(Node->logicalType()->isNone()); + ASSERT_EQ(Node->convertedType(), ConvertedType::kNone); ASSERT_NO_THROW( - node = GroupNode::Make("items", Repetition::REPEATED, {}, empty)); - ASSERT_TRUE(node->logical_type()->is_none()); - ASSERT_EQ(node->converted_type(), ConvertedType::NONE); + Node = GroupNode::make("items", Repetition::kRepeated, {}, empty)); + ASSERT_TRUE(Node->logicalType()->isNone()); + ASSERT_EQ(Node->convertedType(), ConvertedType::kNone); // Invalid ConvertedType in deserialized element ... - node = PrimitiveNode::Make( + Node = PrimitiveNode::make( "string", - Repetition::REQUIRED, - StringLogicalType::Make(), - Type::BYTE_ARRAY); - ASSERT_EQ(node->logical_type()->type(), LogicalType::Type::STRING); - ASSERT_TRUE(node->logical_type()->is_valid()); - ASSERT_TRUE(node->logical_type()->is_serialized()); - facebook::velox::parquet::thrift::SchemaElement string_intermediary; - node->ToParquet(&string_intermediary); - // ... corrupt the Thrift intermediary .... - string_intermediary.logicalType.__isset.STRING = false; - ASSERT_ANY_THROW(node = PrimitiveNode::FromParquet(&string_intermediary)); + Repetition::kRequired, + StringLogicalType::make(), + Type::kByteArray); + ASSERT_EQ(Node->logicalType()->type(), LogicalType::Type::kString); + ASSERT_TRUE(Node->logicalType()->isValid()); + ASSERT_TRUE(Node->logicalType()->isSerialized()); + facebook::velox::parquet::thrift::SchemaElement stringIntermediary; + Node->toParquet(&stringIntermediary); + // ... Corrupt the Thrift intermediary .... + stringIntermediary.logicalType.__isset.STRING = false; + ASSERT_ANY_THROW(Node = PrimitiveNode::fromParquet(&stringIntermediary)); // Invalid TimeUnit in deserialized TimeLogicalType ... - node = PrimitiveNode::Make( + Node = PrimitiveNode::make( "time", - Repetition::REQUIRED, - TimeLogicalType::Make(true, LogicalType::TimeUnit::NANOS), - Type::INT64); - facebook::velox::parquet::thrift::SchemaElement time_intermediary; - node->ToParquet(&time_intermediary); - // ... corrupt the Thrift intermediary .... - time_intermediary.logicalType.TIME.unit.__isset.NANOS = false; - ASSERT_ANY_THROW(PrimitiveNode::FromParquet(&time_intermediary)); + Repetition::kRequired, + TimeLogicalType::make(true, LogicalType::TimeUnit::kNanos), + Type::kInt64); + facebook::velox::parquet::thrift::SchemaElement timeIntermediary; + Node->toParquet(&timeIntermediary); + // ... Corrupt the Thrift intermediary .... + timeIntermediary.logicalType.TIME.unit.__isset.NANOS = false; + ASSERT_ANY_THROW(PrimitiveNode::fromParquet(&timeIntermediary)); // Invalid TimeUnit in deserialized TimestampLogicalType ... - node = PrimitiveNode::Make( + Node = PrimitiveNode::make( "timestamp", - Repetition::REQUIRED, - TimestampLogicalType::Make(true, LogicalType::TimeUnit::NANOS), - Type::INT64); - facebook::velox::parquet::thrift::SchemaElement timestamp_intermediary; - node->ToParquet(×tamp_intermediary); - // ... corrupt the Thrift intermediary .... - timestamp_intermediary.logicalType.TIMESTAMP.unit.__isset.NANOS = false; - ASSERT_ANY_THROW(PrimitiveNode::FromParquet(×tamp_intermediary)); + Repetition::kRequired, + TimestampLogicalType::make(true, LogicalType::TimeUnit::kNanos), + Type::kInt64); + facebook::velox::parquet::thrift::SchemaElement timestampIntermediary; + Node->toParquet(×tampIntermediary); + // ... Corrupt the Thrift intermediary .... + timestampIntermediary.logicalType.TIMESTAMP.unit.__isset.NANOS = false; + ASSERT_ANY_THROW(PrimitiveNode::fromParquet(×tampIntermediary)); } struct SchemaElementConstructionArguments { + SchemaElementConstructionArguments( + std::string name, + std::shared_ptr logicalType, + Type::type physicalType, + int physicalLength, + bool expectConvertedType, + ConvertedType::type convertedType, + bool expectLogicaltype, + std::function checkLogicaltype) + : name(std::move(name)), + logicalType(std::move(logicalType)), + physicalType(physicalType), + physicalLength(physicalLength), + expectConvertedType(expectConvertedType), + convertedType(convertedType), + expectLogicaltype(expectLogicaltype), + checkLogicaltype(std::move(checkLogicaltype)) {} + std::string name; - std::shared_ptr logical_type; - Type::type physical_type; - int physical_length; - bool expect_converted_type; - ConvertedType::type converted_type; - bool expect_logicalType; - std::function check_logicalType; + std::shared_ptr logicalType; + Type::type physicalType; + int physicalLength; + bool expectConvertedType; + ConvertedType::type convertedType; + bool expectLogicaltype; + std::function checkLogicaltype; }; struct LegacySchemaElementConstructionArguments { std::string name; - Type::type physical_type; - int physical_length; - bool expect_converted_type; - ConvertedType::type converted_type; - bool expect_logicalType; - std::function check_logicalType; + Type::type physicalType; + int physicalLength; + bool expectConvertedType; + ConvertedType::type convertedType; + bool expectLogicaltype; + std::function checkLogicaltype; }; class TestSchemaElementConstruction : public ::testing::Test { public: - TestSchemaElementConstruction* Reconstruct( + TestSchemaElementConstruction* reconstruct( const SchemaElementConstructionArguments& c) { // Make node, create serializable Thrift object from it ... - node_ = PrimitiveNode::Make( + node_ = PrimitiveNode::make( c.name, - Repetition::REQUIRED, - c.logical_type, - c.physical_type, - c.physical_length); + Repetition::kRequired, + c.logicalType, + c.physicalType, + c.physicalLength); element_.reset(new facebook::velox::parquet::thrift::SchemaElement); - node_->ToParquet(element_.get()); + node_->toParquet(element_.get()); - // ... then set aside some values for later inspection. + // ... Then set aside some values for later inspection. name_ = c.name; - expect_converted_type_ = c.expect_converted_type; - converted_type_ = c.converted_type; - expect_logicalType_ = c.expect_logicalType; - check_logicalType_ = c.check_logicalType; + expectConvertedType_ = c.expectConvertedType; + convertedType_ = c.convertedType; + expectLogicaltype_ = c.expectLogicaltype; + checkLogicaltype_ = c.checkLogicaltype; return this; } - TestSchemaElementConstruction* LegacyReconstruct( + TestSchemaElementConstruction* legacyReconstruct( const LegacySchemaElementConstructionArguments& c) { // Make node, create serializable Thrift object from it ... - node_ = PrimitiveNode::Make( + node_ = PrimitiveNode::make( c.name, - Repetition::REQUIRED, - c.physical_type, - c.converted_type, - c.physical_length); + Repetition::kRequired, + c.physicalType, + c.convertedType, + c.physicalLength); element_.reset(new facebook::velox::parquet::thrift::SchemaElement); - node_->ToParquet(element_.get()); + node_->toParquet(element_.get()); - // ... then set aside some values for later inspection. + // ... Then set aside some values for later inspection. name_ = c.name; - expect_converted_type_ = c.expect_converted_type; - converted_type_ = c.converted_type; - expect_logicalType_ = c.expect_logicalType; - check_logicalType_ = c.check_logicalType; + expectConvertedType_ = c.expectConvertedType; + convertedType_ = c.convertedType; + expectLogicaltype_ = c.expectLogicaltype; + checkLogicaltype_ = c.checkLogicaltype; return this; } - void Inspect() { + void inspect() { ASSERT_EQ(element_->name, name_); - if (expect_converted_type_) { + if (expectConvertedType_) { ASSERT_TRUE(element_->__isset.converted_type) - << node_->logical_type()->ToString() + << node_->logicalType()->toString() << " logical type unexpectedly failed to generate a converted type in the " "Thrift " "intermediate object"; - ASSERT_EQ(element_->converted_type, ToThrift(converted_type_)) - << node_->logical_type()->ToString() + ASSERT_EQ(element_->converted_type, toThrift(convertedType_)) + << node_->logicalType()->toString() << " logical type unexpectedly failed to generate correct converted type in " "the " "Thrift intermediate object"; } else { ASSERT_FALSE(element_->__isset.converted_type) - << node_->logical_type()->ToString() + << node_->logicalType()->toString() << " logical type unexpectedly generated a converted type in the Thrift " "intermediate object"; } - if (expect_logicalType_) { + if (expectLogicaltype_) { ASSERT_TRUE(element_->__isset.logicalType) - << node_->logical_type()->ToString() + << node_->logicalType()->toString() << " logical type unexpectedly failed to genverate a logicalType in the Thrift " "intermediate object"; - ASSERT_TRUE(check_logicalType_()) - << node_->logical_type()->ToString() + ASSERT_TRUE(checkLogicaltype_()) + << node_->logicalType()->toString() << " logical type generated incorrect logicalType " "settings in the Thrift intermediate object"; } else { ASSERT_FALSE(element_->__isset.logicalType) - << node_->logical_type()->ToString() + << node_->logicalType()->toString() << " logical type unexpectedly generated a logicalType in the Thrift " "intermediate object"; } @@ -2299,13 +2322,13 @@ class TestSchemaElementConstruction : public ::testing::Test { NodePtr node_; std::unique_ptr element_; std::string name_; - bool expect_converted_type_; + bool expectConvertedType_; ConvertedType::type - converted_type_; // expected converted type in Thrift object - bool expect_logicalType_; + convertedType_; // expected converted type in Thrift object + bool expectLogicaltype_; std::function - check_logicalType_; // specialized (by logical type) - // logicalType check for Thrift object + checkLogicaltype_; // specialized (by logical type) + // LogicalType check for Thrift object. }; /* @@ -2316,125 +2339,125 @@ class TestSchemaElementConstruction : public ::testing::Test { */ TEST_F(TestSchemaElementConstruction, SimpleCases) { - auto check_nothing = []() { + auto checkNothing = []() { return true; }; // used for logical types that don't expect a logicalType to be set std::vector cases = { {"string", - LogicalType::String(), - Type::BYTE_ARRAY, + LogicalType::string(), + Type::kByteArray, -1, true, - ConvertedType::UTF8, + ConvertedType::kUtf8, true, [this]() { return element_->logicalType.__isset.STRING; }}, {"enum", - LogicalType::Enum(), - Type::BYTE_ARRAY, + LogicalType::enumType(), + Type::kByteArray, -1, true, - ConvertedType::ENUM, + ConvertedType::kEnum, true, [this]() { return element_->logicalType.__isset.ENUM; }}, {"date", - LogicalType::Date(), - Type::INT32, + LogicalType::date(), + Type::kInt32, -1, true, - ConvertedType::DATE, + ConvertedType::kDate, true, [this]() { return element_->logicalType.__isset.DATE; }}, {"interval", - LogicalType::Interval(), - Type::FIXED_LEN_BYTE_ARRAY, + LogicalType::interval(), + Type::kFixedLenByteArray, 12, true, - ConvertedType::INTERVAL, + ConvertedType::kInterval, false, - check_nothing}, + checkNothing}, {"null", - LogicalType::Null(), - Type::DOUBLE, + LogicalType::nullType(), + Type::kDouble, -1, false, - ConvertedType::NA, + ConvertedType::kNa, true, [this]() { return element_->logicalType.__isset.UNKNOWN; }}, {"json", - LogicalType::JSON(), - Type::BYTE_ARRAY, + LogicalType::json(), + Type::kByteArray, -1, true, - ConvertedType::JSON, + ConvertedType::kJson, true, [this]() { return element_->logicalType.__isset.JSON; }}, {"bson", - LogicalType::BSON(), - Type::BYTE_ARRAY, + LogicalType::bson(), + Type::kByteArray, -1, true, - ConvertedType::BSON, + ConvertedType::kBson, true, [this]() { return element_->logicalType.__isset.BSON; }}, {"uuid", - LogicalType::UUID(), - Type::FIXED_LEN_BYTE_ARRAY, + LogicalType::uuid(), + Type::kFixedLenByteArray, 16, false, - ConvertedType::NA, + ConvertedType::kNa, true, [this]() { return element_->logicalType.__isset.UUID; }}, {"none", - LogicalType::None(), - Type::INT64, + LogicalType::none(), + Type::kInt64, -1, false, - ConvertedType::NA, + ConvertedType::kNa, false, - check_nothing}}; + checkNothing}}; for (const SchemaElementConstructionArguments& c : cases) { - this->Reconstruct(c)->Inspect(); + this->reconstruct(c)->inspect(); } - std::vector legacy_cases = { + std::vector legacyCases = { {"timestamp_ms", - Type::INT64, + Type::kInt64, -1, true, - ConvertedType::TIMESTAMP_MILLIS, + ConvertedType::kTimestampMillis, false, - check_nothing}, + checkNothing}, {"timestamp_us", - Type::INT64, + Type::kInt64, -1, true, - ConvertedType::TIMESTAMP_MICROS, + ConvertedType::kTimestampMicros, false, - check_nothing}, + checkNothing}, }; - for (const LegacySchemaElementConstructionArguments& c : legacy_cases) { - this->LegacyReconstruct(c)->Inspect(); + for (const LegacySchemaElementConstructionArguments& c : legacyCases) { + this->legacyReconstruct(c)->inspect(); } } class TestDecimalSchemaElementConstruction : public TestSchemaElementConstruction { public: - TestDecimalSchemaElementConstruction* Reconstruct( + TestDecimalSchemaElementConstruction* reconstruct( const SchemaElementConstructionArguments& c) { - TestSchemaElementConstruction::Reconstruct(c); - const auto& decimal_logical_type = - checked_cast(*c.logical_type); - precision_ = decimal_logical_type.precision(); - scale_ = decimal_logical_type.scale(); + TestSchemaElementConstruction::reconstruct(c); + const auto& decimalLogicalType = + checked_cast(*c.logicalType); + precision_ = decimalLogicalType.precision(); + scale_ = decimalLogicalType.scale(); return this; } - void Inspect() { - TestSchemaElementConstruction::Inspect(); + void inspect() { + TestSchemaElementConstruction::inspect(); ASSERT_EQ(element_->precision, precision_); ASSERT_EQ(element_->scale, scale_); ASSERT_EQ(element_->logicalType.DECIMAL.precision, precision_); @@ -2448,87 +2471,87 @@ class TestDecimalSchemaElementConstruction }; TEST_F(TestDecimalSchemaElementConstruction, DecimalCases) { - auto check_DECIMAL = [this]() { + auto checkDecimal = [this]() { return element_->logicalType.__isset.DECIMAL; }; std::vector cases = { {"decimal", - LogicalType::Decimal(16, 6), - Type::INT64, + LogicalType::decimal(16, 6), + Type::kInt64, -1, true, - ConvertedType::DECIMAL, + ConvertedType::kDecimal, true, - check_DECIMAL}, + checkDecimal}, {"decimal", - LogicalType::Decimal(1, 0), - Type::INT32, + LogicalType::decimal(1, 0), + Type::kInt32, -1, true, - ConvertedType::DECIMAL, + ConvertedType::kDecimal, true, - check_DECIMAL}, + checkDecimal}, {"decimal", - LogicalType::Decimal(10), - Type::INT64, + LogicalType::decimal(10), + Type::kInt64, -1, true, - ConvertedType::DECIMAL, + ConvertedType::kDecimal, true, - check_DECIMAL}, + checkDecimal}, {"decimal", - LogicalType::Decimal(11, 11), - Type::INT64, + LogicalType::decimal(11, 11), + Type::kInt64, -1, true, - ConvertedType::DECIMAL, + ConvertedType::kDecimal, true, - check_DECIMAL}, + checkDecimal}, {"decimal", - LogicalType::Decimal(9, 9), - Type::INT32, + LogicalType::decimal(9, 9), + Type::kInt32, -1, true, - ConvertedType::DECIMAL, + ConvertedType::kDecimal, true, - check_DECIMAL}, + checkDecimal}, {"decimal", - LogicalType::Decimal(18, 18), - Type::INT64, + LogicalType::decimal(18, 18), + Type::kInt64, -1, true, - ConvertedType::DECIMAL, + ConvertedType::kDecimal, true, - check_DECIMAL}, + checkDecimal}, {"decimal", - LogicalType::Decimal(307, 7), - Type::FIXED_LEN_BYTE_ARRAY, + LogicalType::decimal(307, 7), + Type::kFixedLenByteArray, 128, true, - ConvertedType::DECIMAL, + ConvertedType::kDecimal, true, - check_DECIMAL}, + checkDecimal}, {"decimal", - LogicalType::Decimal(310, 32), - Type::FIXED_LEN_BYTE_ARRAY, + LogicalType::decimal(310, 32), + Type::kFixedLenByteArray, 129, true, - ConvertedType::DECIMAL, + ConvertedType::kDecimal, true, - check_DECIMAL}, + checkDecimal}, {"decimal", - LogicalType::Decimal(2147483645, 2147483645), - Type::FIXED_LEN_BYTE_ARRAY, + LogicalType::decimal(2147483645, 2147483645), + Type::kFixedLenByteArray, 891723282, true, - ConvertedType::DECIMAL, + ConvertedType::kDecimal, true, - check_DECIMAL}, + checkDecimal}, }; for (const SchemaElementConstructionArguments& c : cases) { - this->Reconstruct(c)->Inspect(); + this->reconstruct(c)->inspect(); } } @@ -2536,42 +2559,42 @@ class TestTemporalSchemaElementConstruction : public TestSchemaElementConstruction { public: template - TestTemporalSchemaElementConstruction* Reconstruct( + TestTemporalSchemaElementConstruction* reconstruct( const SchemaElementConstructionArguments& c) { - TestSchemaElementConstruction::Reconstruct(c); - const auto& t = checked_cast(*c.logical_type); - adjusted_ = t.is_adjusted_to_utc(); - unit_ = t.time_unit(); + TestSchemaElementConstruction::reconstruct(c); + const auto& t = checked_cast(*c.logicalType); + adjusted_ = t.isAdjustedToUtc(); + unit_ = t.timeUnit(); return this; } template - void Inspect() { + void inspect() { FAIL() << "Invalid typename specified in test suite"; return; } protected: bool adjusted_; - LogicalType::TimeUnit::unit unit_; + LogicalType::TimeUnit::Unit unit_; }; template <> -void TestTemporalSchemaElementConstruction::Inspect< +void TestTemporalSchemaElementConstruction::inspect< facebook::velox::parquet::thrift::TimeType>() { - TestSchemaElementConstruction::Inspect(); + TestSchemaElementConstruction::inspect(); ASSERT_EQ(element_->logicalType.TIME.isAdjustedToUTC, adjusted_); switch (unit_) { - case LogicalType::TimeUnit::MILLIS: + case LogicalType::TimeUnit::kMillis: ASSERT_TRUE(element_->logicalType.TIME.unit.__isset.MILLIS); break; - case LogicalType::TimeUnit::MICROS: + case LogicalType::TimeUnit::kMicros: ASSERT_TRUE(element_->logicalType.TIME.unit.__isset.MICROS); break; - case LogicalType::TimeUnit::NANOS: + case LogicalType::TimeUnit::kNanos: ASSERT_TRUE(element_->logicalType.TIME.unit.__isset.NANOS); break; - case LogicalType::TimeUnit::UNKNOWN: + case LogicalType::TimeUnit::kUnknown: default: FAIL() << "Invalid time unit in test case"; } @@ -2579,21 +2602,21 @@ void TestTemporalSchemaElementConstruction::Inspect< } template <> -void TestTemporalSchemaElementConstruction::Inspect< +void TestTemporalSchemaElementConstruction::inspect< facebook::velox::parquet::thrift::TimestampType>() { - TestSchemaElementConstruction::Inspect(); + TestSchemaElementConstruction::inspect(); ASSERT_EQ(element_->logicalType.TIMESTAMP.isAdjustedToUTC, adjusted_); switch (unit_) { - case LogicalType::TimeUnit::MILLIS: + case LogicalType::TimeUnit::kMillis: ASSERT_TRUE(element_->logicalType.TIMESTAMP.unit.__isset.MILLIS); break; - case LogicalType::TimeUnit::MICROS: + case LogicalType::TimeUnit::kMicros: ASSERT_TRUE(element_->logicalType.TIMESTAMP.unit.__isset.MICROS); break; - case LogicalType::TimeUnit::NANOS: + case LogicalType::TimeUnit::kNanos: ASSERT_TRUE(element_->logicalType.TIMESTAMP.unit.__isset.NANOS); break; - case LogicalType::TimeUnit::UNKNOWN: + case LogicalType::TimeUnit::kUnknown: default: FAIL() << "Invalid time unit in test case"; } @@ -2601,164 +2624,158 @@ void TestTemporalSchemaElementConstruction::Inspect< } TEST_F(TestTemporalSchemaElementConstruction, TemporalCases) { - auto check_TIME = [this]() { return element_->logicalType.__isset.TIME; }; + auto checkTime = [this]() { return element_->logicalType.__isset.TIME; }; - std::vector time_cases = { + std::vector timeCases = { {"time_T_ms", - LogicalType::Time(true, LogicalType::TimeUnit::MILLIS), - Type::INT32, + LogicalType::time(true, LogicalType::TimeUnit::kMillis), + Type::kInt32, -1, true, - ConvertedType::TIME_MILLIS, + ConvertedType::kTimeMillis, true, - check_TIME}, + checkTime}, {"time_F_ms", - LogicalType::Time(false, LogicalType::TimeUnit::MILLIS), - Type::INT32, + LogicalType::time(false, LogicalType::TimeUnit::kMillis), + Type::kInt32, -1, false, - ConvertedType::NA, + ConvertedType::kNa, true, - check_TIME}, + checkTime}, {"time_T_us", - LogicalType::Time(true, LogicalType::TimeUnit::MICROS), - Type::INT64, + LogicalType::time(true, LogicalType::TimeUnit::kMicros), + Type::kInt64, -1, true, - ConvertedType::TIME_MICROS, + ConvertedType::kTimeMicros, true, - check_TIME}, + checkTime}, {"time_F_us", - LogicalType::Time(false, LogicalType::TimeUnit::MICROS), - Type::INT64, + LogicalType::time(false, LogicalType::TimeUnit::kMicros), + Type::kInt64, -1, false, - ConvertedType::NA, + ConvertedType::kNa, true, - check_TIME}, + checkTime}, {"time_T_ns", - LogicalType::Time(true, LogicalType::TimeUnit::NANOS), - Type::INT64, + LogicalType::time(true, LogicalType::TimeUnit::kNanos), + Type::kInt64, -1, false, - ConvertedType::NA, + ConvertedType::kNa, true, - check_TIME}, + checkTime}, {"time_F_ns", - LogicalType::Time(false, LogicalType::TimeUnit::NANOS), - Type::INT64, + LogicalType::time(false, LogicalType::TimeUnit::kNanos), + Type::kInt64, -1, false, - ConvertedType::NA, + ConvertedType::kNa, true, - check_TIME}, + checkTime}, }; - for (const SchemaElementConstructionArguments& c : time_cases) { - this->Reconstruct(c) - ->Inspect(); + for (const SchemaElementConstructionArguments& c : timeCases) { + this->reconstruct(c) + ->inspect(); } - auto check_TIMESTAMP = [this]() { + auto checkTimestamp = [this]() { return element_->logicalType.__isset.TIMESTAMP; }; - std::vector timestamp_cases = { + std::vector timestampCases = { {"timestamp_T_ms", - LogicalType::Timestamp(true, LogicalType::TimeUnit::MILLIS), - Type::INT64, + LogicalType::timestamp(true, LogicalType::TimeUnit::kMillis), + Type::kInt64, -1, true, - ConvertedType::TIMESTAMP_MILLIS, + ConvertedType::kTimestampMillis, true, - check_TIMESTAMP}, + checkTimestamp}, {"timestamp_F_ms", - LogicalType::Timestamp(false, LogicalType::TimeUnit::MILLIS), - Type::INT64, + LogicalType::timestamp(false, LogicalType::TimeUnit::kMillis), + Type::kInt64, -1, false, - ConvertedType::NA, + ConvertedType::kNa, true, - check_TIMESTAMP}, + checkTimestamp}, {"timestamp_F_ms_force", - LogicalType::Timestamp( - false, - LogicalType::TimeUnit::MILLIS, - /*is_from_converted_type=*/false, - /*force_set_converted_type=*/true), - Type::INT64, + LogicalType::timestamp( + false, LogicalType::TimeUnit::kMillis, false, true), + Type::kInt64, -1, true, - ConvertedType::TIMESTAMP_MILLIS, + ConvertedType::kTimestampMillis, true, - check_TIMESTAMP}, + checkTimestamp}, {"timestamp_T_us", - LogicalType::Timestamp(true, LogicalType::TimeUnit::MICROS), - Type::INT64, + LogicalType::timestamp(true, LogicalType::TimeUnit::kMicros), + Type::kInt64, -1, true, - ConvertedType::TIMESTAMP_MICROS, + ConvertedType::kTimestampMicros, true, - check_TIMESTAMP}, + checkTimestamp}, {"timestamp_F_us", - LogicalType::Timestamp(false, LogicalType::TimeUnit::MICROS), - Type::INT64, + LogicalType::timestamp(false, LogicalType::TimeUnit::kMicros), + Type::kInt64, -1, false, - ConvertedType::NA, + ConvertedType::kNa, true, - check_TIMESTAMP}, + checkTimestamp}, {"timestamp_F_us_force", - LogicalType::Timestamp( - false, - LogicalType::TimeUnit::MILLIS, - /*is_from_converted_type=*/false, - /*force_set_converted_type=*/true), - Type::INT64, + LogicalType::timestamp( + false, LogicalType::TimeUnit::kMillis, false, true), + Type::kInt64, -1, true, - ConvertedType::TIMESTAMP_MILLIS, + ConvertedType::kTimestampMillis, true, - check_TIMESTAMP}, + checkTimestamp}, {"timestamp_T_ns", - LogicalType::Timestamp(true, LogicalType::TimeUnit::NANOS), - Type::INT64, + LogicalType::timestamp(true, LogicalType::TimeUnit::kNanos), + Type::kInt64, -1, false, - ConvertedType::NA, + ConvertedType::kNa, true, - check_TIMESTAMP}, + checkTimestamp}, {"timestamp_F_ns", - LogicalType::Timestamp(false, LogicalType::TimeUnit::NANOS), - Type::INT64, + LogicalType::timestamp(false, LogicalType::TimeUnit::kNanos), + Type::kInt64, -1, false, - ConvertedType::NA, + ConvertedType::kNa, true, - check_TIMESTAMP}, + checkTimestamp}, }; - for (const SchemaElementConstructionArguments& c : timestamp_cases) { - this->Reconstruct(c) - ->Inspect(); + for (const SchemaElementConstructionArguments& c : timestampCases) { + this->reconstruct(c) + ->inspect(); } } class TestIntegerSchemaElementConstruction : public TestSchemaElementConstruction { public: - TestIntegerSchemaElementConstruction* Reconstruct( + TestIntegerSchemaElementConstruction* reconstruct( const SchemaElementConstructionArguments& c) { - TestSchemaElementConstruction::Reconstruct(c); - const auto& int_logical_type = - checked_cast(*c.logical_type); - width_ = int_logical_type.bit_width(); - signed_ = int_logical_type.is_signed(); + TestSchemaElementConstruction::reconstruct(c); + const auto& intLogicalType = + checked_cast(*c.logicalType); + width_ = intLogicalType.bitWidth(); + signed_ = intLogicalType.isSigned(); return this; } - void Inspect() { - TestSchemaElementConstruction::Inspect(); + void inspect() { + TestSchemaElementConstruction::inspect(); ASSERT_EQ(element_->logicalType.INTEGER.bitWidth, width_); ASSERT_EQ(element_->logicalType.INTEGER.isSigned, signed_); return; @@ -2770,221 +2787,229 @@ class TestIntegerSchemaElementConstruction }; TEST_F(TestIntegerSchemaElementConstruction, IntegerCases) { - auto check_INTEGER = [this]() { + auto checkInteger = [this]() { return element_->logicalType.__isset.INTEGER; }; std::vector cases = { {"uint8", - LogicalType::Int(8, false), - Type::INT32, + LogicalType::intType(8, false), + Type::kInt32, -1, true, - ConvertedType::UINT_8, + ConvertedType::kUint8, true, - check_INTEGER}, + checkInteger}, {"uint16", - LogicalType::Int(16, false), - Type::INT32, + LogicalType::intType(16, false), + Type::kInt32, -1, true, - ConvertedType::UINT_16, + ConvertedType::kUint16, true, - check_INTEGER}, + checkInteger}, {"uint32", - LogicalType::Int(32, false), - Type::INT32, + LogicalType::intType(32, false), + Type::kInt32, -1, true, - ConvertedType::UINT_32, + ConvertedType::kUint32, true, - check_INTEGER}, + checkInteger}, {"uint64", - LogicalType::Int(64, false), - Type::INT64, + LogicalType::intType(64, false), + Type::kInt64, -1, true, - ConvertedType::UINT_64, + ConvertedType::kUint64, true, - check_INTEGER}, + checkInteger}, {"int8", - LogicalType::Int(8, true), - Type::INT32, + LogicalType::intType(8, true), + Type::kInt32, -1, true, - ConvertedType::INT_8, + ConvertedType::kInt8, true, - check_INTEGER}, + checkInteger}, {"int16", - LogicalType::Int(16, true), - Type::INT32, + LogicalType::intType(16, true), + Type::kInt32, -1, true, - ConvertedType::INT_16, + ConvertedType::kInt16, true, - check_INTEGER}, + checkInteger}, {"int32", - LogicalType::Int(32, true), - Type::INT32, + LogicalType::intType(32, true), + Type::kInt32, -1, true, - ConvertedType::INT_32, + ConvertedType::kInt32, true, - check_INTEGER}, + checkInteger}, {"int64", - LogicalType::Int(64, true), - Type::INT64, + LogicalType::intType(64, true), + Type::kInt64, -1, true, - ConvertedType::INT_64, + ConvertedType::kInt64, true, - check_INTEGER}, + checkInteger}, }; for (const SchemaElementConstructionArguments& c : cases) { - this->Reconstruct(c)->Inspect(); + this->reconstruct(c)->inspect(); } } TEST(TestLogicalTypeSerialization, SchemaElementNestedCases) { - // Confirm that the intermediate Thrift objects created during node - // serialization contain correct ConvertedType and ConvertedType information + // Confirm that the intermediate Thrift objects created during node. + // Serialization contain correct ConvertedType and ConvertedType information. - NodePtr string_node = PrimitiveNode::Make( + NodePtr stringNode = PrimitiveNode::make( "string", - Repetition::REQUIRED, - StringLogicalType::Make(), - Type::BYTE_ARRAY); - NodePtr date_node = PrimitiveNode::Make( - "date", Repetition::REQUIRED, DateLogicalType::Make(), Type::INT32); - NodePtr json_node = PrimitiveNode::Make( - "json", Repetition::REQUIRED, JSONLogicalType::Make(), Type::BYTE_ARRAY); - NodePtr uuid_node = PrimitiveNode::Make( + Repetition::kRequired, + StringLogicalType::make(), + Type::kByteArray); + NodePtr dateNode = PrimitiveNode::make( + "date", Repetition::kRequired, DateLogicalType::make(), Type::kInt32); + NodePtr jsonNode = PrimitiveNode::make( + "json", Repetition::kRequired, JsonLogicalType::make(), Type::kByteArray); + NodePtr uuidNode = PrimitiveNode::make( "uuid", - Repetition::REQUIRED, - UUIDLogicalType::Make(), - Type::FIXED_LEN_BYTE_ARRAY, + Repetition::kRequired, + UuidLogicalType::make(), + Type::kFixedLenByteArray, 16); - NodePtr timestamp_node = PrimitiveNode::Make( + NodePtr timestampNode = PrimitiveNode::make( "timestamp", - Repetition::REQUIRED, - TimestampLogicalType::Make(false, LogicalType::TimeUnit::NANOS), - Type::INT64); - NodePtr int_node = PrimitiveNode::Make( + Repetition::kRequired, + TimestampLogicalType::make(false, LogicalType::TimeUnit::kNanos), + Type::kInt64); + NodePtr intNode = PrimitiveNode::make( "int", - Repetition::REQUIRED, - IntLogicalType::Make(64, false), - Type::INT64); - NodePtr decimal_node = PrimitiveNode::Make( + Repetition::kRequired, + IntLogicalType::make(64, false), + Type::kInt64); + NodePtr decimalNode = PrimitiveNode::make( "decimal", - Repetition::REQUIRED, - DecimalLogicalType::Make(16, 6), - Type::INT64); + Repetition::kRequired, + DecimalLogicalType::make(16, 6), + Type::kInt64); - NodePtr list_node = GroupNode::Make( + NodePtr listNode = GroupNode::make( "list", - Repetition::REPEATED, - {string_node, - date_node, - json_node, - uuid_node, - timestamp_node, - int_node, - decimal_node}, - ListLogicalType::Make()); - std::vector list_elements; - ToParquet(reinterpret_cast(list_node.get()), &list_elements); - ASSERT_EQ(list_elements[0].name, "list"); - ASSERT_TRUE(list_elements[0].__isset.converted_type); - ASSERT_TRUE(list_elements[0].__isset.logicalType); - ASSERT_EQ(list_elements[0].converted_type, ToThrift(ConvertedType::LIST)); - ASSERT_TRUE(list_elements[0].logicalType.__isset.LIST); - ASSERT_TRUE(list_elements[1].logicalType.__isset.STRING); - ASSERT_TRUE(list_elements[2].logicalType.__isset.DATE); - ASSERT_TRUE(list_elements[3].logicalType.__isset.JSON); - ASSERT_TRUE(list_elements[4].logicalType.__isset.UUID); - ASSERT_TRUE(list_elements[5].logicalType.__isset.TIMESTAMP); - ASSERT_TRUE(list_elements[6].logicalType.__isset.INTEGER); - ASSERT_TRUE(list_elements[7].logicalType.__isset.DECIMAL); - - NodePtr map_node = - GroupNode::Make("map", Repetition::REQUIRED, {}, MapLogicalType::Make()); - std::vector map_elements; - ToParquet(reinterpret_cast(map_node.get()), &map_elements); - ASSERT_EQ(map_elements[0].name, "map"); - ASSERT_TRUE(map_elements[0].__isset.converted_type); - ASSERT_TRUE(map_elements[0].__isset.logicalType); - ASSERT_EQ(map_elements[0].converted_type, ToThrift(ConvertedType::MAP)); - ASSERT_TRUE(map_elements[0].logicalType.__isset.MAP); + Repetition::kRepeated, + {stringNode, + dateNode, + jsonNode, + uuidNode, + timestampNode, + intNode, + decimalNode}, + ListLogicalType::make()); + std::vector listElements; + toParquet(reinterpret_cast(listNode.get()), &listElements); + ASSERT_EQ(listElements[0].name, "list"); + ASSERT_TRUE(listElements[0].__isset.converted_type); + ASSERT_TRUE(listElements[0].__isset.logicalType); + ASSERT_EQ(listElements[0].converted_type, toThrift(ConvertedType::kList)); + ASSERT_TRUE(listElements[0].logicalType.__isset.LIST); + ASSERT_TRUE(listElements[1].logicalType.__isset.STRING); + ASSERT_TRUE(listElements[2].logicalType.__isset.DATE); + ASSERT_TRUE(listElements[3].logicalType.__isset.JSON); + ASSERT_TRUE(listElements[4].logicalType.__isset.UUID); + ASSERT_TRUE(listElements[5].logicalType.__isset.TIMESTAMP); + ASSERT_TRUE(listElements[6].logicalType.__isset.INTEGER); + ASSERT_TRUE(listElements[7].logicalType.__isset.DECIMAL); + + NodePtr mapNode = + GroupNode::make("map", Repetition::kRequired, {}, MapLogicalType::make()); + std::vector mapElements; + toParquet(reinterpret_cast(mapNode.get()), &mapElements); + ASSERT_EQ(mapElements[0].name, "map"); + ASSERT_TRUE(mapElements[0].__isset.converted_type); + ASSERT_TRUE(mapElements[0].__isset.logicalType); + ASSERT_EQ(mapElements[0].converted_type, toThrift(ConvertedType::kMap)); + ASSERT_TRUE(mapElements[0].logicalType.__isset.MAP); } TEST(TestLogicalTypeSerialization, Roundtrips) { - // Confirm that Thrift serialization-deserialization of nodes with logical - // types produces equivalent reconstituted nodes + // Confirm that Thrift serialization-deserialization of nodes with logical. + // Types produces equivalent reconstituted nodes. // Primitive nodes ... struct AnnotatedPrimitiveNodeFactoryArguments { - std::shared_ptr logical_type; - Type::type physical_type; - int physical_length; + std::shared_ptr logicalType; + Type::type physicalType; + int physicalLength; }; std::vector cases = { - {LogicalType::String(), Type::BYTE_ARRAY, -1}, - {LogicalType::Enum(), Type::BYTE_ARRAY, -1}, - {LogicalType::Decimal(16, 6), Type::INT64, -1}, - {LogicalType::Date(), Type::INT32, -1}, - {LogicalType::Time(true, LogicalType::TimeUnit::MILLIS), Type::INT32, -1}, - {LogicalType::Time(true, LogicalType::TimeUnit::MICROS), Type::INT64, -1}, - {LogicalType::Time(true, LogicalType::TimeUnit::NANOS), Type::INT64, -1}, - {LogicalType::Time(false, LogicalType::TimeUnit::MILLIS), - Type::INT32, + {LogicalType::string(), Type::kByteArray, -1}, + {LogicalType::enumType(), Type::kByteArray, -1}, + {LogicalType::decimal(16, 6), Type::kInt64, -1}, + {LogicalType::date(), Type::kInt32, -1}, + {LogicalType::time(true, LogicalType::TimeUnit::kMillis), + Type::kInt32, + -1}, + {LogicalType::time(true, LogicalType::TimeUnit::kMicros), + Type::kInt64, + -1}, + {LogicalType::time(true, LogicalType::TimeUnit::kNanos), + Type::kInt64, + -1}, + {LogicalType::time(false, LogicalType::TimeUnit::kMillis), + Type::kInt32, + -1}, + {LogicalType::time(false, LogicalType::TimeUnit::kMicros), + Type::kInt64, -1}, - {LogicalType::Time(false, LogicalType::TimeUnit::MICROS), - Type::INT64, + {LogicalType::time(false, LogicalType::TimeUnit::kNanos), + Type::kInt64, -1}, - {LogicalType::Time(false, LogicalType::TimeUnit::NANOS), Type::INT64, -1}, - {LogicalType::Timestamp(true, LogicalType::TimeUnit::MILLIS), - Type::INT64, + {LogicalType::timestamp(true, LogicalType::TimeUnit::kMillis), + Type::kInt64, -1}, - {LogicalType::Timestamp(true, LogicalType::TimeUnit::MICROS), - Type::INT64, + {LogicalType::timestamp(true, LogicalType::TimeUnit::kMicros), + Type::kInt64, -1}, - {LogicalType::Timestamp(true, LogicalType::TimeUnit::NANOS), - Type::INT64, + {LogicalType::timestamp(true, LogicalType::TimeUnit::kNanos), + Type::kInt64, -1}, - {LogicalType::Timestamp(false, LogicalType::TimeUnit::MILLIS), - Type::INT64, + {LogicalType::timestamp(false, LogicalType::TimeUnit::kMillis), + Type::kInt64, -1}, - {LogicalType::Timestamp(false, LogicalType::TimeUnit::MICROS), - Type::INT64, + {LogicalType::timestamp(false, LogicalType::TimeUnit::kMicros), + Type::kInt64, -1}, - {LogicalType::Timestamp(false, LogicalType::TimeUnit::NANOS), - Type::INT64, + {LogicalType::timestamp(false, LogicalType::TimeUnit::kNanos), + Type::kInt64, -1}, - {LogicalType::Interval(), Type::FIXED_LEN_BYTE_ARRAY, 12}, - {LogicalType::Int(8, false), Type::INT32, -1}, - {LogicalType::Int(16, false), Type::INT32, -1}, - {LogicalType::Int(32, false), Type::INT32, -1}, - {LogicalType::Int(64, false), Type::INT64, -1}, - {LogicalType::Int(8, true), Type::INT32, -1}, - {LogicalType::Int(16, true), Type::INT32, -1}, - {LogicalType::Int(32, true), Type::INT32, -1}, - {LogicalType::Int(64, true), Type::INT64, -1}, - {LogicalType::Null(), Type::BOOLEAN, -1}, - {LogicalType::JSON(), Type::BYTE_ARRAY, -1}, - {LogicalType::BSON(), Type::BYTE_ARRAY, -1}, - {LogicalType::UUID(), Type::FIXED_LEN_BYTE_ARRAY, 16}, - {LogicalType::None(), Type::BOOLEAN, -1}}; + {LogicalType::interval(), Type::kFixedLenByteArray, 12}, + {LogicalType::intType(8, false), Type::kInt32, -1}, + {LogicalType::intType(16, false), Type::kInt32, -1}, + {LogicalType::intType(32, false), Type::kInt32, -1}, + {LogicalType::intType(64, false), Type::kInt64, -1}, + {LogicalType::intType(8, true), Type::kInt32, -1}, + {LogicalType::intType(16, true), Type::kInt32, -1}, + {LogicalType::intType(32, true), Type::kInt32, -1}, + {LogicalType::intType(64, true), Type::kInt64, -1}, + {LogicalType::nullType(), Type::kBoolean, -1}, + {LogicalType::json(), Type::kByteArray, -1}, + {LogicalType::bson(), Type::kByteArray, -1}, + {LogicalType::uuid(), Type::kFixedLenByteArray, 16}, + {LogicalType::none(), Type::kBoolean, -1}}; for (const AnnotatedPrimitiveNodeFactoryArguments& c : cases) { - ConfirmPrimitiveNodeRoundtrip( - c.logical_type, c.physical_type, c.physical_length); + confirmPrimitiveNodeRoundtrip( + c.logicalType, c.physicalType, c.physicalLength); } // Group nodes ... - ConfirmGroupNodeRoundtrip("map", LogicalType::Map()); - ConfirmGroupNodeRoundtrip("list", LogicalType::List()); + confirmGroupNodeRoundtrip("map", LogicalType::map()); + confirmGroupNodeRoundtrip("list", LogicalType::list()); } } // namespace schema diff --git a/velox/dwio/parquet/writer/arrow/tests/StatisticsTest.cpp b/velox/dwio/parquet/writer/arrow/tests/StatisticsTest.cpp index 85603e04b15..008d95d3df9 100644 --- a/velox/dwio/parquet/writer/arrow/tests/StatisticsTest.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/StatisticsTest.cpp @@ -20,10 +20,12 @@ #include "arrow/testing/builder.h" +#include "velox/common/io/IoStatistics.h" +#include "velox/common/testutil/TempFilePath.h" #include "velox/dwio/parquet/reader/ParquetReader.h" #include "velox/dwio/parquet/writer/arrow/FileWriter.h" +#include "velox/dwio/parquet/writer/arrow/StringTruncation.h" #include "velox/dwio/parquet/writer/arrow/tests/TestUtil.h" -#include "velox/exec/tests/utils/TempFilePath.h" using arrow::default_memory_pool; using arrow::MemoryPool; @@ -32,6 +34,7 @@ using arrow::util::SafeCopy; namespace bit_util = arrow::bit_util; namespace facebook::velox::parquet::arrow { +using namespace facebook::velox::common::testutil; using schema::GroupNode; using schema::NodePtr; @@ -40,8 +43,8 @@ using schema::PrimitiveNode; namespace test { namespace { void writeToFile( - std::shared_ptr filePath, - std::shared_ptr buffer) { + std::shared_ptr filePath, + std::shared_ptr<::arrow::Buffer> buffer) { auto localWriteFile = std::make_unique(filePath->getPath(), false, false); auto bufferReader = std::make_shared<::arrow::io::BufferReader>(buffer); @@ -51,15 +54,15 @@ void writeToFile( } } // namespace -// ---------------------------------------------------------------------- -// Test comparators +// ----------------------------------------------------------------------. +// Test Comparators. -static ByteArray ByteArrayFromString(const std::string& s) { +static ByteArray byteArrayFromString(const std::string& s) { auto ptr = reinterpret_cast(s.data()); return ByteArray(static_cast(s.size()), ptr); } -static FLBA FLBAFromString(const std::string& s) { +static FLBA fLBAFromString(const std::string& s) { auto ptr = reinterpret_cast(s.data()); return FLBA(ptr); } @@ -69,12 +72,12 @@ TEST(Comparison, SignedByteArray) { // decimals are encoded as byte arrays they use twos complement big-endian // encoded values. Comparisons of byte arrays of unequal types need to handle // sign extension. - auto comparator = - MakeComparator(Type::BYTE_ARRAY, SortOrder::SIGNED); + auto Comparator = + makeComparator(Type::kByteArray, SortOrder::kSigned); struct Case { std::vector bytes; int order; - ByteArray ToByteArray() const { + ByteArray toByteArray() const { return ByteArray(static_cast(bytes.size()), bytes.data()); } }; @@ -99,23 +102,23 @@ TEST(Comparison, SignedByteArray) { for (size_t x = 0; x < cases.size(); x++) { const auto& case1 = cases[x]; - // Empty array is always the smallest values - EXPECT_TRUE(comparator->Compare(ByteArray(), case1.ToByteArray())) << x; - EXPECT_FALSE(comparator->Compare(case1.ToByteArray(), ByteArray())) << x; + // Empty array is always the smallest values. + EXPECT_TRUE(Comparator->compare(ByteArray(), case1.toByteArray())) << x; + EXPECT_FALSE(Comparator->compare(case1.toByteArray(), ByteArray())) << x; // Equals is always false. - EXPECT_FALSE(comparator->Compare(case1.ToByteArray(), case1.ToByteArray())) + EXPECT_FALSE(Comparator->compare(case1.toByteArray(), case1.toByteArray())) << x; for (size_t y = 0; y < cases.size(); y++) { const auto& case2 = cases[y]; if (case1.order < case2.order) { EXPECT_TRUE( - comparator->Compare(case1.ToByteArray(), case2.ToByteArray())) + Comparator->compare(case1.toByteArray(), case2.toByteArray())) << x << " (order: " << case1.order << ") " << y << " (order: " << case2.order << ")"; } else { EXPECT_FALSE( - comparator->Compare(case1.ToByteArray(), case2.ToByteArray())) + Comparator->compare(case1.toByteArray(), case2.toByteArray())) << x << " (order: " << case1.order << ") " << y << " (order: " << case2.order << ")"; } @@ -124,36 +127,36 @@ TEST(Comparison, SignedByteArray) { } TEST(Comparison, UnsignedByteArray) { - // Check if UTF-8 is compared using unsigned correctly - auto comparator = - MakeComparator(Type::BYTE_ARRAY, SortOrder::UNSIGNED); + // Check if UTF-8 is compared using unsigned correctly. + auto Comparator = + makeComparator(Type::kByteArray, SortOrder::kUnsigned); std::string s1 = "arrange"; std::string s2 = "arrangement"; - ByteArray s1ba = ByteArrayFromString(s1); - ByteArray s2ba = ByteArrayFromString(s2); - ASSERT_TRUE(comparator->Compare(s1ba, s2ba)); + ByteArray s1ba = byteArrayFromString(s1); + ByteArray s2ba = byteArrayFromString(s2); + ASSERT_TRUE(Comparator->compare(s1ba, s2ba)); - // Multi-byte UTF-8 characters + // Multi-byte UTF-8 characters. s1 = "braten"; s2 = "bügeln"; - s1ba = ByteArrayFromString(s1); - s2ba = ByteArrayFromString(s2); - ASSERT_TRUE(comparator->Compare(s1ba, s2ba)); + s1ba = byteArrayFromString(s1); + s2ba = byteArrayFromString(s2); + ASSERT_TRUE(Comparator->compare(s1ba, s2ba)); s1 = "ünk123456"; // ü = 252 s2 = "ănk123456"; // ă = 259 - s1ba = ByteArrayFromString(s1); - s2ba = ByteArrayFromString(s2); - ASSERT_TRUE(comparator->Compare(s1ba, s2ba)); + s1ba = byteArrayFromString(s1); + s2ba = byteArrayFromString(s2); + ASSERT_TRUE(Comparator->compare(s1ba, s2ba)); } TEST(Comparison, SignedFLBA) { int size = 4; - auto comparator = MakeComparator( - Type::FIXED_LEN_BYTE_ARRAY, SortOrder::SIGNED, size); + auto Comparator = makeComparator( + Type::kFixedLenByteArray, SortOrder::kSigned, size); - std::vector byte_values[] = { + std::vector byteValues[] = { {0x80, 0, 0, 0}, {0xFF, 0xFF, 0x01, 0}, {0xFF, 0xFF, 0x80, 0}, @@ -162,21 +165,18 @@ TEST(Comparison, SignedFLBA) { {0, 0, 0x01, 0x01}, {0, 0x01, 0x01, 0}, {0x01, 0x01, 0, 0}}; - std::vector values_to_compare; - for (auto& bytes : byte_values) { - values_to_compare.emplace_back(FLBA(bytes.data())); + std::vector valuesToCompare; + for (auto& bytes : byteValues) { + valuesToCompare.emplace_back(FLBA(bytes.data())); } - for (size_t x = 0; x < values_to_compare.size(); x++) { - EXPECT_FALSE( - comparator->Compare(values_to_compare[x], values_to_compare[x])) + for (size_t x = 0; x < valuesToCompare.size(); x++) { + EXPECT_FALSE(Comparator->compare(valuesToCompare[x], valuesToCompare[x])) << x; - for (size_t y = x + 1; y < values_to_compare.size(); y++) { - EXPECT_TRUE( - comparator->Compare(values_to_compare[x], values_to_compare[y])) + for (size_t y = x + 1; y < valuesToCompare.size(); y++) { + EXPECT_TRUE(Comparator->compare(valuesToCompare[x], valuesToCompare[y])) << x << " " << y; - EXPECT_FALSE( - comparator->Compare(values_to_compare[y], values_to_compare[x])) + EXPECT_FALSE(Comparator->compare(valuesToCompare[y], valuesToCompare[x])) << y << " " << x; } } @@ -184,20 +184,20 @@ TEST(Comparison, SignedFLBA) { TEST(Comparison, UnsignedFLBA) { int size = 10; - auto comparator = MakeComparator( - Type::FIXED_LEN_BYTE_ARRAY, SortOrder::UNSIGNED, size); + auto Comparator = makeComparator( + Type::kFixedLenByteArray, SortOrder::kUnsigned, size); std::string s1 = "Anti123456"; std::string s2 = "Bunkd123456"; - FLBA s1flba = FLBAFromString(s1); - FLBA s2flba = FLBAFromString(s2); - ASSERT_TRUE(comparator->Compare(s1flba, s2flba)); + FLBA s1flba = fLBAFromString(s1); + FLBA s2flba = fLBAFromString(s2); + ASSERT_TRUE(Comparator->compare(s1flba, s2flba)); s1 = "Bunk123456"; s2 = "Bünk123456"; - s1flba = FLBAFromString(s1); - s2flba = FLBAFromString(s2); - ASSERT_TRUE(comparator->Compare(s1flba, s2flba)); + s1flba = fLBAFromString(s1); + s2flba = fLBAFromString(s2); + ASSERT_TRUE(Comparator->compare(s1flba, s2flba)); } TEST(Comparison, SignedInt96) { @@ -205,11 +205,11 @@ TEST(Comparison, SignedInt96) { Int96 aa{{1, 41, 14}}, bb{{1, 41, 14}}; Int96 aaa{{1, 41, static_cast(-14)}}, bbb{{1, 41, 42}}; - auto comparator = MakeComparator(Type::INT96, SortOrder::SIGNED); + auto Comparator = makeComparator(Type::kInt96, SortOrder::kSigned); - ASSERT_TRUE(comparator->Compare(a, b)); - ASSERT_TRUE(!comparator->Compare(aa, bb) && !comparator->Compare(bb, aa)); - ASSERT_TRUE(comparator->Compare(aaa, bbb)); + ASSERT_TRUE(Comparator->compare(a, b)); + ASSERT_TRUE(!Comparator->compare(aa, bb) && !Comparator->compare(bb, aa)); + ASSERT_TRUE(Comparator->compare(aaa, bbb)); } TEST(Comparison, UnsignedInt96) { @@ -217,35 +217,36 @@ TEST(Comparison, UnsignedInt96) { Int96 aa{{1, 41, 14}}, bb{{1, 41, static_cast(-14)}}; Int96 aaa, bbb; - auto comparator = MakeComparator(Type::INT96, SortOrder::UNSIGNED); + auto Comparator = + makeComparator(Type::kInt96, SortOrder::kUnsigned); - ASSERT_TRUE(comparator->Compare(a, b)); - ASSERT_TRUE(comparator->Compare(aa, bb)); + ASSERT_TRUE(Comparator->compare(a, b)); + ASSERT_TRUE(Comparator->compare(aa, bb)); - // INT96 Timestamp + // INT96 Timestamp. aaa.value[2] = 2451545; // 2000-01-01 bbb.value[2] = 2451546; // 2000-01-02 - // 12 hours + 34 minutes + 56 seconds. - Int96SetNanoSeconds(aaa, 45296000000000); - // 12 hours + 34 minutes + 50 seconds. - Int96SetNanoSeconds(bbb, 45290000000000); - ASSERT_TRUE(comparator->Compare(aaa, bbb)); + // 12 Hours + 34 minutes + 56 seconds. + int96SetNanoSeconds(aaa, 45296000000000); + // 12 Hours + 34 minutes + 50 seconds. + int96SetNanoSeconds(bbb, 45290000000000); + ASSERT_TRUE(Comparator->compare(aaa, bbb)); aaa.value[2] = 2451545; // 2000-01-01 bbb.value[2] = 2451545; // 2000-01-01 - // 11 hours + 34 minutes + 56 seconds. - Int96SetNanoSeconds(aaa, 41696000000000); - // 12 hours + 34 minutes + 50 seconds. - Int96SetNanoSeconds(bbb, 45290000000000); - ASSERT_TRUE(comparator->Compare(aaa, bbb)); + // 11 Hours + 34 minutes + 56 seconds. + int96SetNanoSeconds(aaa, 41696000000000); + // 12 Hours + 34 minutes + 50 seconds. + int96SetNanoSeconds(bbb, 45290000000000); + ASSERT_TRUE(Comparator->compare(aaa, bbb)); aaa.value[2] = 2451545; // 2000-01-01 bbb.value[2] = 2451545; // 2000-01-01 - // 12 hours + 34 minutes + 55 seconds. - Int96SetNanoSeconds(aaa, 45295000000000); - // 12 hours + 34 minutes + 56 seconds. - Int96SetNanoSeconds(bbb, 45296000000000); - ASSERT_TRUE(comparator->Compare(aaa, bbb)); + // 12 Hours + 34 minutes + 55 seconds. + int96SetNanoSeconds(aaa, 45295000000000); + // 12 Hours + 34 minutes + 56 seconds. + int96SetNanoSeconds(bbb, 45296000000000); + ASSERT_TRUE(Comparator->compare(aaa, bbb)); } TEST(Comparison, SignedInt64) { @@ -253,15 +254,15 @@ TEST(Comparison, SignedInt64) { int64_t aa = 1, bb = 1; int64_t aaa = -1, bbb = 1; - NodePtr node = - PrimitiveNode::Make("SignedInt64", Repetition::REQUIRED, Type::INT64); - ColumnDescriptor descr(node, 0, 0); + NodePtr Node = + PrimitiveNode::make("SignedInt64", Repetition::kRequired, Type::kInt64); + ColumnDescriptor descr(Node, 0, 0); - auto comparator = MakeComparator(&descr); + auto Comparator = makeComparator(&descr); - ASSERT_TRUE(comparator->Compare(a, b)); - ASSERT_TRUE(!comparator->Compare(aa, bb) && !comparator->Compare(bb, aa)); - ASSERT_TRUE(comparator->Compare(aaa, bbb)); + ASSERT_TRUE(Comparator->compare(a, b)); + ASSERT_TRUE(!Comparator->compare(aa, bb) && !Comparator->compare(bb, aa)); + ASSERT_TRUE(Comparator->compare(aaa, bbb)); } TEST(Comparison, UnsignedInt64) { @@ -269,19 +270,19 @@ TEST(Comparison, UnsignedInt64) { uint64_t aa = 1, bb = 1; uint64_t aaa = 1, bbb = -1; - NodePtr node = PrimitiveNode::Make( + NodePtr Node = PrimitiveNode::make( "UnsignedInt64", - Repetition::REQUIRED, - Type::INT64, - ConvertedType::UINT_64); - ColumnDescriptor descr(node, 0, 0); + Repetition::kRequired, + Type::kInt64, + ConvertedType::kUint64); + ColumnDescriptor descr(Node, 0, 0); - ASSERT_EQ(SortOrder::UNSIGNED, descr.sort_order()); - auto comparator = MakeComparator(&descr); + ASSERT_EQ(SortOrder::kUnsigned, descr.sortOrder()); + auto Comparator = makeComparator(&descr); - ASSERT_TRUE(comparator->Compare(a, b)); - ASSERT_TRUE(!comparator->Compare(aa, bb) && !comparator->Compare(bb, aa)); - ASSERT_TRUE(comparator->Compare(aaa, bbb)); + ASSERT_TRUE(Comparator->compare(a, b)); + ASSERT_TRUE(!Comparator->compare(aa, bb) && !Comparator->compare(bb, aa)); + ASSERT_TRUE(Comparator->compare(aaa, bbb)); } TEST(Comparison, UnsignedInt32) { @@ -289,174 +290,174 @@ TEST(Comparison, UnsignedInt32) { uint32_t aa = 1, bb = 1; uint32_t aaa = 1, bbb = -1; - NodePtr node = PrimitiveNode::Make( + NodePtr Node = PrimitiveNode::make( "UnsignedInt32", - Repetition::REQUIRED, - Type::INT32, - ConvertedType::UINT_32); - ColumnDescriptor descr(node, 0, 0); + Repetition::kRequired, + Type::kInt32, + ConvertedType::kUint32); + ColumnDescriptor descr(Node, 0, 0); - ASSERT_EQ(SortOrder::UNSIGNED, descr.sort_order()); - auto comparator = MakeComparator(&descr); + ASSERT_EQ(SortOrder::kUnsigned, descr.sortOrder()); + auto Comparator = makeComparator(&descr); - ASSERT_TRUE(comparator->Compare(a, b)); - ASSERT_TRUE(!comparator->Compare(aa, bb) && !comparator->Compare(bb, aa)); - ASSERT_TRUE(comparator->Compare(aaa, bbb)); + ASSERT_TRUE(Comparator->compare(a, b)); + ASSERT_TRUE(!Comparator->compare(aa, bb) && !Comparator->compare(bb, aa)); + ASSERT_TRUE(Comparator->compare(aaa, bbb)); } TEST(Comparison, UnknownSortOrder) { - NodePtr node = PrimitiveNode::Make( + NodePtr Node = PrimitiveNode::make( "Unknown", - Repetition::REQUIRED, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::INTERVAL, + Repetition::kRequired, + Type::kFixedLenByteArray, + ConvertedType::kInterval, 12); - ColumnDescriptor descr(node, 0, 0); + ColumnDescriptor descr(Node, 0, 0); - ASSERT_THROW(Comparator::Make(&descr), ParquetException); + ASSERT_THROW(Comparator::make(&descr), ParquetException); } -// ---------------------------------------------------------------------- +// ----------------------------------------------------------------------. template class TestStatistics : public PrimitiveTypedTest { public: - using c_type = typename TestType::c_type; + using CType = typename TestType::CType; - std::vector GetDeepCopy( - const std::vector&); // allocates new memory for FLBA/ByteArray + std::vector getDeepCopy( + const std::vector&); // allocates new memory for FLBA/ByteArray - c_type* GetValuesPointer(std::vector&); - void DeepFree(std::vector&); + CType* getValuesPointer(std::vector&); + void deepFree(std::vector&); - void TestMinMaxEncode() { - this->GenerateData(1000); + void testMinMaxEncode() { + this->generateData(1000); - auto statistics1 = MakeStatistics(this->schema_.Column(0)); - statistics1->Update(this->values_ptr_, this->values_.size(), 0); - std::string encoded_min = statistics1->EncodeMin(); - std::string encoded_max = statistics1->EncodeMax(); + auto statistics1 = makeStatistics(this->schema_.column(0)); + statistics1->update(this->valuesPtr_, this->values_.size(), 0); + std::string encodedMin = statistics1->encodeMin(); + std::string encodedMax = statistics1->encodeMax(); - auto statistics2 = MakeStatistics( - this->schema_.Column(0), - encoded_min, - encoded_max, + auto statistics2 = makeStatistics( + this->schema_.column(0), + encodedMin, + encodedMax, this->values_.size(), - 0, - 0, - true, - true, - true); - - auto statistics3 = MakeStatistics(this->schema_.Column(0)); - std::vector valid_bits( + 0, // nullCount. + 0, // distinctCount. + true, // hasMinMax. + true, // hasNullCount. + true, // hasDistinctCount. + false, // hasNaNCount. + 0); // nanCount. + + auto statistics3 = makeStatistics(this->schema_.column(0)); + std::vector validBits( ::arrow::bit_util::BytesForBits( static_cast(this->values_.size())) + 1, 255); - statistics3->UpdateSpaced( - this->values_ptr_, - valid_bits.data(), + statistics3->updateSpaced( + this->valuesPtr_, + validBits.data(), 0, this->values_.size(), this->values_.size(), 0); - std::string encoded_min_spaced = statistics3->EncodeMin(); - std::string encoded_max_spaced = statistics3->EncodeMax(); + std::string encodedMinSpaced = statistics3->encodeMin(); + std::string encodedMaxSpaced = statistics3->encodeMax(); - ASSERT_EQ(encoded_min, statistics2->EncodeMin()); - ASSERT_EQ(encoded_max, statistics2->EncodeMax()); + ASSERT_EQ(encodedMin, statistics2->encodeMin()); + ASSERT_EQ(encodedMax, statistics2->encodeMax()); ASSERT_EQ(statistics1->min(), statistics2->min()); ASSERT_EQ(statistics1->max(), statistics2->max()); - ASSERT_EQ(encoded_min_spaced, statistics2->EncodeMin()); - ASSERT_EQ(encoded_max_spaced, statistics2->EncodeMax()); + ASSERT_EQ(encodedMinSpaced, statistics2->encodeMin()); + ASSERT_EQ(encodedMaxSpaced, statistics2->encodeMax()); ASSERT_EQ(statistics3->min(), statistics2->min()); ASSERT_EQ(statistics3->max(), statistics2->max()); } - void TestReset() { - this->GenerateData(1000); - - auto statistics = MakeStatistics(this->schema_.Column(0)); - statistics->Update(this->values_ptr_, this->values_.size(), 0); - ASSERT_EQ(this->values_.size(), statistics->num_values()); - - statistics->Reset(); - ASSERT_TRUE(statistics->HasNullCount()); - ASSERT_FALSE(statistics->HasMinMax()); - ASSERT_FALSE(statistics->HasDistinctCount()); - ASSERT_EQ(0, statistics->null_count()); - ASSERT_EQ(0, statistics->num_values()); - ASSERT_EQ(0, statistics->distinct_count()); - ASSERT_EQ("", statistics->EncodeMin()); - ASSERT_EQ("", statistics->EncodeMax()); + void testReset() { + this->generateData(1000); + + auto Statistics = makeStatistics(this->schema_.column(0)); + Statistics->update(this->valuesPtr_, this->values_.size(), 0); + ASSERT_EQ(this->values_.size(), Statistics->numValues()); + + Statistics->reset(); + ASSERT_TRUE(Statistics->hasNullCount()); + ASSERT_FALSE(Statistics->hasMinMax()); + ASSERT_FALSE(Statistics->hasDistinctCount()); + ASSERT_EQ(0, Statistics->nullCount()); + ASSERT_EQ(0, Statistics->numValues()); + ASSERT_EQ(0, Statistics->distinctCount()); + ASSERT_EQ("", Statistics->encodeMin()); + ASSERT_EQ("", Statistics->encodeMax()); } - void TestMerge() { - int num_null[2]; - random_numbers(2, 42, 0, 100, num_null); + void testMerge() { + int numNull[2]; + randomNumbers(2, 42, 0, 100, numNull); - auto statistics1 = MakeStatistics(this->schema_.Column(0)); - this->GenerateData(1000); - statistics1->Update( - this->values_ptr_, this->values_.size() - num_null[0], num_null[0]); + auto statistics1 = makeStatistics(this->schema_.column(0)); + this->generateData(1000); + statistics1->update( + this->valuesPtr_, this->values_.size() - numNull[0], numNull[0]); - auto statistics2 = MakeStatistics(this->schema_.Column(0)); - this->GenerateData(1000); - statistics2->Update( - this->values_ptr_, this->values_.size() - num_null[1], num_null[1]); + auto statistics2 = makeStatistics(this->schema_.column(0)); + this->generateData(1000); + statistics2->update( + this->valuesPtr_, this->values_.size() - numNull[1], numNull[1]); - auto total = MakeStatistics(this->schema_.Column(0)); - total->Merge(*statistics1); - total->Merge(*statistics2); + auto total = makeStatistics(this->schema_.column(0)); + total->merge(*statistics1); + total->merge(*statistics2); - ASSERT_EQ(num_null[0] + num_null[1], total->null_count()); + ASSERT_EQ(numNull[0] + numNull[1], total->nullCount()); ASSERT_EQ( - this->values_.size() * 2 - num_null[0] - num_null[1], - total->num_values()); + this->values_.size() * 2 - numNull[0] - numNull[1], total->numValues()); ASSERT_EQ(total->min(), std::min(statistics1->min(), statistics2->min())); ASSERT_EQ(total->max(), std::max(statistics1->max(), statistics2->max())); } - void TestEquals() { - const auto n_values = 1; - auto statistics_have_minmax1 = - MakeStatistics(this->schema_.Column(0)); + void testEquals() { + const auto nValues = 1; + auto statisticsHaveMinmax1 = + makeStatistics(this->schema_.column(0)); const auto seed1 = 1; - this->GenerateData(n_values, seed1); - statistics_have_minmax1->Update(this->values_ptr_, this->values_.size(), 0); - auto statistics_have_minmax2 = - MakeStatistics(this->schema_.Column(0)); + this->generateData(nValues, seed1); + statisticsHaveMinmax1->update(this->valuesPtr_, this->values_.size(), 0); + auto statisticsHaveMinmax2 = + makeStatistics(this->schema_.column(0)); const auto seed2 = 9999; - this->GenerateData(n_values, seed2); - statistics_have_minmax2->Update(this->values_ptr_, this->values_.size(), 0); - auto statistics_no_minmax = - MakeStatistics(this->schema_.Column(0)); - - ASSERT_EQ(true, statistics_have_minmax1->Equals(*statistics_have_minmax1)); - ASSERT_EQ(true, statistics_no_minmax->Equals(*statistics_no_minmax)); - ASSERT_EQ(false, statistics_have_minmax1->Equals(*statistics_have_minmax2)); - ASSERT_EQ(false, statistics_have_minmax1->Equals(*statistics_no_minmax)); + this->generateData(nValues, seed2); + statisticsHaveMinmax2->update(this->valuesPtr_, this->values_.size(), 0); + auto statisticsNoMinmax = makeStatistics(this->schema_.column(0)); + + ASSERT_EQ(true, statisticsHaveMinmax1->equals(*statisticsHaveMinmax1)); + ASSERT_EQ(true, statisticsNoMinmax->equals(*statisticsNoMinmax)); + ASSERT_EQ(false, statisticsHaveMinmax1->equals(*statisticsHaveMinmax2)); + ASSERT_EQ(false, statisticsHaveMinmax1->equals(*statisticsNoMinmax)); } - void TestFullRoundtrip(int64_t numValues, int64_t nullCount) { - this->GenerateData(numValues); + void testFullRoundtrip(int64_t numValues, int64_t nullCount) { + this->generateData(numValues); - // compute statistics for the whole batch - auto expectedStats = MakeStatistics(this->schema_.Column(0)); - expectedStats->Update(this->values_ptr_, numValues - nullCount, nullCount); + // Compute statistics for the whole batch. + auto expectedStats = makeStatistics(this->schema_.column(0)); + expectedStats->update(this->valuesPtr_, numValues - nullCount, nullCount); - auto sink = CreateOutputStream(); + auto sink = createOutputStream(); auto gnode = std::static_pointer_cast(this->node_); - std::shared_ptr writerProperties = - WriterProperties::Builder().enable_statistics("column")->build(); - auto fileWriter = ParquetFileWriter::Open(sink, gnode, writerProperties); - auto rowGroupWriter = fileWriter->AppendRowGroup(); + std::shared_ptr WriterProperties = + WriterProperties::Builder().enableStatistics("column")->build(); + auto fileWriter = ParquetFileWriter::open(sink, gnode, WriterProperties); + auto rowGroupWriter = fileWriter->appendRowGroup(); auto columnWriter = - static_cast*>(rowGroupWriter->NextColumn()); + static_cast*>(rowGroupWriter->nextColumn()); - // simulate the case when data comes from multiple buffers, - // in which case special care is necessary for FLBA/ByteArray types + // Simulate the case when data comes from multiple buffers, + // in which case special care is necessary for FLBA/ByteArray types. for (int i = 0; i < 2; i++) { int64_t batchNumValues = i ? numValues - numValues / 2 : numValues / 2; int64_t batchNullCount = i ? nullCount : 0; @@ -464,29 +465,33 @@ class TestStatistics : public PrimitiveTypedTest { std::vector definitionLevels(batchNullCount, 0); definitionLevels.insert( definitionLevels.end(), batchNumValues - batchNullCount, 1); - auto beg = this->values_.begin() + i * numValues / 2; + auto beg = this->values_.cbegin() + i * numValues / 2; auto end = beg + batchNumValues; - std::vector batch = GetDeepCopy(std::vector(beg, end)); - c_type* batchValuesPtr = GetValuesPointer(batch); - columnWriter->WriteBatch( + std::vector batch = getDeepCopy(std::vector(beg, end)); + CType* batchValuesPtr = getValuesPointer(batch); + columnWriter->writeBatch( batchNumValues, definitionLevels.data(), nullptr, batchValuesPtr); - DeepFree(batch); + deepFree(batch); } - columnWriter->Close(); - rowGroupWriter->Close(); - fileWriter->Close(); + columnWriter->close(); + rowGroupWriter->close(); + fileWriter->close(); ASSERT_OK_AND_ASSIGN(auto buffer, sink->Finish()); - // Write the buffer to a temp file - auto filePath = exec::test::TempFilePath::create(); + // Write the buffer to a temp file. + auto filePath = TempFilePath::create(); writeToFile(filePath, buffer); memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); std::shared_ptr rootPool = memory::memoryManager()->addRootPool("StatisticsTest"); std::shared_ptr leafPool = rootPool->addLeafChild("StatisticsTest"); - dwio::common::ReaderOptions readerOptions{leafPool.get()}; + auto dataIoStats = std::make_shared(); + auto metadataIoStats = std::make_shared(); + dwio::common::ReaderOptions readerOptions(leafPool.get()); + readerOptions.setDataIoStats(dataIoStats); + readerOptions.setMetadataIoStats(metadataIoStats); auto input = std::make_unique( std::make_shared(filePath->getPath()), readerOptions.memoryPool()); @@ -495,12 +500,12 @@ class TestStatistics : public PrimitiveTypedTest { auto rowGroup = reader->fileMetaData().rowGroup(0); auto columnChunk = rowGroup.columnChunk(0); EXPECT_EQ(nullCount, columnChunk.getColumnMetadataStatsNullCount()); - EXPECT_TRUE(expectedStats->HasMinMax()); + EXPECT_TRUE(expectedStats->hasMinMax()); EXPECT_EQ( - expectedStats->EncodeMin(), + expectedStats->encodeMin(), columnChunk.getColumnMetadataStatsMinValue()); EXPECT_EQ( - expectedStats->EncodeMax(), + expectedStats->encodeMax(), columnChunk.getColumnMetadataStatsMaxValue()); auto columnStats = columnChunk.getColumnStatistics(INTEGER(), rowGroup.numRows()); @@ -509,29 +514,29 @@ class TestStatistics : public PrimitiveTypedTest { }; template -typename TestType::c_type* TestStatistics::GetValuesPointer( - std::vector& values) { +typename TestType::CType* TestStatistics::getValuesPointer( + std::vector& values) { return values.data(); } template <> -bool* TestStatistics::GetValuesPointer(std::vector& values) { - static std::vector bool_buffer; - bool_buffer.clear(); - bool_buffer.resize(values.size()); - std::copy(values.begin(), values.end(), bool_buffer.begin()); - return reinterpret_cast(bool_buffer.data()); +bool* TestStatistics::getValuesPointer(std::vector& values) { + static std::vector boolBuffer; + boolBuffer.clear(); + boolBuffer.resize(values.size()); + std::copy(values.begin(), values.end(), boolBuffer.begin()); + return reinterpret_cast(boolBuffer.data()); } template -typename std::vector -TestStatistics::GetDeepCopy( - const std::vector& values) { +typename std::vector +TestStatistics::getDeepCopy( + const std::vector& values) { return values; } template <> -std::vector TestStatistics::GetDeepCopy( +std::vector TestStatistics::getDeepCopy( const std::vector& values) { std::vector copy; MemoryPool* pool = ::arrow::default_memory_pool(); @@ -545,7 +550,7 @@ std::vector TestStatistics::GetDeepCopy( } template <> -std::vector TestStatistics::GetDeepCopy( +std::vector TestStatistics::getDeepCopy( const std::vector& values) { std::vector copy; MemoryPool* pool = default_memory_pool(); @@ -559,11 +564,11 @@ std::vector TestStatistics::GetDeepCopy( } template -void TestStatistics::DeepFree( - std::vector& values) {} +void TestStatistics::deepFree( + std::vector& values) {} template <> -void TestStatistics::DeepFree(std::vector& values) { +void TestStatistics::deepFree(std::vector& values) { MemoryPool* pool = default_memory_pool(); for (FLBA& flba : values) { auto ptr = const_cast(flba.ptr); @@ -573,7 +578,7 @@ void TestStatistics::DeepFree(std::vector& values) { } template <> -void TestStatistics::DeepFree(std::vector& values) { +void TestStatistics::deepFree(std::vector& values) { MemoryPool* pool = default_memory_pool(); for (ByteArray& ba : values) { auto ptr = const_cast(ba.ptr); @@ -583,39 +588,41 @@ void TestStatistics::DeepFree(std::vector& values) { } template <> -void TestStatistics::TestMinMaxEncode() { - this->GenerateData(1000); - // Test that we encode min max strings correctly - auto statistics1 = MakeStatistics(this->schema_.Column(0)); - statistics1->Update(this->values_ptr_, this->values_.size(), 0); - std::string encoded_min = statistics1->EncodeMin(); - std::string encoded_max = statistics1->EncodeMax(); - - // encoded is same as unencoded +void TestStatistics::testMinMaxEncode() { + this->generateData(1000); + // Test that we encode min max strings correctly. + auto statistics1 = makeStatistics(this->schema_.column(0)); + statistics1->update(this->valuesPtr_, this->values_.size(), 0); + std::string encodedMin = statistics1->encodeMin(); + std::string encodedMax = statistics1->encodeMax(); + + // Encoded is same as unencoded. ASSERT_EQ( - encoded_min, + encodedMin, std::string( reinterpret_cast(statistics1->min().ptr), statistics1->min().len)); ASSERT_EQ( - encoded_max, + encodedMax, std::string( reinterpret_cast(statistics1->max().ptr), statistics1->max().len)); - auto statistics2 = MakeStatistics( - this->schema_.Column(0), - encoded_min, - encoded_max, + auto statistics2 = makeStatistics( + this->schema_.column(0), + encodedMin, + encodedMax, this->values_.size(), - 0, - 0, - true, - true, - true); - - ASSERT_EQ(encoded_min, statistics2->EncodeMin()); - ASSERT_EQ(encoded_max, statistics2->EncodeMax()); + 0, // nullCount + 0, // distinctCount + true, // hasMinMax + true, // hasNullCount + true, // hasDistinctCount + false, // hasNaNCount + 0); // nanCount + + ASSERT_EQ(encodedMin, statistics2->encodeMin()); + ASSERT_EQ(encodedMax, statistics2->encodeMax()); ASSERT_EQ(statistics1->min(), statistics2->min()); ASSERT_EQ(statistics1->max(), statistics2->max()); } @@ -632,25 +639,25 @@ using Types = ::testing::Types< TYPED_TEST_SUITE(TestStatistics, Types); TYPED_TEST(TestStatistics, MinMaxEncode) { - this->SetUpSchema(Repetition::REQUIRED); - ASSERT_NO_FATAL_FAILURE(this->TestMinMaxEncode()); + this->setUpSchema(Repetition::kRequired); + ASSERT_NO_FATAL_FAILURE(this->testMinMaxEncode()); } -TYPED_TEST(TestStatistics, Reset) { - this->SetUpSchema(Repetition::OPTIONAL); - ASSERT_NO_FATAL_FAILURE(this->TestReset()); +TYPED_TEST(TestStatistics, reset) { + this->setUpSchema(Repetition::kOptional); + ASSERT_NO_FATAL_FAILURE(this->testReset()); } -TYPED_TEST(TestStatistics, Equals) { - this->SetUpSchema(Repetition::OPTIONAL); - ASSERT_NO_FATAL_FAILURE(this->TestEquals()); +TYPED_TEST(TestStatistics, equals) { + this->setUpSchema(Repetition::kOptional); + ASSERT_NO_FATAL_FAILURE(this->testEquals()); } TYPED_TEST(TestStatistics, FullRoundtrip) { - this->SetUpSchema(Repetition::OPTIONAL); - ASSERT_NO_FATAL_FAILURE(this->TestFullRoundtrip(100, 31)); - ASSERT_NO_FATAL_FAILURE(this->TestFullRoundtrip(1000, 415)); - ASSERT_NO_FATAL_FAILURE(this->TestFullRoundtrip(10000, 926)); + this->setUpSchema(Repetition::kOptional); + ASSERT_NO_FATAL_FAILURE(this->testFullRoundtrip(100, 31)); + ASSERT_NO_FATAL_FAILURE(this->testFullRoundtrip(1000, 415)); + ASSERT_NO_FATAL_FAILURE(this->testFullRoundtrip(10000, 926)); } template @@ -661,113 +668,103 @@ using NumericTypes = TYPED_TEST_SUITE(TestNumericStatistics, NumericTypes); -TYPED_TEST(TestNumericStatistics, Merge) { - this->SetUpSchema(Repetition::OPTIONAL); - ASSERT_NO_FATAL_FAILURE(this->TestMerge()); +TYPED_TEST(TestNumericStatistics, merge) { + this->setUpSchema(Repetition::kOptional); + ASSERT_NO_FATAL_FAILURE(this->testMerge()); } -TYPED_TEST(TestNumericStatistics, Equals) { - this->SetUpSchema(Repetition::OPTIONAL); - ASSERT_NO_FATAL_FAILURE(this->TestEquals()); +TYPED_TEST(TestNumericStatistics, equals) { + this->setUpSchema(Repetition::kOptional); + ASSERT_NO_FATAL_FAILURE(this->testEquals()); } template class TestStatisticsHasFlag : public TestStatistics { public: void SetUp() override { - TestStatistics::SetUp(); - this->SetUpSchema(Repetition::OPTIONAL); + this->setUpSchema(Repetition::kOptional); } - std::shared_ptr> MergedStatistics( + std::shared_ptr> mergedStatistics( const TypedStatistics& stats1, const TypedStatistics& stats2) { - auto chunk_statistics = MakeStatistics(this->schema_.Column(0)); - chunk_statistics->Merge(stats1); - chunk_statistics->Merge(stats2); - return chunk_statistics; + auto chunkStatistics = makeStatistics(this->schema_.column(0)); + chunkStatistics->merge(stats1); + chunkStatistics->merge(stats2); + return chunkStatistics; } - void VerifyMergedStatistics( + void verifyMergedStatistics( const TypedStatistics& stats1, const TypedStatistics& stats2, - const std::function*)>& test_fn) { - ASSERT_NO_FATAL_FAILURE(test_fn(MergedStatistics(stats1, stats2).get())); - ASSERT_NO_FATAL_FAILURE(test_fn(MergedStatistics(stats2, stats1).get())); + const std::function*)>& testFn) { + ASSERT_NO_FATAL_FAILURE(testFn(mergedStatistics(stats1, stats2).get())); + ASSERT_NO_FATAL_FAILURE(testFn(mergedStatistics(stats2, stats1).get())); } // Distinct count should set to false when Merge is called. - void TestMergeDistinctCount() { + void testMergeDistinctCount() { // Create a statistics object with distinct count. std::shared_ptr> statistics1; { - EncodedStatistics encoded_statistics1; - statistics1 = - std::dynamic_pointer_cast>(Statistics::Make( - this->schema_.Column(0), - &encoded_statistics1, - /*num_values=*/1000)); - EXPECT_FALSE(statistics1->HasDistinctCount()); + EncodedStatistics encodedStatistics1; + statistics1 = std::dynamic_pointer_cast>( + Statistics::make(this->schema_.column(0), &encodedStatistics1, 1000)); + EXPECT_FALSE(statistics1->hasDistinctCount()); } // Create a statistics object with distinct count. std::shared_ptr> statistics2; { - EncodedStatistics encoded_statistics2; - encoded_statistics2.has_distinct_count = true; - encoded_statistics2.distinct_count = 500; - statistics2 = - std::dynamic_pointer_cast>(Statistics::Make( - this->schema_.Column(0), - &encoded_statistics2, - /*num_values=*/1000)); - EXPECT_TRUE(statistics2->HasDistinctCount()); + EncodedStatistics encodedStatistics2; + encodedStatistics2.hasDistinctCount = true; + encodedStatistics2.distinctCount = 500; + statistics2 = std::dynamic_pointer_cast>( + Statistics::make(this->schema_.column(0), &encodedStatistics2, 1000)); + EXPECT_TRUE(statistics2->hasDistinctCount()); } - VerifyMergedStatistics( + verifyMergedStatistics( *statistics1, *statistics2, - [](TypedStatistics* merged_statistics) { - EXPECT_FALSE(merged_statistics->HasDistinctCount()); - EXPECT_FALSE(merged_statistics->Encode().has_distinct_count); + [](TypedStatistics* mergedStatistics) { + EXPECT_FALSE(mergedStatistics->hasDistinctCount()); + EXPECT_FALSE(mergedStatistics->encode().hasDistinctCount); }); } // If all values in a page are null or nan, its stats should not set min-max. // Merging its stats with another page having good min-max stats should not // drop the valid min-max from the latter page. - void TestMergeMinMax() { - this->GenerateData(1000); + void testMergeMinMax() { + this->generateData(1000); // Create a statistics object without min-max. std::shared_ptr> statistics1; { - statistics1 = MakeStatistics(this->schema_.Column(0)); - statistics1->Update( - this->values_ptr_, - /*num_values=*/0, - /*null_count=*/this->values_.size()); - auto encoded_stats1 = statistics1->Encode(); - EXPECT_FALSE(statistics1->HasMinMax()); - EXPECT_FALSE(encoded_stats1.has_min); - EXPECT_FALSE(encoded_stats1.has_max); + statistics1 = makeStatistics(this->schema_.column(0)); + statistics1->update(this->valuesPtr_, 0, this->values_.size()); + auto encodedStats1 = statistics1->encode(); + EXPECT_FALSE(statistics1->hasMinMax()); + EXPECT_FALSE(encodedStats1.hasMin); + EXPECT_FALSE(encodedStats1.hasMax); } // Create a statistics object with min-max. std::shared_ptr> statistics2; { - statistics2 = MakeStatistics(this->schema_.Column(0)); - statistics2->Update(this->values_ptr_, this->values_.size(), 0); - auto encoded_stats2 = statistics2->Encode(); - EXPECT_TRUE(statistics2->HasMinMax()); - EXPECT_TRUE(encoded_stats2.has_min); - EXPECT_TRUE(encoded_stats2.has_max); + statistics2 = makeStatistics(this->schema_.column(0)); + statistics2->update(this->valuesPtr_, this->values_.size(), 0); + auto encodedStats2 = statistics2->encode(); + EXPECT_TRUE(statistics2->hasMinMax()); + EXPECT_TRUE(encodedStats2.hasMin); + EXPECT_TRUE(encodedStats2.hasMax); } - VerifyMergedStatistics( + verifyMergedStatistics( *statistics1, *statistics2, - [](TypedStatistics* merged_statistics) { - EXPECT_TRUE(merged_statistics->HasMinMax()); - EXPECT_TRUE(merged_statistics->Encode().has_min); - EXPECT_TRUE(merged_statistics->Encode().has_max); + [](TypedStatistics* mergedStatistics) { + EXPECT_TRUE(mergedStatistics->hasMinMax()); + EXPECT_TRUE(mergedStatistics->encode().hasMin); + EXPECT_TRUE(mergedStatistics->encode().hasMax); }); } @@ -775,254 +772,275 @@ class TestStatisticsHasFlag : public TestStatistics { // However, if statistics is created from thrift message, it might not // have null_count. Merging statistics from such page will result in an // invalid null_count as well. - void TestMergeNullCount() { - this->GenerateData(/*num_values=*/1000); + void testMergeNullCount() { + this->generateData(1000); - // Page should have null-count even if no nulls + // Page should have null-count even if no nulls. std::shared_ptr> statistics1; { - statistics1 = MakeStatistics(this->schema_.Column(0)); - statistics1->Update( - this->values_ptr_, - /*num_values=*/this->values_.size(), - /*null_count=*/0); - auto encoded_stats1 = statistics1->Encode(); - EXPECT_TRUE(statistics1->HasNullCount()); - EXPECT_EQ(0, statistics1->null_count()); - EXPECT_TRUE(statistics1->Encode().has_null_count); + statistics1 = makeStatistics(this->schema_.column(0)); + statistics1->update(this->valuesPtr_, this->values_.size(), 0); + auto encodedStats1 = statistics1->encode(); + EXPECT_TRUE(statistics1->hasNullCount()); + EXPECT_EQ(0, statistics1->nullCount()); + EXPECT_TRUE(statistics1->encode().hasNullCount); } - // Merge with null-count should also have null count - VerifyMergedStatistics( + // Merge with null-count should also have null count. + verifyMergedStatistics( *statistics1, *statistics1, - [](TypedStatistics* merged_statistics) { - EXPECT_TRUE(merged_statistics->HasNullCount()); - EXPECT_EQ(0, merged_statistics->null_count()); - auto encoded = merged_statistics->Encode(); - EXPECT_TRUE(encoded.has_null_count); - EXPECT_EQ(0, encoded.null_count); + [](TypedStatistics* mergedStatistics) { + EXPECT_TRUE(mergedStatistics->hasNullCount()); + EXPECT_EQ(0, mergedStatistics->nullCount()); + auto encoded = mergedStatistics->encode(); + EXPECT_TRUE(encoded.hasNullCount); + EXPECT_EQ(0, encoded.nullCount); }); // When loaded from thrift, might not have null count. std::shared_ptr> statistics2; { - EncodedStatistics encoded_statistics2; - encoded_statistics2.has_null_count = false; - statistics2 = - std::dynamic_pointer_cast>(Statistics::Make( - this->schema_.Column(0), - &encoded_statistics2, - /*num_values=*/1000)); - EXPECT_FALSE(statistics2->Encode().has_null_count); - EXPECT_FALSE(statistics2->HasNullCount()); + EncodedStatistics encodedStatistics2; + encodedStatistics2.hasNullCount = false; + statistics2 = std::dynamic_pointer_cast>( + Statistics::make(this->schema_.column(0), &encodedStatistics2, 1000)); + EXPECT_FALSE(statistics2->encode().hasNullCount); + EXPECT_FALSE(statistics2->hasNullCount()); } - // Merge without null-count should not have null count - VerifyMergedStatistics( + // Merge without null-count should not have null count. + verifyMergedStatistics( *statistics1, *statistics2, - [](TypedStatistics* merged_statistics) { - EXPECT_FALSE(merged_statistics->HasNullCount()); - EXPECT_FALSE(merged_statistics->Encode().has_null_count); + [](TypedStatistics* mergedStatistics) { + EXPECT_FALSE(mergedStatistics->hasNullCount()); + EXPECT_FALSE(mergedStatistics->encode().hasNullCount); }); } - // statistics.all_null_value is used to build the page index. + // Statistics.all_null_value is used to build the page index. // If statistics doesn't have null count, all_null_value should be false. - void TestMissingNullCount() { - EncodedStatistics encoded_statistics; - encoded_statistics.has_null_count = false; - auto statistics = Statistics::Make( - this->schema_.Column(0), - &encoded_statistics, - /*num_values=*/1000); - auto typed_stats = - std::dynamic_pointer_cast>(statistics); - EXPECT_FALSE(typed_stats->HasNullCount()); - auto encoded = typed_stats->Encode(); - EXPECT_FALSE(encoded.all_null_value); - EXPECT_FALSE(encoded.has_null_count); - EXPECT_FALSE(encoded.has_distinct_count); - EXPECT_FALSE(encoded.has_min); - EXPECT_FALSE(encoded.has_max); + void testMissingNullCount() { + EncodedStatistics EncodedStatistics; + EncodedStatistics.hasNullCount = false; + auto Statistics = + Statistics::make(this->schema_.column(0), &EncodedStatistics, 1000); + auto typedStats = + std::dynamic_pointer_cast>(Statistics); + EXPECT_FALSE(typedStats->hasNullCount()); + auto encoded = typedStats->encode(); + EXPECT_FALSE(encoded.allNullValue); + EXPECT_FALSE(encoded.hasNullCount); + EXPECT_FALSE(encoded.hasDistinctCount); + EXPECT_FALSE(encoded.hasMin); + EXPECT_FALSE(encoded.hasMax); } }; TYPED_TEST_SUITE(TestStatisticsHasFlag, Types); TYPED_TEST(TestStatisticsHasFlag, MergeDistinctCount) { - ASSERT_NO_FATAL_FAILURE(this->TestMergeDistinctCount()); + ASSERT_NO_FATAL_FAILURE(this->testMergeDistinctCount()); } TYPED_TEST(TestStatisticsHasFlag, MergeNullCount) { - ASSERT_NO_FATAL_FAILURE(this->TestMergeNullCount()); + ASSERT_NO_FATAL_FAILURE(this->testMergeNullCount()); } TYPED_TEST(TestStatisticsHasFlag, MergeMinMax) { - ASSERT_NO_FATAL_FAILURE(this->TestMergeMinMax()); + ASSERT_NO_FATAL_FAILURE(this->testMergeMinMax()); } TYPED_TEST(TestStatisticsHasFlag, MissingNullCount) { - ASSERT_NO_FATAL_FAILURE(this->TestMissingNullCount()); + ASSERT_NO_FATAL_FAILURE(this->testMissingNullCount()); } -// Helper for basic statistics tests below -void AssertStatsSet( +// Helper for basic statistics tests below. +void assertStatsSet( const ApplicationVersion& version, std::shared_ptr props, const ColumnDescriptor* column, - bool expected_is_set) { - auto metadata_builder = ColumnChunkMetaDataBuilder::Make(props, column); - auto column_chunk = ColumnChunkMetaData::Make( - metadata_builder->contents(), - column, - default_reader_properties(), - &version); + bool expectedIsSet) { + auto metadataBuilder = ColumnChunkMetaDataBuilder::make(props, column); + auto columnChunk = ColumnChunkMetaData::make( + metadataBuilder->Contents(), column, defaultReaderProperties(), &version); EncodedStatistics stats; - stats.set_is_signed(false); - metadata_builder->SetStatistics(stats); - ASSERT_EQ(column_chunk->is_stats_set(), expected_is_set); + stats.setIsSigned(false); + metadataBuilder->setStatistics(stats); + ASSERT_EQ(columnChunk->isStatsSet(), expectedIsSet); } -// Statistics are restricted for few types in older parquet version +// Statistics are restricted for few types in older parquet version. TEST(CorruptStatistics, Basics) { - std::string created_by = "parquet-mr version 1.8.0"; - ApplicationVersion version(created_by); + std::string createdBy = "parquet-mr version 1.8.0"; + ApplicationVersion version(createdBy); SchemaDescriptor schema; - schema::NodePtr node; + schema::NodePtr Node; std::vector fields; - // Test Physical Types - fields.push_back(schema::PrimitiveNode::Make( - "col1", Repetition::OPTIONAL, Type::INT32, ConvertedType::NONE)); - fields.push_back(schema::PrimitiveNode::Make( - "col2", Repetition::OPTIONAL, Type::BYTE_ARRAY, ConvertedType::NONE)); - // Test Logical Types - fields.push_back(schema::PrimitiveNode::Make( - "col3", Repetition::OPTIONAL, Type::INT32, ConvertedType::DATE)); - fields.push_back(schema::PrimitiveNode::Make( - "col4", Repetition::OPTIONAL, Type::INT32, ConvertedType::UINT_32)); - fields.push_back(schema::PrimitiveNode::Make( - "col5", - Repetition::OPTIONAL, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::INTERVAL, - 12)); - fields.push_back(schema::PrimitiveNode::Make( - "col6", Repetition::OPTIONAL, Type::BYTE_ARRAY, ConvertedType::UTF8)); - node = schema::GroupNode::Make("schema", Repetition::REQUIRED, fields); - schema.Init(node); - - WriterProperties::Builder builder; - builder.created_by(created_by); - std::shared_ptr props = builder.build(); - - AssertStatsSet(version, props, schema.Column(0), true); - AssertStatsSet(version, props, schema.Column(1), false); - AssertStatsSet(version, props, schema.Column(2), true); - AssertStatsSet(version, props, schema.Column(3), false); - AssertStatsSet(version, props, schema.Column(4), false); - AssertStatsSet(version, props, schema.Column(5), false); -} - -// Statistics for all types have no restrictions in newer parquet version + // Test Physical Types. + fields.push_back( + schema::PrimitiveNode::make( + "col1", Repetition::kOptional, Type::kInt32, ConvertedType::kNone)); + fields.push_back( + schema::PrimitiveNode::make( + "col2", + Repetition::kOptional, + Type::kByteArray, + ConvertedType::kNone)); + // Test Logical Types. + fields.push_back( + schema::PrimitiveNode::make( + "col3", Repetition::kOptional, Type::kInt32, ConvertedType::kDate)); + fields.push_back( + schema::PrimitiveNode::make( + "col4", Repetition::kOptional, Type::kInt32, ConvertedType::kUint32)); + fields.push_back( + schema::PrimitiveNode::make( + "col5", + Repetition::kOptional, + Type::kFixedLenByteArray, + ConvertedType::kInterval, + 12)); + fields.push_back( + schema::PrimitiveNode::make( + "col6", + Repetition::kOptional, + Type::kByteArray, + ConvertedType::kUtf8)); + Node = schema::GroupNode::make("schema", Repetition::kRequired, fields); + schema.init(Node); + + WriterProperties::Builder Builder; + Builder.createdBy(createdBy); + std::shared_ptr props = Builder.build(); + + assertStatsSet(version, props, schema.column(0), true); + assertStatsSet(version, props, schema.column(1), false); + assertStatsSet(version, props, schema.column(2), true); + assertStatsSet(version, props, schema.column(3), false); + assertStatsSet(version, props, schema.column(4), false); + assertStatsSet(version, props, schema.column(5), false); +} + +// Statistics for all types have no restrictions in newer parquet version. TEST(CorrectStatistics, Basics) { - std::string created_by = "parquet-cpp version 1.3.0"; - ApplicationVersion version(created_by); + std::string createdBy = "parquet-cpp version 1.3.0"; + ApplicationVersion version(createdBy); SchemaDescriptor schema; - schema::NodePtr node; + schema::NodePtr Node; std::vector fields; - // Test Physical Types - fields.push_back(schema::PrimitiveNode::Make( - "col1", Repetition::OPTIONAL, Type::INT32, ConvertedType::NONE)); - fields.push_back(schema::PrimitiveNode::Make( - "col2", Repetition::OPTIONAL, Type::BYTE_ARRAY, ConvertedType::NONE)); - // Test Logical Types - fields.push_back(schema::PrimitiveNode::Make( - "col3", Repetition::OPTIONAL, Type::INT32, ConvertedType::DATE)); - fields.push_back(schema::PrimitiveNode::Make( - "col4", Repetition::OPTIONAL, Type::INT32, ConvertedType::UINT_32)); - fields.push_back(schema::PrimitiveNode::Make( - "col5", - Repetition::OPTIONAL, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::INTERVAL, - 12)); - fields.push_back(schema::PrimitiveNode::Make( - "col6", Repetition::OPTIONAL, Type::BYTE_ARRAY, ConvertedType::UTF8)); - node = schema::GroupNode::Make("schema", Repetition::REQUIRED, fields); - schema.Init(node); - - WriterProperties::Builder builder; - builder.created_by(created_by); - std::shared_ptr props = builder.build(); - - AssertStatsSet(version, props, schema.Column(0), true); - AssertStatsSet(version, props, schema.Column(1), true); - AssertStatsSet(version, props, schema.Column(2), true); - AssertStatsSet(version, props, schema.Column(3), true); - AssertStatsSet(version, props, schema.Column(4), false); - AssertStatsSet(version, props, schema.Column(5), true); -} - -// Test SortOrder class + // Test Physical Types. + fields.push_back( + schema::PrimitiveNode::make( + "col1", Repetition::kOptional, Type::kInt32, ConvertedType::kNone)); + fields.push_back( + schema::PrimitiveNode::make( + "col2", + Repetition::kOptional, + Type::kByteArray, + ConvertedType::kNone)); + // Test Logical Types. + fields.push_back( + schema::PrimitiveNode::make( + "col3", Repetition::kOptional, Type::kInt32, ConvertedType::kDate)); + fields.push_back( + schema::PrimitiveNode::make( + "col4", Repetition::kOptional, Type::kInt32, ConvertedType::kUint32)); + fields.push_back( + schema::PrimitiveNode::make( + "col5", + Repetition::kOptional, + Type::kFixedLenByteArray, + ConvertedType::kInterval, + 12)); + fields.push_back( + schema::PrimitiveNode::make( + "col6", + Repetition::kOptional, + Type::kByteArray, + ConvertedType::kUtf8)); + Node = schema::GroupNode::make("schema", Repetition::kRequired, fields); + schema.init(Node); + + WriterProperties::Builder Builder; + Builder.createdBy(createdBy); + std::shared_ptr props = Builder.build(); + + assertStatsSet(version, props, schema.column(0), true); + assertStatsSet(version, props, schema.column(1), true); + assertStatsSet(version, props, schema.column(2), true); + assertStatsSet(version, props, schema.column(3), true); + assertStatsSet(version, props, schema.column(4), false); + assertStatsSet(version, props, schema.column(5), true); +} + +// Test SortOrder class. static const int NUM_VALUES = 10; template class TestStatisticsSortOrder : public ::testing::Test { public: - using c_type = typename TestType::c_type; - - void AddNodes(std::string name) { - fields_.push_back(schema::PrimitiveNode::Make( - name, Repetition::REQUIRED, TestType::type_num, ConvertedType::NONE)); + using CType = typename TestType::CType; + + void addNodes(std::string name) { + fields_.push_back( + schema::PrimitiveNode::make( + name, + Repetition::kRequired, + TestType::typeNum, + ConvertedType::kNone)); } - void SetUpSchema() { + void setUpSchema() { stats_.resize(fields_.size()); values_.resize(NUM_VALUES); schema_ = std::static_pointer_cast( - GroupNode::Make("Schema", Repetition::REQUIRED, fields_)); + GroupNode::make("Schema", Repetition::kRequired, fields_)); - parquet_sink_ = CreateOutputStream(); + parquetSink_ = createOutputStream(); } - void SetValues(); + void setValues(); - void WriteParquet() { - // Add writer properties - WriterProperties::Builder builder; - builder.compression(Compression::SNAPPY); - builder.created_by("parquet-cpp version 1.3.0"); - std::shared_ptr props = builder.build(); + void writeParquet() { + // Add writer properties. + WriterProperties::Builder Builder; + Builder.compression(Compression::SNAPPY); + Builder.createdBy("parquet-cpp version 1.3.0"); + std::shared_ptr props = Builder.build(); - // Create a ParquetFileWriter instance - auto file_writer = ParquetFileWriter::Open(parquet_sink_, schema_, props); + // Create a ParquetFileWriter instance. + auto fileWriter = ParquetFileWriter::open(parquetSink_, schema_, props); // Append a RowGroup with a specific number of rows. - auto rg_writer = file_writer->AppendRowGroup(); + auto rgWriter = fileWriter->appendRowGroup(); - this->SetValues(); + this->setValues(); - // Insert Values + // Insert Values. for (int i = 0; i < static_cast(fields_.size()); i++) { - auto column_writer = - static_cast*>(rg_writer->NextColumn()); - column_writer->WriteBatch(NUM_VALUES, nullptr, nullptr, values_.data()); + auto columnWriter = + static_cast*>(rgWriter->nextColumn()); + columnWriter->writeBatch(NUM_VALUES, nullptr, nullptr, values_.data()); } } - void VerifyParquetStats() { - ASSERT_OK_AND_ASSIGN(auto pbuffer, parquet_sink_->Finish()); + void verifyParquetStats() { + ASSERT_OK_AND_ASSIGN(auto pbuffer, parquetSink_->Finish()); - // Write the pbuffer to a temp file - auto filePath = exec::test::TempFilePath::create(); + // Write the pbuffer to a temp file. + auto filePath = TempFilePath::create(); writeToFile(filePath, pbuffer); memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); std::shared_ptr rootPool = memory::memoryManager()->addRootPool("StatisticsTest"); std::shared_ptr leafPool = rootPool->addLeafChild("StatisticsTest"); - dwio::common::ReaderOptions readerOptions{leafPool.get()}; + auto dataIoStats = std::make_shared(); + auto metadataIoStats = std::make_shared(); + dwio::common::ReaderOptions readerOptions(leafPool.get()); + readerOptions.setDataIoStats(dataIoStats); + readerOptions.setMetadataIoStats(metadataIoStats); auto input = std::make_unique( std::make_shared(filePath->getPath()), readerOptions.memoryPool()); @@ -1038,126 +1056,143 @@ class TestStatisticsSortOrder : public ::testing::Test { } protected: - std::vector values_; - std::vector values_buf_; + std::vector values_; + std::vector valuesBuf_; std::vector fields_; std::shared_ptr schema_; - std::shared_ptr<::arrow::io::BufferOutputStream> parquet_sink_; + std::shared_ptr<::arrow::io::BufferOutputStream> parquetSink_; std::vector stats_; }; using CompareTestTypes = ::testing:: Types; -// TYPE::INT32 +// TYPE::INT32. template <> -void TestStatisticsSortOrder::AddNodes(std::string name) { - // UINT_32 logical type to set Unsigned Statistics - fields_.push_back(schema::PrimitiveNode::Make( - name, Repetition::REQUIRED, Type::INT32, ConvertedType::UINT_32)); - // INT_32 logical type to set Signed Statistics - fields_.push_back(schema::PrimitiveNode::Make( - name, Repetition::REQUIRED, Type::INT32, ConvertedType::INT_32)); +void TestStatisticsSortOrder::addNodes(std::string name) { + // UINT_32 logical type to set Unsigned Statistics. + fields_.push_back( + schema::PrimitiveNode::make( + name, Repetition::kRequired, Type::kInt32, ConvertedType::kUint32)); + // INT_32 logical type to set Signed Statistics. + fields_.push_back( + schema::PrimitiveNode::make( + name, Repetition::kRequired, Type::kInt32, ConvertedType::kInt32)); } template <> -void TestStatisticsSortOrder::SetValues() { +void TestStatisticsSortOrder::setValues() { for (int i = 0; i < NUM_VALUES; i++) { values_[i] = i - 5; // {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4}; } - // Write UINT32 min/max values + // Write UINT32 min/max values. stats_[0] - .set_min(std::string( - reinterpret_cast(&values_[5]), sizeof(c_type))) - .set_max(std::string( - reinterpret_cast(&values_[4]), sizeof(c_type))); - - // Write INT32 min/max values + .setMin( + std::string( + reinterpret_cast(&values_[5]), sizeof(CType))) + .setMax( + std::string( + reinterpret_cast(&values_[4]), sizeof(CType))); + + // Write INT32 min/max values. stats_[1] - .set_min(std::string( - reinterpret_cast(&values_[0]), sizeof(c_type))) - .set_max(std::string( - reinterpret_cast(&values_[9]), sizeof(c_type))); + .setMin( + std::string( + reinterpret_cast(&values_[0]), sizeof(CType))) + .setMax( + std::string( + reinterpret_cast(&values_[9]), sizeof(CType))); } -// TYPE::INT64 +// TYPE::INT64. template <> -void TestStatisticsSortOrder::AddNodes(std::string name) { - // UINT_64 logical type to set Unsigned Statistics - fields_.push_back(schema::PrimitiveNode::Make( - name, Repetition::REQUIRED, Type::INT64, ConvertedType::UINT_64)); - // INT_64 logical type to set Signed Statistics - fields_.push_back(schema::PrimitiveNode::Make( - name, Repetition::REQUIRED, Type::INT64, ConvertedType::INT_64)); +void TestStatisticsSortOrder::addNodes(std::string name) { + // UINT_64 logical type to set Unsigned Statistics. + fields_.push_back( + schema::PrimitiveNode::make( + name, Repetition::kRequired, Type::kInt64, ConvertedType::kUint64)); + // INT_64 logical type to set Signed Statistics. + fields_.push_back( + schema::PrimitiveNode::make( + name, Repetition::kRequired, Type::kInt64, ConvertedType::kInt64)); } template <> -void TestStatisticsSortOrder::SetValues() { +void TestStatisticsSortOrder::setValues() { for (int i = 0; i < NUM_VALUES; i++) { values_[i] = i - 5; // {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4}; } - // Write UINT64 min/max values + // Write UINT64 min/max values. stats_[0] - .set_min(std::string( - reinterpret_cast(&values_[5]), sizeof(c_type))) - .set_max(std::string( - reinterpret_cast(&values_[4]), sizeof(c_type))); - - // Write INT64 min/max values + .setMin( + std::string( + reinterpret_cast(&values_[5]), sizeof(CType))) + .setMax( + std::string( + reinterpret_cast(&values_[4]), sizeof(CType))); + + // Write INT64 min/max values. stats_[1] - .set_min(std::string( - reinterpret_cast(&values_[0]), sizeof(c_type))) - .set_max(std::string( - reinterpret_cast(&values_[9]), sizeof(c_type))); + .setMin( + std::string( + reinterpret_cast(&values_[0]), sizeof(CType))) + .setMax( + std::string( + reinterpret_cast(&values_[9]), sizeof(CType))); } -// TYPE::FLOAT +// TYPE::FLOAT. template <> -void TestStatisticsSortOrder::SetValues() { +void TestStatisticsSortOrder::setValues() { for (int i = 0; i < NUM_VALUES; i++) { values_[i] = static_cast(i) - 5; // {-5.0, -4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0}; } - // Write Float min/max values + // Write Float min/max values. stats_[0] - .set_min(std::string( - reinterpret_cast(&values_[0]), sizeof(c_type))) - .set_max(std::string( - reinterpret_cast(&values_[9]), sizeof(c_type))); + .setMin( + std::string( + reinterpret_cast(&values_[0]), sizeof(CType))) + .setMax( + std::string( + reinterpret_cast(&values_[9]), sizeof(CType))); } -// TYPE::DOUBLE +// TYPE::DOUBLE. template <> -void TestStatisticsSortOrder::SetValues() { +void TestStatisticsSortOrder::setValues() { for (int i = 0; i < NUM_VALUES; i++) { values_[i] = static_cast(i) - 5; // {-5.0, -4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0}; } - // Write Double min/max values + // Write Double min/max values. stats_[0] - .set_min(std::string( - reinterpret_cast(&values_[0]), sizeof(c_type))) - .set_max(std::string( - reinterpret_cast(&values_[9]), sizeof(c_type))); + .setMin( + std::string( + reinterpret_cast(&values_[0]), sizeof(CType))) + .setMax( + std::string( + reinterpret_cast(&values_[9]), sizeof(CType))); } -// TYPE::ByteArray +// TYPE::ByteArray. template <> -void TestStatisticsSortOrder::AddNodes(std::string name) { - // UTF8 logical type to set Unsigned Statistics - fields_.push_back(schema::PrimitiveNode::Make( - name, Repetition::REQUIRED, Type::BYTE_ARRAY, ConvertedType::UTF8)); +void TestStatisticsSortOrder::addNodes(std::string name) { + // UTF8 logical type to set Unsigned Statistics. + fields_.push_back( + schema::PrimitiveNode::make( + name, Repetition::kRequired, Type::kByteArray, ConvertedType::kUtf8)); } template <> -void TestStatisticsSortOrder::SetValues() { - int max_byte_array_len = 10; - size_t nbytes = NUM_VALUES * max_byte_array_len; - values_buf_.resize(nbytes); +void TestStatisticsSortOrder::setValues() { + int maxByteArrayLen = 10; + size_t nbytes = NUM_VALUES * maxByteArrayLen; + valuesBuf_.resize(nbytes); std::vector vals = { "c123", "b123", @@ -1170,7 +1205,7 @@ void TestStatisticsSortOrder::SetValues() { "i123", "ü123"}; - uint8_t* base = &values_buf_.data()[0]; + uint8_t* base = &valuesBuf_.data()[0]; for (int i = 0; i < NUM_VALUES; i++) { memcpy(base, vals[i].c_str(), vals[i].length()); values_[i].ptr = base; @@ -1178,30 +1213,34 @@ void TestStatisticsSortOrder::SetValues() { base += vals[i].length(); } - // Write String min/max values + // Write String min/max values. stats_[0] - .set_min(std::string( - reinterpret_cast(vals[2].c_str()), vals[2].length())) - .set_max(std::string( - reinterpret_cast(vals[9].c_str()), vals[9].length())); + .setMin( + std::string( + reinterpret_cast(vals[2].c_str()), vals[2].length())) + .setMax( + std::string( + reinterpret_cast(vals[9].c_str()), + vals[9].length())); } -// TYPE::FLBAArray +// TYPE::FLBAArray. template <> -void TestStatisticsSortOrder::AddNodes(std::string name) { - // FLBA has only Unsigned Statistics - fields_.push_back(schema::PrimitiveNode::Make( - name, - Repetition::REQUIRED, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::NONE, - FLBA_LENGTH)); +void TestStatisticsSortOrder::addNodes(std::string name) { + // FLBA has only Unsigned Statistics. + fields_.push_back( + schema::PrimitiveNode::make( + name, + Repetition::kRequired, + Type::kFixedLenByteArray, + ConvertedType::kNone, + FLBA_LENGTH)); } template <> -void TestStatisticsSortOrder::SetValues() { +void TestStatisticsSortOrder::setValues() { size_t nbytes = NUM_VALUES * FLBA_LENGTH; - values_buf_.resize(nbytes); + valuesBuf_.resize(nbytes); char vals[NUM_VALUES][FLBA_LENGTH] = { "b12345", "a12345", @@ -1214,88 +1253,93 @@ void TestStatisticsSortOrder::SetValues() { "z12345", "a12345"}; - uint8_t* base = &values_buf_.data()[0]; + uint8_t* base = &valuesBuf_.data()[0]; for (int i = 0; i < NUM_VALUES; i++) { memcpy(base, &vals[i][0], FLBA_LENGTH); values_[i].ptr = base; base += FLBA_LENGTH; } - // Write FLBA min,max values + // Write FLBA min,max values. stats_[0] - .set_min( + .setMin( std::string(reinterpret_cast(&vals[1][0]), FLBA_LENGTH)) - .set_max( + .setMax( std::string(reinterpret_cast(&vals[8][0]), FLBA_LENGTH)); } TYPED_TEST_SUITE(TestStatisticsSortOrder, CompareTestTypes); TYPED_TEST(TestStatisticsSortOrder, MinMax) { - this->AddNodes("Column "); - this->SetUpSchema(); - this->WriteParquet(); - ASSERT_NO_FATAL_FAILURE(this->VerifyParquetStats()); + this->addNodes("Column "); + this->setUpSchema(); + this->writeParquet(); + ASSERT_NO_FATAL_FAILURE(this->verifyParquetStats()); } template -void TestByteArrayStatisticsFromArrow() { +void testByteArrayStatisticsFromArrow() { using TypeTraits = ::arrow::TypeTraits; using ArrayType = typename TypeTraits::ArrayType; - auto values = ArrayFromJSON( + auto values = ::arrow::ArrayFromJSON( TypeTraits::type_singleton(), "[\"c123\", \"b123\", \"a123\", null, " "null, \"f123\", \"g123\", \"h123\", \"i123\", \"ü123\"]"); - const auto& typed_values = static_cast(*values); + const auto& typedValues = static_cast(*values); - NodePtr node = PrimitiveNode::Make( - "field", Repetition::REQUIRED, Type::BYTE_ARRAY, ConvertedType::UTF8); - ColumnDescriptor descr(node, 0, 0); - auto stats = MakeStatistics(&descr); - ASSERT_NO_FATAL_FAILURE(stats->Update(*values)); + NodePtr Node = PrimitiveNode::make( + "field", Repetition::kRequired, Type::kByteArray, ConvertedType::kUtf8); + ColumnDescriptor descr(Node, 0, 0); + auto stats = makeStatistics(&descr); + ASSERT_NO_FATAL_FAILURE(stats->update(*values)); - ASSERT_EQ(ByteArray(typed_values.GetView(2)), stats->min()); - ASSERT_EQ(ByteArray(typed_values.GetView(9)), stats->max()); - ASSERT_EQ(2, stats->null_count()); + ASSERT_EQ(ByteArray(typedValues.GetView(2)), stats->min()); + ASSERT_EQ(ByteArray(typedValues.GetView(9)), stats->max()); + ASSERT_EQ(2, stats->nullCount()); } -TEST(TestByteArrayStatisticsFromArrow, StringType) { - // Part of ARROW-3246. Replicating TestStatisticsSortOrder test but via Arrow - TestByteArrayStatisticsFromArrow<::arrow::StringType>(); +TEST(testByteArrayStatisticsFromArrow, StringType) { + // Part of ARROW-3246. Replicating TestStatisticsSortOrder test but via Arrow. + testByteArrayStatisticsFromArrow<::arrow::StringType>(); } -TEST(TestByteArrayStatisticsFromArrow, LargeStringType) { - TestByteArrayStatisticsFromArrow<::arrow::LargeStringType>(); +TEST(testByteArrayStatisticsFromArrow, LargeStringType) { + testByteArrayStatisticsFromArrow<::arrow::LargeStringType>(); } -// Ensure Decimal sort order is handled properly +// Ensure Decimal sort order is handled properly. using TestStatisticsSortOrderFLBA = TestStatisticsSortOrder; TEST_F(TestStatisticsSortOrderFLBA, decimalSortOrder) { - this->fields_.push_back(schema::PrimitiveNode::Make( - "Column 0", - Repetition::REQUIRED, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::DECIMAL, - FLBA_LENGTH, - 12, - 2)); - this->SetUpSchema(); - this->WriteParquet(); - - ASSERT_OK_AND_ASSIGN(auto pbuffer, parquet_sink_->Finish()); - - // Write the pbuffer to a temp file - auto filePath = exec::test::TempFilePath::create(); + this->fields_.push_back( + schema::PrimitiveNode::make( + "Column 0", + Repetition::kRequired, + Type::kFixedLenByteArray, + ConvertedType::kDecimal, + FLBA_LENGTH, + 12, + 2)); + this->setUpSchema(); + this->writeParquet(); + + ASSERT_OK_AND_ASSIGN(auto pbuffer, parquetSink_->Finish()); + + // Write the pbuffer to a temp file. + auto filePath = TempFilePath::create(); writeToFile(filePath, pbuffer); memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); std::shared_ptr rootPool = memory::memoryManager()->addRootPool("StatisticsTest"); std::shared_ptr leafPool = rootPool->addLeafChild("StatisticsTest"); - dwio::common::ReaderOptions readerOptions{leafPool.get()}; + auto dataIoStats = std::make_shared(); + auto metadataIoStats = std::make_shared(); + dwio::common::ReaderOptions readerOptions(leafPool.get()); + readerOptions.setDataIoStats(dataIoStats); + readerOptions.setMetadataIoStats(metadataIoStats); auto input = std::make_unique( std::make_shared(filePath->getPath()), readerOptions.memoryPool()); @@ -1310,65 +1354,65 @@ template < typename Stats, typename Array, typename T = typename Array::value_type> -void AssertMinMaxAre( +void assertMinMaxAre( Stats stats, const Array& values, - T expected_min, - T expected_max) { - stats->Update(values.data(), values.size(), 0); - ASSERT_TRUE(stats->HasMinMax()); - EXPECT_EQ(stats->min(), expected_min); - EXPECT_EQ(stats->max(), expected_max); + T expectedMin, + T expectedMax) { + stats->update(values.data(), values.size(), 0); + ASSERT_TRUE(stats->hasMinMax()); + EXPECT_EQ(stats->min(), expectedMin); + EXPECT_EQ(stats->max(), expectedMax); } template -void AssertMinMaxAre( +void assertMinMaxAre( Stats stats, const Array& values, - const uint8_t* valid_bitmap, - T expected_min, - T expected_max) { - auto n_values = values.size(); - auto null_count = ::arrow::internal::CountSetBits(valid_bitmap, n_values, 0); - auto non_null_count = n_values - null_count; - stats->UpdateSpaced( + const uint8_t* validBitmap, + T expectedMin, + T expectedMax) { + auto nValues = values.size(); + auto nullCount = ::arrow::internal::CountSetBits(validBitmap, nValues, 0); + auto nonNullCount = nValues - nullCount; + stats->updateSpaced( values.data(), - valid_bitmap, + validBitmap, 0, - non_null_count + null_count, - non_null_count, - null_count); - ASSERT_TRUE(stats->HasMinMax()); - EXPECT_EQ(stats->min(), expected_min); - EXPECT_EQ(stats->max(), expected_max); + nonNullCount + nullCount, + nonNullCount, + nullCount); + ASSERT_TRUE(stats->hasMinMax()); + EXPECT_EQ(stats->min(), expectedMin); + EXPECT_EQ(stats->max(), expectedMax); } template -void AssertUnsetMinMax(Stats stats, const Array& values) { - stats->Update(values.data(), values.size(), 0); - ASSERT_FALSE(stats->HasMinMax()); +void assertUnsetMinMax(Stats stats, const Array& values) { + stats->update(values.data(), values.size(), 0); + ASSERT_FALSE(stats->hasMinMax()); } template -void AssertUnsetMinMax( +void assertUnsetMinMax( Stats stats, const Array& values, - const uint8_t* valid_bitmap) { - auto n_values = values.size(); - auto null_count = ::arrow::internal::CountSetBits(valid_bitmap, n_values, 0); - auto non_null_count = n_values - null_count; - stats->UpdateSpaced( + const uint8_t* validBitmap) { + auto nValues = values.size(); + auto nullCount = ::arrow::internal::CountSetBits(validBitmap, nValues, 0); + auto nonNullCount = nValues - nullCount; + stats->updateSpaced( values.data(), - valid_bitmap, + validBitmap, 0, - non_null_count + null_count, - non_null_count, - null_count); - ASSERT_FALSE(stats->HasMinMax()); + nonNullCount + nullCount, + nonNullCount, + nullCount); + ASSERT_FALSE(stats->hasMinMax()); } -template -void CheckExtrema() { +template +void checkExtrema() { using UT = typename std::make_unsigned::type; const T smin = std::numeric_limits::min(); @@ -1380,18 +1424,18 @@ void CheckExtrema() { std::array values{ 0, smin, smax, umin, umax, smin + 1, smax - 1, umax - 1}; - NodePtr unsigned_node = PrimitiveNode::Make( + NodePtr unsignedNode = PrimitiveNode::make( "uint", - Repetition::OPTIONAL, - LogicalType::Int(sizeof(T) * CHAR_BIT, false /*signed*/), - ParquetType::type_num); - ColumnDescriptor unsigned_descr(unsigned_node, 1, 1); - NodePtr signed_node = PrimitiveNode::Make( + Repetition::kOptional, + LogicalType::intType(sizeof(T) * CHAR_BIT, false /*signed*/), + ParquetType::typeNum); + ColumnDescriptor unsignedDescr(unsignedNode, 1, 1); + NodePtr signedNode = PrimitiveNode::make( "int", - Repetition::OPTIONAL, - LogicalType::Int(sizeof(T) * CHAR_BIT, true /*signed*/), - ParquetType::type_num); - ColumnDescriptor signed_descr(signed_node, 1, 1); + Repetition::kOptional, + LogicalType::intType(sizeof(T) * CHAR_BIT, true /*signed*/), + ParquetType::typeNum); + ColumnDescriptor signedDescr(signedNode, 1, 1); { ARROW_SCOPED_TRACE( @@ -1400,13 +1444,13 @@ void CheckExtrema() { ", umax = ", umax, ", node type = ", - unsigned_node->logical_type()->ToString(), + unsignedNode->logicalType()->toString(), ", physical type = ", - unsigned_descr.physical_type(), + unsignedDescr.physicalType(), ", sort order = ", - unsigned_descr.sort_order()); - auto unsigned_stats = MakeStatistics(&unsigned_descr); - AssertMinMaxAre(unsigned_stats, values, umin, umax); + unsignedDescr.sortOrder()); + auto unsignedStats = makeStatistics(&unsignedDescr); + assertMinMaxAre(unsignedStats, values, umin, umax); } { ARROW_SCOPED_TRACE( @@ -1415,20 +1459,20 @@ void CheckExtrema() { ", smax = ", smax, ", node type = ", - signed_node->logical_type()->ToString(), + signedNode->logicalType()->toString(), ", physical type = ", - signed_descr.physical_type(), + signedDescr.physicalType(), ", sort order = ", - signed_descr.sort_order()); - auto signed_stats = MakeStatistics(&signed_descr); - AssertMinMaxAre(signed_stats, values, smin, smax); + signedDescr.sortOrder()); + auto signedStats = makeStatistics(&signedDescr); + assertMinMaxAre(signedStats, values, smin, smax); } - // With validity bitmap - std::vector is_valid = { + // With validity bitmap. + std::vector isValid = { true, false, false, false, false, true, true, true}; - std::shared_ptr valid_bitmap; - ::arrow::BitmapFromVector(is_valid, &valid_bitmap); + std::shared_ptr validBitmap; + ::arrow::BitmapFromVector(isValid, &validBitmap); { ARROW_SCOPED_TRACE( "spaced unsigned statistics: umin = ", @@ -1436,14 +1480,13 @@ void CheckExtrema() { ", umax = ", umax, ", node type = ", - unsigned_node->logical_type()->ToString(), + unsignedNode->logicalType()->toString(), ", physical type = ", - unsigned_descr.physical_type(), + unsignedDescr.physicalType(), ", sort order = ", - unsigned_descr.sort_order()); - auto unsigned_stats = MakeStatistics(&unsigned_descr); - AssertMinMaxAre( - unsigned_stats, values, valid_bitmap->data(), T{0}, umax - 1); + unsignedDescr.sortOrder()); + auto unsignedStats = makeStatistics(&unsignedDescr); + assertMinMaxAre(unsignedStats, values, validBitmap->data(), T{0}, umax - 1); } { ARROW_SCOPED_TRACE( @@ -1452,99 +1495,102 @@ void CheckExtrema() { ", smax = ", smax, ", node type = ", - signed_node->logical_type()->ToString(), + signedNode->logicalType()->toString(), ", physical type = ", - signed_descr.physical_type(), + signedDescr.physicalType(), ", sort order = ", - signed_descr.sort_order()); - auto signed_stats = MakeStatistics(&signed_descr); - AssertMinMaxAre( - signed_stats, values, valid_bitmap->data(), smin + 1, smax - 1); + signedDescr.sortOrder()); + auto signedStats = makeStatistics(&signedDescr); + assertMinMaxAre( + signedStats, values, validBitmap->data(), smin + 1, smax - 1); } } TEST(TestStatistic, Int32Extrema) { - CheckExtrema(); + checkExtrema(); } TEST(TestStatistic, Int64Extrema) { - CheckExtrema(); + checkExtrema(); } -// PARQUET-1225: Float NaN values may lead to incorrect min-max +// PARQUET-1225: Float NaN values may lead to incorrect min-max. template -void CheckNaNs() { - using T = typename ParquetType::c_type; +void checkNaNs() { + using T = typename ParquetType::CType; constexpr int kNumValues = 8; - NodePtr node = - PrimitiveNode::Make("f", Repetition::OPTIONAL, ParquetType::type_num); - ColumnDescriptor descr(node, 1, 1); + NodePtr Node = + PrimitiveNode::make("f", Repetition::kOptional, ParquetType::typeNum); + ColumnDescriptor descr(Node, 1, 1); constexpr T nan = std::numeric_limits::quiet_NaN(); constexpr T min = -4.0f; constexpr T max = 3.0f; - std::array all_nans{nan, nan, nan, nan, nan, nan, nan, nan}; - std::array some_nans{ + std::array allNans{nan, nan, nan, nan, nan, nan, nan, nan}; + std::array someNans{ nan, max, -3.0f, -1.0f, nan, 2.0f, min, nan}; - uint8_t valid_bitmap = 0x7F; // 0b01111111 - // NaNs excluded - uint8_t valid_bitmap_no_nans = 0x6E; // 0b01101110 - - // Test values - auto some_nan_stats = MakeStatistics(&descr); - // Ingesting only nans should not yield valid min max - AssertUnsetMinMax(some_nan_stats, all_nans); + uint8_t validBitmap = 0x7F; // 0b01111111 + // NaNs excluded. + uint8_t validBitmapNoNans = 0x6E; // 0b01101110 + + // Test values. + auto someNanStats = makeStatistics(&descr); + // Ingesting only nans should not yield valid min max. + assertUnsetMinMax(someNanStats, allNans); + EXPECT_EQ(someNanStats->nanCount(), allNans.size()); // Ingesting a mix of NaNs and non-NaNs should not yield valid min max. - AssertMinMaxAre(some_nan_stats, some_nans, min, max); - // Ingesting only nans after a valid min/max, should have not effect - AssertMinMaxAre(some_nan_stats, all_nans, min, max); + assertMinMaxAre(someNanStats, someNans, min, max); + // Ingesting only nans after a valid min/max, should have not effect. + assertMinMaxAre(someNanStats, allNans, min, max); - some_nan_stats = MakeStatistics(&descr); - AssertUnsetMinMax(some_nan_stats, all_nans, &valid_bitmap); + someNanStats = makeStatistics(&descr); + assertUnsetMinMax(someNanStats, allNans, &validBitmap); // NaNs should not pollute min max when excluded via null bitmap. - AssertMinMaxAre(some_nan_stats, some_nans, &valid_bitmap_no_nans, min, max); + assertMinMaxAre(someNanStats, someNans, &validBitmapNoNans, min, max); // Ingesting NaNs with a null bitmap should not change the result. - AssertMinMaxAre(some_nan_stats, some_nans, &valid_bitmap, min, max); + assertMinMaxAre(someNanStats, someNans, &validBitmap, min, max); - // An array that doesn't start with NaN - std::array other_nans{ + // An array that doesn't start with NaN. + std::array otherNans{ 1.5f, max, -3.0f, -1.0f, nan, 2.0f, min, nan}; - auto other_stats = MakeStatistics(&descr); - AssertMinMaxAre(other_stats, other_nans, min, max); + auto otherStats = makeStatistics(&descr); + assertMinMaxAre(otherStats, otherNans, min, max); + EXPECT_EQ(otherStats->nanCount(), 2); } TEST(TestStatistic, NaNFloatValues) { - CheckNaNs(); + checkNaNs(); } TEST(TestStatistic, NaNDoubleValues) { - CheckNaNs(); + checkNaNs(); } -// ARROW-7376 +// ARROW-7376. TEST(TestStatisticsSortOrderFloatNaN, NaNAndNullsInfiniteLoop) { constexpr int kNumValues = 8; - NodePtr node = - PrimitiveNode::Make("nan_float", Repetition::OPTIONAL, Type::FLOAT); - ColumnDescriptor descr(node, 1, 1); + NodePtr Node = + PrimitiveNode::make("nan_float", Repetition::kOptional, Type::kFloat); + ColumnDescriptor descr(Node, 1, 1); constexpr float nan = std::numeric_limits::quiet_NaN(); - std::array nans_but_last{ + std::array nansButLast{ nan, nan, nan, nan, nan, nan, nan, 0.0f}; - uint8_t all_but_last_valid = 0x7F; // 0b01111111 - auto stats = MakeStatistics(&descr); - AssertUnsetMinMax(stats, nans_but_last, &all_but_last_valid); + uint8_t allButLastValid = 0x7F; // 0b01111111 + auto stats = makeStatistics(&descr); + assertUnsetMinMax(stats, nansButLast, &allButLastValid); + EXPECT_EQ(stats->nanCount(), kNumValues - 1); } template < typename Stats, typename Array, typename T = typename Array::value_type> -void AssertMinMaxZeroesSign(Stats stats, const Array& values) { - stats->Update(values.data(), values.size(), 0); - ASSERT_TRUE(stats->HasMinMax()); +void assertMinMaxZeroesSign(Stats stats, const Array& values) { + stats->update(values.data(), values.size(), 0); + ASSERT_TRUE(stats->hasMinMax()); T zero{}; ASSERT_EQ(stats->min(), zero); @@ -1554,52 +1600,218 @@ void AssertMinMaxZeroesSign(Stats stats, const Array& values) { ASSERT_FALSE(std::signbit(stats->max())); } -// ARROW-5562: Ensure that -0.0f and 0.0f values are properly handled like in -// parquet-mr +// ARROW-5562: Ensure that -0.0f and 0.0f values are properly handled like in. +// Parquet-mr. template -void CheckNegativeZeroStats() { - using T = typename ParquetType::c_type; +void checkNegativeZeroStats() { + using T = typename ParquetType::CType; - NodePtr node = - PrimitiveNode::Make("f", Repetition::OPTIONAL, ParquetType::type_num); - ColumnDescriptor descr(node, 1, 1); + NodePtr Node = + PrimitiveNode::make("f", Repetition::kOptional, ParquetType::typeNum); + ColumnDescriptor descr(Node, 1, 1); T zero{}; { std::array values{-zero, zero}; - auto stats = MakeStatistics(&descr); - AssertMinMaxZeroesSign(stats, values); + auto stats = makeStatistics(&descr); + assertMinMaxZeroesSign(stats, values); } { std::array values{zero, -zero}; - auto stats = MakeStatistics(&descr); - AssertMinMaxZeroesSign(stats, values); + auto stats = makeStatistics(&descr); + assertMinMaxZeroesSign(stats, values); } { std::array values{-zero, -zero}; - auto stats = MakeStatistics(&descr); - AssertMinMaxZeroesSign(stats, values); + auto stats = makeStatistics(&descr); + assertMinMaxZeroesSign(stats, values); } { std::array values{zero, zero}; - auto stats = MakeStatistics(&descr); - AssertMinMaxZeroesSign(stats, values); + auto stats = makeStatistics(&descr); + assertMinMaxZeroesSign(stats, values); } } TEST(TestStatistics, FloatNegativeZero) { - CheckNegativeZeroStats(); + checkNegativeZeroStats(); } TEST(TestStatistics, DoubleNegativeZero) { - CheckNegativeZeroStats(); + checkNegativeZeroStats(); +} + +// Test infinity handling in statistics. +template +void checkInfinityStats() { + using T = typename ParquetType::CType; + + constexpr int32_t kNumValues = 8; + NodePtr Node = PrimitiveNode::make( + "infinity_test", Repetition::kOptional, ParquetType::typeNum); + ColumnDescriptor descr(Node, 1, 1); + + constexpr T posInf = std::numeric_limits::infinity(); + constexpr T negInf = -std::numeric_limits::infinity(); + constexpr T min = -1.0f; + constexpr T max = 1.0f; + + { + std::array allPosInf{ + posInf, posInf, posInf, posInf, posInf, posInf, posInf, posInf}; + auto stats = makeStatistics(&descr); + assertMinMaxAre(stats, allPosInf, posInf, posInf); + } + + { + std::array allNegInf{ + negInf, negInf, negInf, negInf, negInf, negInf, negInf, negInf}; + auto stats = makeStatistics(&descr); + assertMinMaxAre(stats, allNegInf, negInf, negInf); + } + + { + std::array mixedInf{ + posInf, negInf, posInf, negInf, posInf, negInf, posInf, negInf}; + auto stats = makeStatistics(&descr); + assertMinMaxAre(stats, mixedInf, negInf, posInf); + } + + { + std::array mixedValues{ + posInf, max, min, min, negInf, max, min, posInf}; + auto stats = makeStatistics(&descr); + assertMinMaxAre(stats, mixedValues, negInf, posInf); + } + + { + constexpr T nan = std::numeric_limits::quiet_NaN(); + std::array mixedWithNan{ + posInf, nan, max, negInf, nan, min, posInf, nan}; + auto stats = makeStatistics(&descr); + assertMinMaxAre(stats, mixedWithNan, negInf, posInf); + } +} + +TEST(TestStatistics, FloatInfinityValues) { + checkInfinityStats(); +} + +TEST(TestStatistics, DoubleInfinityValues) { + checkInfinityStats(); +} + +// Test infinity values with validity bitmap. +TEST(TestStatistics, InfinityWithNullBitmap) { + constexpr int kNumValues = 8; + NodePtr Node = PrimitiveNode::make( + "infinity_null_test", Repetition::kOptional, Type::kFloat); + ColumnDescriptor descr(Node, 1, 1); + + constexpr float posInf = std::numeric_limits::infinity(); + constexpr float negInf = -std::numeric_limits::infinity(); + + // Test with some infinity values marked as null. + std::array valuesWithNulls{ + posInf, negInf, 1.0f, 2.0f, posInf, -1.0f, 3.0f, negInf}; + + // Bitmap: exclude first posInf and last negInf (01111110 = 0x7E). + uint8_t validBitmap = 0x7E; + + auto stats = makeStatistics(&descr); + assertMinMaxAre(stats, valuesWithNulls, &validBitmap, negInf, posInf); + valuesWithNulls = {posInf, 0.0f, 1.0f, 2.0f, -2.0f, -1.0f, 3.0f, negInf}; + + stats = makeStatistics(&descr); + assertMinMaxAre(stats, valuesWithNulls, &validBitmap, -2.0f, 3.0f); +} + +// Test merging statistics with infinity values. +TEST(TestStatistics, MergeInfinityStatistics) { + NodePtr Node = PrimitiveNode::make( + "merge_infinity", Repetition::kOptional, Type::kDouble); + ColumnDescriptor descr(Node, 1, 1); + + constexpr double posInf = std::numeric_limits::infinity(); + constexpr double negInf = -std::numeric_limits::infinity(); + + auto stats1 = makeStatistics(&descr); + std::array normalValues{-1.0f, 0.0f, 1.0f}; + assertMinMaxAre(stats1, normalValues, -1.0f, 1.0f); + + auto stats2 = makeStatistics(&descr); + std::array infinityValues{negInf, posInf}; + assertMinMaxAre(stats2, infinityValues, negInf, posInf); + + auto mergedStats = makeStatistics(&descr); + mergedStats->merge(*stats1); + mergedStats->merge(*stats2); + + // Result should have infinity bounds. + ASSERT_TRUE(mergedStats->hasMinMax()); + ASSERT_EQ(negInf, mergedStats->min()); + ASSERT_EQ(posInf, mergedStats->max()); +} + +TEST(TestStatistics, CleanInfinityStatistics) { + constexpr int kNumValues = 4; + NodePtr Node = PrimitiveNode::make( + "clean_stat_nullopt", Repetition::kOptional, Type::kFloat); + ColumnDescriptor descr(Node, 1, 1); + + constexpr float nan = std::numeric_limits::quiet_NaN(); + + { + std::array allNans{nan, nan, nan, nan}; + auto stats = makeStatistics(&descr); + assertUnsetMinMax(stats, allNans); + } + + { + std::array values{1.0f, 2.0f, 3.0f, 4.0f}; + uint8_t allNullBitmap = 0x00; + + auto stats = makeStatistics(&descr); + assertUnsetMinMax(stats, values, &allNullBitmap); + } + + { + std::array mixedNans{nan, 1.0f, nan, 2.0f}; + uint8_t partialNullBitmap = 0x05; + + auto stats = makeStatistics(&descr); + assertUnsetMinMax(stats, mixedNans, &partialNullBitmap); + } +} + +TEST(TestStatistics, InfinityCleanStatisticValid) { + constexpr int kNumValues = 4; + NodePtr Node = PrimitiveNode::make( + "clean_stat_valid", Repetition::kOptional, Type::kDouble); + ColumnDescriptor descr(Node, 1, 1); + + constexpr double posInf = std::numeric_limits::infinity(); + constexpr double negInf = -std::numeric_limits::infinity(); + constexpr double nan = std::numeric_limits::quiet_NaN(); + + { + std::array mixedValues{posInf, nan, negInf, nan}; + auto stats = makeStatistics(&descr); + assertMinMaxAre(stats, mixedValues, negInf, posInf); + } + + { + std::array singleInf{negInf}; + auto stats = makeStatistics(&descr); + assertMinMaxAre(stats, singleInf, negInf, negInf); + } } // TODO: disabled as it requires Arrow parquet data dir. -// Test statistics for binary column with UNSIGNED sort order +// Test statistics for binary column with UNSIGNED sort order. /* TEST(TestStatisticsSortOrderMinMax, Unsigned) { std::string dir_string(test::get_data_dir()); @@ -1607,15 +1819,15 @@ TEST(TestStatisticsSortOrderMinMax, Unsigned) { ss << dir_string << "/binary.parquet"; auto path = ss.str(); - // The file is generated by parquet-mr 1.10.0, the first version that - // supports correct statistics for binary data (see PARQUET-1025). It - // contains a single column of binary type. Data is just single byte values - // from 0x00 to 0x0B. + // The file is generated by parquet-mr 1.10.0, the first version that. + // Supports correct statistics for binary data (see PARQUET-1025). It. + // Contains a single column of binary type. Data is just single byte values. + // From 0x00 to 0x0B. auto file_reader = ParquetFileReader::OpenFile(path); auto rg_reader = file_reader->RowGroup(0); auto metadata = rg_reader->metadata(); auto column_schema = metadata->schema()->Column(0); - ASSERT_EQ(SortOrder::UNSIGNED, column_schema->sort_order()); + ASSERT_EQ(SortOrder::kUnsigned, column_schema->sort_order()); auto column_chunk = metadata->ColumnChunk(0); ASSERT_TRUE(column_chunk->is_stats_set()); @@ -1645,5 +1857,397 @@ TEST(TestEncodedStatistics, CopySafe) { } */ +namespace { + +constexpr int32_t kTruncLen = 16; + +template +std::shared_ptr> makeStats( + const ColumnDescriptor* descr, + std::initializer_list values) { + auto stats = makeStatistics(descr); + std::vector v(values); + stats->update(v.data(), v.size(), 0); + return stats; +} + +std::shared_ptr> makeStats( + const ColumnDescriptor* descr, + std::initializer_list values) { + auto stats = makeStatistics(descr); + std::vector strings(values); + std::vector byteArrays; + byteArrays.reserve(strings.size()); + for (const auto& s : strings) { + byteArrays.push_back(byteArrayFromString(s)); + } + stats->update(byteArrays.data(), byteArrays.size(), 0); + return stats; +} + +} // namespace + +TEST(IcebergStatistics, decimalMinMaxValue) { + const NodePtr Node = PrimitiveNode::make( + "decimal_col", + Repetition::kRequired, + LogicalType::decimal(10, 2), + Type::kInt64); + ColumnDescriptor descr(Node, 0, 0); + + auto stats = + makeStats(&descr, {12345, -67890, 100, 50000}); + + ASSERT_TRUE(stats->hasMinMax()); + EXPECT_EQ(stats->min(), -67890); + EXPECT_EQ(stats->max(), 50000); + + const auto lowerBound = stats->icebergLowerBoundInclusive(kTruncLen); + const auto upperBound = stats->icebergUpperBoundExclusive(kTruncLen); + + ASSERT_TRUE(upperBound.has_value()); + + // Verify the encoding is big-endian. + int64_t decodedMin = ::arrow::bit_util::FromBigEndian( + *reinterpret_cast(lowerBound.data())); + int64_t decodedMax = ::arrow::bit_util::FromBigEndian( + *reinterpret_cast(upperBound->data())); + + EXPECT_EQ(decodedMin, -67890); + EXPECT_EQ(decodedMax, 50000); +} + +TEST(IcebergStatistics, nonDecimalBounds) { + const NodePtr Node = + PrimitiveNode::make("int_col", Repetition::kRequired, Type::kInt64); + ColumnDescriptor descr(Node, 0, 0); + + auto stats = makeStats(&descr, {100, -200, 300, 50}); + + ASSERT_TRUE(stats->hasMinMax()); + + // For non-decimal INT64, IcebergLowerBound should equal EncodeMin (plain. + // Encoding). + EXPECT_EQ(stats->icebergLowerBoundInclusive(kTruncLen), stats->encodeMin()); + const auto upperBound = stats->icebergUpperBoundExclusive(kTruncLen); + ASSERT_TRUE(upperBound.has_value()); + EXPECT_EQ(*upperBound, stats->encodeMax()); +} + +TEST(IcebergStatistics, byteArrayBounds) { + NodePtr Node = PrimitiveNode::make( + "string_col", + Repetition::kRequired, + Type::kByteArray, + ConvertedType::kUtf8); + ColumnDescriptor descr(Node, 0, 0); + + auto stats = makeStats( + &descr, + {"AAAAAAAAAAAAAAAAAAAAAAAAA", "ZZZZZZZZZZZZZZZZZZZZZZZZZ", "Hello"}); + + ASSERT_TRUE(stats->hasMinMax()); + + // IcebergLowerBound should be truncated to 16 characters. + const auto lowerBound = stats->icebergLowerBoundInclusive(kTruncLen); + EXPECT_EQ(lowerBound, "AAAAAAAAAAAAAAAA"); + + // IcebergUpperBound should be truncated and incremented. + const auto upperBound = stats->icebergUpperBoundExclusive(kTruncLen); + ASSERT_TRUE(upperBound.has_value()); + // 'Z' (0x5A) incremented becomes '[' (0x5B). + EXPECT_EQ(*upperBound, "ZZZZZZZZZZZZZZZ["); +} + +TEST(IcebergStatistics, byteArrayBoundsNoTruncation) { + NodePtr Node = PrimitiveNode::make( + "string_col", + Repetition::kRequired, + Type::kByteArray, + ConvertedType::kUtf8); + ColumnDescriptor descr(Node, 0, 0); + + // Create strings shorter than 16 characters. + auto stats = makeStats(&descr, {"apple", "zebra", "banana"}); + + ASSERT_TRUE(stats->hasMinMax()); + + // For short strings, IcebergLowerBound/IcebergUpperBound should be the same. + // As EncodeMin/EncodeMax. + EXPECT_EQ(stats->icebergLowerBoundInclusive(kTruncLen), stats->encodeMin()); + const auto upperBound = stats->icebergUpperBoundExclusive(kTruncLen); + ASSERT_TRUE(upperBound.has_value()); + EXPECT_EQ(*upperBound, stats->encodeMax()); +} + +TEST(IcebergStatistics, byteArrayBoundsUnicode) { + NodePtr Node = PrimitiveNode::make( + "string_col", + Repetition::kRequired, + Type::kByteArray, + ConvertedType::kUtf8); + ColumnDescriptor descr(Node, 0, 0); + + // Create Unicode strings longer than 16 characters to trigger truncation. + // Truncation is based on character count (16 chars), not byte count. + // "世" is U+4E16 (3 bytes in UTF-8: E4 B8 96) + // "界" is U+754C (3 bytes in UTF-8: E7 95 8C) + // "你" is U+4F60 (3 bytes in UTF-8: E4 BD A0) + // "好" is U+597D (3 bytes in UTF-8: E5 A5 BD) + auto stats = makeStats( + &descr, + {"AAAA世界世界世界世界世界世界世", + "ZZZZ你好你好你好你好你好你好你", + "Hello, 世界!"}); + + ASSERT_TRUE(stats->hasMinMax()); + + const auto lowerBound = stats->icebergLowerBoundInclusive(kTruncLen); + const auto upperBound = stats->icebergUpperBoundExclusive(kTruncLen); + + ASSERT_TRUE(upperBound.has_value()); + + // Str1 has 18 characters (4 ASCII + 14 Chinese). + // Truncated to 16 chars: "AAAA" (4) + "世界世界世界世界世界世界 (12)". + const std::string expectedLower = "AAAA世界世界世界世界世界世界"; + EXPECT_EQ(lowerBound, expectedLower); + + // Str2 has 18 characters (4 ASCII + 14 Chinese). + // For upperBound, the last character is incremented after truncation. + // "好" (U+597D) incremented becomes U+597E "奾". + const std::string expectedUpper = "ZZZZ你好你好你好你好你好你奾"; + EXPECT_EQ(*upperBound, expectedUpper); +} + +// Verifies the lower and upper bounds for non-string ByteArray (BINARY / +// VARBINARY) take a raw-byte prefix and a byte-level round-up rather than +// going through the UTF-8 paths. Inputs include 0xff bytes that would not +// form a valid UTF-8 sequence; the lower bound must be a byte-truncation +// (string_view::substr), and the upper bound must use the binary round-up +// (which propagates carry on 0xff and yields an exclusive upper bound). +TEST(IcebergStatistics, byteArrayBoundsBinary) { + NodePtr Node = PrimitiveNode::make( + "binary_col", + Repetition::kRequired, + Type::kByteArray, + ConvertedType::kNone); + ColumnDescriptor descr(Node, 0, 0); + + // Each value is 20 raw bytes, longer than kTruncLen (16). Bytes include + // 0xff and other non-ASCII values to ensure UTF-8 parsing would not be + // appropriate. + const std::string min(20, '\x10'); + const std::string max = std::string(15, '\xff') + std::string(5, '\x10'); + auto stats = makeStats(&descr, {min, max}); + + ASSERT_TRUE(stats->hasMinMax()); + + // Lower bound: raw-byte prefix of 'min' (16 bytes of 0x10). + const auto lowerBound = stats->icebergLowerBoundInclusive(kTruncLen); + EXPECT_EQ(lowerBound, std::string(16, '\x10')); + + // Upper bound: max truncated to 16 bytes is 15 * 0xff + 1 * 0x10. The + // binary round-up walks from the end and increments the first non-0xff + // byte (0x10 -> 0x11), then truncates anything past it. + const auto upperBound = stats->icebergUpperBoundExclusive(kTruncLen); + ASSERT_TRUE(upperBound.has_value()); + EXPECT_EQ(*upperBound, std::string(15, '\xff') + std::string(1, '\x11')); +} + +TEST(IcebergStatistics, floatBounds) { + NodePtr Node = + PrimitiveNode::make("float_col", Repetition::kRequired, Type::kFloat); + ColumnDescriptor descr(Node, 0, 0); + + auto stats = makeStats(&descr, {1.5f, -2.5f, 3.5f, 0.5f}); + + ASSERT_TRUE(stats->hasMinMax()); + EXPECT_EQ(stats->icebergLowerBoundInclusive(kTruncLen), stats->encodeMin()); + const auto upperBound = stats->icebergUpperBoundExclusive(kTruncLen); + ASSERT_TRUE(upperBound.has_value()); + EXPECT_EQ(*upperBound, stats->encodeMax()); +} + +TEST(IcebergStatistics, doubleBounds) { + NodePtr Node = + PrimitiveNode::make("double_col", Repetition::kRequired, Type::kDouble); + ColumnDescriptor descr(Node, 0, 0); + + auto stats = makeStats(&descr, {1.5, -2.5, 3.5, 0.5}); + + ASSERT_TRUE(stats->hasMinMax()); + EXPECT_EQ(stats->icebergLowerBoundInclusive(kTruncLen), stats->encodeMin()); + const auto upperBound = stats->icebergUpperBoundExclusive(kTruncLen); + ASSERT_TRUE(upperBound.has_value()); + EXPECT_EQ(*upperBound, stats->encodeMax()); +} + +TEST(IcebergStatistics, emptyStringBounds) { + NodePtr Node = PrimitiveNode::make( + "string_col", + Repetition::kRequired, + Type::kByteArray, + ConvertedType::kUtf8); + ColumnDescriptor descr(Node, 0, 0); + + { + auto stats = makeStats(&descr, {"", "hello", ""}); + + ASSERT_TRUE(stats->hasMinMax()); + EXPECT_EQ(stats->icebergLowerBoundInclusive(kTruncLen), ""); + + const auto upperBound = stats->icebergUpperBoundExclusive(kTruncLen); + ASSERT_TRUE(upperBound.has_value()); + EXPECT_EQ(*upperBound, "hello"); + } + + { + auto stats = makeStats(&descr, {"", ""}); + + ASSERT_TRUE(stats->hasMinMax()); + EXPECT_EQ(stats->min(), ByteArray("")); + EXPECT_EQ(stats->max(), ByteArray("")); + + EXPECT_EQ(stats->icebergLowerBoundInclusive(kTruncLen), ""); + + const auto upperBound = stats->icebergUpperBoundExclusive(kTruncLen); + ASSERT_TRUE(upperBound.has_value()); + EXPECT_EQ(*upperBound, ""); + } +} + +TEST(IcebergStatistics, unboundedUpperBound) { + NodePtr Node = PrimitiveNode::make( + "string_col", + Repetition::kRequired, + Type::kByteArray, + ConvertedType::kUtf8); + ColumnDescriptor descr(Node, 0, 0); + + { + std::string allMaxAscii(20, '\x7F'); + auto stats = makeStats(&descr, {"hello", allMaxAscii}); + + ASSERT_TRUE(stats->hasMinMax()); + + const auto lowerBound = stats->icebergLowerBoundInclusive(kTruncLen); + const auto upperBound = stats->icebergUpperBoundExclusive(kTruncLen); + + EXPECT_EQ(lowerBound, stats->encodeMin()); + EXPECT_FALSE(upperBound.has_value()); + } + + { + std::string allMaxUnicode; + allMaxUnicode.reserve(17 * 4); + for (int i = 0; i < 17; ++i) { + allMaxUnicode += "\U0010FFFF"; + } + auto stats = makeStats(&descr, {"hello", allMaxUnicode}); + + ASSERT_TRUE(stats->hasMinMax()); + + const auto lowerBound = stats->icebergLowerBoundInclusive(kTruncLen); + const auto upperBound = stats->icebergUpperBoundExclusive(kTruncLen); + EXPECT_EQ(lowerBound, stats->encodeMin()); + EXPECT_FALSE(upperBound.has_value()); + } + + { + std::string allMaxAscii(20, '\x7F'); + auto stats = makeStats(&descr, {allMaxAscii, allMaxAscii}); + + ASSERT_TRUE(stats->hasMinMax()); + EXPECT_EQ(stats->min(), ByteArray(allMaxAscii)); + EXPECT_EQ(stats->max(), ByteArray(allMaxAscii)); + + const auto lowerBound = stats->icebergLowerBoundInclusive(kTruncLen); + const auto upperBound = stats->icebergUpperBoundExclusive(kTruncLen); + EXPECT_TRUE(stats->encodeMin().starts_with(lowerBound)); + EXPECT_FALSE(upperBound.has_value()); + } + + { + std::string allMaxUnicode; + allMaxUnicode.reserve(17 * 4); + for (int i = 0; i < 17; ++i) { + allMaxUnicode += "\U0010FFFF"; + } + auto stats = makeStats(&descr, {allMaxUnicode, allMaxUnicode}); + + ASSERT_TRUE(stats->hasMinMax()); + EXPECT_EQ(stats->min(), ByteArray(allMaxUnicode)); + EXPECT_EQ(stats->max(), ByteArray(allMaxUnicode)); + + const auto lowerBound = stats->icebergLowerBoundInclusive(kTruncLen); + const auto upperBound = stats->icebergUpperBoundExclusive(kTruncLen); + EXPECT_TRUE(stats->encodeMin().starts_with(lowerBound)); + EXPECT_FALSE(upperBound.has_value()); + } +} + +TEST(StatisticsComparison, withInt64) { + NodePtr Node = + PrimitiveNode::make("int_col", Repetition::kRequired, Type::kInt64); + ColumnDescriptor descr(Node, 0, 0); + + auto stats1 = makeStats(&descr, {10, 20, 30}); + auto stats2 = makeStats(&descr, {5, 15, 25}); + auto stats3 = makeStats(&descr, {10, 20, 30}); + + ASSERT_TRUE(stats1->hasMinMax()); + ASSERT_TRUE(stats2->hasMinMax()); + ASSERT_TRUE(stats3->hasMinMax()); + + EXPECT_TRUE(stats1->maxGreaterThan(*stats2)); + EXPECT_FALSE(stats2->maxGreaterThan(*stats1)); + EXPECT_TRUE(stats1->maxGreaterThan(*stats3)); + EXPECT_TRUE(stats3->maxGreaterThan(*stats1)); + + EXPECT_FALSE(stats1->minLessThan(*stats2)); + EXPECT_TRUE(stats2->minLessThan(*stats1)); + EXPECT_FALSE(stats1->minLessThan(*stats3)); + EXPECT_FALSE(stats3->minLessThan(*stats1)); +} + +TEST(StatisticsComparison, withDouble) { + NodePtr Node = + PrimitiveNode::make("double_col", Repetition::kRequired, Type::kDouble); + ColumnDescriptor descr(Node, 0, 0); + + auto stats1 = makeStats(&descr, {1.0, 2.0, 3.0}); + auto stats2 = makeStats(&descr, {0.5, 1.5, 2.5}); + + ASSERT_TRUE(stats1->hasMinMax()); + ASSERT_TRUE(stats2->hasMinMax()); + + EXPECT_TRUE(stats1->maxGreaterThan(*stats2)); + EXPECT_FALSE(stats2->maxGreaterThan(*stats1)); + + EXPECT_FALSE(stats1->minLessThan(*stats2)); + EXPECT_TRUE(stats2->minLessThan(*stats1)); +} + +TEST(StatisticsComparison, withByteArray) { + NodePtr Node = PrimitiveNode::make( + "string_col", + Repetition::kRequired, + Type::kByteArray, + ConvertedType::kUtf8); + ColumnDescriptor descr(Node, 0, 0); + + auto stats1 = makeStats(&descr, {"apple", "zebra"}); + auto stats2 = makeStats(&descr, {"banana", "mangomango"}); + + ASSERT_TRUE(stats1->hasMinMax()); + ASSERT_TRUE(stats2->hasMinMax()); + + EXPECT_TRUE(stats1->maxGreaterThan(*stats2)); + EXPECT_FALSE(stats2->maxGreaterThan(*stats1)); + + EXPECT_TRUE(stats1->minLessThan(*stats2)); + EXPECT_FALSE(stats2->minLessThan(*stats1)); +} + } // namespace test } // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/tests/StringTruncationTest.cpp b/velox/dwio/parquet/writer/arrow/tests/StringTruncationTest.cpp new file mode 100644 index 00000000000..a0486b0c1d0 --- /dev/null +++ b/velox/dwio/parquet/writer/arrow/tests/StringTruncationTest.cpp @@ -0,0 +1,271 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "velox/dwio/parquet/writer/arrow/StringTruncation.h" + +namespace facebook::velox::parquet::arrow { + +// Tests for string utility functions used in statistics +TEST(StringTruncation, truncateUtf8) { + auto testTruncate = [](const std::string& input, + int32_t numCodePoints, + const std::string& expected) { + EXPECT_EQ(truncateUtf8(input, numCodePoints), expected); + }; + + // ASCII string. + std::string ascii = "Hello, world!"; + testTruncate(ascii, 0, ""); + testTruncate(ascii, 1, "H"); + testTruncate(ascii, 5, "Hello"); + testTruncate(ascii, 13, ascii); + testTruncate(ascii, 20, ascii); + + // String with multi-bytes characters. + std::string unicode = "Hello, 世界!"; + testTruncate(unicode, 7, "Hello, "); + testTruncate(unicode, 8, "Hello, 世"); + testTruncate(unicode, 9, "Hello, 世界"); + testTruncate(unicode, 10, unicode); + testTruncate(unicode, 20, unicode); + + // String with emoji (surrogate pairs). + std::string emoji = "Hello 🌍!"; + testTruncate(emoji, 6, "Hello "); + testTruncate(emoji, 7, "Hello 🌍"); + testTruncate(emoji, 8, emoji); + testTruncate(emoji, 10, emoji); + + std::string empty = ""; + testTruncate(empty, 0, ""); + testTruncate(empty, 5, ""); + + std::string mixed = "café世界🌍"; + testTruncate(mixed, 3, "caf"); + testTruncate(mixed, 4, "café"); + testTruncate(mixed, 5, "café世"); + testTruncate(mixed, 6, "café世界"); + testTruncate(mixed, 7, mixed); +} + +TEST(StringTruncation, roundUpUtf8) { + auto testRoundUp = [](const std::string& input, + int32_t numCodePoints, + const std::optional& expected) { + EXPECT_EQ(roundUpUtf8(input, numCodePoints), expected); + }; + + std::string ascii = "Hello, world!"; + // Empty truncation returns nullopt. + testRoundUp(ascii, 0, std::nullopt); + // 'o' -> 'p'. + testRoundUp(ascii, 5, "Hellp"); + testRoundUp(ascii, ascii.length(), ascii); + + ascii = "Customer#000001500"; + // '5' -> '6'. + testRoundUp(ascii, 16, "Customer#0000016"); + + std::string unicode = "Hello, 世界!"; + testRoundUp(unicode, 8, "Hello, 丗"); + + // No truncation needed. + std::string shortString = "Hi"; + testRoundUp(shortString, 2, shortString); + testRoundUp(shortString, 20, shortString); + + // Last character is already at maximum value, returns nullopt. + std::string maxChar = "Hello\U0010FFFF"; + testRoundUp(maxChar, 6, maxChar); + + std::string empty = ""; + testRoundUp(empty, 0, ""); + testRoundUp(empty, 5, ""); + + std::string single = "a"; + // No truncation needed. + testRoundUp(single, 1, "a"); + + std::string zChar = "zz"; + // 'z' -> '{'. + testRoundUp(zChar, 1, "{"); + + std::string emojiTest = "🌍!!"; + // U1F30D (🌍) -> U1F30E. + testRoundUp(emojiTest, 1, "\U0001F30E"); + + std::string multiByteTest = "café+"; + // 'f' -> 'g'. + testRoundUp(multiByteTest, 3, "cag"); + // 'é' -> 'ê'. + testRoundUp(multiByteTest, 4, "cafê"); + + // Test surrogate boundary: U+D7FF should increment to U+E000 (skipping + // surrogate range U+D800-U+DFFF). + // U+D7FF followed by "!!" + std::string surrogateTest = "\xED\x9F\xBF!!"; + // U+E000 + testRoundUp(surrogateTest, 1, "\xEE\x80\x80"); + + // Test all max code points - should return nullopt. + std::string allMax = "\U0010FFFF\U0010FFFF"; + testRoundUp(allMax, 1, std::nullopt); +} + +TEST(StringTruncation, roundUpBinary) { + auto testRoundUpBinary = [](const std::string& input, + int32_t truncateLength, + const std::optional& expected) { + EXPECT_EQ(roundUpBinary(input, truncateLength), expected); + }; + + // Basic binary data with truncation. + std::string binary = "Hello, world!"; + // Empty truncation returns nullopt. + testRoundUpBinary(binary, 0, std::nullopt); + // 'o' (0x6F) -> 'p' (0x70). + testRoundUpBinary(binary, 5, "Hellp"); + // No truncation needed - returns input unchanged. + testRoundUpBinary(binary, binary.length(), binary); + testRoundUpBinary(binary, binary.length() + 10, binary); + + // Test with numeric data. + std::string numeric = "Customer#000001500"; + // '5' (0x35) -> '6' (0x36). + testRoundUpBinary(numeric, 16, "Customer#0000016"); + + // Test with binary data containing high bytes. + std::string highBytes = "data\xFE\xFD"; + // No truncation needed - returns input unchanged. + testRoundUpBinary(highBytes, 6, highBytes); + // Truncate to 5 bytes "data\xFE", 0xFE -> 0xFF. + testRoundUpBinary(highBytes, 5, "data\xFF"); + + // Test with all 0xFF bytes - should return nullopt. + std::string allFF = "\xFF\xFF\xFF"; + testRoundUpBinary(allFF, 1, std::nullopt); + testRoundUpBinary(allFF, 2, std::nullopt); + // No truncation needed - returns input unchanged. + testRoundUpBinary(allFF, 3, allFF); + + // Test with trailing 0xFF bytes. + std::string trailingFF = "abc\xFF\xFF"; + // No truncation needed - returns input unchanged. + testRoundUpBinary(trailingFF, 5, trailingFF); + // Truncate to 4 bytes "abc\xFF", 0xFF overflows, 'c' (0x63) -> 'd' (0x64). + testRoundUpBinary(trailingFF, 4, "abd"); + // Truncate to 3 bytes "abc", 'c' (0x63) -> 'd' (0x64). + testRoundUpBinary(trailingFF, 3, "abd"); + + // Test empty string. + std::string empty = ""; + testRoundUpBinary(empty, 0, std::nullopt); + testRoundUpBinary(empty, 5, ""); + + // Test single byte. + std::string single = "a"; + // No truncation needed - returns input unchanged. + testRoundUpBinary(single, 1, "a"); + testRoundUpBinary(single, 10, "a"); + + // Test incrementing single byte with truncation. + std::string singleZ = "zz"; + // Truncate to 1 byte "z", 'z' (0x7A) -> '{' (0x7B). + testRoundUpBinary(singleZ, 1, "{"); + + // Test with null bytes. + std::string withNull = std::string("ab\0cd", 5); + // No truncation needed - returns input unchanged. + testRoundUpBinary(withNull, 5, withNull); + // Truncate to 4 bytes "ab\0c", 'c' (0x63) -> 'd' (0x64). + testRoundUpBinary(withNull, 4, std::string("ab\0d", 4)); + + // Test boundary case: 0xFE -> 0xFF. + std::string boundaryFE = "test\xFE"; + // No truncation needed - returns input unchanged. + testRoundUpBinary(boundaryFE, 5, boundaryFE); + // Truncate to 5 bytes and increment would give same result. + std::string boundaryFE2 = std::string("test\xFE", 5) + "abc"; + testRoundUpBinary(boundaryFE2, 5, "test\xFF"); + + // Test mixed case with overflow in middle. + std::string mixedOverflow = "a\xFF\xFFz"; + // Truncate to 3 bytes "a\xFF\xFF", both 0xFF overflow, 'a' (0x61) -> 'b' + // (0x62). + testRoundUpBinary(mixedOverflow, 3, "b"); + + // Test truncation removes trailing bytes after increment. + std::string longString = "abcdefgh"; + // Truncate to 3 bytes "abc", 'c' (0x63) -> 'd' (0x64), result is "abd". + testRoundUpBinary(longString, 3, "abd"); + // Truncate to 5 bytes "abcde", 'e' (0x65) -> 'f' (0x66), result is "abcdf". + testRoundUpBinary(longString, 5, "abcdf"); + + // Test with UTF-8 multi-byte sequences (treated as raw bytes). + std::string utf8Bytes = "café"; + // Truncate to 3 bytes "caf", 'f' (0x66) -> 'g' (0x67). + testRoundUpBinary(utf8Bytes, 3, "cag"); + // No truncation needed - returns input unchanged. + testRoundUpBinary(utf8Bytes, 5, utf8Bytes); + // Truncate to 5 bytes and increment last byte. + std::string utf8Bytes2 = "café!"; + // Truncate to 5 bytes "café" (caf + 0xC3 0xA9), 0xA9 -> 0xAA. + testRoundUpBinary(utf8Bytes2, 5, "caf\xC3\xAA"); + + // Test with INVALID UTF-8 sequences - this is the key use case for + // roundUpBinary. These sequences would cause roundUpUtf8 to fail, but + // roundUpBinary treats them as raw bytes. + + // Invalid UTF-8: lone continuation byte 0x80. + std::string invalidUtf8_1 = std::string("test\x80", 5); + testRoundUpBinary(invalidUtf8_1, 5, invalidUtf8_1); + testRoundUpBinary(invalidUtf8_1, 4, "tesu"); + + // Invalid UTF-8: incomplete multi-byte sequence (0xC3 without continuation). + std::string invalidUtf8_2 = std::string("data\xC3", 5); + testRoundUpBinary(invalidUtf8_2, 5, invalidUtf8_2); + testRoundUpBinary(invalidUtf8_2, 4, "datb"); + + // Invalid UTF-8: overlong encoding (0xC0 0x80 for null byte). + std::string invalidUtf8_3 = std::string("ab\xC0\x80", 4); + testRoundUpBinary(invalidUtf8_3, 4, invalidUtf8_3); + testRoundUpBinary(invalidUtf8_3, 3, std::string("ab\xC1", 3)); + + // Invalid UTF-8: invalid start byte 0xFE. + std::string invalidUtf8_4 = std::string("xyz\xFE", 4); + testRoundUpBinary(invalidUtf8_4, 4, invalidUtf8_4); + testRoundUpBinary(invalidUtf8_4, 3, "xy{"); + + // Invalid UTF-8: truncated 3-byte sequence (0xE0 0x80 without third byte). + std::string invalidUtf8_5 = std::string("foo\xE0\x80", 5); + testRoundUpBinary(invalidUtf8_5, 5, invalidUtf8_5); + testRoundUpBinary(invalidUtf8_5, 4, std::string("foo\xE1", 4)); + + // Invalid UTF-8: sequence with 0xFF (which is never valid in UTF-8). + std::string invalidUtf8_6 = std::string("bar\xFF", 4); + testRoundUpBinary(invalidUtf8_6, 4, invalidUtf8_6); + // Truncate to 3 bytes "bar", 'r' (0x72) -> 's' (0x73). + testRoundUpBinary(invalidUtf8_6, 3, "bas"); + + // Test with all 0xFF in invalid UTF-8 context. + std::string invalidUtf8_7 = std::string("\xFF\xFF\xFF", 3); + testRoundUpBinary(invalidUtf8_7, 2, std::nullopt); + testRoundUpBinary(invalidUtf8_7, 3, invalidUtf8_7); +} + +} // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/tests/TestUtil.cpp b/velox/dwio/parquet/writer/arrow/tests/TestUtil.cpp index 47156197e1b..58b25d47cc3 100644 --- a/velox/dwio/parquet/writer/arrow/tests/TestUtil.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/TestUtil.cpp @@ -36,7 +36,7 @@ namespace facebook::velox::parquet::arrow { namespace test { -const char* get_data_dir() { +const char* getDataDir() { const auto result = std::getenv("PARQUET_TEST_DATA"); if (!result || !result[0]) { throw ParquetTestException( @@ -46,30 +46,30 @@ const char* get_data_dir() { return result; } -std::string get_bad_data_dir() { - // PARQUET_TEST_DATA should point to - // ARROW_HOME/cpp/submodules/parquet-testing/data so need to reach one folder - // up to access the "bad_data" folder. - std::string data_dir(get_data_dir()); +std::string getBadDataDir() { + // PARQUET_TEST_DATA should point to. + // ARROW_HOME/cpp/submodules/parquet-testing/data so need to reach one folder. + // Up to access the "bad_data" folder. + std::string dataDir(getDataDir()); std::stringstream ss; - ss << data_dir << "/../bad_data"; + ss << dataDir << "/../bad_data"; return ss.str(); } -std::string get_data_file(const std::string& filename, bool is_good) { +std::string getDataFile(const std::string& filename, bool isGood) { std::stringstream ss; - if (is_good) { - ss << get_data_dir(); + if (isGood) { + ss << getDataDir(); } else { - ss << get_bad_data_dir(); + ss << getBadDataDir(); } ss << "/" << filename; return ss.str(); } -void random_bytes(int n, uint32_t seed, std::vector* out) { +void randomBytes(int n, uint32_t seed, std::vector* out) { std::default_random_engine gen(seed); std::uniform_int_distribution d(0, 255); @@ -79,7 +79,7 @@ void random_bytes(int n, uint32_t seed, std::vector* out) { } } -void random_bools(int n, double p, uint32_t seed, bool* out) { +void randomBools(int n, double p, uint32_t seed, bool* out) { std::default_random_engine gen(seed); std::bernoulli_distribution d(p); for (int i = 0; i < n; ++i) { @@ -87,14 +87,14 @@ void random_bools(int n, double p, uint32_t seed, bool* out) { } } -void random_Int96_numbers( +void randomInt96Numbers( int n, uint32_t seed, - int32_t min_value, - int32_t max_value, + int32_t minValue, + int32_t maxValue, Int96* out) { std::default_random_engine gen(seed); - std::uniform_int_distribution d(min_value, max_value); + std::uniform_int_distribution d(minValue, maxValue); for (int i = 0; i < n; ++i) { out[i].value[0] = d(gen); out[i].value[1] = d(gen); @@ -102,7 +102,7 @@ void random_Int96_numbers( } } -void random_fixed_byte_array( +void randomFixedByteArray( int n, uint32_t seed, uint8_t* buf, @@ -119,15 +119,15 @@ void random_fixed_byte_array( } } -void random_byte_array( +void randomByteArray( int n, uint32_t seed, uint8_t* buf, ByteArray* out, - int min_size, - int max_size) { + int minSize, + int maxSize) { std::default_random_engine gen(seed); - std::uniform_int_distribution d1(min_size, max_size); + std::uniform_int_distribution d1(minSize, maxSize); std::uniform_int_distribution d2(0, 255); for (int i = 0; i < n; ++i) { int len = d1(gen); @@ -140,13 +140,13 @@ void random_byte_array( } } -void random_byte_array( +void randomByteArray( int n, uint32_t seed, uint8_t* buf, ByteArray* out, - int max_size) { - random_byte_array(n, seed, buf, out, 0, max_size); + int maxSize) { + randomByteArray(n, seed, buf, out, 0, maxSize); } } // namespace test diff --git a/velox/dwio/parquet/writer/arrow/tests/TestUtil.h b/velox/dwio/parquet/writer/arrow/tests/TestUtil.h index 79f3dc3992c..019f79541c4 100644 --- a/velox/dwio/parquet/writer/arrow/tests/TestUtil.h +++ b/velox/dwio/parquet/writer/arrow/tests/TestUtil.h @@ -38,14 +38,14 @@ #include "velox/dwio/parquet/writer/arrow/tests/ColumnReader.h" // https://github.com/google/googletest/pull/2904 might not be available -// in our version of gtest/gmock -#define EXPECT_THROW_THAT(callable, ex_type, property) \ - EXPECT_THROW( \ - try { (callable)(); } catch (const ex_type& err) { \ - EXPECT_THAT(err, (property)); \ - throw; \ - }, \ - ex_type) +// In our version of gtest/gmock. +#define EXPECT_THROW_THAT(callable, exType, property) \ + EXPECT_THROW( \ + try { (callable)(); } catch (const exType& err) { \ + EXPECT_THAT(err, (property)); \ + throw; \ + }, \ + exType) namespace facebook::velox::parquet::arrow { @@ -72,13 +72,13 @@ class ParquetTestException : public ParquetException { using ParquetException::ParquetException; }; -const char* get_data_dir(); -std::string get_bad_data_dir(); +const char* getDataDir(); +std::string getBadDataDir(); -std::string get_data_file(const std::string& filename, bool is_good = true); +std::string getDataFile(const std::string& filename, bool isGood = true); template -static inline void assert_vector_equal( +static inline void assertVectorEqual( const std::vector& left, const std::vector& right) { ASSERT_EQ(left.size(), right.size()); @@ -89,7 +89,7 @@ static inline void assert_vector_equal( } template -static inline bool vector_equal( +static inline bool vectorEqual( const std::vector& left, const std::vector& right) { if (left.size() != right.size()) { @@ -120,96 +120,96 @@ static std::vector slice(const std::vector& values, int start, int end) { return out; } -void random_bytes(int n, uint32_t seed, std::vector* out); -void random_bools(int n, double p, uint32_t seed, bool* out); +void randomBytes(int n, uint32_t seed, std::vector* out); +void randomBools(int n, double p, uint32_t seed, bool* out); template inline void -random_numbers(int n, uint32_t seed, T min_value, T max_value, T* out) { +randomNumbers(int n, uint32_t seed, T minValue, T maxValue, T* out) { std::default_random_engine gen(seed); - std::uniform_int_distribution d(min_value, max_value); + std::uniform_int_distribution d(minValue, maxValue); for (int i = 0; i < n; ++i) { out[i] = d(gen); } } template <> -inline void random_numbers( +inline void randomNumbers( int n, uint32_t seed, - float min_value, - float max_value, + float minValue, + float maxValue, float* out) { std::default_random_engine gen(seed); - std::uniform_real_distribution d(min_value, max_value); + std::uniform_real_distribution d(minValue, maxValue); for (int i = 0; i < n; ++i) { out[i] = d(gen); } } template <> -inline void random_numbers( +inline void randomNumbers( int n, uint32_t seed, - double min_value, - double max_value, + double minValue, + double maxValue, double* out) { std::default_random_engine gen(seed); - std::uniform_real_distribution d(min_value, max_value); + std::uniform_real_distribution d(minValue, maxValue); for (int i = 0; i < n; ++i) { out[i] = d(gen); } } -void random_Int96_numbers( +void randomInt96Numbers( int n, uint32_t seed, - int32_t min_value, - int32_t max_value, + int32_t minValue, + int32_t maxValue, Int96* out); -void random_fixed_byte_array( +void randomFixedByteArray( int n, uint32_t seed, uint8_t* buf, int len, FLBA* out); -void random_byte_array( +void randomByteArray( int n, uint32_t seed, uint8_t* buf, ByteArray* out, - int min_size, - int max_size); + int minSize, + int maxSize); -void random_byte_array( +void randomByteArray( int n, uint32_t seed, uint8_t* buf, ByteArray* out, - int max_size); + int maxSize); template -std::shared_ptr EncodeValues( +std::shared_ptr encodeValues( Encoding::type encoding, - bool use_dictionary, + bool useDictionary, const Sequence& values, int length, const ColumnDescriptor* descr) { - auto encoder = MakeTypedEncoder(encoding, use_dictionary, descr); - encoder->Put(values, length); - return encoder->FlushValues(); + auto encoder = makeTypedEncoder(encoding, useDictionary, descr); + encoder->put(values, length); + return encoder->flushValues(); } template -static void InitValues( - int num_values, +static void initValues( + int numValues, uint32_t seed, std::vector& values, std::vector& buffer) { - random_numbers( - num_values, + randomNumbers( + numValues, seed, std::numeric_limits::min(), std::numeric_limits::max(), @@ -217,283 +217,283 @@ static void InitValues( } template -static void InitValues( - int num_values, +static void initValues( + int numValues, std::vector& values, std::vector& buffer) { - InitValues(num_values, 0, values, buffer); + initValues(numValues, 0, values, buffer); } template -static void InitDictValues( - int num_values, - int num_dicts, +static void initDictValues( + int numValues, + int numDicts, std::vector& values, std::vector& buffer) { - int repeat_factor = num_values / num_dicts; - InitValues(num_dicts, values, buffer); - // add some repeated values - for (int j = 1; j < repeat_factor; ++j) { - for (int i = 0; i < num_dicts; ++i) { - std::memcpy(&values[num_dicts * j + i], &values[i], sizeof(T)); + int repeatFactor = numValues / numDicts; + initValues(numDicts, values, buffer); + // Add some repeated values. + for (int j = 1; j < repeatFactor; ++j) { + for (int i = 0; i < numDicts; ++i) { + std::memcpy(&values[numDicts * j + i], &values[i], sizeof(T)); } } - // computed only dict_per_page * repeat_factor - 1 values < num_values - // compute remaining - for (int i = num_dicts * repeat_factor; i < num_values; ++i) { - std::memcpy(&values[i], &values[i - num_dicts * repeat_factor], sizeof(T)); + // Computed only dict_per_page * repeat_factor - 1 values < num_values. + // Compute remaining. + for (int i = numDicts * repeatFactor; i < numValues; ++i) { + std::memcpy(&values[i], &values[i - numDicts * repeatFactor], sizeof(T)); } } template <> -inline void InitDictValues( - int num_values, - int num_dicts, +inline void initDictValues( + int numValues, + int numDicts, std::vector& values, std::vector& buffer) { - // No op for bool + // No op for bool. } class MockPageReader : public PageReader { public: explicit MockPageReader(const std::vector>& pages) - : pages_(pages), page_index_(0) {} + : pages_(pages), pageIndex_(0) {} - std::shared_ptr NextPage() override { - if (page_index_ == static_cast(pages_.size())) { - // EOS to consumer + std::shared_ptr nextPage() override { + if (pageIndex_ == static_cast(pages_.size())) { + // EOS to consumer. return std::shared_ptr(nullptr); } - return pages_[page_index_++]; + return pages_[pageIndex_++]; } - // No-op - void set_max_page_header_size(uint32_t size) override {} + // No-op. + void setMaxPageHeaderSize(uint32_t size) override {} private: std::vector> pages_; - int page_index_; + int pageIndex_; }; -// TODO(wesm): this is only used for testing for now. Refactor to form part of -// primary file write path +// TODO(wesm): this is only used for testing for now. Refactor to form part of. +// Primary file write path. template class DataPageBuilder { public: - using c_type = typename Type::c_type; + using CType = typename Type::CType; - // This class writes data and metadata to the passed inputs + // This class writes data and metadata to the passed inputs. explicit DataPageBuilder(ArrowOutputStream* sink) : sink_(sink), - num_values_(0), - encoding_(Encoding::PLAIN), - definition_level_encoding_(Encoding::RLE), - repetition_level_encoding_(Encoding::RLE), - have_def_levels_(false), - have_rep_levels_(false), - have_values_(false) {} - - void AppendDefLevels( + numValues_(0), + encoding_(Encoding::kPlain), + definitionLevelEncoding_(Encoding::kRle), + repetitionLevelEncoding_(Encoding::kRle), + haveDefLevels_(false), + haveRepLevels_(false), + haveValues_(false) {} + + void appendDefLevels( const std::vector& levels, - int16_t max_level, - Encoding::type encoding = Encoding::RLE) { - AppendLevels(levels, max_level, encoding); + int16_t maxLevel, + Encoding::type encoding = Encoding::kRle) { + appendLevels(levels, maxLevel, encoding); - num_values_ = std::max(static_cast(levels.size()), num_values_); - definition_level_encoding_ = encoding; - have_def_levels_ = true; + numValues_ = std::max(static_cast(levels.size()), numValues_); + definitionLevelEncoding_ = encoding; + haveDefLevels_ = true; } - void AppendRepLevels( + void appendRepLevels( const std::vector& levels, - int16_t max_level, - Encoding::type encoding = Encoding::RLE) { - AppendLevels(levels, max_level, encoding); + int16_t maxLevel, + Encoding::type encoding = Encoding::kRle) { + appendLevels(levels, maxLevel, encoding); - num_values_ = std::max(static_cast(levels.size()), num_values_); - repetition_level_encoding_ = encoding; - have_rep_levels_ = true; + numValues_ = std::max(static_cast(levels.size()), numValues_); + repetitionLevelEncoding_ = encoding; + haveRepLevels_ = true; } - void AppendValues( + void appendValues( const ColumnDescriptor* d, - const std::vector& values, - Encoding::type encoding = Encoding::PLAIN) { - std::shared_ptr values_sink = EncodeValues( + const std::vector& values, + Encoding::type encoding = Encoding::kPlain) { + std::shared_ptr valuesSink = encodeValues( encoding, false, values.data(), static_cast(values.size()), d); - PARQUET_THROW_NOT_OK( - sink_->Write(values_sink->data(), values_sink->size())); + PARQUET_THROW_NOT_OK(sink_->Write(valuesSink->data(), valuesSink->size())); - num_values_ = std::max(static_cast(values.size()), num_values_); + numValues_ = std::max(static_cast(values.size()), numValues_); encoding_ = encoding; - have_values_ = true; + haveValues_ = true; } - int32_t num_values() const { - return num_values_; + int32_t numValues() const { + return numValues_; } Encoding::type encoding() const { return encoding_; } - Encoding::type rep_level_encoding() const { - return repetition_level_encoding_; + Encoding::type repLevelEncoding() const { + return repetitionLevelEncoding_; } - Encoding::type def_level_encoding() const { - return definition_level_encoding_; + Encoding::type defLevelEncoding() const { + return definitionLevelEncoding_; } private: ArrowOutputStream* sink_; - int32_t num_values_; + int32_t numValues_; Encoding::type encoding_; - Encoding::type definition_level_encoding_; - Encoding::type repetition_level_encoding_; + Encoding::type definitionLevelEncoding_; + Encoding::type repetitionLevelEncoding_; - bool have_def_levels_; - bool have_rep_levels_; - bool have_values_; + bool haveDefLevels_; + bool haveRepLevels_; + bool haveValues_; - // Used internally for both repetition and definition levels - void AppendLevels( + // Used internally for both repetition and definition levels. + void appendLevels( const std::vector& levels, - int16_t max_level, + int16_t maxLevel, Encoding::type encoding) { - if (encoding != Encoding::RLE) { + if (encoding != Encoding::kRle) { ParquetException::NYI("only rle encoding currently implemented"); } - std::vector encode_buffer(LevelEncoder::MaxBufferSize( - Encoding::RLE, max_level, static_cast(levels.size()))); + std::vector encodeBuffer( + LevelEncoder::maxBufferSize( + Encoding::kRle, maxLevel, static_cast(levels.size()))); - // We encode into separate memory from the output stream because the - // RLE-encoded bytes have to be preceded in the stream by their absolute - // size. + // We encode into separate memory from the output stream because the. + // RLE-encoded bytes have to be preceded in the stream by their absolute. + // Size. LevelEncoder encoder; - encoder.Init( + encoder.init( encoding, - max_level, + maxLevel, static_cast(levels.size()), - encode_buffer.data(), - static_cast(encode_buffer.size())); + encodeBuffer.data(), + static_cast(encodeBuffer.size())); - encoder.Encode(static_cast(levels.size()), levels.data()); + encoder.encode(static_cast(levels.size()), levels.data()); - int32_t rle_bytes = encoder.len(); + int32_t rleBytes = encoder.len(); PARQUET_THROW_NOT_OK(sink_->Write( - reinterpret_cast(&rle_bytes), sizeof(int32_t))); - PARQUET_THROW_NOT_OK(sink_->Write(encode_buffer.data(), rle_bytes)); + reinterpret_cast(&rleBytes), sizeof(int32_t))); + PARQUET_THROW_NOT_OK(sink_->Write(encodeBuffer.data(), rleBytes)); } }; template <> -inline void DataPageBuilder::AppendValues( +inline void DataPageBuilder::appendValues( const ColumnDescriptor* d, const std::vector& values, Encoding::type encoding) { - if (encoding != Encoding::PLAIN) { + if (encoding != Encoding::kPlain) { ParquetException::NYI("only plain encoding currently implemented"); } - auto encoder = MakeTypedEncoder(Encoding::PLAIN, false, d); + auto encoder = makeTypedEncoder(Encoding::kPlain, false, d); dynamic_cast(encoder.get()) - ->Put(values, static_cast(values.size())); - std::shared_ptr buffer = encoder->FlushValues(); + ->put(values, static_cast(values.size())); + std::shared_ptr buffer = encoder->flushValues(); PARQUET_THROW_NOT_OK(sink_->Write(buffer->data(), buffer->size())); - num_values_ = std::max(static_cast(values.size()), num_values_); + numValues_ = std::max(static_cast(values.size()), numValues_); encoding_ = encoding; - have_values_ = true; + haveValues_ = true; } template -static std::shared_ptr MakeDataPage( +static std::shared_ptr makeDataPage( const ColumnDescriptor* d, - const std::vector& values, - int num_vals, + const std::vector& values, + int numVals, Encoding::type encoding, const uint8_t* indices, - int indices_size, - const std::vector& def_levels, - int16_t max_def_level, - const std::vector& rep_levels, - int16_t max_rep_level) { - int num_values = 0; + int indicesSize, + const std::vector& defLevels, + int16_t maxDefLevel, + const std::vector& repLevels, + int16_t maxRepLevel) { + int numValues = 0; - auto page_stream = CreateOutputStream(); - test::DataPageBuilder page_builder(page_stream.get()); + auto pageStream = createOutputStream(); + test::DataPageBuilder pageBuilder(pageStream.get()); - if (!rep_levels.empty()) { - page_builder.AppendRepLevels(rep_levels, max_rep_level); + if (!repLevels.empty()) { + pageBuilder.appendRepLevels(repLevels, maxRepLevel); } - if (!def_levels.empty()) { - page_builder.AppendDefLevels(def_levels, max_def_level); + if (!defLevels.empty()) { + pageBuilder.appendDefLevels(defLevels, maxDefLevel); } - if (encoding == Encoding::PLAIN) { - page_builder.AppendValues(d, values, encoding); - num_values = std::max(page_builder.num_values(), num_vals); + if (encoding == Encoding::kPlain) { + pageBuilder.appendValues(d, values, encoding); + numValues = std::max(pageBuilder.numValues(), numVals); } else { // DICTIONARY PAGES - PARQUET_THROW_NOT_OK(page_stream->Write(indices, indices_size)); - num_values = std::max(page_builder.num_values(), num_vals); + PARQUET_THROW_NOT_OK(pageStream->Write(indices, indicesSize)); + numValues = std::max(pageBuilder.numValues(), numVals); } - PARQUET_ASSIGN_OR_THROW(auto buffer, page_stream->Finish()); + PARQUET_ASSIGN_OR_THROW(auto buffer, pageStream->Finish()); return std::make_shared( buffer, - num_values, + numValues, encoding, - page_builder.def_level_encoding(), - page_builder.rep_level_encoding(), + pageBuilder.defLevelEncoding(), + pageBuilder.repLevelEncoding(), buffer->size()); } template class DictionaryPageBuilder { public: - typedef typename TYPE::c_type TC; - static constexpr int TN = TYPE::type_num; + typedef typename TYPE::CType TC; + static constexpr int TN = TYPE::typeNum; using SpecializedEncoder = typename EncodingTraits::Encoder; - // This class writes data and metadata to the passed inputs + // This class writes data and metadata to the passed inputs. explicit DictionaryPageBuilder(const ColumnDescriptor* d) - : num_dict_values_(0), have_values_(false) { - auto encoder = MakeTypedEncoder(Encoding::PLAIN, true, d); - dict_traits_ = dynamic_cast*>(encoder.get()); + : numDictValues_(0), haveValues_(false) { + auto encoder = makeTypedEncoder(Encoding::kPlain, true, d); + dictTraits_ = dynamic_cast*>(encoder.get()); encoder_.reset(dynamic_cast(encoder.release())); } ~DictionaryPageBuilder() {} - std::shared_ptr AppendValues(const std::vector& values) { - int num_values = static_cast(values.size()); - // Dictionary encoding - encoder_->Put(values.data(), num_values); - num_dict_values_ = dict_traits_->num_entries(); - have_values_ = true; - return encoder_->FlushValues(); + std::shared_ptr appendValues(const std::vector& values) { + int numValues = static_cast(values.size()); + // Dictionary encoding. + encoder_->put(values.data(), numValues); + numDictValues_ = dictTraits_->numEntries(); + haveValues_ = true; + return encoder_->flushValues(); } - std::shared_ptr WriteDict() { - std::shared_ptr dict_buffer = AllocateBuffer( - ::arrow::default_memory_pool(), dict_traits_->dict_encoded_size()); - dict_traits_->WriteDict(dict_buffer->mutable_data()); - return dict_buffer; + std::shared_ptr writeDict() { + std::shared_ptr dictBuffer = allocateBuffer( + ::arrow::default_memory_pool(), dictTraits_->dictEncodedSize()); + dictTraits_->writeDict(dictBuffer->mutable_data()); + return dictBuffer; } - int32_t num_values() const { - return num_dict_values_; + int32_t numValues() const { + return numDictValues_; } private: - DictEncoder* dict_traits_; + DictEncoder* dictTraits_; std::unique_ptr encoder_; - int32_t num_dict_values_; - bool have_values_; + int32_t numDictValues_; + bool haveValues_; }; template <> @@ -504,14 +504,14 @@ inline DictionaryPageBuilder::DictionaryPageBuilder( } template <> -inline std::shared_ptr DictionaryPageBuilder::WriteDict() { +inline std::shared_ptr DictionaryPageBuilder::writeDict() { ParquetException::NYI( "only plain encoding currently implemented for boolean"); return nullptr; } template <> -inline std::shared_ptr DictionaryPageBuilder::AppendValues( +inline std::shared_ptr DictionaryPageBuilder::appendValues( const std::vector& values) { ParquetException::NYI( "only plain encoding currently implemented for boolean"); @@ -519,213 +519,212 @@ inline std::shared_ptr DictionaryPageBuilder::AppendValues( } template -inline static std::shared_ptr MakeDictPage( +inline static std::shared_ptr makeDictPage( const ColumnDescriptor* d, - const std::vector& values, - const std::vector& values_per_page, + const std::vector& values, + const std::vector& valuesPerPage, Encoding::type encoding, - std::vector>& rle_indices) { - test::DictionaryPageBuilder page_builder(d); - int num_pages = static_cast(values_per_page.size()); - int value_start = 0; + std::vector>& rleIndices) { + test::DictionaryPageBuilder pageBuilder(d); + int numPages = static_cast(valuesPerPage.size()); + int valueStart = 0; - for (int i = 0; i < num_pages; i++) { - rle_indices.push_back(page_builder.AppendValues( - slice(values, value_start, value_start + values_per_page[i]))); - value_start += values_per_page[i]; + for (int i = 0; i < numPages; i++) { + rleIndices.push_back(pageBuilder.appendValues( + slice(values, valueStart, valueStart + valuesPerPage[i]))); + valueStart += valuesPerPage[i]; } - auto buffer = page_builder.WriteDict(); + auto buffer = pageBuilder.writeDict(); return std::make_shared( - buffer, page_builder.num_values(), Encoding::PLAIN); + buffer, pageBuilder.numValues(), Encoding::kPlain); } -// Given def/rep levels and values create multiple dict pages +// Given def/rep levels and values create multiple dict pages. template -inline static void PaginateDict( +inline static void paginateDict( const ColumnDescriptor* d, - const std::vector& values, - const std::vector& def_levels, - int16_t max_def_level, - const std::vector& rep_levels, - int16_t max_rep_level, - int num_levels_per_page, - const std::vector& values_per_page, + const std::vector& values, + const std::vector& defLevels, + int16_t maxDefLevel, + const std::vector& repLevels, + int16_t maxRepLevel, + int numLevelsPerPage, + const std::vector& valuesPerPage, std::vector>& pages, - Encoding::type encoding = Encoding::RLE_DICTIONARY) { - int num_pages = static_cast(values_per_page.size()); - std::vector> rle_indices; - std::shared_ptr dict_page = - MakeDictPage(d, values, values_per_page, encoding, rle_indices); - pages.push_back(dict_page); - int def_level_start = 0; - int def_level_end = 0; - int rep_level_start = 0; - int rep_level_end = 0; - for (int i = 0; i < num_pages; i++) { - if (max_def_level > 0) { - def_level_start = i * num_levels_per_page; - def_level_end = (i + 1) * num_levels_per_page; + Encoding::type encoding = Encoding::kRleDictionary) { + int numPages = static_cast(valuesPerPage.size()); + std::vector> rleIndices; + std::shared_ptr dictPage = + makeDictPage(d, values, valuesPerPage, encoding, rleIndices); + pages.push_back(dictPage); + int defLevelStart = 0; + int defLevelEnd = 0; + int repLevelStart = 0; + int repLevelEnd = 0; + for (int i = 0; i < numPages; i++) { + if (maxDefLevel > 0) { + defLevelStart = i * numLevelsPerPage; + defLevelEnd = (i + 1) * numLevelsPerPage; } - if (max_rep_level > 0) { - rep_level_start = i * num_levels_per_page; - rep_level_end = (i + 1) * num_levels_per_page; + if (maxRepLevel > 0) { + repLevelStart = i * numLevelsPerPage; + repLevelEnd = (i + 1) * numLevelsPerPage; } - std::shared_ptr data_page = MakeDataPage( + std::shared_ptr dataPage = makeDataPage( d, {}, - values_per_page[i], + valuesPerPage[i], encoding, - rle_indices[i]->data(), - static_cast(rle_indices[i]->size()), - slice(def_levels, def_level_start, def_level_end), - max_def_level, - slice(rep_levels, rep_level_start, rep_level_end), - max_rep_level); - pages.push_back(data_page); + rleIndices[i]->data(), + static_cast(rleIndices[i]->size()), + slice(defLevels, defLevelStart, defLevelEnd), + maxDefLevel, + slice(repLevels, repLevelStart, repLevelEnd), + maxRepLevel); + pages.push_back(dataPage); } } -// Given def/rep levels and values create multiple plain pages +// Given def/rep levels and values create multiple plain pages. template -static inline void PaginatePlain( +static inline void paginatePlain( const ColumnDescriptor* d, - const std::vector& values, - const std::vector& def_levels, - int16_t max_def_level, - const std::vector& rep_levels, - int16_t max_rep_level, - int num_levels_per_page, - const std::vector& values_per_page, + const std::vector& values, + const std::vector& defLevels, + int16_t maxDefLevel, + const std::vector& repLevels, + int16_t maxRepLevel, + int numLevelsPerPage, + const std::vector& valuesPerPage, std::vector>& pages, - Encoding::type encoding = Encoding::PLAIN) { - int num_pages = static_cast(values_per_page.size()); - int def_level_start = 0; - int def_level_end = 0; - int rep_level_start = 0; - int rep_level_end = 0; - int value_start = 0; - for (int i = 0; i < num_pages; i++) { - if (max_def_level > 0) { - def_level_start = i * num_levels_per_page; - def_level_end = (i + 1) * num_levels_per_page; + Encoding::type encoding = Encoding::kPlain) { + int numPages = static_cast(valuesPerPage.size()); + int defLevelStart = 0; + int defLevelEnd = 0; + int repLevelStart = 0; + int repLevelEnd = 0; + int valueStart = 0; + for (int i = 0; i < numPages; i++) { + if (maxDefLevel > 0) { + defLevelStart = i * numLevelsPerPage; + defLevelEnd = (i + 1) * numLevelsPerPage; } - if (max_rep_level > 0) { - rep_level_start = i * num_levels_per_page; - rep_level_end = (i + 1) * num_levels_per_page; + if (maxRepLevel > 0) { + repLevelStart = i * numLevelsPerPage; + repLevelEnd = (i + 1) * numLevelsPerPage; } - std::shared_ptr page = MakeDataPage( + std::shared_ptr page = makeDataPage( d, - slice(values, value_start, value_start + values_per_page[i]), - values_per_page[i], + slice(values, valueStart, valueStart + valuesPerPage[i]), + valuesPerPage[i], encoding, nullptr, 0, - slice(def_levels, def_level_start, def_level_end), - max_def_level, - slice(rep_levels, rep_level_start, rep_level_end), - max_rep_level); + slice(defLevels, defLevelStart, defLevelEnd), + maxDefLevel, + slice(repLevels, repLevelStart, repLevelEnd), + maxRepLevel); pages.push_back(page); - value_start += values_per_page[i]; + valueStart += valuesPerPage[i]; } } -// Generates pages from randomly generated data +// Generates pages from randomly generated data. template -static inline int MakePages( +static inline int makePages( const ColumnDescriptor* d, - int num_pages, - int levels_per_page, - std::vector& def_levels, - std::vector& rep_levels, - std::vector& values, + int numPages, + int levelsPerPage, + std::vector& defLevels, + std::vector& repLevels, + std::vector& values, std::vector& buffer, std::vector>& pages, - Encoding::type encoding = Encoding::PLAIN, + Encoding::type encoding = Encoding::kPlain, uint32_t seed = 0) { - int num_levels = levels_per_page * num_pages; - int num_values = 0; + int numLevels = levelsPerPage * numPages; + int numValues = 0; int16_t zero = 0; - int16_t max_def_level = d->max_definition_level(); - int16_t max_rep_level = d->max_repetition_level(); - std::vector values_per_page(num_pages, levels_per_page); - // Create definition levels - if (max_def_level > 0 && num_levels != 0) { - def_levels.resize(num_levels); - random_numbers(num_levels, seed, zero, max_def_level, def_levels.data()); - for (int p = 0; p < num_pages; p++) { - int num_values_per_page = 0; - for (int i = 0; i < levels_per_page; i++) { - if (def_levels[i + p * levels_per_page] == max_def_level) { - num_values_per_page++; - num_values++; + int16_t maxDefLevel = d->maxDefinitionLevel(); + int16_t maxRepLevel = d->maxRepetitionLevel(); + std::vector valuesPerPage(numPages, levelsPerPage); + // Create definition levels. + if (maxDefLevel > 0 && numLevels != 0) { + defLevels.resize(numLevels); + randomNumbers(numLevels, seed, zero, maxDefLevel, defLevels.data()); + for (int p = 0; p < numPages; p++) { + int numValuesPerPage = 0; + for (int i = 0; i < levelsPerPage; i++) { + if (defLevels[i + p * levelsPerPage] == maxDefLevel) { + numValuesPerPage++; + numValues++; } } - values_per_page[p] = num_values_per_page; + valuesPerPage[p] = numValuesPerPage; } } else { - num_values = num_levels; + numValues = numLevels; } - // Create repitition levels - if (max_rep_level > 0 && num_levels != 0) { - rep_levels.resize(num_levels); + // Create repitition levels. + if (maxRepLevel > 0 && numLevels != 0) { + repLevels.resize(numLevels); // Using a different seed so that def_levels and rep_levels are different. - random_numbers( - num_levels, seed + 789, zero, max_rep_level, rep_levels.data()); - // The generated levels are random. Force the very first page to start with - // a new record. - rep_levels[0] = 0; + randomNumbers(numLevels, seed + 789, zero, maxRepLevel, repLevels.data()); + // The generated levels are random. Force the very first page to start with. + // A new record. + repLevels[0] = 0; // For a null value, rep_levels and def_levels are both 0. - // If we have a repeated value right after this, it needs to start with - // rep_level = 0 to indicate a new record. - for (int i = 0; i < num_levels - 1; ++i) { - if (rep_levels[i] == 0 && def_levels[i] == 0) { - rep_levels[i + 1] = 0; + // If we have a repeated value right after this, it needs to start with. + // Rep_level = 0 to indicate a new record. + for (int i = 0; i < numLevels - 1; ++i) { + if (repLevels[i] == 0 && defLevels[i] == 0) { + repLevels[i + 1] = 0; } } } - // Create values - values.resize(num_values); - if (encoding == Encoding::PLAIN) { - InitValues(num_values, values, buffer); - PaginatePlain( + // Create values. + values.resize(numValues); + if (encoding == Encoding::kPlain) { + initValues(numValues, values, buffer); + paginatePlain( d, values, - def_levels, - max_def_level, - rep_levels, - max_rep_level, - levels_per_page, - values_per_page, + defLevels, + maxDefLevel, + repLevels, + maxRepLevel, + levelsPerPage, + valuesPerPage, pages); } else if ( - encoding == Encoding::RLE_DICTIONARY || - encoding == Encoding::PLAIN_DICTIONARY) { - // Calls InitValues and repeats the data - InitDictValues( - num_values, levels_per_page, values, buffer); - PaginateDict( + encoding == Encoding::kRleDictionary || + encoding == Encoding::kPlainDictionary) { + // Calls InitValues and repeats the data. + initDictValues( + numValues, levelsPerPage, values, buffer); + paginateDict( d, values, - def_levels, - max_def_level, - rep_levels, - max_rep_level, - levels_per_page, - values_per_page, + defLevels, + maxDefLevel, + repLevels, + maxRepLevel, + levelsPerPage, + valuesPerPage, pages); } - return num_values; + return numValues; } -// ---------------------------------------------------------------------- -// Test data generation +// ----------------------------------------------------------------------. +// Test data generation. template <> -void inline InitValues( - int num_values, +void inline initValues( + int numValues, uint32_t seed, std::vector& values, std::vector& buffer) { @@ -733,185 +732,183 @@ void inline InitValues( if (seed == 0) { seed = static_cast(::arrow::random_seed()); } - ::arrow::random_is_valid(num_values, 0.5, &values, static_cast(seed)); + ::arrow::random_is_valid(numValues, 0.5, &values, static_cast(seed)); } template <> -inline void InitValues( - int num_values, +inline void initValues( + int numValues, uint32_t seed, std::vector& values, std::vector& buffer) { - int max_byte_array_len = 12; - int num_bytes = static_cast(max_byte_array_len + sizeof(uint32_t)); - size_t nbytes = num_values * num_bytes; + int maxByteArrayLen = 12; + int numBytes = static_cast(maxByteArrayLen + sizeof(uint32_t)); + size_t nbytes = numValues * numBytes; buffer.resize(nbytes); - random_byte_array( - num_values, seed, buffer.data(), values.data(), max_byte_array_len); + randomByteArray( + numValues, seed, buffer.data(), values.data(), maxByteArrayLen); } -inline void InitWideByteArrayValues( - int num_values, +inline void initWideByteArrayValues( + int numValues, std::vector& values, std::vector& buffer, - int min_len, - int max_len) { - int num_bytes = static_cast(max_len + sizeof(uint32_t)); - size_t nbytes = num_values * num_bytes; + int minLen, + int maxLen) { + int numBytes = static_cast(maxLen + sizeof(uint32_t)); + size_t nbytes = numValues * numBytes; buffer.resize(nbytes); - random_byte_array( - num_values, 0, buffer.data(), values.data(), min_len, max_len); + randomByteArray(numValues, 0, buffer.data(), values.data(), minLen, maxLen); } template <> -inline void InitValues( - int num_values, +inline void initValues( + int numValues, uint32_t seed, std::vector& values, std::vector& buffer) { - size_t nbytes = num_values * FLBA_LENGTH; + size_t nbytes = numValues * FLBA_LENGTH; buffer.resize(nbytes); - random_fixed_byte_array( - num_values, seed, buffer.data(), FLBA_LENGTH, values.data()); + randomFixedByteArray( + numValues, seed, buffer.data(), FLBA_LENGTH, values.data()); } template <> -inline void InitValues( - int num_values, +inline void initValues( + int numValues, uint32_t seed, std::vector& values, std::vector& buffer) { - random_Int96_numbers( - num_values, + randomInt96Numbers( + numValues, seed, std::numeric_limits::min(), std::numeric_limits::max(), values.data()); } -inline std::string TestColumnName(int i) { - std::stringstream col_name; - col_name << "column_" << i; - return col_name.str(); +inline std::string testColumnName(int i) { + std::stringstream colName; + colName << "column_" << i; + return colName.str(); } -// This class lives here because of its dependency on the InitValues -// specializations. +// This class lives here because of its dependency on the InitValues. +// Specializations. template class PrimitiveTypedTest : public ::testing::Test { public: - using c_type = typename TestType::c_type; + using CType = typename TestType::CType; - void SetUpSchema(Repetition::type repetition, int num_columns = 1) { + void setUpSchema(Repetition::type repetition, int numColumns = 1) { std::vector fields; - for (int i = 0; i < num_columns; ++i) { - std::string name = TestColumnName(i); - fields.push_back(schema::PrimitiveNode::Make( - name, - repetition, - TestType::type_num, - ConvertedType::NONE, - FLBA_LENGTH)); + for (int i = 0; i < numColumns; ++i) { + std::string name = testColumnName(i); + fields.push_back( + schema::PrimitiveNode::make( + name, + repetition, + TestType::typeNum, + ConvertedType::kNone, + FLBA_LENGTH)); } - node_ = schema::GroupNode::Make("schema", Repetition::REQUIRED, fields); - schema_.Init(node_); + node_ = schema::GroupNode::make("schema", Repetition::kRequired, fields); + schema_.init(node_); } - void GenerateData(int64_t num_values, uint32_t seed = 0); - void SetupValuesOut(int64_t num_values); - void SyncValuesOut(); + void generateData(int64_t numValues, uint32_t seed = 0); + void setupValuesOut(int64_t numValues); + void syncValuesOut(); protected: schema::NodePtr node_; SchemaDescriptor schema_; - // Input buffers - std::vector values_; + // Input buffers. + std::vector values_; - std::vector def_levels_; + std::vector defLevels_; std::vector buffer_; // Pointer to the values, needed as we cannot use std::vector::data() - c_type* values_ptr_; - std::vector bool_buffer_; + CType* valuesPtr_; + std::vector boolBuffer_; - // Output buffers - std::vector values_out_; - std::vector bool_buffer_out_; - c_type* values_out_ptr_; + // Output buffers. + std::vector valuesOut_; + std::vector boolBufferOut_; + CType* valuesOutPtr_; }; template -inline void PrimitiveTypedTest::SyncValuesOut() {} +inline void PrimitiveTypedTest::syncValuesOut() {} template <> -inline void PrimitiveTypedTest::SyncValuesOut() { - std::vector::const_iterator source_iterator = - bool_buffer_out_.begin(); - std::vector::iterator destination_iterator = values_out_.begin(); - while (source_iterator != bool_buffer_out_.end()) { - *destination_iterator++ = *source_iterator++ != 0; +inline void PrimitiveTypedTest::syncValuesOut() { + std::vector::const_iterator sourceIterator = boolBufferOut_.begin(); + std::vector::iterator destinationIterator = valuesOut_.begin(); + while (sourceIterator != boolBufferOut_.end()) { + *destinationIterator++ = *sourceIterator++ != 0; } } template -inline void PrimitiveTypedTest::SetupValuesOut(int64_t num_values) { - values_out_.clear(); - values_out_.resize(num_values); - values_out_ptr_ = values_out_.data(); +inline void PrimitiveTypedTest::setupValuesOut(int64_t numValues) { + valuesOut_.clear(); + valuesOut_.resize(numValues); + valuesOutPtr_ = valuesOut_.data(); } template <> -inline void PrimitiveTypedTest::SetupValuesOut( - int64_t num_values) { - values_out_.clear(); - values_out_.resize(num_values); +inline void PrimitiveTypedTest::setupValuesOut(int64_t numValues) { + valuesOut_.clear(); + valuesOut_.resize(numValues); - bool_buffer_out_.clear(); - bool_buffer_out_.resize(num_values); - // Write once to all values so we can copy it without getting Valgrind errors - // about uninitialised values. - std::fill(bool_buffer_out_.begin(), bool_buffer_out_.end(), true); - values_out_ptr_ = reinterpret_cast(bool_buffer_out_.data()); + boolBufferOut_.clear(); + boolBufferOut_.resize(numValues); + // Write once to all values so we can copy it without getting Valgrind errors. + // About uninitialised values. + std::fill(boolBufferOut_.begin(), boolBufferOut_.end(), true); + valuesOutPtr_ = reinterpret_cast(boolBufferOut_.data()); } template -inline void PrimitiveTypedTest::GenerateData( - int64_t num_values, +inline void PrimitiveTypedTest::generateData( + int64_t numValues, uint32_t seed) { - def_levels_.resize(num_values); - values_.resize(num_values); + defLevels_.resize(numValues); + values_.resize(numValues); - InitValues(static_cast(num_values), seed, values_, buffer_); - values_ptr_ = values_.data(); + initValues(static_cast(numValues), seed, values_, buffer_); + valuesPtr_ = values_.data(); - std::fill(def_levels_.begin(), def_levels_.end(), 1); + std::fill(defLevels_.begin(), defLevels_.end(), 1); } template <> -inline void PrimitiveTypedTest::GenerateData( - int64_t num_values, +inline void PrimitiveTypedTest::generateData( + int64_t numValues, uint32_t seed) { - def_levels_.resize(num_values); - values_.resize(num_values); + defLevels_.resize(numValues); + values_.resize(numValues); - InitValues(static_cast(num_values), seed, values_, buffer_); - bool_buffer_.resize(num_values); - std::copy(values_.begin(), values_.end(), bool_buffer_.begin()); - values_ptr_ = reinterpret_cast(bool_buffer_.data()); + initValues(static_cast(numValues), seed, values_, buffer_); + boolBuffer_.resize(numValues); + std::copy(values_.begin(), values_.end(), boolBuffer_.begin()); + valuesPtr_ = reinterpret_cast(boolBuffer_.data()); - std::fill(def_levels_.begin(), def_levels_.end(), 1); + std::fill(defLevels_.begin(), defLevels_.end(), 1); } -// ---------------------------------------------------------------------- -// test data generation +// ----------------------------------------------------------------------. +// Test data generation. template -inline void GenerateData(int num_values, T* out, std::vector* heap) { - // seed the prng so failure is deterministic - random_numbers( - num_values, +inline void generateData(int numValues, T* out, std::vector* heap) { + // Seed the prng so failure is deterministic. + randomNumbers( + numValues, 0, std::numeric_limits::min(), std::numeric_limits::max(), @@ -919,29 +916,29 @@ inline void GenerateData(int num_values, T* out, std::vector* heap) { } template -inline void GenerateBoundData( - int num_values, +inline void generateBoundData( + int numValues, T* out, T min, T max, std::vector* heap) { - // seed the prng so failure is deterministic - random_numbers(num_values, 0, min, max, out); + // Seed the prng so failure is deterministic. + randomNumbers(numValues, 0, min, max, out); } template <> inline void -GenerateData(int num_values, bool* out, std::vector* heap) { - // seed the prng so failure is deterministic - random_bools(num_values, 0.5, 0, out); +generateData(int numValues, bool* out, std::vector* heap) { + // Seed the prng so failure is deterministic. + randomBools(numValues, 0.5, 0, out); } template <> inline void -GenerateData(int num_values, Int96* out, std::vector* heap) { - // seed the prng so failure is deterministic - random_Int96_numbers( - num_values, +generateData(int numValues, Int96* out, std::vector* heap) { + // Seed the prng so failure is deterministic. + randomInt96Numbers( + numValues, 0, std::numeric_limits::min(), std::numeric_limits::max(), @@ -949,25 +946,25 @@ GenerateData(int num_values, Int96* out, std::vector* heap) { } template <> -inline void GenerateData( - int num_values, +inline void generateData( + int numValues, ByteArray* out, std::vector* heap) { - // seed the prng so failure is deterministic - int max_byte_array_len = 12; - heap->resize(num_values * max_byte_array_len); - random_byte_array(num_values, 0, heap->data(), out, 2, max_byte_array_len); + // Seed the prng so failure is deterministic. + int maxByteArrayLen = 12; + heap->resize(numValues * maxByteArrayLen); + randomByteArray(numValues, 0, heap->data(), out, 2, maxByteArrayLen); } static constexpr int kGenerateDataFLBALength = 8; template <> inline void -GenerateData(int num_values, FLBA* out, std::vector* heap) { - // seed the prng so failure is deterministic - heap->resize(num_values * kGenerateDataFLBALength); - random_fixed_byte_array( - num_values, 0, heap->data(), kGenerateDataFLBALength, out); +generateData(int numValues, FLBA* out, std::vector* heap) { + // Seed the prng so failure is deterministic. + heap->resize(numValues * kGenerateDataFLBALength); + randomFixedByteArray( + numValues, 0, heap->data(), kGenerateDataFLBALength, out); } } // namespace test diff --git a/velox/dwio/parquet/writer/arrow/tests/TypesTest.cpp b/velox/dwio/parquet/writer/arrow/tests/TypesTest.cpp index 495b02c1ee8..476cc7f9fea 100644 --- a/velox/dwio/parquet/writer/arrow/tests/TypesTest.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/TypesTest.cpp @@ -26,54 +26,54 @@ namespace facebook::velox::parquet::arrow { TEST(TestTypeToString, PhysicalTypes) { - ASSERT_STREQ("BOOLEAN", TypeToString(Type::BOOLEAN).c_str()); - ASSERT_STREQ("INT32", TypeToString(Type::INT32).c_str()); - ASSERT_STREQ("INT64", TypeToString(Type::INT64).c_str()); - ASSERT_STREQ("INT96", TypeToString(Type::INT96).c_str()); - ASSERT_STREQ("FLOAT", TypeToString(Type::FLOAT).c_str()); - ASSERT_STREQ("DOUBLE", TypeToString(Type::DOUBLE).c_str()); - ASSERT_STREQ("BYTE_ARRAY", TypeToString(Type::BYTE_ARRAY).c_str()); + ASSERT_STREQ("BOOLEAN", typeToString(Type::kBoolean).c_str()); + ASSERT_STREQ("INT32", typeToString(Type::kInt32).c_str()); + ASSERT_STREQ("INT64", typeToString(Type::kInt64).c_str()); + ASSERT_STREQ("INT96", typeToString(Type::kInt96).c_str()); + ASSERT_STREQ("FLOAT", typeToString(Type::kFloat).c_str()); + ASSERT_STREQ("DOUBLE", typeToString(Type::kDouble).c_str()); + ASSERT_STREQ("BYTE_ARRAY", typeToString(Type::kByteArray).c_str()); ASSERT_STREQ( - "FIXED_LEN_BYTE_ARRAY", TypeToString(Type::FIXED_LEN_BYTE_ARRAY).c_str()); + "FIXED_LEN_BYTE_ARRAY", typeToString(Type::kFixedLenByteArray).c_str()); } TEST(TestConvertedTypeToString, ConvertedTypes) { - ASSERT_STREQ("NONE", ConvertedTypeToString(ConvertedType::NONE).c_str()); - ASSERT_STREQ("UTF8", ConvertedTypeToString(ConvertedType::UTF8).c_str()); - ASSERT_STREQ("MAP", ConvertedTypeToString(ConvertedType::MAP).c_str()); + ASSERT_STREQ("NONE", convertedTypeToString(ConvertedType::kNone).c_str()); + ASSERT_STREQ("UTF8", convertedTypeToString(ConvertedType::kUtf8).c_str()); + ASSERT_STREQ("MAP", convertedTypeToString(ConvertedType::kMap).c_str()); ASSERT_STREQ( "MAP_KEY_VALUE", - ConvertedTypeToString(ConvertedType::MAP_KEY_VALUE).c_str()); - ASSERT_STREQ("LIST", ConvertedTypeToString(ConvertedType::LIST).c_str()); - ASSERT_STREQ("ENUM", ConvertedTypeToString(ConvertedType::ENUM).c_str()); + convertedTypeToString(ConvertedType::kMapKeyValue).c_str()); + ASSERT_STREQ("LIST", convertedTypeToString(ConvertedType::kList).c_str()); + ASSERT_STREQ("ENUM", convertedTypeToString(ConvertedType::kEnum).c_str()); ASSERT_STREQ( - "DECIMAL", ConvertedTypeToString(ConvertedType::DECIMAL).c_str()); - ASSERT_STREQ("DATE", ConvertedTypeToString(ConvertedType::DATE).c_str()); + "DECIMAL", convertedTypeToString(ConvertedType::kDecimal).c_str()); + ASSERT_STREQ("DATE", convertedTypeToString(ConvertedType::kDate).c_str()); ASSERT_STREQ( - "TIME_MILLIS", ConvertedTypeToString(ConvertedType::TIME_MILLIS).c_str()); + "TIME_MILLIS", convertedTypeToString(ConvertedType::kTimeMillis).c_str()); ASSERT_STREQ( - "TIME_MICROS", ConvertedTypeToString(ConvertedType::TIME_MICROS).c_str()); + "TIME_MICROS", convertedTypeToString(ConvertedType::kTimeMicros).c_str()); ASSERT_STREQ( "TIMESTAMP_MILLIS", - ConvertedTypeToString(ConvertedType::TIMESTAMP_MILLIS).c_str()); + convertedTypeToString(ConvertedType::kTimestampMillis).c_str()); ASSERT_STREQ( "TIMESTAMP_MICROS", - ConvertedTypeToString(ConvertedType::TIMESTAMP_MICROS).c_str()); - ASSERT_STREQ("UINT_8", ConvertedTypeToString(ConvertedType::UINT_8).c_str()); + convertedTypeToString(ConvertedType::kTimestampMicros).c_str()); + ASSERT_STREQ("UINT_8", convertedTypeToString(ConvertedType::kUint8).c_str()); ASSERT_STREQ( - "UINT_16", ConvertedTypeToString(ConvertedType::UINT_16).c_str()); + "UINT_16", convertedTypeToString(ConvertedType::kUint16).c_str()); ASSERT_STREQ( - "UINT_32", ConvertedTypeToString(ConvertedType::UINT_32).c_str()); + "UINT_32", convertedTypeToString(ConvertedType::kUint32).c_str()); ASSERT_STREQ( - "UINT_64", ConvertedTypeToString(ConvertedType::UINT_64).c_str()); - ASSERT_STREQ("INT_8", ConvertedTypeToString(ConvertedType::INT_8).c_str()); - ASSERT_STREQ("INT_16", ConvertedTypeToString(ConvertedType::INT_16).c_str()); - ASSERT_STREQ("INT_32", ConvertedTypeToString(ConvertedType::INT_32).c_str()); - ASSERT_STREQ("INT_64", ConvertedTypeToString(ConvertedType::INT_64).c_str()); - ASSERT_STREQ("JSON", ConvertedTypeToString(ConvertedType::JSON).c_str()); - ASSERT_STREQ("BSON", ConvertedTypeToString(ConvertedType::BSON).c_str()); + "UINT_64", convertedTypeToString(ConvertedType::kUint64).c_str()); + ASSERT_STREQ("INT_8", convertedTypeToString(ConvertedType::kInt8).c_str()); + ASSERT_STREQ("INT_16", convertedTypeToString(ConvertedType::kInt16).c_str()); + ASSERT_STREQ("INT_32", convertedTypeToString(ConvertedType::kInt32).c_str()); + ASSERT_STREQ("INT_64", convertedTypeToString(ConvertedType::kInt64).c_str()); + ASSERT_STREQ("JSON", convertedTypeToString(ConvertedType::kJson).c_str()); + ASSERT_STREQ("BSON", convertedTypeToString(ConvertedType::kBson).c_str()); ASSERT_STREQ( - "INTERVAL", ConvertedTypeToString(ConvertedType::INTERVAL).c_str()); + "INTERVAL", convertedTypeToString(ConvertedType::kInterval).c_str()); } #ifdef __GNUC__ @@ -87,86 +87,87 @@ TEST(TestConvertedTypeToString, ConvertedTypes) { TEST(TypePrinter, StatisticsTypes) { std::string smin; std::string smax; - int32_t int_min = 1024; - int32_t int_max = 2048; - smin = std::string(reinterpret_cast(&int_min), sizeof(int32_t)); - smax = std::string(reinterpret_cast(&int_max), sizeof(int32_t)); - ASSERT_STREQ("1024", FormatStatValue(Type::INT32, smin).c_str()); - ASSERT_STREQ("2048", FormatStatValue(Type::INT32, smax).c_str()); - - int64_t int64_min = 10240000000000; - int64_t int64_max = 20480000000000; - smin = std::string(reinterpret_cast(&int64_min), sizeof(int64_t)); - smax = std::string(reinterpret_cast(&int64_max), sizeof(int64_t)); - ASSERT_STREQ("10240000000000", FormatStatValue(Type::INT64, smin).c_str()); - ASSERT_STREQ("20480000000000", FormatStatValue(Type::INT64, smax).c_str()); - - float float_min = 1.024f; - float float_max = 2.048f; - smin = std::string(reinterpret_cast(&float_min), sizeof(float)); - smax = std::string(reinterpret_cast(&float_max), sizeof(float)); - ASSERT_STREQ("1.024", FormatStatValue(Type::FLOAT, smin).c_str()); - ASSERT_STREQ("2.048", FormatStatValue(Type::FLOAT, smax).c_str()); - - double double_min = 1.0245; - double double_max = 2.0489; - smin = std::string(reinterpret_cast(&double_min), sizeof(double)); - smax = std::string(reinterpret_cast(&double_max), sizeof(double)); - ASSERT_STREQ("1.0245", FormatStatValue(Type::DOUBLE, smin).c_str()); - ASSERT_STREQ("2.0489", FormatStatValue(Type::DOUBLE, smax).c_str()); + int32_t intMin = 1024; + int32_t intMax = 2048; + smin = std::string(reinterpret_cast(&intMin), sizeof(int32_t)); + smax = std::string(reinterpret_cast(&intMax), sizeof(int32_t)); + ASSERT_STREQ("1024", formatStatValue(Type::kInt32, smin).c_str()); + ASSERT_STREQ("2048", formatStatValue(Type::kInt32, smax).c_str()); + + int64_t int64Min = 10240000000000; + int64_t int64Max = 20480000000000; + smin = std::string(reinterpret_cast(&int64Min), sizeof(int64_t)); + smax = std::string(reinterpret_cast(&int64Max), sizeof(int64_t)); + ASSERT_STREQ("10240000000000", formatStatValue(Type::kInt64, smin).c_str()); + ASSERT_STREQ("20480000000000", formatStatValue(Type::kInt64, smax).c_str()); + + float floatMin = 1.024f; + float floatMax = 2.048f; + smin = std::string(reinterpret_cast(&floatMin), sizeof(float)); + smax = std::string(reinterpret_cast(&floatMax), sizeof(float)); + ASSERT_STREQ("1.024", formatStatValue(Type::kFloat, smin).c_str()); + ASSERT_STREQ("2.048", formatStatValue(Type::kFloat, smax).c_str()); + + double doubleMin = 1.0245; + double doubleMax = 2.0489; + smin = std::string(reinterpret_cast(&doubleMin), sizeof(double)); + smax = std::string(reinterpret_cast(&doubleMax), sizeof(double)); + ASSERT_STREQ("1.0245", formatStatValue(Type::kDouble, smin).c_str()); + ASSERT_STREQ("2.0489", formatStatValue(Type::kDouble, smax).c_str()); #if ARROW_LITTLE_ENDIAN - Int96 Int96_min = {{1024, 2048, 4096}}; - Int96 Int96_max = {{2048, 4096, 8192}}; + Int96 int96Min = {{1024, 2048, 4096}}; + Int96 int96Max = {{2048, 4096, 8192}}; #else - Int96 Int96_min = {{2048, 1024, 4096}}; - Int96 Int96_max = {{4096, 2048, 8192}}; + Int96 int96Min = {{2048, 1024, 4096}}; + Int96 int96Max = {{4096, 2048, 8192}}; #endif - smin = std::string(reinterpret_cast(&Int96_min), sizeof(Int96)); - smax = std::string(reinterpret_cast(&Int96_max), sizeof(Int96)); - ASSERT_STREQ("1024 2048 4096", FormatStatValue(Type::INT96, smin).c_str()); - ASSERT_STREQ("2048 4096 8192", FormatStatValue(Type::INT96, smax).c_str()); + smin = std::string(reinterpret_cast(&int96Min), sizeof(Int96)); + smax = std::string(reinterpret_cast(&int96Max), sizeof(Int96)); + ASSERT_STREQ("1024 2048 4096", formatStatValue(Type::kInt96, smin).c_str()); + ASSERT_STREQ("2048 4096 8192", formatStatValue(Type::kInt96, smax).c_str()); smin = std::string("abcdef"); smax = std::string("ijklmnop"); - ASSERT_STREQ("abcdef", FormatStatValue(Type::BYTE_ARRAY, smin).c_str()); - ASSERT_STREQ("ijklmnop", FormatStatValue(Type::BYTE_ARRAY, smax).c_str()); + ASSERT_STREQ("abcdef", formatStatValue(Type::kByteArray, smin).c_str()); + ASSERT_STREQ("ijklmnop", formatStatValue(Type::kByteArray, smax).c_str()); - // PARQUET-1357: FormatStatValue truncates binary statistics on zero character + // PARQUET-1357: FormatStatValue truncates binary statistics on zero + // character. smax.push_back('\0'); - ASSERT_EQ(smax, FormatStatValue(Type::BYTE_ARRAY, smax)); + ASSERT_EQ(smax, formatStatValue(Type::kByteArray, smax)); smin = std::string("abcdefgh"); smax = std::string("ijklmnop"); ASSERT_STREQ( - "abcdefgh", FormatStatValue(Type::FIXED_LEN_BYTE_ARRAY, smin).c_str()); + "abcdefgh", formatStatValue(Type::kFixedLenByteArray, smin).c_str()); ASSERT_STREQ( - "ijklmnop", FormatStatValue(Type::FIXED_LEN_BYTE_ARRAY, smax).c_str()); + "ijklmnop", formatStatValue(Type::kFixedLenByteArray, smax).c_str()); } TEST(TestInt96Timestamp, Decoding) { - auto check = [](int32_t julian_day, uint64_t nanoseconds) { + auto check = [](int32_t julianDay, uint64_t nanoseconds) { #if ARROW_LITTLE_ENDIAN Int96 i96{ static_cast(nanoseconds), static_cast(nanoseconds >> 32), - static_cast(julian_day)}; + static_cast(julianDay)}; #else Int96 i96{ static_cast(nanoseconds >> 32), static_cast(nanoseconds), - static_cast(julian_day)}; + static_cast(julianDay)}; #endif - // Official formula according to + // Official formula according to. // https://github.com/apache/parquet-format/pull/49 int64_t expected = - (julian_day - 2440588) * (86400LL * 1000 * 1000 * 1000) + nanoseconds; - int64_t actual = Int96GetNanoSeconds(i96); + (julianDay - 2440588) * (86400LL * 1000 * 1000 * 1000) + nanoseconds; + int64_t actual = int96GetNanoSeconds(i96); ASSERT_EQ(expected, actual); }; - // [2333837, 2547339] is the range of Julian days that can be converted to - // 64-bit Unix timestamps. + // [2333837, 2547339] Is the range of Julian days that can be converted to. + // 64-Bit Unix timestamps. check(2333837, 0); check(2333855, 0); check(2547330, 0); diff --git a/velox/dwio/parquet/writer/arrow/tests/XxHasher.cpp b/velox/dwio/parquet/writer/arrow/tests/XxHasher.cpp index cecdfc9e7de..a7b47b27d76 100644 --- a/velox/dwio/parquet/writer/arrow/tests/XxHasher.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/XxHasher.cpp @@ -18,88 +18,87 @@ #include "velox/dwio/parquet/writer/arrow/tests/XxHasher.h" -#define XXH_INLINE_ALL -#include +#include "velox/common/base/XxHashInline.h" namespace facebook::velox::parquet::arrow { namespace { template -uint64_t XxHashHelper(T value, uint32_t seed) { +uint64_t xxHashHelper(T value, uint32_t seed) { return XXH64(reinterpret_cast(&value), sizeof(T), seed); } template -void XxHashesHelper( +void xxHashesHelper( const T* values, uint32_t seed, - int num_values, + int numValues, uint64_t* results) { - for (int i = 0; i < num_values; ++i) { - results[i] = XxHashHelper(values[i], seed); + for (int i = 0; i < numValues; ++i) { + results[i] = xxHashHelper(values[i], seed); } } } // namespace -uint64_t XxHasher::Hash(int32_t value) const { - return XxHashHelper(value, kParquetBloomXxHashSeed); +uint64_t XxHasher::hash(int32_t value) const { + return xxHashHelper(value, kParquetBloomXxHashSeed); } -uint64_t XxHasher::Hash(int64_t value) const { - return XxHashHelper(value, kParquetBloomXxHashSeed); +uint64_t XxHasher::hash(int64_t value) const { + return xxHashHelper(value, kParquetBloomXxHashSeed); } -uint64_t XxHasher::Hash(float value) const { - return XxHashHelper(value, kParquetBloomXxHashSeed); +uint64_t XxHasher::hash(float value) const { + return xxHashHelper(value, kParquetBloomXxHashSeed); } -uint64_t XxHasher::Hash(double value) const { - return XxHashHelper(value, kParquetBloomXxHashSeed); +uint64_t XxHasher::hash(double value) const { + return xxHashHelper(value, kParquetBloomXxHashSeed); } -uint64_t XxHasher::Hash(const FLBA* value, uint32_t len) const { +uint64_t XxHasher::hash(const FLBA* value, uint32_t len) const { return XXH64( reinterpret_cast(value->ptr), len, kParquetBloomXxHashSeed); } -uint64_t XxHasher::Hash(const Int96* value) const { +uint64_t XxHasher::hash(const Int96* value) const { return XXH64( reinterpret_cast(value->value), sizeof(value->value), kParquetBloomXxHashSeed); } -uint64_t XxHasher::Hash(const ByteArray* value) const { +uint64_t XxHasher::hash(const ByteArray* value) const { return XXH64( reinterpret_cast(value->ptr), value->len, kParquetBloomXxHashSeed); } -void XxHasher::Hashes(const int32_t* values, int num_values, uint64_t* hashes) +void XxHasher::hashes(const int32_t* values, int numValues, uint64_t* hashes) const { - XxHashesHelper(values, kParquetBloomXxHashSeed, num_values, hashes); + xxHashesHelper(values, kParquetBloomXxHashSeed, numValues, hashes); } -void XxHasher::Hashes(const int64_t* values, int num_values, uint64_t* hashes) +void XxHasher::hashes(const int64_t* values, int numValues, uint64_t* hashes) const { - XxHashesHelper(values, kParquetBloomXxHashSeed, num_values, hashes); + xxHashesHelper(values, kParquetBloomXxHashSeed, numValues, hashes); } -void XxHasher::Hashes(const float* values, int num_values, uint64_t* hashes) +void XxHasher::hashes(const float* values, int numValues, uint64_t* hashes) const { - XxHashesHelper(values, kParquetBloomXxHashSeed, num_values, hashes); + xxHashesHelper(values, kParquetBloomXxHashSeed, numValues, hashes); } -void XxHasher::Hashes(const double* values, int num_values, uint64_t* hashes) +void XxHasher::hashes(const double* values, int numValues, uint64_t* hashes) const { - XxHashesHelper(values, kParquetBloomXxHashSeed, num_values, hashes); + xxHashesHelper(values, kParquetBloomXxHashSeed, numValues, hashes); } -void XxHasher::Hashes(const Int96* values, int num_values, uint64_t* hashes) +void XxHasher::hashes(const Int96* values, int numValues, uint64_t* hashes) const { - for (int i = 0; i < num_values; ++i) { + for (int i = 0; i < numValues; ++i) { hashes[i] = XXH64( reinterpret_cast(values[i].value), sizeof(values[i].value), @@ -107,9 +106,9 @@ void XxHasher::Hashes(const Int96* values, int num_values, uint64_t* hashes) } } -void XxHasher::Hashes(const ByteArray* values, int num_values, uint64_t* hashes) +void XxHasher::hashes(const ByteArray* values, int numValues, uint64_t* hashes) const { - for (int i = 0; i < num_values; ++i) { + for (int i = 0; i < numValues; ++i) { hashes[i] = XXH64( reinterpret_cast(values[i].ptr), values[i].len, @@ -117,15 +116,15 @@ void XxHasher::Hashes(const ByteArray* values, int num_values, uint64_t* hashes) } } -void XxHasher::Hashes( +void XxHasher::hashes( const FLBA* values, - uint32_t type_len, - int num_values, + uint32_t typeLen, + int numValues, uint64_t* hashes) const { - for (int i = 0; i < num_values; ++i) { + for (int i = 0; i < numValues; ++i) { hashes[i] = XXH64( reinterpret_cast(values[i].ptr), - type_len, + typeLen, kParquetBloomXxHashSeed); } } diff --git a/velox/dwio/parquet/writer/arrow/tests/XxHasher.h b/velox/dwio/parquet/writer/arrow/tests/XxHasher.h index 48650dba114..7b51cffc6de 100644 --- a/velox/dwio/parquet/writer/arrow/tests/XxHasher.h +++ b/velox/dwio/parquet/writer/arrow/tests/XxHasher.h @@ -28,30 +28,30 @@ namespace facebook::velox::parquet::arrow { class PARQUET_EXPORT XxHasher : public Hasher { public: - uint64_t Hash(int32_t value) const override; - uint64_t Hash(int64_t value) const override; - uint64_t Hash(float value) const override; - uint64_t Hash(double value) const override; - uint64_t Hash(const Int96* value) const override; - uint64_t Hash(const ByteArray* value) const override; - uint64_t Hash(const FLBA* val, uint32_t len) const override; - - void Hashes(const int32_t* values, int num_values, uint64_t* hashes) + uint64_t hash(int32_t value) const override; + uint64_t hash(int64_t value) const override; + uint64_t hash(float value) const override; + uint64_t hash(double value) const override; + uint64_t hash(const Int96* value) const override; + uint64_t hash(const ByteArray* value) const override; + uint64_t hash(const FLBA* val, uint32_t len) const override; + + void hashes(const int32_t* values, int numValues, uint64_t* hashes) const override; - void Hashes(const int64_t* values, int num_values, uint64_t* hashes) + void hashes(const int64_t* values, int numValues, uint64_t* hashes) const override; - void Hashes(const float* values, int num_values, uint64_t* hashes) + void hashes(const float* values, int numValues, uint64_t* hashes) const override; - void Hashes(const double* values, int num_values, uint64_t* hashes) + void hashes(const double* values, int numValues, uint64_t* hashes) const override; - void Hashes(const Int96* values, int num_values, uint64_t* hashes) + void hashes(const Int96* values, int numValues, uint64_t* hashes) const override; - void Hashes(const ByteArray* values, int num_values, uint64_t* hashes) + void hashes(const ByteArray* values, int numValues, uint64_t* hashes) const override; - void Hashes( + void hashes( const FLBA* values, - uint32_t type_len, - int num_values, + uint32_t typeLen, + int numValues, uint64_t* hashes) const override; static constexpr int kParquetBloomXxHashSeed = 0; diff --git a/velox/dwio/parquet/writer/arrow/util/ByteStreamSplitInternal.h b/velox/dwio/parquet/writer/arrow/util/ByteStreamSplitInternal.h index 45e5b025ff1..aa2f83f6763 100644 --- a/velox/dwio/parquet/writer/arrow/util/ByteStreamSplitInternal.h +++ b/velox/dwio/parquet/writer/arrow/util/ByteStreamSplitInternal.h @@ -27,21 +27,21 @@ #include #ifdef ARROW_HAVE_SSE4_2 -// Enable the SIMD for ByteStreamSplit Encoder/Decoder +// Enable the SIMD for ByteStreamSplit Encoder/Decoder. #define ARROW_HAVE_SIMD_SPLIT #endif // ARROW_HAVE_SSE4_2 namespace facebook::velox::parquet::arrow { // -// SIMD implementations +// SIMD implementations. // #if defined(ARROW_HAVE_SSE4_2) template -void ByteStreamSplitDecodeSse2( +void byteStreamSplitDecodeSse2( const uint8_t* data, - int64_t num_values, + int64_t numValues, int64_t stride, T* out) { constexpr size_t kNumStreams = sizeof(T); @@ -50,87 +50,88 @@ void ByteStreamSplitDecodeSse2( constexpr size_t kNumStreamsLog2 = (kNumStreams == 8U ? 3U : 2U); constexpr int64_t kBlockSize = sizeof(__m128i) * kNumStreams; - const int64_t size = num_values * sizeof(T); - const int64_t num_blocks = size / kBlockSize; - uint8_t* output_data = reinterpret_cast(out); + const int64_t size = numValues * sizeof(T); + const int64_t numBlocks = size / kBlockSize; + uint8_t* outputData = reinterpret_cast(out); // First handle suffix. - // This helps catch if the simd-based processing overflows into the suffix - // since almost surely a test would fail. - const int64_t num_processed_elements = - (num_blocks * kBlockSize) / kNumStreams; - for (int64_t i = num_processed_elements; i < num_values; ++i) { - uint8_t gathered_byte_data[kNumStreams]; + // This helps catch if the simd-based processing overflows into the suffix. + // Since almost surely a test would fail. + const int64_t numProcessedElements = (numBlocks * kBlockSize) / kNumStreams; + for (int64_t i = numProcessedElements; i < numValues; ++i) { + uint8_t gatheredByteData[kNumStreams]; for (size_t b = 0; b < kNumStreams; ++b) { - const size_t byte_index = b * stride + i; - gathered_byte_data[b] = data[byte_index]; + const size_t byteIndex = b * stride + i; + gatheredByteData[b] = data[byteIndex]; } - out[i] = arrow::util::SafeLoadAs(&gathered_byte_data[0]); + out[i] = arrow::util::SafeLoadAs(&gatheredByteData[0]); } // The blocks get processed hierarchically using the unpack intrinsics. // Example with four streams: - // Stage 1: AAAA BBBB CCCC DDDD - // Stage 2: ACAC ACAC BDBD BDBD - // Stage 3: ABCD ABCD ABCD ABCD + // Stage 1: AAAA BBBB CCCC DDDD. + // Stage 2: ACAC ACAC BDBD BDBD. + // Stage 3: ABCD ABCD ABCD ABCD. __m128i stage[kNumStreamsLog2 + 1U][kNumStreams]; constexpr size_t kNumStreamsHalf = kNumStreams / 2U; - for (int64_t i = 0; i < num_blocks; ++i) { + for (int64_t i = 0; i < numBlocks; ++i) { for (size_t j = 0; j < kNumStreams; ++j) { - stage[0][j] = _mm_loadu_si128(reinterpret_cast( - &data[i * sizeof(__m128i) + j * stride])); + stage[0][j] = mmLoaduSi128( + reinterpret_cast( + &data[i * sizeof(__m128i) + j * stride])); } for (size_t step = 0; step < kNumStreamsLog2; ++step) { for (size_t j = 0; j < kNumStreamsHalf; ++j) { stage[step + 1U][j * 2] = - _mm_unpacklo_epi8(stage[step][j], stage[step][kNumStreamsHalf + j]); + mmUnpackloEpi8(stage[step][j], stage[step][kNumStreamsHalf + j]); stage[step + 1U][j * 2 + 1U] = - _mm_unpackhi_epi8(stage[step][j], stage[step][kNumStreamsHalf + j]); + mmUnpackhiEpi8(stage[step][j], stage[step][kNumStreamsHalf + j]); } } for (size_t j = 0; j < kNumStreams; ++j) { - _mm_storeu_si128( + mmStoreuSi128( reinterpret_cast<__m128i*>( - &output_data[(i * kNumStreams + j) * sizeof(__m128i)]), + &outputData[(i * kNumStreams + j) * sizeof(__m128i)]), stage[kNumStreamsLog2][j]); } } } template -void ByteStreamSplitEncodeSse2( - const uint8_t* raw_values, - const size_t num_values, - uint8_t* output_buffer_raw) { +void byteStreamSplitEncodeSse2( + const uint8_t* rawValues, + const size_t numValues, + uint8_t* outputBufferRaw) { constexpr size_t kNumStreams = sizeof(T); static_assert( kNumStreams == 4U || kNumStreams == 8U, "Invalid number of streams."); constexpr size_t kBlockSize = sizeof(__m128i) * kNumStreams; __m128i stage[3][kNumStreams]; - __m128i final_result[kNumStreams]; + __m128i finalResult[kNumStreams]; - const size_t size = num_values * sizeof(T); - const size_t num_blocks = size / kBlockSize; - const __m128i* raw_values_sse = reinterpret_cast(raw_values); - __m128i* output_buffer_streams[kNumStreams]; + const size_t size = numValues * sizeof(T); + const size_t numBlocks = size / kBlockSize; + const __m128i* rawValuesSse = reinterpret_cast(rawValues); + __m128i* outputBufferStreams[kNumStreams]; for (size_t i = 0; i < kNumStreams; ++i) { - output_buffer_streams[i] = - reinterpret_cast<__m128i*>(&output_buffer_raw[num_values * i]); + outputBufferStreams[i] = + reinterpret_cast<__m128i*>(&outputBufferRaw[numValues * i]); } // First handle suffix. - const size_t num_processed_elements = (num_blocks * kBlockSize) / sizeof(T); - for (size_t i = num_processed_elements; i < num_values; ++i) { + const size_t numProcessedElements = (numBlocks * kBlockSize) / sizeof(T); + for (size_t i = numProcessedElements; i < numValues; ++i) { for (size_t j = 0U; j < kNumStreams; ++j) { - const uint8_t byte_in_value = raw_values[i * kNumStreams + j]; - output_buffer_raw[j * num_values + i] = byte_in_value; + const uint8_t byteInValue = rawValues[i * kNumStreams + j]; + outputBufferRaw[j * numValues + i] = byteInValue; } } - // The current shuffling algorithm diverges for float and double types but the - // compiler should be able to remove the branch since only one path is taken - // for each template instantiation. Example run for floats: Step 0, copy: + // The current shuffling algorithm diverges for float and double types but + // the. Compiler should be able to remove the branch since only one path is + // taken. For each template instantiation. Example run for floats: Step 0, + // copy: // 0: ABCD ABCD ABCD ABCD 1: ABCD ABCD ABCD ABCD ... // Step 1: _mm_unpacklo_epi8 and mm_unpackhi_epi8: // 0: AABB CCDD AABB CCDD 1: AABB CCDD AABB CCDD ... @@ -139,51 +140,49 @@ void ByteStreamSplitEncodeSse2( // 0: AAAA AAAA BBBB BBBB 1: CCCC CCCC DDDD DDDD ... // Step 4: __mm_unpacklo_epi64 and _mm_unpackhi_epi64: // 0: AAAA AAAA AAAA AAAA 1: BBBB BBBB BBBB BBBB ... - for (size_t block_index = 0; block_index < num_blocks; ++block_index) { + for (size_t blockIndex = 0; blockIndex < numBlocks; ++blockIndex) { // First copy the data to stage 0. for (size_t i = 0; i < kNumStreams; ++i) { - stage[0][i] = - _mm_loadu_si128(&raw_values_sse[block_index * kNumStreams + i]); + stage[0][i] = mmLoaduSi128(&rawValuesSse[blockIndex * kNumStreams + i]); } // The shuffling of bytes is performed through the unpack intrinsics. // In my measurements this gives better performance then an implementation // which uses the shuffle intrinsics. - for (size_t stage_lvl = 0; stage_lvl < 2U; ++stage_lvl) { + for (size_t stageLvl = 0; stageLvl < 2U; ++stageLvl) { for (size_t i = 0; i < kNumStreams / 2U; ++i) { - stage[stage_lvl + 1][i * 2] = _mm_unpacklo_epi8( - stage[stage_lvl][i * 2], stage[stage_lvl][i * 2 + 1]); - stage[stage_lvl + 1][i * 2 + 1] = _mm_unpackhi_epi8( - stage[stage_lvl][i * 2], stage[stage_lvl][i * 2 + 1]); + stage[stageLvl + 1][i * 2] = + mmUnpackloEpi8(stage[stageLvl][i * 2], stage[stageLvl][i * 2 + 1]); + stage[stageLvl + 1][i * 2 + 1] = + mmUnpackhiEpi8(stage[stageLvl][i * 2], stage[stageLvl][i * 2 + 1]); } } if constexpr (kNumStreams == 8U) { // This is the path for double. __m128i tmp[8]; for (size_t i = 0; i < 4; ++i) { - tmp[i * 2] = _mm_unpacklo_epi32(stage[2][i], stage[2][i + 4]); - tmp[i * 2 + 1] = _mm_unpackhi_epi32(stage[2][i], stage[2][i + 4]); + tmp[i * 2] = mmUnpackloEpi32(stage[2][i], stage[2][i + 4]); + tmp[i * 2 + 1] = mmUnpackhiEpi32(stage[2][i], stage[2][i + 4]); } for (size_t i = 0; i < 4; ++i) { - final_result[i * 2] = _mm_unpacklo_epi32(tmp[i], tmp[i + 4]); - final_result[i * 2 + 1] = _mm_unpackhi_epi32(tmp[i], tmp[i + 4]); + finalResult[i * 2] = mmUnpackloEpi32(tmp[i], tmp[i + 4]); + finalResult[i * 2 + 1] = mmUnpackhiEpi32(tmp[i], tmp[i + 4]); } } else { - // this is the path for float. + // This is the path for float. __m128i tmp[4]; for (size_t i = 0; i < 2; ++i) { - tmp[i * 2] = _mm_unpacklo_epi8(stage[2][i * 2], stage[2][i * 2 + 1]); - tmp[i * 2 + 1] = - _mm_unpackhi_epi8(stage[2][i * 2], stage[2][i * 2 + 1]); + tmp[i * 2] = mmUnpackloEpi8(stage[2][i * 2], stage[2][i * 2 + 1]); + tmp[i * 2 + 1] = mmUnpackhiEpi8(stage[2][i * 2], stage[2][i * 2 + 1]); } for (size_t i = 0; i < 2; ++i) { - final_result[i * 2] = _mm_unpacklo_epi64(tmp[i], tmp[i + 2]); - final_result[i * 2 + 1] = _mm_unpackhi_epi64(tmp[i], tmp[i + 2]); + finalResult[i * 2] = mmUnpackloEpi64(tmp[i], tmp[i + 2]); + finalResult[i * 2 + 1] = mmUnpackhiEpi64(tmp[i], tmp[i + 2]); } } for (size_t i = 0; i < kNumStreams; ++i) { - _mm_storeu_si128(&output_buffer_streams[i][block_index], final_result[i]); + mmStoreuSi128(&outputBufferStreams[i][blockIndex], finalResult[i]); } } } @@ -191,9 +190,9 @@ void ByteStreamSplitEncodeSse2( #if defined(ARROW_HAVE_AVX2) template -void ByteStreamSplitDecodeAvx2( +void byteStreamSplitDecodeAvx2( const uint8_t* data, - int64_t num_values, + int64_t numValues, int64_t stride, T* out) { constexpr size_t kNumStreams = sizeof(T); @@ -202,119 +201,117 @@ void ByteStreamSplitDecodeAvx2( constexpr size_t kNumStreamsLog2 = (kNumStreams == 8U ? 3U : 2U); constexpr int64_t kBlockSize = sizeof(__m256i) * kNumStreams; - const int64_t size = num_values * sizeof(T); + const int64_t size = numValues * sizeof(T); if (size < kBlockSize) // Back to SSE for small size - return ByteStreamSplitDecodeSse2(data, num_values, stride, out); - const int64_t num_blocks = size / kBlockSize; - uint8_t* output_data = reinterpret_cast(out); + return byteStreamSplitDecodeSse2(data, numValues, stride, out); + const int64_t numBlocks = size / kBlockSize; + uint8_t* outputData = reinterpret_cast(out); // First handle suffix. - const int64_t num_processed_elements = - (num_blocks * kBlockSize) / kNumStreams; - for (int64_t i = num_processed_elements; i < num_values; ++i) { - uint8_t gathered_byte_data[kNumStreams]; + const int64_t numProcessedElements = (numBlocks * kBlockSize) / kNumStreams; + for (int64_t i = numProcessedElements; i < numValues; ++i) { + uint8_t gatheredByteData[kNumStreams]; for (size_t b = 0; b < kNumStreams; ++b) { - const size_t byte_index = b * stride + i; - gathered_byte_data[b] = data[byte_index]; + const size_t byteIndex = b * stride + i; + gatheredByteData[b] = data[byteIndex]; } - out[i] = arrow::util::SafeLoadAs(&gathered_byte_data[0]); + out[i] = arrow::util::SafeLoadAs(&gatheredByteData[0]); } // Processed hierarchically using unpack intrinsics, then permute intrinsics. __m256i stage[kNumStreamsLog2 + 1U][kNumStreams]; - __m256i final_result[kNumStreams]; + __m256i finalResult[kNumStreams]; constexpr size_t kNumStreamsHalf = kNumStreams / 2U; - for (int64_t i = 0; i < num_blocks; ++i) { + for (int64_t i = 0; i < numBlocks; ++i) { for (size_t j = 0; j < kNumStreams; ++j) { - stage[0][j] = _mm256_loadu_si256(reinterpret_cast( - &data[i * sizeof(__m256i) + j * stride])); + stage[0][j] = mm256LoaduSi256( + reinterpret_cast( + &data[i * sizeof(__m256i) + j * stride])); } for (size_t step = 0; step < kNumStreamsLog2; ++step) { for (size_t j = 0; j < kNumStreamsHalf; ++j) { - stage[step + 1U][j * 2] = _mm256_unpacklo_epi8( - stage[step][j], stage[step][kNumStreamsHalf + j]); - stage[step + 1U][j * 2 + 1U] = _mm256_unpackhi_epi8( - stage[step][j], stage[step][kNumStreamsHalf + j]); + stage[step + 1U][j * 2] = + mm256UnpackloEpi8(stage[step][j], stage[step][kNumStreamsHalf + j]); + stage[step + 1U][j * 2 + 1U] = + mm256UnpackhiEpi8(stage[step][j], stage[step][kNumStreamsHalf + j]); } } if constexpr (kNumStreams == 8U) { - // path for double, 128i index: - // {0x00, 0x08}, {0x01, 0x09}, {0x02, 0x0A}, {0x03, 0x0B}, - // {0x04, 0x0C}, {0x05, 0x0D}, {0x06, 0x0E}, {0x07, 0x0F}, - final_result[0] = _mm256_permute2x128_si256( + // Path for double, 128i index: + // {0X00, 0x08}, {0x01, 0x09}, {0x02, 0x0A}, {0x03, 0x0B}, + // {0X04, 0x0C}, {0x05, 0x0D}, {0x06, 0x0E}, {0x07, 0x0F}, + finalResult[0] = mm256Permute2x128Si256( stage[kNumStreamsLog2][0], stage[kNumStreamsLog2][1], 0b00100000); - final_result[1] = _mm256_permute2x128_si256( + finalResult[1] = mm256Permute2x128Si256( stage[kNumStreamsLog2][2], stage[kNumStreamsLog2][3], 0b00100000); - final_result[2] = _mm256_permute2x128_si256( + finalResult[2] = mm256Permute2x128Si256( stage[kNumStreamsLog2][4], stage[kNumStreamsLog2][5], 0b00100000); - final_result[3] = _mm256_permute2x128_si256( + finalResult[3] = mm256Permute2x128Si256( stage[kNumStreamsLog2][6], stage[kNumStreamsLog2][7], 0b00100000); - final_result[4] = _mm256_permute2x128_si256( + finalResult[4] = mm256Permute2x128Si256( stage[kNumStreamsLog2][0], stage[kNumStreamsLog2][1], 0b00110001); - final_result[5] = _mm256_permute2x128_si256( + finalResult[5] = mm256Permute2x128Si256( stage[kNumStreamsLog2][2], stage[kNumStreamsLog2][3], 0b00110001); - final_result[6] = _mm256_permute2x128_si256( + finalResult[6] = mm256Permute2x128Si256( stage[kNumStreamsLog2][4], stage[kNumStreamsLog2][5], 0b00110001); - final_result[7] = _mm256_permute2x128_si256( + finalResult[7] = mm256Permute2x128Si256( stage[kNumStreamsLog2][6], stage[kNumStreamsLog2][7], 0b00110001); } else { - // path for float, 128i index: + // Path for float, 128i index: // {0x00, 0x04}, {0x01, 0x05}, {0x02, 0x06}, {0x03, 0x07} - final_result[0] = _mm256_permute2x128_si256( + finalResult[0] = mm256Permute2x128Si256( stage[kNumStreamsLog2][0], stage[kNumStreamsLog2][1], 0b00100000); - final_result[1] = _mm256_permute2x128_si256( + finalResult[1] = mm256Permute2x128Si256( stage[kNumStreamsLog2][2], stage[kNumStreamsLog2][3], 0b00100000); - final_result[2] = _mm256_permute2x128_si256( + finalResult[2] = mm256Permute2x128Si256( stage[kNumStreamsLog2][0], stage[kNumStreamsLog2][1], 0b00110001); - final_result[3] = _mm256_permute2x128_si256( + finalResult[3] = mm256Permute2x128Si256( stage[kNumStreamsLog2][2], stage[kNumStreamsLog2][3], 0b00110001); } for (size_t j = 0; j < kNumStreams; ++j) { - _mm256_storeu_si256( + mm256StoreuSi256( reinterpret_cast<__m256i*>( - &output_data[(i * kNumStreams + j) * sizeof(__m256i)]), - final_result[j]); + &outputData[(i * kNumStreams + j) * sizeof(__m256i)]), + finalResult[j]); } } } template -void ByteStreamSplitEncodeAvx2( - const uint8_t* raw_values, - const size_t num_values, - uint8_t* output_buffer_raw) { +void byteStreamSplitEncodeAvx2( + const uint8_t* rawValues, + const size_t numValues, + uint8_t* outputBufferRaw) { constexpr size_t kNumStreams = sizeof(T); static_assert( kNumStreams == 4U || kNumStreams == 8U, "Invalid number of streams."); constexpr size_t kBlockSize = sizeof(__m256i) * kNumStreams; if constexpr (kNumStreams == 8U) // Back to SSE, currently no path for double. - return ByteStreamSplitEncodeSse2( - raw_values, num_values, output_buffer_raw); + return byteStreamSplitEncodeSse2(rawValues, numValues, outputBufferRaw); - const size_t size = num_values * sizeof(T); + const size_t size = numValues * sizeof(T); if (size < kBlockSize) // Back to SSE for small size - return ByteStreamSplitEncodeSse2( - raw_values, num_values, output_buffer_raw); - const size_t num_blocks = size / kBlockSize; - const __m256i* raw_values_simd = reinterpret_cast(raw_values); - __m256i* output_buffer_streams[kNumStreams]; + return byteStreamSplitEncodeSse2(rawValues, numValues, outputBufferRaw); + const size_t numBlocks = size / kBlockSize; + const __m256i* rawValuesSimd = reinterpret_cast(rawValues); + __m256i* outputBufferStreams[kNumStreams]; for (size_t i = 0; i < kNumStreams; ++i) { - output_buffer_streams[i] = - reinterpret_cast<__m256i*>(&output_buffer_raw[num_values * i]); + outputBufferStreams[i] = + reinterpret_cast<__m256i*>(&outputBufferRaw[numValues * i]); } // First handle suffix. - const size_t num_processed_elements = (num_blocks * kBlockSize) / sizeof(T); - for (size_t i = num_processed_elements; i < num_values; ++i) { + const size_t numProcessedElements = (numBlocks * kBlockSize) / sizeof(T); + for (size_t i = numProcessedElements; i < numValues; ++i) { for (size_t j = 0U; j < kNumStreams; ++j) { - const uint8_t byte_in_value = raw_values[i * kNumStreams + j]; - output_buffer_raw[j * num_values + i] = byte_in_value; + const uint8_t byteInValue = rawValues[i * kNumStreams + j]; + outputBufferRaw[j * numValues + i] = byteInValue; } } @@ -325,42 +322,36 @@ void ByteStreamSplitEncodeAvx2( constexpr size_t kNumUnpack = 3U; __m256i stage[kNumUnpack + 1][kNumStreams]; static const __m256i kPermuteMask = - _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); + mm256SetEpi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); __m256i permute[kNumStreams]; - __m256i final_result[kNumStreams]; + __m256i finalResult[kNumStreams]; - for (size_t block_index = 0; block_index < num_blocks; ++block_index) { + for (size_t blockIndex = 0; blockIndex < numBlocks; ++blockIndex) { for (size_t i = 0; i < kNumStreams; ++i) { stage[0][i] = - _mm256_loadu_si256(&raw_values_simd[block_index * kNumStreams + i]); + mm256LoaduSi256(&rawValuesSimd[blockIndex * kNumStreams + i]); } - for (size_t stage_lvl = 0; stage_lvl < kNumUnpack; ++stage_lvl) { + for (size_t stageLvl = 0; stageLvl < kNumUnpack; ++stageLvl) { for (size_t i = 0; i < kNumStreams / 2U; ++i) { - stage[stage_lvl + 1][i * 2] = _mm256_unpacklo_epi8( - stage[stage_lvl][i * 2], stage[stage_lvl][i * 2 + 1]); - stage[stage_lvl + 1][i * 2 + 1] = _mm256_unpackhi_epi8( - stage[stage_lvl][i * 2], stage[stage_lvl][i * 2 + 1]); + stage[stageLvl + 1][i * 2] = mm256UnpackloEpi8( + stage[stageLvl][i * 2], stage[stageLvl][i * 2 + 1]); + stage[stageLvl + 1][i * 2 + 1] = mm256UnpackhiEpi8( + stage[stageLvl][i * 2], stage[stageLvl][i * 2 + 1]); } } for (size_t i = 0; i < kNumStreams; ++i) { - permute[i] = - _mm256_permutevar8x32_epi32(stage[kNumUnpack][i], kPermuteMask); + permute[i] = mm256Permutevar8x32Epi32(stage[kNumUnpack][i], kPermuteMask); } - final_result[0] = - _mm256_permute2x128_si256(permute[0], permute[2], 0b00100000); - final_result[1] = - _mm256_permute2x128_si256(permute[0], permute[2], 0b00110001); - final_result[2] = - _mm256_permute2x128_si256(permute[1], permute[3], 0b00100000); - final_result[3] = - _mm256_permute2x128_si256(permute[1], permute[3], 0b00110001); + finalResult[0] = mm256Permute2x128Si256(permute[0], permute[2], 0b00100000); + finalResult[1] = mm256Permute2x128Si256(permute[0], permute[2], 0b00110001); + finalResult[2] = mm256Permute2x128Si256(permute[1], permute[3], 0b00100000); + finalResult[3] = mm256Permute2x128Si256(permute[1], permute[3], 0b00110001); for (size_t i = 0; i < kNumStreams; ++i) { - _mm256_storeu_si256( - &output_buffer_streams[i][block_index], final_result[i]); + mm256StoreuSi256(&outputBufferStreams[i][blockIndex], finalResult[i]); } } } @@ -368,9 +359,9 @@ void ByteStreamSplitEncodeAvx2( #if defined(ARROW_HAVE_AVX512) template -void ByteStreamSplitDecodeAvx512( +void byteStreamSplitDecodeAvx512( const uint8_t* data, - int64_t num_values, + int64_t numValues, int64_t stride, T* out) { constexpr size_t kNumStreams = sizeof(T); @@ -379,157 +370,144 @@ void ByteStreamSplitDecodeAvx512( constexpr size_t kNumStreamsLog2 = (kNumStreams == 8U ? 3U : 2U); constexpr int64_t kBlockSize = sizeof(__m512i) * kNumStreams; - const int64_t size = num_values * sizeof(T); + const int64_t size = numValues * sizeof(T); if (size < kBlockSize) // Back to AVX2 for small size - return ByteStreamSplitDecodeAvx2(data, num_values, stride, out); - const int64_t num_blocks = size / kBlockSize; - uint8_t* output_data = reinterpret_cast(out); + return byteStreamSplitDecodeAvx2(data, numValues, stride, out); + const int64_t numBlocks = size / kBlockSize; + uint8_t* outputData = reinterpret_cast(out); // First handle suffix. - const int64_t num_processed_elements = - (num_blocks * kBlockSize) / kNumStreams; - for (int64_t i = num_processed_elements; i < num_values; ++i) { - uint8_t gathered_byte_data[kNumStreams]; + const int64_t numProcessedElements = (numBlocks * kBlockSize) / kNumStreams; + for (int64_t i = numProcessedElements; i < numValues; ++i) { + uint8_t gatheredByteData[kNumStreams]; for (size_t b = 0; b < kNumStreams; ++b) { - const size_t byte_index = b * stride + i; - gathered_byte_data[b] = data[byte_index]; + const size_t byteIndex = b * stride + i; + gatheredByteData[b] = data[byteIndex]; } - out[i] = arrow::util::SafeLoadAs(&gathered_byte_data[0]); + out[i] = arrow::util::SafeLoadAs(&gatheredByteData[0]); } // Processed hierarchically using the unpack, then two shuffles. __m512i stage[kNumStreamsLog2 + 1U][kNumStreams]; __m512i shuffle[kNumStreams]; - __m512i final_result[kNumStreams]; + __m512i finalResult[kNumStreams]; constexpr size_t kNumStreamsHalf = kNumStreams / 2U; - for (int64_t i = 0; i < num_blocks; ++i) { + for (int64_t i = 0; i < numBlocks; ++i) { for (size_t j = 0; j < kNumStreams; ++j) { - stage[0][j] = _mm512_loadu_si512(reinterpret_cast( - &data[i * sizeof(__m512i) + j * stride])); + stage[0][j] = mm512LoaduSi512( + reinterpret_cast( + &data[i * sizeof(__m512i) + j * stride])); } for (size_t step = 0; step < kNumStreamsLog2; ++step) { for (size_t j = 0; j < kNumStreamsHalf; ++j) { - stage[step + 1U][j * 2] = _mm512_unpacklo_epi8( - stage[step][j], stage[step][kNumStreamsHalf + j]); - stage[step + 1U][j * 2 + 1U] = _mm512_unpackhi_epi8( - stage[step][j], stage[step][kNumStreamsHalf + j]); + stage[step + 1U][j * 2] = + mm512UnpackloEpi8(stage[step][j], stage[step][kNumStreamsHalf + j]); + stage[step + 1U][j * 2 + 1U] = + mm512UnpackhiEpi8(stage[step][j], stage[step][kNumStreamsHalf + j]); } } if constexpr (kNumStreams == 8U) { - // path for double, 128i index: - // {0x00, 0x04, 0x08, 0x0C}, {0x10, 0x14, 0x18, 0x1C}, - // {0x01, 0x05, 0x09, 0x0D}, {0x11, 0x15, 0x19, 0x1D}, - // {0x02, 0x06, 0x0A, 0x0E}, {0x12, 0x16, 0x1A, 0x1E}, - // {0x03, 0x07, 0x0B, 0x0F}, {0x13, 0x17, 0x1B, 0x1F}, - shuffle[0] = _mm512_shuffle_i32x4( + // Path for double, 128i index: + // {0X00, 0x04, 0x08, 0x0C}, {0x10, 0x14, 0x18, 0x1C}, + // {0X01, 0x05, 0x09, 0x0D}, {0x11, 0x15, 0x19, 0x1D}, + // {0X02, 0x06, 0x0A, 0x0E}, {0x12, 0x16, 0x1A, 0x1E}, + // {0X03, 0x07, 0x0B, 0x0F}, {0x13, 0x17, 0x1B, 0x1F}, + shuffle[0] = mm512ShuffleI32x4( stage[kNumStreamsLog2][0], stage[kNumStreamsLog2][1], 0b01000100); - shuffle[1] = _mm512_shuffle_i32x4( + shuffle[1] = mm512ShuffleI32x4( stage[kNumStreamsLog2][2], stage[kNumStreamsLog2][3], 0b01000100); - shuffle[2] = _mm512_shuffle_i32x4( + shuffle[2] = mm512ShuffleI32x4( stage[kNumStreamsLog2][4], stage[kNumStreamsLog2][5], 0b01000100); - shuffle[3] = _mm512_shuffle_i32x4( + shuffle[3] = mm512ShuffleI32x4( stage[kNumStreamsLog2][6], stage[kNumStreamsLog2][7], 0b01000100); - shuffle[4] = _mm512_shuffle_i32x4( + shuffle[4] = mm512ShuffleI32x4( stage[kNumStreamsLog2][0], stage[kNumStreamsLog2][1], 0b11101110); - shuffle[5] = _mm512_shuffle_i32x4( + shuffle[5] = mm512ShuffleI32x4( stage[kNumStreamsLog2][2], stage[kNumStreamsLog2][3], 0b11101110); - shuffle[6] = _mm512_shuffle_i32x4( + shuffle[6] = mm512ShuffleI32x4( stage[kNumStreamsLog2][4], stage[kNumStreamsLog2][5], 0b11101110); - shuffle[7] = _mm512_shuffle_i32x4( + shuffle[7] = mm512ShuffleI32x4( stage[kNumStreamsLog2][6], stage[kNumStreamsLog2][7], 0b11101110); - final_result[0] = - _mm512_shuffle_i32x4(shuffle[0], shuffle[1], 0b10001000); - final_result[1] = - _mm512_shuffle_i32x4(shuffle[2], shuffle[3], 0b10001000); - final_result[2] = - _mm512_shuffle_i32x4(shuffle[0], shuffle[1], 0b11011101); - final_result[3] = - _mm512_shuffle_i32x4(shuffle[2], shuffle[3], 0b11011101); - final_result[4] = - _mm512_shuffle_i32x4(shuffle[4], shuffle[5], 0b10001000); - final_result[5] = - _mm512_shuffle_i32x4(shuffle[6], shuffle[7], 0b10001000); - final_result[6] = - _mm512_shuffle_i32x4(shuffle[4], shuffle[5], 0b11011101); - final_result[7] = - _mm512_shuffle_i32x4(shuffle[6], shuffle[7], 0b11011101); + finalResult[0] = mm512ShuffleI32x4(shuffle[0], shuffle[1], 0b10001000); + finalResult[1] = mm512ShuffleI32x4(shuffle[2], shuffle[3], 0b10001000); + finalResult[2] = mm512ShuffleI32x4(shuffle[0], shuffle[1], 0b11011101); + finalResult[3] = mm512ShuffleI32x4(shuffle[2], shuffle[3], 0b11011101); + finalResult[4] = mm512ShuffleI32x4(shuffle[4], shuffle[5], 0b10001000); + finalResult[5] = mm512ShuffleI32x4(shuffle[6], shuffle[7], 0b10001000); + finalResult[6] = mm512ShuffleI32x4(shuffle[4], shuffle[5], 0b11011101); + finalResult[7] = mm512ShuffleI32x4(shuffle[6], shuffle[7], 0b11011101); } else { - // path for float, 128i index: + // Path for float, 128i index: // {0x00, 0x04, 0x08, 0x0C}, {0x01, 0x05, 0x09, 0x0D} - // {0x02, 0x06, 0x0A, 0x0E}, {0x03, 0x07, 0x0B, 0x0F}, - shuffle[0] = _mm512_shuffle_i32x4( + // {0X02, 0x06, 0x0A, 0x0E}, {0x03, 0x07, 0x0B, 0x0F}, + shuffle[0] = mm512ShuffleI32x4( stage[kNumStreamsLog2][0], stage[kNumStreamsLog2][1], 0b01000100); - shuffle[1] = _mm512_shuffle_i32x4( + shuffle[1] = mm512ShuffleI32x4( stage[kNumStreamsLog2][2], stage[kNumStreamsLog2][3], 0b01000100); - shuffle[2] = _mm512_shuffle_i32x4( + shuffle[2] = mm512ShuffleI32x4( stage[kNumStreamsLog2][0], stage[kNumStreamsLog2][1], 0b11101110); - shuffle[3] = _mm512_shuffle_i32x4( + shuffle[3] = mm512ShuffleI32x4( stage[kNumStreamsLog2][2], stage[kNumStreamsLog2][3], 0b11101110); - final_result[0] = - _mm512_shuffle_i32x4(shuffle[0], shuffle[1], 0b10001000); - final_result[1] = - _mm512_shuffle_i32x4(shuffle[0], shuffle[1], 0b11011101); - final_result[2] = - _mm512_shuffle_i32x4(shuffle[2], shuffle[3], 0b10001000); - final_result[3] = - _mm512_shuffle_i32x4(shuffle[2], shuffle[3], 0b11011101); + finalResult[0] = mm512ShuffleI32x4(shuffle[0], shuffle[1], 0b10001000); + finalResult[1] = mm512ShuffleI32x4(shuffle[0], shuffle[1], 0b11011101); + finalResult[2] = mm512ShuffleI32x4(shuffle[2], shuffle[3], 0b10001000); + finalResult[3] = mm512ShuffleI32x4(shuffle[2], shuffle[3], 0b11011101); } for (size_t j = 0; j < kNumStreams; ++j) { - _mm512_storeu_si512( + mm512StoreuSi512( reinterpret_cast<__m512i*>( - &output_data[(i * kNumStreams + j) * sizeof(__m512i)]), - final_result[j]); + &outputData[(i * kNumStreams + j) * sizeof(__m512i)]), + finalResult[j]); } } } template -void ByteStreamSplitEncodeAvx512( - const uint8_t* raw_values, - const size_t num_values, - uint8_t* output_buffer_raw) { +void byteStreamSplitEncodeAvx512( + const uint8_t* rawValues, + const size_t numValues, + uint8_t* outputBufferRaw) { constexpr size_t kNumStreams = sizeof(T); static_assert( kNumStreams == 4U || kNumStreams == 8U, "Invalid number of streams."); constexpr size_t kBlockSize = sizeof(__m512i) * kNumStreams; - const size_t size = num_values * sizeof(T); + const size_t size = numValues * sizeof(T); if (size < kBlockSize) // Back to AVX2 for small size - return ByteStreamSplitEncodeAvx2( - raw_values, num_values, output_buffer_raw); + return byteStreamSplitEncodeAvx2(rawValues, numValues, outputBufferRaw); - const size_t num_blocks = size / kBlockSize; - const __m512i* raw_values_simd = reinterpret_cast(raw_values); - __m512i* output_buffer_streams[kNumStreams]; + const size_t numBlocks = size / kBlockSize; + const __m512i* rawValuesSimd = reinterpret_cast(rawValues); + __m512i* outputBufferStreams[kNumStreams]; for (size_t i = 0; i < kNumStreams; ++i) { - output_buffer_streams[i] = - reinterpret_cast<__m512i*>(&output_buffer_raw[num_values * i]); + outputBufferStreams[i] = + reinterpret_cast<__m512i*>(&outputBufferRaw[numValues * i]); } // First handle suffix. - const size_t num_processed_elements = (num_blocks * kBlockSize) / sizeof(T); - for (size_t i = num_processed_elements; i < num_values; ++i) { + const size_t numProcessedElements = (numBlocks * kBlockSize) / sizeof(T); + for (size_t i = numProcessedElements; i < numValues; ++i) { for (size_t j = 0U; j < kNumStreams; ++j) { - const uint8_t byte_in_value = raw_values[i * kNumStreams + j]; - output_buffer_raw[j * num_values + i] = byte_in_value; + const uint8_t byteInValue = rawValues[i * kNumStreams + j]; + outputBufferRaw[j * numValues + i] = byteInValue; } } constexpr size_t KNumUnpack = (kNumStreams == 8U) ? 2U : 3U; - __m512i final_result[kNumStreams]; + __m512i finalResult[kNumStreams]; __m512i unpack[KNumUnpack + 1][kNumStreams]; __m512i permutex[kNumStreams]; - __m512i permutex_mask; + __m512i permutexMask; if constexpr (kNumStreams == 8U) { - // use _mm512_set_epi32, no _mm512_set_epi16 for some old gcc version. - permutex_mask = _mm512_set_epi32( + // Use _mm512_set_epi32, no _mm512_set_epi16 for some old gcc version. + permutexMask = mm512SetEpi32( 0x001F0017, 0x000F0007, 0x001E0016, @@ -547,7 +525,7 @@ void ByteStreamSplitEncodeAvx512( 0x00180010, 0x00080000); } else { - permutex_mask = _mm512_set_epi32( + permutexMask = mm512SetEpi32( 0x0F, 0x0B, 0x07, @@ -566,60 +544,52 @@ void ByteStreamSplitEncodeAvx512( 0x00); } - for (size_t block_index = 0; block_index < num_blocks; ++block_index) { + for (size_t blockIndex = 0; blockIndex < numBlocks; ++blockIndex) { for (size_t i = 0; i < kNumStreams; ++i) { unpack[0][i] = - _mm512_loadu_si512(&raw_values_simd[block_index * kNumStreams + i]); + mm512LoaduSi512(&rawValuesSimd[blockIndex * kNumStreams + i]); } - for (size_t unpack_lvl = 0; unpack_lvl < KNumUnpack; ++unpack_lvl) { + for (size_t unpackLvl = 0; unpackLvl < KNumUnpack; ++unpackLvl) { for (size_t i = 0; i < kNumStreams / 2U; ++i) { - unpack[unpack_lvl + 1][i * 2] = _mm512_unpacklo_epi8( - unpack[unpack_lvl][i * 2], unpack[unpack_lvl][i * 2 + 1]); - unpack[unpack_lvl + 1][i * 2 + 1] = _mm512_unpackhi_epi8( - unpack[unpack_lvl][i * 2], unpack[unpack_lvl][i * 2 + 1]); + unpack[unpackLvl + 1][i * 2] = mm512UnpackloEpi8( + unpack[unpackLvl][i * 2], unpack[unpackLvl][i * 2 + 1]); + unpack[unpackLvl + 1][i * 2 + 1] = mm512UnpackhiEpi8( + unpack[unpackLvl][i * 2], unpack[unpackLvl][i * 2 + 1]); } } if constexpr (kNumStreams == 8U) { - // path for double - // 1. unpack to epi16 block - // 2. permutexvar_epi16 to 128i block - // 3. shuffle 128i to final 512i target, index: - // {0x00, 0x04, 0x08, 0x0C}, {0x10, 0x14, 0x18, 0x1C}, - // {0x01, 0x05, 0x09, 0x0D}, {0x11, 0x15, 0x19, 0x1D}, - // {0x02, 0x06, 0x0A, 0x0E}, {0x12, 0x16, 0x1A, 0x1E}, - // {0x03, 0x07, 0x0B, 0x0F}, {0x13, 0x17, 0x1B, 0x1F}, + // Path for double. + // 1. Unpack to epi16 block. + // 2. Permutexvar_epi16 to 128i block. + // 3. Shuffle 128i to final 512i target, index: + // {0X00, 0x04, 0x08, 0x0C}, {0x10, 0x14, 0x18, 0x1C}, + // {0X01, 0x05, 0x09, 0x0D}, {0x11, 0x15, 0x19, 0x1D}, + // {0X02, 0x06, 0x0A, 0x0E}, {0x12, 0x16, 0x1A, 0x1E}, + // {0X03, 0x07, 0x0B, 0x0F}, {0x13, 0x17, 0x1B, 0x1F}, for (size_t i = 0; i < kNumStreams; ++i) permutex[i] = - _mm512_permutexvar_epi16(permutex_mask, unpack[KNumUnpack][i]); + mm512PermutexvarEpi16(permutexMask, unpack[KNumUnpack][i]); __m512i shuffle[kNumStreams]; - shuffle[0] = _mm512_shuffle_i32x4(permutex[0], permutex[2], 0b01000100); - shuffle[1] = _mm512_shuffle_i32x4(permutex[4], permutex[6], 0b01000100); - shuffle[2] = _mm512_shuffle_i32x4(permutex[0], permutex[2], 0b11101110); - shuffle[3] = _mm512_shuffle_i32x4(permutex[4], permutex[6], 0b11101110); - shuffle[4] = _mm512_shuffle_i32x4(permutex[1], permutex[3], 0b01000100); - shuffle[5] = _mm512_shuffle_i32x4(permutex[5], permutex[7], 0b01000100); - shuffle[6] = _mm512_shuffle_i32x4(permutex[1], permutex[3], 0b11101110); - shuffle[7] = _mm512_shuffle_i32x4(permutex[5], permutex[7], 0b11101110); - - final_result[0] = - _mm512_shuffle_i32x4(shuffle[0], shuffle[1], 0b10001000); - final_result[1] = - _mm512_shuffle_i32x4(shuffle[0], shuffle[1], 0b11011101); - final_result[2] = - _mm512_shuffle_i32x4(shuffle[2], shuffle[3], 0b10001000); - final_result[3] = - _mm512_shuffle_i32x4(shuffle[2], shuffle[3], 0b11011101); - final_result[4] = - _mm512_shuffle_i32x4(shuffle[4], shuffle[5], 0b10001000); - final_result[5] = - _mm512_shuffle_i32x4(shuffle[4], shuffle[5], 0b11011101); - final_result[6] = - _mm512_shuffle_i32x4(shuffle[6], shuffle[7], 0b10001000); - final_result[7] = - _mm512_shuffle_i32x4(shuffle[6], shuffle[7], 0b11011101); + shuffle[0] = mm512ShuffleI32x4(permutex[0], permutex[2], 0b01000100); + shuffle[1] = mm512ShuffleI32x4(permutex[4], permutex[6], 0b01000100); + shuffle[2] = mm512ShuffleI32x4(permutex[0], permutex[2], 0b11101110); + shuffle[3] = mm512ShuffleI32x4(permutex[4], permutex[6], 0b11101110); + shuffle[4] = mm512ShuffleI32x4(permutex[1], permutex[3], 0b01000100); + shuffle[5] = mm512ShuffleI32x4(permutex[5], permutex[7], 0b01000100); + shuffle[6] = mm512ShuffleI32x4(permutex[1], permutex[3], 0b11101110); + shuffle[7] = mm512ShuffleI32x4(permutex[5], permutex[7], 0b11101110); + + finalResult[0] = mm512ShuffleI32x4(shuffle[0], shuffle[1], 0b10001000); + finalResult[1] = mm512ShuffleI32x4(shuffle[0], shuffle[1], 0b11011101); + finalResult[2] = mm512ShuffleI32x4(shuffle[2], shuffle[3], 0b10001000); + finalResult[3] = mm512ShuffleI32x4(shuffle[2], shuffle[3], 0b11011101); + finalResult[4] = mm512ShuffleI32x4(shuffle[4], shuffle[5], 0b10001000); + finalResult[5] = mm512ShuffleI32x4(shuffle[4], shuffle[5], 0b11011101); + finalResult[6] = mm512ShuffleI32x4(shuffle[6], shuffle[7], 0b10001000); + finalResult[7] = mm512ShuffleI32x4(shuffle[6], shuffle[7], 0b11011101); } else { // Path for float. // 1. Processed hierarchically to 32i block using the unpack intrinsics. @@ -627,21 +597,16 @@ void ByteStreamSplitEncodeAvx512( // 3. Pack final 256i block with _mm256_permute2x128_si256. for (size_t i = 0; i < kNumStreams; ++i) permutex[i] = - _mm512_permutexvar_epi32(permutex_mask, unpack[KNumUnpack][i]); + mm512PermutexvarEpi32(permutexMask, unpack[KNumUnpack][i]); - final_result[0] = - _mm512_shuffle_i32x4(permutex[0], permutex[2], 0b01000100); - final_result[1] = - _mm512_shuffle_i32x4(permutex[0], permutex[2], 0b11101110); - final_result[2] = - _mm512_shuffle_i32x4(permutex[1], permutex[3], 0b01000100); - final_result[3] = - _mm512_shuffle_i32x4(permutex[1], permutex[3], 0b11101110); + finalResult[0] = mm512ShuffleI32x4(permutex[0], permutex[2], 0b01000100); + finalResult[1] = mm512ShuffleI32x4(permutex[0], permutex[2], 0b11101110); + finalResult[2] = mm512ShuffleI32x4(permutex[1], permutex[3], 0b01000100); + finalResult[3] = mm512ShuffleI32x4(permutex[1], permutex[3], 0b11101110); } for (size_t i = 0; i < kNumStreams; ++i) { - _mm512_storeu_si512( - &output_buffer_streams[i][block_index], final_result[i]); + mm512StoreuSi512(&outputBufferStreams[i][blockIndex], finalResult[i]); } } } @@ -649,36 +614,36 @@ void ByteStreamSplitEncodeAvx512( #if defined(ARROW_HAVE_SIMD_SPLIT) template -void inline ByteStreamSplitDecodeSimd( +void inline byteStreamSplitDecodeSimd( const uint8_t* data, - int64_t num_values, + int64_t numValues, int64_t stride, T* out) { #if defined(ARROW_HAVE_AVX512) - return ByteStreamSplitDecodeAvx512(data, num_values, stride, out); + return byteStreamSplitDecodeAvx512(data, numValues, stride, out); #elif defined(ARROW_HAVE_AVX2) - return ByteStreamSplitDecodeAvx2(data, num_values, stride, out); + return byteStreamSplitDecodeAvx2(data, numValues, stride, out); #elif defined(ARROW_HAVE_SSE4_2) - return ByteStreamSplitDecodeSse2(data, num_values, stride, out); + return byteStreamSplitDecodeSse2(data, numValues, stride, out); #else #error "ByteStreamSplitDecodeSimd not implemented" #endif } template -void inline ByteStreamSplitEncodeSimd( - const uint8_t* raw_values, - const int64_t num_values, - uint8_t* output_buffer_raw) { +void inline byteStreamSplitEncodeSimd( + const uint8_t* rawValues, + const int64_t numValues, + uint8_t* outputBufferRaw) { #if defined(ARROW_HAVE_AVX512) - return ByteStreamSplitEncodeAvx512( - raw_values, static_cast(num_values), output_buffer_raw); + return byteStreamSplitEncodeAvx512( + rawValues, static_cast(numValues), outputBufferRaw); #elif defined(ARROW_HAVE_AVX2) - return ByteStreamSplitEncodeAvx2( - raw_values, static_cast(num_values), output_buffer_raw); + return byteStreamSplitEncodeAvx2( + rawValues, static_cast(numValues), outputBufferRaw); #elif defined(ARROW_HAVE_SSE4_2) - return ByteStreamSplitEncodeSse2( - raw_values, static_cast(num_values), output_buffer_raw); + return byteStreamSplitEncodeSse2( + rawValues, static_cast(numValues), outputBufferRaw); #else #error "ByteStreamSplitEncodeSimd not implemented" #endif @@ -686,21 +651,21 @@ void inline ByteStreamSplitEncodeSimd( #endif // -// Scalar implementations +// Scalar implementations. // -inline void DoSplitStreams( +inline void doSplitStreams( const uint8_t* src, int width, int64_t nvalues, - uint8_t** dest_streams) { + uint8_t** destStreams) { // Value empirically chosen to provide the best performance on the author's - // machine + // machine. constexpr int kBlockSize = 32; while (nvalues >= kBlockSize) { for (int stream = 0; stream < width; ++stream) { - uint8_t* dest = dest_streams[stream]; + uint8_t* dest = destStreams[stream]; for (int i = 0; i < kBlockSize; i += 8) { uint64_t a = src[stream + i * width]; uint64_t b = src[stream + (i + 1) * width]; @@ -719,35 +684,35 @@ inline void DoSplitStreams( #endif ::arrow::util::SafeStore(&dest[i], r); } - dest_streams[stream] += kBlockSize; + destStreams[stream] += kBlockSize; } src += width * kBlockSize; nvalues -= kBlockSize; } - // Epilog + // Epilog. for (int stream = 0; stream < width; ++stream) { - uint8_t* dest = dest_streams[stream]; + uint8_t* dest = destStreams[stream]; for (int64_t i = 0; i < nvalues; ++i) { dest[i] = src[stream + i * width]; } } } -inline void DoMergeStreams( - const uint8_t** src_streams, +inline void doMergeStreams( + const uint8_t** srcStreams, int width, int64_t nvalues, uint8_t* dest) { // Value empirically chosen to provide the best performance on the author's - // machine + // machine. constexpr int kBlockSize = 128; while (nvalues >= kBlockSize) { for (int stream = 0; stream < width; ++stream) { - // Take kBlockSize bytes from the given stream and spread them - // to their logical places in destination. - const uint8_t* src = src_streams[stream]; + // Take kBlockSize bytes from the given stream and spread them. + // To their logical places in destination. + const uint8_t* src = srcStreams[stream]; for (int i = 0; i < kBlockSize; i += 8) { uint64_t v = ::arrow::util::SafeLoadAs(&src[i]); #if ARROW_LITTLE_ENDIAN @@ -770,15 +735,15 @@ inline void DoMergeStreams( dest[stream + (i + 7) * width] = static_cast(v); #endif } - src_streams[stream] += kBlockSize; + srcStreams[stream] += kBlockSize; } dest += width * kBlockSize; nvalues -= kBlockSize; } - // Epilog + // Epilog. for (int stream = 0; stream < width; ++stream) { - const uint8_t* src = src_streams[stream]; + const uint8_t* src = srcStreams[stream]; for (int64_t i = 0; i < nvalues; ++i) { dest[stream + i * width] = src[i]; } @@ -786,60 +751,58 @@ inline void DoMergeStreams( } template -void ByteStreamSplitEncodeScalar( - const uint8_t* raw_values, - const int64_t num_values, - uint8_t* output_buffer_raw) { +void byteStreamSplitEncodeScalar( + const uint8_t* rawValues, + const int64_t numValues, + uint8_t* outputBufferRaw) { constexpr int kNumStreams = static_cast(sizeof(T)); - std::array dest_streams; + std::array destStreams; for (int stream = 0; stream < kNumStreams; ++stream) { - dest_streams[stream] = &output_buffer_raw[stream * num_values]; + destStreams[stream] = &outputBufferRaw[stream * numValues]; } - DoSplitStreams(raw_values, kNumStreams, num_values, dest_streams.data()); + doSplitStreams(rawValues, kNumStreams, numValues, destStreams.data()); } template -void ByteStreamSplitDecodeScalar( +void byteStreamSplitDecodeScalar( const uint8_t* data, - int64_t num_values, + int64_t numValues, int64_t stride, T* out) { constexpr int kNumStreams = static_cast(sizeof(T)); - std::array src_streams; + std::array srcStreams; for (int stream = 0; stream < kNumStreams; ++stream) { - src_streams[stream] = &data[stream * stride]; + srcStreams[stream] = &data[stream * stride]; } - DoMergeStreams( - src_streams.data(), + doMergeStreams( + srcStreams.data(), kNumStreams, - num_values, + numValues, reinterpret_cast(out)); } template -void inline ByteStreamSplitEncode( - const uint8_t* raw_values, - const int64_t num_values, - uint8_t* output_buffer_raw) { +void inline byteStreamSplitEncode( + const uint8_t* rawValues, + const int64_t numValues, + uint8_t* outputBufferRaw) { #if defined(ARROW_HAVE_SIMD_SPLIT) - return ByteStreamSplitEncodeSimd( - raw_values, num_values, output_buffer_raw); + return byteStreamSplitEncodeSimd(rawValues, numValues, outputBufferRaw); #else - return ByteStreamSplitEncodeScalar( - raw_values, num_values, output_buffer_raw); + return byteStreamSplitEncodeScalar(rawValues, numValues, outputBufferRaw); #endif } template -void inline ByteStreamSplitDecode( +void inline byteStreamSplitDecode( const uint8_t* data, - int64_t num_values, + int64_t numValues, int64_t stride, T* out) { #if defined(ARROW_HAVE_SIMD_SPLIT) - return ByteStreamSplitDecodeSimd(data, num_values, stride, out); + return byteStreamSplitDecodeSimd(data, numValues, stride, out); #else - return ByteStreamSplitDecodeScalar(data, num_values, stride, out); + return byteStreamSplitDecodeScalar(data, numValues, stride, out); #endif } diff --git a/velox/dwio/parquet/writer/arrow/util/CMakeLists.txt b/velox/dwio/parquet/writer/arrow/util/CMakeLists.txt index 2aff416fcdb..ce41ffe680b 100644 --- a/velox/dwio/parquet/writer/arrow/util/CMakeLists.txt +++ b/velox/dwio/parquet/writer/arrow/util/CMakeLists.txt @@ -21,6 +21,15 @@ velox_add_library( CompressionLZ4.cpp Hashing.cpp Crc32.cpp + HEADERS + ByteStreamSplitInternal.h + Compression.h + CompressionInternal.h + Crc32.h + Hashing.h + OverflowUtilInternal.h + VisitArrayInline.h + safe-math.h ) velox_link_libraries( diff --git a/velox/dwio/parquet/writer/arrow/util/Compression.cpp b/velox/dwio/parquet/writer/arrow/util/Compression.cpp index f5f77313d10..a51c29b917a 100644 --- a/velox/dwio/parquet/writer/arrow/util/Compression.cpp +++ b/velox/dwio/parquet/writer/arrow/util/Compression.cpp @@ -24,15 +24,14 @@ #include "arrow/result.h" #include "arrow/status.h" -#include "arrow/util/logging.h" #include "velox/dwio/parquet/writer/arrow/util/CompressionInternal.h" namespace facebook::velox::parquet::arrow::util { namespace { -Status CheckSupportsCompressionLevel(Compression::type type) { - if (!Codec::SupportsCompressionLevel(type)) { +Status checkSupportsCompressionLevel(Compression::type type) { + if (!Codec::supportsCompressionLevel(type)) { return Status::Invalid( "The specified codec does not support the compression level parameter"); } @@ -41,20 +40,20 @@ Status CheckSupportsCompressionLevel(Compression::type type) { } // namespace -int Codec::UseDefaultCompressionLevel() { +int Codec::useDefaultCompressionLevel() { return kUseDefaultCompressionLevel; } -Status Codec::Init() { +Status Codec::init() { return Status::OK(); } -const std::string& Codec::GetCodecAsString(Compression::type t) { +const std::string& Codec::getCodecAsString(Compression::type t) { static const std::string uncompressed = "uncompressed", snappy = "snappy", gzip = "gzip", lzo = "lzo", brotli = "brotli", - lz4_raw = "lz4_raw", lz4 = "lz4", - lz4_hadoop = "lz4_hadoop", zstd = "zstd", - bz2 = "bz2", unknown = "unknown"; + lz4Raw = "lz4_raw", lz4 = "lz4", + lz4Hadoop = "lz4_hadoop", zstd = "zstd", bz2 = "bz2", + unknown = "unknown"; switch (t) { case Compression::UNCOMPRESSED: @@ -68,11 +67,11 @@ const std::string& Codec::GetCodecAsString(Compression::type t) { case Compression::BROTLI: return brotli; case Compression::LZ4: - return lz4_raw; + return lz4Raw; case Compression::LZ4_FRAME: return lz4; case Compression::LZ4_HADOOP: - return lz4_hadoop; + return lz4Hadoop; case Compression::ZSTD: return zstd; case Compression::BZ2: @@ -82,7 +81,7 @@ const std::string& Codec::GetCodecAsString(Compression::type t) { } } -Result Codec::GetCompressionType(const std::string& name) { +Result Codec::getCompressionType(const std::string& name) { if (name == "uncompressed") { return Compression::UNCOMPRESSED; } else if (name == "gzip") { @@ -108,7 +107,7 @@ Result Codec::GetCompressionType(const std::string& name) { } } -bool Codec::SupportsCompressionLevel(Compression::type codec) { +bool Codec::supportsCompressionLevel(Compression::type codec) { switch (codec) { case Compression::GZIP: case Compression::BROTLI: @@ -122,88 +121,88 @@ bool Codec::SupportsCompressionLevel(Compression::type codec) { } } -Result Codec::MaximumCompressionLevel(Compression::type codec_type) { - RETURN_NOT_OK(CheckSupportsCompressionLevel(codec_type)); - ARROW_ASSIGN_OR_RAISE(auto codec, Codec::Create(codec_type)); - return codec->maximum_compression_level(); +Result Codec::maximumCompressionLevel(Compression::type codecType) { + RETURN_NOT_OK(checkSupportsCompressionLevel(codecType)); + ARROW_ASSIGN_OR_RAISE(auto codec, Codec::create(codecType)); + return codec->maximumCompressionLevel(); } -Result Codec::MinimumCompressionLevel(Compression::type codec_type) { - RETURN_NOT_OK(CheckSupportsCompressionLevel(codec_type)); - ARROW_ASSIGN_OR_RAISE(auto codec, Codec::Create(codec_type)); - return codec->minimum_compression_level(); +Result Codec::minimumCompressionLevel(Compression::type codecType) { + RETURN_NOT_OK(checkSupportsCompressionLevel(codecType)); + ARROW_ASSIGN_OR_RAISE(auto codec, Codec::create(codecType)); + return codec->minimumCompressionLevel(); } -Result Codec::DefaultCompressionLevel(Compression::type codec_type) { - RETURN_NOT_OK(CheckSupportsCompressionLevel(codec_type)); - ARROW_ASSIGN_OR_RAISE(auto codec, Codec::Create(codec_type)); - return codec->default_compression_level(); +Result Codec::defaultCompressionLevel(Compression::type codecType) { + RETURN_NOT_OK(checkSupportsCompressionLevel(codecType)); + ARROW_ASSIGN_OR_RAISE(auto codec, Codec::create(codecType)); + return codec->defaultCompressionLevel(); } -Result> Codec::Create( - Compression::type codec_type, - const CodecOptions& codec_options) { - if (!IsAvailable(codec_type)) { - if (codec_type == Compression::LZO) { +Result> Codec::create( + Compression::type codecType, + const CodecOptions& codecOptions) { + if (!isAvailable(codecType)) { + if (codecType == Compression::LZO) { return Status::NotImplemented("LZO codec not implemented"); } - auto name = GetCodecAsString(codec_type); + auto name = getCodecAsString(codecType); if (name == "unknown") { return Status::Invalid("Unrecognized codec"); } return Status::NotImplemented( - "Support for codec '", GetCodecAsString(codec_type), "' not built"); + "Support for codec '", getCodecAsString(codecType), "' not built"); } - auto compression_level = codec_options.compression_level; - if (compression_level != kUseDefaultCompressionLevel && - !SupportsCompressionLevel(codec_type)) { + auto compressionLevel = codecOptions.compressionLevel; + if (compressionLevel != kUseDefaultCompressionLevel && + !supportsCompressionLevel(codecType)) { return Status::Invalid( "Codec '", - GetCodecAsString(codec_type), + getCodecAsString(codecType), "' doesn't support setting a compression level."); } std::unique_ptr codec; - switch (codec_type) { + switch (codecType) { case Compression::UNCOMPRESSED: return nullptr; case Compression::SNAPPY: - codec = internal::MakeSnappyCodec(); + codec = internal::makeSnappyCodec(); break; case Compression::GZIP: { - auto opt = dynamic_cast(&codec_options); - codec = internal::MakeGZipCodec( - compression_level, - opt ? opt->gzip_format : GZipFormat::GZIP, - opt ? opt->window_bits : std::nullopt); + auto opt = dynamic_cast(&codecOptions); + codec = internal::makeGZipCodec( + compressionLevel, + opt ? opt->gzipFormat : GZipFormat::GZIP, + opt ? opt->windowBits : std::nullopt); break; } case Compression::BROTLI: { #ifdef ARROW_WITH_BROTLI - auto opt = dynamic_cast(&codec_options); - codec = internal::MakeBrotliCodec( - compression_level, opt ? opt->window_bits : std::nullopt); + auto opt = dynamic_cast(&codecOptions); + codec = internal::makeBrotliCodec( + compressionLevel, opt ? opt->windowBits : std::nullopt); #endif break; } case Compression::LZ4: - codec = internal::MakeLz4RawCodec(compression_level); + codec = internal::makeLz4RawCodec(compressionLevel); break; case Compression::LZ4_FRAME: - codec = internal::MakeLz4FrameCodec(compression_level); + codec = internal::makeLz4FrameCodec(compressionLevel); break; case Compression::LZ4_HADOOP: - codec = internal::MakeLz4HadoopRawCodec(); + codec = internal::makeLz4HadoopRawCodec(); break; case Compression::ZSTD: - codec = internal::MakeZSTDCodec(compression_level); + codec = internal::makeZSTDCodec(compressionLevel); break; case Compression::BZ2: #ifdef ARROW_WITH_BZ2 - codec = internal::MakeBZ2Codec(compression_level); + codec = internal::makeBZ2Codec(compressionLevel); #endif break; default: @@ -214,19 +213,19 @@ Result> Codec::Create( return Status::NotImplemented("LZO codec not implemented"); } - RETURN_NOT_OK(codec->Init()); + RETURN_NOT_OK(codec->init()); return std::move(codec); } -// use compression level to create Codec -Result> Codec::Create( - Compression::type codec_type, - int compression_level) { - return Codec::Create(codec_type, CodecOptions{compression_level}); +// Use compression level to create Codec. +Result> Codec::create( + Compression::type codecType, + int compressionLevel) { + return Codec::create(codecType, CodecOptions{compressionLevel}); } -bool Codec::IsAvailable(Compression::type codec_type) { - switch (codec_type) { +bool Codec::isAvailable(Compression::type codecType) { + switch (codecType) { case Compression::UNCOMPRESSED: case Compression::SNAPPY: case Compression::GZIP: diff --git a/velox/dwio/parquet/writer/arrow/util/Compression.h b/velox/dwio/parquet/writer/arrow/util/Compression.h index 7432008fd20..be92fddeebc 100644 --- a/velox/dwio/parquet/writer/arrow/util/Compression.h +++ b/velox/dwio/parquet/writer/arrow/util/Compression.h @@ -31,7 +31,7 @@ namespace facebook::velox::parquet::arrow { struct Compression { - /// \brief Compression algorithm + /// \brief Compression algorithm. enum type { UNCOMPRESSED, SNAPPY, @@ -53,54 +53,53 @@ using namespace ::arrow; constexpr int kUseDefaultCompressionLevel = std::numeric_limits::min(); -/// \brief Streaming compressor interface +/// \brief Streaming compressor interface. /// class ARROW_EXPORT Compressor { public: virtual ~Compressor() = default; struct CompressResult { - int64_t bytes_read; - int64_t bytes_written; + int64_t bytesRead; + int64_t bytesWritten; }; struct FlushResult { - int64_t bytes_written; - bool should_retry; + int64_t bytesWritten; + bool shouldRetry; }; struct EndResult { - int64_t bytes_written; - bool should_retry; + int64_t bytesWritten; + bool shouldRetry; }; /// \brief Compress some input. /// /// If bytes_read is 0 on return, then a larger output buffer should be /// supplied. - virtual Result Compress( - int64_t input_len, + virtual Result compress( + int64_t inputLen, const uint8_t* input, - int64_t output_len, + int64_t outputLen, uint8_t* output) = 0; /// \brief Flush part of the compressed output. /// - /// If should_retry is true on return, Flush() should be called again - /// with a larger buffer. - virtual Result Flush(int64_t output_len, uint8_t* output) = 0; + /// If should_retry is true on return, Flush() should be called again with a + /// larger buffer. + virtual Result flush(int64_t outputLen, uint8_t* output) = 0; /// \brief End compressing, doing whatever is necessary to end the stream. /// - /// If should_retry is true on return, End() should be called again - /// with a larger buffer. Otherwise, the Compressor should not be used - /// anymore. + /// If should_retry is true on return, End() should be called again with a + /// larger buffer. Otherwise, the Compressor should not be used anymore. /// /// End() implies Flush(). - virtual Result End(int64_t output_len, uint8_t* output) = 0; + virtual Result end(int64_t outputLen, uint8_t* output) = 0; // XXX add methods for buffer size heuristics? }; -/// \brief Streaming decompressor interface +/// \brief Streaming decompressor interface. /// class ARROW_EXPORT Decompressor { public: @@ -108,49 +107,48 @@ class ARROW_EXPORT Decompressor { struct DecompressResult { // XXX is need_more_output necessary? (Brotli?) - int64_t bytes_read; - int64_t bytes_written; - bool need_more_output; + int64_t bytesRead; + int64_t bytesWritten; + bool needMoreOutput; }; /// \brief Decompress some input. /// - /// If need_more_output is true on return, a larger output buffer needs - /// to be supplied. - virtual Result Decompress( - int64_t input_len, + /// If need_more_output is true on return, a larger output buffer needs to be + /// supplied. + virtual Result decompress( + int64_t inputLen, const uint8_t* input, - int64_t output_len, + int64_t outputLen, uint8_t* output) = 0; /// \brief Return whether the compressed stream is finished. /// - /// This is a heuristic. If true is returned, then it is guaranteed - /// that the stream is finished. If false is returned, however, it may - /// simply be that the underlying library isn't able to provide the - /// information. - virtual bool IsFinished() = 0; + /// This is a heuristic. If true is returned, then it is guaranteed that the + /// stream is finished. If false is returned, however, it may simply be that + /// the underlying library isn't able to provide the information. + virtual bool isFinished() = 0; /// \brief Reinitialize decompressor, making it ready for a new compressed /// stream. - virtual Status Reset() = 0; + virtual Status reset() = 0; // XXX add methods for buffer size heuristics? }; -/// \brief Compression codec options +/// \brief Compression codec options. class ARROW_EXPORT CodecOptions { public: - explicit CodecOptions(int compression_level = kUseDefaultCompressionLevel) - : compression_level(compression_level) {} + explicit CodecOptions(int compressionLevel = kUseDefaultCompressionLevel) + : compressionLevel(compressionLevel) {} virtual ~CodecOptions() = default; - int compression_level; + int compressionLevel; }; -// ---------------------------------------------------------------------- -// GZip codec options implementation +// ----------------------------------------------------------------------. +// GZip codec options implementation. enum class GZipFormat { ZLIB, @@ -160,123 +158,124 @@ enum class GZipFormat { class ARROW_EXPORT GZipCodecOptions : public CodecOptions { public: - GZipFormat gzip_format = GZipFormat::GZIP; - std::optional window_bits; + GZipFormat gzipFormat = GZipFormat::GZIP; + std::optional windowBits; }; -// ---------------------------------------------------------------------- -// brotli codec options implementation +// ----------------------------------------------------------------------. +// Brotli codec options implementation. class ARROW_EXPORT BrotliCodecOptions : public CodecOptions { public: - std::optional window_bits; + std::optional windowBits; }; -/// \brief Compression codec +/// \brief Compression codec. class ARROW_EXPORT Codec { public: virtual ~Codec() = default; /// \brief Return special value to indicate that a codec implementation - /// should use its default compression level - static int UseDefaultCompressionLevel(); + /// should use its default compression level. + static int useDefaultCompressionLevel(); - /// \brief Return a string name for compression type - static const std::string& GetCodecAsString(Compression::type t); + /// \brief Return a string name for compression type. + static const std::string& getCodecAsString(Compression::type t); - /// \brief Return compression type for name (all lower case) - static Result GetCompressionType(const std::string& name); + /// \brief Return compression type for name (all lower case). + static Result getCompressionType(const std::string& name); - /// \brief Create a codec for the given compression algorithm with - /// CodecOptions - static Result> Create( + /// \brief Create a codec for the given compression algorithm with codec + /// options. + static Result> create( Compression::type codec, - const CodecOptions& codec_options = CodecOptions{}); + const CodecOptions& codecOptions = CodecOptions{}); - /// \brief Create a codec for the given compression algorithm - static Result> Create( + /// \brief Create a codec for the given compression algorithm. + static Result> create( Compression::type codec, - int compression_level); + int compressionLevel); - /// \brief Return true if support for indicated codec has been enabled - static bool IsAvailable(Compression::type codec); + /// \brief Return true if support for indicated codec has been enabled. + static bool isAvailable(Compression::type codec); - /// \brief Return true if indicated codec supports setting a compression level - static bool SupportsCompressionLevel(Compression::type codec); + /// \brief Return true if indicated codec supports setting a compression + /// level. + static bool supportsCompressionLevel(Compression::type codec); - /// \brief Return the smallest supported compression level for the codec - /// Note: This function creates a temporary Codec instance - static Result MinimumCompressionLevel(Compression::type codec); + /// \brief Return the smallest supported compression level for the codec. + /// Note: This function creates a temporary Codec instance. + static Result minimumCompressionLevel(Compression::type codec); - /// \brief Return the largest supported compression level for the codec - /// Note: This function creates a temporary Codec instance - static Result MaximumCompressionLevel(Compression::type codec); + /// \brief Return the largest supported compression level for the codec. + /// Note: This function creates a temporary Codec instance. + static Result maximumCompressionLevel(Compression::type codec); - /// \brief Return the default compression level - /// Note: This function creates a temporary Codec instance - static Result DefaultCompressionLevel(Compression::type codec); + /// \brief Return the default compression level. + /// Note: This function creates a temporary Codec instance. + static Result defaultCompressionLevel(Compression::type codec); - /// \brief Return the smallest supported compression level - virtual int minimum_compression_level() const = 0; + /// \brief Return the smallest supported compression level. + virtual int minimumCompressionLevel() const = 0; - /// \brief Return the largest supported compression level - virtual int maximum_compression_level() const = 0; + /// \brief Return the largest supported compression level. + virtual int maximumCompressionLevel() const = 0; - /// \brief Return the default compression level - virtual int default_compression_level() const = 0; + /// \brief Return the default compression level. + virtual int defaultCompressionLevel() const = 0; - /// \brief One-shot decompression function + /// \brief One-shot decompression function. /// - /// output_buffer_len must be correct and therefore be obtained in advance. + /// Output_buffer_len must be correct and therefore be obtained in advance. /// The actual decompressed length is returned. /// /// \note One-shot decompression is not always compatible with streaming /// compression. Depending on the codec (e.g. LZ4), different formats may /// be used. - virtual Result Decompress( - int64_t input_len, + virtual Result decompress( + int64_t inputLen, const uint8_t* input, - int64_t output_buffer_len, - uint8_t* output_buffer) = 0; + int64_t outputBufferLen, + uint8_t* outputBuffer) = 0; - /// \brief One-shot compression function + /// \brief One-shot compression function. /// - /// output_buffer_len must first have been computed using MaxCompressedLen(). + /// Output_buffer_len must first have been computed using MaxCompressedLen(). /// The actual compressed length is returned. /// /// \note One-shot compression is not always compatible with streaming - /// decompression. Depending on the codec (e.g. LZ4), different formats may - /// be used. - virtual Result Compress( - int64_t input_len, + /// decompression. Depending on the codec (e.g. LZ4), different formats may be + /// used. + virtual Result compress( + int64_t inputLen, const uint8_t* input, - int64_t output_buffer_len, - uint8_t* output_buffer) = 0; + int64_t outputBufferLen, + uint8_t* outputBuffer) = 0; - virtual int64_t MaxCompressedLen(int64_t input_len, const uint8_t* input) = 0; + virtual int64_t maxCompressedLen(int64_t inputLen, const uint8_t* input) = 0; - /// \brief Create a streaming compressor instance - virtual Result> MakeCompressor() = 0; + /// \brief Create a streaming compressor instance. + virtual Result> makeCompressor() = 0; - /// \brief Create a streaming compressor instance - virtual Result> MakeDecompressor() = 0; + /// \brief Create a streaming compressor instance. + virtual Result> makeDecompressor() = 0; - /// \brief This Codec's compression type - virtual Compression::type compression_type() const = 0; + /// \brief This Codec's compression type. + virtual Compression::type compressionType() const = 0; - /// \brief The name of this Codec's compression type + /// \brief The name of this Codec's compression type. const std::string& name() const { - return GetCodecAsString(compression_type()); + return getCodecAsString(compressionType()); } - /// \brief This Codec's compression level, if applicable - virtual int compression_level() const { - return UseDefaultCompressionLevel(); + /// \brief This Codec's compression level, if applicable. + virtual int compressionLevel() const { + return useDefaultCompressionLevel(); } private: /// \brief Initializes the codec's resources. - virtual Status Init(); + virtual Status init(); }; } // namespace facebook::velox::parquet::arrow::util diff --git a/velox/dwio/parquet/writer/arrow/util/CompressionInternal.h b/velox/dwio/parquet/writer/arrow/util/CompressionInternal.h index 53ddd6e20d0..d43ad595642 100644 --- a/velox/dwio/parquet/writer/arrow/util/CompressionInternal.h +++ b/velox/dwio/parquet/writer/arrow/util/CompressionInternal.h @@ -24,8 +24,8 @@ namespace facebook::velox::parquet::arrow::util { -// ---------------------------------------------------------------------- -// Internal Codec factories +// ----------------------------------------------------------------------. +// Internal Codec factories. namespace internal { @@ -34,49 +34,49 @@ namespace internal { constexpr int kBrotliDefaultCompressionLevel = 8; // Brotli codec. -std::unique_ptr MakeBrotliCodec( - int compression_level = kBrotliDefaultCompressionLevel, - std::optional window_bits = std::nullopt); +std::unique_ptr makeBrotliCodec( + int compressionLevel = kBrotliDefaultCompressionLevel, + std::optional windowBits = std::nullopt); // BZ2 codec. constexpr int kBZ2DefaultCompressionLevel = 9; -std::unique_ptr MakeBZ2Codec( - int compression_level = kBZ2DefaultCompressionLevel); +std::unique_ptr makeBZ2Codec( + int compressionLevel = kBZ2DefaultCompressionLevel); -// GZip +// GZip. constexpr int kGZipDefaultCompressionLevel = 9; -std::unique_ptr MakeGZipCodec( - int compression_level = kGZipDefaultCompressionLevel, +std::unique_ptr makeGZipCodec( + int compressionLevel = kGZipDefaultCompressionLevel, GZipFormat format = GZipFormat::GZIP, - std::optional window_bits = std::nullopt); + std::optional windowBits = std::nullopt); -// Snappy -std::unique_ptr MakeSnappyCodec(); +// Snappy. +std::unique_ptr makeSnappyCodec(); -// Lz4 Codecs +// Lz4 Codecs. constexpr int kLz4DefaultCompressionLevel = 1; // Lz4 frame format codec. -std::unique_ptr MakeLz4FrameCodec( - int compression_level = kLz4DefaultCompressionLevel); +std::unique_ptr makeLz4FrameCodec( + int compressionLevel = kLz4DefaultCompressionLevel); // Lz4 "raw" format codec. -std::unique_ptr MakeLz4RawCodec( - int compression_level = kLz4DefaultCompressionLevel); +std::unique_ptr makeLz4RawCodec( + int compressionLevel = kLz4DefaultCompressionLevel); // Lz4 "Hadoop" format codec (== Lz4 raw codec prefixed with lengths header) -std::unique_ptr MakeLz4HadoopRawCodec(); +std::unique_ptr makeLz4HadoopRawCodec(); // ZSTD codec. -// XXX level = 1 probably doesn't compress very much +// XXX level = 1 probably doesn't compress very much. constexpr int kZSTDDefaultCompressionLevel = 1; -std::unique_ptr MakeZSTDCodec( - int compression_level = kZSTDDefaultCompressionLevel); +std::unique_ptr makeZSTDCodec( + int compressionLevel = kZSTDDefaultCompressionLevel); } // namespace internal } // namespace facebook::velox::parquet::arrow::util diff --git a/velox/dwio/parquet/writer/arrow/util/CompressionLZ4.cpp b/velox/dwio/parquet/writer/arrow/util/CompressionLZ4.cpp index 399fe58b1b0..0b0aeb09d36 100644 --- a/velox/dwio/parquet/writer/arrow/util/CompressionLZ4.cpp +++ b/velox/dwio/parquet/writer/arrow/util/CompressionLZ4.cpp @@ -39,25 +39,25 @@ namespace { constexpr int kLz4MinCompressionLevel = 1; -static Status LZ4Error(LZ4F_errorCode_t ret, const char* prefix_msg) { - return Status::IOError(prefix_msg, LZ4F_getErrorName(ret)); +static Status lZ4Error(LZ4F_errorCode_t ret, const char* prefixMsg) { + return Status::IOError(prefixMsg, LZ4F_getErrorName(ret)); } -static LZ4F_preferences_t DefaultPreferences() { +static LZ4F_preferences_t defaultPreferences() { LZ4F_preferences_t prefs; memset(&prefs, 0, sizeof(prefs)); return prefs; } -static LZ4F_preferences_t PreferencesWithCompressionLevel( - int compression_level) { - LZ4F_preferences_t prefs = DefaultPreferences(); - prefs.compressionLevel = compression_level; +static LZ4F_preferences_t preferencesWithCompressionLevel( + int compressionLevel) { + LZ4F_preferences_t prefs = defaultPreferences(); + prefs.compressionLevel = compressionLevel; return prefs; } -// ---------------------------------------------------------------------- -// Lz4 frame decompressor implementation +// ----------------------------------------------------------------------. +// Lz4 frame Decompressor implementation. class LZ4Decompressor : public Decompressor { public: @@ -69,21 +69,21 @@ class LZ4Decompressor : public Decompressor { } } - Status Init() { + Status init() { LZ4F_errorCode_t ret; finished_ = false; ret = LZ4F_createDecompressionContext(&ctx_, LZ4F_VERSION); if (LZ4F_isError(ret)) { - return LZ4Error(ret, "LZ4 init failed: "); + return lZ4Error(ret, "LZ4 init failed: "); } else { return Status::OK(); } } - Status Reset() override { + Status reset() override { #if defined(LZ4_VERSION_NUMBER) && LZ4_VERSION_NUMBER >= 10800 - // LZ4F_resetDecompressionContext appeared in 1.8.0 + // LZ4F_resetDecompressionContext appeared in 1.8.0. VELOX_DCHECK_NOT_NULL(ctx_); LZ4F_resetDecompressionContext(ctx_); finished_ = false; @@ -92,34 +92,34 @@ class LZ4Decompressor : public Decompressor { if (ctx_ != nullptr) { ARROW_UNUSED(LZ4F_freeDecompressionContext(ctx_)); } - return Init(); + return init(); #endif } - Result Decompress( - int64_t input_len, + Result decompress( + int64_t inputLen, const uint8_t* input, - int64_t output_len, + int64_t outputLen, uint8_t* output) override { auto src = input; auto dst = output; - auto src_size = static_cast(input_len); - auto dst_capacity = static_cast(output_len); + auto srcSize = static_cast(inputLen); + auto dstCapacity = static_cast(outputLen); size_t ret; ret = LZ4F_decompress( - ctx_, dst, &dst_capacity, src, &src_size, nullptr /* options */); + ctx_, dst, &dstCapacity, src, &srcSize, nullptr /* options */); if (LZ4F_isError(ret)) { - return LZ4Error(ret, "LZ4 decompress failed: "); + return lZ4Error(ret, "LZ4 decompress failed: "); } finished_ = (ret == 0); return DecompressResult{ - static_cast(src_size), - static_cast(dst_capacity), - (src_size == 0 && dst_capacity == 0)}; + static_cast(srcSize), + static_cast(dstCapacity), + (srcSize == 0 && dstCapacity == 0)}; } - bool IsFinished() override { + bool isFinished() override { return finished_; } @@ -128,13 +128,13 @@ class LZ4Decompressor : public Decompressor { bool finished_; }; -// ---------------------------------------------------------------------- -// Lz4 frame compressor implementation +// ----------------------------------------------------------------------. +// Lz4 frame compressor implementation. class LZ4Compressor : public Compressor { public: - explicit LZ4Compressor(int compression_level) - : compression_level_(compression_level) {} + explicit LZ4Compressor(int compressionLevel) + : compressionLevel_(compressionLevel) {} ~LZ4Compressor() override { if (ctx_ != nullptr) { @@ -142,482 +142,480 @@ class LZ4Compressor : public Compressor { } } - Status Init() { + Status init() { LZ4F_errorCode_t ret; - prefs_ = PreferencesWithCompressionLevel(compression_level_); - first_time_ = true; + prefs_ = preferencesWithCompressionLevel(compressionLevel_); + firstTime_ = true; ret = LZ4F_createCompressionContext(&ctx_, LZ4F_VERSION); if (LZ4F_isError(ret)) { - return LZ4Error(ret, "LZ4 init failed: "); + return lZ4Error(ret, "LZ4 init failed: "); } else { return Status::OK(); } } -#define BEGIN_COMPRESS(dst, dst_capacity, output_too_small) \ - if (first_time_) { \ - if (dst_capacity < LZ4F_HEADER_SIZE_MAX) { \ - /* Output too small to write LZ4F header */ \ - return (output_too_small); \ - } \ - ret = LZ4F_compressBegin(ctx_, dst, dst_capacity, &prefs_); \ - if (LZ4F_isError(ret)) { \ - return LZ4Error(ret, "LZ4 compress begin failed: "); \ - } \ - first_time_ = false; \ - dst += ret; \ - dst_capacity -= ret; \ - bytes_written += static_cast(ret); \ - } - - Result Compress( - int64_t input_len, +#define BEGIN_COMPRESS(dst, dstCapacity, outputTooSmall) \ + if (firstTime_) { \ + if (dstCapacity < LZ4F_HEADER_SIZE_MAX) { \ + /* Output too small to write LZ4F header */ \ + return (outputTooSmall); \ + } \ + ret = LZ4F_compressBegin(ctx_, dst, dstCapacity, &prefs_); \ + if (LZ4F_isError(ret)) { \ + return lZ4Error(ret, "LZ4 compress begin failed: "); \ + } \ + firstTime_ = false; \ + dst += ret; \ + dstCapacity -= ret; \ + bytesWritten += static_cast(ret); \ + } + + Result compress( + int64_t inputLen, const uint8_t* input, - int64_t output_len, + int64_t outputLen, uint8_t* output) override { auto src = input; auto dst = output; - auto src_size = static_cast(input_len); - auto dst_capacity = static_cast(output_len); + auto srcSize = static_cast(inputLen); + auto dstCapacity = static_cast(outputLen); size_t ret; - int64_t bytes_written = 0; + int64_t bytesWritten = 0; - BEGIN_COMPRESS(dst, dst_capacity, (CompressResult{0, 0})); + BEGIN_COMPRESS(dst, dstCapacity, (CompressResult{0, 0})); - if (dst_capacity < LZ4F_compressBound(src_size, &prefs_)) { - // Output too small to compress into - return CompressResult{0, bytes_written}; + if (dstCapacity < LZ4F_compressBound(srcSize, &prefs_)) { + // Output too small to compress into. + return CompressResult{0, bytesWritten}; } ret = LZ4F_compressUpdate( - ctx_, dst, dst_capacity, src, src_size, nullptr /* options */); + ctx_, dst, dstCapacity, src, srcSize, nullptr /* options */); if (LZ4F_isError(ret)) { - return LZ4Error(ret, "LZ4 compress update failed: "); + return lZ4Error(ret, "LZ4 compress update failed: "); } - bytes_written += static_cast(ret); - VELOX_DCHECK_LE(bytes_written, output_len); - return CompressResult{input_len, bytes_written}; + bytesWritten += static_cast(ret); + VELOX_DCHECK_LE(bytesWritten, outputLen); + return CompressResult{inputLen, bytesWritten}; } - Result Flush(int64_t output_len, uint8_t* output) override { + Result flush(int64_t outputLen, uint8_t* output) override { auto dst = output; - auto dst_capacity = static_cast(output_len); + auto dstCapacity = static_cast(outputLen); size_t ret; - int64_t bytes_written = 0; + int64_t bytesWritten = 0; - BEGIN_COMPRESS(dst, dst_capacity, (FlushResult{0, true})); + BEGIN_COMPRESS(dst, dstCapacity, (FlushResult{0, true})); - if (dst_capacity < LZ4F_compressBound(0, &prefs_)) { - // Output too small to flush into - return FlushResult{bytes_written, true}; + if (dstCapacity < LZ4F_compressBound(0, &prefs_)) { + // Output too small to flush into. + return FlushResult{bytesWritten, true}; } - ret = LZ4F_flush(ctx_, dst, dst_capacity, nullptr /* options */); + ret = LZ4F_flush(ctx_, dst, dstCapacity, nullptr /* options */); if (LZ4F_isError(ret)) { - return LZ4Error(ret, "LZ4 flush failed: "); + return lZ4Error(ret, "LZ4 flush failed: "); } - bytes_written += static_cast(ret); - VELOX_DCHECK_LE(bytes_written, output_len); - return FlushResult{bytes_written, false}; + bytesWritten += static_cast(ret); + VELOX_DCHECK_LE(bytesWritten, outputLen); + return FlushResult{bytesWritten, false}; } - Result End(int64_t output_len, uint8_t* output) override { + Result end(int64_t outputLen, uint8_t* output) override { auto dst = output; - auto dst_capacity = static_cast(output_len); + auto dstCapacity = static_cast(outputLen); size_t ret; - int64_t bytes_written = 0; + int64_t bytesWritten = 0; - BEGIN_COMPRESS(dst, dst_capacity, (EndResult{0, true})); + BEGIN_COMPRESS(dst, dstCapacity, (EndResult{0, true})); - if (dst_capacity < LZ4F_compressBound(0, &prefs_)) { - // Output too small to end frame into - return EndResult{bytes_written, true}; + if (dstCapacity < LZ4F_compressBound(0, &prefs_)) { + // Output too small to end frame into. + return EndResult{bytesWritten, true}; } - ret = LZ4F_compressEnd(ctx_, dst, dst_capacity, nullptr /* options */); + ret = LZ4F_compressEnd(ctx_, dst, dstCapacity, nullptr /* options */); if (LZ4F_isError(ret)) { - return LZ4Error(ret, "LZ4 end failed: "); + return lZ4Error(ret, "LZ4 end failed: "); } - bytes_written += static_cast(ret); - VELOX_DCHECK_LE(bytes_written, output_len); - return EndResult{bytes_written, false}; + bytesWritten += static_cast(ret); + VELOX_DCHECK_LE(bytesWritten, outputLen); + return EndResult{bytesWritten, false}; } #undef BEGIN_COMPRESS protected: - int compression_level_; + int compressionLevel_; LZ4F_compressionContext_t ctx_ = nullptr; LZ4F_preferences_t prefs_; - bool first_time_; + bool firstTime_; }; -// ---------------------------------------------------------------------- -// Lz4 frame codec implementation +// ----------------------------------------------------------------------. +// Lz4 frame codec implementation. class Lz4FrameCodec : public Codec { public: - explicit Lz4FrameCodec(int compression_level) - : compression_level_( - compression_level == kUseDefaultCompressionLevel + explicit Lz4FrameCodec(int compressionLevel) + : compressionLevel_( + compressionLevel == kUseDefaultCompressionLevel ? kLz4DefaultCompressionLevel - : compression_level), - prefs_(PreferencesWithCompressionLevel(compression_level_)) {} + : compressionLevel), + prefs_(preferencesWithCompressionLevel(compressionLevel_)) {} - int64_t MaxCompressedLen( - int64_t input_len, + int64_t maxCompressedLen( + int64_t inputLen, const uint8_t* ARROW_ARG_UNUSED(input)) override { return static_cast( - LZ4F_compressFrameBound(static_cast(input_len), &prefs_)); + LZ4F_compressFrameBound(static_cast(inputLen), &prefs_)); } - Result Compress( - int64_t input_len, + Result compress( + int64_t inputLen, const uint8_t* input, - int64_t output_buffer_len, - uint8_t* output_buffer) override { - auto output_len = LZ4F_compressFrame( - output_buffer, - static_cast(output_buffer_len), + int64_t outputBufferLen, + uint8_t* outputBuffer) override { + auto outputLen = LZ4F_compressFrame( + outputBuffer, + static_cast(outputBufferLen), input, - static_cast(input_len), + static_cast(inputLen), &prefs_); - if (LZ4F_isError(output_len)) { - return LZ4Error(output_len, "Lz4 compression failure: "); + if (LZ4F_isError(outputLen)) { + return lZ4Error(outputLen, "Lz4 compression failure: "); } - return static_cast(output_len); + return static_cast(outputLen); } - Result Decompress( - int64_t input_len, + Result decompress( + int64_t inputLen, const uint8_t* input, - int64_t output_buffer_len, - uint8_t* output_buffer) override { - ARROW_ASSIGN_OR_RAISE(auto decomp, MakeDecompressor()); + int64_t outputBufferLen, + uint8_t* outputBuffer) override { + ARROW_ASSIGN_OR_RAISE(auto decomp, makeDecompressor()); - int64_t total_bytes_written = 0; - while (!decomp->IsFinished() && input_len != 0) { + int64_t totalBytesWritten = 0; + while (!decomp->isFinished() && inputLen != 0) { ARROW_ASSIGN_OR_RAISE( auto res, - decomp->Decompress( - input_len, input, output_buffer_len, output_buffer)); - input += res.bytes_read; - input_len -= res.bytes_read; - output_buffer += res.bytes_written; - output_buffer_len -= res.bytes_written; - total_bytes_written += res.bytes_written; - if (res.need_more_output) { + decomp->decompress(inputLen, input, outputBufferLen, outputBuffer)); + input += res.bytesRead; + inputLen -= res.bytesRead; + outputBuffer += res.bytesWritten; + outputBufferLen -= res.bytesWritten; + totalBytesWritten += res.bytesWritten; + if (res.needMoreOutput) { return Status::IOError("Lz4 decompression buffer too small"); } } - if (!decomp->IsFinished()) { + if (!decomp->isFinished()) { return Status::IOError( "Lz4 compressed input contains less than one frame"); } - if (input_len != 0) { + if (inputLen != 0) { return Status::IOError( "Lz4 compressed input contains more than one frame"); } - return total_bytes_written; + return totalBytesWritten; } - Result> MakeCompressor() override { - auto ptr = std::make_shared(compression_level_); - RETURN_NOT_OK(ptr->Init()); + Result> makeCompressor() override { + auto ptr = std::make_shared(compressionLevel_); + RETURN_NOT_OK(ptr->init()); return ptr; } - Result> MakeDecompressor() override { + Result> makeDecompressor() override { auto ptr = std::make_shared(); - RETURN_NOT_OK(ptr->Init()); + RETURN_NOT_OK(ptr->init()); return ptr; } - Compression::type compression_type() const override { + Compression::type compressionType() const override { return Compression::LZ4_FRAME; } - int minimum_compression_level() const override { + int minimumCompressionLevel() const override { return kLz4MinCompressionLevel; } #if (defined(LZ4_VERSION_NUMBER) && LZ4_VERSION_NUMBER < 10800) - int maximum_compression_level() const override { + int maximumCompressionLevel() const override { return 12; } #else - int maximum_compression_level() const override { + int maximumCompressionLevel() const override { return LZ4F_compressionLevel_max(); } #endif - int default_compression_level() const override { + int defaultCompressionLevel() const override { return kLz4DefaultCompressionLevel; } - int compression_level() const override { - return compression_level_; + int compressionLevel() const override { + return compressionLevel_; } protected: - const int compression_level_; + const int compressionLevel_; const LZ4F_preferences_t prefs_; }; -// ---------------------------------------------------------------------- -// Lz4 "raw" codec implementation +// ----------------------------------------------------------------------. +// Lz4 "raw" codec implementation. class Lz4Codec : public Codec { public: - explicit Lz4Codec(int compression_level) - : compression_level_( - compression_level == kUseDefaultCompressionLevel + explicit Lz4Codec(int compressionLevel) + : compressionLevel_( + compressionLevel == kUseDefaultCompressionLevel ? kLz4DefaultCompressionLevel - : compression_level) {} + : compressionLevel) {} - Result Decompress( - int64_t input_len, + Result decompress( + int64_t inputLen, const uint8_t* input, - int64_t output_buffer_len, - uint8_t* output_buffer) override { - int64_t decompressed_size = LZ4_decompress_safe( + int64_t outputBufferLen, + uint8_t* outputBuffer) override { + int64_t decompressedSize = LZ4_decompress_safe( reinterpret_cast(input), - reinterpret_cast(output_buffer), - static_cast(input_len), - static_cast(output_buffer_len)); - if (decompressed_size < 0) { + reinterpret_cast(outputBuffer), + static_cast(inputLen), + static_cast(outputBufferLen)); + if (decompressedSize < 0) { return Status::IOError("Corrupt Lz4 compressed data."); } - return decompressed_size; + return decompressedSize; } - int64_t MaxCompressedLen( - int64_t input_len, + int64_t maxCompressedLen( + int64_t inputLen, const uint8_t* ARROW_ARG_UNUSED(input)) override { - return LZ4_compressBound(static_cast(input_len)); + return LZ4_compressBound(static_cast(inputLen)); } - Result Compress( - int64_t input_len, + Result compress( + int64_t inputLen, const uint8_t* input, - int64_t output_buffer_len, - uint8_t* output_buffer) override { - int64_t output_len; + int64_t outputBufferLen, + uint8_t* outputBuffer) override { + int64_t outputLen; #ifdef LZ4HC_CLEVEL_MIN - constexpr int min_hc_clevel = LZ4HC_CLEVEL_MIN; + constexpr int minHcClevel = LZ4HC_CLEVEL_MIN; #else // For older versions of the lz4 library - constexpr int min_hc_clevel = 3; + constexpr int minHcClevel = 3; #endif - if (compression_level_ < min_hc_clevel) { - output_len = LZ4_compress_default( + if (compressionLevel_ < minHcClevel) { + outputLen = LZ4_compress_default( reinterpret_cast(input), - reinterpret_cast(output_buffer), - static_cast(input_len), - static_cast(output_buffer_len)); + reinterpret_cast(outputBuffer), + static_cast(inputLen), + static_cast(outputBufferLen)); } else { - output_len = LZ4_compress_HC( + outputLen = LZ4_compress_HC( reinterpret_cast(input), - reinterpret_cast(output_buffer), - static_cast(input_len), - static_cast(output_buffer_len), - compression_level_); + reinterpret_cast(outputBuffer), + static_cast(inputLen), + static_cast(outputBufferLen), + compressionLevel_); } - if (output_len == 0) { + if (outputLen == 0) { return Status::IOError("Lz4 compression failure."); } - return output_len; + return outputLen; } - Result> MakeCompressor() override { + Result> makeCompressor() override { return Status::NotImplemented( "Streaming compression unsupported with LZ4 raw format. " "Try using LZ4 frame format instead."); } - Result> MakeDecompressor() override { + Result> makeDecompressor() override { return Status::NotImplemented( "Streaming decompression unsupported with LZ4 raw format. " "Try using LZ4 frame format instead."); } - Compression::type compression_type() const override { + Compression::type compressionType() const override { return Compression::LZ4; } - int minimum_compression_level() const override { + int minimumCompressionLevel() const override { return kLz4MinCompressionLevel; } #if (defined(LZ4_VERSION_NUMBER) && LZ4_VERSION_NUMBER < 10800) - int maximum_compression_level() const override { + int maximumCompressionLevel() const override { return 12; } #else - int maximum_compression_level() const override { + int maximumCompressionLevel() const override { return LZ4F_compressionLevel_max(); } #endif - int default_compression_level() const override { + int defaultCompressionLevel() const override { return kLz4DefaultCompressionLevel; } protected: - int compression_level_; + int compressionLevel_; }; -// ---------------------------------------------------------------------- -// Lz4 Hadoop "raw" codec implementation +// ----------------------------------------------------------------------. +// Lz4 Hadoop "raw" codec implementation. class Lz4HadoopCodec : public Lz4Codec { public: Lz4HadoopCodec() : Lz4Codec(kUseDefaultCompressionLevel) {} - Result Decompress( - int64_t input_len, + Result decompress( + int64_t inputLen, const uint8_t* input, - int64_t output_buffer_len, - uint8_t* output_buffer) override { - const int64_t decompressed_size = - TryDecompressHadoop(input_len, input, output_buffer_len, output_buffer); - if (decompressed_size != kNotHadoop) { - return decompressed_size; + int64_t outputBufferLen, + uint8_t* outputBuffer) override { + const int64_t decompressedSize = + tryDecompressHadoop(inputLen, input, outputBufferLen, outputBuffer); + if (decompressedSize != kNotHadoop) { + return decompressedSize; } - // Fall back on raw LZ4 codec (for files produces by earlier versions of + // Fall back on raw LZ4 codec (for files produces by earlier versions of. // Parquet C++) - return Lz4Codec::Decompress( - input_len, input, output_buffer_len, output_buffer); + return Lz4Codec::decompress(inputLen, input, outputBufferLen, outputBuffer); } - int64_t MaxCompressedLen( - int64_t input_len, + int64_t maxCompressedLen( + int64_t inputLen, const uint8_t* ARROW_ARG_UNUSED(input)) override { - return kPrefixLength + Lz4Codec::MaxCompressedLen(input_len, nullptr); + return kPrefixLength + Lz4Codec::maxCompressedLen(inputLen, nullptr); } - Result Compress( - int64_t input_len, + Result compress( + int64_t inputLen, const uint8_t* input, - int64_t output_buffer_len, - uint8_t* output_buffer) override { - if (output_buffer_len < kPrefixLength) { + int64_t outputBufferLen, + uint8_t* outputBuffer) override { + if (outputBufferLen < kPrefixLength) { return Status::Invalid( "Output buffer too small for Lz4HadoopCodec compression"); } ARROW_ASSIGN_OR_RAISE( - int64_t output_len, - Lz4Codec::Compress( - input_len, + int64_t outputLen, + Lz4Codec::compress( + inputLen, input, - output_buffer_len - kPrefixLength, - output_buffer + kPrefixLength)); + outputBufferLen - kPrefixLength, + outputBuffer + kPrefixLength)); - // Prepend decompressed size in bytes and compressed size in bytes - // to be compatible with Hadoop Lz4Codec - const uint32_t decompressed_size = - bit_util::ToBigEndian(static_cast(input_len)); - const uint32_t compressed_size = - bit_util::ToBigEndian(static_cast(output_len)); - ::arrow::util::SafeStore(output_buffer, decompressed_size); - ::arrow::util::SafeStore(output_buffer + sizeof(uint32_t), compressed_size); + // Prepend decompressed size in bytes and compressed size in bytes. + // To be compatible with Hadoop Lz4Codec. + const uint32_t decompressedSize = + bit_util::ToBigEndian(static_cast(inputLen)); + const uint32_t compressedSize = + bit_util::ToBigEndian(static_cast(outputLen)); + ::arrow::util::SafeStore(outputBuffer, decompressedSize); + ::arrow::util::SafeStore(outputBuffer + sizeof(uint32_t), compressedSize); - return kPrefixLength + output_len; + return kPrefixLength + outputLen; } - Result> MakeCompressor() override { + Result> makeCompressor() override { return Status::NotImplemented( "Streaming compression unsupported with LZ4 Hadoop raw format. " "Try using LZ4 frame format instead."); } - Result> MakeDecompressor() override { + Result> makeDecompressor() override { return Status::NotImplemented( "Streaming decompression unsupported with LZ4 Hadoop raw format. " "Try using LZ4 frame format instead."); } - Compression::type compression_type() const override { + Compression::type compressionType() const override { return Compression::LZ4_HADOOP; } protected: - // Offset starting at which page data can be read/written + // Offset starting at which page data can be read/written. static const int64_t kPrefixLength = sizeof(uint32_t) * 2; static const int64_t kNotHadoop = -1; - int64_t TryDecompressHadoop( - int64_t input_len, + int64_t tryDecompressHadoop( + int64_t inputLen, const uint8_t* input, - int64_t output_buffer_len, - uint8_t* output_buffer) { + int64_t outputBufferLen, + uint8_t* outputBuffer) { // Parquet files written with the Hadoop Lz4Codec use their own framing. - // The input buffer can contain an arbitrary number of "frames", each - // with the following structure: - // - bytes 0..3: big-endian uint32_t representing the frame decompressed - // size - // - bytes 4..7: big-endian uint32_t representing the frame compressed size - // - bytes 8...: frame compressed data + // The input buffer can contain an arbitrary number of "frames", each. + // With the following structure: + // - Bytes 0..3: big-endian uint32_t representing the frame decompressed + // size. + // - Bytes 4..7: big-endian uint32_t representing the frame compressed size. + // - Bytes 8...: frame compressed data. // // The Hadoop Lz4Codec source code can be found here: // https://github.com/apache/hadoop/blob/trunk/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-nativetask/src/main/native/src/codec/Lz4Codec.cc - int64_t total_decompressed_size = 0; + int64_t totalDecompressedSize = 0; - while (input_len >= kPrefixLength) { - const uint32_t expected_decompressed_size = + while (inputLen >= kPrefixLength) { + const uint32_t expectedDecompressedSize = bit_util::FromBigEndian(::arrow::util::SafeLoadAs(input)); - const uint32_t expected_compressed_size = bit_util::FromBigEndian( + const uint32_t expectedCompressedSize = bit_util::FromBigEndian( ::arrow::util::SafeLoadAs(input + sizeof(uint32_t))); input += kPrefixLength; - input_len -= kPrefixLength; + inputLen -= kPrefixLength; - if (input_len < expected_compressed_size) { - // Not enough bytes for Hadoop "frame" + if (inputLen < expectedCompressedSize) { + // Not enough bytes for Hadoop "frame". return kNotHadoop; } - if (output_buffer_len < expected_decompressed_size) { - // Not enough bytes to hold advertised output => probably not Hadoop + if (outputBufferLen < expectedDecompressedSize) { + // Not enough bytes to hold advertised output => probably not Hadoop. return kNotHadoop; } - // Try decompressing and compare with expected decompressed length - auto maybe_decompressed_size = Lz4Codec::Decompress( - expected_compressed_size, input, output_buffer_len, output_buffer); - if (!maybe_decompressed_size.ok() || - *maybe_decompressed_size != expected_decompressed_size) { + // Try decompressing and compare with expected decompressed length. + auto maybeDecompressedSize = Lz4Codec::decompress( + expectedCompressedSize, input, outputBufferLen, outputBuffer); + if (!maybeDecompressedSize.ok() || + *maybeDecompressedSize != expectedDecompressedSize) { return kNotHadoop; } - input += expected_compressed_size; - input_len -= expected_compressed_size; - output_buffer += expected_decompressed_size; - output_buffer_len -= expected_decompressed_size; - total_decompressed_size += expected_decompressed_size; + input += expectedCompressedSize; + inputLen -= expectedCompressedSize; + outputBuffer += expectedDecompressedSize; + outputBufferLen -= expectedDecompressedSize; + totalDecompressedSize += expectedDecompressedSize; } - if (input_len == 0) { - return total_decompressed_size; + if (inputLen == 0) { + return totalDecompressedSize; } else { return kNotHadoop; } } - int minimum_compression_level() const override { + int minimumCompressionLevel() const override { return kUseDefaultCompressionLevel; } - int maximum_compression_level() const override { + int maximumCompressionLevel() const override { return kUseDefaultCompressionLevel; } - int default_compression_level() const override { + int defaultCompressionLevel() const override { return kUseDefaultCompressionLevel; } }; } // namespace -std::unique_ptr MakeLz4FrameCodec(int compression_level) { - return std::make_unique(compression_level); +std::unique_ptr makeLz4FrameCodec(int compressionLevel) { + return std::make_unique(compressionLevel); } -std::unique_ptr MakeLz4HadoopRawCodec() { +std::unique_ptr makeLz4HadoopRawCodec() { return std::make_unique(); } -std::unique_ptr MakeLz4RawCodec(int compression_level) { - return std::make_unique(compression_level); +std::unique_ptr makeLz4RawCodec(int compressionLevel) { + return std::make_unique(compressionLevel); } } // namespace facebook::velox::parquet::arrow::util::internal diff --git a/velox/dwio/parquet/writer/arrow/util/CompressionSnappy.cpp b/velox/dwio/parquet/writer/arrow/util/CompressionSnappy.cpp index a79a4c95ccd..eaa186cd23c 100644 --- a/velox/dwio/parquet/writer/arrow/util/CompressionSnappy.cpp +++ b/velox/dwio/parquet/writer/arrow/util/CompressionSnappy.cpp @@ -37,88 +37,88 @@ namespace { using ::arrow::Result; -// ---------------------------------------------------------------------- -// Snappy implementation +// ----------------------------------------------------------------------. +// Snappy implementation. class SnappyCodec : public Codec { public: - Result Decompress( - int64_t input_len, + Result decompress( + int64_t inputLen, const uint8_t* input, - int64_t output_buffer_len, - uint8_t* output_buffer) override { - size_t decompressed_size; + int64_t outputBufferLen, + uint8_t* outputBuffer) override { + size_t decompressedSize; if (!snappy::GetUncompressedLength( reinterpret_cast(input), - static_cast(input_len), - &decompressed_size)) { + static_cast(inputLen), + &decompressedSize)) { return Status::IOError("Corrupt snappy compressed data."); } - if (output_buffer_len < static_cast(decompressed_size)) { + if (outputBufferLen < static_cast(decompressedSize)) { return Status::Invalid( "Output buffer size (", - output_buffer_len, + outputBufferLen, ") must be ", - decompressed_size, + decompressedSize, " or larger."); } if (!snappy::RawUncompress( reinterpret_cast(input), - static_cast(input_len), - reinterpret_cast(output_buffer))) { + static_cast(inputLen), + reinterpret_cast(outputBuffer))) { return Status::IOError("Corrupt snappy compressed data."); } - return static_cast(decompressed_size); + return static_cast(decompressedSize); } - int64_t MaxCompressedLen( - int64_t input_len, + int64_t maxCompressedLen( + int64_t inputLen, const uint8_t* ARROW_ARG_UNUSED(input)) override { - VELOX_DCHECK_GE(input_len, 0); - return snappy::MaxCompressedLength(static_cast(input_len)); + VELOX_DCHECK_GE(inputLen, 0); + return snappy::MaxCompressedLength(static_cast(inputLen)); } - Result Compress( - int64_t input_len, + Result compress( + int64_t inputLen, const uint8_t* input, - int64_t ARROW_ARG_UNUSED(output_buffer_len), - uint8_t* output_buffer) override { - size_t output_size; + int64_t ARROW_ARG_UNUSED(outputBufferLen), + uint8_t* outputBuffer) override { + size_t outputSize; snappy::RawCompress( reinterpret_cast(input), - static_cast(input_len), - reinterpret_cast(output_buffer), - &output_size); - return static_cast(output_size); + static_cast(inputLen), + reinterpret_cast(outputBuffer), + &outputSize); + return static_cast(outputSize); } - Result> MakeCompressor() override { + Result> makeCompressor() override { return Status::NotImplemented( "Streaming compression unsupported with Snappy"); } - Result> MakeDecompressor() override { + Result> makeDecompressor() override { return Status::NotImplemented( "Streaming decompression unsupported with Snappy"); } - Compression::type compression_type() const override { + Compression::type compressionType() const override { return Compression::SNAPPY; } - int minimum_compression_level() const override { + int minimumCompressionLevel() const override { return kUseDefaultCompressionLevel; } - int maximum_compression_level() const override { + int maximumCompressionLevel() const override { return kUseDefaultCompressionLevel; } - int default_compression_level() const override { + int defaultCompressionLevel() const override { return kUseDefaultCompressionLevel; } }; } // namespace -std::unique_ptr MakeSnappyCodec() { +std::unique_ptr makeSnappyCodec() { return std::make_unique(); } diff --git a/velox/dwio/parquet/writer/arrow/util/CompressionZlib.cpp b/velox/dwio/parquet/writer/arrow/util/CompressionZlib.cpp index 23ab7f25e34..0e5347138e1 100644 --- a/velox/dwio/parquet/writer/arrow/util/CompressionZlib.cpp +++ b/velox/dwio/parquet/writer/arrow/util/CompressionZlib.cpp @@ -34,19 +34,19 @@ namespace facebook::velox::parquet::arrow::util::internal { namespace { -// ---------------------------------------------------------------------- -// gzip implementation +// ----------------------------------------------------------------------. +// Gzip implementation. -// These are magic numbers from zlib.h. Not clear why they are not defined -// there. +// These are magic numbers from zlib.h. Not clear why they are not defined. +// There. -// Maximum window size +// Maximum window size. constexpr int kGZipMaxWindowBits = 15; -// Minimum window size +// Minimum window size. constexpr int kGZipMinWindowBits = 9; -// Default window size +// Default window size. constexpr int kGZipDefaultWindowBits = 15; // Output Gzip. @@ -58,41 +58,41 @@ constexpr int DETECT_CODEC = 32; constexpr int kGZipMinCompressionLevel = 1; constexpr int kGZipMaxCompressionLevel = 9; -int CompressionWindowBitsForFormat(GZipFormat format, int window_bits) { +int compressionWindowBitsForFormat(GZipFormat format, int windowBits) { switch (format) { case GZipFormat::DEFLATE: - window_bits = -window_bits; + windowBits = -windowBits; break; case GZipFormat::GZIP: - window_bits += GZIP_CODEC; + windowBits += GZIP_CODEC; break; case GZipFormat::ZLIB: break; } - return window_bits; + return windowBits; } -int DecompressionWindowBitsForFormat(GZipFormat format, int window_bits) { +int decompressionWindowBitsForFormat(GZipFormat format, int windowBits) { if (format == GZipFormat::DEFLATE) { - return -window_bits; + return -windowBits; } else { /* If not deflate, autodetect format from header */ - return window_bits | DETECT_CODEC; + return windowBits | DETECT_CODEC; } } -Status ZlibErrorPrefix(const char* prefix_msg, const char* msg) { - return Status::IOError(prefix_msg, (msg) ? msg : "(unknown error)"); +Status zlibErrorPrefix(const char* prefixMsg, const char* msg) { + return Status::IOError(prefixMsg, (msg) ? msg : "(unknown error)"); } -// ---------------------------------------------------------------------- -// gzip decompressor implementation +// ----------------------------------------------------------------------. +// Gzip Decompressor implementation. class GZipDecompressor : public Decompressor { public: - explicit GZipDecompressor(GZipFormat format, int window_bits) + explicit GZipDecompressor(GZipFormat format, int windowBits) : format_(format), - window_bits_(window_bits), + windowBits_(windowBits), initialized_(false), finished_(false) {} @@ -102,88 +102,88 @@ class GZipDecompressor : public Decompressor { } } - Status Init() { + Status init() { VELOX_DCHECK(!initialized_); memset(&stream_, 0, sizeof(stream_)); finished_ = false; int ret; - int window_bits = DecompressionWindowBitsForFormat(format_, window_bits_); - if ((ret = inflateInit2(&stream_, window_bits)) != Z_OK) { - return ZlibError("zlib inflateInit failed: "); + int windowBits = decompressionWindowBitsForFormat(format_, windowBits_); + if ((ret = inflateInit2(&stream_, windowBits)) != Z_OK) { + return zlibError("zlib inflateInit failed: "); } else { initialized_ = true; return Status::OK(); } } - Status Reset() override { + Status reset() override { VELOX_DCHECK(initialized_); finished_ = false; int ret; if ((ret = inflateReset(&stream_)) != Z_OK) { - return ZlibError("zlib inflateReset failed: "); + return zlibError("zlib inflateReset failed: "); } else { return Status::OK(); } } - Result Decompress( - int64_t input_len, + Result decompress( + int64_t inputLen, const uint8_t* input, - int64_t output_len, + int64_t outputLen, uint8_t* output) override { - static constexpr auto input_limit = + static constexpr auto inputLimit = static_cast(std::numeric_limits::max()); stream_.next_in = const_cast(reinterpret_cast(input)); - stream_.avail_in = static_cast(std::min(input_len, input_limit)); + stream_.avail_in = static_cast(std::min(inputLen, inputLimit)); stream_.next_out = reinterpret_cast(output); - stream_.avail_out = static_cast(std::min(output_len, input_limit)); + stream_.avail_out = static_cast(std::min(outputLen, inputLimit)); int ret; ret = inflate(&stream_, Z_SYNC_FLUSH); if (ret == Z_DATA_ERROR || ret == Z_STREAM_ERROR || ret == Z_MEM_ERROR) { - return ZlibError("zlib inflate failed: "); + return zlibError("zlib inflate failed: "); } if (ret == Z_NEED_DICT) { - return ZlibError("zlib inflate failed (need preset dictionary): "); + return zlibError("zlib inflate failed (need preset dictionary): "); } finished_ = (ret == Z_STREAM_END); if (ret == Z_BUF_ERROR) { - // No progress was possible + // No progress was possible. return DecompressResult{0, 0, true}; } else { VELOX_DCHECK(ret == Z_OK || ret == Z_STREAM_END); - // Some progress has been made + // Some progress has been made. return DecompressResult{ - input_len - stream_.avail_in, output_len - stream_.avail_out, false}; + inputLen - stream_.avail_in, outputLen - stream_.avail_out, false}; } return Status::OK(); } - bool IsFinished() override { + bool isFinished() override { return finished_; } protected: - Status ZlibError(const char* prefix_msg) { - return ZlibErrorPrefix(prefix_msg, stream_.msg); + Status zlibError(const char* prefixMsg) { + return zlibErrorPrefix(prefixMsg, stream_.msg); } z_stream stream_; GZipFormat format_; - int window_bits_; + int windowBits_; bool initialized_; bool finished_; }; -// ---------------------------------------------------------------------- -// gzip compressor implementation +// ----------------------------------------------------------------------. +// Gzip Compressor implementation. class GZipCompressor : public Compressor { public: - explicit GZipCompressor(int compression_level) - : initialized_(false), compression_level_(compression_level) {} + explicit GZipCompressor(int compressionLevel) + : initialized_(false), compressionLevel_(compressionLevel) {} ~GZipCompressor() override { if (initialized_) { @@ -191,378 +191,377 @@ class GZipCompressor : public Compressor { } } - Status Init(GZipFormat format, int input_window_bits) { + Status init(GZipFormat format, int inputWindowBits) { VELOX_DCHECK(!initialized_); memset(&stream_, 0, sizeof(stream_)); int ret; - // Initialize to run specified format - int window_bits = CompressionWindowBitsForFormat(format, input_window_bits); + // Initialize to run specified format. + int windowBits = compressionWindowBitsForFormat(format, inputWindowBits); if ((ret = deflateInit2( &stream_, Z_DEFAULT_COMPRESSION, Z_DEFLATED, - window_bits, - compression_level_, + windowBits, + compressionLevel_, Z_DEFAULT_STRATEGY)) != Z_OK) { - return ZlibError("zlib deflateInit failed: "); + return zlibError("zlib deflateInit failed: "); } else { initialized_ = true; return Status::OK(); } } - Result Compress( - int64_t input_len, + Result compress( + int64_t inputLen, const uint8_t* input, - int64_t output_len, + int64_t outputLen, uint8_t* output) override { VELOX_DCHECK(initialized_, "Called on non-initialized stream"); - static constexpr auto input_limit = + static constexpr auto inputLimit = static_cast(std::numeric_limits::max()); stream_.next_in = const_cast(reinterpret_cast(input)); - stream_.avail_in = static_cast(std::min(input_len, input_limit)); + stream_.avail_in = static_cast(std::min(inputLen, inputLimit)); stream_.next_out = reinterpret_cast(output); - stream_.avail_out = static_cast(std::min(output_len, input_limit)); + stream_.avail_out = static_cast(std::min(outputLen, inputLimit)); int64_t ret = 0; ret = deflate(&stream_, Z_NO_FLUSH); if (ret == Z_STREAM_ERROR) { - return ZlibError("zlib compress failed: "); + return zlibError("zlib compress failed: "); } if (ret == Z_OK) { - // Some progress has been made + // Some progress has been made. return CompressResult{ - input_len - stream_.avail_in, output_len - stream_.avail_out}; + inputLen - stream_.avail_in, outputLen - stream_.avail_out}; } else { - // No progress was possible + // No progress was possible. VELOX_DCHECK_EQ(ret, Z_BUF_ERROR); return CompressResult{0, 0}; } } - Result Flush(int64_t output_len, uint8_t* output) override { + Result flush(int64_t outputLen, uint8_t* output) override { VELOX_DCHECK(initialized_, "Called on non-initialized stream"); - static constexpr auto input_limit = + static constexpr auto inputLimit = static_cast(std::numeric_limits::max()); stream_.avail_in = 0; stream_.next_out = reinterpret_cast(output); - stream_.avail_out = static_cast(std::min(output_len, input_limit)); + stream_.avail_out = static_cast(std::min(outputLen, inputLimit)); int64_t ret = 0; ret = deflate(&stream_, Z_SYNC_FLUSH); if (ret == Z_STREAM_ERROR) { - return ZlibError("zlib flush failed: "); + return zlibError("zlib flush failed: "); } - int64_t bytes_written; + int64_t bytesWritten; if (ret == Z_OK) { - bytes_written = output_len - stream_.avail_out; + bytesWritten = outputLen - stream_.avail_out; } else { VELOX_DCHECK_EQ(ret, Z_BUF_ERROR); - bytes_written = 0; + bytesWritten = 0; } - // "If deflate returns with avail_out == 0, this function must be called - // again with the same value of the flush parameter and more output space - // (updated avail_out), until the flush is complete (deflate returns - // with non-zero avail_out)." - // "Note that Z_BUF_ERROR is not fatal, and deflate() can be called again - // with more input and more output space to continue compressing." - return FlushResult{bytes_written, stream_.avail_out == 0}; + // "If deflate returns with avail_out == 0, this function must be called. + // Again with the same value of the flush parameter and more output space. + // (Updated avail_out), until the flush is complete (deflate returns. + // With non-zero avail_out).". + // "Note that Z_BUF_ERROR is not fatal, and deflate() can be called again. + // With more input and more output space to continue compressing.". + return FlushResult{bytesWritten, stream_.avail_out == 0}; } - Result End(int64_t output_len, uint8_t* output) override { + Result end(int64_t outputLen, uint8_t* output) override { VELOX_DCHECK(initialized_, "Called on non-initialized stream"); - static constexpr auto input_limit = + static constexpr auto inputLimit = static_cast(std::numeric_limits::max()); stream_.avail_in = 0; stream_.next_out = reinterpret_cast(output); - stream_.avail_out = static_cast(std::min(output_len, input_limit)); + stream_.avail_out = static_cast(std::min(outputLen, inputLimit)); int64_t ret = 0; ret = deflate(&stream_, Z_FINISH); if (ret == Z_STREAM_ERROR) { - return ZlibError("zlib flush failed: "); + return zlibError("zlib flush failed: "); } - int64_t bytes_written = output_len - stream_.avail_out; + int64_t bytesWritten = outputLen - stream_.avail_out; if (ret == Z_STREAM_END) { - // Flush complete, we can now end the stream + // Flush complete, we can now end the stream. initialized_ = false; ret = deflateEnd(&stream_); if (ret == Z_OK) { - return EndResult{bytes_written, false}; + return EndResult{bytesWritten, false}; } else { - return ZlibError("zlib end failed: "); + return zlibError("zlib end failed: "); } } else { - // Not everything could be flushed, - return EndResult{bytes_written, true}; + // Not everything could be flushed,. + return EndResult{bytesWritten, true}; } } protected: - Status ZlibError(const char* prefix_msg) { - return ZlibErrorPrefix(prefix_msg, stream_.msg); + Status zlibError(const char* prefixMsg) { + return zlibErrorPrefix(prefixMsg, stream_.msg); } z_stream stream_; bool initialized_; - int compression_level_; + int compressionLevel_; }; -// ---------------------------------------------------------------------- -// gzip codec implementation +// ----------------------------------------------------------------------. +// Gzip codec implementation. class GZipCodec : public Codec { public: - explicit GZipCodec(int compression_level, GZipFormat format, int window_bits) + explicit GZipCodec(int compressionLevel, GZipFormat format, int windowBits) : format_(format), - window_bits_(window_bits), - compressor_initialized_(false), - decompressor_initialized_(false) { - compression_level_ = compression_level == kUseDefaultCompressionLevel + windowBits_(windowBits), + compressorInitialized_(false), + decompressorInitialized_(false) { + compressionLevel_ = compressionLevel == kUseDefaultCompressionLevel ? kGZipDefaultCompressionLevel - : compression_level; + : compressionLevel; } ~GZipCodec() override { - EndCompressor(); - EndDecompressor(); + endCompressor(); + endDecompressor(); } - Result> MakeCompressor() override { - auto ptr = std::make_shared(compression_level_); - RETURN_NOT_OK(ptr->Init(format_, window_bits_)); + Result> makeCompressor() override { + auto ptr = std::make_shared(compressionLevel_); + RETURN_NOT_OK(ptr->init(format_, windowBits_)); return ptr; } - Result> MakeDecompressor() override { - auto ptr = std::make_shared(format_, window_bits_); - RETURN_NOT_OK(ptr->Init()); + Result> makeDecompressor() override { + auto ptr = std::make_shared(format_, windowBits_); + RETURN_NOT_OK(ptr->init()); return ptr; } - Status InitCompressor() { - EndDecompressor(); + Status initCompressor() { + endDecompressor(); memset(&stream_, 0, sizeof(stream_)); int ret; - // Initialize to run specified format - int window_bits = CompressionWindowBitsForFormat(format_, window_bits_); + // Initialize to run specified format. + int windowBits = compressionWindowBitsForFormat(format_, windowBits_); if ((ret = deflateInit2( &stream_, Z_DEFAULT_COMPRESSION, Z_DEFLATED, - window_bits, - compression_level_, + windowBits, + compressionLevel_, Z_DEFAULT_STRATEGY)) != Z_OK) { - return ZlibErrorPrefix("zlib deflateInit failed: ", stream_.msg); + return zlibErrorPrefix("zlib deflateInit failed: ", stream_.msg); } - compressor_initialized_ = true; + compressorInitialized_ = true; return Status::OK(); } - void EndCompressor() { - if (compressor_initialized_) { + void endCompressor() { + if (compressorInitialized_) { static_cast(deflateEnd(&stream_)); } - compressor_initialized_ = false; + compressorInitialized_ = false; } - Status InitDecompressor() { - EndCompressor(); + Status initDecompressor() { + endCompressor(); memset(&stream_, 0, sizeof(stream_)); int ret; - // Initialize to run either deflate or zlib/gzip format - int window_bits = DecompressionWindowBitsForFormat(format_, window_bits_); - if ((ret = inflateInit2(&stream_, window_bits)) != Z_OK) { - return ZlibErrorPrefix("zlib inflateInit failed: ", stream_.msg); + // Initialize to run either deflate or zlib/gzip format. + int windowBits = decompressionWindowBitsForFormat(format_, windowBits_); + if ((ret = inflateInit2(&stream_, windowBits)) != Z_OK) { + return zlibErrorPrefix("zlib inflateInit failed: ", stream_.msg); } - decompressor_initialized_ = true; + decompressorInitialized_ = true; return Status::OK(); } - void EndDecompressor() { - if (decompressor_initialized_) { + void endDecompressor() { + if (decompressorInitialized_) { static_cast(inflateEnd(&stream_)); } - decompressor_initialized_ = false; + decompressorInitialized_ = false; } - Result Decompress( - int64_t input_length, + Result decompress( + int64_t inputLength, const uint8_t* input, - int64_t output_buffer_length, + int64_t outputBufferLength, uint8_t* output) override { - if (!decompressor_initialized_) { - RETURN_NOT_OK(InitDecompressor()); + if (!decompressorInitialized_) { + RETURN_NOT_OK(initDecompressor()); } - if (output_buffer_length == 0) { - // The zlib library does not allow *output to be NULL, even when - // output_buffer_length is 0 (inflate() will return Z_STREAM_ERROR). We - // don't consider this an error, so bail early if no output is expected. - // Note that we don't signal an error if the input actually contains - // compressed data. + if (outputBufferLength == 0) { + // The zlib library does not allow *output to be NULL, even when. + // Output_buffer_length is 0 (inflate() will return Z_STREAM_ERROR). We. + // Don't consider this an error, so bail early if no output is expected. + // Note that we don't signal an error if the input actually contains. + // Compressed data. return 0; } - // Reset the stream for this block + // Reset the stream for this block. if (inflateReset(&stream_) != Z_OK) { - return ZlibErrorPrefix("zlib inflateReset failed: ", stream_.msg); + return zlibErrorPrefix("zlib inflateReset failed: ", stream_.msg); } int ret = 0; - // gzip can run in streaming mode or non-streaming mode. We only - // support the non-streaming use case where we present it the entire - // compressed input and a buffer big enough to contain the entire - // compressed output. In the case where we don't know the output, - // we just make a bigger buffer and try the non-streaming mode - // from the beginning again. + // Gzip can run in streaming mode or non-streaming mode. We only. + // Support the non-streaming use case where we present it the entire. + // Compressed input and a buffer big enough to contain the entire. + // Compressed output. In the case where we don't know the output,. + // We just make a bigger buffer and try the non-streaming mode. + // From the beginning again. while (ret != Z_STREAM_END) { stream_.next_in = const_cast(reinterpret_cast(input)); - stream_.avail_in = static_cast(input_length); + stream_.avail_in = static_cast(inputLength); stream_.next_out = reinterpret_cast(output); - stream_.avail_out = static_cast(output_buffer_length); + stream_.avail_out = static_cast(outputBufferLength); - // We know the output size. In this case, we can use Z_FINISH - // which is more efficient. + // We know the output size. In this case, we can use Z_FINISH. + // Which is more efficient. ret = inflate(&stream_, Z_FINISH); if (ret == Z_STREAM_END || ret != Z_OK) { break; } - // Failure, buffer was too small + // Failure, buffer was too small. return Status::IOError( "Too small a buffer passed to GZipCodec. InputLength=", - input_length, + inputLength, " OutputLength=", - output_buffer_length); + outputBufferLength); } - // Failure for some other reason + // Failure for some other reason. if (ret != Z_STREAM_END) { - return ZlibErrorPrefix("GZipCodec failed: ", stream_.msg); + return zlibErrorPrefix("GZipCodec failed: ", stream_.msg); } return stream_.total_out; } - int64_t MaxCompressedLen( - int64_t input_length, + int64_t maxCompressedLen( + int64_t inputLength, const uint8_t* ARROW_ARG_UNUSED(input)) override { - // Must be in compression mode - if (!compressor_initialized_) { - Status s = InitCompressor(); + // Must be in compression mode. + if (!compressorInitialized_) { + Status s = initCompressor(); VELOX_DCHECK(s.ok(), s.ToString()); } - int64_t max_len = deflateBound(&stream_, static_cast(input_length)); - // ARROW-3514: return a more pessimistic estimate to account for bugs - // in old zlib versions. - return max_len + 12; + int64_t maxLen = deflateBound(&stream_, static_cast(inputLength)); + // ARROW-3514: return a more pessimistic estimate to account for bugs. + // In old zlib versions. + return maxLen + 12; } - Result Compress( - int64_t input_length, + Result compress( + int64_t inputLength, const uint8_t* input, - int64_t output_buffer_len, + int64_t outputBufferLen, uint8_t* output) override { - if (!compressor_initialized_) { - RETURN_NOT_OK(InitCompressor()); + if (!compressorInitialized_) { + RETURN_NOT_OK(initCompressor()); } stream_.next_in = const_cast(reinterpret_cast(input)); - stream_.avail_in = static_cast(input_length); + stream_.avail_in = static_cast(inputLength); stream_.next_out = reinterpret_cast(output); - stream_.avail_out = static_cast(output_buffer_len); + stream_.avail_out = static_cast(outputBufferLen); int64_t ret = 0; if ((ret = deflate(&stream_, Z_FINISH)) != Z_STREAM_END) { if (ret == Z_OK) { - // Will return Z_OK (and stream.msg NOT set) if stream.avail_out is too - // small + // Will return Z_OK (and stream.msg NOT set) if stream.avail_out is too. + // Small. return Status::IOError("zlib deflate failed, output buffer too small"); } - return ZlibErrorPrefix("zlib deflate failed: ", stream_.msg); + return zlibErrorPrefix("zlib deflate failed: ", stream_.msg); } if (deflateReset(&stream_) != Z_OK) { - return ZlibErrorPrefix("zlib deflateReset failed: ", stream_.msg); + return zlibErrorPrefix("zlib deflateReset failed: ", stream_.msg); } - // Actual output length - return output_buffer_len - stream_.avail_out; + // Actual output length. + return outputBufferLen - stream_.avail_out; } - Status Init() override { - if (window_bits_ < kGZipMinWindowBits || - window_bits_ > kGZipMaxWindowBits) { + Status init() override { + if (windowBits_ < kGZipMinWindowBits || windowBits_ > kGZipMaxWindowBits) { return Status::Invalid( "GZip window_bits should be between ", kGZipMinWindowBits, " and ", kGZipMaxWindowBits); } - const Status init_compressor_status = InitCompressor(); - if (!init_compressor_status.ok()) { - return init_compressor_status; + const Status initCompressorStatus = initCompressor(); + if (!initCompressorStatus.ok()) { + return initCompressorStatus; } - return InitDecompressor(); + return initDecompressor(); } - Compression::type compression_type() const override { + Compression::type compressionType() const override { return Compression::GZIP; } - int compression_level() const override { - return compression_level_; + int compressionLevel() const override { + return compressionLevel_; } - int minimum_compression_level() const override { + int minimumCompressionLevel() const override { return kGZipMinCompressionLevel; } - int maximum_compression_level() const override { + int maximumCompressionLevel() const override { return kGZipMaxCompressionLevel; } - int default_compression_level() const override { + int defaultCompressionLevel() const override { return kGZipDefaultCompressionLevel; } private: - // zlib is stateful and the z_stream state variable must be initialized - // before + // Zlib is stateful and the z_stream state variable must be initialized. + // Before. z_stream stream_; - // Realistically, this will always be GZIP, but we leave the option open to - // configure + // Realistically, this will always be GZIP, but we leave the option open to. + // Configure. GZipFormat format_; - // These variables are mutually exclusive. When the codec is in "compressor" - // state, compressor_initialized_ is true while decompressor_initialized_ is - // false. When it's decompressing, the opposite is true. + // These variables are mutually exclusive. When the codec is in "Compressor". + // State, compressor_initialized_ is true while decompressor_initialized_ is. + // False. When it's decompressing, the opposite is true. // - // Indeed, this is slightly hacky, but the alternative is having separate - // Compressor and Decompressor classes. If this ever becomes an issue, we can - // perform the refactoring then - int window_bits_; - bool compressor_initialized_; - bool decompressor_initialized_; - int compression_level_; + // Indeed, this is slightly hacky, but the alternative is having separate. + // Compressor and Decompressor classes. If this ever becomes an issue, we can. + // Perform the refactoring then. + int windowBits_; + bool compressorInitialized_; + bool decompressorInitialized_; + int compressionLevel_; }; } // namespace -std::unique_ptr MakeGZipCodec( - int compression_level, +std::unique_ptr makeGZipCodec( + int compressionLevel, GZipFormat format, - std::optional window_bits) { + std::optional windowBits) { return std::make_unique( - compression_level, format, window_bits.value_or(kGZipDefaultWindowBits)); + compressionLevel, format, windowBits.value_or(kGZipDefaultWindowBits)); } } // namespace facebook::velox::parquet::arrow::util::internal diff --git a/velox/dwio/parquet/writer/arrow/util/CompressionZstd.cpp b/velox/dwio/parquet/writer/arrow/util/CompressionZstd.cpp index 22408c77210..71fc0f7bad8 100644 --- a/velox/dwio/parquet/writer/arrow/util/CompressionZstd.cpp +++ b/velox/dwio/parquet/writer/arrow/util/CompressionZstd.cpp @@ -35,12 +35,12 @@ using std::size_t; namespace facebook::velox::parquet::arrow::util::internal { namespace { -Status ZSTDError(size_t ret, const char* prefix_msg) { - return Status::IOError(prefix_msg, ZSTD_getErrorName(ret)); +Status zSTDError(size_t ret, const char* prefixMsg) { + return Status::IOError(prefixMsg, ZSTD_getErrorName(ret)); } -// ---------------------------------------------------------------------- -// ZSTD decompressor implementation +// ----------------------------------------------------------------------. +// ZSTD decompressor implementation. class ZSTDDecompressor : public Decompressor { public: @@ -50,48 +50,48 @@ class ZSTDDecompressor : public Decompressor { ZSTD_freeDStream(stream_); } - Status Init() { + Status init() { finished_ = false; size_t ret = ZSTD_initDStream(stream_); if (ZSTD_isError(ret)) { - return ZSTDError(ret, "ZSTD init failed: "); + return zSTDError(ret, "ZSTD init failed: "); } else { return Status::OK(); } } - Result Decompress( - int64_t input_len, + Result decompress( + int64_t inputLen, const uint8_t* input, - int64_t output_len, + int64_t outputLen, uint8_t* output) override { - ZSTD_inBuffer in_buf; - ZSTD_outBuffer out_buf; + ZSTD_inBuffer inBuf; + ZSTD_outBuffer outBuf; - in_buf.src = input; - in_buf.size = static_cast(input_len); - in_buf.pos = 0; - out_buf.dst = output; - out_buf.size = static_cast(output_len); - out_buf.pos = 0; + inBuf.src = input; + inBuf.size = static_cast(inputLen); + inBuf.pos = 0; + outBuf.dst = output; + outBuf.size = static_cast(outputLen); + outBuf.pos = 0; size_t ret; - ret = ZSTD_decompressStream(stream_, &out_buf, &in_buf); + ret = ZSTD_decompressStream(stream_, &outBuf, &inBuf); if (ZSTD_isError(ret)) { - return ZSTDError(ret, "ZSTD decompress failed: "); + return zSTDError(ret, "ZSTD decompress failed: "); } finished_ = (ret == 0); return DecompressResult{ - static_cast(in_buf.pos), - static_cast(out_buf.pos), - in_buf.pos == 0 && out_buf.pos == 0}; + static_cast(inBuf.pos), + static_cast(outBuf.pos), + inBuf.pos == 0 && outBuf.pos == 0}; } - Status Reset() override { - return Init(); + Status reset() override { + return init(); } - bool IsFinished() override { + bool isFinished() override { return finished_; } @@ -100,186 +100,186 @@ class ZSTDDecompressor : public Decompressor { bool finished_; }; -// ---------------------------------------------------------------------- -// ZSTD compressor implementation +// ----------------------------------------------------------------------. +// ZSTD compressor implementation. class ZSTDCompressor : public Compressor { public: - explicit ZSTDCompressor(int compression_level) - : stream_(ZSTD_createCStream()), compression_level_(compression_level) {} + explicit ZSTDCompressor(int compressionLevel) + : stream_(ZSTD_createCStream()), compressionLevel_(compressionLevel) {} ~ZSTDCompressor() override { ZSTD_freeCStream(stream_); } - Status Init() { - size_t ret = ZSTD_initCStream(stream_, compression_level_); + Status init() { + size_t ret = ZSTD_initCStream(stream_, compressionLevel_); if (ZSTD_isError(ret)) { - return ZSTDError(ret, "ZSTD init failed: "); + return zSTDError(ret, "ZSTD init failed: "); } else { return Status::OK(); } } - Result Compress( - int64_t input_len, + Result compress( + int64_t inputLen, const uint8_t* input, - int64_t output_len, + int64_t outputLen, uint8_t* output) override { - ZSTD_inBuffer in_buf; - ZSTD_outBuffer out_buf; + ZSTD_inBuffer inBuf; + ZSTD_outBuffer outBuf; - in_buf.src = input; - in_buf.size = static_cast(input_len); - in_buf.pos = 0; - out_buf.dst = output; - out_buf.size = static_cast(output_len); - out_buf.pos = 0; + inBuf.src = input; + inBuf.size = static_cast(inputLen); + inBuf.pos = 0; + outBuf.dst = output; + outBuf.size = static_cast(outputLen); + outBuf.pos = 0; size_t ret; - ret = ZSTD_compressStream(stream_, &out_buf, &in_buf); + ret = ZSTD_compressStream(stream_, &outBuf, &inBuf); if (ZSTD_isError(ret)) { - return ZSTDError(ret, "ZSTD compress failed: "); + return zSTDError(ret, "ZSTD compress failed: "); } return CompressResult{ - static_cast(in_buf.pos), static_cast(out_buf.pos)}; + static_cast(inBuf.pos), static_cast(outBuf.pos)}; } - Result Flush(int64_t output_len, uint8_t* output) override { - ZSTD_outBuffer out_buf; + Result flush(int64_t outputLen, uint8_t* output) override { + ZSTD_outBuffer outBuf; - out_buf.dst = output; - out_buf.size = static_cast(output_len); - out_buf.pos = 0; + outBuf.dst = output; + outBuf.size = static_cast(outputLen); + outBuf.pos = 0; size_t ret; - ret = ZSTD_flushStream(stream_, &out_buf); + ret = ZSTD_flushStream(stream_, &outBuf); if (ZSTD_isError(ret)) { - return ZSTDError(ret, "ZSTD flush failed: "); + return zSTDError(ret, "ZSTD flush failed: "); } - return FlushResult{static_cast(out_buf.pos), ret > 0}; + return FlushResult{static_cast(outBuf.pos), ret > 0}; } - Result End(int64_t output_len, uint8_t* output) override { - ZSTD_outBuffer out_buf; + Result end(int64_t outputLen, uint8_t* output) override { + ZSTD_outBuffer outBuf; - out_buf.dst = output; - out_buf.size = static_cast(output_len); - out_buf.pos = 0; + outBuf.dst = output; + outBuf.size = static_cast(outputLen); + outBuf.pos = 0; size_t ret; - ret = ZSTD_endStream(stream_, &out_buf); + ret = ZSTD_endStream(stream_, &outBuf); if (ZSTD_isError(ret)) { - return ZSTDError(ret, "ZSTD end failed: "); + return zSTDError(ret, "ZSTD end failed: "); } - return EndResult{static_cast(out_buf.pos), ret > 0}; + return EndResult{static_cast(outBuf.pos), ret > 0}; } protected: ZSTD_CStream* stream_; private: - int compression_level_; + int compressionLevel_; }; -// ---------------------------------------------------------------------- -// ZSTD codec implementation +// ----------------------------------------------------------------------. +// ZSTD codec implementation. class ZSTDCodec : public Codec { public: - explicit ZSTDCodec(int compression_level) - : compression_level_( - compression_level == kUseDefaultCompressionLevel + explicit ZSTDCodec(int compressionLevel) + : compressionLevel_( + compressionLevel == kUseDefaultCompressionLevel ? kZSTDDefaultCompressionLevel - : compression_level) {} + : compressionLevel) {} - Result Decompress( - int64_t input_len, + Result decompress( + int64_t inputLen, const uint8_t* input, - int64_t output_buffer_len, - uint8_t* output_buffer) override { - if (output_buffer == nullptr) { - // We may pass a NULL 0-byte output buffer but some zstd versions demand - // a valid pointer: https://github.com/facebook/zstd/issues/1385 - static uint8_t empty_buffer; - VELOX_DCHECK_EQ(output_buffer_len, 0); - output_buffer = &empty_buffer; + int64_t outputBufferLen, + uint8_t* outputBuffer) override { + if (outputBuffer == nullptr) { + // We may pass a NULL 0-byte output buffer but some zstd versions demand a + // valid pointer: https://github.com/facebook/zstd/issues/1385. + static uint8_t emptyBuffer; + VELOX_DCHECK_EQ(outputBufferLen, 0); + outputBuffer = &emptyBuffer; } size_t ret = ZSTD_decompress( - output_buffer, - static_cast(output_buffer_len), + outputBuffer, + static_cast(outputBufferLen), input, - static_cast(input_len)); + static_cast(inputLen)); if (ZSTD_isError(ret)) { - return ZSTDError(ret, "ZSTD decompression failed: "); + return zSTDError(ret, "ZSTD decompression failed: "); } - if (static_cast(ret) != output_buffer_len) { + if (static_cast(ret) != outputBufferLen) { return Status::IOError("Corrupt ZSTD compressed data."); } return static_cast(ret); } - int64_t MaxCompressedLen( - int64_t input_len, + int64_t maxCompressedLen( + int64_t inputLen, const uint8_t* ARROW_ARG_UNUSED(input)) override { - VELOX_DCHECK_GE(input_len, 0); - return ZSTD_compressBound(static_cast(input_len)); + VELOX_DCHECK_GE(inputLen, 0); + return ZSTD_compressBound(static_cast(inputLen)); } - Result Compress( - int64_t input_len, + Result compress( + int64_t inputLen, const uint8_t* input, - int64_t output_buffer_len, - uint8_t* output_buffer) override { + int64_t outputBufferLen, + uint8_t* outputBuffer) override { size_t ret = ZSTD_compress( - output_buffer, - static_cast(output_buffer_len), + outputBuffer, + static_cast(outputBufferLen), input, - static_cast(input_len), - compression_level_); + static_cast(inputLen), + compressionLevel_); if (ZSTD_isError(ret)) { - return ZSTDError(ret, "ZSTD compression failed: "); + return zSTDError(ret, "ZSTD compression failed: "); } return static_cast(ret); } - Result> MakeCompressor() override { - auto ptr = std::make_shared(compression_level_); - RETURN_NOT_OK(ptr->Init()); + Result> makeCompressor() override { + auto ptr = std::make_shared(compressionLevel_); + RETURN_NOT_OK(ptr->init()); return ptr; } - Result> MakeDecompressor() override { + Result> makeDecompressor() override { auto ptr = std::make_shared(); - RETURN_NOT_OK(ptr->Init()); + RETURN_NOT_OK(ptr->init()); return ptr; } - Compression::type compression_type() const override { + Compression::type compressionType() const override { return Compression::ZSTD; } - int minimum_compression_level() const override { + int minimumCompressionLevel() const override { return ZSTD_minCLevel(); } - int maximum_compression_level() const override { + int maximumCompressionLevel() const override { return ZSTD_maxCLevel(); } - int default_compression_level() const override { + int defaultCompressionLevel() const override { return kZSTDDefaultCompressionLevel; } - int compression_level() const override { - return compression_level_; + int compressionLevel() const override { + return compressionLevel_; } private: - const int compression_level_; + const int compressionLevel_; }; } // namespace -std::unique_ptr MakeZSTDCodec(int compression_level) { - return std::make_unique(compression_level); +std::unique_ptr makeZSTDCodec(int compressionLevel) { + return std::make_unique(compressionLevel); } -} // namespace facebook::velox::parquet::arrow::util::internal. +} // namespace facebook::velox::parquet::arrow::util::internal diff --git a/velox/dwio/parquet/writer/arrow/util/Crc32.cpp b/velox/dwio/parquet/writer/arrow/util/Crc32.cpp index 9b00633f32f..1e37d763aa1 100644 --- a/velox/dwio/parquet/writer/arrow/util/Crc32.cpp +++ b/velox/dwio/parquet/writer/arrow/util/Crc32.cpp @@ -97,7 +97,7 @@ namespace facebook::velox::parquet::arrow::internal { #define ALIGNOF_UINT32_T alignof(uint32_t) -static const uint32_t crc32_lookup[16][256] = { +static const uint32_t crc32Lookup[16][256] = { /* same algorithm as crc32_bitwise for (int i = 0; i <= 0xFF; i++) @@ -861,7 +861,7 @@ static const uint32_t crc32_lookup[16][256] = { uint32_t crc32(uint32_t prev, const void* data, size_t length) { uint32_t crc = ~prev; unsigned unaligned; - const uint8_t* current_char; + const uint8_t* currentChar; const uint32_t* current; unaligned = @@ -870,20 +870,20 @@ uint32_t crc32(uint32_t prev, const void* data, size_t length) { unaligned = 0; /* process a byte at a time until we hit an alignment boundary (max 3) */ - current_char = reinterpret_cast(data); + currentChar = reinterpret_cast(data); for (; unaligned && length; unaligned--, length--) - crc = (crc >> 8) ^ crc32_lookup[0][(crc & 0xFF) ^ *current_char++]; + crc = (crc >> 8) ^ crc32Lookup[0][(crc & 0xFF) ^ *currentChar++]; - current = reinterpret_cast(current_char); + current = reinterpret_cast(currentChar); /* process 64 bytes at once (Slicing-by-16) */ /* enabling optimization (at least -O2) automatically unrolls the inner * for-loop */ const size_t unroll = 4; - const size_t bytes_at_once = 16 * unroll; + const size_t bytesAtOnce = 16 * unroll; - while (length >= bytes_at_once) { + while (length >= bytesAtOnce) { size_t unrolling; for (unrolling = 0; unrolling < unroll; unrolling++) { #if ARROW_LITTLE_ENDIAN @@ -891,39 +891,39 @@ uint32_t crc32(uint32_t prev, const void* data, size_t length) { uint32_t two = *current++; uint32_t three = *current++; uint32_t four = *current++; - crc = crc32_lookup[0][(four >> 24) & 0xFF] ^ - crc32_lookup[1][(four >> 16) & 0xFF] ^ - crc32_lookup[2][(four >> 8) & 0xFF] ^ crc32_lookup[3][four & 0xFF] ^ - crc32_lookup[4][(three >> 24) & 0xFF] ^ - crc32_lookup[5][(three >> 16) & 0xFF] ^ - crc32_lookup[6][(three >> 8) & 0xFF] ^ crc32_lookup[7][three & 0xFF] ^ - crc32_lookup[8][(two >> 24) & 0xFF] ^ - crc32_lookup[9][(two >> 16) & 0xFF] ^ - crc32_lookup[10][(two >> 8) & 0xFF] ^ crc32_lookup[11][two & 0xFF] ^ - crc32_lookup[12][(one >> 24) & 0xFF] ^ - crc32_lookup[13][(one >> 16) & 0xFF] ^ - crc32_lookup[14][(one >> 8) & 0xFF] ^ crc32_lookup[15][one & 0xFF]; + crc = crc32Lookup[0][(four >> 24) & 0xFF] ^ + crc32Lookup[1][(four >> 16) & 0xFF] ^ + crc32Lookup[2][(four >> 8) & 0xFF] ^ crc32Lookup[3][four & 0xFF] ^ + crc32Lookup[4][(three >> 24) & 0xFF] ^ + crc32Lookup[5][(three >> 16) & 0xFF] ^ + crc32Lookup[6][(three >> 8) & 0xFF] ^ crc32Lookup[7][three & 0xFF] ^ + crc32Lookup[8][(two >> 24) & 0xFF] ^ + crc32Lookup[9][(two >> 16) & 0xFF] ^ + crc32Lookup[10][(two >> 8) & 0xFF] ^ crc32Lookup[11][two & 0xFF] ^ + crc32Lookup[12][(one >> 24) & 0xFF] ^ + crc32Lookup[13][(one >> 16) & 0xFF] ^ + crc32Lookup[14][(one >> 8) & 0xFF] ^ crc32Lookup[15][one & 0xFF]; #else uint32_t one = *current++ ^ ::arrow::bit_util::ByteSwap(crc); uint32_t two = *current++; uint32_t three = *current++; uint32_t four = *current++; - crc = crc32_lookup[0][four & 0xFF] ^ crc32_lookup[1][(four >> 8) & 0xFF] ^ - crc32_lookup[2][(four >> 16) & 0xFF] ^ - crc32_lookup[3][(four >> 24) & 0xFF] ^ crc32_lookup[4][three & 0xFF] ^ - crc32_lookup[5][(three >> 8) & 0xFF] ^ - crc32_lookup[6][(three >> 16) & 0xFF] ^ - crc32_lookup[7][(three >> 24) & 0xFF] ^ crc32_lookup[8][two & 0xFF] ^ - crc32_lookup[9][(two >> 8) & 0xFF] ^ - crc32_lookup[10][(two >> 16) & 0xFF] ^ - crc32_lookup[11][(two >> 24) & 0xFF] ^ crc32_lookup[12][one & 0xFF] ^ - crc32_lookup[13][(one >> 8) & 0xFF] ^ - crc32_lookup[14][(one >> 16) & 0xFF] ^ - crc32_lookup[15][(one >> 24) & 0xFF]; + crc = crc32Lookup[0][four & 0xFF] ^ crc32Lookup[1][(four >> 8) & 0xFF] ^ + crc32Lookup[2][(four >> 16) & 0xFF] ^ + crc32Lookup[3][(four >> 24) & 0xFF] ^ crc32Lookup[4][three & 0xFF] ^ + crc32Lookup[5][(three >> 8) & 0xFF] ^ + crc32Lookup[6][(three >> 16) & 0xFF] ^ + crc32Lookup[7][(three >> 24) & 0xFF] ^ crc32Lookup[8][two & 0xFF] ^ + crc32Lookup[9][(two >> 8) & 0xFF] ^ + crc32Lookup[10][(two >> 16) & 0xFF] ^ + crc32Lookup[11][(two >> 24) & 0xFF] ^ crc32Lookup[12][one & 0xFF] ^ + crc32Lookup[13][(one >> 8) & 0xFF] ^ + crc32Lookup[14][(one >> 16) & 0xFF] ^ + crc32Lookup[15][(one >> 24) & 0xFF]; #endif } - length -= bytes_at_once; + length -= bytesAtOnce; } /* process eight bytes at once (Slicing-by-8) */ @@ -932,21 +932,19 @@ uint32_t crc32(uint32_t prev, const void* data, size_t length) { #if ARROW_LITTLE_ENDIAN uint32_t one = *current++ ^ crc; uint32_t two = *current++; - crc = crc32_lookup[0][(two >> 24) & 0xFF] ^ - crc32_lookup[1][(two >> 16) & 0xFF] ^ - crc32_lookup[2][(two >> 8) & 0xFF] ^ crc32_lookup[3][two & 0xFF] ^ - crc32_lookup[4][(one >> 24) & 0xFF] ^ - crc32_lookup[5][(one >> 16) & 0xFF] ^ - crc32_lookup[6][(one >> 8) & 0xFF] ^ crc32_lookup[7][one & 0xFF]; + crc = crc32Lookup[0][(two >> 24) & 0xFF] ^ + crc32Lookup[1][(two >> 16) & 0xFF] ^ crc32Lookup[2][(two >> 8) & 0xFF] ^ + crc32Lookup[3][two & 0xFF] ^ crc32Lookup[4][(one >> 24) & 0xFF] ^ + crc32Lookup[5][(one >> 16) & 0xFF] ^ crc32Lookup[6][(one >> 8) & 0xFF] ^ + crc32Lookup[7][one & 0xFF]; #else uint32_t one = *current++ ^ ::arrow::bit_util::ByteSwap(crc); uint32_t two = *current++; - crc = crc32_lookup[0][two & 0xFF] ^ crc32_lookup[1][(two >> 8) & 0xFF] ^ - crc32_lookup[2][(two >> 16) & 0xFF] ^ - crc32_lookup[3][(two >> 24) & 0xFF] ^ crc32_lookup[4][one & 0xFF] ^ - crc32_lookup[5][(one >> 8) & 0xFF] ^ - crc32_lookup[6][(one >> 16) & 0xFF] ^ - crc32_lookup[7][(one >> 24) & 0xFF]; + crc = crc32Lookup[0][two & 0xFF] ^ crc32Lookup[1][(two >> 8) & 0xFF] ^ + crc32Lookup[2][(two >> 16) & 0xFF] ^ + crc32Lookup[3][(two >> 24) & 0xFF] ^ crc32Lookup[4][one & 0xFF] ^ + crc32Lookup[5][(one >> 8) & 0xFF] ^ crc32Lookup[6][(one >> 16) & 0xFF] ^ + crc32Lookup[7][(one >> 24) & 0xFF]; #endif length -= 8; @@ -955,14 +953,13 @@ uint32_t crc32(uint32_t prev, const void* data, size_t length) { if (length >= 4) { #if ARROW_LITTLE_ENDIAN uint32_t one = *current++ ^ crc; - crc = crc32_lookup[0][(one >> 24) & 0xFF] ^ - crc32_lookup[1][(one >> 16) & 0xFF] ^ - crc32_lookup[2][(one >> 8) & 0xFF] ^ crc32_lookup[3][one & 0xFF]; + crc = crc32Lookup[0][(one >> 24) & 0xFF] ^ + crc32Lookup[1][(one >> 16) & 0xFF] ^ crc32Lookup[2][(one >> 8) & 0xFF] ^ + crc32Lookup[3][one & 0xFF]; #else uint32_t one = *current++ ^ ::arrow::bit_util::ByteSwap(crc); - crc = crc32_lookup[0][one & 0xFF] ^ crc32_lookup[1][(one >> 8) & 0xFF] ^ - crc32_lookup[2][(one >> 16) & 0xFF] ^ - crc32_lookup[3][(one >> 24) & 0xFF]; + crc = crc32Lookup[0][one & 0xFF] ^ crc32Lookup[1][(one >> 8) & 0xFF] ^ + crc32Lookup[2][(one >> 16) & 0xFF] ^ crc32Lookup[3][(one >> 24) & 0xFF]; #endif length -= 4; @@ -970,10 +967,10 @@ uint32_t crc32(uint32_t prev, const void* data, size_t length) { /* Finish with any remaining bytes one by one */ - current_char = reinterpret_cast(current); + currentChar = reinterpret_cast(current); /* remaining 1 to 3 bytes (standard algorithm) */ while (length-- != 0) - crc = (crc >> 8) ^ crc32_lookup[0][(crc & 0xFF) ^ *current_char++]; + crc = (crc >> 8) ^ crc32Lookup[0][(crc & 0xFF) ^ *currentChar++]; return ~crc; } diff --git a/velox/dwio/parquet/writer/arrow/util/Crc32.h b/velox/dwio/parquet/writer/arrow/util/Crc32.h index afb2045f979..6fd2f9fe2c5 100644 --- a/velox/dwio/parquet/writer/arrow/util/Crc32.h +++ b/velox/dwio/parquet/writer/arrow/util/Crc32.h @@ -23,7 +23,7 @@ namespace facebook::velox::parquet::arrow::internal { -/// \brief Compute the CRC32 checksum of the given data +/// \brief Compute the CRC32 checksum of the given data. /// /// This function computes CRC32 with the polynomial 0x04C11DB7, /// as used in zlib and others (note this is different from CRC32C). diff --git a/velox/dwio/parquet/writer/arrow/util/Hashing.cpp b/velox/dwio/parquet/writer/arrow/util/Hashing.cpp index eabab4b54c4..76c1508ad6d 100644 --- a/velox/dwio/parquet/writer/arrow/util/Hashing.cpp +++ b/velox/dwio/parquet/writer/arrow/util/Hashing.cpp @@ -24,32 +24,32 @@ namespace facebook::velox::parquet::arrow::internal { namespace { -/// \brief A hash function for bitmaps that can handle offsets and lengths in -/// terms of number of bits. The hash only depends on the bits actually hashed. +/// \brief A hash function for bitmaps that can handle offsets and lengths in. +/// Terms of number of bits. The hash only depends on the bits actually hashed. /// -/// This implementation is based on 64-bit versions of MurmurHash2 by Austin +/// This implementation is based on 64-bit versions of MurmurHash2 by Austin. /// Appleby. /// -/// It's the caller's responsibility to ensure that bits_offset + num_bits are -/// readable from the bitmap. +/// It's the caller's responsibility to ensure that bits_offset + num_bits are. +/// Readable from the bitmap. /// /// \param key The pointer to the bitmap. -/// \param seed The seed for the hash function (useful when chaining hash -/// functions). \param bits_offset The offset in bits relative to the start of -/// the bitmap. \param num_bits The number of bits after the offset to be -/// hashed. -uint64_t MurmurHashBitmap64( +/// \param seed The seed for the hash function (useful when chaining hash. +/// Functions). \param bits_offset The offset in bits relative to the start of. +/// The bitmap. \param num_bits The number of bits after the offset to be. +/// Hashed. +uint64_t murmurHashBitmap64( const uint8_t* key, uint64_t seed, - uint64_t bits_offset, - uint64_t num_bits) { + uint64_t bitsOffset, + uint64_t numBits) { const uint64_t m = 0xc6a4a7935bd1e995LLU; const int r = 47; - uint64_t h = seed ^ (num_bits * m); + uint64_t h = seed ^ (numBits * m); ::arrow::internal::BitmapWordReader reader( - key, bits_offset, num_bits); + key, bitsOffset, numBits); auto nwords = reader.words(); while (nwords--) { auto k = reader.NextWord(); @@ -60,12 +60,12 @@ uint64_t MurmurHashBitmap64( h ^= k; h *= m; } - int valid_bits; + int validBits; auto nbytes = reader.trailing_bytes(); if (nbytes) { uint64_t k = 0; do { - auto byte = reader.NextTrailingByte(valid_bits); + auto byte = reader.NextTrailingByte(validBits); k = (k << 8) | static_cast(byte); } while (--nbytes); h ^= k; @@ -80,14 +80,14 @@ uint64_t MurmurHashBitmap64( } // namespace -hash_t ComputeBitmapHash( +hash_t computeBitmapHash( const uint8_t* bitmap, hash_t seed, - int64_t bits_offset, - int64_t num_bits) { - VELOX_DCHECK_GE(bits_offset, 0); - VELOX_DCHECK_GE(num_bits, 0); - return MurmurHashBitmap64(bitmap, seed, bits_offset, num_bits); + int64_t bitsOffset, + int64_t numBits) { + VELOX_DCHECK_GE(bitsOffset, 0); + VELOX_DCHECK_GE(numBits, 0); + return murmurHashBitmap64(bitmap, seed, bitsOffset, numBits); } } // namespace facebook::velox::parquet::arrow::internal diff --git a/velox/dwio/parquet/writer/arrow/util/Hashing.h b/velox/dwio/parquet/writer/arrow/util/Hashing.h index cadca7c6890..b183df953cf 100644 --- a/velox/dwio/parquet/writer/arrow/util/Hashing.h +++ b/velox/dwio/parquet/writer/arrow/util/Hashing.h @@ -16,7 +16,7 @@ // Adapted from Apache Arrow. -// Private header, not to be exported +// Private header, not to be exported. #pragma once @@ -46,9 +46,7 @@ #include "velox/common/base/Exceptions.h" -#define XXH_INLINE_ALL - -#include +#include "velox/common/base/XxHashInline.h" namespace facebook::velox::parquet::arrow::internal { @@ -62,12 +60,12 @@ typedef uint64_t hash_t; // Notes about the choice of a hash function. // - XXH3 is extremely fast on most data sizes, from small to huge; -// faster even than HW CRC-based hashing schemes -// - our custom hash function for tiny values (< 16 bytes) is still -// significantly faster (~30%), at least on this machine and compiler +// Faster even than HW CRC-based hashing schemes. +// - Our custom hash function for tiny values (< 16 bytes) is still. +// Significantly faster (~30%), at least on this machine and compiler. template -inline hash_t ComputeStringHash(const void* data, int64_t length); +inline hash_t computeStringHash(const void* data, int64_t length); /// \brief A hash function for bitmaps that can handle offsets and lengths in /// terms of number of bits. The hash only depends on the bits actually hashed. @@ -75,35 +73,35 @@ inline hash_t ComputeStringHash(const void* data, int64_t length); /// It's the caller's responsibility to ensure that bits_offset + num_bits are /// readable from the bitmap. /// -/// \pre bits_offset >= 0 -/// \pre num_bits >= 0 +/// \pre bits_offset >= 0. +/// \pre num_bits >= 0. /// \pre (bits_offset + num_bits + 7) / 8 <= readable length in bytes from -/// bitmap +/// bitmap. /// /// \param bitmap The pointer to the bitmap. -/// \param seed The seed for the hash function (useful when chaining hash -/// functions). \param bits_offset The offset in bits relative to the start of -/// the bitmap. \param num_bits The number of bits after the offset to be -/// hashed. -ARROW_EXPORT hash_t ComputeBitmapHash( +/// \param seed The seed for the hash function (useful when chaining hash. +/// Functions). \param bits_offset The offset in bits relative to the start of. +/// The bitmap. \param num_bits The number of bits after the offset to be. +/// Hashed. +ARROW_EXPORT hash_t computeBitmapHash( const uint8_t* bitmap, hash_t seed, - int64_t bits_offset, - int64_t num_bits); + int64_t bitsOffset, + int64_t numBits); template struct ScalarHelperBase { - static bool CompareScalars(Scalar u, Scalar v) { + static bool compareScalars(Scalar u, Scalar v) { return u == v; } - static hash_t ComputeHash(const Scalar& value) { - // Generic hash computation for scalars. Simply apply the string hash - // to the bit representation of the value. + static hash_t computeHash(const Scalar& value) { + // Generic hash computation for scalars. Simply apply the string hash. + // To the bit representation of the value. // XXX in the case of FP values, we'd like equal values to have the same // hash, even if they have different bit representations... - return ComputeStringHash(&value, sizeof(value)); + return computeStringHash(&value, sizeof(value)); } }; @@ -116,20 +114,20 @@ struct ScalarHelper< AlgNum, enable_if_t::value>> : public ScalarHelperBase { - // ScalarHelper specialization for integers + // ScalarHelper specialization for integers. - static hash_t ComputeHash(const Scalar& value) { + static hash_t computeHash(const Scalar& value) { // Faster hash computation for integers. - // Two of xxhash's prime multipliers (which are chosen for their + // Two of xxhash's prime multipliers (which are chosen for their. // bit dispersion properties) static constexpr uint64_t multipliers[] = { 11400714785074694791ULL, 14029467366897019727ULL}; - // Multiplying by the prime number mixes the low bits into the high bits, - // then byte-swapping (which is a single CPU instruction) allows the - // combined high and low bits to participate in the initial hash table - // index. + // Multiplying by the prime number mixes the low bits into the high bits,. + // Then byte-swapping (which is a single CPU instruction) allows the. + // Combined high and low bits to participate in the initial hash table. + // Index. auto h = static_cast(value); return ::arrow::bit_util::ByteSwap(multipliers[AlgNum] * h); } @@ -141,10 +139,10 @@ struct ScalarHelper< AlgNum, enable_if_t::value>> : public ScalarHelperBase { - // ScalarHelper specialization for std::string_view + // ScalarHelper specialization for std::string_view. - static hash_t ComputeHash(std::string_view value) { - return ComputeStringHash( + static hash_t computeHash(std::string_view value) { + return computeStringHash( value.data(), static_cast(value.size())); } }; @@ -155,9 +153,9 @@ struct ScalarHelper< AlgNum, enable_if_t::value>> : public ScalarHelperBase { - // ScalarHelper specialization for reals + // ScalarHelper specialization for reals. - static bool CompareScalars(Scalar u, Scalar v) { + static bool compareScalars(Scalar u, Scalar v) { if (std::isnan(u)) { // XXX should we do a bit-precise comparison? return std::isnan(v); @@ -167,10 +165,10 @@ struct ScalarHelper< }; template -hash_t ComputeStringHash(const void* data, int64_t length) { +hash_t computeStringHash(const void* data, int64_t length) { if (ARROW_PREDICT_TRUE(length <= 16)) { - // Specialize for small hash strings, as they are quite common as - // hash table keys. Even XXH3 isn't quite as fast. + // Specialize for small hash strings, as they are quite common as. + // Hash table keys. Even XXH3 isn't quite as fast. auto p = reinterpret_cast(data); auto n = static_cast(length); if (n <= 8) { @@ -179,28 +177,28 @@ hash_t ComputeStringHash(const void* data, int64_t length) { return 1U; } uint32_t x = (n << 24) ^ (p[0] << 16) ^ (p[n / 2] << 8) ^ p[n - 1]; - return ScalarHelper::ComputeHash(x); + return ScalarHelper::computeHash(x); } - // 4 <= length <= 8 - // We can read the string as two overlapping 32-bit ints, apply - // different hash functions to each of them in parallel, then XOR - // the results + // 4 <= Length <= 8. + // We can read the string as two overlapping 32-bit ints, apply. + // Different hash functions to each of them in parallel, then XOR. + // The results. uint32_t x, y; hash_t hx, hy; x = ::arrow::util::SafeLoadAs(p + n - 4); y = ::arrow::util::SafeLoadAs(p); - hx = ScalarHelper::ComputeHash(x); - hy = ScalarHelper::ComputeHash(y); + hx = ScalarHelper::computeHash(x); + hy = ScalarHelper::computeHash(y); return n ^ hx ^ hy; } - // 8 <= length <= 16 - // Apply the same principle as above + // 8 <= Length <= 16. + // Apply the same principle as above. uint64_t x, y; hash_t hx, hy; x = ::arrow::util::SafeLoadAs(p + n - 8); y = ::arrow::util::SafeLoadAs(p); - hx = ScalarHelper::ComputeHash(x); - hy = ScalarHelper::ComputeHash(y); + hx = ScalarHelper::computeHash(x); + hy = ScalarHelper::computeHash(y); return n ^ hx ^ hy; } @@ -208,9 +206,9 @@ hash_t ComputeStringHash(const void* data, int64_t length) { #error XXH3_SECRET_SIZE_MIN changed, please fix kXxh3Secrets #endif - // XXH3_64bits_withSeed generates a secret based on the seed, which is too - // slow. Instead, we use hard-coded random secrets. To maximize cache - // efficiency, they reuse the same memory area. + // XXH3_64bits_withSeed generates a secret based on the seed, which is too. + // Slow. Instead, we use hard-coded random secrets. To maximize cache. + // Efficiency, they reuse the same memory area. static constexpr unsigned char kXxh3Secrets[XXH3_SECRET_SIZE_MIN + 1] = { 0xe7, 0x8b, 0x13, 0xf9, 0xfc, 0xb5, 0x8e, 0xef, 0x81, 0x48, 0x2c, 0xbf, 0xf9, 0x9f, 0xc1, 0x1e, 0x43, 0x6d, 0xbf, 0xa6, 0x6d, 0xb5, 0x72, 0xbc, @@ -233,7 +231,7 @@ hash_t ComputeStringHash(const void* data, int64_t length) { // XXX add a HashEq struct with both hash and compare functions? -// ---------------------------------------------------------------------- +// ----------------------------------------------------------------------. // An open-addressing insert-only hash table (no deletes) template @@ -246,51 +244,51 @@ class HashTable { hash_t h; Payload payload; - // An entry is valid if the hash is different from the sentinel value + // An entry is valid if the hash is different from the sentinel value. operator bool() const { return h != kSentinel; } }; - HashTable(MemoryPool* pool, uint64_t capacity) : entries_builder_(pool) { + HashTable(MemoryPool* pool, uint64_t capacity) : entriesBuilder_(pool) { VELOX_DCHECK_NOT_NULL(pool); - // Minimum of 32 elements + // Minimum of 32 elements. capacity = std::max(capacity, 32UL); capacity_ = ::arrow::bit_util::NextPower2(capacity); - capacity_mask_ = capacity_ - 1; + capacityMask_ = capacity_ - 1; size_ = 0; - auto status = UpsizeBuffer(capacity_); + auto status = upsizeBuffer(capacity_); VELOX_DCHECK(status.ok(), status.ToString()); } - // Lookup with non-linear probing - // cmp_func should have signature bool(const Payload*). + // Lookup with non-linear probing. + // Cmp_func should have signature bool(const Payload*). // Return a (Entry*, found) pair. template - std::pair Lookup(hash_t h, CmpFunc&& cmp_func) { - auto p = Lookup( - h, entries_, capacity_mask_, std::forward(cmp_func)); + std::pair lookup(hash_t h, CmpFunc&& cmpFunc) { + auto p = lookup( + h, entries_, capacityMask_, std::forward(cmpFunc)); return {&entries_[p.first], p.second}; } template - std::pair Lookup(hash_t h, CmpFunc&& cmp_func) const { - auto p = Lookup( - h, entries_, capacity_mask_, std::forward(cmp_func)); + std::pair lookup(hash_t h, CmpFunc&& cmpFunc) const { + auto p = lookup( + h, entries_, capacityMask_, std::forward(cmpFunc)); return {&entries_[p.first], p.second}; } - Status Insert(Entry* entry, hash_t h, const Payload& payload) { - // Ensure entry is empty before inserting + Status insert(Entry* entry, hash_t h, const Payload& payload) { + // Ensure entry is empty before inserting. assert(!*entry); - entry->h = FixHash(h); + entry->h = fixHash(h); entry->payload = payload; ++size_; - if (ARROW_PREDICT_FALSE(NeedUpsizing())) { - // Resize less frequently since it is expensive - return Upsize(capacity_ * kLoadFactor * 2); + if (ARROW_PREDICT_FALSE(needUpsizing())) { + // Resize less frequently since it is expensive. + return upsize(capacity_ * kLoadFactor * 2); } return Status::OK(); } @@ -299,129 +297,129 @@ class HashTable { return size_; } - // Visit all non-empty entries in the table + // Visit all non-empty entries in the table. // The visit_func should have signature void(const Entry*) template - void VisitEntries(VisitFunc&& visit_func) const { + void visitEntries(VisitFunc&& visitFunc) const { for (uint64_t i = 0; i < capacity_; i++) { const auto& entry = entries_[i]; if (entry) { - visit_func(&entry); + visitFunc(&entry); } } } protected: - // NoCompare is for when the value is known not to exist in the table + // NoCompare is for when the value is known not to exist in the table. enum CompareKind { DoCompare, NoCompare }; - // The workhorse lookup function + // The workhorse lookup function. template - std::pair Lookup( + std::pair lookup( hash_t h, const Entry* entries, - uint64_t size_mask, - CmpFunc&& cmp_func) const { - static constexpr uint8_t perturb_shift = 5; + uint64_t sizeMask, + CmpFunc&& cmpFunc) const { + static constexpr uint8_t perturbShift = 5; uint64_t index, perturb; const Entry* entry; - h = FixHash(h); - index = h & size_mask; - perturb = (h >> perturb_shift) + 1U; + h = fixHash(h); + index = h & sizeMask; + perturb = (h >> perturbShift) + 1U; while (true) { entry = &entries[index]; - if (CompareEntry( - h, entry, std::forward(cmp_func))) { - // Found + if (compareEntry( + h, entry, std::forward(cmpFunc))) { + // Found. return {index, true}; } if (entry->h == kSentinel) { - // Empty slot + // Empty slot. return {index, false}; } // Perturbation logic inspired from CPython's set / dict object. - // The goal is that all 64 bits of the unmasked hash value eventually - // participate in the probing sequence, to minimize clustering. - index = (index + perturb) & size_mask; - perturb = (perturb >> perturb_shift) + 1U; + // The goal is that all 64 bits of the unmasked hash value eventually. + // Participate in the probing sequence, to minimize clustering. + index = (index + perturb) & sizeMask; + perturb = (perturb >> perturbShift) + 1U; } } template - bool CompareEntry(hash_t h, const Entry* entry, CmpFunc&& cmp_func) const { + bool compareEntry(hash_t h, const Entry* entry, CmpFunc&& cmpFunc) const { if (CKind == NoCompare) { return false; } else { - return entry->h == h && cmp_func(&entry->payload); + return entry->h == h && cmpFunc(&entry->payload); } } - bool NeedUpsizing() const { - // Keep the load factor <= 1/2 + bool needUpsizing() const { + // Keep the load factor <= 1/2. return size_ * kLoadFactor >= capacity_; } - Status UpsizeBuffer(uint64_t capacity) { - RETURN_NOT_OK(entries_builder_.Resize(capacity)); - entries_ = entries_builder_.mutable_data(); + Status upsizeBuffer(uint64_t capacity) { + RETURN_NOT_OK(entriesBuilder_.Resize(capacity)); + entries_ = entriesBuilder_.mutable_data(); memset(static_cast(entries_), 0, capacity * sizeof(Entry)); return Status::OK(); } - Status Upsize(uint64_t new_capacity) { - assert(new_capacity > capacity_); - uint64_t new_mask = new_capacity - 1; - assert((new_capacity & new_mask) == 0); // it's a power of two + Status upsize(uint64_t newCapacity) { + assert(newCapacity > capacity_); + uint64_t newMask = newCapacity - 1; + assert((newCapacity & newMask) == 0); // it's a power of two - // Stash old entries and seal builder, effectively resetting the Buffer - const Entry* old_entries = entries_; + // Stash old entries and seal builder, effectively resetting the Buffer. + const Entry* oldEntries = entries_; ARROW_ASSIGN_OR_RAISE( - auto previous, entries_builder_.FinishWithLength(capacity_)); - // Allocate new buffer - RETURN_NOT_OK(UpsizeBuffer(new_capacity)); + auto previous, entriesBuilder_.FinishWithLength(capacity_)); + // Allocate new buffer. + RETURN_NOT_OK(upsizeBuffer(newCapacity)); for (uint64_t i = 0; i < capacity_; i++) { - const auto& entry = old_entries[i]; + const auto& entry = oldEntries[i]; if (entry) { - // Dummy compare function will not be called - auto p = Lookup( - entry.h, entries_, new_mask, [](const Payload*) { return false; }); - // Lookup (and CompareEntry) ensure that an - // empty slots is always returned + // Dummy compare function will not be called. + auto p = lookup( + entry.h, entries_, newMask, [](const Payload*) { return false; }); + // Lookup (and CompareEntry) ensure that an. + // Empty slots is always returned. assert(!p.second); entries_[p.first] = entry; } } - capacity_ = new_capacity; - capacity_mask_ = new_mask; + capacity_ = newCapacity; + capacityMask_ = newMask; return Status::OK(); } - hash_t FixHash(hash_t h) const { + hash_t fixHash(hash_t h) const { return (h == kSentinel) ? 42U : h; } // The number of slots available in the hash table array. uint64_t capacity_; - uint64_t capacity_mask_; + uint64_t capacityMask_; // The number of used slots in the hash table array. uint64_t size_; Entry* entries_; - ::arrow::TypedBufferBuilder entries_builder_; + ::arrow::TypedBufferBuilder entriesBuilder_; }; // XXX typedef memo_index_t int32_t ? constexpr int32_t kKeyNotFound = -1; -// ---------------------------------------------------------------------- +// ----------------------------------------------------------------------. // A base class for memoization table. class MemoTable { @@ -431,11 +429,11 @@ class MemoTable { virtual int32_t size() const = 0; }; -// ---------------------------------------------------------------------- +// ----------------------------------------------------------------------. // A memoization table for memory-cheap scalar values. -// The memoization table remembers and allows to look up the insertion -// index for each key. +// The memoization table remembers and allows to look up the insertion. +// Index for each key. template < typename Scalar, @@ -443,131 +441,130 @@ template < class ScalarMemoTable : public MemoTable { public: explicit ScalarMemoTable(MemoryPool* pool, int64_t entries = 0) - : hash_table_(pool, static_cast(entries)) {} + : hashTable_(pool, static_cast(entries)) {} - int32_t Get(const Scalar& value) const { - auto cmp_func = [value](const Payload* payload) -> bool { - return ScalarHelper::CompareScalars(payload->value, value); + int32_t get(const Scalar& value) const { + auto cmpFunc = [value](const Payload* payload) -> bool { + return ScalarHelper::compareScalars(payload->value, value); }; - hash_t h = ComputeHash(value); - auto p = hash_table_.Lookup(h, cmp_func); + hash_t h = computeHash(value); + auto p = hashTable_.lookup(h, cmpFunc); if (p.second) { - return p.first->payload.memo_index; + return p.first->payload.memoIndex; } else { return kKeyNotFound; } } template - Status GetOrInsert( + Status getOrInsert( const Scalar& value, - Func1&& on_found, - Func2&& on_not_found, - int32_t* out_memo_index) { - auto cmp_func = [value](const Payload* payload) -> bool { - return ScalarHelper::CompareScalars(value, payload->value); + Func1&& onFound, + Func2&& onNotFound, + int32_t* outMemoIndex) { + auto cmpFunc = [value](const Payload* payload) -> bool { + return ScalarHelper::compareScalars(value, payload->value); }; - hash_t h = ComputeHash(value); - auto p = hash_table_.Lookup(h, cmp_func); - int32_t memo_index; + hash_t h = computeHash(value); + auto p = hashTable_.lookup(h, cmpFunc); + int32_t memoIndex; if (p.second) { - memo_index = p.first->payload.memo_index; - on_found(memo_index); + memoIndex = p.first->payload.memoIndex; + onFound(memoIndex); } else { - memo_index = size(); - RETURN_NOT_OK(hash_table_.Insert(p.first, h, {value, memo_index})); - on_not_found(memo_index); + memoIndex = size(); + RETURN_NOT_OK(hashTable_.insert(p.first, h, {value, memoIndex})); + onNotFound(memoIndex); } - *out_memo_index = memo_index; + *outMemoIndex = memoIndex; return Status::OK(); } - Status GetOrInsert(const Scalar& value, int32_t* out_memo_index) { - return GetOrInsert( - value, [](int32_t i) {}, [](int32_t i) {}, out_memo_index); + Status getOrInsert(const Scalar& value, int32_t* outMemoIndex) { + return getOrInsert(value, [](int32_t i) {}, [](int32_t i) {}, outMemoIndex); } - int32_t GetNull() const { - return null_index_; + int32_t getNull() const { + return nullIndex_; } template - int32_t GetOrInsertNull(Func1&& on_found, Func2&& on_not_found) { - int32_t memo_index = GetNull(); - if (memo_index != kKeyNotFound) { - on_found(memo_index); + int32_t getOrInsertNull(Func1&& onFound, Func2&& onNotFound) { + int32_t memoIndex = getNull(); + if (memoIndex != kKeyNotFound) { + onFound(memoIndex); } else { - null_index_ = memo_index = size(); - on_not_found(memo_index); + nullIndex_ = memoIndex = size(); + onNotFound(memoIndex); } - return memo_index; + return memoIndex; } - int32_t GetOrInsertNull() { - return GetOrInsertNull([](int32_t i) {}, [](int32_t i) {}); + int32_t getOrInsertNull() { + return getOrInsertNull([](int32_t i) {}, [](int32_t i) {}); } // The number of entries in the memo table +1 if null was added. // (which is also 1 + the largest memo index) int32_t size() const override { - return static_cast(hash_table_.size()) + - (GetNull() != kKeyNotFound); + return static_cast(hashTable_.size()) + + (getNull() != kKeyNotFound); } - // Copy values starting from index `start` into `out_data` - void CopyValues(int32_t start, Scalar* out_data) const { - hash_table_.VisitEntries([=](const HashTableEntry* entry) { - int32_t index = entry->payload.memo_index - start; + // Copy values starting from index `start` into `out_data`. + void copyValues(int32_t start, Scalar* outData) const { + hashTable_.visitEntries([=](const HashTableEntry* entry) { + int32_t index = entry->payload.memoIndex - start; if (index >= 0) { - out_data[index] = entry->payload.value; + outData[index] = entry->payload.value; } }); - // Zero-initialize the null entry - if (null_index_ != kKeyNotFound) { - int32_t index = null_index_ - start; + // Zero-initialize the null entry. + if (nullIndex_ != kKeyNotFound) { + int32_t index = nullIndex_ - start; if (index >= 0) { - out_data[index] = Scalar{}; + outData[index] = Scalar{}; } } } - void CopyValues(Scalar* out_data) const { - CopyValues(0, out_data); + void copyValues(Scalar* outData) const { + copyValues(0, outData); } protected: struct Payload { Scalar value; - int32_t memo_index; + int32_t memoIndex; }; using HashTableType = HashTableTemplateType; using HashTableEntry = typename HashTableType::Entry; - HashTableType hash_table_; - int32_t null_index_ = kKeyNotFound; + HashTableType hashTable_; + int32_t nullIndex_ = kKeyNotFound; - hash_t ComputeHash(const Scalar& value) const { - return ScalarHelper::ComputeHash(value); + hash_t computeHash(const Scalar& value) const { + return ScalarHelper::computeHash(value); } public: - // defined here so that `HashTableType` is visible + // Defined here so that `HashTableType` is visible. // Merge entries from `other_table` into `this->hash_table_`. - Status MergeTable(const ScalarMemoTable& other_table) { - const HashTableType& other_hashtable = other_table.hash_table_; + Status mergeTable(const ScalarMemoTable& otherTable) { + const HashTableType& otherHashtable = otherTable.hashTable_; - other_hashtable.VisitEntries([this](const HashTableEntry* other_entry) { + otherHashtable.visitEntries([this](const HashTableEntry* otherEntry) { int32_t unused; - auto status = this->GetOrInsert(other_entry->payload.value, &unused); + auto status = this->getOrInsert(otherEntry->payload.value, &unused); VELOX_DCHECK(status.ok(), status.ToString()); }); - // TODO: ARROW-17074 - implement proper error handling + // TODO: ARROW-17074 - implement proper error handling. return Status::OK(); } }; -// ---------------------------------------------------------------------- -// A memoization table for small scalar values, using direct indexing +// ----------------------------------------------------------------------. +// A memoization table for small scalar values, using direct indexing. template struct SmallScalarTraits {}; @@ -576,7 +573,7 @@ template <> struct SmallScalarTraits { static constexpr int32_t cardinality = 2; - static uint32_t AsIndex(bool value) { + static uint32_t asIndex(bool value) { return value ? 1 : 0; } }; @@ -588,7 +585,7 @@ struct SmallScalarTraits::value>> { static constexpr int32_t cardinality = 1U + std::numeric_limits::max(); - static uint32_t AsIndex(Scalar value) { + static uint32_t asIndex(Scalar value) { return static_cast(value); } }; @@ -599,94 +596,93 @@ template < class SmallScalarMemoTable : public MemoTable { public: explicit SmallScalarMemoTable(MemoryPool* pool, int64_t entries = 0) { - std::fill(value_to_index_, value_to_index_ + cardinality + 1, kKeyNotFound); - index_to_value_.reserve(cardinality); + std::fill(valueToIndex_, valueToIndex_ + cardinality + 1, kKeyNotFound); + indexToValue_.reserve(cardinality); } - int32_t Get(const Scalar value) const { - auto value_index = AsIndex(value); - return value_to_index_[value_index]; + int32_t get(const Scalar value) const { + auto valueIndex = asIndex(value); + return valueToIndex_[valueIndex]; } template - Status GetOrInsert( + Status getOrInsert( const Scalar value, - Func1&& on_found, - Func2&& on_not_found, - int32_t* out_memo_index) { - auto value_index = AsIndex(value); - auto memo_index = value_to_index_[value_index]; - if (memo_index == kKeyNotFound) { - memo_index = static_cast(index_to_value_.size()); - index_to_value_.push_back(value); - value_to_index_[value_index] = memo_index; - VELOX_DCHECK_LT(memo_index, cardinality + 1); - on_not_found(memo_index); + Func1&& onFound, + Func2&& onNotFound, + int32_t* outMemoIndex) { + auto valueIndex = asIndex(value); + auto memoIndex = valueToIndex_[valueIndex]; + if (memoIndex == kKeyNotFound) { + memoIndex = static_cast(indexToValue_.size()); + indexToValue_.push_back(value); + valueToIndex_[valueIndex] = memoIndex; + VELOX_DCHECK_LT(memoIndex, cardinality + 1); + onNotFound(memoIndex); } else { - on_found(memo_index); + onFound(memoIndex); } - *out_memo_index = memo_index; + *outMemoIndex = memoIndex; return Status::OK(); } - Status GetOrInsert(const Scalar value, int32_t* out_memo_index) { - return GetOrInsert( - value, [](int32_t i) {}, [](int32_t i) {}, out_memo_index); + Status getOrInsert(const Scalar value, int32_t* outMemoIndex) { + return getOrInsert(value, [](int32_t i) {}, [](int32_t i) {}, outMemoIndex); } - int32_t GetNull() const { - return value_to_index_[cardinality]; + int32_t getNull() const { + return valueToIndex_[cardinality]; } template - int32_t GetOrInsertNull(Func1&& on_found, Func2&& on_not_found) { - auto memo_index = GetNull(); - if (memo_index == kKeyNotFound) { - memo_index = value_to_index_[cardinality] = size(); - index_to_value_.push_back(0); - on_not_found(memo_index); + int32_t getOrInsertNull(Func1&& onFound, Func2&& onNotFound) { + auto memoIndex = getNull(); + if (memoIndex == kKeyNotFound) { + memoIndex = valueToIndex_[cardinality] = size(); + indexToValue_.push_back(0); + onNotFound(memoIndex); } else { - on_found(memo_index); + onFound(memoIndex); } - return memo_index; + return memoIndex; } - int32_t GetOrInsertNull() { - return GetOrInsertNull([](int32_t i) {}, [](int32_t i) {}); + int32_t getOrInsertNull() { + return getOrInsertNull([](int32_t i) {}, [](int32_t i) {}); } - // The number of entries in the memo table + // The number of entries in the memo table. // (which is also 1 + the largest memo index) int32_t size() const override { - return static_cast(index_to_value_.size()); + return static_cast(indexToValue_.size()); } // Merge entries from `other_table` into `this`. - Status MergeTable(const SmallScalarMemoTable& other_table) { - for (const Scalar& other_val : other_table.index_to_value_) { + Status mergeTable(const SmallScalarMemoTable& otherTable) { + for (const Scalar& otherVal : otherTable.indexToValue_) { int32_t unused; - RETURN_NOT_OK(this->GetOrInsert(other_val, &unused)); + RETURN_NOT_OK(this->getOrInsert(otherVal, &unused)); } return Status::OK(); } - // Copy values starting from index `start` into `out_data` - void CopyValues(int32_t start, Scalar* out_data) const { + // Copy values starting from index `start` into `out_data`. + void copyValues(int32_t start, Scalar* outData) const { VELOX_DCHECK_GE(start, 0); - VELOX_DCHECK_LE(static_cast(start), index_to_value_.size()); + VELOX_DCHECK_LE(static_cast(start), indexToValue_.size()); int64_t offset = start * static_cast(sizeof(Scalar)); memcpy( - out_data, - index_to_value_.data() + offset, + outData, + indexToValue_.data() + offset, (size() - start) * sizeof(Scalar)); } - void CopyValues(Scalar* out_data) const { - CopyValues(0, out_data); + void copyValues(Scalar* outData) const { + copyValues(0, outData); } const std::vector& values() const { - return index_to_value_; + return indexToValue_; } protected: @@ -695,287 +691,285 @@ class SmallScalarMemoTable : public MemoTable { cardinality <= 256, "cardinality too large for direct-addressed table"); - uint32_t AsIndex(Scalar value) const { - return SmallScalarTraits::AsIndex(value); + uint32_t asIndex(Scalar value) const { + return SmallScalarTraits::asIndex(value); } // The last index is reserved for the null element. - int32_t value_to_index_[cardinality + 1]; - std::vector index_to_value_; + int32_t valueToIndex_[cardinality + 1]; + std::vector indexToValue_; }; -// ---------------------------------------------------------------------- +// ----------------------------------------------------------------------. // A memoization table for variable-sized binary data. template class BinaryMemoTable : public MemoTable { public: - using builder_offset_type = typename BinaryBuilderT::offset_type; + using BuilderOffsetType = typename BinaryBuilderT::offset_type; explicit BinaryMemoTable( MemoryPool* pool, int64_t entries = 0, - int64_t values_size = -1) - : hash_table_(pool, static_cast(entries)), - binary_builder_(pool) { - const int64_t data_size = (values_size < 0) ? entries * 4 : values_size; - auto status = binary_builder_.Resize(entries); + int64_t valuesSize = -1) + : hashTable_(pool, static_cast(entries)), binaryBuilder_(pool) { + const int64_t dataSize = (valuesSize < 0) ? entries * 4 : valuesSize; + auto status = binaryBuilder_.Reserve(entries); VELOX_DCHECK(status.ok(), status.ToString()); - status = binary_builder_.ReserveData(data_size); + status = binaryBuilder_.ReserveData(dataSize); VELOX_DCHECK(status.ok(), status.ToString()); } - int32_t Get(const void* data, builder_offset_type length) const { - hash_t h = ComputeStringHash<0>(data, length); - auto p = Lookup(h, data, length); + int32_t get(const void* data, BuilderOffsetType length) const { + hash_t h = computeStringHash<0>(data, length); + auto p = lookup(h, data, length); if (p.second) { - return p.first->payload.memo_index; + return p.first->payload.memoIndex; } else { return kKeyNotFound; } } - int32_t Get(std::string_view value) const { - return Get(value.data(), static_cast(value.length())); + int32_t get(std::string_view value) const { + return get(value.data(), static_cast(value.length())); } template - Status GetOrInsert( + Status getOrInsert( const void* data, - builder_offset_type length, - Func1&& on_found, - Func2&& on_not_found, - int32_t* out_memo_index) { - hash_t h = ComputeStringHash<0>(data, length); - auto p = Lookup(h, data, length); - int32_t memo_index; + BuilderOffsetType length, + Func1&& onFound, + Func2&& onNotFound, + int32_t* outMemoIndex) { + hash_t h = computeStringHash<0>(data, length); + auto p = lookup(h, data, length); + int32_t memoIndex; if (p.second) { - memo_index = p.first->payload.memo_index; - on_found(memo_index); + memoIndex = p.first->payload.memoIndex; + onFound(memoIndex); } else { - memo_index = size(); - // Insert string value + memoIndex = size(); + // Insert string value. RETURN_NOT_OK( - binary_builder_.Append(static_cast(data), length)); - // Insert hash entry - RETURN_NOT_OK(hash_table_.Insert( - const_cast(p.first), h, {memo_index})); + binaryBuilder_.Append(static_cast(data), length)); + // Insert hash entry. + RETURN_NOT_OK(hashTable_.insert( + const_cast(p.first), h, {memoIndex})); - on_not_found(memo_index); + onNotFound(memoIndex); } - *out_memo_index = memo_index; + *outMemoIndex = memoIndex; return Status::OK(); } template - Status GetOrInsert( + Status getOrInsert( std::string_view value, - Func1&& on_found, - Func2&& on_not_found, - int32_t* out_memo_index) { - return GetOrInsert( + Func1&& onFound, + Func2&& onNotFound, + int32_t* outMemoIndex) { + return getOrInsert( value.data(), - static_cast(value.length()), - std::forward(on_found), - std::forward(on_not_found), - out_memo_index); + static_cast(value.length()), + std::forward(onFound), + std::forward(onNotFound), + outMemoIndex); } - Status GetOrInsert( + Status getOrInsert( const void* data, - builder_offset_type length, - int32_t* out_memo_index) { - return GetOrInsert( - data, length, [](int32_t i) {}, [](int32_t i) {}, out_memo_index); + BuilderOffsetType length, + int32_t* outMemoIndex) { + return getOrInsert( + data, length, [](int32_t i) {}, [](int32_t i) {}, outMemoIndex); } - Status GetOrInsert(std::string_view value, int32_t* out_memo_index) { - return GetOrInsert( + Status getOrInsert(std::string_view value, int32_t* outMemoIndex) { + return getOrInsert( value.data(), - static_cast(value.length()), - out_memo_index); + static_cast(value.length()), + outMemoIndex); } - int32_t GetNull() const { - return null_index_; + int32_t getNull() const { + return nullIndex_; } template - int32_t GetOrInsertNull(Func1&& on_found, Func2&& on_not_found) { - int32_t memo_index = GetNull(); - if (memo_index == kKeyNotFound) { - memo_index = null_index_ = size(); - auto status = binary_builder_.AppendNull(); + int32_t getOrInsertNull(Func1&& onFound, Func2&& onNotFound) { + int32_t memoIndex = getNull(); + if (memoIndex == kKeyNotFound) { + memoIndex = nullIndex_ = size(); + auto status = binaryBuilder_.AppendNull(); VELOX_DCHECK(status.ok(), status.ToString()); - on_not_found(memo_index); + onNotFound(memoIndex); } else { - on_found(memo_index); + onFound(memoIndex); } - return memo_index; + return memoIndex; } - int32_t GetOrInsertNull() { - return GetOrInsertNull([](int32_t i) {}, [](int32_t i) {}); + int32_t getOrInsertNull() { + return getOrInsertNull([](int32_t i) {}, [](int32_t i) {}); } - // The number of entries in the memo table + // The number of entries in the memo table. // (which is also 1 + the largest memo index) int32_t size() const override { return static_cast( - hash_table_.size() + (GetNull() != kKeyNotFound)); + hashTable_.size() + (getNull() != kKeyNotFound)); } - int64_t values_size() const { - return binary_builder_.value_data_length(); + int64_t valuesSize() const { + return binaryBuilder_.value_data_length(); } - // Copy (n + 1) offsets starting from index `start` into `out_data` + // Copy (n + 1) offsets starting from index `start` into `out_data`. template - void CopyOffsets(int32_t start, Offset* out_data) const { + void copyOffsets(int32_t start, Offset* outData) const { VELOX_DCHECK_LE(start, size()); - const builder_offset_type* offsets = binary_builder_.offsets_data(); - const builder_offset_type delta = - start < binary_builder_.length() ? offsets[start] : 0; + const BuilderOffsetType* offsets = binaryBuilder_.offsets_data(); + const BuilderOffsetType delta = + start < binaryBuilder_.length() ? offsets[start] : 0; for (int32_t i = start; i < size(); ++i) { - const builder_offset_type adjusted_offset = offsets[i] - delta; - Offset cast_offset = static_cast(adjusted_offset); + const BuilderOffsetType adjustedOffset = offsets[i] - delta; + Offset castOffset = static_cast(adjustedOffset); assert( - static_cast(cast_offset) == - adjusted_offset); // avoid truncation - *out_data++ = cast_offset; + static_cast(castOffset) == + adjustedOffset); // avoid truncation + *outData++ = castOffset; } // Copy last value since BinaryBuilder only materializes it on in Finish() - *out_data = - static_cast(binary_builder_.value_data_length() - delta); + *outData = static_cast(binaryBuilder_.value_data_length() - delta); } template - void CopyOffsets(Offset* out_data) const { - CopyOffsets(0, out_data); + void copyOffsets(Offset* outData) const { + copyOffsets(0, outData); } - // Copy values starting from index `start` into `out_data` - void CopyValues(int32_t start, uint8_t* out_data) const { - CopyValues(start, -1, out_data); + // Copy values starting from index `start` into `out_data`. + void copyValues(int32_t start, uint8_t* outData) const { + copyValues(start, -1, outData); } - // Same as above, but check output size in debug mode - void CopyValues(int32_t start, int64_t out_size, uint8_t* out_data) const { + // Same as above, but check output size in debug mode. + void copyValues(int32_t start, int64_t outSize, uint8_t* outData) const { VELOX_DCHECK_LE(start, size()); // The absolute byte offset of `start` value in the binary buffer. - const builder_offset_type offset = binary_builder_.offset(start); + const BuilderOffsetType offset = binaryBuilder_.offsets_data()[start]; const auto length = - binary_builder_.value_data_length() - static_cast(offset); + binaryBuilder_.value_data_length() - static_cast(offset); - if (out_size != -1) { - assert(static_cast(length) <= out_size); + if (outSize != -1) { + assert(static_cast(length) <= outSize); } - auto view = binary_builder_.GetView(start); - memcpy(out_data, view.data(), length); + auto view = binaryBuilder_.GetView(start); + memcpy(outData, view.data(), length); } - void CopyValues(uint8_t* out_data) const { - CopyValues(0, -1, out_data); + void copyValues(uint8_t* outData) const { + copyValues(0, -1, outData); } - void CopyValues(int64_t out_size, uint8_t* out_data) const { - CopyValues(0, out_size, out_data); + void copyValues(int64_t outSize, uint8_t* outData) const { + copyValues(0, outSize, outData); } - void CopyFixedWidthValues( + void copyFixedWidthValues( int32_t start, - int32_t width_size, - int64_t out_size, - uint8_t* out_data) const { - // This method exists to cope with the fact that the BinaryMemoTable does - // not know the fixed width when inserting the null value. The data - // buffer hold a zero length string for the null value (if found). + int32_t widthSize, + int64_t outSize, + uint8_t* outData) const { + // This method exists to cope with the fact that the BinaryMemoTable does. + // Not know the fixed width when inserting the null value. The data. + // Buffer hold a zero length string for the null value (if found). // - // Thus, the method will properly inject an empty value of the proper width - // in the output buffer. + // Thus, the method will properly inject an empty value of the proper width. + // In the output buffer. // if (start >= size()) { return; } - int32_t null_index = GetNull(); - if (null_index < start) { + int32_t nullIndex = getNull(); + if (nullIndex < start) { // Nothing to skip, proceed as usual. - CopyValues(start, out_size, out_data); + copyValues(start, outSize, outData); return; } - builder_offset_type left_offset = binary_builder_.offset(start); + BuilderOffsetType leftOffset = binaryBuilder_.offsets_data()[start]; - // Ensure that the data length is exactly missing width_size bytes to fit - // in the expected output (n_values * width_size). + // Ensure that the data length is exactly missing width_size bytes to fit. + // In the expected output (n_values * width_size). #ifndef NDEBUG - int64_t data_length = values_size() - static_cast(left_offset); - assert(data_length + width_size == out_size); - ARROW_UNUSED(data_length); + int64_t dataLength = valuesSize() - static_cast(leftOffset); + assert(dataLength + widthSize == outSize); + ARROW_UNUSED(dataLength); #endif - auto in_data = binary_builder_.value_data() + left_offset; - // The null use 0-length in the data, slice the data in 2 and skip by - // width_size in out_data. [part_1][width_size][part_2] - auto null_data_offset = binary_builder_.offset(null_index); - auto left_size = null_data_offset - left_offset; - if (left_size > 0) { - memcpy(out_data, in_data + left_offset, left_size); + auto inData = binaryBuilder_.value_data() + leftOffset; + // The null use 0-length in the data, slice the data in 2 and skip by. + // Width_size in out_data. [part_1][width_size][part_2]. + auto nullDataOffset = binaryBuilder_.offsets_data()[nullIndex]; + auto leftSize = nullDataOffset - leftOffset; + if (leftSize > 0) { + memcpy(outData, inData + leftOffset, leftSize); } - // Zero-initialize the null entry - memset(out_data + left_size, 0, width_size); - - auto right_size = values_size() - static_cast(null_data_offset); - if (right_size > 0) { - // skip the null fixed size value. - auto out_offset = left_size + width_size; - assert(out_data + out_offset + right_size == out_data + out_size); - memcpy(out_data + out_offset, in_data + null_data_offset, right_size); + // Zero-initialize the null entry. + memset(outData + leftSize, 0, widthSize); + + auto rightSize = valuesSize() - static_cast(nullDataOffset); + if (rightSize > 0) { + // Skip the null fixed size value. + auto outOffset = leftSize + widthSize; + assert(outData + outOffset + rightSize == outData + outSize); + memcpy(outData + outOffset, inData + nullDataOffset, rightSize); } } // Visit the stored values in insertion order. - // The visitor function should have the signature `void(std::string_view)` - // or `void(const std::string_view&)`. + // The visitor function should have the signature `void(std::string_view)`. + // Or `void(const std::string_view&)`. template - void VisitValues(int32_t start, VisitFunc&& visit) const { + void visitValues(int32_t start, VisitFunc&& visit) const { for (int32_t i = start; i < size(); ++i) { - auto sv = binary_builder_.GetView(i); + auto sv = binaryBuilder_.GetView(i); visit(std::string_view(sv.data(), sv.size())); } } protected: struct Payload { - int32_t memo_index; + int32_t memoIndex; }; using HashTableType = HashTable; using HashTableEntry = typename HashTable::Entry; - HashTableType hash_table_; - BinaryBuilderT binary_builder_; + HashTableType hashTable_; + BinaryBuilderT binaryBuilder_; - int32_t null_index_ = kKeyNotFound; + int32_t nullIndex_ = kKeyNotFound; std::pair - Lookup(hash_t h, const void* data, builder_offset_type length) const { - auto cmp_func = [&](const Payload* payload) { - auto lhs = binary_builder_.GetView(payload->memo_index); + lookup(hash_t h, const void* data, BuilderOffsetType length) const { + auto cmpFunc = [&](const Payload* payload) { + auto lhs = binaryBuilder_.GetView(payload->memoIndex); auto rhs = std::string_view(static_cast(data), length); return lhs == rhs; }; - return hash_table_.Lookup(h, cmp_func); + return hashTable_.lookup(h, cmpFunc); } public: - Status MergeTable(const BinaryMemoTable& other_table) { - other_table.VisitValues(0, [this](std::string_view other_value) { + Status mergeTable(const BinaryMemoTable& otherTable) { + otherTable.visitValues(0, [this](std::string_view otherValue) { int32_t unused; - auto status = this->GetOrInsert(other_value, &unused); + auto status = this->getOrInsert(otherValue, &unused); VELOX_DCHECK(status.ok(), status.ToString()); }); return Status::OK(); @@ -992,8 +986,8 @@ struct HashTraits<::arrow::BooleanType> { template struct HashTraits> { - using c_type = typename T::c_type; - using MemoTableType = SmallScalarMemoTable; + using CType = typename T::CType; + using MemoTableType = SmallScalarMemoTable; }; template @@ -1001,8 +995,8 @@ struct HashTraits< T, enable_if_t< ::arrow::has_c_type::value && !::arrow::is_8bit_int::value>> { - using c_type = typename T::c_type; - using MemoTableType = ScalarMemoTable; + using CType = typename T::CType; + using MemoTableType = ScalarMemoTable; }; template @@ -1027,35 +1021,35 @@ struct HashTraits< }; template -static inline Status ComputeNullBitmap( +static inline Status computeNullBitmap( MemoryPool* pool, - const MemoTableType& memo_table, - int64_t start_offset, - int64_t* null_count, - std::shared_ptr<::arrow::Buffer>* null_bitmap) { - int64_t dict_length = static_cast(memo_table.size()) - start_offset; - int64_t null_index = memo_table.GetNull(); - - *null_count = 0; - *null_bitmap = nullptr; - - if (null_index != kKeyNotFound && null_index >= start_offset) { - null_index -= start_offset; - *null_count = 1; + const MemoTableType& memoTable, + int64_t startOffset, + int64_t* nullCount, + std::shared_ptr<::arrow::Buffer>* nullBitmap) { + int64_t dictLength = static_cast(memoTable.size()) - startOffset; + int64_t nullIndex = memoTable.getNull(); + + *nullCount = 0; + *nullBitmap = nullptr; + + if (nullIndex != kKeyNotFound && nullIndex >= startOffset) { + nullIndex -= startOffset; + *nullCount = 1; ARROW_ASSIGN_OR_RAISE( - *null_bitmap, - ::arrow::internal::BitmapAllButOne(pool, dict_length, null_index)); + *nullBitmap, + ::arrow::internal::BitmapAllButOne(pool, dictLength, nullIndex)); } return Status::OK(); } struct StringViewHash { - // std::hash compatible hasher for use with std::unordered_* - // (the std::hash specialization provided by nonstd constructs std::string + // Std::hash compatible hasher for use with std::unordered_*. + // (The std::hash specialization provided by nonstd constructs std::string. // temporaries then invokes std::hash against those) hash_t operator()(std::string_view value) const { - return ComputeStringHash<0>( + return computeStringHash<0>( value.data(), static_cast(value.size())); } }; diff --git a/velox/dwio/parquet/writer/arrow/util/OverflowUtilInternal.h b/velox/dwio/parquet/writer/arrow/util/OverflowUtilInternal.h index 97efda714a8..c0c7d434c07 100644 --- a/velox/dwio/parquet/writer/arrow/util/OverflowUtilInternal.h +++ b/velox/dwio/parquet/writer/arrow/util/OverflowUtilInternal.h @@ -26,60 +26,60 @@ #include "arrow/util/macros.h" #include "arrow/util/visibility.h" -// "safe-math.h" includes from the Windows headers. +// "Safe-math.h" includes from the Windows headers #include "arrow/util/windows_compatibility.h" -// #include "arrow/vendored/portable-snippets/safe-math.h" +// #include "arrow/vendored/portable-snippets/safe-math.h". #include "velox/dwio/parquet/writer/arrow/util/safe-math.h" // clang-format off (avoid include reordering) #include "arrow/util/windows_fixup.h" -// clang-format on +// clang-format on. namespace arrow { namespace internal { -// Define functions AddWithOverflow, SubtractWithOverflow, MultiplyWithOverflow +// Define functions AddWithOverflow, SubtractWithOverflow, MultiplyWithOverflow. // with the signature `bool(T u, T v, T* out)` where T is an integer type. -// On overflow, these functions return true. Otherwise, false is returned +// on overflow, these functions return true. Otherwise, false is returned. // and `out` is updated with the result of the operation. -#define OP_WITH_OVERFLOW(_func_name, _psnip_op, _type, _psnip_type) \ - [[nodiscard]] static inline bool _func_name(_type u, _type v, _type* out) { \ - return !psnip_safe_##_psnip_type##_##_psnip_op(out, u, v); \ +#define OP_WITH_OVERFLOW(funcName, psnipOp, Type, PsnipType) \ + [[nodiscard]] static inline bool funcName(Type u, Type v, Type* out) { \ + return !psnipSafe_##PsnipType##_##psnipOp(out, u, v); \ } -#define OPS_WITH_OVERFLOW(_func_name, _psnip_op) \ - OP_WITH_OVERFLOW(_func_name, _psnip_op, int8_t, int8) \ - OP_WITH_OVERFLOW(_func_name, _psnip_op, int16_t, int16) \ - OP_WITH_OVERFLOW(_func_name, _psnip_op, int32_t, int32) \ - OP_WITH_OVERFLOW(_func_name, _psnip_op, int64_t, int64) \ - OP_WITH_OVERFLOW(_func_name, _psnip_op, uint8_t, uint8) \ - OP_WITH_OVERFLOW(_func_name, _psnip_op, uint16_t, uint16) \ - OP_WITH_OVERFLOW(_func_name, _psnip_op, uint32_t, uint32) \ - OP_WITH_OVERFLOW(_func_name, _psnip_op, uint64_t, uint64) - -OPS_WITH_OVERFLOW(AddWithOverflow, add) +#define OPS_WITH_OVERFLOW(funcName, psnipOp) \ + OP_WITH_OVERFLOW(funcName, psnipOp, int8_t, int8) \ + OP_WITH_OVERFLOW(funcName, psnipOp, int16_t, int16) \ + OP_WITH_OVERFLOW(funcName, psnipOp, int32_t, int32) \ + OP_WITH_OVERFLOW(funcName, psnipOp, int64_t, int64) \ + OP_WITH_OVERFLOW(funcName, psnipOp, uint8_t, uint8) \ + OP_WITH_OVERFLOW(funcName, psnipOp, uint16_t, uint16) \ + OP_WITH_OVERFLOW(funcName, psnipOp, uint32_t, uint32) \ + OP_WITH_OVERFLOW(funcName, psnipOp, uint64_t, uint64) + +OPS_WITH_OVERFLOW(addWithOverflow, add) OPS_WITH_OVERFLOW(SubtractWithOverflow, sub) -OPS_WITH_OVERFLOW(MultiplyWithOverflow, mul) +OPS_WITH_OVERFLOW(multiplyWithOverflow, mul) OPS_WITH_OVERFLOW(DivideWithOverflow, div) #undef OP_WITH_OVERFLOW #undef OPS_WITH_OVERFLOW -// Define function NegateWithOverflow with the signature `bool(T u, T* out)` +// Define function NegateWithOverflow with the signature `bool(T u, T* out)`. // where T is a signed integer type. On overflow, these functions return true. -// Otherwise, false is returned and `out` is updated with the result of the +// otherwise, false is returned and `out` is updated with the result of the. // operation. -#define UNARY_OP_WITH_OVERFLOW(_func_name, _psnip_op, _type, _psnip_type) \ - [[nodiscard]] static inline bool _func_name(_type u, _type* out) { \ - return !psnip_safe_##_psnip_type##_##_psnip_op(out, u); \ +#define UNARY_OP_WITH_OVERFLOW(funcName, psnipOp, Type, PsnipType) \ + [[nodiscard]] static inline bool funcName(Type u, Type* out) { \ + return !psnipSafe_##PsnipType##_##psnipOp(out, u); \ } -#define SIGNED_UNARY_OPS_WITH_OVERFLOW(_func_name, _psnip_op) \ - UNARY_OP_WITH_OVERFLOW(_func_name, _psnip_op, int8_t, int8) \ - UNARY_OP_WITH_OVERFLOW(_func_name, _psnip_op, int16_t, int16) \ - UNARY_OP_WITH_OVERFLOW(_func_name, _psnip_op, int32_t, int32) \ - UNARY_OP_WITH_OVERFLOW(_func_name, _psnip_op, int64_t, int64) +#define SIGNED_UNARY_OPS_WITH_OVERFLOW(funcName, psnipOp) \ + UNARY_OP_WITH_OVERFLOW(funcName, psnipOp, int8_t, int8) \ + UNARY_OP_WITH_OVERFLOW(funcName, psnipOp, int16_t, int16) \ + UNARY_OP_WITH_OVERFLOW(funcName, psnipOp, int32_t, int32) \ + UNARY_OP_WITH_OVERFLOW(funcName, psnipOp, int64_t, int64) SIGNED_UNARY_OPS_WITH_OVERFLOW(NegateWithOverflow, neg) @@ -88,7 +88,7 @@ SIGNED_UNARY_OPS_WITH_OVERFLOW(NegateWithOverflow, neg) /// Signed addition with well-defined behaviour on overflow (as unsigned) template -SignedInt SafeSignedAdd(SignedInt u, SignedInt v) { +SignedInt safeSignedAdd(SignedInt u, SignedInt v) { using UnsignedInt = typename std::make_unsigned::type; return static_cast( static_cast(u) + static_cast(v)); @@ -96,7 +96,7 @@ SignedInt SafeSignedAdd(SignedInt u, SignedInt v) { /// Signed subtraction with well-defined behaviour on overflow (as unsigned) template -SignedInt SafeSignedSubtract(SignedInt u, SignedInt v) { +SignedInt safeSignedSubtract(SignedInt u, SignedInt v) { using UnsignedInt = typename std::make_unsigned::type; return static_cast( static_cast(u) - static_cast(v)); @@ -104,15 +104,15 @@ SignedInt SafeSignedSubtract(SignedInt u, SignedInt v) { /// Signed negation with well-defined behaviour on overflow (as unsigned) template -SignedInt SafeSignedNegate(SignedInt u) { +SignedInt safeSignedNegate(SignedInt u) { using UnsignedInt = typename std::make_unsigned::type; return static_cast(~static_cast(u) + 1); } -/// Signed left shift with well-defined behaviour on negative numbers or -/// overflow +/// Signed left shift with well-defined behaviour on negative numbers or. +/// overflow. template -SignedInt SafeLeftShift(SignedInt u, Shift shift) { +SignedInt safeLeftShift(SignedInt u, Shift shift) { using UnsignedInt = typename std::make_unsigned::type; return static_cast(static_cast(u) << shift); } diff --git a/velox/dwio/parquet/writer/arrow/util/VisitArrayInline.h b/velox/dwio/parquet/writer/arrow/util/VisitArrayInline.h index 886f4b0451e..cd4f1be452b 100644 --- a/velox/dwio/parquet/writer/arrow/util/VisitArrayInline.h +++ b/velox/dwio/parquet/writer/arrow/util/VisitArrayInline.h @@ -24,24 +24,23 @@ namespace facebook::velox::parquet::arrow::util { -#define ARRAY_VISIT_INLINE(TYPE_CLASS) \ - case ::arrow::TYPE_CLASS##Type::type_id: \ - return visitor->Visit( \ - ::arrow::internal::checked_cast< \ - const typename TypeTraits<::arrow::TYPE_CLASS##Type>::ArrayType&>( \ - array), \ +#define ARRAY_VISIT_INLINE(TYPE_CLASS) \ + case ::arrow::TYPE_CLASS##Type::type_id: \ + return visitor->visit( \ + ::arrow::internal::checked_cast::ArrayType&>(array), \ std::forward(args)...); -/// \brief Apply the visitors Visit() method specialized to the array type +/// \brief Apply the visitors Visit() method specialized to the array type. /// /// \tparam VISITOR Visitor type that implements Visit() for all array types. -/// \tparam ARGS Additional arguments, if any, will be passed to the Visit -/// function after the `arr` argument \return Status +/// \tparam ARGS Additional arguments, if any, will be passed to the Visit. +/// Function after the `arr` argument \return Status. /// /// A visitor is a type that implements specialized logic for each Arrow type. /// Example usage: /// -/// ``` +/// ```. /// class ExampleVisitor { /// arrow::Status Visit(arrow::NumericArray arr) { ... } /// arrow::Status Visit(arrow::NumericArray arr) { ... } @@ -49,10 +48,10 @@ namespace facebook::velox::parquet::arrow::util { /// } /// ExampleVisitor visitor; /// VisitArrayInline(some_array, &visitor); -/// ``` +/// ```. template inline Status -VisitArrayInline(const Array& array, VISITOR* visitor, ARGS&&... args) { +visitArrayInline(const Array& array, VISITOR* visitor, ARGS&&... args) { switch (array.type_id()) { ARROW_GENERATE_FOR_ALL_TYPES(ARRAY_VISIT_INLINE); default: diff --git a/velox/dwio/parquet/writer/arrow/util/safe-math.h b/velox/dwio/parquet/writer/arrow/util/safe-math.h index 661b62887bc..2164e60355e 100644 --- a/velox/dwio/parquet/writer/arrow/util/safe-math.h +++ b/velox/dwio/parquet/writer/arrow/util/safe-math.h @@ -75,39 +75,39 @@ PSNIP_SAFE__COMPILER_ATTRIBUTES static PSNIP_SAFE__INLINE #endif -// !defined(__cplusplus) added for Solaris support +// !Defined(__cplusplus) added for Solaris support. #if !defined(__cplusplus) && defined(__STDC_VERSION__) && \ __STDC_VERSION__ >= 199901L -#define psnip_safe_bool _Bool +#define psnipSafeBool bool #else -#define psnip_safe_bool int +#define psnipSafeBool int #endif #if !defined(PSNIP_SAFE_NO_FIXED) /* For maximum portability include the exact-int module from portable snippets. */ -#if !defined(psnip_int64_t) || !defined(psnip_uint64_t) || \ - !defined(psnip_int32_t) || !defined(psnip_uint32_t) || \ - !defined(psnip_int16_t) || !defined(psnip_uint16_t) || \ +#if !defined(Psnip_int64_t) || !defined(Psnip_uint64_t) || \ + !defined(Psnip_int32_t) || !defined(Psnip_uint32_t) || \ + !defined(Psnip_int16_t) || !defined(Psnip_uint16_t) || \ !defined(psnip_int8_t) || !defined(psnip_uint8_t) #include -#if !defined(psnip_int64_t) -#define psnip_int64_t int64_t +#if !defined(Psnip_int64_t) +#define Psnip_int64_t int64_t #endif -#if !defined(psnip_uint64_t) -#define psnip_uint64_t uint64_t +#if !defined(Psnip_uint64_t) +#define Psnip_uint64_t uint64_t #endif -#if !defined(psnip_int32_t) -#define psnip_int32_t int32_t +#if !defined(Psnip_int32_t) +#define Psnip_int32_t int32_t #endif -#if !defined(psnip_uint32_t) -#define psnip_uint32_t uint32_t +#if !defined(Psnip_uint32_t) +#define Psnip_uint32_t uint32_t #endif -#if !defined(psnip_int16_t) -#define psnip_int16_t int16_t +#if !defined(Psnip_int16_t) +#define Psnip_int16_t int16_t #endif -#if !defined(psnip_uint16_t) -#define psnip_uint16_t uint16_t +#if !defined(Psnip_uint16_t) +#define Psnip_uint16_t uint16_t #endif #if !defined(psnip_int8_t) #define psnip_int8_t int8_t @@ -154,16 +154,16 @@ #if !defined(PSNIP_SAFE_NO_PROMOTIONS) -#define PSNIP_SAFE_DEFINE_LARGER_BINARY_OP(T, name, op_name, op) \ - PSNIP_SAFE__FUNCTION psnip_safe_##name##_larger \ - psnip_safe_larger_##name##_##op_name(T a, T b) { \ - return ((psnip_safe_##name##_larger)a)op((psnip_safe_##name##_larger)b); \ +#define PSNIP_SAFE_DEFINE_LARGER_BINARY_OP(T, name, opName, op) \ + PSNIP_SAFE__FUNCTION psnipSafe##name##Larger \ + psnipSafeLarger_##name##_##opName(T a, T b) { \ + return ((psnipSafe##name##Larger)a)op((psnipSafe##name##Larger)b); \ } -#define PSNIP_SAFE_DEFINE_LARGER_UNARY_OP(T, name, op_name, op) \ - PSNIP_SAFE__FUNCTION psnip_safe_##name##_larger \ - psnip_safe_larger_##name##_##op_name(T value) { \ - return (op((psnip_safe_##name##_larger)value)); \ +#define PSNIP_SAFE_DEFINE_LARGER_UNARY_OP(T, name, opName, op) \ + PSNIP_SAFE__FUNCTION psnipSafe##name##Larger \ + psnipSafeLarger_##name##_##opName(T value) { \ + return (op((psnipSafe##name##Larger)value)); \ } #define PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(T, name) \ @@ -188,235 +188,235 @@ ((__GNUC__ >= 4) || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)) && \ defined(__SIZEOF_INT128__) && !defined(__ibmxl__) #define PSNIP_SAFE_HAVE_128 -typedef __int128 psnip_safe_int128_t; -typedef unsigned __int128 psnip_safe_uint128_t; +typedef __int128 Psnip_safe_int128_t; +typedef unsigned __int128 Psnip_safe_uint128_t; #endif /* defined(__GNUC__) */ #if !defined(PSNIP_SAFE_NO_FIXED) #define PSNIP_SAFE_HAVE_INT8_LARGER #define PSNIP_SAFE_HAVE_UINT8_LARGER -typedef psnip_int16_t psnip_safe_int8_larger; -typedef psnip_uint16_t psnip_safe_uint8_larger; +typedef Psnip_int16_t psnipSafeint8Larger; +typedef Psnip_uint16_t psnipSafeuint8Larger; #define PSNIP_SAFE_HAVE_INT16_LARGER -typedef psnip_int32_t psnip_safe_int16_larger; -typedef psnip_uint32_t psnip_safe_uint16_larger; +typedef Psnip_int32_t psnipSafeint16Larger; +typedef Psnip_uint32_t psnipSafeuint16Larger; #define PSNIP_SAFE_HAVE_INT32_LARGER -typedef psnip_int64_t psnip_safe_int32_larger; -typedef psnip_uint64_t psnip_safe_uint32_larger; +typedef Psnip_int64_t psnipSafeint32Larger; +typedef Psnip_uint64_t psnipSafeuint32Larger; #if defined(PSNIP_SAFE_HAVE_128) #define PSNIP_SAFE_HAVE_INT64_LARGER -typedef psnip_safe_int128_t psnip_safe_int64_larger; -typedef psnip_safe_uint128_t psnip_safe_uint64_larger; +typedef Psnip_safe_int128_t psnipSafeint64Larger; +typedef Psnip_safe_uint128_t psnipSafeuint64Larger; #endif /* defined(PSNIP_SAFE_HAVE_128) */ #endif /* !defined(PSNIP_SAFE_NO_FIXED) */ #define PSNIP_SAFE_HAVE_LARGER_SCHAR #if PSNIP_SAFE_IS_LARGER(SCHAR_MAX, SHRT_MAX) -typedef short psnip_safe_schar_larger; +typedef short psnipSafescharLarger; #elif PSNIP_SAFE_IS_LARGER(SCHAR_MAX, INT_MAX) -typedef int psnip_safe_schar_larger; +typedef int psnipSafescharLarger; #elif PSNIP_SAFE_IS_LARGER(SCHAR_MAX, LONG_MAX) -typedef long psnip_safe_schar_larger; +typedef long psnipSafescharLarger; #elif PSNIP_SAFE_IS_LARGER(SCHAR_MAX, LLONG_MAX) -typedef long long psnip_safe_schar_larger; +typedef long long psnipSafescharLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(SCHAR_MAX, 0x7fff) -typedef psnip_int16_t psnip_safe_schar_larger; +typedef Psnip_int16_t psnipSafescharLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && \ PSNIP_SAFE_IS_LARGER(SCHAR_MAX, 0x7fffffffLL) -typedef psnip_int32_t psnip_safe_schar_larger; +typedef Psnip_int32_t psnipSafescharLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && \ PSNIP_SAFE_IS_LARGER(SCHAR_MAX, 0x7fffffffffffffffLL) -typedef psnip_int64_t psnip_safe_schar_larger; +typedef Psnip_int64_t psnipSafescharLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && \ (SCHAR_MAX <= 0x7fffffffffffffffLL) -typedef psnip_safe_int128_t psnip_safe_schar_larger; +typedef Psnip_safe_int128_t psnipSafescharLarger; #else #undef PSNIP_SAFE_HAVE_LARGER_SCHAR #endif #define PSNIP_SAFE_HAVE_LARGER_UCHAR #if PSNIP_SAFE_IS_LARGER(UCHAR_MAX, USHRT_MAX) -typedef unsigned short psnip_safe_uchar_larger; +typedef unsigned short psnipSafeucharLarger; #elif PSNIP_SAFE_IS_LARGER(UCHAR_MAX, UINT_MAX) -typedef unsigned int psnip_safe_uchar_larger; +typedef unsigned int psnipSafeucharLarger; #elif PSNIP_SAFE_IS_LARGER(UCHAR_MAX, ULONG_MAX) -typedef unsigned long psnip_safe_uchar_larger; +typedef unsigned long psnipSafeucharLarger; #elif PSNIP_SAFE_IS_LARGER(UCHAR_MAX, ULLONG_MAX) -typedef unsigned long long psnip_safe_uchar_larger; +typedef unsigned long long psnipSafeucharLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(UCHAR_MAX, 0xffffU) -typedef psnip_uint16_t psnip_safe_uchar_larger; +typedef Psnip_uint16_t psnipSafeucharLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && \ PSNIP_SAFE_IS_LARGER(UCHAR_MAX, 0xffffffffUL) -typedef psnip_uint32_t psnip_safe_uchar_larger; +typedef Psnip_uint32_t psnipSafeucharLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && \ PSNIP_SAFE_IS_LARGER(UCHAR_MAX, 0xffffffffffffffffULL) -typedef psnip_uint64_t psnip_safe_uchar_larger; +typedef Psnip_uint64_t psnipSafeucharLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && \ (UCHAR_MAX <= 0xffffffffffffffffULL) -typedef psnip_safe_uint128_t psnip_safe_uchar_larger; +typedef Psnip_safe_uint128_t psnipSafeucharLarger; #else #undef PSNIP_SAFE_HAVE_LARGER_UCHAR #endif #if CHAR_MIN == 0 && defined(PSNIP_SAFE_HAVE_LARGER_UCHAR) #define PSNIP_SAFE_HAVE_LARGER_CHAR -typedef psnip_safe_uchar_larger psnip_safe_char_larger; +typedef psnipSafeucharLarger psnipSafecharLarger; #elif CHAR_MIN < 0 && defined(PSNIP_SAFE_HAVE_LARGER_SCHAR) #define PSNIP_SAFE_HAVE_LARGER_CHAR -typedef psnip_safe_schar_larger psnip_safe_char_larger; +typedef psnipSafescharLarger psnipSafecharLarger; #endif #define PSNIP_SAFE_HAVE_LARGER_SHRT #if PSNIP_SAFE_IS_LARGER(SHRT_MAX, INT_MAX) -typedef int psnip_safe_short_larger; +typedef int psnipSafeshortLarger; #elif PSNIP_SAFE_IS_LARGER(SHRT_MAX, LONG_MAX) -typedef long psnip_safe_short_larger; +typedef long psnipSafeshortLarger; #elif PSNIP_SAFE_IS_LARGER(SHRT_MAX, LLONG_MAX) -typedef long long psnip_safe_short_larger; +typedef long long psnipSafeshortLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(SHRT_MAX, 0x7fff) -typedef psnip_int16_t psnip_safe_short_larger; +typedef Psnip_int16_t psnipSafeshortLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && \ PSNIP_SAFE_IS_LARGER(SHRT_MAX, 0x7fffffffLL) -typedef psnip_int32_t psnip_safe_short_larger; +typedef Psnip_int32_t psnipSafeshortLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && \ PSNIP_SAFE_IS_LARGER(SHRT_MAX, 0x7fffffffffffffffLL) -typedef psnip_int64_t psnip_safe_short_larger; +typedef Psnip_int64_t psnipSafeshortLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && \ (SHRT_MAX <= 0x7fffffffffffffffLL) -typedef psnip_safe_int128_t psnip_safe_short_larger; +typedef Psnip_safe_int128_t psnipSafeshortLarger; #else #undef PSNIP_SAFE_HAVE_LARGER_SHRT #endif #define PSNIP_SAFE_HAVE_LARGER_USHRT #if PSNIP_SAFE_IS_LARGER(USHRT_MAX, UINT_MAX) -typedef unsigned int psnip_safe_ushort_larger; +typedef unsigned int psnipSafeushortLarger; #elif PSNIP_SAFE_IS_LARGER(USHRT_MAX, ULONG_MAX) -typedef unsigned long psnip_safe_ushort_larger; +typedef unsigned long psnipSafeushortLarger; #elif PSNIP_SAFE_IS_LARGER(USHRT_MAX, ULLONG_MAX) -typedef unsigned long long psnip_safe_ushort_larger; +typedef unsigned long long psnipSafeushortLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(USHRT_MAX, 0xffff) -typedef psnip_uint16_t psnip_safe_ushort_larger; +typedef Psnip_uint16_t psnipSafeushortLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && \ PSNIP_SAFE_IS_LARGER(USHRT_MAX, 0xffffffffUL) -typedef psnip_uint32_t psnip_safe_ushort_larger; +typedef Psnip_uint32_t psnipSafeushortLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && \ PSNIP_SAFE_IS_LARGER(USHRT_MAX, 0xffffffffffffffffULL) -typedef psnip_uint64_t psnip_safe_ushort_larger; +typedef Psnip_uint64_t psnipSafeushortLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && \ (USHRT_MAX <= 0xffffffffffffffffULL) -typedef psnip_safe_uint128_t psnip_safe_ushort_larger; +typedef Psnip_safe_uint128_t psnipSafeushortLarger; #else #undef PSNIP_SAFE_HAVE_LARGER_USHRT #endif #define PSNIP_SAFE_HAVE_LARGER_INT #if PSNIP_SAFE_IS_LARGER(INT_MAX, LONG_MAX) -typedef long psnip_safe_int_larger; +typedef long psnipSafeintLarger; #elif PSNIP_SAFE_IS_LARGER(INT_MAX, LLONG_MAX) -typedef long long psnip_safe_int_larger; +typedef long long psnipSafeintLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(INT_MAX, 0x7fff) -typedef psnip_int16_t psnip_safe_int_larger; +typedef Psnip_int16_t psnipSafeintLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && \ PSNIP_SAFE_IS_LARGER(INT_MAX, 0x7fffffffLL) -typedef psnip_int32_t psnip_safe_int_larger; +typedef Psnip_int32_t psnipSafeintLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && \ PSNIP_SAFE_IS_LARGER(INT_MAX, 0x7fffffffffffffffLL) -typedef psnip_int64_t psnip_safe_int_larger; +typedef Psnip_int64_t psnipSafeintLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && \ (INT_MAX <= 0x7fffffffffffffffLL) -typedef psnip_safe_int128_t psnip_safe_int_larger; +typedef Psnip_safe_int128_t psnipSafeintLarger; #else #undef PSNIP_SAFE_HAVE_LARGER_INT #endif #define PSNIP_SAFE_HAVE_LARGER_UINT #if PSNIP_SAFE_IS_LARGER(UINT_MAX, ULONG_MAX) -typedef unsigned long psnip_safe_uint_larger; +typedef unsigned long psnipSafeuintLarger; #elif PSNIP_SAFE_IS_LARGER(UINT_MAX, ULLONG_MAX) -typedef unsigned long long psnip_safe_uint_larger; +typedef unsigned long long psnipSafeuintLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(UINT_MAX, 0xffff) -typedef psnip_uint16_t psnip_safe_uint_larger; +typedef Psnip_uint16_t psnipSafeuintLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && \ PSNIP_SAFE_IS_LARGER(UINT_MAX, 0xffffffffUL) -typedef psnip_uint32_t psnip_safe_uint_larger; +typedef Psnip_uint32_t psnipSafeuintLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && \ PSNIP_SAFE_IS_LARGER(UINT_MAX, 0xffffffffffffffffULL) -typedef psnip_uint64_t psnip_safe_uint_larger; +typedef Psnip_uint64_t psnipSafeuintLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && \ (UINT_MAX <= 0xffffffffffffffffULL) -typedef psnip_safe_uint128_t psnip_safe_uint_larger; +typedef Psnip_safe_uint128_t psnipSafeuintLarger; #else #undef PSNIP_SAFE_HAVE_LARGER_UINT #endif #define PSNIP_SAFE_HAVE_LARGER_LONG #if PSNIP_SAFE_IS_LARGER(LONG_MAX, LLONG_MAX) -typedef long long psnip_safe_long_larger; +typedef long long psnipSafelongLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(LONG_MAX, 0x7fff) -typedef psnip_int16_t psnip_safe_long_larger; +typedef Psnip_int16_t psnipSafelongLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && \ PSNIP_SAFE_IS_LARGER(LONG_MAX, 0x7fffffffLL) -typedef psnip_int32_t psnip_safe_long_larger; +typedef Psnip_int32_t psnipSafelongLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && \ PSNIP_SAFE_IS_LARGER(LONG_MAX, 0x7fffffffffffffffLL) -typedef psnip_int64_t psnip_safe_long_larger; +typedef Psnip_int64_t psnipSafelongLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && \ (LONG_MAX <= 0x7fffffffffffffffLL) -typedef psnip_safe_int128_t psnip_safe_long_larger; +typedef Psnip_safe_int128_t psnipSafelongLarger; #else #undef PSNIP_SAFE_HAVE_LARGER_LONG #endif #define PSNIP_SAFE_HAVE_LARGER_ULONG #if PSNIP_SAFE_IS_LARGER(ULONG_MAX, ULLONG_MAX) -typedef unsigned long long psnip_safe_ulong_larger; +typedef unsigned long long psnipSafeulongLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(ULONG_MAX, 0xffff) -typedef psnip_uint16_t psnip_safe_ulong_larger; +typedef Psnip_uint16_t psnipSafeulongLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && \ PSNIP_SAFE_IS_LARGER(ULONG_MAX, 0xffffffffUL) -typedef psnip_uint32_t psnip_safe_ulong_larger; +typedef Psnip_uint32_t psnipSafeulongLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && \ PSNIP_SAFE_IS_LARGER(ULONG_MAX, 0xffffffffffffffffULL) -typedef psnip_uint64_t psnip_safe_ulong_larger; +typedef Psnip_uint64_t psnipSafeulongLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && \ (ULONG_MAX <= 0xffffffffffffffffULL) -typedef psnip_safe_uint128_t psnip_safe_ulong_larger; +typedef Psnip_safe_uint128_t psnipSafeulongLarger; #else #undef PSNIP_SAFE_HAVE_LARGER_ULONG #endif #define PSNIP_SAFE_HAVE_LARGER_LLONG #if !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(LLONG_MAX, 0x7fff) -typedef psnip_int16_t psnip_safe_llong_larger; +typedef Psnip_int16_t psnipSafellongLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && \ PSNIP_SAFE_IS_LARGER(LLONG_MAX, 0x7fffffffLL) -typedef psnip_int32_t psnip_safe_llong_larger; +typedef Psnip_int32_t psnipSafellongLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && \ PSNIP_SAFE_IS_LARGER(LLONG_MAX, 0x7fffffffffffffffLL) -typedef psnip_int64_t psnip_safe_llong_larger; +typedef Psnip_int64_t psnipSafellongLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && \ (LLONG_MAX <= 0x7fffffffffffffffLL) -typedef psnip_safe_int128_t psnip_safe_llong_larger; +typedef Psnip_safe_int128_t psnipSafellongLarger; #else #undef PSNIP_SAFE_HAVE_LARGER_LLONG #endif #define PSNIP_SAFE_HAVE_LARGER_ULLONG #if !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(ULLONG_MAX, 0xffff) -typedef psnip_uint16_t psnip_safe_ullong_larger; +typedef Psnip_uint16_t psnipSafeullongLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && \ PSNIP_SAFE_IS_LARGER(ULLONG_MAX, 0xffffffffUL) -typedef psnip_uint32_t psnip_safe_ullong_larger; +typedef Psnip_uint32_t psnipSafeullongLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && \ PSNIP_SAFE_IS_LARGER(ULLONG_MAX, 0xffffffffffffffffULL) -typedef psnip_uint64_t psnip_safe_ullong_larger; +typedef Psnip_uint64_t psnipSafeullongLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && \ (ULLONG_MAX <= 0xffffffffffffffffULL) -typedef psnip_safe_uint128_t psnip_safe_ullong_larger; +typedef Psnip_safe_uint128_t psnipSafeullongLarger; #else #undef PSNIP_SAFE_HAVE_LARGER_ULLONG #endif @@ -424,25 +424,25 @@ typedef psnip_safe_uint128_t psnip_safe_ullong_larger; #if defined(PSNIP_SAFE_SIZE_MAX) #define PSNIP_SAFE_HAVE_LARGER_SIZE #if PSNIP_SAFE_IS_LARGER(PSNIP_SAFE_SIZE_MAX, USHRT_MAX) -typedef unsigned short psnip_safe_size_larger; +typedef unsigned short psnipSafesizeLarger; #elif PSNIP_SAFE_IS_LARGER(PSNIP_SAFE_SIZE_MAX, UINT_MAX) -typedef unsigned int psnip_safe_size_larger; +typedef unsigned int psnipSafesizeLarger; #elif PSNIP_SAFE_IS_LARGER(PSNIP_SAFE_SIZE_MAX, ULONG_MAX) -typedef unsigned long psnip_safe_size_larger; +typedef unsigned long psnipSafesizeLarger; #elif PSNIP_SAFE_IS_LARGER(PSNIP_SAFE_SIZE_MAX, ULLONG_MAX) -typedef unsigned long long psnip_safe_size_larger; +typedef unsigned long long psnipSafesizeLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && \ PSNIP_SAFE_IS_LARGER(PSNIP_SAFE_SIZE_MAX, 0xffff) -typedef psnip_uint16_t psnip_safe_size_larger; +typedef Psnip_uint16_t psnipSafesizeLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && \ PSNIP_SAFE_IS_LARGER(PSNIP_SAFE_SIZE_MAX, 0xffffffffUL) -typedef psnip_uint32_t psnip_safe_size_larger; +typedef Psnip_uint32_t psnipSafesizeLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && \ PSNIP_SAFE_IS_LARGER(PSNIP_SAFE_SIZE_MAX, 0xffffffffffffffffULL) -typedef psnip_uint64_t psnip_safe_size_larger; +typedef Psnip_uint64_t psnipSafesizeLarger; #elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && \ (PSNIP_SAFE_SIZE_MAX <= 0xffffffffffffffffULL) -typedef psnip_safe_uint128_t psnip_safe_size_larger; +typedef Psnip_safe_uint128_t psnipSafesizeLarger; #else #undef PSNIP_SAFE_HAVE_LARGER_SIZE #endif @@ -503,180 +503,177 @@ PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(size_t, size) #if !defined(PSNIP_SAFE_NO_FIXED) PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(psnip_int8_t, int8) PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(psnip_uint8_t, uint8) -PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(psnip_int16_t, int16) -PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(psnip_uint16_t, uint16) -PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(psnip_int32_t, int32) -PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(psnip_uint32_t, uint32) +PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(Psnip_int16_t, int16) +PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(Psnip_uint16_t, uint16) +PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(Psnip_int32_t, int32) +PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(Psnip_uint32_t, uint32) #if defined(PSNIP_SAFE_HAVE_128) -PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(psnip_int64_t, int64) -PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(psnip_uint64_t, uint64) +PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(Psnip_int64_t, int64) +PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(Psnip_uint64_t, uint64) #endif #endif #endif /* !defined(PSNIP_SAFE_NO_PROMOTIONS) */ -#define PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(T, name, op_name) \ - PSNIP_SAFE__FUNCTION psnip_safe_bool psnip_safe_##name##_##op_name( \ - T* res, T a, T b) { \ - return !__builtin_##op_name##_overflow(a, b, res); \ +#define PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(T, name, opName) \ + PSNIP_SAFE__FUNCTION psnipSafeBool psnipSafe_##name##_##opName( \ + T* res, T a, T b) { \ + return !__builtin_##opName##_overflow(a, b, res); \ } -#define PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP( \ - T, name, op_name, min, max) \ - PSNIP_SAFE__FUNCTION psnip_safe_bool psnip_safe_##name##_##op_name( \ - T* res, T a, T b) { \ - const psnip_safe_##name##_larger r = \ - psnip_safe_larger_##name##_##op_name(a, b); \ - *res = (T)r; \ - return (r >= min) && (r <= max); \ +#define PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(T, name, opName, min, max) \ + PSNIP_SAFE__FUNCTION psnipSafeBool psnipSafe_##name##_##opName( \ + T* res, T a, T b) { \ + const psnipSafe_##name##Larger r = \ + psnipSafeLarger_##name##_##opName(a, b); \ + *res = (T)r; \ + return (r >= min) && (r <= max); \ } -#define PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(T, name, op_name, max) \ - PSNIP_SAFE__FUNCTION psnip_safe_bool psnip_safe_##name##_##op_name( \ - T* res, T a, T b) { \ - const psnip_safe_##name##_larger r = \ - psnip_safe_larger_##name##_##op_name(a, b); \ - *res = (T)r; \ - return (r <= max); \ +#define PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(T, name, opName, max) \ + PSNIP_SAFE__FUNCTION psnipSafeBool psnipSafe_##name##_##opName( \ + T* res, T a, T b) { \ + const psnipSafe_##name##Larger r = \ + psnipSafeLarger_##name##_##opName(a, b); \ + *res = (T)r; \ + return (r <= max); \ } #define PSNIP_SAFE_DEFINE_SIGNED_ADD(T, name, min, max) \ - PSNIP_SAFE__FUNCTION psnip_safe_bool psnip_safe_##name##_add( \ + PSNIP_SAFE__FUNCTION psnipSafeBool psnipSafe_##name##_add( \ T* res, T a, T b) { \ - psnip_safe_bool r = \ + psnipSafeBool r = \ !(((b > 0) && (a > (max - b))) || ((b < 0) && (a < (min - b)))); \ if (PSNIP_SAFE_LIKELY(r)) \ *res = a + b; \ return r; \ } -#define PSNIP_SAFE_DEFINE_UNSIGNED_ADD(T, name, max) \ - PSNIP_SAFE__FUNCTION psnip_safe_bool psnip_safe_##name##_add( \ - T* res, T a, T b) { \ - *res = (T)(a + b); \ - return !PSNIP_SAFE_UNLIKELY((b > 0) && (a > (max - b))); \ +#define PSNIP_SAFE_DEFINE_UNSIGNED_ADD(T, name, max) \ + PSNIP_SAFE__FUNCTION psnipSafeBool psnipSafe_##name##_add( \ + T* res, T a, T b) { \ + *res = (T)(a + b); \ + return !PSNIP_SAFE_UNLIKELY((b > 0) && (a > (max - b))); \ } -#define PSNIP_SAFE_DEFINE_SIGNED_SUB(T, name, min, max) \ - PSNIP_SAFE__FUNCTION psnip_safe_bool psnip_safe_##name##_sub( \ - T* res, T a, T b) { \ - psnip_safe_bool r = \ - !((b > 0 && a < (min + b)) || (b < 0 && a > (max + b))); \ - if (PSNIP_SAFE_LIKELY(r)) \ - *res = a - b; \ - return r; \ +#define PSNIP_SAFE_DEFINE_SIGNED_SUB(T, name, min, max) \ + PSNIP_SAFE__FUNCTION psnipSafeBool psnipSafe_##name##_sub( \ + T* res, T a, T b) { \ + psnipSafeBool r = !((b > 0 && a < (min + b)) || (b < 0 && a > (max + b))); \ + if (PSNIP_SAFE_LIKELY(r)) \ + *res = a - b; \ + return r; \ } -#define PSNIP_SAFE_DEFINE_UNSIGNED_SUB(T, name, max) \ - PSNIP_SAFE__FUNCTION psnip_safe_bool psnip_safe_##name##_sub( \ - T* res, T a, T b) { \ - *res = a - b; \ - return !PSNIP_SAFE_UNLIKELY(b > a); \ +#define PSNIP_SAFE_DEFINE_UNSIGNED_SUB(T, name, max) \ + PSNIP_SAFE__FUNCTION psnipSafeBool psnipSafe_##name##_sub( \ + T* res, T a, T b) { \ + *res = a - b; \ + return !PSNIP_SAFE_UNLIKELY(b > a); \ } -#define PSNIP_SAFE_DEFINE_SIGNED_MUL(T, name, min, max) \ - PSNIP_SAFE__FUNCTION psnip_safe_bool psnip_safe_##name##_mul( \ - T* res, T a, T b) { \ - psnip_safe_bool r = 1; \ - if (a > 0) { \ - if (b > 0) { \ - if (a > (max / b)) { \ - r = 0; \ - } \ - } else { \ - if (b < (min / a)) { \ - r = 0; \ - } \ - } \ - } else { \ - if (b > 0) { \ - if (a < (min / b)) { \ - r = 0; \ - } \ - } else { \ - if ((a != 0) && (b < (max / a))) { \ - r = 0; \ - } \ - } \ - } \ - if (PSNIP_SAFE_LIKELY(r)) \ - *res = a * b; \ - return r; \ +#define PSNIP_SAFE_DEFINE_SIGNED_MUL(T, name, min, max) \ + PSNIP_SAFE__FUNCTION psnipSafeBool psnipSafe_##name##_mul( \ + T* res, T a, T b) { \ + psnipSafeBool r = 1; \ + if (a > 0) { \ + if (b > 0) { \ + if (a > (max / b)) { \ + r = 0; \ + } \ + } else { \ + if (b < (min / a)) { \ + r = 0; \ + } \ + } \ + } else { \ + if (b > 0) { \ + if (a < (min / b)) { \ + r = 0; \ + } \ + } else { \ + if ((a != 0) && (b < (max / a))) { \ + r = 0; \ + } \ + } \ + } \ + if (PSNIP_SAFE_LIKELY(r)) \ + *res = a * b; \ + return r; \ } #define PSNIP_SAFE_DEFINE_UNSIGNED_MUL(T, name, max) \ - PSNIP_SAFE__FUNCTION psnip_safe_bool psnip_safe_##name##_mul( \ + PSNIP_SAFE__FUNCTION psnipSafeBool psnipSafe_##name##_mul( \ T* res, T a, T b) { \ *res = (T)(a * b); \ return !PSNIP_SAFE_UNLIKELY((a > 0) && (b > 0) && (a > (max / b))); \ } -#define PSNIP_SAFE_DEFINE_SIGNED_DIV(T, name, min, max) \ - PSNIP_SAFE__FUNCTION psnip_safe_bool psnip_safe_##name##_div( \ - T* res, T a, T b) { \ - if (PSNIP_SAFE_UNLIKELY(b == 0)) { \ - *res = 0; \ - return 0; \ - } else if (PSNIP_SAFE_UNLIKELY(a == min && b == -1)) { \ - *res = min; \ - return 0; \ - } else { \ - *res = (T)(a / b); \ - return 1; \ - } \ +#define PSNIP_SAFE_DEFINE_SIGNED_DIV(T, name, min, max) \ + PSNIP_SAFE__FUNCTION psnipSafeBool psnipSafe_##name##_div( \ + T* res, T a, T b) { \ + if (PSNIP_SAFE_UNLIKELY(b == 0)) { \ + *res = 0; \ + return 0; \ + } else if (PSNIP_SAFE_UNLIKELY(a == min && b == -1)) { \ + *res = min; \ + return 0; \ + } else { \ + *res = (T)(a / b); \ + return 1; \ + } \ } -#define PSNIP_SAFE_DEFINE_UNSIGNED_DIV(T, name, max) \ - PSNIP_SAFE__FUNCTION psnip_safe_bool psnip_safe_##name##_div( \ - T* res, T a, T b) { \ - if (PSNIP_SAFE_UNLIKELY(b == 0)) { \ - *res = 0; \ - return 0; \ - } else { \ - *res = a / b; \ - return 1; \ - } \ +#define PSNIP_SAFE_DEFINE_UNSIGNED_DIV(T, name, max) \ + PSNIP_SAFE__FUNCTION psnipSafeBool psnipSafe_##name##_div( \ + T* res, T a, T b) { \ + if (PSNIP_SAFE_UNLIKELY(b == 0)) { \ + *res = 0; \ + return 0; \ + } else { \ + *res = a / b; \ + return 1; \ + } \ } -#define PSNIP_SAFE_DEFINE_SIGNED_MOD(T, name, min, max) \ - PSNIP_SAFE__FUNCTION psnip_safe_bool psnip_safe_##name##_mod( \ - T* res, T a, T b) { \ - if (PSNIP_SAFE_UNLIKELY(b == 0)) { \ - *res = 0; \ - return 0; \ - } else if (PSNIP_SAFE_UNLIKELY(a == min && b == -1)) { \ - *res = min; \ - return 0; \ - } else { \ - *res = (T)(a % b); \ - return 1; \ - } \ +#define PSNIP_SAFE_DEFINE_SIGNED_MOD(T, name, min, max) \ + PSNIP_SAFE__FUNCTION psnipSafeBool psnipSafe_##name##_mod( \ + T* res, T a, T b) { \ + if (PSNIP_SAFE_UNLIKELY(b == 0)) { \ + *res = 0; \ + return 0; \ + } else if (PSNIP_SAFE_UNLIKELY(a == min && b == -1)) { \ + *res = min; \ + return 0; \ + } else { \ + *res = (T)(a % b); \ + return 1; \ + } \ } -#define PSNIP_SAFE_DEFINE_UNSIGNED_MOD(T, name, max) \ - PSNIP_SAFE__FUNCTION psnip_safe_bool psnip_safe_##name##_mod( \ - T* res, T a, T b) { \ - if (PSNIP_SAFE_UNLIKELY(b == 0)) { \ - *res = 0; \ - return 0; \ - } else { \ - *res = a % b; \ - return 1; \ - } \ +#define PSNIP_SAFE_DEFINE_UNSIGNED_MOD(T, name, max) \ + PSNIP_SAFE__FUNCTION psnipSafeBool psnipSafe_##name##_mod( \ + T* res, T a, T b) { \ + if (PSNIP_SAFE_UNLIKELY(b == 0)) { \ + *res = 0; \ + return 0; \ + } else { \ + *res = a % b; \ + return 1; \ + } \ } -#define PSNIP_SAFE_DEFINE_SIGNED_NEG(T, name, min, max) \ - PSNIP_SAFE__FUNCTION psnip_safe_bool psnip_safe_##name##_neg( \ - T* res, T value) { \ - psnip_safe_bool r = value != min; \ - *res = PSNIP_SAFE_LIKELY(r) ? -value : max; \ - return r; \ +#define PSNIP_SAFE_DEFINE_SIGNED_NEG(T, name, min, max) \ + PSNIP_SAFE__FUNCTION psnipSafeBool psnipSafe_##name##_neg(T* res, T value) { \ + psnipSafeBool r = value != min; \ + *res = PSNIP_SAFE_LIKELY(r) ? -value : max; \ + return r; \ } -#define PSNIP_SAFE_DEFINE_INTSAFE(T, name, op, isf) \ - PSNIP_SAFE__FUNCTION psnip_safe_bool psnip_safe_##name##_##op( \ - T* res, T a, T b) { \ - return isf(a, b, res) == S_OK; \ +#define PSNIP_SAFE_DEFINE_INTSAFE(T, name, op, isf) \ + PSNIP_SAFE__FUNCTION psnipSafeBool psnipSafe_##name##_##op( \ + T* res, T a, T b) { \ + return isf(a, b, res) == S_OK; \ } #if CHAR_MIN == 0 @@ -1071,260 +1068,260 @@ PSNIP_SAFE_DEFINE_UNSIGNED_DIV(psnip_uint8_t, uint8, 0xff) PSNIP_SAFE_DEFINE_UNSIGNED_MOD(psnip_uint8_t, uint8, 0xff) #if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) -PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int16_t, int16, add) -PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int16_t, int16, sub) -PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int16_t, int16, mul) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(Psnip_int16_t, int16, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(Psnip_int16_t, int16, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(Psnip_int16_t, int16, mul) #elif defined(PSNIP_SAFE_HAVE_LARGER_INT16) PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP( - psnip_int16_t, + Psnip_int16_t, int16, add, (-32767 - 1), 0x7fff) PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP( - psnip_int16_t, + Psnip_int16_t, int16, sub, (-32767 - 1), 0x7fff) PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP( - psnip_int16_t, + Psnip_int16_t, int16, mul, (-32767 - 1), 0x7fff) #else -PSNIP_SAFE_DEFINE_SIGNED_ADD(psnip_int16_t, int16, (-32767 - 1), 0x7fff) -PSNIP_SAFE_DEFINE_SIGNED_SUB(psnip_int16_t, int16, (-32767 - 1), 0x7fff) -PSNIP_SAFE_DEFINE_SIGNED_MUL(psnip_int16_t, int16, (-32767 - 1), 0x7fff) +PSNIP_SAFE_DEFINE_SIGNED_ADD(Psnip_int16_t, int16, (-32767 - 1), 0x7fff) +PSNIP_SAFE_DEFINE_SIGNED_SUB(Psnip_int16_t, int16, (-32767 - 1), 0x7fff) +PSNIP_SAFE_DEFINE_SIGNED_MUL(Psnip_int16_t, int16, (-32767 - 1), 0x7fff) #endif -PSNIP_SAFE_DEFINE_SIGNED_DIV(psnip_int16_t, int16, (-32767 - 1), 0x7fff) -PSNIP_SAFE_DEFINE_SIGNED_MOD(psnip_int16_t, int16, (-32767 - 1), 0x7fff) -PSNIP_SAFE_DEFINE_SIGNED_NEG(psnip_int16_t, int16, (-32767 - 1), 0x7fff) +PSNIP_SAFE_DEFINE_SIGNED_DIV(Psnip_int16_t, int16, (-32767 - 1), 0x7fff) +PSNIP_SAFE_DEFINE_SIGNED_MOD(Psnip_int16_t, int16, (-32767 - 1), 0x7fff) +PSNIP_SAFE_DEFINE_SIGNED_NEG(Psnip_int16_t, int16, (-32767 - 1), 0x7fff) #if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) -PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint16_t, uint16, add) -PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint16_t, uint16, sub) -PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint16_t, uint16, mul) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(Psnip_uint16_t, uint16, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(Psnip_uint16_t, uint16, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(Psnip_uint16_t, uint16, mul) #elif defined(PSNIP_SAFE_HAVE_INTSAFE_H) && defined(_WIN32) -PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint16_t, uint16, add, UShortAdd) -PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint16_t, uint16, sub, UShortSub) -PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint16_t, uint16, mul, UShortMult) +PSNIP_SAFE_DEFINE_INTSAFE(Psnip_uint16_t, uint16, add, UShortAdd) +PSNIP_SAFE_DEFINE_INTSAFE(Psnip_uint16_t, uint16, sub, UShortSub) +PSNIP_SAFE_DEFINE_INTSAFE(Psnip_uint16_t, uint16, mul, UShortMult) #elif defined(PSNIP_SAFE_HAVE_LARGER_UINT16) PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP( - psnip_uint16_t, + Psnip_uint16_t, uint16, add, 0xffff) PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP( - psnip_uint16_t, + Psnip_uint16_t, uint16, sub, 0xffff) PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP( - psnip_uint16_t, + Psnip_uint16_t, uint16, mul, 0xffff) #else -PSNIP_SAFE_DEFINE_UNSIGNED_ADD(psnip_uint16_t, uint16, 0xffff) -PSNIP_SAFE_DEFINE_UNSIGNED_SUB(psnip_uint16_t, uint16, 0xffff) -PSNIP_SAFE_DEFINE_UNSIGNED_MUL(psnip_uint16_t, uint16, 0xffff) +PSNIP_SAFE_DEFINE_UNSIGNED_ADD(Psnip_uint16_t, uint16, 0xffff) +PSNIP_SAFE_DEFINE_UNSIGNED_SUB(Psnip_uint16_t, uint16, 0xffff) +PSNIP_SAFE_DEFINE_UNSIGNED_MUL(Psnip_uint16_t, uint16, 0xffff) #endif -PSNIP_SAFE_DEFINE_UNSIGNED_DIV(psnip_uint16_t, uint16, 0xffff) -PSNIP_SAFE_DEFINE_UNSIGNED_MOD(psnip_uint16_t, uint16, 0xffff) +PSNIP_SAFE_DEFINE_UNSIGNED_DIV(Psnip_uint16_t, uint16, 0xffff) +PSNIP_SAFE_DEFINE_UNSIGNED_MOD(Psnip_uint16_t, uint16, 0xffff) #if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) -PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int32_t, int32, add) -PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int32_t, int32, sub) -PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int32_t, int32, mul) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(Psnip_int32_t, int32, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(Psnip_int32_t, int32, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(Psnip_int32_t, int32, mul) #elif defined(PSNIP_SAFE_HAVE_LARGER_INT32) PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP( - psnip_int32_t, + Psnip_int32_t, int32, add, (-0x7fffffffLL - 1), 0x7fffffffLL) PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP( - psnip_int32_t, + Psnip_int32_t, int32, sub, (-0x7fffffffLL - 1), 0x7fffffffLL) PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP( - psnip_int32_t, + Psnip_int32_t, int32, mul, (-0x7fffffffLL - 1), 0x7fffffffLL) #else PSNIP_SAFE_DEFINE_SIGNED_ADD( - psnip_int32_t, + Psnip_int32_t, int32, (-0x7fffffffLL - 1), 0x7fffffffLL) PSNIP_SAFE_DEFINE_SIGNED_SUB( - psnip_int32_t, + Psnip_int32_t, int32, (-0x7fffffffLL - 1), 0x7fffffffLL) PSNIP_SAFE_DEFINE_SIGNED_MUL( - psnip_int32_t, + Psnip_int32_t, int32, (-0x7fffffffLL - 1), 0x7fffffffLL) #endif PSNIP_SAFE_DEFINE_SIGNED_DIV( - psnip_int32_t, + Psnip_int32_t, int32, (-0x7fffffffLL - 1), 0x7fffffffLL) PSNIP_SAFE_DEFINE_SIGNED_MOD( - psnip_int32_t, + Psnip_int32_t, int32, (-0x7fffffffLL - 1), 0x7fffffffLL) PSNIP_SAFE_DEFINE_SIGNED_NEG( - psnip_int32_t, + Psnip_int32_t, int32, (-0x7fffffffLL - 1), 0x7fffffffLL) #if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) -PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint32_t, uint32, add) -PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint32_t, uint32, sub) -PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint32_t, uint32, mul) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(Psnip_uint32_t, uint32, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(Psnip_uint32_t, uint32, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(Psnip_uint32_t, uint32, mul) #elif defined(PSNIP_SAFE_HAVE_INTSAFE_H) && defined(_WIN32) -PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint32_t, uint32, add, UIntAdd) -PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint32_t, uint32, sub, UIntSub) -PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint32_t, uint32, mul, UIntMult) +PSNIP_SAFE_DEFINE_INTSAFE(Psnip_uint32_t, uint32, add, UIntAdd) +PSNIP_SAFE_DEFINE_INTSAFE(Psnip_uint32_t, uint32, sub, UIntSub) +PSNIP_SAFE_DEFINE_INTSAFE(Psnip_uint32_t, uint32, mul, UIntMult) #elif defined(PSNIP_SAFE_HAVE_LARGER_UINT32) PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP( - psnip_uint32_t, + Psnip_uint32_t, uint32, add, 0xffffffffUL) PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP( - psnip_uint32_t, + Psnip_uint32_t, uint32, sub, 0xffffffffUL) PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP( - psnip_uint32_t, + Psnip_uint32_t, uint32, mul, 0xffffffffUL) #else -PSNIP_SAFE_DEFINE_UNSIGNED_ADD(psnip_uint32_t, uint32, 0xffffffffUL) -PSNIP_SAFE_DEFINE_UNSIGNED_SUB(psnip_uint32_t, uint32, 0xffffffffUL) -PSNIP_SAFE_DEFINE_UNSIGNED_MUL(psnip_uint32_t, uint32, 0xffffffffUL) +PSNIP_SAFE_DEFINE_UNSIGNED_ADD(Psnip_uint32_t, uint32, 0xffffffffUL) +PSNIP_SAFE_DEFINE_UNSIGNED_SUB(Psnip_uint32_t, uint32, 0xffffffffUL) +PSNIP_SAFE_DEFINE_UNSIGNED_MUL(Psnip_uint32_t, uint32, 0xffffffffUL) #endif -PSNIP_SAFE_DEFINE_UNSIGNED_DIV(psnip_uint32_t, uint32, 0xffffffffUL) -PSNIP_SAFE_DEFINE_UNSIGNED_MOD(psnip_uint32_t, uint32, 0xffffffffUL) +PSNIP_SAFE_DEFINE_UNSIGNED_DIV(Psnip_uint32_t, uint32, 0xffffffffUL) +PSNIP_SAFE_DEFINE_UNSIGNED_MOD(Psnip_uint32_t, uint32, 0xffffffffUL) #if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) -PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int64_t, int64, add) -PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int64_t, int64, sub) -PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int64_t, int64, mul) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(Psnip_int64_t, int64, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(Psnip_int64_t, int64, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(Psnip_int64_t, int64, mul) #elif defined(PSNIP_SAFE_HAVE_LARGER_INT64) PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP( - psnip_int64_t, + Psnip_int64_t, int64, add, (-0x7fffffffffffffffLL - 1), 0x7fffffffffffffffLL) PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP( - psnip_int64_t, + Psnip_int64_t, int64, sub, (-0x7fffffffffffffffLL - 1), 0x7fffffffffffffffLL) PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP( - psnip_int64_t, + Psnip_int64_t, int64, mul, (-0x7fffffffffffffffLL - 1), 0x7fffffffffffffffLL) #else PSNIP_SAFE_DEFINE_SIGNED_ADD( - psnip_int64_t, + Psnip_int64_t, int64, (-0x7fffffffffffffffLL - 1), 0x7fffffffffffffffLL) PSNIP_SAFE_DEFINE_SIGNED_SUB( - psnip_int64_t, + Psnip_int64_t, int64, (-0x7fffffffffffffffLL - 1), 0x7fffffffffffffffLL) PSNIP_SAFE_DEFINE_SIGNED_MUL( - psnip_int64_t, + Psnip_int64_t, int64, (-0x7fffffffffffffffLL - 1), 0x7fffffffffffffffLL) #endif PSNIP_SAFE_DEFINE_SIGNED_DIV( - psnip_int64_t, + Psnip_int64_t, int64, (-0x7fffffffffffffffLL - 1), 0x7fffffffffffffffLL) PSNIP_SAFE_DEFINE_SIGNED_MOD( - psnip_int64_t, + Psnip_int64_t, int64, (-0x7fffffffffffffffLL - 1), 0x7fffffffffffffffLL) PSNIP_SAFE_DEFINE_SIGNED_NEG( - psnip_int64_t, + Psnip_int64_t, int64, (-0x7fffffffffffffffLL - 1), 0x7fffffffffffffffLL) #if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) -PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint64_t, uint64, add) -PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint64_t, uint64, sub) -PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint64_t, uint64, mul) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(Psnip_uint64_t, uint64, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(Psnip_uint64_t, uint64, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(Psnip_uint64_t, uint64, mul) #elif defined(PSNIP_SAFE_HAVE_INTSAFE_H) && defined(_WIN32) -PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint64_t, uint64, add, ULongLongAdd) -PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint64_t, uint64, sub, ULongLongSub) -PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint64_t, uint64, mul, ULongLongMult) +PSNIP_SAFE_DEFINE_INTSAFE(Psnip_uint64_t, uint64, add, ULongLongAdd) +PSNIP_SAFE_DEFINE_INTSAFE(Psnip_uint64_t, uint64, sub, ULongLongSub) +PSNIP_SAFE_DEFINE_INTSAFE(Psnip_uint64_t, uint64, mul, ULongLongMult) #elif defined(PSNIP_SAFE_HAVE_LARGER_UINT64) PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP( - psnip_uint64_t, + Psnip_uint64_t, uint64, add, 0xffffffffffffffffULL) PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP( - psnip_uint64_t, + Psnip_uint64_t, uint64, sub, 0xffffffffffffffffULL) PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP( - psnip_uint64_t, + Psnip_uint64_t, uint64, mul, 0xffffffffffffffffULL) #else -PSNIP_SAFE_DEFINE_UNSIGNED_ADD(psnip_uint64_t, uint64, 0xffffffffffffffffULL) -PSNIP_SAFE_DEFINE_UNSIGNED_SUB(psnip_uint64_t, uint64, 0xffffffffffffffffULL) -PSNIP_SAFE_DEFINE_UNSIGNED_MUL(psnip_uint64_t, uint64, 0xffffffffffffffffULL) +PSNIP_SAFE_DEFINE_UNSIGNED_ADD(Psnip_uint64_t, uint64, 0xffffffffffffffffULL) +PSNIP_SAFE_DEFINE_UNSIGNED_SUB(Psnip_uint64_t, uint64, 0xffffffffffffffffULL) +PSNIP_SAFE_DEFINE_UNSIGNED_MUL(Psnip_uint64_t, uint64, 0xffffffffffffffffULL) #endif -PSNIP_SAFE_DEFINE_UNSIGNED_DIV(psnip_uint64_t, uint64, 0xffffffffffffffffULL) -PSNIP_SAFE_DEFINE_UNSIGNED_MOD(psnip_uint64_t, uint64, 0xffffffffffffffffULL) +PSNIP_SAFE_DEFINE_UNSIGNED_DIV(Psnip_uint64_t, uint64, 0xffffffffffffffffULL) +PSNIP_SAFE_DEFINE_UNSIGNED_MOD(Psnip_uint64_t, uint64, 0xffffffffffffffffULL) #endif /* !defined(PSNIP_SAFE_NO_FIXED) */ #define PSNIP_SAFE_C11_GENERIC_SELECTION(res, op) \ - _Generic( \ + generic( \ (*res), \ - char: psnip_safe_char_##op, \ - unsigned char: psnip_safe_uchar_##op, \ - short: psnip_safe_short_##op, \ - unsigned short: psnip_safe_ushort_##op, \ - int: psnip_safe_int_##op, \ - unsigned int: psnip_safe_uint_##op, \ - long: psnip_safe_long_##op, \ - unsigned long: psnip_safe_ulong_##op, \ - long long: psnip_safe_llong_##op, \ - unsigned long long: psnip_safe_ullong_##op) + char : psnipSafeChar_##op, \ + unsigned char : psnipSafeUchar_##op, \ + short : psnipSafeShort_##op, \ + unsigned short : psnipSafeUshort_##op, \ + int : psnipSafeInt_##op, \ + unsigned int : psnipSafeUint_##op, \ + long : psnipSafeLong_##op, \ + unsigned long : psnipSafeUlong_##op, \ + long long : psnipSafeLlong_##op, \ + unsigned long long : psnipSafeUllong_##op) #define PSNIP_SAFE_C11_GENERIC_BINARY_OP(op, res, a, b) \ PSNIP_SAFE_C11_GENERIC_SELECTION(res, op)(res, a, b) @@ -1332,12 +1329,12 @@ PSNIP_SAFE_DEFINE_UNSIGNED_MOD(psnip_uint64_t, uint64, 0xffffffffffffffffULL) PSNIP_SAFE_C11_GENERIC_SELECTION(res, op)(res, v) #if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) -#define psnip_safe_add(res, a, b) (!__builtin_add_overflow(a, b, res)) -#define psnip_safe_sub(res, a, b) (!__builtin_sub_overflow(a, b, res)) -#define psnip_safe_mul(res, a, b) (!__builtin_mul_overflow(a, b, res)) -#define psnip_safe_div(res, a, b) (!__builtin_div_overflow(a, b, res)) -#define psnip_safe_mod(res, a, b) (!__builtin_mod_overflow(a, b, res)) -#define psnip_safe_neg(res, v) PSNIP_SAFE_C11_GENERIC_UNARY_OP(neg, res, v) +#define psnipSafeAdd(res, a, b) (!__builtin_add_overflow(a, b, res)) +#define psnipSafeSub(res, a, b) (!__builtin_sub_overflow(a, b, res)) +#define psnipSafeMul(res, a, b) (!__builtin_mul_overflow(a, b, res)) +#define psnipSafeDiv(res, a, b) (!__builtin_div_overflow(a, b, res)) +#define psnipSafeMod(res, a, b) (!__builtin_mod_overflow(a, b, res)) +#define psnipSafeNeg(res, v) PSNIP_SAFE_C11_GENERIC_UNARY_OP(neg, res, v) #elif defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L) /* The are no fixed-length or size selections because they cause an @@ -1345,42 +1342,37 @@ PSNIP_SAFE_DEFINE_UNSIGNED_MOD(psnip_uint64_t, uint64, 0xffffffffffffffffULL) * this doesn't cause problems on exotic platforms, but if it does * please let me know and I'll try to figure something out. */ -#define psnip_safe_add(res, a, b) \ - PSNIP_SAFE_C11_GENERIC_BINARY_OP(add, res, a, b) -#define psnip_safe_sub(res, a, b) \ - PSNIP_SAFE_C11_GENERIC_BINARY_OP(sub, res, a, b) -#define psnip_safe_mul(res, a, b) \ - PSNIP_SAFE_C11_GENERIC_BINARY_OP(mul, res, a, b) -#define psnip_safe_div(res, a, b) \ - PSNIP_SAFE_C11_GENERIC_BINARY_OP(div, res, a, b) -#define psnip_safe_mod(res, a, b) \ - PSNIP_SAFE_C11_GENERIC_BINARY_OP(mod, res, a, b) -#define psnip_safe_neg(res, v) PSNIP_SAFE_C11_GENERIC_UNARY_OP(neg, res, v) +#define psnipSafeAdd(res, a, b) PSNIP_SAFE_C11_GENERIC_BINARY_OP(add, res, a, b) +#define psnipSafeSub(res, a, b) PSNIP_SAFE_C11_GENERIC_BINARY_OP(sub, res, a, b) +#define psnipSafeMul(res, a, b) PSNIP_SAFE_C11_GENERIC_BINARY_OP(mul, res, a, b) +#define psnipSafeDiv(res, a, b) PSNIP_SAFE_C11_GENERIC_BINARY_OP(div, res, a, b) +#define psnipSafeMod(res, a, b) PSNIP_SAFE_C11_GENERIC_BINARY_OP(mod, res, a, b) +#define psnipSafeNeg(res, v) PSNIP_SAFE_C11_GENERIC_UNARY_OP(neg, res, v) #endif #if !defined(PSNIP_SAFE_HAVE_BUILTINS) && \ (defined(PSNIP_SAFE_EMULATE_NATIVE) || \ defined(PSNIP_BUILTIN_EMULATE_NATIVE)) -#define __builtin_sadd_overflow(a, b, res) (!psnip_safe_int_add(res, a, b)) -#define __builtin_saddl_overflow(a, b, res) (!psnip_safe_long_add(res, a, b)) -#define __builtin_saddll_overflow(a, b, res) (!psnip_safe_llong_add(res, a, b)) -#define __builtin_uadd_overflow(a, b, res) (!psnip_safe_uint_add(res, a, b)) -#define __builtin_uaddl_overflow(a, b, res) (!psnip_safe_ulong_add(res, a, b)) -#define __builtin_uaddll_overflow(a, b, res) (!psnip_safe_ullong_add(res, a, b)) - -#define __builtin_ssub_overflow(a, b, res) (!psnip_safe_int_sub(res, a, b)) -#define __builtin_ssubl_overflow(a, b, res) (!psnip_safe_long_sub(res, a, b)) -#define __builtin_ssubll_overflow(a, b, res) (!psnip_safe_llong_sub(res, a, b)) -#define __builtin_usub_overflow(a, b, res) (!psnip_safe_uint_sub(res, a, b)) -#define __builtin_usubl_overflow(a, b, res) (!psnip_safe_ulong_sub(res, a, b)) -#define __builtin_usubll_overflow(a, b, res) (!psnip_safe_ullong_sub(res, a, b)) - -#define __builtin_smul_overflow(a, b, res) (!psnip_safe_int_mul(res, a, b)) -#define __builtin_smull_overflow(a, b, res) (!psnip_safe_long_mul(res, a, b)) -#define __builtin_smulll_overflow(a, b, res) (!psnip_safe_llong_mul(res, a, b)) -#define __builtin_umul_overflow(a, b, res) (!psnip_safe_uint_mul(res, a, b)) -#define __builtin_umull_overflow(a, b, res) (!psnip_safe_ulong_mul(res, a, b)) -#define __builtin_umulll_overflow(a, b, res) (!psnip_safe_ullong_mul(res, a, b)) +#define __builtin_sadd_overflow(a, b, res) (!psnipSafeIntAdd(res, a, b)) +#define __builtin_saddl_overflow(a, b, res) (!psnipSafeLongAdd(res, a, b)) +#define __builtin_saddll_overflow(a, b, res) (!psnipSafeLlongAdd(res, a, b)) +#define __builtin_uadd_overflow(a, b, res) (!psnipSafeUintAdd(res, a, b)) +#define __builtin_uaddl_overflow(a, b, res) (!psnipSafeUlongAdd(res, a, b)) +#define __builtin_uaddll_overflow(a, b, res) (!psnipSafeUllongAdd(res, a, b)) + +#define __builtin_ssub_overflow(a, b, res) (!psnipSafeIntSub(res, a, b)) +#define __builtin_ssubl_overflow(a, b, res) (!psnipSafeLongSub(res, a, b)) +#define __builtin_ssubll_overflow(a, b, res) (!psnipSafeLlongSub(res, a, b)) +#define __builtin_usub_overflow(a, b, res) (!psnipSafeUintSub(res, a, b)) +#define __builtin_usubl_overflow(a, b, res) (!psnipSafeUlongSub(res, a, b)) +#define __builtin_usubll_overflow(a, b, res) (!psnipSafeUllongSub(res, a, b)) + +#define __builtin_smul_overflow(a, b, res) (!psnipSafeIntMul(res, a, b)) +#define __builtin_smull_overflow(a, b, res) (!psnipSafeLongMul(res, a, b)) +#define __builtin_smulll_overflow(a, b, res) (!psnipSafeLlongMul(res, a, b)) +#define __builtin_umul_overflow(a, b, res) (!psnipSafeUintMul(res, a, b)) +#define __builtin_umull_overflow(a, b, res) (!psnipSafeUlongMul(res, a, b)) +#define __builtin_umulll_overflow(a, b, res) (!psnipSafeUllongMul(res, a, b)) #endif #endif /* !defined(PSNIP_SAFE_H) */ diff --git a/velox/dwio/text/CMakeLists.txt b/velox/dwio/text/CMakeLists.txt index d11825c4e30..040e3ca05d0 100644 --- a/velox/dwio/text/CMakeLists.txt +++ b/velox/dwio/text/CMakeLists.txt @@ -18,7 +18,13 @@ endif() add_subdirectory(writer) -velox_add_library(velox_dwio_text_writer_register RegisterTextWriter.cpp) +velox_add_library( + velox_dwio_text_writer_register + RegisterTextWriter.cpp + HEADERS + RegisterTextReader.h + RegisterTextWriter.h +) velox_link_libraries(velox_dwio_text_writer_register velox_dwio_text_writer) diff --git a/velox/dwio/text/reader/CMakeLists.txt b/velox/dwio/text/reader/CMakeLists.txt index d65f1ffbc22..7910a383b7e 100644 --- a/velox/dwio/text/reader/CMakeLists.txt +++ b/velox/dwio/text/reader/CMakeLists.txt @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -velox_add_library(velox_dwio_text_reader TextReader.cpp) +velox_add_library(velox_dwio_text_reader TextReader.cpp HEADERS TextReader.h) velox_link_libraries( velox_dwio_text_reader diff --git a/velox/dwio/text/reader/TextReader.cpp b/velox/dwio/text/reader/TextReader.cpp index a157144ea46..d093e377eef 100644 --- a/velox/dwio/text/reader/TextReader.cpp +++ b/velox/dwio/text/reader/TextReader.cpp @@ -15,13 +15,16 @@ */ #include "velox/dwio/text/reader/TextReader.h" + +#include +#include + #include "velox/common/encode/Base64.h" #include "velox/dwio/common/exception/Exceptions.h" #include "velox/type/fbhive/HiveTypeParser.h" -#include - namespace facebook::velox::text { +namespace { using common::CompressionKind; @@ -29,16 +32,18 @@ using dwio::common::EOFError; using dwio::common::RowReader; using dwio::common::verify; -using folly::AsciiCaseInsensitive; -using folly::StringPiece; - -constexpr const char* kTextfileCompressionExtensionGzip = ".gz"; -constexpr const char* kTextfileCompressionExtensionDeflate = ".deflate"; -constexpr const char* kTextfileCompressionExtensionZst = ".zst"; +static constexpr std::string_view kTextfileCompressionExtensionGzip{".gz"}; +static constexpr std::string_view kTextfileCompressionExtensionDeflate{ + ".deflate"}; +static constexpr std::string_view kTextfileCompressionExtensionZst{".zst"}; +static constexpr std::string_view kTextfileCompressionExtensionLz4{".lz4"}; +static constexpr std::string_view kTextfileCompressionExtensionLzo{".lzo"}; +static constexpr std::string_view kTextfileCompressionExtensionSnappy{ + ".snappy"}; static std::string emptyString = std::string(); -namespace { +constexpr const int32_t kDecompressionBufferFactor = 3; void resizeVector( BaseVector* FOLLY_NULLABLE data, @@ -91,24 +96,24 @@ void resizeVector( } } -bool endsWith(const std::string& str, const std::string& suffix) { - return str.size() >= suffix.size() && - str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0; -} - void setCompressionSettings( const std::string& filename, CompressionKind& kind, dwio::common::compression::CompressionOptions& compressionOptions) { - if (endsWith(filename, kTextfileCompressionExtensionGzip)) { + if (filename.ends_with(kTextfileCompressionExtensionLz4) || + filename.ends_with(kTextfileCompressionExtensionLzo) || + filename.ends_with(kTextfileCompressionExtensionSnappy)) { + VELOX_FAIL("Unsupported compression extension for file: {}", filename); + } + if (filename.ends_with(kTextfileCompressionExtensionGzip)) { kind = CompressionKind::CompressionKind_GZIP; compressionOptions.format.zlib.windowBits = 15; // 2^15-byte deflate window size - } else if (endsWith(filename, kTextfileCompressionExtensionDeflate)) { + } else if (filename.ends_with(kTextfileCompressionExtensionDeflate)) { kind = CompressionKind::CompressionKind_ZLIB; compressionOptions.format.zlib.windowBits = -15; // raw deflate, 2^15-byte window size - } else if (endsWith(filename, kTextfileCompressionExtensionZst)) { + } else if (filename.ends_with(kTextfileCompressionExtensionZst)) { kind = CompressionKind::CompressionKind_ZSTD; } else { kind = CompressionKind::CompressionKind_NONE; @@ -156,20 +161,17 @@ TextRowReader::TextRowReader( std::make_shared>(contents_->pool)} { // Seek to first line at or after the specified region. if (contents_->compression == CompressionKind::CompressionKind_NONE) { - /** - * TODO: Inconsistent row skipping behavior (kept for Presto compatibility) - * - * Issue: When reading from byte offset > 0, we skip rows inclusively at the - * start position, but when reading from byte 0, no rows are skipped. This - * creates inconsistent behavior where a row at the boundary may be skipped - * when it should be included. - * - * Example: If pos_ = 10 is the first byte of row 2, that entire row gets - * skipped, even though it should be read. - * - * Proposed fix: streamPosition_ = (pos_ == 0) ? 0 : --pos_; - * This would skip rows exclusively of pos_, ensuring consistent behavior. - */ + // TODO: Inconsistent row skipping behavior (kept for Presto compatibility) + // Issue: When reading from byte offset > 0, we skip rows inclusively at the + // start position, but when reading from byte 0, no rows are skipped. This + // creates inconsistent behavior where a row at the boundary may be skipped + // when it should be included. + // + // Example: If pos_ = 10 is the first byte of row 2, that entire row gets + // skipped, even though it should be read. + // + // Proposed fix: streamPosition_ = (pos_ == 0) ? 0 : --pos_; + // This would skip rows exclusively of pos_, ensuring consistent behavior. const auto streamPosition_ = pos_; contents_->inputStream = contents_->input->read( @@ -191,34 +193,15 @@ TextRowReader::TextRowReader( } limit_ = std::numeric_limits::max(); - /** - * The output buffer for decompression is allocated based on the - * uncompressed length of the stream. - * - * For decompressors other than ZlibDecompressor, the uncompressed length is - * obtained via getDecompressedLength, and blockSize serves only as a - * fallbak when getDecompressedLength fails to return a valid length. - * - * ZlibDecompressor does not implement getDecompressedLength because the - * DEFLATE algorithm used by zlib does not inherently includes the - * uncompressed length in the compressed stream. As a result, blockSize is - * used to set z_stream.avail_out during decompression to ensure enough - * buffer allocated for the output. Since zlib requires avail_out to be a - * uInt (unsigned int), blockSize is set to std::numeric_limits::max() for full compatibility. - */ - const auto blockSize = - (contents_->compression == CompressionKind::CompressionKind_ZLIB || - contents_->compression == CompressionKind::CompressionKind_GZIP) - ? std::numeric_limits::max() - : std::numeric_limits::max(); - contents_->inputStream = contents_->input->loadCompleteFile(); auto name = contents_->inputStream->getName(); contents_->decompressedInputStream = createDecompressor( contents_->compression, std::move(contents_->inputStream), - blockSize, + // An estimated value used as the output buffer size for the zlib + // decompressor, and as the fallback value of the decompressed length + // for other decompressors. + kDecompressionBufferFactor * contents_->fileLength, contents_->pool, contents_->compressionOptions, fmt::format("Text Reader: Stream {}", name), @@ -274,34 +257,27 @@ uint64_t TextRowReader::next( DelimType delim = DelimTypeNone; const auto& ct = t->childAt(i); const auto& rct = reqT->childAt(i); - auto childVector = rowVecPtr->childAt(i).get(); + BaseVector* childVector = nullptr; if (isSelectedField(ct)) { + childVector = rowVecPtr->childAt(i).get(); ++colIndex; } else if (colIndex < reqChildCount && !projectSelectedType) { - // not selected and not projecting: set to null - if (childVector != nullptr) { - rowVecPtr->setNull(i, true); - childVector = nullptr; - } + // Not selected and not projecting: discard the child by setting it to + // nullptr. The projectColumns() function will later filter out unneeded + // columns based on the ScanSpec. + rowVecPtr->childAt(i) = nullptr; ++colIndex; } else { - // not selected and projecting: just discard the field - childVector = nullptr; + // Not selected and projecting: discard the child. Same reasoning as + // above. + rowVecPtr->childAt(i) = nullptr; } resizeVector(childVector, rowsRead); readElement(ct->type(), rct->type(), childVector, rowsRead, delim); } - // set null property - for (uint64_t i = colIndex; i < reqChildCount; i++) { - auto childVector = rowVecPtr->childAt(i).get(); - - if (childVector != nullptr) { - rowVecPtr->setNull(static_cast(i), true); - } - } (void)skipLine(); ++currentRow_; ++rowsRead; @@ -504,11 +480,9 @@ TextRowReader::getString(TextRowReader& th, bool& isNull, DelimType& delim) { bool wasEscaped = false; th.ownedString_.clear(); - /** - Processing has to be done character by characater instad of chunk by chunk. - This is to avoid edge case handling if escape character(s) are cut off at - the end of the chunk. - */ + // Processing has to be done character by characater instad of chunk by chunk. + // This is to avoid edge case handling if escape character(s) are cut off at + // the end of the chunk. while (true) { auto v = th.getByteOptimized(delim); if (!th.isNone(delim)) { @@ -559,16 +533,22 @@ TextRowReader::getString(TextRowReader& th, bool& isNull, DelimType& delim) { return th.ownedString_; } -uint8_t TextRowReader::getByte(DelimType& delim) { - setNone(delim); - auto v = getByteUnchecked(delim); - if (isNone(delim)) { - if (v == '\r') { - v = getByteUnchecked(delim); // always returns '\n' in this case - } - delim = getDelimType(v); +template +void TextRowReader::setValueFromString( + const std::string& str, + BaseVector* data, + vector_size_t insertionRow, + std::function(const std::string&)> convert) { + if ((atEOF_ && atSOL_) || data == nullptr) { + return; + } + auto flatVector = data->asChecked>(); + auto result = str.empty() ? std::nullopt : convert(str); + if (result) { + flatVector->set(insertionRow, *result); + } else { + flatVector->setNull(insertionRow, true); } - return v; } uint8_t TextRowReader::getByteOptimized(DelimType& delim) { @@ -612,48 +592,6 @@ DelimType TextRowReader::getDelimType(uint8_t v) { return delim; } -template -char TextRowReader::getByteUnchecked(DelimType& delim) { - if (atEOL_) { - if (!skipLF) { - delim = DelimTypeEOR; // top level EOR - } - return '\n'; - } - - try { - char v; - if (!unreadData_.empty()) { - v = unreadData_[0]; - unreadData_.erase(0, 1); - } else { - contents_->inputStream->readFully(&v, 1); - } - pos_++; - - // only when previous char == '\r' - if (skipLF) { - if (v != '\n') { - pos_--; - return '\n'; - } - } else { - atSOL_ = false; - } - return v; - } catch (EOFError&) { - } catch (std::runtime_error& e) { - if (std::string(e.what()).find("Short read of") != 0 && !skipLF) { - throw; - } - } - if (!skipLF) { - setEOF(); - delim = DelimTypeEOR; - } - return '\n'; -} - template char TextRowReader::getByteUncheckedOptimized(DelimType& delim) { if (atEOL_) { @@ -854,8 +792,8 @@ T TextRowReader::getInteger(TextRowReader& th, bool& isNull, DelimType& delim) { namespace { -static const StringView trueStringView = StringView{"TRUE"}; -static const StringView falseStringView = StringView{"FALSE"}; +static constexpr std::string_view kTrueStringView{"TRUE"}; +static constexpr std::string_view kFalseStringView{"FALSE"}; } // namespace @@ -870,21 +808,21 @@ bool TextRowReader::getBoolean( if (isNull) { return false; } - if (str.compare(trueStringView) == 0) { + if (str.compare(kTrueStringView) == 0) { return true; } - if (str.compare(falseStringView) == 0) { + if (str.compare(kFalseStringView) == 0) { return false; } switch (str.size()) { case 4: - if (StringPiece(str).equals("TRUE", AsciiCaseInsensitive())) { + if (boost::algorithm::iequals(str, kTrueStringView)) { return true; } break; case 5: - if (StringPiece(str).equals("FALSE", AsciiCaseInsensitive())) { + if (boost::algorithm::iequals(str, kFalseStringView)) { return false; } break; @@ -898,11 +836,11 @@ bool TextRowReader::getBoolean( namespace { -static const StringView NaNStringView = StringView{"NaN"}; -static const StringView InfinityStringView = StringView{"Infinity"}; -static const StringView ShortInfinityStringView = StringView{"Inf"}; -static const StringView NegInfinityStringView = StringView{"-Infinity"}; -static const StringView ShortNegInfinityStringView = StringView{"-Inf"}; +static constexpr std::string_view kNaNStringView{"NaN"}; +static constexpr std::string_view kInfinityStringView{"Infinity"}; +static constexpr std::string_view kShortInfinityStringView{"Inf"}; +static constexpr std::string_view kNegInfinityStringView{"-Infinity"}; +static constexpr std::string_view kShortNegInfinityStringView{"-Inf"}; bool unacceptableFloatingPoint(std::string& s) { for (int i = 0; i < s.size(); ++i) { @@ -912,18 +850,14 @@ bool unacceptableFloatingPoint(std::string& s) { } } - bool isNaN = - StringPiece(s).equals(StringPiece(NaNStringView), AsciiCaseInsensitive()); + bool isNaN = boost::algorithm::iequals(s, kNaNStringView); - bool isInf = StringPiece(s).equals( - StringPiece(InfinityStringView), AsciiCaseInsensitive()); - bool isShortInf = StringPiece(s).equals( - StringPiece(ShortInfinityStringView), AsciiCaseInsensitive()); + bool isInf = boost::algorithm::iequals(s, kInfinityStringView); + bool isShortInf = boost::algorithm::iequals(s, kShortInfinityStringView); - bool isNegInf = StringPiece(s).equals( - StringPiece(NegInfinityStringView), AsciiCaseInsensitive()); - bool isShortNegInf = StringPiece(s).equals( - StringPiece(ShortNegInfinityStringView), AsciiCaseInsensitive()); + bool isNegInf = boost::algorithm::iequals(s, kNegInfinityStringView); + bool isShortNegInf = + boost::algorithm::iequals(s, kShortNegInfinityStringView); return (!isNaN && !isInf && !isShortInf && !isNegInf && !isShortNegInf); } @@ -1052,8 +986,19 @@ void TextRowReader::readElement( getInteger, data, insertionRow, delim); break; case TypeKind::INTEGER: - putValue( - getInteger, data, insertionRow, delim); + if (reqT->isDate()) { + const std::string& str = getString(*this, isNull, delim); + setValueFromString( + str, + data, + insertionRow, + [](const std::string& s) -> std::optional { + return DATE()->toDays(s); + }); + } else { + putValue( + getInteger, data, insertionRow, delim); + } break; default: VELOX_FAIL( @@ -1065,10 +1010,61 @@ void TextRowReader::readElement( break; case TypeKind::BIGINT: - putValue( - getInteger, data, insertionRow, delim); + if (reqT->isShortDecimal()) { + const std::string& str = getString(*this, isNull, delim); + auto decimalParams = getDecimalPrecisionScale(*reqT); + const auto precision = decimalParams.first; + const auto scale = decimalParams.second; + setValueFromString( + str, + data, + insertionRow, + [precision, scale](const std::string& s) -> std::optional { + int64_t v = 0; + const auto status = DecimalUtil::castFromString( + StringView(s.data(), static_cast(s.size())), + precision, + scale, + v); + return status.ok() ? std::optional(v) : std::nullopt; + }); + } else { + putValue( + getInteger, data, insertionRow, delim); + } break; + case TypeKind::HUGEINT: { + const std::string& str = getString(*this, isNull, delim); + if (reqT->isLongDecimal()) { + auto decimalParams = getDecimalPrecisionScale(*reqT); + const auto precision = decimalParams.first; + const auto scale = decimalParams.second; + setValueFromString( + str, + data, + insertionRow, + [precision, + scale](const std::string& s) -> std::optional { + int128_t v = 0; + const auto status = DecimalUtil::castFromString( + StringView(s.data(), static_cast(s.size())), + precision, + scale, + v); + return status.ok() ? std::optional(v) : std::nullopt; + }); + } else { + setValueFromString( + str, + data, + insertionRow, + [](const std::string& s) -> std::optional { + return HugeInt::parse(s); + }); + } + break; + } case TypeKind::SMALLINT: switch (reqT->kind()) { case TypeKind::BIGINT: @@ -1639,17 +1635,4 @@ uint64_t TextReader::getFileLength() const { return contents_->fileLength; } -uint64_t TextReader::getMemoryUse() { - uint64_t memory = std::min( - uint64_t(contents_->fileLength), - contents_->input->getInputStream()->getNaturalReadSize()); - - // Decompressor needs a buffer. - if (contents_->compression != CompressionKind::CompressionKind_NONE) { - memory *= 3; - } - - return memory; -} - } // namespace facebook::velox::text diff --git a/velox/dwio/text/reader/TextReader.h b/velox/dwio/text/reader/TextReader.h index e1563508705..435de81c35a 100644 --- a/velox/dwio/text/reader/TextReader.h +++ b/velox/dwio/text/reader/TextReader.h @@ -85,8 +85,6 @@ class TextReader : public dwio::common::Reader { uint64_t getFileLength() const; - uint64_t getMemoryUse(); - private: ReaderOptions options_; mutable std::shared_ptr typeWithId_; @@ -206,6 +204,13 @@ class TextRowReader : public dwio::common::RowReader { vector_size_t insertionRow, DelimType& delim); + template + void setValueFromString( + const std::string& str, + BaseVector* FOLLY_NULLABLE data, + vector_size_t insertionRow, + std::function(const std::string&)> convert); + const std::shared_ptr contents_; const std::shared_ptr schemaWithId_; const std::shared_ptr& scanSpec_; diff --git a/velox/dwio/text/tests/CMakeLists.txt b/velox/dwio/text/tests/CMakeLists.txt index 032c209e75a..b1975b85607 100644 --- a/velox/dwio/text/tests/CMakeLists.txt +++ b/velox/dwio/text/tests/CMakeLists.txt @@ -17,7 +17,7 @@ set( velox_dwio_common_test_utils velox_vector_test_lib velox_exec_test_lib - velox_temp_path + velox_test_util GTest::gtest GTest::gtest_main GTest::gmock diff --git a/velox/dwio/text/tests/reader/TextReaderTest.cpp b/velox/dwio/text/tests/reader/TextReaderTest.cpp index b1b3bbda23e..b6c559087b4 100644 --- a/velox/dwio/text/tests/reader/TextReaderTest.cpp +++ b/velox/dwio/text/tests/reader/TextReaderTest.cpp @@ -14,6 +14,10 @@ * limitations under the License. */ +#include "velox/common/file/File.h" +#include "velox/common/io/IoStatistics.h" +#include "velox/common/testutil/TempFilePath.h" +#include "velox/connectors/hive/ExtractionUtils.h" #include "velox/dwio/common/tests/utils/DataFiles.h" #include "velox/dwio/text/RegisterTextReader.h" #include "velox/vector/tests/utils/VectorTestBase.h" @@ -23,11 +27,16 @@ extern long timezone; using namespace facebook::velox; using namespace facebook::velox::test; +using facebook::velox::common::testutil::TempFilePath; namespace facebook::velox::text { namespace { +int32_t parseDate(const std::string& text) { + return DATE()->toDays(text); +} + class TextReaderTest : public testing::Test, public velox::test::VectorTestBase { protected: @@ -53,6 +62,11 @@ class TextReaderTest : public testing::Test, options.setScanSpec(spec); } + std::shared_ptr dataIoStats_ = + std::make_shared(); + std::shared_ptr metadataIoStats_ = + std::make_shared(); + private: std::shared_ptr readFile_; }; @@ -122,7 +136,9 @@ TEST_F(TextReaderTest, basic) { "examples/simple_types_compressed_file.gz"); auto readFile = std::make_shared(path); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(type); auto input = @@ -232,7 +248,9 @@ TEST_F(TextReaderTest, headerAndCustomNullString) { "velox/dwio/text/tests/reader/", "examples/simple_types_with_header"); auto readFile = std::make_shared(path); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(type); auto rowReaderOptions = dwio::common::RowReaderOptions(); setScanSpec(*type, rowReaderOptions); @@ -406,7 +424,9 @@ TEST_F(TextReaderTest, complexTypesWithCustomDelimiters) { auto readFile = std::make_shared(path); auto serDeOptions = dwio::common::SerDeOptions('\t', '|', '#', '\\', true); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(type); readerOptions.setSerDeOptions(serDeOptions); @@ -471,7 +491,9 @@ TEST_F(TextReaderTest, projectComplexTypesWithCustomDelimiters) { auto readFile = std::make_shared(path); auto serDeOptions = dwio::common::SerDeOptions('\t', '|', '#', '\\', true); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(type); readerOptions.setSerDeOptions(serDeOptions); @@ -487,8 +509,9 @@ TEST_F(TextReaderTest, projectComplexTypesWithCustomDelimiters) { dwio::common::RowReaderOptions rowOptions; rowOptions.setScanSpec(spec); - rowOptions.select(std::make_shared( - type, std::vector({"col_string", "col_map"}))); + rowOptions.select( + std::make_shared( + type, std::vector({"col_string", "col_map"}))); auto rowReader = reader->createRowReader(rowOptions); VectorPtr result; @@ -573,7 +596,9 @@ TEST_F(TextReaderTest, projectPrimitiveTypes) { auto path = velox::test::getDataFilePath( "velox/dwio/text/tests/reader/", "examples/simple_types"); auto readFile = std::make_shared(path); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(type); auto input = std::make_unique(readFile, poolRef()); @@ -587,8 +612,10 @@ TEST_F(TextReaderTest, projectPrimitiveTypes) { dwio::common::RowReaderOptions rowOptions; rowOptions.setScanSpec(spec); - rowOptions.select(std::make_shared( - type, std::vector({"col_tiny", "col_int", "col_double"}))); + rowOptions.select( + std::make_shared( + type, + std::vector({"col_tiny", "col_int", "col_double"}))); auto rowReader = reader->createRowReader(rowOptions); VectorPtr result; @@ -639,7 +666,9 @@ TEST_F(TextReaderTest, projectColumns) { "velox/dwio/text/tests/reader/", "examples/simple_types_compressed_file.gz"); auto readFile = std::make_shared(path); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(type); auto input = std::make_unique(readFile, poolRef()); @@ -650,8 +679,9 @@ TEST_F(TextReaderTest, projectColumns) { spec->addField("col_float", 1); dwio::common::RowReaderOptions rowOptions; rowOptions.setScanSpec(spec); - rowOptions.select(std::make_shared( - type, std::vector({"col_float"}))); + rowOptions.select( + std::make_shared( + type, std::vector({"col_float"}))); auto rowReader = reader->createRowReader(rowOptions); VectorPtr result; ASSERT_EQ(rowReader->next(10, result), 10); @@ -693,7 +723,9 @@ TEST_F(TextReaderTest, projectNone) { "velox/dwio/text/tests/reader/", "examples/simple_types"); auto readFile = std::make_shared(path); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(type); dwio::common::RowReaderOptions rowReaderOptions; @@ -728,7 +760,9 @@ TEST_F(TextReaderTest, compressedProjectNone) { "examples/simple_types_compressed_file.gz"); auto readFile = std::make_shared(path); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(type); dwio::common::RowReaderOptions rowReaderOptions; @@ -759,7 +793,9 @@ TEST_F(TextReaderTest, compressedFilter) { "velox/dwio/text/tests/reader/", "examples/simple_types_compressed_file.gz"); auto readFile = std::make_shared(path); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(type); auto input = std::make_unique(readFile, poolRef()); @@ -769,8 +805,9 @@ TEST_F(TextReaderTest, compressedFilter) { BaseVector::createConstant(VARCHAR(), "2023-07-18", 1, pool())); spec->addField("col_int", 1); spec->getOrCreateChild(common::Subfield("col_string")) - ->setFilter(std::make_unique( - std::vector({"BAR"}), false)); + ->setFilter( + std::make_unique( + std::vector({"BAR"}), false)); dwio::common::RowReaderOptions rowOptions; rowOptions.setScanSpec(spec); rowOptions.select( @@ -802,7 +839,9 @@ TEST_F(TextReaderTest, filter) { "velox/dwio/text/tests/reader/", "examples/more_simple_types"); auto readFile = std::make_shared(path); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(type); auto input = @@ -814,8 +853,9 @@ TEST_F(TextReaderTest, filter) { BaseVector::createConstant(VARCHAR(), "2023-07-18", 1, pool())); spec->addField("col_big_int", 1); spec->getOrCreateChild(common::Subfield("col_string")) - ->setFilter(std::make_unique( - std::vector({"BAR", "BAZ"}), false)); + ->setFilter( + std::make_unique( + std::vector({"BAR", "BAZ"}), false)); dwio::common::RowReaderOptions rowOptions; rowOptions.setScanSpec(spec); @@ -859,7 +899,9 @@ TEST_F(TextReaderTest, shrinkBatch) { auto path = velox::test::getDataFilePath( "velox/dwio/text/tests/reader/", "examples/simple_types"); auto readFile = std::make_shared(path); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(type); auto input = std::make_unique(readFile, poolRef()); @@ -892,7 +934,9 @@ TEST_F(TextReaderTest, compressedShrinkBatch) { "velox/dwio/text/tests/reader/", "examples/simple_types_compressed_file.gz"); auto readFile = std::make_shared(path); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(type); auto input = std::make_unique(readFile, poolRef()); @@ -922,7 +966,9 @@ TEST_F(TextReaderTest, emptyFile) { auto path = velox::test::getDataFilePath( "velox/dwio/text/tests/reader/", "examples/empty.gz"); auto readFile = std::make_shared(path); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(type); auto rowReaderOptions = dwio::common::RowReaderOptions(); setScanSpec(*type, rowReaderOptions); @@ -969,7 +1015,9 @@ TEST_F(TextReaderTest, readRanges) { "examples/simple_types_10_bytes_per_row"); auto readFile = std::make_shared(path); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(type); auto input = @@ -1066,7 +1114,9 @@ TEST_F(TextReaderTest, readFloatAsInt) { "velox/dwio/text/tests/reader/", "examples/simple_types"); auto readFile = std::make_shared(path); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(type); auto input = @@ -1165,7 +1215,9 @@ TEST_F(TextReaderTest, simpleTypes) { "velox/dwio/text/tests/reader/", "examples/more_simple_types"); auto readFile = std::make_shared(path); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(type); auto input = @@ -1225,7 +1277,9 @@ TEST_F(TextReaderTest, primitiveLimitsStressTest) { "velox/dwio/text/tests/reader/", "examples/primitive_limits"); auto readFile = std::make_shared(path); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); auto serDeOptions = dwio::common::SerDeOptions('\t', '=', '|', '\\', true); readerOptions.setFileSchema(type); readerOptions.setSerDeOptions(serDeOptions); @@ -1383,7 +1437,9 @@ TEST_F(TextReaderTest, DISABLED_nestedComplexTypesWithCustomDelimiters) { auto serDeOptions = dwio::common::SerDeOptions('\t', '=', '|', '\\', true); serDeOptions.separators[3] = ','; serDeOptions.separators[4] = ':'; - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(type); readerOptions.setSerDeOptions(serDeOptions); @@ -1454,7 +1510,9 @@ TEST_F(TextReaderTest, nestedArraysWithCustomDelimiters) { // - Pipe ('|') for outer array element separation (depth 1) // - Comma (',') for inner array element separation (depth 2) auto serDeOptions = dwio::common::SerDeOptions('\t', '|', ',', '\\', true); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(type); readerOptions.setSerDeOptions(serDeOptions); @@ -1545,7 +1603,9 @@ TEST_F(TextReaderTest, tripleNestedArraysWithCustomDelimiters) { // - Hash ('#') for innermost array element separation (depth 3) auto serDeOptions = dwio::common::SerDeOptions('\t', '|', ',', '\\', true); serDeOptions.separators[3] = '#'; - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(type); readerOptions.setSerDeOptions(serDeOptions); @@ -1579,7 +1639,9 @@ TEST_F(TextReaderTest, varbinarySuccessfulDecoding) { auto readFile = std::make_shared(path); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(type); auto input = @@ -1614,7 +1676,9 @@ TEST_F(TextReaderTest, varbinaryUnsuccessfulDecoding) { auto readFile = std::make_shared(path); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(type); auto input = @@ -1639,6 +1703,97 @@ TEST_F(TextReaderTest, varbinaryUnsuccessfulDecoding) { EXPECT_EQ(binaryVector->valueAt(1), StringView("Another@Invalid#String")); } +TEST_F(TextReaderTest, logicalTypes) { + auto expected = makeRowVector( + {makeNullableFlatVector( + {0, + 123, + -1234567, + 999999999999999, + std::nullopt, + 4242, + -1, + std::nullopt, + 314159265358979, + 77777, + 100000000000000, + -5432199, + std::nullopt, + 1234, + -999999999999999, + 999999999999999}, + DECIMAL(15, 2)), + makeNullableFlatVector( + {0, + HugeInt::parse("999999999999999999999"), + HugeInt::parse("123456789012345678901234567890"), + HugeInt::parse("-99999999999999999999999999"), + HugeInt::parse("88888888888888888888"), + std::nullopt, + 1, + std::nullopt, + HugeInt::parse("27182818284590452353612"), + HugeInt::parse("-123456789012345678999"), + HugeInt::parse("12345678901234567890123456789012345678"), + 987654321012, + std::nullopt, + 5678, + -123, + HugeInt::parse("99999999999999999999999999999999")}, + DECIMAL(38, 2)), + makeNullableFlatVector( + { + parseDate("1970-01-01"), + parseDate("2024-02-29"), + parseDate("1900-01-01"), + parseDate("2099-12-31"), + parseDate("2001-09-11"), + parseDate("2025-09-10"), + std::nullopt, + std::nullopt, + parseDate("1999-12-31"), + parseDate("2012-12-21"), + parseDate("2200-01-01"), + parseDate("1988-08-08"), + parseDate("1969-07-20"), + parseDate("2000-01-01"), + parseDate("1800-06-15"), + parseDate("2500-12-31"), + }, + DATE())}); + + auto type = + ROW({{"c0", DECIMAL(15, 2)}, {"c1", DECIMAL(38, 2)}, {"c2", DATE()}}); + + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", "examples/logical_types.gz"); + + auto readFile = std::make_shared(path); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setFileSchema(type); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + dwio::common::RowReaderOptions rowReaderOptions; + setScanSpec(*type, rowReaderOptions); + auto rowReader = reader->createRowReader(rowReaderOptions); + EXPECT_EQ(*reader->rowType(), *type); + + VectorPtr result; + ASSERT_EQ(rowReader->next(10, result), 10); + for (int i = 0; i < 10; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, i)); + } + ASSERT_EQ(rowReader->next(10, result), 6); + for (int i = 0; i < 6; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, 10 + i)); + } +} + TEST_F(TextReaderTest, nestedRows) { auto nestedRowChildren = std::vector{ makeFlatVector({42, 100, -5, 0, 999}), @@ -1674,7 +1829,9 @@ TEST_F(TextReaderTest, nestedRows) { auto readFile = std::make_shared(path); auto serDeOptions = dwio::common::SerDeOptions('&', ',', '#', '\\', true); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(type); readerOptions.setSerDeOptions(serDeOptions); @@ -1753,7 +1910,9 @@ TEST_P(TextReaderDecompressionTest, tests) { velox::test::getDataFilePath("velox/dwio/text/tests/reader/", filepath); auto readFile = std::make_shared(path); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(type); auto input = @@ -1811,6 +1970,316 @@ INSTANTIATE_TEST_SUITE_P( testing::ValuesIn(params), [](const auto& paramInfo) { return paramInfo.param.compression; }); +TEST_F(TextReaderTest, unsupportedCompressedKind) { + auto type = ROW( + {{"col_string", VARCHAR()}, + {"col_int", INTEGER()}, + {"col_float", DOUBLE()}, + {"col_bool", BOOLEAN()}}); + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + const std::string kBaseDir = "velox/dwio/text/tests/reader/"; + std::vector paths = { + getDataFilePath(kBaseDir, "examples/simple_types_compressed_file.lz4"), + getDataFilePath(kBaseDir, "examples/simple_types_compressed_file.lzo"), + getDataFilePath( + kBaseDir, "examples/simple_types_compressed_file.snappy")}; + for (const auto& path : paths) { + auto readFile = std::make_shared(path); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setFileSchema(type); + auto input = + std::make_unique(readFile, poolRef()); + EXPECT_THROW( + factory->createReader(std::move(input), readerOptions), + VeloxRuntimeError); + } +} + +TEST_F(TextReaderTest, extractionMapKeys) { + // Read a text file with a MAP column and apply a MapKeys extraction transform + // via ScanSpec. The text reader uses RowReader::projectColumns() as + // fallback, which should apply the transform. + const auto type = ROW( + {{"col_string", VARCHAR()}, + {"col_bigint_arr", ARRAY(BIGINT())}, + {"col_double_arr", ARRAY(DOUBLE())}, + {"col_map", MAP(BIGINT(), BOOLEAN())}}); + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", "examples/custom_delimiters_file"); + auto readFile = std::make_shared(path); + + auto serDeOptions = dwio::common::SerDeOptions('\t', '|', '#', '\\', true); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setFileSchema(type); + readerOptions.setSerDeOptions(serDeOptions); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + + // Build a ScanSpec that projects only col_map with a MapKeys extraction + // transform. + auto spec = std::make_shared("root"); + auto* mapSpec = spec->addField("col_map", 0); + mapSpec->addAllChildFields(*MAP(BIGINT(), BOOLEAN())); + + using connector::hive::applyExtractionChain; + using connector::hive::ExtractionPathElement; + using connector::hive::ExtractionPathElementPtr; + using connector::hive::ExtractionStep; + auto chain = std::vector{ + ExtractionPathElement::simple(ExtractionStep::kMapKeys)}; + mapSpec->setExtractionType(common::ScanSpec::ExtractionType::kKeys); + mapSpec->setTransform( + [chain](const VectorPtr& input, memory::MemoryPool* pool) -> VectorPtr { + return applyExtractionChain(input, chain, pool); + }, + ARRAY(BIGINT())); + + auto rowReaderOptions = dwio::common::RowReaderOptions(); + rowReaderOptions.setScanSpec(spec); + rowReaderOptions.range(0, 544); + auto rowReader = reader->createRowReader(rowReaderOptions); + + VectorPtr result; + auto numRows = rowReader->next(100, result); + ASSERT_GT(numRows, 0); + auto* row = result->as(); + ASSERT_EQ(row->childrenSize(), 1); + + // Verify the transform was applied: the result should be ARRAY(BIGINT) + // (keys), not MAP. + auto* keysArray = row->childAt(0)->as(); + ASSERT_NE(keysArray, nullptr); + ASSERT_EQ(keysArray->size(), numRows); + + // Each map in the test file has some keys. Verify array sizes are + // non-negative. + for (int i = 0; i < numRows; ++i) { + ASSERT_GE(keysArray->sizeAt(i), 0); + } +} + +TEST_F(TextReaderTest, extractionMapValues) { + // Read a text file with a MAP column and apply a MapValues extraction + // transform via ScanSpec. + const auto type = ROW( + {{"col_string", VARCHAR()}, + {"col_bigint_arr", ARRAY(BIGINT())}, + {"col_double_arr", ARRAY(DOUBLE())}, + {"col_map", MAP(BIGINT(), BOOLEAN())}}); + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", "examples/custom_delimiters_file"); + auto readFile = std::make_shared(path); + + auto serDeOptions = dwio::common::SerDeOptions('\t', '|', '#', '\\', true); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setFileSchema(type); + readerOptions.setSerDeOptions(serDeOptions); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + + // Build a ScanSpec that projects only col_map with a MapValues extraction + // transform. + auto spec = std::make_shared("root"); + auto* mapSpec = spec->addField("col_map", 0); + mapSpec->addAllChildFields(*MAP(BIGINT(), BOOLEAN())); + + using connector::hive::applyExtractionChain; + using connector::hive::ExtractionPathElement; + using connector::hive::ExtractionPathElementPtr; + using connector::hive::ExtractionStep; + auto chain = std::vector{ + ExtractionPathElement::simple(ExtractionStep::kMapValues)}; + mapSpec->setExtractionType(common::ScanSpec::ExtractionType::kValues); + mapSpec->setTransform( + [chain](const VectorPtr& input, memory::MemoryPool* pool) -> VectorPtr { + return applyExtractionChain(input, chain, pool); + }, + ARRAY(BOOLEAN())); + + auto rowReaderOptions = dwio::common::RowReaderOptions(); + rowReaderOptions.setScanSpec(spec); + rowReaderOptions.range(0, 544); + auto rowReader = reader->createRowReader(rowReaderOptions); + + VectorPtr result; + auto numRows = rowReader->next(100, result); + ASSERT_GT(numRows, 0); + auto* row = result->as(); + ASSERT_EQ(row->childrenSize(), 1); + + // Verify the transform was applied: the result should be ARRAY(BOOLEAN) + // (values), not MAP. + auto* valuesArray = row->childAt(0)->as(); + ASSERT_NE(valuesArray, nullptr); + ASSERT_EQ(valuesArray->size(), numRows); + + // Verify array sizes are non-negative. + for (int i = 0; i < numRows; ++i) { + ASSERT_GE(valuesArray->sizeAt(i), 0); + } +} + +TEST_F(TextReaderTest, extractionMapKeyFilter) { + // Write a text file with a MAP(VARCHAR, BIGINT) column and apply a + // MapKeyFilter extraction transform via ScanSpec. + auto textFile = TempFilePath::create(); + { + auto writeFile = + std::make_unique(textFile->getPath(), true, false); + // Row format: map entries separated by \x02, key-value by \x03. + // Row 0: {"a":1, "b":2, "c":3} + // Row 1: {"a":10, "d":40} + writeFile->append( + "a\x03" + "1\x02" + "b\x03" + "2\x02" + "c\x03" + "3\n" + "a\x03" + "10\x02" + "d\x03" + "40\n"); + writeFile->close(); + } + + const auto type = ROW({{"col_map", MAP(VARCHAR(), BIGINT())}}); + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + + auto readFile = std::make_shared(textFile->getPath()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setFileSchema(type); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + + // Build a ScanSpec with MapKeyFilter extraction for keys {"a", "b"}. + auto spec = std::make_shared("root"); + auto* mapSpec = spec->addField("col_map", 0); + mapSpec->addAllChildFields(*MAP(VARCHAR(), BIGINT())); + + using connector::hive::applyExtractionChain; + using connector::hive::ExtractionPathElement; + using connector::hive::ExtractionPathElementPtr; + using connector::hive::ExtractionStep; + auto chain = std::vector{ + ExtractionPathElement::mapKeyFilter(std::vector{"a", "b"})}; + mapSpec->setTransform( + [chain](const VectorPtr& input, memory::MemoryPool* pool) -> VectorPtr { + return applyExtractionChain(input, chain, pool); + }, + MAP(VARCHAR(), BIGINT())); + + auto serDeOptions = dwio::common::SerDeOptions('\x01', '\x02', '\x03'); + readerOptions.setSerDeOptions(serDeOptions); + input = std::make_unique(readFile, poolRef()); + reader = factory->createReader(std::move(input), readerOptions); + + auto rowReaderOptions = dwio::common::RowReaderOptions(); + rowReaderOptions.setScanSpec(spec); + auto rowReader = reader->createRowReader(rowReaderOptions); + + VectorPtr result; + auto numRows = rowReader->next(100, result); + ASSERT_EQ(numRows, 2); + auto* row = result->as(); + ASSERT_EQ(row->childrenSize(), 1); + + // Verify the MapKeyFilter was applied: only keys "a" and "b" remain. + auto* filteredMap = row->childAt(0)->as(); + ASSERT_NE(filteredMap, nullptr); + // Row 0: {"a":1, "b":2} kept, "c" filtered out. + ASSERT_EQ(filteredMap->sizeAt(0), 2); + // Row 1: {"a":10} kept, "d" filtered out. + ASSERT_EQ(filteredMap->sizeAt(1), 1); +} + +TEST_F(TextReaderTest, extractionTransformOnMapColumn) { + // Read a text file with a MAP column and apply a Size extraction transform + // via ScanSpec. The text reader uses RowReader::projectColumns() as + // fallback, which should apply the transform. + const auto type = ROW( + {{"col_string", VARCHAR()}, + {"col_bigint_arr", ARRAY(BIGINT())}, + {"col_double_arr", ARRAY(DOUBLE())}, + {"col_map", MAP(BIGINT(), BOOLEAN())}}); + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", "examples/custom_delimiters_file"); + auto readFile = std::make_shared(path); + + auto serDeOptions = dwio::common::SerDeOptions('\t', '|', '#', '\\', true); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); + readerOptions.setFileSchema(type); + readerOptions.setSerDeOptions(serDeOptions); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + + // Build a ScanSpec that projects only col_map with a Size extraction + // transform. + auto spec = std::make_shared("root"); + auto* mapSpec = spec->addField("col_map", 0); + mapSpec->addAllChildFields(*MAP(BIGINT(), BOOLEAN())); + + using connector::hive::applyExtractionChain; + using connector::hive::ExtractionPathElement; + using connector::hive::ExtractionPathElementPtr; + using connector::hive::ExtractionStep; + auto chain = std::vector{ + ExtractionPathElement::simple(ExtractionStep::kSize)}; + mapSpec->setExtractionType(common::ScanSpec::ExtractionType::kSize); + mapSpec->setTransform( + [chain](const VectorPtr& input, memory::MemoryPool* pool) -> VectorPtr { + return applyExtractionChain(input, chain, pool); + }, + BIGINT()); + + auto rowReaderOptions = dwio::common::RowReaderOptions(); + rowReaderOptions.setScanSpec(spec); + rowReaderOptions.range(0, 544); + auto rowReader = reader->createRowReader(rowReaderOptions); + + VectorPtr result; + auto numRows = rowReader->next(100, result); + ASSERT_GT(numRows, 0); + auto* row = result->as(); + ASSERT_EQ(row->childrenSize(), 1); + + // Verify the transform was applied: the result should be BIGINT (sizes), + // not MAP. + auto* sizes = row->childAt(0)->as>(); + ASSERT_NE(sizes, nullptr); + ASSERT_EQ(sizes->size(), numRows); + + // Each map entry in the test file has some keys. Just verify sizes are + // non-negative. + for (int i = 0; i < numRows; ++i) { + ASSERT_GE(sizes->valueAt(i), 0); + } +} + } // namespace } // namespace facebook::velox::text diff --git a/velox/dwio/text/tests/reader/examples/logical_types.gz b/velox/dwio/text/tests/reader/examples/logical_types.gz new file mode 100644 index 00000000000..3a20dc8ceee Binary files /dev/null and b/velox/dwio/text/tests/reader/examples/logical_types.gz differ diff --git a/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.lz4 b/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.lz4 new file mode 100644 index 00000000000..fa1adc4234d Binary files /dev/null and b/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.lz4 differ diff --git a/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.lzo b/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.lzo new file mode 100644 index 00000000000..951ed722424 Binary files /dev/null and b/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.lzo differ diff --git a/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.snappy b/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.snappy new file mode 100644 index 00000000000..3343cdfc53e Binary files /dev/null and b/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.snappy differ diff --git a/velox/dwio/text/tests/writer/BufferedWriterSinkTest.cpp b/velox/dwio/text/tests/writer/BufferedWriterSinkTest.cpp index a7372b7b6a7..c19d5f3961c 100644 --- a/velox/dwio/text/tests/writer/BufferedWriterSinkTest.cpp +++ b/velox/dwio/text/tests/writer/BufferedWriterSinkTest.cpp @@ -16,11 +16,12 @@ #include #include "velox/common/file/FileSystems.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/dwio/text/tests/writer/FileReaderUtil.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/vector/tests/utils/VectorTestBase.h" namespace facebook::velox::text { +using namespace facebook::velox::common::testutil; class BufferedWriterSinkTest : public testing::Test, public velox::test::VectorTestBase { @@ -30,7 +31,7 @@ class BufferedWriterSinkTest : public testing::Test, dwio::common::LocalFileSink::registerFactory(); rootPool_ = memory::memoryManager()->addRootPool("BufferedWriterSinkTest"); leafPool_ = rootPool_->addLeafChild("BufferedWriterSinkTest"); - tempPath_ = exec::test::TempDirectoryPath::create(); + tempPath_ = TempDirectoryPath::create(); } protected: @@ -40,7 +41,7 @@ class BufferedWriterSinkTest : public testing::Test, std::shared_ptr rootPool_; std::shared_ptr leafPool_; - std::shared_ptr tempPath_; + std::shared_ptr tempPath_; }; TEST_F(BufferedWriterSinkTest, write) { @@ -80,4 +81,103 @@ TEST_F(BufferedWriterSinkTest, abort) { uint64_t result = readFile(tempPath_->getPath(), filename); EXPECT_EQ(result, 10); } + +TEST_F(BufferedWriterSinkTest, oversizedWriteBypassesBuffer) { + // Regression test: a write whose payload exceeds the flush buffer size must + // drain anything buffered (preserving order) and forward the oversized + // payload directly to the underlying sink, instead of asserting. + constexpr uint64_t kFlushBufferSize = 16; + const std::string oversized(64, 'x'); + + const auto tempPath = tempPath_->getPath(); + const auto filename = "test_buffered_oversized.txt"; + + auto filePath = fs::path(fmt::format("{}/{}", tempPath, filename)); + auto sink = std::make_unique( + filePath, dwio::common::FileSink::Options{.pool = leafPool_.get()}); + + auto bufferedWriterSink = std::make_unique( + std::move(sink), + rootPool_->addLeafChild("bufferedWriterSinkTest"), + kFlushBufferSize); + + // Small write that fits in the buffer, then an oversized write that must + // first drain the small one and then write directly to the sink. + const std::string prefix = "head:"; + bufferedWriterSink->write(prefix.data(), prefix.size()); + bufferedWriterSink->write(oversized.data(), oversized.size()); + // Another small write after the bypass to confirm buffering still works. + const std::string suffix = ":tail"; + bufferedWriterSink->write(suffix.data(), suffix.size()); + bufferedWriterSink->close(); + + const auto fs = filesystems::getFileSystem(tempPath, nullptr); + const auto& file = fs->openFileForRead(filePath.string()); + const auto fileSize = file->size(); + std::string contents(fileSize, '\0'); + file->pread(0, fileSize, contents.data()); + + EXPECT_EQ(contents, prefix + oversized + suffix); +} + +namespace { +// Reads the file at 'path/name' into a string. +std::string readFileBytes(const std::string& path, const std::string& name) { + const auto fs = filesystems::getFileSystem(path, nullptr); + const auto filePath = fs::path(fmt::format("{}/{}", path, name)); + const auto& file = fs->openFileForRead(filePath.string()); + const auto fileSize = file->size(); + std::string out(fileSize, '\0'); + if (fileSize > 0) { + file->pread(0, fileSize, out.data()); + } + return out; +} +} // namespace + +TEST_F(BufferedWriterSinkTest, writeEqualToFlushBufferUsesBufferedPath) { + // Boundary: a write of exactly flushBufferSize_ bytes is *not* oversized + // (the bypass condition is strict `size > flushBufferSize_`); it must go + // through the buffered path and land on disk intact after close(). + constexpr uint64_t kFlushBufferSize = 16; + const std::string exact(kFlushBufferSize, 'a'); + + const auto tempPath = tempPath_->getPath(); + const auto filename = "test_buffered_equal.txt"; + auto filePath = fs::path(fmt::format("{}/{}", tempPath, filename)); + auto sink = std::make_unique( + filePath, dwio::common::FileSink::Options{.pool = leafPool_.get()}); + auto bufferedWriterSink = std::make_unique( + std::move(sink), + rootPool_->addLeafChild("bufferedWriterSinkTest"), + kFlushBufferSize); + + bufferedWriterSink->write(exact.data(), exact.size()); + bufferedWriterSink->close(); + + EXPECT_EQ(readFileBytes(tempPath, filename), exact); +} + +TEST_F(BufferedWriterSinkTest, writeOneByteOverFlushBufferTriggersBypass) { + // Boundary: a write of flushBufferSize_ + 1 bytes is the smallest oversized + // write and must hit the bypass path. The buffer is empty so the flush() is + // a no-op; verify the payload still lands on disk in full. + constexpr uint64_t kFlushBufferSize = 16; + const std::string oneOver(kFlushBufferSize + 1, 'b'); + + const auto tempPath = tempPath_->getPath(); + const auto filename = "test_buffered_one_over.txt"; + auto filePath = fs::path(fmt::format("{}/{}", tempPath, filename)); + auto sink = std::make_unique( + filePath, dwio::common::FileSink::Options{.pool = leafPool_.get()}); + auto bufferedWriterSink = std::make_unique( + std::move(sink), + rootPool_->addLeafChild("bufferedWriterSinkTest"), + kFlushBufferSize); + + bufferedWriterSink->write(oneOver.data(), oneOver.size()); + bufferedWriterSink->close(); + + EXPECT_EQ(readFileBytes(tempPath, filename), oneOver); +} } // namespace facebook::velox::text diff --git a/velox/dwio/text/tests/writer/CMakeLists.txt b/velox/dwio/text/tests/writer/CMakeLists.txt index f53ecec5bc3..eea5ac780da 100644 --- a/velox/dwio/text/tests/writer/CMakeLists.txt +++ b/velox/dwio/text/tests/writer/CMakeLists.txt @@ -18,6 +18,7 @@ add_executable( BufferedWriterSinkTest.cpp FileReaderUtil.cpp ) +velox_add_test_headers(velox_text_writer_test FileReaderUtil.h) add_test( NAME velox_text_writer_test @@ -33,7 +34,6 @@ target_link_libraries( velox_dwio_text_reader_register velox_dwio_text_writer_register velox_link_libs - Boost::regex Folly::folly ${TEST_LINK_LIBS} GTest::gtest diff --git a/velox/dwio/text/tests/writer/TextWriterTest.cpp b/velox/dwio/text/tests/writer/TextWriterTest.cpp index d007c932362..c488ece97f6 100644 --- a/velox/dwio/text/tests/writer/TextWriterTest.cpp +++ b/velox/dwio/text/tests/writer/TextWriterTest.cpp @@ -17,10 +17,11 @@ #include "velox/dwio/text/writer/TextWriter.h" #include "velox/buffer/Buffer.h" #include "velox/common/file/FileSystems.h" +#include "velox/common/io/IoStatistics.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/dwio/text/RegisterTextReader.h" #include "velox/dwio/text/RegisterTextWriter.h" #include "velox/dwio/text/tests/writer/FileReaderUtil.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/vector/tests/utils/VectorTestBase.h" #include @@ -28,6 +29,8 @@ /// TODO: Add fuzzer test. namespace facebook::velox::text { +using namespace facebook::velox::common::testutil; + class TextWriterTest : public testing::Test, public velox::test::VectorTestBase { public: @@ -37,7 +40,7 @@ class TextWriterTest : public testing::Test, registerTextReaderFactory(); rootPool_ = memory::memoryManager()->addRootPool("TextWriterTests"); leafPool_ = rootPool_->addLeafChild("TextWriterTests"); - tempPath_ = exec::test::TempDirectoryPath::create(); + tempPath_ = TempDirectoryPath::create(); } void TearDown() override { @@ -62,9 +65,13 @@ class TextWriterTest : public testing::Test, constexpr static float kInf = std::numeric_limits::infinity(); constexpr static double kNaN = std::numeric_limits::quiet_NaN(); + std::shared_ptr dataIoStats_ = + std::make_shared(); + std::shared_ptr metadataIoStats_ = + std::make_shared(); std::shared_ptr rootPool_; std::shared_ptr leafPool_; - std::shared_ptr tempPath_; + std::shared_ptr tempPath_; }; TEST_F(TextWriterTest, write) { @@ -225,7 +232,9 @@ TEST_F(TextWriterTest, verifyWriteWithTextReader) { auto readerFactory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); auto readFile = std::make_shared(filePath.string()); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(schema); auto input = std::make_unique(readFile, poolRef()); @@ -493,7 +502,9 @@ TEST_F(TextWriterTest, verifyMapAndArrayComplexTypesWithTextReader) { auto readerFactory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); auto readFile = std::make_shared(filePath.string()); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(schema); readerOptions.setSerDeOptions(serDeOptions); auto input = @@ -684,7 +695,9 @@ TEST_F(TextWriterTest, verifyArrayTypesWithTextReader) { auto readerFactory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); auto readFile = std::make_shared(filePath.string()); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(schema); readerOptions.setSerDeOptions(serDeOptions); @@ -922,7 +935,9 @@ TEST_F(TextWriterTest, verifyMapTypesWithTextReader) { auto readerFactory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); auto readFile = std::make_shared(filePath.string()); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(schema); readerOptions.setSerDeOptions(serDeOptions); @@ -1077,7 +1092,9 @@ TEST_F(TextWriterTest, verifyNestedRowTypesWithTextReader) { auto readerFactory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); auto readFile = std::make_shared(filePath.string()); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(schema); readerOptions.setSerDeOptions(serDeOptions); @@ -1465,7 +1482,9 @@ TEST_F( auto readerFactory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); auto readFile = std::make_shared(filePath.string()); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(schema); readerOptions.setSerDeOptions(serDeOptions); @@ -1594,7 +1613,9 @@ TEST_F(TextWriterTest, verifySimpleEscapeCharTestWithTextReader) { auto readerFactory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); auto readFile = std::make_shared(filePath.string()); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(schema); readerOptions.setSerDeOptions(serDeOptions); @@ -1726,7 +1747,9 @@ TEST_F(TextWriterTest, verifyCustomEscapeCharTestWithTextReader) { auto readerFactory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); auto readFile = std::make_shared(filePath.string()); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(schema); readerOptions.setSerDeOptions(serDeOptions); @@ -1848,7 +1871,9 @@ TEST_F(TextWriterTest, verifyHeaderTestWithTextReader) { auto readerFactory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); auto readFile = std::make_shared(filePath.string()); - auto readerOptions = dwio::common::ReaderOptions(pool()); + dwio::common::ReaderOptions readerOptions(pool()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileSchema(schema); readerOptions.setSerDeOptions(serDeOptions); diff --git a/velox/dwio/text/writer/BufferedWriterSink.cpp b/velox/dwio/text/writer/BufferedWriterSink.cpp index e5de4a20401..3f62721b7ef 100644 --- a/velox/dwio/text/writer/BufferedWriterSink.cpp +++ b/velox/dwio/text/writer/BufferedWriterSink.cpp @@ -41,15 +41,20 @@ void BufferedWriterSink::write(char value) { } void BufferedWriterSink::write(const char* data, uint64_t size) { - // TODO Add logic for when size is larger than flushCount_ - VELOX_CHECK_GE( - flushBufferSize_, - size, - "write data size exceeds flush buffer size limit"); - if (buf_->size() + size > flushBufferSize_) { flush(); } + + if (size > flushBufferSize_) { + // Oversized payload: grow the (now-empty) buffer to fit, fill it, then + // flush directly to the sink. flush() resets the buffer back to + // flushBufferSize_ on the way out via reserveBuffer(). + buf_->reserve(size); + buf_->append(0, data, size); + flush(); + return; + } + buf_->append(buf_->size(), data, size); } diff --git a/velox/dwio/text/writer/CMakeLists.txt b/velox/dwio/text/writer/CMakeLists.txt index cc2657f5b0b..eba803d2958 100644 --- a/velox/dwio/text/writer/CMakeLists.txt +++ b/velox/dwio/text/writer/CMakeLists.txt @@ -12,6 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -velox_add_library(velox_dwio_text_writer TextWriter.cpp BufferedWriterSink.cpp) +velox_add_library( + velox_dwio_text_writer + TextWriter.cpp + BufferedWriterSink.cpp + HEADERS + BufferedWriterSink.h + TextWriter.h +) velox_link_libraries(velox_dwio_text_writer velox_dwio_common fmt::fmt) diff --git a/velox/dwio/text/writer/TextWriter.cpp b/velox/dwio/text/writer/TextWriter.cpp index f221257f1dc..be837b12522 100644 --- a/velox/dwio/text/writer/TextWriter.cpp +++ b/velox/dwio/text/writer/TextWriter.cpp @@ -70,13 +70,15 @@ TextWriter::TextWriter( const std::shared_ptr& options, const SerDeOptions& serDeOptions) : schema_(std::move(schema)), - bufferedWriterSink_(std::make_unique( - std::move(sink), - options->memoryPool->addLeafChild(fmt::format( - "{}.text_writer_node.{}", - options->memoryPool->name(), - folly::to(folly::Random::rand64()))), - options->defaultFlushCount)), + bufferedWriterSink_( + std::make_unique( + std::move(sink), + options->memoryPool->addLeafChild( + fmt::format( + "{}.text_writer_node.{}", + options->memoryPool->name(), + folly::to(folly::Random::rand64()))), + options->defaultFlushCount)), headerLineCount_(options->headerLineCount), serDeOptions_(serDeOptions) { VELOX_CHECK_LE(headerLineCount_, 1, "Header line count must be <= 1"); @@ -173,8 +175,9 @@ void TextWriter::flush() { bufferedWriterSink_->flush(); } -void TextWriter::close() { +std::unique_ptr TextWriter::close() { bufferedWriterSink_->close(); + return std::make_unique(); } void TextWriter::abort() { @@ -348,7 +351,7 @@ void TextWriter::writeCellValue( [[fallthrough]]; default: VELOX_NYI( - "Text writer does not support type {}", mapTypeKindToName(type)); + "Text writer does not support type {}", TypeKindName::toName(type)); } VELOX_CHECK( diff --git a/velox/dwio/text/writer/TextWriter.h b/velox/dwio/text/writer/TextWriter.h index 2ec11b82685..045f7b7f5cb 100644 --- a/velox/dwio/text/writer/TextWriter.h +++ b/velox/dwio/text/writer/TextWriter.h @@ -16,6 +16,7 @@ #pragma once +#include "velox/dwio/common/FileMetadata.h" #include "velox/dwio/common/FileSink.h" #include "velox/dwio/common/Options.h" #include "velox/dwio/common/Writer.h" @@ -26,6 +27,9 @@ namespace facebook::velox::text { using dwio::common::SerDeOptions; +/// Text-specific file metadata wrapper. Currently a placeholder. +class TextFileMetadata : public dwio::common::FileMetadata {}; + struct WriterOptions : public dwio::common::WriterOptions { int64_t defaultFlushCount = 10 << 10; uint8_t headerLineCount = @@ -58,7 +62,7 @@ class TextWriter : public dwio::common::Writer { return true; } - void close() override; + std::unique_ptr close() override; void abort() override; diff --git a/velox/examples/OperatorExtensibility.cpp b/velox/examples/OperatorExtensibility.cpp index 6f48b4e9bbd..a6e4aaec682 100644 --- a/velox/examples/OperatorExtensibility.cpp +++ b/velox/examples/OperatorExtensibility.cpp @@ -124,8 +124,9 @@ class DuplicateRowOperator : public exec::Operator { outputChildren.reserve(input->childrenSize()); for (const auto& child : input->children()) { - outputChildren.push_back(BaseVector::wrapInDictionary( - BufferPtr(), indices, outputSize, child)); + outputChildren.push_back( + BaseVector::wrapInDictionary( + BufferPtr(), indices, outputSize, child)); } return std::make_shared( pool(), diff --git a/velox/examples/ScanAndSort.cpp b/velox/examples/ScanAndSort.cpp index aee9144b172..4cc8cc97d2c 100644 --- a/velox/examples/ScanAndSort.cpp +++ b/velox/examples/ScanAndSort.cpp @@ -17,22 +17,24 @@ #include "velox/common/base/Fs.h" #include "velox/common/file/FileSystems.h" #include "velox/common/memory/Memory.h" +#include "velox/common/testutil/TempDirectoryPath.h" +#include "velox/connectors/ConnectorRegistry.h" #include "velox/connectors/hive/HiveConnector.h" #include "velox/connectors/hive/HiveConnectorSplit.h" #include "velox/dwio/common/FileSink.h" #include "velox/dwio/dwrf/RegisterDwrfReader.h" #include "velox/dwio/dwrf/RegisterDwrfWriter.h" #include "velox/exec/Task.h" -#include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/type/Type.h" #include "velox/vector/BaseVector.h" #include +#include #include using namespace facebook::velox; +using namespace facebook::velox::common::testutil; // This file contains a step-by-step minimal example of a workflow that: // @@ -91,7 +93,8 @@ int main(int argc, char** argv) { kHiveConnectorId, std::make_shared( std::unordered_map())); - connector::registerConnector(hiveConnector); + connector::ConnectorRegistry::global().insert( + hiveConnector->connectorId(), hiveConnector); // To be able to read local files, we need to register the local file // filesystem. We also need to register the dwrf reader factory as well as a @@ -104,7 +107,7 @@ int main(int argc, char** argv) { // Create a temporary dir to store the local file created. Note that this // directory is automatically removed when the `tempDir` object runs out of // scope. - auto tempDir = exec::test::TempDirectoryPath::create(); + auto tempDir = TempDirectoryPath::create(); auto absTempDirPath = tempDir->getPath(); // Once we finalize setting up the Hive connector, let's define our query @@ -125,7 +128,7 @@ int main(int argc, char** argv) { std::shared_ptr executor( std::make_shared( - std::thread::hardware_concurrency())); + folly::available_concurrency())); // Task is the top-level execution concept. A task needs a taskId (as a // string), the plan fragment to execute, a destination (only used for @@ -135,7 +138,8 @@ int main(int argc, char** argv) { writerPlanFragment, /*destination=*/0, core::QueryCtx::create(executor.get()), - exec::Task::ExecutionMode::kSerial); + exec::Task::ExecutionMode::kSerial, + exec::Consumer{}); // next() starts execution using the client thread. The loop pumps output // vectors out of the task (there are none in this query fragment). @@ -165,7 +169,8 @@ int main(int argc, char** argv) { readPlanFragment, /*destination=*/0, core::QueryCtx::create(executor.get()), - exec::Task::ExecutionMode::kSerial); + exec::Task::ExecutionMode::kSerial, + exec::Consumer{}); // Now that we have the query fragment and Task structure set up, we will // add data to it via `splits`. diff --git a/velox/examples/ScanOrc.cpp b/velox/examples/ScanOrc.cpp index 1f75f698324..0dbb79697a5 100644 --- a/velox/examples/ScanOrc.cpp +++ b/velox/examples/ScanOrc.cpp @@ -18,14 +18,16 @@ #include #include "velox/common/file/FileSystems.h" +#include "velox/common/io/IoStatistics.h" #include "velox/common/memory/Memory.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/dwio/common/Reader.h" #include "velox/dwio/common/ReaderFactory.h" #include "velox/dwio/dwrf/RegisterDwrfReader.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/vector/BaseVector.h" using namespace facebook::velox; +using namespace facebook::velox::common::testutil; using namespace facebook::velox::dwio::common; using namespace facebook::velox::dwrf; @@ -48,7 +50,11 @@ int main(int argc, char** argv) { auto pool = facebook::velox::memory::memoryManager()->addLeafPool(); std::string filePath{argv[1]}; - dwio::common::ReaderOptions readerOpts{pool.get()}; + auto dataIoStats = std::make_shared(); + auto metadataIoStats = std::make_shared(); + dwio::common::ReaderOptions readerOpts(pool.get()); + readerOpts.setDataIoStats(dataIoStats); + readerOpts.setMetadataIoStats(metadataIoStats); // To make DwrfReader reads ORC file, setFileFormat to FileFormat::ORC readerOpts.setFileFormat(FileFormat::ORC); auto reader = dwio::common::getReaderFactory(FileFormat::ORC) diff --git a/velox/examples/SimpleFunctions.cpp b/velox/examples/SimpleFunctions.cpp index 4c12cb4bf51..e34b0598a22 100644 --- a/velox/examples/SimpleFunctions.cpp +++ b/velox/examples/SimpleFunctions.cpp @@ -341,7 +341,7 @@ struct MyRegexpMatchFunction { // quite expensive to compile it on a per-row basis. In this example we // support both modes (const and non-const). if (pattern != nullptr) { - re_.emplace(*pattern); + re_.emplace(std::string_view(*pattern)); } // Optionally, one could also inspect the session configs in `QueryConfig`. @@ -359,7 +359,8 @@ struct MyRegexpMatchFunction { // > `my_regexp_match(col1, col2)` result = re_.has_value() ? RE2::PartialMatch(toStringPiece(input), *re_) - : RE2::PartialMatch(toStringPiece(input), ::re2::RE2(pattern)); + : RE2::PartialMatch( + toStringPiece(input), ::re2::RE2(std::string_view(pattern))); return true; } diff --git a/velox/exec/AdaptivePrefetch.h b/velox/exec/AdaptivePrefetch.h new file mode 100644 index 00000000000..afc44128116 --- /dev/null +++ b/velox/exec/AdaptivePrefetch.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace facebook::velox::exec { + +/// Adaptive prefetch look-ahead for loops with random memory accesses where +/// the working set exceeds CPU cache. Measures per-iteration latency during +/// initial iterations, then returns a fixed look-ahead for the remainder. +class AdaptivePrefetch { + public: + explicit AdaptivePrefetch(int32_t numIterations) + : numIterations_(numIterations), + start_(std::chrono::steady_clock::now()) {} + + /// Returns look-ahead distance, or 0 when too close to the end. + int32_t lookAhead() { + if (iteration_ == kMeasurementIterations) { + computeLookAhead(); + } + ++iteration_; + if (iteration_ + lookAhead_ > numIterations_) { + return 0; + } + return lookAhead_; + } + + private: + static constexpr int32_t kMeasurementIterations = 16; + static constexpr int32_t kMinLookAhead = 4; + static constexpr int32_t kMaxLookAhead = 32; + static constexpr int64_t kAssumedDramLatencyNs = 100; + static constexpr int64_t kCoefficient = 4; + + void computeLookAhead() { + auto elapsedNs = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start_) + .count(); + lookAhead_ = std::clamp( + static_cast( + kCoefficient * kAssumedDramLatencyNs * kMeasurementIterations / + elapsedNs), + kMinLookAhead, + kMaxLookAhead); + } + + int32_t numIterations_; + int32_t iteration_{0}; + int32_t lookAhead_{kMinLookAhead}; + std::chrono::steady_clock::time_point start_; +}; + +} // namespace facebook::velox::exec diff --git a/velox/exec/Aggregate.cpp b/velox/exec/Aggregate.cpp index dd07b7e4436..c48cfe47a91 100644 --- a/velox/exec/Aggregate.cpp +++ b/velox/exec/Aggregate.cpp @@ -20,7 +20,6 @@ #include "velox/exec/AggregateCompanionAdapter.h" #include "velox/exec/AggregateCompanionSignatures.h" #include "velox/exec/AggregateWindow.h" -#include "velox/expression/SignatureBinder.h" namespace facebook::velox::exec { @@ -139,6 +138,7 @@ std::vector registerAggregateFunction( bool registerCompanionFunctions, bool overwrite) { auto size = names.size(); + VELOX_CHECK_NE(size, 0, "Aggregate function registered without a name."); std::vector registrationResults{size}; for (int i = 0; i < size; ++i) { registrationResults[i] = registerAggregateFunction( @@ -295,24 +295,6 @@ std::unique_ptr Aggregate::create( const std::vector& argTypes, const TypePtr& resultType, const core::QueryConfig& config) { - // TODO(timaou, kletkavrubashku): Reneable the validation once "regr_slope" - // signature is fixed - // - // Validate the result type. if (isPartialOutput(step)) { - // auto intermediateType = Aggregate::intermediateType(name, argTypes); - // VELOX_CHECK( - // resultType->equivalent(*intermediateType), - // "Intermediate type mismatch. Expected: {}, actual: {}", - // intermediateType->toString(), - // resultType->toString()); - // } else { - // auto finalType = Aggregate::finalType(name, argTypes); - // VELOX_CHECK( - // resultType->equivalent(*finalType), - // "Final type mismatch. Expected: {}, actual: {}", - // finalType->toString(), - // resultType->toString()); - // } // Lookup the function in the new registry first. if (auto func = getAggregateFunctionEntry(name)) { return func->factory(step, argTypes, resultType, config); diff --git a/velox/exec/Aggregate.h b/velox/exec/Aggregate.h index 69e7e16bd43..d38f91bc3af 100644 --- a/velox/exec/Aggregate.h +++ b/velox/exec/Aggregate.h @@ -93,6 +93,17 @@ class Aggregate { setAllocatorInternal(allocator); } + /// Called after construction to pass constant input values to the aggregate. + /// Non-constant inputs have null entries in 'constantInputs'. This is called + /// before any data is processed, allowing the aggregate to read constant + /// arguments (e.g., flags, configuration values) at initialization time + /// rather than extracting them from 'args' at runtime. + /// + /// Default implementation is a no-op. Override in subclasses that need + /// access to constant arguments. + virtual void setConstantInputs( + const std::vector& /*constantInputs*/) {} + /// Called for functions that take one or more lambda expression as input. /// These expressions must appear after all non-lambda inputs. /// These expressions cannot use captures. @@ -302,9 +313,22 @@ class Aggregate { for (auto* group : groups) { group[initializedByte_] &= ~initializedMask_; + clearNull(group); } } + /// Returns true if this aggregate function supports lightweight memory + /// compaction via compact(). + virtual bool supportsCompact() const { + return false; + } + + /// Invoked by GroupingSet::compact() to perform lightweight memory compaction + /// on the given 'groups', freeing unused memory without spilling to disk. + virtual uint64_t compact(folly::Range /*groups*/) { + return 0; + } + // Clears state between reuses, e.g. this is called before reusing // the aggregation operator's state after flushing a partial // aggregation. @@ -505,6 +529,10 @@ using AggregateFunctionFactory = std::function( const core::QueryConfig& config)>; struct AggregateFunctionMetadata { + /// True if results of the aggregation ignore duplicate values. + /// For example, min and max ignore duplicates while sum does not. + bool ignoreDuplicates{false}; + /// True if results of the aggregation depend on the order of inputs. For /// example, array_agg is order sensitive while count is not. bool orderSensitive{true}; diff --git a/velox/exec/AggregateCompanionAdapter.cpp b/velox/exec/AggregateCompanionAdapter.cpp index ea3c64c5c9e..9aab4fbac9b 100644 --- a/velox/exec/AggregateCompanionAdapter.cpp +++ b/velox/exec/AggregateCompanionAdapter.cpp @@ -340,25 +340,35 @@ bool registerMergeExtractFunctionInternal( mergeExtractFunctionName, std::move(mergeExtractSignatures), [name, mergeExtractFunctionName]( - core::AggregationNode::Step /*step*/, + core::AggregationNode::Step step, const std::vector& argTypes, const TypePtr& resultType, const core::QueryConfig& config) -> std::unique_ptr { - const auto& [originalResultType, _] = - resolveAggregateFunction(mergeExtractFunctionName, argTypes); - if (!originalResultType) { - // TODO: limitation -- result type must be resolvable given - // intermediate type of the original UDAF. - VELOX_UNREACHABLE( - "Signatures whose result types are not resolvable given intermediate types should have been excluded."); + TypePtr functionResultType; + if (step == core::AggregationNode::Step::kFinal || + step == core::AggregationNode::Step::kSingle) { + functionResultType = resultType; + } else { + // When step is kPartial or kIntermediate, 'resultType' is + // the intermediate type and the original result type needs to + // be resolved for the aggregate function creation. + const auto& originalResultType = + resolveResultType(mergeExtractFunctionName, argTypes); + if (!originalResultType) { + // Result type must be resolvable given intermediate type of + // the original UDAF. + VELOX_FAIL( + "Signatures' result types must be resolvable given intermediate types."); + } + functionResultType = originalResultType; } if (auto func = getAggregateFunctionEntry(name)) { auto fn = func->factory( core::AggregationNode::Step::kFinal, argTypes, - originalResultType, + functionResultType, config); VELOX_CHECK_NOT_NULL(fn); return std::make_unique< diff --git a/velox/exec/AggregateCompanionAdapter.h b/velox/exec/AggregateCompanionAdapter.h index 8a6af66ff51..fc0a5909f8a 100644 --- a/velox/exec/AggregateCompanionAdapter.h +++ b/velox/exec/AggregateCompanionAdapter.h @@ -169,21 +169,35 @@ struct AggregateCompanionAdapter { }; }; +/// In Velox, "Step" is a property of the aggregate operator, whereas in +/// Spark, it is tied to individual aggregate functions. Spark executes +/// aggregates using a mix of partial, intermediate, and final aggregate +/// functions. To bridge the two systems, the planner translates Spark's +/// aggregate modes into corresponding Velox companion functions and assigns +/// the "single" step to Velox’s AggregationNode. These companion functions +/// are intended for internal use within the aggregate operator and are not +/// designed to be used as standalone functions, and their result types +/// may not always be inferable from intermediate types. More details can be +/// found in +/// https://github.com/facebookincubator/velox/pull/11999#issuecomment-3274577979 +/// and https://github.com/facebookincubator/velox/issues/12830. class CompanionFunctionsRegistrar { public: - // Register the partial companion function for an aggregation function of - // `name` and `signatures`. When there is already a function of the same name, - // if `overwrite` is true, the registration is replaced. Otherwise, return - // false without overwriting the registry. + /// Register the partial companion function for an aggregate function of + /// `name` and `signatures`. When there is already a function of the same + /// name, if `overwrite` is true, the registration is replaced. Otherwise, + /// return false without overwriting the registry. This function supports + /// generating Spark compatible companion functions. static bool registerPartialFunction( const std::string& name, const std::vector& signatures, const AggregateFunctionMetadata& metadata, bool overwrite = false); - // When there is already a function of the same name as the merge companion - // function, if `overwrite` is true, the registration is replaced. Otherwise, - // return false without overwriting the registry. + /// When there is already a function of the same name as the merge companion + /// function, if `overwrite` is true, the registration is replaced. Otherwise, + /// return false without overwriting the registry. This function supports + /// generating Spark compatible companion functions. static bool registerMergeFunction( const std::string& name, const std::vector& signatures, @@ -204,14 +218,16 @@ class CompanionFunctionsRegistrar { const std::vector& signatures, bool overwrite = false); - // Similar to registerExtractFunction(), the result type of the original - // aggregation function is required to be resolvable given its intermediate - // type. If there are multiple signatures of the original aggregation function - // with the same intermediate type, register merge-extract functions with + /// If there are multiple signatures of the original aggregate function + /// with the same intermediate type, register merge-extract functions with // suffix of their result types in the function names for each of them. When - // there is already a function of the same name as the merge-extract companion - // function, if `overwrite` is true, the registration is replaced. Otherwise, - // return false without overwriting the registry. + /// there is already a function of the same name as the merge-extract + /// companion function, if `overwrite` is true, the registration is replaced. + /// Otherwise, return false without overwriting the registry. This function + /// supports generating Spark compatible companion functions only when the + /// return types are explicitly specified (typically in "single" or "final" + /// steps). It will throw an exception if return types are not provided and + /// cannot be resolved from the intermediate types. static bool registerMergeExtractFunction( const std::string& name, const std::vector& signatures, diff --git a/velox/exec/AggregateCompanionSignatures.cpp b/velox/exec/AggregateCompanionSignatures.cpp index 2fa7e8d11dc..c6da4488dc6 100644 --- a/velox/exec/AggregateCompanionSignatures.cpp +++ b/velox/exec/AggregateCompanionSignatures.cpp @@ -102,20 +102,18 @@ CompanionSignatures::partialFunctionSignatures( const std::vector& signatures) { std::vector partialSignatures; for (const auto& signature : signatures) { - if (!isResultTypeResolvableGivenIntermediateType(signature)) { - continue; - } std::vector usedTypes = signature->argumentTypes(); usedTypes.push_back(signature->intermediateType()); auto variables = usedTypeVariables(usedTypes, signature->variables()); - partialSignatures.push_back(std::make_shared( - /*variables*/ variables, - /*returnType*/ signature->intermediateType(), - /*intermediateType*/ signature->intermediateType(), - /*argumentTypes*/ signature->argumentTypes(), - /*constantArguments*/ signature->constantArguments(), - /*variableArity*/ signature->variableArity())); + partialSignatures.push_back( + std::make_shared( + /*variables*/ variables, + /*returnType*/ signature->intermediateType(), + /*intermediateType*/ signature->intermediateType(), + /*argumentTypes*/ signature->argumentTypes(), + /*constantArguments*/ signature->constantArguments(), + /*variableArity*/ signature->variableArity())); } return partialSignatures; } @@ -126,10 +124,6 @@ std::string CompanionSignatures::partialFunctionName(const std::string& name) { AggregateFunctionSignaturePtr CompanionSignatures::mergeFunctionSignature( const AggregateFunctionSignaturePtr& signature) { - if (!isResultTypeResolvableGivenIntermediateType(signature)) { - return nullptr; - } - std::vector usedTypes = {signature->intermediateType()}; auto variables = usedTypeVariables(usedTypes, signature->variables()); return std::make_shared( @@ -172,10 +166,6 @@ bool CompanionSignatures::hasSameIntermediateTypesAcrossSignatures( AggregateFunctionSignaturePtr CompanionSignatures::mergeExtractFunctionSignature( const AggregateFunctionSignaturePtr& signature) { - if (!isResultTypeResolvableGivenIntermediateType(signature)) { - return nullptr; - } - std::vector usedTypes = { signature->intermediateType(), signature->returnType()}; auto variables = usedTypeVariables(usedTypes, signature->variables()); diff --git a/velox/exec/AggregateCompanionSignatures.h b/velox/exec/AggregateCompanionSignatures.h index ea7d4b2d6ce..c0557adc091 100644 --- a/velox/exec/AggregateCompanionSignatures.h +++ b/velox/exec/AggregateCompanionSignatures.h @@ -106,8 +106,9 @@ class CompanionSignatures { normalizeType(signature->intermediateType(), signature->variables()); auto normalizedReturnType = normalizeType(signature->returnType(), signature->variables()); - if (distinctIntermediateAndResultTypes.count(std::make_pair( - normalizedIntermediateType, normalizedReturnType))) { + if (distinctIntermediateAndResultTypes.count( + std::make_pair( + normalizedIntermediateType, normalizedReturnType))) { continue; } diff --git a/velox/exec/AggregateFunctionRegistry.cpp b/velox/exec/AggregateFunctionRegistry.cpp index 87373f37aa8..259888032cf 100644 --- a/velox/exec/AggregateFunctionRegistry.cpp +++ b/velox/exec/AggregateFunctionRegistry.cpp @@ -21,25 +21,75 @@ namespace facebook::velox::exec { -std::pair resolveAggregateFunction( +namespace { +std::string makeSignatureNotSupportedError( + const std::string& name, + const std::vector& argTypes, + const std::vector>& + signatures) { + std::stringstream error; + error << "Aggregate function signature is not supported: " + << toString(name, argTypes) + << ". Supported signatures: " << toString(signatures) << "."; + return error.str(); +} +} // namespace + +TypePtr resolveResultType( const std::string& name, const std::vector& argTypes) { if (auto signatures = getAggregateFunctionSignatures(name)) { for (const auto& signature : signatures.value()) { - SignatureBinder binder(*signature, argTypes); + SignatureBinder binder(*signature, argTypes, TypeCoercer::defaults()); if (binder.tryBind()) { - return std::make_pair( - binder.tryResolveReturnType(), - binder.tryResolveType(signature->intermediateType())); + return binder.tryResolveReturnType(); } } - std::stringstream error; - error << "Aggregate function signature is not supported: " - << toString(name, argTypes) - << ". Supported signatures: " << toString(signatures.value()) << "."; - VELOX_USER_FAIL(error.str()); + VELOX_USER_FAIL( + makeSignatureNotSupportedError(name, argTypes, signatures.value())); + } + + VELOX_USER_FAIL("Aggregate function not registered: {}", name); +} + +TypePtr resolveResultTypeWithCoercions( + const std::string& name, + const std::vector& argTypes, + std::vector& coercions, + const TypeCoercer& coercer) { + coercions.clear(); + + if (auto signatures = getAggregateFunctionSignatures(name)) { + std::vector baseSignatures( + signatures.value().begin(), signatures.value().end()); + if (auto type = tryResolveReturnTypeWithCoercions( + baseSignatures, argTypes, coercions, coercer)) { + return type; + } + + VELOX_USER_FAIL( + makeSignatureNotSupportedError(name, argTypes, signatures.value())); + } + + VELOX_USER_FAIL("Aggregate function not registered: {}", name); +} + +TypePtr resolveIntermediateType( + const std::string& name, + const std::vector& argTypes) { + if (auto signatures = getAggregateFunctionSignatures(name)) { + for (const auto& signature : signatures.value()) { + SignatureBinder binder(*signature, argTypes, TypeCoercer::defaults()); + if (binder.tryBind()) { + return binder.tryResolveType(signature->intermediateType()); + } + } + VELOX_USER_FAIL( + "Aggregate function signature is not supported: {}. Supported signatures: {}.", + toString(name, argTypes), + toString(signatures.value())); } else { VELOX_USER_FAIL("Aggregate function not registered: {}", name); } diff --git a/velox/exec/AggregateFunctionRegistry.h b/velox/exec/AggregateFunctionRegistry.h index 2ee6045907d..53eb8b87b32 100644 --- a/velox/exec/AggregateFunctionRegistry.h +++ b/velox/exec/AggregateFunctionRegistry.h @@ -19,15 +19,43 @@ #include #include "velox/type/Type.h" +#include "velox/type/TypeCoercer.h" namespace facebook::velox::exec { -/// Given a name of aggregate function and argument types, returns a pair of the -/// return type and intermediate type if the function exists. Throws if function -/// doesn't exist or doesn't support specified argument types. +/// Given a name of aggregate function and argument types, returns the result +/// type if the function exists. Throws if function doesn't exist or doesn't +/// support specified argument types. Since aggregate functions can be +/// integrated into internal steps of an aggregate operator — rather than +/// always being used as standalone functions at the SQL level — their result +/// types may not always be inferable from the intermediate types. As a +/// result, an exception might be thrown during the type resolution process. In +/// such cases, the caller should explicitly specify the result type. More +/// details can be found in +/// https://github.com/facebookincubator/velox/pull/11999#issuecomment-3274577979 +/// and https://github.com/facebookincubator/velox/issues/12830. +TypePtr resolveResultType( + const std::string& name, + const std::vector& argTypes); + +/// Like 'resolveResultType', but with support for applying type conversions if +/// a function signature doesn't match 'argTypes' exactly. /// -/// @return a pair of {finalType, intermediateType} -std::pair resolveAggregateFunction( +/// @param coercions A list of optional type coercions that were applied to +/// resolve a function successfully. Contains one entry per argument. The entry +/// is null if no coercion is required for that argument. The entry is not null +/// if coercion is necessary. +/// @param coercer Coercion rule set to use when resolving type coercions. +TypePtr resolveResultTypeWithCoercions( + const std::string& name, + const std::vector& argTypes, + std::vector& coercions, + const TypeCoercer& coercer); + +/// Given a name of aggregate function and argument types, returns the +/// intermediate type if the function exists. Throws if function doesn't exist +/// or doesn't support specified argument types. +TypePtr resolveIntermediateType( const std::string& name, const std::vector& argTypes); diff --git a/velox/exec/AggregateInfo.cpp b/velox/exec/AggregateInfo.cpp index a3f28c0684d..044a510e848 100644 --- a/velox/exec/AggregateInfo.cpp +++ b/velox/exec/AggregateInfo.cpp @@ -81,11 +81,10 @@ std::vector toAggregateInfo( arg->toString()); } } + const auto& name = aggregate.call->name(); - info.distinct = aggregate.distinct; - info.intermediateType = resolveAggregateFunction( - aggregate.call->name(), aggregate.rawInputTypes) - .second; + info.intermediateType = + resolveIntermediateType(name, aggregate.rawInputTypes); // Setup aggregation mask: convert the Variable Reference name to the // channel (projection) index, if there is a mask. @@ -98,13 +97,19 @@ std::vector toAggregateInfo( auto index = numKeys + i; const auto& aggResultType = outputType->childAt(index); info.function = Aggregate::create( - aggregate.call->name(), + name, isPartialOutput(step) ? core::AggregationNode::Step::kPartial : core::AggregationNode::Step::kSingle, aggregate.rawInputTypes, aggResultType, operatorCtx.driverCtx()->queryConfig()); + // Pass constant inputs to the aggregate so it can read constant arguments + // (e.g., flags) at initialization time. + if (!constants.empty()) { + info.function->setConstantInputs(constants); + } + auto lambdas = extractLambdaInputs(aggregate); if (!lambdas.empty()) { if (expressionEvaluator == nullptr) { @@ -114,10 +119,13 @@ std::vector toAggregateInfo( info.function->setLambdaExpressions(lambdas, expressionEvaluator); } - // Ignore sorting properties if aggregate function is not sensitive to the - // order of inputs. - auto* entry = getAggregateFunctionEntry(aggregate.call->name()); + // 1. Ignore duplicates property + // if aggregate function is not sensitive to duplicates. + // 2. Ignore sorting properties + // if aggregate function is not sensitive to the order of inputs. + auto* entry = getAggregateFunctionEntry(name); const auto& metadata = entry->metadata; + info.distinct = !metadata.ignoreDuplicates && aggregate.distinct; if (metadata.orderSensitive) { // Sorting keys and orders. const auto numSortingKeys = aggregate.sortingKeys.size(); diff --git a/velox/exec/AggregateWindow.cpp b/velox/exec/AggregateWindow.cpp index 2bdad5342c6..6c90a4fd1d5 100644 --- a/velox/exec/AggregateWindow.cpp +++ b/velox/exec/AggregateWindow.cpp @@ -150,7 +150,7 @@ class AggregateWindowFunction : public exec::WindowFunction { // This is the start of a new incremental aggregation. So the // aggregate_ function object should be initialized. auto singleGroup = std::vector{0}; - aggregate_->clear(); + aggregate_->destroy(folly::Range(&rawSingleGroupRow_, 1)); aggregate_->initializeNewGroups(&rawSingleGroupRow_, singleGroup); aggregateInitialized_ = true; } @@ -257,6 +257,11 @@ class AggregateWindowFunction : public exec::WindowFunction { void fillArgVectors(vector_size_t firstRow, vector_size_t lastRow) { vector_size_t numFrameRows = lastRow + 1 - firstRow; for (int i = 0; i < argIndices_.size(); i++) { + // Without the following call to `prepareForReuse`, if the type of + // `argVectors_[i]` is VARCHAR, then the string buffers will accumulate + // across calculations. As a result, memory consumption will increase over + // time. So we call `prepareForReuse` to clear string buffers timely. + argVectors_[i]->prepareForReuse(); argVectors_[i]->resize(numFrameRows); // Only non-constant field argument vectors need to be populated. The // constant vectors are correctly set during aggregate initialization @@ -331,7 +336,7 @@ class AggregateWindowFunction : public exec::WindowFunction { // TODO : Try to re-use previous computations by advancing and retracting // the aggregation based on the frame changes with each row. This would // require adding new APIs to the Aggregate framework. - aggregate_->clear(); + aggregate_->destroy(folly::Range(&rawSingleGroupRow_, 1)); aggregate_->initializeNewGroups(&rawSingleGroupRow_, kSingleGroup); aggregateInitialized_ = true; diff --git a/velox/exec/ArrowStream.cpp b/velox/exec/ArrowStream.cpp index e368844e41f..a456852c7d9 100644 --- a/velox/exec/ArrowStream.cpp +++ b/velox/exec/ArrowStream.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ #include "velox/exec/ArrowStream.h" +#include "velox/exec/OperatorType.h" #include "velox/vector/arrow/Bridge.h" namespace facebook::velox::exec { @@ -27,7 +28,7 @@ ArrowStream::ArrowStream( arrowStreamNode->outputType(), operatorId, arrowStreamNode->id(), - "ArrowStream") { + OperatorType::kArrowStream) { arrowStream_ = arrowStreamNode->arrowStream(); } diff --git a/velox/exec/AssignUniqueId.cpp b/velox/exec/AssignUniqueId.cpp index e08898decfd..5428d326069 100644 --- a/velox/exec/AssignUniqueId.cpp +++ b/velox/exec/AssignUniqueId.cpp @@ -18,6 +18,8 @@ #include #include +#include "velox/exec/OperatorType.h" + namespace facebook::velox::exec { AssignUniqueId::AssignUniqueId( @@ -31,7 +33,7 @@ AssignUniqueId::AssignUniqueId( planNode->outputType(), operatorId, planNode->id(), - "AssignUniqueId"), + OperatorType::kAssignUniqueId), rowIdPool_(std::move(rowIdPool)) { VELOX_USER_CHECK_LT( uniqueTaskId, diff --git a/velox/exec/BarrierSplit.h b/velox/exec/BarrierSplit.h new file mode 100644 index 00000000000..3249b2eeaa2 --- /dev/null +++ b/velox/exec/BarrierSplit.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +namespace facebook::velox::exec { + +/// A special split used by task barrier processing to signal output drain +/// processing. When task barrier processing is triggered, one barrier split is +/// added to each leaf source node. Once a source node receives the barrier +/// split, it will produce output and propagate the barrier down to the root +/// node of a pipeline (typically exchange or partitioned output). +struct BarrierSplit { + /// The number of drivers in the source pipeline that share this barrier + /// split. It is used to deduplicate barrier processing at the root node of + /// the pipeline which blocks the barrier processing until all the pipelines + /// drivers have reached it. + uint32_t numDrivers; +}; + +} // namespace facebook::velox::exec diff --git a/velox/exec/BlockingReason.cpp b/velox/exec/BlockingReason.cpp index ffe07aa41d8..325249defff 100644 --- a/velox/exec/BlockingReason.cpp +++ b/velox/exec/BlockingReason.cpp @@ -16,6 +16,8 @@ #include "velox/exec/BlockingReason.h" +#include "velox/common/EnumDefine.h" + namespace facebook::velox::exec { namespace { @@ -35,6 +37,8 @@ const auto& blockingReasonNames() { {BlockingReason::kWaitForArbitration, "kWaitForArbitration"}, {BlockingReason::kWaitForScanScaleUp, "kWaitForScanScaleUp"}, {BlockingReason::kWaitForIndexLookup, "kWaitForIndexLookup"}, + {BlockingReason::kWaitForIndexSplits, "kWaitForIndexSplits"}, + {BlockingReason::kWaitForRPC, "kWaitForRPC"}, }; return kNames; } diff --git a/velox/exec/BlockingReason.h b/velox/exec/BlockingReason.h index f89e808b099..6ecc81cd465 100644 --- a/velox/exec/BlockingReason.h +++ b/velox/exec/BlockingReason.h @@ -16,7 +16,8 @@ #pragma once -#include "velox/common/Enums.h" +#include +#include "velox/common/EnumDeclare.h" namespace facebook::velox::exec { @@ -55,6 +56,12 @@ enum class BlockingReason { /// Used by IndexLookupJoin operator, indicating that it was blocked by the /// async index lookup. kWaitForIndexLookup, + /// Used by IndexLookupJoin follower operators waiting for the leader to + /// finish collecting index splits via the IndexLookupJoinBridge. + kWaitForIndexSplits, + /// Used by RPC operators, indicating that the operator is blocked waiting + /// for an async RPC response (e.g., LLM inference, embedding lookups). + kWaitForRPC, }; VELOX_DECLARE_ENUM_NAME(BlockingReason); diff --git a/velox/exec/CMakeLists.txt b/velox/exec/CMakeLists.txt index ba8457f888e..cbea17dc9b3 100644 --- a/velox/exec/CMakeLists.txt +++ b/velox/exec/CMakeLists.txt @@ -19,14 +19,14 @@ velox_add_library( AggregateCompanionSignatures.cpp AggregateFunctionRegistry.cpp AggregateInfo.cpp - AggregationMasks.cpp AggregateWindow.cpp + AggregationMasks.cpp ArrowStream.cpp AssignUniqueId.cpp BlockingReason.cpp CallbackSink.cpp - ContainerRowSerde.cpp ColumnStatsCollector.cpp + ContainerRowSerde.cpp DistinctAggregations.cpp Driver.cpp EnforceSingleRow.cpp @@ -34,6 +34,7 @@ velox_add_library( ExchangeClient.cpp ExchangeQueue.cpp ExchangeSource.cpp + SerializedPage.cpp Expand.cpp FilterProject.cpp GroupId.cpp @@ -44,57 +45,66 @@ velox_add_library( HashPartitionFunction.cpp HashProbe.cpp HashTable.cpp + HashTableCache.cpp IndexLookupJoin.cpp + IndexLookupJoinBridge.cpp JoinBridge.cpp Limit.cpp LocalPartition.cpp LocalPlanner.cpp MarkDistinct.cpp + MarkSorted.cpp + EnforceDistinct.cpp MemoryReclaimer.cpp Merge.cpp MergeJoin.cpp MergeSource.cpp + MixedUnion.cpp NestedLoopJoinBuild.cpp NestedLoopJoinProbe.cpp + SpatialIndex.cpp SpatialJoinBuild.cpp SpatialJoinProbe.cpp Operator.cpp + OperatorTraceCtx.cpp + OperatorTraceReader.cpp + OperatorTraceScan.cpp + OperatorTraceWriter.cpp OperatorUtils.cpp OrderBy.cpp OutputBuffer.cpp OutputBufferManager.cpp - OperatorTraceReader.cpp - OperatorTraceScan.cpp - OperatorTraceWriter.cpp ParallelProject.cpp - TaskStructs.cpp - TaskTraceReader.cpp - TaskTraceWriter.cpp - Trace.cpp - TraceUtil.cpp - PartitionedOutput.cpp PartitionFunction.cpp PartitionStreamingWindowBuild.cpp + PartitionedOutput.cpp PlanNodeStats.cpp PrefixSort.cpp ProbeOperatorState.cpp - RowsStreamingWindowBuild.cpp + SubPartitionedSortWindowBuild.cpp RowContainer.cpp RowNumber.cpp - ScaledScanController.cpp + RowsStreamingWindowBuild.cpp ScaleWriterLocalPartition.cpp + ScaledScanController.cpp SortBuffer.cpp - SortedAggregations.cpp SortWindowBuild.cpp + SortedAggregations.cpp + SpatialJoinBuild.cpp + SpatialJoinProbe.cpp Spill.cpp SpillFile.cpp Spiller.cpp StreamingAggregation.cpp + StreamingEnforceDistinct.cpp Strings.cpp TableScan.cpp TableWriteMerge.cpp TableWriter.cpp Task.cpp + TaskStructs.cpp + TaskTraceReader.cpp + TaskTraceWriter.cpp TopN.cpp TopNRowNumber.cpp Unnest.cpp @@ -104,36 +114,153 @@ velox_add_library( WindowBuild.cpp WindowFunction.cpp WindowPartition.cpp + HEADERS + AddressableNonNullValueList.h + Aggregate.h + AdaptivePrefetch.h + AggregateCompanionAdapter.h + AggregateCompanionSignatures.h + AggregateFunctionRegistry.h + AggregateInfo.h + AggregateWindow.h + AggregationMasks.h + ArrowStream.h + AssignUniqueId.h + BarrierSplit.h + BlockingReason.h + CallbackSink.h + ColumnStatsCollector.h + ContainerRowSerde.h + DistinctAggregations.h + Driver.h + DriverStats.h + EnforceDistinct.h + EnforceSingleRow.h + Exchange.h + ExchangeClient.h + ExchangeQueue.h + ExchangeSource.h + Expand.h + FilterProject.h + GroupId.h + GroupingSet.h + HashAggregation.h + HashBitRange.h + HashBuild.h + HashJoinBridge.h + HashPartitionFunction.h + HashProbe.h + HashTable.h + HashTableCache.h + HilbertIndex.h + IndexLookupJoin.h + IndexLookupJoinBridge.h + JoinBridge.h + Limit.h + LocalPartition.h + LocalPlanner.h + MarkDistinct.h + MarkSorted.h + MemoryReclaimer.h + Merge.h + MergeJoin.h + MergeSource.h + MixedUnion.h + NestedLoopJoinBuild.h + NestedLoopJoinProbe.h + OneWayStatusFlag.h + Operator.h + OperatorStats.h + OperatorTraceCtx.h + OperatorTraceReader.h + OperatorTraceScan.h + OperatorTraceWriter.h + OperatorType.h + OperatorUtils.h + OrderBy.h + OutputBuffer.h + OutputBufferManager.h + ParallelProject.h + PartitionFunction.h + PartitionStreamingWindowBuild.h + PartitionedOutput.h + PlanNodeStats.h + PrefixSort.h + ProbeOperatorState.h + RoundRobinPartitionFunction.h + RowContainer.h + RowNumber.h + RowsStreamingWindowBuild.h + ScaleWriterLocalPartition.h + ScaledScanController.h + SerializedPage.h + SetAccumulator.h + SimpleAggregateAdapter.h + SortBuffer.h + SortWindowBuild.h + SortedAggregations.h + SpatialIndex.h + SpatialJoinBuild.h + SpatialJoinProbe.h + Spill.h + SpillFile.h + Spiller.h + Split.h + StreamingAggregation.h + StreamingEnforceDistinct.h + Strings.h + SubPartitionedSortWindowBuild.h + TableScan.h + TableWriteMerge.h + TableWriter.h + Task.h + TaskStats.h + TaskStructs.h + TaskTraceReader.h + TaskTraceWriter.h + TopN.h + TopNRowNumber.h + TraceUtil.h + Unnest.h + UnorderedStreamReader.h + Values.h + VectorHasher-inl.h + VectorHasher.h + Window.h + WindowBuild.h + WindowFunction.h + WindowPartition.h ) +velox_add_library(velox_exec_spill_stats SpillStats.cpp HEADERS SpillStats.h) +velox_link_libraries(velox_exec_spill_stats velox_common_base velox_file Folly::folly) + velox_link_libraries( velox_exec - velox_file + velox_arrow_bridge + velox_common_base + velox_common_compression velox_core - velox_vector velox_connector + velox_connector_registry + velox_exec_spill_stats velox_expression + velox_file + velox_presto_serializer + velox_trace velox_time - velox_common_base velox_test_util - velox_arrow_bridge - velox_common_compression + velox_vector ) -velox_add_library(velox_cursor Cursor.cpp) -velox_link_libraries( - velox_cursor - velox_core - velox_exception - velox_expression - velox_dwio_common - velox_dwio_dwrf_reader - velox_dwio_dwrf_writer - velox_type_fbhive - velox_presto_serializer - velox_functions_prestosql - velox_aggregates -) +if(VELOX_ENABLE_GEO) + velox_compile_definitions(velox_exec PRIVATE VELOX_ENABLE_GEO) + velox_link_libraries(velox_exec velox_common_geospatial_serde) +endif() + +velox_add_library(velox_cursor Cursor.cpp HEADERS Cursor.h) + +velox_link_libraries(velox_cursor velox_core velox_exception velox_expression) if(${VELOX_BUILD_TESTING}) add_subdirectory(fuzzer) @@ -147,3 +274,9 @@ if(${VELOX_ENABLE_BENCHMARKS}) endif() add_subdirectory(prefixsort) +add_subdirectory(rpc) +add_subdirectory(trace) + +velox_add_library(velox_aggregate_util INTERFACE HEADERS AggregateUtil.h) + +velox_add_library(velox_operator_type INTERFACE HEADERS OperatorType.h) diff --git a/velox/exec/CallbackSink.h b/velox/exec/CallbackSink.h index dc8afd64763..199bbd9d8aa 100644 --- a/velox/exec/CallbackSink.h +++ b/velox/exec/CallbackSink.h @@ -16,6 +16,7 @@ #pragma once #include "velox/exec/Operator.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/OperatorUtils.h" namespace facebook::velox::exec { @@ -26,8 +27,14 @@ class CallbackSink : public Operator { int32_t operatorId, DriverCtx* driverCtx, Consumer consumeCb, - std::function startedCb = nullptr) - : Operator(driverCtx, nullptr, operatorId, "N/A", "CallbackSink"), + std::function startedCb = nullptr, + const std::string& planNodeId = "N/A") + : Operator( + driverCtx, + nullptr, + operatorId, + planNodeId, + OperatorType::kCallbackSink), startedCb_{std::move(startedCb)}, consumeCb_{std::move(consumeCb)} {} diff --git a/velox/exec/ColumnStatsCollector.cpp b/velox/exec/ColumnStatsCollector.cpp index 1fcfab46cfe..1ec85678924 100644 --- a/velox/exec/ColumnStatsCollector.cpp +++ b/velox/exec/ColumnStatsCollector.cpp @@ -54,30 +54,9 @@ void ColumnStatsCollector::initialize() { VELOX_CHECK_NOT_NULL(groupingSet_); } -// static -RowTypePtr ColumnStatsCollector::outputType( - const core::ColumnStatsSpec& statsSpec) { - // Create output type based on the column stats collection specs. - std::vector names; - std::vector types; - const auto numAggregates = statsSpec.aggregates.size(); - const auto outputTypeSize = statsSpec.groupingKeys.size() + numAggregates; - names.reserve(outputTypeSize); - types.reserve(outputTypeSize); - for (const auto& key : statsSpec.groupingKeys) { - names.push_back(key->name()); - types.push_back(key->type()); - } - for (auto i = 0; i < numAggregates; ++i) { - names.push_back(statsSpec.aggregateNames[i]); - types.push_back(statsSpec.aggregates[i].call->type()); - } - return ROW(std::move(names), std::move(types)); -} - void ColumnStatsCollector::setOutputType() { VELOX_CHECK_NULL(outputType_); - outputType_ = outputType(statsSpec_); + outputType_ = statsSpec_.outputType(); } std::pair, std::vector> @@ -120,9 +99,8 @@ std::vector ColumnStatsCollector::createAggregates( } } VELOX_CHECK(!aggregate.distinct); - info.intermediateType = resolveAggregateFunction( - aggregate.call->name(), aggregate.rawInputTypes) - .second; + info.intermediateType = resolveIntermediateType( + aggregate.call->name(), aggregate.rawInputTypes); // Column stats collection doesn't support aggregation mask. VELOX_CHECK_NULL(aggregate.mask); info.mask = std::nullopt; @@ -181,7 +159,6 @@ void ColumnStatsCollector::addInput(RowVectorPtr input) { return; } - // Add input to the grouping set groupingSet_->addInput(input, /*mayPushdown=*/false); } diff --git a/velox/exec/ColumnStatsCollector.h b/velox/exec/ColumnStatsCollector.h index 26af22142e3..5c76f9af391 100644 --- a/velox/exec/ColumnStatsCollector.h +++ b/velox/exec/ColumnStatsCollector.h @@ -51,11 +51,6 @@ class ColumnStatsCollector { memory::MemoryPool* pool, tsan_atomic* nonReclaimableSection); - /// Returns the output row type that will be produced by this collector. - /// The output type is determined by the grouping keys and aggregate functions - /// specified in the ColumnStatsSpec. - static RowTypePtr outputType(const core::ColumnStatsSpec& statsSpec); - /// Initializes the stats collector. Must be called exactly once before /// adding any input data. Sets up internal aggregation structures based /// on the provided ColumnStatsSpec. diff --git a/velox/exec/ContainerRowSerde.cpp b/velox/exec/ContainerRowSerde.cpp index 89f2de3f3a3..b571a0a7b95 100644 --- a/velox/exec/ContainerRowSerde.cpp +++ b/velox/exec/ContainerRowSerde.cpp @@ -388,11 +388,8 @@ std::optional compareSwitch( template < bool typeProvidesCustomComparison, TypeKind Kind, - std::enable_if_t< - Kind != TypeKind::VARCHAR && Kind != TypeKind::VARBINARY && - Kind != TypeKind::ARRAY && Kind != TypeKind::MAP && - Kind != TypeKind::ROW, - int32_t> = 0> + std::enable_if_t = + 0> std::optional compare( ByteInputStream& left, const BaseVector& right, @@ -693,11 +690,8 @@ std::optional compareSwitch( template < bool typeProvidesCustomComparison, TypeKind Kind, - std::enable_if_t< - Kind != TypeKind::VARCHAR && Kind != TypeKind::VARBINARY && - Kind != TypeKind::ARRAY && Kind != TypeKind::MAP && - Kind != TypeKind::ROW, - int32_t> = 0> + std::enable_if_t = + 0> std::optional compare( ByteInputStream& left, ByteInputStream& right, @@ -898,11 +892,8 @@ uint64_t hashSwitch(ByteInputStream& stream, const Type* type); template < bool typeProvidesCustomComparison, TypeKind Kind, - std::enable_if_t< - Kind != TypeKind::VARBINARY && Kind != TypeKind::VARCHAR && - Kind != TypeKind::ARRAY && Kind != TypeKind::MAP && - Kind != TypeKind::ROW, - int32_t> = 0> + std::enable_if_t = + 0> uint64_t hashOne(ByteInputStream& stream, const Type* type) { using T = typename TypeTraits::NativeType; diff --git a/velox/exec/Cursor.cpp b/velox/exec/Cursor.cpp index bdbbc0fe50c..1d807609e18 100644 --- a/velox/exec/Cursor.cpp +++ b/velox/exec/Cursor.cpp @@ -13,38 +13,89 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include "velox/exec/Cursor.h" -#include "velox/common/file/FileSystems.h" +#include #include +#include + +#include "velox/common/file/FileSystems.h" +#include "velox/common/future/VeloxPromise.h" +#include "velox/exec/BlockingReason.h" +#include "velox/vector/EncodedVectorCopy.h" namespace facebook::velox::exec { +namespace { -bool waitForTaskDriversToFinish(exec::Task* task, uint64_t maxWaitMicros) { - VELOX_USER_CHECK(!task->isRunning()); - uint64_t waitMicros = 0; - while ((task->numFinishedDrivers() != task->numTotalDrivers()) && - (waitMicros < maxWaitMicros)) { - const uint64_t kWaitMicros = 1000; - std::this_thread::sleep_for(std::chrono::microseconds(kWaitMicros)); - waitMicros += kWaitMicros; +class TaskQueue { + public: + struct TaskQueueEntry { + RowVectorPtr vector; + uint64_t bytes; + }; + + explicit TaskQueue( + uint64_t maxBytes, + const std::shared_ptr& outputPool) + : pool_( + outputPool != nullptr ? outputPool + : memory::memoryManager()->addLeafPool()), + maxBytes_(maxBytes) {} + + void setNumProducers(int32_t n) { + numProducers_ = n; } - if (task->numFinishedDrivers() != task->numTotalDrivers()) { - LOG(ERROR) << "Timed out waiting for all drivers of task " << task->taskId() - << " to finish. Finished drivers: " << task->numFinishedDrivers() - << ". Total drivers: " << task->numTotalDrivers(); + // Adds a batch of rows to the queue and returns kNotBlocked if the + // producer may continue. Returns kWaitForConsumer if the queue is + // full after the addition and sets '*future' to a future that is + // realized when the producer may continue. + exec::BlockingReason + enqueue(RowVectorPtr vector, bool drained, velox::ContinueFuture* future); + + // Returns nullptr when all producers are at end. Otherwise blocks. + RowVectorPtr dequeue(); + + void close(); + + bool hasNext(); + + velox::memory::MemoryPool* pool() const { + return pool_.get(); } - return task->numFinishedDrivers() == task->numTotalDrivers(); -} + std::optional numProducers_; + std::atomic_int32_t numDrainedProducers_{0}; + + private: + // Owns the vectors in 'queue_', hence must be declared first. + std::shared_ptr pool_; + std::deque queue_; + int32_t producersFinished_ = 0; + uint64_t totalBytes_ = 0; + // Blocks the producer if 'totalBytes' exceeds 'maxBytes' after + // adding the result. + uint64_t maxBytes_; + std::mutex mutex_; + std::vector producerUnblockPromises_; + bool consumerBlocked_ = false; + ContinuePromise consumerPromise_{ContinuePromise::makeEmpty()}; + ContinueFuture consumerFuture_; + bool closed_ = false; +}; exec::BlockingReason TaskQueue::enqueue( RowVectorPtr vector, + bool drained, velox::ContinueFuture* future) { if (!vector) { std::lock_guard l(mutex_); - ++producersFinished_; + if (drained) { + ++numDrainedProducers_; + } else { + ++producersFinished_; + } if (consumerBlocked_) { consumerBlocked_ = false; consumerPromise_.setValue(); @@ -67,7 +118,8 @@ exec::BlockingReason TaskQueue::enqueue( consumerPromise_.setValue(); } if (totalBytes_ > maxBytes_) { - auto [unblockPromise, unblockFuture] = makeVeloxContinuePromiseContract(); + auto [unblockPromise, unblockFuture] = + makeVeloxContinuePromiseContract("TaskQueue::enqueue"); producerUnblockPromises_.emplace_back(std::move(unblockPromise)); *future = std::move(unblockFuture); return exec::BlockingReason::kWaitForConsumer; @@ -96,10 +148,13 @@ RowVectorPtr TaskQueue::dequeue() { } else if ( numProducers_.has_value() && producersFinished_ == numProducers_) { return nullptr; + } else if (numDrainedProducers_ == numProducers_) { + return nullptr; } + if (!vector) { consumerBlocked_ = true; - consumerPromise_ = ContinuePromise(); + consumerPromise_ = ContinuePromise("TaskQueue::dequeue"); consumerFuture_ = consumerPromise_.getFuture(); } } @@ -161,8 +216,15 @@ class TaskCursorBase : public TaskCursor { fmt::format("TaskCursorQuery_{}", cursorQueryId++)); } - if (!params.queryConfigs.empty()) { - auto configCopy = params.queryConfigs; + // If query configs needs to be overwritten in queryCtx. + if (!params.queryConfigs.empty() || !params.breakpoints.empty()) { + auto configCopy = !params.queryConfigs.empty() + ? params.queryConfigs + : queryCtx_->queryConfig().rawConfigsCopy(); + + if (!params.breakpoints.empty()) { + configCopy.insert({core::QueryConfig::kQueryTraceEnabled, "true"}); + } queryCtx_->testingOverrideConfigUnsafe(std::move(configCopy)); } @@ -174,22 +236,25 @@ class TaskCursorBase : public TaskCursor { if (!params.spillDirectory.empty()) { taskSpillDirectory_ = params.spillDirectory + "/" + taskId_; - auto fileSystem = - velox::filesystems::getFileSystem(taskSpillDirectory_, nullptr); - VELOX_CHECK_NOT_NULL(fileSystem, "File System is null!"); - try { - fileSystem->mkdir(taskSpillDirectory_); - } catch (...) { - LOG(ERROR) << "Faield to create task spill directory " - << taskSpillDirectory_ << " base director " - << params.spillDirectory << " exists[" - << std::filesystem::exists(taskSpillDirectory_) << "]"; - - std::rethrow_exception(std::current_exception()); - } + taskSpillDirectoryCb_ = params.spillDirectoryCallback; + if (taskSpillDirectoryCb_ == nullptr) { + auto fileSystem = + velox::filesystems::getFileSystem(taskSpillDirectory_, nullptr); + VELOX_CHECK_NOT_NULL(fileSystem, "File System is null!"); + try { + fileSystem->mkdir(taskSpillDirectory_); + } catch (...) { + LOG(ERROR) << "Faield to create task spill directory " + << taskSpillDirectory_ << " base director " + << params.spillDirectory << " exists[" + << std::filesystem::exists(taskSpillDirectory_) << "]"; + + std::rethrow_exception(std::current_exception()); + } - LOG(INFO) << "Task spill directory[" << taskSpillDirectory_ - << "] created"; + LOG(INFO) << "Task spill directory[" << taskSpillDirectory_ + << "] created"; + } } } @@ -198,6 +263,7 @@ class TaskCursorBase : public TaskCursor { std::shared_ptr queryCtx_; core::PlanFragment planFragment_; std::string taskSpillDirectory_; + std::function taskSpillDirectoryCb_; private: std::shared_ptr executor_; @@ -209,7 +275,7 @@ class MultiThreadedTaskCursor : public TaskCursorBase { : TaskCursorBase( params, std::make_shared( - std::thread::hardware_concurrency())), + folly::available_concurrency())), maxDrivers_{params.maxDrivers}, numConcurrentSplitGroups_{params.numConcurrentSplitGroups}, numSplitGroups_{params.numSplitGroups} { @@ -222,7 +288,14 @@ class MultiThreadedTaskCursor : public TaskCursorBase { std::make_shared(params.bufferedBytes, params.outputPool); // Captured as a shared_ptr by the consumer callback of task_. - auto queue = queue_; + auto queueHolder = std::weak_ptr(queue_); + std::optional spillDiskOpts; + if (!taskSpillDirectory_.empty()) { + spillDiskOpts = common::SpillDiskOptions{ + .spillDirPath = taskSpillDirectory_, + .spillDirCreated = taskSpillDirectoryCb_ == nullptr, + .spillDirCreateCb = taskSpillDirectoryCb_}; + } task_ = Task::create( taskId_, std::move(planFragment_), @@ -230,35 +303,39 @@ class MultiThreadedTaskCursor : public TaskCursorBase { std::move(queryCtx_), Task::ExecutionMode::kParallel, // consumer - [queue, copyResult = params.copyResult]( + [queueHolder, copyResult = params.copyResult, taskId = taskId_]( const RowVectorPtr& vector, bool drained, velox::ContinueFuture* future) { - VELOX_CHECK( - !drained, "Unexpected drain in multithreaded task cursor"); - if (!vector || !copyResult) { - return queue->enqueue(vector, future); + auto queue = queueHolder.lock(); + if (queue == nullptr) { + LOG(ERROR) << "TaskQueue has been destroyed, taskId: " << taskId; + return exec::BlockingReason::kNotBlocked; } - // Make sure to load lazy vector if not loaded already. - for (auto& child : vector->children()) { - child->loadedVector(); + + if (!vector || !copyResult) { + return queue->enqueue(vector, drained, future); } - auto copy = BaseVector::create( - vector->type(), vector->size(), queue->pool()); - copy->copy(vector.get(), 0, 0, vector->size()); - return queue->enqueue(std::move(copy), future); + VectorPtr copy = encodedVectorCopy( + {.pool = queue->pool(), .reuseSource = false}, vector); + return queue->enqueue( + std::static_pointer_cast(std::move(copy)), + drained, + future); }, 0, - [queue](std::exception_ptr) { + std::move(spillDiskOpts), + [queueHolder, taskId = taskId_](std::exception_ptr) { // onError close the queue to unblock producers and consumers. // moveNext will handle rethrowing the error once it's // unblocked. + auto queue = queueHolder.lock(); + if (queue == nullptr) { + LOG(ERROR) << "TaskQueue has been destroyed, taskId: " << taskId; + return; + } queue->close(); }); - - if (!taskSpillDirectory_.empty()) { - task_->setSpillDirectory(taskSpillDirectory_); - } } ~MultiThreadedTaskCursor() override { @@ -300,11 +377,22 @@ class MultiThreadedTaskCursor : public TaskCursorBase { checkTaskError(); if (!current_) { + if (queue_->numDrainedProducers_ > 0) { + VELOX_CHECK(queue_->numProducers_.has_value()); + VELOX_CHECK_EQ( + queue_->numDrainedProducers_.load(), queue_->numProducers_.value()); + queue_->numDrainedProducers_ = 0; + return false; + } atEnd_ = true; } return current_ != nullptr; } + bool moveStep(const core::PlanNodeId& /*planId*/ = "") override { + return moveNext(); + } + void setNoMoreSplits() override { VELOX_CHECK(!noMoreSplits_); noMoreSplits_ = true; @@ -314,14 +402,14 @@ class MultiThreadedTaskCursor : public TaskCursorBase { return noMoreSplits_; } - bool hasNext() override { - return queue_->hasNext(); - } - RowVectorPtr& current() override { return current_; } + core::PlanNodeId at() const override { + return ""; // always at task output. + } + void setError(std::exception_ptr error) override { error_ = error; if (task_) { @@ -371,17 +459,22 @@ class SingleThreadedTaskCursor : public TaskCursorBase { VELOX_CHECK( !queryCtx_->isExecutorSupplied(), "Executor should not be set in serial task cursor"); - + std::optional spillDiskOpts; + if (!taskSpillDirectory_.empty()) { + spillDiskOpts = common::SpillDiskOptions{ + .spillDirPath = taskSpillDirectory_, + .spillDirCreated = true, + .spillDirCreateCb = taskSpillDirectoryCb_}; + } task_ = Task::create( taskId_, std::move(planFragment_), params.destination, std::move(queryCtx_), - Task::ExecutionMode::kSerial); - - if (!taskSpillDirectory_.empty()) { - task_->setSpillDirectory(taskSpillDirectory_); - } + Task::ExecutionMode::kSerial, + std::function{}, + 0, + std::move(spillDiskOpts)); VELOX_CHECK( task_->supportSerialExecutionMode(), @@ -389,7 +482,7 @@ class SingleThreadedTaskCursor : public TaskCursorBase { } ~SingleThreadedTaskCursor() override { - if (task_ && !SingleThreadedTaskCursor::hasNext()) { + if (task_) { task_->requestCancel().wait(); } } @@ -408,26 +501,15 @@ class SingleThreadedTaskCursor : public TaskCursorBase { } bool moveNext() override { - if (!hasNext()) { - return false; - } - current_ = next_; - next_ = nullptr; - return true; - }; - - bool hasNext() override { - if (next_) { - return true; - } if (!task_->isRunning()) { return false; } + while (true) { ContinueFuture future = ContinueFuture::makeEmpty(); RowVectorPtr next = task_->next(&future); if (next != nullptr) { - next_ = next; + current_ = next; return true; } // When next is returned from task as a null pointer. @@ -439,12 +521,21 @@ class SingleThreadedTaskCursor : public TaskCursorBase { VELOX_CHECK_NULL(next); future.wait(); } - }; + return false; + } + + bool moveStep(const core::PlanNodeId& /*planId*/ = "") override { + return moveNext(); + } RowVectorPtr& current() override { return current_; } + core::PlanNodeId at() const override { + return ""; // always at task output. + } + void setError(std::exception_ptr error) override { error_ = error; if (task_) { @@ -460,16 +551,434 @@ class SingleThreadedTaskCursor : public TaskCursorBase { std::shared_ptr task_; bool noMoreSplits_{false}; RowVectorPtr current_; - RowVectorPtr next_; std::exception_ptr error_; }; -std::unique_ptr TaskCursor::create(const CursorParameters& params) { - if (params.serialExecution) { - return std::make_unique(params); +/// Common base class for debugging cursors that support breakpoints. +/// +/// Provides shared infrastructure for pausing execution at traced operators +/// and inspecting intermediate results. Subclasses implement the execution +/// model (serial vs parallel) via the `advance()` and `start()` methods. +class TaskDebuggerCursorBase : public TaskCursorBase { + protected: + // Internal state for coordinating between the tracer and cursor. + // + // This struct manages the synchronization between the trace writer + // (which produces intermediate results) and the cursor (which consumes + // them). + struct TraceDriverState { + // Promise used to signal the tracer to continue after a partial result + // has been consumed. + ContinuePromise tracePromise{ContinuePromise::makeEmpty()}; + + // The most recent intermediate result from a traced operator. + RowVectorPtr traceData; + + // The plan id where this state came from. + core::PlanNodeId planId; + }; + + struct TraceState { + std::deque queue; + std::mutex mutex; + + // Consumer blocking fields used by the parallel cursor to coordinate + // between the consumer callback and the advance() loop. In serial mode, + // consumerBlocked is never set to true, so the wakeup code is a no-op. + bool consumerBlocked = false; + ContinuePromise consumerPromise{ContinuePromise::makeEmpty()}; + ContinueFuture consumerFuture; + }; + + // Custom trace context implementation for the debugger. + // + // This trace context pauses execution at traced operators by blocking + // the trace writer until the cursor consumes the intermediate result. + class TaskDebuggerTraceCtx : public trace::TraceCtx { + public: + // Constructs a trace context for the specified plan nodes. + // + // @param breakpoints Map of plan node IDs to optional callbacks. + // @param traceState Reference to the shared trace state for coordination. + TaskDebuggerTraceCtx( + const CursorParameters::TBreakpointMap& breakpoints, + TraceState& traceState) + : TraceCtx(false), breakpoints_(breakpoints), traceState_(traceState) {} + + // Determines whether a given operator should be traced. + // + // @param op The operator to check. + // @return true if the operator's plan node ID is in the traced set. + bool shouldTrace(const Operator& op) const override { + return breakpoints_.contains(op.planNodeId()); + } + + // Creates an input trace writer for the given operator. + // + // @param op The operator to create a tracer for. + // @return A unique pointer to the trace input writer. + std::unique_ptr createInputTracer( + Operator& op) const override { + auto it = breakpoints_.find(op.planNodeId()); + return std::make_unique( + op.planNodeId(), + it != breakpoints_.end() ? it->second : nullptr, + traceState_); + } + + private: + // Trace writer that captures input vectors and pauses execution. + // + // When an input vector is written, this writer stores it in the shared + // trace state and blocks until the cursor signals to continue. + class TaskDebuggerTraceInputWriter : public trace::TraceInputWriter { + public: + TaskDebuggerTraceInputWriter( + const core::PlanNodeId& planId, + CursorParameters::BreakpointCallback callback, + TraceState& traceState) + : planId_(planId), + callback_(std::move(callback)), + traceState_(traceState) {} + + // Writes an input vector and potentially pauses execution. + // + // Invokes the callback if set. If the callback returns false, the writer + // does not block and execution continues. If the callback returns true + // (or is null), stores the vector in the trace state and creates a future + // that blocks until the cursor consumes the result and signals + // continuation. + // + // @param vector The input vector to trace. + // @param future Output parameter set to a future that blocks until + // the cursor is ready to continue. + // @return true if the writer is blocked waiting for the future, false + // if execution should continue without blocking. + bool write(const RowVectorPtr& vector, ContinueFuture* future) override { + // Invoke the callback if set. If it returns false, don't block. + if (callback_ && !callback_(vector)) { + return false; + } + + std::lock_guard l(traceState_.mutex); + traceState_.queue.push_back( + TraceDriverState{ + .tracePromise = ContinuePromise("TaskQueue::dequeue"), + .traceData = vector, + .planId = planId_, + }); + *future = traceState_.queue.back().tracePromise.getFuture(); + + // If the consumer is blocked waiting for output, unblock it. + if (traceState_.consumerBlocked) { + traceState_.consumerBlocked = false; + traceState_.consumerPromise.setValue(); + } + return true; + } + + // Called when tracing is complete for this operator. + void finish() override {} + + private: + const core::PlanNodeId planId_; + const CursorParameters::BreakpointCallback callback_; + TraceState& traceState_; + }; + + CursorParameters::TBreakpointMap breakpoints_; + TraceState& traceState_; + }; + + using TaskCursorBase::TaskCursorBase; + + /// Ensures the task completes before cleanup. + ~TaskDebuggerCursorBase() override { + if (task_) { + task_->requestCancel().wait(); + } + } + + public: + bool moveNext() override { + return advance(false); + } + + bool moveStep(const core::PlanNodeId& planId = "") override { + return advance(true, planId); + } + + RowVectorPtr& current() override { + return current_; + } + + core::PlanNodeId at() const override { + if (pendingTraceDriverState_) { + return pendingTraceDriverState_->planId; + } + return ""; } - return std::make_unique(params); -} + + void setNoMoreSplits() override { + VELOX_CHECK(!noMoreSplits_); + noMoreSplits_ = true; + } + + bool noMoreSplits() const override { + return noMoreSplits_; + } + + void setError(std::exception_ptr error) override { + error_ = error; + if (task_) { + task_->setError(error); + } + } + + const std::shared_ptr& task() override { + return task_; + } + + protected: + // Advance to the next vector to produce, storing it in `current_`. If + // `isStep` is true, move to the next trace point or task output. If false, + // moves to the next task output. + // + // If `isStep` is true and `planId` is non-empty, only stops at trace points + // matching the given plan node ID; other trace points are skipped. + // + // Returns false when the task is done producing output. + virtual bool advance(bool isStep, const core::PlanNodeId& planId = "") = 0; + + // Unblocks the trace writer (driver) from the previously consumed trace + // state. + void unblockPendingState() { + if (pendingTraceDriverState_) { + pendingTraceDriverState_->tracePromise.setValue(); + pendingTraceDriverState_.reset(); + } + } + + std::shared_ptr task_; + bool noMoreSplits_{false}; + RowVectorPtr current_; + std::exception_ptr error_; + + // Holds the trace state that was returned to the user via moveStep(), + // so its promise can be fulfilled on the next advance() call. + std::optional pendingTraceDriverState_; +}; + +/// A debugging cursor for interactive serial task execution. +/// +/// @note This class assumes serial (single-threaded) execution mode. +class TaskDebuggerSerialCursor : public TaskDebuggerCursorBase { + public: + explicit TaskDebuggerSerialCursor(const CursorParameters& params) + : TaskDebuggerCursorBase(params, nullptr) { + // Installs the required trace provider. + queryCtx_->setTraceCtxProvider( + [&](core::QueryCtx&, const core::PlanFragment&) { + return std::make_unique( + params.breakpoints, traceState_); + }); + + task_ = Task::create( + taskId_, + std::move(planFragment_), + params.destination, + std::move(queryCtx_), + Task::ExecutionMode::kSerial); + } + + // no-op + void start() override {} + + private: + bool advance(bool isStep, const core::PlanNodeId& planId = "") override { + if (error_) { + std::rethrow_exception(error_); + } + + unblockPendingState(); + + while (true) { + ContinueFuture future = ContinueFuture::makeEmpty(); + + if (auto vector = task_->next(&future)) { + current_ = vector; + return true; + } + + // Check if any trace states have been queued by writers. + { + std::lock_guard l(traceState_.mutex); + if (!traceState_.queue.empty()) { + auto state = std::move(traceState_.queue.front()); + traceState_.queue.pop_front(); + + if (isStep && (planId.empty() || state.planId == planId)) { + current_ = state.traceData; + pendingTraceDriverState_ = std::move(state); + return true; + } + + // Signal the task driver to unblock. + state.tracePromise.setValue(); + } + } + + // Wait until the task future is unblocked. + if (future.valid()) { + future.wait(); + } else { + // When no vector was produced and the future is not valid, it's the + // task signal that it has finished producing output. + VELOX_CHECK(!task_->isRunning() || !noMoreSplits_); + break; + } + } + return false; + } + + TraceState traceState_; +}; + +/// A debugging cursor for interactive parallel task execution. +/// +/// Uses a consumer callback to receive output from parallel drivers. +class TaskDebuggerParallelCursor : public TaskDebuggerCursorBase { + public: + explicit TaskDebuggerParallelCursor(const CursorParameters& params) + : TaskDebuggerCursorBase( + params, + std::make_shared( + folly::available_concurrency())), + maxDrivers_(params.maxDrivers), + numConcurrentSplitGroups_(params.numConcurrentSplitGroups) { + // Installs the required trace provider. + queryCtx_->setTraceCtxProvider( + [&](core::QueryCtx&, const core::PlanFragment&) { + return std::make_unique( + params.breakpoints, traceState_); + }); + + task_ = Task::create( + taskId_, + std::move(planFragment_), + params.destination, + std::move(queryCtx_), + Task::ExecutionMode::kParallel, + // consumer + [&](const RowVectorPtr& vector, + bool drained, + velox::ContinueFuture* future) { + VELOX_CHECK( + !drained, "Unexpected drain in multithreaded task cursor"); + + if (!vector) { + // End-of-stream from a driver. Track completion and wake the + // consumer if it's blocked waiting for data. + std::lock_guard l(traceState_.mutex); + ++traceState_.numFinishedProducers; + if (traceState_.consumerBlocked) { + traceState_.consumerBlocked = false; + traceState_.consumerPromise.setValue(); + } + return exec::BlockingReason::kNotBlocked; + } + + std::lock_guard l(traceState_.mutex); + traceState_.queue.push_back( + TraceDriverState{ + .tracePromise = ContinuePromise("TaskQueue::dequeue"), + .traceData = vector, + .planId = "", + }); + *future = traceState_.queue.back().tracePromise.getFuture(); + + if (traceState_.consumerBlocked) { + traceState_.consumerBlocked = false; + traceState_.consumerPromise.setValue(); + } + return exec::BlockingReason::kWaitForConsumer; + }); + } + + void start() override { + if (!started_) { + started_ = true; + try { + task_->start(maxDrivers_, numConcurrentSplitGroups_); + numProducers_ = task_->numOutputDrivers(); + } catch (const VeloxException& e) { + // Could not find output pipeline, due to Task terminated before + // start. Do not override the error. + if (e.message().find("Output pipeline not found for task") == + std::string::npos) { + throw; + } + } + } + } + + private: + bool advance(bool isStep, const core::PlanNodeId& planId = "") override { + start(); + if (error_) { + std::rethrow_exception(error_); + } + + unblockPendingState(); + + while (true) { + // Check if any trace states have been queued by writers. + { + std::lock_guard l(traceState_.mutex); + if (!traceState_.queue.empty()) { + auto state = std::move(traceState_.queue.front()); + traceState_.queue.pop_front(); + + const bool matchesPlanId = planId.empty() || state.planId == planId; + if ((isStep && matchesPlanId) || state.planId.empty()) { + current_ = state.traceData; + pendingTraceDriverState_ = std::move(state); + return true; + } + + // moveNext() skips breakpoint trace data, or the planId filter + // didn't match; unblock the trace writer (driver). + state.tracePromise.setValue(); + } + + // Queue is empty. If all producers have finished, we're done. + if (traceState_.numFinishedProducers >= numProducers_) { + break; + } + + traceState_.consumerBlocked = true; + traceState_.consumerPromise = ContinuePromise("TaskQueue::dequeue"); + traceState_.consumerFuture = traceState_.consumerPromise.getFuture(); + } + traceState_.consumerFuture.wait(); + } + return false; + } + + struct ParallelTraceState : TraceState { + // Number of output drivers that have finished (sent a null vector). + int numFinishedProducers{0}; + }; + + ParallelTraceState traceState_; + + bool started_{false}; + int32_t maxDrivers_; + int32_t numConcurrentSplitGroups_; + int numProducers_{0}; +}; + +} // namespace bool RowCursor::next() { if (++currentRow_ < numRows_) { @@ -498,8 +1007,38 @@ bool RowCursor::next() { return true; } -bool RowCursor::hasNext() { - return currentRow_ < numRows_ || cursor_->hasNext(); +std::unique_ptr TaskCursor::create(const CursorParameters& params) { + if (!params.breakpoints.empty()) { + if (params.serialExecution) { + return std::make_unique(params); + } else { + return std::make_unique(params); + } + } + + if (params.serialExecution) { + return std::make_unique(params); + } + return std::make_unique(params); +} + +bool waitForTaskDriversToFinish(exec::Task* task, uint64_t maxWaitMicros) { + VELOX_USER_CHECK(!task->isRunning()); + uint64_t waitMicros = 0; + while ((task->numFinishedDrivers() != task->numTotalDrivers()) && + (waitMicros < maxWaitMicros)) { + const uint64_t kWaitMicros = 1000; + std::this_thread::sleep_for(std::chrono::microseconds(kWaitMicros)); + waitMicros += kWaitMicros; + } + + if (task->numFinishedDrivers() != task->numTotalDrivers()) { + LOG(ERROR) << "Timed out waiting for all drivers of task " << task->taskId() + << " to finish. Finished drivers: " << task->numFinishedDrivers() + << ". Total drivers: " << task->numTotalDrivers(); + } + + return task->numFinishedDrivers() == task->numTotalDrivers(); } } // namespace facebook::velox::exec diff --git a/velox/exec/Cursor.h b/velox/exec/Cursor.h index 57d1ec8380a..7f859e98c61 100644 --- a/velox/exec/Cursor.h +++ b/velox/exec/Cursor.h @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #pragma once #include "velox/core/PlanNode.h" @@ -21,19 +22,13 @@ namespace facebook::velox::exec { -/// Wait up to maxWaitMicros for all the task drivers to finish. The function -/// returns true if all the drivers have finished, otherwise false. -/// -/// NOTE: user must call this on a finished or failed task. -bool waitForTaskDriversToFinish( - exec::Task* task, - uint64_t maxWaitMicros = 1'000'000); - /// Parameters for initializing a TaskCursor or RowCursor. struct CursorParameters { /// Root node of the plan tree std::shared_ptr planNode; + /// Partition number if task is expected to receive data from a remote data + /// shuffle. Used to initialize ExchangeClient. int32_t destination{0}; /// Maximum number of drivers per pipeline. @@ -69,7 +64,13 @@ struct CursorParameters { /// Spilling directory, if not empty, then the task's spilling directory /// would be built from it. - std::string spillDirectory; + std::string spillDirectory = ""; + + /// Callback function to dynamically create or determine the spill directory + /// path at runtime. If provided, this callback is invoked when spilling is + /// needed and must return a valid directory path. This allows for dynamic + /// spill directory creation or path resolution based on runtime conditions. + std::function spillDirectoryCallback = nullptr; bool copyResult = true; @@ -81,65 +82,58 @@ struct CursorParameters { /// If both 'queryConfigs' and 'queryCtx' are specified, the configurations /// in 'queryCtx' will be overridden by 'queryConfig'. - std::unordered_map queryConfigs; -}; + std::unordered_map queryConfigs = {}; -class TaskQueue { - public: - struct TaskQueueEntry { - RowVectorPtr vector; - uint64_t bytes; - }; - - explicit TaskQueue( - uint64_t maxBytes, - const std::shared_ptr& outputPool) - : pool_( - outputPool != nullptr ? outputPool - : memory::memoryManager()->addLeafPool()), - maxBytes_(maxBytes) {} - - void setNumProducers(int32_t n) { - numProducers_ = n; - } - - // Adds a batch of rows to the queue and returns kNotBlocked if the - // producer may continue. Returns kWaitForConsumer if the queue is - // full after the addition and sets '*future' to a future that is - // realized when the producer may continue. - exec::BlockingReason enqueue( - RowVectorPtr vector, - velox::ContinueFuture* future); - - // Returns nullptr when all producers are at end. Otherwise blocks. - RowVectorPtr dequeue(); - - void close(); + // Debugging related structures: - bool hasNext(); - - velox::memory::MemoryPool* pool() const { - return pool_.get(); - } - - private: - // Owns the vectors in 'queue_', hence must be declared first. - std::shared_ptr pool_; - std::deque queue_; - std::optional numProducers_; - int32_t producersFinished_ = 0; - uint64_t totalBytes_ = 0; - // Blocks the producer if 'totalBytes' exceeds 'maxBytes' after - // adding the result. - uint64_t maxBytes_; - std::mutex mutex_; - std::vector producerUnblockPromises_; - bool consumerBlocked_ = false; - ContinuePromise consumerPromise_; - ContinueFuture consumerFuture_; - bool closed_ = false; + /// Callback type for breakpoints. + /// + /// Called with the current vector when the breakpoint is hit. The return + /// value semantics is "should block?"; returns true if the driver should stop + /// and produce the vector, false to continue without stopping. + /// + /// If callback is not specified (nullptr), assume the partial vector should + /// always be produced, which is the same as callback returning true (always + /// stop). + using BreakpointCallback = std::function; + + /// Map type for breakpoints: plan node ID to optional callback. + using TBreakpointMap = + std::unordered_map; + + /// Breakpoints enable step-by-step execution of a query plan, allowing users + /// to inspect intermediate results at operator boundaries containing + /// breakpoints. This is useful for debugging query execution and + /// understanding data flow through operators. + /// + /// Maps plan node IDs to optional callbacks. When a breakpoint is hit, the + /// callback (if non-null) is invoked with the current vector before the + /// cursor pauses. + TBreakpointMap breakpoints = {}; }; +/// Abstract interface for iterating over query results. TaskCursor manages +/// task execution and provides batch-level access to output vectors. +/// +/// Example usage: +/// @code +/// +/// auto cursor = TaskCursor:create({ +/// .planNode = node, +/// ); +/// +/// // Run through every output. +/// while (cursor->moveNext()) { +/// auto vector = cursor->current(); +/// } +/// @endcode +/// +/// If "breakpoints" are set in the CursorParameters input, then +/// `cursor->moveStep()` will move the cursor to the next breakpoint, which is +/// either the input of an operator with a breakpoint installed, or the next +/// task output. +/// +/// `cursor->moveNext()` will always move the cursor to the next task output. class TaskCursor { public: virtual ~TaskCursor() = default; @@ -149,14 +143,32 @@ class TaskCursor { /// Starts the task if not started yet. virtual void start() = 0; - /// Fetches another batch from the task queue. - /// Starts the task if not started yet. + /// Fetches another batch from the task queue. Starts the task if not started + /// yet. + /// + /// @return Returns false is the task is done producing output. virtual bool moveNext() = 0; - virtual bool hasNext() = 0; + /// Steps through execution, returning either the input to the next operator + /// with a breakpoint installed, or the next task output. If no breakpoints + /// are set, then moveStep() == moveNext(). + /// + /// If @planId is non-empty, only stops at a breakpoint whose plan node ID + /// matches @planId; breakpoints for other plan nodes are skipped + /// (unblocked) automatically. When empty (the default), stops at the next + /// breakpoint regardless of plan node ID. + /// + /// @return Returns false is the task is done producing output. + virtual bool moveStep(const core::PlanNodeId& planId = "") = 0; + /// Returns the vector the cursor is currently on. virtual RowVectorPtr& current() = 0; + /// If breakpoints are set, returns the plan node that generated the trace. If + /// the cursor is at the task output or if there are no breakpoints, + /// returns empty string. + virtual core::PlanNodeId at() const = 0; + virtual void setError(std::exception_ptr error) = 0; virtual bool noMoreSplits() const = 0; @@ -166,6 +178,8 @@ class TaskCursor { virtual const std::shared_ptr& task() = 0; }; +/// Row-level cursor that wraps a TaskCursor and provides access to individual +/// rows and column values within the result set. class RowCursor { public: explicit RowCursor(CursorParameters& params) { @@ -185,8 +199,6 @@ class RowCursor { bool next(); - bool hasNext(); - std::shared_ptr task() const { return cursor_->task(); } @@ -204,4 +216,12 @@ class RowCursor { vector_size_t numRows_ = 0; }; +/// Wait up to maxWaitMicros for all the task drivers to finish. The function +/// returns true if all the drivers have finished, otherwise false. +/// +/// NOTE: user must call this on a finished or failed task. +bool waitForTaskDriversToFinish( + exec::Task* task, + uint64_t maxWaitMicros = 1'000'000); + } // namespace facebook::velox::exec diff --git a/velox/exec/DistinctAggregations.cpp b/velox/exec/DistinctAggregations.cpp index e51413d4843..81936cd7a3f 100644 --- a/velox/exec/DistinctAggregations.cpp +++ b/velox/exec/DistinctAggregations.cpp @@ -20,6 +20,152 @@ namespace facebook::velox::exec { namespace { +// Handles distinct aggregations where all inputs are constants. +// The distinct set of a constant tuple is always either empty or a single +// element, so it only needs a boolean flag per group indicating whether any +// row was seen. +class ConstantDistinctAggregations : public DistinctAggregations { + public: + ConstantDistinctAggregations( + std::vector aggregates, + memory::MemoryPool* pool) + : pool_{pool}, aggregates_{std::move(aggregates)} { + for (const auto& aggregate : aggregates_) { + for (size_t i = 0; i < aggregate->inputs.size(); ++i) { + VELOX_DCHECK_NOT_NULL(aggregate->constantInputs[i]); + } + } + } + + Accumulator accumulator() const override { + return {/*isFixedSize=*/true, + sizeof(bool), + /*usesExternalMemory=*/false, + /*alignment=*/1, + BOOLEAN(), + /*spillExtractFunction=*/ + [this](folly::Range groups, VectorPtr& result) { + extractForSpill(groups, result); + }, + /*destroyFunction=*/nullptr}; + } + + void addInput( + char** groups, + const RowVectorPtr& /*input*/, + const SelectivityVector& rows) override { + rows.applyToSelected([&](vector_size_t i) { value(groups[i]) = true; }); + } + + void addSingleGroupInput( + char* group, + const RowVectorPtr& /*input*/, + const SelectivityVector& /*rows*/) override { + value(group) = true; + } + + void extractValues(folly::Range groups, const RowVectorPtr& result) + override { + raw_vector indices(pool_); + for (const auto& aggregate : aggregates_) { + // All inputs are constant, so the aggregate result is identical for every + // group that saw rows. Compute it once and broadcast. + // Check whether any group is non-empty. + bool hasNonEmpty = false; + for (vector_size_t i = 0; i < groups.size(); ++i) { + if (value(groups[i])) { + hasNonEmpty = true; + break; + } + } + + if (groups.size() < 2 || !hasNonEmpty) { + // With 0 or 1 groups no broadcasting is needed. If all groups + // are empty, there is no need to add input, so just extract directly. + if (hasNonEmpty) { + VELOX_CHECK_EQ(groups.size(), 1); + const SelectivityVector rows(1); + aggregate->function->addSingleGroupRawInput( + groups[0], rows, aggregate->constantInputs, false); + } + aggregate->function->extractValues( + groups.data(), groups.size(), &result->childAt(aggregate->output)); + } else { + // Use groups[0] as the non-empty representative and groups[1] as the + // empty representative. Extract from these two into a 2-row vector, + // then wrap the output in a dictionary that maps each group to the + // appropriate row. + const SelectivityVector rows(1); + aggregate->function->addSingleGroupRawInput( + groups[0], rows, aggregate->constantInputs, false); + + std::array representatives = {groups[0], groups[1]}; + auto extracted = + BaseVector::create(aggregate->function->resultType(), 2, pool_); + aggregate->function->extractValues( + representatives.data(), 2, &extracted); + + auto dictIndices = allocateIndices(groups.size(), pool_); + auto* rawIndices = dictIndices->asMutable(); + for (vector_size_t i = 0; i < groups.size(); ++i) { + // Index 0 = non-empty value, index 1 = empty/default value. + rawIndices[i] = value(groups[i]) ? 0 : 1; + } + result->childAt(aggregate->output) = BaseVector::wrapInDictionary( + nullptr, + std::move(dictIndices), + groups.size(), + std::move(extracted)); + } + + aggregate->function->destroy(groups); + aggregate->function->initializeNewGroups( + groups.data(), + folly::Range( + iota(groups.size(), indices), groups.size())); + } + } + + void addSingleGroupSpillInput( + char* group, + const VectorPtr& input, + vector_size_t index) override { + if (input->as>()->valueAt(index)) { + value(group) = true; + } + } + + protected: + void initializeNewGroupsInternal( + char** groups, + folly::Range indices) override { + for (auto index : indices) { + groups[index][nullByte_] |= nullMask_; + value(groups[index]) = false; + } + + for (const auto& aggregate : aggregates_) { + aggregate->function->initializeNewGroups(groups, indices); + } + } + + private: + bool& value(char* group) const { + return *reinterpret_cast(group + offset_); + } + + void extractForSpill(folly::Range groups, VectorPtr& result) const { + auto* flatResult = result->asFlatVector(); + flatResult->resize(groups.size()); + for (auto i = 0; i < groups.size(); ++i) { + flatResult->set(i, value(groups[i])); + } + } + + memory::MemoryPool* const pool_; + const std::vector aggregates_; +}; + template < typename T, typename AccumulatorType = aggregate::prestosql::SetAccumulator> @@ -28,35 +174,37 @@ class TypedDistinctAggregations : public DistinctAggregations { TypedDistinctAggregations( std::vector aggregates, const RowTypePtr& inputType, + std::vector nonConstantInputs, memory::MemoryPool* pool) : pool_{pool}, aggregates_{std::move(aggregates)}, - inputs_{aggregates_[0]->inputs}, - inputType_(TypedDistinctAggregations::makeInputTypeForAccumulator( - inputType, - inputs_)) {} + nonConstantInputs_{std::move(nonConstantInputs)}, + inputType_(makeInputTypeForAccumulator(inputType, nonConstantInputs_)), + spillType_(ARRAY(inputType_)), + singleNonConstantInput_(nonConstantInputs_.size() == 1) {} /// Returns metadata about the accumulator used to store unique inputs. Accumulator accumulator() const override { - return { - false, // isFixedSize - sizeof(AccumulatorType), - false, // usesExternalMemory - 1, // alignment - nullptr, - [](folly::Range /*groups*/, VectorPtr& /*result*/) { - VELOX_UNREACHABLE(); - }, - [this](folly::Range groups) { - for (auto* group : groups) { - if (!isInitialized(group)) { - continue; - } - auto* accumulator = - reinterpret_cast(group + offset_); - accumulator->free(*allocator_); - } - }}; + return {/*isFixedSize=*/false, + sizeof(AccumulatorType), + /*usesExternalMemory=*/false, + /*alignment=*/1, + spillType_, + /*spillExtractFunction=*/ + [this](folly::Range groups, VectorPtr& result) { + extractForSpill(groups, result); + }, + /*destroyFunction=*/ + [this](folly::Range groups) { + for (auto* group : groups) { + if (!isInitialized(group)) { + continue; + } + auto* accumulator = + reinterpret_cast(group + offset_); + accumulator->free(*allocator_); + } + }}; } void addInput( @@ -99,12 +247,18 @@ class TypedDistinctAggregations : public DistinctAggregations { const auto& aggregate = *aggregates_[i]; // For each group, add distinct inputs to aggregate. + VectorPtr data; for (auto* group : groups) { auto* accumulator = reinterpret_cast(group + offset_); // TODO Process group rows in batches to avoid creating very large input // vectors. - auto data = BaseVector::create(inputType_, accumulator->size(), pool_); + if (!data) { + data = BaseVector::create(inputType_, accumulator->size(), pool_); + } else { + BaseVector::prepareForReuse(data, accumulator->size()); + } + if constexpr (std::is_same_v) { accumulator->extractValues(*data, 0); } else { @@ -114,7 +268,7 @@ class TypedDistinctAggregations : public DistinctAggregations { if (data->size() > 0) { rows.resize(data->size()); std::vector inputForAggregation = - makeInputForAggregation(data); + makeInputForAggregation(data, aggregate); aggregate.function->addSingleGroupRawInput( group, rows, inputForAggregation, false); } @@ -137,13 +291,25 @@ class TypedDistinctAggregations : public DistinctAggregations { } } + void addSingleGroupSpillInput( + char* group, + const VectorPtr& input, + vector_size_t index) override { + auto* elementArray = input->asChecked(); + decodedInput_.decode(*elementArray->elements()); + + auto* accumulator = reinterpret_cast(group + offset_); + RowSizeTracker tracker(group[rowSizeOffset_], *allocator_); + accumulator->addValues(*elementArray, index, decodedInput_, allocator_); + } + protected: void initializeNewGroupsInternal( char** groups, folly::Range indices) override { - for (auto i : indices) { - groups[i][nullByte_] |= nullMask_; - new (groups[i] + offset_) AccumulatorType(inputType_, allocator_); + for (auto index : indices) { + groups[index][nullByte_] |= nullMask_; + new (groups[index] + offset_) AccumulatorType(inputType_, allocator_); } for (auto i = 0; i < aggregates_.size(); ++i) { @@ -153,10 +319,6 @@ class TypedDistinctAggregations : public DistinctAggregations { } private: - bool isSingleInputAggregate() const { - return aggregates_[0]->inputs.size() == 1; - } - void decodeInput(const RowVectorPtr& input, const SelectivityVector& rows) { inputForAccumulator_ = makeInputForAccumulator(input); decodedInput_.decode(*inputForAccumulator_, rows); @@ -164,45 +326,108 @@ class TypedDistinctAggregations : public DistinctAggregations { static TypePtr makeInputTypeForAccumulator( const RowTypePtr& rowType, - const std::vector& inputs) { - if (inputs.size() == 1) { - return rowType->childAt(inputs[0]); + const std::vector& inputChannels) { + const auto numInputChannels = inputChannels.size(); + if (numInputChannels == 1) { + return rowType->childAt(inputChannels[0]); } // Otherwise, synthesize a ROW(distinct_channels[0..N]) std::vector types; + types.reserve(numInputChannels); std::vector names; - for (column_index_t channelIndex : inputs) { - names.emplace_back(rowType->nameOf(channelIndex)); - types.emplace_back(rowType->childAt(channelIndex)); + names.reserve(numInputChannels); + for (column_index_t inputChannel : inputChannels) { + names.emplace_back(rowType->nameOf(inputChannel)); + types.emplace_back(rowType->childAt(inputChannel)); } return ROW(std::move(names), std::move(types)); } VectorPtr makeInputForAccumulator(const RowVectorPtr& input) const { - if (isSingleInputAggregate()) { - return input->childAt(inputs_[0]); + if (singleNonConstantInput_) { + return input->childAt(nonConstantInputs_[0]); } - std::vector newChildren(inputs_.size()); - for (int i = 0; i < inputs_.size(); ++i) { - newChildren[i] = input->childAt(inputs_[i]); + std::vector newChildren(nonConstantInputs_.size()); + for (size_t i = 0; i < nonConstantInputs_.size(); ++i) { + newChildren[i] = input->childAt(nonConstantInputs_[i]); } return std::make_shared( pool_, inputType_, nullptr, input->size(), newChildren); } - std::vector makeInputForAggregation(const VectorPtr& input) const { - if (isSingleInputAggregate()) { - return {std::move(input)}; + /// Build the full input vector list for the aggregate function from the + /// extracted distinct values, splicing constant inputs back in at the + /// correct positions. + std::vector makeInputForAggregation( + VectorPtr& data, + const AggregateInfo& aggregate) const { + const auto& inputs = aggregate.inputs; + const auto& constants = aggregate.constantInputs; + std::vector result(inputs.size()); + + std::vector distinctColumns; + if (singleNonConstantInput_) { + distinctColumns.push_back(data); + } else { + distinctColumns = data->template asUnchecked()->children(); + } + + size_t nonConstantIndex = 0; + for (size_t i = 0; i < inputs.size(); ++i) { + if (inputs[i] == kConstantChannel) { + result[i] = constants[i]; + } else { + result[i] = distinctColumns[nonConstantIndex++]; + } + } + VELOX_DCHECK_EQ(nonConstantIndex, distinctColumns.size()); + return result; + } + + void extractForSpill(folly::Range groups, VectorPtr& result) const { + auto* arrayVector = result->asChecked(); + arrayVector->resize(groups.size()); + + auto* rawOffsets = arrayVector->offsets()->asMutable(); + auto* rawSizes = arrayVector->sizes()->asMutable(); + + vector_size_t offset = 0; + for (auto i = 0; i < groups.size(); ++i) { + auto* accumulator = + reinterpret_cast(groups[i] + offset_); + + const auto numDistinct = accumulator->size(); + VELOX_DCHECK_GT(numDistinct, 0); + + rawSizes[i] = numDistinct; + rawOffsets[i] = offset; + + offset += numDistinct; + } + + auto& elementsVector = arrayVector->elements(); + elementsVector->resize(offset); + + offset = 0; + for (const auto group : groups) { + auto* accumulator = reinterpret_cast(group + offset_); + if constexpr (std::is_same_v) { + offset += accumulator->extractValues(*elementsVector, offset); + } else { + offset += accumulator->extractValues( + *(elementsVector->template as>()), offset); + } } - return input->template asUnchecked()->children(); } memory::MemoryPool* const pool_; const std::vector aggregates_; - const std::vector inputs_; + const std::vector nonConstantInputs_; const TypePtr inputType_; + const TypePtr spillType_; + const bool singleNonConstantInput_; DecodedVector decodedInput_; VectorPtr inputForAccumulator_; @@ -211,13 +436,14 @@ class TypedDistinctAggregations : public DistinctAggregations { template std::unique_ptr createDistinctAggregationsWithCustomCompare( - std::vector aggregates, + const std::vector& aggregates, const RowTypePtr& inputType, + std::vector nonConstantInputs, memory::MemoryPool* pool) { return std::make_unique::NativeType, aggregate::prestosql::CustomComparisonSetAccumulator>>( - aggregates, inputType, pool); + aggregates, inputType, std::move(nonConstantInputs), pool); } } // namespace @@ -229,13 +455,27 @@ std::unique_ptr DistinctAggregations::create( VELOX_CHECK_EQ(aggregates.size(), 1); VELOX_CHECK(!aggregates[0]->inputs.empty()); - const bool isSingleInput = aggregates[0]->inputs.size() == 1; - if (!isSingleInput) { + // Collect non-constant input channels to determine the type for the + // set accumulator. Constant inputs are not deduplicated — they are + // spliced back during extraction. + std::vector nonConstantInputs; + for (auto i = 0; i < aggregates[0]->inputs.size(); ++i) { + if (aggregates[0]->inputs[i] != kConstantChannel) { + nonConstantInputs.push_back(aggregates[0]->inputs[i]); + } + } + + if (nonConstantInputs.empty()) { + return std::make_unique( + std::move(aggregates), pool); + } + + if (nonConstantInputs.size() > 1) { return std::make_unique>( - aggregates, inputType, pool); + std::move(aggregates), inputType, std::move(nonConstantInputs), pool); } - const auto type = inputType->childAt(aggregates[0]->inputs[0]); + const auto type = inputType->childAt(nonConstantInputs[0]); if (type->providesCustomComparison()) { return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( @@ -243,50 +483,51 @@ std::unique_ptr DistinctAggregations::create( type->kind(), aggregates, inputType, + std::move(nonConstantInputs), pool); } switch (type->kind()) { case TypeKind::BOOLEAN: return std::make_unique>( - aggregates, inputType, pool); + aggregates, inputType, std::move(nonConstantInputs), pool); case TypeKind::TINYINT: return std::make_unique>( - aggregates, inputType, pool); + aggregates, inputType, std::move(nonConstantInputs), pool); case TypeKind::SMALLINT: return std::make_unique>( - aggregates, inputType, pool); + aggregates, inputType, std::move(nonConstantInputs), pool); case TypeKind::INTEGER: return std::make_unique>( - aggregates, inputType, pool); + aggregates, inputType, std::move(nonConstantInputs), pool); case TypeKind::BIGINT: return std::make_unique>( - aggregates, inputType, pool); + aggregates, inputType, std::move(nonConstantInputs), pool); case TypeKind::HUGEINT: return std::make_unique>( - aggregates, inputType, pool); + aggregates, inputType, std::move(nonConstantInputs), pool); case TypeKind::REAL: return std::make_unique>( - aggregates, inputType, pool); + aggregates, inputType, std::move(nonConstantInputs), pool); case TypeKind::DOUBLE: return std::make_unique>( - aggregates, inputType, pool); + aggregates, inputType, std::move(nonConstantInputs), pool); case TypeKind::TIMESTAMP: return std::make_unique>( - aggregates, inputType, pool); + aggregates, inputType, std::move(nonConstantInputs), pool); case TypeKind::VARBINARY: [[fallthrough]]; case TypeKind::VARCHAR: return std::make_unique>( - aggregates, inputType, pool); + aggregates, inputType, std::move(nonConstantInputs), pool); case TypeKind::ARRAY: case TypeKind::MAP: case TypeKind::ROW: return std::make_unique>( - aggregates, inputType, pool); + aggregates, inputType, std::move(nonConstantInputs), pool); case TypeKind::UNKNOWN: return std::make_unique>( - aggregates, inputType, pool); + aggregates, inputType, std::move(nonConstantInputs), pool); default: VELOX_UNREACHABLE("Unexpected type {}", type->toString()); } diff --git a/velox/exec/DistinctAggregations.h b/velox/exec/DistinctAggregations.h index fdd13a89fd0..0991b1a5e6f 100644 --- a/velox/exec/DistinctAggregations.h +++ b/velox/exec/DistinctAggregations.h @@ -91,6 +91,15 @@ class DistinctAggregations { folly::Range groups, const RowVectorPtr& result) = 0; + /// Update the single accumulator using previously spilled data. + /// @param group Pointer to the start of the group row. + /// @param input Restored spill data to be added to the accumulator. + /// @param index The index indicating which row in `input` is being added. + virtual void addSingleGroupSpillInput( + char* group, + const VectorPtr& input, + vector_size_t index) = 0; + protected: // Initializes null flags and accumulators for newly encountered groups. This // function should be called only once for each group. diff --git a/velox/exec/Driver.cpp b/velox/exec/Driver.cpp index 202553ef41b..1514f9a8516 100644 --- a/velox/exec/Driver.cpp +++ b/velox/exec/Driver.cpp @@ -16,15 +16,30 @@ #include "velox/exec/Driver.h" +#include + #include "velox/common/process/TraceContext.h" +#include "velox/exec/Operator.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/Task.h" #include "velox/vector/LazyVector.h" using facebook::velox::common::testutil::TestValue; namespace facebook::velox::exec { + +Driver::~Driver() = default; + namespace { +/// Returns current time in microseconds using high_resolution_clock. +/// Used for driver-level lifecycle timing to match BlockingState::sinceUs_. +inline uint64_t currentTimeMicrosHires() { + return std::chrono::duration_cast( + std::chrono::high_resolution_clock::now().time_since_epoch()) + .count(); +} + // Checks if output channel is produced using identity projection and returns // input channel if so. std::optional getIdentityProjection( @@ -44,7 +59,7 @@ void recordSilentThrows(Operator& op) { auto numThrow = threadNumVeloxThrow(); if (numThrow > 0) { op.stats().wlock()->addRuntimeStat( - "numSilentThrow", RuntimeCounter(numThrow)); + std::string(DriverStats::kNumSilentThrow), RuntimeCounter(numThrow)); } } @@ -96,8 +111,8 @@ const core::QueryConfig& DriverCtx::queryConfig() const { return task->queryCtx()->queryConfig(); } -const std::optional& DriverCtx::traceConfig() const { - return task->traceConfig(); +const trace::TraceCtx* DriverCtx::traceCtx() const { + return task->traceCtx(); } velox::memory::MemoryPool* DriverCtx::addOperatorPool( @@ -107,8 +122,25 @@ velox::memory::MemoryPool* DriverCtx::addOperatorPool( planNodeId, splitGroupId, pipelineId, driverId, operatorType); } +namespace { +bool isHashJoinSpillOperator(std::string_view operatorType) { + return operatorType == OperatorType::kHashBuild || + operatorType == OperatorType::kHashProbe; +} + +bool isAggregationSpillOperator(std::string_view operatorType) { + return operatorType == OperatorType::kAggregation || + operatorType == OperatorType::kPartialAggregation; +} + +bool isRowNumberSpillOperator(std::string_view operatorType) { + return operatorType == OperatorType::kRowNumber; +} +} // namespace + std::optional DriverCtx::makeSpillConfig( - int32_t operatorId) const { + int32_t operatorId, + std::string_view operatorType) const { const auto& queryConfig = task->queryCtx()->queryConfig(); if (!queryConfig.spillEnabled()) { return std::nullopt; @@ -126,6 +158,26 @@ std::optional DriverCtx::makeSpillConfig( [this](uint64_t bytes) { task->queryCtx()->updateSpilledBytesAndCheckLimit(bytes); }; + + std::string fileCreateConfig = queryConfig.spillFileCreateConfig(); + if (isHashJoinSpillOperator(operatorType)) { + const auto& hashJoinConfig = queryConfig.hashJoinSpillFileCreateConfig(); + if (!hashJoinConfig.empty()) { + fileCreateConfig = hashJoinConfig; + } + } else if (isAggregationSpillOperator(operatorType)) { + const auto& aggregationConfig = + queryConfig.aggregationSpillFileCreateConfig(); + if (!aggregationConfig.empty()) { + fileCreateConfig = aggregationConfig; + } + } else if (isRowNumberSpillOperator(operatorType)) { + const auto& rowNumberConfig = queryConfig.rowNumberSpillFileCreateConfig(); + if (!rowNumberConfig.empty()) { + fileCreateConfig = rowNumberConfig; + } + } + return common::SpillConfig( std::move(getSpillDirPathCb), std::move(updateAndCheckSpillLimitCb), @@ -142,10 +194,12 @@ std::optional DriverCtx::makeSpillConfig( queryConfig.maxSpillRunRows(), queryConfig.writerFlushThresholdBytes(), queryConfig.spillCompressionKind(), + queryConfig.spillNumMaxMergeFiles(), queryConfig.spillPrefixSortEnabled() ? std::optional(prefixSortConfig()) : std::nullopt, - queryConfig.spillFileCreateConfig()); + fileCreateConfig, + queryConfig.windowSpillMinReadBatchRows()); } std::atomic_uint64_t BlockingState::numBlockedDrivers_{0}; @@ -159,11 +213,13 @@ BlockingState::BlockingState( future_(std::move(future)), operator_(op), reason_(reason), - sinceUs_(std::chrono::duration_cast( - std::chrono::high_resolution_clock::now().time_since_epoch()) - .count()) { + sinceUs_( + std::chrono::duration_cast( + std::chrono::high_resolution_clock::now().time_since_epoch()) + .count()) { // Set before leaving the thread. driver_->state().hasBlockingFuture = true; + driver_->state().blockingStartUs = sinceUs_; numBlockedDrivers_++; } @@ -180,6 +236,10 @@ void BlockingState::setResume(std::shared_ptr state) { std::lock_guard l(task->mutex()); if (!driver->state().isTerminated) { state->operator_->recordBlockingTime(state->sinceUs_, state->reason_); + // Accumulate driver-level blocked time using high_resolution_clock, + // matching sinceUs_ and all other driver lifecycle timing. + driver->addDriverBlockedTime( + (currentTimeMicrosHires() - state->sinceUs_) * 1'000); } VELOX_CHECK(!driver->state().suspended()); VELOX_CHECK(driver->state().hasBlockingFuture); @@ -278,6 +338,7 @@ RowVectorPtr Driver::next( ScopedDriverThreadContext scopedDriverThreadContext(self->driverCtx()); std::shared_ptr blockingState; RowVectorPtr result; + const auto stop = runInternal(self, blockingState, result); if (blockingState != nullptr) { @@ -309,7 +370,7 @@ void Driver::enqueueInternal() { VELOX_CHECK(!state_.isEnqueued); state_.isEnqueued = true; // When enqueuing, starting timing the queue time. - queueTimeStartUs_ = getCurrentTimeMicro(); + queueTimeStartUs_ = currentTimeMicrosHires(); } // Call an Operator method. record silenced throws, but not a query @@ -337,13 +398,13 @@ void Driver::enqueueInternal() { } void OpCallStatus::start(int32_t operatorId, const char* operatorMethod) { - timeStartMs = getCurrentTimeMs(); - opId = operatorId; - method = operatorMethod; + timeStartMs_ = getCurrentTimeMs(); + opId_ = operatorId; + method_ = operatorMethod; } void OpCallStatus::stop() { - timeStartMs = 0; + timeStartMs_ = 0; } size_t OpCallStatusRaw::callDuration() const { @@ -365,10 +426,14 @@ CpuWallTiming Driver::processLazyIoStats( if (&op == operators_[0].get()) { return timing; } + static const std::string kCpuNanosKey(LazyVector::kCpuNanos); + static const std::string kWallNanosKey(LazyVector::kWallNanos); + static const std::string kInputBytesKey(LazyVector::kInputBytes); + auto lockStats = op.stats().wlock(); // Checks and tries to update cpu time from lazy loads. - auto it = lockStats->runtimeStats.find(LazyVector::kCpuNanos); + auto it = lockStats->runtimeStats.find(kCpuNanosKey); if (it == lockStats->runtimeStats.end()) { // Return early if no lazy activity. Lazy CPU and wall times are recorded // together, checking one is enough. @@ -386,7 +451,7 @@ CpuWallTiming Driver::processLazyIoStats( // Checks and tries to update wall time from lazy loads. int64_t wallDelta = 0; - it = lockStats->runtimeStats.find(LazyVector::kWallNanos); + it = lockStats->runtimeStats.find(kWallNanosKey); if (it != lockStats->runtimeStats.end()) { const int64_t wall = it->second.sum; wallDelta = std::max(0, wall - lockStats->lastLazyWallNanos); @@ -397,7 +462,7 @@ CpuWallTiming Driver::processLazyIoStats( // Checks and tries to update input bytes from lazy loads. int64_t inputBytesDelta = 0; - it = lockStats->runtimeStats.find(LazyVector::kInputBytes); + it = lockStats->runtimeStats.find(kInputBytesKey); if (it != lockStats->runtimeStats.end()) { const int64_t inputBytes = it->second.sum; inputBytesDelta = inputBytes - lockStats->lastLazyInputBytes; @@ -410,11 +475,12 @@ CpuWallTiming Driver::processLazyIoStats( cpuDelta = std::min(cpuDelta, timing.cpuNanos); wallDelta = std::min(wallDelta, timing.wallNanos); lockStats = operators_[0]->stats().wlock(); - lockStats->getOutputTiming.add(CpuWallTiming{ - 1, - static_cast(wallDelta), - static_cast(cpuDelta), - }); + lockStats->getOutputTiming.add( + CpuWallTiming{ + 1, + static_cast(wallDelta), + static_cast(cpuDelta), + }); lockStats->inputBytes += inputBytesDelta; lockStats->outputBytes += inputBytesDelta; return CpuWallTiming{ @@ -454,8 +520,24 @@ StopReason Driver::runInternal( std::shared_ptr& self, std::shared_ptr& blockingState, RowVectorPtr& result) { - const auto now = getCurrentTimeMicro(); + // All driver timing uses high_resolution_clock consistently + // (matching BlockingState::sinceUs_ used for blocked time). + const auto now = currentTimeMicrosHires(); const auto queuedTimeUs = now - queueTimeStartUs_; + + totalDriverQueuedNanos_ += queuedTimeUs * 1'000; + onThreadStartUs_ = now; + // For the normal close path, closeOperators() finalizes and clears + // onThreadStartUs_ before reporting. This guard handles early returns + // (e.g. Task::enter() failure) and non-close exit paths. + auto onThreadTimeGuard = folly::makeGuard([this]() { + if (onThreadStartUs_ > 0) { + totalDriverOnThreadNanos_ += + (currentTimeMicrosHires() - onThreadStartUs_) * 1'000; + onThreadStartUs_ = 0; + } + }); + // Update the next operator's queueTime. StopReason stop = closed_ ? StopReason::kTerminate : task()->enter(state_, now); @@ -474,7 +556,7 @@ StopReason Driver::runInternal( // been deleted. if (curOperatorId_ < operators_.size()) { operators_[curOperatorId_]->addRuntimeStat( - "queuedWallNanos", + std::string(DriverStats::kQueuedWallNanos), RuntimeCounter(queuedTimeUs * 1'000, RuntimeCounter::Unit::kNanos)); RECORD_HISTOGRAM_METRIC_VALUE( kMetricDriverQueueTimeMs, queuedTimeUs / 1'000); @@ -492,14 +574,27 @@ StopReason Driver::runInternal( try { // Invoked to initialize the operators once before driver starts execution. initializeOperators(); + int32_t startingOperator = getStartingOperator(); TestValue::adjust("facebook::velox::exec::Driver::runInternal", this); - const int32_t numOperators = operators_.size(); + // If the driver is coming back from a trace interruption, feed the + // intermediate result into the next operator, then resume execution at the + // exact operator where the trace was interrupted. + if (traceInput_ != nullptr) { + Operator* tracedOp = operators_[startingOperator].get(); + CALL_OPERATOR( + addInput(tracedOp, traceInput_), + tracedOp, + startingOperator, + kOpMethodAddInput); + traceInput_ = nullptr; + } + ContinueFuture future = ContinueFuture::makeEmpty(); for (;;) { - for (int32_t i = numOperators - 1; i >= 0; --i) { + for (int32_t i = startingOperator; i >= 0; --i) { stop = task()->shouldStop(); if (stop != StopReason::kNone) { guard.notThrown(); @@ -544,7 +639,7 @@ StopReason Driver::runInternal( return blockDriver(self, i, std::move(future), blockingState, guard); } - if (i < numOperators - 1) { + if (i < operators_.size() - 1) { Operator* nextOp = operators_[i + 1].get(); withDeltaCpuWallTimer(nextOp, &OperatorStats::isBlockedTiming, [&]() { @@ -565,9 +660,11 @@ StopReason Driver::runInternal( nextOp, curOperatorId_ + 1, kOpMethodNeedsInput); + if (needsInput) { uint64_t resultBytes = 0; RowVectorPtr intermediateResult; + withDeltaCpuWallTimer(op, &OperatorStats::getOutputTiming, [&]() { TestValue::adjust( "facebook::velox::exec::Driver::runInternal::getOutput", op); @@ -587,6 +684,16 @@ StopReason Driver::runInternal( } }); if (intermediateResult) { + const bool block = + nextOp->traceInput(intermediateResult, &future); + + if (block) { + blockingReason_ = BlockingReason::kWaitForConsumer; + traceInput_ = intermediateResult; + return blockDriver( + self, i + 1, std::move(future), blockingState, guard); + } + withDeltaCpuWallTimer( nextOp, &OperatorStats::addInputTiming, [&]() { { @@ -594,7 +701,7 @@ StopReason Driver::runInternal( lockedStats->addInputVector( resultBytes, intermediateResult->size()); } - nextOp->traceInput(intermediateResult); + TestValue::adjust( "facebook::velox::exec::Driver::runInternal::addInput", nextOp); @@ -702,7 +809,7 @@ StopReason Driver::runInternal( } } catch (velox::VeloxException&) { task()->setError(std::current_exception()); - // The CancelPoolGuard will close 'self' and remove from Task. + // The CancelGuard will close 'self' and remove from Task. return StopReason::kAlreadyTerminated; } catch (std::exception&) { task()->setError(std::current_exception()); @@ -794,10 +901,40 @@ void Driver::closeOperators() { op->close(); } + // Report driver-level lifecycle timing to the Task accumulator. + // Use partitionId (0..numDrivers-1) so same-index drivers across split + // groups in grouped execution are summed together. + // Finalize on-thread time here (the onThreadTimeGuard in runInternal + // hasn't fired yet since CancelGuard destructs before it). + if (onThreadStartUs_ > 0) { + totalDriverOnThreadNanos_ += + (currentTimeMicrosHires() - onThreadStartUs_) * 1'000; + onThreadStartUs_ = 0; // Prevent double-counting in the guard. + } + task()->addDriverLifecycleStats( + static_cast(ctx_->pipelineId), + ctx_->partitionId, + totalDriverQueuedNanos_, + totalDriverOnThreadNanos_, + totalDriverBlockedNanos_); + // Add operator stats to the task. for (auto& op : operators_) { auto stats = op->stats(true); stats.numDrivers = 1; + + // Calculate this driver's CPU time for this specific operator and add it as + // a runtime stat. This will be aggregated across all drivers, with the max + // field containing the CPU time from the longest running driver. + uint64_t operatorCpuNanos = stats.addInputTiming.cpuNanos + + stats.getOutputTiming.cpuNanos + stats.finishTiming.cpuNanos + + stats.isBlockedTiming.cpuNanos; + + if (operatorCpuNanos > 0) { + stats.runtimeStats[std::string(OperatorStats::kDriverCpuTime)] = + RuntimeMetric(operatorCpuNanos, RuntimeCounter::Unit::kNanos); + } + task()->addOperatorStats(stats); } } @@ -805,22 +942,40 @@ void Driver::closeOperators() { void Driver::updateStats() { DriverStats stats; if (state_.totalPauseTimeMs > 0) { - stats.runtimeStats[DriverStats::kTotalPauseTime] = RuntimeMetric( - 1'000'000 * state_.totalPauseTimeMs, RuntimeCounter::Unit::kNanos); + stats.runtimeStats[std::string(DriverStats::kTotalPauseTime)] = + RuntimeMetric( + 1'000'000 * state_.totalPauseTimeMs, RuntimeCounter::Unit::kNanos); } if (state_.totalOffThreadTimeMs > 0) { - stats.runtimeStats[DriverStats::kTotalOffThreadTime] = RuntimeMetric( - 1'000'000 * state_.totalOffThreadTimeMs, RuntimeCounter::Unit::kNanos); + stats.runtimeStats[std::string(DriverStats::kTotalOffThreadTime)] = + RuntimeMetric( + 1'000'000 * state_.totalOffThreadTimeMs, + RuntimeCounter::Unit::kNanos); } + task()->addDriverStats(ctx_->pipelineId, std::move(stats)); } +void Driver::updateOperatorBlockingStats() { + // Record blocked time if the driver was blocked when terminated. + // This ensures we don't lose blocked time metrics when a query is aborted. + if (state_.hasBlockingFuture) { + // Accumulate driver-level blocked time unconditionally. + totalDriverBlockedNanos_ += + (currentTimeMicrosHires() - state_.blockingStartUs) * 1'000; + // Record per-operator blocked time if operator is available. + if (blockedOperatorId_ < operators_.size()) { + operators_[blockedOperatorId_]->recordBlockingTime( + state_.blockingStartUs, blockingReason_); + } + } +} + void Driver::startBarrier() { VELOX_CHECK(ctx_->task->underBarrier()); VELOX_CHECK( - !barrier_.has_value(), - "The driver has already started barrier processing"); - barrier_ = BarrierState{}; + !hasBarrier(), "The driver has already started barrier processing"); + barrier_.start(); } void Driver::drainOutput() { @@ -828,38 +983,38 @@ void Driver::drainOutput() { hasBarrier(), "Can't drain a driver not under barrier processing"); VELOX_CHECK(!isDraining(), "The driver is already draining"); // Starts to drain from the source operator. - barrier_->drainingOpId = 0; + barrier_.drainingOpId = 0; drainNextOperator(); } bool Driver::isDraining() const { - return hasBarrier() && barrier_->drainingOpId.has_value(); + return hasBarrier() && barrier_.drainingOpId.has_value(); } bool Driver::isDraining(int32_t operatorId) const { - return isDraining() && operatorId == barrier_->drainingOpId; + return isDraining() && operatorId == barrier_.drainingOpId; } bool Driver::hasDrained(int32_t operatorId) const { - return isDraining() && operatorId < barrier_->drainingOpId; + return isDraining() && operatorId < barrier_.drainingOpId; } void Driver::finishDrain(int32_t operatorId) { VELOX_CHECK(isDraining()); - VELOX_CHECK_EQ(barrier_->drainingOpId.value(), operatorId); - barrier_->drainingOpId = barrier_->drainingOpId.value() + 1; + VELOX_CHECK_EQ(barrier_.drainingOpId.value(), operatorId); + barrier_.drainingOpId = barrier_.drainingOpId.value() + 1; drainNextOperator(); } void Driver::drainNextOperator() { VELOX_CHECK(isDraining()); - for (; barrier_->drainingOpId < operators_.size(); - barrier_->drainingOpId = barrier_->drainingOpId.value() + 1) { - if (operators_[barrier_->drainingOpId.value()]->startDrain()) { + for (; barrier_.drainingOpId < operators_.size(); + barrier_.drainingOpId = barrier_.drainingOpId.value() + 1) { + if (operators_[barrier_.drainingOpId.value()]->startDrain()) { break; } } - if (barrier_->drainingOpId == operators_.size()) { + if (barrier_.drainingOpId == operators_.size()) { finishBarrier(); } } @@ -870,21 +1025,37 @@ void Driver::dropInput(int32_t operatorId) { return; } VELOX_CHECK_LT(operatorId, operators_.size()); - if (!barrier_->dropInputOpId.has_value()) { - barrier_->dropInputOpId = operatorId; - } else { - barrier_->dropInputOpId = std::max(*barrier_->dropInputOpId, operatorId); - } + // dropInput() is only called from operators within this driver's pipeline + // during barrier processing. Since a driver runs on a single thread at a + // time, we don't need compare-and-swap here. We simply keep the maximum + // operator id - all operators upstream (with smaller ids) will drop their + // output. + barrier_.dropInputOpId = std::max( + barrier_.dropInputOpId.load(std::memory_order_relaxed), operatorId); } bool Driver::shouldDropOutput(int32_t operatorId) const { - return hasBarrier() && barrier_->dropInputOpId.has_value() && - operatorId < *barrier_->dropInputOpId; + const int32_t dropOpId = + barrier_.dropInputOpId.load(std::memory_order_acquire); + return hasBarrier() && dropOpId != BarrierState::kNoDropInput && + operatorId < dropOpId; +} + +void Driver::BarrierState::start() { + VELOX_CHECK(!active.load(std::memory_order_acquire)); + active.store(true, std::memory_order_release); +} + +void Driver::BarrierState::reset() { + VELOX_CHECK(active.load(std::memory_order_acquire)); + active.store(false, std::memory_order_release); + drainingOpId = std::nullopt; + dropInputOpId.store(kNoDropInput, std::memory_order_relaxed); } void Driver::finishBarrier() { VELOX_CHECK(isDraining()); - VELOX_CHECK_EQ(barrier_->drainingOpId.value(), operators_.size()); + VELOX_CHECK_EQ(barrier_.drainingOpId.value(), operators_.size()); barrier_.reset(); ctx_->task->finishDriverBarrier(); } @@ -906,6 +1077,7 @@ void Driver::close() { void Driver::closeByTask() { VELOX_CHECK(isOnThread()); VELOX_CHECK(isTerminated()); + updateOperatorBlockingStats(); closeOperators(); updateStats(); closed_ = true; @@ -1035,11 +1207,13 @@ int Driver::pushdownFilters( operators_[j]->addDynamicFilterLocked(filterSource->planNodeId(), *lk); } operators_[j]->addRuntimeStat( - "dynamicFiltersAccepted", RuntimeCounter(numFiltersAccepted[j])); + std::string(DriverStats::kDynamicFiltersAccepted), + RuntimeCounter(numFiltersAccepted[j])); } if (numFiltersProduced > 0) { filterSource->addRuntimeStat( - "dynamicFiltersProduced", RuntimeCounter(numFiltersProduced)); + std::string(DriverStats::kDynamicFiltersProduced), + RuntimeCounter(numFiltersProduced)); } return numFiltersProduced; } @@ -1102,9 +1276,13 @@ std::string Driver::toString() const { } out << "{Operators: "; - for (auto& op : operators_) { - out << op->toString() << ", "; - } + std::vector opStrs; + opStrs.reserve(operators_.size()); + std::ranges::transform( + operators_, std::back_inserter(opStrs), [](const auto& op) { + return op->toString(); + }); + out << folly::join(", ", opStrs); out << "}"; const auto ocs = opCallStatus(); if (!ocs.empty()) { @@ -1220,6 +1398,14 @@ StopReason Driver::blockDriver( return StopReason::kBlock; } +int32_t Driver::getStartingOperator() const { + if (traceInput_ != nullptr) { + return blockedOperatorId_; + } + // By default, start at the last (the consumer). + return operators_.size() - 1; +} + std::string Driver::label() const { return fmt::format("", task()->taskId(), ctx_->driverId); } diff --git a/velox/exec/Driver.h b/velox/exec/Driver.h index cdcccbf64aa..2407e7240cc 100644 --- a/velox/exec/Driver.h +++ b/velox/exec/Driver.h @@ -16,18 +16,21 @@ #pragma once +#include #include +#include #include #include #include #include "velox/common/base/Counters.h" +#include "velox/common/base/Portability.h" #include "velox/common/base/StatsReporter.h" -#include "velox/common/base/TraceConfig.h" #include "velox/common/time/CpuWallTimer.h" #include "velox/core/PlanFragment.h" #include "velox/exec/BlockingReason.h" +#include "velox/exec/trace/TraceCtx.h" namespace facebook::velox::exec { @@ -86,7 +89,7 @@ std::ostream& operator<<(std::ostream& out, const StopReason& reason); /// Terminated - 'isTerminated' is set. The Driver cannot run after this and /// the state is final. /// -/// CancelPool allows terminating or pausing a set of Drivers. The Task API +/// Task allows terminating or pausing a set of Drivers. The Task API /// allows starting or resuming Drivers. When terminate is requested the request /// is successful when all Drivers are off thread, blocked or suspended. When /// pause is requested, we have success when all Drivers are either enqueued, @@ -102,7 +105,10 @@ struct ThreadState { std::atomic isTerminated{false}; /// True if there is a future outstanding that will schedule this on an /// executor thread when some promise is realized. - bool hasBlockingFuture{false}; + tsan_atomic hasBlockingFuture{false}; + /// Timestamp in microseconds when the driver became blocked. Used to record + /// blocked time when the driver is terminated while still blocked. + tsan_atomic blockingStartUs{0}; /// The number of suspension requests on a on-thread driver. If > 0, this /// driver thread is in a (recursive) section waiting for RPC or memory /// strategy decision. The thread is not supposed to access its memory, which @@ -167,7 +173,8 @@ struct ThreadState { obj["tid"] = tid.load(); obj["isTerminated"] = isTerminated.load(); obj["isEnqueued"] = isEnqueued.load(); - obj["hasBlockingFuture"] = hasBlockingFuture; + obj["hasBlockingFuture"] = tsanAtomicValue(hasBlockingFuture); + obj["blockingStartUs"] = tsanAtomicValue(blockingStartUs); obj["isSuspended"] = suspended(); obj["startExecTime"] = startExecTimeMs; return obj; @@ -224,9 +231,11 @@ constexpr uint32_t kUngroupedGroupId{std::numeric_limits::max()}; struct DriverCtx { const int driverId; const int pipelineId; + /// Id of the split group this driver should process in case of grouped /// execution, kUngroupedGroupId otherwise. const uint32_t splitGroupId; + /// Id of the partition to use by this driver. For local exchange, for /// instance. const uint32_t partitionId; @@ -234,6 +243,7 @@ struct DriverCtx { std::shared_ptr task; Driver* driver{nullptr}; facebook::velox::process::ThreadDebugInfo threadDebugInfo; + /// Tracks the traced operator ids. It is also used to avoid tracing the /// auxiliary operator such as the aggregation operator used by the table /// writer to generate the columns stats. @@ -248,14 +258,17 @@ struct DriverCtx { const core::QueryConfig& queryConfig() const; - const std::optional& traceConfig() const; + const trace::TraceCtx* traceCtx() const; velox::memory::MemoryPool* addOperatorPool( const core::PlanNodeId& planNodeId, const std::string& operatorType); - /// Builds the spill config for the operator with specified 'operatorId'. - std::optional makeSpillConfig(int32_t operatorId) const; + /// Builds the spill config for the operator with specified 'operatorId' and + /// 'operatorType'. + std::optional makeSpillConfig( + int32_t operatorId, + std::string_view operatorType) const; common::PrefixSortConfig prefixSortConfig() const { return common::PrefixSortConfig{ @@ -273,13 +286,13 @@ constexpr const char* kOpMethodAddInput = "addInput"; constexpr const char* kOpMethodNoMoreInput = "noMoreInput"; constexpr const char* kOpMethodIsFinished = "isFinished"; -/// Same as the structure below, but does not have atomic members. -/// Used to return the status from the struct with atomics. +/// Non-atomic snapshot of OpCallStatus, used to return a consistent status from +/// OpCallStatus which uses atomic members internally. struct OpCallStatusRaw { /// Time (ms) when the operator call started. size_t timeStartMs{0}; - /// Id of the operator, method of which is currently running. It is index into - /// the vector of Driver's operators. + /// Id of the operator, method of which is currently running. It is the index + /// into the vector of Driver's operators. int32_t opId{0}; /// Method of the operator, which is currently running. const char* method{kOpMethodNone}; @@ -293,14 +306,15 @@ struct OpCallStatusRaw { }; /// Structure holds the information about the current operator call the driver -/// is in. Can be used to detect deadlocks and otherwise blocked calls. -/// If timeStartMs is zero, then we aren't in an operator call. -struct OpCallStatus { +/// is in. Can be used to detect deadlocks and otherwise blocked calls. If +/// 'timeStartMs_' is zero, then we aren't in an operator call. +class OpCallStatus { + public: OpCallStatus() {} /// The status accessor. OpCallStatusRaw operator()() const { - return OpCallStatusRaw{timeStartMs, opId, method}; + return OpCallStatusRaw{timeStartMs_, opId_, method_}; } void start(int32_t operatorId, const char* operatorMethod); @@ -308,12 +322,12 @@ struct OpCallStatus { private: /// Time (ms) when the operator call started. - std::atomic_size_t timeStartMs{0}; - /// Id of the operator, method of which is currently running. It is index into - /// the vector of Driver's operators. - std::atomic_int32_t opId{0}; + std::atomic_size_t timeStartMs_{0}; + /// Id of the operator, method of which is currently running. It is the index + /// into the vector of Driver's operators. + std::atomic_int32_t opId_{0}; /// Method of the operator, which is currently running. - std::atomic method{kOpMethodNone}; + std::atomic method_{kOpMethodNone}; }; struct PushdownFilters { @@ -338,6 +352,14 @@ using PipelinePushdownFilters = class Driver : public std::enable_shared_from_this { public: + ~Driver(); + + // Disable copy and move + Driver(const Driver&) = delete; + Driver& operator=(const Driver&) = delete; + Driver(Driver&&) = delete; + Driver& operator=(Driver&&) = delete; + static void enqueue(std::shared_ptr instance); /// Run the pipeline until it produces a batch of data or gets blocked. @@ -394,6 +416,12 @@ class Driver : public std::enable_shared_from_this { /// memory arbitration finishes. bool checkUnderArbitration(ContinueFuture* future); + /// Accumulates blocked time for driver-level lifecycle tracking. + /// Called from BlockingState::setResume() when the blocking future resolves. + void addDriverBlockedTime(uint64_t nanos) { + totalDriverBlockedNanos_ += nanos; + } + void initializeOperatorStats(std::vector& stats); /// Close operators and add operator stats to the task. @@ -481,7 +509,7 @@ class Driver : public std::enable_shared_from_this { /// Returns true if the driver is under barrier processing. bool hasBarrier() const { - return barrier_.has_value(); + return barrier_.active.load(std::memory_order_acquire); } /// Invoked to start draining the output of this driver pipeline from the @@ -579,15 +607,33 @@ class Driver : public std::enable_shared_from_this { void updateStats(); + /// Records operator blocked time ins case the driver was off the thrread and + /// blocked when terminated. + void updateOperatorBlockingStats(); + // Defines the driver barrier processing state. struct BarrierState { + // True if the driver is under barrier processing. This is set by + // startBarrier() and read by hasBarrier() from different threads. + std::atomic_bool active{false}; // If set, the driver has started output draining. It points to the operator // that is currently draining output. std::optional drainingOpId{std::nullopt}; - // If set, the specified operator doesn't need any more input to finish the - // draining operation. All the upstream operators within the same driver - // should drop their output or output processing. - std::optional dropInputOpId{std::nullopt}; + // The operator id that doesn't need any more input to finish the draining + // operation. All the upstream operators within the same driver should drop + // their output or output processing. -1 means not set. This is accessed + // from different driver threads so it needs to be atomic. + static constexpr int32_t kNoDropInput = -1; + std::atomic_int32_t dropInputOpId{kNoDropInput}; + + BarrierState() = default; + + /// Starts the barrier processing. Must be called when the barrier is not + /// active. + void start(); + + /// Resets the barrier state. Must be called when the barrier is active. + void reset(); }; // Invoked to start draining on the next operator. If there is no "next" @@ -627,6 +673,15 @@ class Driver : public std::enable_shared_from_this { std::shared_ptr& blockingState, CancelGuard& guard); + // Returns the operator to start from. The Driver always start the driver + // pipeline from the consumer (leaf), then walk backwards to the root based on + // whether they are ready to consume data, and the previous is ready to + // produce data. + // + // The only exception is when resuming from a trace, in which case the Driver + // must resume at the operator where it was blocked. + int32_t getStartingOperator() const; + std::unique_ptr ctx_; // If set, the operator output batch size stats will be collected during @@ -640,8 +695,8 @@ class Driver : public std::enable_shared_from_this { std::atomic_bool closed_{false}; - // If set, the driver is under a barrier processing. - std::optional barrier_; + // The driver barrier processing state. + BarrierState barrier_; OpCallStatus opCallStatus_; @@ -650,16 +705,44 @@ class Driver : public std::enable_shared_from_this { // Timer used to track down the time we are sitting in the driver queue. size_t queueTimeStartUs_{0}; + + // Driver-level lifecycle timing: independently tracks the three states + // a driver can be in (queued, on-thread, blocked) to enable gap analysis. + // Reported as RuntimeStats on the source operator at close time. + // All three use high_resolution_clock for consistency, matching + // BlockingState::sinceUs_ and enabling accurate gap analysis + // (queued + on-thread + blocked ≈ elapsed time). + // Atomic because closeByTask() may read these from a different thread + // than the one running the onThreadTimeGuard scope guard. + std::atomic totalDriverQueuedNanos_{0}; + std::atomic totalDriverOnThreadNanos_{0}; + std::atomic totalDriverBlockedNanos_{0}; + // Timestamp (micros, high_resolution_clock) when the current on-thread + // period started. Set at the beginning of runInternal, used by + // closeOperators to snapshot on-thread time before the scope guard fires. + // Atomic because closeByTask() may access it from a different thread + // than the one running the onThreadTimeGuard scope guard in runInternal. + std::atomic onThreadStartUs_{0}; + // Id (index in the vector) of the current operator to run (or the 1st one if // we haven't started yet). Used to determine which operator's queueTime we // should update. size_t curOperatorId_{0}; - std::vector> operators_; + std::vector> operators_; // NOLINT - BlockingReason blockingReason_{BlockingReason::kNotBlocked}; + tsan_atomic blockingReason_{BlockingReason::kNotBlocked}; + + // Stores the operator where the driver was last blocked. Note that the driver + // always resumes at the leaf (consumer) to prioritize getting data out of the + // task. The only exception is when resuming from a trace. size_t blockedOperatorId_{0}; + // If this driver is being traced, store a pointer to the current data. Once + // the trace client unblocks the driver, we will feed this vector to the next + // operator in the pipeline. + RowVectorPtr traceInput_; + bool trackOperatorCpuUsage_; // Indicates that a DriverAdapter can rearrange Operators. Set to false at end @@ -822,6 +905,10 @@ struct DriverFactory { /// based on this pipeline. std::vector needsSpatialJoinBridges() const; + /// Returns plan node IDs for which IndexLookupJoin Bridges must be created + /// based on this pipeline. + std::vector needsIndexLookupJoinBridges() const; + static std::vector adapters; }; diff --git a/velox/exec/DriverStats.h b/velox/exec/DriverStats.h index 3b047f4c9e0..d99e3afd464 100644 --- a/velox/exec/DriverStats.h +++ b/velox/exec/DriverStats.h @@ -16,16 +16,29 @@ #pragma once +#include #include #include "velox/common/base/RuntimeMetrics.h" namespace facebook::velox::exec { struct DriverStats { - static constexpr const char* kTotalPauseTime = "totalDriverPauseWallNanos"; - static constexpr const char* kTotalOffThreadTime = + static constexpr std::string_view kTotalPauseTime = + "totalDriverPauseWallNanos"; + static constexpr std::string_view kTotalOffThreadTime = "totalDriverOffThreadWallNanos"; + /// Number of silent Velox throws during operator execution. + static constexpr std::string_view kNumSilentThrow = "numSilentThrow"; + /// Time an operator spent queued before execution. + static constexpr std::string_view kQueuedWallNanos = "queuedWallNanos"; + /// Number of dynamic filters accepted by an operator. + static constexpr std::string_view kDynamicFiltersAccepted = + "dynamicFiltersAccepted"; + /// Number of dynamic filters produced by an operator. + static constexpr std::string_view kDynamicFiltersProduced = + "dynamicFiltersProduced"; + std::unordered_map runtimeStats; }; diff --git a/velox/exec/EnforceDistinct.cpp b/velox/exec/EnforceDistinct.cpp new file mode 100644 index 00000000000..8007edd7b94 --- /dev/null +++ b/velox/exec/EnforceDistinct.cpp @@ -0,0 +1,79 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/EnforceDistinct.h" + +#include "velox/exec/OperatorType.h" + +namespace facebook::velox::exec { + +EnforceDistinct::EnforceDistinct( + int32_t operatorId, + DriverCtx* driverCtx, + const std::shared_ptr& planNode) + : Operator( + driverCtx, + planNode->outputType(), + operatorId, + planNode->id(), + OperatorType::kEnforceDistinct), + errorMessage_{planNode->errorMessage()} { + const auto& inputType = planNode->sources()[0]->outputType(); + + for (auto i = 0; i < inputType->size(); ++i) { + identityProjections_.emplace_back(i, i); + } + + groupingSet_ = GroupingSet::createForDistinct( + inputType, + createVectorHashers(inputType, planNode->distinctKeys()), + toChannels( + inputType, + std::vector{ + planNode->preGroupedKeys().begin(), + planNode->preGroupedKeys().end()}), + operatorCtx_.get(), + &nonReclaimableSection_); +} + +void EnforceDistinct::addInput(RowVectorPtr input) { + groupingSet_->addInput(input, /*mayPushdown=*/false); + + const auto& newGroups = groupingSet_->hashLookup().newGroups; + if (newGroups.size() != input->size()) { + VELOX_USER_FAIL("{}", errorMessage_); + } + + input_ = std::move(input); +} + +RowVectorPtr EnforceDistinct::getOutput() { + if (isFinished() || !input_) { + return nullptr; + } + + auto output = fillOutput(input_->size(), nullptr); + + input_ = nullptr; + + return output; +} + +bool EnforceDistinct::isFinished() { + return noMoreInput_ && !input_; +} + +} // namespace facebook::velox::exec diff --git a/velox/exec/EnforceDistinct.h b/velox/exec/EnforceDistinct.h new file mode 100644 index 00000000000..a09d14f25b2 --- /dev/null +++ b/velox/exec/EnforceDistinct.h @@ -0,0 +1,57 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/exec/GroupingSet.h" +#include "velox/exec/Operator.h" + +namespace facebook::velox::exec { + +/// Enforces uniqueness of rows based on specified key columns. Passes through +/// all input rows unchanged. Throws an exception if any duplicate key values +/// are detected. +class EnforceDistinct : public Operator { + public: + EnforceDistinct( + int32_t operatorId, + DriverCtx* driverCtx, + const std::shared_ptr& planNode); + + bool preservesOrder() const override { + return true; + } + + bool needsInput() const override { + return !noMoreInput_ && !input_; + } + + void addInput(RowVectorPtr input) override; + + RowVectorPtr getOutput() override; + + BlockingReason isBlocked(ContinueFuture* /*future*/) override { + return BlockingReason::kNotBlocked; + } + + bool isFinished() override; + + private: + const std::string errorMessage_; + std::unique_ptr groupingSet_; +}; + +} // namespace facebook::velox::exec diff --git a/velox/exec/EnforceSingleRow.cpp b/velox/exec/EnforceSingleRow.cpp index 85d629431fc..c7f71b08edb 100644 --- a/velox/exec/EnforceSingleRow.cpp +++ b/velox/exec/EnforceSingleRow.cpp @@ -15,6 +15,8 @@ */ #include "velox/exec/EnforceSingleRow.h" +#include "velox/exec/OperatorType.h" + namespace facebook::velox::exec { EnforceSingleRow::EnforceSingleRow( @@ -26,7 +28,7 @@ EnforceSingleRow::EnforceSingleRow( planNode->outputType(), operatorId, planNode->id(), - "EnforceSingleRow") { + OperatorType::kEnforceSingleRow) { isIdentityProjection_ = true; } diff --git a/velox/exec/Exchange.cpp b/velox/exec/Exchange.cpp index 4afdb425e7a..2d67c88d54b 100644 --- a/velox/exec/Exchange.cpp +++ b/velox/exec/Exchange.cpp @@ -14,8 +14,9 @@ * limitations under the License. */ #include "velox/exec/Exchange.h" - +#include "velox/common/Casts.h" #include "velox/common/serialization/Serializable.h" +#include "velox/exec/OperatorUtils.h" #include "velox/exec/Task.h" #include "velox/serializers/CompactRowSerializer.h" @@ -42,16 +43,18 @@ void RemoteConnectorSplit::registerSerDe() { } namespace { -std::unique_ptr getVectorSerdeOptions( - const core::QueryConfig& queryConfig, - VectorSerde::Kind kind) { - std::unique_ptr options = - kind == VectorSerde::Kind::kPresto - ? std::make_unique() - : std::make_unique(); - options->compressionKind = - common::stringToCompressionKind(queryConfig.shuffleCompressionKind()); - return options; +std::unique_ptr mergePages( + std::vector>& pages) { + VELOX_CHECK(!pages.empty()); + std::unique_ptr mergedBufs; + for (const auto& page : pages) { + if (mergedBufs == nullptr) { + mergedBufs = page->getIOBuf(); + } else { + mergedBufs->appendToChain(page->getIOBuf()); + } + } + return mergedBufs; } } // namespace @@ -60,7 +63,7 @@ Exchange::Exchange( DriverCtx* driverCtx, const std::shared_ptr& exchangeNode, std::shared_ptr exchangeClient, - const std::string& operatorType) + std::string_view operatorType) : SourceOperator( driverCtx, exchangeNode->outputType(), @@ -71,8 +74,14 @@ Exchange::Exchange( driverCtx->queryConfig().preferredOutputBatchBytes()}, serdeKind_{exchangeNode->serdeKind()}, serdeOptions_{getVectorSerdeOptions( - operatorCtx_->driverCtx()->queryConfig(), - serdeKind_)}, + common::stringToCompressionKind(operatorCtx_->driverCtx() + ->queryConfig() + .shuffleCompressionKind()), + serdeKind_, + std::nullopt, + operatorCtx_->driverCtx() + ->queryConfig() + .minShuffleCompressionPageSizeBytes())}, processSplits_{operatorCtx_->driverCtx()->driverId == 0}, driverId_{driverCtx->driverId}, exchangeClient_{std::move(exchangeClient)} {} @@ -85,42 +94,48 @@ void Exchange::addRemoteTaskIds(std::vector& remoteTaskIds) { stats_.wlock()->numSplits += remoteTaskIds.size(); } -bool Exchange::getSplits(ContinueFuture* future) { +void Exchange::getSplits(ContinueFuture* future) { if (!processSplits_) { - return false; + return; } if (noMoreSplits_) { - return false; + return; } std::vector remoteTaskIds; for (;;) { exec::Split split; - auto reason = operatorCtx_->task()->getSplitOrFuture( - operatorCtx_->driverCtx()->splitGroupId, planNodeId(), split, *future); - if (reason == BlockingReason::kNotBlocked) { - if (split.hasConnectorSplit()) { - auto remoteSplit = std::dynamic_pointer_cast( - split.connectorSplit); - VELOX_CHECK_NOT_NULL(remoteSplit, "Wrong type of split"); - if (FOLLY_UNLIKELY(splitTracer_ != nullptr)) { - splitTracer_->write(split); - } - remoteTaskIds.push_back(remoteSplit->taskId); - } else { - addRemoteTaskIds(remoteTaskIds); - exchangeClient_->noMoreRemoteTasks(); - noMoreSplits_ = true; - if (atEnd_) { - operatorCtx_->task()->multipleSplitsFinished( - false, stats_.rlock()->numSplits, 0); - recordExchangeClientStats(); - } - return false; - } - } else { + const auto reason = operatorCtx_->task()->getSplitOrFuture( + operatorCtx_->driverCtx()->driverId, + operatorCtx_->driverCtx()->splitGroupId, + planNodeId(), + /*maxPreloadSplits=*/0, + /*preload=*/nullptr, + split, + *future); + if (reason != BlockingReason::kNotBlocked) { addRemoteTaskIds(remoteTaskIds); - return true; + return; + } + + if (split.hasConnectorSplit()) { + auto remoteSplit = + checkedPointerCast(split.connectorSplit); + if (FOLLY_UNLIKELY(splitTracer_ != nullptr)) { + splitTracer_->write(split); + } + remoteTaskIds.push_back(remoteSplit->taskId); + continue; + } + + addRemoteTaskIds(remoteTaskIds); + exchangeClient_->noMoreRemoteTasks(); + noMoreSplits_ = true; + if (atEnd_) { + operatorCtx_->task()->multipleSplitsFinished( + false, stats_.rlock()->numSplits, 0); + recordExchangeClientStats(); } + return; } } @@ -130,7 +145,6 @@ BlockingReason Exchange::isBlocked(ContinueFuture* future) { } // Start fetching data right away. Do not wait for all splits to be available. - if (!splitFuture_.valid()) { getSplits(&splitFuture_); } @@ -159,6 +173,7 @@ BlockingReason Exchange::isBlocked(ContinueFuture* future) { } // Block until data becomes available. + VELOX_CHECK(dataFuture.valid()); *future = std::move(dataFuture); return BlockingReason::kWaitForProducer; } @@ -167,102 +182,142 @@ bool Exchange::isFinished() { return atEnd_ && currentPages_.empty(); } -namespace { -std::unique_ptr mergePages( - std::vector>& pages) { - VELOX_CHECK(!pages.empty()); - std::unique_ptr mergedBufs; - for (const auto& page : pages) { - if (mergedBufs == nullptr) { - mergedBufs = page->getIOBuf(); - } else { - mergedBufs->appendToChain(page->getIOBuf()); - } - } - return mergedBufs; -} -} // namespace - RowVectorPtr Exchange::getOutput() { auto* serde = getSerde(); if (serde->supportsAppendInDeserialize()) { - uint64_t rawInputBytes{0}; - if (currentPages_.empty()) { - return nullptr; - } - vector_size_t resultOffset = 0; - for (const auto& page : currentPages_) { + return getOutputFromColumnarPages(serde); + } + return getOutputFromRowPages(serde); +} + +RowVectorPtr Exchange::getOutputFromColumnarPages(VectorSerde* serde) { + if (currentPages_.empty()) { + return nullptr; + } + + // Calculate target row count based on estimated row size, similar to + // getOutputFromRowPages. + // Start conservatively, then use estimates. + const auto numRows = estimatedRowSize_.has_value() + ? std::max( + (preferredOutputBatchBytes_ / estimatedRowSize_.value()), + kInitialOutputRows) + : kInitialOutputRows; + + // Process pages one-by-one from currentPages_ pointed by columnarPageIdx_. + // Within each page, deserialize vectors incrementally until we hit the target + // batch size. + uint64_t rawInputBytes = 0; + vector_size_t resultOffset{0}; + + // Should be either starting fresh or continuing from a previous partial page + VELOX_CHECK( + inputStream_ == nullptr || columnarPageIdx_ < currentPages_.size()); + + // Iterate through pages + while (columnarPageIdx_ < currentPages_.size()) { + auto& page = currentPages_[columnarPageIdx_]; + + if (!inputStream_) { + // NOTE: 'rawInputBytes' only counts bytes from pages processed from the + // beginning in this call. If processing resumes from the middle of a + // page, that page's bytes are not counted. This ensures each page is + // counted only once in 'rawInputBytes' across multiple calls. rawInputBytes += page->size(); + inputStream_ = page->prepareStreamForDeserialize(); + } - auto inputStream = page->prepareStreamForDeserialize(); - while (!inputStream->atEnd()) { - serde->deserialize( - inputStream.get(), - pool(), - outputType_, - &result_, - resultOffset, - serdeOptions_.get()); - resultOffset = result_->size(); - } + // Inner loop: deserialize vectors from current page until batch is full + // or page is exhausted. + while (!inputStream_->atEnd() && resultOffset < numRows) { + serde->deserialize( + inputStream_.get(), + pool(), + outputType_, + &result_, + resultOffset, + serdeOptions_.get()); + + resultOffset = result_->size(); + } + + if (inputStream_->atEnd()) { + // Page is fully consumed, free memory immediately, and move to the next. + inputStream_ = nullptr; + page.reset(); + ++columnarPageIdx_; + } + + // Stop if accumulated enough rows for this batch. + if (resultOffset >= numRows) { + break; } - currentPages_.clear(); - recordInputStats(rawInputBytes); - return result_; - } - if (serde->kind() == VectorSerde::Kind::kCompactRow) { - return getOutputFromCompactRows(serde); } - if (serde->kind() == VectorSerde::Kind::kUnsafeRow) { - return getOutputFromUnsafeRows(serde); + + const auto numOutputRows = result_->size(); + VELOX_CHECK_GT(numOutputRows, 0); + + estimatedRowSize_ = std::max( + result_->estimateFlatSize() / numOutputRows, + estimatedRowSize_.value_or(1L)); + + // If processed all pages, clear the vector and reset state. + if (columnarPageIdx_ >= currentPages_.size()) { + VELOX_CHECK_NULL(inputStream_); + currentPages_.clear(); + columnarPageIdx_ = 0; } - VELOX_UNREACHABLE( - "Unsupported serde kind: {}", VectorSerde::kindName(serde->kind())); + + recordInputStats(rawInputBytes); + return result_; } -RowVectorPtr Exchange::getOutputFromCompactRows(VectorSerde* serde) { +RowVectorPtr Exchange::getOutputFromRowPages(VectorSerde* serde) { uint64_t rawInputBytes{0}; if (currentPages_.empty()) { - VELOX_CHECK_NULL(compactRowInputStream_); - VELOX_CHECK_NULL(compactRowIterator_); + VELOX_CHECK_NULL(inputStream_); + VELOX_CHECK_NULL(rowIterator_); return nullptr; } - if (compactRowInputStream_ == nullptr) { + if (inputStream_ == nullptr) { std::unique_ptr mergedBufs = mergePages(currentPages_); rawInputBytes += mergedBufs->computeChainDataLength(); - compactRowPages_ = std::make_unique(std::move(mergedBufs)); - compactRowInputStream_ = compactRowPages_->prepareStreamForDeserialize(); + mergedRowPage_ = + std::make_unique(std::move(mergedBufs)); + inputStream_ = mergedRowPage_->prepareStreamForDeserialize(); } - auto numRows = kInitialOutputCompactRows; - if (estimatedCompactRowSize_.has_value()) { + auto numRows = kInitialOutputRows; + if (estimatedRowSize_.has_value()) { numRows = std::max( - (preferredOutputBatchBytes_ / estimatedCompactRowSize_.value()), - kInitialOutputCompactRows); + (preferredOutputBatchBytes_ / estimatedRowSize_.value()), + kInitialOutputRows); } + // Check if the serde supports batched deserialization serde->deserialize( - compactRowInputStream_.get(), - compactRowIterator_, + inputStream_.get(), + rowIterator_, numRows, outputType_, &result_, pool(), serdeOptions_.get()); + const auto numOutputRows = result_->size(); VELOX_CHECK_GT(numOutputRows, 0); - estimatedCompactRowSize_ = std::max( + estimatedRowSize_ = std::max( result_->estimateFlatSize() / numOutputRows, - estimatedCompactRowSize_.value_or(1L)); + estimatedRowSize_.value_or(1L)); - if (compactRowInputStream_->atEnd() && compactRowIterator_ == nullptr) { + if (inputStream_->atEnd() && rowIterator_ == nullptr) { // only clear the input stream if we have reached the end of the row // iterator because row iterator may depend on input stream if serialized // rows are not compressed. - compactRowInputStream_ = nullptr; - compactRowPages_ = nullptr; + inputStream_ = nullptr; + mergedRowPage_ = nullptr; currentPages_.clear(); } @@ -270,22 +325,6 @@ RowVectorPtr Exchange::getOutputFromCompactRows(VectorSerde* serde) { return result_; } -RowVectorPtr Exchange::getOutputFromUnsafeRows(VectorSerde* serde) { - uint64_t rawInputBytes{0}; - if (currentPages_.empty()) { - return nullptr; - } - std::unique_ptr mergedBufs = mergePages(currentPages_); - rawInputBytes += mergedBufs->computeChainDataLength(); - auto mergedPages = std::make_unique(std::move(mergedBufs)); - auto source = mergedPages->prepareStreamForDeserialize(); - serde->deserialize( - source.get(), pool(), outputType_, &result_, serdeOptions_.get()); - currentPages_.clear(); - recordInputStats(rawInputBytes); - return result_; -} - void Exchange::recordInputStats(uint64_t rawInputBytes) { auto lockedStats = stats_.wlock(); lockedStats->rawInputBytes += rawInputBytes; @@ -297,16 +336,27 @@ void Exchange::close() { SourceOperator::close(); currentPages_.clear(); result_ = nullptr; + + // Clean up stateful deserialization state + inputStream_ = nullptr; + mergedRowPage_ = nullptr; + rowIterator_ = nullptr; + columnarPageIdx_ = 0; + if (exchangeClient_) { - recordExchangeClientStats(); + // Close the client before recording stats so that stats are captured + // from the final state. ExchangeClient::close() caches final stats + // before clearing sources. exchangeClient_->close(); + recordExchangeClientStats(); } exchangeClient_ = nullptr; { auto lockedStats = stats_.wlock(); lockedStats->addRuntimeStat( Operator::kShuffleSerdeKind, - RuntimeCounter(static_cast(serdeKind_))); + RuntimeCounter( + static_cast(VectorSerde::kindByName(serdeKind_)))); lockedStats->addRuntimeStat( Operator::kShuffleCompressionKind, RuntimeCounter(static_cast(serdeOptions_->compressionKind))); @@ -325,14 +375,13 @@ void Exchange::recordExchangeClientStats() { lockedStats->runtimeStats.insert({name, value}); } - auto backgroundCpuTimeMs = - exchangeClientStats.find(ExchangeClient::kBackgroundCpuTimeMs); - if (backgroundCpuTimeMs != exchangeClientStats.end()) { + const auto iter = + exchangeClientStats.find(std::string(Operator::kBackgroundCpuTimeNanos)); + if (iter != exchangeClientStats.end()) { const CpuWallTiming backgroundTiming{ - static_cast(backgroundCpuTimeMs->second.count), + static_cast(iter->second.count), 0, - static_cast(backgroundCpuTimeMs->second.sum) * - Timestamp::kNanosecondsInMillisecond}; + static_cast(iter->second.sum)}; lockedStats->backgroundTiming.clear(); lockedStats->backgroundTiming.add(backgroundTiming); } diff --git a/velox/exec/Exchange.h b/velox/exec/Exchange.h index 2e235b698e9..05e984fc595 100644 --- a/velox/exec/Exchange.h +++ b/velox/exec/Exchange.h @@ -19,6 +19,7 @@ #include "velox/exec/ExchangeClient.h" #include "velox/exec/Operator.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/OutputBufferManager.h" #include "velox/serializers/PrestoSerializer.h" #include "velox/serializers/RowSerializer.h" @@ -50,7 +51,7 @@ class Exchange : public SourceOperator { DriverCtx* driverCtx, const std::shared_ptr& exchangeNode, std::shared_ptr exchangeClient, - const std::string& operatorType = "Exchange"); + std::string_view operatorType = OperatorType::kExchange); ~Exchange() override { close(); @@ -67,12 +68,11 @@ class Exchange : public SourceOperator { protected: virtual VectorSerde* getSerde(); - private: - // When 'estimatedCompactRowSize_' is unset, meaning we haven't materialized + // When 'estimatedRowSize_' is unset, meaning we haven't materialized // and returned any output from this exchange operator, we return this // conservative number of output rows, to make sure memory does not grow too // much. - static constexpr uint64_t kInitialOutputCompactRows = 64; + static constexpr uint64_t kInitialOutputRows = 64; // Invoked to create exchange client for remote tasks. The function shuffles // the source task ids first to randomize the source tasks we fetch data from. @@ -82,11 +82,8 @@ class Exchange : public SourceOperator { // Fetches splits from the task until there are no more splits or task returns // a future that will be complete when more splits arrive. Adds splits to - // exchangeClient_. Returns true if received a future from the task and sets - // the 'future' parameter. Returns false if fetched all splits or if this - // operator is not the first operator in the pipeline and therefore is not - // responsible for fetching splits and adding them to the exchangeClient_. - bool getSplits(ContinueFuture* future); + // exchangeClient_. + void getSplits(ContinueFuture* future); // Fetches runtime stats from ExchangeClient and replaces these in this // operator's stats. @@ -94,13 +91,13 @@ class Exchange : public SourceOperator { void recordInputStats(uint64_t rawInputBytes); - RowVectorPtr getOutputFromCompactRows(VectorSerde* serde); + RowVectorPtr getOutputFromColumnarPages(VectorSerde* serde); - RowVectorPtr getOutputFromUnsafeRows(VectorSerde* serde); + RowVectorPtr getOutputFromRowPages(VectorSerde* serde); const uint64_t preferredOutputBatchBytes_; - const VectorSerde::Kind serdeKind_; + const std::string serdeKind_; const std::unique_ptr serdeOptions_; @@ -121,19 +118,25 @@ class Exchange : public SourceOperator { // Reusable result vector. RowVectorPtr result_; - std::vector> currentPages_; + std::vector> currentPages_; bool atEnd_{false}; std::default_random_engine rng_{std::random_device{}()}; - // Memory holders needed by compact row serde to perform cursor like reads - // across 'getOutputFromCompactRows' calls. - std::unique_ptr compactRowPages_; - std::unique_ptr compactRowInputStream_; - std::unique_ptr compactRowIterator_; + // Memory holders for deserialization across 'getOutput' calls. + // The merged pages for row serialization. + std::unique_ptr mergedRowPage_; + std::unique_ptr rowIterator_; + + // State for columnar page deserialization. + // Index of the current page in 'currentPages_' being processed. + size_t columnarPageIdx_{0}; + + // Stream for deserialization used by both row and columnar. + std::unique_ptr inputStream_; // The estimated bytes per row of the output of this exchange operator // computed from the last processed output. - std::optional estimatedCompactRowSize_; + std::optional estimatedRowSize_; }; } // namespace facebook::velox::exec diff --git a/velox/exec/ExchangeClient.cpp b/velox/exec/ExchangeClient.cpp index 0a8b92b2f01..0b1ae2ef098 100644 --- a/velox/exec/ExchangeClient.cpp +++ b/velox/exec/ExchangeClient.cpp @@ -53,7 +53,16 @@ void ExchangeClient::addRemoteTaskId(const std::string& remoteTaskId) { sources_.push_back(source); queue_->addSourceLocked(); emptySources_.push(source); - requestSpecs = pickSourcesToRequestLocked(); + // When lazyFetching_ is true, I/O will be triggered lazily when next() is + // called from Exchange::isBlocked(). This allows waiter tasks using + // cached hash tables to skip I/O entirely when the table is already + // cached - the HashBuild operator will finish before + // Exchange::isBlocked() is ever called, so no unnecessary data fetching + // occurs. + if (!lazyFetching_) { + // Start fetching data immediately. + requestSpecs = pickSourcesToRequestLocked(); + } } } @@ -78,6 +87,11 @@ void ExchangeClient::close() { if (closed_) { return; } + + // Capture stats BEFORE clearing sources_. + // This allows stats() to return meaningful data even after close(). + stats_ = collectStatsLocked(); + closed_ = true; sources = std::move(sources_); producingSources = std::move(producingSources_); @@ -91,41 +105,52 @@ void ExchangeClient::close() { queue_->close(); } -folly::F14FastMap ExchangeClient::stats() const { - folly::F14FastMap stats; +folly::F14FastMap ExchangeClient::stats() { std::lock_guard l(queue_->mutex()); + if (stats_.empty()) { + stats_ = collectStatsLocked(); + } + return stats_; +} + +folly::F14FastMap +ExchangeClient::collectStatsLocked() const { + folly::F14FastMap stats; for (const auto& source : sources_) { if (source->supportsMetrics()) { for (const auto& [name, value] : source->metrics()) { - if (UNLIKELY(stats.count(name) == 0)) { - stats.insert(std::pair(name, RuntimeMetric(value.unit))); - } - stats[name].merge(value); + auto [iter, inserted] = stats.try_emplace(name, value.unit); + iter->second.merge(value); } } else { for (const auto& [name, value] : source->stats()) { - stats[name].addValue(value); + auto [iter, inserted] = stats.try_emplace(name); + iter->second.addValue(value); } } } - stats["peakBytes"] = - RuntimeMetric(queue_->peakBytes(), RuntimeCounter::Unit::kBytes); - stats["numReceivedPages"] = RuntimeMetric(queue_->receivedPages()); - stats["averageReceivedPageBytes"] = RuntimeMetric( - queue_->averageReceivedPageBytes(), RuntimeCounter::Unit::kBytes); + stats.insert_or_assign( + "peakBytes", + RuntimeMetric(queue_->peakBytes(), RuntimeCounter::Unit::kBytes)); + stats.insert_or_assign( + "numReceivedPages", RuntimeMetric(queue_->receivedPages())); + stats.insert_or_assign( + "averageReceivedPageBytes", + RuntimeMetric( + queue_->averageReceivedPageBytes(), RuntimeCounter::Unit::kBytes)); return stats; } -std::vector> ExchangeClient::next( +std::vector> ExchangeClient::next( int consumerId, uint32_t maxBytes, bool* atEnd, ContinueFuture* future) { std::vector requestSpecs; - std::vector> pages; + std::vector> pages; ContinuePromise stalePromise = ContinuePromise::makeEmpty(); { std::lock_guard l(queue_->mutex()); @@ -161,7 +186,7 @@ void ExchangeClient::request(std::vector&& requestSpecs) { for (auto& spec : requestSpecs) { auto future = folly::SemiFuture::makeEmpty(); if (spec.maxBytes == 0) { - future = spec.source->requestDataSizes(kRequestDataSizesMaxWaitSec_); + future = spec.source->requestDataSizes(requestDataSizesMaxWaitSec_); } else { future = spec.source->request(spec.maxBytes, kRequestDataMaxWait); } @@ -185,7 +210,7 @@ void ExchangeClient::request(std::vector&& requestSpecs) { RECORD_METRIC_VALUE(kMetricExchangeDataCount); } - bool pauseCurrentSource = false; + bool pauseCurrentSource{false}; std::vector requestSpecs; std::shared_ptr currentSource = spec.source; { @@ -232,6 +257,9 @@ ExchangeClient::pickSourcesToRequestLocked() { if (closed_) { return {}; } + if (skipRequestDataSizeWithSingleSource()) { + return pickupSingleSourceToRequestLocked(); + } std::vector requestSpecs; while (!emptySources_.empty()) { auto& source = emptySources_.front(); @@ -267,10 +295,10 @@ ExchangeClient::pickSourcesToRequestLocked() { // 1. We have full capacity but still cannot initiate one single data // transfer. Let the transfer happen in this case to avoid getting stuck. // - // 2. We have some data in the queue that is not big enough for - // consumers and it is big enough to not allow ExchangeClient to - // initiate request for more data. Let transfer happen in this case - // to avoid this deadlock situation. + // 2. We have some data in the queue that is not big enough for consumers, + // and it is big enough to not allow ExchangeClient to initiate request + // for more data. Let transfer happen in this case to avoid this deadlock + // situation. auto& source = producingSources_.front().source; auto requestBytes = producingSources_.front().remainingBytes.at(0); LOG(INFO) << "Requesting large single page " << requestBytes @@ -283,6 +311,43 @@ ExchangeClient::pickSourcesToRequestLocked() { return requestSpecs; } +std::vector +ExchangeClient::pickupSingleSourceToRequestLocked() { + VELOX_CHECK_EQ(sources_.size(), 1); + VELOX_CHECK(!closed_); + if (emptySources_.empty() && producingSources_.empty()) { + return {}; + } + + VELOX_CHECK_EQ(totalPendingBytes_, 0); + VELOX_CHECK_LE(!!emptySources_.empty() + !!producingSources_.empty(), 1); + const auto requestBytes = maxQueuedBytes_ - queue_->totalBytes(); + + if (requestBytes <= 0) { + return {}; + } + std::vector requestSpecs; + SCOPE_EXIT { + totalPendingBytes_ += requestBytes; + }; + if (!emptySources_.empty()) { + VELOX_CHECK_EQ(emptySources_.size(), 1); + auto& source = emptySources_.front(); + VELOX_CHECK(source->shouldRequestLocked()); + requestSpecs.push_back({std::move(source), requestBytes}); + emptySources_.pop(); + return requestSpecs; + } + + VELOX_CHECK_EQ(producingSources_.size(), 1); + auto& source = producingSources_.front().source; + VELOX_CHECK(source->shouldRequestLocked()); + VELOX_CHECK(!producingSources_.front().remainingBytes.empty()); + requestSpecs.push_back({std::move(source), requestBytes}); + producingSources_.pop(); + return requestSpecs; +} + ExchangeClient::~ExchangeClient() { close(); } diff --git a/velox/exec/ExchangeClient.h b/velox/exec/ExchangeClient.h index b99fc49e885..cb279d74184 100644 --- a/velox/exec/ExchangeClient.h +++ b/velox/exec/ExchangeClient.h @@ -26,7 +26,6 @@ class ExchangeClient : public std::enable_shared_from_this { public: static constexpr int32_t kDefaultMaxQueuedBytes = 32 << 20; // 32 MB. static constexpr std::chrono::milliseconds kRequestDataMaxWait{100}; - static inline const std::string kBackgroundCpuTimeMs = "backgroundCpuTimeMs"; ExchangeClient( std::string taskId, @@ -36,23 +35,29 @@ class ExchangeClient : public std::enable_shared_from_this { uint64_t minOutputBatchBytes, memory::MemoryPool* pool, folly::Executor* executor, - int32_t requestDataSizesMaxWaitSec = 10) + int32_t requestDataSizesMaxWaitSec = 10, + bool skipRequestDataSizeWithSingleSource = false, + bool lazyFetching = false) : taskId_{std::move(taskId)}, destination_(destination), maxQueuedBytes_{maxQueuedBytes}, - kRequestDataSizesMaxWaitSec_{requestDataSizesMaxWaitSec}, + requestDataSizesMaxWaitSec_{requestDataSizesMaxWaitSec}, pool_(pool), executor_(executor), - queue_(std::make_shared( - numberOfConsumers, - minOutputBatchBytes)), + queue_( + std::make_shared( + numberOfConsumers, + minOutputBatchBytes)), // See comment in 'pickSourcesToRequestLocked' for why this is needed // for 'minOutputBatchBytes_'. Note: ExchangeQueue does not need max(1, // minOutputBatchBytes) because for 'MergeExchangeSource', we want // ExchangeQueue 'minOutputBatchBytes' to be be 0 so that it always // unblocks. In short, 0 has a special meaning for ExchangeQueue minOutputBatchBytes_( - std::max(static_cast(1), minOutputBatchBytes)) { + std::max(static_cast(1), minOutputBatchBytes)), + skipRequestDataSizeWithSingleSource_( + skipRequestDataSizeWithSingleSource), + lazyFetching_(lazyFetching) { VELOX_CHECK_NOT_NULL(pool_); VELOX_CHECK_NOT_NULL(executor_); // NOTE: the executor is used to run async response callback from the @@ -87,8 +92,8 @@ class ExchangeClient : public std::enable_shared_from_this { // Returns runtime statistics aggregated across all of the exchange sources. // ExchangeClient is expected to report background CPU time by including a - // runtime metric named ExchangeClient::kBackgroundCpuTimeMs. - folly::F14FastMap stats() const; + // runtime metric named Operator::kBackgroundCpuTimeNanos. + folly::F14FastMap stats(); const std::shared_ptr& queue() const { return queue_; @@ -102,7 +107,7 @@ class ExchangeClient : public std::enable_shared_from_this { /// /// The data may be compressed, in which case 'maxBytes' applies to compressed /// size. - std::vector> + std::vector> next(int consumerId, uint32_t maxBytes, bool* atEnd, ContinueFuture* future); std::string toString() const; @@ -110,7 +115,7 @@ class ExchangeClient : public std::enable_shared_from_this { folly::dynamic toJson() const; std::chrono::seconds requestDataSizesMaxWaitSec() const { - return kRequestDataSizesMaxWaitSec_; + return requestDataSizesMaxWaitSec_; } const std::unordered_set& getRemoteTaskIdList() const { @@ -131,15 +136,42 @@ class ExchangeClient : public std::enable_shared_from_this { std::vector remainingBytes; }; + // Selects exchange sources to request data from based on available queue + // capacity. Handles multiple sources by first requesting data sizes from all + // empty sources, then requesting actual data from producing sources based on + // their remaining bytes and available capacity. May initiate out-of-band + // transfers for large pages that exceed capacity to avoid deadlock + // situations. For single source case, delegates to + // pickupSingleSourceToRequestLocked which sets max request bytes based on + // available queue space instead of reported remaining bytes from exchange + // sources. std::vector pickSourcesToRequestLocked(); + // Specialized single-source request picker for single-source exchange + // clients. Sets the max request bytes based on available space in the queue + // rather than the reported remaining bytes from exchange sources. The reason + // is that single source has no other alternative so just fetch as much as + // possible from that source. Returns a request spec for the single source + // when there is available capacity in the queue and no pending requests. If + // capacity is unavailable or requests are already pending, returns empty + // vector. + std::vector pickupSingleSourceToRequestLocked(); void request(std::vector&& requestSpecs); + /// Returns true if skip request data size optimization is enabled for single + /// source exchanges. + bool skipRequestDataSizeWithSingleSource() const { + return skipRequestDataSizeWithSingleSource_ && queue_->hasNoMoreSources() && + sources_.size() == 1; + } + + folly::F14FastMap collectStatsLocked() const; + // Handy for ad-hoc logging. const std::string taskId_; const int destination_; const int64_t maxQueuedBytes_; - const std::chrono::seconds kRequestDataSizesMaxWaitSec_; + const std::chrono::seconds requestDataSizesMaxWaitSec_; memory::MemoryPool* const pool_; folly::Executor* const executor_; @@ -149,10 +181,21 @@ class ExchangeClient : public std::enable_shared_from_this { std::vector> sources_; bool closed_{false}; + folly::F14FastMap stats_; + // The minimum byte size the consumer is expected to consume from // the exchange queue. const uint64_t minOutputBatchBytes_; + // Enable single source exchange optimization query config flag + // when there is only one exchange source. + const bool skipRequestDataSizeWithSingleSource_; + + // If true, defer fetching until next() is called. + // If false (default), start fetching data immediately when remote tasks are + // added. + const bool lazyFetching_; + // Total number of bytes in flight. int64_t totalPendingBytes_{0}; diff --git a/velox/exec/ExchangeQueue.cpp b/velox/exec/ExchangeQueue.cpp index 7d1c24369b9..9a9114ae2c4 100644 --- a/velox/exec/ExchangeQueue.cpp +++ b/velox/exec/ExchangeQueue.cpp @@ -22,40 +22,12 @@ using facebook::velox::common::testutil::TestValue; namespace facebook::velox::exec { -SerializedPage::SerializedPage( - std::unique_ptr iobuf, - std::function onDestructionCb, - std::optional numRows) - : iobuf_(std::move(iobuf)), - iobufBytes_(chainBytes(*iobuf_.get())), - numRows_(numRows), - onDestructionCb_(onDestructionCb) { - VELOX_CHECK_NOT_NULL(iobuf_); - for (auto& buf : *iobuf_) { - int32_t bufSize = buf.size(); - ranges_.push_back(ByteRange{ - const_cast(reinterpret_cast(buf.data())), - bufSize, - 0}); - } -} - -SerializedPage::~SerializedPage() { - if (onDestructionCb_) { - onDestructionCb_(*iobuf_.get()); - } -} - -std::unique_ptr SerializedPage::prepareStreamForDeserialize() { - return std::make_unique(std::move(ranges_)); -} - void ExchangeQueue::noMoreSources() { std::vector promises; { std::lock_guard l(mutex_); noMoreSources_ = true; - promises = checkCompleteLocked(); + promises = checkNoMoreInput(); } clearPromises(promises); } @@ -70,20 +42,21 @@ void ExchangeQueue::close() { } int64_t ExchangeQueue::minOutputBatchBytesLocked() const { - // always allow to unblock when at end - if (atEnd_) { + // Allow to unblock if no more input. + if (noMoreInput_) { return 0; } - // At most 1% of received bytes so far to minimize latency for small exchanges + // At most 1% of received bytes so far to minimize latency for small + // exchanges. return std::min(minOutputBatchBytes_, receivedBytes_ / 100); } void ExchangeQueue::enqueueLocked( - std::unique_ptr&& page, + std::unique_ptr&& page, std::vector& promises) { if (page == nullptr) { ++numCompleted_; - auto completedPromises = checkCompleteLocked(); + auto completedPromises = checkNoMoreInput(); promises.reserve(promises.size() + completedPromises.size()); for (auto& promise : completedPromises) { promises.push_back(std::move(promise)); @@ -128,12 +101,12 @@ void ExchangeQueue::addPromiseLocked( *stalePromise = std::move(it->second); it->second = std::move(promise); } else { - promises_[consumerId] = std::move(promise); + promises_.emplace(consumerId, std::move(promise)); } VELOX_CHECK_LE(promises_.size(), numberOfConsumers_); } -std::vector> ExchangeQueue::dequeueLocked( +std::vector> ExchangeQueue::dequeueLocked( int consumerId, uint32_t maxBytes, bool* atEnd, @@ -156,11 +129,11 @@ std::vector> ExchangeQueue::dequeueLocked( return {}; } - std::vector> pages; + std::vector> pages; uint32_t pageBytes = 0; for (;;) { if (queue_.empty()) { - if (atEnd_) { + if (noMoreInput_) { *atEnd = true; } else if (pages.empty()) { addPromiseLocked(consumerId, future, stalePromise); @@ -189,9 +162,9 @@ void ExchangeQueue::setError(const std::string& error) { return; } error_ = error; - atEnd_ = true; - // NOTE: clear the serialized page queue as we won't consume from an - // errored queue. + noMoreInput_ = true; + // NOTE: clear the serialized page queue as we won't consume from an errored + // queue. queue_.clear(); promises = clearAllPromisesLocked(); } diff --git a/velox/exec/ExchangeQueue.h b/velox/exec/ExchangeQueue.h index 4f77360fdbc..91a633d5366 100644 --- a/velox/exec/ExchangeQueue.h +++ b/velox/exec/ExchangeQueue.h @@ -15,66 +15,11 @@ */ #pragma once -#include "velox/common/memory/ByteStream.h" +#include "velox/exec/SerializedPage.h" -namespace facebook::velox::exec { - -/// Corresponds to Presto SerializedPage, i.e. a container for serialize vectors -/// in Presto wire format. -class SerializedPage { - public: - /// Construct from IOBuf chain. - explicit SerializedPage( - std::unique_ptr iobuf, - std::function onDestructionCb = nullptr, - std::optional numRows = std::nullopt); - - ~SerializedPage(); +#include - /// Returns the size of the serialized data in bytes. - uint64_t size() const { - return iobufBytes_; - } - - std::optional numRows() const { - return numRows_; - } - - /// Makes 'input' ready for deserializing 'this' with - /// VectorStreamGroup::read(). - std::unique_ptr prepareStreamForDeserialize(); - - std::unique_ptr getIOBuf() const { - return iobuf_->clone(); - } - - private: - static int64_t chainBytes(folly::IOBuf& iobuf) { - int64_t size = 0; - for (auto& range : iobuf) { - size += range.size(); - } - return size; - } - - // Buffers containing the serialized data. The memory is owned by 'iobuf_'. - std::vector ranges_; - - // IOBuf holding the data in 'ranges_. - std::unique_ptr iobuf_; - - // Number of payload bytes in 'iobuf_'. - const int64_t iobufBytes_; - - // Number of payload rows, if provided. - const std::optional numRows_; - - // Callback that will be called on destruction of the SerializedPage, - // primarily used to free externally allocated memory backing folly::IOBuf - // from caller. Caller is responsible to pass in proper cleanup logic to - // prevent any memory leak. - std::function onDestructionCb_; -}; +namespace facebook::velox::exec { /// Queue of results retrieved from source. Owned by shared_ptr by /// Exchange and client threads and registered callbacks waiting @@ -108,7 +53,7 @@ class ExchangeQueue { /// returned in 'promises'. When 'page' is nullptr and the queue is not /// completed serving data, no 'promises' will be added and returned. void enqueueLocked( - std::unique_ptr&& page, + std::unique_ptr&& page, std::vector& promises); /// If data is permanently not available, e.g. the source cannot be @@ -127,7 +72,7 @@ class ExchangeQueue { /// /// The data may be compressed, in which case 'maxBytes' applies to compressed /// size. - std::vector> dequeueLocked( + std::vector> dequeueLocked( int consumerId, uint32_t maxBytes, bool* atEnd, @@ -162,6 +107,10 @@ class ExchangeQueue { void noMoreSources(); + bool hasNoMoreSources() const { + return noMoreSources_; + } + void close(); private: @@ -170,9 +119,9 @@ class ExchangeQueue { return clearAllPromisesLocked(); } - std::vector checkCompleteLocked() { + std::vector checkNoMoreInput() { if (noMoreSources_ && numCompleted_ == numSources_) { - atEnd_ = true; + noMoreInput_ = true; return clearAllPromisesLocked(); } return {}; @@ -193,7 +142,9 @@ class ExchangeQueue { } std::vector clearAllPromisesLocked() { - std::vector promises(promises_.size()); + std::vector promises; + promises.reserve(promises_.size()); + auto it = promises_.begin(); while (it != promises_.end()) { promises.push_back(std::move(it->second)); @@ -216,11 +167,14 @@ class ExchangeQueue { int numCompleted_{0}; int numSources_{0}; - bool noMoreSources_{false}; - bool atEnd_{false}; + tsan_atomic noMoreSources_{false}; + // True if no more pages will be enqueued. This can be due to all sources + // completing normally or an error. Note that the queue itself may still + // contain data to be consumed. + bool noMoreInput_{false}; std::mutex mutex_; - std::deque> queue_; + std::deque> queue_; // The map from consumer id to the waiting promise folly::F14FastMap promises_; diff --git a/velox/exec/ExchangeSource.h b/velox/exec/ExchangeSource.h index 2b0f74fa205..79ec65781b8 100644 --- a/velox/exec/ExchangeSource.h +++ b/velox/exec/ExchangeSource.h @@ -106,7 +106,7 @@ class ExchangeSource : public std::enable_shared_from_this { // Returns runtime statistics. ExchangeSource is expected to report // background CPU time by including a runtime metric named - // ExchangeClient::kBackgroundCpuTimeMs. + // Operator::kBackgroundCpuTimeNanos. virtual folly::F14FastMap stats() const { VELOX_UNREACHABLE(); } diff --git a/velox/exec/Expand.cpp b/velox/exec/Expand.cpp index 5d866b888c1..39a35053c0e 100644 --- a/velox/exec/Expand.cpp +++ b/velox/exec/Expand.cpp @@ -15,6 +15,8 @@ */ #include "velox/exec/Expand.h" +#include "velox/exec/OperatorType.h" + namespace facebook::velox::exec { Expand::Expand( @@ -26,11 +28,12 @@ Expand::Expand( expandNode->outputType(), operatorId, expandNode->id(), - "Expand") { + OperatorType::kExpand) { const auto& inputType = expandNode->inputType(); const auto numRows = expandNode->projections().size(); fieldProjections_.reserve(numRows); constantProjections_.reserve(numRows); + constantOutputs_.reserve(numRows); const auto numColumns = expandNode->names().size(); for (const auto& rowProjections : expandNode->projections()) { std::vector rowProjection; @@ -58,6 +61,26 @@ Expand::Expand( } } +void Expand::initialize() { + Operator::initialize(); + if (constantProjections_.empty()) { + return; + } + const auto numColumns = constantProjections_[0].size(); + for (const auto& projections : constantProjections_) { + std::vector constantOutput; + constantOutput.reserve(numColumns); + for (const auto& constant : projections) { + if (constant) { + constantOutput.push_back(constant->toConstantVector(pool())); + } else { + constantOutput.push_back(nullptr); + } + } + constantOutputs_.emplace_back(std::move(constantOutput)); + } +} + bool Expand::needsInput() const { return !noMoreInput_ && input_ == nullptr; } @@ -81,21 +104,13 @@ RowVectorPtr Expand::getOutput() { std::vector outputColumns(outputType_->size()); const auto& rowProjection = fieldProjections_[rowIndex_]; - const auto& constantProjection = constantProjections_[rowIndex_]; + const auto& constantProjection = constantOutputs_[rowIndex_]; const auto numColumns = rowProjection.size(); for (auto i = 0; i < numColumns; ++i) { if (rowProjection[i] == kConstantChannel) { - const auto& constantExpr = constantProjection[i]; - if (constantExpr->value().isNull()) { - // Add null column. - outputColumns[i] = BaseVector::createNullConstant( - outputType_->childAt(i), numInput, pool()); - } else { - // Add constant column. - outputColumns[i] = BaseVector::createConstant( - constantExpr->type(), constantExpr->value(), numInput, pool()); - } + outputColumns[i] = + BaseVector::wrapInConstant(numInput, 0, constantProjection[i]); } else { outputColumns[i] = input_->childAt(rowProjection[i]); } diff --git a/velox/exec/Expand.h b/velox/exec/Expand.h index 97c737c1d1f..adf87a71526 100644 --- a/velox/exec/Expand.h +++ b/velox/exec/Expand.h @@ -42,11 +42,15 @@ class Expand : public Operator { } private: + void initialize() override; + std::vector> fieldProjections_; std::vector>> constantProjections_; + std::vector> constantOutputs_; + // Used to indicate the index of fieldProjections_. int32_t rowIndex_{0}; }; diff --git a/velox/exec/FilterProject.cpp b/velox/exec/FilterProject.cpp index b8189539d8f..a583b34db2e 100644 --- a/velox/exec/FilterProject.cpp +++ b/velox/exec/FilterProject.cpp @@ -15,6 +15,8 @@ */ #include "velox/exec/FilterProject.h" #include "velox/core/Expressions.h" +#include "velox/exec/Driver.h" +#include "velox/exec/OperatorType.h" #include "velox/expression/Expr.h" #include "velox/expression/FieldReference.h" @@ -72,6 +74,65 @@ std::vector splitStats( return {std::move(projectStats), std::move(filterStats)}; } +// Unwraps dictionary/constant/sequence encodings to get to the underlying +// Lazy vector if it exists. Otherwise, returns nullptr. +const BaseVector* unwrapToLazy(const VectorPtr& vector) { + switch (vector->encoding()) { + case VectorEncoding::Simple::CONSTANT: + case VectorEncoding::Simple::DICTIONARY: + case VectorEncoding::Simple::SEQUENCE: { + return vector->valueVector() ? unwrapToLazy(vector->valueVector()) + : nullptr; + } + case VectorEncoding::Simple::LAZY: + return vector.get(); + default: + return nullptr; + } +} + +// Returns a unique identity pointer for Lazy vectors. For other vectors, +// returns nullptr. +const void* lazyIdentityPtr(const VectorPtr& vec) { + if (!vec) { + return nullptr; + } + const BaseVector* lazyVec = unwrapToLazy(vec); + if (!lazyVec) { + return nullptr; + } + return static_cast(lazyVec); +} + +// Load the reused lazy vectors in the output. A lazy vector cannot be reused +// across different output fields because its contents may be loaded via a +// hook during pushdown. Accessing a loaded vector in this case can lead to +// incorrect results or even a crash. +void loadReusedLazyVectors(const RowVectorPtr& output) { + if (!output || !output->containsLazyNotLoaded()) { + return; + } + + const auto& vectors = output->children(); + + // Build a map of lazy identity pointers to their occurrence count. + std::unordered_map lazyIdentityCounts; + for (const auto& vector : vectors) { + const void* id = lazyIdentityPtr(vector); + if (id) { + lazyIdentityCounts[id]++; + } + } + + // Load only the vectors whose lazy identity appears more than once. + for (auto& vector : vectors) { + const void* id = lazyIdentityPtr(vector); + if (id && lazyIdentityCounts[id] > 1) { + vector->loadedVector(); + } + } +} + } // namespace FilterProject::FilterProject( @@ -84,7 +145,7 @@ FilterProject::FilterProject( project ? project->outputType() : filter->outputType(), operatorId, project ? project->id() : filter->id(), - "FilterProject"), + OperatorType::kFilterProject), hasFilter_(filter != nullptr), lazyDereference_( dynamic_cast(project.get()) != @@ -148,6 +209,11 @@ void FilterProject::initialize() { } filter_.reset(); project_.reset(); + + if (const auto* traceCtx = operatorCtx_->driverCtx()->traceCtx(); + traceCtx && traceCtx->shouldTrace(*this)) { + exprs_->maybeSetupTracers(*this, *traceCtx); + } } void FilterProject::addInput(RowVectorPtr input) { @@ -185,7 +251,9 @@ RowVectorPtr FilterProject::getOutput() { if (!hasFilter_) { VELOX_CHECK(!isIdentityProjection_); auto results = project(*rows, evalCtx); - return fillOutput(size, nullptr, results); + auto output = fillOutput(size, nullptr, results); + loadReusedLazyVectors(output); + return output; } // evaluate filter @@ -205,10 +273,12 @@ RowVectorPtr FilterProject::getOutput() { results = project(*rows, evalCtx); } - return fillOutput( + auto output = fillOutput( numOut, allRowsSelected ? nullptr : filterEvalCtx_.selectedIndices, results); + loadReusedLazyVectors(output); + return output; } std::vector FilterProject::project( diff --git a/velox/exec/FilterProject.h b/velox/exec/FilterProject.h index 79aafcc4a8d..85947b21d97 100644 --- a/velox/exec/FilterProject.h +++ b/velox/exec/FilterProject.h @@ -59,6 +59,7 @@ class FilterProject : public Operator { void close() override { Operator::close(); if (exprs_ != nullptr) { + exprs_->finishTracers(); exprs_->clear(); } else { VELOX_CHECK(!initialized_); @@ -82,6 +83,12 @@ class FilterProject : public Operator { /// tracking is enabled via query config. OperatorStats stats(bool clear) override; + /// Returns the filterNode, call this function before initialize the operator, + /// this field is reset in function initialize. + const std::shared_ptr& filterNode() const { + return filter_; + } + private: // Evaluate filter on all rows. Return number of rows that passed the filter. // Populate filterEvalCtx_.selectedBits and selectedIndices with the indices diff --git a/velox/exec/GroupId.cpp b/velox/exec/GroupId.cpp index f64c4e51fc3..cf51b4ed2ba 100644 --- a/velox/exec/GroupId.cpp +++ b/velox/exec/GroupId.cpp @@ -15,6 +15,8 @@ */ #include "velox/exec/GroupId.h" +#include "velox/exec/OperatorType.h" + namespace facebook::velox::exec { GroupId::GroupId( @@ -26,7 +28,7 @@ GroupId::GroupId( groupIdNode->outputType(), operatorId, groupIdNode->id(), - "GroupId") { + OperatorType::kGroupId) { const auto& inputType = groupIdNode->sources()[0]->outputType(); std::unordered_map diff --git a/velox/exec/GroupingSet.cpp b/velox/exec/GroupingSet.cpp index 63bbf4b5c3c..ab885023b27 100644 --- a/velox/exec/GroupingSet.cpp +++ b/velox/exec/GroupingSet.cpp @@ -15,6 +15,7 @@ */ #include "velox/exec/GroupingSet.h" #include "velox/common/testutil/TestValue.h" +#include "velox/exec/OperatorUtils.h" #include "velox/exec/Task.h" using facebook::velox::common::testutil::TestValue; @@ -55,26 +56,26 @@ GroupingSet::GroupingSet( tsan_atomic* nonReclaimableSection, const core::QueryConfig* queryConfig, memory::MemoryPool* pool, - folly::Synchronized* spillStats) + exec::SpillStats* spillStats) : preGroupedKeyChannels_(std::move(preGroupedKeys)), - groupingKeyOutputProjections_(std::move(groupingKeyOutputProjections)), - hashers_(std::move(hashers)), - isGlobal_(hashers_.empty()), + isGlobal_(hashers.empty()), isPartial_(isPartial), isRawInput_(isRawInput), + ignoreNullKeys_(ignoreNullKeys), + isAdaptive_(queryConfig->hashAdaptivityEnabled()), + globalGroupingSets_(globalGroupingSets), + nonReclaimableSection_(nonReclaimableSection), + spillConfig_(spillConfig), queryConfig_(queryConfig), pool_(pool), + spillStats_(spillStats), + groupingKeyOutputProjections_(std::move(groupingKeyOutputProjections)), + hashers_(std::move(hashers)), aggregates_(std::move(aggregates)), masks_(extractMaskChannels(aggregates_)), - ignoreNullKeys_(ignoreNullKeys), - globalGroupingSets_(globalGroupingSets), groupIdChannel_(groupIdChannel), - spillConfig_(spillConfig), - nonReclaimableSection_(nonReclaimableSection), stringAllocator_(pool_), - rows_(pool_), - isAdaptive_(queryConfig_->hashAdaptivityEnabled()), - spillStats_(spillStats) { + rows_(pool_) { VELOX_CHECK_NOT_NULL(nonReclaimableSection_); VELOX_CHECK(pool_->trackUsage()); @@ -122,6 +123,10 @@ GroupingSet::GroupingSet( } else { distinctAggregations_.push_back(nullptr); } + + if (aggregate.function->supportsCompact()) { + hasCompactableAggregates_ = true; + } } } @@ -131,15 +136,16 @@ GroupingSet::~GroupingSet() { } } -std::unique_ptr GroupingSet::createForMarkDistinct( +std::unique_ptr GroupingSet::createForDistinct( const RowTypePtr& inputType, std::vector>&& hashers, + std::vector&& preGroupedKeys, OperatorCtx* operatorCtx, tsan_atomic* nonReclaimableSection) { return std::make_unique( inputType, std::move(hashers), - /*preGroupedKeys=*/std::vector{}, + std::move(preGroupedKeys), /*groupingKeyOutputProjections=*/std::vector{}, /*aggregates=*/std::vector{}, /*ignoreNullKeys=*/false, @@ -172,6 +178,7 @@ bool equalKeys( } // namespace void GroupingSet::addInput(const RowVectorPtr& input, bool mayPushdown) { + drainedNewGroups_ = {}; if (isGlobal_) { addGlobalAggregationInput(input, mayPushdown); return; @@ -205,6 +212,7 @@ void GroupingSet::addInput(const RowVectorPtr& input, bool mayPushdown) { } void GroupingSet::noMoreInput() { + drainedNewGroups_ = {}; noMoreInput_ = true; if (remainingInput_) { @@ -230,6 +238,38 @@ bool GroupingSet::hasSpilled() const { return outputSpiller_ != nullptr; } +uint64_t GroupingSet::compact() { + VELOX_CHECK(hasCompactableAggregates_); + + uint64_t freedBytes = 0; + + if (isGlobal_) { + if (globalAggregationInitialized_) { + VELOX_CHECK_NOT_NULL(lookup_); + VELOX_CHECK_EQ(lookup_->hits.size(), 1); + char* group = lookup_->hits[0]; + for (auto& aggregate : aggregates_) { + freedBytes += aggregate.function->compact(folly::Range(&group, 1)); + } + } + } else if (table_ != nullptr) { + auto* rows = table_->rows(); + if (rows != nullptr && rows->numRows() > 0) { + RowContainerIterator iter; + std::vector groups(1'000); + while (const auto numRows = rows->listRows( + &iter, static_cast(groups.size()), groups.data())) { + for (auto& aggregate : aggregates_) { + freedBytes += + aggregate.function->compact(folly::Range(groups.data(), numRows)); + } + } + } + } + + return freedBytes; +} + bool GroupingSet::hasOutput() { return noMoreInput_ || remainingInput_; } @@ -238,7 +278,7 @@ void GroupingSet::addInputForActiveRows( const RowVectorPtr& input, bool mayPushdown) { VELOX_CHECK(!isGlobal_); - if (!table_) { + if (table_ == nullptr) { createHashTable(); } ensureInputFits(input); @@ -320,6 +360,23 @@ void GroupingSet::addRemainingInput() { activeRows_.updateBounds(); addInputForActiveRows(remainingInput_, remainingMayPushdown_); + + if (isDistinct() && !preGroupedKeyChannels_.empty()) { + // Pre-grouped distinct aggregation does not currently spill + // (AggregationNode::canSpill returns false when preGroupedKeys is + // non-empty; see velox#3264). The captured row pointers below + // reference table_->rows() and would dangle after table_->clear() in + // any spill path. If spill becomes enabled here, restore a + // materialize-before-clear path or otherwise preserve these rows. + VELOX_CHECK(!hasSpilled()); + const auto& newGroups = lookup_->newGroups; + drainedNewGroups_.clear(); + drainedNewGroups_.reserve(newGroups.size()); + for (auto idx : newGroups) { + drainedNewGroups_.push_back(lookup_->hits[idx]); + } + } + remainingInput_.reset(); } @@ -605,7 +662,7 @@ bool GroupingSet::getGlobalAggregationOutput( initializeGlobalAggregation(); - auto groups = lookup_->hits.data(); + auto* groups = lookup_->hits.data(); for (int32_t i = 0; i < aggregates_.size(); ++i) { if (!aggregates_[i].sortingKeys.empty()) { continue; @@ -815,6 +872,30 @@ void GroupingSet::extractGroups( } } +bool GroupingSet::hasDrainedNewGroups() const { + return !drainedNewGroups_.empty(); +} + +vector_size_t GroupingSet::drainedNewGroupsCount() const { + return static_cast(drainedNewGroups_.size()); +} + +void GroupingSet::extractDrainedNewGroups(const RowVectorPtr& result) { + VELOX_CHECK(isDistinct()); + VELOX_DCHECK(!drainedNewGroups_.empty()); + result->resize(drainedNewGroups_.size()); + auto* rowContainer = table_->rows(); + for (vector_size_t i = 0; i < result->childrenSize(); ++i) { + auto& keyVector = result->childAt(i); + rowContainer->extractColumn( + drainedNewGroups_.data(), + drainedNewGroups_.size(), + groupingKeyOutputProjections_[i], + keyVector); + } + drainedNewGroups_.clear(); +} + void GroupingSet::resetTable(bool freeTable) { if (table_ != nullptr) { table_->clear(freeTable); @@ -932,8 +1013,11 @@ void GroupingSet::ensureInputFits(const RowVectorPtr& input) { } LOG(WARNING) << "Failed to reserve " << succinctBytes(targetIncrementBytes) << " for memory pool " << pool_->name() - << ", usage: " << succinctBytes(pool_->usedBytes()) - << ", reservation: " << succinctBytes(pool_->reservedBytes()); + << ", root pool: " << pool_->root()->name() + << ", used: " << succinctBytes(pool_->usedBytes()) + << ", reservation: " << succinctBytes(pool_->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool_->root()->reservedBytes()); } void GroupingSet::ensureOutputFits() { @@ -972,8 +1056,11 @@ void GroupingSet::ensureOutputFits() { LOG(WARNING) << "Failed to reserve " << succinctBytes(outputBufferSizeToReserve) << " for memory pool " << pool_->name() - << ", usage: " << succinctBytes(pool_->usedBytes()) - << ", reservation: " << succinctBytes(pool_->reservedBytes()); + << ", root pool: " << pool_->root()->name() + << ", used: " << succinctBytes(pool_->usedBytes()) + << ", reservation: " << succinctBytes(pool_->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool_->root()->reservedBytes()); } RowTypePtr GroupingSet::makeSpillType() const { @@ -985,6 +1072,7 @@ RowTypePtr GroupingSet::makeSpillType() const { } std::vector names; + names.reserve(types.size()); for (auto i = 0; i < types.size(); ++i) { names.push_back(fmt::format("s{}", i)); } @@ -992,7 +1080,7 @@ RowTypePtr GroupingSet::makeSpillType() const { return ROW(std::move(names), std::move(types)); } -std::optional GroupingSet::spilledStats() const { +std::optional GroupingSet::spilledStats() const { if (!hasSpilled()) { return std::nullopt; } @@ -1100,6 +1188,8 @@ bool GroupingSet::getOutputWithSpill( false, false, false, + false, // hasCountFlag + false, false, pool_); @@ -1136,8 +1226,7 @@ bool GroupingSet::prepareNextSpillPartitionOutput() { auto it = spillPartitionSet_.begin(); VELOX_CHECK_NE(outputSpillPartition_, it->first.partitionNumber()); outputSpillPartition_ = it->first.partitionNumber(); - merge_ = it->second->createOrderedReader( - spillConfig_->readBufferSize, pool_, spillStats_); + merge_ = it->second->createOrderedReader(*spillConfig_, pool_, spillStats_); spillPartitionSet_.erase(it); return true; } @@ -1246,6 +1335,9 @@ void GroupingSet::prepareSpillResultWithoutAggregates( spillResultWithoutAggregates_->childAt(groupingKeyOutputProjections_[i]) = std::move(result->childAt(i)); } + + spillSources_.resize(maxOutputRows); + spillSourceRows_.resize(maxOutputRows); } void GroupingSet::projectResult(const RowVectorPtr& result) { @@ -1266,6 +1358,7 @@ bool GroupingSet::mergeNextWithoutAggregates( VELOX_CHECK_EQ( numDistinctSpillFilesPerPartition_.size(), 1 << spillConfig_->numPartitionBits); + VELOX_CHECK(pool_ == result->pool()); // We are looping over sorted rows produced by tree-of-losers. We logically // split the stream into runs of duplicate rows. As we process each run we @@ -1280,12 +1373,15 @@ bool GroupingSet::mergeNextWithoutAggregates( // less than 'numDistinctSpillFilesPerPartition_'. bool newDistinct{true}; int32_t numOutputRows{0}; + int32_t outputSize{0}; + bool endOfBatch = false; prepareSpillResultWithoutAggregates(maxOutputRows, result); - while (numOutputRows < maxOutputRows) { + while (numOutputRows + outputSize < maxOutputRows) { const auto next = merge_->nextWithEquals(); auto* stream = next.first; if (stream == nullptr) { + VELOX_CHECK_EQ(outputSize, 0); if (numOutputRows > 0) { break; } @@ -1300,17 +1396,40 @@ bool GroupingSet::mergeNextWithoutAggregates( numDistinctSpillFilesPerPartition_[outputSpillPartition_]) { newDistinct = false; } - if (next.second) { - stream->pop(); - continue; - } - if (newDistinct) { + auto index = stream->currentIndex(&endOfBatch); + if (!next.second && newDistinct) { // Yield result for new distinct. - spillResultWithoutAggregates_->copy( - &stream->current(), numOutputRows++, stream->currentIndex(), 1); + spillSources_[outputSize] = &stream->current(); + spillSourceRows_[outputSize] = index; + ++outputSize; + } + + if (FOLLY_UNLIKELY(endOfBatch)) { + // The stream is at end of input batch. Need to copy out the rows before + // fetching next batch in 'pop'. + gatherCopy( + spillResultWithoutAggregates_.get(), + numOutputRows, + outputSize, + spillSources_, + spillSourceRows_); + numOutputRows += outputSize; + outputSize = 0; } stream->pop(); - newDistinct = true; + // Reset newDistinct flag for new row. + if (!next.second) { + newDistinct = true; + } + } + if (FOLLY_LIKELY(outputSize != 0)) { + gatherCopy( + spillResultWithoutAggregates_.get(), + numOutputRows, + outputSize, + spillSources_, + spillSourceRows_); + numOutputRows += outputSize; } spillResultWithoutAggregates_->resize(numOutputRows); projectResult(result); @@ -1322,12 +1441,17 @@ void GroupingSet::initializeRow(SpillMergeStream& stream, char* row) { mergeRows_->store(stream.decoded(i), stream.currentIndex(), mergeState_, i); } vector_size_t zero = 0; - for (auto& aggregate : aggregates_) { - if (!aggregate.sortingKeys.empty()) { + for (auto i = 0; i < aggregates_.size(); ++i) { + if (!aggregates_[i].sortingKeys.empty()) { continue; } - aggregate.function->initializeNewGroups( - &row, folly::Range(&zero, 1)); + if (!aggregates_[i].distinct) { + aggregates_[i].function->initializeNewGroups( + &row, folly::Range(&zero, 1)); + } else { + distinctAggregations_[i]->initializeNewGroups( + &row, folly::Range(&zero, 1)); + } } if (sortedAggregations_ != nullptr) { @@ -1380,11 +1504,22 @@ void GroupingSet::updateRow(SpillMergeStream& input, char* row) { } mergeSelection_.setValid(input.currentIndex(), false); + auto sortOrDistinctAggIndex = aggregates_.size() + keyChannels_.size(); if (sortedAggregations_ != nullptr) { - const auto& vector = - input.current().childAt(aggregates_.size() + keyChannels_.size()); + const auto& vector = input.current().childAt(sortOrDistinctAggIndex); sortedAggregations_->addSingleGroupSpillInput( row, vector, input.currentIndex()); + ++sortOrDistinctAggIndex; + } + + for (const auto& distinctAgg : distinctAggregations_) { + if (distinctAgg != nullptr) { + distinctAgg->addSingleGroupSpillInput( + row, + input.current().childAt(sortOrDistinctAggIndex), + input.currentIndex()); + ++sortOrDistinctAggIndex; + } } } @@ -1406,6 +1541,8 @@ void GroupingSet::abandonPartialAggregation() { false, false, false, + false, // hasCountFlag + false, false, pool_); initializeAggregates(aggregates_, *intermediateRows_, true); @@ -1502,8 +1639,9 @@ void GroupingSet::toIntermediate( &aggregateVector); } if (intermediateRows_) { - intermediateRows_->eraseRows(folly::Range( - intermediateGroups_.data(), intermediateGroups_.size())); + intermediateRows_->eraseRows( + folly::Range( + intermediateGroups_.data(), intermediateGroups_.size())); } // It's unnecessary to call function->clear() to reset the internal states of @@ -1526,7 +1664,7 @@ AggregationInputSpiller::AggregationInputSpiller( const HashBitRange& hashBitRange, const std::vector& sortingKeys, const common::SpillConfig* spillConfig, - folly::Synchronized* spillStats) + exec::SpillStats* spillStats) : SpillerBase( container, std::move(rowType), @@ -1542,7 +1680,7 @@ AggregationOutputSpiller::AggregationOutputSpiller( RowContainer* container, RowTypePtr rowType, const common::SpillConfig* spillConfig, - folly::Synchronized* spillStats) + exec::SpillStats* spillStats) : SpillerBase( container, std::move(rowType), diff --git a/velox/exec/GroupingSet.h b/velox/exec/GroupingSet.h index 43ce19eed1c..3d70a841aa1 100644 --- a/velox/exec/GroupingSet.h +++ b/velox/exec/GroupingSet.h @@ -15,13 +15,14 @@ */ #pragma once +#include "velox/common/base/TreeOfLosers.h" +#include "velox/common/file/FileSystems.h" #include "velox/exec/AggregateInfo.h" #include "velox/exec/AggregationMasks.h" #include "velox/exec/DistinctAggregations.h" #include "velox/exec/HashTable.h" #include "velox/exec/SortedAggregations.h" #include "velox/exec/Spiller.h" -#include "velox/exec/TreeOfLosers.h" #include "velox/exec/VectorHasher.h" namespace facebook::velox::exec { @@ -45,14 +46,18 @@ class GroupingSet { tsan_atomic* nonReclaimableSection, const core::QueryConfig* queryConfig, memory::MemoryPool* pool, - folly::Synchronized* spillStats); + exec::SpillStats* spillStats); ~GroupingSet(); - /// Used by MarkDistinct operator to identify rows with unique values. - static std::unique_ptr createForMarkDistinct( + /// Used by MarkDistinct and EnforceDistinct operators to identify rows with + /// unique values for a set of keys. + /// @param preGroupedKeys Subset of grouping keys that input is already + /// clustered on. + static std::unique_ptr createForDistinct( const RowTypePtr& inputType, std::vector>&& hashers, + std::vector&& preGroupedKeys, OperatorCtx* operatorCtx, tsan_atomic* nonReclaimableSection); @@ -109,6 +114,18 @@ class GroupingSet { const HashLookup& hashLookup() const; + /// Returns true if there are pending new-group rows from the most recent + /// tail auto-drain that have not yet been consumed. + bool hasDrainedNewGroups() const; + + /// Returns the number of pending drained new-group rows. + vector_size_t drainedNewGroupsCount() const; + + /// Extracts the pending drained new-group rows into 'result' by reading + /// directly from row-container pointers. Clears the pending state after + /// extraction. + void extractDrainedNewGroups(const RowVectorPtr& result); + /// Spills all the rows in container. void spill(); @@ -118,11 +135,22 @@ class GroupingSet { void spill(const RowContainerIterator& rowIterator); /// Returns the spiller stats including total bytes and rows spilled so far. - std::optional spilledStats() const; + std::optional spilledStats() const; /// Returns true if spilling has triggered on this grouping set. bool hasSpilled() const; + /// Performs lightweight memory compaction across all aggregates before + /// spilling. Iterates over all groups and calls Aggregate::compact() on each + /// aggregate function. Returns the total number of bytes freed. + uint64_t compact(); + + /// Returns true if any aggregate function supports lightweight memory + /// compaction. + bool hasCompactableAggregates() const { + return hasCompactableAggregates_; + } + /// Returns the hashtable stats. HashTableStats hashTableStats() const { return table_ ? table_->stats() : HashTableStats{}; @@ -133,6 +161,12 @@ class GroupingSet { return table_ ? table_->rows()->numRows() : 0; } + /// Returns the underlying hash table, or nullptr if it has not been created + /// yet. + BaseHashTable* table() const { + return table_.get(); + } + /// Frees hash tables and other state when giving up partial aggregation as /// non-productive. Must be called before toIntermediate() is used. void abandonPartialAggregation(); @@ -289,43 +323,45 @@ class GroupingSet { // 'toIntermediate'. std::vector accumulators(bool excludeToIntermediate); - std::vector keyChannels_; - // A subset of grouping keys on which the input is clustered. const std::vector preGroupedKeyChannels_; - // Provides the column projections for extracting the grouping keys from - // 'table_' for output. The vector index is the output channel and the value - // is the corresponding column index stored in 'table_'. - std::vector groupingKeyOutputProjections_; - - std::vector> hashers_; const bool isGlobal_; const bool isPartial_; const bool isRawInput_; + const bool ignoreNullKeys_; + const bool isAdaptive_; + // List of global grouping set numbers, if being used with a GROUPING SET. + const std::vector globalGroupingSets_; + // Indicates if this grouping set and the associated hash aggregation operator + // is under non-reclaimable execution section or not. + tsan_atomic* const nonReclaimableSection_; + + const common::SpillConfig* const spillConfig_; const core::QueryConfig* const queryConfig_; memory::MemoryPool* const pool_; + exec::SpillStats* const spillStats_; + + std::vector keyChannels_; + // Provides the column projections for extracting the grouping keys from + // 'table_' for output. The vector index is the output channel and the value + // is the corresponding column index stored in 'table_'. + std::vector groupingKeyOutputProjections_; + std::vector> hashers_; std::vector aggregates_; AggregationMasks masks_; std::unique_ptr sortedAggregations_; std::vector> distinctAggregations_; - const bool ignoreNullKeys_; + // Boolean indicating whether any aggregate supports compact(). + bool hasCompactableAggregates_{false}; uint64_t numInputRows_ = 0; - // List of global grouping set numbers, if being used with a GROUPING SET. - const std::vector globalGroupingSets_; // Column for groupId for a GROUPING SET. std::optional groupIdChannel_; - const common::SpillConfig* const spillConfig_; - - // Indicates if this grouping set and the associated hash aggregation operator - // is under non-reclaimable execution section or not. - tsan_atomic* const nonReclaimableSection_; - // Boolean indicating whether accumulators for a global aggregation (i.e. // aggregation with no grouping keys) have been initialized. bool globalAggregationInitialized_{false}; @@ -342,7 +378,6 @@ class GroupingSet { // aggregation HashStringAllocator stringAllocator_; memory::AllocationPool rows_; - const bool isAdaptive_; bool noMoreInput_{false}; @@ -354,12 +389,23 @@ class GroupingSet { // First row in remainingInput_ that needs to be processed. vector_size_t firstRemainingRow_; + // Populated by addRemainingInput(); reset at top of addInput()/noMoreInput(). + // Pointers reference table_->rows() and remain valid only while no spill + // path runs. Pre-grouped distinct aggregation does not currently spill, so + // the row container is not cleared between capture and consumption. + std::vector drainedNewGroups_; + // In case of distinct aggregation without aggregates and the grouping key // reordered, the spilled data is first loaded into // 'spillResultWithoutAggregates_' and then reordered back and load to // result. RowVectorPtr spillResultWithoutAggregates_{nullptr}; + // Records the source rows to copy to 'output_' in order. + std::vector spillSources_; + + std::vector spillSourceRows_; + // The value of mayPushdown flag specified in addInput() for the // 'remainingInput_'. bool remainingMayPushdown_; @@ -408,8 +454,6 @@ class GroupingSet { // Temporary for case where an aggregate in toIntermediate() outputs post-init // state of aggregate for all rows. std::vector firstGroup_; - - folly::Synchronized* const spillStats_; }; class AggregationInputSpiller : public SpillerBase { @@ -422,7 +466,7 @@ class AggregationInputSpiller : public SpillerBase { const HashBitRange& hashBitRange, const std::vector& sortingKeys, const common::SpillConfig* spillConfig, - folly::Synchronized* spillStats); + exec::SpillStats* spillStats); void spill(); @@ -444,7 +488,7 @@ class AggregationOutputSpiller : public SpillerBase { RowContainer* container, RowTypePtr rowType, const common::SpillConfig* spillConfig, - folly::Synchronized* spillStats); + exec::SpillStats* spillStats); void spill(const RowContainerIterator& startRowIter); diff --git a/velox/exec/HashAggregation.cpp b/velox/exec/HashAggregation.cpp index bcc3b4b97cb..ad81999bf5e 100644 --- a/velox/exec/HashAggregation.cpp +++ b/velox/exec/HashAggregation.cpp @@ -16,10 +16,14 @@ #include "velox/exec/HashAggregation.h" #include +#include "velox/common/testutil/TestValue.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/PrefixSort.h" #include "velox/exec/Task.h" #include "velox/expression/Expr.h" +using facebook::velox::common::testutil::TestValue; + namespace facebook::velox::exec { HashAggregation::HashAggregation( @@ -32,15 +36,22 @@ HashAggregation::HashAggregation( operatorId, aggregationNode->id(), aggregationNode->step() == core::AggregationNode::Step::kPartial - ? "PartialAggregation" - : "Aggregation", + ? OperatorType::kPartialAggregation + : OperatorType::kAggregation, aggregationNode->canSpill(driverCtx->queryConfig()) - ? driverCtx->makeSpillConfig(operatorId) + ? driverCtx->makeSpillConfig( + operatorId, + aggregationNode->step() == + core::AggregationNode::Step::kPartial + ? OperatorType::kPartialAggregation + : OperatorType::kAggregation) : std::nullopt), aggregationNode_(aggregationNode), isPartialOutput_(isPartialOutput(aggregationNode->step())), isGlobal_(aggregationNode->groupingKeys().empty()), isDistinct_(!isGlobal_ && aggregationNode->aggregates().empty()), + memoryCompactionEnabled_( + driverCtx->queryConfig().aggregationMemoryCompactionReclaimEnabled()), maxExtendedPartialAggregationMemoryUsage_( driverCtx->queryConfig().maxExtendedPartialAggregationMemoryUsage()), abandonPartialAggregationMinRows_( @@ -116,6 +127,8 @@ void HashAggregation::initialize() { operatorCtx_->pool(), spillStats_.get()); + hasCompactableAggregates_ = groupingSet_->hasCompactableAggregates(); + aggregationNode_.reset(); } @@ -220,8 +233,6 @@ void HashAggregation::updateRuntimeStats() { const auto& hashers = groupingSet_->hashLookup().hashers; uint64_t asRange{0}; uint64_t asDistinct{0}; - const auto hashTableStats = groupingSet_->hashTableStats(); - auto lockedStats = stats_.wlock(); auto& runtimeStats = lockedStats->runtimeStats; @@ -235,14 +246,9 @@ void HashAggregation::updateRuntimeStats() { } } - runtimeStats[BaseHashTable::kCapacity] = - RuntimeMetric(hashTableStats.capacity); - runtimeStats[BaseHashTable::kNumRehashes] = - RuntimeMetric(hashTableStats.numRehashes); - runtimeStats[BaseHashTable::kNumDistinct] = - RuntimeMetric(hashTableStats.numDistinct); - runtimeStats[BaseHashTable::kNumTombstones] = - RuntimeMetric(hashTableStats.numTombstones); + if (auto* table = groupingSet_->table()) { + table->addRuntimeStats(runtimeStats); + } } void HashAggregation::prepareOutput(vector_size_t size) { @@ -266,10 +272,13 @@ void HashAggregation::resetPartialOutputIfNeed() { { auto lockedStats = stats_.wlock(); lockedStats->addRuntimeStat( - "flushRowCount", RuntimeCounter(numOutputRows_)); - lockedStats->addRuntimeStat("flushTimes", RuntimeCounter(1)); + std::string(HashAggregation::kFlushRowCount), + RuntimeCounter(numOutputRows_)); + lockedStats->addRuntimeStat( + std::string(HashAggregation::kFlushTimes), RuntimeCounter(1)); lockedStats->addRuntimeStat( - "partialAggregationPct", RuntimeCounter(aggregationPct)); + std::string(HashAggregation::kPartialAggregationPct), + RuntimeCounter(saturateCast(aggregationPct))); } groupingSet_->resetTable(/*freeTable=*/false); partialFull_ = false; @@ -293,7 +302,6 @@ void HashAggregation::maybeIncreasePartialAggregationMemoryUsage( maxExtendedPartialAggregationMemoryUsage_)) { groupingSet_->abandonPartialAggregation(); pool()->release(); - addRuntimeStat("abandonedPartialAggregation", RuntimeCounter(1)); abandonedPartialAggregation_ = true; return; } @@ -332,6 +340,9 @@ RowVectorPtr HashAggregation::getOutput() { } prepareOutput(input_->size()); groupingSet_->toIntermediate(input_, output_); + addRuntimeStat( + std::string(HashAggregation::kAbandonedPartialAggregationRows), + RuntimeCounter(input_->size())); numOutputRows_ += input_->size(); input_ = nullptr; return output_; @@ -343,7 +354,7 @@ RowVectorPtr HashAggregation::getOutput() { // - distinct aggregation has new keys; // - running in partial streaming mode and have some output ready. if (!noMoreInput_ && !partialFull_ && !newDistincts_ && - !groupingSet_->hasOutput()) { + !groupingSet_->hasDrainedNewGroups() && !groupingSet_->hasOutput()) { input_ = nullptr; return nullptr; } @@ -379,6 +390,15 @@ RowVectorPtr HashAggregation::getDistinctOutput() { VELOX_CHECK(isDistinct_); VELOX_CHECK(!finished_); + if (groupingSet_->hasDrainedNewGroups()) { + auto size = groupingSet_->drainedNewGroupsCount(); + prepareOutput(size); + groupingSet_->extractDrainedNewGroups(output_); + numOutputRows_ += size; + resetPartialOutputIfNeed(); + return output_; + } + if (newDistincts_) { VELOX_CHECK_NOT_NULL(input_); @@ -458,13 +478,37 @@ void HashAggregation::reclaim( updateEstimatedOutputRowSize(); + // Try lightweight compaction first before spilling. + if (memoryCompactionEnabled_) { + uint64_t compactedBytes{0}; + if (hasCompactableAggregates_) { + compactedBytes = groupingSet_->compact(); + } + TestValue::adjust( + "facebook::velox::exec::HashAggregation::reclaim::compact", + &compactedBytes); + if (compactedBytes > 0) { + stats.reclaimedBytes += compactedBytes; + pool()->release(); + if (compactedBytes >= targetBytes) { + return; + } + } + } + + if (!canSpill()) { + return; + } + if (noMoreInput_) { if (groupingSet_->hasSpilled()) { LOG(WARNING) << "Can't reclaim from aggregation operator which has spilled and is under output processing, pool " - << pool()->name() - << ", memory usage: " << succinctBytes(pool()->usedBytes()) - << ", reservation: " << succinctBytes(pool()->reservedBytes()); + << pool()->name() << ", root pool: " << pool()->root()->name() + << ", used: " << succinctBytes(pool()->usedBytes()) + << ", reservation: " << succinctBytes(pool()->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool()->root()->reservedBytes()); return; } if (isDistinct_) { diff --git a/velox/exec/HashAggregation.h b/velox/exec/HashAggregation.h index 5cd44f77cb4..800c90d8539 100644 --- a/velox/exec/HashAggregation.h +++ b/velox/exec/HashAggregation.h @@ -15,6 +15,8 @@ */ #pragma once +#include + #include "velox/exec/GroupingSet.h" #include "velox/exec/Operator.h" @@ -22,6 +24,18 @@ namespace facebook::velox::exec { class HashAggregation : public Operator { public: + /// Runtime stat keys for hash aggregation. + /// Number of rows flushed in partial aggregation output. + static constexpr std::string_view kFlushRowCount = "flushRowCount"; + /// Number of partial aggregation flush operations. + static constexpr std::string_view kFlushTimes = "flushTimes"; + /// Ratio of output to input rows in partial aggregation as a percentage. + static constexpr std::string_view kPartialAggregationPct = + "partialAggregationPct"; + /// Number of rows emitted after partial aggregation was abandoned. + static constexpr std::string_view kAbandonedPartialAggregationRows = + "abandonedPartialAggregationRows"; + HashAggregation( int32_t operatorId, DriverCtx* driverCtx, @@ -45,6 +59,13 @@ class HashAggregation : public Operator { bool isFinished() override; + /// HashAggregation can reclaim memory via lightweight compaction even when + /// spilling is not enabled. + bool canReclaim() const override { + return (memoryCompactionEnabled_ && hasCompactableAggregates_) || + canSpill(); + } + void reclaim(uint64_t targetBytes, memory::MemoryReclaimer::Stats& stats) override; @@ -89,6 +110,7 @@ class HashAggregation : public Operator { const bool isPartialOutput_; const bool isGlobal_; const bool isDistinct_; + const bool memoryCompactionEnabled_; const int64_t maxExtendedPartialAggregationMemoryUsage_; // Minimum number of rows to see before deciding to give up on partial // aggregation. @@ -100,6 +122,11 @@ class HashAggregation : public Operator { int64_t maxPartialAggregationMemoryUsage_; std::unique_ptr groupingSet_; + // Cached from groupingSet_->hasCompactableAggregates() during initialize(). + // Stored separately to allow safe access from the arbitration thread without + // dereferencing groupingSet_. + bool hasCompactableAggregates_{false}; + // Size of a single output row estimated using // 'groupingSet_->estimateRowSize()'. If spilling, this value is set to max // 'groupingSet_->estimateRowSize()' across all accumulated data set. diff --git a/velox/exec/HashBuild.cpp b/velox/exec/HashBuild.cpp index 3dde65a5750..7760a1f5d83 100644 --- a/velox/exec/HashBuild.cpp +++ b/velox/exec/HashBuild.cpp @@ -18,6 +18,8 @@ #include "velox/common/base/Counters.h" #include "velox/common/base/StatsReporter.h" #include "velox/common/testutil/TestValue.h" +#include "velox/exec/HashTableCache.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/OperatorUtils.h" #include "velox/exec/Task.h" #include "velox/exec/VectorHasher.h" @@ -55,14 +57,22 @@ HashBuild::HashBuild( nullptr, operatorId, joinNode->id(), - "HashBuild", + OperatorType::kHashBuild, joinNode->canSpill(driverCtx->queryConfig()) - ? driverCtx->makeSpillConfig(operatorId) + ? driverCtx->makeSpillConfig(operatorId, OperatorType::kHashBuild) : std::nullopt), joinNode_(std::move(joinNode)), joinType_{joinNode_->joinType()}, nullAware_{joinNode_->isNullAware()}, + nullAsValue_{joinNode_->isNullAsValue()}, needProbedFlagSpill_{needRightSideJoin(joinType_)}, + dropDuplicates_(joinNode_->canDropDuplicates()), + vectorHasherMaxNumDistinct_( + driverCtx->queryConfig().joinBuildVectorHasherMaxNumDistinct()), + abandonHashBuildDedupMinRows_( + driverCtx->queryConfig().abandonHashBuildDedupMinRows()), + abandonHashBuildDedupMinPct_( + driverCtx->queryConfig().abandonHashBuildDedupMinPct()), joinBridge_(operatorCtx_->task()->getHashJoinBridgeLocked( operatorCtx_->driverCtx()->splitGroupId, planNodeId())), @@ -86,36 +96,136 @@ HashBuild::HashBuild( // Identify the non-key build side columns and make a decoder for each. const int32_t numDependents = inputType->size() - numKeys; - if (numDependents > 0) { - // Number of join keys (numKeys) may be less then number of input columns - // (inputType->size()). In this case numDependents is negative and cannot be - // used to call 'reserve'. This happens when we join different probe side - // keys with the same build side key: SELECT * FROM t LEFT JOIN u ON t.k1 = - // u.k AND t.k2 = u.k. - dependentChannels_.reserve(numDependents); - decoders_.reserve(numDependents); - } - for (auto i = 0; i < inputType->size(); ++i) { - if (keyChannelMap_.find(i) == keyChannelMap_.end()) { - dependentChannels_.emplace_back(i); - decoders_.emplace_back(std::make_unique()); + if (!dropDuplicates_) { + if (numDependents > 0) { + // Number of join keys (numKeys) may be less then number of input columns + // (inputType->size()). In this case numDependents is negative and cannot + // be used to call 'reserve'. This happens when we join different probe + // side keys with the same build side key: SELECT * FROM t LEFT JOIN u ON + // t.k1 = u.k AND t.k2 = u.k. + dependentChannels_.reserve(numDependents); + decoders_.reserve(numDependents); + } + + for (auto i = 0; i < inputType->size(); ++i) { + if (keyChannelMap_.find(i) == keyChannelMap_.end()) { + dependentChannels_.emplace_back(i); + decoders_.emplace_back(std::make_unique()); + } } } tableType_ = hashJoinTableType(joinNode_); - setupTable(); - setupSpiller(); + stateCleared_ = false; } void HashBuild::initialize() { Operator::initialize(); + if (setupCachedHashTable()) { + return; + } + + // Set up table and spiller now that cache state is initialized. + // This ensures tableMemoryPool() returns the cache's tablePool when enabled. + setupTable(); + setupSpiller(); + if (isAntiJoin(joinType_) && joinNode_->filter()) { setupFilterForAntiJoins(keyChannelMap_); } } +bool HashBuild::setupCachedHashTable() { + if (!joinNode_->useHashTableCache()) { + return false; + } + + const auto& queryId = operatorCtx_->task()->queryCtx()->queryId(); + cacheKey_ = fmt::format("{}:{}", queryId, planNodeId()); + + // Get or create the cache entry (which includes the pool). + // If another task is already building, future_ will be set. + auto* cache = HashTableCache::instance(); + auto* queryCtx = operatorCtx_->task()->queryCtx().get(); + cacheEntry_ = cache->get(cacheKey_, taskId(), queryCtx, &future_); + VELOX_CHECK_NOT_NULL(cacheEntry_); + VELOX_CHECK_NOT_NULL(cacheEntry_->tablePool); + + // Check if table is already built. + if (cacheEntry_->buildComplete) { + noMoreInput(); + return true; + } + + // Check if we're a waiter task (future was set by get). + if (future_.valid()) { + setState(State::kWaitForBuild); + return true; + } + + // This is the builder task - proceed with building. + return false; +} + +bool HashBuild::getHashTableFromCache() { + if (!useHashTableCache()) { + return false; + } + + if (!cacheEntry_->buildComplete) { + // Cache miss - we need to build the table. + stats_.wlock()->addRuntimeStat( + std::string(BaseHashTable::kHashTableCacheMiss), RuntimeCounter(1)); + return false; + } + + // Table already built by a previous task! Use it directly. + // Notify the bridge with the cached table. + // We pass a shared_ptr copy (not std::move) since the cache retains + // ownership. + joinBridge_->setHashTable( + cacheEntry_->table, {}, cacheEntry_->hasNullKeys, nullptr); + // Record cache hit metric. + stats_.wlock()->addRuntimeStat( + std::string(BaseHashTable::kHashTableCacheHit), RuntimeCounter(1)); + return true; +} + +void HashBuild::maybeSetHashTableInCache( + const std::shared_ptr& table) { + if (!useHashTableCache()) { + return; + } + auto* cache = HashTableCache::instance(); + cache->put(cacheKey(), table, joinHasNullKeys_); +} + +bool HashBuild::receivedCachedHashTable() { + if (!useHashTableCache() || future_.valid()) { + return false; + } + // Builder task drivers coordinate via allPeersFinished and should fall + // through to the kWaitForProbe path in isBlocked(). Only waiter task + // drivers (different taskId than the builder) should enter here. + VELOX_CHECK_NOT_NULL(cacheEntry_); + if (hashTableCacheBuilderTask()) { + return false; + } + // We were waiting on cached table from another task. + // Ensure that table is ready. + VELOX_CHECK( + cacheEntry_->buildComplete, + "Hash table cache build failed for key '{}'. " + "The builder task may have encountered an error (e.g., OOM).", + cacheKey_); + // Proceed through normal noMoreInput flow which will use the cache. + setRunning(); + noMoreInput(); + return true; +} + void HashBuild::setupTable() { VELOX_CHECK_NULL(table_); @@ -133,6 +243,7 @@ void HashBuild::setupTable() { for (int i = numKeys; i < tableType_->size(); ++i) { dependentTypes.emplace_back(tableType_->childAt(i)); } + auto& queryConfig = operatorCtx_->driverCtx()->queryConfig(); if (joinNode_->isRightJoin() || joinNode_->isFullJoin() || joinNode_->isRightSemiProjectJoin()) { // Do not ignore null keys. @@ -141,44 +252,47 @@ void HashBuild::setupTable() { dependentTypes, true, // allowDuplicates true, // hasProbedFlag - operatorCtx_->driverCtx() - ->queryConfig() - .minTableRowsForParallelJoinBuild(), - pool()); + false, // hasCountFlag + queryConfig.minTableRowsForParallelJoinBuild(), + tableMemoryPool()); } else { - // (Left) semi and anti join with no extra filter only needs to know whether - // there is a match. Hence, no need to store entries with duplicate keys. - const bool dropDuplicates = !joinNode_->filter() && - (joinNode_->isLeftSemiFilterJoin() || - joinNode_->isLeftSemiProjectJoin() || isAntiJoin(joinType_)); // Right semi join needs to tag build rows that were probed. const bool needProbedFlag = joinNode_->isRightSemiFilterJoin(); - if (isLeftNullAwareJoinWithFilter(joinNode_)) { + const bool hasCountFlag = joinNode_->isCountingJoin(); + if (nullAsValue_ || isLeftNullAwareJoinWithFilter(joinNode_)) { // We need to check null key rows in build side in case of null-aware anti // or left semi project join with filter set. table_ = HashTable::createForJoin( std::move(keyHashers), dependentTypes, - !dropDuplicates, // allowDuplicates + !dropDuplicates_, // allowDuplicates needProbedFlag, // hasProbedFlag - operatorCtx_->driverCtx() - ->queryConfig() - .minTableRowsForParallelJoinBuild(), - pool()); + hasCountFlag, + queryConfig.minTableRowsForParallelJoinBuild(), + tableMemoryPool()); } else { // Ignore null keys table_ = HashTable::createForJoin( std::move(keyHashers), dependentTypes, - !dropDuplicates, // allowDuplicates + !dropDuplicates_, // allowDuplicates needProbedFlag, // hasProbedFlag - operatorCtx_->driverCtx() - ->queryConfig() - .minTableRowsForParallelJoinBuild(), - pool()); + hasCountFlag, + queryConfig.minTableRowsForParallelJoinBuild(), + tableMemoryPool(), + queryConfig.hashProbeBloomFilterPushdownMaxSize()); } } analyzeKeys_ = table_->hashMode() != BaseHashTable::HashMode::kHash; + if (abandonHashBuildDedupMinPct_ == 0 && !joinNode_->isCountingJoin()) { + // Building a HashTable without duplicates is disabled if + // abandonBuildNoDupHashMinPct_ is 0. Counting joins always require dedup. + abandonHashBuildDedup_ = true; + table_->setAllowDuplicates(true); + return; + } + // Only create HashLookup when dedup is enabled. + lookup_ = std::make_unique(table_->hashers(), pool()); } void HashBuild::setupSpiller(SpillPartition* spillPartition) { @@ -214,12 +328,19 @@ void HashBuild::setupSpiller(SpillPartition* spillPartition) { numPartitionBits; // Disable spilling if exceeding the max spill level and the query might run // out of memory if the restored partition still can't fit in memory. - if (config->exceedSpillLevelLimit(startPartitionBit)) { + if (FOLLY_UNLIKELY(config->exceedSpillLevelLimit(startPartitionBit))) { RECORD_METRIC_VALUE(kMetricMaxSpillLevelExceededCount); LOG(WARNING) << "Exceeded spill level limit: " << config->maxSpillLevel << ", and disable spilling for memory pool: " - << pool()->name(); - ++spillStats_->wlock()->spillMaxLevelExceededCount; + << pool()->name() + << ", root pool: " << pool()->root()->name() + << ", used: " << succinctBytes(pool()->usedBytes()) + << ", reservation: " + << succinctBytes(pool()->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool()->root()->reservedBytes()); + spillStats_->spillMaxLevelExceededCount.fetch_add( + 1, std::memory_order_relaxed); exceededMaxSpillLevelLimit_ = true; return; } @@ -309,6 +430,11 @@ void HashBuild::removeInputRowsForAntiJoinFilter() { void HashBuild::addInput(RowVectorPtr input) { checkRunning(); + + VELOX_CHECK( + !useHashTableCache() || + (cacheEntry_->builderTaskId == taskId() && !cacheEntry_->buildComplete)); + ensureInputFits(input); TestValue::adjust("facebook::velox::exec::HashBuild::addInput", this); @@ -336,7 +462,7 @@ void HashBuild::addInput(RowVectorPtr input) { } if (!isRightJoin(joinType_) && !isFullJoin(joinType_) && - !isRightSemiProjectJoin(joinType_) && + !isRightSemiProjectJoin(joinType_) && !nullAsValue_ && !isLeftNullAwareJoinWithFilter(joinNode_)) { deselectRowsWithNulls(hashers, activeRows_); if (nullAware_ && !joinHasNullKeys_ && @@ -377,6 +503,41 @@ void HashBuild::addInput(RowVectorPtr input) { return; } + if (dropDuplicates_ && !abandonHashBuildDedup_) { + // Counting joins must not abandon dedup — accurate counts are required. + const bool abandonEarly = !joinNode_->isCountingJoin() && + abandonHashBuildDedupEarly(table_->numDistinct()); + if (!abandonEarly) { + numHashInputRows_ += activeRows_.countSelected(); + table_->prepareForGroupProbe( + *lookup_, + input, + activeRows_, + BaseHashTable::kNoSpillInputStartPartitionBit); + if (lookup_->rows.empty()) { + return; + } + table_->groupProbe( + *lookup_, BaseHashTable::kNoSpillInputStartPartitionBit); + + // For counting joins, increment the count for duplicate rows. + // New rows are initialized with count = 1 by initializeRow. + // Increment count for all rows, then decrement for new rows to + // correct the over-counting. + if (joinNode_->isCountingJoin()) { + auto* rows = table_->rows(); + for (auto row : lookup_->rows) { + rows->incrementCount(lookup_->hits[row]); + } + for (auto newRow : lookup_->newGroups) { + rows->decrementCount(lookup_->hits[newRow]); + } + } + return; + } + abandonHashBuildDedup(); + } + if (analyzeKeys_ && hashes_.size() < activeRows_.end()) { hashes_.resize(activeRows_.end()); } @@ -506,8 +667,11 @@ void HashBuild::ensureInputFits(RowVectorPtr& input) { } LOG(WARNING) << "Failed to reserve " << succinctBytes(targetIncrementBytes) << " for memory pool " << pool()->name() - << ", usage: " << succinctBytes(pool()->usedBytes()) - << ", reservation: " << succinctBytes(pool()->reservedBytes()); + << ", root pool: " << pool()->root()->name() + << ", used: " << succinctBytes(pool()->usedBytes()) + << ", reservation: " << succinctBytes(pool()->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool()->root()->reservedBytes()); } void HashBuild::spillInput(const RowVectorPtr& input) { @@ -627,6 +791,7 @@ void HashBuild::noMoreInput() { if (noMoreInput_) { return; } + Operator::noMoreInput(); noMoreInputInternal(); @@ -656,7 +821,19 @@ bool HashBuild::finishHashBuild() { // build pipeline. if (!operatorCtx_->task()->allPeersFinished( planNodeId(), operatorCtx_->driver(), &future_, promises, peers)) { - setState(State::kWaitForBuild); + if (useHashTableCache() && !hashTableCacheBuilderTask()) { + // Waiter task non-last driver: no partial table was built (we used the + // cached table). Nothing to contribute — finish immediately. Clear the + // future since allPeersFinished() set it but we don't need to wait. + VELOX_CHECK_NULL( + table_, "Waiter task should not have built a partial hash table"); + future_ = folly::SemiFuture::makeEmpty(); + setState(State::kFinish); + } else { + // Builder task non-last driver: the last driver needs our partial + // table. Wait in kWaitForBuild until it has moved our table out. + setState(State::kWaitForBuild); + } return false; } @@ -671,6 +848,10 @@ bool HashBuild::finishHashBuild() { } }; + if (getHashTableFromCache()) { + return true; + } + if (joinHasNullKeys_ && isAntiJoin(joinType_) && nullAware_ && !joinNode_->filter()) { joinBridge_->setAntiJoinHasNullKeys(); @@ -746,8 +927,6 @@ bool HashBuild::finishHashBuild() { pool()->release(); }; - // TODO: Re-enable parallel join build with spilling triggered after - // https://github.com/facebookincubator/velox/issues/3567 is fixed. CpuWallTiming timing; { CpuWallTimer cpuWallTimer{timing}; @@ -755,11 +934,13 @@ bool HashBuild::finishHashBuild() { std::move(otherTables), isInputFromSpill() ? spillConfig()->startPartitionBit : BaseHashTable::kNoSpillInputStartPartitionBit, + vectorHasherMaxNumDistinct_, + dropDuplicates_, allowParallelJoinBuild ? operatorCtx_->task()->queryCtx()->executor() : nullptr); } stats_.wlock()->addRuntimeStat( - BaseHashTable::kBuildWallNanos, + std::string(BaseHashTable::kBuildWallNanos), RuntimeCounter(timing.wallNanos, RuntimeCounter::Unit::kNanos)); addRuntimeStats(); @@ -784,11 +965,16 @@ bool HashBuild::finishHashBuild() { spillStats); }; } + + // For hash table caching: the last driver caches the merged table. + std::shared_ptr table = std::move(table_); + maybeSetHashTableInCache(table); joinBridge_->setHashTable( - std::move(table_), + table, std::move(spillPartitions), joinHasNullKeys_, std::move(tableSpillFunc)); + if (canSpill()) { stateCleared_ = true; } @@ -833,14 +1019,15 @@ void HashBuild::ensureTableFits(uint64_t numRows) { LOG(WARNING) << "Failed to reserve " << succinctBytes(memoryBytesToReserve) << " for join table build from last hash build operator " - << pool()->name() - << ", usage: " << succinctBytes(pool()->usedBytes()) - << ", reservation: " << succinctBytes(pool()->reservedBytes()); + << pool()->name() << ", root pool: " << pool()->root()->name() + << ", used: " << succinctBytes(pool()->usedBytes()) + << ", reservation: " << succinctBytes(pool()->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool()->root()->reservedBytes()); } void HashBuild::postHashBuildProcess() { checkRunning(); - if (!canSpill()) { setState(State::kFinish); return; @@ -879,6 +1066,7 @@ void HashBuild::setupSpillInput(HashJoinBridge::SpillInput spillInput) { setupTable(); setupSpiller(spillInput.spillPartition.get()); stateCleared_ = false; + numHashInputRows_ = 0; // Start to process spill input. processSpillInput(); @@ -904,7 +1092,6 @@ void HashBuild::processSpillInput() { void HashBuild::addRuntimeStats() { // Report range sizes and number of distinct values for the join keys. const auto& hashers = table_->hashers(); - const auto hashTableStats = table_->stats(); uint64_t asRange{0}; uint64_t asDistinct{0}; auto lockedStats = stats_.wlock(); @@ -912,23 +1099,54 @@ void HashBuild::addRuntimeStats() { for (const auto& timing : table_->parallelJoinBuildStats().partitionTimings) { lockedStats->getOutputTiming.add(timing); lockedStats->addRuntimeStat( - BaseHashTable::kParallelJoinPartitionWallNanos, + std::string(BaseHashTable::kParallelJoinPartitionWallNanos), RuntimeCounter(timing.wallNanos, RuntimeCounter::Unit::kNanos)); lockedStats->addRuntimeStat( - BaseHashTable::kParallelJoinPartitionCpuNanos, + std::string(BaseHashTable::kParallelJoinPartitionCpuNanos), RuntimeCounter(timing.cpuNanos, RuntimeCounter::Unit::kNanos)); } for (const auto& timing : table_->parallelJoinBuildStats().buildTimings) { lockedStats->getOutputTiming.add(timing); lockedStats->addRuntimeStat( - BaseHashTable::kParallelJoinBuildWallNanos, + std::string(BaseHashTable::kParallelJoinBuildWallNanos), RuntimeCounter(timing.wallNanos, RuntimeCounter::Unit::kNanos)); lockedStats->addRuntimeStat( - BaseHashTable::kParallelJoinBuildCpuNanos, + std::string(BaseHashTable::kParallelJoinBuildCpuNanos), RuntimeCounter(timing.cpuNanos, RuntimeCounter::Unit::kNanos)); } + for (const auto& timing : + table_->parallelJoinBuildStats().bloomFilterPartitionTimings) { + lockedStats->getOutputTiming.add(timing); + if (timing.wallNanos > 0) { + lockedStats->addRuntimeStat( + std::string( + BaseHashTable::kParallelJoinBloomFilterPartitionWallNanos), + RuntimeCounter(timing.wallNanos, RuntimeCounter::Unit::kNanos)); + } + if (timing.cpuNanos > 0) { + lockedStats->addRuntimeStat( + std::string(BaseHashTable::kParallelJoinBloomFilterPartitionCpuNanos), + RuntimeCounter(timing.cpuNanos, RuntimeCounter::Unit::kNanos)); + } + } + + for (const auto& timing : + table_->parallelJoinBuildStats().bloomFilterBuildTimings) { + lockedStats->getOutputTiming.add(timing); + if (timing.wallNanos > 0) { + lockedStats->addRuntimeStat( + std::string(BaseHashTable::kParallelJoinBloomFilterBuildWallNanos), + RuntimeCounter(timing.wallNanos, RuntimeCounter::Unit::kNanos)); + } + if (timing.cpuNanos > 0) { + lockedStats->addRuntimeStat( + std::string(BaseHashTable::kParallelJoinBloomFilterBuildCpuNanos), + RuntimeCounter(timing.cpuNanos, RuntimeCounter::Unit::kNanos)); + } + } + for (auto i = 0; i < hashers.size(); i++) { hashers[i]->cardinality(0, asRange, asDistinct); if (asRange != VectorHasher::kRangeTooLarge) { @@ -941,24 +1159,21 @@ void HashBuild::addRuntimeStats() { } } - lockedStats->runtimeStats[BaseHashTable::kCapacity] = - RuntimeMetric(hashTableStats.capacity); - lockedStats->runtimeStats[BaseHashTable::kNumRehashes] = - RuntimeMetric(hashTableStats.numRehashes); - lockedStats->runtimeStats[BaseHashTable::kNumDistinct] = - RuntimeMetric(hashTableStats.numDistinct); - if (hashTableStats.numTombstones != 0) { - lockedStats->runtimeStats[BaseHashTable::kNumTombstones] = - RuntimeMetric(hashTableStats.numTombstones); - } + table_->addRuntimeStats(lockedStats->runtimeStats); // Add max spilling level stats if spilling has been triggered. if (spiller_ != nullptr && spiller_->spillTriggered()) { lockedStats->addRuntimeStat( - "maxSpillLevel", + std::string(HashBuild::kMaxSpillLevel), RuntimeCounter( spillConfig()->spillLevel(spiller_->hashBits().begin()))); } + + lockedStats->addRuntimeStat( + std::string(BaseHashTable::kVectorHasherMergeCpuNanos), + RuntimeCounter( + table_->vectorHasherMergeTiming().cpuNanos, + RuntimeCounter::Unit::kNanos)); } BlockingReason HashBuild::isBlocked(ContinueFuture* future) { @@ -976,6 +1191,11 @@ BlockingReason HashBuild::isBlocked(ContinueFuture* future) { case State::kFinish: break; case State::kWaitForBuild: + if (receivedCachedHashTable()) { + break; + } + // We were waiting for peer drivers to finish - fall through to + // kWaitForProbe which has the same logic. [[fallthrough]]; case State::kWaitForProbe: if (!future_.valid()) { @@ -1059,6 +1279,14 @@ bool HashBuild::canSpill() const { if (!Operator::canSpill()) { return false; } + // For Cached hash table, we don't support spill either by the + // task thats building or by the task that is re-using it + if (useHashTableCache()) { + return false; + } + if (joinNode_->isCountingJoin()) { + return false; + } if (operatorCtx_->task()->hasMixedExecutionGroupJoin(joinNode_.get())) { return operatorCtx_->driverCtx() ->queryConfig() @@ -1091,8 +1319,12 @@ void HashBuild::reclaim( LOG(WARNING) << "Can't reclaim from hash build operator, exceeded maximum spill " "level of " - << config->maxSpillLevel << ", " << pool()->name() << ", usage " - << succinctBytes(pool()->usedBytes()); + << config->maxSpillLevel << ", " << pool()->name() + << ", root pool: " << pool()->root()->name() + << ", used: " << succinctBytes(pool()->usedBytes()) + << ", reservation: " << succinctBytes(pool()->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool()->root()->reservedBytes()); return; } @@ -1105,11 +1337,16 @@ void HashBuild::reclaim( LOG(WARNING) << "Can't reclaim from hash build operator, state_[" << stateName(state_) << "], nonReclaimableSection_[" << nonReclaimableSection_ << "], spiller_[" - << (stateCleared_ ? "cleared" - : (spiller_->finalized() ? "finalized" - : "non-finalized")) + << (stateCleared_ ? "cleared" + : spiller_ == nullptr ? "null" + : spiller_->finalized() ? "finalized" + : "non-finalized") << "] " << pool()->name() - << ", usage: " << succinctBytes(pool()->usedBytes()); + << ", root pool: " << pool()->root()->name() + << ", used: " << succinctBytes(pool()->usedBytes()) + << ", reservation: " << succinctBytes(pool()->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool()->root()->reservedBytes()); return; } @@ -1128,9 +1365,18 @@ void HashBuild::reclaim( ++stats.numNonReclaimableAttempts; LOG(WARNING) << "Can't reclaim from hash build operator, state_[" << stateName(buildOp->state_) << "], nonReclaimableSection_[" - << buildOp->nonReclaimableSection_ << "], " - << buildOp->pool()->name() << ", usage: " - << succinctBytes(buildOp->pool()->usedBytes()); + << buildOp->nonReclaimableSection_ << "], spiller_[" + << (buildOp->stateCleared_ ? "cleared" + : buildOp->spiller_ == nullptr ? "null" + : buildOp->spiller_->finalized() ? "finalized" + : "non-finalized") + << "], " << buildOp->pool()->name() + << ", root pool: " << buildOp->pool()->root()->name() + << ", used: " << succinctBytes(buildOp->pool()->usedBytes()) + << ", reservation: " + << succinctBytes(buildOp->pool()->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(buildOp->pool()->root()->reservedBytes()); return; } } @@ -1150,6 +1396,19 @@ void HashBuild::reclaim( } } +memory::MemoryPool* HashBuild::tableMemoryPool() const { + if (useHashTableCache()) { + // Cached hash tables use a leaf pool under the query pool (from cache + // entry). This allows the table to outlive the task while still supporting + // allocations. + VELOX_CHECK_NOT_NULL(cacheEntry_); + VELOX_CHECK_NOT_NULL(cacheEntry_->tablePool); + return cacheEntry_->tablePool.get(); + } + // Regular joins use operator pool + return pool(); +} + bool HashBuild::nonReclaimableState() const { // Apart from being in the nonReclaimable section, it's also not reclaimable // if: @@ -1167,6 +1426,11 @@ bool HashBuild::nonReclaimableState() const { void HashBuild::close() { Operator::close(); + if (useHashTableCache() && cacheEntry_ != nullptr && + !cacheEntry_->buildComplete && hashTableCacheBuilderTask()) { + HashTableCache::instance()->drop(cacheKey_); + } + { // Free up major memory usage. Gate access to them as they can be accessed // by the last build thread that finishes building the hash table. @@ -1185,7 +1449,7 @@ HashBuildSpiller::HashBuildSpiller( RowTypePtr rowType, HashBitRange bits, const common::SpillConfig* spillConfig, - folly::Synchronized* spillStats) + exec::SpillStats* spillStats) : SpillerBase( container, std::move(rowType), @@ -1240,4 +1504,22 @@ void HashBuildSpiller::extractSpill( rows.data(), rows.size(), false, false, result->childAt(types.size())); } } + +bool HashBuild::abandonHashBuildDedupEarly(int64_t numDistinct) const { + VELOX_CHECK(dropDuplicates_); + return numHashInputRows_ > abandonHashBuildDedupMinRows_ && + 100 * numDistinct / numHashInputRows_ >= abandonHashBuildDedupMinPct_; +} + +void HashBuild::abandonHashBuildDedup() { + // The hash table is no longer directly constructed in addInput. The data + // that was previously inserted into the hash table is already in the + // RowContainer. + addRuntimeStat( + std::string(HashBuild::kAbandonBuildNoDupHash), RuntimeCounter(1)); + abandonHashBuildDedup_ = true; + table_->setAllowDuplicates(true); + lookup_.reset(); +} + } // namespace facebook::velox::exec diff --git a/velox/exec/HashBuild.h b/velox/exec/HashBuild.h index 9485662bd0a..96beccf9850 100644 --- a/velox/exec/HashBuild.h +++ b/velox/exec/HashBuild.h @@ -15,8 +15,11 @@ */ #pragma once +#include + #include "velox/exec/HashJoinBridge.h" #include "velox/exec/HashTable.h" +#include "velox/exec/HashTableCache.h" #include "velox/exec/Operator.h" #include "velox/exec/Spill.h" #include "velox/exec/Spiller.h" @@ -54,6 +57,13 @@ class HashBuild final : public Operator { }; static std::string stateName(State state); + /// Runtime stat keys for hash build. + /// Maximum spill level reached during hash join build. + static constexpr std::string_view kMaxSpillLevel = "maxSpillLevel"; + /// Whether dedup hash build was abandoned. + static constexpr std::string_view kAbandonBuildNoDupHash = + "abandonBuildNoDupHash"; + HashBuild( int32_t operatorId, DriverCtx* driverCtx, @@ -104,9 +114,35 @@ class HashBuild final : public Operator { bool isRunning() const; void checkRunning() const; + // Returns true if this task is the builder task for the hash table cache + // entry (i.e. the task that builds the table, as opposed to a waiter task + // that reuses a cached table built by another task). + bool hashTableCacheBuilderTask() const { + return cacheEntry_->builderTaskId == taskId(); + } + // Invoked to set up hash table to build. void setupTable(); + // Sets up hash table caching if enabled. Returns true if the cached table + // is already available or if this operator should wait for another task + // to build it, in which case further initialization should be skipped. + // Returns false if this operator should proceed with building the table. + bool setupCachedHashTable(); + + // Checks if a cached hash table is available and uses it if so. + // Returns true if the cached table was used (build can be skipped). + // Returns false if we need to build the table (cache miss). + bool getHashTableFromCache(); + + // Called when waiting for a cached hash table from another task. + // Returns true if the cached table was received and noMoreInput was called. + bool receivedCachedHashTable(); + + // Stores the built hash table in the cache for reuse by other tasks. + // No-op if hash table caching is not enabled. + void maybeSetHashTableInCache(const std::shared_ptr& table); + // Invoked when operator has finished processing the build input and wait for // all the other drivers to finish the processing. The last driver that // reaches to the hash build barrier, is responsible to build the hash table @@ -204,12 +240,43 @@ class HashBuild final : public Operator { // not. bool nonReclaimableState() const; + // True if we have enough rows and not enough duplicate join keys, i.e. more + // than 'abandonHashBuildDedupMinRows_' rows and more than + // 'abandonHashBuildDedupMinPct_' % of rows are unique. + bool abandonHashBuildDedupEarly(int64_t numDistinct) const; + + // Invoked to abandon build deduped hash table. + void abandonHashBuildDedup(); + + // Returns true if this operator is using a cached hash table. + // When enabled, the hash table is built once and cached for reuse + // by other tasks within the same query and stage. + bool useHashTableCache() const { + return !cacheKey_.empty(); + } + + // Returns the hash table cache key for this operator. + // Only valid if useHashTableCache() returns true. + const std::string& cacheKey() const { + VELOX_CHECK( + useHashTableCache(), + "cacheKey() called when table caching is not enabled"); + return cacheKey_; + } + + // Determines the memory pool to use for the hash table. + // For cached hash tables, uses query-level pool so the table can + // outlive the task. For regular joins, uses operator pool. + memory::MemoryPool* tableMemoryPool() const; + const std::shared_ptr joinNode_; const core::JoinType joinType_; const bool nullAware_; + const bool nullAsValue_; + // Sets to true for join type which needs right side join processing. The hash // table spiller then needs to record the probed flag, and the spilled input // reader also needs to restore the recorded probed flag. This is used to @@ -217,12 +284,37 @@ class HashBuild final : public Operator { // not. const bool needProbedFlagSpill_; + // Indicates whether drop duplicate rows. Rows containing duplicate keys + // can be removed for left semi and anti join. + const bool dropDuplicates_; + + // Maximum number of distinct values to keep when merging vector hashers + const size_t vectorHasherMaxNumDistinct_; + + // Minimum number of rows to see before deciding to give up build no + // duplicates hash table. + const int32_t abandonHashBuildDedupMinRows_; + + // Min unique rows pct for give up build deduped hash table. If more + // than this many rows are unique, build hash table in addInput phase is not + // worthwhile. + const int32_t abandonHashBuildDedupMinPct_; + std::shared_ptr joinBridge_; tsan_atomic exceededMaxSpillLevelLimit_{false}; State state_{State::kRunning}; + // For hash table caching: the cache key passed in at construction. + // If set, this operator coordinates via HashTableCache. + // Key format: "queryId:planNodeId" + std::string cacheKey_; + + // For hash table caching: cached entry containing the shared table and pool. + // Retrieved from HashTableCache. + std::shared_ptr cacheEntry_; + // The row type used for hash table build and disk spilling. RowTypePtr tableType_; @@ -242,6 +334,9 @@ class HashBuild final : public Operator { // Container for the rows being accumulated. std::unique_ptr table_; + // Used for building hash table while adding input rows. + std::unique_ptr lookup_; + // Key channels in 'input_' std::vector keyChannels_; @@ -269,6 +364,10 @@ class HashBuild final : public Operator { // at least one entry with null join keys. bool joinHasNullKeys_{false}; + // Whether to abandon building a HashTable without duplicates in HashBuild + // addInput phase for left semi/anti join. + bool abandonHashBuildDedup_{false}; + // The type used to spill hash table which might attach a boolean column to // record the probed flag if 'needProbedFlagSpill_' is true. RowTypePtr spillType_; @@ -310,6 +409,10 @@ class HashBuild final : public Operator { // Maps key channel in 'input_' to channel in key. folly::F14FastMap keyChannelMap_; + + // Count the number of hash table input rows for building deduped + // hash table. It will not be updated after abandonBuildNoDupHash_ is true. + int64_t numHashInputRows_ = 0; }; inline std::ostream& operator<<(std::ostream& os, HashBuild::State state) { @@ -328,7 +431,7 @@ class HashBuildSpiller : public SpillerBase { RowTypePtr rowType, HashBitRange bits, const common::SpillConfig* spillConfig, - folly::Synchronized* spillStats); + exec::SpillStats* spillStats); /// Invoked to spill all the rows stored in the row container of the hash /// build. diff --git a/velox/exec/HashJoinBridge.cpp b/velox/exec/HashJoinBridge.cpp index c9dc19948a0..4e9e25eab26 100644 --- a/velox/exec/HashJoinBridge.cpp +++ b/velox/exec/HashJoinBridge.cpp @@ -42,13 +42,18 @@ RowTypePtr hashJoinTableType( types.emplace_back(inputType->childAt(channel)); } + if (joinNode->canDropDuplicates()) { + // For left semi and anti join with no extra filter, hash table does not + // store dependent columns. + return ROW(std::move(names), std::move(types)); + } + for (auto i = 0; i < inputType->size(); ++i) { if (keyChannelSet.find(i) == keyChannelSet.end()) { names.emplace_back(inputType->nameOf(i)); types.emplace_back(inputType->childAt(i)); } } - return ROW(std::move(names), std::move(types)); } @@ -98,7 +103,7 @@ std::unique_ptr createSpiller( const RowTypePtr& tableType, const HashBitRange& hashBitRange, const common::SpillConfig* spillConfig, - folly::Synchronized* stats) { + exec::SpillStats* stats) { return std::make_unique( joinType, parentId, @@ -166,7 +171,7 @@ SpillPartitionSet spillHashJoinTable( const HashBitRange& hashBitRange, const std::shared_ptr& joinNode, const common::SpillConfig* spillConfig, - folly::Synchronized* stats) { + exec::SpillStats* spillStats) { VELOX_CHECK_NOT_NULL(table); VELOX_CHECK_NOT_NULL(spillConfig); if (table->numDistinct() == 0) { @@ -189,7 +194,7 @@ SpillPartitionSet spillHashJoinTable( tableType, hashBitRange, spillConfig, - stats)); + spillStats)); spillers.push_back(spillersHolder.back().get()); } if (spillersHolder.empty()) { @@ -211,7 +216,7 @@ SpillPartitionSet spillHashJoinTable( } void HashJoinBridge::setHashTable( - std::unique_ptr table, + std::shared_ptr table, SpillPartitionSet spillPartitionSet, bool hasNullKeys, HashJoinTableSpillFunc&& tableSpillFunc) { @@ -452,11 +457,11 @@ uint64_t HashJoinMemoryReclaimer::reclaim( } bool isHashBuildMemoryPool(const memory::MemoryPool& pool) { - return folly::StringPiece(pool.name()).endsWith("HashBuild"); + return pool.name().ends_with("HashBuild"); } bool isHashProbeMemoryPool(const memory::MemoryPool& pool) { - return folly::StringPiece(pool.name()).endsWith("HashProbe"); + return pool.name().ends_with("HashProbe"); } bool needRightSideJoin(core::JoinType joinType) { diff --git a/velox/exec/HashJoinBridge.h b/velox/exec/HashJoinBridge.h index 879eab6801f..bd175f99c22 100644 --- a/velox/exec/HashJoinBridge.h +++ b/velox/exec/HashJoinBridge.h @@ -53,8 +53,9 @@ class HashJoinBridge : public JoinBridge { /// Invoked by the build operator to set the built hash table. /// 'spillPartitionSet' contains the spilled partitions while building /// 'table' which only applies if the disk spilling is enabled. + /// Accepts both unique_ptr (regular joins) and shared_ptr (broadcast joins). void setHashTable( - std::unique_ptr table, + std::shared_ptr table, SpillPartitionSet spillPartitionSet, bool hasNullKeys, HashJoinTableSpillFunc&& tableSpillFunc); @@ -143,6 +144,17 @@ class HashJoinBridge : public JoinBridge { bool testingHasMoreSpilledPartitions(); + /// Return the next unclaimed row container id in the current hash table for + /// HashProbe drivers to output build-side rows. + int getAndIncrementUnclaimedRowContainerId() { + return unclaimedRowContainerId_.fetch_add(1); + } + + /// Reset the next unclaimed row container id to 0. + void resetUnclaimedRowContainerId() { + unclaimedRowContainerId_.store(0); + } + private: void appendSpilledHashTablePartitionsLocked( SpillPartitionSet&& spillPartitionSet); @@ -183,6 +195,13 @@ class HashJoinBridge : public JoinBridge { // processing. bool probeStarted_; + // Keep track of the next row container id in a hash table that has not been + // processed by any hash probe driver. This is used when hash probe drivers + // output build-side rows. When drivers are allowed to output build-side rows + // in parallel, drivers call getAndIncrementClaimedRowContainerId() to ensure + // the row containers they process do not overlap with each other. + std::atomic_int unclaimedRowContainerId_{0}; + friend test::HashJoinBridgeTestHelper; }; @@ -253,7 +272,7 @@ SpillPartitionSet spillHashJoinTable( const HashBitRange& hashBitRange, const std::shared_ptr& joinNode, const common::SpillConfig* spillConfig, - folly::Synchronized* stats); + exec::SpillStats* spillStats); /// Returns the type used to spill a given hash table type. The function /// might attach a boolean column at the end of 'tableType' if 'joinType' needs diff --git a/velox/exec/HashPartitionFunction.cpp b/velox/exec/HashPartitionFunction.cpp index 896facc4efa..60323b9703b 100644 --- a/velox/exec/HashPartitionFunction.cpp +++ b/velox/exec/HashPartitionFunction.cpp @@ -16,8 +16,7 @@ #include #include -#define XXH_INLINE_ALL -#include // @manual=third-party//xxHash:xxhash +#include "velox/common/base/XxHashInline.h" namespace facebook::velox::exec { namespace { diff --git a/velox/exec/HashProbe.cpp b/velox/exec/HashProbe.cpp index 7bbfe47656a..3ce4f60018d 100644 --- a/velox/exec/HashProbe.cpp +++ b/velox/exec/HashProbe.cpp @@ -15,9 +15,11 @@ */ #include "velox/exec/HashProbe.h" +#include #include "velox/common/base/Counters.h" #include "velox/common/base/StatsReporter.h" #include "velox/common/testutil/TestValue.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/OperatorUtils.h" #include "velox/exec/Task.h" #include "velox/expression/FieldReference.h" @@ -118,15 +120,19 @@ HashProbe::HashProbe( joinNode->outputType(), operatorId, joinNode->id(), - "HashProbe", + OperatorType::kHashProbe, joinNode->canSpill(driverCtx->queryConfig()) - ? driverCtx->makeSpillConfig(operatorId) + ? driverCtx->makeSpillConfig(operatorId, OperatorType::kHashProbe) : std::nullopt), outputBatchSize_{outputBatchRows()}, joinNode_(std::move(joinNode)), joinType_{joinNode_->joinType()}, nullAware_{joinNode_->isNullAware()}, + nullAsValue_{joinNode_->isNullAsValue()}, probeType_(joinNode_->sources()[0]->outputType()), + canOutputBuildRowsInParallel_( + driverCtx->queryConfig().parallelOutputJoinBuildRowsEnabled() && + !canSpill()), joinBridge_(operatorCtx_->task()->getHashJoinBridgeLocked( operatorCtx_->driverCtx()->splitGroupId, planNodeId())), @@ -376,6 +382,10 @@ void HashProbe::initializeResultIter() { void HashProbe::pushdownDynamicFilters() { auto* driver = operatorCtx_->driverCtx()->driver; + const bool hashProbeStringDynamicFilterPushdownEnabled = + operatorCtx_->driverCtx() + ->queryConfig() + .hashProbeStringDynamicFilterPushdownEnabled(); auto numFilters = driver->pushdownFilters( this, keyChannels_, @@ -384,10 +394,27 @@ void HashProbe::pushdownDynamicFilters() { if (dynamicFiltersProducedOnChannels_.contains(sourceChannel)) { return true; } - filter = table_->hashers()[sourceChannel]->getFilter(false); + auto& hasher = *table_->hashers()[sourceChannel]; + if (hasher.typeKind() == TypeKind::VARCHAR || + hasher.typeKind() == TypeKind::VARBINARY) { + if (!hashProbeStringDynamicFilterPushdownEnabled) { + return false; + } + } + filter = hasher.getFilter(false); if (!filter) { - return false; + filter = hasher.getBloomFilter(); + if (!filter) { + return false; + } + auto* bloomFilter = + checkedPointerCast( + filter.get()); + addRuntimeStat( + std::string(HashProbe::kBloomFilterSize), + RuntimeCounter(bloomFilter->blocksByteSize())); } + dynamicFiltersProducedOnChannels_.insert(sourceChannel); for (auto* peer : findPeerOperators()) { peer->dynamicFiltersProducedOnChannels_.insert(sourceChannel); } @@ -399,7 +426,7 @@ void HashProbe::pushdownDynamicFilters() { // * build side has no dependent columns. if (keyChannels_.size() == 1 && !table_->hasDuplicateKeys() && tableOutputProjections_.empty() && !filter_ && numFilters > 0 && - !isRightJoin(joinType_)) { + !table_->hashers()[0]->getBloomFilter() && !isRightJoin(joinType_)) { canReplaceWithDynamicFilter_ = true; } } @@ -456,10 +483,14 @@ void HashProbe::asyncWaitForHashTable() { } } else if ( (isInnerJoin(joinType_) || isLeftSemiFilterJoin(joinType_) || + isCountingLeftSemiFilterJoin(joinType_) || isRightSemiFilterJoin(joinType_) || (isRightSemiProjectJoin(joinType_) && !nullAware_) || isRightJoin(joinType_)) && table_->hashMode() != BaseHashTable::HashMode::kHash && !isSpillInput() && + operatorCtx_->driverCtx() + ->queryConfig() + .hashProbeDynamicFilterPushdownEnabled() && !hasMoreSpillData()) { // Find out whether there are any upstream operators that can accept dynamic // filters on all or a subset of the join keys. Create dynamic filters to @@ -493,7 +524,6 @@ void HashProbe::prepareForSpillRestore() { restoringPartitionId_.reset(); spillInputPartitionIds_.clear(); spillOutputReader_.reset(); - lastProbeIterator_.reset(); VELOX_CHECK(promises_.empty() || lastProber_); if (!lastProber_) { @@ -634,7 +664,6 @@ BlockingReason HashProbe::isBlocked(ContinueFuture* future) { } break; case ProbeOperatorState::kWaitForPeers: - VELOX_CHECK(canSpill()); if (!future_.valid()) { setRunning(); } @@ -662,7 +691,9 @@ void HashProbe::decodeAndDetectNonNullKeys() { hashers_[i]->decode(*key, nonNullInputRows_); } - deselectRowsWithNulls(hashers_, nonNullInputRows_); + if (!nullAsValue_) { + deselectRowsWithNulls(hashers_, nonNullInputRows_); + } if (isRightSemiProjectJoin(joinType_) && nonNullInputRows_.countSelected() < input_->size()) { probeSideHasNullKeys_ = true; @@ -839,6 +870,7 @@ void HashProbe::fillLeftSemiProjectMatchColumn(vector_size_t size) { } void HashProbe::fillOutput(vector_size_t size) { + TestValue::adjust("facebook::velox::exec::HashProbe::fillOutput", this); prepareOutput(size); for (auto [in, out] : projectedInputColumns_) { @@ -863,28 +895,49 @@ void HashProbe::fillOutput(vector_size_t size) { } RowVectorPtr HashProbe::getBuildSideOutput() { - auto* outputTableRows = + if (buildSideOutputRowContainerId_ == -1) { + buildSideOutputRowContainerId_ = + joinBridge_->getAndIncrementUnclaimedRowContainerId(); + lastProbeIterator_.reset(); + } + if (buildSideOutputRowContainerId_ >= table_->numRowContainers()) { + return nullptr; + } + + char** outputTableRows = initBuffer(outputTableRows_, outputTableRowsCapacity_, pool()); - int32_t numOut; - if (isRightSemiFilterJoin(joinType_)) { - numOut = table_->listProbedRows( - &lastProbeIterator_, - outputTableRowsCapacity_, - RowContainer::kUnlimited, - outputTableRows); - } else if (isRightSemiProjectJoin(joinType_)) { - numOut = table_->listAllRows( - &lastProbeIterator_, - outputTableRowsCapacity_, - RowContainer::kUnlimited, - outputTableRows); - } else { - // Must be a right join or full join. - numOut = table_->listNotProbedRows( - &lastProbeIterator_, - outputTableRowsCapacity_, - RowContainer::kUnlimited, - outputTableRows); + int32_t numOut{0}; + while (numOut == 0 && + buildSideOutputRowContainerId_ < table_->numRowContainers()) { + if (isRightSemiFilterJoin(joinType_)) { + numOut = table_->listProbedRows( + lastProbeIterator_, + buildSideOutputRowContainerId_, + outputTableRowsCapacity_, + RowContainer::kUnlimited, + outputTableRows); + } else if (isRightSemiProjectJoin(joinType_)) { + numOut = table_->listAllRows( + lastProbeIterator_, + buildSideOutputRowContainerId_, + outputTableRowsCapacity_, + RowContainer::kUnlimited, + outputTableRows); + + } else { + // Must be a right join or full join. + numOut = table_->listNotProbedRows( + lastProbeIterator_, + buildSideOutputRowContainerId_, + outputTableRowsCapacity_, + RowContainer::kUnlimited, + outputTableRows); + } + if (numOut == 0) { + buildSideOutputRowContainerId_ = + joinBridge_->getAndIncrementUnclaimedRowContainerId(); + lastProbeIterator_.reset(); + } } if (numOut == 0) { return nullptr; @@ -908,9 +961,9 @@ RowVectorPtr HashProbe::getBuildSideOutput() { if (isRightSemiProjectJoin(joinType_)) { // Populate 'match' column. - if (noInput_) { + if (noInput_ && nullAware_) { // Probe side is empty. All rows should return 'match = false', even ones - // with a null join key. + // with a null join key. (This applies to null-aware joins only.) matchColumn() = createConstantFalse(numOut, pool()); } else { table_->rows()->extractProbedFlags( @@ -940,14 +993,20 @@ bool HashProbe::needLastProbe() const { bool HashProbe::skipProbeOnEmptyBuild() const { return isInnerJoin(joinType_) || isLeftSemiFilterJoin(joinType_) || - isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_) || - isRightSemiProjectJoin(joinType_); + isCountingLeftSemiFilterJoin(joinType_) || isRightJoin(joinType_) || + isRightSemiFilterJoin(joinType_) || isRightSemiProjectJoin(joinType_); } bool HashProbe::canSpill() const { if (!Operator::canSpill()) { return false; } + // Hash table caching is incompatible with spilling. When the table is + // cached and shared across tasks, clearing it after probe would corrupt + // the cache for subsequent tasks. + if (joinNode_->useHashTableCache()) { + return false; + } if (operatorCtx_->task()->hasMixedExecutionGroupJoin(joinNode_.get())) { return operatorCtx_->driverCtx() ->queryConfig() @@ -978,16 +1037,11 @@ void HashProbe::checkStateTransition(ProbeOperatorState state) { VELOX_CHECK_NE(state_, state); switch (state) { case ProbeOperatorState::kRunning: - if (!canSpill()) { - VELOX_CHECK_EQ(state_, ProbeOperatorState::kWaitForBuild); - } else { - VELOX_CHECK( - state_ == ProbeOperatorState::kWaitForBuild || - state_ == ProbeOperatorState::kWaitForPeers); - } + VELOX_CHECK( + state_ == ProbeOperatorState::kWaitForBuild || + state_ == ProbeOperatorState::kWaitForPeers); break; case ProbeOperatorState::kWaitForPeers: - VELOX_CHECK(canSpill()); [[fallthrough]]; case ProbeOperatorState::kWaitForBuild: [[fallthrough]]; @@ -1028,12 +1082,12 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) { clearProjectedOutput(); - if (!input_) { + if (input_ == nullptr) { if (hasMoreInput()) { return nullptr; } - if (needLastProbe() && lastProber_) { + if (needLastProbe() && (canOutputBuildRowsInParallel_ || lastProber_)) { auto output = getBuildSideOutput(); if (output != nullptr) { return output; @@ -1085,7 +1139,9 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) { const auto inputSize = input_->size(); if (replacedWithDynamicFilter_) { - addRuntimeStat("replacedWithDynamicFilterRows", RuntimeCounter(inputSize)); + addRuntimeStat( + std::string(HashProbe::kReplacedWithDynamicFilterRows), + RuntimeCounter(inputSize)); auto output = Operator::fillOutput(inputSize, nullptr); input_ = nullptr; return output; @@ -1093,7 +1149,7 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) { const bool isLeftSemiOrAntiJoinNoFilter = !filter_ && (isLeftSemiFilterJoin(joinType_) || isLeftSemiProjectJoin(joinType_) || - isAntiJoin(joinType_)); + isAntiJoin(joinType_) || isCountingJoin(joinType_)); const bool emptyBuildSide = (table_->numDistinct() == 0); @@ -1155,6 +1211,31 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) { } } } + } else if (isCountingAntiJoin(joinType_)) { + // Counting anti-join: emit row if no match or count has reached zero. + // Decrement count on match. + auto* rows = table_->rows(); + for (auto i = 0; i < inputSize; ++i) { + auto* hit = lookup_->hits[i]; + if (!activeRows_.isValid(i) || !hit || rows->count(hit) == 0) { + mapping[numOut] = i; + ++numOut; + } else { + rows->decrementCount(hit); + } + } + } else if (isCountingLeftSemiFilterJoin(joinType_)) { + // Counting semi-join: emit row if match and count > 0. + // Decrement count on match. + auto* rows = table_->rows(); + for (auto i = 0; i < inputSize; ++i) { + auto* hit = lookup_->hits[i]; + if (activeRows_.isValid(i) && hit && rows->count(hit) > 0) { + rows->decrementCount(hit); + mapping[numOut] = i; + ++numOut; + } + } } else { numOut = table_->listJoinResults( *resultIter_, @@ -1173,9 +1254,30 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) { } VELOX_CHECK_LE(numOut, outputBatchSize); + // Pre-load lazy input vectors within a reclaimable section so that the + // subsequent non-reclaimable evalFilter/fillOutput does not trigger OOM. + // If memory reclaim spills the hash table during lazy loading, + // ensureLazyInputLoaded() spills the current input and sets input_ to + // nullptr. We must bail out because resultIter_ now references freed + // hash table rows. + if (!toSpillOutput) { + ensureLazyInputLoaded(); + if (input_ == nullptr) { + return nullptr; + } + VELOX_CHECK_NOT_NULL(table_); + } + numOut = evalFilter(numOut); if (numOut == 0) { + // The hash probe might get stuck in the output loop if the filter is + // highly selective. This does not apply if the call is made during + // spilling, because we cannot break out and resume when the operator is + // undergoing spilling. + if (!toSpillOutput && shouldYield()) { + return nullptr; + } continue; } @@ -1251,7 +1353,7 @@ RowVectorPtr HashProbe::createFilterInput(vector_size_t size) { } void HashProbe::prepareFilterRowsForNullAwareJoin( - RowVectorPtr& filterInput, + const RowVector* filterInput, vector_size_t numRows, bool filterPropagateNulls) { VELOX_CHECK_LE(numRows, kBatchSize); @@ -1260,6 +1362,18 @@ void HashProbe::prepareFilterRowsForNullAwareJoin( BaseVector::create(filterInputType_, kBatchSize, pool()); } + // When no rows are selected, filterInput must be nullptr since + // createFilterInput() cannot be called when outputRowMapping_ has been + // overwritten. When rows are selected, filterInput must be valid. + if (FOLLY_UNLIKELY(!filterInputRows_.hasSelections())) { + VELOX_CHECK_NULL(filterInput); + if (filterPropagateNulls) { + nullFilterInputRows_.resizeFill(numRows, false); + } + return; + } + + VELOX_CHECK_NOT_NULL(filterInput); if (filterPropagateNulls) { nullFilterInputRows_.resizeFill(numRows, false); auto* rawNullRows = nullFilterInputRows_.asMutableRange().bits(); @@ -1331,7 +1445,7 @@ const uint64_t* getFlatFilterResult(VectorPtr& result) { } // namespace void HashProbe::applyFilterOnTableRowsForNullAwareJoin( - const SelectivityVector& rows, + SelectivityVector& rows, SelectivityVector& filterPassedRows, std::function iterator) { if (!rows.hasSelections()) { @@ -1339,47 +1453,53 @@ void HashProbe::applyFilterOnTableRowsForNullAwareJoin( } VELOX_CHECK(table_->rows(), "Should not move rows in hash joins"); char* data[kBatchSize]; - while (auto numRows = iterator(data, kBatchSize)) { - filterTableInput_->resize(numRows); - filterTableInputRows_.resizeFill(numRows, true); + + while (auto numBuildRows = iterator(data, kBatchSize)) { + // Extract build-side columns once per build batch. + filterTableInput_->resize(numBuildRows); + filterTableInputRows_.resizeFill(numBuildRows, true); for (auto& projection : filterTableProjections_) { table_->extractColumn( - folly::Range(data, numRows), + folly::Range(data, numBuildRows), projection.inputChannel, filterTableInput_->childAt(projection.outputChannel)); } + + // Skip probe rows that already passed the filter on a previous build batch. + rows.deselect(filterPassedRows); rows.applyToSelected([&](vector_size_t row) { for (auto& projection : filterInputProjections_) { filterTableInput_->childAt(projection.outputChannel) = BaseVector::wrapInConstant( - numRows, row, input_->childAt(projection.inputChannel)); + numBuildRows, row, input_->childAt(projection.inputChannel)); } EvalCtx evalCtx( operatorCtx_->execCtx(), filter_.get(), filterTableInput_.get()); filter_->eval(filterTableInputRows_, evalCtx, filterTableResult_); + + bool passed = false; if (auto* values = getFlatFilterResult(filterTableResult_[0])) { - if (!bits::testSetBits( - values, 0, numRows, [](vector_size_t) { return false; })) { - filterPassedRows.setValid(row, true); - } + passed = !bits::testSetBits( + values, 0, numBuildRows, [](vector_size_t) { return false; }); } else { decodedFilterTableResult_.decode( *filterTableResult_[0], filterTableInputRows_); if (decodedFilterTableResult_.isConstantMapping()) { - if (!decodedFilterTableResult_.isNullAt(0) && - decodedFilterTableResult_.valueAt(0)) { - filterPassedRows.setValid(row, true); - } + passed = !decodedFilterTableResult_.isNullAt(0) && + decodedFilterTableResult_.valueAt(0); } else { - for (vector_size_t i = 0; i < numRows; ++i) { + for (vector_size_t i = 0; i < numBuildRows; ++i) { if (!decodedFilterTableResult_.isNullAt(i) && decodedFilterTableResult_.valueAt(i)) { - filterPassedRows.setValid(row, true); + passed = true; break; } } } } + if (passed) { + filterPassedRows.setValid(row, true); + } }); } } @@ -1423,7 +1543,7 @@ SelectivityVector HashProbe::evalFilterForNullAwareJoin( if (buildSideHasNullKeys_) { prepareNullKeyProbeHashers(); BaseHashTable::NullKeyRowsIterator iter; - nullKeyProbeRows.deselect(filterPassedRows); + nullKeyProbeRows.updateBounds(); applyFilterOnTableRowsForNullAwareJoin( nullKeyProbeRows, filterPassedRows, [&](char** data, int32_t maxRows) { return table_->listNullKeyRows( @@ -1431,7 +1551,7 @@ SelectivityVector HashProbe::evalFilterForNullAwareJoin( }); } BaseHashTable::RowsIterator iter; - crossJoinProbeRows.deselect(filterPassedRows); + crossJoinProbeRows.updateBounds(); applyFilterOnTableRowsForNullAwareJoin( crossJoinProbeRows, filterPassedRows, [&](char** data, int32_t maxRows) { return table_->listAllRows( @@ -1484,17 +1604,26 @@ int32_t HashProbe::evalFilter(int32_t numRows) { filterInputRows_.updateBounds(); } - RowVectorPtr filterInput = createFilterInput(numRows); - - if (nullAware_) { - prepareFilterRowsForNullAwareJoin( - filterInput, numRows, filterPropagateNulls); - } + // Skip filter evaluation when no rows are selected. filterPassed() + // short-circuits on filterInputRows_, so the result is never read. + // We cannot call createFilterInput() here because it wraps probe columns + // in a dictionary over outputRowMapping_, which listJoinResults() may have + // already overwritten in the previous iteration — the dictionary would fail + // validation in debug builds. + if (FOLLY_LIKELY(filterInputRows_.hasSelections())) { + RowVectorPtr filterInput = createFilterInput(numRows); + if (nullAware_) { + prepareFilterRowsForNullAwareJoin( + filterInput.get(), numRows, filterPropagateNulls); + } - EvalCtx evalCtx(operatorCtx_->execCtx(), filter_.get(), filterInput.get()); - filter_->eval(0, 1, true, filterInputRows_, evalCtx, filterResult_); + EvalCtx evalCtx(operatorCtx_->execCtx(), filter_.get(), filterInput.get()); + filter_->eval(0, 1, true, filterInputRows_, evalCtx, filterResult_); - decodedFilterResult_.decode(*filterResult_[0], filterInputRows_); + decodedFilterResult_.decode(*filterResult_[0], filterInputRows_); + } else if (nullAware_) { + prepareFilterRowsForNullAwareJoin(nullptr, numRows, filterPropagateNulls); + } int32_t numPassed = 0; if (isLeftJoin(joinType_) || isFullJoin(joinType_)) { @@ -1655,7 +1784,7 @@ void HashProbe::ensureLoadedIfNotAtEnd(column_index_t channel) { void HashProbe::ensureLoaded(column_index_t channel) { if (!filter_ && (isLeftSemiFilterJoin(joinType_) || isLeftSemiProjectJoin(joinType_) || - isAntiJoin(joinType_))) { + isAntiJoin(joinType_) || isCountingJoin(joinType_))) { return; } @@ -1700,21 +1829,34 @@ void HashProbe::noMoreInputInternal() { spillInputPartitionIds_.size(), inputSpiller_->state().spilledPartitionIdSet().size()); inputSpiller_->finishSpill(inputSpillPartitionSet_); - VELOX_CHECK_EQ(spillStats_->rlock()->spillSortTimeNanos, 0); + VELOX_CHECK_EQ( + spillStats_->spillSortTimeNanos.load(std::memory_order_relaxed), 0); } - const bool hasSpillEnabled = canSpill(); std::vector promises; std::vector> peers; - // The last operator to finish processing inputs is responsible for - // producing build-side rows based on the join. + + // Reset flags about outputting build-side rows in parallel. + buildSideOutputRowContainerId_ = -1; + + // NOTE: if 'canSpill()' is false and 'outputBuildRowsInParallel' is + // false too, then a hash probe operator doesn't need to wait for all the + // other peers to finish probe processing. If 'canSpill()' is true, it + // needs to wait and might expect spill gets triggered by the other probe + // operators, or there is previously spilled table partition(s) that needs to + // restore. If 'outputBuildRowsInParallel', it needs to wait to all drivers + // start outputting build-side rows in parallel only after all drivers finish + // probe processing. + const bool outputBuildRowsInParallel = + canOutputBuildRowsInParallel_ && needLastProbe(); + const bool shouldBlock = canSpill() || outputBuildRowsInParallel; if (!operatorCtx_->task()->allPeersFinished( planNodeId(), operatorCtx_->driver(), - hasSpillEnabled ? &future_ : nullptr, - hasSpillEnabled ? promises_ : promises, + shouldBlock ? &future_ : nullptr, + shouldBlock ? promises_ : promises, peers)) { - if (hasSpillEnabled) { + if (shouldBlock) { VELOX_CHECK(future_.valid()); setState(ProbeOperatorState::kWaitForPeers); VELOX_DCHECK(promises_.empty()); @@ -1725,13 +1867,14 @@ void HashProbe::noMoreInputInternal() { } VELOX_CHECK(promises.empty()); - // NOTE: if 'hasSpillEnabled' is false, then a hash probe operator doesn't - // need to wait for all the other peers to finish probe processing. - // Otherwise, it needs to wait and might expect spill gets triggered by the - // other probe operators, or there is previously spilled table partition(s) - // that needs to restore. - VELOX_CHECK(hasSpillEnabled || peers.empty()); lastProber_ = true; + joinBridge_->resetUnclaimedRowContainerId(); + // If 'outputBuildRowsInParallel' is true, wake up all peers to start + // outputting build-side rows in parallel. Otherwise, only let the last prober + // proceed. + if (outputBuildRowsInParallel) { + wakeupPeerOperators(); + } } bool HashProbe::isFinished() { @@ -1775,8 +1918,9 @@ void HashProbe::ensureOutputFits() { } // We only need to reserve memory for output if need. - if (input_ == nullptr && - (hasMoreInput() || !(needLastProbe() && lastProber_))) { + bool outputBuildSideRows = + needLastProbe() && (canOutputBuildRowsInParallel_ || lastProber_); + if (input_ == nullptr && (hasMoreInput() || !outputBuildSideRows)) { return; } @@ -1785,9 +1929,11 @@ void HashProbe::ensureOutputFits() { memory::testingRunArbitration(pool()); } - const uint64_t bytesToReserve = - operatorCtx_->driverCtx()->queryConfig().preferredOutputBatchBytes() * - 1.2; + const uint64_t bytesToReserve = static_cast( + static_cast(operatorCtx_->driverCtx() + ->queryConfig() + .preferredOutputBatchBytes()) * + 1.2); if (pool()->availableReservation() >= bytesToReserve) { return; } @@ -1799,8 +1945,102 @@ void HashProbe::ensureOutputFits() { } LOG(WARNING) << "Failed to reserve " << succinctBytes(bytesToReserve) << " for memory pool " << pool()->name() - << ", usage: " << succinctBytes(pool()->usedBytes()) - << ", reservation: " << succinctBytes(pool()->reservedBytes()); + << ", root pool: " << pool()->root()->name() + << ", used: " << succinctBytes(pool()->usedBytes()) + << ", reservation: " << succinctBytes(pool()->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool()->root()->reservedBytes()); +} + +bool HashProbe::needLazyLoadProbeInput() const { + if (!canReclaim() || input_ == nullptr || replacedWithDynamicFilter_) { + return false; + } + + // For left semi filter/project and anti joins without a filter, + // ensureLoaded() is a no-op — lazy vectors are passed through zero-copy + // via wrapChild() in fillOutput() and never loaded during probe. + // Preloading is unnecessary because no loading happens. + if (!filter_ && + (isLeftSemiFilterJoin(joinType_) || isLeftSemiProjectJoin(joinType_) || + isAntiJoin(joinType_))) { + return false; + } + return true; +} + +void HashProbe::ensureLazyInputLoaded() { + if (!needLazyLoadProbeInput()) { + return; + } + + VELOX_CHECK_NOT_NULL(input_); + + // When atEnd (single output batch), fillOutput() skips loading via + // ensureLoadedIfNotAtEnd() and wraps lazy vectors zero-copy. + // However, createFilterInput() unconditionally loads filter+projected + // columns, so those still need preloading when a filter is present. + const bool atEnd{resultIter_->atEnd()}; + if (atEnd && !filter_) { + return; + } + + std::vector lazyChannels; + if (!atEnd) { + for (const auto& [in, _] : projectedInputColumns_) { + if (isLazyNotLoaded(*input_->childAt(in))) { + lazyChannels.push_back(in); + } + } + } + if (filter_) { + for (const auto& projection : filterInputProjections_) { + if (atEnd && + projectedInputColumns_.find(projection.inputChannel) == + projectedInputColumns_.end()) { + continue; + } + if (isLazyNotLoaded(*input_->childAt(projection.inputChannel))) { + lazyChannels.push_back(projection.inputChannel); + } + } + } + if (lazyChannels.empty()) { + return; + } + + const bool tableWasNonEmpty{table_->numDistinct() > 0}; + { + loadingLazyInput_ = true; + SCOPE_EXIT { + loadingLazyInput_ = false; + }; + Operator::ReclaimableSectionGuard guard(this); + if (testingTriggerSpill(pool()->name())) { + memory::testingRunArbitration(pool()); + } + for (const auto channel : lazyChannels) { + ensureLoaded(channel); + } + } + + // If the hash table was spilled during lazy loading (it was non-empty + // before but is now empty), spill the current input so it can be re-probed + // during the restore pass. After spilling, resultIter_ references freed + // hash table rows, so we must abandon the current iteration by setting + // input_ to nullptr. + VELOX_CHECK_NOT_NULL(input_); + if (tableWasNonEmpty && table_->numDistinct() == 0) { + VELOX_CHECK_NOT_NULL(inputSpiller_); + spillInput(input_); + if (input_ != nullptr) { + // Non-spilled rows remain. For join types that skip probe on empty + // build side (inner, right, semi filter), dropping them is correct. + // For left/anti/full joins this would be data loss — fail loudly. + VELOX_CHECK(skipProbeOnEmptyBuild()); + } + input_ = nullptr; + } } bool HashProbe::canReclaim() const { @@ -1826,8 +2066,12 @@ void HashProbe::reclaim( LOG(WARNING) << "Can't reclaim from hash probe operator, exceeded maximum spill " "level of " - << config->maxSpillLevel << ", " << pool()->name() << ", usage " - << succinctBytes(pool()->usedBytes()); + << config->maxSpillLevel << ", " << pool()->name() + << ", root pool: " << pool()->root()->name() + << ", used: " << succinctBytes(pool()->usedBytes()) + << ", reservation: " << succinctBytes(pool()->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool()->root()->reservedBytes()); return; } @@ -1844,9 +2088,13 @@ void HashProbe::reclaim( << (table_ == nullptr ? "nullptr" : std::to_string(table_->numDistinct())) << "], " << pool()->name() - << ", usage: " << succinctBytes(pool()->usedBytes()) + << ", root pool: " << pool()->root()->name() + << ", used: " << succinctBytes(pool()->usedBytes()) + << ", reservation: " << succinctBytes(pool()->reservedBytes()) << ", node pool reservation: " - << succinctBytes(pool()->parent()->reservedBytes()); + << succinctBytes(pool()->parent()->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool()->root()->reservedBytes()); return; } @@ -1874,9 +2122,14 @@ void HashProbe::reclaim( ? "nullptr" : std::to_string(probeOp->table_->numDistinct())) << "], " << peerPool->name() - << ", usage: " << succinctBytes(peerPool->usedBytes()) + << ", root pool: " << peerPool->root()->name() + << ", used: " << succinctBytes(peerPool->usedBytes()) + << ", reservation: " + << succinctBytes(peerPool->reservedBytes()) << ", node pool reservation: " - << succinctBytes(peerPool->parent()->reservedBytes()); + << succinctBytes(peerPool->parent()->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(peerPool->root()->reservedBytes()); return; } hasMoreProbeInput |= !probeOp->noMoreSpillInput_; @@ -1932,6 +2185,13 @@ void HashProbe::spillOutput(const std::vector& operators) { auto* spillExecutor = spillConfig()->executor; for (auto* op : operators) { HashProbe* probeOp = static_cast(op); + // Skip operators that are in the middle of loading lazy input vectors. + // Their reclaim path re-enters getOutputInternal()->fillOutput(), or + // evalFilter()->createFilterInput() which loads the same lazy vectors, + // deadlocking on the column reader's folly::basic_once_flag. + if (probeOp->loadingLazyInput_) { + continue; + } spillTasks.push_back( memory::createAsyncMemoryReclaimTask([probeOp]() { try { @@ -2048,7 +2308,8 @@ void HashProbe::checkMaxSpillLevel( << "Exceeded spill level limit: " << config->maxSpillLevel << ", and disable spilling for memory pool: " << pool()->name(); exceededMaxSpillLevelLimit_ = true; - ++spillStats_->wlock()->spillMaxLevelExceededCount; + spillStats_->spillMaxLevelExceededCount.fetch_add( + 1, std::memory_order_relaxed); return; } } diff --git a/velox/exec/HashProbe.h b/velox/exec/HashProbe.h index e04521d5431..aacf12aa321 100644 --- a/velox/exec/HashProbe.h +++ b/velox/exec/HashProbe.h @@ -15,6 +15,8 @@ */ #pragma once +#include + #include "velox/exec/HashBuild.h" #include "velox/exec/HashTable.h" #include "velox/exec/Operator.h" @@ -26,6 +28,13 @@ namespace facebook::velox::exec { // Probes a hash table made by HashBuild. class HashProbe : public Operator { public: + /// Runtime stat keys for hash probe. + /// Size of the bloom filter in bytes. + static constexpr std::string_view kBloomFilterSize = "bloomFilterSize"; + /// Number of rows bypassed via dynamic filter replacement. + static constexpr std::string_view kReplacedWithDynamicFilterRows = + "replacedWithDynamicFilterRows"; + HashProbe( int32_t operatorId, DriverCtx* driverCtx, @@ -110,7 +119,8 @@ class HashProbe : public Operator { // output. static bool joinIncludesMissesFromLeft(core::JoinType joinType) { return isLeftJoin(joinType) || isFullJoin(joinType) || - isAntiJoin(joinType) || isLeftSemiProjectJoin(joinType); + isAntiJoin(joinType) || isCountingAntiJoin(joinType) || + isLeftSemiProjectJoin(joinType); } void setState(ProbeOperatorState state); @@ -190,7 +200,7 @@ class HashProbe : public Operator { // 'filterPropagateNulls' is true, the probe input row which has null in any // probe filter column can't pass the filter. void prepareFilterRowsForNullAwareJoin( - RowVectorPtr& filterInput, + const RowVector* filterInput, vector_size_t numRows, bool filterPropagateNulls); @@ -210,7 +220,7 @@ class HashProbe : public Operator { // that pass the filter in 'filterPassedRows'. Used in null-aware join // processing. void applyFilterOnTableRowsForNullAwareJoin( - const SelectivityVector& rows, + SelectivityVector& rows, SelectivityVector& filterPassedRows, std::function iterator); @@ -269,6 +279,17 @@ class HashProbe : public Operator { // operator is set to reclaimable at this stage. void ensureOutputFits(); + // Returns true if the current input batch requires lazy preloading before + // the non-reclaimable probe output loop. Checks canReclaim(), input state, + // and join type to determine if preloading is needed. + bool needLazyLoadProbeInput() const; + + // Pre-loads lazy input vectors in a reclaimable section so that the + // subsequent non-reclaimable evalFilter/fillOutput does not trigger OOM. + // Called after listJoinResults() so that atEnd() is known and preloading + // can be skipped for columns that fillOutput/createFilterInput won't load. + void ensureLazyInputLoaded(); + // Setups spilled output reader if 'spillOutputPartitionSet_' is not empty. void maybeSetupSpillOutputReader(); @@ -376,8 +397,18 @@ class HashProbe : public Operator { const bool nullAware_; + const bool nullAsValue_; + const RowTypePtr probeType_; + // Flag to indicate whether this hash probe operator can output build-side + // rows in parallel with the peer operators for the current hash table. + // Outputting build-side rows in parallel is currently not allowed in either + // of the following cases: + // 1. QueryConfig::kParallelOutputJoinBuildRowsEnabled is false. + // 2. Spill is enabled. + const bool canOutputBuildRowsInParallel_; + std::shared_ptr joinBridge_; ProbeOperatorState state_{ProbeOperatorState::kWaitForBuild}; @@ -630,7 +661,7 @@ class HashProbe : public Operator { std::optional currentRowPassed; }; - BaseHashTable::RowsIterator lastProbeIterator_; + RowContainerIterator lastProbeIterator_; // For left and anti join with filter, tracks the probe side rows which had // matches on the build side but didn't pass the filter. @@ -666,6 +697,11 @@ class HashProbe : public Operator { // input. SelectivityVector passingInputRows_; + // Set while loading lazy input vectors inside a ReclaimableSectionGuard. + // Tells reclaim() to skip spillOutput() for this operator to avoid + // re-entering the column reader (which would deadlock on call_once). + tsan_atomic loadingLazyInput_{false}; + // Indicates if this hash probe has exceeded max spill limit which is not // allowed to spill. This is reset when hash probe operator starts to probe // the next previously spilled hash table partition. @@ -728,6 +764,10 @@ class HashProbe : public Operator { // Input vector used for listing rows with null keys. VectorPtr nullKeyProbeInput_; + + // The index of the row container in the current hash table that this hash + // probe oprator is processing to output build-side rows. + int buildSideOutputRowContainerId_{-1}; }; inline std::ostream& operator<<(std::ostream& os, ProbeOperatorState state) { diff --git a/velox/exec/HashTable.cpp b/velox/exec/HashTable.cpp index 4c4a8c46382..4c811c47e95 100644 --- a/velox/exec/HashTable.cpp +++ b/velox/exec/HashTable.cpp @@ -22,12 +22,13 @@ #include "velox/common/process/ProcessBase.h" #include "velox/common/process/TraceContext.h" #include "velox/common/testutil/TestValue.h" +#include "velox/exec/AdaptivePrefetch.h" #include "velox/exec/OperatorUtils.h" -#include "velox/vector/VectorTypeUtils.h" using facebook::velox::common::testutil::TestValue; namespace facebook::velox::exec { + // static std::string BaseHashTable::modeString(HashMode mode) { switch (mode) { @@ -51,17 +52,22 @@ HashTable::HashTable( bool allowDuplicates, bool isJoinBuild, bool hasProbedFlag, + bool hasCountFlag, uint32_t minTableSizeForParallelJoinBuild, - memory::MemoryPool* pool) + memory::MemoryPool* pool, + uint64_t bloomFilterMaxSize) : BaseHashTable(std::move(hashers)), pool_(pool), minTableSizeForParallelJoinBuild_(minTableSizeForParallelJoinBuild), + bloomFilterMaxSize_(bloomFilterMaxSize), isJoinBuild_(isJoinBuild), + allowDuplicates_(allowDuplicates), buildPartitionBounds_(raw_vector(pool)) { + VELOX_CHECK(bloomFilterMaxSize_ == 0 || (isJoinBuild && ignoreNullKeys)); std::vector keys; for (auto& hasher : hashers_) { keys.push_back(hasher->type()); - if (!VectorHasher::typeKindSupportsValueIds(hasher->typeKind())) { + if (!hasher->typeSupportsValueIds()) { hashMode_ = HashMode::kHash; } } @@ -74,7 +80,9 @@ HashTable::HashTable( allowDuplicates, isJoinBuild, hasProbedFlag, + hasCountFlag, hashMode_ != HashMode::kHash, + /*useListRowIndex=*/false, pool); nextOffset_ = rows_->nextOffset(); } @@ -234,8 +242,7 @@ class ProbeState { int64_t numProbedBuckets = 0; while (numProbedBuckets < table.numBuckets()) { if (!hits_) { - const uint16_t empty = simd::toBitMask(tagsInTable_ == kEmptyGroup); - if (empty) { + if (simd::any(tagsInTable_ == kEmptyGroup)) { return nullptr; } } else { @@ -275,8 +282,7 @@ class ProbeState { template void eraseHit(Table& table, int64_t& numTombstones) { const auto kEmptyGroup = BaseHashTable::TagVector::broadcast(kEmptyTag); - const bool hasEmptyGroup = - simd::toBitMask(tagsInTable_ == kEmptyGroup) != 0; + const bool hasEmptyGroup = simd::any(tagsInTable_ == kEmptyGroup); table.bucketAt(bucketOffset_) ->setTag(indexInTags_, hasEmptyGroup ? 0 : kTombstoneTag); @@ -332,6 +338,8 @@ char* HashTable::insertEntry( HashLookup& lookup, uint64_t index, vector_size_t row) { + TestValue::adjust( + "facebook::velox::exec::HashTable::insertEntry", rows_->pool()); char* group = rows_->newRow(); lookup.hits[row] = group; // NOLINT storeKeys(lookup, row); @@ -358,8 +366,13 @@ bool HashTable::compareKeys( int32_t i = 0; do { auto& hasher = lookup.hashers[i]; - if (!rows_->equals( - group, rows_->columnAt(i), hasher->decodedVector(), row)) { + if (rows_->compare( + group, + rows_->columnAt(i), + hasher->decodedVector(), + row, + CompareFlags::equality( + CompareFlags::NullHandlingMode::kNullAsValue)) != 0) { return false; } } while (++i < numKeys); @@ -417,9 +430,12 @@ FOLLY_ALWAYS_INLINE void HashTable::fullProbe( } namespace { + // Group prefetch size for join build & probe. constexpr int32_t kPrefetchSize = 64; +constexpr int32_t kHashBatchSize = 1024; + // Normalized keys have non0-random bits. Bits need to be propagated // up to make a tag byte and down so that non-lowest bits of // normalized key affect the hash table index. @@ -798,7 +814,14 @@ bool HashTable::hashRows( return true; } if (!initNormalizedKeys && hashMode_ == HashMode::kNormalizedKey) { - for (auto i = 0; i < rows.size(); ++i) { + // Prefetch row pointers ahead to hide DRAM latency when reading + // normalizedKey from random RowContainer arena addresses. + const auto numRows = static_cast(rows.size()); + AdaptivePrefetch prefetch(numRows); + for (int32_t i = 0; i < numRows; ++i) { + if (auto ahead = prefetch.lookAhead()) { + __builtin_prefetch(rows[i + ahead] - sizeof(normalized_key_t)); + } hashes[i] = mixNormalizedKey(RowContainer::normalizedKey(rows[i]), sizeBits_); } @@ -834,12 +857,87 @@ bool HashTable::hashRows( } namespace { + +template +void partitionBloomFilterRowsImpl( + int32_t offset, + const common::BigintValuesUsingBloomFilter& filter, + const RowContainer& rowContainer, + uint8_t partitionMask, + RowPartitions& rowPartitions) { + char* rows[kHashBatchSize]; + uint8_t partitions[kHashBatchSize]; + RowContainerIterator iter; + while (auto numRows = rowContainer.listRows( + &iter, kHashBatchSize, RowContainer::kUnlimited, rows)) { + for (int i = 0; i < numRows; ++i) { + auto value = folly::loadUnaligned(rows[i] + offset); + partitions[i] = filter.blockIndex(value) & partitionMask; + } + rowPartitions.appendPartitions( + folly::Range(partitions, numRows)); + } +} + +void partitionBloomFilterRows( + const VectorHasher& hasher, + int32_t offset, + const common::BigintValuesUsingBloomFilter& filter, + const RowContainer& rowContainer, + uint8_t numPartitions, + RowPartitions& rowPartitions) { + VELOX_DCHECK(hasher.supportsBloomFilter()); + VELOX_DCHECK(bits::isPowerOfTwo(numPartitions)); + switch (hasher.typeKind()) { + case TypeKind::INTEGER: + partitionBloomFilterRowsImpl( + offset, filter, rowContainer, numPartitions - 1, rowPartitions); + break; + case TypeKind::BIGINT: + partitionBloomFilterRowsImpl( + offset, filter, rowContainer, numPartitions - 1, rowPartitions); + break; + default: + VELOX_UNREACHABLE(); + } +} + +template +void buildBloomFilterImpl( + int32_t offset, + char** rows, + int numRows, + common::BigintValuesUsingBloomFilter& filter) { + for (int i = 0; i < numRows; ++i) { + filter.insert(folly::loadUnaligned(rows[i] + offset)); + } +} + +void buildBloomFilter( + const VectorHasher& hasher, + int32_t offset, + char** rows, + int numRows, + common::BigintValuesUsingBloomFilter& filter) { + VELOX_DCHECK(hasher.supportsBloomFilter()); + switch (hasher.typeKind()) { + case TypeKind::INTEGER: + buildBloomFilterImpl(offset, rows, numRows, filter); + break; + case TypeKind::BIGINT: + buildBloomFilterImpl(offset, rows, numRows, filter); + break; + default: + VELOX_UNREACHABLE(); + } +} + template void syncWorkItems( std::vector>& items, - std::exception_ptr& error, std::vector& timings, - bool log = false) { + bool throwError) { + std::exception_ptr error; // All items must be synced also in case of error because the items // hold references to the table and rows which could be destructed // if unwinding the stack did not pause to sync. @@ -850,15 +948,42 @@ void syncWorkItems( timings.push_back(item->prepareTiming()); } } catch (const std::exception& e) { - if (log) { + if (!throwError) { LOG(ERROR) << "Error in async hash build: " << e.what(); + } else { + error = std::current_exception(); } - error = std::current_exception(); } } + if (error) { + std::rethrow_exception(error); + } } + } // namespace +template <> +bool HashTable::bloomFilterSupported() const { + if (!(bloomFilterMaxSize_ > 0 && + common::BigintValuesUsingBloomFilter::numBlocks(numDistinct_) * + sizeof(SplitBlockBloomFilter::Block) <= + bloomFilterMaxSize_)) { + return false; + } + for (auto& hasher : hashers_) { + if (hasher->supportsBloomFilter()) { + return true; + } + } + return false; +} + +template <> +bool HashTable::bloomFilterSupported() const { + VELOX_CHECK_EQ(bloomFilterMaxSize_, 0); + return false; +} + template bool HashTable::canApplyParallelJoinBuild() const { if (!isJoinBuild_ || buildExecutor_ == nullptr) { @@ -908,23 +1033,50 @@ void HashTable::parallelJoinBuild() { buildPartitionBounds_.back() = sizeMask_ + 1; std::vector>> partitionSteps; std::vector>> buildSteps; + std::vector>> bloomFilterPartitionSteps; + std::vector>> bloomFilterBuildSteps; // rowPartitions are used in the async threads, so declare them before the // sync guard. std::vector> rowPartitions; auto sync = folly::makeGuard([&]() { // This is executed on returning path, possibly in unwinding, so must not // throw. - std::exception_ptr error; syncWorkItems( - partitionSteps, error, parallelJoinBuildStats_.partitionTimings, true); + partitionSteps, parallelJoinBuildStats_.partitionTimings, false); + syncWorkItems(buildSteps, parallelJoinBuildStats_.buildTimings, false); syncWorkItems( - buildSteps, error, parallelJoinBuildStats_.buildTimings, true); + bloomFilterPartitionSteps, + parallelJoinBuildStats_.bloomFilterPartitionTimings, + false); + syncWorkItems( + bloomFilterBuildSteps, + parallelJoinBuildStats_.bloomFilterBuildTimings, + false); // Release the partition bounds to reduce memory usage. buildPartitionBounds_ = raw_vector(pool_); }); - const auto getTable = [this](size_t i) INLINE_LAMBDA { - return i == 0 ? this : otherTables_[i - 1].get(); + // Passing driver context directly to avoid cross thread access to thread + // local driver thread context. + const DriverCtx* driverCtx{nullptr}; + if (const auto* driverThreadCtx = driverThreadContext()) { + driverCtx = driverThreadCtx->driverCtx(); + } + + const auto runStep = [&](auto& steps, auto&& work, bool runInCurrentThread) { + auto step = std::make_shared>([work = std::move(work)] { + work(); + return std::make_unique(true); + }); + steps.push_back(step); + if (runInCurrentThread) { + step->prepare(); + } else { + buildExecutor_->add([driverCtx, step]() { + ScopedDriverThreadContext scopedDriverThreadContext(driverCtx); + step->prepare(); + }); + } }; // This step can involve large memory allocations, so there is a chance of @@ -932,69 +1084,120 @@ void HashTable::parallelJoinBuild() { // concurrency issues. rowPartitions.reserve(numPartitions); for (auto i = 0; i < numPartitions; ++i) { - auto* table = getTable(i); - rowPartitions.push_back(table->rows()->createRowPartitions(*rows_->pool())); - } - - // Passing driver context directly to avoid cross thread access to thread - // local driver thread context. - const DriverCtx* driverCtx{nullptr}; - if (const auto* driverThreadCtx = driverThreadContext()) { - driverCtx = driverThreadCtx->driverCtx(); + rowPartitions.push_back( + tableAt(i)->rows()->createRowPartitions(*rows_->pool())); } // The parallel table partitioning step. for (auto i = 0; i < numPartitions; ++i) { - auto* table = getTable(i); - partitionSteps.push_back(std::make_shared>( - [this, table, rawRowPartitions = rowPartitions[i].get()]() { + auto* table = tableAt(i); + bool last = i == numPartitions - 1; + runStep( + partitionSteps, + [this, table, rawRowPartitions = rowPartitions[i].get()] { partitionRows(*table, *rawRowPartitions); - return std::make_unique(true); - })); - VELOX_CHECK(!partitionSteps.empty()); - buildExecutor_->add([driverCtx, step = partitionSteps.back()]() { - ScopedDriverThreadContext scopedDriverThreadContext(driverCtx); - step->prepare(); - }); - } - - std::exception_ptr error; - syncWorkItems( - partitionSteps, error, parallelJoinBuildStats_.partitionTimings); - if (error != nullptr) { - std::rethrow_exception(error); + }, + // run last partition on current thread to avoid wasting current thread + // on just waiting + last); } + syncWorkItems(partitionSteps, parallelJoinBuildStats_.partitionTimings, true); - // The parallel table building step. + // The parallel table building step. Each partition collects overflow rows + // (rows whose target bucket falls outside its [start, end) range) along + // with their hashes, so the final serial re-insertion phase below does + // not need to recompute them. std::vector> overflowPerPartition(numPartitions); + std::vector> overflowHashesPerPartition(numPartitions); for (auto i = 0; i < numPartitions; ++i) { - buildSteps.push_back(std::make_shared>( - [this, i, &overflowPerPartition, &rowPartitions]() { - buildJoinPartition(i, rowPartitions, overflowPerPartition[i]); - return std::make_unique(true); - })); - VELOX_CHECK(!buildSteps.empty()); - buildExecutor_->add([driverCtx, step = buildSteps.back()]() { - ScopedDriverThreadContext scopedDriverThreadContext(driverCtx); - step->prepare(); - }); - } - syncWorkItems(buildSteps, error, parallelJoinBuildStats_.buildTimings); - - if (error != nullptr) { - std::rethrow_exception(error); + bool last = i == numPartitions - 1; + runStep( + buildSteps, + [this, + i, + &overflowPerPartition, + &overflowHashesPerPartition, + &rowPartitions] { + buildJoinPartition( + i, + rowPartitions, + overflowPerPartition[i], + overflowHashesPerPartition[i]); + }, + // run last partition on current thread to avoid wasting current thread + // on just waiting + last); + } + syncWorkItems(buildSteps, parallelJoinBuildStats_.buildTimings, true); + + if (bloomFilterSupported()) { + const auto numBloomFilterPartitions = bits::isPowerOfTwo(numPartitions) + ? numPartitions + : bits::nextPowerOfTwo(numPartitions) / 2; + VELOX_CHECK_GT(numBloomFilterPartitions, 0); + for (int i = 0; i < hashers_.size(); ++i) { + if (!hashers_[i]->supportsBloomFilter()) { + continue; + } + auto filter = std::make_shared( + numDistinct_, false); + hashers_[i]->setBloomFilter(filter); + for (auto j = 0; j < numPartitions; ++j) { + bool last = j == numPartitions - 1; + auto* rows = tableAt(j)->rows(); + rowPartitions[j]->reset(); + runStep( + bloomFilterPartitionSteps, + [hasher = hashers_[i].get(), + offset = rows->columnAt(i).offset(), + filter, + rows, + numBloomFilterPartitions, + rowPartitions = rowPartitions[j].get()] { + partitionBloomFilterRows( + *hasher, + offset, + *filter, + *rows, + numBloomFilterPartitions, + *rowPartitions); + }, + // run last partition on current thread to avoid wasting current + // thread on just waiting + last); + } + syncWorkItems( + bloomFilterPartitionSteps, + parallelJoinBuildStats_.bloomFilterPartitionTimings, + true); + for (auto j = 0; j < numBloomFilterPartitions; ++j) { + bool last = j == numBloomFilterPartitions - 1; + runStep( + bloomFilterBuildSteps, + [this, i, j, &rowPartitions] { + buildBloomFilterPartition(i, j, rowPartitions); + }, + // run last partition on current thread to avoid wasting current + // thread on just waiting + last); + } + syncWorkItems( + bloomFilterBuildSteps, + parallelJoinBuildStats_.bloomFilterBuildTimings, + true); + } } - raw_vector hashes(pool_); + // Serially re-insert overflow rows that didn't fit in any partition's + // bucket range. Hashes were captured during the parallel build phase, so + // we reuse them instead of re-hashing. for (auto i = 0; i < numPartitions; ++i) { auto& overflows = overflowPerPartition[i]; - hashes.resize(overflows.size()); - hashRows( - folly::Range(overflows.data(), overflows.size()), - false, - hashes); - auto table = i == 0 ? this : otherTables_[i - 1].get(); - insertForJoin(overflows.data(), hashes.data(), overflows.size(), nullptr); + auto& overflowHashes = overflowHashesPerPartition[i]; + VELOX_CHECK_EQ(overflows.size(), overflowHashes.size()); + insertForJoin( + overflows.data(), overflowHashes.data(), overflows.size(), nullptr); + auto* table = tableAt(i); VELOX_CHECK_EQ(table->rows()->numRows(), table->numParallelBuildRows_); } } @@ -1025,14 +1228,14 @@ template void HashTable::partitionRows( HashTable& subtable, RowPartitions& rowPartitions) { - constexpr int32_t kBatch = 1024; - raw_vector rows(kBatch, pool_); - raw_vector hashes(kBatch, pool_); - raw_vector partitions(kBatch, pool_); + raw_vector rows(kHashBatchSize, pool_); + raw_vector hashes(kHashBatchSize, pool_); + raw_vector partitions(kHashBatchSize, pool_); RowContainerIterator iter; while (auto numRows = subtable.rows_->listRows( - &iter, kBatch, RowContainer::kUnlimited, rows.data())) { - hashRows(folly::Range(rows.data(), numRows), true, hashes); + &iter, kHashBatchSize, RowContainer::kUnlimited, rows.data())) { + VELOX_CHECK( + hashRows(folly::Range(rows.data(), numRows), true, hashes)); VELOX_DCHECK_EQ( 0, buildPartitionBounds_.capacity() % @@ -1052,27 +1255,57 @@ template void HashTable::buildJoinPartition( uint8_t partition, const std::vector>& rowPartitions, - std::vector& overflow) { - constexpr int32_t kBatch = 1024; - raw_vector rows(kBatch, pool_); - raw_vector hashes(kBatch, pool_); + std::vector& overflow, + std::vector& overflowHashes) { + raw_vector rows(kHashBatchSize, pool_); + raw_vector hashes(kHashBatchSize, pool_); const int32_t numPartitions = 1 + otherTables_.size(); TableInsertPartitionInfo partitionInfo{ buildPartitionBounds_[partition], buildPartitionBounds_[partition + 1], - overflow}; + overflow, + overflowHashes}; for (auto i = 0; i < numPartitions; ++i) { - auto* table = i == 0 ? this : otherTables_[i - 1].get(); + auto* table = tableAt(i); RowContainerIterator iter; - while (const auto numRows = table->rows_->listPartitionRows( - iter, partition, kBatch, *rowPartitions[i], rows.data())) { - hashRows(folly::Range(rows.data(), numRows), false, hashes); + while ( + const auto numRows = table->rows_->listPartitionRows( + iter, partition, kHashBatchSize, *rowPartitions[i], rows.data())) { + VELOX_CHECK(hashRows(folly::Range(rows.data(), numRows), false, hashes)); insertForJoin(rows.data(), hashes.data(), numRows, &partitionInfo); table->numParallelBuildRows_ += numRows; } } } +template <> +void HashTable::buildBloomFilterPartition( + column_index_t columnIndex, + uint8_t partition, + const std::vector>& rowPartitions) { + char* rows[kHashBatchSize]; + for (auto i = 0; i < 1 + otherTables_.size(); ++i) { + auto* table = tableAt(i); + auto rowColumn = table->rows_->columnAt(columnIndex); + auto* filter = checkedPointerCast( + hashers_[columnIndex]->getBloomFilter().get()); + RowContainerIterator iter; + while (auto numRows = table->rows_->listPartitionRows( + iter, partition, kHashBatchSize, *rowPartitions[i], rows)) { + buildBloomFilter( + *hashers_[columnIndex], rowColumn.offset(), rows, numRows, *filter); + } + } +} + +template <> +void HashTable::buildBloomFilterPartition( + column_index_t /*columnIndex*/, + uint8_t /*partition*/, + const std::vector>& /*rowPartitions*/) { + VELOX_UNREACHABLE(); +} + template bool HashTable::insertBatch( char** groups, @@ -1093,7 +1326,7 @@ bool HashTable::insertBatch( template void HashTable::insertForGroupBy( char** groups, - uint64_t* hashes, + const uint64_t* hashes, int32_t numGroups) { if (hashMode_ == HashMode::kArray) { for (auto i = 0; i < numGroups; ++i) { @@ -1119,9 +1352,24 @@ void HashTable::insertForGroupBy( bool inserted{false}; for (int64_t numProbedBuckets = 0; numProbedBuckets < numBuckets(); ++numProbedBuckets) { + // We are populating a newly allocated table during rehash(), so + // here each tag is either empty (zero) or in-use (high bit set) + // and never a tombstone (0x7f, high bit unset). Therefore, the two + // approaches below are equivalent: + // - On x86 with SSE2, we benefit from a single instruction that builds + // a bit mask by combining top bits of the tags in a vector, so we + // pass the raw tags to simd::toBitMask(). + // - On all architectures, the universally robust way is to call + // simd::toBitMask() with an intermediate xsimd::batch_bool + // generated from a comparison between the tags and an empty batch. MaskType free = ~simd::toBitMask( - BaseHashTable::TagVector::batch_bool_type(tagsInTable)) & +#if XSIMD_WITH_SSE2 + BaseHashTable::TagVector::batch_bool_type(tagsInTable) +#else + tagsInTable != TagVector::broadcast(ProbeState::kEmptyTag) +#endif + ) & ProbeState::kFullMask; if (free) { auto freeOffset = bits::getAndClearLastSetBit(free); @@ -1151,7 +1399,9 @@ bool HashTable::arrayPushRow(char* row, int32_t index) { hasDuplicates_.set(); } } else if (existing) { - // Semijoin or a known unique build side ignores a repeat of a key. + if (rows_->countOffset() > 0) { + rows_->addCount(existing, rows_->count(row)); + } return false; } table_[index] = row; @@ -1179,7 +1429,7 @@ FOLLY_ALWAYS_INLINE void HashTable::buildFullProbe( -static_cast(sizeof(normalized_key_t)); auto insertFn = [&](int32_t /*row*/, PartitionBoundIndexType index) { if (partitionInfo != nullptr && !partitionInfo->inRange(index)) { - partitionInfo->addOverflow(inserted); + partitionInfo->addOverflow(inserted, hash); return nullptr; } storeRowPointer(index, hash, inserted); @@ -1194,6 +1444,8 @@ FOLLY_ALWAYS_INLINE void HashTable::buildFullProbe( RowContainer::normalizedKey(inserted)) { if (nextOffset_ > 0) { pushNext(group, inserted); + } else if (rows_->countOffset() > 0) { + rows_->addCount(group, rows_->count(inserted)); } return true; } @@ -1211,6 +1463,8 @@ FOLLY_ALWAYS_INLINE void HashTable::buildFullProbe( if (compareKeys(group, inserted)) { if (nextOffset_ > 0) { pushNext(group, inserted); + } else if (rows_->countOffset() > 0) { + rows_->addCount(group, rows_->count(inserted)); } return true; } @@ -1227,7 +1481,7 @@ template template FOLLY_ALWAYS_INLINE void HashTable::insertForJoinWithPrefetch( char** groups, - uint64_t* hashes, + const uint64_t* hashes, int32_t numGroups, TableInsertPartitionInfo* partitionInfo) { auto i = 0; @@ -1263,7 +1517,7 @@ FOLLY_ALWAYS_INLINE void HashTable::insertForJoinWithPrefetch( template void HashTable::insertForJoin( char** groups, - uint64_t* hashes, + const uint64_t* hashes, int32_t numGroups, TableInsertPartitionInfo* partitionInfo) { // The insertable rows are in the table, all get put in the hash table or @@ -1288,7 +1542,6 @@ void HashTable::rehash( bool initNormalizedKeys, int8_t spillInputStartPartitionBit) { ++numRehashes_; - constexpr int32_t kHashBatchSize = 1024; if (canApplyParallelJoinBuild()) { parallelJoinBuild(); return; @@ -1296,22 +1549,48 @@ void HashTable::rehash( raw_vector hashes(pool_); hashes.resize(kHashBatchSize); char* groups[kHashBatchSize]; + const bool shouldBuildBloomFilter = bloomFilterSupported(); + std::vector bloomFilters; + if (shouldBuildBloomFilter) { + bloomFilters.resize(hashers_.size()); + for (int i = 0; i < hashers_.size(); ++i) { + if (!hashers_[i]->supportsBloomFilter()) { + continue; + } + auto filter = std::make_shared( + numDistinct_, false); + bloomFilters[i] = filter.get(); + hashers_[i]->setBloomFilter(filter); + } + } // A join build can have multiple payload tables. Loop over 'this' // and the possible other tables and put all the data in the table // of 'this'. for (int32_t i = 0; i <= otherTables_.size(); ++i) { RowContainerIterator iterator; int32_t numGroups; + auto* table = tableAt(i); do { - numGroups = (i == 0 ? this : otherTables_[i - 1].get()) - ->rows() - ->listRows(&iterator, kHashBatchSize, groups); + numGroups = table->rows()->listRows(&iterator, kHashBatchSize, groups); if (!insertBatch( groups, numGroups, hashes, initNormalizedKeys || i != 0)) { VELOX_CHECK_NE(hashMode_, HashMode::kHash); setHashMode(HashMode::kHash, 0, spillInputStartPartitionBit); return; } + if (shouldBuildBloomFilter) { + for (int j = 0; j < hashers_.size(); ++j) { + if (!hashers_[j]->supportsBloomFilter()) { + continue; + } + buildBloomFilter( + *hashers_[j], + table->rows()->columnAt(j).offset(), + groups, + numGroups, + *bloomFilters[j]); + } + } } while (numGroups > 0); } } @@ -1350,7 +1629,6 @@ void HashTable::setHashMode( template bool HashTable::analyze() { - constexpr int32_t kHashBatchSize = 1024; // @lint-ignore CLANGTIDY char* groups[kHashBatchSize]; RowContainerIterator iterator; @@ -1487,7 +1765,9 @@ void HashTable::decideHashMode( return; } disableRangeArrayHash_ |= disableRangeArrayHash; - if (numDistinct_ && !isJoinBuild_) { + if (numDistinct_ && (!isJoinBuild_ || joinBuildNoDuplicates())) { + // If the join type is left semi and anti, allowDuplicates_ will be false, + // and join build is building hash table while adding input rows. if (!analyze()) { setHashMode(HashMode::kHash, numNew, spillInputStartPartitionBit); return; @@ -1709,12 +1989,26 @@ template void HashTable::prepareJoinTable( std::vector> tables, int8_t spillInputStartPartitionBit, + size_t vectorHasherMaxNumDistinct, + bool dropDuplicates, folly::Executor* executor) { buildExecutor_ = executor; + if (dropDuplicates) { + if (table_ != nullptr) { + // Reset table_ and capacity_ to trigger rehash. + rows_->pool()->freeContiguous(tableAllocation_); + table_ = nullptr; + capacity_ = 0; + } + // Call analyze to insert all unique values in row container to the + // table hashers' uniqueValues_; + analyze(); + } otherTables_.reserve(tables.size()); for (auto& table : tables) { - otherTables_.emplace_back(std::unique_ptr>( - dynamic_cast*>(table.release()))); + otherTables_.emplace_back( + std::unique_ptr>( + dynamic_cast*>(table.release()))); } // If there are multiple tables, we need to merge the 'columnHasNulls' flags @@ -1732,6 +2026,7 @@ void HashTable::prepareJoinTable( bool useValueIds = mayUseValueIds(*this); if (useValueIds) { + CpuWallTimer timer(vectorHasherMergeTiming_); for (auto& other : otherTables_) { if (!mayUseValueIds(*other)) { useValueIds = false; @@ -1740,8 +2035,13 @@ void HashTable::prepareJoinTable( } if (useValueIds) { for (auto& other : otherTables_) { + if (dropDuplicates) { + // Before merging with the current hashers, all values in the row + // containers of other table need to be inserted into uniqueValues_. + other->analyze(); + } for (auto i = 0; i < hashers_.size(); ++i) { - hashers_[i]->merge(*other->hashers_[i]); + hashers_[i]->merge(*other->hashers_[i], vectorHasherMaxNumDistinct); if (!hashers_[i]->mayUseValueIds()) { useValueIds = false; break; @@ -1967,6 +2267,55 @@ int32_t HashTable::listAllRows( return listRows(iter, maxRows, maxBytes, rows); } +template +template +int32_t HashTable::listRows( + RowContainerIterator& rowContainerIt, + int rowContainerId, + int32_t maxRows, + uint64_t maxBytes, + char** rows) { + const auto& rowContainer = rowContainerId == 0 + ? rows_.get() + : otherTables_[rowContainerId - 1]->rows(); + const auto numRows = rowContainer->template listRows( + &rowContainerIt, maxRows, maxBytes, rows); + return numRows; +} + +template +int32_t HashTable::listNotProbedRows( + RowContainerIterator& rowContainerIt, + int rowContainerId, + int32_t maxRows, + uint64_t maxBytes, + char** rows) { + return listRows( + rowContainerIt, rowContainerId, maxRows, maxBytes, rows); +} + +template +int32_t HashTable::listProbedRows( + RowContainerIterator& rowContainerIt, + int rowContainerId, + int32_t maxRows, + uint64_t maxBytes, + char** rows) { + return listRows( + rowContainerIt, rowContainerId, maxRows, maxBytes, rows); +} + +template +int32_t HashTable::listAllRows( + RowContainerIterator& rowContainerIt, + int rowContainerId, + int32_t maxRows, + uint64_t maxBytes, + char** rows) { + return listRows( + rowContainerIt, rowContainerId, maxRows, maxBytes, rows); +} + template <> int32_t HashTable::listNullKeyRows( NullKeyRowsIterator* iter, diff --git a/velox/exec/HashTable.h b/velox/exec/HashTable.h index c19d74727a1..b8d3a33ee46 100644 --- a/velox/exec/HashTable.h +++ b/velox/exec/HashTable.h @@ -16,9 +16,8 @@ #pragma once #include "velox/common/base/Portability.h" -#include "velox/common/memory/MemoryAllocator.h" +#include "velox/common/base/RuntimeMetrics.h" #include "velox/exec/OneWayStatusFlag.h" -#include "velox/exec/Operator.h" #include "velox/exec/RowContainer.h" #include "velox/exec/VectorHasher.h" @@ -30,15 +29,22 @@ struct TableInsertPartitionInfo { /// ['start', 'end') specifies the insert range of this table partition. PartitionBoundIndexType start; PartitionBoundIndexType end; - /// Used to contains the overflowed rows which can't be inserted into the - /// given table partition range. + /// Holds overflowed rows that can't be inserted into the given table + /// partition range. Re-inserted serially after all partitions finish. std::vector& overflows; + /// Hashes of the rows in 'overflows', stored 1:1. Carried through so + /// the serial re-insertion phase does not need to re-hash overflow rows. + std::vector& overflowHashes; TableInsertPartitionInfo( PartitionBoundIndexType _start, PartitionBoundIndexType _end, - std::vector& _overflows) - : start(_start), end(_end), overflows(_overflows) { + std::vector& _overflows, + std::vector& _overflowHashes) + : start(_start), + end(_end), + overflows(_overflows), + overflowHashes(_overflowHashes) { VELOX_CHECK_GE(start, 0); VELOX_CHECK_LT(start, end); } @@ -48,9 +54,10 @@ struct TableInsertPartitionInfo { return index >= start && index < end; } - /// Adds 'row' falls outside of this partititon range into 'overflows'. - void addOverflow(char* row) { + /// Records 'row' (with its already-computed 'hash') as an overflow. + void addOverflow(char* row, uint64_t hash) { overflows.push_back(row); + overflowHashes.push_back(hash); } }; @@ -115,6 +122,8 @@ struct HashTableStats { struct ParallelJoinBuildStats { std::vector partitionTimings; std::vector buildTimings; + std::vector bloomFilterPartitionTimings; + std::vector bloomFilterBuildTimings; }; class BaseHashTable { @@ -140,21 +149,38 @@ class BaseHashTable { /// The name of the runtime stats collected and reported by operators that use /// the HashTable (HashBuild, HashAggregation). - static inline const std::string kCapacity{"hashtable.capacity"}; - static inline const std::string kNumRehashes{"hashtable.numRehashes"}; - static inline const std::string kNumDistinct{"hashtable.numDistinct"}; - static inline const std::string kNumTombstones{"hashtable.numTombstones"}; + static constexpr std::string_view kCapacity{"hashtable.capacity"}; + static constexpr std::string_view kNumRehashes{"hashtable.numRehashes"}; + static constexpr std::string_view kNumDistinct{"hashtable.numDistinct"}; + static constexpr std::string_view kNumTombstones{"hashtable.numTombstones"}; + static constexpr std::string_view kHashMode{"hashtable.hashMode"}; /// The same as above but only reported by the HashBuild operator. - static inline const std::string kBuildWallNanos{"hashtable.buildWallNanos"}; - static inline const std::string kParallelJoinPartitionWallNanos{ + static constexpr std::string_view kBuildWallNanos{"hashtable.buildWallNanos"}; + static constexpr std::string_view kParallelJoinPartitionWallNanos{ "hashtable.parallelJoinPartitionWallNanos"}; - static inline const std::string kParallelJoinPartitionCpuNanos{ + static constexpr std::string_view kParallelJoinPartitionCpuNanos{ "hashtable.parallelJoinPartitionCpuNanos"}; - static inline const std::string kParallelJoinBuildWallNanos{ + static constexpr std::string_view kParallelJoinBuildWallNanos{ "hashtable.parallelJoinBuildWallNanos"}; - static inline const std::string kParallelJoinBuildCpuNanos{ + static constexpr std::string_view kParallelJoinBuildCpuNanos{ "hashtable.parallelJoinBuildCpuNanos"}; + static constexpr std::string_view kParallelJoinBloomFilterPartitionWallNanos{ + "hashtable.parallelJoinBloomFilterPartitionWallNanos"}; + static constexpr std::string_view kParallelJoinBloomFilterPartitionCpuNanos{ + "hashtable.parallelJoinBloomFilterPartitionCpuNanos"}; + static constexpr std::string_view kParallelJoinBloomFilterBuildWallNanos{ + "hashtable.parallelJoinBloomFilterBuildWallNanos"}; + static constexpr std::string_view kParallelJoinBloomFilterBuildCpuNanos{ + "hashtable.parallelJoinBloomFilterBuildCpuNanos"}; + static constexpr std::string_view kVectorHasherMergeCpuNanos{ + "hashtable.vectorHasherMergeCpuNanos"}; + static constexpr std::string_view kHashTableCacheHit{"hashtable.cacheHit"}; + static constexpr std::string_view kHashTableCacheMiss{"hashtable.cacheMiss"}; + + /// Populates 'runtimeStats' with hash table stats. + virtual void addRuntimeStats( + std::unordered_map& runtimeStats) const = 0; /// Returns the string of the given 'mode'. static std::string modeString(HashMode mode); @@ -285,6 +311,15 @@ class BaseHashTable { uint64_t maxBytes, char** rows) = 0; + /// Same as above, but only return rows from the row container of + /// 'rowContainerId'. + virtual int32_t listNotProbedRows( + RowContainerIterator& rowContainerIt, + int rowContainerId, + int32_t maxRows, + uint64_t maxBytes, + char** rows) = 0; + /// Returns rows with 'probed' flag set. Used by the right semi join. virtual int32_t listProbedRows( RowsIterator* iter, @@ -292,6 +327,15 @@ class BaseHashTable { uint64_t maxBytes, char** rows) = 0; + /// Same as above, but only return rows from the row container of + /// 'rowContainerId'. + virtual int32_t listProbedRows( + RowContainerIterator& rowContainerIt, + int rowContainerId, + int32_t maxRows, + uint64_t maxBytes, + char** rows) = 0; + /// Returns all rows. Used by the right semi join project. virtual int32_t listAllRows( RowsIterator* iter, @@ -299,6 +343,15 @@ class BaseHashTable { uint64_t maxBytes, char** rows) = 0; + /// Same as above, but only return rows from the row container of + /// 'rowContainerId'. + virtual int32_t listAllRows( + RowContainerIterator& rowContainerIt, + int rowContainerId, + int32_t maxRows, + uint64_t maxBytes, + char** rows) = 0; + /// Returns all rows with null keys. Used by null-aware joins (e.g. anti or /// left semi project). virtual int32_t listNullKeyRows( @@ -310,8 +363,22 @@ class BaseHashTable { virtual void prepareJoinTable( std::vector> tables, int8_t spillInputStartPartitionBit, + size_t vectorHasherMaxNumDistinct, + bool dropDuplicates = false, folly::Executor* executor = nullptr) = 0; + /// The hash table used for join build in left semi and anti join may not + /// retain duplicate join keys when allowDuplicates_ is false. This is + /// achieved by constructing the hash table in the addInput phase to eliminate + /// duplicate join keys. When the percentage of duplicate data is small, it + /// will adaptively adjust to not build the hash table in the addInput phase. + /// Instead, it operates like other join types by reading all the data before + /// building the hash table. This function is used to change the behavior of + /// building hash table, if allowDuplicates is true, the join hash table will + /// not be built during the addInput phase, and the input data will also not + /// be deduplicated, but it will not impact the containing row container. + virtual void setAllowDuplicates(bool allowDuplicates) = 0; + /// Returns the memory footprint in bytes for any data structures /// owned by 'this'. virtual int64_t allocatedBytes() const = 0; @@ -328,6 +395,9 @@ class BaseHashTable { /// side. This is used for sizing the internal hash table. virtual uint64_t numDistinct() const = 0; + /// Returns the number of row containers in this hash table. + virtual int32_t numRowContainers() const = 0; + /// Return a number of current stats that can help with debugging and /// profiling. virtual HashTableStats stats() const = 0; @@ -414,8 +484,7 @@ class BaseHashTable { __attribute__((__no_sanitize__("thread"))) #endif #endif - static TagVector - loadTags(uint8_t* tags, int64_t tagIndex) { + static TagVector loadTags(uint8_t* tags, int64_t tagIndex) { // Cannot use xsimd::batch::unaligned here because we need to skip TSAN. auto src = tags + tagIndex; #if XSIMD_WITH_SSE2 @@ -429,6 +498,10 @@ class BaseHashTable { return parallelJoinBuildStats_; } + const CpuWallTiming& vectorHasherMergeTiming() const { + return vectorHasherMergeTiming_; + } + /// Copies the values at 'columnIndex' into 'result' for the 'rows.size' rows /// pointed to by 'rows'. If an entry in 'rows' is null, sets corresponding /// row in 'result' to null. @@ -452,6 +525,7 @@ class BaseHashTable { std::unique_ptr rows_; ParallelJoinBuildStats parallelJoinBuildStats_; + CpuWallTiming vectorHasherMergeTiming_; }; FOLLY_ALWAYS_INLINE std::ostream& operator<<( @@ -484,8 +558,10 @@ class HashTable : public BaseHashTable { bool allowDuplicates, bool isJoinBuild, bool hasProbedFlag, + bool hasCountFlag, uint32_t minTableSizeForParallelJoinBuild, - memory::MemoryPool* pool); + memory::MemoryPool* pool, + uint64_t bloomFilterMaxSize = 0); ~HashTable() override = default; @@ -500,6 +576,7 @@ class HashTable : public BaseHashTable { false, // allowDuplicates false, // isJoinBuild false, // hasProbedFlag + false, // hasCountFlag 0, // minTableSizeForParallelJoinBuild pool); } @@ -509,8 +586,10 @@ class HashTable : public BaseHashTable { const std::vector& dependentTypes, bool allowDuplicates, bool hasProbedFlag, + bool hasCountFlag, uint32_t minTableSizeForParallelJoinBuild, - memory::MemoryPool* pool) { + memory::MemoryPool* pool, + uint64_t bloomFilterMaxSize = 0) { return std::make_unique( std::move(hashers), std::vector{}, @@ -518,8 +597,10 @@ class HashTable : public BaseHashTable { allowDuplicates, true, // isJoinBuild hasProbedFlag, + hasCountFlag, minTableSizeForParallelJoinBuild, - pool); + pool, + bloomFilterMaxSize); } void groupProbe(HashLookup& lookup, int8_t spillInputStartPartitionBit) @@ -552,6 +633,27 @@ class HashTable : public BaseHashTable { uint64_t maxBytes, char** rows) override; + int32_t listNotProbedRows( + RowContainerIterator& rowContainerIt, + int rowContainerId, + int32_t maxRows, + uint64_t maxBytes, + char** rows) override; + + int32_t listProbedRows( + RowContainerIterator& rowContainerIt, + int rowContainerId, + int32_t maxRows, + uint64_t maxBytes, + char** rows) override; + + int32_t listAllRows( + RowContainerIterator& rowContainerIt, + int rowContainerId, + int32_t maxRows, + uint64_t maxBytes, + char** rows) override; + int32_t listNullKeyRows( NullKeyRowsIterator* iter, int32_t maxRows, @@ -578,6 +680,10 @@ class HashTable : public BaseHashTable { return numDistinct_; } + int32_t numRowContainers() const override { + return otherTables_.size() + 1; + } + HashTableStats stats() const override { return HashTableStats{ capacity_, numRehashes_, numDistinct_, numTombstones_}; @@ -587,10 +693,27 @@ class HashTable : public BaseHashTable { return hasDuplicates_.check(); } + void setAllowDuplicates(const bool allowDuplicates) override { + allowDuplicates_ = allowDuplicates; + } + HashMode hashMode() const override { return hashMode_; } + void addRuntimeStats( + std::unordered_map& runtimeStats) + const override { + runtimeStats[std::string(kCapacity)] = RuntimeMetric(capacity_); + runtimeStats[std::string(kHashMode)] = + RuntimeMetric(static_cast(hashMode_)); + runtimeStats[std::string(kNumRehashes)] = RuntimeMetric(numRehashes_); + runtimeStats[std::string(kNumDistinct)] = RuntimeMetric(numDistinct_); + if (numTombstones_ != 0) { + runtimeStats[std::string(kNumTombstones)] = RuntimeMetric(numTombstones_); + } + } + void decideHashMode( int32_t numNew, int8_t spillInputStartPartitionBit, @@ -611,6 +734,8 @@ class HashTable : public BaseHashTable { void prepareJoinTable( std::vector> tables, int8_t spillInputStartPartitionBit, + size_t vectorHasherMaxNumDistinct, + bool dropDuplicates = false, folly::Executor* executor = nullptr) override; void prepareForJoinProbe( @@ -766,6 +891,14 @@ class HashTable : public BaseHashTable { int32_t listRows(RowsIterator* iter, int32_t maxRows, uint64_t maxBytes, char** rows); + template + int32_t listRows( + RowContainerIterator& rowContainerIt, + int rowContainerId, + int32_t maxRows, + uint64_t maxBytes, + char** rows); + char*& nextRow(char* row) { return *reinterpret_cast(row + nextOffset_); } @@ -846,7 +979,7 @@ class HashTable : public BaseHashTable { // to the end of 'overflows' in 'partitionInfo'. void insertForJoin( char** groups, - uint64_t* hashes, + const uint64_t* hashes, int32_t numGroups, TableInsertPartitionInfo* partitionInfo); @@ -854,7 +987,8 @@ class HashTable : public BaseHashTable { // contents in a RowContainer owned by 'this'. 'hashes' are the hash // numbers or array indices (if kArray mode) for each // group. 'groups' is expected to have no duplicate keys. - void insertForGroupBy(char** groups, uint64_t* hashes, int32_t numGroups); + void + insertForGroupBy(char** groups, const uint64_t* hashes, int32_t numGroups); // Checks if we can apply parallel table build optimization for hash join. // The function returns true if all of the following conditions: @@ -876,12 +1010,14 @@ class HashTable : public BaseHashTable { void parallelJoinBuild(); // Inserts the rows in 'partition' from this and 'otherTables' into 'this'. - // The rows that would have gone past the end of the partition are returned in - // 'overflow'. + // The rows that would have gone past the end of the partition are returned + // in 'overflow', along with their already-computed hashes in + // 'overflowHashes' so the caller can re-insert without re-hashing. void buildJoinPartition( uint8_t partition, const std::vector>& rowPartitions, - std::vector& overflow); + std::vector& overflow, + std::vector& overflowHashes); // Assigns a partition to each row of 'subtable' in RowPartitions of // subtable's RowContainer. If 'hashMode_' is kNormalizedKeys, records the @@ -890,6 +1026,20 @@ class HashTable : public BaseHashTable { HashTable& subtable, RowPartitions& rowPartitions); + // Whether we should build Bloom filters. If Bloom filter pushdown is + // enabled, and the size fits, this returns true if any of the key columns + // support it. The actual build should build a Bloom filter for each key + // column that supports it, and skip the ones that do not. + bool bloomFilterSupported() const; + + // Populate the Bloom filter for the key column with `columnIndex` and rows + // with certain `partition`. The partitions information is stored in + // `rowPartitions`. + void buildBloomFilterPartition( + column_index_t columnIndex, + uint8_t partition, + const std::vector>& rowPartitions); + // Calculates hashes for 'rows' and returns them in 'hashes'. If // 'initNormalizedKeys' is true, the normalized keys are stored below each row // in the container. If 'initNormalizedKeys' is false and the table is in @@ -950,7 +1100,7 @@ class HashTable : public BaseHashTable { template void insertForJoinWithPrefetch( char** groups, - uint64_t* hashes, + const uint64_t* hashes, int32_t numGroups, TableInsertPartitionInfo* partitionInfo); @@ -967,7 +1117,13 @@ class HashTable : public BaseHashTable { // or distinct mode VectorHashers in a group by hash table. 0 for // join build sides. int32_t reservePct() const { - return isJoinBuild_ ? 0 : 50; + return (isJoinBuild_ && allowDuplicates_) ? 0 : 50; + } + + // Used to indicate whether it is a HashTable that does not contain duplicate + // join keys. + bool joinBuildNoDuplicates() const { + return isJoinBuild_ && !allowDuplicates_; } // Returns the byte offset of the bucket for 'hash' starting from 'table_'. @@ -1023,6 +1179,11 @@ class HashTable : public BaseHashTable { } } + // Returns the i-th sub-table participating in parallel join build. + inline HashTable* tableAt(size_t idx) { + return idx == 0 ? this : otherTables_[idx - 1].get(); + } + // We don't want any overlap in the bit ranges used by bucket index and those // used by spill partitioning; otherwise because we receive data from only one // partition, the overlapped bits would be the same and only a fraction of the @@ -1035,8 +1196,11 @@ class HashTable : public BaseHashTable { // The min table size in row to trigger parallel join table build. const uint32_t minTableSizeForParallelJoinBuild_; + const uint64_t bloomFilterMaxSize_; + int8_t sizeBits_; bool isJoinBuild_ = false; + bool allowDuplicates_ = true; // Set at join build time if the table has duplicates, meaning that // the join can be cardinality increasing. Atomic for tsan because diff --git a/velox/exec/HashTableCache.cpp b/velox/exec/HashTableCache.cpp new file mode 100644 index 00000000000..dcecce4cc7e --- /dev/null +++ b/velox/exec/HashTableCache.cpp @@ -0,0 +1,146 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/HashTableCache.h" + +#include + +#include "velox/core/QueryCtx.h" +#include "velox/exec/MemoryReclaimer.h" + +namespace facebook::velox::exec { + +HashTableCache* HashTableCache::instance() { + static HashTableCache instance; + return &instance; +} + +std::shared_ptr HashTableCache::get( + const std::string& key, + const std::string& taskId, + core::QueryCtx* queryCtx, + ContinueFuture* future) { + VELOX_CHECK_NOT_NULL(future, "future parameter must not be null"); + VELOX_CHECK_NOT_NULL(queryCtx, "queryCtx parameter must not be null"); + + std::lock_guard guard(lock_); + + auto it = tables_.find(key); + if (it == tables_.end()) { + // No entry exists - create a placeholder for this task to build the table. + auto* queryPool = queryCtx->pool(); + auto entry = std::make_shared( + key, + taskId, + // Add memory reclaimer that is not reclaimable. + queryPool->addLeafChild( + fmt::format("cached_table_{}", key), + /* threadsafe */ true, + exec::MemoryReclaimer::create())); + tables_.insert({key, entry}); + + // Register callback to clean up this cache entry when QueryCtx is + // destroyed. This ensures tablePool memory is freed before the query + // pool is destroyed. + queryCtx->addReleaseCallback( + [cacheKey = key]() { HashTableCache::instance()->drop(cacheKey); }); + + // Return entry with pool, table will be filled later. + return entry; + } + + auto& entry = it->second; + + // Check if build is complete + if (entry->buildComplete) { + return entry; + } + + // If this is the builder task, don't wait - all drivers of the builder task + // should proceed to build (they coordinate via JoinBridge, not here). + if (entry->builderTaskId == taskId) { + return entry; + } + + auto [promise, _future] = + makeVeloxContinuePromiseContract(fmt::format("HashTableCache::{}", key)); + entry->buildPromises.push_back(std::move(promise)); + *future = std::move(_future); + + return entry; +} + +void HashTableCache::put( + const std::string& key, + std::shared_ptr table, + bool hasNullKeys) { + std::vector promises; + + { + std::lock_guard guard(lock_); + + auto it = tables_.find(key); + VELOX_CHECK( + it != tables_.end(), + "Cache entry for key '{}' must be created by get() before put()", + key); + + auto& entry = it->second; + VELOX_CHECK(!entry->buildComplete); + VELOX_CHECK_NULL(entry->table); + // Update the entry with the built table + entry->table = std::move(table); + entry->hasNullKeys = hasNullKeys; + entry->buildComplete = true; + + // Collect promises to notify waiters + promises = std::move(entry->buildPromises); + } + + // Notify all waiting tasks outside the lock + for (auto& promise : promises) { + promise.setValue(); + } +} + +void HashTableCache::drop(const std::string& key) { + std::shared_ptr entry; + std::vector promises; + { + std::lock_guard guard(lock_); + auto it = tables_.find(key); + if (it != tables_.end()) { + entry = std::move(it->second); + tables_.erase(it); + } + } + + // Clear the table outside the lock to free memory before the entry + // is destroyed. This ensures the tablePool's memory is released + // before any parent pools are destroyed. + if (entry) { + promises = std::move(entry->buildPromises); + entry->table.reset(); + } + + // Fulfill any pending build promises so waiting tasks are unblocked + // rather than hanging forever (e.g., after builder OOM). + for (auto& promise : promises) { + promise.setValue(); + } +} + +} // namespace facebook::velox::exec diff --git a/velox/exec/HashTableCache.h b/velox/exec/HashTableCache.h new file mode 100644 index 00000000000..e1d9717e89e --- /dev/null +++ b/velox/exec/HashTableCache.h @@ -0,0 +1,83 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +#include "velox/exec/HashTable.h" + +namespace facebook::velox::core { +class QueryCtx; +} + +namespace facebook::velox::exec { + +/// Cached hash table entry with build coordination metadata. +struct HashTableCacheEntry { + HashTableCacheEntry( + std::string _cacheKey, + std::string _builderTaskId, + std::shared_ptr _tablePool) + : cacheKey(std::move(_cacheKey)), + builderTaskId(std::move(_builderTaskId)), + tablePool(std::move(_tablePool)) {} + + const std::string cacheKey; + const std::string builderTaskId; + const std::shared_ptr tablePool; + std::shared_ptr table; + bool hasNullKeys{false}; + tsan_atomic buildComplete{false}; + std::vector buildPromises; +}; + +/// Global cache for hash tables shared across tasks within the same query. +/// First task builds the table, subsequent tasks wait and reuse it. +class HashTableCache { + public: + static HashTableCache* instance(); + + /// Gets or creates a cache entry. First caller becomes the builder. + /// Subsequent callers from different tasks get a future to wait on. + /// When a new entry is created, a release callback is registered on queryCtx + /// to clean up the entry when the query completes. + /// @param future Must be non-null; set if caller needs to wait. + std::shared_ptr get( + const std::string& key, + const std::string& taskId, + core::QueryCtx* queryCtx, + ContinueFuture* future); + + /// Stores a built hash table and notifies waiting tasks. + void put( + const std::string& key, + std::shared_ptr table, + bool hasNullKeys); + + /// Removes a cache entry. + void drop(const std::string& key); + + private: + HashTableCache() = default; + + std::mutex lock_; + std::unordered_map> tables_; +}; + +} // namespace facebook::velox::exec diff --git a/velox/exec/HilbertIndex.h b/velox/exec/HilbertIndex.h new file mode 100644 index 00000000000..9d8fd096f95 --- /dev/null +++ b/velox/exec/HilbertIndex.h @@ -0,0 +1,156 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// Based off of https://threadlocalmutex.com/?p=126 + +#pragma once + +#include +#include + +#include "velox/common/base/Exceptions.h" + +namespace facebook::velox::exec { + +class HilbertIndex { + public: + /// Construct a Hilber index. If a min value is greater than the max value, + /// this will panic. + HilbertIndex(float minX, float minY, float maxX, float maxY) + : minX_(minX), minY_(minY), maxX_(maxX), maxY_(maxY) { + VELOX_CHECK(minX_ <= maxX_); + VELOX_CHECK(minY_ <= maxY_); + + float deltaX = maxX_ - minX_; + // Subnormals cause numerical instability. + // NOLINTNEXTLINE(facebook-hte-FloatingPointMin) + if (deltaX < std::numeric_limits::min()) { + xScale_ = 0; + } else { + xScale_ = kHilbertMax / deltaX; + } + + float deltaY = maxY_ - minY_; + // Subnormals cause numerical instability. + // NOLINTNEXTLINE(facebook-hte-FloatingPointMin) + if (deltaY < std::numeric_limits::min()) { + yScale_ = 0; + } else { + yScale_ = kHilbertMax / deltaY; + } + } + + uint32_t inline indexOf(float x, float y) const { + if (!(x >= minX_ && x <= maxX_ && y >= minY_ && y <= maxY_)) { + // Put things outside the bounds at the end of the Hilbert curve. + // Negation handles NaNs + return std::numeric_limits::max(); + } + + float maxFloat = static_cast(std::numeric_limits::max()); + + uint32_t xInt = static_cast( + std::clamp(xScale_ * (x - minX_), 0.0f, maxFloat)); + uint32_t yInt = static_cast( + std::clamp(yScale_ * (y - minY_), 0.0f, maxFloat)); + return discreteIndexOf(xInt, yInt); + } + + private: + static inline uint32_t interleave(uint32_t x) { + x = (x | (x << 8)) & 0x00FF00FF; + x = (x | (x << 4)) & 0x0F0F0F0F; + x = (x | (x << 2)) & 0x33333333; + x = (x | (x << 1)) & 0x55555555; + return x; + } + + static inline uint32_t discreteIndexOf(uint32_t x, uint32_t y) { + uint32_t A, B, C, D; + + // Initial prefix scan round, prime with x and y + { + uint32_t a = x ^ y; + uint32_t b = 0xFFFF ^ a; + uint32_t c = 0xFFFF ^ (x | y); + uint32_t d = x & (y ^ 0xFFFF); + + A = a | (b >> 1); + B = (a >> 1) ^ a; + + C = ((c >> 1) ^ (b & (d >> 1))) ^ c; + D = ((a & (c >> 1)) ^ (d >> 1)) ^ d; + } + + { + uint32_t a = A; + uint32_t b = B; + uint32_t c = C; + uint32_t d = D; + + A = ((a & (a >> 2)) ^ (b & (b >> 2))); + B = ((a & (b >> 2)) ^ (b & ((a ^ b) >> 2))); + + C ^= ((a & (c >> 2)) ^ (b & (d >> 2))); + D ^= ((b & (c >> 2)) ^ ((a ^ b) & (d >> 2))); + } + + { + uint32_t a = A; + uint32_t b = B; + uint32_t c = C; + uint32_t d = D; + + A = ((a & (a >> 4)) ^ (b & (b >> 4))); + B = ((a & (b >> 4)) ^ (b & ((a ^ b) >> 4))); + + C ^= ((a & (c >> 4)) ^ (b & (d >> 4))); + D ^= ((b & (c >> 4)) ^ ((a ^ b) & (d >> 4))); + } + + // Final round and projection + { + uint32_t a = A; + uint32_t b = B; + uint32_t c = C; + uint32_t d = D; + + C ^= ((a & (c >> 8)) ^ (b & (d >> 8))); + D ^= ((b & (c >> 8)) ^ ((a ^ b) & (d >> 8))); + } + + // Undo transformation prefix scan + uint32_t a = C ^ (C >> 1); + uint32_t b = D ^ (D >> 1); + + // Recover index bits + uint32_t i0 = x ^ y; + uint32_t i1 = b | (0xFFFF ^ (i0 | a)); + + return (interleave(i1) << 1) | interleave(i0); + } + + static const int8_t kHilbertBits = 16; + static constexpr float kHilbertMax = (1 << kHilbertBits) - 1; + + const float minX_; + const float minY_; + const float maxX_; + const float maxY_; + float xScale_; + float yScale_; +}; + +} // namespace facebook::velox::exec diff --git a/velox/exec/IndexLookupJoin.cpp b/velox/exec/IndexLookupJoin.cpp index f0e437ec713..982f86fefbb 100644 --- a/velox/exec/IndexLookupJoin.cpp +++ b/velox/exec/IndexLookupJoin.cpp @@ -16,14 +16,52 @@ #include "velox/exec/IndexLookupJoin.h" #include "velox/buffer/Buffer.h" +#include "velox/common/base/RuntimeMetrics.h" +#include "velox/common/testutil/TestValue.h" #include "velox/connectors/Connector.h" +#include "velox/connectors/ConnectorRegistry.h" +#include "velox/core/QueryConfig.h" +#include "velox/exec/OperatorTraceWriter.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/OperatorUtils.h" #include "velox/exec/Task.h" +#include "velox/exec/trace/TraceUtil.h" #include "velox/expression/Expr.h" #include "velox/expression/FieldReference.h" +#include "velox/vector/LazyVector.h" + +using facebook::velox::common::testutil::TestValue; namespace facebook::velox::exec { +using IndexSource = connector::IndexSource; + +void IndexLookupJoin::IndexStatWriter::addRuntimeStat( + std::string_view name, + const RuntimeCounter& value) { + auto lockedStats = runtimeStats_.wlock(); + auto it = lockedStats->find(std::string(name)); + if (it != lockedStats->end()) { + it->second.addValue(value.value); + } else { + RuntimeMetric metric(value.unit); + metric.addValue(value.value); + lockedStats->emplace(std::string(name), std::move(metric)); + } +} + +void IndexLookupJoin::IndexStatWriter::setRuntimeStat( + const std::string& name, + const RuntimeMetric& metric) { + runtimeStats_.wlock()->insert_or_assign(name, metric); +} + +std::unordered_map +IndexLookupJoin::IndexStatWriter::runtimeStats() const { + return *runtimeStats_.rlock(); +} + namespace { + void duplicateJoinKeyCheck( const std::vector& keys) { folly::F14FastSet lookupKeyNames; @@ -60,6 +98,36 @@ void addLookupInputColumn( lookupInputNameSet.insert(columnName); } +// Normalizes all join conditions into a unified representation by converting +// equi-join keys (leftKeys/rightKeys) to EqualIndexLookupCondition objects. +// Each leftKey/rightKey pair is converted to an EqualIndexLookupCondition +// where: +// - key: the index column expression (from rightKeys) +// - value: the probe column expression (from leftKeys) +// The resulting vector contains the converted equi-join conditions followed by +// the original joinConditions. +std::vector getJoinConditions( + const std::vector& leftKeys, + const std::vector& rightKeys, + const std::vector& joinConditions) { + VELOX_CHECK_EQ(leftKeys.size(), rightKeys.size()); + + std::vector normalizedConditions; + normalizedConditions.reserve(leftKeys.size() + joinConditions.size()); + + for (size_t i = 0; i < leftKeys.size(); ++i) { + normalizedConditions.push_back( + std::make_shared( + rightKeys[i], leftKeys[i])); + } + + for (const auto& condition : joinConditions) { + normalizedConditions.push_back(condition); + } + + return normalizedConditions; +} + // Validates one of between bound, and update the lookup input channels and type // to include the corresponding probe input column if the bound is not constant. bool addBetweenConditionBound( @@ -86,8 +154,9 @@ bool addBetweenConditionBound( lookupInputChannels, lookupInputNameSet); } else { - VELOX_USER_CHECK(core::TypedExprs::asConstant(typeExpr)->type()->equivalent( - *indexKeyType)); + VELOX_USER_CHECK( + core::TypedExprs::asConstant(typeExpr)->type()->equivalent( + *indexKeyType)); } return isConstant; } @@ -127,8 +196,102 @@ void addBetweenCondition( "At least one of the between condition bounds needs to be not constant: {}", betweenCondition->toString()); } + +// Create a row vector wrapper without allocating any buffer. +// We expect the child vectors to be directly set from probe inputs and join +// outputs. +inline RowVectorPtr createRowVector( + velox::memory::MemoryPool* pool, + const RowTypePtr& type, + vector_size_t numRows) { + std::vector children(type->size(), nullptr); + return std::make_shared( + pool, type, nullptr, numRows, std::move(children)); +} + +// Extracts a runtime stat from the map, removes it, and returns its sum value. +// Returns 0 if the stat is not found. +int64_t extractStatSum( + std::unordered_map& stats, + const std::string& name) { + auto it = stats.find(name); + if (it == stats.end()) { + return 0; + } + const auto value = it->second.sum; + stats.erase(it); + return value; +} + } // namespace +// static +std::vector IndexLookupJoin::splitStats( + const OperatorStats& combinedStats, + const core::PlanNodeId& indexSourceNodeId, + const IndexStatWriter& indexSourceStatWriter) { + // Create stats for the IndexSource node from the accumulated index source + // runtime stats. + OperatorStats indexSourceStats; + indexSourceStats.operatorId = combinedStats.operatorId; + indexSourceStats.pipelineId = combinedStats.pipelineId; + indexSourceStats.planNodeId = indexSourceNodeId; + indexSourceStats.operatorType = "IndexSource"; + indexSourceStats.numDrivers = combinedStats.numDrivers; + + // Populate IndexSource runtime stats from the writer's accumulated map. + indexSourceStats.runtimeStats = indexSourceStatWriter.runtimeStats(); + + // Extract standard operator stats from runtime stats into OperatorStats + // fields so they show up in PlanNodeStats (inputRows, outputRows, etc.). + indexSourceStats.outputPositions = + extractStatSum(indexSourceStats.runtimeStats, "outputPositions"); + indexSourceStats.outputBytes = + extractStatSum(indexSourceStats.runtimeStats, "outputBytes"); + indexSourceStats.outputVectors = + extractStatSum(indexSourceStats.runtimeStats, "outputVectors"); + indexSourceStats.inputPositions = + extractStatSum(indexSourceStats.runtimeStats, "inputPositions"); + indexSourceStats.inputBytes = + extractStatSum(indexSourceStats.runtimeStats, "inputBytes"); + + // Populate IndexSource addInputTiming from lookup timing stats recorded by + // recordConnectorStats into the writer. + const auto lookupCount = + extractStatSum(indexSourceStats.runtimeStats, "lookupCount"); + const auto lookupWallNanos = + extractStatSum(indexSourceStats.runtimeStats, "lookupWallNanos"); + const auto lookupCpuNanos = + extractStatSum(indexSourceStats.runtimeStats, "lookupCpuNanos"); + if (lookupCount > 0) { + indexSourceStats.addInputTiming = CpuWallTiming{ + static_cast(lookupCount), + static_cast(lookupWallNanos), + static_cast(lookupCpuNanos)}; + } + + // Create stats for the IndexLookupJoin node. + auto joinStats = combinedStats; + + // Remove residual probe-side lazy loading stats from join stats. These + // accumulate in every non-scan operator's runtimeStats and are normally + // harmless, but splitStats copies combinedStats to create the join node + // entry, so we must erase them to avoid exposing them on the join node. + joinStats.runtimeStats.erase(std::string(LazyVector::kInputBytes)); + joinStats.runtimeStats.erase(std::string(LazyVector::kCpuNanos)); + joinStats.runtimeStats.erase(std::string(LazyVector::kWallNanos)); + + // Remove index source stats from join stats. recordConnectorStats copies + // these into the operator's runtimeStats (prefixed with "indexSource.") so + // they flow through the task-level stats path, but they belong to the + // IndexSource node, not the join node. + for (const auto& [name, _] : indexSourceStatWriter.runtimeStats()) { + joinStats.runtimeStats.erase(fmt::format("indexSource.{}", name)); + } + + return {std::move(joinStats), std::move(indexSourceStats)}; +} + IndexLookupJoin::IndexLookupJoin( int32_t operatorId, DriverCtx* driverCtx, @@ -138,7 +301,8 @@ IndexLookupJoin::IndexLookupJoin( joinNode->outputType(), operatorId, joinNode->id(), - "IndexLookupJoin"), + OperatorType::kIndexLookupJoin), + splitOutput_{driverCtx->queryConfig().indexLookupJoinSplitOutput()}, // TODO: support to update output batch size with output size stats during // the lookup processing. outputBatchSize_{ @@ -146,12 +310,15 @@ IndexLookupJoin::IndexLookupJoin( ? outputBatchRows() : std::numeric_limits::max()}, joinType_{joinNode->joinType()}, - includeMatchColumn_(joinNode->includeMatchColumn()), - numKeys_{joinNode->leftKeys().size()}, + hasMarker_(joinNode->hasMarker()), probeType_{joinNode->sources()[0]->outputType()}, lookupType_{joinNode->lookupSource()->outputType()}, + indexSourceNodeId_(joinNode->lookupSource()->id()), lookupTableHandle_{joinNode->lookupSource()->tableHandle()}, - lookupConditions_{joinNode->joinConditions()}, + joinConditions_{getJoinConditions( + joinNode->leftKeys(), + joinNode->rightKeys(), + joinNode->joinConditions())}, lookupColumnHandles_(joinNode->lookupSource()->assignments()), connectorQueryCtx_{operatorCtx_->createConnectorQueryCtx( lookupTableHandle_->connectorId(), @@ -163,12 +330,29 @@ IndexLookupJoin::IndexLookupJoin( operatorType(), lookupTableHandle_->connectorId()), spillConfig_.has_value() ? &(spillConfig_.value()) : nullptr)}, - connector_(connector::getConnector(lookupTableHandle_->connectorId())), + connector_( + connector::ConnectorRegistry::tryGet( + *driverCtx->task->queryCtx(), + lookupTableHandle_->connectorId())), maxNumInputBatches_( 1 + driverCtx->queryConfig().indexLookupJoinMaxPrefetchBatches()), - joinNode_{joinNode} { + isIndexSplitCollector_{driverCtx->partitionId == 0}, + joinNode_{joinNode}, + indexStatWriter_(std::make_shared()) { duplicateJoinKeyCheck(joinNode_->leftKeys()); duplicateJoinKeyCheck(joinNode_->rightKeys()); + + // Set up the stat splitter to report separate OperatorStats for + // IndexLookupJoin and IndexSource nodes. Must be done in the constructor + // (not initialize()) because operator stats are copied to task stats before + // initialize() is called. + stats_.withWLock([&](auto& stats) { + stats.setStatSplitter([indexSourceId = indexSourceNodeId_, + indexSourceStatWriter = + indexStatWriter_](const auto& combinedStats) { + return splitStats(combinedStats, indexSourceId, *indexSourceStatWriter); + }); + }); } void IndexLookupJoin::initialize() { @@ -184,15 +368,56 @@ void IndexLookupJoin::initialize() { initLookupInput(); initLookupOutput(); initOutputProjections(); + initFilter(); indexSource_ = connector_->createIndexSource( lookupInputType_, - numKeys_, - lookupConditions_, + joinConditions_, lookupOutputType_, lookupTableHandle_, lookupColumnHandles_, connectorQueryCtx_.get()); + + if (lookupTableHandle_->needsIndexSplit()) { + auto* driverCtx = operatorCtx_->driverCtx(); + joinBridge_ = driverCtx->task->getIndexLookupJoinBridge( + driverCtx->splitGroupId, planNodeId()); + VELOX_CHECK_NOT_NULL(joinBridge_); + } + + createIndexSplitTracer(); +} + +void IndexLookupJoin::createIndexSplitTracer() { + if (inputTracer_ == nullptr) { + return; + } + const auto& queryConfig = operatorCtx_->driverCtx()->queryConfig(); + const auto taskTraceDir = exec::trace::getTaskTraceDirectory( + queryConfig.queryTraceDir(), + operatorCtx_->driverCtx()->task->queryCtx()->queryId(), + operatorCtx_->taskId()); + const auto opTraceDir = exec::trace::getOpTraceDirectory( + taskTraceDir, + planNodeId(), + operatorCtx_->driverCtx()->pipelineId, + operatorCtx_->driverCtx()->driverId); + indexSplitTracer_ = + std::make_unique(this, opTraceDir); +} + +void IndexLookupJoin::traceIndexSplit(const exec::Split& split) { + if (FOLLY_UNLIKELY(indexSplitTracer_ != nullptr)) { + auto connectorSplit = split.connectorSplit; + indexSplitTracer_->write(exec::Split(std::move(connectorSplit))); + } +} + +void IndexLookupJoin::closeIndexSplitTracer() { + // NOTE: skip calling finish() because the input tracer (via Operator::close) + // already wrote the summary file. Both tracers share the same summary path + // and finish() would fail with O_EXCL. + indexSplitTracer_.reset(); } void IndexLookupJoin::ensureInputLoaded(const InputBatchState& batch) { @@ -215,14 +440,13 @@ void IndexLookupJoin::initLookupInput() { VELOX_CHECK(lookupInputChannels_.empty()); std::vector lookupInputNames; - lookupInputNames.reserve(numKeys_ + lookupConditions_.size()); + lookupInputNames.reserve(joinConditions_.size()); std::vector lookupInputTypes; - lookupInputTypes.reserve(numKeys_ + lookupConditions_.size()); - lookupInputChannels_.reserve(numKeys_ + lookupConditions_.size()); + lookupInputTypes.reserve(joinConditions_.size()); + lookupInputChannels_.reserve(joinConditions_.size()); SCOPE_EXIT { - VELOX_CHECK_GE( - lookupInputNames.size(), numKeys_ + lookupConditions_.size()); + VELOX_CHECK_GE(lookupInputNames.size(), joinConditions_.size()); VELOX_CHECK_EQ(lookupInputNames.size(), lookupInputChannels_.size()); lookupInputType_ = ROW(std::move(lookupInputNames), std::move(lookupInputTypes)); @@ -231,25 +455,6 @@ void IndexLookupJoin::initLookupInput() { folly::F14FastSet lookupInputColumnSet; folly::F14FastSet lookupIndexColumnSet; - // List probe columns used in join-equi caluse first. - for (auto keyIdx = 0; keyIdx < numKeys_; ++keyIdx) { - const auto probeKeyName = joinNode_->leftKeys()[keyIdx]->name(); - const auto indexKeyName = joinNode_->rightKeys()[keyIdx]->name(); - VELOX_USER_CHECK_EQ(lookupIndexColumnSet.count(indexKeyName), 0); - lookupIndexColumnSet.insert(indexKeyName); - const auto probeKeyChannel = probeType_->getChildIdx(probeKeyName); - const auto probeKeyType = probeType_->childAt(probeKeyChannel); - VELOX_USER_CHECK( - lookupType_->findChild(indexKeyName)->equivalent(*probeKeyType)); - addLookupInputColumn( - indexKeyName, - probeKeyType, - probeKeyChannel, - lookupInputNames, - lookupInputTypes, - lookupInputChannels_, - lookupInputColumnSet); - } SCOPE_EXIT { VELOX_CHECK(lookupKeyOrConditionHashers_.empty()); @@ -257,19 +462,42 @@ void IndexLookupJoin::initLookupInput() { lookupKeyOrConditionHashers_ = createVectorHashers(probeType_, lookupInputChannels_); }; - if (lookupConditions_.empty()) { - return; - } - for (const auto& lookupCondition : lookupConditions_) { - const auto indexKeyName = getColumnName(lookupCondition->key); + for (const auto& condition : joinConditions_) { + const auto indexKeyName = getColumnName(condition->key); VELOX_USER_CHECK_EQ(lookupIndexColumnSet.count(indexKeyName), 0); lookupIndexColumnSet.insert(indexKeyName); const auto indexKeyType = lookupType_->findChild(indexKeyName); + if (const auto equalCondition = + std::dynamic_pointer_cast( + condition)) { + VELOX_CHECK( + !equalCondition->isFilter(), + "Constant equal condition in join not supported"); + // Process as a join condition - value references a probe column. + const auto probeKeyName = getColumnName(equalCondition->value); + const auto probeKeyChannel = probeType_->getChildIdx(probeKeyName); + const auto probeKeyType = probeType_->childAt(probeKeyChannel); + VELOX_USER_CHECK( + indexKeyType->equivalent(*probeKeyType), + "Index key type {} must be equivalent to probe key type {}", + indexKeyType->toString(), + probeKeyType->toString()); + addLookupInputColumn( + probeKeyName, + probeKeyType, + probeKeyChannel, + lookupInputNames, + lookupInputTypes, + lookupInputChannels_, + lookupInputColumnSet); + continue; + } + if (const auto inCondition = std::dynamic_pointer_cast( - lookupCondition)) { + condition)) { const auto conditionInputName = getColumnName(inCondition->list); const auto conditionInputChannel = probeType_->getChildIdx(conditionInputName); @@ -286,11 +514,12 @@ void IndexLookupJoin::initLookupInput() { lookupInputTypes, lookupInputChannels_, lookupInputColumnSet); + continue; } if (const auto betweenCondition = std::dynamic_pointer_cast( - lookupCondition)) { + condition)) { addBetweenCondition( betweenCondition, probeType_, @@ -299,21 +528,10 @@ void IndexLookupJoin::initLookupInput() { lookupInputTypes, lookupInputChannels_, lookupInputColumnSet); + continue; } - if (const auto equalCondition = - std::dynamic_pointer_cast( - lookupCondition)) { - // Process an equal join condition by validating that the value is - // constant. Equal conditions only support constant values for filtering. - VELOX_USER_CHECK( - core::TypedExprs::isConstant(equalCondition->value), - "Equal condition value must be constant: {}", - equalCondition->toString()); - VELOX_USER_CHECK(core::TypedExprs::asConstant(equalCondition->value) - ->type() - ->equivalent(*indexKeyType)); - } + VELOX_UNSUPPORTED("Unsupported join condition type"); } } @@ -371,15 +589,129 @@ void IndexLookupJoin::initOutputProjections() { } lookupOutputProjections_.emplace_back(i, outputChannelOpt.value()); } - if (includeMatchColumn_) { + if (hasMarker_) { matchOutputChannel_ = outputType_->size() - 1; } + VELOX_USER_CHECK_EQ( probeOutputProjections_.size() + lookupOutputProjections_.size() + !!matchOutputChannel_.has_value(), outputType_->size()); } +void IndexLookupJoin::initFilter() { + VELOX_CHECK_NULL(filter_); + + if (joinNode_->filter() == nullptr) { + return; + } + + std::vector filters = {joinNode_->filter()}; + filter_ = + std::make_unique(std::move(filters), operatorCtx_->execCtx()); + + std::vector names; + std::vector types; + const auto numFields = filter_->expr(0)->distinctFields().size(); + names.reserve(numFields); + types.reserve(numFields); + + column_index_t filterChannel{0}; + const auto addChannel = [&](column_index_t channel, + const RowTypePtr& inputType, + std::vector& projections) { + names.emplace_back(inputType->nameOf(channel)); + types.emplace_back(inputType->childAt(channel)); + projections.emplace_back(channel, filterChannel++); + }; + + for (const auto& field : filter_->expr(0)->distinctFields()) { + const auto& name = field->field(); + auto channel = probeType_->getChildIdxIfExists(name); + if (channel.has_value()) { + addChannel(channel.value(), probeType_, filterProbeInputProjections_); + continue; + } + channel = lookupOutputType_->getChildIdxIfExists(name); + if (channel.has_value()) { + addChannel( + channel.value(), lookupOutputType_, filterLookupInputProjections_); + continue; + } + VELOX_FAIL( + "Index lookup join filter field not found in either left or right input: {}", + field->toString()); + } + + filterInputType_ = ROW(std::move(names), std::move(types)); +} + +bool IndexLookupJoin::collectIndexSplits(ContinueFuture* future) { + VELOX_CHECK(needsIndexSplits()); + + TestValue::adjust( + "facebook::velox::exec::IndexLookupJoin::collectIndexSplits", this); + + auto* driverCtx = operatorCtx_->driverCtx(); + while (true) { + exec::Split split; + const auto reason = driverCtx->task->getSplitOrFuture( + driverCtx->driverId, + driverCtx->splitGroupId, + indexSourceNodeId_, + /*maxPreloadSplits=*/0, + /*preload=*/nullptr, + split, + indexSplitFuture_); + if (reason != BlockingReason::kNotBlocked) { + *future = std::move(indexSplitFuture_); + return false; + } + + if (!split.hasConnectorSplit()) { + noMoreIndexSplits_ = true; + VELOX_CHECK(!hasNoIndexSplits_); + VELOX_CHECK_NOT_NULL(joinBridge_); + joinBridge_->setIndexSplits(indexSplits_); + { + auto lockedStats = stats_.wlock(); + lockedStats->addRuntimeStat( + kNumIndexSplits, + RuntimeCounter( + static_cast(indexSplits_.size()), + RuntimeCounter::Unit::kNone)); + } + if (indexSplits_.empty()) { + hasNoIndexSplits_ = true; + } else { + indexSource_->addSplits(std::move(indexSplits_)); + } + return true; + } + + traceIndexSplit(split); + indexSplits_.push_back(std::move(split.connectorSplit)); + } +} + +bool IndexLookupJoin::waitForIndexSplits(ContinueFuture* future) { + VELOX_CHECK(needsIndexSplits()); + VELOX_CHECK(!isIndexSplitCollector_); + VELOX_CHECK(indexSplits_.empty()); + + auto splits = joinBridge_->splitsOrFuture(future); + if (future->valid()) { + return false; + } + noMoreIndexSplits_ = true; + if (splits.empty()) { + hasNoIndexSplits_ = true; + } else { + indexSource_->addSplits(std::move(splits)); + } + return true; +} + bool IndexLookupJoin::startDrain() { return numInputBatches() != 0; } @@ -388,6 +720,13 @@ bool IndexLookupJoin::needsInput() const { if (noMoreInput_ || isDraining()) { return false; } + // Don't accept input until we have collected all splits for index source. + if (needsIndexSplits()) { + return false; + } + if (shouldSkipInput()) { + return false; + } if (numInputBatches() >= maxNumInputBatches_) { return false; } @@ -402,6 +741,21 @@ bool IndexLookupJoin::needsInput() const { } BlockingReason IndexLookupJoin::isBlocked(ContinueFuture* future) { + // Handle split collection for index sources that require splits. + if (needsIndexSplits()) { + if (isIndexSplitCollector_) { + if (!collectIndexSplits(future)) { + VELOX_CHECK(future->valid()); + return BlockingReason::kWaitForSplit; + } + } else { + if (!waitForIndexSplits(future)) { + VELOX_CHECK(future->valid()); + return BlockingReason::kWaitForIndexSplits; + } + } + } + auto& batch = currentInputBatch(); if (!batch.lookupFuture.valid()) { endLookupBlockWait(); @@ -438,16 +792,33 @@ void IndexLookupJoin::endLookupBlockWait() { void IndexLookupJoin::addInput(RowVectorPtr input) { VELOX_CHECK_GT(input->size(), 0); + VELOX_CHECK(!shouldSkipInput()); auto& batch = nextInputBatch(); VELOX_CHECK_LE(numInputBatches(), maxNumInputBatches_); batch.input = std::move(input); + // Probe-side lazy loading happens here — stats go to kInputBytes (standard + // names) and are transferred to the scan by Driver::processLazyIoStats(). ensureInputLoaded(batch); decodeAndDetectNonNullKeys(batch); prepareLookup(batch); - startLookup(batch); + recordIndexSourceInputStats(batch); + // startLookup may trigger index-side lazy loading via mergeLookupResults() + // when sync lookups return multiple partial results. Redirect those stats to + // index-specific names. + { + RuntimeStatWriterScopeGuard guard(indexStatWriter_.get()); + startLookup(batch); + } } RowVectorPtr IndexLookupJoin::getOutput() { + // Redirect lazy loading stats during getOutput() to index-specific names. + // This separates index-side lazy loading (loading lookup result vectors) from + // probe-side lazy loading (loading scan input vectors in addInput()), so that + // Driver::processLazyIoStats() correctly attributes only probe-side stats to + // the scan operator. + RuntimeStatWriterScopeGuard guard(indexStatWriter_.get()); + SCOPE_EXIT { if (numInputBatches() == 0 && isDraining()) { finishDrain(); @@ -478,14 +849,8 @@ void IndexLookupJoin::prepareLookup(InputBatchState& batch) { const size_t numLookupRows = batch.lookupInputHasNullKeys ? batch.nonNullInputRows.countSelected() : batch.input->size(); - if (batch.lookupInput == nullptr) { - batch.lookupInput = - BaseVector::create(lookupInputType_, numLookupRows, pool()); - } else { - VectorPtr lookupInputVector = std::move(batch.lookupInput); - BaseVector::prepareForReuse(lookupInputVector, numLookupRows); - batch.lookupInput = std::static_pointer_cast(lookupInputVector); - } + batch.lookupInput = createRowVector( + pool(), lookupInputType_, static_cast(numLookupRows)); if (!batch.lookupInputHasNullKeys) { for (auto i = 0; i < lookupInputType_->size(); ++i) { @@ -525,6 +890,110 @@ void IndexLookupJoin::prepareLookup(InputBatchState& batch) { } } +void IndexLookupJoin::mergeLookupResults(InputBatchState& batch) { + VELOX_CHECK(!batch.partialOutputs.empty()); + VELOX_CHECK_NULL(batch.lookupResult); + + if (batch.partialOutputs.size() == 1) { + batch.lookupResult = std::move(batch.partialOutputs[0]); + SCOPE_EXIT { + VELOX_CHECK_NOT_NULL(batch.lookupResult); + batch.partialOutputs.clear(); + }; + return; + } + + // Calculate total size. + vector_size_t totalSize = 0; + for (const auto& result : batch.partialOutputs) { + totalSize += static_cast(result->size()); + } + + // Merge inputHits buffers. + auto mergedInputHits = allocateIndices(totalSize, pool()); + auto* rawMergedInputHits = mergedInputHits->asMutable(); + vector_size_t offset = 0; + for (const auto& result : batch.partialOutputs) { + std::memcpy( + rawMergedInputHits + offset, + result->inputHits->as(), + result->size() * sizeof(vector_size_t)); + offset += static_cast(result->size()); + } + + // Merge output RowVectors. + // NOTE: Uncommon path for connectors that do not respect output batch size + // properly + auto mergedOutput = BaseVector::create( + batch.partialOutputs[0]->output->type(), totalSize, pool()); + vector_size_t outputOffset = 0; + for (const auto& result : batch.partialOutputs) { + mergedOutput->copy(result->output.get(), outputOffset, 0, result->size()); + outputOffset += static_cast(result->size()); + } + + batch.lookupResult = std::make_unique( + std::move(mergedInputHits), std::move(mergedOutput)); + batch.partialOutputs.clear(); +} + +bool IndexLookupJoin::getLookupResults(InputBatchState& batch) { + VELOX_CHECK_NOT_NULL(batch.lookupInput); + VELOX_CHECK_NOT_NULL(batch.lookupResultIter); + VELOX_CHECK(!batch.lookupFuture.valid()); + + // Result is ready. + if (batch.lookupResult != nullptr) { + return true; + } + + // Fetch the first result if not already fetched. + if (batch.lookupResult == nullptr && batch.partialOutputs.empty()) { + auto lookupResultOr = + batch.lookupResultIter->next(outputBatchSize_, batch.lookupFuture); + if (!lookupResultOr.has_value()) { + VELOX_CHECK(batch.lookupFuture.valid()); + return false; + } + VELOX_CHECK(!batch.lookupFuture.valid()); + + // Either splitOutput_ is true, or no more results, or first result is null. + if (splitOutput_ || !batch.lookupResultIter->hasNext()) { + batch.lookupResult = std::move(lookupResultOr).value(); + return true; + } + + // Otherwise start accumulating results. + batch.partialOutputs.push_back(std::move(lookupResultOr).value()); + } + + // Continue accumulating remaining results when splitOutput_ is false. + // This handles both initial accumulation and resuming after async + // interruption. + VELOX_CHECK(!splitOutput_); + VELOX_CHECK(!batch.partialOutputs.empty()); + VELOX_CHECK_NULL(batch.lookupResult); + + while (batch.lookupResultIter->hasNext()) { + auto nextResultOr = + batch.lookupResultIter->next(outputBatchSize_, batch.lookupFuture); + if (!nextResultOr.has_value()) { + // Need to wait for async operation. + VELOX_CHECK(batch.lookupFuture.valid()); + return false; + } + VELOX_CHECK(!batch.lookupFuture.valid()); + auto nextResult = std::move(nextResultOr).value(); + if (nextResult != nullptr) { + batch.partialOutputs.push_back(std::move(nextResult)); + } + } + + // All results accumulated, merge them. + mergeLookupResults(batch); + return true; +} + void IndexLookupJoin::decodeAndDetectNonNullKeys(InputBatchState& batch) { const auto numRows = batch.input->size(); batch.nonNullInputRows.resize(numRows); @@ -555,60 +1024,55 @@ void IndexLookupJoin::startLookup(InputBatchState& batch) { VELOX_CHECK_NULL(batch.lookupResult); VELOX_CHECK(!batch.lookupFuture.valid()); - if (batch.lookupInput->size() == 0) { - // No need to start lookup for empty lookup input. + if (shouldSkipLookup(batch)) { return; } - batch.lookupResultIter = indexSource_->lookup( - connector::IndexSource::LookupRequest{batch.lookupInput}); - auto lookupResultOr = - batch.lookupResultIter->next(outputBatchSize_, batch.lookupFuture); - if (!lookupResultOr.has_value()) { - VELOX_CHECK(batch.lookupFuture.valid()); - return; - } - VELOX_CHECK(!batch.lookupFuture.valid()); - batch.lookupResult = std::move(lookupResultOr).value(); + // Create the lookup result iterator. + batch.lookupResultIter = + indexSource_->lookup(connector::IndexSource::Request{batch.lookupInput}); + + getLookupResults(batch); } RowVectorPtr IndexLookupJoin::getOutputFromLookupResult( InputBatchState& batch) { VELOX_CHECK(!batch.empty()); + VELOX_CHECK(!shouldSkipInput()); VELOX_CHECK(!batch.lookupFuture.valid() || batch.lookupFuture.isReady()); batch.lookupFuture = ContinueFuture::makeEmpty(); - if (batch.lookupInput->size() == 0) { - if (hasRemainingOutputForLeftJoin(batch)) { - return produceRemainingOutputForLeftJoin(batch); - } - finishInput(batch); + if (shouldSkipLookup(batch)) { + return produceRemainingOutput(batch); + } + + if (!getLookupResults(batch)) { + // Async operation pending, need to wait. + VELOX_CHECK(batch.lookupFuture.valid()); return nullptr; } - VELOX_CHECK_NOT_NULL(batch.lookupResultIter); + VELOX_CHECK(!batch.lookupFuture.valid()); + VELOX_CHECK(batch.partialOutputs.empty()); if (batch.lookupResult == nullptr) { - auto resultOptional = - batch.lookupResultIter->next(outputBatchSize_, batch.lookupFuture); - if (!resultOptional.has_value()) { - VELOX_CHECK(batch.lookupFuture.valid()); - return nullptr; - } - VELOX_CHECK(!batch.lookupFuture.valid()); - - batch.lookupResult = std::move(resultOptional).value(); - if (batch.lookupResult == nullptr) { - if (hasRemainingOutputForLeftJoin(batch)) { - return produceRemainingOutputForLeftJoin(batch); - } - finishInput(batch); - return nullptr; + if (hasRemainingOutputForLeftJoin(batch)) { + return produceRemainingOutputForLeftJoin(batch); } + finishInput(batch); + return nullptr; } + prepareLookupResult(batch); VELOX_CHECK_NOT_NULL(batch.lookupResult); + if (!applyFilterOnLookupResult(batch)) { + VELOX_CHECK_NULL(batch.lookupResult); + // All rows in lookup result are filtered out, and fetch next lookup result + // batch. + return nullptr; + } + SCOPE_EXIT { maybeFinishLookupResult(batch); }; @@ -618,8 +1082,52 @@ RowVectorPtr IndexLookupJoin::getOutputFromLookupResult( return produceOutputForLeftJoin(batch); } +RowVectorPtr IndexLookupJoin::produceRemainingOutput(InputBatchState& batch) { + if (hasRemainingOutputForLeftJoin(batch)) { + return produceRemainingOutputForLeftJoin(batch); + } + finishInput(batch); + return nullptr; +} + +void IndexLookupJoin::recordIndexSourceInputStats( + const InputBatchState& batch) { + if (batch.lookupInput == nullptr || batch.lookupInput->size() == 0) { + return; + } + indexStatWriter_->addRuntimeStat( + "inputPositions", + RuntimeCounter( + static_cast(batch.lookupInput->size()), + RuntimeCounter::Unit::kNone)); + indexStatWriter_->addRuntimeStat( + "inputBytes", + RuntimeCounter( + static_cast(batch.lookupInput->estimateFlatSize()), + RuntimeCounter::Unit::kBytes)); +} + +void IndexLookupJoin::recordIndexSourceOutputStats( + const InputBatchState& batch) { + indexStatWriter_->addRuntimeStat( + "outputPositions", + RuntimeCounter( + static_cast(batch.lookupResult->size()), + RuntimeCounter::Unit::kNone)); + indexStatWriter_->addRuntimeStat( + "outputBytes", + RuntimeCounter( + static_cast(batch.lookupResult->output->estimateFlatSize()), + RuntimeCounter::Unit::kBytes)); + indexStatWriter_->addRuntimeStat( + "outputVectors", RuntimeCounter(1, RuntimeCounter::Unit::kNone)); +} + void IndexLookupJoin::prepareLookupResult(InputBatchState& batch) { VELOX_CHECK_NOT_NULL(batch.lookupResult); + + recordIndexSourceOutputStats(batch); + if (rawLookupInputHitIndices_ != nullptr) { return; } @@ -630,28 +1138,8 @@ void IndexLookupJoin::prepareLookupResult(InputBatchState& batch) { return; } VELOX_CHECK_NOT_NULL(batch.nonNullInputMappings); - vector_size_t* rawLookupInputHitIndices{nullptr}; - if (batch.lookupResult->inputHits->isMutable()) { - rawLookupInputHitIndices = - batch.lookupResult->inputHits->asMutable(); - } else { - const auto indicesByteSize = - batch.lookupResult->size() * sizeof(vector_size_t); - if ((batch.resultInputHitIndices == nullptr) || - !batch.resultInputHitIndices->unique() || - (batch.resultInputHitIndices->capacity() < indicesByteSize)) { - batch.resultInputHitIndices = allocateIndices(indicesByteSize, pool()); - } else { - batch.resultInputHitIndices->setSize(indicesByteSize); - } - rawLookupInputHitIndices = - batch.resultInputHitIndices->asMutable(); - std::memcpy( - rawLookupInputHitIndices, - batch.lookupResult->inputHits->as(), - indicesByteSize); - batch.lookupResult->inputHits = batch.resultInputHitIndices; - } + vector_size_t* rawLookupInputHitIndices = + batch.ensureInputHitsWritable(pool()); for (auto i = 0; i < batch.lookupResult->size(); ++i) { rawLookupInputHitIndices[i] = batch.rawNonNullInputMappings[rawLookupInputHitIndices[i]]; @@ -665,15 +1153,44 @@ void IndexLookupJoin::prepareLookupResult(InputBatchState& batch) { rawLookupInputHitIndices_ = rawLookupInputHitIndices; } +vector_size_t* IndexLookupJoin::InputBatchState::ensureInputHitsWritable( + memory::MemoryPool* pool) { + VELOX_CHECK_NOT_NULL(lookupResult); + if (lookupResult->inputHits->isMutable()) { + return lookupResult->inputHits->asMutable(); + } + + const auto indicesByteSize = lookupResult->size() * sizeof(vector_size_t); + if ((resultInputHitIndices == nullptr) || + !resultInputHitIndices->isMutable() || + (resultInputHitIndices->capacity() < indicesByteSize)) { + resultInputHitIndices = allocateIndices(indicesByteSize, pool); + } else { + resultInputHitIndices->setSize(indicesByteSize); + } + auto* rawLookupInputHitIndices = + resultInputHitIndices->asMutable(); + std::memcpy( + rawLookupInputHitIndices, + lookupResult->inputHits->as(), + indicesByteSize); + lookupResult->inputHits = resultInputHitIndices; + return rawLookupInputHitIndices; +} + void IndexLookupJoin::maybeFinishLookupResult(InputBatchState& batch) { VELOX_CHECK_NOT_NULL(batch.lookupResult); if (nextOutputResultRow_ == batch.lookupResult->size()) { - batch.lookupResult = nullptr; - nextOutputResultRow_ = 0; - rawLookupInputHitIndices_ = nullptr; + finishLookupResult(batch); } } +void IndexLookupJoin::finishLookupResult(InputBatchState& batch) { + batch.lookupResult = nullptr; + nextOutputResultRow_ = 0; + rawLookupInputHitIndices_ = nullptr; +} + bool IndexLookupJoin::hasRemainingOutputForLeftJoin( const InputBatchState& batch) const { if (joinType_ != core::JoinType::kLeft) { @@ -687,8 +1204,8 @@ bool IndexLookupJoin::hasRemainingOutputForLeftJoin( void IndexLookupJoin::finishInput(InputBatchState& batch) { VELOX_CHECK_NOT_NULL(batch.input); - VELOX_CHECK_EQ( - batch.lookupInput->size() == 0, batch.lookupResultIter == nullptr); + VELOX_CHECK(!shouldSkipInput()); + VELOX_CHECK_EQ(shouldSkipLookup(batch), batch.lookupResultIter == nullptr); VELOX_CHECK(!batch.lookupFuture.valid()); batch.input = nullptr; @@ -706,19 +1223,14 @@ void IndexLookupJoin::finishInput(InputBatchState& batch) { VELOX_CHECK(!nextBatch.lookupFuture.valid()); } else { VELOX_CHECK_EQ( - nextBatch.lookupInput->size() != 0, nextBatch.lookupFuture.valid()); + !shouldSkipLookup(nextBatch), nextBatch.lookupFuture.valid()); } } } -void IndexLookupJoin::prepareOutput(vector_size_t numOutputRows) { - if (output_ == nullptr) { - output_ = BaseVector::create(outputType_, numOutputRows, pool()); - } else { - VectorPtr output = std::move(output_); - BaseVector::prepareForReuse(output, numOutputRows); - output_ = std::static_pointer_cast(output); - } +RowVectorPtr IndexLookupJoin::prepareOutput(vector_size_t numOutputRows) { + std::vector children(outputType_->size(), nullptr); + return createRowVector(pool(), outputType_, numOutputRows); } RowVectorPtr IndexLookupJoin::produceOutputForInnerJoin( @@ -729,22 +1241,22 @@ RowVectorPtr IndexLookupJoin::produceOutputForInnerJoin( const size_t numOutputRows = std::min( batch.lookupResult->size() - nextOutputResultRow_, outputBatchSize_); - prepareOutput(numOutputRows); + auto output = prepareOutput(numOutputRows); if (numOutputRows == batch.lookupResult->size()) { for (const auto& projection : probeOutputProjections_) { - output_->childAt(projection.outputChannel) = BaseVector::wrapInDictionary( + output->childAt(projection.outputChannel) = BaseVector::wrapInDictionary( nullptr, batch.lookupResult->inputHits, numOutputRows, batch.input->childAt(projection.inputChannel)); } for (const auto& projection : lookupOutputProjections_) { - output_->childAt(projection.outputChannel) = + output->childAt(projection.outputChannel) = batch.lookupResult->output->childAt(projection.inputChannel); } } else { for (const auto& projection : probeOutputProjections_) { - output_->childAt(projection.outputChannel) = BaseVector::wrapInDictionary( + output->childAt(projection.outputChannel) = BaseVector::wrapInDictionary( nullptr, Buffer::slice( batch.lookupResult->inputHits, @@ -755,14 +1267,14 @@ RowVectorPtr IndexLookupJoin::produceOutputForInnerJoin( batch.input->childAt(projection.inputChannel)); } for (const auto& projection : lookupOutputProjections_) { - output_->childAt(projection.outputChannel) = + output->childAt(projection.outputChannel) = batch.lookupResult->output->childAt(projection.inputChannel) ->slice(nextOutputResultRow_, numOutputRows); } } nextOutputResultRow_ += numOutputRows; VELOX_CHECK_LE(nextOutputResultRow_, batch.lookupResult->size()); - return output_; + return output; } void IndexLookupJoin::fillOutputMatchRows( @@ -775,7 +1287,7 @@ void IndexLookupJoin::fillOutputMatchRows( offset, offset + size, match ? bits::kNotNull : bits::kNull); - if (!includeMatchColumn_) { + if (!hasMarker_) { return; } VELOX_CHECK_NOT_NULL(rawMatchValues_); @@ -803,26 +1315,38 @@ RowVectorPtr IndexLookupJoin::produceOutputForLeftJoin( VELOX_CHECK_NOT_NULL(rawLookupOutputNulls_); size_t numOutputRows{0}; size_t totalMissedInputRows{0}; + + // Outputs up to 'numMisses' missed (unmatched) input rows into the current + // output batch, capped by the remaining output capacity. Updates + // numOutputRows, lastProcessedInputRow, and totalMissedInputRows. + const auto outputMissedRows = [&](vector_size_t numMisses) INLINE_LAMBDA { + const auto numToOutput = + std::min(numMisses, maxOutputRows - numOutputRows); + if (totalMissedInputRows == 0) { + ensureMatchColumn(maxOutputRows); + fillOutputMatchRows(0, maxOutputRows, true); + } + fillOutputMatchRows(numOutputRows, numToOutput, false); + for (vector_size_t i = 0; i < numToOutput; ++i) { + rawProbeOutputRowIndices_[numOutputRows++] = ++lastProcessedInputRow; + } + totalMissedInputRows += numToOutput; + }; + for (; numOutputRows < maxOutputRows && nextOutputResultRow_ < batch.lookupResult->size();) { VELOX_CHECK_GE( - rawLookupInputHitIndices_[nextOutputResultRow_], lastProcessedInputRow); + rawLookupInputHitIndices_[nextOutputResultRow_], + lastProcessedInputRow, + "nextOutputResultRow_ {}, batch.lookupResult->size() {}", + nextOutputResultRow_, + batch.lookupResult->size()); const vector_size_t numMissedInputRows = rawLookupInputHitIndices_[nextOutputResultRow_] - lastProcessedInputRow - 1; VELOX_CHECK_GE(numMissedInputRows, -1); if (numMissedInputRows > 0) { - if (totalMissedInputRows == 0) { - ensureMatchColumn(maxOutputRows); - fillOutputMatchRows(0, maxOutputRows, true); - } - const auto numOutputMissedInputRows = std::min( - numMissedInputRows, maxOutputRows - numOutputRows); - fillOutputMatchRows(numOutputRows, numOutputMissedInputRows, false); - for (auto i = 0; i < numOutputMissedInputRows; ++i) { - rawProbeOutputRowIndices_[numOutputRows++] = ++lastProcessedInputRow; - } - totalMissedInputRows += numOutputMissedInputRows; + outputMissedRows(numMissedInputRows); continue; } @@ -833,6 +1357,17 @@ RowVectorPtr IndexLookupJoin::produceOutputForLeftJoin( ++nextOutputResultRow_; ++numOutputRows; } + + // If splitOutput_ is false, include any trailing missed input rows. + if (!splitOutput_ && nextOutputResultRow_ == batch.lookupResult->size() && + numOutputRows < maxOutputRows) { + const vector_size_t numRemainingInputRows = + batch.input->size() - lastProcessedInputRow - 1; + if (numRemainingInputRows > 0) { + outputMissedRows(numRemainingInputRows); + } + } + VELOX_CHECK( numOutputRows == maxOutputRows || nextOutputResultRow_ == batch.lookupResult->size()); @@ -850,24 +1385,24 @@ RowVectorPtr IndexLookupJoin::produceOutputForLeftJoin( return nullptr; } - prepareOutput(numOutputRows); + auto output = prepareOutput(numOutputRows); const auto numInputRows = lastProcessedInputRow - startProcessInputRow + 1; if (numInputRows == numOutputRows) { if (startProcessInputRow == 0 && numInputRows == batch.input->size()) { for (const auto& projection : probeOutputProjections_) { - output_->childAt(projection.outputChannel) = + output->childAt(projection.outputChannel) = batch.input->childAt(projection.inputChannel); } } else { for (const auto& projection : probeOutputProjections_) { - output_->childAt(projection.outputChannel) = + output->childAt(projection.outputChannel) = batch.input->childAt(projection.inputChannel) ->slice(startProcessInputRow, numInputRows); } } } else { for (const auto& projection : probeOutputProjections_) { - output_->childAt(projection.outputChannel) = BaseVector::wrapInDictionary( + output->childAt(projection.outputChannel) = BaseVector::wrapInDictionary( nullptr, probeOutputRowMapping_, numOutputRows, @@ -877,39 +1412,39 @@ RowVectorPtr IndexLookupJoin::produceOutputForLeftJoin( if (totalMissedInputRows > 0) { for (const auto& projection : lookupOutputProjections_) { - output_->childAt(projection.outputChannel) = BaseVector::wrapInDictionary( + output->childAt(projection.outputChannel) = BaseVector::wrapInDictionary( lookupOutputNulls_, lookupOutputRowMapping_, numOutputRows, batch.lookupResult->output->childAt(projection.inputChannel)); } - if (includeMatchColumn_) { - output_->childAt(matchOutputChannel_.value()) = matchColumn_; + if (hasMarker_) { + output->childAt(matchOutputChannel_.value()) = matchColumn_; } } else { if (startOutputRow == 0 && numOutputRows == batch.lookupResult->output->size()) { for (const auto& projection : lookupOutputProjections_) { - output_->childAt(projection.outputChannel) = + output->childAt(projection.outputChannel) = batch.lookupResult->output->childAt(projection.inputChannel); } } else { for (const auto& projection : lookupOutputProjections_) { - output_->childAt(projection.outputChannel) = + output->childAt(projection.outputChannel) = batch.lookupResult->output->childAt(projection.inputChannel) ->slice(startOutputRow, numOutputRows); } } - if (includeMatchColumn_) { - output_->childAt(matchOutputChannel_.value()) = + if (hasMarker_) { + output->childAt(matchOutputChannel_.value()) = BaseVector::createConstant(BOOLEAN(), true, numOutputRows, pool()); } } - return output_; + return output; } void IndexLookupJoin::ensureMatchColumn(vector_size_t maxOutputRows) { - if (!includeMatchColumn_) { + if (!hasMarker_) { return; } if (matchColumn_) { @@ -925,7 +1460,7 @@ void IndexLookupJoin::ensureMatchColumn(vector_size_t maxOutputRows) { } void IndexLookupJoin::setMatchColumnSize(vector_size_t numOutputRows) { - if (!includeMatchColumn_) { + if (!hasMarker_) { return; } VELOX_CHECK_NOT_NULL(matchColumn_); @@ -945,31 +1480,29 @@ RowVectorPtr IndexLookupJoin::produceRemainingOutputForLeftJoin( outputBatchSize_, batch.input->size() - startProcessInputRow); VELOX_CHECK_GT(numOutputRows, 0); VELOX_CHECK_LE(numOutputRows, batch.input->size()); - prepareOutput(numOutputRows); + auto output = prepareOutput(numOutputRows); if (numOutputRows != batch.input->size()) { for (const auto& projection : probeOutputProjections_) { - output_->childAt(projection.outputChannel) = + output->childAt(projection.outputChannel) = batch.input->childAt(projection.inputChannel) ->slice(startProcessInputRow, numOutputRows); } } else { for (const auto& projection : probeOutputProjections_) { - output_->childAt(projection.outputChannel) = + output->childAt(projection.outputChannel) = batch.input->childAt(projection.inputChannel); } } for (const auto& projection : lookupOutputProjections_) { - output_->childAt(projection.outputChannel) = BaseVector::createNullConstant( - output_->type()->childAt(projection.outputChannel), - numOutputRows, - pool()); + output->childAt(projection.outputChannel) = BaseVector::createNullConstant( + outputType_->childAt(projection.outputChannel), numOutputRows, pool()); } - if (includeMatchColumn_) { - output_->childAt(matchOutputChannel_.value()) = + if (hasMarker_) { + output->childAt(matchOutputChannel_.value()) = BaseVector::createConstant(BOOLEAN(), false, numOutputRows, pool()); } lastProcessedInputRow_ = lastProcessedInputRow + numOutputRows; - return output_; + return output; } void IndexLookupJoin::prepareOutputRowMappings(size_t outputBatchSize) { @@ -1015,6 +1548,106 @@ void IndexLookupJoin::close() { lookupOutputNulls_ = nullptr; Operator::close(); + + closeIndexSplitTracer(); +} + +bool IndexLookupJoin::applyFilterOnLookupResult(InputBatchState& batch) { + VELOX_CHECK_NOT_NULL(batch.lookupResult); + if (!filter_) { + return true; + } + if (batch.lookupResult->size() == 0) { + return true; + } + + const auto numResultRows = batch.lookupResult->size(); + + // Prepare filter input vector + filterRows_.resize(numResultRows); + filterRows_.setAll(); + filterInput_ = createRowVector( + pool(), filterInputType_, static_cast(numResultRows)); + + // Populate filter input from probe input. + for (const auto& projection : filterProbeInputProjections_) { + // Get the probe input column and dictionary-wrap it with hit indices + filterInput_->childAt(projection.outputChannel) = + BaseVector::wrapInDictionary( + nullptr, + batch.lookupResult->inputHits, + numResultRows, + batch.input->childAt(projection.inputChannel)); + } + + // Populate filter input from lookup result. + for (const auto& projection : filterLookupInputProjections_) { + filterInput_->childAt(projection.outputChannel) = + batch.lookupResult->output->childAt(projection.inputChannel); + } + + // Evaluate filter + filterResult_.resize(1); + EvalCtx evalCtx(operatorCtx_->execCtx(), filter_.get(), filterInput_.get()); + filter_->eval(filterRows_, evalCtx, filterResult_); + decodedFilterResult_.decode(*filterResult_[0], filterRows_); + + const auto indicesByteSize = numResultRows * sizeof(vector_size_t); + if (!filteredIndices_ || !filteredIndices_->isMutable() || + filteredIndices_->capacity() < indicesByteSize) { + filteredIndices_ = allocateIndices(numResultRows, pool()); + } else { + filteredIndices_->setSize(indicesByteSize); + } + auto* rawFilteredIndices = filteredIndices_->asMutable(); + + vector_size_t numPassed{0}; + for (auto i = 0; i < numResultRows; ++i) { + if (!decodedFilterResult_.isNullAt(i) && + decodedFilterResult_.valueAt(i)) { + rawFilteredIndices[numPassed++] = i; + } + } + + if (numPassed == 0) { + finishLookupResult(batch); + return false; + } + + if (numPassed == numResultRows) { + return true; + } + + // Some rows passed - create filtered lookup result. + filteredIndices_->setSize(numPassed * sizeof(vector_size_t)); + + // Update the inputHits buffer. + auto* rawLookupInputHitIndices = batch.ensureInputHitsWritable(pool()); + for (auto i = 0; i < numPassed; ++i) { + rawLookupInputHitIndices[i] = + rawLookupInputHitIndices_[rawFilteredIndices[i]]; +#ifdef NDEBUG + if (i > 0) { + VELOX_DCHECK_LE( + rawLookupInputHitIndices[i - 1], rawLookupInputHitIndices[i]); + } +#endif + } + batch.lookupResult->inputHits->setSize(numPassed * sizeof(vector_size_t)); + rawLookupInputHitIndices_ = rawLookupInputHitIndices; + + // Create the filtered result vector. + auto filteredOutput = BaseVector::create( + batch.lookupResult->output->type(), numPassed, pool()); + for (auto i = 0; i < batch.lookupResult->output->childrenSize(); ++i) { + filteredOutput->childAt(i) = BaseVector::wrapInDictionary( + nullptr, + filteredIndices_, + numPassed, + batch.lookupResult->output->childAt(i)); + } + batch.lookupResult->output = std::move(filteredOutput); + return true; } void IndexLookupJoin::recordConnectorStats() { @@ -1023,23 +1656,46 @@ void IndexLookupJoin::recordConnectorStats() { // in that case. return; } - auto lockedStats = stats_.wlock(); auto connectorStats = indexSource_->runtimeStats(); for (auto& [name, value] : connectorStats) { - lockedStats->runtimeStats.erase(name); - lockedStats->runtimeStats.emplace(name, std::move(value)); - } - if (connectorStats.count(kConnectorLookupWallTime) != 0) { - const CpuWallTiming backgroundTiming{ - static_cast(connectorStats[kConnectorLookupWallTime].count), - static_cast(connectorStats[kConnectorLookupWallTime].sum), - // NOTE: this might not be accurate as it doesn't include the time - // spent inside the index storage client. - static_cast(connectorStats[kConnectorResultPrepareTime].sum) + - connectorStats[kClientRequestProcessTime].sum + - connectorStats[kClientResultProcessTime].sum}; - lockedStats->backgroundTiming.clear(); - lockedStats->backgroundTiming.add(backgroundTiming); + indexStatWriter_->setRuntimeStat(name, value); + } + + // Record lookup timing into the index stat writer. splitStats extracts these + // to populate the IndexSource node's addInputTiming. + if (connectorStats.count(std::string(kConnectorLookupWallTime)) != 0) { + const auto& lookupWallTime = + connectorStats[std::string(kConnectorLookupWallTime)]; + indexStatWriter_->addRuntimeStat( + "lookupCount", + RuntimeCounter( + static_cast(lookupWallTime.count), + RuntimeCounter::Unit::kNone)); + indexStatWriter_->addRuntimeStat( + "lookupWallNanos", + RuntimeCounter( + static_cast(lookupWallTime.sum), + RuntimeCounter::Unit::kNanos)); + // NOTE: lookupCpuNanos may undercount CPU consumed on prefetch worker + // threads or async I/O completion handlers, since CpuWallTimer measures + // CPU on the calling thread only. + indexStatWriter_->addRuntimeStat( + "lookupCpuNanos", + RuntimeCounter( + static_cast( + connectorStats[std::string(kConnectorResultPrepareTime)].sum + + connectorStats[std::string(kClientRequestProcessTime)].sum + + connectorStats[std::string(kClientResultProcessTime)].sum), + RuntimeCounter::Unit::kNanos)); + } + // Copy index source stats into the operator's own runtimeStats so they are + // visible through the task-level runtime stats path. + const auto indexSourceStats = indexStatWriter_->runtimeStats(); + { + auto lockedStats = stats_.wlock(); + for (const auto& [name, value] : indexSourceStats) { + lockedStats->runtimeStats[fmt::format("indexSource.{}", name)] = value; + } } } } // namespace facebook::velox::exec diff --git a/velox/exec/IndexLookupJoin.h b/velox/exec/IndexLookupJoin.h index 159324b7bb7..369a6fb7207 100644 --- a/velox/exec/IndexLookupJoin.h +++ b/velox/exec/IndexLookupJoin.h @@ -14,6 +14,8 @@ * limitations under the License. */ #pragma once + +#include "velox/exec/IndexLookupJoinBridge.h" #include "velox/exec/Operator.h" #include "velox/exec/VectorHasher.h" @@ -39,7 +41,7 @@ class IndexLookupJoin : public Operator { RowVectorPtr getOutput() override; bool isFinished() override { - return noMoreInput_ && (numInputBatches() == 0); + return (noMoreInput_ || shouldSkipInput()) && (numInputBatches() == 0); } void close() override; @@ -47,41 +49,73 @@ class IndexLookupJoin : public Operator { /// Defines lookup runtime stats. /// The end-to-end walltime in nanoseconds that the index connector do the /// lookup. - static inline const std::string kConnectorLookupWallTime{ + static constexpr std::string_view kConnectorLookupWallTime{ "connectorLookupWallNanos"}; /// The cpu time in nanoseconds that the index connector process response from /// storage client for followup processing by index join operator. - static inline const std::string kConnectorResultPrepareTime{ + static constexpr std::string_view kConnectorResultPrepareTime{ "connectorResultPrepareCpuNanos"}; /// The cpu time in nanoseconds that the storage client process request for /// remote storage lookup such as encoding the lookup input data into remotr /// storage request. - static inline const std::string kClientRequestProcessTime{ + static constexpr std::string_view kClientRequestProcessTime{ "clientRequestProcessCpuNanos"}; /// The walltime in nanoseconds that the storage client wait for the lookup /// from remote storage. - static inline const std::string kClientLookupWaitWallTime{ + static constexpr std::string_view kClientLookupWaitWallTime{ "clientlookupWaitWallNanos"}; /// The number of split requests sent to remote storage for a client lookup /// request. - static inline const std::string kClientNumStorageRequests{ + static constexpr std::string_view kClientNumStorageRequests{ "clientNumStorageRequests"}; /// The cpu time in nanoseconds that the storage client process response from /// remote storage lookup such as decoding the response data into velox /// vectors. - static inline const std::string kClientResultProcessTime{ + static constexpr std::string_view kClientResultProcessTime{ "clientResultProcessCpuNanos"}; /// The byte size of the raw result received from the remote storage lookup. - static inline const std::string kClientLookupResultRawSize{ + static constexpr std::string_view kClientLookupResultRawSize{ "clientLookupResultRawSize"}; /// The byte size of the result data in velox vectors that are decoded from /// the raw data received from the remote storage lookup. - static inline const std::string kClientLookupResultSize{ + static constexpr std::string_view kClientLookupResultSize{ "clientLookupResultSize"}; + /// The number of lookup results received from remote storage with error. + static constexpr std::string_view kClientNumErrorResults{ + "clientNumErrorResults"}; + /// The number of index splits provided for index lookup. + static constexpr std::string_view kNumIndexSplits{"numIndexSplits"}; private: - using LookupResultIter = connector::IndexSource::LookupResultIterator; - using LookupResult = connector::IndexSource::LookupResult; + // Intercepts runtime stats emitted during index-side operations (getOutput / + // startLookup) and accumulates them into a local map, separating them from + // probe-side stats so Driver::processLazyIoStats() correctly attributes + // only probe-side stats to the scan operator. Held via shared_ptr so the + // stat splitter lambda can outlive the operator and read the final stats. + class IndexStatWriter : public BaseRuntimeStatWriter { + public: + void addRuntimeStat(std::string_view name, const RuntimeCounter& value) + override; + + // Sets a runtime metric in the index source stats map. Thread-safe. + void setRuntimeStat(const std::string& name, const RuntimeMetric& metric); + + // Returns a snapshot of the accumulated index source runtime stats. + std::unordered_map runtimeStats() const; + + private: + folly::Synchronized> + runtimeStats_; + }; + + // Produces separate OperatorStats for IndexLookupJoin and IndexSource nodes. + static std::vector splitStats( + const OperatorStats& combinedStats, + const core::PlanNodeId& indexSourceNodeId, + const IndexStatWriter& indexSourceStatWriter); + + using ResultIterator = connector::IndexSource::ResultIterator; + using Result = connector::IndexSource::Result; // Contains the state of an input batch processing. struct InputBatchState { @@ -100,14 +134,14 @@ class IndexLookupJoin : public Operator { // The reusable vector projected from 'input' as index lookup input. RowVectorPtr lookupInput; // Used to fetch lookup results for an input batch. - std::shared_ptr lookupResultIter; + std::shared_ptr lookupResultIter; // Used for synchronization with the async fetch result from index source // through 'lookupResultIter'. ContinueFuture lookupFuture; // Used to store the lookup result fetched from 'lookupResultIter' for // output processing. We might split the output result into multiple output // batches based on the operator's output batch size limit. - std::unique_ptr lookupResult; + std::unique_ptr lookupResult; // Specifies the indices of input row in 'input' that have matches in // 'output' from 'lookupResult'. This is only used in case // 'lookupInputHasNullKeys' is true in which 'inputHits' in 'lookupResult' @@ -117,6 +151,10 @@ class IndexLookupJoin : public Operator { // row in 'input' through the mapping specified by 'nonNullInputMappings'. // The redirect input hit indices are stored in 'resultInputHitIndices'. BufferPtr resultInputHitIndices; + // When splitOutput_ is false, this tracks partially accumulated results + // that are waiting for async operations to complete before continuing + // accumulation. + std::vector> partialOutputs; InputBatchState() : lookupFuture(ContinueFuture::makeEmpty()) {} @@ -125,12 +163,20 @@ class IndexLookupJoin : public Operator { lookupResultIter = nullptr; lookupFuture = ContinueFuture::makeEmpty(); lookupResult = nullptr; + partialOutputs.clear(); } // Indicates if this input batch is empty. bool empty() const { return input == nullptr; } + + // Ensures that the lookup result's inputHits buffer is writable and returns + // a mutable pointer. If the buffer is already mutable, returns it directly. + // Otherwise, creates a new writable buffer by copying the existing data and + // returns a pointer to the new buffer. This is needed when filters or null + // key handling requires modifying the input hit indices. + vector_size_t* ensureInputHitsWritable(memory::MemoryPool* pool); }; void initInputBatches(); @@ -138,17 +184,51 @@ class IndexLookupJoin : public Operator { void initLookupInput(); void initLookupOutput(); void initOutputProjections(); + void initFilter(); + + // Collects splits for the index source until no more splits signal is + // received. Returns true if all splits have been collected and the index + // source is ready. Returns false if we are still waiting for splits. + // Only called by the split collector operator (partitionId == 0). + bool collectIndexSplits(ContinueFuture* future); + + // Waits for the split collector to share index splits via the bridge. + // Returns true if splits are available and have been added to the index + // source. Returns false if we are still waiting for splits. + bool waitForIndexSplits(ContinueFuture* future); + + // Applies the join filter directly on the lookup result, updating the + // lookup result to only include rows that pass the filter. Returns true if + // some rows passed the filter, otherwise false. + bool applyFilterOnLookupResult(InputBatchState& batch); + void ensureInputLoaded(const InputBatchState& batch); // Prepare index source lookup for a given 'input_'. void prepareLookup(InputBatchState& batch); void startLookup(InputBatchState& batch); + // Helper function to merge batch.partialOutputs into a single + // batch.lookupResult. This is used when splitOutput_ is false to ensure all + // results from an iterator are combined into one output batch. + void mergeLookupResults(InputBatchState& batch); + // Helper function to get all lookup results. Fetches the first result if not + // already fetched, and when splitOutput_ is false, accumulates all remaining + // results into a single batch. Handles both initial lookup and resuming + // accumulation after async interruption. Returns true if results are ready, + // false if an async operation is pending. + bool getLookupResults(InputBatchState& batch); + void startLookupBlockWait(); void endLookupBlockWait(); RowVectorPtr getOutputFromLookupResult(InputBatchState& batch); RowVectorPtr produceOutputForInnerJoin(const InputBatchState& batch); RowVectorPtr produceOutputForLeftJoin(const InputBatchState& batch); + // Handles production of remaining output after lookup result processing is + // complete. For left joins, this ensures unmatched rows from the probe side + // are included in the output with null values for lookup columns. For inner + // joins, this simply finishes the input batch. + RowVectorPtr produceRemainingOutput(InputBatchState& batch); // Produces output for the remaining input rows that has no matches from the // lookup at the end of current input batch processing. RowVectorPtr produceRemainingOutputForLeftJoin(const InputBatchState& batch); @@ -160,8 +240,10 @@ class IndexLookupJoin : public Operator { bool hasRemainingOutputForLeftJoin(const InputBatchState& batch) const; // Checks if we have finished processing the current 'lookupResult_'. If so, - // we reset 'lookupResult_' and corresponding processing state. + // call 'finishLookupResult' to reset 'lookupResult_' and corresponding + // processing state. void maybeFinishLookupResult(InputBatchState& batch); + void finishLookupResult(InputBatchState& batch); // Invoked after finished processing the current 'input_' batch. The function // resets the input batch and the lookup result states. @@ -172,8 +254,9 @@ class IndexLookupJoin : public Operator { // for output rows without lookup matches. void prepareOutputRowMappings(size_t outputBatchSize); - // Prepare 'output_' for the next output batch with size of 'numOutputRows'. - void prepareOutput(vector_size_t numOutputRows); + // Creates a new output RowVector with 'numOutputRows' rows and nullptr + // children. Callers populate the children before returning it. + RowVectorPtr prepareOutput(vector_size_t numOutputRows); // Invoked to ensure the match column is created to store the output match // result for the left join. @@ -196,6 +279,12 @@ class IndexLookupJoin : public Operator { // input rows. void prepareLookupResult(InputBatchState& batch); + // Records index source input stats from the lookup keys. + void recordIndexSourceInputStats(const InputBatchState& batch); + + // Records index source output stats from the lookup result. + void recordIndexSourceOutputStats(const InputBatchState& batch); + // Invoked at operator close to record the lookup stats. void recordConnectorStats(); @@ -205,6 +294,25 @@ class IndexLookupJoin : public Operator { return maxNumInputBatches_ > 1; } + // Returns true if the index source needs splits and we haven't received the + // no-more-splits signal yet. + bool needsIndexSplits() const { + return lookupTableHandle_->needsIndexSplit() && !noMoreIndexSplits_; + } + + // Returns true if input processing can be skipped entirely. This is the case + // for INNER JOIN when there are no index splits — every probe row will be + // discarded since there can be no matches. + bool shouldSkipInput() const { + return hasNoIndexSplits_ && joinType_ == core::JoinType::kInner; + } + + // Returns true if lookup should be skipped for the given batch, either + // because the index source has no splits or because all probe keys are null. + bool shouldSkipLookup(const InputBatchState& batch) const { + return hasNoIndexSplits_ || batch.lookupInput->size() == 0; + } + // Returns the number of input batches to process. size_t numInputBatches() const { VELOX_CHECK_LE(startBatchIndex_, endBatchIndex_); @@ -228,21 +336,29 @@ class IndexLookupJoin : public Operator { return inputBatches_[startBatchIndex_ % maxNumInputBatches_]; } + // If true, allows one input row to produce multiple output rows. + // If false, enforces one-to-one mapping. + const bool splitOutput_; // Maximum number of rows in the output batch. const vector_size_t outputBatchSize_; // Type of join. const core::JoinType joinType_; - const bool includeMatchColumn_; - const size_t numKeys_; + const bool hasMarker_; const RowTypePtr probeType_; const RowTypePtr lookupType_; + // The plan node id of the lookup source (index source). + const core::PlanNodeId indexSourceNodeId_; const connector::ConnectorTableHandlePtr lookupTableHandle_; - const std::vector lookupConditions_; + const std::vector joinConditions_; const connector::ColumnHandleMap lookupColumnHandles_; const std::shared_ptr connectorQueryCtx_; const std::shared_ptr connector_; const size_t maxNumInputBatches_; + // True if this operator (partitionId == 0) is responsible for collecting + // index splits from the task and sharing them via the bridge. + const bool isIndexSplitCollector_; + // The lookup join plan node used to initialize this operator and reset after // that. std::shared_ptr joinNode_; @@ -300,13 +416,64 @@ class IndexLookupJoin : public Operator { BufferPtr lookupOutputNulls_; uint64_t* rawLookupOutputNulls_{nullptr}; - // The reusable output vector for the join output. - RowVectorPtr output_; + // Join filter. + std::unique_ptr filter_; + + // Join filter input type. + RowTypePtr filterInputType_; + + // Maps probe-side input channels to channels in 'filterInputType_'. + std::vector filterProbeInputProjections_; + // Maps lookup-side input channels to channels in 'filterInputType_', + std::vector filterLookupInputProjections_; + + // Reusable memory for filter evaluations. + RowVectorPtr filterInput_; + SelectivityVector filterRows_; + std::vector filterResult_; + DecodedVector decodedFilterResult_; + BufferPtr filteredIndices_; + FlatVectorPtr matchColumn_{nullptr}; uint64_t* rawMatchValues_{nullptr}; // The start time of the current lookup driver block wait, and reset after the // driver wait completes. std::optional blockWaitStartNs_; + + // The bridge for sharing index splits across operators in the same pipeline. + // Null if the index source does not need splits. + std::shared_ptr joinBridge_; + + // Split collection state for index sources that require splits. + // True if we have received the no-more-splits signal for the index source. + bool noMoreIndexSplits_{false}; + // True if the index source received zero splits (e.g., partition pruning + // eliminated all index partitions). When set, lookups are skipped entirely. + bool hasNoIndexSplits_{false}; + // The future to wait for the next index split. + ContinueFuture indexSplitFuture_; + // The collected splits for the index source. It is passed to index source + // after the no-more-splits signal is received (i.e., 'noMoreIndexSplits_' is + // true). + std::vector> indexSplits_; + + // Traces the index splits received by this operator for replay. Set when + // tracing is enabled for this operator. + std::unique_ptr indexSplitTracer_; + + // Creates the index split tracer if input tracing is enabled. + void createIndexSplitTracer(); + + // Traces the given index split for replay. + void traceIndexSplit(const exec::Split& split); + + // Closes and resets the index split tracer. + void closeIndexSplitTracer(); + + // Intercepts and accumulates index source runtime stats. Held via + // shared_ptr so the stat splitter lambda can read the final stats after the + // operator is destroyed. + std::shared_ptr indexStatWriter_; }; } // namespace facebook::velox::exec diff --git a/velox/exec/IndexLookupJoinBridge.cpp b/velox/exec/IndexLookupJoinBridge.cpp new file mode 100644 index 00000000000..a00a7159e8a --- /dev/null +++ b/velox/exec/IndexLookupJoinBridge.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/exec/IndexLookupJoinBridge.h" + +namespace facebook::velox::exec { + +void IndexLookupJoinBridge::setIndexSplits( + std::vector> splits) { + std::vector promises; + { + std::lock_guard l(mutex_); + VELOX_CHECK(started_, "Bridge must be started before setting index splits"); + VELOX_CHECK( + !cancelled_, "Setting index splits after the bridge is cancelled"); + VELOX_CHECK(!splitsSet_, "setIndexSplits must be called only once"); + splitsSet_ = true; + indexSplits_ = std::move(splits); + promises = std::move(promises_); + } + notify(std::move(promises)); +} + +std::vector> +IndexLookupJoinBridge::splitsOrFuture(ContinueFuture* future) { + std::lock_guard l(mutex_); + VELOX_CHECK( + !cancelled_, "Getting index splits after the bridge is cancelled"); + if (splitsSet_) { + return indexSplits_; + } + promises_.emplace_back("IndexLookupJoinBridge::splitsOrFuture"); + *future = promises_.back().getSemiFuture(); + return {}; +} + +} // namespace facebook::velox::exec diff --git a/velox/exec/IndexLookupJoinBridge.h b/velox/exec/IndexLookupJoinBridge.h new file mode 100644 index 00000000000..b085c477918 --- /dev/null +++ b/velox/exec/IndexLookupJoinBridge.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/connectors/Connector.h" +#include "velox/exec/JoinBridge.h" + +namespace facebook::velox::exec { + +/// Coordinates sharing of index splits among multiple IndexLookupJoin operators +/// in the same pipeline. The leader operator (driverId 0) collects all splits +/// from the task and publishes them via setIndexSplits(). Follower operators +/// wait for splits via splitsOrFuture(). +class IndexLookupJoinBridge : public JoinBridge { + public: + /// Called by the leader operator after collecting all index splits. Stores + /// the splits and notifies all waiting followers. Must be called only once. + /// May be called with an empty vector (e.g., when partition pruning + /// eliminates all index partitions for a split group). + void setIndexSplits( + std::vector> splits); + + /// Returns the index splits if the leader has called setIndexSplits(), + /// otherwise sets 'future' and returns an empty vector. The caller should + /// check future->valid() to distinguish "not ready" (future set) from + /// "ready with 0 splits" (future not set, empty vector). + std::vector> splitsOrFuture( + ContinueFuture* future); + + private: + bool splitsSet_{false}; + std::vector> indexSplits_; +}; + +} // namespace facebook::velox::exec diff --git a/velox/exec/Limit.cpp b/velox/exec/Limit.cpp index 0ed5bbf7386..39bcf5ffdb8 100644 --- a/velox/exec/Limit.cpp +++ b/velox/exec/Limit.cpp @@ -15,6 +15,8 @@ */ #include "velox/exec/Limit.h" +#include "velox/exec/OperatorType.h" + namespace facebook::velox::exec { Limit::Limit( int32_t operatorId, @@ -25,7 +27,7 @@ Limit::Limit( limitNode->outputType(), operatorId, limitNode->id(), - "Limit"), + OperatorType::kLimit), remainingOffset_{limitNode->offset()}, remainingLimit_{limitNode->count()} { isIdentityProjection_ = true; diff --git a/velox/exec/LocalPartition.cpp b/velox/exec/LocalPartition.cpp index 1fbce8b8960..eb6eb81add3 100644 --- a/velox/exec/LocalPartition.cpp +++ b/velox/exec/LocalPartition.cpp @@ -15,6 +15,8 @@ */ #include "velox/exec/LocalPartition.h" +#include "velox/common/Casts.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/Task.h" #include "velox/vector/EncodedVectorCopy.h" @@ -263,7 +265,7 @@ LocalExchange::LocalExchange( std::move(outputType), operatorId, planNodeId, - "LocalExchange"), + OperatorType::kLocalExchange), partition_{partition}, queue_{operatorCtx_->task()->getLocalExchangeQueue( ctx->splitGroupId, @@ -332,7 +334,7 @@ LocalPartition::LocalPartition( planNode->outputType(), operatorId, planNode->id(), - "LocalPartition"), + OperatorType::kLocalPartition), queues_{ ctx->task->getLocalExchangeQueues(ctx->splitGroupId, planNode->id())}, numPartitions_{queues_.size()}, @@ -420,6 +422,7 @@ RowVectorPtr LocalPartition::wrapChildren( void LocalPartition::copy( const RowVectorPtr& input, const folly::Range& ranges, + const size_t partition, VectorPtr& target) { if (ranges.empty()) { return; @@ -432,57 +435,153 @@ void LocalPartition::copy( } if (!target) { - target = BaseVector::create(outputType_, 0, pool()); + target = getOrCreateVector(partition); } target->resize(target->size() + ranges.size()); target->copyRanges(input.get(), ranges); } -RowVectorPtr LocalPartition::processPartition( +VectorPtr LocalPartition::getOrCreateVector(const size_t partition) { + auto reusable = queues_[partition]->getVector(); + if (reusable) { + VELOX_CHECK_EQ(reusable->type(), outputType_); + reusable->unsafeResize(0); + for (auto i = 0; i < reusable->childrenSize(); ++i) { + reusable->childAt(i) = nullptr; + } + return reusable; + } else { + return BaseVector::create(outputType_, 0, pool()); + } +} + +void LocalPartition::populatePartitionBuffer( const RowVectorPtr& input, - vector_size_t size, - int partition, - const BufferPtr& indices, - const vector_size_t* rawIndices) { + const vector_size_t numPartitionRows, + const size_t partition, + const vector_size_t* rawIndices, + uint64_t& totalPartitionBufferSizeExcludingString, + uint64_t& totalPartitionStringBufferSize) { + VELOX_CHECK_GT(singlePartitionBufferSize_, 0); + copyRanges_.resize(numPartitionRows); + + auto& partitionBuffer = partitionBuffers_[partition]; + auto targetIndex = 0; + if (partitionBuffer) { + targetIndex = partitionBuffer->size(); + } + for (int i = 0; i < numPartitionRows; i++) { + copyRanges_[i] = {rawIndices[i], targetIndex, 1}; + targetIndex++; + } + + copy(input, copyRanges_, partition, partitionBuffer); + + if (partitionBuffer) { + uint64_t stringBufferSize{0}; + auto totalSize = partitionBuffer->retainedSize(stringBufferSize); + totalPartitionBufferSizeExcludingString += totalSize - stringBufferSize; + totalPartitionStringBufferSize += stringBufferSize; + } +} + +RowVectorPtr LocalPartition::createPartition( + const RowVectorPtr& input, + const vector_size_t numPartitionRows, + const size_t partition, + const BufferPtr& indices) { RowVectorPtr partitionData{nullptr}; if (singlePartitionBufferSize_ > 0) { - if (partitionBuffers_.empty()) { - partitionBuffers_.resize(numPartitions_); - } - if (copyRanges_.size() < size) { - copyRanges_.resize(size); - } - auto& partitionBuffer = partitionBuffers_[partition]; - auto targetIndex = 0; if (partitionBuffer) { - targetIndex = partitionBuffer->size(); - } - for (int i = 0; i < size; i++) { - copyRanges_[i] = {rawIndices[i], targetIndex, 1}; - targetIndex++; + partitionData = + checkedPointerCast(partitionBuffer); + partitionBuffers_[partition] = nullptr; } + } else if (numPartitionRows > 0) { + partitionData = wrapChildren( + input, numPartitionRows, indices, queues_[partition]->getVector()); + } + return partitionData; +} - copy( - input, - folly::Range{copyRanges_.data(), static_cast(size)}, - partitionBuffer); - - if (partitionBuffer && - partitionBuffer->retainedSize() >= singlePartitionBufferSize_) { - partitionData = std::dynamic_pointer_cast(partitionBuffer); - VELOX_CHECK(partitionData); - partitionBuffers_[partition] = nullptr; +void LocalPartition::populateAndEnqueuePartitions( + RowVectorPtr input, + const std::vector& numRowsPerPartition, + const std::vector& indexBuffers, + const std::vector& rawIndicesBuffers) { + uint64_t totalPartitionBufferSizeExcludingString = 0; + uint64_t totalPartitionStringBufferSize = 0; + uint16_t nonEmptyPartitionCount = 0; + + // Populate partition buffers if in buffer mode. + if (singlePartitionBufferSize_ > 0) { + if (partitionBuffers_.empty()) { + partitionBuffers_.resize(numPartitions_); + } + for (auto partition = 0; partition < numPartitions_; partition++) { + populatePartitionBuffer( + input, + numRowsPerPartition[partition], + partition, + rawIndicesBuffers[partition], + totalPartitionBufferSizeExcludingString, + totalPartitionStringBufferSize); + if (partitionBuffers_[partition]) { + nonEmptyPartitionCount++; + } } } else { - partitionData = - wrapChildren(input, size, indices, queues_[partition]->getVector()); + nonEmptyPartitionCount = numPartitions_ - + std::count(numRowsPerPartition.begin(), numRowsPerPartition.end(), 0); + } + VELOX_CHECK_GT( + nonEmptyPartitionCount, + 0, + "Input rows should be assigned to at least one partition"); + + // Calculate the partition buffer size across all partitions with amortized + // string buffer sizes. + auto balancedTotalPartitionBufferSize = + totalPartitionBufferSizeExcludingString + + (totalPartitionStringBufferSize / nonEmptyPartitionCount); + auto inputRetainedSize = input->retainedSize(); + + // Enqueue all partitions if one of the following conditions is met: + // 1. This operator is not in buffer mode. + // 2. This operator is in buffer mode and the total buffer size across all + // partitions exceeds 'singlePartitionBufferSize_ * numPartitions_'. + if (singlePartitionBufferSize_ == 0 || + balancedTotalPartitionBufferSize >= + singlePartitionBufferSize_ * numPartitions_) { + auto perPartitionAmortizedSize = + (singlePartitionBufferSize_ > 0 ? balancedTotalPartitionBufferSize + : inputRetainedSize) / + nonEmptyPartitionCount; + for (auto partition = 0; partition < numPartitions_; partition++) { + auto partitionSize = numRowsPerPartition[partition]; + auto partitionData = createPartition( + input, partitionSize, partition, indexBuffers[partition]); + if (!partitionData) { + continue; + } + + ContinueFuture future; + auto reason = queues_[partition]->enqueue( + std::move(partitionData), perPartitionAmortizedSize, &future); + if (reason != BlockingReason::kNotBlocked) { + blockingReasons_.push_back(reason); + futures_.push_back(std::move(future)); + } + } } - return partitionData; } void LocalPartition::addInput(RowVectorPtr input) { prepareForInput(input); + if (input->size() == 0) { + return; + } const auto singlePartition = numPartitions_ == 1 ? 0 @@ -512,31 +611,7 @@ void LocalPartition::addInput(RowVectorPtr input) { ++maxIndex[partition]; } - const int64_t totalSize = input->retainedSize(); - for (auto partition = 0; partition < numPartitions_; partition++) { - auto partitionSize = maxIndex[partition]; - if (partitionSize == 0) { - // Do not enqueue empty partitions. - continue; - } - - auto partitionData = processPartition( - input, - partitionSize, - partition, - indexBuffers_[partition], - rawIndices_[partition]); - - if (partitionData) { - ContinueFuture future; - auto reason = queues_[partition]->enqueue( - partitionData, totalSize * partitionSize / numInput, &future); - if (reason != BlockingReason::kNotBlocked) { - blockingReasons_.push_back(reason); - futures_.push_back(std::move(future)); - } - } - } + populateAndEnqueuePartitions(input, maxIndex, indexBuffers_, rawIndices_); } void LocalPartition::prepareForInput(RowVectorPtr& input) { @@ -566,19 +641,36 @@ BlockingReason LocalPartition::isBlocked(ContinueFuture* future) { void LocalPartition::noMoreInput() { Operator::noMoreInput(); if (!partitionBuffers_.empty()) { + uint64_t totalPartitionBufferSizeExcludingString = 0; + uint64_t totalPartitionStringBufferSize = 0; + uint16_t nonEmptyPartitionCount = 0; for (auto partition = 0; partition < numPartitions_; partition++) { - if (partitionBuffers_[partition] && - partitionBuffers_[partition]->size() > 0) { - auto partitionData = - std::dynamic_pointer_cast(partitionBuffers_[partition]); - VELOX_CHECK(partitionData); - ContinueFuture future; - queues_[partition]->enqueue( - partitionData, - partitionBuffers_[partition]->retainedSize(), - &future); + if (partitionBuffers_[partition]) { + uint64_t stringBufferSize{0}; + auto totalSize = + partitionBuffers_[partition]->retainedSize(stringBufferSize); + totalPartitionBufferSizeExcludingString += totalSize - stringBufferSize; + totalPartitionStringBufferSize += stringBufferSize; + nonEmptyPartitionCount++; + } + } + if (nonEmptyPartitionCount > 0) { + auto balancedPartitionBufferSize = + totalPartitionBufferSizeExcludingString + + (totalPartitionStringBufferSize / nonEmptyPartitionCount); + for (auto partition = 0; partition < numPartitions_; partition++) { + if (partitionBuffers_[partition]) { + auto partitionData = checkedPointerCast( + partitionBuffers_[partition]); + ContinueFuture future; + + queues_[partition]->enqueue( + partitionData, + balancedPartitionBufferSize / nonEmptyPartitionCount, + &future); + } + partitionBuffers_[partition] = nullptr; } - partitionBuffers_[partition] = nullptr; } partitionBuffers_.resize(0); copyRanges_.resize(0); diff --git a/velox/exec/LocalPartition.h b/velox/exec/LocalPartition.h index bdf7ab42df5..f2b1da8de50 100644 --- a/velox/exec/LocalPartition.h +++ b/velox/exec/LocalPartition.h @@ -240,12 +240,21 @@ class LocalPartition : public Operator { void allocateIndexBuffers(const std::vector& sizes); - RowVectorPtr processPartition( - const RowVectorPtr& input, - vector_size_t size, - int partition, - const BufferPtr& indices, - const vector_size_t* rawIndices); + /// Create partitions from 'input' according to 'numRowsPerPartition' and + /// 'indexBuffers', and enqueue the partitions to LocalExchangeQueues. The + /// behavior of partition vector creation varies depending on + /// 'singlePartitionBufferSize_'. If 'singlePartitionBufferSize_' is non-zero, + /// append rows from 'input' to 'partitionBuffers_' for every partition. When + /// the total size of all partition buffer vectors exceeds + /// 'singlePartitionBufferSize_ * numPartitions_', flush all partitionBuffers_ + /// vectors to LocalExchangeQueues. If 'singlePartitionBufferSize_' is zero, + /// create partition vectors by wrapping 'input' with indexBuffers and flush + /// them to LocalExchangeQueues immediately. + void populateAndEnqueuePartitions( + RowVectorPtr input, + const std::vector& numRowsPerPartition, + const std::vector& indexBuffers, + const std::vector& rawIndicesBuffers); const std::vector> queues_; const size_t numPartitions_; @@ -261,6 +270,11 @@ class LocalPartition : public Operator { std::vector rawIndices_; private: + // Try getting a reusable vector for 'partition' from the corresponding + // local-exchange vector pool of this partition. If none is available, create + // a new vector. + VectorPtr getOrCreateVector(const size_t partition); + RowVectorPtr wrapChildren( const RowVectorPtr& input, vector_size_t size, @@ -270,8 +284,30 @@ class LocalPartition : public Operator { void copy( const RowVectorPtr& input, const folly::Range& ranges, + const size_t partition, VectorPtr& target); + /// Add rows from 'input' to 'partitionBuffers_' that every row belongs to. + /// Also set 'totalPartitionBufferSizeExcludingString' to be the total size of + /// all partition buffer vectors excluding string buffers inside them, and set + /// 'totalPartitionStringBufferSize' to be the total size of all string + /// buffers in all partition buffer vectors. + void populatePartitionBuffer( + const RowVectorPtr& input, + const vector_size_t numPartitionRows, + const size_t partition, + const vector_size_t* rawIndices, + uint64_t& totalPartitionBufferSizeExcludingString, + uint64_t& totalPartitionStringBufferSize); + + /// Return the partition vector to be added to LocalExchangeQueue. This method + /// returns nullptr if no row belongs to 'partition' + RowVectorPtr createPartition( + const RowVectorPtr& input, + const vector_size_t numPartitionRows, + const size_t partition, + const BufferPtr& indices); + const uint64_t singlePartitionBufferSize_; std::vector copyRanges_; std::vector partitionBuffers_; diff --git a/velox/exec/LocalPlanner.cpp b/velox/exec/LocalPlanner.cpp index b78ca6ec8b9..39f009fe39a 100644 --- a/velox/exec/LocalPlanner.cpp +++ b/velox/exec/LocalPlanner.cpp @@ -18,6 +18,7 @@ #include "velox/exec/ArrowStream.h" #include "velox/exec/AssignUniqueId.h" #include "velox/exec/CallbackSink.h" +#include "velox/exec/EnforceDistinct.h" #include "velox/exec/EnforceSingleRow.h" #include "velox/exec/Exchange.h" #include "velox/exec/Expand.h" @@ -29,8 +30,10 @@ #include "velox/exec/IndexLookupJoin.h" #include "velox/exec/Limit.h" #include "velox/exec/MarkDistinct.h" +#include "velox/exec/MarkSorted.h" #include "velox/exec/Merge.h" #include "velox/exec/MergeJoin.h" +#include "velox/exec/MixedUnion.h" #include "velox/exec/NestedLoopJoinBuild.h" #include "velox/exec/NestedLoopJoinProbe.h" #include "velox/exec/OperatorTraceScan.h" @@ -43,6 +46,7 @@ #include "velox/exec/SpatialJoinBuild.h" #include "velox/exec/SpatialJoinProbe.h" #include "velox/exec/StreamingAggregation.h" +#include "velox/exec/StreamingEnforceDistinct.h" #include "velox/exec/TableScan.h" #include "velox/exec/TableWriteMerge.h" #include "velox/exec/TableWriter.h" @@ -84,6 +88,11 @@ bool mustStartNewPipeline( return true; } + if (std::dynamic_pointer_cast(planNode)) { + // MixedUnion's sources run on their own pipelines. + return true; + } + if (std::dynamic_pointer_cast(planNode)) { return true; } @@ -142,6 +151,27 @@ OperatorSupplier makeOperatorSupplier( }; } + if (auto mixedUnion = + std::dynamic_pointer_cast(planNode)) { + return [mixedUnion](int32_t operatorId, DriverCtx* ctx) { + auto mergeSource = ctx->task->addLocalMergeSource( + ctx->splitGroupId, + mixedUnion->id(), + mixedUnion->outputType(), + static_cast(ctx->queryConfig().localMergeSourceQueueSize())); + auto consumerCb = + [mergeSource]( + RowVectorPtr input, bool drained, ContinueFuture* future) { + return mergeSource->enqueue(std::move(input), future, drained); + }; + auto startCb = [mergeSource](ContinueFuture* future) { + return mergeSource->started(future); + }; + return std::make_unique( + operatorId, ctx, std::move(consumerCb), std::move(startCb)); + }; + } + if (auto localPartitionNode = std::dynamic_pointer_cast(planNode)) { if (localPartitionNode->scaleWriter()) { @@ -203,7 +233,15 @@ OperatorSupplier makeOperatorSupplier( return source->enqueue(std::move(input), future); } }; - return std::make_unique(operatorId, ctx, consumer); + // NOTE: Pass planNodeId to associate CallbackSink with the MergeJoin + // node for proper operator identification and input collection. + // Operator::maybeSetTracer() uses this to enable tracing. + return std::make_unique( + operatorId, + ctx, + consumer, + nullptr, + ctx->queryConfig().queryTraceEnabled() ? planNodeId : "N/A"); }; } @@ -261,74 +299,24 @@ uint32_t maxDrivers( return count; } for (auto& node : driverFactory.planNodes) { - if (auto topN = std::dynamic_pointer_cast(node)) { - if (!topN->isPartial()) { - // final topN must run single-threaded - return 1; - } - } else if ( - auto values = std::dynamic_pointer_cast(node)) { - // values node must run single-threaded, unless in test context - if (!values->testingIsParallelizable()) { - return 1; - } - } else if (std::dynamic_pointer_cast(node)) { - // ArrowStream node must run single-threaded. + if (node->requiresSingleThread()) { return 1; - } else if ( - auto limit = std::dynamic_pointer_cast(node)) { - // final limit must run single-threaded - if (!limit->isPartial()) { - return 1; - } - } else if ( - auto orderBy = - std::dynamic_pointer_cast(node)) { - // final orderby must run single-threaded - if (!orderBy->isPartial()) { - return 1; - } - } else if ( - auto localExchange = + } + + if (auto localExchange = std::dynamic_pointer_cast(node)) { - // Local gather must run single-threaded. - switch (localExchange->type()) { - case core::LocalPartitionNode::Type::kGather: - return 1; - case core::LocalPartitionNode::Type::kRepartition: - count = std::min(queryConfig.maxLocalExchangePartitionCount(), count); - break; - default: - VELOX_UNREACHABLE("Unexpected local exchange type"); - } - } else if (std::dynamic_pointer_cast(node)) { - // Local merge must run single-threaded. - return 1; - } else if (std::dynamic_pointer_cast(node)) { - // Merge exchange must run single-threaded. - return 1; - } else if (std::dynamic_pointer_cast(node)) { - // Merge join must run single-threaded. - return 1; - } else if ( - auto join = std::dynamic_pointer_cast(node)) { - // Right semi project doesn't support multi-threaded execution. - if (join->isRightSemiProjectJoin()) { - return 1; + // Repartition limits parallelism to the partition count. + if (localExchange->type() == + core::LocalPartitionNode::Type::kRepartition) { + count = std::min(queryConfig.maxLocalExchangePartitionCount(), count); } } else if ( auto tableWrite = std::dynamic_pointer_cast(node)) { - const auto& connectorInsertHandle = - tableWrite->insertTableHandle()->connectorInsertTableHandle(); - if (!connectorInsertHandle->supportsMultiThreading()) { - return 1; + if (tableWrite->hasPartitioningScheme()) { + return queryConfig.taskPartitionedWriterCount(); } else { - if (tableWrite->hasPartitioningScheme()) { - return queryConfig.taskPartitionedWriterCount(); - } else { - return queryConfig.taskWriterCount(); - } + return queryConfig.taskWriterCount(); } } else { auto result = Operator::maxDrivers(node); @@ -490,7 +478,7 @@ std::shared_ptr DriverFactory::createDriver( std::vector> operators; operators.reserve(planNodes.size()); - for (int32_t i = 0; i < planNodes.size(); i++) { + for (int32_t i = 0; i < planNodes.size(); ++i) { // Id of the Operator being made. This is not the same as 'i' // because some PlanNodes may get fused. auto id = operators.size(); @@ -501,8 +489,9 @@ std::shared_ptr DriverFactory::createDriver( auto next = planNodes[i + 1]; if (auto projectNode = std::dynamic_pointer_cast(next)) { - operators.push_back(std::make_unique( - id, ctx.get(), filterNode, projectNode)); + operators.push_back( + std::make_unique( + id, ctx.get(), filterNode, projectNode)); i++; continue; } @@ -543,8 +532,9 @@ std::shared_ptr DriverFactory::createDriver( auto tableWriteMergeNode = std::dynamic_pointer_cast( planNode)) { - operators.push_back(std::make_unique( - id, ctx.get(), tableWriteMergeNode)); + operators.push_back( + std::make_unique( + id, ctx.get(), tableWriteMergeNode)); } else if ( auto mergeExchangeNode = std::dynamic_pointer_cast( @@ -556,14 +546,16 @@ std::shared_ptr DriverFactory::createDriver( std::dynamic_pointer_cast(planNode)) { // NOTE: the exchange client can only be used by one operator in a driver. VELOX_CHECK_NOT_NULL(exchangeClient); - operators.push_back(std::make_unique( - id, ctx.get(), exchangeNode, std::move(exchangeClient))); + operators.push_back( + std::make_unique( + id, ctx.get(), exchangeNode, std::move(exchangeClient))); } else if ( auto partitionedOutputNode = std::dynamic_pointer_cast( planNode)) { - operators.push_back(std::make_unique( - id, ctx.get(), partitionedOutputNode, eagerFlush(*planNode))); + operators.push_back( + std::make_unique( + id, ctx.get(), partitionedOutputNode, eagerFlush(*planNode))); } else if ( auto joinNode = std::dynamic_pointer_cast(planNode)) { @@ -589,8 +581,9 @@ std::shared_ptr DriverFactory::createDriver( auto aggregationNode = std::dynamic_pointer_cast(planNode)) { if (aggregationNode->isPreGrouped()) { - operators.push_back(std::make_unique( - id, ctx.get(), aggregationNode)); + operators.push_back( + std::make_unique( + id, ctx.get(), aggregationNode)); } else { operators.push_back( std::make_unique(id, ctx.get(), aggregationNode)); @@ -637,12 +630,36 @@ std::shared_ptr DriverFactory::createDriver( std::dynamic_pointer_cast(planNode)) { operators.push_back( std::make_unique(id, ctx.get(), markDistinctNode)); + } else if ( + auto enforceDistinctNode = + std::dynamic_pointer_cast( + planNode)) { + if (enforceDistinctNode->isPreGrouped()) { + operators.push_back( + std::make_unique( + id, ctx.get(), enforceDistinctNode)); + } else { + operators.push_back( + std::make_unique( + id, ctx.get(), enforceDistinctNode)); + } + } else if ( + auto markSortedNode = + std::dynamic_pointer_cast(planNode)) { + operators.push_back( + std::make_unique(id, ctx.get(), markSortedNode)); } else if ( auto localMerge = std::dynamic_pointer_cast(planNode)) { auto localMergeOp = std::make_unique(id, ctx.get(), localMerge); operators.push_back(std::move(localMergeOp)); + } else if ( + auto mixedUnion = + std::dynamic_pointer_cast(planNode)) { + auto mixedUnionOp = + std::make_unique(id, ctx.get(), mixedUnion); + operators.push_back(std::move(mixedUnionOp)); } else if ( auto mergeJoin = std::dynamic_pointer_cast(planNode)) { @@ -653,12 +670,13 @@ std::shared_ptr DriverFactory::createDriver( auto localPartitionNode = std::dynamic_pointer_cast( planNode)) { - operators.push_back(std::make_unique( - id, - ctx.get(), - localPartitionNode->outputType(), - localPartitionNode->id(), - ctx->partitionId)); + operators.push_back( + std::make_unique( + id, + ctx.get(), + localPartitionNode->outputType(), + localPartitionNode->id(), + ctx->partitionId)); } else if ( auto unnest = std::dynamic_pointer_cast(planNode)) { @@ -673,17 +691,19 @@ std::shared_ptr DriverFactory::createDriver( auto assignUniqueIdNode = std::dynamic_pointer_cast( planNode)) { - operators.push_back(std::make_unique( - id, - ctx.get(), - assignUniqueIdNode, - assignUniqueIdNode->taskUniqueId(), - assignUniqueIdNode->uniqueIdCounter())); + operators.push_back( + std::make_unique( + id, + ctx.get(), + assignUniqueIdNode, + assignUniqueIdNode->taskUniqueId(), + assignUniqueIdNode->uniqueIdCounter())); } else if ( const auto traceScanNode = std::dynamic_pointer_cast(planNode)) { - operators.push_back(std::make_unique( - id, ctx.get(), traceScanNode)); + operators.push_back( + std::make_unique( + id, ctx.get(), traceScanNode)); } else { std::unique_ptr extended; if (planNode->requiresExchangeClient()) { @@ -731,7 +751,7 @@ std::vector> DriverFactory::replaceOperators( } driver.operators_.erase( - driver.operators_.begin() + begin, driver.operators_.begin() + end); + driver.operators_.cbegin() + begin, driver.operators_.cbegin() + end); // Insert the replacement at the place of the erase. Do manually because // insert() is not good with unique pointers. @@ -812,6 +832,21 @@ std::vector DriverFactory::needsSpatialJoinBridges() const { return planNodeIds; } +std::vector DriverFactory::needsIndexLookupJoinBridges() + const { + std::vector planNodeIds; + for (const auto& planNode : planNodes) { + if (auto joinNode = + std::dynamic_pointer_cast( + planNode)) { + if (joinNode->needsIndexSplit()) { + planNodeIds.emplace_back(joinNode->id()); + } + } + } + return planNodeIds; +} + // static void DriverFactory::registerAdapter(DriverAdapter adapter) { adapters.push_back(std::move(adapter)); diff --git a/velox/exec/MarkDistinct.cpp b/velox/exec/MarkDistinct.cpp index 8b4d0c09e55..2b562c714af 100644 --- a/velox/exec/MarkDistinct.cpp +++ b/velox/exec/MarkDistinct.cpp @@ -15,7 +15,11 @@ */ #include "velox/exec/MarkDistinct.h" -#include "velox/common/base/Range.h" + +#include "velox/common/memory/Memory.h" +#include "velox/common/memory/MemoryArbitrator.h" +#include "velox/exec/OperatorType.h" +#include "velox/exec/OperatorUtils.h" #include "velox/vector/FlatVector.h" #include @@ -32,39 +36,161 @@ MarkDistinct::MarkDistinct( planNode->outputType(), operatorId, planNode->id(), - "MarkDistinct") { - const auto& inputType = planNode->sources()[0]->outputType(); + OperatorType::kMarkDistinct, + planNode->canSpill(driverCtx->queryConfig()) + ? driverCtx->makeSpillConfig(operatorId, planNode->name()) + : std::nullopt) { + inputType_ = planNode->sources()[0]->outputType(); // Set all input columns as identity projection. - for (auto i = 0; i < inputType->size(); ++i) { + for (auto i = 0; i < inputType_->size(); ++i) { identityProjections_.emplace_back(i, i); } - // We will use result[0] for distinct mask output. - resultProjections_.emplace_back(0, inputType->size()); + // Use result[0] for distinct mask output. + resultProjections_.emplace_back(0, inputType_->size()); + + for (const auto& key : planNode->distinctKeys()) { + distinctKeyChannels_.push_back(inputType_->getChildIdx(key->name())); + } - groupingSet_ = GroupingSet::createForMarkDistinct( - inputType, - createVectorHashers(inputType, planNode->distinctKeys()), + groupingSet_ = GroupingSet::createForDistinct( + inputType_, + createVectorHashers(inputType_, planNode->distinctKeys()), + /*preGroupedKeys=*/{}, operatorCtx_.get(), &nonReclaimableSection_); results_.resize(1); + + if (spillEnabled()) { + setSpillPartitionBits(); + } } void MarkDistinct::addInput(RowVectorPtr input) { - groupingSet_->addInput(input, false /*mayPushdown*/); + ensureInputFits(input); + + if (inputSpiller_ != nullptr) { + spillInput(input, pool()); + return; + } + // Don't add to the hash table here. We defer it to getOutput() so that if + // spill() is called between addInput() and getOutput(), the hash table spill + // won't include this input's keys. This prevents those keys from being + // re-suppressed during restore. input_ = std::move(input); } +void MarkDistinct::noMoreInput() { + Operator::noMoreInput(); + + if (inputSpiller_ != nullptr) { + finishSpillInputAndRestoreNext(); + } +} + +void MarkDistinct::finishSpillInputAndRestoreNext() { + VELOX_CHECK_NOT_NULL(inputSpiller_); + inputSpiller_->finishSpill(spillInputPartitionSet_); + inputSpiller_.reset(); + removeEmptyPartitions(spillInputPartitionSet_); + restoreNextSpillPartition(); +} + +void MarkDistinct::restoreNextSpillPartition() { + if (spillInputPartitionSet_.empty()) { + return; + } + + auto it = spillInputPartitionSet_.begin(); + restoringPartitionId_ = it->first; + + spillInputReader_ = it->second->createUnorderedReader( + spillConfig_->readBufferSize, pool(), spillStats_.get()); + + auto hashTableIt = spillHashTablePartitionSet_.find(it->first); + if (hashTableIt != spillHashTablePartitionSet_.end()) { + auto reader = hashTableIt->second->createUnorderedReader( + spillConfig_->readBufferSize, pool(), spillStats_.get()); + + setSpillPartitionBits(&(it->first)); + + auto* table = groupingSet_->table(); + const auto& hashers = table->hashers(); + auto lookup = std::make_unique(hashers, pool()); + + std::vector columns(inputType_->size()); + RowVectorPtr data; + while (reader->nextBatch(data)) { + for (auto i = 0; i < hashers.size(); ++i) { + columns[hashers[i]->channel()] = data->childAt(i); + } + + auto input = std::make_shared( + pool(), inputType_, nullptr, data->size(), std::move(columns)); + + SelectivityVector rows(data->size()); + table->prepareForGroupProbe( + *lookup, input, rows, spillConfig_->startPartitionBit); + table->groupProbe(*lookup, spillConfig_->startPartitionBit); + + columns.assign(inputType_->size(), nullptr); + } + } + + spillInputPartitionSet_.erase(it); + + RowVectorPtr spilledInput; + spillInputReader_->nextBatch(spilledInput); + VELOX_CHECK_NOT_NULL(spilledInput); + addInput(std::move(spilledInput)); +} + +void MarkDistinct::recursiveSpillInput() { + RowVectorPtr spilledInput; + while (spillInputReader_->nextBatch(spilledInput)) { + spillInput(spilledInput, pool()); + + if (shouldYield()) { + yield_ = true; + return; + } + } + + finishSpillInputAndRestoreNext(); +} + RowVectorPtr MarkDistinct::getOutput() { - if (isFinished() || !input_) { + if (isFinished()) { return nullptr; } - auto outputSize = input_->size(); - // Re-use memory for the ID vector if possible. + if (input_ == nullptr) { + if (spillInputReader_ == nullptr) { + return nullptr; + } + + recursiveSpillInput(); + if (yield_) { + yield_ = false; + return nullptr; + } + + if (input_ == nullptr) { + return nullptr; + } + } + + // Add the current input to the hash table now, just before producing output. + // This is deferred from addInput() so that if spill() is called between + // addInput() and getOutput(), the hash table doesn't contain this input's + // keys — ensuring they are correctly marked as new during restore. + groupingSet_->addInput(input_, /*mayPushdown=*/false); + + const auto outputSize = input_->size(); + VectorPtr& result = results_[0]; if (result && result.use_count() == 1) { BaseVector::prepareForReuse(result, outputSize); @@ -72,26 +198,269 @@ RowVectorPtr MarkDistinct::getOutput() { result = BaseVector::create(BOOLEAN(), outputSize, operatorCtx_->pool()); } - // newGroups contains the indices of distinct rows. - // For each index in newGroups, we mark the index'th bit true in the result - // vector. - auto resultBits = + auto* resultBits = results_[0]->as>()->mutableRawValues(); - bits::fillBits(resultBits, 0, outputSize, false); for (const auto i : groupingSet_->hashLookup().newGroups) { bits::setBit(resultBits, i, true); } - auto output = fillOutput(outputSize, nullptr); - // Drop reference to input_ to make it singly-referenced at the producer and - // allow for memory reuse. + auto output = fillOutput(outputSize, nullptr); input_ = nullptr; + if (spillInputReader_ != nullptr) { + RowVectorPtr spilledInput; + if (spillInputReader_->nextBatch(spilledInput)) { + addInput(std::move(spilledInput)); + } else { + spillInputReader_.reset(); + restoringPartitionId_.reset(); + groupingSet_->resetTable(true); + restoreNextSpillPartition(); + } + } + return output; } bool MarkDistinct::isFinished() { - return noMoreInput_ && !input_; + return noMoreInput_ && input_ == nullptr && spillInputReader_ == nullptr; +} + +void MarkDistinct::ensureInputFits(const RowVectorPtr& input) { + if (!spillEnabled() || inputSpiller_ != nullptr) { + return; + } + + const auto numDistinct = groupingSet_->numDistinct(); + if (numDistinct == 0) { + return; + } + + auto* table = groupingSet_->table(); + auto* rows = table->rows(); + const auto [freeRows, outOfLineFreeBytes] = rows->freeSpace(); + const auto outOfLineBytes = + rows->stringAllocator().retainedSize() - outOfLineFreeBytes; + const auto outOfLineBytesPerRow = outOfLineBytes / numDistinct; + + if (testingTriggerSpill(pool()->name())) { + Operator::ReclaimableSectionGuard guard(this); + memory::testingRunArbitration(pool()); + return; + } + + const auto currentUsage = pool()->usedBytes(); + const auto minReservationBytes = + currentUsage * spillConfig_->minSpillableReservationPct / 100; + const auto availableReservationBytes = pool()->availableReservation(); + const auto tableIncrementBytes = + static_cast(table->hashTableSizeIncrease(input->size())); + const auto incrementBytes = + static_cast(rows->sizeIncrement( + input->size(), outOfLineBytesPerRow * input->size())) + + tableIncrementBytes; + + if (availableReservationBytes >= minReservationBytes) { + if ((tableIncrementBytes == 0) && (freeRows > input->size()) && + (outOfLineBytes == 0 || + outOfLineFreeBytes >= outOfLineBytesPerRow * input->size())) { + return; + } + + if (availableReservationBytes > 2 * incrementBytes) { + return; + } + } + + const auto targetIncrementBytes = std::max( + incrementBytes * 2, + currentUsage * spillConfig_->spillableReservationGrowthPct / 100); + { + Operator::ReclaimableSectionGuard guard(this); + if (pool()->maybeReserve(targetIncrementBytes)) { + if (inputSpiller_ != nullptr) { + pool()->release(); + } + return; + } + } + + LOG(WARNING) << "Failed to reserve " << succinctBytes(targetIncrementBytes) + << " for memory pool " << pool()->name() + << ", root pool: " << pool()->root()->name() + << ", used: " << succinctBytes(pool()->usedBytes()) + << ", reservation: " << succinctBytes(pool()->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool()->root()->reservedBytes()); +} + +void MarkDistinct::reclaim( + uint64_t /* unused */, + memory::MemoryReclaimer::Stats& /* unused */) { + VELOX_CHECK(canReclaim()); + VELOX_CHECK(!nonReclaimableSection_); + + if (groupingSet_->numDistinct() == 0) { + return; + } + + if (FOLLY_UNLIKELY(exceededMaxSpillLevelLimit_)) { + LOG(WARNING) << "Exceeded mark distinct spill level limit: " + << spillConfig_->maxSpillLevel + << ", and abandon spilling for memory pool: " << pool()->name() + << ", root pool: " << pool()->root()->name() + << ", used: " << succinctBytes(pool()->usedBytes()) + << ", reservation: " << succinctBytes(pool()->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool()->root()->reservedBytes()); + spillStats_->spillMaxLevelExceededCount.fetch_add( + 1, std::memory_order_relaxed); + return; + } + + spill(); +} + +SpillPartitionIdSet MarkDistinct::spillHashTable() { + VELOX_CHECK_GT(groupingSet_->numDistinct(), 0); + + auto* table = groupingSet_->table(); + auto columnTypes = table->rows()->columnTypes(); + auto tableType = ROW(std::move(columnTypes)); + + auto hashTableSpiller = std::make_unique( + table->rows(), + restoringPartitionId_, + tableType, + spillPartitionBits_, + &spillConfig_.value(), + spillStats_.get()); + + hashTableSpiller->spill(); + hashTableSpiller->finishSpill(spillHashTablePartitionSet_); + + groupingSet_->resetTable(true); + pool()->release(); + return hashTableSpiller->state().spilledPartitionIdSet(); +} + +void MarkDistinct::setupInputSpiller( + const SpillPartitionIdSet& spillPartitionIdSet) { + VELOX_CHECK(!spillPartitionIdSet.empty()); + + inputSpiller_ = std::make_unique( + inputType_, + restoringPartitionId_, + spillPartitionBits_, + &spillConfig_.value(), + spillStats_.get()); + + spillHashFunction_ = std::make_unique( + inputSpiller_->hashBits(), inputType_, distinctKeyChannels_); +} + +void MarkDistinct::spill() { + VELOX_CHECK(spillEnabled()); + + spilled_ = true; + + const auto spillPartitionIdSet = spillHashTable(); + VELOX_CHECK_EQ(groupingSet_->numDistinct(), 0); + + setupInputSpiller(spillPartitionIdSet); + if (input_ != nullptr) { + spillInput(input_, memory::spillMemoryPool()); + input_ = nullptr; + } + results_.clear(); + results_.resize(1); +} + +void MarkDistinct::spillInput( + const RowVectorPtr& input, + memory::MemoryPool* pool) { + const auto numRows = input->size(); + + std::vector partitionAssignments(numRows); + const auto singlePartition = + spillHashFunction_->partition(*input, partitionAssignments); + + const auto numPartitions = spillHashFunction_->numPartitions(); + + std::vector partitionIndices(numPartitions); + std::vector rawPartitionIndices(numPartitions); + + for (auto i = 0; i < numPartitions; ++i) { + partitionIndices[i] = allocateIndices(numRows, pool); + rawPartitionIndices[i] = partitionIndices[i]->asMutable(); + } + + std::vector numSpillInputs(numPartitions, 0); + + for (auto row = 0; row < numRows; ++row) { + const auto partition = singlePartition.has_value() + ? singlePartition.value() + : partitionAssignments[row]; + rawPartitionIndices[partition][numSpillInputs[partition]++] = row; + } + + // Ensure vector are lazy loaded before spilling. + for (auto i = 0; i < input->childrenSize(); ++i) { + input->childAt(i)->loadedVector(); + } + + for (auto partition = 0; partition < numSpillInputs.size(); ++partition) { + const auto numInputs = numSpillInputs[partition]; + if (numInputs == 0) { + continue; + } + + inputSpiller_->spill( + SpillPartitionId(partition), + wrap(numInputs, partitionIndices[partition], input)); + } +} + +void MarkDistinct::setSpillPartitionBits( + const SpillPartitionId* restoredPartitionId) { + const auto startPartitionBitOffset = restoredPartitionId == nullptr + ? spillConfig_->startPartitionBit + : partitionBitOffset( + *restoredPartitionId, + spillConfig_->startPartitionBit, + spillConfig_->numPartitionBits) + + spillConfig_->numPartitionBits; + if (spillConfig_->exceedSpillLevelLimit(startPartitionBitOffset)) { + exceededMaxSpillLevelLimit_ = true; + return; + } + + exceededMaxSpillLevelLimit_ = false; + spillPartitionBits_ = HashBitRange( + startPartitionBitOffset, + startPartitionBitOffset + spillConfig_->numPartitionBits); +} + +MarkDistinctHashTableSpiller::MarkDistinctHashTableSpiller( + RowContainer* container, + std::optional parentId, + RowTypePtr rowType, + HashBitRange bits, + const common::SpillConfig* spillConfig, + exec::SpillStats* spillStats) + : SpillerBase( + container, + std::move(rowType), + bits, + {}, + spillConfig->maxFileSize, + spillConfig->maxSpillRunRows, + parentId, + spillConfig, + spillStats) {} + +void MarkDistinctHashTableSpiller::spill() { + SpillerBase::spill(nullptr); } } // namespace facebook::velox::exec diff --git a/velox/exec/MarkDistinct.h b/velox/exec/MarkDistinct.h index dabba1b179a..c8c582b5ea8 100644 --- a/velox/exec/MarkDistinct.h +++ b/velox/exec/MarkDistinct.h @@ -17,10 +17,21 @@ #pragma once #include "velox/exec/GroupingSet.h" +#include "velox/exec/HashPartitionFunction.h" #include "velox/exec/Operator.h" +#include "velox/exec/Spiller.h" namespace facebook::velox::exec { +/// Marks distinct rows based on a set of grouping keys. For each input row, +/// produces a boolean output column indicating whether the row's key +/// combination is seen for the first time. +/// +/// Supports spilling by persisting the hash table state to disk (like +/// RowNumber). When spill is triggered, the hash table contents and future +/// input are partitioned and written to disk. During restore, each partition's +/// hash table is rebuilt from the spilled data, preserving knowledge of which +/// keys were already seen before spill. class MarkDistinct : public Operator { public: MarkDistinct( @@ -29,6 +40,10 @@ class MarkDistinct : public Operator { const std::shared_ptr& planNode); bool preservesOrder() const override { + return false; + } + + bool isFilter() const override { return true; } @@ -38,6 +53,8 @@ class MarkDistinct : public Operator { void addInput(RowVectorPtr input) override; + void noMoreInput() override; + RowVectorPtr getOutput() override; BlockingReason isBlocked(ContinueFuture* /*future*/) override { @@ -46,8 +63,85 @@ class MarkDistinct : public Operator { bool isFinished() override; + void reclaim(uint64_t targetBytes, memory::MemoryReclaimer::Stats& stats) + override; + private: - // TODO: Document spilling configuration in spilling.rst. + bool spillEnabled() const { + return spillConfig_.has_value(); + } + + void ensureInputFits(const RowVectorPtr& input); + + void spill(); + + void spillInput(const RowVectorPtr& input, memory::MemoryPool* pool); + + void setupInputSpiller(const SpillPartitionIdSet& spillPartitionIdSet); + + void setSpillPartitionBits( + const SpillPartitionId* restoredPartitionId = nullptr); + + SpillPartitionIdSet spillHashTable(); + + void restoreNextSpillPartition(); + + void finishSpillInputAndRestoreNext(); + + /// Drains remaining batches from the current spill partition reader and + /// spills them when a recursive spill was triggered during restore. + void recursiveSpillInput(); + + RowTypePtr inputType_; + std::unique_ptr groupingSet_; + + HashBitRange spillPartitionBits_; + + std::unique_ptr inputSpiller_; + + std::unique_ptr> spillInputReader_; + + std::optional restoringPartitionId_; + + SpillPartitionSet spillInputPartitionSet_; + + std::unique_ptr spillHashFunction_; + + SpillPartitionSet spillHashTablePartitionSet_; + + bool exceededMaxSpillLevelLimit_{false}; + + bool spilled_{false}; + + bool yield_{false}; + + std::vector distinctKeyChannels_; +}; + +/// Spills the hash table contents (distinct keys) from MarkDistinct's +/// GroupingSet to disk, partitioned by key hash. +class MarkDistinctHashTableSpiller : public SpillerBase { + public: + static constexpr std::string_view kType = "MarkDistinctHashTableSpiller"; + + MarkDistinctHashTableSpiller( + RowContainer* container, + std::optional parentId, + RowTypePtr rowType, + HashBitRange bits, + const common::SpillConfig* spillConfig, + exec::SpillStats* spillStats); + + void spill(); + + private: + bool needSort() const override { + return false; + } + + std::string type() const override { + return std::string(kType); + } }; } // namespace facebook::velox::exec diff --git a/velox/exec/MarkSorted.cpp b/velox/exec/MarkSorted.cpp new file mode 100644 index 00000000000..120175323a6 --- /dev/null +++ b/velox/exec/MarkSorted.cpp @@ -0,0 +1,330 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/MarkSorted.h" +#include "velox/common/base/BitUtil.h" +#include "velox/common/base/CompareFlags.h" +#include "velox/functions/lib/SIMDComparisonUtil.h" +#include "velox/vector/FlatVector.h" + +namespace facebook::velox::exec { + +namespace { +bool isSimdEligibleType(TypeKind kind) { + return kind == TypeKind::TINYINT || kind == TypeKind::SMALLINT || + kind == TypeKind::INTEGER || kind == TypeKind::BIGINT; +} +} // namespace + +MarkSorted::MarkSorted( + int32_t operatorId, + DriverCtx* driverCtx, + const std::shared_ptr& planNode) + : Operator( + driverCtx, + planNode->outputType(), + operatorId, + planNode->id(), + "MarkSorted"), + markerName_(planNode->markerName()), + zeroCopyThreshold_( + driverCtx->queryConfig().markSortedZeroCopyThreshold()) { + const auto& inputType = planNode->sources()[0]->outputType(); + const auto& sortingKeys = planNode->sortingKeys(); + const auto& sortingOrders = planNode->sortingOrders(); + + // Set all input columns as identity projection. + for (auto i = 0; i < inputType->size(); ++i) { + identityProjections_.emplace_back(i, i); + } + + // Map the marker result (results_[0]) to the last column position in output, + // immediately after all input columns. + resultProjections_.emplace_back(0, inputType->size()); + + // Extract channel indices for sorting keys. + sortingKeyChannels_.reserve(sortingKeys.size()); + compareFlags_.reserve(sortingKeys.size()); + sortingOrders_.reserve(sortingKeys.size()); + + for (auto i = 0; i < sortingKeys.size(); ++i) { + const auto& key = sortingKeys[i]; + auto channel = inputType->getChildIdx(key->name()); + sortingKeyChannels_.push_back(channel); + + const auto& order = sortingOrders[i]; + sortingOrders_.push_back(order); + + // Build CompareFlags from SortOrder. + compareFlags_.push_back( + {order.isNullsFirst(), + order.isAscending(), + false, // equalsOnly + CompareFlags::NullHandlingMode::kNullAsValue}); + } + + // Precompute type for lastRow_ (contains only sorting key columns). + std::vector keyNames; + std::vector keyTypes; + keyNames.reserve(sortingKeys.size()); + keyTypes.reserve(sortingKeys.size()); + for (auto i = 0; i < sortingKeys.size(); ++i) { + auto channel = sortingKeyChannels_[i]; + keyNames.push_back(inputType->nameOf(channel)); + keyTypes.push_back(inputType->childAt(channel)); + } + lastRowType_ = ROW(std::move(keyNames), std::move(keyTypes)); + + results_.resize(1); +} + +void MarkSorted::addInput(RowVectorPtr input) { + input_ = std::move(input); +} + +bool MarkSorted::isSortedRelativeTo( + const RowVectorPtr& currentData, + vector_size_t currentIndex, + const RowVectorPtr& prevData, + vector_size_t prevIndex) { + for (auto i = 0; i < sortingKeyChannels_.size(); ++i) { + auto channel = sortingKeyChannels_[i]; + const auto& currentColumn = currentData->childAt(channel); + const auto& prevColumn = prevData->childAt(channel); + + auto result = currentColumn->compare( + prevColumn.get(), currentIndex, prevIndex, compareFlags_[i]); + + if (result.has_value()) { + if (result.value() < 0) { + return false; + } else if (result.value() > 0) { + return true; + } + } + } + + // All keys are equal - this is still considered sorted. + return true; +} + +bool MarkSorted::allKeysConstant() const { + for (auto channel : sortingKeyChannels_) { + auto& keyCol = input_->childAt(channel); + if (!keyCol->isConstantEncoding() || keyCol->isNullAt(0)) { + return false; + } + } + return true; +} + +bool MarkSorted::canApplySimdPath() const { + if (sortingKeyChannels_.size() != 1) { + return false; + } + auto& keyCol = input_->childAt(sortingKeyChannels_[0]); + if (!keyCol->isFlatEncoding() || keyCol->mayHaveNulls()) { + return false; + } + if (keyCol->type()->providesCustomComparison()) { + return false; + } + return isSimdEligibleType(keyCol->typeKind()); +} + +void MarkSorted::applySimdComparison( + uint64_t* resultBits, + vector_size_t numRows) { + auto channel = sortingKeyChannels_[0]; + auto& keyCol = input_->childAt(channel); + bool ascending = sortingOrders_[0].isAscending(); + const auto numCompares = numRows - 1; + + // Reuse bit-packed result buffer for SIMD output. + const auto bufferBytes = bits::nbytes(numCompares); + if (!simdBuffer_ || simdBuffer_->size() < bufferBytes) { + simdBuffer_ = AlignedBuffer::allocate(bufferBytes, pool()); + } + auto simdResult = simdBuffer_->asMutable(); + memset(simdResult, 0, bufferBytes); + + // Dispatch by type to call applySimdComparison with typed raw data. + // Consecutive-row trick: rawData+1 as lhs, rawData as rhs. + // For ascending: row[i+1] >= row[i] means sorted. + // For descending: row[i+1] <= row[i] means sorted. + auto dispatchSimd = [&](auto dummy) { + using T = decltype(dummy); + const T* rawData = keyCol->asFlatVector()->rawValues(); + if (ascending) { + functions::applySimdComparison>( + 0, numCompares, rawData + 1, rawData, simdResult); + } else { + functions::applySimdComparison>( + 0, numCompares, rawData + 1, rawData, simdResult); + } + }; + + auto kind = keyCol->typeKind(); + if (kind == TypeKind::TINYINT) { + dispatchSimd(int8_t{}); + } else if (kind == TypeKind::SMALLINT) { + dispatchSimd(int16_t{}); + } else if (kind == TypeKind::INTEGER) { + dispatchSimd(int32_t{}); + } else { + VELOX_DCHECK_EQ(kind, TypeKind::BIGINT); + dispatchSimd(int64_t{}); + } + + // Copy SIMD results into resultBits with +1 offset. + // simdResult bit i = comparison result for (data[i+1] vs data[i]). + // Row 0 is handled by cross-batch logic, so we write starting at bit 1. + bits::copyBits( + reinterpret_cast(simdResult), + 0, + resultBits, + 1, + numCompares); +} + +RowVectorPtr MarkSorted::getOutput() { + if (isFinished() || !input_) { + return nullptr; + } + + auto outputSize = input_->size(); + + // Handle empty batches. + if (outputSize == 0) { + input_ = nullptr; + return nullptr; + } + + // Re-use memory for the marker vector if possible. + VectorPtr& result = results_[0]; + if (result && result.use_count() == 1) { + BaseVector::prepareForReuse(result, outputSize); + } else { + result = BaseVector::create(BOOLEAN(), outputSize, operatorCtx_->pool()); + } + + auto resultBits = + results_[0]->as>()->mutableRawValues(); + + // Initialize all bits to true (sorted), then clear for violations. + bits::fillBits(resultBits, 0, outputSize, true); + + // Cross-batch comparison: compare first row of current batch with last row + // of previous batch. + if (prevInput_) { + // Zero-copy mode: prevInput_ has same schema as current input. + for (column_index_t k = 0; k < sortingKeyChannels_.size(); ++k) { + auto channel = sortingKeyChannels_[k]; + auto cmp = prevInput_->childAt(channel)->compare( + input_->childAt(channel).get(), + prevInput_->size() - 1, + 0, + compareFlags_[k]); + if (cmp.has_value() && cmp.value() != 0) { + if (cmp.value() > 0) { + // Previous > current in sort order means NOT sorted. + bits::clearBit(resultBits, 0); + } + break; + } + } + prevInput_.reset(); + } else if (lastRow_) { + // Copy mode: lastRow_ has key columns at sequential indices. + for (column_index_t k = 0; k < sortingKeyChannels_.size(); ++k) { + auto channel = sortingKeyChannels_[k]; + auto cmp = lastRow_->childAt(k)->compare( + input_->childAt(channel).get(), 0, 0, compareFlags_[k]); + if (cmp.has_value() && cmp.value() != 0) { + if (cmp.value() > 0) { + bits::clearBit(resultBits, 0); + } + break; + } + } + } + + // Within-batch comparison. + if (allKeysConstant()) { + // ConstantVector fast path: all key columns are constant non-null, + // so all rows are trivially sorted. Bits are already true. + } else if (canApplySimdPath()) { + // SIMD fast path: single flat non-null primitive key. + applySimdComparison(resultBits, outputSize); + } else { + // Generic path: compare each row with its predecessor. + for (auto i = 1; i < outputSize; ++i) { + if (!isSortedRelativeTo(input_, i, input_, i - 1)) { + bits::setBit(resultBits, i, false); + } + } + } + + // Store last row for next batch's cross-batch comparison. + if (input_->size() < zeroCopyThreshold_) { + prevInput_ = input_; + lastRow_.reset(); + } else { + copyLastRowKeyColumns(); + prevInput_.reset(); + } + + auto output = fillOutput(outputSize, nullptr); + + // Drop reference to input_ to make it singly-referenced at the producer and + // allow for memory reuse. + input_ = nullptr; + + return output; +} + +void MarkSorted::copyLastRowKeyColumns() { + auto lastIndex = input_->size() - 1; + auto numKeys = sortingKeyChannels_.size(); + + std::vector keyChildren(numKeys); + for (auto i = 0; i < numKeys; ++i) { + auto channel = sortingKeyChannels_[i]; + const auto& sourceColumn = input_->childAt(channel); + keyChildren[i] = + BaseVector::create(sourceColumn->type(), 1, operatorCtx_->pool()); + keyChildren[i]->copy(sourceColumn.get(), 0, lastIndex, 1); + } + + lastRow_ = std::make_shared( + operatorCtx_->pool(), + lastRowType_, + nullptr, // no nulls + 1, // single row + std::move(keyChildren)); +} + +void MarkSorted::noMoreInput() { + Operator::noMoreInput(); + lastRow_.reset(); + prevInput_.reset(); +} + +bool MarkSorted::isFinished() { + return noMoreInput_ && !input_; +} + +} // namespace facebook::velox::exec diff --git a/velox/exec/MarkSorted.h b/velox/exec/MarkSorted.h new file mode 100644 index 00000000000..e546e3c00a8 --- /dev/null +++ b/velox/exec/MarkSorted.h @@ -0,0 +1,116 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/exec/Operator.h" + +namespace facebook::velox::exec { + +/// Marks each row with a boolean indicating whether it maintains sort order +/// relative to its predecessor. The first row is always marked true. +/// Subsequent rows are marked true if they are sorted relative to the +/// previous row based on the configured sorting keys and orders. +class MarkSorted : public Operator { + public: + MarkSorted( + int32_t operatorId, + DriverCtx* driverCtx, + const std::shared_ptr& planNode); + + bool preservesOrder() const override { + return true; + } + + bool isFilter() const override { + return true; + } + + bool needsInput() const override { + return !noMoreInput_ && !input_; + } + + void addInput(RowVectorPtr input) override; + + RowVectorPtr getOutput() override; + + BlockingReason isBlocked(ContinueFuture* /*future*/) override { + return BlockingReason::kNotBlocked; + } + + bool isFinished() override; + + void noMoreInput() override; + + private: + /// Compare row at currentIndex in currentData with row at prevIndex in + /// prevData. Both vectors must share the same schema (uses + /// sortingKeyChannels_ to access columns). + bool isSortedRelativeTo( + const RowVectorPtr& currentData, + vector_size_t currentIndex, + const RowVectorPtr& prevData, + vector_size_t prevIndex); + + /// Copy only sorting key columns of the last row from input_ into lastRow_. + /// Creates a single-row RowVector with key columns at sequential indices + /// (0, 1, 2, ...) to avoid holding a reference to the entire input batch. + void copyLastRowKeyColumns(); + + /// Returns true if all key columns in input_ use constant encoding with + /// non-null values. When true, within-batch comparison can be skipped + /// entirely since all rows have the same key values. + bool allKeysConstant() const; + + /// Returns true if input has a single sorting key that is flat, non-null, + /// and a SIMD-eligible primitive type. + bool canApplySimdPath() const; + + /// Apply SIMD comparison for a single primitive sorting key. Writes + /// bit-packed results into resultBits (clears bits for unsorted rows). + /// Row 0 is not modified (handled by cross-batch logic). + void applySimdComparison(uint64_t* resultBits, vector_size_t numRows); + + const std::string markerName_; + std::vector sortingKeyChannels_; + std::vector compareFlags_; + std::vector sortingOrders_; + + /// Key-only RowType for lastRow_ construction. Columns are at sequential + /// indices (0, 1, 2, ...) mapping to sortingKeyChannels_ in the input. + RowTypePtr lastRowType_; + + /// Stores only sorting key column values of the last row from the previous + /// batch, for cross-batch comparison. Has a different schema than input_ + /// (key columns only), so cross-batch comparison uses inline logic instead + /// of isSortedRelativeTo(). + RowVectorPtr lastRow_; + + /// Zero-copy: holds reference to the previous input batch for cross-batch + /// comparison when the batch is smaller than zeroCopyThreshold_. Mutually + /// exclusive with lastRow_ (one or neither is set, never both). + RowVectorPtr prevInput_; + + /// Batch size threshold for zero-copy optimization. Batches smaller than + /// this hold a reference to the entire batch; larger batches deep-copy + /// key columns only. + const int32_t zeroCopyThreshold_; + + /// Reusable buffer for SIMD comparison results. Allocated on first use + /// and grown as needed to avoid per-batch allocation overhead. + BufferPtr simdBuffer_; +}; +} // namespace facebook::velox::exec diff --git a/velox/exec/MemoryReclaimer.cpp b/velox/exec/MemoryReclaimer.cpp index 4aee429883d..1837da7881b 100644 --- a/velox/exec/MemoryReclaimer.cpp +++ b/velox/exec/MemoryReclaimer.cpp @@ -108,9 +108,10 @@ uint64_t ParallelMemoryReclaimer::reclaim( if (!reclaimableBytesOpt.has_value()) { continue; } - candidates.push_back(Candidate{ - std::move(child), - static_cast(reclaimableBytesOpt.value())}); + candidates.push_back( + Candidate{ + std::move(child), + static_cast(reclaimableBytesOpt.value())}); } } } @@ -134,21 +135,23 @@ uint64_t ParallelMemoryReclaimer::reclaim( if (candidate.reclaimableBytes == 0) { continue; } - reclaimTasks.push_back(memory::createAsyncMemoryReclaimTask( - [&, reclaimPool = candidate.pool]() { - try { - Stats reclaimStats; - const auto bytes = - reclaimPool->reclaim(targetBytes, maxWaitMs, reclaimStats); - return std::make_unique( - bytes, std::move(reclaimStats)); - } catch (const std::exception& e) { - VELOX_MEM_LOG(ERROR) << "Reclaim from memory pool " << pool->name() - << " failed: " << e.what(); - // The exception is captured and thrown by the caller. - return std::make_unique(std::current_exception()); - } - })); + reclaimTasks.push_back( + memory::createAsyncMemoryReclaimTask( + [&, reclaimPool = candidate.pool]() { + try { + Stats reclaimStats; + const auto bytes = + reclaimPool->reclaim(targetBytes, maxWaitMs, reclaimStats); + return std::make_unique( + bytes, std::move(reclaimStats)); + } catch (const std::exception& e) { + VELOX_MEM_LOG(ERROR) << "Reclaim from memory pool " + << pool->name() << " failed: " << e.what(); + // The exception is captured and thrown by the caller. + return std::make_unique( + std::current_exception()); + } + })); if (reclaimTasks.size() > 1) { executor_->add([source = reclaimTasks.back()]() { source->prepare(); }); } diff --git a/velox/exec/Merge.cpp b/velox/exec/Merge.cpp index 887fec7017e..aada7c0a413 100644 --- a/velox/exec/Merge.cpp +++ b/velox/exec/Merge.cpp @@ -15,26 +15,16 @@ */ #include "velox/exec/Merge.h" +#include +#include #include "velox/common/testutil/TestValue.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/OperatorUtils.h" #include "velox/exec/Task.h" using facebook::velox::common::testutil::TestValue; namespace facebook::velox::exec { -namespace { -std::unique_ptr getVectorSerdeOptions( - const core::QueryConfig& queryConfig, - VectorSerde::Kind kind) { - std::unique_ptr options = - kind == VectorSerde::Kind::kPresto - ? std::make_unique() - : std::make_unique(); - options->compressionKind = - common::stringToCompressionKind(queryConfig.shuffleCompressionKind()); - return options; -} -} // namespace Merge::Merge( int32_t operatorId, @@ -44,7 +34,7 @@ Merge::Merge( sortingKeys, const std::vector& sortingOrders, const std::string& planNodeId, - const std::string& operatorType, + std::string_view operatorType, const std::optional& spillConfig) : SourceOperator( driverCtx, @@ -170,8 +160,12 @@ void Merge::setupSpillMerger() { std::vector> spillReadFiles; spillReadFiles.reserve(spillFiles.size()); for (const auto& spillFile : spillFiles) { - spillReadFiles.emplace_back(SpillReadFile::create( - spillFile, spillConfig_->readBufferSize, pool(), spillStats_.get())); + spillReadFiles.emplace_back( + SpillReadFile::create( + spillFile, + spillConfig_->readBufferSize, + pool(), + spillStats_.get())); } spillReadFilesGroups.push_back(std::move(spillReadFiles)); } @@ -206,8 +200,9 @@ void Merge::maybeStartNextMergeSourceGroup() { std::vector> cursors; cursors.reserve(sources.size()); for (auto* source : sources) { - cursors.push_back(std::make_unique( - source, sortingKeys_, maxOutputBatchRows_)); + cursors.push_back( + std::make_unique( + source, sortingKeys_, maxOutputBatchRows_)); } // TODO: consider to provide a config other than the regular operator batch @@ -342,8 +337,9 @@ SourceMerger::SourceMerger( } return streams; }()), - merger_(std::make_unique>( - std::move(sourceStreams))), + merger_( + std::make_unique>( + std::move(sourceStreams))), pool_(pool) {} void SourceMerger::isBlocked( @@ -390,10 +386,7 @@ RowVectorPtr SourceMerger::getOutput( VELOX_CHECK_GT(outputBatchRows_, 0); if (!output_) { - output_ = BaseVector::create(type_, outputBatchRows_, pool_); - for (auto& child : output_->children()) { - child->resize(outputBatchRows_); - } + output_ = createOutputVector(); } for (;;) { @@ -439,6 +432,23 @@ RowVectorPtr SourceMerger::getOutput( } } +RowVectorPtr SourceMerger::createOutputVector() { + // Attempt to generate output vector using stream data to preserve encodings. + // First, find the first stream with non-null data to determine column + // encodings. + const RowVector* source = nullptr; + for (const auto* stream : streams_) { + if (stream->hasData() && (source = stream->data())) { + return BaseVector::createEmptyLike( + source, outputBatchRows_, pool_); + } + } + + // If a non-null stream cannot be found, default to generating row vector by + // type. + return BaseVector::create(type_, outputBatchRows_, pool_); +} + bool SourceStream::operator<(const MergeStream& other) const { const auto& otherCursor = static_cast(other); for (auto i = 0; i < sortingKeys_.size(); ++i) { @@ -496,7 +506,8 @@ void SourceStream::copyToOutput(RowVectorPtr& output) { bool SourceStream::fetchMoreData(std::vector& futures) { ContinueFuture future; - auto reason = source_->next(data_, &future); + bool drained{false}; + auto reason = source_->next(data_, &future, drained); if (reason != BlockingReason::kNotBlocked) { needData_ = true; futures.emplace_back(std::move(future)); @@ -528,7 +539,7 @@ SpillMerger::SpillMerger( uint64_t maxOutputBatchBytes, int mergeSourceQueueSize, const common::SpillConfig* spillConfig, - const std::shared_ptr>& spillStats, + const std::shared_ptr& spillStats, velox::memory::MemoryPool* pool) : executor_(spillConfig->executor), spillStats_(spillStats), @@ -562,14 +573,21 @@ void SpillMerger::start() { RowVectorPtr SpillMerger::getOutput( std::vector& sourceBlockingFutures, - bool& atEnd) const { + bool& atEnd) { TestValue::adjust( "facebook::velox::exec::SpillMerger::getOutput", &sourceBlockingFutures); sourceMerger_->isBlocked(sourceBlockingFutures); if (!sourceBlockingFutures.empty()) { return nullptr; } - return sourceMerger_->getOutput(sourceBlockingFutures, atEnd); + // SpillMerger::getOutput waits for all readers to finish, reaches EOF, + // and rethrows any captured error. Centralizing error propagation here + // helps prevent potential resource leaks. + auto output = sourceMerger_->getOutput(sourceBlockingFutures, atEnd); + if (atEnd) { + checkError(); + } + return output; } std::vector> SpillMerger::createMergeSources( @@ -609,68 +627,105 @@ std::unique_ptr SpillMerger::createSourceMerger( std::vector> streams; streams.reserve(sources.size()); for (const auto& source : sources) { - streams.push_back(std::make_unique( - source.get(), sortingKeys, maxOutputBatchRows)); + streams.push_back( + std::make_unique( + source.get(), sortingKeys, maxOutputBatchRows)); } return std::make_unique( type, std::move(streams), maxOutputBatchRows, maxOutputBatchBytes, pool); } -// static. -void SpillMerger::asyncReadFromSpillFileStream( +void SpillMerger::finishSource(size_t streamIdx) const { + ContinueFuture future{ContinueFuture::makeEmpty()}; + sources_[streamIdx]->enqueue(nullptr, &future); + VELOX_CHECK(!future.valid()); +} + +void SpillMerger::readFromSpillFileStream( const std::weak_ptr& mergeHolder, size_t streamIdx) { TestValue::adjust( - "facebook::velox::exec::SpillMerger::asyncReadFromSpillFileStream", - static_cast(0)); + "facebook::velox::exec::SpillMerger::readFromSpillFileStream", nullptr); const auto merger = mergeHolder.lock(); if (merger == nullptr) { LOG(ERROR) << "SpillMerger is destroyed, abandon reading from batch stream"; return; } - merger->readFromSpillFileStream(streamIdx); -} -void SpillMerger::readFromSpillFileStream(size_t streamIdx) { - RowVectorPtr vector; - ContinueFuture future{ContinueFuture::makeEmpty()}; - if (!batchStreams_[streamIdx]->nextBatch(vector)) { - VELOX_CHECK_NULL(vector); - sources_[streamIdx]->enqueue(nullptr, &future); - VELOX_CHECK(!future.valid()); - return; - } - const auto blockingReason = - sources_[streamIdx]->enqueue(std::move(vector), &future); - // TODO: add async error handling. - if (blockingReason == BlockingReason::kNotBlocked) { - VELOX_CHECK(!future.valid()); - executor_->add( - [mergeHolder = std::weak_ptr(shared_from_this()), streamIdx]() { - asyncReadFromSpillFileStream(mergeHolder, streamIdx); - }); - } else { - VELOX_CHECK(future.valid()); - std::move(future) - .via(executor_) - .thenValue([mergeHolder = std::weak_ptr(shared_from_this()), - streamIdx](folly::Unit) { - asyncReadFromSpillFileStream(mergeHolder, streamIdx); - }) - .thenError( - folly::tag_t{}, - [streamIdx](const std::exception& e) { - LOG(ERROR) << "Stop the " << streamIdx - << "th batch stream producer on error: " << e.what(); - }); + try { + if (hasError()) { + finishSource(streamIdx); + return; + } + + RowVectorPtr vector; + if (!batchStreams_[streamIdx]->nextBatch(vector)) { + VELOX_CHECK_NULL(vector); + finishSource(streamIdx); + return; + } + + ContinueFuture future{ContinueFuture::makeEmpty()}; + const auto blockingReason = + sources_[streamIdx]->enqueue(std::move(vector), &future); + if (blockingReason == BlockingReason::kNotBlocked) { + VELOX_CHECK(!future.valid()); + readFromSpillFileStream(mergeHolder, streamIdx); + } else { + VELOX_CHECK(future.valid()); + std::move(future) + .via(executor_) + .thenValue([this, mergeHolder, streamIdx](auto&&) { + readFromSpillFileStream(mergeHolder, streamIdx); + }) + .thenError( + folly::tag_t{}, + [this, mergeHolder, streamIdx](const std::exception& e) { + const auto merger = mergeHolder.lock(); + if (merger != nullptr) { + LOG(ERROR) << "Stop the " << streamIdx + << " th source on error: " << e.what(); + setError(std::make_exception_ptr(e)); + finishSource(streamIdx); + } + }); + } + } catch (const std::exception& e) { + LOG(ERROR) << "The " << streamIdx + << " spill stream failed with error: " << e.what(); + setError(std::current_exception()); + finishSource(streamIdx); } } void SpillMerger::scheduleAsyncSpillFileStreamReads() { VELOX_CHECK_EQ(batchStreams_.size(), sources_.size()); for (auto i = 0; i < batchStreams_.size(); ++i) { - executor_->add( - [&, streamIdx = i]() { readFromSpillFileStream(streamIdx); }); + executor_->add([&, streamIdx = i]() { + readFromSpillFileStream(std::weak_ptr(shared_from_this()), streamIdx); + }); + } +} + +void SpillMerger::setError(const std::exception_ptr& exception) { + std::lock_guard l(mutex_); + if (exception_ != nullptr) { + return; + } + exception_ = exception; +} + +bool SpillMerger::hasError() const { + std::lock_guard l(mutex_); + return exception_ != nullptr; +} + +void SpillMerger::checkError() { + if (hasError()) { + sourceMerger_.reset(); + batchStreams_.clear(); + sources_.clear(); + std::rethrow_exception(exception_); } } @@ -685,9 +740,11 @@ LocalMerge::LocalMerge( localMergeNode->sortingKeys(), localMergeNode->sortingOrders(), localMergeNode->id(), - "LocalMerge", + OperatorType::kLocalMerge, localMergeNode->canSpill(driverCtx->queryConfig()) - ? driverCtx->makeSpillConfig(operatorId) + ? driverCtx->makeSpillConfig( + operatorId, + OperatorType::kLocalMerge) : std::nullopt) { VELOX_CHECK_EQ( operatorCtx_->driverCtx()->driverId, @@ -722,11 +779,14 @@ MergeExchange::MergeExchange( mergeExchangeNode->sortingKeys(), mergeExchangeNode->sortingOrders(), mergeExchangeNode->id(), - "MergeExchange"), + OperatorType::kMergeExchange), serde_(getNamedVectorSerde(mergeExchangeNode->serdeKind())), serdeOptions_(getVectorSerdeOptions( - driverCtx->queryConfig(), - mergeExchangeNode->serdeKind())) {} + common::stringToCompressionKind( + driverCtx->queryConfig().shuffleCompressionKind()), + mergeExchangeNode->serdeKind(), + std::nullopt, + driverCtx->queryConfig().minShuffleCompressionPageSizeBytes())) {} BlockingReason MergeExchange::addMergeSources(ContinueFuture* future) { if (operatorCtx_->driverCtx()->driverId != 0) { @@ -741,7 +801,14 @@ BlockingReason MergeExchange::addMergeSources(ContinueFuture* future) { for (;;) { exec::Split split; auto reason = operatorCtx_->task()->getSplitOrFuture( - operatorCtx_->driverCtx()->splitGroupId, planNodeId(), split, *future); + + operatorCtx_->driverCtx()->driverId, + operatorCtx_->driverCtx()->splitGroupId, + planNodeId(), + /*maxPreloadSplits=*/0, + /*preload=*/nullptr, + split, + *future); if (reason != BlockingReason::kNotBlocked) { return reason; } @@ -770,13 +837,14 @@ BlockingReason MergeExchange::addMergeSources(ContinueFuture* future) { operatorCtx_->planNodeId(), operatorCtx_->driverCtx()->pipelineId, remoteSourceIndex); - sources_.emplace_back(MergeSource::createMergeExchangeSource( - this, - remoteSourceTaskIds_[remoteSourceIndex], - operatorCtx_->task()->destination(), - maxQueuedBytesPerSource, - pool, - operatorCtx_->task()->queryCtx()->executor())); + sources_.emplace_back( + MergeSource::createMergeExchangeSource( + this, + remoteSourceTaskIds_[remoteSourceIndex], + operatorCtx_->task()->destination(), + maxQueuedBytesPerSource, + pool, + operatorCtx_->task()->queryCtx()->executor())); } } // TODO Delay this call until all input data has been processed. @@ -795,7 +863,8 @@ void MergeExchange::close() { auto lockedStats = stats_.wlock(); lockedStats->addRuntimeStat( Operator::kShuffleSerdeKind, - RuntimeCounter(static_cast(serde_->kind()))); + RuntimeCounter( + static_cast(VectorSerde::kindByName(serde_->kind())))); lockedStats->addRuntimeStat( Operator::kShuffleCompressionKind, RuntimeCounter(static_cast(serdeOptions_->compressionKind))); diff --git a/velox/exec/Merge.h b/velox/exec/Merge.h index f41dfebd94f..799ba348fe0 100644 --- a/velox/exec/Merge.h +++ b/velox/exec/Merge.h @@ -15,11 +15,11 @@ */ #pragma once +#include "velox/common/base/TreeOfLosers.h" #include "velox/exec/Exchange.h" #include "velox/exec/MergeSource.h" #include "velox/exec/Spill.h" #include "velox/exec/Spiller.h" -#include "velox/exec/TreeOfLosers.h" namespace facebook::velox::exec { @@ -40,7 +40,7 @@ class Merge : public SourceOperator { sortingKeys, const std::vector& sortingOrders, const std::string& planNodeId, - const std::string& operatorType, + std::string_view operatorType, const std::optional& spillConfig = std::nullopt); void initialize() override; @@ -61,12 +61,12 @@ class Merge : public SourceOperator { /// The running wall time of the merge operator reading from the streaming /// source. If spilling is enabled for local merge, this also includes the /// time that writes to the spilled source. - static inline const std::string kStreamingSourceReadWallNanos{ + static constexpr std::string_view kStreamingSourceReadWallNanos{ "streamingSourceReadWallNanos"}; /// The running wall time of the merge operator reading from the spilled /// source to produce the final output. This only applies when spilling is /// enabled for local merge. - static inline const std::string kSpilledSourceReadWallNanos{ + static constexpr std::string_view kSpilledSourceReadWallNanos{ "spilledSourceReadWallNanos"}; protected: @@ -167,6 +167,10 @@ class SourceMerger { private: void setOutputBatchSize(); + /// Creates the output vector. If a template is available from input data, + /// creates output children with matching encodings to support FlatMapVector. + RowVectorPtr createOutputVector(); + const RowTypePtr type_; const vector_size_t maxOutputBatchRows_; const uint64_t maxOutputBatchBytes_; @@ -211,6 +215,12 @@ class SourceStream final : public MergeStream { return !atEnd_; } + /// Returns the current data batch from the source. Used for encoding + /// detection to create output vectors with matching encodings. + const RowVector* data() const { + return data_.get(); + } + // Returns the estimated row size based on the vector received from the // merge source. std::optional estimateRowSize() const { @@ -292,8 +302,7 @@ class SpillMerger : public std::enable_shared_from_this { uint64_t maxOutputBatchBytes, int mergeSourceQueueSize, const common::SpillConfig* spillConfig, - const std::shared_ptr>& - spillStats, + const std::shared_ptr& spillStats, velox::memory::MemoryPool* pool); ~SpillMerger(); @@ -302,7 +311,7 @@ class SpillMerger : public std::enable_shared_from_this { RowVectorPtr getOutput( std::vector& sourceBlockingFutures, - bool& atEnd) const; + bool& atEnd); private: static std::vector> createMergeSources( @@ -321,21 +330,32 @@ class SpillMerger : public std::enable_shared_from_this { uint64_t maxOutputBatchBytes, velox::memory::MemoryPool* pool); - static void asyncReadFromSpillFileStream( + void finishSource(size_t streamIdx) const; + + void readFromSpillFileStream( const std::weak_ptr& mergeHolder, size_t streamIdx); - void readFromSpillFileStream(size_t streamIdx); - void scheduleAsyncSpillFileStreamReads(); + // Sets 'exception_' when an async reader throws. + void setError(const std::exception_ptr& exception); + + // Returns true if any async reader has thrown an exception. + bool hasError() const; + + // If any async reader has thrown an exception, rethrows it. + void checkError(); + folly::Executor* const executor_; - const std::shared_ptr> spillStats_; + const std::shared_ptr spillStats_; const std::shared_ptr pool_; std::vector> sources_; std::vector> batchStreams_; std::unique_ptr sourceMerger_; + mutable std::timed_mutex mutex_; + std::exception_ptr exception_ = nullptr; }; // LocalMerge merges its source's output into a single stream of diff --git a/velox/exec/MergeJoin.cpp b/velox/exec/MergeJoin.cpp index b5cc721b04c..65d2d180718 100644 --- a/velox/exec/MergeJoin.cpp +++ b/velox/exec/MergeJoin.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ #include "velox/exec/MergeJoin.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/OperatorUtils.h" #include "velox/exec/Task.h" #include "velox/expression/FieldReference.h" @@ -48,11 +49,20 @@ MergeJoin::MergeJoin( joinNode->outputType(), operatorId, joinNode->id(), - "MergeJoin"), - outputBatchSize_{outputBatchRows()}, + OperatorType::kMergeJoin), + preferredOutputBatchBytes_{ + driverCtx->queryConfig().preferredOutputBatchBytes()}, + preferredOutputBatchRows_{ + driverCtx->queryConfig().preferredOutputBatchRows()}, + dynamicOutputBatchSizeEnabled_{ + driverCtx->queryConfig().mergeJoinOutputBatchStartSize() != 0}, joinType_{joinNode->joinType()}, numKeys_{joinNode->leftKeys().size()}, rightNodeId_{joinNode->sources()[1]->id()}, + outputBatchSize_{ + dynamicOutputBatchSizeEnabled_ + ? driverCtx->queryConfig().mergeJoinOutputBatchStartSize() + : preferredOutputBatchRows_}, joinNode_(joinNode) { VELOX_USER_CHECK( core::MergeJoinNode::isSupported(joinType_), @@ -111,12 +121,12 @@ void MergeJoin::initialize() { if (joinNode_->isLeftJoin() || joinNode_->isAntiJoin() || joinNode_->isRightJoin() || joinNode_->isFullJoin() || isSemiFilterJoin(joinType_)) { - joinTracker_ = JoinTracker(outputBatchSize_, pool()); + joinTracker_ = JoinTracker(preferredOutputBatchRows_, pool()); } } else if (joinNode_->isAntiJoin()) { // Anti join needs to track the left side rows that have no match on the // right. - joinTracker_ = JoinTracker(outputBatchSize_, pool()); + joinTracker_ = JoinTracker(preferredOutputBatchRows_, pool()); } joinNode_.reset(); @@ -226,28 +236,37 @@ void MergeJoin::addInput(RowVectorPtr input) { // static int32_t MergeJoin::compare( - const std::vector& keys, - const RowVectorPtr& batch, - vector_size_t index, - const std::vector& otherKeys, - const RowVectorPtr& otherBatch, - vector_size_t otherIndex) { - for (auto i = 0; i < keys.size(); ++i) { + const std::vector& leftKeys, + const RowVectorPtr& leftBatch, + vector_size_t leftIndex, + const std::vector& rightKeys, + const RowVectorPtr& rightBatch, + vector_size_t rightIndex) { + for (auto i = 0; i < leftKeys.size(); ++i) { static const CompareFlags kCompareFlags = { .equalsOnly = true, .nullHandlingMode = CompareFlags::NullHandlingMode::kNullAsIndeterminate}; - const auto compare = batch->childAt(keys[i])->compare( - otherBatch->childAt(otherKeys[i]).get(), - index, - otherIndex, - kCompareFlags); + const auto compare = leftBatch->childAt(leftKeys[i]) + ->compare( + rightBatch->childAt(rightKeys[i]).get(), + leftIndex, + rightIndex, + kCompareFlags); - // Comparing null with anything will return std::nullopt. if (!compare.has_value()) { - // The SQL semantics of Presto and Spark will always return false if - // comparing a NULL value with any other value. - return -1; + // Under CompareFlags::NullHandlingMode::kNullAsIndeterminate, + // std::nullopt is returned in three cases: + // 1) Both the left key and the right key are null. + // 2) The left key is null, and the right key is not null. + // 3) The left key is not null, and the right key is null. + // + // However, the comparison result semantics differ: + // - Cases (1) and (2): return -1, meaning input_ should catch up with + // rightInput_. + // - Case (3): return 1, indicating the left key is considered greater, + // so rightInput_ should catch up with input_ in the subsequent steps. + return leftBatch->childAt(leftKeys[i])->isNullAt(leftIndex) ? -1 : 1; } else if (compare.value() != 0) { return compare.value(); } @@ -559,7 +578,7 @@ bool MergeJoin::prepareOutput( } else { inputs[i] = BaseVector::create( filterInputType_->childAt(i), - outputBatchSize_, + preferredOutputBatchRows_, operatorCtx_->pool()); } } @@ -568,7 +587,7 @@ bool MergeJoin::prepareOutput( operatorCtx_->pool(), filterInputType_, nullptr, - outputBatchSize_, + preferredOutputBatchRows_, std::move(inputs)); } return false; @@ -751,6 +770,9 @@ RowVectorPtr MergeJoin::getOutput() { for (;;) { auto output = doGetOutput(); if (output != nullptr && output->size() > 0) { + // Update the batch size based on the output before filtering. + updateOutputBatchSize(output); + if (filter_) { output = applyFilter(output); if (output != nullptr) { @@ -996,19 +1018,23 @@ RowVectorPtr MergeJoin::doGetOutput() { leftEndRow < input_->size(), std::nullopt}; - vector_size_t endRightRow = rightRowIndex_ + 1; - while (endRightRow < rightInput_->size() && - compareRight(endRightRow) == 0) { - ++endRightRow; + vector_size_t rightEndRow = rightRowIndex_ + 1; + while (rightEndRow < rightInput_->size() && + compareRight(rightEndRow) == 0) { + ++rightEndRow; } rightMatch_ = Match{ {rightInput_}, rightRowIndex_, - endRightRow, - endRightRow < rightInput_->size(), + rightEndRow, + rightEndRow < rightInput_->size(), std::nullopt}; + // Track matched rows for this key match. + matchedLeftRows_ += leftEndRow - leftMatch_->startRowIndex; + matchedRightRows_ += rightEndRow - rightMatch_->startRowIndex; + if (!leftMatch_->complete || !rightMatch_->complete) { if (!leftMatch_->complete) { // Need to continue looking for the end of match. @@ -1024,10 +1050,10 @@ RowVectorPtr MergeJoin::doGetOutput() { leftRowIndex_ = leftEndRow; if (isFullJoin(joinType_) || isRightJoin(joinType_)) { - rightRowIndex_ = endRightRow; + rightRowIndex_ = rightEndRow; } else { rightRowIndex_ = - firstNonNull(rightInput_, rightKeyChannels_, endRightRow); + firstNonNull(rightInput_, rightKeyChannels_, rightEndRow); } if (rightBatchFinished()) { @@ -1267,6 +1293,34 @@ void MergeJoin::clearRightInput() { rightInput_ = nullptr; } +void MergeJoin::updateOutputBatchSize(const RowVectorPtr& output) { + if (!dynamicOutputBatchSizeEnabled_) { + return; + } + + VELOX_CHECK_NOT_NULL(output); + + const auto outputSize = output->size(); + VELOX_CHECK_GT(outputSize, 0); + + // Calculate average row size from the current output batch. + const auto avgRowSize = output->estimateFlatSize() / outputSize; + + if (avgRowSize == 0) { + // Avoid division by zero; keep current batch size. + return; + } + + outputBatchSize_ = std::min( + static_cast(preferredOutputBatchBytes_ / avgRowSize), + static_cast(preferredOutputBatchRows_)); + + // Ensure we have at least 1 row. + if (outputBatchSize_ == 0) { + outputBatchSize_ = 1; + } +} + RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { const auto numRows = output->size(); @@ -1464,6 +1518,16 @@ bool MergeJoin::isFinished() { } void MergeJoin::close() { + // Report match ratio statistics. + { + auto lockedStats = stats_.wlock(); + lockedStats->addRuntimeStat( + std::string(MergeJoin::kMatchedLeftRows), + RuntimeCounter(matchedLeftRows_)); + lockedStats->addRuntimeStat( + std::string(MergeJoin::kMatchedRightRows), + RuntimeCounter(matchedRightRows_)); + } if (rightSource_) { rightSource_->close(); } diff --git a/velox/exec/MergeJoin.h b/velox/exec/MergeJoin.h index 414cb77bcaa..58b41a2e659 100644 --- a/velox/exec/MergeJoin.h +++ b/velox/exec/MergeJoin.h @@ -14,6 +14,8 @@ * limitations under the License. */ #pragma once +#include + #include #include "velox/exec/MergeSource.h" @@ -44,6 +46,12 @@ namespace facebook::velox::exec { /// than one right vector, it gets copied and flattened. class MergeJoin : public Operator { public: + /// Runtime stat keys for merge join. + /// Number of left rows matched in merge join. + static constexpr std::string_view kMatchedLeftRows = "matchedLeftRows"; + /// Number of right rows matched in merge join. + static constexpr std::string_view kMatchedRightRows = "matchedRightRows"; + MergeJoin( int32_t operatorId, DriverCtx* driverCtx, @@ -346,6 +354,12 @@ class MergeJoin : public Operator { return std::move(output_); } + // Updates outputBatchSize_ dynamically based on the average row size of the + // current output batch. The new batch size is computed as: + // 1. preferredOutputBatchBytes / avgRowSize + // 2. min(result from step 1, preferredOutputBatchRows) + void updateOutputBatchSize(const RowVectorPtr& output); + // Evaluates join filter on 'filterInput_' and returns 'output' that contains // a subset of rows on which the filter passed. Returns nullptr if no rows // passed the filter. @@ -557,8 +571,15 @@ class MergeJoin : public Operator { // dictionaries wrapped around the right side input. bool isRightFlattened_{false}; - // Maximum number of rows in the output batch. - const vector_size_t outputBatchSize_; + // Preferred output batch size in bytes from QueryConfig. + const uint64_t preferredOutputBatchBytes_; + + // Preferred output batch size in rows from QueryConfig. + const vector_size_t preferredOutputBatchRows_; + + // Whether dynamic output batch sizing is enabled. When disabled (default), + // outputBatchSize_ is fixed at preferredOutputBatchRows_. + const bool dynamicOutputBatchSizeEnabled_; // Type of join. const core::JoinType joinType_; @@ -568,6 +589,11 @@ class MergeJoin : public Operator { const core::PlanNodeId rightNodeId_; + // Maximum number of rows in the output batch. This is dynamically adjusted + // based on the average row size of previous output batches when dynamic + // batching is enabled. + vector_size_t outputBatchSize_; + // The cached merge join plan node used to initialize this operator after the // driver has started execution. It is reset after the initialization. std::shared_ptr joinNode_; @@ -632,5 +658,11 @@ class MergeJoin : public Operator { bool leftHasDrained_{false}; bool rightHasDrained_{false}; + + // Stats for tracking matched rows from the left side + uint64_t matchedLeftRows_{0}; + + // Stats for tracking matched rows from the right side + uint64_t matchedRightRows_{0}; }; } // namespace facebook::velox::exec diff --git a/velox/exec/MergeSource.cpp b/velox/exec/MergeSource.cpp index 2c91c39002a..34ddf2e0cf2 100644 --- a/velox/exec/MergeSource.cpp +++ b/velox/exec/MergeSource.cpp @@ -80,14 +80,23 @@ class LocalMergeSource : public MergeSource { return queue_.withWLock([&](auto& queue) { return queue.started(future); }); } - BlockingReason next(RowVectorPtr& data, ContinueFuture* future) override { + BlockingReason next(RowVectorPtr& data, ContinueFuture* future, bool& drained) + override { + drained = false; ScopedPromiseNotification notification(1); - return queue_.withWLock( - [&](auto& queue) { return queue.next(data, future, notification); }); + return queue_.withWLock([&](auto& queue) { + return queue.next(data, future, drained, notification); + }); } - BlockingReason enqueue(RowVectorPtr input, ContinueFuture* future) override { + BlockingReason + enqueue(RowVectorPtr input, ContinueFuture* future, bool drained) override { ScopedPromiseNotification notification(1); + if (drained) { + VELOX_CHECK_NULL(input); + queue_.withWLock([&](auto& queue) { queue.drain(notification); }); + return BlockingReason::kNotBlocked; + } return queue_.withWLock([&](auto& queue) { return queue.enqueue(input, future, notification); }); @@ -118,6 +127,7 @@ class LocalMergeSource : public MergeSource { BlockingReason next( RowVectorPtr& data, ContinueFuture* future, + bool& drained, ScopedPromiseNotification& notification) { VELOX_CHECK(started_); data.reset(); @@ -126,6 +136,11 @@ class LocalMergeSource : public MergeSource { if (atEnd_) { return BlockingReason::kNotBlocked; } + if (drained_) { + drained = true; + drained_ = false; + return BlockingReason::kNotBlocked; + } consumerPromises_.emplace_back("LocalMergeSourceQueue::next"); *future = consumerPromises_.back().getSemiFuture(); return BlockingReason::kWaitForProducer; @@ -140,6 +155,12 @@ class LocalMergeSource : public MergeSource { return BlockingReason::kNotBlocked; } + void drain(ScopedPromiseNotification& notification) { + VELOX_CHECK(!atEnd_); + drained_ = true; + notifyConsumers(notification); + } + BlockingReason enqueue( RowVectorPtr input, ContinueFuture* future, @@ -180,6 +201,7 @@ class LocalMergeSource : public MergeSource { bool started_{false}; bool atEnd_{false}; + bool drained_{false}; boost::circular_buffer data_; std::vector consumerPromises_; std::vector producerPromises_; @@ -198,15 +220,16 @@ class MergeExchangeSource : public MergeSource { memory::MemoryPool* pool, folly::Executor* executor) : mergeExchange_(mergeExchange), - client_(std::make_shared( - mergeExchange->taskId(), - destination, - maxQueuedBytes, - 1, - // Deliver right away to avoid blocking other sources - 0, - pool, - executor)) { + client_( + std::make_shared( + mergeExchange->taskId(), + destination, + maxQueuedBytes, + 1, + // Deliver right away to avoid blocking other sources + 0, + pool, + executor)) { client_->addRemoteTaskId(taskId); client_->noMoreRemoteTasks(); } @@ -221,7 +244,9 @@ class MergeExchangeSource : public MergeSource { VELOX_NYI(); } - BlockingReason next(RowVectorPtr& data, ContinueFuture* future) override { + BlockingReason next(RowVectorPtr& data, ContinueFuture* future, bool& drained) + override { + drained = false; data.reset(); if (atEnd_ && !currentPage_) { @@ -277,16 +302,16 @@ class MergeExchangeSource : public MergeSource { } } - private: - BlockingReason enqueue(RowVectorPtr input, ContinueFuture* future) override { + BlockingReason + enqueue(RowVectorPtr input, ContinueFuture* future, bool drained) override { VELOX_FAIL(); } + private: MergeExchange* const mergeExchange_; - std::shared_ptr client_; std::unique_ptr inputStream_; - std::unique_ptr currentPage_; + std::unique_ptr currentPage_; bool atEnd_ = false; }; } // namespace diff --git a/velox/exec/MergeSource.h b/velox/exec/MergeSource.h index ed9f3d1f37f..86ce401d079 100644 --- a/velox/exec/MergeSource.h +++ b/velox/exec/MergeSource.h @@ -37,11 +37,14 @@ class MergeSource { /// from the consumer. virtual BlockingReason started(ContinueFuture* future) = 0; - virtual BlockingReason next(RowVectorPtr& data, ContinueFuture* future) = 0; - - virtual BlockingReason enqueue( - RowVectorPtr input, - ContinueFuture* future) = 0; + virtual BlockingReason + next(RowVectorPtr& data, ContinueFuture* future, bool& drained) = 0; + + /// Called by the producer to enqueue more data, signal end of data (nullptr + /// input), or signal drain (nullptr input with drained=true) under barrier + /// processing. + virtual BlockingReason + enqueue(RowVectorPtr input, ContinueFuture* future, bool drained = false) = 0; virtual void close() = 0; diff --git a/velox/exec/MixedUnion.cpp b/velox/exec/MixedUnion.cpp new file mode 100644 index 00000000000..93918bd361b --- /dev/null +++ b/velox/exec/MixedUnion.cpp @@ -0,0 +1,304 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/MixedUnion.h" +#include "velox/exec/OperatorType.h" +#include "velox/exec/OperatorUtils.h" +#include "velox/exec/Task.h" + +namespace facebook::velox::exec { + +MixedUnion::MixedUnion( + int32_t operatorId, + DriverCtx* driverCtx, + const std::shared_ptr& unionNode) + : SourceOperator( + driverCtx, + unionNode->outputType(), + operatorId, + unionNode->id(), + OperatorType::kMixedUnion), + unionNode_(unionNode), + maxOutputBatchRows_(outputBatchRows()), + maxOutputBatchBytes_( + driverCtx->queryConfig().preferredOutputBatchBytes()) {} + +BlockingReason MixedUnion::addMergeSources(ContinueFuture* /* future */) { + if (sources_.empty()) { + // Get merge sources from the task + sources_ = operatorCtx_->task()->getLocalMergeSources( + operatorCtx_->driverCtx()->splitGroupId, planNodeId()); + + // Initialize tracking vectors + const auto numSources = sources_.size(); + pendingData_.resize(numSources); + sourcesFinished_.resize(numSources, false); + sourcesDrained_.resize(numSources, false); + } + return BlockingReason::kNotBlocked; +} + +void MixedUnion::startSources() { + if (sourcesStarted_) { + return; + } + + // Start all sources + for (auto& source : sources_) { + source->start(); + } + sourcesStarted_ = true; +} + +BlockingReason MixedUnion::isBlocked(ContinueFuture* future) { + const auto reason = addMergeSources(future); + if (reason != BlockingReason::kNotBlocked) { + return reason; + } + + if (sources_.empty()) { + finished_ = true; + return BlockingReason::kNotBlocked; + } + + startSources(); + + // Pre-fetch data from each source into pendingData_. This is done in both + // normal and drain modes so that getOutputMixed() can uniformly drain + // pendingData_ without polling sources directly. + std::vector blockingFutures; + for (size_t i = 0; i < sources_.size(); ++i) { + if (sourcesFinished_[i] || sourcesDrained_[i] || pendingData_[i]) { + continue; + } + + ContinueFuture sourceFuture; + RowVectorPtr data; + bool drained{false}; + const auto blockingReason = sources_[i]->next(data, &sourceFuture, drained); + + if (blockingReason != BlockingReason::kNotBlocked) { + blockingFutures.push_back(std::move(sourceFuture)); + } else if (data) { + pendingData_[i] = std::move(data); + } else if (drained) { + sourcesDrained_[i] = true; + } else { + sourcesFinished_[i] = true; + } + } + + if (!blockingFutures.empty()) { + // Use collectAny to continue as soon as any source has data, allowing us to + // prefetch into pendingData_ incrementally. This differs from Merge which + // waits one source at a time since it needs sorted merging. + *future = folly::collectAny(std::move(blockingFutures)).unit(); + return BlockingReason::kWaitForProducer; + } + + return BlockingReason::kNotBlocked; +} + +bool MixedUnion::isFinished() { + return finished_; +} + +RowVectorPtr MixedUnion::getOutput() { + if (finished_) { + return nullptr; + } + + return getOutputMixed(); +} + +bool MixedUnion::hasPendingDrainData() const { + // Check if there's any pending data to drain. + for (const auto& data : pendingData_) { + if (data != nullptr) { + return true; + } + } + + // Check if any source still has data to drain. + // A source is considered drained if it's finished OR has signaled drained. + for (size_t i = 0; i < sources_.size(); ++i) { + if (!sourcesFinished_[i] && !sourcesDrained_[i]) { + return true; + } + } + + return false; +} + +bool MixedUnion::startDrain() { + VELOX_CHECK(isDraining()); + + // Note: We don't call source->drain() here because the producer's + // CallbackSink has already called it when it entered drain mode. + // We just need to check if there's pending data to drain and drain any + // remaining data from sources. + + if (hasPendingDrainData()) { + return true; + } + + // No data to drain. Reset state for next barrier cycle. + std::fill(sourcesDrained_.begin(), sourcesDrained_.end(), false); + + return false; +} + +void MixedUnion::maybeFinishDrain() { + if (!isDraining()) { + return; + } + + if (hasPendingDrainData()) { + return; + } + + finishDrain(); +} + +void MixedUnion::finishDrain() { + VELOX_CHECK(isDraining()); + + // Reset drain state for next barrier. + std::fill(sourcesDrained_.begin(), sourcesDrained_.end(), false); + + Operator::finishDrain(); +} + +RowVectorPtr MixedUnion::getOutputMixed() { + // Drain pendingData_ populated by isBlocked() in both normal and drain modes. + std::vector validInputs; + for (size_t i = 0; i < pendingData_.size(); ++i) { + if (pendingData_[i]) { + validInputs.push_back(std::move(pendingData_[i])); + pendingData_[i] = nullptr; + } + } + + if (!validInputs.empty()) { + auto result = combineResults(validInputs); + return result; + } + + // No pending data. Check termination conditions. + if (isDraining()) { + maybeFinishDrain(); + return nullptr; + } + + bool allFinished = true; + bool allDrained = true; + for (size_t i = 0; i < sources_.size(); ++i) { + if (sourcesFinished_[i]) { + continue; + } + allFinished = false; + if (!sourcesDrained_[i]) { + allDrained = false; + } + } + + if (allFinished) { + finished_ = true; + return nullptr; + } + + if (allDrained && operatorCtx_->driver()->hasBarrier()) { + operatorCtx_->driver()->drainOutput(); + } + + return nullptr; +} + +RowVectorPtr MixedUnion::combineResults(std::vector& results) { + if (results.empty()) { + return nullptr; + } + + if (results.size() == 1) { + auto result = std::move(results[0]); + // Record output statistics + { + auto lockedStats = stats_.wlock(); + lockedStats->addOutputVector(result->estimateFlatSize(), result->size()); + } + return result; + } + + // Calculate total number of rows + vector_size_t totalRows = 0; + for (const auto& result : results) { + totalRows += result->size(); + } + + if (totalRows == 0) { + return nullptr; + } + + // Create combined output vector + auto combinedResult = + BaseVector::create(outputType_, totalRows, pool()); + + // Copy data from all input vectors + vector_size_t currentOffset = 0; + for (const auto& result : results) { + if (result->size() > 0) { + for (auto i = 0; i < outputType_->size(); ++i) { + // Copy column data + std::vector ranges; + ranges.push_back({0, currentOffset, result->size()}); + + combinedResult->childAt(i)->copyRanges( + result->childAt(i).get(), ranges); + } + currentOffset += result->size(); + } + } + + // Record output statistics + { + auto lockedStats = stats_.wlock(); + lockedStats->addOutputVector( + combinedResult->estimateFlatSize(), combinedResult->size()); + } + + return combinedResult; +} + +bool MixedUnion::hasDataFromAllSources() const { + for (size_t i = 0; i < pendingData_.size(); ++i) { + // If source is not finished and has no data, we don't have all sources + if (!sourcesFinished_[i] && !pendingData_[i]) { + return false; + } + } + return true; +} + +void MixedUnion::close() { + // Close all sources + for (auto& source : sources_) { + source->close(); + } + pendingData_.clear(); + Operator::close(); +} + +} // namespace facebook::velox::exec diff --git a/velox/exec/MixedUnion.h b/velox/exec/MixedUnion.h new file mode 100644 index 00000000000..bba78d343a6 --- /dev/null +++ b/velox/exec/MixedUnion.h @@ -0,0 +1,102 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/core/PlanNode.h" +#include "velox/exec/MergeSource.h" +#include "velox/exec/Operator.h" + +namespace facebook::velox::exec { + +/// Union operator that processes splits from all inputs simultaneously as a +/// SourceOperator. Unlike traditional operators that receive input via +/// addInput(), MixedUnion manages multiple MergeSource objects internally, +/// pulling data from each source and combining results in a round-robin or +/// interleaved fashion. +class MixedUnion : public SourceOperator { + public: + MixedUnion( + int32_t operatorId, + DriverCtx* driverCtx, + const std::shared_ptr& unionNode); + + RowVectorPtr getOutput() override; + + BlockingReason isBlocked(ContinueFuture* future) override; + + bool isFinished() override; + + void close() override; + + /// Invoked by the driver to start draining output on this operator. + /// Returns true if this operator has buffered output to drain (pending data + /// in pendingData_ or sources that haven't signaled drained yet). + /// Returns false if there's no data to drain and the driver should proceed + /// to the next operator. + bool startDrain() override; + + void finishDrain() override; + + private: + /// Check if there's pending data to drain (either in pendingData_ or from + /// sources that haven't signaled drained yet). Used by both startDrain() and + /// maybeFinishDrain() to avoid duplicate code. + bool hasPendingDrainData() const; + + /// Check if all sources have been drained and finish drain if so. + void maybeFinishDrain(); + + /// Get merge sources from the task + BlockingReason addMergeSources(ContinueFuture* future); + + /// Start reading from sources + void startSources(); + + /// Process inputs in mixed mode (all at once) + RowVectorPtr getOutputMixed(); + + /// Combine multiple row vectors into a single output vector + RowVectorPtr combineResults(std::vector& results); + + /// Check if we have data from all active sources for mixed mode + bool hasDataFromAllSources() const; + + const std::shared_ptr unionNode_; + + /// MergeSource objects representing each input pipeline + std::vector> sources_; + + /// Track which sources have been started + bool sourcesStarted_{false}; + + /// Store pending data from each source + std::vector pendingData_; + + /// Track which sources have finished + std::vector sourcesFinished_; + + /// True when all sources are exhausted + bool finished_{false}; + + /// Tracks which sources have been drained (during barrier processing). + std::vector sourcesDrained_; + + /// Maximum output batch size + const vector_size_t maxOutputBatchRows_; + const uint64_t maxOutputBatchBytes_; +}; + +} // namespace facebook::velox::exec diff --git a/velox/exec/NestedLoopJoinBuild.cpp b/velox/exec/NestedLoopJoinBuild.cpp index 5fbd2acc4d0..f38098fab86 100644 --- a/velox/exec/NestedLoopJoinBuild.cpp +++ b/velox/exec/NestedLoopJoinBuild.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ #include "velox/exec/NestedLoopJoinBuild.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/Task.h" namespace facebook::velox::exec { @@ -50,7 +51,7 @@ NestedLoopJoinBuild::NestedLoopJoinBuild( nullptr, operatorId, joinNode->id(), - "NestedLoopJoinBuild") {} + OperatorType::kNestedLoopJoinBuild) {} void NestedLoopJoinBuild::addInput(RowVectorPtr input) { if (input->size() > 0) { diff --git a/velox/exec/NestedLoopJoinProbe.cpp b/velox/exec/NestedLoopJoinProbe.cpp index 3c72407796c..868e8088920 100644 --- a/velox/exec/NestedLoopJoinProbe.cpp +++ b/velox/exec/NestedLoopJoinProbe.cpp @@ -15,6 +15,7 @@ */ #include "velox/exec/NestedLoopJoinProbe.h" #include "velox/exec/DriverStats.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/OperatorUtils.h" #include "velox/exec/Task.h" #include "velox/expression/FieldReference.h" @@ -56,7 +57,7 @@ NestedLoopJoinProbe::NestedLoopJoinProbe( joinNode->outputType(), operatorId, joinNode->id(), - "NestedLoopJoinProbe"), + OperatorType::kNestedLoopJoinProbe), joinType_(joinNode->joinType()), outputBatchSize_{outputBatchRows()}, joinNode_(joinNode) { diff --git a/velox/exec/OneWayStatusFlag.h b/velox/exec/OneWayStatusFlag.h index 9986f5bafe7..4eea37c2471 100644 --- a/velox/exec/OneWayStatusFlag.h +++ b/velox/exec/OneWayStatusFlag.h @@ -16,53 +16,28 @@ #pragma once -#include #include namespace facebook::velox::exec { -/// A simple one way status flag that uses a non atomic flag to avoid -/// unnecessary atomic operations. class OneWayStatusFlag { public: - bool check() const { -#if defined(__x86_64__) - folly::annotate_ignore_thread_sanitizer_guard g(__FILE__, __LINE__); - return fastStatus_ || atomicStatus_.load(); -#else - return atomicStatus_.load(std::memory_order_relaxed) || - atomicStatus_.load(); -#endif + bool check() const noexcept { + return status_.load(std::memory_order_acquire); } - void set() { -#if defined(__x86_64__) - folly::annotate_ignore_thread_sanitizer_guard g(__FILE__, __LINE__); - if (!fastStatus_) { - atomicStatus_.store(true); - fastStatus_ = true; + void set() noexcept { + if (!status_.load(std::memory_order_relaxed)) { + status_.store(true, std::memory_order_release); } -#else - if (!atomicStatus_.load(std::memory_order_relaxed)) { - atomicStatus_.store(true); - } -#endif } - /// Operator overload to convert OneWayStatusFlag to bool - operator bool() const { + explicit operator bool() const noexcept { return check(); } private: -#if defined(__x86_64__) - // This flag can only go from false to true, and is only checked at the end of - // a loop. Given that once a flag is true it can never go back to false, we - // are ok to use this in a non synchronized manner to avoid the overhead. As - // such we consciously exempt ourselves here from TSAN detection. - bool fastStatus_{false}; -#endif - std::atomic_bool atomicStatus_{false}; + std::atomic_bool status_{false}; }; } // namespace facebook::velox::exec diff --git a/velox/exec/Operator.cpp b/velox/exec/Operator.cpp index b5dcdb849c7..972fa31fb9b 100644 --- a/velox/exec/Operator.cpp +++ b/velox/exec/Operator.cpp @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include "velox/exec/Operator.h" #include "velox/common/base/Counters.h" #include "velox/common/base/StatsReporter.h" @@ -20,7 +21,7 @@ #include "velox/common/testutil/TestValue.h" #include "velox/exec/Driver.h" #include "velox/exec/OperatorUtils.h" -#include "velox/exec/TraceUtil.h" +#include "velox/exec/Task.h" #include "velox/expression/Expr.h" using facebook::velox::common::testutil::TestValue; @@ -31,12 +32,12 @@ OperatorCtx::OperatorCtx( DriverCtx* driverCtx, const core::PlanNodeId& planNodeId, int32_t operatorId, - const std::string& operatorType) + std::string_view operatorType) : driverCtx_(driverCtx), planNodeId_(planNodeId), operatorId_(operatorId), operatorType_(operatorType), - pool_(driverCtx_->addOperatorPool(planNodeId, operatorType)) {} + pool_(driverCtx_->addOperatorPool(planNodeId, operatorType_)) {} core::ExecCtx* OperatorCtx::execCtx() const { if (!execCtx_) { @@ -72,6 +73,8 @@ OperatorCtx::createConnectorQueryCtx( task->queryCtx()->fsTokenProvider()); connectorQueryCtx->setSelectiveNimbleReaderEnabled( driverCtx_->queryConfig().selectiveNimbleReaderEnabled()); + connectorQueryCtx->setRowSizeTrackingMode( + driverCtx_->queryConfig().rowSizeTrackingMode()); return connectorQueryCtx; } @@ -80,23 +83,25 @@ Operator::Operator( RowTypePtr outputType, int32_t operatorId, std::string planNodeId, - std::string operatorType, + std::string_view operatorType, std::optional spillConfig) - : operatorCtx_(std::make_unique( - driverCtx, - planNodeId, - operatorId, - operatorType)), + : operatorCtx_( + std::make_unique( + driverCtx, + planNodeId, + operatorId, + operatorType)), outputType_(std::move(outputType)), spillConfig_(std::move(spillConfig)), dryRun_( - operatorCtx_->driverCtx()->traceConfig().has_value() && - operatorCtx_->driverCtx()->traceConfig()->dryRun), - stats_(OperatorStats{ - operatorId, - driverCtx->pipelineId, - std::move(planNodeId), - std::move(operatorType)}) {} + operatorCtx_->driverCtx()->traceCtx() && + operatorCtx_->driverCtx()->traceCtx()->dryRun()), + stats_( + OperatorStats{ + operatorId, + driverCtx->pipelineId, + std::move(planNodeId), + std::string(operatorType)}) {} void Operator::maybeSetReclaimer() { VELOX_CHECK_NULL(pool()->reclaimer()); @@ -109,53 +114,22 @@ void Operator::maybeSetReclaimer() { } void Operator::maybeSetTracer() { - const auto& traceConfig = operatorCtx_->driverCtx()->traceConfig(); - if (!traceConfig.has_value()) { - return; - } - - const auto nodeId = planNodeId(); - if (traceConfig->queryNodeId.empty() || traceConfig->queryNodeId != nodeId) { - return; - } + const auto* traceCtx = operatorCtx_->driverCtx()->traceCtx(); - auto& tracedOpMap = operatorCtx_->driverCtx()->tracedOperatorMap; - if (const auto iter = tracedOpMap.find(operatorId()); - iter != tracedOpMap.end()) { - LOG(WARNING) << "Operator " << iter->first << " with type of " - << operatorType() << ", plan node " << nodeId - << " might be the auxiliary operator of " << iter->second - << " which has the same operator id"; - return; - } - tracedOpMap.emplace(operatorId(), operatorType()); - - if (!trace::canTrace(operatorType())) { - VELOX_UNSUPPORTED("{} does not support tracing", operatorType()); - } - - const auto pipelineId = operatorCtx_->driverCtx()->pipelineId; - const auto driverId = operatorCtx_->driverCtx()->driverId; - LOG(INFO) << "Trace input for operator type: " << operatorType() - << ", operator id: " << operatorId() << ", pipeline: " << pipelineId - << ", driver: " << driverId << ", task: " << taskId(); - const auto opTraceDirPath = trace::getOpTraceDirectory( - traceConfig->queryTraceDir, planNodeId(), pipelineId, driverId); - trace::createTraceDirectory( - opTraceDirPath, - operatorCtx_->driverCtx()->queryConfig().opTraceDirectoryCreateConfig()); - - if (dynamic_cast(this) != nullptr) { - setupSplitTracer(opTraceDirPath); - } else { - setupInputTracer(opTraceDirPath); + if (traceCtx && traceCtx->shouldTrace(*this)) { + if (dynamic_cast(this) != nullptr) { + splitTracer_ = traceCtx->createSplitTracer(*this); + } else { + inputTracer_ = traceCtx->createInputTracer(*this); + } } } -void Operator::traceInput(const RowVectorPtr& input) { +bool Operator::traceInput(const RowVectorPtr& input, ContinueFuture* future) { if (FOLLY_UNLIKELY(inputTracer_ != nullptr)) { - inputTracer_->write(input); + return inputTracer_->write(input, future); } + return false; } void Operator::finishTrace() { @@ -175,19 +149,6 @@ Operator::translators() { return translators; } -void Operator::setupInputTracer(const std::string& opTraceDirPath) { - inputTracer_ = std::make_unique( - this, - opTraceDirPath, - memory::traceMemoryPool(), - operatorCtx_->driverCtx()->traceConfig()->updateAndCheckTraceLimitCB); -} - -void Operator::setupSplitTracer(const std::string& opTraceDirPath) { - splitTracer_ = - std::make_unique(this, opTraceDirPath); -} - // static std::unique_ptr Operator::fromPlanNode( DriverCtx* ctx, @@ -391,107 +352,121 @@ void Operator::recordBlockingTime(uint64_t start, BlockingReason reason) { } void Operator::recordSpillStats() { - const auto lockedSpillStats = spillStats_->wlock(); auto lockedStats = stats_.wlock(); - lockedStats->spilledInputBytes += lockedSpillStats->spilledInputBytes; - lockedStats->spilledBytes += lockedSpillStats->spilledBytes; - lockedStats->spilledRows += lockedSpillStats->spilledRows; - lockedStats->spilledPartitions += lockedSpillStats->spilledPartitions; - lockedStats->spilledFiles += lockedSpillStats->spilledFiles; - if (lockedSpillStats->spillFillTimeNanos != 0) { + lockedStats->spilledInputBytes += + spillStats_->spilledInputBytes.load(std::memory_order_relaxed); + lockedStats->spilledBytes += + spillStats_->spilledBytes.load(std::memory_order_relaxed); + lockedStats->spilledRows += + spillStats_->spilledRows.load(std::memory_order_relaxed); + lockedStats->spilledPartitions += + spillStats_->spilledPartitions.load(std::memory_order_relaxed); + lockedStats->spilledFiles += + spillStats_->spilledFiles.load(std::memory_order_relaxed); + + const auto fillTime = + spillStats_->spillFillTimeNanos.load(std::memory_order_relaxed); + if (fillTime != 0) { lockedStats->addRuntimeStat( kSpillFillTime, - RuntimeCounter{ - static_cast(lockedSpillStats->spillFillTimeNanos), - RuntimeCounter::Unit::kNanos}); + RuntimeCounter{saturateCast(fillTime), RuntimeCounter::Unit::kNanos}); } - if (lockedSpillStats->spillSortTimeNanos != 0) { + const auto sortTime = + spillStats_->spillSortTimeNanos.load(std::memory_order_relaxed); + if (sortTime != 0) { lockedStats->addRuntimeStat( kSpillSortTime, - RuntimeCounter{ - static_cast(lockedSpillStats->spillSortTimeNanos), - RuntimeCounter::Unit::kNanos}); + RuntimeCounter{saturateCast(sortTime), RuntimeCounter::Unit::kNanos}); } - if (lockedSpillStats->spillExtractVectorTimeNanos != 0) { + const auto extractTime = + spillStats_->spillExtractVectorTimeNanos.load(std::memory_order_relaxed); + if (extractTime != 0) { lockedStats->addRuntimeStat( kSpillExtractVectorTime, RuntimeCounter{ - static_cast(lockedSpillStats->spillExtractVectorTimeNanos), - RuntimeCounter::Unit::kNanos}); + saturateCast(extractTime), RuntimeCounter::Unit::kNanos}); } - if (lockedSpillStats->spillSerializationTimeNanos != 0) { + const auto serializationTime = + spillStats_->spillSerializationTimeNanos.load(std::memory_order_relaxed); + if (serializationTime != 0) { lockedStats->addRuntimeStat( kSpillSerializationTime, RuntimeCounter{ - static_cast(lockedSpillStats->spillSerializationTimeNanos), - RuntimeCounter::Unit::kNanos}); + saturateCast(serializationTime), RuntimeCounter::Unit::kNanos}); } - if (lockedSpillStats->spillFlushTimeNanos != 0) { + const auto flushTime = + spillStats_->spillFlushTimeNanos.load(std::memory_order_relaxed); + if (flushTime != 0) { lockedStats->addRuntimeStat( kSpillFlushTime, - RuntimeCounter{ - static_cast(lockedSpillStats->spillFlushTimeNanos), - RuntimeCounter::Unit::kNanos}); + RuntimeCounter{saturateCast(flushTime), RuntimeCounter::Unit::kNanos}); } - if (lockedSpillStats->spillWrites != 0) { + const auto writes = spillStats_->spillWrites.load(std::memory_order_relaxed); + if (writes != 0) { lockedStats->addRuntimeStat( - kSpillWrites, - RuntimeCounter{static_cast(lockedSpillStats->spillWrites)}); + kSpillWrites, RuntimeCounter{saturateCast(writes)}); } - if (lockedSpillStats->spillWriteTimeNanos != 0) { + const auto writeTime = + spillStats_->spillWriteTimeNanos.load(std::memory_order_relaxed); + if (writeTime != 0) { lockedStats->addRuntimeStat( kSpillWriteTime, - RuntimeCounter{ - static_cast(lockedSpillStats->spillWriteTimeNanos), - RuntimeCounter::Unit::kNanos}); + RuntimeCounter{saturateCast(writeTime), RuntimeCounter::Unit::kNanos}); } - if (lockedSpillStats->spillRuns != 0) { - lockedStats->addRuntimeStat( - kSpillRuns, - RuntimeCounter{static_cast(lockedSpillStats->spillRuns)}); - common::updateGlobalSpillRunStats(lockedSpillStats->spillRuns); + const auto runs = spillStats_->spillRuns.load(std::memory_order_relaxed); + if (runs != 0) { + lockedStats->addRuntimeStat(kSpillRuns, RuntimeCounter{saturateCast(runs)}); + updateGlobalSpillRunStats(runs); } - if (lockedSpillStats->spillMaxLevelExceededCount != 0) { + const auto maxLevelExceeded = + spillStats_->spillMaxLevelExceededCount.load(std::memory_order_relaxed); + if (maxLevelExceeded != 0) { lockedStats->addRuntimeStat( - kExceededMaxSpillLevel, - RuntimeCounter{static_cast( - lockedSpillStats->spillMaxLevelExceededCount)}); - common::updateGlobalMaxSpillLevelExceededCount( - lockedSpillStats->spillMaxLevelExceededCount); + kExceededMaxSpillLevel, RuntimeCounter{saturateCast(maxLevelExceeded)}); + updateGlobalMaxSpillLevelExceededCount(maxLevelExceeded); } - if (lockedSpillStats->spillReadBytes != 0) { + const auto readBytes = + spillStats_->spillReadBytes.load(std::memory_order_relaxed); + if (readBytes != 0) { lockedStats->addRuntimeStat( kSpillReadBytes, - RuntimeCounter{ - static_cast(lockedSpillStats->spillReadBytes), - RuntimeCounter::Unit::kBytes}); + RuntimeCounter{saturateCast(readBytes), RuntimeCounter::Unit::kBytes}); } - if (lockedSpillStats->spillReads != 0) { + const auto reads = spillStats_->spillReads.load(std::memory_order_relaxed); + if (reads != 0) { lockedStats->addRuntimeStat( - kSpillReads, - RuntimeCounter{static_cast(lockedSpillStats->spillReads)}); + kSpillReads, RuntimeCounter{saturateCast(reads)}); } - if (lockedSpillStats->spillReadTimeNanos != 0) { + const auto readTime = + spillStats_->spillReadTimeNanos.load(std::memory_order_relaxed); + if (readTime != 0) { lockedStats->addRuntimeStat( kSpillReadTime, - RuntimeCounter{ - static_cast(lockedSpillStats->spillReadTimeNanos), - RuntimeCounter::Unit::kNanos}); + RuntimeCounter{saturateCast(readTime), RuntimeCounter::Unit::kNanos}); } - if (lockedSpillStats->spillDeserializationTimeNanos != 0) { + const auto deserializationTime = + spillStats_->spillDeserializationTimeNanos.load( + std::memory_order_relaxed); + if (deserializationTime != 0) { lockedStats->addRuntimeStat( kSpillDeserializationTime, RuntimeCounter{ - static_cast( - lockedSpillStats->spillDeserializationTimeNanos), - RuntimeCounter::Unit::kNanos}); + saturateCast(deserializationTime), RuntimeCounter::Unit::kNanos}); } - lockedSpillStats->reset(); + + // Collect filesystem I/O stats for spilling. + const auto ioStatsMap = spillStats_->ioStats.stats(); + for (const auto& [statName, statValue] : ioStatsMap) { + lockedStats->addRuntimeStat( + statName, RuntimeCounter(statValue.sum, statValue.unit)); + } + + spillStats_->reset(); } std::string Operator::toString() const { @@ -553,7 +528,7 @@ std::vector calculateOutputChannels( } void OperatorStats::addRuntimeStat( - const std::string& name, + std::string_view name, const RuntimeCounter& value) { addOperatorRuntimeStats(name, value, runtimeStats); } @@ -772,9 +747,12 @@ uint64_t Operator::MemoryReclaimer::reclaim( ++stats.numNonReclaimableAttempts; RECORD_METRIC_VALUE(kMetricMemoryNonReclaimableCount); LOG(WARNING) << "Can't reclaim from memory pool " << pool->name() - << " which is under non-reclaimable section, memory usage: " - << succinctBytes(pool->usedBytes()) - << ", reservation: " << succinctBytes(pool->reservedBytes()); + << " which is under non-reclaimable section" + << ", root pool: " << pool->root()->name() + << ", used: " << succinctBytes(pool->usedBytes()) + << ", reservation: " << succinctBytes(pool->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool->root()->reservedBytes()); return 0; } diff --git a/velox/exec/Operator.h b/velox/exec/Operator.h index a2459ba52d6..53e462da850 100644 --- a/velox/exec/Operator.h +++ b/velox/exec/Operator.h @@ -16,12 +16,14 @@ #pragma once #include +#include #include "velox/core/PlanNode.h" #include "velox/core/QueryCtx.h" #include "velox/exec/Driver.h" #include "velox/exec/JoinBridge.h" #include "velox/exec/OperatorStats.h" -#include "velox/exec/OperatorTraceWriter.h" +#include "velox/exec/SpillStats.h" +#include "velox/exec/trace/TraceWriter.h" namespace facebook::velox::exec { @@ -43,7 +45,7 @@ class OperatorCtx { DriverCtx* driverCtx, const core::PlanNodeId& planNodeId, int32_t operatorId, - const std::string& operatorType = ""); + std::string_view operatorType = ""); const std::shared_ptr& task() const { return driverCtx_->task; @@ -152,40 +154,45 @@ class Operator : public BaseRuntimeStatWriter { } }; + /// The name for background cpu time metric if operator has background cpu + /// usages outside its driver thread. + static constexpr std::string_view kBackgroundCpuTimeNanos = + "backgroundCpuTimeNanos"; + /// The name of the runtime spill stats collected and reported by operators /// that support spilling. /// This indicates the spill not supported for a spillable operator when the /// spill config is enabled. This is due to the spill limitation in certain /// plan node config such as unpartition window operator. - static inline const std::string kSpillNotSupported{"spillNotSupported"}; + static constexpr std::string_view kSpillNotSupported{"spillNotSupported"}; /// The spill write stats. - static inline const std::string kSpillFillTime{"spillFillWallNanos"}; - static inline const std::string kSpillSortTime{"spillSortWallNanos"}; - static inline const std::string kSpillExtractVectorTime{ + static constexpr std::string_view kSpillFillTime{"spillFillWallNanos"}; + static constexpr std::string_view kSpillSortTime{"spillSortWallNanos"}; + static constexpr std::string_view kSpillExtractVectorTime{ "spillExtractVectorWallNanos"}; - static inline const std::string kSpillSerializationTime{ + static constexpr std::string_view kSpillSerializationTime{ "spillSerializationWallNanos"}; - static inline const std::string kSpillFlushTime{"spillFlushWallNanos"}; - static inline const std::string kSpillWrites{"spillWrites"}; - static inline const std::string kSpillWriteTime{"spillWriteWallNanos"}; - static inline const std::string kSpillRuns{"spillRuns"}; - static inline const std::string kExceededMaxSpillLevel{ + static constexpr std::string_view kSpillFlushTime{"spillFlushWallNanos"}; + static constexpr std::string_view kSpillWrites{"spillWrites"}; + static constexpr std::string_view kSpillWriteTime{"spillWriteWallNanos"}; + static constexpr std::string_view kSpillRuns{"spillRuns"}; + static constexpr std::string_view kExceededMaxSpillLevel{ "exceededMaxSpillLevel"}; /// The spill read stats. - static inline const std::string kSpillReadBytes{"spillReadBytes"}; - static inline const std::string kSpillReads{"spillReads"}; - static inline const std::string kSpillReadTime{"spillReadWallNanos"}; - static inline const std::string kSpillDeserializationTime{ + static constexpr std::string_view kSpillReadBytes{"spillReadBytes"}; + static constexpr std::string_view kSpillReads{"spillReads"}; + static constexpr std::string_view kSpillReadTime{"spillReadWallNanos"}; + static constexpr std::string_view kSpillDeserializationTime{ "spillDeserializationWallNanos"}; /// The vector serde kind used by an operator for shuffle. The recorded /// runtime stats value is the corresponding enum value. - static inline const std::string kShuffleSerdeKind{"shuffleSerdeKind"}; + static constexpr std::string_view kShuffleSerdeKind{"shuffleSerdeKind"}; /// The compression kind used by an operator for shuffle. The recorded /// runtime stats value is the corresponding enum value. - static inline const std::string kShuffleCompressionKind{ + static constexpr std::string_view kShuffleCompressionKind{ "shuffleCompressionKind"}; /// 'operatorId' is the initial index of the 'this' in the Driver's list of @@ -203,7 +210,7 @@ class Operator : public BaseRuntimeStatWriter { RowTypePtr outputType, int32_t operatorId, std::string planNodeId, - std::string operatorType, + std::string_view operatorType, std::optional spillConfig = std::nullopt); virtual ~Operator() = default; @@ -212,7 +219,7 @@ class Operator : public BaseRuntimeStatWriter { /// allocation from memory pool that can't be done under operator constructor. /// /// NOTE: the default implementation set 'initialized_' to true to ensure we - /// never call this more than once. The overload initialize() implementation + /// never call this more than once. The overriding initialize() implementation /// must call this base implementation first. virtual void initialize(); @@ -290,7 +297,7 @@ class Operator : public BaseRuntimeStatWriter { } /// Traces input batch of the operator. - virtual void traceInput(const RowVectorPtr&); + virtual bool traceInput(const RowVectorPtr& input, ContinueFuture* future); /// Finishes tracing of the operator. virtual void finishTrace(); @@ -341,7 +348,7 @@ class Operator : public BaseRuntimeStatWriter { /// Add a single runtime stat to the operator stats under the write lock. /// This member overrides BaseRuntimeStatWriter's member. - void addRuntimeStat(const std::string& name, const RuntimeCounter& value) + void addRuntimeStat(std::string_view name, const RuntimeCounter& value) override { stats_.wlock()->addRuntimeStat(name, value); } @@ -516,6 +523,12 @@ class Operator : public BaseRuntimeStatWriter { return input_ != nullptr; } + /// Returns the spill config for this operator. This method is only used for + /// test. + const common::SpillConfig* testingSpillConfig() const { + return spillConfig(); + } + protected: static std::vector>& translators(); friend class NonReclaimableSection; @@ -628,14 +641,14 @@ class Operator : public BaseRuntimeStatWriter { bool initialized_{false}; folly::Synchronized stats_; - std::shared_ptr> spillStats_ = - std::make_shared>(); + std::shared_ptr spillStats_ = + std::make_shared(); /// NOTE: only one of the two could be set for an operator for tracing . /// 'splitTracer_' is only set for table scan to record the processed split /// for now. - std::unique_ptr inputTracer_{nullptr}; - std::unique_ptr splitTracer_{nullptr}; + std::unique_ptr inputTracer_{nullptr}; + std::unique_ptr splitTracer_{nullptr}; /// Indicates if an operator is under a non-reclaimable execution section. /// This prevents the memory arbitrator from reclaiming memory from this @@ -666,12 +679,6 @@ class Operator : public BaseRuntimeStatWriter { bool shouldYield() const { return operatorCtx_->driverCtx()->driver->shouldYield(); } - - private: - // Setup 'inputTracer_' to record the processed input vectors. - void setupInputTracer(const std::string& traceDir); - // Setup 'splitTracer_' for table scan to record the processed split. - void setupSplitTracer(const std::string& traceDir); }; /// Given a row type returns indices for the specified subset of columns. @@ -698,7 +705,7 @@ class SourceOperator : public Operator { RowTypePtr outputType, int32_t operatorId, const std::string& planNodeId, - const std::string& operatorType, + std::string_view operatorType, const std::optional& spillConfig = std::nullopt) : Operator( driverCtx, diff --git a/velox/exec/OperatorStats.h b/velox/exec/OperatorStats.h index 4715ccb06d5..19ca8e93318 100644 --- a/velox/exec/OperatorStats.h +++ b/velox/exec/OperatorStats.h @@ -15,6 +15,8 @@ */ #pragma once +#include + #include "velox/common/base/RuntimeMetrics.h" #include "velox/common/memory/MemoryPool.h" #include "velox/common/time/CpuWallTimer.h" @@ -89,6 +91,21 @@ struct DynamicFilterStats { }; struct OperatorStats { + /// Runtime stat name for per-driver CPU time (actual work time, not including + /// blocked time) for this operator. The max field will contain the CPU time + /// from the longest running single driver. + static constexpr const char* kDriverCpuTime = "driverCpuTimeNanos"; + + /// Running time metrics from CpuWallTiming structures, aggregated per thread. + static constexpr std::string_view kRunningAddInputWallNanos = + "runningAddInputWallNanos"; + static constexpr std::string_view kRunningGetOutputWallNanos = + "runningGetOutputWallNanos"; + static constexpr std::string_view kRunningFinishWallNanos = + "runningFinishWallNanos"; + static constexpr std::string_view kRunningIsBlockedWallNanos = + "runningIsBlockedWallNanos"; + /// Initial ordinal position in the operator's pipeline. int32_t operatorId = 0; int32_t pipelineId = 0; @@ -219,7 +236,7 @@ struct OperatorStats { outputVectors += 1; } - void addRuntimeStat(const std::string& name, const RuntimeCounter& value); + void addRuntimeStat(std::string_view name, const RuntimeCounter& value); void add(const OperatorStats& other); void clear(); }; diff --git a/velox/exec/OperatorTraceCtx.cpp b/velox/exec/OperatorTraceCtx.cpp new file mode 100644 index 00000000000..cb1c1d27056 --- /dev/null +++ b/velox/exec/OperatorTraceCtx.cpp @@ -0,0 +1,162 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/OperatorTraceCtx.h" + +#include +#include "velox/common/base/Exceptions.h" +#include "velox/exec/Operator.h" +#include "velox/exec/OperatorTraceWriter.h" +#include "velox/exec/TaskTraceWriter.h" +#include "velox/exec/trace/TraceUtil.h" + +namespace facebook::velox::exec::trace { +namespace { + +std::string setupTraceDirectory( + const Operator& op, + const std::string& queryTraceDir) { + const auto* operatorCtx = op.operatorCtx(); + const auto pipelineId = operatorCtx->driverCtx()->pipelineId; + const auto driverId = operatorCtx->driverCtx()->driverId; + + LOG(INFO) << "Trace input for operator type: " << op.operatorType() + << ", operator id: " << op.operatorId() + << ", pipeline: " << pipelineId << ", driver: " << driverId + << ", task: " << op.taskId(); + + const auto opTraceDirPath = + getOpTraceDirectory(queryTraceDir, op.planNodeId(), pipelineId, driverId); + + createTraceDirectory( + opTraceDirPath, + operatorCtx->driverCtx()->queryConfig().opTraceDirectoryCreateConfig()); + return opTraceDirPath; +} + +} // namespace + +OperatorTraceCtx::OperatorTraceCtx( + std::string queryNodeId, + std::string queryTraceDir, + UpdateAndCheckTraceLimitCB updateAndCheckTraceLimitCB, + std::string taskRegExp, + bool dryRun) + : TraceCtx(dryRun), + queryNodeId_(std::move(queryNodeId)), + queryTraceDir_(std::move(queryTraceDir)), + taskRegExp_(std::move(taskRegExp)), + updateAndCheckTraceLimitCB_(std::move(updateAndCheckTraceLimitCB)) { + VELOX_CHECK(!queryNodeId_.empty(), "The query trace node cannot be empty"); +} + +// static +std::unique_ptr OperatorTraceCtx::maybeCreate( + core::QueryCtx& queryCtx, + const core::PlanFragment& planFragment, + const std::string& taskId) { + const auto& queryConfig = queryCtx.queryConfig(); + + VELOX_USER_CHECK( + !queryConfig.queryTraceDir().empty(), + "Query trace enabled but the trace dir is not set"); + + VELOX_USER_CHECK( + !queryConfig.queryTraceTaskRegExp().empty(), + "Query trace enabled but the trace task regexp is not set"); + + if (!RE2::FullMatch(taskId, queryConfig.queryTraceTaskRegExp())) { + return nullptr; + } + + const auto traceNodeId = queryConfig.queryTraceNodeId(); + VELOX_USER_CHECK(!traceNodeId.empty(), "Query trace node ID are not set"); + + const auto traceDir = getTaskTraceDirectory( + queryConfig.queryTraceDir(), queryCtx.queryId(), taskId); + + VELOX_USER_CHECK_NOT_NULL( + core::PlanNode::findFirstNode( + planFragment.planNode.get(), + [traceNodeId](const core::PlanNode* node) -> bool { + return node->id() == traceNodeId; + }), + "Trace plan node ID = '{}' not found from task '{}'", + traceNodeId, + taskId); + + LOG(INFO) << "Trace input for plan nodes '" << traceNodeId << "' from task '" + << taskId << "'"; + + UpdateAndCheckTraceLimitCB updateAndCheckTraceLimitCB = [&](uint64_t bytes) { + queryCtx.updateTracedBytesAndCheckLimit(bytes); + }; + + return std::make_unique( + traceNodeId, + traceDir, + std::move(updateAndCheckTraceLimitCB), + queryConfig.queryTraceTaskRegExp(), + queryConfig.queryTraceDryRun()); +} + +bool OperatorTraceCtx::shouldTrace(const Operator& op) const { + const auto& nodeId = op.planNodeId(); + + if (queryNodeId_.empty() || queryNodeId_ != nodeId) { + return false; + } + + auto& tracedOpMap = op.operatorCtx()->driverCtx()->tracedOperatorMap; + if (const auto iter = tracedOpMap.find(op.operatorId()); + iter != tracedOpMap.end()) { + LOG(WARNING) << "Operator " << iter->first << " with type of " + << op.operatorType() << ", plan node " << nodeId + << " might be the auxiliary operator of " << iter->second + << " which has the same operator id"; + return false; + } + tracedOpMap.emplace(op.operatorId(), op.operatorType()); + + if (!canTrace(op.operatorType())) { + VELOX_UNSUPPORTED("{} does not support tracing", op.operatorType()); + } + return true; +} + +std::unique_ptr OperatorTraceCtx::createInputTracer( + Operator& op) const { + return std::make_unique( + &op, + setupTraceDirectory(op, queryTraceDir_), + memory::traceMemoryPool(), + updateAndCheckTraceLimitCB_); +} + +std::unique_ptr OperatorTraceCtx::createSplitTracer( + Operator& op) const { + return std::make_unique( + &op, setupTraceDirectory(op, queryTraceDir_)); +} + +std::unique_ptr OperatorTraceCtx::createMetadataTracer() + const { + createTraceDirectory(queryTraceDir_); + return std::make_unique( + queryTraceDir_, queryNodeId_, memory::traceMemoryPool()); +} + +} // namespace facebook::velox::exec::trace diff --git a/velox/common/base/TraceConfig.h b/velox/exec/OperatorTraceCtx.h similarity index 53% rename from velox/common/base/TraceConfig.h rename to velox/exec/OperatorTraceCtx.h index 0eb84070f22..26da31a2353 100644 --- a/velox/common/base/TraceConfig.h +++ b/velox/exec/OperatorTraceCtx.h @@ -16,45 +16,55 @@ #pragma once -#include -#include -#include -#include - -namespace facebook::velox { - -#define VELOX_TRACE_LIMIT_EXCEEDED(errorMessage) \ - _VELOX_THROW( \ - ::facebook::velox::VeloxRuntimeError, \ - ::facebook::velox::error_source::kErrorSourceRuntime.c_str(), \ - ::facebook::velox::error_code::kTraceLimitExceeded.c_str(), \ - /* isRetriable */ true, \ - "{}", \ - errorMessage); +#include "velox/core/PlanFragment.h" +#include "velox/core/QueryCtx.h" +#include "velox/exec/trace/TraceCtx.h" + +namespace facebook::velox::exec::trace { + +class TraceInputWriter; +class TraceSplitWriter; /// The callback used to update and aggregate the trace bytes of a query. If the /// query trace limit is set, the callback return true if the aggregate traced /// bytes exceed the set limit otherwise return false. using UpdateAndCheckTraceLimitCB = std::function; -struct TraceConfig { - /// Target query trace node id. - std::string queryNodeId; - /// Base dir of query trace. - std::string queryTraceDir; - UpdateAndCheckTraceLimitCB updateAndCheckTraceLimitCB; - /// The trace task regexp. - std::string taskRegExp; - /// If true, we only collect operator input trace without the actual - /// execution. This is used by crash debugging so that we can collect the - /// input that triggers the crash. - bool dryRun{false}; - - TraceConfig( +class OperatorTraceCtx : public TraceCtx { + public: + OperatorTraceCtx( std::string queryNodeId, std::string queryTraceDir, UpdateAndCheckTraceLimitCB updateAndCheckTraceLimitCB, std::string taskRegExp, bool dryRun); + + static std::unique_ptr maybeCreate( + core::QueryCtx& queryCtx, + const core::PlanFragment& planFragment, + const std::string& taskId); + + bool shouldTrace(const Operator& op) const override; + + std::unique_ptr createInputTracer( + Operator& op) const override; + + std::unique_ptr createSplitTracer( + Operator& op) const override; + + std::unique_ptr createMetadataTracer() const override; + + private: + /// Target query trace node id. + const std::string queryNodeId_; + + /// Base dir of query trace. + const std::string queryTraceDir_; + + /// The trace task regexp. + const std::string taskRegExp_; + + UpdateAndCheckTraceLimitCB updateAndCheckTraceLimitCB_; }; -} // namespace facebook::velox + +} // namespace facebook::velox::exec::trace diff --git a/velox/exec/OperatorTraceReader.cpp b/velox/exec/OperatorTraceReader.cpp index 2e92b6c4071..52edeffe323 100644 --- a/velox/exec/OperatorTraceReader.cpp +++ b/velox/exec/OperatorTraceReader.cpp @@ -19,9 +19,10 @@ #include #include "velox/common/file/FileInputStream.h" #include "velox/exec/OperatorTraceReader.h" -#include "velox/exec/TraceUtil.h" +#include "velox/exec/trace/TraceUtil.h" namespace facebook::velox::exec::trace { + OperatorTraceInputReader::OperatorTraceInputReader( std::string traceDir, RowTypePtr dataType, @@ -30,7 +31,7 @@ OperatorTraceInputReader::OperatorTraceInputReader( fs_(filesystems::getFileSystem(traceDir_, nullptr)), dataType_(std::move(dataType)), pool_(pool), - serde_(getNamedVectorSerde(VectorSerde::Kind::kPresto)), + serde_(getNamedVectorSerde("Presto")), inputStream_(getInputStream()) { VELOX_CHECK_NOT_NULL(dataType_); } @@ -155,4 +156,5 @@ std::vector OperatorTraceSplitReader::deserialize( } return splits; } + } // namespace facebook::velox::exec::trace diff --git a/velox/exec/OperatorTraceReader.h b/velox/exec/OperatorTraceReader.h index 6a8e7c65650..95355d09be6 100644 --- a/velox/exec/OperatorTraceReader.h +++ b/velox/exec/OperatorTraceReader.h @@ -19,10 +19,11 @@ #include "velox/common/file/FileInputStream.h" #include "velox/common/file/FileSystems.h" #include "velox/exec/Split.h" -#include "velox/exec/Trace.h" +#include "velox/exec/trace/Trace.h" #include "velox/serializers/PrestoSerializer.h" namespace facebook::velox::exec::trace { + /// Used to read an operator trace input. class OperatorTraceInputReader { public: @@ -95,4 +96,5 @@ class OperatorTraceSplitReader { const std::shared_ptr fs_; memory::MemoryPool* const pool_; }; + } // namespace facebook::velox::exec::trace diff --git a/velox/exec/OperatorTraceScan.cpp b/velox/exec/OperatorTraceScan.cpp index 965715dcc5f..e16c79ef746 100644 --- a/velox/exec/OperatorTraceScan.cpp +++ b/velox/exec/OperatorTraceScan.cpp @@ -15,8 +15,8 @@ */ #include "velox/exec/OperatorTraceScan.h" - -#include "velox/exec/TraceUtil.h" +#include "velox/exec/OperatorType.h" +#include "velox/exec/trace/TraceUtil.h" namespace facebook::velox::exec::trace { @@ -29,7 +29,7 @@ OperatorTraceScan::OperatorTraceScan( traceScanNode->outputType(), operatorId, traceScanNode->id(), - "OperatorTraceScan") { + OperatorType::kOperatorTraceScan) { traceReader_ = std::make_unique( getOpTraceDirectory( traceScanNode->traceDir(), diff --git a/velox/exec/OperatorTraceScan.h b/velox/exec/OperatorTraceScan.h index 1542e43f2c4..a18a31345dc 100644 --- a/velox/exec/OperatorTraceScan.h +++ b/velox/exec/OperatorTraceScan.h @@ -21,6 +21,7 @@ #include "velox/exec/OperatorTraceReader.h" namespace facebook::velox::exec::trace { + /// This is a scan operator for query replay. It uses traced data from a /// specific directory path, which is /// $traceRoot/$taskId/$nodeId/$pipelineId/$driverId. diff --git a/velox/exec/OperatorTraceWriter.cpp b/velox/exec/OperatorTraceWriter.cpp index 36719246f29..17dc05af0ad 100644 --- a/velox/exec/OperatorTraceWriter.cpp +++ b/velox/exec/OperatorTraceWriter.cpp @@ -23,15 +23,17 @@ #include "velox/common/file/File.h" #include "velox/common/file/FileSystems.h" #include "velox/exec/Operator.h" -#include "velox/exec/Trace.h" -#include "velox/exec/TraceUtil.h" +#include "velox/exec/OperatorType.h" +#include "velox/exec/trace/Trace.h" +#include "velox/exec/trace/TraceUtil.h" namespace facebook::velox::exec::trace { namespace { + void recordOperatorSummary(Operator* op, folly::dynamic& obj) { obj[OperatorTraceTraits::kOpTypeKey] = op->operatorType(); const auto stats = op->stats(/*clear=*/false); - if (op->operatorType() == "TableScan") { + if (op->operatorType() == OperatorType::kTableScan) { obj[OperatorTraceTraits::kNumSplitsKey] = stats.numSplits; } obj[OperatorTraceTraits::kPeakMemoryKey] = @@ -41,6 +43,7 @@ void recordOperatorSummary(Operator* op, folly::dynamic& obj) { obj[OperatorTraceTraits::kRawInputRowsKey] = stats.rawInputPositions; obj[OperatorTraceTraits::kRawInputBytesKey] = stats.rawInputBytes; } + } // namespace OperatorTraceInputWriter::OperatorTraceInputWriter( @@ -52,15 +55,17 @@ OperatorTraceInputWriter::OperatorTraceInputWriter( traceDir_(std::move(traceDir)), fs_(filesystems::getFileSystem(traceDir_, nullptr)), pool_(pool), - serde_(getNamedVectorSerde(VectorSerde::Kind::kPresto)), + serde_(getNamedVectorSerde("Presto")), updateAndCheckTraceLimitCB_(std::move(updateAndCheckTraceLimitCB)) { traceFile_ = fs_->openFileForWrite(getOpTraceInputFilePath(traceDir_)); VELOX_CHECK_NOT_NULL(traceFile_); } -void OperatorTraceInputWriter::write(const RowVectorPtr& rows) { +bool OperatorTraceInputWriter::write( + const RowVectorPtr& rows, + ContinueFuture*) { if (FOLLY_UNLIKELY(finished_)) { - return; + return false; } if (batch_ == nullptr) { @@ -80,6 +85,7 @@ void OperatorTraceInputWriter::write(const RowVectorPtr& rows) { auto iobuf = out.getIOBuf(); updateAndCheckTraceLimitCB_(iobuf->computeChainDataLength()); traceFile_->append(std::move(iobuf)); + return false; } void OperatorTraceInputWriter::finish() { @@ -146,9 +152,9 @@ std::unique_ptr OperatorTraceSplitWriter::serialize( auto ioBuf = folly::IOBuf::create(sizeof(length) + split.size() + sizeof(crc32)); folly::io::Appender appender(ioBuf.get(), 0); - appender.writeLE(length); + appender.writeLE(length); appender.push(reinterpret_cast(split.data()), length); - appender.writeLE(crc32); + appender.writeLE(crc32); return ioBuf; } diff --git a/velox/exec/OperatorTraceWriter.h b/velox/exec/OperatorTraceWriter.h index 189577dee1f..49b8a94a999 100644 --- a/velox/exec/OperatorTraceWriter.h +++ b/velox/exec/OperatorTraceWriter.h @@ -16,10 +16,12 @@ #pragma once -#include "velox/common/base/TraceConfig.h" #include "velox/common/file/File.h" #include "velox/common/file/FileSystems.h" +#include "velox/exec/OperatorTraceCtx.h" #include "velox/exec/Split.h" +#include "velox/exec/trace/TraceCtx.h" +#include "velox/exec/trace/TraceWriter.h" #include "velox/serializers/PrestoSerializer.h" #include "velox/vector/VectorStream.h" @@ -32,7 +34,7 @@ namespace facebook::velox::exec::trace { /// Used to serialize and write the input vectors from a particular operator /// into a data file. Additionally, it creates a corresponding summary file that /// contains information such as peak memory, input rows, operator type, etc. -class OperatorTraceInputWriter { +class OperatorTraceInputWriter : public TraceInputWriter { public: /// 'traceOp' is the operator to trace. 'traceDir' specifies the trace /// directory for the operator. @@ -43,10 +45,10 @@ class OperatorTraceInputWriter { UpdateAndCheckTraceLimitCB updateAndCheckTraceLimitCB); /// Serializes rows and writes out each batch. - void write(const RowVectorPtr& rows); + bool write(const RowVectorPtr& rows, ContinueFuture* future) override; /// Closes the data file and writes out the data summary. - void finish(); + void finish() override; private: // Flushes the trace data summaries to the disk. @@ -77,16 +79,16 @@ class OperatorTraceInputWriter { /// Currently, it only works with 'HiveConnectorSplit'. In the future, it will /// be extended to handle more types of splits, such as /// 'IcebergHiveConnectorSplit'. -class OperatorTraceSplitWriter { +class OperatorTraceSplitWriter : public TraceSplitWriter { public: explicit OperatorTraceSplitWriter(Operator* traceOp, std::string traceDir); /// Serializes and writes out each split. Each serialized split is immediately /// flushed to ensure that we can still replay a traced operator even if a /// crash occurs during execution. - void write(const exec::Split& split) const; + void write(const exec::Split& split) const override; - void finish(); + void finish() override; private: static std::unique_ptr serialize(const std::string& split); diff --git a/velox/exec/OperatorType.h b/velox/exec/OperatorType.h new file mode 100644 index 00000000000..40a11cdac8a --- /dev/null +++ b/velox/exec/OperatorType.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace facebook::velox::exec { + +/// Centralized constants for operator type strings used in operator +/// constructors and comparison sites throughout the execution engine. +struct OperatorType { + static constexpr std::string_view kAggregation = "Aggregation"; + static constexpr std::string_view kArrowStream = "ArrowStream"; + static constexpr std::string_view kAssignUniqueId = "AssignUniqueId"; + static constexpr std::string_view kBlockedOperator = "BlockedOperator"; + static constexpr std::string_view kCallbackSink = "CallbackSink"; + static constexpr std::string_view kEnforceDistinct = "EnforceDistinct"; + static constexpr std::string_view kEnforceSingleRow = "EnforceSingleRow"; + static constexpr std::string_view kExchange = "Exchange"; + static constexpr std::string_view kExpand = "Expand"; + static constexpr std::string_view kFilterProject = "FilterProject"; + static constexpr std::string_view kGroupId = "GroupId"; + static constexpr std::string_view kHashBuild = "HashBuild"; + static constexpr std::string_view kHashProbe = "HashProbe"; + static constexpr std::string_view kIndexLookupJoin = "IndexLookupJoin"; + static constexpr std::string_view kLimit = "Limit"; + static constexpr std::string_view kLocalExchange = "LocalExchange"; + static constexpr std::string_view kLocalMerge = "LocalMerge"; + static constexpr std::string_view kLocalPartition = "LocalPartition"; + static constexpr std::string_view kMarkDistinct = "MarkDistinct"; + static constexpr std::string_view kMergeExchange = "MergeExchange"; + static constexpr std::string_view kMergeJoin = "MergeJoin"; + static constexpr std::string_view kMixedUnion = "MixedUnion"; + static constexpr std::string_view kNestedLoopJoinBuild = + "NestedLoopJoinBuild"; + static constexpr std::string_view kNestedLoopJoinProbe = + "NestedLoopJoinProbe"; + static constexpr std::string_view kOperatorTraceScan = "OperatorTraceScan"; + static constexpr std::string_view kOrderBy = "OrderBy"; + static constexpr std::string_view kParallelProject = "ParallelProject"; + static constexpr std::string_view kPartialAggregation = "PartialAggregation"; + static constexpr std::string_view kPartitionedOutput = "PartitionedOutput"; + static constexpr std::string_view kRowNumber = "RowNumber"; + static constexpr std::string_view kSpatialJoinBuild = "SpatialJoinBuild"; + static constexpr std::string_view kSpatialJoinProbe = "SpatialJoinProbe"; + static constexpr std::string_view kStreamingEnforceDistinct = + "StreamingEnforceDistinct"; + static constexpr std::string_view kTableScan = "TableScan"; + static constexpr std::string_view kTableWrite = "TableWrite"; + static constexpr std::string_view kTableWriteMerge = "TableWriteMerge"; + static constexpr std::string_view kTopN = "TopN"; + static constexpr std::string_view kTopNRowNumber = "TopNRowNumber"; + static constexpr std::string_view kUnnest = "Unnest"; + static constexpr std::string_view kValues = "Values"; + static constexpr std::string_view kWindow = "Window"; +}; + +} // namespace facebook::velox::exec diff --git a/velox/exec/OperatorUtils.cpp b/velox/exec/OperatorUtils.cpp index eb7258202a9..184d8edd764 100644 --- a/velox/exec/OperatorUtils.cpp +++ b/velox/exec/OperatorUtils.cpp @@ -14,8 +14,10 @@ * limitations under the License. */ #include "velox/exec/OperatorUtils.h" +#include "velox/exec/PartitionedOutput.h" #include "velox/exec/VectorHasher.h" #include "velox/expression/EvalCtx.h" +#include "velox/serializers/PrestoSerializer.h" #include "velox/vector/ConstantVector.h" #include "velox/vector/FlatVector.h" #include "velox/vector/LazyVector.h" @@ -101,7 +103,11 @@ void gatherCopy( const std::vector& sources, const std::vector& sourceIndices, column_index_t sourceChannel) { - if (target->isScalar()) { + const bool flattenSources = + std::all_of(sources.begin(), sources.end(), [](const auto& source) { + return source->isFlatEncoding(); + }); + if (target->isScalar() && flattenSources) { VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( scalarGatherCopy, target->type()->kind(), @@ -124,7 +130,14 @@ bool shouldAggregateRuntimeMetric(const std::string& name) { "dataSourceAddSplitWallNanos", "dataSourceLazyWallNanos", "queuedWallNanos", - "flushTimes"}; + "flushTimes", + "driverCpuTimeNanos", + "ioWaitWallNanos", + "storageReadWallNanos", + "ssdCacheReadWallNanos", + "cacheWaitWallNanos", + "coalescedSsdLoadWallNanos", + "coalescedStorageLoadWallNanos"}; if (metricNames.contains(name)) { return true; } @@ -457,21 +470,21 @@ std::string makeOperatorSpillPath( } void setOperatorRuntimeStats( - const std::string& name, + std::string_view name, const RuntimeCounter& value, std::unordered_map& stats) { - stats[name] = RuntimeMetric(value.unit); - stats[name].addValue(value.value); + auto [it, _] = + stats.insert_or_assign(std::string(name), RuntimeMetric(value.unit)); + it->second.addValue(value.value); } void addOperatorRuntimeStats( - const std::string& name, + std::string_view name, const RuntimeCounter& value, std::unordered_map& stats) { - auto statIt = stats.find(name); - if (UNLIKELY(statIt == stats.end())) { - statIt = stats.insert(std::pair(name, RuntimeMetric(value.unit))).first; - } else { + auto [statIt, inserted] = + stats.emplace(std::string(name), RuntimeMetric(value.unit)); + if (!inserted) { VELOX_CHECK_EQ(statIt->second.unit, value.unit); } statIt->second.addValue(value.value); @@ -569,4 +582,21 @@ std::unique_ptr BlockedOperatorFactory::toOperator( } return nullptr; } + +std::unique_ptr getVectorSerdeOptions( + common::CompressionKind compressionKind, + const std::string& kind, + std::optional minCompressionRatio, + int32_t minCompressionPageSizeBytes) { + std::unique_ptr options = kind == "Presto" + ? std::make_unique() + : std::make_unique(); + options->compressionKind = compressionKind; + if (minCompressionRatio.has_value()) { + options->minCompressionRatio = minCompressionRatio.value(); + } + options->minCompressionPageSizeBytes = minCompressionPageSizeBytes; + return options; +} + } // namespace facebook::velox::exec diff --git a/velox/exec/OperatorUtils.h b/velox/exec/OperatorUtils.h index ac2698e437d..f80d48bc227 100644 --- a/velox/exec/OperatorUtils.h +++ b/velox/exec/OperatorUtils.h @@ -15,8 +15,12 @@ */ #pragma once +#include +#include "velox/core/QueryConfig.h" #include "velox/exec/Operator.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/Spiller.h" +#include "velox/vector/VectorStream.h" namespace facebook::velox::exec { @@ -147,13 +151,13 @@ std::string makeOperatorSpillPath( /// Set a named runtime metric in operator 'stats'. void setOperatorRuntimeStats( - const std::string& name, + std::string_view name, const RuntimeCounter& value, std::unordered_map& stats); /// Add a named runtime metric to operator 'stats'. void addOperatorRuntimeStats( - const std::string& name, + std::string_view name, const RuntimeCounter& value, std::unordered_map& stats); @@ -229,7 +233,12 @@ class BlockedOperator : public Operator { int32_t id, core::PlanNodePtr node, BlockedOperatorCb&& blockedCb) - : Operator(ctx, node->outputType(), id, node->id(), "BlockedOperator"), + : Operator( + ctx, + node->outputType(), + id, + node->id(), + OperatorType::kBlockedOperator), blockedCb_(std::move(blockedCb)) {} BlockingReason isBlocked(ContinueFuture* future) override { @@ -307,4 +316,13 @@ class BlockedOperatorFactory : public Operator::PlanNodeTranslator { private: BlockedOperatorCb blockedCb_{nullptr}; }; + +/// Creates VectorSerde::Options for the given VectorSerde kind with compression +/// settings. Optionally configures minimum compression ratio. +std::unique_ptr getVectorSerdeOptions( + common::CompressionKind compressionKind, + const std::string& kind, + std::optional minCompressionRatio = std::nullopt, + int32_t minCompressionPageSizeBytes = 0); + } // namespace facebook::velox::exec diff --git a/velox/exec/OrderBy.cpp b/velox/exec/OrderBy.cpp index dc6b5d54809..85957281898 100644 --- a/velox/exec/OrderBy.cpp +++ b/velox/exec/OrderBy.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ #include "velox/exec/OrderBy.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/OperatorUtils.h" #include "velox/exec/Task.h" #include "velox/vector/FlatVector.h" @@ -39,9 +40,9 @@ OrderBy::OrderBy( orderByNode->outputType(), operatorId, orderByNode->id(), - "OrderBy", + OperatorType::kOrderBy, orderByNode->canSpill(driverCtx->queryConfig()) - ? driverCtx->makeSpillConfig(operatorId) + ? driverCtx->makeSpillConfig(operatorId, OperatorType::kOrderBy) : std::nullopt) { maxOutputRows_ = outputBatchRows(std::nullopt); VELOX_CHECK(pool()->trackUsage()); diff --git a/velox/exec/OutputBuffer.cpp b/velox/exec/OutputBuffer.cpp index b2396d745ce..e3045d1f64e 100644 --- a/velox/exec/OutputBuffer.cpp +++ b/velox/exec/OutputBuffer.cpp @@ -30,10 +30,10 @@ void ArbitraryBuffer::noMoreData() { pages_.push_back(nullptr); } -void ArbitraryBuffer::enqueue(std::unique_ptr page) { +void ArbitraryBuffer::enqueue(std::unique_ptr page) { VELOX_CHECK_NOT_NULL(page, "Unexpected null page"); VELOX_CHECK(!hasNoMoreData(), "Arbitrary buffer has set no more data marker"); - pages_.push_back(std::shared_ptr(page.release())); + pages_.push_back(std::shared_ptr(page.release())); } void ArbitraryBuffer::getAvailablePageSizes(std::vector& out) const { @@ -45,7 +45,7 @@ void ArbitraryBuffer::getAvailablePageSizes(std::vector& out) const { } } -std::vector> ArbitraryBuffer::getPages( +std::vector> ArbitraryBuffer::getPages( uint64_t maxBytes) { if (maxBytes == 0 && !pages_.empty() && pages_.front() == nullptr) { // Always give out an end marker when this buffer is finished and fully @@ -57,7 +57,7 @@ std::vector> ArbitraryBuffer::getPages( VELOX_CHECK_EQ(pages_.size(), 1); return {nullptr}; } - std::vector> pages; + std::vector> pages; uint64_t bytesRemoved{0}; while (bytesRemoved < maxBytes && !pages_.empty()) { if (pages_.front() == nullptr) { @@ -81,7 +81,7 @@ std::string ArbitraryBuffer::toString() const { hasNoMoreData()); } -void DestinationBuffer::Stats::recordEnqueue(const SerializedPage& data) { +void DestinationBuffer::Stats::recordEnqueue(const SerializedPageBase& data) { const auto numRows = data.numRows(); VELOX_CHECK(numRows.has_value(), "SerializedPage's numRows must be valid"); bytesBuffered += data.size(); @@ -89,7 +89,8 @@ void DestinationBuffer::Stats::recordEnqueue(const SerializedPage& data) { ++pagesBuffered; } -void DestinationBuffer::Stats::recordAcknowledge(const SerializedPage& data) { +void DestinationBuffer::Stats::recordAcknowledge( + const SerializedPageBase& data) { const auto numRows = data.numRows(); VELOX_CHECK(numRows.has_value(), "SerializedPage's numRows must be valid"); const int64_t size = data.size(); @@ -104,7 +105,7 @@ void DestinationBuffer::Stats::recordAcknowledge(const SerializedPage& data) { ++pagesSent; } -void DestinationBuffer::Stats::recordDelete(const SerializedPage& data) { +void DestinationBuffer::Stats::recordDelete(const SerializedPageBase& data) { recordAcknowledge(data); } @@ -185,7 +186,7 @@ DestinationBuffer::Data DestinationBuffer::getData( return {std::move(data), std::move(remainingBytes), true}; } -void DestinationBuffer::enqueue(std::shared_ptr data) { +void DestinationBuffer::enqueue(std::shared_ptr data) { // Drop duplicate end markers. if (data == nullptr && !data_.empty() && data_.back() == nullptr) { return; @@ -245,7 +246,7 @@ void DestinationBuffer::loadData(ArbitraryBuffer* buffer, uint64_t maxBytes) { } } -std::vector> DestinationBuffer::acknowledge( +std::vector> DestinationBuffer::acknowledge( int64_t sequence, bool fromGetData) { const int64_t numDeleted = sequence - sequence_; @@ -268,7 +269,7 @@ std::vector> DestinationBuffer::acknowledge( VELOX_CHECK_LE( numDeleted, data_.size(), "Ack received for a not yet produced item"); - std::vector> freed; + std::vector> freed; for (auto i = 0; i < numDeleted; ++i) { if (data_[i] == nullptr) { VELOX_CHECK_EQ(i, data_.size() - 1, "null marker found in the middle"); @@ -277,14 +278,14 @@ std::vector> DestinationBuffer::acknowledge( stats_.recordAcknowledge(*data_[i]); freed.push_back(std::move(data_[i])); } - data_.erase(data_.begin(), data_.begin() + numDeleted); + data_.erase(data_.cbegin(), data_.cbegin() + numDeleted); sequence_ += numDeleted; return freed; } -std::vector> +std::vector> DestinationBuffer::deleteResults() { - std::vector> freed; + std::vector> freed; for (auto i = 0; i < data_.size(); ++i) { if (data_[i] == nullptr) { VELOX_CHECK_EQ(i, data_.size() - 1, "null marker found in the middle"); @@ -314,7 +315,7 @@ namespace { // that we do the expensive free outside and only then continue the // producers which will allocate more memory. void releaseAfterAcknowledge( - std::vector>& freed, + std::vector>& freed, std::vector& promises) { freed.clear(); for (auto& promise : promises) { @@ -445,7 +446,7 @@ void OutputBuffer::updateTotalBufferedBytesMsLocked() { bool OutputBuffer::enqueue( int destination, - std::unique_ptr data, + std::unique_ptr data, ContinueFuture* future) { VELOX_CHECK_NOT_NULL(data); VELOX_CHECK( @@ -492,13 +493,13 @@ bool OutputBuffer::enqueue( } void OutputBuffer::enqueueBroadcastOutputLocked( - std::unique_ptr data, + std::unique_ptr data, std::vector& dataAvailableCbs) { VELOX_DCHECK(isBroadcast()); VELOX_CHECK_NULL(arbitraryBuffer_); VELOX_DCHECK(dataAvailableCbs.empty()); - std::shared_ptr sharedData(data.release()); + std::shared_ptr sharedData(data.release()); for (auto& buffer : buffers_) { if (buffer != nullptr) { buffer->enqueue(sharedData); @@ -514,7 +515,7 @@ void OutputBuffer::enqueueBroadcastOutputLocked( } void OutputBuffer::enqueueArbitraryOutputLocked( - std::unique_ptr data, + std::unique_ptr data, std::vector& dataAvailableCbs) { VELOX_DCHECK(isArbitrary()); VELOX_DCHECK_NOT_NULL(arbitraryBuffer_); @@ -541,7 +542,7 @@ void OutputBuffer::enqueueArbitraryOutputLocked( void OutputBuffer::enqueuePartitionedOutputLocked( int destination, - std::unique_ptr data, + std::unique_ptr data, std::vector& dataAvailableCbs) { VELOX_DCHECK(isPartitioned()); VELOX_CHECK_NULL(arbitraryBuffer_); @@ -631,7 +632,7 @@ bool OutputBuffer::isFinishedLocked() { } void OutputBuffer::acknowledge(int destination, int64_t sequence) { - std::vector> freed; + std::vector> freed; std::vector promises; { std::lock_guard l(mutex_); @@ -649,7 +650,7 @@ void OutputBuffer::acknowledge(int destination, int64_t sequence) { } void OutputBuffer::updateAfterAcknowledgeLocked( - const std::vector>& freed, + const std::vector>& freed, std::vector& promises) { uint64_t freedBytes{0}; int freedPages{0}; @@ -673,7 +674,7 @@ void OutputBuffer::updateAfterAcknowledgeLocked( } bool OutputBuffer::deleteResults(int destination) { - std::vector> freed; + std::vector> freed; std::vector promises; bool isFinished; DataAvailable dataAvailable; @@ -717,7 +718,7 @@ void OutputBuffer::getData( DataAvailableCallback notify, DataConsumerActiveCheckCallback activeCheck) { DestinationBuffer::Data data; - std::vector> freed; + std::vector> freed; std::vector promises; { std::lock_guard l(mutex_); diff --git a/velox/exec/OutputBuffer.h b/velox/exec/OutputBuffer.h index 640c1bfcd8e..a21d9928e5a 100644 --- a/velox/exec/OutputBuffer.h +++ b/velox/exec/OutputBuffer.h @@ -15,6 +15,7 @@ */ #pragma once +#include "velox/common/base/Portability.h" #include "velox/core/PlanNode.h" #include "velox/exec/ExchangeQueue.h" @@ -75,11 +76,11 @@ class ArbitraryBuffer { /// appends a null page at the end of 'pages_' as end marker. void noMoreData(); - void enqueue(std::unique_ptr page); + void enqueue(std::unique_ptr page); /// Returns a number of pages with total bytes no less than 'maxBytes' if /// there are sufficient buffered pages. - std::vector> getPages(uint64_t maxBytes); + std::vector> getPages(uint64_t maxBytes); /// Append the available page sizes to `out'. void getAvailablePageSizes(std::vector& out) const; @@ -87,7 +88,7 @@ class ArbitraryBuffer { std::string toString() const; private: - std::deque> pages_; + std::deque> pages_; }; class DestinationBuffer { @@ -98,11 +99,11 @@ class DestinationBuffer { /// 2. Sent: the data is removed from the buffer after it is acked or /// deleted. struct Stats { - void recordEnqueue(const SerializedPage& data); + void recordEnqueue(const SerializedPageBase& data); - void recordAcknowledge(const SerializedPage& data); + void recordAcknowledge(const SerializedPageBase& data); - void recordDelete(const SerializedPage& data); + void recordDelete(const SerializedPageBase& data); bool finished{false}; @@ -117,7 +118,7 @@ class DestinationBuffer { int64_t pagesSent{0}; }; - void enqueue(std::shared_ptr data); + void enqueue(std::shared_ptr data); /// Invoked to load data with up to 'notifyMaxBytes_' bytes from arbitrary /// 'buffer' if there is pending fetch from this destination in which case @@ -165,12 +166,12 @@ class DestinationBuffer { /// do not give a warning for the case where no data is removed, otherwise we /// expect that data does get freed. We cannot assert that data gets deleted /// because acknowledge messages can arrive out of order. - std::vector> acknowledge( + std::vector> acknowledge( int64_t sequence, bool fromGetData); /// Removes all remaining data from the queue and returns the removed data. - std::vector> deleteResults(); + std::vector> deleteResults(); /// Returns and clears the notify callback, if any, along with arguments for /// the callback. @@ -187,7 +188,7 @@ class DestinationBuffer { private: void clearNotify(); - std::vector> data_; + std::vector> data_; // The sequence number of the first in 'data_'. int64_t sequence_ = 0; DataAvailableCallback notify_{nullptr}; @@ -280,7 +281,7 @@ class OutputBuffer { bool enqueue( int destination, - std::unique_ptr data, + std::unique_ptr data, ContinueFuture* future); void noMoreData(); @@ -345,7 +346,7 @@ class OutputBuffer { // Updates buffered size and returns possibly continuable producer promises // in 'promises'. void updateAfterAcknowledgeLocked( - const std::vector>& freed, + const std::vector>& freed, std::vector& promises); /// Given an updated total number of broadcast buffers, add any missing ones @@ -353,16 +354,16 @@ class OutputBuffer { void addOutputBuffersLocked(int numBuffers); void enqueueBroadcastOutputLocked( - std::unique_ptr data, + std::unique_ptr data, std::vector& dataAvailableCbs); void enqueueArbitraryOutputLocked( - std::unique_ptr data, + std::unique_ptr data, std::vector& dataAvailableCbs); void enqueuePartitionedOutputLocked( int destination, - std::unique_ptr data, + std::unique_ptr data, std::vector& dataAvailableCbs); std::string toStringLocked() const; @@ -401,11 +402,11 @@ class OutputBuffer { // While noMoreBuffers_ is false, stores the enqueued data to // broadcast to destinations that have not yet been initialized. Cleared // after receiving no-more-broadcast-buffers signal. - std::vector> dataToBroadcast_; + std::vector> dataToBroadcast_; std::mutex mutex_; // Actual data size in 'buffers_'. - int64_t bufferedBytes_{0}; + tsan_atomic bufferedBytes_{0}; // The number of buffered pages which corresponds to 'bufferedBytes_'. int64_t bufferedPages_{0}; // The total number of output bytes, rows and pages. @@ -425,7 +426,7 @@ class OutputBuffer { uint32_t numFinished_{0}; // When this reaches buffers_.size(), 'this' can be freed. int numFinalAcknowledges_ = 0; - bool atEnd_ = false; + tsan_atomic atEnd_{false}; // Time since last change in bufferedBytes_. Used to compute total time data // is buffered. Ignored if bufferedBytes_ is zero. diff --git a/velox/exec/OutputBufferManager.cpp b/velox/exec/OutputBufferManager.cpp index 4773911a530..f8183218fb4 100644 --- a/velox/exec/OutputBufferManager.cpp +++ b/velox/exec/OutputBufferManager.cpp @@ -57,7 +57,7 @@ uint64_t OutputBufferManager::numBuffers() const { bool OutputBufferManager::enqueue( const std::string& taskId, int destination, - std::unique_ptr data, + std::unique_ptr data, ContinueFuture* future) { return getBuffer(taskId)->enqueue(destination, std::move(data), future); } diff --git a/velox/exec/OutputBufferManager.h b/velox/exec/OutputBufferManager.h index 8affa6f9051..ef9487ee87b 100644 --- a/velox/exec/OutputBufferManager.h +++ b/velox/exec/OutputBufferManager.h @@ -53,7 +53,7 @@ class OutputBufferManager { bool enqueue( const std::string& taskId, int destination, - std::unique_ptr data, + std::unique_ptr data, ContinueFuture* future); void noMoreData(const std::string& taskId); diff --git a/velox/exec/ParallelProject.cpp b/velox/exec/ParallelProject.cpp index 8d3cf41e986..e772838e2bf 100644 --- a/velox/exec/ParallelProject.cpp +++ b/velox/exec/ParallelProject.cpp @@ -17,6 +17,7 @@ #include "velox/exec/ParallelProject.h" #include "velox/common/base/AsyncSource.h" #include "velox/exec/Operator.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/Task.h" namespace facebook::velox::exec { @@ -30,7 +31,7 @@ ParallelProject::ParallelProject( node->outputType(), operatorId, node->id(), - "ParallelProject"), + OperatorType::kParallelProject), node_(node) {} namespace { @@ -139,8 +140,9 @@ RowVectorPtr ParallelProject::getOutput() { std::vector results(outputType_->size()); for (auto i = 0; i < work_.size(); ++i) { - pending.push_back(std::make_shared>( - [i, &results, this]() { return doWork(i, results); })); + pending.push_back( + std::make_shared>( + [i, &results, this]() { return doWork(i, results); })); auto item = pending.back(); operatorCtx_->task()->queryCtx()->executor()->add( [item]() { item->prepare(); }); diff --git a/velox/exec/PartitionStreamingWindowBuild.cpp b/velox/exec/PartitionStreamingWindowBuild.cpp index 331d3e6f6e1..a3ce88cbe7b 100644 --- a/velox/exec/PartitionStreamingWindowBuild.cpp +++ b/velox/exec/PartitionStreamingWindowBuild.cpp @@ -78,7 +78,7 @@ PartitionStreamingWindowBuild::nextPartition() { data_->eraseRows( folly::Range(sortedRows_.data(), numPreviousPartitionRows)); sortedRows_.erase( - sortedRows_.begin(), sortedRows_.begin() + numPreviousPartitionRows); + sortedRows_.cbegin(), sortedRows_.cbegin() + numPreviousPartitionRows); sortedRows_.shrink_to_fit(); for (int i = currentPartition_; i < partitionStartRows_.size(); ++i) { partitionStartRows_[i] = diff --git a/velox/exec/PartitionStreamingWindowBuild.h b/velox/exec/PartitionStreamingWindowBuild.h index bb5cb352d24..2c9cc1e14c9 100644 --- a/velox/exec/PartitionStreamingWindowBuild.h +++ b/velox/exec/PartitionStreamingWindowBuild.h @@ -38,7 +38,7 @@ class PartitionStreamingWindowBuild : public WindowBuild { VELOX_UNREACHABLE(); } - std::optional spilledStats() const override { + std::optional spilledStats() const override { return std::nullopt; } diff --git a/velox/exec/PartitionedOutput.cpp b/velox/exec/PartitionedOutput.cpp index 974cfbb3ec8..ea221045b87 100644 --- a/velox/exec/PartitionedOutput.cpp +++ b/velox/exec/PartitionedOutput.cpp @@ -15,24 +15,12 @@ */ #include "velox/exec/PartitionedOutput.h" +#include "velox/exec/OperatorType.h" +#include "velox/exec/OperatorUtils.h" #include "velox/exec/OutputBufferManager.h" #include "velox/exec/Task.h" namespace facebook::velox::exec { -namespace { -std::unique_ptr getVectorSerdeOptions( - const core::QueryConfig& queryConfig, - VectorSerde::Kind kind) { - std::unique_ptr options = - kind == VectorSerde::Kind::kPresto - ? std::make_unique() - : std::make_unique(); - options->compressionKind = - common::stringToCompressionKind(queryConfig.shuffleCompressionKind()); - options->minCompressionRatio = PartitionedOutput::minCompressionRatio(); - return options; -} -} // namespace namespace detail { Destination::Destination( @@ -88,21 +76,17 @@ BlockingReason Destination::advance( } // Serialize - if (current_ == nullptr) { - current_ = std::make_unique(pool_, serde_); - const auto rowType = asRowType(output->type()); - current_->createStreamTree(rowType, rowsInCurrent_, serdeOptions_); - } + createVectorStreamGroup(output); const auto rows = folly::Range(&rows_[firstRow], rowIdx_ - firstRow); - if (serde_->kind() == VectorSerde::Kind::kCompactRow) { + if (serde_->kind() == "CompactRow") { VELOX_CHECK_NOT_NULL(outputCompactRow); current_->append(*outputCompactRow, rows, sizes); - } else if (serde_->kind() == VectorSerde::Kind::kUnsafeRow) { + } else if (serde_->kind() == "UnsafeRow") { VELOX_CHECK_NOT_NULL(outputUnsafeRow); current_->append(*outputUnsafeRow, rows, sizes); } else { - VELOX_CHECK_EQ(serde_->kind(), VectorSerde::Kind::kPresto); + VELOX_CHECK_EQ(serde_->kind(), "Presto"); current_->append(output, rows, scratch); } @@ -116,6 +100,26 @@ BlockingReason Destination::advance( return BlockingReason::kNotBlocked; } +void Destination::createVectorStreamGroup(const RowVectorPtr& output) { + if (current_ == nullptr || needsStreamTreeRecreation_) { + if (current_ == nullptr) { + current_ = std::make_unique(pool_, serde_); + } + const auto rowType = asRowType(output->type()); + current_->createStreamTree(rowType, rowsInCurrent_, serdeOptions_); + needsStreamTreeRecreation_ = false; + } +} + +void Destination::clearVectorStreamGroup() { + current_->clear(); + // Signal that createStreamTree() must be called before the next append + // to properly reinitialize the serializer with a fresh stream tree. + // This fixes a crash where the serializer was in an invalid state after + // clear() due to stale references to freed StreamArena memory. + needsStreamTreeRecreation_ = true; +} + BlockingReason Destination::flush( OutputBufferManager& bufferManager, const std::function& bufferReleaseFn, @@ -134,7 +138,20 @@ BlockingReason Destination::flush( const int64_t flushedRows = rowsInCurrent_; current_->flush(&stream); - current_->clear(); + + // Accumulate stats from the current serializer BEFORE clear() to preserve + // compression metrics across flushes. + const auto currentStats = current_->runtimeStats(); + for (const auto& [name, counter] : currentStats) { + auto it = accumulatedStats_.find(name); + if (it != accumulatedStats_.end()) { + it->second.value += counter.value; + } else { + accumulatedStats_.emplace(name, counter); + } + } + + clearVectorStreamGroup(); const int64_t flushedBytes = stream.tellp(); @@ -145,7 +162,7 @@ BlockingReason Destination::flush( bool blocked = bufferManager.enqueue( taskId_, destination_, - std::make_unique( + std::make_unique( stream.getIOBuf(bufferReleaseFn), nullptr, flushedRows), future); @@ -157,11 +174,18 @@ BlockingReason Destination::flush( void Destination::updateStats(Operator* op) { VELOX_CHECK(finished_); + auto lockedStats = op->stats().wlock(); + + // First add accumulated stats from previous serialization cycles. + for (const auto& [name, counter] : accumulatedStats_) { + lockedStats->addRuntimeStat(name, counter); + } + + // Then add stats from the current serializer (if any). if (current_) { const auto serializerStats = current_->runtimeStats(); - auto lockedStats = op->stats().wlock(); - for (auto& pair : serializerStats) { - lockedStats->addRuntimeStat(pair.first, pair.second); + for (const auto& [name, counter] : serializerStats) { + lockedStats->addRuntimeStat(name, counter); } } } @@ -178,7 +202,7 @@ PartitionedOutput::PartitionedOutput( planNode->outputType(), operatorId, planNode->id(), - "PartitionedOutput"), + OperatorType::kPartitionedOutput), keyChannels_(toChannels(planNode->inputType(), planNode->keys())), numDestinations_(planNode->numPartitions()), replicateNullsAndAny_(planNode->isReplicateNullsAndAny()), @@ -200,11 +224,19 @@ PartitionedOutput::PartitionedOutput( maxBufferedBytes_(ctx->task->queryCtx() ->queryConfig() .maxPartitionedOutputBufferSize()), - eagerFlush_(eagerFlush), + eagerFlush_( + eagerFlush || + ctx->task->queryCtx()->queryConfig().partitionedOutputEagerFlush()), serde_(getNamedVectorSerde(planNode->serdeKind())), serdeOptions_(getVectorSerdeOptions( - operatorCtx_->driverCtx()->queryConfig(), - planNode->serdeKind())) { + common::stringToCompressionKind(operatorCtx_->driverCtx() + ->queryConfig() + .shuffleCompressionKind()), + planNode->serdeKind(), + PartitionedOutput::minCompressionRatio(), + operatorCtx_->driverCtx() + ->queryConfig() + .minShuffleCompressionPageSizeBytes())) { if (!planNode->isPartitioned()) { VELOX_USER_CHECK_EQ(numDestinations_, 1); } @@ -245,9 +277,9 @@ void PartitionedOutput::initializeInput(RowVectorPtr input) { output_->childAt(i)->loadedVector(); } - if (serde_->kind() == VectorSerde::Kind::kCompactRow) { + if (serde_->kind() == "CompactRow") { outputCompactRow_ = std::make_unique(output_); - } else if (serde_->kind() == VectorSerde::Kind::kUnsafeRow) { + } else if (serde_->kind() == "UnsafeRow") { outputUnsafeRow_ = std::make_unique(output_); } } @@ -256,17 +288,18 @@ void PartitionedOutput::initializeDestinations() { if (destinations_.empty()) { auto taskId = operatorCtx_->taskId(); for (int i = 0; i < numDestinations_; ++i) { - destinations_.push_back(std::make_unique( - taskId, - i, - serde_, - serdeOptions_.get(), - pool(), - eagerFlush_, - [&](uint64_t bytes, uint64_t rows) { - auto lockedStats = stats_.wlock(); - lockedStats->addOutputVector(bytes, rows); - })); + destinations_.push_back( + std::make_unique( + taskId, + i, + serde_, + serdeOptions_.get(), + pool(), + eagerFlush_, + [&](uint64_t bytes, uint64_t rows) { + auto lockedStats = stats_.wlock(); + lockedStats->addOutputVector(bytes, rows); + })); } } } @@ -289,16 +322,16 @@ void PartitionedOutput::estimateRowSizes() { raw_vector storage(pool()); const auto numbers = iota(numInput, storage); const auto rows = folly::Range(numbers, numInput); - if (serde_->kind() == VectorSerde::Kind::kCompactRow) { + if (serde_->kind() == "CompactRow") { VELOX_CHECK_NOT_NULL(outputCompactRow_); serde_->estimateSerializedSize( outputCompactRow_.get(), rows, sizePointers_.data()); - } else if (serde_->kind() == VectorSerde::Kind::kUnsafeRow) { + } else if (serde_->kind() == "UnsafeRow") { VELOX_CHECK_NOT_NULL(outputUnsafeRow_); serde_->estimateSerializedSize( outputUnsafeRow_.get(), rows, sizePointers_.data()); } else { - VELOX_CHECK_EQ(serde_->kind(), VectorSerde::Kind::kPresto); + VELOX_CHECK_EQ(serde_->kind(), "Presto"); serde_->estimateSerializedSize( output_.get(), rows, sizePointers_.data(), scratch_); } @@ -481,7 +514,8 @@ void PartitionedOutput::close() { auto lockedStats = stats_.wlock(); lockedStats->addRuntimeStat( Operator::kShuffleSerdeKind, - RuntimeCounter(static_cast(serde_->kind()))); + RuntimeCounter( + static_cast(VectorSerde::kindByName(serde_->kind())))); lockedStats->addRuntimeStat( Operator::kShuffleCompressionKind, RuntimeCounter(static_cast(serdeOptions_->compressionKind))); diff --git a/velox/exec/PartitionedOutput.h b/velox/exec/PartitionedOutput.h index 5a1c44cf0b1..dd8c68e8b9e 100644 --- a/velox/exec/PartitionedOutput.h +++ b/velox/exec/PartitionedOutput.h @@ -102,6 +102,15 @@ class Destination { targetNumRows_ = (10'000 * targetSizePct_) / 100; } + // Creates VectorStreamGroup if needed. May recreate the stream tree + // after flush() to reinitialize the serializer. + void createVectorStreamGroup(const RowVectorPtr& output); + + // Clears the VectorStreamGroup and marks it for recreation. + // This ensures the serializer is properly reinitialized before the next + // append to avoid crashes from stale references to freed StreamArena memory. + void clearVectorStreamGroup(); + const std::string taskId_; const int destination_; VectorSerde* const serde_; @@ -122,6 +131,16 @@ class Destination { // The current stream where the input is serialized to. This is cleared on // every flush() call. std::unique_ptr current_; + + // Whether the stream tree needs to be recreated. Set after flush() to ensure + // proper initialization of the serializer before the next append. + bool needsStreamTreeRecreation_{false}; + + // Accumulated runtime stats from previous serialization cycles. Stats are + // collected before recreating the stream tree to avoid losing compression + // metrics from earlier flushes. + std::unordered_map accumulatedStats_; + bool finished_{false}; // Flush accumulated data to buffer manager after reaching this diff --git a/velox/exec/PlanNodeStats.cpp b/velox/exec/PlanNodeStats.cpp index ed3414fed7d..76c47b3db39 100644 --- a/velox/exec/PlanNodeStats.cpp +++ b/velox/exec/PlanNodeStats.cpp @@ -59,6 +59,14 @@ PlanNodeStats& PlanNodeStats::operator+=(const PlanNodeStats& another) { } } + for (const auto& [name, exprStats] : another.expressionStats) { + auto const [it, inserted] = + this->expressionStats.try_emplace(name, exprStats); + if (!inserted) { + it->second.add(exprStats); + } + } + // Populating number of drivers for plan nodes with multiple operators is not // useful. Each operator could have been executed in different pipelines with // different number of drivers. diff --git a/velox/exec/PrefixSort.cpp b/velox/exec/PrefixSort.cpp index 45bafec927e..6578429a02e 100644 --- a/velox/exec/PrefixSort.cpp +++ b/velox/exec/PrefixSort.cpp @@ -101,7 +101,7 @@ FOLLY_ALWAYS_INLINE void extractRowColumnToPrefix( default: VELOX_UNSUPPORTED( "prefix-sort does not support type kind: {}", - mapTypeKindToName(typeKind)); + TypeKindName::toName(typeKind)); } } @@ -310,7 +310,7 @@ void PrefixSort::extractRowAndEncodePrefixKeys(char* row, char* prefixBuffer) { } // static. -uint32_t PrefixSort::maxRequiredBytes( +uint64_t PrefixSort::maxRequiredBytes( const RowContainer* rowContainer, const std::vector& compareFlags, const velox::common::PrefixSortConfig& config, @@ -345,14 +345,15 @@ void PrefixSort::stdSort( }); } -uint32_t PrefixSort::maxRequiredBytes() const { +uint64_t PrefixSort::maxRequiredBytes() const { const auto numRows = rowContainer_->numRows(); const auto numPages = memory::AllocationTraits::numPages(numRows * sortLayout_.entrySize); // Prefix data size + swap buffer size. return memory::AllocationTraits::pageBytes(numPages) + - pool_->preferredSize(checkedPlus( - sortLayout_.entrySize, AlignedBuffer::kPaddedSize)) + + pool_->preferredSize( + checkedPlus( + sortLayout_.entrySize, AlignedBuffer::kPaddedSize)) + 2 * pool_->alignment(); } diff --git a/velox/exec/PrefixSort.h b/velox/exec/PrefixSort.h index 3d1fa46c2a5..e36c3588413 100644 --- a/velox/exec/PrefixSort.h +++ b/velox/exec/PrefixSort.h @@ -153,7 +153,7 @@ class PrefixSort { /// The std::sort won't require bytes while prefix sort may require buffers /// such as prefix data. The logic is similar to the above function /// PrefixSort::sort but returns the maximum buffer the sort may need. - static uint32_t maxRequiredBytes( + static uint64_t maxRequiredBytes( const RowContainer* rowContainer, const std::vector& compareFlags, const velox::common::PrefixSortConfig& config, @@ -161,7 +161,7 @@ class PrefixSort { /// The runtime stats name collected for prefix sort. /// The number of prefix sort keys. - static inline const std::string kNumPrefixSortKeys{"numPrefixSortKeys"}; + static constexpr std::string_view kNumPrefixSortKeys{"numPrefixSortKeys"}; private: /// Fallback to stdSort when prefix sort conditions such as config and memory @@ -205,7 +205,7 @@ class PrefixSort { // Estimates the memory required for prefix sort such as prefix buffer and // swap buffer. - uint32_t maxRequiredBytes() const; + uint64_t maxRequiredBytes() const; void sortInternal(std::vector>& rows); diff --git a/velox/exec/RowContainer.cpp b/velox/exec/RowContainer.cpp index cbf28e6c67a..5a3e8e291e0 100644 --- a/velox/exec/RowContainer.cpp +++ b/velox/exec/RowContainer.cpp @@ -42,8 +42,7 @@ static int32_t typeKindSize(TypeKind kind) { __attribute__((__no_sanitize__("thread"))) #endif #endif -inline void -setBit(char* bits, uint32_t idx) { +inline void setBit(char* bits, uint32_t idx) { auto bitsAs8Bit = reinterpret_cast(bits); bitsAs8Bit[idx / 8] |= (1 << (idx % 8)); } @@ -136,15 +135,19 @@ RowContainer::RowContainer( bool hasNext, bool isJoinBuild, bool hasProbedFlag, + bool hasCountFlag, bool hasNormalizedKeys, + bool useListRowIndex, memory::MemoryPool* pool) : keyTypes_(keyTypes), nullableKeys_(nullableKeys), isJoinBuild_(isJoinBuild), hasNormalizedKeys_(hasNormalizedKeys), + useListRowIndex_(useListRowIndex), stringAllocator_(std::make_unique(pool)), accumulators_(accumulators), - rows_(pool) { + rows_(pool), + rowPointers_(StlAllocator(stringAllocator_.get())) { // Compute the layout of the payload row. The row has keys, null flags, // accumulators, dependent fields. All fields are fixed width. If variable // width data is referenced, this is done with StringView(for VARCHAR) and @@ -245,6 +248,10 @@ RowContainer::RowContainer( nextOffset_ = offset; offset += sizeof(void*); } + if (hasCountFlag) { + countOffset_ = offset; + offset += sizeof(int32_t); + } fixedRowSize_ = bits::roundUp(offset, alignment_); originalNormalizedKeySize_ = hasNormalizedKeys_ ? bits::roundUp(sizeof(normalized_key_t), alignment_) @@ -280,6 +287,10 @@ char* RowContainer::newRow() { if (normalizedKeySize_) { ++numRowsWithNormalizedKey_; } + + if (useListRowIndex_) { + rowPointers_.push_back(row); + } } return initializeRow(row, false /* reuse */); } @@ -314,6 +325,9 @@ char* RowContainer::initializeRow(char* row, bool reuse) { variableRowSize(row) = 0; } bits::clearBit(row, freeFlagOffset_); + if (countOffset_) { + countRef(row) = 1; + } return row; } @@ -587,7 +601,7 @@ int32_t RowContainer::variableSizeAt(const char* row, column_index_t column) } const auto typeKind = typeKinds_[column]; - if (typeKind == TypeKind::VARCHAR || typeKind == TypeKind::VARBINARY) { + if (is_string_kind(typeKind)) { return reinterpret_cast(row + rowColumn.offset()) ->size(); } else { @@ -613,7 +627,7 @@ int32_t RowContainer::extractVariableSizeAt( } const auto typeKind = typeKinds_[column]; - if (typeKind == TypeKind::VARCHAR || typeKind == TypeKind::VARBINARY) { + if (is_string_kind(typeKind)) { const auto value = valueAt(row, rowColumn.offset()); const auto size = value.size(); ::memcpy(output, &size, 4); @@ -651,7 +665,7 @@ int32_t RowContainer::storeVariableSizeAt( // First 4 bytes is the size of the data. const auto size = *reinterpret_cast(data); - if (typeKind == TypeKind::VARCHAR || typeKind == TypeKind::VARBINARY) { + if (is_string_kind(typeKind)) { if (size > 0) { stringAllocator_->copyMultipart( StringView(data + 4, size), row, rowColumn.offset()); @@ -889,13 +903,11 @@ void RowContainer::hashTyped( : BaseVector::kNullHash; } else { uint64_t hash; - if constexpr (Kind == TypeKind::VARCHAR || Kind == TypeKind::VARBINARY) { + if constexpr (is_string_kind(Kind)) { hash = folly::hasher()(HashStringAllocator::contiguousString( valueAt(row, offset), storage)); - } else if constexpr ( - Kind == TypeKind::ROW || Kind == TypeKind::ARRAY || - Kind == TypeKind::MAP) { + } else if constexpr (is_nested_kind(Kind)) { auto in = prepareRead(row, offset); hash = ContainerRowSerde::hash(in, type); } else if constexpr (typeProvidesCustomComparison) { @@ -965,6 +977,8 @@ void RowContainer::clear() { hasDuplicateRows_ = false; rows_.clear(); + rowPointers_.clear(); + rowPointers_.shrink_to_fit(); stringAllocator_->clear(); numRows_ = 0; numRowsWithNormalizedKey_ = 0; @@ -1026,7 +1040,8 @@ std::optional RowContainer::estimateRowSize() const { } int64_t freeBytes = rows_.freeBytes() + fixedRowSize_ * numFreeRows_; int64_t usedSize = rows_.allocatedBytes() - freeBytes + - stringAllocator_->retainedSize() - stringAllocator_->freeSpace(); + stringAllocator_->retainedSize() - stringAllocator_->freeSpace() - + rowPointers_.capacity() * sizeof(char*); int64_t rowSize = usedSize / numRows_; VELOX_CHECK_GT( rowSize, 0, "Estimated row size of the RowContainer must be positive."); diff --git a/velox/exec/RowContainer.h b/velox/exec/RowContainer.h index 03166af6482..acdfb3ee608 100644 --- a/velox/exec/RowContainer.h +++ b/velox/exec/RowContainer.h @@ -83,6 +83,8 @@ struct RowContainerIterator { char* rowBegin{nullptr}; /// First byte after the end of the range containing 'currentRow'. char* endOfRun{nullptr}; + /// Cursor of the list row operation. + int32_t listRowCursor{0}; /// Returns the current row, skipping a possible normalized key below the /// first byte of row. @@ -118,6 +120,10 @@ class RowPartitions { return size_; } + void reset() { + size_ = 0; + } + private: const int32_t capacity_; @@ -273,6 +279,21 @@ class RowContainer { const std::vector& keyTypes, const std::vector& dependentTypes, memory::MemoryPool* pool) + : RowContainer( + keyTypes, + dependentTypes, + /*useListRowIndex=*/false, + pool) {} + + /// If 'useListRowIndex' is true, the container maintains an internal array of + /// row pointers so that listRowsFast() can return rows without scanning + /// underlying allocations or checking free/probe flags. It is intended to be + /// used in SortBuffer and SortInputSpiller to improve performance. + RowContainer( + const std::vector& keyTypes, + const std::vector& dependentTypes, + bool useListRowIndex, + memory::MemoryPool* pool) : RowContainer( keyTypes, true, // nullableKeys @@ -281,7 +302,9 @@ class RowContainer { false, // hasNext false, // isJoinBuild false, // hasProbedFlag + false, // hasCountFlag false, // hasNormalizedKey + useListRowIndex, pool) {} ~RowContainer(); @@ -312,7 +335,9 @@ class RowContainer { bool hasNext, bool isJoinBuild, bool hasProbedFlag, + bool hasCountFlag, bool hasNormalizedKey, + bool useListRowIndex, memory::MemoryPool* pool); /// Allocates a new row and initializes possible aggregates to null. @@ -573,8 +598,7 @@ class RowContainer { __attribute__((__no_sanitize__("thread"))) #endif #endif - int32_t - listRows( + int32_t listRows( RowContainerIterator* iter, int32_t maxRows, uint64_t maxBytes, @@ -638,6 +662,20 @@ class RowContainer { return count; } + /// Fast path for `listRows` that returns `rowPointers_` directly. Used by + /// `SortBuffer` and `SortInputSpiller`, so it skips checking the free and + /// probe flags. + int32_t listRowsFast(RowContainerIterator* iter, int32_t maxRows, char** rows) + const { + int32_t count = 0; + while (count < maxRows && iter->listRowCursor < rowPointers_.size()) { + char* row = rowPointers_[iter->listRowCursor]; + rows[count++] = row; + ++iter->listRowCursor; + } + return count; + } + /// Extracts up to 'maxRows' rows starting at the position of 'iter'. A /// default constructed or reset iter starts at the beginning. Returns the /// number of rows written to 'rows'. Returns 0 when at end. Stops after the @@ -652,6 +690,9 @@ class RowContainer { int32_t listRows(RowContainerIterator* iter, int32_t maxRows, char** rows) const { + if (useListRowIndex_) { + return listRowsFast(iter, maxRows, rows); + } return listRows(iter, maxRows, kUnlimited, rows); } @@ -668,21 +709,13 @@ class RowContainer { __attribute__((__no_sanitize__("thread"))) #endif #endif - void - setProbedFlag(char** rows, int32_t numRows); - - /// Returns true if 'row' at 'column' equals the value at 'index' in - /// 'decoded'. 'mayHaveNulls' specifies if nulls need to be checked. This is a - /// fast path for compare(). - template - bool equals( - const char* row, - RowColumn column, - const DecodedVector& decoded, - vector_size_t index) const; + void setProbedFlag(char** rows, int32_t numRows); /// Compares the value at 'column' in 'row' with the value at 'index' in /// 'decoded'. Returns 0 for equal, < 0 for 'row' < 'decoded', > 0 otherwise. + /// 'mayHaveNulls' specifies if nulls need to be checked. This is a fast path + /// for compare(). + template int32_t compare( const char* row, RowColumn column, @@ -735,6 +768,38 @@ class RowContainer { return probedFlagOffset_; } + /// Byte offset of the per-row count for counting joins. 0 if not applicable. + int32_t countOffset() const { + return countOffset_; + } + + /// Returns the count stored at the given row. Used for counting joins. + int32_t count(const char* row) const { + VELOX_DCHECK_NE(countOffset_, 0); + return countRef(const_cast(row)); + } + + /// Increments the count at the given row. Used during hash table build for + /// counting joins. + void incrementCount(char* row) const { + VELOX_DCHECK_NE(countOffset_, 0); + ++countRef(row); + } + + /// Decrements the count at the given row. Used during hash table probe for + /// counting joins. + void decrementCount(char* row) const { + VELOX_DCHECK_NE(countOffset_, 0); + --countRef(row); + } + + /// Adds 'n' to the count at the given row. Used during hash table merge + /// to combine counts from multiple build-side tables. + void addCount(char* row, int32_t n) const { + VELOX_DCHECK_NE(countOffset_, 0); + countRef(row) += n; + } + /// Returns the offset of a uint32_t row size or 0 if the row has no variable /// width fields or accumulators. int32_t rowSizeOffset() const { @@ -800,6 +865,10 @@ class RowContainer { return 0; } + const std::vector>& testingRowPointers() const { + return rowPointers_; + } + memory::MemoryPool* pool() const { return stringAllocator_->pool(); } @@ -1167,56 +1236,29 @@ class RowContainer { bool mix, uint64_t* result) const; - template - inline bool equalsWithNulls( - const char* row, - int32_t offset, - int32_t nullByte, - uint8_t nullMask, - const DecodedVector& decoded, - vector_size_t index) const { - bool rowIsNull = isNullAt(row, nullByte, nullMask); - bool indexIsNull = decoded.isNullAt(index); - if (rowIsNull || indexIsNull) { - return rowIsNull == indexIsNull; - } - - return equalsNoNulls( - row, offset, decoded, index); - } - - template - inline bool equalsNoNulls( + template + inline int compare( const char* row, - int32_t offset, + RowColumn column, const DecodedVector& decoded, - vector_size_t index) const { - using T = typename KindToFlatVector::HashRowType; - - if constexpr ( - Kind == TypeKind::ROW || Kind == TypeKind::ARRAY || - Kind == TypeKind::MAP) { - return compareComplexType(row, offset, decoded, index) == 0; - } else if constexpr ( - Kind == TypeKind::VARCHAR || Kind == TypeKind::VARBINARY) { - return compareStringAsc( - valueAt(row, offset), decoded, index) == 0; - } else if constexpr (typeProvidesCustomComparison) { - return SimpleVector::template comparePrimitiveAscWithCustomComparison< - Kind>( - decoded.base()->type().get(), - decoded.valueAt(index), - valueAt(row, offset)) == 0; + vector_size_t index, + CompareFlags flags) const { + if (decoded.base()->typeUsesCustomComparison()) { + return compare( + row, column, decoded, index, flags); } else { - return SimpleVector::comparePrimitiveAsc( - decoded.valueAt(index), valueAt(row, offset)) == 0; + return compare( + row, column, decoded, index, flags); } } template < bool typeProvidesCustomComparison, TypeKind Kind, - std::enable_if_t = 0> + bool mayHaveNulls, + std::enable_if_t< + Kind != TypeKind::OPAQUE && Kind != TypeKind::UNKNOWN, + int32_t> = 0> inline int compare( const char* row, RowColumn column, @@ -1224,20 +1266,23 @@ class RowContainer { vector_size_t index, CompareFlags flags) const { using T = typename KindToFlatVector::HashRowType; - bool rowIsNull = isNullAt(row, column.nullByte(), column.nullMask()); - bool indexIsNull = decoded.isNullAt(index); - if (rowIsNull) { - return indexIsNull ? 0 : flags.nullsFirst ? -1 : 1; - } - if (indexIsNull) { - return flags.nullsFirst ? 1 : -1; + + if constexpr (mayHaveNulls) { + bool rowIsNull = isNullAt(row, column.nullByte(), column.nullMask()); + bool indexIsNull = decoded.isNullAt(index); + if (rowIsNull) { + return indexIsNull ? 0 : flags.nullsFirst ? -1 : 1; + } + if (indexIsNull) { + return flags.nullsFirst ? 1 : -1; + } } + if constexpr ( Kind == TypeKind::ROW || Kind == TypeKind::ARRAY || Kind == TypeKind::MAP) { return compareComplexType(row, column.offset(), decoded, index, flags); - } else if constexpr ( - Kind == TypeKind::VARCHAR || Kind == TypeKind::VARBINARY) { + } else if constexpr (is_string_kind(Kind)) { auto result = compareStringAsc( valueAt(row, column.offset()), decoded, index); return flags.ascending ? result : result * -1; @@ -1261,6 +1306,22 @@ class RowContainer { template < bool typeProvidesCustomComparison, TypeKind Kind, + bool mayHaveNulls, + std::enable_if_t = 0> + inline int compare( + const char* row, + RowColumn column, + const DecodedVector& /*decoded*/, + vector_size_t /*index*/, + CompareFlags flags) const { + const bool rowIsNull = isNullAt(row, column.nullByte(), column.nullMask()); + return rowIsNull ? 0 : flags.nullsFirst ? 1 : -1; + } + + template < + bool typeProvidesCustomComparison, + TypeKind Kind, + bool mayHaveNulls, std::enable_if_t = 0> inline int compare( const char* /*row*/, @@ -1301,8 +1362,7 @@ class RowContainer { Kind == TypeKind::MAP) { return compareComplexType( left, right, type, leftOffset, rightOffset, flags); - } else if constexpr ( - Kind == TypeKind::VARCHAR || Kind == TypeKind::VARBINARY) { + } else if constexpr (is_string_kind(Kind)) { auto leftValue = valueAt(left, leftOffset); auto rightValue = valueAt(right, rightOffset); auto result = compareStringAsc(leftValue, rightValue); @@ -1482,12 +1542,17 @@ class RowContainer { } } + int32_t& countRef(char* row) const { + return *reinterpret_cast(row + countOffset_); + } + const std::vector keyTypes_; const bool nullableKeys_; const bool isJoinBuild_; // True if normalized keys are enabled in initial state. const bool hasNormalizedKeys_; - + // True if use 'listRowsFast'. + const bool useListRowIndex_; const std::unique_ptr stringAllocator_; // Indicates if we can add new row to this row container. It is set to false @@ -1521,6 +1586,9 @@ class RowContainer { // not applicable. int32_t probedFlagOffset_ = 0; + // Byte offset of the per-row count for counting joins. 0 if not applicable. + int32_t countOffset_ = 0; + // Bit position of free bit. int32_t freeFlagOffset_ = 0; int32_t rowSizeOffset_ = 0; @@ -1543,6 +1611,7 @@ class RowContainer { uint64_t numFreeRows_ = 0; memory::AllocationPool rows_; + std::vector> rowPointers_; int alignment_ = 1; @@ -1730,60 +1799,21 @@ inline void RowContainer::extractNulls( } template -inline bool RowContainer::equals( - const char* row, - RowColumn column, - const DecodedVector& decoded, - vector_size_t index) const { - auto typeKind = decoded.base()->typeKind(); - if (typeKind == TypeKind::UNKNOWN) { - return isNullAt(row, column.nullByte(), column.nullMask()); - } - - if constexpr (!mayHaveNulls) { - return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH( - equalsNoNulls, false, typeKind, row, column.offset(), decoded, index); - } else { - return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH( - equalsWithNulls, - false, - typeKind, - row, - column.offset(), - column.nullByte(), - column.nullMask(), - decoded, - index); - } -} - inline int RowContainer::compare( const char* row, RowColumn column, const DecodedVector& decoded, vector_size_t index, CompareFlags flags) const { - if (decoded.base()->typeUsesCustomComparison()) { - return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH_ALL( - compare, - true, - decoded.base()->typeKind(), - row, - column, - decoded, - index, - flags); - } else { - return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH_ALL( - compare, - false, - decoded.base()->typeKind(), - row, - column, - decoded, - index, - flags); - } + return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH_ALL( + compare, + mayHaveNulls, + decoded.base()->typeKind(), + row, + column, + decoded, + index, + flags); } inline int RowContainer::compare( diff --git a/velox/exec/RowNumber.cpp b/velox/exec/RowNumber.cpp index e57424da59d..cd2cd4ce36a 100644 --- a/velox/exec/RowNumber.cpp +++ b/velox/exec/RowNumber.cpp @@ -15,6 +15,7 @@ */ #include "velox/exec/RowNumber.h" #include "velox/common/memory/MemoryArbitrator.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/OperatorUtils.h" namespace facebook::velox::exec { @@ -28,9 +29,9 @@ RowNumber::RowNumber( rowNumberNode->outputType(), operatorId, rowNumberNode->id(), - "RowNumber", + OperatorType::kRowNumber, rowNumberNode->canSpill(driverCtx->queryConfig()) - ? driverCtx->makeSpillConfig(operatorId) + ? driverCtx->makeSpillConfig(operatorId, OperatorType::kRowNumber) : std::nullopt), limit_{rowNumberNode->limit()}, generateRowNumber_{rowNumberNode->generateRowNumber()} { @@ -46,6 +47,7 @@ RowNumber::RowNumber( false, // allowDuplicates false, // isJoinBuild false, // hasProbedFlag + false, // hasCountFlag 0, // minTableSizeForParallelJoinBuild pool()); lookup_ = std::make_unique(table_->hashers(), pool()); @@ -239,8 +241,11 @@ void RowNumber::ensureInputFits(const RowVectorPtr& input) { LOG(WARNING) << "Failed to reserve " << succinctBytes(targetIncrementBytes) << " for memory pool " << pool()->name() - << ", usage: " << succinctBytes(pool()->usedBytes()) - << ", reservation: " << succinctBytes(pool()->reservedBytes()); + << ", root pool: " << pool()->root()->name() + << ", used: " << succinctBytes(pool()->usedBytes()) + << ", reservation: " << succinctBytes(pool()->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool()->root()->reservedBytes()); } FlatVector& RowNumber::getOrCreateRowNumberVector(vector_size_t size) { @@ -383,12 +388,17 @@ void RowNumber::reclaim( return; } - if (exceededMaxSpillLevelLimit_) { + if (FOLLY_UNLIKELY(exceededMaxSpillLevelLimit_)) { LOG(WARNING) << "Exceeded row spill level limit: " << spillConfig_->maxSpillLevel - << ", and abandon spilling for memory pool: " - << pool()->name(); - ++spillStats_->wlock()->spillMaxLevelExceededCount; + << ", and abandon spilling for memory pool: " << pool()->name() + << ", root pool: " << pool()->root()->name() + << ", used: " << succinctBytes(pool()->usedBytes()) + << ", reservation: " << succinctBytes(pool()->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool()->root()->reservedBytes()); + spillStats_->spillMaxLevelExceededCount.fetch_add( + 1, std::memory_order_relaxed); return; } @@ -544,7 +554,7 @@ RowNumberHashTableSpiller::RowNumberHashTableSpiller( RowTypePtr rowType, HashBitRange bits, const common::SpillConfig* spillConfig, - folly::Synchronized* spillStats) + exec::SpillStats* spillStats) : SpillerBase( container, std::move(rowType), diff --git a/velox/exec/RowNumber.h b/velox/exec/RowNumber.h index 255827ece03..b34fc9d9c20 100644 --- a/velox/exec/RowNumber.h +++ b/velox/exec/RowNumber.h @@ -161,7 +161,7 @@ class RowNumberHashTableSpiller : public SpillerBase { RowTypePtr rowType, HashBitRange bits, const common::SpillConfig* spillConfig, - folly::Synchronized* spillStats); + exec::SpillStats* spillStats); void spill(); diff --git a/velox/exec/RowsStreamingWindowBuild.cpp b/velox/exec/RowsStreamingWindowBuild.cpp index a3dedd8a60c..7f03a465b6c 100644 --- a/velox/exec/RowsStreamingWindowBuild.cpp +++ b/velox/exec/RowsStreamingWindowBuild.cpp @@ -50,8 +50,9 @@ bool RowsStreamingWindowBuild::needsInput() { void RowsStreamingWindowBuild::ensureInputPartition() { if (windowPartitions_.empty() || windowPartitions_.back()->complete()) { - windowPartitions_.emplace_back(std::make_shared( - data_.get(), inversedInputChannels_, sortKeyInfo_)); + windowPartitions_.emplace_back( + std::make_shared( + data_.get(), inversedInputChannels_, sortKeyInfo_)); } } diff --git a/velox/exec/RowsStreamingWindowBuild.h b/velox/exec/RowsStreamingWindowBuild.h index 065ff4c7c44..421236889e3 100644 --- a/velox/exec/RowsStreamingWindowBuild.h +++ b/velox/exec/RowsStreamingWindowBuild.h @@ -41,7 +41,7 @@ class RowsStreamingWindowBuild : public WindowBuild { VELOX_UNREACHABLE(); } - std::optional spilledStats() const override { + std::optional spilledStats() const override { return std::nullopt; } diff --git a/velox/exec/ScaleWriterLocalPartition.cpp b/velox/exec/ScaleWriterLocalPartition.cpp index 3b9bcf4cee3..7530ff403a0 100644 --- a/velox/exec/ScaleWriterLocalPartition.cpp +++ b/velox/exec/ScaleWriterLocalPartition.cpp @@ -174,28 +174,11 @@ void ScaleWriterPartitioningLocalPartition::addInput(RowVectorPtr input) { row; } - for (auto i = 0; i < numPartitions_; ++i) { - const auto writerRowCount = writerAssignmentCounts_[i]; - if (writerRowCount == 0) { - continue; - } - - auto writerInput = processPartition( - input, - writerRowCount, - i, - std::move(writerAssignmmentIndicesBuffers_[i]), - rawWriterAssignmmentIndicesBuffers_[i]); - if (writerInput != nullptr) { - ContinueFuture future; - auto reason = queues_[i]->enqueue( - writerInput, totalInputBytes * writerRowCount / numInput, &future); - if (reason != BlockingReason::kNotBlocked) { - blockingReasons_.push_back(reason); - futures_.push_back(std::move(future)); - } - } - } + populateAndEnqueuePartitions( + input, + writerAssignmentCounts_, + writerAssignmmentIndicesBuffers_, + rawWriterAssignmmentIndicesBuffers_); } // Only update the scaling state if the memory used is below the diff --git a/velox/exec/ScaleWriterLocalPartition.h b/velox/exec/ScaleWriterLocalPartition.h index 3fba1e781d2..3bfd2c5467b 100644 --- a/velox/exec/ScaleWriterLocalPartition.h +++ b/velox/exec/ScaleWriterLocalPartition.h @@ -44,9 +44,9 @@ class ScaleWriterPartitioningLocalPartition : public LocalPartition { /// The name of the runtime stats of writer scaling. /// The number of times that we triggers the rebalance of table partitions. - static inline const std::string kRebalanceTriggers{"rebalanceTriggers"}; + static constexpr std::string_view kRebalanceTriggers{"rebalanceTriggers"}; /// The number of times that we scale a partition processing. - static inline const std::string kScaledPartitions{"scaledPartitions"}; + static constexpr std::string_view kScaledPartitions{"scaledPartitions"}; private: void prepareForWriterAssignments(vector_size_t numInput); @@ -98,7 +98,7 @@ class ScaleWriterLocalPartition : public LocalPartition { /// The name of the runtime stats of writer scaling. /// The number of scaled writers. - static inline const std::string kScaledWriters{"scaledWriters"}; + static constexpr std::string_view kScaledWriters{"scaledWriters"}; private: // Gets the writer id to process the next input in a round-robin manner. diff --git a/velox/exec/SerializedPage.cpp b/velox/exec/SerializedPage.cpp new file mode 100644 index 00000000000..62ab77218c9 --- /dev/null +++ b/velox/exec/SerializedPage.cpp @@ -0,0 +1,52 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/exec/SerializedPage.h" + +#include + +namespace facebook::velox::exec { + +PrestoSerializedPage::PrestoSerializedPage( + std::unique_ptr iobuf, + std::function onDestructionCb, + std::optional numRows) + : iobuf_(std::move(iobuf)), + iobufBytes_(chainBytes(*iobuf_.get())), + numRows_(numRows), + onDestructionCb_(std::move(onDestructionCb)) { + VELOX_CHECK_NOT_NULL(iobuf_); + for (auto& buf : *iobuf_) { + int32_t bufSize = buf.size(); + ranges_.push_back( + ByteRange{ + const_cast(reinterpret_cast(buf.data())), + bufSize, + 0}); + } +} + +PrestoSerializedPage::~PrestoSerializedPage() { + if (onDestructionCb_) { + onDestructionCb_(*iobuf_.get()); + } +} + +std::unique_ptr +PrestoSerializedPage::prepareStreamForDeserialize() { + return std::make_unique(std::move(ranges_)); +} + +} // namespace facebook::velox::exec diff --git a/velox/exec/SerializedPage.h b/velox/exec/SerializedPage.h new file mode 100644 index 00000000000..3f93de8012a --- /dev/null +++ b/velox/exec/SerializedPage.h @@ -0,0 +1,98 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/common/memory/ByteStream.h" + +namespace facebook::velox::exec { + +/// Interface for serialized pages. +class SerializedPageBase { + public: + virtual ~SerializedPageBase() = default; + + /// Returns the size of the serialized data in bytes. + virtual uint64_t size() const = 0; + + /// Returns the number of rows if available. + virtual std::optional numRows() const = 0; + + /// Makes 'input' ready for deserializing 'this' with + /// VectorStreamGroup::read(). + virtual std::unique_ptr prepareStreamForDeserialize() = 0; + + /// Returns a clone of the IOBuf. + virtual std::unique_ptr getIOBuf() const = 0; +}; + +/// Corresponds to Presto SerializedPage, i.e. a container for serialized +/// vectors in Presto wire format. +class PrestoSerializedPage : public SerializedPageBase { + public: + /// Construct from IOBuf chain. + explicit PrestoSerializedPage( + std::unique_ptr iobuf, + std::function onDestructionCb = nullptr, + std::optional numRows = std::nullopt); + + ~PrestoSerializedPage() override; + + uint64_t size() const override { + return iobufBytes_; + } + + std::optional numRows() const override { + return numRows_; + } + + std::unique_ptr prepareStreamForDeserialize() override; + + std::unique_ptr getIOBuf() const override { + return iobuf_->clone(); + } + + private: + static int64_t chainBytes(folly::IOBuf& iobuf) { + int64_t size = 0; + for (auto& range : iobuf) { + size += range.size(); + } + return size; + } + + // Buffers containing the serialized data. The memory is owned by 'iobuf_'. + std::vector ranges_; + + // IOBuf holding the data in 'ranges_. + std::unique_ptr iobuf_; + + // Number of payload bytes in 'iobuf_'. + const int64_t iobufBytes_; + + // Number of payload rows, if provided. + const std::optional numRows_; + + // Callback that will be called on destruction of the PrestoSerializedPage, + // primarily used to free externally allocated memory backing folly::IOBuf + // from caller. Caller is responsible to pass in proper cleanup logic to + // prevent any memory leak. + std::function onDestructionCb_; +}; + +// TODO: Remove after fully migration to new SerializedPageBase and +// PrestoSerializedPage API. +using SerializedPage = PrestoSerializedPage; +} // namespace facebook::velox::exec diff --git a/velox/exec/SimpleAggregateAdapter.h b/velox/exec/SimpleAggregateAdapter.h index 7ff74ae43b6..24a35955e5e 100644 --- a/velox/exec/SimpleAggregateAdapter.h +++ b/velox/exec/SimpleAggregateAdapter.h @@ -18,9 +18,43 @@ #include "velox/exec/Aggregate.h" #include "velox/expression/VectorReaders.h" #include "velox/expression/VectorWriters.h" +#include "velox/type/SimpleFunctionApi.h" namespace facebook::velox::exec { +// AggregateInputType is similar to Row but allows Variadic as the last element. +// Unlike Row, this type is not meant to be serialized or used as a data type +// in Velox vectors - it's only used for type deduction in +// SimpleAggregateAdapter. +template +struct AggregateInputType { + template + using type_at = typename std::tuple_element>::type; + + static const size_t size_ = sizeof...(T); + + // Verify that Variadic, if present, is only at the last position. + static_assert( + []() constexpr { + if constexpr (sizeof...(T) <= 1) { + return true; + } else { + // Check that no element except possibly the last is Variadic + constexpr bool checks[] = {!isVariadicType::value...}; + for (size_t i = 0; i < sizeof...(T) - 1; ++i) { + if (!checks[i]) { + return false; + } + } + return true; + } + }(), + "Variadic can only appear at the end of aggregation input type"); + + private: + AggregateInputType() {} +}; + // The writer type of T used in simple UDAF interface. An instance of // out_type allows writing one row into the output vector. template @@ -44,9 +78,15 @@ class SimpleAggregateAdapter : public Aggregate { explicit SimpleAggregateAdapter( core::AggregationNode::Step step, const std::vector& argTypes, - TypePtr resultType) + TypePtr resultType, + const core::QueryConfig* config = nullptr) : Aggregate(std::move(resultType)), fn_{std::make_unique()} { - if constexpr (support_initialize_) { + if constexpr (support_initialize_with_config_) { + VELOX_CHECK_NOT_NULL( + config, + "QueryConfig is required to initialize this aggregate function."); + fn_->initialize(step, argTypes, resultType_, *config); + } else if constexpr (support_initialize_) { fn_->initialize(step, argTypes, resultType_); } } @@ -159,6 +199,20 @@ class SimpleAggregateAdapter : public Aggregate { struct support_initialize> : std::true_type {}; + // Whether the function defines an initialize() method that accepts a + // QueryConfig parameter. If so, the config is forwarded during construction. + template + struct support_initialize_with_config : std::false_type {}; + + template + struct support_initialize_with_config< + T, + std::void_t().initialize( + std::declval(), + std::declval&>(), + std::declval(), + std::declval()))>> : std::true_type {}; + // Whether the accumulator requires aligned access. If it is defined, // SimpleAggregateAdapter::accumulatorAlignmentSize() returns // alignof(typename FUNC::AccumulatorType). @@ -188,6 +242,9 @@ class SimpleAggregateAdapter : public Aggregate { static constexpr bool support_initialize_ = support_initialize::value; + static constexpr bool support_initialize_with_config_ = + support_initialize_with_config::value; + static constexpr bool accumulator_is_aligned_ = accumulator_is_aligned::value; @@ -226,8 +283,7 @@ class SimpleAggregateAdapter : public Aggregate { inputDecoded_[i].decode(*args[i], rows); } - addRawInputImpl( - groups, rows, std::make_index_sequence{}); + addRawInputImpl(groups, rows); } // Similar to addRawInput, but add inputs to one single accumulator. @@ -244,8 +300,7 @@ class SimpleAggregateAdapter : public Aggregate { inputDecoded_[i].decode(*args[i], rows); } - addSingleGroupRawInputImpl( - group, rows, std::make_index_sequence{}); + addSingleGroupRawInputImpl(group, rows); } bool supportsToIntermediate() const override { @@ -262,11 +317,7 @@ class SimpleAggregateAdapter : public Aggregate { inputDecoded[i].decode(*args[i], rows); } - toIntermediateImpl( - inputDecoded, - rows, - result, - std::make_index_sequence{}); + toIntermediateImpl(inputDecoded, rows, result); } else { VELOX_UNREACHABLE( "toIntermediate should only be called when support_to_intermediate_ is true."); @@ -385,18 +436,97 @@ class SimpleAggregateAdapter : public Aggregate { } private: - template - void addRawInputImpl( + // Check if InputType has a variadic argument (must be at the last position). + static constexpr bool hasVariadicInput() { + if constexpr (FUNC::InputType::size_ == 0) { + return false; + } else { + return isVariadicType>::value; + } + } + + static constexpr bool has_variadic_input_ = hasVariadicInput(); + + // The position in InputType where the variadic argument starts. + static constexpr int32_t variadicStartPosition_ = + has_variadic_input_ ? FUNC::InputType::size_ - 1 : FUNC::InputType::size_; + + // Check if an input at POSITION is non-null for the given row. For + // non-variadic inputs this checks isSet(). For variadic inputs this checks + // hasTopLevelNull(), which inspects each variadic element's top-level + // nullity. This matches the expression layer's per-input null deselection + // for simple scalar functions. + template + static bool isInputNonNull(const TReader& reader, vector_size_t row) { + if constexpr (POSITION >= variadicStartPosition_) { + return !reader.hasTopLevelNull(row); + } else { + return reader.isSet(row); + } + } + + // Generic recursive unpacking function that creates VectorReaders for each + // input type and invokes the callback with all accumulated readers. + template + void unpackInputs( + const std::vector& inputDecoded, + const Callback& callback, + TReader&... readers) const { + if constexpr (POSITION == FUNC::InputType::size_) { + callback(readers...); + } else if constexpr (isVariadicType< + typename FUNC::InputType::template type_at< + POSITION>>::value) { + static_assert( + POSITION == FUNC::InputType::size_ - 1, + "Variadic must be the last argument"); + using VariadicT = typename FUNC::InputType::template type_at; + VectorReader variadicReader( + inputDecoded, variadicStartPosition_); + callback(readers..., variadicReader); + } else { + VectorReader> reader( + &inputDecoded[POSITION]); + unpackInputs(inputDecoded, callback, readers..., reader); + } + } + + void addRawInputImpl(char** groups, const SelectivityVector& rows) { + unpackInputs<0>(inputDecoded_, [&](auto&... readers) { + addRawInputWithReaders(groups, rows, readers...); + }); + } + + template + void addRawInputWithReaders( + char** groups, + const SelectivityVector& rows, + TReader&... readers) { + addRawInputWithReadersImpl( + groups, + rows, + std::make_index_sequence{}, + readers...); + } + + // Process rows with the accumulated readers. Use Is to access the original + // input types to create OptionalAccessor. + template + void addRawInputWithReadersImpl( char** groups, const SelectivityVector& rows, - std::index_sequence) { - std::tuple>...> - readers{&inputDecoded_[Is]...}; + std::index_sequence, + TReader&... readers) { + static_assert( + sizeof...(Is) == FUNC::InputType::size_, + "Reader count must match InputType size"); if constexpr (aggregate_default_null_behavior_) { - rows.applyToSelected([&](auto row) { - // If any input is null, we ignore the whole row. - if (!(std::get(readers).isSet(row) && ...)) { + rows.applyToSelected([&](vector_size_t row) { + // If any input is null, we ignore the whole row. For variadic + // arguments, this includes nulls within the variadic elements. + if (!(isInputNonNull(readers, row) && ...)) { return; } std::optional> tracker; @@ -404,11 +534,11 @@ class SimpleAggregateAdapter : public Aggregate { tracker.emplace(groups[row][rowSizeOffset_], *allocator_); } auto group = value(groups[row]); - group->addInput(allocator_, std::get(readers)[row]...); + group->addInput(allocator_, readers[row]...); clearNull(groups[row]); }); } else { - rows.applyToSelected([&](auto row) { + rows.applyToSelected([&](vector_size_t row) { std::optional> tracker; if constexpr (!accumulator_is_fixed_size_) { tracker.emplace(groups[row][rowSizeOffset_], *allocator_); @@ -417,7 +547,7 @@ class SimpleAggregateAdapter : public Aggregate { bool nonNull = group->addInput( allocator_, OptionalAccessor>{ - &std::get(readers), (int64_t)row}...); + &readers, (int64_t)row}...); if (nonNull) { clearNull(groups[row]); } @@ -425,30 +555,54 @@ class SimpleAggregateAdapter : public Aggregate { } } - template - void addSingleGroupRawInputImpl( + void addSingleGroupRawInputImpl(char* group, const SelectivityVector& rows) { + unpackInputs<0>(inputDecoded_, [&](auto&... readers) { + addSingleGroupRawInputWithReaders(group, rows, readers...); + }); + } + + template + void addSingleGroupRawInputWithReaders( + char* group, + const SelectivityVector& rows, + TReader&... readers) { + addSingleGroupRawInputWithReadersImpl( + group, + rows, + std::make_index_sequence{}, + readers...); + } + + // Process rows for a single group with the accumulated readers. Use Is to + // access the original input types to create OptionalAccessor. + template + void addSingleGroupRawInputWithReadersImpl( char* group, const SelectivityVector& rows, - std::index_sequence) { - std::tuple>...> - readers{&inputDecoded_[Is]...}; + std::index_sequence, + TReader&... readers) { + static_assert( + sizeof...(Is) == FUNC::InputType::size_, + "Reader count must match InputType size"); + auto accumulator = value(group); if constexpr (aggregate_default_null_behavior_) { - rows.applyToSelected([&](auto row) { - // If any input is null, we ignore the whole row. - if (!(std::get(readers).isSet(row) && ...)) { + rows.applyToSelected([&](vector_size_t row) { + // If any input is null, we ignore the whole row. For variadic + // arguments, this includes nulls within the variadic elements. + if (!(isInputNonNull(readers, row) && ...)) { return; } std::optional> tracker; if constexpr (!accumulator_is_fixed_size_) { tracker.emplace(group[rowSizeOffset_], *allocator_); } - accumulator->addInput(allocator_, std::get(readers)[row]...); + accumulator->addInput(allocator_, readers[row]...); clearNull(group); }); } else { - rows.applyToSelected([&](auto row) { + rows.applyToSelected([&](vector_size_t row) { std::optional> tracker; if constexpr (!accumulator_is_fixed_size_) { tracker.emplace(group[rowSizeOffset_], *allocator_); @@ -456,7 +610,7 @@ class SimpleAggregateAdapter : public Aggregate { bool nonNull = accumulator->addInput( allocator_, OptionalAccessor>{ - &std::get(readers), (int64_t)row}...); + &readers, (int64_t)row}...); if (nonNull) { clearNull(group); } @@ -464,15 +618,10 @@ class SimpleAggregateAdapter : public Aggregate { } } - template void toIntermediateImpl( const std::vector& inputDecoded, const SelectivityVector& rows, - VectorPtr& result, - std::index_sequence) const { - std::tuple>...> - readers{&inputDecoded[Is]...}; - + VectorPtr& result) const { VELOX_CHECK(result); result->ensureWritable(rows); auto* rawNulls = result->mutableRawNulls(); @@ -485,29 +634,56 @@ class SimpleAggregateAdapter : public Aggregate { exec::VectorWriter writer; writer.init(*flatResult); + unpackInputs<0>(inputDecoded, [&](auto&... readers) { + toIntermediateWithReaders(rows, writer, readers...); + }); + writer.finish(); + } + + template + void toIntermediateWithReaders( + const SelectivityVector& rows, + exec::VectorWriter& writer, + TReader&... readers) const { + toIntermediateWithReadersImpl( + rows, + writer, + std::make_index_sequence{}, + readers...); + } + + // Process rows for toIntermediate with the accumulated readers. + template + void toIntermediateWithReadersImpl( + const SelectivityVector& rows, + exec::VectorWriter& writer, + std::index_sequence, + TReader&... readers) const { + static_assert( + sizeof...(Is) == FUNC::InputType::size_, + "Reader count must match InputType size"); + if constexpr (aggregate_default_null_behavior_) { - rows.applyToSelected([&](auto row) { + rows.applyToSelected([&](vector_size_t row) { writer.setOffset(row); - // If any input is null, we ignore the whole row. - if (!(std::get(readers).isSet(row) && ...)) { + // If any input is null, we ignore the whole row. For variadic + // arguments, this includes nulls within the variadic elements. + if (!(isInputNonNull(readers, row) && ...)) { writer.commitNull(); return; } - bool nonNull = FUNC::toIntermediate( - writer.current(), std::get(readers)[row]...); + bool nonNull = FUNC::toIntermediate(writer.current(), readers[row]...); writer.commit(nonNull); }); - writer.finish(); } else { - rows.applyToSelected([&](auto row) { + rows.applyToSelected([&](vector_size_t row) { writer.setOffset(row); bool nonNull = FUNC::toIntermediate( writer.current(), OptionalAccessor>{ - &std::get(readers), (int64_t)row}...); + &readers, (int64_t)row}...); writer.commit(nonNull); }); - writer.finish(); } } @@ -519,7 +695,7 @@ class SimpleAggregateAdapter : public Aggregate { VectorReader reader(&intermediateDecoded_); if constexpr (aggregate_default_null_behavior_) { - rows.applyToSelected([&](auto row) { + rows.applyToSelected([&](vector_size_t row) { if (!reader.isSet(row)) { return; } @@ -532,7 +708,7 @@ class SimpleAggregateAdapter : public Aggregate { clearNull(groups[row]); }); } else { - rows.applyToSelected([&](auto row) { + rows.applyToSelected([&](vector_size_t row) { std::optional> tracker; if constexpr (!accumulator_is_fixed_size_) { tracker.emplace(groups[row][rowSizeOffset_], *allocator_); @@ -558,7 +734,7 @@ class SimpleAggregateAdapter : public Aggregate { auto accumulator = value(group); if constexpr (aggregate_default_null_behavior_) { - rows.applyToSelected([&](auto row) { + rows.applyToSelected([&](vector_size_t row) { if (!reader.isSet(row)) { return; } @@ -570,7 +746,7 @@ class SimpleAggregateAdapter : public Aggregate { clearNull(group); }); } else { - rows.applyToSelected([&](auto row) { + rows.applyToSelected([&](vector_size_t row) { std::optional> tracker; if constexpr (!accumulator_is_fixed_size_) { tracker.emplace(group[rowSizeOffset_], *allocator_); diff --git a/velox/exec/SortBuffer.cpp b/velox/exec/SortBuffer.cpp index d64f858d95f..fda9c19a5c9 100644 --- a/velox/exec/SortBuffer.cpp +++ b/velox/exec/SortBuffer.cpp @@ -28,7 +28,7 @@ SortBuffer::SortBuffer( tsan_atomic* nonReclaimableSection, common::PrefixSortConfig prefixSortConfig, const common::SpillConfig* spillConfig, - folly::Synchronized* spillStats) + exec::SpillStats* spillStats) : input_(input), sortCompareFlags_(sortCompareFlags), pool_(pool), @@ -37,7 +37,7 @@ SortBuffer::SortBuffer( spillConfig_(spillConfig), spillStats_(spillStats), sortedRows_(0, memory::StlAllocator(*pool)) { - VELOX_CHECK_GE(input_->size(), sortCompareFlags_.size()); + VELOX_CHECK_GE(input_->children().size(), sortCompareFlags_.size()); VELOX_CHECK_GT(sortCompareFlags_.size(), 0); VELOX_CHECK_EQ(sortColumnIndices.size(), sortCompareFlags_.size()); VELOX_CHECK_NOT_NULL(nonReclaimableSection_); @@ -74,7 +74,7 @@ SortBuffer::SortBuffer( } data_ = std::make_unique( - sortedColumnTypes, nonSortedColumnTypes, pool_); + sortedColumnTypes, nonSortedColumnTypes, /*useListRowIndex=*/true, pool_); spillerStoreType_ = ROW(std::move(sortedSpillColumnNames), std::move(sortedSpillColumnTypes)); } @@ -90,12 +90,12 @@ void SortBuffer::addInput(const VectorPtr& input) { VELOX_CHECK(!noMoreInput_); ensureInputFits(input); - SelectivityVector allRows(input->size()); + const SelectivityVector allRows(input->size()); std::vector rows(input->size()); for (int row = 0; row < input->size(); ++row) { rows[row] = data_->newRow(); } - auto* inputRow = input->as(); + const auto* inputRow = input->as(); for (const auto& columnProjection : columnMap_) { DecodedVector decoded( *inputRow->childAt(columnProjection.outputChannel), allRows); @@ -128,6 +128,7 @@ void SortBuffer::noMoreInput() { updateEstimatedOutputRowSize(); // Sort the pointers to the rows in RowContainer (data_) instead of sorting // the rows. + // TODO: Reuse 'RowContainer::rowPointers_'. sortedRows_.resize(numInputRows_); RowContainerIterator iter; data_->listRows(&iter, numInputRows_, sortedRows_.data()); @@ -168,7 +169,7 @@ RowVectorPtr SortBuffer::getOutput(vector_size_t maxOutputRows) { } else { getOutputWithoutSpill(); } - return output_; + return std::move(output_); } bool SortBuffer::hasSpilled() const { @@ -257,8 +258,11 @@ void SortBuffer::ensureInputFits(const VectorPtr& input) { } LOG(WARNING) << "Failed to reserve " << succinctBytes(targetIncrementBytes) << " for memory pool " << pool()->name() - << ", usage: " << succinctBytes(pool()->usedBytes()) - << ", reservation: " << succinctBytes(pool()->reservedBytes()); + << ", root pool: " << pool()->root()->name() + << ", used: " << succinctBytes(pool()->usedBytes()) + << ", reservation: " << succinctBytes(pool()->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool()->root()->reservedBytes()); } void SortBuffer::ensureOutputFits(vector_size_t batchSize) { @@ -289,8 +293,11 @@ void SortBuffer::ensureOutputFits(vector_size_t batchSize) { LOG(WARNING) << "Failed to reserve " << succinctBytes(outputBufferSizeToReserve) << " for memory pool " << pool_->name() - << ", usage: " << succinctBytes(pool_->usedBytes()) - << ", reservation: " << succinctBytes(pool_->reservedBytes()); + << ", root pool: " << pool_->root()->name() + << ", used: " << succinctBytes(pool_->usedBytes()) + << ", reservation: " << succinctBytes(pool_->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool_->root()->reservedBytes()); } void SortBuffer::ensureSortFits() { @@ -310,7 +317,7 @@ void SortBuffer::ensureSortFits() { } // The memory for std::vector sorted rows and prefix sort required buffer. - uint64_t sortBufferToReserve = + const auto sortBufferToReserve = numInputRows_ * sizeof(char*) + PrefixSort::maxRequiredBytes( data_.get(), sortCompareFlags_, prefixSortConfig_, pool_); @@ -389,10 +396,6 @@ void SortBuffer::prepareOutput(vector_size_t batchSize) { BaseVector::create(input_, batchSize, pool_)); } - for (auto& child : output_->children()) { - child->resize(batchSize); - } - if (hasSpilled()) { spillSources_.resize(batchSize); spillSourceRows_.resize(batchSize); @@ -488,7 +491,7 @@ void SortBuffer::prepareOutputWithSpill() { VELOX_CHECK_EQ(spillPartitionSet_.size(), 1); spillMerger_ = spillPartitionSet_.begin()->second->createOrderedReader( - spillConfig_->readBufferSize, pool(), spillStats_); + *spillConfig_, pool(), spillStats_); spillPartitionSet_.clear(); } } // namespace facebook::velox::exec diff --git a/velox/exec/SortBuffer.h b/velox/exec/SortBuffer.h index 9804e2f8589..99ddc866cf0 100644 --- a/velox/exec/SortBuffer.h +++ b/velox/exec/SortBuffer.h @@ -40,7 +40,7 @@ class SortBuffer { tsan_atomic* nonReclaimableSection, common::PrefixSortConfig prefixSortConfig, const common::SpillConfig* spillConfig = nullptr, - folly::Synchronized* spillStats = nullptr); + exec::SpillStats* spillStats = nullptr); ~SortBuffer(); @@ -125,7 +125,7 @@ class SortBuffer { const common::SpillConfig* const spillConfig_; - folly::Synchronized* const spillStats_; + exec::SpillStats* const spillStats_; // The column projection map between 'input_' and 'spillerStoreType_' as sort // buffer stores the sort columns first in 'data_'. diff --git a/velox/exec/SortWindowBuild.cpp b/velox/exec/SortWindowBuild.cpp index f25175cc2cf..89eeeb08160 100644 --- a/velox/exec/SortWindowBuild.cpp +++ b/velox/exec/SortWindowBuild.cpp @@ -16,6 +16,7 @@ #include "velox/exec/SortWindowBuild.h" #include "velox/exec/MemoryReclaimer.h" +#include "velox/exec/Window.h" namespace facebook::velox::exec { @@ -45,16 +46,19 @@ SortWindowBuild::SortWindowBuild( common::PrefixSortConfig&& prefixSortConfig, const common::SpillConfig* spillConfig, tsan_atomic* nonReclaimableSection, - folly::Synchronized* spillStats) + folly::Synchronized* opStats, + exec::SpillStats* spillStats) : WindowBuild(node, pool, spillConfig, nonReclaimableSection), numPartitionKeys_{node->partitionKeys().size()}, compareFlags_{makeCompareFlags(numPartitionKeys_, node->sortingOrders())}, pool_(pool), prefixSortConfig_(prefixSortConfig), + opStats_(opStats), spillStats_(spillStats), sortedRows_(0, memory::StlAllocator(*pool)), partitionStartRows_(0, memory::StlAllocator(*pool)) { VELOX_CHECK_NOT_NULL(pool_); + VELOX_CHECK_NOT_NULL(opStats_); allKeyInfo_.reserve(partitionKeyInfo_.size() + sortKeyInfo_.size()); allKeyInfo_.insert( allKeyInfo_.cend(), partitionKeyInfo_.begin(), partitionKeyInfo_.end()); @@ -72,13 +76,20 @@ void SortWindowBuild::addInput(RowVectorPtr input) { // Add all the rows into the RowContainer. for (auto row = 0; row < input->size(); ++row) { - char* newRow = data_->newRow(); + addDecodedInputRow(decodedInputVectors_, row); + } +} - for (auto col = 0; col < input->childrenSize(); ++col) { - data_->store(decodedInputVectors_[col], row, newRow, col); - } +void SortWindowBuild::addDecodedInputRow( + std::vector& decodedInputVectors, + vector_size_t row) { + char* newRow = data_->newRow(); + + for (auto col = 0; col < inputChannels_.size(); ++col) { + data_->store(decodedInputVectors[col], row, newRow, col); } - numRows_ += input->size(); + + numRows_++; } void SortWindowBuild::ensureInputFits(const RowVectorPtr& input) { @@ -135,9 +146,12 @@ void SortWindowBuild::ensureInputFits(const RowVectorPtr& input) { LOG(WARNING) << "Failed to reserve " << succinctBytes(targetIncrementBytes) << " for memory pool " << data_->pool()->name() - << ", usage: " << succinctBytes(data_->pool()->usedBytes()) + << ", root pool: " << data_->pool()->root()->name() + << ", used: " << succinctBytes(data_->pool()->usedBytes()) << ", reservation: " - << succinctBytes(data_->pool()->reservedBytes()); + << succinctBytes(data_->pool()->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(data_->pool()->root()->reservedBytes()); } void SortWindowBuild::ensureSortFits() { @@ -194,7 +208,7 @@ void SortWindowBuild::spill() { data_->pool()->release(); } -std::optional SortWindowBuild::spilledStats() const { +std::optional SortWindowBuild::spilledStats() const { if (spiller_ == nullptr) { return std::nullopt; } @@ -280,7 +294,7 @@ void SortWindowBuild::noMoreInput() { spiller_->finishSpill(spillPartitionSet); VELOX_CHECK_EQ(spillPartitionSet.size(), 1); merge_ = spillPartitionSet.begin()->second->createOrderedReader( - spillConfig_->readBufferSize, pool_, spillStats_); + *spillConfig_, pool_, spillStats_); } else { // At this point we have seen all the input rows. The operator is // being prepared to output rows now. @@ -294,14 +308,31 @@ void SortWindowBuild::noMoreInput() { pool_->release(); } -void SortWindowBuild::loadNextPartitionFromSpill() { +void SortWindowBuild::loadNextPartitionBatchFromSpill() { + // Check if current partition batch still has available partitions. If so, + // return directly. + if (currentPartition_ < static_cast(partitionStartRows_.size() - 2)) { + return; + } + + const int minReadBatchRows = spillConfig_->windowMinReadBatchRows; sortedRows_.clear(); - sortedRows_.shrink_to_fit(); + sortedRows_.reserve(minReadBatchRows); data_->clear(); + partitionStartRows_.clear(); + partitionStartRows_.reserve(minReadBatchRows); + partitionStartRows_.push_back(0); + currentPartition_ = -1; + numSpillReadBatches_++; + // Load at least #minReadBatchRows rows and a complete partition. The rows + // might contain multiple partitions. Record the partition boundaries as + // inMemory case. In this way, the logic of getting window partitions would be + // identical between inMemory and spill. for (;;) { auto next = merge_->next(); if (next == nullptr) { + partitionStartRows_.push_back(sortedRows_.size()); break; } @@ -324,7 +355,10 @@ void SortWindowBuild::loadNextPartitionFromSpill() { } if (newPartition) { - break; + partitionStartRows_.push_back(sortedRows_.size()); + if (sortedRows_.size() >= minReadBatchRows) { + break; + } } auto* newRow = data_->newRow(); @@ -334,16 +368,20 @@ void SortWindowBuild::loadNextPartitionFromSpill() { sortedRows_.push_back(newRow); next->pop(); } -} -std::shared_ptr SortWindowBuild::nextPartition() { - if (merge_ != nullptr) { - VELOX_CHECK(!sortedRows_.empty(), "No window partitions available"); - auto partition = folly::Range(sortedRows_.data(), sortedRows_.size()); - return std::make_shared( - data_.get(), partition, inversedInputChannels_, sortKeyInfo_); + // No more partition batches. All data is consumed. + if (sortedRows_.empty()) { + partitionStartRows_.clear(); + numSpillReadBatches_--; + + auto lockedOpStats = opStats_->wlock(); + lockedOpStats + ->runtimeStats[std::string(Window::kWindowSpillReadNumBatches)] = + RuntimeMetric(numSpillReadBatches_); } +} +std::shared_ptr SortWindowBuild::nextPartition() { VELOX_CHECK(!partitionStartRows_.empty(), "No window partitions available"); currentPartition_++; @@ -364,8 +402,7 @@ std::shared_ptr SortWindowBuild::nextPartition() { bool SortWindowBuild::hasNextPartition() { if (merge_ != nullptr) { - loadNextPartitionFromSpill(); - return !sortedRows_.empty(); + loadNextPartitionBatchFromSpill(); } return partitionStartRows_.size() > 0 && diff --git a/velox/exec/SortWindowBuild.h b/velox/exec/SortWindowBuild.h index 72875094007..47e51ad7979 100644 --- a/velox/exec/SortWindowBuild.h +++ b/velox/exec/SortWindowBuild.h @@ -16,6 +16,7 @@ #pragma once +#include "velox/common/file/FileSystems.h" #include "velox/exec/PrefixSort.h" #include "velox/exec/Spiller.h" #include "velox/exec/WindowBuild.h" @@ -32,7 +33,8 @@ class SortWindowBuild : public WindowBuild { common::PrefixSortConfig&& prefixSortConfig, const common::SpillConfig* spillConfig, tsan_atomic* nonReclaimableSection, - folly::Synchronized* spillStats); + folly::Synchronized* opStats, + exec::SpillStats* spillStats); ~SortWindowBuild() override { pool_->release(); @@ -45,9 +47,13 @@ class SortWindowBuild : public WindowBuild { void addInput(RowVectorPtr input) override; + void addDecodedInputRow( + std::vector& decodedInputVectors, + vector_size_t row); + void spill() override; - std::optional spilledStats() const override; + std::optional spilledStats() const override; void noMoreInput() override; @@ -55,9 +61,9 @@ class SortWindowBuild : public WindowBuild { std::shared_ptr nextPartition() override; - private: void ensureInputFits(const RowVectorPtr& input); + private: void ensureSortFits(); void setupSpiller(); @@ -75,8 +81,10 @@ class SortWindowBuild : public WindowBuild { // Find the next partition start row from start. vector_size_t findNextPartitionStartRow(vector_size_t start); - // Reads next partition from spilled data into 'data_' and 'sortedRows_'. - void loadNextPartitionFromSpill(); + // Load the next partition batch if needed. If current partition batch is not + // entirely consumed, return directly. Otherwise, read next partition batch + // from spilled data into 'data_' and set pointers in 'sortedRows_'. + void loadNextPartitionBatchFromSpill(); const size_t numPartitionKeys_; @@ -92,7 +100,9 @@ class SortWindowBuild : public WindowBuild { // Config for Prefix-sort. const common::PrefixSortConfig prefixSortConfig_; - folly::Synchronized* const spillStats_; + folly::Synchronized* const opStats_; + + exec::SpillStats* const spillStats_; // allKeyInfo_ is a combination of (partitionKeyInfo_ and sortKeyInfo_). // It is used to perform a full sorting of the input rows to be able to @@ -121,5 +131,8 @@ class SortWindowBuild : public WindowBuild { // Used to sort-merge spilled data. std::unique_ptr> merge_; + + // Number of batches of whole partitions read from spilled data. + uint64_t numSpillReadBatches_ = 0; }; } // namespace facebook::velox::exec diff --git a/velox/exec/SpatialIndex.cpp b/velox/exec/SpatialIndex.cpp new file mode 100644 index 00000000000..41adcff2ff2 --- /dev/null +++ b/velox/exec/SpatialIndex.cpp @@ -0,0 +1,179 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "velox/common/base/Exceptions.h" +#include "velox/exec/HilbertIndex.h" +#include "velox/exec/SpatialIndex.h" + +namespace facebook::velox::exec { + +std::vector RTreeLevel::query( + const Envelope& queryEnv, + const std::vector& branchIndices) const { + std::vector result; + + for (size_t branchIdx : branchIndices) { + size_t startIdx = branchIdx * branchSize_; + size_t endIdx = std::min(startIdx + branchSize_, minXs_.size()); + for (size_t idx = startIdx; idx < endIdx; ++idx) { + bool intersects = (queryEnv.maxX >= minXs_[idx]) && + (queryEnv.maxY >= minYs_[idx]) && (queryEnv.minX <= maxXs_[idx]) && + (queryEnv.minY <= maxYs_[idx]); + if (intersects) { + result.push_back(idx); + } + } + } + + return result; +} + +namespace { +std::pair> buildLevel( + uint32_t branchSize, + const std::vector& envelopes) { + std::vector minXs; + minXs.reserve(envelopes.size()); + std::vector minYs; + minYs.reserve(envelopes.size()); + std::vector maxXs; + maxXs.reserve(envelopes.size()); + std::vector maxYs; + maxYs.reserve(envelopes.size()); + + std::vector parentEnvelopes; + parentEnvelopes.reserve((envelopes.size() + branchSize - 1) / branchSize); + Envelope currentBounds = Envelope::empty(); + + uint32_t idx = 0; + for (const auto& env : envelopes) { + ++idx; + currentBounds.maxX = std::max(currentBounds.maxX, env.maxX); + currentBounds.maxY = std::max(currentBounds.maxY, env.maxY); + currentBounds.minX = std::min(currentBounds.minX, env.minX); + currentBounds.minY = std::min(currentBounds.minY, env.minY); + if (idx % branchSize == 0) { + parentEnvelopes.push_back(currentBounds); + currentBounds = Envelope::empty(); + } + + minXs.push_back(env.minX); + minYs.push_back(env.minY); + maxXs.push_back(env.maxX); + maxYs.push_back(env.maxY); + } + + if (!currentBounds.isEmpty()) { + parentEnvelopes.push_back(currentBounds); + } + + return { + RTreeLevel( + branchSize, + std::move(minXs), + std::move(minYs), + std::move(maxXs), + std::move(maxYs)), + std::move(parentEnvelopes)}; +} +} // namespace + +SpatialIndex::SpatialIndex( + Envelope bounds, + std::vector envelopes, + uint32_t branchSize) + : branchSize_(branchSize), bounds_(std::move(bounds)) { + VELOX_CHECK_GT(branchSize_, 1); + + if (!bounds_.isEmpty()) { + HilbertIndex hilbert( + bounds_.minX, bounds_.minY, bounds_.maxX, bounds_.maxY); + + std::sort( + envelopes.begin(), envelopes.end(), [&](const auto& a, const auto& b) { + return hilbert.indexOf(a.minX, a.minY) < + hilbert.indexOf(b.minX, b.minY); + }); + } + + rowIndices_.reserve(envelopes.size()); + for (const auto& env : envelopes) { + VELOX_CHECK(env.minX >= bounds_.minX); + VELOX_CHECK(env.minY >= bounds_.minY); + VELOX_CHECK(env.maxX <= bounds_.maxX); + VELOX_CHECK(env.maxY <= bounds_.maxY); + rowIndices_.push_back(env.rowIndex); + } + + if (envelopes.size() > 0) { + size_t numLevels = + std::ceil(std::log(envelopes.size()) / std::log(branchSize_)); + levels_.reserve(numLevels); + } + + while (envelopes.size() > branchSize_) { + auto [level, parentEnvelopes] = buildLevel(branchSize_, envelopes); + levels_.push_back(std::move(level)); + envelopes = std::move(parentEnvelopes); + } + + if (envelopes.size() > 1 || levels_.empty()) { + levels_.push_back(buildLevel(branchSize_, envelopes).first); + } + + VELOX_CHECK_GT(branchSize_ + 1, levels_.back().size()); +} + +std::vector SpatialIndex::query(const Envelope& queryEnv) const { + std::vector result; + if (!Envelope::intersects(queryEnv, bounds_)) { + return result; + } + + size_t thisLevel = levels_.size() - 1; + VELOX_CHECK_GT(levels_[thisLevel].size(), 0); + VELOX_CHECK_GT(branchSize_ + 1, levels_[thisLevel].size()); + + // The top level should have only one branch. + std::vector childIndices = {0}; + for (; thisLevel > 0; --thisLevel) { + // Avoiding thisLevel = 0 due to int underflow + childIndices = levels_[thisLevel].query(queryEnv, childIndices); + // If we have no matches, return. + if (childIndices.empty()) { + return result; + } + } + + // We're at level 0 now. The indices index into rowIndices. + VELOX_DCHECK_EQ(thisLevel, 0); + childIndices = levels_[thisLevel].query(queryEnv, childIndices); + result.reserve(childIndices.size()); + for (auto idx : childIndices) { + result.push_back(rowIndices_[idx]); + } + + return result; +} + +Envelope SpatialIndex::bounds() const { + return bounds_; +} + +} // namespace facebook::velox::exec diff --git a/velox/exec/SpatialIndex.h b/velox/exec/SpatialIndex.h new file mode 100644 index 00000000000..d6b2ad6d83b --- /dev/null +++ b/velox/exec/SpatialIndex.h @@ -0,0 +1,231 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include "velox/vector/TypeAliases.h" + +namespace facebook::velox::exec { + +/// A minimal envelope for a geometry. +/// It also includes an index for the geometry for later reference. This can +/// be -1 if the geometry is not indexed. +/// +/// Envelopes use float32s instead of float64s so that SIMD loops can be +/// twice as fast. Our geometries use float64 coordinates, so we have to +/// downcast them for the envelope. The loss of precision is theoretically fine +/// because the envelope checks are already approximate: either they don't +/// intersect or they might intersect. Thus, expanding the envelopes slightly +/// does not affect correctness (but it might affect efficiency slightly). +/// +/// We want to show that if two envelopes expressed with float64 precision would +/// intersect, the envelopes with float32 precision would also intersect. +/// +/// Define +/// ``` +/// nextUp(f) = std::nextafter(f, std::numeric_limits::infinity()) +/// nextDown(f) = std::nextafter(f, -std::numeric_limits::infinity()) +/// ``` +/// which move a float up or down one ulp (unit in the last place). +/// +/// Since the conditions are all of the form `maxX >= minX` for float64s maxX +/// and minX, we need to show that this implies `nextUp((float) maxX) >= +/// nextDown((float) minX)`. +/// +/// Assume you have a double `d` and two adjacent floats `f0` and `f1`, such +/// that `d` is "between" `f0` and `f1`: +/// +/// 1. `(double) f0 <= d <= (double) f1` +/// 2. `nextup(f0) == f1 && f0 = nextdown(f0)` +/// +/// This implies `nextdown((float) d) <= f0 && nextup((float) d) >= f1`. +/// +/// Let double `minX` have two adjacent floats `f0`, `f1` as above, and `maxX` +/// have two adjacent floats `g0`, `g1`. Then +/// ``` +/// (double) nextDown((float) minX) +/// <= (double) f0 +/// <= minX +/// <= maxX +/// <= (double) g1 +/// <= (double) nextUp((float) maxX) +/// ``` +/// +/// And this implies `nextDown((float) minX) <= nextUp((float) maxX)` as +/// desired. The same argument applies to all for members, so if we construct +/// the float32 precision envelope by applying nextDown to the minX/Ys and +/// nextUp to maxX/Ys, the float32 envelope intersects in all cases that the +/// float64 envelope would (but not necessarily the converse). +struct Envelope { + float minX{std::numeric_limits::infinity()}; + float minY{std::numeric_limits::infinity()}; + float maxX{-std::numeric_limits::infinity()}; + float maxY{-std::numeric_limits::infinity()}; + vector_size_t rowIndex = -1; + + /// Returns true if the intersection of two envelopes is not empty. + static inline bool intersects(const Envelope& left, const Envelope& right) { + return (left.maxX >= right.minX) && (left.minX <= right.maxX) && + (left.maxY >= right.minY) && (left.minY <= right.maxY); + } + + /// Returns true if the envelope contains at least one point. + /// An envelope of a point is not empty. + inline bool isEmpty() const { + // This negation handles NaNs correctly. + return !((minX <= maxX) && (minY <= maxY)); + } + + /// Expands this Envelope to also contain the other. + inline void merge(const Envelope& other) { + minX = std::min(minX, other.minX); + minY = std::min(minY, other.minY); + maxX = std::max(maxX, other.maxX); + maxY = std::max(maxY, other.maxY); + } + + /// Construct an empty envelope. + static constexpr inline Envelope empty() { + return Envelope{ + .minX = std::numeric_limits::infinity(), + .minY = std::numeric_limits::infinity(), + .maxX = -std::numeric_limits::infinity(), + .maxY = -std::numeric_limits::infinity()}; + } + + static constexpr inline Envelope from( + double minX, + double minY, + double maxX, + double maxY, + vector_size_t rowIndex = -1) { + return Envelope{ + .minX = std::nextafterf( + static_cast(minX), -std::numeric_limits::infinity()), + .minY = std::nextafterf( + static_cast(minY), -std::numeric_limits::infinity()), + .maxX = std::nextafterf( + static_cast(maxX), std::numeric_limits::infinity()), + .maxY = std::nextafterf( + static_cast(maxY), std::numeric_limits::infinity()), + .rowIndex = rowIndex}; + } + + static inline Envelope of(const std::vector& envelopes) { + Envelope result = Envelope::empty(); + for (const auto& envelope : envelopes) { + result.merge(envelope); + } + return result; + } +}; + +/// A single level of an R-tree. It is a set of envelopes that can be linearly +/// scanned for envelope intersection. +class RTreeLevel { + public: + RTreeLevel(const RTreeLevel&) = delete; + RTreeLevel& operator=(const RTreeLevel&) = delete; + + RTreeLevel() = default; + RTreeLevel(RTreeLevel&&) = default; + RTreeLevel& operator=(RTreeLevel&&) = default; + ~RTreeLevel() = default; + + explicit RTreeLevel( + size_t branchSize, + std::vector minXs, + std::vector minYs, + std::vector maxXs, + std::vector maxYs) + : branchSize_{branchSize}, + minXs_(std::move(minXs)), + minYs_(std::move(minYs)), + maxXs_(std::move(maxXs)), + maxYs_(std::move(maxYs)) {} + + /// Returns the internal indices of all envelopes that probeEnv intersects. + /// Order of the returned indices is an implementation detail and cannot be + /// relied upon. + /// This does not do a short-circuit bounds check: the caller should do that + /// first. + std::vector query( + const Envelope& queryEnv, + const std::vector& branchIndices) const; + + size_t size() const { + return minXs_.size(); + } + + private: + size_t branchSize_{}; + Envelope bounds_; + std::vector minXs_{}; + std::vector minYs_{}; + std::vector maxXs_{}; + std::vector maxYs_{}; +}; + +/// A spatial index for a set of geometries. The index only cares about the +/// envelopes of the geometries, and an index into the geometries (not stored in +/// SpatialIndex). +/// +/// The contract is that SpatialIndex::probe returns the indices of all +/// envelopes that probeEnv intersects. The form of the index is an +/// implementation detail. The order of the returned indicies is an +/// implementation detail. +class SpatialIndex { + public: + SpatialIndex(const SpatialIndex&) = delete; + SpatialIndex& operator=(const SpatialIndex&) = delete; + + SpatialIndex() = default; + SpatialIndex(SpatialIndex&&) = default; + SpatialIndex& operator=(SpatialIndex&&) = default; + ~SpatialIndex() = default; + + static const uint32_t kDefaultRTreeBranchSize = 32; + + /// Constructs a spatial index from envelopes contained with `bounds`. + /// `bounds` must contain all envelopes in `envelopes`, otherwise the + /// an assertio will fail. Envelopes should not contain NaN coordinates. + explicit SpatialIndex( + Envelope bounds, + std::vector envelopes, + uint32_t branchSize = kDefaultRTreeBranchSize); + + /// Returns the row indices of all envelopes that probeEnv intersects. + /// Order of the returned indices is an implementation detail and cannot be + /// relied upon. + std::vector query(const Envelope& queryEnv) const; + + /// Returns the envelope of the all envelopes in the index. + /// The returned envelope will have index = -1. + Envelope bounds() const; + + private: + uint32_t branchSize_ = kDefaultRTreeBranchSize; + + Envelope bounds_ = Envelope::empty(); + std::vector levels_{}; + std::vector rowIndices_{}; +}; + +} // namespace facebook::velox::exec diff --git a/velox/exec/SpatialJoinBuild.cpp b/velox/exec/SpatialJoinBuild.cpp index 240e3ba6854..238c9ba7b5b 100644 --- a/velox/exec/SpatialJoinBuild.cpp +++ b/velox/exec/SpatialJoinBuild.cpp @@ -13,30 +13,39 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include "velox/exec/SpatialJoinBuild.h" +#include +#include "velox/common/geospatial/GeometryConstants.h" +#ifdef VELOX_ENABLE_GEO +#include "velox/common/geospatial/GeometrySerde.h" +#endif +#include "velox/exec/OperatorType.h" #include "velox/exec/Task.h" namespace facebook::velox::exec { -void SpatialJoinBridge::setData(std::vector buildVectors) { +using velox::common::geospatial::GeometrySerializationType; + +void SpatialJoinBridge::setData(SpatialBuildResult buildResult) { std::vector promises; { std::lock_guard l(mutex_); - VELOX_CHECK(!buildVectors_.has_value(), "setData must be called only once"); - buildVectors_ = std::move(buildVectors); + VELOX_CHECK(!buildResult_.has_value(), "setData must be called only once"); + buildResult_ = std::move(buildResult); promises = std::move(promises_); } notify(std::move(promises)); } -std::optional> SpatialJoinBridge::dataOrFuture( +std::optional SpatialJoinBridge::dataOrFuture( ContinueFuture* future) { std::lock_guard l(mutex_); VELOX_CHECK(!cancelled_, "Getting data after the build side is aborted"); - if (buildVectors_.has_value()) { - return buildVectors_; + if (buildResult_.has_value()) { + return buildResult_.value(); } - promises_.emplace_back("SpatialJoinBridge::tableOrFuture"); + promises_.emplace_back("SpatialJoinBridge::dataOrFuture"); *future = promises_.back().getSemiFuture(); return std::nullopt; } @@ -50,7 +59,20 @@ SpatialJoinBuild::SpatialJoinBuild( nullptr, operatorId, joinNode->id(), - "SpatialJoinBuild") {} + OperatorType::kSpatialJoinBuild) { + const auto& buildType = joinNode->rightNode()->outputType(); + buildGeometryChannel_ = + buildType->getChildIdx(joinNode->buildGeometry()->name()); + VELOX_CHECK_EQ( + buildType->childAt(buildGeometryChannel_), + joinNode->buildGeometry()->type()); + if (joinNode->radius().has_value()) { + auto radiusVar = joinNode->radius().value(); + uint32_t radiusChannel = buildType->getChildIdx(radiusVar->name()); + VELOX_CHECK_EQ(buildType->childAt(radiusChannel), radiusVar->type()); + radiusChannel_ = radiusChannel; + } +} void SpatialJoinBuild::addInput(RowVectorPtr input) { if (input->size() > 0) { @@ -106,6 +128,88 @@ std::vector SpatialJoinBuild::mergeDataVectors() const { return merged; } +Envelope SpatialJoinBuild::readEnvelope( + const StringView& geometryBytes, + double radius) { +#ifdef VELOX_ENABLE_GEO + radius = std::max(radius, 0.0); + auto geosEnvelope = + common::geospatial::GeometryDeserializer::deserializeEnvelope( + geometryBytes); + if (geosEnvelope->isNull()) { + return Envelope::empty(); + } else { + return Envelope::from( + geosEnvelope->getMinX() - radius, + geosEnvelope->getMinY() - radius, + geosEnvelope->getMaxX() + radius, + geosEnvelope->getMaxY() + radius); + } +#else + // When VELOX_ENABLE_GEO is not set, return an envelope of infinite area + // to ensure all geometries are considered for spatial join + return Envelope::from( + -std::numeric_limits::infinity(), + -std::numeric_limits::infinity(), + std::numeric_limits::infinity(), + std::numeric_limits::infinity()); +#endif +} + +SpatialIndex SpatialJoinBuild::buildSpatialIndex( + const std::vector& data, + column_index_t geometryIdx, + std::optional radiusIdx) { + size_t numRows = 0; + for (auto& vector : data) { + numRows += vector->size(); + } + std::vector envelopes; + // TODO: Chunk the data to avoid allocating a large vector. + envelopes.reserve(numRows); + + DecodedVector radiusCol; + DecodedVector geometryCol; + vector_size_t offset = 0; + Envelope bounds = Envelope::empty(); + + for (auto& vector : data) { + const auto& rawGeometryCol = + vector->childAt(geometryIdx)->asChecked>(); + geometryCol.decode(*rawGeometryCol); + + auto constantZero = velox::BaseVector::createConstant( + velox::DOUBLE(), 0.0, vector->size(), pool()); + if (radiusIdx.has_value()) { + const auto& rawRadiusCol = + vector->childAt(radiusIdx.value())->asChecked>(); + radiusCol.decode(*rawRadiusCol); + } else { + radiusCol.decode(*constantZero); + } + + // TODO: Make a selectivity vector based on nulls and use for DecodedVector. + for (vector_size_t i = 0; i < vector->size(); ++i) { + if (geometryCol.isNullAt(i) || radiusCol.isNullAt(i)) { + // If geometry or radius is null, it will not match the predicate and so + // we should skip the envelope. + continue; + } + double radius = radiusCol.valueAt(i); + const StringView geometryBytes = geometryCol.valueAt(i); + Envelope envelope = SpatialJoinBuild::readEnvelope(geometryBytes, radius); + if (FOLLY_UNLIKELY(envelope.isEmpty())) { + continue; + } + envelope.rowIndex = offset + geometryCol.index(i); + bounds.merge(envelope); + envelopes.push_back(std::move(envelope)); + } + offset += vector->size(); + } + return SpatialIndex(std::move(bounds), std::move(envelopes)); +} + void SpatialJoinBuild::noMoreInput() { Operator::noMoreInput(); std::vector promises; @@ -142,10 +246,17 @@ void SpatialJoinBuild::noMoreInput() { } dataVectors_ = mergeDataVectors(); + SpatialIndex spatialIndex = + buildSpatialIndex(dataVectors_, buildGeometryChannel_, radiusChannel_); + SpatialBuildResult buildResult; + buildResult.spatialIndex = + std::make_shared(std::move(spatialIndex)); + buildResult.buildVectors = std::move(dataVectors_); + operatorCtx_->task() ->getSpatialJoinBridge( operatorCtx_->driverCtx()->splitGroupId, planNodeId()) - ->setData(std::move(dataVectors_)); + ->setData(std::move(buildResult)); } bool SpatialJoinBuild::isFinished() { diff --git a/velox/exec/SpatialJoinBuild.h b/velox/exec/SpatialJoinBuild.h index bea8fdd354b..daceada9f29 100644 --- a/velox/exec/SpatialJoinBuild.h +++ b/velox/exec/SpatialJoinBuild.h @@ -18,17 +18,23 @@ #include "velox/core/PlanNode.h" #include "velox/exec/JoinBridge.h" #include "velox/exec/Operator.h" +#include "velox/exec/SpatialIndex.h" namespace facebook::velox::exec { +struct SpatialBuildResult { + std::vector buildVectors; + std::shared_ptr spatialIndex; +}; + class SpatialJoinBridge : public JoinBridge { public: - void setData(std::vector buildVectors); + void setData(SpatialBuildResult buildResult); - std::optional> dataOrFuture(ContinueFuture* future); + std::optional dataOrFuture(ContinueFuture* future); private: - std::optional> buildVectors_; + std::optional buildResult_; }; class SpatialJoinBuild : public Operator { @@ -59,11 +65,25 @@ class SpatialJoinBuild : public Operator { Operator::close(); } - std::vector mergeDataVectors() const; + static Envelope readEnvelope( + const StringView& serializedGeometry, + double radius); private: + std::vector mergeDataVectors() const; + + SpatialIndex buildSpatialIndex( + const std::vector& data, + column_index_t geometryIdx, + std::optional radiusIdx); + std::vector dataVectors_; + // Channel of geometry variable used to build spatial index + column_index_t buildGeometryChannel_; + // Channel (if set) of radius variable used to build spatial index + std::optional radiusChannel_{}; + // Future for synchronizing with other Drivers of the same pipeline. All build // Drivers must be completed before making data available for the probe side. ContinueFuture future_{ContinueFuture::makeEmpty()}; diff --git a/velox/exec/SpatialJoinProbe.cpp b/velox/exec/SpatialJoinProbe.cpp index bc49d24cba4..7fbdc6f519d 100644 --- a/velox/exec/SpatialJoinProbe.cpp +++ b/velox/exec/SpatialJoinProbe.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ #include "velox/exec/SpatialJoinProbe.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/OperatorUtils.h" #include "velox/exec/SpatialJoinBuild.h" #include "velox/exec/Task.h" @@ -42,6 +43,107 @@ std::vector extractProjections( } // namespace +////////////////// +// OUTPUT BUILDER + +void SpatialJoinOutputBuilder::initializeOutput( + const RowVectorPtr& input, + memory::MemoryPool* pool) { + if (output_ == nullptr) { + output_ = + BaseVector::create(outputType_, outputBatchSize_, pool); + } else { + VectorPtr outputVector = std::move(output_); + BaseVector::prepareForReuse(outputVector, outputBatchSize_); + output_ = std::static_pointer_cast(outputVector); + } + probeOutputIndices_ = allocateIndices(outputBatchSize_, pool); + rawProbeOutputIndices_ = probeOutputIndices_->asMutable(); + + // Add probe side projections as dictionary vectors + for (const auto& projection : probeProjections_) { + output_->childAt(projection.outputChannel) = wrapChild( + outputBatchSize_, + probeOutputIndices_, + input->childAt(projection.inputChannel)); + } + + // Add build side projections as uninitialized vectors + for (const auto& projection : buildProjections_) { + auto child = output_->childAt(projection.outputChannel); + if (child == nullptr) { + child = BaseVector::create( + outputType_->childAt(projection.outputChannel), + outputBatchSize_, + operatorCtx_.pool()); + } + } +} + +void SpatialJoinOutputBuilder::addOutputRow( + vector_size_t probeRow, + vector_size_t buildRow) { + VELOX_CHECK_NOT_NULL(probeOutputIndices_); + // Probe side is always a dictionary; just populate the index. + rawProbeOutputIndices_[outputRow_] = probeRow; + + // For the build side, we accumulate the ranges to copy, then copy all of + // them at once. Consecutive records are copied in one memcpy. + if (!buildCopyRanges_.empty() && + (buildCopyRanges_.back().sourceIndex + buildCopyRanges_.back().count) == + buildRow) { + ++buildCopyRanges_.back().count; + } else { + buildCopyRanges_.push_back({buildRow, outputRow_, 1}); + } + ++outputRow_; +} + +void SpatialJoinOutputBuilder::copyBuildValues( + const RowVectorPtr& buildVector) { + if (buildCopyRanges_.empty()) { + return; + } + + VELOX_CHECK_NOT_NULL(output_); + + for (const auto& projection : buildProjections_) { + const auto& buildChild = buildVector->childAt(projection.inputChannel); + const auto& outputChild = output_->childAt(projection.outputChannel); + outputChild->copyRanges(buildChild.get(), buildCopyRanges_); + } + buildCopyRanges_.clear(); +} + +void SpatialJoinOutputBuilder::addProbeMismatchRow(vector_size_t probeRow) { + VELOX_CHECK_NOT_NULL(output_); + + // Probe side is always a dictionary; just populate the index. + rawProbeOutputIndices_[outputRow_] = probeRow; + + // Null out build projections. + for (const auto& projection : buildProjections_) { + const auto& outputChild = output_->childAt(projection.outputChannel); + outputChild->setNull(outputRow_, true); + } + ++outputRow_; +} + +RowVectorPtr SpatialJoinOutputBuilder::takeOutput() { + VELOX_CHECK(buildCopyRanges_.empty()); + if (outputRow_ == 0 || !output_) { + return nullptr; + } + RowVectorPtr output = std::move(output_); + output->resize(outputRow_); + output_ = nullptr; + outputRow_ = 0; + return output; +} + +//////////////////// +// SpatialJoinProbe + SpatialJoinProbe::SpatialJoinProbe( int32_t operatorId, DriverCtx* driverCtx, @@ -51,25 +153,42 @@ SpatialJoinProbe::SpatialJoinProbe( joinNode->outputType(), operatorId, joinNode->id(), - "SpatialJoinProbe"), + OperatorType::kSpatialJoinProbe), joinType_(joinNode->joinType()), outputBatchSize_{outputBatchRows()}, - joinNode_(joinNode) { - auto probeType = joinNode_->sources()[0]->outputType(); - auto buildType = joinNode_->sources()[1]->outputType(); + joinNode_(joinNode), + buildProjections_(extractProjections( + joinNode_->rightNode()->outputType(), + outputType_)), + outputBuilder_{ + outputBatchSize_, + outputType_, + extractProjections( + joinNode_->leftNode()->outputType(), + outputType_), // these are the identity Projections + buildProjections_, + *operatorCtx_} { + auto probeType = joinNode_->leftNode()->outputType(); identityProjections_ = extractProjections(probeType, outputType_); - buildProjections_ = extractProjections(buildType, outputType_); + probeGeometryChannel_ = + probeType->getChildIdx(joinNode_->probeGeometry()->name()); + VELOX_CHECK_EQ( + probeType->childAt(probeGeometryChannel_), + joinNode_->probeGeometry()->type()); } +///////// +// SETUP + void SpatialJoinProbe::initialize() { Operator::initialize(); - VELOX_CHECK(joinNode_ != nullptr); + VELOX_CHECK_NOT_NULL(joinNode_); if (joinNode_->joinCondition() != nullptr) { initializeFilter( joinNode_->joinCondition(), - joinNode_->sources()[0]->outputType(), - joinNode_->sources()[1]->outputType()); + joinNode_->leftNode()->outputType(), + joinNode_->rightNode()->outputType()); } joinNode_.reset(); @@ -151,32 +270,15 @@ void SpatialJoinProbe::close() { joinCondition_->clear(); } buildVectors_.reset(); + spatialIndex_.reset(); Operator::close(); } -void SpatialJoinProbe::addInput(RowVectorPtr input) { - VELOX_CHECK_NULL(input_); - - // In getOutput(), we are going to wrap input in dictionaries a few rows at a - // time. Since lazy vectors cannot be wrapped in different dictionaries, we - // are going to load them here. - for (auto& child : input->children()) { - child->loadedVector(); - } - input_ = std::move(input); - if (input_->size() > 0) { - probeSideEmpty_ = false; - } - VELOX_CHECK_EQ(buildIndex_, 0); -} - void SpatialJoinProbe::noMoreInput() { Operator::noMoreInput(); - if (state_ != ProbeOperatorState::kRunning || input_ != nullptr) { - return; + if (state_ == ProbeOperatorState::kRunning && input_ == nullptr) { + setState(ProbeOperatorState::kFinish); } - setState(ProbeOperatorState::kFinish); - return; } bool SpatialJoinProbe::getBuildData(ContinueFuture* future) { @@ -191,15 +293,56 @@ bool SpatialJoinProbe::getBuildData(ContinueFuture* future) { return false; } - buildVectors_ = std::move(buildData); + buildVectors_ = buildData.value().buildVectors; + spatialIndex_ = buildData.value().spatialIndex; return true; } +void SpatialJoinProbe::checkStateTransition(ProbeOperatorState state) { + VELOX_CHECK_NE(state_, state); + switch (state) { + case ProbeOperatorState::kRunning: + VELOX_CHECK_EQ(state_, ProbeOperatorState::kWaitForBuild); + break; + case ProbeOperatorState::kWaitForBuild: + [[fallthrough]]; + case ProbeOperatorState::kFinish: + VELOX_CHECK_EQ(state_, ProbeOperatorState::kRunning); + break; + default: + VELOX_UNREACHABLE(probeOperatorStateName(state_)); + break; + } +} + +//////////////// +// INPUT/OUTPUT + +void SpatialJoinProbe::addInput(RowVectorPtr input) { + VELOX_CHECK_NULL(input_); + VELOX_CHECK_EQ(probeRow_, 0); + VELOX_CHECK(!probeHasMatch_); + VELOX_CHECK_EQ(buildVectorIndex_, 0); + VELOX_CHECK_EQ(candidateIndex_, 0); + + // In getOutput(), we are going to wrap input in dictionaries a few rows at a + // time. Since lazy vectors cannot be wrapped in different dictionaries, we + // are going to load them here. + for (auto& child : input->children()) { + child->loadedVector(); + } + input_ = std::move(input); + decodedGeometryCol_.decode(*input_->childAt(probeGeometryChannel_) + ->asChecked>()); + ++probeCount_; +} + RowVectorPtr SpatialJoinProbe::getOutput() { if (state_ == ProbeOperatorState::kFinish || state_ == ProbeOperatorState::kWaitForPeers) { return nullptr; } + RowVectorPtr output{nullptr}; while (output == nullptr) { // Need more input. @@ -207,155 +350,156 @@ RowVectorPtr SpatialJoinProbe::getOutput() { break; } + // If the task owning this operator isn't running, there is no point + // to continue executing this procedure, which may be long in degenerate + // cases. Exit the working loop and let the Driver handle exiting + // gracefully in its own loop. + if (!operatorCtx_->task()->isRunning()) { + break; + } + + if (shouldYield()) { + break; + } + // Generate actual join output by processing probe and build matches, and // probe mismaches (for left joins). output = generateOutput(); } + + if (output != nullptr) { + ++outputCount_; + } return output; } RowVectorPtr SpatialJoinProbe::generateOutput() { - // If addToOutput() returns false, output_ is filled. Need to produce it. - if (!addToOutput()) { - VELOX_CHECK_GT(output_->size(), 0); - return std::move(output_); + VELOX_CHECK_NOT_NULL(input_); + VELOX_CHECK_GT(input_->size(), probeRow_); + outputBuilder_.initializeOutput(input_, pool()); + + while (!isOutputDone()) { + // Fill output_ with the results from one row. This may produce too + // much output and only partially complete. If so, the next time we + // call this we'll get the next chunk. + // + // addProbeRowOutput is responsible for advancing probeRow_. + addProbeRowOutput(); } - // Try to advance the probe cursor; call finish if no more probe input. - if (advanceProbe()) { + // If we've exhausted the input, release it. + if (probeRow_ >= input_->size()) { finishProbeInput(); - if (numOutputRows_ == 0) { - // output_ can only be re-used across probe rows within the same input_. - // Here we have to abandon the emtpy non-null output_ before we advance to - // the next probe input. - output_ = nullptr; - } } - if (!readyToProduceOutput()) { - return nullptr; - } - - output_->resize(numOutputRows_); - return std::move(output_); + return outputBuilder_.takeOutput(); } -bool SpatialJoinProbe::readyToProduceOutput() { - if (!output_ || numOutputRows_ == 0) { - return false; - } +// Return true if adding output stops early because output is full. +void SpatialJoinProbe::addProbeRowOutput() { + VELOX_CHECK(buildVectors_.has_value()); + VELOX_CHECK(!outputBuilder_.isOutputFull()); - // If the input_ has no remaining rows or the output_ is fully filled, - // it's right time for output. - return !input_ || numOutputRows_ >= outputBatchSize_; -} - -bool SpatialJoinProbe::advanceProbe() { - if (hasProbedAllBuildData()) { - probeRow_ += 1; - probeRowHasMatch_ = false; - buildIndex_ = 0; - - // If we finished processing the probe side. - if (probeRow_ >= input_->size()) { - return true; - } + // Find the candidates for each probe row from the spatial index. Only do + // this at the start for each row. + if (buildVectorIndex_ == 0 && candidateIndex_ == 0) { + candidateBuildRows_ = querySpatialIndex(); } - return false; -} -bool SpatialJoinProbe::addToOutput() { - VELOX_CHECK_NOT_NULL(input_); - prepareOutput(); - - while (!hasProbedAllBuildData()) { - const auto& currentBuild = buildVectors_.value()[buildIndex_]; - - // Empty build vector; move to the next. - if (currentBuild->size() == 0) { - ++buildIndex_; - buildRow_ = 0; - continue; - } - - // Only re-calculate the filter if we have a new build vector. - if (buildRow_ == 0) { - evaluateSpatialJoinFilter(currentBuild); + while (!isProbeRowDone()) { + addBuildVectorOutput(buildVectors_.value()[buildVectorIndex_]); + if (outputBuilder_.isOutputFull()) { + // If full, don't advance buildVectorIndex_ because we may not have + // exhausted the current vector. Return instead of breaking so that we + // can add a mismatch row later if necessary. + return; } + advanceBuildVector(); + } - // Iterate over the filter results. For each match, add an output record. - for (vector_size_t i = buildRow_; i < decodedFilterResult_.size(); ++i) { - if (!isSpatialJoinConditionMatch(i)) { - continue; - } + // Now that we have finished the probe row, check if we need to add a probe + // mismatch record. + if (!probeHasMatch_ && needsProbeMismatch(joinType_)) { + outputBuilder_.addProbeMismatchRow(probeRow_); + } + // Advance here instead of the loop in generateOutput so that early return on + // full doesn't advance the probe. + advanceProbeRow(); +} - addOutputRow(i); - ++numOutputRows_; - probeRowHasMatch_ = true; +void SpatialJoinProbe::addBuildVectorOutput(const RowVectorPtr& buildVector) { + if (FOLLY_UNLIKELY(needsFilterEvaluated_)) { + // Evaluate join filter for the whole vector just once. + evaluateJoinFilter(buildVector); + needsFilterEvaluated_ = false; + } - // If the buffer is full, save state and produce it as output. - if (numOutputRows_ == outputBatchSize_) { - buildRow_ = i + 1; - copyBuildValues(currentBuild); - return false; - } + // Start where we left off: after the last buildRow_ that was processed. + while (!isBuildVectorDone(buildVector)) { + vector_size_t buildRow = relativeBuildRow(candidateIndex_); + if (isJoinConditionMatch(candidateIndex_)) { + outputBuilder_.addOutputRow(probeRow_, buildRow); + probeHasMatch_ = true; } - // Before moving to the next build vector, copy the needed ranges. - copyBuildValues(currentBuild); - ++buildIndex_; - buildRow_ = 0; + // Advance candidateIndex_ even if full, since we're finished with this row. + ++candidateIndex_; } - // Check if the current probed row needs to be added as a mismatch (for left - // and full outer joins). - checkProbeMismatchRow(); - - // Signals that all input has been generated for the probeRow and build - // vectors; safe to move to the next probe record. - return true; + // Since we are copying from the current buildVector, we must copy here. + outputBuilder_.copyBuildValues(buildVector); } -void SpatialJoinProbe::prepareOutput() { - if (output_ != nullptr) { - return; +std::vector SpatialJoinProbe::querySpatialIndex() { + VELOX_CHECK(spatialIndex_.has_value()); + VELOX_CHECK_NOT_NULL(spatialIndex_.value()); + + if (decodedGeometryCol_.isNullAt(probeRow_)) { + return std::vector{}; } - std::vector localColumns(outputType_->size()); + // Always apply radius to build side, not probe side. + Envelope envelope = SpatialJoinBuild::readEnvelope( + decodedGeometryCol_.valueAt(probeRow_), 0 /* radius */); + std::vector candidates = spatialIndex_.value()->query(envelope); + std::sort(candidates.begin(), candidates.end()); - probeOutputIndices_ = allocateIndices(outputBatchSize_, pool()); - rawProbeOutputIndices_ = probeOutputIndices_->asMutable(); + return candidates; +} - for (const auto& projection : identityProjections_) { - localColumns[projection.outputChannel] = BaseVector::wrapInDictionary( - {}, - probeOutputIndices_, - outputBatchSize_, - input_->childAt(projection.inputChannel)); +BufferPtr SpatialJoinProbe::makeBuildVectorIndices(vector_size_t vectorSize) { + // Find the slice of candidates that are in this build vector. + vector_size_t endIndex = candidateIndex_; + for (; endIndex < candidateBuildRows_.size(); ++endIndex) { + if (relativeBuildRow(endIndex) >= vectorSize) { + break; + } } - // For other join types, add build side projections - for (const auto& projection : buildProjections_) { - localColumns[projection.outputChannel] = BaseVector::create( - outputType_->childAt(projection.outputChannel), - outputBatchSize_, - operatorCtx_->pool()); + // Make an index vector to fit the candidates. Populate each entry with its + // relative build row. + vector_size_t indexCount = + static_cast(endIndex - candidateIndex_); + auto rowIndices = allocateIndices(indexCount, operatorCtx_->pool()); + auto rawIndices = rowIndices->asMutable(); + for (vector_size_t idx = 0; idx < indexCount; ++idx) { + rawIndices[idx] = relativeBuildRow(idx + candidateIndex_); } - numOutputRows_ = 0; - output_ = std::make_shared( - pool(), outputType_, nullptr, outputBatchSize_, std::move(localColumns)); + return rowIndices; } -void SpatialJoinProbe::evaluateSpatialJoinFilter( - const RowVectorPtr& buildVector) { - // First step to process is to get a batch so we can evaluate the join - // filter. - auto filterInput = getNextCrossProductBatch( +void SpatialJoinProbe::evaluateJoinFilter(const RowVectorPtr& buildVector) { + // Get the indices of the rows in the build vector that are candidates. + auto candidateRowsBuffer = makeBuildVectorIndices(buildVector->size()); + + // Now get the input for the spatial join filter, one row per candidate. + auto filterInput = getNextJoinBatch( buildVector, filterInputType_, filterProbeProjections_, - filterBuildProjections_); + filterBuildProjections_, + candidateRowsBuffer); if (filterInputRows_.size() != filterInput->size()) { filterInputRows_.resizeFill(filterInput->size(), true); @@ -371,28 +515,28 @@ void SpatialJoinProbe::evaluateSpatialJoinFilter( decodedFilterResult_.decode(*filterOutput_, filterInputRows_); } -RowVectorPtr SpatialJoinProbe::getNextCrossProductBatch( +RowVectorPtr SpatialJoinProbe::getNextJoinBatch( const RowVectorPtr& buildVector, const RowTypePtr& outputType, const std::vector& probeProjections, - const std::vector& buildProjections) { + const std::vector& buildProjections, + BufferPtr candidateRows) const { VELOX_CHECK_GT(buildVector->size(), 0); + // candidateRows is a buffer of vector_size_t indices into buildVector + const vector_size_t numOutputRows = + candidateRows->size() / sizeof(vector_size_t); + if (numOutputRows == 0) { + return RowVector::createEmpty(outputType, pool()); + } - return genCrossProductMultipleBuildVectors( - buildVector, outputType, probeProjections, buildProjections); -} - -RowVectorPtr SpatialJoinProbe::genCrossProductMultipleBuildVectors( - const RowVectorPtr& buildVector, - const RowTypePtr& outputType, - const std::vector& probeProjections, - const std::vector& buildProjections) { std::vector projectedChildren(outputType->size()); - const vector_size_t numOutputRows = buildVector->size(); - // Project columns from the build side. projectChildren( - projectedChildren, buildVector, buildProjections, numOutputRows, nullptr); + projectedChildren, + buildVector, + buildProjections, + numOutputRows, + candidateRows); // Wrap projections from the probe side as constants. for (const auto [inputChannel, outputChannel] : probeProjections) { @@ -404,68 +548,14 @@ RowVectorPtr SpatialJoinProbe::genCrossProductMultipleBuildVectors( pool(), outputType, nullptr, numOutputRows, std::move(projectedChildren)); } -void SpatialJoinProbe::addOutputRow(vector_size_t buildRow) { - // Probe side is always a dictionary; just populate the index. - rawProbeOutputIndices_[numOutputRows_] = probeRow_; - - // For the build side, we accumulate the ranges to copy, then copy all of them - // at once. If records are consecutive and can have a single copy range run. - if (!buildCopyRanges_.empty() && - (buildCopyRanges_.back().sourceIndex + buildCopyRanges_.back().count) == - buildRow) { - ++buildCopyRanges_.back().count; - } else { - buildCopyRanges_.push_back({buildRow, numOutputRows_, 1}); - } -} - -void SpatialJoinProbe::copyBuildValues(const RowVectorPtr& buildVector) { - if (buildCopyRanges_.empty() || isLeftSemiProjectJoin(joinType_)) { - return; - } - - for (const auto& projection : buildProjections_) { - const auto& buildChild = buildVector->childAt(projection.inputChannel); - const auto& outputChild = output_->childAt(projection.outputChannel); - outputChild->copyRanges(buildChild.get(), buildCopyRanges_); - } - buildCopyRanges_.clear(); -} - -void SpatialJoinProbe::checkProbeMismatchRow() { - // If we are processing the last batch of the build side, check if we need - // to add a probe mismatch record. - if (needsProbeMismatch(joinType_) && hasProbedAllBuildData() && - !probeRowHasMatch_) { - prepareOutput(); - addProbeMismatchRow(); - ++numOutputRows_; - } -} - -void SpatialJoinProbe::addProbeMismatchRow() { - // Probe side is always a dictionary; just populate the index. - rawProbeOutputIndices_[numOutputRows_] = probeRow_; - - // Null out build projections. - for (const auto& projection : buildProjections_) { - const auto& outputChild = output_->childAt(projection.outputChannel); - outputChild->setNull(numOutputRows_, true); - } -} - void SpatialJoinProbe::finishProbeInput() { VELOX_CHECK_NOT_NULL(input_); input_.reset(); - buildIndex_ = 0; probeRow_ = 0; - if (!noMoreInput_) { - return; + if (noMoreInput_) { + setState(ProbeOperatorState::kFinish); } - - setState(ProbeOperatorState::kFinish); - return; } } // namespace facebook::velox::exec diff --git a/velox/exec/SpatialJoinProbe.h b/velox/exec/SpatialJoinProbe.h index 54b2cbe701d..e9fc2eea13a 100644 --- a/velox/exec/SpatialJoinProbe.h +++ b/velox/exec/SpatialJoinProbe.h @@ -18,16 +18,71 @@ #include "velox/core/PlanNode.h" #include "velox/exec/Operator.h" #include "velox/exec/ProbeOperatorState.h" +#include "velox/exec/SpatialIndex.h" namespace facebook::velox::exec { +class SpatialJoinOutputBuilder { + public: + SpatialJoinOutputBuilder( + vector_size_t outputBatchSize, + RowTypePtr outputType, + std::vector probeProjections, + std::vector buildProjections, + const OperatorCtx& operatorCtx) + : outputBatchSize_{outputBatchSize}, + outputType_{std::move(outputType)}, + probeProjections_{std::move(probeProjections)}, + buildProjections_{std::move(buildProjections)}, + operatorCtx_{operatorCtx} { + VELOX_CHECK_GT(outputBatchSize_, 0); + } + + void initializeOutput(const RowVectorPtr& input, memory::MemoryPool* pool); + + bool isOutputFull() const { + return outputRow_ >= outputBatchSize_; + } + + void addOutputRow(vector_size_t probeRow, vector_size_t buildRow); + + /// Checks if it is required to add a probe mismatch row, and does it if + /// needed. The caller needs to ensure there is available space in `output_` + /// for the new record, which has nulled out build projections. + void addProbeMismatchRow(vector_size_t probeRow); + + void copyBuildValues(const RowVectorPtr& buildVector); + + RowVectorPtr takeOutput(); + + private: + // Initialization parameters + const vector_size_t outputBatchSize_; + const RowTypePtr outputType_; + const std::vector probeProjections_; + const std::vector buildProjections_; + const OperatorCtx& operatorCtx_; + + // Output state + RowVectorPtr output_; + vector_size_t outputRow_{0}; + // Dictionary indices for probe columns for output vector. + BufferPtr probeOutputIndices_; + // Mutable pointer to probeOutputIndices_ + vector_size_t* rawProbeOutputIndices_{}; + + // Stores the ranges of build values to be copied to the output vector (we + // batch them and copy once, instead of copying them row-by-row). + std::vector buildCopyRanges_{}; +}; + /// Implements a Spatial Join between records from the probe (input_) -/// and build (SpatialJoinBridge) sides. It supports inner, left, right and -/// full outer joins. +/// and build (SpatialJoinBridge) sides. It supports inner and left joins. /// /// This class is designed to evaluate spatial join conditions (e.g. -/// ST_INTERSECTS, ST_CONTAINS, ST_WITHIN) between geometric data types. It can -/// also implement spatial cross-join semantics if joinCondition is nullptr. +/// ST_INTERSECTS, ST_CONTAINS, ST_WITHIN) between geometric data types. It +/// can also implement spatial cross-join semantics if joinCondition is +/// nullptr. /// /// The output follows the order of the probe side rows (for inner and left /// joins). All build vectors are materialized upfront (check buildVectors_), @@ -39,11 +94,14 @@ namespace facebook::velox::exec { /// 1. Materialize a cross-product batch across probe and build. /// 2. Evaluate the spatial join condition. /// 3. Add spatial matches to the output. -/// 4. Once all build vectors are processed for a particular probe row, check if +/// 4. Once all build vectors are processed for a particular probe row, check +/// if /// a probe mismatch is needed (only for left and full outer joins). -/// 5. Once all probe and build inputs are processed, check if build mismatches +/// 5. Once all probe and build inputs are processed, check if build +/// mismatches /// are needed (only for right and full outer joins). -/// 6. If so, signal other peer operators; only a single operator instance will +/// 6. If so, signal other peer operators; only a single operator instance +/// will /// collect all build matches at the end, and emit any records that haven't /// been matched by any of the peers. /// @@ -79,6 +137,13 @@ class SpatialJoinProbe : public Operator { void close() override; private: + void checkStateTransition(ProbeOperatorState state); + + void setState(ProbeOperatorState state) { + checkStateTransition(state); + state_ = state; + } + // Initialize spatial filter for evaluating spatial join conditions. void initializeFilter( const core::TypedExprPtr& filter, @@ -91,94 +156,14 @@ class SpatialJoinProbe : public Operator { // `buildVectors_` before it can produce output. bool getBuildData(ContinueFuture* future); - // Generates output from spatial join matches between probe and build sides, - // as well as probe mismatches (for left and full outer joins). As much as - // possible, generates outputs `outputBatchSize_` records at a time, but - // batches may be smaller in some cases - outputs follow the probe side buffer - // boundaries. + // Produce as much output as possible for the current input. RowVectorPtr generateOutput(); - // For non cross-join mode, the `output_` can be reused across multiple probe - // rows. If the input_ has remaining rows and the output_ is not fully filled, - // it returns false here. - bool readyToProduceOutput(); - - // Fill in joined output to `output_` by matching the current probeRow_ and - // successive build vectors (using getNextCrossProductBatch()). Stops when - // either all build vectors were matched for the current probeRow (returns - // true), or if the output is full (returns false). If it returns false, a - // valid vector with more than zero records will be available at `output_`; - // if it returns true, either nullptr or zero records may be placed at - // `output_`. Also if it returns true, it's the caller's responsibility to - // decide when to set `output_` size. - // - // Also updates `buildMatched_` if the build records that received a match, so - // that they can be used to implement right and full outer join semantic once - // all probe data has been processed. - bool addToOutput(); - - // Advances 'probeRow_' and resets required state information. Returns true - // if there is no more probe data to be processed in the current `input_` - // (and hence a new probe input is required). False otherwise. - bool advanceProbe(); - - // Ensures a new batch of records is available at `output_` and ready to - // receive rows. Batches have space for `outputBatchSize_`. - void prepareOutput(); - - // Evaluates the spatial joinCondition for a given build vector. This method - // sets `filterOutput_` and `decodedFilterResult_`, which will be ready to be - // used by `isSpatialJoinConditionMatch(buildRow)` below. - void evaluateSpatialJoinFilter(const RowVectorPtr& buildVector); - - // Checks if the spatial join condition matched for a particular row. - bool isSpatialJoinConditionMatch(vector_size_t i) const { - return ( - !decodedFilterResult_.isNullAt(i) && - decodedFilterResult_.valueAt(i)); + // Returns true if the input is exhausted or the output is full. + bool isOutputDone() const { + return probeRow_ >= input_->size() || outputBuilder_.isOutputFull(); } - // Generates the next batch of a cross product between probe and build. It - // should be used as the entry point, and will internally delegate to one of - // the three functions below. - // - // Output projections can be specified so that this function can be used to - // generate both filter input and actual output (in case there is no join - // filter - cross join). - RowVectorPtr getNextCrossProductBatch( - const RowVectorPtr& buildVector, - const RowTypePtr& outputType, - const std::vector& probeProjections, - const std::vector& buildProjections); - - // As a fallback, process the current probe row to as much build data as - // possible (probe row as constant, and flat copied data for build records). - RowVectorPtr genCrossProductMultipleBuildVectors( - const RowVectorPtr& buildVector, - const RowTypePtr& outputType, - const std::vector& probeProjections, - const std::vector& buildProjections); - - // Add a single record to `output_` based on buildRow from buildVector, and - // the current probeRow and probe vector (input_). Probe side projections are - // zero-copy (dictionary indices), and build side projections are marked to be - // copied using `buildCopyRanges_`; they will be copied later on by - // `copyBuildValues()`. - void addOutputRow(vector_size_t buildRow); - - // Copies the ranges from buildVector specified by `buildCopyRanges_` to - // `output_`, one projected column at a time. Clears buildCopyRanges_. - void copyBuildValues(const RowVectorPtr& buildVector); - - // Checks if it is required to add a probe mismatch row, and does it if - // needed. The caller needs to ensure there is available space in `output_` - // for the new record, which has nulled out build projections. - void checkProbeMismatchRow(); - - // Add a probe mismatch (only for left/full outer joins). The record is based - // on the current probeRow and vector (input_) and build projections are null. - void addProbeMismatchRow(); - // Called when we are done processing the current probe batch, to signal we // are ready for the next one. // @@ -186,118 +171,226 @@ class SpatialJoinProbe : public Operator { // change the operator state to signal peers. void finishProbeInput(); - // Whether we have processed all build data for the current probe row (based - // on buildIndex_'s value). - bool hasProbedAllBuildData() const { - return (buildIndex_ >= buildVectors_.value().size()); - } + // Add the output for a single probe row. This will return early if the + // output vector is full. + void addProbeRowOutput(); - // If build has a single vector, we can wrap probe and build batches into - // dictionaries and produce as many combinations of probe and build rows, - // until `numOutputRows_` is filled. - bool isSingleBuildVector() const { - return buildVectors_->size() == 1; + // Returns true if all output for the current probe row has been produced. + bool isProbeRowDone() const { + return candidateIndex_ >= candidateBuildRows_.size() || + buildVectorIndex_ >= buildVectors_.value().size(); } - // If there are no incoming records in the build side. - bool isBuildSideEmpty() const { - return buildVectors_->empty(); + // Increment probeRow_ and reset associated fields + void advanceProbeRow() { + ++probeRow_; + probeHasMatch_ = false; + buildVectorIndex_ = 0; + candidateIndex_ = 0; + candidateOffsetForCurrentBuildVector_ = 0; + buildRowOffset_ = 0; + needsFilterEvaluated_ = true; } - // If build has a single row, we can simply add it as a constant to probe - // batches. - bool isSingleBuildRow() const { - return isSingleBuildVector() && buildVectors_->front()->size() == 1; + // Add the output for a single build vector for a single probe row. This will + // return early if the output vector is full. + void addBuildVectorOutput(const RowVectorPtr& buildVector); + + // Returns true if all the rows for the current build vector have been + // processed, or the output is full. + bool isBuildVectorDone(const RowVectorPtr& buildVector) const { + // Note that candidateBuildRows_ entries are row numbers across + // all build vectors. + return candidateIndex_ >= candidateBuildRows_.size() || + relativeBuildRow(candidateIndex_) >= buildVector->size() || + outputBuilder_.isOutputFull(); } - // TODO: Add state transition check. - void setState(ProbeOperatorState state) { - state_ = state; + // Increment buildVectorIndex_ and reset associated fields + void advanceBuildVector() { + VELOX_CHECK(buildVectors_.has_value()); + + buildRowOffset_ += buildVectors_.value()[buildVectorIndex_]->size(); + ++buildVectorIndex_; + needsFilterEvaluated_ = true; + candidateOffsetForCurrentBuildVector_ = candidateIndex_; } - const core::JoinType joinType_; + // Calculate candidate build rows from spatialIndex_ for the current probe + // row. This should be done each time the probe is advanced. + std::vector querySpatialIndex(); - // Output buffer members. + // Evaluates the spatial joinCondition for a given build vector. This method + // sets `filterOutput_` and `decodedFilterResult_`, which will be ready to + // be used by `isSpatialJoinConditionMatch()` below. + // This only evaluates rows that are in the candidateBuildRows_, restricted to + // those in the current build vector. Thus we must index into this with + // candidateIndex_. + void evaluateJoinFilter(const RowVectorPtr& buildVector); - // Maximum number of rows in the output batch. - const vector_size_t outputBatchSize_; + // Checks if the spatial join condition matched for a particular row. + bool isJoinConditionMatch(vector_size_t candidateIndex) const { + vector_size_t relativeIndex = + candidateIndex - candidateOffsetForCurrentBuildVector_; + VELOX_CHECK_GT(decodedFilterResult_.size(), relativeIndex); + return ( + !decodedFilterResult_.isNullAt(relativeIndex) && + decodedFilterResult_.valueAt(relativeIndex)); + } - // The current output batch being populated. - RowVectorPtr output_; + // Generates the next batch of a cross product between probe and build using + // the supplied projections. It uses the current probe row as constant, and + // flat copied data for build records. + RowVectorPtr getNextJoinBatch( + const RowVectorPtr& buildVector, + const RowTypePtr& outputType, + const std::vector& probeProjections, + const std::vector& buildProjections, + BufferPtr candidateRows) const; + + // Given a candidate index, return the row index into the current build + // vector. For example, if we have candidates [2, 50, 81] and have processed + // two build vectors with size 30 and 40, then `relativeBuildRow(2) == 11` + // (81 - (30 + 40)). + vector_size_t relativeBuildRow(vector_size_t candidateRow) const { + return candidateBuildRows_[candidateRow] - buildRowOffset_; + } - // Number of output rows in the current output batch. - vector_size_t numOutputRows_{0}; + // Make the indices of build vector candidates suitable for creating a + // DictionaryVector. + BufferPtr makeBuildVectorIndices(vector_size_t vectorSize); - // Dictionary indices for probe columns used to generate cross-product. - BufferPtr probeIndices_; + ///////// + // SETUP + // Variables set during operator setup that are used during execution. + // These should not be modified after the operator is initialized. - // Dictionary indices for probe columns for output vector. - BufferPtr probeOutputIndices_; - vector_size_t* rawProbeOutputIndices_{}; + const core::JoinType joinType_; - // Dictionary indices for build columns. - BufferPtr buildIndices_; + // Maximum number of rows in the output batch. + const vector_size_t outputBatchSize_; - // Spatial join condition expression. + // Join metadata and state. + std::shared_ptr joinNode_; + // Spatial join condition expression. // Must not be null std::unique_ptr joinCondition_; // Input type for the spatial join condition expression. RowTypePtr filterInputType_; - // Spatial join condition evaluation state that need to persisted across the - // generation of successive output buffers. - SelectivityVector filterInputRows_; - VectorPtr filterOutput_; - DecodedVector decodedFilterResult_; + // List of output projections from the build side. Note that the list of + // projections from the probe side is available at `identityProjections_`. + std::vector buildProjections_; - // Join metadata and state. - std::shared_ptr joinNode_; + // Projections needed as input to the filter to evaluation spatial join + // filter conditions. Note that if this is a cross-join, filter projections + // are the same as output projections. + std::vector filterProbeProjections_; + std::vector filterBuildProjections_; - ProbeOperatorState state_{ProbeOperatorState::kWaitForBuild}; - ContinueFuture future_{ContinueFuture::makeEmpty()}; + // Stores the build spatial index for the join + std::optional> spatialIndex_; + // Stores the data for build vectors (right side of the join). + std::optional> buildVectors_; - // Probe side state. + // Channel of geometry variable used to probe spatial index + column_index_t probeGeometryChannel_; - // Probe row being currently processed (related to `input_`). - vector_size_t probeRow_{0}; + ////////////////// + // OPERATOR STATE + // Variables used to track the general operator state during exection. + // These will change throughout setup and execution. - // Whether the current probeRow_ has produces a match. Used for left and full - // outer joins. - bool probeRowHasMatch_{false}; + ProbeOperatorState state_{ProbeOperatorState::kWaitForBuild}; + ContinueFuture future_{ContinueFuture::makeEmpty()}; - // Indicate if the probe side has empty input or not. For the last probe, - // this indicates if all the probe sides are empty or not. This flag is used - // for mismatched output producing. - bool probeSideEmpty_{true}; + // The information needed to produce an output RowVectorPtr. It is stored + // for all execution, but is reset on each output batch. + SpatialJoinOutputBuilder outputBuilder_; - // Build side state. + // Count of output batches produced (1-indexed). Primarily for debugging. + size_t outputCount_{0}; - // Stores the data for build vectors (right side of the join). - std::optional> buildVectors_; + // This is always set to all true, but we need it for eval/etc. Reuse between + // evaluations. + SelectivityVector filterInputRows_; + // The output result of the join condition evaluation on the **current** + // build vector. We must index into this with + // `candidateIndex_ - candidateOffsetForCurrentBuildVector_`. + VectorPtr filterOutput_; + // Decoded filterOutput: remove recursive dictionary/etc encodings. + // Like filterOutput_, this is only for the current build vector and we + // must index into this with + // `candidateIndex_ - candidateOffsetForCurrentBuildVector_`. + DecodedVector decodedFilterResult_; - // Index into `buildVectors_` for the build vector being currently processed. - size_t buildIndex_{0}; + // Decoded geometry vector. Must be reset whenever input_ is changed (it + // maintains a pointer to input_). + DecodedVector decodedGeometryCol_{}; - // Row being currently processed from `buildVectors_[buildIndex_]`. - vector_size_t buildRow_{0}; + /////////////// + // PROBE STATE + // Variables used to track the probe-side state state during exection. + // These will change throughout setup and execution. - // Stores the ranges of build values to be copied to the output vector (we - // batch them and copy once, instead of copying them row-by-row). - std::vector buildCopyRanges_; + // Count of probe batches added (1-indexed). Primarily for debugging. + size_t probeCount_{0}; - // List of output projections from the build side. Note that the list of - // projections from the probe side is available at `identityProjections_`. - std::vector buildProjections_; + // Probe row being currently processed (related to `input_`). + vector_size_t probeRow_{0}; - // Projections needed as input to the filter to evaluation spatial join filter - // conditions. Note that if this is a cross-join, filter projections are the - // same as output projections. - std::vector filterProbeProjections_; - std::vector filterBuildProjections_; + // Whether the current probeRow_ has found a match. Needed for left join. + bool probeHasMatch_{false}; - BufferPtr buildOutMapping_; + /////////////// + // BUILD STATE + // Variables used to track the build-side state state during exection. + // These will change throughout setup and execution. + // + // The build rows are stored in a vector of RowVectorPtrs. These are + // conceptually indexed by an absolute build row, which indexes into a + // flattened vector of rows. buildVectorIndex_ is the index to the current + // RowVectorPtr in buildVectors_, buildRowOffset_ is the sum of the sizes + // of the previous build vectors and should be subtracted from buildRow + // to index into the current build vector. + // + // We primarily use candidateBuildRows_, which is a vector of (absolute) + // build rows. candidateIndex_ indexes the entry in candidateBuildRows_, + // so candidateBuildrows_[candidiateIndex_] is the absolute build row + // of the current candidate. + + // Whether we need to evaluate the join filter on this build vector. It + // should be done once per build vector/probe row pair. + bool needsFilterEvaluated_{true}; + + // Index into `buildVectors_` for the build vector being currently + // processed. + size_t buildVectorIndex_{0}; + + // Keep track of how many build rows we've traversed in previous build + // RowVectors. Subtract this from the current element in candidateBuildRows_ + // to index into the current build RowVector. + vector_size_t buildRowOffset_{0}; + + // Build rows returned from the spatial index. + // The value is the row number over all build vectors, so if the have two + // build vectors of size 100 and 200, candidate row 50 is the 50th entry of + // the first vector, and 101 is the 2nd entry of the second vector. + std::vector candidateBuildRows_{}; + + // Index of candidate currently being processed from + // `buildVectors_[buildIndex_]`. + vector_size_t candidateIndex_{0}; + + // How many candidates were in previous build vectors. + // This is important because for each build vector, we calculate a + // decodedFilterResult_ with only the rows from from the candidates in + // that build vector. candidateIndex_ indexes over _all_ candidates, so + // we must substract candidateOffsetForCurrentBuildVector_ to index into the + // candidates for this build vector. + vector_size_t candidateOffsetForCurrentBuildVector_{0}; }; } // namespace facebook::velox::exec diff --git a/velox/exec/Spill.cpp b/velox/exec/Spill.cpp index 376d06502a3..193d86265c6 100644 --- a/velox/exec/Spill.cpp +++ b/velox/exec/Spill.cpp @@ -15,14 +15,82 @@ */ #include "velox/exec/Spill.h" +#include "velox/common/Casts.h" #include "velox/common/base/RuntimeMetrics.h" #include "velox/common/file/FileSystems.h" #include "velox/common/testutil/TestValue.h" +#include "velox/exec/OperatorUtils.h" #include "velox/serializers/PrestoSerializer.h" using facebook::velox::common::testutil::TestValue; namespace facebook::velox::exec { + +namespace { +/// gatherMerge merges & sorts with the mergeTree and gatherCopy the +/// results into target. 'target' is the result RowVector, and the copying +/// starts from row 0 up to row target.size(). 'mergeTree' is the data source. +/// 'totalNumRows' is the actual num of rows that is copied to target. +/// 'bufferSources' and 'bufferSourceIndices' are buffering vectors that could +/// be reused across calls. +void gatherMerge( + RowVectorPtr& target, + TreeOfLosers& mergeTree, + int32_t& totalNumRows, + std::vector& bufferSources, + std::vector& bufferSourceIndices) { + VELOX_CHECK_GE(bufferSources.size(), target->size()); + VELOX_CHECK_GE(bufferSourceIndices.size(), target->size()); + totalNumRows = 0; + int32_t numBatchRows = 0; + bool endOfBatch = false; + for (auto currentStream = mergeTree.next(); + currentStream != nullptr && totalNumRows + numBatchRows < target->size(); + currentStream = mergeTree.next()) { + bufferSources[numBatchRows] = ¤tStream->current(); + bufferSourceIndices[numBatchRows] = + currentStream->currentIndex(&endOfBatch); + ++numBatchRows; + if (FOLLY_UNLIKELY(endOfBatch)) { + // The stream is at end of input batch. Need to copy out the rows before + // fetching next batch in 'pop'. + gatherCopy( + target.get(), + totalNumRows, + numBatchRows, + bufferSources, + bufferSourceIndices); + totalNumRows += numBatchRows; + numBatchRows = 0; + } + // Advance the stream. + currentStream->pop(); + } + VELOX_CHECK_LE(totalNumRows + numBatchRows, target->size()); + + if (FOLLY_LIKELY(numBatchRows != 0)) { + gatherCopy( + target.get(), + totalNumRows, + numBatchRows, + bufferSources, + bufferSourceIndices); + totalNumRows += numBatchRows; + numBatchRows = 0; + } +} +} // namespace + +void testingGatherMerge( + RowVectorPtr& target, + TreeOfLosers& mergeTree, + int32_t& totalNumRows, + std::vector& bufferSources, + std::vector& bufferSourceIndices) { + gatherMerge( + target, mergeTree, totalNumRows, bufferSources, bufferSourceIndices); +} + void SpillMergeStream::pop() { VELOX_CHECK(!closed_); if (++index_ >= size_) { @@ -70,7 +138,7 @@ SpillState::SpillState( common::CompressionKind compressionKind, const std::optional& prefixSortConfig, memory::MemoryPool* pool, - folly::Synchronized* stats, + exec::SpillStats* stats, const std::string& fileCreateConfig) : getSpillDirPathCb_(getSpillDirPathCb), updateAndCheckSpillLimitCb_(updateAndCheckSpillLimitCb), @@ -110,8 +178,8 @@ std::vector SpillState::makeSortingKeys( void SpillState::setPartitionSpilled(const SpillPartitionId& id) { VELOX_DCHECK(!spilledPartitionIdSet_.contains(id)); spilledPartitionIdSet_.emplace(id); - ++stats_->wlock()->spilledPartitions; - common::incrementGlobalSpilledPartitionStats(); + stats_->spilledPartitions.fetch_add(1, std::memory_order_relaxed); + incrementGlobalSpilledPartitionStats(); } /*static*/ @@ -119,17 +187,17 @@ void SpillState::validateSpillBytesSize(uint64_t bytes) { static constexpr uint64_t kMaxSpillBytesPerWrite = std::numeric_limits::max(); if (bytes >= kMaxSpillBytesPerWrite) { - VELOX_GENERIC_SPILL_FAILURE(fmt::format( - "Spill bytes will overflow. Bytes {}, kMaxSpillBytesPerWrite: {}", - bytes, - kMaxSpillBytesPerWrite)); + VELOX_GENERIC_SPILL_FAILURE( + fmt::format( + "Spill bytes will overflow. Bytes {}, kMaxSpillBytesPerWrite: {}", + bytes, + kMaxSpillBytesPerWrite)); } } void SpillState::updateSpilledInputBytes(uint64_t bytes) { - auto statsLocked = stats_->wlock(); - statsLocked->spilledInputBytes += bytes; - common::updateGlobalSpillMemoryBytes(bytes); + stats_->spilledInputBytes.fetch_add(bytes, std::memory_order_relaxed); + updateGlobalSpillMemoryBytes(bytes); } uint64_t SpillState::appendToPartition( @@ -278,13 +346,14 @@ std::unique_ptr> SpillPartition::createUnorderedReader( uint64_t bufferSize, memory::MemoryPool* pool, - folly::Synchronized* spillStats) { + exec::SpillStats* spillStats) { VELOX_CHECK_NOT_NULL(pool); std::vector> streams; streams.reserve(files_.size()); for (auto& fileInfo : files_) { - streams.push_back(FileSpillBatchStream::create( - SpillReadFile::create(fileInfo, bufferSize, pool, spillStats))); + streams.push_back( + FileSpillBatchStream::create( + SpillReadFile::create(fileInfo, bufferSize, pool, spillStats))); } files_.clear(); return std::make_unique>( @@ -292,15 +361,16 @@ SpillPartition::createUnorderedReader( } std::unique_ptr> -SpillPartition::createOrderedReader( +SpillPartition::createOrderedReaderInternal( uint64_t bufferSize, memory::MemoryPool* pool, - folly::Synchronized* spillStats) { + exec::SpillStats* spillStats) { std::vector> streams; streams.reserve(files_.size()); for (auto& fileInfo : files_) { - streams.push_back(FileSpillMergeStream::create( - SpillReadFile::create(fileInfo, bufferSize, pool, spillStats))); + streams.push_back( + FileSpillMergeStream::create( + SpillReadFile::create(fileInfo, bufferSize, pool, spillStats))); } files_.clear(); // Check if the partition is empty or not. @@ -310,6 +380,174 @@ SpillPartition::createOrderedReader( return std::make_unique>(std::move(streams)); } +namespace { +size_t estimateOutputBatchRows( + const std::vector>& streams, + vector_size_t maxRows, + size_t maxBytes) { + size_t numEstimations{0}; + int64_t totalEstimatedBytes{0}; + for (const auto& stream : streams) { + const auto streamEstimateRowSize = stream->estimateRowSize(); + if (streamEstimateRowSize.has_value()) { + ++numEstimations; + totalEstimatedBytes += streamEstimateRowSize.value(); + } + } + + if (numEstimations == 0) { + return maxRows; + } + + const auto estimateRowSize = + std::max(1, totalEstimatedBytes / numEstimations); + return std::min( + std::max(1, maxBytes / estimateRowSize), maxRows); +} + +// This contains batching parameters and various kinds of batching buffers that +// are reused across multiple merging rounds. +struct SpillFileMergeParams { + static constexpr size_t kDefaultMaxBatchRows = 1'000; + static constexpr size_t kDefaultMaxBatchBytes = 64 * 1024; + + SpillFileMergeParams( + const TypePtr& type, + memory::MemoryPool* pool, + const vector_size_t _maxBatchRows = kDefaultMaxBatchRows, + const size_t _maxBatchBytes = kDefaultMaxBatchBytes) + : maxBatchRows(_maxBatchRows), maxBatchBytes(_maxBatchBytes) { + rowVector = std::static_pointer_cast( + BaseVector::create(type, maxBatchRows, pool)); + spillSources.resize(maxBatchRows); + spillSourceRows.resize(maxBatchRows); + } + + const vector_size_t maxBatchRows; + const size_t maxBatchBytes; + RowVectorPtr rowVector; + std::vector spillSources; + std::vector spillSourceRows; +}; + +SpillFileInfo mergeSpillFiles( + const std::vector& files, + const std::string& pathPrefix, + const common::UpdateAndCheckSpillLimitCB& updateAndCheckSpillLimitCb, + const std::string& fileCreateConfig, + uint64_t readBufferSize, + uint64_t writeBufferSize, + SpillFileMergeParams& mergeParams, + memory::MemoryPool* pool, + exec::SpillStats* spillStats) { + VELOX_CHECK_GT(files.size(), 0); + std::vector> streams; + streams.reserve(files.size()); + for (const auto& fileInfo : files) { + streams.push_back( + FileSpillMergeStream::create( + SpillReadFile::create(fileInfo, readBufferSize, pool, spillStats))); + } + const auto batchRows = estimateOutputBatchRows( + streams, mergeParams.maxBatchRows, mergeParams.maxBatchBytes); + + auto mergeTree = + std::make_unique>(std::move(streams)); + const auto type = files[0].type; + + auto writer = std::make_unique( + type, + files[0].sortingKeys, + files[0].compressionKind, + pathPrefix, + std::numeric_limits::max(), + writeBufferSize, + fileCreateConfig, + updateAndCheckSpillLimitCb, + pool, + spillStats); + + while (mergeTree->next()) { + VectorPtr tmpRowVector = std::move(mergeParams.rowVector); + BaseVector::prepareForReuse(tmpRowVector, batchRows); + mergeParams.rowVector = checkedPointerCast(tmpRowVector); + mergeParams.rowVector->resize(batchRows); + int32_t outputRow = 0; + gatherMerge( + mergeParams.rowVector, + *mergeTree, + outputRow, + mergeParams.spillSources, + mergeParams.spillSourceRows); + + IndexRange range{0, outputRow}; + writer->write(mergeParams.rowVector, folly::Range(&range, 1)); + } + auto resultFiles = writer->finish(); + VELOX_CHECK_EQ(resultFiles.size(), 1); + return std::move(resultFiles[0]); +} + +struct SpillFileCompare { + bool operator()(const SpillFileInfo& lhs, const SpillFileInfo& rhs) const { + return lhs.size > rhs.size; + } +}; +using SpillFileHeap = std:: + priority_queue, SpillFileCompare>; +} // namespace + +std::unique_ptr> +SpillPartition::createOrderedReader( + const common::SpillConfig& spillConfig, + memory::MemoryPool* pool, + exec::SpillStats* spillStats) { + const auto numMaxMergeFiles = spillConfig.numMaxMergeFiles; + VELOX_CHECK_NE(numMaxMergeFiles, 1); + if (numMaxMergeFiles == 0 || files_.size() <= numMaxMergeFiles) { + return createOrderedReaderInternal( + spillConfig.readBufferSize, pool, spillStats); + } + + SpillFileHeap orderedFiles(files_.begin(), files_.end()); + SpillFiles files; + files.reserve(numMaxMergeFiles); + const auto mergeFilePathPrefix = files_[0].path; + SpillFileMergeParams mergeParams(files_[0].type, pool); + + // Recursively merge the files. + for (uint32_t round = 0; orderedFiles.size() > numMaxMergeFiles; ++round) { + const uint64_t numMergeFiles = std::min( + static_cast(numMaxMergeFiles), + static_cast(orderedFiles.size() + 1 - numMaxMergeFiles)); + // Choose the top 'numMergeFiles' smallest files for merging to minimize IO. + for (uint32_t i = 0; i < numMergeFiles; i++) { + files.push_back(orderedFiles.top()); + orderedFiles.pop(); + } + auto mergedFile = mergeSpillFiles( + files, + fmt::format("{}-merge-round-{}", mergeFilePathPrefix, round), + spillConfig.updateAndCheckSpillLimitCb, + spillConfig.fileCreateConfig, + spillConfig.readBufferSize, + spillConfig.writeBufferSize, + mergeParams, + pool, + spillStats); + orderedFiles.push(mergedFile); + files.clear(); + } + + files_.clear(); + while (!orderedFiles.empty()) { + files_.push_back(orderedFiles.top()); + orderedFiles.pop(); + } + return createOrderedReaderInternal( + spillConfig.readBufferSize, pool, spillStats); +} + IterableSpillPartitionSet::IterableSpillPartitionSet() { spillPartitionIter_ = spillPartitions_.begin(); } @@ -436,6 +674,8 @@ std::unique_ptr ConcatFilesSpillBatchStream::create( } bool ConcatFilesSpillBatchStream::nextBatch(RowVectorPtr& batch) { + TestValue::adjust( + "facebook::velox::exec::ConcatFilesSpillBatchStream::nextBatch", nullptr); VELOX_CHECK_NULL(batch); VELOX_CHECK(!atEnd_); for (; fileIndex_ < spillFiles_.size(); ++fileIndex_) { @@ -454,10 +694,11 @@ bool ConcatFilesSpillBatchStream::nextBatch(RowVectorPtr& batch) { SpillPartitionId::SpillPartitionId(uint32_t partitionNumber) : encodedId_(partitionNumber) { if (FOLLY_UNLIKELY(partitionNumber >= (1 << kMaxPartitionBits))) { - VELOX_FAIL(fmt::format( - "Partition number {} exceeds max partition number {}", - partitionNumber, - 1 << kMaxPartitionBits)); + VELOX_FAIL( + fmt::format( + "Partition number {} exceeds max partition number {}", + partitionNumber, + 1 << kMaxPartitionBits)); } } @@ -466,10 +707,11 @@ SpillPartitionId::SpillPartitionId( uint32_t partitionNumber) { const auto childSpillLevel = parent.spillLevel() + 1; if (FOLLY_UNLIKELY(childSpillLevel > kMaxSpillLevel)) { - VELOX_FAIL(fmt::format( - "Spill level {} exceeds max spill level {}", - childSpillLevel, - kMaxSpillLevel)); + VELOX_FAIL( + fmt::format( + "Spill level {} exceeds max spill level {}", + childSpillLevel, + kMaxSpillLevel)); } encodedId_ = parent.encodedId_; encodedId_ = encodedId_ & ~kSpillLevelBitMask; diff --git a/velox/exec/Spill.h b/velox/exec/Spill.h index 18fd2cbfcb1..c523c6ed96b 100644 --- a/velox/exec/Spill.h +++ b/velox/exec/Spill.h @@ -20,12 +20,12 @@ #include #include "velox/common/base/SpillConfig.h" -#include "velox/common/base/SpillStats.h" +#include "velox/common/base/TreeOfLosers.h" #include "velox/common/compression/Compression.h" #include "velox/common/file/File.h" #include "velox/common/file/FileSystems.h" #include "velox/exec/SpillFile.h" -#include "velox/exec/TreeOfLosers.h" +#include "velox/exec/SpillStats.h" #include "velox/exec/UnorderedStreamReader.h" #include "velox/exec/VectorHasher.h" #include "velox/vector/ComplexVector.h" @@ -33,6 +33,23 @@ #include "velox/vector/VectorStream.h" namespace facebook::velox::exec { + +class SpillMergeStream; + +/// Testing gatherMerge without exposing the interface in the header. Used in +/// test only. gatherMerge merges & sorts with the mergeTree and gatherCopy the +/// results into target. 'target' is the result RowVector, and the copying +/// starts from row 0 up to row target.size(). 'mergeTree' is the data source. +/// 'totalNumRows' is the actual num of rows that is copied to target. +/// 'bufferSources' and 'bufferSourceIndices' are buffering vectors that could +/// be reused across calls. +void testingGatherMerge( + RowVectorPtr& target, + TreeOfLosers& mergeTree, + int32_t& totalNumRows, + std::vector& bufferSources, + std::vector& bufferSourceIndices); + class VectorHasher; /// A source of sorted spilled RowVectors coming either from a file or memory. @@ -80,6 +97,15 @@ class SpillMergeStream : public MergeStream { return decoded_[index]; } + /// Returns the estimated row size based on the vector received from the + /// merge source. + std::optional estimateRowSize() const { + if (rowVector_ == nullptr || rowVector_->size() == 0) { + return std::nullopt; + } + return rowVector_->estimateFlatSize() / rowVector_->size(); + } + protected: virtual const std::vector& sortingKeys() const = 0; @@ -477,22 +503,33 @@ class SpillPartition { std::unique_ptr> createUnorderedReader( uint64_t bufferSize, memory::MemoryPool* pool, - folly::Synchronized* spillStats); + exec::SpillStats* spillStats); + + /// Create an ordered stream reader from this spill partition. If the + /// partition has more than spillConfig.numMaxMergeFiles files, the files will + /// be pre-merged recursively to make sure the final ordered reader reads no + /// more than numMaxMergeFiles files. This behavior is to avoid OOM problem + /// when opening and reading too many files at the same time. If + /// numMaxMergeFiles < 2, the merge way is unlimited. + std::unique_ptr> createOrderedReader( + const common::SpillConfig& spillConfig, + memory::MemoryPool* pool, + exec::SpillStats* spillStats); + std::string toString() const; + + private: /// Invoked to create an ordered stream reader from this spill partition. /// The created reader will take the ownership of the spill files. /// 'bufferSize' specifies the read size from the storage. If the file /// system supports async read mode, then reader allocates two buffers with /// one buffer prefetch ahead. 'spillStats' is provided to collect the spill /// stats when reading data from spilled files. - std::unique_ptr> createOrderedReader( + std::unique_ptr> createOrderedReaderInternal( uint64_t bufferSize, memory::MemoryPool* pool, - folly::Synchronized* spillStats); + exec::SpillStats* spillStats); - std::string toString() const; - - private: SpillPartitionId id_; SpillFiles files_; // Counts the total file size in bytes from this spilled partition. @@ -549,7 +586,7 @@ class SpillState { /// 'numSortKeys' is the number of leading columns on which the data is /// sorted, 0 if only hash partitioning is used. 'targetFileSize' is the /// target size of a single file. 'pool' owns the memory for state and - /// results. + /// results. 'ioStats' is used to collect filesystem I/O stats. SpillState( const common::GetSpillDirectoryPathCB& getSpillDirectoryPath, const common::UpdateAndCheckSpillLimitCB& updateAndCheckSpillLimitCb, @@ -560,7 +597,7 @@ class SpillState { common::CompressionKind compressionKind, const std::optional& prefixSortConfig, memory::MemoryPool* pool, - folly::Synchronized* stats, + exec::SpillStats* stats, const std::string& fileCreateConfig = {}); static std::vector makeSortingKeys( @@ -665,7 +702,7 @@ class SpillState { const std::optional prefixSortConfig_; const std::string fileCreateConfig_; memory::MemoryPool* const pool_; - folly::Synchronized* const stats_; + exec::SpillStats* const stats_; // A set of spilled partition ids. SpillPartitionIdSet spilledPartitionIdSet_; diff --git a/velox/exec/SpillFile.cpp b/velox/exec/SpillFile.cpp index a5ad7d53d7e..cb4647c8ca8 100644 --- a/velox/exec/SpillFile.cpp +++ b/velox/exec/SpillFile.cpp @@ -16,8 +16,7 @@ #include "velox/exec/SpillFile.h" #include "velox/common/base/RuntimeMetrics.h" -#include "velox/common/file/FileSystems.h" -#include "velox/vector/VectorStream.h" +#include "velox/serializers/SerializedPageFile.h" namespace facebook::velox::exec { namespace { @@ -29,49 +28,6 @@ namespace { static const bool kDefaultUseLosslessTimestamp = true; } // namespace -std::unique_ptr SpillWriteFile::create( - uint32_t id, - const std::string& pathPrefix, - const std::string& fileCreateConfig) { - return std::unique_ptr( - new SpillWriteFile(id, pathPrefix, fileCreateConfig)); -} - -SpillWriteFile::SpillWriteFile( - uint32_t id, - const std::string& pathPrefix, - const std::string& fileCreateConfig) - : id_(id), path_(fmt::format("{}-{}", pathPrefix, ordinalCounter_++)) { - auto fs = filesystems::getFileSystem(path_, nullptr); - file_ = fs->openFileForWrite( - path_, - filesystems::FileOptions{ - {{filesystems::FileOptions::kFileCreateConfig.toString(), - fileCreateConfig}}, - nullptr, - std::nullopt}); -} - -void SpillWriteFile::finish() { - VELOX_CHECK_NOT_NULL(file_); - size_ = file_->size(); - file_->close(); - file_ = nullptr; -} - -uint64_t SpillWriteFile::size() const { - if (file_ != nullptr) { - return file_->size(); - } - return size_; -} - -uint64_t SpillWriteFile::write(std::unique_ptr iobuf) { - auto writtenBytes = iobuf->computeChainDataLength(); - file_->append(std::move(iobuf)); - return writtenBytes; -} - SpillWriter::SpillWriter( const RowTypePtr& type, const std::vector& sortingKeys, @@ -80,160 +36,82 @@ SpillWriter::SpillWriter( uint64_t targetFileSize, uint64_t writeBufferSize, const std::string& fileCreateConfig, - common::UpdateAndCheckSpillLimitCB& updateAndCheckSpillLimitCb, + const common::UpdateAndCheckSpillLimitCB& updateAndCheckSpillLimitCb, memory::MemoryPool* pool, - folly::Synchronized* stats) - : type_(type), + exec::SpillStats* stats) + : serializer::SerializedPageFileWriter( + pathPrefix, + targetFileSize, + writeBufferSize, + fileCreateConfig, + std::make_unique< + serializer::presto::PrestoVectorSerde::PrestoOptions>( + kDefaultUseLosslessTimestamp, + compressionKind, + 0.8, + /*_nullsFirst=*/true), + getNamedVectorSerde("Presto"), + pool, + &stats->ioStats), + type_(type), sortingKeys_(sortingKeys), - compressionKind_(compressionKind), - pathPrefix_(pathPrefix), - targetFileSize_(targetFileSize), - writeBufferSize_(writeBufferSize), - fileCreateConfig_(fileCreateConfig), - updateAndCheckSpillLimitCb_(updateAndCheckSpillLimitCb), - pool_(pool), - serde_(getNamedVectorSerde(VectorSerde::Kind::kPresto)), - stats_(stats) {} - -SpillWriteFile* SpillWriter::ensureFile() { - if ((currentFile_ != nullptr) && (currentFile_->size() > targetFileSize_)) { - closeFile(); - } - if (currentFile_ == nullptr) { - currentFile_ = SpillWriteFile::create( - nextFileId_++, - fmt::format("{}-{}", pathPrefix_, finishedFiles_.size()), - fileCreateConfig_); - } - return currentFile_.get(); -} - -void SpillWriter::closeFile() { - if (currentFile_ == nullptr) { - return; - } - currentFile_->finish(); - updateSpilledFileStats(currentFile_->size()); - finishedFiles_.push_back(SpillFileInfo{ - .id = currentFile_->id(), - .type = type_, - .path = currentFile_->path(), - .size = currentFile_->size(), - .sortingKeys = sortingKeys_, - .compressionKind = compressionKind_}); - currentFile_.reset(); -} - -size_t SpillWriter::numFinishedFiles() const { - return finishedFiles_.size(); -} - -uint64_t SpillWriter::flush() { - if (batch_ == nullptr) { - return 0; - } - - auto* file = ensureFile(); - VELOX_CHECK_NOT_NULL(file); - - IOBufOutputStream out( - *pool_, nullptr, std::max(64 * 1024, batch_->size())); - uint64_t flushTimeNs{0}; - { - NanosecondTimer timer(&flushTimeNs); - batch_->flush(&out); - } - batch_.reset(); - - uint64_t writeTimeNs{0}; - uint64_t writtenBytes{0}; - auto iobuf = out.getIOBuf(); - { - NanosecondTimer timer(&writeTimeNs); - writtenBytes = file->write(std::move(iobuf)); - } - updateWriteStats(writtenBytes, flushTimeNs, writeTimeNs); - updateAndCheckSpillLimitCb_(writtenBytes); - return writtenBytes; -} - -uint64_t SpillWriter::write( - const RowVectorPtr& rows, - const folly::Range& indices) { - checkNotFinished(); - - uint64_t timeNs{0}; - { - NanosecondTimer timer(&timeNs); - if (batch_ == nullptr) { - serializer::presto::PrestoVectorSerde::PrestoOptions options = { - kDefaultUseLosslessTimestamp, - compressionKind_, - 0.8, - /*_nullsFirst=*/true}; - batch_ = std::make_unique(pool_, serde_); - batch_->createStreamTree( - std::static_pointer_cast(rows->type()), - 1'000, - &options); - } - batch_->append(rows, indices); - } - updateAppendStats(rows->size(), timeNs); - if (batch_->size() < writeBufferSize_) { - return 0; - } - return flush(); -} + stats_(stats), + updateAndCheckLimitCb_(updateAndCheckSpillLimitCb) {} void SpillWriter::updateAppendStats( uint64_t numRows, uint64_t serializationTimeNs) { - auto statsLocked = stats_->wlock(); - statsLocked->spilledRows += numRows; - statsLocked->spillSerializationTimeNanos += serializationTimeNs; - common::updateGlobalSpillAppendStats(numRows, serializationTimeNs); + stats_->spilledRows.fetch_add(numRows, std::memory_order_relaxed); + stats_->spillSerializationTimeNanos.fetch_add( + serializationTimeNs, std::memory_order_relaxed); + updateGlobalSpillAppendStats(numRows, serializationTimeNs); } void SpillWriter::updateWriteStats( uint64_t spilledBytes, uint64_t flushTimeNs, uint64_t fileWriteTimeNs) { - auto statsLocked = stats_->wlock(); - statsLocked->spilledBytes += spilledBytes; - statsLocked->spillFlushTimeNanos += flushTimeNs; - statsLocked->spillWriteTimeNanos += fileWriteTimeNs; - ++statsLocked->spillWrites; - common::updateGlobalSpillWriteStats( - spilledBytes, flushTimeNs, fileWriteTimeNs); -} - -void SpillWriter::updateSpilledFileStats(uint64_t fileSize) { - ++stats_->wlock()->spilledFiles; + stats_->spillWrites.fetch_add(1, std::memory_order_relaxed); + stats_->spilledBytes.fetch_add(spilledBytes, std::memory_order_relaxed); + stats_->spillFlushTimeNanos.fetch_add(flushTimeNs, std::memory_order_relaxed); + stats_->spillWriteTimeNanos.fetch_add( + fileWriteTimeNs, std::memory_order_relaxed); + updateGlobalSpillWriteStats(spilledBytes, flushTimeNs, fileWriteTimeNs); + updateAndCheckLimitCb_(spilledBytes); +} + +void SpillWriter::updateFileStats( + const serializer::SerializedPageFile::FileInfo& file) { + stats_->spilledFiles.fetch_add(1, std::memory_order_relaxed); addThreadLocalRuntimeStat( - "spillFileSize", RuntimeCounter(fileSize, RuntimeCounter::Unit::kBytes)); - common::incrementGlobalSpilledFiles(); -} - -void SpillWriter::finishFile() { - checkNotFinished(); - flush(); - closeFile(); - VELOX_CHECK_NULL(currentFile_); + "spillFileSize", RuntimeCounter(file.size, RuntimeCounter::Unit::kBytes)); + incrementGlobalSpilledFiles(); } SpillFiles SpillWriter::finish() { - checkNotFinished(); - auto finishGuard = folly::makeGuard([this]() { finished_ = true; }); - - finishFile(); - return std::move(finishedFiles_); + const auto serializedPageFiles = + serializer::SerializedPageFileWriter::finish(); + SpillFiles spillFiles; + spillFiles.reserve(serializedPageFiles.size()); + for (const auto& fileInfo : serializedPageFiles) { + spillFiles.push_back( + SpillFileInfo{ + .id = fileInfo.id, + .type = type_, + .path = fileInfo.path, + .size = fileInfo.size, + .sortingKeys = sortingKeys_, + .compressionKind = serdeOptions_->compressionKind}); + } + return spillFiles; } std::vector SpillWriter::testingSpilledFilePaths() const { checkNotFinished(); std::vector spilledFilePaths; + spilledFilePaths.reserve( + finishedFiles_.size() + (currentFile_ != nullptr ? 1 : 0)); for (auto& file : finishedFiles_) { spilledFilePaths.push_back(file.path); } @@ -260,7 +138,7 @@ std::unique_ptr SpillReadFile::create( const SpillFileInfo& fileInfo, uint64_t bufferSize, memory::MemoryPool* pool, - folly::Synchronized* stats) { + exec::SpillStats* stats) { return std::unique_ptr(new SpillReadFile( fileInfo.id, fileInfo.path, @@ -282,52 +160,42 @@ SpillReadFile::SpillReadFile( const std::vector& sortingKeys, common::CompressionKind compressionKind, memory::MemoryPool* pool, - folly::Synchronized* stats) - : id_(id), + exec::SpillStats* stats) + : serializer::SerializedPageFileReader( + path, + bufferSize, + type, + getNamedVectorSerde("Presto"), + std::make_unique< + serializer::presto::PrestoVectorSerde::PrestoOptions>( + kDefaultUseLosslessTimestamp, + compressionKind, + 0.8, + /*_nullsFirst=*/true), + pool, + &stats->ioStats), + id_(id), path_(path), size_(size), - type_(type), sortingKeys_(sortingKeys), - compressionKind_(compressionKind), - readOptions_{ - kDefaultUseLosslessTimestamp, - compressionKind_, - 0.8, - /*_nullsFirst=*/true}, - pool_(pool), - serde_(getNamedVectorSerde(VectorSerde::Kind::kPresto)), - stats_(stats) { - auto fs = filesystems::getFileSystem(path_, nullptr); - auto file = fs->openFileForRead(path_); - input_ = std::make_unique( - std::move(file), bufferSize, pool_); -} - -bool SpillReadFile::nextBatch(RowVectorPtr& rowVector) { - if (input_->atEnd()) { - recordSpillStats(); - return false; - } - - uint64_t timeNs{0}; - { - NanosecondTimer timer{&timeNs}; - VectorStreamGroup::read( - input_.get(), pool_, type_, serde_, &rowVector, &readOptions_); - } - stats_->wlock()->spillDeserializationTimeNanos += timeNs; - common::updateGlobalSpillDeserializationTimeNs(timeNs); - return true; -} + stats_(stats) {} -void SpillReadFile::recordSpillStats() { +void SpillReadFile::updateFinalStats() { VELOX_CHECK(input_->atEnd()); const auto readStats = input_->stats(); - common::updateGlobalSpillReadStats( + updateGlobalSpillReadStats( readStats.numReads, readStats.readBytes, readStats.readTimeNs); - auto lockedSpillStats = stats_->wlock(); - lockedSpillStats->spillReads += readStats.numReads; - lockedSpillStats->spillReadTimeNanos += readStats.readTimeNs; - lockedSpillStats->spillReadBytes += readStats.readBytes; -} + stats_->spillReads.fetch_add(readStats.numReads, std::memory_order_relaxed); + stats_->spillReadBytes.fetch_add( + readStats.readBytes, std::memory_order_relaxed); + stats_->spillReadTimeNanos.fetch_add( + readStats.readTimeNs, std::memory_order_relaxed); +}; + +void SpillReadFile::updateSerializationTimeStats(uint64_t timeNs) { + stats_->spillDeserializationTimeNanos.fetch_add( + timeNs, std::memory_order_relaxed); + updateGlobalSpillDeserializationTimeNs(timeNs); +}; + } // namespace facebook::velox::exec diff --git a/velox/exec/SpillFile.h b/velox/exec/SpillFile.h index 55547eec0a5..fdedf817d7b 100644 --- a/velox/exec/SpillFile.h +++ b/velox/exec/SpillFile.h @@ -17,14 +17,16 @@ #pragma once #include +#include #include "velox/common/base/SpillConfig.h" -#include "velox/common/base/SpillStats.h" +#include "velox/common/base/TreeOfLosers.h" #include "velox/common/compression/Compression.h" #include "velox/common/file/File.h" #include "velox/common/file/FileInputStream.h" -#include "velox/exec/TreeOfLosers.h" +#include "velox/exec/SpillStats.h" #include "velox/serializers/PrestoSerializer.h" +#include "velox/serializers/SerializedPageFile.h" #include "velox/vector/ComplexVector.h" #include "velox/vector/DecodedVector.h" #include "velox/vector/VectorStream.h" @@ -32,53 +34,6 @@ namespace facebook::velox::exec { using SpillSortKey = std::pair; -/// Represents a spill file for writing the serialized spilled data into a disk -/// file. -class SpillWriteFile { - public: - static std::unique_ptr create( - uint32_t id, - const std::string& pathPrefix, - const std::string& fileCreateConfig); - - uint32_t id() const { - return id_; - } - - /// Returns the file size in bytes. - uint64_t size() const; - - const std::string& path() const { - return path_; - } - - uint64_t write(std::unique_ptr iobuf); - - WriteFile* file() { - return file_.get(); - } - - /// Finishes writing and flushes any unwritten data. - void finish(); - - private: - static inline std::atomic ordinalCounter_{0}; - - SpillWriteFile( - uint32_t id, - const std::string& pathPrefix, - const std::string& fileCreateConfig); - - // The spill file id which is monotonically increasing and unique for each - // associated spill partition. - const uint32_t id_; - const std::string path_; - - std::unique_ptr file_; - // Byte size of the backing file. Set when finishing writing. - uint64_t size_{0}; -}; - /// Records info of a finished spill file which is used for read. struct SpillFileInfo { uint32_t id; @@ -95,7 +50,7 @@ using SpillFiles = std::vector; /// Used to write the spilled data to a sequence of files for one partition. If /// data is sorted, each file is sorted. The globally sorted order is produced /// by merging the constituent files. -class SpillWriter { +class SpillWriter : public serializer::SerializedPageFileWriter { public: /// 'type' is a RowType describing the content. 'numSortKeys' is the number /// of leading columns on which the data is sorted. 'path' is a file path @@ -116,29 +71,13 @@ class SpillWriter { uint64_t targetFileSize, uint64_t writeBufferSize, const std::string& fileCreateConfig, - common::UpdateAndCheckSpillLimitCB& updateAndCheckSpillLimitCb, + const common::UpdateAndCheckSpillLimitCB& updateAndCheckSpillLimitCb, memory::MemoryPool* pool, - folly::Synchronized* stats); - - /// Adds 'rows' for the positions in 'indices' into 'this'. The indices - /// must produce a view where the rows are sorted if sorting is desired. - /// Consecutive calls must have sorted data so that the first row of the - /// next call is not less than the last row of the previous call. - /// Returns the size to write. - uint64_t write( - const RowVectorPtr& rows, - const folly::Range& indices); - - /// Closes the current output file if any. Subsequent calls to write will - /// start a new one. - void finishFile(); - - /// Returns the number of current finished files. - size_t numFinishedFiles() const; + exec::SpillStats* stats); /// Finishes this file writer and returns the written spill files info. /// - /// NOTE: we don't allow write to a spill writer after t + /// NOTE: we don't allow write to a spill writer after finish SpillFiles finish(); std::vector testingSpilledFilePaths() const; @@ -146,55 +85,29 @@ class SpillWriter { std::vector testingSpilledFileIds() const; private: - FOLLY_ALWAYS_INLINE void checkNotFinished() const { - VELOX_CHECK(!finished_, "SpillWriter has finished"); - } - - // Returns an open spill file for write. If there is no open spill file, then - // the function creates a new one. If the current open spill file exceeds the - // target file size limit, then it first closes the current one and then - // creates a new one. 'currentFile_' points to the current open spill file. - SpillWriteFile* ensureFile(); - - // Closes the current open spill file pointed by 'currentFile_'. - void closeFile(); - - // Writes data from 'batch_' to the current output file. Returns the actual - // written size. - uint64_t flush(); - // Invoked to increment the number of spilled files and the file size. - void updateSpilledFileStats(uint64_t fileSize); + void updateFileStats( + const serializer::SerializedPageFile::FileInfo& fileInfo) override; // Invoked to update the number of spilled rows. - void updateAppendStats(uint64_t numRows, uint64_t serializationTimeUs); + void updateAppendStats(uint64_t numRows, uint64_t serializationTimeUs) + override; // Invoked to update the disk write stats. void updateWriteStats( uint64_t spilledBytes, uint64_t flushTimeUs, - uint64_t writeTimeUs); + uint64_t writeTimeUs) override; const RowTypePtr type_; + const std::vector sortingKeys_; - const common::CompressionKind compressionKind_; - const std::string pathPrefix_; - const uint64_t targetFileSize_; - const uint64_t writeBufferSize_; - const std::string fileCreateConfig_; - // Updates the aggregated spill bytes of this query, and throws if exceeds - // the max spill bytes limit. - const common::UpdateAndCheckSpillLimitCB updateAndCheckSpillLimitCb_; - memory::MemoryPool* const pool_; - VectorSerde* const serde_; - folly::Synchronized* const stats_; + exec::SpillStats* const stats_; - bool finished_{false}; - uint32_t nextFileId_{0}; - std::unique_ptr batch_; - std::unique_ptr currentFile_; - SpillFiles finishedFiles_; + // Updates the aggregated bytes of this query, and throws if exceeds + // the max bytes limit. + const common::UpdateAndCheckSpillLimitCB updateAndCheckLimitCb_; }; /// Represents a spill file for read which turns the serialized spilled data @@ -204,13 +117,13 @@ class SpillWriter { /// needs to remove the unused spill files at some point later. For example, a /// query Task deletes all the generated spill files in one operation using /// rmdir() call. -class SpillReadFile { +class SpillReadFile : public serializer::SerializedPageFileReader { public: static std::unique_ptr create( const SpillFileInfo& fileInfo, uint64_t bufferSize, memory::MemoryPool* pool, - folly::Synchronized* stats); + exec::SpillStats* stats); uint32_t id() const { return id_; @@ -220,8 +133,6 @@ class SpillReadFile { return sortingKeys_; } - bool nextBatch(RowVectorPtr& rowVector); - /// Returns the file size in bytes. uint64_t size() const { return size_; @@ -241,26 +152,25 @@ class SpillReadFile { const std::vector& sortingKeys, common::CompressionKind compressionKind, memory::MemoryPool* pool, - folly::Synchronized* stats); + exec::SpillStats* stats); + + // Records spill read stats at the end of read input. + void updateFinalStats() override; - // Invoked to record spill read stats at the end of read input. - void recordSpillStats(); + void updateSerializationTimeStats(uint64_t timeNs) override; // The spill file id which is monotonically increasing and unique for each // associated spill partition. const uint32_t id_; + const std::string path_; + // The file size in bytes. const uint64_t size_; - // The data type of spilled data. - const RowTypePtr type_; + const std::vector sortingKeys_; - const common::CompressionKind compressionKind_; - const serializer::presto::PrestoVectorSerde::PrestoOptions readOptions_; - memory::MemoryPool* const pool_; - VectorSerde* const serde_; - folly::Synchronized* const stats_; - std::unique_ptr input_; + exec::SpillStats* const stats_; }; + } // namespace facebook::velox::exec diff --git a/velox/exec/SpillStats.cpp b/velox/exec/SpillStats.cpp new file mode 100644 index 00000000000..8b3aef5c0f6 --- /dev/null +++ b/velox/exec/SpillStats.cpp @@ -0,0 +1,509 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/SpillStats.h" +#include +#include +#include "velox/common/base/Counters.h" +#include "velox/common/base/StatsReporter.h" +#include "velox/common/base/SuccinctPrinter.h" + +namespace facebook::velox::exec { +namespace { +std::vector& allSpillStats() { + static std::vector spillStatsList(folly::available_concurrency()); + return spillStatsList; +} + +SpillStats& localSpillStats() { + const auto idx = std::hash{}(std::this_thread::get_id()); + auto& spillStatsVector = allSpillStats(); + return spillStatsVector[idx % spillStatsVector.size()]; +} +} // namespace + +SpillStats::SpillStats( + uint64_t _spillRuns, + uint64_t _spilledInputBytes, + uint64_t _spilledBytes, + uint64_t _spilledRows, + uint32_t _spilledPartitions, + uint64_t _spilledFiles, + uint64_t _spillFillTimeNanos, + uint64_t _spillSortTimeNanos, + uint64_t _spillExtractVectorTimeNanos, + uint64_t _spillSerializationTimeNanos, + uint64_t _spillWrites, + uint64_t _spillFlushTimeNanos, + uint64_t _spillWriteTimeNanos, + uint64_t _spillMaxLevelExceededCount, + uint64_t _spillReadBytes, + uint64_t _spillReads, + uint64_t _spillReadTimeNanos, + uint64_t _spillDeserializationTimeNanos) { + spillRuns.store(_spillRuns, std::memory_order_relaxed); + spilledInputBytes.store(_spilledInputBytes, std::memory_order_relaxed); + spilledBytes.store(_spilledBytes, std::memory_order_relaxed); + spilledRows.store(_spilledRows, std::memory_order_relaxed); + spilledPartitions.store(_spilledPartitions, std::memory_order_relaxed); + spilledFiles.store(_spilledFiles, std::memory_order_relaxed); + spillFillTimeNanos.store(_spillFillTimeNanos, std::memory_order_relaxed); + spillSortTimeNanos.store(_spillSortTimeNanos, std::memory_order_relaxed); + spillExtractVectorTimeNanos.store( + _spillExtractVectorTimeNanos, std::memory_order_relaxed); + spillSerializationTimeNanos.store( + _spillSerializationTimeNanos, std::memory_order_relaxed); + spillWrites.store(_spillWrites, std::memory_order_relaxed); + spillFlushTimeNanos.store(_spillFlushTimeNanos, std::memory_order_relaxed); + spillWriteTimeNanos.store(_spillWriteTimeNanos, std::memory_order_relaxed); + spillMaxLevelExceededCount.store( + _spillMaxLevelExceededCount, std::memory_order_relaxed); + spillReadBytes.store(_spillReadBytes, std::memory_order_relaxed); + spillReads.store(_spillReads, std::memory_order_relaxed); + spillReadTimeNanos.store(_spillReadTimeNanos, std::memory_order_relaxed); + spillDeserializationTimeNanos.store( + _spillDeserializationTimeNanos, std::memory_order_relaxed); +} + +SpillStats::SpillStats(const SpillStats& other) { + copyFrom(other); +} + +SpillStats& SpillStats::operator=(const SpillStats& other) { + if (this != &other) { + copyFrom(other); + } + return *this; +} + +SpillStats& SpillStats::operator+=(const SpillStats& other) { + spillRuns.fetch_add( + other.spillRuns.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spilledInputBytes.fetch_add( + other.spilledInputBytes.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spilledBytes.fetch_add( + other.spilledBytes.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spilledRows.fetch_add( + other.spilledRows.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spilledPartitions.fetch_add( + other.spilledPartitions.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spilledFiles.fetch_add( + other.spilledFiles.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spillFillTimeNanos.fetch_add( + other.spillFillTimeNanos.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spillSortTimeNanos.fetch_add( + other.spillSortTimeNanos.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spillExtractVectorTimeNanos.fetch_add( + other.spillExtractVectorTimeNanos.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spillSerializationTimeNanos.fetch_add( + other.spillSerializationTimeNanos.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spillWrites.fetch_add( + other.spillWrites.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spillFlushTimeNanos.fetch_add( + other.spillFlushTimeNanos.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spillWriteTimeNanos.fetch_add( + other.spillWriteTimeNanos.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spillMaxLevelExceededCount.fetch_add( + other.spillMaxLevelExceededCount.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spillReadBytes.fetch_add( + other.spillReadBytes.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spillReads.fetch_add( + other.spillReads.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spillReadTimeNanos.fetch_add( + other.spillReadTimeNanos.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spillDeserializationTimeNanos.fetch_add( + other.spillDeserializationTimeNanos.load(std::memory_order_relaxed), + std::memory_order_relaxed); + ioStats.merge(other.ioStats); + return *this; +} + +bool SpillStats::empty() const { + return spilledBytes.load(std::memory_order_relaxed) == 0; +} + +void SpillStats::copyFrom(const SpillStats& other) { + spillRuns.store( + other.spillRuns.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spilledInputBytes.store( + other.spilledInputBytes.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spilledBytes.store( + other.spilledBytes.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spilledRows.store( + other.spilledRows.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spilledPartitions.store( + other.spilledPartitions.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spilledFiles.store( + other.spilledFiles.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spillFillTimeNanos.store( + other.spillFillTimeNanos.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spillSortTimeNanos.store( + other.spillSortTimeNanos.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spillExtractVectorTimeNanos.store( + other.spillExtractVectorTimeNanos.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spillSerializationTimeNanos.store( + other.spillSerializationTimeNanos.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spillWrites.store( + other.spillWrites.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spillFlushTimeNanos.store( + other.spillFlushTimeNanos.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spillWriteTimeNanos.store( + other.spillWriteTimeNanos.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spillMaxLevelExceededCount.store( + other.spillMaxLevelExceededCount.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spillReadBytes.store( + other.spillReadBytes.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spillReads.store( + other.spillReads.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spillReadTimeNanos.store( + other.spillReadTimeNanos.load(std::memory_order_relaxed), + std::memory_order_relaxed); + spillDeserializationTimeNanos.store( + other.spillDeserializationTimeNanos.load(std::memory_order_relaxed), + std::memory_order_relaxed); + ioStats.merge(other.ioStats); +} + +SpillStats SpillStats::operator-(const SpillStats& other) const { + SpillStats result; + result.spillRuns.store( + spillRuns.load(std::memory_order_relaxed) - + other.spillRuns.load(std::memory_order_relaxed), + std::memory_order_relaxed); + result.spilledInputBytes.store( + spilledInputBytes.load(std::memory_order_relaxed) - + other.spilledInputBytes.load(std::memory_order_relaxed), + std::memory_order_relaxed); + result.spilledBytes.store( + spilledBytes.load(std::memory_order_relaxed) - + other.spilledBytes.load(std::memory_order_relaxed), + std::memory_order_relaxed); + result.spilledRows.store( + spilledRows.load(std::memory_order_relaxed) - + other.spilledRows.load(std::memory_order_relaxed), + std::memory_order_relaxed); + result.spilledPartitions.store( + spilledPartitions.load(std::memory_order_relaxed) - + other.spilledPartitions.load(std::memory_order_relaxed), + std::memory_order_relaxed); + result.spilledFiles.store( + spilledFiles.load(std::memory_order_relaxed) - + other.spilledFiles.load(std::memory_order_relaxed), + std::memory_order_relaxed); + result.spillFillTimeNanos.store( + spillFillTimeNanos.load(std::memory_order_relaxed) - + other.spillFillTimeNanos.load(std::memory_order_relaxed), + std::memory_order_relaxed); + result.spillSortTimeNanos.store( + spillSortTimeNanos.load(std::memory_order_relaxed) - + other.spillSortTimeNanos.load(std::memory_order_relaxed), + std::memory_order_relaxed); + result.spillExtractVectorTimeNanos.store( + spillExtractVectorTimeNanos.load(std::memory_order_relaxed) - + other.spillExtractVectorTimeNanos.load(std::memory_order_relaxed), + std::memory_order_relaxed); + result.spillSerializationTimeNanos.store( + spillSerializationTimeNanos.load(std::memory_order_relaxed) - + other.spillSerializationTimeNanos.load(std::memory_order_relaxed), + std::memory_order_relaxed); + result.spillWrites.store( + spillWrites.load(std::memory_order_relaxed) - + other.spillWrites.load(std::memory_order_relaxed), + std::memory_order_relaxed); + result.spillFlushTimeNanos.store( + spillFlushTimeNanos.load(std::memory_order_relaxed) - + other.spillFlushTimeNanos.load(std::memory_order_relaxed), + std::memory_order_relaxed); + result.spillWriteTimeNanos.store( + spillWriteTimeNanos.load(std::memory_order_relaxed) - + other.spillWriteTimeNanos.load(std::memory_order_relaxed), + std::memory_order_relaxed); + result.spillMaxLevelExceededCount.store( + spillMaxLevelExceededCount.load(std::memory_order_relaxed) - + other.spillMaxLevelExceededCount.load(std::memory_order_relaxed), + std::memory_order_relaxed); + result.spillReadBytes.store( + spillReadBytes.load(std::memory_order_relaxed) - + other.spillReadBytes.load(std::memory_order_relaxed), + std::memory_order_relaxed); + result.spillReads.store( + spillReads.load(std::memory_order_relaxed) - + other.spillReads.load(std::memory_order_relaxed), + std::memory_order_relaxed); + result.spillReadTimeNanos.store( + spillReadTimeNanos.load(std::memory_order_relaxed) - + other.spillReadTimeNanos.load(std::memory_order_relaxed), + std::memory_order_relaxed); + result.spillDeserializationTimeNanos.store( + spillDeserializationTimeNanos.load(std::memory_order_relaxed) - + other.spillDeserializationTimeNanos.load(std::memory_order_relaxed), + std::memory_order_relaxed); + return result; +} + +bool SpillStats::operator==(const SpillStats& other) const { + return spillRuns.load(std::memory_order_relaxed) == + other.spillRuns.load(std::memory_order_relaxed) && + spilledInputBytes.load(std::memory_order_relaxed) == + other.spilledInputBytes.load(std::memory_order_relaxed) && + spilledBytes.load(std::memory_order_relaxed) == + other.spilledBytes.load(std::memory_order_relaxed) && + spilledRows.load(std::memory_order_relaxed) == + other.spilledRows.load(std::memory_order_relaxed) && + spilledPartitions.load(std::memory_order_relaxed) == + other.spilledPartitions.load(std::memory_order_relaxed) && + spilledFiles.load(std::memory_order_relaxed) == + other.spilledFiles.load(std::memory_order_relaxed) && + spillFillTimeNanos.load(std::memory_order_relaxed) == + other.spillFillTimeNanos.load(std::memory_order_relaxed) && + spillSortTimeNanos.load(std::memory_order_relaxed) == + other.spillSortTimeNanos.load(std::memory_order_relaxed) && + spillExtractVectorTimeNanos.load(std::memory_order_relaxed) == + other.spillExtractVectorTimeNanos.load(std::memory_order_relaxed) && + spillSerializationTimeNanos.load(std::memory_order_relaxed) == + other.spillSerializationTimeNanos.load(std::memory_order_relaxed) && + spillWrites.load(std::memory_order_relaxed) == + other.spillWrites.load(std::memory_order_relaxed) && + spillFlushTimeNanos.load(std::memory_order_relaxed) == + other.spillFlushTimeNanos.load(std::memory_order_relaxed) && + spillWriteTimeNanos.load(std::memory_order_relaxed) == + other.spillWriteTimeNanos.load(std::memory_order_relaxed) && + spillMaxLevelExceededCount.load(std::memory_order_relaxed) == + other.spillMaxLevelExceededCount.load(std::memory_order_relaxed) && + spillReadBytes.load(std::memory_order_relaxed) == + other.spillReadBytes.load(std::memory_order_relaxed) && + spillReads.load(std::memory_order_relaxed) == + other.spillReads.load(std::memory_order_relaxed) && + spillReadTimeNanos.load(std::memory_order_relaxed) == + other.spillReadTimeNanos.load(std::memory_order_relaxed) && + spillDeserializationTimeNanos.load(std::memory_order_relaxed) == + other.spillDeserializationTimeNanos.load(std::memory_order_relaxed); +} + +void SpillStats::reset() { + spillRuns.store(0, std::memory_order_relaxed); + spilledInputBytes.store(0, std::memory_order_relaxed); + spilledBytes.store(0, std::memory_order_relaxed); + spilledRows.store(0, std::memory_order_relaxed); + spilledPartitions.store(0, std::memory_order_relaxed); + spilledFiles.store(0, std::memory_order_relaxed); + spillFillTimeNanos.store(0, std::memory_order_relaxed); + spillSortTimeNanos.store(0, std::memory_order_relaxed); + spillExtractVectorTimeNanos.store(0, std::memory_order_relaxed); + spillSerializationTimeNanos.store(0, std::memory_order_relaxed); + spillWrites.store(0, std::memory_order_relaxed); + spillFlushTimeNanos.store(0, std::memory_order_relaxed); + spillWriteTimeNanos.store(0, std::memory_order_relaxed); + spillMaxLevelExceededCount.store(0, std::memory_order_relaxed); + spillReadBytes.store(0, std::memory_order_relaxed); + spillReads.store(0, std::memory_order_relaxed); + spillReadTimeNanos.store(0, std::memory_order_relaxed); + spillDeserializationTimeNanos.store(0, std::memory_order_relaxed); + ioStats = IoStats(); +} + +std::string SpillStats::toString() const { + std::stringstream ss; + ss << "spillRuns[" << spillRuns.load(std::memory_order_relaxed) << "] " + << "spilledInputBytes[" + << succinctBytes(spilledInputBytes.load(std::memory_order_relaxed)) << "] " + << "spilledBytes[" + << succinctBytes(spilledBytes.load(std::memory_order_relaxed)) << "] " + << "spilledRows[" << spilledRows.load(std::memory_order_relaxed) << "] " + << "spilledPartitions[" + << spilledPartitions.load(std::memory_order_relaxed) << "] " + << "spilledFiles[" << spilledFiles.load(std::memory_order_relaxed) << "] " + << "spillFillTimeNanos[" + << succinctNanos(spillFillTimeNanos.load(std::memory_order_relaxed)) + << "] " + << "spillSortTimeNanos[" + << succinctNanos(spillSortTimeNanos.load(std::memory_order_relaxed)) + << "] " + << "spillExtractVectorTime[" + << succinctNanos( + spillExtractVectorTimeNanos.load(std::memory_order_relaxed)) + << "] " + << "spillSerializationTimeNanos[" + << succinctNanos( + spillSerializationTimeNanos.load(std::memory_order_relaxed)) + << "] " + << "spillWrites[" << spillWrites.load(std::memory_order_relaxed) << "] " + << "spillFlushTimeNanos[" + << succinctNanos(spillFlushTimeNanos.load(std::memory_order_relaxed)) + << "] " + << "spillWriteTimeNanos[" + << succinctNanos(spillWriteTimeNanos.load(std::memory_order_relaxed)) + << "] " + << "maxSpillExceededLimitCount[" + << spillMaxLevelExceededCount.load(std::memory_order_relaxed) << "] " + << "spillReadBytes[" + << succinctBytes(spillReadBytes.load(std::memory_order_relaxed)) << "] " + << "spillReads[" << spillReads.load(std::memory_order_relaxed) << "] " + << "spillReadTimeNanos[" + << succinctNanos(spillReadTimeNanos.load(std::memory_order_relaxed)) + << "] " + << "spillReadDeserializationTimeNanos[" + << succinctNanos( + spillDeserializationTimeNanos.load(std::memory_order_relaxed)) + << "]"; + + const auto ioStatsMap = ioStats.stats(); + if (!ioStatsMap.empty()) { + ss << " ioStats["; + bool first = true; + for (const auto& [name, metric] : ioStatsMap) { + if (!first) { + ss << ", "; + } + first = false; + ss << name << ":{sum:" << metric.sum << ", count:" << metric.count + << ", min:" << metric.min << ", max:" << metric.max << "}"; + } + ss << "]"; + } + return ss.str(); +} + +void updateGlobalSpillRunStats(uint64_t numRuns) { + localSpillStats().spillRuns.fetch_add(numRuns, std::memory_order_relaxed); +} + +void updateGlobalSpillAppendStats( + uint64_t numRows, + uint64_t serializationTimeNs) { + RECORD_METRIC_VALUE(kMetricSpilledRowsCount, numRows); + RECORD_HISTOGRAM_METRIC_VALUE( + kMetricSpillSerializationTimeMs, serializationTimeNs / 1'000'000); + auto& stats = localSpillStats(); + stats.spilledRows.fetch_add(numRows, std::memory_order_relaxed); + stats.spillSerializationTimeNanos.fetch_add( + serializationTimeNs, std::memory_order_relaxed); +} + +void incrementGlobalSpilledPartitionStats() { + localSpillStats().spilledPartitions.fetch_add(1, std::memory_order_relaxed); +} + +void updateGlobalSpillFillTime(uint64_t timeNs) { + RECORD_HISTOGRAM_METRIC_VALUE(kMetricSpillFillTimeMs, timeNs / 1'000'000); + localSpillStats().spillFillTimeNanos.fetch_add( + timeNs, std::memory_order_relaxed); +} + +void updateGlobalSpillSortTime(uint64_t timeNs) { + RECORD_HISTOGRAM_METRIC_VALUE(kMetricSpillSortTimeMs, timeNs / 1'000'000); + localSpillStats().spillSortTimeNanos.fetch_add( + timeNs, std::memory_order_relaxed); +} + +void updateGlobalSpillExtractVectorTime(uint64_t timeNs) { + RECORD_HISTOGRAM_METRIC_VALUE( + kMetricSpillExtractVectorTimeMs, timeNs / 1'000'000); + localSpillStats().spillExtractVectorTimeNanos.fetch_add( + timeNs, std::memory_order_relaxed); +} + +void updateGlobalSpillWriteStats( + uint64_t spilledBytes, + uint64_t flushTimeNs, + uint64_t writeTimeNs) { + RECORD_METRIC_VALUE(kMetricSpillWritesCount); + RECORD_METRIC_VALUE(kMetricSpilledBytes, spilledBytes); + RECORD_HISTOGRAM_METRIC_VALUE( + kMetricSpillFlushTimeMs, flushTimeNs / 1'000'000); + RECORD_HISTOGRAM_METRIC_VALUE( + kMetricSpillWriteTimeMs, writeTimeNs / 1'000'000); + auto& stats = localSpillStats(); + stats.spillWrites.fetch_add(1, std::memory_order_relaxed); + stats.spilledBytes.fetch_add(spilledBytes, std::memory_order_relaxed); + stats.spillFlushTimeNanos.fetch_add(flushTimeNs, std::memory_order_relaxed); + stats.spillWriteTimeNanos.fetch_add(writeTimeNs, std::memory_order_relaxed); +} + +void updateGlobalSpillReadStats( + uint64_t spillReads, + uint64_t spillReadBytes, + uint64_t spillReadTimeNs) { + auto& stats = localSpillStats(); + stats.spillReads.fetch_add(spillReads, std::memory_order_relaxed); + stats.spillReadBytes.fetch_add(spillReadBytes, std::memory_order_relaxed); + stats.spillReadTimeNanos.fetch_add( + spillReadTimeNs, std::memory_order_relaxed); +} + +void updateGlobalSpillMemoryBytes(uint64_t spilledInputBytes) { + RECORD_METRIC_VALUE(kMetricSpilledInputBytes, spilledInputBytes); + localSpillStats().spilledInputBytes.fetch_add( + spilledInputBytes, std::memory_order_relaxed); +} + +void incrementGlobalSpilledFiles() { + RECORD_METRIC_VALUE(kMetricSpilledFilesCount); + localSpillStats().spilledFiles.fetch_add(1, std::memory_order_relaxed); +} + +void updateGlobalMaxSpillLevelExceededCount( + uint64_t maxSpillLevelExceededCount) { + localSpillStats().spillMaxLevelExceededCount.fetch_add( + maxSpillLevelExceededCount, std::memory_order_relaxed); +} + +void updateGlobalSpillDeserializationTimeNs(uint64_t timeNs) { + localSpillStats().spillDeserializationTimeNanos.fetch_add( + timeNs, std::memory_order_relaxed); +} + +SpillStats globalSpillStats() { + SpillStats gSpillStats; + for (auto& stats : allSpillStats()) { + gSpillStats += stats; + } + return gSpillStats; +} +} // namespace facebook::velox::exec diff --git a/velox/common/base/SpillStats.h b/velox/exec/SpillStats.h similarity index 76% rename from velox/common/base/SpillStats.h rename to velox/exec/SpillStats.h index 36e3c7b33e7..7f4c09e2652 100644 --- a/velox/common/base/SpillStats.h +++ b/velox/exec/SpillStats.h @@ -15,60 +15,62 @@ */ #pragma once -#include +#include #include -#include -#include +#include "velox/common/base/Exceptions.h" +#include "velox/common/file/File.h" -#include "velox/common/compression/Compression.h" +namespace facebook::velox::exec { -namespace facebook::velox::common { -/// Provides the fine-grained spill execution stats. +/// Thread-safe spill statistics with atomic members. struct SpillStats { /// The number of times that spilling runs on an operator. - uint64_t spillRuns{0}; - /// The number of bytes in memory to spill - uint64_t spilledInputBytes{0}; + std::atomic_uint64_t spillRuns{0}; + /// The number of bytes in memory to spill. + std::atomic_uint64_t spilledInputBytes{0}; /// The number of bytes spilled to disks. - /// /// NOTE: if compression is enabled, this counts the compressed bytes. - uint64_t spilledBytes{0}; + std::atomic_uint64_t spilledBytes{0}; /// The number of spilled rows. - uint64_t spilledRows{0}; + std::atomic_uint64_t spilledRows{0}; /// NOTE: when we sum up the stats from a group of spill operators, it is /// the total number of spilled partitions X number of operators. - uint32_t spilledPartitions{0}; + std::atomic_uint32_t spilledPartitions{0}; /// The number of spilled files. - uint64_t spilledFiles{0}; + std::atomic_uint64_t spilledFiles{0}; /// The time spent on filling rows for spilling. - uint64_t spillFillTimeNanos{0}; + std::atomic_uint64_t spillFillTimeNanos{0}; /// The time spent on sorting rows for spilling. - uint64_t spillSortTimeNanos{0}; + std::atomic_uint64_t spillSortTimeNanos{0}; /// The time spent on extracting vector from RowContainer for spilling. - uint64_t spillExtractVectorTimeNanos{0}; + std::atomic_uint64_t spillExtractVectorTimeNanos{0}; /// The time spent on serializing rows for spilling. - uint64_t spillSerializationTimeNanos{0}; + std::atomic_uint64_t spillSerializationTimeNanos{0}; /// The number of spill writer flushes, equivalent to number of write calls to /// underlying filesystem. - uint64_t spillWrites{0}; + std::atomic_uint64_t spillWrites{0}; /// The time spent on copy out serialized rows for disk write. If compression /// is enabled, this includes the compression time. - uint64_t spillFlushTimeNanos{0}; + std::atomic_uint64_t spillFlushTimeNanos{0}; /// The time spent on writing spilled rows to disk. - uint64_t spillWriteTimeNanos{0}; + std::atomic_uint64_t spillWriteTimeNanos{0}; /// The number of times that an hash build operator exceeds the max spill /// limit. - uint64_t spillMaxLevelExceededCount{0}; + std::atomic_uint64_t spillMaxLevelExceededCount{0}; /// The number of bytes read from spilled files. - uint64_t spillReadBytes{0}; + std::atomic_uint64_t spillReadBytes{0}; /// The number of spill reader reads, equivalent to the number of read calls /// to the underlying filesystem. - uint64_t spillReads{0}; + std::atomic_uint64_t spillReads{0}; /// The time spent on read data from spilled files. - uint64_t spillReadTimeNanos{0}; + std::atomic_uint64_t spillReadTimeNanos{0}; /// The time spent on deserializing rows read from spilled files. - uint64_t spillDeserializationTimeNanos{0}; + std::atomic_uint64_t spillDeserializationTimeNanos{0}; + /// Filesystem I/O stats for spill operations. + IoStats ioStats; + + SpillStats() = default; SpillStats( uint64_t _spillRuns, @@ -90,24 +92,29 @@ struct SpillStats { uint64_t _spillReadTimeNanos, uint64_t _spillDeserializationTimeNanos); - SpillStats() = default; + SpillStats(const SpillStats& other); - bool empty() const { - return spilledBytes == 0; - } + SpillStats& operator=(const SpillStats& other); SpillStats& operator+=(const SpillStats& other); + SpillStats operator-(const SpillStats& other) const; - bool operator==(const SpillStats& other) const = default; + + bool operator==(const SpillStats& other) const; + + bool empty() const; void reset(); std::string toString() const; + + private: + void copyFrom(const SpillStats& other); }; FOLLY_ALWAYS_INLINE std::ostream& operator<<( std::ostream& o, - const common::SpillStats& stats) { + const SpillStats& stats) { return o << stats.toString(); } @@ -163,12 +170,12 @@ void updateGlobalSpillDeserializationTimeNs(uint64_t timeNs); /// Gets the cumulative global spill stats. SpillStats globalSpillStats(); -} // namespace facebook::velox::common +} // namespace facebook::velox::exec template <> -struct fmt::formatter +struct fmt::formatter : fmt::formatter { - auto format(const facebook::velox::common::SpillStats& s, format_context& ctx) + auto format(const facebook::velox::exec::SpillStats& s, format_context& ctx) const { return formatter::format(s.toString(), ctx); } diff --git a/velox/exec/Spiller.cpp b/velox/exec/Spiller.cpp index 35352aed980..65de8a906ce 100644 --- a/velox/exec/Spiller.cpp +++ b/velox/exec/Spiller.cpp @@ -37,7 +37,7 @@ SpillerBase::SpillerBase( uint64_t maxSpillRunRows, std::optional parentId, const common::SpillConfig* spillConfig, - folly::Synchronized* spillStats) + exec::SpillStats* spillStats) : container_(container), executor_(spillConfig->executor), bits_(bits), @@ -101,8 +101,9 @@ bool SpillerBase::fillSpillRuns(RowContainerIterator* iterator) { uint64_t totalRows{0}; for (;;) { - const auto numRows = container_->listRows( - iterator, rows.size(), RowContainer::kUnlimited, rows.data()); + // TODO: Reuse 'RowContainer::rowPointers_'. + const auto numRows = + container_->listRows(iterator, rows.size(), rows.data()); if (numRows == 0) { lastRun = true; break; @@ -143,7 +144,7 @@ bool SpillerBase::fillSpillRuns(RowContainerIterator* iterator) { } void SpillerBase::runSpill(bool lastRun) { - ++spillStats_->wlock()->spillRuns; + spillStats_->spillRuns.fetch_add(1, std::memory_order_relaxed); std::vector>> writes; for (const auto& [id, spillRun] : spillRuns_) { @@ -154,8 +155,9 @@ void SpillerBase::runSpill(bool lastRun) { if (spillRun.rows.empty()) { continue; } - writes.push_back(memory::createAsyncMemoryReclaimTask( - [partitionId = id, this]() { return writeSpill(partitionId); })); + writes.push_back( + memory::createAsyncMemoryReclaimTask( + [partitionId = id, this]() { return writeSpill(partitionId); })); if ((writes.size() > 1) && executor_ != nullptr) { executor_->add([source = writes.back()]() { source->prepare(); }); } @@ -307,13 +309,14 @@ void SpillerBase::extractSpill( } void SpillerBase::updateSpillExtractVectorTime(uint64_t timeNs) { - spillStats_->wlock()->spillExtractVectorTimeNanos += timeNs; - common::updateGlobalSpillExtractVectorTime(timeNs); + spillStats_->spillExtractVectorTimeNanos.fetch_add( + timeNs, std::memory_order_relaxed); + updateGlobalSpillExtractVectorTime(timeNs); } void SpillerBase::updateSpillSortTime(uint64_t timeNs) { - spillStats_->wlock()->spillSortTimeNanos += timeNs; - common::updateGlobalSpillSortTime(timeNs); + spillStats_->spillSortTimeNanos.fetch_add(timeNs, std::memory_order_relaxed); + updateGlobalSpillSortTime(timeNs); } void SpillerBase::checkEmptySpillRuns() const { @@ -326,8 +329,8 @@ void SpillerBase::checkEmptySpillRuns() const { } void SpillerBase::updateSpillFillTime(uint64_t timeNs) { - spillStats_->wlock()->spillFillTimeNanos += timeNs; - common::updateGlobalSpillFillTime(timeNs); + spillStats_->spillFillTimeNanos.fetch_add(timeNs, std::memory_order_relaxed); + updateGlobalSpillFillTime(timeNs); } void SpillerBase::finishSpill(SpillPartitionSet& partitionSet) { @@ -352,8 +355,8 @@ void SpillerBase::finishSpill(SpillPartitionSet& partitionSet) { } } -common::SpillStats SpillerBase::stats() const { - return spillStats_->copy(); +exec::SpillStats SpillerBase::stats() const { + return *spillStats_; } std::string SpillerBase::toString() const { @@ -387,7 +390,7 @@ NoRowContainerSpiller::NoRowContainerSpiller( std::optional parentId, HashBitRange bits, const common::SpillConfig* spillConfig, - folly::Synchronized* spillStats) + exec::SpillStats* spillStats) : NoRowContainerSpiller( std::move(rowType), parentId, @@ -402,7 +405,7 @@ NoRowContainerSpiller::NoRowContainerSpiller( HashBitRange bits, const std::vector& sortingKeys, const common::SpillConfig* spillConfig, - folly::Synchronized* spillStats) + exec::SpillStats* spillStats) : SpillerBase( nullptr, std::move(rowType), @@ -435,7 +438,7 @@ SortOutputSpiller::SortOutputSpiller( RowContainer* container, RowTypePtr rowType, const common::SpillConfig* spillConfig, - folly::Synchronized* spillStats) + exec::SpillStats* spillStats) : SpillerBase( container, std::move(rowType), diff --git a/velox/exec/Spiller.h b/velox/exec/Spiller.h index 9151a7eaf53..f651338fcea 100644 --- a/velox/exec/Spiller.h +++ b/velox/exec/Spiller.h @@ -17,6 +17,7 @@ #include "velox/common/base/SpillConfig.h" #include "velox/common/compression/Compression.h" +#include "velox/common/file/FileSystems.h" #include "velox/exec/HashBitRange.h" #include "velox/exec/RowContainer.h" @@ -46,7 +47,7 @@ class SpillerBase { return finalized_; } - common::SpillStats stats() const; + exec::SpillStats stats() const; std::string toString() const; @@ -60,7 +61,7 @@ class SpillerBase { uint64_t maxSpillRunRows, std::optional parentId, const common::SpillConfig* spillConfig, - folly::Synchronized* spillStats); + exec::SpillStats* spillStats); // Invoked to spill. If 'startRowIter' is not null, then we only spill rows // from row container starting at the offset pointed by 'startRowIter'. @@ -148,7 +149,7 @@ class SpillerBase { const std::optional parentId_; - folly::Synchronized* const spillStats_; + exec::SpillStats* const spillStats_; const std::vector compareFlags_; @@ -207,7 +208,7 @@ class NoRowContainerSpiller : public SpillerBase { std::optional parentId, HashBitRange bits, const common::SpillConfig* spillConfig, - folly::Synchronized* spillStats); + exec::SpillStats* spillStats); void spill( const SpillPartitionId& partitionId, @@ -226,7 +227,7 @@ class NoRowContainerSpiller : public SpillerBase { HashBitRange bits, const std::vector& sortingKeys, const common::SpillConfig* spillConfig, - folly::Synchronized* spillStats); + exec::SpillStats* spillStats); private: std::string type() const override { @@ -246,7 +247,7 @@ class MergeSpiller final : public NoRowContainerSpiller { HashBitRange bits, const std::vector& sortingKeys, const common::SpillConfig* spillConfig, - folly::Synchronized* spillStats) + exec::SpillStats* spillStats) : NoRowContainerSpiller( std::move(rowType), parentId, @@ -265,7 +266,7 @@ class SortInputSpiller : public SpillerBase { RowTypePtr rowType, const std::vector& sortingKeys, const common::SpillConfig* spillConfig, - folly::Synchronized* spillStats) + exec::SpillStats* spillStats) : SpillerBase( container, std::move(rowType), @@ -297,7 +298,7 @@ class SortOutputSpiller : public SpillerBase { RowContainer* container, RowTypePtr rowType, const common::SpillConfig* spillConfig, - folly::Synchronized* spillStats); + exec::SpillStats* spillStats); void spill(SpillRows& rows); diff --git a/velox/exec/Split.h b/velox/exec/Split.h index 1ccdcaeba16..58913feb951 100644 --- a/velox/exec/Split.h +++ b/velox/exec/Split.h @@ -15,7 +15,11 @@ */ #pragma once +#include +#include + #include "velox/connectors/Connector.h" +#include "velox/exec/BarrierSplit.h" namespace facebook::velox::exec { @@ -23,10 +27,8 @@ struct Split { std::shared_ptr connectorSplit{nullptr}; int32_t groupId{-1}; // Bucketed group id (-1 means 'none'). - /// Indicates if this is a barrier split. A barrier split is used by task - /// barrier processing which adds one barrier split to each leaf source node - /// to signal the output drain processing. - bool barrier{false}; + /// Indicates if this is a barrier split. + std::optional barrier; Split() = default; @@ -36,9 +38,10 @@ struct Split { : connectorSplit(std::move(connectorSplit)), groupId(groupId) {} /// Called by the task barrier to create a special barrier split. - static Split createBarrier() { - static Split barrierSplit; - barrierSplit.barrier = true; + static Split createBarrier(uint32_t numDrivers = 1) { + Split barrierSplit; + barrierSplit.barrier = BarrierSplit{numDrivers}; + VELOX_CHECK_NULL(barrierSplit.connectorSplit); return barrierSplit; } @@ -48,7 +51,7 @@ struct Split { Split& operator=(const Split& other) = default; bool isBarrier() const { - return barrier; + return barrier.has_value(); } inline bool hasConnectorSplit() const { diff --git a/velox/exec/StreamingAggregation.cpp b/velox/exec/StreamingAggregation.cpp index 873e2e853ce..ef21af360f6 100644 --- a/velox/exec/StreamingAggregation.cpp +++ b/velox/exec/StreamingAggregation.cpp @@ -15,6 +15,11 @@ */ #include "velox/exec/StreamingAggregation.h" +#include "velox/common/testutil/TestValue.h" + +using facebook::velox::common::testutil::TestValue; + +#include "velox/exec/OperatorType.h" namespace facebook::velox::exec { @@ -28,8 +33,8 @@ StreamingAggregation::StreamingAggregation( operatorId, aggregationNode->id(), aggregationNode->step() == core::AggregationNode::Step::kPartial - ? "PartialAggregation" - : "Aggregation"), + ? OperatorType::kPartialAggregation + : OperatorType::kAggregation), maxOutputBatchSize_{outputBatchRows()}, minOutputBatchSize_{ operatorCtx_->driverCtx() @@ -41,8 +46,11 @@ StreamingAggregation::StreamingAggregation( ->queryConfig() .streamingAggregationMinOutputBatchRows()) : maxOutputBatchSize_}, + maxOutputBatchBytes_{ + operatorCtx_->driverCtx()->queryConfig().preferredOutputBatchBytes()}, aggregationNode_{aggregationNode}, - step_{aggregationNode->step()} { + step_{aggregationNode->step()}, + noGroupsSpanBatches_{aggregationNode_->noGroupsSpanBatches()} { if (aggregationNode_->ignoreNullKeys()) { VELOX_UNSUPPORTED( "Streaming aggregation doesn't support ignoring null keys yet"); @@ -180,6 +188,14 @@ RowVectorPtr StreamingAggregation::createOutput(size_t numGroups) { } else { function->extractValues(groups_.data(), numGroups, &result); } + + // Clear any state the aggregations may be holding onto and return them + // to a valid initial state. + function->destroy(folly::Range(groups_.data(), numGroups)); + std::vector newGroups; + newGroups.resize(numGroups); + std::iota(newGroups.begin(), newGroups.end(), 0); + function->initializeNewGroups(groups_.data(), newGroups); } if (sortedAggregations_) { @@ -194,6 +210,9 @@ RowVectorPtr StreamingAggregation::createOutput(size_t numGroups) { } } + TestValue::adjust( + "facebook::velox::exec::StreamingAggregation::createOutput", this); + std::rotate(groups_.begin(), groups_.begin() + numGroups, groups_.end()); numGroups_ -= numGroups; @@ -356,23 +375,42 @@ RowVectorPtr StreamingAggregation::getOutput() { initializeNewGroups(numPrevGroups); evaluateAggregates(); + const auto estimatedRowBytes = rows_->estimateRowSize(); + const auto estimatedBatchBytes = + estimatedRowBytes.value_or(0) * rows_->numRows(); + RowVectorPtr output; - if ((numPrevGroups != 0) && (numGroups_ > minOutputBatchSize_)) { + // we do not respect minOutputBatchRows or outputDueToBatchBytes + // when noGroupsSpanBatches is set + const bool outputDueToBatchSize = numGroups_ > minOutputBatchSize_; + const bool outputDueToBatchBytes = + numGroups_ > 1 && estimatedBatchBytes > maxOutputBatchBytes_; + if (noGroupsSpanBatches_ || + (numPrevGroups > 0 && (outputDueToBatchSize || outputDueToBatchBytes))) { size_t numOutputGroups{0}; - // NOTE: we only want to apply the single group output optimization if - // 'minOutputBatchSize_' is set to one for eagerly streaming output - // producing. - if (!prevGroupAssigned || numPrevGroups == 1 || minOutputBatchSize_ != 1) { - numOutputGroups = std::min(numGroups_ - 1, numPrevGroups); + if (noGroupsSpanBatches_) { + numOutputGroups = numGroups_; } else { - numOutputGroups = std::min(numGroups_ - 1, numPrevGroups - 1); - outputFirstGroup_ = (numGroups_ - numOutputGroups) > 1; + // NOTE: we only want to apply the single group output optimization if + // 'minOutputBatchSize_' is set to one for eagerly streaming output + // producing. + if (!prevGroupAssigned || numPrevGroups == 1 || + minOutputBatchSize_ != 1) { + numOutputGroups = std::min(numGroups_ - 1, numPrevGroups); + } else { + numOutputGroups = std::min(numGroups_ - 1, numPrevGroups - 1); + outputFirstGroup_ = (numGroups_ - numOutputGroups) > 1; + } } VELOX_CHECK_GT(numOutputGroups, 0); output = createOutput(numOutputGroups); } prevInput_ = input_; + if (numGroups_ == 0) { + VELOX_CHECK(noGroupsSpanBatches_); + prevInput_ = nullptr; + } input_ = nullptr; return output; } @@ -404,6 +442,8 @@ std::unique_ptr StreamingAggregation::makeRowContainer( false, false, false, + false, // hasCountFlag + false, false, pool()); } diff --git a/velox/exec/StreamingAggregation.h b/velox/exec/StreamingAggregation.h index d01e8e36274..5fcf5d4b729 100644 --- a/velox/exec/StreamingAggregation.h +++ b/velox/exec/StreamingAggregation.h @@ -96,11 +96,20 @@ class StreamingAggregation : public Operator { // Maximum number of rows in the output batch. const vector_size_t minOutputBatchSize_; + // If the size of the data in the RowContainer exceeds this value, we will + // output a batch regardless of the number of rows. + const uint64_t maxOutputBatchBytes_; + // Used at initialize() and gets reset() afterward. std::shared_ptr aggregationNode_; const core::AggregationNode::Step step_; + // When true, indicates that no sort group spans across input batches. Each + // input batch contains complete data for its groups. This allows the + // streaming aggregation operator to produce all group results for each input. + const bool noGroupsSpanBatches_; + std::vector groupingKeys_; std::vector aggregates_; std::unique_ptr sortedAggregations_; diff --git a/velox/exec/StreamingEnforceDistinct.cpp b/velox/exec/StreamingEnforceDistinct.cpp new file mode 100644 index 00000000000..e19c6bbe01d --- /dev/null +++ b/velox/exec/StreamingEnforceDistinct.cpp @@ -0,0 +1,117 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/StreamingEnforceDistinct.h" + +#include "velox/exec/OperatorType.h" + +namespace facebook::velox::exec { + +namespace { +// Compares two rows in the same or different vectors and returns true if they +// match in all key columns. +bool equalKeys( + const std::vector& keyChannels, + const RowVectorPtr& batch, + vector_size_t index, + const RowVectorPtr& otherBatch, + vector_size_t otherIndex) { + for (auto channel : keyChannels) { + if (!batch->childAt(channel)->equalValueAt( + otherBatch->childAt(channel).get(), index, otherIndex)) { + return false; + } + } + return true; +} +} // namespace + +StreamingEnforceDistinct::StreamingEnforceDistinct( + int32_t operatorId, + DriverCtx* driverCtx, + const std::shared_ptr& planNode) + : Operator( + driverCtx, + planNode->outputType(), + operatorId, + planNode->id(), + OperatorType::kStreamingEnforceDistinct), + inputType_{planNode->sources()[0]->outputType()}, + keyChannels_{toChannels( + inputType_, + std::vector{ + planNode->distinctKeys().begin(), + planNode->distinctKeys().end()})}, + errorMessage_{planNode->errorMessage()} { + for (auto i = 0; i < inputType_->size(); ++i) { + identityProjections_.emplace_back(i, i); + } +} + +void StreamingEnforceDistinct::addInput(RowVectorPtr input) { + if (input->size() == 0) { + return; + } + + // Check first row against previous batch's last row. + if (prevKeyValues_ != nullptr && + equalKeys(keyChannels_, input, 0, prevKeyValues_, 0)) { + VELOX_USER_FAIL("{}", errorMessage_); + } + + // Check consecutive rows within this batch. + for (vector_size_t i = 1; i < input->size(); ++i) { + if (equalKeys(keyChannels_, input, i, input, i - 1)) { + VELOX_USER_FAIL("{}", errorMessage_); + } + } + + // Save key values from the last row for comparison with next batch. + + if (prevKeyValues_ == nullptr) { + std::vector keyVectors(inputType_->size()); + for (auto channel : keyChannels_) { + keyVectors[channel] = + BaseVector::create(inputType_->childAt(channel), 1, pool()); + } + prevKeyValues_ = std::make_shared( + pool(), inputType_, nullptr, 1, std::move(keyVectors)); + } + + const auto lastRow = input->size() - 1; + for (auto channel : keyChannels_) { + prevKeyValues_->childAt(channel)->copy( + input->childAt(channel).get(), 0, lastRow, 1); + } + + input_ = std::move(input); +} + +RowVectorPtr StreamingEnforceDistinct::getOutput() { + if (isFinished() || !input_) { + return nullptr; + } + + auto output = fillOutput(input_->size(), nullptr); + input_ = nullptr; + return output; +} + +bool StreamingEnforceDistinct::isFinished() { + return noMoreInput_ && !input_; +} + +} // namespace facebook::velox::exec diff --git a/velox/exec/StreamingEnforceDistinct.h b/velox/exec/StreamingEnforceDistinct.h new file mode 100644 index 00000000000..5f22c098712 --- /dev/null +++ b/velox/exec/StreamingEnforceDistinct.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/exec/Operator.h" + +namespace facebook::velox::exec { + +/// Streaming implementation of EnforceDistinct for pre-grouped input. +/// Compares each row with the previous row to detect duplicates. +/// Memory usage is O(1) - only stores the previous row's key values. +/// +/// Use this operator when input is clustered on distinct keys, i.e., rows with +/// the same key values are guaranteed to be adjacent. +class StreamingEnforceDistinct : public Operator { + public: + StreamingEnforceDistinct( + int32_t operatorId, + DriverCtx* driverCtx, + const std::shared_ptr& planNode); + + bool preservesOrder() const override { + return true; + } + + bool needsInput() const override { + return !noMoreInput_ && !input_; + } + + void addInput(RowVectorPtr input) override; + + RowVectorPtr getOutput() override; + + BlockingReason isBlocked(ContinueFuture* /*future*/) override { + return BlockingReason::kNotBlocked; + } + + bool isFinished() override; + + private: + const RowTypePtr inputType_; + const std::vector keyChannels_; + const std::string errorMessage_; + + // Key values from the last row of the previous batch for cross-batch + // comparison. Lazily initialized on first input batch. + RowVectorPtr prevKeyValues_; +}; + +} // namespace facebook::velox::exec diff --git a/velox/exec/SubPartitionedSortWindowBuild.cpp b/velox/exec/SubPartitionedSortWindowBuild.cpp new file mode 100644 index 00000000000..2f2a247a8d4 --- /dev/null +++ b/velox/exec/SubPartitionedSortWindowBuild.cpp @@ -0,0 +1,192 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/SubPartitionedSortWindowBuild.h" +#include "velox/exec/MemoryReclaimer.h" + +namespace facebook::velox::exec { + +SubPartitionedSortWindowBuild::SubPartitionedSortWindowBuild( + const std::shared_ptr& node, + int32_t numSubPartitions, + velox::memory::MemoryPool* pool, + common::PrefixSortConfig&& prefixSortConfig, + const common::SpillConfig* spillConfig, + tsan_atomic* nonReclaimableSection, + folly::Synchronized* opStats, + exec::SpillStats* spillStats) + : WindowBuild(node, pool, spillConfig, nonReclaimableSection), + numSubPartitions_(numSubPartitions), + numPartitionKeys_{node->partitionKeys().size()}, + pool_(pool), + spillStats_(spillStats) { + VELOX_CHECK_NOT_NULL(pool_); + data_.reset(); + + std::vector keyChannels(numPartitionKeys_); + for (int i = 0; i < numPartitionKeys_; i++) { + keyChannels[i] = inputChannels_[i]; + } + subPartitioningFunction_ = std::make_unique( + false, numSubPartitions_, node->inputType(), keyChannels); + subWindowBuilds_.resize(numSubPartitions_); + for (int i = 0; i < numSubPartitions_; i++) { + subWindowBuilds_[i] = std::make_unique( + node, + pool, + common::PrefixSortConfig(prefixSortConfig), + spillConfig, + nonReclaimableSection, + opStats, + spillStats); + } +} + +void SubPartitionedSortWindowBuild::addInput(RowVectorPtr input) { + VELOX_CHECK_LT(currentSubPartition_, 0); + + subPartitionIdsBuffer_.resize(input->size()); + subPartitioningFunction_->partition(*input, subPartitionIdsBuffer_); + + for (auto i = 0; i < inputChannels_.size(); ++i) { + decodedInputVectors_[i].decode(*input->childAt(inputChannels_[i])); + } + + ensureInputFits(input); + + for (auto row = 0; row < input->size(); ++row) { + auto& windowBuild = subWindowBuilds_[subPartitionIdsBuffer_[row]]; + windowBuild->addDecodedInputRow(decodedInputVectors_, row); + } + + numRows_ += input->size(); +} + +bool SubPartitionedSortWindowBuild::switchToNextSubPartition() { + if (currentSubPartition_ >= numSubPartitions_) { + return false; + } + + if (currentSubPartition_ >= 0) { + subWindowBuilds_[currentSubPartition_].reset(); + } + currentSubPartition_++; + if (currentSubPartition_ >= numSubPartitions_) { + return false; + } + + VELOX_CHECK_NOT_NULL(subWindowBuilds_[currentSubPartition_]); + // WindowBuild starts processing the partitions when 'noMoreInput' is called, + // which allocates additional memory. We want to defer the memory allocation + // as late as possible to reduce memory usage, so we don't call 'noMoreInput' + // until the sub partition's data is to be consumed. + subWindowBuilds_[currentSubPartition_]->noMoreInput(); + return true; +} + +void SubPartitionedSortWindowBuild::ensureInputFits(const RowVectorPtr& input) { + if (spillConfig_ == nullptr) { + // Spilling is disabled. + return; + } + + if (numRows_ == 0) { + // Nothing to spill. + return; + } + + // Test-only spill path. + if (testingTriggerSpill(pool_->name())) { + spill(); + return; + } + + VELOX_CHECK_LT(currentSubPartition_, 0); + for (auto& windowBuild : subWindowBuilds_) { + windowBuild->ensureInputFits(input); + } +} + +void SubPartitionedSortWindowBuild::spill() { + VELOX_CHECK_LT(currentSubPartition_, 0); + for (auto& windowBuild : subWindowBuilds_) { + windowBuild->spill(); + } + spilled_ = true; +} + +std::optional SubPartitionedSortWindowBuild::spilledStats() + const { + if (!spilled_) { + return std::nullopt; + } + return {*spillStats_}; +} + +void SubPartitionedSortWindowBuild::noMoreInput() { + if (numRows_ == 0) { + return; + } + + if (spilled_) { + // Spill remaining data to avoid running out of memory while sort-merging + // spilled data. + spill(); + } + + switchToNextSubPartition(); + + VELOX_CHECK_EQ(currentSubPartition_, 0); +} + +std::shared_ptr +SubPartitionedSortWindowBuild::nextPartition() { + VELOX_CHECK_GE(currentSubPartition_, 0); + VELOX_CHECK_LT(currentSubPartition_, numSubPartitions_); + VELOX_CHECK_NOT_NULL(subWindowBuilds_[currentSubPartition_]); + return subWindowBuilds_[currentSubPartition_]->nextPartition(); +} + +std::optional SubPartitionedSortWindowBuild::estimateRowSize() { + auto subPartition = std::max(currentSubPartition_, 0); + if (subPartition >= numSubPartitions_) { + return std::nullopt; + } + + if (subWindowBuilds_[subPartition]) { + return subWindowBuilds_[subPartition]->estimateRowSize(); + } + + return std::nullopt; +} + +bool SubPartitionedSortWindowBuild::hasNextPartition() { + // Check if the build hasn't begun or has finished. + if (currentSubPartition_ < 0 || currentSubPartition_ >= numSubPartitions_) { + return false; + } + + VELOX_CHECK_NOT_NULL(subWindowBuilds_[currentSubPartition_]); + if (subWindowBuilds_[currentSubPartition_]->hasNextPartition()) { + return true; + } + + if (switchToNextSubPartition()) { + return hasNextPartition(); + } + return false; +} +} // namespace facebook::velox::exec diff --git a/velox/exec/SubPartitionedSortWindowBuild.h b/velox/exec/SubPartitionedSortWindowBuild.h new file mode 100644 index 00000000000..8735f438d30 --- /dev/null +++ b/velox/exec/SubPartitionedSortWindowBuild.h @@ -0,0 +1,95 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/exec/HashPartitionFunction.h" +#include "velox/exec/PrefixSort.h" +#include "velox/exec/SortWindowBuild.h" +#include "velox/exec/Spiller.h" + +namespace facebook::velox::exec { +// Divides the input data into several sub partitions by partition keys, then +// sequentially sorts input data of each sub partition by {partition keys, sort +// keys} to identify window partitions with SortWindowBuild. As each sub +// partition has a smaller working set, the memory used by sorting is reduced. +// Besides, once a sub partition is completely consumed, its memory could be +// released immediately. +class SubPartitionedSortWindowBuild : public WindowBuild { + public: + SubPartitionedSortWindowBuild( + const std::shared_ptr& node, + int32_t numSubPartitions, + velox::memory::MemoryPool* pool, + common::PrefixSortConfig&& prefixSortConfig, + const common::SpillConfig* spillConfig, + tsan_atomic* nonReclaimableSection, + folly::Synchronized* opStats, + exec::SpillStats* spillStats); + + ~SubPartitionedSortWindowBuild() override { + pool_->release(); + } + + bool needsInput() override { + // No sub partitions are available yet, so can consume input rows. + return currentSubPartition_ < 0; + } + + void addInput(RowVectorPtr input) override; + + void spill() override; + + std::optional spilledStats() const override; + + void noMoreInput() override; + + bool hasNextPartition() override; + + std::shared_ptr nextPartition() override; + + std::optional estimateRowSize() override; + + private: + // The current sub partition's WindowBuild has finished producing all the + // data. Release all the memory of current sub partition's WindowBuild, and + // then switch to next sub partition's WindowBuild as the new current one. + bool switchToNextSubPartition(); + + void ensureInputFits(const RowVectorPtr& input); + + const int32_t numSubPartitions_; + + const size_t numPartitionKeys_; + + memory::MemoryPool* const pool_; + + exec::SpillStats* const spillStats_; + + // Divide input rows to the corresponding sub partitions. + std::unique_ptr subPartitioningFunction_; + + // WindowBuilds for each sub partition. + std::vector> subWindowBuilds_; + + bool spilled_{false}; + + // Buffers the subPartitionIds for each row. Reused across addInput calls. + std::vector subPartitionIdsBuffer_; + + int32_t currentSubPartition_ = -1; +}; +} // namespace facebook::velox::exec diff --git a/velox/exec/TableScan.cpp b/velox/exec/TableScan.cpp index 9157e876d5f..de261dda455 100644 --- a/velox/exec/TableScan.cpp +++ b/velox/exec/TableScan.cpp @@ -16,6 +16,8 @@ #include "velox/exec/TableScan.h" #include "velox/common/testutil/TestValue.h" #include "velox/common/time/Timer.h" +#include "velox/connectors/ConnectorRegistry.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/Task.h" using facebook::velox::common::testutil::TestValue; @@ -76,7 +78,7 @@ TableScan::TableScan( tableScanNode->outputType(), operatorId, tableScanNode->id(), - "TableScan"), + OperatorType::kTableScan), tableHandle_(tableScanNode->tableHandle()), columnHandles_(tableScanNode->assignments()), driverCtx_(driverCtx), @@ -89,9 +91,14 @@ TableScan::TableScan( driverCtx_->driverId, operatorType(), tableHandle_->connectorId())), - connector_(connector::getConnector(tableHandle_->connectorId())), + connector_( + connector::ConnectorRegistry::tryGet( + *driverCtx->task->queryCtx(), + tableHandle_->connectorId())), getOutputTimeLimitMs_( driverCtx_->queryConfig().tableScanGetOutputTimeLimitMs()), + outputBatchRowsOverride_( + driverCtx_->queryConfig().tableScanOutputBatchRowsOverride()), scaledController_(driverCtx_->task->getScaledScanControllerLocked( driverCtx_->splitGroupId, planNodeId())) { @@ -171,6 +178,7 @@ RowVectorPtr TableScan::getOutput() { }, &debugString_}); + checkPreload(); if (needNewSplit_) { const auto hasNewSplit = getSplit(); if (!hasNewSplit) { @@ -180,21 +188,15 @@ RowVectorPtr TableScan::getOutput() { } continue; } - const auto estimatedRowSize = dataSource_->estimatedRowSize(); - readBatchSize_ = - estimatedRowSize == connector::DataSource::kUnknownRowSize - ? outputBatchRows() - : outputBatchRows(estimatedRowSize); } VELOX_CHECK(!needNewSplit_); VELOX_CHECK(!hasDrained()); - int32_t readBatchSize = readBatchSize_; - if (maxFilteringRatio_ > 0) { - readBatchSize = std::min( - maxReadBatchSize_, - static_cast(readBatchSize / maxFilteringRatio_)); - } + const auto estimatedRowSize = dataSource_->estimatedRowSize(); + const int32_t readBatchSize = calculateBatchSize(estimatedRowSize); + + const auto prevCompletedRows = dataSource_->getCompletedRows(); + uint64_t ioTimeUs{0}; std::optional dataOptional; { @@ -203,11 +205,10 @@ RowVectorPtr TableScan::getOutput() { dataOptional = dataSource_->next(readBatchSize, blockingFuture_); } - checkPreload(); { auto lockedStats = stats_.wlock(); lockedStats->addRuntimeStat( - "dataSourceReadWallNanos", + std::string(TableScan::kDataSourceReadWallNanos), RuntimeCounter(ioTimeUs * 1'000, RuntimeCounter::Unit::kNanos)); if (!dataOptional.has_value()) { @@ -225,7 +226,19 @@ RowVectorPtr TableScan::getOutput() { if (data != nullptr && !shouldDropOutput()) { constexpr int kMaxSelectiveBatchSizeMultiplier = 4; if (data->size() > 0) { - lockedStats->addInputVector(data->estimateFlatSize(), data->size()); + uint64_t flatSize = 0; + if (driverCtx_->driver->enableOperatorBatchSizeStats()) { + flatSize = data->estimateFlatSize(); + } + lockedStats->addInputVector(flatSize, data->size()); + const auto completedRowsDelta = + dataSource_->getCompletedRows() - prevCompletedRows; + if (completedRowsDelta > 0) { + core::ScanBatchEvent event; + event.numRows = completedRowsDelta; + event.wallTimeMicros = ioTimeUs; + dataSource_->fireScanBatchCallback(event); + } maxFilteringRatio_ = std::max( {maxFilteringRatio_, 1.0 * data->size() / readBatchSize, @@ -234,8 +247,9 @@ RowVectorPtr TableScan::getOutput() { RECORD_METRIC_VALUE( velox::kMetricTableScanBatchProcessTimeMs, ioTimeUs / 1'000); } - RECORD_METRIC_VALUE( - velox::kMetricTableScanBatchBytes, data->estimateFlatSize()); + if (driverCtx_->driver->enableOperatorBatchSizeStats()) { + RECORD_METRIC_VALUE(velox::kMetricTableScanBatchBytes, flatSize); + } return data; } else { maxFilteringRatio_ = std::max( @@ -250,12 +264,14 @@ RowVectorPtr TableScan::getOutput() { auto lockedStats = stats_.wlock(); if (numPreloadedSplits_ > 0) { lockedStats->addRuntimeStat( - "preloadedSplits", RuntimeCounter(numPreloadedSplits_)); + std::string(TableScan::kPreloadedSplits), + RuntimeCounter(numPreloadedSplits_)); numPreloadedSplits_ = 0; } if (numReadyPreloadedSplits_ > 0) { lockedStats->addRuntimeStat( - "readyPreloadedSplits", RuntimeCounter(numReadyPreloadedSplits_)); + std::string(TableScan::kReadyPreloadedSplits), + RuntimeCounter(numReadyPreloadedSplits_)); numReadyPreloadedSplits_ = 0; } currNumRawInputRows = lockedStats->rawInputPositions; @@ -285,12 +301,13 @@ bool TableScan::getSplit() { exec::Split split; blockingReason_ = driverCtx_->task->getSplitOrFuture( + driverCtx_->driverId, driverCtx_->splitGroupId, planNodeId(), - split, - blockingFuture_, maxPreloadedSplits_, - splitPreloader_); + splitPreloader_, + split, + blockingFuture_); if (blockingReason_ != BlockingReason::kNotBlocked) { return false; } @@ -303,15 +320,15 @@ bool TableScan::getSplit() { if (!split.hasConnectorSplit()) { noMoreSplits_ = true; if (dataSource_) { - const auto connectorStats = dataSource_->runtimeStats(); + const auto connectorStats = dataSource_->getRuntimeStats(); auto lockedStats = stats_.wlock(); - for (const auto& [name, counter] : connectorStats) { + for (const auto& [name, metric] : connectorStats) { if (FOLLY_UNLIKELY(lockedStats->runtimeStats.count(name) == 0)) { - lockedStats->runtimeStats.emplace(name, RuntimeMetric(counter.unit)); + lockedStats->runtimeStats.emplace(name, RuntimeMetric(metric.unit)); } else { - VELOX_CHECK_EQ(lockedStats->runtimeStats.at(name).unit, counter.unit); + VELOX_CHECK_EQ(lockedStats->runtimeStats.at(name).unit, metric.unit); } - lockedStats->runtimeStats.at(name).addValue(counter.value); + lockedStats->runtimeStats.at(name).merge(metric); } } return false; @@ -322,7 +339,8 @@ bool TableScan::getSplit() { } stats_.wlock()->addRuntimeStat( - "connectorSplitSize", RuntimeCounter(split.connectorSplit->size())); + std::string(TableScan::kConnectorSplitSize), + RuntimeCounter(static_cast(split.connectorSplit->size()))); const auto& connectorSplit = split.connectorSplit; currentSplitWeight_ = connectorSplit->splitWeight; needNewSplit_ = false; @@ -346,6 +364,10 @@ bool TableScan::getSplit() { tableHandle_, columnHandles_, connectorQueryCtx_.get()); + if (const auto& callback = + operatorCtx_->driverCtx()->task->queryCtx()->scanBatchCallback()) { + dataSource_->setScanBatchCallback(callback); + } } debugString_ = fmt::format( @@ -368,18 +390,18 @@ bool TableScan::getSplit() { auto preparedDataSource = connectorSplit->dataSource->move(); auto endTimeNs = getCurrentTimeNano(); stats_.wlock()->addRuntimeStat( - "waitForPreloadSplitNanos", + std::string(TableScan::kWaitForPreloadSplitNanos), RuntimeCounter(endTimeNs - startTimeNs, RuntimeCounter::Unit::kNanos)); - stats_.wlock()->addRuntimeStat( - "preloadSplitPrepareTimeNanos", - RuntimeCounter( - connectorSplit->dataSource->prepareTiming().wallNanos, - RuntimeCounter::Unit::kNanos)); - if (!preparedDataSource) { + if (preparedDataSource == nullptr) { // There must be a cancellation. VELOX_CHECK(operatorCtx_->task()->isCancelled()); return false; } + stats_.wlock()->addRuntimeStat( + std::string(TableScan::kPreloadSplitPrepareTimeNanos), + RuntimeCounter( + connectorSplit->dataSource->prepareTiming().wallNanos, + RuntimeCounter::Unit::kNanos)); dataSource_->setFromDataSource(std::move(preparedDataSource)); } else { uint64_t addSplitTimeUs{0}; @@ -389,7 +411,7 @@ bool TableScan::getSplit() { dataSource_->addSplit(connectorSplit); } stats_.wlock()->addRuntimeStat( - "dataSourceAddSplitWallNanos", + std::string(TableScan::kDataSourceAddSplitWallNanos), RuntimeCounter(addSplitTimeUs * 1'000, RuntimeCounter::Unit::kNanos)); } ++stats_.wlock()->numSplits; @@ -471,21 +493,18 @@ void TableScan::checkPreload() { !connector_->supportsSplitPreload()) { return; } - if (dataSource_->allPrefetchIssued()) { - maxPreloadedSplits_ = driverCtx_->task->numDrivers(driverCtx_->driver) * - maxSplitPreloadPerDriver_; - if (!splitPreloader_) { - splitPreloader_ = - [ioExecutor, - this](const std::shared_ptr& split) { - preload(split); - - ioExecutor->add([connectorSplit = split]() mutable { - connectorSplit->dataSource->prepare(); - connectorSplit.reset(); - }); - }; - } + maxPreloadedSplits_ = driverCtx_->task->numDrivers(driverCtx_->driver) * + maxSplitPreloadPerDriver_; + if (!splitPreloader_) { + splitPreloader_ = + [ioExecutor, + this](const std::shared_ptr& split) { + preload(split); + ioExecutor->add([connectorSplit = split]() mutable { + connectorSplit->dataSource->prepare(); + connectorSplit.reset(); + }); + }; } } @@ -504,6 +523,35 @@ void TableScan::addDynamicFilterLocked( stats_.wlock()->dynamicFilterStats.producerNodeIds.emplace(producer); } +int32_t TableScan::calculateBatchSize(int64_t currentEstimatedRowSize) { + if (outputBatchRowsOverride_ > 0) { + return outputBatchRowsOverride_; + } + int64_t estimatedRowSize = connector::DataSource::kUnknownRowSize; + if (currentEstimatedRowSize != connector::DataSource::kUnknownRowSize) { + // Use current file estimate. + fileEstimatedRowSize_ = currentEstimatedRowSize; + estimatedRowSize = currentEstimatedRowSize; + } else if (fileEstimatedRowSize_ != connector::DataSource::kUnknownRowSize) { + // Fallback to previous file estimate. + estimatedRowSize = fileEstimatedRowSize_; + } + // Otherwise, no estimate available: use preferredOutputBatchRows() + // (readBatchSize_ default). + + if (estimatedRowSize != connector::DataSource::kUnknownRowSize) { + readBatchSize_ = outputBatchRows(estimatedRowSize); + } + + int32_t batchSize = readBatchSize_; + if (maxFilteringRatio_ > 0) { + batchSize = std::min( + maxReadBatchSize_, + static_cast(batchSize / maxFilteringRatio_)); + } + return batchSize; +} + void TableScan::close() { Operator::close(); @@ -524,7 +572,7 @@ void TableScan::close() { const auto scaledStats = scaledController_->stats(); auto lockedStats = stats_.wlock(); lockedStats->addRuntimeStat( - TableScan::kNumRunningScaleThreads, + std::string(TableScan::kNumRunningScaleThreads), RuntimeCounter(scaledStats.numRunningDrivers)); } } // namespace facebook::velox::exec diff --git a/velox/exec/TableScan.h b/velox/exec/TableScan.h index 69bff9ddc45..e3573211404 100644 --- a/velox/exec/TableScan.h +++ b/velox/exec/TableScan.h @@ -15,6 +15,8 @@ */ #pragma once +#include + #include "velox/core/PlanNode.h" #include "velox/exec/Operator.h" #include "velox/exec/ScaledScanController.h" @@ -61,13 +63,44 @@ class TableScan : public SourceOperator { /// /// NOTE: we only report the number of running scan drivers at the point that /// all the splits have been dispatched. - static inline const std::string kNumRunningScaleThreads{ - "numRunningScaleThreads"}; + static constexpr std::string_view kNumRunningScaleThreads = + "numRunningScaleThreads"; + + /// Time spent reading from the data source. + static constexpr std::string_view kDataSourceReadWallNanos = + "dataSourceReadWallNanos"; + + /// Number of splits that started background preload. + static constexpr std::string_view kPreloadedSplits = "preloadedSplits"; + + /// Number of preloaded splits that finished before being read. + static constexpr std::string_view kReadyPreloadedSplits = + "readyPreloadedSplits"; + + /// Size of the connector split. + static constexpr std::string_view kConnectorSplitSize = "connectorSplitSize"; + + /// Time waiting for a preloaded split to become available. + static constexpr std::string_view kWaitForPreloadSplitNanos = + "waitForPreloadSplitNanos"; + + /// Time for preload split preparation. + static constexpr std::string_view kPreloadSplitPrepareTimeNanos = + "preloadSplitPrepareTimeNanos"; + + /// Time spent adding a split to the data source. + static constexpr std::string_view kDataSourceAddSplitWallNanos = + "dataSourceAddSplitWallNanos"; std::shared_ptr testingScaledController() const { return scaledController_; } + /// Returns the current read batch size. Used for testing. + vector_size_t testingReadBatchSize() const { + return readBatchSize_; + } + private: // Checks if this table scan operator needs to yield before processing the // next split. @@ -99,6 +132,10 @@ class TableScan : public SourceOperator { // processing or not. void tryScaleUp(); + // Calculates the batch size to read based on available row size information. + // Returns the number of rows to read in the next batch. + int32_t calculateBatchSize(int64_t currentEstimatedRowSize); + const connector::ConnectorTableHandlePtr tableHandle_; const connector::ColumnHandleMap columnHandles_; DriverCtx* const driverCtx_; @@ -110,6 +147,8 @@ class TableScan : public SourceOperator { // limit'. const size_t getOutputTimeLimitMs_{0}; + const uint32_t outputBatchRowsOverride_; + // If set, used for scan scale processing. It is shared by all the scan // operators instantiated from the same table scan node. const std::shared_ptr scaledController_; @@ -141,6 +180,10 @@ class TableScan : public SourceOperator { double maxFilteringRatio_{0}; + // Row size estimate from the file reader. It is set to the last known + // estimated row size from the current split reader or the previous ones. + int64_t fileEstimatedRowSize_{connector::DataSource::kUnknownRowSize}; + // String shown in ExceptionContext inside DataSource and LazyVector loading. std::string debugString_; diff --git a/velox/exec/TableWriteMerge.cpp b/velox/exec/TableWriteMerge.cpp index 8a9b6c2b8ff..f8e96378a20 100644 --- a/velox/exec/TableWriteMerge.cpp +++ b/velox/exec/TableWriteMerge.cpp @@ -16,35 +16,10 @@ #include "velox/exec/TableWriteMerge.h" -#include "HashAggregation.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/TableWriter.h" -#include "velox/exec/Task.h" namespace facebook::velox::exec { -namespace { -bool isSameCommitContext( - const folly::dynamic& first, - const folly::dynamic& second) { - return std::tie( - first[TableWriteTraits::kTaskIdContextKey], - first[TableWriteTraits::kCommitStrategyContextKey]) == - std::tie( - second[TableWriteTraits::kTaskIdContextKey], - second[TableWriteTraits::kCommitStrategyContextKey]); -} - -bool containsNonNullRows(const VectorPtr& vector) { - if (!vector->mayHaveNulls()) { - return true; - } - for (int i = 0; i < vector->size(); ++i) { - if (!vector->isNullAt(i)) { - return true; - } - } - return false; -} -} // namespace TableWriteMerge::TableWriteMerge( int32_t operatorId, @@ -55,7 +30,7 @@ TableWriteMerge::TableWriteMerge( tableWriteMergeNode->outputType(), operatorId, tableWriteMergeNode->id(), - "TableWriteMerge") { + OperatorType::kTableWriteMerge) { if (tableWriteMergeNode->outputType()->size() == 1) { VELOX_USER_CHECK(!tableWriteMergeNode->hasColumnStatsSpec()); } else { @@ -79,35 +54,98 @@ void TableWriteMerge::initialize() { } } +namespace { +// Creates a RowVector containing only the rows at the given indices from +// 'input'. Each child vector is wrapped in a dictionary to avoid copying. +RowVectorPtr +selectRows(const RowVectorPtr& input, BufferPtr indices, vector_size_t size) { + std::vector children(input->childrenSize()); + for (auto i = 0; i < input->childrenSize(); ++i) { + children[i] = + BaseVector::wrapInDictionary(nullptr, indices, size, input->childAt(i)); + } + return std::make_shared( + input->pool(), input->type(), nullptr, size, std::move(children)); +} +} // namespace + void TableWriteMerge::addInput(RowVectorPtr input) { VELOX_CHECK(!noMoreInput_); VELOX_CHECK_GT(input->size(), 0); - if (isStatistics(input)) { - VELOX_CHECK_NOT_NULL(statsCollector_); - statsCollector_->addInput(input); - return; + // Possibly mixed batch: split into stats and data using dictionary wrapping. + auto statsIndices = allocateIndices(input->size(), pool()); + auto dataIndices = allocateIndices(input->size(), pool()); + auto* rawStatsIndices = statsIndices->asMutable(); + auto* rawDataIndices = dataIndices->asMutable(); + vector_size_t numStats{0}; + vector_size_t numData{0}; + for (vector_size_t i = 0; i < input->size(); ++i) { + if (TableWriteTraits::isStatisticsRow(input, i)) { + rawStatsIndices[numStats++] = i; + } else { + rawDataIndices[numData++] = i; + } } - // Increments row count. + if (numStats > 0) { + if (numData == 0) { + addStatisticsInput(input); + } else { + addStatisticsInput(selectRows(input, statsIndices, numStats)); + } + } + + if (numData > 0) { + if (numStats == 0) { + addDataInput(input); + } else { + addDataInput(selectRows(input, dataIndices, numData)); + } + } +} + +void TableWriteMerge::addStatisticsInput(const RowVectorPtr& input) { + VELOX_CHECK_NOT_NULL(statsCollector_); + statsCollector_->addInput(input); +} + +void TableWriteMerge::addDataInput(const RowVectorPtr& input) { numRows_ += TableWriteTraits::getRowCount(input); - // Makes sure the lifespan is the same. + // Validate commit strategy consistency. TaskId may differ in cross-worker + // merge (coordinator merging output from multiple workers). auto commitContext = TableWriteTraits::getTableCommitContext(input); if (lastCommitContext_ != nullptr) { - VELOX_CHECK( - isSameCommitContext(lastCommitContext_, commitContext), - "incompatible table commit context: {} is not compatible with {}", - lastCommitContext_.asString(), - commitContext.asString()); + VELOX_CHECK_EQ( + lastCommitContext_[TableWriteTraits::kCommitStrategyContextKey] + .asString(), + commitContext[TableWriteTraits::kCommitStrategyContextKey].asString(), + "Mismatched commit strategy in commit context"); } lastCommitContext_ = commitContext; - // Adds fragments to the buffer. Fragments will be emitted as soon as possible - // to avoid using extra memory. + // Buffer non-null fragments for early emission to free memory. The input + // may contain a mix of fragment rows (non-null) and summary rows (null + // fragment) when the TableWriter produces them in a single multi-row + // output. auto fragmentVector = input->childAt(TableWriteTraits::kFragmentChannel); - if (containsNonNullRows(fragmentVector)) { + if (!fragmentVector->mayHaveNulls()) { fragmentVectors_.push(fragmentVector); + } else { + auto indices = allocateIndices(fragmentVector->size(), pool()); + auto* rawIndices = indices->asMutable(); + vector_size_t numNonNull{0}; + for (vector_size_t i = 0; i < fragmentVector->size(); ++i) { + if (!fragmentVector->isNullAt(i)) { + rawIndices[numNonNull++] = i; + } + } + if (numNonNull > 0) { + fragmentVectors_.push( + BaseVector::wrapInDictionary( + nullptr, indices, numNonNull, fragmentVector)); + } } } @@ -213,8 +251,4 @@ RowVectorPtr TableWriteMerge::createLastOutput() { return output; } -bool TableWriteMerge::isStatistics(RowVectorPtr input) { - return input->childAt(TableWriteTraits::kRowCountChannel)->isNullAt(0) && - input->childAt(TableWriteTraits::kFragmentChannel)->isNullAt(0); -} } // namespace facebook::velox::exec diff --git a/velox/exec/TableWriteMerge.h b/velox/exec/TableWriteMerge.h index c5a0ade2ccd..8dd1ea99b3c 100644 --- a/velox/exec/TableWriteMerge.h +++ b/velox/exec/TableWriteMerge.h @@ -55,6 +55,12 @@ class TableWriteMerge : public Operator { void close() override; private: + // Processes a batch of statistics rows. + void addStatisticsInput(const RowVectorPtr& input); + + // Processes a batch of data rows (row counts, fragments, commit context). + void addDataInput(const RowVectorPtr& input); + // Creates non-last output with fragments and last commit context only. RowVectorPtr createFragmentsOutput(); @@ -65,9 +71,6 @@ class TableWriteMerge : public Operator { // Creates the last output and fragment columns must be null. RowVectorPtr createLastOutput(); - // Check if the input is statistics input. - bool isStatistics(RowVectorPtr input); - std::unique_ptr statsCollector_; bool finished_{false}; // The sum of written rows. diff --git a/velox/exec/TableWriter.cpp b/velox/exec/TableWriter.cpp index 776a6e4d653..fa1ba1f6e1a 100644 --- a/velox/exec/TableWriter.cpp +++ b/velox/exec/TableWriter.cpp @@ -15,6 +15,8 @@ */ #include "velox/exec/TableWriter.h" +#include "velox/connectors/ConnectorRegistry.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/Task.h" namespace facebook::velox::exec { @@ -28,9 +30,11 @@ TableWriter::TableWriter( tableWriteNode->outputType(), operatorId, tableWriteNode->id(), - "TableWrite", + OperatorType::kTableWrite, tableWriteNode->canSpill(driverCtx->queryConfig()) - ? driverCtx->makeSpillConfig(operatorId) + ? driverCtx->makeSpillConfig( + operatorId, + OperatorType::kTableWrite) : std::nullopt), driverCtx_(driverCtx), connectorPool_(driverCtx_->task->addConnectorPoolLocked( @@ -42,7 +46,7 @@ TableWriter::TableWriter( insertTableHandle_( tableWriteNode->insertTableHandle()->connectorInsertTableHandle()), commitStrategy_(tableWriteNode->commitStrategy()), - createTimeUs_(getCurrentTimeNano()) { + createTimeNs_(getCurrentTimeNano()) { setConnectorMemoryReclaimer(); if (tableWriteNode->outputType()->size() == 1) { VELOX_USER_CHECK(!tableWriteNode->columnStatsSpec().has_value()); @@ -60,7 +64,8 @@ TableWriter::TableWriter( &nonReclaimableSection_); } const auto& connectorId = tableWriteNode->insertTableHandle()->connectorId(); - connector_ = connector::getConnector(connectorId); + connector_ = connector::ConnectorRegistry::tryGet( + *driverCtx->task->queryCtx(), connectorId); connectorQueryCtx_ = operatorCtx_->createConnectorQueryCtx( connectorId, planNodeId(), @@ -258,8 +263,9 @@ RowVectorPtr TableWriter::getOutput() { if (statsCollector_ != nullptr) { for (int i = TableWriteTraits::kStatsChannel; i < outputType_->size(); ++i) { - columns.push_back(BaseVector::createNullConstant( - outputType_->childAt(i), writtenRowsVector->size(), pool())); + columns.push_back( + BaseVector::createNullConstant( + outputType_->childAt(i), writtenRowsVector->size(), pool())); } } @@ -280,7 +286,7 @@ std::string TableWriter::createTableCommitContext(bool lastOutput) { void TableWriter::updateStats(const connector::DataSink::Stats& stats) { const auto currentTimeNs = getCurrentTimeNano(); - VELOX_CHECK_GE(currentTimeNs, createTimeUs_); + VELOX_CHECK_GE(currentTimeNs, createTimeNs_); { auto lockedStats = stats_.wlock(); lockedStats->physicalWrittenBytes = stats.numWrittenBytes; @@ -314,10 +320,10 @@ void TableWriter::updateStats(const connector::DataSink::Stats& stats) { lockedStats->addRuntimeStat( kRunningWallNanos, RuntimeCounter( - currentTimeNs - createTimeUs_, RuntimeCounter::Unit::kNanos)); + currentTimeNs - createTimeNs_, RuntimeCounter::Unit::kNanos)); } if (!stats.spillStats.empty()) { - *spillStats_->wlock() += stats.spillStats; + *spillStats_ += stats.spillStats; } } @@ -336,8 +342,9 @@ void TableWriter::close() { void TableWriter::setConnectorMemoryReclaimer() { VELOX_CHECK_NOT_NULL(connectorPool_); if (connectorPool_->parent()->reclaimer() != nullptr) { - connectorPool_->setReclaimer(TableWriter::ConnectorReclaimer::create( - spillConfig_, operatorCtx_->driverCtx(), this)); + connectorPool_->setReclaimer( + TableWriter::ConnectorReclaimer::create( + spillConfig_, operatorCtx_->driverCtx(), this)); } } @@ -387,8 +394,11 @@ uint64_t TableWriter::ConnectorReclaimer::reclaim( // TODO: reduce the log frequency if it is too verbose. ++stats.numNonReclaimableAttempts; LOG(WARNING) << "Can't reclaim from a closed writer connector pool: " - << pool->name() - << ", memory usage: " << succinctBytes(pool->reservedBytes()); + << pool->name() << ", root pool: " << pool->root()->name() + << ", used: " << succinctBytes(pool->usedBytes()) + << ", reservation: " << succinctBytes(pool->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool->root()->reservedBytes()); return 0; } @@ -397,115 +407,15 @@ uint64_t TableWriter::ConnectorReclaimer::reclaim( ++stats.numNonReclaimableAttempts; LOG(WARNING) << "Can't reclaim from a writer connector pool which hasn't initialized yet: " - << pool->name() - << ", memory usage: " << succinctBytes(pool->reservedBytes()); + << pool->name() << ", root pool: " << pool->root()->name() + << ", used: " << succinctBytes(pool->usedBytes()) + << ", reservation: " << succinctBytes(pool->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool->root()->reservedBytes()); return 0; } RuntimeStatWriterScopeGuard opStatsGuard(op_); return ParallelMemoryReclaimer::reclaim(pool, targetBytes, maxWaitMs, stats); } -// static -RowVectorPtr TableWriteTraits::createAggregationStatsOutput( - RowTypePtr outputType, - RowVectorPtr aggregationOutput, - StringView tableCommitContext, - velox::memory::MemoryPool* pool) { - // TODO: record aggregation stats output time. - if (aggregationOutput == nullptr) { - return nullptr; - } - VELOX_CHECK_GT(aggregationOutput->childrenSize(), 0); - const vector_size_t numOutputRows = aggregationOutput->childAt(0)->size(); - std::vector columns; - for (int channel = 0; channel < outputType->size(); channel++) { - if (channel < TableWriteTraits::kContextChannel) { - // 1. Set null rows column. - // 2. Set null fragments column. - columns.push_back(BaseVector::createNullConstant( - outputType->childAt(channel), numOutputRows, pool)); - continue; - } - if (channel == TableWriteTraits::kContextChannel) { - // 3. Set commitcontext column. - columns.push_back(std::make_shared>( - pool, - numOutputRows, - false /*isNull*/, - VARBINARY(), - std::move(tableCommitContext))); - continue; - } - // 4. Set statistics columns. - columns.push_back( - aggregationOutput->childAt(channel - TableWriteTraits::kStatsChannel)); - } - return std::make_shared( - pool, outputType, nullptr, numOutputRows, columns); -} - -std::string TableWriteTraits::rowCountColumnName() { - static const std::string kRowCountName = "rows"; - return kRowCountName; -} - -std::string TableWriteTraits::fragmentColumnName() { - static const std::string kFragmentName = "fragments"; - return kFragmentName; -} - -std::string TableWriteTraits::contextColumnName() { - static const std::string kContextName = "commitcontext"; - return kContextName; -} - -const TypePtr& TableWriteTraits::rowCountColumnType() { - static const TypePtr kRowCountType = BIGINT(); - return kRowCountType; -} - -const TypePtr& TableWriteTraits::fragmentColumnType() { - static const TypePtr kFragmentType = VARBINARY(); - return kFragmentType; -} - -const TypePtr& TableWriteTraits::contextColumnType() { - static const TypePtr kContextType = VARBINARY(); - return kContextType; -} - -// static. -RowTypePtr TableWriteTraits::outputType( - const std::optional& columnStatsSpec) { - static const auto kOutputTypeWithoutStats = - ROW({rowCountColumnName(), fragmentColumnName(), contextColumnName()}, - {rowCountColumnType(), fragmentColumnType(), contextColumnType()}); - if (!columnStatsSpec.has_value()) { - return kOutputTypeWithoutStats; - } - return kOutputTypeWithoutStats->unionWith( - ColumnStatsCollector::outputType(columnStatsSpec.value())); -} - -folly::dynamic TableWriteTraits::getTableCommitContext( - const RowVectorPtr& input) { - VELOX_CHECK_GT(input->size(), 0); - auto* contextVector = - input->childAt(kContextChannel)->as>(); - return folly::parseJson(contextVector->valueAt(input->size() - 1)); -} - -int64_t TableWriteTraits::getRowCount(const RowVectorPtr& output) { - VELOX_CHECK_GT(output->size(), 0); - auto rowCountVector = - output->childAt(kRowCountChannel)->asFlatVector(); - VELOX_CHECK_NOT_NULL(rowCountVector); - int64_t rowCount{0}; - for (int i = 0; i < output->size(); ++i) { - if (!rowCountVector->isNullAt(i)) { - rowCount += rowCountVector->valueAt(i); - } - } - return rowCount; -} } // namespace facebook::velox::exec diff --git a/velox/exec/TableWriter.h b/velox/exec/TableWriter.h index f25ceeb28af..577dc394154 100644 --- a/velox/exec/TableWriter.h +++ b/velox/exec/TableWriter.h @@ -17,86 +17,13 @@ #pragma once #include "velox/core/PlanNode.h" +#include "velox/core/TableWriteTraits.h" #include "velox/exec/ColumnStatsCollector.h" #include "velox/exec/MemoryReclaimer.h" #include "velox/exec/Operator.h" namespace facebook::velox::exec { -/// Defines table writer output related config properties that are shared -/// between TableWriter and TableWriteMerger. -/// -/// TODO: the table write output processing is Prestissimo specific. Consider -/// move these part logic to Prestissimo and pass to Velox through a customized -/// output processing callback. -class TableWriteTraits { - public: - /// Defines the column names/types in table write output. - static std::string rowCountColumnName(); - static std::string fragmentColumnName(); - static std::string contextColumnName(); - - static const TypePtr& rowCountColumnType(); - static const TypePtr& fragmentColumnType(); - static const TypePtr& contextColumnType(); - - /// Defines the column channels in table write output. - /// Both the statistics and the row_count + fragments are transferred over the - /// same communication link between the TableWriter and TableFinish. Thus the - /// multiplexing is needed. - /// - /// The transferred page layout looks like: - /// [row_count_channel], [fragment_channel], [context_channel], - /// [statistic_channel_1] ... [statistic_channel_N]] - /// - /// [row_count_channel] - contains number of rows processed by a TableWriter - /// [fragment_channel] - contains data provided by the DataSink#finish - /// [statistic_channel_1] ...[statistic_channel_N] - - /// contain aggregated statistics computed by the statistics aggregation - /// within the TableWriter - /// - /// For convenience, we never set both: [row_count_channel] + - /// [fragment_channel] and the [statistic_channel_1] ... - /// [statistic_channel_N]. - /// - /// If this is a row that holds statistics - the [row_count_channel] + - /// [fragment_channel] will be NULL. - /// - /// If this is a row that holds the row count - /// or the fragment - all the statistics channels will be set to NULL. - static constexpr int32_t kRowCountChannel = 0; - static constexpr int32_t kFragmentChannel = 1; - static constexpr int32_t kContextChannel = 2; - static constexpr int32_t kStatsChannel = 3; - - /// Defines the names of metadata in commit context in table writer output. - static constexpr std::string_view kLifeSpanContextKey = "lifespan"; - static constexpr std::string_view kTaskIdContextKey = "taskId"; - static constexpr std::string_view kCommitStrategyContextKey = - "pageSinkCommitStrategy"; - static constexpr std::string_view klastPageContextKey = "lastPage"; - - static RowTypePtr outputType( - const std::optional& columnStatsSpec); - - /// Returns the parsed commit context from table writer 'output'. - static folly::dynamic getTableCommitContext(const RowVectorPtr& output); - - /// Returns the sum of row counts from table writer 'output'. - static int64_t getRowCount(const RowVectorPtr& output); - - /// Creates the statistics output. - /// Statistics page layout (aggregate by partition): - /// row fragments context [partition] stats1 stats2 ... - /// null null X [X] X X - /// null null X [X] X X - static RowVectorPtr createAggregationStatsOutput( - RowTypePtr outputType, - RowVectorPtr aggregationOutput, - StringView tableCommitContext, - velox::memory::MemoryPool* pool); -}; - class TableWriter : public Operator { public: TableWriter( @@ -150,15 +77,15 @@ class TableWriter : public Operator { /// The name of runtime stats specific to table writer. /// The running wall time of a writer operator from creation to close. - static inline const std::string kRunningWallNanos{"runningWallNanos"}; + static constexpr std::string_view kRunningWallNanos{"runningWallNanos"}; /// The number of files written by this writer operator. - static inline const std::string kNumWrittenFiles{"numWrittenFiles"}; + static constexpr std::string_view kNumWrittenFiles{"numWrittenFiles"}; /// The file write IO walltime. - static inline const std::string kWriteIOTime{"writeIOWallNanos"}; + static constexpr std::string_view kWriteIOTime{"writeIOWallNanos"}; /// The walltime spend on file write data recoding. - static inline const std::string kWriteRecodeTime{"writeRecodeWallNanos"}; + static constexpr std::string_view kWriteRecodeTime{"writeRecodeWallNanos"}; /// The walltime spent on file write data compression. - static inline const std::string kWriteCompressionTime{ + static constexpr std::string_view kWriteCompressionTime{ "writeCompressionWallNanos"}; private: @@ -233,9 +160,9 @@ class TableWriter : public Operator { const connector::ConnectorInsertTableHandlePtr insertTableHandle_; const connector::CommitStrategy commitStrategy_; // Records the writer operator creation time in ns. This is used to record - // the running wall time of a writer operator. This can helps to detect the + // the running wall time of a writer operator. This can help to detect the // slow scaled writer scheduling in Prestissimo. - const uint64_t createTimeUs_{0}; + const uint64_t createTimeNs_{0}; std::unique_ptr statsCollector_; std::shared_ptr connector_; @@ -260,4 +187,10 @@ class TableWriter : public Operator { bool closed_{false}; vector_size_t numWrittenRows_{0}; }; + +// TODO: TableWriteTraits got moved to velox/core as it pertains to plan +// metadata, not execution. Maintaining the alias here in order not to break +// backward compatibility. +using core::TableWriteTraits; + } // namespace facebook::velox::exec diff --git a/velox/exec/Task.cpp b/velox/exec/Task.cpp index eab1bf7d59c..9b1c2f7e385 100644 --- a/velox/exec/Task.cpp +++ b/velox/exec/Task.cpp @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include #include #include @@ -26,22 +27,22 @@ #include "velox/common/time/Timer.h" #include "velox/exec/Exchange.h" #include "velox/exec/HashJoinBridge.h" +#include "velox/exec/IndexLookupJoinBridge.h" #include "velox/exec/LocalPlanner.h" #include "velox/exec/MemoryReclaimer.h" #include "velox/exec/NestedLoopJoinBuild.h" +#include "velox/exec/OperatorTraceCtx.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/OperatorUtils.h" #include "velox/exec/OutputBufferManager.h" #include "velox/exec/PlanNodeStats.h" #include "velox/exec/SpatialJoinBuild.h" #include "velox/exec/TableScan.h" #include "velox/exec/Task.h" -#include "velox/exec/TaskTraceWriter.h" -#include "velox/exec/TraceUtil.h" using facebook::velox::common::testutil::TestValue; namespace facebook::velox::exec { - namespace { // RAII helper class to satisfy given promises and notify listeners of an event @@ -116,15 +117,49 @@ std::string errorMessageImpl(const std::exception_ptr& exception) { return message; } +// Returns the number of source nodes to process for a given plan node. +// For most nodes, this is the number of sources. For index lookup join, +// only the first source (probe side) is processed as a source node. +inline size_t numSourceNodes(const core::PlanNode* planNode) { + if (!isIndexLookupJoin(planNode)) { + return planNode->sources().size(); + } + const auto* indexNode = + checkedPointerCast(planNode); + VELOX_CHECK_EQ(indexNode->sources().size(), 2); + return !indexNode->needsIndexSplit() ? 1 : indexNode->sources().size(); +} + +// Collects IDs of IndexLookupJoinNode's lookup source (index) TableScanNodes. +// These nodes are NOT separate pipeline leaves in Velox's LocalPlanner +// (it only plans the probe source) and must be excluded from +// groupedExecutionLeafNodeIds validation. +void collectIndexSourceIds( + const core::PlanNodePtr& node, + std::unordered_set& indexSourceIds) { + if (auto indexJoin = + std::dynamic_pointer_cast(node)) { + indexSourceIds.insert(indexJoin->lookupSource()->id()); + // Only recurse into the probe side (sources()[0]), not the lookup source. + collectIndexSourceIds(indexJoin->sources()[0], indexSourceIds); + return; + } + for (const auto& source : node->sources()) { + collectIndexSourceIds(source, indexSourceIds); + } +} + // Add 'running time' metrics from CpuWallTiming structures to have them // available aggregated per thread. void addRunningTimeOperatorMetrics(exec::OperatorStats& op) { - op.runtimeStats["runningAddInputWallNanos"] = + op.runtimeStats[std::string(OperatorStats::kRunningAddInputWallNanos)] = RuntimeMetric(op.addInputTiming.wallNanos, RuntimeCounter::Unit::kNanos); - op.runtimeStats["runningGetOutputWallNanos"] = + op.runtimeStats[std::string(OperatorStats::kRunningGetOutputWallNanos)] = RuntimeMetric(op.getOutputTiming.wallNanos, RuntimeCounter::Unit::kNanos); - op.runtimeStats["runningFinishWallNanos"] = + op.runtimeStats[std::string(OperatorStats::kRunningFinishWallNanos)] = RuntimeMetric(op.finishTiming.wallNanos, RuntimeCounter::Unit::kNanos); + op.runtimeStats[std::string(OperatorStats::kRunningIsBlockedWallNanos)] = + RuntimeMetric(op.isBlockedTiming.wallNanos, RuntimeCounter::Unit::kNanos); } void buildSplitStates( @@ -150,7 +185,7 @@ void buildSplitStates( } const auto& sources = planNode->sources(); - const auto numSources = isIndexLookupJoin(planNode) ? 1 : sources.size(); + const auto numSources = numSourceNodes(planNode); for (auto i = 0; i < numSources; ++i) { buildSplitStates(sources[i].get(), allIds, splitStateMap); } @@ -173,26 +208,33 @@ std::string makeUuid() { // Returns true if an operator is a hash join operator given 'operatorType'. bool isHashJoinOperator(const std::string& operatorType) { - return (operatorType == "HashBuild") || (operatorType == "HashProbe"); + return (operatorType == OperatorType::kHashBuild) || + (operatorType == OperatorType::kHashProbe); } class QueueSplitsStore : public SplitsStore { public: using SplitsStore::SplitsStore; - void requestBarrier(std::vector& promises) override { - addSplit(Split::createBarrier(), promises); + void requestBarrier( + uint32_t numDrivers, + std::vector& promises) override { + addSplit(Split::createBarrier(numDrivers), promises); } bool nextSplit( - Split& split, - ContinueFuture& future, + std::optional driverId, int maxPreloadSplits, - const ConnectorSplitPreloadFunc& preload) override { + const ConnectorSplitPreloadFunc& preload, + Split& split, + ContinueFuture& future) override { if (!splits_.empty()) { split = getSplit(maxPreloadSplits, preload); return true; } + if (tryGetBarrier(driverId, split)) { + return true; + } if (noMoreSplits_) { return true; } @@ -300,7 +342,7 @@ bool registerSplitListenerFactory( bool unregisterSplitListenerFactory( const std::shared_ptr& factory) { return splitListenerFactories().withWLock([&](auto& factories) { - for (auto it = factories.begin(); it != factories.end(); ++it) { + for (auto it = factories.cbegin(); it != factories.cend(); ++it) { if ((*it) == factory) { factories.erase(it); return true; @@ -320,8 +362,8 @@ std::shared_ptr Task::create( ExecutionMode mode, Consumer consumer, int32_t memoryArbitrationPriority, + std::optional spillDiskOpts, std::function onError) { - VELOX_CHECK_NOT_NULL(planFragment.planNode); return Task::create( taskId, std::move(planFragment), @@ -331,6 +373,7 @@ std::shared_ptr Task::create( (consumer ? [c = std::move(consumer)]() { return c; } : ConsumerSupplier{}), memoryArbitrationPriority, + std::move(spillDiskOpts), std::move(onError)); } @@ -343,6 +386,7 @@ std::shared_ptr Task::create( ExecutionMode mode, ConsumerSupplier consumerSupplier, int32_t memoryArbitrationPriority, + std::optional spillDiskOpts, std::function onError) { VELOX_CHECK_NOT_NULL(planFragment.planNode); auto task = std::shared_ptr(new Task( @@ -354,7 +398,7 @@ std::shared_ptr Task::create( std::move(consumerSupplier), memoryArbitrationPriority, std::move(onError))); - task->initTaskPool(); + task->init(std::move(spillDiskOpts)); task->addToTaskList(); return task; } @@ -377,7 +421,7 @@ Task::Task( planFragment_(std::move(planFragment)), firstNodeNotSupportingBarrier_( planFragment_.firstNodeNotSupportingBarrier()), - traceConfig_(maybeMakeTraceConfig()), + traceCtx_(maybeMakeTraceCtx()), consumerSupplier_(std::move(consumerSupplier)), onError_(std::move(onError)), splitsStates_(buildSplitStates(planFragment_.planNode)), @@ -457,11 +501,6 @@ Task::~Task() { } void Task::ensureBarrierSupport() const { - VELOX_CHECK_EQ( - mode_, - Task::ExecutionMode::kSerial, - "Task doesn't support barriered execution."); - VELOX_CHECK_NULL( firstNodeNotSupportingBarrier_, "Task doesn't support barriered execution. Name of the first node that " @@ -469,6 +508,76 @@ void Task::ensureBarrierSupport() const { firstNodeNotSupportingBarrier_->name()); } +void Task::init(std::optional&& spillDiskOpts) { + VELOX_CHECK(driverFactories_.empty()); + initTaskPool(); + + setSpillDiskConfig(std::move(spillDiskOpts)); + + if (mode_ != Task::ExecutionMode::kSerial) { + return; + } + + // Create drivers. + VELOX_CHECK_NULL( + consumerSupplier_, + "Serial execution mode doesn't support delivering results to a " + "callback"); + + taskStats_.executionStartTimeMs = getCurrentTimeMs(); + LocalPlanner::plan( + planFragment_, nullptr, &driverFactories_, queryCtx_->queryConfig(), 1); + exchangeClients_.resize(driverFactories_.size()); + + // In Task::next() we always assume ungrouped execution. + for (const auto& factory : driverFactories_) { + VELOX_CHECK(factory->supportsSerialExecution()); + numDriversUngrouped_ += factory->numDrivers; + numTotalDrivers_ += factory->numTotalDrivers; + taskStats_.pipelineStats.emplace_back( + factory->inputDriver, factory->outputDriver); + VELOX_CHECK_EQ(factory->numDrivers, 1); + numDriversPerLeafNode_[factory->leafNodeId()] = factory->numDrivers; + } + + initDriverLifecycleStatsLocked(); + + // Create drivers. + createSplitGroupStateLocked(kUngroupedGroupId); + std::vector> drivers = + createDriversLocked(kUngroupedGroupId); + if (pool_->reservedBytes() != 0) { + VELOX_FAIL( + "Unexpected memory pool allocations during task[{}] driver initialization: {}", + taskId_, + pool_->treeMemoryUsage()); + } + + drivers_ = std::move(drivers); + driverBlockingStates_.reserve(drivers_.size()); + for (auto i = 0; i < drivers_.size(); ++i) { + driverBlockingStates_.emplace_back( + std::make_unique(drivers_[i].get())); + } +} + +void Task::setSpillDiskConfig( + std::optional&& spillDiskOpts) { + if (!spillDiskOpts.has_value()) { + return; + } + VELOX_CHECK( + !spillDiskOpts->spillDirPath.empty(), "Spill directory can't be empty"); + VELOX_CHECK( + spillDiskOpts->spillDirCreated || spillDiskOpts->spillDirCreateCb); + VELOX_CHECK_NULL(spillDirectoryCallback_); + VELOX_CHECK(!spillDirectoryCreated_); + VELOX_CHECK(spillDirectory_.empty()); + spillDirectory_ = std::move(spillDiskOpts->spillDirPath); + spillDirectoryCreated_ = spillDiskOpts->spillDirCreated; + spillDirectoryCallback_ = std::move(spillDiskOpts->spillDirCreateCb); +} + Task::TaskList& Task::taskList() { static TaskList taskList; return taskList; @@ -547,7 +656,7 @@ bool Task::allNodesReceivedNoMoreSplitsMessageLocked() const { const std::string& Task::getOrCreateSpillDirectory() { VELOX_CHECK( !spillDirectory_.empty() || spillDirectoryCallback_, - "Spill directory or spill directory callback must be set "); + "Spill directory or spill directory callback must be set"); if (spillDirectoryCreated_) { return spillDirectory_; } @@ -681,8 +790,9 @@ velox::memory::MemoryPool* Task::addOperatorPool( } else { nodePool = getOrAddNodePool(planNodeId); } - childPools_.push_back(nodePool->addLeafChild(fmt::format( - "op.{}.{}.{}.{}", planNodeId, pipelineId, driverId, operatorType))); + childPools_.push_back(nodePool->addLeafChild( + fmt::format( + "op.{}.{}.{}.{}", planNodeId, pipelineId, driverId, operatorType))); return childPools_.back().get(); } @@ -693,13 +803,14 @@ velox::memory::MemoryPool* Task::addConnectorPoolLocked( const std::string& operatorType, const std::string& connectorId) { auto* nodePool = getOrAddNodePool(planNodeId); - childPools_.push_back(nodePool->addAggregateChild(fmt::format( - "op.{}.{}.{}.{}.{}", - planNodeId, - pipelineId, - driverId, - operatorType, - connectorId))); + childPools_.push_back(nodePool->addAggregateChild( + fmt::format( + "op.{}.{}.{}.{}.{}", + planNodeId, + pipelineId, + driverId, + operatorType, + connectorId))); return childPools_.back().get(); } @@ -769,48 +880,8 @@ RowVectorPtr Task::next(ContinueFuture* future) { } } - // On first call, create the drivers. - if (driverFactories_.empty()) { - VELOX_CHECK_NULL( - consumerSupplier_, - "Serial execution mode doesn't support delivering results to a " - "callback"); - - taskStats_.executionStartTimeMs = getCurrentTimeMs(); - LocalPlanner::plan( - planFragment_, nullptr, &driverFactories_, queryCtx_->queryConfig(), 1); - exchangeClients_.resize(driverFactories_.size()); - - // In Task::next() we always assume ungrouped execution. - for (const auto& factory : driverFactories_) { - VELOX_CHECK(factory->supportsSerialExecution()); - numDriversUngrouped_ += factory->numDrivers; - numTotalDrivers_ += factory->numTotalDrivers; - taskStats_.pipelineStats.emplace_back( - factory->inputDriver, factory->outputDriver); - } - - // Create drivers. - createSplitGroupStateLocked(kUngroupedGroupId); - std::vector> drivers = - createDriversLocked(kUngroupedGroupId); - if (pool_->reservedBytes() != 0) { - VELOX_FAIL( - "Unexpected memory pool allocations during task[{}] driver initialization: {}", - taskId_, - pool_->treeMemoryUsage()); - } - - drivers_ = std::move(drivers); - driverBlockingStates_.reserve(drivers_.size()); - for (auto i = 0; i < drivers_.size(); ++i) { - driverBlockingStates_.emplace_back( - std::make_unique(drivers_[i].get())); - } - if (underBarrier()) { - startDriverBarriersLocked(); - } - } + VELOX_CHECK_EQ( + state_, TaskState::kRunning, "Task has already finished processing."); // Run drivers one at a time. If a driver blocks, continue running the other // drivers. Running other drivers is expected to unblock some or all blocked @@ -986,10 +1057,13 @@ void Task::createDriverFactoriesLocked(uint32_t maxDrivers) { numDriversUngrouped_ += factory->numDrivers; } numTotalDrivers_ += factory->numTotalDrivers; + numDriversPerLeafNode_[factory->leafNodeId()] = factory->numDrivers; taskStats_.pipelineStats.emplace_back( factory->inputDriver, factory->outputDriver); } + initDriverLifecycleStatsLocked(); + validateGroupedExecutionLeafNodes(); } @@ -1211,9 +1285,20 @@ void Task::validateGroupedExecutionLeafNodes() { !planFragment_.groupedExecutionLeafNodeIds.empty(), "groupedExecutionLeafNodeIds must not be empty in " "grouped execution mode"); + + // Collect IndexLookupJoin lookup source node IDs. These are in + // groupedExecutionLeafNodeIds (for coordinator-side grouped split + // scheduling) but are NOT separate pipeline leaves in Velox — + // IndexLookupJoin manages the index source internally. + std::unordered_set indexSourceIds; + collectIndexSourceIds(planFragment_.planNode, indexSourceIds); + // Check that each node designated as the grouped execution leaf node // existing in a pipeline that will run grouped execution. for (const auto& leafNodeId : planFragment_.groupedExecutionLeafNodeIds) { + if (indexSourceIds.count(leafNodeId)) { + continue; + } bool found{false}; for (auto& factory : driverFactories_) { if (leafNodeId == factory->leafNodeId()) { @@ -1263,6 +1348,8 @@ void Task::createSplitGroupStateLocked(uint32_t splitGroupId) { splitGroupId, factory->needsNestedLoopJoinBridges()); addSpatialJoinBridgesLocked( splitGroupId, factory->needsSpatialJoinBridges()); + addIndexLookupJoinBridgesLocked( + splitGroupId, factory->needsIndexLookupJoinBridges()); addCustomJoinBridgesLocked(splitGroupId, factory->planNodes); core::PlanNodeId tableScanNodeId; @@ -1600,18 +1687,12 @@ void Task::addSplitToStoreLocked( uint32_t groupId, const exec::Split& split, std::vector& promises) { - auto& splitsStore = splitsState.groupSplitsStores[groupId]; - if (!splitsStore) { - setSplitsStore( - splitsStore, - std::make_unique(!splitsState.sourceIsTableScan)); - } + auto* splitsStore = getOrCreateSplitsStoreLocked(splitsState, groupId); if (split.isBarrier()) { - splitsStore->requestBarrier(promises); + splitsStore->requestBarrier(split.barrier->numDrivers, promises); return; } - auto* queueSplitsStore = - checked_pointer_cast(splitsStore.get()); + auto* queueSplitsStore = checkedPointerCast(splitsStore); queueSplitsStore->addSplit(split, promises); } @@ -1625,7 +1706,7 @@ void Task::noMoreSplitsForGroup( auto& splitsState = getPlanNodeSplitsStateLocked(planNodeId); noMoreSplitsForStore( - splitsState.groupSplitsStores[splitGroupId].get(), promises); + getOrCreateSplitsStoreLocked(splitsState, splitGroupId), promises); // There were no splits in this group, hence, no active drivers. Mark the // group complete. @@ -1722,6 +1803,18 @@ void Task::setSplitsStore( splitsStore->setPreloadingSplits(preloadingSplits_); } +SplitsStore* Task::getOrCreateSplitsStoreLocked( + SplitsState& splitsState, + uint32_t splitGroupId) { + auto& splitsStore = splitsState.groupSplitsStores[splitGroupId]; + if (!splitsStore) { + setSplitsStore( + splitsStore, + std::make_unique(!splitsState.sourceIsTableScan)); + } + return splitsStore.get(); +} + ContinueFuture Task::requestBarrier() { ensureBarrierSupport(); return startBarrier("Task::requestBarrier"); @@ -1764,7 +1857,8 @@ ContinueFuture Task::startBarrier(std::string_view comment) { promises.reserve(leafPlanNodeIds.size()); for (const auto& leafPlanNode : leafPlanNodeIds) { - auto barrierSplit = Split::createBarrier(); + const auto barrierSplit = + Split::createBarrier(numDriversPerLeafNode_.at(leafPlanNode)); auto& splitState = getPlanNodeSplitsStateLocked(leafPlanNode); addSplitLocked(splitState, barrierSplit, promises); } @@ -1862,8 +1956,8 @@ void Task::dropInputLocked( VELOX_CHECK(!drivers.empty()); const auto dropNodeId = *dropNodeIds.begin(); bool foundDriver{false}; - auto it = drivers.begin(); - while (it != drivers.end()) { + auto it = drivers.cbegin(); + while (it != drivers.cend()) { Driver* driver = *it; VELOX_CHECK_NOT_NULL(driver); if (auto* dropOp = driver->findOperator(dropNodeId)) { @@ -1925,21 +2019,18 @@ bool Task::isAllSplitsFinishedLocked() { } BlockingReason Task::getSplitOrFuture( + uint32_t driverId, uint32_t splitGroupId, const core::PlanNodeId& planNodeId, - exec::Split& split, - ContinueFuture& future, int32_t maxPreloadSplits, - const ConnectorSplitPreloadFunc& preload) { + const ConnectorSplitPreloadFunc& preload, + exec::Split& split, + ContinueFuture& future) { std::lock_guard l(mutex_); auto& splitsState = getPlanNodeSplitsStateLocked(planNodeId); - auto& splitsStore = splitsState.groupSplitsStores[splitGroupId]; - if (!splitsStore) { - setSplitsStore( - splitsStore, - std::make_unique(!splitsState.sourceIsTableScan)); - } - return splitsStore->nextSplit(split, future, maxPreloadSplits, preload) + auto* splitsStore = getOrCreateSplitsStoreLocked(splitsState, splitGroupId); + return splitsStore->nextSplit( + driverId, maxPreloadSplits, preload, split, future) ? BlockingReason::kNotBlocked : BlockingReason::kWaitForSplit; } @@ -2301,6 +2392,26 @@ void Task::addSpatialJoinBridgesLocked( } } +void Task::addIndexLookupJoinBridgesLocked( + uint32_t splitGroupId, + const std::vector& planNodeIds) { + auto& splitGroupState = splitGroupStates_[splitGroupId]; + for (const auto& planNodeId : planNodeIds) { + auto const inserted = + splitGroupState.bridges + .emplace(planNodeId, std::make_shared()) + .second; + VELOX_CHECK( + inserted, "Join bridge for node {} is already present", planNodeId); + } +} + +std::shared_ptr Task::getIndexLookupJoinBridge( + uint32_t splitGroupId, + const core::PlanNodeId& planNodeId) { + return getJoinBridgeInternal(splitGroupId, planNodeId); +} + std::shared_ptr Task::getHashJoinBridge( uint32_t splitGroupId, const core::PlanNodeId& planNodeId) { @@ -2339,7 +2450,7 @@ template std::shared_ptr Task::getJoinBridgeInternalLocked( uint32_t splitGroupId, const core::PlanNodeId& planNodeId, - MemberType SplitGroupState::*bridges_member) { + MemberType SplitGroupState::* bridges_member) { const auto& splitGroupState = splitGroupStates_[splitGroupId]; auto it = (splitGroupState.*bridges_member).find(planNodeId); @@ -2409,22 +2520,27 @@ ContinueFuture Task::terminate(TaskState terminalState) { "Termination time has already been set, this should only happen once."); taskStats_.terminationTimeMs = getCurrentTimeMs(); if (state_ == TaskState::kCanceled || state_ == TaskState::kAborted) { - try { - VELOX_FAIL( - state_ == TaskState::kCanceled ? "Cancelled" - : "Aborted for external error"); - } catch (const std::exception&) { - exception_ = std::current_exception(); - } + // Construct the exception directly instead of going through VELOX_FAIL to + // avoid error log when cancellation is expected. + exception_ = std::make_exception_ptr(VeloxRuntimeError( + __FILE__, + __LINE__, + __FUNCTION__, + /*expression=*/"", + state_ == TaskState::kCanceled ? "Cancelled" + : "Aborted for external error", + error_source::kErrorSourceRuntime, + error_code::kInvalidState, + /*isRetriable=*/false)); } if (state_ != TaskState::kFinished) { VELOX_CHECK(!cancellationSource_.isCancellationRequested()); cancellationSource_.requestCancellation(); } - LOG(INFO) << "Terminating task " << taskId() << " with state " - << taskStateString(state_) << " after running for " - << succinctMillis(timeSinceStartMsLocked()); + VLOG(1) << "Terminating task " << taskId() << " with state " + << taskStateString(state_) << " after running for " + << succinctMillis(timeSinceStartMsLocked()); taskCompletionNotifier.activate( std::move(taskCompletionPromises_), [&]() { onTaskCompletion(); }); @@ -2452,6 +2568,9 @@ ContinueFuture Task::terminate(TaskState terminalState) { exchangeClients.swap(exchangeClients_); barrierPromises.swap(barrierFinishPromises_); + // Clear the barrier flag to ensure underBarrier() returns false after task + // termination. + barrierRequested_ = false; } taskCompletionNotifier.notify(); @@ -2514,8 +2633,13 @@ ContinueFuture Task::terminate(TaskState terminalState) { } while (!store->allSplitsConsumed()) { auto future = ContinueFuture::makeEmpty(); - VELOX_CHECK( - store->nextSplit(splits.emplace_back(), future, 0, nullptr)); + const auto hasNextSplit = store->nextSplit( + /*driverId=*/std::nullopt, + /*maxPreloadSplits=*/0, + /*preload=*/nullptr, + splits.emplace_back(), + future); + VELOX_CHECK(hasNextSplit); } } if (!splits.empty()) { @@ -2616,6 +2740,37 @@ void Task::addDriverStats(int pipelineId, DriverStats stats) { taskStats_.pipelineStats[pipelineId].driverStats.push_back(std::move(stats)); } +void Task::initDriverLifecycleStatsLocked() { + pipelineLifecycleStats_.resize(driverFactories_.size()); + for (size_t i = 0; i < driverFactories_.size(); ++i) { + auto& pls = pipelineLifecycleStats_[i]; + const auto& leafNode = driverFactories_[i]->planNodes.front(); + pls.sourceOperatorType = leafNode->name(); + pls.sourcePlanNodeId = leafNode->id(); + pls.driverTimes.resize(driverFactories_[i]->numDrivers); + } +} + +void Task::addDriverLifecycleStats( + uint32_t pipelineId, + uint32_t driverIndex, + uint64_t queuedNanos, + uint64_t onThreadNanos, + uint64_t blockedNanos) { + std::lock_guard l(mutex_); + if (pipelineId >= pipelineLifecycleStats_.size()) { + return; + } + auto& pls = pipelineLifecycleStats_[pipelineId]; + if (pls.driverTimes.empty()) { + return; + } + const auto idx = driverIndex % pls.driverTimes.size(); + pls.driverTimes[idx].queuedNanos += queuedNanos; + pls.driverTimes[idx].onThreadNanos += onThreadNanos; + pls.driverTimes[idx].blockedNanos += blockedNanos; +} + TaskStats Task::taskStats() const { std::lock_guard l(mutex_); @@ -2671,6 +2826,41 @@ TaskStats Task::taskStats() const { taskStats.longestRunningOpCallMs = 0; } + // Emit per-pipeline driver lifecycle timing. Merged into the first + // existing DriverStats entry for each pipeline to avoid inflating the + // driverStats vector. Each logical driver index contributes one sample, + // giving proper sum/count/min/max aggregation. + for (size_t i = 0; i < pipelineLifecycleStats_.size(); ++i) { + const auto& pls = pipelineLifecycleStats_[i]; + if (pls.driverTimes.empty()) { + continue; + } + auto& pipeDriverStats = taskStats.pipelineStats[i].driverStats; + if (pipeDriverStats.empty()) { + pipeDriverStats.emplace_back(); + } + auto& targetStats = pipeDriverStats[0].runtimeStats; + const auto prefix = fmt::format( + "P{}-{}.{}", i, pls.sourceOperatorType, pls.sourcePlanNodeId); + const auto queuedKey = fmt::format("{}.driverQueuedWallNanos", prefix); + const auto onThreadKey = fmt::format("{}.driverOnThreadWallNanos", prefix); + const auto blockedKey = fmt::format("{}.driverBlockedWallNanos", prefix); + for (const auto& timing : pls.driverTimes) { + auto addOrMerge = [&targetStats](const std::string& key, uint64_t nanos) { + const auto value = saturateCast(nanos); + auto it = targetStats.find(key); + if (it != targetStats.end()) { + it->second.merge(RuntimeMetric(value, RuntimeCounter::Unit::kNanos)); + } else { + targetStats[key] = RuntimeMetric(value, RuntimeCounter::Unit::kNanos); + } + }; + addOrMerge(queuedKey, timing.queuedNanos); + addOrMerge(onThreadKey, timing.onThreadNanos); + addOrMerge(blockedKey, timing.blockedNanos); + } + } + auto bufferManager = bufferManager_.lock(); taskStats.outputBufferUtilization = bufferManager->getUtilization(taskId_); taskStats.outputBufferOverutilized = bufferManager->isOverutilized(taskId_); @@ -2770,19 +2960,29 @@ void Task::onTaskCompletion() { } for (auto& listener : listeners) { - listener->onTaskCompletion( - uuid_, - taskId_, - state, - exception, - stats, - planFragment_, - exchangeClientByPlanNode_); + try { + listener->onTaskCompletion( + uuid_, + taskId_, + state, + exception, + stats, + planFragment_, + exchangeClientByPlanNode_); + } catch (const std::exception& e) { + LOG(ERROR) << "TaskCompletionListener threw for task " << taskId_ + << ": " << e.what(); + } } }); for (auto& listener : splitListeners_) { - listener->onTaskCompletion(); + try { + listener->onTaskCompletion(); + } catch (const std::exception& e) { + LOG(ERROR) << "SplitCompletionListener threw for task " << taskId_ << ": " + << e.what(); + } } } @@ -3006,8 +3206,9 @@ void Task::createLocalExchangeQueuesLocked( queryCtx_->queryConfig().maxLocalExchangeBufferSize()); exchange.queues.reserve(numPartitions); for (auto i = 0; i < numPartitions; ++i) { - exchange.queues.emplace_back(std::make_shared( - exchange.memoryManager, exchange.vectorPool, i)); + exchange.queues.emplace_back( + std::make_shared( + exchange.memoryManager, exchange.vectorPool, i)); } const auto partitionNode = @@ -3138,7 +3339,7 @@ std::string Task::errorMessage() const { } StopReason Task::enter(ThreadState& state, uint64_t nowMicros) { - TestValue::adjust("facebook::velox::exec::Task::enter", &state); + TestValue::adjust("facebook::velox::exec::Task::enter", this); std::lock_guard l(mutex_); VELOX_CHECK(state.isEnqueued); state.isEnqueued = false; @@ -3396,7 +3597,9 @@ void Task::createExchangeClientLocked( queryCtx()->queryConfig().minExchangeOutputBatchBytes(), addExchangeClientPool(planNodeId, pipelineId), queryCtx()->executor(), - queryCtx()->queryConfig().requestDataSizesMaxWaitSec()); + queryCtx()->queryConfig().requestDataSizesMaxWaitSec(), + queryCtx()->queryConfig().singleSourceExchangeOptimizationEnabled(), + queryCtx()->queryConfig().exchangeLazyFetchingEnabled()); exchangeClientByPlanNode_.emplace(planNodeId, exchangeClients_[pipelineId]); } @@ -3415,66 +3618,27 @@ std::shared_ptr Task::getExchangeClientLocked( return exchangeClients_[pipelineId]; } -std::optional Task::maybeMakeTraceConfig() const { - const auto& queryConfig = queryCtx_->queryConfig(); - if (!queryConfig.queryTraceEnabled()) { - return std::nullopt; +std::unique_ptr Task::maybeMakeTraceCtx() const { + if (!queryCtx_->queryConfig().queryTraceEnabled()) { + return nullptr; } - VELOX_USER_CHECK( - !queryConfig.queryTraceDir().empty(), - "Query trace enabled but the trace dir is not set"); - - VELOX_USER_CHECK( - !queryConfig.queryTraceTaskRegExp().empty(), - "Query trace enabled but the trace task regexp is not set"); - - if (!RE2::FullMatch(taskId_, queryConfig.queryTraceTaskRegExp())) { - return std::nullopt; + if (queryCtx_->traceCtxProvider()) { + return queryCtx_->traceCtxProvider()(*queryCtx_, planFragment()); } - - const auto traceNodeId = queryConfig.queryTraceNodeId(); - VELOX_USER_CHECK(!traceNodeId.empty(), "Query trace node ID are not set"); - - const auto traceDir = trace::getTaskTraceDirectory( - queryConfig.queryTraceDir(), queryCtx_->queryId(), taskId_); - - VELOX_USER_CHECK_NOT_NULL( - core::PlanNode::findFirstNode( - planFragment_.planNode.get(), - [traceNodeId](const core::PlanNode* node) -> bool { - return node->id() == traceNodeId; - }), - "Trace plan node ID = {} not found from task {}", - traceNodeId, - taskId_); - - LOG(INFO) << "Trace input for plan nodes " << traceNodeId << " from task " - << taskId_; - - UpdateAndCheckTraceLimitCB updateAndCheckTraceLimitCB = - [this](uint64_t bytes) { - queryCtx_->updateTracedBytesAndCheckLimit(bytes); - }; - return TraceConfig( - traceNodeId, - traceDir, - std::move(updateAndCheckTraceLimitCB), - queryConfig.queryTraceTaskRegExp(), - queryConfig.queryTraceDryRun()); + // Fallback to default trace. + return trace::OperatorTraceCtx::maybeCreate( + *queryCtx_, planFragment(), taskId()); } void Task::maybeInitTrace() { - if (!traceConfig_) { + if (!traceCtx_) { return; } - trace::createTraceDirectory(traceConfig_->queryTraceDir); - const auto metadataWriter = std::make_unique( - traceConfig_->queryTraceDir, memory::traceMemoryPool()); - auto traceNode = - trace::getTraceNode(planFragment_.planNode, traceConfig_->queryNodeId); - metadataWriter->write(queryCtx_, traceNode); + if (auto metadataWriter = traceCtx_->createMetadataTracer()) { + metadataWriter->write(*queryCtx_, *planFragment_.planNode); + } } void Task::testingVisitDrivers(const std::function& callback) { @@ -3685,8 +3849,8 @@ bool Task::DriverBlockingState::blocked(ContinueFuture* future) { VELOX_CHECK(promises_.empty()); return false; } - auto [blockPromise, blockFuture] = - makeVeloxContinuePromiseContract(fmt::format( + auto [blockPromise, blockFuture] = makeVeloxContinuePromiseContract( + fmt::format( "DriverBlockingState {} from task {}", driver_->driverCtx()->driverId, driver_->task()->taskId())); diff --git a/velox/exec/Task.h b/velox/exec/Task.h index 01b906df087..2b483b806a9 100644 --- a/velox/exec/Task.h +++ b/velox/exec/Task.h @@ -15,8 +15,9 @@ */ #pragma once +#include + #include "velox/common/base/SkewedPartitionBalancer.h" -#include "velox/common/base/TraceConfig.h" #include "velox/core/PlanFragment.h" #include "velox/core/QueryCtx.h" #include "velox/exec/Driver.h" @@ -26,6 +27,7 @@ #include "velox/exec/ScaledScanController.h" #include "velox/exec/TaskStats.h" #include "velox/exec/TaskStructs.h" +#include "velox/exec/trace/TraceCtx.h" #include "velox/vector/ComplexVector.h" namespace facebook::velox::exec { @@ -33,6 +35,7 @@ namespace facebook::velox::exec { class OutputBufferManager; class HashJoinBridge; +class IndexLookupJoinBridge; class NestedLoopJoinBridge; class SpatialJoinBridge; class SplitListener; @@ -66,11 +69,16 @@ class Task : public std::enable_shared_from_this { /// @param consumer Optional factory function to get callbacks to pass the /// results of the execution. In a parallel execution mode, results from each /// thread are passed on to a separate consumer. + /// @param memoryArbitrationPriority Priority used by the memory arbitrator + /// to determine which task should have its memory reclaimed first when the + /// system is under memory pressure. Higher values indicate higher priority + /// (lower likelihood of being reclaimed). Default is 0. + /// @param spillDiskOpts Optional configuration for spill disk storage. When + /// provided, allows operators to spill intermediate data to disk during + /// execution when memory pressure is high. Includes spill directory path + /// and callback options. Default is std::nullopt (no spilling). /// @param onError Optional callback to receive an exception if task /// execution fails. - /// @param memoryArbitrationPriority Optional priority on task that, in a - /// multi task system, is used for memory arbitration to decide the order of - /// reclaiming. static std::shared_ptr create( const std::string& taskId, core::PlanFragment planFragment, @@ -79,6 +87,7 @@ class Task : public std::enable_shared_from_this { ExecutionMode mode, Consumer consumer = nullptr, int32_t memoryArbitrationPriority = 0, + std::optional spillDiskOpts = std::nullopt, std::function onError = nullptr); static std::shared_ptr create( @@ -89,6 +98,7 @@ class Task : public std::enable_shared_from_this { ExecutionMode mode, ConsumerSupplier consumerSupplier, int32_t memoryArbitrationPriority = 0, + std::optional spillDiskOpts = std::nullopt, std::function onError = nullptr); /// Convenience function for shortening a Presto taskId. To be used @@ -97,22 +107,6 @@ class Task : public std::enable_shared_from_this { ~Task(); - /// Specify directory to which data will be spilled if spilling is enabled and - /// required. Set 'alreadyCreated' to true if the directory has already been - /// created by the caller. - void setSpillDirectory( - const std::string& spillDirectory, - bool alreadyCreated = true) { - spillDirectory_ = spillDirectory; - spillDirectoryCreated_ = alreadyCreated; - } - - void setCreateSpillDirectoryCb( - std::function spillDirectoryCallback) { - VELOX_CHECK_NULL(spillDirectoryCallback_); - spillDirectoryCallback_ = std::move(spillDirectoryCallback); - } - /// Returns human-friendly representation of the plan augmented with runtime /// statistics. The implementation invokes exec::printPlanWithStats(). /// @@ -160,8 +154,8 @@ class Task : public std::enable_shared_from_this { } /// Returns query trace config if specified. - const std::optional& traceConfig() const { - return traceConfig_; + const trace::TraceCtx* traceCtx() const { + return traceCtx_.get(); } /// Returns ConsumerSupplier passed in the constructor. @@ -446,12 +440,13 @@ class Task : public std::enable_shared_from_this { /// so many of splits at the head of the queue are preloading. If /// they are not, calls preload on them to start preload. BlockingReason getSplitOrFuture( + uint32_t driverId, uint32_t splitGroupId, const core::PlanNodeId& planNodeId, + int32_t maxPreloadSplits, + const ConnectorSplitPreloadFunc& preload, exec::Split& split, - ContinueFuture& future, - int32_t maxPreloadSplits = 0, - const ConnectorSplitPreloadFunc& preload = nullptr); + ContinueFuture& future); /// Returns the scaled scan controller for a given table scan node if the /// query has configured. @@ -520,7 +515,7 @@ class Task : public std::enable_shared_from_this { void setError(const std::string& message); /// Returns all the peer operators of the 'caller' operator from a given - /// 'pipelindId' in this task. + /// 'pipelineId' in this task. std::vector findPeerOperators(int pipelineId, Operator* caller); /// Synchronizes completion of an Operator across Drivers of 'this'. @@ -566,6 +561,11 @@ class Task : public std::enable_shared_from_this { uint32_t splitGroupId, const std::vector& planNodeIds); + /// Adds IndexLookupJoinBridge's for all the specified plan node IDs. + void addIndexLookupJoinBridgesLocked( + uint32_t splitGroupId, + const std::vector& planNodeIds); + /// Adds custom join bridges for all the specified plan nodes. void addCustomJoinBridgesLocked( uint32_t splitGroupId, @@ -593,6 +593,11 @@ class Task : public std::enable_shared_from_this { uint32_t splitGroupId, const core::PlanNodeId& planNodeId); + /// Returns an IndexLookupJoinBridge for 'planNodeId'. + std::shared_ptr getIndexLookupJoinBridge( + uint32_t splitGroupId, + const core::PlanNodeId& planNodeId); + /// Returns a custom join bridge for 'planNodeId'. std::shared_ptr getCustomJoinBridge( uint32_t splitGroupId, @@ -610,6 +615,16 @@ class Task : public std::enable_shared_from_this { /// Adds per driver statistics. Called from Drivers upon their closure. void addDriverStats(int pipelineId, DriverStats stats); + /// Accumulates driver lifecycle timing (queued, on-thread, blocked) for gap + /// analysis. Same-index drivers across split groups are summed together. + /// Called from Driver::closeOperators() upon driver closure. + void addDriverLifecycleStats( + uint32_t pipelineId, + uint32_t driverIndex, + uint64_t queuedNanos, + uint64_t onThreadNanos, + uint64_t blockedNanos); + /// Returns kNone if no pause or terminate is requested. The thread count is /// incremented if kNone is returned. If something else is returned the /// calling thread should unwind and return itself to its pool. If 'this' goes @@ -626,7 +641,7 @@ class Task : public std::enable_shared_from_this { StopReason enterForTerminateLocked(ThreadState& state); /// Marks that the Driver is not on thread. If no more Drivers in the - /// CancelPool are on thread, this realizes threadFinishFutures_. These allow + /// Task are on thread, this realizes threadFinishFutures_. /// syncing with pause or termination. The Driver may go off thread because of /// hasBlockingFuture or pause requested or terminate requested. The /// return value indicates the reason. If kTerminate is returned, the @@ -821,6 +836,13 @@ class Task : public std::enable_shared_from_this { int32_t memoryArbitrationPriority = 0, std::function onError = nullptr); + // Invoked to do post-create initialization. + void init(std::optional&& spillDiskOpts); + + // Invoked to initialize the spill storage config for this task. + void setSpillDiskConfig( + std::optional&& spillDiskOpts); + // Invoked to add this to the system-wide running task list on task creation. void addToTaskList(); @@ -859,7 +881,7 @@ class Task : public std::enable_shared_from_this { // message. bool allNodesReceivedNoMoreSplitsMessageLocked() const; - // Recursive helper for 'allSpilitsConsumed()' method. + // Recursive helper for 'allSplitsConsumed()' method. bool allSplitsConsumedHelper(const core::PlanNode* planNode) const; // Remove the spill directory, if the Task was creating it for potential @@ -944,7 +966,7 @@ class Task : public std::enable_shared_from_this { VELOX_CHECK_NOT_NULL(task); } - // Gets the shared pointer to the driver to ensure its liveness during the + // Gets the shared pointer to the task to ensure its liveness during the // memory reclaim operation. // // NOTE: a task's memory pool might outlive the task itself. @@ -976,7 +998,7 @@ class Task : public std::enable_shared_from_this { std::shared_ptr getJoinBridgeInternalLocked( uint32_t splitGroupId, const core::PlanNodeId& planNodeId, - MemberType SplitGroupState::*bridges_member); + MemberType SplitGroupState::* bridges_member); std::shared_ptr getCustomJoinBridgeInternal( uint32_t splitGroupId, @@ -1038,6 +1060,12 @@ class Task : public std::enable_shared_from_this { std::unique_ptr& splitsStore, std::unique_ptr newSplitsStore); + // Returns the splits store for the given group, creating one if it doesn't + // exist. + SplitsStore* getOrCreateSplitsStoreLocked( + SplitsState& splitsState, + uint32_t splitGroupId); + // Invoked when all the driver threads are off thread. The function returns // 'threadFinishPromises_' to fulfill. std::vector allThreadsFinishedLocked(); @@ -1118,7 +1146,7 @@ class Task : public std::enable_shared_from_this { int32_t pipelineId) const; // Builds the query trace config. - std::optional maybeMakeTraceConfig() const; + std::unique_ptr maybeMakeTraceCtx() const; // Create a 'QueryMetadtaWriter' to trace the query metadata if the query // trace enabled. @@ -1165,7 +1193,7 @@ class Task : public std::enable_shared_from_this { // and all its plan nodes support barrier processing. const core::PlanNode* firstNodeNotSupportingBarrier_{}; - const std::optional traceConfig_; + const std::unique_ptr traceCtx_; inline static std::atomic_uint64_t numCreatedTasks_; @@ -1176,7 +1204,7 @@ class Task : public std::enable_shared_from_this { // to pool_ must be defined after pool_, childPools_. std::shared_ptr pool_; - // Keep driver and operator memory pools alive for the duration of the task + // Keep plan node and operator memory pools alive for the duration of the task // to allow for sharing vectors across drivers without copy. std::vector> childPools_; @@ -1215,11 +1243,12 @@ class Task : public std::enable_shared_from_this { ConsumerSupplier consumerSupplier_; // The function that is executed when the task encounters its first error, - // that is, serError() is called for the first time. + // that is, setError() is called for the first time. std::function onError_; std::vector> driverFactories_; std::vector> drivers_; + std::unordered_map numDriversPerLeafNode_; // Tracks the blocking state for each driver under serialized execution mode. class DriverBlockingState { @@ -1349,6 +1378,27 @@ class Task : public std::enable_shared_from_this { TaskStats taskStats_; + // Per-pipeline driver lifecycle timing accumulator. For each pipeline, holds + // a vector of per-driver timing indexed by driver index within the pipeline. + // In grouped execution, same-index drivers across groups accumulate into the + // same entry, giving the total time for a logical driver across all groups. + // Reported as RuntimeMetrics at task close via taskStats(). + struct DriverLifecycleTiming { + uint64_t queuedNanos{0}; + uint64_t onThreadNanos{0}; + uint64_t blockedNanos{0}; + }; + + struct PipelineLifecycleStats { + std::string sourceOperatorType; + core::PlanNodeId sourcePlanNodeId; + std::vector driverTimes; + }; + std::vector pipelineLifecycleStats_; + + /// Initializes pipelineLifecycleStats_ from driverFactories_. + void initDriverLifecycleStatsLocked(); + // Stores inter-operator state (exchange, bridges) per split group. During // ungrouped execution we use the [0] entry in this vector. std::unordered_map splitGroupStates_; @@ -1377,7 +1427,7 @@ class Task : public std::enable_shared_from_this { // The promises for the futures returned to callers of requestBarrier(). std::vector barrierFinishPromises_; - std::atomic toYield_ = 0; + std::atomic_int32_t toYield_ = 0; int32_t numThreads_ = 0; // Microsecond real time when 'this' last went from no threads to // one thread running. Used to decide if continuous run should be diff --git a/velox/exec/TaskStructs.cpp b/velox/exec/TaskStructs.cpp index 21786166125..2d5cbc27d42 100644 --- a/velox/exec/TaskStructs.cpp +++ b/velox/exec/TaskStructs.cpp @@ -23,12 +23,23 @@ void SplitsStore::addSplit( std::vector& promises) { VELOX_CHECK(!noMoreSplits_); VELOX_CHECK(!(remoteSplit_ && split.isBarrier())); - splits_.push_back(std::move(split)); - if (promises_.empty()) { - return; + VELOX_CHECK(barrierSplits_.empty()); + if (split.isBarrier()) { + for (auto i = 0; i < split.barrier->numDrivers; ++i) { + barrierSplits_[i] = Split::createBarrier(); + } + VELOX_CHECK_LE(promises_.size(), split.barrier->numDrivers); + // A barrier is assigned to every driver; wake up all currently blocked + // drivers to process it. + std::move(promises_.begin(), promises_.end(), std::back_inserter(promises)); + promises_.clear(); + } else { + splits_.push_back(std::move(split)); + if (!promises_.empty()) { + promises.push_back(std::move(promises_.back())); + promises_.pop_back(); + } } - promises.push_back(std::move(promises_.back())); - promises_.pop_back(); } ContinueFuture SplitsStore::makeFuture() { @@ -85,4 +96,21 @@ Split SplitsStore::getSplit( return split; } +bool SplitsStore::tryGetBarrier( + std::optional driverId, + Split& split) { + if (!driverId.has_value()) { + barrierSplits_.clear(); + return false; + } + // Delivers a barrier exactly once for each driver from the same plan node. + auto it = barrierSplits_.find(*driverId); + if (it == barrierSplits_.end()) { + return false; + } + split = it->second; + barrierSplits_.erase(it); + return true; +} + } // namespace facebook::velox::exec diff --git a/velox/exec/TaskStructs.h b/velox/exec/TaskStructs.h index ba99890d3a0..59c8f0912b6 100644 --- a/velox/exec/TaskStructs.h +++ b/velox/exec/TaskStructs.h @@ -88,15 +88,29 @@ class SplitsStore { /// /// `promises` should be set by caller (potentially outside a lock), to notify /// any waiters on the splits. - virtual void requestBarrier(std::vector& promises) = 0; + virtual void requestBarrier( + uint32_t numDrivers, + std::vector& promises) = 0; - /// Return true when split is set or there is no more splits; false when + /// Returns the next split to process. + /// + /// @param driverId The driver id requesting the split. If set, a barrier + /// split will be delivered exactly once for each driver. If not set, all + /// barrier splits are cleared and no barrier split is returned. This is + /// used when cleaning up remaining remote splits during task termination. + /// @param maxPreloadSplits Maximum number of splits to preload. + /// @param preload Function to preload connector splits. + /// @param split Output parameter for the next split. + /// @param future Output parameter for the future to wait on if no split is + /// available. + /// @return true if a split is set or there are no more splits; false if the /// caller should retry when the future is fulfilled. virtual bool nextSplit( - Split& split, - ContinueFuture& future, + std::optional driverId, int maxPreloadSplits, - const ConnectorSplitPreloadFunc& preload) = 0; + const ConnectorSplitPreloadFunc& preload, + Split& split, + ContinueFuture& future) = 0; /// Return whether all splits has been consumed and there will be no more /// splits. @@ -139,6 +153,8 @@ class SplitsStore { ContinueFuture makeFuture(); + bool tryGetBarrier(std::optional driverId, Split& split); + const bool remoteSplit_; TaskStats* taskStats_{}; folly::F14FastSet>* @@ -146,6 +162,8 @@ class SplitsStore { // Arrived (added), but not distributed yet, splits. std::deque splits_; + // The map from driver id to barrier splits. + std::unordered_map barrierSplits_; // Signal, that no more splits will arrive. bool noMoreSplits_{false}; diff --git a/velox/exec/TaskTraceReader.cpp b/velox/exec/TaskTraceReader.cpp index 59647bd5305..9cf56807a9f 100644 --- a/velox/exec/TaskTraceReader.cpp +++ b/velox/exec/TaskTraceReader.cpp @@ -18,8 +18,8 @@ #include "velox/common/file/FileSystems.h" #include "velox/core/PlanNode.h" -#include "velox/exec/Trace.h" -#include "velox/exec/TraceUtil.h" +#include "velox/exec/trace/Trace.h" +#include "velox/exec/trace/TraceUtil.h" namespace facebook::velox::exec::trace { @@ -31,9 +31,10 @@ TaskTraceMetadataReader::TaskTraceMetadataReader( traceFilePath_(getTaskTraceMetaFilePath(traceDir_)), pool_(pool), metadataObj_(getTaskMetadata(traceFilePath_, fs_)), - tracePlanNode_(ISerializable::deserialize( - metadataObj_[TraceTraits::kPlanNodeKey], - pool_)) {} + tracePlanNode_( + ISerializable::deserialize( + metadataObj_[TraceTraits::kPlanNodeKey], + pool_)) {} std::unordered_map TaskTraceMetadataReader::queryConfigs() const { @@ -66,17 +67,24 @@ core::PlanNodePtr TaskTraceMetadataReader::queryPlan() const { } std::string TaskTraceMetadataReader::nodeName(const std::string& nodeId) const { - const auto* traceNode = core::PlanNode::findFirstNode( - tracePlanNode_.get(), - [&nodeId](const core::PlanNode* node) { return node->id() == nodeId; }); + LOG(ERROR) << "node id " << nodeId << " trace plan node " + << tracePlanNode_->toString(true, true); + const auto* traceNode = + core::PlanNode::findNodeById(tracePlanNode_.get(), nodeId); + VELOX_CHECK_NOT_NULL( + traceNode, "trace node id {} not found in the trace plan", nodeId); return std::string(traceNode->name()); } std::optional TaskTraceMetadataReader::connectorId( const std::string& nodeId) const { - const auto* traceNode = core::PlanNode::findFirstNode( - tracePlanNode_.get(), - [&nodeId](const core::PlanNode* node) { return node->id() == nodeId; }); + const auto* traceNode = + core::PlanNode::findNodeById(tracePlanNode_.get(), nodeId); + VELOX_CHECK_NOT_NULL( + traceNode, + "trace node id {} not found in the trace plan: {}", + nodeId, + tracePlanNode_->toString(true, true)); if (const auto* indexLookupJoinNode = dynamic_cast(traceNode)) { @@ -84,14 +92,16 @@ std::optional TaskTraceMetadataReader::connectorId( indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(); VELOX_CHECK(!indexLookupConnectorId.empty()); return indexLookupConnectorId; - } else if ( - const auto* tableScanNode = + } + + if (const auto* tableScanNode = dynamic_cast(traceNode)) { VELOX_CHECK_NOT_NULL(tableScanNode); const auto connectorId = tableScanNode->tableHandle()->connectorId(); VELOX_CHECK(!connectorId.empty()); return connectorId; } + return std::nullopt; } } // namespace facebook::velox::exec::trace diff --git a/velox/exec/TaskTraceWriter.cpp b/velox/exec/TaskTraceWriter.cpp index 46d28b66a41..218832c4e1e 100644 --- a/velox/exec/TaskTraceWriter.cpp +++ b/velox/exec/TaskTraceWriter.cpp @@ -18,15 +18,17 @@ #include "velox/common/file/File.h" #include "velox/core/PlanNode.h" #include "velox/core/QueryCtx.h" -#include "velox/exec/Trace.h" -#include "velox/exec/TraceUtil.h" +#include "velox/exec/trace/Trace.h" +#include "velox/exec/trace/TraceUtil.h" namespace facebook::velox::exec::trace { TaskTraceMetadataWriter::TaskTraceMetadataWriter( std::string traceDir, + std::string traceNodeId, memory::MemoryPool* /* pool */) : traceDir_(std::move(traceDir)), + traceNodeId_(std::move(traceNodeId)), fs_(filesystems::getFileSystem(traceDir_, nullptr)), traceFilePath_(getTaskTraceMetaFilePath(traceDir_)) { VELOX_CHECK_NOT_NULL(fs_); @@ -34,19 +36,22 @@ TaskTraceMetadataWriter::TaskTraceMetadataWriter( } void TaskTraceMetadataWriter::write( - const std::shared_ptr& queryCtx, - const core::PlanNodePtr& planNode) { + const core::QueryCtx& queryCtx, + const core::PlanNode& planNode) { VELOX_CHECK(!finished_, "Query metadata can only be written once"); finished_ = true; + + auto traceNode = trace::getTraceNode(planNode, traceNodeId_); + folly::dynamic queryConfigObj = folly::dynamic::object; - const auto configValues = queryCtx->queryConfig().rawConfigsCopy(); + const auto configValues = queryCtx.queryConfig().rawConfigsCopy(); for (const auto& [key, value] : configValues) { queryConfigObj[key] = value; } folly::dynamic connectorPropertiesObj = folly::dynamic::object; for (const auto& [connectorId, configs] : - queryCtx->connectorSessionProperties()) { + queryCtx.connectorSessionProperties()) { folly::dynamic obj = folly::dynamic::object; for (const auto& [key, value] : configs->rawConfigsCopy()) { obj[key] = value; @@ -57,7 +62,7 @@ void TaskTraceMetadataWriter::write( folly::dynamic metaObj = folly::dynamic::object; metaObj[TraceTraits::kQueryConfigKey] = queryConfigObj; metaObj[TraceTraits::kConnectorPropertiesKey] = connectorPropertiesObj; - metaObj[TraceTraits::kPlanNodeKey] = planNode->serialize(); + metaObj[TraceTraits::kPlanNodeKey] = traceNode->serialize(); const auto metaStr = folly::toJson(metaObj); const auto file = fs_->openFileForWrite(traceFilePath_); diff --git a/velox/exec/TaskTraceWriter.h b/velox/exec/TaskTraceWriter.h index a61cfb719eb..b10d6952058 100644 --- a/velox/exec/TaskTraceWriter.h +++ b/velox/exec/TaskTraceWriter.h @@ -19,18 +19,23 @@ #include "velox/common/file/FileSystems.h" #include "velox/core/PlanNode.h" #include "velox/core/QueryCtx.h" +#include "velox/exec/trace/TraceWriter.h" namespace facebook::velox::exec::trace { -class TaskTraceMetadataWriter { + +class TaskTraceMetadataWriter : public TraceMetadataWriter { public: - TaskTraceMetadataWriter(std::string traceDir, memory::MemoryPool* pool); + TaskTraceMetadataWriter( + std::string traceDir, + std::string traceNodeId, + memory::MemoryPool* pool); - void write( - const std::shared_ptr& queryCtx, - const core::PlanNodePtr& planNode); + void write(const core::QueryCtx& queryCtx, const core::PlanNode& planNode) + override; private: const std::string traceDir_; + const std::string traceNodeId_; const std::shared_ptr fs_; const std::string traceFilePath_; bool finished_{false}; diff --git a/velox/exec/TopN.cpp b/velox/exec/TopN.cpp index 1c7af492477..961661d6801 100644 --- a/velox/exec/TopN.cpp +++ b/velox/exec/TopN.cpp @@ -16,6 +16,7 @@ #include #include "velox/exec/ContainerRowSerde.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/TopN.h" #include "velox/vector/FlatVector.h" @@ -29,7 +30,7 @@ TopN::TopN( topNNode->outputType(), operatorId, topNNode->id(), - "TopN"), + OperatorType::kTopN), count_(topNNode->count()), data_(std::make_unique(outputType_->children(), pool())), comparator_( diff --git a/velox/exec/TopNRowNumber.cpp b/velox/exec/TopNRowNumber.cpp index 5bd614ffa1e..b26b8519fb0 100644 --- a/velox/exec/TopNRowNumber.cpp +++ b/velox/exec/TopNRowNumber.cpp @@ -15,10 +15,34 @@ */ #include "velox/exec/TopNRowNumber.h" +#include "velox/exec/OperatorType.h" + namespace facebook::velox::exec { namespace { +#define RANK_FUNCTION_DISPATCH(TEMPLATE_FUNC, functionKind, ...) \ + [&]() { \ + switch (functionKind) { \ + case core::TopNRowNumberNode::RankFunction::kRowNumber: { \ + return TEMPLATE_FUNC< \ + core::TopNRowNumberNode::RankFunction::kRowNumber>(__VA_ARGS__); \ + } \ + case core::TopNRowNumberNode::RankFunction::kRank: { \ + return TEMPLATE_FUNC( \ + __VA_ARGS__); \ + } \ + case core::TopNRowNumberNode::RankFunction::kDenseRank: { \ + return TEMPLATE_FUNC< \ + core::TopNRowNumberNode::RankFunction::kDenseRank>(__VA_ARGS__); \ + } \ + default: \ + VELOX_FAIL( \ + "not a rank function kind: {}", \ + core::TopNRowNumberNode::rankFunctionName(functionKind)); \ + } \ + }() + std::vector reorderInputChannels( const RowTypePtr& inputType, const std::vector& partitionKeys, @@ -109,13 +133,17 @@ TopNRowNumber::TopNRowNumber( node->outputType(), operatorId, node->id(), - "TopNRowNumber", + OperatorType::kTopNRowNumber, node->canSpill(driverCtx->queryConfig()) - ? driverCtx->makeSpillConfig(operatorId) + ? driverCtx->makeSpillConfig( + operatorId, + OperatorType::kTopNRowNumber) : std::nullopt), + rankFunction_(node->rankFunction()), limit_{node->limit()}, generateRowNumber_{node->generateRowNumber()}, numPartitionKeys_{node->partitionKeys().size()}, + numSortingKeys_{node->sortingKeys().size()}, inputChannels_{reorderInputChannels( node->inputType(), node->partitionKeys(), @@ -127,22 +155,20 @@ TopNRowNumber::TopNRowNumber( driverCtx->queryConfig().abandonPartialTopNRowNumberMinRows()), abandonPartialMinPct_( driverCtx->queryConfig().abandonPartialTopNRowNumberMinPct()), - data_(std::make_unique( - slice(inputType_->children(), 0, spillCompareFlags_.size()), - slice( - inputType_->children(), - spillCompareFlags_.size(), - inputType_->size()), - pool())), + data_( + std::make_unique( + slice(inputType_->children(), 0, spillCompareFlags_.size()), + slice( + inputType_->children(), + spillCompareFlags_.size(), + inputType_->size()), + pool())), comparator_( inputType_, node->sortingKeys(), node->sortingOrders(), data_.get()), decodedVectors_(inputType_->size()) { - VELOX_CHECK_EQ( - node->rankFunction(), core::TopNRowNumberNode::RankFunction::kRowNumber); - const auto& keys = node->partitionKeys(); const auto numKeys = keys.size(); @@ -163,6 +189,7 @@ TopNRowNumber::TopNRowNumber( false, // allowDuplicates false, // isJoinBuild false, // hasProbedFlag + false, // hasCountFlag 0, // minTableSizeForParallelJoinBuild pool()); partitionOffset_ = table_->rows()->columnAt(numKeys).offset(); @@ -203,7 +230,17 @@ void TopNRowNumber::addInput(RowVectorPtr input) { SelectivityVector rows(numInput); table_->prepareForGroupProbe( *lookup_, input, rows, BaseHashTable::kNoSpillInputStartPartitionBit); - table_->groupProbe(*lookup_, BaseHashTable::kNoSpillInputStartPartitionBit); + try { + table_->groupProbe( + *lookup_, BaseHashTable::kNoSpillInputStartPartitionBit); + } catch (...) { + // If groupProbe throws (e.g., due to OOM), we need to clean up the new + // groups that were inserted but not yet initialized by + // initializeNewPartitions(). Otherwise, close() will crash when trying to + // destroy uninitialized TopRows structures. + cleanupNewPartitions(); + throw; + } // Initialize new partitions. initializeNewPartitions(); @@ -211,18 +248,23 @@ void TopNRowNumber::addInput(RowVectorPtr input) { // Process input rows. For each row, lookup the partition. If the highest // (top) rank in that partition is less than limit, add the new row. // Otherwise, check if row should replace an existing row or be discarded. - processInputRowLoop(numInput); + RANK_FUNCTION_DISPATCH(processInputRowLoop, rankFunction_, numInput); + // It is determined that the TopNRowNumber (as a partial) is not rejecting + // enough input rows to make the duplicate detection worthwhile. Hence, + // abandon the processing at this partial TopN and let the final TopN do + // the processing. if (abandonPartialEarly()) { abandonedPartial_ = true; - addRuntimeStat("abandonedPartial", RuntimeCounter(1)); + addRuntimeStat( + std::string(TopNRowNumber::kAbandonedPartial), RuntimeCounter(1)); updateEstimatedOutputRowSize(); outputBatchSize_ = outputBatchRows(estimatedOutputRowSize_); outputRows_.resize(outputBatchSize_); } } else { - processInputRowLoop(numInput); + RANK_FUNCTION_DISPATCH(processInputRowLoop, rankFunction_, numInput); } } @@ -247,7 +289,60 @@ void TopNRowNumber::initializeNewPartitions() { } } -char* TopNRowNumber::processRowWithinLimit( +void TopNRowNumber::cleanupNewPartitions() { + std::vector newRows(lookup_->newGroups.size()); + for (auto i = 0; i < lookup_->newGroups.size(); ++i) { + newRows[i] = lookup_->hits[lookup_->newGroups[i]]; + } + table_->erase(folly::Range(newRows.data(), newRows.size())); + lookup_->newGroups.clear(); +} + +template <> +char* TopNRowNumber::processRowWithinLimit< + core::TopNRowNumberNode::RankFunction::kRank>( + vector_size_t index, + TopRows& partition) { + // The topRanks queue is not filled yet. + auto& topRows = partition.rows; + if (topRows.empty()) { + partition.topRank = 1; + } else { + // Rank assigns all peer rows the same rank, but the rank increments by + // the number of peers when moving between peers. So when adding a new + // row: + // If row == top rank then top rank is unchanged. + // If row < top rank then top rank += 1. + // If row > top, then rank += number of peers of top rank. + auto* topRow = topRows.top(); + auto result = comparator_.compare(decodedVectors_, index, topRow); + if (result < 0) { + partition.topRank += 1; + } else if (result > 0) { + partition.topRank += partition.numTopRankRows(); + } + } + return data_->newRow(); +} + +template <> +char* TopNRowNumber::processRowWithinLimit< + core::TopNRowNumberNode::RankFunction::kDenseRank>( + vector_size_t index, + TopRows& partition) { + // The topRanks queue is not filled yet. + // dense_rank will add this row to its partition. But the top rank is + // incremented only if the new row is not a peer of any other existing + // row in the partition queue. + if (!partition.isDuplicate(decodedVectors_, index)) { + partition.topRank++; + } + return data_->newRow(); +} + +template <> +char* TopNRowNumber::processRowWithinLimit< + core::TopNRowNumberNode::RankFunction::kRowNumber>( vector_size_t /*index*/, TopRows& partition) { // row_number accumulates the new row in the partition, and the top rank is @@ -256,7 +351,62 @@ char* TopNRowNumber::processRowWithinLimit( return data_->newRow(); } -char* TopNRowNumber::processRowExceedingLimit( +template <> +char* TopNRowNumber::processRowExceedingLimit< + core::TopNRowNumberNode::RankFunction::kRank>( + vector_size_t index, + TopRows& partition) { + auto& topRows = partition.rows; + // The new row < top rank + // For rank, the new row gets assigned its rank as per its position in the + // queue. But the ranks of all subsequent rows increment by 1. + // So we can remove the rows at the top rank as its rank > limit now. + char* topRow = partition.removeTopRankRows(); + char* newRow = data_->initializeRow(topRow, /*reuse=*/true); + if (topRows.empty()) { + partition.topRank = 1; + } else { + // The new top rank value depends on the number of peers of the top ranking + // row. If the current row also has the same value as the new top ranking + // row then it has to be counted as a peer as well. + auto numNewTopRankRows = partition.numTopRankRows(); + topRow = topRows.top(); + if (comparator_.compare(decodedVectors_, index, topRow) == 0) { + partition.topRank = topRows.size() - numNewTopRankRows + 1; + } else { + partition.topRank = topRows.size() - numNewTopRankRows + 2; + } + } + return newRow; +} + +template <> +char* TopNRowNumber::processRowExceedingLimit< + core::TopNRowNumberNode::RankFunction::kDenseRank>( + vector_size_t index, + TopRows& partition) { + char* newRow = nullptr; + // The new row < top rank + // For dense_rank: + // i) If the row is a peer of an existing row in the queue, then it + // has the same rank as it. The ranks of other rows are unchanged. So its + // only added to the queue. + // ii) If the row is a distinct new value in the queue, then it is assigned + // a rank as per its position, and the ranks of all subsequent rows += 1. + // So the current top rank rows can be removed from the queue as their new + // rank > limit. + if (partition.isDuplicate(decodedVectors_, index)) { + newRow = data_->newRow(); + } else { + char* topRow = partition.removeTopRankRows(); + newRow = data_->initializeRow(topRow, /*reuse=*/true); + } + return newRow; +} + +template <> +char* TopNRowNumber::processRowExceedingLimit< + core::TopNRowNumberNode::RankFunction::kRowNumber>( vector_size_t /*index*/, TopRows& partition) { // The new row has rank < highest (aka top) rank at 'limit' function value. @@ -269,13 +419,16 @@ char* TopNRowNumber::processRowExceedingLimit( return data_->initializeRow(topRow, true /* reuse */); } +template void TopNRowNumber::processInputRow(vector_size_t index, TopRows& partition) { auto& topRows = partition.rows; char* newRow = nullptr; if (partition.topRank < limit_) { - newRow = processRowWithinLimit(index, partition); + newRow = processRowWithinLimit(index, partition); } else { + // The partition has now accumulated >= limit rows. So the new rows can be + // rejected or replace existing rows based on the order_by values. char* topRow = topRows.top(); const auto result = comparator_.compare(decodedVectors_, index, topRow); @@ -284,14 +437,18 @@ void TopNRowNumber::processInputRow(vector_size_t index, TopRows& partition) { return; } - if (result == 0) { - // The new row has the same value as the top rank row. row_number rejects - // such rows. - return; + // This row has the same value as the top rank row. row_number rejects + // such rows, but are added to the queue for rank and dense_rank. The top + // rank remains unchanged. + else if (result == 0) { + if (rankFunction_ == core::TopNRowNumberNode::RankFunction::kRowNumber) { + return; + } + newRow = data_->newRow(); } - if (result < 0) { - newRow = processRowExceedingLimit(index, partition); + else if (result < 0) { + newRow = processRowExceedingLimit(index, partition); } } @@ -302,14 +459,15 @@ void TopNRowNumber::processInputRow(vector_size_t index, TopRows& partition) { topRows.push(newRow); } +template void TopNRowNumber::processInputRowLoop(vector_size_t numInput) { if (table_) { for (auto i = 0; i < numInput; ++i) { - processInputRow(i, partitionAt(lookup_->hits[i])); + processInputRow(i, partitionAt(lookup_->hits[i])); } } else { for (auto i = 0; i < numInput; ++i) { - processInputRow(i, *singlePartition_); + processInputRow(i, *singlePartition_); } } } @@ -330,7 +488,7 @@ void TopNRowNumber::noMoreInput() { spiller_->finishSpill(spillPartitionSet); VELOX_CHECK_EQ(spillPartitionSet.size(), 1); merge_ = spillPartitionSet.begin()->second->createOrderedReader( - spillConfig_->readBufferSize, pool(), spillStats_.get()); + *spillConfig_, pool(), spillStats_.get()); } else { outputRows_.resize(outputBatchSize_); } @@ -355,10 +513,46 @@ void TopNRowNumber::updateEstimatedOutputRowSize() { } } +// This function handles a special case when determining the starting +// rank value for the 'rank' function. +// If there are many peer rows for the highest rank, then topRank could +// oscillate between the two cases of topRank < limit and topRank > limit +// as rows are added +// E.g. If the input rows are 0, 0, 0, 5, 0, 0, 6 and we want rank <= 5, then +// at 0, 0, 0, 5 : +// topRows.pq - 0, 0, 0, 5 topRank -> 4 +// 0 is added. +// topRows.pq - 0, 0, 0, 0, 5 topRank -> 5 +// topRank = limit now. +// So when the next 0 is added, the last 5 is popped from TopRows and 0 is added +// topRows.pq - 0, 0, 0, 0, 0, topRank -> 1 +// This makes topRank < 5 and so when 6 comes by, 6 is pushed +// topRows.pq - 0, 0, 0, 0, 0, 6 topRank -> 6 +// So when doing getOutput, we need to adjust this case. +// Since topRank > limit, then the highest rank is popped and the +// topRank is adjusted as length(pq) - number_of_duplicates_of_new_top_row + 1. +vector_size_t TopNRowNumber::fixTopRank(TopRows& partition) { + if (rankFunction_ == core::TopNRowNumberNode::RankFunction::kRank) { + if (partition.topRank > limit_) { + partition.removeTopRankRows(); + auto numNewTopRankRows = partition.numTopRankRows(); + partition.topRank = partition.rows.size() - numNewTopRankRows + 1; + } + } + + return partition.topRank; +} + TopNRowNumber::TopRows* TopNRowNumber::nextPartition() { + auto setNextRankAndPeer = [&](TopRows& partition) { + nextRank_ = fixTopRank(partition); + numPeers_ = 1; + }; + if (!table_) { if (!outputPartitionNumber_) { outputPartitionNumber_ = 0; + setNextRankAndPeer(*singlePartition_); return singlePartition_.get(); } return nullptr; @@ -374,7 +568,6 @@ TopNRowNumber::TopRows* TopNRowNumber::nextPartition() { // No more partitions. return nullptr; } - outputPartitionNumber_ = 0; } else { ++outputPartitionNumber_.value(); @@ -384,24 +577,56 @@ TopNRowNumber::TopRows* TopNRowNumber::nextPartition() { } } - return &partitionAt(partitions_[outputPartitionNumber_.value()]); + auto partition = &partitionAt(partitions_[outputPartitionNumber_.value()]); + setNextRankAndPeer(*partition); + return partition; } +template +void TopNRowNumber::computeNextRankInMemory( + TopRows& partition, + vector_size_t outputIndex) { + if constexpr (TRank == core::TopNRowNumberNode::RankFunction::kRowNumber) { + nextRank_ -= 1; + return; + } + + // This is the logic for rank() and dense_rank(). + // If the next row is a peer of the current one, then the rank remains the + // same. + if (comparator_.compare(outputRows_[outputIndex], partition.rows.top()) == + 0) { + return; + } + + // The new row is not a peer of the current one. So dense_rank drops the + // rank by 1, but rank drops by the number of peers of the new top + // row (new rank) in TopRows queue. + if constexpr (TRank == core::TopNRowNumberNode::RankFunction::kDenseRank) { + nextRank_ -= 1; + } else { + nextRank_ -= partition.numTopRankRows(); + } +} + +template void TopNRowNumber::appendPartitionRows( TopRows& partition, vector_size_t numRows, vector_size_t outputOffset, - FlatVector* rowNumbers) { + FlatVector* rankValues) { // The partition.rows priority queue pops rows in order of reverse - // row numbers. - auto rowNumber = partition.rows.size(); + // ranks. Output rows based on nextRank_ and update it with each row. for (auto i = 0; i < numRows; ++i) { - const auto index = outputOffset + i; - if (rowNumbers) { - rowNumbers->set(index, rowNumber--); + auto index = outputOffset + i; + if (rankValues) { + rankValues->set(index, nextRank_); } outputRows_[index] = partition.rows.top(); partition.rows.pop(); + if (!partition.rows.empty()) { + computeNextRankInMemory(partition, index); + } } } @@ -417,7 +642,7 @@ RowVectorPtr TopNRowNumber::getOutput() { return output; } - // We may have input accumulated in 'data_'. + // There could be older rows accumulated in 'data_'. if (data_->numRows() > 0) { return getOutputFromMemory(); } @@ -426,6 +651,7 @@ RowVectorPtr TopNRowNumber::getOutput() { finished_ = true; } + // There is no data to return at this moment. return nullptr; } @@ -433,9 +659,11 @@ RowVectorPtr TopNRowNumber::getOutput() { return nullptr; } + // All the input data is received, so the operator can start producing + // output. RowVectorPtr output; if (merge_ != nullptr) { - output = getOutputFromSpill(); + output = RANK_FUNCTION_DISPATCH(getOutputFromSpill, rankFunction_); } else { output = getOutputFromMemory(); } @@ -472,20 +700,27 @@ RowVectorPtr TopNRowNumber::getOutputFromMemory() { const auto numOutputRowsLeft = outputBatchSize_ - offset; if (outputPartition_->rows.size() > numOutputRowsLeft) { - // Only a partial partition can be output in this getOutput() call. // Output as many rows as possible. - // NOTE: the partial output partition erases the yielded output rows - // and next getOutput() call starts with the remaining rows. - appendPartitionRows( - *outputPartition_, numOutputRowsLeft, offset, rowNumbers); + RANK_FUNCTION_DISPATCH( + appendPartitionRows, + rankFunction_, + *outputPartition_, + numOutputRowsLeft, + offset, + rowNumbers); offset += numOutputRowsLeft; break; } // Add all partition rows. - auto numPartitionRows = outputPartition_->rows.size(); - appendPartitionRows( - *outputPartition_, numPartitionRows, offset, rowNumbers); + const auto numPartitionRows = outputPartition_->rows.size(); + RANK_FUNCTION_DISPATCH( + appendPartitionRows, + rankFunction_, + *outputPartition_, + numPartitionRows, + offset, + rowNumbers); offset += numPartitionRows; outputPartition_ = nullptr; } @@ -512,13 +747,15 @@ RowVectorPtr TopNRowNumber::getOutputFromMemory() { return output; } -bool TopNRowNumber::isNewPartition( +bool TopNRowNumber::compareSpillRowColumns( const RowVectorPtr& output, vector_size_t index, - SpillMergeStream* next) { + const SpillMergeStream* next, + vector_size_t startColumn, + vector_size_t endColumn) { VELOX_CHECK_GT(index, 0); - for (auto i = 0; i < numPartitionKeys_; ++i) { + for (auto i = startColumn; i < endColumn; ++i) { if (!output->childAt(inputChannels_[i]) ->equalValueAt( next->current().childAt(i).get(), @@ -530,22 +767,79 @@ bool TopNRowNumber::isNewPartition( return false; } -void TopNRowNumber::setupNextOutput( +// Compares the partition keys for new partitions. +bool TopNRowNumber::isNewPartition( const RowVectorPtr& output, - int32_t rowNumber) { - auto* lookAhead = merge_->next(); - if (lookAhead == nullptr) { - nextRowNumber_ = 0; + vector_size_t index, + const SpillMergeStream* next) { + return compareSpillRowColumns(output, index, next, 0, numPartitionKeys_); +} + +// Compares the sorting keys for determining peers. +bool TopNRowNumber::isNewRank( + const RowVectorPtr& output, + vector_size_t index, + const SpillMergeStream* next) { + return compareSpillRowColumns( + output, + index, + next, + numPartitionKeys_, + numPartitionKeys_ + numSortingKeys_); +} + +template +void TopNRowNumber::computeNextRankInSpill( + const RowVectorPtr& output, + vector_size_t index, + const SpillMergeStream* next) { + if (isNewPartition(output, index, next)) { + nextRank_ = 1; + numPeers_ = 1; + return; + } + + if constexpr (TRank == core::TopNRowNumberNode::RankFunction::kRowNumber) { + nextRank_ += 1; return; } - if (isNewPartition(output, output->size(), lookAhead)) { - nextRowNumber_ = 0; + // The function is either rank or dense_rank. + // This row belongs to the same partition as the previous row. However, + // it should be determined if it is a peer row as well. If its a peer, + // then increase numPeers_ but the rank remains unchanged. + if (!isNewRank(output, index, next)) { + numPeers_ += 1; return; } - nextRowNumber_ = rowNumber; - if (nextRowNumber_ < limit_) { + // The row is not a peer, so increment the rank and peers accordingly. + if constexpr (TRank == core::TopNRowNumberNode::RankFunction::kDenseRank) { + nextRank_ += 1; + numPeers_ = 1; + return; + } + + // Rank function increments by number of peers. + nextRank_ += numPeers_; + numPeers_ = 1; +} + +template +void TopNRowNumber::setupNextOutput(const RowVectorPtr& output) { + auto resetNextRankAndPeer = [this]() { + nextRank_ = 1; + numPeers_ = 1; + }; + + auto* lookAhead = merge_->next(); + if (lookAhead == nullptr) { + resetNextRankAndPeer(); + return; + } + + computeNextRankInSpill(output, output->size(), lookAhead); + if (nextRank_ <= limit_) { return; } @@ -553,16 +847,17 @@ void TopNRowNumber::setupNextOutput( lookAhead->pop(); while (auto* next = merge_->next()) { if (isNewPartition(output, output->size(), next)) { - nextRowNumber_ = 0; + resetNextRankAndPeer(); return; } next->pop(); } // This partition is the last partition. - nextRowNumber_ = 0; + resetNextRankAndPeer(); } +template RowVectorPtr TopNRowNumber::getOutputFromSpill() { VELOX_CHECK_NOT_NULL(merge_); @@ -570,37 +865,32 @@ RowVectorPtr TopNRowNumber::getOutputFromSpill() { // All rows from the same partition will appear together. // We'll identify partition boundaries by comparing partition keys of the // current row with the previous row. When new partition starts, we'll reset - // row number to zero. Once row number reaches the 'limit_', we'll start + // nextRank_ and numPeers_. Once rank reaches the 'limit_', we'll start // dropping rows until the next partition starts. // We'll emit output every time we accumulate 'outputBatchSize_' rows. - auto output = BaseVector::create(outputType_, outputBatchSize_, pool()); - FlatVector* rowNumbers = nullptr; + FlatVector* rankValues = nullptr; if (generateRowNumber_) { - rowNumbers = output->children().back()->as>(); + rankValues = output->children().back()->as>(); } // Index of the next row to append to output. vector_size_t index = 0; - - // Row number of the next row in the current partition. - vector_size_t rowNumber = nextRowNumber_; - VELOX_CHECK_LT(rowNumber, limit_); + VELOX_CHECK_LE(nextRank_, limit_); for (;;) { auto next = merge_->next(); if (next == nullptr) { break; } - // Check if this row comes from a new partition. - if (index > 0 && isNewPartition(output, index, next)) { - rowNumber = 0; + if (index > 0) { + computeNextRankInSpill(output, index, next); } // Copy this row to the output buffer if this partition has // < limit_ rows output. - if (rowNumber < limit_) { + if (nextRank_ <= limit_) { for (auto i = 0; i < inputChannels_.size(); ++i) { output->childAt(inputChannels_[i]) ->copy( @@ -609,12 +899,11 @@ RowVectorPtr TopNRowNumber::getOutputFromSpill() { next->currentIndex(), 1); } - if (rowNumbers) { - // Row numbers start with 1. - rowNumbers->set(index, rowNumber + 1); + + if (rankValues) { + rankValues->set(index, nextRank_); } ++index; - ++rowNumber; } // Pop this row from the spill. @@ -625,8 +914,8 @@ RowVectorPtr TopNRowNumber::getOutputFromSpill() { // Prepare the next batch : // i) If 'limit_' is reached for this partition, then skip the rows // until the next partition. - // ii) If the next row is from a new partition, then reset rowNumber_. - setupNextOutput(output, rowNumber); + // ii) If the next row is from a new partition, then reset nextRank_. + setupNextOutput(output); return output; } } @@ -692,8 +981,11 @@ void TopNRowNumber::reclaim( // TODO Add support for spilling after noMoreInput(). LOG(WARNING) << "Can't reclaim from topNRowNumber operator which has started producing output: " - << pool()->name() << ", usage: " << succinctBytes(pool()->usedBytes()) - << ", reservation: " << succinctBytes(pool()->reservedBytes()); + << pool()->name() << ", root pool: " << pool()->root()->name() + << ", used: " << succinctBytes(pool()->usedBytes()) + << ", reservation: " << succinctBytes(pool()->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool()->root()->reservedBytes()); return; } @@ -741,14 +1033,15 @@ void TopNRowNumber::ensureInputFits(const RowVectorPtr& input) { if ((tableIncrementBytes == 0) && (freeRows > input->size()) && (outOfLineBytes == 0 || outOfLineFreeBytes >= outOfLineBytesPerRow * input->size())) { - // Enough free rows for input rows and enough variable length free space. + // Enough free rows for input rows and enough variable length free + // space. return; } } - // Check if we can increase reservation. The increment is the largest of twice - // the maximum increment from this input and 'spillableReservationGrowthPct_' - // of the current memory usage. + // Check if we can increase reservation. The increment is the largest of + // twice the maximum increment from this input and + // 'spillableReservationGrowthPct_' of the current memory usage. const auto targetIncrementBytes = std::max( incrementBytes * 2, currentUsage * spillConfig_->spillableReservationGrowthPct / 100); @@ -761,8 +1054,11 @@ void TopNRowNumber::ensureInputFits(const RowVectorPtr& input) { LOG(WARNING) << "Failed to reserve " << succinctBytes(targetIncrementBytes) << " for memory pool " << pool()->name() - << ", usage: " << succinctBytes(pool()->usedBytes()) - << ", reservation: " << succinctBytes(pool()->reservedBytes()); + << ", root pool: " << pool()->root()->name() + << ", used: " << succinctBytes(pool()->usedBytes()) + << ", reservation: " << succinctBytes(pool()->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool()->root()->reservedBytes()); } void TopNRowNumber::spill() { @@ -789,4 +1085,80 @@ void TopNRowNumber::setupSpiller() { &spillConfig_.value(), spillStats_.get()); } + +// Using the underlying vector of the priority queue for the algorithms to +// check duplicates and count the number of top rank rows. This makes the +// algorithms O(n). There could be other approaches to make the +// algorithms O(1), but would trade memory efficiency. +namespace { +template +S& PriorityQueueVector(std::priority_queue& q) { + struct PrivateQueue : private std::priority_queue { + static S& Container(std::priority_queue& q) { + return q.*&PrivateQueue::c; + } + }; + return PrivateQueue::Container(q); +} +} // namespace + +char* TopNRowNumber::TopRows::removeTopRankRows() { + VELOX_CHECK(!rows.empty()); + + char* topRow = rows.top(); + rows.pop(); + + while (!rows.empty()) { + char* newTopRow = rows.top(); + if (rowComparator.compare(topRow, newTopRow) != 0) { + return topRow; + } + rows.pop(); + } + return topRow; +} + +vector_size_t TopNRowNumber::TopRows::numTopRankRows() { + VELOX_CHECK(!rows.empty()); + + tempTopRankRows.clear(); + SCOPE_EXIT { + tempTopRankRows.clear(); + }; + auto popAndSaveTopRow = [&]() { + tempTopRankRows.push_back(rows.top()); + rows.pop(); + }; + + char* topRow = rows.top(); + popAndSaveTopRow(); + while (!rows.empty()) { + if (rowComparator.compare(topRow, rows.top()) == 0) { + popAndSaveTopRow(); + } else { + break; + } + } + + vector_size_t numTopRows = tempTopRankRows.size(); + // Re-insert all rows with the top rank row. + for (char* row : tempTopRankRows) { + rows.push(row); + } + return numTopRows; +} + +bool TopNRowNumber::TopRows::isDuplicate( + const std::vector& decodedVectors, + vector_size_t index) { + const std::vector> partitionRowsVector = + PriorityQueueVector(rows); + for (const char* row : partitionRowsVector) { + if (rowComparator.compare(decodedVectors, index, row) == 0) { + return true; + } + } + return false; +} + } // namespace facebook::velox::exec diff --git a/velox/exec/TopNRowNumber.h b/velox/exec/TopNRowNumber.h index dc21f0f93c4..12daa35cce3 100644 --- a/velox/exec/TopNRowNumber.h +++ b/velox/exec/TopNRowNumber.h @@ -15,6 +15,8 @@ */ #pragma once +#include + #include "velox/exec/HashTable.h" #include "velox/exec/Operator.h" #include "velox/exec/Spiller.h" @@ -22,22 +24,63 @@ namespace facebook::velox::exec { class TopNRowNumberSpiller; -/// Partitions the input using specified partitioning keys, sorts rows within -/// partitions using specified sorting keys, assigns row numbers and returns up -/// to specified number of rows per partition. +/// TopNRowNumber is an optimized version of a Window operator with a +/// single row_number or rank or dense_rank window function followed by a +/// rank <= N filter. N must be >= 0. If the TopNRowNumber has no partition +/// keys, then all the rows belong to a single partition. However, the +/// TopNRowNumber should have at least one sorting key specified. +/// +/// TopNRowNumber is more efficient than a general Window operator as it does +/// not store all rows of a partition. Instead, it only keeps the top N +/// rows of the partition at any point. +/// +/// The operator partitions the input using specified partitioning keys, +/// and maintains a TopRows structure per partition in a HashTable. The TopRows +/// maintains a priority queue of row pointers. The priority queue is +/// kept ordered by sorting keys of the TopNRowNumber. The TopRows only retains +/// rows whose ranks satisfy the filter condition (so rank <= N). N is also +/// called the limit of the operator. To aid this filtering, the TopRows tracks +/// the greatest rank seen for each partition. +/// +/// The operator processes all input rows before beginning to output rows. /// -/// It is allowed to not specify partitioning keys. In this case the whole input -/// is treated as a single partition. +/// For each input row, it retrieves the TopRows corresponding to the partition +/// keys. The TopRows is first filled until it has N rows. Thereafter, new rows +/// are compared with the top row in the TopRows priority queue. +/// If the new rows order by values are less than (for ASC) or greater than +/// (for DESC) so row rank <= topRank, then the row is added to TopRows. +/// For each outcome, the greatest rank of the TopRows is updated as per the +/// ranking function logic. +/// For each function type, the rank maintenance logic is in: +/// - processRowWithinLimit() function when the TopRows is filling the first +/// N rows. +/// - processRowExceedingLimit() function when the TopRows already has N rows. /// -/// At least one sorting key must be specified. +/// After processing all the input rows, the operator proceeds to output the +/// rows. The rows might all be in memory or spilled to disk if memory +/// reclamation was triggered during processing. /// -/// The limit (maximum number of rows to return per partition) must be greater -/// than zero. +/// If the rows are in memory, then the operator iterates over each partition +/// in the HashTable, and starts outputting rows from the partition. The +/// TopRows structure maintains the rows in descending order of their ranks +/// (greatest rank at the top of the priority queue). So when outputting, +/// the operator first fixes the top rank of the partition using fixTopRank() +/// and then computes the ranks of each row using computeNextRankInMemory(). +/// The logic of the next rank differs based on the ranking function. /// -/// This is an optimized version of a Window operator with a single row_number -/// window function followed by a row_number <= N filter. +/// If the rows are in the spill, then the spiller iterates over each spilled +/// partition in order of the ranks. For each row from the spill, the next +/// rank is computed using computeNextRankInSpill() function. The logic of +/// the next rank differs based on the ranking function. +/// Note : The spill could have > limit rows for a partition as each spill +/// resets the TopRows for the partition. So stop outputting rows after +/// reaching the limit for each partition. + class TopNRowNumber : public Operator { public: + /// Runtime stat key indicating partial TopN was abandoned. + static constexpr std::string_view kAbandonedPartial = "abandonedPartial"; + TopNRowNumber( int32_t operatorId, DriverCtx* driverCtx, @@ -71,7 +114,23 @@ class TopNRowNumber : public Operator { override; private: - // A priority queue to keep track of top 'limit' rows for a given partition. + // This structure holds the top rows for a partition. It uses a priority + // queue to maintain the top rows in order of their ranks. Note the rank + // logic depends on the respective function (row_number, rank or dense_rank). + // However, a common requirement across all three is to maintain the rows in + // order of their sort keys so that the greatest rank row is always at the top + // of the queue. This ordering is done using the RowComparator passed to the + // TopRows. + // + // The number of rows in TopRows are limited to 'limit' specified for the + // operator. The greatest rank of the rows in TopRows is maintained in the + // 'topRank' variable. + // + // The TopRows structure is first filled in order to collect 'limit' + // rows. Thereafter, new rows are compared with the top row and either kept + // or discarded and the new top rank is updated. The rank computation differs + // based on the ranking function. This structure has methods for abstractions + // used for the top rank maintenance algorithms. struct TopRows { struct Compare { RowComparator& comparator; @@ -84,17 +143,43 @@ class TopNRowNumber : public Operator { std::priority_queue>, Compare> rows; - // This is the highest rank (this code will be enhanced for rank, dense_rank - // soon) seen so far in the input rows. It is compared - // with the limit for the operator. + // Temporary storage for rows with the highest rank in the partition. + std::vector> tempTopRankRows; + + RowComparator& rowComparator; + + // This is the greatest rank seen so far in the input rows. Note: rank is + // the result of the respective function computation (row_number, rank or + // dense_rank). It is compared with the expected limit for the operator. int64_t topRank = 0; + // Number of rows with the highest rank in the partition. + vector_size_t numTopRankRows(); + + // Remove all rows with the highest rank in the partition. + // Returns a pointer to the last removed row. + char* removeTopRankRows(); + + // Returns true if the row at position index in decodedVectors + // has the same order by keys as another row in the TopRows + // priority_vector. + bool isDuplicate( + const std::vector& decodedVectors, + vector_size_t index); + TopRows(HashStringAllocator* allocator, RowComparator& comparator) - : rows{{comparator}, StlAllocator(allocator)} {} + : rows{{comparator}, StlAllocator(allocator)}, + tempTopRankRows(StlAllocator(allocator)), + rowComparator(comparator) {} }; void initializeNewPartitions(); + // Cleans up any newly inserted but uninitialized partitions from the hash + // table. This is called when groupProbe throws (e.g., due to OOM) to ensure + // close() doesn't crash trying to destroy uninitialized TopRows structures. + void cleanupNewPartitions(); + TopRows& partitionAt(char* group) { return *reinterpret_cast(group + partitionOffset_); } @@ -104,26 +189,42 @@ class TopNRowNumber : public Operator { // Handles input row when the partition has not yet accumulated 'limit' rows. // Returns a pointer to the row to add to the partition accumulator. + template char* processRowWithinLimit(vector_size_t index, TopRows& partition); // Handles input row when the partition has already accumulated 'limit' rows. // Returns a pointer to the row to add to the partition accumulator. + template char* processRowExceedingLimit(vector_size_t index, TopRows& partition); - // Loop to add each row to a partition or discard the row. + // Loop to process the numInput input rows received by the operator. + template void processInputRowLoop(vector_size_t numInput); // Adds input row to a partition or discards the row. + template void processInputRow(vector_size_t index, TopRows& partition); // Returns next partition to add to output or nullptr if there are no // partitions left. TopRows* nextPartition(); - // Appends numRows of the output partition the output. Note: The rows are - // popped in reverse order of the row_number. + // If there are many rows with the highest rank, then the topRank + // of the partition can oscillate between a very small value and a + // value > limit. Fix the partition for this condition before starting to + // output the partition. + vector_size_t fixTopRank(TopRows& partition); + + // Computes the rank for the next row to be output + // (all output rows in memory). + template + void computeNextRankInMemory(TopRows& partition, vector_size_t rowIndex); + + // Appends numRows of the current partition to the output. Note: The rows are + // popped in reverse order of the rank. // NOTE: This function erases the yielded output rows from the partition // and the next call starts with the remaining rows. + template void appendPartitionRows( TopRows& partition, vector_size_t numRows, @@ -141,6 +242,7 @@ class TopNRowNumber : public Operator { void setupSpiller(); + template RowVectorPtr getOutputFromSpill(); RowVectorPtr getOutputFromMemory(); @@ -150,17 +252,42 @@ class TopNRowNumber : public Operator { bool isNewPartition( const RowVectorPtr& output, vector_size_t index, - SpillMergeStream* next); + const SpillMergeStream* next); - // Sets nextRowNumber_ to rowNumber. Checks if next row in 'merge_' belongs to - // a different partition than last row in 'output' and if so updates - // nextRowNumber_ to 0. Also, checks current partition reached the limit on - // number of rows and if so advances 'merge_' to the first row on the next - // partition and sets nextRowNumber_ to 0. + // Returns true if 'next' row is a new rank (rows differ on order by keys) + // of the previous row in the partition (at output[index] of the + // output block). + bool isNewRank( + const RowVectorPtr& output, + vector_size_t index, + const SpillMergeStream* next); + + // Utility method to compare values from startColumn to endColumn for + // 'next' row from SpillMergeStream with current row of output (at index). + bool compareSpillRowColumns( + const RowVectorPtr& output, + vector_size_t index, + const SpillMergeStream* next, + vector_size_t startColumn, + vector_size_t endColumn); + + // Computes next rank value for spill output. + template + inline void computeNextRankInSpill( + const RowVectorPtr& output, + vector_size_t index, + const SpillMergeStream* next); + + // Checks if next row in 'merge_' belongs to a different partition than last + // row in 'output' and if so updates nextRank_ and numPeers_ to 1. + // Also, checks current partition reached the limit on rank and + // if so advances 'merge_' to the first row on the next + // partition and sets nextRank_ and numPeers_ to 0. // // @post 'merge_->next()' is either at end or points to a row that should be - // included in the next output batch using 'nextRowNumber_'. - void setupNextOutput(const RowVectorPtr& output, int32_t rowNumber); + // included in the next output batch using 'nextRank_'. + template + void setupNextOutput(const RowVectorPtr& output); // Called in noMoreInput() and spill(). void updateEstimatedOutputRowSize(); @@ -169,11 +296,15 @@ class TopNRowNumber : public Operator { // cardinality sufficiently. Returns false if spilling was triggered earlier. bool abandonPartialEarly() const; + // Rank function semantics of operator. + const core::TopNRowNumberNode::RankFunction rankFunction_; + const int32_t limit_; const bool generateRowNumber_; const size_t numPartitionKeys_; + const size_t numSortingKeys_; // Input columns in the order of: partition keys, sorting keys, the rest. const std::vector inputChannels_; @@ -260,7 +391,11 @@ class TopNRowNumber : public Operator { // Used to sort-merge spilled data. std::unique_ptr> merge_; - // Row number for the first row in the next output batch from the spiller. - int32_t nextRowNumber_{0}; + // Row number/rank or dense_rank for the first row in the next output batch + // from the spiller. + vector_size_t nextRank_{1}; + // Number of peers of first row in the previous output batch. This is used + // in rank calculation. + vector_size_t numPeers_{1}; }; } // namespace facebook::velox::exec diff --git a/velox/exec/TraceUtil.h b/velox/exec/TraceUtil.h index b9814e0e90d..a3483320e87 100644 --- a/velox/exec/TraceUtil.h +++ b/velox/exec/TraceUtil.h @@ -16,138 +16,5 @@ #pragma once -#include -#include -#include "velox/common/file/FileSystems.h" -#include "velox/core/PlanNode.h" -#include "velox/exec/Task.h" -#include "velox/type/Type.h" - -#include - -namespace facebook::velox::exec::trace { - -/// Creates a directory to store the query trace metdata and data. -void createTraceDirectory( - const std::string& traceDir, - const std::string& directoryConfig = ""); - -/// Returns the trace directory for a given query. -std::string getQueryTraceDirectory( - const std::string& traceDir, - const std::string& queryId); - -/// Returns the trace directory for a given query task. -std::string getTaskTraceDirectory( - const std::string& traceDir, - const Task& task); - -std::string getTaskTraceDirectory( - const std::string& traceDir, - const std::string& queryId, - const std::string& taskId); - -/// Returns the file path for a given task's metadata trace file. -std::string getTaskTraceMetaFilePath(const std::string& taskTraceDir); - -/// Returns the trace directory for a given traced plan node. -std::string getNodeTraceDirectory( - const std::string& taskTraceDir, - const std::string& nodeId); - -/// Returns the trace directory for a given traced pipeline. -std::string getPipelineTraceDirectory( - const std::string& nodeTraceDir, - uint32_t pipelineId); - -/// Returns the trace directory for a given traced operator. -std::string getOpTraceDirectory( - const std::string& taskTraceDir, - const std::string& nodeId, - uint32_t pipelineId, - uint32_t driverId); - -std::string getOpTraceDirectory( - const std::string& nodeTraceDir, - uint32_t pipelineId, - uint32_t driverId); - -/// Returns the file path for a given operator's traced input file. -std::string getOpTraceInputFilePath(const std::string& opTraceDir); - -/// Returns the file path for a given operator's traced split file. -std::string getOpTraceSplitFilePath(const std::string& opTraceDir); - -/// Returns the file path for a given operator's traced input file. -std::string getOpTraceSummaryFilePath(const std::string& opTraceDir); - -/// Extracts the input data type for the trace scan operator. The function first -/// uses the traced node id to find traced operator's plan node from the traced -/// plan fragment. Then it uses the specified source node index to find the -/// output data type from its source node plans as the input data type of the -/// traced plan node. -/// -/// For hash join plan node, there are two source nodes, the output data type -/// of the first node is the input data type of the 'HashProbe' operator, and -/// the output data type of the second one is the input data type of the -/// 'HashBuild' operator. -/// -/// @param tracedPlan The root node of the trace plan fragment. -/// @param tracedNodeId The node id of the trace node. -/// @param sourceIndex The source index of the specific traced operator. -RowTypePtr getDataType( - const core::PlanNodePtr& tracedPlan, - const std::string& tracedNodeId, - size_t sourceIndex = 0); - -/// Extracts pipeline IDs in ascending order by listing the trace directory, -/// then decoding the names of the subdirectories to obtain the pipeline IDs, -/// and finally sorting them. 'nodeTraceDir' corresponds to the trace directory -/// of the plan node. -std::vector listPipelineIds( - const std::string& nodeTraceDir, - const std::shared_ptr& fs); - -/// Extracts driver IDs in ascending order by listing the trace directory for a -/// given pipeline then decoding the names of the subdirectories to obtain the -/// driver IDs, and finally sorting them. 'nodeTraceDir' corresponds to the -/// trace directory of the plan node. -std::vector listDriverIds( - const std::string& nodeTraceDir, - uint32_t pipelineId, - const std::shared_ptr& fs); - -/// Extracts the driver IDs from the comma-separated list of driver IDs string. -std::vector extractDriverIds(const std::string& driverIds); - -/// Extracts task ids of the query tracing by listing the query trace directory. -/// 'traceDir' is the root trace directory. 'queryId' is the query id. -std::vector getTaskIds( - const std::string& traceDir, - const std::string& queryId, - const std::shared_ptr& fs); - -/// Gets the metadata from a given task metadata file which includes query plan, -/// configs and connector properties. -folly::dynamic getTaskMetadata( - const std::string& taskMetaFilePath, - const std::shared_ptr& fs); - -/// Checks whether the operator can be traced. -bool canTrace(const std::string& operatorType); - -/// Gets the specified the trace node from 'plan'. In the returned trace node, -/// we replace its source nodes with DummySourceNode for replay. -core::PlanNodePtr getTraceNode( - const core::PlanNodePtr& plan, - core::PlanNodeId nodeId); - -using TraceNodeFactory = std::function< - core::PlanNodePtr(const core::PlanNode*, const core::PlanNodeId&)>; - -void registerTraceNodeFactory( - const std::string& operatorType, - TraceNodeFactory&& factory); - -void registerDummySourceSerDe(); -} // namespace facebook::velox::exec::trace +// To maintain backward compatibility. +#include "velox/exec/trace/TraceUtil.h" diff --git a/velox/exec/Unnest.cpp b/velox/exec/Unnest.cpp index 99aa13fdd3b..7d464112cea 100644 --- a/velox/exec/Unnest.cpp +++ b/velox/exec/Unnest.cpp @@ -16,6 +16,7 @@ #include "velox/exec/Unnest.h" #include "velox/common/base/Nulls.h" +#include "velox/exec/OperatorType.h" #include "velox/vector/FlatVector.h" namespace facebook::velox::exec { @@ -40,11 +41,16 @@ Unnest::Unnest( unnestNode->outputType(), operatorId, unnestNode->id(), - "Unnest"), + OperatorType::kUnnest), withOrdinality_(unnestNode->hasOrdinality()), - withEmptyUnnestValue_(unnestNode->hasEmptyUnnestValue()), + withMarker_(unnestNode->hasMarker()), maxOutputSize_( - driverCtx->queryConfig().unnestSplitOutput() + // If splitOutput is set to true in the UnnestNode or it's not set at + // all and it's enabled in the QueryConfig. + ((unnestNode->splitOutput().has_value() && + unnestNode->splitOutput().value()) || + (!unnestNode->splitOutput().has_value() && + driverCtx->queryConfig().unnestSplitOutput())) ? outputBatchRows() : std::numeric_limits::max()) { const auto& inputType = unnestNode->sources()[0]->outputType(); @@ -60,11 +66,11 @@ Unnest::Unnest( unnestDecoded_.resize(unnestVariables.size()); column_index_t checkOutputChannel = outputType_->size() - 1; - if (withEmptyUnnestValue_) { + if (withMarker_) { VELOX_CHECK_EQ( outputType_->childAt(checkOutputChannel), BOOLEAN(), - "Empty unnest value column should be BOOLEAN type."); + "Marker column should be BOOLEAN type."); --checkOutputChannel; } if (withOrdinality_) { @@ -194,7 +200,7 @@ Unnest::RowRange Unnest::extractRowRange(vector_size_t inputSize) const { if (rawMaxSizes_[inputRow] == 0) { VELOX_CHECK_EQ(remainingInnerRows, 0); hasEmptyUnnestValue = true; - if (withEmptyUnnestValue_) { + if (withMarker_) { remainingInnerRows = 1; } } @@ -237,12 +243,11 @@ void Unnest::generateRepeatedColumns( vector_size_t* rawRepeatedIndices = repeatedIndices->asMutable(); - const bool generateEmptyUnnestValue = - withEmptyUnnestValue_ && range.hasEmptyUnnestValue; + const bool generateMarker = withMarker_ && range.hasEmptyUnnestValue; vector_size_t index{0}; VELOX_CHECK_GT(range.numInputRows, 0); // Record the row number to process. - if (generateEmptyUnnestValue) { + if (generateMarker) { range.forEachRow( [&](vector_size_t row, vector_size_t /*start*/, vector_size_t size) { if (FOLLY_UNLIKELY(size == 0)) { @@ -302,7 +307,7 @@ const Unnest::UnnestChannelEncoding Unnest::generateEncodingForChannel( range.forEachRow( [&](vector_size_t row, vector_size_t start, vector_size_t size) { const auto end = start + size; - if (size == 0 && withEmptyUnnestValue_) { + if (size == 0 && withMarker_) { identityMapping = false; bits::setNull(rawNulls, index++, true); } else if (!currentDecoded.isNullAt(row)) { @@ -343,9 +348,8 @@ VectorPtr Unnest::generateOrdinalityVector(const RowRange& range) { // Set the ordinality at each result row to be the index of the element in // the original array (or map) plus one. auto* rawOrdinality = ordinalityVector->mutableRawValues(); - const bool hasEmptyUnnestValue = - withEmptyUnnestValue_ && range.hasEmptyUnnestValue; - if (!hasEmptyUnnestValue) { + const bool hasMarker = withMarker_ && range.hasEmptyUnnestValue; + if (!hasMarker) { range.forEachRow( [&](vector_size_t /*row*/, vector_size_t start, vector_size_t size) { std::iota(rawOrdinality, rawOrdinality + size, start + 1); @@ -374,28 +378,28 @@ VectorPtr Unnest::generateOrdinalityVector(const RowRange& range) { return ordinalityVector; } -VectorPtr Unnest::generateEmptyUnnestValueVector(const RowRange& range) { - VELOX_CHECK(withEmptyUnnestValue_); +VectorPtr Unnest::generateMarkerVector(const RowRange& range) { + VELOX_CHECK(withMarker_); VELOX_DCHECK_GT(range.numInputRows, 0); if (!range.hasEmptyUnnestValue) { return BaseVector::createConstant( - BOOLEAN(), false, range.numInnerRows, pool()); + BOOLEAN(), true, range.numInnerRows, pool()); } - // Create a vector with all elements set to false initially assuming most + // Create a vector with all elements set to true initially assuming most // output rows have non-empty unnest values. - auto emptyBuffer = - velox::AlignedBuffer::allocate(range.numInnerRows, pool(), false); - auto emptyVector = std::make_shared>( + auto markerBuffer = + velox::AlignedBuffer::allocate(range.numInnerRows, pool(), true); + auto markerVector = std::make_shared>( pool(), /*type=*/BOOLEAN(), /*nulls=*/nullptr, range.numInnerRows, - /*values=*/std::move(emptyBuffer), + /*values=*/std::move(markerBuffer), /*stringBuffers=*/std::vector{}); - // Set each output row has empty unnest values. - auto* const rawEmpty = emptyVector->mutableRawValues(); + // Set each output row with empty unnest values to false. + auto* const rawMarker = markerVector->mutableRawValues(); size_t index{0}; range.forEachRow( [&](vector_size_t /*row*/, vector_size_t start, vector_size_t size) { @@ -403,12 +407,12 @@ VectorPtr Unnest::generateEmptyUnnestValueVector(const RowRange& range) { index += size; } else { VELOX_DCHECK_EQ(size, 0); - bits::setBit(rawEmpty, index++, true); + bits::setBit(rawMarker, index++, false); } }, rawMaxSizes_, firstInnerRowStart_); - return emptyVector; + return markerVector; } RowVectorPtr Unnest::generateOutput(const RowRange& range) { @@ -443,8 +447,8 @@ RowVectorPtr Unnest::generateOutput(const RowRange& range) { if (withOrdinality_) { outputs[outputColumnIndex++] = generateOrdinalityVector(range); } - if (withEmptyUnnestValue_) { - outputs[outputColumnIndex++] = generateEmptyUnnestValueVector(range); + if (withMarker_) { + outputs[outputColumnIndex++] = generateMarkerVector(range); } return std::make_shared( diff --git a/velox/exec/Unnest.h b/velox/exec/Unnest.h index 58812e86ddd..e1fb967525c 100644 --- a/velox/exec/Unnest.h +++ b/velox/exec/Unnest.h @@ -134,15 +134,15 @@ class Unnest : public Operator { // Invoked by generateOutput for the ordinality column. VectorPtr generateOrdinalityVector(const RowRange& rowRange); - // Invoked by generateOutput for the empty unnest value column. - VectorPtr generateEmptyUnnestValueVector(const RowRange& rowRange); + // Invoked by generateOutput for the marker column. + VectorPtr generateMarkerVector(const RowRange& rowRange); // Invoked when finish one input batch processing to reset the internal // execution state for the next batch. void finishInput(); const bool withOrdinality_; - const bool withEmptyUnnestValue_; + const bool withMarker_; // The maximum number of output batch rows. const vector_size_t maxOutputSize_; diff --git a/velox/exec/Values.cpp b/velox/exec/Values.cpp index 2660732979a..cc9503594e8 100644 --- a/velox/exec/Values.cpp +++ b/velox/exec/Values.cpp @@ -15,6 +15,7 @@ */ #include "velox/exec/Values.h" #include "velox/common/testutil/TestValue.h" +#include "velox/exec/OperatorType.h" using facebook::velox::common::testutil::TestValue; namespace facebook::velox::exec { @@ -28,7 +29,7 @@ Values::Values( values->outputType(), operatorId, values->id(), - "Values"), + OperatorType::kValues), valueNodes_(std::move(values)), roundsLeft_(valueNodes_->repeatTimes()) {} @@ -45,8 +46,9 @@ void Values::initialize() { // If this is parallelizable, copy the values to prevent Vectors from // being shared across threads. Note that the contract in ValuesNode is // that this should only be enabled for testing. - values_.emplace_back(std::static_pointer_cast( - vector->testingCopyPreserveEncodings())); + values_.emplace_back( + std::static_pointer_cast( + vector->testingCopyPreserveEncodings())); } else { values_.emplace_back(vector); } diff --git a/velox/exec/VectorHasher-inl.h b/velox/exec/VectorHasher-inl.h index 63f842c774b..6964d3a3c08 100644 --- a/velox/exec/VectorHasher-inl.h +++ b/velox/exec/VectorHasher-inl.h @@ -28,14 +28,14 @@ bool VectorHasher::tryMapToRangeSimd( constexpr int kWidth = xsimd::batch::size; for (; row + kWidth <= rows.end(); row += kWidth) { auto data = xsimd::load_unaligned(values + row); - int32_t gtMax = simd::toBitMask(data > allHigh); - int32_t ltMin = simd::toBitMask(data < allLow); + bool gtMax = simd::any(data > allHigh); + bool ltMin = simd::any(data < allLow); // value - (low - 1) doesn't work when low is the lowest possible (e.g. // std::numeric_limits::min()) if constexpr (sizeof(T) == sizeof(uint64_t)) { (data - allLow + allOne).store_unaligned(result + row); } - if ((gtMax | ltMin) != 0) { + if (gtMax || ltMin) { inRange = false; break; } diff --git a/velox/exec/VectorHasher.cpp b/velox/exec/VectorHasher.cpp index e502575aae7..70d8295b70c 100644 --- a/velox/exec/VectorHasher.cpp +++ b/velox/exec/VectorHasher.cpp @@ -23,35 +23,35 @@ namespace facebook::velox::exec { -#define VALUE_ID_TYPE_DISPATCH(TEMPLATE_FUNC, typeKind, ...) \ - [&]() { \ - switch (typeKind) { \ - case TypeKind::BOOLEAN: { \ - return TEMPLATE_FUNC(__VA_ARGS__); \ - } \ - case TypeKind::TINYINT: { \ - return TEMPLATE_FUNC(__VA_ARGS__); \ - } \ - case TypeKind::SMALLINT: { \ - return TEMPLATE_FUNC(__VA_ARGS__); \ - } \ - case TypeKind::INTEGER: { \ - return TEMPLATE_FUNC(__VA_ARGS__); \ - } \ - case TypeKind::BIGINT: { \ - return TEMPLATE_FUNC(__VA_ARGS__); \ - } \ - case TypeKind::VARCHAR: \ - case TypeKind::VARBINARY: { \ - return TEMPLATE_FUNC(__VA_ARGS__); \ - } \ - case TypeKind::TIMESTAMP: { \ - return TEMPLATE_FUNC(__VA_ARGS__); \ - } \ - default: \ - VELOX_UNREACHABLE( \ - "Unsupported value ID type: ", mapTypeKindToName(typeKind)); \ - } \ +#define VALUE_ID_TYPE_DISPATCH(TEMPLATE_FUNC, typeKind, ...) \ + [&]() { \ + switch (typeKind) { \ + case TypeKind::BOOLEAN: { \ + return TEMPLATE_FUNC(__VA_ARGS__); \ + } \ + case TypeKind::TINYINT: { \ + return TEMPLATE_FUNC(__VA_ARGS__); \ + } \ + case TypeKind::SMALLINT: { \ + return TEMPLATE_FUNC(__VA_ARGS__); \ + } \ + case TypeKind::INTEGER: { \ + return TEMPLATE_FUNC(__VA_ARGS__); \ + } \ + case TypeKind::BIGINT: { \ + return TEMPLATE_FUNC(__VA_ARGS__); \ + } \ + case TypeKind::VARCHAR: \ + case TypeKind::VARBINARY: { \ + return TEMPLATE_FUNC(__VA_ARGS__); \ + } \ + case TypeKind::TIMESTAMP: { \ + return TEMPLATE_FUNC(__VA_ARGS__); \ + } \ + default: \ + VELOX_UNREACHABLE( \ + "Unsupported value ID type: ", TypeKindName::toName(typeKind)); \ + } \ }() namespace { @@ -351,6 +351,8 @@ bool VectorHasher::makeValueIdsDecoded( bool VectorHasher::computeValueIds( const SelectivityVector& rows, raw_vector& result) { + checkTypeSupportsValueIds(); + return VALUE_ID_TYPE_DISPATCH(makeValueIds, typeKind_, rows, result.data()); } @@ -361,6 +363,8 @@ bool VectorHasher::computeValueIdsForRows( int32_t nullByte, uint8_t nullMask, raw_vector& result) { + checkTypeSupportsValueIds(); + return VALUE_ID_TYPE_DISPATCH( makeValueIdsForRows, typeKind_, @@ -422,32 +426,24 @@ void VectorHasher::lookupValueIdsTyped( result[row] = multiplier_ == 1 ? id : result[row] + multiplier_ * id; }); } - } else if (decoded.isIdentityMapping()) { + return; + } + + if (decoded.isIdentityMapping() && !decoded.mayHaveNulls()) { if (Kind == TypeKind::BIGINT && isRange_) { lookupIdsRangeSimd(decoded, rows, result); - } else if (Kind == TypeKind::INTEGER && isRange_) { + rows.updateBounds(); + return; + } + if (Kind == TypeKind::INTEGER && isRange_) { lookupIdsRangeSimd(decoded, rows, result); - } else { - rows.applyToSelected([&](vector_size_t row) INLINE_LAMBDA { - if (decoded.isNullAt(row)) { - if (multiplier_ == 1) { - result[row] = 0; - } - return; - } - T value = decoded.valueAt(row); - uint64_t id = lookupValueId(value); - if (id == kUnmappable) { - rows.setValid(row, false); - return; - } - result[row] = multiplier_ == 1 ? id : result[row] + multiplier_ * id; - }); + rows.updateBounds(); + return; } - rows.updateBounds(); - } else { - hashes.resize(decoded.base()->size()); - std::fill(hashes.begin(), hashes.end(), 0); + } + + if (decoded.isIdentityMapping() || + rows.countSelected() <= decoded.base()->size()) { rows.applyToSelected([&](vector_size_t row) INLINE_LAMBDA { if (decoded.isNullAt(row)) { if (multiplier_ == 1) { @@ -455,21 +451,41 @@ void VectorHasher::lookupValueIdsTyped( } return; } - auto baseIndex = decoded.index(row); - uint64_t id = hashes[baseIndex]; - if (id == 0) { - T value = decoded.valueAt(row); - id = lookupValueId(value); - if (id == kUnmappable) { - rows.setValid(row, false); - return; - } - hashes[baseIndex] = id; + T value = decoded.valueAt(row); + uint64_t id = lookupValueId(value); + if (id == kUnmappable) { + rows.setValid(row, false); + return; } result[row] = multiplier_ == 1 ? id : result[row] + multiplier_ * id; }); rows.updateBounds(); + return; } + + hashes.resize(decoded.base()->size()); + std::fill(hashes.begin(), hashes.end(), 0); + rows.applyToSelected([&](vector_size_t row) INLINE_LAMBDA { + if (decoded.isNullAt(row)) { + if (multiplier_ == 1) { + result[row] = 0; + } + return; + } + auto baseIndex = decoded.index(row); + uint64_t id = hashes[baseIndex]; + if (id == 0) { + T value = decoded.valueAt(row); + id = lookupValueId(value); + if (id == kUnmappable) { + rows.setValid(row, false); + return; + } + hashes[baseIndex] = id; + } + result[row] = multiplier_ == 1 ? id : result[row] + multiplier_ * id; + }); + rows.updateBounds(); } template @@ -533,6 +549,8 @@ void VectorHasher::lookupValueIds( SelectivityVector& rows, ScratchMemory& scratchMemory, raw_vector& result) const { + checkTypeSupportsValueIds(); + scratchMemory.decoded.decode(values, rows); VALUE_ID_TYPE_DISPATCH( lookupValueIdsTyped, @@ -552,7 +570,7 @@ void VectorHasher::hash( result[row] = mix ? bits::hashMix(result[row], kNullHash) : kNullHash; }); } else { - if (type_->providesCustomComparison()) { + if (typeProvidesCustomComparison_) { VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH( hashValues, true, typeKind_, rows, mix, result.data()); } else { @@ -581,7 +599,7 @@ void VectorHasher::precompute(const BaseVector& value) { const SelectivityVector rows(1, true); decoded_.decode(value, rows); - if (type_->providesCustomComparison()) { + if (typeProvidesCustomComparison_) { precomputedHash_ = VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH( hashOne, true, typeKind_, decoded_, 0); } else { @@ -596,6 +614,8 @@ void VectorHasher::analyze( int32_t offset, int32_t nullByte, uint8_t nullMask) { + checkTypeSupportsValueIds(); + VALUE_ID_TYPE_DISPATCH( analyzeTyped, typeKind_, groups, numGroups, offset, nullByte, nullMask); } @@ -668,6 +688,10 @@ void VectorHasher::setRangeOverflow() { std::unique_ptr VectorHasher::getFilter( bool nullAllowed) const { + if (typeProvidesCustomComparison_) { + return nullptr; + } + switch (typeKind_) { case TypeKind::TINYINT: [[fallthrough]]; @@ -686,8 +710,19 @@ std::unique_ptr VectorHasher::getFilter( return common::createBigintValues(values, nullAllowed); } [[fallthrough]]; + case TypeKind::VARCHAR: + [[fallthrough]]; + case TypeKind::VARBINARY: + if (!distinctOverflow_) { + std::vector values; + values.reserve(uniqueValues_.size()); + for (const auto& value : uniqueValues_) { + values.emplace_back(value.asString()); + } + return std::make_unique(values, nullAllowed); + } + [[fallthrough]]; default: - // TODO Add support for strings. return nullptr; } } @@ -767,6 +802,12 @@ void VectorHasher::cardinality( int32_t reservePct, uint64_t& asRange, uint64_t& asDistincts) { + if (!typeSupportsValueIds()) { + asRange = kRangeTooLarge; + asDistincts = kRangeTooLarge; + return; + } + if (typeKind_ == TypeKind::BOOLEAN) { hasRange_ = true; asRange = 3; @@ -808,6 +849,8 @@ uint64_t VectorHasher::enableValueIds(uint64_t multiplier, int32_t reservePct) { typeKind_, TypeKind::BOOLEAN, "A boolean VectorHasher should always be by range"); + checkTypeSupportsValueIds(); + multiplier_ = multiplier; rangeSize_ = addIdReserve(uniqueValues_.size(), reservePct) + 1; isRange_ = false; @@ -821,6 +864,8 @@ uint64_t VectorHasher::enableValueIds(uint64_t multiplier, int32_t reservePct) { uint64_t VectorHasher::enableValueRange( uint64_t multiplier, int32_t reservePct) { + checkTypeSupportsValueIds(); + multiplier_ = multiplier; VELOX_CHECK_LE(0, reservePct); VELOX_CHECK(hasRange_); @@ -849,7 +894,7 @@ void VectorHasher::copyStatsFrom(const VectorHasher& other) { uniqueValues_ = other.uniqueValues_; } -void VectorHasher::merge(const VectorHasher& other) { +void VectorHasher::merge(const VectorHasher& other, size_t maxNumDistinct) { if (typeKind_ == TypeKind::BOOLEAN) { return; } @@ -867,18 +912,25 @@ void VectorHasher::merge(const VectorHasher& other) { } else { setRangeOverflow(); } - if (!distinctOverflow_ && !other.distinctOverflow_) { - // Unique values can be merged without dispatch on type. All the - // merged hashers must stay live for string type columns. - for (UniqueValue value : other.uniqueValues_) { - // Assign a new id at end of range for the case 'value' is not - // in 'uniqueValues_'. We do not set overflow here because the - // memory is already allocated and there is a known cap on size. - value.setId(uniqueValues_.size() + 1); - uniqueValues_.insert(value); - } - } else { + if (distinctOverflow_) { + return; + } + if (other.distinctOverflow_) { setDistinctOverflow(); + return; + } + // Unique values can be merged without dispatch on type. All the + // merged hashers must stay live for string type columns. + for (UniqueValue value : other.uniqueValues_) { + // Assign a new id at end of range for the case 'value' is not + // in 'uniqueValues_'. We do not set overflow here because the + // memory is already allocated and there is a known cap on size. + value.setId(uniqueValues_.size() + 1); + if (uniqueValues_.insert(value).second && + uniqueValues_.size() > maxNumDistinct) { + setDistinctOverflow(); + break; + } } } diff --git a/velox/exec/VectorHasher.h b/velox/exec/VectorHasher.h index ebd3534c116..f580fc6f865 100644 --- a/velox/exec/VectorHasher.h +++ b/velox/exec/VectorHasher.h @@ -61,6 +61,15 @@ class UniqueValue { return data_; } + std::string asString() const { + if (size_ <= sizeof(int64_t)) { + // String is stored inline in data_. + return std::string{reinterpret_cast(&data_), size_}; + } + // String is stored as a pointer in data_. + return std::string{reinterpret_cast(data_), size_}; + } + void setData(int64_t data) { data_ = data; } @@ -131,8 +140,15 @@ class VectorHasher { static constexpr int32_t kNoLimit = -1; VectorHasher(TypePtr type, column_index_t channel) - : channel_(channel), type_(std::move(type)), typeKind_(type_->kind()) { - if (typeKind_ == TypeKind::BOOLEAN) { + : channel_(channel), + type_(std::move(type)), + typeKind_(type_->kind()), + typeProvidesCustomComparison_(type_->providesCustomComparison()) { + if (!typeSupportsValueIds()) { + // Ensure any range or unique value based hashing is disabled. + setRangeOverflow(); + setDistinctOverflow(); + } else if (typeKind_ == TypeKind::BOOLEAN) { // We do not need samples to know the cardinality or limits of a bool // vector. hasRange_ = true; @@ -235,15 +251,39 @@ class VectorHasher { ScratchMemory& scratchMemory, raw_vector& result) const; - // Returns true if either range or distinct values have not overflowed. + // Returns true if either range or distinct values have not overflowed and the + // type doesn't support custom comparison. bool mayUseValueIds() const { - return hasRange_ || !distinctOverflow_; + return typeSupportsValueIds() && (hasRange_ || !distinctOverflow_); } // Returns an instance of the filter corresponding to a set of unique values. // Returns null if distinctOverflow_ is true. std::unique_ptr getFilter(bool nullAllowed) const; + bool supportsBloomFilter() const { + if (typeProvidesCustomComparison_) { + return false; + } + switch (typeKind_) { + // Smaller integers would never overflow 100'000 distinct values. + case TypeKind::INTEGER: + case TypeKind::BIGINT: + return distinctOverflow_; + default: + return false; + } + } + + void setBloomFilter(common::FilterPtr filter) { + VELOX_DCHECK(supportsBloomFilter()); + bloomFilter_ = std::move(filter); + } + + const common::FilterPtr& getBloomFilter() const { + return bloomFilter_; + } + void resetStats() { uniqueValues_.clear(); uniqueValuesStorage_.clear(); @@ -284,8 +324,12 @@ class VectorHasher { return isRange_; } - static bool typeKindSupportsValueIds(TypeKind kind) { - switch (kind) { + bool typeSupportsValueIds() const { + if (typeProvidesCustomComparison_) { + return false; + } + + switch (typeKind_) { case TypeKind::BOOLEAN: case TypeKind::TINYINT: case TypeKind::SMALLINT: @@ -302,7 +346,7 @@ class VectorHasher { // Merges the value ids information of 'other' into 'this'. Ranges // and distinct values are unioned. - void merge(const VectorHasher& other); + void merge(const VectorHasher& other, size_t maxNumDistinct); // true if no values have been added. bool empty() const { @@ -535,6 +579,13 @@ class VectorHasher { void setRangeOverflow(); + inline void checkTypeSupportsValueIds() const { + VELOX_DCHECK( + typeSupportsValueIds(), + "Value IDs cannot be used, the type {} is not supported.", + type_->toString()); + } + static inline bool isNullAt(const char* group, int32_t nullByte, uint8_t nullMask) { return (group[nullByte] & nullMask) != 0; @@ -551,6 +602,7 @@ class VectorHasher { const column_index_t channel_; const TypePtr type_; const TypeKind typeKind_; + const bool typeProvidesCustomComparison_; DecodedVector decoded_; raw_vector cachedHashes_; @@ -587,6 +639,8 @@ class VectorHasher { // Memory for unique string values. std::vector uniqueValuesStorage_; uint64_t distinctStringsBytes_ = 0; + + common::FilterPtr bloomFilter_; }; template <> @@ -601,14 +655,6 @@ bool VectorHasher::makeValueIdsForRows( template <> void VectorHasher::analyzeValue(StringView value); -template <> -inline bool VectorHasher::tryMapToRange( - const StringView* /*values*/, - const SelectivityVector& /*rows*/, - uint64_t* /*result*/) { - return false; -} - template <> inline uint64_t VectorHasher::valueId(StringView value) { auto size = value.size(); @@ -703,10 +749,12 @@ inline bool VectorHasher::tryMapToRange( } template <> -bool VectorHasher::tryMapToRange( +inline bool VectorHasher::tryMapToRange( const StringView* /*values*/, const SelectivityVector& /*rows*/, - uint64_t* /*result*/); + uint64_t* /*result*/) { + return false; +} template <> bool VectorHasher::makeValueIdsFlatNoNulls( diff --git a/velox/exec/Window.cpp b/velox/exec/Window.cpp index 321a469ddae..f9107522f0a 100644 --- a/velox/exec/Window.cpp +++ b/velox/exec/Window.cpp @@ -14,14 +14,27 @@ * limitations under the License. */ #include "velox/exec/Window.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/OperatorUtils.h" #include "velox/exec/PartitionStreamingWindowBuild.h" #include "velox/exec/RowsStreamingWindowBuild.h" #include "velox/exec/SortWindowBuild.h" +#include "velox/exec/SubPartitionedSortWindowBuild.h" #include "velox/exec/Task.h" namespace facebook::velox::exec { +namespace { +common::PrefixSortConfig makePrefixSortConfig( + const core::QueryConfig& queryConfig) { + return common::PrefixSortConfig{ + queryConfig.prefixSortNormalizedKeyMaxBytes(), + queryConfig.prefixSortMinRows(), + queryConfig.prefixSortMaxStringPrefixLength()}; +} + +} // namespace + Window::Window( int32_t operatorId, DriverCtx* driverCtx, @@ -31,9 +44,9 @@ Window::Window( windowNode->outputType(), operatorId, windowNode->id(), - "Window", + OperatorType::kWindow, windowNode->canSpill(driverCtx->queryConfig()) - ? driverCtx->makeSpillConfig(operatorId) + ? driverCtx->makeSpillConfig(operatorId, OperatorType::kWindow) : std::nullopt), numInputColumns_(windowNode->inputType()->size()), windowNode_(windowNode), @@ -44,7 +57,8 @@ Window::Window( if (spillConfig == nullptr && operatorCtx_->driverCtx()->queryConfig().windowSpillEnabled()) { auto lockedStats = stats_.wlock(); - lockedStats->runtimeStats.emplace(kSpillNotSupported, RuntimeMetric(1)); + lockedStats->runtimeStats.emplace( + std::string(kSpillNotSupported), RuntimeMetric(1)); } if (windowNode->inputsSorted()) { if (supportRowsStreaming()) { @@ -55,16 +69,28 @@ Window::Window( windowNode, pool(), spillConfig, &nonReclaimableSection_); } } else { - windowBuild_ = std::make_unique( - windowNode, - pool(), - common::PrefixSortConfig{ - driverCtx->queryConfig().prefixSortNormalizedKeyMaxBytes(), - driverCtx->queryConfig().prefixSortMinRows(), - driverCtx->queryConfig().prefixSortMaxStringPrefixLength()}, - spillConfig, - &nonReclaimableSection_, - spillStats_.get()); + if (auto numSubPartitions = + operatorCtx_->driverCtx()->queryConfig().windowNumSubPartitions(); + numSubPartitions > 1) { + windowBuild_ = std::make_unique( + windowNode, + numSubPartitions, + pool(), + makePrefixSortConfig(driverCtx->queryConfig()), + spillConfig, + &nonReclaimableSection_, + &stats_, + spillStats_.get()); + } else { + windowBuild_ = std::make_unique( + windowNode, + pool(), + makePrefixSortConfig(driverCtx->queryConfig()), + spillConfig, + &nonReclaimableSection_, + &stats_, + spillStats_.get()); + } } } @@ -159,10 +185,11 @@ Window::WindowFrame Window::createWindowFrame( return std::make_optional( FrameChannelArg{kConstantChannel, nullptr, value}); } else { - return std::make_optional(FrameChannelArg{ - frameChannel, - BaseVector::create(frame->type(), 0, pool()), - std::nullopt}); + return std::make_optional( + FrameChannelArg{ + frameChannel, + BaseVector::create(frame->type(), 0, pool()), + std::nullopt}); } }; @@ -196,14 +223,15 @@ void Window::createWindowFunctions() { } } - windowFunctions_.push_back(WindowFunction::create( - windowNodeFunction.functionCall->name(), - functionArgs, - windowNodeFunction.functionCall->type(), - windowNodeFunction.ignoreNulls, - operatorCtx_->pool(), - &stringAllocator_, - operatorCtx_->driverCtx()->queryConfig())); + windowFunctions_.push_back( + WindowFunction::create( + windowNodeFunction.functionCall->name(), + functionArgs, + windowNodeFunction.functionCall->type(), + windowNodeFunction.ignoreNulls, + operatorCtx_->pool(), + &stringAllocator_, + operatorCtx_->driverCtx()->queryConfig())); windowFrames_.push_back( createWindowFrame(windowNode_, windowNodeFunction.frame, inputType)); @@ -251,8 +279,11 @@ void Window::reclaim( // TODO Add support for spilling after noMoreInput(). LOG(WARNING) << "Can't reclaim from window operator which has started producing output: " - << pool()->name() << ", usage: " << succinctBytes(pool()->usedBytes()) - << ", reservation: " << succinctBytes(pool()->reservedBytes()); + << pool()->name() << ", root pool: " << pool()->root()->name() + << ", used: " << succinctBytes(pool()->usedBytes()) + << ", reservation: " << succinctBytes(pool()->reservedBytes()) + << ", root pool reservation: " + << succinctBytes(pool()->root()->reservedBytes()); return; } diff --git a/velox/exec/Window.h b/velox/exec/Window.h index 8cbd3b1ee85..2ddca5065bf 100644 --- a/velox/exec/Window.h +++ b/velox/exec/Window.h @@ -67,6 +67,11 @@ class Window : public Operator { void reclaim(uint64_t targetBytes, memory::MemoryReclaimer::Stats& stats) override; + /// Runtime statistics holding total number of batches read from spilled data. + /// 0 if no spilling occurred. + static constexpr std::string_view kWindowSpillReadNumBatches{ + "windowSpillReadNumBatches"}; + private: // Used for k preceding/following frames. Index is the column index if k is a // column. value is used to read column values from the column index when k diff --git a/velox/exec/WindowBuild.h b/velox/exec/WindowBuild.h index c67e090a02a..730e15ebb28 100644 --- a/velox/exec/WindowBuild.h +++ b/velox/exec/WindowBuild.h @@ -48,7 +48,7 @@ class WindowBuild { virtual void spill() = 0; /// Returns the spiller stats including total bytes and rows spilled so far. - virtual std::optional spilledStats() const = 0; + virtual std::optional spilledStats() const = 0; /// The Window operator invokes this function to indicate that no more input /// rows will be passed from the Window operator to the WindowBuild. When @@ -70,14 +70,16 @@ class WindowBuild { /// Returns the average size of input rows in bytes stored in the data /// container of the WindowBuild. - std::optional estimateRowSize() { + virtual std::optional estimateRowSize() { return data_->estimateRowSize(); } /// Releases the memory held by the window build. This is called by the /// window operator when all rows have been processed. void release() { - data_->clear(); + if (data_) { + data_->clear(); + } } void setNumRowsPerOutput(vector_size_t numRowsPerOutput) { diff --git a/velox/exec/WindowFunction.cpp b/velox/exec/WindowFunction.cpp index 30c3d4840b7..1825efd2ecd 100644 --- a/velox/exec/WindowFunction.cpp +++ b/velox/exec/WindowFunction.cpp @@ -68,6 +68,52 @@ std::optional> getWindowFunctionSignatures( return std::nullopt; } +TypePtr resolveWindowResultType( + const std::string& name, + const std::vector& argTypes) { + auto sanitizedName = sanitizeName(name); + + if (auto signatures = getWindowFunctionSignatures(sanitizedName)) { + for (const auto& signature : signatures.value()) { + SignatureBinder binder(*signature, argTypes, TypeCoercer::defaults()); + if (binder.tryBind()) { + return binder.tryResolveReturnType(); + } + } + + VELOX_USER_FAIL( + "Window function signature is not supported: {}. Supported signatures: {}.", + toString(sanitizedName, argTypes), + toString(signatures.value())); + } + + VELOX_USER_FAIL("Window function not registered: {}", name); +} + +TypePtr resolveWindowResultTypeWithCoercions( + const std::string& name, + const std::vector& argTypes, + std::vector& coercions, + const TypeCoercer& coercer) { + coercions.clear(); + + auto sanitizedName = sanitizeName(name); + + if (auto signatures = getWindowFunctionSignatures(sanitizedName)) { + if (auto type = tryResolveReturnTypeWithCoercions( + signatures.value(), argTypes, coercions, coercer)) { + return type; + } + + VELOX_USER_FAIL( + "Window function signature is not supported: {}. Supported signatures: {}.", + toString(sanitizedName, argTypes), + toString(signatures.value())); + } + + VELOX_USER_FAIL("Window function not registered: {}", name); +} + std::unique_ptr WindowFunction::create( const std::string& name, const std::vector& args, @@ -86,7 +132,7 @@ std::unique_ptr WindowFunction::create( const auto& signatures = func.value()->signatures; for (auto& signature : signatures) { - SignatureBinder binder(*signature, argTypes); + SignatureBinder binder(*signature, argTypes, TypeCoercer::defaults()); if (binder.tryBind()) { auto type = binder.tryResolveType(signature->returnType()); VELOX_USER_CHECK( diff --git a/velox/exec/WindowFunction.h b/velox/exec/WindowFunction.h index e9bea92ee2c..dd0febc41fc 100644 --- a/velox/exec/WindowFunction.h +++ b/velox/exec/WindowFunction.h @@ -18,6 +18,7 @@ #include "velox/core/QueryConfig.h" #include "velox/exec/WindowPartition.h" #include "velox/expression/FunctionSignature.h" +#include "velox/type/TypeCoercer.h" #include "velox/vector/BaseVector.h" namespace facebook::velox::exec { @@ -179,6 +180,25 @@ bool registerWindowFunction( std::optional> getWindowFunctionSignatures( const std::string& name); +/// Resolves the return type of a window function. +/// Throws if no matching signature is found. +TypePtr resolveWindowResultType( + const std::string& name, + const std::vector& argTypes); + +/// Like 'resolveWindowResultType', but with support for applying type +/// coercions if a function signature doesn't match 'argTypes' exactly. +/// +/// @param coercions A list of optional type coercions applied to resolve the +/// function. Contains one entry per argument. The entry is null if no coercion +/// is required for that argument. +/// @param coercer Coercion rule set to use when resolving type coercions. +TypePtr resolveWindowResultTypeWithCoercions( + const std::string& name, + const std::vector& argTypes, + std::vector& coercions, + const TypeCoercer& coercer); + struct WindowFunctionEntry { std::vector signatures; WindowFunctionFactory factory; diff --git a/velox/exec/WindowPartition.cpp b/velox/exec/WindowPartition.cpp index 90b209732f8..954437d0ed6 100644 --- a/velox/exec/WindowPartition.cpp +++ b/velox/exec/WindowPartition.cpp @@ -74,7 +74,7 @@ void WindowPartition::removeProcessedRows(vector_size_t numRows) { previousRow_ = rows_[numRows - 1]; } - rows_.erase(rows_.begin(), rows_.begin() + numRows); + rows_.erase(rows_.cbegin(), rows_.cbegin() + numRows); partition_ = folly::Range(rows_.data(), rows_.size()); startRow_ += numRows; } diff --git a/velox/exec/benchmarks/AtomicsBench.cpp b/velox/exec/benchmarks/AtomicsBench.cpp index 74343e9303f..c86991ef6fc 100644 --- a/velox/exec/benchmarks/AtomicsBench.cpp +++ b/velox/exec/benchmarks/AtomicsBench.cpp @@ -16,16 +16,39 @@ #include #include -#include #include #include #include +#include "velox/common/base/Portability.h" #include "velox/exec/OneWayStatusFlag.h" -using namespace ::testing; -using namespace facebook::velox; -static const size_t kNumThreads = 88; -static const size_t kNumIterations = 10000; +namespace { + +using facebook::velox::exec::OneWayStatusFlag; +constexpr size_t kNumThreads = 88; +constexpr size_t kNumIterations = 10000; + +#if defined(__x86_64__) && !defined(TSAN_BUILD) + +class OneWayStatusFlagUnsafe { + public: + bool check() const { + return fastStatus_ || atomicStatus_.load(); + } + + void set() { + if (!fastStatus_) { + atomicStatus_.store(true); + fastStatus_ = true; + } + } + + private: + bool fastStatus_{false}; + std::atomic_bool atomicStatus_{false}; +}; + +#endif void runParallelUpdates( std::function callback, @@ -46,12 +69,28 @@ void runParallelUpdates( } } -BENCHMARK(std_atomic_bool_write) { - std::atomic flag{false}; +BENCHMARK(std_atomic_bool_write_seq_cst) { + std::atomic_bool flag{false}; runParallelUpdates( [&](size_t iters) { for (size_t i = 0; i < iters; ++i) { flag.store(true); + bool dummy{}; + folly::doNotOptimizeAway(dummy); + } + }, + kNumThreads, // Threads + kNumIterations); // Iterations per thread +} + +BENCHMARK(std_atomic_bool_write_release) { + std::atomic_bool flag{false}; + runParallelUpdates( + [&](size_t iters) { + for (size_t i = 0; i < iters; ++i) { + flag.store(true, std::memory_order_release); + bool dummy{}; + folly::doNotOptimizeAway(dummy); } }, kNumThreads, // Threads @@ -59,76 +98,82 @@ BENCHMARK(std_atomic_bool_write) { } BENCHMARK(std_atomic_bool_write_relaxed) { - std::atomic flag{false}; + std::atomic_bool flag{false}; runParallelUpdates( [&](size_t iters) { for (size_t i = 0; i < iters; ++i) { flag.store(true, std::memory_order_relaxed); + bool dummy{}; + folly::doNotOptimizeAway(dummy); } }, kNumThreads, // Threads kNumIterations); // Iterations per thread } -BENCHMARK(std_atomic_bool_read_write_relaxed) { - std::atomic flag{false}; +BENCHMARK(one_way_flag_write) { + OneWayStatusFlag flag; runParallelUpdates( [&](size_t iters) { for (size_t i = 0; i < iters; ++i) { - if (!flag.load(std::memory_order_relaxed)) { - flag.store(true, std::memory_order_acq_rel); - } + flag.set(); + bool dummy{}; + folly::doNotOptimizeAway(dummy); } }, kNumThreads, // Threads kNumIterations); // Iterations per thread } -BENCHMARK(one_way_flag_write) { - exec::OneWayStatusFlag flag; +#if defined(__x86_64__) && !defined(TSAN_BUILD) + +BENCHMARK(one_way_flag_unsafe_write) { + OneWayStatusFlagUnsafe flag; runParallelUpdates( [&](size_t iters) { for (size_t i = 0; i < iters; ++i) { flag.set(); + bool dummy{}; + folly::doNotOptimizeAway(dummy); } }, kNumThreads, // Threads kNumIterations); // Iterations per thread } +#endif + // Read Benchmarks -BENCHMARK(std_atomic_bool_read) { - std::atomic flag{false}; +BENCHMARK(std_atomic_bool_read_seq_cst) { + std::atomic_bool flag{false}; runParallelUpdates( [&](size_t iters) { for (size_t i = 0; i < iters; ++i) { - folly::doNotOptimizeAway(flag.load()); + folly::doNotOptimizeAway(flag.load(std::memory_order_seq_cst)); } }, kNumThreads, // Threads kNumIterations); // Iterations per thread } -BENCHMARK(std_atomic_bool_relaxed_read) { - std::atomic flag{false}; +BENCHMARK(std_atomic_bool_read_acquire) { + std::atomic_bool flag{false}; runParallelUpdates( [&](size_t iters) { for (size_t i = 0; i < iters; ++i) { - folly::doNotOptimizeAway(flag.load(std::memory_order_relaxed)); + folly::doNotOptimizeAway(flag.load(std::memory_order_acquire)); } }, kNumThreads, // Threads kNumIterations); // Iterations per thread } -BENCHMARK(std_atomic_bool_read_relaxed_acquire) { - std::atomic flag{false}; +BENCHMARK(std_atomic_bool_read_relaxed) { + std::atomic_bool flag{false}; runParallelUpdates( [&](size_t iters) { for (size_t i = 0; i < iters; ++i) { - folly::doNotOptimizeAway( - flag.load(std::memory_order_relaxed) || - flag.load(std::memory_order_acquire)); + folly::doNotOptimizeAway(flag.load(std::memory_order_relaxed)); } }, kNumThreads, // Threads @@ -136,7 +181,21 @@ BENCHMARK(std_atomic_bool_read_relaxed_acquire) { } BENCHMARK(one_way_flag_read) { - exec::OneWayStatusFlag flag; + OneWayStatusFlag flag; + runParallelUpdates( + [&](size_t iters) { + for (size_t i = 0; i < iters; ++i) { + folly::doNotOptimizeAway(flag.check()); + } + }, + kNumThreads, // Threads + kNumIterations); // Iterations per thread +} + +#if defined(__x86_64__) && !defined(TSAN_BUILD) + +BENCHMARK(one_way_flag_unsafe_read) { + OneWayStatusFlagUnsafe flag; runParallelUpdates( [&](size_t iters) { for (size_t i = 0; i < iters; ++i) { @@ -147,6 +206,10 @@ BENCHMARK(one_way_flag_read) { kNumIterations); // Iterations per thread } +#endif + +} // namespace + int main(int argc, char** argv) { folly::Init init(&argc, &argv); folly::runBenchmarks(); diff --git a/velox/exec/benchmarks/CMakeLists.txt b/velox/exec/benchmarks/CMakeLists.txt index 5dfe0ff853e..1ccecdf49d5 100644 --- a/velox/exec/benchmarks/CMakeLists.txt +++ b/velox/exec/benchmarks/CMakeLists.txt @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + add_executable(velox_exec_vector_hasher_benchmark VectorHasherBenchmark.cpp) target_link_libraries( @@ -51,6 +52,10 @@ target_link_libraries( GTest::gtest_main ) +add_executable(velox_hash_rows_benchmark HashRowsBenchmark.cpp) + +target_link_libraries(velox_hash_rows_benchmark Folly::folly Folly::follybenchmark) + add_executable(velox_hash_benchmark HashTableBenchmark.cpp) target_link_libraries( @@ -72,6 +77,17 @@ target_link_libraries( Folly::follybenchmark ) +add_executable(velox_hash_join_build_benchmark HashJoinBuildBenchmark.cpp) + +target_link_libraries( + velox_hash_join_build_benchmark + velox_exec + velox_exec_test_lib + velox_vector_fuzzer + velox_vector_test_lib + Folly::follybenchmark +) + add_executable(velox_hash_join_prepare_join_table_benchmark HashJoinPrepareJoinTableBenchmark.cpp) target_link_libraries( @@ -98,6 +114,7 @@ if(${VELOX_ENABLE_PARQUET}) endif() add_library(velox_orderby_benchmark_util OrderByBenchmarkUtil.cpp) +velox_add_test_headers(velox_orderby_benchmark_util OrderByBenchmarkUtil.h) target_link_libraries(velox_orderby_benchmark_util velox_vector_fuzzer velox_vector_test_lib) @@ -136,6 +153,30 @@ target_link_libraries( Folly::follybenchmark ) +add_executable(velox_window_sub_partitioned_sort_benchmark WindowSubPartitionedSortBenchmark.cpp) + +target_link_libraries( + velox_window_sub_partitioned_sort_benchmark + velox_aggregates + velox_exec + velox_exec_test_lib + velox_hive_connector + velox_vector_fuzzer + velox_vector_test_lib + velox_window + Folly::follybenchmark +) + +add_executable(velox_mark_sorted_benchmark MarkSortedBenchmark.cpp) + +target_link_libraries( + velox_mark_sorted_benchmark + velox_exec + velox_exec_test_lib + velox_vector_test_lib + Folly::follybenchmark +) + add_executable(velox_streaming_aggregation_benchmark StreamingAggregationBenchmark.cpp) target_link_libraries( @@ -158,3 +199,25 @@ target_link_libraries( velox_vector_fuzzer Folly::follybenchmark ) + +add_executable(velox_atomics_benchmark AtomicsBench.cpp) + +target_link_libraries(velox_atomics_benchmark Folly::follybenchmark) + +if(VELOX_ENABLE_GEO) + add_executable(velox_spatial_join_benchmark SpatialJoinBenchmark.cpp) + + target_compile_definitions(velox_spatial_join_benchmark PRIVATE VELOX_ENABLE_GEO) + + target_link_libraries( + velox_spatial_join_benchmark + velox_memory + velox_exec + velox_exec_test_lib + velox_parse_parser + velox_presto_types + velox_vector_test_lib + velox_functions_prestosql + Folly::follybenchmark + ) +endif() diff --git a/velox/exec/benchmarks/DuplicateProjectBenchmark.cpp b/velox/exec/benchmarks/DuplicateProjectBenchmark.cpp index 4c08384d3c8..a2251b26d7c 100644 --- a/velox/exec/benchmarks/DuplicateProjectBenchmark.cpp +++ b/velox/exec/benchmarks/DuplicateProjectBenchmark.cpp @@ -32,6 +32,7 @@ using namespace facebook::velox::test; using namespace facebook::velox::exec::test; namespace { +using namespace facebook::velox::common::testutil; static constexpr int32_t kNumVectors = 50; static constexpr int32_t kRowsPerVector = 10'000; diff --git a/velox/exec/benchmarks/ExchangeBenchmark.cpp b/velox/exec/benchmarks/ExchangeBenchmark.cpp index 17dbd6ee02e..45689ccbf64 100644 --- a/velox/exec/benchmarks/ExchangeBenchmark.cpp +++ b/velox/exec/benchmarks/ExchangeBenchmark.cpp @@ -162,7 +162,7 @@ class ExchangeBenchmark : public VectorTestBase { std::vector finalAggTaskIds; core::PlanNodePtr finalAggPlan = exec::test::PlanBuilder() - .exchange(leafPlan->outputType(), VectorSerde::Kind::kPresto) + .exchange(leafPlan->outputType(), "Presto") .capturePlanNodeId(exchangeId) .singleAggregation({}, {"count(1)"}) .partitionedOutput({}, 1) @@ -184,11 +184,10 @@ class ExchangeBenchmark : public VectorTestBase { })}); // plan: Agg/kSingle(1) <-- Exchange (0) - plan = - exec::test::PlanBuilder() - .exchange(finalAggPlan->outputType(), VectorSerde::Kind::kPresto) - .singleAggregation({}, {"sum(a0)"}) - .planNode(); + plan = exec::test::PlanBuilder() + .exchange(finalAggPlan->outputType(), "Presto") + .singleAggregation({}, {"sum(a0)"}) + .planNode(); }; exec::test::AssertQueryBuilder(plan) @@ -600,7 +599,7 @@ int main(int argc, char** argv) { functions::prestosql::registerAllScalarFunctions(); aggregate::prestosql::registerAllAggregateFunctions(); parse::registerTypeResolver(); - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kPresto)) { + if (!isRegisteredNamedVectorSerde("Presto")) { serializer::presto::PrestoVectorSerde::registerNamedVectorSerde(); } exec::ExchangeSource::registerFactory(exec::test::createLocalExchangeSource); diff --git a/velox/exec/benchmarks/FilterProjectBenchmark.cpp b/velox/exec/benchmarks/FilterProjectBenchmark.cpp index e34be8a68e8..4c0c26ac03e 100644 --- a/velox/exec/benchmarks/FilterProjectBenchmark.cpp +++ b/velox/exec/benchmarks/FilterProjectBenchmark.cpp @@ -94,10 +94,11 @@ class FilterProjectBenchmark : public VectorTestBase { auto& type = data[0]->type()->as(); builder.values(data); for (auto level = 0; level < numStages; ++level) { - builder.filter(fmt::format( - "c0 >= {}", - static_cast( - 1000000 - pow(passPct / 100.0, 1 + level) * 1000000))); + builder.filter( + fmt::format( + "c0 >= {}", + static_cast( + 1000000 - pow(passPct / 100.0, 1 + level) * 1000000))); std::vector projections = {"c0"}; int32_t nthBigint = 0; int32_t nthVarchar = 0; diff --git a/velox/exec/benchmarks/HashJoinBuildBenchmark.cpp b/velox/exec/benchmarks/HashJoinBuildBenchmark.cpp new file mode 100644 index 00000000000..783b48a4362 --- /dev/null +++ b/velox/exec/benchmarks/HashJoinBuildBenchmark.cpp @@ -0,0 +1,399 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include "velox/exec/HashBuild.h" +#include "velox/exec/HashTable.h" +#include "velox/exec/OperatorType.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/exec/tests/utils/VectorTestUtil.h" +#include "velox/vector/tests/utils/VectorTestBase.h" + +using namespace facebook::velox; +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; +using namespace facebook::velox::test; + +namespace { +struct BenchmarkParams { + BenchmarkParams() = default; + + // Benchmark params, we need to provide: + // -the expect hash mode, + // -the build row schema, + // -the duplicate factor, + // -number of building rows, + // -number of probing rows, + // -the abandon percentage, + // -the number of build vector batches. + BenchmarkParams( + BaseHashTable::HashMode mode, + const TypePtr& buildType, + double dupFactor, + int64_t buildSize, + int64_t probeSize, + int32_t abandonPct, + int32_t numBuildBatches) + : mode{mode}, + buildType{buildType}, + hashTableSize{static_cast(std::floor(buildSize / dupFactor))}, + buildSize{buildSize}, + probeSize{probeSize}, + numBuildBatches{numBuildBatches}, + dupFactor{dupFactor}, + abandonPct{abandonPct} { + VELOX_CHECK_LE(hashTableSize, buildSize); + VELOX_CHECK_GE(numBuildBatches, 1); + + if (hashTableSize > BaseHashTable::kArrayHashMaxSize && + mode == BaseHashTable::HashMode::kArray) { + VELOX_FAIL("Bad hash mode."); + } + + numFields = buildType->size(); + if (mode == BaseHashTable::HashMode::kNormalizedKey) { + extraValue = BaseHashTable::kArrayHashMaxSize + 100; + } else if (mode == BaseHashTable::HashMode::kHash) { + extraValue = std::numeric_limits::max() - 1; + } else { + extraValue = 0; + } + + title = fmt::format( + "dupFactor:{:<2},abandonPct:{},{}", + dupFactor, + abandonPct, + BaseHashTable::modeString(mode)); + } + + // Expected mode. + BaseHashTable::HashMode mode; + + // Type of build & probe row. + TypePtr buildType; + + // Distinct rows in the table. + int64_t hashTableSize; + + // Number of build rows. + int64_t buildSize; + + // Number of probe rows. + int64_t probeSize; + + // Number of build RowContainers. + int32_t numBuildBatches; + + // Title for reporting. + std::string title; + + // The duplicate factor, 2 means every row will repeat 2 times. + double dupFactor; + + // This parameter controls the hashing mode. It is incorporated into the keys + // on the build side. If the expected mode is an array, its value is 0. If + // the expected mode is a normalized key, its value is 'kArrayHashMaxSize' + + // 100 to make the key range > 'kArrayHashMaxSize'. If the expected mode is a + // hash, its value is the maximum value of int64_t minus 1 to make the key + // range == 'kRangeTooLarge'. + int64_t extraValue; + + // Number of fields. + int32_t numFields; + + int32_t abandonPct; + + std::string toString() const { + return fmt::format( + "DupFactor:{:<2}, AbandonPct:{}, HashMode:{:<14}", + dupFactor, + abandonPct, + BaseHashTable::modeString(mode)); + } +}; + +struct BenchmarkResult { + BenchmarkParams params; + + uint64_t totalClock{0}; + + uint64_t hashBuildPeakMemoryBytes{0}; + + bool isBuildNoDupHashTableAbandon{false}; + + // The mode of the table. + BaseHashTable::HashMode hashMode; + + std::string toString() const { + return fmt::format( + "{}, isAbandon:{:<5}, totalClock:{}ms, peakMemoryBytes:{}", + params.toString(), + isBuildNoDupHashTableAbandon, + totalClock / 1000'000, + succinctBytes(hashBuildPeakMemoryBytes)); + } +}; + +class HashJoinBuildBenchmark : public VectorTestBase { + public: + HashJoinBuildBenchmark() : randomEngine_((std::random_device{}())) {} + + BenchmarkResult run(BenchmarkParams params) { + params_ = std::move(params); + BenchmarkResult result; + result.params = params_; + result.hashMode = params_.mode; + + std::vector buildVectors; + makeBuildBatches(buildVectors); + + int64_t sequence = 0; + int64_t batchSize = params_.probeSize / 4; + std::vector probeVectors; + for (auto i = 0; i < 4; ++i) { + auto batch = makeProbeVector(batchSize, params_.hashTableSize, sequence); + probeVectors.emplace_back(batch); + } + + uint64_t totalClocks{0}; + { + ClockTimer timer(totalClocks); + auto plan = makeHashJoinPlan(buildVectors, probeVectors); + CursorParameters cursorParams; + cursorParams.planNode = std::move(plan); + cursorParams.queryCtx = core::QueryCtx::create( + executor_.get(), + core::QueryConfig{{}}, + {}, + cache::AsyncDataCache::getInstance(), + rootPool_); + cursorParams.queryCtx->testingOverrideConfigUnsafe({ + {core::QueryConfig::kAbandonDedupHashMapMinPct, + std::to_string(params_.abandonPct)}, + {core::QueryConfig::kAbandonDedupHashMapMinRows, "1000000"}, + }); + + cursorParams.maxDrivers = 1; + auto cursor = TaskCursor::create(cursorParams); + auto* task = cursor->task().get(); + while (cursor->moveNext()) { + } + waitForTaskCompletion(task); + result.isBuildNoDupHashTableAbandon = isBuildNoDupHashTableAbandon(task); + } + result.totalClock = totalClocks; + + result.hashBuildPeakMemoryBytes = getHashBuildPeakMemory(rootPool_.get()); + return result; + } + + private: + std::shared_ptr makeHashJoinPlan( + const std::vector& buildVectors, + const std::vector& probeVectors) { + auto planNodeIdGenerator = std::make_shared(); + return exec::test::PlanBuilder(planNodeIdGenerator, pool_.get()) + .values(probeVectors) + .project({"c0 AS t0", "c1 as t1", "c2 as t2"}) + .hashJoin( + {"t0"}, + {"u0"}, + exec::test::PlanBuilder(planNodeIdGenerator) + .values(buildVectors) + .project({"c0 AS u0"}) + .planNode(), + "", + {"t0", "t1", "match"}, + core::JoinType::kLeftSemiProject) + .planNode(); + } + + // Create the row vector for the build side, where the first column is used + // as the join key, and the remaining columns are dependent fields. + // If expect mode is array, the key is within the range [0, hashTableSize]; + // If expect mode is normalized key, the key is within the range + // [0, hashTableSize] + extraValue(kArrayHashMaxSize + 100); + // If expect mode is hash, the key is within the range [0, hashTableSize] + + // extraValue(max_int64 -1); + RowVectorPtr makeBuildRows( + std::vector& data, + int64_t start, + int64_t end, + bool addExtraValue) { + auto subData = + std::vector(data.begin() + start, data.begin() + end); + if (addExtraValue) { + subData[0] = params_.extraValue; + } + + std::vector children; + children.push_back(makeFlatVector(subData)); + return makeRowVector(children); + } + + // Generate the build side data batches. + void makeBuildBatches(std::vector& batches) { + int64_t buildKey = 0; + std::vector data; + for (auto i = 0; i < params_.buildSize; ++i) { + data.emplace_back((buildKey++) % params_.hashTableSize); + } + std::shuffle(data.begin(), data.end(), randomEngine_); + + auto size = params_.buildSize / params_.numBuildBatches; + for (auto i = 0; i < params_.numBuildBatches; ++i) { + batches.push_back(makeBuildRows( + data, + i * size, + (i + 1) * size + 1, + i == params_.numBuildBatches - 1)); + } + } + + // Create the row vector for the probe side, where the first column is used + // as the join key, and the remaining columns are dependent fields. + // Probe key is within the range [0, hashTableSize]. + RowVectorPtr + makeProbeVector(int64_t size, int64_t hashTableSize, int64_t& sequence) { + std::vector children; + for (int32_t i = 0; i < params_.numFields; ++i) { + children.push_back( + makeFlatVector( + size, + [&](vector_size_t row) { + return (sequence + row) % hashTableSize; + }, + nullptr)); + } + sequence += size; + + for (int32_t i = 0; i < 2; ++i) { + children.push_back( + makeFlatVector( + size, [&](vector_size_t row) { return row + size; }, nullptr)); + } + return makeRowVector(children); + } + + static int64_t getHashBuildPeakMemory(memory::MemoryPool* rootPool) { + int64_t hashBuildPeakBytes = 0; + std::vector pools; + pools.push_back(rootPool); + while (!pools.empty()) { + std::vector childPools; + for (auto pool : pools) { + pool->visitChildren([&](memory::MemoryPool* childPool) -> bool { + if (childPool->name().find("HashBuild") != std::string::npos) { + hashBuildPeakBytes += childPool->peakBytes(); + } + childPools.push_back(childPool); + return true; + }); + } + pools.swap(childPools); + } + if (hashBuildPeakBytes == 0) { + VELOX_FAIL("Failed to get HashBuild peak memory"); + } + return hashBuildPeakBytes; + } + + static bool isBuildNoDupHashTableAbandon(exec::Task* task) { + for (auto& pipelineStat : task->taskStats().pipelineStats) { + for (auto& operatorStat : pipelineStat.operatorStats) { + if (operatorStat.operatorType == OperatorType::kHashBuild) { + return operatorStat + .runtimeStats[std::string( + HashBuild::kAbandonBuildNoDupHash)] + .count != 0; + } + } + } + return false; + } + + std::default_random_engine randomEngine_; + BenchmarkParams params_; +}; + +} // namespace + +int main(int argc, char** argv) { + folly::Init init{&argc, &argv}; + memory::MemoryManager::Options options; + options.useMmapAllocator = true; + options.allocatorCapacity = 10UL << 30; + options.useMmapArena = true; + options.mmapArenaCapacityRatio = 1; + memory::MemoryManager::initialize(options); + + auto bm = std::make_unique(); + std::vector results; + + auto buildRowSize = (2L << 20) - 3; + auto probeRowSize = 100000000L; + + TypePtr twoKeyType{ROW({"k1"}, {BIGINT()})}; + + const std::vector hashModes = { + BaseHashTable::HashMode::kArray, + BaseHashTable::HashMode::kNormalizedKey, + BaseHashTable::HashMode::kHash, + }; + const std::vector dupFactorVector = { + 2, + 8, + 32, + }; + const std::vector abandonPcts = { + 90, + 80, + 70, + 50, + 0, + }; + + std::vector params; + for (auto mode : hashModes) { + for (auto dupFactor : dupFactorVector) { + for (auto pct : abandonPcts) { + params.push_back(BenchmarkParams( + mode, twoKeyType, dupFactor, buildRowSize, probeRowSize, pct, 512)); + } + } + } + + for (auto& param : params) { + BenchmarkResult result; + folly::addBenchmark(__FILE__, param.title, [param, &results, &bm]() { + results.emplace_back(bm->run(param)); + return 1; + }); + } + + folly::runBenchmarks(); + + for (auto& result : results) { + std::cout << result.toString() << std::endl; + } + return 0; +} diff --git a/velox/exec/benchmarks/HashJoinListResultBenchmark.cpp b/velox/exec/benchmarks/HashJoinListResultBenchmark.cpp index f9b6f13f179..cb0c7788c0f 100644 --- a/velox/exec/benchmarks/HashJoinListResultBenchmark.cpp +++ b/velox/exec/benchmarks/HashJoinListResultBenchmark.cpp @@ -280,10 +280,11 @@ class HashTableListJoinResultBenchmark : public VectorTestBase { std::vector children; children.push_back(makeFlatVector(data)); for (int32_t i = 0; i < params_.numDependentFields; ++i) { - children.push_back(makeFlatVector( - data.size(), - [&](vector_size_t row) { return row + maxKey; }, - nullptr)); + children.push_back( + makeFlatVector( + data.size(), + [&](vector_size_t row) { return row + maxKey; }, + nullptr)); } return makeRowVector(children); } @@ -311,14 +312,16 @@ class HashTableListJoinResultBenchmark : public VectorTestBase { RowVectorPtr makeProbeVector(int32_t size, int64_t hashTableSize, int64_t& sequence) { std::vector children; - children.push_back(makeFlatVector( - size, - [&](vector_size_t row) { return (sequence + row) % hashTableSize; }, - nullptr)); + children.push_back( + makeFlatVector( + size, + [&](vector_size_t row) { return (sequence + row) % hashTableSize; }, + nullptr)); sequence += size; for (int32_t i = 0; i < params_.numDependentFields; ++i) { - children.push_back(makeFlatVector( - size, [&](vector_size_t row) { return row + size; }, nullptr)); + children.push_back( + makeFlatVector( + size, [&](vector_size_t row) { return row + size; }, nullptr)); } return makeRowVector(children); } @@ -375,6 +378,7 @@ class HashTableListJoinResultBenchmark : public VectorTestBase { dependentTypes, true, false, + false, // hasCountFlag 1'000, tablePools[i].get()); @@ -391,6 +395,8 @@ class HashTableListJoinResultBenchmark : public VectorTestBase { topTable_->prepareJoinTable( std::move(otherTables), BaseHashTable::kNoSpillInputStartPartitionBit, + 1'000'000, + false, executor_.get()); } buildTime_ = buildClocks; diff --git a/velox/exec/benchmarks/HashJoinPrepareJoinTableBenchmark.cpp b/velox/exec/benchmarks/HashJoinPrepareJoinTableBenchmark.cpp index 6d5988577ad..ac9ffddd4da 100644 --- a/velox/exec/benchmarks/HashJoinPrepareJoinTableBenchmark.cpp +++ b/velox/exec/benchmarks/HashJoinPrepareJoinTableBenchmark.cpp @@ -129,6 +129,8 @@ class HashJoinPrepareJoinTableBenchmark : public VectorTestBase { topTable_->prepareJoinTable( std::move(otherTables_), BaseHashTable::kNoSpillInputStartPartitionBit, + 1'000'000, + false, executor_.get()); VELOX_CHECK_EQ(topTable_->hashMode(), params_.mode); } @@ -219,6 +221,7 @@ class HashJoinPrepareJoinTableBenchmark : public VectorTestBase { dependentTypes, true, false, + false, // hasCountFlag 1'000, pool_.get()); diff --git a/velox/exec/benchmarks/HashRowsBenchmark.cpp b/velox/exec/benchmarks/HashRowsBenchmark.cpp new file mode 100644 index 00000000000..289731003f1 --- /dev/null +++ b/velox/exec/benchmarks/HashRowsBenchmark.cpp @@ -0,0 +1,153 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include + +#include "velox/exec/AdaptivePrefetch.h" + +namespace { + +using namespace facebook::velox::exec; +using NormalizedKeyT = uint64_t; + +struct FakeRowContainer { + explicit FakeRowContainer(int64_t numRows) { + constexpr int32_t kRowSize = 32; + constexpr int32_t kKeySize = sizeof(NormalizedKeyT); + data_.resize(static_cast(numRows) * kRowSize); + rows_.resize(static_cast(numRows)); + + std::mt19937_64 rng(42); + for (int64_t i = 0; i < numRows; ++i) { + char* base = data_.data() + i * kRowSize; + auto* key = reinterpret_cast(base); + *key = rng(); + rows_[i] = base + kKeySize; + } + std::shuffle(rows_.begin(), rows_.end(), std::mt19937(123)); + } + + char** rows() { + return rows_.data(); + } + int64_t numRows() const { + return static_cast(rows_.size()); + } + + private: + std::vector data_; + std::vector rows_; +}; + +inline NormalizedKeyT& normalizedKey(char* row) { + return reinterpret_cast(row)[-1]; +} + +constexpr int32_t kBatchSize = 1024; + +std::unique_ptr g4K; +std::unique_ptr g40K; +std::unique_ptr g400K; +std::unique_ptr g4M; + +void hashRowsNoPrefetch(FakeRowContainer& container, int32_t iters) { + std::vector hashes(kBatchSize); + auto** allRows = container.rows(); + const int64_t numRows = container.numRows(); + + for (int32_t iter = 0; iter < iters; ++iter) { + for (int64_t start = 0; start + kBatchSize <= numRows; + start += kBatchSize) { + char** rows = allRows + start; + for (int32_t i = 0; i < kBatchSize; ++i) { + hashes[i] = folly::hasher()(normalizedKey(rows[i])); + } + } + } + folly::doNotOptimizeAway(hashes.data()); +} + +void hashRowsWithPrefetch(FakeRowContainer& container, int32_t iters) { + std::vector hashes(kBatchSize); + auto** allRows = container.rows(); + const int64_t numRows = container.numRows(); + + for (int32_t iter = 0; iter < iters; ++iter) { + for (int64_t start = 0; start + kBatchSize <= numRows; + start += kBatchSize) { + char** rows = allRows + start; + AdaptivePrefetch prefetch(kBatchSize); + for (int32_t i = 0; i < kBatchSize; ++i) { + if (auto ahead = prefetch.lookAhead()) { + __builtin_prefetch(rows[i + ahead] - sizeof(NormalizedKeyT)); + } + hashes[i] = folly::hasher()(normalizedKey(rows[i])); + } + } + } + folly::doNotOptimizeAway(hashes.data()); +} + +BENCHMARK(noPrefetch_4K) { + hashRowsNoPrefetch(*g4K, 5); +} +BENCHMARK_RELATIVE(withPrefetch_4K) { + hashRowsWithPrefetch(*g4K, 5); +} + +BENCHMARK_DRAW_LINE(); + +BENCHMARK(noPrefetch_40K) { + hashRowsNoPrefetch(*g40K, 5); +} +BENCHMARK_RELATIVE(withPrefetch_40K) { + hashRowsWithPrefetch(*g40K, 5); +} + +BENCHMARK_DRAW_LINE(); + +BENCHMARK(noPrefetch_400K) { + hashRowsNoPrefetch(*g400K, 5); +} +BENCHMARK_RELATIVE(withPrefetch_400K) { + hashRowsWithPrefetch(*g400K, 5); +} + +BENCHMARK_DRAW_LINE(); + +BENCHMARK(noPrefetch_4M) { + hashRowsNoPrefetch(*g4M, 5); +} +BENCHMARK_RELATIVE(withPrefetch_4M) { + hashRowsWithPrefetch(*g4M, 5); +} + +} // namespace + +int main(int argc, char** argv) { + folly::Init init(&argc, &argv); + g4K = std::make_unique(4'000); + g40K = std::make_unique(40'000); + g400K = std::make_unique(400'000); + g4M = std::make_unique(4'000'000); + folly::runBenchmarks(); + return 0; +} diff --git a/velox/exec/benchmarks/HashTableBenchmark.cpp b/velox/exec/benchmarks/HashTableBenchmark.cpp index 7bc4ec69a4d..4d6c9b918c7 100644 --- a/velox/exec/benchmarks/HashTableBenchmark.cpp +++ b/velox/exec/benchmarks/HashTableBenchmark.cpp @@ -173,14 +173,16 @@ class HashTableBenchmark : public VectorTestBase { std::vector batches; std::vector> keyHashers; for (auto channel = 0; channel < params_.numKeys; ++channel) { - keyHashers.emplace_back(std::make_unique( - params_.buildType->childAt(channel), channel)); + keyHashers.emplace_back( + std::make_unique( + params_.buildType->childAt(channel), channel)); } auto table = HashTable::createForJoin( std::move(keyHashers), dependentTypes, true, false, + false, // hasCountFlag 1'000, pool_.get()); @@ -198,6 +200,8 @@ class HashTableBenchmark : public VectorTestBase { topTable_->prepareJoinTable( std::move(otherTables), BaseHashTable::kNoSpillInputStartPartitionBit, + 1'000'000, + false, executor_.get()); LOG(INFO) << "Made table " << topTable_->toString(); @@ -409,8 +413,9 @@ class HashTableBenchmark : public VectorTestBase { TypePtr buildType, std::vector& batches) { for (auto i = 0; i < numBatches; ++i) { - batches.push_back(std::static_pointer_cast( - makeVector(buildType, batchSize, sequence))); + batches.push_back( + std::static_pointer_cast( + makeVector(buildType, batchSize, sequence))); sequence += batchSize; } } diff --git a/velox/exec/benchmarks/MarkSortedBenchmark.cpp b/velox/exec/benchmarks/MarkSortedBenchmark.cpp new file mode 100644 index 00000000000..87f84eee3f3 --- /dev/null +++ b/velox/exec/benchmarks/MarkSortedBenchmark.cpp @@ -0,0 +1,170 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/vector/tests/utils/VectorMaker.h" + +using namespace facebook::velox; +using namespace facebook::velox::exec::test; + +namespace { + +constexpr vector_size_t kBatchSize = 10'000; +constexpr int kIterations = 100; + +class BenchmarkHelper { + public: + memory::MemoryPool* pool() { + return pool_.get(); + } + + test::VectorMaker& vectorMaker() { + return vectorMaker_; + } + + void runMarkSorted( + const std::vector& input, + const std::vector& sortingKeys, + const std::vector& sortingOrders) { + auto plan = PlanBuilder() + .values(input) + .markSorted("is_sorted", sortingKeys, sortingOrders) + .planNode(); + auto result = AssertQueryBuilder(plan).copyResults(pool()); + folly::doNotOptimizeAway(result); + } + + private: + std::shared_ptr pool_{ + memory::memoryManager()->addLeafPool()}; + test::VectorMaker vectorMaker_{pool_.get()}; +}; + +// --- Sorted INTEGER (SIMD path) --- +BENCHMARK(sortedInteger) { + folly::BenchmarkSuspender suspender; + BenchmarkHelper helper; + auto data = helper.vectorMaker().rowVector({ + helper.vectorMaker().flatVector( + kBatchSize, [](vector_size_t i) { return i; }), + }); + suspender.dismiss(); + + for (int i = 0; i < kIterations; ++i) { + helper.runMarkSorted({data}, {"c0"}, {core::kAscNullsLast}); + } +} + +// --- Sorted BIGINT (SIMD path) --- +BENCHMARK_RELATIVE(sortedBigint) { + folly::BenchmarkSuspender suspender; + BenchmarkHelper helper; + auto data = helper.vectorMaker().rowVector({ + helper.vectorMaker().flatVector( + kBatchSize, [](vector_size_t i) { return static_cast(i); }), + }); + suspender.dismiss(); + + for (int i = 0; i < kIterations; ++i) { + helper.runMarkSorted({data}, {"c0"}, {core::kAscNullsLast}); + } +} + +// --- Sorted VARCHAR (generic path, no SIMD) --- +// U8 fix: strings stored in vector with lifetime spanning the benchmark. +BENCHMARK_RELATIVE(sortedVarchar) { + folly::BenchmarkSuspender suspender; + BenchmarkHelper helper; + std::vector storage(kBatchSize); + for (vector_size_t i = 0; i < kBatchSize; ++i) { + storage[i] = fmt::format("str_{:08d}", i); + } + auto data = helper.vectorMaker().rowVector({ + helper.vectorMaker().flatVector( + kBatchSize, + [&storage](vector_size_t i) { return StringView(storage[i]); }), + }); + suspender.dismiss(); + + for (int i = 0; i < kIterations; ++i) { + helper.runMarkSorted({data}, {"c0"}, {core::kAscNullsLast}); + } +} + +// --- Unsorted INTEGER (SIMD path, many false bits) --- +BENCHMARK(unsortedInteger) { + folly::BenchmarkSuspender suspender; + BenchmarkHelper helper; + auto data = helper.vectorMaker().rowVector({ + helper.vectorMaker().flatVector( + kBatchSize, + [](vector_size_t i) { + // Alternating pattern creates many unsorted pairs. + return (i % 2 == 0) ? i : kBatchSize - i; + }), + }); + suspender.dismiss(); + + for (int i = 0; i < kIterations; ++i) { + helper.runMarkSorted({data}, {"c0"}, {core::kAscNullsLast}); + } +} + +// --- ConstantVector (O(1) fast path) --- +BENCHMARK_RELATIVE(constantVector) { + folly::BenchmarkSuspender suspender; + BenchmarkHelper helper; + auto data = helper.vectorMaker().rowVector({ + BaseVector::createConstant(INTEGER(), 42, kBatchSize, helper.pool()), + }); + suspender.dismiss(); + + for (int i = 0; i < kIterations; ++i) { + helper.runMarkSorted({data}, {"c0"}, {core::kAscNullsLast}); + } +} + +// --- Cross-batch comparison (measures overhead) --- +BENCHMARK(crossBatch) { + folly::BenchmarkSuspender suspender; + BenchmarkHelper helper; + constexpr vector_size_t kSmallBatch = 100; + constexpr int kNumBatches = 100; + std::vector batches; + batches.reserve(kNumBatches); + for (int b = 0; b < kNumBatches; ++b) { + batches.push_back(helper.vectorMaker().rowVector({ + helper.vectorMaker().flatVector( + kSmallBatch, [b](vector_size_t i) { return b * kSmallBatch + i; }), + })); + } + suspender.dismiss(); + + for (int i = 0; i < kIterations; ++i) { + helper.runMarkSorted(batches, {"c0"}, {core::kAscNullsLast}); + } +} + +} // namespace + +int main(int argc, char** argv) { + folly::Init init{&argc, &argv}; + memory::MemoryManager::initialize(memory::MemoryManager::Options{}); + folly::runBenchmarks(); + return 0; +} diff --git a/velox/exec/benchmarks/MergeBenchmark.cpp b/velox/exec/benchmarks/MergeBenchmark.cpp index f5fbaee0ff6..118007abe0d 100644 --- a/velox/exec/benchmarks/MergeBenchmark.cpp +++ b/velox/exec/benchmarks/MergeBenchmark.cpp @@ -19,7 +19,7 @@ #include -#include "velox/exec/TreeOfLosers.h" +#include "velox/common/base/TreeOfLosers.h" #include "velox/exec/tests/utils/MergeTestBase.h" using namespace facebook::velox; diff --git a/velox/exec/benchmarks/OrderByBenchmark.cpp b/velox/exec/benchmarks/OrderByBenchmark.cpp index 326bbf0e53d..be01eed258f 100644 --- a/velox/exec/benchmarks/OrderByBenchmark.cpp +++ b/velox/exec/benchmarks/OrderByBenchmark.cpp @@ -56,7 +56,7 @@ class OrderByBenchmark { const auto start = getCurrentTimeMicro(); for (auto i = 0; i < iterations; ++i) { std::shared_ptr task; - test::AssertQueryBuilder(plan).runWithoutResults(task); + test::AssertQueryBuilder(plan).countResults(task); auto taskStats = exec::toPlanStats(task->taskStats()); auto& stats = taskStats.at(orderByNodeId); inputNs += stats.addInputTiming.wallNanos; @@ -79,8 +79,9 @@ class OrderByBenchmark { core::PlanNodeId& orderByNodeId) { folly::BenchmarkSuspender suspender; std::vector vectors; - vectors.emplace_back(OrderByBenchmarkUtil::fuzzRows( - test.rowType, test.numRows, pool_.get())); + vectors.emplace_back( + OrderByBenchmarkUtil::fuzzRows( + test.rowType, test.numRows, pool_.get())); std::vector keys; keys.reserve(test.numKeys); diff --git a/velox/exec/benchmarks/OrderByBenchmarkUtil.h b/velox/exec/benchmarks/OrderByBenchmarkUtil.h index 3605edc936f..8b3855daabc 100644 --- a/velox/exec/benchmarks/OrderByBenchmarkUtil.h +++ b/velox/exec/benchmarks/OrderByBenchmarkUtil.h @@ -23,12 +23,13 @@ class OrderByBenchmarkUtil { public: /// Add the benchmarks with the parameter. /// @param benchmarkFunc benchmark generator. - static void addBenchmarks(const std::function& benchmarkFunc); + static void addBenchmarks( + const std::function& benchmarkFunc); /// Generate RowVector by VectorFuzzer according to rowType. Use /// FLAGS_data_null_ratio to specify the columns null ratio diff --git a/velox/exec/benchmarks/RowContainerSortBenchmark.cpp b/velox/exec/benchmarks/RowContainerSortBenchmark.cpp index c1f5555d37f..3dcffe398ff 100644 --- a/velox/exec/benchmarks/RowContainerSortBenchmark.cpp +++ b/velox/exec/benchmarks/RowContainerSortBenchmark.cpp @@ -17,6 +17,7 @@ #include #include +#include "velox/common/io/IoStatistics.h" #include "velox/dwio/common/tests/utils/DataFiles.h" #include "velox/dwio/parquet/reader/ParquetReader.h" #include "velox/exec/RowContainer.h" @@ -64,7 +65,11 @@ std::vector> getDataFromFile() { const std::string sample(getExampleFilePath("str_sort.parquet")); auto rowType = ROW({"query_sig", "result_sig"}, {VARCHAR(), VARCHAR()}); auto pool = memory::memoryManager()->addLeafPool(); - facebook::velox::dwio::common::ReaderOptions readerOptions{pool.get()}; + auto dataIoStats = std::make_shared(); + auto metadataIoStats = std::make_shared(); + facebook::velox::dwio::common::ReaderOptions readerOptions(pool.get()); + readerOptions.setDataIoStats(dataIoStats); + readerOptions.setMetadataIoStats(metadataIoStats); facebook::velox::parquet::ParquetReader reader = createReader(sample, readerOptions); auto rowReaderOpts = getReaderOpts(rowType); diff --git a/velox/exec/benchmarks/SpatialJoinBenchmark.cpp b/velox/exec/benchmarks/SpatialJoinBenchmark.cpp new file mode 100644 index 00000000000..066cceb6379 --- /dev/null +++ b/velox/exec/benchmarks/SpatialJoinBenchmark.cpp @@ -0,0 +1,360 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "velox/common/memory/Memory.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/parse/TypeResolver.h" +#include "velox/vector/tests/utils/VectorTestBase.h" + +/// Benchmark for SpatialJoin operator, which implements a nested-loop join +/// with spatial predicates (e.g., ST_INTERSECTS, ST_CONTAINS, ST_WITHIN). +/// +/// This benchmark measures the performance of spatial joins under different +/// conditions: +/// - Different build and probe side sizes (cross join cardinality) +/// - Different spatial predicates +/// - Different data distributions (dense vs sparse geometries) +/// - Inner vs Left join types +/// +/// The benchmark creates synthetic geometric data and measures the throughput +/// of spatial join operations. The focus is on understanding how the nested +/// loop pattern performs with varying data sizes and selectivity. + +using namespace facebook::velox; +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; + +namespace { + +/// Spatial distribution patterns for geometry generation. +enum class Distribution { + kUniform, // Geometries uniformly distributed in space + kClustered // Geometries clustered in specific regions +}; + +// Constants for geometry generation. +constexpr int32_t kNullPatternModulo = 13; +constexpr int32_t kRandomCoordinateMax = 10000; +constexpr double kCoordinateScaleDivisor = 10.0; +constexpr int32_t kNumClusters = 5; +constexpr double kClusterSpacing = 200.0; +constexpr double kClusterCenterOffset = 100.0; +constexpr int32_t kClusterSpreadRange = 100; +constexpr int32_t kClusterSpreadHalf = 50; +constexpr double kPolygonSize = 10.0; + +// Constants for benchmark configuration. +constexpr int32_t kDefaultBatchSize = 10000; +constexpr int32_t kSmallBenchmarkSize = 1000; +constexpr int32_t kMediumProbeBenchmarkSize = 50000; +constexpr int32_t kMediumBuildBenchmarkSize = 5000; +constexpr int32_t kLargeProbeBenchmarkSize = 200000; +constexpr int32_t kLargeBuildBenchmarkSize = 50000; + +/// Parameters for a spatial join benchmark test case. +struct SpatialJoinBenchmarkParams { + /// Number of rows on the probe (left) side. + int32_t probeSize; + + /// Number of rows on the build (right) side. + int32_t buildSize; + + /// Spatial predicate to use (e.g., "ST_Intersects", "ST_Contains"). + std::string predicate; + + /// Join type (kInner or kLeft). + core::JoinType joinType; + + /// Spatial distribution pattern for geometry generation. + Distribution distribution; + + /// Description for benchmark naming. + std::string toString() const { + std::string joinTypeStr = + (joinType == core::JoinType::kInner) ? "Inner" : "Left"; + std::string distributionStr = + (distribution == Distribution::kUniform) ? "uniform" : "clustered"; + return fmt::format( + "{}x{}_{}_{}_{}", + probeSize, + buildSize, + predicate, + joinTypeStr, + distributionStr); + } +}; + +class SpatialJoinBenchmark : public facebook::velox::test::VectorTestBase { + public: + SpatialJoinBenchmark() : rng_((std::random_device{}())) {} + + /// Creates a vector of POINT geometries with specified distribution. + VectorPtr + makePointVector(int32_t size, Distribution distribution, bool nulls = false) { + return makeFlatVector( + size, + [&](vector_size_t row) { + if (nulls && (row % kNullPatternModulo == 0)) { + return std::string(""); + } + double x, y; + if (distribution == Distribution::kUniform) { + x = (folly::Random::rand32(rng_) % kRandomCoordinateMax) / + kCoordinateScaleDivisor; + y = (folly::Random::rand32(rng_) % kRandomCoordinateMax) / + kCoordinateScaleDivisor; + } else { + int cluster = row % kNumClusters; + double centerX = (cluster * kClusterSpacing) + kClusterCenterOffset; + double centerY = (cluster * kClusterSpacing) + kClusterCenterOffset; + x = centerX + + ((folly::Random::rand32(rng_) % kClusterSpreadRange) - + kClusterSpreadHalf); + y = centerY + + ((folly::Random::rand32(rng_) % kClusterSpreadRange) - + kClusterSpreadHalf); + } + return fmt::format("POINT ({} {})", x, y); + }, + [&](vector_size_t row) { + return nulls && (row % kNullPatternModulo == 0); + }); + } + + /// Creates a vector of POLYGON geometries with specified distribution. + VectorPtr makePolygonVector( + int32_t size, + Distribution distribution, + bool nulls = false) { + return makeFlatVector( + size, + [&](vector_size_t row) { + if (nulls && (row % kNullPatternModulo == 0)) { + return std::string(""); + } + double centerX, centerY; + if (distribution == Distribution::kUniform) { + centerX = (folly::Random::rand32(rng_) % kRandomCoordinateMax) / + kCoordinateScaleDivisor; + centerY = (folly::Random::rand32(rng_) % kRandomCoordinateMax) / + kCoordinateScaleDivisor; + } else { + int cluster = row % kNumClusters; + centerX = (cluster * kClusterSpacing) + kClusterCenterOffset; + centerY = (cluster * kClusterSpacing) + kClusterCenterOffset; + } + return fmt::format( + "POLYGON (({} {}, {} {}, {} {}, {} {}, {} {}))", + centerX - kPolygonSize, + centerY - kPolygonSize, + centerX + kPolygonSize, + centerY - kPolygonSize, + centerX + kPolygonSize, + centerY + kPolygonSize, + centerX - kPolygonSize, + centerY + kPolygonSize, + centerX - kPolygonSize, + centerY - kPolygonSize); + }, + [&](vector_size_t row) { + return nulls && (row % kNullPatternModulo == 0); + }); + } + + RowVectorPtr createProjectionVector( + const std::string& prefix, + RowVectorPtr input) { + const auto plan = PlanBuilder(std::make_shared()) + .values({input}) + .project( + {fmt::format("{}_id", prefix), + fmt::format( + "ST_GeometryFromText({}_geom) AS {}_geom", + prefix, + prefix)}) + .planNode(); + return AssertQueryBuilder(plan).copyResults(pool_.get()); + } + + /// Creates test data for the specified parameters. + std::pair, std::vector> makeTestData( + const SpatialJoinBenchmarkParams& params) { + // Create probe side data (points) + std::vector probeVectors; + const int32_t batchSize = std::min(params.probeSize, kDefaultBatchSize); + const int32_t numBatches = (params.probeSize + batchSize - 1) / batchSize; + + for (int32_t i = 0; i < numBatches; ++i) { + int32_t currentBatchSize = + std::min(batchSize, params.probeSize - (i * batchSize)); + auto geomVector = + makePointVector(currentBatchSize, params.distribution, false); + auto idVector = makeFlatVector( + currentBatchSize, + [i, batchSize](vector_size_t row) { return (i * batchSize) + row; }); + probeVectors.push_back(createProjectionVector( + "probe", + makeRowVector({"probe_id", "probe_geom"}, {idVector, geomVector}))); + } + + // Create build side data (polygons) + std::vector buildVectors; + const int32_t buildBatchSize = + std::min(params.buildSize, kDefaultBatchSize); + const int32_t numBuildBatches = + (params.buildSize + buildBatchSize - 1) / buildBatchSize; + + for (int32_t i = 0; i < numBuildBatches; ++i) { + int32_t currentBatchSize = + std::min(buildBatchSize, params.buildSize - (i * buildBatchSize)); + auto geomVector = + makePolygonVector(currentBatchSize, params.distribution, false); + auto idVector = makeFlatVector( + currentBatchSize, [i, buildBatchSize](vector_size_t row) { + return (i * buildBatchSize) + row; + }); + buildVectors.push_back(createProjectionVector( + "build", + makeRowVector({"build_id", "build_geom"}, {idVector, geomVector}))); + } + + return {probeVectors, buildVectors}; + } + + /// Creates a spatial join plan with the specified parameters. + std::shared_ptr makeSpatialJoinPlan( + std::vector&& probeVectors, + std::vector&& buildVectors, + const SpatialJoinBenchmarkParams& params) { + const auto planNodeIdGenerator = + std::make_shared(); + return PlanBuilder(planNodeIdGenerator) + .values(probeVectors) + .spatialJoin( + PlanBuilder(planNodeIdGenerator).values(buildVectors).planNode(), + fmt::format("{}(probe_geom, build_geom)", params.predicate), + "probe_geom", + "build_geom", + std::nullopt, + {"probe_id", "probe_geom", "build_id", "build_geom"}, + params.joinType) + .planNode(); + } + + /// Runs a single benchmark iteration. + uint64_t run( + std::shared_ptr plan, + const SpatialJoinBenchmarkParams& params) { + auto result = AssertQueryBuilder(plan).copyResults(pool_.get()); + return result->size(); + } + + /// Adds a benchmark for the given parameters. + void addBenchmark(const SpatialJoinBenchmarkParams& params) { + auto name = params.toString(); + folly::addBenchmark(__FILE__, name, [this, params]() { + std::shared_ptr plan; + BENCHMARK_SUSPEND { + auto [probeVectors, buildVectors] = makeTestData(params); + plan = makeSpatialJoinPlan( + std::move(probeVectors), std::move(buildVectors), params); + } + + run(plan, params); + return 1; + }); + } + + private: + std::default_random_engine rng_; +}; + +} // namespace + +int main(int argc, char** argv) { + folly::Init init{&argc, &argv}; + memory::initializeMemoryManager(memory::MemoryManager::Options{}); + parse::registerTypeResolver(); + functions::prestosql::registerAllScalarFunctions(); + + SpatialJoinBenchmark bm; + + // Small scale benchmarks (1K x 1K) + bm.addBenchmark( + {kSmallBenchmarkSize, + kSmallBenchmarkSize, + "ST_Intersects", + core::JoinType::kInner, + Distribution::kUniform}); + bm.addBenchmark( + {kSmallBenchmarkSize, + kSmallBenchmarkSize, + "ST_Intersects", + core::JoinType::kInner, + Distribution::kClustered}); + + // Medium scale benchmarks (50K x 5K) + bm.addBenchmark( + {kMediumProbeBenchmarkSize, + kMediumBuildBenchmarkSize, + "ST_Intersects", + core::JoinType::kInner, + Distribution::kUniform}); + bm.addBenchmark( + {kMediumProbeBenchmarkSize, + kMediumBuildBenchmarkSize, + "ST_Intersects", + core::JoinType::kInner, + Distribution::kClustered}); + + // Left join benchmarks (50K x 5K) + bm.addBenchmark( + {kMediumProbeBenchmarkSize / 2, + kMediumBuildBenchmarkSize, + "ST_Intersects", + core::JoinType::kLeft, + Distribution::kUniform}); + bm.addBenchmark( + {kMediumProbeBenchmarkSize / 2, + kMediumBuildBenchmarkSize, + "ST_Intersects", + core::JoinType::kLeft, + Distribution::kClustered}); + + // Contains predicate benchmarks (50K x 5K) + bm.addBenchmark( + {kMediumProbeBenchmarkSize / 2, + kMediumBuildBenchmarkSize, + "ST_Contains", + core::JoinType::kInner, + Distribution::kUniform}); + + // Large scale benchmark (200K x 50K) + bm.addBenchmark( + {kLargeProbeBenchmarkSize, + kLargeBuildBenchmarkSize, + "ST_Intersects", + core::JoinType::kInner, + Distribution::kUniform}); + + folly::runBenchmarks(); + return 0; +} diff --git a/velox/exec/benchmarks/StreamingAggregationBenchmark.cpp b/velox/exec/benchmarks/StreamingAggregationBenchmark.cpp index dc1f16ae6a1..9393d7cef7d 100644 --- a/velox/exec/benchmarks/StreamingAggregationBenchmark.cpp +++ b/velox/exec/benchmarks/StreamingAggregationBenchmark.cpp @@ -84,10 +84,9 @@ class StreamingAggregationBenchmark : public VectorTestBase { std::to_string(params.numGroups)); folly::addBenchmark(__FILE__, name, [plan = &test->plan]() { - std::shared_ptr task; exec::test::AssertQueryBuilder(*plan) .serialExecution(true) - .runWithoutResults(task); + .countResults(); return 1; }); diff --git a/velox/exec/benchmarks/WindowPrefixSortBenchmark.cpp b/velox/exec/benchmarks/WindowPrefixSortBenchmark.cpp index 27025f72351..bb012adc453 100644 --- a/velox/exec/benchmarks/WindowPrefixSortBenchmark.cpp +++ b/velox/exec/benchmarks/WindowPrefixSortBenchmark.cpp @@ -21,6 +21,7 @@ #include "velox/common/memory/SharedArbitrator.h" #include "velox/exec/Cursor.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" @@ -38,6 +39,7 @@ static constexpr int32_t kNumVectors = 50; static constexpr int32_t kRowsPerVector = 1'0000; namespace { +using namespace facebook::velox::common::testutil; class WindowPrefixSortBenchmark : public HiveConnectorTestBase { public: @@ -72,8 +74,9 @@ class WindowPrefixSortBenchmark : public HiveConnectorTestBase { // Generate key with a small number of unique values from a small range // (0-16). - children.emplace_back(makeFlatVector( - kRowsPerVector, [](auto row) { return row % 17; })); + children.emplace_back( + makeFlatVector( + kRowsPerVector, [](auto row) { return row % 17; })); // Generate key with a small number of unique values from a large range // (300 total values). @@ -94,8 +97,9 @@ class WindowPrefixSortBenchmark : public HiveConnectorTestBase { // Generate a column with increasing values to get a deterministic sort // order. - children.emplace_back(makeFlatVector( - kRowsPerVector, [](auto row) { return row; })); + children.emplace_back( + makeFlatVector( + kRowsPerVector, [](auto row) { return row; })); // Generate random values without nulls. children.emplace_back(fuzzer.fuzzFlat(INTEGER())); @@ -168,11 +172,11 @@ class WindowPrefixSortBenchmark : public HiveConnectorTestBase { auto stats = task->taskStats(); for (auto& pipeline : stats.pipelineStats) { for (auto& op : pipeline.operatorStats) { - if (op.operatorType == "Window") { + if (op.operatorType == OperatorType::kWindow) { windowNanos_.add(op.addInputTiming); windowNanos_.add(op.getOutputTiming); } - if (op.operatorType == "Values") { + if (op.operatorType == OperatorType::kValues) { // This is the timing for Window::noMoreInput() where the window // sorting happens. So including in the cpu timing. windowNanos_.add(op.finishTiming); @@ -192,7 +196,8 @@ class WindowPrefixSortBenchmark : public HiveConnectorTestBase { std::move(plan), 0, core::QueryCtx::create(executor_.get()), - Task::ExecutionMode::kSerial); + Task::ExecutionMode::kSerial, + exec::Consumer{}); } else { const std::unordered_map queryConfigMap( {{core::QueryConfig::kPrefixSortNormalizedKeyMaxBytes, "0"}}); @@ -202,7 +207,8 @@ class WindowPrefixSortBenchmark : public HiveConnectorTestBase { 0, core::QueryCtx::create( executor_.get(), core::QueryConfig(queryConfigMap)), - Task::ExecutionMode::kSerial); + Task::ExecutionMode::kSerial, + exec::Consumer{}); } } diff --git a/velox/exec/benchmarks/WindowSubPartitionedSortBenchmark.cpp b/velox/exec/benchmarks/WindowSubPartitionedSortBenchmark.cpp new file mode 100644 index 00000000000..0b64ab91bed --- /dev/null +++ b/velox/exec/benchmarks/WindowSubPartitionedSortBenchmark.cpp @@ -0,0 +1,410 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "velox/common/memory/SharedArbitrator.h" +#include "velox/exec/Cursor.h" +#include "velox/exec/OperatorType.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" +#include "velox/functions/prestosql/window/WindowFunctionsRegistration.h" +#include "velox/vector/fuzzer/VectorFuzzer.h" + +DEFINE_int64(fuzzer_seed, 99887766, "Seed for random input dataset generator"); + +using namespace facebook::velox; +using namespace facebook::velox::test; +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; + +static constexpr int32_t kNumVectors = 200; +static constexpr int32_t kRowsPerVector = 10'000; + +namespace { +using namespace facebook::velox::common::testutil; + +class BenchmarkRecorder { + public: + BenchmarkRecorder() = default; + + void record(std::string name, uint64_t numBytes) { + // Only record the first apperance. + if (numBytesRecords_.count(name) == 0) { + numBytesRecords_[name] = {1, numBytes}; + names_.push_back(name); + } else { + auto& record = numBytesRecords_[name]; + record.numAppearance++; + record.totalCount += numBytes; + } + } + + std::string report() { + std::string result = "name, memory(MB)\n"; + for (auto& name : names_) { + auto& record = numBytesRecords_[name]; + result += fmt::format( + "{}, {}MB\n", + name, + record.totalCount / 1024 / 1024 / record.numAppearance); + } + return result; + } + + private: + struct Counter { + int32_t numAppearance{0}; + uint64_t totalCount{0}; + }; + std::vector names_; + std::unordered_map numBytesRecords_; +}; + +class WindowSubPartitionedSortBenchmark : public HiveConnectorTestBase { + public: + WindowSubPartitionedSortBenchmark( + int32_t numVectors, + int32_t rowsPerVector, + std::shared_ptr recorder) + : numVectors_(numVectors), + rowsPerVector_(rowsPerVector), + recorder_(recorder) { + memory::SharedArbitrator::registerFactory(); + HiveConnectorTestBase::SetUp(); + aggregate::prestosql::registerAllAggregateFunctions(); + window::prestosql::registerAllWindowFunctions(); + + inputType_ = ROW({ + {"k_array", INTEGER()}, + {"k_norm", INTEGER()}, + {"k_hash", INTEGER()}, + {"k_sort", INTEGER()}, + {"i32", INTEGER()}, + {"i64", BIGINT()}, + {"f32", REAL()}, + {"f64", DOUBLE()}, + {"i32_halfnull", INTEGER()}, + {"i64_halfnull", BIGINT()}, + {"f32_halfnull", REAL()}, + {"f64_halfnull", DOUBLE()}, + }); + + VectorFuzzer::Options opts; + opts.vectorSize = rowsPerVector_; + opts.nullRatio = 0; + VectorFuzzer fuzzer(opts, pool_.get(), FLAGS_fuzzer_seed); + std::vector inputVectors; + for (auto i = 0; i < numVectors_; ++i) { + std::vector children; + + // Generate key with a small number of unique values from a small range + // (0-16). + children.emplace_back( + makeFlatVector( + rowsPerVector_, [](auto row) { return row % 17; })); + + // Generate key with a small number of unique values from a large range + // (300 total values). + children.emplace_back( + makeFlatVector(rowsPerVector_, [](auto row) { + if (row % 3 == 0) { + return std::numeric_limits::max() - row % 100; + } else if (row % 3 == 1) { + return row % 100; + } else { + return std::numeric_limits::min() + row % 100; + } + })); + + // Generate key with many unique values from a large range (500K total + // values). + children.emplace_back(fuzzer.fuzzFlat(INTEGER())); + + // Generate a column with increasing values to get a deterministic sort + // order. + children.emplace_back( + makeFlatVector( + rowsPerVector_, [](auto row) { return row; })); + + // Generate random values without nulls. + children.emplace_back(fuzzer.fuzzFlat(INTEGER())); + children.emplace_back(fuzzer.fuzzFlat(BIGINT())); + children.emplace_back(fuzzer.fuzzFlat(REAL())); + children.emplace_back(fuzzer.fuzzFlat(DOUBLE())); + + // Generate random values with nulls. + opts.nullRatio = 0.05; // 5% + fuzzer.setOptions(opts); + + children.emplace_back(fuzzer.fuzzFlat(INTEGER())); + children.emplace_back(fuzzer.fuzzFlat(BIGINT())); + children.emplace_back(fuzzer.fuzzFlat(REAL())); + children.emplace_back(fuzzer.fuzzFlat(DOUBLE())); + + inputVectors.emplace_back(makeRowVector(inputType_->names(), children)); + } + + sourceFilePath_ = TempFilePath::create(); + writeToFile(sourceFilePath_->getPath(), inputVectors); + } + + ~WindowSubPartitionedSortBenchmark() override { + HiveConnectorTestBase::TearDown(); + } + + CpuWallTiming windowNanos() { + return windowNanos_; + } + + void TestBody() override {} + + void run( + const std::string& recordName, + const std::string& key, + const std::string& aggregate, + int32_t numSubPartitions) { + folly::BenchmarkSuspender suspender1; + + windowNanos_.clear(); + windowMems_.clear(); + + std::string functionSql = fmt::format( + "{} over (partition by {} order by k_sort)", aggregate, key); + + core::PlanNodeId tableScanPlanId; + core::PlanFragment plan = PlanBuilder() + .tableScan(inputType_) + .capturePlanNodeId(tableScanPlanId) + .window({functionSql}) + .planFragment(); + + vector_size_t numResultRows = 0; + auto task = makeTask(plan, numSubPartitions); + task->addSplit( + tableScanPlanId, + exec::Split(makeHiveConnectorSplit(sourceFilePath_->getPath()))); + task->noMoreSplits(tableScanPlanId); + suspender1.dismiss(); + + while (auto result = task->next()) { + numResultRows += result->size(); + } + + folly::BenchmarkSuspender suspender2; + auto stats = task->taskStats(); + for (auto& pipeline : stats.pipelineStats) { + for (auto& op : pipeline.operatorStats) { + if (op.operatorType == OperatorType::kWindow) { + windowNanos_.add(op.addInputTiming); + windowNanos_.add(op.getOutputTiming); + windowMems_.add(op.memoryStats); + } + if (op.operatorType == OperatorType::kValues) { + // This is the timing for Window::noMoreInput() where the window + // sorting happens. So including in the cpu timing. + windowNanos_.add(op.finishTiming); + } + } + } + recorder_->record(recordName, windowMems_.peakTotalMemoryReservation); + suspender2.dismiss(); + folly::doNotOptimizeAway(numResultRows); + } + + std::shared_ptr makeTask( + core::PlanFragment plan, + int32_t numSubPartitions) { + bool subPartitionedSort = numSubPartitions > 1; + if (subPartitionedSort) { + const std::unordered_map queryConfigMap( + {{core::QueryConfig::kWindowNumSubPartitions, + std::to_string(numSubPartitions)}}); + return exec::Task::create( + "t", + std::move(plan), + 0, + core::QueryCtx::create( + executor_.get(), core::QueryConfig(queryConfigMap)), + Task::ExecutionMode::kSerial); + + } else { + return exec::Task::create( + "t", + std::move(plan), + 0, + core::QueryCtx::create(executor_.get()), + Task::ExecutionMode::kSerial); + } + } + + uint64_t getLatestMemoryUsage() { + return windowMems_.peakTotalMemoryReservation; + } + + private: + const int32_t numVectors_; + const int32_t rowsPerVector_; + const std::shared_ptr recorder_; + RowTypePtr inputType_; + std::shared_ptr sourceFilePath_; + + CpuWallTiming windowNanos_; + MemoryStats windowMems_; +}; + +std::unique_ptr benchmark; +auto recorder = std::make_shared(); + +void doSortRun( + uint32_t, + const std::string& recordName, + int32_t numSubPartitions, + const std::string& key, + const std::string& aggregate) { + benchmark->run(recordName, key, aggregate, numSubPartitions); +} + +#define BENCHMARK_AND_RECORD_HEAD(_num_, _name_, _key_, _agg_) \ + BENCHMARK_NAMED_PARAM( \ + doSortRun, \ + num##_num_##_##_name_, \ + fmt::format("num{}_{}", #_num_, #_name_), \ + _num_, \ + _key_, \ + _agg_); + +#define BENCHMARK_AND_RECORD_TAIL(_num_, _name_, _key_, _agg_) \ + BENCHMARK_RELATIVE_NAMED_PARAM( \ + doSortRun, \ + num##_num_##_##_name_, \ + fmt::format("num{}_{}", #_num_, #_name_), \ + _num_, \ + _key_, \ + _agg_); + +#define BATCHED_BENCHMARKS(_name_, _key_, _agg_) \ + BENCHMARK_AND_RECORD_HEAD(1, _name_, _key_, _agg_); \ + BENCHMARK_AND_RECORD_TAIL(2, _name_, _key_, _agg_); \ + BENCHMARK_AND_RECORD_TAIL(4, _name_, _key_, _agg_); \ + BENCHMARK_AND_RECORD_TAIL(8, _name_, _key_, _agg_); \ + BENCHMARK_AND_RECORD_TAIL(16, _name_, _key_, _agg_); \ + BENCHMARK_AND_RECORD_TAIL(32, _name_, _key_, _agg_); \ + BENCHMARK_AND_RECORD_TAIL(64, _name_, _key_, _agg_); \ + BENCHMARK_AND_RECORD_TAIL(128, _name_, _key_, _agg_); \ + BENCHMARK_AND_RECORD_TAIL(256, _name_, _key_, _agg_); \ + BENCHMARK_AND_RECORD_TAIL(512, _name_, _key_, _agg_); \ + BENCHMARK_AND_RECORD_TAIL(1024, _name_, _key_, _agg_); \ + BENCHMARK_AND_RECORD_TAIL(2048, _name_, _key_, _agg_); + +#define AGG_BENCHMARKS(_name_, _key_) \ + BATCHED_BENCHMARKS( \ + _name_##_INTEGER_##_key_, #_key_, fmt::format("{}(i32)", (#_name_))); \ + BATCHED_BENCHMARKS( \ + _name_##_REAL_##_key_, #_key_, fmt::format("{}(f32)", (#_name_))); \ + BATCHED_BENCHMARKS( \ + _name_##_INTEGER_NULLS_##_key_, \ + #_key_, \ + fmt::format("{}(i32_halfnull)", (#_name_))); \ + BATCHED_BENCHMARKS( \ + _name_##_REAL_NULLS_##_key_, \ + #_key_, \ + fmt::format("{}(f32_halfnull)", (#_name_))); + +#define MULTI_KEY_AGG_BENCHMARKS(_name_, _key1_, _key2_) \ + BATCHED_BENCHMARKS( \ + _name_##_BIGINT_##_key1_##_key2_, \ + fmt::format("{},{}", (#_key1_), (#_key2_)), \ + fmt::format("{}(i64)", (#_name_))); \ + BATCHED_BENCHMARKS( \ + _name_##_BIGINT_NULLS_##_key1_##_key2_, \ + fmt::format("{},{}", (#_key1_), (#_key2_)), \ + fmt::format("{}(i64_halfnull)", (#_name_))); \ + BATCHED_BENCHMARKS( \ + _name_##_DOUBLE_##_key1_##_key2_, \ + fmt::format("{},{}", (#_key1_), (#_key2_)), \ + fmt::format("{}(f64)", (#_name_))); \ + BATCHED_BENCHMARKS( \ + _name_##_DOUBLE_NULLS_##_key1_##_key2_, \ + fmt::format("{},{}", (#_key1_), (#_key2_)), \ + fmt::format("{}(f64_halfnull)", (#_name_))); + +// Count(1) aggregate. +BATCHED_BENCHMARKS(count_k_array, "k_array", "count(1)"); +BATCHED_BENCHMARKS(count_k_norm, "k_norm", "count(1)"); +BATCHED_BENCHMARKS(count_k_hash, "k_hash", "count(1)"); +BATCHED_BENCHMARKS(count_k_array_k_hash, "k_array,i32", "count(1)"); +BENCHMARK_DRAW_LINE(); + +// Count aggregate. +AGG_BENCHMARKS(count, k_array) +AGG_BENCHMARKS(count, k_norm) +AGG_BENCHMARKS(count, k_hash) +MULTI_KEY_AGG_BENCHMARKS(count, k_array, i32) +MULTI_KEY_AGG_BENCHMARKS(count, k_array, i64) +MULTI_KEY_AGG_BENCHMARKS(count, k_hash, f32) +MULTI_KEY_AGG_BENCHMARKS(count, k_hash, f64) +BENCHMARK_DRAW_LINE(); + +// Avg aggregate. +AGG_BENCHMARKS(avg, k_array) +AGG_BENCHMARKS(avg, k_norm) +AGG_BENCHMARKS(avg, k_hash) +MULTI_KEY_AGG_BENCHMARKS(avg, k_array, i32) +MULTI_KEY_AGG_BENCHMARKS(avg, k_array, i64) +MULTI_KEY_AGG_BENCHMARKS(avg, k_hash, f32) +MULTI_KEY_AGG_BENCHMARKS(avg, k_hash, f64) +BENCHMARK_DRAW_LINE(); + +// Min aggregate. +AGG_BENCHMARKS(min, k_array) +AGG_BENCHMARKS(min, k_norm) +AGG_BENCHMARKS(min, k_hash) +MULTI_KEY_AGG_BENCHMARKS(min, k_array, i32) +MULTI_KEY_AGG_BENCHMARKS(min, k_array, i64) +MULTI_KEY_AGG_BENCHMARKS(min, k_hash, f32) +MULTI_KEY_AGG_BENCHMARKS(min, k_hash, f64) +BENCHMARK_DRAW_LINE(); + +// Max aggregate. +AGG_BENCHMARKS(max, k_array) +AGG_BENCHMARKS(max, k_norm) +AGG_BENCHMARKS(max, k_hash) +MULTI_KEY_AGG_BENCHMARKS(max, k_array, i32) +MULTI_KEY_AGG_BENCHMARKS(max, k_array, i64) +MULTI_KEY_AGG_BENCHMARKS(max, k_hash, f32) +MULTI_KEY_AGG_BENCHMARKS(max, k_hash, f64) +BENCHMARK_DRAW_LINE(); + +} // namespace + +int main(int argc, char** argv) { + folly::Init init(&argc, &argv); + facebook::velox::memory::MemoryManager::initialize( + facebook::velox::memory::MemoryManager::Options{}); + + benchmark = std::make_unique( + kNumVectors, kRowsPerVector, recorder); + folly::runBenchmarks(); + benchmark.reset(); + + std::cout << std::endl << recorder->report(); + return 0; +} diff --git a/velox/exec/fuzzer/AggregationFuzzer.cpp b/velox/exec/fuzzer/AggregationFuzzer.cpp index 98854f82d26..844e25dca97 100644 --- a/velox/exec/fuzzer/AggregationFuzzer.cpp +++ b/velox/exec/fuzzer/AggregationFuzzer.cpp @@ -21,8 +21,8 @@ #include "velox/connectors/hive/TableHandle.h" #include "velox/dwio/dwrf/reader/DwrfReader.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/exec/PartitionFunction.h" #include "velox/exec/fuzzer/FuzzerUtil.h" @@ -34,11 +34,18 @@ DEFINE_bool( true, "When true, generates plans with aggregations over sorted inputs"); +DEFINE_bool( + enable_streaming_aggregations, + true, + "When true, generates plans with streaming aggregations"); + using facebook::velox::fuzzer::CallableSignature; using facebook::velox::fuzzer::SignatureTemplate; namespace facebook::velox::exec::test { +using namespace facebook::velox::common::testutil; + class AggregationFuzzerBase; namespace { @@ -490,21 +497,23 @@ void makeAlternativePlansWithValues( const std::vector& projections, std::vector& plans) { // Partial -> final aggregation plan. - plans.push_back(PlanBuilder() - .values(inputVectors) - .projectExpressions(projections) - .partialAggregation(groupingKeys, aggregates, masks) - .finalAggregation() - .planNode()); + plans.push_back( + PlanBuilder() + .values(inputVectors) + .projectExpressions(projections) + .partialAggregation(groupingKeys, aggregates, masks) + .finalAggregation() + .planNode()); // Partial -> intermediate -> final aggregation plan. - plans.push_back(PlanBuilder() - .values(inputVectors) - .projectExpressions(projections) - .partialAggregation(groupingKeys, aggregates, masks) - .intermediateAggregation() - .finalAggregation() - .planNode()); + plans.push_back( + PlanBuilder() + .values(inputVectors) + .projectExpressions(projections) + .partialAggregation(groupingKeys, aggregates, masks) + .intermediateAggregation() + .finalAggregation() + .planNode()); // Partial -> local exchange -> final aggregation plan. auto numSources = std::min(4, inputVectors.size()); @@ -550,23 +559,25 @@ void makeAlternativePlansWithTableScan( // the false negatives. #ifndef TSAN_BUILD // Partial -> final aggregation plan. - plans.push_back(PlanBuilder() - .tableScan(inputRowType) - .projectExpressions(projections) - .partialAggregation(groupingKeys, aggregates, masks) - .localPartition(groupingKeys) - .finalAggregation() - .planNode()); + plans.push_back( + PlanBuilder() + .tableScan(inputRowType) + .projectExpressions(projections) + .partialAggregation(groupingKeys, aggregates, masks) + .localPartition(groupingKeys) + .finalAggregation() + .planNode()); // Partial -> intermediate -> final aggregation plan. - plans.push_back(PlanBuilder() - .tableScan(inputRowType) - .projectExpressions(projections) - .partialAggregation(groupingKeys, aggregates, masks) - .localPartition(groupingKeys) - .intermediateAggregation() - .finalAggregation() - .planNode()); + plans.push_back( + PlanBuilder() + .tableScan(inputRowType) + .projectExpressions(projections) + .partialAggregation(groupingKeys, aggregates, masks) + .localPartition(groupingKeys) + .intermediateAggregation() + .finalAggregation() + .planNode()); #endif } @@ -578,17 +589,18 @@ void makeStreamingPlansWithValues( const std::vector& projections, std::vector& plans) { // Single aggregation. - plans.push_back(PlanBuilder() - .values(inputVectors) - .projectExpressions(projections) - .orderBy(groupingKeys, false) - .streamingAggregation( - groupingKeys, - aggregates, - masks, - core::AggregationNode::Step::kSingle, - false) - .planNode()); + plans.push_back( + PlanBuilder() + .values(inputVectors) + .projectExpressions(projections) + .orderBy(groupingKeys, false) + .streamingAggregation( + groupingKeys, + aggregates, + masks, + core::AggregationNode::Step::kSingle, + false) + .planNode()); // Partial -> final aggregation plan. plans.push_back( @@ -643,17 +655,18 @@ void makeStreamingPlansWithTableScan( const std::vector& projections, std::vector& plans) { // Single aggregation. - plans.push_back(PlanBuilder() - .tableScan(inputRowType) - .projectExpressions(projections) - .orderBy(groupingKeys, false) - .streamingAggregation( - groupingKeys, - aggregates, - masks, - core::AggregationNode::Step::kSingle, - false) - .planNode()); + plans.push_back( + PlanBuilder() + .tableScan(inputRowType) + .projectExpressions(projections) + .orderBy(groupingKeys, false) + .streamingAggregation( + groupingKeys, + aggregates, + masks, + core::AggregationNode::Step::kSingle, + false) + .planNode()); // Partial -> final aggregation plan. plans.push_back( @@ -724,7 +737,7 @@ bool AggregationFuzzer::verifyAggregation( std::vector plans; plans.push_back({firstPlan, {}}); - auto directory = exec::test::TempDirectoryPath::create(); + auto directory = TempDirectoryPath::create(); // Alternate between using Values and TableScan node. @@ -741,7 +754,7 @@ bool AggregationFuzzer::verifyAggregation( projections, tableScanPlans); - if (!groupingKeys.empty()) { + if (FLAGS_enable_streaming_aggregations && !groupingKeys.empty()) { // Use OrderBy + StreamingAggregation on original input. makeStreamingPlansWithTableScan( groupingKeys, @@ -772,7 +785,7 @@ bool AggregationFuzzer::verifyAggregation( makeAlternativePlansWithValues( groupingKeys, aggregates, masks, flatInput, projections, valuesPlans); - if (!groupingKeys.empty()) { + if (FLAGS_enable_streaming_aggregations && !groupingKeys.empty()) { // Use OrderBy + StreamingAggregation on original input. makeStreamingPlansWithValues( groupingKeys, aggregates, masks, input, projections, valuesPlans); @@ -839,7 +852,7 @@ bool AggregationFuzzer::verifySortedAggregation( std::vector plans; plans.push_back({firstPlan, {}}); - if (!groupingKeys.empty()) { + if (FLAGS_enable_streaming_aggregations && !groupingKeys.empty()) { plans.push_back( {PlanBuilder() .values(input) @@ -855,10 +868,10 @@ bool AggregationFuzzer::verifySortedAggregation( {}}); } - std::shared_ptr directory; + std::shared_ptr directory; const auto inputRowType = asRowType(input[0]->type()); if (isTableScanSupported(inputRowType)) { - directory = exec::test::TempDirectoryPath::create(); + directory = TempDirectoryPath::create(); auto splits = makeSplits(input, directory->getPath(), writerPool_); plans.push_back( @@ -869,7 +882,7 @@ bool AggregationFuzzer::verifySortedAggregation( .planNode(), splits}); - if (!groupingKeys.empty()) { + if (FLAGS_enable_streaming_aggregations && !groupingKeys.empty()) { plans.push_back( {PlanBuilder() .tableScan(inputRowType) @@ -1066,7 +1079,7 @@ bool AggregationFuzzer::compareEquivalentPlanResults( firstPlan, referenceQueryRunner_.get()); stats_.updateReferenceQueryStats(referenceResult.second); - if (referenceResult.first) { + if (referenceResult.first && !referenceResult.first.value().empty()) { velox::fuzzer::ResultOrError expected; expected.result = fuzzer::mergeRowVectors( referenceResult.first.value(), pool_.get()); @@ -1135,7 +1148,7 @@ bool AggregationFuzzer::verifyDistinctAggregation( std::vector plans; plans.push_back({firstPlan, {}}); - if (!groupingKeys.empty()) { + if (FLAGS_enable_streaming_aggregations && !groupingKeys.empty()) { plans.push_back( {PlanBuilder() .values(input) @@ -1153,10 +1166,10 @@ bool AggregationFuzzer::verifyDistinctAggregation( // Alternate between using Values and TableScan node. - std::shared_ptr directory; + std::shared_ptr directory; const auto inputRowType = asRowType(input[0]->type()); if (isTableScanSupported(inputRowType) && vectorFuzzer_.coinToss(0.5)) { - directory = exec::test::TempDirectoryPath::create(); + directory = TempDirectoryPath::create(); auto splits = makeSplits(input, directory->getPath(), writerPool_); plans.push_back( @@ -1167,7 +1180,7 @@ bool AggregationFuzzer::verifyDistinctAggregation( .planNode(), splits}); - if (!groupingKeys.empty()) { + if (FLAGS_enable_streaming_aggregations && !groupingKeys.empty()) { plans.push_back( {PlanBuilder() .tableScan(inputRowType) diff --git a/velox/exec/fuzzer/AggregationFuzzerBase.cpp b/velox/exec/fuzzer/AggregationFuzzerBase.cpp index 6dfbf494746..753bd019537 100644 --- a/velox/exec/fuzzer/AggregationFuzzerBase.cpp +++ b/velox/exec/fuzzer/AggregationFuzzerBase.cpp @@ -18,10 +18,10 @@ #include #include "velox/common/base/Fs.h" #include "velox/common/base/VeloxException.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/connectors/hive/HiveConnectorSplit.h" #include "velox/dwio/dwrf/writer/Writer.h" #include "velox/exec/Spill.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/expression/SignatureBinder.h" #include "velox/expression/fuzzer/ArgumentTypeFuzzer.h" #include "velox/vector/VectorSaver.h" @@ -85,13 +85,15 @@ DEFINE_bool( namespace facebook::velox::exec::test { +using namespace facebook::velox::common::testutil; + int32_t AggregationFuzzerBase::randInt(int32_t min, int32_t max) { return boost::random::uniform_int_distribution(min, max)(rng_); } bool AggregationFuzzerBase::isSupportedType(const TypePtr& type) const { // Date / IntervalDayTime/ Unknown are not currently supported by DWRF. - if (type->isDate() || type->isIntervalDayTime() || type->isUnKnown()) { + if (type->isDate() || type->isIntervalDayTime() || type->isUnknown()) { return false; } @@ -303,8 +305,9 @@ std::vector AggregationFuzzerBase::generateInputData( children.push_back(vectorFuzzer_.fuzz(inputType->childAt(j), size)); } - input.push_back(std::make_shared( - pool_.get(), inputType, nullptr, size, std::move(children))); + input.push_back( + std::make_shared( + pool_.get(), inputType, nullptr, size, std::move(children))); } if (generator != nullptr) { @@ -404,16 +407,18 @@ std::vector AggregationFuzzerBase::generateInputDataWithRowNumber( // values. This is done to introduce some repetition of key values for // windowing. auto baseVector = vectorFuzzer_.fuzz(types[i], numPartitions); - children.push_back(BaseVector::wrapInDictionary( - partitionNulls, partitionIndices, size, baseVector)); + children.push_back( + BaseVector::wrapInDictionary( + partitionNulls, partitionIndices, size, baseVector)); } else if ( windowFrameBoundsSet.find(names[i]) != windowFrameBoundsSet.end()) { // Frame bound columns cannot have NULLs. children.push_back(vectorFuzzer_.fuzzNotNull(types[i], size)); } else if (sortingKeySet.find(names[i]) != sortingKeySet.end()) { auto baseVector = vectorFuzzer_.fuzz(types[i], numPeerGroups); - children.push_back(BaseVector::wrapInDictionary( - sortingNulls, sortingIndices, size, baseVector)); + children.push_back( + BaseVector::wrapInDictionary( + sortingNulls, sortingIndices, size, baseVector)); } else { children.push_back(vectorFuzzer_.fuzz(types[i], size)); } @@ -495,7 +500,7 @@ velox::fuzzer::ResultOrError AggregationFuzzerBase::execute( int32_t spillPct{0}; if (injectSpill) { - spillDirectory = exec::test::TempDirectoryPath::create(); + spillDirectory = TempDirectoryPath::create(); builder.spillDirectory(spillDirectory->getPath()) .config(core::QueryConfig::kSpillEnabled, "true") .config(core::QueryConfig::kAggregationSpillEnabled, "true") diff --git a/velox/exec/fuzzer/AggregationFuzzerBase.h b/velox/exec/fuzzer/AggregationFuzzerBase.h index 246823fdd75..ce97472983c 100644 --- a/velox/exec/fuzzer/AggregationFuzzerBase.h +++ b/velox/exec/fuzzer/AggregationFuzzerBase.h @@ -81,7 +81,11 @@ class AggregationFuzzerBase { : getFuzzerOptions(timestampPrecision), pool_.get()} { filesystems::registerLocalFileSystem(); - registerHiveConnector(hiveConfigs); + // Fuzzer test generates a lot of files and directories. Disable file + // handle cache to avoid EBADF errors. + auto configs = hiveConfigs; + configs[connector::hive::HiveConfig::kEnableFileHandleCache] = "false"; + registerHiveConnector(configs); dwrf::registerDwrfReaderFactory(); dwrf::registerDwrfWriterFactory(); @@ -121,6 +125,7 @@ class AggregationFuzzerBase { opts.stringVariableLength = true; opts.stringLength = 4'000; opts.nullRatio = FLAGS_null_ratio; + opts.useRandomNullPattern = true; opts.timestampPrecision = timestampPrecision; return opts; } diff --git a/velox/exec/fuzzer/AggregationFuzzerRunner.h b/velox/exec/fuzzer/AggregationFuzzerRunner.h index 6aa044d84ef..65f58aec142 100644 --- a/velox/exec/fuzzer/AggregationFuzzerRunner.h +++ b/velox/exec/fuzzer/AggregationFuzzerRunner.h @@ -107,13 +107,13 @@ class AggregationFuzzerRunner { facebook::velox::parse::registerTypeResolver(); facebook::velox::serializer::presto::PrestoVectorSerde:: registerVectorSerde(); - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kPresto)) { + if (!isRegisteredNamedVectorSerde("Presto")) { serializer::presto::PrestoVectorSerde::registerNamedVectorSerde(); } - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kCompactRow)) { + if (!isRegisteredNamedVectorSerde("CompactRow")) { serializer::CompactRowVectorSerde::registerNamedVectorSerde(); } - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kUnsafeRow)) { + if (!isRegisteredNamedVectorSerde("UnsafeRow")) { serializer::spark::UnsafeRowVectorSerde::registerNamedVectorSerde(); } facebook::velox::filesystems::registerLocalFileSystem(); diff --git a/velox/exec/fuzzer/CMakeLists.txt b/velox/exec/fuzzer/CMakeLists.txt index a826e6785ad..99e0e47e9f2 100644 --- a/velox/exec/fuzzer/CMakeLists.txt +++ b/velox/exec/fuzzer/CMakeLists.txt @@ -12,6 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Generate Thrift library for LocalRunnerService early since velox_fuzzer_util +# depends on it when VELOX_ENABLE_REMOTE_FUNCTIONS is enabled. +if(VELOX_ENABLE_REMOTE_FUNCTIONS) + include(FBThriftCppLibrary) + add_fbthrift_cpp_library( + local_runner_service_thrift + if/LocalRunnerService.thrift + SERVICES + LocalRunnerService + ) + target_compile_options(local_runner_service_thrift PRIVATE -Wno-error=deprecated-declarations) +endif() + add_library( velox_fuzzer_util ReferenceQueryRunner.cpp @@ -26,6 +39,30 @@ add_library( FuzzerUtil.cpp PrestoSql.cpp ) +velox_add_test_headers( + velox_fuzzer_util + DuckQueryRunner.h + DuckQueryRunnerToSqlPlanNodeVisitor.h + FuzzerUtil.h + PrestoQueryRunner.h + PrestoQueryRunnerIntermediateTypeTransforms.h + PrestoQueryRunnerIntervalTransform.h + PrestoQueryRunnerJsonTransform.h + PrestoQueryRunnerTimestampWithTimeZoneTransform.h + PrestoQueryRunnerToSqlPlanNodeVisitor.h + PrestoSql.h + ReferenceQueryRunner.h + VeloxQueryRunner.h +) + +# TODO Add VeloxQueryRunner to velox_fuzzer_util to support in +# ExpressionFuzzerTest. More information can be found here: +# https://github.com/facebookincubator/velox/issues/15414 +if(VELOX_ENABLE_REMOTE_FUNCTIONS) + target_sources(velox_fuzzer_util PRIVATE VeloxQueryRunner.cpp) + target_link_libraries(velox_fuzzer_util local_runner_service_thrift) + target_include_directories(velox_fuzzer_util PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) +endif() target_link_libraries( velox_fuzzer_util @@ -35,8 +72,7 @@ target_link_libraries( velox_exec_test_lib velox_expression_functions velox_presto_types - cpr::cpr - Boost::regex + CURL::libcurl velox_presto_type_parser Folly::folly velox_hive_connector @@ -46,11 +82,16 @@ target_link_libraries( ) add_library(velox_aggregation_fuzzer_base AggregationFuzzerBase.cpp) +velox_add_test_headers( + velox_aggregation_fuzzer_base + AggregationFuzzerBase.h + AggregationFuzzerOptions.h +) target_link_libraries( velox_aggregation_fuzzer_base velox_exec_test_lib - velox_temp_path + velox_test_util velox_common_base velox_file velox_hive_connector @@ -65,6 +106,7 @@ target_link_libraries( ) add_library(velox_aggregation_fuzzer AggregationFuzzer.cpp) +velox_add_test_headers(velox_aggregation_fuzzer AggregationFuzzer.h AggregationFuzzerRunner.h) target_link_libraries( velox_aggregation_fuzzer @@ -77,6 +119,7 @@ target_link_libraries( ) add_library(velox_window_fuzzer WindowFuzzer.cpp) +velox_add_test_headers(velox_window_fuzzer WindowFuzzer.h WindowFuzzerRunner.h) target_link_libraries( velox_window_fuzzer @@ -86,13 +129,14 @@ target_link_libraries( velox_exec_test_lib velox_expression_test_utility velox_aggregation_fuzzer_base - velox_temp_path + velox_test_util ) -add_library(velox_row_number_fuzzer_base_lib RowNumberFuzzerBase.cpp) +add_library(velox_spill_fuzzer_base_lib SpillFuzzerBase.cpp) +velox_add_test_headers(velox_spill_fuzzer_base_lib SpillFuzzerBase.h) target_link_libraries( - velox_row_number_fuzzer_base_lib + velox_spill_fuzzer_base_lib velox_dwio_dwrf_reader velox_fuzzer_util velox_vector_fuzzer @@ -100,10 +144,11 @@ target_link_libraries( ) add_library(velox_row_number_fuzzer_lib RowNumberFuzzer.cpp) +velox_add_test_headers(velox_row_number_fuzzer_lib RowNumberFuzzer.h) target_link_libraries( velox_row_number_fuzzer_lib - velox_row_number_fuzzer_base_lib + velox_spill_fuzzer_base_lib velox_type velox_expression_test_utility ) @@ -114,10 +159,11 @@ add_executable(velox_row_number_fuzzer RowNumberFuzzerRunner.cpp) target_link_libraries(velox_row_number_fuzzer velox_row_number_fuzzer_lib) add_library(velox_topn_row_number_fuzzer_lib TopNRowNumberFuzzer.cpp) +velox_add_test_headers(velox_topn_row_number_fuzzer_lib TopNRowNumberFuzzer.h) target_link_libraries( velox_topn_row_number_fuzzer_lib - velox_row_number_fuzzer_base_lib + velox_spill_fuzzer_base_lib velox_type velox_expression_test_utility ) @@ -125,10 +171,31 @@ target_link_libraries( # TopNRowNumber Fuzzer. add_executable(velox_topn_row_number_fuzzer TopNRowNumberFuzzerRunner.cpp) -target_link_libraries(velox_topn_row_number_fuzzer velox_topn_row_number_fuzzer_lib) +target_link_libraries( + velox_topn_row_number_fuzzer + velox_topn_row_number_fuzzer_lib + velox_functions_prestosql + velox_window +) + +add_library(velox_mark_distinct_fuzzer_lib MarkDistinctFuzzer.cpp) +velox_add_test_headers(velox_mark_distinct_fuzzer_lib MarkDistinctFuzzer.h) + +target_link_libraries( + velox_mark_distinct_fuzzer_lib + velox_spill_fuzzer_base_lib + velox_type + velox_expression_test_utility +) + +# MarkDistinct Fuzzer. +add_executable(velox_mark_distinct_fuzzer MarkDistinctFuzzerRunner.cpp) + +target_link_libraries(velox_mark_distinct_fuzzer velox_mark_distinct_fuzzer_lib velox_aggregates) # Join Fuzzer. add_executable(velox_join_fuzzer JoinFuzzerRunner.cpp JoinFuzzer.cpp JoinMaker.cpp) +velox_add_test_headers(velox_join_fuzzer JoinFuzzer.h JoinMaker.h) target_link_libraries( velox_join_fuzzer @@ -139,7 +206,22 @@ target_link_libraries( velox_expression_test_utility ) +# Spatial Join Fuzzer. +add_executable(velox_spatial_join_fuzzer SpatialJoinFuzzerRunner.cpp SpatialJoinFuzzer.cpp) +velox_add_test_headers(velox_spatial_join_fuzzer SpatialJoinFuzzer.h) + +target_link_libraries( + velox_spatial_join_fuzzer + velox_type + velox_vector_fuzzer + velox_fuzzer_util + velox_exec_test_lib + velox_expression_test_utility + velox_vector_test_lib +) + add_library(velox_writer_fuzzer WriterFuzzer.cpp) +velox_add_test_headers(velox_writer_fuzzer WriterFuzzer.h WriterFuzzerRunner.h) target_link_libraries( velox_writer_fuzzer @@ -148,7 +230,7 @@ target_link_libraries( velox_vector_fuzzer velox_exec_test_lib velox_expression_test_utility - velox_temp_path + velox_test_util velox_vector_test_lib velox_dwio_faulty_file_sink velox_file_test_utils @@ -160,6 +242,7 @@ add_executable( MemoryArbitrationFuzzerRunner.cpp MemoryArbitrationFuzzer.cpp ) +velox_add_test_headers(velox_memory_arbitration_fuzzer MemoryArbitrationFuzzer.h) target_link_libraries( velox_memory_arbitration_fuzzer @@ -167,6 +250,8 @@ target_link_libraries( velox_fuzzer_util velox_type velox_vector_fuzzer + velox_dwio_dwrf_reader + velox_dwio_dwrf_writer velox_exec_test_lib velox_expression_test_utility velox_functions_prestosql @@ -174,6 +259,7 @@ target_link_libraries( ) add_library(velox_cache_fuzzer_lib CacheFuzzer.cpp) +velox_add_test_headers(velox_cache_fuzzer_lib CacheFuzzer.h) # Cache Fuzzer add_executable(velox_cache_fuzzer CacheFuzzerRunner.cpp) @@ -183,7 +269,7 @@ target_link_libraries(velox_cache_fuzzer velox_cache_fuzzer_lib velox_fuzzer_uti target_link_libraries( velox_cache_fuzzer_lib velox_dwio_common - velox_temp_path + velox_test_util velox_vector_test_lib ) @@ -198,6 +284,53 @@ target_link_libraries( velox_vector_fuzzer ) +# LocalRunnerService Library (requires FBThrift support) +if(VELOX_ENABLE_REMOTE_FUNCTIONS) + add_library(velox_local_runner_service_lib LocalRunnerService.cpp) + velox_add_test_headers(velox_local_runner_service_lib LocalRunnerService.h) + + target_link_libraries( + velox_local_runner_service_lib + local_runner_service_thrift + velox_core + velox_exec + velox_exec_test_lib + velox_expression + velox_functions_prestosql + velox_common_base + velox_memory + Folly::folly + FBThrift::thriftcpp2 + gflags + glog::glog + ) + + target_include_directories(velox_local_runner_service_lib PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) + + # LocalRunnerService Executable + add_executable(velox_local_runner_service_runner LocalRunnerServiceRunner.cpp) + + target_link_libraries( + velox_local_runner_service_runner + velox_local_runner_service_lib + velox_functions_prestosql + gtest + gflags + Folly::folly + ) +endif() + if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) endif() + +velox_add_library(velox_expression_transformer INTERFACE HEADERS ExprTransformer.h) + +velox_add_library( + velox_fuzzer_generator_verifier + INTERFACE + HEADERS + InputGenerator.h + ResultVerifier.h + TransformResultVerifier.h +) diff --git a/velox/exec/fuzzer/CacheFuzzer.cpp b/velox/exec/fuzzer/CacheFuzzer.cpp index 57d364f7450..b518396413d 100644 --- a/velox/exec/fuzzer/CacheFuzzer.cpp +++ b/velox/exec/fuzzer/CacheFuzzer.cpp @@ -23,10 +23,11 @@ #include "velox/common/caching/SsdCache.h" #include "velox/common/file/FileSystems.h" #include "velox/common/file/tests/FaultyFileSystem.h" +#include "velox/common/io/IoStatistics.h" #include "velox/common/memory/Memory.h" #include "velox/common/memory/MmapAllocator.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/dwio/common/CachedBufferedInput.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" DEFINE_int32(steps, 10, "Number of plans to generate and test."); @@ -88,6 +89,7 @@ using namespace facebook::velox::dwio::common; using namespace facebook::velox::tests::utils; namespace facebook::velox::exec { +using namespace facebook::velox::common::testutil; namespace { class CacheFuzzer { @@ -157,11 +159,17 @@ class CacheFuzzer { // instead of random location for cache reuse. std::vector>> fileFragments_; std::vector> inputs_; - std::shared_ptr sourceDataDir_; - std::shared_ptr cacheDataDir_; + std::shared_ptr sourceDataDir_; + std::shared_ptr cacheDataDir_; std::unique_ptr memoryManager_; std::unique_ptr executor_; std::shared_ptr cache_; + + std::shared_ptr dataIoStats_{ + std::make_shared()}; + std::shared_ptr metadataIoStats_{ + std::make_shared()}; + // Save the config for the last iteration so they can be potentially reused // after restart. int64_t lastMemoryCacheBytes_; @@ -189,9 +197,8 @@ CacheFuzzer::CacheFuzzer(size_t initialSeed) { void CacheFuzzer::initSourceDataFiles() { // Skip errors on source data files. - sourceDataDir_ = exec::test::TempDirectoryPath::create(); - cacheDataDir_ = - exec::test::TempDirectoryPath::create(FLAGS_enable_file_faulty_injection); + sourceDataDir_ = TempDirectoryPath::create(); + cacheDataDir_ = TempDirectoryPath::create(FLAGS_enable_file_faulty_injection); fs_ = filesystems::getFileSystem(sourceDataDir_->getPath(), nullptr); // Create files with random sizes. @@ -370,27 +377,30 @@ void CacheFuzzer::initializeCache(bool restartCache) { } void CacheFuzzer::initializeInputs() { - const auto readOptions = io::ReaderOptions(pool_.get()); + io::ReaderOptions readOptions(pool_.get()); + readOptions.setDataIoStats(dataIoStats_); + readOptions.setMetadataIoStats(metadataIoStats_); auto tracker = std::make_shared( "testTracker", nullptr, 256 << 10 /*256KB*/); - auto ioStats = std::make_shared(); - auto fsStats = std::make_shared(); + auto ioStatistics = std::make_shared(); + auto ioStats = std::make_shared(); inputs_.reserve(FLAGS_num_source_files); for (auto i = 0; i < FLAGS_num_source_files; ++i) { // Initialize buffered input. auto readFile = fs_->openFileForRead(fileNames_[i]); auto const withExecutor = !folly::Random::oneIn(3, rng_); - inputs_.emplace_back(std::make_unique( - std::move(readFile), - MetricsLog::voidLog(), - fileIds_[i], // NOLINT - cache_.get(), - tracker, - fileIds_[i], // NOLINT - ioStats, - fsStats, - withExecutor ? executor_.get() : nullptr, - readOptions)); + inputs_.emplace_back( + std::make_unique( + std::move(readFile), + MetricsLog::voidLog(), + fileIds_[i], // NOLINT + cache_.get(), + tracker, + fileIds_[i], // NOLINT + ioStatistics, + ioStats, + withExecutor ? executor_.get() : nullptr, + readOptions)); // Divide file into fragments. std::vector> fragments; diff --git a/velox/exec/fuzzer/DuckQueryRunnerToSqlPlanNodeVisitor.cpp b/velox/exec/fuzzer/DuckQueryRunnerToSqlPlanNodeVisitor.cpp index 388e67ca36a..4ec4f00ca5c 100644 --- a/velox/exec/fuzzer/DuckQueryRunnerToSqlPlanNodeVisitor.cpp +++ b/velox/exec/fuzzer/DuckQueryRunnerToSqlPlanNodeVisitor.cpp @@ -277,7 +277,8 @@ void DuckQueryRunnerToSqlPlanNodeVisitor::visit( sql << inputType->nameOf(i); } - sql << ", row_number() OVER ("; + sql << ", " << core::TopNRowNumberNode::rankFunctionName(node.rankFunction()) + << "() OVER ("; const auto& partitionKeys = node.partitionKeys(); if (!partitionKeys.empty()) { diff --git a/velox/exec/fuzzer/DuckQueryRunnerToSqlPlanNodeVisitor.h b/velox/exec/fuzzer/DuckQueryRunnerToSqlPlanNodeVisitor.h index d7566ed833f..b7a655b7a88 100644 --- a/velox/exec/fuzzer/DuckQueryRunnerToSqlPlanNodeVisitor.h +++ b/velox/exec/fuzzer/DuckQueryRunnerToSqlPlanNodeVisitor.h @@ -92,6 +92,16 @@ class DuckQueryRunnerToSqlPlanNodeVisitor : public PrestoSqlPlanNodeVisitor { VELOX_NYI(); } + void visit(const core::EnforceDistinctNode&, core::PlanNodeVisitorContext&) + const override { + VELOX_NYI(); + } + + void visit(const core::MarkSortedNode&, core::PlanNodeVisitorContext&) + const override { + VELOX_NYI(); + } + void visit(const core::MergeExchangeNode&, core::PlanNodeVisitorContext&) const override { VELOX_NYI(); @@ -177,6 +187,11 @@ class DuckQueryRunnerToSqlPlanNodeVisitor : public PrestoSqlPlanNodeVisitor { VELOX_NYI(); } + void visit(const core::MixedUnionNode&, core::PlanNodeVisitorContext&) + const override { + VELOX_NYI(); + } + private: std::unordered_set aggregateFunctionNames_; }; diff --git a/velox/exec/fuzzer/ExchangeFuzzer.cpp b/velox/exec/fuzzer/ExchangeFuzzer.cpp index 4443b24fba2..9cd60c2a600 100644 --- a/velox/exec/fuzzer/ExchangeFuzzer.cpp +++ b/velox/exec/fuzzer/ExchangeFuzzer.cpp @@ -143,7 +143,7 @@ class ExchangeFuzzer : public VectorTestBase { } auto partialAggPlan = exec::test::PlanBuilder() - .exchange(leafPlan->outputType(), VectorSerde::Kind::kPresto) + .exchange(leafPlan->outputType(), "Presto") .partialAggregation({}, makeAggregates(rowType, 1)) .partitionedOutput({}, 1) .planNode(); @@ -158,14 +158,13 @@ class ExchangeFuzzer : public VectorTestBase { addRemoteSplits(task, leafTaskIds); } - auto plan = - exec::test::PlanBuilder() - .exchange(partialAggPlan->outputType(), VectorSerde::Kind::kPresto) - .finalAggregation( - {}, - makeAggregates(*partialAggPlan->outputType(), 0), - rawInputTypes) - .planNode(); + auto plan = exec::test::PlanBuilder() + .exchange(partialAggPlan->outputType(), "Presto") + .finalAggregation( + {}, + makeAggregates(*partialAggPlan->outputType(), 0), + rawInputTypes) + .planNode(); try { // Create the Task to do the final aggregation using a TaskCursor so we @@ -409,9 +408,11 @@ class ExchangeFuzzer : public VectorTestBase { LOG(INFO) << "Terminating with error"; exit(1); } - LOG(INFO) << "Memory after run=" - << succinctBytes(memory::AllocationTraits::pageBytes( - memory::memoryManager()->allocator()->numAllocated())); + LOG(INFO) + << "Memory after run=" + << succinctBytes( + memory::AllocationTraits::pageBytes( + memory::memoryManager()->allocator()->numAllocated())); if (FLAGS_duration_sec == 0 && FLAGS_steps && counter + 1 >= FLAGS_steps) { @@ -572,13 +573,13 @@ int main(int argc, char** argv) { aggregate::prestosql::registerAllAggregateFunctions(); parse::registerTypeResolver(); serializer::presto::PrestoVectorSerde::registerVectorSerde(); - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kPresto)) { + if (!isRegisteredNamedVectorSerde("Presto")) { serializer::presto::PrestoVectorSerde::registerNamedVectorSerde(); } - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kCompactRow)) { + if (!isRegisteredNamedVectorSerde("CompactRow")) { serializer::CompactRowVectorSerde::registerNamedVectorSerde(); } - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kUnsafeRow)) { + if (!isRegisteredNamedVectorSerde("UnsafeRow")) { serializer::spark::UnsafeRowVectorSerde::registerNamedVectorSerde(); } exec::ExchangeSource::registerFactory(exec::test::createLocalExchangeSource); diff --git a/velox/exec/fuzzer/FuzzerUtil.cpp b/velox/exec/fuzzer/FuzzerUtil.cpp index 04fc59363a8..60df03452e3 100644 --- a/velox/exec/fuzzer/FuzzerUtil.cpp +++ b/velox/exec/fuzzer/FuzzerUtil.cpp @@ -17,6 +17,7 @@ #include #include #include "velox/common/memory/SharedArbitrator.h" +#include "velox/connectors/ConnectorRegistry.h" #include "velox/connectors/hive/HiveConnector.h" #include "velox/connectors/hive/HiveConnectorSplit.h" #include "velox/dwio/catalog/fbhive/FileUtils.h" @@ -24,6 +25,7 @@ #include "velox/exec/fuzzer/DuckQueryRunner.h" #include "velox/exec/fuzzer/PrestoQueryRunner.h" #include "velox/expression/SignatureBinder.h" +#include "velox/functions/prestosql/types/IPAddressType.h" #include "velox/functions/prestosql/types/IPPrefixType.h" using namespace facebook::velox::dwio::catalog::fbhive; @@ -247,6 +249,35 @@ bool containsUnsupportedTypes(const TypePtr& type) { containsType(type, INTERVAL_DAY_TIME()); } +bool containsIPAddress(const TypePtr& type) { + if (type->isArray()) { + const auto& elementType = type->asArray().elementType(); + if (isIPAddressType(elementType)) { + return true; + } + return containsIPAddress(elementType); + } + + if (type->isMap()) { + const auto& keyType = type->asMap().keyType(); + const auto& valueType = type->asMap().valueType(); + if (isIPAddressType(keyType) || isIPAddressType(valueType)) { + return true; + } + return containsIPAddress(keyType) || containsIPAddress(valueType); + } + + if (type->isRow()) { + for (auto i = 0; i < type->size(); ++i) { + if (containsIPAddress(type->childAt(i))) { + return true; + } + } + } + + return false; +} + // Determine whether type is or contains typeName. typeName should be in lower // case. bool containTypeName( @@ -323,8 +354,9 @@ TypePtr sanitizeTryResolveType( const exec::TypeSignature& typeSignature, const std::unordered_map& variables, const std::unordered_map& resolvedTypeVariables) { - return sanitize(SignatureBinder::tryResolveType( - typeSignature, variables, resolvedTypeVariables)); + return sanitize( + SignatureBinder::tryResolveType( + typeSignature, variables, resolvedTypeVariables)); } TypePtr sanitizeTryResolveType( @@ -336,13 +368,14 @@ TypePtr sanitizeTryResolveType( longEnumParameterVariablesBindings, const std::unordered_map& varcharEnumParameterVariablesBindings) { - return sanitize(SignatureBinder::tryResolveType( - typeSignature, - variables, - typeVariablesBindings, - integerVariablesBindings, - longEnumParameterVariablesBindings, - varcharEnumParameterVariablesBindings)); + return sanitize( + SignatureBinder::tryResolveType( + typeSignature, + variables, + typeVariablesBindings, + integerVariablesBindings, + longEnumParameterVariablesBindings, + varcharEnumParameterVariablesBindings)); } void setupMemory( @@ -359,11 +392,13 @@ void setupMemory( options.checkUsageLeak = true; options.arbitrationStateCheckCb = memoryArbitrationStateCheck; options.extraArbitratorConfigs = { - {std::string(velox::memory::SharedArbitrator::ExtraConfig:: - kGlobalArbitrationEnabled), + {std::string( + velox::memory::SharedArbitrator::ExtraConfig:: + kGlobalArbitrationEnabled), enableGlobalArbitration ? "true" : "false"}, - {std::string(velox::memory::SharedArbitrator::ExtraConfig:: - kMemoryPoolMinReclaimBytes), + {std::string( + velox::memory::SharedArbitrator::ExtraConfig:: + kMemoryPoolMinReclaimBytes), "0B"}}; facebook::velox::memory::MemoryManager::initialize(options); } @@ -378,7 +413,8 @@ void registerHiveConnector( auto hiveConnector = factory.newConnector( kHiveConnectorId, std::make_shared(std::move(configs))); - connector::registerConnector(hiveConnector); + connector::ConnectorRegistry::global().insert( + hiveConnector->connectorId(), hiveConnector); } std::unique_ptr setupReferenceQueryRunner( diff --git a/velox/exec/fuzzer/FuzzerUtil.h b/velox/exec/fuzzer/FuzzerUtil.h index 5ad39818ca5..da306395546 100644 --- a/velox/exec/fuzzer/FuzzerUtil.h +++ b/velox/exec/fuzzer/FuzzerUtil.h @@ -93,6 +93,17 @@ RowTypePtr concat(const RowTypePtr& a, const RowTypePtr& b); /// TODO Investigate mismatches reported when comparing Varbinary. bool containsUnsupportedTypes(const TypePtr& type); +/// Checks if a type contains IPADDRESS in any container position (array +/// element, map key, or map value) at any nesting level. Returns false for +/// bare IPADDRESS or IPADDRESS directly in a ROW field. +/// +/// Presto's Int128ArrayBlock doesn't implement compareTo(), which causes +/// failures when IPADDRESS appears in containers that require element-level +/// comparison (arrays, map keys, map values). +/// +/// See: https://github.com/prestodb/presto/issues/26836 +bool containsIPAddress(const TypePtr& type); + /// Determines whether the signature has an argument that contains typeName. /// typeName should be in lower case. bool usesInputTypeName( diff --git a/velox/exec/fuzzer/JoinFuzzer.cpp b/velox/exec/fuzzer/JoinFuzzer.cpp index 3edbc634cbe..33801b3d4a3 100644 --- a/velox/exec/fuzzer/JoinFuzzer.cpp +++ b/velox/exec/fuzzer/JoinFuzzer.cpp @@ -16,6 +16,8 @@ #include "velox/exec/fuzzer/JoinFuzzer.h" #include #include "velox/common/file/FileSystems.h" +#include "velox/common/testutil/TempDirectoryPath.h" +#include "velox/connectors/ConnectorRegistry.h" #include "velox/connectors/hive/HiveConnector.h" #include "velox/dwio/dwrf/RegisterDwrfReader.h" #include "velox/dwio/dwrf/RegisterDwrfWriter.h" @@ -25,7 +27,6 @@ #include "velox/exec/fuzzer/JoinMaker.h" #include "velox/exec/fuzzer/ReferenceQueryRunner.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/vector/fuzzer/VectorFuzzer.h" DEFINE_int32(steps, 10, "Number of plans to generate and test."); @@ -66,6 +67,7 @@ DEFINE_double( "The chance of testing plans with filters enabled."); namespace facebook::velox::exec { +using namespace facebook::velox::common::testutil; namespace { @@ -88,6 +90,7 @@ class JoinFuzzer { opts.stringVariableLength = true; opts.stringLength = 100; opts.nullRatio = FLAGS_null_ratio; + opts.useRandomNullPattern = true; opts.timestampPrecision = VectorFuzzer::Options::TimestampPrecision::kMilliSeconds; return opts; @@ -210,7 +213,8 @@ JoinFuzzer::JoinFuzzer( auto hiveConnector = factory.newConnector( test::kHiveConnectorId, std::make_shared(std::move(hiveConfig))); - connector::registerConnector(hiveConnector); + connector::ConnectorRegistry::global().insert( + hiveConnector->connectorId(), hiveConnector); dwrf::registerDwrfReaderFactory(); dwrf::registerDwrfWriterFactory(); @@ -389,10 +393,10 @@ RowVectorPtr JoinFuzzer::execute( builder.numConcurrentSplitGroups(randInt(1, plan.numGroups)); } - std::shared_ptr spillDirectory; + std::shared_ptr spillDirectory; int32_t spillPct{0}; if (injectSpill) { - spillDirectory = exec::test::TempDirectoryPath::create(); + spillDirectory = TempDirectoryPath::create(); builder.config(core::QueryConfig::kSpillEnabled, true) .config(core::QueryConfig::kJoinSpillEnabled, true) .config(core::QueryConfig::kMixedGroupedModeHashJoinSpillEnabled, true) @@ -617,7 +621,7 @@ void JoinFuzzer::verify(core::JoinType joinType) { } } - const auto tableScanDir = exec::test::TempDirectoryPath::create(); + const auto tableScanDir = TempDirectoryPath::create(); auto localFs = filesystems::getFileSystem(tableScanDir->getPath(), nullptr); std::string probePath = fmt::format("{}/{}", tableScanDir->getPath(), "probe"); diff --git a/velox/exec/fuzzer/JoinFuzzerRunner.cpp b/velox/exec/fuzzer/JoinFuzzerRunner.cpp index 7c350643793..b8eefa21b67 100644 --- a/velox/exec/fuzzer/JoinFuzzerRunner.cpp +++ b/velox/exec/fuzzer/JoinFuzzerRunner.cpp @@ -101,18 +101,15 @@ int main(int argc, char** argv) { facebook::velox::filesystems::registerLocalFileSystem(); facebook::velox::functions::prestosql::registerAllScalarFunctions(); facebook::velox::parse::registerTypeResolver(); - if (!isRegisteredNamedVectorSerde( - facebook::velox::VectorSerde::Kind::kPresto)) { + if (!facebook::velox::isRegisteredNamedVectorSerde("Presto")) { facebook::velox::serializer::presto::PrestoVectorSerde:: registerNamedVectorSerde(); } - if (!isRegisteredNamedVectorSerde( - facebook::velox::VectorSerde::Kind::kCompactRow)) { + if (!facebook::velox::isRegisteredNamedVectorSerde("CompactRow")) { facebook::velox::serializer::CompactRowVectorSerde:: registerNamedVectorSerde(); } - if (!isRegisteredNamedVectorSerde( - facebook::velox::VectorSerde::Kind::kUnsafeRow)) { + if (!facebook::velox::isRegisteredNamedVectorSerde("UnsafeRow")) { facebook::velox::serializer::spark::UnsafeRowVectorSerde:: registerNamedVectorSerde(); } diff --git a/velox/exec/fuzzer/JoinMaker.cpp b/velox/exec/fuzzer/JoinMaker.cpp index 07e2b66d6b3..0addc7f94ba 100644 --- a/velox/exec/fuzzer/JoinMaker.cpp +++ b/velox/exec/fuzzer/JoinMaker.cpp @@ -62,10 +62,11 @@ std::vector makeSourcesForPartitionedJoinPlan( std::vector sourceNodes; for (const auto& sourceInput : sourceInputs) { - sourceNodes.push_back(test::PlanBuilder(planNodeIdGenerator) - .values(sourceInput) - .projectExpressions(joinSource->projections()) - .planNode()); + sourceNodes.push_back( + test::PlanBuilder(planNodeIdGenerator) + .values(sourceInput) + .projectExpressions(joinSource->projections()) + .planNode()); } return sourceNodes; @@ -610,7 +611,7 @@ core::JoinType flipJoinType(core::JoinType joinType) { auto flippedJoinType = tryFlipJoinType(joinType); if (!flippedJoinType.has_value()) { - VELOX_UNSUPPORTED(fmt::format("Unable to flip join type: {}", joinType)); + VELOX_UNSUPPORTED("Unable to flip join type: {}", joinType); } return *flippedJoinType; diff --git a/velox/exec/fuzzer/LocalRunnerService.cpp b/velox/exec/fuzzer/LocalRunnerService.cpp new file mode 100644 index 00000000000..2475f6af40e --- /dev/null +++ b/velox/exec/fuzzer/LocalRunnerService.cpp @@ -0,0 +1,168 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/fuzzer/if/gen-cpp2/LocalRunnerService.h" +#include +#include +#include +#include +#include +#include +#include +#include "velox/common/base/Exceptions.h" +#include "velox/common/memory/ByteStream.h" +#include "velox/exec/fuzzer/LocalRunnerService.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/serializers/PrestoSerializer.h" + +using namespace facebook::velox; +using namespace facebook::velox::runner; + +namespace facebook::velox::runner { +namespace { + +class StdoutCapture { + public: + StdoutCapture() { + oldCoutBuf_ = std::cout.rdbuf(); + std::cout.rdbuf(buffer_.rdbuf()); + } + ~StdoutCapture() { + std::cout.rdbuf(oldCoutBuf_); + } + std::string str() { + return buffer_.str(); + } + + private: + std::stringstream buffer_; + std::streambuf* oldCoutBuf_; +}; + +std::pair execute( + const std::string& serializedPlan, + const std::string& queryId, + std::shared_ptr pool) { + StdoutCapture stdoutCapture; + + core::PlanNodePtr plan; + try { + folly::dynamic planJson = folly::parseJson(serializedPlan); + VLOG(1) << "Deserializing plan:\n" << serializedPlan; + plan = core::PlanNode::deserialize(planJson, pool.get()); + } catch (const std::exception& e) { + throw std::runtime_error( + fmt::format("Failed to deserialize plan: {}", e.what())); + } + VLOG(1) << "Deserialized plan:\n" << plan->toString(true, true); + + try { + exec::test::AssertQueryBuilder queryBuilder(plan); + queryBuilder.config("session_timezone", "America/Los_Angeles"); + + std::shared_ptr task; + auto results = queryBuilder.copyResults(pool.get(), task); + + return {results, stdoutCapture.str()}; + } catch (const std::exception& e) { + throw std::runtime_error( + fmt::format("Error executing query: {}", e.what())); + } +} + +} // namespace + +std::string serializeBatch( + const RowVectorPtr& rowVector, + memory::MemoryPool* pool) { + std::ostringstream out; + + OStreamOutputStream outputStream(&out); + + auto serde = std::make_unique(); + serializer::presto::PrestoVectorSerde::PrestoOptions options; + + auto serializer = serde->createBatchSerializer(pool, &options); + serializer->serialize(rowVector, &outputStream); + + return out.str(); +} + +std::vector convertToBatches( + const std::vector& rowVectors, + memory::MemoryPool* pool) { + std::vector results; + + if (rowVectors.empty()) { + return results; + } + + auto leafPool = pool->addLeafChild("batchSerialization"); + + for (const auto& rowVector : rowVectors) { + Batch result; + const auto& rowType = rowVector->type()->asRow(); + + for (auto i = 0; i < rowType.size(); ++i) { + result.columnNames()->push_back(rowType.nameOf(i)); + result.columnTypes()->push_back(rowType.childAt(i)->toString()); + } + + std::string serializedData = serializeBatch(rowVector, leafPool.get()); + result.serializedData() = std::move(serializedData); + + results.push_back(std::move(result)); + } + + return results; +} + +void LocalRunnerServiceHandler::execute( + ExecutePlanResponse& response, + std::unique_ptr request) { + VLOG(1) << "Received executePlan request"; + + auto rootPool = memory::memoryManager()->addRootPool(); + auto pool = rootPool->addLeafChild("localRunnerHandler"); + + RowVectorPtr results; + std::string output; + + try { + VLOG(1) << "Executing plan in service handler"; + std::tie(results, output) = + ::execute(*request->serializedPlan(), *request->queryId(), pool); + + VLOG(1) << fmt::format( + "Result:\nresult rowVector: {}\nstdout: {}", + results->toString(true), + output); + } catch (const std::exception& e) { + VLOG(1) << "Exception executing plan: " << e.what(); + response.success() = false; + response.errorMessage() = e.what(); + return; + } + + VLOG(1) << "Converting results to Thrift response"; + auto resultBatches = convertToBatches({results}, rootPool.get()); + response.results() = std::move(resultBatches); + response.output() = output; + response.success() = true; + VLOG(1) << "Response sent"; +} + +} // namespace facebook::velox::runner diff --git a/velox/exec/fuzzer/LocalRunnerService.h b/velox/exec/fuzzer/LocalRunnerService.h new file mode 100644 index 00000000000..dbd5f5d8f17 --- /dev/null +++ b/velox/exec/fuzzer/LocalRunnerService.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/// Thrift service implementation and library for executing Velox query plans +/// remotely. +/// +/// This file provides conversion utilities and a service handler for the +/// LocalRunnerService. It enables remote execution of serialized Velox +/// expression evaluation primarily used for fuzzing where query plans need to +/// be executed on remote workers. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "velox/exec/fuzzer/if/gen-cpp2/LocalRunnerService.h" +#include "velox/vector/ComplexVector.h" + +namespace facebook::velox::runner { + +/// Converts a collection of Velox RowVectors into Thrift Batches using +/// binary serialization. +std::vector convertToBatches( + const std::vector& rowVectors, + memory::MemoryPool* pool); + +/// Thrift service handler for executing Velox query plans. +/// Executes a serialized Velox query plan. This method deserializes the plan +/// from JSON, configures execution, runs the query plan to completion, +/// converts results to Thrift Batches and captures any subsequent errors or +/// output. The method returns a Thrift response containing the results. +class LocalRunnerServiceHandler + : public apache::thrift::ServiceHandler { + public: + void execute( + ExecutePlanResponse& response, + std::unique_ptr request) override; +}; + +} // namespace facebook::velox::runner diff --git a/velox/exec/fuzzer/LocalRunnerServiceRunner.cpp b/velox/exec/fuzzer/LocalRunnerServiceRunner.cpp new file mode 100644 index 00000000000..7bb15f9d355 --- /dev/null +++ b/velox/exec/fuzzer/LocalRunnerServiceRunner.cpp @@ -0,0 +1,59 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "velox/core/ITypedExpr.h" +#include "velox/core/PlanNode.h" +#include "velox/exec/fuzzer/LocalRunnerService.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/type/Type.h" + +using namespace facebook::velox; +using namespace facebook::velox::runner; + +DEFINE_int32( + port, + 9091, + "LocalRunnerService port number to be used in conjunction with ExpressionFuzzerTest flag local_runner_port."); + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + + folly::Init init(&argc, &argv); + + memory::initializeMemoryManager(memory::MemoryManager::Options{}); + Type::registerSerDe(); + core::PlanNode::registerSerDe(); + core::ITypedExpr::registerSerDe(); + functions::prestosql::registerAllScalarFunctions(); + functions::prestosql::registerInternalFunctions(); + + std::shared_ptr thriftServer = + std::make_shared(); + thriftServer->setPort(FLAGS_port); + thriftServer->setInterface(std::make_shared()); + thriftServer->setNumIOWorkerThreads(1); + thriftServer->setNumCPUWorkerThreads(1); + + VLOG(1) << "Starting LocalRunnerService"; + thriftServer->serve(); + + return 0; +} diff --git a/velox/exec/fuzzer/MarkDistinctFuzzer.cpp b/velox/exec/fuzzer/MarkDistinctFuzzer.cpp new file mode 100644 index 00000000000..af9ef508419 --- /dev/null +++ b/velox/exec/fuzzer/MarkDistinctFuzzer.cpp @@ -0,0 +1,223 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/fuzzer/MarkDistinctFuzzer.h" + +#include + +#include "velox/common/testutil/TempDirectoryPath.h" +#include "velox/exec/fuzzer/FuzzerUtil.h" +#include "velox/exec/fuzzer/SpillFuzzerBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" + +namespace facebook::velox::exec { +using namespace facebook::velox::common::testutil; +namespace { + +class MarkDistinctFuzzer : public SpillFuzzerBase { + public: + explicit MarkDistinctFuzzer( + size_t initialSeed, + std::unique_ptr); + + private: + void runSingleIteration() override; + + std::pair, std::vector> + generatePartitionKeys(); + + std::vector generateInput( + const std::vector& keyNames, + const std::vector& keyTypes); + + // Makes the query plan: Values -> MarkDistinct -> Aggregation. + // MarkDistinct marks distinct rows for (groupKey, distinctKey) combinations, + // then aggregation uses the mask to compute count(distinct distinctKey). + static PlanWithSplits makeDefaultPlan( + const std::string& groupKey, + const std::string& distinctKey, + const std::vector& input); + + static PlanWithSplits makePlanWithTableScan( + const RowTypePtr& type, + const std::string& groupKey, + const std::string& distinctKey, + const std::vector& splits); + + void addPlansWithTableScan( + const std::string& tableDir, + const std::string& groupKey, + const std::string& distinctKey, + const std::vector& input, + std::vector& altPlans); +}; + +MarkDistinctFuzzer::MarkDistinctFuzzer( + size_t initialSeed, + std::unique_ptr referenceQueryRunner) + : SpillFuzzerBase(initialSeed, std::move(referenceQueryRunner)) { + vectorFuzzer_.getMutableOptions().timestampPrecision = + fuzzer::FuzzerTimestampPrecision::kMilliSeconds; +} + +std::pair, std::vector> +MarkDistinctFuzzer::generatePartitionKeys() { + // Generate 1-3 keys for grouping/distinct. + const auto numKeys = randInt(1, 3); + std::vector names; + std::vector types; + for (auto i = 0; i < numKeys; ++i) { + names.push_back(fmt::format("c{}", i)); + types.push_back(vectorFuzzer_.randType(/*maxDepth=*/1)); + } + return std::make_pair(names, types); +} + +std::vector MarkDistinctFuzzer::generateInput( + const std::vector& keyNames, + const std::vector& keyTypes) { + std::vector names = keyNames; + std::vector types = keyTypes; + + // Add up to 2 payload columns. + const auto numPayload = randInt(0, 2); + for (auto i = 0; i < numPayload; ++i) { + names.push_back(fmt::format("c{}", i + keyNames.size())); + types.push_back(vectorFuzzer_.randType(/*maxDepth=*/2)); + } + + const auto inputType = ROW(std::move(names), std::move(types)); + std::vector input; + input.reserve(FLAGS_num_batches); + for (auto i = 0; i < FLAGS_num_batches; ++i) { + input.push_back(vectorFuzzer_.fuzzInputRow(inputType)); + } + + return input; +} + +PlanWithSplits MarkDistinctFuzzer::makeDefaultPlan( + const std::string& groupKey, + const std::string& distinctKey, + const std::vector& input) { + // Build plan: Values -> MarkDistinct -> Aggregation + // This is equivalent to: SELECT groupKey, count(DISTINCT distinctKey) + // FROM input GROUP BY groupKey + auto plan = test::PlanBuilder() + .values(input) + .markDistinct("distinct_marker", {groupKey, distinctKey}) + .singleAggregation( + {groupKey}, + {"count(\"" + distinctKey + "\")"}, + {"distinct_marker"}) + .planNode(); + return PlanWithSplits{std::move(plan)}; +} + +PlanWithSplits MarkDistinctFuzzer::makePlanWithTableScan( + const RowTypePtr& type, + const std::string& groupKey, + const std::string& distinctKey, + const std::vector& splits) { + auto plan = test::PlanBuilder() + .tableScan(type) + .markDistinct("distinct_marker", {groupKey, distinctKey}) + .singleAggregation( + {groupKey}, + {"count(\"" + distinctKey + "\")"}, + {"distinct_marker"}) + .planNode(); + return PlanWithSplits{plan, splits}; +} + +void MarkDistinctFuzzer::addPlansWithTableScan( + const std::string& tableDir, + const std::string& groupKey, + const std::string& distinctKey, + const std::vector& input, + std::vector& altPlans) { + VELOX_CHECK(!tableDir.empty()); + + if (!isTableScanSupported(input[0]->type())) { + return; + } + + const std::vector inputSplits = test::makeSplits( + input, fmt::format("{}/mark_distinct", tableDir), writerPool_); + altPlans.push_back(makePlanWithTableScan( + asRowType(input[0]->type()), groupKey, distinctKey, inputSplits)); +} + +void MarkDistinctFuzzer::runSingleIteration() { + const auto [keyNames, keyTypes] = generatePartitionKeys(); + + // We need at least 2 keys: one for GROUP BY and one for DISTINCT. + // If only 1 key was generated, add another one. + std::vector allNames = keyNames; + std::vector allTypes = keyTypes; + if (allNames.size() < 2) { + allNames.push_back(fmt::format("c{}", allNames.size())); + allTypes.push_back(vectorFuzzer_.randType(/*maxDepth=*/1)); + } + + const auto input = generateInput(allNames, allTypes); + test::logVectors(input); + + // Use first key as group-by key, second as distinct key. + const auto& groupKey = allNames[0]; + const auto& distinctKey = allNames[1]; + + auto defaultPlan = makeDefaultPlan(groupKey, distinctKey, input); + const auto expected = execute(defaultPlan, /*injectSpill=*/false, false); + + // Validate against DuckDB using an equivalent plan without MarkDistinct. + // DuckDB cannot translate MarkDistinctNode to SQL, so we build a reference + // plan that uses count(DISTINCT ...) directly. + if (expected != nullptr) { + auto referencePlan = + test::PlanBuilder() + .values(input) + .singleAggregation( + {groupKey}, + {fmt::format("count(DISTINCT \"{}\")", distinctKey)}) + .planNode(); + validateExpectedResults(referencePlan, input, expected); + } + + std::vector altPlans; + altPlans.push_back(std::move(defaultPlan)); + + const auto tableScanDir = TempDirectoryPath::create(); + addPlansWithTableScan( + tableScanDir->getPath(), groupKey, distinctKey, input, altPlans); + + for (auto i = 0; i < altPlans.size(); ++i) { + testPlan( + altPlans[i], + i, + expected, + "core::QueryConfig::kMarkDistinctSpillEnabled"); + } +} + +} // namespace + +void markDistinctFuzzer( + size_t seed, + std::unique_ptr referenceQueryRunner) { + MarkDistinctFuzzer(seed, std::move(referenceQueryRunner)).run(); +} +} // namespace facebook::velox::exec diff --git a/velox/experimental/wave/dwio/nimble/fuzzer/NimbleReaderFuzzer.h b/velox/exec/fuzzer/MarkDistinctFuzzer.h similarity index 50% rename from velox/experimental/wave/dwio/nimble/fuzzer/NimbleReaderFuzzer.h rename to velox/exec/fuzzer/MarkDistinctFuzzer.h index 26c3ede8213..bcd34e62a6b 100644 --- a/velox/experimental/wave/dwio/nimble/fuzzer/NimbleReaderFuzzer.h +++ b/velox/exec/fuzzer/MarkDistinctFuzzer.h @@ -13,30 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once #include #include "velox/exec/fuzzer/ReferenceQueryRunner.h" -namespace facebook::velox::wave { -struct NimbleReaderFuzzerOptions { - size_t minNumStreams{1}; - size_t maxNumStreams{10}; - size_t minNumChunks{1}; - size_t maxNumChunks{10}; - size_t minChunkSize{1}; - size_t minNumValues{10}; - size_t maxNumValues{10000}; - double hasNullProbability{0.5}; - double hasFilterProbability{0.75}; - double isFilterProbability{0.5}; - double nullAllowedProbability{0}; - double filterKeepValuesProbability{1.0}; - std::vector types{INTEGER(), BIGINT(), REAL(), DOUBLE()}; -}; -void nimbleReaderFuzzer( +namespace facebook::velox::exec { +void markDistinctFuzzer( size_t seed, - NimbleReaderFuzzerOptions options, - std::unique_ptr referenceQueryRunner); -} // namespace facebook::velox::wave + std::unique_ptr referenceQueryRunner); +} diff --git a/velox/exec/fuzzer/MarkDistinctFuzzerRunner.cpp b/velox/exec/fuzzer/MarkDistinctFuzzerRunner.cpp new file mode 100644 index 00000000000..3e2bb69e5c5 --- /dev/null +++ b/velox/exec/fuzzer/MarkDistinctFuzzerRunner.cpp @@ -0,0 +1,97 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "velox/common/memory/SharedArbitrator.h" +#include "velox/exec/fuzzer/FuzzerUtil.h" +#include "velox/exec/fuzzer/MarkDistinctFuzzer.h" +#include "velox/exec/fuzzer/ReferenceQueryRunner.h" +#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" + +/// MarkDistinctFuzzerRunner leverages MarkDistinctFuzzer and VectorFuzzer to +/// automatically generate and execute tests. It works as follows: +/// +/// 1. Plan Generation: Generate two equivalent query plans, one is +/// mark-distinct + aggregation over ValuesNode and the other is over +/// TableScanNode. +/// 2. Executes a variety of logically equivalent query plans and checks the +/// results are the same. +/// 3. Rinse and repeat. +/// +/// It is used as follows: +/// +/// $ ./velox_mark_distinct_fuzzer --duration_sec 600 +/// +/// The flags that configure MarkDistinctFuzzer's behavior are: +/// +/// --steps: how many iterations to run. +/// --duration_sec: alternatively, for how many seconds it should run (takes +/// precedence over --steps). +/// --seed: pass a deterministic seed to reproduce the behavior (each iteration +/// will print a seed as part of the logs). +/// --v=1: verbose logging; print a lot more details about the execution. +/// --batch_size: size of input vector batches generated. +/// --num_batches: number of input vector batches to generate. +/// --enable_spill: test plans with spilling enabled. +/// --enable_oom_injection: randomly trigger OOM while executing query plans. +/// e.g: +/// +/// $ ./velox_mark_distinct_fuzzer \ +/// --seed 123 \ +/// --duration_sec 600 \ +/// --v=1 + +DEFINE_int64( + seed, + 0, + "Initial seed for random number generator used to reproduce previous " + "results (0 means start with random seed)."); + +DEFINE_string( + presto_url, + "", + "Presto coordinator URI along with port. If set, we use Presto " + "source of truth. Otherwise, use DuckDB. Example: " + "--presto_url=http://127.0.0.1:8080"); + +DEFINE_uint32( + req_timeout_ms, + 1000, + "Timeout in milliseconds for HTTP requests made to reference DB, " + "such as Presto. Example: --req_timeout_ms=2000"); + +DEFINE_int64(allocator_capacity, 8L << 30, "Allocator capacity in bytes."); + +DEFINE_int64(arbitrator_capacity, 6L << 30, "Arbitrator capacity in bytes."); + +using namespace facebook::velox; + +int main(int argc, char** argv) { + folly::Init init(&argc, &argv); + exec::test::setupMemory(FLAGS_allocator_capacity, FLAGS_arbitrator_capacity); + aggregate::prestosql::registerAllAggregateFunctions(); + std::shared_ptr rootPool{ + memory::memoryManager()->addRootPool()}; + auto referenceQueryRunner = exec::test::setupReferenceQueryRunner( + rootPool.get(), + FLAGS_presto_url, + "mark_distinct_fuzzer", + FLAGS_req_timeout_ms); + const size_t initialSeed = FLAGS_seed == 0 ? std::time(nullptr) : FLAGS_seed; + exec::markDistinctFuzzer(initialSeed, std::move(referenceQueryRunner)); +} diff --git a/velox/exec/fuzzer/MemoryArbitrationFuzzer.cpp b/velox/exec/fuzzer/MemoryArbitrationFuzzer.cpp index 1d855140727..b0c0078f88b 100644 --- a/velox/exec/fuzzer/MemoryArbitrationFuzzer.cpp +++ b/velox/exec/fuzzer/MemoryArbitrationFuzzer.cpp @@ -18,18 +18,20 @@ #include #include +#include #include "velox/common/file/FileSystems.h" #include "velox/common/file/tests/FaultyFileSystem.h" #include "velox/common/fuzzer/Utils.h" +#include "velox/common/testutil/TempDirectoryPath.h" +#include "velox/connectors/ConnectorRegistry.h" #include "velox/connectors/hive/HiveConnector.h" -#include "velox/dwio/dwrf/RegisterDwrfReader.h" // @manual -#include "velox/dwio/dwrf/RegisterDwrfWriter.h" // @manual +#include "velox/dwio/dwrf/RegisterDwrfReader.h" +#include "velox/dwio/dwrf/RegisterDwrfWriter.h" #include "velox/exec/MemoryReclaimer.h" #include "velox/exec/fuzzer/FuzzerUtil.h" #include "velox/exec/tests/utils/ArbitratorTestUtil.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/serializers/CompactRowSerializer.h" #include "velox/serializers/PrestoSerializer.h" #include "velox/serializers/UnsafeRowSerializer.h" @@ -99,9 +101,16 @@ DEFINE_int32( "After each specified number of milliseconds, abort a random task." "If given 0, no task will be aborted."); +DEFINE_string( + plan_type, + "all", + "Type of plans to test. Options: all, hash_join, aggregate, " + "row_number, topn_row_number, order_by."); + using namespace facebook::velox::tests::utils; namespace facebook::velox::exec { +using namespace facebook::velox::common::testutil; namespace { using fuzzer::coinToss; @@ -150,7 +159,7 @@ class MemoryArbitrationFuzzer { return boost::random::uniform_int_distribution(min, max)(rng_); } - std::shared_ptr maybeGenerateFaultySpillDirectory(); + std::shared_ptr maybeGenerateFaultySpillDirectory(); // Returns a list of randomly generated key types for join and aggregation. std::vector generateKeyTypes(int32_t numKeys); @@ -161,7 +170,8 @@ class MemoryArbitrationFuzzer { // Returns randomly generated input with up to 3 additional payload columns. std::vector generateInput( const std::vector& keyNames, - const std::vector& keyTypes); + const std::vector& keyTypes, + int32_t minPayload = 0); // Reuses the 'generateInput' method to return randomly generated // probe input. @@ -181,6 +191,12 @@ class MemoryArbitrationFuzzer { const std::vector& keyNames, const std::vector& keyTypes); + // Reuses the 'generateInput' method to return randomly generated + // topN row number input. + std::vector generateTopNRowNumberInput( + const std::vector& keyNames, + const std::vector& keyTypes); + // Reuses the 'generateInput' method to return randomly generated // order by input. std::vector generateOrderByInput( @@ -210,6 +226,8 @@ class MemoryArbitrationFuzzer { std::vector rowNumberPlans(const std::string& tableDir); + std::vector topNRowNumberPlans(const std::string& tableDir); + std::vector orderByPlans(const std::string& tableDir); // Helper method that combines all above plan methods into one. @@ -237,6 +255,7 @@ class MemoryArbitrationFuzzer { {core::QueryConfig::kSpillStartPartitionBit, "29"}, {core::QueryConfig::kAggregationSpillEnabled, "true"}, {core::QueryConfig::kRowNumberSpillEnabled, "true"}, + {core::QueryConfig::kTopNRowNumberSpillEnabled, "true"}, {core::QueryConfig::kOrderBySpillEnabled, "true"}, }; @@ -256,7 +275,7 @@ class MemoryArbitrationFuzzer { VectorFuzzer vectorFuzzer_; std::shared_ptr executor_{ std::make_shared( - std::thread::hardware_concurrency())}; + folly::available_concurrency())}; folly::Synchronized stats_; }; @@ -266,13 +285,13 @@ MemoryArbitrationFuzzer::MemoryArbitrationFuzzer(size_t initialSeed) // paritition key, and presto doesn't supports nanosecond precision. vectorFuzzer_.getMutableOptions().timestampPrecision = fuzzer::FuzzerTimestampPrecision::kMilliSeconds; - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kPresto)) { + if (!isRegisteredNamedVectorSerde("Presto")) { serializer::presto::PrestoVectorSerde::registerNamedVectorSerde(); } - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kCompactRow)) { + if (!isRegisteredNamedVectorSerde("CompactRow")) { serializer::CompactRowVectorSerde::registerNamedVectorSerde(); } - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kUnsafeRow)) { + if (!isRegisteredNamedVectorSerde("UnsafeRow")) { serializer::spark::UnsafeRowVectorSerde::registerNamedVectorSerde(); } // Make sure not to run out of open file descriptors. @@ -283,7 +302,8 @@ MemoryArbitrationFuzzer::MemoryArbitrationFuzzer(size_t initialSeed) const auto hiveConnector = hiveFactory.newConnector( test::kHiveConnectorId, std::make_shared(std::move(hiveConfig))); - connector::registerConnector(hiveConnector); + connector::ConnectorRegistry::global().insert( + hiveConnector->connectorId(), hiveConnector); dwrf::registerDwrfReaderFactory(); dwrf::registerDwrfWriterFactory(); @@ -325,7 +345,8 @@ MemoryArbitrationFuzzer::generatePartitionKeys() { std::vector MemoryArbitrationFuzzer::generateInput( const std::vector& keyNames, - const std::vector& keyTypes) { + const std::vector& keyTypes, + int32_t minPayload) { std::vector names = keyNames; std::vector types = keyTypes; @@ -337,8 +358,7 @@ std::vector MemoryArbitrationFuzzer::generateInput( } } - // Add up to 3 payload columns. - const auto numPayload = randInt(0, 3); + const auto numPayload = randInt(minPayload, 3); for (auto i = 0; i < numPayload; ++i) { names.push_back(fmt::format("tp{}", i + keyNames.size())); types.push_back(vectorFuzzer_.randType(2 /*maxDepth*/)); @@ -433,6 +453,12 @@ std::vector MemoryArbitrationFuzzer::generateRowNumberInput( return generateInput(keyNames, keyTypes); } +std::vector MemoryArbitrationFuzzer::generateTopNRowNumberInput( + const std::vector& keyNames, + const std::vector& keyTypes) { + return generateInput(keyNames, keyTypes, 1); +} + std::vector MemoryArbitrationFuzzer::generateOrderByInput( const std::vector& keyNames, const std::vector& keyTypes) { @@ -503,9 +529,10 @@ MemoryArbitrationFuzzer::hashJoinPlans( joinType, false) .planNode(); - plans.push_back(PlanWithSplits{ - std::move(plan), - {{probeScanId, probeSplits}, {buildScanId, buildSplits}}}); + plans.push_back( + PlanWithSplits{ + std::move(plan), + {{probeScanId, probeSplits}, {buildScanId, buildSplits}}}); return plans; } @@ -680,6 +707,90 @@ MemoryArbitrationFuzzer::rowNumberPlans(const std::string& tableDir) { return plans; } +std::vector +MemoryArbitrationFuzzer::topNRowNumberPlans(const std::string& tableDir) { + static const std::vector kRankFunctions = { + "row_number", "rank", "dense_rank"}; + + const auto [keyNames, keyTypes] = generatePartitionKeys(); + const auto input = generateTopNRowNumberInput(keyNames, keyTypes); + + std::vector plans; + + const auto inputType = asRowType(input[0]->type()); + std::vector sortingKeys; + + std::unordered_set partitionKeySet( + keyNames.begin(), keyNames.end()); + for (const auto& name : inputType->names()) { + if (partitionKeySet.find(name) == partitionKeySet.end()) { + sortingKeys.push_back(name); + } + } + + const auto numSortingKeys = randInt(1, sortingKeys.size()); + sortingKeys.resize(numSortingKeys); + + const auto rankFunction = + kRankFunctions[randInt(0, kRankFunctions.size() - 1)]; + const auto limit = randInt(1, 100); + const bool generateRowNumber = vectorFuzzer_.coinToss(0.5); + + std::vector projectFields = keyNames; + if (generateRowNumber) { + projectFields.emplace_back("row_number"); + } + + // Values plan with Partiton Keys + auto plan = PlanWithSplits{ + test::PlanBuilder() + .values(input) + .topNRank( + rankFunction, keyNames, sortingKeys, limit, generateRowNumber) + .project(projectFields) + .planNode(), + {}}; + plans.push_back(std::move(plan)); + + if (!test::isTableScanSupported(input[0]->type())) { + return plans; + } + + const std::vector splits = test::makeSplits( + input, fmt::format("{}/topn_row_number", tableDir), writerPool_); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId scanId; + // TableScan Plan with Parition Keys + plan = PlanWithSplits{ + test::PlanBuilder(planNodeIdGenerator) + .tableScan(asRowType(input[0]->type())) + .capturePlanNodeId(scanId) + .topNRank( + rankFunction, keyNames, sortingKeys, limit, generateRowNumber) + .project(projectFields) + .planNode(), + {{scanId, splits}}}; + plans.push_back(std::move(plan)); + + std::vector globalProjectFields; + if (generateRowNumber) { + globalProjectFields.emplace_back("row_number"); + } + + // Global TopN + plan = PlanWithSplits{ + test::PlanBuilder() + .values(input) + .topNRank(rankFunction, {}, sortingKeys, limit, generateRowNumber) + .project(globalProjectFields) + .planNode(), + {}}; + plans.push_back(std::move(plan)); + + return plans; +} + std::vector MemoryArbitrationFuzzer::orderByPlans(const std::string& tableDir) { const auto [keyNames, keyTypes] = generatePartitionKeys(); @@ -716,18 +827,40 @@ MemoryArbitrationFuzzer::orderByPlans(const std::string& tableDir) { std::vector MemoryArbitrationFuzzer::allPlans(const std::string& tableDir) { std::vector plans; - for (const auto& plan : hashJoinPlans(tableDir)) { - plans.push_back(plan); - } - for (const auto& plan : aggregatePlans(tableDir)) { - plans.push_back(plan); - } - for (const auto& plan : rowNumberPlans(tableDir)) { - plans.push_back(plan); - } - for (const auto& plan : orderByPlans(tableDir)) { - plans.push_back(plan); - } + const std::string planType = FLAGS_plan_type; + + auto appendPlansIf = + [&](const std::string& type, + std::function(const std::string&)> + planGenerator) { + if (planType == "all" || planType == type) { + auto newPlans = planGenerator(tableDir); + plans.insert( + plans.end(), + std::make_move_iterator(newPlans.begin()), + std::make_move_iterator(newPlans.end())); + } + }; + appendPlansIf("hash_join", [this](const std::string& dir) { + return hashJoinPlans(dir); + }); + appendPlansIf("aggregate", [this](const std::string& dir) { + return aggregatePlans(dir); + }); + appendPlansIf("row_number", [this](const std::string& dir) { + return rowNumberPlans(dir); + }); + appendPlansIf("topn_row_number", [this](const std::string& dir) { + return topNRowNumberPlans(dir); + }); + appendPlansIf( + "order_by", [this](const std::string& dir) { return orderByPlans(dir); }); + + VELOX_USER_CHECK( + !plans.empty(), + "No plans generated for plan_type: {}. Valid options are: all, hash_join, aggregate, row_number, topn_row_number, order_by", + planType); + return plans; } @@ -746,12 +879,12 @@ std::string MemoryArbitrationFuzzer::extractQueryIdFromSpillPath( // Stats that keeps track of per thread execution status in verify() folly::ConcurrentHashMap spillFsTaskSet; -std::shared_ptr +std::shared_ptr MemoryArbitrationFuzzer::maybeGenerateFaultySpillDirectory() { FuzzerGenerator fsRng(rng_()); const auto injectFsFault = coinToss(fsRng, FLAGS_spill_faulty_fs_ratio); if (!injectFsFault) { - return exec::test::TempDirectoryPath::create(false); + return TempDirectoryPath::create(false); } using OpType = FaultFileOperation::Type; static const std::vector> opTypes{ @@ -762,7 +895,7 @@ MemoryArbitrationFuzzer::maybeGenerateFaultySpillDirectory() { {OpType::kRead, OpType::kWrite}, {OpType::kReadv, OpType::kWrite}}; - const auto directory = exec::test::TempDirectoryPath::create(true); + const auto directory = TempDirectoryPath::create(true); auto faultyFileSystem = std::dynamic_pointer_cast( filesystems::getFileSystem(directory->getPath(), nullptr)); faultyFileSystem->setFileInjectionHook( @@ -788,7 +921,7 @@ MemoryArbitrationFuzzer::maybeGenerateFaultySpillDirectory() { void MemoryArbitrationFuzzer::verify() { auto spillDirectory = maybeGenerateFaultySpillDirectory(); - const auto tableScanDir = exec::test::TempDirectoryPath::create(false); + const auto tableScanDir = TempDirectoryPath::create(false); auto plans = allPlans(tableScanDir->getPath()); @@ -824,7 +957,10 @@ void MemoryArbitrationFuzzer::verify() { const auto plan = plans.at(getRandomIndex(rng, plans.size() - 1)); test::AssertQueryBuilder builder(plan.plan); - builder.queryCtx(queryCtx); + // Use a long timeout (1 hour) to avoid false failures from CI thread + // starvation while still catching real deadlocks. + static constexpr uint64_t kOneHourUs{3'600'000'000ULL}; + builder.queryCtx(queryCtx).maxWaitMicros(kOneHourUs); for (const auto& [planNodeId, nodeSplits] : plan.splits) { builder.splits(planNodeId, nodeSplits); } @@ -856,21 +992,6 @@ void MemoryArbitrationFuzzer::verify() { const auto injectedTaskAbortRequest = queryTaskAbortRequestMap.find(queryId)->second; - // Debug logging to understand the failure - if (!injectedSpillFsFault && !injectedTaskAbortRequest) { - LOG(ERROR) << "============== VELOX_CHECK failure debug info:"; - LOG(ERROR) << " queryId: " << queryId; - LOG(ERROR) << " spillFsTaskSet size: " << spillFsTaskSet.size(); - LOG(ERROR) << " spillFsTaskSet contents:"; - // Iterate through spillFsTaskSet to log contents - for (auto it = spillFsTaskSet.cbegin(); - it != spillFsTaskSet.cend(); - ++it) { - LOG(ERROR) << " key: " << it->first; - } - LOG(ERROR) << " error message: " << e.message(); - } - VELOX_CHECK( injectedSpillFsFault || injectedTaskAbortRequest, "injectedSpillFsFault: {}, injectedTaskAbortRequest: {}, error message: {}", @@ -967,13 +1088,13 @@ void MemoryArbitrationFuzzer::go() { size_t iteration = 0; while (!isDone(iteration, startTime)) { - LOG(WARNING) << "==============================> Started iteration " - << iteration << " (seed: " << currentSeed_ << ")"; + LOG(INFO) << "==============================> Started iteration " + << iteration << " (seed: " << currentSeed_ << ")"; verify(); + stats_.rlock()->print(); LOG(INFO) << "==============================> Done with iteration " << iteration; - stats_.rlock()->print(); reSeed(); ++iteration; diff --git a/velox/exec/fuzzer/MemoryArbitrationFuzzerRunner.cpp b/velox/exec/fuzzer/MemoryArbitrationFuzzerRunner.cpp index cfa1e928c1d..01f3fccc038 100644 --- a/velox/exec/fuzzer/MemoryArbitrationFuzzerRunner.cpp +++ b/velox/exec/fuzzer/MemoryArbitrationFuzzerRunner.cpp @@ -28,6 +28,7 @@ #include "velox/exec/fuzzer/MemoryArbitrationFuzzer.h" #include "velox/exec/fuzzer/PrestoQueryRunner.h" #include "velox/exec/fuzzer/ReferenceQueryRunner.h" +#include "velox/serializers/PrestoSerializer.h" DEFINE_int64(allocator_capacity, 32L << 30, "Allocator capacity in bytes."); diff --git a/velox/exec/fuzzer/PrestoQueryRunner.cpp b/velox/exec/fuzzer/PrestoQueryRunner.cpp index 27b524a3af5..da32ed01f0d 100644 --- a/velox/exec/fuzzer/PrestoQueryRunner.cpp +++ b/velox/exec/fuzzer/PrestoQueryRunner.cpp @@ -14,9 +14,10 @@ * limitations under the License. */ -#include // @manual +#include #include #include +#include #include #include "velox/common/base/Fs.h" @@ -38,9 +39,13 @@ #include "velox/functions/prestosql/types/IPAddressType.h" #include "velox/functions/prestosql/types/IPPrefixType.h" #include "velox/functions/prestosql/types/JsonType.h" +#include "velox/functions/prestosql/types/KHyperLogLogType.h" +#include "velox/functions/prestosql/types/PrestoTypes.h" #include "velox/functions/prestosql/types/QDigestType.h" +#include "velox/functions/prestosql/types/SetDigestType.h" #include "velox/functions/prestosql/types/SfmSketchType.h" #include "velox/functions/prestosql/types/TDigestType.h" +#include "velox/functions/prestosql/types/TimeWithTimezoneType.h" #include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" #include "velox/functions/prestosql/types/UuidType.h" #include "velox/functions/prestosql/types/parser/TypeParser.h" @@ -51,6 +56,13 @@ using namespace facebook::velox; namespace facebook::velox::exec::test { namespace { +static size_t +writeFunction(char* data, size_t size, size_t nmemb, void* userdata) { + std::string* response = static_cast(userdata); + response->append(data, size * nmemb); + return size * nmemb; +} + void writeToFile( const std::string& path, const std::vector& data, @@ -132,8 +144,9 @@ class ServerResponse { std::vector types; for (const auto& column : response_["columns"]) { names.push_back(column["name"].asString()); - types.push_back(facebook::velox::functions::prestosql::parseType( - column["type"].asString())); + types.push_back( + facebook::velox::functions::prestosql::parseType( + column["type"].asString())); } auto rowType = ROW(std::move(names), std::move(types)); @@ -192,17 +205,24 @@ const std::vector& PrestoQueryRunner::supportedScalarTypes() const { VARBINARY(), TIMESTAMP(), TIMESTAMP_WITH_TIME_ZONE(), + IPADDRESS(), }; return kScalarTypes; } // static bool PrestoQueryRunner::isSupportedDwrfType(const TypePtr& type) { - if (type->isDate() || type->isIntervalDayTime() || type->isUnKnown() || + if (type->isDate() || type->isIntervalDayTime() || type->isUnknown() || isGeometryType(type)) { return false; } + // Block IPADDRESS in containers due to Presto's Int128ArrayBlock + // not supporting compareTo(). + if (containsIPAddress(type)) { + return false; + } + for (auto i = 0; i < type->size(); ++i) { if (!isSupportedDwrfType(type->childAt(i))) { return false; @@ -243,8 +263,9 @@ PrestoQueryRunner::inputProjections( children[batchIndex].push_back(input[batchIndex]->childAt(childIndex)); } - projections.push_back(std::make_shared( - names[childIndex], names[childIndex])); + projections.push_back( + std::make_shared( + names[childIndex], names[childIndex])); } } @@ -258,12 +279,13 @@ PrestoQueryRunner::inputProjections( std::vector output; output.reserve(input.size()); for (int batchIndex = 0; batchIndex < input.size(); batchIndex++) { - output.push_back(std::make_shared( - input[batchIndex]->pool(), - rowType, - input[batchIndex]->nulls(), - input[batchIndex]->size(), - std::move(children[batchIndex]))); + output.push_back( + std::make_shared( + input[batchIndex]->pool(), + rowType, + input[batchIndex]->nulls(), + input[batchIndex]->size(), + std::move(children[batchIndex]))); } return std::make_pair(output, projections); @@ -306,30 +328,48 @@ bool PrestoQueryRunner::isConstantExprSupported( // same timezone as Velox. Interval type cannot be used as the type of // constant literals in Presto SQL. auto& type = expr->type(); - return type->isPrimitiveType() && !type->isTimestamp() && + return type->isPrimitiveType() && !type->isTimestamp() && !type->isTime() && !isJsonType(type) && !type->isIntervalDayTime() && !isIPAddressType(type) && !isIPPrefixType(type) && !isUuidType(type) && !isTimestampWithTimeZoneType(type) && !isHyperLogLogType(type) && - !isTDigestType(type) && !isQDigestType(type) && !isBingTileType(type) && - !isSfmSketchType(type); - ; + !isKHyperLogLogType(type) && !isTDigestType(type) && + !isQDigestType(type) && !isSetDigestType(type) && + !isBingTileType(type) && !isSfmSketchType(type) && + !isTimeWithTimeZone(type); } return true; } bool PrestoQueryRunner::isSupported(const exec::FunctionSignature& signature) { - // TODO: support queries with these types. Among the types below, hugeint is - // not a native type in Presto, so fuzzer should not use it as the type of - // cast-to or constant literals. Hyperloglog and TDigest can only be casted - // from varbinary and cannot be used as the type of constant literals. - // Interval year to month can only be casted from NULL and cannot be used as - // the type of constant literals. Json, Ipaddress, Ipprefix, and UUID require - // special handling, because Presto requires literals of these types to be - // valid, and doesn't allow creating HIVE columns of these types. + // TODO: support queries with these types. + // Types not supported by PrestoQueryRunner and their reasons: + // + // hugeint: + // - Not a native type in Presto + // - Fuzzer should not use it for cast-to or constant literals + // + // interval year to month: + // - Can only be casted from NULL + // - Cannot be used as constant literal types + // + // ipaddress, ipprefix, uuid: + // - Require special handling in Presto + // - Presto requires literals of these types to be valid + // - Cannot create HIVE columns of these types + // + // geometry: + // - Under development in Presto + // - Cannot be used as constant literals + // - Expected differences between Presto Java and Velox C++ implementations + // + // p4hyperloglog: + // - Not a native type in Presto + // - Cannot create HIVE columns of these types return !( usesTypeName(signature, "interval year to month") || usesTypeName(signature, "hugeint") || - usesInputTypeName(signature, "ipaddress") || + usesTypeName(signature, "geometry") || usesTypeName(signature, "time") || + usesTypeName(signature, "p4hyperloglog") || usesInputTypeName(signature, "ipprefix") || usesInputTypeName(signature, "uuid")); } @@ -357,22 +397,24 @@ std::string PrestoQueryRunner::createTable( for (auto i = 0; i < inputType->size(); ++i) { appendComma(i, nullValues); nullValues << fmt::format( - "cast(null as {})", toTypeSql(inputType->childAt(i))); + "cast(null as {})", PrestoTypes::toSql(inputType->childAt(i))); } execute(fmt::format("DROP TABLE IF EXISTS {}", name)); - execute(fmt::format( - "CREATE TABLE {}({}) WITH (format = 'DWRF') AS SELECT {}", - name, - folly::join(", ", inputType->names()), - nullValues.str())); + execute( + fmt::format( + "CREATE TABLE {}({}) WITH (format = 'DWRF') AS SELECT {}", + name, + folly::join(", ", inputType->names()), + nullValues.str())); // Query Presto to find out table's location on disk. auto results = execute(fmt::format("SELECT \"$path\" FROM {}", name)); - auto filePath = extractSingleValue(results); - auto tableDirectoryPath = fs::path(filePath).parent_path(); + + // TODO: Remove explicit std::string_view cast. + auto tableDirectoryPath = fs::path(std::string_view(filePath)).parent_path(); // Delete the all-null row. execute(fmt::format("DELETE FROM {}", name)); @@ -471,38 +513,82 @@ std::vector PrestoQueryRunner::execute( std::string PrestoQueryRunner::startQuery( const std::string& sql, const std::string& sessionProperty) { - auto uri = fmt::format("{}/v1/statement?binaryResults=true", coordinatorUri_); - cpr::Url url{uri}; - cpr::Body body{sql}; - cpr::Header header( - {{"X-Presto-User", user_}, - {"X-Presto-Catalog", "hive"}, - {"X-Presto-Schema", "tpch"}, - {"Content-Type", "text/plain"}, - {"X-Presto-Session", sessionProperty}}); - cpr::Timeout timeout{timeout_}; - cpr::Response response = cpr::Post(url, body, header, timeout); + CURL* curl = curl_easy_init(); + VELOX_CHECK_NOT_NULL(curl, "Failed to initialize libcurl"); + + // Prepare curl headers + struct curl_slist* headers = nullptr; + headers = curl_slist_append( + headers, fmt::format("X-Presto-User: {}", user_).c_str()); + headers = curl_slist_append(headers, "X-Presto-Catalog: hive"); + headers = curl_slist_append(headers, "X-Presto-Schema: tpch"); + headers = curl_slist_append(headers, "Content-Type: text/plain"); + headers = curl_slist_append( + headers, fmt::format("X-Presto-Session: {}", sessionProperty).c_str()); + + std::string url = + fmt::format("{}/v1/statement?binaryResults=true", coordinatorUri_); + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + curl_easy_setopt(curl, CURLOPT_POST, 1L); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, sql.c_str()); + curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, sql.size()); + curl_easy_setopt(curl, CURLOPT_TIMEOUT_MS, timeout_); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, writeFunction); + curl_easy_setopt(curl, CURLOPT_SSLVERSION, CURL_SSLVERSION_TLSv1_2); + + std::string response; + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response); + + // Perform the request + CURLcode res = curl_easy_perform(curl); + + // Clean up CURL resources before checking the result to avoid leaks on + // error. + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + VELOX_CHECK_EQ( - response.status_code, - 200, + CURLE_OK, + res, "POST to {} failed: {}", - uri, - response.error.message); - return response.text; + coordinatorUri_, + curl_easy_strerror(res)); + + return response; } std::string PrestoQueryRunner::fetchNext(const std::string& nextUri) { - cpr::Url url(nextUri); - cpr::Header header({{"X-Presto-Client-Binary-Results", "true"}}); - cpr::Timeout timeout{timeout_}; - cpr::Response response = cpr::Get(url, header, timeout); + CURL* curl = curl_easy_init(); + VELOX_CHECK_NOT_NULL(curl, "Failed to initialize libcurl"); + + // Set up the request URL + curl_easy_setopt(curl, CURLOPT_URL, nextUri.c_str()); + + // Set up headers + struct curl_slist* headers = nullptr; + headers = curl_slist_append(headers, "X-Presto-Client-Binary-Results: true"); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + curl_easy_setopt(curl, CURLOPT_TIMEOUT_MS, timeout_); + curl_easy_setopt(curl, CURLOPT_SSLVERSION, CURL_SSLVERSION_TLSv1_2); + + // Capture the response body + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, writeFunction); + std::string response; + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response); + + // Perform GET request + CURLcode res = curl_easy_perform(curl); + + // Clean up CURL resources before checking the result to avoid leaks on + // error. + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + VELOX_CHECK_EQ( - response.status_code, - 200, - "GET from {} failed: {}", - nextUri, - response.error.message); - return response.text; + CURLE_OK, res, "Get request failed: {}", curl_easy_strerror(res)); + + return response; } bool PrestoQueryRunner::supportsVeloxVectorResults() const { @@ -510,3 +596,10 @@ bool PrestoQueryRunner::supportsVeloxVectorResults() const { } } // namespace facebook::velox::exec::test + +template <> +struct fmt::formatter : formatter { + auto format(CURLcode s, format_context& ctx) const { + return formatter::format(static_cast(s), ctx); + } +}; diff --git a/velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.cpp b/velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.cpp index 059263c9941..3d61b4ba83f 100644 --- a/velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.cpp +++ b/velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.cpp @@ -21,10 +21,14 @@ #include "velox/expression/Expr.h" #include "velox/functions/prestosql/types/BingTileType.h" #include "velox/functions/prestosql/types/HyperLogLogType.h" +#include "velox/functions/prestosql/types/IPAddressType.h" #include "velox/functions/prestosql/types/JsonType.h" +#include "velox/functions/prestosql/types/KHyperLogLogType.h" #include "velox/functions/prestosql/types/QDigestType.h" +#include "velox/functions/prestosql/types/SetDigestType.h" #include "velox/functions/prestosql/types/SfmSketchType.h" #include "velox/functions/prestosql/types/TDigestType.h" +#include "velox/functions/prestosql/types/TimeWithTimezoneType.h" #include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" #include "velox/parse/Expressions.h" #include "velox/parse/TypeResolver.h" @@ -50,6 +54,9 @@ intermediateTypeTransforms() { {HYPERLOGLOG(), std::make_shared( HYPERLOGLOG(), VARBINARY())}, + {KHYPERLOGLOG(), + std::make_shared( + KHYPERLOGLOG(), VARBINARY())}, {TDIGEST(DOUBLE()), std::make_shared( TDIGEST(DOUBLE()), VARBINARY())}, @@ -62,14 +69,26 @@ intermediateTypeTransforms() { {QDIGEST(REAL()), std::make_shared( QDIGEST(REAL()), VARBINARY())}, + {SETDIGEST(), + std::make_shared( + SETDIGEST(), VARBINARY())}, {SFMSKETCH(), std::make_shared( SFMSKETCH(), VARBINARY())}, {JSON(), std::make_shared()}, + {TIME(), + std::make_shared( + TIME(), VARCHAR())}, + {TIME_WITH_TIME_ZONE(), + std::make_shared( + TIME_WITH_TIME_ZONE(), VARCHAR())}, {BINGTILE(), std::make_shared( BINGTILE(), BIGINT())}, {INTERVAL_DAY_TIME(), std::make_shared()}, + {IPADDRESS(), + std::make_shared( + IPADDRESS(), VARCHAR())}, }; return intermediateTypeTransforms; } @@ -237,10 +256,11 @@ core::ExprPtr getProjectionForRow( rowType.nameOf(i), transformDirection)); } else { - children.push_back(std::make_shared( - rowType.nameOf(i), - rowType.nameOf(i), - std::vector{inputExpr})); + children.push_back( + std::make_shared( + rowType.nameOf(i), + rowType.nameOf(i), + std::vector{inputExpr})); } } diff --git a/velox/exec/fuzzer/PrestoQueryRunnerToSqlPlanNodeVisitor.cpp b/velox/exec/fuzzer/PrestoQueryRunnerToSqlPlanNodeVisitor.cpp index 4a828fd6813..04259b9c9c3 100644 --- a/velox/exec/fuzzer/PrestoQueryRunnerToSqlPlanNodeVisitor.cpp +++ b/velox/exec/fuzzer/PrestoQueryRunnerToSqlPlanNodeVisitor.cpp @@ -216,7 +216,7 @@ void PrestoQueryRunnerToSqlPlanNodeVisitor::visit( // ) // AS SELECT * FROM t_ std::stringstream sql; - sql << "CREATE TABLE tmp_write"; + sql << "CREATE TABLE " << ReferenceQueryRunner::getWriteTableName(); std::vector partitionKeys; for (auto i = 0; i < node.columnNames().size(); ++i) { if (insertTableHandle->inputColumns()[i]->isPartitionKey()) { @@ -291,7 +291,8 @@ void PrestoQueryRunnerToSqlPlanNodeVisitor::visit( sql << inputType->nameOf(i); } - sql << ", row_number() OVER ("; + sql << ", " << core::TopNRowNumberNode::rankFunctionName(node.rankFunction()) + << "() OVER ("; const auto& partitionKeys = node.partitionKeys(); if (!partitionKeys.empty()) { diff --git a/velox/exec/fuzzer/PrestoQueryRunnerToSqlPlanNodeVisitor.h b/velox/exec/fuzzer/PrestoQueryRunnerToSqlPlanNodeVisitor.h index dd26b66611c..ac58b4e7995 100644 --- a/velox/exec/fuzzer/PrestoQueryRunnerToSqlPlanNodeVisitor.h +++ b/velox/exec/fuzzer/PrestoQueryRunnerToSqlPlanNodeVisitor.h @@ -95,6 +95,16 @@ class PrestoQueryRunnerToSqlPlanNodeVisitor : public PrestoSqlPlanNodeVisitor { VELOX_NYI(); } + void visit(const core::EnforceDistinctNode&, core::PlanNodeVisitorContext&) + const override { + VELOX_NYI(); + } + + void visit(const core::MarkSortedNode&, core::PlanNodeVisitorContext&) + const override { + VELOX_NYI(); + } + void visit(const core::MergeExchangeNode&, core::PlanNodeVisitorContext&) const override { VELOX_NYI(); @@ -179,6 +189,11 @@ class PrestoQueryRunnerToSqlPlanNodeVisitor : public PrestoSqlPlanNodeVisitor { void visit(const core::WindowNode& node, core::PlanNodeVisitorContext& ctx) const override; + void visit(const core::MixedUnionNode&, core::PlanNodeVisitorContext&) + const override { + VELOX_NYI(); + } + /// Used to visit custom PlanNodes that extend the set provided by Velox. void visit(const core::PlanNode&, core::PlanNodeVisitorContext&) const override { diff --git a/velox/exec/fuzzer/PrestoSql.cpp b/velox/exec/fuzzer/PrestoSql.cpp index 6b034ef89d1..1effd97e49c 100644 --- a/velox/exec/fuzzer/PrestoSql.cpp +++ b/velox/exec/fuzzer/PrestoSql.cpp @@ -19,6 +19,7 @@ #include "velox/exec/fuzzer/PrestoQueryRunner.h" #include "velox/exec/fuzzer/ReferenceQueryRunner.h" #include "velox/functions/prestosql/types/JsonType.h" +#include "velox/functions/prestosql/types/PrestoTypes.h" #include "velox/vector/SimpleVector.h" namespace facebook::velox::exec::test { @@ -71,36 +72,6 @@ void appendComma(int32_t i, std::stringstream& sql) { } } -// Returns the SQL string of the given type. -std::string toTypeSql(const TypePtr& type) { - switch (type->kind()) { - case TypeKind::ARRAY: - return fmt::format("ARRAY({})", toTypeSql(type->childAt(0))); - case TypeKind::MAP: - return fmt::format( - "MAP({}, {})", - toTypeSql(type->childAt(0)), - toTypeSql(type->childAt(1))); - case TypeKind::ROW: { - const auto& rowType = type->asRow(); - std::stringstream sql; - sql << "ROW("; - for (auto i = 0; i < type->size(); ++i) { - appendComma(i, sql); - // TODO Field names may need to be quoted. - sql << rowType.nameOf(i) << " " << toTypeSql(type->childAt(i)); - } - sql << ")"; - return sql.str(); - } - default: - if (type->isPrimitiveType()) { - return type->toString(); - } - VELOX_UNSUPPORTED("Type is not supported: {}", type->toString()); - } -} - std::string toLambdaSql(const core::LambdaTypedExprPtr& lambda) { std::stringstream sql; const auto& signature = lambda->signature(); @@ -346,7 +317,7 @@ std::string toCastSql(const core::CastTypedExpr& cast) { sql << "cast("; } toCallInputsSql(cast.inputs(), sql); - sql << " as " << toTypeSql(cast.type()); + sql << " as " << facebook::velox::PrestoTypes::toSql(cast.type()); sql << ")"; return sql.str(); } @@ -355,7 +326,9 @@ std::string toConcatSql(const core::ConcatTypedExpr& concat) { std::stringstream input; toCallInputsSql(concat.inputs(), input); return fmt::format( - "cast(row({}) as {})", input.str(), toTypeSql(concat.type())); + "cast(row({}) as {})", + input.str(), + facebook::velox::PrestoTypes::toSql(concat.type())); } template @@ -381,7 +354,7 @@ std::string getConstantValue(const core::ConstantTypedExpr& expr) { std::string toConstantSql(const core::ConstantTypedExpr& constant) { const auto& type = constant.type(); - const auto typeSql = toTypeSql(type); + const auto typeSql = facebook::velox::PrestoTypes::toSql(type); std::stringstream sql; if (constant.isNull()) { diff --git a/velox/exec/fuzzer/PrestoSql.h b/velox/exec/fuzzer/PrestoSql.h index 2acfaaa2df9..7fce82729c8 100644 --- a/velox/exec/fuzzer/PrestoSql.h +++ b/velox/exec/fuzzer/PrestoSql.h @@ -23,9 +23,6 @@ namespace facebook::velox::exec::test { /// than 0. void appendComma(int32_t i, std::stringstream& sql); -/// Return the SQL string of type. -std::string toTypeSql(const TypePtr& type); - /// Converts input expressions into SQL string and appends to a given /// stringstream. void toCallInputsSql( diff --git a/velox/exec/fuzzer/ReferenceQueryRunner.cpp b/velox/exec/fuzzer/ReferenceQueryRunner.cpp index 8561db46ead..b6d212ec7d5 100644 --- a/velox/exec/fuzzer/ReferenceQueryRunner.cpp +++ b/velox/exec/fuzzer/ReferenceQueryRunner.cpp @@ -15,12 +15,36 @@ */ #include +#include + #include "velox/core/PlanNode.h" #include "velox/exec/fuzzer/PrestoSql.h" #include "velox/exec/fuzzer/ReferenceQueryRunner.h" +DEFINE_string( + table_name_prefix, + "", + "Prefix for temporary table names created by fuzzer reference query " + "runners. Use to avoid collisions when running multiple fuzzer instances " + "against the same database server."); + namespace facebook::velox::exec::test { +std::string ReferenceQueryRunner::getTableName( + const core::ValuesNode& valuesNode) { + if (FLAGS_table_name_prefix.empty()) { + return fmt::format("t_{}", valuesNode.id()); + } + return fmt::format("{}_t_{}", FLAGS_table_name_prefix, valuesNode.id()); +} + +std::string ReferenceQueryRunner::getWriteTableName() { + if (FLAGS_table_name_prefix.empty()) { + return "tmp_write"; + } + return fmt::format("{}_tmp_write", FLAGS_table_name_prefix); +} + std::unordered_map> ReferenceQueryRunner::getAllTables(const core::PlanNodePtr& plan) { std::unordered_map> result; diff --git a/velox/exec/fuzzer/ReferenceQueryRunner.h b/velox/exec/fuzzer/ReferenceQueryRunner.h index 22e01c6e622..d277ea4347f 100644 --- a/velox/exec/fuzzer/ReferenceQueryRunner.h +++ b/velox/exec/fuzzer/ReferenceQueryRunner.h @@ -55,7 +55,8 @@ class ReferenceQueryRunner { enum class RunnerType { kPrestoQueryRunner, kDuckQueryRunner, - kSparkQueryRunner + kSparkQueryRunner, + kVeloxQueryRunner }; // @param aggregatePool Used to allocate memory needed for vectors produced @@ -161,10 +162,15 @@ class ReferenceQueryRunner { VELOX_UNSUPPORTED(); } - /// Returns the name of the values node table in the form t_. - static std::string getTableName(const core::ValuesNode& valuesNode) { - return fmt::format("t_{}", valuesNode.id()); - } + /// Returns the name of the values node table in the form + /// [_]t_. When --table_name_prefix is set, the prefix is + /// prepended to avoid collisions between parallel fuzzer instances + /// sharing the same database server. + static std::string getTableName(const core::ValuesNode& valuesNode); + + /// Returns the write destination table name, incorporating the + /// --table_name_prefix when set. + static std::string getWriteTableName(); protected: memory::MemoryPool* aggregatePool() { diff --git a/velox/exec/fuzzer/RowNumberFuzzer.cpp b/velox/exec/fuzzer/RowNumberFuzzer.cpp index 7c8c911db75..7d43ec4fd38 100644 --- a/velox/exec/fuzzer/RowNumberFuzzer.cpp +++ b/velox/exec/fuzzer/RowNumberFuzzer.cpp @@ -18,15 +18,16 @@ #include #include "velox/common/file/FileSystems.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/exec/fuzzer/FuzzerUtil.h" -#include "velox/exec/fuzzer/RowNumberFuzzerBase.h" +#include "velox/exec/fuzzer/SpillFuzzerBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" namespace facebook::velox::exec { +using namespace facebook::velox::common::testutil; namespace { -class RowNumberFuzzer : public RowNumberFuzzerBase { +class RowNumberFuzzer : public SpillFuzzerBase { public: explicit RowNumberFuzzer( size_t initialSeed, @@ -68,7 +69,7 @@ class RowNumberFuzzer : public RowNumberFuzzerBase { RowNumberFuzzer::RowNumberFuzzer( size_t initialSeed, std::unique_ptr referenceQueryRunner) - : RowNumberFuzzerBase(initialSeed, std::move(referenceQueryRunner)) { + : SpillFuzzerBase(initialSeed, std::move(referenceQueryRunner)) { // Set timestamp precision as milliseconds, as timestamp may be used as // paritition key, and presto doesn't supports nanosecond precision. vectorFuzzer_.getMutableOptions().timestampPrecision = @@ -109,7 +110,7 @@ std::vector RowNumberFuzzer::generateInput( return input; } -RowNumberFuzzerBase::PlanWithSplits RowNumberFuzzer::makeDefaultPlan( +PlanWithSplits RowNumberFuzzer::makeDefaultPlan( const std::vector& partitionKeys, const std::vector& input) { auto planNodeIdGenerator = std::make_shared(); @@ -123,7 +124,7 @@ RowNumberFuzzerBase::PlanWithSplits RowNumberFuzzer::makeDefaultPlan( return PlanWithSplits{std::move(plan)}; } -RowNumberFuzzerBase::PlanWithSplits RowNumberFuzzer::makePlanWithTableScan( +PlanWithSplits RowNumberFuzzer::makePlanWithTableScan( const RowTypePtr& type, const std::vector& partitionKeys, const std::vector& splits) { @@ -164,8 +165,7 @@ void RowNumberFuzzer::runSingleIteration() { test::logVectors(input); auto defaultPlan = makeDefaultPlan(keyNames, input); - const auto expected = - execute(defaultPlan, pool_, /*injectSpill=*/false, false); + const auto expected = execute(defaultPlan, /*injectSpill=*/false, false); if (expected != nullptr) { validateExpectedResults(defaultPlan.plan, input, expected); @@ -174,7 +174,7 @@ void RowNumberFuzzer::runSingleIteration() { std::vector altPlans; altPlans.push_back(std::move(defaultPlan)); - const auto tableScanDir = exec::test::TempDirectoryPath::create(); + const auto tableScanDir = TempDirectoryPath::create(); addPlansWithTableScan(tableScanDir->getPath(), keyNames, input, altPlans); for (auto i = 0; i < altPlans.size(); ++i) { diff --git a/velox/exec/fuzzer/SpatialJoinFuzzer.cpp b/velox/exec/fuzzer/SpatialJoinFuzzer.cpp new file mode 100644 index 00000000000..546315defc3 --- /dev/null +++ b/velox/exec/fuzzer/SpatialJoinFuzzer.cpp @@ -0,0 +1,600 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/fuzzer/SpatialJoinFuzzer.h" + +#include "velox/common/file/FileSystems.h" +#include "velox/common/fuzzer/Utils.h" +#include "velox/exec/fuzzer/FuzzerUtil.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/vector/fuzzer/VectorFuzzer.h" + +DEFINE_int32(steps, 10, "Number of plans to generate and test."); + +DEFINE_int32( + duration_sec, + 0, + "For how long it should run (in seconds). If zero, " + "it executes exactly --steps iterations and exits."); + +DEFINE_int32( + batch_size, + 100, + "The number of elements on each generated vector."); + +DEFINE_int32(num_batches, 10, "The number of generated vectors."); + +DEFINE_double( + null_ratio, + 0.1, + "Chance of adding a null value in a vector " + "(expressed as double from 0 to 1)."); + +namespace facebook::velox::exec { + +namespace { +using namespace facebook::velox; + +/// Spatial distribution patterns for geometry generation. +enum class GeometryDistribution { + kUniform, // Geometries uniformly distributed in space + kClustered, // Geometries clustered in specific regions + kSparse // Sparse geometries with low overlap probability +}; + +// Constants for geometry generation. +constexpr int32_t kRandomCoordinateMax = 1000; +constexpr int32_t kNumClusters = 5; +constexpr double kClusterSpacing = 200.0; +constexpr double kClusterCenterOffset = 100.0; +constexpr int32_t kClusterSpreadRange = 100; +constexpr int32_t kClusterSpreadHalf = kClusterSpreadRange / 2; +constexpr double kPolygonSize = 10.0; +constexpr double kSparseSpread = 2000.0; +constexpr uint32_t kMaxRadius = 100; + +// Base class for geometry string generators. +class GeometryInputGenerator : public AbstractInputGenerator { + public: + GeometryInputGenerator( + GeometryDistribution distribution, + size_t seed, + double nullRatio) + : AbstractInputGenerator(seed, VARCHAR(), nullptr, nullRatio), + distribution_(distribution) {} + + protected: + std::pair generateCoordinates() { + double x{0.0}, y{0.0}; + switch (distribution_) { + case GeometryDistribution::kUniform: { + x = fuzzer::rand( + rng_, -kRandomCoordinateMax, kRandomCoordinateMax); + y = fuzzer::rand( + rng_, -kRandomCoordinateMax, kRandomCoordinateMax); + break; + } + case GeometryDistribution::kClustered: { + uint32_t cluster = fuzzer::rand(rng_, 0, kNumClusters); + double centerX = (cluster * kClusterSpacing) + kClusterCenterOffset; + double centerY = (cluster * kClusterSpacing) + kClusterCenterOffset; + x = centerX + + ((fuzzer::rand( + rng_, -kClusterSpreadRange, kClusterSpreadRange)) - + kClusterSpreadHalf); + y = centerY + + ((fuzzer::rand( + rng_, -kClusterSpreadRange, kClusterSpreadRange)) - + kClusterSpreadHalf); + break; + } + case GeometryDistribution::kSparse: { + x = fuzzer::rand(rng_, -kSparseSpread, kSparseSpread); + y = fuzzer::rand(rng_, -kSparseSpread, kSparseSpread); + break; + } + default: + VELOX_UNREACHABLE(); + } + return {x, y}; + } + + GeometryDistribution distribution_; +}; + +// Generates POINT geometry strings. +class PointInputGenerator : public GeometryInputGenerator { + public: + PointInputGenerator( + GeometryDistribution distribution, + size_t seed, + double nullRatio) + : GeometryInputGenerator(distribution, seed, nullRatio) {} + + variant generate() override { + if (fuzzer::coinToss(rng_, nullRatio_)) { + return variant::null(type_->kind()); + } + + auto [x, y] = generateCoordinates(); + return fmt::format("POINT ({} {})", x, y); + } +}; + +// Generates POLYGON geometry strings. +class PolygonInputGenerator : public GeometryInputGenerator { + public: + PolygonInputGenerator( + GeometryDistribution distribution, + size_t seed, + double nullRatio) + : GeometryInputGenerator(distribution, seed, nullRatio) {} + + variant generate() override { + if (fuzzer::coinToss(rng_, nullRatio_)) { + return variant::null(type_->kind()); + } + auto [centerX, centerY] = generateCoordinates(); + return fmt::format( + "POLYGON (({} {}, {} {}, {} {}, {} {}, {} {}))", + centerX - kPolygonSize, + centerY - kPolygonSize, + centerX + kPolygonSize, + centerY - kPolygonSize, + centerX + kPolygonSize, + centerY + kPolygonSize, + centerX - kPolygonSize, + centerY + kPolygonSize, + centerX - kPolygonSize, + centerY - kPolygonSize); + } +}; + +// Generates LINESTRING geometry strings. +class LineStringInputGenerator : public GeometryInputGenerator { + public: + LineStringInputGenerator( + GeometryDistribution distribution, + size_t seed, + double nullRatio) + : GeometryInputGenerator(distribution, seed, nullRatio) {} + + variant generate() override { + if (fuzzer::coinToss(rng_, nullRatio_)) { + return variant::null(type_->kind()); + } + auto [x1, y1] = generateCoordinates(); + double x2 = x1 + kPolygonSize; + double y2 = y1 + kPolygonSize; + return fmt::format("LINESTRING ({} {}, {} {})", x1, y1, x2, y2); + } +}; + +class SpatialJoinFuzzer { + public: + explicit SpatialJoinFuzzer(size_t initialSeed); + + void go(); + + private: + static VectorFuzzer::Options getFuzzerOptions() { + VectorFuzzer::Options opts; + opts.vectorSize = FLAGS_batch_size; + opts.stringVariableLength = true; + opts.stringLength = 100; + opts.nullRatio = FLAGS_null_ratio; + return opts; + } + + void seed(size_t seed) { + currentSeed_ = seed; + vectorFuzzer_.reSeed(seed); + rng_.seed(currentSeed_); + } + + void reSeed() { + seed(rng_()); + } + + // Randomly pick a join type supported by SpatialJoin. + core::JoinType pickJoinType(); + + // Randomly pick a spatial predicate function. + std::string pickSpatialPredicate(); + + // Randomly pick a geometry distribution pattern. + GeometryDistribution pickDistribution(); + + // Runs one test iteration from query plans generation, execution and result + // verification. + void verify(core::JoinType joinType); + + // Creates a vector of POINT geometries with specified distribution. + VectorPtr makePointVector(int32_t size, GeometryDistribution distribution); + + // Creates a vector of POLYGON geometries with specified distribution. + VectorPtr makePolygonVector(int32_t size, GeometryDistribution distribution); + + // Creates a vector of LINESTRING geometries with specified distribution. + VectorPtr makeLineStringVector( + int32_t size, + GeometryDistribution distribution); + + // Returns randomly generated probe input with geometry columns (as WKT + // strings). + std::vector generateProbeInput( + GeometryDistribution distribution); + + // Same as generateProbeInput() but copies over 10% of the input to ensure + // some matches during joining. Also generates an empty input with a 10% + // chance. + std::vector generateBuildInput( + const std::vector& probeInput, + GeometryDistribution distribution); + + // Executes a plan and returns the result. + RowVectorPtr execute(const core::PlanNodePtr& plan); + + int32_t randInt(int32_t min, int32_t max) { + return boost::random::uniform_int_distribution(min, max)(rng_); + } + + const std::shared_ptr pool_{ + memory::memoryManager()->addLeafPool()}; + std::mt19937 rng_; + size_t currentSeed_{0}; + + VectorFuzzer vectorFuzzer_; + + struct { + size_t numIterations{0}; + } stats_; +}; + +SpatialJoinFuzzer::SpatialJoinFuzzer(size_t initialSeed) + : vectorFuzzer_{getFuzzerOptions(), pool_.get()} { + filesystems::registerLocalFileSystem(); + seed(initialSeed); +} + +template +bool isDone(size_t i, T startTime) { + if (FLAGS_duration_sec > 0) { + std::chrono::duration elapsed = + std::chrono::system_clock::now() - startTime; + return elapsed.count() >= FLAGS_duration_sec; + } + return i >= FLAGS_steps; +} + +core::JoinType SpatialJoinFuzzer::pickJoinType() { + // SpatialJoin only supports INNER and LEFT join types. + static std::vector kJoinTypes = { + core::JoinType::kInner, core::JoinType::kLeft}; + + const size_t idx = randInt(0, kJoinTypes.size() - 1); + return kJoinTypes[idx]; +} + +std::string SpatialJoinFuzzer::pickSpatialPredicate() { + // Common spatial predicates supported by spatial joins. + static std::vector kPredicates = { + "ST_Intersects", + "ST_Contains", + "ST_Within", + "ST_Distance", + "ST_Overlaps", + "ST_Crosses", + "ST_Touches", + "ST_Equals"}; + + const size_t idx = randInt(0, kPredicates.size() - 1); + return kPredicates[idx]; +} + +GeometryDistribution SpatialJoinFuzzer::pickDistribution() { + static std::vector kDistributions = { + GeometryDistribution::kUniform, + GeometryDistribution::kClustered, + GeometryDistribution::kSparse}; + + const size_t idx = randInt(0, kDistributions.size() - 1); + return kDistributions[idx]; +} + +VectorPtr SpatialJoinFuzzer::makePointVector( + int32_t size, + GeometryDistribution distribution) { + auto generator = std::make_shared( + distribution, currentSeed_, getFuzzerOptions().nullRatio); + return vectorFuzzer_.fuzzFlat(VARCHAR(), size, generator); +} + +VectorPtr SpatialJoinFuzzer::makePolygonVector( + int32_t size, + GeometryDistribution distribution) { + auto generator = std::make_shared( + distribution, currentSeed_, getFuzzerOptions().nullRatio); + return vectorFuzzer_.fuzzFlat(VARCHAR(), size, generator); +} + +VectorPtr SpatialJoinFuzzer::makeLineStringVector( + int32_t size, + GeometryDistribution distribution) { + auto generator = std::make_shared( + distribution, currentSeed_, getFuzzerOptions().nullRatio); + return vectorFuzzer_.fuzzFlat(VARCHAR(), size, generator); +} + +std::vector SpatialJoinFuzzer::generateProbeInput( + GeometryDistribution distribution) { + std::vector input; + + const int32_t numRows = FLAGS_batch_size * FLAGS_num_batches; + const int32_t batchSize = FLAGS_batch_size; + const int32_t numBatches = FLAGS_num_batches; + + // Randomly pick geometry type for probe side. + const int geometryType = randInt(0, 2); + + for (int32_t i = 0; i < numBatches; ++i) { + int32_t currentBatchSize = std::min(batchSize, numRows - (i * batchSize)); + + VectorPtr geomVector; + if (geometryType == 0) { + geomVector = makePointVector(currentBatchSize, distribution); + } else if (geometryType == 1) { + geomVector = makePolygonVector(currentBatchSize, distribution); + } else { + geomVector = makeLineStringVector(currentBatchSize, distribution); + } + + auto idVector = vectorFuzzer_.fuzzFlat(BIGINT(), currentBatchSize); + auto rowType = ROW( + {"probe_id", "probe_geom_wkt"}, {idVector->type(), geomVector->type()}); + auto rowVector = std::make_shared( + pool_.get(), + rowType, + nullptr, + currentBatchSize, + std::vector{idVector, geomVector}); + input.push_back(rowVector); + } + + return input; +} + +std::vector SpatialJoinFuzzer::generateBuildInput( + const std::vector& probeInput, + GeometryDistribution distribution) { + std::vector input; + + // 1 in 10 times use empty build. + if (vectorFuzzer_.coinToss(0.1)) { + auto rowType = ROW({"build_id", "build_geom_wkt"}, {BIGINT(), VARCHAR()}); + auto rowVector = std::make_shared( + pool_.get(), + rowType, + nullptr, + 0, + std::vector{ + vectorFuzzer_.fuzzFlat(BIGINT(), 0), + vectorFuzzer_.fuzzFlat(VARCHAR(), 0)}); + return {rowVector}; + } + + // Randomly pick geometry type for build side. + const int geometryType = randInt(0, 2); + + for (const auto& probe : probeInput) { + auto numRows = 1 + probe->size() / 8; + + VectorPtr geomVector; + if (geometryType == 0) { + geomVector = makePointVector(numRows, distribution); + } else if (geometryType == 1) { + geomVector = makePolygonVector(numRows, distribution); + } else { + geomVector = makeLineStringVector(numRows, distribution); + } + + auto idVector = vectorFuzzer_.fuzzFlat(BIGINT(), numRows); + + // To ensure some matches, copy some geometries from probe side. + if (probe->size() > 0) { + std::vector rowNumbers(numRows); + SelectivityVector rows(numRows, false); + for (vector_size_t i = 0; i < numRows; ++i) { + if (vectorFuzzer_.coinToss(0.3)) { + rowNumbers[i] = randInt(0, probe->size() - 1); + rows.setValid(i, true); + } + } + + // Copy geometry from probe to build. + auto probeGeom = probe->childAt(1); + geomVector->copy(probeGeom.get(), rows, rowNumbers.data()); + } + + auto rowType = ROW( + {"build_id", "build_geom_wkt"}, {idVector->type(), geomVector->type()}); + auto rowVector = std::make_shared( + pool_.get(), + rowType, + nullptr, + numRows, + std::vector{idVector, geomVector}); + input.push_back(rowVector); + } + + return input; +} + +RowVectorPtr SpatialJoinFuzzer::execute(const core::PlanNodePtr& plan) { + LOG(INFO) << "Executing query plan: " << std::endl + << plan->toString(true, true); + + return test::AssertQueryBuilder(plan).copyResults(pool_.get()); +} + +void SpatialJoinFuzzer::verify(core::JoinType joinType) { + const auto distribution = pickDistribution(); + const auto predicate = pickSpatialPredicate(); + + // Generate test data (WKT strings). + auto probeInput = generateProbeInput(distribution); + auto buildInput = generateBuildInput(probeInput, distribution); + + if (VLOG_IS_ON(1)) { + VLOG(1) << "Probe input: " << probeInput[0]->toString(); + for (const auto& v : probeInput) { + VLOG(1) << std::endl << v->toString(0, v->size()); + } + + VLOG(1) << "Build input: " << buildInput[0]->toString(); + for (const auto& v : buildInput) { + VLOG(1) << std::endl << v->toString(0, v->size()); + } + } + + // Build spatial join plan with geometry conversion as part of the plan. + const auto planNodeIdGenerator = + std::make_shared(); + + std::string joinCondition; + std::optional radiusColumn; + std::optional radiusExpression; + if (predicate == "ST_Distance") { + // ST_Distance returns a value, use it with a threshold. + // For ST_Distance, we use a radius column instead of embedding the + // threshold in the join condition. + joinCondition = + fmt::format("{}(probe_geom, build_geom) < radius", predicate); + radiusColumn = "radius"; + radiusExpression = fmt::format( + "CAST({} AS DOUBLE) AS radius", + static_cast(randInt(0, kMaxRadius))); + } else { + // Other predicates return boolean. + joinCondition = fmt::format("{}(probe_geom, build_geom)", predicate); + } + + // Create SpatialJoin plan with geometry conversion projections. + auto spatialJoinPlan = + test::PlanBuilder(planNodeIdGenerator) + .values(probeInput) + // Convert probe WKT strings to Geometry + .project( + {"probe_id", + "ST_GeometryFromText(probe_geom_wkt) AS probe_geom", + "probe_geom_wkt"}) + .spatialJoin( + test::PlanBuilder(planNodeIdGenerator) + .values(buildInput) + // Convert build WKT strings to Geometry + .project( + radiusColumn.has_value() + ? std::vector< + std:: + string>{"build_id", "ST_GeometryFromText(build_geom_wkt) AS build_geom", "build_geom_wkt", radiusExpression.value()} + : std::vector< + std:: + string>{"build_id", "ST_GeometryFromText(build_geom_wkt) AS build_geom", "build_geom_wkt"}) + .planNode(), + joinCondition, + "probe_geom", + "build_geom", + radiusColumn, + {"probe_id", "probe_geom_wkt", "build_id", "build_geom_wkt"}, + joinType) + .planNode(); + + // Create equivalent NestedLoopJoin plan for comparison. + auto nestedLoopJoinPlan = + test::PlanBuilder(planNodeIdGenerator) + .values(probeInput) + // Convert probe WKT strings to Geometry + .project( + {"probe_id", + "ST_GeometryFromText(probe_geom_wkt) AS probe_geom", + "probe_geom_wkt"}) + .nestedLoopJoin( + test::PlanBuilder(planNodeIdGenerator) + .values(buildInput) + // Convert build WKT strings to Geometry + .project( + radiusColumn.has_value() + ? std::vector< + std:: + string>{"build_id", "ST_GeometryFromText(build_geom_wkt) AS build_geom", "build_geom_wkt", radiusExpression.value()} + : std::vector< + std:: + string>{"build_id", "ST_GeometryFromText(build_geom_wkt) AS build_geom", "build_geom_wkt"}) + .planNode(), + {joinCondition}, + {"probe_id", "probe_geom_wkt", "build_id", "build_geom_wkt"}, + joinType) + .planNode(); + + LOG(INFO) << "Executing SpatialJoin plan..."; + const auto spatialJoinResult = execute(spatialJoinPlan); + + LOG(INFO) << "Executing NestedLoopJoin plan..."; + const auto nestedLoopJoinResult = execute(nestedLoopJoinPlan); + + // Compare SpatialJoin vs NestedLoopJoin results. + auto result = + test::assertEqualResults({nestedLoopJoinResult}, {spatialJoinResult}); + VELOX_CHECK(result, "SpatialJoin and NestedLoopJoin results don't match"); + + LOG(INFO) << "SpatialJoin matches NestedLoopJoin."; +} + +void SpatialJoinFuzzer::go() { + VELOX_USER_CHECK( + FLAGS_steps > 0 || FLAGS_duration_sec > 0, + "Either --steps or --duration_sec needs to be greater than zero."); + VELOX_USER_CHECK_GE(FLAGS_batch_size, 10, "Batch size must be at least 10."); + + const auto startTime = std::chrono::system_clock::now(); + + while (!isDone(stats_.numIterations, startTime)) { + LOG(WARNING) << "==============================> Started iteration " + << stats_.numIterations << " (seed: " << currentSeed_ << ")"; + + // Pick join type. + const auto joinType = pickJoinType(); + + verify(joinType); + + LOG(WARNING) << "==============================> Done with iteration " + << stats_.numIterations; + + reSeed(); + ++stats_.numIterations; + } + + LOG(INFO) << "Total iterations: " << stats_.numIterations; +} + +} // namespace + +void spatialJoinFuzzer(size_t seed) { + SpatialJoinFuzzer(seed).go(); +} + +} // namespace facebook::velox::exec diff --git a/velox/exec/fuzzer/SpatialJoinFuzzer.h b/velox/exec/fuzzer/SpatialJoinFuzzer.h new file mode 100644 index 00000000000..92bdd967213 --- /dev/null +++ b/velox/exec/fuzzer/SpatialJoinFuzzer.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +namespace facebook::velox::exec { + +/// Runs the fuzzer for SpatialJoin operator. Generates random geometry data +/// and spatial join plans with various predicates (ST_Intersects, ST_Contains, +/// ST_Within, ST_Distance), comparing SpatialJoin results against +/// NestedLoopJoin as the reference implementation. +/// +/// The fuzzer tests: +/// - Different spatial predicates +/// - INNER and LEFT join types (the only types supported by SpatialJoin) +/// - Different geometry types (POINT, POLYGON, LINESTRING) +/// - Various data distributions (uniform, clustered, sparse) +/// - Different sizes of probe and build sides +/// - Plans with and without filters +/// - Different output column projections +void spatialJoinFuzzer(size_t seed); + +} // namespace facebook::velox::exec diff --git a/velox/exec/fuzzer/SpatialJoinFuzzerRunner.cpp b/velox/exec/fuzzer/SpatialJoinFuzzerRunner.cpp new file mode 100644 index 00000000000..162429d598d --- /dev/null +++ b/velox/exec/fuzzer/SpatialJoinFuzzerRunner.cpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "velox/common/file/FileSystems.h" +#include "velox/common/memory/Memory.h" +#include "velox/exec/fuzzer/SpatialJoinFuzzer.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/parse/TypeResolver.h" +#include "velox/serializers/PrestoSerializer.h" + +DEFINE_int64( + seed, + 0, + "Initial seed for random number generator used to reproduce previous " + "results (0 means start with random seed)."); + +using namespace facebook::velox; + +int main(int argc, char** argv) { + folly::Init init(&argc, &argv); + + // Initialize memory system. + memory::initializeMemoryManager(memory::MemoryManager::Options{}); + auto pool = memory::memoryManager()->addLeafPool(); + + // Register file systems. + filesystems::registerLocalFileSystem(); + + // Register Presto functions. + functions::prestosql::registerAllScalarFunctions(); + + // Register type resolver. + parse::registerTypeResolver(); + + // Register serializers. + if (!isRegisteredNamedVectorSerde("Presto")) { + serializer::presto::PrestoVectorSerde::registerNamedVectorSerde(); + } + + // Determine the seed. + size_t seed = FLAGS_seed == 0 ? std::random_device{}() : FLAGS_seed; + LOG(INFO) << "Using seed: " << seed; + + // Run the spatial join fuzzer. + exec::spatialJoinFuzzer(seed); + + return 0; +} diff --git a/velox/exec/fuzzer/RowNumberFuzzerBase.cpp b/velox/exec/fuzzer/SpillFuzzerBase.cpp similarity index 76% rename from velox/exec/fuzzer/RowNumberFuzzerBase.cpp rename to velox/exec/fuzzer/SpillFuzzerBase.cpp index 1b634b0f23b..b239881e1b1 100644 --- a/velox/exec/fuzzer/RowNumberFuzzerBase.cpp +++ b/velox/exec/fuzzer/SpillFuzzerBase.cpp @@ -14,15 +14,16 @@ * limitations under the License. */ -#include "velox/exec/fuzzer/RowNumberFuzzerBase.h" +#include "velox/exec/fuzzer/SpillFuzzerBase.h" #include +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/connectors/hive/HiveConnector.h" #include "velox/dwio/dwrf/RegisterDwrfReader.h" +#include "velox/dwio/dwrf/RegisterDwrfWriter.h" #include "velox/exec/Spill.h" #include "velox/exec/fuzzer/FuzzerUtil.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/serializers/CompactRowSerializer.h" #include "velox/serializers/PrestoSerializer.h" #include "velox/serializers/UnsafeRowSerializer.h" @@ -66,7 +67,29 @@ DEFINE_bool( namespace facebook::velox::exec { -RowNumberFuzzerBase::RowNumberFuzzerBase( +using namespace facebook::velox::common::testutil; + +namespace { +void compareResults(const RowVectorPtr& expected, const RowVectorPtr& actual) { + if (actual != nullptr && expected != nullptr) { + try { + VELOX_CHECK( + test::assertEqualResults({expected}, {actual}), + "Logically equivalent plans produced different results"); + } catch (const VeloxException&) { + LOG(ERROR) << "Expected\n" + << expected->toString(0, expected->size()) << "\nActual\n" + << actual->toString(0, actual->size()); + throw; + } + } else { + VELOX_CHECK( + FLAGS_enable_oom_injection, "Got unexpected nullptr for results"); + } +} +} // namespace + +SpillFuzzerBase::SpillFuzzerBase( size_t initialSeed, std::unique_ptr referenceQueryRunner) : vectorFuzzer_{getFuzzerOptions(), pool_.get()}, @@ -75,17 +98,21 @@ RowNumberFuzzerBase::RowNumberFuzzerBase( seed(initialSeed); } -void RowNumberFuzzerBase::setupReadWrite() { +void SpillFuzzerBase::setupReadWrite() { filesystems::registerLocalFileSystem(); dwrf::registerDwrfReaderFactory(); + dwrf::registerDwrfWriterFactory(); - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kPresto)) { + if (!isRegisteredNamedVectorSerde( + VectorSerde::kindName(VectorSerde::Kind::kPresto))) { serializer::presto::PrestoVectorSerde::registerNamedVectorSerde(); } - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kCompactRow)) { + if (!isRegisteredNamedVectorSerde( + VectorSerde::kindName(VectorSerde::Kind::kCompactRow))) { serializer::CompactRowVectorSerde::registerNamedVectorSerde(); } - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kUnsafeRow)) { + if (!isRegisteredNamedVectorSerde( + VectorSerde::kindName(VectorSerde::Kind::kUnsafeRow))) { serializer::spark::UnsafeRowVectorSerde::registerNamedVectorSerde(); } @@ -98,20 +125,19 @@ void RowNumberFuzzerBase::setupReadWrite() { // Sometimes we generate zero-column input of type ROW({}) or a column of type // UNKNOWN(). Such data cannot be written to a file and therefore cannot // be tested with TableScan. -bool RowNumberFuzzerBase::isTableScanSupported(const TypePtr& type) { - if (type->kind() == TypeKind::ROW && type->size() == 0) { - return false; - } - if (type->kind() == TypeKind::UNKNOWN) { - return false; - } - if (type->kind() == TypeKind::HUGEINT) { - return false; - } - // Disable testing with TableScan when input contains TIMESTAMP type, due to - // the issue #8127. - if (type->kind() == TypeKind::TIMESTAMP) { - return false; +bool SpillFuzzerBase::isTableScanSupported(const TypePtr& type) { + switch (type->kind()) { + case TypeKind::UNKNOWN: + case TypeKind::HUGEINT: + case TypeKind::TIMESTAMP: + return false; + case TypeKind::ROW: + if (type->size() == 0) { + return false; + } + break; + default: + break; } for (auto i = 0; i < type->size(); ++i) { @@ -123,7 +149,7 @@ bool RowNumberFuzzerBase::isTableScanSupported(const TypePtr& type) { return true; } -void RowNumberFuzzerBase::validateExpectedResults( +void SpillFuzzerBase::validateExpectedResults( const core::PlanNodePtr& plan, const std::vector& input, const RowVectorPtr& result) { @@ -149,7 +175,7 @@ bool isDone(size_t i, T startTime) { return i >= FLAGS_steps; } -void RowNumberFuzzerBase::run() { +void SpillFuzzerBase::run() { VELOX_USER_CHECK( FLAGS_steps > 0 || FLAGS_duration_sec > 0, "Either --steps or --duration_sec needs to be greater than zero."); @@ -170,9 +196,8 @@ void RowNumberFuzzerBase::run() { } } -RowVectorPtr RowNumberFuzzerBase::execute( +RowVectorPtr SpillFuzzerBase::execute( const PlanWithSplits& plan, - const std::shared_ptr& pool, bool injectSpill, bool injectOOM, const std::optional& spillConfig, @@ -191,8 +216,8 @@ RowVectorPtr RowNumberFuzzerBase::execute( "Spill config not set for execute with spilling"); VELOX_CHECK_GE( maxSpillLevel, 0, "Max spill should be set for execute with spilling"); - std::shared_ptr spillDirectory; - spillDirectory = exec::test::TempDirectoryPath::create(); + std::shared_ptr spillDirectory; + spillDirectory = TempDirectoryPath::create(); builder.config(core::QueryConfig::kSpillEnabled, true) .config(core::QueryConfig::kMaxSpillLevel, maxSpillLevel) .config(spillConfig.value(), true) @@ -216,7 +241,7 @@ RowVectorPtr RowNumberFuzzerBase::execute( TestScopedSpillInjection scopedSpillInjection(spillPct); RowVectorPtr result; try { - result = builder.copyResults(pool.get()); + result = builder.copyResults(pool_.get()); } catch (VeloxRuntimeError& e) { if (injectOOM && e.errorCode() == facebook::velox::error_code::kMemCapExceeded && @@ -236,50 +261,30 @@ RowVectorPtr RowNumberFuzzerBase::execute( return result; } -void RowNumberFuzzerBase::testPlan( +void SpillFuzzerBase::testPlan( const PlanWithSplits& plan, int32_t testNumber, const RowVectorPtr& expected, const std::optional& spillConfig) { LOG(INFO) << "Testing plan #" << testNumber; - auto actual = - execute(plan, pool_, /*injectSpill=*/false, FLAGS_enable_oom_injection); - if (actual != nullptr && expected != nullptr) { - VELOX_CHECK( - test::assertEqualResults({expected}, {actual}), - "Logically equivalent plans produced different results"); - } else { - VELOX_CHECK( - FLAGS_enable_oom_injection, "Got unexpected nullptr for results"); - } + // Test without spilling first, then with spilling if enabled. + const bool testSpill = FLAGS_enable_spill; + for (int round = 0; round < (testSpill ? 2 : 1); ++round) { + const bool injectSpill = round == 1; + if (injectSpill) { + LOG(INFO) << "Testing plan #" << testNumber << " with spilling"; + } - if (FLAGS_enable_spill) { - LOG(INFO) << "Testing plan #" << testNumber << " with spilling"; const auto fuzzMaxSpillLevel = FLAGS_max_spill_level == -1 ? randInt(0, 3) : FLAGS_max_spill_level; - actual = execute( + auto actual = execute( plan, - pool_, - /*=injectSpill=*/true, + injectSpill, FLAGS_enable_oom_injection, - spillConfig, - fuzzMaxSpillLevel); - if (actual != nullptr && expected != nullptr) { - try { - VELOX_CHECK( - test::assertEqualResults({expected}, {actual}), - "Logically equivalent plans produced different results"); - } catch (const VeloxException&) { - LOG(ERROR) << "Expected\n" - << expected->toString(0, expected->size()) << "\nActual\n" - << actual->toString(0, actual->size()); - throw; - } - } else { - VELOX_CHECK( - FLAGS_enable_oom_injection, "Got unexpected nullptr for results"); - } + injectSpill ? spillConfig : std::nullopt, + injectSpill ? fuzzMaxSpillLevel : -1); + compareResults(expected, actual); } } diff --git a/velox/exec/fuzzer/RowNumberFuzzerBase.h b/velox/exec/fuzzer/SpillFuzzerBase.h similarity index 83% rename from velox/exec/fuzzer/RowNumberFuzzerBase.h rename to velox/exec/fuzzer/SpillFuzzerBase.h index 1aaf5f4f3b0..4f25116bef1 100644 --- a/velox/exec/fuzzer/RowNumberFuzzerBase.h +++ b/velox/exec/fuzzer/SpillFuzzerBase.h @@ -41,21 +41,31 @@ DECLARE_bool(enable_oom_injection); namespace facebook::velox::exec { -class RowNumberFuzzerBase { +struct PlanWithSplits { + core::PlanNodePtr plan; + std::vector splits; + + explicit PlanWithSplits( + core::PlanNodePtr _plan, + const std::vector& _splits = {}) + : plan(std::move(_plan)), splits(_splits) {} +}; + +class SpillFuzzerBase { public: - explicit RowNumberFuzzerBase( + explicit SpillFuzzerBase( size_t initialSeed, std::unique_ptr); void run(); - virtual ~RowNumberFuzzerBase() = default; + virtual ~SpillFuzzerBase() = default; protected: bool isTableScanSupported(const TypePtr& type); - // Runs one test iteration from query plans generations, executions and result - // verifications. + // Runs one test iteration from query plans generations, executions and + // result verifications. virtual void runSingleIteration() = 0; // Sets up the Dwrf reader/writer, serializers and Hive connector for the @@ -91,27 +101,16 @@ class RowNumberFuzzerBase { const std::vector& input, const RowVectorPtr& result); - struct PlanWithSplits { - core::PlanNodePtr plan; - std::vector splits; - - explicit PlanWithSplits( - core::PlanNodePtr _plan, - const std::vector& _splits = {}) - : plan(std::move(_plan)), splits(_splits) {} - }; - - // Executes a plan with spilling and oom injection possibly. + // Executes a plan with optional spill and OOM injection. RowVectorPtr execute( const PlanWithSplits& plan, - const std::shared_ptr& pool, bool injectSpill, bool injectOOM, const std::optional& spillConfig = std::nullopt, int maxSpillLevel = -1); - // Tests a plan by executing it with and without spilling. OOM injection - // also might be done based on FLAG_enable_oom_injection. + // Tests a plan by executing it with and without spilling. OOM injection is + // controlled by FLAGS_enable_oom_injection. void testPlan( const PlanWithSplits& plan, int32_t testNumber, @@ -123,15 +122,15 @@ class RowNumberFuzzerBase { std::shared_ptr rootPool_{ memory::memoryManager()->addRootPool( - "rowNumberFuzzer", + "spillFuzzer", memory::kMaxMemory, exec::MemoryReclaimer::create())}; std::shared_ptr pool_{rootPool_->addLeafChild( - "rowNumberFuzzerLeaf", + "spillFuzzerLeaf", true, exec::MemoryReclaimer::create())}; std::shared_ptr writerPool_{rootPool_->addAggregateChild( - "rowNumberFuzzerWriter", + "spillFuzzerWriter", exec::MemoryReclaimer::create())}; VectorFuzzer vectorFuzzer_; std::unique_ptr referenceQueryRunner_; diff --git a/velox/exec/fuzzer/TopNRowNumberFuzzer.cpp b/velox/exec/fuzzer/TopNRowNumberFuzzer.cpp index 7d658aa6888..ac66df864c1 100644 --- a/velox/exec/fuzzer/TopNRowNumberFuzzer.cpp +++ b/velox/exec/fuzzer/TopNRowNumberFuzzer.cpp @@ -16,18 +16,21 @@ #include "velox/exec/fuzzer/TopNRowNumberFuzzer.h" +#include +#include #include +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/exec/fuzzer/FuzzerUtil.h" -#include "velox/exec/fuzzer/RowNumberFuzzerBase.h" +#include "velox/exec/fuzzer/SpillFuzzerBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/vector/tests/utils/VectorMaker.h" namespace facebook::velox::exec { +using namespace facebook::velox::common::testutil; namespace { -class TopNRowNumberFuzzer : public RowNumberFuzzerBase { +class TopNRowNumberFuzzer : public SpillFuzzerBase { public: explicit TopNRowNumberFuzzer( size_t initialSeed, @@ -41,6 +44,8 @@ class TopNRowNumberFuzzer : public RowNumberFuzzerBase { std::pair, std::vector> generateKeys( const std::string& prefix); + std::string generateRankFunction(); + std::vector generateInput( const std::vector& keyNames, const std::vector& keyTypes, @@ -49,24 +54,36 @@ class TopNRowNumberFuzzer : public RowNumberFuzzerBase { // Makes the query plan with default settings in TopNRowNumberFuzzer. std::pair makeDefaultPlan( + std::string_view rankFunction, const std::vector& partitionKeys, const std::vector& sortKeys, const std::vector& allKeys, const std::vector& input); PlanWithSplits makePlanWithTableScan( + std::string_view rankFunction, const std::vector& partitionKeys, const std::vector& sortKeys, const std::vector& allKeys, int limit, const std::vector& input, const std::string& tableDir); + + // Makes an alternate plan using WindowNode with the same rank function, + // then filters rows where the rank value is within the limit. + PlanWithSplits makeWindowPlan( + std::string_view rankFunction, + const std::vector& partitionKeys, + const std::vector& sortKeys, + const std::vector& allKeys, + int32_t limit, + const std::vector& input); }; TopNRowNumberFuzzer::TopNRowNumberFuzzer( size_t initialSeed, std::unique_ptr referenceQueryRunner) - : RowNumberFuzzerBase(initialSeed, std::move(referenceQueryRunner)) {} + : SpillFuzzerBase(initialSeed, std::move(referenceQueryRunner)) {} std::pair, std::vector> TopNRowNumberFuzzer::generateKeys(const std::string& prefix) { @@ -93,6 +110,19 @@ TopNRowNumberFuzzer::generateKeys(const std::string& prefix) { return std::make_pair(keys, types); } +std::string TopNRowNumberFuzzer::generateRankFunction() { + int32_t rankFunction = randInt(0, 2); + switch (rankFunction) { + case 0: + return "row_number"; + case 1: + return "rank"; + case 2: + return "dense_rank"; + } + return "row_number"; +} + std::vector TopNRowNumberFuzzer::generateInput( const std::vector& keyNames, const std::vector& keyTypes, @@ -153,12 +183,14 @@ std::vector TopNRowNumberFuzzer::generateInput( // values. This is done to introduce some repetition of key values for // windowing. auto baseVector = vectorFuzzer_.fuzz(keyTypes[i], numPartitions); - children.push_back(BaseVector::wrapInDictionary( - partitionNulls, partitionIndices, size, baseVector)); + children.push_back( + BaseVector::wrapInDictionary( + partitionNulls, partitionIndices, size, baseVector)); } else if (sortingKeySet.find(keyNames[i]) != sortingKeySet.end()) { auto baseVector = vectorFuzzer_.fuzz(keyTypes[i], numPeerGroups); - children.push_back(BaseVector::wrapInDictionary( - sortingNulls, sortingIndices, size, baseVector)); + children.push_back( + BaseVector::wrapInDictionary( + sortingNulls, sortingIndices, size, baseVector)); } else { children.push_back(vectorFuzzer_.fuzz(keyTypes[i], size)); } @@ -171,26 +203,26 @@ std::vector TopNRowNumberFuzzer::generateInput( return input; } -std::pair -TopNRowNumberFuzzer::makeDefaultPlan( +std::pair TopNRowNumberFuzzer::makeDefaultPlan( + std::string_view rankFunction, const std::vector& partitionKeys, const std::vector& sortKeys, const std::vector& allKeys, const std::vector& input) { - auto planNodeIdGenerator = std::make_shared(); std::vector projectFields = allKeys; projectFields.emplace_back("row_number"); int32_t limit = randInt(1, FLAGS_batch_size); auto plan = test::PlanBuilder() .values(input) - .topNRowNumber(partitionKeys, sortKeys, limit, true) + .topNRank(rankFunction, partitionKeys, sortKeys, limit, true) .project(projectFields) .planNode(); return std::make_pair(PlanWithSplits{std::move(plan)}, limit); } -RowNumberFuzzerBase::PlanWithSplits TopNRowNumberFuzzer::makePlanWithTableScan( +PlanWithSplits TopNRowNumberFuzzer::makePlanWithTableScan( + std::string_view rankFunction, const std::vector& partitionKeys, const std::vector& sortKeys, const std::vector& allKeys, @@ -205,7 +237,7 @@ RowNumberFuzzerBase::PlanWithSplits TopNRowNumberFuzzer::makePlanWithTableScan( auto planNodeIdGenerator = std::make_shared(); auto plan = test::PlanBuilder(planNodeIdGenerator) .tableScan(asRowType(input[0]->type())) - .topNRowNumber(partitionKeys, sortKeys, limit, true) + .topNRank(rankFunction, partitionKeys, sortKeys, limit, true) .project(projectFields) .planNode(); @@ -214,6 +246,39 @@ RowNumberFuzzerBase::PlanWithSplits TopNRowNumberFuzzer::makePlanWithTableScan( return PlanWithSplits{plan, splits}; } +PlanWithSplits TopNRowNumberFuzzer::makeWindowPlan( + std::string_view rankFunction, + const std::vector& partitionKeys, + const std::vector& sortKeys, + const std::vector& allKeys, + int32_t limit, + const std::vector& input) { + // Build the window OVER clause with optional PARTITION BY and ORDER BY. + std::string overClause; + if (!partitionKeys.empty()) { + overClause = + fmt::format("partition by {} ", folly::join(", ", partitionKeys)); + } + overClause += fmt::format("order by {}", folly::join(", ", sortKeys)); + + // Alias the window output to "row_number" to match the default plan's output + // schema, regardless of which rank function (row_number, rank, dense_rank) + // is used. + auto windowExpr = + fmt::format("{}() over ({}) AS row_number", rankFunction, overClause); + + std::vector projectFields = allKeys; + projectFields.emplace_back("row_number"); + + auto plan = test::PlanBuilder() + .values(input) + .window({windowExpr}) + .filter(fmt::format("row_number <= {}", limit)) + .project(projectFields) + .planNode(); + return PlanWithSplits{std::move(plan)}; +} + void TopNRowNumberFuzzer::runSingleIteration() { const auto [partitionKeys, partitionTypes] = generateKeys("p"); const auto [sortKeys, sortTypes] = generateKeys("s"); @@ -236,11 +301,12 @@ void TopNRowNumberFuzzer::runSingleIteration() { const auto input = generateInput(allKeys, allTypes, partitionKeys, sortKeys); test::logVectors(input); + const auto rankFunction = generateRankFunction(); + auto [defaultPlan, limit] = - makeDefaultPlan(partitionKeys, allSortKeys, allKeys, input); + makeDefaultPlan(rankFunction, partitionKeys, allSortKeys, allKeys, input); - const auto expected = - execute(defaultPlan, pool_, /*injectSpill=*/false, false); + const auto expected = execute(defaultPlan, /*injectSpill=*/false, false); if (expected != nullptr) { validateExpectedResults(defaultPlan.plan, input, expected); } @@ -248,9 +314,13 @@ void TopNRowNumberFuzzer::runSingleIteration() { std::vector altPlans; altPlans.push_back(std::move(defaultPlan)); - const auto tableScanDir = exec::test::TempDirectoryPath::create(); + altPlans.push_back(makeWindowPlan( + rankFunction, partitionKeys, allSortKeys, allKeys, limit, input)); + + const auto tableScanDir = TempDirectoryPath::create(); if (isTableScanSupported(input[0]->type())) { altPlans.push_back(makePlanWithTableScan( + rankFunction, partitionKeys, allSortKeys, allKeys, diff --git a/velox/exec/fuzzer/TopNRowNumberFuzzerRunner.cpp b/velox/exec/fuzzer/TopNRowNumberFuzzerRunner.cpp index b93b49be367..7368ad94c95 100644 --- a/velox/exec/fuzzer/TopNRowNumberFuzzerRunner.cpp +++ b/velox/exec/fuzzer/TopNRowNumberFuzzerRunner.cpp @@ -22,6 +22,9 @@ #include "velox/exec/fuzzer/FuzzerUtil.h" #include "velox/exec/fuzzer/ReferenceQueryRunner.h" #include "velox/exec/fuzzer/TopNRowNumberFuzzer.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/functions/prestosql/window/WindowFunctionsRegistration.h" +#include "velox/parse/TypeResolver.h" /// TopNRowNumberFuzzerRunner leverages TopNRowNumberFuzzer and VectorFuzzer to /// automatically generate and execute tests. It works as follows: @@ -85,6 +88,9 @@ int main(int argc, char** argv) { // singletons, installing proper signal handlers for better debugging // experience, and initialize glog and gflags. folly::Init init(&argc, &argv); + functions::prestosql::registerAllScalarFunctions(); + parse::registerTypeResolver(); + window::prestosql::registerAllWindowFunctions(); exec::test::setupMemory(FLAGS_allocator_capacity, FLAGS_arbitrator_capacity); std::shared_ptr rootPool{ memory::memoryManager()->addRootPool()}; diff --git a/velox/exec/fuzzer/VeloxQueryRunner.cpp b/velox/exec/fuzzer/VeloxQueryRunner.cpp new file mode 100644 index 00000000000..765dbc2beab --- /dev/null +++ b/velox/exec/fuzzer/VeloxQueryRunner.cpp @@ -0,0 +1,244 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/fuzzer/VeloxQueryRunner.h" + +#include +#include +#include +#include +#include +#include "velox/core/PlanNode.h" +#include "velox/exec/fuzzer/if/gen-cpp2/LocalRunnerService.h" +#include "velox/exec/tests/utils/QueryAssertions.h" +#include "velox/functions/prestosql/types/BingTileType.h" +#include "velox/functions/prestosql/types/GeometryType.h" +#include "velox/functions/prestosql/types/IPAddressType.h" +#include "velox/functions/prestosql/types/IPPrefixType.h" +#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" +#include "velox/functions/prestosql/types/UuidType.h" +#include "velox/functions/prestosql/types/parser/TypeParser.h" +#include "velox/serializers/PrestoSerializer.h" + +using namespace facebook::velox::runner; + +namespace facebook::velox::exec::test { + +namespace { + +RowTypePtr parseBatchRowType(Batch batch) { + std::vector names; + std::vector types; + + for (const auto& name : *batch.columnNames()) { + names.push_back(name); + } + + // Clean up type strings and format according to TypeParser::parseType + // expectation. Input types are serialized by type()->asRow() for Thrift + // struct by LocalRunnerService in the format of SOME_COMPLEX_TYPE (note the '<' and '>'). And in the case of complex type ROW, + // types follow column name and a semicolon: as an example, ROW. We need to change these particular character choices to + // paranthesis (in the case of the angled brackets) and spaces (in the case of + // the semicolon). As an example, + // MAP -> MAP(VARCHAR, TIMESTAMP) + // ROW -> ROW(f0 TYPE_1, f1 TYPE_2, etc.) + // as expected by TypeParser::parseType. Without this cleanup, parse will + // crash and fuzzer will fail. + for (const auto& typeString : *batch.columnTypes()) { + auto parsedTypeString = typeString; + std::replace(parsedTypeString.begin(), parsedTypeString.end(), '<', '('); + std::replace(parsedTypeString.begin(), parsedTypeString.end(), '>', ')'); + std::replace(parsedTypeString.begin(), parsedTypeString.end(), ':', ' '); + types.push_back(functions::prestosql::parseType(parsedTypeString)); + } + + return ROW(std::move(names), std::move(types)); +} + +std::vector deserializeBatches( + const std::vector& resultBatches, + memory::MemoryPool* pool) { + std::vector queryResults; + + auto serde = std::make_unique(); + serializer::presto::PrestoVectorSerde::PrestoOptions options; + + for (const auto& batch : resultBatches) { + VELOX_CHECK( + apache::thrift::is_non_optional_field_set_manually_or_by_serializer( + batch.serializedData())); + VELOX_CHECK(!batch.serializedData()->empty()); + + // Deserialize binary data. + const auto& serializedData = *batch.serializedData(); + ByteRange byteRange{ + reinterpret_cast(const_cast(serializedData.data())), + static_cast(serializedData.length()), + 0}; + auto byteStream = std::make_unique( + std::vector{{byteRange}}); + + RowVectorPtr rowVector; + serde->deserialize( + byteStream.get(), + pool, + parseBatchRowType(batch), + &rowVector, + 0, + &options); + + VELOX_CHECK_NOT_NULL(rowVector); + queryResults.push_back(rowVector); + } + + return queryResults; +} + +std::shared_ptr> createThriftClient( + const std::string& host, + int port, + std::chrono::milliseconds timeout, + folly::EventBase* evb) { + folly::SocketAddress addr(host, port); + auto socket = folly::AsyncSocket::newSocket(evb, addr, timeout.count()); + auto channel = + apache::thrift::RocketClientChannel::newChannel(std::move(socket)); + return std::make_shared>( + std::move(channel)); +} +} // namespace + +VeloxQueryRunner::VeloxQueryRunner( + memory::MemoryPool* aggregatePool, + std::string serviceUri, + std::chrono::milliseconds timeout) + : ReferenceQueryRunner(aggregatePool), + serviceUri_(std::move(serviceUri)), + timeout_(timeout) { + pool_ = aggregatePool->addLeafChild("leaf"); + + folly::Uri uri(serviceUri_); + thriftHost_ = uri.host(); + thriftPort_ = uri.port(); +} + +const std::vector& VeloxQueryRunner::supportedScalarTypes() const { + static const std::vector kScalarTypes{ + BOOLEAN(), + TINYINT(), + SMALLINT(), + INTEGER(), + BIGINT(), + REAL(), + DOUBLE(), + VARCHAR(), + VARBINARY(), + TIMESTAMP(), + TIMESTAMP_WITH_TIME_ZONE(), + IPADDRESS(), + UUID(), + // https://github.com/facebookincubator/velox/issues/15379 (IPPREFIX) + // https://github.com/facebookincubator/velox/issues/15380 (Non-orderable + // custom types such as HYPERLOGLOG, JSON, BINGTILE, GEOMETRY, etc.) + }; + return kScalarTypes; +} + +const std::unordered_map& +VeloxQueryRunner::aggregationFunctionDataSpecs() const { + static const std::unordered_map + kAggregationFunctionDataSpecs{}; + return kAggregationFunctionDataSpecs; +} + +std::optional VeloxQueryRunner::toSql( + const core::PlanNodePtr& /*plan*/) { + // We don't need to convert to SQL for VeloxQueryRunner + // as we're sending the serialized plan directly + VELOX_FAIL("VeloxQueryRunner does not support SQL conversion"); +} + +bool VeloxQueryRunner::isConstantExprSupported( + const core::TypedExprPtr& /*expr*/) { + // Since we're using Velox directly, we support all constant expressions + return true; +} + +bool VeloxQueryRunner::isSupported( + const exec::FunctionSignature& /*signature*/) { + // Since we're using Velox directly, we support all function signatures + return true; +} + +std::vector VeloxQueryRunner::execute( + const std::string& /*sql*/) { + VELOX_FAIL("VeloxQueryRunner does not support SQL execution"); +} + +std::vector VeloxQueryRunner::execute( + const std::string& /*sql*/, + const std::string& /*sessionProperty*/) { + VELOX_FAIL("VeloxQueryRunner does not support SQL execution"); +} + +std::pair< + std::optional>>, + ReferenceQueryErrorCode> +VeloxQueryRunner::execute(const core::PlanNodePtr& plan) { + auto serializedPlan = serializePlan(plan); + auto queryId = fmt::format("velox_local_query_runner_{}", rand()); + + auto client = + createThriftClient(thriftHost_, thriftPort_, timeout_, &eventBase_); + + // Create the request + ExecutePlanRequest request; + request.serializedPlan() = serializedPlan; + request.queryId() = queryId; + request.numWorkers() = 4; // Default value + request.numDrivers() = 2; // Default value + + // Send the request + ExecutePlanResponse response; + try { + client->sync_execute(response, request); + } catch (const std::exception& e) { + VELOX_FAIL("Thrift request failed: {}", e.what()); + } + + // Handle the response + if (*response.success()) { + LOG(INFO) << "Reference eval succeeded."; + return std::make_pair( + exec::test::materialize( + deserializeBatches(*response.results(), pool_.get())), + ReferenceQueryErrorCode::kSuccess); + } else { + LOG(INFO) << "Reference eval failed."; + return std::make_pair( + std::nullopt, ReferenceQueryErrorCode::kReferenceQueryFail); + } +} + +std::string VeloxQueryRunner::serializePlan(const core::PlanNodePtr& plan) { + // Serialize the plan to JSON + folly::dynamic serializedPlan = plan->serialize(); + return folly::toJson(serializedPlan); +} + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/fuzzer/VeloxQueryRunner.h b/velox/exec/fuzzer/VeloxQueryRunner.h new file mode 100644 index 00000000000..67826c62bce --- /dev/null +++ b/velox/exec/fuzzer/VeloxQueryRunner.h @@ -0,0 +1,77 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include "velox/exec/fuzzer/ReferenceQueryRunner.h" +#include "velox/vector/ComplexVector.h" + +namespace facebook::velox::exec::test { + +class VeloxQueryRunner : public ReferenceQueryRunner { + public: + /// @param serviceUri Thrift URI of the LocalRunnerService. + /// @param timeout Timeout in milliseconds of a request. + VeloxQueryRunner( + memory::MemoryPool* aggregatePool, + std::string serviceUri, + std::chrono::milliseconds timeout); + + RunnerType runnerType() const override { + return RunnerType::kVeloxQueryRunner; + } + + const std::vector& supportedScalarTypes() const override; + + const std::unordered_map& + aggregationFunctionDataSpecs() const override; + + std::optional toSql(const core::PlanNodePtr& plan) override; + + bool isConstantExprSupported(const core::TypedExprPtr& expr) override; + + bool isSupported(const exec::FunctionSignature& signature) override; + + std::pair< + std::optional>>, + ReferenceQueryErrorCode> + execute(const core::PlanNodePtr& plan) override; + + bool supportsVeloxVectorResults() const override { + return true; + } + + std::vector execute(const std::string& sql) override; + + std::vector execute( + const std::string& sql, + const std::string& sessionProperty) override; + + private: + // Serializes the plan node to JSON string + std::string serializePlan(const core::PlanNodePtr& plan); + + std::string serviceUri_; + std::chrono::milliseconds timeout_; + folly::EventBase eventBase_; + std::shared_ptr pool_; + + // Thrift-specific members + std::string thriftHost_; + int thriftPort_{9091}; +}; + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/fuzzer/WindowFuzzer.cpp b/velox/exec/fuzzer/WindowFuzzer.cpp index 81eddf15976..10160ad2029 100644 --- a/velox/exec/fuzzer/WindowFuzzer.cpp +++ b/velox/exec/fuzzer/WindowFuzzer.cpp @@ -17,8 +17,8 @@ #include "velox/exec/fuzzer/WindowFuzzer.h" #include +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/expression/ScopedVarSetter.h" DEFINE_bool( @@ -28,6 +28,8 @@ DEFINE_bool( namespace facebook::velox::exec::test { +using namespace facebook::velox::common::testutil; + namespace { bool supportIgnoreNulls(const std::string& name) { @@ -81,12 +83,20 @@ std::string WindowFuzzer::generateKRowsFrameBound( constexpr int64_t kMax = std::numeric_limits::max(); constexpr int64_t kMin = std::numeric_limits::min(); - // For frames with kPreceding, kFollowing bounds, pick a valid k, in the - // range of 1 to 10, 70% of times. Test for random k values remaining times. - int64_t minKValue, maxKValue; + // For frames with kPreceding, kFollowing bounds: + // - 70%: pick a valid k in the range [1, 10]. + // - 20%: pick a valid k in the full positive range [1, INT64_MAX]. + // - 10%: pick from the full range [INT64_MIN, INT64_MAX] to exercise + // invalid (negative/zero) offsets and verify both Velox and the reference + // DB reject them consistently. + int64_t minKValue; + int64_t maxKValue; if (vectorFuzzer_.coinToss(0.7)) { minKValue = 1; maxKValue = 10; + } else if (vectorFuzzer_.coinToss(2.0 / 3.0)) { + minKValue = 1; + maxKValue = kMax; } else { minKValue = kMin; maxKValue = kMax; @@ -609,8 +619,9 @@ void WindowFuzzer::testAlternativePlans( allKeys.emplace_back(key + " NULLS FIRST"); } for (const auto& keyAndOrder : sortingKeysAndOrders) { - allKeys.emplace_back(fmt::format( - "{} {}", keyAndOrder.key_, keyAndOrder.sortOrder_.toString())); + allKeys.emplace_back( + fmt::format( + "{} {}", keyAndOrder.key_, keyAndOrder.sortOrder_.toString())); } // Streaming window from values. @@ -627,7 +638,7 @@ void WindowFuzzer::testAlternativePlans( } // With TableScan. - auto directory = exec::test::TempDirectoryPath::create(); + auto directory = TempDirectoryPath::create(); const auto inputRowType = asRowType(input[0]->type()); if (isTableScanSupported(inputRowType)) { auto splits = makeSplits(input, directory->getPath(), writerPool_); diff --git a/velox/exec/fuzzer/WriterFuzzer.cpp b/velox/exec/fuzzer/WriterFuzzer.cpp index 9c2132c7403..71894a2fe55 100644 --- a/velox/exec/fuzzer/WriterFuzzer.cpp +++ b/velox/exec/fuzzer/WriterFuzzer.cpp @@ -16,6 +16,7 @@ #include "velox/exec/fuzzer/WriterFuzzer.h" #include +#include #include #include @@ -24,6 +25,7 @@ #include "velox/common/encode/Base64.h" #include "velox/common/file/FileSystems.h" #include "velox/common/file/tests/FaultyFileSystem.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/connectors/hive/HiveConnector.h" #include "velox/connectors/hive/HiveConnectorSplit.h" #include "velox/connectors/hive/TableHandle.h" @@ -32,7 +34,6 @@ #include "velox/exec/fuzzer/PrestoQueryRunner.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/expression/fuzzer/FuzzerToolkit.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" #include "velox/vector/VectorSaver.h" @@ -58,6 +59,13 @@ DEFINE_int32( DEFINE_int32(num_batches, 10, "The number of generated vectors."); +DEFINE_string( + max_target_file_size, + "", + "Maximum target file size for testing file rotation. " + "If empty, randomly selects from various sizes (10KB-5MB) to test edge cases. " + "Set to a specific value like '1MB' to use a fixed size."); + DEFINE_double( null_ratio, 0.1, @@ -68,6 +76,7 @@ using namespace facebook::velox::connector::hive; using namespace facebook::velox::test; namespace facebook::velox::exec::test { +using namespace facebook::velox::common::testutil; namespace { using facebook::velox::filesystems::FileSystem; @@ -148,6 +157,17 @@ class WriterFuzzer { const std::vector>& sortBy, const std::shared_ptr& outputDirectoryPath); + // Verifies file rotation by writing with a small max target file size. + // This specifically tests the file split logic for non-bucketed, + // non-sorted tables. + void verifyFileRotation( + const std::vector& input, + const std::vector& names, + const std::vector& types, + int32_t partitionOffset, + const std::vector& partitionKeys, + const std::shared_ptr& outputDirectoryPath); + // Generates table column handles based on table column properties connector::ColumnHandleMap getTableColumnHandles( const std::vector& names, @@ -159,7 +179,8 @@ class WriterFuzzer { RowVectorPtr execute( const core::PlanNodePtr& plan, const int32_t maxDrivers = 2, - const std::vector& splits = {}); + const std::vector& splits = {}, + const std::string& maxTargetFileSizeBytes = "0B"); RowVectorPtr veloxToPrestoResult(const RowVectorPtr& result); @@ -247,7 +268,7 @@ class WriterFuzzer { }; // Supported partition key column types - // According to VectorHasher::typeKindSupportsValueIds and + // According to VectorHasher::typeSupportsValueIds and // https://github.com/prestodb/presto/blob/10143be627beb2c61aba5b3d36af473d2a8ef65e/presto-hive/src/main/java/com/facebook/presto/hive/HiveUtil.java#L593 const std::vector kPartitionKeyTypes_{ BOOLEAN(), @@ -264,6 +285,10 @@ class WriterFuzzer { const std::string injectedErrorMsg_{"Injected Faulty File Error"}; std::atomic injectedErrorCount_{0}; + // Current max target file size for file rotation testing. + // Randomized per iteration to test various file split scenarios. + std::string currentMaxTargetFileSize_; + FuzzerGenerator rng_; size_t currentSeed_{0}; std::unique_ptr referenceQueryRunner_; @@ -332,8 +357,20 @@ void WriterFuzzer::go() { } while (!isDone(iteration, startTime)) { + // Randomize max target file size for file rotation testing. + // Use small sizes to trigger file splits. Include very small sizes + // (10KB, 50KB) to force many splits and test edge cases. + const std::vector fileSizes = { + "10KB", "50KB", "100KB", "500KB", "1MB", "2MB", "5MB"}; + auto fileSizeIdx = boost::random::uniform_int_distribution( + 0, fileSizes.size() - 1)(rng_); + currentMaxTargetFileSize_ = FLAGS_max_target_file_size.empty() + ? fileSizes[fileSizeIdx] + : FLAGS_max_target_file_size; + LOG(INFO) << "==============================> Started iteration " - << iteration << " (seed: " << currentSeed_ << ")"; + << iteration << " (seed: " << currentSeed_ + << ", maxTargetFileSize: " << currentMaxTargetFileSize_ << ")"; std::vector names; std::vector types; @@ -364,11 +401,12 @@ void WriterFuzzer::go() { sortColumnOffset -= offset; sortBy.reserve(sortColumns.size()); for (const auto& sortByColumn : sortColumns) { - sortBy.push_back(std::make_shared( - sortByColumn, - kSortOrderTypes_.at( - boost::random::uniform_int_distribution( - 0, 1)(rng_)))); + sortBy.push_back( + std::make_shared( + sortByColumn, + kSortOrderTypes_.at( + boost::random::uniform_int_distribution( + 0, 1)(rng_)))); } } } @@ -379,8 +417,8 @@ void WriterFuzzer::go() { } auto input = generateInputData(names, types, partitionOffset); - const auto outputDirPath = exec::test::TempDirectoryPath::create( - FLAGS_file_system_error_injection); + const auto outputDirPath = + TempDirectoryPath::create(FLAGS_file_system_error_injection); verifyWriter( input, @@ -394,6 +432,20 @@ void WriterFuzzer::go() { sortBy, outputDirPath); + // Test file rotation for non-bucketed, non-sorted tables. + // File rotation only works when bucketCount == 0 and sortBy is empty. + if (bucketCount == 0 && sortBy.empty()) { + const auto fileRotationOutputDirPath = + TempDirectoryPath::create(FLAGS_file_system_error_injection); + verifyFileRotation( + input, + names, + types, + partitionOffset, + partitionKeys, + fileRotationOutputDirPath); + } + LOG(INFO) << "==============================> Done with iteration " << iteration++; reSeed(); @@ -483,8 +535,9 @@ std::vector WriterFuzzer::generateInputData( partitionValues.at(j - partitionOffset), size)); } } - input.push_back(std::make_shared( - pool_.get(), inputType, nullptr, size, std::move(children))); + input.push_back( + std::make_shared( + pool_.get(), inputType, nullptr, size, std::move(children))); } return input; @@ -532,7 +585,8 @@ void WriterFuzzer::verifyWriter( } try { - referenceQueryRunner_->execute("DROP TABLE IF EXISTS tmp_write"); + referenceQueryRunner_->execute( + "DROP TABLE IF EXISTS " + ReferenceQueryRunner::getWriteTableName()); } catch (...) { LOG(WARNING) << "Drop table query failed in the reference DB"; return; @@ -578,7 +632,8 @@ void WriterFuzzer::verifyWriter( } try { auto referenceData = referenceQueryRunner_->execute( - "SELECT *" + bucketSql + " FROM tmp_write"); + "SELECT *" + bucketSql + " FROM " + + ReferenceQueryRunner::getWriteTableName()); VELOX_CHECK( assertEqualResults(referenceData, {actual}), "Velox and reference DB results don't match"); @@ -635,6 +690,134 @@ void WriterFuzzer::verifyWriter( LOG(INFO) << "Verified results against reference DB"; } +void WriterFuzzer::verifyFileRotation( + const std::vector& input, + const std::vector& names, + const std::vector& types, + const int32_t partitionOffset, + const std::vector& partitionKeys, + const std::shared_ptr& outputDirectoryPath) { + // Create a non-bucketed, non-sorted table write plan. + // File rotation only works for non-bucketed, non-sorted tables. + const auto plan = PlanBuilder() + .values(input) + .tableWrite( + outputDirectoryPath->getPath(), + partitionKeys, + dwio::common::FileFormat::DWRF) + .planNode(); + + const auto maxDrivers = + boost::random::uniform_int_distribution(1, 16)(rng_); + RowVectorPtr result; + const uint64_t prevInjectedErrorCount = injectedErrorCount_; + try { + result = veloxToPrestoResult( + execute(plan, maxDrivers, {}, currentMaxTargetFileSize_)); + } catch (VeloxRuntimeError& error) { + if (injectedErrorCount_ == prevInjectedErrorCount) { + throw error; + } + VELOX_CHECK_GT( + injectedErrorCount_, + prevInjectedErrorCount, + "Unexpected writer fuzzer failure: {}", + error.message()); + VELOX_CHECK_EQ( + error.message(), injectedErrorMsg_, "Unexpected writer fuzzer failure"); + return; + } + + const auto outputPath = outputDirectoryPath->getDelegatePath(); + + // 1. Count the number of files created to verify file rotation occurred. + const auto partitionNameAndFileCount = + getPartitionNameAndFilecount(outputPath); + int32_t totalFileCount = 0; + for (const auto& [partitionName, fileCount] : partitionNameAndFileCount) { + totalFileCount += fileCount; + } + LOG(INFO) << "File rotation: " << totalFileCount + << " files created with max target file size " + << currentMaxTargetFileSize_; + + // 2. Verify the written data by reading it back. + auto splits = makeSplits(outputPath); + auto columnHandles = + getTableColumnHandles(names, types, partitionOffset, /*bucketCount=*/0); + const auto rowType = generateOutputType(names, types, /*bucketCount=*/0); + + auto readPlan = PlanBuilder() + .tableScan(rowType, {}, "", rowType, columnHandles) + .planNode(); + auto actual = execute(readPlan, maxDrivers, splits); + + // 3. Compare row count with input. + int64_t expectedRowCount = 0; + for (const auto& batch : input) { + expectedRowCount += batch->size(); + } + VELOX_CHECK_EQ( + actual->size(), + expectedRowCount, + "File rotation: Row count mismatch. Expected {}, got {}", + expectedRowCount, + actual->size()); + + // 4. Verify the actual data content matches the input. + std::vector expectedBatches; + expectedBatches.reserve(input.size()); + for (const auto& batch : input) { + std::vector children; + children.reserve(batch->childrenSize()); + for (int32_t i = 0; i < partitionOffset; ++i) { + children.push_back(batch->childAt(i)); + } + for (int32_t i = partitionOffset; i < batch->childrenSize(); ++i) { + children.push_back(batch->childAt(i)); + } + expectedBatches.push_back( + std::make_shared( + pool_.get(), rowType, nullptr, batch->size(), std::move(children))); + } + + VELOX_CHECK( + assertEqualResults(expectedBatches, {actual}), + "File rotation: Data content mismatch between written and read-back data"); + + // 5. Compare with Presto as the source of truth. + try { + referenceQueryRunner_->execute( + "DROP TABLE IF EXISTS " + ReferenceQueryRunner::getWriteTableName()); + } catch (...) { + LOG(WARNING) << "Drop table query failed in the reference DB"; + return; + } + + auto prestoResult = referenceQueryRunner_->execute(plan); + if (!prestoResult.first.has_value()) { + LOG(WARNING) << "Presto write query failed, skipping comparison"; + return; + } + + VELOX_CHECK( + assertEqualResults(*prestoResult.first, plan->outputType(), {result}), + "File rotation: Velox and Presto row counts don't match"); + + try { + auto prestoData = referenceQueryRunner_->execute( + "SELECT * FROM " + ReferenceQueryRunner::getWriteTableName()); + VELOX_CHECK( + assertEqualResults(prestoData, {actual}), + "File rotation: Velox and Presto data don't match"); + } catch (...) { + LOG(WARNING) << "Query failed in the reference DB"; + return; + } + + LOG(INFO) << "File rotation verification succeeded"; +} + connector::ColumnHandleMap WriterFuzzer::getTableColumnHandles( const std::vector& names, const std::vector& types, @@ -642,11 +825,11 @@ connector::ColumnHandleMap WriterFuzzer::getTableColumnHandles( const int32_t bucketCount) { connector::ColumnHandleMap columnHandle; for (int i = 0; i < names.size(); ++i) { - HiveColumnHandle::ColumnType columnType; + FileColumnHandle::ColumnType columnType; if (i < partitionOffset) { - columnType = HiveColumnHandle::ColumnType::kRegular; + columnType = FileColumnHandle::ColumnType::kRegular; } else { - columnType = HiveColumnHandle::ColumnType::kPartitionKey; + columnType = FileColumnHandle::ColumnType::kPartitionKey; } columnHandle.insert( {names.at(i), @@ -659,7 +842,7 @@ connector::ColumnHandleMap WriterFuzzer::getTableColumnHandles( {"$bucket", std::make_shared( "$bucket", - HiveColumnHandle::ColumnType::kSynthesized, + FileColumnHandle::ColumnType::kSynthesized, INTEGER(), INTEGER())}); } @@ -669,7 +852,8 @@ connector::ColumnHandleMap WriterFuzzer::getTableColumnHandles( RowVectorPtr WriterFuzzer::execute( const core::PlanNodePtr& plan, const int32_t maxDrivers, - const std::vector& splits) { + const std::vector& splits, + const std::string& maxTargetFileSizeBytes) { LOG(INFO) << "Executing query plan: " << std::endl << plan->toString(true, true); fuzzer::ResultOrError resultOrError; @@ -682,6 +866,10 @@ RowVectorPtr WriterFuzzer::execute( kHiveConnectorId, connector::hive::HiveConfig::kMaxPartitionsPerWritersSession, "400") + .connectorSessionProperty( + kHiveConnectorId, + connector::hive::HiveConfig::kMaxTargetFileSizeSession, + maxTargetFileSizeBytes) .copyResults(pool_.get()); } @@ -702,10 +890,11 @@ RowVectorPtr WriterFuzzer::veloxToPrestoResult(const RowVectorPtr& result) { } std::string WriterFuzzer::getReferenceOutputDirectoryPath(int32_t layers) { - auto filePath = - referenceQueryRunner_->execute("SELECT \"$path\" FROM tmp_write"); + auto filePath = referenceQueryRunner_->execute( + "SELECT \"$path\" FROM " + ReferenceQueryRunner::getWriteTableName()); + auto stringView = extractSingleValue(filePath); auto tableDirectoryPath = - fs::path(extractSingleValue(filePath)).parent_path(); + fs::path(std::string_view(stringView)).parent_path(); while (layers-- > 0) { tableDirectoryPath = tableDirectoryPath.parent_path(); } @@ -853,8 +1042,8 @@ std::string WriterFuzzer::sortSql( } selectedColumns << sortBy.at(i)->sortColumn(); } - return "SELECT " + selectedColumns.str() + " FROM tmp_write " + - whereSql.str(); + return "SELECT " + selectedColumns.str() + " FROM " + + ReferenceQueryRunner::getWriteTableName() + " " + whereSql.str(); } std::string WriterFuzzer::partitionToSql( diff --git a/velox/exec/fuzzer/WriterFuzzerRunner.h b/velox/exec/fuzzer/WriterFuzzerRunner.h index fbe76685c89..97b240732b3 100644 --- a/velox/exec/fuzzer/WriterFuzzerRunner.h +++ b/velox/exec/fuzzer/WriterFuzzerRunner.h @@ -24,6 +24,7 @@ #include "velox/common/file/FileSystems.h" #include "velox/common/file/tests/FaultyFileSystem.h" +#include "velox/connectors/ConnectorRegistry.h" #include "velox/connectors/hive/HiveConnector.h" #include "velox/dwio/common/FileSink.h" #include "velox/dwio/common/tests/FaultyFileSink.h" @@ -82,7 +83,8 @@ class WriterFuzzerRunner { kHiveConnectorId, std::make_shared( std::unordered_map())); - connector::registerConnector(hiveConnector); + connector::ConnectorRegistry::global().insert( + hiveConnector->connectorId(), hiveConnector); dwrf::registerDwrfReaderFactory(); dwrf::registerDwrfWriterFactory(); dwio::common::registerFileSinks(); diff --git a/velox/exec/fuzzer/if/LocalRunnerService.thrift b/velox/exec/fuzzer/if/LocalRunnerService.thrift new file mode 100644 index 00000000000..5e28b5f0151 --- /dev/null +++ b/velox/exec/fuzzer/if/LocalRunnerService.thrift @@ -0,0 +1,58 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file defines a Thrift service for executing Velox query plans remotely. +// Results are returned using Presto's binary serialization format for efficient +// data transfer. + +namespace cpp2 facebook.velox.runner + +// Represents a batch of rows using Presto's binary serialization format. +// The serialized data can be deserialized using PrestoVectorSerde to reconstruct +// the original RowVector. +struct Batch { + // Binary serialized RowVector data in Presto format + 1: binary serializedData; + // Column names in the RowVector + 2: list columnNames; + // Column type strings in the RowVector + 3: list columnTypes; +} + +// Request to execute a serialized Velox query plan. +struct ExecutePlanRequest { + 1: string serializedPlan; + 2: string queryId; + 3: i32 numWorkers = 4; + 4: i32 numDrivers = 2; +} + +// Response from executing a query plan. +struct ExecutePlanResponse { + 1: list results; + 2: string output; + 3: bool success; + 4: optional string errorMessage; +} + +// Service for executing Velox query plans locally. +// This service enables remote execution of serialized query plans with +// configurable parallelism, returning results in a structured format. +service LocalRunnerService { + // Inputs a Thrift request and executes a serialized Velox query plan and + // returns the results as a Thrift response. + ExecutePlanResponse execute(1: ExecutePlanRequest request); +} diff --git a/velox/exec/fuzzer/tests/CMakeLists.txt b/velox/exec/fuzzer/tests/CMakeLists.txt index b7478d931a2..a77e6691a5c 100644 --- a/velox/exec/fuzzer/tests/CMakeLists.txt +++ b/velox/exec/fuzzer/tests/CMakeLists.txt @@ -16,3 +16,24 @@ add_executable(presto_sql_test PrestoSqlTest.cpp) add_test(presto_sql_test presto_sql_test) target_link_libraries(presto_sql_test velox_fuzzer_util velox_presto_types) + +# LocalRunnerService Test (requires FBThrift support) +if(VELOX_ENABLE_REMOTE_FUNCTIONS) + add_executable(local_runner_service_test LocalRunnerServiceTest.cpp) + add_test(local_runner_service_test local_runner_service_test) + + target_link_libraries( + local_runner_service_test + velox_local_runner_service_lib + local_runner_service_thrift + velox_core + velox_type + velox_functions_prestosql + velox_functions_test_lib + velox_vector_test_lib + velox_common_base + Folly::folly + gtest + gtest_main + ) +endif() diff --git a/velox/exec/fuzzer/tests/LocalRunnerServiceTest.cpp b/velox/exec/fuzzer/tests/LocalRunnerServiceTest.cpp new file mode 100644 index 00000000000..97ec6930877 --- /dev/null +++ b/velox/exec/fuzzer/tests/LocalRunnerServiceTest.cpp @@ -0,0 +1,256 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "velox/core/PlanNode.h" +#include "velox/exec/fuzzer/LocalRunnerService.h" +#include "velox/exec/fuzzer/if/gen-cpp2/LocalRunnerService.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" +#include "velox/serializers/PrestoSerializer.h" +#include "velox/type/Type.h" + +using namespace facebook::velox; +using namespace facebook::velox::runner; +using namespace facebook::velox::test; + +namespace facebook::velox::fuzzer::test { +class LocalRunnerServiceTest : public functions::test::FunctionBaseTest { + protected: + void SetUp() override { + Type::registerSerDe(); + core::PlanNode::registerSerDe(); + core::ITypedExpr::registerSerDe(); + functions::prestosql::registerAllScalarFunctions(); + functions::prestosql::registerInternalFunctions(); + + createTestData(); + } + + void createTestData() { + // Create test vectors for different data types + auto rowType = ROW({ + {"bool_col", BOOLEAN()}, + {"int_col", INTEGER()}, + {"bigint_col", BIGINT()}, + {"double_col", DOUBLE()}, + {"varchar_col", VARCHAR()}, + {"timestamp_col", TIMESTAMP()}, + {"array_col", ARRAY(ARRAY(INTEGER()))}, + }); + + testRowVector_ = makeRowVector( + {"bool_col", + "int_col", + "bigint_col", + "double_col", + "varchar_col", + "timestamp_col", + "array_col"}, + { + makeFlatVector( + 10, + [](auto row) { return row % 2 == 0; }, + [](auto row) { return row % 3 == 0; }), + makeFlatVector( + 10, + [](auto row) { return row; }, + [](auto row) { return row % 3 == 0; }), + makeFlatVector( + 10, + [](auto row) { return row; }, + [](auto row) { return row % 3 == 0; }), + makeFlatVector( + 10, + [](auto row) { return row * 1.1; }, + [](auto row) { return row % 3 == 0; }), + makeFlatVector( + 10, + [](auto row) { return fmt::format("str_{}", row); }, + [](auto row) { return row % 3 == 0; }), + makeFlatVector( + 10, + [](auto row) { return facebook::velox::Timestamp(row, 0); }, + [](auto row) { return row % 3 == 0; }), + makeNestedArrayVectorFromJson( + {"[[1, 2]]", + "[[3]]", + "[[4]]", + "[[5, 6]]", + "[[7]]", + "[[8]]", + "[[9]]", + "[[10]]", + "[[11]]", + "[[12]]"}), + }); + + testRowVectorWrapped_ = makeRowVector( + {"bool_col", + "int_col", + "bigint_col", + "double_col", + "varchar_col", + "timestamp_col", + "array_col"}, + { + makeFlatVector( + 5, + [](auto row) { return row % 2 == 0; }, + [](auto row) { return row % 3 == 0; }), + wrapInDictionary( + makeIndices(5, [](auto row) { return (row * 17 + 3) % 10; }), + 5, + makeFlatVector( + 10, + [](auto row) { return row; }, + [](auto row) { return row % 3 == 0; })), + BaseVector::wrapInConstant( + 5, + 0, + makeFlatVector( + 10, + [](auto row) { return row; }, + [](auto row) { return row % 3 == 0; })), + makeFlatVector( + 5, + [](auto row) { return row * 1.1; }, + [](auto row) { return row % 3 == 0; }), + makeFlatVector( + 5, + [](auto row) { return fmt::format("str_{}", row); }, + [](auto row) { return row % 3 == 0; }), + makeFlatVector( + 5, + [](auto row) { return facebook::velox::Timestamp(row, 0); }, + [](auto row) { return row % 3 == 0; }), + makeNestedArrayVectorFromJson( + {"[[1, 2]]", + "[[3]]", + "[[4]]", + "[[5, 6]]", + "[[7]]", + "[[8]]", + "[[9]]", + "[[10]]", + "[[11]]", + "[[12]]"}), + }); + } + + RowVectorPtr testRowVector_; + RowVectorPtr testRowVectorWrapped_; +}; + +TEST_F(LocalRunnerServiceTest, ConvertToBatchesRoundTrip) { + auto result = facebook::velox::runner::convertToBatches( + {testRowVector_}, rootPool_.get()); + + ASSERT_EQ(result.size(), 1); + ASSERT_EQ(result[0].columnNames()->size(), 7); + ASSERT_EQ(result[0].columnTypes()->size(), 7); + + // Verify serializedData is present + ASSERT_GT(result[0].serializedData()->size(), 0); + + // Deserialize and verify + auto leafPool = rootPool_->addLeafChild("deserialize"); + auto serde = std::make_unique< + facebook::velox::serializer::presto::PrestoVectorSerde>(); + facebook::velox::serializer::presto::PrestoVectorSerde::PrestoOptions options; + + const auto& serializedData = *result[0].serializedData(); + ByteRange byteRange{ + reinterpret_cast(const_cast(serializedData.data())), + static_cast(serializedData.length()), + 0}; + auto byteStream = std::make_unique( + std::vector{{byteRange}}); + + RowVectorPtr deserialized; + serde->deserialize( + byteStream.get(), + leafPool.get(), + asRowType(testRowVector_->type()), + &deserialized, + 0, + &options); + + ASSERT_NE(deserialized, nullptr); + ASSERT_EQ(deserialized->size(), testRowVector_->size()); + ASSERT_EQ(deserialized->childrenSize(), testRowVector_->childrenSize()); + + assertEqualVectors(deserialized, testRowVector_); +} + +TEST_F(LocalRunnerServiceTest, ServiceHandlerMockRequestIntegration) { + LocalRunnerServiceHandler handler; + + auto request = std::make_unique(); + // Serialized plan for the following: + // expressions: (p0:DOUBLE, plus(null,0.1646418017335236)) + request->serializedPlan() = + R"({"names":["p0","p1"],"id":"project","name":"ProjectNode","sources":[{"name":"ProjectNode","id":"transform","projections":[{"name":"FieldAccessTypedExpr","type":{"name":"Type","type":"BIGINT"},"inputs":[{"name":"InputTypedExpr","type":{"type":"ROW","name":"Type","names":["row_number"],"cTypes":[{"name":"Type","type":"BIGINT"}]}}],"fieldName":"row_number"}],"names":["row_number"],"sources":[{"name":"ValuesNode","id":"efb6650a_8541_4214_82dd_9792a4965380","data":"AAAAAF4AAAB7ImNUeXBlcyI6W3sidHlwZSI6IkJJR0lOVCIsIm5hbWUiOiJUeXBlIn1dLCJuYW1lcyI6WyJyb3dfbnVtYmVyIl0sInR5cGUiOiJST1ciLCJuYW1lIjoiVHlwZSJ9AQAAAAABAAAAAQAAAAAfAAAAeyJ0eXBlIjoiQklHSU5UIiwibmFtZSI6IlR5cGUifQEAAAAAAQgAAAAAAAAAAAAAAA==","parallelizable":false,"repeatTimes":1}]}],"projections":[{"name":"CallTypedExpr","type":{"name":"Type","type":"DOUBLE"},"functionName":"plus","inputs":[{"name":"ConstantTypedExpr","type":{"name":"Type","type":"DOUBLE"},"valueVector":"AQAAAB8AAAB7InR5cGUiOiJET1VCTEUiLCJuYW1lIjoiVHlwZSJ9AQAAAAE="},{"name":"ConstantTypedExpr","type":{"name":"Type","type":"DOUBLE"},"valueVector":"AQAAAB8AAAB7InR5cGUiOiJET1VCTEUiLCJuYW1lIjoiVHlwZSJ9AQAAAAABAAAAifsSxT8="}]},{"name":"FieldAccessTypedExpr","type":{"name":"Type","type":"BIGINT"},"fieldName":"row_number"}]})"; + request->queryId() = "query1"; + + ExecutePlanResponse response; + handler.execute(response, std::move(request)); + + EXPECT_TRUE(*response.success()); + EXPECT_EQ(response.results()->size(), 1); + + const auto& batch = (*response.results()).front(); + EXPECT_EQ(batch.columnNames()->size(), 2); + EXPECT_EQ((*batch.columnNames())[0], "p0"); + EXPECT_EQ(batch.columnTypes()->size(), 2); + EXPECT_EQ((*batch.columnTypes())[0], "DOUBLE"); + EXPECT_GT(batch.serializedData()->size(), 0); +} + +TEST_F(LocalRunnerServiceTest, ServiceHandlerMockRequestIntegrationFailure) { + LocalRunnerServiceHandler handler; + + auto request = std::make_unique(); + // Serialized plan for the following: + // expressions: (p0:TINYINT, divide(89,"c0") + // Will encounter divide by zero error. + request->serializedPlan() = + R"({"projections":[{"inputs":[{"valueVector":"AQAAACAAAAB7InR5cGUiOiJUSU5ZSU5UIiwibmFtZSI6IlR5cGUifQEAAAAAAVk=","type":{"type":"TINYINT","name":"Type"},"name":"ConstantTypedExpr"},{"fieldName":"c0","type":{"type":"TINYINT","name":"Type"},"name":"FieldAccessTypedExpr"}],"functionName":"divide","type":{"type":"TINYINT","name":"Type"},"name":"CallTypedExpr"},{"fieldName":"row_number","type":{"type":"BIGINT","name":"Type"},"name":"FieldAccessTypedExpr"}],"sources":[{"projections":[{"inputs":[{"type":{"cTypes":[{"type":"TINYINT","name":"Type"},{"type":"BIGINT","name":"Type"}],"names":["c0","row_number"],"type":"ROW","name":"Type"},"name":"InputTypedExpr"}],"fieldName":"c0","type":{"type":"TINYINT","name":"Type"},"name":"FieldAccessTypedExpr"},{"inputs":[{"type":{"cTypes":[{"type":"TINYINT","name":"Type"},{"type":"BIGINT","name":"Type"}],"names":["c0","row_number"],"type":"ROW","name":"Type"},"name":"InputTypedExpr"}],"fieldName":"row_number","type":{"type":"BIGINT","name":"Type"},"name":"FieldAccessTypedExpr"}],"sources":[{"parallelizable":false,"repeatTimes":1,"data":"AAAAAIQAAAB7ImNUeXBlcyI6W3sidHlwZSI6IlRJTllJTlQiLCJuYW1lIjoiVHlwZSJ9LHsidHlwZSI6IkJJR0lOVCIsIm5hbWUiOiJUeXBlIn1dLCJuYW1lcyI6WyJjMCIsInJvd19udW1iZXIiXSwidHlwZSI6IlJPVyIsIm5hbWUiOiJUeXBlIn0KAAAAAAIAAAABAgAAACAAAAB7InR5cGUiOiJUSU5ZSU5UIiwibmFtZSI6IlR5cGUifQoAAAAAKAAAAAMAAAACAAAABgAAAAAAAAABAAAACAAAAAUAAAAAAAAACAAAAAUAAAACAAAAIAAAAHsidHlwZSI6IlRJTllJTlQiLCJuYW1lIjoiVHlwZSJ9CgAAAAECAAAA9/8oAAAACQAAAAQAAAAJAAAAAAAAAAYAAAAHAAAABAAAAAYAAAAAAAAAAAAAAAIAAAAgAAAAeyJ0eXBlIjoiVElOWUlOVCIsIm5hbWUiOiJUeXBlIn0KAAAAAQIAAAD7oigAAAAJAAAAAQAAAAkAAAAHAAAAAAAAAAUAAAAEAAAAAwAAAAEAAAAAAAAAAAAAACAAAAB7InR5cGUiOiJUSU5ZSU5UIiwibmFtZSI6IlR5cGUifQoAAAAAAQoAAABTOkYvJBw5ZUAAAQAAAAAfAAAAeyJ0eXBlIjoiQklHSU5UIiwibmFtZSI6IlR5cGUifQoAAAAAAVAAAAAAAAAAAAAAAAEAAAAAAAAAAgAAAAAAAAADAAAAAAAAAAQAAAAAAAAABQAAAAAAAAAGAAAAAAAAAAcAAAAAAAAACAAAAAAAAAAJAAAAAAAAAA==","id":"d69f11dc_1f0e_40ae_8c5d_2cde4b784a12","name":"ValuesNode"}],"names":["c0","row_number"],"id":"transform","name":"ProjectNode"}],"names":["p0","p1"],"id":"project","name":"ProjectNode"})"; + request->queryId() = "query1"; + + ExecutePlanResponse response; + handler.execute(response, std::move(request)); + + ASSERT_TRUE(response.errorMessage().has_value()); + auto errorMsg = response.errorMessage().value(); + EXPECT_NE(errorMsg.find("Error Source: USER"), std::string::npos); + EXPECT_NE(errorMsg.find("Error Code: ARITHMETIC_ERROR"), std::string::npos); + EXPECT_NE(errorMsg.find("Reason: division by zero"), std::string::npos); + + EXPECT_FALSE(*response.success()); + EXPECT_EQ(response.results()->size(), 0); +} + +} // namespace facebook::velox::fuzzer::test + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + folly::Init init(&argc, &argv); + return RUN_ALL_TESTS(); +} diff --git a/velox/exec/fuzzer/tests/PrestoSqlTest.cpp b/velox/exec/fuzzer/tests/PrestoSqlTest.cpp index e1873d448e8..2ac4d863d8c 100644 --- a/velox/exec/fuzzer/tests/PrestoSqlTest.cpp +++ b/velox/exec/fuzzer/tests/PrestoSqlTest.cpp @@ -18,47 +18,10 @@ #include "velox/common/base/tests/GTestUtils.h" #include "velox/exec/fuzzer/PrestoSql.h" -#include "velox/functions/prestosql/types/JsonType.h" -#include "velox/functions/prestosql/types/QDigestType.h" -#include "velox/functions/prestosql/types/TDigestType.h" -#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" namespace facebook::velox::exec::test { namespace { -TEST(PrestoSqlTest, toTypeSql) { - EXPECT_EQ(toTypeSql(BOOLEAN()), "BOOLEAN"); - EXPECT_EQ(toTypeSql(TINYINT()), "TINYINT"); - EXPECT_EQ(toTypeSql(SMALLINT()), "SMALLINT"); - EXPECT_EQ(toTypeSql(INTEGER()), "INTEGER"); - EXPECT_EQ(toTypeSql(BIGINT()), "BIGINT"); - EXPECT_EQ(toTypeSql(REAL()), "REAL"); - EXPECT_EQ(toTypeSql(DOUBLE()), "DOUBLE"); - EXPECT_EQ(toTypeSql(VARCHAR()), "VARCHAR"); - EXPECT_EQ(toTypeSql(VARBINARY()), "VARBINARY"); - EXPECT_EQ(toTypeSql(TDIGEST(DOUBLE())), "TDIGEST(DOUBLE)"); - EXPECT_EQ(toTypeSql(TIMESTAMP()), "TIMESTAMP"); - EXPECT_EQ(toTypeSql(QDIGEST(DOUBLE())), "QDIGEST(DOUBLE)"); - EXPECT_EQ(toTypeSql(QDIGEST(BIGINT())), "QDIGEST(BIGINT)"); - EXPECT_EQ(toTypeSql(QDIGEST(REAL())), "QDIGEST(REAL)"); - EXPECT_EQ(toTypeSql(DATE()), "DATE"); - EXPECT_EQ(toTypeSql(TIMESTAMP_WITH_TIME_ZONE()), "TIMESTAMP WITH TIME ZONE"); - EXPECT_EQ(toTypeSql(ARRAY(BOOLEAN())), "ARRAY(BOOLEAN)"); - EXPECT_EQ(toTypeSql(MAP(BOOLEAN(), INTEGER())), "MAP(BOOLEAN, INTEGER)"); - EXPECT_EQ( - toTypeSql(ROW({{"a", BOOLEAN()}, {"b", INTEGER()}})), - "ROW(a BOOLEAN, b INTEGER)"); - EXPECT_EQ( - toTypeSql( - ROW({{"a_", BOOLEAN()}, {"b$", INTEGER()}, {"c d", INTEGER()}})), - "ROW(a_ BOOLEAN, b$ INTEGER, c d INTEGER)"); - EXPECT_EQ(toTypeSql(JSON()), "JSON"); - EXPECT_EQ(toTypeSql(UNKNOWN()), "UNKNOWN"); - VELOX_ASSERT_THROW( - toTypeSql(FUNCTION({INTEGER()}, INTEGER())), - "Type is not supported: FUNCTION"); -} - void toUnaryOperator( const std::string& operatorName, const std::string& expectedSql) { @@ -95,11 +58,12 @@ TEST(PrestoSqlTest, toCallSql) { toUnaryOperator("negate", "(- c0)"); toUnaryOperator("not", "(not c0)"); VELOX_ASSERT_THROW( - toCallSql(std::make_shared( - INTEGER(), - "not", - std::make_shared(VARCHAR(), "c0"), - std::make_shared(VARCHAR(), "c1"))), + toCallSql( + std::make_shared( + INTEGER(), + "not", + std::make_shared(VARCHAR(), "c0"), + std::make_shared(VARCHAR(), "c1"))), "Expected one argument to a unary operator"); // Binary operators @@ -116,113 +80,126 @@ TEST(PrestoSqlTest, toCallSql) { toBinaryOperator("gte", "(c0 >= c1)"); toBinaryOperator("distinct_from", "(c0 is distinct from c1)"); VELOX_ASSERT_THROW( - toCallSql(std::make_shared( - INTEGER(), - "plus", - std::make_shared(INTEGER(), "c0"), - std::make_shared(INTEGER(), "c1"), - std::make_shared(INTEGER(), "c3"))), + toCallSql( + std::make_shared( + INTEGER(), + "plus", + std::make_shared(INTEGER(), "c0"), + std::make_shared(INTEGER(), "c1"), + std::make_shared(INTEGER(), "c3"))), "Expected two arguments to a binary operator"); // Functions IS NULL and NOT NULL toIsNullOrIsNotNull("is_null", "(c0 is null)"); toIsNullOrIsNotNull("not_null", "(c0 is not null)"); VELOX_ASSERT_THROW( - toCallSql(std::make_shared( - BOOLEAN(), - "is_null", - std::make_shared(INTEGER(), "c0"), - std::make_shared(INTEGER(), "c1"))), + toCallSql( + std::make_shared( + BOOLEAN(), + "is_null", + std::make_shared(INTEGER(), "c0"), + std::make_shared(INTEGER(), "c1"))), "Expected one argument to function 'is_null' or 'not_null'"); // Function IN EXPECT_EQ( - toCallSql(std::make_shared( - BOOLEAN(), - "in", - std::make_shared(VARCHAR(), "a"), - std::make_shared(VARCHAR(), "b"))), + toCallSql( + std::make_shared( + BOOLEAN(), + "in", + std::make_shared(VARCHAR(), "a"), + std::make_shared(VARCHAR(), "b"))), "'a' in ('b')"); EXPECT_EQ( - toCallSql(std::make_shared( - BOOLEAN(), - "in", - std::make_shared(VARCHAR(), "a"), - std::make_shared(VARCHAR(), "b"), - std::make_shared(VARCHAR(), "c"), - std::make_shared(VARCHAR(), "d"))), + toCallSql( + std::make_shared( + BOOLEAN(), + "in", + std::make_shared(VARCHAR(), "a"), + std::make_shared(VARCHAR(), "b"), + std::make_shared(VARCHAR(), "c"), + std::make_shared(VARCHAR(), "d"))), "'a' in ('b', 'c', 'd')"); VELOX_ASSERT_THROW( - toCallSql(std::make_shared( - BOOLEAN(), - "in", - std::make_shared(VARCHAR(), "a"))), + toCallSql( + std::make_shared( + BOOLEAN(), + "in", + std::make_shared(VARCHAR(), "a"))), "Expected at least two arguments to function 'in'"); // Function LIKE EXPECT_EQ( - toCallSql(std::make_shared( - BOOLEAN(), - "like", - std::make_shared(VARCHAR(), "c0"), - std::make_shared(VARCHAR(), "a"))), + toCallSql( + std::make_shared( + BOOLEAN(), + "like", + std::make_shared(VARCHAR(), "c0"), + std::make_shared(VARCHAR(), "a"))), "(c0 like 'a')"); EXPECT_EQ( - toCallSql(std::make_shared( - BOOLEAN(), - "like", - std::make_shared(VARCHAR(), "c0"), - std::make_shared(VARCHAR(), "a"), - std::make_shared(VARCHAR(), "b"))), + toCallSql( + std::make_shared( + BOOLEAN(), + "like", + std::make_shared(VARCHAR(), "c0"), + std::make_shared(VARCHAR(), "a"), + std::make_shared(VARCHAR(), "b"))), "(c0 like 'a' escape 'b')"); VELOX_ASSERT_THROW( - toCallSql(std::make_shared( - BOOLEAN(), - "like", - std::make_shared(VARCHAR(), "a"))), + toCallSql( + std::make_shared( + BOOLEAN(), + "like", + std::make_shared(VARCHAR(), "a"))), "Expected at least two arguments to function 'like'"); VELOX_ASSERT_THROW( - toCallSql(std::make_shared( - BOOLEAN(), - "like", - std::make_shared(VARCHAR(), "a"), - std::make_shared(VARCHAR(), "b"), - std::make_shared(VARCHAR(), "c"), - std::make_shared(VARCHAR(), "d"))), + toCallSql( + std::make_shared( + BOOLEAN(), + "like", + std::make_shared(VARCHAR(), "a"), + std::make_shared(VARCHAR(), "b"), + std::make_shared(VARCHAR(), "c"), + std::make_shared(VARCHAR(), "d"))), "Expected at most three arguments to function 'like'"); // Functions OR and AND EXPECT_EQ( - toCallSql(std::make_shared( - BOOLEAN(), - "or", - std::make_shared(BOOLEAN(), true), - std::make_shared(BOOLEAN(), false))), + toCallSql( + std::make_shared( + BOOLEAN(), + "or", + std::make_shared(BOOLEAN(), true), + std::make_shared(BOOLEAN(), false))), "(BOOLEAN 'true' or BOOLEAN 'false')"); EXPECT_EQ( - toCallSql(std::make_shared( - BOOLEAN(), - "and", - std::make_shared(BOOLEAN(), true), - std::make_shared(BOOLEAN(), false))), + toCallSql( + std::make_shared( + BOOLEAN(), + "and", + std::make_shared(BOOLEAN(), true), + std::make_shared(BOOLEAN(), false))), "(BOOLEAN 'true' and BOOLEAN 'false')"); EXPECT_EQ( - toCallSql(std::make_shared( - BOOLEAN(), - "or", - std::make_shared(BOOLEAN(), true), - std::make_shared(BOOLEAN(), false), - std::make_shared(BOOLEAN(), true), - std::make_shared(BOOLEAN(), false))), + toCallSql( + std::make_shared( + BOOLEAN(), + "or", + std::make_shared(BOOLEAN(), true), + std::make_shared(BOOLEAN(), false), + std::make_shared(BOOLEAN(), true), + std::make_shared(BOOLEAN(), false))), "(BOOLEAN 'true' or BOOLEAN 'false' or BOOLEAN 'true' or BOOLEAN 'false')"); EXPECT_EQ( - toCallSql(std::make_shared( - BOOLEAN(), - "and", - std::make_shared(BOOLEAN(), true), - std::make_shared(BOOLEAN(), false), - std::make_shared(BOOLEAN(), true), - std::make_shared(BOOLEAN(), false))), + toCallSql( + std::make_shared( + BOOLEAN(), + "and", + std::make_shared(BOOLEAN(), true), + std::make_shared(BOOLEAN(), false), + std::make_shared(BOOLEAN(), true), + std::make_shared(BOOLEAN(), false))), "(BOOLEAN 'true' and BOOLEAN 'false' and BOOLEAN 'true' and BOOLEAN 'false')"); VELOX_ASSERT_THROW( toCallSql(std::make_shared(BOOLEAN(), "or")), @@ -230,25 +207,28 @@ TEST(PrestoSqlTest, toCallSql) { // Functions ARRAY_CONSTRUCTOR and ROW_CONSTRUCTOR EXPECT_EQ( - toCallSql(std::make_shared( - ARRAY(INTEGER()), - "array_constructor", - std::make_shared(VARCHAR(), "a"), - std::make_shared(VARCHAR(), "b"), - std::make_shared(VARCHAR(), "c"))), + toCallSql( + std::make_shared( + ARRAY(INTEGER()), + "array_constructor", + std::make_shared(VARCHAR(), "a"), + std::make_shared(VARCHAR(), "b"), + std::make_shared(VARCHAR(), "c"))), "ARRAY['a', 'b', 'c']"); EXPECT_EQ( - toCallSql(std::make_shared( - ARRAY(INTEGER()), "array_constructor")), + toCallSql( + std::make_shared( + ARRAY(INTEGER()), "array_constructor")), "ARRAY[]"); EXPECT_EQ( - toCallSql(std::make_shared( - BOOLEAN(), - "row_constructor", - std::make_shared(VARCHAR(), "a"), - std::make_shared(VARCHAR(), "b"), - std::make_shared(VARCHAR(), "c"), - std::make_shared(VARCHAR(), "d"))), + toCallSql( + std::make_shared( + BOOLEAN(), + "row_constructor", + std::make_shared(VARCHAR(), "a"), + std::make_shared(VARCHAR(), "b"), + std::make_shared(VARCHAR(), "c"), + std::make_shared(VARCHAR(), "d"))), "row('a', 'b', 'c', 'd')"); VELOX_ASSERT_THROW( toCallSql( @@ -257,27 +237,29 @@ TEST(PrestoSqlTest, toCallSql) { // Function BETWEEN EXPECT_EQ( - toCallSql(std::make_shared( - BOOLEAN(), - "between", - std::make_shared(INTEGER(), "c0"), - std::make_shared(INTEGER(), "c1"), - std::make_shared(INTEGER(), "c2"))), + toCallSql( + std::make_shared( + BOOLEAN(), + "between", + std::make_shared(INTEGER(), "c0"), + std::make_shared(INTEGER(), "c1"), + std::make_shared(INTEGER(), "c2"))), "(c0 between c1 and c2)"); // Edge case check for ambiguous parantheses processing, query will fail // without the parantheses wrapping the left-hand side. EXPECT_EQ( - toCallSql(std::make_shared( - BOOLEAN(), - "lt", + toCallSql( std::make_shared( BOOLEAN(), - "between", - std::make_shared(INTEGER(), "c0"), - std::make_shared(INTEGER(), "c0"), - std::make_shared( - INTEGER(), variant::null(TypeKind::INTEGER))), - std::make_shared(INTEGER(), "c0"))), + "lt", + std::make_shared( + BOOLEAN(), + "between", + std::make_shared(INTEGER(), "c0"), + std::make_shared(INTEGER(), "c0"), + std::make_shared( + INTEGER(), variant::null(TypeKind::INTEGER))), + std::make_shared(INTEGER(), "c0"))), "((c0 between c0 and cast(null as INTEGER)) < c0)"); VELOX_ASSERT_THROW( toCallSql(std::make_shared(BOOLEAN(), "between")), @@ -285,66 +267,74 @@ TEST(PrestoSqlTest, toCallSql) { // Function SUBSCRIPT, builds '[]' SQL EXPECT_EQ( - toCallSql(std::make_shared( - INTEGER(), - "subscript", - std::make_shared( - ARRAY(INTEGER()), "array"), - std::make_shared(INTEGER(), "c0"))), + toCallSql( + std::make_shared( + INTEGER(), + "subscript", + std::make_shared( + ARRAY(INTEGER()), "array"), + std::make_shared(INTEGER(), "c0"))), "array[c0]"); VELOX_ASSERT_THROW( - toCallSql(std::make_shared( - INTEGER(), - "subscript", - std::make_shared( - ARRAY(INTEGER()), "array"), - std::make_shared(INTEGER(), "c0"), - std::make_shared(INTEGER(), "c1"))), + toCallSql( + std::make_shared( + INTEGER(), + "subscript", + std::make_shared( + ARRAY(INTEGER()), "array"), + std::make_shared(INTEGER(), "c0"), + std::make_shared(INTEGER(), "c1"))), "Expected two arguments to function 'subscript'"); // Function SWITCH, builds 'CASE WHEN ... THEN ... ELSE ... END' SQL // SWITCH cases with no ELSE. EXPECT_EQ( - toCallSql(std::make_shared( - INTEGER(), - "switch", - std::make_shared(BOOLEAN(), "c0"), - std::make_shared(VARCHAR(), "c1"))), + toCallSql( + std::make_shared( + INTEGER(), + "switch", + std::make_shared(BOOLEAN(), "c0"), + std::make_shared(VARCHAR(), "c1"))), "case when c0 then c1 end"); EXPECT_EQ( - toCallSql(std::make_shared( - INTEGER(), - "switch", - std::make_shared(BOOLEAN(), "c0"), - std::make_shared(INTEGER(), "c1"), - std::make_shared(BOOLEAN(), "c2"), - std::make_shared(INTEGER(), "c3"))), + toCallSql( + std::make_shared( + INTEGER(), + "switch", + std::make_shared(BOOLEAN(), "c0"), + std::make_shared(INTEGER(), "c1"), + std::make_shared(BOOLEAN(), "c2"), + std::make_shared(INTEGER(), "c3"))), "case when c0 then c1 when c2 then c3 end"); // SWITCH case with ELSE. EXPECT_EQ( - toCallSql(std::make_shared( - INTEGER(), - "switch", - std::make_shared(BOOLEAN(), "c0"), - std::make_shared(INTEGER(), "c1"), - std::make_shared(BOOLEAN(), "c2"), - std::make_shared(INTEGER(), "c3"), - std::make_shared(INTEGER(), "c4"))), + toCallSql( + std::make_shared( + INTEGER(), + "switch", + std::make_shared(BOOLEAN(), "c0"), + std::make_shared(INTEGER(), "c1"), + std::make_shared(BOOLEAN(), "c2"), + std::make_shared(INTEGER(), "c3"), + std::make_shared(INTEGER(), "c4"))), "case when c0 then c1 when c2 then c3 else c4 end"); VELOX_ASSERT_THROW( - toCallSql(std::make_shared( - INTEGER(), - "switch", - std::make_shared(INTEGER(), "c0"))), + toCallSql( + std::make_shared( + INTEGER(), + "switch", + std::make_shared(INTEGER(), "c0"))), "Expected at least two arguments to function 'switch'"); // Generic functions EXPECT_EQ( - toCallSql(std::make_shared( - INTEGER(), - "array_top_n", - std::make_shared(ARRAY(INTEGER()), "c0"), - std::make_shared(INTEGER(), "c1"))), + toCallSql( + std::make_shared( + INTEGER(), + "array_top_n", + std::make_shared( + ARRAY(INTEGER()), "c0"), + std::make_shared(INTEGER(), "c1"))), "array_top_n(c0, c1)"); EXPECT_EQ( toCallSql(std::make_shared(REAL(), "infinity")), @@ -377,10 +367,10 @@ TEST(PrestoSqlTest, toCallInputsSql) { TEST(PrestoSqlTest, toConstantSql) { EXPECT_EQ( toConstantSql(core::ConstantTypedExpr(INTERVAL_YEAR_MONTH(), 123)), - "INTERVAL '123' YEAR TO MONTH"); + "INTERVAL '10-3' YEAR TO MONTH"); EXPECT_EQ( toConstantSql(core::ConstantTypedExpr(INTERVAL_DAY_TIME(), int64_t(123))), - "INTERVAL '123' DAY TO SECOND"); + "INTERVAL '0 00:00:00.123' DAY TO SECOND"); } } // namespace diff --git a/velox/exec/prefixsort/CMakeLists.txt b/velox/exec/prefixsort/CMakeLists.txt index aac7db501ec..9553d268332 100644 --- a/velox/exec/prefixsort/CMakeLists.txt +++ b/velox/exec/prefixsort/CMakeLists.txt @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +velox_add_library(velox_prefixsort INTERFACE HEADERS PrefixSortAlgorithm.h PrefixSortEncoder.h) + if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) elseif(${VELOX_BUILD_TEST_UTILS}) @@ -21,3 +23,5 @@ endif() if(${VELOX_ENABLE_BENCHMARKS}) add_subdirectory(benchmarks) endif() + +velox_install_library_headers() diff --git a/velox/exec/prefixsort/PrefixSortEncoder.h b/velox/exec/prefixsort/PrefixSortEncoder.h index 945b1de5e16..87756792a72 100644 --- a/velox/exec/prefixsort/PrefixSortEncoder.h +++ b/velox/exec/prefixsort/PrefixSortEncoder.h @@ -28,7 +28,7 @@ namespace facebook::velox::exec::prefixsort { class PrefixSortEncoder { public: PrefixSortEncoder(bool ascending, bool nullsFirst) - : ascending_(ascending), nullsFirst_(nullsFirst){}; + : ascending_(ascending), nullsFirst_(nullsFirst) {} /// Encode native primitive types(such as uint64_t, int64_t, uint32_t, /// int32_t, uint16_t, int16_t, float, double, Timestamp). diff --git a/velox/exec/prefixsort/tests/PrefixEncoderTest.cpp b/velox/exec/prefixsort/tests/PrefixEncoderTest.cpp index 398da750b04..ab88faa09cb 100644 --- a/velox/exec/prefixsort/tests/PrefixEncoderTest.cpp +++ b/velox/exec/prefixsort/tests/PrefixEncoderTest.cpp @@ -270,8 +270,7 @@ class PrefixEncoderTest : public testing::Test, const auto rightValue = rightVector->isNullAt(i) ? std::nullopt : std::optional(rightVector->valueAt(i)); - if constexpr ( - Kind == TypeKind::VARCHAR || Kind == TypeKind::VARBINARY) { + if constexpr (is_string_kind(Kind)) { encoder.encode(leftValue, leftEncoded, 17, true); encoder.encode(rightValue, rightEncoded, 17, true); } else { diff --git a/velox/exec/prefixsort/tests/utils/CMakeLists.txt b/velox/exec/prefixsort/tests/utils/CMakeLists.txt index c4691246a33..6b6c790cd83 100644 --- a/velox/exec/prefixsort/tests/utils/CMakeLists.txt +++ b/velox/exec/prefixsort/tests/utils/CMakeLists.txt @@ -13,5 +13,6 @@ # limitations under the License. add_library(velox_exec_prefixsort_test_lib EncoderTestUtils.cpp) +velox_add_test_headers(velox_exec_prefixsort_test_lib EncoderTestUtils.h) target_link_libraries(velox_exec_prefixsort_test_lib velox_vector_test_lib) diff --git a/velox/exec/rpc/CMakeLists.txt b/velox/exec/rpc/CMakeLists.txt new file mode 100644 index 00000000000..dafd5ad6fcd --- /dev/null +++ b/velox/exec/rpc/CMakeLists.txt @@ -0,0 +1,68 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# RPCNode is part of velox_core (in core/PlanNode.h) + +velox_add_library(velox_rpc_state RPCState.cpp HEADERS RPCState.h) + +velox_link_libraries( + velox_rpc_state + velox_rpc_types + velox_common_base + velox_future + velox_vector + Folly::folly +) + +velox_add_library(velox_rpc_rate_limiter RPCRateLimiter.cpp HEADERS RPCRateLimiter.h) + +velox_link_libraries(velox_rpc_rate_limiter velox_future) + +velox_add_library(velox_rpc_operator RPCOperator.cpp HEADERS RPCOperator.h) + +velox_link_libraries( + velox_rpc_operator + velox_rpc_state + velox_rpc_rate_limiter + velox_async_rpc_function + velox_async_rpc_function_registry + velox_exec + velox_expression + velox_vector + velox_buffer + velox_future + Folly::folly +) + +velox_add_library( + velox_rpc_plan_node_translator + RPCPlanNodeTranslator.cpp + HEADERS + RPCPlanNodeTranslator.h +) + +velox_link_libraries(velox_rpc_plan_node_translator velox_rpc_operator velox_exec) + +if(${VELOX_BUILD_TESTING}) + velox_add_library(velox_demo_rpc_function tests/DemoRPCFunction.cpp) + + velox_link_libraries( + velox_demo_rpc_function + velox_mock_rpc_client + velox_expression_functions + velox_async_rpc_function + ) + + add_subdirectory(tests) +endif() diff --git a/velox/exec/rpc/RPCOperator.cpp b/velox/exec/rpc/RPCOperator.cpp new file mode 100644 index 00000000000..b4f903fa50c --- /dev/null +++ b/velox/exec/rpc/RPCOperator.cpp @@ -0,0 +1,683 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/rpc/RPCOperator.h" + +#include "velox/common/time/CpuWallTimer.h" +#include "velox/common/time/Timer.h" +#include "velox/expression/rpc/AsyncRPCFunctionRegistry.h" + +#define RPC_OP_LOG(severity) LOG(severity) << "[RPC_OP] " +#define RPC_OP_VLOG(level) VLOG(level) << "[RPC_OP] " + +namespace facebook::velox::exec::rpc { + +RPCOperator::RPCOperator( + int32_t operatorId, + exec::DriverCtx* driverCtx, + std::shared_ptr rpcNode) + : exec::Operator( + driverCtx, + rpcNode->outputType(), + operatorId, + rpcNode->id(), + "RPC"), + rpcNode_(std::move(rpcNode)), + state_(std::make_shared()), + dispatchBatchSize_(rpcNode_->dispatchBatchSize()) { + // Configure RPCState with streaming mode. + state_->setStreamingMode(rpcNode_->streamingMode()); +} + +void RPCOperator::initialize() { + Operator::initialize(); + + // Resolve the AsyncRPCFunction by name from the registry. + function_ = AsyncRPCFunctionRegistry::create(rpcNode_->functionName()); + VELOX_CHECK( + function_, + "Unknown RPC function '{}'. Ensure it is registered via " + "AsyncRPCFunctionRegistry::registerFunction() before query execution.", + rpcNode_->functionName()); + + // Initialize the function with query config, argument types, and constants. + // The function creates/caches its own transport and clients internally. + function_->initialize( + operatorCtx_->driverCtx()->queryConfig(), + rpcNode_->argumentTypes(), + rpcNode_->constantInputs()); + + tierKey_ = function_->tierKey(); + + RPC_OP_VLOG(1) << "Created operator for function '" + << rpcNode_->functionName() << "', planNodeId=" << planNodeId() + << ", operatorId=" << operatorId() << ", streamingMode=" + << (rpcNode_->streamingMode() == RPCStreamingMode::kBatch + ? "BATCH" + : "PER_ROW"); + + // Precompute argument column indices for addInput(). + const auto& argCols = rpcNode_->argumentColumns(); + if (!argCols.empty()) { + auto sourceType = rpcNode_->source()->outputType(); + argumentColumnIndices_.reserve(argCols.size()); + for (const auto& colName : argCols) { + auto idx = sourceType->getChildIdx(colName); + argumentColumnIndices_.push_back(static_cast(idx)); + } + RPC_OP_VLOG(1) << "Initialized with " << argCols.size() + << " argument columns"; + } else { + RPC_OP_VLOG(1) << "Initialized with no argument columns " + << "(fallback to all input columns)"; + } + + // Precompute output column projections to avoid string lookups in + // buildOutputVector(). + initOutputProjections(); +} + +bool RPCOperator::needsInput() const { + if (noMoreInput_ || isDraining()) { + return false; + } + + // Don't accept input if we have results ready to output. + if (!claimedRows_.empty() || claimedBatch_.has_value()) { + return false; + } + + // Check per-state backpressure. + if (state_->isUnderBackpressure()) { + return false; + } + + return true; +} + +void RPCOperator::addInput(RowVectorPtr input) { + if (!input || input->size() == 0) { + RPC_OP_VLOG(2) << "addInput received empty input"; + return; + } + + RPC_OP_VLOG(1) << "addInput received " << input->size() << " rows with " + << input->childrenSize() << " input columns"; + + SelectivityVector rows(input->size()); + + // Read pre-computed argument columns by precomputed index. + std::vector args; + if (!argumentColumnIndices_.empty()) { + args.reserve(argumentColumnIndices_.size()); + for (auto idx : argumentColumnIndices_) { + args.push_back(input->childAt(idx)); + } + } else { + // Fallback: use all input columns as arguments. + for (auto i = 0; i < input->childrenSize(); ++i) { + args.push_back(input->childAt(i)); + } + } + + // Flatten/load all columns upfront to avoid issues with lazy vectors. + std::vector flattenedColumns; + flattenedColumns.reserve(input->childrenSize()); + for (int32_t j = 0; j < input->childrenSize(); ++j) { + auto column = BaseVector::loadedVectorShared(input->childAt(j)); + BaseVector::flattenVector(column); + flattenedColumns.push_back(column); + } + + auto streamingMode = state_->streamingMode(); + + if (streamingMode == RPCStreamingMode::kPerRow) { + // PER_ROW: function dispatches individual RPCs and returns futures. + auto futures = function_->dispatchPerRow(rows, args); + + auto batchIndex = state_->storeInputBatch( + flattenedColumns, static_cast(futures.size())); + numRequestsDispatched_ += static_cast(futures.size()); + + for (auto& [originalRowIndex, future] : futures) { + auto rowId = globalRowIdCounter_++; + auto token = std::make_shared( + RPCRateLimiter::acquire(tierKey_)); + + auto wrapped = + std::move(future) + .within(kBatchRpcTimeout) + .deferValue([rowId, token](RPCResponse resp) { + resp.rowId = rowId; + return resp; + }) + .deferError([token](folly::exception_wrapper ew) { + return folly::makeSemiFuture(std::move(ew)); + }); + + state_->addPendingRow( + state_, + rowId, + RPCState::RowLocation{batchIndex, originalRowIndex}, + std::move(wrapped)); + } + } else { + // BATCH: function accumulates typed data internally. + auto rowIndices = function_->accumulateBatch(rows, args); + + auto batchIndex = state_->storeInputBatch( + flattenedColumns, static_cast(rowIndices.size())); + numRequestsDispatched_ += static_cast(rowIndices.size()); + + for (auto originalRowIndex : rowIndices) { + auto rowId = globalRowIdCounter_++; + batchRowLocations_.push_back( + RPCState::RowLocation{batchIndex, originalRowIndex}); + batchRowIds_.push_back(rowId); + } + + if (dispatchBatchSize_ > 0 && + function_->pendingBatchSize() >= dispatchBatchSize_) { + // Flush in chunks of dispatchBatchSize_ to avoid sending one + // giant batch_predict call that overwhelms the server. + while (function_->pendingBatchSize() >= dispatchBatchSize_ && + !state_->isUnderBackpressure()) { + flushBatchRequests(dispatchBatchSize_); + } + } + } +} + +void RPCOperator::flushBatchRequests(int32_t maxRows) { + if (function_->pendingBatchSize() == 0) { + VELOX_CHECK( + batchRowLocations_.empty(), + "Operator has {} accumulated batch rows but function reports " + "pendingBatchSize=0. Function must override pendingBatchSize() " + "when using BATCH mode.", + batchRowLocations_.size()); + return; + } + + // Determine how many rows to flush. + auto flushCount = maxRows > 0 + ? std::min(static_cast(batchRowLocations_.size()), maxRows) + : static_cast(batchRowLocations_.size()); + + RPC_OP_LOG(INFO) << "Flushing batch with " << flushCount << " of " + << function_->pendingBatchSize() << " accumulated rows"; + + // Split off the rows to flush. + std::vector rowLocations( + batchRowLocations_.begin(), batchRowLocations_.begin() + flushCount); + std::vector rowIds( + batchRowIds_.begin(), batchRowIds_.begin() + flushCount); + batchRowLocations_.erase( + batchRowLocations_.begin(), batchRowLocations_.begin() + flushCount); + batchRowIds_.erase(batchRowIds_.begin(), batchRowIds_.begin() + flushCount); + + auto future = function_->flushBatch(maxRows); + + // Count each flushBatch() as 1 pending unit in the rate limiter. + auto token = std::make_shared( + RPCRateLimiter::acquire(tierKey_)); + + // Stamp rowIds onto responses. + auto wrapped = + std::move(future) + .within(kBatchRpcTimeout) + .deferValue([rowIds = std::move(rowIds), + token](std::vector resps) { + VELOX_CHECK_EQ( + resps.size(), + rowIds.size(), + "RPC batch response count ({}) does not match row count ({})", + resps.size(), + rowIds.size()); + for (size_t i = 0; i < resps.size(); ++i) { + resps[i].rowId = rowIds[i]; + } + return resps; + }) + .deferError([token](folly::exception_wrapper ew) { + RPC_OP_LOG(ERROR) << "RPC batch failed: " << ew.what(); + return folly::makeSemiFuture>( + std::move(ew)); + }); + + state_->addPendingBatch(state_, std::move(wrapped), std::move(rowLocations)); +} + +void RPCOperator::noMoreInput() { + exec::Operator::noMoreInput(); + + RPC_OP_VLOG(1) << "noMoreInput: totalRequestsDispatched=" + << numRequestsDispatched_; + + if (state_->streamingMode() == RPCStreamingMode::kBatch) { + // Flush any remaining accumulated rows in chunks. + while (function_->pendingBatchSize() > 0) { + flushBatchRequests(dispatchBatchSize_ > 0 ? dispatchBatchSize_ : 0); + } + } + + state_->setNoMoreInput(); +} + +RowVectorPtr RPCOperator::getOutput() { + auto streamingMode = state_->streamingMode(); + + if (streamingMode == RPCStreamingMode::kPerRow) { + if (claimedRows_.empty()) { + // If draining and nothing left to output, check finish. + if (isDraining() && state_->isFinished()) { + finished_ = true; + finishDrain(); + } + return nullptr; + } + + // Drain additional ready rows (non-blocking) for batched output. + // This amortizes RowVector allocation across multiple completed rows. + state_->drainReadyRows(claimedRows_, 1024); + + auto numRows = static_cast(claimedRows_.size()); + for (const auto& row : claimedRows_) { + if (row.response.hasError()) { + numErrors_++; + } + } + auto output = buildOutputFromReadyRows(claimedRows_); + numResponsesCollected_ += numRows; + claimedRows_.clear(); + return output; + } else { + if (!claimedBatch_.has_value()) { + // If draining and nothing left to output, check finish. + if (isDraining() && state_->isFinished()) { + finished_ = true; + finishDrain(); + } + return nullptr; + } + + // Fail loudly on batch errors instead of silently dropping rows. + if (claimedBatch_->error.has_value()) { + auto error = claimedBatch_->error.value(); + claimedBatch_.reset(); + VELOX_FAIL("RPC batch failed: {}", error); + } + + auto numRows = static_cast(claimedBatch_->responses.size()); + for (const auto& response : claimedBatch_->responses) { + if (response.hasError()) { + numErrors_++; + } + } + + // Delegate congestion evaluation to the function. + // The function knows its domain-specific error semantics. + auto signal = function_->evaluateCongestion(claimedBatch_->responses); + if (signal == AsyncRPCFunction::CongestionSignal::kError) { + state_->onBatchError(); + } else if (signal == AsyncRPCFunction::CongestionSignal::kSuccess) { + state_->onBatchSuccess(function_->congestionRecoveryIncrement()); + } + + auto output = buildOutputFromReadyBatch(*claimedBatch_); + numResponsesCollected_ += numRows; + claimedBatch_.reset(); + return output; + } +} + +exec::BlockingReason RPCOperator::isBlocked(ContinueFuture* future) { + // End any previous block wait measurement. + if (blockWaitStartNs_.has_value()) { + auto elapsed = getCurrentTimeNano() - blockWaitStartNs_.value(); + if (blockWaitIsBackpressure_) { + totalBackpressureWaitNanos_ += elapsed; + } else { + totalBlockWaitNanos_ += elapsed; + } + blockWaitStartNs_ = std::nullopt; + } + + // Check per-tier backpressure first. + if (auto backpressureFuture = RPCRateLimiter::checkBackpressure(tierKey_)) { + RPC_OP_VLOG(1) << "Backpressure applied for tier '" << tierKey_ + << "', pending=" << RPCRateLimiter::pendingCount(tierKey_); + *future = std::move(*backpressureFuture); + blockWaitStartNs_ = getCurrentTimeNano(); + blockWaitIsBackpressure_ = true; + return exec::BlockingReason::kWaitForRPC; + } + + // If we already have output ready, don't block. + if (!claimedRows_.empty() || claimedBatch_.has_value()) { + return exec::BlockingReason::kNotBlocked; + } + + // If finished, don't block. + if (finished_) { + return exec::BlockingReason::kNotBlocked; + } + + auto streamingMode = state_->streamingMode(); + + if (streamingMode == RPCStreamingMode::kPerRow) { + if (!noMoreInput_ && !isDraining()) { + auto claimedRow = state_->tryClaimReady(); + if (claimedRow) { + claimedRows_.push_back(std::move(*claimedRow)); + } + return exec::BlockingReason::kNotBlocked; + } + + std::optional claimedRow; + ContinueFuture waitFuture{ContinueFuture::makeEmpty()}; + auto result = state_->tryClaimOrWait(&waitFuture, &claimedRow); + + switch (result) { + case RPCState::ClaimResult::kClaimed: + claimedRows_.push_back(std::move(*claimedRow)); + return exec::BlockingReason::kNotBlocked; + + case RPCState::ClaimResult::kFinished: + finished_ = true; + return exec::BlockingReason::kNotBlocked; + + case RPCState::ClaimResult::kMustWait: + *future = std::move(waitFuture); + blockWaitStartNs_ = getCurrentTimeNano(); + blockWaitIsBackpressure_ = false; + return exec::BlockingReason::kWaitForRPC; + } + } else { + // BATCH mode + if (!noMoreInput_ && !isDraining()) { + auto readyBatch = state_->tryPollReady(); + if (readyBatch) { + if (readyBatch->error.has_value()) { + RPC_OP_LOG(WARNING) + << "Received batch with error: " << readyBatch->error.value(); + } + claimedBatch_ = std::move(*readyBatch); + } + return exec::BlockingReason::kNotBlocked; + } + + std::optional readyBatch; + ContinueFuture waitFuture{ContinueFuture::makeEmpty()}; + auto result = state_->tryPollBatchOrWait(&waitFuture, &readyBatch); + + switch (result) { + case RPCState::BatchPollResult::kGotBatch: + if (readyBatch->error.has_value()) { + RPC_OP_LOG(WARNING) + << "Received batch with error: " << readyBatch->error.value(); + } + claimedBatch_ = std::move(*readyBatch); + return exec::BlockingReason::kNotBlocked; + + case RPCState::BatchPollResult::kFinished: + finished_ = true; + return exec::BlockingReason::kNotBlocked; + + case RPCState::BatchPollResult::kMustWait: + *future = std::move(waitFuture); + blockWaitStartNs_ = getCurrentTimeNano(); + blockWaitIsBackpressure_ = false; + return exec::BlockingReason::kWaitForRPC; + } + } + + VELOX_UNREACHABLE(); +} + +bool RPCOperator::isFinished() { + return finished_ && claimedRows_.empty() && !claimedBatch_.has_value(); +} + +bool RPCOperator::startDrain() { + VELOX_CHECK(isDraining()); + VELOX_CHECK(!noMoreInput_); + + // Flush any undispatched accumulated rows. + if (function_->pendingBatchSize() > 0) { + flushBatchRequests(); + } + + // Signal RPCState that no more rows will be dispatched so it can + // detect the finish condition once in-flight RPCs complete. + state_->setNoMoreInput(); + + // If we have claimed output or pending in-flight RPCs, there is + // buffered data to drain. + if (!claimedRows_.empty() || claimedBatch_.has_value()) { + return true; + } + if (state_ && !state_->isFinished()) { + return true; + } + return false; +} + +void RPCOperator::close() { + recordRuntimeStats(); + + // Release resources explicitly. RPCState may be held alive by in-flight + // RPC callbacks (via shared_ptr capture), but we release our reference + // so that input batch memory can be freed as soon as possible. + state_.reset(); + function_.reset(); + claimedRows_.clear(); + claimedBatch_.reset(); + batchRowLocations_.clear(); + batchRowIds_.clear(); + reusableIndices_.reset(); + + Operator::close(); +} + +void RPCOperator::initOutputProjections() { + const auto& outputColumn = rpcNode_->outputColumn(); + const auto& outputType = rpcNode_->outputType(); + auto sourceType = rpcNode_->source()->outputType(); + + for (int32_t i = 0; i < static_cast(outputType->size()); ++i) { + const auto& colName = outputType->nameOf(i); + if (colName == outputColumn) { + rpcResultOutputChannel_ = static_cast(i); + } else { + auto colIdx = sourceType->getChildIdxIfExists(colName); + if (colIdx.has_value()) { + passthroughProjections_.push_back( + OutputProjection{ + .outputChannel = static_cast(i), + .sourceChannel = static_cast(colIdx.value())}); + } + } + } + + RPC_OP_VLOG(1) << "initOutputProjections: rpcResultChannel=" + << rpcResultOutputChannel_ << ", passthroughProjections=" + << passthroughProjections_.size(); +} + +void RPCOperator::recordRuntimeStats() { + auto lockedStats = stats_.wlock(); + lockedStats->addRuntimeStat( + kRpcRequestsDispatched, RuntimeCounter(numRequestsDispatched_)); + lockedStats->addRuntimeStat( + kRpcResponsesReceived, RuntimeCounter(numResponsesCollected_)); + lockedStats->addRuntimeStat(kRpcErrorCount, RuntimeCounter(numErrors_)); + if (totalBlockWaitNanos_ > 0) { + lockedStats->addRuntimeStat( + kRpcWaitWallNanos, + RuntimeCounter( + static_cast(totalBlockWaitNanos_), + RuntimeCounter::Unit::kNanos)); + } + if (totalBackpressureWaitNanos_ > 0) { + lockedStats->addRuntimeStat( + kRpcBackpressureWaitNanos, + RuntimeCounter( + static_cast(totalBackpressureWaitNanos_), + RuntimeCounter::Unit::kNanos)); + } + + if (totalBlockWaitNanos_ > 0 || numResponsesCollected_ > 0) { + const CpuWallTiming backgroundTiming{ + static_cast(numResponsesCollected_), totalBlockWaitNanos_, 0}; + lockedStats->backgroundTiming.clear(); + lockedStats->backgroundTiming.add(backgroundTiming); + } +} + +RowVectorPtr RPCOperator::buildOutputFromReadyRows( + std::vector& readyRows) { + std::vector responses; + responses.reserve(readyRows.size()); + + std::vector> locations; + locations.reserve(readyRows.size()); + + for (auto& row : readyRows) { + responses.push_back(std::move(row.response)); + locations.emplace_back(row.location.batchIndex, row.location.rowIndex); + } + + return buildOutputVector(responses, locations); +} + +RowVectorPtr RPCOperator::buildOutputFromReadyBatch( + RPCState::ReadyBatch& readyBatch) { + std::vector> locations; + locations.reserve(readyBatch.rowLocations.size()); + for (const auto& loc : readyBatch.rowLocations) { + locations.emplace_back(loc.batchIndex, loc.rowIndex); + } + + return buildOutputVector(readyBatch.responses, locations); +} + +RowVectorPtr RPCOperator::buildOutputVector( + const std::vector& responses, + const std::vector>& locations) { + const auto numRows = static_cast(responses.size()); + auto* pool = operatorCtx_->pool(); + + const auto& outputType = rpcNode_->outputType(); + + // Use AsyncRPCFunction to build RPC result column. + auto responseVector = function_->buildOutput(responses, pool); + + // Check if all rows come from the same batch (common for BATCH mode). + bool singleBatch = true; + if (numRows > 0) { + int32_t firstBatch = locations[0].first; + for (vector_size_t i = 1; i < numRows; ++i) { + if (locations[i].first != firstBatch) { + singleBatch = false; + break; + } + } + } + + std::vector outputChildren(outputType->size()); + + // Set RPC result column using precomputed index. + outputChildren[rpcResultOutputChannel_] = responseVector; + + // Set passthrough columns using precomputed projections. + if (numRows == 0) { + for (const auto& proj : passthroughProjections_) { + outputChildren[proj.outputChannel] = + BaseVector::create(outputType->childAt(proj.outputChannel), 0, pool); + } + } else if (singleBatch) { + // All rows from same batch: use dictionary wrapping (zero-copy). + const auto indicesByteSize = numRows * sizeof(vector_size_t); + if (!reusableIndices_ || !reusableIndices_->unique() || + reusableIndices_->capacity() < indicesByteSize) { + reusableIndices_ = allocateIndices(numRows, pool); + } + reusableIndices_->setSize(indicesByteSize); + auto rawIndices = reusableIndices_->asMutable(); + for (vector_size_t rowIdx = 0; rowIdx < numRows; ++rowIdx) { + rawIndices[rowIdx] = locations[rowIdx].second; + } + + const auto batchCols = state_->getInputBatchColumns(locations[0].first); + for (const auto& proj : passthroughProjections_) { + if (proj.sourceChannel < static_cast(batchCols.size())) { + outputChildren[proj.outputChannel] = BaseVector::wrapInDictionary( + nullptr, reusableIndices_, numRows, batchCols[proj.sourceChannel]); + } else { + outputChildren[proj.outputChannel] = BaseVector::createNullConstant( + outputType->childAt(proj.outputChannel), numRows, pool); + } + } + } else { + // Rows from multiple batches: fetch columns once per batch. + std::unordered_map> batchColsCache; + for (const auto& proj : passthroughProjections_) { + auto combined = BaseVector::create( + outputType->childAt(proj.outputChannel), numRows, pool); + for (vector_size_t rowIdx = 0; rowIdx < numRows; ++rowIdx) { + const auto& [batchIdx, rowInBatch] = locations[rowIdx]; + auto it = batchColsCache.find(batchIdx); + if (it == batchColsCache.end()) { + it = batchColsCache + .emplace(batchIdx, state_->getInputBatchColumns(batchIdx)) + .first; + } + const auto& batchCols = it->second; + if (proj.sourceChannel < + static_cast(batchCols.size())) { + combined->copy( + batchCols[proj.sourceChannel].get(), rowIdx, rowInBatch, 1); + } else { + combined->setNull(rowIdx, true); + } + } + outputChildren[proj.outputChannel] = combined; + } + } + + // Fill any remaining nullptr entries with null constants. + for (int32_t i = 0; i < static_cast(outputChildren.size()); ++i) { + if (!outputChildren[i]) { + outputChildren[i] = + BaseVector::createNullConstant(outputType->childAt(i), numRows, pool); + } + } + + // Release rows from their input batches. + std::unordered_map batchReleaseCounts; + for (const auto& loc : locations) { + batchReleaseCounts[loc.first]++; + } + for (const auto& [batchIdx, count] : batchReleaseCounts) { + state_->releaseRows(batchIdx, count); + } + + return std::make_shared( + pool, outputType, nullptr, numRows, std::move(outputChildren)); +} + +} // namespace facebook::velox::exec::rpc diff --git a/velox/exec/rpc/RPCOperator.h b/velox/exec/rpc/RPCOperator.h new file mode 100644 index 00000000000..aab13149bcf --- /dev/null +++ b/velox/exec/rpc/RPCOperator.h @@ -0,0 +1,201 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include "velox/buffer/Buffer.h" +#include "velox/common/future/VeloxPromise.h" +#include "velox/exec/Operator.h" +#include "velox/exec/rpc/RPCRateLimiter.h" +#include "velox/exec/rpc/RPCState.h" +#include "velox/expression/rpc/AsyncRPCFunction.h" + +namespace facebook::velox::exec::rpc { + +/// Single-operator implementation for async RPC execution in Velox. +/// +/// Handles both dispatch (send RPCs in addInput()) and join (receive results +/// in isBlocked()/getOutput()) within the same operator. +/// +/// Architecture: +/// TableScan -> RPCOperator -> downstream +/// +/// The operator lifecycle: +/// 1. addInput(): Reads pre-computed argument columns, dispatches RPCs via +/// the function's dispatchPerRow() or accumulateBatch()+flushBatch(). +/// 2. needsInput(): Returns false when under backpressure or when there are +/// ready results to output (prioritizes outputting over accepting input). +/// 3. isBlocked(): Checks RPCState for completed responses. Returns +/// kWaitForRPC when no results are ready yet. +/// 4. getOutput(): Builds output RowVector from completed RPC responses +/// combined with preserved input (passthrough) columns. +/// 5. noMoreInput(): In BATCH mode, flushes remaining accumulated rows. +/// Signals RPCState that no more rows will be dispatched. +/// +/// Supports two streaming modes: +/// - PER_ROW: Rows emitted as individual RPCs complete (out-of-order). +/// Lower tail latency for high-variance workloads (e.g., LLM inference). +/// - BATCH: All rows in a batch complete before emitting. Lower overhead +/// for uniform-latency workloads. Supports pipelined dispatch via +/// dispatchBatchSize. +/// +/// State is derived from data presence (no explicit state machine enum): +/// - Has output: claimedRows_ non-empty or claimedBatch_ has value +/// - Finished: noMoreInput_ && state_->isFinished() && no claimed data +/// +/// Thread safety model: +/// - addInput(), getOutput(), needsInput(), isBlocked() are called from +/// a single driver thread (Velox guarantee). No synchronization needed +/// for operator-local state (e.g., globalRowIdCounter_, claimedRows_). +/// - Async RPC callbacks may run on any thread (transport executor pool). +/// All cross-thread coordination goes through RPCState, which is fully +/// mutex-protected (see RPCState.h for per-method annotations). +/// - RPCRateLimiter tokens use RAII: destruction (including from cancelled +/// futures) automatically decrements the pending count and notifies +/// waiters. +/// +class RPCOperator : public exec::Operator { + public: + RPCOperator( + int32_t operatorId, + exec::DriverCtx* driverCtx, + std::shared_ptr rpcNode); + + void initialize() override; + + void close() override; + + bool needsInput() const override; + + void addInput(RowVectorPtr input) override; + + void noMoreInput() override; + + RowVectorPtr getOutput() override; + + exec::BlockingReason isBlocked(ContinueFuture* future) override; + + bool isFinished() override; + + bool startDrain() override; + + /// Runtime stat names. + static inline const std::string kRpcRequestsDispatched{ + "rpcRequestsDispatched"}; + static inline const std::string kRpcResponsesReceived{"rpcResponsesReceived"}; + static inline const std::string kRpcErrorCount{"rpcErrorCount"}; + static inline const std::string kRpcWaitWallNanos{"rpcWaitWallNanos"}; + static inline const std::string kRpcBackpressureWaitNanos{ + "rpcBackpressureWaitNanos"}; + + private: + /// Flush accumulated batch rows via function_->flushBatch(). + /// Called when threshold is reached or at noMoreInput/drain time. + /// @param maxRows Maximum rows to flush. 0 means flush all. + void flushBatchRequests(int32_t maxRows = 0); + + /// Build output RowVector from ready rows (PER_ROW mode). + /// Supports multiple rows via batched drain for pipeline efficiency. + RowVectorPtr buildOutputFromReadyRows( + std::vector& readyRows); + + /// Build output RowVector from a ready batch (BATCH mode). + RowVectorPtr buildOutputFromReadyBatch(RPCState::ReadyBatch& readyBatch); + + /// Common helper: build output vector from responses + input data lookup. + RowVectorPtr buildOutputVector( + const std::vector& responses, + const std::vector>& locations); + + /// Precompute output column projections from source type to output type. + /// Called once in initialize() to avoid repeated string lookups in + /// buildOutputVector(). + void initOutputProjections(); + + /// Record runtime stats into operator stats. Called from close(). + void recordRuntimeStats(); + + std::shared_ptr rpcNode_; + std::shared_ptr state_; + std::shared_ptr function_; + + // Tier key for per-tier rate limiting (from function_->tierKey()). + std::string tierKey_; + + // Precomputed argument column indices for reading from input in addInput(). + // Initialized in initialize() by looking up argumentColumns in source type. + std::vector argumentColumnIndices_; + + // Collected row locations for current batch (BATCH mode). + // Passed to addPendingBatch() when the batch is flushed. + std::vector batchRowLocations_; + + // Collected row IDs for current batch (BATCH mode). + // Used to stamp rowIds onto responses at flush time. + std::vector batchRowIds_; + + int64_t numRequestsDispatched_{0}; + int64_t numResponsesCollected_{0}; + int64_t numErrors_{0}; + + // Global row ID counter for unique IDs across all input batches. + int64_t globalRowIdCounter_{0}; + + // Dispatch batch size for pipelined BATCH mode. + // 0 = collect all rows, fire once in noMoreInput(). + // > 0 = fire flushBatch() every N rows during addInput(). + int32_t dispatchBatchSize_{0}; + + // Claimed rows/batch from isBlocked() for use in getOutput(). + // State is derived from these: if non-empty, we have output ready. + std::vector claimedRows_; + std::optional claimedBatch_; + + // Whether we've detected the finish condition. + bool finished_{false}; + + // Timeout for batch RPC calls (30 minutes). + // This is a ceiling — the operator returns as soon as results are ready. + // Batch LLM inference can take many minutes due to MetaGen queuing + // and GPU scheduling, so the timeout needs generous headroom. + static constexpr auto kBatchRpcTimeout = std::chrono::milliseconds(3'600'000); + + // Block wait time tracking for runtime stats. + std::optional blockWaitStartNs_; + bool blockWaitIsBackpressure_{false}; + uint64_t totalBlockWaitNanos_{0}; + uint64_t totalBackpressureWaitNanos_{0}; + + // Reusable indices buffer for dictionary wrapping in single-batch output + // path. + BufferPtr reusableIndices_; + + // Precomputed output column projections (initialized in initialize()). + // Maps output column index to source column index for passthrough columns. + // Avoids repeated string-based column lookups in buildOutputVector(). + struct OutputProjection { + column_index_t outputChannel; + column_index_t sourceChannel; + }; + std::vector passthroughProjections_; + column_index_t rpcResultOutputChannel_{0}; +}; + +} // namespace facebook::velox::exec::rpc diff --git a/velox/exec/rpc/RPCPlanNodeTranslator.cpp b/velox/exec/rpc/RPCPlanNodeTranslator.cpp new file mode 100644 index 00000000000..7ce309ff7eb --- /dev/null +++ b/velox/exec/rpc/RPCPlanNodeTranslator.cpp @@ -0,0 +1,77 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/rpc/RPCPlanNodeTranslator.h" + +#include + +#include "velox/exec/rpc/RPCOperator.h" + +namespace facebook::velox::exec::rpc { + +std::unique_ptr RPCPlanNodeTranslator::toOperator( + exec::DriverCtx* ctx, + int32_t id, + const core::PlanNodePtr& node) { + if (auto rpcNode = std::dynamic_pointer_cast(node)) { + VLOG(1) << "[RPC_TRANSLATOR] Creating RPCOperator for node id=" + << node->id(); + return std::make_unique(id, ctx, rpcNode); + } + return nullptr; +} + +std::optional RPCPlanNodeTranslator::maxDrivers( + const core::PlanNodePtr& node) { + if (auto rpcNode = std::dynamic_pointer_cast(node)) { + if (rpcNode->streamingMode() == rpc::RPCStreamingMode::kBatch) { + // BATCH mode: Force single-driver execution. Multiple drivers would + // race for batch results, with only one getting data and others + // finishing empty. + return 1; + } + + // When all arguments are constants and no real data columns flow from the + // source, the upstream is a synthetic single-row ValuesNode. The Java + // AddLocalExchanges optimizer inserts a ROUND_ROBIN LocalExchange that + // distributes this single row across N drivers, causing N-1 drivers to + // finish empty and the result to be lost. Force single-driver execution. + const auto& constantInputs = rpcNode->constantInputs(); + const auto& argumentColumns = rpcNode->argumentColumns(); + bool allConstant = !constantInputs.empty() && + std::all_of( + constantInputs.begin(), constantInputs.end(), [](const auto& v) { + return v != nullptr; + }); + if (allConstant) { + auto sourceType = rpcNode->source()->outputType(); + if (sourceType->size() <= argumentColumns.size()) { + return 1; + } + } + + // PER_ROW mode: Allow parallel execution. Each driver claims individual + // rows atomically via RPCState::tryClaimOrWait(). + return std::nullopt; + } + return std::nullopt; +} + +void registerRPCPlanNodeTranslator() { + exec::Operator::registerOperator(std::make_unique()); +} + +} // namespace facebook::velox::exec::rpc diff --git a/velox/exec/rpc/RPCPlanNodeTranslator.h b/velox/exec/rpc/RPCPlanNodeTranslator.h new file mode 100644 index 00000000000..b69e3bc532f --- /dev/null +++ b/velox/exec/rpc/RPCPlanNodeTranslator.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/exec/Operator.h" + +namespace facebook::velox::exec::rpc { + +/// PlanNodeTranslator for RPCNode. +/// +/// Creates RPCOperator from RPCNode plan nodes. +/// +/// maxDrivers() returns 1 for BATCH mode to prevent multiple drivers from +/// competing for batch results. PER_ROW mode allows parallel execution +/// since each driver claims individual rows atomically. +class RPCPlanNodeTranslator : public exec::Operator::PlanNodeTranslator { + public: + std::unique_ptr toOperator( + exec::DriverCtx* ctx, + int32_t id, + const core::PlanNodePtr& node) override; + + /// Returns 1 for BATCH mode (single-driver), nullopt for PER_ROW mode. + std::optional maxDrivers(const core::PlanNodePtr& node) override; +}; + +/// Register the RPCPlanNodeTranslator with Velox. +void registerRPCPlanNodeTranslator(); + +} // namespace facebook::velox::exec::rpc diff --git a/velox/exec/rpc/RPCRateLimiter.cpp b/velox/exec/rpc/RPCRateLimiter.cpp new file mode 100644 index 00000000000..f6da2aab80f --- /dev/null +++ b/velox/exec/rpc/RPCRateLimiter.cpp @@ -0,0 +1,186 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/rpc/RPCRateLimiter.h" + +#define RPC_RATE_LIMITER_LOG(severity) LOG(severity) << "[RPC_RATE_LIMITER] " +#define RPC_RATE_LIMITER_VLOG(level) VLOG(level) << "[RPC_RATE_LIMITER] " + +namespace facebook::velox::exec::rpc { + +// --- Token implementation --- + +RPCRateLimiter::Token::Token(const std::string& tierKey) + : tierKey_(tierKey), valid_(true) {} + +RPCRateLimiter::Token& RPCRateLimiter::Token::operator=( + Token&& other) noexcept { + if (this != &other) { + if (valid_) { + decrementPending(tierKey_); + } + tierKey_ = std::move(other.tierKey_); + valid_ = other.valid_; + other.valid_ = false; + } + return *this; +} + +RPCRateLimiter::Token::~Token() { + if (valid_) { + decrementPending(tierKey_); + } +} + +// --- Function-local statics --- + +std::mutex& RPCRateLimiter::mapMutex() { + static std::mutex mutex; + return mutex; +} + +std::atomic& RPCRateLimiter::defaultMaxPendingRef() { + // Default: 20 concurrent RPCs per process per tier. + static std::atomic maxPending{20}; + return maxPending; +} + +std::unordered_map>& +RPCRateLimiter::tiers() { + static std::unordered_map> tierMap; + return tierMap; +} + +// --- TierState lookup --- + +RPCRateLimiter::TierState& RPCRateLimiter::getOrCreateTierState( + const std::string& tierKey) { + std::lock_guard l(mapMutex()); + auto& tierMap = tiers(); + auto it = tierMap.find(tierKey); + if (it != tierMap.end()) { + return *it->second; + } + auto [newIt, _] = tierMap.emplace(tierKey, std::make_unique()); + return *newIt->second; +} + +// --- Public API --- + +RPCRateLimiter::Token RPCRateLimiter::acquire(const std::string& tierKey) { + incrementPending(tierKey); + return Token(tierKey); +} + +std::optional RPCRateLimiter::checkBackpressure( + const std::string& tierKey) { + auto& state = getOrCreateTierState(tierKey); + + std::lock_guard l(state.mutex); + + int64_t pending = state.pendingCount.load(); + int64_t maxPending = + state.maxPending > 0 ? state.maxPending : defaultMaxPendingRef().load(); + + if (pending < maxPending) { + RPC_RATE_LIMITER_VLOG(2) + << "checkBackpressure[" << tierKey << "]: OK (pending=" << pending + << ", max=" << maxPending << ")"; + return std::nullopt; + } + + RPC_RATE_LIMITER_VLOG(1) << "checkBackpressure[" << tierKey + << "]: BLOCKED (pending=" << pending + << ", max=" << maxPending + << "), creating wait promise #" + << state.waiters.size(); + state.waiters.emplace_back("RPCRateLimiter::checkBackpressure"); + return state.waiters.back().getSemiFuture(); +} + +int64_t RPCRateLimiter::pendingCount(const std::string& tierKey) { + auto& state = getOrCreateTierState(tierKey); + return state.pendingCount.load(); +} + +void RPCRateLimiter::setMaxPending(const std::string& tierKey, int64_t limit) { + auto& state = getOrCreateTierState(tierKey); + std::lock_guard l(state.mutex); + state.maxPending = limit; + RPC_RATE_LIMITER_VLOG(1) << "setMaxPending[" << tierKey << "]: set to " + << limit; +} + +void RPCRateLimiter::setDefaultMaxPending(int64_t limit) { + defaultMaxPendingRef().store(limit); + RPC_RATE_LIMITER_VLOG(1) << "setDefaultMaxPending: set to " << limit; +} + +int64_t RPCRateLimiter::defaultMaxPending() { + return defaultMaxPendingRef().load(); +} + +void RPCRateLimiter::testingResetAllState() { + std::lock_guard l(mapMutex()); + defaultMaxPendingRef().store(20); + tiers().clear(); +} + +// --- Internal helpers --- + +void RPCRateLimiter::incrementPending(const std::string& tierKey) { + auto& state = getOrCreateTierState(tierKey); + int64_t newCount = ++state.pendingCount; + RPC_RATE_LIMITER_VLOG(2) << "incrementPending[" << tierKey + << "]: pending=" << newCount; +} + +void RPCRateLimiter::decrementPending(const std::string& tierKey) { + auto& state = getOrCreateTierState(tierKey); + int64_t newCount = --state.pendingCount; + RPC_RATE_LIMITER_VLOG(2) << "decrementPending[" << tierKey + << "]: pending=" << newCount; + + // CRITICAL: Hold the per-tier mutex when checking whether to notify waiters. + // This prevents a TOCTOU race where: + // 1. We read newCount < maxPending (should notify) + // 2. A new waiter is added in checkBackpressure() between read and notify + // + // By holding the lock during check-and-notify, any waiter added in + // checkBackpressure() will either: + // - Be in state.waiters before we check (and we'll notify it) + // - See the updated count and not need to wait at all + // + // We notify only one waiter per decrement (FIFO) to avoid thundering herd. + std::optional waiterToNotify; + { + std::lock_guard l(state.mutex); + int64_t maxPending = + state.maxPending > 0 ? state.maxPending : defaultMaxPendingRef().load(); + if (newCount < maxPending && !state.waiters.empty()) { + RPC_RATE_LIMITER_VLOG(1) + << "decrementPending[" << tierKey << "]: notifying 1 of " + << state.waiters.size() << " waiters"; + waiterToNotify = std::move(state.waiters.front()); + state.waiters.pop_front(); + } + } + if (waiterToNotify) { + waiterToNotify->setValue(); + } +} + +} // namespace facebook::velox::exec::rpc diff --git a/velox/exec/rpc/RPCRateLimiter.h b/velox/exec/rpc/RPCRateLimiter.h new file mode 100644 index 00000000000..a9f2e0a4d45 --- /dev/null +++ b/velox/exec/rpc/RPCRateLimiter.h @@ -0,0 +1,123 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "velox/common/future/VeloxPromise.h" + +namespace facebook::velox::exec::rpc { + +/// Per-tier (per-process) rate limiter for RPC dispatch. +/// +/// Each backend service tier (e.g., "service.backend.prod") gets its own +/// independent concurrency limit and waiter queue. This allows different +/// backends to have different concurrency budgets. +/// +/// The rate limiter is per-process. Each Presto worker is a separate process +/// with its own ServiceRouter connections. Cross-worker coordination is +/// handled by the backend's own admission control. +/// +/// The tier key comes from IRPCClient::tierKey(). Empty string means +/// "no tier configured" and falls back to the global default limit. +class RPCRateLimiter { + public: + /// RAII token representing one in-flight request slot. + /// Decrements the pending count on destruction, guaranteeing cleanup + /// even if the RPC future is abandoned (e.g., query cancellation). + class Token { + public: + Token() = default; + + Token(Token&& other) noexcept + : tierKey_(std::move(other.tierKey_)), valid_(other.valid_) { + other.valid_ = false; + } + + Token& operator=(Token&& other) noexcept; + + ~Token(); + + // Non-copyable. + Token(const Token&) = delete; + Token& operator=(const Token&) = delete; + + private: + friend class RPCRateLimiter; + explicit Token(const std::string& tierKey); + + std::string tierKey_; + bool valid_{false}; + }; + + /// Acquire a slot for the given tier. Increments pending count and + /// returns a Token that will decrement on destruction. + static Token acquire(const std::string& tierKey); + + /// Check backpressure for a specific tier. + /// Returns a future to wait on if that tier's limit is reached. + /// @return Future to wait on if blocked, nullopt if can proceed. + static std::optional checkBackpressure( + const std::string& tierKey); + + /// Get current pending count for a specific tier. + static int64_t pendingCount(const std::string& tierKey); + + /// Configure the max pending limit for a specific tier. + /// If not configured, falls back to the global default. + static void setMaxPending(const std::string& tierKey, int64_t limit); + + /// Set the global default max pending (used when no per-tier config). + static void setDefaultMaxPending(int64_t limit); + + /// Get the global default max pending. + static int64_t defaultMaxPending(); + + /// Reset all state. Intended ONLY for unit tests + /// to avoid test contamination across test cases. + /// WARNING: Do NOT call this in production code. + static void testingResetAllState(); + + private: + /// Per-tier state. Allocated via unique_ptr for pointer stability across + /// map inserts (std::unordered_map does not invalidate existing entries). + struct TierState { + std::mutex mutex; + std::atomic pendingCount{0}; + int64_t maxPending{0}; // 0 = use global default + std::deque waiters; + }; + + static void incrementPending(const std::string& tierKey); + static void decrementPending(const std::string& tierKey); + + static TierState& getOrCreateTierState(const std::string& tierKey); + + // Global mutex protects only the tier map for inserts/lookups. + static std::mutex& mapMutex(); + static std::atomic& defaultMaxPendingRef(); + static std::unordered_map>& tiers(); +}; + +} // namespace facebook::velox::exec::rpc diff --git a/velox/exec/rpc/RPCState.cpp b/velox/exec/rpc/RPCState.cpp new file mode 100644 index 00000000000..4857d0d7258 --- /dev/null +++ b/velox/exec/rpc/RPCState.cpp @@ -0,0 +1,382 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/rpc/RPCState.h" + +#include + +#include "velox/common/base/Exceptions.h" + +#define RPC_STATE_LOG(severity) LOG(severity) << "[RPC_STATE] " +#define RPC_STATE_VLOG(level) VLOG(level) << "[RPC_STATE] " + +namespace facebook::velox::exec::rpc { + +// ===== Configuration ===== +// These setters are called during RPCOperator construction, before any +// concurrent access. No lock needed — avoids lock-order-inversion with the +// Task mutex (TSAN: M0→M1→M0 cycle). + +void RPCState::setStreamingMode(RPCStreamingMode mode) { + streamingMode_ = mode; +} + +RPCStreamingMode RPCState::streamingMode() const { + return streamingMode_; +} + +void RPCState::setMaxPendingRows(int64_t maxPendingRows) { + maxPendingRows_ = maxPendingRows; +} + +void RPCState::setMaxPendingBatches(int64_t maxPendingBatches) { + maxPendingBatches_ = maxPendingBatches; + effectiveMaxPendingBatches_ = maxPendingBatches; +} + +// ===== Input batch storage ===== + +int32_t RPCState::storeInputBatch( + std::vector flatColumns, + int64_t rowCount) { + std::lock_guard l(mutex_); + auto batchIndex = static_cast(inputBatches_.size()); + inputBatches_.push_back( + InputBatchRef{ + .flatColumns = std::move(flatColumns), .activeRowCount = rowCount}); + RPC_STATE_VLOG(2) << "storeInputBatch: batchIndex=" << batchIndex + << ", rowCount=" << rowCount; + return batchIndex; +} + +std::vector RPCState::getInputBatchColumns( + int32_t batchIndex) const { + std::lock_guard l(mutex_); + VELOX_CHECK_LT( + batchIndex, + static_cast(inputBatches_.size()), + "Invalid batchIndex {} for inputBatches_ of size {}", + batchIndex, + inputBatches_.size()); + return inputBatches_[batchIndex].flatColumns; +} + +void RPCState::releaseRows(int32_t batchIndex, int64_t count) { + std::lock_guard l(mutex_); + VELOX_CHECK_LT( + batchIndex, + static_cast(inputBatches_.size()), + "Invalid batchIndex {} for inputBatches_ of size {}", + batchIndex, + inputBatches_.size()); + auto& batch = inputBatches_[batchIndex]; + VELOX_CHECK_GE( + batch.activeRowCount, + count, + "Cannot release {} rows from batch {} with only {} active rows", + count, + batchIndex, + batch.activeRowCount); + batch.activeRowCount -= count; + if (batch.activeRowCount == 0) { + // Release the column vectors to free memory. + batch.flatColumns.clear(); + RPC_STATE_VLOG(2) << "releaseRows: batch " << batchIndex + << " fully released"; + } +} + +// ===== PER_ROW mode API ===== + +void RPCState::addPendingRow( + std::shared_ptr selfPtr, + int64_t rowId, + RowLocation location, + folly::SemiFuture future) { + { + std::lock_guard l(mutex_); + numPendingRows_++; + RPC_STATE_VLOG(2) << "addPendingRow: rowId=" << rowId + << ", pendingCount=" << numPendingRows_; + } + + // Attach completion callbacks that delegate to completeRow(). + // We capture selfPtr to keep this RPCState alive until the callback fires. + auto stateForError = selfPtr; + folly::futures::detachOn( + folly::getKeepAliveToken(folly::InlineExecutor::instance()), + std::move(future) + .deferValue([state = std::move(selfPtr), rowId, location]( + RPCResponse response) mutable { + state->completeRow(rowId, location, std::move(response)); + }) + .deferError([state = std::move(stateForError), rowId, location]( + const folly::exception_wrapper& ew) mutable { + RPC_STATE_LOG(ERROR) + << "RPC failed for rowId=" << rowId << ": " << ew.what(); + RPCResponse errorResponse; + errorResponse.rowId = rowId; + errorResponse.error = ew.what().toStdString(); + state->completeRow(rowId, location, std::move(errorResponse)); + })); +} + +void RPCState::completeRow( + int64_t rowId, + RowLocation location, + RPCResponse response) { + std::lock_guard l(mutex_); + readyRows_.push_back( + ReadyRow{ + .rowId = rowId, + .location = location, + .response = std::move(response)}); + numPendingRows_--; + + RPC_STATE_VLOG(2) << "Row completed: rowId=" << rowId + << ", readyRows=" << readyRows_.size() + << ", pendingCount=" << numPendingRows_; + + notifyWaitersLocked(); +} + +RPCState::ClaimResult RPCState::tryClaimOrWait( + ContinueFuture* future, + std::optional* claimedRow) { + std::lock_guard l(mutex_); + + // Step 1: Try to claim a ready row. + if (!readyRows_.empty()) { + *claimedRow = std::move(readyRows_.front()); + readyRows_.pop_front(); + RPC_STATE_VLOG(1) << "tryClaimOrWait: claimed rowId=" + << (*claimedRow)->rowId + << ", remaining ready=" << readyRows_.size(); + return ClaimResult::kClaimed; + } + + // Step 2: Check finish condition. + if (noMoreInput_ && numPendingRows_ == 0) { + RPC_STATE_VLOG(1) << "tryClaimOrWait: finish condition met"; + return ClaimResult::kFinished; + } + + // Step 3: Must wait — create a promise under the same lock to prevent + // TOCTOU races (no completion can slip between the check and wait). + RPC_STATE_VLOG(2) << "tryClaimOrWait: must wait (pending=" << numPendingRows_ + << ")"; + promises_.emplace_back("RPCState::tryClaimOrWait"); + *future = promises_.back().getSemiFuture(); + return ClaimResult::kMustWait; +} + +std::optional RPCState::tryClaimReady() { + std::lock_guard l(mutex_); + if (!readyRows_.empty()) { + auto row = std::move(readyRows_.front()); + readyRows_.pop_front(); + RPC_STATE_VLOG(1) << "tryClaimReady: claimed rowId=" << row.rowId + << ", remaining ready=" << readyRows_.size(); + return row; + } + return std::nullopt; +} + +void RPCState::drainReadyRows(std::vector& out, int32_t maxRows) { + std::lock_guard l(mutex_); + while (!readyRows_.empty() && static_cast(out.size()) < maxRows) { + out.push_back(std::move(readyRows_.front())); + readyRows_.pop_front(); + } +} + +int64_t RPCState::numPendingRows() { + std::lock_guard l(mutex_); + return numPendingRows_; +} + +// ===== BATCH mode API ===== + +void RPCState::addPendingBatch( + std::shared_ptr selfPtr, + folly::SemiFuture> future, + std::vector rowLocations) { + std::lock_guard l(mutex_); + + int64_t batchId = nextBatchId_++; + + // Attach a completion callback that notifies waiters when the batch + // completes. We use .via().thenValue() to run the callback eagerly. + auto callbackFuture = + std::move(future) + .via(folly::getKeepAliveToken(folly::InlineExecutor::instance())) + .thenValue([state = selfPtr](std::vector responses) { + RPC_STATE_VLOG(1) + << "Batch completed with " << responses.size() << " responses"; + { + std::lock_guard l(state->mutex_); + state->notifyWaitersLocked(); + } + return responses; + }) + .thenError([state = selfPtr](folly::exception_wrapper ew) { + RPC_STATE_LOG(ERROR) << "Batch failed: " << ew.what(); + { + std::lock_guard l(state->mutex_); + state->notifyWaitersLocked(); + } + return folly::makeSemiFuture>( + std::move(ew)); + }) + .semi(); + + pendingBatches_.push_back( + PendingBatch{ + .batchId = batchId, + .future = std::move(callbackFuture), + .rowLocations = std::move(rowLocations)}); + + RPC_STATE_VLOG(1) << "addPendingBatch: batchId=" << batchId + << ", totalPending=" << pendingBatches_.size(); +} + +RPCState::BatchPollResult RPCState::tryPollBatchOrWait( + ContinueFuture* future, + std::optional* readyBatch) { + std::lock_guard l(mutex_); + + // Step 1: Check for a ready batch (out-of-order: first ready wins). + for (auto it = pendingBatches_.begin(); it != pendingBatches_.end(); ++it) { + if (it->future.isReady()) { + ReadyBatch result; + result.batchId = it->batchId; + result.rowLocations = std::move(it->rowLocations); + + try { + result.responses = std::move(it->future).get(); + RPC_STATE_VLOG(1) << "tryPollBatchOrWait: batchId=" << result.batchId + << " ready with " << result.responses.size() + << " responses"; + } catch (const std::exception& e) { + result.error = e.what(); + result.responses = {}; + RPC_STATE_LOG(ERROR) << "tryPollBatchOrWait: batchId=" << result.batchId + << " failed: " << e.what(); + } + + pendingBatches_.erase(it); + *readyBatch = std::move(result); + return BatchPollResult::kGotBatch; + } + } + + // Step 2: Check finish condition. + if (noMoreInput_ && pendingBatches_.empty()) { + RPC_STATE_VLOG(1) << "tryPollBatchOrWait: finish condition met"; + return BatchPollResult::kFinished; + } + + // Step 3: Must wait. + RPC_STATE_VLOG(2) << "tryPollBatchOrWait: must wait"; + promises_.emplace_back("RPCState::tryPollBatchOrWait"); + *future = promises_.back().getSemiFuture(); + return BatchPollResult::kMustWait; +} + +std::optional RPCState::tryPollReady() { + std::lock_guard l(mutex_); + for (auto it = pendingBatches_.begin(); it != pendingBatches_.end(); ++it) { + if (it->future.isReady()) { + ReadyBatch result; + result.batchId = it->batchId; + result.rowLocations = std::move(it->rowLocations); + try { + result.responses = std::move(it->future).get(); + RPC_STATE_VLOG(1) << "tryPollReady: batchId=" << result.batchId + << " ready with " << result.responses.size() + << " responses"; + } catch (const std::exception& e) { + result.error = e.what(); + result.responses = {}; + RPC_STATE_LOG(ERROR) << "tryPollReady: batchId=" << result.batchId + << " failed: " << e.what(); + } + pendingBatches_.erase(it); + return result; + } + } + return std::nullopt; +} + +// ===== Common ===== + +void RPCState::setNoMoreInput() { + std::lock_guard l(mutex_); + noMoreInput_ = true; + RPC_STATE_VLOG(1) << "setNoMoreInput: pendingRows=" << numPendingRows_ + << ", pendingBatches=" << pendingBatches_.size() + << ", readyRows=" << readyRows_.size(); + notifyWaitersLocked(); +} + +bool RPCState::isFinished() { + std::lock_guard l(mutex_); + return noMoreInput_ && numPendingRows_ == 0 && readyRows_.empty() && + pendingBatches_.empty(); +} + +bool RPCState::isUnderBackpressure() { + std::lock_guard l(mutex_); + if (streamingMode_ == RPCStreamingMode::kBatch) { + return static_cast(pendingBatches_.size()) >= + effectiveMaxPendingBatches_; + } + return numPendingRows_ >= maxPendingRows_; +} + +void RPCState::onBatchSuccess(int64_t increment) { + std::lock_guard l(mutex_); + if (effectiveMaxPendingBatches_ < maxPendingBatches_) { + effectiveMaxPendingBatches_ = + std::min(effectiveMaxPendingBatches_ + increment, maxPendingBatches_); + RPC_STATE_LOG(INFO) << "RPC congestion: batch success, window increased to " + << effectiveMaxPendingBatches_ << "/" + << maxPendingBatches_; + } +} + +void RPCState::onBatchError() { + std::lock_guard l(mutex_); + auto prev = effectiveMaxPendingBatches_; + effectiveMaxPendingBatches_ = + std::max(effectiveMaxPendingBatches_ / 2, 1); + if (effectiveMaxPendingBatches_ < prev) { + RPC_STATE_LOG(WARNING) + << "RPC congestion: batch error, window decreased from " << prev + << " to " << effectiveMaxPendingBatches_; + } +} + +void RPCState::notifyWaitersLocked() { + // Fulfill all promises to wake up blocked drivers. + // Called while mutex_ is held. + for (auto& promise : promises_) { + promise.setValue(); + } + promises_.clear(); +} + +} // namespace facebook::velox::exec::rpc diff --git a/velox/exec/rpc/RPCState.h b/velox/exec/rpc/RPCState.h new file mode 100644 index 00000000000..a4c018d369f --- /dev/null +++ b/velox/exec/rpc/RPCState.h @@ -0,0 +1,277 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include + +#include "velox/common/future/VeloxPromise.h" +#include "velox/common/rpc/RPCTypes.h" +#include "velox/vector/BaseVector.h" + +namespace facebook::velox::exec::rpc { + +// Import RPC types from velox/common/rpc into this namespace. +using velox::rpc::RPCResponse; +using velox::rpc::RPCStreamingMode; + +/// A stored input batch with its rows referenced by index. +/// Instead of slicing individual rows, we keep the entire batch and +/// use row indices to look up passthrough columns at output time. +struct InputBatchRef { + /// Flattened columns from the input batch (shared across all rows). + std::vector flatColumns; + + /// Number of rows from this batch that are still in-flight or pending output. + /// When this reaches 0, the batch can be released. + int64_t activeRowCount{0}; +}; + +/// Shared state between RPCOperator's driver thread and RPC completion +/// callbacks. +/// +/// A mutex-protected state object owned by RPCOperator. It coordinates +/// async RPC dispatch (addPendingRow/addPendingBatch) with result +/// consumption (tryClaimOrWait/tryPollBatchOrWait) across the driver +/// thread and RPC client executor threads. +/// +/// Thread safety: All public methods are thread-safe. The mutex_ protects +/// all mutable state. Completion callbacks from the RPC client's executor +/// threads call notifyWaitersLocked() to wake the driver thread. +/// +/// Two streaming modes: +/// - PER_ROW: Rows are emitted as they complete individually (out-of-order). +/// Lower tail latency for high-variance workloads (e.g., LLM inference). +/// - BATCH: All rows in a batch complete before emitting. Lower overhead +/// for uniform-latency workloads. +class RPCState { + public: + // ===== Data structures ===== + + /// Location of a row within an input batch. + struct RowLocation { + int32_t batchIndex{0}; + vector_size_t rowIndex{0}; + }; + + /// A completed row with its RPC response and location in the input batch. + struct ReadyRow { + int64_t rowId{0}; + RowLocation location; + RPCResponse response; + }; + + /// A batch of rows waiting for RPC response. + struct PendingBatch { + int64_t batchId; + folly::SemiFuture> future; + /// Row locations for mapping responses back to input batch positions. + /// Stored at batch level instead of per-row in a map, since callBatch() + /// returns responses in the same order as requests. + std::vector rowLocations; + }; + + /// A batch with completed RPC responses. + struct ReadyBatch { + int64_t batchId{0}; + std::vector responses; + std::optional error; + /// Row locations carried from PendingBatch for response-to-input mapping. + std::vector rowLocations; + }; + + RPCState() = default; + + // ===== Configuration ===== + // These must be called before any dispatch (single-threaded init phase). + + /// Set the streaming mode. Must be called before any dispatch. + void setStreamingMode(RPCStreamingMode mode); + + /// Get the current streaming mode. + RPCStreamingMode streamingMode() const; + + /// Set the maximum number of pending rows before backpressure. + void setMaxPendingRows(int64_t maxPendingRows); + + /// Set the maximum number of pending batches before backpressure (BATCH + /// mode). + void setMaxPendingBatches(int64_t maxPendingBatches); + + // ===== Input batch storage ===== + // Called from the driver thread (addInput/getOutput). Thread-safe. + + /// Store an input batch and return its index. Thread-safe. + /// The batch is reference-counted; call releaseRows() when rows are output. + int32_t storeInputBatch(std::vector flatColumns, int64_t rowCount); + + /// Get the flattened columns for a stored input batch. Thread-safe. + /// Returns by value to avoid returning a reference that outlives the lock. + std::vector getInputBatchColumns(int32_t batchIndex) const; + + /// Release rows from an input batch. Thread-safe. + /// When all rows are released, the batch columns are freed. + void releaseRows(int32_t batchIndex, int64_t count); + + // ===== PER_ROW mode API ===== + + /// Add a pending row with its RPC future and location in the input batch. + /// Thread-safe. Called from the driver thread in addInput(). + /// + /// Attaches a completion callback to the future that moves the response + /// into readyRows_ and notifies waiting drivers. The callback runs on + /// the RPC client's executor thread, acquiring mutex_ internally. + /// + /// @param selfPtr Shared pointer to this RPCState, captured by the callback + /// to prevent premature destruction. + /// @param rowId Globally unique row ID for correlation. + /// @param location Row's location in the stored input batch. + /// @param future The SemiFuture from client->call(). + void addPendingRow( + std::shared_ptr selfPtr, + int64_t rowId, + RowLocation location, + folly::SemiFuture future); + + /// Atomically try to claim a ready row, check finish, or wait. Thread-safe. + /// Called from the driver thread in isBlocked(). + /// + /// All three checks happen under a single lock to prevent TOCTOU races. + /// + /// @param[out] future Set to a wait future if kMustWait. + /// @param[out] claimedRow Set to the claimed row if kClaimed. + /// @return kClaimed, kFinished, or kMustWait. + enum class ClaimResult { kClaimed, kFinished, kMustWait }; + ClaimResult tryClaimOrWait( + ContinueFuture* future, + std::optional* claimedRow); + + /// Non-blocking claim of a ready row. Thread-safe. + /// Returns nullopt if no ready rows. + /// Unlike tryClaimOrWait(), does NOT create a promise on miss. + /// Use this pre-noMoreInput to avoid accumulating orphaned promises. + std::optional tryClaimReady(); + + /// Drain all currently ready rows (non-blocking, up to maxRows). Thread-safe. + /// Used for batched PER_ROW output to amortize RowVector allocation. + void drainReadyRows(std::vector& out, int32_t maxRows); + + /// Returns the number of pending (in-flight) rows. Thread-safe. + int64_t numPendingRows(); + + // ===== BATCH mode API ===== + + /// Add a pending batch future with row locations for response-to-input + /// mapping. Thread-safe. Called from the driver thread in + /// flushBatchRequests(). + /// + /// Attaches a completion callback that notifies waiters on completion. + /// The callback runs on the RPC client's executor thread, acquiring + /// mutex_ internally. + /// + /// @param selfPtr Shared pointer to this RPCState (prevent destruction). + /// @param future The SemiFuture from client->callBatch(). + /// @param rowLocations Locations mapping each request to its input batch + /// position. Stored on the PendingBatch and carried through to + /// ReadyBatch, eliminating per-row rowLocations_ map overhead. + void addPendingBatch( + std::shared_ptr selfPtr, + folly::SemiFuture> future, + std::vector rowLocations); + + /// Atomically try to poll a ready batch, check finish, or wait. Thread-safe. + /// Called from the driver thread in isBlocked(). + /// + /// @param[out] future Set to a wait future if kMustWait. + /// @param[out] readyBatch Set to the ready batch if kGotBatch. + /// @return kGotBatch, kFinished, or kMustWait. + enum class BatchPollResult { kGotBatch, kFinished, kMustWait }; + BatchPollResult tryPollBatchOrWait( + ContinueFuture* future, + std::optional* readyBatch); + + /// Non-blocking poll of a ready batch. Thread-safe. + /// Returns nullopt if no batch ready. + /// Unlike tryPollBatchOrWait(), does NOT create a promise on miss. + /// Use this pre-noMoreInput to avoid accumulating orphaned promises. + std::optional tryPollReady(); + + // ===== Common ===== + + /// Signal that no more rows will be dispatched. Thread-safe. + void setNoMoreInput(); + + /// Returns true when all work is complete. Thread-safe. + bool isFinished(); + + /// Returns true if backpressure should be applied. Thread-safe. + /// PER_ROW mode: pending rows >= maxPendingRows. + /// BATCH mode: pending batches >= effectiveMaxPendingBatches + /// (congestion-adjusted). + bool isUnderBackpressure(); + + /// Signal that a batch completed successfully (all responses non-empty). + /// Increases the effective concurrency window by increment (additive + /// increase). Thread-safe. + void onBatchSuccess(int64_t increment = 2); + + /// Signal that a batch had errors (e.g., empty responses from overload). + /// Halves the effective concurrency window (multiplicative decrease). + /// Thread-safe. + void onBatchError(); + + private: + /// Move a completed row into readyRows_ and notify waiters. + /// Called from the RPC completion callback (runs on executor thread). + void completeRow(int64_t rowId, RowLocation location, RPCResponse response); + + /// Fulfill all waiting promises and clear. Called under lock. + void notifyWaitersLocked(); + + mutable std::mutex mutex_; + std::vector promises_; + + // Input batch storage (shared across PER_ROW and BATCH modes) + std::vector inputBatches_; + + // PER_ROW state + std::deque readyRows_; + int64_t numPendingRows_{0}; + + // BATCH state + int64_t nextBatchId_{0}; + std::deque pendingBatches_; + + // Common + bool noMoreInput_{false}; + RPCStreamingMode streamingMode_{RPCStreamingMode::kPerRow}; + int64_t maxPendingRows_{100}; + int64_t maxPendingBatches_{2}; + + // Congestion control for BATCH mode. + // effectiveMaxPendingBatches_ starts at maxPendingBatches_ and adjusts: + // - On success: min(effective + 1, maxPendingBatches_) (additive increase) + // - On error: max(effective / 2, 1) (multiplicative + // decrease) + int64_t effectiveMaxPendingBatches_{2}; +}; + +} // namespace facebook::velox::exec::rpc diff --git a/velox/exec/rpc/tests/CMakeLists.txt b/velox/exec/rpc/tests/CMakeLists.txt new file mode 100644 index 00000000000..8b2f61801d5 --- /dev/null +++ b/velox/exec/rpc/tests/CMakeLists.txt @@ -0,0 +1,95 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_executable(velox_mock_rpc_client_test MockRPCClientTest.cpp) + +target_link_libraries( + velox_mock_rpc_client_test + velox_mock_rpc_client + GTest::gtest + GTest::gtest_main + Folly::folly +) + +add_test(NAME velox_mock_rpc_client_test COMMAND velox_mock_rpc_client_test) + +add_executable(velox_rpc_state_test RPCStateTest.cpp) + +target_link_libraries( + velox_rpc_state_test + velox_rpc_state + velox_rpc_types + GTest::gtest + GTest::gtest_main + Folly::folly +) + +add_test(NAME velox_rpc_state_test COMMAND velox_rpc_state_test) + +add_executable(velox_rpc_node_test RPCNodeTest.cpp) + +target_link_libraries( + velox_rpc_node_test + velox_mock_rpc_client + velox_core + velox_memory + velox_vector + GTest::gtest + GTest::gtest_main + Folly::folly +) + +add_test(NAME velox_rpc_node_test COMMAND velox_rpc_node_test) + +add_executable(velox_demo_rpc_function_test DemoRPCFunctionTest.cpp) +velox_add_test_headers(velox_demo_rpc_function_test DemoRPCFunction.h) + +target_link_libraries( + velox_demo_rpc_function_test + velox_demo_rpc_function + velox_core + velox_memory + velox_vector + GTest::gtest + GTest::gtest_main + Folly::folly +) + +add_test(NAME velox_demo_rpc_function_test COMMAND velox_demo_rpc_function_test) + +add_executable(velox_rpc_rate_limiter_test RPCRateLimiterTest.cpp) + +target_link_libraries( + velox_rpc_rate_limiter_test + velox_rpc_rate_limiter + GTest::gtest + GTest::gtest_main + Folly::folly +) + +add_test(NAME velox_rpc_rate_limiter_test COMMAND velox_rpc_rate_limiter_test) + +add_executable(velox_rpc_operator_test RPCOperatorTest.cpp Main.cpp) + +target_link_libraries( + velox_rpc_operator_test + velox_demo_rpc_function + velox_rpc_plan_node_translator + velox_async_rpc_function_registry + velox_exec_test_lib + GTest::gtest + Folly::folly +) + +add_test(NAME velox_rpc_operator_test COMMAND velox_rpc_operator_test) diff --git a/velox/exec/rpc/tests/DemoRPCFunction.cpp b/velox/exec/rpc/tests/DemoRPCFunction.cpp new file mode 100644 index 00000000000..150880cecda --- /dev/null +++ b/velox/exec/rpc/tests/DemoRPCFunction.cpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/rpc/tests/DemoRPCFunction.h" + +namespace facebook::velox::exec::rpc { + +void DemoAsyncRPCFunction::initialize( + const core::QueryConfig& /*queryConfig*/, + const std::vector& /*inputTypes*/, + const std::vector& /*constantInputs*/) { + // Create and cache the mock client during initialization. + client_ = std::make_shared( + std::chrono::milliseconds(1), // minimal latency + 0.0); // no errors +} + +std::vector>> +DemoAsyncRPCFunction::dispatchPerRow( + const SelectivityVector& rows, + const std::vector& args) { + std::vector>> results; + + if (args.empty()) { + return results; + } + + auto* promptVector = args[0]->as>(); + if (!promptVector) { + return results; + } + + rows.applyToSelected([&](vector_size_t row) { + if (promptVector->isNullAt(row)) { + // Null input → immediate error response. + results.emplace_back( + row, + folly::makeSemiFuture(RPCResponse{ + .rowId = 0, + .result = "", + .metadata = {}, + .error = "null_input"})); + return; + } + + // Build RPCRequest for MockRPCClient (test utility still uses payload). + RPCRequest request; + request.payload = promptVector->valueAt(row).str(); + + results.emplace_back(row, client_->call(request)); + }); + + return results; +} + +std::vector> +DemoAsyncRPCFunction::signatures() { + // 1-argument form: demo_rpc(prompt) + auto sig = exec::FunctionSignatureBuilder() + .returnType("varchar") + .argumentType("varchar") // prompt + .build(); + + return {std::move(sig)}; +} + +} // namespace facebook::velox::exec::rpc diff --git a/velox/exec/rpc/tests/DemoRPCFunction.h b/velox/exec/rpc/tests/DemoRPCFunction.h new file mode 100644 index 00000000000..73126ed8f35 --- /dev/null +++ b/velox/exec/rpc/tests/DemoRPCFunction.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "velox/common/rpc/clients/MockRPCClient.h" +#include "velox/expression/FunctionSignature.h" +#include "velox/expression/rpc/AsyncRPCFunction.h" + +namespace facebook::velox::exec::rpc { + +using velox::rpc::MockRPCClient; + +/// Demo AsyncRPCFunction that uses MockRPCClient for end-to-end testing. +/// +/// Demonstrates the full AsyncRPCFunction lifecycle: +/// 1. initialize() — creates and caches the MockRPCClient +/// 2. dispatchPerRow() — dispatches per-row RPCs via MockRPCClient +/// 3. buildOutput() — uses base class default (error→null, result→varchar) +/// +/// Returns "Response for: " for each input row. No external +/// dependencies — runs entirely in-process with simulated latency. +/// +/// SQL usage: +/// SELECT demo_rpc('hello world') +/// -- Returns: "Response for: hello world" +class DemoAsyncRPCFunction : public AsyncRPCFunction { + public: + /// Initialize the mock client. Called by RPCOperator during init. + void initialize( + const core::QueryConfig& queryConfig, + const std::vector& inputTypes, + const std::vector& constantInputs) override; + + std::string name() const override { + return "demo_rpc"; + } + + TypePtr resultType() const override { + return VARCHAR(); + } + + /// Dispatch individual RPCs for each active row via MockRPCClient. + /// Null-input rows get an immediate RPCResponse with error="null_input". + std::vector>> + dispatchPerRow( + const SelectivityVector& rows, + const std::vector& args) override; + + /// SQL function signatures for registration. + static std::vector> signatures(); + + private: + std::shared_ptr client_; +}; + +} // namespace facebook::velox::exec::rpc diff --git a/velox/exec/rpc/tests/DemoRPCFunctionRegistration.cpp b/velox/exec/rpc/tests/DemoRPCFunctionRegistration.cpp new file mode 100644 index 00000000000..e41ee48483d --- /dev/null +++ b/velox/exec/rpc/tests/DemoRPCFunctionRegistration.cpp @@ -0,0 +1,25 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/// Static registration of DemoAsyncRPCFunction with the +/// AsyncRPCFunctionRegistry. + +#include "velox/exec/rpc/tests/DemoRPCFunction.h" +#include "velox/expression/rpc/AsyncRPCFunctionRegistry.h" + +using namespace facebook::velox::exec::rpc; + +VELOX_REGISTER_RPC_FUNCTION(demo_rpc, DemoAsyncRPCFunction); diff --git a/velox/exec/rpc/tests/DemoRPCFunctionTest.cpp b/velox/exec/rpc/tests/DemoRPCFunctionTest.cpp new file mode 100644 index 00000000000..c2b24871a47 --- /dev/null +++ b/velox/exec/rpc/tests/DemoRPCFunctionTest.cpp @@ -0,0 +1,152 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/// DemoRPCFunctionTest - End-to-end test for the reference AsyncRPCFunction +/// implementation. +/// +/// Exercises the full lifecycle: initialize -> dispatchPerRow -> buildOutput, +/// verifying that DemoAsyncRPCFunction correctly follows the AsyncRPCFunction +/// contract. + +#include "velox/exec/rpc/tests/DemoRPCFunction.h" + +#include + +#include "velox/common/memory/Memory.h" +#include "velox/core/QueryConfig.h" +#include "velox/vector/FlatVector.h" + +namespace facebook::velox::exec::rpc { +namespace { + +class DemoRPCFunctionTest : public testing::Test { + protected: + static void SetUpTestSuite() { + memory::MemoryManager::testingSetInstance({}); + } + + void SetUp() override { + function_ = std::make_shared(); + pool_ = memory::memoryManager()->addLeafPool(); + + // Follow the lifecycle: initialize() before any dispatch. + function_->initialize(core::QueryConfig{{}}, {}, {}); + } + + std::shared_ptr function_; + std::shared_ptr pool_; +}; + +TEST_F(DemoRPCFunctionTest, endToEnd) { + // Build input vector. + std::vector prompts = {"hello world", "test prompt"}; + const auto numRows = static_cast(prompts.size()); + auto input = BaseVector::create>( + VARCHAR(), numRows, pool_.get()); + for (vector_size_t i = 0; i < numRows; ++i) { + input->set(i, StringView(prompts[i])); + } + + // Dispatch per-row and collect futures. + SelectivityVector rows(numRows); + auto futures = function_->dispatchPerRow(rows, {input}); + ASSERT_EQ(futures.size(), 2); + EXPECT_EQ(futures[0].first, 0); + EXPECT_EQ(futures[1].first, 1); + + // Resolve futures. + std::vector responses; + responses.reserve(futures.size()); + for (auto& [rowIdx, future] : futures) { + responses.push_back(std::move(future).get()); + } + ASSERT_EQ(responses.size(), 2); + for (const auto& resp : responses) { + EXPECT_FALSE(resp.hasError()); + } + + // Build output vector. + auto result = function_->buildOutput(responses, pool_.get()); + ASSERT_EQ(result->size(), 2); + auto* flat = result->asFlatVector(); + EXPECT_FALSE(flat->isNullAt(0)); + EXPECT_FALSE(flat->isNullAt(1)); + // MockRPCClient returns "Response for: ". + EXPECT_EQ(flat->valueAt(0).str(), "Response for: hello world"); + EXPECT_EQ(flat->valueAt(1).str(), "Response for: test prompt"); +} + +TEST_F(DemoRPCFunctionTest, nullInput) { + // Build input with a null row. + auto input = + BaseVector::create>(VARCHAR(), 2, pool_.get()); + input->set(0, StringView("valid prompt")); + input->setNull(1, true); + + SelectivityVector rows(2); + auto futures = function_->dispatchPerRow(rows, {input}); + + // Both rows produce futures (null row gets immediate error response). + ASSERT_EQ(futures.size(), 2); + EXPECT_EQ(futures[0].first, 0); + EXPECT_EQ(futures[1].first, 1); + + // Non-null row should succeed. + auto resp0 = std::move(futures[0].second).get(); + EXPECT_FALSE(resp0.hasError()); + + // Null row should get error="null_input". + auto resp1 = std::move(futures[1].second).get(); + EXPECT_TRUE(resp1.hasError()); + EXPECT_EQ(resp1.error.value(), "null_input"); +} + +TEST_F(DemoRPCFunctionTest, errorResponse) { + // Build output from a response with an error. + std::vector responses; + RPCResponse ok; + ok.rowId = 0; + ok.result = "good result"; + responses.push_back(std::move(ok)); + + RPCResponse err; + err.rowId = 1; + err.error = "RPC failed"; + responses.push_back(std::move(err)); + + auto result = function_->buildOutput(responses, pool_.get()); + ASSERT_EQ(result->size(), 2); + auto* flat = result->asFlatVector(); + EXPECT_FALSE(flat->isNullAt(0)); + EXPECT_EQ(flat->valueAt(0).str(), "good result"); + EXPECT_TRUE(flat->isNullAt(1)); +} + +TEST_F(DemoRPCFunctionTest, signatures) { + auto sigs = DemoAsyncRPCFunction::signatures(); + ASSERT_EQ(sigs.size(), 1); + // demo_rpc(varchar) -> varchar + EXPECT_EQ(sigs[0]->argumentTypes().size(), 1); +} + +TEST_F(DemoRPCFunctionTest, metadata) { + EXPECT_EQ(function_->name(), "demo_rpc"); + EXPECT_EQ(function_->resultType()->kind(), TypeKind::VARCHAR); + EXPECT_EQ(function_->tierKey(), ""); +} + +} // namespace +} // namespace facebook::velox::exec::rpc diff --git a/velox/exec/rpc/tests/Main.cpp b/velox/exec/rpc/tests/Main.cpp new file mode 100644 index 00000000000..39c009ebdd7 --- /dev/null +++ b/velox/exec/rpc/tests/Main.cpp @@ -0,0 +1,31 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/common/memory/Memory.h" +#include "velox/common/process/ThreadDebugInfo.h" + +#include +#include +#include + +// This main is needed for some tests on linux. +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + // Signal handler required for ThreadDebugInfoTest + facebook::velox::process::addDefaultFatalSignalHandler(); + folly::Init init(&argc, &argv, false); + facebook::velox::memory::MemoryManager::initialize({}); + return RUN_ALL_TESTS(); +} diff --git a/velox/exec/rpc/tests/MockRPCClientTest.cpp b/velox/exec/rpc/tests/MockRPCClientTest.cpp new file mode 100644 index 00000000000..84f679562a6 --- /dev/null +++ b/velox/exec/rpc/tests/MockRPCClientTest.cpp @@ -0,0 +1,105 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/// MockRPCClientTest - Tests the mock RPC backend. +/// +/// TESTS: +/// - basicCallAndError: Single call succeeds; error preserves rowId +/// - batchCallAndError: Batch call returns correct count; errors preserve +/// rowIds + +#include "velox/common/rpc/clients/MockRPCClient.h" + +#include +#include + +namespace facebook::velox::rpc { +namespace { + +class MockRPCClientTest : public testing::Test {}; + +TEST_F(MockRPCClientTest, basicCallAndError) { + // Success path + MockRPCClient client(std::chrono::milliseconds(1), 0.0); + + RPCRequest request; + request.rowId = 42; + request.payload = "What is the capital of France?"; + request.options[std::string(rpc::keys::kModel)] = "test-model"; + + auto response = client.call(request).get(); + + EXPECT_FALSE(response.hasError()); + EXPECT_FALSE(response.result.empty()); + EXPECT_EQ(response.rowId, 42); + EXPECT_EQ(client.callCount(), 1); + + // Error path: rowId must be preserved + MockRPCClient errorClient(std::chrono::milliseconds(1), 1.0); + + RPCRequest errorRequest; + errorRequest.rowId = 12345; + errorRequest.payload = "Test"; + + auto errorResponse = errorClient.call(errorRequest).get(); + + EXPECT_TRUE(errorResponse.hasError()); + EXPECT_EQ(errorResponse.rowId, 12345); +} + +TEST_F(MockRPCClientTest, batchCallAndError) { + // Success path + MockRPCClient client(std::chrono::milliseconds(1), 0.0); + + std::vector requests; + for (int i = 0; i < 5; i++) { + RPCRequest req; + req.rowId = i; + req.payload = "Prompt " + std::to_string(i); + requests.push_back(std::move(req)); + } + + auto responses = client.callBatch(requests).get(); + + EXPECT_EQ(responses.size(), 5); + for (int i = 0; i < 5; i++) { + EXPECT_FALSE(responses[i].hasError()); + EXPECT_EQ(responses[i].rowId, i); + } + EXPECT_EQ(client.callCount(), 5); + + // Error path: rowIds must be preserved + MockRPCClient errorClient(std::chrono::milliseconds(1), 1.0); + + std::vector errorRequests; + for (int i = 0; i < 5; i++) { + RPCRequest req; + req.rowId = 100 + i; + req.payload = "Test " + std::to_string(i); + errorRequests.push_back(std::move(req)); + } + + auto errorResponses = errorClient.callBatch(errorRequests).get(); + + EXPECT_EQ(errorResponses.size(), 5); + for (int i = 0; i < 5; i++) { + EXPECT_TRUE(errorResponses[i].hasError()); + EXPECT_EQ(errorResponses[i].rowId, 100 + i); + } +} + +} // namespace +} // namespace facebook::velox::rpc diff --git a/velox/exec/rpc/tests/RPCNodeTest.cpp b/velox/exec/rpc/tests/RPCNodeTest.cpp new file mode 100644 index 00000000000..9a31deac9e4 --- /dev/null +++ b/velox/exec/rpc/tests/RPCNodeTest.cpp @@ -0,0 +1,243 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/// RPCNodeTest - Tests plan node creation and AsyncRPCFunction behavior. +/// +/// TESTS: +/// - rpcNodeCreation: RPCNode can be created with correct fields +/// - rpcNodeWithBatchMode: Batch mode configuration works +/// - functionDispatchPerRow: AsyncRPCFunction.dispatchPerRow() works +/// - functionBuildOutput: AsyncRPCFunction.buildOutput() works +/// - planNodeToString: toString() includes configuration info +/// - planNodeSingleSource: Plan node has exactly one source + +#include "velox/core/PlanNode.h" + +#include + +#include "velox/common/memory/Memory.h" +#include "velox/common/rpc/clients/MockRPCClient.h" +#include "velox/expression/rpc/AsyncRPCFunction.h" +#include "velox/vector/FlatVector.h" + +namespace facebook::velox::exec::rpc { + +using velox::rpc::MockRPCClient; + +namespace { + +/// Mock implementation of AsyncRPCFunction for testing. +class MockAsyncRPCFunction : public AsyncRPCFunction { + public: + explicit MockAsyncRPCFunction(std::shared_ptr client) + : client_(std::move(client)) {} + + std::string name() const override { + return "mock_rpc_function"; + } + + TypePtr resultType() const override { + return VARCHAR(); + } + + std::vector>> + dispatchPerRow( + const SelectivityVector& rows, + const std::vector& args) override { + std::vector>> + results; + + if (args.empty()) { + return results; + } + + auto* promptVector = args[0]->as>(); + if (!promptVector) { + return results; + } + + rows.applyToSelected([&](vector_size_t row) { + if (promptVector->isNullAt(row)) { + results.emplace_back( + row, + folly::makeSemiFuture(RPCResponse{ + .rowId = static_cast(row), + .result = "", + .metadata = {}, + .error = "null_input"})); + return; + } + RPCRequest request; + request.payload = promptVector->valueAt(row).str(); + results.emplace_back(row, client_->call(request)); + }); + + return results; + } + + private: + std::shared_ptr client_; +}; + +class RPCNodeTest : public testing::Test { + protected: + static void SetUpTestSuite() { + memory::MemoryManager::testingSetInstance({}); + } + + void SetUp() override { + client_ = + std::make_shared(std::chrono::milliseconds(10), 0.0); + function_ = std::make_shared(client_); + pool_ = memory::memoryManager()->addLeafPool(); + } + + std::shared_ptr client_; + std::shared_ptr function_; + std::shared_ptr pool_; +}; + +TEST_F(RPCNodeTest, rpcNodeCreation) { + auto rpcNode = std::make_shared( + "rpc-1", + nullptr, + "mock_rpc_function", + VARCHAR(), + "response", + ROW({"response"}, {VARCHAR()}), + std::vector{}, + std::vector{}, + std::vector{}); + + EXPECT_EQ(rpcNode->id(), "rpc-1"); + EXPECT_EQ(rpcNode->name(), "RPC"); + EXPECT_EQ(rpcNode->functionName(), "mock_rpc_function"); + EXPECT_EQ(rpcNode->outputColumn(), "response"); + EXPECT_EQ(rpcNode->streamingMode(), rpc::RPCStreamingMode::kPerRow); + EXPECT_EQ(rpcNode->dispatchBatchSize(), 0); +} + +TEST_F(RPCNodeTest, rpcNodeWithBatchMode) { + auto rpcNode = std::make_shared( + "rpc-2", + nullptr, + "mock_rpc_function", + VARCHAR(), + "result", + ROW({"result"}, {VARCHAR()}), + std::vector{}, + std::vector{}, + std::vector{}, + rpc::RPCStreamingMode::kBatch, + 100); + + EXPECT_EQ(rpcNode->streamingMode(), rpc::RPCStreamingMode::kBatch); + EXPECT_EQ(rpcNode->dispatchBatchSize(), 100); +} + +TEST_F(RPCNodeTest, functionDispatchPerRow) { + std::vector prompts = { + "What is 2+2?", + "What is the capital of France?", + "Explain quantum computing."}; + + const auto numPrompts = static_cast(prompts.size()); + auto promptVector = BaseVector::create>( + VARCHAR(), numPrompts, pool_.get()); + for (vector_size_t i = 0; i < numPrompts; ++i) { + promptVector->set(i, StringView(prompts[i])); + } + + std::vector args = {promptVector}; + SelectivityVector rows(numPrompts); + + auto futures = function_->dispatchPerRow(rows, args); + + // Extract row indices and verify ordering. + std::vector rowIndices; + rowIndices.reserve(futures.size()); + for (const auto& [idx, _] : futures) { + rowIndices.push_back(idx); + } + ASSERT_EQ(rowIndices.size(), 3); + EXPECT_EQ(rowIndices[0], 0); + EXPECT_EQ(rowIndices[1], 1); + EXPECT_EQ(rowIndices[2], 2); + + // Resolve futures and verify responses. + for (auto& [rowIdx, future] : futures) { + auto response = std::move(future).get(); + EXPECT_FALSE(response.hasError()); + EXPECT_FALSE(response.result.empty()); + } +} + +TEST_F(RPCNodeTest, functionBuildOutput) { + std::vector responses; + for (int i = 0; i < 3; ++i) { + RPCResponse resp; + resp.rowId = i; + resp.result = "Response for prompt " + std::to_string(i); + responses.push_back(std::move(resp)); + } + + auto result = function_->buildOutput(responses, pool_.get()); + + EXPECT_EQ(result->size(), 3); + EXPECT_EQ(result->type()->kind(), TypeKind::VARCHAR); + + auto* flatResult = result->asFlatVector(); + EXPECT_EQ(flatResult->valueAt(0).str(), "Response for prompt 0"); + EXPECT_EQ(flatResult->valueAt(1).str(), "Response for prompt 1"); + EXPECT_EQ(flatResult->valueAt(2).str(), "Response for prompt 2"); +} + +TEST_F(RPCNodeTest, planNodeToString) { + auto rpcNode = std::make_shared( + "rpc-1", + nullptr, + "mock_rpc_function", + VARCHAR(), + "response", + ROW({"response"}, {VARCHAR()}), + std::vector{}, + std::vector{}, + std::vector{}); + + std::string str = rpcNode->toString(/*detailed=*/true); + + EXPECT_TRUE(str.find("RPC") != std::string::npos); + EXPECT_TRUE(str.find("mock_rpc_function") != std::string::npos); +} + +TEST_F(RPCNodeTest, planNodeSingleSource) { + auto rpcNode = std::make_shared( + "rpc-1", + nullptr, + "mock_rpc_function", + VARCHAR(), + "response", + ROW({"response"}, {VARCHAR()}), + std::vector{}, + std::vector{}, + std::vector{}); + + ASSERT_EQ(rpcNode->sources().size(), 1); + EXPECT_EQ(rpcNode->source().get(), nullptr); +} + +} // namespace +} // namespace facebook::velox::exec::rpc diff --git a/velox/exec/rpc/tests/RPCOperatorTest.cpp b/velox/exec/rpc/tests/RPCOperatorTest.cpp new file mode 100644 index 00000000000..bf546608a32 --- /dev/null +++ b/velox/exec/rpc/tests/RPCOperatorTest.cpp @@ -0,0 +1,189 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/// RPCOperatorTest - End-to-end task-level test for RPCOperator. +/// +/// Runs a full Velox Task/Driver pipeline: Values → RPCNode → output. +/// Verifies that RPCPlanNodeTranslator, RPCOperator, RPCState, and +/// AsyncRPCFunction wire together correctly through the execution engine. + +#include + +#include "velox/exec/rpc/RPCPlanNodeTranslator.h" +#include "velox/exec/rpc/RPCRateLimiter.h" +#include "velox/exec/rpc/tests/DemoRPCFunction.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/expression/rpc/AsyncRPCFunctionRegistry.h" + +namespace facebook::velox::exec::rpc { + +using namespace facebook::velox::exec::test; + +class RPCOperatorTest : public OperatorTestBase { + protected: + static void SetUpTestCase() { + OperatorTestBase::SetUpTestCase(); + registerRPCPlanNodeTranslator(); + AsyncRPCFunctionRegistry::registerFunction( + "demo_rpc", []() { return std::make_shared(); }); + } + + static void TearDownTestCase() { + OperatorTestBase::TearDownTestCase(); + // Reset MemoryManager to shut down SharedArbitrator executor threads. + // Without this, TSAN reports a non-zero exit because the executor + // threads are still running at process exit. + memory::MemoryManager::testingSetInstance({}); + } + + void TearDown() override { + RPCRateLimiter::testingResetAllState(); + OperatorTestBase::TearDown(); + } + + /// Build an RPCNode on top of a source plan node. + /// argumentColumnNames specifies which source columns are RPC arguments. + core::PlanNodePtr makeRPCNode( + const core::PlanNodePtr& source, + const std::vector& argumentColumnNames) { + auto sourceType = source->outputType(); + + std::vector argCols; + std::vector argTypes; + std::vector constantInputs; + for (const auto& colName : argumentColumnNames) { + argCols.push_back(colName); + argTypes.push_back(sourceType->findChild(colName)); + constantInputs.push_back(nullptr); // Variable, not constant. + } + + // Output type = all source columns + RPC result column. + auto outputNames = sourceType->names(); + auto outputTypes = sourceType->children(); + outputNames.emplace_back("__rpc_result"); + outputTypes.push_back(VARCHAR()); + auto outputType = ROW(std::move(outputNames), std::move(outputTypes)); + + return std::make_shared( + "rpc-0", + source, + "demo_rpc", + VARCHAR(), + "__rpc_result", + outputType, + argCols, + argTypes, + constantInputs); + } +}; + +/// Runs Values(3 rows) → RPCNode → verifies passthrough + RPC result. +TEST_F(RPCOperatorTest, basicPerRow) { + auto input = makeRowVector( + {"prompt"}, + {makeFlatVector( + {"hello world", "test prompt", "third row"})}); + + auto plan = makeRPCNode(PlanBuilder().values({input}).planNode(), {"prompt"}); + + auto result = AssertQueryBuilder(plan).copyResults(pool()); + + ASSERT_EQ(result->size(), 3); + ASSERT_EQ(result->type()->size(), 2); // prompt + __rpc_result + + // Rows may arrive out of order (async dispatch). Collect and sort to verify. + auto* prompts = result->childAt(0)->asFlatVector(); + auto* results = result->childAt(1)->asFlatVector(); + + std::map rows; + for (vector_size_t i = 0; i < result->size(); ++i) { + rows[prompts->valueAt(i).str()] = results->valueAt(i).str(); + } + + EXPECT_EQ(rows["hello world"], "Response for: hello world"); + EXPECT_EQ(rows["test prompt"], "Response for: test prompt"); + EXPECT_EQ(rows["third row"], "Response for: third row"); +} + +/// Null input rows should produce null in the RPC result column. +TEST_F(RPCOperatorTest, nullInput) { + auto promptVector = + makeNullableFlatVector({"valid prompt", std::nullopt}); + auto input = makeRowVector({"prompt"}, {promptVector}); + + auto plan = makeRPCNode(PlanBuilder().values({input}).planNode(), {"prompt"}); + + auto result = AssertQueryBuilder(plan).copyResults(pool()); + + ASSERT_EQ(result->size(), 2); + + auto* prompts = result->childAt(0)->asFlatVector(); + auto* results = result->childAt(1)->asFlatVector(); + + // Find which row is the valid one vs the null one. + for (vector_size_t i = 0; i < result->size(); ++i) { + if (prompts->isNullAt(i)) { + // Null input row should produce null result. + EXPECT_TRUE(results->isNullAt(i)); + } else { + EXPECT_EQ(prompts->valueAt(i).str(), "valid prompt"); + EXPECT_FALSE(results->isNullAt(i)); + EXPECT_EQ(results->valueAt(i).str(), "Response for: valid prompt"); + } + } +} + +/// Multiple source columns — verifies all passthrough columns are preserved. +TEST_F(RPCOperatorTest, multipleColumns) { + auto input = makeRowVector( + {"id", "prompt", "extra"}, + {makeFlatVector({100, 200}), + makeFlatVector({"question one", "question two"}), + makeFlatVector({1.5, 2.5})}); + + // Only "prompt" is an RPC argument; "id" and "extra" are passthrough. + auto plan = makeRPCNode(PlanBuilder().values({input}).planNode(), {"prompt"}); + + auto result = AssertQueryBuilder(plan).copyResults(pool()); + + ASSERT_EQ(result->size(), 2); + ASSERT_EQ(result->type()->size(), 4); // id, prompt, extra, __rpc_result + + // Rows may arrive out of order. Index by prompt to verify. + auto* prompts = result->childAt(1)->asFlatVector(); + auto* ids = result->childAt(0)->asFlatVector(); + auto* extras = result->childAt(2)->asFlatVector(); + auto* results = result->childAt(3)->asFlatVector(); + + std::map rowIndex; + for (vector_size_t i = 0; i < result->size(); ++i) { + rowIndex[prompts->valueAt(i).str()] = i; + } + + auto i1 = rowIndex["question one"]; + EXPECT_EQ(ids->valueAt(i1), 100); + EXPECT_EQ(extras->valueAt(i1), 1.5); + EXPECT_EQ(results->valueAt(i1).str(), "Response for: question one"); + + auto i2 = rowIndex["question two"]; + EXPECT_EQ(ids->valueAt(i2), 200); + EXPECT_EQ(extras->valueAt(i2), 2.5); + EXPECT_EQ(results->valueAt(i2).str(), "Response for: question two"); +} + +} // namespace facebook::velox::exec::rpc diff --git a/velox/exec/rpc/tests/RPCRateLimiterTest.cpp b/velox/exec/rpc/tests/RPCRateLimiterTest.cpp new file mode 100644 index 00000000000..127a0893159 --- /dev/null +++ b/velox/exec/rpc/tests/RPCRateLimiterTest.cpp @@ -0,0 +1,228 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/// RPCRateLimiterTest - Tests for the per-tier RPC rate limiter. +/// +/// RPCRateLimiter provides per-tier concurrency limits with FIFO waiter +/// notification and RAII token-based slot management. +/// +/// Tests cover: +/// - acquireAndRelease: Token acquire increments, destruction decrements. +/// - backpressureWhenAtLimit: checkBackpressure returns future at limit. +/// - backpressureReliefOnRelease: Waiter notified when token released. +/// - fifoWaiterNotification: Multiple waiters notified in FIFO order. +/// - perTierIsolation: Different tiers have independent limits. +/// - perTierMaxPending: setMaxPending overrides global default. +/// - defaultMaxPending: setDefaultMaxPending affects tiers without override. +/// - tokenMoveSemantics: Move constructor/assignment transfer ownership. +/// - testingResetAllState: Reset clears all tiers and restores defaults. + +#include "velox/exec/rpc/RPCRateLimiter.h" + +#include + +#include + +namespace facebook::velox::exec::rpc { +namespace { + +class RPCRateLimiterTest : public testing::Test { + protected: + void SetUp() override { + RPCRateLimiter::testingResetAllState(); + } + + void TearDown() override { + RPCRateLimiter::testingResetAllState(); + } +}; + +TEST_F(RPCRateLimiterTest, acquireAndRelease) { + const std::string tier = "test.tier"; + EXPECT_EQ(RPCRateLimiter::pendingCount(tier), 0); + + { + auto token = RPCRateLimiter::acquire(tier); + EXPECT_EQ(RPCRateLimiter::pendingCount(tier), 1); + } + // Token destroyed — count should be back to 0. + EXPECT_EQ(RPCRateLimiter::pendingCount(tier), 0); +} + +TEST_F(RPCRateLimiterTest, multipleAcquires) { + const std::string tier = "test.tier"; + + auto token1 = RPCRateLimiter::acquire(tier); + auto token2 = RPCRateLimiter::acquire(tier); + auto token3 = RPCRateLimiter::acquire(tier); + EXPECT_EQ(RPCRateLimiter::pendingCount(tier), 3); +} + +TEST_F(RPCRateLimiterTest, backpressureWhenAtLimit) { + const std::string tier = "test.tier"; + RPCRateLimiter::setMaxPending(tier, 2); + + auto token1 = RPCRateLimiter::acquire(tier); + // 1 pending, limit 2 — no backpressure. + EXPECT_FALSE(RPCRateLimiter::checkBackpressure(tier).has_value()); + + auto token2 = RPCRateLimiter::acquire(tier); + // 2 pending, limit 2 — at limit, should get backpressure. + auto future = RPCRateLimiter::checkBackpressure(tier); + EXPECT_TRUE(future.has_value()); +} + +TEST_F(RPCRateLimiterTest, backpressureReliefOnRelease) { + const std::string tier = "test.tier"; + RPCRateLimiter::setMaxPending(tier, 1); + + auto token1 = RPCRateLimiter::acquire(tier); + + // At limit — should block. + auto future = RPCRateLimiter::checkBackpressure(tier); + ASSERT_TRUE(future.has_value()); + EXPECT_FALSE(future->isReady()); + + // Release the token — waiter should be notified. + token1 = RPCRateLimiter::Token(); + EXPECT_TRUE(future->isReady()); + EXPECT_EQ(RPCRateLimiter::pendingCount(tier), 0); +} + +TEST_F(RPCRateLimiterTest, fifoWaiterNotification) { + const std::string tier = "test.tier"; + RPCRateLimiter::setMaxPending(tier, 1); + + auto token = RPCRateLimiter::acquire(tier); + + // Two waiters enqueue while at limit. + auto future1 = RPCRateLimiter::checkBackpressure(tier); + auto future2 = RPCRateLimiter::checkBackpressure(tier); + ASSERT_TRUE(future1.has_value()); + ASSERT_TRUE(future2.has_value()); + + // Release token — only first waiter should be notified (FIFO). + token = RPCRateLimiter::Token(); + EXPECT_TRUE(future1->isReady()); + EXPECT_FALSE(future2->isReady()); + + // Acquire and release again — second waiter should be notified. + { + auto token2 = RPCRateLimiter::acquire(tier); + } + EXPECT_TRUE(future2->isReady()); +} + +TEST_F(RPCRateLimiterTest, perTierIsolation) { + const std::string tier1 = "tier.one"; + const std::string tier2 = "tier.two"; + RPCRateLimiter::setMaxPending(tier1, 1); + RPCRateLimiter::setMaxPending(tier2, 1); + + auto tokenA = RPCRateLimiter::acquire(tier1); + EXPECT_EQ(RPCRateLimiter::pendingCount(tier1), 1); + EXPECT_EQ(RPCRateLimiter::pendingCount(tier2), 0); + + // tier1 at limit, tier2 not. + EXPECT_TRUE(RPCRateLimiter::checkBackpressure(tier1).has_value()); + EXPECT_FALSE(RPCRateLimiter::checkBackpressure(tier2).has_value()); +} + +TEST_F(RPCRateLimiterTest, perTierMaxPending) { + const std::string tier = "test.tier"; + RPCRateLimiter::setDefaultMaxPending(10); + RPCRateLimiter::setMaxPending(tier, 2); + + auto token1 = RPCRateLimiter::acquire(tier); + auto token2 = RPCRateLimiter::acquire(tier); + + // Per-tier limit of 2 applies, not global default of 10. + EXPECT_TRUE(RPCRateLimiter::checkBackpressure(tier).has_value()); +} + +TEST_F(RPCRateLimiterTest, defaultMaxPending) { + RPCRateLimiter::setDefaultMaxPending(2); + const std::string tier = "test.default.tier"; + + auto token1 = RPCRateLimiter::acquire(tier); + EXPECT_FALSE(RPCRateLimiter::checkBackpressure(tier).has_value()); + + auto token2 = RPCRateLimiter::acquire(tier); + // Global default of 2 reached. + EXPECT_TRUE(RPCRateLimiter::checkBackpressure(tier).has_value()); +} + +TEST_F(RPCRateLimiterTest, tokenMoveConstructor) { + const std::string tier = "test.tier"; + + auto token1 = RPCRateLimiter::acquire(tier); + EXPECT_EQ(RPCRateLimiter::pendingCount(tier), 1); + + // Move construct — ownership transfers, count stays 1. + auto token2 = std::move(token1); + EXPECT_EQ(RPCRateLimiter::pendingCount(tier), 1); + + // Destroy moved-from token — no effect. + // (token1 is already moved-from, but let it go out of scope naturally) +} + +TEST_F(RPCRateLimiterTest, tokenMoveAssignment) { + const std::string tier1 = "tier.one"; + const std::string tier2 = "tier.two"; + + auto token1 = RPCRateLimiter::acquire(tier1); + auto token2 = RPCRateLimiter::acquire(tier2); + EXPECT_EQ(RPCRateLimiter::pendingCount(tier1), 1); + EXPECT_EQ(RPCRateLimiter::pendingCount(tier2), 1); + + // Move-assign token2 into token1 — old token1 (tier1) released. + token1 = std::move(token2); + EXPECT_EQ(RPCRateLimiter::pendingCount(tier1), 0); + EXPECT_EQ(RPCRateLimiter::pendingCount(tier2), 1); +} + +TEST_F(RPCRateLimiterTest, testingResetAllState) { + const std::string tier = "test.tier"; + RPCRateLimiter::setDefaultMaxPending(5); + RPCRateLimiter::setMaxPending(tier, 3); + auto token = RPCRateLimiter::acquire(tier); + + // Move the token out so it doesn't decrement during reset. + // (In practice, testingResetAllState clears the tiers map, + // so the token's destructor will create a new empty tier state.) + + RPCRateLimiter::testingResetAllState(); + + // Default restored to 20. + EXPECT_EQ(RPCRateLimiter::defaultMaxPending(), 20); + // Tier state cleared — pending count is 0 for a fresh tier. + EXPECT_EQ(RPCRateLimiter::pendingCount(tier), 0); +} + +TEST_F(RPCRateLimiterTest, noBackpressureBelowLimit) { + const std::string tier = "test.tier"; + RPCRateLimiter::setMaxPending(tier, 5); + + std::vector tokens; + for (int i = 0; i < 4; ++i) { + tokens.push_back(RPCRateLimiter::acquire(tier)); + EXPECT_FALSE(RPCRateLimiter::checkBackpressure(tier).has_value()); + } + EXPECT_EQ(RPCRateLimiter::pendingCount(tier), 4); +} + +} // namespace +} // namespace facebook::velox::exec::rpc diff --git a/velox/exec/rpc/tests/RPCStateTest.cpp b/velox/exec/rpc/tests/RPCStateTest.cpp new file mode 100644 index 00000000000..4a7aa39c778 --- /dev/null +++ b/velox/exec/rpc/tests/RPCStateTest.cpp @@ -0,0 +1,490 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/// RPCStateTest - Tests for the RPCState shared state coordination. +/// +/// RPCState coordinates between the operator's driver thread and async RPC +/// completion callbacks using a mutex-protected state object. +/// +/// Tests cover both PER_ROW and BATCH streaming modes: +/// +/// PER_ROW tests: +/// - basicAddAndClaim: Add pending row, fulfill, claim via tryClaimOrWait +/// - claimOrWaitReturnsFinished: After noMoreInput + all claimed +/// - claimOrWaitMustWait: When no rows ready, returns kMustWait with future +/// - pendingRowCount: Tracks pending count correctly +/// +/// BATCH tests: +/// - basicAddAndPollBatch: Add pending batch, fulfill, poll +/// - pollBatchOrWaitFinished: After noMoreInput + all polled +/// - pollBatchOrWaitMustWait: When no batch ready, returns kMustWait +/// - batchErrorHandling: Exception in batch future handled gracefully +/// - batchRowLocationsCarriedThrough: Row locations carried from pending to +/// ready batch +/// +/// Common tests: +/// - noMoreInputAndIsFinished: Lifecycle signals +/// - inputBatchStorageAndRelease: Batch-reference input storage and release +/// - backpressure: isUnderBackpressure when pending >= max +/// - drainReadyRows: Batched drain of multiple ready rows + +#include "velox/exec/rpc/RPCState.h" + +#include +#include + +#include +#include +#include +#include + +namespace facebook::velox::exec::rpc { +namespace { + +class RPCStateTest : public testing::Test { + protected: + void SetUp() override { + state_ = std::make_shared(); + } + + /// Polls a condition with short waits until it becomes true or timeout. + static void waitFor( + const std::function& condition, + std::chrono::milliseconds timeout = std::chrono::milliseconds(5000)) { + auto deadline = std::chrono::steady_clock::now() + timeout; + while (!condition()) { + ASSERT_LT(std::chrono::steady_clock::now(), deadline) + << "Timed out waiting for condition"; + /* sleep override */ + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + } + + std::shared_ptr state_; +}; + +// ========== PER_ROW mode tests ========== + +TEST_F(RPCStateTest, basicAddAndClaim) { + state_->setStreamingMode(RPCStreamingMode::kPerRow); + + auto [promise, future] = folly::makePromiseContract(); + + state_->addPendingRow( + state_, 42, RPCState::RowLocation{0, 0}, std::move(future)); + EXPECT_EQ(state_->numPendingRows(), 1); + + // Fulfill the promise + RPCResponse response; + response.rowId = 42; + response.result = "test result"; + promise.setValue(std::move(response)); + + // Wait for async callback to move response into readyRows_ + waitFor([&]() { return state_->numPendingRows() == 0; }); + + // Now claim + ContinueFuture waitFuture{ContinueFuture::makeEmpty()}; + std::optional claimedRow; + auto result = state_->tryClaimOrWait(&waitFuture, &claimedRow); + + ASSERT_EQ(result, RPCState::ClaimResult::kClaimed); + ASSERT_TRUE(claimedRow.has_value()); + EXPECT_EQ(claimedRow->rowId, 42); + EXPECT_EQ(claimedRow->response.result, "test result"); +} + +TEST_F(RPCStateTest, addAndClaimDirect) { + state_->setStreamingMode(RPCStreamingMode::kPerRow); + + auto [promise, future] = folly::makePromiseContract(); + + state_->addPendingRow( + state_, 42, RPCState::RowLocation{0, 0}, std::move(future)); + + // Fulfill the promise + RPCResponse response; + response.rowId = 42; + response.result = "test result"; + promise.setValue(std::move(response)); + + // Wait for async callback to move response into readyRows_ + waitFor([&]() { return state_->numPendingRows() == 0; }); + + // Now claim + ContinueFuture waitFuture{ContinueFuture::makeEmpty()}; + std::optional claimedRow; + auto result = state_->tryClaimOrWait(&waitFuture, &claimedRow); + + ASSERT_EQ(result, RPCState::ClaimResult::kClaimed); + ASSERT_TRUE(claimedRow.has_value()); + EXPECT_EQ(claimedRow->rowId, 42); + EXPECT_EQ(claimedRow->response.result, "test result"); + EXPECT_FALSE(claimedRow->response.hasError()); +} + +TEST_F(RPCStateTest, claimOrWaitReturnsFinished) { + state_->setStreamingMode(RPCStreamingMode::kPerRow); + + // No rows, signal noMoreInput + state_->setNoMoreInput(); + + ContinueFuture waitFuture{ContinueFuture::makeEmpty()}; + std::optional claimedRow; + auto result = state_->tryClaimOrWait(&waitFuture, &claimedRow); + + EXPECT_EQ(result, RPCState::ClaimResult::kFinished); +} + +TEST_F(RPCStateTest, claimOrWaitMustWait) { + state_->setStreamingMode(RPCStreamingMode::kPerRow); + + // Add a pending row that hasn't completed yet + auto [promise, future] = folly::makePromiseContract(); + state_->addPendingRow( + state_, 1, RPCState::RowLocation{0, 0}, std::move(future)); + + // Try to claim — should return kMustWait because row isn't ready + ContinueFuture waitFuture{ContinueFuture::makeEmpty()}; + std::optional claimedRow; + auto result = state_->tryClaimOrWait(&waitFuture, &claimedRow); + + EXPECT_EQ(result, RPCState::ClaimResult::kMustWait); + EXPECT_FALSE(claimedRow.has_value()); + + // Fulfill to clean up + RPCResponse response; + response.rowId = 1; + response.result = "done"; + promise.setValue(std::move(response)); +} + +TEST_F(RPCStateTest, pendingRowCount) { + state_->setStreamingMode(RPCStreamingMode::kPerRow); + + EXPECT_EQ(state_->numPendingRows(), 0); + + auto [promise1, future1] = folly::makePromiseContract(); + auto [promise2, future2] = folly::makePromiseContract(); + + state_->addPendingRow( + state_, 1, RPCState::RowLocation{0, 0}, std::move(future1)); + EXPECT_EQ(state_->numPendingRows(), 1); + + state_->addPendingRow( + state_, 2, RPCState::RowLocation{0, 1}, std::move(future2)); + EXPECT_EQ(state_->numPendingRows(), 2); + + // Fulfill first + RPCResponse r1; + r1.rowId = 1; + r1.result = "r1"; + promise1.setValue(std::move(r1)); + + waitFor([&]() { return state_->numPendingRows() == 1; }); + EXPECT_EQ(state_->numPendingRows(), 1); + + // Fulfill second + RPCResponse r2; + r2.rowId = 2; + r2.result = "r2"; + promise2.setValue(std::move(r2)); + + waitFor([&]() { return state_->numPendingRows() == 0; }); + EXPECT_EQ(state_->numPendingRows(), 0); +} + +// ========== BATCH mode tests ========== + +TEST_F(RPCStateTest, basicAddAndPollBatch) { + state_->setStreamingMode(RPCStreamingMode::kBatch); + + auto [promise, future] = + folly::makePromiseContract>(); + + state_->addPendingBatch(state_, std::move(future), {}); + + // Fulfill the promise + std::vector responses; + RPCResponse batchResponse; + batchResponse.rowId = 1; + batchResponse.result = "test"; + responses.push_back(std::move(batchResponse)); + promise.setValue(std::move(responses)); + + // Wait for async callback and poll + waitFor([&]() { + ContinueFuture waitFuture{ContinueFuture::makeEmpty()}; + std::optional readyBatch; + auto result = state_->tryPollBatchOrWait(&waitFuture, &readyBatch); + if (result == RPCState::BatchPollResult::kGotBatch) { + // Verify row locations are empty (we passed empty). + EXPECT_TRUE(readyBatch->rowLocations.empty()); + return true; + } + return false; + }); +} + +TEST_F(RPCStateTest, pollBatchOrWaitFinished) { + state_->setStreamingMode(RPCStreamingMode::kBatch); + + // No batches, signal noMoreInput + state_->setNoMoreInput(); + + ContinueFuture waitFuture{ContinueFuture::makeEmpty()}; + std::optional readyBatch; + auto result = state_->tryPollBatchOrWait(&waitFuture, &readyBatch); + + EXPECT_EQ(result, RPCState::BatchPollResult::kFinished); +} + +TEST_F(RPCStateTest, pollBatchOrWaitMustWait) { + state_->setStreamingMode(RPCStreamingMode::kBatch); + + // Add a pending batch that hasn't completed yet + auto [promise, future] = + folly::makePromiseContract>(); + state_->addPendingBatch(state_, std::move(future), {}); + + // Try to poll — should return kMustWait + ContinueFuture waitFuture{ContinueFuture::makeEmpty()}; + std::optional readyBatch; + auto result = state_->tryPollBatchOrWait(&waitFuture, &readyBatch); + + EXPECT_EQ(result, RPCState::BatchPollResult::kMustWait); + EXPECT_FALSE(readyBatch.has_value()); + + // Fulfill to clean up + promise.setValue(std::vector{}); +} + +TEST_F(RPCStateTest, batchErrorHandling) { + state_->setStreamingMode(RPCStreamingMode::kBatch); + + auto [promise, future] = + folly::makePromiseContract>(); + + state_->addPendingBatch(state_, std::move(future), {}); + + // Set an exception + promise.setException(std::runtime_error("RPC batch failed")); + + // Wait for async callback and poll + ContinueFuture waitFuture{ContinueFuture::makeEmpty()}; + std::optional readyBatch; + + waitFor([&]() { + auto result = state_->tryPollBatchOrWait(&waitFuture, &readyBatch); + return result == RPCState::BatchPollResult::kGotBatch; + }); + + ASSERT_TRUE(readyBatch.has_value()); + ASSERT_TRUE(readyBatch->error.has_value()); + EXPECT_TRUE(readyBatch->error->find("RPC batch failed") != std::string::npos); + EXPECT_TRUE(readyBatch->responses.empty()); +} + +// ========== Common tests ========== + +TEST_F(RPCStateTest, noMoreInputAndIsFinished) { + EXPECT_FALSE(state_->isFinished()); + + state_->setNoMoreInput(); + EXPECT_TRUE(state_->isFinished()); +} + +TEST_F(RPCStateTest, isFinishedWithPendingRows) { + state_->setStreamingMode(RPCStreamingMode::kPerRow); + + auto [promise, future] = folly::makePromiseContract(); + state_->addPendingRow( + state_, 1, RPCState::RowLocation{0, 0}, std::move(future)); + + state_->setNoMoreInput(); + + // Not finished while rows are pending + EXPECT_FALSE(state_->isFinished()); + + // Fulfill and claim + RPCResponse response; + response.rowId = 1; + response.result = "done"; + promise.setValue(std::move(response)); + + waitFor([&]() { return state_->numPendingRows() == 0; }); + + // Claim the ready row + ContinueFuture waitFuture{ContinueFuture::makeEmpty()}; + std::optional claimedRow; + state_->tryClaimOrWait(&waitFuture, &claimedRow); + + // Now finished (noMoreInput + no pending + no ready) + EXPECT_TRUE(state_->isFinished()); +} + +TEST_F(RPCStateTest, inputBatchStorageAndRelease) { + // Store two input batches. + std::vector columns1; // Empty but valid for testing + std::vector columns2; + + auto batchIdx1 = state_->storeInputBatch(std::move(columns1), 3); + auto batchIdx2 = state_->storeInputBatch(std::move(columns2), 2); + + EXPECT_EQ(batchIdx1, 0); + EXPECT_EQ(batchIdx2, 1); + + // Verify we can retrieve columns (empty in this test). + const auto& cols1 = state_->getInputBatchColumns(batchIdx1); + EXPECT_TRUE(cols1.empty()); + + const auto& cols2 = state_->getInputBatchColumns(batchIdx2); + EXPECT_TRUE(cols2.empty()); + + // Release rows incrementally. + state_->releaseRows(batchIdx1, 2); + // Batch 1 still has 1 active row, columns not released yet. + const auto& cols1After = state_->getInputBatchColumns(batchIdx1); + // Still accessible (columns empty in this test, but structure is still + // valid). + (void)cols1After; + + // Release remaining row — batch should be fully released. + state_->releaseRows(batchIdx1, 1); + + // Release all rows from batch 2 at once. + state_->releaseRows(batchIdx2, 2); +} + +TEST_F(RPCStateTest, batchRowLocationsCarriedThrough) { + state_->setStreamingMode(RPCStreamingMode::kBatch); + + auto [promise, future] = + folly::makePromiseContract>(); + + // Pass row locations with the batch. + std::vector locations = {{0, 5}, {0, 10}, {1, 3}}; + state_->addPendingBatch(state_, std::move(future), locations); + + // Fulfill the promise. + std::vector responses; + for (int i = 0; i < 3; ++i) { + RPCResponse r; + r.rowId = i; + r.result = "result_" + std::to_string(i); + responses.push_back(std::move(r)); + } + promise.setValue(std::move(responses)); + + // Poll and verify row locations are carried through. + waitFor([&]() { + ContinueFuture waitFuture{ContinueFuture::makeEmpty()}; + std::optional readyBatch; + auto result = state_->tryPollBatchOrWait(&waitFuture, &readyBatch); + if (result == RPCState::BatchPollResult::kGotBatch) { + EXPECT_EQ(readyBatch->responses.size(), 3); + EXPECT_EQ(readyBatch->rowLocations.size(), 3); + EXPECT_EQ(readyBatch->rowLocations[0].batchIndex, 0); + EXPECT_EQ(readyBatch->rowLocations[0].rowIndex, 5); + EXPECT_EQ(readyBatch->rowLocations[1].batchIndex, 0); + EXPECT_EQ(readyBatch->rowLocations[1].rowIndex, 10); + EXPECT_EQ(readyBatch->rowLocations[2].batchIndex, 1); + EXPECT_EQ(readyBatch->rowLocations[2].rowIndex, 3); + return true; + } + return false; + }); +} + +TEST_F(RPCStateTest, drainReadyRows) { + state_->setStreamingMode(RPCStreamingMode::kPerRow); + + // Add 3 pending rows and fulfill them all. + std::vector> promises; + for (int i = 0; i < 3; ++i) { + auto [promise, future] = folly::makePromiseContract(); + state_->addPendingRow( + state_, + i, + RPCState::RowLocation{0, static_cast(i)}, + std::move(future)); + promises.push_back(std::move(promise)); + } + + // Fulfill all 3. + for (int i = 0; i < 3; ++i) { + RPCResponse response; + response.rowId = i; + response.result = "result_" + std::to_string(i); + promises[i].setValue(std::move(response)); + } + + waitFor([&]() { return state_->numPendingRows() == 0; }); + + // Drain up to 10 rows (should get all 3). + std::vector out; + state_->drainReadyRows(out, 10); + + ASSERT_EQ(out.size(), 3); + // Verify all responses are present (order may vary due to async). + std::set rowIds; + for (const auto& row : out) { + rowIds.insert(row.rowId); + } + EXPECT_EQ(rowIds.count(0), 1); + EXPECT_EQ(rowIds.count(1), 1); + EXPECT_EQ(rowIds.count(2), 1); + + // Drain again — should be empty. + std::vector out2; + state_->drainReadyRows(out2, 10); + EXPECT_TRUE(out2.empty()); +} + +TEST_F(RPCStateTest, backpressure) { + state_->setStreamingMode(RPCStreamingMode::kPerRow); + state_->setMaxPendingRows(2); + + EXPECT_FALSE(state_->isUnderBackpressure()); + + auto [promise1, future1] = folly::makePromiseContract(); + state_->addPendingRow( + state_, 1, RPCState::RowLocation{0, 0}, std::move(future1)); + EXPECT_FALSE(state_->isUnderBackpressure()); + + auto [promise2, future2] = folly::makePromiseContract(); + state_->addPendingRow( + state_, 2, RPCState::RowLocation{0, 1}, std::move(future2)); + EXPECT_TRUE(state_->isUnderBackpressure()); + + // Fulfill one to relieve backpressure + RPCResponse r1; + r1.rowId = 1; + r1.result = "r1"; + promise1.setValue(std::move(r1)); + + waitFor([&]() { return !state_->isUnderBackpressure(); }); + EXPECT_FALSE(state_->isUnderBackpressure()); + + // Clean up + RPCResponse r2; + r2.rowId = 2; + r2.result = "r2"; + promise2.setValue(std::move(r2)); +} + +} // namespace +} // namespace facebook::velox::exec::rpc diff --git a/velox/exec/tests/AdaptivePrefetchTest.cpp b/velox/exec/tests/AdaptivePrefetchTest.cpp new file mode 100644 index 00000000000..a48461a11c9 --- /dev/null +++ b/velox/exec/tests/AdaptivePrefetchTest.cpp @@ -0,0 +1,60 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/AdaptivePrefetch.h" +#include +#include + +namespace facebook::velox::exec { +namespace { + +TEST(AdaptivePrefetchTest, returnsInitialLookAheadDuringMeasurement) { + AdaptivePrefetch prefetch(1000); + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(prefetch.lookAhead(), 4); + } +} + +TEST(AdaptivePrefetchTest, slowIterationsClampToMin) { + AdaptivePrefetch prefetch(1000); + for (int i = 0; i < 16; ++i) { + prefetch.lookAhead(); + std::this_thread::sleep_for(std::chrono::microseconds(10)); + } + EXPECT_EQ(prefetch.lookAhead(), 4); +} + +TEST(AdaptivePrefetchTest, fastIterationsProduceHighLookAhead) { + AdaptivePrefetch prefetch(1000); + for (int i = 0; i < 16; ++i) { + prefetch.lookAhead(); + } + EXPECT_GT(prefetch.lookAhead(), 4); +} + +TEST(AdaptivePrefetchTest, returnsZeroNearEnd) { + AdaptivePrefetch prefetch(20); + int zeroCount = 0; + for (int i = 0; i < 20; ++i) { + if (prefetch.lookAhead() == 0) { + ++zeroCount; + } + } + EXPECT_GT(zeroCount, 0); +} + +} // namespace +} // namespace facebook::velox::exec diff --git a/velox/exec/tests/AggregateCompanionAdapterTest.cpp b/velox/exec/tests/AggregateCompanionAdapterTest.cpp index 61d27961ca8..e050dff7fdb 100644 --- a/velox/exec/tests/AggregateCompanionAdapterTest.cpp +++ b/velox/exec/tests/AggregateCompanionAdapterTest.cpp @@ -77,8 +77,8 @@ class AggregateCompanionRegistryTest : public testing::Test { const std::vector& argTypes, const TypePtr& intermediateType, const TypePtr& resultType) { - const auto& [resolvedResult, resolveIntermediate] = - resolveAggregateFunction(name, argTypes); + const auto& resolvedResult = resolveResultType(name, argTypes); + const auto& resolveIntermediate = resolveIntermediateType(name, argTypes); checkEqual(resolvedResult, resultType); checkEqual(resolveIntermediate, intermediateType); } @@ -414,22 +414,27 @@ TEST_F( TEST_F( AggregateCompanionRegistryTest, resultTypeNotResolvableFromIntermediateType) { - // We only register companion functions for original signatures whose result - // type can be resolved from its intermediate type. + // We only register partial, merge and merge_extract companion functions for + // original signatures whose result type cannot be resolved from its + // intermediate type. std::vector> signatures{ AggregateFunctionSignatureBuilder() - .typeVariable("T") - .returnType("array(T)") - .intermediateType("varbinary") - .argumentType("T") + .integerVariable("a_precision") + .integerVariable("a_scale") + .integerVariable("i_precision", "min(38, a_precision + 10)") + .integerVariable("r_precision", "min(38, a_precision + 4)") + .integerVariable("r_scale", "min(38, a_scale + 4)") + .returnType("DECIMAL(r_precision, r_scale)") + .intermediateType("ROW(DECIMAL(i_precision, a_scale), bigint)") + .argumentType("DECIMAL(a_precision, a_scale)") .build()}; registerDummyAggregateFunction("aggregateFunc6", signatures); - checkAggregateSignaturesCount("aggregateFunc6_partial", 0); + checkAggregateSignaturesCount("aggregateFunc6_partial", 1); - checkAggregateSignaturesCount("aggregateFunc6_merge", 0); + checkAggregateSignaturesCount("aggregateFunc6_merge", 1); - checkAggregateSignaturesCount("aggregateFunc6_merge_extract", 0); + checkAggregateSignaturesCount("aggregateFunc6_merge_extract", 1); checkScalarSignaturesCount("aggregateFunc6_extract", 0); } diff --git a/velox/exec/tests/AggregateFunctionRegistryTest.cpp b/velox/exec/tests/AggregateFunctionRegistryTest.cpp index a225e395729..8bd8dcf43a8 100644 --- a/velox/exec/tests/AggregateFunctionRegistryTest.cpp +++ b/velox/exec/tests/AggregateFunctionRegistryTest.cpp @@ -40,10 +40,44 @@ class AggregateFunctionRegistryTest : public testing::Test { const std::vector& argTypes, const TypePtr& expectedFinalType, const TypePtr& expectedIntermediateType) { - auto [finalType, intermediateType] = - resolveAggregateFunction(name, argTypes); - EXPECT_EQ(*finalType, *expectedFinalType); - EXPECT_EQ(*intermediateType, *expectedIntermediateType); + { + auto finalType = resolveResultType(name, argTypes); + auto intermediateType = resolveIntermediateType(name, argTypes); + VELOX_EXPECT_EQ_TYPES(finalType, expectedFinalType); + VELOX_EXPECT_EQ_TYPES(intermediateType, expectedIntermediateType); + } + + { + std::vector coercions; + auto finalType = resolveResultTypeWithCoercions( + name, argTypes, coercions, TypeCoercer::defaults()); + VELOX_EXPECT_EQ_TYPES(finalType, expectedFinalType); + + EXPECT_EQ(coercions.size(), argTypes.size()); + for (const auto& coercion : coercions) { + EXPECT_EQ(coercion, nullptr); + } + } + } + + void testCoersions( + const std::string& name, + const std::vector& argTypes, + const TypePtr& expectedFinalType, + const std::vector& expectedCoercions) { + VELOX_ASSERT_THROW( + resolveResultType(name, argTypes), + "Aggregate function signature is not supported"); + + std::vector coercions; + auto finalType = resolveResultTypeWithCoercions( + name, argTypes, coercions, TypeCoercer::defaults()); + VELOX_EXPECT_EQ_TYPES(finalType, expectedFinalType); + + EXPECT_EQ(coercions.size(), argTypes.size()); + for (int i = 0; i < coercions.size(); ++i) { + VELOX_EXPECT_EQ_TYPES(coercions[i], expectedCoercions[i]); + } } void clearRegistry() { @@ -67,26 +101,38 @@ TEST_F(AggregateFunctionRegistryTest, basic) { TEST_F(AggregateFunctionRegistryTest, wrongFunctionName) { VELOX_ASSERT_THROW( - resolveAggregateFunction("aggregate_func_nonexist", {BIGINT(), BIGINT()}), + resolveIntermediateType("aggregate_func_nonexist", {BIGINT(), BIGINT()}), "Aggregate function not registered: aggregate_func_nonexist"); VELOX_ASSERT_THROW( - resolveAggregateFunction("aggregate_func_nonexist", {}), + resolveIntermediateType("aggregate_func_nonexist", {}), "Aggregate function not registered: aggregate_func_nonexist"); } TEST_F(AggregateFunctionRegistryTest, wrongArgType) { VELOX_ASSERT_THROW( - resolveAggregateFunction("aggregate_func", {DOUBLE(), BIGINT()}), + resolveIntermediateType("aggregate_func", {DOUBLE(), BIGINT()}), "Aggregate function signature is not supported"); VELOX_ASSERT_THROW( - resolveAggregateFunction("aggregate_func", {BIGINT()}), + resolveResultType("aggregate_func", {BIGINT()}), "Aggregate function signature is not supported"); VELOX_ASSERT_THROW( - resolveAggregateFunction( - "aggregate_func", {BIGINT(), BIGINT(), BIGINT()}), + resolveResultType("aggregate_func", {BIGINT(), BIGINT(), BIGINT()}), "Aggregate function signature is not supported"); } +TEST_F(AggregateFunctionRegistryTest, coercions) { + // (bigint, double) -> bigint + // (T, T) -> T + testCoersions( + "aggregate_func", {DOUBLE(), BIGINT()}, DOUBLE(), {nullptr, DOUBLE()}); + + testCoersions( + "aggregate_func", {TINYINT(), BIGINT()}, BIGINT(), {BIGINT(), nullptr}); + + testCoersions( + "aggregate_func", {INTEGER(), DOUBLE()}, BIGINT(), {BIGINT(), nullptr}); +} + TEST_F(AggregateFunctionRegistryTest, functionNameInMixedCase) { testResolve( "aggregatE_funC", {BIGINT(), DOUBLE()}, BIGINT(), ARRAY(BIGINT())); diff --git a/velox/exec/tests/AggregateSpillBenchmarkBase.cpp b/velox/exec/tests/AggregateSpillBenchmarkBase.cpp index b8f15c3c799..03b5f8fdcf6 100644 --- a/velox/exec/tests/AggregateSpillBenchmarkBase.cpp +++ b/velox/exec/tests/AggregateSpillBenchmarkBase.cpp @@ -36,10 +36,12 @@ std::unique_ptr makeRowContainer( true, // nullableKeys std::vector{}, dependentTypes, - false, // hasNext - false, // isJoinBuild - false, // hasProbedFlag - false, // hasNormalizedKey + /*hasNext=*/false, + /*isJoinBuild=*/false, + /*hasProbedFlag=*/false, + /*hasCountFlag=*/false, + /*hasNormalizedKey=*/false, + /*useListRowIndex=*/false, pool.get()); } diff --git a/velox/exec/tests/AggregateSpillBenchmarkBase.h b/velox/exec/tests/AggregateSpillBenchmarkBase.h index aafb67d009e..f25f35831bb 100644 --- a/velox/exec/tests/AggregateSpillBenchmarkBase.h +++ b/velox/exec/tests/AggregateSpillBenchmarkBase.h @@ -20,7 +20,7 @@ namespace facebook::velox::exec::test { class AggregateSpillBenchmarkBase : public SpillerBenchmarkBase { public: explicit AggregateSpillBenchmarkBase(std::string spillerType) - : spillerType_(spillerType){}; + : spillerType_(spillerType) {}; /// Sets up the test. void setUp() override; diff --git a/velox/exec/tests/AggregationTest.cpp b/velox/exec/tests/AggregationTest.cpp index 58b6d58b443..fc2ecd4942e 100644 --- a/velox/exec/tests/AggregationTest.cpp +++ b/velox/exec/tests/AggregationTest.cpp @@ -18,14 +18,16 @@ #include #include -#include "folly/experimental/EventCount.h" +#include "folly/synchronization/EventCount.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/FileSystems.h" #include "velox/common/memory/SharedArbitrator.h" #include "velox/common/memory/tests/SharedArbitratorTestUtil.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/common/testutil/TestValue.h" #include "velox/dwio/common/tests/utils/BatchMaker.h" #include "velox/exec/Aggregate.h" +#include "velox/exec/AggregateCompanionSignatures.h" #include "velox/exec/GroupingSet.h" #include "velox/exec/PlanNodeStats.h" #include "velox/exec/PrefixSort.h" @@ -36,7 +38,7 @@ #include "velox/exec/tests/utils/OperatorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/exec/tests/utils/SumNonPODAggregate.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" +#include "velox/type/tests/utils/CustomTypesForTesting.h" namespace facebook::velox::exec::test { @@ -112,35 +114,43 @@ void checkSpillStats(PlanNodeStats& stats, bool expectedSpill) { ASSERT_GT(stats.spilledInputBytes, 0); ASSERT_GT(stats.spilledBytes, 0); ASSERT_GT(stats.spilledPartitions, 0); - ASSERT_GT(stats.customStats[Operator::kSpillRuns].sum, 0); - ASSERT_GT(stats.customStats[Operator::kSpillFillTime].sum, 0); - ASSERT_GT(stats.customStats[Operator::kSpillSortTime].sum, 0); - ASSERT_GT(stats.customStats[Operator::kSpillExtractVectorTime].sum, 0); - ASSERT_GT(stats.customStats[Operator::kSpillSerializationTime].sum, 0); - ASSERT_GT(stats.customStats[Operator::kSpillFlushTime].sum, 0); - ASSERT_GT(stats.customStats[Operator::kSpillWrites].sum, 0); - ASSERT_GT(stats.customStats[Operator::kSpillWriteTime].sum, 0); + ASSERT_GT(stats.customStats[std::string(Operator::kSpillRuns)].sum, 0); + ASSERT_GT(stats.customStats[std::string(Operator::kSpillFillTime)].sum, 0); + ASSERT_GT(stats.customStats[std::string(Operator::kSpillSortTime)].sum, 0); + ASSERT_GT( + stats.customStats[std::string(Operator::kSpillExtractVectorTime)].sum, + 0); + ASSERT_GT( + stats.customStats[std::string(Operator::kSpillSerializationTime)].sum, + 0); + ASSERT_GT(stats.customStats[std::string(Operator::kSpillFlushTime)].sum, 0); + ASSERT_GT(stats.customStats[std::string(Operator::kSpillWrites)].sum, 0); + ASSERT_GT(stats.customStats[std::string(Operator::kSpillWriteTime)].sum, 0); } else { ASSERT_EQ(stats.spilledRows, 0); ASSERT_EQ(stats.spilledInputBytes, 0); ASSERT_EQ(stats.spilledBytes, 0); ASSERT_EQ(stats.spilledPartitions, 0); ASSERT_EQ(stats.spilledFiles, 0); - ASSERT_EQ(stats.customStats[Operator::kSpillRuns].sum, 0); - ASSERT_EQ(stats.customStats[Operator::kSpillFillTime].sum, 0); - ASSERT_EQ(stats.customStats[Operator::kSpillSortTime].sum, 0); - ASSERT_EQ(stats.customStats[Operator::kSpillExtractVectorTime].sum, 0); - ASSERT_EQ(stats.customStats[Operator::kSpillSerializationTime].sum, 0); - ASSERT_EQ(stats.customStats[Operator::kSpillFlushTime].sum, 0); - ASSERT_EQ(stats.customStats[Operator::kSpillWrites].sum, 0); - ASSERT_EQ(stats.customStats[Operator::kSpillWriteTime].sum, 0); + ASSERT_EQ(stats.customStats[std::string(Operator::kSpillRuns)].sum, 0); + ASSERT_EQ(stats.customStats[std::string(Operator::kSpillFillTime)].sum, 0); + ASSERT_EQ(stats.customStats[std::string(Operator::kSpillSortTime)].sum, 0); + ASSERT_EQ( + stats.customStats[std::string(Operator::kSpillExtractVectorTime)].sum, + 0); + ASSERT_EQ( + stats.customStats[std::string(Operator::kSpillSerializationTime)].sum, + 0); + ASSERT_EQ(stats.customStats[std::string(Operator::kSpillFlushTime)].sum, 0); + ASSERT_EQ(stats.customStats[std::string(Operator::kSpillWrites)].sum, 0); + ASSERT_EQ(stats.customStats[std::string(Operator::kSpillWriteTime)].sum, 0); } ASSERT_EQ( - stats.customStats[Operator::kSpillSerializationTime].count, - stats.customStats[Operator::kSpillFlushTime].count); + stats.customStats[std::string(Operator::kSpillSerializationTime)].count, + stats.customStats[std::string(Operator::kSpillFlushTime)].count); ASSERT_EQ( - stats.customStats[Operator::kSpillWrites].count, - stats.customStats[Operator::kSpillWriteTime].count); + stats.customStats[std::string(Operator::kSpillWrites)].count, + stats.customStats[std::string(Operator::kSpillWriteTime)].count); } class AggregationTest : public OperatorTestBase { @@ -332,10 +342,12 @@ class AggregationTest : public OperatorTestBase { std::vector& batches) { std::vector children; dictionary->setSize(count * sizeof(vector_size_t)); - children.push_back(BaseVector::wrapInDictionary( - BufferPtr(nullptr), dictionary, count, rows->childAt(0))); - children.push_back(BaseVector::wrapInDictionary( - BufferPtr(nullptr), dictionary, count, rows->childAt(1))); + children.push_back( + BaseVector::wrapInDictionary( + BufferPtr(nullptr), dictionary, count, rows->childAt(0))); + children.push_back( + BaseVector::wrapInDictionary( + BufferPtr(nullptr), dictionary, count, rows->childAt(1))); children.push_back(children[1]); batches.push_back(vectorMaker_.rowVector(children)); dictionary = AlignedBuffer::allocate( @@ -379,7 +391,9 @@ class AggregationTest : public OperatorTestBase { false, false, true, + false, // hasCountFlag true, + false, pool_.get()); } @@ -478,7 +492,8 @@ TEST_F(AggregationTest, missingFunctionOrSignature) { std::vector{}, std::vector{"agg"}, aggregates, - false, + /*ignoreNullKeys=*/false, + /*noGroupsSpanBatches=*/false, std::move(source)); }) .planNode(); @@ -539,7 +554,8 @@ TEST_F(AggregationTest, missingLambdaFunction) { std::vector{}, std::vector{"agg"}, aggregates, - false, + /*ignoreNullKeys=*/false, + /*noGroupsSpanBatches=*/false, std::move(source)); }) .planNode(); @@ -550,47 +566,6 @@ TEST_F(AggregationTest, missingLambdaFunction) { readCursor(params), "Aggregate function not registered: missing-lambda"); } -TEST_F(AggregationTest, DISABLED_resultTypeMismatch) { - using Step = core::AggregationNode::Step; - - registerAggregateFunction( - "test_aggregate", - {AggregateFunctionSignatureBuilder() - .returnType("bigint") - .intermediateType("bigint") - .argumentType("bigint") - .build()}, - [&](Step /*step*/, - const std::vector& /*argTypes*/, - const TypePtr& /*resultType*/, - const core::QueryConfig& /*config*/) - -> std::unique_ptr { VELOX_UNREACHABLE(); }, - false /*registerCompanionFunctions*/, - true /*overwrite*/); - - for (auto step : {Step::kIntermediate, Step::kPartial}) { - VELOX_ASSERT_THROW( - Aggregate::create( - "test_aggregate", - step, - std::vector{BIGINT()}, - INTEGER(), - core::QueryConfig{{}}), - "Intermediate type mismatch"); - } - - for (auto step : {Step::kFinal, Step::kSingle}) { - VELOX_ASSERT_THROW( - Aggregate::create( - "test_aggregate", - step, - std::vector{BIGINT()}, - INTEGER(), - core::QueryConfig{{}}), - "Final type mismatch"); - } -} - TEST_F(AggregationTest, global) { auto vectors = makeVectors(rowType_, 10, 100); createDuckDbTable(vectors); @@ -670,8 +645,11 @@ TEST_F(AggregationTest, manyGlobalAggregations) { createDuckDbTable(vectors); aggregates.clear(); for (int i = 0; i < rowType->size(); i++) { - aggregates.push_back(fmt::format( - "array_agg({} ORDER BY {})", rowType->nameOf(i), rowType->nameOf(i))); + aggregates.push_back( + fmt::format( + "array_agg({} ORDER BY {})", + rowType->nameOf(i), + rowType->nameOf(i))); } op = PlanBuilder() @@ -876,8 +854,9 @@ TEST_F(AggregationTest, allKeyTypes) { std::vector batches; for (auto i = 0; i < 10; ++i) { - batches.push_back(std::static_pointer_cast( - BatchMaker::createBatch(rowType, 100, *pool_))); + batches.push_back( + std::static_pointer_cast( + BatchMaker::createBatch(rowType, 100, *pool_))); } createDuckDbTable(batches); auto op = @@ -911,12 +890,13 @@ TEST_F(AggregationTest, partialAggregationMemoryLimit) { core::PlanNodeId aggNodeId; auto task = AssertQueryBuilder(duckDbQueryRunner_) .config(QueryConfig::kMaxPartialAggregationMemory, 100) - .plan(PlanBuilder() - .values(vectors) - .partialAggregation({"c0"}, {}) - .capturePlanNodeId(aggNodeId) - .finalAggregation() - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .partialAggregation({"c0"}, {}) + .capturePlanNodeId(aggNodeId) + .finalAggregation() + .planNode()) .assertResults("SELECT distinct c0 FROM tmp"); EXPECT_GT( toPlanStats(task->taskStats()) @@ -934,12 +914,13 @@ TEST_F(AggregationTest, partialAggregationMemoryLimit) { // Count aggregation. task = AssertQueryBuilder(duckDbQueryRunner_) .config(QueryConfig::kMaxPartialAggregationMemory, 1) - .plan(PlanBuilder() - .values(vectors) - .partialAggregation({"c0"}, {"count(1)"}) - .capturePlanNodeId(aggNodeId) - .finalAggregation() - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .partialAggregation({"c0"}, {"count(1)"}) + .capturePlanNodeId(aggNodeId) + .finalAggregation() + .planNode()) .assertResults("SELECT c0, count(1) FROM tmp GROUP BY 1"); EXPECT_GT( toPlanStats(task->taskStats()) @@ -957,12 +938,13 @@ TEST_F(AggregationTest, partialAggregationMemoryLimit) { // Global aggregation. task = AssertQueryBuilder(duckDbQueryRunner_) .config(QueryConfig::kMaxPartialAggregationMemory, 1) - .plan(PlanBuilder() - .values(vectors) - .partialAggregation({}, {"sum(c0)"}) - .capturePlanNodeId(aggNodeId) - .finalAggregation() - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .partialAggregation({}, {"sum(c0)"}) + .capturePlanNodeId(aggNodeId) + .finalAggregation() + .planNode()) .assertResults("SELECT sum(c0) FROM tmp"); EXPECT_EQ( 0, @@ -997,11 +979,12 @@ TEST_F(AggregationTest, partialDistinctWithAbandon) { .config(QueryConfig::kAbandonPartialAggregationMinRows, 100) .config(QueryConfig::kAbandonPartialAggregationMinPct, 50) .maxDrivers(1) - .plan(PlanBuilder() - .values(vectors) - .partialAggregation({"c0"}, {}) - .finalAggregation() - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .partialAggregation({"c0"}, {}) + .finalAggregation() + .planNode()) .assertResults("SELECT distinct c0 FROM tmp"); // with aggregation, just in case. @@ -1009,11 +992,12 @@ TEST_F(AggregationTest, partialDistinctWithAbandon) { .config(QueryConfig::kAbandonPartialAggregationMinRows, 100) .config(QueryConfig::kAbandonPartialAggregationMinPct, 50) .maxDrivers(1) - .plan(PlanBuilder() - .values(vectors) - .partialAggregation({"c0"}, {"sum(c0)"}) - .finalAggregation() - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .partialAggregation({"c0"}, {"sum(c0)"}) + .finalAggregation() + .planNode()) .assertResults("SELECT distinct c0, sum(c0) FROM tmp group by c0"); } @@ -1031,6 +1015,7 @@ TEST_F(AggregationTest, distinctWithGroupingKeysReordered) { options.vectorSize = vectorSize; options.stringVariableLength = false; options.stringLength = 128; + options.nullRatio = 0.1; VectorFuzzer fuzzer(options, pool()); const int numVectors{5}; std::vector vectors; @@ -1042,22 +1027,22 @@ TEST_F(AggregationTest, distinctWithGroupingKeysReordered) { // Distinct aggregation with grouping key with larger prefix encoded size // first. - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); TestScopedSpillInjection scopedSpillInjection(100); - auto task = - AssertQueryBuilder(duckDbQueryRunner_) - .config(QueryConfig::kAbandonPartialAggregationMinRows, 100) - .config(QueryConfig::kAbandonPartialAggregationMinPct, 50) - .spillDirectory(spillDirectory->getPath()) - .config(QueryConfig::kSpillEnabled, true) - .config(QueryConfig::kAggregationSpillEnabled, true) - .config(QueryConfig::kSpillPrefixSortEnabled, true) - .maxDrivers(1) - .plan(PlanBuilder() - .values(vectors) - .singleAggregation({"c4", "c1", "c3", "c2", "c0"}, {}) - .planNode()) - .assertResults("SELECT distinct c4, c1, c3, c2, c0 FROM tmp"); + auto task = AssertQueryBuilder(duckDbQueryRunner_) + .config(QueryConfig::kAbandonPartialAggregationMinRows, 100) + .config(QueryConfig::kAbandonPartialAggregationMinPct, 50) + .spillDirectory(spillDirectory->getPath()) + .config(QueryConfig::kSpillEnabled, true) + .config(QueryConfig::kAggregationSpillEnabled, true) + .config(QueryConfig::kSpillPrefixSortEnabled, true) + .maxDrivers(1) + .plan( + PlanBuilder() + .values(vectors) + .singleAggregation({"c4", "c1", "c3", "c2", "c0"}, {}) + .planNode()) + .assertResults("SELECT distinct c4, c1, c3, c2, c0 FROM tmp"); } TEST_F(AggregationTest, largeValueRangeArray) { @@ -1157,12 +1142,13 @@ TEST_F(AggregationTest, partialAggregationMemoryLimitIncrease) { .config( QueryConfig::kMaxExtendedPartialAggregationMemory, std::to_string(testData.extendedPartialMemoryLimit)) - .plan(PlanBuilder() - .values(vectors) - .partialAggregation({"c0"}, {}) - .capturePlanNodeId(aggNodeId) - .finalAggregation() - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .partialAggregation({"c0"}, {}) + .capturePlanNodeId(aggNodeId) + .finalAggregation() + .planNode()) .assertResults("SELECT distinct c0 FROM tmp"); const auto runtimeStats = toPlanStats(task->taskStats()).at(aggNodeId).customStats; @@ -1233,13 +1219,13 @@ TEST_F(AggregationTest, partialAggregationMaybeReservationReleaseCheck) { TEST_F(AggregationTest, spillAll) { auto inputs = makeVectors(rowType_, 100, 10); - const auto numDistincts = - AssertQueryBuilder(PlanBuilder() - .values(inputs) - .singleAggregation({"c0"}, {}, {}) - .planNode()) - .copyResults(pool_.get()) - ->size(); + const auto numDistincts = AssertQueryBuilder( + PlanBuilder() + .values(inputs) + .singleAggregation({"c0"}, {}, {}) + .planNode()) + .copyResults(pool_.get()) + ->size(); auto plan = PlanBuilder() .values(inputs) @@ -1249,7 +1235,7 @@ TEST_F(AggregationTest, spillAll) { auto results = AssertQueryBuilder(plan).copyResults(pool_.get()); for (int numPartitionBits : {1, 2, 3}) { - auto tempDirectory = exec::test::TempDirectoryPath::create(); + auto tempDirectory = TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); TestScopedSpillInjection scopedSpillInjection(100); auto task = AssertQueryBuilder(plan) @@ -1263,7 +1249,11 @@ TEST_F(AggregationTest, spillAll) { auto stats = task->taskStats().pipelineStats; ASSERT_LT( - 0, stats[0].operatorStats[1].runtimeStats[Operator::kSpillRuns].count); + 0, + stats[0] + .operatorStats[1] + .runtimeStats[std::string(Operator::kSpillRuns)] + .count); // Check spilled bytes. ASSERT_LT(0, stats[0].operatorStats[1].spilledInputBytes); ASSERT_LT(0, stats[0].operatorStats[1].spilledBytes); @@ -1739,7 +1729,7 @@ TEST_F(AggregationTest, outputBatchSizeCheckWithSpill) { inputs = largeVectors; } createDuckDbTable(inputs); - auto tempDirectory = exec::test::TempDirectoryPath::create(); + auto tempDirectory = TempDirectoryPath::create(); core::PlanNodeId aggrNodeId; TestScopedSpillInjection scopedSpillInjection(100); auto task = @@ -1753,11 +1743,12 @@ TEST_F(AggregationTest, outputBatchSizeCheckWithSpill) { .config( QueryConfig::kMaxOutputBatchRows, std::to_string(testData.maxOutputRows)) - .plan(PlanBuilder() - .values(inputs) - .singleAggregation({"c0"}, {"array_agg(c1)"}) - .capturePlanNodeId(aggrNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(inputs) + .singleAggregation({"c0"}, {"array_agg(c1)"}) + .capturePlanNodeId(aggrNodeId) + .planNode()) .assertResults("SELECT c0, array_agg(c1) FROM tmp GROUP BY 1"); ASSERT_GT(toPlanStats(task->taskStats()).at(aggrNodeId).spilledBytes, 0); ASSERT_EQ( @@ -1803,7 +1794,7 @@ TEST_F(AggregationTest, outputBatchSizeCheckWithSpillForOrderedAggr) { SCOPED_TRACE(testData.debugString()); createDuckDbTable(vectors); - auto tempDirectory = exec::test::TempDirectoryPath::create(); + auto tempDirectory = TempDirectoryPath::create(); core::PlanNodeId aggrNodeId; TestScopedSpillInjection scopedSpillInjection(100); auto task = @@ -1817,11 +1808,12 @@ TEST_F(AggregationTest, outputBatchSizeCheckWithSpillForOrderedAggr) { .config( QueryConfig::kMaxOutputBatchRows, std::to_string(testData.maxOutputRows)) - .plan(PlanBuilder() - .values(vectors) - .singleAggregation({"c0"}, {"array_agg(c1 order by c1)"}) - .capturePlanNodeId(aggrNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .singleAggregation({"c0"}, {"array_agg(c1 order by c1)"}) + .capturePlanNodeId(aggrNodeId) + .planNode()) .assertResults( "SELECT c0, array_agg(c1 order by c1) FROM tmp GROUP BY 1"); ASSERT_GT(toPlanStats(task->taskStats()).at(aggrNodeId).spilledBytes, 0); @@ -1847,7 +1839,7 @@ TEST_F(AggregationTest, spillDuringOutputProcessing) { createDuckDbTable({input}); const int numOutputRows = 5; - auto tempDirectory = exec::test::TempDirectoryPath::create(); + auto tempDirectory = TempDirectoryPath::create(); core::PlanNodeId aggrNodeId; TestScopedSpillInjection scopedSpillInjection(100); auto task = @@ -1863,11 +1855,12 @@ TEST_F(AggregationTest, spillDuringOutputProcessing) { .config( QueryConfig::kMaxOutputBatchRows, std::to_string(numOutputRows)) .config(QueryConfig::kSpillNumPartitionBits, "0") - .plan(PlanBuilder() - .values({input}) - .singleAggregation({"c0", "c1"}, {"max(c2)", "min(c3)"}) - .capturePlanNodeId(aggrNodeId) - .planNode()) + .plan( + PlanBuilder() + .values({input}) + .singleAggregation({"c0", "c1"}, {"max(c2)", "min(c3)"}) + .capturePlanNodeId(aggrNodeId) + .planNode()) .assertResults( "SELECT c0, c1, max(c2), min(c3) FROM tmp GROUP BY 1, 2"); @@ -1943,11 +1936,12 @@ TEST_F(AggregationTest, outputBatchSizeCheckWithoutSpill) { .config( QueryConfig::kMaxOutputBatchRows, std::to_string(testData.maxOutputRows)) - .plan(PlanBuilder() - .values(inputs) - .singleAggregation({"c0"}, {"array_agg(c1)"}) - .capturePlanNodeId(aggrNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(inputs) + .singleAggregation({"c0"}, {"array_agg(c1)"}) + .capturePlanNodeId(aggrNodeId) + .planNode()) .assertResults("SELECT c0, array_agg(c1) FROM tmp GROUP BY 1"); ASSERT_EQ( @@ -1973,8 +1967,9 @@ DEBUG_ONLY_TEST_F(AggregationTest, minSpillableMemoryReservation) { createDuckDbTable(batches); for (int32_t minSpillableReservationPct : {5, 50, 100}) { - SCOPED_TRACE(fmt::format( - "minSpillableReservationPct: {}", minSpillableReservationPct)); + SCOPED_TRACE( + fmt::format( + "minSpillableReservationPct: {}", minSpillableReservationPct)); SCOPED_TESTVALUE_SET( "facebook::velox::exec::GroupingSet::addInputForActiveRows", @@ -1991,7 +1986,7 @@ DEBUG_ONLY_TEST_F(AggregationTest, minSpillableMemoryReservation) { currentUsedBytes * minSpillableReservationPct / 100); }))); - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); auto task = AssertQueryBuilder(duckDbQueryRunner_) .spillDirectory(spillDirectory->getPath()) @@ -2003,10 +1998,11 @@ DEBUG_ONLY_TEST_F(AggregationTest, minSpillableMemoryReservation) { .config( QueryConfig::kSpillableReservationGrowthPct, std::to_string(minSpillableReservationPct + 1)) - .plan(PlanBuilder() - .values(batches) - .singleAggregation({"c0"}, {"array_agg(c2)", "max(c3)"}) - .planNode()) + .plan( + PlanBuilder() + .values(batches) + .singleAggregation({"c0"}, {"array_agg(c2)", "max(c3)"}) + .planNode()) .assertResults( "SELECT c0, array_agg(c2), max(c3) FROM tmp GROUP BY 1"); OperatorTestBase::deleteTaskAndCheckSpillDirectory(task); @@ -2034,18 +2030,19 @@ TEST_F(AggregationTest, distinctWithSpilling) { for (const auto& testParam : testParams) { createDuckDbTable(testParam.inputs); - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); core::PlanNodeId aggrNodeId; TestScopedSpillInjection scopedSpillInjection(100); auto task = AssertQueryBuilder(duckDbQueryRunner_) .spillDirectory(spillDirectory->getPath()) .config(QueryConfig::kSpillEnabled, true) .config(QueryConfig::kAggregationSpillEnabled, true) - .plan(PlanBuilder() - .values(testParam.inputs) - .singleAggregation({"c0"}, {}, {}) - .capturePlanNodeId(aggrNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(testParam.inputs) + .singleAggregation({"c0"}, {}, {}) + .capturePlanNodeId(aggrNodeId) + .planNode()) .assertResults("SELECT distinct c0 FROM tmp"); // Verify that spilling is not triggered. @@ -2060,35 +2057,125 @@ TEST_F(AggregationTest, distinctWithSpilling) { } } -TEST_F(AggregationTest, spillingForAggrsWithDistinct) { - auto vectors = makeVectors(rowType_, 100, 10); +class DistinctAggregationTest : public AggregationTest, + public testing::WithParamInterface { + protected: + std::vector makeVectors( + const RowTypePtr& rowType, + size_t size, + int numVectors, + column_index_t keyChannel) { + std::vector vectors; + vectors.reserve(numVectors); + VectorFuzzer aggVectorfuzzer( + {.vectorSize = size, .nullRatio = GetParam()}, pool()); + // Key column is always non-null. + VectorFuzzer keyVectorFuzzer({.vectorSize = size, .nullRatio = 0}, pool()); + + for (int32_t i = 0; i < numVectors; ++i) { + std::vector children; + children.reserve(rowType->children().size()); + + for (auto idx = 0; idx < rowType->children().size(); idx++) { + auto& vectorFuzzer = + idx == keyChannel ? keyVectorFuzzer : aggVectorfuzzer; + children.push_back(vectorFuzzer.fuzzFlat(rowType->childAt(idx))); + } + + vectors.push_back( + std::make_shared( + pool(), rowType, nullptr, size, children)); + } + return vectors; + } +}; + +TEST_P(DistinctAggregationTest, spillingForAggrsWithDistinct) { + auto vectors = makeVectors(rowType_, 100, 10, 1); createDuckDbTable(vectors); - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); core::PlanNodeId aggrNodeId; - TestScopedSpillInjection scopedSpillInjection(100); - auto task = - AssertQueryBuilder(duckDbQueryRunner_) - .spillDirectory(spillDirectory->getPath()) - .config(QueryConfig::kSpillEnabled, true) - .config(QueryConfig::kAggregationSpillEnabled, true) - .plan(PlanBuilder() - .values(vectors) - .singleAggregation({"c1"}, {"count(DISTINCT c0)"}, {}) - .capturePlanNodeId(aggrNodeId) - .planNode()) - .assertResults("SELECT c1, count(DISTINCT c0) FROM tmp GROUP BY c1"); - // Verify that spilling is not triggered. - const auto& queryConfig = task->queryCtx()->queryConfig(); - ASSERT_TRUE(queryConfig.spillEnabled()); - ASSERT_TRUE(queryConfig.aggregationSpillEnabled()); - ASSERT_EQ(toPlanStats(task->taskStats()).at(aggrNodeId).spilledBytes, 0); - OperatorTestBase::deleteTaskAndCheckSpillDirectory(task); + + auto testPlan = [&](const core::PlanNodePtr& plan, const std::string& sql) { + TestScopedSpillInjection scopedSpillInjection(100); + auto task = AssertQueryBuilder(duckDbQueryRunner_) + .spillDirectory(spillDirectory->getPath()) + .config(QueryConfig::kSpillEnabled, "true") + .config(QueryConfig::kAggregationSpillEnabled, "true") + .plan(plan) + .assertResults(sql); + + auto taskStats = exec::toPlanStats(task->taskStats()); + auto& stats = taskStats.at(aggrNodeId); + checkSpillStats(stats, true); + OperatorTestBase::deleteTaskAndCheckSpillDirectory(task); + }; + + // Single aggregate with single input scenario. + auto plan = PlanBuilder() + .values(vectors) + .singleAggregation({"c1"}, {"count(DISTINCT c0)"}, {}) + .capturePlanNodeId(aggrNodeId) + .planNode(); + testPlan(plan, "SELECT c1, count(DISTINCT c0) FROM tmp GROUP BY c1"); + + // Single aggregate with multiple input scenario. + plan = PlanBuilder() + .values(vectors) + .singleAggregation({"c1"}, {"covar_pop(DISTINCT c5, c5)"}, {}) + .capturePlanNodeId(aggrNodeId) + .planNode(); + testPlan(plan, "SELECT c1, covar_pop(DISTINCT c5, c5) FROM tmp GROUP BY c1"); + + // Mixed test including multiple types of distinct aggregate functions. + plan = PlanBuilder() + .values(vectors) + .singleAggregation( + {"c1"}, + {"min(c0)", + "count(c2)", + "count(DISTINCT c0)", + "covar_pop(DISTINCT c5, c5)", + "array_agg(c0 ORDER BY c0)"}, + {}) + .capturePlanNodeId(aggrNodeId) + .planNode(); + testPlan( + plan, + "SELECT c1, min(c0), count(c2), count(DISTINCT c0), covar_pop(DISTINCT c5, c5), array_agg(c0 ORDER BY c0) FROM tmp GROUP BY c1"); + + // Single aggregate with mixed column and constant inputs. + // Tests that constant inputs are properly filtered from the accumulator + // during spilling and spliced back during extraction. + plan = PlanBuilder() + .values(vectors) + .singleAggregation({"c1"}, {"covar_pop(DISTINCT c5, 1.0)"}, {}) + .capturePlanNodeId(aggrNodeId) + .planNode(); + testPlan(plan, "SELECT c1, covar_pop(DISTINCT c5, 1.0) FROM tmp GROUP BY c1"); + + // All-constant distinct inputs. + plan = PlanBuilder() + .values(vectors) + .singleAggregation({"c1"}, {"sum(DISTINCT 3)"}, {}) + .capturePlanNodeId(aggrNodeId) + .planNode(); + testPlan(plan, "SELECT c1, sum(DISTINCT 3) FROM tmp GROUP BY c1"); } +VELOX_INSTANTIATE_TEST_SUITE_P( + DistinctAggregationTest, + DistinctAggregationTest, + ::testing::Values(0, 0.5, 1), + [](const testing::TestParamInfo& info) { + int ratio = static_cast(info.param * 100); + return fmt::format("nullRatio_{}", ratio); + }); + TEST_F(AggregationTest, spillingForAggrsWithSorting) { auto vectors = makeVectors(rowType_, 100, 10); createDuckDbTable(vectors); - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); core::PlanNodeId aggrNodeId; @@ -2289,7 +2376,7 @@ TEST_F(AggregationTest, spillPrefixSortOptimization) { 0}}; for (const auto& testData : testSettings) { - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); core::PlanNodeId aggrNodeId; @@ -2317,16 +2404,22 @@ TEST_F(AggregationTest, spillPrefixSortOptimization) { checkSpillStats(stats, true); if (testData.expectedNumPrefixSortKeys > 0) { ASSERT_GE( - stats.customStats.at(PrefixSort::kNumPrefixSortKeys).sum, + stats.customStats.at(std::string(PrefixSort::kNumPrefixSortKeys)) + .sum, testData.expectedNumPrefixSortKeys); ASSERT_EQ( - stats.customStats.at(PrefixSort::kNumPrefixSortKeys).max, + stats.customStats.at(std::string(PrefixSort::kNumPrefixSortKeys)) + .max, testData.expectedNumPrefixSortKeys); ASSERT_EQ( - stats.customStats.at(PrefixSort::kNumPrefixSortKeys).min, + stats.customStats.at(std::string(PrefixSort::kNumPrefixSortKeys)) + .min, testData.expectedNumPrefixSortKeys); } else { - ASSERT_EQ(stats.customStats.count(PrefixSort::kNumPrefixSortKeys), 0); + ASSERT_EQ( + stats.customStats.count( + std::string(PrefixSort::kNumPrefixSortKeys)), + 0); } OperatorTestBase::deleteTaskAndCheckSpillDirectory(task); }; @@ -2350,6 +2443,52 @@ TEST_F(AggregationTest, spillPrefixSortOptimization) { } } +TEST_F(AggregationTest, distinctWithProperPrefixPreGroupedKeys) { + std::vector vectors; + vectors.push_back(makeRowVector({ + makeFlatVector({0, 0, 1, 1, 2}), + makeFlatVector({0, 1, 10, 11, 20}), + })); + createDuckDbTable(vectors); + auto plan = PlanBuilder() + .values(vectors) + .aggregation( + {"c0", "c1"}, + {"c0"}, + {}, + {}, + core::AggregationNode::Step::kSingle, + false) + .planNode(); + AssertQueryBuilder(plan, duckDbQueryRunner_) + .assertResults("SELECT DISTINCT c0, c1 FROM tmp"); +} + +TEST_F(AggregationTest, distinctWithPreGroupedKeysAcrossBatches) { + std::vector vectors; + vectors.push_back(makeRowVector({ + makeFlatVector({0, 0, 1, 1}), + makeFlatVector({0, 1, 10, 11}), + })); + vectors.push_back(makeRowVector({ + makeFlatVector({1, 1, 2, 2}), + makeFlatVector({12, 13, 20, 21}), + })); + createDuckDbTable(vectors); + auto plan = PlanBuilder() + .values(vectors) + .aggregation( + {"c0", "c1"}, + {"c0"}, + {}, + {}, + core::AggregationNode::Step::kSingle, + false) + .planNode(); + AssertQueryBuilder(plan, duckDbQueryRunner_) + .assertResults("SELECT DISTINCT c0, c1 FROM tmp"); +} + TEST_F(AggregationTest, preGroupedAggregationWithSpilling) { std::vector vectors; int64_t val = 0; @@ -2362,7 +2501,7 @@ TEST_F(AggregationTest, preGroupedAggregationWithSpilling) { makeFlatVector(10, [](auto row) { return row; })})); } createDuckDbTable(vectors); - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); core::PlanNodeId aggrNodeId; TestScopedSpillInjection scopedSpillInjection(100); auto task = @@ -2370,17 +2509,18 @@ TEST_F(AggregationTest, preGroupedAggregationWithSpilling) { .spillDirectory(spillDirectory->getPath()) .config(QueryConfig::kSpillEnabled, true) .config(QueryConfig::kAggregationSpillEnabled, true) - .plan(PlanBuilder() - .values(vectors) - .aggregation( - {"c0", "c1"}, - {"c0"}, - {"sum(c2)"}, - {}, - core::AggregationNode::Step::kSingle, - false) - .capturePlanNodeId(aggrNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .aggregation( + {"c0", "c1"}, + {"c0"}, + {"sum(c2)"}, + {}, + core::AggregationNode::Step::kSingle, + false) + .capturePlanNodeId(aggrNodeId) + .planNode()) .assertResults("SELECT c0, c1, sum(c2) FROM tmp GROUP BY c0, c1"); auto stats = task->taskStats().pipelineStats; // Verify that spilling is not triggered. @@ -2463,10 +2603,11 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimDuringInputProcessing) { for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); - auto tempDirectory = exec::test::TempDirectoryPath::create(); + auto tempDirectory = TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); - queryCtx->testingOverrideMemoryPool(memory::memoryManager()->addRootPool( - queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); auto expectedResult = AssertQueryBuilder( PlanBuilder() @@ -2613,15 +2754,17 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimDuringReserve) { batches.push_back(fuzzer.fuzzRow(rowType)); } - auto tempDirectory = exec::test::TempDirectoryPath::create(); + auto tempDirectory = TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); - queryCtx->testingOverrideMemoryPool(memory::memoryManager()->addRootPool( - queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); auto expectedResult = - AssertQueryBuilder(PlanBuilder() - .values(batches) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .planNode()) + AssertQueryBuilder( + PlanBuilder() + .values(batches) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .planNode()) .queryCtx(queryCtx) .copyResults(pool_.get()); @@ -2666,10 +2809,11 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimDuringReserve) { }))); std::thread taskThread([&]() { - AssertQueryBuilder(PlanBuilder() - .values(batches) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .planNode()) + AssertQueryBuilder( + PlanBuilder() + .values(batches) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .planNode()) .queryCtx(queryCtx) .spillDirectory(tempDirectory->getPath()) .config(QueryConfig::kSpillEnabled, true) @@ -2725,7 +2869,7 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimDuringAllocation) { for (const auto enableSpilling : enableSpillings) { SCOPED_TRACE(fmt::format("enableSpilling {}", enableSpilling)); - auto tempDirectory = exec::test::TempDirectoryPath::create(); + auto tempDirectory = TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); queryCtx->testingOverrideMemoryPool( memory::memoryManager()->addRootPool(queryCtx->queryId(), kMaxBytes)); @@ -2851,10 +2995,11 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimDuringOutputProcessing) { for (const auto enableSpilling : enableSpillings) { SCOPED_TRACE(fmt::format("enableSpilling {}", enableSpilling)); - auto tempDirectory = exec::test::TempDirectoryPath::create(); + auto tempDirectory = TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); - queryCtx->testingOverrideMemoryPool(memory::memoryManager()->addRootPool( - queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); auto expectedResult = AssertQueryBuilder( PlanBuilder() @@ -2996,7 +3141,7 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimDuringNonReclaimableSection) { for (const auto& testData : testSettings) { SCOPED_TRACE(fmt::format("testData {}", testData.debugString())); - auto tempDirectory = exec::test::TempDirectoryPath::create(); + auto tempDirectory = TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); queryCtx->testingOverrideMemoryPool( memory::memoryManager()->addRootPool(queryCtx->queryId(), kMaxBytes)); @@ -3149,7 +3294,7 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimWithEmptyAggregationTable) { for (const auto enableSpilling : enableSpillings) { SCOPED_TRACE(fmt::format("enableSpilling {}", enableSpilling)); - auto tempDirectory = exec::test::TempDirectoryPath::create(); + auto tempDirectory = TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); queryCtx->testingOverrideMemoryPool( memory::memoryManager()->addRootPool(queryCtx->queryId(), kMaxBytes)); @@ -3460,6 +3605,99 @@ TEST_F(AggregationTest, distinctHang) { .assertResults("SELECT distinct c0, c1 FROM tmp"); } +TEST_F(AggregationTest, distinctWithConstantInput) { + auto data = makeRowVector({ + makeFlatVector({1, 1, 2, 2, 3}), + makeFlatVector({10, 20, 30, 40, 50}), + makeFlatVector({0.1, 0.2, 0.3, 0.4, 0.5}), + }); + createDuckDbTable({data}); + + auto plan = + PlanBuilder() + .values({data}) + .singleAggregation( + {"c0"}, {"max_by(DISTINCT c1, 1)", "corr(DISTINCT 0.5, c2)"}) + .planNode(); + + assertQuery( + plan, + "SELECT c0, max_by(DISTINCT c1, 1), corr(DISTINCT 0.5, c2) FROM tmp GROUP BY c0"); + + // All-constant distinct inputs. + plan = PlanBuilder() + .values({data}) + .singleAggregation({"c0"}, {"sum(DISTINCT 3)"}) + .planNode(); + + assertQuery(plan, "SELECT c0, sum(DISTINCT 3) FROM tmp GROUP BY c0"); + + // All-constant distinct inputs with a filter. + plan = PlanBuilder() + .values({data}) + .project({"c0", "c1", "c2", "c1 > 1 as mask"}) + .singleAggregation({"c0"}, {"sum(DISTINCT 3)"}, {"mask"}) + .planNode(); + + assertQuery( + plan, + "SELECT c0, sum(DISTINCT 3) FILTER (WHERE c1 > 1) FROM tmp GROUP BY c0"); + + // All-constant distinct inputs together with column distinct inputs. + plan = PlanBuilder() + .values({data}) + .singleAggregation({"c0"}, {"sum(DISTINCT 1)", "count(c1)"}) + .planNode(); + + assertQuery( + plan, "SELECT c0, sum(DISTINCT 1), count(c1) FROM tmp GROUP BY c0"); + + // Global aggregation with constant distinct input. + plan = PlanBuilder() + .values({data}) + .project({"c0", "c1", "c2", "c1 > 1 as mask"}) + .singleAggregation( + {}, + {"max_by(DISTINCT c1, 1)", + "sum(DISTINCT 3)", + "sum(DISTINCT 3)", + "count(c1)"}, + {"", "", "mask", ""}) + .planNode(); + + assertQuery( + plan, + "SELECT max_by(DISTINCT c1, 1), sum(DISTINCT 3), sum(DISTINCT 3) FILTER (WHERE c1 > 1), count(c1) FROM tmp"); + + // Mixed empty and non-empty groups. + plan = PlanBuilder() + .values({data}) + .project({"c0", "c1 > 25 as mask"}) + .singleAggregation({"c0"}, {"sum(DISTINCT 3)"}, {"mask"}) + .planNode(); + assertQuery( + plan, + "SELECT c0, sum(DISTINCT 3) FILTER (WHERE c1 > 25) FROM tmp GROUP BY c0"); + + // Group-by with all groups empty. + plan = PlanBuilder() + .values({data}) + .project({"c0", "c1 > 100 as mask"}) + .singleAggregation({"c0"}, {"sum(DISTINCT 3)"}, {"mask"}) + .planNode(); + assertQuery( + plan, + "SELECT c0, sum(DISTINCT 3) FILTER (WHERE c1 > 100) FROM tmp GROUP BY c0"); + + // Global with empty group. + plan = PlanBuilder() + .values({data}) + .project({"c1 > 100 as mask"}) + .singleAggregation({}, {"sum(DISTINCT 3)"}, {"mask"}) + .planNode(); + assertQuery(plan, "SELECT sum(DISTINCT 3) FILTER (WHERE c1 > 100) FROM tmp"); +} + // Trigger memory pool allocation at HashAggregation::populateAggregateInputs by // aggregating null constant. Ensure the allocation happens outside of // HashAggregation's constructor. @@ -3510,10 +3748,11 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimEmptyInput) { } })); - auto tempDirectory = exec::test::TempDirectoryPath::create(); + auto tempDirectory = TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); - queryCtx->testingOverrideMemoryPool(memory::memoryManager()->addRootPool( - queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); core::PlanNodeId aggNodeId; auto task = AssertQueryBuilder( @@ -3541,10 +3780,11 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimEmptyOutput) { auto batches = makeVectors(rowType, 100, 5); auto expectedResult = - AssertQueryBuilder(PlanBuilder() - .values(batches) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .planNode()) + AssertQueryBuilder( + PlanBuilder() + .values(batches) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .planNode()) .copyResults(pool_.get()); std::atomic_int numGetOutput{0}; @@ -3582,26 +3822,27 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimEmptyOutput) { } }))); - auto tempDirectory = exec::test::TempDirectoryPath::create(); + auto tempDirectory = TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); - queryCtx->testingOverrideMemoryPool(memory::memoryManager()->addRootPool( - queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); core::PlanNodeId aggNodeId; - auto task = - AssertQueryBuilder(PlanBuilder() - .values(batches) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .capturePlanNodeId(aggNodeId) - .planNode()) - .spillDirectory(tempDirectory->getPath()) - .queryCtx(queryCtx) - .config(QueryConfig::kSpillEnabled, true) - .config(QueryConfig::kAggregationSpillEnabled, true) - // Set the output query configs to ensure fetch the result in one - // output batch. - .config(QueryConfig::kPreferredOutputBatchBytes, 1UL << 30) - .config(QueryConfig::kMaxOutputBatchRows, 1024) - .assertResults(expectedResult); + auto task = AssertQueryBuilder( + PlanBuilder() + .values(batches) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .capturePlanNodeId(aggNodeId) + .planNode()) + .spillDirectory(tempDirectory->getPath()) + .queryCtx(queryCtx) + .config(QueryConfig::kSpillEnabled, true) + .config(QueryConfig::kAggregationSpillEnabled, true) + // Set the output query configs to ensure fetch the result in + // one output batch. + .config(QueryConfig::kPreferredOutputBatchBytes, 1UL << 30) + .config(QueryConfig::kMaxOutputBatchRows, 1024) + .assertResults(expectedResult); // Since the spilling is triggered after the aggregation operator has produced // all the output, we don't expect any spilled data. auto taskStats = exec::toPlanStats(task->taskStats()); @@ -3621,7 +3862,7 @@ TEST_F(AggregationTest, maxSpillBytes) { .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) .capturePlanNodeId(aggregationNodeId) .planNode(); - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); struct { int32_t maxSpilledBytes; @@ -3678,7 +3919,7 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimFromAggregation) { testingRunArbitration(op->pool()); }))); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); core::PlanNodeId aggrNodeId; auto task = AssertQueryBuilder(duckDbQueryRunner_) @@ -3688,11 +3929,12 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimFromAggregation) { .config( core::QueryConfig::kMaxSpillRunRows, std::to_string(maxSpillRunRows)) - .plan(PlanBuilder() - .values(vectors) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .capturePlanNodeId(aggrNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .capturePlanNodeId(aggrNodeId) + .planNode()) .assertResults( "SELECT c0, c1, array_agg(c2) FROM tmp GROUP BY c0, c1"); auto taskStats = exec::toPlanStats(task->taskStats()); @@ -3702,7 +3944,8 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimFromAggregation) { // reporting in unit test. ASSERT_GE( planStats - .customStats[memory::SharedArbitrator::kMemoryArbitrationWallNanos] + .customStats[std::string( + memory::SharedArbitrator::kMemoryArbitrationWallNanos)] .sum, 0); task.reset(); @@ -3732,7 +3975,7 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimFromDistinctAggregation) { testingRunArbitration(op->pool()); }))); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); core::PlanNodeId aggrNodeId; auto task = AssertQueryBuilder(duckDbQueryRunner_) .spillDirectory(spillDirectory->getPath()) @@ -3741,11 +3984,12 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimFromDistinctAggregation) { .config( core::QueryConfig::kMaxSpillRunRows, std::to_string(maxSpillRunRows)) - .plan(PlanBuilder() - .values(vectors) - .singleAggregation({"c0"}, {}) - .capturePlanNodeId(aggrNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .singleAggregation({"c0"}, {}) + .capturePlanNodeId(aggrNodeId) + .planNode()) .assertResults("SELECT distinct c0 FROM tmp"); auto taskStats = exec::toPlanStats(task->taskStats()); auto& planStats = taskStats.at(aggrNodeId); @@ -3758,7 +4002,7 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimFromDistinctAggregation) { DEBUG_ONLY_TEST_F(AggregationTest, reclaimFromAggregationOnNoMoreInput) { std::vector vectors = createVectors(8, rowType_, fuzzerOpts_); createDuckDbTable(vectors); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); std::atomic injectNoMoreInputOnce{true}; SCOPED_TESTVALUE_SET( @@ -3780,10 +4024,11 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimFromAggregationOnNoMoreInput) { .config(core::QueryConfig::kSpillEnabled, true) .config(core::QueryConfig::kAggregationSpillEnabled, true) .maxDrivers(1) - .plan(PlanBuilder() - .values(vectors) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .planNode()) .assertResults( "SELECT c0, c1, array_agg(c2) FROM tmp GROUP BY c0, c1"); auto stats = task->taskStats().pipelineStats; @@ -3803,7 +4048,7 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimFromAggregationDuringOutput) { } createDuckDbTable(vectors); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); std::atomic_int numInputs{0}; SCOPED_TESTVALUE_SET( "facebook::velox::exec::Driver::runInternal::getOutput", @@ -3825,10 +4070,11 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimFromAggregationDuringOutput) { .config(core::QueryConfig::kPreferredOutputBatchRows, numRows / 10) .maxDrivers(1) //.queryCtx(aggregationQueryCtx) - .plan(PlanBuilder() - .values(vectors) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .planNode()) .assertResults( "SELECT c0, c1, array_agg(c2) FROM tmp GROUP BY c0, c1"); auto stats = task->taskStats().pipelineStats; @@ -3840,17 +4086,18 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimFromAggregationDuringOutput) { TEST_F(AggregationTest, reclaimFromCompletedAggregation) { std::vector vectors = createVectors(8, rowType_, fuzzerOpts_); createDuckDbTable(vectors); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); folly::EventCount arbitrationWait; std::atomic_bool arbitrationWaitFlag{true}; std::thread aggregationThread([&]() { auto task = AssertQueryBuilder(duckDbQueryRunner_) - .plan(PlanBuilder() - .values(vectors) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .planNode()) .assertResults( "SELECT c0, c1, array_agg(c2) FROM tmp GROUP BY c0, c1"); waitForTaskCompletion(task.get()); @@ -3864,6 +4111,108 @@ TEST_F(AggregationTest, reclaimFromCompletedAggregation) { waitForAllTasksToBeDeleted(); } +DEBUG_ONLY_TEST_F(AggregationTest, reclaimWithCompact) { + const int numInputs = 8; + std::vector vectors = + createVectors(numInputs, rowType_, fuzzerOpts_); + createDuckDbTable(vectors); + + struct { + bool spillEnabled; + bool compactionEnabled; + uint64_t compactBytes; + bool expectedReclaimable; + bool expectSpill; + + std::string debugString() const { + return fmt::format( + "spillEnabled {}, compactionEnabled {}, compactBytes {}," + " expectedReclaimable {}, expectSpill {}", + spillEnabled, + compactionEnabled, + compactBytes, + expectedReclaimable, + expectSpill); + } + } testSettings[] = { + // Spill enabled, compaction enabled, compaction frees enough bytes -> no + // spill. + {true, true, 1UL << 30, true, false}, + // Spill enabled, compaction enabled, compaction frees 0 bytes -> spill. + {true, true, 0, true, true}, + // Spill enabled, compaction disabled -> reclaimable via spill. + {true, false, 0, true, true}, + // Spill disabled, compaction enabled -> non-reclaimable (no compactable + // aggregates with array_agg). + {false, true, 0, false, false}, + // Spill disabled, compaction disabled -> non-reclaimable. + {false, false, 0, false, false}, + }; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + std::atomic_int inputCount{0}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::HashAggregation::reclaim::compact", + std::function(([&](uint64_t* compactedBytes) { + *compactedBytes = testData.compactBytes; + }))); + + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::addInput", + std::function(([&](exec::Operator* op) { + if (op->operatorCtx()->operatorType() != "Aggregation") { + return; + } + if (++inputCount != numInputs / 2) { + return; + } + ASSERT_EQ(op->canReclaim(), testData.expectedReclaimable); + if (testData.expectedReclaimable) { + testingRunArbitration(op->pool()); + } else { + // When neither spill nor compaction can reclaim memory, calling + // reclaim() directly would fail the canReclaim() check. Under real + // memory pressure this operator would cause the query to OOM. + memory::MemoryReclaimer::Stats reclaimStats; + VELOX_ASSERT_THROW(op->reclaim(0, reclaimStats), ""); + } + }))); + + const auto spillDirectory = exec::test::TempDirectoryPath::create(); + core::PlanNodeId aggrNodeId; + AssertQueryBuilder queryBuilder(duckDbQueryRunner_); + queryBuilder.config( + core::QueryConfig::kAggregationMemoryCompactionReclaimEnabled, + testData.compactionEnabled); + if (testData.spillEnabled) { + queryBuilder.spillDirectory(spillDirectory->getPath()) + .config(core::QueryConfig::kSpillEnabled, true) + .config(core::QueryConfig::kAggregationSpillEnabled, true); + } + auto task = + queryBuilder + .plan( + PlanBuilder() + .values(vectors) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .capturePlanNodeId(aggrNodeId) + .planNode()) + .assertResults( + "SELECT c0, c1, array_agg(c2) FROM tmp GROUP BY c0, c1"); + auto taskStats = exec::toPlanStats(task->taskStats()); + auto& planStats = taskStats.at(aggrNodeId); + if (testData.expectSpill) { + ASSERT_GT(planStats.spilledBytes, 0); + } else { + ASSERT_EQ(planStats.spilledBytes, 0); + } + task.reset(); + waitForAllTasksToBeDeleted(); + } +} + TEST_F(AggregationTest, ignoreNullKeys) { // Some keys are null. auto data = makeRowVector({ @@ -4029,7 +4378,9 @@ TEST_F(AggregationTest, destroyAfterPartialInitialization) { false, // hasNext false, // isJoinBuild false, // hasProbedFlag + false, // hasCountFlag false, // hasNormalizedKeys + false, // useListRowIndex pool()); const auto rowColumn = rows.columnAt(0); agg.setOffsets( @@ -4138,4 +4489,100 @@ TEST_F(AggregationTest, nanKeys) { {makeRowVector({c0, c1}), c1}, {makeRowVector({e0, e1}), e1}); } + +TEST_F(AggregationTest, keysProvideCustomComparison) { + // Columns reused across test cases. + auto c0 = makeFlatVector( + {0, 1, 256, 257, 512, 513}, + velox::test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON()); + auto c1 = makeFlatVector({1, 2, 1, 2, 1, 2}); + // Expected result columns reused across test cases. A deduplicated version of + // c0 and c1. + auto e0 = makeFlatVector( + {0, 1}, velox::test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON()); + auto e1 = makeFlatVector({1, 2}); + + auto testDistinctAgg = [&](const std::vector& aggKeys, + const std::vector& inputCols, + const std::vector& expectedCols) { + auto plan = PlanBuilder() + .values({makeRowVector(inputCols)}) + .singleAggregation(aggKeys, {}, {}) + .planNode(); + AssertQueryBuilder(plan).assertResults(makeRowVector(expectedCols)); + }; + + // Test with a primitive type key. + testDistinctAgg({"c0"}, {c0}, {e0}); + // Multiple key columns. + testDistinctAgg({"c0", "c1"}, {c0, c1}, {e0, e1}); + + // Test with a complex type key. + testDistinctAgg({"c0"}, {makeRowVector({c0, c1})}, {makeRowVector({e0, e1})}); + // Multiple key columns. + testDistinctAgg( + {"c0", "c1"}, + {makeRowVector({c0, c1}), c1}, + {makeRowVector({e0, e1}), e1}); +} + +// Test that aggregation spill uses the aggregation_spill_file_create_config +// when set, and other spillable operators use the default +// spill_file_create_config. +DEBUG_ONLY_TEST_F(AggregationTest, aggregationSpillFileCreateConfig) { + auto vectors = makeVectors(rowType_, 32, 100); + createDuckDbTable(vectors); + + auto tempDirectory = TempDirectoryPath::create(); + + std::atomic_bool aggregationConfigVerified{false}; + std::atomic_bool defaultConfigVerified{false}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::isBlocked", + std::function([&](exec::Operator* op) { + const auto* spillConfig = op->testingSpillConfig(); + if (spillConfig == nullptr) { + return; + } + const auto& opType = op->operatorType(); + if (opType == "Aggregation" || opType == "PartialAggregation") { + // Aggregation operators should use + // aggregation_spill_file_create_config. + ASSERT_EQ(spillConfig->fileCreateConfig, "test_aggregation_config") + << "Operator: " << opType; + aggregationConfigVerified = true; + } else { + // Other spillable operators (e.g., OrderBy) should use the default + // spill_file_create_config. + ASSERT_EQ(spillConfig->fileCreateConfig, "test_default_config") + << "Operator: " << opType; + defaultConfigVerified = true; + } + })); + + // Build a plan with aggregation and orderBy. Aggregation operators should use + // aggregation_spill_file_create_config and orderBy should use the default + // spill_file_create_config. + TestScopedSpillInjection scopedSpillInjection(100); + AssertQueryBuilder(duckDbQueryRunner_) + .spillDirectory(tempDirectory->getPath()) + .config(QueryConfig::kSpillEnabled, true) + .config(QueryConfig::kAggregationSpillEnabled, true) + .config(QueryConfig::kOrderBySpillEnabled, true) + .config(QueryConfig::kSpillFileCreateConfig, "test_default_config") + .config( + QueryConfig::kAggregationSpillFileCreateConfig, + "test_aggregation_config") + .plan( + PlanBuilder() + .values(vectors) + .singleAggregation({"c0", "c1"}, {"sum(c2)"}) + .orderBy({"c0 ASC NULLS LAST"}, false) + .planNode()) + .assertResults( + "SELECT c0, c1, sum(c2) FROM tmp GROUP BY c0, c1 ORDER BY c0 ASC NULLS LAST"); + + ASSERT_TRUE(aggregationConfigVerified.load()); + ASSERT_TRUE(defaultConfigVerified.load()); +} } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/AssertQueryBuilderTest.cpp b/velox/exec/tests/AssertQueryBuilderTest.cpp index 6ae70783f7a..e7155436c24 100644 --- a/velox/exec/tests/AssertQueryBuilderTest.cpp +++ b/velox/exec/tests/AssertQueryBuilderTest.cpp @@ -20,6 +20,7 @@ #include "velox/vector/fuzzer/VectorFuzzer.h" namespace facebook::velox::exec::test { +using namespace facebook::velox::common::testutil; class AssertQueryBuilderTest : public HiveConnectorTestBase {}; @@ -94,8 +95,11 @@ TEST_F(AssertQueryBuilderTest, hiveSplits) { // Single leaf node with two splits. auto makeSplits = [](const std::string& path, size_t numRepeats = 1) { - std::vector> splits( - numRepeats, makeHiveConnectorSplit(path)); + std::vector> splits; + splits.reserve(numRepeats); + for (size_t i = 0; i < numRepeats; ++i) { + splits.emplace_back(makeHiveConnectorSplit(path)); + } return splits; }; diff --git a/velox/exec/tests/AssignUniqueIdTest.cpp b/velox/exec/tests/AssignUniqueIdTest.cpp index 6a34e3bc141..762dbdcbe14 100644 --- a/velox/exec/tests/AssignUniqueIdTest.cpp +++ b/velox/exec/tests/AssignUniqueIdTest.cpp @@ -21,7 +21,7 @@ #include "velox/exec/tests/utils/QueryAssertions.h" namespace facebook::velox::exec { - +using namespace facebook::velox::common::testutil; using namespace facebook::velox::test; using namespace facebook::velox::exec::test; @@ -200,22 +200,32 @@ TEST_F(AssignUniqueIdTest, barrier) { .assignUniqueId("row_number") .project({"c0", "c1", "row_number"}) .planNode(); + struct { + bool hasBarrier; + bool serialExecution; - for (const auto barrierExecution : {false, true}) { - SCOPED_TRACE(fmt::format("barrierExecution {}", barrierExecution)); + std::string toString() const { + return fmt::format( + "hasBarrier: {}, serialExecution: {}", hasBarrier, serialExecution); + } + } testSettings[] = { + {false, false}, {false, true}, {true, false}, {true, true}}; + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.toString()); std::shared_ptr task; auto result = AssertQueryBuilder(plan) .splits(makeHiveConnectorSplits(tempFiles)) - .serialExecution(true) - .barrierExecution(barrierExecution) + .serialExecution(testData.serialExecution) + .maxDrivers(testData.serialExecution ? 1 : 3) + .barrierExecution(testData.hasBarrier) .copyResults(pool(), task); auto results = split(result, numSplits); verifyUniqueId(vectors, results); const auto taskStats = task->taskStats(); - ASSERT_EQ(taskStats.numBarriers, barrierExecution ? numSplits : 0); + ASSERT_EQ(taskStats.numBarriers, testData.hasBarrier ? numSplits : 0); ASSERT_EQ(taskStats.numFinishedSplits, numSplits); } } diff --git a/velox/exec/tests/AsyncConnectorTest.cpp b/velox/exec/tests/AsyncConnectorTest.cpp index 3904a610fd6..d2f5261c286 100644 --- a/velox/exec/tests/AsyncConnectorTest.cpp +++ b/velox/exec/tests/AsyncConnectorTest.cpp @@ -13,8 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include +#include #include "velox/connectors/Connector.h" +#include "velox/connectors/ConnectorRegistry.h" #include "velox/exec/PlanNodeStats.h" #include "velox/exec/tests/utils/OperatorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" @@ -60,7 +61,8 @@ class TestSplit : public connector::ConnectorSplit { return ContinueFuture::makeEmpty(); } - auto [promise, future] = makeVeloxContinuePromiseContract(); + auto [promise, future] = + makeVeloxContinuePromiseContract("TestSplit::touch"); promise_ = std::move(promise); scheduler_.addFunction( @@ -126,7 +128,7 @@ class TestDataSource : public connector::DataSource { return 0; } - std::unordered_map runtimeStats() override { + std::unordered_map getRuntimeStats() override { return {}; } @@ -186,11 +188,12 @@ class AsyncConnectorTest : public OperatorTestBase { std::unordered_map()), nullptr, nullptr); - connector::registerConnector(testConnector); + connector::ConnectorRegistry::global().insert( + testConnector->connectorId(), testConnector); } void TearDown() override { - connector::unregisterConnector(kTestConnectorId); + connector::ConnectorRegistry::global().erase(kTestConnectorId); OperatorTestBase::TearDown(); } }; diff --git a/velox/exec/tests/CMakeLists.txt b/velox/exec/tests/CMakeLists.txt index 010f343dea0..bda084e8ad9 100644 --- a/velox/exec/tests/CMakeLists.txt +++ b/velox/exec/tests/CMakeLists.txt @@ -19,6 +19,11 @@ add_executable( AggregateCompanionSignaturesTest.cpp DummyAggregateFunction.cpp ) +velox_add_test_headers( + aggregate_companion_functions_test + DummyAggregateFunction.h + PrestoQueryRunnerIntermediateTypeTransformTestBase.h +) add_test( NAME aggregate_companion_functions_test @@ -36,116 +41,123 @@ target_link_libraries( GTest::gtest_main ) -add_executable( - velox_exec_test +# Sources are ordered via greedy bin-packing by measured EC2 execution time. +# With VELOX_TESTS_PER_GROUP=10, CMake takes files positionally in batches of +# 10 (files 1-10 = group0, 11-20 = group1, etc.). +# +# Measured per-file EC2 timings (sequential, 8-core, 30GB): +# IndexLookupJoinTest 783s, MultiFragmentTest 599s, +# HashJoinTestExtra ~585s, SpillerTest 584s, HashJoinTest ~575s, +# TableWriterTest 225s, TableScanTest 189s, IndexLookupJoinTestExtra 188s, +# MergeJoinTest 111s, AggregationTest 105s, OutputBufferManagerTest 98s, +# ScaleWriterLocalPartitionTest 96s, OrderByTest 71s, TopNRowNumberTest 64s, +# HashTableTest 60s, StreamingAggregationTest 58s, ExchangeClientTest 54s, +# RowNumberTest 42s. +# +# Max group ~784s (bounded by IndexLookupJoinTest.cpp). +set( + VELOX_EXEC_TEST_SOURCES + # group0 (~784s): IndexLookupJoinTest 783 + 9 lightweight + IndexLookupJoinTest.cpp + ConcatFilesSpillMergeStreamTest.cpp + MixedUnionWithTableScanTest.cpp + MemoryReclaimerTest.cpp + EnforceDistinctTest.cpp + TraceUtilTest.cpp + HashPartitionFunctionTest.cpp + SpatialIndexTest.cpp + ValuesTest.cpp + ParallelProjectTest.cpp + # group1 (~599s): MultiFragmentTest 599 + 9 lightweight + MultiFragmentTest.cpp + EnforceSingleRowTest.cpp + FilterToExpressionTest.cpp + ScaledScanControllerTest.cpp + HilbertIndexTest.cpp + OperatorTraceTest.cpp + LimitTest.cpp + SplitListenerTest.cpp AddressableNonNullValueListTest.cpp - AggregationTest.cpp - AggregateFunctionRegistryTest.cpp ArrowStreamTest.cpp - AssignUniqueIdTest.cpp - AsyncConnectorTest.cpp - ConcatFilesSpillMergeStreamTest.cpp - ContainerRowSerdeTest.cpp + # group2 (~585s): HashJoinTestExtra ~585 + 9 lightweight + HashJoinTestExtra.cpp + AggregateFunctionRegistryTest.cpp + RoundRobinPartitionFunctionTest.cpp ColumnStatsCollectorTest.cpp - CustomJoinTest.cpp - EnforceSingleRowTest.cpp - ExchangeClientTest.cpp + MixedUnionTest.cpp ExpandTest.cpp - FilterProjectTest.cpp - FilterToExpressionTest.cpp FunctionResolutionTest.cpp - HashBitRangeTest.cpp - HashJoinBridgeTest.cpp - HashJoinTest.cpp - HashPartitionFunctionTest.cpp - HashTableTest.cpp - IndexLookupJoinTest.cpp - LimitTest.cpp - LocalPartitionTest.cpp - Main.cpp - MarkDistinctTest.cpp - MemoryReclaimerTest.cpp - MergeJoinTest.cpp - MergeTest.cpp - MergerTest.cpp - MultiFragmentTest.cpp - NestedLoopJoinTest.cpp - OrderByTest.cpp - OperatorTraceTest.cpp - OutputBufferManagerTest.cpp - ParallelProjectTest.cpp + SpillStatsTest.cpp + MarkSortedTest.cpp PartitionedOutputTest.cpp - PlanNodeSerdeTest.cpp - PlanNodeToStringTest.cpp + # group3 (~584s): SpillerTest 584 + 9 lightweight + SpillerTest.cpp PlanNodeToSummaryStringTest.cpp + CountingJoinTest.cpp + CustomJoinTest.cpp + TaskListenerTest.cpp + SplitTest.cpp + SqlTest.cpp + WindowFunctionRegistryTest.cpp + StreamingEnforceDistinctTest.cpp + HashBitRangeTest.cpp + # group4 (~575s): HashJoinTest ~575 + 9 lightweight + HashJoinTest.cpp + SpillTest.cpp + WindowTest.cpp PrefixSortTest.cpp + MergerTest.cpp + LocalPartitionTest.cpp PrintPlanWithStatsTest.cpp ProbeOperatorStateTest.cpp - TraceUtilTest.cpp - RoundRobinPartitionFunctionTest.cpp - RowContainerTest.cpp - RowNumberTest.cpp - ScaledScanControllerTest.cpp - ScaleWriterLocalPartitionTest.cpp - SortBufferTest.cpp - SpillerTest.cpp - SpillTest.cpp - SplitListenerTest.cpp - SplitTest.cpp - SqlTest.cpp - StreamingAggregationTest.cpp - TableScanTest.cpp + MarkDistinctTest.cpp + MergeTest.cpp + # group5 (~445s): TableWriterTest 225 + OutputBufMgr 98 + # + TopNRowNumber 64 + StreamingAgg 58 TableWriterTest.cpp - TaskListenerTest.cpp - ThreadDebugInfoTest.cpp + OutputBufferManagerTest.cpp TopNRowNumberTest.cpp + StreamingAggregationTest.cpp + ContainerRowSerdeTest.cpp + RowContainerTest.cpp TopNTest.cpp + WriterFuzzerUtilTest.cpp + PlanNodeStatsTest.cpp + ThreadDebugInfoTest.cpp + # group6 (~467s): TableScanTest 189 + AggregationTest 105 + # + OrderBy 71 + HashTable 60 + RowNumber 42 + TableScanTest.cpp + AggregationTest.cpp + OrderByTest.cpp + HashTableCacheTest.cpp + HashTableTest.cpp + RowNumberTest.cpp UnnestTest.cpp - UnorderedStreamReaderTest.cpp - ValuesTest.cpp + PlanNodeSerdeTest.cpp + HashJoinBridgeTest.cpp + SortBufferTest.cpp VectorHasherTest.cpp - WindowFunctionRegistryTest.cpp - WindowTest.cpp - WriterFuzzerUtilTest.cpp -) - -if(VELOX_ENABLE_GEO) - target_sources(velox_exec_test PRIVATE SpatialJoinTest.cpp) -endif() - -add_executable( - velox_exec_infra_test - AssertQueryBuilderTest.cpp - DriverTest.cpp - FunctionSignatureBuilderTest.cpp - GroupedExecutionTest.cpp - Main.cpp - OperatorUtilsTest.cpp - PlanBuilderTest.cpp - PrestoQueryRunnerTest.cpp - QueryAssertionsTest.cpp - TaskTest.cpp - TreeOfLosersTest.cpp -) - -add_test(NAME velox_exec_test COMMAND velox_exec_test WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) - -# TODO: Revert back to 3000 once is fixed. -# https://github.com/facebookincubator/velox/issues/13879 -set_tests_properties(velox_exec_test PROPERTIES TIMEOUT 6000) - -add_test( - NAME velox_exec_infra_test - COMMAND velox_exec_infra_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + # group7 (~449s): IndexLookupJoinTestExtra 188 + # + MergeJoin 111 + ScaleWriter 96 + Exchange 54 + IndexLookupJoinTestExtra.cpp + MergeJoinTest.cpp + ScaleWriterLocalPartitionTest.cpp + ExchangeClientTest.cpp + NestedLoopJoinTest.cpp + UnorderedStreamReaderTest.cpp + PlanNodeToStringTest.cpp + AssignUniqueIdTest.cpp + FilterProjectTest.cpp + AsyncConnectorTest.cpp ) -target_link_libraries( - velox_exec_test +set( + VELOX_EXEC_TEST_DEPS velox_aggregates velox_dwio_common velox_dwio_common_exception velox_dwio_common_test_utils + velox_dwio_orc_reader velox_dwio_parquet_reader velox_dwio_parquet_writer velox_exec @@ -161,6 +173,7 @@ target_link_libraries( velox_test_util velox_type velox_type_test_lib + velox_trace velox_vector velox_vector_fuzzer velox_writer_fuzzer @@ -170,7 +183,6 @@ target_link_libraries( Boost::date_time Boost::filesystem Boost::program_options - Boost::regex Boost::thread Boost::system GTest::gtest @@ -182,6 +194,45 @@ target_link_libraries( fmt::fmt ) +velox_add_grouped_tests( + PREFIX velox_exec_test + SOURCES ${VELOX_EXEC_TEST_SOURCES} + DEPS ${VELOX_EXEC_TEST_DEPS} + EXTRA_SOURCES Main.cpp + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + +if(VELOX_ENABLE_GEO) + add_executable(velox_exec_SpatialJoinTest SpatialJoinTest.cpp Main.cpp) + add_test( + NAME velox_exec_SpatialJoinTest + COMMAND velox_exec_SpatialJoinTest + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + ) + target_link_libraries(velox_exec_SpatialJoinTest ${VELOX_EXEC_TEST_DEPS}) +endif() + +add_executable( + velox_exec_infra_test + AssertQueryBuilderTest.cpp + DriverTest.cpp + FunctionSignatureBuilderTest.cpp + GroupedExecutionTest.cpp + Main.cpp + OperatorUtilsTest.cpp + PlanBuilderTest.cpp + PrestoQueryRunnerTest.cpp + QueryAssertionsTest.cpp + TaskTest.cpp + TreeOfLosersTest.cpp +) + +add_test( + NAME velox_exec_infra_test + COMMAND velox_exec_infra_test + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + target_link_libraries( velox_exec_infra_test velox_dwio_common @@ -208,7 +259,6 @@ target_link_libraries( Boost::date_time Boost::filesystem Boost::program_options - Boost::regex Boost::thread Boost::system GTest::gtest @@ -220,22 +270,23 @@ target_link_libraries( fmt::fmt ) -add_executable( - velox_exec_util_test - Main.cpp +# Split velox_exec_util_test into individual test binaries for parallel execution. +# Each test depends on the shared base class PrestoQueryRunnerIntermediateTypeTransformTestBase.cpp. +set( + VELOX_EXEC_UTIL_TEST_SOURCES PrestoQueryRunnerHyperLogLogTransformTest.cpp - PrestoQueryRunnerIntermediateTypeTransformTestBase.cpp PrestoQueryRunnerTDigestTransformTest.cpp PrestoQueryRunnerQDigestTransformTest.cpp + PrestoQueryRunnerKHyperLogLogTransformTest.cpp + PrestoQueryRunnerSetDigestTransformTest.cpp PrestoQueryRunnerJsonTransformTest.cpp PrestoQueryRunnerIntervalTransformTest.cpp + PrestoQueryRunnerTimeTransformTest.cpp PrestoQueryRunnerTimestampWithTimeZoneTransformTest.cpp ) -add_test(velox_exec_util_test velox_exec_util_test) - -target_link_libraries( - velox_exec_util_test +set( + VELOX_EXEC_UTIL_TEST_DEPS velox_fuzzer_util velox_exec_test_lib velox_functions_test_lib @@ -245,6 +296,13 @@ target_link_libraries( GTest::gtest_main ) +velox_add_grouped_tests( + PREFIX velox_exec_util_test + SOURCES ${VELOX_EXEC_UTIL_TEST_SOURCES} + DEPS ${VELOX_EXEC_UTIL_TEST_DEPS} + EXTRA_SOURCES PrestoQueryRunnerIntermediateTypeTransformTestBase.cpp Main.cpp +) + add_executable(velox_in_10_min_demo VeloxIn10MinDemo.cpp) target_link_libraries( @@ -264,6 +322,7 @@ add_executable( TableEvolutionFuzzerTest.cpp TableEvolutionFuzzer.cpp ) +velox_add_test_headers(velox_table_evolution_fuzzer_test TableEvolutionFuzzer.h) target_link_libraries( velox_table_evolution_fuzzer_test @@ -273,7 +332,7 @@ target_link_libraries( velox_fuzzer_util velox_functions_prestosql velox_exec_test_lib - velox_temp_path + velox_test_util velox_vector_fuzzer GTest::gtest GTest::gtest_main @@ -297,7 +356,14 @@ target_link_libraries( GTest::gtest_main ) -add_library(velox_simple_aggregate SimpleAverageAggregate.cpp SimpleArrayAggAggregate.cpp) +add_library( + velox_simple_aggregate + SimpleAverageAggregate.cpp + SimpleArrayAggAggregate.cpp + SimpleVariadicArrayAggAggregate.cpp + SimpleVariadicSumAggregate.cpp +) +velox_add_test_headers(velox_simple_aggregate SimpleAggregateFunctionsRegistration.h) target_link_libraries( velox_simple_aggregate @@ -325,6 +391,12 @@ if(VELOX_ENABLE_BENCHMARKS) JoinSpillInputBenchmarkBase.cpp SpillerBenchmarkBase.cpp ) + velox_add_test_headers( + velox_spiller_join_benchmark_base + AggregateSpillBenchmarkBase.h + JoinSpillInputBenchmarkBase.h + SpillerBenchmarkBase.h + ) target_link_libraries( velox_spiller_join_benchmark_base velox_exec @@ -371,14 +443,6 @@ if(VELOX_ENABLE_BENCHMARKS) ) endif() -add_executable(cpr_http_client_test CprHttpClientTest.cpp) -add_test( - NAME cpr_http_client_test - COMMAND cpr_http_client_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} -) -target_link_libraries(cpr_http_client_test cpr::cpr GTest::gtest GTest::gtest_main) - add_executable(velox_driver_test OperatorReplacementTest.cpp Main.cpp) add_test( NAME velox_driver_test @@ -387,3 +451,13 @@ add_test( ) target_link_libraries(velox_driver_test velox_exec velox_exec_test_lib GTest::gtest) + +add_executable(velox_adaptive_prefetch_test AdaptivePrefetchTest.cpp) +add_test( + NAME velox_adaptive_prefetch_test + COMMAND velox_adaptive_prefetch_test + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) +target_link_libraries(velox_adaptive_prefetch_test GTest::gtest GTest::gtest_main) + +velox_add_library(velox_aggregate_registry_util INTERFACE HEADERS AggregateRegistryTestUtil.h) diff --git a/velox/exec/tests/ColumnStatsCollectorTest.cpp b/velox/exec/tests/ColumnStatsCollectorTest.cpp index 250ae45e920..339fdf3df74 100644 --- a/velox/exec/tests/ColumnStatsCollectorTest.cpp +++ b/velox/exec/tests/ColumnStatsCollectorTest.cpp @@ -64,20 +64,20 @@ class ColumnStatsCollectorTest : public OperatorTestBase { core::AggregationNode::Aggregate agg; agg.call = std::dynamic_pointer_cast( - core::Expressions::inferTypes(untypedExpr.expr, type, pool())); + core::Expressions::inferTypes(untypedExpr, type, pool())); for (const auto& input : agg.call->inputs()) { agg.rawInputTypes.push_back(input->type()); } - VELOX_CHECK_NULL(untypedExpr.maskExpr); - VELOX_CHECK(!untypedExpr.distinct); - VELOX_CHECK(untypedExpr.orderBy.empty()); + VELOX_CHECK_NULL(untypedExpr->filter()); + VELOX_CHECK(!untypedExpr->isDistinct()); + VELOX_CHECK(untypedExpr->orderBy().empty()); aggs.emplace_back(agg); - if (untypedExpr.expr->alias().has_value()) { - names.push_back(untypedExpr.expr->alias().value()); + if (untypedExpr->alias().has_value()) { + names.push_back(untypedExpr->alias().value()); } else { names.push_back(fmt::format("a{}", i)); } diff --git a/velox/exec/tests/ConcatFilesSpillMergeStreamTest.cpp b/velox/exec/tests/ConcatFilesSpillMergeStreamTest.cpp index fc8e81fde7f..25b1371618c 100644 --- a/velox/exec/tests/ConcatFilesSpillMergeStreamTest.cpp +++ b/velox/exec/tests/ConcatFilesSpillMergeStreamTest.cpp @@ -14,14 +14,14 @@ * limitations under the License. */ +#include #include "velox/common/file/FileSystems.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/exec/SortBuffer.h" #include "velox/exec/Spill.h" #include "velox/exec/tests/utils/OperatorTestBase.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/type/Type.h" #include "velox/vector/fuzzer/VectorFuzzer.h" -#include "velox/vector/tests/utils/VectorTestBase.h" #include @@ -31,6 +31,7 @@ using namespace facebook::velox; using namespace facebook::velox::memory; namespace facebook::velox::exec::test { +using namespace facebook::velox::common::testutil; class ConcatFilesSpillMergeStreamTest : public OperatorTestBase { protected: @@ -99,8 +100,12 @@ class ConcatFilesSpillMergeStreamTest : public OperatorTestBase { std::vector> spillReadFiles; spillReadFiles.reserve(spillFiles.size()); for (const auto& spillFile : spillFiles) { - spillReadFiles.emplace_back(SpillReadFile::create( - spillFile, spillConfig_.readBufferSize, pool_.get(), &spillStats_)); + spillReadFiles.emplace_back( + SpillReadFile::create( + spillFile, + spillConfig_.readBufferSize, + pool_.get(), + &spillStats_)); } auto stream = ConcatFilesSpillMergeStream::create(i - 1, std::move(spillReadFiles)); @@ -196,7 +201,7 @@ class ConcatFilesSpillMergeStreamTest : public OperatorTestBase { {"c3", VARCHAR()}}); const std::shared_ptr executor_{ std::make_shared( - std::thread::hardware_concurrency())}; + folly::available_concurrency())}; const std::vector sortColumnIndices_{0, 2}; const std::vector sortCompareFlags_{ CompareFlags{}, @@ -204,7 +209,7 @@ class ConcatFilesSpillMergeStreamTest : public OperatorTestBase { const std::vector sortingKeys_ = SpillState::makeSortingKeys(sortColumnIndices_, sortCompareFlags_); const std::shared_ptr spillDirectory_ = - exec::test::TempDirectoryPath::create(); + TempDirectoryPath::create(); const common::SpillConfig spillConfig_{ [&]() -> const std::string& { return spillDirectory_->getPath(); }, [&](uint64_t) {}, @@ -221,9 +226,10 @@ class ConcatFilesSpillMergeStreamTest : public OperatorTestBase { 0, 0, "none", + 0, std::nullopt}; - folly::Synchronized spillStats_; + exec::SpillStats spillStats_; tsan_atomic nonReclaimableSection_{false}; }; } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/CountingJoinTest.cpp b/velox/exec/tests/CountingJoinTest.cpp new file mode 100644 index 00000000000..da9db1630d4 --- /dev/null +++ b/velox/exec/tests/CountingJoinTest.cpp @@ -0,0 +1,418 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "velox/exec/HashTable.h" +#include "velox/exec/PlanNodeStats.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" + +namespace facebook::velox::exec::test { +namespace { + +class CountingJoinTest : public OperatorTestBase { + protected: + void assertCountingJoin( + const std::vector& probe, + const std::vector& build, + core::JoinType joinType, + const std::vector& expected) { + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(makeRowVector({makeFlatVector(probe)})) + .hashJoin( + {"c0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values(makeRowVector( + {"u0"}, {makeFlatVector(build)})) + .planNode(), + "", + {"c0"}, + joinType, + /*nullAware=*/false, + /*nullAsValue=*/true) + .planNode(); + AssertQueryBuilder(plan).assertResults( + makeRowVector({makeFlatVector(expected)})); + } +}; + +// {A, A, A, B, B} EXCEPT ALL {A, A, C} = {A, B, B} +TEST_F(CountingJoinTest, countingAnti) { + assertCountingJoin( + {1, 1, 1, 2, 2}, {1, 1, 3}, core::JoinType::kCountingAnti, {1, 2, 2}); +} + +// {A, A, A, B, B} INTERSECT ALL {A, A, C} = {A, A} +TEST_F(CountingJoinTest, countingSemi) { + assertCountingJoin( + {1, 1, 1, 2, 2}, + {1, 1, 3}, + core::JoinType::kCountingLeftSemiFilter, + {1, 1}); +} + +// Empty build side: returns all probe rows. +TEST_F(CountingJoinTest, countingAntiEmptyBuild) { + assertCountingJoin({1, 2, 3}, {}, core::JoinType::kCountingAnti, {1, 2, 3}); +} + +// Empty build side: returns nothing. +TEST_F(CountingJoinTest, countingSemiEmptyBuild) { + assertCountingJoin( + {1, 2, 3}, {}, core::JoinType::kCountingLeftSemiFilter, {}); +} + +// Empty probe side: returns nothing. +TEST_F(CountingJoinTest, emptyProbe) { + assertCountingJoin({}, {1, 2, 3}, core::JoinType::kCountingAnti, {}); + assertCountingJoin( + {}, {1, 2, 3}, core::JoinType::kCountingLeftSemiFilter, {}); +} + +// No duplicates on build side: degenerates to regular anti/semi. +TEST_F(CountingJoinTest, noDuplicatesOnBuild) { + assertCountingJoin( + {1, 2, 3, 4, 5}, {2, 4}, core::JoinType::kCountingAnti, {1, 3, 5}); + assertCountingJoin( + {1, 2, 3, 4, 5}, {2, 4}, core::JoinType::kCountingLeftSemiFilter, {2, 4}); +} + +// All duplicates on build side. +TEST_F(CountingJoinTest, allDuplicatesOnBuild) { + // EXCEPT ALL: 5 - 3 = 2 + assertCountingJoin( + {1, 1, 1, 1, 1}, {1, 1, 1}, core::JoinType::kCountingAnti, {1, 1}); + // INTERSECT ALL: min(5, 3) = 3 + assertCountingJoin( + {1, 1, 1, 1, 1}, + {1, 1, 1}, + core::JoinType::kCountingLeftSemiFilter, + {1, 1, 1}); +} + +// Multiple batches on probe side. +TEST_F(CountingJoinTest, multipleProbeBatches) { + auto probeVectors = { + makeRowVector({makeFlatVector({1, 1, 2})}), + makeRowVector({makeFlatVector({1, 2, 3})}), + makeRowVector({makeFlatVector({2, 3, 3})}), + }; + auto buildVector = + makeRowVector({"u0"}, {makeFlatVector({1, 2, 2, 3})}); + + auto test = [&](core::JoinType joinType, + const std::vector& expected) { + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .values(probeVectors) + .hashJoin( + {"c0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values(buildVector).planNode(), + "", + {"c0"}, + joinType, + /*nullAware=*/false, + /*nullAsValue=*/true) + .planNode(); + AssertQueryBuilder(plan).assertResults( + makeRowVector({makeFlatVector(expected)})); + }; + + // Probe: {1:3, 2:3, 3:3}, Build: {1:1, 2:2, 3:1} + // EXCEPT ALL: {1:2, 2:1, 3:2} + test(core::JoinType::kCountingAnti, {1, 1, 2, 3, 3}); + // INTERSECT ALL: {1:1, 2:2, 3:1} + test(core::JoinType::kCountingLeftSemiFilter, {1, 2, 2, 3}); +} + +// Multiple columns as join keys. +TEST_F(CountingJoinTest, multipleKeys) { + auto probeVector = makeRowVector( + {makeFlatVector({1, 1, 1, 2, 2}), + makeFlatVector({10, 10, 20, 10, 10})}); + auto buildVector = makeRowVector( + {"u0", "u1"}, + {makeFlatVector({1, 1, 2}), + makeFlatVector({10, 10, 10})}); + + // Probe: {(1,10):2, (1,20):1, (2,10):2}, Build: {(1,10):2, (2,10):1} + // EXCEPT ALL: {(1,20):1, (2,10):1} + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .values(probeVector) + .hashJoin( + {"c0", "c1"}, + {"u0", "u1"}, + PlanBuilder(planNodeIdGenerator).values(buildVector).planNode(), + "", + {"c0", "c1"}, + core::JoinType::kCountingAnti, + /*nullAware=*/false, + /*nullAsValue=*/true) + .planNode(); + + auto expected = makeRowVector({ + makeFlatVector({1, 2}), + makeFlatVector({20, 10}), + }); + AssertQueryBuilder(plan).assertResults(expected); +} + +// Null join keys match each other with nullAsValue (IS NOT DISTINCT FROM +// semantics), as required by SQL set operations (EXCEPT ALL, INTERSECT ALL). +TEST_F(CountingJoinTest, nullKeys) { + // Probe: {1, null, 1, null, 2}, Build: {1, null, 3} + auto probeVector = makeRowVector( + {makeNullableFlatVector({1, std::nullopt, 1, std::nullopt, 2})}); + auto buildVector = makeRowVector( + {"u0"}, {makeNullableFlatVector({1, std::nullopt, 3})}); + + auto test = [&](core::JoinType joinType, + const std::vector>& expected) { + SCOPED_TRACE(core::JoinTypeName::toName(joinType)); + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values({probeVector}) + .hashJoin( + {"c0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values({buildVector}) + .planNode(), + "", + {"c0"}, + joinType, + /*nullAware=*/false, + /*nullAsValue=*/true) + .planNode(); + AssertQueryBuilder(plan).assertResults( + makeRowVector({makeNullableFlatVector(expected)})); + }; + + // EXCEPT ALL: null keys match. Probe has 2 nulls, build has 1 null, so 1 + // null passes through. Key 1 appears twice on probe, once on build, so 1 + // copy passes through. Key 2 has no build match. + test(core::JoinType::kCountingAnti, {1, std::nullopt, 2}); + + // INTERSECT ALL: null keys match. min(2, 1) = 1 null kept. + // Key 1: min(2, 1) = 1. Key 2: no match. + test(core::JoinType::kCountingLeftSemiFilter, {1, std::nullopt}); +} + +// Verifies that the build side deduplicates rows and stores only distinct keys. +TEST_F(CountingJoinTest, buildSideDedup) { + // 10 distinct keys repeated 1'000 times each = 10'000 build rows. + auto buildVector = makeRowVector( + {"u0"}, {makeFlatVector(10, [](auto row) { return row; })}); + + // Probe has values 0..14. Build has values 0..9, each repeated 1'000 times. + // EXCEPT ALL: values 0..9 each appear once on probe but 1'000 times on build, + // so all are consumed. Values 10..14 have no match and are emitted. + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(makeRowVector({makeFlatVector( + 15, [](auto row) { return row; })})) + .hashJoin( + {"c0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values({buildVector}, false, 1'000) + .planNode(), + "", + {"c0"}, + core::JoinType::kCountingAnti, + /*nullAware=*/false, + /*nullAsValue=*/true) + .planNode(); + + auto expected = + makeRowVector({makeFlatVector({10, 11, 12, 13, 14})}); + auto task = AssertQueryBuilder(plan).assertResults(expected); + + // Verify the build received 10'000 rows but the hash table has only 10 + // distinct keys. + auto planStats = toPlanStats(task->taskStats()); + const auto& joinStats = planStats.at(plan->id()); + EXPECT_EQ(10'000, joinStats.operatorStats.at("HashBuild")->inputRows); + EXPECT_EQ( + 10, + joinStats.customStats.at(std::string(BaseHashTable::kNumDistinct)).sum); +} + +// Verifies that counts are correctly merged when multiple build drivers process +// the same keys. Each driver builds its own hash table with per-key counts; +// during the merge in prepareJoinTable, duplicate keys must have their counts +// summed rather than silently dropped. +// +// Tests three key configurations to exercise all hash table modes: +// - Small INT32 keys {1, 2}: array mode (arrayPushRow path). +// - Two INT32 key columns with ~1500 distinct values each: normalized-key +// mode (buildFullProbe with normalizedKey comparison). +// - Complex-type (ARRAY(INT32)) keys: hash mode (buildFullProbe with +// compareKeys). +TEST_F(CountingJoinTest, multipleBuildDrivers) { + constexpr int32_t kNumDrivers = 4; + + auto test = [&](const RowVectorPtr& probeVector, + const RowVectorPtr& buildVector, + core::JoinType joinType, + BaseHashTable::HashMode expectedHashMode, + const std::vector& expectedChildren) { + auto expected = + makeRowVector(probeVector->rowType()->names(), expectedChildren); + auto probeKeys = probeVector->rowType()->names(); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values({probeVector}) + .hashJoin( + probeKeys, + buildVector->rowType()->names(), + PlanBuilder(planNodeIdGenerator) + .values({buildVector}, /*parallelizable=*/true) + .planNode(), + "", + probeKeys, + joinType, + /*nullAware=*/false, + /*nullAsValue=*/true) + .planNode(); + auto task = AssertQueryBuilder(plan) + .maxDrivers(kNumDrivers) + .assertResults(expected); + + // Verify that all build drivers ran. Each driver processes 1 batch. + auto planStats = toPlanStats(task->taskStats()); + const auto& joinStats = planStats.at(plan->id()); + + const auto& buildStats = *joinStats.operatorStats.at("HashBuild"); + EXPECT_EQ(kNumDrivers, buildStats.numDrivers); + EXPECT_EQ(kNumDrivers, buildStats.inputVectors); + EXPECT_EQ(buildVector->size() * kNumDrivers, buildStats.inputRows); + + // Probe side runs single-threaded (Values node is not parallelizable by + // default). It processes 1 batch. + const auto& probeStats = *joinStats.operatorStats.at("HashProbe"); + EXPECT_EQ(1, probeStats.numDrivers); + EXPECT_EQ(1, probeStats.inputVectors); + EXPECT_EQ(probeVector->size(), probeStats.inputRows); + + EXPECT_EQ( + static_cast(expectedHashMode), + buildStats.customStats.at(std::string(BaseHashTable::kHashMode)).sum); + }; + + // Small keys {1, 2}: triggers array-based hash mode. + // Build: {1, 1, 2}, each of 4 drivers => merged {1:8, 2:4}. + // Probe: {1:3, 2:5}. + { + SCOPED_TRACE("small keys (array mode)"); + auto buildVector = + makeRowVector({"u0"}, {makeFlatVector({1, 1, 2})}); + auto probeVector = + makeRowVector({makeFlatVector({1, 1, 1, 2, 2, 2, 2, 2})}); + + // EXCEPT ALL: {1: max(3-8,0)=0, 2: max(5-4,0)=1} => {2}. + test( + probeVector, + buildVector, + core::JoinType::kCountingAnti, + BaseHashTable::HashMode::kArray, + {makeFlatVector({2})}); + // INTERSECT ALL: {1: min(3,8)=3, 2: min(5,4)=4}. + test( + probeVector, + buildVector, + core::JoinType::kCountingLeftSemiFilter, + BaseHashTable::HashMode::kArray, + {makeFlatVector({1, 1, 1, 2, 2, 2, 2})}); + } + + // Two INT32 key columns with ~1500 distinct values each. The combined + // cardinality (1500 * 1500 = 2.25M) exceeds kArrayHashMaxSize (2M), + // forcing normalized-key mode (buildFullProbe with normalizedKey comparison). + { + SCOPED_TRACE("multi-column keys (normalized-key mode)"); + constexpr int32_t kNumKeys = 1'500; + + // Build: 2-column keys (i, i) for i in 0..kNumKeys-1, plus duplicates of + // (0, 0) and (1, 1). Both columns are identical — we need 2 columns so + // the combined cardinality (1500 * 1500 = 2.25M) exceeds kArrayHashMaxSize + // (2M). + auto buildKeys = makeFlatVector(kNumKeys + 2, [](auto row) { + return row < kNumKeys ? row : row - kNumKeys; + }); + + // Probe: key (0,0) appears 10 times, key (1,1) appears 5 times. + auto probeKeys = + makeFlatVector({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + + // Each of 4 drivers: (0,0) count=2, (1,1) count=2, others count=1. + // After merge: (0,0) count=8, (1,1) count=8. + // EXCEPT ALL: (0,0): max(10-8,0)=2, (1,1): max(5-8,0)=0. + auto expectedAnti = makeFlatVector({0, 0}); + test( + makeRowVector({probeKeys, probeKeys}), + makeRowVector({"u0", "u1"}, {buildKeys, buildKeys}), + core::JoinType::kCountingAnti, + BaseHashTable::HashMode::kNormalizedKey, + {expectedAnti, expectedAnti}); + // INTERSECT ALL: (0,0): min(10,8)=8, (1,1): min(5,8)=5. + auto expectedSemi = + makeFlatVector({0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + test( + makeRowVector({probeKeys, probeKeys}), + makeRowVector({"u0", "u1"}, {buildKeys, buildKeys}), + core::JoinType::kCountingLeftSemiFilter, + BaseHashTable::HashMode::kNormalizedKey, + {expectedSemi, expectedSemi}); + } + + // ARRAY(INT32) type keys do not support value IDs in VectorHasher, forcing + // kHash mode (buildFullProbe with compareKeys). + { + SCOPED_TRACE("complex-type keys (hash mode)"); + auto buildVector = + makeRowVector({"u0"}, {makeArrayVector({{1}, {1}, {2}})}); + auto probeVector = makeRowVector( + {makeArrayVector({{1}, {1}, {1}, {2}, {2}, {2}, {2}, {2}})}); + + // EXCEPT ALL: {[1]: max(3-8,0)=0, [2]: max(5-4,0)=1} => {[2]}. + test( + probeVector, + buildVector, + core::JoinType::kCountingAnti, + BaseHashTable::HashMode::kHash, + {makeArrayVector({{2}})}); + // INTERSECT ALL: {[1]: min(3,8)=3, [2]: min(5,4)=4}. + test( + probeVector, + buildVector, + core::JoinType::kCountingLeftSemiFilter, + BaseHashTable::HashMode::kHash, + {makeArrayVector({{1}, {1}, {1}, {2}, {2}, {2}, {2}})}); + } +} + +} // namespace +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/CprHttpClientTest.cpp b/velox/exec/tests/CprHttpClientTest.cpp deleted file mode 100644 index 9a19ccaaed2..00000000000 --- a/velox/exec/tests/CprHttpClientTest.cpp +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "velox/common/base/tests/GTestUtils.h" - -#include -#include - -class CprHttpClientTest : public testing::Test {}; - -// This test requires open access to internet and most places test runners might -// be closed off from the general internet. And this test case is just an -// illustration of how to use cpr, so disable it by default. -TEST_F(CprHttpClientTest, DISABLED_basic) { - auto response = cpr::Get( - cpr::Url{"https://facebookincubator.github.io/velox/"}, - cpr::Timeout{std::chrono::seconds{3}}); - ASSERT_EQ(response.status_code, 200); - ASSERT_FALSE(response.text.empty()); - - response = cpr::Get(cpr::Url{"null"}); - ASSERT_NE(response.status_code, 200); - ASSERT_TRUE(response.text.empty()); - - response = cpr::Post( - cpr::Url{"https://facebookincubator.github.io/velox/"}, - cpr::Body{"select * from nation limit 1"}, - cpr::Header({{"Content-Type", "text/plain"}})); - ASSERT_EQ(response.status_code, 405); - ASSERT_FALSE(response.text.empty()); - - response = cpr::Post( - cpr::Url{"null"}, - cpr::Body{"select * from nation limit 1"}, - cpr::Header({{"Content-Type", "text/plain"}})); - ASSERT_NE(response.status_code, 200); - ASSERT_TRUE(response.text.empty()); -} diff --git a/velox/exec/tests/CustomTraceTest.cpp b/velox/exec/tests/CustomTraceTest.cpp new file mode 100644 index 00000000000..267217b7372 --- /dev/null +++ b/velox/exec/tests/CustomTraceTest.cpp @@ -0,0 +1,1180 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/exec/trace/TraceCtx.h" +#include "velox/vector/BaseVector.h" +#include "velox/vector/tests/utils/VectorTestBase.h" + +namespace facebook::velox::exec::trace::test { +namespace { + +using exec::test::AssertQueryBuilder; +using exec::test::PlanBuilder; +using velox::test::assertEqualVectors; + +using TCapturedVectors = std::unordered_map; + +// A custom test trace implementation that only captures pointers to internal +// vectors. It only traces operators from `traceIds`, and assumes +// single-threaded execution. Vectors are captured in TCapturedVectors. +class TestTraceCtx : public TraceCtx { + public: + TestTraceCtx( + const std::vector& tracedIds, + TCapturedVectors& tracedVectors) + : TraceCtx(false), + tracedIds_(tracedIds.begin(), tracedIds.end()), + tracedVectors_(tracedVectors) {} + + bool shouldTrace(const Operator& op) const override { + return tracedIds_.contains(op.planNodeId()); + } + + class TestTraceInputWriter : public TraceInputWriter { + public: + TestTraceInputWriter( + const core::PlanNodeId& planId, + TCapturedVectors& tracedVectors) + : planId_(planId), tracedVectors_(tracedVectors) {} + + bool write(const RowVectorPtr& vector, ContinueFuture*) override { + tracedVectors_[planId_] = vector; + return false; + } + + void finish() override {} + + private: + const core::PlanNodeId planId_; + TCapturedVectors& tracedVectors_; + }; + + std::unique_ptr createInputTracer( + Operator& op) const override { + return std::make_unique( + op.planNodeId(), tracedVectors_); + } + + private: + std::unordered_set tracedIds_; + + TCapturedVectors& tracedVectors_; +}; + +// Captures expression output vectors keyed by function name. +using TExprCapturedVectors = std::unordered_map; + +// A test trace implementation for expression-level tracing. Traces output +// batches from named expression functions listed in tracedExprs. +class ExprTestTraceCtx : public TraceCtx { + public: + ExprTestTraceCtx( + const std::vector& tracedExprs, + TExprCapturedVectors& capturedVectors, + bool shouldTraceOp = true) + : TraceCtx(false), + shouldTraceOp_(shouldTraceOp), + tracedExprs_(tracedExprs.begin(), tracedExprs.end()), + capturedVectors_(capturedVectors) {} + + bool shouldTrace(const Operator& /*op*/) const override { + return shouldTraceOp_; + } + + bool shouldTraceExpr(std::string_view functionName) const override { + return tracedExprs_.contains(std::string(functionName)); + } + + std::unique_ptr createExprOutputTracer( + const Operator& /*op*/, + std::string_view functionName, + int instanceIndex) const override { + auto key = fmt::format("{}_{}", functionName, instanceIndex); + return std::make_unique(std::move(key), capturedVectors_); + } + + private: + class TestExprWriter : public TraceExprWriter { + public: + TestExprWriter( + std::string functionName, + TExprCapturedVectors& capturedVectors) + : functionName_(std::move(functionName)), + capturedVectors_(capturedVectors) {} + + void write(const VectorPtr& vector) override { + capturedVectors_[functionName_] = vector; + } + + void finish() override {} + + private: + const std::string functionName_; + TExprCapturedVectors& capturedVectors_; + }; + + const bool shouldTraceOp_; + std::unordered_set tracedExprs_; + TExprCapturedVectors& capturedVectors_; +}; + +class CustomTraceTest : public exec::test::HiveConnectorTestBase {}; + +TEST_F(CustomTraceTest, customTrace) { + auto vector = makeRowVector( + {"a"}, {makeFlatVector(10, [](auto row) { return row; })}); + + core::PlanNodeId traceNodeId1; + core::PlanNodeId traceNodeId2; + + // Trace the inputs from two operators. + auto plan = PlanBuilder() + .values({vector}) + .project({"a * 10 as a"}) + .capturePlanNodeId(traceNodeId1) + .project({"a * 10 as a"}) + .project({"a * 10 as a"}) + .capturePlanNodeId(traceNodeId2) + .planNode(); + + TCapturedVectors tracedVectors; + auto queryCtx = + core::QueryCtx::Builder() + .executor(executor_.get()) + .traceCtxProvider([&](core::QueryCtx&, const core::PlanFragment&) { + return std::make_unique( + std::vector{traceNodeId1, traceNodeId2}, + tracedVectors); + }) + .build(); + + std::shared_ptr task; + AssertQueryBuilder(plan) + .config(core::QueryConfig::kQueryTraceEnabled, true) + .queryCtx(queryCtx) + .countResults(task); + + auto it1 = tracedVectors.find(traceNodeId1); + auto it2 = tracedVectors.find(traceNodeId2); + + ASSERT_TRUE(it1 != tracedVectors.end()); + ASSERT_TRUE(it2 != tracedVectors.end()); + + auto expected1 = makeRowVector( + {"a"}, {makeFlatVector(10, [](auto row) { return row; })}); + auto expected2 = makeRowVector( + {"a"}, + {makeFlatVector(10, [](auto row) { return row * 10 * 10; })}); + + assertEqualVectors(it1->second, expected1); + assertEqualVectors(it2->second, expected2); + + // Vectors need to be destructed before the pool in the task dies. + tracedVectors.clear(); +} + +TEST_F(CustomTraceTest, exprOutputTrace) { + auto vector = makeRowVector( + {"a"}, {makeFlatVector(10, [](auto row) { return row; })}); + + auto plan = + PlanBuilder().values({vector}).project({"a * 10 as a"}).planNode(); + + TExprCapturedVectors capturedVectors; + auto queryCtx = + core::QueryCtx::Builder() + .executor(executor_.get()) + .traceCtxProvider([&](core::QueryCtx&, const core::PlanFragment&) { + return std::make_unique( + std::vector{"multiply"}, capturedVectors); + }) + .build(); + + std::shared_ptr task; + AssertQueryBuilder(plan) + .config(core::QueryConfig::kQueryTraceEnabled, true) + .queryCtx(queryCtx) + .countResults(task); + + auto iterator = capturedVectors.find("multiply_0"); + ASSERT_TRUE(iterator != capturedVectors.end()); + + auto expected = + makeFlatVector(10, [](auto row) { return row * 10; }); + assertEqualVectors(iterator->second, expected); + + capturedVectors.clear(); +} + +TEST_F(CustomTraceTest, exprTraceOnlyMatchingFunctions) { + auto vector = makeRowVector( + {"a"}, {makeFlatVector(10, [](auto row) { return row; })}); + + // Two projections: multiply then plus. Only trace "plus". + auto plan = PlanBuilder() + .values({vector}) + .project({"a * 10 as b", "a + 5 as c"}) + .planNode(); + + TExprCapturedVectors capturedVectors; + auto queryCtx = + core::QueryCtx::Builder() + .executor(executor_.get()) + .traceCtxProvider([&](core::QueryCtx&, const core::PlanFragment&) { + return std::make_unique( + std::vector{"plus"}, capturedVectors); + }) + .build(); + + std::shared_ptr task; + AssertQueryBuilder(plan) + .config(core::QueryConfig::kQueryTraceEnabled, true) + .queryCtx(queryCtx) + .countResults(task); + + // "plus" should be captured. + EXPECT_TRUE(capturedVectors.contains("plus_0")); + + // "multiply" should not be captured. + EXPECT_FALSE(capturedVectors.contains("multiply_0")); + + capturedVectors.clear(); +} + +TEST_F(CustomTraceTest, exprTraceMultipleInstances) { + auto vector = makeRowVector( + {"a", "b"}, + {makeFlatVector(10, [](auto row) { return row; }), + makeFlatVector(10, [](auto row) { return row + 100; })}); + + auto plan = PlanBuilder() + .values({vector}) + .project({"a * 10 as c", "b * 20 as d"}) + .planNode(); + + TExprCapturedVectors capturedVectors; + auto queryCtx = + core::QueryCtx::Builder() + .executor(executor_.get()) + .traceCtxProvider([&](core::QueryCtx&, const core::PlanFragment&) { + return std::make_unique( + std::vector{"multiply"}, capturedVectors); + }) + .build(); + + std::shared_ptr task; + AssertQueryBuilder(plan) + .config(core::QueryConfig::kQueryTraceEnabled, true) + .queryCtx(queryCtx) + .countResults(task); + + ASSERT_TRUE(capturedVectors.contains("multiply_0")); + ASSERT_TRUE(capturedVectors.contains("multiply_1")); + + auto expected0 = + makeFlatVector(10, [](auto row) { return row * 10; }); + auto expected1 = + makeFlatVector(10, [](auto row) { return (row + 100) * 20; }); + + assertEqualVectors(capturedVectors.at("multiply_0"), expected0); + assertEqualVectors(capturedVectors.at("multiply_1"), expected1); + + capturedVectors.clear(); +} + +TEST_F(CustomTraceTest, exprTraceCseSharedNode) { + auto vector = makeRowVector( + {"a"}, {makeFlatVector(10, [](auto row) { return row; })}); + + // Both projections share the "a * 10" subexpression (CSE). The visited set + // should ensure the shared multiply node is only traced once. + auto plan = PlanBuilder() + .values({vector}) + .project({"a * 10 + 1 as b", "a * 10 + 2 as c"}) + .planNode(); + + TExprCapturedVectors capturedVectors; + auto queryCtx = + core::QueryCtx::Builder() + .executor(executor_.get()) + .traceCtxProvider([&](core::QueryCtx&, const core::PlanFragment&) { + return std::make_unique( + std::vector{"multiply"}, capturedVectors); + }) + .build(); + + std::shared_ptr task; + AssertQueryBuilder(plan) + .config(core::QueryConfig::kQueryTraceEnabled, true) + .queryCtx(queryCtx) + .countResults(task); + + // The shared CSE node should produce exactly one multiply instance. + EXPECT_TRUE(capturedVectors.contains("multiply_0")); + EXPECT_FALSE(capturedVectors.contains("multiply_1")); + + capturedVectors.clear(); +} + +TEST_F(CustomTraceTest, exprTraceNullResult) { + auto vector = makeRowVector( + {"a"}, + {makeNullableFlatVector({1, std::nullopt, 3, std::nullopt, 5})}); + + // try(a / 0) produces all nulls without throwing. + auto plan = + PlanBuilder().values({vector}).project({"try(a / 0) as b"}).planNode(); + + TExprCapturedVectors capturedVectors; + auto queryCtx = + core::QueryCtx::Builder() + .executor(executor_.get()) + .traceCtxProvider([&](core::QueryCtx&, const core::PlanFragment&) { + return std::make_unique( + std::vector{"divide"}, capturedVectors); + }) + .build(); + + std::shared_ptr task; + AssertQueryBuilder(plan) + .config(core::QueryConfig::kQueryTraceEnabled, true) + .queryCtx(queryCtx) + .countResults(task); + + // divide is inside try() which catches the error and produces nulls. + // The tracer should still capture output (the null vector). + EXPECT_TRUE(capturedVectors.contains("divide_0")); + + capturedVectors.clear(); +} + +TEST_F(CustomTraceTest, exprTraceOperatorScoping) { + auto vector = makeRowVector( + {"a"}, {makeFlatVector(10, [](auto row) { return row; })}); + + auto plan = + PlanBuilder().values({vector}).project({"a * 10 as b"}).planNode(); + + TExprCapturedVectors capturedVectors; + auto queryCtx = + core::QueryCtx::Builder() + .executor(executor_.get()) + .traceCtxProvider([&](core::QueryCtx&, const core::PlanFragment&) { + return std::make_unique( + std::vector{"multiply"}, capturedVectors, false); + }) + .build(); + + std::shared_ptr task; + AssertQueryBuilder(plan) + .config(core::QueryConfig::kQueryTraceEnabled, true) + .queryCtx(queryCtx) + .countResults(task); + + EXPECT_TRUE(capturedVectors.empty()); +} + +// A trace context that creates writers which throw from write() and finish(). +class ThrowingExprTraceCtx : public TraceCtx { + public: + explicit ThrowingExprTraceCtx(const std::vector& tracedExprs) + : TraceCtx(false), tracedExprs_(tracedExprs.begin(), tracedExprs.end()) {} + + bool shouldTrace(const Operator& /*op*/) const override { + return true; + } + + bool shouldTraceExpr(std::string_view functionName) const override { + return tracedExprs_.contains(std::string(functionName)); + } + + std::unique_ptr createExprOutputTracer( + const Operator& /*op*/, + std::string_view /*functionName*/, + int /*instanceIndex*/) const override { + return std::make_unique(); + } + + private: + class ThrowingExprWriter : public TraceExprWriter { + public: + void write(const VectorPtr& /*result*/) override { + throw std::runtime_error("write failed"); + } + + void finish() override { + throw std::runtime_error("finish failed"); + } + }; + + std::unordered_set tracedExprs_; +}; + +TEST_F(CustomTraceTest, exprTraceThrowingWriter) { + auto vector = makeRowVector( + {"a"}, {makeFlatVector(10, [](auto row) { return row; })}); + + auto plan = + PlanBuilder().values({vector}).project({"a * 10 as a"}).planNode(); + + auto queryCtx = + core::QueryCtx::Builder() + .executor(executor_.get()) + .traceCtxProvider([&](core::QueryCtx&, const core::PlanFragment&) { + return std::make_unique( + std::vector{"multiply"}); + }) + .build(); + + std::shared_ptr task; + ASSERT_NO_THROW(AssertQueryBuilder(plan) + .config(core::QueryConfig::kQueryTraceEnabled, true) + .queryCtx(queryCtx) + .countResults(task)); +} + +// Helper to convert a vector of plan node IDs to a breakpoints map with null +// callbacks. +CursorParameters::TBreakpointMap toBreakpointsMap( + const std::vector& ids) { + CursorParameters::TBreakpointMap result; + for (const auto& id : ids) { + result[id] = nullptr; + } + return result; +} + +void assertCursorOutput( + const core::PlanNodePtr& plan, + const std::vector& breakpoints, + const std::vector& expectation) { + auto cursor = TaskCursor::create({ + .planNode = plan, + .serialExecution = true, + .breakpoints = toBreakpointsMap(breakpoints), + }); + size_t i = 0; + + while (cursor->moveStep()) { + if (i < expectation.size()) { + assertEqualVectors(cursor->current(), expectation[i++]); + } else { + ADD_FAILURE() << "Cursor output is longer than expectation: " << i; + } + } + EXPECT_EQ(i, expectation.size()); +} + +TEST_F(CustomTraceTest, taskDebuggerCursor) { + const size_t size = 10; + auto makeData = [&](std::function values) { + return makeRowVector( + {"a"}, {makeFlatVector(size, std::move(values))}); + }; + + // Two input vectors. + auto input1 = makeData([](auto row) { return row; }); + auto input2 = makeData([](auto row) { return row + 10; }); + + // Now, the expected input for a series of operators. + auto input1Project1 = makeData([](auto row) { return row; }); + auto input1Project2 = makeData([](auto row) { return row * 10; }); + auto input1Project3 = makeData([](auto row) { return row * 100; }); + auto input1Project4 = makeData([](auto row) { return row * 1'000; }); + auto output1 = makeData([](auto row) { return row * 10'000; }); + + auto input2Project1 = makeData([](auto row) { return (row + 10); }); + auto input2Project2 = makeData([](auto row) { return (row + 10) * 10; }); + auto input2Project3 = makeData([](auto row) { return (row + 10) * 100; }); + auto input2Project4 = makeData([](auto row) { return (row + 10) * 1'000; }); + auto output2 = makeData([](auto row) { return (row + 10) * 10'000; }); + + core::PlanNodeId project1, project2, project3, project4; + auto plan = PlanBuilder() + .values({input1, input2}) + .project({"a * 10 as a"}) + .capturePlanNodeId(project1) + .project({"a * 10 as a"}) + .capturePlanNodeId(project2) + .project({"a * 10 as a"}) + .capturePlanNodeId(project3) + .project({"a * 10 as a"}) + .capturePlanNodeId(project4) + .planNode(); + + // Test a series of combinations. + assertCursorOutput(plan, {}, {output1, output2}); + assertCursorOutput( + plan, {project1}, {input1Project1, output1, input2Project1, output2}); + assertCursorOutput( + plan, + {project1, project2}, + { + input1Project1, + input1Project2, + output1, + input2Project1, + input2Project2, + output2, + }); + assertCursorOutput( + plan, + {project2, project4}, + { + input1Project2, + input1Project4, + output1, + input2Project2, + input2Project4, + output2, + }); + assertCursorOutput( + plan, + {project1, project2, project3, project4}, + { + input1Project1, + input1Project2, + input1Project3, + input1Project4, + output1, + input2Project1, + input2Project2, + input2Project3, + input2Project4, + output2, + }); +} + +TEST_F(CustomTraceTest, cursorAt) { + const size_t size = 10; + auto input1 = makeRowVector( + {"a"}, {makeFlatVector(size, [](auto row) { return row; })}); + auto input2 = makeRowVector( + {"a"}, + {makeFlatVector(size, [](auto row) { return row + 10; })}); + + core::PlanNodeId project1, project2, project3, project4; + auto plan = PlanBuilder() + .values({input1, input2}) + .project({"a * 10 as a"}) + .capturePlanNodeId(project1) + .project({"a * 10 as a"}) + .capturePlanNodeId(project2) + .project({"a * 10 as a"}) + .capturePlanNodeId(project3) + .project({"a * 10 as a"}) + .capturePlanNodeId(project4) + .planNode(); + + auto cursor = TaskCursor::create({ + .planNode = plan, + .serialExecution = true, + .breakpoints = toBreakpointsMap({project1, project3}), + }); + + // Before any step, at() should return empty string. + EXPECT_EQ(cursor->at(), ""); + + // First step stops at project1 (first breakpoint). + EXPECT_TRUE(cursor->moveStep()); + EXPECT_EQ(cursor->at(), project1); + + // Second step stops at project3 (second breakpoint). + EXPECT_TRUE(cursor->moveStep()); + EXPECT_EQ(cursor->at(), project3); + + // Third step produces final output (no breakpoint, empty at()). + EXPECT_TRUE(cursor->moveStep()); + EXPECT_EQ(cursor->at(), ""); + + // Fourth step stops at project1 for second input batch. + EXPECT_TRUE(cursor->moveStep()); + EXPECT_EQ(cursor->at(), project1); + + // Fifth step stops at project3 for second input batch. + EXPECT_TRUE(cursor->moveStep()); + EXPECT_EQ(cursor->at(), project3); + + // Sixth step produces final output for second batch. + EXPECT_TRUE(cursor->moveStep()); + EXPECT_EQ(cursor->at(), ""); + + // No more data. + EXPECT_FALSE(cursor->moveStep()); + + // Test that moveNext() skips breakpoints and at() returns empty. + auto cursor2 = TaskCursor::create({ + .planNode = plan, + .serialExecution = true, + .breakpoints = toBreakpointsMap({project1, project3}), + }); + + EXPECT_EQ(cursor2->at(), ""); + + // moveNext() should skip to final output. + EXPECT_TRUE(cursor2->moveNext()); + EXPECT_EQ(cursor2->at(), ""); + + EXPECT_TRUE(cursor2->moveNext()); + EXPECT_EQ(cursor2->at(), ""); + + EXPECT_FALSE(cursor2->moveNext()); +} + +TEST_F(CustomTraceTest, breakpointCallbackAlwaysStop) { + const size_t size = 10; + auto input = makeRowVector( + {"a"}, {makeFlatVector(size, [](auto row) { return row; })}); + + core::PlanNodeId project1; + auto plan = PlanBuilder() + .values({input}) + .project({"a * 10 as a"}) + .capturePlanNodeId(project1) + .planNode(); + + // Callback that always returns true (always stop). + int callbackCount = 0; + CursorParameters::TBreakpointMap breakpoints; + breakpoints[project1] = [&](const RowVectorPtr& vector) { + ++callbackCount; + EXPECT_EQ(vector->size(), size); + return true; + }; + + auto cursor = TaskCursor::create({ + .planNode = plan, + .serialExecution = true, + .breakpoints = std::move(breakpoints), + }); + + // First step should stop at breakpoint. + EXPECT_TRUE(cursor->moveStep()); + EXPECT_EQ(cursor->at(), project1); + EXPECT_EQ(callbackCount, 1); + + // Second step should produce final output. + EXPECT_TRUE(cursor->moveStep()); + EXPECT_EQ(cursor->at(), ""); + + EXPECT_FALSE(cursor->moveStep()); +} + +TEST_F(CustomTraceTest, breakpointCallbackNeverStop) { + const size_t size = 10; + auto input = makeRowVector( + {"a"}, {makeFlatVector(size, [](auto row) { return row; })}); + + core::PlanNodeId project1; + auto plan = PlanBuilder() + .values({input}) + .project({"a * 10 as a"}) + .capturePlanNodeId(project1) + .planNode(); + + // Callback that always returns false (never stop). + int callbackCount = 0; + CursorParameters::TBreakpointMap breakpoints; + breakpoints[project1] = [&](const RowVectorPtr& vector) { + ++callbackCount; + EXPECT_EQ(vector->size(), size); + return false; + }; + + auto cursor = TaskCursor::create({ + .planNode = plan, + .serialExecution = true, + .breakpoints = std::move(breakpoints), + }); + + // Step should skip the breakpoint and go directly to final output. + EXPECT_TRUE(cursor->moveStep()); + EXPECT_EQ(cursor->at(), ""); + EXPECT_EQ(callbackCount, 1); // Callback was still invoked. + + EXPECT_FALSE(cursor->moveStep()); +} + +TEST_F(CustomTraceTest, breakpointCallbackConditional) { + const size_t size = 10; + auto input1 = makeRowVector( + {"a"}, {makeFlatVector(size, [](auto row) { return row; })}); + auto input2 = makeRowVector( + {"a"}, + {makeFlatVector(size, [](auto row) { return row + 100; })}); + + core::PlanNodeId project1; + auto plan = PlanBuilder() + .values({input1, input2}) + .project({"a * 10 as a"}) + .capturePlanNodeId(project1) + .planNode(); + + // Callback that stops only when first element is >= 100. + int callbackCount = 0; + CursorParameters::TBreakpointMap breakpoints; + breakpoints[project1] = [&](const RowVectorPtr& vector) { + ++callbackCount; + auto values = vector->childAt(0)->asFlatVector(); + return values->valueAt(0) >= 100; + }; + + auto cursor = TaskCursor::create({ + .planNode = plan, + .serialExecution = true, + .breakpoints = std::move(breakpoints), + }); + + // First batch: callback returns false (first element is 0), skips breakpoint. + // Goes to final output for first batch. + EXPECT_TRUE(cursor->moveStep()); + EXPECT_EQ(cursor->at(), ""); + EXPECT_EQ(callbackCount, 1); + + // Second batch: callback returns true (first element is 100), stops at + // breakpoint. + EXPECT_TRUE(cursor->moveStep()); + EXPECT_EQ(cursor->at(), project1); + EXPECT_EQ(callbackCount, 2); + + // Final output for second batch. + EXPECT_TRUE(cursor->moveStep()); + EXPECT_EQ(cursor->at(), ""); + + EXPECT_FALSE(cursor->moveStep()); +} + +TEST_F(CustomTraceTest, breakpointMixedCallbacks) { + const size_t size = 10; + auto input = makeRowVector( + {"a"}, {makeFlatVector(size, [](auto row) { return row; })}); + + core::PlanNodeId project1, project2; + auto plan = PlanBuilder() + .values({input}) + .project({"a * 10 as a"}) + .capturePlanNodeId(project1) + .project({"a * 10 as a"}) + .capturePlanNodeId(project2) + .planNode(); + + // project1 has callback returning false (don't stop). + // project2 has null callback (always stop). + int callbackCount = 0; + CursorParameters::TBreakpointMap breakpoints; + breakpoints[project1] = [&](const RowVectorPtr&) { + ++callbackCount; + return false; + }; + breakpoints[project2] = nullptr; // null callback = always stop. + + auto cursor = TaskCursor::create({ + .planNode = plan, + .serialExecution = true, + .breakpoints = std::move(breakpoints), + }); + + // project1 callback returns false, so it's skipped. + // project2 has null callback, so it stops. + EXPECT_TRUE(cursor->moveStep()); + EXPECT_EQ(cursor->at(), project2); + EXPECT_EQ(callbackCount, 1); + + // Final output. + EXPECT_TRUE(cursor->moveStep()); + EXPECT_EQ(cursor->at(), ""); + + EXPECT_FALSE(cursor->moveStep()); +} + +// Tests stepping through multiple input batches with a breakpoint on a single +// operator using parallel execution. With 4 drivers each processing 3 batches, +// we expect 12 breakpoint hits and 12 final outputs, interleaved in +// non-deterministic order across drivers. +TEST_F(CustomTraceTest, parallelSingleBreakpoint) { + const size_t size = 10; + auto input1 = makeRowVector( + {"a"}, {makeFlatVector(size, [](auto row) { return row; })}); + auto input2 = makeRowVector( + {"a"}, + {makeFlatVector(size, [](auto row) { return row + 10; })}); + auto input3 = makeRowVector( + {"a"}, + {makeFlatVector(size, [](auto row) { return row + 20; })}); + + core::PlanNodeId project1; + auto plan = PlanBuilder() + .values({input1, input2, input3}, /*parallelizable=*/true) + .project({"a * 10 as a"}) + .capturePlanNodeId(project1) + .planNode(); + + constexpr int kNumDrivers = 4; + constexpr int kNumBatches = 3; + + auto cursor = TaskCursor::create({ + .planNode = plan, + .maxDrivers = kNumDrivers, + .breakpoints = toBreakpointsMap({project1}), + }); + + int numBreakpointHits = 0; + int numFinalOutputs = 0; + + while (cursor->moveStep()) { + if (cursor->at() == project1) { + ++numBreakpointHits; + } else { + EXPECT_EQ(cursor->at(), ""); + ++numFinalOutputs; + } + } + + // Each of the 4 drivers processes all 3 batches, producing one breakpoint + // hit and one final output per batch. + EXPECT_EQ(numBreakpointHits, kNumDrivers * kNumBatches); + EXPECT_EQ(numFinalOutputs, kNumDrivers * kNumBatches); +} + +// Tests that moveNext() in parallel mode skips breakpoints and only produces +// final task outputs. +TEST_F(CustomTraceTest, parallelMoveNext) { + const size_t size = 10; + auto input1 = makeRowVector( + {"a"}, {makeFlatVector(size, [](auto row) { return row; })}); + auto input2 = makeRowVector( + {"a"}, + {makeFlatVector(size, [](auto row) { return row + 10; })}); + auto input3 = makeRowVector( + {"a"}, + {makeFlatVector(size, [](auto row) { return row + 20; })}); + + core::PlanNodeId project1; + auto plan = PlanBuilder() + .values({input1, input2, input3}, /*parallelizable=*/true) + .project({"a * 10 as a"}) + .capturePlanNodeId(project1) + .planNode(); + + constexpr int kNumDrivers = 4; + constexpr int kNumBatches = 3; + + auto cursor = TaskCursor::create({ + .planNode = plan, + .maxDrivers = kNumDrivers, + .breakpoints = toBreakpointsMap({project1}), + }); + + int numOutputs = 0; + + // moveNext() should skip all breakpoints and only return final outputs. + while (cursor->moveNext()) { + EXPECT_EQ(cursor->at(), "") << "moveNext() should skip breakpoints"; + ++numOutputs; + } + + // Each of the 4 drivers processes all 3 batches, producing only final + // outputs (no breakpoint hits). + EXPECT_EQ(numOutputs, kNumDrivers * kNumBatches); +} + +// Tests stepping through multiple input batches with breakpoints on multiple +// operators using parallel execution. With 4 drivers each processing 3 batches +// through 2 breakpoints, we expect 12 hits per breakpoint and 12 final outputs, +// interleaved in non-deterministic order across drivers. +TEST_F(CustomTraceTest, parallelMultipleBreakpoints) { + const size_t size = 10; + auto makeData = [&](std::function values) { + return makeRowVector( + {"a"}, {makeFlatVector(size, std::move(values))}); + }; + + auto input1 = makeData([](auto row) { return row; }); + auto input2 = makeData([](auto row) { return row + 10; }); + auto input3 = makeData([](auto row) { return row + 20; }); + + core::PlanNodeId project1, project2; + auto plan = PlanBuilder() + .values({input1, input2, input3}, /*parallelizable=*/true) + .project({"a * 10 as a"}) + .capturePlanNodeId(project1) + .project({"a * 10 as a"}) + .capturePlanNodeId(project2) + .planNode(); + + constexpr int kNumDrivers = 4; + constexpr int kNumBatches = 3; + + auto cursor = TaskCursor::create({ + .planNode = plan, + .maxDrivers = kNumDrivers, + .breakpoints = toBreakpointsMap({project1, project2}), + }); + + int numProject1Hits = 0; + int numProject2Hits = 0; + int numFinalOutputs = 0; + + while (cursor->moveStep()) { + if (cursor->at() == project1) { + ++numProject1Hits; + } else if (cursor->at() == project2) { + ++numProject2Hits; + } else { + EXPECT_EQ(cursor->at(), ""); + ++numFinalOutputs; + } + } + + // Each of the 4 drivers processes all 3 batches, hitting both breakpoints + // and producing one final output per batch. + EXPECT_EQ(numProject1Hits, kNumDrivers * kNumBatches); + EXPECT_EQ(numProject2Hits, kNumDrivers * kNumBatches); + EXPECT_EQ(numFinalOutputs, kNumDrivers * kNumBatches); +} + +// Tests the debugger cursor with a hash join plan, which creates multiple +// pipelines (probe and build). Verifies that breakpoints on operators in +// different pipelines are all correctly hit. +TEST_F(CustomTraceTest, parallelHashJoinBreakpoints) { + auto probeData = makeRowVector( + {"t_key", "t_val"}, + { + makeFlatVector({1, 2, 3, 4, 5}), + makeFlatVector({10, 20, 30, 40, 50}), + }); + + auto buildData = makeRowVector( + {"u_key", "u_val"}, + { + makeFlatVector({1, 3, 5, 7}), + makeFlatVector({100, 300, 500, 700}), + }); + + auto planNodeIdGenerator = std::make_shared(); + + core::PlanNodeId probeProjectId; + core::PlanNodeId buildProjectId; + core::PlanNodeId joinId; + + auto plan = PlanBuilder(planNodeIdGenerator) + .values({probeData}, /*parallelizable=*/true) + .project({"t_key", "t_val * 10 as t_val"}) + .capturePlanNodeId(probeProjectId) + .hashJoin( + {"t_key"}, + {"u_key"}, + PlanBuilder(planNodeIdGenerator) + .values({buildData}, /*parallelizable=*/true) + .project({"u_key", "u_val * 10 as u_val"}) + .capturePlanNodeId(buildProjectId) + .planNode(), + "", + {"t_key", "t_val", "u_val"}) + .capturePlanNodeId(joinId) + .planNode(); + + constexpr int kNumDrivers = 4; + + auto cursor = TaskCursor::create({ + .planNode = plan, + .maxDrivers = kNumDrivers, + .breakpoints = toBreakpointsMap({probeProjectId, buildProjectId}), + }); + + int numProbeProjectHits = 0; + int numBuildProjectHits = 0; + int numFinalOutputs = 0; + + while (cursor->moveStep()) { + if (cursor->at() == probeProjectId) { + ++numProbeProjectHits; + } else if (cursor->at() == buildProjectId) { + ++numBuildProjectHits; + } else { + EXPECT_EQ(cursor->at(), ""); + ++numFinalOutputs; + } + } + + // Each driver processes the single probe batch (1 breakpoint hit per driver). + EXPECT_EQ(numProbeProjectHits, kNumDrivers); + + // The build pipeline runs with a single driver regardless of maxDrivers, so + // only 1 breakpoint hit for the build project. + EXPECT_GT(numBuildProjectHits, 0); + + // The join should produce at least one final output (the 3 matching rows). + EXPECT_GT(numFinalOutputs, 0); +} + +// Tests that moveStep(planId) only stops at breakpoints matching the specified +// plan node ID, skipping breakpoints for other nodes. +TEST_F(CustomTraceTest, moveStepWithPlanId) { + const size_t size = 10; + auto input1 = makeRowVector( + {"a"}, {makeFlatVector(size, [](auto row) { return row; })}); + auto input2 = makeRowVector( + {"a"}, + {makeFlatVector(size, [](auto row) { return row + 10; })}); + + core::PlanNodeId project1, project2, project3; + auto plan = PlanBuilder() + .values({input1, input2}) + .project({"a * 10 as a"}) + .capturePlanNodeId(project1) + .project({"a * 10 as a"}) + .capturePlanNodeId(project2) + .project({"a * 10 as a"}) + .capturePlanNodeId(project3) + .planNode(); + + // Set breakpoints on all three project nodes. + auto cursor = TaskCursor::create({ + .planNode = plan, + .serialExecution = true, + .breakpoints = toBreakpointsMap({project1, project2, project3}), + }); + + // Step targeting project2 should skip project1 and stop at project2. + EXPECT_TRUE(cursor->moveStep(project2)); + EXPECT_EQ(cursor->at(), project2); + + // Step targeting project3 should skip project2 (for the remaining pipeline) + // and stop at project3. + EXPECT_TRUE(cursor->moveStep(project3)); + EXPECT_EQ(cursor->at(), project3); + + // Step with no filter stops at the next breakpoint (project1 for second + // batch). + EXPECT_TRUE(cursor->moveStep()); + EXPECT_EQ(cursor->at(), ""); + + // First batch final output was produced above. Now step targeting project1 + // for the second input batch. + EXPECT_TRUE(cursor->moveStep(project1)); + EXPECT_EQ(cursor->at(), project1); + + // Step targeting project3 should skip project2 and stop at project3. + EXPECT_TRUE(cursor->moveStep(project3)); + EXPECT_EQ(cursor->at(), project3); + + // Final output for second batch. + EXPECT_TRUE(cursor->moveStep()); + EXPECT_EQ(cursor->at(), ""); + + // No more data. + EXPECT_FALSE(cursor->moveStep()); +} + +// Tests that moveStep(planId) only stops at breakpoints matching the specified +// plan node ID when using moveStep with no filter (empty planId), the original +// behavior is preserved (stop at any breakpoint). +TEST_F(CustomTraceTest, moveStepWithPlanIdDefaultBehavior) { + const size_t size = 10; + auto input = makeRowVector( + {"a"}, {makeFlatVector(size, [](auto row) { return row; })}); + + core::PlanNodeId project1, project2; + auto plan = PlanBuilder() + .values({input}) + .project({"a * 10 as a"}) + .capturePlanNodeId(project1) + .project({"a * 10 as a"}) + .capturePlanNodeId(project2) + .planNode(); + + auto cursor = TaskCursor::create({ + .planNode = plan, + .serialExecution = true, + .breakpoints = toBreakpointsMap({project1, project2}), + }); + + // Default moveStep() (empty planId) should stop at project1. + EXPECT_TRUE(cursor->moveStep()); + EXPECT_EQ(cursor->at(), project1); + + // Default moveStep() should stop at project2. + EXPECT_TRUE(cursor->moveStep()); + EXPECT_EQ(cursor->at(), project2); + + // Final output. + EXPECT_TRUE(cursor->moveStep()); + EXPECT_EQ(cursor->at(), ""); + + EXPECT_FALSE(cursor->moveStep()); +} + +// Tests moveStep(planId) with parallel execution. +TEST_F(CustomTraceTest, parallelMoveStepWithPlanId) { + constexpr int kNumDrivers = 4; + constexpr int kNumBatches = 3; + + std::vector batches; + for (int i = 0; i < kNumBatches; ++i) { + batches.push_back(makeRowVector( + {"a"}, + { + makeFlatVector( + 10, [&](auto row) { return row + i * 100; }), + })); + } + + core::PlanNodeId project1, project2; + auto plan = PlanBuilder() + .values(batches, /*parallelizable=*/true) + .project({"a * 10 as a"}) + .capturePlanNodeId(project1) + .project({"a * 10 as a"}) + .capturePlanNodeId(project2) + .planNode(); + + auto cursor = TaskCursor::create({ + .planNode = plan, + .maxDrivers = kNumDrivers, + .breakpoints = toBreakpointsMap({project1, project2}), + }); + + int numProject2Hits = 0; + int numFinalOutputs = 0; + + // Only step to project2 breakpoints, skipping project1. + while (cursor->moveStep(project2)) { + if (cursor->at() == project2) { + ++numProject2Hits; + } else { + EXPECT_EQ(cursor->at(), ""); + ++numFinalOutputs; + } + } + + // Each of the 4 drivers processes all 3 batches, hitting project2 once per + // batch. + EXPECT_EQ(numProject2Hits, kNumDrivers * kNumBatches); + EXPECT_EQ(numFinalOutputs, kNumDrivers * kNumBatches); +} + +} // namespace +} // namespace facebook::velox::exec::trace::test diff --git a/velox/exec/tests/DriverTest.cpp b/velox/exec/tests/DriverTest.cpp index 275ed3bd345..2d72b9f7d6e 100644 --- a/velox/exec/tests/DriverTest.cpp +++ b/velox/exec/tests/DriverTest.cpp @@ -17,7 +17,7 @@ #include #include #include -#include "folly/experimental/EventCount.h" +#include "folly/synchronization/EventCount.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/testutil/TestValue.h" #include "velox/dwio/common/tests/utils/BatchMaker.h" @@ -124,8 +124,9 @@ class DriverTest : public OperatorTestBase { bool addTestingPauser = false) { std::vector batches; for (int32_t i = 0; i < numBatches; ++i) { - batches.push_back(std::dynamic_pointer_cast( - BatchMaker::createBatch(rowType, rowsInBatch, *pool_))); + batches.push_back( + std::dynamic_pointer_cast( + BatchMaker::createBatch(rowType, rowsInBatch, *pool_))); } if (filterFunc) { int32_t hits = 0; @@ -188,10 +189,11 @@ class DriverTest : public OperatorTestBase { bool paused = false; for (;;) { if (operation == ResultOperation::kPause && paused) { - if (!cursor->hasNext()) { - paused = false; - Task::resume(cursor->task()); - } + // Resume the task so that next() can retrieve more data. + // If there's already buffered data, next() returns it immediately; + // otherwise it will wait for the resumed task to produce output. + paused = false; + Task::resume(cursor->task()); } if (!cursor->next()) { break; @@ -468,10 +470,11 @@ TEST_F(DriverTest, error) { EXPECT_EQ(numRead, 0); EXPECT_TRUE(stateFutures_.at(0).isReady()); // Realized immediately since task not running. - EXPECT_TRUE(tasks_[0] - ->taskCompletionFuture() - .within(std::chrono::microseconds(1'000'000)) - .isReady()); + EXPECT_TRUE( + tasks_[0] + ->taskCompletionFuture() + .within(std::chrono::microseconds(1'000'000)) + .isReady()); EXPECT_EQ(tasks_[0]->state(), TaskState::kFailed); } @@ -799,8 +802,9 @@ TEST_F(DriverTest, pauserNode) { // all its Tasks in the test instance to create inter-Task pauses. static DriverTest* testInstance; testInstance = this; - Operator::registerOperator(std::make_unique( - kThreadsPerTask, sequence, testInstance)); + Operator::registerOperator( + std::make_unique( + kThreadsPerTask, sequence, testInstance)); std::vector params(kNumTasks); int32_t hits{0}; @@ -1614,7 +1618,8 @@ DEBUG_ONLY_TEST_F(DriverTest, driverCpuTimeSlicingCheck) { 0, core::QueryCtx::create( driverExecutor_.get(), core::QueryConfig{std::move(queryConfig)}), - testParam.executionMode); + testParam.executionMode, + exec::Consumer{}); while (task->next() != nullptr) { } } diff --git a/velox/exec/tests/EnforceDistinctTest.cpp b/velox/exec/tests/EnforceDistinctTest.cpp new file mode 100644 index 00000000000..b947ec63892 --- /dev/null +++ b/velox/exec/tests/EnforceDistinctTest.cpp @@ -0,0 +1,133 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" + +using namespace facebook::velox::exec::test; + +namespace facebook::velox::exec { +namespace { + +class EnforceDistinctTest : public OperatorTestBase { + protected: + core::PlanNodePtr makePlan( + const std::vector& input, + const std::vector& keys, + const std::string& errorMessage) { + return PlanBuilder() + .values(input) + .enforceDistinct(keys, errorMessage) + .planNode(); + } + + core::PlanNodePtr makePlan( + const RowVectorPtr& input, + const std::string& key, + const std::string& errorMessage) { + return makePlan( + std::vector{input}, + std::vector{key}, + errorMessage); + } + + void assertDistinct( + const std::vector& input, + const std::vector& keys) { + auto plan = makePlan(input, keys, "Duplicate key found"); + AssertQueryBuilder(plan).assertResults(input); + } +}; + +TEST_F(EnforceDistinctTest, uniqueRowsSingleKey) { + auto data = makeRowVector({ + makeNullableFlatVector({1, 2, 3, std::nullopt, 5, 6, 7, 8, 9}), + makeFlatVector( + {"a", "a", "b", "b", "a", "a", "b", "b", "a"}), + }); + + assertDistinct(split(data, 3), {"c0"}); +} + +TEST_F(EnforceDistinctTest, uniqueRowsMultipleKeys) { + auto data = makeRowVector({ + makeFlatVector({1, 1, 2, 2, 3, 3, 4, 4, 5}), + makeFlatVector( + {"x", "x", "y", "y", "x", "x", "y", "y", "x"}), + makeFlatVector({10, 20, 10, 20, 10, 20, 10, 20, 10}), + }); + + assertDistinct(split(data, 3), {"c0", "c2"}); +} + +TEST_F(EnforceDistinctTest, duplicateWithinBatch) { + auto data = makeRowVector({ + makeFlatVector({1, 2, 3, 2, 5}), + }); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(makePlan(data, "c0", "Duplicate key found")) + .countResults(), + "Duplicate key found"); +} + +TEST_F(EnforceDistinctTest, duplicateAcrossBatches) { + auto batch1 = makeRowVector({ + makeFlatVector({1, 2, 3}), + }); + + auto batch2 = makeRowVector({ + makeFlatVector({4, 2, 6}), + }); + + VELOX_ASSERT_THROW( + AssertQueryBuilder( + makePlan({batch1, batch2}, {"c0"}, "Duplicate key found")) + .countResults(), + "Duplicate key found"); +} + +TEST_F(EnforceDistinctTest, emptyInput) { + auto data = makeRowVector({ + makeFlatVector({}), + }); + + assertDistinct({data}, {"c0"}); +} + +TEST_F(EnforceDistinctTest, singleRow) { + auto data = makeRowVector({ + makeFlatVector({42}), + }); + + assertDistinct({data}, {"c0"}); +} + +TEST_F(EnforceDistinctTest, duplicateNulls) { + auto data = makeRowVector({ + makeNullableFlatVector({1, std::nullopt, 3, std::nullopt, 5}), + }); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(makePlan(data, "c0", "Duplicate key found")) + .countResults(), + "Duplicate key found"); +} + +} // namespace +} // namespace facebook::velox::exec diff --git a/velox/exec/tests/ExchangeClientTest.cpp b/velox/exec/tests/ExchangeClientTest.cpp index f8a6a374eba..27d8404afd0 100644 --- a/velox/exec/tests/ExchangeClientTest.cpp +++ b/velox/exec/tests/ExchangeClientTest.cpp @@ -36,20 +36,19 @@ namespace { static constexpr int32_t kDefaultMinExchangeOutputBatchBytes{2 << 20}; // 2 MB. -class ExchangeClientTest - : public testing::Test, - public velox::test::VectorTestBase, - public testing::WithParamInterface { +class ExchangeClientTest : public testing::Test, + public velox::test::VectorTestBase, + public testing::WithParamInterface { protected: static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kPresto)) { + if (!isRegisteredNamedVectorSerde("Presto")) { serializer::presto::PrestoVectorSerde::registerNamedVectorSerde(); } - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kCompactRow)) { + if (!isRegisteredNamedVectorSerde("CompactRow")) { serializer::CompactRowVectorSerde::registerNamedVectorSerde(); } - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kUnsafeRow)) { + if (!isRegisteredNamedVectorSerde("UnsafeRow")) { serializer::spark::UnsafeRowVectorSerde::registerNamedVectorSerde(); } } @@ -85,13 +84,15 @@ class ExchangeClientTest executor_.get(), core::QueryConfig{std::move(config)}); queryCtx->testingOverrideMemoryPool( memory::memoryManager()->addRootPool(queryCtx->queryId())); - auto plan = test::PlanBuilder().values({}).planNode(); + auto plan = + test::PlanBuilder().values(std::vector{}).planNode(); return Task::create( taskId, core::PlanFragment{plan}, 0, std::move(queryCtx), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); } int32_t enqueue( @@ -108,9 +109,9 @@ class ExchangeClientTest return pageSize; } - std::vector> + std::vector> fetchPages(int consumerId, ExchangeClient& client, int32_t numPages) { - std::vector> allPages; + std::vector> allPages; for (auto i = 0; i < numPages; ++i) { bool atEnd{false}; ContinueFuture future; @@ -138,7 +139,7 @@ class ExchangeClientTest static void enqueue( ExchangeQueue& queue, - std::unique_ptr page) { + std::unique_ptr page) { std::vector promises; { std::lock_guard l(queue.mutex()); @@ -149,17 +150,17 @@ class ExchangeClientTest } } - static std::unique_ptr makePage(uint64_t size) { + static std::unique_ptr makePage(uint64_t size) { auto ioBuf = folly::IOBuf::create(size); ioBuf->append(size); - return std::make_unique(std::move(ioBuf), nullptr, 1); + return std::make_unique(std::move(ioBuf), nullptr, 1); } folly::Executor* executor() const { return executor_.get(); } - VectorSerde::Kind serdeKind_; + std::string serdeKind_; std::unique_ptr executor_; std::shared_ptr bufferManager_; }; @@ -588,7 +589,7 @@ TEST_P(ExchangeClientTest, acknowledge) { SCOPED_TESTVALUE_SET( "facebook::velox::exec::test::LocalExchangeSource::pause", std::function(([&numberOfAcknowledgeRequests](void*) { - numberOfAcknowledgeRequests++; + ++numberOfAcknowledgeRequests; }))); { @@ -609,10 +610,11 @@ TEST_P(ExchangeClientTest, acknowledge) { client->addRemoteTaskId(sourceTaskId); client->noMoreRemoteTasks(); - ASSERT_TRUE(std::move(future) - .via(executor()) - .wait(std::chrono::seconds{10}) - .isReady()); + ASSERT_TRUE( + std::move(future) + .via(executor()) + .wait(std::chrono::seconds{10}) + .isReady()); #ifndef NDEBUG // The client knew there is more data available but could not fetch any more @@ -663,7 +665,7 @@ TEST_P(ExchangeClientTest, acknowledge) { int attempts = 100; bool outputBuffersEmpty; while (attempts > 0) { - attempts--; + --attempts; outputBuffersEmpty = bufferManager_->getUtilization(sourceTaskId) == 0; if (outputBuffersEmpty) { break; @@ -689,10 +691,11 @@ TEST_P(ExchangeClientTest, acknowledge) { pages = client->next(1, 1, &atEnd, &dequeueEndOfDataFuture); ASSERT_EQ(0, pages.size()); - ASSERT_TRUE(std::move(dequeueEndOfDataFuture) - .via(executor()) - .wait(std::chrono::seconds{10}) - .isReady()); + ASSERT_TRUE( + std::move(dequeueEndOfDataFuture) + .via(executor()) + .wait(std::chrono::seconds{10}) + .isReady()); pages = client->next(1, 1, &atEnd, &dequeueEndOfDataFuture); ASSERT_EQ(0, pages.size()); ASSERT_TRUE(atEnd); @@ -976,13 +979,197 @@ TEST_P(ExchangeClientTest, minOutputBatchBytesMultipleConsumers) { client->close(); } +TEST_P(ExchangeClientTest, skipRequestDataSizeWithSingleSource) { + // Test skipRequestDataSizeWithSingleSource flag behavior + + struct { + bool skipEnabled; + + std::string debugString() const { + return fmt::format("skipEnabled={}", skipEnabled); + } + } testSettings[] = { + // skip enabled + {true}, + // skip disabled + {false}}; + + for (const auto& setting : testSettings) { + SCOPED_TRACE(setting.debugString()); + + auto client = std::make_shared( + "test-" + setting.debugString(), + 17, + 1024, + 1, + kDefaultMinExchangeOutputBatchBytes, + pool(), + executor(), + 10, + setting.skipEnabled); + + client->close(); + } +} + +TEST_P(ExchangeClientTest, skipRequestDataSizeNotTriggeredWithMultipleSources) { + // Test that optimization is NOT triggered with multiple sources + + auto data = makeRowVector({makeFlatVector(100, folly::identity)}); + auto page = test::toSerializedPage(data, serdeKind_, bufferManager_, pool()); + + // Client with optimization ENABLED but multiple sources + auto client = std::make_shared( + "test-multi-source", + 17, + page->size() * 10, + 1, + kDefaultMinExchangeOutputBatchBytes, + pool(), + executor(), + 10, + // enableSingleSourceOptimization = true (but won't trigger with + // multiple sources) + true); + + // Setup: Create tasks with TWO sources + std::vector> tasks; + for (int i = 0; i < 2; ++i) { + auto taskId = fmt::format("local://test-source-{}", i); + auto task = makeTask(taskId); + bufferManager_->initializeTask( + task, core::PartitionedOutputNode::Kind::kPartitioned, 100, 16); + + // Enqueue data + for (int j = 0; j < 3; ++j) { + enqueue(taskId, 17, data); + } + + tasks.push_back(task); + client->addRemoteTaskId(taskId); + } + + client->noMoreRemoteTasks(); + + // Fetch pages - should work with regular path (not single source + // optimization) + // 3 pages from each of 2 sources + auto pages = fetchPages(1, *client, 6); + ASSERT_EQ(pages.size(), 6); + + // Cleanup: Signal no more data first to allow faster task termination, + // then cancel and remove tasks. + for (auto& task : tasks) { + bufferManager_->noMoreData(task->taskId()); + } + for (auto& task : tasks) { + task->requestCancel(); + bufferManager_->removeTask(task->taskId()); + } + tasks.clear(); + + client->close(); +} + +// Test that lazyFetching=true defers data fetching until next() is called. +// When lazyFetching=false (default), fetching starts immediately when remote +// tasks are added via pickSourcesToRequestLocked(). When lazyFetching=true, +// pickSourcesToRequestLocked() is not called in addRemoteTaskId(), deferring +// the fetch until next() is called. This is useful for cached hash table +// scenarios where waiter tasks may not need the data if the table is already +// cached. +TEST_P(ExchangeClientTest, lazyFetching) { + auto data = makeRowVector({makeFlatVector({1, 2, 3, 4, 5})}); + + // Test with lazyFetching=false (default behavior). + // Verify that fetching starts and we can retrieve pages normally. + { + auto taskId = "local://eager-fetching-test"; + auto task = makeTask(taskId); + + bufferManager_->initializeTask( + task, core::PartitionedOutputNode::Kind::kPartitioned, 100, 16); + + auto client = std::make_shared( + "t", + 17, + ExchangeClient::kDefaultMaxQueuedBytes, + 1, + kDefaultMinExchangeOutputBatchBytes, + pool(), + executor(), + 10, // requestDataSizesMaxWaitSec + false, // skipRequestDataSizeWithSingleSource + false); // lazyFetching=false (default) + + client->addRemoteTaskId(taskId); + enqueue(taskId, 17, data); + + auto pages = fetchPages(1, *client, 1); + ASSERT_EQ(1, pages.size()); + + task->requestCancel(); + bufferManager_->removeTask(taskId); + task.reset(); + client->close(); + } + + // Test with lazyFetching=true. + // Verify that we can still retrieve pages (fetch is triggered by next()). + { + auto taskId = "local://lazy-fetching-test"; + auto task = makeTask(taskId); + + bufferManager_->initializeTask( + task, core::PartitionedOutputNode::Kind::kPartitioned, 100, 16); + + auto client = std::make_shared( + "t", + 17, + ExchangeClient::kDefaultMaxQueuedBytes, + 1, + kDefaultMinExchangeOutputBatchBytes, + pool(), + executor(), + 10, // requestDataSizesMaxWaitSec + false, // skipRequestDataSizeWithSingleSource + true); // lazyFetching=true + + client->addRemoteTaskId(taskId); + enqueue(taskId, 17, data); + + // Even with lazy fetching, we should be able to retrieve pages + // since next() triggers the fetch. + auto pages = fetchPages(1, *client, 1); + ASSERT_EQ(1, pages.size()); + + task->requestCancel(); + bufferManager_->removeTask(taskId); + task.reset(); + client->close(); + } +} + +// Test the new hasNoMoreSources() API +TEST_P(ExchangeClientTest, hasNoMoreSourcesApi) { + auto queue = std::make_shared(1, 0); + + // Initially, should return false + EXPECT_FALSE(queue->hasNoMoreSources()); + + // After calling noMoreSources(), should return true + queue->noMoreSources(); + + EXPECT_TRUE(queue->hasNoMoreSources()); +} + VELOX_INSTANTIATE_TEST_SUITE_P( ExchangeClientTest, ExchangeClientTest, - testing::Values( - VectorSerde::Kind::kPresto, - VectorSerde::Kind::kCompactRow, - VectorSerde::Kind::kUnsafeRow)); + testing::Values("Presto", "CompactRow", "UnsafeRow"), + [](const testing::TestParamInfo& info) { + return fmt::format("{}", info.param); + }); } // namespace } // namespace facebook::velox::exec diff --git a/velox/exec/tests/ExpandTest.cpp b/velox/exec/tests/ExpandTest.cpp index f541d20c04a..c9fec2a0d7f 100644 --- a/velox/exec/tests/ExpandTest.cpp +++ b/velox/exec/tests/ExpandTest.cpp @@ -21,7 +21,6 @@ using namespace facebook::velox; using namespace facebook::velox::exec::test; namespace facebook::velox::exec { - namespace { class ExpandTest : public OperatorTestBase { public: @@ -37,7 +36,31 @@ class ExpandTest : public OperatorTestBase { }); } }; -} // anonymous namespace + +TEST_F(ExpandTest, complexConstant) { + auto data = makeRowVectorData(3); + auto children = data->children(); + auto arrayVector = + makeArrayVector({{1, 2, 3}, {1, 2, 3}, {1, 2, 3}}); + children.push_back(arrayVector); + children.push_back(makeAllNullArrayVector(3, BIGINT())); + children.push_back(makeNullConstant(TypeKind::BIGINT, 3)); + auto expected = makeRowVector(children); + + auto plan = PlanBuilder(pool()) + .values({data}) + .expand( + {{"k1", + "k2", + "a", + "b", + "ARRAY[1, 2, 3] as c", + "null::bigint[] as d", + "null::bigint as e"}}) + .planNode(); + + assertQuery(plan, expected); +} TEST_F(ExpandTest, groupingSets) { auto data = makeRowVectorData(1'000); @@ -151,4 +174,5 @@ TEST_F(ExpandTest, invalidUseCases) { "projections must not be empty."); } +} // namespace } // namespace facebook::velox::exec diff --git a/velox/exec/tests/ExpressionBuilderTest.cpp b/velox/exec/tests/ExpressionBuilderTest.cpp new file mode 100644 index 00000000000..18f0a3dc879 --- /dev/null +++ b/velox/exec/tests/ExpressionBuilderTest.cpp @@ -0,0 +1,195 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/tests/utils/ExpressionBuilder.h" + +#include +#include "velox/parse/ExpressionsParser.h" +#include "velox/type/Variant.h" + +namespace facebook::velox::expr_builder::test { +namespace { + +// Test convenience functions for downcasting. +template +std::shared_ptr as(detail::ExprWrapper in) { + return std::dynamic_pointer_cast(in.expr()); +} + +template +bool is(detail::ExprWrapper in) { + return std::dynamic_pointer_cast(in.expr()) != nullptr; +} + +// Parses a SQL expression using DuckDB. +core::ExprPtr parseSql(const std::string& sql) { + return parse::DuckSqlExpressionsParser().parseExpr(sql); +} + +TEST(ExpressionBuilderTest, columnReference) { + EXPECT_EQ(col("c0"), parseSql("c0")); + EXPECT_EQ(parseSql("c0"), col("c0")); + EXPECT_EQ("c0"_c, parseSql("c0")); + + EXPECT_EQ(col("parent", "child"), parseSql("parent.child")); + EXPECT_EQ(col("parent").subfield("child"), parseSql("parent.child")); +} + +TEST(ExpressionBuilderTest, literals) { + auto validate = [](detail::ExprWrapper expr, + const TypePtr& expectedType, + variant expectedValue) { + EXPECT_TRUE(is(expr)); + auto constant = as(expr); + EXPECT_EQ(*constant->type(), *expectedType); + EXPECT_TRUE(constant->value().equalsWithEpsilon(expectedValue)); + }; + + // Integer literal types. + validate(lit(123456L), BIGINT(), variant(123456L)); + validate(lit(123), INTEGER(), variant(123)); + validate(lit(int16_t(123)), SMALLINT(), variant(int16_t(123))); + validate(lit(int8_t(123)), TINYINT(), variant(int8_t(123))); + + // Boolean. + validate(lit(true), BOOLEAN(), variant(true)); + validate(lit(false), BOOLEAN(), variant(false)); + + // Floating point. + validate(lit(10.1f), REAL(), variant(10.1f)); + validate(lit(10.1), DOUBLE(), variant(10.1)); + + // String. + validate(lit("str"), VARCHAR(), variant("str")); + + // Null. + validate(lit(nullptr), UNKNOWN(), variant::null(TypeKind::UNKNOWN)); +} + +TEST(ExpressionBuilderTest, comparisons) { + // Make sure all combinations work, as long as at least one side is a + // ExprWrapper. + EXPECT_EQ(col("a") == lit(10L), parseSql("a = 10")); + EXPECT_EQ(lit(10L) == col("a"), parseSql("10 = a")); + + EXPECT_EQ(col("a") == 10L, parseSql("a = 10")); + EXPECT_EQ(10L == col("a"), parseSql("10 = a")); + + EXPECT_EQ(col("a") == col("b"), parseSql("a = b")); + EXPECT_EQ(col("a") == nullptr, parseSql("a = null")); + + // Other comparisons. + EXPECT_EQ(col("a") != 1.1, parseSql("a != 1.1")); + EXPECT_EQ(col("a") != lit(1.1), parseSql("a != 1.1")); + EXPECT_EQ(col("a") > 42L, parseSql("a > 42")); + EXPECT_EQ(col("a") >= 42L, parseSql("a >= 42")); + EXPECT_EQ(col("a") < 42L, parseSql("a < 42")); + EXPECT_EQ(col("a") <= 42L, parseSql("a <= 42")); + + EXPECT_EQ(!col("a"), parseSql("not a")); + EXPECT_EQ(isNull(col("a")), parseSql("a is null")); + EXPECT_EQ(col("a").isNull(), parseSql("a is null")); + EXPECT_EQ(!isNull(col("a")), parseSql("a is not null")); + EXPECT_EQ(!col("a").isNull(), parseSql("a is not null")); + + EXPECT_EQ(isNull("a"), parseSql("\'a\' is null")); // this is "a" literal. +} + +TEST(ExpressionBuilderTest, between) { + EXPECT_EQ(between(col("a"), 0L, 10L), parseSql("a between 0 and 10")); + + EXPECT_EQ(col("a").between(0L, 10L), parseSql("a between 0 and 10")); +} + +TEST(ExpressionBuilderTest, arithmetics) { + EXPECT_EQ(col("b") + 1L, parseSql("b + 1")); + EXPECT_EQ(1L + col("b"), parseSql("1 + b")); + EXPECT_EQ(lit("str") + col("b"), parseSql("'str' + b")); + + EXPECT_EQ(col("b") - 1L, parseSql("b - 1")); + EXPECT_EQ(col("b") * 1L, parseSql("b * 1")); + EXPECT_EQ(col("b") / 1L, parseSql("b / 1")); + EXPECT_EQ(col("b") % 1L, parseSql("b % 1")); + + EXPECT_EQ(col("b") + 1L / col("c") * 10L, parseSql("b + 1 / c * 10")); +} + +TEST(ExpressionBuilderTest, conjuncts) { + EXPECT_EQ(col("b") && 1L, parseSql("b and 1")); + EXPECT_EQ(col("b") || 1L, parseSql("b or 1")); + EXPECT_EQ(col("b") || false, parseSql("b or false")); + + EXPECT_EQ(col("a") && col("b") || col("c"), parseSql("a and b or c")); +} + +TEST(ExpressionBuilderTest, functions) { + EXPECT_EQ(call("func"), parseSql("func()")); + EXPECT_EQ( + call("func", col("a"), 100L, col("c")), parseSql("func(a, 100, c)")); + + // Nested functions. + auto expr = call("f1", call("f2", col("a") > call("f3", col("d")))); + EXPECT_EQ(expr, parseSql("f1(f2(a > f3(d)))")); + + expr = 10L * col("c1") > call("func", 3.4, col("g") / col("h"), call("j")); + EXPECT_EQ(expr, parseSql("10 * c1 > func(3.4, g / h, j())")); +} + +TEST(ExpressionBuilderTest, casts) { + // Casts. + EXPECT_EQ(lit("1").cast(TINYINT()).toString(), "cast(1 as TINYINT)"); + EXPECT_EQ( + col("c0").cast(VARBINARY()).toString(), "cast(\"c0\" as VARBINARY)"); + + EXPECT_EQ(cast(1, TINYINT()).toString(), "cast(1 as TINYINT)"); + EXPECT_EQ( + cast(col("c0"), VARBINARY()).toString(), "cast(\"c0\" as VARBINARY)"); + + // Try casts. + EXPECT_EQ(lit("1").tryCast(TINYINT()).toString(), "try_cast(1 as TINYINT)"); + EXPECT_EQ( + col("c0").tryCast(VARBINARY()).toString(), + "try_cast(\"c0\" as VARBINARY)"); + + EXPECT_EQ(tryCast(1, TINYINT()).toString(), "try_cast(1 as TINYINT)"); + EXPECT_EQ( + tryCast(col("c0"), VARBINARY()).toString(), + "try_cast(\"c0\" as VARBINARY)"); +} + +TEST(ExpressionBuilderTest, alias) { + EXPECT_EQ(lit("str").alias("col"), parseSql("'str' as col")); + EXPECT_EQ(col("c1").alias("col"), parseSql("c1 as col")); + EXPECT_EQ((col("c1") > 1.1).alias("col"), parseSql("c1 > 1.1 as col")); + + EXPECT_EQ( + col("c1").between(1L, 10L).alias("my_col"), + parseSql("c1 between 1 and 10 as my_col")); + + // As a free function. + EXPECT_EQ(alias(col("c1") == "bla", "col"), parseSql("c1 = 'bla' as col")); +} + +TEST(ExpressionBuilderTest, lambdas) { + EXPECT_EQ(lambda("x", 1L), parseSql("x -> 1")); + EXPECT_EQ(lambda({"x"}, 1L), parseSql("x -> 1")); + EXPECT_EQ(lambda({"x"}, col("x") + 1L), parseSql("x -> x + 1")); + EXPECT_EQ( + lambda({"x", "y"}, col("x") * col("y")), parseSql("(x, y) -> x * y")); +} + +} // namespace +} // namespace facebook::velox::expr_builder::test diff --git a/velox/exec/tests/FilterProjectTest.cpp b/velox/exec/tests/FilterProjectTest.cpp index d9923793ce7..37e6c2d5cff 100644 --- a/velox/exec/tests/FilterProjectTest.cpp +++ b/velox/exec/tests/FilterProjectTest.cpp @@ -20,6 +20,7 @@ #include "velox/exec/tests/utils/PlanBuilder.h" namespace facebook::velox::exec { +using namespace facebook::velox::common::testutil; namespace { class FilterProjectTest : public test::HiveConnectorTestBase { @@ -351,7 +352,7 @@ TEST_F(FilterProjectTest, statsSplitter) { .planNode(); std::shared_ptr task; - test::AssertQueryBuilder(plan).runWithoutResults(task); + test::AssertQueryBuilder(plan).countResults(task); auto planStats = toPlanStats(task->taskStats()); @@ -377,11 +378,11 @@ TEST_F(FilterProjectTest, statsSplitter) { TEST_F(FilterProjectTest, barrier) { std::vector vectors; - std::vector> tempFiles; + std::vector> tempFiles; const int numSplits{5}; for (int32_t i = 0; i < numSplits; ++i) { vectors.push_back(makeTestVector()); - tempFiles.push_back(test::TempFilePath::create()); + tempFiles.push_back(TempFilePath::create()); } writeToFiles(toFilePaths(tempFiles), vectors); createDuckDbTable(vectors); @@ -394,16 +395,27 @@ TEST_F(FilterProjectTest, barrier) { .capturePlanNodeId(projectPlanNodeId) .planNode(); struct { + bool serialExecution; bool barrierExecution; int numOutputRows; std::string toString() const { return fmt::format( - "barrierExecution {}, numOutputRows {}", + "serialExecution {}, barrierExecution {}, numOutputRows {}", + serialExecution, barrierExecution, numOutputRows); } - } testSettings[] = {{true, 23}, {false, 23}, {true, 200}, {false, 200}}; + } testSettings[] = { + {true, true, 23}, + {true, false, 23}, + {false, true, 23}, + {false, false, 23}, + {true, true, 200}, + {true, false, 200}, + {false, true, 200}, + {false, false, 200}, + }; for (const auto& testData : testSettings) { SCOPED_TRACE(testData.toString()); auto task = @@ -415,7 +427,8 @@ TEST_F(FilterProjectTest, barrier) { core::QueryConfig::kPreferredOutputBatchRows, std::to_string(testData.numOutputRows)) .splits(makeHiveConnectorSplits(tempFiles)) - .serialExecution(true) + .serialExecution(testData.serialExecution) + .maxDrivers(testData.serialExecution ? 1 : 3) .barrierExecution(testData.barrierExecution) .assertResults("SELECT c0, c1, c0 + c1 FROM tmp WHERE c1 % 10 > 0"); const auto taskStats = task->taskStats(); @@ -441,7 +454,7 @@ TEST_F(FilterProjectTest, lazyDereference) { makeRowVector({expected[0], expected[1]}), makeRowVector({makeRowVector({expected[2]})}), }); - auto file = test::TempFilePath::create(); + auto file = TempFilePath::create(); writeToFile(file->getPath(), vector); CursorParameters params; params.copyResult = false; diff --git a/velox/exec/tests/FilterToExpressionTest.cpp b/velox/exec/tests/FilterToExpressionTest.cpp index 0e66e865fe1..8f84e04fede 100644 --- a/velox/exec/tests/FilterToExpressionTest.cpp +++ b/velox/exec/tests/FilterToExpressionTest.cpp @@ -16,9 +16,6 @@ #include "velox/exec/tests/utils/FilterToExpression.h" #include #include "velox/core/Expressions.h" -#include "velox/core/QueryCtx.h" -#include "velox/expression/Expr.h" -#include "velox/expression/ExprToSubfieldFilter.h" #include "velox/vector/tests/utils/VectorTestBase.h" namespace facebook::velox::core::test { @@ -30,14 +27,6 @@ class FilterToExpressionTest : public testing::Test, memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); } - // Helper method to create a row type for testing - RowTypePtr createTestRowType() { - return ROW( - {"a", "b", "c", "d", "e", "f"}, - {BIGINT(), DOUBLE(), VARCHAR(), BOOLEAN(), REAL(), TIMESTAMP()}); - } - - // Helper method to verify expression type and structure void verifyExpr( const TypedExprPtr& expr, const std::string& expectedType, @@ -50,28 +39,15 @@ class FilterToExpressionTest : public testing::Test, ASSERT_EQ(callExpr->name(), expectedName); } - // Helper method for round trip testing - void testRoundTrip( - const std::string& fieldName, - std::unique_ptr filter); - - core::ExpressionEvaluator* evaluator() { - return &evaluator_; + TypedExprPtr toExpr(const common::Filter* filter, const TypePtr& type) { + common::Subfield subfield("a"); + return filterToExpr(subfield, filter, ROW({"a"}, {type}), pool()); } - - private: - std::shared_ptr pool_ = - memory::memoryManager()->addLeafPool(); - std::shared_ptr queryCtx_{core::QueryCtx::create()}; - exec::SimpleExpressionEvaluator evaluator_{queryCtx_.get(), pool_.get()}; }; -TEST_F(FilterToExpressionTest, AlwaysTrue) { +TEST_F(FilterToExpressionTest, alwaysTrue) { auto filter = std::make_unique(); - common::Subfield subfield("a"); - auto rowType = createTestRowType(); - - auto expr = filterToExpr(subfield, filter.get(), rowType, pool()); + auto expr = toExpr(filter.get(), BIGINT()); ASSERT_TRUE(expr != nullptr); ASSERT_EQ(expr->type()->toString(), "BOOLEAN"); @@ -81,12 +57,9 @@ TEST_F(FilterToExpressionTest, AlwaysTrue) { ASSERT_TRUE(constantExpr->value().value()); } -TEST_F(FilterToExpressionTest, AlwaysFalse) { +TEST_F(FilterToExpressionTest, alwaysFalse) { auto filter = std::make_unique(); - common::Subfield subfield("a"); - auto rowType = createTestRowType(); - - auto expr = filterToExpr(subfield, filter.get(), rowType, pool()); + auto expr = toExpr(filter.get(), BIGINT()); ASSERT_TRUE(expr != nullptr); ASSERT_EQ(expr->type()->toString(), "BOOLEAN"); @@ -96,24 +69,18 @@ TEST_F(FilterToExpressionTest, AlwaysFalse) { ASSERT_FALSE(constantExpr->value().value()); } -TEST_F(FilterToExpressionTest, IsNull) { +TEST_F(FilterToExpressionTest, isNull) { auto filter = std::make_unique(); - common::Subfield subfield("a"); - auto rowType = createTestRowType(); - - auto expr = filterToExpr(subfield, filter.get(), rowType, pool()); + auto expr = toExpr(filter.get(), BIGINT()); verifyExpr(expr, "BOOLEAN", "is_null"); auto callExpr = std::dynamic_pointer_cast(expr); ASSERT_EQ(callExpr->inputs().size(), 1); } -TEST_F(FilterToExpressionTest, IsNotNull) { +TEST_F(FilterToExpressionTest, isNotNull) { auto filter = std::make_unique(); - common::Subfield subfield("a"); - auto rowType = createTestRowType(); - - auto expr = filterToExpr(subfield, filter.get(), rowType, pool()); + auto expr = toExpr(filter.get(), BIGINT()); verifyExpr(expr, "BOOLEAN", "not"); auto callExpr = std::dynamic_pointer_cast(expr); @@ -126,12 +93,9 @@ TEST_F(FilterToExpressionTest, IsNotNull) { ASSERT_EQ(isNullExpr->name(), "is_null"); } -TEST_F(FilterToExpressionTest, BoolValue) { +TEST_F(FilterToExpressionTest, boolValue) { auto filter = std::make_unique(true, false); - common::Subfield subfield("d"); - auto rowType = createTestRowType(); - - auto expr = filterToExpr(subfield, filter.get(), rowType, pool()); + auto expr = toExpr(filter.get(), BOOLEAN()); verifyExpr(expr, "BOOLEAN", "eq"); auto callExpr = std::dynamic_pointer_cast(expr); @@ -148,12 +112,9 @@ TEST_F(FilterToExpressionTest, BoolValue) { ASSERT_EQ(constantExpr->value().value(), true); } -TEST_F(FilterToExpressionTest, BigintRangeSingleValue) { +TEST_F(FilterToExpressionTest, bigintRangeSingleValue) { auto filter = std::make_unique(42, 42, false); - common::Subfield subfield("a"); - auto rowType = createTestRowType(); - - auto expr = filterToExpr(subfield, filter.get(), rowType, pool()); + auto expr = toExpr(filter.get(), BIGINT()); verifyExpr(expr, "BOOLEAN", "eq"); auto callExpr = std::dynamic_pointer_cast(expr); @@ -165,12 +126,9 @@ TEST_F(FilterToExpressionTest, BigintRangeSingleValue) { ASSERT_EQ(constantExpr->value().value(), 42); } -TEST_F(FilterToExpressionTest, BigintRangeWithRange) { +TEST_F(FilterToExpressionTest, bigintRangeWithRange) { auto filter = std::make_unique(10, 20, false); - common::Subfield subfield("a"); - auto rowType = createTestRowType(); - - auto expr = filterToExpr(subfield, filter.get(), rowType, pool()); + auto expr = toExpr(filter.get(), BIGINT()); verifyExpr(expr, "BOOLEAN", "and"); auto callExpr = std::dynamic_pointer_cast(expr); @@ -187,12 +145,9 @@ TEST_F(FilterToExpressionTest, BigintRangeWithRange) { ASSERT_EQ(lessOrEqual->name(), "lte"); } -TEST_F(FilterToExpressionTest, NegatedBigintRangeSingleValue) { +TEST_F(FilterToExpressionTest, negatedBigintRangeSingleValue) { auto filter = std::make_unique(42, 42, false); - common::Subfield subfield("a"); - auto rowType = createTestRowType(); - - auto expr = filterToExpr(subfield, filter.get(), rowType, pool()); + auto expr = toExpr(filter.get(), BIGINT()); // The implementation now uses getNonNegated() which creates a NOT expression // even for single values, so we expect "not" instead of "neq" @@ -234,13 +189,10 @@ TEST_F(FilterToExpressionTest, NegatedBigintRangeSingleValue) { } } -TEST_F(FilterToExpressionTest, DoubleRange) { +TEST_F(FilterToExpressionTest, doubleRange) { auto filter = std::make_unique( 1.5, false, false, 3.5, false, false, false); - common::Subfield subfield("b"); - auto rowType = createTestRowType(); - - auto expr = filterToExpr(subfield, filter.get(), rowType, pool()); + auto expr = toExpr(filter.get(), DOUBLE()); verifyExpr(expr, "BOOLEAN", "and"); auto callExpr = std::dynamic_pointer_cast(expr); @@ -257,13 +209,10 @@ TEST_F(FilterToExpressionTest, DoubleRange) { ASSERT_EQ(lessOrEqual->name(), "lte"); } -TEST_F(FilterToExpressionTest, FloatRange) { +TEST_F(FilterToExpressionTest, floatRange) { auto filter = std::make_unique( 1.5f, false, true, 3.5f, false, true, false); - common::Subfield subfield("e"); - auto rowType = createTestRowType(); - - auto expr = filterToExpr(subfield, filter.get(), rowType, pool()); + auto expr = toExpr(filter.get(), REAL()); verifyExpr(expr, "BOOLEAN", "and"); auto callExpr = std::dynamic_pointer_cast(expr); @@ -280,13 +229,10 @@ TEST_F(FilterToExpressionTest, FloatRange) { ASSERT_EQ(lessThan->name(), "lt"); } -TEST_F(FilterToExpressionTest, BytesRange) { +TEST_F(FilterToExpressionTest, bytesRange) { auto filter = std::make_unique( "apple", false, false, "orange", false, false, false); - common::Subfield subfield("c"); - auto rowType = createTestRowType(); - - auto expr = filterToExpr(subfield, filter.get(), rowType, pool()); + auto expr = toExpr(filter.get(), VARCHAR()); verifyExpr(expr, "BOOLEAN", "and"); auto callExpr = std::dynamic_pointer_cast(expr); @@ -303,13 +249,10 @@ TEST_F(FilterToExpressionTest, BytesRange) { ASSERT_EQ(lessOrEqual->name(), "lte"); } -TEST_F(FilterToExpressionTest, BigintValuesUsingHashTable) { +TEST_F(FilterToExpressionTest, bigintValuesUsingHashTable) { std::vector values = {10, 20, 30}; auto filter = common::createBigintValues(values, false); - common::Subfield subfield("a"); - auto rowType = createTestRowType(); - - auto expr = filterToExpr(subfield, filter.get(), rowType, pool()); + auto expr = toExpr(filter.get(), BIGINT()); // The implementation creates an optimized expression: (range check) AND (in // check) @@ -337,13 +280,10 @@ TEST_F(FilterToExpressionTest, BigintValuesUsingHashTable) { ASSERT_EQ(arrayExpr->inputs().size(), 3); } -TEST_F(FilterToExpressionTest, BytesValues) { +TEST_F(FilterToExpressionTest, bytesValues) { std::vector values = {"apple", "banana", "orange"}; auto filter = std::make_unique(values, false); - common::Subfield subfield("c"); - auto rowType = createTestRowType(); - - auto expr = filterToExpr(subfield, filter.get(), rowType, pool()); + auto expr = toExpr(filter.get(), VARCHAR()); verifyExpr(expr, "BOOLEAN", "in"); auto callExpr = std::dynamic_pointer_cast(expr); @@ -356,13 +296,10 @@ TEST_F(FilterToExpressionTest, BytesValues) { ASSERT_EQ(arrayExpr->inputs().size(), 3); } -TEST_F(FilterToExpressionTest, NegatedBytesValues) { +TEST_F(FilterToExpressionTest, negatedBytesValues) { std::vector values = {"apple", "banana", "orange"}; auto filter = std::make_unique(values, false); - common::Subfield subfield("c"); - auto rowType = createTestRowType(); - - auto expr = filterToExpr(subfield, filter.get(), rowType, pool()); + auto expr = toExpr(filter.get(), VARCHAR()); verifyExpr(expr, "BOOLEAN", "not"); auto callExpr = std::dynamic_pointer_cast(expr); @@ -375,14 +312,11 @@ TEST_F(FilterToExpressionTest, NegatedBytesValues) { ASSERT_TRUE(containsExpr->name() == "in" || containsExpr->name() == "or"); } -TEST_F(FilterToExpressionTest, NegatedBigintValuesUsingHashTable) { +TEST_F(FilterToExpressionTest, negatedBigintValuesUsingHashTable) { std::vector values = {10, 20, 30}; auto filter = std::make_unique( 10, 30, values, false); - common::Subfield subfield("a"); - auto rowType = createTestRowType(); - - auto expr = filterToExpr(subfield, filter.get(), rowType, pool()); + auto expr = toExpr(filter.get(), BIGINT()); // The implementation creates a NOT expression for the optimized IN check verifyExpr(expr, "BOOLEAN", "not"); @@ -409,15 +343,12 @@ TEST_F(FilterToExpressionTest, NegatedBigintValuesUsingHashTable) { ASSERT_EQ(isNullExpr->name(), "is_null"); } -TEST_F(FilterToExpressionTest, TimestampRange) { +TEST_F(FilterToExpressionTest, timestampRange) { auto timestamp1 = Timestamp::fromMillis(1609459200000); // 2021-01-01 auto timestamp2 = Timestamp::fromMillis(1640995200000); // 2022-01-01 auto filter = std::make_unique(timestamp1, timestamp2, false); - common::Subfield subfield("f"); - auto rowType = createTestRowType(); - - auto expr = filterToExpr(subfield, filter.get(), rowType, pool()); + auto expr = toExpr(filter.get(), TIMESTAMP()); verifyExpr(expr, "BOOLEAN", "and"); auto callExpr = std::dynamic_pointer_cast(expr); @@ -434,23 +365,20 @@ TEST_F(FilterToExpressionTest, TimestampRange) { ASSERT_EQ(lessOrEqual->name(), "lte"); } -TEST_F(FilterToExpressionTest, BigintMultiRange) { +TEST_F(FilterToExpressionTest, bigintMultiRange) { std::vector> ranges; ranges.push_back(std::make_unique(10, 20, false)); ranges.push_back(std::make_unique(30, 40, false)); auto filter = std::make_unique(std::move(ranges), false); - common::Subfield subfield("a"); - auto rowType = createTestRowType(); - - auto expr = filterToExpr(subfield, filter.get(), rowType, pool()); + auto expr = toExpr(filter.get(), BIGINT()); verifyExpr(expr, "BOOLEAN", "or"); auto callExpr = std::dynamic_pointer_cast(expr); ASSERT_EQ(callExpr->inputs().size(), 2); } -TEST_F(FilterToExpressionTest, MultiRange) { +TEST_F(FilterToExpressionTest, multiRange) { // Create a MultiRange filter with compatible filters for BIGINT field std::vector> filters; @@ -464,10 +392,7 @@ TEST_F(FilterToExpressionTest, MultiRange) { filters.push_back(std::make_unique(30, 40, false)); auto filter = std::make_unique(std::move(filters), false); - common::Subfield subfield("a"); - auto rowType = createTestRowType(); - - auto expr = filterToExpr(subfield, filter.get(), rowType, pool()); + auto expr = toExpr(filter.get(), BIGINT()); // Verify the top-level expression is an OR verifyExpr(expr, "BOOLEAN", "or"); @@ -497,132 +422,4 @@ TEST_F(FilterToExpressionTest, MultiRange) { ASSERT_EQ(thirdInput->inputs().size(), 2); } -// Helper method for round trip testing -void FilterToExpressionTest::testRoundTrip( - const std::string& fieldName, - std::unique_ptr filter) { - // Step 1: Convert filter to expression - common::Subfield subfield(fieldName); - auto rowType = createTestRowType(); - auto expr = filterToExpr(subfield, filter.get(), rowType, pool()); - ASSERT_TRUE(expr != nullptr); - - // Step 2: Convert expression back to filter - auto callExpr = std::dynamic_pointer_cast(expr); - if (!callExpr) { - // Some filters like AlwaysTrue/AlwaysFalse convert to ConstantTypedExpr - // which can't be converted back to a filter - return; - } - - // Special handling for BoolValue filter - if (filter->kind() == common::FilterKind::kBoolValue) { - // For BoolValue, we need to extract the eq expression from the and - // expression - if (callExpr->name() == "and" && callExpr->inputs().size() == 2) { - auto eqExpr = - std::dynamic_pointer_cast(callExpr->inputs()[0]); - if (eqExpr && eqExpr->name() == "eq") { - callExpr = eqExpr; - } - } - } - - // Special handling for "in" with array_constructor - if (callExpr->name() == "in" && callExpr->inputs().size() == 2) { - auto arrayExpr = - std::dynamic_pointer_cast(callExpr->inputs()[1]); - if (arrayExpr && arrayExpr->name() == "array_constructor") { - // Use toSubfieldFilter for array_constructor expressions - auto [roundTripSubfield, roundTripFilter] = - exec::toSubfieldFilter(expr, evaluator()); - - // Step 3: Verify the round-tripped filter and subfield - ASSERT_TRUE(roundTripFilter != nullptr); - ASSERT_EQ(roundTripSubfield.toString(), subfield.toString()); - - // Compare filter properties - this will vary based on filter type - // For this test we'll just verify the filter kind is the same - ASSERT_EQ(roundTripFilter->kind(), filter->kind()); - return; - } - } - - // Special handling for range filters (and expressions) - if (callExpr->name() == "and" && callExpr->inputs().size() == 2) { - auto firstInput = - std::dynamic_pointer_cast(callExpr->inputs()[0]); - auto secondInput = - std::dynamic_pointer_cast(callExpr->inputs()[1]); - - if (firstInput && secondInput && firstInput->name() == "gte" && - secondInput->name() == "lte") { - // Extract the field and bounds - auto field = firstInput->inputs()[0]; - auto lowerBound = firstInput->inputs()[1]; - auto upperBound = secondInput->inputs()[1]; - - // Create a between expression - auto betweenExpr = std::make_shared( - callExpr->type(), "between", field, lowerBound, upperBound); - - common::Subfield roundTripSubfield; - auto roundTripFilter = - exec::ExprToSubfieldFilterParser::getInstance() - ->leafCallToSubfieldFilter( - *betweenExpr, roundTripSubfield, evaluator(), false); - - // Step 3: Verify the round-tripped filter and subfield - ASSERT_TRUE(roundTripFilter != nullptr); - ASSERT_EQ(roundTripSubfield.toString(), subfield.toString()); - - // Compare filter properties - this will vary based on filter type - // For this test we'll just verify the filter kind is the same - ASSERT_EQ(roundTripFilter->kind(), filter->kind()); - return; - } - } - - // For all other expressions, use leafCallToSubfieldFilter directly - common::Subfield roundTripSubfield; - auto roundTripFilter = - exec::ExprToSubfieldFilterParser::getInstance()->leafCallToSubfieldFilter( - *callExpr, roundTripSubfield, evaluator(), false); - - // Step 3: Verify the round-tripped filter and subfield - ASSERT_TRUE(roundTripFilter != nullptr); - ASSERT_EQ(roundTripSubfield.toString(), subfield.toString()); - - // Compare filter properties - this will vary based on filter type - // For this test we'll just verify the filter kind is the same - ASSERT_EQ(roundTripFilter->kind(), filter->kind()); -} - -// Round trip tests for various filter types -TEST_F(FilterToExpressionTest, RoundTripBigintRangeSingleValue) { - auto filter = std::make_unique(42, 42, false); - testRoundTrip("a", std::move(filter)); -} - -TEST_F(FilterToExpressionTest, RoundTripBigintRangeWithRange) { - auto filter = std::make_unique(10, 20, false); - testRoundTrip("a", std::move(filter)); -} - -TEST_F(FilterToExpressionTest, RoundTripIsNull) { - auto filter = std::make_unique(); - testRoundTrip("a", std::move(filter)); -} - -TEST_F(FilterToExpressionTest, RoundTripBoolValue) { - auto filter = std::make_unique(true, false); - testRoundTrip("d", std::move(filter)); -} - -TEST_F(FilterToExpressionTest, RoundTripBytesRange) { - auto filter = std::make_unique( - "apple", false, false, "orange", false, false, false); - testRoundTrip("c", std::move(filter)); -} - } // namespace facebook::velox::core::test diff --git a/velox/exec/tests/FunctionSignatureBuilderTest.cpp b/velox/exec/tests/FunctionSignatureBuilderTest.cpp index 6ac14b0592a..832f32c957a 100644 --- a/velox/exec/tests/FunctionSignatureBuilderTest.cpp +++ b/velox/exec/tests/FunctionSignatureBuilderTest.cpp @@ -45,11 +45,12 @@ TEST_F(FunctionSignatureBuilderTest, basicTypeTests) { // Integer variables do not have to be used in the inputs, but in that case // must appear in the return. - ASSERT_NO_THROW(FunctionSignatureBuilder() - .integerVariable("a") - .returnType("DECIMAL(a, a)") - .argumentType("integer") - .build();); + ASSERT_NO_THROW( + FunctionSignatureBuilder() + .integerVariable("a") + .returnType("DECIMAL(a, a)") + .argumentType("integer") + .build();); VELOX_ASSERT_THROW( FunctionSignatureBuilder() @@ -124,7 +125,7 @@ TEST_F(FunctionSignatureBuilderTest, typeParamTests) { .returnType("integer") .argumentType("row(..., varchar)") .build(), - "Failed to parse type signature [row(..., varchar)]: syntax error, unexpected COMMA"); + "Failed to parse type signature [row(..., varchar)]: syntax error, unexpected ELLIPSIS"); // Type params cant have type params. VELOX_ASSERT_THROW( @@ -155,6 +156,60 @@ TEST_F(FunctionSignatureBuilderTest, anyInReturn) { "Type 'Any' cannot appear in return type"); } +TEST_F(FunctionSignatureBuilderTest, homogeneousRowInReturn) { + VELOX_ASSERT_USER_THROW( + exec::FunctionSignatureBuilder() + .typeVariable("T") + .returnType("row(T, ...)") + .argumentType("T") + .build(), + "Homogeneous row cannot appear in return type"); + + VELOX_ASSERT_USER_THROW( + exec::FunctionSignatureBuilder() + .returnType("array(row(bigint, ...))") + .argumentType("bigint") + .build(), + "Homogeneous row cannot appear in return type"); +} + +TEST_F(FunctionSignatureBuilderTest, variableArity) { + // .variableArity() requires at least one argument. + VELOX_ASSERT_USER_THROW( + exec::FunctionSignatureBuilder() + .returnType("bigint") + .variableArity() + .build(), + "Variable arity requires at least one argument"); + + // .variableArity() can be used only once. + VELOX_ASSERT_USER_THROW( + exec::FunctionSignatureBuilder() + .returnType("bigint") + .variableArity("bigint") + .variableArity("integer") + .build(), + "Cannot add arguments after variable arity argument"); + + VELOX_ASSERT_USER_THROW( + exec::FunctionSignatureBuilder() + .returnType("bigint") + .argumentType("bigint") + .variableArity() + .variableArity() + .build(), + "Only one variable arity argument is allowed"); + + // No arguments can be added after calling .variableArity(). + VELOX_ASSERT_USER_THROW( + exec::FunctionSignatureBuilder() + .returnType("bigint") + .variableArity("bigint") + .argumentType("boolean") + .build(), + "Cannot add arguments after variable arity argument"); +} + TEST_F(FunctionSignatureBuilderTest, scalarConstantFlags) { { auto signature = FunctionSignatureBuilder() diff --git a/velox/exec/tests/GroupedExecutionTest.cpp b/velox/exec/tests/GroupedExecutionTest.cpp index 07fc21ba3f7..477a41b020a 100644 --- a/velox/exec/tests/GroupedExecutionTest.cpp +++ b/velox/exec/tests/GroupedExecutionTest.cpp @@ -17,16 +17,18 @@ #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/memory/MemoryArbitrator.h" +#include "velox/common/testutil/TempDirectoryPath.h" +#include "velox/common/testutil/TempFilePath.h" #include "velox/exec/Cursor.h" #include "velox/exec/OutputBufferManager.h" #include "velox/exec/PlanNodeStats.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/type/Type.h" namespace facebook::velox::exec::test { +using namespace facebook::velox::common::testutil; class GroupedExecutionTest : public virtual HiveConnectorTestBase { protected: @@ -547,7 +549,7 @@ TEST_F(GroupedExecutionTest, hashJoinWithMixedGroupedExecutionWithSpill) { } TestScopedSpillInjection scopedSpillInjection(triggerSpill ? 100 : 0); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); AssertQueryBuilder queryBuilder(duckDbQueryRunner_); queryBuilder.plan(plan) .spillDirectory(spillDirectory->getPath()) @@ -675,16 +677,25 @@ DEBUG_ONLY_TEST_F( } })); + const auto spillDirectory = TempDirectoryPath::create(); + std::optional spillOpts; + if (testData.enableSpill) { + spillOpts = common::SpillDiskOptions{ + .spillDirPath = spillDirectory->getPath(), + .spillDirCreated = true, + .spillDirCreateCb = nullptr}; + } + auto task = exec::Task::create( "0", std::move(planFragment), 0, std::move(queryCtx), - Task::ExecutionMode::kParallel); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); - if (testData.enableSpill) { - task->setSpillDirectory(spillDirectory->getPath()); - } + Task::ExecutionMode::kParallel, + /*consumer=*/Consumer{}, + /*memoryArbitrationPriority=*/0, + spillOpts, + /*onError=*/nullptr); // 'numDriversPerGroup' drivers max to execute one group at a time. task->start(numDriversPerGroup, testData.groupConcurrency); @@ -817,15 +828,21 @@ DEBUG_ONLY_TEST_F( memory::testingRunArbitration(op->pool()); })); + const auto spillDirectory = TempDirectoryPath::create(); + common::SpillDiskOptions spillOpts{ + .spillDirPath = spillDirectory->getPath(), + .spillDirCreated = true, + .spillDirCreateCb = nullptr}; + auto task = exec::Task::create( "0", std::move(planFragment), 0, std::move(queryCtx), - Task::ExecutionMode::kParallel); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); - - task->setSpillDirectory(spillDirectory->getPath()); + Task::ExecutionMode::kParallel, + Consumer{}, + /*memoryArbitrationPriority=*/0, + spillOpts); // 'numDriversPerGroup' drivers max to execute one group at a time. task->start(numDriversPerGroup, 1); @@ -848,8 +865,8 @@ DEBUG_ONLY_TEST_F( } // Total drivers should be numDriversPerGroup * (numGroups + 1), but since - // probe does not receive termination signal, it cannot signal the build side - // to finish. we expect only build's numDriversPerGroup finished. + // probe does not receive termination signal, it cannot signal the build + // side to finish. we expect only build's numDriversPerGroup finished. waitForFinishedDrivers(task, numDriversPerGroup); // 'Delete results' from output buffer triggers 'set all output consumed', @@ -1025,8 +1042,8 @@ TEST_F(GroupedExecutionTest, groupedExecutionWithHashAndNestedLoopJoin) { const std::unordered_set expectedSplitGroupIds({1, 5, 8}); int numSplitGroupJoinNodes{0}; task->pool()->visitChildren([&](memory::MemoryPool* childPool) -> bool { - if (folly::StringPiece(childPool->name()) - .startsWith(fmt::format("node.{}[", joinNodeId))) { + if (childPool->name().starts_with( + fmt::format("node.{}[", joinNodeId))) { ++numSplitGroupJoinNodes; std::vector parts; folly::split(".", childPool->name(), parts); diff --git a/velox/exec/tests/HashJoinBridgeTest.cpp b/velox/exec/tests/HashJoinBridgeTest.cpp index 8b91ce52bf5..ccf3d411b89 100644 --- a/velox/exec/tests/HashJoinBridgeTest.cpp +++ b/velox/exec/tests/HashJoinBridgeTest.cpp @@ -16,15 +16,15 @@ #include "velox/exec/HashJoinBridge.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/FileSystems.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/exec/HashTable.h" #include "velox/exec/Spill.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" using namespace facebook::velox; using namespace facebook::velox::exec; -using facebook::velox::exec::test::TempDirectoryPath; namespace facebook::velox::exec::test { +using namespace facebook::velox::common::testutil; class HashJoinBridgeTestHelper { public: @@ -44,11 +44,17 @@ class HashJoinBridgeTestHelper { HashJoinBridge* const bridge_; }; +namespace { struct TestParam { int32_t numProbers{1}; int32_t numBuilders{1}; }; +inline void PrintTo(const TestParam& param, std::ostream* os) { + *os << fmt::format( + "probers:{}_builders:{}", param.numProbers, param.numBuilders); +} + class HashJoinBridgeTest : public testing::Test, public testing::WithParamInterface { public: @@ -70,7 +76,7 @@ class HashJoinBridgeTest : public testing::Test, void SetUp() override { rng_.seed(1245); - tempDir_ = exec::test::TempDirectoryPath::create(); + tempDir_ = TempDirectoryPath::create(); } void TearDown() override {} @@ -96,7 +102,7 @@ class HashJoinBridgeTest : public testing::Test, std::make_unique(rowType_->childAt(channel), channel)); } return HashTable::createForJoin( - std::move(keyHashers), {}, true, false, 1'000, pool_.get()); + std::move(keyHashers), {}, true, false, false, 1'000, pool_.get()); } std::vector createEmptyFutures(int32_t count) { @@ -670,12 +676,16 @@ TEST_P(HashJoinBridgeTest, hashJoinTableType) { std::vector buildKeys; std::vector probeKeys; for (uint32_t i = 0; i < testData.buildKeyType->size(); i++) { - buildKeys.push_back(std::make_shared( - testData.buildKeyType->childAt(i), testData.buildKeyType->nameOf(i))); + buildKeys.push_back( + std::make_shared( + testData.buildKeyType->childAt(i), + testData.buildKeyType->nameOf(i))); } for (uint32_t i = 0; i < testData.probeKeyType->size(); i++) { - probeKeys.push_back(std::make_shared( - testData.probeKeyType->childAt(i), testData.probeKeyType->nameOf(i))); + probeKeys.push_back( + std::make_shared( + testData.probeKeyType->childAt(i), + testData.probeKeyType->nameOf(i))); } const auto joinNode = std::make_shared( "join-bridge-test", @@ -724,4 +734,5 @@ TEST(HashJoinBridgeTest, hashJoinTableSpillType) { ASSERT_EQ(spillType->names(), testData.expectedTableSpillType->names()); } } +} // namespace } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/HashJoinTest.cpp b/velox/exec/tests/HashJoinTest.cpp index 510e063e919..1f10c643ce5 100644 --- a/velox/exec/tests/HashJoinTest.cpp +++ b/velox/exec/tests/HashJoinTest.cpp @@ -17,13 +17,16 @@ #include #include -#include "folly/experimental/EventCount.h" +#include "folly/synchronization/EventCount.h" #include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/common/testutil/TestValue.h" #include "velox/dwio/common/tests/utils/BatchMaker.h" #include "velox/exec/Cursor.h" #include "velox/exec/HashBuild.h" #include "velox/exec/HashJoinBridge.h" +#include "velox/exec/HashProbe.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/OperatorUtils.h" #include "velox/exec/PlanNodeStats.h" #include "velox/exec/tests/utils/ArbitratorTestUtil.h" @@ -31,8 +34,8 @@ #include "velox/exec/tests/utils/HashJoinTestBase.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/exec/tests/utils/VectorTestUtil.h" +#include "velox/type/tests/utils/CustomTypesForTesting.h" #include "velox/vector/fuzzer/VectorFuzzer.h" using namespace facebook::velox; @@ -44,27 +47,35 @@ using facebook::velox::test::BatchMaker; namespace facebook::velox::exec { namespace { -class HashJoinTest : public HashJoinTestBase { +class HashJoinTest : public HashJoinTestBase, + public testing::WithParamInterface { public: - HashJoinTest() : HashJoinTestBase(TestParam(1)) {} + HashJoinTest() : HashJoinTestBase(GetParam()) {} explicit HashJoinTest(const TestParam& param) : HashJoinTestBase(param) {} + + static std::vector getTestParams() { + return std::vector({TestParam{1, false}, TestParam{1, true}}); + } }; -class MultiThreadedHashJoinTest - : public HashJoinTest, - public testing::WithParamInterface { +class MultiThreadedHashJoinTest : public HashJoinTest { public: MultiThreadedHashJoinTest() : HashJoinTest(GetParam()) {} static std::vector getTestParams() { - return std::vector({TestParam{1}, TestParam{3}}); + return std::vector( + {TestParam{1, false}, + TestParam{1, true}, + TestParam{3, false}, + TestParam{3, true}}); } }; TEST_P(MultiThreadedHashJoinTest, bigintArray) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .keyTypes({BIGINT()}) .probeVectors(1600, 5) .buildVectors(1500, 5) @@ -76,6 +87,7 @@ TEST_P(MultiThreadedHashJoinTest, bigintArray) { TEST_P(MultiThreadedHashJoinTest, outOfJoinKeyColumnOrder) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeType(probeType_) .probeKeys({"t_k2"}) .probeVectors(5, 10) @@ -91,6 +103,7 @@ TEST_P(MultiThreadedHashJoinTest, outOfJoinKeyColumnOrder) { TEST_P(MultiThreadedHashJoinTest, joinWithCancellation) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .keyTypes({BIGINT()}) .probeVectors(1600, 5) .buildVectors(1500, 5) @@ -105,9 +118,10 @@ TEST_P(MultiThreadedHashJoinTest, joinWithCancellation) { } TEST_P(MultiThreadedHashJoinTest, testJoinWithSpillenabledCancellation) { - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .keyTypes({BIGINT()}) .probeVectors(1600, 5) .buildVectors(1500, 5) @@ -128,6 +142,7 @@ TEST_P(MultiThreadedHashJoinTest, emptyBuild) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .keyTypes({BIGINT()}) .probeVectors(1600, 5) .buildVectors(0, 5) @@ -159,6 +174,7 @@ TEST_P(MultiThreadedHashJoinTest, emptyBuild) { TEST_P(MultiThreadedHashJoinTest, emptyProbe) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .keyTypes({BIGINT()}) .probeVectors(0, 5) .buildVectors(1500, 5) @@ -195,6 +211,7 @@ TEST_P(MultiThreadedHashJoinTest, emptyProbe) { TEST_P(MultiThreadedHashJoinTest, normalizedKey) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .keyTypes({BIGINT(), VARCHAR()}) .probeVectors(1600, 5) .buildVectors(1500, 5) @@ -220,6 +237,7 @@ DEBUG_ONLY_TEST_P(MultiThreadedHashJoinTest, parallelJoinBuildCheck) { std::function([&](void*) { isParallelBuild = true; })); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .keyTypes({BIGINT(), VARCHAR()}) .probeVectors(1600, 5) .buildVectors(1500, 5) @@ -250,6 +268,7 @@ DEBUG_ONLY_TEST_P( VELOX_ASSERT_THROW( HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .keyTypes({BIGINT(), VARCHAR()}) .probeVectors(1600, 5) .buildVectors(1500, 5) @@ -280,6 +299,7 @@ TEST_P(MultiThreadedHashJoinTest, allTypes) { TEST_P(MultiThreadedHashJoinTest, filter) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .keyTypes({BIGINT()}) .probeVectors(1600, 5) .buildVectors(1500, 5) @@ -289,8 +309,101 @@ TEST_P(MultiThreadedHashJoinTest, filter) { .run(); } +// Regression test for a JoinFuzzer-found bug where HashProbe::evalFilter +// produces a DictionaryVector with indices pointing past the base +// vector's size. The issue involves the filter "t_N = true" on a +// dictionary-encoded probe boolean column combined with expression +// memoization across multiple output batches from the same probe input. +// In debug builds, this triggers a validation failure in +// DictionaryVector::validate(). +// +// This test exercises the same code path: hash join with boolean filter on +// a probe column that is also a join key, using dictionary-encoded probe +// input (matching the fuzzer's ENCODED input type) and small output batch +// sizes to force multiple output batches from the same probe input. +TEST_P(HashJoinTest, booleanJoinFilterDictionaryValidation) { + VectorFuzzer::Options opts; + opts.nullRatio = 0.1; + + for (int seed = 0; seed < 20; ++seed) { + SCOPED_TRACE(fmt::format("seed: {}", seed)); + opts.vectorSize = 10 + (seed % 20); + VectorFuzzer fuzzer(opts, pool_.get(), seed); + + auto probeType = ROW({"t0", "t1"}, {INTEGER(), BOOLEAN()}); + auto buildType = ROW({"u0", "u1"}, {INTEGER(), BOOLEAN()}); + + // Use fuzzRow which wraps columns in dictionary/constant encoding, + // matching the JoinFuzzer's ENCODED input type. + std::vector probeVectors = {fuzzer.fuzzRow(probeType)}; + std::vector buildVectors = {fuzzer.fuzzRow(buildType)}; + + for (int batchSize : {3, 5}) { + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(1) + .probeKeys({"t0", "t1"}) + .probeVectors(std::vector(probeVectors)) + .buildKeys({"u0", "u1"}) + .buildVectors(std::vector(buildVectors)) + .joinFilter("t1 = true") + .joinOutputLayout({"t0", "t1", "u0", "u1"}) + .config( + core::QueryConfig::kPreferredOutputBatchRows, + std::to_string(batchSize)) + .config( + core::QueryConfig::kMaxOutputBatchRows, std::to_string(batchSize)) + .injectSpill(false) + .referenceQuery( + "SELECT t.t0, t.t1, u.u0, u.u1 FROM t, u " + "WHERE t.t0 = u.u0 AND t.t1 = u.u1 AND t.t1 = true") + .run(); + } + } +} + +// Same as above but with the boolean filter column as a non-key payload +// column, exercising a slightly different code path where only integer +// keys are used for join matching. +TEST_P(HashJoinTest, booleanPayloadFilterDictionaryValidation) { + VectorFuzzer::Options opts; + opts.nullRatio = 0.1; + + for (int seed = 0; seed < 20; ++seed) { + SCOPED_TRACE(fmt::format("seed: {}", seed)); + opts.vectorSize = 10 + (seed % 20); + VectorFuzzer fuzzer(opts, pool_.get(), seed); + + auto probeType = ROW({"t0", "t1"}, {INTEGER(), BOOLEAN()}); + auto buildType = ROW({"u0", "u1"}, {INTEGER(), BOOLEAN()}); + + std::vector probeVectors = {fuzzer.fuzzRow(probeType)}; + std::vector buildVectors = {fuzzer.fuzzRow(buildType)}; + + for (int batchSize : {3, 5}) { + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(1) + .probeKeys({"t0"}) + .probeVectors(std::vector(probeVectors)) + .buildKeys({"u0"}) + .buildVectors(std::vector(buildVectors)) + .joinFilter("t1 = true") + .joinOutputLayout({"t0", "t1", "u0", "u1"}) + .config( + core::QueryConfig::kPreferredOutputBatchRows, + std::to_string(batchSize)) + .config( + core::QueryConfig::kMaxOutputBatchRows, std::to_string(batchSize)) + .injectSpill(false) + .referenceQuery( + "SELECT t.t0, t.t1, u.u0, u.u1 FROM t, u " + "WHERE t.t0 = u.u0 AND t.t1 = true") + .run(); + } + } +} + DEBUG_ONLY_TEST_P(MultiThreadedHashJoinTest, filterSpillOnFirstProbeInput) { - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); std::atomic_bool injectProbeSpillOnce{true}; SCOPED_TESTVALUE_SET( "facebook::velox::exec::Driver::runInternal::getOutput", @@ -312,6 +425,7 @@ DEBUG_ONLY_TEST_P(MultiThreadedHashJoinTest, filterSpillOnFirstProbeInput) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .keyTypes({BIGINT()}) .numDrivers(1) .probeVectors(1600, 5) @@ -362,6 +476,7 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithNull) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeType(probeType_) .probeKeys({"t_k2"}) .probeVectors(std::move(probeVectors)) @@ -380,6 +495,75 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithNull) { } } +// Verifies that nullAsValue flag makes NULL keys match each other in anti and +// left semi filter joins, as required by SQL set operations (EXCEPT, +// INTERSECT). +TEST_P(MultiThreadedHashJoinTest, nullAsValueAntiJoin) { + // Probe: (1, null), (2, 3). Build: (1, null), (2, 4). + // With nullAsValue, (1, null) matches on both keys and is anti-joined away. + // (2, 3) does not match (2, 4) because 3 != 4, so it passes through. + auto probeVector = makeRowVector( + {"c0", "c1"}, + {makeNullableFlatVector({1, 2}), + makeNullableFlatVector({std::nullopt, 3})}); + auto buildVector = makeRowVector( + {"u0", "u1"}, + {makeNullableFlatVector({1, 2}), + makeNullableFlatVector({std::nullopt, 4})}); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({probeVector}) + .hashJoin( + {"c0", "c1"}, + {"u0", "u1"}, + PlanBuilder(planNodeIdGenerator).values({buildVector}).planNode(), + "", + {"c0", "c1"}, + core::JoinType::kAnti, + /*nullAware=*/false, + /*nullAsValue=*/true) + .planNode(); + AssertQueryBuilder(plan).assertResults(makeRowVector( + {"c0", "c1"}, + {makeNullableFlatVector({2}), + makeNullableFlatVector({3})})); +} + +TEST_P(MultiThreadedHashJoinTest, nullAsValueSemiJoin) { + // Probe: (1, null), (2, 3). Build: (1, null), (2, 4). + // With nullAsValue, (1, null) matches and passes through semi join. + // (2, 3) does not match (2, 4), so it is filtered out. + auto probeVector = makeRowVector( + {"c0", "c1"}, + {makeNullableFlatVector({1, 2}), + makeNullableFlatVector({std::nullopt, 3})}); + auto buildVector = makeRowVector( + {"u0", "u1"}, + {makeNullableFlatVector({1, 2}), + makeNullableFlatVector({std::nullopt, 4})}); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({probeVector}) + .hashJoin( + {"c0", "c1"}, + {"u0", "u1"}, + PlanBuilder(planNodeIdGenerator).values({buildVector}).planNode(), + "", + {"c0", "c1"}, + core::JoinType::kLeftSemiFilter, + /*nullAware=*/false, + /*nullAsValue=*/true) + .planNode(); + AssertQueryBuilder(plan).assertResults(makeRowVector( + {"c0", "c1"}, + {makeNullableFlatVector({1}), + makeNullableFlatVector({std::nullopt})})); +} + TEST_P(MultiThreadedHashJoinTest, rightSemiJoinFilterWithLargeOutput) { // Build the identical left and right vectors to generate large join // outputs. @@ -401,6 +585,7 @@ TEST_P(MultiThreadedHashJoinTest, rightSemiJoinFilterWithLargeOutput) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"u0"}) @@ -452,6 +637,7 @@ TEST_P(MultiThreadedHashJoinTest, arrayBasedLookup) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"c0"}) @@ -518,6 +704,7 @@ TEST_P(MultiThreadedHashJoinTest, joinSidesDifferentSchema) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t_c0"}) .probeVectors(std::move(probeVectors)) .probeProjections({"c0 AS t_c0", "c1 AS t_c1", "c2 AS t_c2"}) @@ -557,6 +744,7 @@ TEST_P(MultiThreadedHashJoinTest, innerJoinWithEmptyBuild) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"c0"}) @@ -591,6 +779,7 @@ TEST_P(MultiThreadedHashJoinTest, innerJoinWithEmptyBuild) { TEST_P(MultiThreadedHashJoinTest, leftSemiJoinFilter) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeType(probeType_) .probeVectors(174, 5) .probeKeys({"t_k1"}) @@ -627,6 +816,7 @@ TEST_P(MultiThreadedHashJoinTest, leftSemiJoinFilterWithEmptyBuild) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"c0"}) @@ -668,6 +858,7 @@ TEST_P(MultiThreadedHashJoinTest, leftSemiJoinFilterWithExtraFilter) { auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"u0"}) @@ -684,6 +875,7 @@ TEST_P(MultiThreadedHashJoinTest, leftSemiJoinFilterWithExtraFilter) { auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"u0"}) @@ -700,6 +892,7 @@ TEST_P(MultiThreadedHashJoinTest, leftSemiJoinFilterWithExtraFilter) { TEST_P(MultiThreadedHashJoinTest, rightSemiJoinFilter) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeType(probeType_) .probeVectors(133, 3) .probeKeys({"t_k1"}) @@ -741,6 +934,7 @@ TEST_P(MultiThreadedHashJoinTest, rightSemiJoinFilterWithEmptyBuild) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"u0"}) @@ -798,6 +992,7 @@ TEST_P(MultiThreadedHashJoinTest, rightSemiJoinFilterWithAllMatches) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"u0"}) @@ -833,6 +1028,7 @@ TEST_P(MultiThreadedHashJoinTest, rightSemiJoinFilterWithExtraFilter) { auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"u0"}) @@ -855,6 +1051,7 @@ TEST_P(MultiThreadedHashJoinTest, rightSemiJoinFilterWithExtraFilter) { auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"u0"}) @@ -876,6 +1073,7 @@ TEST_P(MultiThreadedHashJoinTest, rightSemiJoinFilterWithExtraFilter) { auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"u0"}) @@ -941,23 +1139,21 @@ TEST_P(MultiThreadedHashJoinTest, semiFilterOverLazyVectors) { core::JoinType::kLeftSemiFilter) .planNode(); - SplitInput splitInput = { - {probeScanId, - {exec::Split(makeHiveConnectorSplit(probeFile->getPath()))}}, - {buildScanId, - {exec::Split(makeHiveConnectorSplit(buildFile->getPath()))}}, + SplitPath splitPaths = { + {probeScanId, {probeFile->getPath()}}, + {buildScanId, {buildFile->getPath()}}, }; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .planNode(plan) - .inputSplits(splitInput) + .inputSplits(splitPaths) .checkSpillStats(false) .referenceQuery("SELECT t0, t1 FROM t WHERE t0 IN (SELECT u0 FROM u)") .run(); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .planNode(flipJoinSides(plan)) - .inputSplits(splitInput) + .inputSplits(splitPaths) .checkSpillStats(false) .referenceQuery("SELECT t0, t1 FROM t WHERE t0 IN (SELECT u0 FROM u)") .run(); @@ -981,7 +1177,7 @@ TEST_P(MultiThreadedHashJoinTest, semiFilterOverLazyVectors) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .planNode(plan) - .inputSplits(splitInput) + .inputSplits(splitPaths) .checkSpillStats(false) .referenceQuery( "SELECT t0, t1 FROM t WHERE t0 IN (SELECT u0 FROM u WHERE (t1 + u1) % 3 = 0)") @@ -989,7 +1185,7 @@ TEST_P(MultiThreadedHashJoinTest, semiFilterOverLazyVectors) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .planNode(flipJoinSides(plan)) - .inputSplits(splitInput) + .inputSplits(splitPaths) .checkSpillStats(false) .referenceQuery( "SELECT t0, t1 FROM t WHERE t0 IN (SELECT u0 FROM u WHERE (t1 + u1) % 3 = 0)") @@ -1019,6 +1215,7 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoin) { auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"c0"}) @@ -1039,6 +1236,7 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoin) { auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"c0"}) @@ -1059,6 +1257,7 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoin) { auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"c0"}) @@ -1096,6 +1295,7 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithFilter) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"u0"}) @@ -1150,6 +1350,7 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithFilterAndEmptyBuild) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::vector(probeVectors)) .buildKeys({"u0"}) @@ -1209,6 +1410,7 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithFilterAndNullKey) { auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"u0"}) @@ -1269,6 +1471,7 @@ TEST_P( auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"u0"}) @@ -1307,6 +1510,7 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithFilterOnNullableColumn) { }); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"u0"}) @@ -1357,6 +1561,7 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithFilterOnNullableColumn) { }); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"u0"}) @@ -1386,6771 +1591,223 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithFilterOnNullableColumn) { } } -TEST_P(MultiThreadedHashJoinTest, antiJoin) { - auto probeVectors = makeBatches(64, [&](int32_t /*unused*/) { +TEST_P( + MultiThreadedHashJoinTest, + nullAwareAntiJoinWithFilterBatchedEvaluation) { + // Use >1024 build rows to trigger multiple batches in + // applyFilterOnTableRowsForNullAwareJoin (kBatchSize is 1024), exercising the + // per-batch deselect of filterPassedRows from rows. Include null probe keys + // so that crossJoinProbeRows is non-empty and the cross-join path iterates + // all 2048 build rows across 2 batches. + auto probeVectors = makeBatches(1, [&](int32_t /*unused*/) { return makeRowVector( {"t0", "t1"}, { - makeNullableFlatVector({std::nullopt, 1, 2}), - makeFlatVector({0, 1, 2}), + makeFlatVector( + 256, + [](auto row) { return row % 50; }, + [](auto row) { return row < 4; }), + makeFlatVector(256, [](auto row) { return row; }), }); }); - auto buildVectors = makeBatches(64, [&](int32_t /*unused*/) { + auto buildVectors = makeBatches(1, [&](int32_t /*unused*/) { return makeRowVector( {"u0", "u1"}, { - makeNullableFlatVector({std::nullopt, 2, 3}), - makeFlatVector({0, 2, 3}), + makeFlatVector(2048, [](auto row) { return row % 25; }), + makeFlatVector(2048, [](auto row) { return row * 2; }), }); }); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) .probeKeys({"t0"}) - .probeVectors(std::vector(probeVectors)) + .probeVectors(std::move(probeVectors)) .buildKeys({"u0"}) - .buildVectors(std::vector(buildVectors)) + .buildVectors(std::move(buildVectors)) .joinType(core::JoinType::kAnti) + .nullAware(true) + .joinFilter("t1 <> u1") .joinOutputLayout({"t0", "t1"}) .referenceQuery( - "SELECT t.* FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE u.u0 = t.t0)") + "SELECT t.* FROM t WHERE t0 NOT IN (SELECT u0 FROM u WHERE t1 <> u1)") + .checkSpillStats(false) .run(); - - std::vector filters({ - "u1 > t1", - "u1 * t1 > 0", - // This filter is true on rows without a match. It should not prevent - // the row from being returned. - "coalesce(u1, t1, 0::integer) is not null", - // This filter throws if evaluated on rows without a match. The join - // should not evaluate filter on those rows and therefore should not - // fail. - "t1 / coalesce(u1, 0::integer) is not null", - // This filter triggers memory pool allocation at - // HashBuild::setupFilterForAntiJoins, which should not be invoked in - // operator's constructor. - "contains(array[1, 2, NULL], 1)", - }); - for (const std::string& filter : filters) { - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .probeKeys({"t0"}) - .probeVectors(std::vector(probeVectors)) - .buildKeys({"u0"}) - .buildVectors(std::vector(buildVectors)) - .joinType(core::JoinType::kAnti) - .joinFilter(filter) - .joinOutputLayout({"t0", "t1"}) - .referenceQuery(fmt::format( - "SELECT t.* FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE u.u0 = t.t0 AND {})", - filter)) - .run(); - } -} - -TEST_P(MultiThreadedHashJoinTest, antiJoinWithFilterAndEmptyBuild) { - const std::vector finishOnEmptys = {false, true}; - for (const auto finishOnEmpty : finishOnEmptys) { - SCOPED_TRACE(fmt::format("finishOnEmpty: {}", finishOnEmpty)); - - auto probeVectors = makeBatches(4, [&](int32_t /*unused*/) { - return makeRowVector( - {"t0", "t1"}, - { - makeNullableFlatVector({std::nullopt, 1, 2}), - makeFlatVector({0, 1, 2}), - }); - }); - auto buildVectors = makeBatches(4, [&](int32_t /*unused*/) { - return makeRowVector( - {"u0", "u1"}, - { - makeNullableFlatVector({3, 2, 3}), - makeFlatVector({0, 2, 3}), - }); - }); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) - .numDrivers(numDrivers_) - .probeKeys({"t0"}) - .probeVectors(std::vector(probeVectors)) - .buildKeys({"u0"}) - .buildVectors(std::vector(buildVectors)) - .buildFilter("u0 < 0") - .joinType(core::JoinType::kAnti) - .joinFilter("u1 > t1") - .joinOutputLayout({"t0", "t1"}) - .referenceQuery( - "SELECT t.* FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE u0 < 0 AND u.u0 = t.t0)") - .checkSpillStats(false) - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - const auto statsPair = taskSpilledStats(*task); - ASSERT_EQ(statsPair.first.spilledRows, 0); - ASSERT_EQ(statsPair.first.spilledBytes, 0); - ASSERT_EQ(statsPair.first.spilledPartitions, 0); - ASSERT_EQ(statsPair.first.spilledFiles, 0); - ASSERT_EQ(statsPair.second.spilledRows, 0); - ASSERT_EQ(statsPair.second.spilledBytes, 0); - ASSERT_EQ(statsPair.second.spilledPartitions, 0); - ASSERT_EQ(statsPair.second.spilledFiles, 0); - verifyTaskSpilledRuntimeStats(*task, false); - ASSERT_EQ(maxHashBuildSpillLevel(*task), -1); - }) - .run(); - } } -TEST_P(MultiThreadedHashJoinTest, leftJoin) { - // Left side keys are [0, 1, 2,..20]. - // Use 3-rd column as row number to allow for asserting the order of - // results. - std::vector probeVectors = mergeBatches( - makeBatches( - 3, - [&](int32_t /*unused*/) { - return makeRowVector( - {"c0", "c1", "row_number"}, - { - makeFlatVector( - 77, [](auto row) { return row % 21; }, nullEvery(13)), - makeFlatVector(77, [](auto row) { return row; }), - makeFlatVector(77, [](auto row) { return row; }), - }); - }), - makeBatches( - 2, - [&](int32_t /*unused*/) { - return makeRowVector( - {"c0", "c1", "row_number"}, - { - makeFlatVector( - 97, - [](auto row) { return (row + 3) % 21; }, - nullEvery(13)), - makeFlatVector(97, [](auto row) { return row; }), - makeFlatVector( - 97, [](auto row) { return 97 + row; }), - }); - }), - true); - - std::vector buildVectors = - makeBatches(3, [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 73, [](auto row) { return row % 5; }, nullEvery(7)), - makeFlatVector( - 73, [](auto row) { return -111 + row * 2; }, nullEvery(7)), +TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithFilterEarlyTermination) { + auto probeVectors = makeBatches(1, [&](int32_t /*unused*/) { + return makeRowVector( + {"t0", "t1"}, + { + makeFlatVector(100, [](auto row) { return row % 10; }), + makeFlatVector(100, [](auto row) { return row; }), }); - }); + }); + auto buildVectors = makeBatches(1, [&](int32_t /*unused*/) { + return makeRowVector( + {"u0", "u1"}, + { + makeFlatVector(500, [](auto row) { return row % 10; }), + makeFlatVector(500, [](auto row) { return row; }), + }); + }); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) - .probeKeys({"c0"}) + .probeKeys({"t0"}) .probeVectors(std::move(probeVectors)) - .buildKeys({"u_c0"}) + .buildKeys({"u0"}) .buildVectors(std::move(buildVectors)) - .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) - .joinType(core::JoinType::kLeft) - .joinOutputLayout({"row_number", "c0", "c1", "u_c0"}) + .joinType(core::JoinType::kAnti) + .nullAware(true) + .joinFilter("t1 < u1") + .joinOutputLayout({"t0", "t1"}) .referenceQuery( - "SELECT t.row_number, t.c0, t.c1, u.c0 FROM t LEFT JOIN u ON t.c0 = u.c0") - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - int nullJoinBuildKeyCount = 0; - int nullJoinProbeKeyCount = 0; - - for (auto& pipeline : task->taskStats().pipelineStats) { - for (auto op : pipeline.operatorStats) { - if (op.operatorType == "HashBuild") { - nullJoinBuildKeyCount += op.numNullKeys; - } - if (op.operatorType == "HashProbe") { - nullJoinProbeKeyCount += op.numNullKeys; - } - } - } - ASSERT_EQ(nullJoinBuildKeyCount, 33 * GetParam().numDrivers); - ASSERT_EQ(nullJoinProbeKeyCount, 34 * GetParam().numDrivers); - }) + "SELECT t.* FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE t0 = u0 AND t1 < u1)") + .checkSpillStats(false) .run(); } -TEST_P(MultiThreadedHashJoinTest, nullStatsWithEmptyBuild) { - std::vector probeVectors = - makeBatches(1, [&](int32_t /*unused*/) { - return makeRowVector( - {"c0", "c1", "row_number"}, - { - makeFlatVector( - 77, [](auto row) { return row % 21; }, nullEvery(13)), - makeFlatVector(77, [](auto row) { return row; }), - makeFlatVector(77, [](auto row) { return row; }), - }); - }); - - // All null keys on build side. - std::vector buildVectors = - makeBatches(1, [&](int32_t /*unused*/) { - return makeRowVector({ +TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithFilterMixedNulls) { + auto probeVectors = makeBatches(2, [&](int32_t batch) { + return makeRowVector( + {"t0", "t1"}, + { + makeNullableFlatVector( + {std::nullopt, + 1, + 2, + std::nullopt, + 4, + 5, + 6, + std::nullopt, + 8, + 9}), makeFlatVector( - 1, [](auto row) { return row % 5; }, nullEvery(1)), + 10, [batch](auto row) { return batch * 10 + row; }), + }); + }); + auto buildVectors = makeBatches(2, [&](int32_t batch) { + return makeRowVector( + {"u0", "u1"}, + { + makeNullableFlatVector( + {1, + std::nullopt, + 3, + 4, + std::nullopt, + 6, + 7, + 8, + std::nullopt, + 10}), makeFlatVector( - 1, [](auto row) { return -111 + row * 2; }, nullEvery(1)), + 10, [batch](auto row) { return batch * 5 + row; }), }); - }); + }); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) - .probeKeys({"c0"}) + .probeKeys({"t0"}) .probeVectors(std::move(probeVectors)) - .buildKeys({"u_c0"}) + .buildKeys({"u0"}) .buildVectors(std::move(buildVectors)) - .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) - .joinType(core::JoinType::kLeft) - .joinOutputLayout({"row_number", "c0", "c1", "u_c0"}) + .joinType(core::JoinType::kAnti) + .nullAware(true) + .joinFilter("t1 + u1 < 50") + .joinOutputLayout({"t0", "t1"}) .referenceQuery( - "SELECT t.row_number, t.c0, t.c1, u.c0 FROM t LEFT JOIN u ON t.c0 = u.c0") - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - int nullJoinBuildKeyCount = 0; - int nullJoinProbeKeyCount = 0; - - for (auto& pipeline : task->taskStats().pipelineStats) { - for (auto op : pipeline.operatorStats) { - if (op.operatorType == "HashBuild") { - nullJoinBuildKeyCount += op.numNullKeys; - } - if (op.operatorType == "HashProbe") { - nullJoinProbeKeyCount += op.numNullKeys; - } - } - } - // Due to inaccurate stats tracking in case of empty build side, - // we will report 0 null keys on probe side. - ASSERT_EQ(nullJoinProbeKeyCount, 0); - ASSERT_EQ(nullJoinBuildKeyCount, 1 * GetParam().numDrivers); - }) + "SELECT t.* FROM t WHERE t0 NOT IN (SELECT u0 FROM u WHERE t1 + u1 < 50)") .checkSpillStats(false) .run(); } -TEST_P(MultiThreadedHashJoinTest, leftJoinWithEmptyBuild) { - const std::vector finishOnEmptys = {false, true}; - for (const auto finishOnEmpty : finishOnEmptys) { - SCOPED_TRACE(fmt::format("finishOnEmpty: {}", finishOnEmpty)); - - // Left side keys are [0, 1, 2,..10]. - // Use 3-rd column as row number to allow for asserting the order of - // results. - std::vector probeVectors = mergeBatches( - makeBatches( - 3, - [&](int32_t /*unused*/) { - return makeRowVector( - {"c0", "c1", "row_number"}, - { - makeFlatVector( - 77, [](auto row) { return row % 11; }, nullEvery(13)), - makeFlatVector(77, [](auto row) { return row; }), - makeFlatVector(77, [](auto row) { return row; }), - }); - }), - makeBatches( - 2, - [&](int32_t /*unused*/) { - return makeRowVector( - {"c0", "c1", "row_number"}, - { - makeFlatVector( - 97, - [](auto row) { return (row + 3) % 11; }, - nullEvery(13)), - makeFlatVector(97, [](auto row) { return row; }), - makeFlatVector( - 97, [](auto row) { return 97 + row; }), - }); - }), - true); - - std::vector buildVectors = - makeBatches(3, [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 73, [](auto row) { return row % 5; }, nullEvery(7)), - makeFlatVector( - 73, [](auto row) { return -111 + row * 2; }, nullEvery(7)), - }); +TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithFilterAllNullProbeKeys) { + auto probeVectors = makeBatches(1, [&](int32_t /*unused*/) { + return makeRowVector( + {"t0", "t1"}, + { + makeNullConstant(TypeKind::INTEGER, 64), + makeFlatVector(64, [](auto row) { return row; }), }); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) - .numDrivers(numDrivers_) - .probeKeys({"c0"}) - .probeVectors(std::move(probeVectors)) - .buildKeys({"u_c0"}) - .buildVectors(std::move(buildVectors)) - .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) - .buildFilter("c0 < 0") - .joinType(core::JoinType::kLeft) - .joinOutputLayout({"row_number", "c1"}) - .referenceQuery( - "SELECT t.row_number, t.c1 FROM t LEFT JOIN (SELECT c0 FROM u WHERE c0 < 0) u ON t.c0 = u.c0") - .checkSpillStats(false) - .run(); - } -} - -TEST_P(MultiThreadedHashJoinTest, leftJoinWithNoJoin) { - // Left side keys are [0, 1, 2,..10]. - // Use 3-rd column as row number to allow for asserting the order of - // results. - std::vector probeVectors = mergeBatches( - makeBatches( - 3, - [&](int32_t /*unused*/) { - return makeRowVector( - {"c0", "c1", "row_number"}, - { - makeFlatVector( - 77, [](auto row) { return row % 11; }, nullEvery(13)), - makeFlatVector(77, [](auto row) { return row; }), - makeFlatVector(77, [](auto row) { return row; }), - }); - }), - makeBatches( - 2, - [&](int32_t /*unused*/) { - return makeRowVector( - {"c0", "c1", "row_number"}, - { - makeFlatVector( - 97, - [](auto row) { return (row + 3) % 11; }, - nullEvery(13)), - makeFlatVector(97, [](auto row) { return row; }), - makeFlatVector( - 97, [](auto row) { return 97 + row; }), - }); - }), - true); - - std::vector buildVectors = - makeBatches(3, [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 73, [](auto row) { return row % 5; }, nullEvery(7)), - makeFlatVector( - 73, [](auto row) { return -111 + row * 2; }, nullEvery(7)), + }); + auto buildVectors = makeBatches(1, [&](int32_t /*unused*/) { + return makeRowVector( + {"u0", "u1"}, + { + makeFlatVector(32, [](auto row) { return row; }), + makeFlatVector(32, [](auto row) { return row * 3; }), }); - }); + }); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) - .probeKeys({"c0"}) + .probeKeys({"t0"}) .probeVectors(std::move(probeVectors)) - .buildKeys({"u_c0"}) + .buildKeys({"u0"}) .buildVectors(std::move(buildVectors)) - .buildProjections({"c0 - 123::INTEGER AS u_c0", "c1 AS u_c1"}) - .joinType(core::JoinType::kLeft) - .joinOutputLayout({"row_number", "c0", "u_c1"}) + .joinType(core::JoinType::kAnti) + .nullAware(true) + .joinFilter("t1 <> u1") + .joinOutputLayout({"t0", "t1"}) .referenceQuery( - "SELECT t.row_number, t.c0, u.c1 FROM t LEFT JOIN (SELECT c0 - 123::INTEGER AS u_c0, c1 FROM u) u ON t.c0 = u.u_c0") + "SELECT t.* FROM t WHERE t0 NOT IN (SELECT u0 FROM u WHERE t1 <> u1)") + .checkSpillStats(false) .run(); } -TEST_P(MultiThreadedHashJoinTest, leftJoinWithAllMatch) { - // Left side keys are [0, 1, 2,..10]. - // Use 3-rd column as row number to allow for asserting the order of - // results. - std::vector probeVectors = mergeBatches( - makeBatches( - 3, - [&](int32_t /*unused*/) { - return makeRowVector( - {"c0", "c1", "row_number"}, - { - makeFlatVector( - 77, [](auto row) { return row % 11; }, nullEvery(13)), - makeFlatVector(77, [](auto row) { return row; }), - makeFlatVector(77, [](auto row) { return row; }), - }); - }), - makeBatches( - 2, - [&](int32_t /*unused*/) { - return makeRowVector( - {"c0", "c1", "row_number"}, - { - makeFlatVector( - 97, - [](auto row) { return (row + 3) % 11; }, - nullEvery(13)), - makeFlatVector(97, [](auto row) { return row; }), - makeFlatVector( - 97, [](auto row) { return 97 + row; }), - }); - }), - true); - - std::vector buildVectors = - makeBatches(3, [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 73, [](auto row) { return row % 5; }, nullEvery(7)), - makeFlatVector( - 73, [](auto row) { return -111 + row * 2; }, nullEvery(7)), +TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithFilterEmptyBatch) { + auto probeVectors = makeBatches(1, [&](int32_t /*unused*/) { + return makeRowVector( + {"t0", "t1"}, + { + makeFlatVector(32, [](auto row) { return row; }), + makeFlatVector(32, [](auto row) { return row; }), }); - }); + }); + auto buildVectors = makeBatches(1, [&](int32_t /*unused*/) { + return makeRowVector( + {"u0", "u1"}, + { + makeFlatVector(32, [](auto row) { return row; }), + makeFlatVector(32, [](auto row) { return 1000 + row; }), + }); + }); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) - .probeKeys({"c0"}) + .probeKeys({"t0"}) .probeVectors(std::move(probeVectors)) - .probeFilter("c0 < 5") - .buildKeys({"u_c0"}) + .buildKeys({"u0"}) .buildVectors(std::move(buildVectors)) - .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) - .joinType(core::JoinType::kLeft) - .joinOutputLayout({"row_number", "c0", "c1", "u_c1"}) + .joinType(core::JoinType::kAnti) + .nullAware(true) + .joinFilter("t1 > u1") + .joinOutputLayout({"t0", "t1"}) .referenceQuery( - "SELECT t.row_number, t.c0, t.c1, u.c1 FROM (SELECT * FROM t WHERE c0 < 5) t LEFT JOIN u ON t.c0 = u.c0") + "SELECT t.* FROM t WHERE t0 NOT IN (SELECT u0 FROM u WHERE t1 > u1)") + .checkSpillStats(false) .run(); } -TEST_P(MultiThreadedHashJoinTest, leftJoinWithFilter) { - // Left side keys are [0, 1, 2,..10]. - // Use 3-rd column as row number to allow for asserting the order of - // results. - std::vector probeVectors = mergeBatches( - makeBatches( - 3, - [&](int32_t /*unused*/) { - return makeRowVector( - {"c0", "c1", "row_number"}, - { - makeFlatVector( - 77, [](auto row) { return row % 11; }, nullEvery(13)), - makeFlatVector(77, [](auto row) { return row; }), - makeFlatVector(77, [](auto row) { return row; }), - }); - }), - makeBatches( - 2, - [&](int32_t /*unused*/) { - return makeRowVector( - {"c0", "c1", "row_number"}, - { - makeFlatVector( - 97, - [](auto row) { return (row + 3) % 11; }, - nullEvery(13)), - makeFlatVector(97, [](auto row) { return row; }), - makeFlatVector( - 97, [](auto row) { return 97 + row; }), - }); - }), - true); +VELOX_INSTANTIATE_TEST_SUITE_P( + MultiThreadedHashJoinTest, + MultiThreadedHashJoinTest, + testing::ValuesIn(MultiThreadedHashJoinTest::getTestParams()), + [](const testing::TestParamInfo& info) { + return TestParamToName(info.param); + }); - std::vector buildVectors = - makeBatches(3, [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 73, [](auto row) { return row % 5; }, nullEvery(7)), - makeFlatVector( - 73, [](auto row) { return -111 + row * 2; }, nullEvery(7)), - }); - }); - - // Additional filter. - { - auto testProbeVectors = probeVectors; - auto testBuildVectors = buildVectors; - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .probeKeys({"c0"}) - .probeVectors(std::move(testProbeVectors)) - .buildKeys({"u_c0"}) - .buildVectors(std::move(testBuildVectors)) - .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) - .joinType(core::JoinType::kLeft) - .joinFilter("(c1 + u_c1) % 2 = 1") - .joinOutputLayout({"row_number", "c0", "c1", "u_c1"}) - .referenceQuery( - "SELECT t.row_number, t.c0, t.c1, u.c1 FROM t LEFT JOIN u ON t.c0 = u.c0 AND (t.c1 + u.c1) % 2 = 1") - .run(); - } - - // No rows pass the additional filter. - { - auto testProbeVectors = probeVectors; - auto testBuildVectors = buildVectors; - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .probeKeys({"c0"}) - .probeVectors(std::move(testProbeVectors)) - .buildKeys({"u_c0"}) - .buildVectors(std::move(testBuildVectors)) - .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) - .joinType(core::JoinType::kLeft) - .joinFilter("(c1 + u_c1) % 2 = 3") - .joinOutputLayout({"row_number", "c0", "c1", "u_c1"}) - .referenceQuery( - "SELECT t.row_number, t.c0, t.c1, u.c1 FROM t LEFT JOIN u ON t.c0 = u.c0 AND (t.c1 + u.c1) % 2 = 3") - .run(); - } -} - -/// Tests left join with a filter that may evaluate to true, false or null. -/// Makes sure that null filter results are handled correctly, e.g. as if the -/// filter returned false. -TEST_P(MultiThreadedHashJoinTest, leftJoinWithNullableFilter) { - std::vector probeVectors = mergeBatches( - makeBatches( - 5, - [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector({1, 2, 3, 4, 5}), - makeNullableFlatVector( - {10, std::nullopt, 30, std::nullopt, 50}), - }); - }), - makeBatches( - 5, - [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector({1, 2, 3, 4, 5}), - makeNullableFlatVector( - {std::nullopt, 20, 30, std::nullopt, 50}), - }); - }), - true); - - std::vector buildVectors = - makeBatches(5, [&](int32_t /*unused*/) { - return makeRowVector( - {makeFlatVector(128, [](vector_size_t row) { - if (row < 3) { - return row; - } - return row + 10; - })}); - }); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .probeKeys({"c0"}) - .probeVectors(std::move(probeVectors)) - .buildKeys({"u_c0"}) - .buildVectors(std::move(buildVectors)) - .buildProjections({"c0 AS u_c0"}) - .joinType(core::JoinType::kLeft) - .joinFilter("c1 + u_c0 > 0") - .joinOutputLayout({"c0", "c1", "u_c0"}) - .referenceQuery( - "SELECT * FROM t LEFT JOIN u ON (t.c0 = u.c0 AND t.c1 + u.c0 > 0)") - .run(); -} - -TEST_P(MultiThreadedHashJoinTest, rightJoin) { - // Left side keys are [0, 1, 2,..20]. - std::vector probeVectors = mergeBatches( - makeBatches( - 3, - [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 137, [](auto row) { return row % 21; }, nullEvery(13)), - makeFlatVector(137, [](auto row) { return row; }), - }); - }), - makeBatches( - 3, - [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 234, - [](auto row) { return (row + 3) % 21; }, - nullEvery(13)), - makeFlatVector(234, [](auto row) { return row; }), - }); - }), - true); - - // Right side keys are [-3, -2, -1, 0, 1, 2, 3]. - std::vector buildVectors = - makeBatches(3, [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 123, [](auto row) { return -3 + row % 7; }, nullEvery(11)), - makeFlatVector( - 123, [](auto row) { return -111 + row * 2; }, nullEvery(13)), - }); - }); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .probeKeys({"c0"}) - .probeVectors(std::move(probeVectors)) - .buildKeys({"u_c0"}) - .buildVectors(std::move(buildVectors)) - .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) - .joinType(core::JoinType::kRight) - .joinOutputLayout({"c0", "c1", "u_c1"}) - .referenceQuery( - "SELECT t.c0, t.c1, u.c1 FROM t RIGHT JOIN u ON t.c0 = u.c0") - .run(); -} - -TEST_P(MultiThreadedHashJoinTest, rightJoinWithEmptyBuild) { - const std::vector finishOnEmptys = {false, true}; - for (const auto finishOnEmpty : finishOnEmptys) { - SCOPED_TRACE(fmt::format("finishOnEmpty: {}", finishOnEmpty)); - - // Left side keys are [0, 1, 2,..10]. - std::vector probeVectors = mergeBatches( - makeBatches( - 3, - [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 137, [](auto row) { return row % 11; }, nullEvery(13)), - makeFlatVector(137, [](auto row) { return row; }), - }); - }), - makeBatches( - 3, - [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 234, - [](auto row) { return (row + 3) % 11; }, - nullEvery(13)), - makeFlatVector(234, [](auto row) { return row; }), - }); - }), - true); - - // Right side keys are [-3, -2, -1, 0, 1, 2, 3]. - std::vector buildVectors = - makeBatches(3, [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 123, [](auto row) { return -3 + row % 7; }, nullEvery(11)), - makeFlatVector( - 123, [](auto row) { return -111 + row * 2; }, nullEvery(13)), - }); - }); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) - .numDrivers(numDrivers_) - .probeKeys({"c0"}) - .probeVectors(std::move(probeVectors)) - .buildKeys({"u_c0"}) - .buildVectors(std::move(buildVectors)) - .buildFilter("c0 > 100") - .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) - .joinType(core::JoinType::kRight) - .joinOutputLayout({"c1"}) - .referenceQuery("SELECT null LIMIT 0") - .checkSpillStats(false) - .run(); - } -} - -TEST_P(MultiThreadedHashJoinTest, rightJoinWithAllMatch) { - // Left side keys are [0, 1, 2,..20]. - std::vector probeVectors = mergeBatches( - makeBatches( - 3, - [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 137, [](auto row) { return row % 21; }, nullEvery(13)), - makeFlatVector(137, [](auto row) { return row; }), - }); - }), - makeBatches( - 3, - [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 234, - [](auto row) { return (row + 3) % 21; }, - nullEvery(13)), - makeFlatVector(234, [](auto row) { return row; }), - }); - }), - true); - - // Right side keys are [-3, -2, -1, 0, 1, 2, 3]. - std::vector buildVectors = - makeBatches(3, [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 123, [](auto row) { return -3 + row % 7; }, nullEvery(11)), - makeFlatVector( - 123, [](auto row) { return -111 + row * 2; }, nullEvery(13)), - }); - }); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .probeKeys({"c0"}) - .probeVectors(std::move(probeVectors)) - .buildKeys({"u_c0"}) - .buildVectors(std::move(buildVectors)) - .buildFilter("c0 >= 0") - .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) - .joinType(core::JoinType::kRight) - .joinOutputLayout({"c0", "c1", "u_c1"}) - .referenceQuery( - "SELECT t.c0, t.c1, u.c1 FROM t RIGHT JOIN (SELECT * FROM u WHERE c0 >= 0) u ON t.c0 = u.c0") - .run(); -} - -TEST_P(MultiThreadedHashJoinTest, rightJoinWithFilter) { - // Left side keys are [0, 1, 2,..20]. - std::vector probeVectors = mergeBatches( - makeBatches( - 3, - [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 137, [](auto row) { return row % 21; }, nullEvery(13)), - makeFlatVector(137, [](auto row) { return row; }), - }); - }), - makeBatches( - 3, - [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 234, - [](auto row) { return (row + 3) % 21; }, - nullEvery(13)), - makeFlatVector(234, [](auto row) { return row; }), - }); - }), - true); - - // Right side keys are [-3, -2, -1, 0, 1, 2, 3]. - std::vector buildVectors = - makeBatches(3, [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 123, [](auto row) { return -3 + row % 7; }, nullEvery(11)), - makeFlatVector( - 123, [](auto row) { return -111 + row * 2; }, nullEvery(13)), - }); - }); - - // Filter with passed rows. - { - auto testProbeVectors = probeVectors; - auto testBuildVectors = buildVectors; - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .probeKeys({"c0"}) - .probeVectors(std::move(testProbeVectors)) - .buildKeys({"u_c0"}) - .buildVectors(std::move(testBuildVectors)) - .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) - .joinType(core::JoinType::kRight) - .joinFilter("(c1 + u_c1) % 2 = 1") - .joinOutputLayout({"c0", "c1", "u_c1"}) - .referenceQuery( - "SELECT t.c0, t.c1, u.c1 FROM t RIGHT JOIN u ON t.c0 = u.c0 AND (t.c1 + u.c1) % 2 = 1") - .run(); - } - - // Filter without passed rows. - { - auto testProbeVectors = probeVectors; - auto testBuildVectors = buildVectors; - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .probeKeys({"c0"}) - .probeVectors(std::move(testProbeVectors)) - .buildKeys({"u_c0"}) - .buildVectors(std::move(testBuildVectors)) - .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) - .joinType(core::JoinType::kRight) - .joinFilter("(c1 + u_c1) % 2 = 3") - .joinOutputLayout({"c0", "c1", "u_c1"}) - .referenceQuery( - "SELECT t.c0, t.c1, u.c1 FROM t RIGHT JOIN u ON t.c0 = u.c0 AND (t.c1 + u.c1) % 2 = 3") - .run(); - } -} - -TEST_P(MultiThreadedHashJoinTest, fullJoin) { - // Left side keys are [0, 1, 2,..20]. - std::vector probeVectors = mergeBatches( - makeBatches( - 3, - [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 213, [](auto row) { return row % 21; }, nullEvery(13)), - makeFlatVector(213, [](auto row) { return row; }), - }); - }), - makeBatches( - 2, - [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 137, - [](auto row) { return (row + 3) % 21; }, - nullEvery(13)), - makeFlatVector(137, [](auto row) { return row; }), - }); - }), - true); - - // Right side keys are [-3, -2, -1, - // 0, 1, 2, 3]. - std::vector buildVectors = - makeBatches(3, [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 123, [](auto row) { return -3 + row % 7; }, nullEvery(11)), - makeFlatVector( - 123, [](auto row) { return -111 + row * 2; }, nullEvery(13)), - }); - }); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .probeKeys({"c0"}) - .probeVectors(std::move(probeVectors)) - .buildKeys({"u_c0"}) - .buildVectors(std::move(buildVectors)) - .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) - .joinType(core::JoinType::kFull) - .joinOutputLayout({"c0", "c1", "u_c1"}) - .referenceQuery( - "SELECT t.c0, t.c1, u.c1 FROM t FULL OUTER JOIN u ON t.c0 = u.c0") - .run(); -} - -TEST_P(MultiThreadedHashJoinTest, fullJoinWithEmptyBuild) { - const std::vector finishOnEmptys = {false, true}; - for (const auto finishOnEmpty : finishOnEmptys) { - SCOPED_TRACE(fmt::format("finishOnEmpty: {}", finishOnEmpty)); - - // Left side keys are [0, 1, 2,..10]. - std::vector probeVectors = mergeBatches( - makeBatches( - 3, - [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 213, [](auto row) { return row % 11; }, nullEvery(13)), - makeFlatVector(213, [](auto row) { return row; }), - }); - }), - makeBatches( - 2, - [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 137, - [](auto row) { return (row + 3) % 11; }, - nullEvery(13)), - makeFlatVector(137, [](auto row) { return row; }), - }); - }), - true); - - // Right side keys are [-3, -2, -1, 0, 1, 2, 3]. - std::vector buildVectors = - makeBatches(3, [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 123, [](auto row) { return -3 + row % 7; }, nullEvery(11)), - makeFlatVector( - 123, [](auto row) { return -111 + row * 2; }, nullEvery(13)), - }); - }); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) - .numDrivers(numDrivers_) - .probeKeys({"c0"}) - .probeVectors(std::move(probeVectors)) - .buildKeys({"u_c0"}) - .buildVectors(std::move(buildVectors)) - .buildFilter("c0 > 100") - .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) - .joinType(core::JoinType::kFull) - .joinOutputLayout({"c1"}) - .referenceQuery( - "SELECT t.c1 FROM t FULL OUTER JOIN (SELECT * FROM u WHERE c0 > 100) u ON t.c0 = u.c0") - .checkSpillStats(false) - .run(); - } -} - -TEST_P(MultiThreadedHashJoinTest, fullJoinWithNoMatch) { - // Left side keys are [0, 1, 2,..10]. - std::vector probeVectors = mergeBatches( - makeBatches( - 3, - [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 213, [](auto row) { return row % 11; }, nullEvery(13)), - makeFlatVector(213, [](auto row) { return row; }), - }); - }), - makeBatches( - 2, - [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 137, - [](auto row) { return (row + 3) % 11; }, - nullEvery(13)), - makeFlatVector(137, [](auto row) { return row; }), - }); - }), - true); - - // Right side keys are [-3, -2, -1, 0, 1, 2, 3]. - std::vector buildVectors = - makeBatches(3, [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 123, [](auto row) { return -3 + row % 7; }, nullEvery(11)), - makeFlatVector( - 123, [](auto row) { return -111 + row * 2; }, nullEvery(13)), - }); - }); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .probeKeys({"c0"}) - .probeVectors(std::move(probeVectors)) - .buildKeys({"u_c0"}) - .buildVectors(std::move(buildVectors)) - .buildFilter("c0 < 0") - .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) - .joinType(core::JoinType::kFull) - .joinOutputLayout({"c1"}) - .referenceQuery( - "SELECT t.c1 FROM t FULL OUTER JOIN (SELECT * FROM u WHERE c0 < 0) u ON t.c0 = u.c0") - .run(); -} - -TEST_P(MultiThreadedHashJoinTest, fullJoinWithFilters) { - // Left side keys are [0, 1, 2,..10]. - std::vector probeVectors = mergeBatches( - makeBatches( - 3, - [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 213, [](auto row) { return row % 11; }, nullEvery(13)), - makeFlatVector(213, [](auto row) { return row; }), - }); - }), - makeBatches( - 2, - [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 137, - [](auto row) { return (row + 3) % 11; }, - nullEvery(13)), - makeFlatVector(137, [](auto row) { return row; }), - }); - }), - true); - - // Right side keys are [-3, -2, -1, 0, 1, 2, 3]. - std::vector buildVectors = - makeBatches(3, [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector( - 123, [](auto row) { return -3 + row % 7; }, nullEvery(11)), - makeFlatVector( - 123, [](auto row) { return -111 + row * 2; }, nullEvery(13)), - }); - }); - - // Filter with passed rows. - { - auto testProbeVectors = probeVectors; - auto testBuildVectors = buildVectors; - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .probeKeys({"c0"}) - .probeVectors(std::move(testProbeVectors)) - .buildKeys({"u_c0"}) - .buildVectors(std::move(testBuildVectors)) - .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) - .joinType(core::JoinType::kFull) - .joinFilter("(c1 + u_c1) % 2 = 1") - .joinOutputLayout({"c0", "c1", "u_c1"}) - .referenceQuery( - "SELECT t.c0, t.c1, u.c1 FROM t FULL OUTER JOIN u ON t.c0 = u.c0 AND (t.c1 + u.c1) % 2 = 1") - .run(); - } - - // Filter without passed rows. - { - auto testProbeVectors = probeVectors; - auto testBuildVectors = buildVectors; - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .probeKeys({"c0"}) - .probeVectors(std::move(testProbeVectors)) - .buildKeys({"u_c0"}) - .buildVectors(std::move(testBuildVectors)) - .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) - .joinType(core::JoinType::kFull) - .joinFilter("(c1 + u_c1) % 2 = 3") - .joinOutputLayout({"c0", "c1", "u_c1"}) - .referenceQuery( - "SELECT t.c0, t.c1, u.c1 FROM t FULL OUTER JOIN u ON t.c0 = u.c0 AND (t.c1 + u.c1) % 2 = 3") - .run(); - } -} - -TEST_P(MultiThreadedHashJoinTest, noSpillLevelLimit) { - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .keyTypes({INTEGER()}) - .probeVectors(1600, 5) - .buildVectors(1500, 5) - .referenceQuery( - "SELECT t_k0, t_data, u_k0, u_data FROM t, u WHERE t.t_k0 = u.u_k0") - .maxSpillLevel(-1) - .config(core::QueryConfig::kSpillStartPartitionBit, "51") - .config(core::QueryConfig::kSpillNumPartitionBits, "3") - .checkSpillStats(false) - .verifier([&](const std::shared_ptr& task, bool hasSpill) { - if (!hasSpill) { - return; - } - ASSERT_EQ(maxHashBuildSpillLevel(*task), 3); - }) - .run(); -} - -// Verify that dynamic filter pushed down is turned off for null-aware right -// semi project join. -TEST_F(HashJoinTest, nullAwareRightSemiProjectOverScan) { - std::vector probes; - std::vector builds; - // Matches present: - probes.push_back(makeRowVector( - {"t0"}, - { - makeNullableFlatVector({1, std::nullopt, 2}), - })); - builds.push_back(makeRowVector( - {"u0"}, - { - makeNullableFlatVector({1, 2, 3, std::nullopt}), - })); - - // No matches present: - probes.push_back(makeRowVector( - {"t0"}, - { - makeFlatVector({5, 6}), - })); - builds.push_back(makeRowVector( - {"u0"}, - { - makeNullableFlatVector({1, 2, 3, std::nullopt}), - })); - - for (int i = 0; i < probes.size(); i++) { - RowVectorPtr& probe = probes[i]; - RowVectorPtr& build = builds[i]; - std::shared_ptr probeFile = TempFilePath::create(); - writeToFile(probeFile->getPath(), {probe}); - - std::shared_ptr buildFile = TempFilePath::create(); - writeToFile(buildFile->getPath(), {build}); - - createDuckDbTable("t", {probe}); - createDuckDbTable("u", {build}); - - core::PlanNodeId probeScanId; - core::PlanNodeId buildScanId; - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .tableScan(asRowType(probe->type())) - .capturePlanNodeId(probeScanId) - .hashJoin( - {"t0"}, - {"u0"}, - PlanBuilder(planNodeIdGenerator) - .tableScan(asRowType(build->type())) - .capturePlanNodeId(buildScanId) - .planNode(), - "", - {"u0", "match"}, - core::JoinType::kRightSemiProject, - true /*nullAware*/) - .planNode(); - - SplitInput splitInput = { - {probeScanId, - {exec::Split(makeHiveConnectorSplit(probeFile->getPath()))}}, - {buildScanId, - {exec::Split(makeHiveConnectorSplit(buildFile->getPath()))}}, - }; - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .inputSplits(splitInput) - .checkSpillStats(false) - .referenceQuery("SELECT u0, u0 IN (SELECT t0 FROM t) FROM u") - .run(); - } -} - -TEST_F(HashJoinTest, duplicateJoinKeys) { - auto leftVectors = makeBatches(3, [&](int32_t /*unused*/) { - return makeRowVector({ - makeNullableFlatVector( - {1, 2, 2, 3, 3, std::nullopt, 4, 5, 5, 6, 7}), - makeNullableFlatVector( - {1, 2, 2, std::nullopt, 3, 3, 4, 5, 5, 6, 8}), - }); - }); - - auto rightVectors = makeBatches(3, [&](int32_t /*unused*/) { - return makeRowVector({ - makeNullableFlatVector({1, 1, 3, 4, std::nullopt, 5, 7, 8}), - makeNullableFlatVector({1, 1, 3, 4, 5, std::nullopt, 7, 8}), - }); - }); - - createDuckDbTable("t", leftVectors); - createDuckDbTable("u", rightVectors); - - auto planNodeIdGenerator = std::make_shared(); - - auto assertPlan = [&](const std::vector& leftProject, - const std::vector& leftKeys, - const std::vector& rightProject, - const std::vector& rightKeys, - const std::vector& outputLayout, - core::JoinType joinType, - const std::string& query) { - auto plan = PlanBuilder(planNodeIdGenerator) - .values(leftVectors) - .project(leftProject) - .hashJoin( - leftKeys, - rightKeys, - PlanBuilder(planNodeIdGenerator) - .values(rightVectors) - .project(rightProject) - .planNode(), - "", - outputLayout, - joinType) - .planNode(); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .referenceQuery(query) - .run(); - }; - - std::vector> joins = { - {core::JoinType::kInner, "INNER JOIN"}, - {core::JoinType::kLeft, "LEFT JOIN"}, - {core::JoinType::kRight, "RIGHT JOIN"}, - {core::JoinType::kFull, "FULL OUTER JOIN"}}; - - for (const auto& [joinType, joinTypeSql] : joins) { - // Duplicate keys on the build side. - assertPlan( - {"c0 AS t0", "c1 as t1"}, // leftProject - {"t0", "t1"}, // leftKeys - {"c0 AS u0"}, // rightProject - {"u0", "u0"}, // rightKeys - {"t0", "t1", "u0"}, // outputLayout - joinType, - "SELECT t.c0, t.c1, u.c0 FROM t " + joinTypeSql + - " u ON t.c0 = u.c0 and t.c1 = u.c0"); - } - - for (const auto& [joinType, joinTypeSql] : joins) { - // Duplicated keys on the probe side. - assertPlan( - {"c0 AS t0"}, // leftProject - {"t0", "t0"}, // leftKeys - {"c0 AS u0", "c1 AS u1"}, // rightProject - {"u0", "u1"}, // rightKeys - {"t0", "u0", "u1"}, // outputLayout - joinType, - "SELECT t.c0, u.c0, u.c1 FROM t " + joinTypeSql + - " u ON t.c0 = u.c0 and t.c0 = u.c1"); - } -} - -TEST_F(HashJoinTest, semiProject) { - // Some keys have multiple rows: 2, 3, 5. - auto probeVectors = makeBatches(3, [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector({1, 2, 2, 3, 3, 3, 4, 5, 5, 6, 7}), - makeFlatVector({10, 20, 21, 30, 31, 32, 40, 50, 51, 60, 70}), - }); - }); - - // Some keys are missing: 2, 6. - // Some have multiple rows: 1, 5. - // Some keys are not present on probe side: 8. - auto buildVectors = makeBatches(3, [&](int32_t /*unused*/) { - return makeRowVector({ - makeFlatVector({1, 1, 3, 4, 5, 5, 7, 8}), - makeFlatVector({100, 101, 300, 400, 500, 501, 700, 800}), - }); - }); - - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", buildVectors); - - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .values(probeVectors) - .project({"c0 AS t0", "c1 AS t1"}) - .hashJoin( - {"t0"}, - {"u0"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors) - .project({"c0 AS u0", "c1 AS u1"}) - .planNode(), - "", - {"t0", "t1", "match"}, - core::JoinType::kLeftSemiProject) - .planNode(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .referenceQuery( - "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE t.c0 = u.c0) FROM t") - .run(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(flipJoinSides(plan)) - .referenceQuery( - "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE t.c0 = u.c0) FROM t") - .run(); - - // With extra filter. - planNodeIdGenerator = std::make_shared(); - plan = PlanBuilder(planNodeIdGenerator) - .values(probeVectors) - .project({"c0 AS t0", "c1 AS t1"}) - .hashJoin( - {"t0"}, - {"u0"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors) - .project({"c0 AS u0", "c1 AS u1"}) - .planNode(), - "t1 * 10 <> u1", - {"t0", "t1", "match"}, - core::JoinType::kLeftSemiProject) - .planNode(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .referenceQuery( - "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE t.c0 = u.c0 AND t.c1 * 10 <> u.c1) FROM t") - .run(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(flipJoinSides(plan)) - .referenceQuery( - "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE t.c0 = u.c0 AND t.c1 * 10 <> u.c1) FROM t") - .run(); - - // Empty build side. - planNodeIdGenerator = std::make_shared(); - plan = PlanBuilder(planNodeIdGenerator) - .values(probeVectors) - .project({"c0 AS t0", "c1 AS t1"}) - .hashJoin( - {"t0"}, - {"u0"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors) - .project({"c0 AS u0", "c1 AS u1"}) - .filter("u0 < 0") - .planNode(), - "", - {"t0", "t1", "match"}, - core::JoinType::kLeftSemiProject) - .planNode(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .referenceQuery( - "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE u.c0 < 0 AND t.c0 = u.c0) FROM t") - // NOTE: there is no spilling in empty build test case as all the - // build-side rows have been filtered out. - .checkSpillStats(false) - .run(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(flipJoinSides(plan)) - .referenceQuery( - "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE u.c0 < 0 AND t.c0 = u.c0) FROM t") - // NOTE: there is no spilling in empty build test case as all the - // build-side rows have been filtered out. - .checkSpillStats(false) - .run(); -} - -TEST_F(HashJoinTest, semiProjectWithNullKeys) { - // Some keys have multiple rows: 2, 3, 5. - auto probeVectors = makeBatches(3, [&](int32_t /*unused*/) { - return makeRowVector( - {"t0", "t1"}, - { - makeNullableFlatVector( - {1, 2, 2, 3, 3, 3, 4, std::nullopt, 5, 5, 6, 7}), - makeFlatVector( - {10, 20, 21, 30, 31, 32, 40, -1, 50, 51, 60, 70}), - }); - }); - - // Some keys are missing: 2, 6. - // Some have multiple rows: 1, 5. - // Some keys are not present on probe side: 8. - auto buildVectors = makeBatches(3, [&](int32_t /*unused*/) { - return makeRowVector( - {"u0", "u1"}, - { - makeNullableFlatVector( - {1, 1, 3, 4, std::nullopt, 5, 5, 7, 8}), - makeFlatVector( - {100, 101, 300, 400, -100, 500, 501, 700, 800}), - }); - }); - - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", buildVectors); - - auto makePlan = [&](bool nullAware, - const std::string& probeFilter = "", - const std::string& buildFilter = "") { - auto planNodeIdGenerator = std::make_shared(); - return PlanBuilder(planNodeIdGenerator) - .values(probeVectors) - .optionalFilter(probeFilter) - .hashJoin( - {"t0"}, - {"u0"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors) - .optionalFilter(buildFilter) - .planNode(), - "", - {"t0", "t1", "match"}, - core::JoinType::kLeftSemiProject, - nullAware) - .planNode(); - }; - - // Null join keys on both sides. - auto plan = makePlan(false /*nullAware*/); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .referenceQuery( - "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0) FROM t") - .run(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(flipJoinSides(plan)) - .referenceQuery( - "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0) FROM t") - .run(); - - plan = makePlan(true /*nullAware*/); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .referenceQuery("SELECT t0, t1, t0 IN (SELECT u0 FROM u) FROM t") - .run(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(flipJoinSides(plan)) - .referenceQuery("SELECT t0, t1, t0 IN (SELECT u0 FROM u) FROM t") - .run(); - - // Null join keys on build side-only. - plan = makePlan(false /*nullAware*/, "t0 IS NOT NULL"); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .referenceQuery( - "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0) FROM t WHERE t0 IS NOT NULL") - .run(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(flipJoinSides(plan)) - .referenceQuery( - "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0) FROM t WHERE t0 IS NOT NULL") - .run(); - - plan = makePlan(true /*nullAware*/, "t0 IS NOT NULL"); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .referenceQuery( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u) FROM t WHERE t0 IS NOT NULL") - .run(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(flipJoinSides(plan)) - .referenceQuery( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u) FROM t WHERE t0 IS NOT NULL") - .run(); - - // Null join keys on probe side-only. - plan = makePlan(false /*nullAware*/, "", "u0 IS NOT NULL"); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .referenceQuery( - "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0 AND u0 IS NOT NULL) FROM t") - .run(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(flipJoinSides(plan)) - .referenceQuery( - "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0 AND u0 IS NOT NULL) FROM t") - .run(); - - plan = makePlan(true /*nullAware*/, "", "u0 IS NOT NULL"); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .referenceQuery( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE u0 IS NOT NULL) FROM t") - .run(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(flipJoinSides(plan)) - .referenceQuery( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE u0 IS NOT NULL) FROM t") - .run(); - - // Empty build side. - plan = makePlan(false /*nullAware*/, "", "u0 < 0"); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) - .planNode(plan) - .checkSpillStats(false) - .referenceQuery( - "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0 AND u0 < 0) FROM t") - .run(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) - .planNode(flipJoinSides(plan)) - .checkSpillStats(false) - .referenceQuery( - "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0 AND u0 < 0) FROM t") - .run(); - - plan = makePlan(true /*nullAware*/, "", "u0 < 0"); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) - .planNode(plan) - .checkSpillStats(false) - .referenceQuery( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE u0 < 0) FROM t") - .run(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) - .planNode(flipJoinSides(plan)) - .checkSpillStats(false) - .referenceQuery( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE u0 < 0) FROM t") - .run(); - - // Build side with all rows having null join keys. - plan = makePlan(false /*nullAware*/, "", "u0 IS NULL"); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) - .planNode(plan) - .checkSpillStats(false) - .referenceQuery( - "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0 AND u0 IS NULL) FROM t") - .run(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) - .planNode(flipJoinSides(plan)) - .checkSpillStats(false) - .referenceQuery( - "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0 AND u0 IS NULL) FROM t") - .run(); - - plan = makePlan(true /*nullAware*/, "", "u0 IS NULL"); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) - .planNode(plan) - .checkSpillStats(false) - .referenceQuery( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE u0 IS NULL) FROM t") - .run(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) - .planNode(flipJoinSides(plan)) - .checkSpillStats(false) - .referenceQuery( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE u0 IS NULL) FROM t") - .run(); -} - -TEST_F(HashJoinTest, semiProjectWithFilter) { - auto probeVectors = makeBatches(3, [&](auto /*unused*/) { - return makeRowVector( - {"t0", "t1"}, - { - makeNullableFlatVector({1, 2, 3, std::nullopt, 5}), - makeFlatVector({10, 20, 30, 40, 50}), - }); - }); - - auto buildVectors = makeBatches(3, [&](auto /*unused*/) { - return makeRowVector( - {"u0", "u1"}, - { - makeNullableFlatVector({1, 2, 3, std::nullopt}), - makeFlatVector({11, 22, 33, 44}), - }); - }); - - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", buildVectors); - - auto makePlan = [&](bool nullAware, const std::string& filter) { - auto planNodeIdGenerator = std::make_shared(); - return PlanBuilder(planNodeIdGenerator) - .values(probeVectors) - .hashJoin( - {"t0"}, - {"u0"}, - PlanBuilder(planNodeIdGenerator).values(buildVectors).planNode(), - filter, - {"t0", "t1", "match"}, - core::JoinType::kLeftSemiProject, - nullAware) - .planNode(); - }; - - std::vector filters = { - "t1 <> u1", - "t1 < u1", - "t1 > u1", - "t1 is not null AND u1 is not null", - "t1 is null OR u1 is null", - }; - for (const auto& filter : filters) { - auto plan = makePlan(true /*nullAware*/, filter); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .referenceQuery(fmt::format( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE {}) FROM t", filter)) - .injectSpill(false) - .run(); - - plan = makePlan(false /*nullAware*/, filter); - - // DuckDB Exists operator returns NULL when u0 or t0 is NULL. We exclude - // these values. - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .referenceQuery(fmt::format( - "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE (u0 is not null OR t0 is not null) AND u0 = t0 AND {}) FROM t", - filter)) - .injectSpill(false) - .run(); - } -} - -TEST_F(HashJoinTest, nullAwareRightSemiProjectWithFilterNotAllowed) { - auto probe = makeRowVector(ROW({"t0", "t1"}, {INTEGER(), BIGINT()}), 10); - auto build = makeRowVector(ROW({"u0", "u1"}, {INTEGER(), BIGINT()}), 10); - - auto planNodeIdGenerator = std::make_shared(); - VELOX_ASSERT_THROW( - PlanBuilder(planNodeIdGenerator) - .values({probe}) - .hashJoin( - {"t0"}, - {"u0"}, - PlanBuilder(planNodeIdGenerator).values({build}).planNode(), - "t1 > u1", - {"u0", "u1", "match"}, - core::JoinType::kRightSemiProject, - true /* nullAware */), - "Null-aware right semi project join doesn't support extra filter"); -} - -TEST_F(HashJoinTest, leftSemiJoinWithExtraOutputCapacity) { - std::vector probeVectors; - std::vector buildVectors; - probeVectors.push_back(makeRowVector( - {"t0", "t1"}, - { - makeFlatVector({1, 2, 3, 4, 5, 6}), - makeFlatVector({10, 10, 10, 10, 10, 10}), - })); - - buildVectors.push_back(makeRowVector( - {"u0", "u1"}, - { - makeFlatVector({1, 1, 1, 1, 1}), - makeFlatVector({10, 10, 10, 10, 10}), - })); - buildVectors.push_back(makeRowVector( - {"u0", "u1"}, - { - makeFlatVector({2, 3, 4, 5, 6}), - makeFlatVector({10, 10, 10, 10, 10}), - })); - - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", buildVectors); - auto runQuery = [&](const std::string& query, - const std::string& filter, - core::JoinType joinType) { - auto planNodeIdGenerator = std::make_shared(); - std::vector outputLayout = {"t0", "t1"}; - if (joinType == core::JoinType::kLeftSemiProject) { - outputLayout.push_back("match"); - } - auto plan = PlanBuilder(planNodeIdGenerator) - .values(probeVectors) - .hashJoin( - {"t0"}, - {"u0"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors) - .planNode(), - filter, - outputLayout, - joinType, - false) - .planNode(); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .config(core::QueryConfig::kPreferredOutputBatchRows, "5") - .referenceQuery(query) - .injectSpill(false) - .run(); - }; - { - SCOPED_TRACE("left semi filter join"); - std::string filter = "t1 = u1"; - runQuery( - fmt::format( - "SELECT t0, t1 FROM t WHERE EXISTS (SELECT u0 FROM u WHERE t0 = u0 AND {})", - filter), - filter, - core::JoinType::kLeftSemiFilter); - } - - { - SCOPED_TRACE("left semi project join"); - std::string filter = "t1 <> u1"; - runQuery( - fmt::format( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE {}) FROM t", filter), - filter, - core::JoinType::kLeftSemiProject); - } -} - -TEST_F(HashJoinTest, nullAwareMultiKeyNotAllowed) { - auto probe = makeRowVector( - ROW({"t0", "t1", "t2"}, {INTEGER(), BIGINT(), VARCHAR()}), 10); - auto build = makeRowVector( - ROW({"u0", "u1", "u2"}, {INTEGER(), BIGINT(), VARCHAR()}), 10); - - // Null-aware left semi project join. - auto planNodeIdGenerator = std::make_shared(); - VELOX_ASSERT_THROW( - PlanBuilder(planNodeIdGenerator) - .values({probe}) - .hashJoin( - {"t0", "t1"}, - {"u0", "u1"}, - PlanBuilder(planNodeIdGenerator).values({build}).planNode(), - "", - {"t0", "t1", "match"}, - core::JoinType::kLeftSemiProject, - true /* nullAware */), - "Null-aware joins allow only one join key"); - - // Null-aware right semi project join. - VELOX_ASSERT_THROW( - PlanBuilder(planNodeIdGenerator) - .values({probe}) - .hashJoin( - {"t0", "t1"}, - {"u0", "u1"}, - PlanBuilder(planNodeIdGenerator).values({build}).planNode(), - "", - {"u0", "u1", "match"}, - core::JoinType::kRightSemiProject, - true /* nullAware */), - "Null-aware joins allow only one join key"); - - // Null-aware anti join. - VELOX_ASSERT_THROW( - PlanBuilder(planNodeIdGenerator) - .values({probe}) - .hashJoin( - {"t0", "t1"}, - {"u0", "u1"}, - PlanBuilder(planNodeIdGenerator).values({build}).planNode(), - "", - {"t0", "t1"}, - core::JoinType::kAnti, - true /* nullAware */), - "Null-aware joins allow only one join key"); -} - -TEST_F(HashJoinTest, semiProjectOverLazyVectors) { - auto probeVectors = makeBatches(1, [&](auto /*unused*/) { - return makeRowVector( - {"t0", "t1"}, - { - makeFlatVector(1'000, [](auto row) { return row; }), - makeFlatVector(1'000, [](auto row) { return row * 10; }), - }); - }); - - auto buildVectors = makeBatches(3, [&](auto /*unused*/) { - return makeRowVector( - {"u0", "u1"}, - { - makeFlatVector( - 1'000, [](auto row) { return -100 + (row / 5); }), - makeFlatVector( - 1'000, [](auto row) { return -1000 + (row / 5) * 10; }), - }); - }); - - std::shared_ptr probeFile = TempFilePath::create(); - writeToFile(probeFile->getPath(), probeVectors); - - std::shared_ptr buildFile = TempFilePath::create(); - writeToFile(buildFile->getPath(), buildVectors); - - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", buildVectors); - - core::PlanNodeId probeScanId; - core::PlanNodeId buildScanId; - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .tableScan(asRowType(probeVectors[0]->type())) - .capturePlanNodeId(probeScanId) - .hashJoin( - {"t0"}, - {"u0"}, - PlanBuilder(planNodeIdGenerator) - .tableScan(asRowType(buildVectors[0]->type())) - .capturePlanNodeId(buildScanId) - .planNode(), - "", - {"t0", "t1", "match"}, - core::JoinType::kLeftSemiProject) - .planNode(); - - SplitInput splitInput = { - {probeScanId, - {exec::Split(makeHiveConnectorSplit(probeFile->getPath()))}}, - {buildScanId, - {exec::Split(makeHiveConnectorSplit(buildFile->getPath()))}}, - }; - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .inputSplits(splitInput) - .checkSpillStats(false) - .referenceQuery("SELECT t0, t1, t0 IN (SELECT u0 FROM u) FROM t") - .run(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(flipJoinSides(plan)) - .inputSplits(splitInput) - .checkSpillStats(false) - .referenceQuery("SELECT t0, t1, t0 IN (SELECT u0 FROM u) FROM t") - .run(); - - // With extra filter. - planNodeIdGenerator = std::make_shared(); - plan = PlanBuilder(planNodeIdGenerator) - .tableScan(asRowType(probeVectors[0]->type())) - .capturePlanNodeId(probeScanId) - .hashJoin( - {"t0"}, - {"u0"}, - PlanBuilder(planNodeIdGenerator) - .tableScan(asRowType(buildVectors[0]->type())) - .capturePlanNodeId(buildScanId) - .planNode(), - "(t1 + u1) % 3 = 0", - {"t0", "t1", "match"}, - core::JoinType::kLeftSemiProject) - .planNode(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .inputSplits(splitInput) - .checkSpillStats(false) - .referenceQuery( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE (t1 + u1) % 3 = 0) FROM t") - .run(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(flipJoinSides(plan)) - .inputSplits(splitInput) - .checkSpillStats(false) - .referenceQuery( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE (t1 + u1) % 3 = 0) FROM t") - .run(); -} - -VELOX_INSTANTIATE_TEST_SUITE_P( - HashJoinTest, - MultiThreadedHashJoinTest, - testing::ValuesIn(MultiThreadedHashJoinTest::getTestParams())); - -// TODO: try to parallelize the following test cases if possible. -TEST_F(HashJoinTest, memory) { - // Measures memory allocation in a 1:n hash join followed by - // projection and aggregation. We expect vectors to be mostly - // reused, except for t_k0 + 1, which is a dictionary after the - // join. - std::vector probeVectors = - makeBatches(10, [&](int32_t /*unused*/) { - return std::dynamic_pointer_cast( - BatchMaker::createBatch(probeType_, 1000, *pool_)); - }); - - // auto buildType = makeRowType(keyTypes, "u_"); - std::vector buildVectors = - makeBatches(10, [&](int32_t /*unused*/) { - return std::dynamic_pointer_cast( - BatchMaker::createBatch(buildType_, 1000, *pool_)); - }); - - auto planNodeIdGenerator = std::make_shared(); - CursorParameters params; - params.planNode = PlanBuilder(planNodeIdGenerator) - .values(probeVectors, true) - .hashJoin( - {"t_k1"}, - {"u_k1"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors, true) - .planNode(), - "", - concat(probeType_->names(), buildType_->names())) - .project({"t_k1 % 1000 AS k1", "u_k1 % 1000 AS k2"}) - .singleAggregation({}, {"sum(k1)", "sum(k2)"}) - .planNode(); - params.queryCtx = core::QueryCtx::create(driverExecutor_.get()); - auto [taskCursor, rows] = readCursor(params); - EXPECT_GT(3'500, params.queryCtx->pool()->stats().numAllocs); - EXPECT_GT(40'000'000, params.queryCtx->pool()->stats().cumulativeBytes); -} - -TEST_F(HashJoinTest, lazyVectors) { - // a dataset of multiple row groups with multiple columns. We create - // different dictionary wrappings for different columns and load the - // rows in scope at different times. - auto probeVectors = makeBatches(3, [&](int32_t /*unused*/) { - return makeRowVector( - {makeFlatVector(3'000, [](auto row) { return row; }), - makeFlatVector(30'000, [](auto row) { return row % 23; }), - makeFlatVector(30'000, [](auto row) { return row % 31; }), - makeFlatVector(30'000, [](auto row) { - return StringView::makeInline(fmt::format("{} string", row % 43)); - })}); - }); - - std::vector buildVectors = - makeBatches(4, [&](int32_t /*unused*/) { - return makeRowVector( - {makeFlatVector(1'000, [](auto row) { return row * 3; }), - makeFlatVector( - 10'000, [](auto row) { return row % 31; })}); - }); - - std::vector> tempFiles; - - for (const auto& probeVector : probeVectors) { - tempFiles.push_back(TempFilePath::create()); - writeToFile(tempFiles.back()->getPath(), probeVector); - } - createDuckDbTable("t", probeVectors); - - for (const auto& buildVector : buildVectors) { - tempFiles.push_back(TempFilePath::create()); - writeToFile(tempFiles.back()->getPath(), buildVector); - } - createDuckDbTable("u", buildVectors); - - auto makeInputSplits = [&](const core::PlanNodeId& probeScanId, - const core::PlanNodeId& buildScanId) { - return [&] { - std::vector probeSplits; - for (int i = 0; i < probeVectors.size(); ++i) { - probeSplits.push_back( - exec::Split(makeHiveConnectorSplit(tempFiles[i]->getPath()))); - } - std::vector buildSplits; - for (int i = 0; i < buildVectors.size(); ++i) { - buildSplits.push_back(exec::Split(makeHiveConnectorSplit( - tempFiles[probeSplits.size() + i]->getPath()))); - } - SplitInput splits; - splits.emplace(probeScanId, probeSplits); - splits.emplace(buildScanId, buildSplits); - return splits; - }; - }; - - { - auto planNodeIdGenerator = std::make_shared(); - core::PlanNodeId probeScanId; - core::PlanNodeId buildScanId; - auto op = PlanBuilder(planNodeIdGenerator) - .tableScan(ROW({"c0", "c1"}, {INTEGER(), BIGINT()})) - .capturePlanNodeId(probeScanId) - .hashJoin( - {"c0"}, - {"c0"}, - PlanBuilder(planNodeIdGenerator) - .tableScan(ROW({"c0"}, {INTEGER()})) - .capturePlanNodeId(buildScanId) - .planNode(), - "", - {"c1"}) - .project({"c1 + 1"}) - .planNode(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(op)) - .makeInputSplits(makeInputSplits(probeScanId, buildScanId)) - .referenceQuery("SELECT t.c1 + 1 FROM t, u WHERE t.c0 = u.c0") - .run(); - } - - { - auto planNodeIdGenerator = std::make_shared(); - core::PlanNodeId probeScanId; - core::PlanNodeId buildScanId; - auto op = PlanBuilder(planNodeIdGenerator) - .tableScan( - ROW({"c0", "c1", "c2", "c3"}, - {INTEGER(), BIGINT(), INTEGER(), VARCHAR()})) - .capturePlanNodeId(probeScanId) - .filter("c2 < 29") - .hashJoin( - {"c0"}, - {"bc0"}, - PlanBuilder(planNodeIdGenerator) - .tableScan(ROW({"c0", "c1"}, {INTEGER(), BIGINT()})) - .capturePlanNodeId(buildScanId) - .project({"c0 as bc0", "c1 as bc1"}) - .planNode(), - "(c1 + bc1) % 33 < 27", - {"c1", "bc1", "c3"}) - .project({"c1 + 1", "bc1", "length(c3)"}) - .planNode(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(op)) - .makeInputSplits(makeInputSplits(probeScanId, buildScanId)) - .referenceQuery( - "SELECT t.c1 + 1, U.c1, length(t.c3) FROM t, u WHERE t.c0 = u.c0 and t.c2 < 29 and (t.c1 + u.c1) % 33 < 27") - .run(); - } -} - -TEST_F(HashJoinTest, lazyVectorNotLoadedInFilter) { - // Ensure that if lazy vectors are temporarily wrapped during a filter's - // execution and remain unloaded, the temporary wrap is promptly - // discarded. This precaution prevents the generation of the probe's output - // from wrapping an unloaded vector while the temporary wrap is - // still alive. - // This is done by generating a sufficiently small batch to allow the lazy - // vector to remain unloaded, as it doesn't need to be split between batches. - // Then we use a filter that skips the execution of the expression containing - // the lazy vector, thereby avoiding its loading. - - testLazyVectorsWithFilter( - core::JoinType::kInner, - "c1 >= 0 OR c2 > 0", - {"c1", "c2"}, - "SELECT t.c1, t.c2 FROM t, u WHERE t.c0 = u.c0"); -} - -TEST_F(HashJoinTest, lazyVectorPartiallyLoadedInFilterLeftJoin) { - // Test the case where a filter loads a subset of the rows that will be output - // from a column on the probe side. - - testLazyVectorsWithFilter( - core::JoinType::kLeft, - "c1 > 0 AND c2 > 0", - {"c1", "c2"}, - "SELECT t.c1, t.c2 FROM t LEFT JOIN u ON t.c0 = u.c0 AND (c1 > 0 AND c2 > 0)"); -} - -TEST_F(HashJoinTest, lazyVectorPartiallyLoadedInFilterFullJoin) { - // Test the case where a filter loads a subset of the rows that will be output - // from a column on the probe side. - - testLazyVectorsWithFilter( - core::JoinType::kFull, - "c1 > 0 AND c2 > 0", - {"c1", "c2"}, - "SELECT t.c1, t.c2 FROM t FULL OUTER JOIN u ON t.c0 = u.c0 AND (c1 > 0 AND c2 > 0)"); -} - -TEST_F(HashJoinTest, lazyVectorPartiallyLoadedInFilterLeftSemiProject) { - // Test the case where a filter loads a subset of the rows that will be output - // from a column on the probe side. - - testLazyVectorsWithFilter( - core::JoinType::kLeftSemiProject, - "c1 > 0 AND c2 > 0", - {"c1", "c2", "match"}, - "SELECT t.c1, t.c2, EXISTS (SELECT * FROM u WHERE t.c0 = u.c0 AND (t.c1 > 0 AND t.c2 > 0)) FROM t"); -} - -TEST_F(HashJoinTest, lazyVectorPartiallyLoadedInFilterAntiJoin) { - // Test the case where a filter loads a subset of the rows that will be output - // from a column on the probe side. - - testLazyVectorsWithFilter( - core::JoinType::kAnti, - "c1 > 0 AND c2 > 0", - {"c1", "c2"}, - "SELECT t.c1, t.c2 FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE t.c0 = u.c0 AND (t.c1 > 0 AND t.c2 > 0))"); -} - -TEST_F(HashJoinTest, lazyVectorPartiallyLoadedInFilterInnerJoin) { - // Test the case where a filter loads a subset of the rows that will be output - // from a column on the probe side. - - testLazyVectorsWithFilter( - core::JoinType::kInner, - "not (c1 < 15 and c2 >= 0)", - {"c1", "c2"}, - "SELECT t.c1, t.c2 FROM t, u WHERE t.c0 = u.c0 AND NOT (c1 < 15 AND c2 >= 0)"); -} - -TEST_F(HashJoinTest, lazyVectorPartiallyLoadedInFilterLeftSemiFilter) { - // Test the case where a filter loads a subset of the rows that will be output - // from a column on the probe side. - - testLazyVectorsWithFilter( - core::JoinType::kLeftSemiFilter, - "not (c1 < 15 and c2 >= 0)", - {"c1", "c2"}, - "SELECT t.c1, t.c2 FROM t WHERE c0 IN (SELECT u.c0 FROM u WHERE t.c0 = u.c0 AND NOT (t.c1 < 15 AND t.c2 >= 0))"); -} - -TEST_F(HashJoinTest, dynamicFilters) { - const int32_t numSplits = 10; - const int32_t numRowsProbe = 333; - const int32_t numRowsBuild = 100; - - std::vector probeVectors; - probeVectors.reserve(numSplits); - - std::vector> tempFiles; - for (int32_t i = 0; i < numSplits; ++i) { - auto rowVector = makeRowVector({ - makeFlatVector( - numRowsProbe, [&](auto row) { return row - i * 10; }), - makeFlatVector(numRowsProbe, [](auto row) { return row; }), - }); - probeVectors.push_back(rowVector); - tempFiles.push_back(TempFilePath::create()); - writeToFile(tempFiles.back()->getPath(), rowVector); - } - auto makeInputSplits = [&](const core::PlanNodeId& nodeId) { - return [&] { - std::vector probeSplits; - for (auto& file : tempFiles) { - probeSplits.push_back( - exec::Split(makeHiveConnectorSplit(file->getPath()))); - } - SplitInput splits; - splits.emplace(nodeId, probeSplits); - return splits; - }; - }; - - // 100 key values in [35, 233] range. - std::vector buildVectors; - for (int i = 0; i < 5; ++i) { - buildVectors.push_back(makeRowVector({ - makeFlatVector( - numRowsBuild / 5, - [i](auto row) { return 35 + 2 * (row + i * numRowsBuild / 5); }), - makeFlatVector(numRowsBuild / 5, [](auto row) { return row; }), - })); - } - std::vector keyOnlyBuildVectors; - for (int i = 0; i < 5; ++i) { - keyOnlyBuildVectors.push_back( - makeRowVector({makeFlatVector(numRowsBuild / 5, [i](auto row) { - return 35 + 2 * (row + i * numRowsBuild / 5); - })})); - } - - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", buildVectors); - - auto probeType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); - - auto planNodeIdGenerator = std::make_shared(); - - auto buildSide = PlanBuilder(planNodeIdGenerator, pool_.get()) - .values(buildVectors) - .project({"c0 AS u_c0", "c1 AS u_c1"}) - .planNode(); - auto keyOnlyBuildSide = PlanBuilder(planNodeIdGenerator, pool_.get()) - .values(keyOnlyBuildVectors) - .project({"c0 AS u_c0"}) - .planNode(); - - // Basic push-down. - { - SCOPED_TRACE("Inner join"); - core::PlanNodeId probeScanId; - core::PlanNodeId joinId; - auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) - .tableScan(probeType) - .capturePlanNodeId(probeScanId) - .hashJoin( - {"c0"}, - {"u_c0"}, - buildSide, - "", - {"c0", "c1", "u_c1"}, - core::JoinType::kInner) - .capturePlanNodeId(joinId) - .project({"c0", "c1 + 1", "c1 + u_c1"}) - .planNode(); - { - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(op)) - .makeInputSplits(makeInputSplits(probeScanId)) - .referenceQuery( - "SELECT t.c0, t.c1 + 1, t.c1 + u.c1 FROM t, u WHERE t.c0 = u.c0") - .verifier([&](const std::shared_ptr& task, bool hasSpill) { - SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); - auto planStats = toPlanStats(task->taskStats()); - if (hasSpill) { - // Dynamic filtering should be disabled with spilling triggered. - ASSERT_EQ(0, getFiltersProduced(task, 1).sum); - ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); - } else { - ASSERT_EQ(1, getFiltersProduced(task, 1).sum); - ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(0, getReplacedWithFilterRows(task, 1).sum); - ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_EQ( - planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, - std::unordered_set({joinId})); - } - }) - .run(); - } - - // Left semi join. - op = PlanBuilder(planNodeIdGenerator, pool_.get()) - .tableScan(probeType) - .capturePlanNodeId(probeScanId) - .hashJoin( - {"c0"}, - {"u_c0"}, - buildSide, - "", - {"c0", "c1"}, - core::JoinType::kLeftSemiFilter) - .capturePlanNodeId(joinId) - .project({"c0", "c1 + 1"}) - .planNode(); - - { - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(op)) - .makeInputSplits(makeInputSplits(probeScanId)) - .referenceQuery( - "SELECT t.c0, t.c1 + 1 FROM t WHERE t.c0 IN (SELECT c0 FROM u)") - .verifier([&](const std::shared_ptr& task, bool hasSpill) { - SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); - auto planStats = toPlanStats(task->taskStats()); - if (hasSpill) { - // Dynamic filtering should be disabled with spilling triggered. - ASSERT_EQ(0, getFiltersProduced(task, 1).sum); - ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(0, getReplacedWithFilterRows(task, 1).sum); - ASSERT_EQ(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); - } else { - ASSERT_EQ(1, getFiltersProduced(task, 1).sum); - ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); - ASSERT_GT(getReplacedWithFilterRows(task, 1).sum, 0); - ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_EQ( - planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, - std::unordered_set({joinId})); - } - }) - .run(); - } - - // Right semi join. - op = PlanBuilder(planNodeIdGenerator, pool_.get()) - .tableScan(probeType) - .capturePlanNodeId(probeScanId) - .hashJoin( - {"c0"}, - {"u_c0"}, - buildSide, - "", - {"u_c0", "u_c1"}, - core::JoinType::kRightSemiFilter) - .capturePlanNodeId(joinId) - .project({"u_c0", "u_c1 + 1"}) - .planNode(); - - { - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(op)) - .makeInputSplits(makeInputSplits(probeScanId)) - .referenceQuery( - "SELECT u.c0, u.c1 + 1 FROM u WHERE u.c0 IN (SELECT c0 FROM t)") - .verifier([&](const std::shared_ptr& task, bool hasSpill) { - SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); - auto planStats = toPlanStats(task->taskStats()); - if (hasSpill) { - // Dynamic filtering should be disabled with spilling triggered. - ASSERT_EQ(0, getFiltersProduced(task, 1).sum); - ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(getReplacedWithFilterRows(task, 1).sum, 0); - ASSERT_EQ(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); - } else { - ASSERT_EQ(1, getFiltersProduced(task, 1).sum); - ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(getReplacedWithFilterRows(task, 1).sum, 0); - ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_EQ( - planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, - std::unordered_set({joinId})); - } - }) - .run(); - } - - // Right join. - op = PlanBuilder(planNodeIdGenerator, pool_.get()) - .tableScan(probeType) - .capturePlanNodeId(probeScanId) - .hashJoin( - {"c0"}, - {"u_c0"}, - buildSide, - "", - {"c0", "c1", "u_c1"}, - core::JoinType::kRight) - .capturePlanNodeId(joinId) - .project({"c0", "c1 + 1", "c1 + u_c1"}) - .planNode(); - { - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(op)) - .makeInputSplits(makeInputSplits(probeScanId)) - .referenceQuery( - "SELECT t.c0, t.c1 + 1, t.c1 + u.c1 FROM t RIGHT JOIN u ON t.c0 = u.c0") - .verifier([&](const std::shared_ptr& task, bool hasSpill) { - SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); - auto planStats = toPlanStats(task->taskStats()); - if (hasSpill) { - // Dynamic filtering should be disabled with spilling triggered. - ASSERT_EQ(0, getFiltersProduced(task, 1).sum); - ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); - } else { - ASSERT_EQ(1, getFiltersProduced(task, 1).sum); - ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(0, getReplacedWithFilterRows(task, 1).sum); - ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_EQ( - planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, - std::unordered_set({joinId})); - } - }) - .run(); - } - } - - // Basic push-down with column names projected out of the table scan - // having different names than column names in the files. - { - SCOPED_TRACE("Inner join column rename"); - auto scanOutputType = ROW({"a", "b"}, {INTEGER(), BIGINT()}); - connector::ColumnHandleMap assignments; - assignments["a"] = regularColumn("c0", INTEGER()); - assignments["b"] = regularColumn("c1", BIGINT()); - - core::PlanNodeId probeScanId; - core::PlanNodeId joinId; - auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) - .startTableScan() - .outputType(scanOutputType) - .assignments(assignments) - .endTableScan() - .capturePlanNodeId(probeScanId) - .hashJoin({"a"}, {"u_c0"}, buildSide, "", {"a", "b", "u_c1"}) - .capturePlanNodeId(joinId) - .project({"a", "b + 1", "b + u_c1"}) - .planNode(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(op)) - .makeInputSplits(makeInputSplits(probeScanId)) - .referenceQuery( - "SELECT t.c0, t.c1 + 1, t.c1 + u.c1 FROM t, u WHERE t.c0 = u.c0") - .verifier([&](const std::shared_ptr& task, bool hasSpill) { - SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); - auto planStats = toPlanStats(task->taskStats()); - if (hasSpill) { - // Dynamic filtering should be disabled with spilling triggered. - ASSERT_EQ(0, getFiltersProduced(task, 1).sum); - ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(0, getReplacedWithFilterRows(task, 1).sum); - ASSERT_EQ(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); - } else { - ASSERT_EQ(1, getFiltersProduced(task, 1).sum); - ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(0, getReplacedWithFilterRows(task, 1).sum); - ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_EQ( - planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, - std::unordered_set({joinId})); - } - }) - .run(); - } - - // Push-down that requires merging filters. - { - SCOPED_TRACE("Merge filters"); - core::PlanNodeId probeScanId; - core::PlanNodeId joinId; - auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) - .tableScan(probeType, {"c0 < 500::INTEGER"}) - .capturePlanNodeId(probeScanId) - .hashJoin({"c0"}, {"u_c0"}, buildSide, "", {"c1", "u_c1"}) - .capturePlanNodeId(joinId) - .project({"c1 + u_c1"}) - .planNode(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(op)) - .makeInputSplits(makeInputSplits(probeScanId)) - .referenceQuery( - "SELECT t.c1 + u.c1 FROM t, u WHERE t.c0 = u.c0 AND t.c0 < 500") - .verifier([&](const std::shared_ptr& task, bool hasSpill) { - SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); - auto planStats = toPlanStats(task->taskStats()); - if (hasSpill) { - // Dynamic filtering should be disabled with spilling triggered. - ASSERT_EQ(0, getFiltersProduced(task, 1).sum); - ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(0, getReplacedWithFilterRows(task, 1).sum); - ASSERT_EQ(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); - } else { - ASSERT_EQ(1, getFiltersProduced(task, 1).sum); - ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(0, getReplacedWithFilterRows(task, 1).sum); - ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_EQ( - planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, - std::unordered_set({joinId})); - } - }) - .run(); - } - - // Push-down that turns join into a no-op. - { - SCOPED_TRACE("canReplaceWithDynamicFilter"); - core::PlanNodeId probeScanId; - core::PlanNodeId joinId; - auto op = - PlanBuilder(planNodeIdGenerator, pool_.get()) - .tableScan(probeType) - .capturePlanNodeId(probeScanId) - .hashJoin({"c0"}, {"u_c0"}, keyOnlyBuildSide, "", {"c0", "c1"}) - .capturePlanNodeId(joinId) - .project({"c0", "c1 + 1"}) - .planNode(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(op)) - .makeInputSplits(makeInputSplits(probeScanId)) - .referenceQuery("SELECT t.c0, t.c1 + 1 FROM t, u WHERE t.c0 = u.c0") - .verifier([&](const std::shared_ptr& task, bool hasSpill) { - SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); - auto planStats = toPlanStats(task->taskStats()); - if (hasSpill) { - // Dynamic filtering should be disabled with spilling triggered. - ASSERT_EQ(0, getFiltersProduced(task, 1).sum); - ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(0, getReplacedWithFilterRows(task, 1).sum); - ASSERT_EQ(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); - } else { - ASSERT_EQ(1, getFiltersProduced(task, 1).sum); - ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); - ASSERT_EQ( - getReplacedWithFilterRows(task, 1).sum, - numRowsBuild * numSplits); - ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_EQ( - planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, - std::unordered_set({joinId})); - } - }) - .run(); - } - - // Push-down that turns join into a no-op with output having a different - // number of columns than the input. - { - SCOPED_TRACE("canReplaceWithDynamicFilter column rename"); - core::PlanNodeId probeScanId; - core::PlanNodeId joinId; - auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) - .tableScan(probeType) - .capturePlanNodeId(probeScanId) - .hashJoin({"c0"}, {"u_c0"}, keyOnlyBuildSide, "", {"c0"}) - .capturePlanNodeId(joinId) - .planNode(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(op)) - .makeInputSplits(makeInputSplits(probeScanId)) - .referenceQuery("SELECT t.c0 FROM t JOIN u ON (t.c0 = u.c0)") - .verifier([&](const std::shared_ptr& task, bool hasSpill) { - SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); - auto planStats = toPlanStats(task->taskStats()); - if (hasSpill) { - // Dynamic filtering should be disabled with spilling triggered. - ASSERT_EQ(0, getFiltersProduced(task, 1).sum); - ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(0, getReplacedWithFilterRows(task, 1).sum); - ASSERT_EQ(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); - } else { - ASSERT_EQ(1, getFiltersProduced(task, 1).sum); - ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); - ASSERT_EQ( - getReplacedWithFilterRows(task, 1).sum, - numRowsBuild * numSplits); - ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_EQ( - planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, - std::unordered_set({joinId})); - } - }) - .run(); - } - - // Push-down that requires merging filters and turns join into a no-op. - { - SCOPED_TRACE("canReplaceWithDynamicFilter merge filters"); - core::PlanNodeId probeScanId; - core::PlanNodeId joinId; - auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) - .tableScan(probeType, {"c0 < 500::INTEGER"}) - .capturePlanNodeId(probeScanId) - .hashJoin({"c0"}, {"u_c0"}, keyOnlyBuildSide, "", {"c1"}) - .capturePlanNodeId(joinId) - .project({"c1 + 1"}) - .planNode(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(op)) - .makeInputSplits(makeInputSplits(probeScanId)) - .referenceQuery( - "SELECT t.c1 + 1 FROM t, u WHERE t.c0 = u.c0 AND t.c0 < 500") - .verifier([&](const std::shared_ptr& task, bool hasSpill) { - SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); - auto planStats = toPlanStats(task->taskStats()); - if (hasSpill) { - // Dynamic filtering should be disabled with spilling triggered. - ASSERT_EQ(0, getFiltersProduced(task, 1).sum); - ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(getReplacedWithFilterRows(task, 1).sum, 0); - ASSERT_EQ(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); - } else { - ASSERT_EQ(1, getFiltersProduced(task, 1).sum); - ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); - ASSERT_GT(getReplacedWithFilterRows(task, 1).sum, 0); - ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_EQ( - planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, - std::unordered_set({joinId})); - } - }) - .run(); - } - - // Push-down with highly selective filter in the scan. - { - SCOPED_TRACE("Highly selective filter"); - // Inner join. - core::PlanNodeId probeScanId; - core::PlanNodeId joinId; - auto op = - PlanBuilder(planNodeIdGenerator, pool_.get()) - .tableScan(probeType, {"c0 < 200::INTEGER"}) - .capturePlanNodeId(probeScanId) - .hashJoin( - {"c0"}, {"u_c0"}, buildSide, "", {"c1"}, core::JoinType::kInner) - .capturePlanNodeId(joinId) - .project({"c1 + 1"}) - .planNode(); - - { - SCOPED_TRACE("Inner join"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(op)) - .makeInputSplits(makeInputSplits(probeScanId)) - .referenceQuery( - "SELECT t.c1 + 1 FROM t, u WHERE t.c0 = u.c0 AND t.c0 < 200") - .verifier([&](const std::shared_ptr& task, bool hasSpill) { - SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); - auto planStats = toPlanStats(task->taskStats()); - if (hasSpill) { - // Dynamic filtering should be disabled with spilling triggered. - ASSERT_EQ(0, getFiltersProduced(task, 1).sum); - ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(getReplacedWithFilterRows(task, 1).sum, 0); - ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); - } else { - ASSERT_EQ(1, getFiltersProduced(task, 1).sum); - ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); - ASSERT_GT(getReplacedWithFilterRows(task, 1).sum, 0); - ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_EQ( - planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, - std::unordered_set({joinId})); - } - }) - .run(); - } - - // Left semi join. - op = PlanBuilder(planNodeIdGenerator, pool_.get()) - .tableScan(probeType, {"c0 < 200::INTEGER"}) - .capturePlanNodeId(probeScanId) - .hashJoin( - {"c0"}, - {"u_c0"}, - buildSide, - "", - {"c1"}, - core::JoinType::kLeftSemiFilter) - .capturePlanNodeId(joinId) - .project({"c1 + 1"}) - .planNode(); - - { - SCOPED_TRACE("Left semi join"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(op)) - .makeInputSplits(makeInputSplits(probeScanId)) - .referenceQuery( - "SELECT t.c1 + 1 FROM t WHERE t.c0 IN (SELECT c0 FROM u) AND t.c0 < 200") - .verifier([&](const std::shared_ptr& task, bool hasSpill) { - SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); - auto planStats = toPlanStats(task->taskStats()); - if (hasSpill) { - // Dynamic filtering should be disabled with spilling triggered. - ASSERT_EQ(0, getFiltersProduced(task, 1).sum); - ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(getReplacedWithFilterRows(task, 1).sum, 0); - ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); - } else { - ASSERT_EQ(1, getFiltersProduced(task, 1).sum); - ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); - ASSERT_GT(getReplacedWithFilterRows(task, 1).sum, 0); - ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_EQ( - planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, - std::unordered_set({joinId})); - } - }) - .run(); - } - - // Right semi join. - op = PlanBuilder(planNodeIdGenerator, pool_.get()) - .tableScan(probeType, {"c0 < 200::INTEGER"}) - .capturePlanNodeId(probeScanId) - .hashJoin( - {"c0"}, - {"u_c0"}, - buildSide, - "", - {"u_c1"}, - core::JoinType::kRightSemiFilter) - .capturePlanNodeId(joinId) - .project({"u_c1 + 1"}) - .planNode(); - - { - SCOPED_TRACE("Right semi join"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(op)) - .makeInputSplits(makeInputSplits(probeScanId)) - .referenceQuery( - "SELECT u.c1 + 1 FROM u WHERE u.c0 IN (SELECT c0 FROM t) AND u.c0 < 200") - .verifier([&](const std::shared_ptr& task, bool hasSpill) { - SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); - auto planStats = toPlanStats(task->taskStats()); - if (hasSpill) { - // Dynamic filtering should be disabled with spilling triggered. - ASSERT_EQ(0, getFiltersProduced(task, 1).sum); - ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(getReplacedWithFilterRows(task, 1).sum, 0); - ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); - } else { - ASSERT_EQ(1, getFiltersProduced(task, 1).sum); - ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(getReplacedWithFilterRows(task, 1).sum, 0); - ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_EQ( - planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, - std::unordered_set({joinId})); - } - }) - .run(); - } - - // Right join. - op = - PlanBuilder(planNodeIdGenerator, pool_.get()) - .tableScan(probeType, {"c0 < 200::INTEGER"}) - .capturePlanNodeId(probeScanId) - .hashJoin( - {"c0"}, {"u_c0"}, buildSide, "", {"c1"}, core::JoinType::kRight) - .capturePlanNodeId(joinId) - .project({"c1 + 1"}) - .planNode(); - - { - SCOPED_TRACE("Right join"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(op)) - .makeInputSplits(makeInputSplits(probeScanId)) - .referenceQuery( - "SELECT t.c1 + 1 FROM (SELECT * FROM t WHERE t.c0 < 200) t RIGHT JOIN u ON t.c0 = u.c0") - .verifier([&](const std::shared_ptr& task, bool hasSpill) { - SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); - auto planStats = toPlanStats(task->taskStats()); - if (hasSpill) { - // Dynamic filtering should be disabled with spilling triggered. - ASSERT_EQ(0, getFiltersProduced(task, 1).sum); - ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(getReplacedWithFilterRows(task, 1).sum, 0); - ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); - } else { - ASSERT_EQ(1, getFiltersProduced(task, 1).sum); - ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(getReplacedWithFilterRows(task, 1).sum, 0); - ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); - ASSERT_EQ( - planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, - std::unordered_set({joinId})); - } - }) - .run(); - } - } - - // Disable filter push-down by using values in place of scan. - { - SCOPED_TRACE("Disabled in case of values node"); - core::PlanNodeId joinId; - auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) - .values(probeVectors) - .hashJoin({"c0"}, {"u_c0"}, buildSide, "", {"c1"}) - .capturePlanNodeId(joinId) - .project({"c1 + 1"}) - .planNode(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(op)) - .referenceQuery("SELECT t.c1 + 1 FROM t, u WHERE t.c0 = u.c0") - .verifier([&](const std::shared_ptr& task, bool hasSpill) { - auto planStats = toPlanStats(task->taskStats()); - ASSERT_EQ(0, getFiltersProduced(task, 1).sum); - ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(numRowsProbe * numSplits, getInputPositions(task, 1)); - }) - .run(); - } - - // Disable filter push-down by using an expression as the join key on the - // probe side. - { - SCOPED_TRACE("Disabled in case of join condition"); - core::PlanNodeId probeScanId; - core::PlanNodeId joinId; - auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) - .tableScan(probeType) - .capturePlanNodeId(probeScanId) - .project({"cast(c0 + 1 as integer) AS t_key", "c1"}) - .hashJoin({"t_key"}, {"u_c0"}, buildSide, "", {"c1"}) - .capturePlanNodeId(joinId) - .project({"c1 + 1"}) - .planNode(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(op)) - .makeInputSplits(makeInputSplits(probeScanId)) - .referenceQuery("SELECT t.c1 + 1 FROM t, u WHERE (t.c0 + 1) = u.c0") - .verifier([&](const std::shared_ptr& task, bool hasSpill) { - auto planStats = toPlanStats(task->taskStats()); - ASSERT_EQ(0, getFiltersProduced(task, 1).sum); - ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(numRowsProbe * numSplits, getInputPositions(task, 1)); - ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); - }) - .run(); - } -} - -TEST_F(HashJoinTest, dynamicFiltersStatsWithChainedJoins) { - const int32_t numSplits = 10; - const int32_t numProbeRows = 333; - const int32_t numBuildRows = 100; - - std::vector probeVectors; - probeVectors.reserve(numSplits); - std::vector> tempFiles; - for (int32_t i = 0; i < numSplits; ++i) { - auto rowVector = makeRowVector({ - makeFlatVector( - numProbeRows, [&](auto row) { return row - i * 10; }), - makeFlatVector(numProbeRows, [](auto row) { return row; }), - }); - probeVectors.push_back(rowVector); - tempFiles.push_back(TempFilePath::create()); - writeToFile(tempFiles.back()->getPath(), rowVector); - } - auto makeInputSplits = [&](const core::PlanNodeId& nodeId) { - return [&] { - std::vector probeSplits; - for (auto& file : tempFiles) { - probeSplits.push_back( - exec::Split(makeHiveConnectorSplit(file->getPath()))); - } - SplitInput splits; - splits.emplace(nodeId, probeSplits); - return splits; - }; - }; - - // 100 key values in [35, 233] range. - std::vector buildVectors; - for (int i = 0; i < 5; ++i) { - buildVectors.push_back(makeRowVector({ - makeFlatVector( - numBuildRows / 5, - [i](auto row) { return 35 + 2 * (row + i * numBuildRows / 5); }), - makeFlatVector(numBuildRows / 5, [](auto row) { return row; }), - })); - } - - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", buildVectors); - - auto probeType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); - - auto planNodeIdGenerator = std::make_shared(); - - auto buildSide1 = PlanBuilder(planNodeIdGenerator, pool_.get()) - .values(buildVectors) - .project({"c0 AS u_c0", "c1 AS u_c1"}) - .planNode(); - auto buildSide2 = PlanBuilder(planNodeIdGenerator, pool_.get()) - .values(buildVectors) - .project({"c0 AS u_c0", "c1 AS u_c1"}) - .planNode(); - // Inner join pushdown. - core::PlanNodeId probeScanId; - core::PlanNodeId joinId1; - core::PlanNodeId joinId2; - auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) - .tableScan(probeType) - .capturePlanNodeId(probeScanId) - .hashJoin( - {"c0"}, - {"u_c0"}, - buildSide1, - "", - {"c0", "c1"}, - core::JoinType::kInner) - .capturePlanNodeId(joinId1) - .hashJoin( - {"c0"}, - {"u_c0"}, - buildSide2, - "", - {"c0", "c1", "u_c1"}, - core::JoinType::kInner) - .capturePlanNodeId(joinId2) - .project({"c0", "c1 + 1", "c1 + u_c1"}) - .planNode(); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(op)) - .makeInputSplits(makeInputSplits(probeScanId)) - .injectSpill(false) - .referenceQuery( - "SELECT t.c0, t.c1 + 1, t.c1 + u.c1 FROM t, u WHERE t.c0 = u.c0") - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - auto planStats = toPlanStats(task->taskStats()); - ASSERT_EQ( - planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, - std::unordered_set({joinId1, joinId2})); - }) - .run(); -} - -TEST_F(HashJoinTest, dynamicFiltersWithSkippedSplits) { - const int32_t numSplits = 20; - const int32_t numNonSkippedSplits = 10; - const int32_t numRowsProbe = 333; - const int32_t numRowsBuild = 100; - - std::vector probeVectors; - probeVectors.reserve(numSplits); - - std::vector> tempFiles; - // Each split has a column containing - // the split number. This is used to filter out whole splits based - // on metadata. We test how using metadata for dropping splits - // interactts with dynamic filters. In specific, if the first split - // is discarded based on metadata, the dynamic filters must not be - // lost even if there is no actual reader for the split. - for (int32_t i = 0; i < numSplits; ++i) { - auto rowVector = makeRowVector({ - makeFlatVector( - numRowsProbe, [&](auto row) { return row - i * 10; }), - makeFlatVector(numRowsProbe, [](auto row) { return row; }), - makeFlatVector( - numRowsProbe, [&](auto /*row*/) { return i % 2 == 0 ? 0 : i; }), - }); - probeVectors.push_back(rowVector); - tempFiles.push_back(TempFilePath::create()); - writeToFile(tempFiles.back()->getPath(), rowVector); - } - - auto makeInputSplits = [&](const core::PlanNodeId& nodeId) { - return [&] { - std::vector probeSplits; - for (auto& file : tempFiles) { - probeSplits.push_back( - exec::Split(makeHiveConnectorSplit(file->getPath()))); - } - // We add splits that have no rows. - auto makeEmpty = [&]() { - return exec::Split( - HiveConnectorSplitBuilder(tempFiles.back()->getPath()) - .start(10000000) - .length(1) - .build()); - }; - std::vector emptyFront = {makeEmpty(), makeEmpty()}; - std::vector emptyMiddle = {makeEmpty(), makeEmpty()}; - probeSplits.insert( - probeSplits.begin(), emptyFront.begin(), emptyFront.end()); - probeSplits.insert( - probeSplits.begin() + 13, emptyMiddle.begin(), emptyMiddle.end()); - SplitInput splits; - splits.emplace(nodeId, probeSplits); - return splits; - }; - }; - - // 100 key values in [35, 233] range. - std::vector buildVectors; - for (int i = 0; i < 5; ++i) { - buildVectors.push_back(makeRowVector({ - makeFlatVector( - numRowsBuild / 5, - [i](auto row) { return 35 + 2 * (row + i * numRowsBuild / 5); }), - makeFlatVector(numRowsBuild / 5, [](auto row) { return row; }), - })); - } - std::vector keyOnlyBuildVectors; - for (int i = 0; i < 5; ++i) { - keyOnlyBuildVectors.push_back( - makeRowVector({makeFlatVector(numRowsBuild / 5, [i](auto row) { - return 35 + 2 * (row + i * numRowsBuild / 5); - })})); - } - - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", buildVectors); - - auto probeType = ROW({"c0", "c1", "c2"}, {INTEGER(), BIGINT(), BIGINT()}); - - auto planNodeIdGenerator = std::make_shared(); - - auto buildSide = PlanBuilder(planNodeIdGenerator, pool_.get()) - .values(buildVectors) - .project({"c0 AS u_c0", "c1 AS u_c1"}) - .planNode(); - auto keyOnlyBuildSide = PlanBuilder(planNodeIdGenerator, pool_.get()) - .values(keyOnlyBuildVectors) - .project({"c0 AS u_c0"}) - .planNode(); - - // Basic push-down. - { - // Inner join. - core::PlanNodeId probeScanId; - auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) - .tableScan(probeType, {"c2 > 0"}) - .capturePlanNodeId(probeScanId) - .hashJoin( - {"c0"}, - {"u_c0"}, - buildSide, - "", - {"c0", "c1", "u_c1"}, - core::JoinType::kInner) - .project({"c0", "c1 + 1", "c1 + u_c1"}) - .planNode(); - { - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(op)) - .numDrivers(1) - .makeInputSplits(makeInputSplits(probeScanId)) - .referenceQuery( - "SELECT t.c0, t.c1 + 1, t.c1 + u.c1 FROM t, u WHERE t.c0 = u.c0 AND t.c2 > 0") - .verifier([&](const std::shared_ptr& task, bool hasSpill) { - SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); - if (hasSpill) { - // Dynamic filtering should be disabled with spilling triggered. - ASSERT_EQ(0, getFiltersProduced(task, 1).sum); - ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); - ASSERT_EQ( - getInputPositions(task, 1), - numRowsProbe * numNonSkippedSplits); - } else { - ASSERT_EQ(1, getFiltersProduced(task, 1).sum); - ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(0, getReplacedWithFilterRows(task, 1).sum); - ASSERT_LT( - getInputPositions(task, 1), - numRowsProbe * numNonSkippedSplits); - } - }) - .run(); - } - - // Left semi join. - op = PlanBuilder(planNodeIdGenerator, pool_.get()) - .tableScan(probeType, {"c2 > 0"}) - .capturePlanNodeId(probeScanId) - .hashJoin( - {"c0"}, - {"u_c0"}, - buildSide, - "", - {"c0", "c1"}, - core::JoinType::kLeftSemiFilter) - .project({"c0", "c1 + 1"}) - .planNode(); - - { - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(op)) - .numDrivers(1) - .makeInputSplits(makeInputSplits(probeScanId)) - .referenceQuery( - "SELECT t.c0, t.c1 + 1 FROM t WHERE t.c0 IN (SELECT c0 FROM u) AND t.c2 > 0") - .verifier([&](const std::shared_ptr& task, bool hasSpill) { - SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); - if (hasSpill) { - // Dynamic filtering should be disabled with spilling triggered. - ASSERT_EQ(0, getFiltersProduced(task, 1).sum); - ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(0, getReplacedWithFilterRows(task, 1).sum); - ASSERT_EQ( - getInputPositions(task, 1), - numRowsProbe * numNonSkippedSplits); - } else { - ASSERT_EQ(1, getFiltersProduced(task, 1).sum); - ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); - ASSERT_GT(getReplacedWithFilterRows(task, 1).sum, 0); - ASSERT_LT( - getInputPositions(task, 1), - numRowsProbe * numNonSkippedSplits); - } - }) - .run(); - } - - // Right semi join. - op = PlanBuilder(planNodeIdGenerator, pool_.get()) - .tableScan(probeType, {"c2 > 0"}) - .capturePlanNodeId(probeScanId) - .hashJoin( - {"c0"}, - {"u_c0"}, - buildSide, - "", - {"u_c0", "u_c1"}, - core::JoinType::kRightSemiFilter) - .project({"u_c0", "u_c1 + 1"}) - .planNode(); - - { - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(op)) - .numDrivers(1) - .makeInputSplits(makeInputSplits(probeScanId)) - .referenceQuery( - "SELECT u.c0, u.c1 + 1 FROM u WHERE u.c0 IN (SELECT c0 FROM t WHERE t.c2 > 0)") - .verifier([&](const std::shared_ptr& task, bool hasSpill) { - SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); - if (hasSpill) { - // Dynamic filtering should be disabled with spilling triggered. - ASSERT_EQ(0, getFiltersProduced(task, 1).sum); - ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(getReplacedWithFilterRows(task, 1).sum, 0); - ASSERT_EQ( - getInputPositions(task, 1), - numRowsProbe * numNonSkippedSplits); - } else { - ASSERT_EQ(1, getFiltersProduced(task, 1).sum); - ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); - ASSERT_EQ(getReplacedWithFilterRows(task, 1).sum, 0); - ASSERT_LT( - getInputPositions(task, 1), - numRowsProbe * numNonSkippedSplits); - } - }) - .run(); - } - } -} - -TEST_F(HashJoinTest, dynamicFiltersAppliedToPreloadedSplits) { - vector_size_t size = 1000; - const int32_t numSplits = 5; - - std::vector probeVectors; - probeVectors.reserve(numSplits); - - // Prepare probe side table. - std::vector> tempFiles; - std::vector probeSplits; - for (int32_t i = 0; i < numSplits; ++i) { - auto rowVector = makeRowVector( - {"p0", "p1"}, - { - makeFlatVector( - size, [&](auto row) { return (row + 1) * (i + 1); }), - makeFlatVector(size, [&](auto /*row*/) { return i; }), - }); - probeVectors.push_back(rowVector); - tempFiles.push_back(TempFilePath::create()); - writeToFile(tempFiles.back()->getPath(), rowVector); - auto split = HiveConnectorSplitBuilder(tempFiles.back()->getPath()) - .partitionKey("p1", std::to_string(i)) - .build(); - probeSplits.push_back(exec::Split(split)); - } - - auto outputType = ROW({"p0", "p1"}, {BIGINT(), BIGINT()}); - connector::ColumnHandleMap assignments = { - {"p0", regularColumn("p0", BIGINT())}, - {"p1", partitionKey("p1", BIGINT())}}; - createDuckDbTable("p", probeVectors); - - // Prepare build side table. - std::vector buildVectors{ - makeRowVector({"b0"}, {makeFlatVector({0, numSplits})})}; - createDuckDbTable("b", buildVectors); - - // Executing the join with p1=b0, we expect a dynamic filter for p1 to prune - // the entire file/split. There are total of five splits, and all except the - // first one are expected to be pruned. The result 'preloadedSplits' > 1 - // confirms the successful push of dynamic filters to the preloading data - // source. - core::PlanNodeId probeScanId; - core::PlanNodeId joinNodeId; - auto planNodeIdGenerator = std::make_shared(); - auto op = - PlanBuilder(planNodeIdGenerator) - .startTableScan() - .outputType(outputType) - .assignments(assignments) - .endTableScan() - .capturePlanNodeId(probeScanId) - .hashJoin( - {"p1"}, - {"b0"}, - PlanBuilder(planNodeIdGenerator).values(buildVectors).planNode(), - "", - {"p0"}, - core::JoinType::kInner) - .capturePlanNodeId(joinNodeId) - .project({"p0"}) - .planNode(); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(op)) - .config(core::QueryConfig::kMaxSplitPreloadPerDriver, "3") - .injectSpill(false) - .inputSplits({{probeScanId, probeSplits}}) - .referenceQuery("select p.p0 from p, b where b.b0 = p.p1") - .checkSpillStats(false) - .verifier([&](const std::shared_ptr& task, bool /*hasSpill*/) { - auto planStats = toPlanStats(task->taskStats()); - auto getStatSum = [&](const core::PlanNodeId& id, - const std::string& name) { - return planStats.at(id).customStats.at(name).sum; - }; - ASSERT_EQ(1, getStatSum(joinNodeId, "dynamicFiltersProduced")); - ASSERT_EQ(1, getStatSum(probeScanId, "dynamicFiltersAccepted")); - ASSERT_EQ(4, getStatSum(probeScanId, "skippedSplits")); - ASSERT_LT(1, getStatSum(probeScanId, "preloadedSplits")); - }) - .run(); -} - -TEST_F(HashJoinTest, dynamicFiltersPushDownThroughAgg) { - const int32_t numRowsProbe = 300; - const int32_t numRowsBuild = 100; - - // Create probe data - std::vector probeVectors{makeRowVector({ - makeFlatVector(numRowsProbe, [&](auto row) { return row - 10; }), - makeFlatVector(numRowsProbe, folly::identity), - })}; - std::shared_ptr probeFile = TempFilePath::create(); - writeToFile(probeFile->getPath(), probeVectors); - - // Create build data - std::vector buildVectors{makeRowVector( - {"u0"}, {makeFlatVector(numRowsBuild, [&](auto row) { - return 35 + 2 * (row + numRowsBuild / 5); - })})}; - - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", buildVectors); - - auto probeType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); - auto planNodeIdGenerator = std::make_shared(); - auto buildSide = - PlanBuilder(planNodeIdGenerator).values(buildVectors).planNode(); - - // Inner join. - core::PlanNodeId scanNodeId; - core::PlanNodeId joinNodeId; - core::PlanNodeId aggNodeId; - auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) - .tableScan(probeType) - .capturePlanNodeId(scanNodeId) - .partialAggregation({"c0"}, {"sum(c1)"}) - .capturePlanNodeId(aggNodeId) - .hashJoin( - {"c0"}, - {"u0"}, - buildSide, - "", - {"c0", "a0"}, - core::JoinType::kInner) - .capturePlanNodeId(joinNodeId) - .planNode(); - - SplitInput splitInput = { - {scanNodeId, {Split(makeHiveConnectorSplit(probeFile->getPath()))}}}; - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(op)) - .inputSplits(splitInput) - .injectSpill(false) - .checkSpillStats(false) - .referenceQuery("SELECT c0, sum(c1) FROM t, u WHERE c0 = u0 group by c0") - .verifier([&](const std::shared_ptr& task, bool hasSpill) { - auto planStats = toPlanStats(task->taskStats()); - auto dynamicFilterStats = planStats.at(scanNodeId).dynamicFilterStats; - ASSERT_EQ( - 1, getFiltersProduced(task, getOperatorIndex(joinNodeId)).sum); - ASSERT_EQ( - 1, getFiltersAccepted(task, getOperatorIndex(scanNodeId)).sum); - ASSERT_LT( - getInputPositions(task, getOperatorIndex(aggNodeId)), numRowsProbe); - ASSERT_EQ( - dynamicFilterStats.producerNodeIds, - std::unordered_set({joinNodeId})); - }) - .run(); -} - -TEST_F(HashJoinTest, noDynamicFiltersPushDownThroughRightJoin) { - std::vector innerBuild = {makeRowVector( - {"a"}, - { - makeFlatVector(5, [](auto i) { return 2 * i; }), - })}; - std::vector rightBuild = {makeRowVector( - {"b"}, - { - makeFlatVector(5, [](auto i) { return 1 + 2 * i; }), - })}; - std::vector rightProbe = {makeRowVector( - {"aa", "bb"}, - { - makeFlatVector(10, folly::identity), - makeFlatVector(10, folly::identity), - })}; - auto file = TempFilePath::create(); - writeToFile(file->getPath(), rightProbe); - auto planNodeIdGenerator = std::make_shared(); - core::PlanNodeId scanNodeId; - auto plan = - PlanBuilder(planNodeIdGenerator) - .tableScan(asRowType(rightProbe[0]->type())) - .capturePlanNodeId(scanNodeId) - .hashJoin( - {"bb"}, - {"b"}, - PlanBuilder(planNodeIdGenerator).values(rightBuild).planNode(), - "", - {"aa", "b"}, - core::JoinType::kRight) - .hashJoin( - {"aa"}, - {"a"}, - PlanBuilder(planNodeIdGenerator).values(innerBuild).planNode(), - "", - {"aa"}) - .planNode(); - AssertQueryBuilder(plan) - .split(scanNodeId, Split(makeHiveConnectorSplit(file->getPath()))) - .assertResults( - BaseVector::create(innerBuild[0]->type(), 0, pool_.get())); -} - -// Verify the size of the join output vectors when projecting build-side -// variable-width column. -TEST_F(HashJoinTest, memoryUsage) { - std::vector probeVectors = - makeBatches(10, [&](int32_t /*unused*/) { - return makeRowVector( - {makeFlatVector(1'000, [](auto row) { return row % 5; })}); - }); - std::vector buildVectors = - makeBatches(5, [&](int32_t /*unused*/) { - return makeRowVector( - {"u_c0", "u_c1"}, - {makeFlatVector({0, 1, 2}), - makeFlatVector({ - std::string(40, 'a'), - std::string(50, 'b'), - std::string(30, 'c'), - })}); - }); - core::PlanNodeId joinNodeId; - - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .values(probeVectors) - .hashJoin( - {"c0"}, - {"u_c0"}, - PlanBuilder(planNodeIdGenerator) - .values({buildVectors}) - .planNode(), - "", - {"c0", "u_c1"}) - .capturePlanNodeId(joinNodeId) - .singleAggregation({}, {"count(1)"}) - .planNode(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(plan)) - .referenceQuery("SELECT 30000") - .verifier([&](const std::shared_ptr& task, bool hasSpill) { - if (hasSpill) { - return; - } - auto planStats = toPlanStats(task->taskStats()); - auto outputBytes = planStats.at(joinNodeId).outputBytes; - ASSERT_LT(outputBytes, ((40 + 50 + 30) / 3 + 8) * 1000 * 10 * 5); - // Verify number of memory allocations. Should not be too high if - // hash join is able to re-use output vectors that contain - // build-side data. - ASSERT_GT(40, task->pool()->stats().numAllocs); - }) - .run(); -} - -/// Test an edge case in producing small output batches where the logic to -/// calculate the set of probe-side rows to load lazy vectors for was -/// triggering a crash. -TEST_F(HashJoinTest, smallOutputBatchSize) { - // Setup probe data with 50 non-null matching keys followed by 50 null - // keys: 1, 2, 1, 2,...null, null. - auto probeVectors = makeRowVector({ - makeFlatVector( - 100, - [](auto row) { return 1 + row % 2; }, - [](auto row) { return row > 50; }), - makeFlatVector(100, [](auto row) { return row * 10; }), - }); - - // Setup build side to match non-null probe side keys. - auto buildVectors = makeRowVector( - {"u_c0", "u_c1"}, - { - makeFlatVector({1, 2}), - makeFlatVector({100, 200}), - }); - - createDuckDbTable("t", {probeVectors}); - createDuckDbTable("u", {buildVectors}); - - // Plan hash inner join with a filter. - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .values({probeVectors}) - .hashJoin( - {"c0"}, - {"u_c0"}, - PlanBuilder(planNodeIdGenerator) - .values({buildVectors}) - .planNode(), - "c1 < u_c1", - {"c0", "u_c1"}) - .planNode(); - - // Use small output batch size to trigger logic for calculating set of - // probe-side rows to load lazy vectors for. - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(plan)) - .config(core::QueryConfig::kPreferredOutputBatchRows, std::to_string(10)) - .referenceQuery("SELECT c0, u_c1 FROM t, u WHERE c0 = u_c0 AND c1 < u_c1") - .injectSpill(false) - .run(); -} - -TEST_F(HashJoinTest, spillFileSize) { - const std::vector maxSpillFileSizes({0, 1, 1'000'000'000}); - for (const auto spillFileSize : maxSpillFileSizes) { - SCOPED_TRACE(fmt::format("spillFileSize: {}", spillFileSize)); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .keyTypes({BIGINT()}) - .probeVectors(100, 3) - .buildVectors(100, 3) - .referenceQuery( - "SELECT t_k0, t_data, u_k0, u_data FROM t, u WHERE t.t_k0 = u.u_k0") - .config(core::QueryConfig::kSpillStartPartitionBit, "48") - .config(core::QueryConfig::kSpillNumPartitionBits, "3") - .config( - core::QueryConfig::kMaxSpillFileSize, std::to_string(spillFileSize)) - .checkSpillStats(false) - .maxSpillLevel(0) - .verifier([&](const std::shared_ptr& task, bool hasSpill) { - if (!hasSpill) { - return; - } - const auto statsPair = taskSpilledStats(*task); - const int32_t numPartitions = statsPair.first.spilledPartitions; - ASSERT_EQ(statsPair.second.spilledPartitions, numPartitions); - const auto fileSizes = numTaskSpillFiles(*task); - if (spillFileSize != 1) { - ASSERT_EQ(fileSizes.first, numPartitions); - } else { - ASSERT_GT(fileSizes.first, numPartitions); - } - verifyTaskSpilledRuntimeStats(*task, true); - }) - .run(); - } -} - -TEST_F(HashJoinTest, spillPartitionBitsOverlap) { - auto builder = - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .keyTypes({BIGINT(), BIGINT()}) - .probeVectors(2'000, 3) - .buildVectors(2'000, 3) - .referenceQuery( - "SELECT t_k0, t_k1, t_data, u_k0, u_k1, u_data FROM t, u WHERE t_k0 = u_k0 and t_k1 = u_k1") - .config(core::QueryConfig::kSpillStartPartitionBit, "8") - .config(core::QueryConfig::kSpillNumPartitionBits, "1") - .checkSpillStats(false) - .maxSpillLevel(0); - VELOX_ASSERT_THROW(builder.run(), "vs. 8"); -} - -// The test is to verify if the hash build reservation has been released on -// task error. -DEBUG_ONLY_TEST_F(HashJoinTest, buildReservationReleaseCheck) { - std::vector probeVectors = - makeBatches(1, [&](int32_t /*unused*/) { - return std::dynamic_pointer_cast( - BatchMaker::createBatch(probeType_, 1000, *pool_)); - }); - std::vector buildVectors = makeBatches(10, [&](int32_t index) { - return std::dynamic_pointer_cast( - BatchMaker::createBatch(buildType_, 5000 * (1 + index), *pool_)); - }); - - auto planNodeIdGenerator = std::make_shared(); - CursorParameters params; - params.planNode = PlanBuilder(planNodeIdGenerator) - .values(probeVectors, true) - .hashJoin( - {"t_k1"}, - {"u_k1"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors, true) - .planNode(), - "", - concat(probeType_->names(), buildType_->names())) - .planNode(); - params.queryCtx = core::QueryCtx::create(driverExecutor_.get()); - // NOTE: the spilling setup is to trigger memory reservation code path which - // only gets executed when spilling is enabled. We don't care about if - // spilling is really triggered in test or not. - auto spillDirectory = exec::test::TempDirectoryPath::create(); - params.spillDirectory = spillDirectory->getPath(); - params.queryCtx->testingOverrideConfigUnsafe( - {{core::QueryConfig::kSpillEnabled, "true"}, - {core::QueryConfig::kMaxSpillLevel, "0"}}); - params.maxDrivers = 1; - - auto cursor = TaskCursor::create(params); - auto* task = cursor->task().get(); - - // Set up a testvalue to trigger task abort when hash build tries to reserve - // memory. - SCOPED_TESTVALUE_SET( - "facebook::velox::common::memory::MemoryPoolImpl::maybeReserve", - std::function( - [&](memory::MemoryPool* /*unused*/) { task->requestAbort(); })); - auto runTask = [&]() { - while (cursor->moveNext()) { - } - }; - VELOX_ASSERT_THROW(runTask(), ""); - ASSERT_TRUE(waitForTaskAborted(task, 5'000'000)); -} - -TEST_F(HashJoinTest, dynamicFilterOnPartitionKey) { - vector_size_t size = 10; - auto filePaths = makeFilePaths(1); - auto rowVector = makeRowVector( - {makeFlatVector(size, [&](auto row) { return row; })}); - createDuckDbTable("u", {rowVector}); - writeToFile(filePaths[0]->getPath(), rowVector); - std::vector buildVectors{ - makeRowVector({"c0"}, {makeFlatVector({0, 1, 2})})}; - createDuckDbTable("t", buildVectors); - auto split = facebook::velox::exec::test::HiveConnectorSplitBuilder( - filePaths[0]->getPath()) - .partitionKey("k", "0") - .build(); - auto outputType = ROW({"n1_0", "n1_1"}, {BIGINT(), BIGINT()}); - connector::ColumnHandleMap assignments = { - {"n1_0", regularColumn("c0", BIGINT())}, - {"n1_1", partitionKey("k", BIGINT())}}; - - core::PlanNodeId probeScanId; - auto planNodeIdGenerator = std::make_shared(); - auto op = - PlanBuilder(planNodeIdGenerator) - .startTableScan() - .outputType(outputType) - .assignments(assignments) - .endTableScan() - .capturePlanNodeId(probeScanId) - .hashJoin( - {"n1_1"}, - {"c0"}, - PlanBuilder(planNodeIdGenerator).values(buildVectors).planNode(), - "", - {"c0"}, - core::JoinType::kInner) - .project({"c0"}) - .planNode(); - SplitInput splits = {{probeScanId, {exec::Split(split)}}}; - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(op)) - .inputSplits(splits) - .referenceQuery("select t.c0 from t, u where t.c0 = 0") - .checkSpillStats(false) - .run(); -} - -TEST_F(HashJoinTest, probeMemoryLimitOnBuildProjection) { - const uint64_t numBuildRows = 20; - std::vector probeVectors = - makeBatches(10, [&](int32_t /*unused*/) { - return makeRowVector({makeFlatVector( - 1'000, [](auto row) { return row % 25; })}); - }); - - std::vector buildVectors = - makeBatches(1, [&](int32_t /*unused*/) { - return makeRowVector( - {"u_c0", "u_c1", "u_c2", "u_c3", "u_c4"}, - {makeFlatVector( - numBuildRows, [](auto row) { return row; }), - makeFlatVector( - numBuildRows, - [](auto /* row */) { return std::string(4096, 'a'); }), - makeFlatVector( - numBuildRows, - [](auto /* row */) { return std::string(4096, 'a'); }), - makeFlatVector( - numBuildRows, - [](auto row) { - // Row that has too large of size variation. - if (row == 0) { - return std::string(4096, 'a'); - } else { - return std::string(1, 'a'); - } - }), - makeFlatVector(numBuildRows, [](auto row) { - // Row that has tolerable size variation. - if (row == 0) { - return std::string(4096, 'a'); - } else { - return std::string(256, 'a'); - } - })}); - }); - - createDuckDbTable("t", {probeVectors}); - createDuckDbTable("u", {buildVectors}); - - struct TestParam { - std::vector varSizeColumns; - int32_t numExpectedBatches; - std::string referenceQuery; - std::string debugString() const { - std::stringstream ss; - ss << "varSizeColumns ["; - for (const auto& columnIndex : varSizeColumns) { - ss << columnIndex << ", "; - } - ss << "] "; - ss << "numExpectedBatches " << numExpectedBatches << ", referenceQuery '" - << referenceQuery << "'"; - return ss.str(); - } - }; - - std::vector testParams{ - {{}, 10, "SELECT t.c0 FROM t JOIN u ON t.c0 = u.u_c0"}, - {{1}, 4000, "SELECT t.c0, u.u_c1 FROM t JOIN u ON t.c0 = u.u_c0"}, - {{1, 2}, - 8000, - "SELECT t.c0, u.u_c1, u.u_c2 FROM t JOIN u ON t.c0 = u.u_c0"}, - {{3}, 210, "SELECT t.c0, u.u_c3 FROM t JOIN u ON t.c0 = u.u_c0"}, - {{4}, 2670, "SELECT t.c0, u.u_c4 FROM t JOIN u ON t.c0 = u.u_c0"}}; - - for (const auto& testParam : testParams) { - SCOPED_TRACE(testParam.debugString()); - core::PlanNodeId joinNodeId; - std::vector outputLayout; - outputLayout.push_back("c0"); - for (int32_t i = 0; i < testParam.varSizeColumns.size(); i++) { - outputLayout.push_back(fmt::format("u_c{}", testParam.varSizeColumns[i])); - } - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .values(probeVectors) - .hashJoin( - {"c0"}, - {"u_c0"}, - PlanBuilder(planNodeIdGenerator) - .values({buildVectors}) - .planNode(), - "", - outputLayout) - .capturePlanNodeId(joinNodeId) - .planNode(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(plan)) - .config(core::QueryConfig::kPreferredOutputBatchBytes, "8192") - .injectSpill(false) - .referenceQuery(testParam.referenceQuery) - .verifier([&](const std::shared_ptr& task, bool /* unused */) { - auto planStats = toPlanStats(task->taskStats()); - auto outputBatches = planStats.at(joinNodeId).outputVectors; - ASSERT_EQ(outputBatches, testParam.numExpectedBatches); - }) - .run(); - } -} - -DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringInputProcessing) { - constexpr int64_t kMaxBytes = 1LL << 30; // 1GB - VectorFuzzer fuzzer({.vectorSize = 1000}, pool()); - const int32_t numBuildVectors = 10; - std::vector buildVectors; - for (int32_t i = 0; i < numBuildVectors; ++i) { - buildVectors.push_back(fuzzer.fuzzRow(buildType_)); - } - const int32_t numProbeVectors = 5; - std::vector probeVectors; - for (int32_t i = 0; i < numProbeVectors; ++i) { - probeVectors.push_back(fuzzer.fuzzRow(probeType_)); - } - - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", buildVectors); - - struct { - // 0: trigger reclaim with some input processed. - // 1: trigger reclaim after all the inputs processed. - int triggerCondition; - bool spillEnabled; - bool expectedReclaimable; - - std::string debugString() const { - return fmt::format( - "triggerCondition {}, spillEnabled {}, expectedReclaimable {}", - triggerCondition, - spillEnabled, - expectedReclaimable); - } - } testSettings[] = { - {0, true, true}, {0, true, true}, {0, false, false}, {0, false, false}}; - for (const auto& testData : testSettings) { - SCOPED_TRACE(testData.debugString()); - - auto tempDirectory = exec::test::TempDirectoryPath::create(); - auto queryPool = memory::memoryManager()->addRootPool( - "", kMaxBytes, memory::MemoryReclaimer::create()); - - core::PlanNodeId probeScanId; - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .values(probeVectors, false) - .hashJoin( - {"t_k1"}, - {"u_k1"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors, false) - .planNode(), - "", - concat(probeType_->names(), buildType_->names())) - .planNode(); - - folly::EventCount driverWait; - auto driverWaitKey = driverWait.prepareWait(); - folly::EventCount testWait; - auto testWaitKey = testWait.prepareWait(); - - std::atomic numInputs{0}; - Operator* op; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal::addInput", - std::function(([&](Operator* testOp) { - if (testOp->operatorType() != "HashBuild") { - return; - } - op = testOp; - ++numInputs; - if (testData.triggerCondition == 0) { - if (numInputs != 2) { - return; - } - } - if (testData.triggerCondition == 1) { - if (numInputs != numBuildVectors) { - return; - } - } - ASSERT_EQ(op->canReclaim(), testData.expectedReclaimable); - uint64_t reclaimableBytes{0}; - const bool reclaimable = op->reclaimableBytes(reclaimableBytes); - ASSERT_EQ(reclaimable, testData.expectedReclaimable); - if (testData.expectedReclaimable) { - ASSERT_GT(reclaimableBytes, 0); - } else { - ASSERT_EQ(reclaimableBytes, 0); - } - testWait.notify(); - driverWait.wait(driverWaitKey); - }))); - - std::thread taskThread([&]() { - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .planNode(plan) - .queryPool(std::move(queryPool)) - .injectSpill(false) - .spillDirectory(testData.spillEnabled ? tempDirectory->getPath() : "") - .referenceQuery( - "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") - .config(core::QueryConfig::kSpillStartPartitionBit, "29") - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - const auto statsPair = taskSpilledStats(*task); - if (testData.expectedReclaimable) { - ASSERT_GT(statsPair.first.spilledBytes, 0); - ASSERT_EQ(statsPair.first.spilledPartitions, 8); - ASSERT_GT(statsPair.second.spilledBytes, 0); - ASSERT_EQ(statsPair.second.spilledPartitions, 8); - verifyTaskSpilledRuntimeStats(*task, true); - } else { - ASSERT_EQ(statsPair.first.spilledBytes, 0); - ASSERT_EQ(statsPair.first.spilledPartitions, 0); - ASSERT_EQ(statsPair.second.spilledBytes, 0); - ASSERT_EQ(statsPair.second.spilledPartitions, 0); - verifyTaskSpilledRuntimeStats(*task, false); - } - }) - .run(); - }); - - testWait.wait(testWaitKey); - ASSERT_TRUE(op != nullptr); - auto task = op->operatorCtx()->task(); - auto taskPauseWait = task->requestPause(); - driverWait.notify(); - taskPauseWait.wait(); - - uint64_t reclaimableBytes{0}; - const bool reclaimable = op->reclaimableBytes(reclaimableBytes); - ASSERT_EQ(op->canReclaim(), testData.expectedReclaimable); - ASSERT_EQ(reclaimable, testData.expectedReclaimable); - if (testData.expectedReclaimable) { - ASSERT_GT(reclaimableBytes, 0); - } else { - ASSERT_EQ(reclaimableBytes, 0); - } - - if (testData.expectedReclaimable) { - { - memory::ScopedMemoryArbitrationContext ctx(op->pool()); - op->pool()->reclaim( - folly::Random::oneIn(2) ? 0 : folly::Random::rand32(), - 0, - reclaimerStats_); - } - ASSERT_GT(reclaimerStats_.reclaimExecTimeUs, 0); - ASSERT_GT(reclaimerStats_.reclaimedBytes, 0); - reclaimerStats_.reset(); - ASSERT_EQ(op->pool()->usedBytes(), 0); - } else { - VELOX_ASSERT_THROW( - op->reclaim( - folly::Random::oneIn(2) ? 0 : folly::Random::rand32(), - reclaimerStats_), - ""); - } - - Task::resume(task); - task.reset(); - - taskThread.join(); - } - ASSERT_EQ(reclaimerStats_, memory::MemoryReclaimer::Stats{}); -} - -DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringReserve) { - constexpr int64_t kMaxBytes = 1LL << 30; // 1GB - const int32_t numBuildVectors = 3; - std::vector buildVectors; - for (int32_t i = 0; i < numBuildVectors; ++i) { - const size_t size = i == 0 ? 1 : 1'000; - VectorFuzzer fuzzer({.vectorSize = size}, pool()); - buildVectors.push_back(fuzzer.fuzzRow(buildType_)); - } - - const int32_t numProbeVectors = 3; - std::vector probeVectors; - for (int32_t i = 0; i < numProbeVectors; ++i) { - VectorFuzzer fuzzer({.vectorSize = 1'000}, pool()); - probeVectors.push_back(fuzzer.fuzzRow(probeType_)); - } - - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", buildVectors); - - auto tempDirectory = exec::test::TempDirectoryPath::create(); - auto queryPool = memory::memoryManager()->addRootPool( - "", kMaxBytes, memory::MemoryReclaimer::create()); - - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .values(probeVectors, false) - .hashJoin( - {"t_k1"}, - {"u_k1"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors, false) - .planNode(), - "", - concat(probeType_->names(), buildType_->names())) - .planNode(); - - folly::EventCount driverWait; - std::atomic_bool driverWaitFlag{true}; - folly::EventCount testWait; - std::atomic_bool testWaitFlag{true}; - - Operator* op{nullptr}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal::addInput", - std::function(([&](Operator* testOp) { - if (testOp->operatorType() != "HashBuild") { - return; - } - op = testOp; - }))); - - std::atomic injectOnce{true}; - SCOPED_TESTVALUE_SET( - "facebook::velox::common::memory::MemoryPoolImpl::maybeReserve", - std::function( - ([&](memory::MemoryPoolImpl* pool) { - ASSERT_TRUE(op != nullptr); - if (!isHashBuildMemoryPool(*pool)) { - return; - } - ASSERT_TRUE(op->canReclaim()); - if (op->pool()->usedBytes() == 0) { - // We skip trigger memory reclaim when the hash table is empty on - // memory reservation. - return; - } - if (!injectOnce.exchange(false)) { - return; - } - uint64_t reclaimableBytes{0}; - const bool reclaimable = op->reclaimableBytes(reclaimableBytes); - ASSERT_TRUE(reclaimable); - ASSERT_GT(reclaimableBytes, 0); - auto* driver = op->operatorCtx()->driver(); - TestSuspendedSection suspendedSection(driver); - testWaitFlag = false; - testWait.notifyAll(); - driverWait.await([&]() { return !driverWaitFlag.load(); }); - }))); - - std::thread taskThread([&]() { - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .planNode(plan) - .queryPool(std::move(queryPool)) - .injectSpill(false) - .spillDirectory(tempDirectory->getPath()) - .referenceQuery( - "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") - .config(core::QueryConfig::kSpillStartPartitionBit, "29") - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - const auto statsPair = taskSpilledStats(*task); - ASSERT_GT(statsPair.first.spilledBytes, 0); - ASSERT_EQ(statsPair.first.spilledPartitions, 8); - ASSERT_GT(statsPair.second.spilledBytes, 0); - ASSERT_EQ(statsPair.second.spilledPartitions, 8); - verifyTaskSpilledRuntimeStats(*task, true); - }) - .run(); - }); - - testWait.await([&]() { return !testWaitFlag.load(); }); - ASSERT_TRUE(op != nullptr); - auto task = op->operatorCtx()->task(); - task->requestPause().wait(); - - uint64_t reclaimableBytes{0}; - const bool reclaimable = op->reclaimableBytes(reclaimableBytes); - ASSERT_TRUE(op->canReclaim()); - ASSERT_TRUE(reclaimable); - ASSERT_GT(reclaimableBytes, 0); - - { - memory::ScopedMemoryArbitrationContext ctx(op->pool()); - uint64_t reclaimedBytes = task->pool()->reclaim( - folly::Random::oneIn(2) ? 0 : folly::Random::rand32(), - 0, - reclaimerStats_); - ASSERT_GT(reclaimedBytes, 0); - } - ASSERT_GT(reclaimerStats_.reclaimedBytes, 0); - ASSERT_GT(reclaimerStats_.reclaimExecTimeUs, 0); - ASSERT_EQ(op->pool()->usedBytes(), 0); - - driverWaitFlag = false; - driverWait.notifyAll(); - Task::resume(task); - task.reset(); - - taskThread.join(); -} - -DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringAllocation) { - constexpr int64_t kMaxBytes = 1LL << 30; // 1GB - VectorFuzzer fuzzer({.vectorSize = 1000}, pool()); - const int32_t numBuildVectors = 10; - std::vector buildVectors; - for (int32_t i = 0; i < numBuildVectors; ++i) { - buildVectors.push_back(fuzzer.fuzzRow(buildType_)); - } - const int32_t numProbeVectors = 5; - std::vector probeVectors; - for (int32_t i = 0; i < numProbeVectors; ++i) { - probeVectors.push_back(fuzzer.fuzzRow(probeType_)); - } - - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", buildVectors); - - const std::vector enableSpillings = {false, true}; - for (const auto enableSpilling : enableSpillings) { - SCOPED_TRACE(fmt::format("enableSpilling {}", enableSpilling)); - - auto tempDirectory = exec::test::TempDirectoryPath::create(); - auto queryPool = memory::memoryManager()->addRootPool("", kMaxBytes); - - core::PlanNodeId probeScanId; - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .values(probeVectors, false) - .hashJoin( - {"t_k1"}, - {"u_k1"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors, false) - .planNode(), - "", - concat(probeType_->names(), buildType_->names())) - .planNode(); - - folly::EventCount driverWait; - auto driverWaitKey = driverWait.prepareWait(); - folly::EventCount testWait; - auto testWaitKey = testWait.prepareWait(); - - Operator* op; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal::addInput", - std::function(([&](Operator* testOp) { - if (testOp->operatorType() != "HashBuild") { - return; - } - op = testOp; - }))); - - std::atomic injectOnce{true}; - SCOPED_TESTVALUE_SET( - "facebook::velox::common::memory::MemoryPoolImpl::allocateNonContiguous", - std::function( - ([&](memory::MemoryPoolImpl* pool) { - ASSERT_TRUE(op != nullptr); - const std::string re(".*HashBuild"); - if (!RE2::FullMatch(pool->name(), re)) { - return; - } - if (!injectOnce.exchange(false)) { - return; - } - ASSERT_EQ(op->canReclaim(), enableSpilling); - uint64_t reclaimableBytes{0}; - const bool reclaimable = op->reclaimableBytes(reclaimableBytes); - ASSERT_EQ(reclaimable, enableSpilling); - if (enableSpilling) { - ASSERT_GE(reclaimableBytes, 0); - } else { - ASSERT_EQ(reclaimableBytes, 0); - } - auto* driver = op->operatorCtx()->driver(); - TestSuspendedSection suspendedSection(driver); - testWait.notify(); - driverWait.wait(driverWaitKey); - }))); - - std::thread taskThread([&]() { - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .planNode(plan) - .queryPool(std::move(queryPool)) - .injectSpill(false) - .spillDirectory(enableSpilling ? tempDirectory->getPath() : "") - .referenceQuery( - "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - const auto statsPair = taskSpilledStats(*task); - ASSERT_EQ(statsPair.first.spilledBytes, 0); - ASSERT_EQ(statsPair.first.spilledPartitions, 0); - ASSERT_EQ(statsPair.second.spilledBytes, 0); - ASSERT_EQ(statsPair.second.spilledPartitions, 0); - verifyTaskSpilledRuntimeStats(*task, false); - }) - .run(); - }); - - testWait.wait(testWaitKey); - ASSERT_TRUE(op != nullptr); - auto task = op->operatorCtx()->task(); - auto taskPauseWait = task->requestPause(); - taskPauseWait.wait(); - - uint64_t reclaimableBytes{0}; - const bool reclaimable = op->reclaimableBytes(reclaimableBytes); - ASSERT_EQ(op->canReclaim(), enableSpilling); - ASSERT_EQ(reclaimable, enableSpilling); - if (enableSpilling) { - ASSERT_GE(reclaimableBytes, 0); - } else { - ASSERT_EQ(reclaimableBytes, 0); - } - VELOX_ASSERT_THROW( - op->reclaim( - folly::Random::oneIn(2) ? 0 : folly::Random::rand32(), - reclaimerStats_), - ""); - - driverWait.notify(); - Task::resume(task); - task.reset(); - - taskThread.join(); - } - ASSERT_EQ(reclaimerStats_, memory::MemoryReclaimer::Stats{0}); -} - -DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringOutputProcessing) { - constexpr int64_t kMaxBytes = 1LL << 30; // 1GB - VectorFuzzer fuzzer({.vectorSize = 1000}, pool()); - const int32_t numBuildVectors = 10; - std::vector buildVectors; - for (int32_t i = 0; i < numBuildVectors; ++i) { - buildVectors.push_back(fuzzer.fuzzRow(buildType_)); - } - const int32_t numProbeVectors = 5; - std::vector probeVectors; - for (int32_t i = 0; i < numProbeVectors; ++i) { - probeVectors.push_back(fuzzer.fuzzRow(probeType_)); - } - - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", buildVectors); - - const std::vector enableSpillings = {false, true}; - for (const auto enableSpilling : enableSpillings) { - SCOPED_TRACE(fmt::format("enableSpilling {}", enableSpilling)); - auto tempDirectory = exec::test::TempDirectoryPath::create(); - auto queryPool = memory::memoryManager()->addRootPool( - "", kMaxBytes, memory::MemoryReclaimer::create()); - - core::PlanNodeId probeScanId; - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .values(probeVectors, false) - .hashJoin( - {"t_k1"}, - {"u_k1"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors, false) - .planNode(), - "", - concat(probeType_->names(), buildType_->names())) - .planNode(); - - std::atomic_bool driverWaitFlag{true}; - folly::EventCount driverWait; - std::atomic_bool testWaitFlag{true}; - folly::EventCount testWait; - - std::atomic injectOnce{true}; - Operator* op; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal::noMoreInput", - std::function(([&](Operator* testOp) { - if (testOp->operatorType() != "HashBuild") { - return; - } - op = testOp; - if (!injectOnce.exchange(false)) { - return; - } - ASSERT_EQ(op->canReclaim(), enableSpilling); - uint64_t reclaimableBytes{0}; - const bool reclaimable = op->reclaimableBytes(reclaimableBytes); - ASSERT_EQ(reclaimable, enableSpilling); - if (enableSpilling) { - ASSERT_GT(reclaimableBytes, 0); - } else { - ASSERT_EQ(reclaimableBytes, 0); - } - testWaitFlag = false; - testWait.notifyAll(); - driverWait.await([&]() { return !driverWaitFlag.load(); }); - }))); - - std::thread taskThread([&]() { - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .planNode(plan) - .queryPool(std::move(queryPool)) - .injectSpill(false) - .spillDirectory(enableSpilling ? tempDirectory->getPath() : "") - .referenceQuery( - "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - const auto statsPair = taskSpilledStats(*task); - ASSERT_EQ(statsPair.first.spilledBytes, 0); - ASSERT_EQ(statsPair.first.spilledPartitions, 0); - ASSERT_EQ(statsPair.second.spilledBytes, 0); - ASSERT_EQ(statsPair.second.spilledPartitions, 0); - verifyTaskSpilledRuntimeStats(*task, false); - }) - .run(); - }); - - testWait.await([&]() { return !testWaitFlag.load(); }); - ASSERT_TRUE(op != nullptr); - auto task = op->operatorCtx()->task(); - auto taskPauseWait = task->requestPause(); - driverWaitFlag = false; - driverWait.notifyAll(); - taskPauseWait.wait(); - - uint64_t reclaimableBytes{0}; - const bool reclaimable = op->reclaimableBytes(reclaimableBytes); - ASSERT_EQ(op->canReclaim(), enableSpilling); - ASSERT_EQ(reclaimable, enableSpilling); - - if (enableSpilling) { - ASSERT_GT(reclaimableBytes, 0); - const auto usedMemoryBytes = op->pool()->usedBytes(); - { - memory::ScopedMemoryArbitrationContext ctx(op->pool()); - op->pool()->reclaim( - folly::Random::oneIn(2) ? 0 : folly::Random::rand32(), - 0, - reclaimerStats_); - } - ASSERT_GE(reclaimerStats_.reclaimedBytes, 0); - ASSERT_GT(reclaimerStats_.reclaimExecTimeUs, 0); - // No reclaim as the operator has started output processing. - ASSERT_EQ(usedMemoryBytes, op->pool()->usedBytes()); - } else { - ASSERT_EQ(reclaimableBytes, 0); - VELOX_ASSERT_THROW( - op->reclaim( - folly::Random::oneIn(2) ? 0 : folly::Random::rand32(), - reclaimerStats_), - ""); - } - - Task::resume(task); - task.reset(); - - taskThread.join(); - } - ASSERT_EQ(reclaimerStats_.numNonReclaimableAttempts, 1); -} - -DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringWaitForProbe) { - constexpr int64_t kMaxBytes = 1LL << 30; // 1GB - VectorFuzzer fuzzer({.vectorSize = 1000}, pool()); - const int32_t numBuildVectors = 10; - std::vector buildVectors; - for (int32_t i = 0; i < numBuildVectors; ++i) { - buildVectors.push_back(fuzzer.fuzzRow(buildType_)); - } - const int32_t numProbeVectors = 5; - std::vector probeVectors; - for (int32_t i = 0; i < numProbeVectors; ++i) { - probeVectors.push_back(fuzzer.fuzzRow(probeType_)); - } - - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", buildVectors); - - auto tempDirectory = exec::test::TempDirectoryPath::create(); - auto queryPool = memory::memoryManager()->addRootPool( - "", kMaxBytes, memory::MemoryReclaimer::create()); - - core::PlanNodeId probeScanId; - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .values(probeVectors, false) - .hashJoin( - {"t_k1"}, - {"u_k1"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors, false) - .planNode(), - "", - concat(probeType_->names(), buildType_->names())) - .planNode(); - - std::atomic_bool driverWaitFlag{true}; - folly::EventCount driverWait; - std::atomic_bool testWaitFlag{true}; - folly::EventCount testWait; - - Operator* op{nullptr}; - std::atomic injectSpillOnce{true}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::HashBuild::finishHashBuild", - std::function(([&](Operator* testOp) { - if (testOp->operatorType() != "HashBuild") { - return; - } - op = testOp; - if (!injectSpillOnce.exchange(false)) { - return; - } - auto* driver = op->operatorCtx()->driver(); - auto task = driver->task(); - memory::ScopedMemoryArbitrationContext ctx(op->pool()); - Operator::ReclaimableSectionGuard guard(testOp); - testingRunArbitration(testOp->pool()); - }))); - - std::atomic injectOnce{true}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal::noMoreInput", - std::function(([&](Operator* testOp) { - if (testOp->operatorType() != "HashProbe") { - return; - } - if (!injectOnce.exchange(false)) { - return; - } - ASSERT_TRUE(op != nullptr); - ASSERT_TRUE(op->canReclaim()); - uint64_t reclaimableBytes{0}; - const bool reclaimable = op->reclaimableBytes(reclaimableBytes); - ASSERT_TRUE(reclaimable); - ASSERT_GT(reclaimableBytes, 0); - testWaitFlag = false; - testWait.notifyAll(); - auto* driver = testOp->operatorCtx()->driver(); - auto task = driver->task(); - TestSuspendedSection suspendedSection(driver); - driverWait.await([&]() { return !driverWaitFlag.load(); }); - }))); - - std::thread taskThread([&]() { - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .planNode(plan) - .queryPool(std::move(queryPool)) - .injectSpill(false) - .spillDirectory(tempDirectory->getPath()) - .referenceQuery( - "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") - .config(core::QueryConfig::kSpillStartPartitionBit, "29") - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - const auto statsPair = taskSpilledStats(*task); - ASSERT_GT(statsPair.first.spilledBytes, 0); - ASSERT_EQ(statsPair.first.spilledPartitions, 8); - ASSERT_GT(statsPair.second.spilledBytes, 0); - ASSERT_EQ(statsPair.second.spilledPartitions, 8); - }) - .run(); - }); - - testWait.await([&]() { return !testWaitFlag.load(); }); - ASSERT_TRUE(op != nullptr); - auto task = op->operatorCtx()->task(); - auto taskPauseWait = task->requestPause(); - taskPauseWait.wait(); - - uint64_t reclaimableBytes{0}; - const bool reclaimable = op->reclaimableBytes(reclaimableBytes); - ASSERT_TRUE(op->canReclaim()); - ASSERT_TRUE(reclaimable); - ASSERT_GT(reclaimableBytes, 0); - - const auto usedMemoryBytes = op->pool()->usedBytes(); - reclaimerStats_.reset(); - { - memory::ScopedMemoryArbitrationContext ctx(op->pool()); - op->pool()->reclaim( - folly::Random::oneIn(2) ? 0 : folly::Random::rand32(), - 0, - reclaimerStats_); - } - ASSERT_GE(reclaimerStats_.reclaimedBytes, 0); - ASSERT_GT(reclaimerStats_.reclaimExecTimeUs, 0); - // No reclaim as the build operator is not in building table state. - ASSERT_EQ(usedMemoryBytes, op->pool()->usedBytes()); - - driverWaitFlag = false; - driverWait.notifyAll(); - Task::resume(task); - task.reset(); - - taskThread.join(); - ASSERT_EQ(reclaimerStats_.numNonReclaimableAttempts, 1); -} - -DEBUG_ONLY_TEST_F(HashJoinTest, hashBuildAbortDuringOutputProcessing) { - const auto buildVectors = makeVectors(buildType_, 10, 128); - const auto probeVectors = makeVectors(probeType_, 5, 128); - - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", buildVectors); - - struct { - bool abortFromRootMemoryPool; - int numDrivers; - - std::string debugString() const { - return fmt::format( - "abortFromRootMemoryPool {} numDrivers {}", - abortFromRootMemoryPool, - numDrivers); - } - } testSettings[] = {{true, 1}, {false, 1}, {true, 4}, {false, 4}}; - - for (const auto& testData : testSettings) { - SCOPED_TRACE(testData.debugString()); - - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .values(probeVectors, true) - .hashJoin( - {"t_k1"}, - {"u_k1"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors, true) - .planNode(), - "", - concat(probeType_->names(), buildType_->names())) - .planNode(); - - std::atomic injectOnce{true}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal::noMoreInput", - std::function(([&](Operator* op) { - if (op->operatorType() != "HashBuild") { - return; - } - if (!injectOnce.exchange(false)) { - return; - } - ASSERT_GT(op->pool()->usedBytes(), 0); - auto* driver = op->operatorCtx()->driver(); - ASSERT_EQ( - driver->task()->enterSuspended(driver->state()), - StopReason::kNone); - testData.abortFromRootMemoryPool ? abortPool(op->pool()->root()) - : abortPool(op->pool()); - // We can't directly reclaim memory from this hash build operator as - // its driver thread is running and in suspension state. - ASSERT_GT(op->pool()->root()->usedBytes(), 0); - ASSERT_EQ( - driver->task()->leaveSuspended(driver->state()), - StopReason::kAlreadyTerminated); - ASSERT_TRUE(op->pool()->aborted()); - ASSERT_TRUE(op->pool()->root()->aborted()); - VELOX_MEM_POOL_ABORTED("Memory pool aborted"); - }))); - - VELOX_ASSERT_THROW( - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .planNode(plan) - .injectSpill(false) - .referenceQuery( - "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") - .run(), - "Manual MemoryPool Abortion"); - waitForAllTasksToBeDeleted(); - } -} - -DEBUG_ONLY_TEST_F(HashJoinTest, hashBuildAbortDuringInputProcessing) { - const auto buildVectors = makeVectors(buildType_, 10, 128); - const auto probeVectors = makeVectors(probeType_, 5, 128); - - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", buildVectors); - - struct { - bool abortFromRootMemoryPool; - int numDrivers; - - std::string debugString() const { - return fmt::format( - "abortFromRootMemoryPool {} numDrivers {}", - abortFromRootMemoryPool, - numDrivers); - } - } testSettings[] = {{true, 1}, {false, 1}, {true, 4}, {false, 4}}; - - for (const auto& testData : testSettings) { - SCOPED_TRACE(testData.debugString()); - - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .values(probeVectors, true) - .hashJoin( - {"t_k1"}, - {"u_k1"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors, true) - .planNode(), - "", - concat(probeType_->names(), buildType_->names())) - .planNode(); - - std::atomic numInputs{0}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal::addInput", - std::function(([&](Operator* op) { - if (op->operatorType() != "HashBuild") { - return; - } - if (++numInputs != 2) { - return; - } - ASSERT_GT(op->pool()->usedBytes(), 0); - auto* driver = op->operatorCtx()->driver(); - ASSERT_EQ( - driver->task()->enterSuspended(driver->state()), - StopReason::kNone); - testData.abortFromRootMemoryPool ? abortPool(op->pool()->root()) - : abortPool(op->pool()); - // We can't directly reclaim memory from this hash build operator as - // its driver thread is running and in suspension state. - ASSERT_GT(op->pool()->root()->usedBytes(), 0); - ASSERT_EQ( - driver->task()->leaveSuspended(driver->state()), - StopReason::kAlreadyTerminated); - ASSERT_TRUE(op->pool()->aborted()); - ASSERT_TRUE(op->pool()->root()->aborted()); - VELOX_MEM_POOL_ABORTED("Memory pool aborted"); - }))); - - VELOX_ASSERT_THROW( - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .planNode(plan) - .injectSpill(false) - .referenceQuery( - "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") - .run(), - "Manual MemoryPool Abortion"); - - waitForAllTasksToBeDeleted(); - } -} - -DEBUG_ONLY_TEST_F(HashJoinTest, hashBuildAbortDuringAllocation) { - const auto buildVectors = makeVectors(buildType_, 10, 128); - const auto probeVectors = makeVectors(probeType_, 5, 128); - - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", buildVectors); - - struct { - bool abortFromRootMemoryPool; - int numDrivers; - - std::string debugString() const { - return fmt::format( - "abortFromRootMemoryPool {} numDrivers {}", - abortFromRootMemoryPool, - numDrivers); - } - } testSettings[] = {{true, 1}, {false, 1}, {true, 4}, {false, 4}}; - - for (const auto& testData : testSettings) { - SCOPED_TRACE(testData.debugString()); - - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .values(probeVectors, true) - .hashJoin( - {"t_k1"}, - {"u_k1"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors, true) - .planNode(), - "", - concat(probeType_->names(), buildType_->names())) - .planNode(); - - std::atomic_bool injectOnce{true}; - SCOPED_TESTVALUE_SET( - "facebook::velox::common::memory::MemoryPoolImpl::allocateNonContiguous", - std::function( - ([&](memory::MemoryPoolImpl* pool) { - if (!isHashBuildMemoryPool(*pool)) { - return; - } - if (!injectOnce.exchange(false)) { - return; - } - - const auto* driverCtx = driverThreadContext()->driverCtx(); - ASSERT_EQ( - driverCtx->task->enterSuspended(driverCtx->driver->state()), - StopReason::kNone); - testData.abortFromRootMemoryPool ? abortPool(pool->root()) - : abortPool(pool); - // We can't directly reclaim memory from this hash build operator - // as its driver thread is running and in suspegnsion state. - ASSERT_GE(pool->root()->usedBytes(), 0); - ASSERT_EQ( - driverCtx->task->leaveSuspended(driverCtx->driver->state()), - StopReason::kAlreadyTerminated); - ASSERT_TRUE(pool->aborted()); - ASSERT_TRUE(pool->root()->aborted()); - VELOX_MEM_POOL_ABORTED("Memory pool aborted"); - }))); - - VELOX_ASSERT_THROW( - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .planNode(plan) - .injectSpill(false) - .referenceQuery( - "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") - .run(), - "Manual MemoryPool Abortion"); - - waitForAllTasksToBeDeleted(); - } -} - -DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeAbortDuringInputProcessing) { - const auto buildVectors = makeVectors(buildType_, 10, 128); - const auto probeVectors = makeVectors(probeType_, 5, 128); - - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", buildVectors); - - struct { - bool abortFromRootMemoryPool; - int numDrivers; - - std::string debugString() const { - return fmt::format( - "abortFromRootMemoryPool {} numDrivers {}", - abortFromRootMemoryPool, - numDrivers); - } - } testSettings[] = {{true, 1}, {false, 1}, {true, 4}, {false, 4}}; - - for (const auto& testData : testSettings) { - SCOPED_TRACE(testData.debugString()); - - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .values(probeVectors, true) - .hashJoin( - {"t_k1"}, - {"u_k1"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors, true) - .planNode(), - "", - concat(probeType_->names(), buildType_->names())) - .planNode(); - - std::atomic numInputs{0}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal::addInput", - std::function(([&](Operator* op) { - if (op->operatorType() != "HashProbe") { - return; - } - if (++numInputs != 2) { - return; - } - auto* driver = op->operatorCtx()->driver(); - ASSERT_EQ( - driver->task()->enterSuspended(driver->state()), - StopReason::kNone); - testData.abortFromRootMemoryPool ? abortPool(op->pool()->root()) - : abortPool(op->pool()); - ASSERT_EQ( - driver->task()->leaveSuspended(driver->state()), - StopReason::kAlreadyTerminated); - ASSERT_TRUE(op->pool()->aborted()); - ASSERT_TRUE(op->pool()->root()->aborted()); - VELOX_MEM_POOL_ABORTED("Memory pool aborted"); - }))); - - VELOX_ASSERT_THROW( - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .planNode(plan) - .injectSpill(false) - .referenceQuery( - "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") - .run(), - "Manual MemoryPool Abortion"); - waitForAllTasksToBeDeleted(); - } -} - -TEST_F(HashJoinTest, leftJoinWithMissAtEndOfBatch) { - // Tests some cases where the row at the end of an output batch fails the - // filter. - auto probeVectors = std::vector{makeRowVector( - {"t_k1", "t_k2"}, - {makeFlatVector(20, [](auto row) { return 1 + row % 2; }), - makeFlatVector(20, [](auto row) { return row; })})}; - auto buildVectors = std::vector{ - makeRowVector({"u_k1"}, {makeFlatVector({1, 2})})}; - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", {buildVectors}); - auto planNodeIdGenerator = std::make_shared(); - - auto test = [&](const std::string& filter) { - auto plan = PlanBuilder(planNodeIdGenerator) - .values(probeVectors, true) - .hashJoin( - {"t_k1"}, - {"u_k1"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors, true) - .planNode(), - filter, - {"t_k1", "u_k1"}, - core::JoinType::kLeft) - .planNode(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .injectSpill(false) - .checkSpillStats(false) - .maxSpillLevel(0) - .numDrivers(1) - .config( - core::QueryConfig::kPreferredOutputBatchRows, std::to_string(10)) - .referenceQuery(fmt::format( - "SELECT t_k1, u_k1 from t left join u on t_k1 = u_k1 and {}", - filter)) - .run(); - }; - - // Alternate rows pass this filter and last row of a batch fails. - test("t_k1=1"); - - // All rows fail this filter. - test("t_k1=5"); - - // All rows in the second batch pass this filter. - test("t_k2 > 9"); -} - -TEST_F(HashJoinTest, leftJoinWithMissAtEndOfBatchMultipleBuildMatches) { - // Tests some cases where the row at the end of an output batch fails the - // filter and there are multiple matches with the build side.. - auto probeVectors = std::vector{makeRowVector( - {"t_k1", "t_k2"}, - {makeFlatVector(10, [](auto row) { return 1 + row % 2; }), - makeFlatVector(10, [](auto row) { return row; })})}; - auto buildVectors = std::vector{ - makeRowVector({"u_k1"}, {makeFlatVector({1, 2, 1, 2})})}; - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", {buildVectors}); - auto planNodeIdGenerator = std::make_shared(); - - auto test = [&](const std::string& filter) { - auto plan = PlanBuilder(planNodeIdGenerator) - .values(probeVectors, true) - .hashJoin( - {"t_k1"}, - {"u_k1"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors, true) - .planNode(), - filter, - {"t_k1", "u_k1"}, - core::JoinType::kLeft) - .planNode(); - - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .injectSpill(false) - .checkSpillStats(false) - .maxSpillLevel(0) - .numDrivers(1) - .config( - core::QueryConfig::kPreferredOutputBatchRows, std::to_string(10)) - .referenceQuery(fmt::format( - "SELECT t_k1, u_k1 from t left join u on t_k1 = u_k1 and {}", - filter)) - .run(); - }; - - // In this case the rows with t_k2 = 4 appear at the end of the first batch, - // meaning the last rows in that output batch are misses, and don't get added. - // The rows with t_k2 = 8 appear in the second batch so only one row is - // written, meaning there is space in the second output batch for the miss - // with tk_2 = 4 to get written. - test("t_k2 != 4 and t_k2 != 8"); -} - -TEST_F(HashJoinTest, leftJoinPreserveProbeOrder) { - const std::vector probeVectors = { - makeRowVector( - {"k1", "v1"}, - { - makeConstant(0, 2), - makeFlatVector({1, 0}), - }), - }; - const std::vector buildVectors = { - makeRowVector( - {"k2", "v2"}, - { - makeConstant(0, 2), - makeConstant(0, 2), - }), - }; - auto planNodeIdGenerator = std::make_shared(); - auto plan = - PlanBuilder(planNodeIdGenerator) - .values(probeVectors) - .hashJoin( - {"k1"}, - {"k2"}, - PlanBuilder(planNodeIdGenerator).values(buildVectors).planNode(), - "v1 % 2 = v2 % 2", - {"v1"}, - core::JoinType::kLeft) - .planNode(); - auto result = AssertQueryBuilder(plan) - .config(core::QueryConfig::kPreferredOutputBatchRows, "1") - .serialExecution(true) - .copyResults(pool_.get()); - ASSERT_EQ(result->size(), 3); - auto* v1 = - result->childAt(0)->loadedVector()->asUnchecked>(); - ASSERT_FALSE(v1->mayHaveNulls()); - ASSERT_EQ(v1->valueAt(0), 1); - ASSERT_EQ(v1->valueAt(1), 0); - ASSERT_EQ(v1->valueAt(2), 0); -} - -DEBUG_ONLY_TEST_F(HashJoinTest, minSpillableMemoryReservation) { - VectorFuzzer fuzzer({.vectorSize = 1000}, pool()); - const int32_t numBuildVectors = 10; - std::vector buildVectors; - for (int32_t i = 0; i < numBuildVectors; ++i) { - buildVectors.push_back(fuzzer.fuzzInputRow(buildType_)); - } - const int32_t numProbeVectors = 5; - std::vector probeVectors; - for (int32_t i = 0; i < numProbeVectors; ++i) { - probeVectors.push_back(fuzzer.fuzzInputRow(probeType_)); - } - - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", buildVectors); - - core::PlanNodeId probeScanId; - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .values(probeVectors, false) - .hashJoin( - {"t_k1"}, - {"u_k1"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors, false) - .planNode(), - "", - concat(probeType_->names(), buildType_->names())) - .planNode(); - - for (int32_t minSpillableReservationPct : {5, 50, 100}) { - SCOPED_TRACE(fmt::format( - "minSpillableReservationPct: {}", minSpillableReservationPct)); - - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::HashBuild::addInput", - std::function(([&](exec::HashBuild* hashBuild) { - memory::MemoryPool* pool = hashBuild->pool(); - const auto availableReservationBytes = pool->availableReservation(); - const auto currentUsedBytes = pool->usedBytes(); - // Verifies we always have min reservation after ensuring the input. - ASSERT_GE( - availableReservationBytes, - currentUsedBytes * minSpillableReservationPct / 100); - }))); - - auto tempDirectory = exec::test::TempDirectoryPath::create(); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .planNode(plan) - .injectSpill(false) - .spillDirectory(tempDirectory->getPath()) - .referenceQuery( - "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") - .run(); - } -} - -DEBUG_ONLY_TEST_F(HashJoinTest, exceededMaxSpillLevel) { - VectorFuzzer fuzzer({.vectorSize = 1000}, pool()); - const int32_t numBuildVectors = 10; - std::vector buildVectors; - for (int32_t i = 0; i < numBuildVectors; ++i) { - buildVectors.push_back(fuzzer.fuzzRow(buildType_)); - } - const int32_t numProbeVectors = 5; - std::vector probeVectors; - for (int32_t i = 0; i < numProbeVectors; ++i) { - probeVectors.push_back(fuzzer.fuzzRow(probeType_)); - } - - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", buildVectors); - - core::PlanNodeId probeScanId; - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .values(probeVectors, false) - .hashJoin( - {"t_k1"}, - {"u_k1"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors, false) - .planNode(), - "", - concat(probeType_->names(), buildType_->names())) - .planNode(); - - auto tempDirectory = exec::test::TempDirectoryPath::create(); - const int exceededMaxSpillLevelCount = - common::globalSpillStats().spillMaxLevelExceededCount; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::HashBuild::reclaim", - std::function(([&](exec::Operator* op) { - HashBuild* hashBuild = static_cast(op); - ASSERT_FALSE(hashBuild->testingExceededMaxSpillLevelLimit()); - }))); - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::HashProbe::reclaim", - std::function(([&](exec::Operator* op) { - HashProbe* hashProbe = static_cast(op); - ASSERT_FALSE(hashProbe->testingExceededMaxSpillLevelLimit()); - }))); - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::HashBuild::finishHashBuild", - std::function(([&](exec::HashBuild* hashBuild) { - Operator::ReclaimableSectionGuard guard(hashBuild); - testingRunArbitration(hashBuild->pool()); - }))); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(1) - .planNode(plan) - // Always trigger spilling. - .injectSpill(false) - .maxSpillLevel(0) - .spillDirectory(tempDirectory->getPath()) - .referenceQuery( - "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") - .config(core::QueryConfig::kSpillStartPartitionBit, "29") - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - auto opStats = toOperatorStats(task->taskStats()); - ASSERT_EQ( - opStats.at("HashProbe") - .runtimeStats[Operator::kExceededMaxSpillLevel] - .sum, - 8); - ASSERT_EQ( - opStats.at("HashProbe") - .runtimeStats[Operator::kExceededMaxSpillLevel] - .count, - 1); - ASSERT_EQ( - opStats.at("HashBuild") - .runtimeStats[Operator::kExceededMaxSpillLevel] - .sum, - 8); - ASSERT_EQ( - opStats.at("HashBuild") - .runtimeStats[Operator::kExceededMaxSpillLevel] - .count, - 1); - }) - .run(); - ASSERT_EQ( - common::globalSpillStats().spillMaxLevelExceededCount, - exceededMaxSpillLevelCount + 16); -} - -TEST_F(HashJoinTest, maxSpillBytes) { - const auto rowType = - ROW({"c0", "c1", "c2"}, {INTEGER(), INTEGER(), VARCHAR()}); - const auto probeVectors = createVectors(rowType, 1024, 10 << 20); - const auto buildVectors = createVectors(rowType, 1024, 10 << 20); - - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .values(probeVectors, true) - .project({"c0", "c1", "c2"}) - .hashJoin( - {"c0"}, - {"u1"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors, true) - .project({"c0 AS u0", "c1 AS u1", "c2 AS u2"}) - .planNode(), - "", - {"c0", "c1", "c2"}, - core::JoinType::kInner) - .planNode(); - - auto spillDirectory = exec::test::TempDirectoryPath::create(); - auto queryCtx = core::QueryCtx::create(executor_.get()); - - struct { - int32_t maxSpilledBytes; - bool expectedExceedLimit; - std::string debugString() const { - return fmt::format("maxSpilledBytes {}", maxSpilledBytes); - } - } testSettings[] = {{1 << 30, false}, {16 << 20, true}, {0, false}}; - - for (const auto& testData : testSettings) { - SCOPED_TRACE(testData.debugString()); - try { - TestScopedSpillInjection scopedSpillInjection(100); - AssertQueryBuilder(plan) - .spillDirectory(spillDirectory->getPath()) - .queryCtx(queryCtx) - .config(core::QueryConfig::kSpillEnabled, true) - .config(core::QueryConfig::kJoinSpillEnabled, true) - .config(core::QueryConfig::kMaxSpillBytes, testData.maxSpilledBytes) - .copyResults(pool_.get()); - ASSERT_FALSE(testData.expectedExceedLimit); - } catch (const VeloxRuntimeError& e) { - ASSERT_TRUE(testData.expectedExceedLimit); - ASSERT_NE( - e.message().find( - "Query exceeded per-query local spill limit of 16.00MB"), - std::string::npos); - ASSERT_EQ( - e.errorCode(), facebook::velox::error_code::kSpillLimitExceeded); - } - } -} - -TEST_F(HashJoinTest, onlyHashBuildMaxSpillBytes) { - const auto rowType = - ROW({"c0", "c1", "c2"}, {INTEGER(), INTEGER(), VARCHAR()}); - const auto probeVectors = createVectors(rowType, 32, 128); - const auto buildVectors = createVectors(rowType, 1024, 10 << 20); - - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .values(probeVectors, true) - .hashJoin( - {"c0"}, - {"u1"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors, true) - .project({"c0 AS u0", "c1 AS u1", "c2 AS u2"}) - .planNode(), - "", - {"c0", "c1", "c2"}, - core::JoinType::kInner) - .planNode(); - - auto spillDirectory = exec::test::TempDirectoryPath::create(); - auto queryCtx = core::QueryCtx::create(executor_.get()); - - struct { - int32_t maxSpilledBytes; - bool expectedExceedLimit; - std::string debugString() const { - return fmt::format("maxSpilledBytes {}", maxSpilledBytes); - } - } testSettings[] = {{1 << 30, false}, {16 << 20, true}, {0, false}}; - - for (const auto& testData : testSettings) { - SCOPED_TRACE(testData.debugString()); - try { - TestScopedSpillInjection scopedSpillInjection(100); - AssertQueryBuilder(plan) - .spillDirectory(spillDirectory->getPath()) - .queryCtx(queryCtx) - .config(core::QueryConfig::kSpillEnabled, true) - .config(core::QueryConfig::kJoinSpillEnabled, true) - .config(core::QueryConfig::kMaxSpillBytes, testData.maxSpilledBytes) - .copyResults(pool_.get()); - ASSERT_FALSE(testData.expectedExceedLimit); - } catch (const VeloxRuntimeError& e) { - ASSERT_TRUE(testData.expectedExceedLimit); - ASSERT_NE( - e.message().find( - "Query exceeded per-query local spill limit of 16.00MB"), - std::string::npos); - ASSERT_EQ( - e.errorCode(), facebook::velox::error_code::kSpillLimitExceeded); - } - } -} - -TEST_F(HashJoinTest, reclaimFromJoinBuilderWithMultiDrivers) { - auto rowType = ROW({ - {"c0", INTEGER()}, - {"c1", INTEGER()}, - {"c2", VARCHAR()}, - }); - const auto vectors = createVectors(rowType, 64 << 20, fuzzerOpts_); - const int numDrivers = 4; - - memory::MemoryManager::Options options; - options.allocatorCapacity = 8L << 30; - auto memoryManagerWithoutArbitrator = - std::make_unique(options); - const auto expectedResult = - runHashJoinTask( - vectors, - newQueryCtx( - memoryManagerWithoutArbitrator.get(), executor_.get(), 8L << 30), - false, - numDrivers, - pool(), - false) - .data; - - auto memoryManagerWithArbitrator = createMemoryManager(); - const auto& arbitrator = memoryManagerWithArbitrator->arbitrator(); - // Create a query ctx with a small capacity to trigger spilling. - auto result = runHashJoinTask( - vectors, - newQueryCtx( - memoryManagerWithArbitrator.get(), executor_.get(), 128 << 20), - false, - numDrivers, - pool(), - true, - expectedResult); - auto taskStats = exec::toPlanStats(result.task->taskStats()); - auto& planStats = taskStats.at(result.planNodeId); - ASSERT_GT(planStats.spilledBytes, 0); - result.task.reset(); - - // This test uses on-demand created memory manager instead of the global - // one. We need to make sure any used memory got cleaned up before exiting - // the scope - waitForAllTasksToBeDeleted(); - ASSERT_GT(arbitrator->stats().numRequests, 0); - ASSERT_GT(arbitrator->stats().reclaimedUsedBytes, 0); -} - -DEBUG_ONLY_TEST_F( - HashJoinTest, - failedToReclaimFromHashJoinBuildersInNonReclaimableSection) { - auto rowType = ROW({ - {"c0", INTEGER()}, - {"c1", INTEGER()}, - {"c2", VARCHAR()}, - }); - const auto vectors = createVectors(rowType, 64 << 20, fuzzerOpts_); - const int numDrivers = 1; - std::shared_ptr queryCtx = - newQueryCtx(memory::memoryManager(), executor_.get(), 512 << 20); - const auto expectedResult = - runHashJoinTask(vectors, queryCtx, false, numDrivers, pool(), false).data; - - std::atomic_bool nonReclaimableSectionWaitFlag{true}; - std::atomic_bool reclaimerInitializationWaitFlag{true}; - folly::EventCount nonReclaimableSectionWait; - std::atomic_bool memoryArbitrationWaitFlag{true}; - folly::EventCount memoryArbitrationWait; - - std::atomic numInitializedDrivers{0}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal", - std::function([&](exec::Driver* driver) { - numInitializedDrivers++; - // We need to make sure reclaimers on both build and probe side are set - // (in Operator::initialize) to avoid race conditions, producing - // consistent test results. - if (numInitializedDrivers.load() == 2) { - reclaimerInitializationWaitFlag = false; - nonReclaimableSectionWait.notifyAll(); - } - })); - - std::atomic injectNonReclaimableSectionOnce{true}; - SCOPED_TESTVALUE_SET( - "facebook::velox::common::memory::MemoryPoolImpl::allocateNonContiguous", - std::function( - ([&](memory::MemoryPoolImpl* pool) { - if (!isHashBuildMemoryPool(*pool)) { - return; - } - if (!injectNonReclaimableSectionOnce.exchange(false)) { - return; - } - - // Signal the test control that one of the hash build operator has - // entered into non-reclaimable section. - nonReclaimableSectionWaitFlag = false; - nonReclaimableSectionWait.notifyAll(); - - // Suspend the driver to simulate the arbitration. - pool->reclaimer()->enterArbitration(); - // Wait for the memory arbitration to complete. - memoryArbitrationWait.await( - [&]() { return !memoryArbitrationWaitFlag.load(); }); - pool->reclaimer()->leaveArbitration(); - }))); - - std::thread joinThread([&]() { - const auto result = runHashJoinTask( - vectors, queryCtx, false, numDrivers, pool(), true, expectedResult); - auto taskStats = exec::toPlanStats(result.task->taskStats()); - auto& planStats = taskStats.at(result.planNodeId); - ASSERT_EQ(planStats.spilledBytes, 0); - }); - - // Wait for the hash build operators to enter into non-reclaimable section. - nonReclaimableSectionWait.await([&]() { - return ( - !nonReclaimableSectionWaitFlag.load() && - !reclaimerInitializationWaitFlag.load()); - }); - - // We expect capacity grow fails as we can't reclaim from hash join operators. - memory::testingRunArbitration(); - - // Notify the hash build operator that memory arbitration has been done. - memoryArbitrationWaitFlag = false; - memoryArbitrationWait.notifyAll(); - - joinThread.join(); - - // This test uses on-demand created memory manager instead of the global - // one. We need to make sure any used memory got cleaned up before exiting - // the scope - waitForAllTasksToBeDeleted(); - ASSERT_EQ( - memory::memoryManager()->arbitrator()->stats().numNonReclaimableAttempts, - 2); -} - -DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringTableBuild) { - VectorFuzzer fuzzer({.vectorSize = 1000}, pool()); - const int32_t numBuildVectors = 5; - std::vector buildVectors; - for (int32_t i = 0; i < numBuildVectors; ++i) { - buildVectors.push_back(fuzzer.fuzzRow(buildType_)); - } - const int32_t numProbeVectors = 5; - std::vector probeVectors; - for (int32_t i = 0; i < numProbeVectors; ++i) { - probeVectors.push_back(fuzzer.fuzzRow(probeType_)); - } - - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", buildVectors); - - core::PlanNodeId probeScanId; - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .values(probeVectors, false) - .hashJoin( - {"t_k1"}, - {"u_k1"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors, false) - .planNode(), - "", - concat(probeType_->names(), buildType_->names())) - .planNode(); - - std::atomic_bool injectSpillOnce{true}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::HashBuild::finishHashBuild", - std::function([&](Operator* op) { - if (!injectSpillOnce.exchange(false)) { - return; - } - Operator::ReclaimableSectionGuard guard(op); - testingRunArbitration(op->pool()); - })); - - auto tempDirectory = exec::test::TempDirectoryPath::create(); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(4) - .planNode(plan) - .injectSpill(false) - .maxSpillLevel(0) - .spillDirectory(tempDirectory->getPath()) - .referenceQuery( - "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") - .config(core::QueryConfig::kSpillStartPartitionBit, "29") - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - auto opStats = toOperatorStats(task->taskStats()); - ASSERT_GT( - opStats.at("HashBuild").runtimeStats[Operator::kSpillWrites].sum, - 0); - }) - .run(); -} - -DEBUG_ONLY_TEST_F(HashJoinTest, exceptionDuringFinishJoinBuild) { - // This test is to make sure there is no memory leak when exceptions are - // thrown while parallelly preparing join table. - auto memoryManager = memory::memoryManager(); - const auto& arbitrator = memoryManager->arbitrator(); - const uint64_t numDrivers = 2; - const auto expectedFreeCapacityBytes = arbitrator->stats().freeCapacityBytes; - - const uint64_t numBuildSideRows = 500; - auto buildKeyVector = makeFlatVector( - numBuildSideRows, - [](vector_size_t row) { return folly::Random::rand64(); }); - auto buildSideVector = - makeRowVector({"b0", "b1"}, {buildKeyVector, buildKeyVector}); - std::vector buildSideVectors; - for (int i = 0; i < numDrivers; ++i) { - buildSideVectors.push_back(buildSideVector); - } - createDuckDbTable("build", buildSideVectors); - - const uint64_t numProbeSideRows = 10; - auto probeKeyVector = makeFlatVector( - numProbeSideRows, - [&](vector_size_t row) { return buildKeyVector->valueAt(row); }); - auto probeSideVector = - makeRowVector({"p0", "p1"}, {probeKeyVector, probeKeyVector}); - std::vector probeSideVectors; - for (int i = 0; i < numDrivers; ++i) { - probeSideVectors.push_back(probeSideVector); - } - createDuckDbTable("probe", probeSideVectors); - - ASSERT_EQ(arbitrator->stats().freeCapacityBytes, expectedFreeCapacityBytes); - - // We set the task to fail right before we reserve memory for other operators. - // We rely on the driver suspension before parallel join build to throw - // exceptions (suspension on an already terminated task throws). - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::HashBuild::ensureTableFits", - std::function([&](HashBuild* buildOp) { - try { - VELOX_FAIL("Simulated failure"); - } catch (VeloxException&) { - buildOp->operatorCtx()->task()->setError(std::current_exception()); - } - })); - - std::vector probeInput = {probeSideVector}; - std::vector buildInput = {buildSideVector}; - auto planNodeIdGenerator = std::make_shared(); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); - - ASSERT_EQ(arbitrator->stats().freeCapacityBytes, expectedFreeCapacityBytes); - VELOX_ASSERT_THROW( - AssertQueryBuilder(duckDbQueryRunner_) - .spillDirectory(spillDirectory->getPath()) - .config(core::QueryConfig::kSpillEnabled, true) - .config(core::QueryConfig::kJoinSpillEnabled, true) - .queryCtx( - newQueryCtx(memoryManager, executor_.get(), kMemoryCapacity)) - .maxDrivers(numDrivers) - .plan(PlanBuilder(planNodeIdGenerator) - .values(probeInput, true) - .hashJoin( - {"p0"}, - {"b0"}, - PlanBuilder(planNodeIdGenerator) - .values(buildInput, true) - .planNode(), - "", - {"p0", "p1", "b0", "b1"}, - core::JoinType::kInner) - .planNode()) - .assertResults( - "SELECT probe.p0, probe.p1, build.b0, build.b1 FROM probe " - "INNER JOIN build ON probe.p0 = build.b0"), - "Simulated failure"); - // This test uses on-demand created memory manager instead of the global - // one. We need to make sure any used memory got cleaned up before exiting - // the scope - waitForAllTasksToBeDeleted(); - ASSERT_EQ(arbitrator->stats().freeCapacityBytes, expectedFreeCapacityBytes); -} - -DEBUG_ONLY_TEST_F(HashJoinTest, arbitrationTriggeredDuringParallelJoinBuild) { - std::unique_ptr memoryManager = createMemoryManager(); - const uint64_t numDrivers = 2; - - // Large build side key product to bump hash mode to kHash instead of kArray - // to trigger parallel join build. - const uint64_t numBuildSideRows = 500; - auto buildKeyVector = makeFlatVector( - numBuildSideRows, - [](vector_size_t row) { return folly::Random::rand64(); }); - auto buildSideVector = makeRowVector( - {"b0", "b1", "b2"}, {buildKeyVector, buildKeyVector, buildKeyVector}); - std::vector buildSideVectors; - for (int i = 0; i < numDrivers; ++i) { - buildSideVectors.push_back(buildSideVector); - } - createDuckDbTable("build", buildSideVectors); - - const uint64_t numProbeSideRows = 10; - auto probeKeyVector = makeFlatVector( - numProbeSideRows, - [&](vector_size_t row) { return buildKeyVector->valueAt(row); }); - auto probeSideVector = makeRowVector( - {"p0", "p1", "p2"}, {probeKeyVector, probeKeyVector, probeKeyVector}); - std::vector probeSideVectors; - for (int i = 0; i < numDrivers; ++i) { - probeSideVectors.push_back(probeSideVector); - } - createDuckDbTable("probe", probeSideVectors); - - std::shared_ptr joinQueryCtx = - newQueryCtx(memoryManager.get(), executor_.get(), kMemoryCapacity); - - const int64_t allocSize = 512LL << 20; - std::atomic parallelBuildTriggered{false}; - std::atomic joinBuildPool{nullptr}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::HashTable::parallelJoinBuild", - std::function([&](memory::MemoryPool* pool) { - parallelBuildTriggered = true; - // Pick the last running driver threads' pool for later memory - // allocation. This pick is rather arbitrary, as it is un-important - // which pool is going to be allocated from later in a parallel join's - // off-driver thread. - joinBuildPool = pool; - })); - - std::atomic_bool offThreadAllocationTriggered{false}; - folly::EventCount asyncMoveWait; - std::atomic asyncMoveWaitFlag{true}; - SCOPED_TESTVALUE_SET( - "facebook::velox::AsyncSource::prepare", - std::function([&](void* /* unused */) { - if (!offThreadAllocationTriggered.exchange(true)) { - SCOPE_EXIT { - asyncMoveWaitFlag = false; - asyncMoveWait.notifyAll(); - }; - // Executed by the first thread hitting the test value location. This - // allocation will trigger arbitration and fail. - VELOX_ASSERT_THROW( - joinBuildPool.load()->allocate(allocSize), - "Exceeded memory pool cap"); - } - })); - - // Wait for allocation (hence arbitration) on the prepare thread to finish - // before calling AsyncSource::move(). This is to ensure no other AsyncSource - // (hence arbitration) is running on the driver thread (on-thread) before the - // ongoing arbitration finishes. Without ensuring this, the on-thread - // arbitration (triggered by calling AsyncSource::move() first) has - // thread-local driver context by default, defying the purpose of this test. - SCOPED_TESTVALUE_SET( - "facebook::velox::AsyncSource::move", - std::function([&](void* /* unused */) { - asyncMoveWait.await([&]() { return !asyncMoveWaitFlag.load(); }); - })); - - std::vector probeInput = {probeSideVector}; - std::vector buildInput = {buildSideVector}; - auto planNodeIdGenerator = std::make_shared(); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); - AssertQueryBuilder(duckDbQueryRunner_) - .spillDirectory(spillDirectory->getPath()) - .config(core::QueryConfig::kSpillEnabled, true) - .config(core::QueryConfig::kJoinSpillEnabled, true) - // Set very low table size threshold to trigger parallel build. - .config(core::QueryConfig::kMinTableRowsForParallelJoinBuild, 0) - // Set multiple hash build drivers to trigger parallel build. - .maxDrivers(numDrivers) - .queryCtx(joinQueryCtx) - .plan(PlanBuilder(planNodeIdGenerator) - .values(probeInput, true) - .hashJoin( - {"p0", "p1", "p2"}, - {"b0", "b1", "b2"}, - PlanBuilder(planNodeIdGenerator) - .values(buildInput, true) - .planNode(), - "", - {"p0", "p1", "b0", "b1"}, - core::JoinType::kInner) - .planNode()) - .assertResults( - "SELECT probe.p0, probe.p1, build.b0, build.b1 FROM probe " - "INNER JOIN build ON probe.p0 = build.b0 AND probe.p1 = build.b1 AND " - "probe.p2 = build.b2"); - ASSERT_TRUE(parallelBuildTriggered); - - // This test uses on-demand created memory manager instead of the global - // one. We need to make sure any used memory got cleaned up before exiting - // the scope - waitForAllTasksToBeDeleted(); -} - -DEBUG_ONLY_TEST_F(HashJoinTest, arbitrationTriggeredByEnsureJoinTableFit) { - // Use manual spill injection other than spill injection framework. This is - // because spill injection framework does not allow fine grain spill within a - // single operator (We do not want to spill during addInput() but only during - // finishHashBuild()). - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::HashBuild::ensureTableFits", - std::function(([&](Operator* op) { - Operator::ReclaimableSectionGuard guard(op); - memory::testingRunArbitration(op->pool()); - }))); - auto tempDirectory = exec::test::TempDirectoryPath::create(); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .injectSpill(false) - .spillDirectory(tempDirectory->getPath()) - .keyTypes({BIGINT()}) - .probeVectors(1600, 5) - .buildVectors(1500, 5) - .referenceQuery( - "SELECT t_k0, t_data, u_k0, u_data FROM t, u WHERE t.t_k0 = u.u_k0") - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - const auto statsPair = taskSpilledStats(*task); - ASSERT_GT(statsPair.first.spilledBytes, 0); - }) - .run(); -} - -DEBUG_ONLY_TEST_F(HashJoinTest, joinBuildSpillError) { - const int kMemoryCapacity = 32 << 20; - // Set a small memory capacity to trigger spill. - std::unique_ptr memoryManager = - createMemoryManager(kMemoryCapacity, 0); - const auto& arbitrator = memoryManager->arbitrator(); - auto rowType = ROW( - {{"c0", INTEGER()}, - {"c1", INTEGER()}, - {"c2", VARCHAR()}, - {"c3", VARCHAR()}}); - - std::vector vectors = createVectors(16, rowType, fuzzerOpts_); - createDuckDbTable(vectors); - - std::shared_ptr joinQueryCtx = - newQueryCtx(memoryManager.get(), executor_.get(), kMemoryCapacity); - - const int numDrivers = 4; - std::atomic numAppends{0}; - const std::string injectedErrorMsg("injected spillError"); - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::SpillState::appendToPartition", - std::function([&](exec::SpillState* state) { - if (++numAppends != numDrivers) { - return; - } - VELOX_FAIL(injectedErrorMsg); - })); - - auto planNodeIdGenerator = std::make_shared(); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); - auto plan = PlanBuilder(planNodeIdGenerator) - .values(vectors) - .project({"c0 AS t0", "c1 AS t1", "c2 AS t2"}) - .hashJoin( - {"t0"}, - {"u0"}, - PlanBuilder(planNodeIdGenerator) - .values(vectors) - .project({"c0 AS u0", "c1 AS u1", "c2 AS u2"}) - .planNode(), - "", - {"t1"}, - core::JoinType::kAnti) - .planNode(); - VELOX_ASSERT_THROW( - AssertQueryBuilder(plan) - .queryCtx(joinQueryCtx) - .spillDirectory(spillDirectory->getPath()) - .config(core::QueryConfig::kSpillEnabled, true) - .copyResults(pool()), - injectedErrorMsg); - - waitForAllTasksToBeDeleted(); - ASSERT_EQ(arbitrator->stats().numFailures, 1); - - // Wait again here as this test uses on-demand created memory manager instead - // of the global one. We need to make sure any used memory got cleaned up - // before exiting the scope - waitForAllTasksToBeDeleted(); -} - -DEBUG_ONLY_TEST_F(HashJoinTest, probeSpillOnWaitForPeers) { - // This test creates a scenario when tester probe thread finishes processing - // input, entering kWaitForPeers state, and the other thread is still - // processing, spill is triggered properly performed. - - folly::EventCount startWait; - folly::Synchronized testerOpName; - std::atomic_bool injectedSpillOnce{false}; - - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal::getOutput", - std::function([&](Operator* op) { - if (!isHashProbeMemoryPool(*op->pool())) { - return; - } - testerOpName.withWLock([&](std::string& opName) { - if (opName.empty()) { - opName = op->pool()->name(); - } - }); - if (op->pool()->name() == *testerOpName.rlock()) { - // Do not block tester thread. - return; - } - startWait.await([&]() { return injectedSpillOnce.load(); }); - })); - - // tester probe operator is guaranteed to be in kWaitForPeers state the next - // isBlocked() is called after noMoreInput() is called. - std::atomic_bool noMoreInputCalled{false}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal::noMoreInput", - std::function([&](Operator* op) { - if (!isHashProbeMemoryPool(*op->pool())) { - return; - } - noMoreInputCalled = true; - })); - - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal::isBlocked", - std::function([&](Operator* op) { - if (!isHashProbeMemoryPool(*op->pool())) { - return; - } - if (injectedSpillOnce || !noMoreInputCalled) { - return; - } - injectedSpillOnce = true; - EXPECT_EQ( - dynamic_cast(op)->testingState(), - ProbeOperatorState::kWaitForPeers); - testingRunArbitration(op->pool()); - })); - - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Task::requestPauseLocked", - std::function([&](Task* task) { startWait.notifyAll(); })); - - const uint64_t numDrivers{2}; - std::shared_ptr joinQueryCtx = - newQueryCtx(memory::memoryManager(), executor_.get(), kMemoryCapacity); - auto rowType = ROW({{"c0", INTEGER()}, {"c1", INTEGER()}}); - fuzzerOpts_.vectorSize = 20; - std::vector vectors = createVectors(6, rowType, fuzzerOpts_); - std::vector totalVectors; - for (auto i = 0; i < numDrivers; ++i) { - totalVectors.insert(totalVectors.end(), vectors.begin(), vectors.end()); - } - createDuckDbTable(totalVectors); - auto spillDirectory = exec::test::TempDirectoryPath::create(); - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .values(vectors, true) - .project({"c0 AS t0", "c1 AS t1"}) - .hashJoin( - {"t0"}, - {"u0"}, - PlanBuilder(planNodeIdGenerator) - .values(vectors, true) - .project({"c0 AS u0", "c1 AS u1"}) - .planNode(), - "", - {"t1"}, - core::JoinType::kInner) - .planNode(); - - { - auto task = - AssertQueryBuilder(duckDbQueryRunner_) - .plan(plan) - .queryCtx(joinQueryCtx) - .spillDirectory(spillDirectory->getPath()) - .config(core::QueryConfig::kSpillEnabled, true) - .maxDrivers(numDrivers) - .assertResults("SELECT a.c1 from tmp a join tmp b on a.c0 = b.c0"); - - auto opStats = toOperatorStats(task->taskStats()); - ASSERT_GT(opStats.at("HashProbe").spilledBytes, 0); - ASSERT_EQ(opStats.at("HashBuild").spilledBytes, 0); - - const auto* arbitrator = memory::memoryManager()->arbitrator(); - ASSERT_GT(arbitrator->stats().reclaimedUsedBytes, 0); - } - waitForAllTasksToBeDeleted(); -} - -DEBUG_ONLY_TEST_F(HashJoinTest, taskWaitTimeout) { - const int queryMemoryCapacity = 128 << 20; - // Creates a large number of vectors based on the query capacity to trigger - // memory arbitration. - fuzzerOpts_.vectorSize = 10'000; - auto rowType = ROW( - {{"c0", INTEGER()}, - {"c1", INTEGER()}, - {"c2", VARCHAR()}, - {"c3", VARCHAR()}}); - const auto vectors = - createVectors(rowType, queryMemoryCapacity / 2, fuzzerOpts_); - const int numDrivers = 4; - const auto expectedResult = - runHashJoinTask(vectors, nullptr, false, numDrivers, pool(), false).data; - - for (uint64_t timeoutMs : {1'000, 30'000}) { - SCOPED_TRACE(fmt::format("timeout {}", succinctMillis(timeoutMs))); - auto memoryManager = createMemoryManager(512 << 20, 0, timeoutMs); - auto queryCtx = - newQueryCtx(memoryManager.get(), executor_.get(), queryMemoryCapacity); - - // Set test injection to block one hash build operator to inject delay when - // memory reclaim waits for task to pause. - folly::EventCount buildBlockWait; - std::atomic buildBlockWaitFlag{true}; - std::atomic blockOneBuild{true}; - SCOPED_TESTVALUE_SET( - "facebook::velox::common::memory::MemoryPoolImpl::maybeReserve", - std::function([&](memory::MemoryPool* pool) { - const std::string re(".*HashBuild"); - if (!RE2::FullMatch(pool->name(), re)) { - return; - } - if (!blockOneBuild.exchange(false)) { - return; - } - buildBlockWait.await([&]() { return !buildBlockWaitFlag.load(); }); - })); - - folly::EventCount taskPauseWait; - std::atomic taskPauseWaitFlag{false}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Task::requestPauseLocked", - std::function(([&](Task* /*unused*/) { - taskPauseWaitFlag = true; - taskPauseWait.notifyAll(); - }))); - - std::thread queryThread([&]() { - // We expect failure on short time out. - if (timeoutMs == 1'000) { - VELOX_ASSERT_THROW( - runHashJoinTask( - vectors, - queryCtx, - false, - numDrivers, - pool(), - true, - expectedResult), - "Memory reclaim failed to wait"); - } else { - // We expect succeed on large time out or no timeout. - const auto result = runHashJoinTask( - vectors, queryCtx, false, numDrivers, pool(), true, expectedResult); - auto taskStats = exec::toPlanStats(result.task->taskStats()); - auto& planStats = taskStats.at(result.planNodeId); - ASSERT_GT(planStats.spilledBytes, 0); - } - }); - - // Wait for task pause to reach, and then delay for a while before unblock - // the blocked hash build operator. - taskPauseWait.await([&]() { return taskPauseWaitFlag.load(); }); - // Wait for two seconds and expect the short reclaim wait timeout. - std::this_thread::sleep_for(std::chrono::seconds(2)); - // Unblock the blocked build operator to let memory reclaim proceed. - buildBlockWaitFlag = false; - buildBlockWait.notifyAll(); - - queryThread.join(); - - // This test uses on-demand created memory manager instead of the global - // one. We need to make sure any used memory got cleaned up before exiting - // the scope - waitForAllTasksToBeDeleted(); - } -} - -DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeSpill) { - struct { - bool triggerBuildSpill; - // Triggers after no more input or not. - bool afterNoMoreInput; - // The index of get output call to trigger probe side spilling. - int probeOutputIndex; - - std::string debugString() const { - return fmt::format( - "triggerBuildSpill: {}, afterNoMoreInput: {}, probeOutputIndex: {}", - triggerBuildSpill, - afterNoMoreInput, - probeOutputIndex); - } - } testSettings[] = { - {false, false, 0}, - {false, false, 1}, - {false, false, 10}, - {false, true, 0}, - {true, false, 0}, - {true, false, 1}, - {true, false, 10}, - {true, true, 0}}; - - for (const auto& testData : testSettings) { - SCOPED_TRACE(testData.debugString()); - - std::atomic_bool injectBuildSpillOnce{true}; - std::atomic_int buildInputCount{0}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal::addInput", - std::function([&](Operator* op) { - if (!testData.triggerBuildSpill) { - return; - } - if (!isHashBuildMemoryPool(*op->pool())) { - return; - } - if (buildInputCount++ != 1) { - return; - } - if (!injectBuildSpillOnce.exchange(false)) { - return; - } - testingRunArbitration(op->pool()); - })); - - std::atomic_bool injectProbeSpillOnce{true}; - std::atomic_int probeOutputCount{0}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal::getOutput", - std::function([&](Operator* op) { - if (!isHashProbeMemoryPool(*op->pool())) { - return; - } - if (testData.afterNoMoreInput) { - if (!op->testingNoMoreInput()) { - return; - } - } else { - if (probeOutputCount++ != testData.probeOutputIndex) { - return; - } - } - if (!injectProbeSpillOnce.exchange(false)) { - return; - } - testingRunArbitration(op->pool()); - })); - - fuzzerOpts_.vectorSize = 128; - auto probeVectors = createVectors(10, probeType_, fuzzerOpts_); - auto buildVectors = createVectors(20, buildType_, fuzzerOpts_); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(1) - .spillDirectory(spillDirectory->getPath()) - .probeKeys({"t_k1"}) - .probeVectors(std::move(probeVectors)) - .buildKeys({"u_k1"}) - .buildVectors(std::move(buildVectors)) - .config(core::QueryConfig::kJoinSpillEnabled, "true") - .joinType(core::JoinType::kRight) - .joinOutputLayout({"t_k1", "t_k2", "u_k1", "t_v1"}) - .referenceQuery( - "SELECT t.t_k1, t.t_k2, u.u_k1, t.t_v1 FROM t RIGHT JOIN u ON t.t_k1 = u.u_k1") - .injectSpill(false) - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - auto opStats = toOperatorStats(task->taskStats()); - ASSERT_GT(opStats.at("HashProbe").spilledBytes, 0); - if (testData.triggerBuildSpill) { - ASSERT_GT(opStats.at("HashBuild").spilledBytes, 0); - } else { - ASSERT_EQ(opStats.at("HashBuild").spilledBytes, 0); - } - - const auto* arbitrator = memory::memoryManager()->arbitrator(); - ASSERT_GT(arbitrator->stats().numRequests, 0); - ASSERT_GT(arbitrator->stats().reclaimedUsedBytes, 0); - }) - .run(); - } -} - -DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeSpillInMiddeOfLastOutputProcessing) { - std::atomic_int outputCountAfterNoMoreInout{0}; - std::atomic_bool injectOnce{true}; - ::facebook::velox::common::testutil::ScopedTestValue abc( - "facebook::velox::exec::Driver::runInternal::getOutput", - std::function([&](Operator* op) { - if (!isHashProbeMemoryPool(*op->pool())) { - return; - } - if (!op->testingNoMoreInput()) { - return; - } - if (outputCountAfterNoMoreInout++ != 1) { - return; - } - if (!injectOnce.exchange(false)) { - return; - } - testingRunArbitration(op->pool()); - })); - - fuzzerOpts_.vectorSize = 128; - auto probeVectors = createVectors(10, probeType_, fuzzerOpts_); - auto buildVectors = createVectors(20, buildType_, fuzzerOpts_); - - const auto spillDirectory = exec::test::TempDirectoryPath::create(); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(1) - .spillDirectory(spillDirectory->getPath()) - .probeKeys({"t_k1"}) - .probeVectors(std::move(probeVectors)) - .buildKeys({"u_k1"}) - .buildVectors(std::move(buildVectors)) - .config(core::QueryConfig::kJoinSpillEnabled, "true") - .config(core::QueryConfig::kPreferredOutputBatchRows, std::to_string(10)) - .joinType(core::JoinType::kRight) - .joinOutputLayout({"t_k1", "t_k2", "u_k1", "t_v1"}) - .referenceQuery( - "SELECT t.t_k1, t.t_k2, u.u_k1, t.t_v1 FROM t RIGHT JOIN u ON t.t_k1 = u.u_k1") - .injectSpill(false) - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - auto opStats = toOperatorStats(task->taskStats()); - ASSERT_GT(opStats.at("HashProbe").spilledBytes, 0); - // Verifies that we only spill the output which is single partitioned - // but not the hash table. - ASSERT_EQ(opStats.at("HashProbe").spilledPartitions, 1); - }) - .run(); -} - -// Inject probe-side spilling in the middle of output processing. If -// 'recursiveSpill' is true, we trigger probe-spilling when probe the hash table -// built from spilled data. -DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeSpillInMiddeOfOutputProcessing) { - for (bool recursiveSpill : {false, true}) { - std::atomic_int buildInputCount{0}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal::addInput", - std::function([&](Operator* op) { - if (!isHashBuildMemoryPool(*op->pool())) { - return; - } - if (!recursiveSpill) { - return; - } - // Trigger spill after the build side has processed some rows. - if (buildInputCount++ != 1) { - return; - } - testingRunArbitration(op->pool()); - })); - - std::atomic_bool injectProbeSpillOnce{true}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal::getOutput", - std::function([&](Operator* op) { - if (!isHashProbeMemoryPool(*op->pool())) { - return; - } - - if (op->testingHasInput()) { - return; - } - if (recursiveSpill) { - if (static_cast(op)->testingHasInputSpiller()) { - return; - } - } - if (!injectProbeSpillOnce.exchange(false)) { - return; - } - testingRunArbitration(op->pool()); - })); - - fuzzerOpts_.vectorSize = 128; - auto probeVectors = createVectors(10, probeType_, fuzzerOpts_); - auto buildVectors = createVectors(20, buildType_, fuzzerOpts_); - - const auto spillDirectory = exec::test::TempDirectoryPath::create(); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(1) - .spillDirectory(spillDirectory->getPath()) - .probeKeys({"t_k1"}) - .probeVectors(std::move(probeVectors)) - .buildKeys({"u_k1"}) - .buildVectors(std::move(buildVectors)) - .config(core::QueryConfig::kJoinSpillEnabled, "true") - .config( - core::QueryConfig::kPreferredOutputBatchRows, std::to_string(10)) - .joinType(core::JoinType::kRight) - .joinOutputLayout({"t_k1", "t_k2", "u_k1", "t_v1"}) - .referenceQuery( - "SELECT t.t_k1, t.t_k2, u.u_k1, t.t_v1 FROM t RIGHT JOIN u ON t.t_k1 = u.u_k1") - .injectSpill(false) - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - auto opStats = toOperatorStats(task->taskStats()); - ASSERT_GT(opStats.at("HashProbe").spilledBytes, 0); - ASSERT_GT(opStats.at("HashProbe").spilledPartitions, 1); - }) - .run(); - } -} - -DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeSpillWhenOneOfProbeFinish) { - const int numDrivers{3}; - - std::atomic_bool probeWaitFlag{true}; - folly::EventCount probeWait; - std::atomic_int numBlockedProbeOps{0}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal::getOutput", - std::function([&](Operator* op) { - if (!isHashProbeMemoryPool(*op->pool())) { - return; - } - if (++numBlockedProbeOps <= numDrivers - 1) { - probeWait.await([&]() { return !probeWaitFlag.load(); }); - return; - } - })); - - std::atomic_bool notifyOnce{true}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal::noMoreInput", - std::function([&](Operator* op) { - if (!isHashProbeMemoryPool(*op->pool())) { - return; - } - if (!notifyOnce.exchange(false)) { - return; - } - probeWaitFlag = false; - probeWait.notifyAll(); - })); - - std::thread queryThread([&]() { - const auto spillDirectory = exec::test::TempDirectoryPath::create(); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers, true, true) - .spillDirectory(spillDirectory->getPath()) - .keyTypes({BIGINT()}) - .probeVectors(32, 5) - .buildVectors(32, 5) - .config(core::QueryConfig::kJoinSpillEnabled, "true") - .referenceQuery( - "SELECT t_k0, t_data, u_k0, u_data FROM t, u WHERE t.t_k0 = u.u_k0") - .injectSpill(false) - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - auto opStats = toOperatorStats(task->taskStats()); - ASSERT_EQ(opStats.at("HashBuild").spilledBytes, 0); - ASSERT_GT(opStats.at("HashProbe").spilledBytes, 0); - }) - .run(); - }); - // Wait until one of the hash probe operator has finished. - probeWait.await([&]() { return !probeWaitFlag.load(); }); - memory::testingRunArbitration(); - queryThread.join(); -} - -DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeSpillExceedLimit) { - // If 'buildTriggerSpill' is true, then spilling is triggered by hash build. - for (const bool buildTriggerSpill : {false, true}) { - SCOPED_TRACE(fmt::format("buildTriggerSpill {}", buildTriggerSpill)); - - SCOPED_TESTVALUE_SET( - "facebook::velox::common::memory::MemoryPoolImpl::maybeReserve", - std::function([&](memory::MemoryPool* pool) { - if (buildTriggerSpill && !isHashBuildMemoryPool(*pool)) { - return; - } - if (!buildTriggerSpill && !isHashProbeMemoryPool(*pool)) { - return; - } - testingRunArbitration(pool); - })); - - fuzzerOpts_.vectorSize = 128; - auto probeVectors = createVectors(32, probeType_, fuzzerOpts_); - auto buildVectors = createVectors(64, buildType_, fuzzerOpts_); - for (int i = 0; i < probeVectors.size(); ++i) { - const auto probeKeyChannel = probeType_->getChildIdx("t_k1"); - const auto buildKeyChannle = buildType_->getChildIdx("u_k1"); - probeVectors[i]->childAt(probeKeyChannel) = - buildVectors[i]->childAt(buildKeyChannle); - } - - const auto spillDirectory = exec::test::TempDirectoryPath::create(); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(1) - .spillDirectory(spillDirectory->getPath()) - .probeKeys({"t_k1"}) - .probeVectors(std::move(probeVectors)) - .buildKeys({"u_k1"}) - .buildVectors(std::move(buildVectors)) - .config(core::QueryConfig::kMaxSpillLevel, "1") - .config(core::QueryConfig::kSpillNumPartitionBits, "1") - .config(core::QueryConfig::kJoinSpillEnabled, "true") - // Set small write buffer size to have small vectors to read from - // spilled data. - .config(core::QueryConfig::kSpillWriteBufferSize, "1") - .config( - core::QueryConfig::kPreferredOutputBatchRows, std::to_string(10)) - .joinType(core::JoinType::kRight) - .joinOutputLayout({"t_k1", "t_k2", "u_k1", "t_v1"}) - .referenceQuery( - "SELECT t.t_k1, t.t_k2, u.u_k1, t.t_v1 FROM t RIGHT JOIN u ON t.t_k1 = u.u_k1") - .injectSpill(false) - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - auto opStats = toOperatorStats(task->taskStats()); - if (buildTriggerSpill) { - ASSERT_GT(opStats.at("HashProbe").spilledBytes, 0); - ASSERT_GT(opStats.at("HashBuild").spilledBytes, 0); - } else { - ASSERT_GT(opStats.at("HashProbe").spilledBytes, 0); - ASSERT_EQ(opStats.at("HashBuild").spilledBytes, 0); - } - ASSERT_GT( - opStats.at("HashProbe") - .runtimeStats[Operator::kExceededMaxSpillLevel] - .sum, - 0); - ASSERT_GT( - opStats.at("HashBuild") - .runtimeStats[Operator::kExceededMaxSpillLevel] - .sum, - 0); - }) - .run(); - } -} - -DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeSpillUnderNonReclaimableSection) { - std::atomic_bool injectOnce{true}; - SCOPED_TESTVALUE_SET( - "facebook::velox::common::memory::MemoryPoolImpl::allocateNonContiguous", - std::function([&](memory::MemoryPool* pool) { - if (!isHashProbeMemoryPool(*pool)) { - return; - } - if (!injectOnce.exchange(false)) { - return; - } - auto* arbitrator = memory::memoryManager()->arbitrator(); - const auto numNonReclaimableAttempts = - arbitrator->stats().numNonReclaimableAttempts; - testingRunArbitration(pool); - // Verifies that we run into non-reclaimable section when reclaim from - // hash probe. - ASSERT_EQ( - arbitrator->stats().numNonReclaimableAttempts, - numNonReclaimableAttempts + 1); - })); - - const auto spillDirectory = exec::test::TempDirectoryPath::create(); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(1) - .spillDirectory(spillDirectory->getPath()) - .keyTypes({BIGINT()}) - .probeVectors(32, 5) - .buildVectors(32, 5) - .config(core::QueryConfig::kJoinSpillEnabled, "true") - .referenceQuery( - "SELECT t_k0, t_data, u_k0, u_data FROM t, u WHERE t.t_k0 = u.u_k0") - .injectSpill(false) - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - auto opStats = toOperatorStats(task->taskStats()); - ASSERT_EQ(opStats.at("HashProbe").spilledBytes, 0); - ASSERT_EQ(opStats.at("HashBuild").spilledBytes, 0); - }) - .run(); -} - -// This test case is to cover the case that hash probe trigger spill for right -// semi join types and the pending input needs to be processed in multiple -// steps. -DEBUG_ONLY_TEST_F(HashJoinTest, spillOutputWithRightSemiJoins) { - for (const auto joinType : - {core::JoinType::kRightSemiFilter, core::JoinType::kRightSemiProject}) { - std::atomic_bool injectOnce{true}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal::getOutput", - std::function([&](Operator* op) { - if (op->operatorCtx()->operatorType() != "HashProbe") { - return; - } - if (!op->testingHasInput()) { - return; - } - if (!injectOnce.exchange(false)) { - return; - } - testingRunArbitration(op->pool()); - })); - - std::string duckDbSqlReference; - std::vector joinOutputLayout; - bool nullAware{false}; - if (joinType == core::JoinType::kRightSemiProject) { - duckDbSqlReference = "SELECT u_k2, u_k1 IN (SELECT t_k1 FROM t) FROM u"; - joinOutputLayout = {"u_k2", "match"}; - // Null aware is only supported for semi projection join type. - nullAware = true; - } else { - duckDbSqlReference = - "SELECT u_k2 FROM u WHERE u_k1 IN (SELECT t_k1 FROM t)"; - joinOutputLayout = {"u_k2"}; - } - - const auto spillDirectory = exec::test::TempDirectoryPath::create(); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(1) - .spillDirectory(spillDirectory->getPath()) - .probeType(probeType_) - .probeVectors(128, 3) - .probeKeys({"t_k1"}) - .buildType(buildType_) - .buildVectors(128, 4) - .buildKeys({"u_k1"}) - .joinType(joinType) - // Set a small number of output rows to process the input in multiple - // steps. - .config( - core::QueryConfig::kPreferredOutputBatchRows, std::to_string(10)) - .injectSpill(false) - .joinOutputLayout(std::move(joinOutputLayout)) - .nullAware(nullAware) - .referenceQuery(duckDbSqlReference) - .run(); - } -} - -DEBUG_ONLY_TEST_F(HashJoinTest, spillCheckOnLeftSemiFilterWithDynamicFilters) { - const int32_t numSplits = 10; - const int32_t numRowsProbe = 333; - const int32_t numRowsBuild = 100; - - std::vector probeVectors; - probeVectors.reserve(numSplits); - - std::vector> tempFiles; - for (int32_t i = 0; i < numSplits; ++i) { - auto rowVector = makeRowVector({ - makeFlatVector( - numRowsProbe, [&](auto row) { return row - i * 10; }), - makeFlatVector(numRowsProbe, [](auto row) { return row; }), - }); - probeVectors.push_back(rowVector); - tempFiles.push_back(TempFilePath::create()); - writeToFile(tempFiles.back()->getPath(), rowVector); - } - auto makeInputSplits = [&](const core::PlanNodeId& nodeId) { - return [&] { - std::vector probeSplits; - for (auto& file : tempFiles) { - probeSplits.push_back( - exec::Split(makeHiveConnectorSplit(file->getPath()))); - } - SplitInput splits; - splits.emplace(nodeId, probeSplits); - return splits; - }; - }; - - // 100 key values in [35, 233] range. - std::vector buildVectors; - for (int i = 0; i < 5; ++i) { - buildVectors.push_back(makeRowVector({ - makeFlatVector( - numRowsBuild / 5, - [i](auto row) { return 35 + 2 * (row + i * numRowsBuild / 5); }), - makeFlatVector(numRowsBuild / 5, [](auto row) { return row; }), - })); - } - std::vector keyOnlyBuildVectors; - for (int i = 0; i < 5; ++i) { - keyOnlyBuildVectors.push_back( - makeRowVector({makeFlatVector(numRowsBuild / 5, [i](auto row) { - return 35 + 2 * (row + i * numRowsBuild / 5); - })})); - } - - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", buildVectors); - - auto probeType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); - - auto planNodeIdGenerator = std::make_shared(); - - auto buildSide = PlanBuilder(planNodeIdGenerator, pool_.get()) - .values(buildVectors) - .project({"c0 AS u_c0", "c1 AS u_c1"}) - .planNode(); - - // Left semi join. - core::PlanNodeId probeScanId; - core::PlanNodeId joinNodeId; - const auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) - .tableScan(probeType) - .capturePlanNodeId(probeScanId) - .hashJoin( - {"c0"}, - {"u_c0"}, - buildSide, - "", - {"c0", "c1"}, - core::JoinType::kLeftSemiFilter) - .capturePlanNodeId(joinNodeId) - .project({"c0", "c1 + 1"}) - .planNode(); - - std::atomic_bool injectOnce{true}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal::getOutput", - std::function([&](Operator* op) { - if (op->operatorCtx()->operatorType() != "HashProbe") { - return; - } - if (!op->testingHasInput()) { - return; - } - if (!injectOnce.exchange(false)) { - return; - } - testingRunArbitration(op->pool()); - })); - - auto spillDirectory = exec::test::TempDirectoryPath::create(); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(std::move(op)) - .makeInputSplits(makeInputSplits(probeScanId)) - .spillDirectory(spillDirectory->getPath()) - .injectSpill(false) - .referenceQuery( - "SELECT t.c0, t.c1 + 1 FROM t WHERE t.c0 IN (SELECT c0 FROM u)") - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - // Verify spill hasn't triggered. - auto taskStats = exec::toPlanStats(task->taskStats()); - auto& planStats = taskStats.at(joinNodeId); - ASSERT_GT(planStats.spilledBytes, 0); - }) - .run(); -} - -// This test is to verify there is no memory reservation made before hash probe -// start processing. This can cause unnecessary spill and query OOM under some -// real workload with many stages as each hash probe might reserve non-trivial -// amount of memory. -DEBUG_ONLY_TEST_F( +VELOX_INSTANTIATE_TEST_SUITE_P( HashJoinTest, - hashProbeMemoryReservationCheckBeforeProbeStartWithSpillEnabled) { - fuzzerOpts_.vectorSize = 128; - auto probeVectors = createVectors(10, probeType_, fuzzerOpts_); - auto buildVectors = createVectors(20, buildType_, fuzzerOpts_); - - std::atomic_bool checkOnce{true}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal::addInput", - std::function(([&](Operator* op) { - if (op->operatorType() != "HashProbe") { - return; - } - if (!checkOnce.exchange(false)) { - return; - } - ASSERT_EQ(op->pool()->usedBytes(), 0); - ASSERT_EQ(op->pool()->reservedBytes(), 0); - }))); - - const auto spillDirectory = exec::test::TempDirectoryPath::create(); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(1) - .spillDirectory(spillDirectory->getPath()) - .probeKeys({"t_k1"}) - .probeVectors(std::move(probeVectors)) - .buildKeys({"u_k1"}) - .buildVectors(std::move(buildVectors)) - .config(core::QueryConfig::kJoinSpillEnabled, "true") - .joinType(core::JoinType::kInner) - .joinOutputLayout({"t_k1", "t_k2", "u_k1", "t_v1"}) - .referenceQuery( - "SELECT t.t_k1, t.t_k2, u.u_k1, t.t_v1 FROM t JOIN u ON t.t_k1 = u.u_k1") - .injectSpill(true) - .verifier([&](const std::shared_ptr& task, bool injectSpill) { - if (!injectSpill) { - return; - } - auto opStats = toOperatorStats(task->taskStats()); - ASSERT_GT(opStats.at("HashProbe").spilledBytes, 0); - ASSERT_GE(opStats.at("HashProbe").spilledPartitions, 1); - }) - .run(); -} - -TEST_F(HashJoinTest, nanKeys) { - // Verify the NaN values with different binary representations are considered - // equal. - static const double kNan = std::numeric_limits::quiet_NaN(); - static const double kSNaN = std::numeric_limits::signaling_NaN(); - auto probeInput = makeRowVector( - {makeFlatVector({kNan, kSNaN}), makeFlatVector({1, 2})}); - auto buildInput = makeRowVector({makeFlatVector({kNan, 1})}); - - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .values({probeInput}) - .project({"c0 AS t0", "c1 AS t1"}) - .hashJoin( - {"t0"}, - {"u0"}, - PlanBuilder(planNodeIdGenerator) - .values({buildInput}) - .project({"c0 AS u0"}) - .planNode(), - "", - {"t0", "u0", "t1"}, - core::JoinType::kLeft) - .planNode(); - auto queryCtx = core::QueryCtx::create(executor_.get()); - auto result = - AssertQueryBuilder(plan).queryCtx(queryCtx).copyResults(pool_.get()); - auto expected = makeRowVector( - {makeFlatVector({kNan, kNan}), - makeFlatVector({kNan, kNan}), - makeFlatVector({1, 2})}); - facebook::velox::test::assertEqualVectors(expected, result); -} - -DEBUG_ONLY_TEST_F(HashJoinTest, spillOnBlockedProbe) { - auto blockedOperatorFactoryUniquePtr = - std::make_unique(); - auto blockedOperatorFactory = blockedOperatorFactoryUniquePtr.get(); - Operator::registerOperator(std::move(blockedOperatorFactoryUniquePtr)); - - std::vector unblockPromises; - std::atomic_bool shouldBlock{true}; - blockedOperatorFactory->setBlockedCb([&](ContinueFuture* future) { - if (!shouldBlock) { - return BlockingReason::kNotBlocked; - } - auto [p, f] = makeVeloxContinuePromiseContract("Blocked Operator"); - *future = std::move(f); - unblockPromises.push_back(std::move(p)); - return BlockingReason::kWaitForConsumer; - }); - - folly::EventCount arbitrationWait; - std::atomic arbitrationWaitFlag{true}; - ::facebook::velox::common::testutil::ScopedTestValue _scopedTestValue15( - "facebook::velox::exec::HashBuild::finishHashBuild", - std::function([&](Operator* /* unused */) { - arbitrationWaitFlag = false; - arbitrationWait.notifyAll(); - })); - std::thread arbitrationThread([&]() { - arbitrationWait.await([&]() { return !arbitrationWaitFlag.load(); }); - memory::memoryManager()->shrinkPools(); - shouldBlock = false; - for (auto& unblockPromise : unblockPromises) { - unblockPromise.setValue(); - } - }); - - auto rowType = ROW({{"c0", INTEGER()}, {"c1", INTEGER()}}); - std::vector vectors = createVectors(1, rowType, fuzzerOpts_); - createDuckDbTable(vectors); - auto planNodeIdGenerator = std::make_shared(); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); - auto plan = PlanBuilder(planNodeIdGenerator) - .values(vectors) - .project({"c0 AS t0", "c1 AS t1"}) - .hashJoin( - {"t0"}, - {"u0"}, - PlanBuilder(planNodeIdGenerator) - .values(vectors) - .project({"c0 AS u0", "c1 AS u1"}) - .planNode(), - "", - {"t1"}, - core::JoinType::kInner) - .addNode([&](std::string id, core::PlanNodePtr input) { - return std::make_shared(id, input); - }) - .planNode(); - - { - auto task = - AssertQueryBuilder(duckDbQueryRunner_) - .plan(plan) - .queryCtx(newQueryCtx( - memory::memoryManager(), executor_.get(), kMemoryCapacity)) - .spillDirectory(spillDirectory->getPath()) - .config(core::QueryConfig::kSpillEnabled, true) - .maxDrivers(1) - .assertResults("SELECT a.c1 from tmp a join tmp b on a.c0 = b.c0"); - auto joinSpillStats = taskSpilledStats(*task); - auto buildSpillStats = joinSpillStats.first; - ASSERT_GT(buildSpillStats.spilledBytes, 0); - } - arbitrationThread.join(); - waitForAllTasksToBeDeleted(30'000'000); -} - -DEBUG_ONLY_TEST_F(HashJoinTest, buildReclaimedMemoryReport) { - constexpr int64_t kMaxBytes = 1LL << 30; // 1GB - const int32_t numBuildVectors = 3; - std::vector buildVectors; - for (int32_t i = 0; i < numBuildVectors; ++i) { - VectorFuzzer fuzzer({.vectorSize = 200}, pool()); - buildVectors.push_back(fuzzer.fuzzRow(buildType_)); - } - - const int32_t numProbeVectors = 3; - std::vector probeVectors; - for (int32_t i = 0; i < numProbeVectors; ++i) { - VectorFuzzer fuzzer({.vectorSize = 200}, pool()); - probeVectors.push_back(fuzzer.fuzzRow(probeType_)); - } - - const int numDrivers{2}; - // duckdb need double probe and build inputs as we run two drivers for hash - // join. - std::vector totalProbeVectors = probeVectors; - totalProbeVectors.insert( - totalProbeVectors.end(), probeVectors.begin(), probeVectors.end()); - std::vector totalBuildVectors = buildVectors; - totalBuildVectors.insert( - totalBuildVectors.end(), buildVectors.begin(), buildVectors.end()); - - createDuckDbTable("t", totalProbeVectors); - createDuckDbTable("u", totalBuildVectors); - - auto tempDirectory = exec::test::TempDirectoryPath::create(); - auto queryPool = memory::memoryManager()->addRootPool( - "", kMaxBytes, memory::MemoryReclaimer::create()); - - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .values(probeVectors, true) - .hashJoin( - {"t_k1"}, - {"u_k1"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors, true) - .planNode(), - "", - concat(probeType_->names(), buildType_->names())) - .planNode(); - - folly::EventCount driverWait; - std::atomic_bool driverWaitFlag{true}; - folly::EventCount taskWait; - std::atomic_bool taskWaitFlag{true}; - - Operator* op{nullptr}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::HashBuild::finishHashBuild", - std::function(([&](Operator* testOp) { op = testOp; }))); - - std::atomic_bool injectOnce{true}; - SCOPED_TESTVALUE_SET( - "facebook::velox::common::memory::MemoryPoolImpl::maybeReserve", - std::function( - ([&](memory::MemoryPoolImpl* pool) { - if (op == nullptr || op->pool() != pool) { - return; - } - ASSERT_TRUE(isHashBuildMemoryPool(*pool)); - ASSERT_TRUE(op->canReclaim()); - ASSERT_GT(op->pool()->usedBytes(), 0); - ASSERT_GT( - op->pool()->parent()->reservedBytes(), - op->pool()->reservedBytes()); - if (!injectOnce.exchange(false)) { - return; - } - uint64_t reclaimableBytes{0}; - const bool reclaimable = op->reclaimableBytes(reclaimableBytes); - ASSERT_TRUE(reclaimable); - ASSERT_GT(reclaimableBytes, 0); - auto* driver = op->operatorCtx()->driver(); - TestSuspendedSection suspendedSection(driver); - taskWaitFlag = false; - taskWait.notifyAll(); - driverWait.await([&]() { return !driverWaitFlag.load(); }); - }))); - - std::thread taskThread([&]() { - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers) - .planNode(plan) - .queryPool(std::move(queryPool)) - .injectSpill(false) - .spillDirectory(tempDirectory->getPath()) - .referenceQuery( - "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") - .config(core::QueryConfig::kSpillStartPartitionBit, "29") - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - const auto statsPair = taskSpilledStats(*task); - ASSERT_GT(statsPair.first.spilledBytes, 0); - ASSERT_EQ(statsPair.first.spilledPartitions, 16); - ASSERT_GT(statsPair.second.spilledBytes, 0); - ASSERT_EQ(statsPair.second.spilledPartitions, 16); - verifyTaskSpilledRuntimeStats(*task, true); - }) - .run(); - }); - - taskWait.await([&]() { return !taskWaitFlag.load(); }); - ASSERT_TRUE(op != nullptr); - auto task = op->operatorCtx()->task(); - auto* nodePool = op->pool()->parent(); - const auto nodeMemoryUsage = nodePool->reservedBytes(); - { - memory::ScopedMemoryArbitrationContext ctx(op->pool()); - const uint64_t reclaimedBytes = task->pool()->reclaim( - task->pool()->capacity(), 1'000'000, reclaimerStats_); - ASSERT_GT(reclaimedBytes, 0); - ASSERT_EQ(nodeMemoryUsage - nodePool->reservedBytes(), reclaimedBytes); - } - // Verify all the memory has been freed. - ASSERT_EQ(nodePool->reservedBytes(), 0); - - driverWaitFlag = false; - driverWait.notifyAll(); - task.reset(); - - taskThread.join(); -} - -DEBUG_ONLY_TEST_F(HashJoinTest, probeReclaimedMemoryReport) { - constexpr int64_t kMaxBytes = 1LL << 30; // 1GB - const int32_t numBuildVectors = 3; - std::vector buildVectors; - for (int32_t i = 0; i < numBuildVectors; ++i) { - VectorFuzzer fuzzer({.vectorSize = 200}, pool()); - buildVectors.push_back(fuzzer.fuzzRow(buildType_)); - } - - const int32_t numProbeVectors = 3; - std::vector probeVectors; - for (int32_t i = 0; i < numProbeVectors; ++i) { - VectorFuzzer fuzzer({.vectorSize = 200}, pool()); - probeVectors.push_back(fuzzer.fuzzRow(probeType_)); - } - - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", buildVectors); - - auto tempDirectory = exec::test::TempDirectoryPath::create(); - auto queryPool = memory::memoryManager()->addRootPool( - "", kMaxBytes, memory::MemoryReclaimer::create()); - - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .values(probeVectors, true) - .hashJoin( - {"t_k1"}, - {"u_k1"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors, true) - .planNode(), - "", - concat(probeType_->names(), buildType_->names())) - .planNode(); - - folly::EventCount driverWait; - std::atomic_bool driverWaitFlag{true}; - folly::EventCount taskWait; - std::atomic_bool taskWaitFlag{true}; - - Operator* op{nullptr}; - std::atomic_int probeInputCount{0}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal::addInput", - std::function(([&](Operator* testOp) { - if (testOp->operatorType() != "HashProbe") { - return; - } - op = testOp; - - ASSERT_TRUE(op->canReclaim()); - if (probeInputCount++ != 1) { - return; - } - auto* driver = op->operatorCtx()->driver(); - TestSuspendedSection suspendedSection(driver); - taskWaitFlag = false; - taskWait.notifyAll(); - driverWait.await([&]() { return !driverWaitFlag.load(); }); - }))); - - std::thread taskThread([&]() { - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(1) - .planNode(plan) - .queryPool(std::move(queryPool)) - .injectSpill(false) - .spillDirectory(tempDirectory->getPath()) - .referenceQuery( - "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") - .config(core::QueryConfig::kSpillStartPartitionBit, "29") - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - const auto statsPair = taskSpilledStats(*task); - // The spill triggered at the probe side. - ASSERT_EQ(statsPair.first.spilledBytes, 0); - ASSERT_EQ(statsPair.first.spilledPartitions, 0); - ASSERT_GT(statsPair.second.spilledBytes, 0); - ASSERT_EQ(statsPair.second.spilledPartitions, 16); - }) - .run(); - }); - - taskWait.await([&]() { return !taskWaitFlag.load(); }); - ASSERT_TRUE(op != nullptr); - auto task = op->operatorCtx()->task(); - auto* nodePool = op->pool()->parent(); - const auto nodeMemoryUsage = nodePool->reservedBytes(); - { - memory::ScopedMemoryArbitrationContext ctx(op->pool()); - const uint64_t reclaimedBytes = task->pool()->reclaim( - task->pool()->capacity(), 1'000'000, reclaimerStats_); - ASSERT_GT(reclaimedBytes, 0); - ASSERT_EQ(nodeMemoryUsage - nodePool->reservedBytes(), reclaimedBytes); - } - // Verify all the memory has been freed, except for the ones for hash lookup. - ASSERT_EQ(nodePool->reservedBytes(), 1048576); - - driverWaitFlag = false; - driverWait.notifyAll(); - task.reset(); - - taskThread.join(); -} - -DEBUG_ONLY_TEST_F(HashJoinTest, hashTableCleanupAfterProbeFinish) { - auto buildVectors = makeVectors(buildType_, 5, 100); - auto probeVectors = makeVectors(probeType_, 5, 100); - - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", buildVectors); - - HashProbe* probeOp{nullptr}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal::getOutput", - std::function([&](Operator* op) { - if (probeOp == nullptr && op->operatorType() == "HashProbe") { - probeOp = dynamic_cast(op); - } - })); - - bool tableEmpty{false}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal::noMoreInput", - std::function([&](Operator* op) { - if (op->operatorType() == "FilterProject") { - tableEmpty = (probeOp->testingTable()->numDistinct() == 0); - } - })); - - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .values(probeVectors, true) - .hashJoin( - {"t_k1"}, - {"u_k1"}, - PlanBuilder(planNodeIdGenerator) - .values(buildVectors, true) - .planNode(), - "", - concat(probeType_->names(), buildType_->names())) - .project({"t_k1", "t_k2", "t_v1", "u_k1", "u_k2", "u_v1"}) - .planNode(); - - auto tempDirectory = exec::test::TempDirectoryPath::create(); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(1) - .planNode(plan) - .injectSpill(false) - .spillDirectory(tempDirectory->getPath()) - .referenceQuery( - "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") - .config(core::QueryConfig::kSpillStartPartitionBit, "29") - .run(); - ASSERT_TRUE(tableEmpty); -} + HashJoinTest, + testing::ValuesIn(HashJoinTest::getTestParams()), + [](const testing::TestParamInfo& info) { + return TestParamToName(info.param); + }); } // namespace } // namespace facebook::velox::exec diff --git a/velox/exec/tests/HashJoinTestExtra.cpp b/velox/exec/tests/HashJoinTestExtra.cpp new file mode 100644 index 00000000000..4e224f1632d --- /dev/null +++ b/velox/exec/tests/HashJoinTestExtra.cpp @@ -0,0 +1,7966 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include "folly/synchronization/EventCount.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/testutil/TempDirectoryPath.h" +#include "velox/common/testutil/TestValue.h" +#include "velox/dwio/common/tests/utils/BatchMaker.h" +#include "velox/exec/Cursor.h" +#include "velox/exec/HashBuild.h" +#include "velox/exec/HashJoinBridge.h" +#include "velox/exec/HashProbe.h" +#include "velox/exec/OperatorType.h" +#include "velox/exec/OperatorUtils.h" +#include "velox/exec/PlanNodeStats.h" +#include "velox/exec/tests/utils/ArbitratorTestUtil.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/HashJoinTestBase.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/exec/tests/utils/VectorTestUtil.h" +#include "velox/type/tests/utils/CustomTypesForTesting.h" +#include "velox/vector/fuzzer/VectorFuzzer.h" + +using namespace facebook::velox; +using namespace facebook::velox::exec::test; +using namespace facebook::velox::common::testutil; + +using facebook::velox::test::BatchMaker; + +namespace facebook::velox::exec { +namespace { + +class HashJoinTest : public HashJoinTestBase, + public testing::WithParamInterface { + public: + HashJoinTest() : HashJoinTestBase(GetParam()) {} + + explicit HashJoinTest(const TestParam& param) : HashJoinTestBase(param) {} + + static std::vector getTestParams() { + return std::vector({TestParam{1, false}, TestParam{1, true}}); + } +}; + +class MultiThreadedHashJoinTest : public HashJoinTest { + public: + MultiThreadedHashJoinTest() : HashJoinTest(GetParam()) {} + + static std::vector getTestParams() { + return std::vector( + {TestParam{1, false}, + TestParam{1, true}, + TestParam{3, false}, + TestParam{3, true}}); + } +}; + +// TODO: try to parallelize the following test cases if possible. +TEST_P(HashJoinTest, memory) { + // Measures memory allocation in a 1:n hash join followed by + // projection and aggregation. We expect vectors to be mostly + // reused, except for t_k0 + 1, which is a dictionary after the + // join. + std::vector probeVectors = + makeBatches(10, [&](int32_t /*unused*/) { + return std::dynamic_pointer_cast( + BatchMaker::createBatch(probeType_, 1000, *pool_)); + }); + + // auto buildType = makeRowType(keyTypes, "u_"); + std::vector buildVectors = + makeBatches(10, [&](int32_t /*unused*/) { + return std::dynamic_pointer_cast( + BatchMaker::createBatch(buildType_, 1000, *pool_)); + }); + + auto planNodeIdGenerator = std::make_shared(); + CursorParameters params; + params.planNode = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, true) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, true) + .planNode(), + "", + concat(probeType_->names(), buildType_->names())) + .project({"t_k1 % 1000 AS k1", "u_k1 % 1000 AS k2"}) + .singleAggregation({}, {"sum(k1)", "sum(k2)"}) + .planNode(); + params.queryCtx = core::QueryCtx::create(driverExecutor_.get()); + auto [taskCursor, rows] = readCursor(params); + EXPECT_GT(3'500, params.queryCtx->pool()->stats().numAllocs); + EXPECT_GT(40'000'000, params.queryCtx->pool()->stats().cumulativeBytes); +} + +TEST_P(HashJoinTest, lazyVectors) { + // a dataset of multiple row groups with multiple columns. We create + // different dictionary wrappings for different columns and load the + // rows in scope at different times. + auto probeVectors = makeBatches(3, [&](int32_t /*unused*/) { + return makeRowVector( + {makeFlatVector(3'000, [](auto row) { return row; }), + makeFlatVector(30'000, [](auto row) { return row % 23; }), + makeFlatVector(30'000, [](auto row) { return row % 31; }), + makeFlatVector(30'000, [](auto row) { + return StringView::makeInline(fmt::format("{} string", row % 43)); + })}); + }); + + std::vector buildVectors = + makeBatches(4, [&](int32_t /*unused*/) { + return makeRowVector( + {makeFlatVector(1'000, [](auto row) { return row * 3; }), + makeFlatVector( + 10'000, [](auto row) { return row % 31; })}); + }); + + std::vector> tempFiles; + + for (const auto& probeVector : probeVectors) { + tempFiles.push_back(TempFilePath::create()); + writeToFile(tempFiles.back()->getPath(), probeVector); + } + createDuckDbTable("t", probeVectors); + + for (const auto& buildVector : buildVectors) { + tempFiles.push_back(TempFilePath::create()); + writeToFile(tempFiles.back()->getPath(), buildVector); + } + createDuckDbTable("u", buildVectors); + + auto makeInputSplits = [&](const core::PlanNodeId& probeScanId, + const core::PlanNodeId& buildScanId) { + return [&] { + std::vector probeSplits; + for (int i = 0; i < probeVectors.size(); ++i) { + probeSplits.push_back( + exec::Split(makeHiveConnectorSplit(tempFiles[i]->getPath()))); + } + std::vector buildSplits; + for (int i = 0; i < buildVectors.size(); ++i) { + buildSplits.push_back( + exec::Split(makeHiveConnectorSplit( + tempFiles[probeSplits.size() + i]->getPath()))); + } + SplitInput splits; + splits.emplace(probeScanId, probeSplits); + splits.emplace(buildScanId, buildSplits); + return splits; + }; + }; + + { + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId probeScanId; + core::PlanNodeId buildScanId; + auto op = PlanBuilder(planNodeIdGenerator) + .tableScan(ROW({"c0", "c1"}, {INTEGER(), BIGINT()})) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"c0"}, + {"c0"}, + PlanBuilder(planNodeIdGenerator) + .tableScan(ROW({"c0"}, {INTEGER()})) + .capturePlanNodeId(buildScanId) + .planNode(), + "", + {"c1"}) + .project({"c1 + 1"}) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .makeInputSplits(makeInputSplits(probeScanId, buildScanId)) + .referenceQuery("SELECT t.c1 + 1 FROM t, u WHERE t.c0 = u.c0") + .run(); + } + + { + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId probeScanId; + core::PlanNodeId buildScanId; + auto op = PlanBuilder(planNodeIdGenerator) + .tableScan( + ROW({"c0", "c1", "c2", "c3"}, + {INTEGER(), BIGINT(), INTEGER(), VARCHAR()})) + .capturePlanNodeId(probeScanId) + .filter("c2 < 29") + .hashJoin( + {"c0"}, + {"bc0"}, + PlanBuilder(planNodeIdGenerator) + .tableScan(ROW({"c0", "c1"}, {INTEGER(), BIGINT()})) + .capturePlanNodeId(buildScanId) + .project({"c0 as bc0", "c1 as bc1"}) + .planNode(), + "(c1 + bc1) % 33 < 27", + {"c1", "bc1", "c3"}) + .project({"c1 + 1", "bc1", "length(c3)"}) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .makeInputSplits(makeInputSplits(probeScanId, buildScanId)) + .referenceQuery( + "SELECT t.c1 + 1, U.c1, length(t.c3) FROM t, u WHERE t.c0 = u.c0 and t.c2 < 29 and (t.c1 + u.c1) % 33 < 27") + .run(); + } +} + +TEST_P(HashJoinTest, lazyVectorNotLoadedInFilter) { + // Ensure that if lazy vectors are temporarily wrapped during a filter's + // execution and remain unloaded, the temporary wrap is promptly + // discarded. This precaution prevents the generation of the probe's output + // from wrapping an unloaded vector while the temporary wrap is + // still alive. + // This is done by generating a sufficiently small batch to allow the lazy + // vector to remain unloaded, as it doesn't need to be split between batches. + // Then we use a filter that skips the execution of the expression containing + // the lazy vector, thereby avoiding its loading. + + testLazyVectorsWithFilter( + core::JoinType::kInner, + "c1 >= 0 OR c2 > 0", + {"c1", "c2"}, + "SELECT t.c1, t.c2 FROM t, u WHERE t.c0 = u.c0"); +} + +TEST_P(HashJoinTest, lazyVectorPartiallyLoadedInFilterLeftJoin) { + // Test the case where a filter loads a subset of the rows that will be output + // from a column on the probe side. + + testLazyVectorsWithFilter( + core::JoinType::kLeft, + "c1 > 0 AND c2 > 0", + {"c1", "c2"}, + "SELECT t.c1, t.c2 FROM t LEFT JOIN u ON t.c0 = u.c0 AND (c1 > 0 AND c2 > 0)"); +} + +TEST_P(HashJoinTest, lazyVectorPartiallyLoadedInFilterFullJoin) { + // Test the case where a filter loads a subset of the rows that will be output + // from a column on the probe side. + + testLazyVectorsWithFilter( + core::JoinType::kFull, + "c1 > 0 AND c2 > 0", + {"c1", "c2"}, + "SELECT t.c1, t.c2 FROM t FULL OUTER JOIN u ON t.c0 = u.c0 AND (c1 > 0 AND c2 > 0)"); +} + +TEST_P(HashJoinTest, lazyVectorPartiallyLoadedInFilterLeftSemiProject) { + // Test the case where a filter loads a subset of the rows that will be output + // from a column on the probe side. + + testLazyVectorsWithFilter( + core::JoinType::kLeftSemiProject, + "c1 > 0 AND c2 > 0", + {"c1", "c2", "match"}, + "SELECT t.c1, t.c2, EXISTS (SELECT * FROM u WHERE t.c0 = u.c0 AND (t.c1 > 0 AND t.c2 > 0)) FROM t"); +} + +TEST_P(HashJoinTest, lazyVectorPartiallyLoadedInFilterAntiJoin) { + // Test the case where a filter loads a subset of the rows that will be output + // from a column on the probe side. + + testLazyVectorsWithFilter( + core::JoinType::kAnti, + "c1 > 0 AND c2 > 0", + {"c1", "c2"}, + "SELECT t.c1, t.c2 FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE t.c0 = u.c0 AND (t.c1 > 0 AND t.c2 > 0))"); +} + +TEST_P(HashJoinTest, lazyVectorPartiallyLoadedInFilterInnerJoin) { + // Test the case where a filter loads a subset of the rows that will be output + // from a column on the probe side. + + testLazyVectorsWithFilter( + core::JoinType::kInner, + "not (c1 < 15 and c2 >= 0)", + {"c1", "c2"}, + "SELECT t.c1, t.c2 FROM t, u WHERE t.c0 = u.c0 AND NOT (c1 < 15 AND c2 >= 0)"); +} + +TEST_P(HashJoinTest, lazyVectorPartiallyLoadedInFilterLeftSemiFilter) { + // Test the case where a filter loads a subset of the rows that will be output + // from a column on the probe side. + + testLazyVectorsWithFilter( + core::JoinType::kLeftSemiFilter, + "not (c1 < 15 and c2 >= 0)", + {"c1", "c2"}, + "SELECT t.c1, t.c2 FROM t WHERE c0 IN (SELECT u.c0 FROM u WHERE t.c0 = u.c0 AND NOT (t.c1 < 15 AND t.c2 >= 0))"); +} + +// Verifies that lazy probe-side vectors are loaded in ensureLazyInputLoaded() +// (inside a reclaimable section) BEFORE the non-reclaimable probe output loop +// in getOutputInternal(). Tests fuzzed data with partial/all lazy columns and +// deterministic low-selectivity data. When spill is injected via +// HashJoinBuilder, verifies spill actually occurs and results remain correct. +DEBUG_ONLY_TEST_P(HashJoinTest, hashProbeEnsureLazyInputLoaded) { + struct { + bool useDeterministicData; + std::vector lazyColumnIndices; + bool useHashJoinBuilder; + + std::string debugString() const { + return fmt::format( + "useDeterministicData: {}, lazyColumns: {}, useHashJoinBuilder: {}", + useDeterministicData, + lazyColumnIndices.empty() ? "all" + : folly::join(",", lazyColumnIndices), + useHashJoinBuilder); + } + } testSettings[] = { + // Fuzzed data, non-key columns lazy. + {false, {1, 2}, false}, + // Deterministic low selectivity (~1% match rate), non-key columns lazy. + {true, {1, 2}, false}, + // Fuzzed data, all columns lazy, with spill injection. + {false, {}, true}, + }; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + std::vector probeVectors; + std::vector buildInput; + + if (testData.useDeterministicData) { + auto probe = makeRowVector( + {"t_k1", "t_k2", "t_v1"}, + {makeFlatVector(1'000, [](auto row) { return row; }), + makeFlatVector( + 1'000, + [](auto /*row*/) { return StringView::makeInline("probe"); }), + makeFlatVector(1'000, [](auto /*row*/) { + return StringView::makeInline("val"); + })}); + probeVectors.push_back( + VectorFuzzer({.vectorSize = 1'000}, pool()) + .fuzzRowChildrenToLazy(probe, testData.lazyColumnIndices)); + createDuckDbTable( + "t", + {std::dynamic_pointer_cast( + probe->testingCopyPreserveEncodings())}); + + // Build data only has keys 0-9, so only ~1% of probe rows match. + // Use enough duplicates (200 per key) so the output exceeds + // outputBatchSize (default 1024) and spans multiple batches, ensuring + // the ensureLoadedIfNotAtEnd at-end optimization does not bypass the + // preloading path we are testing. + auto build = makeRowVector( + {"u_k1", "u_k2", "u_v1"}, + {makeFlatVector(2'000, [](auto row) { return row % 10; }), + makeFlatVector( + 2'000, + [](auto /*row*/) { return StringView::makeInline("build"); }), + makeFlatVector(2'000, [](auto row) { return row * 100; })}); + buildInput.push_back(build); + createDuckDbTable("u", {build}); + } else { + VectorFuzzer::Options opts; + opts.vectorSize = 1'000; + VectorFuzzer fuzzer(opts, pool()); + + auto nonLazy = fuzzer.fuzzRow(probeType_); + // Overwrite the join key column (index 0) with values from a small + // range to guarantee probe-build matches. Without this, random int32 + // keys across 1'000 rows have near-zero probability of overlapping. + nonLazy->childAt(0) = + makeFlatVector(1'000, [](auto row) { return row % 100; }); + createDuckDbTable( + "t", + {std::dynamic_pointer_cast( + nonLazy->testingCopyPreserveEncodings())}); + probeVectors.push_back( + testData.lazyColumnIndices.empty() + ? VectorFuzzer(opts, pool()).fuzzRowChildrenToLazy(nonLazy) + : VectorFuzzer(opts, pool()) + .fuzzRowChildrenToLazy( + nonLazy, testData.lazyColumnIndices)); + + auto buildRow = fuzzer.fuzzRow(buildType_); + buildRow->childAt(0) = + makeFlatVector(1'000, [](auto row) { return row % 100; }); + buildInput.push_back(buildRow); + createDuckDbTable("u", buildInput); + } + + // Track whether fillOutput (the non-reclaimable probe output loop) has + // been entered. Lazy loading should happen BEFORE this point. + std::atomic_bool fillOutputEntered{false}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::HashProbe::fillOutput", + std::function( + [&](HashProbe*) { fillOutputEntered = true; })); + + std::atomic_int lazyLoadsBeforeFillOutput{0}; + std::atomic_int lazyLoadsDuringFillOutput{0}; + SCOPED_TESTVALUE_SET( + "facebook::velox::{}::VectorLoaderWrap::loadInternal", + std::function([&](void*) { + if (fillOutputEntered) { + ++lazyLoadsDuringFillOutput; + } else { + ++lazyLoadsBeforeFillOutput; + } + })); + + const std::string referenceQuery = + "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1"; + + if (testData.useHashJoinBuilder) { + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .probeKeys({"t_k1"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u_k1"}) + .buildVectors(std::move(buildInput)) + .referenceQuery(referenceQuery) + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + if (hasSpill) { + const auto statsPair = taskSpilledStats(*task); + ASSERT_GT(statsPair.first.spilledBytes, 0); + } + }) + .run(); + } else { + const auto spillDirectory = TempDirectoryPath::create(); + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, false) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildInput, false) + .planNode(), + "", + concat(probeType_->names(), buildType_->names())) + .planNode(); + + AssertQueryBuilder(duckDbQueryRunner_) + .spillDirectory(spillDirectory->getPath()) + .config(core::QueryConfig::kSpillEnabled, true) + .config(core::QueryConfig::kJoinSpillEnabled, true) + .plan(plan) + .assertResults(referenceQuery); + } + + // HashJoinBuilder runs the query twice (without and with spill). On the + // first (non-spill) run canReclaim() is false, so ensureLazyInputLoaded() + // is a no-op and lazy vectors are loaded inside fillOutput(). The atomic + // counters accumulate across both runs, so only assert loading order for + // the single-run non-builder case where spill is explicitly configured. + if (!testData.useHashJoinBuilder) { + ASSERT_GT(lazyLoadsBeforeFillOutput, 0); + ASSERT_EQ(lazyLoadsDuringFillOutput, 0); + } + } +} + +TEST_P(HashJoinTest, dynamicFilters) { + const int32_t numSplits = 10; + const int32_t numRowsProbe = 333; + const int32_t numRowsBuild = 100; + + std::vector probeVectors; + probeVectors.reserve(numSplits); + + std::vector> tempFiles; + for (int32_t i = 0; i < numSplits; ++i) { + auto rowVector = makeRowVector({ + makeFlatVector( + numRowsProbe, [&](auto row) { return row - i * 10; }), + makeFlatVector(numRowsProbe, [](auto row) { return row; }), + }); + probeVectors.push_back(rowVector); + tempFiles.push_back(TempFilePath::create()); + writeToFile(tempFiles.back()->getPath(), rowVector); + } + auto makeInputSplits = [&](const core::PlanNodeId& nodeId) { + return [&] { + std::vector probeSplits; + for (auto& file : tempFiles) { + probeSplits.push_back( + exec::Split(makeHiveConnectorSplit(file->getPath()))); + } + SplitInput splits; + splits.emplace(nodeId, probeSplits); + return splits; + }; + }; + + // 100 key values in [35, 233] range. + std::vector buildVectors; + for (int i = 0; i < 5; ++i) { + buildVectors.push_back(makeRowVector({ + makeFlatVector( + numRowsBuild / 5, + [i](auto row) { return 35 + 2 * (row + i * numRowsBuild / 5); }), + makeFlatVector(numRowsBuild / 5, [](auto row) { return row; }), + })); + } + std::vector keyOnlyBuildVectors; + for (int i = 0; i < 5; ++i) { + keyOnlyBuildVectors.push_back( + makeRowVector({makeFlatVector(numRowsBuild / 5, [i](auto row) { + return 35 + 2 * (row + i * numRowsBuild / 5); + })})); + } + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto probeType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + + auto planNodeIdGenerator = std::make_shared(); + + auto buildSide = PlanBuilder(planNodeIdGenerator, pool_.get()) + .values(buildVectors) + .project({"c0 AS u_c0", "c1 AS u_c1"}) + .planNode(); + auto keyOnlyBuildSide = PlanBuilder(planNodeIdGenerator, pool_.get()) + .values(keyOnlyBuildVectors) + .project({"c0 AS u_c0"}) + .planNode(); + + // Basic push-down. + { + SCOPED_TRACE("Inner join"); + core::PlanNodeId probeScanId; + core::PlanNodeId joinId; + auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) + .tableScan(probeType) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"c0"}, + {"u_c0"}, + buildSide, + "", + {"c0", "c1", "u_c1"}, + core::JoinType::kInner) + .capturePlanNodeId(joinId) + .project({"c0", "c1 + 1", "c1 + u_c1"}) + .planNode(); + { + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .makeInputSplits(makeInputSplits(probeScanId)) + .referenceQuery( + "SELECT t.c0, t.c1 + 1, t.c1 + u.c1 FROM t, u WHERE t.c0 = u.c0") + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); + auto planStats = toPlanStats(task->taskStats()); + if (hasSpill) { + // Dynamic filtering should be disabled with spilling triggered. + ASSERT_EQ(0, getFiltersProduced(task, 1).sum); + ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); + } else { + ASSERT_EQ(1, getFiltersProduced(task, 1).sum); + ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(0, getReplacedWithFilterRows(task, 1).sum); + ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_EQ( + planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, + std::unordered_set({joinId})); + } + }) + .run(); + } + + // Left semi join. + op = PlanBuilder(planNodeIdGenerator, pool_.get()) + .tableScan(probeType) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"c0"}, + {"u_c0"}, + buildSide, + "", + {"c0", "c1"}, + core::JoinType::kLeftSemiFilter) + .capturePlanNodeId(joinId) + .project({"c0", "c1 + 1"}) + .planNode(); + + { + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .makeInputSplits(makeInputSplits(probeScanId)) + .referenceQuery( + "SELECT t.c0, t.c1 + 1 FROM t WHERE t.c0 IN (SELECT c0 FROM u)") + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); + auto planStats = toPlanStats(task->taskStats()); + if (hasSpill) { + // Dynamic filtering should be disabled with spilling triggered. + ASSERT_EQ(0, getFiltersProduced(task, 1).sum); + ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(0, getReplacedWithFilterRows(task, 1).sum); + ASSERT_EQ(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); + } else { + ASSERT_EQ(1, getFiltersProduced(task, 1).sum); + ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); + ASSERT_GT(getReplacedWithFilterRows(task, 1).sum, 0); + ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_EQ( + planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, + std::unordered_set({joinId})); + } + }) + .run(); + } + + // Right semi join. + op = PlanBuilder(planNodeIdGenerator, pool_.get()) + .tableScan(probeType) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"c0"}, + {"u_c0"}, + buildSide, + "", + {"u_c0", "u_c1"}, + core::JoinType::kRightSemiFilter) + .capturePlanNodeId(joinId) + .project({"u_c0", "u_c1 + 1"}) + .planNode(); + + { + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .makeInputSplits(makeInputSplits(probeScanId)) + .referenceQuery( + "SELECT u.c0, u.c1 + 1 FROM u WHERE u.c0 IN (SELECT c0 FROM t)") + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); + auto planStats = toPlanStats(task->taskStats()); + if (hasSpill) { + // Dynamic filtering should be disabled with spilling triggered. + ASSERT_EQ(0, getFiltersProduced(task, 1).sum); + ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(getReplacedWithFilterRows(task, 1).sum, 0); + ASSERT_EQ(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); + } else { + ASSERT_EQ(1, getFiltersProduced(task, 1).sum); + ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(getReplacedWithFilterRows(task, 1).sum, 0); + ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_EQ( + planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, + std::unordered_set({joinId})); + } + }) + .run(); + } + + // Right join. + op = PlanBuilder(planNodeIdGenerator, pool_.get()) + .tableScan(probeType) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"c0"}, + {"u_c0"}, + buildSide, + "", + {"c0", "c1", "u_c1"}, + core::JoinType::kRight) + .capturePlanNodeId(joinId) + .project({"c0", "c1 + 1", "c1 + u_c1"}) + .planNode(); + { + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .makeInputSplits(makeInputSplits(probeScanId)) + .referenceQuery( + "SELECT t.c0, t.c1 + 1, t.c1 + u.c1 FROM t RIGHT JOIN u ON t.c0 = u.c0") + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); + auto planStats = toPlanStats(task->taskStats()); + if (hasSpill) { + // Dynamic filtering should be disabled with spilling triggered. + ASSERT_EQ(0, getFiltersProduced(task, 1).sum); + ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); + } else { + ASSERT_EQ(1, getFiltersProduced(task, 1).sum); + ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(0, getReplacedWithFilterRows(task, 1).sum); + ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_EQ( + planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, + std::unordered_set({joinId})); + } + }) + .run(); + } + } + + // Basic push-down with column names projected out of the table scan + // having different names than column names in the files. + { + SCOPED_TRACE("Inner join column rename"); + auto scanOutputType = ROW({"a", "b"}, {INTEGER(), BIGINT()}); + connector::ColumnHandleMap assignments; + assignments["a"] = regularColumn("c0", INTEGER()); + assignments["b"] = regularColumn("c1", BIGINT()); + + core::PlanNodeId probeScanId; + core::PlanNodeId joinId; + auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) + .startTableScan() + .outputType(scanOutputType) + .assignments(assignments) + .endTableScan() + .capturePlanNodeId(probeScanId) + .hashJoin({"a"}, {"u_c0"}, buildSide, "", {"a", "b", "u_c1"}) + .capturePlanNodeId(joinId) + .project({"a", "b + 1", "b + u_c1"}) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .makeInputSplits(makeInputSplits(probeScanId)) + .referenceQuery( + "SELECT t.c0, t.c1 + 1, t.c1 + u.c1 FROM t, u WHERE t.c0 = u.c0") + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); + auto planStats = toPlanStats(task->taskStats()); + if (hasSpill) { + // Dynamic filtering should be disabled with spilling triggered. + ASSERT_EQ(0, getFiltersProduced(task, 1).sum); + ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(0, getReplacedWithFilterRows(task, 1).sum); + ASSERT_EQ(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); + } else { + ASSERT_EQ(1, getFiltersProduced(task, 1).sum); + ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(0, getReplacedWithFilterRows(task, 1).sum); + ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_EQ( + planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, + std::unordered_set({joinId})); + } + }) + .run(); + } + + // Push-down that requires merging filters. + { + SCOPED_TRACE("Merge filters"); + core::PlanNodeId probeScanId; + core::PlanNodeId joinId; + auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) + .tableScan(probeType, {"c0 < 500::INTEGER"}) + .capturePlanNodeId(probeScanId) + .hashJoin({"c0"}, {"u_c0"}, buildSide, "", {"c1", "u_c1"}) + .capturePlanNodeId(joinId) + .project({"c1 + u_c1"}) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .makeInputSplits(makeInputSplits(probeScanId)) + .referenceQuery( + "SELECT t.c1 + u.c1 FROM t, u WHERE t.c0 = u.c0 AND t.c0 < 500") + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); + auto planStats = toPlanStats(task->taskStats()); + if (hasSpill) { + // Dynamic filtering should be disabled with spilling triggered. + ASSERT_EQ(0, getFiltersProduced(task, 1).sum); + ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(0, getReplacedWithFilterRows(task, 1).sum); + ASSERT_EQ(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); + } else { + ASSERT_EQ(1, getFiltersProduced(task, 1).sum); + ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(0, getReplacedWithFilterRows(task, 1).sum); + ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_EQ( + planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, + std::unordered_set({joinId})); + } + }) + .run(); + } + + // Push-down that turns join into a no-op. + { + SCOPED_TRACE("canReplaceWithDynamicFilter"); + core::PlanNodeId probeScanId; + core::PlanNodeId joinId; + auto op = + PlanBuilder(planNodeIdGenerator, pool_.get()) + .tableScan(probeType) + .capturePlanNodeId(probeScanId) + .hashJoin({"c0"}, {"u_c0"}, keyOnlyBuildSide, "", {"c0", "c1"}) + .capturePlanNodeId(joinId) + .project({"c0", "c1 + 1"}) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .makeInputSplits(makeInputSplits(probeScanId)) + .referenceQuery("SELECT t.c0, t.c1 + 1 FROM t, u WHERE t.c0 = u.c0") + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); + auto planStats = toPlanStats(task->taskStats()); + if (hasSpill) { + // Dynamic filtering should be disabled with spilling triggered. + ASSERT_EQ(0, getFiltersProduced(task, 1).sum); + ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(0, getReplacedWithFilterRows(task, 1).sum); + ASSERT_EQ(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); + } else { + ASSERT_EQ(1, getFiltersProduced(task, 1).sum); + ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); + ASSERT_EQ( + getReplacedWithFilterRows(task, 1).sum, + numRowsBuild * numSplits); + ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_EQ( + planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, + std::unordered_set({joinId})); + } + }) + .run(); + } + + // Push-down that turns join into a no-op with output having a different + // number of columns than the input. + { + SCOPED_TRACE("canReplaceWithDynamicFilter column rename"); + core::PlanNodeId probeScanId; + core::PlanNodeId joinId; + auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) + .tableScan(probeType) + .capturePlanNodeId(probeScanId) + .hashJoin({"c0"}, {"u_c0"}, keyOnlyBuildSide, "", {"c0"}) + .capturePlanNodeId(joinId) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .makeInputSplits(makeInputSplits(probeScanId)) + .referenceQuery("SELECT t.c0 FROM t JOIN u ON (t.c0 = u.c0)") + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); + auto planStats = toPlanStats(task->taskStats()); + if (hasSpill) { + // Dynamic filtering should be disabled with spilling triggered. + ASSERT_EQ(0, getFiltersProduced(task, 1).sum); + ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(0, getReplacedWithFilterRows(task, 1).sum); + ASSERT_EQ(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); + } else { + ASSERT_EQ(1, getFiltersProduced(task, 1).sum); + ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); + ASSERT_EQ( + getReplacedWithFilterRows(task, 1).sum, + numRowsBuild * numSplits); + ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_EQ( + planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, + std::unordered_set({joinId})); + } + }) + .run(); + } + + // Push-down that requires merging filters and turns join into a no-op. + { + SCOPED_TRACE("canReplaceWithDynamicFilter merge filters"); + core::PlanNodeId probeScanId; + core::PlanNodeId joinId; + auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) + .tableScan(probeType, {"c0 < 500::INTEGER"}) + .capturePlanNodeId(probeScanId) + .hashJoin({"c0"}, {"u_c0"}, keyOnlyBuildSide, "", {"c1"}) + .capturePlanNodeId(joinId) + .project({"c1 + 1"}) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .makeInputSplits(makeInputSplits(probeScanId)) + .referenceQuery( + "SELECT t.c1 + 1 FROM t, u WHERE t.c0 = u.c0 AND t.c0 < 500") + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); + auto planStats = toPlanStats(task->taskStats()); + if (hasSpill) { + // Dynamic filtering should be disabled with spilling triggered. + ASSERT_EQ(0, getFiltersProduced(task, 1).sum); + ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(getReplacedWithFilterRows(task, 1).sum, 0); + ASSERT_EQ(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); + } else { + ASSERT_EQ(1, getFiltersProduced(task, 1).sum); + ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); + ASSERT_GT(getReplacedWithFilterRows(task, 1).sum, 0); + ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_EQ( + planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, + std::unordered_set({joinId})); + } + }) + .run(); + } + + // Push-down with highly selective filter in the scan. + { + SCOPED_TRACE("Highly selective filter"); + // Inner join. + core::PlanNodeId probeScanId; + core::PlanNodeId joinId; + auto op = + PlanBuilder(planNodeIdGenerator, pool_.get()) + .tableScan(probeType, {"c0 < 200::INTEGER"}) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"c0"}, {"u_c0"}, buildSide, "", {"c1"}, core::JoinType::kInner) + .capturePlanNodeId(joinId) + .project({"c1 + 1"}) + .planNode(); + + { + SCOPED_TRACE("Inner join"); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .makeInputSplits(makeInputSplits(probeScanId)) + .referenceQuery( + "SELECT t.c1 + 1 FROM t, u WHERE t.c0 = u.c0 AND t.c0 < 200") + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); + auto planStats = toPlanStats(task->taskStats()); + if (hasSpill) { + // Dynamic filtering should be disabled with spilling triggered. + ASSERT_EQ(0, getFiltersProduced(task, 1).sum); + ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(getReplacedWithFilterRows(task, 1).sum, 0); + ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); + } else { + ASSERT_EQ(1, getFiltersProduced(task, 1).sum); + ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); + ASSERT_GT(getReplacedWithFilterRows(task, 1).sum, 0); + ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_EQ( + planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, + std::unordered_set({joinId})); + } + }) + .run(); + } + + // Left semi join. + op = PlanBuilder(planNodeIdGenerator, pool_.get()) + .tableScan(probeType, {"c0 < 200::INTEGER"}) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"c0"}, + {"u_c0"}, + buildSide, + "", + {"c1"}, + core::JoinType::kLeftSemiFilter) + .capturePlanNodeId(joinId) + .project({"c1 + 1"}) + .planNode(); + + { + SCOPED_TRACE("Left semi join"); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .makeInputSplits(makeInputSplits(probeScanId)) + .referenceQuery( + "SELECT t.c1 + 1 FROM t WHERE t.c0 IN (SELECT c0 FROM u) AND t.c0 < 200") + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); + auto planStats = toPlanStats(task->taskStats()); + if (hasSpill) { + // Dynamic filtering should be disabled with spilling triggered. + ASSERT_EQ(0, getFiltersProduced(task, 1).sum); + ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(getReplacedWithFilterRows(task, 1).sum, 0); + ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); + } else { + ASSERT_EQ(1, getFiltersProduced(task, 1).sum); + ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); + ASSERT_GT(getReplacedWithFilterRows(task, 1).sum, 0); + ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_EQ( + planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, + std::unordered_set({joinId})); + } + }) + .run(); + } + + // Right semi join. + op = PlanBuilder(planNodeIdGenerator, pool_.get()) + .tableScan(probeType, {"c0 < 200::INTEGER"}) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"c0"}, + {"u_c0"}, + buildSide, + "", + {"u_c1"}, + core::JoinType::kRightSemiFilter) + .capturePlanNodeId(joinId) + .project({"u_c1 + 1"}) + .planNode(); + + { + SCOPED_TRACE("Right semi join"); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .makeInputSplits(makeInputSplits(probeScanId)) + .referenceQuery( + "SELECT u.c1 + 1 FROM u WHERE u.c0 IN (SELECT c0 FROM t) AND u.c0 < 200") + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); + auto planStats = toPlanStats(task->taskStats()); + if (hasSpill) { + // Dynamic filtering should be disabled with spilling triggered. + ASSERT_EQ(0, getFiltersProduced(task, 1).sum); + ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(getReplacedWithFilterRows(task, 1).sum, 0); + ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); + } else { + ASSERT_EQ(1, getFiltersProduced(task, 1).sum); + ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(getReplacedWithFilterRows(task, 1).sum, 0); + ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_EQ( + planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, + std::unordered_set({joinId})); + } + }) + .run(); + } + + // Right join. + op = + PlanBuilder(planNodeIdGenerator, pool_.get()) + .tableScan(probeType, {"c0 < 200::INTEGER"}) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"c0"}, {"u_c0"}, buildSide, "", {"c1"}, core::JoinType::kRight) + .capturePlanNodeId(joinId) + .project({"c1 + 1"}) + .planNode(); + + { + SCOPED_TRACE("Right join"); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .makeInputSplits(makeInputSplits(probeScanId)) + .referenceQuery( + "SELECT t.c1 + 1 FROM (SELECT * FROM t WHERE t.c0 < 200) t RIGHT JOIN u ON t.c0 = u.c0") + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); + auto planStats = toPlanStats(task->taskStats()); + if (hasSpill) { + // Dynamic filtering should be disabled with spilling triggered. + ASSERT_EQ(0, getFiltersProduced(task, 1).sum); + ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(getReplacedWithFilterRows(task, 1).sum, 0); + ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); + } else { + ASSERT_EQ(1, getFiltersProduced(task, 1).sum); + ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(getReplacedWithFilterRows(task, 1).sum, 0); + ASSERT_LT(getInputPositions(task, 1), numRowsProbe * numSplits); + ASSERT_EQ( + planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, + std::unordered_set({joinId})); + } + }) + .run(); + } + } + + // Disable filter push-down by using values in place of scan. + { + SCOPED_TRACE("Disabled in case of values node"); + core::PlanNodeId joinId; + auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) + .values(probeVectors) + .hashJoin({"c0"}, {"u_c0"}, buildSide, "", {"c1"}) + .capturePlanNodeId(joinId) + .project({"c1 + 1"}) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .referenceQuery("SELECT t.c1 + 1 FROM t, u WHERE t.c0 = u.c0") + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + auto planStats = toPlanStats(task->taskStats()); + ASSERT_EQ(0, getFiltersProduced(task, 1).sum); + ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(numRowsProbe * numSplits, getInputPositions(task, 1)); + }) + .run(); + } + + // Disable filter push-down by using an expression as the join key on the + // probe side. + { + SCOPED_TRACE("Disabled in case of join condition"); + core::PlanNodeId probeScanId; + core::PlanNodeId joinId; + auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) + .tableScan(probeType) + .capturePlanNodeId(probeScanId) + .project({"cast(c0 + 1 as integer) AS t_key", "c1"}) + .hashJoin({"t_key"}, {"u_c0"}, buildSide, "", {"c1"}) + .capturePlanNodeId(joinId) + .project({"c1 + 1"}) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .makeInputSplits(makeInputSplits(probeScanId)) + .referenceQuery("SELECT t.c1 + 1 FROM t, u WHERE (t.c0 + 1) = u.c0") + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + auto planStats = toPlanStats(task->taskStats()); + ASSERT_EQ(0, getFiltersProduced(task, 1).sum); + ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(numRowsProbe * numSplits, getInputPositions(task, 1)); + ASSERT_TRUE(planStats.at(probeScanId).dynamicFilterStats.empty()); + }) + .run(); + } +} + +TEST_P(HashJoinTest, dynamicFiltersStatsWithChainedJoins) { + const int32_t numSplits = 10; + const int32_t numProbeRows = 333; + const int32_t numBuildRows = 100; + + std::vector probeVectors; + probeVectors.reserve(numSplits); + std::vector> tempFiles; + for (int32_t i = 0; i < numSplits; ++i) { + auto rowVector = makeRowVector({ + makeFlatVector( + numProbeRows, [&](auto row) { return row - i * 10; }), + makeFlatVector(numProbeRows, [](auto row) { return row; }), + }); + probeVectors.push_back(rowVector); + tempFiles.push_back(TempFilePath::create()); + writeToFile(tempFiles.back()->getPath(), rowVector); + } + auto makeInputSplits = [&](const core::PlanNodeId& nodeId) { + return [&] { + std::vector probeSplits; + for (auto& file : tempFiles) { + probeSplits.push_back( + exec::Split(makeHiveConnectorSplit(file->getPath()))); + } + SplitInput splits; + splits.emplace(nodeId, probeSplits); + return splits; + }; + }; + + // 100 key values in [35, 233] range. + std::vector buildVectors; + for (int i = 0; i < 5; ++i) { + buildVectors.push_back(makeRowVector({ + makeFlatVector( + numBuildRows / 5, + [i](auto row) { return 35 + 2 * (row + i * numBuildRows / 5); }), + makeFlatVector(numBuildRows / 5, [](auto row) { return row; }), + })); + } + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto probeType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + + auto planNodeIdGenerator = std::make_shared(); + + auto buildSide1 = PlanBuilder(planNodeIdGenerator, pool_.get()) + .values(buildVectors) + .project({"c0 AS u_c0", "c1 AS u_c1"}) + .planNode(); + auto buildSide2 = PlanBuilder(planNodeIdGenerator, pool_.get()) + .values(buildVectors) + .project({"c0 AS u_c0", "c1 AS u_c1"}) + .planNode(); + // Inner join pushdown. + core::PlanNodeId probeScanId; + core::PlanNodeId joinId1; + core::PlanNodeId joinId2; + auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) + .tableScan(probeType) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"c0"}, + {"u_c0"}, + buildSide1, + "", + {"c0", "c1"}, + core::JoinType::kInner) + .capturePlanNodeId(joinId1) + .hashJoin( + {"c0"}, + {"u_c0"}, + buildSide2, + "", + {"c0", "c1", "u_c1"}, + core::JoinType::kInner) + .capturePlanNodeId(joinId2) + .project({"c0", "c1 + 1", "c1 + u_c1"}) + .planNode(); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .makeInputSplits(makeInputSplits(probeScanId)) + .injectSpill(false) + .referenceQuery( + "SELECT t.c0, t.c1 + 1, t.c1 + u.c1 FROM t, u WHERE t.c0 = u.c0") + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + auto planStats = toPlanStats(task->taskStats()); + ASSERT_EQ( + planStats.at(probeScanId).dynamicFilterStats.producerNodeIds, + std::unordered_set({joinId1, joinId2})); + }) + .run(); +} + +TEST_P(HashJoinTest, dynamicFiltersWithSkippedSplits) { + const int32_t numSplits = 20; + const int32_t numNonSkippedSplits = 10; + const int32_t numRowsProbe = 333; + const int32_t numRowsBuild = 100; + + std::vector probeVectors; + probeVectors.reserve(numSplits); + + std::vector> tempFiles; + // Each split has a column containing + // the split number. This is used to filter out whole splits based + // on metadata. We test how using metadata for dropping splits + // interactts with dynamic filters. In specific, if the first split + // is discarded based on metadata, the dynamic filters must not be + // lost even if there is no actual reader for the split. + for (int32_t i = 0; i < numSplits; ++i) { + auto rowVector = makeRowVector({ + makeFlatVector( + numRowsProbe, [&](auto row) { return row - i * 10; }), + makeFlatVector(numRowsProbe, [](auto row) { return row; }), + makeFlatVector( + numRowsProbe, [&](auto /*row*/) { return i % 2 == 0 ? 0 : i; }), + }); + probeVectors.push_back(rowVector); + tempFiles.push_back(TempFilePath::create()); + writeToFile(tempFiles.back()->getPath(), rowVector); + } + + auto makeInputSplits = [&](const core::PlanNodeId& nodeId) { + return [&] { + std::vector probeSplits; + for (auto& file : tempFiles) { + probeSplits.push_back( + exec::Split(makeHiveConnectorSplit(file->getPath()))); + } + // We add splits that have no rows. + auto makeEmpty = [&]() { + return exec::Split( + HiveConnectorSplitBuilder(tempFiles.back()->getPath()) + .start(10000000) + .length(1) + .build()); + }; + std::vector emptyFront = {makeEmpty(), makeEmpty()}; + std::vector emptyMiddle = {makeEmpty(), makeEmpty()}; + probeSplits.insert( + probeSplits.begin(), emptyFront.begin(), emptyFront.end()); + probeSplits.insert( + probeSplits.begin() + 13, emptyMiddle.begin(), emptyMiddle.end()); + SplitInput splits; + splits.emplace(nodeId, probeSplits); + return splits; + }; + }; + + // 100 key values in [35, 233] range. + std::vector buildVectors; + for (int i = 0; i < 5; ++i) { + buildVectors.push_back(makeRowVector({ + makeFlatVector( + numRowsBuild / 5, + [i](auto row) { return 35 + 2 * (row + i * numRowsBuild / 5); }), + makeFlatVector(numRowsBuild / 5, [](auto row) { return row; }), + })); + } + std::vector keyOnlyBuildVectors; + for (int i = 0; i < 5; ++i) { + keyOnlyBuildVectors.push_back( + makeRowVector({makeFlatVector(numRowsBuild / 5, [i](auto row) { + return 35 + 2 * (row + i * numRowsBuild / 5); + })})); + } + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto probeType = ROW({"c0", "c1", "c2"}, {INTEGER(), BIGINT(), BIGINT()}); + + auto planNodeIdGenerator = std::make_shared(); + + auto buildSide = PlanBuilder(planNodeIdGenerator, pool_.get()) + .values(buildVectors) + .project({"c0 AS u_c0", "c1 AS u_c1"}) + .planNode(); + auto keyOnlyBuildSide = PlanBuilder(planNodeIdGenerator, pool_.get()) + .values(keyOnlyBuildVectors) + .project({"c0 AS u_c0"}) + .planNode(); + + // Basic push-down. + { + // Inner join. + core::PlanNodeId probeScanId; + auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) + .tableScan(probeType, {"c2 > 0"}) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"c0"}, + {"u_c0"}, + buildSide, + "", + {"c0", "c1", "u_c1"}, + core::JoinType::kInner) + .project({"c0", "c1 + 1", "c1 + u_c1"}) + .planNode(); + { + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .numDrivers(1) + .makeInputSplits(makeInputSplits(probeScanId)) + .referenceQuery( + "SELECT t.c0, t.c1 + 1, t.c1 + u.c1 FROM t, u WHERE t.c0 = u.c0 AND t.c2 > 0") + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); + if (hasSpill) { + // Dynamic filtering should be disabled with spilling triggered. + ASSERT_EQ(0, getFiltersProduced(task, 1).sum); + ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); + ASSERT_EQ( + getInputPositions(task, 1), + numRowsProbe * numNonSkippedSplits); + } else { + ASSERT_EQ(1, getFiltersProduced(task, 1).sum); + ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(0, getReplacedWithFilterRows(task, 1).sum); + ASSERT_LT( + getInputPositions(task, 1), + numRowsProbe * numNonSkippedSplits); + } + }) + .run(); + } + + // Left semi join. + op = PlanBuilder(planNodeIdGenerator, pool_.get()) + .tableScan(probeType, {"c2 > 0"}) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"c0"}, + {"u_c0"}, + buildSide, + "", + {"c0", "c1"}, + core::JoinType::kLeftSemiFilter) + .project({"c0", "c1 + 1"}) + .planNode(); + + { + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .numDrivers(1) + .makeInputSplits(makeInputSplits(probeScanId)) + .referenceQuery( + "SELECT t.c0, t.c1 + 1 FROM t WHERE t.c0 IN (SELECT c0 FROM u) AND t.c2 > 0") + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); + if (hasSpill) { + // Dynamic filtering should be disabled with spilling triggered. + ASSERT_EQ(0, getFiltersProduced(task, 1).sum); + ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(0, getReplacedWithFilterRows(task, 1).sum); + ASSERT_EQ( + getInputPositions(task, 1), + numRowsProbe * numNonSkippedSplits); + } else { + ASSERT_EQ(1, getFiltersProduced(task, 1).sum); + ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); + ASSERT_GT(getReplacedWithFilterRows(task, 1).sum, 0); + ASSERT_LT( + getInputPositions(task, 1), + numRowsProbe * numNonSkippedSplits); + } + }) + .run(); + } + + // Right semi join. + op = PlanBuilder(planNodeIdGenerator, pool_.get()) + .tableScan(probeType, {"c2 > 0"}) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"c0"}, + {"u_c0"}, + buildSide, + "", + {"u_c0", "u_c1"}, + core::JoinType::kRightSemiFilter) + .project({"u_c0", "u_c1 + 1"}) + .planNode(); + + { + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .numDrivers(1) + .makeInputSplits(makeInputSplits(probeScanId)) + .referenceQuery( + "SELECT u.c0, u.c1 + 1 FROM u WHERE u.c0 IN (SELECT c0 FROM t WHERE t.c2 > 0)") + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); + if (hasSpill) { + // Dynamic filtering should be disabled with spilling triggered. + ASSERT_EQ(0, getFiltersProduced(task, 1).sum); + ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(getReplacedWithFilterRows(task, 1).sum, 0); + ASSERT_EQ( + getInputPositions(task, 1), + numRowsProbe * numNonSkippedSplits); + } else { + ASSERT_EQ(1, getFiltersProduced(task, 1).sum); + ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(getReplacedWithFilterRows(task, 1).sum, 0); + ASSERT_LT( + getInputPositions(task, 1), + numRowsProbe * numNonSkippedSplits); + } + }) + .run(); + } + } + + // Test with VARCHAR keys. + { + SCOPED_TRACE("VARCHAR keys with skipped splits"); + + std::vector stringProbeVectors; + std::vector> stringTempFiles; + for (int32_t i = 0; i < numSplits; ++i) { + auto rowVector = makeRowVector({ + makeFlatVector( + numRowsProbe, + [&](auto row) { + return StringView::makeInline( + fmt::format("key_{}", row - i * 10)); + }), + makeFlatVector(numRowsProbe, [](auto row) { return row; }), + makeFlatVector( + numRowsProbe, [&](auto /*row*/) { return i % 2 == 0 ? 0 : i; }), + }); + stringProbeVectors.push_back(rowVector); + stringTempFiles.push_back(TempFilePath::create()); + writeToFile(stringTempFiles.back()->getPath(), rowVector); + } + + auto makeStringInputSplits = [&](const core::PlanNodeId& nodeId) { + return [&] { + std::vector probeSplits; + for (auto& file : stringTempFiles) { + probeSplits.push_back( + exec::Split(makeHiveConnectorSplit(file->getPath()))); + } + // We add splits that have no rows. + auto makeEmpty = [&]() { + return exec::Split( + HiveConnectorSplitBuilder(stringTempFiles.back()->getPath()) + .start(10000000) + .length(1) + .build()); + }; + std::vector emptyFront = {makeEmpty(), makeEmpty()}; + std::vector emptyMiddle = {makeEmpty(), makeEmpty()}; + probeSplits.insert( + probeSplits.begin(), emptyFront.begin(), emptyFront.end()); + probeSplits.insert( + probeSplits.begin() + 13, emptyMiddle.begin(), emptyMiddle.end()); + SplitInput splits; + splits.emplace(nodeId, probeSplits); + return splits; + }; + }; + + // Create build vectors in range [key_35, key_233]. + std::vector stringBuildVectors; + for (int i = 0; i < 5; ++i) { + stringBuildVectors.push_back(makeRowVector({ + makeFlatVector( + numRowsBuild / 5, + [i](auto row) { + return StringView::makeInline( + fmt::format( + "key_{}", 35 + 2 * (row + i * numRowsBuild / 5))); + }), + makeFlatVector( + numRowsBuild / 5, [](auto row) { return row; }), + })); + } + + createDuckDbTable("t_str_skip", stringProbeVectors); + createDuckDbTable("u_str_skip", stringBuildVectors); + + auto stringProbeType = + ROW({"c0", "c1", "c2"}, {VARCHAR(), BIGINT(), BIGINT()}); + + auto stringBuildSide = PlanBuilder(planNodeIdGenerator, pool_.get()) + .values(stringBuildVectors) + .project({"c0 AS u_c0", "c1 AS u_c1"}) + .planNode(); + + // Inner join. + { + SCOPED_TRACE("VARCHAR Inner join with skipped splits"); + core::PlanNodeId probeScanId; + auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) + .tableScan(stringProbeType, {"c2 > 0"}) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"c0"}, + {"u_c0"}, + stringBuildSide, + "", + {"c0", "c1", "u_c1"}, + core::JoinType::kInner) + .project({"c0", "c1 + 1", "c1 + u_c1"}) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .numDrivers(1) + .makeInputSplits(makeStringInputSplits(probeScanId)) + .config( + core::QueryConfig::kHashProbeDynamicFilterPushdownEnabled, "true") + .config( + core::QueryConfig::kHashProbeStringDynamicFilterPushdownEnabled, + "true") + .referenceQuery( + "SELECT t_str_skip.c0, t_str_skip.c1 + 1, t_str_skip.c1 + u_str_skip.c1 FROM t_str_skip, u_str_skip WHERE t_str_skip.c0 = u_str_skip.c0 AND t_str_skip.c2 > 0") + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); + if (hasSpill) { + ASSERT_EQ(0, getFiltersProduced(task, 1).sum); + ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); + ASSERT_EQ( + getInputPositions(task, 1), + numRowsProbe * numNonSkippedSplits); + } else { + ASSERT_EQ(1, getFiltersProduced(task, 1).sum); + ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(0, getReplacedWithFilterRows(task, 1).sum); + ASSERT_LT( + getInputPositions(task, 1), + numRowsProbe * numNonSkippedSplits); + } + }) + .run(); + } + + // Left semi join. + { + SCOPED_TRACE("VARCHAR Left semi join with skipped splits"); + core::PlanNodeId probeScanId; + auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) + .tableScan(stringProbeType, {"c2 > 0"}) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"c0"}, + {"u_c0"}, + stringBuildSide, + "", + {"c0", "c1"}, + core::JoinType::kLeftSemiFilter) + .project({"c0", "c1 + 1"}) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .numDrivers(1) + .makeInputSplits(makeStringInputSplits(probeScanId)) + .config( + core::QueryConfig::kHashProbeDynamicFilterPushdownEnabled, "true") + .config( + core::QueryConfig::kHashProbeStringDynamicFilterPushdownEnabled, + "true") + .referenceQuery( + "SELECT t_str_skip.c0, t_str_skip.c1 + 1 FROM t_str_skip WHERE t_str_skip.c0 IN (SELECT c0 FROM u_str_skip) AND t_str_skip.c2 > 0") + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); + if (hasSpill) { + ASSERT_EQ(0, getFiltersProduced(task, 1).sum); + ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(0, getReplacedWithFilterRows(task, 1).sum); + ASSERT_EQ( + getInputPositions(task, 1), + numRowsProbe * numNonSkippedSplits); + } else { + ASSERT_EQ(1, getFiltersProduced(task, 1).sum); + ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); + ASSERT_GT(getReplacedWithFilterRows(task, 1).sum, 0); + ASSERT_LT( + getInputPositions(task, 1), + numRowsProbe * numNonSkippedSplits); + } + }) + .run(); + } + + // Right semi join. + { + SCOPED_TRACE("VARCHAR Right semi join with skipped splits"); + core::PlanNodeId probeScanId; + auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) + .tableScan(stringProbeType, {"c2 > 0"}) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"c0"}, + {"u_c0"}, + stringBuildSide, + "", + {"u_c0", "u_c1"}, + core::JoinType::kRightSemiFilter) + .project({"u_c0", "u_c1 + 1"}) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .numDrivers(1) + .makeInputSplits(makeStringInputSplits(probeScanId)) + .config( + core::QueryConfig::kHashProbeDynamicFilterPushdownEnabled, "true") + .config( + core::QueryConfig::kHashProbeStringDynamicFilterPushdownEnabled, + "true") + .referenceQuery( + "SELECT u_str_skip.c0, u_str_skip.c1 + 1 FROM u_str_skip WHERE u_str_skip.c0 IN (SELECT c0 FROM t_str_skip WHERE t_str_skip.c2 > 0)") + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + SCOPED_TRACE(fmt::format("hasSpill:{}", hasSpill)); + if (hasSpill) { + ASSERT_EQ(0, getFiltersProduced(task, 1).sum); + ASSERT_EQ(0, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(getReplacedWithFilterRows(task, 1).sum, 0); + ASSERT_EQ( + getInputPositions(task, 1), + numRowsProbe * numNonSkippedSplits); + } else { + ASSERT_EQ(1, getFiltersProduced(task, 1).sum); + ASSERT_EQ(1, getFiltersAccepted(task, 0).sum); + ASSERT_EQ(getReplacedWithFilterRows(task, 1).sum, 0); + ASSERT_LT( + getInputPositions(task, 1), + numRowsProbe * numNonSkippedSplits); + } + }) + .run(); + } + } +} + +TEST_P(HashJoinTest, dynamicFiltersAppliedToPreloadedSplits) { + vector_size_t size = 1000; + const int32_t numSplits = 5; + + std::vector probeVectors; + probeVectors.reserve(numSplits); + + // Prepare probe side table. + std::vector> tempFiles; + std::vector probeSplits; + for (int32_t i = 0; i < numSplits; ++i) { + auto rowVector = makeRowVector( + {"p0", "p1"}, + { + makeFlatVector( + size, [&](auto row) { return (row + 1) * (i + 1); }), + makeFlatVector(size, [&](auto /*row*/) { return i; }), + }); + probeVectors.push_back(rowVector); + tempFiles.push_back(TempFilePath::create()); + writeToFile(tempFiles.back()->getPath(), rowVector); + } + + auto makeInputSplits = [&](const core::PlanNodeId& nodeId) { + return [&] { + std::vector splits; + splits.reserve(tempFiles.size()); + for (int32_t i = 0; i < tempFiles.size(); ++i) { + auto split = HiveConnectorSplitBuilder(tempFiles[i]->getPath()) + .partitionKey("p1", std::to_string(i)) + .build(); + splits.emplace_back(exec::Split(split)); + } + SplitInput inputSplits; + inputSplits.emplace(nodeId, splits); + return inputSplits; + }; + }; + + auto outputType = ROW({"p0", "p1"}, {BIGINT(), BIGINT()}); + connector::ColumnHandleMap assignments = { + {"p0", regularColumn("p0", BIGINT())}, + {"p1", partitionKey("p1", BIGINT())}}; + createDuckDbTable("p", probeVectors); + + // Prepare build side table. + std::vector buildVectors{ + makeRowVector({"b0"}, {makeFlatVector({0, numSplits})})}; + createDuckDbTable("b", buildVectors); + + // Executing the join with p1=b0, we expect a dynamic filter for p1 to prune + // the entire file/split. There are total of five splits, and all except the + // first one are expected to be pruned. The result 'preloadedSplits' > 1 + // confirms the successful push of dynamic filters to the preloading data + // source. + core::PlanNodeId probeScanId; + core::PlanNodeId joinNodeId; + auto planNodeIdGenerator = std::make_shared(); + auto op = + PlanBuilder(planNodeIdGenerator) + .startTableScan() + .outputType(outputType) + .assignments(assignments) + .endTableScan() + .capturePlanNodeId(probeScanId) + .hashJoin( + {"p1"}, + {"b0"}, + PlanBuilder(planNodeIdGenerator).values(buildVectors).planNode(), + "", + {"p0"}, + core::JoinType::kInner) + .capturePlanNodeId(joinNodeId) + .project({"p0"}) + .planNode(); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .config(core::QueryConfig::kMaxSplitPreloadPerDriver, "3") + .injectSpill(false) + .makeInputSplits(makeInputSplits(probeScanId)) + .referenceQuery("select p.p0 from p, b where b.b0 = p.p1") + .checkSpillStats(false) + .verifier([&](const std::shared_ptr& task, bool /*hasSpill*/) { + auto planStats = toPlanStats(task->taskStats()); + auto getStatSum = [&](const core::PlanNodeId& id, + const std::string& name) { + return planStats.at(id).customStats.at(name).sum; + }; + ASSERT_EQ(1, getStatSum(joinNodeId, "dynamicFiltersProduced")); + ASSERT_EQ(1, getStatSum(probeScanId, "dynamicFiltersAccepted")); + ASSERT_EQ(4, getStatSum(probeScanId, "skippedSplits")); + ASSERT_LT(1, getStatSum(probeScanId, "preloadedSplits")); + }) + .run(); +} + +TEST_P(HashJoinTest, dynamicFiltersPushDownThroughAgg) { + const int32_t numRowsProbe = 300; + const int32_t numRowsBuild = 100; + + // Create probe data + std::vector probeVectors{makeRowVector({ + makeFlatVector(numRowsProbe, [&](auto row) { return row - 10; }), + makeFlatVector(numRowsProbe, folly::identity), + })}; + std::shared_ptr probeFile = TempFilePath::create(); + writeToFile(probeFile->getPath(), probeVectors); + + // Create build data + std::vector buildVectors{makeRowVector( + {"u0"}, {makeFlatVector(numRowsBuild, [&](auto row) { + return 35 + 2 * (row + numRowsBuild / 5); + })})}; + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto probeType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + auto planNodeIdGenerator = std::make_shared(); + auto buildSide = + PlanBuilder(planNodeIdGenerator).values(buildVectors).planNode(); + + // Inner join. + core::PlanNodeId scanNodeId; + core::PlanNodeId joinNodeId; + core::PlanNodeId aggNodeId; + auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) + .tableScan(probeType) + .capturePlanNodeId(scanNodeId) + .partialAggregation({"c0"}, {"sum(c1)"}) + .capturePlanNodeId(aggNodeId) + .hashJoin( + {"c0"}, + {"u0"}, + buildSide, + "", + {"c0", "a0"}, + core::JoinType::kInner) + .capturePlanNodeId(joinNodeId) + .planNode(); + + SplitPath splitPaths = {{scanNodeId, {probeFile->getPath()}}}; + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .inputSplits(splitPaths) + .injectSpill(false) + .checkSpillStats(false) + .referenceQuery("SELECT c0, sum(c1) FROM t, u WHERE c0 = u0 group by c0") + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + auto planStats = toPlanStats(task->taskStats()); + auto dynamicFilterStats = planStats.at(scanNodeId).dynamicFilterStats; + ASSERT_EQ( + 1, getFiltersProduced(task, getOperatorIndex(joinNodeId)).sum); + ASSERT_EQ( + 1, getFiltersAccepted(task, getOperatorIndex(scanNodeId)).sum); + ASSERT_LT( + getInputPositions(task, getOperatorIndex(aggNodeId)), numRowsProbe); + ASSERT_EQ( + dynamicFilterStats.producerNodeIds, + std::unordered_set({joinNodeId})); + }) + .run(); +} + +TEST_P(HashJoinTest, noDynamicFiltersPushDownThroughRightJoin) { + std::vector innerBuild = {makeRowVector( + {"a"}, + { + makeFlatVector(5, [](auto i) { return 2 * i; }), + })}; + std::vector rightBuild = {makeRowVector( + {"b"}, + { + makeFlatVector(5, [](auto i) { return 1 + 2 * i; }), + })}; + std::vector rightProbe = {makeRowVector( + {"aa", "bb"}, + { + makeFlatVector(10, folly::identity), + makeFlatVector(10, folly::identity), + })}; + auto file = TempFilePath::create(); + writeToFile(file->getPath(), rightProbe); + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId scanNodeId; + auto plan = + PlanBuilder(planNodeIdGenerator) + .tableScan(asRowType(rightProbe[0]->type())) + .capturePlanNodeId(scanNodeId) + .hashJoin( + {"bb"}, + {"b"}, + PlanBuilder(planNodeIdGenerator).values(rightBuild).planNode(), + "", + {"aa", "b"}, + core::JoinType::kRight) + .hashJoin( + {"aa"}, + {"a"}, + PlanBuilder(planNodeIdGenerator).values(innerBuild).planNode(), + "", + {"aa"}) + .planNode(); + AssertQueryBuilder(plan) + .split(scanNodeId, Split(makeHiveConnectorSplit(file->getPath()))) + .assertResults( + BaseVector::create(innerBuild[0]->type(), 0, pool_.get())); +} + +// Verify the size of the join output vectors when projecting build-side +// variable-width column. +TEST_P(HashJoinTest, memoryUsage) { + std::vector probeVectors = + makeBatches(10, [&](int32_t /*unused*/) { + return makeRowVector( + {makeFlatVector(1'000, [](auto row) { return row % 5; })}); + }); + std::vector buildVectors = + makeBatches(5, [&](int32_t /*unused*/) { + return makeRowVector( + {"u_c0", "u_c1"}, + {makeFlatVector({0, 1, 2}), + makeFlatVector({ + std::string(40, 'a'), + std::string(50, 'b'), + std::string(30, 'c'), + })}); + }); + core::PlanNodeId joinNodeId; + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors) + .hashJoin( + {"c0"}, + {"u_c0"}, + PlanBuilder(planNodeIdGenerator) + .values({buildVectors}) + .planNode(), + "", + {"c0", "u_c1"}) + .capturePlanNodeId(joinNodeId) + .singleAggregation({}, {"count(1)"}) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(plan)) + .referenceQuery("SELECT 30000") + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + if (hasSpill) { + return; + } + auto planStats = toPlanStats(task->taskStats()); + auto outputBytes = planStats.at(joinNodeId).outputBytes; + ASSERT_LT(outputBytes, ((40 + 50 + 30) / 3 + 8) * 1000 * 10 * 5); + // Verify number of memory allocations. Should not be too high if + // hash join is able to re-use output vectors that contain + // build-side data. + ASSERT_GT(40, task->pool()->stats().numAllocs); + }) + .run(); +} + +/// Test an edge case in producing small output batches where the logic to +/// calculate the set of probe-side rows to load lazy vectors for was +/// triggering a crash. +TEST_P(HashJoinTest, smallOutputBatchSize) { + // Setup probe data with 50 non-null matching keys followed by 50 null + // keys: 1, 2, 1, 2,...null, null. + auto probeVectors = makeRowVector({ + makeFlatVector( + 100, + [](auto row) { return 1 + row % 2; }, + [](auto row) { return row > 50; }), + makeFlatVector(100, [](auto row) { return row * 10; }), + }); + + // Setup build side to match non-null probe side keys. + auto buildVectors = makeRowVector( + {"u_c0", "u_c1"}, + { + makeFlatVector({1, 2}), + makeFlatVector({100, 200}), + }); + + createDuckDbTable("t", {probeVectors}); + createDuckDbTable("u", {buildVectors}); + + // Plan hash inner join with a filter. + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values({probeVectors}) + .hashJoin( + {"c0"}, + {"u_c0"}, + PlanBuilder(planNodeIdGenerator) + .values({buildVectors}) + .planNode(), + "c1 < u_c1", + {"c0", "u_c1"}) + .planNode(); + + // Use small output batch size to trigger logic for calculating set of + // probe-side rows to load lazy vectors for. + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(plan)) + .config(core::QueryConfig::kPreferredOutputBatchRows, std::to_string(10)) + .referenceQuery("SELECT c0, u_c1 FROM t, u WHERE c0 = u_c0 AND c1 < u_c1") + .injectSpill(false) + .run(); +} + +TEST_P(HashJoinTest, spillFileSize) { + const std::vector maxSpillFileSizes({0, 1, 1'000'000'000}); + for (const auto spillFileSize : maxSpillFileSizes) { + SCOPED_TRACE(fmt::format("spillFileSize: {}", spillFileSize)); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .keyTypes({BIGINT()}) + .probeVectors(100, 3) + .buildVectors(100, 3) + .referenceQuery( + "SELECT t_k0, t_data, u_k0, u_data FROM t, u WHERE t.t_k0 = u.u_k0") + .config(core::QueryConfig::kSpillStartPartitionBit, "48") + .config(core::QueryConfig::kSpillNumPartitionBits, "3") + .config( + core::QueryConfig::kMaxSpillFileSize, std::to_string(spillFileSize)) + .checkSpillStats(false) + .maxSpillLevel(0) + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + if (!hasSpill) { + return; + } + const auto statsPair = taskSpilledStats(*task); + const int32_t numPartitions = statsPair.first.spilledPartitions; + ASSERT_EQ(statsPair.second.spilledPartitions, numPartitions); + const auto fileSizes = numTaskSpillFiles(*task); + if (spillFileSize != 1) { + ASSERT_EQ(fileSizes.first, numPartitions); + } else { + ASSERT_GT(fileSizes.first, numPartitions); + } + verifyTaskSpilledRuntimeStats(*task, true); + }) + .run(); + } +} + +TEST_P(HashJoinTest, spillPartitionBitsOverlap) { + auto builder = + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .keyTypes({BIGINT(), BIGINT()}) + .probeVectors(2'000, 3) + .buildVectors(2'000, 3) + .referenceQuery( + "SELECT t_k0, t_k1, t_data, u_k0, u_k1, u_data FROM t, u WHERE t_k0 = u_k0 and t_k1 = u_k1") + .config(core::QueryConfig::kSpillStartPartitionBit, "8") + .config(core::QueryConfig::kSpillNumPartitionBits, "1") + .checkSpillStats(false) + .maxSpillLevel(0); + VELOX_ASSERT_THROW(builder.run(), "vs. 8"); +} + +// The test is to verify if the hash build reservation has been released on +// task error. +DEBUG_ONLY_TEST_P(HashJoinTest, buildReservationReleaseCheck) { + std::vector probeVectors = + makeBatches(1, [&](int32_t /*unused*/) { + return std::dynamic_pointer_cast( + BatchMaker::createBatch(probeType_, 1000, *pool_)); + }); + std::vector buildVectors = makeBatches(10, [&](int32_t index) { + return std::dynamic_pointer_cast( + BatchMaker::createBatch(buildType_, 5000 * (1 + index), *pool_)); + }); + + auto planNodeIdGenerator = std::make_shared(); + CursorParameters params; + params.planNode = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, true) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, true) + .planNode(), + "", + concat(probeType_->names(), buildType_->names())) + .planNode(); + params.queryCtx = core::QueryCtx::create(driverExecutor_.get()); + // NOTE: the spilling setup is to trigger memory reservation code path which + // only gets executed when spilling is enabled. We don't care about if + // spilling is really triggered in test or not. + auto spillDirectory = TempDirectoryPath::create(); + params.spillDirectory = spillDirectory->getPath(); + params.queryCtx->testingOverrideConfigUnsafe( + {{core::QueryConfig::kSpillEnabled, "true"}, + {core::QueryConfig::kMaxSpillLevel, "0"}}); + params.maxDrivers = 1; + + auto cursor = TaskCursor::create(params); + auto* task = cursor->task().get(); + + // Set up a testvalue to trigger task abort when hash build tries to reserve + // memory. + SCOPED_TESTVALUE_SET( + "facebook::velox::common::memory::MemoryPoolImpl::maybeReserve", + std::function( + [&](memory::MemoryPool* /*unused*/) { task->requestAbort(); })); + auto runTask = [&]() { + while (cursor->moveNext()) { + } + }; + VELOX_ASSERT_THROW(runTask(), ""); + ASSERT_TRUE(waitForTaskAborted(task, 5'000'000)); +} + +TEST_P(HashJoinTest, dynamicFilterOnPartitionKey) { + vector_size_t size = 10; + auto filePaths = makeFilePaths(1); + auto rowVector = makeRowVector( + {makeFlatVector(size, [&](auto row) { return row; })}); + createDuckDbTable("u", {rowVector}); + writeToFile(filePaths[0]->getPath(), rowVector); + std::vector buildVectors{ + makeRowVector({"c0"}, {makeFlatVector({0, 1, 2})})}; + createDuckDbTable("t", buildVectors); + auto outputType = ROW({"n1_0", "n1_1"}, {BIGINT(), BIGINT()}); + connector::ColumnHandleMap assignments = { + {"n1_0", regularColumn("c0", BIGINT())}, + {"n1_1", partitionKey("k", BIGINT())}}; + + core::PlanNodeId probeScanId; + auto planNodeIdGenerator = std::make_shared(); + auto op = + PlanBuilder(planNodeIdGenerator) + .startTableScan() + .outputType(outputType) + .assignments(assignments) + .endTableScan() + .capturePlanNodeId(probeScanId) + .hashJoin( + {"n1_1"}, + {"c0"}, + PlanBuilder(planNodeIdGenerator).values(buildVectors).planNode(), + "", + {"c0"}, + core::JoinType::kInner) + .project({"c0"}) + .planNode(); + + auto makeInputSplits = [&](const core::PlanNodeId& nodeId) { + return [&] { + auto split = facebook::velox::exec::test::HiveConnectorSplitBuilder( + filePaths[0]->getPath()) + .partitionKey("k", "0") + .build(); + SplitInput splits = {{nodeId, {exec::Split(split)}}}; + return splits; + }; + }; + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .makeInputSplits(makeInputSplits(probeScanId)) + .referenceQuery("select t.c0 from t, u where t.c0 = 0") + .checkSpillStats(false) + .run(); +} + +TEST_P(HashJoinTest, probeMemoryLimitOnBuildProjection) { + const uint64_t numBuildRows = 20; + std::vector probeVectors = + makeBatches(10, [&](int32_t /*unused*/) { + return makeRowVector({makeFlatVector( + 1'000, [](auto row) { return row % 25; })}); + }); + + std::vector buildVectors = + makeBatches(1, [&](int32_t /*unused*/) { + return makeRowVector( + {"u_c0", "u_c1", "u_c2", "u_c3", "u_c4"}, + {makeFlatVector( + numBuildRows, [](auto row) { return row; }), + makeFlatVector( + numBuildRows, + [](auto /* row */) { return std::string(4096, 'a'); }), + makeFlatVector( + numBuildRows, + [](auto /* row */) { return std::string(4096, 'a'); }), + makeFlatVector( + numBuildRows, + [](auto row) { + // Row that has too large of size variation. + if (row == 0) { + return std::string(4096, 'a'); + } else { + return std::string(1, 'a'); + } + }), + makeFlatVector(numBuildRows, [](auto row) { + // Row that has tolerable size variation. + if (row == 0) { + return std::string(4096, 'a'); + } else { + return std::string(256, 'a'); + } + })}); + }); + + createDuckDbTable("t", {probeVectors}); + createDuckDbTable("u", {buildVectors}); + + struct TestParam { + std::vector varSizeColumns; + int32_t numExpectedBatches; + std::string referenceQuery; + std::string debugString() const { + std::stringstream ss; + ss << "varSizeColumns ["; + for (const auto& columnIndex : varSizeColumns) { + ss << columnIndex << ", "; + } + ss << "] "; + ss << "numExpectedBatches " << numExpectedBatches << ", referenceQuery '" + << referenceQuery << "'"; + return ss.str(); + } + }; + + std::vector testParams{ + {{}, 10, "SELECT t.c0 FROM t JOIN u ON t.c0 = u.u_c0"}, + {{1}, 4000, "SELECT t.c0, u.u_c1 FROM t JOIN u ON t.c0 = u.u_c0"}, + {{1, 2}, + 8000, + "SELECT t.c0, u.u_c1, u.u_c2 FROM t JOIN u ON t.c0 = u.u_c0"}, + {{3}, 210, "SELECT t.c0, u.u_c3 FROM t JOIN u ON t.c0 = u.u_c0"}, + {{4}, 2670, "SELECT t.c0, u.u_c4 FROM t JOIN u ON t.c0 = u.u_c0"}}; + + for (const auto& testParam : testParams) { + SCOPED_TRACE(testParam.debugString()); + core::PlanNodeId joinNodeId; + std::vector outputLayout; + outputLayout.push_back("c0"); + for (int32_t i = 0; i < testParam.varSizeColumns.size(); i++) { + outputLayout.push_back(fmt::format("u_c{}", testParam.varSizeColumns[i])); + } + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors) + .hashJoin( + {"c0"}, + {"u_c0"}, + PlanBuilder(planNodeIdGenerator) + .values({buildVectors}) + .planNode(), + "", + outputLayout) + .capturePlanNodeId(joinNodeId) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(plan)) + .config(core::QueryConfig::kPreferredOutputBatchBytes, "8192") + .injectSpill(false) + .referenceQuery(testParam.referenceQuery) + .verifier([&](const std::shared_ptr& task, bool /* unused */) { + auto planStats = toPlanStats(task->taskStats()); + auto outputBatches = planStats.at(joinNodeId).outputVectors; + ASSERT_EQ(outputBatches, testParam.numExpectedBatches); + }) + .run(); + } +} + +DEBUG_ONLY_TEST_P(HashJoinTest, reclaimDuringInputProcessing) { + constexpr int64_t kMaxBytes = 1LL << 30; // 1GB + VectorFuzzer fuzzer({.vectorSize = 1000}, pool()); + const int32_t numBuildVectors = 10; + std::vector buildVectors; + for (int32_t i = 0; i < numBuildVectors; ++i) { + buildVectors.push_back(fuzzer.fuzzRow(buildType_)); + } + const int32_t numProbeVectors = 5; + std::vector probeVectors; + for (int32_t i = 0; i < numProbeVectors; ++i) { + probeVectors.push_back(fuzzer.fuzzRow(probeType_)); + } + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + struct { + // 0: trigger reclaim with some input processed. + // 1: trigger reclaim after all the inputs processed. + int triggerCondition; + bool spillEnabled; + bool expectedReclaimable; + + std::string debugString() const { + return fmt::format( + "triggerCondition {}, spillEnabled {}, expectedReclaimable {}", + triggerCondition, + spillEnabled, + expectedReclaimable); + } + } testSettings[] = { + {0, true, true}, {0, true, true}, {0, false, false}, {0, false, false}}; + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + auto tempDirectory = TempDirectoryPath::create(); + auto queryPool = memory::memoryManager()->addRootPool( + "", kMaxBytes, memory::MemoryReclaimer::create()); + + core::PlanNodeId probeScanId; + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, false) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, false) + .planNode(), + "", + concat(probeType_->names(), buildType_->names())) + .planNode(); + + folly::EventCount driverWait; + auto driverWaitKey = driverWait.prepareWait(); + folly::EventCount testWait; + auto testWaitKey = testWait.prepareWait(); + + std::atomic numInputs{0}; + Operator* op; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::addInput", + std::function(([&](Operator* testOp) { + if (testOp->operatorType() != "HashBuild") { + return; + } + op = testOp; + ++numInputs; + if (testData.triggerCondition == 0) { + if (numInputs != 2) { + return; + } + } + if (testData.triggerCondition == 1) { + if (numInputs != numBuildVectors) { + return; + } + } + ASSERT_EQ(op->canReclaim(), testData.expectedReclaimable); + uint64_t reclaimableBytes{0}; + const bool reclaimable = op->reclaimableBytes(reclaimableBytes); + ASSERT_EQ(reclaimable, testData.expectedReclaimable); + if (testData.expectedReclaimable) { + ASSERT_GT(reclaimableBytes, 0); + } else { + ASSERT_EQ(reclaimableBytes, 0); + } + testWait.notify(); + driverWait.wait(driverWaitKey); + }))); + + std::thread taskThread([&]() { + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .planNode(plan) + .queryPool(std::move(queryPool)) + .injectSpill(false) + .spillDirectory(testData.spillEnabled ? tempDirectory->getPath() : "") + .referenceQuery( + "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") + .config(core::QueryConfig::kSpillStartPartitionBit, "29") + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + const auto statsPair = taskSpilledStats(*task); + if (testData.expectedReclaimable) { + ASSERT_GT(statsPair.first.spilledBytes, 0); + ASSERT_EQ(statsPair.first.spilledPartitions, 8); + ASSERT_GT(statsPair.second.spilledBytes, 0); + ASSERT_EQ(statsPair.second.spilledPartitions, 8); + verifyTaskSpilledRuntimeStats(*task, true); + } else { + ASSERT_EQ(statsPair.first.spilledBytes, 0); + ASSERT_EQ(statsPair.first.spilledPartitions, 0); + ASSERT_EQ(statsPair.second.spilledBytes, 0); + ASSERT_EQ(statsPair.second.spilledPartitions, 0); + verifyTaskSpilledRuntimeStats(*task, false); + } + }) + .run(); + }); + + testWait.wait(testWaitKey); + ASSERT_TRUE(op != nullptr); + auto task = op->operatorCtx()->task(); + auto taskPauseWait = task->requestPause(); + driverWait.notify(); + taskPauseWait.wait(); + + uint64_t reclaimableBytes{0}; + const bool reclaimable = op->reclaimableBytes(reclaimableBytes); + ASSERT_EQ(op->canReclaim(), testData.expectedReclaimable); + ASSERT_EQ(reclaimable, testData.expectedReclaimable); + if (testData.expectedReclaimable) { + ASSERT_GT(reclaimableBytes, 0); + } else { + ASSERT_EQ(reclaimableBytes, 0); + } + + if (testData.expectedReclaimable) { + { + memory::ScopedMemoryArbitrationContext ctx(op->pool()); + op->pool()->reclaim( + folly::Random::oneIn(2) ? 0 : folly::Random::rand32(), + 0, + reclaimerStats_); + } + ASSERT_GT(reclaimerStats_.reclaimExecTimeUs, 0); + ASSERT_GT(reclaimerStats_.reclaimedBytes, 0); + reclaimerStats_.reset(); + ASSERT_EQ(op->pool()->usedBytes(), 0); + } else { + VELOX_ASSERT_THROW( + op->reclaim( + folly::Random::oneIn(2) ? 0 : folly::Random::rand32(), + reclaimerStats_), + ""); + } + + Task::resume(task); + task.reset(); + + taskThread.join(); + } + ASSERT_EQ(reclaimerStats_, memory::MemoryReclaimer::Stats{}); +} + +DEBUG_ONLY_TEST_P(HashJoinTest, reclaimDuringReserve) { + constexpr int64_t kMaxBytes = 1LL << 30; // 1GB + const int32_t numBuildVectors = 3; + std::vector buildVectors; + for (int32_t i = 0; i < numBuildVectors; ++i) { + const size_t size = i == 0 ? 1 : 1'000; + VectorFuzzer fuzzer({.vectorSize = size}, pool()); + buildVectors.push_back(fuzzer.fuzzRow(buildType_)); + } + + const int32_t numProbeVectors = 3; + std::vector probeVectors; + for (int32_t i = 0; i < numProbeVectors; ++i) { + VectorFuzzer fuzzer({.vectorSize = 1'000}, pool()); + probeVectors.push_back(fuzzer.fuzzRow(probeType_)); + } + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto tempDirectory = TempDirectoryPath::create(); + auto queryPool = memory::memoryManager()->addRootPool( + "", kMaxBytes, memory::MemoryReclaimer::create()); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, false) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, false) + .planNode(), + "", + concat(probeType_->names(), buildType_->names())) + .planNode(); + + folly::EventCount driverWait; + std::atomic_bool driverWaitFlag{true}; + folly::EventCount testWait; + std::atomic_bool testWaitFlag{true}; + + Operator* op{nullptr}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::addInput", + std::function(([&](Operator* testOp) { + if (testOp->operatorType() != "HashBuild") { + return; + } + op = testOp; + }))); + + std::atomic injectOnce{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::common::memory::MemoryPoolImpl::maybeReserve", + std::function( + ([&](memory::MemoryPoolImpl* pool) { + ASSERT_TRUE(op != nullptr); + if (!isHashBuildMemoryPool(*pool)) { + return; + } + ASSERT_TRUE(op->canReclaim()); + if (op->pool()->usedBytes() == 0) { + // We skip trigger memory reclaim when the hash table is empty on + // memory reservation. + return; + } + if (!injectOnce.exchange(false)) { + return; + } + uint64_t reclaimableBytes{0}; + const bool reclaimable = op->reclaimableBytes(reclaimableBytes); + ASSERT_TRUE(reclaimable); + ASSERT_GT(reclaimableBytes, 0); + auto* driver = op->operatorCtx()->driver(); + TestSuspendedSection suspendedSection(driver); + testWaitFlag = false; + testWait.notifyAll(); + driverWait.await([&]() { return !driverWaitFlag.load(); }); + }))); + + std::thread taskThread([&]() { + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .planNode(plan) + .queryPool(std::move(queryPool)) + .injectSpill(false) + .spillDirectory(tempDirectory->getPath()) + .referenceQuery( + "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") + .config(core::QueryConfig::kSpillStartPartitionBit, "29") + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + const auto statsPair = taskSpilledStats(*task); + ASSERT_GT(statsPair.first.spilledBytes, 0); + ASSERT_EQ(statsPair.first.spilledPartitions, 8); + ASSERT_GT(statsPair.second.spilledBytes, 0); + ASSERT_EQ(statsPair.second.spilledPartitions, 8); + verifyTaskSpilledRuntimeStats(*task, true); + }) + .run(); + }); + + testWait.await([&]() { return !testWaitFlag.load(); }); + ASSERT_TRUE(op != nullptr); + auto task = op->operatorCtx()->task(); + task->requestPause().wait(); + + uint64_t reclaimableBytes{0}; + const bool reclaimable = op->reclaimableBytes(reclaimableBytes); + ASSERT_TRUE(op->canReclaim()); + ASSERT_TRUE(reclaimable); + ASSERT_GT(reclaimableBytes, 0); + + { + memory::ScopedMemoryArbitrationContext ctx(op->pool()); + uint64_t reclaimedBytes = task->pool()->reclaim( + folly::Random::oneIn(2) ? 0 : folly::Random::rand32(), + 0, + reclaimerStats_); + ASSERT_GT(reclaimedBytes, 0); + } + ASSERT_GT(reclaimerStats_.reclaimedBytes, 0); + ASSERT_GT(reclaimerStats_.reclaimExecTimeUs, 0); + ASSERT_EQ(op->pool()->usedBytes(), 0); + + driverWaitFlag = false; + driverWait.notifyAll(); + Task::resume(task); + task.reset(); + + taskThread.join(); +} + +DEBUG_ONLY_TEST_P(HashJoinTest, reclaimDuringAllocation) { + constexpr int64_t kMaxBytes = 1LL << 30; // 1GB + VectorFuzzer fuzzer({.vectorSize = 1000}, pool()); + const int32_t numBuildVectors = 10; + std::vector buildVectors; + for (int32_t i = 0; i < numBuildVectors; ++i) { + buildVectors.push_back(fuzzer.fuzzRow(buildType_)); + } + const int32_t numProbeVectors = 5; + std::vector probeVectors; + for (int32_t i = 0; i < numProbeVectors; ++i) { + probeVectors.push_back(fuzzer.fuzzRow(probeType_)); + } + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + const std::vector enableSpillings = {false, true}; + for (const auto enableSpilling : enableSpillings) { + SCOPED_TRACE(fmt::format("enableSpilling {}", enableSpilling)); + + auto tempDirectory = TempDirectoryPath::create(); + auto queryPool = memory::memoryManager()->addRootPool("", kMaxBytes); + + core::PlanNodeId probeScanId; + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, false) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, false) + .planNode(), + "", + concat(probeType_->names(), buildType_->names())) + .planNode(); + + folly::EventCount driverWait; + auto driverWaitKey = driverWait.prepareWait(); + folly::EventCount testWait; + auto testWaitKey = testWait.prepareWait(); + + Operator* op; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::addInput", + std::function(([&](Operator* testOp) { + if (testOp->operatorType() != "HashBuild") { + return; + } + op = testOp; + }))); + + std::atomic injectOnce{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::common::memory::MemoryPoolImpl::allocateNonContiguous", + std::function( + ([&](memory::MemoryPoolImpl* pool) { + ASSERT_TRUE(op != nullptr); + const std::string re(".*HashBuild"); + if (!RE2::FullMatch(pool->name(), re)) { + return; + } + if (!injectOnce.exchange(false)) { + return; + } + ASSERT_EQ(op->canReclaim(), enableSpilling); + uint64_t reclaimableBytes{0}; + const bool reclaimable = op->reclaimableBytes(reclaimableBytes); + ASSERT_EQ(reclaimable, enableSpilling); + if (enableSpilling) { + ASSERT_GE(reclaimableBytes, 0); + } else { + ASSERT_EQ(reclaimableBytes, 0); + } + auto* driver = op->operatorCtx()->driver(); + TestSuspendedSection suspendedSection(driver); + testWait.notify(); + driverWait.wait(driverWaitKey); + }))); + + std::thread taskThread([&]() { + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .planNode(plan) + .queryPool(std::move(queryPool)) + .injectSpill(false) + .spillDirectory(enableSpilling ? tempDirectory->getPath() : "") + .referenceQuery( + "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + const auto statsPair = taskSpilledStats(*task); + ASSERT_EQ(statsPair.first.spilledBytes, 0); + ASSERT_EQ(statsPair.first.spilledPartitions, 0); + ASSERT_EQ(statsPair.second.spilledBytes, 0); + ASSERT_EQ(statsPair.second.spilledPartitions, 0); + verifyTaskSpilledRuntimeStats(*task, false); + }) + .run(); + }); + + testWait.wait(testWaitKey); + ASSERT_TRUE(op != nullptr); + auto task = op->operatorCtx()->task(); + auto taskPauseWait = task->requestPause(); + taskPauseWait.wait(); + + uint64_t reclaimableBytes{0}; + const bool reclaimable = op->reclaimableBytes(reclaimableBytes); + ASSERT_EQ(op->canReclaim(), enableSpilling); + ASSERT_EQ(reclaimable, enableSpilling); + if (enableSpilling) { + ASSERT_GE(reclaimableBytes, 0); + } else { + ASSERT_EQ(reclaimableBytes, 0); + } + VELOX_ASSERT_THROW( + op->reclaim( + folly::Random::oneIn(2) ? 0 : folly::Random::rand32(), + reclaimerStats_), + ""); + + driverWait.notify(); + Task::resume(task); + task.reset(); + + taskThread.join(); + } + ASSERT_EQ(reclaimerStats_, memory::MemoryReclaimer::Stats{0}); +} + +DEBUG_ONLY_TEST_P(HashJoinTest, reclaimDuringOutputProcessing) { + constexpr int64_t kMaxBytes = 1LL << 30; // 1GB + VectorFuzzer fuzzer({.vectorSize = 1000}, pool()); + const int32_t numBuildVectors = 10; + std::vector buildVectors; + for (int32_t i = 0; i < numBuildVectors; ++i) { + buildVectors.push_back(fuzzer.fuzzRow(buildType_)); + } + const int32_t numProbeVectors = 5; + std::vector probeVectors; + for (int32_t i = 0; i < numProbeVectors; ++i) { + probeVectors.push_back(fuzzer.fuzzRow(probeType_)); + } + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + const std::vector enableSpillings = {false, true}; + for (const auto enableSpilling : enableSpillings) { + SCOPED_TRACE(fmt::format("enableSpilling {}", enableSpilling)); + auto tempDirectory = TempDirectoryPath::create(); + auto queryPool = memory::memoryManager()->addRootPool( + "", kMaxBytes, memory::MemoryReclaimer::create()); + + core::PlanNodeId probeScanId; + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, false) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, false) + .planNode(), + "", + concat(probeType_->names(), buildType_->names())) + .planNode(); + + std::atomic_bool driverWaitFlag{true}; + folly::EventCount driverWait; + std::atomic_bool testWaitFlag{true}; + folly::EventCount testWait; + + std::atomic injectOnce{true}; + Operator* op; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::noMoreInput", + std::function(([&](Operator* testOp) { + if (testOp->operatorType() != "HashBuild") { + return; + } + op = testOp; + if (!injectOnce.exchange(false)) { + return; + } + ASSERT_EQ(op->canReclaim(), enableSpilling); + uint64_t reclaimableBytes{0}; + const bool reclaimable = op->reclaimableBytes(reclaimableBytes); + ASSERT_EQ(reclaimable, enableSpilling); + if (enableSpilling) { + ASSERT_GT(reclaimableBytes, 0); + } else { + ASSERT_EQ(reclaimableBytes, 0); + } + testWaitFlag = false; + testWait.notifyAll(); + driverWait.await([&]() { return !driverWaitFlag.load(); }); + }))); + + std::thread taskThread([&]() { + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .planNode(plan) + .queryPool(std::move(queryPool)) + .injectSpill(false) + .spillDirectory(enableSpilling ? tempDirectory->getPath() : "") + .referenceQuery( + "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + const auto statsPair = taskSpilledStats(*task); + ASSERT_EQ(statsPair.first.spilledBytes, 0); + ASSERT_EQ(statsPair.first.spilledPartitions, 0); + ASSERT_EQ(statsPair.second.spilledBytes, 0); + ASSERT_EQ(statsPair.second.spilledPartitions, 0); + verifyTaskSpilledRuntimeStats(*task, false); + }) + .run(); + }); + + testWait.await([&]() { return !testWaitFlag.load(); }); + ASSERT_TRUE(op != nullptr); + auto task = op->operatorCtx()->task(); + auto taskPauseWait = task->requestPause(); + driverWaitFlag = false; + driverWait.notifyAll(); + taskPauseWait.wait(); + + uint64_t reclaimableBytes{0}; + const bool reclaimable = op->reclaimableBytes(reclaimableBytes); + ASSERT_EQ(op->canReclaim(), enableSpilling); + ASSERT_EQ(reclaimable, enableSpilling); + + if (enableSpilling) { + ASSERT_GT(reclaimableBytes, 0); + const auto usedMemoryBytes = op->pool()->usedBytes(); + { + memory::ScopedMemoryArbitrationContext ctx(op->pool()); + op->pool()->reclaim( + folly::Random::oneIn(2) ? 0 : folly::Random::rand32(), + 0, + reclaimerStats_); + } + ASSERT_GE(reclaimerStats_.reclaimedBytes, 0); + ASSERT_GT(reclaimerStats_.reclaimExecTimeUs, 0); + // No reclaim as the operator has started output processing. + ASSERT_EQ(usedMemoryBytes, op->pool()->usedBytes()); + } else { + ASSERT_EQ(reclaimableBytes, 0); + VELOX_ASSERT_THROW( + op->reclaim( + folly::Random::oneIn(2) ? 0 : folly::Random::rand32(), + reclaimerStats_), + ""); + } + + Task::resume(task); + task.reset(); + + taskThread.join(); + } + ASSERT_EQ(reclaimerStats_.numNonReclaimableAttempts, 1); +} + +DEBUG_ONLY_TEST_P(HashJoinTest, reclaimDuringWaitForProbe) { + constexpr int64_t kMaxBytes = 1LL << 30; // 1GB + VectorFuzzer fuzzer({.vectorSize = 1000}, pool()); + const int32_t numBuildVectors = 10; + std::vector buildVectors; + for (int32_t i = 0; i < numBuildVectors; ++i) { + buildVectors.push_back(fuzzer.fuzzRow(buildType_)); + } + const int32_t numProbeVectors = 5; + std::vector probeVectors; + for (int32_t i = 0; i < numProbeVectors; ++i) { + probeVectors.push_back(fuzzer.fuzzRow(probeType_)); + } + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto tempDirectory = TempDirectoryPath::create(); + auto queryPool = memory::memoryManager()->addRootPool( + "", kMaxBytes, memory::MemoryReclaimer::create()); + + core::PlanNodeId probeScanId; + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, false) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, false) + .planNode(), + "", + concat(probeType_->names(), buildType_->names())) + .planNode(); + + std::atomic_bool driverWaitFlag{true}; + folly::EventCount driverWait; + std::atomic_bool testWaitFlag{true}; + folly::EventCount testWait; + + Operator* op{nullptr}; + std::atomic injectSpillOnce{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::HashBuild::finishHashBuild", + std::function(([&](Operator* testOp) { + if (testOp->operatorType() != "HashBuild") { + return; + } + op = testOp; + if (!injectSpillOnce.exchange(false)) { + return; + } + auto* driver = op->operatorCtx()->driver(); + auto task = driver->task(); + memory::ScopedMemoryArbitrationContext ctx(op->pool()); + Operator::ReclaimableSectionGuard guard(testOp); + testingRunArbitration(testOp->pool()); + }))); + + std::atomic injectOnce{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::noMoreInput", + std::function(([&](Operator* testOp) { + if (testOp->operatorType() != "HashProbe") { + return; + } + if (!injectOnce.exchange(false)) { + return; + } + ASSERT_TRUE(op != nullptr); + ASSERT_TRUE(op->canReclaim()); + uint64_t reclaimableBytes{0}; + const bool reclaimable = op->reclaimableBytes(reclaimableBytes); + ASSERT_TRUE(reclaimable); + ASSERT_GT(reclaimableBytes, 0); + testWaitFlag = false; + testWait.notifyAll(); + auto* driver = testOp->operatorCtx()->driver(); + auto task = driver->task(); + TestSuspendedSection suspendedSection(driver); + driverWait.await([&]() { return !driverWaitFlag.load(); }); + }))); + + std::thread taskThread([&]() { + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .planNode(plan) + .queryPool(std::move(queryPool)) + .injectSpill(false) + .spillDirectory(tempDirectory->getPath()) + .referenceQuery( + "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") + .config(core::QueryConfig::kSpillStartPartitionBit, "29") + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + const auto statsPair = taskSpilledStats(*task); + ASSERT_GT(statsPair.first.spilledBytes, 0); + ASSERT_EQ(statsPair.first.spilledPartitions, 8); + ASSERT_GT(statsPair.second.spilledBytes, 0); + ASSERT_EQ(statsPair.second.spilledPartitions, 8); + }) + .run(); + }); + + testWait.await([&]() { return !testWaitFlag.load(); }); + ASSERT_TRUE(op != nullptr); + auto task = op->operatorCtx()->task(); + auto taskPauseWait = task->requestPause(); + taskPauseWait.wait(); + + uint64_t reclaimableBytes{0}; + const bool reclaimable = op->reclaimableBytes(reclaimableBytes); + ASSERT_TRUE(op->canReclaim()); + ASSERT_TRUE(reclaimable); + ASSERT_GT(reclaimableBytes, 0); + + const auto usedMemoryBytes = op->pool()->usedBytes(); + reclaimerStats_.reset(); + { + memory::ScopedMemoryArbitrationContext ctx(op->pool()); + op->pool()->reclaim( + folly::Random::oneIn(2) ? 0 : folly::Random::rand32(), + 0, + reclaimerStats_); + } + ASSERT_GE(reclaimerStats_.reclaimedBytes, 0); + ASSERT_GT(reclaimerStats_.reclaimExecTimeUs, 0); + // No reclaim as the build operator is not in building table state. + ASSERT_EQ(usedMemoryBytes, op->pool()->usedBytes()); + + driverWaitFlag = false; + driverWait.notifyAll(); + Task::resume(task); + task.reset(); + + taskThread.join(); + ASSERT_EQ(reclaimerStats_.numNonReclaimableAttempts, 1); +} + +DEBUG_ONLY_TEST_P(HashJoinTest, hashBuildAbortDuringOutputProcessing) { + const auto buildVectors = makeVectors(buildType_, 10, 128); + const auto probeVectors = makeVectors(probeType_, 5, 128); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + struct { + bool abortFromRootMemoryPool; + int numDrivers; + + std::string debugString() const { + return fmt::format( + "abortFromRootMemoryPool {} numDrivers {}", + abortFromRootMemoryPool, + numDrivers); + } + } testSettings[] = {{true, 1}, {false, 1}, {true, 4}, {false, 4}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, true) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, true) + .planNode(), + "", + concat(probeType_->names(), buildType_->names())) + .planNode(); + + std::atomic injectOnce{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::noMoreInput", + std::function(([&](Operator* op) { + if (op->operatorType() != "HashBuild") { + return; + } + if (!injectOnce.exchange(false)) { + return; + } + ASSERT_GT(op->pool()->usedBytes(), 0); + auto* driver = op->operatorCtx()->driver(); + ASSERT_EQ( + driver->task()->enterSuspended(driver->state()), + StopReason::kNone); + testData.abortFromRootMemoryPool ? abortPool(op->pool()->root()) + : abortPool(op->pool()); + // We can't directly reclaim memory from this hash build operator as + // its driver thread is running and in suspension state. + ASSERT_GT(op->pool()->root()->usedBytes(), 0); + ASSERT_EQ( + driver->task()->leaveSuspended(driver->state()), + StopReason::kAlreadyTerminated); + ASSERT_TRUE(op->pool()->aborted()); + ASSERT_TRUE(op->pool()->root()->aborted()); + VELOX_MEM_POOL_ABORTED("Memory pool aborted"); + }))); + + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .planNode(plan) + .injectSpill(false) + .referenceQuery( + "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") + .run(), + "Manual MemoryPool Abortion"); + waitForAllTasksToBeDeleted(); + } +} + +DEBUG_ONLY_TEST_P(HashJoinTest, hashBuildAbortDuringInputProcessing) { + const auto buildVectors = makeVectors(buildType_, 10, 128); + const auto probeVectors = makeVectors(probeType_, 5, 128); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + struct { + bool abortFromRootMemoryPool; + int numDrivers; + + std::string debugString() const { + return fmt::format( + "abortFromRootMemoryPool {} numDrivers {}", + abortFromRootMemoryPool, + numDrivers); + } + } testSettings[] = {{true, 1}, {false, 1}, {true, 4}, {false, 4}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, true) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, true) + .planNode(), + "", + concat(probeType_->names(), buildType_->names())) + .planNode(); + + std::atomic numInputs{0}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::addInput", + std::function(([&](Operator* op) { + if (op->operatorType() != "HashBuild") { + return; + } + if (++numInputs != 2) { + return; + } + ASSERT_GT(op->pool()->usedBytes(), 0); + auto* driver = op->operatorCtx()->driver(); + ASSERT_EQ( + driver->task()->enterSuspended(driver->state()), + StopReason::kNone); + testData.abortFromRootMemoryPool ? abortPool(op->pool()->root()) + : abortPool(op->pool()); + // We can't directly reclaim memory from this hash build operator as + // its driver thread is running and in suspension state. + ASSERT_GT(op->pool()->root()->usedBytes(), 0); + ASSERT_EQ( + driver->task()->leaveSuspended(driver->state()), + StopReason::kAlreadyTerminated); + ASSERT_TRUE(op->pool()->aborted()); + ASSERT_TRUE(op->pool()->root()->aborted()); + VELOX_MEM_POOL_ABORTED("Memory pool aborted"); + }))); + + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .planNode(plan) + .injectSpill(false) + .referenceQuery( + "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") + .run(), + "Manual MemoryPool Abortion"); + + waitForAllTasksToBeDeleted(); + } +} + +DEBUG_ONLY_TEST_P(HashJoinTest, hashBuildAbortDuringAllocation) { + const auto buildVectors = makeVectors(buildType_, 10, 128); + const auto probeVectors = makeVectors(probeType_, 5, 128); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + struct { + bool abortFromRootMemoryPool; + int numDrivers; + + std::string debugString() const { + return fmt::format( + "abortFromRootMemoryPool {} numDrivers {}", + abortFromRootMemoryPool, + numDrivers); + } + } testSettings[] = {{true, 1}, {false, 1}, {true, 4}, {false, 4}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, true) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, true) + .planNode(), + "", + concat(probeType_->names(), buildType_->names())) + .planNode(); + + std::atomic_bool injectOnce{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::common::memory::MemoryPoolImpl::allocateNonContiguous", + std::function( + ([&](memory::MemoryPoolImpl* pool) { + if (!isHashBuildMemoryPool(*pool)) { + return; + } + if (!injectOnce.exchange(false)) { + return; + } + + const auto* driverCtx = driverThreadContext()->driverCtx(); + ASSERT_EQ( + driverCtx->task->enterSuspended(driverCtx->driver->state()), + StopReason::kNone); + testData.abortFromRootMemoryPool ? abortPool(pool->root()) + : abortPool(pool); + // We can't directly reclaim memory from this hash build operator + // as its driver thread is running and in suspegnsion state. + ASSERT_GE(pool->root()->usedBytes(), 0); + ASSERT_EQ( + driverCtx->task->leaveSuspended(driverCtx->driver->state()), + StopReason::kAlreadyTerminated); + ASSERT_TRUE(pool->aborted()); + ASSERT_TRUE(pool->root()->aborted()); + VELOX_MEM_POOL_ABORTED("Memory pool aborted"); + }))); + + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .planNode(plan) + .injectSpill(false) + .referenceQuery( + "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") + .run(), + "Manual MemoryPool Abortion"); + + waitForAllTasksToBeDeleted(); + } +} + +DEBUG_ONLY_TEST_P(HashJoinTest, hashProbeAbortDuringInputProcessing) { + const auto buildVectors = makeVectors(buildType_, 10, 128); + const auto probeVectors = makeVectors(probeType_, 5, 128); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + struct { + bool abortFromRootMemoryPool; + int numDrivers; + + std::string debugString() const { + return fmt::format( + "abortFromRootMemoryPool {} numDrivers {}", + abortFromRootMemoryPool, + numDrivers); + } + } testSettings[] = {{true, 1}, {false, 1}, {true, 4}, {false, 4}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, true) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, true) + .planNode(), + "", + concat(probeType_->names(), buildType_->names())) + .planNode(); + + std::atomic numInputs{0}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::addInput", + std::function(([&](Operator* op) { + if (op->operatorType() != "HashProbe") { + return; + } + if (++numInputs != 2) { + return; + } + auto* driver = op->operatorCtx()->driver(); + ASSERT_EQ( + driver->task()->enterSuspended(driver->state()), + StopReason::kNone); + testData.abortFromRootMemoryPool ? abortPool(op->pool()->root()) + : abortPool(op->pool()); + ASSERT_EQ( + driver->task()->leaveSuspended(driver->state()), + StopReason::kAlreadyTerminated); + ASSERT_TRUE(op->pool()->aborted()); + ASSERT_TRUE(op->pool()->root()->aborted()); + VELOX_MEM_POOL_ABORTED("Memory pool aborted"); + }))); + + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .planNode(plan) + .injectSpill(false) + .referenceQuery( + "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") + .run(), + "Manual MemoryPool Abortion"); + waitForAllTasksToBeDeleted(); + } +} + +TEST_P(HashJoinTest, leftJoinWithMissAtEndOfBatch) { + // Tests some cases where the row at the end of an output batch fails the + // filter. + auto probeVectors = std::vector{makeRowVector( + {"t_k1", "t_k2"}, + {makeFlatVector(20, [](auto row) { return 1 + row % 2; }), + makeFlatVector(20, [](auto row) { return row; })})}; + auto buildVectors = std::vector{ + makeRowVector({"u_k1"}, {makeFlatVector({1, 2})})}; + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {buildVectors}); + auto planNodeIdGenerator = std::make_shared(); + + auto test = [&](const std::string& filter) { + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, true) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, true) + .planNode(), + filter, + {"t_k1", "u_k1"}, + core::JoinType::kLeft) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .injectSpill(false) + .checkSpillStats(false) + .maxSpillLevel(0) + .numDrivers(1) + .config( + core::QueryConfig::kPreferredOutputBatchRows, std::to_string(10)) + .referenceQuery( + fmt::format( + "SELECT t_k1, u_k1 from t left join u on t_k1 = u_k1 and {}", + filter)) + .run(); + }; + + // Alternate rows pass this filter and last row of a batch fails. + test("t_k1=1"); + + // All rows fail this filter. + test("t_k1=5"); + + // All rows in the second batch pass this filter. + test("t_k2 > 9"); +} + +TEST_P(HashJoinTest, leftJoinWithMissAtEndOfBatchMultipleBuildMatches) { + // Tests some cases where the row at the end of an output batch fails the + // filter and there are multiple matches with the build side.. + auto probeVectors = std::vector{makeRowVector( + {"t_k1", "t_k2"}, + {makeFlatVector(10, [](auto row) { return 1 + row % 2; }), + makeFlatVector(10, [](auto row) { return row; })})}; + auto buildVectors = std::vector{ + makeRowVector({"u_k1"}, {makeFlatVector({1, 2, 1, 2})})}; + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {buildVectors}); + auto planNodeIdGenerator = std::make_shared(); + + auto test = [&](const std::string& filter) { + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, true) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, true) + .planNode(), + filter, + {"t_k1", "u_k1"}, + core::JoinType::kLeft) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .injectSpill(false) + .checkSpillStats(false) + .maxSpillLevel(0) + .numDrivers(1) + .config( + core::QueryConfig::kPreferredOutputBatchRows, std::to_string(10)) + .referenceQuery( + fmt::format( + "SELECT t_k1, u_k1 from t left join u on t_k1 = u_k1 and {}", + filter)) + .run(); + }; + + // In this case the rows with t_k2 = 4 appear at the end of the first batch, + // meaning the last rows in that output batch are misses, and don't get added. + // The rows with t_k2 = 8 appear in the second batch so only one row is + // written, meaning there is space in the second output batch for the miss + // with tk_2 = 4 to get written. + test("t_k2 != 4 and t_k2 != 8"); +} + +TEST_P(HashJoinTest, leftJoinPreserveProbeOrder) { + const std::vector probeVectors = { + makeRowVector( + {"k1", "v1"}, + { + makeConstant(0, 2), + makeFlatVector({1, 0}), + }), + }; + const std::vector buildVectors = { + makeRowVector( + {"k2", "v2"}, + { + makeConstant(0, 2), + makeConstant(0, 2), + }), + }; + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .values(probeVectors) + .hashJoin( + {"k1"}, + {"k2"}, + PlanBuilder(planNodeIdGenerator).values(buildVectors).planNode(), + "v1 % 2 = v2 % 2", + {"v1"}, + core::JoinType::kLeft) + .planNode(); + auto result = AssertQueryBuilder(plan) + .config(core::QueryConfig::kPreferredOutputBatchRows, "1") + .serialExecution(true) + .copyResults(pool_.get()); + ASSERT_EQ(result->size(), 3); + auto* v1 = + result->childAt(0)->loadedVector()->asUnchecked>(); + ASSERT_FALSE(v1->mayHaveNulls()); + ASSERT_EQ(v1->valueAt(0), 1); + ASSERT_EQ(v1->valueAt(1), 0); + ASSERT_EQ(v1->valueAt(2), 0); +} + +DEBUG_ONLY_TEST_P(HashJoinTest, minSpillableMemoryReservation) { + VectorFuzzer fuzzer({.vectorSize = 1000}, pool()); + const int32_t numBuildVectors = 10; + std::vector buildVectors; + for (int32_t i = 0; i < numBuildVectors; ++i) { + buildVectors.push_back(fuzzer.fuzzInputRow(buildType_)); + } + const int32_t numProbeVectors = 5; + std::vector probeVectors; + for (int32_t i = 0; i < numProbeVectors; ++i) { + probeVectors.push_back(fuzzer.fuzzInputRow(probeType_)); + } + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + core::PlanNodeId probeScanId; + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, false) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, false) + .planNode(), + "", + concat(probeType_->names(), buildType_->names())) + .planNode(); + + for (int32_t minSpillableReservationPct : {5, 50, 100}) { + SCOPED_TRACE( + fmt::format( + "minSpillableReservationPct: {}", minSpillableReservationPct)); + + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::HashBuild::addInput", + std::function(([&](exec::HashBuild* hashBuild) { + memory::MemoryPool* pool = hashBuild->pool(); + const auto availableReservationBytes = pool->availableReservation(); + const auto currentUsedBytes = pool->usedBytes(); + // Verifies we always have min reservation after ensuring the input. + ASSERT_GE( + availableReservationBytes, + currentUsedBytes * minSpillableReservationPct / 100); + }))); + + auto tempDirectory = TempDirectoryPath::create(); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .planNode(plan) + .injectSpill(false) + .spillDirectory(tempDirectory->getPath()) + .referenceQuery( + "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") + .run(); + } +} + +DEBUG_ONLY_TEST_P(HashJoinTest, exceededMaxSpillLevel) { + VectorFuzzer fuzzer({.vectorSize = 1000}, pool()); + const int32_t numBuildVectors = 10; + std::vector buildVectors; + for (int32_t i = 0; i < numBuildVectors; ++i) { + buildVectors.push_back(fuzzer.fuzzRow(buildType_)); + } + const int32_t numProbeVectors = 5; + std::vector probeVectors; + for (int32_t i = 0; i < numProbeVectors; ++i) { + probeVectors.push_back(fuzzer.fuzzRow(probeType_)); + } + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + core::PlanNodeId probeScanId; + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, false) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, false) + .planNode(), + "", + concat(probeType_->names(), buildType_->names())) + .planNode(); + + auto tempDirectory = TempDirectoryPath::create(); + const int exceededMaxSpillLevelCount = + globalSpillStats().spillMaxLevelExceededCount; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::HashBuild::reclaim", + std::function(([&](exec::Operator* op) { + HashBuild* hashBuild = static_cast(op); + ASSERT_FALSE(hashBuild->testingExceededMaxSpillLevelLimit()); + }))); + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::HashProbe::reclaim", + std::function(([&](exec::Operator* op) { + HashProbe* hashProbe = static_cast(op); + ASSERT_FALSE(hashProbe->testingExceededMaxSpillLevelLimit()); + }))); + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::HashBuild::finishHashBuild", + std::function(([&](exec::HashBuild* hashBuild) { + Operator::ReclaimableSectionGuard guard(hashBuild); + testingRunArbitration(hashBuild->pool()); + }))); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(1) + .planNode(plan) + // Always trigger spilling. + .injectSpill(false) + .maxSpillLevel(0) + .spillDirectory(tempDirectory->getPath()) + .referenceQuery( + "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") + .config(core::QueryConfig::kSpillStartPartitionBit, "29") + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + auto opStats = toOperatorStats(task->taskStats()); + ASSERT_EQ( + opStats.at("HashProbe") + .runtimeStats[std::string(Operator::kExceededMaxSpillLevel)] + .sum, + 8); + ASSERT_EQ( + opStats.at("HashProbe") + .runtimeStats[std::string(Operator::kExceededMaxSpillLevel)] + .count, + 1); + ASSERT_EQ( + opStats.at("HashBuild") + .runtimeStats[std::string(Operator::kExceededMaxSpillLevel)] + .sum, + 8); + ASSERT_EQ( + opStats.at("HashBuild") + .runtimeStats[std::string(Operator::kExceededMaxSpillLevel)] + .count, + 1); + }) + .run(); + ASSERT_EQ( + globalSpillStats().spillMaxLevelExceededCount, + exceededMaxSpillLevelCount + 16); +} + +TEST_P(HashJoinTest, maxSpillBytes) { + const auto rowType = + ROW({"c0", "c1", "c2"}, {INTEGER(), INTEGER(), VARCHAR()}); + const auto probeVectors = createVectors(rowType, 1024, 10 << 20); + const auto buildVectors = createVectors(rowType, 1024, 10 << 20); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, true) + .project({"c0", "c1", "c2"}) + .hashJoin( + {"c0"}, + {"u1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, true) + .project({"c0 AS u0", "c1 AS u1", "c2 AS u2"}) + .planNode(), + "", + {"c0", "c1", "c2"}, + core::JoinType::kInner) + .planNode(); + + auto spillDirectory = TempDirectoryPath::create(); + auto queryCtx = core::QueryCtx::create(executor_.get()); + + struct { + int32_t maxSpilledBytes; + bool expectedExceedLimit; + std::string debugString() const { + return fmt::format("maxSpilledBytes {}", maxSpilledBytes); + } + } testSettings[] = {{1 << 30, false}, {16 << 20, true}, {0, false}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + try { + TestScopedSpillInjection scopedSpillInjection(100); + AssertQueryBuilder(plan) + .spillDirectory(spillDirectory->getPath()) + .queryCtx(queryCtx) + .config(core::QueryConfig::kSpillEnabled, true) + .config(core::QueryConfig::kJoinSpillEnabled, true) + .config(core::QueryConfig::kMaxSpillBytes, testData.maxSpilledBytes) + .copyResults(pool_.get()); + ASSERT_FALSE(testData.expectedExceedLimit); + } catch (const VeloxRuntimeError& e) { + ASSERT_TRUE(testData.expectedExceedLimit); + ASSERT_NE( + e.message().find( + "Query exceeded per-query local spill limit of 16.00MB"), + std::string::npos); + ASSERT_EQ( + e.errorCode(), facebook::velox::error_code::kSpillLimitExceeded); + } + } +} + +TEST_P(HashJoinTest, onlyHashBuildMaxSpillBytes) { + const auto rowType = + ROW({"c0", "c1", "c2"}, {INTEGER(), INTEGER(), VARCHAR()}); + const auto probeVectors = createVectors(rowType, 32, 128); + const auto buildVectors = createVectors(rowType, 1024, 10 << 20); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, true) + .hashJoin( + {"c0"}, + {"u1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, true) + .project({"c0 AS u0", "c1 AS u1", "c2 AS u2"}) + .planNode(), + "", + {"c0", "c1", "c2"}, + core::JoinType::kInner) + .planNode(); + + auto spillDirectory = TempDirectoryPath::create(); + auto queryCtx = core::QueryCtx::create(executor_.get()); + + struct { + int32_t maxSpilledBytes; + bool expectedExceedLimit; + std::string debugString() const { + return fmt::format("maxSpilledBytes {}", maxSpilledBytes); + } + } testSettings[] = {{1 << 30, false}, {16 << 20, true}, {0, false}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + try { + TestScopedSpillInjection scopedSpillInjection(100); + AssertQueryBuilder(plan) + .spillDirectory(spillDirectory->getPath()) + .queryCtx(queryCtx) + .config(core::QueryConfig::kSpillEnabled, true) + .config(core::QueryConfig::kJoinSpillEnabled, true) + .config(core::QueryConfig::kMaxSpillBytes, testData.maxSpilledBytes) + .copyResults(pool_.get()); + ASSERT_FALSE(testData.expectedExceedLimit); + } catch (const VeloxRuntimeError& e) { + ASSERT_TRUE(testData.expectedExceedLimit); + ASSERT_NE( + e.message().find( + "Query exceeded per-query local spill limit of 16.00MB"), + std::string::npos); + ASSERT_EQ( + e.errorCode(), facebook::velox::error_code::kSpillLimitExceeded); + } + } +} + +TEST_P(HashJoinTest, reclaimFromJoinBuilderWithMultiDrivers) { + auto rowType = ROW({ + {"c0", INTEGER()}, + {"c1", INTEGER()}, + {"c2", VARCHAR()}, + }); + const auto vectors = createVectors(rowType, 64 << 20, fuzzerOpts_); + const int numDrivers = 4; + + memory::MemoryManager::Options options; + options.allocatorCapacity = 8L << 30; + auto memoryManagerWithoutArbitrator = + std::make_unique(options); + const auto expectedResult = + runHashJoinTask( + vectors, + newQueryCtx( + memoryManagerWithoutArbitrator.get(), executor_.get(), 8L << 30), + false, + numDrivers, + pool(), + false) + .data; + + auto memoryManagerWithArbitrator = createMemoryManager(); + const auto& arbitrator = memoryManagerWithArbitrator->arbitrator(); + // Create a query ctx with a small capacity to trigger spilling. + auto result = runHashJoinTask( + vectors, + newQueryCtx( + memoryManagerWithArbitrator.get(), executor_.get(), 128 << 20), + false, + numDrivers, + pool(), + true, + expectedResult); + auto taskStats = exec::toPlanStats(result.task->taskStats()); + auto& planStats = taskStats.at(result.planNodeId); + ASSERT_GT(planStats.spilledBytes, 0); + result.task.reset(); + + // This test uses on-demand created memory manager instead of the global + // one. We need to make sure any used memory got cleaned up before exiting + // the scope + waitForAllTasksToBeDeleted(); + ASSERT_GT(arbitrator->stats().numRequests, 0); + ASSERT_GT(arbitrator->stats().reclaimedUsedBytes, 0); +} + +TEST_P(HashJoinTest, semiJoinAbandonBuildNoDupHashEarly) { + auto probeVectors = makeBatches(3, [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector({1, 2, 2, 3, 3, 3, 4, 5, 5, 6, 7}), + makeFlatVector({10, 20, 21, 30, 31, 32, 40, 50, 51, 60, 70}), + }); + }); + + auto buildVectors = makeBatches(3, [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector({1, 1, 3, 4, 5, 5, 7, 8}), + makeFlatVector({100, 101, 300, 400, 500, 501, 700, 800}), + }); + }); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors) + .project({"c0 AS t0", "c1 AS t1"}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors) + .project({"c0 AS u0", "c1 AS u1"}) + .planNode(), + "", + {"t0", "t1", "match"}, + core::JoinType::kLeftSemiProject) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .config(core::QueryConfig::kAbandonDedupHashMapMinRows, "1") + .config(core::QueryConfig::kAbandonDedupHashMapMinPct, "10") + .planNode(plan) + .referenceQuery( + "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE t.c0 = u.c0) FROM t") + .run(); +} + +TEST_P(HashJoinTest, antiJoinAbandonBuildNoDupHashEarly) { + auto probeVectors = makeBatches(64, [&](int32_t /*unused*/) { + return makeRowVector( + {"t0", "t1"}, + { + makeNullableFlatVector({std::nullopt, 1, 2, 3, 4, 5, 6}), + makeFlatVector({0, 1, 2, 3, 4, 5, 6}), + }); + }); + auto buildVectors = makeBatches(64, [&](int32_t /*unused*/) { + return makeRowVector( + {"u0", "u1"}, + { + makeNullableFlatVector({std::nullopt, 2, 3, 4, 6, 7, 8}), + makeFlatVector({0, 2, 3, 4, 6, 7, 8}), + }); + }); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .config(core::QueryConfig::kAbandonDedupHashMapMinRows, "1") + .config(core::QueryConfig::kAbandonDedupHashMapMinPct, "10") + .numDrivers(numDrivers_) + .probeKeys({"t0"}) + .probeVectors(std::vector(probeVectors)) + .buildKeys({"u0"}) + .buildVectors(std::vector(buildVectors)) + .joinType(core::JoinType::kAnti) + .joinOutputLayout({"t0", "t1"}) + .referenceQuery( + "SELECT t.* FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE u.u0 = t.t0)") + .run(); +} + +TEST_P(HashJoinTest, semiJoinDeduplicateResetCapacity) { + const int32_t vectorSize = 10; + const int32_t numBatches = 210; + VectorFuzzer fuzzer({.vectorSize = vectorSize}, pool()); + + // Row type with double and int64_t columns + // Join Key is double -> VectorHasher::typeKindSupportsValueIds will + // return false -> HashMode is kHash + auto rowType = ROW({"c0", "c1"}, {DOUBLE(), BIGINT()}); + + auto probeVectors = makeBatches( + numBatches, [&](int32_t /*unused*/) { return fuzzer.fuzzRow(rowType); }); + + auto buildVectors = makeBatches( + numBatches, [&](int32_t /*unused*/) { return fuzzer.fuzzRow(rowType); }); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors) + .project({"c0 AS t0", "c1 AS t1"}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors) + .project({"c0 AS u0", "c1 AS u1"}) + .planNode(), + "", + {"t0", "t1", "match"}, + core::JoinType::kLeftSemiProject) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .config(core::QueryConfig::kAbandonDedupHashMapMinRows, "10") + .config(core::QueryConfig::kAbandonDedupHashMapMinPct, "50") + .numDrivers(1) + .checkSpillStats(false) + .planNode(plan) + .referenceQuery( + "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE t.c0 = u.c0) FROM t") + .run(); +} + +DEBUG_ONLY_TEST_P( + HashJoinTest, + failedToReclaimFromHashJoinBuildersInNonReclaimableSection) { + auto rowType = ROW({ + {"c0", INTEGER()}, + {"c1", INTEGER()}, + {"c2", VARCHAR()}, + }); + const auto vectors = createVectors(rowType, 64 << 20, fuzzerOpts_); + const int numDrivers = 1; + std::shared_ptr queryCtx = + newQueryCtx(memory::memoryManager(), executor_.get(), 512 << 20); + const auto expectedResult = + runHashJoinTask(vectors, queryCtx, false, numDrivers, pool(), false).data; + + std::atomic_bool nonReclaimableSectionWaitFlag{true}; + std::atomic_bool reclaimerInitializationWaitFlag{true}; + folly::EventCount nonReclaimableSectionWait; + std::atomic_bool memoryArbitrationWaitFlag{true}; + folly::EventCount memoryArbitrationWait; + + std::atomic numInitializedDrivers{0}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal", + std::function([&](exec::Driver* driver) { + numInitializedDrivers++; + // We need to make sure reclaimers on both build and probe side are set + // (in Operator::initialize) to avoid race conditions, producing + // consistent test results. + if (numInitializedDrivers.load() == 2) { + reclaimerInitializationWaitFlag = false; + nonReclaimableSectionWait.notifyAll(); + } + })); + + std::atomic injectNonReclaimableSectionOnce{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::common::memory::MemoryPoolImpl::allocateNonContiguous", + std::function( + ([&](memory::MemoryPoolImpl* pool) { + if (!isHashBuildMemoryPool(*pool)) { + return; + } + if (!injectNonReclaimableSectionOnce.exchange(false)) { + return; + } + + // Signal the test control that one of the hash build operator has + // entered into non-reclaimable section. + nonReclaimableSectionWaitFlag = false; + nonReclaimableSectionWait.notifyAll(); + + // Suspend the driver to simulate the arbitration. + pool->reclaimer()->enterArbitration(); + // Wait for the memory arbitration to complete. + memoryArbitrationWait.await( + [&]() { return !memoryArbitrationWaitFlag.load(); }); + pool->reclaimer()->leaveArbitration(); + }))); + + std::thread joinThread([&]() { + const auto result = runHashJoinTask( + vectors, queryCtx, false, numDrivers, pool(), true, expectedResult); + auto taskStats = exec::toPlanStats(result.task->taskStats()); + auto& planStats = taskStats.at(result.planNodeId); + ASSERT_EQ(planStats.spilledBytes, 0); + }); + + // Wait for the hash build operators to enter into non-reclaimable section. + nonReclaimableSectionWait.await([&]() { + return ( + !nonReclaimableSectionWaitFlag.load() && + !reclaimerInitializationWaitFlag.load()); + }); + + // We expect capacity grow fails as we can't reclaim from hash join operators. + memory::testingRunArbitration(); + + // Notify the hash build operator that memory arbitration has been done. + memoryArbitrationWaitFlag = false; + memoryArbitrationWait.notifyAll(); + + joinThread.join(); + + // This test uses on-demand created memory manager instead of the global + // one. We need to make sure any used memory got cleaned up before exiting + // the scope + waitForAllTasksToBeDeleted(); + ASSERT_EQ( + memory::memoryManager()->arbitrator()->stats().numNonReclaimableAttempts, + 2); +} + +DEBUG_ONLY_TEST_P(HashJoinTest, reclaimDuringTableBuild) { + VectorFuzzer fuzzer({.vectorSize = 1000}, pool()); + const int32_t numBuildVectors = 5; + std::vector buildVectors; + for (int32_t i = 0; i < numBuildVectors; ++i) { + buildVectors.push_back(fuzzer.fuzzRow(buildType_)); + } + const int32_t numProbeVectors = 5; + std::vector probeVectors; + for (int32_t i = 0; i < numProbeVectors; ++i) { + probeVectors.push_back(fuzzer.fuzzRow(probeType_)); + } + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + core::PlanNodeId probeScanId; + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, false) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, false) + .planNode(), + "", + concat(probeType_->names(), buildType_->names())) + .planNode(); + + std::atomic_bool injectSpillOnce{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::HashBuild::finishHashBuild", + std::function([&](Operator* op) { + if (!injectSpillOnce.exchange(false)) { + return; + } + Operator::ReclaimableSectionGuard guard(op); + testingRunArbitration(op->pool()); + })); + + auto tempDirectory = TempDirectoryPath::create(); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(4) + .planNode(plan) + .injectSpill(false) + .maxSpillLevel(0) + .spillDirectory(tempDirectory->getPath()) + .referenceQuery( + "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") + .config(core::QueryConfig::kSpillStartPartitionBit, "29") + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + auto opStats = toOperatorStats(task->taskStats()); + ASSERT_GT( + opStats.at("HashBuild") + .runtimeStats[std::string(Operator::kSpillWrites)] + .sum, + 0); + }) + .run(); +} + +DEBUG_ONLY_TEST_P(HashJoinTest, exceptionDuringFinishJoinBuild) { + // This test is to make sure there is no memory leak when exceptions are + // thrown while parallelly preparing join table. + auto memoryManager = memory::memoryManager(); + const auto& arbitrator = memoryManager->arbitrator(); + const uint64_t numDrivers = 2; + const auto expectedFreeCapacityBytes = arbitrator->stats().freeCapacityBytes; + + const uint64_t numBuildSideRows = 500; + auto buildKeyVector = makeFlatVector( + numBuildSideRows, + [](vector_size_t row) { return folly::Random::rand64(); }); + auto buildSideVector = + makeRowVector({"b0", "b1"}, {buildKeyVector, buildKeyVector}); + std::vector buildSideVectors; + for (int i = 0; i < numDrivers; ++i) { + buildSideVectors.push_back(buildSideVector); + } + createDuckDbTable("build", buildSideVectors); + + const uint64_t numProbeSideRows = 10; + auto probeKeyVector = makeFlatVector( + numProbeSideRows, + [&](vector_size_t row) { return buildKeyVector->valueAt(row); }); + auto probeSideVector = + makeRowVector({"p0", "p1"}, {probeKeyVector, probeKeyVector}); + std::vector probeSideVectors; + for (int i = 0; i < numDrivers; ++i) { + probeSideVectors.push_back(probeSideVector); + } + createDuckDbTable("probe", probeSideVectors); + + ASSERT_EQ(arbitrator->stats().freeCapacityBytes, expectedFreeCapacityBytes); + + // We set the task to fail right before we reserve memory for other operators. + // We rely on the driver suspension before parallel join build to throw + // exceptions (suspension on an already terminated task throws). + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::HashBuild::ensureTableFits", + std::function([&](HashBuild* buildOp) { + try { + VELOX_FAIL("Simulated failure"); + } catch (VeloxException&) { + buildOp->operatorCtx()->task()->setError(std::current_exception()); + } + })); + + std::vector probeInput = {probeSideVector}; + std::vector buildInput = {buildSideVector}; + auto planNodeIdGenerator = std::make_shared(); + const auto spillDirectory = TempDirectoryPath::create(); + + ASSERT_EQ(arbitrator->stats().freeCapacityBytes, expectedFreeCapacityBytes); + VELOX_ASSERT_THROW( + AssertQueryBuilder(duckDbQueryRunner_) + .spillDirectory(spillDirectory->getPath()) + .config(core::QueryConfig::kSpillEnabled, true) + .config(core::QueryConfig::kJoinSpillEnabled, true) + .queryCtx( + newQueryCtx(memoryManager, executor_.get(), kMemoryCapacity)) + .maxDrivers(numDrivers) + .plan(PlanBuilder(planNodeIdGenerator) + .values(probeInput, true) + .hashJoin( + {"p0"}, + {"b0"}, + PlanBuilder(planNodeIdGenerator) + .values(buildInput, true) + .planNode(), + "", + {"p0", "p1", "b0", "b1"}, + core::JoinType::kInner) + .planNode()) + .assertResults( + "SELECT probe.p0, probe.p1, build.b0, build.b1 FROM probe " + "INNER JOIN build ON probe.p0 = build.b0"), + "Simulated failure"); + // This test uses on-demand created memory manager instead of the global + // one. We need to make sure any used memory got cleaned up before exiting + // the scope + waitForAllTasksToBeDeleted(); + ASSERT_EQ(arbitrator->stats().freeCapacityBytes, expectedFreeCapacityBytes); +} + +DEBUG_ONLY_TEST_P(HashJoinTest, arbitrationTriggeredDuringParallelJoinBuild) { + std::unique_ptr memoryManager = createMemoryManager(); + const uint64_t numDrivers = 2; + + // Large build side key product to bump hash mode to kHash instead of kArray + // to trigger parallel join build. + const uint64_t numBuildSideRows = 500; + auto buildKeyVector = makeFlatVector( + numBuildSideRows, + [](vector_size_t row) { return folly::Random::rand64(); }); + auto buildSideVector = makeRowVector( + {"b0", "b1", "b2"}, {buildKeyVector, buildKeyVector, buildKeyVector}); + std::vector buildSideVectors; + for (int i = 0; i < numDrivers; ++i) { + buildSideVectors.push_back(buildSideVector); + } + createDuckDbTable("build", buildSideVectors); + + const uint64_t numProbeSideRows = 10; + auto probeKeyVector = makeFlatVector( + numProbeSideRows, + [&](vector_size_t row) { return buildKeyVector->valueAt(row); }); + auto probeSideVector = makeRowVector( + {"p0", "p1", "p2"}, {probeKeyVector, probeKeyVector, probeKeyVector}); + std::vector probeSideVectors; + for (int i = 0; i < numDrivers; ++i) { + probeSideVectors.push_back(probeSideVector); + } + createDuckDbTable("probe", probeSideVectors); + + std::shared_ptr joinQueryCtx = + newQueryCtx(memoryManager.get(), executor_.get(), kMemoryCapacity); + + const int64_t allocSize = 512LL << 20; + std::atomic parallelBuildTriggered{false}; + std::atomic joinBuildPool{nullptr}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::HashTable::parallelJoinBuild", + std::function([&](memory::MemoryPool* pool) { + parallelBuildTriggered = true; + // Pick the last running driver threads' pool for later memory + // allocation. This pick is rather arbitrary, as it is un-important + // which pool is going to be allocated from later in a parallel join's + // off-driver thread. + joinBuildPool = pool; + })); + + std::atomic_bool offThreadAllocationTriggered{false}; + folly::EventCount asyncMoveWait; + std::atomic asyncMoveWaitFlag{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::AsyncSource::prepare", + std::function([&](void* /* unused */) { + if (!offThreadAllocationTriggered.exchange(true)) { + SCOPE_EXIT { + asyncMoveWaitFlag = false; + asyncMoveWait.notifyAll(); + }; + // Executed by the first thread hitting the test value location. This + // allocation will trigger arbitration and fail. + VELOX_ASSERT_THROW( + joinBuildPool.load()->allocate(allocSize), + "Exceeded memory pool cap"); + } + })); + + // Wait for allocation (hence arbitration) on the prepare thread to finish + // before calling AsyncSource::move(). This is to ensure no other AsyncSource + // (hence arbitration) is running on the driver thread (on-thread) before the + // ongoing arbitration finishes. Without ensuring this, the on-thread + // arbitration (triggered by calling AsyncSource::move() first) has + // thread-local driver context by default, defying the purpose of this test. + SCOPED_TESTVALUE_SET( + "facebook::velox::AsyncSource::move", + std::function([&](void* /* unused */) { + asyncMoveWait.await([&]() { return !asyncMoveWaitFlag.load(); }); + })); + + std::vector probeInput = {probeSideVector}; + std::vector buildInput = {buildSideVector}; + auto planNodeIdGenerator = std::make_shared(); + const auto spillDirectory = TempDirectoryPath::create(); + AssertQueryBuilder(duckDbQueryRunner_) + .spillDirectory(spillDirectory->getPath()) + .config(core::QueryConfig::kSpillEnabled, true) + .config(core::QueryConfig::kJoinSpillEnabled, true) + // Set very low table size threshold to trigger parallel build. + .config(core::QueryConfig::kMinTableRowsForParallelJoinBuild, 0) + // Set multiple hash build drivers to trigger parallel build. + .maxDrivers(numDrivers) + .queryCtx(joinQueryCtx) + .plan(PlanBuilder(planNodeIdGenerator) + .values(probeInput, true) + .hashJoin( + {"p0", "p1", "p2"}, + {"b0", "b1", "b2"}, + PlanBuilder(planNodeIdGenerator) + .values(buildInput, true) + .planNode(), + "", + {"p0", "p1", "b0", "b1"}, + core::JoinType::kInner) + .planNode()) + .assertResults( + "SELECT probe.p0, probe.p1, build.b0, build.b1 FROM probe " + "INNER JOIN build ON probe.p0 = build.b0 AND probe.p1 = build.b1 AND " + "probe.p2 = build.b2"); + ASSERT_TRUE(parallelBuildTriggered); + + // This test uses on-demand created memory manager instead of the global + // one. We need to make sure any used memory got cleaned up before exiting + // the scope + waitForAllTasksToBeDeleted(); +} + +DEBUG_ONLY_TEST_P(HashJoinTest, arbitrationTriggeredByEnsureJoinTableFit) { + // Use manual spill injection other than spill injection framework. This is + // because spill injection framework does not allow fine grain spill within a + // single operator (We do not want to spill during addInput() but only during + // finishHashBuild()). + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::HashBuild::ensureTableFits", + std::function(([&](Operator* op) { + Operator::ReclaimableSectionGuard guard(op); + memory::testingRunArbitration(op->pool()); + }))); + auto tempDirectory = TempDirectoryPath::create(); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .injectSpill(false) + .spillDirectory(tempDirectory->getPath()) + .keyTypes({BIGINT()}) + .probeVectors(1600, 5) + .buildVectors(1500, 5) + .referenceQuery( + "SELECT t_k0, t_data, u_k0, u_data FROM t, u WHERE t.t_k0 = u.u_k0") + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + const auto statsPair = taskSpilledStats(*task); + ASSERT_GT(statsPair.first.spilledBytes, 0); + }) + .run(); +} + +DEBUG_ONLY_TEST_P(HashJoinTest, joinBuildSpillError) { + const int kMemoryCapacity = 27 << 20; + // Set a small memory capacity to trigger spill. + std::unique_ptr memoryManager = + createMemoryManager(kMemoryCapacity, 0); + const auto& arbitrator = memoryManager->arbitrator(); + auto rowType = ROW( + {{"c0", INTEGER()}, + {"c1", INTEGER()}, + {"c2", VARCHAR()}, + {"c3", VARCHAR()}}); + + std::vector vectors = createVectors(16, rowType, fuzzerOpts_); + createDuckDbTable(vectors); + + std::shared_ptr joinQueryCtx = + newQueryCtx(memoryManager.get(), executor_.get(), kMemoryCapacity); + + const int numDrivers = 4; + std::atomic numAppends{0}; + const std::string injectedErrorMsg("injected spillError"); + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::SpillState::appendToPartition", + std::function([&](exec::SpillState* state) { + if (++numAppends != numDrivers) { + return; + } + VELOX_FAIL(injectedErrorMsg); + })); + + auto planNodeIdGenerator = std::make_shared(); + const auto spillDirectory = TempDirectoryPath::create(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(vectors) + .project({"c0 AS t0", "c1 AS t1", "c2 AS t2"}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values(vectors) + .project({"c0 AS u0", "c1 AS u1", "c2 AS u2"}) + .planNode(), + "", + {"t1"}, + core::JoinType::kAnti) + .planNode(); + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .queryCtx(joinQueryCtx) + .spillDirectory(spillDirectory->getPath()) + .config(core::QueryConfig::kSpillEnabled, true) + .copyResults(pool()), + injectedErrorMsg); + + waitForAllTasksToBeDeleted(); + ASSERT_EQ(arbitrator->stats().numFailures, 1); + + // Wait again here as this test uses on-demand created memory manager instead + // of the global one. We need to make sure any used memory got cleaned up + // before exiting the scope + waitForAllTasksToBeDeleted(); +} + +DEBUG_ONLY_TEST_P(HashJoinTest, probeSpillOnWaitForPeers) { + // This test creates a scenario when tester probe thread finishes processing + // input, entering kWaitForPeers state, and the other thread is still + // processing, spill is triggered properly performed. + + folly::EventCount startWait; + folly::Synchronized testerOpName; + std::atomic_bool injectedSpillOnce{false}; + + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::getOutput", + std::function([&](Operator* op) { + if (!isHashProbeMemoryPool(*op->pool())) { + return; + } + testerOpName.withWLock([&](std::string& opName) { + if (opName.empty()) { + opName = op->pool()->name(); + } + }); + if (op->pool()->name() == *testerOpName.rlock()) { + // Do not block tester thread. + return; + } + startWait.await([&]() { return injectedSpillOnce.load(); }); + })); + + // tester probe operator is guaranteed to be in kWaitForPeers state the next + // isBlocked() is called after noMoreInput() is called. + std::atomic_bool noMoreInputCalled{false}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::noMoreInput", + std::function([&](Operator* op) { + if (!isHashProbeMemoryPool(*op->pool())) { + return; + } + noMoreInputCalled = true; + })); + + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::isBlocked", + std::function([&](Operator* op) { + if (!isHashProbeMemoryPool(*op->pool())) { + return; + } + if (injectedSpillOnce || !noMoreInputCalled) { + return; + } + injectedSpillOnce = true; + EXPECT_EQ( + dynamic_cast(op)->testingState(), + ProbeOperatorState::kWaitForPeers); + testingRunArbitration(op->pool()); + })); + + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Task::requestPauseLocked", + std::function([&](Task* task) { startWait.notifyAll(); })); + + const uint64_t numDrivers{2}; + std::shared_ptr joinQueryCtx = + newQueryCtx(memory::memoryManager(), executor_.get(), kMemoryCapacity); + auto rowType = ROW({{"c0", INTEGER()}, {"c1", INTEGER()}}); + fuzzerOpts_.vectorSize = 20; + std::vector vectors = createVectors(6, rowType, fuzzerOpts_); + std::vector totalVectors; + for (auto i = 0; i < numDrivers; ++i) { + totalVectors.insert(totalVectors.end(), vectors.begin(), vectors.end()); + } + createDuckDbTable(totalVectors); + auto spillDirectory = TempDirectoryPath::create(); + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(vectors, true) + .project({"c0 AS t0", "c1 AS t1"}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values(vectors, true) + .project({"c0 AS u0", "c1 AS u1"}) + .planNode(), + "", + {"t1"}, + core::JoinType::kInner) + .planNode(); + + { + auto task = + AssertQueryBuilder(duckDbQueryRunner_) + .plan(plan) + .queryCtx(joinQueryCtx) + .spillDirectory(spillDirectory->getPath()) + .config(core::QueryConfig::kSpillEnabled, true) + .maxDrivers(numDrivers) + .assertResults("SELECT a.c1 from tmp a join tmp b on a.c0 = b.c0"); + + auto opStats = toOperatorStats(task->taskStats()); + ASSERT_GT(opStats.at("HashProbe").spilledBytes, 0); + ASSERT_EQ(opStats.at("HashBuild").spilledBytes, 0); + + const auto* arbitrator = memory::memoryManager()->arbitrator(); + ASSERT_GT(arbitrator->stats().reclaimedUsedBytes, 0); + } + waitForAllTasksToBeDeleted(); +} + +DEBUG_ONLY_TEST_P(HashJoinTest, taskWaitTimeout) { + const int queryMemoryCapacity = 128 << 20; + // Creates a large number of vectors based on the query capacity to trigger + // memory arbitration. + fuzzerOpts_.vectorSize = 10'000; + auto rowType = ROW( + {{"c0", INTEGER()}, + {"c1", INTEGER()}, + {"c2", VARCHAR()}, + {"c3", VARCHAR()}}); + const auto vectors = + createVectors(rowType, queryMemoryCapacity / 2, fuzzerOpts_); + const int numDrivers = 4; + const auto expectedResult = + runHashJoinTask(vectors, nullptr, false, numDrivers, pool(), false).data; + + for (uint64_t timeoutMs : {1'000, 30'000}) { + SCOPED_TRACE(fmt::format("timeout {}", succinctMillis(timeoutMs))); + auto memoryManager = createMemoryManager(512 << 20, 0, timeoutMs); + auto queryCtx = + newQueryCtx(memoryManager.get(), executor_.get(), queryMemoryCapacity); + + // Set test injection to block one hash build operator to inject delay when + // memory reclaim waits for task to pause. + folly::EventCount buildBlockWait; + std::atomic buildBlockWaitFlag{true}; + std::atomic blockOneBuild{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::common::memory::MemoryPoolImpl::maybeReserve", + std::function([&](memory::MemoryPool* pool) { + const std::string re(".*HashBuild"); + if (!RE2::FullMatch(pool->name(), re)) { + return; + } + if (!blockOneBuild.exchange(false)) { + return; + } + buildBlockWait.await([&]() { return !buildBlockWaitFlag.load(); }); + })); + + folly::EventCount taskPauseWait; + std::atomic taskPauseWaitFlag{false}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Task::requestPauseLocked", + std::function(([&](Task* /*unused*/) { + taskPauseWaitFlag = true; + taskPauseWait.notifyAll(); + }))); + + std::thread queryThread([&]() { + // We expect failure on short time out. + if (timeoutMs == 1'000) { + VELOX_ASSERT_THROW( + runHashJoinTask( + vectors, + queryCtx, + false, + numDrivers, + pool(), + true, + expectedResult), + "Memory reclaim failed to wait"); + } else { + // With 30s timeout, we expect either: + // 1. Success with spilling (operators become reclaimable in time), OR + // 2. kMemCapExceeded if operators remain in non-reclaimable state + try { + const auto result = runHashJoinTask( + vectors, + queryCtx, + false, + numDrivers, + pool(), + true, + expectedResult); + auto taskStats = exec::toPlanStats(result.task->taskStats()); + auto& planStats = taskStats.at(result.planNodeId); + ASSERT_GT(planStats.spilledBytes, 0); + } catch (const VeloxRuntimeError& e) { + ASSERT_EQ(e.errorCode(), error_code::kMemCapExceeded) + << "Unexpected error: " << e.what(); + } + } + }); + + // Wait for task pause to reach, and then delay for a while before unblock + // the blocked hash build operator. + taskPauseWait.await([&]() { return taskPauseWaitFlag.load(); }); + // Wait for two seconds and expect the short reclaim wait timeout. + std::this_thread::sleep_for(std::chrono::seconds(2)); + // Unblock the blocked build operator to let memory reclaim proceed. + buildBlockWaitFlag = false; + buildBlockWait.notifyAll(); + + queryThread.join(); + + // This test uses on-demand created memory manager instead of the global + // one. We need to make sure any used memory got cleaned up before exiting + // the scope + waitForAllTasksToBeDeleted(); + } +} + +DEBUG_ONLY_TEST_P(HashJoinTest, hashProbeSpill) { + struct { + bool triggerBuildSpill; + // Triggers after no more input or not. + bool afterNoMoreInput; + // The index of get output call to trigger probe side spilling. + int probeOutputIndex; + + std::string debugString() const { + return fmt::format( + "triggerBuildSpill: {}, afterNoMoreInput: {}, probeOutputIndex: {}", + triggerBuildSpill, + afterNoMoreInput, + probeOutputIndex); + } + } testSettings[] = { + {false, false, 0}, + {false, false, 1}, + {false, false, 10}, + {false, true, 0}, + {true, false, 0}, + {true, false, 1}, + {true, false, 10}, + {true, true, 0}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + std::atomic_bool injectBuildSpillOnce{true}; + std::atomic_int buildInputCount{0}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::addInput", + std::function([&](Operator* op) { + if (!testData.triggerBuildSpill) { + return; + } + if (!isHashBuildMemoryPool(*op->pool())) { + return; + } + if (buildInputCount++ != 1) { + return; + } + if (!injectBuildSpillOnce.exchange(false)) { + return; + } + testingRunArbitration(op->pool()); + })); + + std::atomic_bool injectProbeSpillOnce{true}; + std::atomic_int probeOutputCount{0}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::getOutput", + std::function([&](Operator* op) { + if (!isHashProbeMemoryPool(*op->pool())) { + return; + } + if (testData.afterNoMoreInput) { + if (!op->testingNoMoreInput()) { + return; + } + } else { + if (probeOutputCount++ != testData.probeOutputIndex) { + return; + } + } + if (!injectProbeSpillOnce.exchange(false)) { + return; + } + testingRunArbitration(op->pool()); + })); + + fuzzerOpts_.vectorSize = 128; + auto probeVectors = createVectors(10, probeType_, fuzzerOpts_); + auto buildVectors = createVectors(20, buildType_, fuzzerOpts_); + const auto spillDirectory = TempDirectoryPath::create(); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(1) + .spillDirectory(spillDirectory->getPath()) + .probeKeys({"t_k1"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u_k1"}) + .buildVectors(std::move(buildVectors)) + .config(core::QueryConfig::kJoinSpillEnabled, "true") + .joinType(core::JoinType::kRight) + .joinOutputLayout({"t_k1", "t_k2", "u_k1", "t_v1"}) + .referenceQuery( + "SELECT t.t_k1, t.t_k2, u.u_k1, t.t_v1 FROM t RIGHT JOIN u ON t.t_k1 = u.u_k1") + .injectSpill(false) + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + auto opStats = toOperatorStats(task->taskStats()); + if (!parallelBuildSideRowsEnabled_) { + ASSERT_GT(opStats.at("HashProbe").spilledBytes, 0); + } + if (testData.triggerBuildSpill) { + ASSERT_GT(opStats.at("HashBuild").spilledBytes, 0); + } else { + ASSERT_EQ(opStats.at("HashBuild").spilledBytes, 0); + } + + const auto* arbitrator = memory::memoryManager()->arbitrator(); + ASSERT_GT(arbitrator->stats().numRequests, 0); + ASSERT_GT(arbitrator->stats().reclaimedUsedBytes, 0); + }) + .run(); + } +} + +DEBUG_ONLY_TEST_P(HashJoinTest, hashProbeSpillInMiddleOfLastOutputProcessing) { + std::atomic_int outputCountAfterNoMoreInout{0}; + std::atomic_bool injectOnce{true}; + ::facebook::velox::common::testutil::ScopedTestValue abc( + "facebook::velox::exec::Driver::runInternal::getOutput", + std::function([&](Operator* op) { + if (!isHashProbeMemoryPool(*op->pool())) { + return; + } + if (!op->testingNoMoreInput()) { + return; + } + if (outputCountAfterNoMoreInout++ != 1) { + return; + } + if (!injectOnce.exchange(false)) { + return; + } + testingRunArbitration(op->pool()); + })); + + fuzzerOpts_.vectorSize = 128; + auto probeVectors = createVectors(10, probeType_, fuzzerOpts_); + auto buildVectors = createVectors(20, buildType_, fuzzerOpts_); + + const auto spillDirectory = TempDirectoryPath::create(); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(1) + .spillDirectory(spillDirectory->getPath()) + .probeKeys({"t_k1"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u_k1"}) + .buildVectors(std::move(buildVectors)) + .config(core::QueryConfig::kJoinSpillEnabled, "true") + .config(core::QueryConfig::kPreferredOutputBatchRows, std::to_string(10)) + .joinType(core::JoinType::kRight) + .joinOutputLayout({"t_k1", "t_k2", "u_k1", "t_v1"}) + .referenceQuery( + "SELECT t.t_k1, t.t_k2, u.u_k1, t.t_v1 FROM t RIGHT JOIN u ON t.t_k1 = u.u_k1") + .injectSpill(false) + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + auto opStats = toOperatorStats(task->taskStats()); + ASSERT_GT(opStats.at("HashProbe").spilledBytes, 0); + // Verifies that we only spill the output which is single partitioned + // but not the hash table. + ASSERT_EQ(opStats.at("HashProbe").spilledPartitions, 1); + }) + .run(); +} + +// Inject probe-side spilling in the middle of output processing. If +// 'recursiveSpill' is true, we trigger probe-spilling when probe the hash table +// built from spilled data. +DEBUG_ONLY_TEST_P(HashJoinTest, hashProbeSpillInMiddleOfOutputProcessing) { + for (bool recursiveSpill : {false, true}) { + std::atomic_int buildInputCount{0}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::addInput", + std::function([&](Operator* op) { + if (!isHashBuildMemoryPool(*op->pool())) { + return; + } + if (!recursiveSpill) { + return; + } + // Trigger spill after the build side has processed some rows. + if (buildInputCount++ != 1) { + return; + } + testingRunArbitration(op->pool()); + })); + + std::atomic_bool injectProbeSpillOnce{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::getOutput", + std::function([&](Operator* op) { + if (!isHashProbeMemoryPool(*op->pool())) { + return; + } + + if (op->testingHasInput()) { + return; + } + if (recursiveSpill) { + if (static_cast(op)->testingHasInputSpiller()) { + return; + } + } + if (!injectProbeSpillOnce.exchange(false)) { + return; + } + testingRunArbitration(op->pool()); + })); + + fuzzerOpts_.vectorSize = 128; + auto probeVectors = createVectors(10, probeType_, fuzzerOpts_); + auto buildVectors = createVectors(20, buildType_, fuzzerOpts_); + + const auto spillDirectory = TempDirectoryPath::create(); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(1) + .spillDirectory(spillDirectory->getPath()) + .probeKeys({"t_k1"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u_k1"}) + .buildVectors(std::move(buildVectors)) + .config(core::QueryConfig::kJoinSpillEnabled, "true") + .config( + core::QueryConfig::kPreferredOutputBatchRows, std::to_string(10)) + .joinType(core::JoinType::kRight) + .joinOutputLayout({"t_k1", "t_k2", "u_k1", "t_v1"}) + .referenceQuery( + "SELECT t.t_k1, t.t_k2, u.u_k1, t.t_v1 FROM t RIGHT JOIN u ON t.t_k1 = u.u_k1") + .injectSpill(false) + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + auto opStats = toOperatorStats(task->taskStats()); + ASSERT_GT(opStats.at("HashProbe").spilledBytes, 0); + ASSERT_GT(opStats.at("HashProbe").spilledPartitions, 1); + }) + .run(); + } +} + +DEBUG_ONLY_TEST_P(HashJoinTest, hashProbeSpillWhenOneOfProbeFinish) { + const int numDrivers{3}; + + std::atomic_bool probeWaitFlag{true}; + folly::EventCount probeWait; + std::atomic_int numBlockedProbeOps{0}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::getOutput", + std::function([&](Operator* op) { + if (!isHashProbeMemoryPool(*op->pool())) { + return; + } + if (++numBlockedProbeOps <= numDrivers - 1) { + probeWait.await([&]() { return !probeWaitFlag.load(); }); + return; + } + })); + + std::atomic_bool notifyOnce{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::noMoreInput", + std::function([&](Operator* op) { + if (!isHashProbeMemoryPool(*op->pool())) { + return; + } + if (!notifyOnce.exchange(false)) { + return; + } + probeWaitFlag = false; + probeWait.notifyAll(); + })); + + std::thread queryThread([&]() { + const auto spillDirectory = TempDirectoryPath::create(); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers, true, true) + .spillDirectory(spillDirectory->getPath()) + .keyTypes({BIGINT()}) + .probeVectors(32, 5) + .buildVectors(32, 5) + .config(core::QueryConfig::kJoinSpillEnabled, "true") + .referenceQuery( + "SELECT t_k0, t_data, u_k0, u_data FROM t, u WHERE t.t_k0 = u.u_k0") + .injectSpill(false) + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + auto opStats = toOperatorStats(task->taskStats()); + ASSERT_EQ(opStats.at("HashBuild").spilledBytes, 0); + ASSERT_GT(opStats.at("HashProbe").spilledBytes, 0); + }) + .run(); + }); + // Wait until one of the hash probe operator has finished. + probeWait.await([&]() { return !probeWaitFlag.load(); }); + memory::testingRunArbitration(); + queryThread.join(); +} + +DEBUG_ONLY_TEST_P(HashJoinTest, hashProbeSpillExceedLimit) { + // If 'buildTriggerSpill' is true, then spilling is triggered by hash build. + for (const bool buildTriggerSpill : {false, true}) { + SCOPED_TRACE(fmt::format("buildTriggerSpill {}", buildTriggerSpill)); + + SCOPED_TESTVALUE_SET( + "facebook::velox::common::memory::MemoryPoolImpl::maybeReserve", + std::function([&](memory::MemoryPool* pool) { + if (buildTriggerSpill && !isHashBuildMemoryPool(*pool)) { + return; + } + if (!buildTriggerSpill && !isHashProbeMemoryPool(*pool)) { + return; + } + testingRunArbitration(pool); + })); + + fuzzerOpts_.vectorSize = 128; + auto probeVectors = createVectors(32, probeType_, fuzzerOpts_); + auto buildVectors = createVectors(64, buildType_, fuzzerOpts_); + for (int i = 0; i < probeVectors.size(); ++i) { + const auto probeKeyChannel = probeType_->getChildIdx("t_k1"); + const auto buildKeyChannle = buildType_->getChildIdx("u_k1"); + probeVectors[i]->childAt(probeKeyChannel) = + buildVectors[i]->childAt(buildKeyChannle); + } + + const auto spillDirectory = TempDirectoryPath::create(); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(1) + .spillDirectory(spillDirectory->getPath()) + .probeKeys({"t_k1"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u_k1"}) + .buildVectors(std::move(buildVectors)) + .config(core::QueryConfig::kMaxSpillLevel, "1") + .config(core::QueryConfig::kSpillNumPartitionBits, "1") + .config(core::QueryConfig::kJoinSpillEnabled, "true") + // Set small write buffer size to have small vectors to read from + // spilled data. + .config(core::QueryConfig::kSpillWriteBufferSize, "1") + .config( + core::QueryConfig::kPreferredOutputBatchRows, std::to_string(10)) + .joinType(core::JoinType::kRight) + .joinOutputLayout({"t_k1", "t_k2", "u_k1", "t_v1"}) + .referenceQuery( + "SELECT t.t_k1, t.t_k2, u.u_k1, t.t_v1 FROM t RIGHT JOIN u ON t.t_k1 = u.u_k1") + .injectSpill(false) + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + auto opStats = toOperatorStats(task->taskStats()); + if (buildTriggerSpill) { + ASSERT_GT(opStats.at("HashProbe").spilledBytes, 0); + ASSERT_GT(opStats.at("HashBuild").spilledBytes, 0); + } else { + ASSERT_GT(opStats.at("HashProbe").spilledBytes, 0); + ASSERT_EQ(opStats.at("HashBuild").spilledBytes, 0); + } + ASSERT_GT( + opStats.at("HashProbe") + .runtimeStats[std::string(Operator::kExceededMaxSpillLevel)] + .sum, + 0); + ASSERT_GT( + opStats.at("HashBuild") + .runtimeStats[std::string(Operator::kExceededMaxSpillLevel)] + .sum, + 0); + }) + .run(); + } +} + +DEBUG_ONLY_TEST_P(HashJoinTest, hashProbeSpillUnderNonReclaimableSection) { + std::atomic_bool injectOnce{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::common::memory::MemoryPoolImpl::allocateNonContiguous", + std::function([&](memory::MemoryPool* pool) { + if (!isHashProbeMemoryPool(*pool)) { + return; + } + if (!injectOnce.exchange(false)) { + return; + } + auto* arbitrator = memory::memoryManager()->arbitrator(); + const auto numNonReclaimableAttempts = + arbitrator->stats().numNonReclaimableAttempts; + testingRunArbitration(pool); + // Verifies that we run into non-reclaimable section when reclaim from + // hash probe. + ASSERT_EQ( + arbitrator->stats().numNonReclaimableAttempts, + numNonReclaimableAttempts + 1); + })); + + const auto spillDirectory = TempDirectoryPath::create(); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(1) + .spillDirectory(spillDirectory->getPath()) + .keyTypes({BIGINT()}) + .probeVectors(32, 5) + .buildVectors(32, 5) + .config(core::QueryConfig::kJoinSpillEnabled, "true") + .referenceQuery( + "SELECT t_k0, t_data, u_k0, u_data FROM t, u WHERE t.t_k0 = u.u_k0") + .injectSpill(false) + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + auto opStats = toOperatorStats(task->taskStats()); + ASSERT_EQ(opStats.at("HashProbe").spilledBytes, 0); + ASSERT_EQ(opStats.at("HashBuild").spilledBytes, 0); + }) + .run(); +} + +// This test case is to cover the case that hash probe trigger spill for right +// semi join types and the pending input needs to be processed in multiple +// steps. +DEBUG_ONLY_TEST_P(HashJoinTest, spillOutputWithRightSemiJoins) { + for (const auto joinType : + {core::JoinType::kRightSemiFilter, core::JoinType::kRightSemiProject}) { + std::atomic_bool injectOnce{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::getOutput", + std::function([&](Operator* op) { + if (op->operatorCtx()->operatorType() != "HashProbe") { + return; + } + if (!op->testingHasInput()) { + return; + } + if (!injectOnce.exchange(false)) { + return; + } + testingRunArbitration(op->pool()); + })); + + std::string duckDbSqlReference; + std::vector joinOutputLayout; + bool nullAware{false}; + if (joinType == core::JoinType::kRightSemiProject) { + duckDbSqlReference = "SELECT u_k2, u_k1 IN (SELECT t_k1 FROM t) FROM u"; + joinOutputLayout = {"u_k2", "match"}; + // Null aware is only supported for semi projection join type. + nullAware = true; + } else { + duckDbSqlReference = + "SELECT u_k2 FROM u WHERE u_k1 IN (SELECT t_k1 FROM t)"; + joinOutputLayout = {"u_k2"}; + } + + const auto spillDirectory = TempDirectoryPath::create(); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(1) + .spillDirectory(spillDirectory->getPath()) + .probeType(probeType_) + .probeVectors(128, 3) + .probeKeys({"t_k1"}) + .buildType(buildType_) + .buildVectors(128, 4) + .buildKeys({"u_k1"}) + .joinType(joinType) + // Set a small number of output rows to process the input in multiple + // steps. + .config( + core::QueryConfig::kPreferredOutputBatchRows, std::to_string(10)) + .injectSpill(false) + .joinOutputLayout(std::move(joinOutputLayout)) + .nullAware(nullAware) + .referenceQuery(duckDbSqlReference) + .run(); + } +} + +DEBUG_ONLY_TEST_P(HashJoinTest, spillCheckOnLeftSemiFilterWithDynamicFilters) { + const int32_t numSplits = 10; + const int32_t numRowsProbe = 333; + const int32_t numRowsBuild = 100; + + std::vector probeVectors; + probeVectors.reserve(numSplits); + + std::vector> tempFiles; + for (int32_t i = 0; i < numSplits; ++i) { + auto rowVector = makeRowVector({ + makeFlatVector( + numRowsProbe, [&](auto row) { return row - i * 10; }), + makeFlatVector(numRowsProbe, [](auto row) { return row; }), + }); + probeVectors.push_back(rowVector); + tempFiles.push_back(TempFilePath::create()); + writeToFile(tempFiles.back()->getPath(), rowVector); + } + auto makeInputSplits = [&](const core::PlanNodeId& nodeId) { + return [&] { + std::vector probeSplits; + for (auto& file : tempFiles) { + probeSplits.push_back( + exec::Split(makeHiveConnectorSplit(file->getPath()))); + } + SplitInput splits; + splits.emplace(nodeId, probeSplits); + return splits; + }; + }; + + // 100 key values in [35, 233] range. + std::vector buildVectors; + for (int i = 0; i < 5; ++i) { + buildVectors.push_back(makeRowVector({ + makeFlatVector( + numRowsBuild / 5, + [i](auto row) { return 35 + 2 * (row + i * numRowsBuild / 5); }), + makeFlatVector(numRowsBuild / 5, [](auto row) { return row; }), + })); + } + std::vector keyOnlyBuildVectors; + for (int i = 0; i < 5; ++i) { + keyOnlyBuildVectors.push_back( + makeRowVector({makeFlatVector(numRowsBuild / 5, [i](auto row) { + return 35 + 2 * (row + i * numRowsBuild / 5); + })})); + } + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto probeType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + + auto planNodeIdGenerator = std::make_shared(); + + auto buildSide = PlanBuilder(planNodeIdGenerator, pool_.get()) + .values(buildVectors) + .project({"c0 AS u_c0", "c1 AS u_c1"}) + .planNode(); + + // Left semi join. + core::PlanNodeId probeScanId; + core::PlanNodeId joinNodeId; + const auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) + .tableScan(probeType) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"c0"}, + {"u_c0"}, + buildSide, + "", + {"c0", "c1"}, + core::JoinType::kLeftSemiFilter) + .capturePlanNodeId(joinNodeId) + .project({"c0", "c1 + 1"}) + .planNode(); + + std::atomic_bool injectOnce{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::getOutput", + std::function([&](Operator* op) { + if (op->operatorCtx()->operatorType() != "HashProbe") { + return; + } + if (!op->testingHasInput()) { + return; + } + if (!injectOnce.exchange(false)) { + return; + } + testingRunArbitration(op->pool()); + })); + + auto spillDirectory = TempDirectoryPath::create(); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .makeInputSplits(makeInputSplits(probeScanId)) + .spillDirectory(spillDirectory->getPath()) + .injectSpill(false) + .referenceQuery( + "SELECT t.c0, t.c1 + 1 FROM t WHERE t.c0 IN (SELECT c0 FROM u)") + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + // Verify spill hasn't triggered. + auto taskStats = exec::toPlanStats(task->taskStats()); + auto& planStats = taskStats.at(joinNodeId); + ASSERT_GT(planStats.spilledBytes, 0); + }) + .run(); +} + +// This test is to verify there is no memory reservation made before hash probe +// start processing. This can cause unnecessary spill and query OOM under some +// real workload with many stages as each hash probe might reserve non-trivial +// amount of memory. +DEBUG_ONLY_TEST_P( + HashJoinTest, + hashProbeMemoryReservationCheckBeforeProbeStartWithSpillEnabled) { + fuzzerOpts_.vectorSize = 128; + auto probeVectors = createVectors(10, probeType_, fuzzerOpts_); + auto buildVectors = createVectors(20, buildType_, fuzzerOpts_); + + std::atomic_bool checkOnce{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::addInput", + std::function(([&](Operator* op) { + if (op->operatorType() != "HashProbe") { + return; + } + if (!checkOnce.exchange(false)) { + return; + } + ASSERT_EQ(op->pool()->usedBytes(), 0); + ASSERT_EQ(op->pool()->reservedBytes(), 0); + }))); + + const auto spillDirectory = TempDirectoryPath::create(); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(1) + .spillDirectory(spillDirectory->getPath()) + .probeKeys({"t_k1"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u_k1"}) + .buildVectors(std::move(buildVectors)) + .config(core::QueryConfig::kJoinSpillEnabled, "true") + .joinType(core::JoinType::kInner) + .joinOutputLayout({"t_k1", "t_k2", "u_k1", "t_v1"}) + .referenceQuery( + "SELECT t.t_k1, t.t_k2, u.u_k1, t.t_v1 FROM t JOIN u ON t.t_k1 = u.u_k1") + .injectSpill(true) + .verifier([&](const std::shared_ptr& task, bool injectSpill) { + if (!injectSpill) { + return; + } + auto opStats = toOperatorStats(task->taskStats()); + ASSERT_GT(opStats.at("HashProbe").spilledBytes, 0); + ASSERT_GE(opStats.at("HashProbe").spilledPartitions, 1); + }) + .run(); +} + +TEST_P(HashJoinTest, nanKeys) { + // Verify the NaN values with different binary representations are considered + // equal. + static const double kNan = std::numeric_limits::quiet_NaN(); + static const double kSNaN = std::numeric_limits::signaling_NaN(); + auto probeInput = makeRowVector( + {makeFlatVector({kNan, kSNaN}), makeFlatVector({1, 2})}); + auto buildInput = makeRowVector({makeFlatVector({kNan, 1})}); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values({probeInput}) + .project({"c0 AS t0", "c1 AS t1"}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values({buildInput}) + .project({"c0 AS u0"}) + .planNode(), + "", + {"t0", "u0", "t1"}, + core::JoinType::kLeft) + .planNode(); + auto queryCtx = core::QueryCtx::create(executor_.get()); + auto result = + AssertQueryBuilder(plan).queryCtx(queryCtx).copyResults(pool_.get()); + auto expected = makeRowVector( + {makeFlatVector({kNan, kNan}), + makeFlatVector({kNan, kNan}), + makeFlatVector({1, 2})}); + facebook::velox::test::assertEqualVectors(expected, result); +} + +DEBUG_ONLY_TEST_P(HashJoinTest, spillOnBlockedProbe) { + auto blockedOperatorFactoryUniquePtr = + std::make_unique(); + auto blockedOperatorFactory = blockedOperatorFactoryUniquePtr.get(); + Operator::registerOperator(std::move(blockedOperatorFactoryUniquePtr)); + + std::vector unblockPromises; + std::atomic_bool shouldBlock{true}; + blockedOperatorFactory->setBlockedCb([&](ContinueFuture* future) { + if (!shouldBlock) { + return BlockingReason::kNotBlocked; + } + auto [p, f] = makeVeloxContinuePromiseContract("Blocked Operator"); + *future = std::move(f); + unblockPromises.push_back(std::move(p)); + return BlockingReason::kWaitForConsumer; + }); + + folly::EventCount arbitrationWait; + std::atomic arbitrationWaitFlag{true}; + ::facebook::velox::common::testutil::ScopedTestValue _scopedTestValue15( + "facebook::velox::exec::HashBuild::finishHashBuild", + std::function([&](Operator* /* unused */) { + arbitrationWaitFlag = false; + arbitrationWait.notifyAll(); + })); + std::thread arbitrationThread([&]() { + arbitrationWait.await([&]() { return !arbitrationWaitFlag.load(); }); + memory::memoryManager()->shrinkPools(); + shouldBlock = false; + for (auto& unblockPromise : unblockPromises) { + unblockPromise.setValue(); + } + }); + + auto rowType = ROW({{"c0", INTEGER()}, {"c1", INTEGER()}}); + std::vector vectors = createVectors(1, rowType, fuzzerOpts_); + createDuckDbTable(vectors); + auto planNodeIdGenerator = std::make_shared(); + const auto spillDirectory = TempDirectoryPath::create(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(vectors) + .project({"c0 AS t0", "c1 AS t1"}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values(vectors) + .project({"c0 AS u0", "c1 AS u1"}) + .planNode(), + "", + {"t1"}, + core::JoinType::kInner) + .addNode([&](std::string id, core::PlanNodePtr input) { + return std::make_shared(id, input); + }) + .planNode(); + + { + auto task = + AssertQueryBuilder(duckDbQueryRunner_) + .plan(plan) + .queryCtx(newQueryCtx( + memory::memoryManager(), executor_.get(), kMemoryCapacity)) + .spillDirectory(spillDirectory->getPath()) + .config(core::QueryConfig::kSpillEnabled, true) + .maxDrivers(1) + .assertResults("SELECT a.c1 from tmp a join tmp b on a.c0 = b.c0"); + auto joinSpillStats = taskSpilledStats(*task); + auto buildSpillStats = joinSpillStats.first; + ASSERT_GT(buildSpillStats.spilledBytes, 0); + } + arbitrationThread.join(); + waitForAllTasksToBeDeleted(30'000'000); + Operator::unregisterAllOperators(); +} + +DEBUG_ONLY_TEST_P(HashJoinTest, buildReclaimedMemoryReport) { + constexpr int64_t kMaxBytes = 1LL << 30; // 1GB + const int32_t numBuildVectors = 3; + std::vector buildVectors; + for (int32_t i = 0; i < numBuildVectors; ++i) { + VectorFuzzer fuzzer({.vectorSize = 200}, pool()); + buildVectors.push_back(fuzzer.fuzzRow(buildType_)); + } + + const int32_t numProbeVectors = 3; + std::vector probeVectors; + for (int32_t i = 0; i < numProbeVectors; ++i) { + VectorFuzzer fuzzer({.vectorSize = 200}, pool()); + probeVectors.push_back(fuzzer.fuzzRow(probeType_)); + } + + const int numDrivers{2}; + // duckdb need double probe and build inputs as we run two drivers for hash + // join. + std::vector totalProbeVectors = probeVectors; + totalProbeVectors.insert( + totalProbeVectors.end(), probeVectors.begin(), probeVectors.end()); + std::vector totalBuildVectors = buildVectors; + totalBuildVectors.insert( + totalBuildVectors.end(), buildVectors.begin(), buildVectors.end()); + + createDuckDbTable("t", totalProbeVectors); + createDuckDbTable("u", totalBuildVectors); + + auto tempDirectory = TempDirectoryPath::create(); + auto queryPool = memory::memoryManager()->addRootPool( + "", kMaxBytes, memory::MemoryReclaimer::create()); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, true) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, true) + .planNode(), + "", + concat(probeType_->names(), buildType_->names())) + .planNode(); + + folly::EventCount driverWait; + std::atomic_bool driverWaitFlag{true}; + folly::EventCount taskWait; + std::atomic_bool taskWaitFlag{true}; + + Operator* op{nullptr}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::HashBuild::finishHashBuild", + std::function(([&](Operator* testOp) { op = testOp; }))); + + std::atomic_bool injectOnce{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::common::memory::MemoryPoolImpl::maybeReserve", + std::function( + ([&](memory::MemoryPoolImpl* pool) { + if (op == nullptr || op->pool() != pool) { + return; + } + ASSERT_TRUE(isHashBuildMemoryPool(*pool)); + ASSERT_TRUE(op->canReclaim()); + ASSERT_GT(op->pool()->usedBytes(), 0); + ASSERT_GT( + op->pool()->parent()->reservedBytes(), + op->pool()->reservedBytes()); + if (!injectOnce.exchange(false)) { + return; + } + uint64_t reclaimableBytes{0}; + const bool reclaimable = op->reclaimableBytes(reclaimableBytes); + ASSERT_TRUE(reclaimable); + ASSERT_GT(reclaimableBytes, 0); + auto* driver = op->operatorCtx()->driver(); + TestSuspendedSection suspendedSection(driver); + taskWaitFlag = false; + taskWait.notifyAll(); + driverWait.await([&]() { return !driverWaitFlag.load(); }); + }))); + + std::thread taskThread([&]() { + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers) + .planNode(plan) + .queryPool(std::move(queryPool)) + .injectSpill(false) + .spillDirectory(tempDirectory->getPath()) + .referenceQuery( + "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") + .config(core::QueryConfig::kSpillStartPartitionBit, "29") + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + const auto statsPair = taskSpilledStats(*task); + ASSERT_GT(statsPair.first.spilledBytes, 0); + ASSERT_EQ(statsPair.first.spilledPartitions, 16); + ASSERT_GT(statsPair.second.spilledBytes, 0); + ASSERT_EQ(statsPair.second.spilledPartitions, 16); + verifyTaskSpilledRuntimeStats(*task, true); + }) + .run(); + }); + + taskWait.await([&]() { return !taskWaitFlag.load(); }); + ASSERT_TRUE(op != nullptr); + auto task = op->operatorCtx()->task(); + auto* nodePool = op->pool()->parent(); + const auto nodeMemoryUsage = nodePool->reservedBytes(); + { + memory::ScopedMemoryArbitrationContext ctx(op->pool()); + const uint64_t reclaimedBytes = task->pool()->reclaim( + task->pool()->capacity(), 1'000'000, reclaimerStats_); + ASSERT_GT(reclaimedBytes, 0); + ASSERT_EQ(nodeMemoryUsage - nodePool->reservedBytes(), reclaimedBytes); + } + // Verify all the memory has been freed. + ASSERT_EQ(nodePool->reservedBytes(), 0); + + driverWaitFlag = false; + driverWait.notifyAll(); + task.reset(); + + taskThread.join(); +} + +DEBUG_ONLY_TEST_P(HashJoinTest, probeReclaimedMemoryReport) { + constexpr int64_t kMaxBytes = 1LL << 30; // 1GB + const int32_t numBuildVectors = 3; + std::vector buildVectors; + for (int32_t i = 0; i < numBuildVectors; ++i) { + VectorFuzzer fuzzer({.vectorSize = 200}, pool()); + buildVectors.push_back(fuzzer.fuzzRow(buildType_)); + } + + const int32_t numProbeVectors = 3; + std::vector probeVectors; + for (int32_t i = 0; i < numProbeVectors; ++i) { + VectorFuzzer fuzzer({.vectorSize = 200}, pool()); + probeVectors.push_back(fuzzer.fuzzRow(probeType_)); + } + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto tempDirectory = TempDirectoryPath::create(); + auto queryPool = memory::memoryManager()->addRootPool( + "", kMaxBytes, memory::MemoryReclaimer::create()); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, true) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, true) + .planNode(), + "", + concat(probeType_->names(), buildType_->names())) + .planNode(); + + folly::EventCount driverWait; + std::atomic_bool driverWaitFlag{true}; + folly::EventCount taskWait; + std::atomic_bool taskWaitFlag{true}; + + Operator* op{nullptr}; + std::atomic_int probeInputCount{0}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::addInput", + std::function(([&](Operator* testOp) { + if (testOp->operatorType() != "HashProbe") { + return; + } + op = testOp; + + ASSERT_TRUE(op->canReclaim()); + if (probeInputCount++ != 1) { + return; + } + auto* driver = op->operatorCtx()->driver(); + TestSuspendedSection suspendedSection(driver); + taskWaitFlag = false; + taskWait.notifyAll(); + driverWait.await([&]() { return !driverWaitFlag.load(); }); + }))); + + std::thread taskThread([&]() { + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(1) + .planNode(plan) + .queryPool(std::move(queryPool)) + .injectSpill(false) + .spillDirectory(tempDirectory->getPath()) + .referenceQuery( + "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") + .config(core::QueryConfig::kSpillStartPartitionBit, "29") + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + const auto statsPair = taskSpilledStats(*task); + // The spill triggered at the probe side. + ASSERT_EQ(statsPair.first.spilledBytes, 0); + ASSERT_EQ(statsPair.first.spilledPartitions, 0); + ASSERT_GT(statsPair.second.spilledBytes, 0); + ASSERT_EQ(statsPair.second.spilledPartitions, 16); + }) + .run(); + }); + + taskWait.await([&]() { return !taskWaitFlag.load(); }); + ASSERT_TRUE(op != nullptr); + auto task = op->operatorCtx()->task(); + auto* nodePool = op->pool()->parent(); + const auto nodeMemoryUsage = nodePool->reservedBytes(); + { + memory::ScopedMemoryArbitrationContext ctx(op->pool()); + const uint64_t reclaimedBytes = task->pool()->reclaim( + task->pool()->capacity(), 1'000'000, reclaimerStats_); + ASSERT_GT(reclaimedBytes, 0); + ASSERT_EQ(nodeMemoryUsage - nodePool->reservedBytes(), reclaimedBytes); + } + // Verify all the memory has been freed, except for the ones for hash lookup. + ASSERT_EQ(nodePool->reservedBytes(), 1048576); + + driverWaitFlag = false; + driverWait.notifyAll(); + task.reset(); + + taskThread.join(); +} + +DEBUG_ONLY_TEST_P(HashJoinTest, hashTableCleanupAfterProbeFinish) { + auto buildVectors = makeVectors(buildType_, 5, 100); + auto probeVectors = makeVectors(probeType_, 5, 100); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + HashProbe* probeOp{nullptr}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::getOutput", + std::function([&](Operator* op) { + if (probeOp == nullptr && op->operatorType() == "HashProbe") { + probeOp = dynamic_cast(op); + } + })); + + bool tableEmpty{false}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::noMoreInput", + std::function([&](Operator* op) { + if (op->operatorType() == "FilterProject") { + tableEmpty = (probeOp->testingTable()->numDistinct() == 0); + } + })); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, true) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, true) + .planNode(), + "", + concat(probeType_->names(), buildType_->names())) + .project({"t_k1", "t_k2", "t_v1", "u_k1", "u_k2", "u_v1"}) + .planNode(); + + auto tempDirectory = TempDirectoryPath::create(); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(1) + .planNode(plan) + .injectSpill(false) + .spillDirectory(tempDirectory->getPath()) + .referenceQuery( + "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") + .config(core::QueryConfig::kSpillStartPartitionBit, "29") + .run(); + ASSERT_TRUE(tableEmpty); +} + +TEST_P(HashJoinTest, innerJoinForTypeWithCustomComparisonAndSmallVector) { + // This test corresponds to the SQL query: + // SELECT + // LEFT_TABLE.ip_addr as ip_left_string, + // RIGHT_TABLE.ip_addr as ip_right_string, + // CAST(LEFT_TABLE.ip_addr AS IPADDRESS) as ip_left_ip_address_type, + // CAST(RIGHT_TABLE.ip_addr AS IPADDRESS) as ip_right_ip_address_type, + // CAST(LEFT_TABLE.ip_addr AS IPADDRESS) = CAST(RIGHT_TABLE.ip_addr AS + // IPADDRESS) as are_equal_as_ip_address_type + // FROM + // (VALUES ('2620:10d:c0a8:f0::37'), ('2620:10d:c053:33::37')) AS + // LEFT_TABLE(ip_addr) INNER JOIN (VALUES ('2620:10d:c0a8:f0::37')) AS + // RIGHT_TABLE(ip_addr) ON CAST(LEFT_TABLE.ip_addr AS IPADDRESS) = + // CAST(RIGHT_TABLE.ip_addr AS IPADDRESS) + // LIMIT 1000 + + auto leftVectors = makeRowVector({makeFlatVector( + {StringView("2620:10d:c0a8:f0::37"), + StringView("2620:10d:c053:33::37")})}); + + auto rightVectors = makeRowVector( + {makeFlatVector({StringView("2620:10d:c0a8:f0::37")})}); + createDuckDbTable("t", {leftVectors}); + createDuckDbTable("u", {rightVectors}); + + auto planNodeIdGenerator = std::make_shared(); + + auto rightPlan = PlanBuilder(planNodeIdGenerator) + .values({rightVectors}) + .project( + {"c0 AS ip_addr_right", + "CAST(c0 AS IPADDRESS) AS ip_addr_cast_right"}) + .planNode(); + + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({leftVectors}) + .project({"c0 AS ip_addr", "CAST(c0 AS IPADDRESS) AS ip_addr_cast"}) + .hashJoin( + {"ip_addr_cast"}, + {"ip_addr_cast_right"}, + rightPlan, + "", + {"ip_addr", "ip_addr_cast"}, + core::JoinType::kInner) + .limit(0, 1000, false) + .planNode(); + + auto result = AssertQueryBuilder(plan).copyResults(pool()); + + ASSERT_EQ(result->size(), 1); + auto ipAddr = result->childAt(0)->as>(); + + // We expect 1 row (only the matching IPv6 address: 2620:10d:c0a8:f0::37) + ASSERT_EQ(ipAddr->valueAt(0), StringView("2620:10d:c0a8:f0::37")); + + // Test that different IPADDRESS values correctly don't match in hash join. + leftVectors = makeRowVector({ + makeFlatVector({ + "2620:10d:c053:33::37"_sv, + }), + }); + + rightVectors = makeRowVector({ + makeFlatVector({ + "2620:10d:c0a8:f0::37"_sv, + }), + }); + + planNodeIdGenerator = std::make_shared(); + + rightPlan = PlanBuilder(planNodeIdGenerator) + .values({rightVectors}) + .project( + {"c0 AS ip_addr_right", + "CAST(c0 AS IPADDRESS) AS ip_addr_cast_right"}) + .planNode(); + + plan = PlanBuilder(planNodeIdGenerator) + .values({leftVectors}) + .project( + {"c0 AS ip_left", + "CAST(c0 AS IPADDRESS) AS ip_left_cast", + "CAST(c0 AS VARCHAR) AS ip_left_string"}) + .hashJoin( + {"ip_left_cast"}, + {"ip_addr_cast_right"}, + rightPlan, + "", + {"ip_left_cast", "ip_addr_cast_right", "ip_addr_right"}, + core::JoinType::kInner) + .planNode(); + + // Result should be empty since the IP addresses are different + result = AssertQueryBuilder(plan).copyResults(pool()); + ASSERT_EQ(result->size(), 0) + << "Expected no matches between different IP addresses, but got " + << result->size() << " rows"; +} + +/// Test hash join where build-side keys have a type that supports custom +/// comparison and come from a small range which would allow for array-based +/// lookup instead of a hash table for other types. +TEST_P(HashJoinTest, arrayBasedLookupCustomComparisonType) { + std::vector probeVectors = { + makeRowVector({makeFlatVector( + 1'024, + [](auto row) { return row; }, + nullptr, + velox::test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON())})}; + + std::vector buildVectors = { + makeRowVector({makeFlatVector( + 256, + [](auto row) { return row; }, + nullptr, + velox::test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON())})}; + + auto planNodeIdGenerator = std::make_shared(); + + auto rightPlan = PlanBuilder(planNodeIdGenerator) + .values({buildVectors}) + .project({"c0 as right"}) + .planNode(); + + auto plan = PlanBuilder(planNodeIdGenerator) + .values({probeVectors}) + .project({"c0 as left"}) + .hashJoin( + {"left"}, + {"right"}, + rightPlan, + "", + {"left"}, + core::JoinType::kInner) + .planNode(); + + auto result = AssertQueryBuilder(plan).copyResults(pool()); + + // The probe side consists of the values 0-1023, the build side consists of + // the values 0-255. If custom comparison is not respected, the join will + // produce 256 values (0-255). When custom comparison is respected equality is + // treated mod 256 so we get 1024 values (0-1023). + EXPECT_EQ(result->size(), 1'024); +} + +DEBUG_ONLY_TEST_P( + HashJoinTest, + hashProbeShouldYieldWhenFilterConsistentlyRejectAll) { + const uint32_t kProbeSize = 100; + const uint32_t kBuildSize = 10'000; + const uint64_t kDriverCpuTimeSliceLimitMs = 1'000; + const std::string kLargeBatchSize = + folly::to(kProbeSize * kBuildSize); + + struct { + uint32_t numGetOutputCalls; + bool hasDelay; + std::string debugString() const { + return fmt::format( + "numGetOutputCalls: {}, hasDelay: {}", numGetOutputCalls, hasDelay); + } + } testSettings[] = {{0, false}, {0, true}}; + + // Create probe data with keys 0-99 and an additional filter column + const auto probeData = makeRowVector( + {"t_k1", "t_filter"}, + { + makeFlatVector(kProbeSize, [](auto row) { return row; }), + makeFlatVector( + kProbeSize, + [](/*row=*/auto) { return 1; }), // All rows have value 1 + }); + + const auto buildData = makeRowVector( + {"u_k1"}, + { + makeFlatVector(kBuildSize, [](auto row) { return row; }), + }); + + createDuckDbTable("t", {probeData}); + createDuckDbTable("u", {buildData}); + + auto planNodeIdGenerator = std::make_shared(); + auto planNode = + PlanBuilder(planNodeIdGenerator) + .values({probeData}) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator).values({buildData}).planNode(), + // Filter that DOES find join matches but then rejects all of them + // This ensures numOut > 0 after listJoinResults, but == 0 after + // evalFilter All probe rows have t_filter=1, so the condition + // t_filter > 100000 rejects all + "t_filter > 100000", + {"t_k1", "u_k1"}, + core::JoinType::kInner) + .planNode(); + + for (auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + std::atomic_int hashProbeGetOutputCalls{0}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::getOutput", + std::function([&](Operator* op) { + if (op->operatorType() == "HashProbe") { + // Inject delay on the 2nd getOutput call when hasDelay is true + // This simulates the scenario where: + // 1. First getOutput: Probe data added via addInput + // 2. Second getOutput: Join finds matches, filter rejects all + // During this call, we inject delay INSIDE the processing + // to simulate CPU-intensive work in the loop + if (hashProbeGetOutputCalls.fetch_add(1) == 1 && + testData.hasDelay) { + std::this_thread::sleep_for( + std::chrono::milliseconds(2 * kDriverCpuTimeSliceLimitMs)); + } + } + })); + + auto queryCtx = core::QueryCtx::create( + executor_.get(), + core::QueryConfig({ + {core::QueryConfig::kDriverCpuTimeSliceLimitMs, + folly::to(kDriverCpuTimeSliceLimitMs)}, + {core::QueryConfig::kPreferredOutputBatchRows, kLargeBatchSize}, + })); + + AssertQueryBuilder(planNode, duckDbQueryRunner_) + .queryCtx(queryCtx) + .maxDrivers(1) + .assertResults( + "SELECT t_k1, u_k1 FROM t, u WHERE t_k1 = u_k1 AND t_filter > 100000"); + testData.numGetOutputCalls = hashProbeGetOutputCalls.load(); + } + ASSERT_LT( + testSettings[0].numGetOutputCalls, testSettings[1].numGetOutputCalls); +} + +// This test validates that when spillOutput() is running (toSpillOutput=true), +// the operator should NOT yield even when shouldYield() returns true. This is +// critical because yielding during spillOutput would break the spilling loop. +DEBUG_ONLY_TEST_P( + HashJoinTest, + spillOutputShouldNotYieldWhenFilterConsistentlyRejectAll) { + const uint32_t kProbeSize = 100; + const uint32_t kBuildSize = 10'000; + const uint64_t driverCpuTimeSliceLimitMs = 1'000; + const std::string largeBatchSize = + folly::to(kProbeSize * kBuildSize); + + // Create probe data with keys 0-99 and an additional filter column + const auto probeData = makeRowVector( + {"t_k1", "t_filter"}, + { + makeFlatVector(kProbeSize, [](auto row) { return row; }), + makeFlatVector( + kProbeSize, + [](/*row=*/auto) { return 1; }), // All rows have value 1 + }); + + const auto buildData = makeRowVector( + {"u_k1"}, + { + makeFlatVector(kBuildSize, [](auto row) { return row; }), + }); + + createDuckDbTable("t", {probeData}); + createDuckDbTable("u", {buildData}); + + auto planNodeIdGenerator = std::make_shared(); + auto planNode = + PlanBuilder(planNodeIdGenerator) + .values({probeData}) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator).values({buildData}).planNode(), + // Filter that DOES find join matches but then rejects all of them + // This ensures numOut > 0 after listJoinResults, but == 0 after + // evalFilter. All probe rows have t_filter=1, so the condition + // t_filter > 100000 rejects all + "t_filter > 100000", + {"t_k1", "u_k1"}, + core::JoinType::kInner) + .planNode(); + + std::atomic_bool spillTriggered{false}; + ::facebook ::velox ::common ::testutil ::ScopedTestValue _scopedTestValue5200( + "facebook::velox::exec::Driver::runInternal::getOutput", + std::function([&](Operator* op) { + if (spillTriggered.load() || op->operatorType() != "HashProbe" || + !op->testingHasInput()) { + return; + } + spillTriggered = true; + testingRunArbitration(op->pool()); + })); + + // We inject delay in reclaim to trigger shouldYield(). + // The test verifies that the query completes successfully despite + // shouldYield() returning true, which would only happen if the + // !toSpillOutput check prevents early return. + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::HashProbe::reclaim", + std::function([&](HashProbe* probe) { + if (!spillTriggered.load()) { + return; + } + // Inject delay once to trigger shouldYield() + std::this_thread::sleep_for( + std::chrono::milliseconds(2 * driverCpuTimeSliceLimitMs)); + })); + + const auto spillDirectory = TempDirectoryPath::create(); + AssertQueryBuilder(planNode, duckDbQueryRunner_) + .queryCtx(core::QueryCtx::create(driverExecutor_.get())) + .maxDrivers(1) + .spillDirectory(spillDirectory->getPath()) + .config(core::QueryConfig::kSpillEnabled, true) + .config(core::QueryConfig::kJoinSpillEnabled, true) + .config(core::QueryConfig::kSpillStartPartitionBit, 29) + .config( + core::QueryConfig::kDriverCpuTimeSliceLimitMs, + driverCpuTimeSliceLimitMs) + .config(core::QueryConfig::kPreferredOutputBatchRows, largeBatchSize) + .assertResults( + "SELECT t_k1, u_k1 FROM t, u WHERE t_k1 = u_k1 AND t_filter > 100000"); + + ASSERT_TRUE(spillTriggered.load()); +} + +TEST_P(MultiThreadedHashJoinTest, antiJoin) { + auto probeVectors = makeBatches(64, [&](int32_t /*unused*/) { + return makeRowVector( + {"t0", "t1"}, + { + makeNullableFlatVector({std::nullopt, 1, 2}), + makeFlatVector({0, 1, 2}), + }); + }); + auto buildVectors = makeBatches(64, [&](int32_t /*unused*/) { + return makeRowVector( + {"u0", "u1"}, + { + makeNullableFlatVector({std::nullopt, 2, 3}), + makeFlatVector({0, 2, 3}), + }); + }); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .probeKeys({"t0"}) + .probeVectors(std::vector(probeVectors)) + .buildKeys({"u0"}) + .buildVectors(std::vector(buildVectors)) + .joinType(core::JoinType::kAnti) + .joinOutputLayout({"t0", "t1"}) + .referenceQuery( + "SELECT t.* FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE u.u0 = t.t0)") + .run(); + + std::vector filters({ + "u1 > t1", + "u1 * t1 > 0", + // This filter is true on rows without a match. It should not prevent + // the row from being returned. + "coalesce(u1, t1, 0::integer) is not null", + // This filter throws if evaluated on rows without a match. The join + // should not evaluate filter on those rows and therefore should not + // fail. + "t1 / coalesce(u1, 0::integer) is not null", + // This filter triggers memory pool allocation at + // HashBuild::setupFilterForAntiJoins, which should not be invoked in + // operator's constructor. + "contains(array[1, 2, NULL], 1)", + }); + for (const std::string& filter : filters) { + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .probeKeys({"t0"}) + .probeVectors(std::vector(probeVectors)) + .buildKeys({"u0"}) + .buildVectors(std::vector(buildVectors)) + .joinType(core::JoinType::kAnti) + .joinFilter(filter) + .joinOutputLayout({"t0", "t1"}) + .referenceQuery( + fmt::format( + "SELECT t.* FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE u.u0 = t.t0 AND {})", + filter)) + .run(); + } +} + +TEST_P(MultiThreadedHashJoinTest, antiJoinWithFilterAndEmptyBuild) { + const std::vector finishOnEmptys = {false, true}; + for (const auto finishOnEmpty : finishOnEmptys) { + SCOPED_TRACE(fmt::format("finishOnEmpty: {}", finishOnEmpty)); + + auto probeVectors = makeBatches(4, [&](int32_t /*unused*/) { + return makeRowVector( + {"t0", "t1"}, + { + makeNullableFlatVector({std::nullopt, 1, 2}), + makeFlatVector({0, 1, 2}), + }); + }); + auto buildVectors = makeBatches(4, [&](int32_t /*unused*/) { + return makeRowVector( + {"u0", "u1"}, + { + makeNullableFlatVector({3, 2, 3}), + makeFlatVector({0, 2, 3}), + }); + }); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .probeKeys({"t0"}) + .probeVectors(std::vector(probeVectors)) + .buildKeys({"u0"}) + .buildVectors(std::vector(buildVectors)) + .buildFilter("u0 < 0") + .joinType(core::JoinType::kAnti) + .joinFilter("u1 > t1") + .joinOutputLayout({"t0", "t1"}) + .referenceQuery( + "SELECT t.* FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE u0 < 0 AND u.u0 = t.t0)") + .checkSpillStats(false) + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + const auto statsPair = taskSpilledStats(*task); + ASSERT_EQ(statsPair.first.spilledRows, 0); + ASSERT_EQ(statsPair.first.spilledBytes, 0); + ASSERT_EQ(statsPair.first.spilledPartitions, 0); + ASSERT_EQ(statsPair.first.spilledFiles, 0); + ASSERT_EQ(statsPair.second.spilledRows, 0); + ASSERT_EQ(statsPair.second.spilledBytes, 0); + ASSERT_EQ(statsPair.second.spilledPartitions, 0); + ASSERT_EQ(statsPair.second.spilledFiles, 0); + verifyTaskSpilledRuntimeStats(*task, false); + ASSERT_EQ(maxHashBuildSpillLevel(*task), -1); + }) + .run(); + } +} + +// Reproduces a debug-mode crash in evalFilter for ANTI join with filter. +// A batch of all matched rows followed by a batch of all non-matching rows +// triggers DictionaryVector::validate on a stale filterResult_ whose shared +// outputRowMapping_ buffer has been overwritten by listJoinResults. +TEST_P(MultiThreadedHashJoinTest, antiJoinWithBooleanKeysAndFilter) { + // 5 build rows with key (true, true). + auto buildVectors = makeBatches(1, [&](int32_t /*unused*/) { + return makeRowVector( + {"u0", "u1"}, + { + makeFlatVector(5, [](auto /*row*/) { return true; }), + makeFlatVector(5, [](auto /*row*/) { return true; }), + }); + }); + + // 10 probe rows: rows 0-1 match (2 * 5 = 10 output entries), + // rows 2-9 don't match (8 miss entries). + auto probeVectors = makeBatches(1, [&](int32_t /*unused*/) { + return makeRowVector( + {"t0", "t1"}, + { + makeFlatVector(10, [](auto row) { return row < 2; }), + makeFlatVector(10, [](auto row) { return row < 2; }), + }); + }); + + // outputBatchSize=10 forces the split: 10 matches in batch 1, + // 8 misses in batch 2. + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(1) + .probeKeys({"t0", "t1"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u0", "u1"}) + .buildVectors(std::move(buildVectors)) + .joinType(core::JoinType::kAnti) + .joinFilter("t1 = true") + .joinOutputLayout({"t0", "t1"}) + .referenceQuery( + "SELECT t.* FROM t WHERE NOT EXISTS " + "(SELECT * FROM u WHERE t.t0 = u.u0 AND t.t1 = u.u1 AND t.t1 = true)") + .config(core::QueryConfig::kPreferredOutputBatchRows, std::to_string(10)) + .checkSpillStats(false) + .run(); +} + +TEST_P(MultiThreadedHashJoinTest, leftJoin) { + // Left side keys are [0, 1, 2,..20]. + // Use 3-rd column as row number to allow for asserting the order of + // results. + std::vector probeVectors = mergeBatches( + makeBatches( + 3, + [&](int32_t /*unused*/) { + return makeRowVector( + {"c0", "c1", "row_number"}, + { + makeFlatVector( + 77, [](auto row) { return row % 21; }, nullEvery(13)), + makeFlatVector(77, [](auto row) { return row; }), + makeFlatVector(77, [](auto row) { return row; }), + }); + }), + makeBatches( + 2, + [&](int32_t /*unused*/) { + return makeRowVector( + {"c0", "c1", "row_number"}, + { + makeFlatVector( + 97, + [](auto row) { return (row + 3) % 21; }, + nullEvery(13)), + makeFlatVector(97, [](auto row) { return row; }), + makeFlatVector( + 97, [](auto row) { return 97 + row; }), + }); + }), + true); + + std::vector buildVectors = + makeBatches(3, [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 73, [](auto row) { return row % 5; }, nullEvery(7)), + makeFlatVector( + 73, [](auto row) { return -111 + row * 2; }, nullEvery(7)), + }); + }); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .probeKeys({"c0"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u_c0"}) + .buildVectors(std::move(buildVectors)) + .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) + .joinType(core::JoinType::kLeft) + .joinOutputLayout({"row_number", "c0", "c1", "u_c0"}) + .referenceQuery( + "SELECT t.row_number, t.c0, t.c1, u.c0 FROM t LEFT JOIN u ON t.c0 = u.c0") + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + int nullJoinBuildKeyCount = 0; + int nullJoinProbeKeyCount = 0; + + for (auto& pipeline : task->taskStats().pipelineStats) { + for (auto op : pipeline.operatorStats) { + if (op.operatorType == OperatorType::kHashBuild) { + nullJoinBuildKeyCount += op.numNullKeys; + } + if (op.operatorType == OperatorType::kHashProbe) { + nullJoinProbeKeyCount += op.numNullKeys; + } + } + } + ASSERT_EQ(nullJoinBuildKeyCount, 33 * GetParam().numDrivers); + ASSERT_EQ(nullJoinProbeKeyCount, 34 * GetParam().numDrivers); + }) + .run(); +} + +TEST_P(MultiThreadedHashJoinTest, nullStatsWithEmptyBuild) { + std::vector probeVectors = + makeBatches(1, [&](int32_t /*unused*/) { + return makeRowVector( + {"c0", "c1", "row_number"}, + { + makeFlatVector( + 77, [](auto row) { return row % 21; }, nullEvery(13)), + makeFlatVector(77, [](auto row) { return row; }), + makeFlatVector(77, [](auto row) { return row; }), + }); + }); + + // All null keys on build side. + std::vector buildVectors = + makeBatches(1, [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 1, [](auto row) { return row % 5; }, nullEvery(1)), + makeFlatVector( + 1, [](auto row) { return -111 + row * 2; }, nullEvery(1)), + }); + }); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .probeKeys({"c0"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u_c0"}) + .buildVectors(std::move(buildVectors)) + .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) + .joinType(core::JoinType::kLeft) + .joinOutputLayout({"row_number", "c0", "c1", "u_c0"}) + .referenceQuery( + "SELECT t.row_number, t.c0, t.c1, u.c0 FROM t LEFT JOIN u ON t.c0 = u.c0") + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + int nullJoinBuildKeyCount = 0; + int nullJoinProbeKeyCount = 0; + + for (auto& pipeline : task->taskStats().pipelineStats) { + for (auto op : pipeline.operatorStats) { + if (op.operatorType == OperatorType::kHashBuild) { + nullJoinBuildKeyCount += op.numNullKeys; + } + if (op.operatorType == OperatorType::kHashProbe) { + nullJoinProbeKeyCount += op.numNullKeys; + } + } + } + // Due to inaccurate stats tracking in case of empty build side, + // we will report 0 null keys on probe side. + ASSERT_EQ(nullJoinProbeKeyCount, 0); + ASSERT_EQ(nullJoinBuildKeyCount, 1 * GetParam().numDrivers); + }) + .checkSpillStats(false) + .run(); +} + +TEST_P(MultiThreadedHashJoinTest, leftJoinWithEmptyBuild) { + const std::vector finishOnEmptys = {false, true}; + for (const auto finishOnEmpty : finishOnEmptys) { + SCOPED_TRACE(fmt::format("finishOnEmpty: {}", finishOnEmpty)); + + // Left side keys are [0, 1, 2,..10]. + // Use 3-rd column as row number to allow for asserting the order of + // results. + std::vector probeVectors = mergeBatches( + makeBatches( + 3, + [&](int32_t /*unused*/) { + return makeRowVector( + {"c0", "c1", "row_number"}, + { + makeFlatVector( + 77, [](auto row) { return row % 11; }, nullEvery(13)), + makeFlatVector(77, [](auto row) { return row; }), + makeFlatVector(77, [](auto row) { return row; }), + }); + }), + makeBatches( + 2, + [&](int32_t /*unused*/) { + return makeRowVector( + {"c0", "c1", "row_number"}, + { + makeFlatVector( + 97, + [](auto row) { return (row + 3) % 11; }, + nullEvery(13)), + makeFlatVector(97, [](auto row) { return row; }), + makeFlatVector( + 97, [](auto row) { return 97 + row; }), + }); + }), + true); + + std::vector buildVectors = + makeBatches(3, [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 73, [](auto row) { return row % 5; }, nullEvery(7)), + makeFlatVector( + 73, [](auto row) { return -111 + row * 2; }, nullEvery(7)), + }); + }); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .probeKeys({"c0"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u_c0"}) + .buildVectors(std::move(buildVectors)) + .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) + .buildFilter("c0 < 0") + .joinType(core::JoinType::kLeft) + .joinOutputLayout({"row_number", "c1"}) + .referenceQuery( + "SELECT t.row_number, t.c1 FROM t LEFT JOIN (SELECT c0 FROM u WHERE c0 < 0) u ON t.c0 = u.c0") + .checkSpillStats(false) + .run(); + } +} + +TEST_P(MultiThreadedHashJoinTest, leftJoinWithNoJoin) { + // Left side keys are [0, 1, 2,..10]. + // Use 3-rd column as row number to allow for asserting the order of + // results. + std::vector probeVectors = mergeBatches( + makeBatches( + 3, + [&](int32_t /*unused*/) { + return makeRowVector( + {"c0", "c1", "row_number"}, + { + makeFlatVector( + 77, [](auto row) { return row % 11; }, nullEvery(13)), + makeFlatVector(77, [](auto row) { return row; }), + makeFlatVector(77, [](auto row) { return row; }), + }); + }), + makeBatches( + 2, + [&](int32_t /*unused*/) { + return makeRowVector( + {"c0", "c1", "row_number"}, + { + makeFlatVector( + 97, + [](auto row) { return (row + 3) % 11; }, + nullEvery(13)), + makeFlatVector(97, [](auto row) { return row; }), + makeFlatVector( + 97, [](auto row) { return 97 + row; }), + }); + }), + true); + + std::vector buildVectors = + makeBatches(3, [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 73, [](auto row) { return row % 5; }, nullEvery(7)), + makeFlatVector( + 73, [](auto row) { return -111 + row * 2; }, nullEvery(7)), + }); + }); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .probeKeys({"c0"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u_c0"}) + .buildVectors(std::move(buildVectors)) + .buildProjections({"c0 - 123::INTEGER AS u_c0", "c1 AS u_c1"}) + .joinType(core::JoinType::kLeft) + .joinOutputLayout({"row_number", "c0", "u_c1"}) + .referenceQuery( + "SELECT t.row_number, t.c0, u.c1 FROM t LEFT JOIN (SELECT c0 - 123::INTEGER AS u_c0, c1 FROM u) u ON t.c0 = u.u_c0") + .run(); +} + +TEST_P(MultiThreadedHashJoinTest, leftJoinWithAllMatch) { + // Left side keys are [0, 1, 2,..10]. + // Use 3-rd column as row number to allow for asserting the order of + // results. + std::vector probeVectors = mergeBatches( + makeBatches( + 3, + [&](int32_t /*unused*/) { + return makeRowVector( + {"c0", "c1", "row_number"}, + { + makeFlatVector( + 77, [](auto row) { return row % 11; }, nullEvery(13)), + makeFlatVector(77, [](auto row) { return row; }), + makeFlatVector(77, [](auto row) { return row; }), + }); + }), + makeBatches( + 2, + [&](int32_t /*unused*/) { + return makeRowVector( + {"c0", "c1", "row_number"}, + { + makeFlatVector( + 97, + [](auto row) { return (row + 3) % 11; }, + nullEvery(13)), + makeFlatVector(97, [](auto row) { return row; }), + makeFlatVector( + 97, [](auto row) { return 97 + row; }), + }); + }), + true); + + std::vector buildVectors = + makeBatches(3, [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 73, [](auto row) { return row % 5; }, nullEvery(7)), + makeFlatVector( + 73, [](auto row) { return -111 + row * 2; }, nullEvery(7)), + }); + }); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .probeKeys({"c0"}) + .probeVectors(std::move(probeVectors)) + .probeFilter("c0 < 5") + .buildKeys({"u_c0"}) + .buildVectors(std::move(buildVectors)) + .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) + .joinType(core::JoinType::kLeft) + .joinOutputLayout({"row_number", "c0", "c1", "u_c1"}) + .referenceQuery( + "SELECT t.row_number, t.c0, t.c1, u.c1 FROM (SELECT * FROM t WHERE c0 < 5) t LEFT JOIN u ON t.c0 = u.c0") + .run(); +} + +TEST_P(MultiThreadedHashJoinTest, leftJoinWithFilter) { + // Left side keys are [0, 1, 2,..10]. + // Use 3-rd column as row number to allow for asserting the order of + // results. + std::vector probeVectors = mergeBatches( + makeBatches( + 3, + [&](int32_t /*unused*/) { + return makeRowVector( + {"c0", "c1", "row_number"}, + { + makeFlatVector( + 77, [](auto row) { return row % 11; }, nullEvery(13)), + makeFlatVector(77, [](auto row) { return row; }), + makeFlatVector(77, [](auto row) { return row; }), + }); + }), + makeBatches( + 2, + [&](int32_t /*unused*/) { + return makeRowVector( + {"c0", "c1", "row_number"}, + { + makeFlatVector( + 97, + [](auto row) { return (row + 3) % 11; }, + nullEvery(13)), + makeFlatVector(97, [](auto row) { return row; }), + makeFlatVector( + 97, [](auto row) { return 97 + row; }), + }); + }), + true); + + std::vector buildVectors = + makeBatches(3, [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 73, [](auto row) { return row % 5; }, nullEvery(7)), + makeFlatVector( + 73, [](auto row) { return -111 + row * 2; }, nullEvery(7)), + }); + }); + + // Additional filter. + { + auto testProbeVectors = probeVectors; + auto testBuildVectors = buildVectors; + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .probeKeys({"c0"}) + .probeVectors(std::move(testProbeVectors)) + .buildKeys({"u_c0"}) + .buildVectors(std::move(testBuildVectors)) + .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) + .joinType(core::JoinType::kLeft) + .joinFilter("(c1 + u_c1) % 2 = 1") + .joinOutputLayout({"row_number", "c0", "c1", "u_c1"}) + .referenceQuery( + "SELECT t.row_number, t.c0, t.c1, u.c1 FROM t LEFT JOIN u ON t.c0 = u.c0 AND (t.c1 + u.c1) % 2 = 1") + .run(); + } + + // No rows pass the additional filter. + { + auto testProbeVectors = probeVectors; + auto testBuildVectors = buildVectors; + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .probeKeys({"c0"}) + .probeVectors(std::move(testProbeVectors)) + .buildKeys({"u_c0"}) + .buildVectors(std::move(testBuildVectors)) + .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) + .joinType(core::JoinType::kLeft) + .joinFilter("(c1 + u_c1) % 2 = 3") + .joinOutputLayout({"row_number", "c0", "c1", "u_c1"}) + .referenceQuery( + "SELECT t.row_number, t.c0, t.c1, u.c1 FROM t LEFT JOIN u ON t.c0 = u.c0 AND (t.c1 + u.c1) % 2 = 3") + .run(); + } +} + +/// Tests left join with a filter that may evaluate to true, false or null. +/// Makes sure that null filter results are handled correctly, e.g. as if the +/// filter returned false. +TEST_P(MultiThreadedHashJoinTest, leftJoinWithNullableFilter) { + std::vector probeVectors = mergeBatches( + makeBatches( + 5, + [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector({1, 2, 3, 4, 5}), + makeNullableFlatVector( + {10, std::nullopt, 30, std::nullopt, 50}), + }); + }), + makeBatches( + 5, + [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector({1, 2, 3, 4, 5}), + makeNullableFlatVector( + {std::nullopt, 20, 30, std::nullopt, 50}), + }); + }), + true); + + std::vector buildVectors = + makeBatches(5, [&](int32_t /*unused*/) { + return makeRowVector( + {makeFlatVector(128, [](vector_size_t row) { + if (row < 3) { + return row; + } + return row + 10; + })}); + }); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .probeKeys({"c0"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u_c0"}) + .buildVectors(std::move(buildVectors)) + .buildProjections({"c0 AS u_c0"}) + .joinType(core::JoinType::kLeft) + .joinFilter("c1 + u_c0 > 0") + .joinOutputLayout({"c0", "c1", "u_c0"}) + .referenceQuery( + "SELECT * FROM t LEFT JOIN u ON (t.c0 = u.c0 AND t.c1 + u.c0 > 0)") + .run(); +} + +TEST_P(MultiThreadedHashJoinTest, rightJoin) { + // Left side keys are [0, 1, 2,..20]. + std::vector probeVectors = mergeBatches( + makeBatches( + 3, + [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 137, [](auto row) { return row % 21; }, nullEvery(13)), + makeFlatVector(137, [](auto row) { return row; }), + }); + }), + makeBatches( + 3, + [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 234, + [](auto row) { return (row + 3) % 21; }, + nullEvery(13)), + makeFlatVector(234, [](auto row) { return row; }), + }); + }), + true); + + // Right side keys are [-3, -2, -1, 0, 1, 2, 3]. + std::vector buildVectors = + makeBatches(3, [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 123, [](auto row) { return -3 + row % 7; }, nullEvery(11)), + makeFlatVector( + 123, [](auto row) { return -111 + row * 2; }, nullEvery(13)), + }); + }); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .probeKeys({"c0"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u_c0"}) + .buildVectors(std::move(buildVectors)) + .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) + .joinType(core::JoinType::kRight) + .joinOutputLayout({"c0", "c1", "u_c1"}) + .referenceQuery( + "SELECT t.c0, t.c1, u.c1 FROM t RIGHT JOIN u ON t.c0 = u.c0") + .run(); +} + +TEST_P(MultiThreadedHashJoinTest, rightJoinWithEmptyBuild) { + const std::vector finishOnEmptys = {false, true}; + for (const auto finishOnEmpty : finishOnEmptys) { + SCOPED_TRACE(fmt::format("finishOnEmpty: {}", finishOnEmpty)); + + // Left side keys are [0, 1, 2,..10]. + std::vector probeVectors = mergeBatches( + makeBatches( + 3, + [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 137, [](auto row) { return row % 11; }, nullEvery(13)), + makeFlatVector(137, [](auto row) { return row; }), + }); + }), + makeBatches( + 3, + [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 234, + [](auto row) { return (row + 3) % 11; }, + nullEvery(13)), + makeFlatVector(234, [](auto row) { return row; }), + }); + }), + true); + + // Right side keys are [-3, -2, -1, 0, 1, 2, 3]. + std::vector buildVectors = + makeBatches(3, [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 123, [](auto row) { return -3 + row % 7; }, nullEvery(11)), + makeFlatVector( + 123, [](auto row) { return -111 + row * 2; }, nullEvery(13)), + }); + }); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .probeKeys({"c0"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u_c0"}) + .buildVectors(std::move(buildVectors)) + .buildFilter("c0 > 100") + .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) + .joinType(core::JoinType::kRight) + .joinOutputLayout({"c1"}) + .referenceQuery("SELECT null LIMIT 0") + .checkSpillStats(false) + .run(); + } +} + +TEST_P(MultiThreadedHashJoinTest, rightJoinWithAllMatch) { + // Left side keys are [0, 1, 2,..20]. + std::vector probeVectors = mergeBatches( + makeBatches( + 3, + [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 137, [](auto row) { return row % 21; }, nullEvery(13)), + makeFlatVector(137, [](auto row) { return row; }), + }); + }), + makeBatches( + 3, + [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 234, + [](auto row) { return (row + 3) % 21; }, + nullEvery(13)), + makeFlatVector(234, [](auto row) { return row; }), + }); + }), + true); + + // Right side keys are [-3, -2, -1, 0, 1, 2, 3]. + std::vector buildVectors = + makeBatches(3, [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 123, [](auto row) { return -3 + row % 7; }, nullEvery(11)), + makeFlatVector( + 123, [](auto row) { return -111 + row * 2; }, nullEvery(13)), + }); + }); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .probeKeys({"c0"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u_c0"}) + .buildVectors(std::move(buildVectors)) + .buildFilter("c0 >= 0") + .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) + .joinType(core::JoinType::kRight) + .joinOutputLayout({"c0", "c1", "u_c1"}) + .referenceQuery( + "SELECT t.c0, t.c1, u.c1 FROM t RIGHT JOIN (SELECT * FROM u WHERE c0 >= 0) u ON t.c0 = u.c0") + .run(); +} + +TEST_P(MultiThreadedHashJoinTest, rightJoinWithFilter) { + // Left side keys are [0, 1, 2,..20]. + std::vector probeVectors = mergeBatches( + makeBatches( + 3, + [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 137, [](auto row) { return row % 21; }, nullEvery(13)), + makeFlatVector(137, [](auto row) { return row; }), + }); + }), + makeBatches( + 3, + [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 234, + [](auto row) { return (row + 3) % 21; }, + nullEvery(13)), + makeFlatVector(234, [](auto row) { return row; }), + }); + }), + true); + + // Right side keys are [-3, -2, -1, 0, 1, 2, 3]. + std::vector buildVectors = + makeBatches(3, [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 123, [](auto row) { return -3 + row % 7; }, nullEvery(11)), + makeFlatVector( + 123, [](auto row) { return -111 + row * 2; }, nullEvery(13)), + }); + }); + + // Filter with passed rows. + { + auto testProbeVectors = probeVectors; + auto testBuildVectors = buildVectors; + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .probeKeys({"c0"}) + .probeVectors(std::move(testProbeVectors)) + .buildKeys({"u_c0"}) + .buildVectors(std::move(testBuildVectors)) + .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) + .joinType(core::JoinType::kRight) + .joinFilter("(c1 + u_c1) % 2 = 1") + .joinOutputLayout({"c0", "c1", "u_c1"}) + .referenceQuery( + "SELECT t.c0, t.c1, u.c1 FROM t RIGHT JOIN u ON t.c0 = u.c0 AND (t.c1 + u.c1) % 2 = 1") + .run(); + } + + // Filter without passed rows. + { + auto testProbeVectors = probeVectors; + auto testBuildVectors = buildVectors; + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .probeKeys({"c0"}) + .probeVectors(std::move(testProbeVectors)) + .buildKeys({"u_c0"}) + .buildVectors(std::move(testBuildVectors)) + .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) + .joinType(core::JoinType::kRight) + .joinFilter("(c1 + u_c1) % 2 = 3") + .joinOutputLayout({"c0", "c1", "u_c1"}) + .referenceQuery( + "SELECT t.c0, t.c1, u.c1 FROM t RIGHT JOIN u ON t.c0 = u.c0 AND (t.c1 + u.c1) % 2 = 3") + .run(); + } +} + +TEST_P(MultiThreadedHashJoinTest, fullJoin) { + // Left side keys are [0, 1, 2,..20]. + std::vector probeVectors = mergeBatches( + makeBatches( + 3, + [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 213, [](auto row) { return row % 21; }, nullEvery(13)), + makeFlatVector(213, [](auto row) { return row; }), + }); + }), + makeBatches( + 2, + [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 137, + [](auto row) { return (row + 3) % 21; }, + nullEvery(13)), + makeFlatVector(137, [](auto row) { return row; }), + }); + }), + true); + + // Right side keys are [-3, -2, -1, + // 0, 1, 2, 3]. + std::vector buildVectors = + makeBatches(3, [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 123, [](auto row) { return -3 + row % 7; }, nullEvery(11)), + makeFlatVector( + 123, [](auto row) { return -111 + row * 2; }, nullEvery(13)), + }); + }); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .probeKeys({"c0"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u_c0"}) + .buildVectors(std::move(buildVectors)) + .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) + .joinType(core::JoinType::kFull) + .joinOutputLayout({"c0", "c1", "u_c1"}) + .referenceQuery( + "SELECT t.c0, t.c1, u.c1 FROM t FULL OUTER JOIN u ON t.c0 = u.c0") + .run(); +} + +TEST_P(MultiThreadedHashJoinTest, fullJoinWithEmptyBuild) { + const std::vector finishOnEmptys = {false, true}; + for (const auto finishOnEmpty : finishOnEmptys) { + SCOPED_TRACE(fmt::format("finishOnEmpty: {}", finishOnEmpty)); + + // Left side keys are [0, 1, 2,..10]. + std::vector probeVectors = mergeBatches( + makeBatches( + 3, + [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 213, [](auto row) { return row % 11; }, nullEvery(13)), + makeFlatVector(213, [](auto row) { return row; }), + }); + }), + makeBatches( + 2, + [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 137, + [](auto row) { return (row + 3) % 11; }, + nullEvery(13)), + makeFlatVector(137, [](auto row) { return row; }), + }); + }), + true); + + // Right side keys are [-3, -2, -1, 0, 1, 2, 3]. + std::vector buildVectors = + makeBatches(3, [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 123, [](auto row) { return -3 + row % 7; }, nullEvery(11)), + makeFlatVector( + 123, [](auto row) { return -111 + row * 2; }, nullEvery(13)), + }); + }); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .probeKeys({"c0"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u_c0"}) + .buildVectors(std::move(buildVectors)) + .buildFilter("c0 > 100") + .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) + .joinType(core::JoinType::kFull) + .joinOutputLayout({"c1"}) + .referenceQuery( + "SELECT t.c1 FROM t FULL OUTER JOIN (SELECT * FROM u WHERE c0 > 100) u ON t.c0 = u.c0") + .checkSpillStats(false) + .run(); + } +} + +TEST_P(MultiThreadedHashJoinTest, fullJoinWithNoMatch) { + // Left side keys are [0, 1, 2,..10]. + std::vector probeVectors = mergeBatches( + makeBatches( + 3, + [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 213, [](auto row) { return row % 11; }, nullEvery(13)), + makeFlatVector(213, [](auto row) { return row; }), + }); + }), + makeBatches( + 2, + [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 137, + [](auto row) { return (row + 3) % 11; }, + nullEvery(13)), + makeFlatVector(137, [](auto row) { return row; }), + }); + }), + true); + + // Right side keys are [-3, -2, -1, 0, 1, 2, 3]. + std::vector buildVectors = + makeBatches(3, [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 123, [](auto row) { return -3 + row % 7; }, nullEvery(11)), + makeFlatVector( + 123, [](auto row) { return -111 + row * 2; }, nullEvery(13)), + }); + }); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .probeKeys({"c0"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u_c0"}) + .buildVectors(std::move(buildVectors)) + .buildFilter("c0 < 0") + .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) + .joinType(core::JoinType::kFull) + .joinOutputLayout({"c1"}) + .referenceQuery( + "SELECT t.c1 FROM t FULL OUTER JOIN (SELECT * FROM u WHERE c0 < 0) u ON t.c0 = u.c0") + .run(); +} + +TEST_P(MultiThreadedHashJoinTest, fullJoinWithFilters) { + // Left side keys are [0, 1, 2,..10]. + std::vector probeVectors = mergeBatches( + makeBatches( + 3, + [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 213, [](auto row) { return row % 11; }, nullEvery(13)), + makeFlatVector(213, [](auto row) { return row; }), + }); + }), + makeBatches( + 2, + [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 137, + [](auto row) { return (row + 3) % 11; }, + nullEvery(13)), + makeFlatVector(137, [](auto row) { return row; }), + }); + }), + true); + + // Right side keys are [-3, -2, -1, 0, 1, 2, 3]. + std::vector buildVectors = + makeBatches(3, [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 123, [](auto row) { return -3 + row % 7; }, nullEvery(11)), + makeFlatVector( + 123, [](auto row) { return -111 + row * 2; }, nullEvery(13)), + }); + }); + + // Filter with passed rows. + { + auto testProbeVectors = probeVectors; + auto testBuildVectors = buildVectors; + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .probeKeys({"c0"}) + .probeVectors(std::move(testProbeVectors)) + .buildKeys({"u_c0"}) + .buildVectors(std::move(testBuildVectors)) + .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) + .joinType(core::JoinType::kFull) + .joinFilter("(c1 + u_c1) % 2 = 1") + .joinOutputLayout({"c0", "c1", "u_c1"}) + .referenceQuery( + "SELECT t.c0, t.c1, u.c1 FROM t FULL OUTER JOIN u ON t.c0 = u.c0 AND (t.c1 + u.c1) % 2 = 1") + .run(); + } + + // Filter without passed rows. + { + auto testProbeVectors = probeVectors; + auto testBuildVectors = buildVectors; + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .probeKeys({"c0"}) + .probeVectors(std::move(testProbeVectors)) + .buildKeys({"u_c0"}) + .buildVectors(std::move(testBuildVectors)) + .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) + .joinType(core::JoinType::kFull) + .joinFilter("(c1 + u_c1) % 2 = 3") + .joinOutputLayout({"c0", "c1", "u_c1"}) + .referenceQuery( + "SELECT t.c0, t.c1, u.c1 FROM t FULL OUTER JOIN u ON t.c0 = u.c0 AND (t.c1 + u.c1) % 2 = 3") + .run(); + } +} + +TEST_P(MultiThreadedHashJoinTest, noSpillLevelLimit) { + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .keyTypes({INTEGER()}) + .probeVectors(1600, 5) + .buildVectors(1500, 5) + .referenceQuery( + "SELECT t_k0, t_data, u_k0, u_data FROM t, u WHERE t.t_k0 = u.u_k0") + .maxSpillLevel(-1) + .config(core::QueryConfig::kSpillStartPartitionBit, "51") + .config(core::QueryConfig::kSpillNumPartitionBits, "3") + .checkSpillStats(false) + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + if (!hasSpill) { + return; + } + ASSERT_EQ(maxHashBuildSpillLevel(*task), 3); + }) + .run(); +} + +// Verify that dynamic filter pushed down is turned off for null-aware right +// semi project join. +TEST_P(HashJoinTest, nullAwareRightSemiProjectOverScan) { + std::vector probes; + std::vector builds; + // Matches present: + probes.push_back(makeRowVector( + {"t0"}, + { + makeNullableFlatVector({1, std::nullopt, 2}), + })); + builds.push_back(makeRowVector( + {"u0"}, + { + makeNullableFlatVector({1, 2, 3, std::nullopt}), + })); + + // No matches present: + probes.push_back(makeRowVector( + {"t0"}, + { + makeFlatVector({5, 6}), + })); + builds.push_back(makeRowVector( + {"u0"}, + { + makeNullableFlatVector({1, 2, 3, std::nullopt}), + })); + + for (int i = 0; i < probes.size(); i++) { + RowVectorPtr& probe = probes[i]; + RowVectorPtr& build = builds[i]; + std::shared_ptr probeFile = TempFilePath::create(); + writeToFile(probeFile->getPath(), {probe}); + + std::shared_ptr buildFile = TempFilePath::create(); + writeToFile(buildFile->getPath(), {build}); + + createDuckDbTable("t", {probe}); + createDuckDbTable("u", {build}); + + core::PlanNodeId probeScanId; + core::PlanNodeId buildScanId; + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .tableScan(asRowType(probe->type())) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .tableScan(asRowType(build->type())) + .capturePlanNodeId(buildScanId) + .planNode(), + "", + {"u0", "match"}, + core::JoinType::kRightSemiProject, + true /*nullAware*/) + .planNode(); + + SplitPath splitPaths = { + {probeScanId, {probeFile->getPath()}}, + {buildScanId, {buildFile->getPath()}}, + }; + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .inputSplits(splitPaths) + .checkSpillStats(false) + .referenceQuery("SELECT u0, u0 IN (SELECT t0 FROM t) FROM u") + .run(); + } +} + +TEST_P(HashJoinTest, duplicateJoinKeys) { + auto leftVectors = makeBatches(3, [&](int32_t /*unused*/) { + return makeRowVector({ + makeNullableFlatVector( + {1, 2, 2, 3, 3, std::nullopt, 4, 5, 5, 6, 7}), + makeNullableFlatVector( + {1, 2, 2, std::nullopt, 3, 3, 4, 5, 5, 6, 8}), + }); + }); + + auto rightVectors = makeBatches(3, [&](int32_t /*unused*/) { + return makeRowVector({ + makeNullableFlatVector({1, 1, 3, 4, std::nullopt, 5, 7, 8}), + makeNullableFlatVector({1, 1, 3, 4, 5, std::nullopt, 7, 8}), + }); + }); + + createDuckDbTable("t", leftVectors); + createDuckDbTable("u", rightVectors); + + auto planNodeIdGenerator = std::make_shared(); + + auto assertPlan = [&](const std::vector& leftProject, + const std::vector& leftKeys, + const std::vector& rightProject, + const std::vector& rightKeys, + const std::vector& outputLayout, + core::JoinType joinType, + const std::string& query) { + auto plan = PlanBuilder(planNodeIdGenerator) + .values(leftVectors) + .project(leftProject) + .hashJoin( + leftKeys, + rightKeys, + PlanBuilder(planNodeIdGenerator) + .values(rightVectors) + .project(rightProject) + .planNode(), + "", + outputLayout, + joinType) + .planNode(); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .referenceQuery(query) + .run(); + }; + + std::vector> joins = { + {core::JoinType::kInner, "INNER JOIN"}, + {core::JoinType::kLeft, "LEFT JOIN"}, + {core::JoinType::kRight, "RIGHT JOIN"}, + {core::JoinType::kFull, "FULL OUTER JOIN"}}; + + for (const auto& [joinType, joinTypeSql] : joins) { + // Duplicate keys on the build side. + assertPlan( + {"c0 AS t0", "c1 as t1"}, // leftProject + {"t0", "t1"}, // leftKeys + {"c0 AS u0"}, // rightProject + {"u0", "u0"}, // rightKeys + {"t0", "t1", "u0"}, // outputLayout + joinType, + "SELECT t.c0, t.c1, u.c0 FROM t " + joinTypeSql + + " u ON t.c0 = u.c0 and t.c1 = u.c0"); + } + + for (const auto& [joinType, joinTypeSql] : joins) { + // Duplicated keys on the probe side. + assertPlan( + {"c0 AS t0"}, // leftProject + {"t0", "t0"}, // leftKeys + {"c0 AS u0", "c1 AS u1"}, // rightProject + {"u0", "u1"}, // rightKeys + {"t0", "u0", "u1"}, // outputLayout + joinType, + "SELECT t.c0, u.c0, u.c1 FROM t " + joinTypeSql + + " u ON t.c0 = u.c0 and t.c0 = u.c1"); + } +} + +TEST_P(HashJoinTest, semiProject) { + // Some keys have multiple rows: 2, 3, 5. + auto probeVectors = makeBatches(3, [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector({1, 2, 2, 3, 3, 3, 4, 5, 5, 6, 7}), + makeFlatVector({10, 20, 21, 30, 31, 32, 40, 50, 51, 60, 70}), + }); + }); + + // Some keys are missing: 2, 6. + // Some have multiple rows: 1, 5. + // Some keys are not present on probe side: 8. + auto buildVectors = makeBatches(3, [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector({1, 1, 3, 4, 5, 5, 7, 8}), + makeFlatVector({100, 101, 300, 400, 500, 501, 700, 800}), + }); + }); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors) + .project({"c0 AS t0", "c1 AS t1"}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors) + .project({"c0 AS u0", "c1 AS u1"}) + .planNode(), + "", + {"t0", "t1", "match"}, + core::JoinType::kLeftSemiProject) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .referenceQuery( + "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE t.c0 = u.c0) FROM t") + .run(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(flipJoinSides(plan)) + .referenceQuery( + "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE t.c0 = u.c0) FROM t") + .run(); + + // With extra filter. + planNodeIdGenerator = std::make_shared(); + plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors) + .project({"c0 AS t0", "c1 AS t1"}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors) + .project({"c0 AS u0", "c1 AS u1"}) + .planNode(), + "t1 * 10 <> u1", + {"t0", "t1", "match"}, + core::JoinType::kLeftSemiProject) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .referenceQuery( + "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE t.c0 = u.c0 AND t.c1 * 10 <> u.c1) FROM t") + .run(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(flipJoinSides(plan)) + .referenceQuery( + "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE t.c0 = u.c0 AND t.c1 * 10 <> u.c1) FROM t") + .run(); + + // Empty build side. + planNodeIdGenerator = std::make_shared(); + plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors) + .project({"c0 AS t0", "c1 AS t1"}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors) + .project({"c0 AS u0", "c1 AS u1"}) + .filter("u0 < 0") + .planNode(), + "", + {"t0", "t1", "match"}, + core::JoinType::kLeftSemiProject) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .referenceQuery( + "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE u.c0 < 0 AND t.c0 = u.c0) FROM t") + // NOTE: there is no spilling in empty build test case as all the + // build-side rows have been filtered out. + .checkSpillStats(false) + .run(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(flipJoinSides(plan)) + .referenceQuery( + "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE u.c0 < 0 AND t.c0 = u.c0) FROM t") + // NOTE: there is no spilling in empty build test case as all the + // build-side rows have been filtered out. + .checkSpillStats(false) + .run(); +} + +TEST_P(HashJoinTest, semiProjectWithNullKeys) { + // Some keys have multiple rows: 2, 3, 5. + auto probeVectors = makeBatches(3, [&](int32_t /*unused*/) { + return makeRowVector( + {"t0", "t1"}, + { + makeNullableFlatVector( + {1, 2, 2, 3, 3, 3, 4, std::nullopt, 5, 5, 6, 7}), + makeFlatVector( + {10, 20, 21, 30, 31, 32, 40, -1, 50, 51, 60, 70}), + }); + }); + + // Some keys are missing: 2, 6. + // Some have multiple rows: 1, 5. + // Some keys are not present on probe side: 8. + auto buildVectors = makeBatches(3, [&](int32_t /*unused*/) { + return makeRowVector( + {"u0", "u1"}, + { + makeNullableFlatVector( + {1, 1, 3, 4, std::nullopt, 5, 5, 7, 8}), + makeFlatVector( + {100, 101, 300, 400, -100, 500, 501, 700, 800}), + }); + }); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto makePlan = [&](bool nullAware, + const std::string& probeFilter = "", + const std::string& buildFilter = "") { + auto planNodeIdGenerator = std::make_shared(); + return PlanBuilder(planNodeIdGenerator) + .values(probeVectors) + .optionalFilter(probeFilter) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors) + .optionalFilter(buildFilter) + .planNode(), + "", + {"t0", "t1", "match"}, + core::JoinType::kLeftSemiProject, + nullAware) + .planNode(); + }; + + // Null join keys on both sides. + auto plan = makePlan(false /*nullAware*/); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .referenceQuery( + "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0) FROM t") + .run(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(flipJoinSides(plan)) + .referenceQuery( + "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0) FROM t") + .run(); + + plan = makePlan(true /*nullAware*/); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .referenceQuery("SELECT t0, t1, t0 IN (SELECT u0 FROM u) FROM t") + .run(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(flipJoinSides(plan)) + .referenceQuery("SELECT t0, t1, t0 IN (SELECT u0 FROM u) FROM t") + .run(); + + // Null join keys on build side-only. + plan = makePlan(false /*nullAware*/, "t0 IS NOT NULL"); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .referenceQuery( + "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0) FROM t WHERE t0 IS NOT NULL") + .run(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(flipJoinSides(plan)) + .referenceQuery( + "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0) FROM t WHERE t0 IS NOT NULL") + .run(); + + plan = makePlan(true /*nullAware*/, "t0 IS NOT NULL"); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .referenceQuery( + "SELECT t0, t1, t0 IN (SELECT u0 FROM u) FROM t WHERE t0 IS NOT NULL") + .run(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(flipJoinSides(plan)) + .referenceQuery( + "SELECT t0, t1, t0 IN (SELECT u0 FROM u) FROM t WHERE t0 IS NOT NULL") + .run(); + + // Null join keys on probe side-only. + plan = makePlan(false /*nullAware*/, "", "u0 IS NOT NULL"); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .referenceQuery( + "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0 AND u0 IS NOT NULL) FROM t") + .run(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(flipJoinSides(plan)) + .referenceQuery( + "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0 AND u0 IS NOT NULL) FROM t") + .run(); + + plan = makePlan(true /*nullAware*/, "", "u0 IS NOT NULL"); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .referenceQuery( + "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE u0 IS NOT NULL) FROM t") + .run(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(flipJoinSides(plan)) + .referenceQuery( + "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE u0 IS NOT NULL) FROM t") + .run(); + + // Empty build side. + plan = makePlan(false /*nullAware*/, "", "u0 < 0"); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) + .planNode(plan) + .checkSpillStats(false) + .referenceQuery( + "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0 AND u0 < 0) FROM t") + .run(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) + .planNode(flipJoinSides(plan)) + .checkSpillStats(false) + .referenceQuery( + "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0 AND u0 < 0) FROM t") + .run(); + + plan = makePlan(true /*nullAware*/, "", "u0 < 0"); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) + .planNode(plan) + .checkSpillStats(false) + .referenceQuery( + "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE u0 < 0) FROM t") + .run(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) + .planNode(flipJoinSides(plan)) + .checkSpillStats(false) + .referenceQuery( + "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE u0 < 0) FROM t") + .run(); + + // Build side with all rows having null join keys. + plan = makePlan(false /*nullAware*/, "", "u0 IS NULL"); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) + .planNode(plan) + .checkSpillStats(false) + .referenceQuery( + "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0 AND u0 IS NULL) FROM t") + .run(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) + .planNode(flipJoinSides(plan)) + .checkSpillStats(false) + .referenceQuery( + "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0 AND u0 IS NULL) FROM t") + .run(); + + plan = makePlan(true /*nullAware*/, "", "u0 IS NULL"); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) + .planNode(plan) + .checkSpillStats(false) + .referenceQuery( + "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE u0 IS NULL) FROM t") + .run(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) + .planNode(flipJoinSides(plan)) + .checkSpillStats(false) + .referenceQuery( + "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE u0 IS NULL) FROM t") + .run(); +} + +TEST_P(HashJoinTest, semiProjectWithFilter) { + auto probeVectors = makeBatches(3, [&](auto /*unused*/) { + return makeRowVector( + {"t0", "t1"}, + { + makeNullableFlatVector({1, 2, 3, std::nullopt, 5}), + makeFlatVector({10, 20, 30, 40, 50}), + }); + }); + + auto buildVectors = makeBatches(3, [&](auto /*unused*/) { + return makeRowVector( + {"u0", "u1"}, + { + makeNullableFlatVector({1, 2, 3, std::nullopt}), + makeFlatVector({11, 22, 33, 44}), + }); + }); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto makePlan = [&](bool nullAware, const std::string& filter) { + auto planNodeIdGenerator = std::make_shared(); + return PlanBuilder(planNodeIdGenerator) + .values(probeVectors) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values(buildVectors).planNode(), + filter, + {"t0", "t1", "match"}, + core::JoinType::kLeftSemiProject, + nullAware) + .planNode(); + }; + + std::vector filters = { + "t1 <> u1", + "t1 < u1", + "t1 > u1", + "t1 is not null AND u1 is not null", + "t1 is null OR u1 is null", + }; + for (const auto& filter : filters) { + auto plan = makePlan(true /*nullAware*/, filter); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .referenceQuery( + fmt::format( + "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE {}) FROM t", + filter)) + .injectSpill(false) + .run(); + + plan = makePlan(false /*nullAware*/, filter); + + // DuckDB Exists operator returns NULL when u0 or t0 is NULL. We exclude + // these values. + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .referenceQuery( + fmt::format( + "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE (u0 is not null OR t0 is not null) AND u0 = t0 AND {}) FROM t", + filter)) + .injectSpill(false) + .run(); + } +} + +TEST_P(HashJoinTest, nullAwareRightSemiProjectWithFilterNotAllowed) { + auto probe = makeRowVector(ROW({"t0", "t1"}, {INTEGER(), BIGINT()}), 10); + auto build = makeRowVector(ROW({"u0", "u1"}, {INTEGER(), BIGINT()}), 10); + + auto planNodeIdGenerator = std::make_shared(); + VELOX_ASSERT_THROW( + PlanBuilder(planNodeIdGenerator) + .values({probe}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({build}).planNode(), + "t1 > u1", + {"u0", "u1", "match"}, + core::JoinType::kRightSemiProject, + true /* nullAware */), + "Null-aware right semi project join doesn't support extra filter"); +} + +TEST_P(HashJoinTest, leftSemiJoinWithExtraOutputCapacity) { + std::vector probeVectors; + std::vector buildVectors; + probeVectors.push_back(makeRowVector( + {"t0", "t1"}, + { + makeFlatVector({1, 2, 3, 4, 5, 6}), + makeFlatVector({10, 10, 10, 10, 10, 10}), + })); + + buildVectors.push_back(makeRowVector( + {"u0", "u1"}, + { + makeFlatVector({1, 1, 1, 1, 1}), + makeFlatVector({10, 10, 10, 10, 10}), + })); + buildVectors.push_back(makeRowVector( + {"u0", "u1"}, + { + makeFlatVector({2, 3, 4, 5, 6}), + makeFlatVector({10, 10, 10, 10, 10}), + })); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + auto runQuery = [&](const std::string& query, + const std::string& filter, + core::JoinType joinType) { + auto planNodeIdGenerator = std::make_shared(); + std::vector outputLayout = {"t0", "t1"}; + if (joinType == core::JoinType::kLeftSemiProject) { + outputLayout.push_back("match"); + } + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors) + .planNode(), + filter, + outputLayout, + joinType, + false) + .planNode(); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .config(core::QueryConfig::kPreferredOutputBatchRows, "5") + .referenceQuery(query) + .injectSpill(false) + .run(); + }; + { + SCOPED_TRACE("left semi filter join"); + std::string filter = "t1 = u1"; + runQuery( + fmt::format( + "SELECT t0, t1 FROM t WHERE EXISTS (SELECT u0 FROM u WHERE t0 = u0 AND {})", + filter), + filter, + core::JoinType::kLeftSemiFilter); + } + + { + SCOPED_TRACE("left semi project join"); + std::string filter = "t1 <> u1"; + runQuery( + fmt::format( + "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE {}) FROM t", filter), + filter, + core::JoinType::kLeftSemiProject); + } +} + +TEST_P(HashJoinTest, nullAwareMultiKeyNotAllowed) { + auto probe = makeRowVector( + ROW({"t0", "t1", "t2"}, {INTEGER(), BIGINT(), VARCHAR()}), 10); + auto build = makeRowVector( + ROW({"u0", "u1", "u2"}, {INTEGER(), BIGINT(), VARCHAR()}), 10); + + // Null-aware left semi project join. + auto planNodeIdGenerator = std::make_shared(); + VELOX_ASSERT_THROW( + PlanBuilder(planNodeIdGenerator) + .values({probe}) + .hashJoin( + {"t0", "t1"}, + {"u0", "u1"}, + PlanBuilder(planNodeIdGenerator).values({build}).planNode(), + "", + {"t0", "t1", "match"}, + core::JoinType::kLeftSemiProject, + true /* nullAware */), + "Null-aware joins allow only one join key"); + + // Null-aware right semi project join. + VELOX_ASSERT_THROW( + PlanBuilder(planNodeIdGenerator) + .values({probe}) + .hashJoin( + {"t0", "t1"}, + {"u0", "u1"}, + PlanBuilder(planNodeIdGenerator).values({build}).planNode(), + "", + {"u0", "u1", "match"}, + core::JoinType::kRightSemiProject, + true /* nullAware */), + "Null-aware joins allow only one join key"); + + // Null-aware anti join. + VELOX_ASSERT_THROW( + PlanBuilder(planNodeIdGenerator) + .values({probe}) + .hashJoin( + {"t0", "t1"}, + {"u0", "u1"}, + PlanBuilder(planNodeIdGenerator).values({build}).planNode(), + "", + {"t0", "t1"}, + core::JoinType::kAnti, + true /* nullAware */), + "Null-aware joins allow only one join key"); +} + +TEST_P(HashJoinTest, semiProjectOverLazyVectors) { + auto probeVectors = makeBatches(1, [&](auto /*unused*/) { + return makeRowVector( + {"t0", "t1"}, + { + makeFlatVector(1'000, [](auto row) { return row; }), + makeFlatVector(1'000, [](auto row) { return row * 10; }), + }); + }); + + auto buildVectors = makeBatches(3, [&](auto /*unused*/) { + return makeRowVector( + {"u0", "u1"}, + { + makeFlatVector( + 1'000, [](auto row) { return -100 + (row / 5); }), + makeFlatVector( + 1'000, [](auto row) { return -1000 + (row / 5) * 10; }), + }); + }); + + std::shared_ptr probeFile = TempFilePath::create(); + writeToFile(probeFile->getPath(), probeVectors); + + std::shared_ptr buildFile = TempFilePath::create(); + writeToFile(buildFile->getPath(), buildVectors); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + core::PlanNodeId probeScanId; + core::PlanNodeId buildScanId; + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .tableScan(asRowType(probeVectors[0]->type())) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .tableScan(asRowType(buildVectors[0]->type())) + .capturePlanNodeId(buildScanId) + .planNode(), + "", + {"t0", "t1", "match"}, + core::JoinType::kLeftSemiProject) + .planNode(); + + SplitPath splitPaths = { + {probeScanId, {probeFile->getPath()}}, + {buildScanId, {buildFile->getPath()}}, + }; + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .inputSplits(splitPaths) + .checkSpillStats(false) + .referenceQuery("SELECT t0, t1, t0 IN (SELECT u0 FROM u) FROM t") + .run(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(flipJoinSides(plan)) + .inputSplits(splitPaths) + .checkSpillStats(false) + .referenceQuery("SELECT t0, t1, t0 IN (SELECT u0 FROM u) FROM t") + .run(); + + // With extra filter. + planNodeIdGenerator = std::make_shared(); + plan = PlanBuilder(planNodeIdGenerator) + .tableScan(asRowType(probeVectors[0]->type())) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .tableScan(asRowType(buildVectors[0]->type())) + .capturePlanNodeId(buildScanId) + .planNode(), + "(t1 + u1) % 3 = 0", + {"t0", "t1", "match"}, + core::JoinType::kLeftSemiProject) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .inputSplits(splitPaths) + .checkSpillStats(false) + .referenceQuery( + "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE (t1 + u1) % 3 = 0) FROM t") + .run(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(flipJoinSides(plan)) + .inputSplits(splitPaths) + .checkSpillStats(false) + .referenceQuery( + "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE (t1 + u1) % 3 = 0) FROM t") + .run(); +} + +// Verifies that hash join with lazy probe input produces correct results +// across multiple join types, with automatic spill injection via +// HashJoinBuilder. +TEST_P(MultiThreadedHashJoinTest, hashJoinWithLazyProbeInputAndSpill) { + VectorFuzzer::Options opts; + opts.vectorSize = 1'000; + VectorFuzzer fuzzer(opts, pool()); + + const int32_t numProbeVectors = 3; + std::vector probeVectors; + std::vector probeReference; + probeVectors.reserve(numProbeVectors); + probeReference.reserve(numProbeVectors); + for (int32_t i = 0; i < numProbeVectors; ++i) { + auto nonLazy = fuzzer.fuzzRow(probeType_); + probeReference.push_back( + std::dynamic_pointer_cast( + nonLazy->testingCopyPreserveEncodings())); + probeVectors.push_back( + VectorFuzzer(opts, pool()).fuzzRowChildrenToLazy(nonLazy)); + } + createDuckDbTable("t", probeReference); + + const int32_t numBuildVectors = 3; + std::vector buildVectors; + buildVectors.reserve(numBuildVectors); + for (int32_t i = 0; i < numBuildVectors; ++i) { + buildVectors.push_back(fuzzer.fuzzRow(buildType_)); + } + createDuckDbTable("u", buildVectors); + + struct { + core::JoinType joinType; + std::string filter; + std::vector outputColumns; + std::string referenceQuery; + } testCases[] = { + {core::JoinType::kInner, + "", + concat(probeType_->names(), buildType_->names()), + "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1"}, + {core::JoinType::kLeft, + "", + concat(probeType_->names(), buildType_->names()), + "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t LEFT JOIN u ON t.t_k1 = u.u_k1"}, + {core::JoinType::kFull, + "", + concat(probeType_->names(), buildType_->names()), + "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t FULL OUTER JOIN u ON t.t_k1 = u.u_k1"}, + {core::JoinType::kAnti, + "t_k2 <> u_k2", + probeType_->names(), + "SELECT t_k1, t_k2, t_v1 FROM t WHERE NOT EXISTS (SELECT 1 FROM u WHERE t.t_k1 = u.u_k1 AND t.t_k2 <> u.u_k2)"}, + {core::JoinType::kLeftSemiProject, + "", + concat(probeType_->names(), {"match"}), + "SELECT t_k1, t_k2, t_v1, EXISTS (SELECT 1 FROM u WHERE t.t_k1 = u.u_k1) FROM t"}, + }; + + for (const auto& testCase : testCases) { + SCOPED_TRACE( + fmt::format("joinType: {}", static_cast(testCase.joinType))); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) + .probeKeys({"t_k1"}) + .probeVectors(std::vector(probeVectors)) + .buildKeys({"u_k1"}) + .buildVectors(std::vector(buildVectors)) + .joinType(testCase.joinType) + .joinFilter(testCase.filter) + .joinOutputLayout(std::vector(testCase.outputColumns)) + .referenceQuery(testCase.referenceQuery) + .run(); + } +} + +VELOX_INSTANTIATE_TEST_SUITE_P( + MultiThreadedHashJoinTest, + MultiThreadedHashJoinTest, + testing::ValuesIn(MultiThreadedHashJoinTest::getTestParams()), + [](const testing::TestParamInfo& info) { + return TestParamToName(info.param); + }); + +VELOX_INSTANTIATE_TEST_SUITE_P( + HashJoinTest, + HashJoinTest, + testing::ValuesIn(HashJoinTest::getTestParams()), + [](const testing::TestParamInfo& info) { + return TestParamToName(info.param); + }); + +// Test that hash join spill uses the hash_join_spill_file_create_config when +// set, and other spillable operators use the default spill_file_create_config. +DEBUG_ONLY_TEST_P(HashJoinTest, hashJoinSpillFileCreateConfig) { + const auto rowType = + ROW({"c0", "c1", "c2"}, {INTEGER(), INTEGER(), VARCHAR()}); + const auto probeVectors = createVectors(rowType, 128, 10); + const auto buildVectors = createVectors(rowType, 128, 10); + + auto planNodeIdGenerator = std::make_shared(); + // Build a plan with hash join and orderBy. Hash join operators should use + // hash_join_spill_file_create_config and orderBy should use the default + // spill_file_create_config. + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, true) + .hashJoin( + {"c0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, true) + .project({"c0 AS u0", "c1 AS u1", "c2 AS u2"}) + .planNode(), + "", + {"c0", "c1", "c2"}, + core::JoinType::kInner) + .orderBy({"c0 ASC NULLS LAST"}, false) + .planNode(); + + auto spillDirectory = TempDirectoryPath::create(); + + std::atomic_bool hashJoinConfigVerified{false}; + std::atomic_bool defaultConfigVerified{false}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::isBlocked", + std::function([&](exec::Operator* op) { + const auto* spillConfig = op->testingSpillConfig(); + if (spillConfig == nullptr) { + return; + } + const auto& opType = op->operatorType(); + if (opType == "HashBuild" || opType == "HashProbe") { + // Hash join operators should use hash_join_spill_file_create_config. + ASSERT_EQ(spillConfig->fileCreateConfig, "test_hashjoin_config") + << "Operator: " << opType; + hashJoinConfigVerified = true; + } else { + // Other spillable operators (e.g., OrderBy) should use the default + // spill_file_create_config. + ASSERT_EQ(spillConfig->fileCreateConfig, "test_default_config") + << "Operator: " << opType; + defaultConfigVerified = true; + } + })); + + TestScopedSpillInjection scopedSpillInjection(100); + AssertQueryBuilder(plan) + .spillDirectory(spillDirectory->getPath()) + .config(core::QueryConfig::kSpillEnabled, true) + .config(core::QueryConfig::kJoinSpillEnabled, true) + .config(core::QueryConfig::kOrderBySpillEnabled, true) + .config(core::QueryConfig::kSpillFileCreateConfig, "test_default_config") + .config( + core::QueryConfig::kHashJoinSpillFileCreateConfig, + "test_hashjoin_config") + .copyResults(pool_.get()); + + ASSERT_TRUE(hashJoinConfigVerified.load()); + ASSERT_TRUE(defaultConfigVerified.load()); +} + +} // namespace +} // namespace facebook::velox::exec diff --git a/velox/exec/tests/HashJoinWithCacheTest.cpp b/velox/exec/tests/HashJoinWithCacheTest.cpp new file mode 100644 index 00000000000..07b7ab8ef0c --- /dev/null +++ b/velox/exec/tests/HashJoinWithCacheTest.cpp @@ -0,0 +1,827 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/memory/MemoryArbitrator.h" +#include "velox/common/testutil/TestValue.h" +#include "velox/exec/Cursor.h" +#include "velox/exec/HashBuild.h" +#include "velox/exec/HashJoinBridge.h" +#include "velox/exec/HashProbe.h" +#include "velox/exec/HashTable.h" +#include "velox/exec/HashTableCache.h" +#include "velox/exec/MemoryReclaimer.h" +#include "velox/exec/PlanNodeStats.h" + +#include "velox/exec/tests/utils/ArbitratorTestUtil.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/exec/tests/utils/QueryAssertions.h" + +#include "velox/exec/tests/utils/VectorTestUtil.h" + +namespace facebook::velox::exec::test { +namespace { + +// Test fixture for hash join with hash table caching tests. +class HashJoinWithCacheTest : public HiveConnectorTestBase {}; + +// Tests hash table caching for broadcast joins. +// First task builds the table (cache miss), second task reuses it (cache hit). +TEST_F(HashJoinWithCacheTest, sequential) { + // Use a unique query ID for this test to ensure clean cache state. + const std::string queryId = + "hashTableCachingTest_" + + std::to_string( + std::chrono::steady_clock::now().time_since_epoch().count()); + + // Create probe and build vectors with distinct column names. + std::vector probeVectors = makeBatches(10, [&](int32_t) { + return makeRowVector( + {"t_k", "t_v"}, + { + makeFlatVector(100, [](auto row) { return row % 23; }), + makeFlatVector(100, [](auto row) { return row; }), + }); + }); + + std::vector buildVectors = makeBatches(5, [&](int32_t) { + return makeRowVector( + {"u_k", "u_v"}, + { + makeFlatVector(50, [](auto row) { return row % 31; }), + makeFlatVector(50, [](auto row) { return row * 10; }), + }); + }); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto planNodeIdGenerator = std::make_shared(); + + // Build the plan using HashJoinNode::Builder to set useHashTableCache. + auto buildPlanNode = + PlanBuilder(planNodeIdGenerator).values(buildVectors).planNode(); + + auto probePlanNode = + PlanBuilder(planNodeIdGenerator).values(probeVectors).planNode(); + + // Create HashJoinNode with useHashTableCache = true. + auto outputType = ROW( + {"t_k", "t_v", "u_k", "u_v"}, {INTEGER(), BIGINT(), INTEGER(), BIGINT()}); + + auto joinNode = + core::HashJoinNode::Builder() + .id(planNodeIdGenerator->next()) + .joinType(core::JoinType::kInner) + .nullAware(false) + .leftKeys( + {std::make_shared(INTEGER(), "t_k")}) + .rightKeys( + {std::make_shared(INTEGER(), "u_k")}) + .left(probePlanNode) + .right(buildPlanNode) + .outputType(outputType) + .useHashTableCache(true) + .build(); + const auto joinNodeId = joinNode->id(); + + const int numDrivers = 3; + + // Create a shared QueryCtx for all tasks. + auto queryCtx = core::QueryCtx::create( + driverExecutor_.get(), + core::QueryConfig({}), + std::unordered_map>{}, + cache::AsyncDataCache::getInstance(), + nullptr, + nullptr, + queryId); + + // Helper to run a task and return the completed task. + // Both tasks use the same queryCtx so they share the cache entry. + auto runTask = [&]() { + return AssertQueryBuilder(joinNode, duckDbQueryRunner_) + .queryCtx(queryCtx) + .maxDrivers(numDrivers) + .assertResults( + "SELECT t.t_k, t.t_v, u.u_k, u.u_v FROM t, u WHERE t.t_k = u.u_k"); + }; + + // First task - should build the table (cache miss). + auto task1 = runTask(); + + // Get stats from first task - expect cache miss. + auto opStats1 = toOperatorStats(task1->taskStats()); + ASSERT_EQ(opStats1.count("HashBuild"), 1); + auto& hashBuildStats1 = opStats1.at("HashBuild"); + + // The last driver that finishes building reports the cache miss stat. + ASSERT_EQ( + hashBuildStats1.runtimeStats.count( + std::string(BaseHashTable::kHashTableCacheMiss)), + 1) + << "First task should report cache miss"; + EXPECT_EQ( + hashBuildStats1.runtimeStats + .at(std::string(BaseHashTable::kHashTableCacheMiss)) + .count, + 1) + << "Exactly one driver should report cache miss (the one that builds)"; + EXPECT_EQ( + hashBuildStats1.runtimeStats.count( + std::string(BaseHashTable::kHashTableCacheHit)), + 0) + << "First task should not have any cache hits"; + + // Second task - should reuse the cached table (cache hit). + auto task2 = runTask(); + + // Get stats from second task - expect cache hit. + auto opStats2 = toOperatorStats(task2->taskStats()); + ASSERT_EQ(opStats2.count("HashBuild"), 1); + auto& hashBuildStats2 = opStats2.at("HashBuild"); + + // The last driver that finishes reports the cache hit stat. + ASSERT_EQ( + hashBuildStats2.runtimeStats.count( + std::string(BaseHashTable::kHashTableCacheHit)), + 1) + << "Second task should report cache hit"; + EXPECT_EQ( + hashBuildStats2.runtimeStats + .at(std::string(BaseHashTable::kHashTableCacheHit)) + .count, + 1) + << "Exactly one driver should report cache hit (the one after barrier)"; + EXPECT_EQ( + hashBuildStats2.runtimeStats.count( + std::string(BaseHashTable::kHashTableCacheMiss)), + 0) + << "Second task should not have any cache misses"; + + // Clean up cache entry before tasks are destroyed. + // The release callback on QueryCtx fires too late (after Task destruction + // starts), so we need explicit cleanup in tests. + const auto cacheKey = fmt::format("{}:{}", queryId, joinNodeId); + HashTableCache::instance()->drop(cacheKey); +} + +// Tests that multiple tasks running concurrently share the cached hash table. +// One task builds the table (cache miss), all others wait and reuse it +// (cache hits). +TEST_F(HashJoinWithCacheTest, concurrent) { + // Use a unique query ID for this test to ensure clean cache state. + const std::string queryId = + "hashTableCachingConcurrentTest_" + + std::to_string( + std::chrono::steady_clock::now().time_since_epoch().count()); + + // Create probe and build vectors with distinct column names. + std::vector probeVectors = makeBatches(10, [&](int32_t) { + return makeRowVector( + {"t_k", "t_v"}, + { + makeFlatVector(100, [](auto row) { return row % 23; }), + makeFlatVector(100, [](auto row) { return row; }), + }); + }); + + std::vector buildVectors = makeBatches(5, [&](int32_t) { + return makeRowVector( + {"u_k", "u_v"}, + { + makeFlatVector(50, [](auto row) { return row % 31; }), + makeFlatVector(50, [](auto row) { return row * 10; }), + }); + }); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto planNodeIdGenerator = std::make_shared(); + + // Build the plan using HashJoinNode::Builder to set useHashTableCache. + auto buildPlanNode = + PlanBuilder(planNodeIdGenerator).values(buildVectors).planNode(); + + auto probePlanNode = + PlanBuilder(planNodeIdGenerator).values(probeVectors).planNode(); + + // Create HashJoinNode with useHashTableCache = true. + auto outputType = ROW( + {"t_k", "t_v", "u_k", "u_v"}, {INTEGER(), BIGINT(), INTEGER(), BIGINT()}); + + auto joinNode = + core::HashJoinNode::Builder() + .id(planNodeIdGenerator->next()) + .joinType(core::JoinType::kInner) + .nullAware(false) + .leftKeys( + {std::make_shared(INTEGER(), "t_k")}) + .rightKeys( + {std::make_shared(INTEGER(), "u_k")}) + .left(probePlanNode) + .right(buildPlanNode) + .outputType(outputType) + .useHashTableCache(true) + .build(); + const auto joinNodeId = joinNode->id(); + + const int numDrivers = 3; + + // Create a shared QueryCtx for all tasks. + auto queryCtx = core::QueryCtx::create( + driverExecutor_.get(), + core::QueryConfig({}), + std::unordered_map>{}, + cache::AsyncDataCache::getInstance(), + nullptr, + nullptr, + queryId); + + // Helper to run a task and return the completed task. + // All tasks use the same queryCtx so they share the cache entry. + auto runTask = [&]() { + return AssertQueryBuilder(joinNode, duckDbQueryRunner_) + .queryCtx(queryCtx) + .maxDrivers(numDrivers) + .assertResults( + "SELECT t.t_k, t.t_v, u.u_k, u.u_v FROM t, u WHERE t.t_k = u.u_k"); + }; + + // Run 10 threads concurrently, each executing 5 tasks sequentially. + constexpr int kNumThreads = 10; + constexpr int kTasksPerThread = 5; + constexpr int kTotalTasks = kNumThreads * kTasksPerThread; + + // Each thread maintains its own local vector of tasks. + std::vector>> threadTasks(kNumThreads); + std::vector threads; + threads.reserve(kNumThreads); + + // Launch all threads at once. + for (int i = 0; i < kNumThreads; ++i) { + threads.emplace_back([&, i]() { + // Each thread runs multiple tasks sequentially into its local vector. + threadTasks[i].reserve(kTasksPerThread); + for (int j = 0; j < kTasksPerThread; ++j) { + threadTasks[i].push_back(runTask()); + } + }); + } + + // Wait for all threads to complete. + for (auto& thread : threads) { + thread.join(); + } + + // Merge all thread-local task vectors into a single vector. + std::vector> allTasks; + allTasks.reserve(kTotalTasks); + for (auto& tasks : threadTasks) { + for (auto& task : tasks) { + allTasks.push_back(std::move(task)); + } + } + + ASSERT_EQ(allTasks.size(), kTotalTasks); + + // Collect stats from all tasks. + int totalCacheMisses = 0; + int totalCacheHits = 0; + + for (const auto& task : allTasks) { + auto opStats = toOperatorStats(task->taskStats()); + ASSERT_EQ(opStats.count("HashBuild"), 1); + auto& hashBuildStats = opStats.at("HashBuild"); + + if (hashBuildStats.runtimeStats.count( + std::string(BaseHashTable::kHashTableCacheMiss))) { + totalCacheMisses += + hashBuildStats.runtimeStats + .at(std::string(BaseHashTable::kHashTableCacheMiss)) + .count; + } + if (hashBuildStats.runtimeStats.count( + std::string(BaseHashTable::kHashTableCacheHit))) { + totalCacheHits += hashBuildStats.runtimeStats + .at(std::string(BaseHashTable::kHashTableCacheHit)) + .count; + } + } + + // Exactly one task should build (cache miss) and all others should reuse + // (cache hits). + EXPECT_EQ(totalCacheMisses, 1) + << "Exactly one task should report a cache miss (the builder)"; + EXPECT_EQ(totalCacheHits, kTotalTasks - 1) + << "All other tasks should report cache hits"; + + // Clean up cache entry before tasks are destroyed. + const auto cacheKey = fmt::format("{}:{}", queryId, joinNodeId); + HashTableCache::instance()->drop(cacheKey); +} + +// Tests that HashBuild and HashProbe cannot reclaim when using a cached hash +// table. When useHashTableCache() is true, canReclaim() returns false for both +// operators because spilling would clear the cached table and corrupt it for +// other tasks. This test uses TestValue to verify canReclaim() returns false. +DEBUG_ONLY_TEST_F(HashJoinWithCacheTest, probeCannotSpillWithCachedTable) { + const std::string queryId = + "probeCannotSpillTest_" + + std::to_string( + std::chrono::steady_clock::now().time_since_epoch().count()); + + // Create build and probe vectors. + std::vector buildVectors = makeBatches(1, [&](int32_t) { + return makeRowVector( + {"u_k", "u_v"}, + { + makeFlatVector(100, [](auto row) { return row % 23; }), + makeFlatVector(100, [](auto row) { return row * 10; }), + }); + }); + + std::vector probeVectors = makeBatches(5, [&](int32_t) { + return makeRowVector( + {"t_k", "t_v"}, + { + makeFlatVector(100, [](auto row) { return row % 23; }), + makeFlatVector(100, [](auto row) { return row; }), + }); + }); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto planNodeIdGenerator = std::make_shared(); + auto buildPlanNode = + PlanBuilder(planNodeIdGenerator).values(buildVectors).planNode(); + auto probePlanNode = + PlanBuilder(planNodeIdGenerator).values(probeVectors).planNode(); + + auto outputType = ROW( + {"t_k", "t_v", "u_k", "u_v"}, {INTEGER(), BIGINT(), INTEGER(), BIGINT()}); + + auto joinNode = + core::HashJoinNode::Builder() + .id(planNodeIdGenerator->next()) + .joinType(core::JoinType::kInner) + .nullAware(false) + .leftKeys( + {std::make_shared(INTEGER(), "t_k")}) + .rightKeys( + {std::make_shared(INTEGER(), "u_k")}) + .left(probePlanNode) + .right(buildPlanNode) + .outputType(outputType) + .useHashTableCache(true) + .build(); + const auto joinNodeId = joinNode->id(); + + auto queryCtx = core::QueryCtx::create( + driverExecutor_.get(), + core::QueryConfig({}), + std::unordered_map>{}, + cache::AsyncDataCache::getInstance(), + nullptr, + nullptr, + queryId); + + // Use TestValue to verify canReclaim() returns false for HashBuild. + std::atomic_bool hashBuildChecked{false}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::HashBuild::addInput", + std::function([&](HashBuild* build) { + if (hashBuildChecked.exchange(true)) { + return; + } + ASSERT_FALSE(build->canReclaim()) + << "HashBuild should not be reclaimable with cached hash table"; + })); + + // Use TestValue to verify canReclaim() returns false for HashProbe. + std::atomic_bool hashProbeChecked{false}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::addInput", + std::function([&](Operator* op) { + if (!isHashProbeMemoryPool(*op->pool())) { + return; + } + if (hashProbeChecked.exchange(true)) { + return; + } + auto* probe = dynamic_cast(op); + ASSERT_NE(probe, nullptr); + ASSERT_FALSE(probe->canReclaim()) + << "HashProbe should not be reclaimable with cached hash table"; + })); + + // Run the query and verify results. + auto task = + AssertQueryBuilder(joinNode, duckDbQueryRunner_) + .queryCtx(queryCtx) + .maxDrivers(1) + .assertResults( + "SELECT t.t_k, t.t_v, u.u_k, u.u_v FROM t, u WHERE t.t_k = u.u_k"); + + // Verify that both operators were checked. + ASSERT_TRUE(hashBuildChecked) << "HashBuild canReclaim check was not reached"; + ASSERT_TRUE(hashProbeChecked) << "HashProbe canReclaim check was not reached"; + + // Verify that HashProbe operator stats show no spilling. + auto opStats = toOperatorStats(task->taskStats()); + ASSERT_EQ(opStats.count("HashProbe"), 1); + auto& probeStats = opStats.at("HashProbe"); + + EXPECT_EQ(probeStats.spilledInputBytes, 0) + << "HashProbe should not spill when using cached hash table"; + EXPECT_EQ(probeStats.spilledBytes, 0) + << "HashProbe should not spill when using cached hash table"; + + // Clean up cache entry before task is destroyed. + const auto cacheKey = fmt::format("{}:{}", queryId, joinNodeId); + HashTableCache::instance()->drop(cacheKey); +} + +// Tests OOM behavior with cached hash tables via memory arbitration. +// This test triggers memory arbitration during HashProbe to verify: +// 1. HashProbe::canReclaim() returns false when useHashTableCache=true +// 2. When an allocation exceeds capacity, arbitration runs but can't reclaim +// 3. OOM is thrown by the arbitration framework +// 4. Cleanup works correctly via QueryCtx release callbacks +DEBUG_ONLY_TEST_F(HashJoinWithCacheTest, probeOOMWithCachedTable) { + const std::string queryId = + "probeOOMTest_" + + std::to_string( + std::chrono::steady_clock::now().time_since_epoch().count()); + + // Create build side with ~1MB hash table. + // 10 batches × 10,000 rows = 100,000 rows × 12 bytes = ~1.2MB raw data. + std::vector buildVectors = makeBatches(10, [&](int32_t) { + return makeRowVector( + {"u_k", "u_v"}, + { + makeFlatVector(10000, [](auto row) { return row % 5000; }), + makeFlatVector(10000, [](auto row) { return row * 10; }), + }); + }); + + // Create probe vectors with matching key range. + std::vector probeVectors = makeBatches(10, [&](int32_t) { + return makeRowVector( + {"t_k", "t_v"}, + { + makeFlatVector(1000, [](auto row) { return row % 5000; }), + makeFlatVector(1000, [](auto row) { return row; }), + }); + }); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto planNodeIdGenerator = std::make_shared(); + auto buildPlanNode = + PlanBuilder(planNodeIdGenerator).values(buildVectors).planNode(); + auto probePlanNode = + PlanBuilder(planNodeIdGenerator).values(probeVectors).planNode(); + + auto outputType = ROW( + {"t_k", "t_v", "u_k", "u_v"}, {INTEGER(), BIGINT(), INTEGER(), BIGINT()}); + + auto joinNode = + core::HashJoinNode::Builder() + .id(planNodeIdGenerator->next()) + .joinType(core::JoinType::kInner) + .nullAware(false) + .leftKeys( + {std::make_shared(INTEGER(), "t_k")}) + .rightKeys( + {std::make_shared(INTEGER(), "u_k")}) + .left(probePlanNode) + .right(buildPlanNode) + .outputType(outputType) + .useHashTableCache(true) + .build(); + + // Create QueryCtx with sufficient memory for build but we'll exhaust it + // during probe via TestValue injection. + auto queryCtx = core::QueryCtx::create( + driverExecutor_.get(), + core::QueryConfig({}), + std::unordered_map>{}, + cache::AsyncDataCache::getInstance(), + nullptr, + nullptr, + queryId); + + // Use a pool with limited capacity. Build uses ~10MB, so 20MB should be + // enough for build but tight for additional allocations during probe. + constexpr int64_t kPoolCapacity = 20 * 1024 * 1024; // 20MB + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), kPoolCapacity, exec::MemoryReclaimer::create())); + + // Use TestValue at the addInput injection point to trigger OOM during + // HashProbe. We allocate more memory than the pool has available, which + // triggers arbitration. Since HashProbe::canReclaim() returns false when + // useHashTableCache=true, arbitration cannot reclaim and OOM is thrown. + std::atomic_bool injected{false}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::addInput", + std::function([&](Operator* op) { + // Only inject once, and only for HashProbe operator. + if (!isHashProbeMemoryPool(*op->pool())) { + return; + } + if (injected.exchange(true)) { + return; + } + + auto* probe = dynamic_cast(op); + ASSERT_NE(probe, nullptr); + + // Verify that HashProbe cannot reclaim when using cached hash table. + // canReclaim() returns false because canSpill() is false. + ASSERT_FALSE(probe->canReclaim()) + << "HashProbe should not be reclaimable with cached hash table"; + + // Allocate memory equal to pool capacity. + // If HashProbe could spill, arbitration would reclaim memory and this + // allocation would succeed. But since canReclaim() returns false with + // cached hash table, arbitration can't free memory and OOM is thrown. + auto* pool = op->pool(); + // This allocation will trigger arbitration and throw OOM. + pool->allocate(kPoolCapacity); + })); + + // This should throw OOM during probe. The cleanup should work correctly. + VELOX_ASSERT_THROW( + AssertQueryBuilder(joinNode, duckDbQueryRunner_) + .queryCtx(queryCtx) + .maxDrivers(1) + .copyResults(pool()), + "Exceeded memory pool capacity"); + + waitForAllTasksToBeDeleted(); + queryCtx.reset(); + // Cache should be cleaned up by QueryCtx destructor via release callback. +} + +// Verifies that the hash table cache pool has a memory reclaimer that properly +// suspends the driver thread during memory arbitration. Without a reclaimer, +// allocations from the cache pool on a driver thread cause: +// "Driver thread is not suspended under memory arbitration processing". +DEBUG_ONLY_TEST_F(HashJoinWithCacheTest, cachePoolArbitration) { + const std::string queryId = + "cachePoolArbitrationTest_" + + std::to_string( + std::chrono::steady_clock::now().time_since_epoch().count()); + + std::vector buildVectors = makeBatches(10, [&](int32_t) { + return makeRowVector( + {"u_k", "u_v"}, + { + makeFlatVector( + 10'000, [](auto row) { return row % 5'000; }), + makeFlatVector(10'000, [](auto row) { return row * 10; }), + }); + }); + + std::vector probeVectors = makeBatches(5, [&](int32_t) { + return makeRowVector( + {"t_k", "t_v"}, + { + makeFlatVector(100, [](auto row) { return row % 5'000; }), + makeFlatVector(100, [](auto row) { return row; }), + }); + }); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto planNodeIdGenerator = std::make_shared(); + auto buildPlanNode = + PlanBuilder(planNodeIdGenerator).values(buildVectors).planNode(); + auto probePlanNode = + PlanBuilder(planNodeIdGenerator).values(probeVectors).planNode(); + + auto outputType = ROW( + {"t_k", "t_v", "u_k", "u_v"}, {INTEGER(), BIGINT(), INTEGER(), BIGINT()}); + + auto joinNode = + core::HashJoinNode::Builder() + .id(planNodeIdGenerator->next()) + .joinType(core::JoinType::kInner) + .nullAware(false) + .leftKeys( + {std::make_shared(INTEGER(), "t_k")}) + .rightKeys( + {std::make_shared(INTEGER(), "u_k")}) + .left(probePlanNode) + .right(buildPlanNode) + .outputType(outputType) + .useHashTableCache(true) + .build(); + const auto joinNodeId = joinNode->id(); + + auto queryCtx = core::QueryCtx::create( + driverExecutor_.get(), + core::QueryConfig({}), + std::unordered_map>{}, + cache::AsyncDataCache::getInstance(), + nullptr, + nullptr, + queryId); + + // Tight pool so the large allocation in the TestValue callback triggers + // memory arbitration via growCapacity. + constexpr int64_t kPoolCapacity = 20 * 1024 * 1024; // 20MB + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), kPoolCapacity, exec::MemoryReclaimer::create())); + + const auto cacheKey = fmt::format("{}:{}", queryId, joinNodeId); + + std::atomic_bool injected{false}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::HashBuild::finishHashBuild", + std::function(([&](exec::HashBuild* hashBuild) { + if (injected.exchange(true)) { + return; + } + // Retrieve the cache entry to access the cache pool. Calling get() + // with the builder task's own ID returns the existing entry without + // side effects. + auto task = hashBuild->operatorCtx()->task(); + ContinueFuture future; + auto entry = HashTableCache::instance()->get( + cacheKey, task->taskId(), task->queryCtx().get(), &future); + + // Allocate enough from the cache pool to trigger growCapacity. + // growCapacity creates a MemoryPoolArbitrationSection on the + // requesting pool, which calls pool->enterArbitration(). The + // reclaimer's enterArbitration() must suspend the driver thread; + // without a reclaimer the driver is not suspended and the + // arbitration state check fails. + entry->tablePool->allocate(kPoolCapacity); + }))); + + // The query throws because the allocation exceeds capacity. + // With a proper reclaimer on the cache pool, the error is OOM + // ("Exceeded memory pool capacity"), not "Driver thread is not suspended". + VELOX_ASSERT_THROW( + AssertQueryBuilder(joinNode, duckDbQueryRunner_) + .queryCtx(queryCtx) + .maxDrivers(1) + .copyResults(pool()), + "Exceeded memory pool capacity"); + + ASSERT_TRUE(injected) << "TestValue injection was not triggered"; + + waitForAllTasksToBeDeleted(); + queryCtx.reset(); +} + +} // namespace + +// Reproduces a spin-loop bug where non-last drivers of the HashBuild pipeline +// never reach kFinish when hash table caching is enabled. +// +// Root cause: after allPeersFinished() returns false, non-last drivers were +// placed in kWaitForBuild state. When they wake from the peers-finished future, +// isBlocked() enters the kWaitForBuild case and calls +// receivedCachedHashTable(), which does setRunning() + noMoreInput(). +// noMoreInput() returns early (noMoreInput_ is already true), so the operator +// stays in kRunning forever -- never reaching kFinish. The driver loop spins +// calling isBlocked/getOutput/isFinished in a tight loop, burning CPU until the +// task is terminated by the probe pipeline completing. +// +// Fix (D100200527): in finishHashBuild(), when allPeersFinished() returns false +// and hash table caching is enabled, set kWaitForProbe instead of +// kWaitForBuild. This skips receivedCachedHashTable() and goes directly to +// postHashBuildProcess() -> kFinish. +// +// Detection: a TestValue fires at the exact point where the state is set for +// each non-last driver. We verify the state is kWaitForProbe (the fix), not +// kWaitForBuild (the bug). This check is deterministic: every non-last driver +// hits the TestValue exactly once, regardless of thread scheduling. +DEBUG_ONLY_TEST_F(HashJoinWithCacheTest, nonLastDriverSpinLoop) { + const std::string queryId = + "spinLoopTest_" + + std::to_string( + std::chrono::steady_clock::now().time_since_epoch().count()); + + std::vector buildVectors = makeBatches(1, [&](int32_t) { + return makeRowVector( + {"u_k", "u_v"}, + { + makeFlatVector(50, [](auto row) { return row % 31; }), + makeFlatVector(50, [](auto row) { return row * 10; }), + }); + }); + + std::vector probeVectors = makeBatches(3, [&](int32_t) { + return makeRowVector( + {"t_k", "t_v"}, + { + makeFlatVector(100, [](auto row) { return row % 23; }), + makeFlatVector(100, [](auto row) { return row; }), + }); + }); + + auto planNodeIdGenerator = std::make_shared(); + // The build ValuesNode must be parallelizable so the build pipeline gets + // multiple drivers. + auto buildPlanNode = PlanBuilder(planNodeIdGenerator) + .values(buildVectors, /*parallelizable=*/true) + .planNode(); + auto probePlanNode = + PlanBuilder(planNodeIdGenerator).values(probeVectors).planNode(); + + auto outputType = ROW( + {"t_k", "t_v", "u_k", "u_v"}, {INTEGER(), BIGINT(), INTEGER(), BIGINT()}); + + auto joinNode = + core::HashJoinNode::Builder() + .id(planNodeIdGenerator->next()) + .joinType(core::JoinType::kInner) + .nullAware(false) + .leftKeys( + {std::make_shared(INTEGER(), "t_k")}) + .rightKeys( + {std::make_shared(INTEGER(), "u_k")}) + .left(probePlanNode) + .right(buildPlanNode) + .outputType(outputType) + .useHashTableCache(true) + .build(); + const auto joinNodeId = joinNode->id(); + + const int numDrivers = 4; + + auto queryCtx = core::QueryCtx::create( + driverExecutor_.get(), + core::QueryConfig({}), + std::unordered_map>{}, + cache::AsyncDataCache::getInstance(), + nullptr, + nullptr, + queryId); + + // Track non-last drivers via the finishHashBuild TestValue. This fires + // right after setState(kWaitForBuild) for each non-last driver. We verify + // the allPeersFinished coordination works (expected count = numDrivers - 1). + // + // The spin loop bug was in receivedCachedHashTable(): when non-last drivers + // woke from the peers-finished future and re-entered isBlocked(), the + // function would call setRunning() + noMoreInput() (a no-op), leaving the + // operator in kRunning forever. The fix adds a guard in + // receivedCachedHashTable() via hashTableCacheBuilderTask() so it returns + // false for builder tasks, letting the driver fall through to + // postHashBuildProcess -> kFinish. + std::atomic_int lastDriverCount{0}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::HashBuild::finishHashBuild", + std::function([&](HashBuild*) { ++lastDriverCount; })); + + // Run the builder task. Skip DuckDB comparison — parallelizable build source + // produces numDrivers copies of build data. This test checks the state + // machine, not result correctness (other tests cover that). + auto task = AssertQueryBuilder(joinNode) + .queryCtx(queryCtx) + .maxDrivers(numDrivers) + .copyResults(pool()); + + // With numDrivers drivers, exactly 1 is the last driver. + EXPECT_EQ(lastDriverCount.load(), 1) + << "Expected exactly 1 last driver, got " << lastDriverCount.load(); + + // Verify the task completed successfully — if the spin loop bug were + // present, non-last drivers would spin in kRunning until the task is + // killed, wasting CPU. The task completing with correct output and + // exactly one last driver confirms the fix works. + + const auto cacheKey = fmt::format("{}:{}", queryId, joinNodeId); + HashTableCache::instance()->drop(cacheKey); +} + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/HashTableCacheTest.cpp b/velox/exec/tests/HashTableCacheTest.cpp new file mode 100644 index 00000000000..3736178af42 --- /dev/null +++ b/velox/exec/tests/HashTableCacheTest.cpp @@ -0,0 +1,276 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/HashTableCache.h" + +#include + +#include "velox/common/caching/AsyncDataCache.h" +#include "velox/common/memory/Memory.h" +#include "velox/core/QueryCtx.h" + +namespace facebook::velox::exec::test { + +class HashTableCacheTest : public testing::Test { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance({}); + } + + void SetUp() override { + pool_ = memory::memoryManager()->addRootPool("HashTableCacheTest"); + queryCtx_ = core::QueryCtx::create(); + } + + void TearDown() override { + // Clean up any cache entries created during tests. + auto* cache = HashTableCache::instance(); + for (const auto& key : createdKeys_) { + cache->drop(key); + } + createdKeys_.clear(); + queryCtx_.reset(); + } + + // Helper to track keys for cleanup. + void trackKey(const std::string& key) { + createdKeys_.push_back(key); + } + + std::shared_ptr pool_; + std::shared_ptr queryCtx_; + std::vector createdKeys_; +}; + +TEST_F(HashTableCacheTest, basicGet) { + auto* cache = HashTableCache::instance(); + const std::string key = "query1:node1"; + trackKey(key); + + ContinueFuture future = ContinueFuture::makeEmpty(); + auto entry = cache->get(key, "task1", queryCtx_.get(), &future); + + ASSERT_NE(entry, nullptr); + ASSERT_NE(entry->tablePool, nullptr); + EXPECT_EQ(entry->builderTaskId, "task1"); + EXPECT_FALSE(entry->buildComplete); + // First caller should not get a future (they are the builder). + EXPECT_FALSE(future.valid()); +} + +TEST_F(HashTableCacheTest, secondCallerGetsWaitFuture) { + auto* cache = HashTableCache::instance(); + const std::string key = "query2:node1"; + trackKey(key); + + // First caller (builder). + ContinueFuture future1 = ContinueFuture::makeEmpty(); + auto entry1 = cache->get(key, "task1", queryCtx_.get(), &future1); + EXPECT_FALSE(future1.valid()); + EXPECT_EQ(entry1->builderTaskId, "task1"); + + // Second caller (waiter). + ContinueFuture future2 = ContinueFuture::makeEmpty(); + auto entry2 = cache->get(key, "task2", queryCtx_.get(), &future2); + + // Should get the same entry. + EXPECT_EQ(entry1, entry2); + // Should get a valid future to wait on. + EXPECT_TRUE(future2.valid()); + // Builder task ID should not change. + EXPECT_EQ(entry2->builderTaskId, "task1"); +} + +TEST_F(HashTableCacheTest, putNotifiesWaiters) { + auto* cache = HashTableCache::instance(); + const std::string key = "query3:node1"; + trackKey(key); + + // First caller (builder). + ContinueFuture future1 = ContinueFuture::makeEmpty(); + auto entry = cache->get(key, "task1", queryCtx_.get(), &future1); + + // Second caller (waiter). + ContinueFuture future2 = ContinueFuture::makeEmpty(); + cache->get(key, "task2", queryCtx_.get(), &future2); + ASSERT_TRUE(future2.valid()); + + // Put the table. + cache->put(key, nullptr, false); + + // Entry should now be marked complete. + EXPECT_TRUE(entry->buildComplete); + + // The future should be fulfilled. + EXPECT_TRUE(future2.isReady()); +} + +TEST_F(HashTableCacheTest, getAfterBuildComplete) { + auto* cache = HashTableCache::instance(); + const std::string key = "query4:node1"; + trackKey(key); + + // First caller creates and builds. + ContinueFuture future1 = ContinueFuture::makeEmpty(); + cache->get(key, "task1", queryCtx_.get(), &future1); + cache->put(key, nullptr, true); + + // Later caller should get completed entry without waiting. + ContinueFuture future2 = ContinueFuture::makeEmpty(); + auto entry = cache->get(key, "task2", queryCtx_.get(), &future2); + + EXPECT_TRUE(entry->buildComplete); + EXPECT_FALSE(future2.valid()); // No need to wait. + EXPECT_TRUE(entry->hasNullKeys); +} + +TEST_F(HashTableCacheTest, drop) { + auto* cache = HashTableCache::instance(); + const std::string key = "query5:node1"; + // Don't track - we're testing drop. + + ContinueFuture future = ContinueFuture::makeEmpty(); + auto entry1 = cache->get(key, "task1", queryCtx_.get(), &future); + ASSERT_NE(entry1, nullptr); + + // Keep track of the original pool to verify it's different after flush. + auto originalPool = entry1->tablePool; + + // Drop the entry. + cache->drop(key); + + // Getting the same key should create a new entry. + // Use a new queryCtx with different pool to avoid the leaf child name + // collision. + auto pool2 = memory::memoryManager()->addRootPool("HashTableCacheTest2"); + auto queryCtx2 = core::QueryCtx::create( + nullptr, // executor + core::QueryConfig{{}}, // queryConfig + {}, // connectorConfigs + cache::AsyncDataCache::getInstance(), // cache + pool2); + ContinueFuture future2 = ContinueFuture::makeEmpty(); + auto entry2 = cache->get(key, "task2", queryCtx2.get(), &future2); + + EXPECT_NE(entry1, entry2); + EXPECT_EQ(entry2->builderTaskId, "task2"); + // The pool should be different (created under queryCtx2's pool). + EXPECT_NE(entry1->tablePool, entry2->tablePool); + + // Cleanup. + cache->drop(key); +} + +TEST_F(HashTableCacheTest, concurrentWaiters) { + auto* cache = HashTableCache::instance(); + const std::string key = "query6:node1"; + trackKey(key); + + // First caller (builder). + ContinueFuture builderFuture = ContinueFuture::makeEmpty(); + auto entry = cache->get(key, "builder", queryCtx_.get(), &builderFuture); + + // Multiple waiters. + constexpr int kNumWaiters = 5; + std::vector waiterFutures(kNumWaiters); + + for (int i = 0; i < kNumWaiters; ++i) { + waiterFutures[i] = ContinueFuture::makeEmpty(); + cache->get( + key, fmt::format("waiter{}", i), queryCtx_.get(), &waiterFutures[i]); + EXPECT_TRUE(waiterFutures[i].valid()); + } + + // Put the table. + cache->put(key, nullptr, false); + + // All waiters should be notified. + for (int i = 0; i < kNumWaiters; ++i) { + EXPECT_TRUE(waiterFutures[i].isReady()); + } +} + +TEST_F(HashTableCacheTest, builderTaskDriversDoNotWait) { + auto* cache = HashTableCache::instance(); + const std::string key = "query7:node1"; + trackKey(key); + + // First driver of builder task. + ContinueFuture future1 = ContinueFuture::makeEmpty(); + auto entry1 = cache->get(key, "task1", queryCtx_.get(), &future1); + EXPECT_FALSE(future1.valid()); + + // Second driver of the same builder task should not wait. + ContinueFuture future2 = ContinueFuture::makeEmpty(); + auto entry2 = cache->get(key, "task1", queryCtx_.get(), &future2); + + EXPECT_EQ(entry1, entry2); + // Same task should not get a future - they coordinate via JoinBridge. + EXPECT_FALSE(future2.valid()); +} + +// Verifies that when the builder task fails (e.g., OOM in finishHashBuild()), +// dropping the cache entry unblocks all waiting tasks and allows new tasks +// to become builders. +TEST_F(HashTableCacheTest, builderFailureUnblocksWaiters) { + auto* cache = HashTableCache::instance(); + const std::string key = "query_oom:node1"; + + // Builder task gets the entry. + ContinueFuture builderFuture = ContinueFuture::makeEmpty(); + auto builderEntry = + cache->get(key, "builder_task", queryCtx_.get(), &builderFuture); + ASSERT_FALSE(builderFuture.valid()); + + // Two waiter tasks arrive while builder is working. + ContinueFuture waiterFuture1 = ContinueFuture::makeEmpty(); + cache->get(key, "waiter_task_1", queryCtx_.get(), &waiterFuture1); + ASSERT_TRUE(waiterFuture1.valid()); + + ContinueFuture waiterFuture2 = ContinueFuture::makeEmpty(); + cache->get(key, "waiter_task_2", queryCtx_.get(), &waiterFuture2); + ASSERT_TRUE(waiterFuture2.valid()); + + // Simulate builder task failure: close() now calls drop() when the builder + // fails without completing the build. + cache->drop(key); + builderEntry.reset(); + + // Both waiters should be unblocked. + EXPECT_TRUE(waiterFuture1.isReady()) + << "Pre-existing waiter should be unblocked after builder failure"; + EXPECT_TRUE(waiterFuture2.isReady()) + << "Second waiter should also be unblocked after builder failure"; + + // A new task arriving after drop should become a fresh builder. + auto pool2 = memory::memoryManager()->addRootPool("HashTableCacheTest_retry"); + auto queryCtx2 = core::QueryCtx::create( + nullptr, + core::QueryConfig{{}}, + {}, + cache::AsyncDataCache::getInstance(), + pool2); + ContinueFuture newFuture = ContinueFuture::makeEmpty(); + auto newEntry = cache->get(key, "new_builder", queryCtx2.get(), &newFuture); + EXPECT_FALSE(newFuture.valid()) + << "New task after drop should become builder, not waiter"; + EXPECT_EQ(newEntry->builderTaskId, "new_builder"); + EXPECT_FALSE(newEntry->buildComplete); + + cache->drop(key); +} + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/HashTableTest.cpp b/velox/exec/tests/HashTableTest.cpp index bee323d38f8..5466bdf149d 100644 --- a/velox/exec/tests/HashTableTest.cpp +++ b/velox/exec/tests/HashTableTest.cpp @@ -15,7 +15,7 @@ */ #include "velox/exec/HashTable.h" -#include "folly/experimental/EventCount.h" +#include "folly/synchronization/EventCount.h" #include "velox/common/base/SelectivityInfo.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/testutil/TestValue.h" @@ -132,11 +132,18 @@ class HashTableTest : public testing::TestWithParam, std::vector batches; std::vector> keyHashers; for (auto channel = 0; channel < numKeys; ++channel) { - keyHashers.emplace_back(std::make_unique( - buildType->childAt(channel), channel)); + keyHashers.emplace_back( + std::make_unique( + buildType->childAt(channel), channel)); } auto table = HashTable::createForJoin( - std::move(keyHashers), dependentTypes, true, false, 1'000, pool()); + std::move(keyHashers), + dependentTypes, + true, + false, + false, + 1'000, + pool()); makeRows(size, 1, sequence, buildType, batches); copyVectorsToTable(batches, startOffset, table.get()); @@ -158,6 +165,8 @@ class HashTableTest : public testing::TestWithParam, topTable_->prepareJoinTable( std::move(otherTables), BaseHashTable::kNoSpillInputStartPartitionBit, + 1'000'000, + false, executor_.get()); ASSERT_GE( estimatedTableSize, @@ -431,8 +440,9 @@ class HashTableTest : public testing::TestWithParam, TypePtr buildType, std::vector& batches) { for (auto i = 0; i < numBatches; ++i) { - batches.push_back(std::static_pointer_cast( - makeVector(buildType, batchSize, sequence))); + batches.push_back( + std::static_pointer_cast( + makeVector(buildType, batchSize, sequence))); sequence += batchSize; } } @@ -539,10 +549,14 @@ class HashTableTest : public testing::TestWithParam, std::vector> hashers; hashers.push_back(std::make_unique(keys->type(), 0)); auto table = HashTable::createForJoin( - std::move(hashers), {BIGINT()}, true, false, 1'000, pool()); + std::move(hashers), {BIGINT()}, true, false, false, 1'000, pool()); copyVectorsToTable({batch}, 0, table.get()); table->prepareJoinTable( - {}, BaseHashTable::kNoSpillInputStartPartitionBit, executor_.get()); + {}, + BaseHashTable::kNoSpillInputStartPartitionBit, + 1'000'000, + false, + executor_.get()); ASSERT_EQ(table->hashMode(), mode); std::vector rows(nullValues.size()); BaseHashTable::NullKeyRowsIterator iter; @@ -835,12 +849,16 @@ TEST_P(HashTableTest, regularHashingTableSize) { std::make_unique(type->childAt(channel), channel)); } auto table = HashTable::createForJoin( - std::move(keyHashers), {}, true, false, 1'000, pool()); + std::move(keyHashers), {}, true, false, false, 1'000, pool()); std::vector batches; makeRows(1 << 12, 1, 0, type, batches); copyVectorsToTable(batches, 0, table.get()); table->prepareJoinTable( - {}, BaseHashTable::kNoSpillInputStartPartitionBit, executor_.get()); + {}, + BaseHashTable::kNoSpillInputStartPartitionBit, + 1'000'000, + false, + executor_.get()); ASSERT_EQ(table->hashMode(), mode); EXPECT_GE(table->testingRehashSize(), table->numDistinct()); }; @@ -872,6 +890,7 @@ TEST_P(HashTableTest, listJoinResultsSize) { {BIGINT(), VARCHAR()}, true, false, + false, kNumRows, pool()); std::vector batches; @@ -1026,7 +1045,7 @@ DEBUG_ONLY_TEST_P(HashTableTest, nextBucketOffset) { std::make_unique(type->childAt(channel), channel)); } auto table = HashTable::createForJoin( - std::move(keyHashers), {}, true, false, 1'000, pool()); + std::move(keyHashers), {}, true, false, false, 1'000, pool()); auto testHelper = HashTableTestHelper::create(table.get()); const uint64_t numDistincts = bits::nextPowerOfTwo( 2UL * std::numeric_limits::max() / testHelper.tableSlotSize()); @@ -1133,7 +1152,7 @@ DEBUG_ONLY_TEST_P(HashTableTest, failureInCreateRowPartitions) { // Set minTableSizeForParallelJoinBuild to be really small so we can trigger // a parallel join build without needing a lot of data. auto table = HashTable::createForJoin( - std::move(hashers), {BIGINT()}, true, false, 1, pool()); + std::move(hashers), {BIGINT()}, true, false, false, 1, pool()); copyVectorsToTable({batch}, 0, table.get()); if (topTable == nullptr) { @@ -1146,6 +1165,8 @@ DEBUG_ONLY_TEST_P(HashTableTest, failureInCreateRowPartitions) { topTable->prepareJoinTable( std::move(otherTables), BaseHashTable::kNoSpillInputStartPartitionBit, + 1'000'000, + false, executor_.get()); auto topTabletestHelper = HashTableTestHelper::create(topTable.get()); @@ -1225,6 +1246,7 @@ TEST_P(HashTableTest, toStringSingleKey) { {}, /*dependentTypes*/ true /*allowDuplicates*/, false /*hasProbedFlag*/, + false /*hasCountFlag*/, 1 /*minTableSizeForParallelJoinBuild*/, pool()); @@ -1234,7 +1256,8 @@ TEST_P(HashTableTest, toStringSingleKey) { store(*table->rows(), data); - table->prepareJoinTable({}, BaseHashTable::kNoSpillInputStartPartitionBit); + table->prepareJoinTable( + {}, BaseHashTable::kNoSpillInputStartPartitionBit, 1'000'000); ASSERT_NO_THROW(table->toString()); ASSERT_NO_THROW(table->toString(0)); @@ -1253,6 +1276,7 @@ TEST_P(HashTableTest, toStringMultipleKeys) { {}, /*dependentTypes*/ true /*allowDuplicates*/, false /*hasProbedFlag*/, + false /*hasCountFlag*/, 1 /*minTableSizeForParallelJoinBuild*/, pool()); @@ -1265,16 +1289,18 @@ TEST_P(HashTableTest, toStringMultipleKeys) { store(*table->rows(), data); - table->prepareJoinTable({}, BaseHashTable::kNoSpillInputStartPartitionBit); + table->prepareJoinTable( + {}, BaseHashTable::kNoSpillInputStartPartitionBit, 1'000'000); ASSERT_NO_THROW(table->toString()); } TEST(HashTableTest, tableInsertPartitionInfo) { std::vector overflows; + std::vector overflowHashes; const auto testFn = [&](PartitionBoundIndexType start, PartitionBoundIndexType end) { - TableInsertPartitionInfo info{start, end, overflows}; + TableInsertPartitionInfo info{start, end, overflows, overflowHashes}; }; struct { PartitionBoundIndexType start; @@ -1290,8 +1316,9 @@ TEST(HashTableTest, tableInsertPartitionInfo) { VELOX_ASSERT_THROW(testFn(badData.start, badData.end), ""); } ASSERT_TRUE(overflows.empty()); + ASSERT_TRUE(overflowHashes.empty()); - TableInsertPartitionInfo info{1, 1000, overflows}; + TableInsertPartitionInfo info{1, 1000, overflows, overflowHashes}; ASSERT_TRUE(info.inRange(1)); ASSERT_FALSE(info.inRange(0)); ASSERT_FALSE(info.inRange(-1)); @@ -1299,17 +1326,23 @@ TEST(HashTableTest, tableInsertPartitionInfo) { ASSERT_FALSE(info.inRange(1'000)); ASSERT_FALSE(info.inRange(12'000)); ASSERT_TRUE(overflows.empty()); + ASSERT_TRUE(overflowHashes.empty()); const std::vector insertBuffers{100, 200, 300, 500}; - for (const auto insertBuffer : insertBuffers) { - info.addOverflow(reinterpret_cast(insertBuffer)); + const std::vector insertHashes{0xAA, 0xBB, 0xCC, 0xDD}; + for (int i = 0; i < insertBuffers.size(); ++i) { + info.addOverflow( + reinterpret_cast(insertBuffers[i]), insertHashes[i]); } ASSERT_EQ(overflows.size(), insertBuffers.size()); + ASSERT_EQ(overflowHashes.size(), insertHashes.size()); for (int i = 0; i < insertBuffers.size(); ++i) { ASSERT_EQ(insertBuffers[i], reinterpret_cast(info.overflows[i])); + ASSERT_EQ(insertHashes[i], info.overflowHashes[i]); } for (int i = 0; i < overflows.size(); ++i) { ASSERT_EQ(overflows[i], info.overflows[i]); + ASSERT_EQ(overflowHashes[i], info.overflowHashes[i]); } } } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/HilbertIndexTest.cpp b/velox/exec/tests/HilbertIndexTest.cpp new file mode 100644 index 00000000000..85cc634f253 --- /dev/null +++ b/velox/exec/tests/HilbertIndexTest.cpp @@ -0,0 +1,174 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/HilbertIndex.h" +#include +#include + +using namespace ::testing; +using namespace facebook::velox::exec; + +namespace facebook::velox::exec::test { + +class HilbertIndexTest : public virtual testing::Test {}; + +TEST_F(HilbertIndexTest, testOrder) { + HilbertIndex hilbert(0, 0, 4, 4); + + uint32_t h0 = hilbert.indexOf(0.0, 0.0); + uint32_t h1 = hilbert.indexOf(1.0, 1.0); + uint32_t h2 = hilbert.indexOf(1.0, 3.0); + uint32_t h3 = hilbert.indexOf(3.0, 3.0); + uint32_t h4 = hilbert.indexOf(3.0, 1.0); + + ASSERT_LT(h0, h1); + ASSERT_LT(h1, h2); + ASSERT_LT(h2, h3); + ASSERT_LT(h3, h4); +} + +TEST_F(HilbertIndexTest, testOutOfBounds) { + HilbertIndex hilbert(0, 0, 1, 1); + + ASSERT_EQ(hilbert.indexOf(2.0, 2.0), std::numeric_limits::max()); +} + +TEST_F(HilbertIndexTest, testDegenerateRectangle) { + HilbertIndex hilbert(0, 0, 0, 0); + + ASSERT_EQ(hilbert.indexOf(0.0, 0.0), 0); + ASSERT_EQ(hilbert.indexOf(2.0, 2.0), std::numeric_limits::max()); +} + +TEST_F(HilbertIndexTest, testDegenerateHorizontalRectangle) { + HilbertIndex hilbert(0, 0, 4, 0); + + ASSERT_EQ(hilbert.indexOf(0.0, 0.0), 0); + ASSERT_LT(hilbert.indexOf(1.0, 0.0), hilbert.indexOf(2.0, 0.0)); + ASSERT_EQ(hilbert.indexOf(0.0, 2.0), std::numeric_limits::max()); + ASSERT_EQ(hilbert.indexOf(2.0, 2.0), std::numeric_limits::max()); +} + +TEST_F(HilbertIndexTest, testDegenerateVerticalRectangle) { + HilbertIndex hilbert(0, 0, 0, 4); + + ASSERT_EQ(hilbert.indexOf(0.0, 0.0), 0); + ASSERT_LT(hilbert.indexOf(0.0, 1.0), hilbert.indexOf(0.0, 2.0)); + ASSERT_EQ(hilbert.indexOf(2.0, 0.0), std::numeric_limits::max()); + ASSERT_EQ(hilbert.indexOf(2.0, 2.0), std::numeric_limits::max()); +} + +TEST_F(HilbertIndexTest, testNegativeCoordinates) { + HilbertIndex hilbert(-10, -10, 10, 10); + + uint32_t h0 = hilbert.indexOf(-5.0, -5.0); + uint32_t h1 = hilbert.indexOf(0.0, 0.0); + uint32_t h2 = hilbert.indexOf(5.0, 5.0); + + ASSERT_LT(h0, h1); + ASSERT_LT(h1, h2); + + ASSERT_EQ( + hilbert.indexOf(-15.0, -15.0), std::numeric_limits::max()); + ASSERT_EQ(hilbert.indexOf(15.0, 15.0), std::numeric_limits::max()); +} + +TEST_F(HilbertIndexTest, testFloatingPointPrecision) { + HilbertIndex hilbert(0, 0, 1, 1); + + uint32_t h1 = hilbert.indexOf(0.1, 0.1); + uint32_t h2 = hilbert.indexOf(0.2, 0.2); + uint32_t h3 = hilbert.indexOf(0.9, 0.9); + + ASSERT_LT(h1, h2); + ASSERT_LT(h2, h3); +} + +TEST_F(HilbertIndexTest, testBoundaryPoints) { + HilbertIndex hilbert(0, 0, 10, 10); + + uint32_t h0 = hilbert.indexOf(0.0, 0.0); + uint32_t h1 = hilbert.indexOf(10.0, 10.0); + uint32_t h2 = hilbert.indexOf(0.0, 10.0); + // Bottom-right corner is at the end of the range, so may be MAX + + ASSERT_NE(h0, std::numeric_limits::max()); + ASSERT_NE(h1, std::numeric_limits::max()); + ASSERT_NE(h2, std::numeric_limits::max()); +} + +TEST_F(HilbertIndexTest, testLargeCoordinates) { + HilbertIndex hilbert(0, 0, 1000000, 1000000); + + uint32_t h1 = hilbert.indexOf(100000, 100000); + uint32_t h2 = hilbert.indexOf(500000, 500000); + uint32_t h3 = hilbert.indexOf(900000, 900000); + + ASSERT_LT(h1, h2); + ASSERT_LT(h2, h3); +} + +TEST_F(HilbertIndexTest, testDensityInSmallRegion) { + HilbertIndex hilbert(0, 0, 100, 100); + + std::vector indices; + for (int i = 0; i < 10; ++i) { + for (int j = 0; j < 10; ++j) { + indices.push_back(hilbert.indexOf(i * 10.0f + 5.0f, j * 10.0f + 5.0f)); + } + } + + std::set uniqueIndices(indices.begin(), indices.end()); + ASSERT_EQ(indices.size(), 100); + ASSERT_GT(uniqueIndices.size(), 90); +} + +TEST_F(HilbertIndexTest, testSpatialLocality) { + HilbertIndex hilbert(0, 0, 100, 100); + + uint32_t h1 = hilbert.indexOf(50.0, 50.0); + uint32_t h2 = hilbert.indexOf(50.1, 50.1); + uint32_t h3 = hilbert.indexOf(50.2, 50.2); + uint32_t h4 = hilbert.indexOf(90.0, 90.0); + + uint32_t diff12 = std::abs(static_cast(h1 - h2)); + uint32_t diff23 = std::abs(static_cast(h2 - h3)); + uint32_t diff14 = std::abs(static_cast(h1 - h4)); + + ASSERT_LT(diff12, diff14); + ASSERT_LT(diff23, diff14); +} + +TEST_F(HilbertIndexTest, testIdenticalPoints) { + HilbertIndex hilbert(0, 0, 10, 10); + + uint32_t h1 = hilbert.indexOf(5.0, 5.0); + uint32_t h2 = hilbert.indexOf(5.0, 5.0); + + ASSERT_EQ(h1, h2); +} + +TEST_F(HilbertIndexTest, testExtremelySmallBounds) { + HilbertIndex hilbert(0, 0, 0.001, 0.001); + + uint32_t h1 = hilbert.indexOf(0.0, 0.0); + uint32_t h2 = hilbert.indexOf(0.0005, 0.0005); + + ASSERT_NE(h1, std::numeric_limits::max()); + ASSERT_NE(h2, std::numeric_limits::max()); +} + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/IndexLookupJoinBridgeTest.cpp b/velox/exec/tests/IndexLookupJoinBridgeTest.cpp new file mode 100644 index 00000000000..63ca4823748 --- /dev/null +++ b/velox/exec/tests/IndexLookupJoinBridgeTest.cpp @@ -0,0 +1,264 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/exec/IndexLookupJoinBridge.h" + +#include +#include +#include + +#include "velox/common/base/tests/GTestUtils.h" + +using namespace facebook::velox; +using namespace facebook::velox::exec; +using namespace facebook::velox::connector; + +namespace { + +std::shared_ptr makeFakeSplit(const std::string& id) { + return std::make_shared(id); +} + +std::vector> makeFakeSplits(int count) { + std::vector> splits; + splits.reserve(count); + for (int i = 0; i < count; ++i) { + splits.push_back(makeFakeSplit("connector_" + std::to_string(i))); + } + return splits; +} + +class IndexLookupJoinBridgeTest : public testing::Test { + protected: + std::shared_ptr createBridge() { + auto bridge = std::make_shared(); + bridge->start(); + return bridge; + } +}; + +TEST_F(IndexLookupJoinBridgeTest, setAndGetSplits) { + auto bridge = createBridge(); + auto splits = makeFakeSplits(3); + const auto expectedSize = splits.size(); + + bridge->setIndexSplits(std::move(splits)); + + ContinueFuture future = ContinueFuture::makeEmpty(); + auto result = bridge->splitsOrFuture(&future); + ASSERT_FALSE(future.valid()); + ASSERT_EQ(result.size(), expectedSize); +} + +TEST_F(IndexLookupJoinBridgeTest, splitsOrFutureBlocks) { + auto bridge = createBridge(); + + ContinueFuture future = ContinueFuture::makeEmpty(); + auto result = bridge->splitsOrFuture(&future); + ASSERT_TRUE(future.valid()); + ASSERT_TRUE(result.empty()); + + // Set splits and verify the future is fulfilled. + auto splits = makeFakeSplits(2); + bridge->setIndexSplits(std::move(splits)); + + std::move(future).via(&folly::InlineExecutor::instance()).get(); + + // Now splitsOrFuture should return immediately on repeated calls. + for (int i = 0; i < 3; ++i) { + ContinueFuture future2 = ContinueFuture::makeEmpty(); + auto result2 = bridge->splitsOrFuture(&future2); + ASSERT_FALSE(future2.valid()); + ASSERT_EQ(result2.size(), 2); + } +} + +TEST_F(IndexLookupJoinBridgeTest, multipleFollowers) { + auto bridge = createBridge(); + + // Multiple followers wait for splits. + const int numFollowers = 4; + std::vector futures; + for (int i = 0; i < numFollowers; ++i) { + ContinueFuture future = ContinueFuture::makeEmpty(); + auto result = bridge->splitsOrFuture(&future); + ASSERT_TRUE(future.valid()); + ASSERT_TRUE(result.empty()); + futures.push_back(std::move(future)); + } + + // Leader sets splits. + bridge->setIndexSplits(makeFakeSplits(5)); + + // All futures should be fulfilled. + for (auto& future : futures) { + std::move(future).via(&folly::InlineExecutor::instance()).get(); + } + + // All followers can now get splits. + for (int i = 0; i < numFollowers; ++i) { + ContinueFuture future = ContinueFuture::makeEmpty(); + auto result = bridge->splitsOrFuture(&future); + ASSERT_FALSE(future.valid()); + ASSERT_EQ(result.size(), 5); + } +} + +TEST_F(IndexLookupJoinBridgeTest, setBeforeGet) { + auto bridge = createBridge(); + + // Leader sets splits before any follower calls splitsOrFuture. + bridge->setIndexSplits(makeFakeSplits(3)); + + // Followers get splits immediately without blocking. + for (int i = 0; i < 3; ++i) { + ContinueFuture future = ContinueFuture::makeEmpty(); + auto result = bridge->splitsOrFuture(&future); + ASSERT_FALSE(future.valid()); + ASSERT_EQ(result.size(), 3); + } +} + +TEST_F(IndexLookupJoinBridgeTest, setTwiceThrows) { + auto bridge = createBridge(); + bridge->setIndexSplits(makeFakeSplits(1)); + + VELOX_ASSERT_THROW( + bridge->setIndexSplits(makeFakeSplits(1)), + "setIndexSplits must be called only once"); +} + +TEST_F(IndexLookupJoinBridgeTest, setEmptyTwiceThrows) { + auto bridge = createBridge(); + bridge->setIndexSplits({}); + + VELOX_ASSERT_THROW( + bridge->setIndexSplits({}), "setIndexSplits must be called only once"); +} + +TEST_F(IndexLookupJoinBridgeTest, setEmptySplits) { + auto bridge = createBridge(); + bridge->setIndexSplits({}); + + ContinueFuture future = ContinueFuture::makeEmpty(); + auto result = bridge->splitsOrFuture(&future); + ASSERT_FALSE(future.valid()); + ASSERT_TRUE(result.empty()); +} + +TEST_F(IndexLookupJoinBridgeTest, emptySplitsUnblockFollowers) { + auto bridge = createBridge(); + + ContinueFuture future = ContinueFuture::makeEmpty(); + auto result = bridge->splitsOrFuture(&future); + ASSERT_TRUE(future.valid()); + ASSERT_TRUE(result.empty()); + + bridge->setIndexSplits({}); + + std::move(future).via(&folly::InlineExecutor::instance()).get(); + + ContinueFuture future2 = ContinueFuture::makeEmpty(); + auto result2 = bridge->splitsOrFuture(&future2); + ASSERT_FALSE(future2.valid()); + ASSERT_TRUE(result2.empty()); +} + +TEST_F(IndexLookupJoinBridgeTest, setBeforeStartThrows) { + auto bridge = std::make_shared(); + // Do not call start(). + + VELOX_ASSERT_THROW( + bridge->setIndexSplits(makeFakeSplits(1)), + "Bridge must be started before setting index splits"); +} + +TEST_F(IndexLookupJoinBridgeTest, cancelUnblocksFollowers) { + auto bridge = createBridge(); + + ContinueFuture future = ContinueFuture::makeEmpty(); + auto result = bridge->splitsOrFuture(&future); + ASSERT_TRUE(future.valid()); + + bridge->cancel(); + + // Future should be fulfilled by cancel. + std::move(future).via(&folly::InlineExecutor::instance()).get(); +} + +TEST_F(IndexLookupJoinBridgeTest, getAfterCancelThrows) { + auto bridge = createBridge(); + bridge->cancel(); + + ContinueFuture future = ContinueFuture::makeEmpty(); + VELOX_ASSERT_THROW( + bridge->splitsOrFuture(&future), + "Getting index splits after the bridge is cancelled"); +} + +TEST_F(IndexLookupJoinBridgeTest, setAfterCancelThrows) { + auto bridge = createBridge(); + bridge->cancel(); + + VELOX_ASSERT_THROW( + bridge->setIndexSplits(makeFakeSplits(1)), + "Setting index splits after the bridge is cancelled"); +} + +TEST_F(IndexLookupJoinBridgeTest, concurrentSetAndGet) { + auto bridge = createBridge(); + const int numFollowers = 8; + const int numSplits = 5; + + std::vector threads; + std::atomic_int successCount{0}; + std::atomic_bool stop{false}; + + // Start followers that keep getting splits until stopped. + threads.reserve(numFollowers); + for (int i = 0; i < numFollowers; ++i) { + threads.emplace_back([&bridge, &successCount, &stop, numSplits]() { + ContinueFuture future = ContinueFuture::makeEmpty(); + auto result = bridge->splitsOrFuture(&future); + if (future.valid()) { + std::move(future).via(&folly::InlineExecutor::instance()).get(); + } + while (!stop.load()) { + ContinueFuture f = ContinueFuture::makeEmpty(); + auto r = bridge->splitsOrFuture(&f); + ASSERT_FALSE(f.valid()); + ASSERT_EQ(r.size(), numSplits); + ++successCount; + } + }); + } + + // Give followers time to register. + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + + // Leader sets splits. + bridge->setIndexSplits(makeFakeSplits(numSplits)); + + // Let followers run for a while. + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + stop.store(true); + + for (auto& t : threads) { + t.join(); + } + ASSERT_GE(successCount.load(), numFollowers); +} + +} // namespace diff --git a/velox/exec/tests/IndexLookupJoinTest.cpp b/velox/exec/tests/IndexLookupJoinTest.cpp index f1e056f117e..a63af98d62b 100644 --- a/velox/exec/tests/IndexLookupJoinTest.cpp +++ b/velox/exec/tests/IndexLookupJoinTest.cpp @@ -15,12 +15,14 @@ */ #include "velox/exec/IndexLookupJoin.h" -#include "folly/experimental/EventCount.h" +#include "fmt/format.h" +#include "folly/synchronization/EventCount.h" #include "gmock/gmock.h" #include "gtest/gtest-matchers.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/testutil/TestValue.h" #include "velox/connectors/Connector.h" +#include "velox/connectors/ConnectorRegistry.h" #include "velox/core/PlanNode.h" #include "velox/exec/PlanNodeStats.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" @@ -28,37 +30,42 @@ #include "velox/exec/tests/utils/IndexLookupJoinTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/exec/tests/utils/TestIndexStorageConnector.h" +#include "velox/vector/LazyVector.h" using namespace facebook::velox; using namespace facebook::velox::exec; using namespace facebook::velox::exec::test; using namespace facebook::velox::common::testutil; -namespace fecebook::velox::exec::test { +namespace facebook::velox::exec::test { namespace { struct TestParam { bool asyncLookup; int32_t numPrefetches; bool serialExecution; bool hasNullKeys; + bool needsIndexSplit; TestParam( bool _asyncLookup, int32_t _numPrefetches, bool _serialExecution, - bool _hasNullKeys) + bool _hasNullKeys, + bool _needsIndexSplit = false) : asyncLookup(_asyncLookup), numPrefetches(_numPrefetches), serialExecution(_serialExecution), - hasNullKeys(_hasNullKeys) {} + hasNullKeys(_hasNullKeys), + needsIndexSplit(_needsIndexSplit) {} std::string toString() const { return fmt::format( - "asyncLookup={}, numPrefetches={}, serialExecution={}, hasNullKeys={}", + "asyncLookup={}, numPrefetches={}, serialExecution={}, hasNullKeys={}, needsIndexSplit={}", asyncLookup, numPrefetches, serialExecution, - hasNullKeys); + hasNullKeys, + needsIndexSplit); } }; @@ -71,8 +78,20 @@ class IndexLookupJoinTest : public IndexLookupJoinTestBase, for (int numPrefetches : {0, 3}) { for (bool serialExecution : {false, true}) { for (bool hasNullKeys : {false, true}) { - testParams.emplace_back( - asyncLookup, numPrefetches, serialExecution, hasNullKeys); + for (bool needsIndexSplit : {false, true}) { + // Serial execution doesn't support index split as it requires + // single-threaded execution which is incompatible with the + // split-based parallelism used by index lookup join. + if (serialExecution && needsIndexSplit) { + continue; + } + testParams.emplace_back( + asyncLookup, + numPrefetches, + serialExecution, + hasNullKeys, + needsIndexSplit); + } } } } @@ -98,7 +117,7 @@ class IndexLookupJoinTest : public IndexLookupJoinTestBase, } void TearDown() override { - connector::unregisterConnector(kTestIndexConnectorName); + connector::ConnectorRegistry::global().erase(kTestIndexConnectorName); HiveConnectorTestBase::TearDown(); } @@ -112,9 +131,10 @@ class IndexLookupJoinTest : public IndexLookupJoinTestBase, // flag. static std::shared_ptr makeIndexTableHandle( const std::shared_ptr& indexTable, - bool asyncLookup) { + bool asyncLookup, + bool needsIndexSplit = false) { return std::make_shared( - kTestIndexConnectorName, indexTable, asyncLookup); + kTestIndexConnectorName, indexTable, asyncLookup, needsIndexSplit); } static connector::ColumnHandleMap makeIndexColumnHandles( @@ -176,12 +196,23 @@ TEST_F(IndexLookupJoinTest, joinCondition) { PlanBuilder::parseIndexJoinCondition("c0=1", rowType, pool_.get()); ASSERT_TRUE(equalFilterCondition->isFilter()); ASSERT_EQ(equalFilterCondition->toString(), "ROW[\"c0\"] = 1"); + + auto equalJoinCondition = + PlanBuilder::parseIndexJoinCondition("c0=c1", rowType, pool_.get()); + ASSERT_FALSE(equalJoinCondition->isFilter()); + ASSERT_EQ(equalJoinCondition->toString(), "ROW[\"c0\"] = ROW[\"c1\"]"); + + auto equalJoinConditionAsFilter = + PlanBuilder::parseIndexJoinCondition("c0=1", rowType, pool_.get()); + ASSERT_TRUE(equalJoinConditionAsFilter->isFilter()); + ASSERT_EQ(equalJoinConditionAsFilter->toString(), "ROW[\"c0\"] = 1"); } TEST_P(IndexLookupJoinTest, planNodeAndSerde) { TestIndexTableHandle::registerSerDe(); - auto indexConnectorHandle = makeIndexTableHandle(nullptr, true); + auto indexConnectorHandle = + makeIndexTableHandle(nullptr, true, GetParam().needsIndexSplit); auto left = makeRowVector( {"t0", "t1", "t2", "t3", "t4"}, @@ -224,17 +255,18 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { for (const auto joinType : {core::JoinType::kLeft, core::JoinType::kInner}) { auto plan = PlanBuilder(planNodeIdGenerator) .values({left}) - .indexLookupJoin( - {"t0"}, - {"u0"}, - indexTableScan, - {}, - /*includeMatchColumn=*/false, - {"t0", "u1", "t2", "t1"}, - joinType) + .startIndexLookupJoin() + .leftKeys({"t0"}) + .rightKeys({"u0"}) + .indexSource(indexTableScan) + .outputLayout({"t0", "u1", "t2", "t1"}) + .joinType(joinType) + .endIndexLookupJoin() .planNode(); auto indexLookupJoinNode = std::dynamic_pointer_cast(plan); + ASSERT_EQ( + indexLookupJoinNode->needsIndexSplit(), GetParam().needsIndexSplit); ASSERT_TRUE(indexLookupJoinNode->joinConditions().empty()); ASSERT_EQ( indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), @@ -246,17 +278,19 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { for (const auto joinType : {core::JoinType::kLeft, core::JoinType::kInner}) { auto plan = PlanBuilder(planNodeIdGenerator, pool_.get()) .values({left}) - .indexLookupJoin( - {"t0"}, - {"u0"}, - indexTableScan, - {"contains(t3, u0)", "contains(t4, u1)"}, - /*includeMatchColumn=*/false, - {"t0", "u1", "t2", "t1"}, - joinType) + .startIndexLookupJoin() + .leftKeys({"t0"}) + .rightKeys({"u0"}) + .indexSource(indexTableScan) + .joinConditions({"contains(t3, u0)", "contains(t4, u1)"}) + .outputLayout({"t0", "u1", "t2", "t1"}) + .joinType(joinType) + .endIndexLookupJoin() .planNode(); auto indexLookupJoinNode = std::dynamic_pointer_cast(plan); + ASSERT_EQ( + indexLookupJoinNode->needsIndexSplit(), GetParam().needsIndexSplit); ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 2); ASSERT_EQ( indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), @@ -265,6 +299,85 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { } // with between join conditions. + for (const auto joinType : {core::JoinType::kLeft, core::JoinType::kInner}) { + auto plan = PlanBuilder(planNodeIdGenerator, pool_.get()) + .values({left}) + .startIndexLookupJoin() + .leftKeys({"t0"}) + .rightKeys({"u0"}) + .indexSource(indexTableScan) + .joinConditions( + {"u0 between t0 AND t1", + "u1 between t1 AND 10", + "u1 between 10 AND t1"}) + .outputLayout({"t0", "u1", "t2", "t1"}) + .joinType(joinType) + .endIndexLookupJoin() + .planNode(); + auto indexLookupJoinNode = + std::dynamic_pointer_cast(plan); + ASSERT_EQ( + indexLookupJoinNode->needsIndexSplit(), GetParam().needsIndexSplit); + ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 3); + ASSERT_EQ( + indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), + kTestIndexConnectorName); + testSerde(plan); + } + + // with mix join conditions. + for (const auto joinType : {core::JoinType::kLeft, core::JoinType::kInner}) { + auto plan = + PlanBuilder(planNodeIdGenerator, pool_.get()) + .values({left}) + .startIndexLookupJoin() + .leftKeys({"t0"}) + .rightKeys({"u0"}) + .indexSource(indexTableScan) + .joinConditions({"contains(t3, u0)", "u1 between 10 AND t1"}) + .outputLayout({"t0", "u1", "t2", "t1"}) + .joinType(joinType) + .endIndexLookupJoin() + .planNode(); + auto indexLookupJoinNode = + std::dynamic_pointer_cast(plan); + ASSERT_EQ( + indexLookupJoinNode->needsIndexSplit(), GetParam().needsIndexSplit); + ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 2); + ASSERT_EQ( + indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), + kTestIndexConnectorName); + testSerde(plan); + } + + // with has match column. + { + auto plan = + PlanBuilder(planNodeIdGenerator, pool_.get()) + .values({left}) + .startIndexLookupJoin() + .leftKeys({"t0"}) + .rightKeys({"u0"}) + .indexSource(indexTableScan) + .joinConditions({"contains(t3, u0)", "u1 between 10 AND t1"}) + .hasMarker(true) + .outputLayout({"t0", "u1", "t2", "t1", "match"}) + .joinType(core::JoinType::kLeft) + .endIndexLookupJoin() + .planNode(); + auto indexLookupJoinNode = + std::dynamic_pointer_cast(plan); + ASSERT_EQ( + indexLookupJoinNode->needsIndexSplit(), GetParam().needsIndexSplit); + ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 2); + ASSERT_EQ(indexLookupJoinNode->filter(), nullptr); + ASSERT_EQ( + indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), + kTestIndexConnectorName); + testSerde(plan); + } + + // with filter. for (const auto joinType : {core::JoinType::kLeft, core::JoinType::kInner}) { auto plan = PlanBuilder(planNodeIdGenerator, pool_.get()) .values({left}) @@ -272,23 +385,27 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { {"t0"}, {"u0"}, indexTableScan, - {"u0 between t0 AND t1", - "u1 between t1 AND 10", - "u1 between 10 AND t1"}, - /*includeMatchColumn=*/false, + {}, + /*filter=*/"t1 % 2 = 0", + /*hasMarker=*/false, {"t0", "u1", "t2", "t1"}, joinType) .planNode(); auto indexLookupJoinNode = std::dynamic_pointer_cast(plan); - ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 3); + ASSERT_TRUE(indexLookupJoinNode->joinConditions().empty()); + ASSERT_NE(indexLookupJoinNode->filter(), nullptr); + ASSERT_EQ( + indexLookupJoinNode->needsIndexSplit(), GetParam().needsIndexSplit); + ASSERT_EQ( + indexLookupJoinNode->filter()->toString(), "eq(mod(ROW[\"t1\"],2),0)"); ASSERT_EQ( indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), kTestIndexConnectorName); testSerde(plan); } - // with mix join conditions. + // with join conditions and filter. for (const auto joinType : {core::JoinType::kLeft, core::JoinType::kInner}) { auto plan = PlanBuilder(planNodeIdGenerator, pool_.get()) .values({left}) @@ -296,21 +413,28 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { {"t0"}, {"u0"}, indexTableScan, - {"contains(t3, u0)", "u1 between 10 AND t1"}, - /*includeMatchColumn=*/false, + {"contains(t3, u0)"}, + /*filter=*/"u1 % 2 = 0 AND t2 > 5", + /*hasMarker=*/false, {"t0", "u1", "t2", "t1"}, joinType) .planNode(); auto indexLookupJoinNode = std::dynamic_pointer_cast(plan); - ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 2); + ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 1); + ASSERT_EQ( + indexLookupJoinNode->needsIndexSplit(), GetParam().needsIndexSplit); + ASSERT_NE(indexLookupJoinNode->filter(), nullptr); + ASSERT_EQ( + indexLookupJoinNode->filter()->toString(), + "and(eq(mod(ROW[\"u1\"],2),0),gt(ROW[\"t2\"],5))"); ASSERT_EQ( indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), kTestIndexConnectorName); testSerde(plan); } - // with has match column. + // with filter and marker for left join. { auto plan = PlanBuilder(planNodeIdGenerator, pool_.get()) .values({left}) @@ -318,14 +442,50 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { {"t0"}, {"u0"}, indexTableScan, - {"contains(t3, u0)", "u1 between 10 AND t1"}, - /*includeMatchColumn=*/true, + {"u1 between 10 AND t1"}, + /*filter=*/"t2 < u2", + /*hasMarker=*/true, {"t0", "u1", "t2", "t1", "match"}, core::JoinType::kLeft) .planNode(); auto indexLookupJoinNode = std::dynamic_pointer_cast(plan); - ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 2); + ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 1); + ASSERT_EQ( + indexLookupJoinNode->needsIndexSplit(), GetParam().needsIndexSplit); + ASSERT_NE(indexLookupJoinNode->filter(), nullptr); + ASSERT_EQ( + indexLookupJoinNode->filter()->toString(), + "lt(ROW[\"t2\"],ROW[\"u2\"])"); + ASSERT_TRUE(indexLookupJoinNode->hasMarker()); + ASSERT_EQ( + indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), + kTestIndexConnectorName); + testSerde(plan); + } + + // with complex filter expression. + { + auto plan = PlanBuilder(planNodeIdGenerator, pool_.get()) + .values({left}) + .startIndexLookupJoin() + .leftKeys({"t0"}) + .rightKeys({"u0"}) + .indexSource(indexTableScan) + .filter("(t1 + u1) * 2 > 100 OR t2 = u2") + .outputLayout({"t0", "u1", "t2", "t1"}) + .joinType(core::JoinType::kInner) + .endIndexLookupJoin() + .planNode(); + auto indexLookupJoinNode = + std::dynamic_pointer_cast(plan); + ASSERT_TRUE(indexLookupJoinNode->joinConditions().empty()); + ASSERT_EQ( + indexLookupJoinNode->needsIndexSplit(), GetParam().needsIndexSplit); + ASSERT_NE(indexLookupJoinNode->filter(), nullptr); + ASSERT_EQ( + indexLookupJoinNode->filter()->toString(), + "or(gt(multiply(plus(ROW[\"t1\"],ROW[\"u1\"]),2),100),eq(ROW[\"t2\"],ROW[\"u2\"]))"); ASSERT_EQ( indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), kTestIndexConnectorName); @@ -337,14 +497,13 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { VELOX_ASSERT_USER_THROW( PlanBuilder(planNodeIdGenerator) .values({left}) - .indexLookupJoin( - {"t0"}, - {"u0"}, - indexTableScan, - {}, - /*includeMatchColumn=*/false, - {"t0", "u1", "t2", "t1"}, - core::JoinType::kFull) + .startIndexLookupJoin() + .leftKeys({"t0"}) + .rightKeys({"u0"}) + .indexSource(indexTableScan) + .outputLayout({"t0", "u1", "t2", "t1"}) + .joinType(core::JoinType::kFull) + .endIndexLookupJoin() .planNode(), "Unsupported index lookup join type FULL"); } @@ -354,13 +513,12 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { VELOX_ASSERT_USER_THROW( PlanBuilder(planNodeIdGenerator) .values({left}) - .indexLookupJoin( - {"t0"}, - {"u0"}, - nonIndexTableScan, - {}, - /*includeMatchColumn=*/false, - {"t0", "u1", "t2", "t1"}) + .startIndexLookupJoin() + .leftKeys({"t0"}) + .rightKeys({"u0"}) + .indexSource(nonIndexTableScan) + .outputLayout({"t0", "u1", "t2", "t1"}) + .endIndexLookupJoin() .planNode(), "The lookup table handle hive_table from connector test-hive doesn't support index lookup"); } @@ -370,13 +528,13 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { VELOX_ASSERT_THROW( PlanBuilder(planNodeIdGenerator) .values({left}) - .indexLookupJoin( - {"t0", "t1"}, - {"u0"}, - indexTableScan, - {"contains(t4, u0)"}, - /*includeMatchColumn=*/false, - {"t0", "u1", "t2", "t1"}) + .startIndexLookupJoin() + .leftKeys({"t0", "t1"}) + .rightKeys({"u0"}) + .indexSource(indexTableScan) + .joinConditions({"contains(t4, u0)"}) + .outputLayout({"t0", "u1", "t2", "t1"}) + .endIndexLookupJoin() .planNode(), "The index lookup join node requires same number of join keys on left and right sides"); } @@ -386,19 +544,19 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { VELOX_ASSERT_THROW( PlanBuilder(planNodeIdGenerator) .values({left}) - .indexLookupJoin( - {}, - {}, - indexTableScan, - {"contains(t4, u0)"}, - /*includeMatchColumn=*/false, - {"t0", "u1", "t2", "t1"}) + .startIndexLookupJoin() + .leftKeys({}) + .rightKeys({}) + .indexSource(indexTableScan) + .joinConditions({"contains(t4, u0)"}) + .outputLayout({"t0", "u1", "t2", "t1"}) + .endIndexLookupJoin() .planNode(), "The index lookup join node requires at least one join key"); } } -TEST_P(IndexLookupJoinTest, equalJoin) { +TEST_P(IndexLookupJoinTest, DISABLED_equalJoin) { struct { std::vector keyCardinalities; int numProbeBatches; @@ -750,7 +908,7 @@ TEST_P(IndexLookupJoinTest, equalJoin) { for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); - SequenceTableData tableData; + IndexTableData tableData; generateIndexTableData(testData.keyCardinalities, tableData, pool_); auto probeVectors = generateProbeInput( testData.numProbeBatches, @@ -767,15 +925,15 @@ TEST_P(IndexLookupJoinTest, equalJoin) { createProbeFiles(probeVectors); createDuckDbTable("t", probeVectors); - createDuckDbTable("u", {tableData.tableData}); + createDuckDbTable("u", {tableData.tableVectors}); const auto indexTable = TestIndexTable::create( /*numEqualJoinKeys=*/3, - tableData.keyData, - tableData.valueData, + tableData.keyVectors, + tableData.valueVectors, *pool()); - const auto indexTableHandle = - makeIndexTableHandle(indexTable, GetParam().asyncLookup); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); auto planNodeIdGenerator = std::make_shared(); const auto indexScanNode = makeIndexScanNode( planNodeIdGenerator, @@ -789,7 +947,8 @@ TEST_P(IndexLookupJoinTest, equalJoin) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, testData.joinType, testData.outputColumns); runLookupQuery( @@ -799,6 +958,7 @@ TEST_P(IndexLookupJoinTest, equalJoin) { GetParam().serialExecution, 32, GetParam().numPrefetches, + GetParam().needsIndexSplit, testData.duckDbVerifySql); if (testData.joinType != core::JoinType::kLeft) { continue; @@ -809,12 +969,18 @@ TEST_P(IndexLookupJoinTest, equalJoin) { indexScanNode, {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, - {}, - /*includeMatchColumn=*/true, + /*joinConditions=*/{}, + /*filter=*/"", + /*hasMarker=*/true, testData.joinType, testData.outputColumns); verifyResultWithMatchColumn( - plan, probeScanId, planWithMatchColumn, probeScanNodeId_, probeFiles); + plan, + probeScanId, + planWithMatchColumn, + probeScanNodeId_, + probeFiles, + GetParam().needsIndexSplit); } } @@ -1221,7 +1387,7 @@ TEST_P(IndexLookupJoinTest, betweenJoinCondition) { for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); - SequenceTableData tableData; + IndexTableData tableData; generateIndexTableData(testData.keyCardinalities, tableData, pool_); auto probeVectors = generateProbeInput( testData.numProbeBatches, @@ -1240,15 +1406,15 @@ TEST_P(IndexLookupJoinTest, betweenJoinCondition) { createProbeFiles(probeVectors); createDuckDbTable("t", probeVectors); - createDuckDbTable("u", {tableData.tableData}); + createDuckDbTable("u", {tableData.tableVectors}); const auto indexTable = TestIndexTable::create( /*numEqualJoinKeys=*/2, - tableData.keyData, - tableData.valueData, + tableData.keyVectors, + tableData.valueVectors, *pool()); - const auto indexTableHandle = - makeIndexTableHandle(indexTable, GetParam().asyncLookup); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); auto planNodeIdGenerator = std::make_shared(); const auto indexScanNode = makeIndexScanNode( planNodeIdGenerator, @@ -1262,7 +1428,8 @@ TEST_P(IndexLookupJoinTest, betweenJoinCondition) { {"t0", "t1"}, {"u0", "u1"}, {testData.betweenCondition}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, testData.joinType, testData.outputColumns); runLookupQuery( @@ -1272,6 +1439,7 @@ TEST_P(IndexLookupJoinTest, betweenJoinCondition) { GetParam().serialExecution, 32, GetParam().numPrefetches, + GetParam().needsIndexSplit, testData.duckDbVerifySql); if (testData.joinType != core::JoinType::kLeft) { continue; @@ -1283,11 +1451,17 @@ TEST_P(IndexLookupJoinTest, betweenJoinCondition) { {"t0", "t1"}, {"u0", "u1"}, {testData.betweenCondition}, - /*includeMatchColumn=*/true, + /*filter=*/"", + /*hasMarker=*/true, testData.joinType, testData.outputColumns); verifyResultWithMatchColumn( - plan, probeScanId, planWithMatchColumn, probeScanNodeId_, probeFiles); + plan, + probeScanId, + planWithMatchColumn, + probeScanNodeId_, + probeFiles, + GetParam().needsIndexSplit); } } @@ -1561,7 +1735,7 @@ TEST_P(IndexLookupJoinTest, inJoinCondition) { for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); - SequenceTableData tableData; + IndexTableData tableData; generateIndexTableData(testData.keyCardinalities, tableData, pool_); auto probeVectors = generateProbeInput( testData.numProbeBatches, @@ -1579,15 +1753,15 @@ TEST_P(IndexLookupJoinTest, inJoinCondition) { createProbeFiles(probeVectors); createDuckDbTable("t", probeVectors); - createDuckDbTable("u", {tableData.tableData}); + createDuckDbTable("u", {tableData.tableVectors}); const auto indexTable = TestIndexTable::create( /*numEqualJoinKeys=*/2, - tableData.keyData, - tableData.valueData, + tableData.keyVectors, + tableData.valueVectors, *pool()); - const auto indexTableHandle = - makeIndexTableHandle(indexTable, GetParam().asyncLookup); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); auto planNodeIdGenerator = std::make_shared(); const auto indexScanNode = makeIndexScanNode( planNodeIdGenerator, @@ -1601,7 +1775,8 @@ TEST_P(IndexLookupJoinTest, inJoinCondition) { {"t0", "t1"}, {"u0", "u1"}, {testData.inCondition}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, testData.joinType, testData.outputColumns); runLookupQuery( @@ -1611,6 +1786,7 @@ TEST_P(IndexLookupJoinTest, inJoinCondition) { GetParam().serialExecution, 32, GetParam().numPrefetches, + GetParam().needsIndexSplit, testData.duckDbVerifySql); if (testData.joinType != core::JoinType::kLeft) { continue; @@ -1622,11 +1798,17 @@ TEST_P(IndexLookupJoinTest, inJoinCondition) { {"t0", "t1"}, {"u0", "u1"}, {testData.inCondition}, - /*includeMatchColumn=*/true, + /*filter=*/"", + /*hasMarker=*/true, testData.joinType, testData.outputColumns); verifyResultWithMatchColumn( - plan, probeScanId, planWithMatchColumn, probeScanNodeId_, probeFiles); + plan, + probeScanId, + planWithMatchColumn, + probeScanNodeId_, + probeFiles, + GetParam().needsIndexSplit); } } @@ -1702,7 +1884,7 @@ TEST_P(IndexLookupJoinTest, prefixKeysEqualJoin) { for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); - SequenceTableData tableData; + IndexTableData tableData; generateIndexTableData(testData.keyCardinalities, tableData, pool_); // Generate probe vectors with only the prefix of keys @@ -1726,15 +1908,15 @@ TEST_P(IndexLookupJoinTest, prefixKeysEqualJoin) { createProbeFiles(probeVectors); createDuckDbTable("t", probeVectors); - createDuckDbTable("u", {tableData.tableData}); + createDuckDbTable("u", {tableData.tableVectors}); const auto indexTable = TestIndexTable::create( /*numEqualJoinKeys=*/testData.numKeysToUse, - tableData.keyData, - tableData.valueData, + tableData.keyVectors, + tableData.valueVectors, *pool()); - const auto indexTableHandle = - makeIndexTableHandle(indexTable, GetParam().asyncLookup); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); auto planNodeIdGenerator = std::make_shared(); const auto indexScanNode = makeIndexScanNode( planNodeIdGenerator, @@ -1756,7 +1938,8 @@ TEST_P(IndexLookupJoinTest, prefixKeysEqualJoin) { leftKeys, rightKeys, {}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, testData.joinType, testData.outputColumns); runLookupQuery( @@ -1766,6 +1949,7 @@ TEST_P(IndexLookupJoinTest, prefixKeysEqualJoin) { GetParam().serialExecution, 32, GetParam().numPrefetches, + GetParam().needsIndexSplit, testData.duckDbVerifySql); if (testData.joinType != core::JoinType::kLeft) { continue; @@ -1777,11 +1961,17 @@ TEST_P(IndexLookupJoinTest, prefixKeysEqualJoin) { leftKeys, rightKeys, {}, - /*includeMatchColumn=*/true, + /*filter=*/"", + /*hasMarker=*/true, testData.joinType, testData.outputColumns); verifyResultWithMatchColumn( - plan, probeScanId, planWithMatchColumn, probeScanNodeId_, probeFiles); + plan, + probeScanId, + planWithMatchColumn, + probeScanNodeId_, + probeFiles, + GetParam().needsIndexSplit); } } @@ -1852,7 +2042,7 @@ TEST_P(IndexLookupJoinTest, prefixKeysbetweenJoinCondition) { for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); - SequenceTableData tableData; + IndexTableData tableData; generateIndexTableData(testData.keyCardinalities, tableData, pool_); auto probeVectors = generateProbeInput( @@ -1872,15 +2062,15 @@ TEST_P(IndexLookupJoinTest, prefixKeysbetweenJoinCondition) { createProbeFiles(probeVectors); createDuckDbTable("t", probeVectors); - createDuckDbTable("u", {tableData.tableData}); + createDuckDbTable("u", {tableData.tableVectors}); const auto indexTable = TestIndexTable::create( /*numEqualJoinKeys=*/1, - tableData.keyData, - tableData.valueData, + tableData.keyVectors, + tableData.valueVectors, *pool()); - const auto indexTableHandle = - makeIndexTableHandle(indexTable, GetParam().asyncLookup); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); auto planNodeIdGenerator = std::make_shared(); const auto indexScanNode = makeIndexScanNode( planNodeIdGenerator, @@ -1894,7 +2084,8 @@ TEST_P(IndexLookupJoinTest, prefixKeysbetweenJoinCondition) { {"t0"}, {"u0"}, {testData.betweenCondition}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, testData.joinType, testData.outputColumns); runLookupQuery( @@ -1904,6 +2095,7 @@ TEST_P(IndexLookupJoinTest, prefixKeysbetweenJoinCondition) { GetParam().serialExecution, 32, GetParam().numPrefetches, + GetParam().needsIndexSplit, testData.duckDbVerifySql); if (testData.joinType != core::JoinType::kLeft) { continue; @@ -1915,11 +2107,17 @@ TEST_P(IndexLookupJoinTest, prefixKeysbetweenJoinCondition) { {"t0"}, {"u0"}, {testData.betweenCondition}, - /*includeMatchColumn=*/true, + /*filter=*/"", + /*hasMarker=*/true, testData.joinType, testData.outputColumns); verifyResultWithMatchColumn( - plan, probeScanId, planWithMatchColumn, probeScanNodeId_, probeFiles); + plan, + probeScanId, + planWithMatchColumn, + probeScanNodeId_, + probeFiles, + GetParam().needsIndexSplit); } } @@ -1993,7 +2191,7 @@ TEST_P(IndexLookupJoinTest, prefixInJoinCondition) { for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); - SequenceTableData tableData; + IndexTableData tableData; generateIndexTableData(testData.keyCardinalities, tableData, pool_); auto probeVectors = generateProbeInput( testData.numProbeBatches, @@ -2011,15 +2209,15 @@ TEST_P(IndexLookupJoinTest, prefixInJoinCondition) { createProbeFiles(probeVectors); createDuckDbTable("t", probeVectors); - createDuckDbTable("u", {tableData.tableData}); + createDuckDbTable("u", {tableData.tableVectors}); const auto indexTable = TestIndexTable::create( /*numEqualJoinKeys=*/1, - tableData.keyData, - tableData.valueData, + tableData.keyVectors, + tableData.valueVectors, *pool()); - const auto indexTableHandle = - makeIndexTableHandle(indexTable, GetParam().asyncLookup); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); auto planNodeIdGenerator = std::make_shared(); const auto indexScanNode = makeIndexScanNode( planNodeIdGenerator, @@ -2033,7 +2231,8 @@ TEST_P(IndexLookupJoinTest, prefixInJoinCondition) { {"t0"}, {"u0"}, {testData.inCondition}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, testData.joinType, testData.outputColumns); runLookupQuery( @@ -2043,6 +2242,7 @@ TEST_P(IndexLookupJoinTest, prefixInJoinCondition) { GetParam().serialExecution, 32, GetParam().numPrefetches, + GetParam().needsIndexSplit, testData.duckDbVerifySql); if (testData.joinType != core::JoinType::kLeft) { continue; @@ -2054,16 +2254,22 @@ TEST_P(IndexLookupJoinTest, prefixInJoinCondition) { {"t0"}, {"u0"}, {testData.inCondition}, - /*includeMatchColumn=*/true, + /*filter=*/"", + /*hasMarker=*/true, testData.joinType, testData.outputColumns); verifyResultWithMatchColumn( - plan, probeScanId, planWithMatchColumn, probeScanNodeId_, probeFiles); + plan, + probeScanId, + planWithMatchColumn, + probeScanNodeId_, + probeFiles, + GetParam().needsIndexSplit); } } DEBUG_ONLY_TEST_P(IndexLookupJoinTest, connectorError) { - SequenceTableData tableData; + IndexTableData tableData; generateIndexTableData({100, 1, 1}, tableData, pool_); const std::vector probeVectors = generateProbeInput( 20, @@ -2091,9 +2297,12 @@ DEBUG_ONLY_TEST_P(IndexLookupJoinTest, connectorError) { })); const auto indexTable = TestIndexTable::create( - /*numEqualJoinKeys=*/3, tableData.keyData, tableData.valueData, *pool()); - const auto indexTableHandle = - makeIndexTableHandle(indexTable, GetParam().asyncLookup); + /*numEqualJoinKeys=*/3, + tableData.keyVectors, + tableData.valueVectors, + *pool()); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); auto planNodeIdGenerator = std::make_shared(); const auto indexScanNode = makeIndexScanNode( planNodeIdGenerator, @@ -2107,7 +2316,8 @@ DEBUG_ONLY_TEST_P(IndexLookupJoinTest, connectorError) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, core::JoinType::kInner, {"u0", "u1", "u2", "t5"}); VELOX_ASSERT_THROW( @@ -2118,6 +2328,7 @@ DEBUG_ONLY_TEST_P(IndexLookupJoinTest, connectorError) { GetParam().serialExecution, 100, GetParam().numPrefetches, + GetParam().needsIndexSplit, "SELECT u.c0, u.c1, t.c2, t.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2"), errorMsg); } @@ -2127,7 +2338,7 @@ DEBUG_ONLY_TEST_P(IndexLookupJoinTest, prefetch) { // This test only works for async lookup. return; } - SequenceTableData tableData; + IndexTableData tableData; generateIndexTableData({100, 1, 1}, tableData, pool_); const int numProbeBatches{20}; ASSERT_GT(numProbeBatches, GetParam().numPrefetches); @@ -2145,7 +2356,7 @@ DEBUG_ONLY_TEST_P(IndexLookupJoinTest, prefetch) { std::vector> probeFiles = createProbeFiles(probeVectors); createDuckDbTable("t", probeVectors); - createDuckDbTable("u", {tableData.tableData}); + createDuckDbTable("u", {tableData.tableVectors}); std::atomic_int lookupCount{0}; folly::EventCount asyncLookupWait; @@ -2160,9 +2371,12 @@ DEBUG_ONLY_TEST_P(IndexLookupJoinTest, prefetch) { })); const auto indexTable = TestIndexTable::create( - /*numEqualJoinKeys=*/3, tableData.keyData, tableData.valueData, *pool()); - const auto indexTableHandle = - makeIndexTableHandle(indexTable, GetParam().asyncLookup); + /*numEqualJoinKeys=*/3, + tableData.keyVectors, + tableData.valueVectors, + *pool()); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); auto planNodeIdGenerator = std::make_shared(); const auto indexScanNode = makeIndexScanNode( planNodeIdGenerator, @@ -2176,7 +2390,8 @@ DEBUG_ONLY_TEST_P(IndexLookupJoinTest, prefetch) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, core::JoinType::kInner, {"u3", "t5"}); std::thread queryThread([&] { @@ -2187,6 +2402,7 @@ DEBUG_ONLY_TEST_P(IndexLookupJoinTest, prefetch) { GetParam().serialExecution, 100, GetParam().numPrefetches, + GetParam().needsIndexSplit, "SELECT u.c3, t.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2"); }); while (lookupCount < 1 + GetParam().numPrefetches) { @@ -2200,7 +2416,7 @@ DEBUG_ONLY_TEST_P(IndexLookupJoinTest, prefetch) { } TEST_P(IndexLookupJoinTest, outputBatchSizeWithInnerJoin) { - SequenceTableData tableData; + IndexTableData tableData; generateIndexTableData({3'000, 1, 1}, tableData, pool_); struct { @@ -2256,15 +2472,15 @@ TEST_P(IndexLookupJoinTest, outputBatchSizeWithInnerJoin) { createProbeFiles(probeVectors); createDuckDbTable("t", probeVectors); - createDuckDbTable("u", {tableData.tableData}); + createDuckDbTable("u", {tableData.tableVectors}); const auto indexTable = TestIndexTable::create( /*numEqualJoinKeys=*/3, - tableData.keyData, - tableData.valueData, + tableData.keyVectors, + tableData.valueVectors, *pool()); - const auto indexTableHandle = - makeIndexTableHandle(indexTable, GetParam().asyncLookup); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); auto planNodeIdGenerator = std::make_shared(); const auto indexScanNode = makeIndexScanNode( @@ -2279,29 +2495,36 @@ TEST_P(IndexLookupJoinTest, outputBatchSizeWithInnerJoin) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, core::JoinType::kInner, {"t4", "u5"}); - const auto task = - AssertQueryBuilder(duckDbQueryRunner_) - .plan(plan) - .config( - core::QueryConfig::kIndexLookupJoinMaxPrefetchBatches, - std::to_string(GetParam().numPrefetches)) - .config( - core::QueryConfig::kPreferredOutputBatchRows, - std::to_string(testData.maxBatchRows)) - .config( - core::QueryConfig::kPreferredOutputBatchBytes, - std::to_string(1ULL << 30)) - .config( - core::QueryConfig::kIndexLookupJoinSplitOutput, - testData.splitOutput ? "true" : "false") - .splits(probeScanNodeId_, makeHiveConnectorSplits(probeFiles)) - .serialExecution(GetParam().serialExecution) - .barrierExecution(GetParam().serialExecution) - .assertResults( - "SELECT t.c4, u.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2"); + AssertQueryBuilder queryBuilder(duckDbQueryRunner_); + queryBuilder.plan(plan) + .config( + core::QueryConfig::kIndexLookupJoinMaxPrefetchBatches, + std::to_string(GetParam().numPrefetches)) + .config( + core::QueryConfig::kPreferredOutputBatchRows, + std::to_string(testData.maxBatchRows)) + .config( + core::QueryConfig::kPreferredOutputBatchBytes, + std::to_string(1ULL << 30)) + .config( + core::QueryConfig::kIndexLookupJoinSplitOutput, + testData.splitOutput ? "true" : "false") + .splits(probeScanNodeId_, makeHiveConnectorSplits(probeFiles)) + .serialExecution(GetParam().serialExecution) + .barrierExecution(GetParam().serialExecution); + if (GetParam().needsIndexSplit) { + queryBuilder.split( + indexScanNodeId_, + Split( + std::make_shared( + kTestIndexConnectorName))); + } + const auto task = queryBuilder.assertResults( + "SELECT t.c4, u.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2"); ASSERT_EQ( toPlanStats(task->taskStats()).at(joinNodeId_).outputVectors, testData.numExpectedOutputBatch); @@ -2309,7 +2532,7 @@ TEST_P(IndexLookupJoinTest, outputBatchSizeWithInnerJoin) { } TEST_P(IndexLookupJoinTest, outputBatchSizeWithLeftJoin) { - SequenceTableData tableData; + IndexTableData tableData; generateIndexTableData({3'000, 1, 1}, tableData, pool_); struct { @@ -2365,15 +2588,15 @@ TEST_P(IndexLookupJoinTest, outputBatchSizeWithLeftJoin) { createProbeFiles(probeVectors); createDuckDbTable("t", probeVectors); - createDuckDbTable("u", {tableData.tableData}); + createDuckDbTable("u", {tableData.tableVectors}); const auto indexTable = TestIndexTable::create( /*numEqualJoinKeys=*/3, - tableData.keyData, - tableData.valueData, + tableData.keyVectors, + tableData.valueVectors, *pool()); - const auto indexTableHandle = - makeIndexTableHandle(indexTable, GetParam().asyncLookup); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); auto planNodeIdGenerator = std::make_shared(); const auto indexScanNode = makeIndexScanNode( planNodeIdGenerator, @@ -2387,29 +2610,36 @@ TEST_P(IndexLookupJoinTest, outputBatchSizeWithLeftJoin) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, core::JoinType::kLeft, {"t4", "u5"}); - const auto task = - AssertQueryBuilder(duckDbQueryRunner_) - .plan(plan) - .config( - core::QueryConfig::kIndexLookupJoinMaxPrefetchBatches, - std::to_string(GetParam().numPrefetches)) - .config( - core::QueryConfig::kPreferredOutputBatchRows, - std::to_string(testData.maxBatchRows)) - .config( - core::QueryConfig::kPreferredOutputBatchBytes, - std::to_string(1ULL << 30)) - .config( - core::QueryConfig::kIndexLookupJoinSplitOutput, - testData.splitOutput ? "true" : "false") - .splits(probeScanNodeId_, makeHiveConnectorSplits(probeFiles)) - .serialExecution(GetParam().serialExecution) - .barrierExecution(GetParam().serialExecution) - .assertResults( - "SELECT t.c4, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2"); + AssertQueryBuilder queryBuilder(duckDbQueryRunner_); + queryBuilder.plan(plan) + .config( + core::QueryConfig::kIndexLookupJoinMaxPrefetchBatches, + std::to_string(GetParam().numPrefetches)) + .config( + core::QueryConfig::kPreferredOutputBatchRows, + std::to_string(testData.maxBatchRows)) + .config( + core::QueryConfig::kPreferredOutputBatchBytes, + std::to_string(1ULL << 30)) + .config( + core::QueryConfig::kIndexLookupJoinSplitOutput, + testData.splitOutput ? "true" : "false") + .splits(probeScanNodeId_, makeHiveConnectorSplits(probeFiles)) + .serialExecution(GetParam().serialExecution) + .barrierExecution(GetParam().serialExecution); + if (GetParam().needsIndexSplit) { + queryBuilder.split( + indexScanNodeId_, + Split( + std::make_shared( + kTestIndexConnectorName))); + } + const auto task = queryBuilder.assertResults( + "SELECT t.c4, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2"); ASSERT_EQ( toPlanStats(task->taskStats()).at(joinNodeId_).outputVectors, testData.numExpectedOutputBatch); @@ -2420,21 +2650,28 @@ TEST_P(IndexLookupJoinTest, outputBatchSizeWithLeftJoin) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, - /*includeMatchColumn=*/true, + /*filter=*/"", + /*hasMarker=*/true, core::JoinType::kLeft, {"t4", "u5"}); verifyResultWithMatchColumn( - plan, probeScanId, planWithMatchColumn, probeScanNodeId_, probeFiles); + plan, + probeScanId, + planWithMatchColumn, + probeScanNodeId_, + probeFiles, + GetParam().needsIndexSplit); } } -DEBUG_ONLY_TEST_P(IndexLookupJoinTest, runtimeStats) { - SequenceTableData tableData; +DEBUG_ONLY_TEST_P(IndexLookupJoinTest, statsSplitter) { + IndexTableData tableData; generateIndexTableData({100, 1, 1}, tableData, pool_); const int numProbeBatches{2}; + const int batchSize{100}; const std::vector probeVectors = generateProbeInput( numProbeBatches, - 100, + batchSize, 1, tableData, pool_, @@ -2446,18 +2683,22 @@ DEBUG_ONLY_TEST_P(IndexLookupJoinTest, runtimeStats) { std::vector> probeFiles = createProbeFiles(probeVectors); createDuckDbTable("t", probeVectors); - createDuckDbTable("u", {tableData.tableData}); + createDuckDbTable("u", {tableData.tableVectors}); + // Add a small delay in async lookup to ensure timing stats are captured. SCOPED_TESTVALUE_SET( "facebook::velox::exec::test::TestIndexSource::ResultIterator::asyncLookup", std::function([&](void*) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); // NOLINT + std::this_thread::sleep_for(std::chrono::milliseconds(10)); // NOLINT })); const auto indexTable = TestIndexTable::create( - /*numEqualJoinKeys=*/3, tableData.keyData, tableData.valueData, *pool()); - const auto indexTableHandle = - makeIndexTableHandle(indexTable, GetParam().asyncLookup); + /*numEqualJoinKeys=*/3, + tableData.keyVectors, + tableData.valueVectors, + *pool()); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); auto planNodeIdGenerator = std::make_shared(); const auto indexScanNode = makeIndexScanNode( planNodeIdGenerator, @@ -2471,7 +2712,8 @@ DEBUG_ONLY_TEST_P(IndexLookupJoinTest, runtimeStats) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, core::JoinType::kInner, {"u3", "t5"}); auto task = runLookupQuery( @@ -2481,50 +2723,58 @@ DEBUG_ONLY_TEST_P(IndexLookupJoinTest, runtimeStats) { GetParam().serialExecution, 100, 0, + GetParam().needsIndexSplit, "SELECT u.c3, t.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2"); auto taskStats = toPlanStats(task->taskStats()); - auto& operatorStats = taskStats.at(joinNodeId_); - ASSERT_EQ(operatorStats.backgroundTiming.count, numProbeBatches); - ASSERT_GT(operatorStats.backgroundTiming.cpuNanos, 0); - ASSERT_GT(operatorStats.backgroundTiming.wallNanos, 0); - auto runtimeStats = operatorStats.customStats; - ASSERT_EQ( - runtimeStats.at(IndexLookupJoin::kConnectorLookupWallTime).count, - numProbeBatches); - ASSERT_GT(runtimeStats.at(IndexLookupJoin::kConnectorLookupWallTime).sum, 0); - ASSERT_EQ( - runtimeStats.at(IndexLookupJoin::kClientLookupWaitWallTime).count, - numProbeBatches); - ASSERT_GT(runtimeStats.at(IndexLookupJoin::kClientLookupWaitWallTime).sum, 0); - ASSERT_EQ( - runtimeStats.at(IndexLookupJoin::kConnectorResultPrepareTime).count, - numProbeBatches); - ASSERT_GT( - runtimeStats.at(IndexLookupJoin::kConnectorResultPrepareTime).sum, 0); - ASSERT_EQ(runtimeStats.count(IndexLookupJoin::kClientRequestProcessTime), 0); - ASSERT_EQ(runtimeStats.count(IndexLookupJoin::kClientResultProcessTime), 0); - ASSERT_EQ(runtimeStats.count(IndexLookupJoin::kClientLookupResultSize), 0); - ASSERT_EQ(runtimeStats.count(IndexLookupJoin::kClientLookupResultRawSize), 0); - ASSERT_THAT( - operatorStats.toString(true, true), - testing::MatchesRegex(".*Runtime stats.*connectorLookupWallNanos:.*")); - ASSERT_THAT( - operatorStats.toString(true, true), - testing::MatchesRegex(".*Runtime stats.*clientlookupWaitWallNanos.*")); - ASSERT_THAT( - operatorStats.toString(true, true), - testing::MatchesRegex( - ".*Runtime stats.*connectorResultPrepareCpuNanos.*")); + + // Verify that both the IndexLookupJoin node and IndexSource node have stats. + ASSERT_GT(taskStats.count(joinNodeId_), 0); + ASSERT_GT(taskStats.count(indexScanNodeId_), 0); + + const auto& joinStats = taskStats.at(joinNodeId_); + const auto& indexSourceStats = taskStats.at(indexScanNodeId_); + + EXPECT_GT(joinStats.inputRows, 0); + EXPECT_GT(joinStats.outputRows, 0); + + EXPECT_GT(indexSourceStats.outputRows, 0); + EXPECT_EQ(indexSourceStats.outputRows, joinStats.outputRows); + + EXPECT_GT(indexSourceStats.inputRows, 0); + EXPECT_GT(indexSourceStats.inputBytes, 0); + EXPECT_EQ(indexSourceStats.inputRows, joinStats.inputRows); + + EXPECT_GT(indexSourceStats.addInputTiming.count, 0); + + EXPECT_GT( + indexSourceStats.customStats.count( + std::string(IndexLookupJoin::kConnectorLookupWallTime)), + 0); + EXPECT_GT( + indexSourceStats.customStats + .at(std::string(IndexLookupJoin::kConnectorLookupWallTime)) + .sum, + 0); + + // backgroundTiming should be cleared from join stats (moved to IndexSource). + EXPECT_EQ(joinStats.backgroundTiming.count, 0); + EXPECT_EQ(joinStats.backgroundTiming.cpuNanos, 0); + EXPECT_EQ(joinStats.backgroundTiming.wallNanos, 0); } -TEST_P(IndexLookupJoinTest, barrier) { - SequenceTableData tableData; +/// Verifies that the probe-side scan operator's inputBytes/outputBytes are not +/// inflated with index lookup bytes. IndexLookupJoin uses a custom stat writer +/// in getOutput() to redirect index-side lazy loading stats to separate names, +/// so Driver::processLazyIoStats() only transfers probe-side stats to the scan. +TEST_P(IndexLookupJoinTest, scanStatsNotInflated) { + IndexTableData tableData; generateIndexTableData({100, 1, 1}, tableData, pool_); - const int numProbeSplits{5}; - const auto probeVectors = generateProbeInput( - numProbeSplits, - 256, + const int numProbeBatches{2}; + const int batchSize{100}; + const std::vector probeVectors = generateProbeInput( + numProbeBatches, + batchSize, 1, tableData, pool_, @@ -2536,12 +2786,15 @@ TEST_P(IndexLookupJoinTest, barrier) { std::vector> probeFiles = createProbeFiles(probeVectors); createDuckDbTable("t", probeVectors); - createDuckDbTable("u", {tableData.tableData}); + createDuckDbTable("u", {tableData.tableVectors}); const auto indexTable = TestIndexTable::create( - /*numEqualJoinKeys=*/3, tableData.keyData, tableData.valueData, *pool()); - const auto indexTableHandle = - makeIndexTableHandle(indexTable, GetParam().asyncLookup); + /*numEqualJoinKeys=*/3, + tableData.keyVectors, + tableData.valueVectors, + *pool()); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); auto planNodeIdGenerator = std::make_shared(); const auto indexScanNode = makeIndexScanNode( planNodeIdGenerator, @@ -2555,236 +2808,130 @@ TEST_P(IndexLookupJoinTest, barrier) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, core::JoinType::kInner, {"u3", "t5"}); + auto task = runLookupQuery( + plan, + probeFiles, + GetParam().serialExecution, + GetParam().serialExecution, + 100, + 0, + GetParam().needsIndexSplit, + "SELECT u.c3, t.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2"); - struct { - int numPrefetches; - bool barrierExecution; - - std::string debugString() const { - return fmt::format( - "numPrefetches {}, barrierExecution {}", - numPrefetches, - barrierExecution); - } - } testSettings[] = { - {0, true}, - {0, false}, - {1, true}, - {1, false}, - {4, true}, - {4, false}, - {256, true}, - {256, false}}; + auto taskStats = toPlanStats(task->taskStats()); - for (const auto& testData : testSettings) { - SCOPED_TRACE(testData.debugString()); - auto task = runLookupQuery( - plan, - probeFiles, - true, - testData.barrierExecution, - 32, - testData.numPrefetches, - "SELECT u.c3, t.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2"); + ASSERT_GT(taskStats.count(probeScanNodeId_), 0); + const auto& scanStats = taskStats.at(probeScanNodeId_); + + EXPECT_GT(scanStats.inputRows, 0); + EXPECT_GT(scanStats.inputBytes, 0); + EXPECT_EQ(scanStats.inputBytes, scanStats.outputBytes); + + // Verify that lazy loading stats are NOT present on the join node. + // Probe-side lazy stats accumulate in every non-scan operator's + // runtimeStats when lazy vectors are loaded. processLazyIoStats transfers + // timing deltas to the scan but doesn't erase the accumulated stats. + // For most operators this is harmless — the stats just sit in runtimeStats. + // But IndexLookupJoin's splitStats copies combinedStats to create the join + // node stats, so we must explicitly erase them to avoid exposing residual + // lazy stats on the join node. + const auto& joinStats = taskStats.at(joinNodeId_); + EXPECT_EQ( + joinStats.customStats.count(std::string(LazyVector::kInputBytes)), 0); +} - const auto taskStats = task->taskStats(); - ASSERT_EQ( - taskStats.numBarriers, testData.barrierExecution ? numProbeSplits : 0); - ASSERT_EQ(taskStats.numFinishedSplits, numProbeSplits); +/// Verifies that IndexSource stats report rows BEFORE the join filter is +/// applied, while IndexLookupJoin stats report rows AFTER the filter. +DEBUG_ONLY_TEST_P(IndexLookupJoinTest, statsSplitterWithFilter) { + // Skip serial execution tests for simplicity - the stats behavior is the + // same. + if (GetParam().serialExecution || GetParam().hasNullKeys) { + return; } -} -TEST_P(IndexLookupJoinTest, nullKeys) { - SequenceTableData tableData; + IndexTableData tableData; generateIndexTableData({100, 1, 1}, tableData, pool_); - const int numProbeSplits{5}; - const int probeBatchSize{256}; - const auto probeVectors = generateProbeInput( - numProbeSplits, - probeBatchSize, + const int numProbeBatches{2}; + const int batchSize{100}; + const std::vector probeVectors = generateProbeInput( + numProbeBatches, + batchSize, 1, tableData, pool_, {"t0", "t1", "t2"}, - /*hasNullKeys=*/true, + /*hasNullKeys=*/false, {}, {}, - /*equalMatchPct=*/100); - // Set some probe key vector to all nulls to trigger the case that entire - // probe input is skipped. - for (int i = 0; i < numProbeSplits; i += 2) { - for (int row = 0; row < probeBatchSize; ++row) { - probeVectors[i]->childAt(i % keyType_->size())->setNull(row, true); - } - } + 100); std::vector> probeFiles = createProbeFiles(probeVectors); createDuckDbTable("t", probeVectors); - createDuckDbTable("u", {tableData.tableData}); + createDuckDbTable("u", {tableData.tableVectors}); + + // Add a small delay in async lookup to ensure timing stats are captured. + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::test::TestIndexSource::ResultIterator::asyncLookup", + std::function([&](void*) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); // NOLINT + })); const auto indexTable = TestIndexTable::create( - /*numEqualJoinKeys=*/3, tableData.keyData, tableData.valueData, *pool()); - const auto indexTableHandle = - makeIndexTableHandle(indexTable, GetParam().asyncLookup); + /*numEqualJoinKeys=*/3, + tableData.keyVectors, + tableData.valueVectors, + *pool()); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); auto planNodeIdGenerator = std::make_shared(); - std::unordered_map> - columnHandles; const auto indexScanNode = makeIndexScanNode( planNodeIdGenerator, indexTableHandle, makeScanOutputType({"u0", "u1", "u2", "u3", "u5"}), makeIndexColumnHandles({"u0", "u1", "u2", "u3", "u5"})); - const auto innerPlan = makeLookupPlan( + // Add a filter that should filter out approximately half the rows. + // The filter "u3 % 2 = 0" will keep only even values of u3. + auto plan = makeLookupPlan( planNodeIdGenerator, indexScanNode, {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, - /*includeMatchColumn=*/false, + /*filter=*/"u3 % 2 = 0", + /*hasMarker=*/false, core::JoinType::kInner, {"u3", "t5"}); - - runLookupQuery( - innerPlan, - probeFiles, - true, - true, - 32, - GetParam().numPrefetches, - "SELECT u.c3, t.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2"); - - const auto leftPlan = makeLookupPlan( - planNodeIdGenerator, - indexScanNode, - {"t0", "t1", "t2"}, - {"u0", "u1", "u2"}, - {}, - /*includeMatchColumn=*/false, - core::JoinType::kLeft, - {"u3", "t5"}); - - runLookupQuery( - leftPlan, - probeFiles, - true, - true, - 32, - GetParam().numPrefetches, - "SELECT u.c3, t.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2"); - - const auto probeScanId = probeScanNodeId_; - auto planWithMatchColumn = makeLookupPlan( - planNodeIdGenerator, - indexScanNode, - {"t0", "t1", "t2"}, - {"u0", "u1", "u2"}, - {}, - /*includeMatchColumn=*/true, - core::JoinType::kLeft, - {"u3", "t5"}); - verifyResultWithMatchColumn( - leftPlan, probeScanId, planWithMatchColumn, probeScanNodeId_, probeFiles); -} - -TEST_P(IndexLookupJoinTest, joinFuzzer) { - SequenceTableData tableData; - generateIndexTableData({1024, 1, 1}, tableData, pool_); - const auto probeVectors = generateProbeInput( - 50, 256, 1, tableData, pool_, {"t0", "t1", "t2"}, GetParam().hasNullKeys); - std::vector> probeFiles = - createProbeFiles(probeVectors); - - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", {tableData.tableData}); - - const auto indexTable = TestIndexTable::create( - /*numEqualJoinKeys=*/1, tableData.keyData, tableData.valueData, *pool()); - const auto indexTableHandle = - makeIndexTableHandle(indexTable, GetParam().asyncLookup); - auto planNodeIdGenerator = std::make_shared(); - auto scanOutput = tableType_->names(); - std::random_device rd; - std::mt19937 g(rd()); - std::shuffle(scanOutput.begin(), scanOutput.end(), g); - const auto indexScanNode = makeIndexScanNode( - planNodeIdGenerator, - indexTableHandle, - makeScanOutputType(scanOutput), - makeIndexColumnHandles(scanOutput)); - - auto plan = makeLookupPlan( - planNodeIdGenerator, - indexScanNode, - {"t0"}, - {"u0"}, - {"contains(t4, u1)", "u2 between t1 and t2"}, - /*includeMatchColumn=*/false, - core::JoinType::kInner, - {"u0", "u4", "t0", "t1", "t4"}); - runLookupQuery( + auto task = runLookupQuery( plan, probeFiles, GetParam().serialExecution, GetParam().serialExecution, - 32, - GetParam().numPrefetches, - "SELECT u.c0, u.c1, u.c2, u.c3, u.c4, u.c5, t.c0, t.c1, t.c2, t.c3, t.c4, t.c5 FROM t, u WHERE t.c0 = u.c0 AND array_contains(t.c4, u.c1) AND u.c2 BETWEEN t.c1 AND t.c2"); -} + 100, + 0, + GetParam().needsIndexSplit, + "SELECT u.c3, t.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2 AND u.c3 % 2 = 0"); -TEST_P(IndexLookupJoinTest, tableRowsWithDuplicateKeys) { - SequenceTableData tableData; - generateIndexTableData({10, 1, 1}, tableData, pool_); - for (int i = 0; i < keyType_->size(); ++i) { - tableData.keyData->childAt(i) = makeFlatVector( - tableData.keyData->childAt(i)->size(), - [](auto /*unused*/) { return 1; }); - tableData.tableData->childAt(i) = makeFlatVector( - tableData.keyData->childAt(i)->size(), - [](auto /*unused*/) { return 1; }); - } + auto taskStats = toPlanStats(task->taskStats()); - auto probeVectors = generateProbeInput( - 4, 32, 1, tableData, pool_, {"t0", "t1", "t2"}, false, {}, {}, 100); - std::vector> probeFiles = - createProbeFiles(probeVectors); + ASSERT_GT(taskStats.count(joinNodeId_), 0); + ASSERT_GT(taskStats.count(indexScanNodeId_), 0); - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", {tableData.tableData}); + const auto& joinStats = taskStats.at(joinNodeId_); + const auto& indexSourceStats = taskStats.at(indexScanNodeId_); - const auto indexTable = TestIndexTable::create( - /*numEqualJoinKeys=*/3, tableData.keyData, tableData.valueData, *pool()); - const auto indexTableHandle = - makeIndexTableHandle(indexTable, GetParam().asyncLookup); - auto planNodeIdGenerator = std::make_shared(); - auto scanOutput = tableType_->names(); - const auto indexScanNode = makeIndexScanNode( - planNodeIdGenerator, - indexTableHandle, - makeScanOutputType(scanOutput), - makeIndexColumnHandles(scanOutput)); + EXPECT_GT(joinStats.inputRows, 0); + EXPECT_GT(joinStats.outputRows, 0); + EXPECT_GT(indexSourceStats.outputRows, 0); - auto plan = makeLookupPlan( - planNodeIdGenerator, - indexScanNode, - {"t0", "t1", "t2"}, - {"u0", "u1", "u2"}, - {}, - /*includeMatchColumn=*/false, - core::JoinType::kInner, - scanOutput); - runLookupQuery( - plan, - probeFiles, - GetParam().serialExecution, - GetParam().serialExecution, - 32, - GetParam().numPrefetches, - "SELECT u.c0, u.c1, u.c2, u.c3, u.c4, u.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c1 = u.c1 AND u.c2 = t.c2"); + // IndexSource reports rows before the filter; IndexLookupJoin reports after. + EXPECT_GT(indexSourceStats.outputRows, joinStats.outputRows); } + } // namespace VELOX_INSTANTIATE_TEST_SUITE_P( @@ -2793,10 +2940,11 @@ VELOX_INSTANTIATE_TEST_SUITE_P( testing::ValuesIn(IndexLookupJoinTest::getTestParams()), [](const testing::TestParamInfo& info) { return fmt::format( - "{}_{}prefetches_{}_{}", + "{}_{}prefetches_{}_{}_{}", info.param.asyncLookup ? "async" : "sync", info.param.numPrefetches, info.param.serialExecution ? "serial" : "parallel", - info.param.hasNullKeys ? "nullKeys" : "noNullKeys"); + info.param.hasNullKeys ? "nullKeys" : "noNullKeys", + info.param.needsIndexSplit ? "needsIndexSplit" : "noSplit"); }); -} // namespace fecebook::velox::exec::test +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/IndexLookupJoinTestExtra.cpp b/velox/exec/tests/IndexLookupJoinTestExtra.cpp new file mode 100644 index 00000000000..6ab7347dfa4 --- /dev/null +++ b/velox/exec/tests/IndexLookupJoinTestExtra.cpp @@ -0,0 +1,2186 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fmt/format.h" +#include "folly/synchronization/EventCount.h" +#include "gmock/gmock.h" +#include "gtest/gtest-matchers.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/testutil/TestValue.h" +#include "velox/connectors/Connector.h" +#include "velox/connectors/ConnectorRegistry.h" +#include "velox/core/PlanNode.h" +#include "velox/exec/IndexLookupJoin.h" +#include "velox/exec/OutputBufferManager.h" +#include "velox/exec/PlanNodeStats.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/exec/tests/utils/IndexLookupJoinTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/exec/tests/utils/TestIndexStorageConnector.h" + +using namespace facebook::velox; +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; +using namespace facebook::velox::common::testutil; + +namespace facebook::velox::exec::test { +namespace { +struct TestParam { + bool asyncLookup; + int32_t numPrefetches; + bool serialExecution; + bool hasNullKeys; + bool needsIndexSplit; + + TestParam( + bool _asyncLookup, + int32_t _numPrefetches, + bool _serialExecution, + bool _hasNullKeys, + bool _needsIndexSplit = false) + : asyncLookup(_asyncLookup), + numPrefetches(_numPrefetches), + serialExecution(_serialExecution), + hasNullKeys(_hasNullKeys), + needsIndexSplit(_needsIndexSplit) {} + + std::string toString() const { + return fmt::format( + "asyncLookup={}, numPrefetches={}, serialExecution={}, hasNullKeys={}, needsIndexSplit={}", + asyncLookup, + numPrefetches, + serialExecution, + hasNullKeys, + needsIndexSplit); + } +}; + +class IndexLookupJoinTest : public IndexLookupJoinTestBase, + public testing::WithParamInterface { + public: + static std::vector getTestParams() { + std::vector testParams; + for (bool asyncLookup : {false, true}) { + for (int numPrefetches : {0, 3}) { + for (bool serialExecution : {false, true}) { + for (bool hasNullKeys : {false, true}) { + for (bool needsIndexSplit : {false, true}) { + // Serial execution doesn't support index split as it requires + // single-threaded execution which is incompatible with the + // split-based parallelism used by index lookup join. + if (serialExecution && needsIndexSplit) { + continue; + } + testParams.emplace_back( + asyncLookup, + numPrefetches, + serialExecution, + hasNullKeys, + needsIndexSplit); + } + } + } + } + } + return testParams; + } + + protected: + void SetUp() override { + HiveConnectorTestBase::SetUp(); + core::PlanNode::registerSerDe(); + connector::hive::HiveColumnHandle::registerSerDe(); + Type::registerSerDe(); + core::ITypedExpr::registerSerDe(); + TestIndexConnectorFactory::registerConnector(connectorCpuExecutor_.get()); + + keyType_ = ROW({"u0", "u1", "u2"}, {BIGINT(), BIGINT(), BIGINT()}); + valueType_ = ROW({"u3", "u4", "u5"}, {BIGINT(), BIGINT(), VARCHAR()}); + tableType_ = concat(keyType_, valueType_); + probeType_ = ROW( + {"t0", "t1", "t2", "t3", "t4", "t5"}, + {BIGINT(), BIGINT(), BIGINT(), BIGINT(), ARRAY(BIGINT()), VARCHAR()}); + } + + void TearDown() override { + connector::ConnectorRegistry::global().erase(kTestIndexConnectorName); + HiveConnectorTestBase::TearDown(); + } + + void testSerde(const core::PlanNodePtr& plan) { + auto serialized = plan->serialize(); + auto copy = ISerializable::deserialize(serialized, pool()); + ASSERT_EQ(plan->toString(true, true), copy->toString(true, true)); + } + + // Makes index table handle with the specified index table and async lookup + // flag. + static std::shared_ptr makeIndexTableHandle( + const std::shared_ptr& indexTable, + bool asyncLookup, + bool needsIndexSplit = false) { + return std::make_shared( + kTestIndexConnectorName, indexTable, asyncLookup, needsIndexSplit); + } + + static connector::ColumnHandleMap makeIndexColumnHandles( + const std::vector& names) { + connector::ColumnHandleMap handles; + for (const auto& name : names) { + handles.emplace(name, std::make_shared(name)); + } + + return handles; + } + + const std::unique_ptr connectorCpuExecutor_{ + std::make_unique(128)}; +}; + +// Verifies that when splitOutput_ is false, trailing input rows that have no +// lookup matches are included in the current output batch rather than being +// emitted in a separate batch via produceRemainingOutputForLeftJoin. +TEST_P(IndexLookupJoinTest, leftJoinTrailingMissesWithNoSplitOutput) { + IndexTableData tableData; + generateIndexTableData({500, 1, 1}, tableData, pool_); + + struct { + int numProbeBatches; + int numRowsPerProbeBatch; + int maxBatchRows; + int equalMatchPct; + bool splitOutput; + + std::string debugString() const { + return fmt::format( + "numProbeBatches: {}, numRowsPerProbeBatch: {}, maxBatchRows: {}, equalMatchPct: {}, splitOutput: {}", + numProbeBatches, + numRowsPerProbeBatch, + maxBatchRows, + equalMatchPct, + splitOutput); + } + } testSettings[] = { + // With splitOutput=false, trailing misses should be folded into the + // current batch. With splitOutput=true, they are emitted separately. + {10, 100, 200, 10, false}, + {10, 100, 200, 50, false}, + {10, 100, 200, 2, false}, + {1, 500, 1000, 10, false}, + {1, 500, 1000, 50, false}, + {10, 50, 200, 10, false}, + // With no matches, all rows are misses. + {10, 100, 200, 0, false}, + // 100% matches - no trailing misses exist. + {10, 100, 200, 100, false}, + // splitOutput=true as comparison. + {10, 100, 200, 10, true}, + {10, 100, 200, 50, true}, + {10, 100, 200, 0, true}, + }; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + const auto probeVectors = generateProbeInput( + testData.numProbeBatches, + testData.numRowsPerProbeBatch, + 1, + tableData, + pool_, + {"t0", "t1", "t2"}, + GetParam().hasNullKeys, + {}, + {}, + testData.equalMatchPct); + std::vector> probeFiles = + createProbeFiles(probeVectors); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableVectors}); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/3, + tableData.keyVectors, + tableData.valueVectors, + *pool()); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); + auto planNodeIdGenerator = std::make_shared(); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType({"u0", "u1", "u2", "u5"}), + makeIndexColumnHandles({"u0", "u1", "u2", "u5"})); + + auto plan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + /*filter=*/"", + /*hasMarker=*/false, + core::JoinType::kLeft, + {"t4", "u5"}); + AssertQueryBuilder queryBuilder(duckDbQueryRunner_); + queryBuilder.plan(plan) + .config( + core::QueryConfig::kIndexLookupJoinMaxPrefetchBatches, + std::to_string(GetParam().numPrefetches)) + .config( + core::QueryConfig::kPreferredOutputBatchRows, + std::to_string(testData.maxBatchRows)) + .config( + core::QueryConfig::kPreferredOutputBatchBytes, + std::to_string(1ULL << 30)) + .config( + core::QueryConfig::kIndexLookupJoinSplitOutput, + testData.splitOutput ? "true" : "false") + .splits(probeScanNodeId_, makeHiveConnectorSplits(probeFiles)) + .serialExecution(GetParam().serialExecution) + .barrierExecution(GetParam().serialExecution); + if (GetParam().needsIndexSplit) { + queryBuilder.split( + indexScanNodeId_, + Split( + std::make_shared( + kTestIndexConnectorName))); + } + const auto task = queryBuilder.assertResults( + "SELECT t.c4, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2"); + + // Verify match column correctness for all cases. + const auto probeScanId = probeScanNodeId_; + auto planWithMatchColumn = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + /*filter=*/"", + /*hasMarker=*/true, + core::JoinType::kLeft, + {"t4", "u5"}); + verifyResultWithMatchColumn( + plan, + probeScanId, + planWithMatchColumn, + probeScanNodeId_, + probeFiles, + GetParam().needsIndexSplit); + } +} + +DEBUG_ONLY_TEST_P(IndexLookupJoinTest, runtimeStats) { + IndexTableData tableData; + generateIndexTableData({100, 1, 1}, tableData, pool_); + const int numProbeBatches{2}; + const std::vector probeVectors = generateProbeInput( + numProbeBatches, + 100, + 1, + tableData, + pool_, + {"t0", "t1", "t2"}, + GetParam().hasNullKeys, + {}, + {}, + 100); + std::vector> probeFiles = + createProbeFiles(probeVectors); + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableVectors}); + + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::test::TestIndexSource::ResultIterator::asyncLookup", + std::function([&](void*) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); // NOLINT + })); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/3, + tableData.keyVectors, + tableData.valueVectors, + *pool()); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); + auto planNodeIdGenerator = std::make_shared(); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType({"u0", "u1", "u2", "u3", "u5"}), + makeIndexColumnHandles({"u0", "u1", "u2", "u3", "u5"})); + + auto plan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + /*filter=*/"", + /*hasMarker=*/false, + core::JoinType::kInner, + {"u3", "t5"}); + auto task = runLookupQuery( + plan, + probeFiles, + GetParam().serialExecution, + GetParam().serialExecution, + 100, + 0, + GetParam().needsIndexSplit, + "SELECT u.c3, t.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2"); + + auto taskStats = toPlanStats(task->taskStats()); + + // Check IndexSource stats - lookup timing should be here, not on join stats. + auto& indexSourceStats = taskStats.at(indexScanNodeId_); + ASSERT_EQ(indexSourceStats.addInputTiming.count, numProbeBatches); + ASSERT_GT(indexSourceStats.addInputTiming.cpuNanos, 0); + ASSERT_GT(indexSourceStats.addInputTiming.wallNanos, 0); + + // Verify that backgroundTiming was cleared from join stats (moved to + // IndexSource to avoid double counting). + auto& operatorStats = taskStats.at(joinNodeId_); + ASSERT_EQ(operatorStats.backgroundTiming.count, 0); + ASSERT_EQ(operatorStats.backgroundTiming.cpuNanos, 0); + ASSERT_EQ(operatorStats.backgroundTiming.wallNanos, 0); + + // Check runtime stats are present on IndexSource. + auto runtimeStats = indexSourceStats.customStats; + ASSERT_EQ( + runtimeStats.at(std::string(IndexLookupJoin::kConnectorLookupWallTime)) + .count, + numProbeBatches); + ASSERT_GT( + runtimeStats.at(std::string(IndexLookupJoin::kConnectorLookupWallTime)) + .sum, + 0); + ASSERT_EQ( + runtimeStats.at(std::string(IndexLookupJoin::kClientLookupWaitWallTime)) + .count, + numProbeBatches); + ASSERT_GT( + runtimeStats.at(std::string(IndexLookupJoin::kClientLookupWaitWallTime)) + .sum, + 0); + ASSERT_EQ( + runtimeStats.at(std::string(IndexLookupJoin::kConnectorResultPrepareTime)) + .count, + numProbeBatches); + ASSERT_GT( + runtimeStats.at(std::string(IndexLookupJoin::kConnectorResultPrepareTime)) + .sum, + 0); + ASSERT_EQ( + runtimeStats.count( + std::string(IndexLookupJoin::kClientRequestProcessTime)), + 0); + ASSERT_EQ( + runtimeStats.count( + std::string(IndexLookupJoin::kClientResultProcessTime)), + 0); + ASSERT_EQ( + runtimeStats.count(std::string(IndexLookupJoin::kClientLookupResultSize)), + 0); + ASSERT_EQ( + runtimeStats.count( + std::string(IndexLookupJoin::kClientLookupResultRawSize)), + 0); + ASSERT_THAT( + indexSourceStats.toString(true, true), + testing::MatchesRegex(".*Runtime stats.*connectorLookupWallNanos:.*")); + ASSERT_THAT( + indexSourceStats.toString(true, true), + testing::MatchesRegex(".*Runtime stats.*clientlookupWaitWallNanos.*")); + ASSERT_THAT( + indexSourceStats.toString(true, true), + testing::MatchesRegex( + ".*Runtime stats.*connectorResultPrepareCpuNanos.*")); +} + +/// Verifies that IndexLookupJoin's StatsSplitter correctly reports separate +/// operator stats for both the IndexLookupJoin node and the IndexSource node. +/// This ensures IndexSource appears with its own CPU/Scheduled/Output stats +/// in the query plan visualization. +DEBUG_ONLY_TEST_P(IndexLookupJoinTest, statsSplitter) { + IndexTableData tableData; + generateIndexTableData({100, 1, 1}, tableData, pool_); + const int numProbeBatches{2}; + const int batchSize{100}; + const std::vector probeVectors = generateProbeInput( + numProbeBatches, + batchSize, + 1, + tableData, + pool_, + {"t0", "t1", "t2"}, + GetParam().hasNullKeys, + {}, + {}, + 100); + std::vector> probeFiles = + createProbeFiles(probeVectors); + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableVectors}); + + // Add a small delay in async lookup to ensure timing stats are captured. + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::test::TestIndexSource::ResultIterator::asyncLookup", + std::function([&](void*) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); // NOLINT + })); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/3, + tableData.keyVectors, + tableData.valueVectors, + *pool()); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); + auto planNodeIdGenerator = std::make_shared(); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType({"u0", "u1", "u2", "u3", "u5"}), + makeIndexColumnHandles({"u0", "u1", "u2", "u3", "u5"})); + + auto plan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + /*filter=*/"", + /*hasMarker=*/false, + core::JoinType::kInner, + {"u3", "t5"}); + auto task = runLookupQuery( + plan, + probeFiles, + GetParam().serialExecution, + GetParam().serialExecution, + 100, + 0, + GetParam().needsIndexSplit, + "SELECT u.c3, t.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2"); + + auto taskStats = toPlanStats(task->taskStats()); + + // Verify that both the IndexLookupJoin node and IndexSource node have stats. + ASSERT_TRUE(taskStats.count(joinNodeId_) > 0) + << "IndexLookupJoin node stats missing"; + ASSERT_TRUE(taskStats.count(indexScanNodeId_) > 0) + << "IndexSource node stats missing"; + + const auto& joinStats = taskStats.at(joinNodeId_); + const auto& indexSourceStats = taskStats.at(indexScanNodeId_); + + // Verify join stats have input from probe side. + EXPECT_GT(joinStats.inputRows, 0); + EXPECT_GT(joinStats.outputRows, 0); + + // Verify IndexSource stats have output positions (the lookup results). + EXPECT_GT(indexSourceStats.outputRows, 0); + EXPECT_EQ(indexSourceStats.outputRows, joinStats.outputRows); + + // Verify IndexSource stats have input positions (lookup keys sent to + // connector). + EXPECT_GT(indexSourceStats.inputRows, 0); + EXPECT_GT(indexSourceStats.inputBytes, 0); + // For inner join without filter, input rows should match join input rows + // (all probe rows are sent as lookup keys). + EXPECT_EQ(indexSourceStats.inputRows, joinStats.inputRows); + + // Verify IndexSource stats have timing from backgroundTiming (lookup time). + // The addInputTiming should contain the lookup wall/cpu time. + EXPECT_GT(indexSourceStats.addInputTiming.count, 0); + + // Verify runtime stats are present on IndexSource (connector metrics). + // These include connector lookup wall time, etc. + EXPECT_TRUE( + indexSourceStats.customStats.count( + std::string(IndexLookupJoin::kConnectorLookupWallTime)) > 0) + << "IndexSource should have connector lookup wall time"; + EXPECT_GT( + indexSourceStats.customStats + .at(std::string(IndexLookupJoin::kConnectorLookupWallTime)) + .sum, + 0); + + // Verify that backgroundTiming was cleared from join stats (moved to + // IndexSource to avoid double counting). + EXPECT_EQ(joinStats.backgroundTiming.count, 0); + EXPECT_EQ(joinStats.backgroundTiming.cpuNanos, 0); + EXPECT_EQ(joinStats.backgroundTiming.wallNanos, 0); + + // Verify that index source stats are present in the raw operator-level + // Verify index source stats are copied into the operator's runtimeStats + // with the "indexSource." prefix for task-level stats visibility. + const auto rawTaskStats = task->taskStats(); + bool foundJoinOp = false; + for (const auto& pipeline : rawTaskStats.pipelineStats) { + for (const auto& opStats : pipeline.operatorStats) { + if (opStats.operatorType == "IndexLookupJoin") { + foundJoinOp = true; + EXPECT_TRUE(opStats.runtimeStats.count( + fmt::format( + "indexSource.{}", IndexLookupJoin::kConnectorLookupWallTime))) + << "IndexLookupJoin operator runtimeStats should contain index " + "source stats for task-level stats reporting"; + break; + } + } + if (foundJoinOp) { + break; + } + } + EXPECT_TRUE(foundJoinOp) << "IndexLookupJoin operator not found in raw stats"; + + // After splitStats (via toPlanStats above), the join node's customStats + // should NOT contain index source stats — they belong to IndexSource. + EXPECT_EQ( + joinStats.customStats.count( + std::string(IndexLookupJoin::kConnectorLookupWallTime)), + 0) + << "Join node should not have index source stats after splitStats"; +} + +/// Verifies that IndexSource stats report rows BEFORE the join filter is +/// applied, while IndexLookupJoin stats report rows AFTER the filter. +DEBUG_ONLY_TEST_P(IndexLookupJoinTest, statsSplitterWithFilter) { + // Skip serial execution tests for simplicity - the stats behavior is the + // same. + if (GetParam().serialExecution || GetParam().hasNullKeys) { + return; + } + + IndexTableData tableData; + generateIndexTableData({100, 1, 1}, tableData, pool_); + const int numProbeBatches{2}; + const int batchSize{100}; + const std::vector probeVectors = generateProbeInput( + numProbeBatches, + batchSize, + 1, + tableData, + pool_, + {"t0", "t1", "t2"}, + /*hasNullKeys=*/false, + {}, + {}, + 100); + std::vector> probeFiles = + createProbeFiles(probeVectors); + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableVectors}); + + // Add a small delay in async lookup to ensure timing stats are captured. + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::test::TestIndexSource::ResultIterator::asyncLookup", + std::function([&](void*) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); // NOLINT + })); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/3, + tableData.keyVectors, + tableData.valueVectors, + *pool()); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); + auto planNodeIdGenerator = std::make_shared(); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType({"u0", "u1", "u2", "u3", "u5"}), + makeIndexColumnHandles({"u0", "u1", "u2", "u3", "u5"})); + + // Add a filter that should filter out approximately half the rows. + // The filter "u3 % 2 = 0" will keep only even values of u3. + auto plan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + /*filter=*/"u3 % 2 = 0", + /*hasMarker=*/false, + core::JoinType::kInner, + {"u3", "t5"}); + auto task = runLookupQuery( + plan, + probeFiles, + GetParam().serialExecution, + GetParam().serialExecution, + 100, + 0, + GetParam().needsIndexSplit, + "SELECT u.c3, t.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2 AND u.c3 % 2 = 0"); + + auto taskStats = toPlanStats(task->taskStats()); + + // Verify that both the IndexLookupJoin node and IndexSource node have stats. + ASSERT_TRUE(taskStats.count(joinNodeId_) > 0) + << "IndexLookupJoin node stats missing"; + ASSERT_TRUE(taskStats.count(indexScanNodeId_) > 0) + << "IndexSource node stats missing"; + + const auto& joinStats = taskStats.at(joinNodeId_); + const auto& indexSourceStats = taskStats.at(indexScanNodeId_); + + // Verify join stats have input from probe side. + EXPECT_GT(joinStats.inputRows, 0); + EXPECT_GT(joinStats.outputRows, 0); + + // Verify IndexSource stats have output positions (rows before filter). + EXPECT_GT(indexSourceStats.outputRows, 0); + + // KEY ASSERTION: IndexSource should have MORE rows than IndexLookupJoin + // because IndexSource reports rows BEFORE the filter ("u3 % 2 = 0") is + // applied. + EXPECT_GT(indexSourceStats.outputRows, joinStats.outputRows) + << "IndexSource should report more rows (before filter) than " + << "IndexLookupJoin (after filter)"; +} + +TEST_P(IndexLookupJoinTest, DISABLED_barrier) { + if (GetParam().needsIndexSplit) { + return; + } + IndexTableData tableData; + generateIndexTableData({100, 1, 1}, tableData, pool_); + const int numProbeSplits{5}; + const auto probeVectors = generateProbeInput( + numProbeSplits, + 256, + 1, + tableData, + pool_, + {"t0", "t1", "t2"}, + GetParam().hasNullKeys, + {}, + {}, + 100); + std::vector> probeFiles = + createProbeFiles(probeVectors); + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableVectors}); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/3, + tableData.keyVectors, + tableData.valueVectors, + *pool()); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); + auto planNodeIdGenerator = std::make_shared(); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType({"u0", "u1", "u2", "u3", "u5"}), + makeIndexColumnHandles({"u0", "u1", "u2", "u3", "u5"})); + + auto plan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + /*filter=*/"", + /*hasMarker=*/false, + core::JoinType::kInner, + {"u3", "t5"}); + + struct { + int numPrefetches; + bool barrierExecution; + bool serialExecution; + + std::string debugString() const { + return fmt::format( + "numPrefetches {}, barrierExecution {}, serialExecution {}", + numPrefetches, + barrierExecution, + serialExecution); + } + } testSettings[] = { + {0, false, false}, + {0, false, true}, + {1, true, true}, + {1, false, true}, + {4, true, true}, + {4, false, true}, + {256, true, true}, + {256, false, true}, + {0, true, false}, + {0, false, false}, + {1, true, false}, + {1, false, false}, + {4, true, false}, + {4, false, false}, + {256, true, false}, + {256, false, false}, + }; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + auto task = runLookupQuery( + plan, + probeFiles, + testData.serialExecution, + testData.barrierExecution, + 32, + testData.numPrefetches, + GetParam().needsIndexSplit, + "SELECT u.c3, t.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2"); + + const auto taskStats = task->taskStats(); + ASSERT_EQ( + taskStats.numBarriers, testData.barrierExecution ? numProbeSplits : 0); + ASSERT_EQ(taskStats.numFinishedSplits, numProbeSplits); + } +} + +TEST_P(IndexLookupJoinTest, nullKeys) { + IndexTableData tableData; + generateIndexTableData({100, 1, 1}, tableData, pool_); + const int numProbeSplits{5}; + const int probeBatchSize{256}; + const auto probeVectors = generateProbeInput( + numProbeSplits, + probeBatchSize, + 1, + tableData, + pool_, + {"t0", "t1", "t2"}, + /*hasNullKeys=*/true, + {}, + {}, + /*equalMatchPct=*/100); + // Set some probe key vector to all nulls to trigger the case that entire + // probe input is skipped. + for (int i = 0; i < numProbeSplits; i += 2) { + for (int row = 0; row < probeBatchSize; ++row) { + probeVectors[i]->childAt(i % keyType_->size())->setNull(row, true); + } + } + std::vector> probeFiles = + createProbeFiles(probeVectors); + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableVectors}); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/3, + tableData.keyVectors, + tableData.valueVectors, + *pool()); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); + auto planNodeIdGenerator = std::make_shared(); + std::unordered_map> + columnHandles; + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType({"u0", "u1", "u2", "u3", "u5"}), + makeIndexColumnHandles({"u0", "u1", "u2", "u3", "u5"})); + + const auto innerPlan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + /*filter=*/"", + /*hasMarker=*/false, + core::JoinType::kInner, + {"u3", "t5"}); + + runLookupQuery( + innerPlan, + probeFiles, + /*serialExecution=*/GetParam().serialExecution, + /*barrierExecution=*/GetParam().serialExecution, + 32, + GetParam().numPrefetches, + GetParam().needsIndexSplit, + "SELECT u.c3, t.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2"); + + const auto leftPlan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + /*filter=*/"", + /*hasMarker=*/false, + core::JoinType::kLeft, + {"u3", "t5"}); + + runLookupQuery( + leftPlan, + probeFiles, + /*serialExecution=*/GetParam().serialExecution, + /*barrierExecution=*/GetParam().serialExecution, + 32, + GetParam().numPrefetches, + GetParam().needsIndexSplit, + "SELECT u.c3, t.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2"); + + const auto probeScanId = probeScanNodeId_; + auto planWithMatchColumn = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + /*filter=*/"", + /*hasMarker=*/true, + core::JoinType::kLeft, + {"u3", "t5"}); + verifyResultWithMatchColumn( + leftPlan, + probeScanId, + planWithMatchColumn, + probeScanNodeId_, + probeFiles, + GetParam().needsIndexSplit); +} + +TEST_P(IndexLookupJoinTest, joinFuzzer) { + IndexTableData tableData; + generateIndexTableData({1024, 1, 1}, tableData, pool_); + const auto probeVectors = generateProbeInput( + 50, 256, 1, tableData, pool_, {"t0", "t1", "t2"}, GetParam().hasNullKeys); + std::vector> probeFiles = + createProbeFiles(probeVectors); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableVectors}); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/1, + tableData.keyVectors, + tableData.valueVectors, + *pool()); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); + auto planNodeIdGenerator = std::make_shared(); + auto scanOutput = tableType_->names(); + std::random_device rd; + std::mt19937 g(rd()); + std::shuffle(scanOutput.begin(), scanOutput.end(), g); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType(scanOutput), + makeIndexColumnHandles(scanOutput)); + + auto plan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0"}, + {"u0"}, + {"contains(t4, u1)", "u2 between t1 and t2"}, + /*filter=*/"", + /*hasMarker=*/false, + core::JoinType::kInner, + {"u0", "u4", "t0", "t1", "t4"}); + runLookupQuery( + plan, + probeFiles, + GetParam().serialExecution, + GetParam().serialExecution, + 32, + GetParam().numPrefetches, + GetParam().needsIndexSplit, + "SELECT u.c0, u.c1, u.c2, u.c3, u.c4, u.c5, t.c0, t.c1, t.c2, t.c3, t.c4, t.c5 FROM t, u WHERE t.c0 = u.c0 AND array_contains(t.c4, u.c1) AND u.c2 BETWEEN t.c1 AND t.c2"); +} + +TEST_P(IndexLookupJoinTest, tableRowsWithDuplicateKeys) { + IndexTableData tableData; + generateIndexTableData({10, 1, 1}, tableData, pool_); + for (int i = 0; i < keyType_->size(); ++i) { + tableData.keyVectors->childAt(i) = makeFlatVector( + tableData.keyVectors->childAt(i)->size(), + [](auto /*unused*/) { return 1; }); + tableData.tableVectors->childAt(i) = makeFlatVector( + tableData.keyVectors->childAt(i)->size(), + [](auto /*unused*/) { return 1; }); + } + + auto probeVectors = generateProbeInput( + 4, 32, 1, tableData, pool_, {"t0", "t1", "t2"}, false, {}, {}, 100); + std::vector> probeFiles = + createProbeFiles(probeVectors); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableVectors}); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/3, + tableData.keyVectors, + tableData.valueVectors, + *pool()); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); + auto planNodeIdGenerator = std::make_shared(); + auto scanOutput = tableType_->names(); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType(scanOutput), + makeIndexColumnHandles(scanOutput)); + + auto plan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + /*filter=*/"", + /*hasMarker=*/false, + core::JoinType::kInner, + scanOutput); + runLookupQuery( + plan, + probeFiles, + GetParam().serialExecution, + GetParam().serialExecution, + 32, + GetParam().numPrefetches, + GetParam().needsIndexSplit, + "SELECT u.c0, u.c1, u.c2, u.c3, u.c4, u.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c1 = u.c1 AND u.c2 = t.c2"); +} + +TEST_P(IndexLookupJoinTest, withFilter) { + struct { + std::vector keyCardinalities; + int numProbeBatches; + int numRowsPerProbeBatch; + int matchPct; + std::vector scanOutputColumns; + std::vector outputColumns; + core::JoinType joinType; + std::string filter; + std::string duckDbVerifySql; + + std::string debugString() const { + return fmt::format( + "keyCardinalities: {}, numProbeBatches: {}, numRowsPerProbeBatch: {}, matchPct: {}, " + "scanOutputColumns: {}, outputColumns: {}, joinType: {}, filter: {}, " + "duckDbVerifySql: {}", + folly::join(",", keyCardinalities), + numProbeBatches, + numRowsPerProbeBatch, + matchPct, + folly::join(",", scanOutputColumns), + folly::join(",", outputColumns), + core::JoinTypeName::toName(joinType), + filter, + duckDbVerifySql); + } + } testSettings[] = { + // Inner join with filter on probe side + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "t3 % 2 = 0", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c3 % 2 = 0"}, + // Inner join with filter always be true. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "t3 = t3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c3 = t.c3"}, + // Inner join with filter always be false. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "t3 != t3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c3 != t.c3"}, + + // Inner join with filter on lookup side + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "u3 % 2 = 0", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND u.c3 % 2 = 0"}, + // Inner join with filter always be true. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "u3 = u3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND u.c3 = u.c3"}, + // Inner join with filter always be false. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "u3 != u3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND u.c3 != u.c3"}, + + // Inner join with filter on both side + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "u3 % 2 = 0 AND t3 % 2 = 0", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND u.c3 % 2 = 0 AND t.c3 % 2 = 0"}, + // Inner join with filter always be true. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "u3 = u3 AND t3 = t3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND u.c3 = u.c3 AND t.c3 = t.c3"}, + // Inner join with filter always be false. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "u3 != u3 AND t3 != t3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND u.c3 != u.c3 AND t.c3 != t.c3"}, + + // Left join with filter on probe side + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "t3 % 2 = 0", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND t.c3 % 2 = 0"}, + // Left join with filter always be true. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "t3 = t3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND t.c3 = t.c3"}, + // Inner join with filter always be false. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "t3 != t3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND t.c3 != t.c3"}, + + // Left join with filter on lookup side + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "u3 % 2 = 0", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND u.c3 % 2 = 0"}, + // Inner join with filter always be true. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "u3 = u3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND u.c3 = u.c3"}, + // Left join with filter always be false. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "u3 != u3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND u.c3 != u.c3"}, + + // Left join with filter on both side + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "u3 % 2 = 0 AND t3 % 2 = 0", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND u.c3 % 2 = 0 AND t.c3 % 2 = 0"}, + // Left join with filter always be true. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "u3 = u3 AND t3 = t3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND u.c3 = u.c3 AND t.c3 = t.c3"}, + // Left join with filter always be false. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "u3 != u3 AND t3 != t3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND u.c3 != u.c3 AND t.c3 != t.c3"}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + IndexTableData tableData; + generateIndexTableData(testData.keyCardinalities, tableData, pool_); + auto probeVectors = generateProbeInput( + testData.numProbeBatches, + testData.numRowsPerProbeBatch, + 1, + tableData, + pool_, + {"t0", "t1", "t2"}, + GetParam().hasNullKeys, + {}, + {}, + testData.matchPct); + std::vector> probeFiles = + createProbeFiles(probeVectors); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableVectors}); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/3, + tableData.keyVectors, + tableData.valueVectors, + *pool()); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); + auto planNodeIdGenerator = std::make_shared(); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType(testData.scanOutputColumns), + makeIndexColumnHandles(testData.scanOutputColumns)); + + // Create a plan with filter + auto plan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + testData.filter, + /*hasMarker=*/false, + testData.joinType, + testData.outputColumns); + + runLookupQuery( + plan, + probeFiles, + GetParam().serialExecution, + GetParam().serialExecution, + 32, + GetParam().numPrefetches, + GetParam().needsIndexSplit, + testData.duckDbVerifySql); + + if (testData.joinType != core::JoinType::kLeft) { + continue; + } + const auto probeScanId = probeScanNodeId_; + auto planWithMatchColumn = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + testData.filter, + /*hasMarker=*/true, + testData.joinType, + testData.outputColumns); + verifyResultWithMatchColumn( + plan, + probeScanId, + planWithMatchColumn, + probeScanNodeId_, + probeFiles, + GetParam().needsIndexSplit); + } +} + +TEST_P(IndexLookupJoinTest, mixedFilterBatches) { + // Create IndexTableData using VectorTestBase utilities + IndexTableData tableData; + + const std::string dummyString("test"); + StringView dummyStringView(dummyString); + // Create table key data (u0, u1, u2) using makeFlatVector + auto u0 = makeFlatVector(64, [&](auto row) { return row % 8; }); + auto u1 = makeFlatVector(64, [&](auto row) { return row % 8; }); + auto u2 = makeFlatVector(64, [&](auto row) { return row % 8; }); + tableData.keyVectors = makeRowVector({"u0", "u1", "u2"}, {u0, u1, u2}); + + // Create table value data (u3, u4, u5) using makeFlatVector + auto u3 = makeFlatVector(64, [&](auto row) { return row; }); + auto u4 = makeFlatVector(64, [&](auto row) { return row; }); + auto u5 = makeFlatVector( + 64, [&](auto /*unused*/) { return dummyStringView; }); + tableData.valueVectors = makeRowVector({"u3", "u4", "u5"}, {u3, u4, u5}); + + // Create complete table data by combining key and value data + tableData.tableVectors = makeRowVector( + {"u0", "u1", "u2", "u3", "u4", "u5"}, {u0, u1, u2, u3, u4, u5}); + + // Create probe vectors using makeArrayVectorFromJson in a loop + std::vector probeVectors; + probeVectors.reserve(5); + for (int i = 0; i < 5; ++i) { + probeVectors.push_back(makeRowVector( + {"t0", "t1", "t2", "t3", "t4", "t5"}, + {makeFlatVector(128, [&](auto row) { return row; }), + makeFlatVector(128, [&](auto row) { return row; }), + makeFlatVector(128, [&](auto row) { return row; }), + makeFlatVector(128, [&](auto row) { return row; }), + makeArrayVector( + 128, + [](vector_size_t /*unused*/) { return 1; }, + [](vector_size_t, vector_size_t) { return 1; }), + makeFlatVector( + 128, [&](auto /*unused*/) { return dummyStringView; })})); + } + + std::vector> probeFiles = + createProbeFiles(probeVectors); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableVectors}); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/3, + tableData.keyVectors, + tableData.valueVectors, + *pool()); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); + auto planNodeIdGenerator = std::make_shared(); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType({"u0", "u1", "u2", "u3", "u5"}), + makeIndexColumnHandles({"u0", "u1", "u2", "u3", "u5"})); + + auto plan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + "t3 > 4", + /*hasMarker=*/false, + core::JoinType::kLeft, + {"t1", "u1", "u2", "u3", "u5"}); + + AssertQueryBuilder queryBuilder(duckDbQueryRunner_); + queryBuilder.plan(plan) + .config( + core::QueryConfig::kIndexLookupJoinMaxPrefetchBatches, + std::to_string(GetParam().numPrefetches)) + .config(core::QueryConfig::kPreferredOutputBatchRows, "4") + .config(core::QueryConfig::kIndexLookupJoinSplitOutput, "true") + .splits(probeScanNodeId_, makeHiveConnectorSplits(probeFiles)) + .serialExecution(GetParam().serialExecution) + .barrierExecution(GetParam().serialExecution); + if (GetParam().needsIndexSplit) { + queryBuilder.split( + indexScanNodeId_, + Split( + std::make_shared( + kTestIndexConnectorName))); + } + queryBuilder.assertResults( + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2 AND t.c3 > 4"); +} + +// Tests the index split handling behavior of the IndexLookupJoin operator. +// When needsIndexSplit is true: +// - The operator blocks waiting for splits until it receives them +// - Works correctly with various split counts (0, 1, 2, 3 splits) +// - With 0 splits (partition pruning), LEFT JOIN emits probe rows with nulls +// This test only runs when GetParam().needsIndexSplit is true. +DEBUG_ONLY_TEST_P(IndexLookupJoinTest, needsIndexSplit) { + if (!GetParam().needsIndexSplit) { + return; + } + keyType_ = ROW({"u0"}, {BIGINT()}); + valueType_ = ROW({"u1", "u2"}, {BIGINT(), VARCHAR()}); + tableType_ = concat(keyType_, valueType_); + probeType_ = ROW({"t0", "t1", "t2"}, {BIGINT(), BIGINT(), VARCHAR()}); + + IndexTableData tableData; + generateIndexTableData({100}, tableData, pool_); + const auto probeVectors = + generateProbeInput(3, 100, 1, tableData, pool_, {"t0"}); + const auto probeFiles = createProbeFiles(probeVectors); + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableVectors}); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/1, + tableData.keyVectors, + tableData.valueVectors, + *pool()); + + const auto duckDbVerifySql = + "SELECT t.c0, t.c1, u.c1, u.c2 FROM t LEFT JOIN u ON t.c0 = u.c0"; + + struct { + int numIndexSplits; + + std::string debugString() const { + return fmt::format("numIndexSplits: {}", numIndexSplits); + } + } testSettings[] = { + {1}, + {2}, + {3}, + {0}, + }; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + // Create index table handle that requires splits. + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, /*needsIndexSplit=*/true); + auto planNodeIdGenerator = std::make_shared(); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType({"u0", "u1", "u2"}), + makeIndexColumnHandles({"u0", "u1", "u2"})); + auto plan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0"}, + {"u0"}, + {}, + "", + /*hasMarker=*/false, + core::JoinType::kLeft, + {"t0", "t1", "u1", "u2"}); + + // Track the number of times collectIndexSplits is called. + std::atomic_int collectSplitCallCount{0}; + folly::EventCount waitCollectSplit; + std::atomic_bool waitCollectSplitFlag{true}; + + std::mutex mutex; + std::shared_ptr task; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::IndexLookupJoin::collectIndexSplits", + std::function( + [&](const IndexLookupJoin* op) { + { + std::lock_guard lock(mutex); + if (task == nullptr) { + task = op->operatorCtx()->task(); + // Signal that we've entered collectIndexSplits for the first + // time. + waitCollectSplitFlag = false; + waitCollectSplit.notifyAll(); + } + } + ++collectSplitCallCount; + })); + + // For 0 splits, the LEFT JOIN should produce all probe rows with nulls + // on the lookup side. + const auto expectedSql = testData.numIndexSplits == 0 + ? "SELECT t.c0, t.c1, NULL, NULL FROM t" + : duckDbVerifySql; + + // Run the query in a separate thread without providing index splits + // upfront. The main thread will provide splits after the task starts. + std::thread queryThread([&] { + AssertQueryBuilder queryBuilder(duckDbQueryRunner_); + queryBuilder.plan(plan) + .splits(probeScanNodeId_, makeHiveConnectorSplits(probeFiles)) + .assertResults(expectedSql); + }); + + // Wait for collectIndexSplits to be called. + waitCollectSplit.await([&] { return !waitCollectSplitFlag.load(); }); + // Wait for 1 second and expect the task to NOT finish since it's + // waiting for splits. + std::this_thread::sleep_for(std::chrono::seconds(1)); // NOLINT + { + std::lock_guard lock(mutex); + ASSERT_NE(task, nullptr); + ASSERT_EQ(task->state(), TaskState::kRunning) + << "Task should still be running while waiting for splits"; + + for (int i = 0; i < testData.numIndexSplits; ++i) { + task->addSplit( + indexScanNodeId_, + Split( + std::make_shared( + kTestIndexConnectorName))); + } + task->noMoreSplits(indexScanNodeId_); + } + + queryThread.join(); + + // Verify collectIndexSplits was called the expected number of times: + // once initially when blocked, plus once after splits are available. + ASSERT_EQ(collectSplitCallCount.load(), 2) + << "collectIndexSplits should be called once initially (blocked), " + "then once when splits are available"; + } +} + +// Tests that when needsIndexSplit is false, the operator does NOT call +// collectIndexSplits. The query should complete successfully without waiting +// for any index splits. +// This test only runs when GetParam().needsIndexSplit is false. +DEBUG_ONLY_TEST_P(IndexLookupJoinTest, noNeedsIndexSplitNoCollect) { + if (GetParam().needsIndexSplit) { + return; + } + + keyType_ = ROW({"u0"}, {BIGINT()}); + valueType_ = ROW({"u1", "u2"}, {BIGINT(), VARCHAR()}); + tableType_ = concat(keyType_, valueType_); + probeType_ = ROW({"t0", "t1", "t2"}, {BIGINT(), BIGINT(), VARCHAR()}); + + IndexTableData tableData; + generateIndexTableData({100}, tableData, pool_); + const auto probeVectors = + generateProbeInput(3, 100, 1, tableData, pool_, {"t0"}); + const auto probeFiles = createProbeFiles(probeVectors); + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableVectors}); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/1, + tableData.keyVectors, + tableData.valueVectors, + *pool()); + + const auto duckDbVerifySql = + "SELECT t.c0, t.c1, u.c1, u.c2 FROM t LEFT JOIN u ON t.c0 = u.c0"; + + // Create index table handle with needsIndexSplit=false. + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, /*needsIndexSplit=*/false); + auto planNodeIdGenerator = std::make_shared(); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType({"u0", "u1", "u2"}), + makeIndexColumnHandles({"u0", "u1", "u2"})); + auto plan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0"}, + {"u0"}, + {}, + "", + /*hasMarker=*/false, + core::JoinType::kLeft, + {"t0", "t1", "u1", "u2"}); + + // Track if collectIndexSplits is ever called (it should NOT be). + std::atomic_bool collectSplitsCalled{false}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::IndexLookupJoin::collectIndexSplits", + std::function( + [&](const IndexLookupJoin* /*op*/) { collectSplitsCalled = true; })); + + // Run the query - it should complete without waiting for splits. + AssertQueryBuilder(duckDbQueryRunner_) + .plan(plan) + .splits(probeScanNodeId_, makeHiveConnectorSplits(probeFiles)) + .assertResults(duckDbVerifySql); + + // Verify collectIndexSplits was never called since needsIndexSplit=false. + ASSERT_FALSE(collectSplitsCalled.load()) + << "collectIndexSplits should not be called when needsIndexSplit is " + "false"; +} + +// Tests that when needsIndexSplit is false, adding splits or signaling +// no-more-splits should fail because the index scan node is not registered +// for split collection. +// This test only runs when GetParam().needsIndexSplit is false. +DEBUG_ONLY_TEST_P(IndexLookupJoinTest, noNeedsIndexSplitSplitOperationFails) { + if (GetParam().needsIndexSplit) { + return; + } + + keyType_ = ROW({"u0"}, {BIGINT()}); + valueType_ = ROW({"u1", "u2"}, {BIGINT(), VARCHAR()}); + tableType_ = concat(keyType_, valueType_); + probeType_ = ROW({"t0", "t1", "t2"}, {BIGINT(), BIGINT(), VARCHAR()}); + + IndexTableData tableData; + generateIndexTableData({100}, tableData, pool_); + const auto probeVectors = + generateProbeInput(3, 100, 1, tableData, pool_, {"t0"}); + const auto probeFiles = createProbeFiles(probeVectors); + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableVectors}); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/1, + tableData.keyVectors, + tableData.valueVectors, + *pool()); + + const auto duckDbVerifySql = + "SELECT t.c0, t.c1, u.c1, u.c2 FROM t LEFT JOIN u ON t.c0 = u.c0"; + + // Create index table handle with needsIndexSplit=false. + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, /*needsIndexSplit=*/false); + auto planNodeIdGenerator = std::make_shared(); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType({"u0", "u1", "u2"}), + makeIndexColumnHandles({"u0", "u1", "u2"})); + auto plan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0"}, + {"u0"}, + {}, + "", + /*hasMarker=*/false, + core::JoinType::kLeft, + {"t0", "t1", "u1", "u2"}); + + // Test both addSplit and noMoreSplits operations. + for (bool testAddSplit : {true, false}) { + SCOPED_TRACE(fmt::format("testAddSplit: {}", testAddSplit)); + + // Use TestValue to block the Task from starting to allow us to verify + // that split operations fail. + folly::EventCount taskEnterWait; + std::atomic_bool taskEnterWaitFlag{true}; + std::shared_ptr task; + std::mutex mutex; + + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Task::enter", + std::function([&](Task* taskPtr) { + { + std::lock_guard lock(mutex); + if (task == nullptr) { + task = taskPtr->shared_from_this(); + } + } + // Block until the test thread signals to continue. + taskEnterWait.await([&] { return !taskEnterWaitFlag.load(); }); + })); + + // Run the query in a separate thread. + std::thread queryThread([&] { + AssertQueryBuilder(duckDbQueryRunner_) + .plan(plan) + .splits(probeScanNodeId_, makeHiveConnectorSplits(probeFiles)) + .assertResults(duckDbVerifySql); + }); + + // Wait a bit for the task to start and hit the TestValue. + std::this_thread::sleep_for(std::chrono::seconds(1)); // NOLINT + + { + std::lock_guard lock(mutex); + if (task != nullptr) { + if (testAddSplit) { + // Try to add a split - this should fail because the index scan node + // is not registered for split collection. + VELOX_ASSERT_THROW( + task->addSplit( + indexScanNodeId_, + Split( + std::make_shared( + kTestIndexConnectorName))), + "Splits can be associated only with leaf plan nodes which require splits. Plan node ID 0 doesn't refer to such plan node"); + } else { + // Try to signal no more splits - this should fail because the index + // scan node is not registered for split collection. + VELOX_ASSERT_THROW( + task->noMoreSplits(indexScanNodeId_), + "Splits can be associated only with leaf plan nodes which require splits. Plan node ID 0 doesn't refer to such plan node."); + } + } + } + + // Allow the query to complete. + taskEnterWaitFlag = false; + taskEnterWait.notifyAll(); + + queryThread.join(); + } +} +// Verifies that when multiple drivers are waiting for index splits (the +// collector waiting for splits from the task, followers waiting on the bridge), +// aborting the task unblocks all drivers and the task terminates cleanly. +DEBUG_ONLY_TEST_P(IndexLookupJoinTest, abortDuringIndexSplitWait) { + if (GetParam().serialExecution || !GetParam().needsIndexSplit) { + return; + } + + keyType_ = ROW({"u0"}, {BIGINT()}); + valueType_ = ROW({"u1", "u2"}, {BIGINT(), VARCHAR()}); + tableType_ = concat(keyType_, valueType_); + probeType_ = ROW({"t0", "t1", "t2"}, {BIGINT(), BIGINT(), VARCHAR()}); + + IndexTableData tableData; + generateIndexTableData({100}, tableData, pool_); + const auto probeVectors = + generateProbeInput(3, 100, 1, tableData, pool_, {"t0"}); + const auto probeFiles = createProbeFiles(probeVectors); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/1, + tableData.keyVectors, + tableData.valueVectors, + *pool()); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, /*needsIndexSplit=*/true); + auto planNodeIdGenerator = std::make_shared(); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType({"u0", "u1", "u2"}), + makeIndexColumnHandles({"u0", "u1", "u2"})); + auto plan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0"}, + {"u0"}, + {}, + /*filter=*/"", + /*hasMarker=*/false, + core::JoinType::kLeft, + {"t0", "t1", "u1", "u2"}); + + const int numDrivers = 4; + + // Use TestValue to capture the task when the collector enters + // collectIndexSplits, and hold it there until the test signals. + folly::EventCount collectSplitWait; + std::atomic_bool collectSplitReached{false}; + folly::EventCount collectUnblockWait; + std::atomic_bool collectUnblockFlag{false}; + std::mutex mutex; + std::shared_ptr task; + + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::IndexLookupJoin::collectIndexSplits", + std::function( + [&](const IndexLookupJoin* op) { + { + std::lock_guard lock(mutex); + if (task == nullptr) { + task = op->operatorCtx()->task(); + collectSplitReached = true; + collectSplitWait.notifyAll(); + } + } + // Hold the collector here until the test signals. + collectUnblockWait.await([&] { return collectUnblockFlag.load(); }); + })); + + // Run the query in a separate thread. Don't provide index splits so the + // collector blocks. + std::thread queryThread([&] { + VELOX_ASSERT_THROW( + AssertQueryBuilder(duckDbQueryRunner_) + .plan(plan) + .splits(probeScanNodeId_, makeHiveConnectorSplits(probeFiles)) + .maxDrivers(numDrivers) + .copyResults(pool()), + "Aborted"); + }); + + // Wait for the collector to enter collectIndexSplits. + collectSplitWait.await([&] { return collectSplitReached.load(); }); + + // Wait until all non-collector drivers are off-thread (blocked on the + // bridge). The collector is held by the TestValue callback above. + while (true) { + int offThreadCount = 0; + task->testingVisitDrivers([&](Driver* driver) { + if (!driver->isOnThread()) { + ++offThreadCount; + } + }); + // All followers (numDrivers - 1) should be off-thread. + if (offThreadCount >= numDrivers - 1) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(10)); // NOLINT + } + + // Abort the task while all drivers are blocked. + { + std::lock_guard lock(mutex); + ASSERT_NE(task, nullptr); + ASSERT_EQ(task->state(), TaskState::kRunning); + task->requestAbort(); + } + + // Wait for all follower drivers to finish. The collector is still held by + // the TestValue callback, so only 1 driver (the collector) should remain. + while (true) { + int aliveCount = 0; + task->testingVisitDrivers([&](Driver*) { ++aliveCount; }); + if (aliveCount == 1) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(10)); // NOLINT + } + + // Release the collector from the TestValue callback. + collectUnblockFlag = true; + collectUnblockWait.notifyAll(); + + queryThread.join(); + + // Verify the task terminated properly. + ASSERT_TRUE(waitForTaskAborted(task.get())); +} + +TEST_P(IndexLookupJoinTest, multiDriverWithIndexSplits) { + if (GetParam().serialExecution || !GetParam().needsIndexSplit) { + return; + } + IndexTableData tableData; + generateIndexTableData({100, 1, 1}, tableData, pool_); + const int numProbeSplits{5}; + const auto probeVectors = generateProbeInput( + numProbeSplits, + 256, + 1, + tableData, + pool_, + {"t0", "t1", "t2"}, + GetParam().hasNullKeys, + {}, + {}, + 100); + const auto probeFiles = createProbeFiles(probeVectors); + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableVectors}); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/3, + tableData.keyVectors, + tableData.valueVectors, + *pool()); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); + auto planNodeIdGenerator = std::make_shared(); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType({"u0", "u1", "u2", "u3", "u5"}), + makeIndexColumnHandles({"u0", "u1", "u2", "u3", "u5"})); + + for (const auto joinType : {core::JoinType::kInner, core::JoinType::kLeft}) { + SCOPED_TRACE( + fmt::format("joinType: {}", core::JoinTypeName::toName(joinType))); + const auto duckDbSql = joinType == core::JoinType::kInner + ? "SELECT u.c3, t.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2" + : "SELECT u.c3, t.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2"; + auto plan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + /*filter=*/"", + /*hasMarker=*/false, + joinType, + {"u3", "t5"}); + + for (int numDrivers : {2, 4}) { + SCOPED_TRACE(fmt::format("numDrivers: {}", numDrivers)); + auto task = runLookupQuery( + plan, + probeFiles, + /*serialExecution=*/false, + /*barrierExecution=*/false, + 100, + GetParam().numPrefetches, + GetParam().needsIndexSplit, + duckDbSql, + numDrivers); + auto taskStats = toPlanStats(task->taskStats()); + ASSERT_EQ(taskStats.at(joinNodeId_).numDrivers, numDrivers); + } + } +} + +// Verifies IndexLookupJoin works correctly with grouped execution where the +// probe side is grouped. Each split group creates its own set of drivers with +// partitionId 0..numDrivers-1, and the partitionId==0 driver in each group +// acts as the split collector for that group's index splits. +TEST_P(IndexLookupJoinTest, groupedExecution) { + if (GetParam().serialExecution || !GetParam().needsIndexSplit) { + return; + } + IndexTableData tableData; + generateIndexTableData({100, 1, 1}, tableData, pool_); + const int numProbeSplitsPerGroup{2}; + const int numSplitGroups{3}; + const int numDrivers{2}; + const auto probeVectors = generateProbeInput( + numProbeSplitsPerGroup * numSplitGroups, + 256, + 1, + tableData, + pool_, + {"t0", "t1", "t2"}, + GetParam().hasNullKeys, + {}, + {}, + 100); + const auto probeFiles = createProbeFiles(probeVectors); + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableVectors}); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/3, + tableData.keyVectors, + tableData.valueVectors, + *pool()); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); + auto planNodeIdGenerator = std::make_shared(); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType({"u0", "u1", "u2", "u3", "u5"}), + makeIndexColumnHandles({"u0", "u1", "u2", "u3", "u5"})); + + for (const auto joinType : {core::JoinType::kInner, core::JoinType::kLeft}) { + SCOPED_TRACE( + fmt::format("joinType: {}", core::JoinTypeName::toName(joinType))); + auto outputLayout = std::vector{"u3", "t5"}; + auto plan = PlanBuilder(planNodeIdGenerator, pool()) + .tableScan(probeType_) + .capturePlanNodeId(probeScanNodeId_) + .startIndexLookupJoin() + .leftKeys({"t0", "t1", "t2"}) + .rightKeys({"u0", "u1", "u2"}) + .indexSource(indexScanNode) + .outputLayout(outputLayout) + .joinType(joinType) + .endIndexLookupJoin() + .capturePlanNodeId(joinNodeId_) + .partitionedOutput({}, 1, outputLayout) + .planFragment(); + + plan.executionStrategy = core::ExecutionStrategy::kGrouped; + plan.groupedExecutionLeafNodeIds.emplace(probeScanNodeId_); + plan.groupedExecutionLeafNodeIds.emplace(indexScanNodeId_); + plan.numSplitGroups = numSplitGroups; + + auto queryCtx = core::QueryCtx::create(executor_.get()); + std::unordered_map configs; + configs[QueryConfig::kIndexLookupJoinMaxPrefetchBatches] = + std::to_string(GetParam().numPrefetches); + queryCtx->testingOverrideConfigUnsafe(std::move(configs)); + + auto task = exec::Task::create( + fmt::format("grouped-index-lookup-join-{}", joinType), + std::move(plan), + 0, + std::move(queryCtx), + Task::ExecutionMode::kParallel); + task->start(numDrivers, /*concurrentSplitGroups=*/2); + + // Add probe splits with group IDs. + int fileIdx = 0; + for (int group = 0; group < numSplitGroups; ++group) { + for (int j = 0; j < numProbeSplitsPerGroup; ++j) { + task->addSplit( + probeScanNodeId_, + Split( + makeHiveConnectorSplit(probeFiles[fileIdx++]->getPath()), + group)); + } + task->noMoreSplitsForGroup(probeScanNodeId_, group); + } + task->noMoreSplits(probeScanNodeId_); + + // Add one index split per group so each group's collector finds its split. + for (int group = 0; group < numSplitGroups; ++group) { + task->addSplit( + indexScanNodeId_, + Split( + std::make_shared( + kTestIndexConnectorName), + group)); + task->noMoreSplitsForGroup(indexScanNodeId_, group); + } + + // Consume output via OutputBufferManager. + auto outputBufferManager = exec::OutputBufferManager::getInstanceRef(); + outputBufferManager->deleteResults(task->taskId(), 0); + + waitForTaskCompletion(task.get()); + ASSERT_EQ(task->state(), TaskState::kFinished); + + auto taskStats = toPlanStats(task->taskStats()); + ASSERT_EQ( + taskStats.at(joinNodeId_).numDrivers, numDrivers * numSplitGroups); + } +} + +// Verifies that grouped execution completes correctly when some split groups +// receive no index splits (e.g., partition pruning eliminates all index +// partitions for a group). For INNER JOIN, the pruned group produces no +// output. For LEFT JOIN, the pruned group emits probe rows with nulls. +TEST_P(IndexLookupJoinTest, groupedExecutionWithEmptyIndexSplits) { + if (GetParam().serialExecution || !GetParam().needsIndexSplit) { + return; + } + IndexTableData tableData; + generateIndexTableData({100, 1, 1}, tableData, pool_); + const int numProbeSplitsPerGroup{1}; + const int numSplitGroups{3}; + const int numDrivers{2}; + const auto probeVectors = generateProbeInput( + numProbeSplitsPerGroup * numSplitGroups, + 256, + 1, + tableData, + pool_, + {"t0", "t1", "t2"}, + GetParam().hasNullKeys, + {}, + {}, + 100); + const auto probeFiles = createProbeFiles(probeVectors); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/3, + tableData.keyVectors, + tableData.valueVectors, + *pool()); + const auto indexTableHandle = makeIndexTableHandle( + indexTable, GetParam().asyncLookup, GetParam().needsIndexSplit); + auto planNodeIdGenerator = std::make_shared(); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType({"u0", "u1", "u2", "u3", "u5"}), + makeIndexColumnHandles({"u0", "u1", "u2", "u3", "u5"})); + + for (const auto joinType : {core::JoinType::kInner, core::JoinType::kLeft}) { + SCOPED_TRACE( + fmt::format("joinType: {}", core::JoinTypeName::toName(joinType))); + auto outputLayout = std::vector{"u3", "t5"}; + auto plan = PlanBuilder(planNodeIdGenerator, pool()) + .tableScan(probeType_) + .capturePlanNodeId(probeScanNodeId_) + .startIndexLookupJoin() + .leftKeys({"t0", "t1", "t2"}) + .rightKeys({"u0", "u1", "u2"}) + .indexSource(indexScanNode) + .outputLayout(outputLayout) + .joinType(joinType) + .endIndexLookupJoin() + .capturePlanNodeId(joinNodeId_) + .partitionedOutput({}, 1, outputLayout) + .planFragment(); + + plan.executionStrategy = core::ExecutionStrategy::kGrouped; + plan.groupedExecutionLeafNodeIds.emplace(probeScanNodeId_); + plan.groupedExecutionLeafNodeIds.emplace(indexScanNodeId_); + plan.numSplitGroups = numSplitGroups; + + auto queryCtx = core::QueryCtx::create(executor_.get()); + std::unordered_map configs; + configs[QueryConfig::kIndexLookupJoinMaxPrefetchBatches] = + std::to_string(GetParam().numPrefetches); + queryCtx->testingOverrideConfigUnsafe(std::move(configs)); + + auto task = exec::Task::create( + fmt::format("grouped-empty-index-splits-{}", joinType), + std::move(plan), + 0, + std::move(queryCtx), + Task::ExecutionMode::kParallel); + task->start(numDrivers, /*concurrentSplitGroups=*/2); + + // Add probe splits for all groups. + int fileIdx = 0; + for (int group = 0; group < numSplitGroups; ++group) { + for (int j = 0; j < numProbeSplitsPerGroup; ++j) { + task->addSplit( + probeScanNodeId_, + Split( + makeHiveConnectorSplit(probeFiles[fileIdx++]->getPath()), + group)); + } + task->noMoreSplitsForGroup(probeScanNodeId_, group); + } + task->noMoreSplits(probeScanNodeId_); + + // Add index splits only for group 0. Groups 1 and 2 get no index splits, + // simulating partition pruning. + task->addSplit( + indexScanNodeId_, + Split( + std::make_shared(kTestIndexConnectorName), + 0)); + task->noMoreSplitsForGroup(indexScanNodeId_, 0); + task->noMoreSplitsForGroup(indexScanNodeId_, 1); + task->noMoreSplitsForGroup(indexScanNodeId_, 2); + + auto outputBufferManager = exec::OutputBufferManager::getInstanceRef(); + outputBufferManager->deleteResults(task->taskId(), 0); + + waitForTaskCompletion(task.get()); + ASSERT_EQ(task->state(), TaskState::kFinished); + + auto taskStats = toPlanStats(task->taskStats()); + ASSERT_EQ( + taskStats.at(joinNodeId_).numDrivers, numDrivers * numSplitGroups); + } +} + +// Verifies that grouped execution leaf validation passes when the index +// source node ID is in groupedExecutionLeafNodeIds. Without the validation +// exclusion in collectIndexLookupSourceIds, Task::start() would reject the +// plan because it cannot find the index source node in any driver factory. +TEST_P(IndexLookupJoinTest, groupedExecutionLeafValidation) { + if (GetParam().serialExecution) { + return; + } + + const auto indexTableHandle = makeIndexTableHandle( + nullptr, GetParam().asyncLookup, GetParam().needsIndexSplit); + auto planNodeIdGenerator = std::make_shared(); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType({"u0", "u1", "u2", "u3", "u5"}), + makeIndexColumnHandles({"u0", "u1", "u2", "u3", "u5"})); + + auto outputLayout = std::vector{"u3", "t5"}; + auto plan = PlanBuilder(planNodeIdGenerator, pool()) + .tableScan(probeType_) + .capturePlanNodeId(probeScanNodeId_) + .startIndexLookupJoin() + .leftKeys({"t0", "t1", "t2"}) + .rightKeys({"u0", "u1", "u2"}) + .indexSource(indexScanNode) + .outputLayout(outputLayout) + .joinType(core::JoinType::kInner) + .endIndexLookupJoin() + .capturePlanNodeId(joinNodeId_) + .partitionedOutput({}, 1, outputLayout) + .planFragment(); + + plan.executionStrategy = core::ExecutionStrategy::kGrouped; + plan.groupedExecutionLeafNodeIds.emplace(probeScanNodeId_); + plan.groupedExecutionLeafNodeIds.emplace(indexScanNodeId_); + plan.numSplitGroups = 1; + + auto queryCtx = core::QueryCtx::create(executor_.get()); + auto task = exec::Task::create( + "grouped-execution-leaf-validation", + std::move(plan), + 0, + std::move(queryCtx), + Task::ExecutionMode::kParallel); + + // Task::start() runs validateGroupedExecutionLeafNodes which would throw + // if the index source node ID is not excluded from validation. + ASSERT_NO_THROW(task->start(1)); + + task->requestAbort().wait(); + ASSERT_TRUE(waitForTaskAborted(task.get())); +} + +} // namespace + +VELOX_INSTANTIATE_TEST_SUITE_P( + IndexLookupJoinTest, + IndexLookupJoinTest, + testing::ValuesIn(IndexLookupJoinTest::getTestParams()), + [](const testing::TestParamInfo& info) { + return fmt::format( + "{}_{}prefetches_{}_{}_{}", + info.param.asyncLookup ? "async" : "sync", + info.param.numPrefetches, + info.param.serialExecution ? "serial" : "parallel", + info.param.hasNullKeys ? "nullKeys" : "noNullKeys", + info.param.needsIndexSplit ? "needsIndexSplit" : "noSplit"); + }); +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/LimitTest.cpp b/velox/exec/tests/LimitTest.cpp index b83eb191a39..fb0badafb83 100644 --- a/velox/exec/tests/LimitTest.cpp +++ b/velox/exec/tests/LimitTest.cpp @@ -22,6 +22,7 @@ using namespace facebook::velox; using namespace facebook::velox::exec; using namespace facebook::velox::exec::test; +using namespace facebook::velox::common::testutil; namespace { class LimitTest : public HiveConnectorTestBase {}; @@ -146,7 +147,7 @@ TEST_F(LimitTest, partialLimitEagerFlush) { test(false); } -TEST_F(LimitTest, barrier) { +TEST_F(LimitTest, DISABLED_barrier) { std::vector vectors; std::vector> tempFiles; const int numSplits{5}; @@ -160,6 +161,7 @@ TEST_F(LimitTest, barrier) { createDuckDbTable(vectors); struct { + bool serialExecution; bool barrierExecution; int offset; int limit; @@ -170,7 +172,8 @@ TEST_F(LimitTest, barrier) { std::string toString() const { return fmt::format( - "barrierExecution {}, offset: {}, limit: {}, numExpectedBarriers: {}, numExpectedOutputRows: {}, numExpectedFinishedSplits: {}, numExpectedOutputBatches: {}", + "serialExecution {}, barrierExecution {}, offset: {}, limit: {}, numExpectedBarriers: {}, numExpectedOutputRows: {}, numExpectedFinishedSplits: {}, numExpectedOutputBatches: {}", + serialExecution, barrierExecution, offset, limit, @@ -179,130 +182,314 @@ TEST_F(LimitTest, barrier) { numExpectedFinishedSplits, numExpectedOutputBatches); } - } testSettings[] = {// Test the case where the limit covers all the input rows - // with barrier and not. - {true, - 0, - numRowsPerSplit * numSplits, - numSplits, - numRowsPerSplit * numSplits, - numSplits - 1, - numSplits}, - {false, - 0, - numRowsPerSplit * numSplits, - 0, - numRowsPerSplit * numSplits, - numSplits - 1, - numSplits}, - // Test the cases where the limit covers the first entire - // split rows with barrier or not. with barrier and not. - {true, 0, numRowsPerSplit, 1, numRowsPerSplit, 0, 1}, - {false, 0, numRowsPerSplit, 0, numRowsPerSplit, 0, 1}, - // Test the case where the limit covers the one and half - // split rows with barrier or not. - {true, - 0, - numRowsPerSplit + numRowsPerSplit / 2, - 2, - numRowsPerSplit + numRowsPerSplit / 2, - 1, - 2}, - {false, - 0, - numRowsPerSplit + numRowsPerSplit / 2, - 0, - numRowsPerSplit + numRowsPerSplit / 2, - 1, - 2}, - {true, - 0, - numRowsPerSplit + numRowsPerSplit - 1, - 2, - numRowsPerSplit + numRowsPerSplit - 1, - 1, - 2}, - {false, - 0, - numRowsPerSplit + numRowsPerSplit - 1, - 0, - numRowsPerSplit + numRowsPerSplit - 1, - 1, - 2}, - // Test the case where the limit set to cover more than - // all the input rows with barrier or not. - {true, - 0, - numRowsPerSplit * (numSplits + 1), - numSplits, - numRowsPerSplit * numSplits, - numSplits, - numSplits}, - {false, - 0, - numRowsPerSplit * (numSplits + 1), - 0, - numRowsPerSplit * numSplits, - numSplits, - numSplits}, - // Test the cases where the limit set to cover partial - // input rows in the middle with barrier or not. - {true, - numRowsPerSplit, - numRowsPerSplit + numRowsPerSplit - 1, - 3, - numRowsPerSplit + numRowsPerSplit - 1, - 2, - 2}, - {false, - numRowsPerSplit, - numRowsPerSplit + numRowsPerSplit - 1, - 0, - numRowsPerSplit + numRowsPerSplit - 1, - 2, - 2}, - {true, - numRowsPerSplit, - numRowsPerSplit * (numSplits - 1), - numSplits, - numRowsPerSplit * (numSplits - 1), - numSplits - 1, - numSplits - 1}, - {false, - numRowsPerSplit, - numRowsPerSplit * (numSplits - 1), - 0, - numRowsPerSplit * (numSplits - 1), - numSplits - 1, - numSplits - 1}, - {true, - numRowsPerSplit, - numRowsPerSplit / 2, - 2, - numRowsPerSplit / 2, - 1, - 1}, - {false, - numRowsPerSplit, - numRowsPerSplit / 2, - 0, - numRowsPerSplit / 2, - 1, - 1}, - {true, - numRowsPerSplit / 2, - numRowsPerSplit * numSplits, - numSplits, - numRowsPerSplit * numSplits - numRowsPerSplit / 2, - numSplits, - numSplits}, - {false, - numRowsPerSplit / 2, - numRowsPerSplit * numSplits, - 0, - numRowsPerSplit * numSplits - numRowsPerSplit / 2, - numSplits, - numSplits}}; + } testSettings[] = { + // Test the case where the limit covers all the input rows + // with barrier and not. + {false, + false, + 0, + numRowsPerSplit * numSplits, + numSplits, + numRowsPerSplit * numSplits, + numSplits - 1, + numSplits}, + {true, + false, + 0, + numRowsPerSplit * numSplits, + 0, + numRowsPerSplit * numSplits, + numSplits - 1, + numSplits}, + {false, + true, + 0, + numRowsPerSplit * numSplits, + numSplits, + numRowsPerSplit * numSplits, + numSplits - 1, + numSplits}, + {false, + false, + 0, + numRowsPerSplit * numSplits, + 0, + numRowsPerSplit * numSplits, + numSplits - 1, + numSplits}, + // Test the cases where the limit covers the first entire + // split rows with barrier or not. with barrier and not. + { + true, + true, + 0, + numRowsPerSplit, + 1, + numRowsPerSplit, + 0, + 1, + }, + { + true, + false, + 0, + numRowsPerSplit, + 0, + numRowsPerSplit, + 0, + 1, + }, + { + false, + true, + 0, + numRowsPerSplit, + 1, + numRowsPerSplit, + 0, + 1, + }, + { + false, + false, + 0, + numRowsPerSplit, + 0, + numRowsPerSplit, + 0, + 1, + }, + // Test the case where the limit covers the one and half + // split rows with barrier or not. + {true, + true, + 0, + numRowsPerSplit + numRowsPerSplit / 2, + 2, + numRowsPerSplit + numRowsPerSplit / 2, + 1, + 2}, + {true, + false, + 0, + numRowsPerSplit + numRowsPerSplit / 2, + 0, + numRowsPerSplit + numRowsPerSplit / 2, + 1, + 2}, + {false, + true, + 0, + numRowsPerSplit + numRowsPerSplit / 2, + 2, + numRowsPerSplit + numRowsPerSplit / 2, + 1, + 2}, + {false, + false, + 0, + numRowsPerSplit + numRowsPerSplit / 2, + 0, + numRowsPerSplit + numRowsPerSplit / 2, + 1, + 2}, + {true, + true, + 0, + numRowsPerSplit + numRowsPerSplit - 1, + 2, + numRowsPerSplit + numRowsPerSplit - 1, + 1, + 2}, + {true, + false, + 0, + numRowsPerSplit + numRowsPerSplit - 1, + 0, + numRowsPerSplit + numRowsPerSplit - 1, + 1, + 2}, + {false, + true, + 0, + numRowsPerSplit + numRowsPerSplit - 1, + 2, + numRowsPerSplit + numRowsPerSplit - 1, + 1, + 2}, + {false, + false, + 0, + numRowsPerSplit + numRowsPerSplit - 1, + 0, + numRowsPerSplit + numRowsPerSplit - 1, + 1, + 2}, + // Test the case where the limit set to cover more than + // all the input rows with barrier or not. + {true, + true, + 0, + numRowsPerSplit * (numSplits + 1), + numSplits, + numRowsPerSplit * numSplits, + numSplits, + numSplits}, + {true, + false, + 0, + numRowsPerSplit * (numSplits + 1), + 0, + numRowsPerSplit * numSplits, + numSplits, + numSplits}, + {false, + true, + 0, + numRowsPerSplit * (numSplits + 1), + numSplits, + numRowsPerSplit * numSplits, + numSplits, + numSplits}, + {false, + false, + 0, + numRowsPerSplit * (numSplits + 1), + 0, + numRowsPerSplit * numSplits, + numSplits, + numSplits}, + // Test the cases where the limit set to cover partial + // input rows in the middle with barrier or not. + {true, + true, + numRowsPerSplit, + numRowsPerSplit + numRowsPerSplit - 1, + 3, + numRowsPerSplit + numRowsPerSplit - 1, + 2, + 2}, + {true, + false, + numRowsPerSplit, + numRowsPerSplit + numRowsPerSplit - 1, + 0, + numRowsPerSplit + numRowsPerSplit - 1, + 2, + 2}, + {false, + true, + numRowsPerSplit, + numRowsPerSplit + numRowsPerSplit - 1, + 3, + numRowsPerSplit + numRowsPerSplit - 1, + 2, + 2}, + {false, + false, + numRowsPerSplit, + numRowsPerSplit + numRowsPerSplit - 1, + 0, + numRowsPerSplit + numRowsPerSplit - 1, + 2, + 2}, + {true, + true, + numRowsPerSplit, + numRowsPerSplit * (numSplits - 1), + numSplits, + numRowsPerSplit * (numSplits - 1), + numSplits - 1, + numSplits - 1}, + {true, + false, + numRowsPerSplit, + numRowsPerSplit * (numSplits - 1), + 0, + numRowsPerSplit * (numSplits - 1), + numSplits - 1, + numSplits - 1}, + {false, + true, + numRowsPerSplit, + numRowsPerSplit * (numSplits - 1), + numSplits, + numRowsPerSplit * (numSplits - 1), + numSplits - 1, + numSplits - 1}, + {false, + false, + numRowsPerSplit, + numRowsPerSplit * (numSplits - 1), + 0, + numRowsPerSplit * (numSplits - 1), + numSplits - 1, + numSplits - 1}, + {true, + true, + numRowsPerSplit, + numRowsPerSplit / 2, + 2, + numRowsPerSplit / 2, + 1, + 1}, + {true, + false, + numRowsPerSplit, + numRowsPerSplit / 2, + 0, + numRowsPerSplit / 2, + 1, + 1}, + {false, + true, + numRowsPerSplit, + numRowsPerSplit / 2, + 2, + numRowsPerSplit / 2, + 1, + 1}, + {false, + false, + numRowsPerSplit, + numRowsPerSplit / 2, + 0, + numRowsPerSplit / 2, + 1, + 1}, + {true, + true, + numRowsPerSplit / 2, + numRowsPerSplit * numSplits, + numSplits, + numRowsPerSplit * numSplits - numRowsPerSplit / 2, + numSplits, + numSplits}, + {true, + false, + numRowsPerSplit / 2, + numRowsPerSplit * numSplits, + 0, + numRowsPerSplit * numSplits - numRowsPerSplit / 2, + numSplits, + numSplits}, + {false, + true, + numRowsPerSplit / 2, + numRowsPerSplit * numSplits, + numSplits, + numRowsPerSplit * numSplits - numRowsPerSplit / 2, + numSplits, + numSplits}, + {false, + false, + numRowsPerSplit / 2, + numRowsPerSplit * numSplits, + 0, + numRowsPerSplit * numSplits - numRowsPerSplit / 2, + numSplits, + numSplits}, + }; for (const auto& testData : testSettings) { SCOPED_TRACE(testData.toString()); core::PlanNodeId limitPlanNodeId; @@ -313,15 +500,15 @@ TEST_F(LimitTest, barrier) { .planNode(); auto task = AssertQueryBuilder(plan, duckDbQueryRunner_) .splits(makeHiveConnectorSplits(tempFiles)) - .serialExecution(true) + .serialExecution(testData.serialExecution) .barrierExecution(testData.barrierExecution) - .assertResults(fmt::format( - "SELECT * FROM tmp LIMIT {} OFFSET {}", - testData.limit, - testData.offset)); + .assertResults( + fmt::format( + "SELECT * FROM tmp LIMIT {} OFFSET {}", + testData.limit, + testData.offset)); const auto taskStats = task->taskStats(); ASSERT_EQ(taskStats.numBarriers, testData.numExpectedBarriers); - ASSERT_EQ(taskStats.numFinishedSplits, testData.numExpectedFinishedSplits); ASSERT_EQ( exec::toPlanStats(taskStats).at(limitPlanNodeId).outputRows, testData.numExpectedOutputRows); diff --git a/velox/exec/tests/LocalPartitionTest.cpp b/velox/exec/tests/LocalPartitionTest.cpp index 7dbf886fda6..ee59473fb88 100644 --- a/velox/exec/tests/LocalPartitionTest.cpp +++ b/velox/exec/tests/LocalPartitionTest.cpp @@ -24,6 +24,7 @@ using facebook::velox::test::BatchMaker; namespace facebook::velox::exec::test { +using namespace facebook::velox::common::testutil; namespace { class LocalPartitionTest : public HiveConnectorTestBase { @@ -329,6 +330,80 @@ TEST_F(LocalPartitionTest, partitionBuffering) { queryBuilder.assertResults(query), 2200, 2, 2); } +TEST_F(LocalPartitionTest, partitionBufferingWithStringBuffers) { + // Test string buffer memory accounting with partition buffering. + // String buffers are multiply-referenced across partitions. The fix + // amortizes string buffer sizes across partitions to avoid over-counting, + // which allows more efficient buffering. + std::vector vectors; + for (auto i = 0; i < 4; ++i) { + vectors.emplace_back(makeRowVector( + {"c0", "c1"}, + {makeFlatVector(100, [](auto row) { return row % 2; }), + makeFlatVector( + 100, [](auto /*row*/) { return std::string(100, 'a'); })})); + } + + auto runQuery = [&](const std::vector& input, + int maxDrivers, + int maxPartitionBufferSize) { + std::string query{"SELECT c0, arbitrary(c1) FROM tmp GROUP BY c0"}; + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .localPartition( + {"c0"}, + {PlanBuilder(planNodeIdGenerator).values(input).planNode()}) + .partialAggregation({"c0"}, {"arbitrary(c1)"}) + .planNode(); + createDuckDbTable(vectors); + + AssertQueryBuilder queryBuilder(plan, duckDbQueryRunner_); + queryBuilder.maxDrivers(maxDrivers); + + std::unordered_map configs; + configs[core::QueryConfig:: + kMinLocalExchangePartitionCountToUsePartitionBuffer] = + std::to_string(2); + configs[core::QueryConfig::kMaxLocalExchangePartitionBufferSize] = + std::to_string(maxPartitionBufferSize); + queryBuilder.configs(configs); + + return queryBuilder.assertResults(query); + }; + + // With amortized string buffer accounting, the buffer can accumulate rows + // from all input vectors. So only 2 vectors are flushed to + // LocalExchangeQueues, compared to 8 if not amortizing string buffer sizes. + verifyExchangeSourceOperatorStats(runQuery(vectors, 2, 50000), 400, 2, 2); + + // Test case with 99% of data belonging to one partition. We expect all + // partition buffers getting flushed together when the total size of all + // partition buffers exceeds the limit, instead of flushing each partiton + // buffer individually. + vectors.clear(); + for (auto i = 0; i < 4; ++i) { + vectors.emplace_back(makeRowVector( + {"c0", "c1"}, + {makeFlatVector( + 100, + [](auto row) { + if (row == 0) { + return 0; + } else { + return 1; + } + }), + makeFlatVector( + 100, [](auto /*row*/) { return std::string(100, 'a'); })})); + } + + // The total size of all partition buffers should exceed the limit and trigger + // the flush of all partition buffers at every batch, hence flushing a total + // of 8 vectors. + verifyExchangeSourceOperatorStats(runQuery(vectors, 2, 20000), 400, 8, 2); +} + TEST_F(LocalPartitionTest, partitionBufferingPreserveEncoding) { std::vector vectors = { makeRowVector({"c0"}, {makeConstant(0, 100)}), @@ -388,8 +463,9 @@ TEST_F(LocalPartitionTest, maxBufferSizeGather) { auto valuesNode = [&](int start, int end) { return PlanBuilder(planNodeIdGenerator) - .values(std::vector( - vectors.begin() + start, vectors.begin() + end)) + .values( + std::vector( + vectors.begin() + start, vectors.begin() + end)) .planNode(); }; @@ -961,8 +1037,9 @@ TEST_F(LocalPartitionTest, unionAllLocalExchangeWithInterDependency) { } }; - Operator::registerOperator(std::make_unique( - std::move(blockingCallback), std::move(finishCallback))); + Operator::registerOperator( + std::make_unique( + std::move(blockingCallback), std::move(finishCallback))); auto planNodeIdGenerator = std::make_shared(); auto plan = PlanBuilder(planNodeIdGenerator) @@ -1030,8 +1107,9 @@ TEST_F( auto finishCallback = [&](bool /*unused*/) {}; - Operator::registerOperator(std::make_unique( - std::move(blockingCallback), std::move(finishCallback))); + Operator::registerOperator( + std::make_unique( + std::move(blockingCallback), std::move(finishCallback))); auto planNodeIdGenerator = std::make_shared(); auto plan = PlanBuilder(planNodeIdGenerator) @@ -1143,11 +1221,23 @@ TEST_F(LocalPartitionTest, barrier) { tableScanNode(), }) .planNode(); + struct { + bool hasBarrier; + bool serialExecution; - for (const auto hasBarrier : {false, true}) { - SCOPED_TRACE(fmt::format("hasBarrier {}", hasBarrier)); + std::string toString() const { + return fmt::format( + "hasBarrier: {}, serialExecution: {}", hasBarrier, serialExecution); + } + } testSettings[] = { + {false, false}, {false, true}, {true, false}, {true, true}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.toString()); AssertQueryBuilder queryBuilder(plan, duckDbQueryRunner_); - queryBuilder.barrierExecution(hasBarrier).serialExecution(true); + queryBuilder.barrierExecution(testData.hasBarrier) + .serialExecution(testData.serialExecution) + .maxDrivers(testData.serialExecution ? 1 : 3); for (auto i = 0; i < numSources; ++i) { for (auto j = 0; j < numSplits; ++j) { queryBuilder.split( @@ -1156,7 +1246,8 @@ TEST_F(LocalPartitionTest, barrier) { } const auto task = queryBuilder.assertResults("SELECT * FROM tmp"); - ASSERT_EQ(task->taskStats().numBarriers, hasBarrier ? numSplits : 0); + ASSERT_EQ( + task->taskStats().numBarriers, testData.hasBarrier ? numSplits : 0); } } diff --git a/velox/exec/tests/Main.cpp b/velox/exec/tests/Main.cpp index 164b6422fe8..39c009ebdd7 100644 --- a/velox/exec/tests/Main.cpp +++ b/velox/exec/tests/Main.cpp @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/common/memory/Memory.h" #include "velox/common/process/ThreadDebugInfo.h" #include @@ -25,5 +26,6 @@ int main(int argc, char** argv) { // Signal handler required for ThreadDebugInfoTest facebook::velox::process::addDefaultFatalSignalHandler(); folly::Init init(&argc, &argv, false); + facebook::velox::memory::MemoryManager::initialize({}); return RUN_ALL_TESTS(); } diff --git a/velox/exec/tests/MarkDistinctTest.cpp b/velox/exec/tests/MarkDistinctTest.cpp index c7f04b41517..80faa68925b 100644 --- a/velox/exec/tests/MarkDistinctTest.cpp +++ b/velox/exec/tests/MarkDistinctTest.cpp @@ -14,13 +14,20 @@ * limitations under the License. */ +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/file/FileSystems.h" +#include "velox/exec/PlanNodeStats.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/OperatorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/exec/tests/utils/TempDirectoryPath.h" using namespace facebook::velox; -using namespace facebook::velox::test; using namespace facebook::velox::exec::test; +using namespace facebook::velox::test; +using facebook::velox::exec::Operator; +using facebook::velox::exec::TestScopedSpillInjection; +using facebook::velox::memory::testingRunArbitration; class MarkDistinctTest : public OperatorTestBase { public: @@ -42,6 +49,18 @@ class MarkDistinctTest : public OperatorTestBase { auto results = AssertQueryBuilder(plan).copyResults(pool()); assertEqualVectors(expectedResults, results); } + + protected: + MarkDistinctTest() { + filesystems::registerLocalFileSystem(); + } + + RowTypePtr rowType_{ROW({"c0", "c1"}, {BIGINT(), BIGINT()})}; + + VectorFuzzer::Options fuzzerOpts_{ + .vectorSize = 1024, + .nullRatio = 0, + .allowLazyVector = false}; }; template @@ -141,3 +160,385 @@ TEST_F(MarkDistinctTest, aggregation) { .assertResults( "SELECT c0, sum(distinct c1), sum(distinct c2) FROM tmp GROUP BY 1"); } + +TEST_F(MarkDistinctTest, spill) { + auto vectors = createVectors(8, rowType_, fuzzerOpts_); + createDuckDbTable(vectors); + + struct { + uint32_t spillPartitionBits; + uint32_t numSpills; + uint32_t cpuTimeSliceLimitMs; + + std::string debugString() const { + return fmt::format( + "spillPartitionBits {}, numSpills {}, cpuTimeSliceLimitMs {}", + spillPartitionBits, + numSpills, + cpuTimeSliceLimitMs); + } + } testSettings[] = {{2, 1, 0}, {3, 1, 0}, {2, 2, 0}, {2, 2, 10}, {2, 3, 10}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + const auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto queryCtx = core::QueryCtx::create(executor_.get()); + TestScopedSpillInjection scopedSpillInjection( + 100, ".*", testData.numSpills); + + core::PlanNodeId markDistinctId; + auto task = + AssertQueryBuilder(duckDbQueryRunner_) + .spillDirectory(spillDirectory->getPath()) + .config(core::QueryConfig::kSpillEnabled, true) + .config(core::QueryConfig::kMarkDistinctSpillEnabled, true) + .config( + core::QueryConfig::kSpillNumPartitionBits, + testData.spillPartitionBits) + .config( + core::QueryConfig::kDriverCpuTimeSliceLimitMs, + testData.cpuTimeSliceLimitMs) + .config(core::QueryConfig::kAggregationSpillEnabled, false) + .queryCtx(queryCtx) + .plan( + PlanBuilder() + .values(vectors) + .markDistinct("c1_distinct", {"c0", "c1"}) + .capturePlanNodeId(markDistinctId) + .singleAggregation({"c0"}, {"count(c1)"}, {"c1_distinct"}) + .planNode()) + .assertResults("SELECT c0, count(distinct c1) FROM tmp GROUP BY 1"); + + auto taskStats = exec::toPlanStats(task->taskStats()); + auto& planStats = taskStats.at(markDistinctId); + ASSERT_GT(planStats.spilledBytes, 0); + ASSERT_GT(planStats.spilledFiles, 0); + ASSERT_GT(planStats.spilledRows, 0); + ASSERT_GT(planStats.spilledPartitions, 0); + + task.reset(); + waitForAllTasksToBeDeleted(); + } +} + +DEBUG_ONLY_TEST_F(MarkDistinctTest, reclaimDuringInputOrOutput) { + auto vectors = createVectors(8, rowType_, fuzzerOpts_); + createDuckDbTable(vectors); + + struct { + std::string spillInjectionPoint; + uint32_t spillPartitionBits; + + std::string debugString() const { + return fmt::format( + "spillInjectionPoint {}, spillPartitionBits {}", + spillInjectionPoint, + spillPartitionBits); + } + } testSettings[] = { + {"facebook::velox::exec::Driver::runInternal::addInput", 2}, + {"facebook::velox::exec::Driver::runInternal::getOutput", 2}, + {"facebook::velox::exec::Driver::runInternal::addInput", 3}, + {"facebook::velox::exec::Driver::runInternal::getOutput", 3}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + const auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto queryCtx = core::QueryCtx::create(executor_.get()); + + std::atomic_int numRound{0}; + SCOPED_TESTVALUE_SET( + testData.spillInjectionPoint, + std::function(([&](Operator* op) { + if (op->operatorType() != "MarkDistinct") { + return; + } + if (++numRound != 5) { + return; + } + testingRunArbitration(op->pool(), 0); + }))); + + core::PlanNodeId markDistinctId; + auto task = + AssertQueryBuilder(duckDbQueryRunner_) + .spillDirectory(spillDirectory->getPath()) + .config(core::QueryConfig::kSpillEnabled, true) + .config(core::QueryConfig::kMarkDistinctSpillEnabled, true) + .config( + core::QueryConfig::kSpillNumPartitionBits, + testData.spillPartitionBits) + .config(core::QueryConfig::kAggregationSpillEnabled, false) + .queryCtx(queryCtx) + .plan( + PlanBuilder() + .values(vectors) + .markDistinct("c1_distinct", {"c0", "c1"}) + .capturePlanNodeId(markDistinctId) + .singleAggregation({"c0"}, {"count(c1)"}, {"c1_distinct"}) + .planNode()) + .assertResults("SELECT c0, count(distinct c1) FROM tmp GROUP BY 1"); + + auto taskStats = exec::toPlanStats(task->taskStats()); + auto& planStats = taskStats.at(markDistinctId); + ASSERT_GT(planStats.spilledBytes, 0); + ASSERT_GT(planStats.spilledFiles, 0); + ASSERT_GT(planStats.spilledRows, 0); + ASSERT_GT(planStats.spilledPartitions, 0); + + task.reset(); + waitForAllTasksToBeDeleted(); + } +} + +DEBUG_ONLY_TEST_F(MarkDistinctTest, recursiveSpill) { + auto vectors = createVectors(32, rowType_, fuzzerOpts_); + createDuckDbTable(vectors); + + struct { + int32_t numSpills; + int32_t maxSpillLevel; + + std::string debugString() const { + return fmt::format( + "numSpills {}, maxSpillLevel {}", numSpills, maxSpillLevel); + } + } testSettings[] = {{2, 3}, {8, 4}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + const auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto queryCtx = core::QueryCtx::create(executor_.get()); + + std::atomic_int numSpills{0}; + std::atomic_int numInputs{0}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::addInput", + std::function(([&](Operator* op) { + if (op->operatorType() != "MarkDistinct") { + return; + } + if (++numInputs != 5) { + return; + } + ++numSpills; + testingRunArbitration(op->pool(), 0); + }))); + + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::getOutput", + std::function(([&](Operator* op) { + if (op->operatorType() != "MarkDistinct") { + return; + } + if (!op->testingNoMoreInput()) { + return; + } + if (numSpills++ >= testData.numSpills) { + return; + } + testingRunArbitration(op->pool(), 0); + }))); + + core::PlanNodeId markDistinctId; + auto task = + AssertQueryBuilder(duckDbQueryRunner_) + .spillDirectory(spillDirectory->getPath()) + .config(core::QueryConfig::kSpillEnabled, true) + .config(core::QueryConfig::kMarkDistinctSpillEnabled, true) + .config( + core::QueryConfig::kMaxSpillLevel, testData.maxSpillLevel - 1) + .config(core::QueryConfig::kAggregationSpillEnabled, false) + .queryCtx(queryCtx) + .plan( + PlanBuilder() + .values(vectors) + .markDistinct("c1_distinct", {"c0", "c1"}) + .capturePlanNodeId(markDistinctId) + .singleAggregation({"c0"}, {"count(c1)"}, {"c1_distinct"}) + .planNode()) + .assertResults("SELECT c0, count(distinct c1) FROM tmp GROUP BY 1"); + + auto taskStats = exec::toPlanStats(task->taskStats()); + auto& planStats = taskStats.at(markDistinctId); + ASSERT_GT(planStats.spilledBytes, 0); + ASSERT_GT(planStats.spilledFiles, 0); + ASSERT_GT(planStats.spilledRows, 0); + + auto runTimeStats = + task->taskStats().pipelineStats.back().operatorStats.at(1).runtimeStats; + if (testData.numSpills > testData.maxSpillLevel) { + ASSERT_GT(runTimeStats["exceededMaxSpillLevel"].sum, 0); + } else { + ASSERT_EQ(runTimeStats.count("exceededMaxSpillLevel"), 0); + } + + task.reset(); + waitForAllTasksToBeDeleted(); + } +} + +TEST_F(MarkDistinctTest, spillWithDuplicateKeys) { + // Verifies correctness when the same key appears in both pre-spill and + // post-spill input batches. The hash table state must be preserved through + // spill/restore so that keys already marked as distinct before spill are not + // re-marked during restore. + auto vectors = { + makeRowVector({ + makeFlatVector({1, 2, 3, 4, 5}), + makeFlatVector({10, 20, 30, 40, 50}), + }), + makeRowVector({ + makeFlatVector({1, 2, 3, 4, 5}), + makeFlatVector({10, 20, 30, 40, 50}), + }), + makeRowVector({ + makeFlatVector({1, 2, 6, 7, 8}), + makeFlatVector({10, 20, 60, 70, 80}), + }), + makeRowVector({ + makeFlatVector({1, 2, 3, 9, 10}), + makeFlatVector({10, 20, 30, 90, 100}), + }), + }; + createDuckDbTable(vectors); + + const auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto queryCtx = core::QueryCtx::create(executor_.get()); + // Trigger spill after the second batch so pre-spill output includes keys + // that also appear in post-spill batches. + TestScopedSpillInjection scopedSpillInjection(100, ".*", 1); + + core::PlanNodeId markDistinctId; + auto task = + AssertQueryBuilder(duckDbQueryRunner_) + .spillDirectory(spillDirectory->getPath()) + .config(core::QueryConfig::kSpillEnabled, true) + .config(core::QueryConfig::kMarkDistinctSpillEnabled, true) + .config(core::QueryConfig::kAggregationSpillEnabled, false) + .queryCtx(queryCtx) + .plan( + PlanBuilder() + .values(vectors) + .markDistinct("c1_distinct", {"c0", "c1"}) + .capturePlanNodeId(markDistinctId) + .singleAggregation({"c0"}, {"count(c1)"}, {"c1_distinct"}) + .planNode()) + .assertResults("SELECT c0, count(distinct c1) FROM tmp GROUP BY 1"); + + auto taskStats = exec::toPlanStats(task->taskStats()); + auto& planStats = taskStats.at(markDistinctId); + ASSERT_GT(planStats.spilledBytes, 0); + ASSERT_GT(planStats.spilledRows, 0); + + task.reset(); + waitForAllTasksToBeDeleted(); +} + +TEST_F(MarkDistinctTest, memoryUsage) { + const auto rowType = + ROW({"c0", "c1", "c2"}, {INTEGER(), INTEGER(), VARCHAR()}); + const auto vectors = createVectors(rowType, 1024, 100 << 20); + + core::PlanNodeId markDistinctId; + auto plan = PlanBuilder() + .values(vectors) + .markDistinct("c1_distinct", {"c0", "c1"}) + .capturePlanNodeId(markDistinctId) + .singleAggregation({"c0"}, {"count(c1)"}, {"c1_distinct"}) + .planNode(); + + struct { + uint8_t numSpills; + + std::string debugString() const { + return fmt::format("numSpills {}", numSpills); + } + } testSettings[] = {{1}, {3}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + int64_t peakBytesWithSpilling = 0; + int64_t peakBytesWithOutSpilling = 0; + + for (const auto& spillEnable : {false, true}) { + auto queryCtx = core::QueryCtx::create(executor_.get()); + auto spillDirectory = exec::test::TempDirectoryPath::create(); + const std::string spillEnableConfig = std::to_string(spillEnable); + + std::shared_ptr task; + TestScopedSpillInjection scopedSpillInjection( + 100, ".*", testData.numSpills); + AssertQueryBuilder(plan) + .spillDirectory(spillDirectory->getPath()) + .queryCtx(queryCtx) + .config(core::QueryConfig::kSpillEnabled, spillEnableConfig) + .config( + core::QueryConfig::kMarkDistinctSpillEnabled, spillEnableConfig) + .config(core::QueryConfig::kAggregationSpillEnabled, "false") + .copyResults(pool_.get(), task); + + if (spillEnable) { + peakBytesWithSpilling = queryCtx->pool()->peakBytes(); + auto taskStats = exec::toPlanStats(task->taskStats()); + const auto& stats = taskStats.at(markDistinctId); + + ASSERT_GT(stats.spilledBytes, 0); + ASSERT_GT(stats.spilledRows, 0); + ASSERT_GT(stats.spilledFiles, 0); + ASSERT_GT(stats.spilledPartitions, 0); + } else { + peakBytesWithOutSpilling = queryCtx->pool()->peakBytes(); + } + } + + ASSERT_GT(peakBytesWithOutSpilling, peakBytesWithSpilling); + } +} + +TEST_F(MarkDistinctTest, maxSpillBytes) { + auto rowType = ROW({"c0", "c1", "c2"}, {INTEGER(), INTEGER(), VARCHAR()}); + auto vectors = createVectors(rowType, 1024, 15 << 20); + + auto plan = PlanBuilder() + .values(vectors) + .markDistinct("c1_distinct", {"c0", "c1"}) + .singleAggregation({"c0"}, {"count(c1)"}, {"c1_distinct"}) + .planNode(); + + struct { + int32_t maxSpilledBytes; + bool expectedExceedLimit; + + std::string debugString() const { + return fmt::format("maxSpilledBytes {}", maxSpilledBytes); + } + } testSettings[] = {{1 << 30, false}, {1 << 20, true}, {0, false}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto queryCtx = core::QueryCtx::create(executor_.get()); + try { + TestScopedSpillInjection scopedSpillInjection(100, ".*", 1); + AssertQueryBuilder(plan) + .spillDirectory(spillDirectory->getPath()) + .queryCtx(queryCtx) + .config(core::QueryConfig::kSpillEnabled, true) + .config(core::QueryConfig::kMarkDistinctSpillEnabled, true) + .config(core::QueryConfig::kAggregationSpillEnabled, false) + .config(core::QueryConfig::kMaxSpillBytes, testData.maxSpilledBytes) + .copyResults(pool_.get()); + ASSERT_FALSE(testData.expectedExceedLimit); + } catch (const VeloxRuntimeError& e) { + ASSERT_TRUE(testData.expectedExceedLimit); + ASSERT_NE( + e.message().find( + "Query exceeded per-query local spill limit of 1.00MB"), + std::string::npos); + ASSERT_EQ( + e.errorCode(), facebook::velox::error_code::kSpillLimitExceeded); + } + } +} diff --git a/velox/exec/tests/MarkSortedTest.cpp b/velox/exec/tests/MarkSortedTest.cpp new file mode 100644 index 00000000000..1f79ea64c6d --- /dev/null +++ b/velox/exec/tests/MarkSortedTest.cpp @@ -0,0 +1,482 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" + +using namespace facebook::velox; +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; +using namespace facebook::velox::core; + +class MarkSortedTest : public OperatorTestBase { + protected: + void assertMarkSortedResults( + const std::vector& input, + const std::vector& sortingKeys, + const std::vector& sortingOrders, + const std::string& markerName, + const std::vector& expectedMarkers) { + auto plan = PlanBuilder() + .values(input) + .markSorted(markerName, sortingKeys, sortingOrders) + .planNode(); + + auto results = AssertQueryBuilder(plan).copyResults(pool()); + + // Verify marker column exists and has correct values. + auto markerIdx = static_cast(results->childrenSize() - 1); + auto markerVector = results->childAt(markerIdx); + ASSERT_EQ(markerVector->size(), expectedMarkers.size()); + + auto flatMarker = markerVector->as>(); + for (vector_size_t i = 0; i < expectedMarkers.size(); ++i) { + ASSERT_EQ(flatMarker->valueAt(i), expectedMarkers[i]) + << "Mismatch at row " << i; + } + } + + void assertMarkSortedResultsWithConfig( + const std::vector& input, + const std::vector& sortingKeys, + const std::vector& sortingOrders, + const std::string& markerName, + const std::vector& expectedMarkers, + const std::unordered_map& configs) { + auto plan = PlanBuilder() + .values(input) + .markSorted(markerName, sortingKeys, sortingOrders) + .planNode(); + + auto builder = AssertQueryBuilder(plan); + for (const auto& [key, value] : configs) { + builder.config(key, value); + } + auto results = builder.copyResults(pool()); + + auto markerIdx = static_cast(results->childrenSize() - 1); + auto markerVector = results->childAt(markerIdx); + ASSERT_EQ(markerVector->size(), expectedMarkers.size()); + + auto flatMarker = markerVector->as>(); + for (vector_size_t i = 0; i < expectedMarkers.size(); ++i) { + ASSERT_EQ(flatMarker->valueAt(i), expectedMarkers[i]) + << "Mismatch at row " << i; + } + } +}; + +TEST_F(MarkSortedTest, singleKeyAscSorted) { + auto data = makeRowVector({ + makeFlatVector({1, 2, 3, 4, 5}), + }); + + assertMarkSortedResults( + {data}, + {"c0"}, + {core::kAscNullsLast}, + "is_sorted", + {true, true, true, true, true}); +} + +TEST_F(MarkSortedTest, singleKeyDescSorted) { + auto data = makeRowVector({ + makeFlatVector({5, 4, 3, 2, 1}), + }); + + assertMarkSortedResults( + {data}, + {"c0"}, + {core::kDescNullsLast}, + "is_sorted", + {true, true, true, true, true}); +} + +TEST_F(MarkSortedTest, singleKeyAscUnsorted) { + auto data = makeRowVector({ + makeFlatVector({1, 3, 2, 4, 5}), + }); + + assertMarkSortedResults( + {data}, + {"c0"}, + {core::kAscNullsLast}, + "is_sorted", + {true, true, false, true, true}); +} + +TEST_F(MarkSortedTest, singleKeyDescUnsorted) { + auto data = makeRowVector({ + makeFlatVector({5, 4, 6, 2, 1}), + }); + + assertMarkSortedResults( + {data}, + {"c0"}, + {core::kDescNullsLast}, + "is_sorted", + {true, true, false, true, true}); +} + +TEST_F(MarkSortedTest, multipleKeysSorted) { + auto data = makeRowVector({ + makeFlatVector({1, 1, 2, 2, 3}), + makeFlatVector({1, 2, 1, 2, 1}), + }); + + assertMarkSortedResults( + {data}, + {"c0", "c1"}, + {core::kAscNullsLast, core::kAscNullsLast}, + "is_sorted", + {true, true, true, true, true}); +} + +TEST_F(MarkSortedTest, multipleKeysUnsorted) { + auto data = makeRowVector({ + makeFlatVector({1, 1, 2, 2, 3}), + makeFlatVector({1, 3, 2, 1, 1}), + }); + + assertMarkSortedResults( + {data}, + {"c0", "c1"}, + {core::kAscNullsLast, core::kAscNullsLast}, + "is_sorted", + {true, true, true, false, true}); +} + +TEST_F(MarkSortedTest, nullsFirstSorted) { + auto data = makeRowVector({ + makeNullableFlatVector({std::nullopt, std::nullopt, 1, 2, 3}), + }); + + assertMarkSortedResults( + {data}, + {"c0"}, + {core::kAscNullsFirst}, + "is_sorted", + {true, true, true, true, true}); +} + +TEST_F(MarkSortedTest, nullsLastSorted) { + auto data = makeRowVector({ + makeNullableFlatVector({1, 2, 3, std::nullopt, std::nullopt}), + }); + + assertMarkSortedResults( + {data}, + {"c0"}, + {core::kAscNullsLast}, + "is_sorted", + {true, true, true, true, true}); +} + +TEST_F(MarkSortedTest, nullsFirstUnsorted) { + auto data = makeRowVector({ + makeNullableFlatVector({1, std::nullopt, 2, 3, 4}), + }); + + assertMarkSortedResults( + {data}, + {"c0"}, + {core::kAscNullsFirst}, + "is_sorted", + {true, false, true, true, true}); +} + +TEST_F(MarkSortedTest, crossBatchSorted) { + auto batch1 = makeRowVector({ + makeFlatVector({1, 2, 3}), + }); + auto batch2 = makeRowVector({ + makeFlatVector({4, 5, 6}), + }); + + assertMarkSortedResults( + {batch1, batch2}, + {"c0"}, + {core::kAscNullsLast}, + "is_sorted", + {true, true, true, true, true, true}); +} + +TEST_F(MarkSortedTest, crossBatchUnsorted) { + auto batch1 = makeRowVector({ + makeFlatVector({1, 2, 5}), + }); + auto batch2 = makeRowVector({ + makeFlatVector({3, 4, 6}), + }); + + assertMarkSortedResults( + {batch1, batch2}, + {"c0"}, + {core::kAscNullsLast}, + "is_sorted", + {true, true, true, false, true, true}); +} + +TEST_F(MarkSortedTest, emptyBatch) { + auto data = makeRowVector({ + makeFlatVector({}), + }); + + assertMarkSortedResults( + {data}, {"c0"}, {core::kAscNullsLast}, "is_sorted", {}); +} + +TEST_F(MarkSortedTest, allNullValues) { + auto data = makeRowVector({ + makeNullableFlatVector( + {std::nullopt, std::nullopt, std::nullopt}), + }); + + assertMarkSortedResults( + {data}, {"c0"}, {core::kAscNullsLast}, "is_sorted", {true, true, true}); +} + +TEST_F(MarkSortedTest, firstRowAlwaysTrue) { + auto data = makeRowVector({ + makeFlatVector({100, 1, 2, 3}), + }); + + assertMarkSortedResults( + {data}, + {"c0"}, + {core::kAscNullsLast}, + "is_sorted", + {true, false, true, true}); +} + +TEST_F(MarkSortedTest, singleRow) { + auto data = makeRowVector({ + makeFlatVector({42}), + }); + + assertMarkSortedResults( + {data}, {"c0"}, {core::kAscNullsLast}, "is_sorted", {true}); +} + +TEST_F(MarkSortedTest, stringKey) { + auto data = makeRowVector({ + makeFlatVector({"apple", "banana", "cherry", "date"}), + }); + + assertMarkSortedResults( + {data}, + {"c0"}, + {core::kAscNullsLast}, + "is_sorted", + {true, true, true, true}); +} + +TEST_F(MarkSortedTest, stringKeyUnsorted) { + auto data = makeRowVector({ + makeFlatVector({"apple", "cherry", "banana", "date"}), + }); + + assertMarkSortedResults( + {data}, + {"c0"}, + {core::kAscNullsLast}, + "is_sorted", + {true, true, false, true}); +} + +// --- Optimization tests --- + +TEST_F(MarkSortedTest, constantVectorSorted) { + // All key columns are constant non-null — trivially sorted. + auto data = makeRowVector({ + BaseVector::createConstant(INTEGER(), 42, 5, pool()), + }); + + assertMarkSortedResults( + {data}, + {"c0"}, + {core::kAscNullsLast}, + "is_sorted", + {true, true, true, true, true}); +} + +TEST_F(MarkSortedTest, constantVectorWithNull) { + // Constant null key — falls back to generic path. + auto data = makeRowVector({ + BaseVector::createNullConstant(INTEGER(), 5, pool()), + }); + + assertMarkSortedResults( + {data}, + {"c0"}, + {core::kAscNullsLast}, + "is_sorted", + {true, true, true, true, true}); +} + +TEST_F(MarkSortedTest, simdPathInteger) { + // Single flat non-null INTEGER key — uses SIMD path (ascending). + auto data = makeRowVector({ + makeFlatVector({1, 2, 3, 4, 5, 6, 7, 8}), + }); + + assertMarkSortedResults( + {data}, + {"c0"}, + {core::kAscNullsLast}, + "is_sorted", + {true, true, true, true, true, true, true, true}); +} + +TEST_F(MarkSortedTest, simdPathBigint) { + // SIMD path with BIGINT key (descending), with unsorted violations. + auto data = makeRowVector({ + makeFlatVector({100, 90, 80, 85, 70, 60, 50, 40}), + }); + + assertMarkSortedResults( + {data}, + {"c0"}, + {core::kDescNullsLast}, + "is_sorted", + {true, true, true, false, true, true, true, true}); +} + +TEST_F(MarkSortedTest, genericPathDouble) { + // DOUBLE key — SIMD not applicable for floating point, uses generic path. + auto data = makeRowVector({ + makeFlatVector({1.0, 2.0, 3.0, 2.5, 4.0}), + }); + + assertMarkSortedResults( + {data}, + {"c0"}, + {core::kAscNullsLast}, + "is_sorted", + {true, true, true, false, true}); +} + +TEST_F(MarkSortedTest, simdFallbackMultiKey) { + // Multi-key: SIMD not applicable, falls back to generic. + auto data = makeRowVector({ + makeFlatVector({1, 1, 2, 2, 3}), + makeFlatVector({1, 2, 1, 2, 1}), + }); + + assertMarkSortedResults( + {data}, + {"c0", "c1"}, + {core::kAscNullsLast, core::kAscNullsLast}, + "is_sorted", + {true, true, true, true, true}); +} + +TEST_F(MarkSortedTest, simdFallbackVarchar) { + // VARCHAR key: SIMD not applicable, falls back to generic. + auto data = makeRowVector({ + makeFlatVector({"a", "b", "c", "d", "e"}), + }); + + assertMarkSortedResults( + {data}, + {"c0"}, + {core::kAscNullsLast}, + "is_sorted", + {true, true, true, true, true}); +} + +TEST_F(MarkSortedTest, simdFallbackWithNulls) { + // Flat vector with nulls — SIMD not applicable, falls back to generic. + // With kAscNullsLast, null > any value, so row 3 (4 after null) is unsorted. + auto data = makeRowVector({ + makeNullableFlatVector({1, 2, std::nullopt, 4, 5}), + }); + + assertMarkSortedResults( + {data}, + {"c0"}, + {core::kAscNullsLast}, + "is_sorted", + {true, true, true, false, true}); +} + +TEST_F(MarkSortedTest, zeroCopySmallBatch) { + // Small batches (< threshold) use zero-copy cross-batch. Verify correctness. + auto batch1 = makeRowVector({ + makeFlatVector({1, 2, 3}), + }); + auto batch2 = makeRowVector({ + makeFlatVector({4, 5, 6}), + }); + + // Default threshold is 1000, so 3-row batches use zero-copy. + assertMarkSortedResults( + {batch1, batch2}, + {"c0"}, + {core::kAscNullsLast}, + "is_sorted", + {true, true, true, true, true, true}); +} + +TEST_F(MarkSortedTest, copyLargeBatch) { + // Set threshold very low so even small batches use copy mode. + auto batch1 = makeRowVector({ + makeFlatVector({1, 2, 5}), + }); + auto batch2 = makeRowVector({ + makeFlatVector({3, 4, 6}), + }); + + assertMarkSortedResultsWithConfig( + {batch1, batch2}, + {"c0"}, + {core::kAscNullsLast}, + "is_sorted", + {true, true, true, false, true, true}, + {{QueryConfig::kMarkSortedZeroCopyThreshold, "1"}}); +} + +TEST_F(MarkSortedTest, zeroCopyThresholdConfig) { + // Verify threshold is configurable: set to 2 so batch of 3 uses copy mode + // and batch of 1 uses zero-copy. + auto batch1 = makeRowVector({ + makeFlatVector({10}), + }); + auto batch2 = makeRowVector({ + makeFlatVector({5, 6, 7}), + }); + + // batch1 (1 row) < threshold(2) → zero-copy + // batch2 first row (5) < prev (10) → unsorted for ascending + assertMarkSortedResultsWithConfig( + {batch1, batch2}, + {"c0"}, + {core::kAscNullsLast}, + "is_sorted", + {true, false, true, true}, + {{QueryConfig::kMarkSortedZeroCopyThreshold, "2"}}); +} + +TEST_F(MarkSortedTest, simdSmallBatch) { + // SIMD works correctly with very small batches (< 16 rows). + auto data = makeRowVector({ + makeFlatVector({3, 1, 2}), + }); + + assertMarkSortedResults( + {data}, {"c0"}, {core::kAscNullsLast}, "is_sorted", {true, false, true}); +} diff --git a/velox/exec/tests/MergeJoinTest.cpp b/velox/exec/tests/MergeJoinTest.cpp index 4019cc9b80b..cbaae09a0bf 100644 --- a/velox/exec/tests/MergeJoinTest.cpp +++ b/velox/exec/tests/MergeJoinTest.cpp @@ -16,11 +16,12 @@ #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/testutil/TestValue.h" +#include "velox/exec/PlanNodeStats.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "folly/experimental/EventCount.h" +#include "folly/synchronization/EventCount.h" using namespace facebook::velox; using namespace facebook::velox::common::testutil; @@ -112,12 +113,13 @@ class MergeJoinTest : public HiveConnectorTestBase { for (const auto& row : input) { std::vector children; for (const auto& child : row->children()) { - children.push_back(std::make_shared( - pool(), - child->type(), - child->size(), - std::make_unique( - batchId, counter, [=, this](RowSet) { return child; }))); + children.push_back( + std::make_shared( + pool(), + child->type(), + child->size(), + std::make_unique( + batchId, counter, [=, this](RowSet) { return child; }))); } data.push_back(makeRowVector(children)); @@ -369,6 +371,87 @@ class MergeJoinTest : public HiveConnectorTestBase { std::bind( &MergeJoinTest::generateLazyInput, this, std::placeholders::_1)); } + + void testJoinTwoKeysWithNulls( + RowVectorPtr& leftVectors, + RowVectorPtr& rightVectors) { + auto leftFile = TempFilePath::create(); + writeToFile(leftFile->getPath(), leftVectors); + createDuckDbTable("t", {leftVectors}); + auto rightFile = TempFilePath::create(); + writeToFile(rightFile->getPath(), rightVectors); + createDuckDbTable("u", {rightVectors}); + + auto joinTypes = { + core::JoinType::kInner, + core::JoinType::kLeft, + core::JoinType::kRight, + core::JoinType::kFull, + }; + + for (auto joinType : joinTypes) { + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId leftScanId; + core::PlanNodeId rightScanId; + auto op = PlanBuilder(planNodeIdGenerator) + .tableScan( + ROW({"c0", "c1", "c2", "c3"}, + {VARCHAR(), VARCHAR(), VARCHAR(), VARCHAR()})) + .capturePlanNodeId(leftScanId) + .mergeJoin( + {"c0", "c1"}, + {"rc0", "rc1"}, + PlanBuilder(planNodeIdGenerator) + .tableScan( + ROW({"rc0", "rc1", "rc2"}, + {VARCHAR(), VARCHAR(), VARCHAR()})) + .capturePlanNodeId(rightScanId) + .planNode(), + "", + {"c0", "c1", "c2", "c3", "rc0", "rc1", "rc2"}, + joinType) + .planNode(); + AssertQueryBuilder(op, duckDbQueryRunner_) + .split(rightScanId, makeHiveConnectorSplit(rightFile->getPath())) + .split(leftScanId, makeHiveConnectorSplit(leftFile->getPath())) + .assertResults( + fmt::format( + "SELECT * FROM t {} JOIN u " + "ON t.c0 = u.rc0 AND t.c1 = u.rc1", + core::JoinTypeName::toName(joinType))); + } + + { + // anti join + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId leftScanId; + core::PlanNodeId rightScanId; + auto op = PlanBuilder(planNodeIdGenerator) + .tableScan( + ROW({"c0", "c1", "c2", "c3"}, + {VARCHAR(), VARCHAR(), VARCHAR(), VARCHAR()})) + .capturePlanNodeId(leftScanId) + .mergeJoin( + {"c0", "c1"}, + {"rc0", "rc1"}, + PlanBuilder(planNodeIdGenerator) + .tableScan( + ROW({"rc0", "rc1", "rc2"}, + {VARCHAR(), VARCHAR(), VARCHAR()})) + .capturePlanNodeId(rightScanId) + .planNode(), + "", + {"c0", "c1", "c2", "c3"}, + core::JoinType::kAnti) + .planNode(); + AssertQueryBuilder(op, duckDbQueryRunner_) + .split(rightScanId, makeHiveConnectorSplit(rightFile->getPath())) + .split(leftScanId, makeHiveConnectorSplit(leftFile->getPath())) + .assertResults( + "SELECT * FROM t WHERE NOT exists (select * from u " + "where t.c0 = u.rc0 AND t.c1 = u.rc1)"); + } + } }; TEST_F(MergeJoinTest, oneToOneAllMatch) { @@ -870,10 +953,11 @@ TEST_F(MergeJoinTest, lazyVectors) { AssertQueryBuilder(op, duckDbQueryRunner_) .split(rightScanId, makeHiveConnectorSplit(rightFile->getPath())) .split(leftScanId, makeHiveConnectorSplit(leftFile->getPath())) - .assertResults(fmt::format( - "SELECT c0, rc0, c1, rc1, c2, c3 FROM t {} JOIN u " - "ON t.c0 = u.rc0 AND c1 + rc1 < 30", - core::JoinTypeName::toName(joinType))); + .assertResults( + fmt::format( + "SELECT c0, rc0, c1, rc1, c2, c3 FROM t {} JOIN u " + "ON t.c0 = u.rc0 AND c1 + rc1 < 30", + core::JoinTypeName::toName(joinType))); } } @@ -1310,6 +1394,160 @@ TEST_F(MergeJoinTest, antiJoinWithTwoJoinKeys) { "SELECT * FROM t WHERE NOT exists (select * from u where t.a = u.c and t.b < u.d)"); } +TEST_F(MergeJoinTest, matchRatioStats) { + // Test match ratio statistics for different join scenarios. + + // Inner join with full match (all rows match). + { + auto left = makeRowVector( + {"t0"}, {makeNullableFlatVector({1, 2, 3, 4, 5})}); + auto right = makeRowVector( + {"u0"}, {makeNullableFlatVector({1, 2, 3, 4, 5})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId mergeJoinNodeId; + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + "", + {"t0", "u0"}, + core::JoinType::kInner) + .capturePlanNodeId(mergeJoinNodeId) + .planNode(); + + auto task = AssertQueryBuilder(plan, duckDbQueryRunner_) + .assertResults("SELECT t0, u0 FROM t, u WHERE t0 = u0"); + + auto stats = toPlanStats(task->taskStats()); + ASSERT_EQ(stats.at(mergeJoinNodeId).outputRows, 5); + + auto runtimeStats = stats.at(mergeJoinNodeId).customStats; + ASSERT_EQ(runtimeStats.at("matchedLeftRows").sum, 5); + ASSERT_EQ(runtimeStats.at("matchedRightRows").sum, 5); + } + + // Inner join with partial match. + { + auto left = makeRowVector( + {"t0"}, {makeNullableFlatVector({1, 2, 3, 4, 5, 6, 7, 8})}); + auto right = makeRowVector( + {"u0"}, {makeNullableFlatVector({2, 4, 6, 10, 12})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId mergeJoinNodeId; + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + "", + {"t0", "u0"}, + core::JoinType::kInner) + .capturePlanNodeId(mergeJoinNodeId) + .planNode(); + + auto task = AssertQueryBuilder(plan, duckDbQueryRunner_) + .assertResults("SELECT t0, u0 FROM t, u WHERE t0 = u0"); + + auto stats = toPlanStats(task->taskStats()); + ASSERT_EQ(stats.at(mergeJoinNodeId).outputRows, 3); + + auto runtimeStats = stats.at(mergeJoinNodeId).customStats; + // Only 3 left rows match (2, 4, 6). + ASSERT_EQ(runtimeStats.at("matchedLeftRows").sum, 3); + // Only 3 right rows match (2, 4, 6). + ASSERT_EQ(runtimeStats.at("matchedRightRows").sum, 3); + } + + // Left join - all left rows appear in output. + { + auto left = makeRowVector( + {"t0"}, {makeNullableFlatVector({1, 2, 3, 4, 5})}); + auto right = + makeRowVector({"u0"}, {makeNullableFlatVector({2, 4})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId mergeJoinNodeId; + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + "", + {"t0", "u0"}, + core::JoinType::kLeft) + .capturePlanNodeId(mergeJoinNodeId) + .planNode(); + + auto task = + AssertQueryBuilder(plan, duckDbQueryRunner_) + .assertResults("SELECT t0, u0 FROM t LEFT JOIN u ON t0 = u0"); + + auto stats = toPlanStats(task->taskStats()); + ASSERT_EQ(stats.at(mergeJoinNodeId).outputRows, 5); + + auto runtimeStats = stats.at(mergeJoinNodeId).customStats; + // Only 2 left rows match (2, 4). + ASSERT_EQ(runtimeStats.at("matchedLeftRows").sum, 2); + ASSERT_EQ(runtimeStats.at("matchedRightRows").sum, 2); + } + + // Join with duplicate keys (cartesian product). + { + auto left = makeRowVector( + {"t0"}, {makeNullableFlatVector({1, 1, 1, 2, 2})}); + auto right = makeRowVector( + {"u0"}, {makeNullableFlatVector({1, 1, 2, 2, 2})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId mergeJoinNodeId; + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + "", + {"t0", "u0"}, + core::JoinType::kInner) + .capturePlanNodeId(mergeJoinNodeId) + .planNode(); + + auto task = AssertQueryBuilder(plan, duckDbQueryRunner_) + .assertResults("SELECT t0, u0 FROM t, u WHERE t0 = u0"); + + auto stats = toPlanStats(task->taskStats()); + ASSERT_EQ(stats.at(mergeJoinNodeId).outputRows, 12); + + auto runtimeStats = stats.at(mergeJoinNodeId).customStats; + // 3 left rows with key=1 and 2 left rows with key=2. + ASSERT_EQ(runtimeStats.at("matchedLeftRows").sum, 5); + // 2 right rows with key=1 and 3 right rows with key=2. + ASSERT_EQ(runtimeStats.at("matchedRightRows").sum, 5); + } +} + TEST_F(MergeJoinTest, antiJoinWithUniqueJoinKeys) { auto left = makeRowVector( {"a", "b"}, @@ -1704,6 +1942,20 @@ TEST_F(MergeJoinTest, barrier) { createDuckDbTable("t", {left}); createDuckDbTable("u", {right}); + struct { + bool hasBarrier; + bool serialExecution; + + std::string toString() const { + return fmt::format( + "hasBarrier: {}, serialExecution: {}", hasBarrier, serialExecution); + } + } testSettings[] = { + {false, false}, + {false, true}, + {true, false}, + {true, true}, + }; { // Inner join. @@ -1730,10 +1982,12 @@ TEST_F(MergeJoinTest, barrier) { {"t0", "t1", "u0", "u1"}, core::JoinType::kInner) .planNode(); - for (const auto hasBarrier : {false, true}) { - SCOPED_TRACE(fmt::format("hasBarrier {}", hasBarrier)); + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.toString()); AssertQueryBuilder queryBuilder(plan, duckDbQueryRunner_); - queryBuilder.barrierExecution(hasBarrier).serialExecution(true); + queryBuilder.barrierExecution(testData.hasBarrier) + .serialExecution(testData.serialExecution) + .maxDrivers(testData.serialExecution ? 1 : 3); queryBuilder.split( leftNodeId, makeHiveConnectorSplit(leftFile->getPath())); queryBuilder.split( @@ -1742,8 +1996,10 @@ TEST_F(MergeJoinTest, barrier) { const auto task = queryBuilder.assertResults( "SELECT t0, t1, u0, u1 FROM t INNER JOIN u ON t.t0 = u.u0"); - ASSERT_EQ(task->taskStats().numBarriers, hasBarrier ? 1 : 0); - ASSERT_EQ(task->taskStats().numFinishedSplits, hasBarrier ? 2 : 1); + ASSERT_EQ(task->taskStats().numBarriers, testData.hasBarrier ? 1 : 0); + ASSERT_EQ( + task->taskStats().numFinishedSplits, + (testData.hasBarrier || !testData.serialExecution) ? 2 : 1); } } @@ -1772,10 +2028,12 @@ TEST_F(MergeJoinTest, barrier) { {"t0", "t1", "u0", "u1"}, core::JoinType::kFull) .planNode(); - for (const auto hasBarrier : {false, true}) { - SCOPED_TRACE(fmt::format("hasBarrier {}", hasBarrier)); + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.toString()); AssertQueryBuilder queryBuilder(plan, duckDbQueryRunner_); - queryBuilder.barrierExecution(hasBarrier).serialExecution(true); + queryBuilder.barrierExecution(testData.hasBarrier) + .serialExecution(testData.serialExecution) + .maxDrivers(testData.serialExecution ? 1 : 3); queryBuilder.split( leftNodeId, makeHiveConnectorSplit(leftFile->getPath())); queryBuilder.split( @@ -1784,7 +2042,7 @@ TEST_F(MergeJoinTest, barrier) { const auto task = queryBuilder.assertResults( "SELECT t0, t1, u0, u1 FROM t FULL OUTER JOIN u ON t.t0 = u.u0"); - ASSERT_EQ(task->taskStats().numBarriers, hasBarrier ? 1 : 0); + ASSERT_EQ(task->taskStats().numBarriers, testData.hasBarrier ? 1 : 0); ASSERT_EQ(task->taskStats().numFinishedSplits, 2); } } @@ -1814,10 +2072,12 @@ TEST_F(MergeJoinTest, barrier) { {"t0", "t1", "u0", "u1"}, core::JoinType::kRight) .planNode(); - for (const auto hasBarrier : {false, true}) { - SCOPED_TRACE(fmt::format("hasBarrier {}", hasBarrier)); + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.toString()); AssertQueryBuilder queryBuilder(plan, duckDbQueryRunner_); - queryBuilder.barrierExecution(hasBarrier).serialExecution(true); + queryBuilder.barrierExecution(testData.hasBarrier) + .serialExecution(testData.serialExecution) + .maxDrivers(testData.serialExecution ? 1 : 3); queryBuilder.split( leftNodeId, makeHiveConnectorSplit(leftFile->getPath())); queryBuilder.split( @@ -1826,7 +2086,7 @@ TEST_F(MergeJoinTest, barrier) { const auto task = queryBuilder.assertResults( "SELECT t0, t1, u0, u1 FROM t RIGHT JOIN u ON t.t0 = u.u0"); - ASSERT_EQ(task->taskStats().numBarriers, hasBarrier ? 1 : 0); + ASSERT_EQ(task->taskStats().numBarriers, testData.hasBarrier ? 1 : 0); ASSERT_EQ(task->taskStats().numFinishedSplits, 2); } } @@ -1856,10 +2116,12 @@ TEST_F(MergeJoinTest, barrier) { {"t0", "t1", "u0", "u1"}, core::JoinType::kLeft) .planNode(); - for (const auto hasBarrier : {true}) { - SCOPED_TRACE(fmt::format("hasBarrier {}", hasBarrier)); + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.toString()); AssertQueryBuilder queryBuilder(plan, duckDbQueryRunner_); - queryBuilder.barrierExecution(hasBarrier).serialExecution(true); + queryBuilder.barrierExecution(testData.hasBarrier) + .serialExecution(testData.serialExecution) + .maxDrivers(testData.serialExecution ? 1 : 3); queryBuilder.split( leftNodeId, makeHiveConnectorSplit(leftFile->getPath())); queryBuilder.split( @@ -1868,8 +2130,10 @@ TEST_F(MergeJoinTest, barrier) { const auto task = queryBuilder.assertResults( "SELECT t0, t1, u0, u1 FROM t LEFT JOIN u ON t.t0 = u.u0"); - ASSERT_EQ(task->taskStats().numBarriers, hasBarrier ? 1 : 0); - ASSERT_EQ(task->taskStats().numFinishedSplits, hasBarrier ? 2 : 1); + ASSERT_EQ(task->taskStats().numBarriers, testData.hasBarrier ? 1 : 0); + ASSERT_EQ( + task->taskStats().numFinishedSplits, + (testData.hasBarrier || !testData.serialExecution) ? 2 : 1); } } @@ -1898,10 +2162,12 @@ TEST_F(MergeJoinTest, barrier) { {"t0", "t1"}, core::JoinType::kAnti) .planNode(); - for (const auto hasBarrier : {true}) { - SCOPED_TRACE(fmt::format("hasBarrier {}", hasBarrier)); + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.toString()); AssertQueryBuilder queryBuilder(plan, duckDbQueryRunner_); - queryBuilder.barrierExecution(hasBarrier).serialExecution(true); + queryBuilder.barrierExecution(testData.hasBarrier) + .serialExecution(testData.serialExecution) + .maxDrivers(testData.serialExecution ? 1 : 3); queryBuilder.split( leftNodeId, makeHiveConnectorSplit(leftFile->getPath())); queryBuilder.split( @@ -1910,10 +2176,13 @@ TEST_F(MergeJoinTest, barrier) { const auto task = queryBuilder.assertResults( "SELECT t0, t1 FROM t WHERE NOT exists (select u0, u1 from u where t0 = u0)"); - ASSERT_EQ(task->taskStats().numBarriers, hasBarrier ? 1 : 0); - ASSERT_EQ(task->taskStats().numFinishedSplits, hasBarrier ? 2 : 1); + ASSERT_EQ(task->taskStats().numBarriers, testData.hasBarrier ? 1 : 0); + ASSERT_EQ( + task->taskStats().numFinishedSplits, + (testData.hasBarrier || !testData.serialExecution) ? 2 : 1); } } + waitForAllTasksToBeDeleted(); } TEST_F(MergeJoinTest, antiJoinWithFilterWithMultiMatchedRows) { @@ -1977,3 +2246,409 @@ TEST_F(MergeJoinTest, antiJoinWithTwoJoinKeysInDifferentBatch) { .assertResults( "SELECT * FROM t WHERE NOT exists (select * from u where t.a = u.c and t.b < u.d)"); } + +TEST_F(MergeJoinTest, testJoinWithTwoKeysAndSecondColumnHasNulls) { + auto left = makeRowVector( + {"c0", "c1", "c2", "c3"}, + { + makeNullableFlatVector( + {"202408", "202409", "202409", "202410"}), + makeNullableFlatVector({"1", std::nullopt, "2", "3"}), + makeNullableFlatVector({"1", "2", "2", "3"}), + makeNullableFlatVector({"1", "2", "2", "3"}), + }); + auto right = makeRowVector( + {"rc0", "rc1", "rc2"}, + {makeNullableFlatVector( + {"202408", "202409", "202409", "202410"}), + makeNullableFlatVector({"1", std::nullopt, "2", "3"}), + makeNullableFlatVector({"1", std::nullopt, "2", "3"})}); + + testJoinTwoKeysWithNulls(left, right); +} + +// Test that the dynamic output batch sizing follows the expected growth pattern +// Verifies that when dynamic batch sizing is enabled +// (mergeJoinOutputBatchStartSize > 0), MergeJoin adjusts output batch size +// based on average row size and preferred bytes. +TEST_F(MergeJoinTest, dynamicOutputBatchSizing) { + // Create simple two-column BIGINT data for both left and right sides. + // Each row is approximately 16 bytes (2 x 8 bytes for BIGINT). + // We create enough rows to see multiple output batches. + const vector_size_t numRows = 1000; + + auto left = makeRowVector( + {"t_c0", "t_c1"}, + { + makeFlatVector(numRows, [](auto row) { return row; }), + makeFlatVector(numRows, [](auto row) { return row * 10; }), + }); + + auto right = makeRowVector( + {"u_c0", "u_c1"}, + { + makeFlatVector(numRows, [](auto row) { return row; }), + makeFlatVector(numRows, [](auto row) { return row * 100; }), + }); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId mergeJoinNodeId; + + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"t_c0"}, + {"u_c0"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + "", + {"t_c0", "t_c1", "u_c1"}, + core::JoinType::kInner) + .capturePlanNodeId(mergeJoinNodeId) + .planNode(); + + // Set a very high preferred row count (10000) but a relatively small byte + // limit. The dynamic sizing should compute batch size based on + // preferredBytes / avgRowSize, rather than immediately producing 10000 rows. + const uint32_t highPreferredRows = 10000; + // Set preferred bytes to allow roughly 64 rows per batch (each output row is + // ~24 bytes: 3 x 8 bytes for 3 BIGINT columns). + const uint64_t preferredBytes = 24 * 64; + + auto queryCtx = core::QueryCtx::create(executor_.get()); + queryCtx->testingOverrideConfigUnsafe({ + {core::QueryConfig::kPreferredOutputBatchRows, + std::to_string(highPreferredRows)}, + {core::QueryConfig::kPreferredOutputBatchBytes, + std::to_string(preferredBytes)}, + // Enable dynamic batch sizing by setting a non-zero start size. + {core::QueryConfig::kMergeJoinOutputBatchStartSize, "1"}, + }); + + CursorParameters params; + params.planNode = plan; + params.queryCtx = queryCtx; + + auto cursor = TaskCursor::create(params); + cursor->start(); + + std::vector outputBatchSizes; + while (cursor->moveNext()) { + auto result = cursor->current(); + if (result && result->size() > 0) { + outputBatchSizes.push_back(result->size()); + } + } + + // Verify that we got multiple output batches. + ASSERT_GT(outputBatchSizes.size(), 1); + + // Verify the first batch starts small (should be 1 row due to initial + // outputBatchSize_ = 1). + EXPECT_EQ(outputBatchSizes[0], 1); + + // After the first batch, the batch size should immediately jump to the + // computed size based on preferredBytes / avgRowSize (roughly 64 rows). + // All subsequent batches (except the last which may have fewer remaining + // rows) should be approximately this computed size. + for (size_t i = 1; i < outputBatchSizes.size() - 1; ++i) { + EXPECT_GT(outputBatchSizes[i], 8) + << "Batch " << i + << " should be computed based on preferredBytes / avgRowSize"; + } + + // Verify we never exceed the preferred row count. + for (size_t i = 0; i < outputBatchSizes.size(); ++i) { + EXPECT_LE(outputBatchSizes[i], highPreferredRows); + } +} + +// Test that filterInput_ properly handles the case where outputBatchSize_ +// increases after initial creation. +// +// This test uses variable-length string data where: +// - First batch of rows have LARGE strings (causing small batch size when +// filterInput_ is created) +// - Later batches have SMALL strings (causing batch size to increase) +// +// CRITICAL: The filter expression references columns that are NOT in the +// output projection. This creates non-shared child vectors in filterInput_ +// that are allocated with the initial outputBatchSize_. When outputBatchSize_ +// increases, copyRow() would write beyond the buffer capacity if not handled +// correctly. +TEST_F(MergeJoinTest, dynamicOutputBatchSizingWithFilter) { + const vector_size_t numRows = 1000; + + // Left side has 4 columns: t_c0 (join key), t_c1 (in output), t_c2 (in + // output), t_c3 (ONLY in filter, NOT in output). + auto left = makeRowVector( + {"t_c0", "t_c1", "t_c2", "t_c3"}, + { + makeFlatVector(numRows, [](auto row) { return row; }), + makeFlatVector(numRows, [](auto row) { return row * 10; }), + makeFlatVector( + numRows, + [](auto row) { + if (row < 100) { + return std::string(1000, 'A' + (row % 26)); + } else { + return std::to_string(row); + } + }), + makeFlatVector(numRows, [](auto row) { return row % 10; }), + }); + + auto right = makeRowVector( + {"u_c0", "u_c1", "u_c2", "u_c3"}, + { + makeFlatVector(numRows, [](auto row) { return row; }), + makeFlatVector(numRows, [](auto row) { return row * 100; }), + makeFlatVector( + numRows, + [](auto row) { + if (row < 100) { + return std::string(1000, 'X' + (row % 3)); + } else { + return std::to_string(row * 2); + } + }), + makeFlatVector(numRows, [](auto row) { return row % 5; }), + }); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId mergeJoinNodeId; + + // The filter references t_c3 and u_c3 which are NOT in the output projection. + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"t_c0"}, + {"u_c0"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + "(t_c3 + u_c3) % 3 = 0", + {"t_c0", "t_c1", "t_c2", "u_c1", "u_c2"}, + core::JoinType::kInner) + .capturePlanNodeId(mergeJoinNodeId) + .planNode(); + + const uint32_t highPreferredRows = 10000; + const uint64_t preferredBytes = 10000; + + auto queryCtx = core::QueryCtx::create(executor_.get()); + queryCtx->testingOverrideConfigUnsafe({ + {core::QueryConfig::kPreferredOutputBatchRows, + std::to_string(highPreferredRows)}, + {core::QueryConfig::kPreferredOutputBatchBytes, + std::to_string(preferredBytes)}, + {core::QueryConfig::kMergeJoinOutputBatchStartSize, "1"}, + }); + + // If filterInput_ buffer overflow occurs, ASAN will catch it. + AssertQueryBuilder(plan, duckDbQueryRunner_) + .queryCtx(queryCtx) + .assertResults( + "SELECT t_c0, t_c1, t_c2, u_c1, u_c2 FROM t, u " + "WHERE t_c0 = u_c0 AND (t.t_c3 + u.u_c3) % 3 = 0"); +} + +// Verifies that when mergeJoinOutputBatchStartSize is 0 (default), dynamic +// batch sizing is disabled and the batch size is fixed at +// preferredOutputBatchRows. +TEST_F(MergeJoinTest, dynamicOutputBatchSizingDisabledByDefault) { + const vector_size_t numRows = 1000; + + auto left = makeRowVector( + {"t_c0", "t_c1"}, + { + makeFlatVector(numRows, [](auto row) { return row; }), + makeFlatVector(numRows, [](auto row) { return row * 10; }), + }); + + auto right = makeRowVector( + {"u_c0", "u_c1"}, + { + makeFlatVector(numRows, [](auto row) { return row; }), + makeFlatVector(numRows, [](auto row) { return row * 100; }), + }); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId mergeJoinNodeId; + + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"t_c0"}, + {"u_c0"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + "", + {"t_c0", "t_c1", "u_c1"}, + core::JoinType::kInner) + .capturePlanNodeId(mergeJoinNodeId) + .planNode(); + + // Use the default config (mergeJoinOutputBatchStartSize = 0), which disables + // dynamic batch sizing. The batch size should be fixed at + // preferredOutputBatchRows. + const uint32_t preferredRows = 100; + + auto queryCtx = core::QueryCtx::create(executor_.get()); + queryCtx->testingOverrideConfigUnsafe({ + {core::QueryConfig::kPreferredOutputBatchRows, + std::to_string(preferredRows)}, + }); + + CursorParameters params; + params.planNode = plan; + params.queryCtx = queryCtx; + + auto cursor = TaskCursor::create(params); + cursor->start(); + + std::vector outputBatchSizes; + while (cursor->moveNext()) { + auto result = cursor->current(); + if (result && result->size() > 0) { + outputBatchSizes.push_back(result->size()); + } + } + + // Verify that we got multiple output batches. + ASSERT_GT(outputBatchSizes.size(), 1); + + // Since dynamic batch sizing is disabled, all batches (except possibly the + // last one) should be exactly preferredRows in size. + for (size_t i = 0; i < outputBatchSizes.size() - 1; ++i) { + EXPECT_EQ(outputBatchSizes[i], preferredRows) + << "Batch " << i << " should be exactly " << preferredRows + << " rows when dynamic batch sizing is disabled"; + } + + // The last batch should be <= preferredRows (may have fewer remaining rows). + EXPECT_LE(outputBatchSizes.back(), preferredRows); +} + +TEST_F(MergeJoinTest, flatMapVectorInnerJoin) { + auto left = makeRowVector( + {"t_c0", "t_c1"}, + { + makeFlatVector({1, 2, 3}), + makeFlatMapVector({ + {{1, 10}, {2, 20}}, + {{1, 30}, {3, 40}}, + {{2, 50}}, + }), + }); + + auto right = makeRowVector( + {"u_c0", "u_c1"}, + { + makeFlatVector({1, 2, 4}), + makeFlatMapVector({ + {{1, 100}, {2, 200}}, + {{3, 300}}, + {{1, 400}, {2, 500}, {3, 600}}, + }), + }); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"t_c0"}, + {"u_c0"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + "", + {"t_c0", "t_c1", "u_c1"}, + core::JoinType::kInner) + .planNode(); + + auto expected = makeRowVector({ + makeFlatVector({1, 2}), + makeFlatMapVector({ + {{1, 10}, {2, 20}}, + {{1, 30}, {3, 40}}, + }), + makeFlatMapVector({ + {{1, 100}, {2, 200}}, + {{3, 300}}, + }), + }); + + CursorParameters params; + params.planNode = plan; + auto [cursor, results] = readCursor(params); + facebook::velox::test::assertEqualVectors(expected, results[0]); + ASSERT_EQ( + results[0]->childAt(1)->wrappedVector()->encoding(), + VectorEncoding::Simple::FLAT_MAP); + ASSERT_EQ( + results[0]->childAt(2)->wrappedVector()->encoding(), + VectorEncoding::Simple::FLAT_MAP); +} + +TEST_F(MergeJoinTest, flatMapVectorLeftJoin) { + auto left = makeRowVector( + {"t_c0", "t_c1"}, + { + makeFlatVector({1, 2, 3}), + makeFlatMapVector({ + {{1, 10}, {2, 20}}, + {{1, 30}, {3, 40}}, + {{2, 50}}, + }), + }); + + auto right = makeRowVector( + {"u_c0", "u_c1"}, + { + makeFlatVector({1, 2}), + makeFlatMapVector({ + {{1, 100}, {2, 200}}, + {{3, 300}}, + }), + }); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"t_c0"}, + {"u_c0"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + "", + {"t_c0", "t_c1", "u_c1"}, + core::JoinType::kLeft) + .planNode(); + + auto expected = makeRowVector({ + makeFlatVector({1, 2, 3}), + makeFlatMapVector({ + {{1, 10}, {2, 20}}, + {{1, 30}, {3, 40}}, + {{2, 50}}, + }), + makeNullableFlatMapVector({ + {{{1, 100}, {2, 200}}}, + {{{3, 300}}}, + std::nullopt, + }), + }); + + CursorParameters params; + params.planNode = plan; + auto [cursor, results] = readCursor(params); + facebook::velox::test::assertEqualVectors(expected, results[0]); + ASSERT_EQ( + results[0]->childAt(1)->wrappedVector()->encoding(), + VectorEncoding::Simple::FLAT_MAP); + ASSERT_EQ( + results[0]->childAt(2)->wrappedVector()->encoding(), + VectorEncoding::Simple::FLAT_MAP); +} diff --git a/velox/exec/tests/MergeTest.cpp b/velox/exec/tests/MergeTest.cpp index 02f80fcd140..cf0c6de444f 100644 --- a/velox/exec/tests/MergeTest.cpp +++ b/velox/exec/tests/MergeTest.cpp @@ -15,14 +15,14 @@ */ #include "velox/exec/Merge.h" -#include "folly/experimental/EventCount.h" +#include "folly/synchronization/EventCount.h" #include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/common/testutil/TestValue.h" #include "velox/exec/PlanNodeStats.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/OperatorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" using namespace facebook::velox; using namespace facebook::velox::exec; @@ -403,9 +403,13 @@ class MergeTest : public OperatorTestBase { ASSERT_EQ(planStats.spilledFiles, 0); ASSERT_EQ(planStats.spilledRows, 0); ASSERT_EQ( - planStats.customStats.count(Merge::kSpilledSourceReadWallNanos), 0); + planStats.customStats.count( + std::string(Merge::kSpilledSourceReadWallNanos)), + 0); ASSERT_GE( - planStats.customStats.count(Merge::kStreamingSourceReadWallNanos), 0); + planStats.customStats.count( + std::string(Merge::kStreamingSourceReadWallNanos)), + 0); } else { const auto expectedFiles = (inputVectors.size() + numMaxMergeSources - 1) / numMaxMergeSources; @@ -415,9 +419,13 @@ class MergeTest : public OperatorTestBase { ASSERT_EQ(planStats.spilledFiles, expectedFiles); ASSERT_EQ(planStats.spilledRows, expectedSpillRows); ASSERT_GE( - planStats.customStats.count(Merge::kSpilledSourceReadWallNanos), 0); + planStats.customStats.count( + std::string(Merge::kSpilledSourceReadWallNanos)), + 0); ASSERT_GE( - planStats.customStats.count(Merge::kStreamingSourceReadWallNanos), 0); + planStats.customStats.count( + std::string(Merge::kStreamingSourceReadWallNanos)), + 0); } ASSERT_EQ( planStats.outputRows, @@ -601,6 +609,38 @@ TEST_F(MergeTest, localMergeSpillPartialEmpty) { ASSERT_EQ(planStats.spilledRows, 120); } +DEBUG_ONLY_TEST_F(MergeTest, localMergeSpillWithException) { + std::vector vectors; + for (int32_t i = 0; i < 9; ++i) { + constexpr vector_size_t batchSize = 137; + auto c0 = makeFlatVector( + batchSize, [&](auto row) { return batchSize * i + row; }, nullEvery(5)); + auto c1 = makeFlatVector( + batchSize, [&](auto row) { return row; }, nullEvery(5)); + auto c2 = makeFlatVector( + batchSize, [](auto row) { return row * 0.1; }, nullEvery(11)); + auto c3 = makeFlatVector(batchSize, [](auto row) { + return StringView::makeInline(std::to_string(row)); + }); + vectors.push_back(makeRowVector({c0, c1, c2, c3})); + } + createDuckDbTable(vectors); + + for (auto i = 0; i < 11; ++i) { + std::atomic_int cnt{0}; + const auto errorMessage = "ConcatFilesSpillBatchStream::nextBatch fail"; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::ConcatFilesSpillBatchStream::nextBatch", + std::function([&](void* /*unused*/) { + if (cnt++ == i) { + VELOX_FAIL("ConcatFilesSpillBatchStream::nextBatch fail"); + } + })); + + VELOX_ASSERT_THROW(testSingleKeyWithSpill(vectors, "c0"), errorMessage); + } +} + DEBUG_ONLY_TEST_F(MergeTest, localMergeSmallBatch) { std::vector vectors; for (int32_t i = 0; i < 9; ++i) { @@ -735,7 +775,7 @@ DEBUG_ONLY_TEST_F(MergeTest, localMergeAbort) { })); SCOPED_TESTVALUE_SET( - "facebook::velox::exec::SpillMerger::asyncReadFromSpillFileStream", + "facebook::velox::exec::SpillMerger::readFromSpillFileStream", std::function([&](void* /*unused*/) { if (cnt++ == 2) { blocked = true; @@ -915,3 +955,275 @@ TEST_F(MergeTest, localMergeOutputSizeWithoutSpill) { testData.numExpectedOutputBatches); } } + +/// Tests that MultiThreadedTaskCursor correctly preserves FlatMapVector +/// encoding when reading data directly without a merge step. +TEST_F(MergeTest, preserveVectorEncoding) { + auto data = makeRowVector({ + makeFlatVector({1, 2, 3}), + vectorMaker_.flatMapVector({ + {{1, 10}, {2, 20}}, + {{1, 30}}, + {{2, 40}, {3, 50}}, + }), + }); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator).values({data}).planNode(); + + CursorParameters params; + params.planNode = plan; + params.queryCtx = core::QueryCtx::create(executor_.get()); + params.maxDrivers = 2; + + auto result = readCursor(params); + ASSERT_EQ(result.second.size(), 1); + + auto output = result.second[0]; + ASSERT_EQ(output->size(), 3); + + auto mapColumn = output->childAt(1); + EXPECT_EQ(mapColumn->encoding(), VectorEncoding::Simple::FLAT_MAP); + auto flatMapVector = mapColumn->as(); + EXPECT_EQ(flatMapVector->distinctKeys()->size(), 3); + + auto verifier = vectorMaker_.flatMapVector({ + {{1, 10}, {2, 20}}, + {{1, 30}}, + {{2, 40}, {3, 50}}, + }); + facebook::velox::test::assertEqualVectors(mapColumn, verifier); +} + +/// Tests that LocalMerge correctly preserves FlatMapVector encoding +/// when merging data with FlatMapVector columns. +TEST_F(MergeTest, flatMapVectorEncoding) { + // Create input data with FlatMapVector columns. + // Data is already sorted by c0. + auto data1 = makeRowVector({ + makeFlatVector({0, 2, 4}), + vectorMaker_.flatMapVector({ + {{1, 10}, {2, 20}}, + {{1, 30}, {3, 40}}, + {{2, 50}}, + }), + }); + + auto data2 = makeRowVector({ + makeFlatVector({1, 3, 5}), + vectorMaker_.flatMapVector({ + {{1, 100}, {2, 200}}, + {{3, 300}}, + {{1, 400}, {2, 500}, {3, 600}}, + }), + }); + + auto verifyResult = [this](const std::vector& results) { + ASSERT_EQ(results.size(), 1); + + auto output = results[0]; + ASSERT_EQ(output->size(), 6); + + auto sortKey = output->childAt(0)->asFlatVector(); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(sortKey->valueAt(i), i); + } + + auto mapColumn = output->childAt(1); + EXPECT_EQ(mapColumn->encoding(), VectorEncoding::Simple::FLAT_MAP); + auto flatMapVector = mapColumn->as(); + EXPECT_EQ(flatMapVector->distinctKeys()->size(), 3); + auto verifier = vectorMaker_.flatMapVector({ + {{1, 10}, {2, 20}}, + {{1, 100}, {2, 200}}, + {{1, 30}, {3, 40}}, + {{3, 300}}, + {{2, 50}}, + {{1, 400}, {2, 500}, {3, 600}}, + }); + // Merge does not resize child vectors. + verifier->resize(mapColumn->size()); + facebook::velox::test::assertEqualVectors(mapColumn, verifier); + }; + + // Test multi-threaded execution. + { + SCOPED_TRACE("Multi-threaded execution"); + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .localMerge( + {"c0"}, + { + PlanBuilder(planNodeIdGenerator).values({data1}).planNode(), + PlanBuilder(planNodeIdGenerator).values({data2}).planNode(), + }) + .planNode(); + + CursorParameters params; + params.planNode = plan; + params.queryCtx = core::QueryCtx::create(executor_.get()); + params.maxDrivers = 2; + + auto result = readCursor(params); + verifyResult(result.second); + } + + // Test serial execution. + { + SCOPED_TRACE("Serial execution"); + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .localMerge( + {"c0"}, + { + PlanBuilder(planNodeIdGenerator).values({data1}).planNode(), + PlanBuilder(planNodeIdGenerator).values({data2}).planNode(), + }) + .planNode(); + + CursorParameters params; + params.planNode = plan; + params.queryCtx = core::QueryCtx::create(); + params.serialExecution = true; + + auto result = readCursor(params); + verifyResult(result.second); + } +} + +/// Tests that LocalMerge correctly handles FlatMapVector encoding +/// when the first source is empty but subsequent sources have data. +TEST_F(MergeTest, flatMapVectorEncodingWithEmptyFirstSource) { + // Create an empty first source. + auto emptyData = makeRowVector( + {"c0", "c1"}, + { + makeFlatVector({}), + vectorMaker_.flatMapVector({}), + }); + + // Create second source with FlatMapVector data. + auto data = makeRowVector({ + makeFlatVector({1, 2, 3}), + vectorMaker_.flatMapVector({ + {{1, 10}, {2, 20}}, + {{1, 30}}, + {{2, 40}, {3, 50}}, + }), + }); + + auto planNodeIdGenerator = std::make_shared(); + + auto plan = + PlanBuilder(planNodeIdGenerator) + .localMerge( + {"c0"}, + { + PlanBuilder(planNodeIdGenerator) + .values({emptyData}) + .planNode(), + PlanBuilder(planNodeIdGenerator).values({data}).planNode(), + }) + .planNode(); + + CursorParameters params; + params.planNode = plan; + params.queryCtx = core::QueryCtx::create(executor_.get()); + params.maxDrivers = 2; + + auto result = readCursor(params); + ASSERT_EQ(result.second.size(), 1); + + auto output = result.second[0]; + ASSERT_EQ(output->size(), 3); + + // Verify the FlatMapVector encoding is preserved even when first source is + // empty. + auto mapColumn = output->childAt(1); + EXPECT_EQ(mapColumn->encoding(), VectorEncoding::Simple::FLAT_MAP); + auto flatMapVector = mapColumn->as(); + EXPECT_EQ(flatMapVector->distinctKeys()->size(), 3); + auto verifier = vectorMaker_.flatMapVector({ + {{1, 10}, {2, 20}}, + {{1, 30}}, + {{2, 40}, {3, 50}}, + }); + verifier->resize(mapColumn->size()); + facebook::velox::test::assertEqualVectors(mapColumn, verifier); +} + +/// Tests that LocalMerge correctly merges multiple sources with FlatMapVector +/// columns and different key sets. +TEST_F(MergeTest, flatMapVectorEncodingMultipleSources) { + auto data1 = makeRowVector({ + makeFlatVector({0, 3}), + vectorMaker_.flatMapVector({ + {{1, 10}}, + {{1, 40}, {2, 50}}, + }), + }); + + auto data2 = makeRowVector({ + makeFlatVector({1, 4}), + vectorMaker_.flatMapVector({ + {{2, 20}}, + {{3, 60}}, + }), + }); + + auto data3 = makeRowVector({ + makeFlatVector({2, 5}), + vectorMaker_.flatMapVector({ + {{1, 30}, {3, 35}}, + {{1, 70}, {2, 80}, {3, 90}}, + }), + }); + + auto planNodeIdGenerator = std::make_shared(); + + auto plan = + PlanBuilder(planNodeIdGenerator) + .localMerge( + {"c0"}, + { + PlanBuilder(planNodeIdGenerator).values({data1}).planNode(), + PlanBuilder(planNodeIdGenerator).values({data2}).planNode(), + PlanBuilder(planNodeIdGenerator).values({data3}).planNode(), + }) + .planNode(); + + CursorParameters params; + params.planNode = plan; + params.queryCtx = core::QueryCtx::create(executor_.get()); + params.maxDrivers = 3; + + auto result = readCursor(params); + ASSERT_EQ(result.second.size(), 1); + + auto output = result.second[0]; + ASSERT_EQ(output->size(), 6); + + // Verify the output is sorted correctly. + auto sortKey = output->childAt(0)->asFlatVector(); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(sortKey->valueAt(i), i); + } + + // Verify the FlatMapVector encoding is preserved. + auto mapColumn = output->childAt(1); + EXPECT_EQ(mapColumn->encoding(), VectorEncoding::Simple::FLAT_MAP); + auto flatMapVector = mapColumn->as(); + EXPECT_EQ(flatMapVector->distinctKeys()->size(), 3); + auto verifier = vectorMaker_.flatMapVector({ + {{1, 10}}, + {{2, 20}}, + {{1, 30}, {3, 35}}, + {{1, 40}, {2, 50}}, + {{3, 60}}, + {{1, 70}, {2, 80}, {3, 90}}, + }); + verifier->resize(mapColumn->size()); + facebook::velox::test::assertEqualVectors(mapColumn, verifier); +} diff --git a/velox/exec/tests/MergerTest.cpp b/velox/exec/tests/MergerTest.cpp index e1845fc552e..2a7a21f4e15 100644 --- a/velox/exec/tests/MergerTest.cpp +++ b/velox/exec/tests/MergerTest.cpp @@ -14,13 +14,15 @@ * limitations under the License. */ +#include +#include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/FileSystems.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/exec/Merge.h" #include "velox/exec/MergeSource.h" #include "velox/exec/SortBuffer.h" #include "velox/exec/Spill.h" #include "velox/exec/tests/utils/OperatorTestBase.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/type/Type.h" #include "velox/vector/fuzzer/VectorFuzzer.h" #include "velox/vector/tests/utils/VectorTestBase.h" @@ -33,6 +35,7 @@ using namespace facebook::velox; using namespace facebook::velox::memory; namespace facebook::velox::exec::test { +using namespace facebook::velox::common::testutil; class MergerTest : public OperatorTestBase { protected: @@ -100,11 +103,12 @@ class MergerTest : public OperatorTestBase { std::vector> spillReadFiles; spillReadFiles.reserve(spillFiles.size()); for (const auto& spillFile : spillFiles) { - spillReadFiles.emplace_back(SpillReadFile::create( - spillFile, - spillConfig_.readBufferSize, - pool_.get(), - spillStats_.get())); + spillReadFiles.emplace_back( + SpillReadFile::create( + spillFile, + spillConfig_.readBufferSize, + pool_.get(), + spillStats_.get())); } spillReadFilesGroups.emplace_back(std::move(spillReadFiles)); } @@ -148,8 +152,9 @@ class MergerTest : public OperatorTestBase { uint64_t outputBatchBytes) const { std::vector> sourceStreams; for (const auto& source : sources) { - sourceStreams.push_back(std::make_unique( - source.get(), sortingKeys_, outputBatchRows)); + sourceStreams.push_back( + std::make_unique( + source.get(), sortingKeys_, outputBatchRows)); } return std::make_unique( inputType_, @@ -291,7 +296,7 @@ class MergerTest : public OperatorTestBase { const RowTypePtr inputType_ = ROW({{"c0", BIGINT()}, {"c1", SMALLINT()}}); const std::shared_ptr executor_{ std::make_shared( - std::thread::hardware_concurrency())}; + folly::available_concurrency())}; const std::vector sortColumnIndices_{0, 1}; const std::vector sortCompareFlags_{ CompareFlags{.ascending = true}, @@ -299,7 +304,7 @@ class MergerTest : public OperatorTestBase { const std::vector sortingKeys_ = SpillState::makeSortingKeys(sortColumnIndices_, sortCompareFlags_); const std::shared_ptr spillDirectory_ = - exec::test::TempDirectoryPath::create(); + TempDirectoryPath::create(); const common::SpillConfig spillConfig_{ [&]() -> const std::string& { return spillDirectory_->getPath(); }, [&](uint64_t) {}, @@ -316,10 +321,11 @@ class MergerTest : public OperatorTestBase { 0, 0, "none", + 0, std::nullopt}; - std::shared_ptr> spillStats_ = - std::make_shared>(); + std::shared_ptr spillStats_ = + std::make_shared(); tsan_atomic nonReclaimableSection_{false}; }; } // namespace facebook::velox::exec::test @@ -420,3 +426,37 @@ TEST_F(MergerTest, spillMerger) { checkResults(expectedResults, results); } } + +DEBUG_ONLY_TEST_F(MergerTest, spillMergerException) { + struct TestSetting { + size_t maxOutputRows; + size_t numSources; + size_t queueSize; + + std::string debugString() const { + return fmt::format( + "maxOutputRows:{}, numStreams:{}, queueSize:{}", + maxOutputRows, + numSources, + queueSize); + } + }; + + std::atomic_int cnt{0}; + const auto errorMessage = "ConcatFilesSpillBatchStream::nextBatch fail"; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::ConcatFilesSpillBatchStream::nextBatch", + std::function([&](void* /*unused*/) { + if (cnt++ == 11) { + VELOX_FAIL("ConcatFilesSpillBatchStream::nextBatch fail"); + } + })); + const auto numSources = 5; + const auto queueSize = 2; + const auto sources = createMergeSources(numSources, queueSize); + auto [inputs, filesGroup] = generateInputs(numSources, 16); + const auto spillMerger = + createSpillMerger(std::move(filesGroup), 100, queueSize); + spillMerger->start(); + VELOX_ASSERT_THROW(getOutputFromSpillMerger(spillMerger.get()), errorMessage); +} diff --git a/velox/exec/tests/MixedUnionTest.cpp b/velox/exec/tests/MixedUnionTest.cpp new file mode 100644 index 00000000000..0940515666a --- /dev/null +++ b/velox/exec/tests/MixedUnionTest.cpp @@ -0,0 +1,1521 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/exec/tests/utils/TempFilePath.h" + +using namespace facebook::velox; +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; + +class MixedUnionTest : public OperatorTestBase { + protected: + std::shared_ptr makeUnionNode( + std::vector sources, + std::shared_ptr planNodeIdGenerator) { + return std::make_shared( + planNodeIdGenerator->next(), std::move(sources)); + } +}; + +TEST_F(MixedUnionTest, basicUnion) { + auto data1 = makeRowVector({ + makeFlatVector({1, 2, 3}), + makeFlatVector({"a", "b", "c"}), + }); + auto data2 = makeRowVector({ + makeFlatVector({4, 5, 6}), + makeFlatVector({"d", "e", "f"}), + }); + + auto planNodeIdGenerator = std::make_shared(); + auto unionNode = makeUnionNode( + {PlanBuilder(planNodeIdGenerator).values({data1}).planNode(), + PlanBuilder(planNodeIdGenerator).values({data2}).planNode()}, + planNodeIdGenerator); + + assertQuery( + CursorParameters{ + .planNode = unionNode, + .maxDrivers = 1, + }, + "VALUES (1, 'a'), (2, 'b'), (3, 'c'), (4, 'd'), (5, 'e'), (6, 'f')"); +} + +TEST_F(MixedUnionTest, singleSource) { + auto data = makeRowVector({ + makeFlatVector({1, 2, 3}), + makeFlatVector({"a", "b", "c"}), + }); + + auto planNodeIdGenerator = std::make_shared(); + auto unionNode = makeUnionNode( + {PlanBuilder(planNodeIdGenerator).values({data}).planNode()}, + planNodeIdGenerator); + + assertQuery( + CursorParameters{ + .planNode = unionNode, + .maxDrivers = 1, + }, + "VALUES (1, 'a'), (2, 'b'), (3, 'c')"); +} + +TEST_F(MixedUnionTest, emptyFirstSource) { + auto emptyData = makeRowVector( + {makeFlatVector({}), makeFlatVector({})}); + auto data = makeRowVector({ + makeFlatVector({1, 2}), + makeFlatVector({"a", "b"}), + }); + + auto planNodeIdGenerator = std::make_shared(); + auto unionNode = makeUnionNode( + {PlanBuilder(planNodeIdGenerator).values({emptyData}).planNode(), + PlanBuilder(planNodeIdGenerator).values({data}).planNode()}, + planNodeIdGenerator); + + assertQuery( + CursorParameters{ + .planNode = unionNode, + .maxDrivers = 1, + }, + "VALUES (1, 'a'), (2, 'b')"); +} + +TEST_F(MixedUnionTest, emptyLastSource) { + auto data = makeRowVector({ + makeFlatVector({1, 2}), + makeFlatVector({"a", "b"}), + }); + auto emptyData = makeRowVector( + {makeFlatVector({}), makeFlatVector({})}); + + auto planNodeIdGenerator = std::make_shared(); + auto unionNode = makeUnionNode( + {PlanBuilder(planNodeIdGenerator).values({data}).planNode(), + PlanBuilder(planNodeIdGenerator).values({emptyData}).planNode()}, + planNodeIdGenerator); + + assertQuery( + CursorParameters{ + .planNode = unionNode, + .maxDrivers = 1, + }, + "VALUES (1, 'a'), (2, 'b')"); +} + +TEST_F(MixedUnionTest, allEmptySources) { + auto emptyData = makeRowVector( + {makeFlatVector({}), makeFlatVector({})}); + + auto planNodeIdGenerator = std::make_shared(); + auto unionNode = makeUnionNode( + {PlanBuilder(planNodeIdGenerator).values({emptyData}).planNode(), + PlanBuilder(planNodeIdGenerator).values({emptyData}).planNode()}, + planNodeIdGenerator); + + auto result = AssertQueryBuilder(unionNode).copyResults(pool()); + ASSERT_EQ(result->size(), 0); +} + +TEST_F(MixedUnionTest, multipleInputs) { + auto data1 = makeRowVector({ + makeFlatVector({1}), + makeFlatVector({"a"}), + }); + auto data2 = makeRowVector({ + makeFlatVector({2}), + makeFlatVector({"b"}), + }); + auto data3 = makeRowVector({ + makeFlatVector({3}), + makeFlatVector({"c"}), + }); + + auto planNodeIdGenerator = std::make_shared(); + auto unionNode = makeUnionNode( + {PlanBuilder(planNodeIdGenerator).values({data1}).planNode(), + PlanBuilder(planNodeIdGenerator).values({data2}).planNode(), + PlanBuilder(planNodeIdGenerator).values({data3}).planNode()}, + planNodeIdGenerator); + + assertQuery( + CursorParameters{ + .planNode = unionNode, + .maxDrivers = 1, + }, + "VALUES (1, 'a'), (2, 'b'), (3, 'c')"); +} + +TEST_F(MixedUnionTest, manySources) { + constexpr int kNumSources = 10; + auto planNodeIdGenerator = std::make_shared(); + + std::vector sources; + sources.reserve(kNumSources); + for (int i = 0; i < kNumSources; ++i) { + auto data = makeRowVector({ + makeFlatVector({i}), + }); + sources.push_back( + PlanBuilder(planNodeIdGenerator).values({data}).planNode()); + } + + auto unionNode = makeUnionNode(std::move(sources), planNodeIdGenerator); + auto result = AssertQueryBuilder(unionNode).copyResults(pool()); + ASSERT_EQ(result->size(), kNumSources); +} + +TEST_F(MixedUnionTest, unequalSourceSizes) { + auto data1 = makeRowVector({ + makeFlatVector({1, 2, 3, 4, 5}), + }); + auto data2 = makeRowVector({ + makeFlatVector({6, 7}), + }); + auto data3 = makeRowVector({ + makeFlatVector({8, 9, 10, 11, 12, 13, 14}), + }); + + auto planNodeIdGenerator = std::make_shared(); + auto unionNode = makeUnionNode( + {PlanBuilder(planNodeIdGenerator).values({data1}).planNode(), + PlanBuilder(planNodeIdGenerator).values({data2}).planNode(), + PlanBuilder(planNodeIdGenerator).values({data3}).planNode()}, + planNodeIdGenerator); + + auto result = AssertQueryBuilder(unionNode).copyResults(pool()); + ASSERT_EQ(result->size(), 14); +} + +TEST_F(MixedUnionTest, multipleBatchesPerSource) { + // Test that union correctly handles the case where sources provide data + // in a single batch each + auto data1 = makeRowVector({ + makeFlatVector({1, 2, 3, 4, 5, 6}), + }); + auto data2 = makeRowVector({ + makeFlatVector({7, 8, 9, 10, 11, 12}), + }); + + auto planNodeIdGenerator = std::make_shared(); + auto unionNode = makeUnionNode( + {PlanBuilder(planNodeIdGenerator).values({data1}).planNode(), + PlanBuilder(planNodeIdGenerator).values({data2}).planNode()}, + planNodeIdGenerator); + + auto result = AssertQueryBuilder(unionNode).copyResults(pool()); + ASSERT_EQ(result->size(), 12); +} + +TEST_F(MixedUnionTest, withNulls) { + auto data1 = makeRowVector({ + makeNullableFlatVector({1, std::nullopt, 3}), + makeNullableFlatVector({"a", "b", std::nullopt}), + }); + auto data2 = makeRowVector({ + makeNullableFlatVector({std::nullopt, 5, 6}), + makeNullableFlatVector({std::nullopt, "e", "f"}), + }); + + auto planNodeIdGenerator = std::make_shared(); + auto unionNode = makeUnionNode( + {PlanBuilder(planNodeIdGenerator).values({data1}).planNode(), + PlanBuilder(planNodeIdGenerator).values({data2}).planNode()}, + planNodeIdGenerator); + + auto result = AssertQueryBuilder(unionNode).copyResults(pool()); + ASSERT_EQ(result->size(), 6); + + auto* col0 = result->childAt(0)->asFlatVector(); + EXPECT_TRUE(col0->isNullAt(1)); + EXPECT_TRUE(col0->isNullAt(3)); + + auto* col1 = result->childAt(1)->asFlatVector(); + EXPECT_TRUE(col1->isNullAt(2)); + EXPECT_TRUE(col1->isNullAt(3)); +} + +TEST_F(MixedUnionTest, variousTypes) { + auto data1 = makeRowVector({ + makeFlatVector({1, 2}), + makeFlatVector({1.1, 2.2}), + makeFlatVector({true, false}), + makeFlatVector({"hello", "world"}), + }); + auto data2 = makeRowVector({ + makeFlatVector({3, 4}), + makeFlatVector({3.3, 4.4}), + makeFlatVector({false, true}), + makeFlatVector({"foo", "bar"}), + }); + + auto planNodeIdGenerator = std::make_shared(); + auto unionNode = makeUnionNode( + {PlanBuilder(planNodeIdGenerator).values({data1}).planNode(), + PlanBuilder(planNodeIdGenerator).values({data2}).planNode()}, + planNodeIdGenerator); + + auto result = AssertQueryBuilder(unionNode).copyResults(pool()); + ASSERT_EQ(result->size(), 4); + + auto* int64Col = result->childAt(0)->asFlatVector(); + EXPECT_EQ(int64Col->valueAt(0), 1); + EXPECT_EQ(int64Col->valueAt(3), 4); + + auto* doubleCol = result->childAt(1)->asFlatVector(); + EXPECT_DOUBLE_EQ(doubleCol->valueAt(0), 1.1); + EXPECT_DOUBLE_EQ(doubleCol->valueAt(3), 4.4); +} + +TEST_F(MixedUnionTest, largeDataVolume) { + constexpr int kRowsPerSource = 1000; + constexpr int kNumSources = 3; + + auto planNodeIdGenerator = std::make_shared(); + std::vector sources; + sources.reserve(kNumSources); + + for (int s = 0; s < kNumSources; ++s) { + auto data = makeRowVector({ + makeFlatVector( + kRowsPerSource, [s](auto row) { return s * 10000 + row; }), + makeFlatVector( + kRowsPerSource, [](auto row) { return row * 0.1; }), + }); + sources.push_back( + PlanBuilder(planNodeIdGenerator).values({data}).planNode()); + } + + auto unionNode = makeUnionNode(std::move(sources), planNodeIdGenerator); + auto result = AssertQueryBuilder(unionNode).copyResults(pool()); + ASSERT_EQ(result->size(), kRowsPerSource * kNumSources); +} + +TEST_F(MixedUnionTest, stats) { + constexpr int kRowsPerSource = 100; + auto data1 = makeRowVector({ + makeFlatVector(kRowsPerSource, [](auto row) { return row; }), + }); + auto data2 = makeRowVector({ + makeFlatVector( + kRowsPerSource, [](auto row) { return row + 100; }), + }); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId unionNodeId; + auto unionNode = std::make_shared( + planNodeIdGenerator->next(), + std::vector{ + PlanBuilder(planNodeIdGenerator).values({data1}).planNode(), + PlanBuilder(planNodeIdGenerator).values({data2}).planNode()}); + unionNodeId = unionNode->id(); + + auto task = AssertQueryBuilder(unionNode).copyResults(pool()); + ASSERT_EQ(task->size(), kRowsPerSource * 2); +} + +TEST_F(MixedUnionTest, withFilter) { + auto data1 = makeRowVector({ + makeFlatVector({1, 2, 3, 4, 5}), + }); + auto data2 = makeRowVector({ + makeFlatVector({6, 7, 8, 9, 10}), + }); + + auto planNodeIdGenerator = std::make_shared(); + auto unionNode = makeUnionNode( + {PlanBuilder(planNodeIdGenerator).values({data1}).planNode(), + PlanBuilder(planNodeIdGenerator).values({data2}).planNode()}, + planNodeIdGenerator); + + auto plan = PlanBuilder(planNodeIdGenerator) + .addNode([&](auto, auto) { return unionNode; }) + .filter("c0 > 5") + .planNode(); + + assertQuery(plan, "VALUES (6), (7), (8), (9), (10)"); +} + +TEST_F(MixedUnionTest, withProject) { + auto data1 = makeRowVector({ + makeFlatVector({1, 2}), + makeFlatVector({10, 20}), + }); + auto data2 = makeRowVector({ + makeFlatVector({3, 4}), + makeFlatVector({30, 40}), + }); + + auto planNodeIdGenerator = std::make_shared(); + auto unionNode = makeUnionNode( + {PlanBuilder(planNodeIdGenerator).values({data1}).planNode(), + PlanBuilder(planNodeIdGenerator).values({data2}).planNode()}, + planNodeIdGenerator); + + auto plan = PlanBuilder(planNodeIdGenerator) + .addNode([&](auto, auto) { return unionNode; }) + .project({"c0 + c1 as sum"}) + .planNode(); + + assertQuery(plan, "VALUES (11), (22), (33), (44)"); +} + +TEST_F(MixedUnionTest, withAggregation) { + auto data1 = makeRowVector({ + makeFlatVector({1, 1, 2}), + makeFlatVector({10, 20, 30}), + }); + auto data2 = makeRowVector({ + makeFlatVector({1, 2, 2}), + makeFlatVector({40, 50, 60}), + }); + + auto planNodeIdGenerator = std::make_shared(); + auto unionNode = makeUnionNode( + {PlanBuilder(planNodeIdGenerator).values({data1}).planNode(), + PlanBuilder(planNodeIdGenerator).values({data2}).planNode()}, + planNodeIdGenerator); + + auto plan = PlanBuilder(planNodeIdGenerator) + .addNode([&](auto, auto) { return unionNode; }) + .singleAggregation({"c0"}, {"sum(c1)"}) + .planNode(); + + assertQuery(plan, "VALUES (1, 70), (2, 140)"); +} + +/// Test fixture for MixedUnion barrier execution tests using HiveConnector. +class MixedUnionBarrierTest : public HiveConnectorTestBase { + protected: + void SetUp() override { + HiveConnectorTestBase::SetUp(); + } + + std::shared_ptr makeUnionNode( + std::vector sources, + std::shared_ptr planNodeIdGenerator) { + return std::make_shared( + planNodeIdGenerator->next(), std::move(sources)); + } + + std::vector> writeDataToFiles( + const std::vector& data) { + std::vector> files; + files.reserve(data.size()); + for (const auto& vector : data) { + auto file = TempFilePath::create(); + writeToFile(file->getPath(), vector); + files.push_back(file); + } + return files; + } +}; + +TEST_F(MixedUnionBarrierTest, supportsBarrier) { + auto planNodeIdGenerator = std::make_shared(); + const auto rowType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + + core::PlanNodeId scanNodeId; + core::PlanNodePtr unionNode; + const auto plan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(scanNodeId) + .addNode([&](const auto& id, auto source) { + return std::make_shared( + id, + std::vector{ + source, + PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .planNode()}); + }) + .capturePlanNode(unionNode) + .planNode(); + + auto mixedUnionNode = + std::dynamic_pointer_cast(unionNode); + ASSERT_TRUE(mixedUnionNode != nullptr); + ASSERT_TRUE(mixedUnionNode->supportsBarrier()); +} + +TEST_F(MixedUnionBarrierTest, barrierExecution) { + const auto rowType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + + // Create test data for two sources + auto leftData = makeRowVector({ + makeFlatVector(100, [](auto row) { return row; }), + makeFlatVector(100, [](auto row) { return row * 10; }), + }); + auto rightData = makeRowVector({ + makeFlatVector(100, [](auto row) { return row + 100; }), + makeFlatVector(100, [](auto row) { return row * 10 + 1000; }), + }); + + // Write data to files + auto leftFile = TempFilePath::create(); + writeToFile(leftFile->getPath(), leftData); + auto rightFile = TempFilePath::create(); + writeToFile(rightFile->getPath(), rightData); + + // Build plan with MixedUnion + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId leftScanNodeId; + core::PlanNodeId rightScanNodeId; + + auto leftScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(leftScanNodeId) + .planNode(); + auto rightScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(rightScanNodeId) + .planNode(); + + auto plan = std::make_shared( + planNodeIdGenerator->next(), + std::vector{leftScan, rightScan}); + + for (const auto hasBarrier : {false, true}) { + SCOPED_TRACE(fmt::format("hasBarrier {}", hasBarrier)); + AssertQueryBuilder queryBuilder(plan); + queryBuilder.barrierExecution(hasBarrier).serialExecution(true); + queryBuilder.split( + leftScanNodeId, makeHiveConnectorSplit(leftFile->getPath())); + queryBuilder.split( + rightScanNodeId, makeHiveConnectorSplit(rightFile->getPath())); + + std::shared_ptr task; + auto result = queryBuilder.copyResults(pool(), task); + ASSERT_EQ(result->size(), 200); + + const auto taskStats = task->taskStats(); + ASSERT_EQ(taskStats.numBarriers, hasBarrier ? 1 : 0); + } +} + +TEST_F(MixedUnionBarrierTest, barrierWithMultipleSplits) { + const auto rowType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + + // Create test data for multiple splits on each side + std::vector> leftFiles; + std::vector> rightFiles; + constexpr int kSplitsPerSide = 3; + constexpr int kRowsPerSplit = 50; + + for (int i = 0; i < kSplitsPerSide; ++i) { + auto leftData = makeRowVector({ + makeFlatVector( + kRowsPerSplit, [i](auto row) { return i * 1000 + row; }), + makeFlatVector( + kRowsPerSplit, [i](auto row) { return i * 10000 + row * 10; }), + }); + auto rightData = makeRowVector({ + makeFlatVector( + kRowsPerSplit, [i](auto row) { return i * 1000 + row + 500; }), + makeFlatVector( + kRowsPerSplit, + [i](auto row) { return i * 10000 + row * 10 + 5000; }), + }); + + auto leftFile = TempFilePath::create(); + writeToFile(leftFile->getPath(), leftData); + leftFiles.push_back(leftFile); + + auto rightFile = TempFilePath::create(); + writeToFile(rightFile->getPath(), rightData); + rightFiles.push_back(rightFile); + } + + // Build plan + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId leftScanNodeId; + core::PlanNodeId rightScanNodeId; + + auto leftScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(leftScanNodeId) + .planNode(); + auto rightScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(rightScanNodeId) + .planNode(); + + auto plan = std::make_shared( + planNodeIdGenerator->next(), + std::vector{leftScan, rightScan}); + + for (const auto hasBarrier : {false, true}) { + SCOPED_TRACE(fmt::format("hasBarrier {}", hasBarrier)); + AssertQueryBuilder queryBuilder(plan); + queryBuilder.barrierExecution(hasBarrier).serialExecution(true); + + for (const auto& file : leftFiles) { + queryBuilder.split( + leftScanNodeId, makeHiveConnectorSplit(file->getPath())); + } + for (const auto& file : rightFiles) { + queryBuilder.split( + rightScanNodeId, makeHiveConnectorSplit(file->getPath())); + } + + auto result = queryBuilder.copyResults(pool()); + ASSERT_EQ(result->size(), kSplitsPerSide * kRowsPerSplit * 2); + } +} + +TEST_F(MixedUnionBarrierTest, barrierWithProject) { + const auto rowType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + + auto leftData = makeRowVector({ + makeFlatVector(50, [](auto row) { return row; }), + makeFlatVector(50, [](auto row) { return row * 2; }), + }); + auto rightData = makeRowVector({ + makeFlatVector(50, [](auto row) { return row + 50; }), + makeFlatVector(50, [](auto row) { return (row + 50) * 2; }), + }); + + auto leftFile = TempFilePath::create(); + writeToFile(leftFile->getPath(), leftData); + auto rightFile = TempFilePath::create(); + writeToFile(rightFile->getPath(), rightData); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId leftScanNodeId; + core::PlanNodeId rightScanNodeId; + + auto leftScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(leftScanNodeId) + .planNode(); + auto rightScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(rightScanNodeId) + .planNode(); + + auto unionNode = std::make_shared( + planNodeIdGenerator->next(), + std::vector{leftScan, rightScan}); + + // Add a project on top of the union + auto plan = PlanBuilder(planNodeIdGenerator) + .addNode([&](const auto&, auto) { return unionNode; }) + .project({"c0", "c1", "c0 + c1 as sum"}) + .planNode(); + + for (const auto hasBarrier : {false, true}) { + SCOPED_TRACE(fmt::format("hasBarrier {}", hasBarrier)); + AssertQueryBuilder queryBuilder(plan); + queryBuilder.barrierExecution(hasBarrier).serialExecution(true); + queryBuilder.split( + leftScanNodeId, makeHiveConnectorSplit(leftFile->getPath())); + queryBuilder.split( + rightScanNodeId, makeHiveConnectorSplit(rightFile->getPath())); + + auto result = queryBuilder.copyResults(pool()); + ASSERT_EQ(result->size(), 100); + ASSERT_EQ(result->type()->size(), 3); + } +} + +TEST_F(MixedUnionBarrierTest, barrierWithFilter) { + const auto rowType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + + auto leftData = makeRowVector({ + makeFlatVector(100, [](auto row) { return row; }), + makeFlatVector(100, [](auto row) { return row * 10; }), + }); + auto rightData = makeRowVector({ + makeFlatVector(100, [](auto row) { return row + 100; }), + makeFlatVector(100, [](auto row) { return row * 10 + 1000; }), + }); + + auto leftFile = TempFilePath::create(); + writeToFile(leftFile->getPath(), leftData); + auto rightFile = TempFilePath::create(); + writeToFile(rightFile->getPath(), rightData); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId leftScanNodeId; + core::PlanNodeId rightScanNodeId; + + auto leftScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(leftScanNodeId) + .planNode(); + auto rightScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(rightScanNodeId) + .planNode(); + + auto unionNode = std::make_shared( + planNodeIdGenerator->next(), + std::vector{leftScan, rightScan}); + + // Add a filter on top of the union + auto plan = PlanBuilder(planNodeIdGenerator) + .addNode([&](const auto&, auto) { return unionNode; }) + .filter("c0 >= 50 AND c0 < 150") + .planNode(); + + for (const auto hasBarrier : {false, true}) { + SCOPED_TRACE(fmt::format("hasBarrier {}", hasBarrier)); + AssertQueryBuilder queryBuilder(plan); + queryBuilder.barrierExecution(hasBarrier).serialExecution(true); + queryBuilder.split( + leftScanNodeId, makeHiveConnectorSplit(leftFile->getPath())); + queryBuilder.split( + rightScanNodeId, makeHiveConnectorSplit(rightFile->getPath())); + + auto result = queryBuilder.copyResults(pool()); + // Left source: rows 50-99 (50 rows) + // Right source: rows 100-149 (50 rows from 100-199 range, but capped at + // 149) + ASSERT_EQ(result->size(), 100); + } +} + +TEST_F(MixedUnionBarrierTest, barrierWithThreeSources) { + const auto rowType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + + auto data1 = makeRowVector({ + makeFlatVector(100, [](auto row) { return row; }), + makeFlatVector(100, [](auto row) { return row * 10; }), + }); + auto data2 = makeRowVector({ + makeFlatVector(100, [](auto row) { return row + 100; }), + makeFlatVector(100, [](auto row) { return row * 10 + 1000; }), + }); + auto data3 = makeRowVector({ + makeFlatVector(100, [](auto row) { return row + 200; }), + makeFlatVector(100, [](auto row) { return row * 10 + 2000; }), + }); + + auto file1 = TempFilePath::create(); + writeToFile(file1->getPath(), data1); + auto file2 = TempFilePath::create(); + writeToFile(file2->getPath(), data2); + auto file3 = TempFilePath::create(); + writeToFile(file3->getPath(), data3); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId scanNodeId1; + core::PlanNodeId scanNodeId2; + core::PlanNodeId scanNodeId3; + + auto scan1 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(scanNodeId1) + .planNode(); + auto scan2 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(scanNodeId2) + .planNode(); + auto scan3 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(scanNodeId3) + .planNode(); + + auto plan = std::make_shared( + planNodeIdGenerator->next(), + std::vector{scan1, scan2, scan3}); + + for (const auto hasBarrier : {false, true}) { + SCOPED_TRACE(fmt::format("hasBarrier {}", hasBarrier)); + AssertQueryBuilder queryBuilder(plan); + queryBuilder.barrierExecution(hasBarrier).serialExecution(true); + queryBuilder.split(scanNodeId1, makeHiveConnectorSplit(file1->getPath())); + queryBuilder.split(scanNodeId2, makeHiveConnectorSplit(file2->getPath())); + queryBuilder.split(scanNodeId3, makeHiveConnectorSplit(file3->getPath())); + + auto result = queryBuilder.copyResults(pool()); + ASSERT_EQ(result->size(), 300); + } +} + +TEST_F(MixedUnionBarrierTest, barrierWithEmptySource) { + const auto rowType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + + auto leftData = makeRowVector({ + makeFlatVector(100, [](auto row) { return row; }), + makeFlatVector(100, [](auto row) { return row * 10; }), + }); + // Empty right source + auto rightData = makeRowVector({ + makeFlatVector({}), + makeFlatVector({}), + }); + + auto leftFile = TempFilePath::create(); + writeToFile(leftFile->getPath(), leftData); + auto rightFile = TempFilePath::create(); + writeToFile(rightFile->getPath(), rightData); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId leftScanNodeId; + core::PlanNodeId rightScanNodeId; + + auto leftScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(leftScanNodeId) + .planNode(); + auto rightScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(rightScanNodeId) + .planNode(); + + auto plan = std::make_shared( + planNodeIdGenerator->next(), + std::vector{leftScan, rightScan}); + + for (const auto hasBarrier : {false, true}) { + SCOPED_TRACE(fmt::format("hasBarrier {}", hasBarrier)); + AssertQueryBuilder queryBuilder(plan); + queryBuilder.barrierExecution(hasBarrier).serialExecution(true); + queryBuilder.split( + leftScanNodeId, makeHiveConnectorSplit(leftFile->getPath())); + queryBuilder.split( + rightScanNodeId, makeHiveConnectorSplit(rightFile->getPath())); + + auto result = queryBuilder.copyResults(pool()); + ASSERT_EQ(result->size(), 100); + } +} + +TEST_F(MixedUnionBarrierTest, barrierWithUnequalSourceSizes) { + const auto rowType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + + // Left source has significantly more data + auto leftData = makeRowVector({ + makeFlatVector(1000, [](auto row) { return row; }), + makeFlatVector(1000, [](auto row) { return row * 10; }), + }); + // Right source has much less data + auto rightData = makeRowVector({ + makeFlatVector(10, [](auto row) { return row + 1000; }), + makeFlatVector(10, [](auto row) { return row * 10 + 10000; }), + }); + + auto leftFile = TempFilePath::create(); + writeToFile(leftFile->getPath(), leftData); + auto rightFile = TempFilePath::create(); + writeToFile(rightFile->getPath(), rightData); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId leftScanNodeId; + core::PlanNodeId rightScanNodeId; + + auto leftScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(leftScanNodeId) + .planNode(); + auto rightScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(rightScanNodeId) + .planNode(); + + auto plan = std::make_shared( + planNodeIdGenerator->next(), + std::vector{leftScan, rightScan}); + + for (const auto hasBarrier : {false, true}) { + SCOPED_TRACE(fmt::format("hasBarrier {}", hasBarrier)); + AssertQueryBuilder queryBuilder(plan); + queryBuilder.barrierExecution(hasBarrier).serialExecution(true); + queryBuilder.split( + leftScanNodeId, makeHiveConnectorSplit(leftFile->getPath())); + queryBuilder.split( + rightScanNodeId, makeHiveConnectorSplit(rightFile->getPath())); + + auto result = queryBuilder.copyResults(pool()); + ASSERT_EQ(result->size(), 1010); + } +} + +TEST_F(MixedUnionBarrierTest, barrierTaskStats) { + const auto rowType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + + auto leftData = makeRowVector({ + makeFlatVector(100, [](auto row) { return row; }), + makeFlatVector(100, [](auto row) { return row * 10; }), + }); + auto rightData = makeRowVector({ + makeFlatVector(100, [](auto row) { return row + 100; }), + makeFlatVector(100, [](auto row) { return row * 10 + 1000; }), + }); + + auto leftFile = TempFilePath::create(); + writeToFile(leftFile->getPath(), leftData); + auto rightFile = TempFilePath::create(); + writeToFile(rightFile->getPath(), rightData); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId leftScanNodeId; + core::PlanNodeId rightScanNodeId; + + auto leftScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(leftScanNodeId) + .planNode(); + auto rightScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(rightScanNodeId) + .planNode(); + + auto plan = std::make_shared( + planNodeIdGenerator->next(), + std::vector{leftScan, rightScan}); + + for (const auto hasBarrier : {false, true}) { + SCOPED_TRACE(fmt::format("hasBarrier {}", hasBarrier)); + AssertQueryBuilder queryBuilder(plan); + queryBuilder.barrierExecution(hasBarrier).serialExecution(true); + queryBuilder.split( + leftScanNodeId, makeHiveConnectorSplit(leftFile->getPath())); + queryBuilder.split( + rightScanNodeId, makeHiveConnectorSplit(rightFile->getPath())); + + std::shared_ptr task; + auto result = queryBuilder.copyResults(pool(), task); + ASSERT_EQ(result->size(), 200); + + const auto taskStats = task->taskStats(); + ASSERT_EQ(taskStats.numBarriers, hasBarrier ? 1 : 0); + ASSERT_EQ(taskStats.numFinishedSplits, 2); + } +} + +// Tests barrier with multiple splits per source. Each split triggers a separate +// barrier cycle, exercising the drain-reset logic in finishDrain(). After each +// cycle, sourcesDrained_ is reset so the next split group can proceed. +TEST_F(MixedUnionBarrierTest, barrierMultipleBarrierCycles) { + const auto rowType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + + constexpr int kSplitsPerSide = 3; + constexpr int kRowsPerSplit = 40; + + std::vector> leftFiles; + std::vector> rightFiles; + + for (int i = 0; i < kSplitsPerSide; ++i) { + auto leftData = makeRowVector({ + makeFlatVector( + kRowsPerSplit, [i](auto row) { return i * 1000 + row; }), + makeFlatVector( + kRowsPerSplit, [i](auto row) { return i * 10000 + row * 10; }), + }); + auto rightData = makeRowVector({ + makeFlatVector( + kRowsPerSplit, [i](auto row) { return i * 1000 + row + 500; }), + makeFlatVector( + kRowsPerSplit, + [i](auto row) { return i * 10000 + row * 10 + 5000; }), + }); + + auto leftFile = TempFilePath::create(); + writeToFile(leftFile->getPath(), leftData); + leftFiles.push_back(leftFile); + + auto rightFile = TempFilePath::create(); + writeToFile(rightFile->getPath(), rightData); + rightFiles.push_back(rightFile); + } + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId leftScanNodeId; + core::PlanNodeId rightScanNodeId; + + auto leftScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(leftScanNodeId) + .planNode(); + auto rightScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(rightScanNodeId) + .planNode(); + + auto plan = std::make_shared( + planNodeIdGenerator->next(), + std::vector{leftScan, rightScan}); + + AssertQueryBuilder queryBuilder(plan); + queryBuilder.barrierExecution(true).serialExecution(true); + + for (const auto& file : leftFiles) { + queryBuilder.split(leftScanNodeId, makeHiveConnectorSplit(file->getPath())); + } + for (const auto& file : rightFiles) { + queryBuilder.split( + rightScanNodeId, makeHiveConnectorSplit(file->getPath())); + } + + std::shared_ptr task; + auto result = queryBuilder.copyResults(pool(), task); + ASSERT_EQ(result->size(), kSplitsPerSide * kRowsPerSplit * 2); + + const auto taskStats = task->taskStats(); + ASSERT_EQ(taskStats.numBarriers, kSplitsPerSide); +} + +// Tests barrier where one source has very little data and likely finishes +// before the drain signal arrives. This exercises the code path in startDrain() +// and maybeFinishDrain() that sets sourcesDrained_[i] = true when +// sourcesFinished_[i] is already true. +TEST_F(MixedUnionBarrierTest, barrierWithSourceFinishedBeforeDrain) { + const auto rowType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + + // Left source has a lot of data. + auto leftData = makeRowVector({ + makeFlatVector(1000, [](auto row) { return row; }), + makeFlatVector(1000, [](auto row) { return row * 10; }), + }); + // Right source has very little data — likely finishes before drain starts. + auto rightData = makeRowVector({ + makeFlatVector(1, [](auto /*row*/) { return 9999; }), + makeFlatVector(1, [](auto /*row*/) { return 99990; }), + }); + + auto leftFile = TempFilePath::create(); + writeToFile(leftFile->getPath(), leftData); + auto rightFile = TempFilePath::create(); + writeToFile(rightFile->getPath(), rightData); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId leftScanNodeId; + core::PlanNodeId rightScanNodeId; + + auto leftScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(leftScanNodeId) + .planNode(); + auto rightScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(rightScanNodeId) + .planNode(); + + auto plan = std::make_shared( + planNodeIdGenerator->next(), + std::vector{leftScan, rightScan}); + + for (const auto hasBarrier : {false, true}) { + SCOPED_TRACE(fmt::format("hasBarrier {}", hasBarrier)); + AssertQueryBuilder queryBuilder(plan); + queryBuilder.barrierExecution(hasBarrier).serialExecution(true); + queryBuilder.split( + leftScanNodeId, makeHiveConnectorSplit(leftFile->getPath())); + queryBuilder.split( + rightScanNodeId, makeHiveConnectorSplit(rightFile->getPath())); + + auto result = queryBuilder.copyResults(pool()); + ASSERT_EQ(result->size(), 1001); + } +} + +// Tests barrier where all sources are empty. The drain cycle should complete +// immediately since all sources finish without producing data. +TEST_F(MixedUnionBarrierTest, barrierWithAllEmptySources) { + const auto rowType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + + auto emptyData = makeRowVector({ + makeFlatVector({}), + makeFlatVector({}), + }); + + auto leftFile = TempFilePath::create(); + writeToFile(leftFile->getPath(), emptyData); + auto rightFile = TempFilePath::create(); + writeToFile(rightFile->getPath(), emptyData); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId leftScanNodeId; + core::PlanNodeId rightScanNodeId; + + auto leftScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(leftScanNodeId) + .planNode(); + auto rightScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(rightScanNodeId) + .planNode(); + + auto plan = std::make_shared( + planNodeIdGenerator->next(), + std::vector{leftScan, rightScan}); + + for (const auto hasBarrier : {false, true}) { + SCOPED_TRACE(fmt::format("hasBarrier {}", hasBarrier)); + AssertQueryBuilder queryBuilder(plan); + queryBuilder.barrierExecution(hasBarrier).serialExecution(true); + queryBuilder.split( + leftScanNodeId, makeHiveConnectorSplit(leftFile->getPath())); + queryBuilder.split( + rightScanNodeId, makeHiveConnectorSplit(rightFile->getPath())); + + auto result = queryBuilder.copyResults(pool()); + ASSERT_EQ(result->size(), 0); + } +} + +// Tests barrier with three sources where one source is empty. Validates that +// the drain cycle correctly handles a mix of finished and active sources. +TEST_F(MixedUnionBarrierTest, barrierWithThreeSourcesOneEmpty) { + const auto rowType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + + auto data1 = makeRowVector({ + makeFlatVector(100, [](auto row) { return row; }), + makeFlatVector(100, [](auto row) { return row * 10; }), + }); + auto emptyData = makeRowVector({ + makeFlatVector({}), + makeFlatVector({}), + }); + auto data3 = makeRowVector({ + makeFlatVector(50, [](auto row) { return row + 200; }), + makeFlatVector(50, [](auto row) { return row * 10 + 2000; }), + }); + + auto file1 = TempFilePath::create(); + writeToFile(file1->getPath(), data1); + auto file2 = TempFilePath::create(); + writeToFile(file2->getPath(), emptyData); + auto file3 = TempFilePath::create(); + writeToFile(file3->getPath(), data3); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId scanNodeId1; + core::PlanNodeId scanNodeId2; + core::PlanNodeId scanNodeId3; + + auto scan1 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(scanNodeId1) + .planNode(); + auto scan2 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(scanNodeId2) + .planNode(); + auto scan3 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(scanNodeId3) + .planNode(); + + auto plan = std::make_shared( + planNodeIdGenerator->next(), + std::vector{scan1, scan2, scan3}); + + for (const auto hasBarrier : {false, true}) { + SCOPED_TRACE(fmt::format("hasBarrier {}", hasBarrier)); + AssertQueryBuilder queryBuilder(plan); + queryBuilder.barrierExecution(hasBarrier).serialExecution(true); + queryBuilder.split(scanNodeId1, makeHiveConnectorSplit(file1->getPath())); + queryBuilder.split(scanNodeId2, makeHiveConnectorSplit(file2->getPath())); + queryBuilder.split(scanNodeId3, makeHiveConnectorSplit(file3->getPath())); + + auto result = queryBuilder.copyResults(pool()); + ASSERT_EQ(result->size(), 150); + } +} + +// Tests drain where one source finishes producing data very quickly (and is +// marked finished) while the other source is still producing large amounts of +// data. The finished source should be marked as drained immediately in +// startDrain() and maybeFinishDrain(). +TEST_F(MixedUnionBarrierTest, drainOneSideFinishedOtherProducing) { + const auto rowType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + + // Left source produces a lot of data in small batches. + constexpr int kLeftBatches = 10; + constexpr int kRowsPerBatch = 50; + auto leftFile = TempFilePath::create(); + { + auto leftData = makeRowVector({ + makeFlatVector( + kLeftBatches * kRowsPerBatch, [](auto row) { return row; }), + makeFlatVector( + kLeftBatches * kRowsPerBatch, [](auto row) { return row * 10; }), + }); + writeToFile(leftFile->getPath(), leftData); + } + + // Right source is empty — finishes immediately. + auto rightFile = TempFilePath::create(); + { + auto rightData = makeRowVector({ + makeFlatVector({}), + makeFlatVector({}), + }); + writeToFile(rightFile->getPath(), rightData); + } + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId leftScanNodeId; + core::PlanNodeId rightScanNodeId; + + auto leftScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(leftScanNodeId) + .planNode(); + auto rightScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(rightScanNodeId) + .planNode(); + + auto plan = std::make_shared( + planNodeIdGenerator->next(), + std::vector{leftScan, rightScan}); + + AssertQueryBuilder queryBuilder(plan); + queryBuilder.barrierExecution(true).serialExecution(true); + queryBuilder.split( + leftScanNodeId, makeHiveConnectorSplit(leftFile->getPath())); + queryBuilder.split( + rightScanNodeId, makeHiveConnectorSplit(rightFile->getPath())); + + std::shared_ptr task; + auto result = queryBuilder.copyResults(pool(), task); + ASSERT_EQ(result->size(), kLeftBatches * kRowsPerBatch); + + const auto taskStats = task->taskStats(); + ASSERT_EQ(taskStats.numBarriers, 1); +} + +// Tests drain where both sources have produced some data but one is drained +// (signals no more data) while the other still has pending data buffered. +// This exercises the code path where pendingData_ is drained while one source +// has already signaled drained. +TEST_F(MixedUnionBarrierTest, drainWithPendingDataOnOneSide) { + const auto rowType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + + // Both sources have data, but different amounts to create asymmetric + // completion. + auto leftData = makeRowVector({ + makeFlatVector(200, [](auto row) { return row; }), + makeFlatVector(200, [](auto row) { return row * 10; }), + }); + auto rightData = makeRowVector({ + makeFlatVector(50, [](auto row) { return row + 1000; }), + makeFlatVector(50, [](auto row) { return row * 10 + 10000; }), + }); + + auto leftFile = TempFilePath::create(); + writeToFile(leftFile->getPath(), leftData); + auto rightFile = TempFilePath::create(); + writeToFile(rightFile->getPath(), rightData); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId leftScanNodeId; + core::PlanNodeId rightScanNodeId; + + auto leftScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(leftScanNodeId) + .planNode(); + auto rightScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(rightScanNodeId) + .planNode(); + + auto plan = std::make_shared( + planNodeIdGenerator->next(), + std::vector{leftScan, rightScan}); + + AssertQueryBuilder queryBuilder(plan); + queryBuilder.barrierExecution(true).serialExecution(true); + queryBuilder.split( + leftScanNodeId, makeHiveConnectorSplit(leftFile->getPath())); + queryBuilder.split( + rightScanNodeId, makeHiveConnectorSplit(rightFile->getPath())); + + std::shared_ptr task; + auto result = queryBuilder.copyResults(pool(), task); + ASSERT_EQ(result->size(), 250); + + const auto taskStats = task->taskStats(); + ASSERT_EQ(taskStats.numBarriers, 1); +} + +// Tests multiple barrier cycles where sources complete in different orders +// each cycle. This exercises the reset logic in finishDrain() ensuring +// sourcesDrained_ is properly cleared between cycles. +TEST_F(MixedUnionBarrierTest, multipleBarrierCyclesAlternatingCompletion) { + const auto rowType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + + constexpr int kSplitsPerSide = 4; + + std::vector> leftFiles; + std::vector> rightFiles; + + // Create splits with alternating sizes so completion order varies. + for (int i = 0; i < kSplitsPerSide; ++i) { + // Alternate which side has more data. + int leftRows = (i % 2 == 0) ? 100 : 10; + int rightRows = (i % 2 == 0) ? 10 : 100; + + auto leftData = makeRowVector({ + makeFlatVector( + leftRows, [i](auto row) { return i * 1000 + row; }), + makeFlatVector( + leftRows, [i](auto row) { return i * 10000 + row * 10; }), + }); + auto rightData = makeRowVector({ + makeFlatVector( + rightRows, [i](auto row) { return i * 1000 + row + 500; }), + makeFlatVector( + rightRows, [i](auto row) { return i * 10000 + row * 10 + 5000; }), + }); + + auto leftFile = TempFilePath::create(); + writeToFile(leftFile->getPath(), leftData); + leftFiles.push_back(leftFile); + + auto rightFile = TempFilePath::create(); + writeToFile(rightFile->getPath(), rightData); + rightFiles.push_back(rightFile); + } + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId leftScanNodeId; + core::PlanNodeId rightScanNodeId; + + auto leftScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(leftScanNodeId) + .planNode(); + auto rightScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(rightScanNodeId) + .planNode(); + + auto plan = std::make_shared( + planNodeIdGenerator->next(), + std::vector{leftScan, rightScan}); + + AssertQueryBuilder queryBuilder(plan); + queryBuilder.barrierExecution(true).serialExecution(true); + + for (const auto& file : leftFiles) { + queryBuilder.split(leftScanNodeId, makeHiveConnectorSplit(file->getPath())); + } + for (const auto& file : rightFiles) { + queryBuilder.split( + rightScanNodeId, makeHiveConnectorSplit(file->getPath())); + } + + std::shared_ptr task; + auto result = queryBuilder.copyResults(pool(), task); + // Each cycle: 2 splits alternate 100+10 and 10+100 = 110 rows per cycle. + // Total: 4 cycles * 110 = 440 rows. + ASSERT_EQ(result->size(), 440); + + const auto taskStats = task->taskStats(); + ASSERT_EQ(taskStats.numBarriers, kSplitsPerSide); +} + +// Tests drain with three sources where they complete in different orders: +// first source finishes, second is drained, third still has pending data. +TEST_F(MixedUnionBarrierTest, drainThreeSourcesMixedStates) { + const auto rowType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + + // Source 1: Empty (finishes immediately). + auto data1 = makeRowVector({ + makeFlatVector({}), + makeFlatVector({}), + }); + // Source 2: Small amount (drains quickly). + auto data2 = makeRowVector({ + makeFlatVector(20, [](auto row) { return row + 100; }), + makeFlatVector(20, [](auto row) { return row * 10 + 1000; }), + }); + // Source 3: Large amount (still producing when others complete). + auto data3 = makeRowVector({ + makeFlatVector(500, [](auto row) { return row + 200; }), + makeFlatVector(500, [](auto row) { return row * 10 + 2000; }), + }); + + auto file1 = TempFilePath::create(); + writeToFile(file1->getPath(), data1); + auto file2 = TempFilePath::create(); + writeToFile(file2->getPath(), data2); + auto file3 = TempFilePath::create(); + writeToFile(file3->getPath(), data3); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId scanNodeId1; + core::PlanNodeId scanNodeId2; + core::PlanNodeId scanNodeId3; + + auto scan1 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(scanNodeId1) + .planNode(); + auto scan2 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(scanNodeId2) + .planNode(); + auto scan3 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(scanNodeId3) + .planNode(); + + auto plan = std::make_shared( + planNodeIdGenerator->next(), + std::vector{scan1, scan2, scan3}); + + AssertQueryBuilder queryBuilder(plan); + queryBuilder.barrierExecution(true).serialExecution(true); + queryBuilder.split(scanNodeId1, makeHiveConnectorSplit(file1->getPath())); + queryBuilder.split(scanNodeId2, makeHiveConnectorSplit(file2->getPath())); + queryBuilder.split(scanNodeId3, makeHiveConnectorSplit(file3->getPath())); + + std::shared_ptr task; + auto result = queryBuilder.copyResults(pool(), task); + ASSERT_EQ(result->size(), 520); + + const auto taskStats = task->taskStats(); + ASSERT_EQ(taskStats.numBarriers, 1); +} + +// Tests that the operator correctly handles rapid barrier cycles with minimal +// data. This ensures that state reset between cycles is complete and doesn't +// leave stale flags. +TEST_F(MixedUnionBarrierTest, rapidBarrierCyclesMinimalData) { + const auto rowType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + + constexpr int kSplitsPerSide = 5; + constexpr int kRowsPerSplit = 5; + + std::vector> leftFiles; + std::vector> rightFiles; + + for (int i = 0; i < kSplitsPerSide; ++i) { + auto leftData = makeRowVector({ + makeFlatVector( + kRowsPerSplit, [i](auto row) { return i * 100 + row; }), + makeFlatVector( + kRowsPerSplit, [i](auto row) { return i * 1000 + row * 10; }), + }); + auto rightData = makeRowVector({ + makeFlatVector( + kRowsPerSplit, [i](auto row) { return i * 100 + row + 50; }), + makeFlatVector( + kRowsPerSplit, [i](auto row) { return i * 1000 + row * 10 + 500; }), + }); + + auto leftFile = TempFilePath::create(); + writeToFile(leftFile->getPath(), leftData); + leftFiles.push_back(leftFile); + + auto rightFile = TempFilePath::create(); + writeToFile(rightFile->getPath(), rightData); + rightFiles.push_back(rightFile); + } + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId leftScanNodeId; + core::PlanNodeId rightScanNodeId; + + auto leftScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(leftScanNodeId) + .planNode(); + auto rightScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(rightScanNodeId) + .planNode(); + + auto plan = std::make_shared( + planNodeIdGenerator->next(), + std::vector{leftScan, rightScan}); + + AssertQueryBuilder queryBuilder(plan); + queryBuilder.barrierExecution(true).serialExecution(true); + + for (const auto& file : leftFiles) { + queryBuilder.split(leftScanNodeId, makeHiveConnectorSplit(file->getPath())); + } + for (const auto& file : rightFiles) { + queryBuilder.split( + rightScanNodeId, makeHiveConnectorSplit(file->getPath())); + } + + std::shared_ptr task; + auto result = queryBuilder.copyResults(pool(), task); + ASSERT_EQ(result->size(), kSplitsPerSide * kRowsPerSplit * 2); + + const auto taskStats = task->taskStats(); + ASSERT_EQ(taskStats.numBarriers, kSplitsPerSide); +} + +TEST_F(MixedUnionBarrierTest, barrierWithAggregation) { + const auto rowType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + + auto leftData = makeRowVector({ + makeFlatVector(100, [](auto row) { return row % 10; }), + makeFlatVector(100, [](auto row) { return row; }), + }); + auto rightData = makeRowVector({ + makeFlatVector(100, [](auto row) { return row % 10; }), + makeFlatVector(100, [](auto row) { return row + 100; }), + }); + + auto leftFile = TempFilePath::create(); + writeToFile(leftFile->getPath(), leftData); + auto rightFile = TempFilePath::create(); + writeToFile(rightFile->getPath(), rightData); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId leftScanNodeId; + core::PlanNodeId rightScanNodeId; + + auto leftScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(leftScanNodeId) + .planNode(); + auto rightScan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(rightScanNodeId) + .planNode(); + + auto unionNode = std::make_shared( + planNodeIdGenerator->next(), + std::vector{leftScan, rightScan}); + + // Add a single aggregation on top of the union. + // Note: Only non-barrier mode works with singleAggregation since hash + // aggregation doesn't support barriers (only streaming aggregation does). + auto plan = PlanBuilder(planNodeIdGenerator) + .addNode([&](const auto&, auto) { return unionNode; }) + .singleAggregation({"c0"}, {"sum(c1)"}) + .planNode(); + + // Aggregation doesn't support barriers, so only test non-barrier mode. + AssertQueryBuilder queryBuilder(plan); + queryBuilder.serialExecution(true); + queryBuilder.split( + leftScanNodeId, makeHiveConnectorSplit(leftFile->getPath())); + queryBuilder.split( + rightScanNodeId, makeHiveConnectorSplit(rightFile->getPath())); + + auto result = queryBuilder.copyResults(pool()); + // 10 distinct groups + ASSERT_EQ(result->size(), 10); +} diff --git a/velox/exec/tests/MixedUnionWithTableScanTest.cpp b/velox/exec/tests/MixedUnionWithTableScanTest.cpp new file mode 100644 index 00000000000..fb6f8011191 --- /dev/null +++ b/velox/exec/tests/MixedUnionWithTableScanTest.cpp @@ -0,0 +1,688 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/vector/tests/utils/VectorTestBase.h" + +using namespace facebook::velox; +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; + +/// Tests for MixedUnion operator with TableScan sources. +/// Similar to koski's TestLocalWarehouseReadSplit.cpp but exercises +/// velox execution directly. +class MixedUnionWithTableScanTest : public HiveConnectorTestBase { + protected: + void SetUp() override { + HiveConnectorTestBase::SetUp(); + } + + void TearDown() override { + HiveConnectorTestBase::TearDown(); + } + + /// Helper to create a MixedUnionNode from multiple plan node sources. + std::shared_ptr makeMixedUnionNode( + std::vector sources, + std::shared_ptr planNodeIdGenerator) { + return std::make_shared( + planNodeIdGenerator->next(), std::move(sources)); + } + + /// Helper to create a table scan plan node. + core::PlanNodePtr makeTableScanNode( + const RowTypePtr& rowType, + std::shared_ptr planNodeIdGenerator) { + return PlanBuilder(planNodeIdGenerator).tableScan(rowType).planNode(); + } + + /// Helper to write test data to files and return the file paths. + std::vector> writeTestData( + const std::vector& vectors) { + auto filePaths = makeFilePaths(vectors.size()); + for (size_t i = 0; i < vectors.size(); ++i) { + writeToFile(filePaths[i]->getPath(), vectors[i]); + } + return filePaths; + } + + RowTypePtr testSchema_{ROW({"c0", "c1"}, {BIGINT(), VARCHAR()})}; +}; + +/// Basic test: Union of two table scans. +TEST_F(MixedUnionWithTableScanTest, basicUnionOfTableScans) { + // Create test data for two tables. + auto data1 = makeRowVector({ + makeFlatVector({1, 2, 3}), + makeFlatVector({"a", "b", "c"}), + }); + auto data2 = makeRowVector({ + makeFlatVector({4, 5, 6}), + makeFlatVector({"d", "e", "f"}), + }); + + // Write data to files. + auto filePath1 = TempFilePath::create(); + auto filePath2 = TempFilePath::create(); + writeToFile(filePath1->getPath(), data1); + writeToFile(filePath2->getPath(), data2); + + // Build plan with MixedUnion of two table scans. + auto planNodeIdGenerator = std::make_shared(); + auto rowType = asRowType(data1->type()); + + core::PlanNodeId tableScan1Id; + auto source1 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(tableScan1Id) + .planNode(); + + core::PlanNodeId tableScan2Id; + auto source2 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(tableScan2Id) + .planNode(); + + auto unionNode = makeMixedUnionNode({source1, source2}, planNodeIdGenerator); + + // Execute query with splits. + auto result = + AssertQueryBuilder(unionNode) + .split(tableScan1Id, makeHiveConnectorSplit(filePath1->getPath())) + .split(tableScan2Id, makeHiveConnectorSplit(filePath2->getPath())) + .copyResults(pool()); + + ASSERT_EQ(result->size(), 6); + + // Validate result content and row order. + // MixedUnion preserves row order: data1 rows come first, then data2 rows. + auto expected = makeRowVector({ + makeFlatVector({1, 2, 3, 4, 5, 6}), + makeFlatVector({"a", "b", "c", "d", "e", "f"}), + }); + facebook::velox::test::assertEqualVectors(expected, result); +} + +/// Test union of a single table scan (edge case). +TEST_F(MixedUnionWithTableScanTest, singleTableScanUnion) { + auto data = makeRowVector({ + makeFlatVector({1, 2, 3}), + makeFlatVector({"a", "b", "c"}), + }); + + auto filePath = TempFilePath::create(); + writeToFile(filePath->getPath(), data); + + auto planNodeIdGenerator = std::make_shared(); + auto rowType = asRowType(data->type()); + + core::PlanNodeId tableScanId; + auto source = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(tableScanId) + .planNode(); + + auto unionNode = makeMixedUnionNode({source}, planNodeIdGenerator); + + auto result = + AssertQueryBuilder(unionNode) + .split(tableScanId, makeHiveConnectorSplit(filePath->getPath())) + .copyResults(pool()); + + ASSERT_EQ(result->size(), 3); + + // Validate result content and row order. + facebook::velox::test::assertEqualVectors(data, result); +} + +/// Test union with multiple table scans (more than 2). +TEST_F(MixedUnionWithTableScanTest, multipleTableScanUnion) { + constexpr int kNumSources = 5; + + auto planNodeIdGenerator = std::make_shared(); + std::vector> filePaths; + std::vector sources; + std::vector tableScanIds; + + auto rowType = ROW({"c0"}, {BIGINT()}); + + for (int i = 0; i < kNumSources; ++i) { + auto data = makeRowVector({ + makeFlatVector({i * 10, i * 10 + 1}), + }); + + auto filePath = TempFilePath::create(); + writeToFile(filePath->getPath(), data); + filePaths.push_back(filePath); + + core::PlanNodeId tableScanId; + auto source = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(tableScanId) + .planNode(); + sources.push_back(source); + tableScanIds.push_back(tableScanId); + } + + auto unionNode = makeMixedUnionNode(std::move(sources), planNodeIdGenerator); + + auto queryBuilder = AssertQueryBuilder(unionNode); + for (int i = 0; i < kNumSources; ++i) { + queryBuilder.split( + tableScanIds[i], makeHiveConnectorSplit(filePaths[i]->getPath())); + } + + auto result = queryBuilder.copyResults(pool()); + ASSERT_EQ(result->size(), kNumSources * 2); + + // Validate result content and row order. + // MixedUnion preserves row order: source 0, 1, 2, 3, 4. + auto expected = makeRowVector({ + makeFlatVector({0, 1, 10, 11, 20, 21, 30, 31, 40, 41}), + }); + facebook::velox::test::assertEqualVectors(expected, result); +} + +/// Test union with empty table scan source. +TEST_F(MixedUnionWithTableScanTest, emptyTableScanSource) { + auto data = makeRowVector({ + makeFlatVector({1, 2, 3}), + makeFlatVector({"a", "b", "c"}), + }); + auto emptyData = makeRowVector({ + makeFlatVector({}), + makeFlatVector({}), + }); + + auto filePath1 = TempFilePath::create(); + auto filePath2 = TempFilePath::create(); + writeToFile(filePath1->getPath(), data); + writeToFile(filePath2->getPath(), emptyData); + + auto planNodeIdGenerator = std::make_shared(); + auto rowType = asRowType(data->type()); + + core::PlanNodeId tableScan1Id; + auto source1 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(tableScan1Id) + .planNode(); + + core::PlanNodeId tableScan2Id; + auto source2 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(tableScan2Id) + .planNode(); + + auto unionNode = makeMixedUnionNode({source1, source2}, planNodeIdGenerator); + + auto result = + AssertQueryBuilder(unionNode) + .split(tableScan1Id, makeHiveConnectorSplit(filePath1->getPath())) + .split(tableScan2Id, makeHiveConnectorSplit(filePath2->getPath())) + .copyResults(pool()); + + // Should only get data from non-empty source. + ASSERT_EQ(result->size(), 3); + + // Validate result content and row order. + facebook::velox::test::assertEqualVectors(data, result); +} + +/// Test union with filter pushed down to table scan. +TEST_F(MixedUnionWithTableScanTest, unionWithFilteredTableScans) { + auto data1 = makeRowVector({ + makeFlatVector({1, 2, 3, 4, 5}), + makeFlatVector({"a", "b", "c", "d", "e"}), + }); + auto data2 = makeRowVector({ + makeFlatVector({6, 7, 8, 9, 10}), + makeFlatVector({"f", "g", "h", "i", "j"}), + }); + + auto filePath1 = TempFilePath::create(); + auto filePath2 = TempFilePath::create(); + writeToFile(filePath1->getPath(), data1); + writeToFile(filePath2->getPath(), data2); + + auto planNodeIdGenerator = std::make_shared(); + auto rowType = asRowType(data1->type()); + + // Build table scans with filters. + core::PlanNodeId tableScan1Id; + auto source1 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(tableScan1Id) + .filter("c0 > 2") + .planNode(); + + core::PlanNodeId tableScan2Id; + auto source2 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(tableScan2Id) + .filter("c0 < 9") + .planNode(); + + auto unionNode = makeMixedUnionNode({source1, source2}, planNodeIdGenerator); + + auto result = + AssertQueryBuilder(unionNode) + .split(tableScan1Id, makeHiveConnectorSplit(filePath1->getPath())) + .split(tableScan2Id, makeHiveConnectorSplit(filePath2->getPath())) + .copyResults(pool()); + + // source1: 3,4,5 (3 rows), source2: 6,7,8 (3 rows). + ASSERT_EQ(result->size(), 6); + + // Validate result content and row order. + auto expected = makeRowVector({ + makeFlatVector({3, 4, 5, 6, 7, 8}), + makeFlatVector({"c", "d", "e", "f", "g", "h"}), + }); + facebook::velox::test::assertEqualVectors(expected, result); +} + +/// Test union with projection after table scan. +TEST_F(MixedUnionWithTableScanTest, unionWithProjectedTableScans) { + auto data1 = makeRowVector({ + makeFlatVector({1, 2, 3}), + makeFlatVector({10, 20, 30}), + }); + auto data2 = makeRowVector({ + makeFlatVector({4, 5, 6}), + makeFlatVector({40, 50, 60}), + }); + + auto filePath1 = TempFilePath::create(); + auto filePath2 = TempFilePath::create(); + writeToFile(filePath1->getPath(), data1); + writeToFile(filePath2->getPath(), data2); + + auto planNodeIdGenerator = std::make_shared(); + auto rowType = asRowType(data1->type()); + + // Build table scans with projections (sum of columns). + core::PlanNodeId tableScan1Id; + auto source1 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(tableScan1Id) + .project({"c0 + c1 as sum"}) + .planNode(); + + core::PlanNodeId tableScan2Id; + auto source2 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(tableScan2Id) + .project({"c0 + c1 as sum"}) + .planNode(); + + auto unionNode = makeMixedUnionNode({source1, source2}, planNodeIdGenerator); + + auto result = + AssertQueryBuilder(unionNode) + .split(tableScan1Id, makeHiveConnectorSplit(filePath1->getPath())) + .split(tableScan2Id, makeHiveConnectorSplit(filePath2->getPath())) + .copyResults(pool()); + + ASSERT_EQ(result->size(), 6); + + // Validate result content and row order. + auto expected = makeRowVector({ + makeFlatVector({11, 22, 33, 44, 55, 66}), + }); + facebook::velox::test::assertEqualVectors(expected, result); +} + +/// Test union followed by aggregation. +TEST_F(MixedUnionWithTableScanTest, unionWithAggregation) { + auto data1 = makeRowVector({ + makeFlatVector({1, 1, 2}), + makeFlatVector({10, 20, 30}), + }); + auto data2 = makeRowVector({ + makeFlatVector({1, 2, 2}), + makeFlatVector({40, 50, 60}), + }); + + auto filePath1 = TempFilePath::create(); + auto filePath2 = TempFilePath::create(); + writeToFile(filePath1->getPath(), data1); + writeToFile(filePath2->getPath(), data2); + + auto planNodeIdGenerator = std::make_shared(); + auto rowType = asRowType(data1->type()); + + core::PlanNodeId tableScan1Id; + auto source1 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(tableScan1Id) + .planNode(); + + core::PlanNodeId tableScan2Id; + auto source2 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(tableScan2Id) + .planNode(); + + auto unionNode = makeMixedUnionNode({source1, source2}, planNodeIdGenerator); + + // Add aggregation after union. + auto plan = PlanBuilder(planNodeIdGenerator) + .addNode([&](auto, auto) { return unionNode; }) + .singleAggregation({"c0"}, {"sum(c1)"}) + .planNode(); + + auto result = + AssertQueryBuilder(plan) + .split(tableScan1Id, makeHiveConnectorSplit(filePath1->getPath())) + .split(tableScan2Id, makeHiveConnectorSplit(filePath2->getPath())) + .copyResults(pool()); + + // Group 1: 10 + 20 + 40 = 70. + // Group 2: 30 + 50 + 60 = 140. + ASSERT_EQ(result->size(), 2); +} + +/// Test union with multiple splits per table scan. +TEST_F(MixedUnionWithTableScanTest, unionWithMultipleSplitsPerSource) { + auto data1a = makeRowVector({ + makeFlatVector({1, 2}), + }); + auto data1b = makeRowVector({ + makeFlatVector({3, 4}), + }); + auto data2a = makeRowVector({ + makeFlatVector({5, 6}), + }); + auto data2b = makeRowVector({ + makeFlatVector({7, 8}), + }); + + auto filePath1a = TempFilePath::create(); + auto filePath1b = TempFilePath::create(); + auto filePath2a = TempFilePath::create(); + auto filePath2b = TempFilePath::create(); + writeToFile(filePath1a->getPath(), data1a); + writeToFile(filePath1b->getPath(), data1b); + writeToFile(filePath2a->getPath(), data2a); + writeToFile(filePath2b->getPath(), data2b); + + auto planNodeIdGenerator = std::make_shared(); + auto rowType = asRowType(data1a->type()); + + core::PlanNodeId tableScan1Id; + auto source1 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(tableScan1Id) + .planNode(); + + core::PlanNodeId tableScan2Id; + auto source2 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(tableScan2Id) + .planNode(); + + auto unionNode = makeMixedUnionNode({source1, source2}, planNodeIdGenerator); + + // Add multiple splits per table scan. + auto result = + AssertQueryBuilder(unionNode) + .split(tableScan1Id, makeHiveConnectorSplit(filePath1a->getPath())) + .split(tableScan1Id, makeHiveConnectorSplit(filePath1b->getPath())) + .split(tableScan2Id, makeHiveConnectorSplit(filePath2a->getPath())) + .split(tableScan2Id, makeHiveConnectorSplit(filePath2b->getPath())) + .copyResults(pool()); + + ASSERT_EQ(result->size(), 8); + + // Validate result content and row order. + // MixedUnion preserves row order: data1a, data1b, data2a, data2b. + auto expected = makeRowVector({ + makeFlatVector({1, 2, 5, 6, 3, 4, 7, 8}), + }); + facebook::velox::test::assertEqualVectors(expected, result); +} + +/// Test union with null values in data. +TEST_F(MixedUnionWithTableScanTest, unionWithNullValues) { + auto data1 = makeRowVector({ + makeNullableFlatVector({1, std::nullopt, 3}), + makeNullableFlatVector({"a", "b", std::nullopt}), + }); + auto data2 = makeRowVector({ + makeNullableFlatVector({std::nullopt, 5, 6}), + makeNullableFlatVector({std::nullopt, "e", "f"}), + }); + + auto filePath1 = TempFilePath::create(); + auto filePath2 = TempFilePath::create(); + writeToFile(filePath1->getPath(), data1); + writeToFile(filePath2->getPath(), data2); + + auto planNodeIdGenerator = std::make_shared(); + auto rowType = asRowType(data1->type()); + + core::PlanNodeId tableScan1Id; + auto source1 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(tableScan1Id) + .planNode(); + + core::PlanNodeId tableScan2Id; + auto source2 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(tableScan2Id) + .planNode(); + + auto unionNode = makeMixedUnionNode({source1, source2}, planNodeIdGenerator); + + auto result = + AssertQueryBuilder(unionNode) + .split(tableScan1Id, makeHiveConnectorSplit(filePath1->getPath())) + .split(tableScan2Id, makeHiveConnectorSplit(filePath2->getPath())) + .copyResults(pool()); + + ASSERT_EQ(result->size(), 6); + + // Validate result content and row order, including nulls. + auto expected = makeRowVector({ + makeNullableFlatVector({1, std::nullopt, 3, std::nullopt, 5, 6}), + makeNullableFlatVector( + {"a", "b", std::nullopt, std::nullopt, "e", "f"}), + }); + facebook::velox::test::assertEqualVectors(expected, result); +} + +/// Test union with various data types. +TEST_F(MixedUnionWithTableScanTest, unionWithVariousTypes) { + auto data1 = makeRowVector({ + makeFlatVector({1, 2}), + makeFlatVector({1.1, 2.2}), + makeFlatVector({true, false}), + makeFlatVector({"hello", "world"}), + }); + auto data2 = makeRowVector({ + makeFlatVector({3, 4}), + makeFlatVector({3.3, 4.4}), + makeFlatVector({false, true}), + makeFlatVector({"foo", "bar"}), + }); + + auto filePath1 = TempFilePath::create(); + auto filePath2 = TempFilePath::create(); + writeToFile(filePath1->getPath(), data1); + writeToFile(filePath2->getPath(), data2); + + auto planNodeIdGenerator = std::make_shared(); + auto rowType = asRowType(data1->type()); + + core::PlanNodeId tableScan1Id; + auto source1 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(tableScan1Id) + .planNode(); + + core::PlanNodeId tableScan2Id; + auto source2 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(tableScan2Id) + .planNode(); + + auto unionNode = makeMixedUnionNode({source1, source2}, planNodeIdGenerator); + + auto result = + AssertQueryBuilder(unionNode) + .split(tableScan1Id, makeHiveConnectorSplit(filePath1->getPath())) + .split(tableScan2Id, makeHiveConnectorSplit(filePath2->getPath())) + .copyResults(pool()); + + ASSERT_EQ(result->size(), 4); + + // Validate result content and row order. + auto expected = makeRowVector({ + makeFlatVector({1, 2, 3, 4}), + makeFlatVector({1.1, 2.2, 3.3, 4.4}), + makeFlatVector({true, false, false, true}), + makeFlatVector({"hello", "world", "foo", "bar"}), + }); + facebook::velox::test::assertEqualVectors(expected, result); +} + +/// Test union with unequal source sizes. +TEST_F(MixedUnionWithTableScanTest, unionWithUnequalSourceSizes) { + auto data1 = makeRowVector({ + makeFlatVector({1, 2, 3, 4, 5}), + }); + auto data2 = makeRowVector({ + makeFlatVector({6, 7}), + }); + auto data3 = makeRowVector({ + makeFlatVector({8, 9, 10, 11, 12, 13, 14}), + }); + + auto filePath1 = TempFilePath::create(); + auto filePath2 = TempFilePath::create(); + auto filePath3 = TempFilePath::create(); + writeToFile(filePath1->getPath(), data1); + writeToFile(filePath2->getPath(), data2); + writeToFile(filePath3->getPath(), data3); + + auto planNodeIdGenerator = std::make_shared(); + auto rowType = asRowType(data1->type()); + + core::PlanNodeId tableScan1Id; + auto source1 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(tableScan1Id) + .planNode(); + + core::PlanNodeId tableScan2Id; + auto source2 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(tableScan2Id) + .planNode(); + + core::PlanNodeId tableScan3Id; + auto source3 = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(tableScan3Id) + .planNode(); + + auto unionNode = + makeMixedUnionNode({source1, source2, source3}, planNodeIdGenerator); + + auto result = + AssertQueryBuilder(unionNode) + .split(tableScan1Id, makeHiveConnectorSplit(filePath1->getPath())) + .split(tableScan2Id, makeHiveConnectorSplit(filePath2->getPath())) + .split(tableScan3Id, makeHiveConnectorSplit(filePath3->getPath())) + .copyResults(pool()); + + ASSERT_EQ(result->size(), 14); + + // Validate result content and row order. + auto expected = makeRowVector({ + makeFlatVector({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}), + }); + facebook::velox::test::assertEqualVectors(expected, result); +} + +/// Test union with large data volume. +TEST_F(MixedUnionWithTableScanTest, unionWithLargeDataVolume) { + constexpr int kRowsPerSource = 1000; + constexpr int kNumSources = 3; + + auto planNodeIdGenerator = std::make_shared(); + std::vector> filePaths; + std::vector sources; + std::vector tableScanIds; + + auto rowType = ROW({"c0", "c1"}, {BIGINT(), DOUBLE()}); + + for (int s = 0; s < kNumSources; ++s) { + auto data = makeRowVector({ + makeFlatVector( + kRowsPerSource, [s](auto row) { return s * 10000 + row; }), + makeFlatVector( + kRowsPerSource, [](auto row) { return row * 0.1; }), + }); + + auto filePath = TempFilePath::create(); + writeToFile(filePath->getPath(), data); + filePaths.push_back(filePath); + + core::PlanNodeId tableScanId; + auto source = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType) + .capturePlanNodeId(tableScanId) + .planNode(); + sources.push_back(source); + tableScanIds.push_back(tableScanId); + } + + auto unionNode = makeMixedUnionNode(std::move(sources), planNodeIdGenerator); + + auto queryBuilder = AssertQueryBuilder(unionNode); + for (int i = 0; i < kNumSources; ++i) { + queryBuilder.split( + tableScanIds[i], makeHiveConnectorSplit(filePaths[i]->getPath())); + } + + auto result = queryBuilder.copyResults(pool()); + ASSERT_EQ(result->size(), kRowsPerSource * kNumSources); + + // Validate result content and row order. + // MixedUnion preserves row order: source 0 rows, then source 1 rows, etc. + auto expected = makeRowVector({ + makeFlatVector( + kRowsPerSource * kNumSources, + [](auto row) { + int source = row / 1000; + int rowInSource = row % 1000; + return source * 10000 + rowInSource; + }), + makeFlatVector( + kRowsPerSource * kNumSources, + [](auto row) { + int rowInSource = row % 1000; + return rowInSource * 0.1; + }), + }); + facebook::velox::test::assertEqualVectors(expected, result); +} diff --git a/velox/exec/tests/MultiFragmentTest.cpp b/velox/exec/tests/MultiFragmentTest.cpp index f58cab54e27..2606802afea 100644 --- a/velox/exec/tests/MultiFragmentTest.cpp +++ b/velox/exec/tests/MultiFragmentTest.cpp @@ -13,8 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "folly/experimental/EventCount.h" +#include "folly/synchronization/EventCount.h" #include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/common/testutil/TestValue.h" #include "velox/connectors/hive/HiveConnector.h" #include "velox/connectors/hive/HiveConnectorSplit.h" @@ -30,7 +31,6 @@ #include "velox/exec/tests/utils/LocalExchangeSource.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/exec/tests/utils/SerializedPageUtil.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" using namespace facebook::velox::exec::test; @@ -38,15 +38,14 @@ using facebook::velox::common::testutil::TestValue; using facebook::velox::test::BatchMaker; namespace facebook::velox::exec { +using namespace facebook::velox::common::testutil; namespace { struct TestParam { - VectorSerde::Kind serdeKind; + std::string serdeKind; common::CompressionKind compressionKind; - TestParam( - VectorSerde::Kind _serdeKind, - common::CompressionKind _compressionKind) + TestParam(std::string _serdeKind, common::CompressionKind _compressionKind) : serdeKind(_serdeKind), compressionKind(_compressionKind) {} }; @@ -55,18 +54,12 @@ class MultiFragmentTest : public HiveConnectorTestBase, public: static std::vector getTestParams() { std::vector params; - params.emplace_back( - VectorSerde::Kind::kPresto, common::CompressionKind_NONE); - params.emplace_back( - VectorSerde::Kind::kCompactRow, common::CompressionKind_NONE); - params.emplace_back( - VectorSerde::Kind::kUnsafeRow, common::CompressionKind_NONE); - params.emplace_back( - VectorSerde::Kind::kPresto, common::CompressionKind_LZ4); - params.emplace_back( - VectorSerde::Kind::kCompactRow, common::CompressionKind_LZ4); - params.emplace_back( - VectorSerde::Kind::kUnsafeRow, common::CompressionKind_LZ4); + params.emplace_back("Presto", common::CompressionKind_NONE); + params.emplace_back("CompactRow", common::CompressionKind_NONE); + params.emplace_back("UnsafeRow", common::CompressionKind_NONE); + params.emplace_back("Presto", common::CompressionKind_LZ4); + params.emplace_back("CompactRow", common::CompressionKind_LZ4); + params.emplace_back("UnsafeRow", common::CompressionKind_LZ4); return params; } @@ -109,8 +102,9 @@ class MultiFragmentTest : public HiveConnectorTestBase, auto queryCtx = core::QueryCtx::create( executor ? executor : executor_.get(), core::QueryConfig(std::move(configCopy))); - queryCtx->testingOverrideMemoryPool(memory::memoryManager()->addRootPool( - queryCtx->queryId(), maxMemory, MemoryReclaimer::create())); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), maxMemory, MemoryReclaimer::create())); core::PlanFragment planFragment{planNode}; return Task::create( taskId, @@ -127,7 +121,9 @@ class MultiFragmentTest : public HiveConnectorTestBase, std::unordered_map& extraQueryConfigs, int destination = 0, Consumer consumer = nullptr, - int64_t maxMemory = memory::kMaxMemory) const { + int64_t maxMemory = memory::kMaxMemory, + const std::optional& diskSpillOpts = + std::nullopt) const { auto configCopy = configSettings_; for (const auto& [k, v] : extraQueryConfigs) { configCopy[k] = v; @@ -139,8 +135,9 @@ class MultiFragmentTest : public HiveConnectorTestBase, nullptr, nullptr, executor_.get()); - queryCtx->testingOverrideMemoryPool(memory::memoryManager()->addRootPool( - queryCtx->queryId(), maxMemory, MemoryReclaimer::create())); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), maxMemory, MemoryReclaimer::create())); core::PlanFragment planFragment{planNode}; return Task::create( taskId, @@ -148,7 +145,9 @@ class MultiFragmentTest : public HiveConnectorTestBase, destination, std::move(queryCtx), Task::ExecutionMode::kParallel, - std::move(consumer)); + std::move(consumer), + /*memoryArbitrationPriority=*/0, + diskSpillOpts); } std::vector makeVectors(int count, int rowsPerVector) { @@ -252,7 +251,7 @@ class MultiFragmentTest : public HiveConnectorTestBase, exchangeStats.at("localExchangeSource.numPages").count); ASSERT_EQ( expectedBackgroundCpuCount, - exchangeStats.at(ExchangeClient::kBackgroundCpuTimeMs).count); + exchangeStats.at(std::string(Operator::kBackgroundCpuTimeNanos)).count); ASSERT_EQ( expectedBackgroundCpuCount, taskStats.at("0").backgroundTiming.count); } @@ -376,23 +375,27 @@ TEST_P(MultiFragmentTest, aggregationSingleKey) { auto leafPlanStats = toPlanStats(leafTask->taskStats()); const auto serdeKindRuntimsStats = leafPlanStats.at(partitionNodeId) - .customStats.at(Operator::kShuffleSerdeKind); + .customStats.at(std::string(Operator::kShuffleSerdeKind)); ASSERT_EQ(serdeKindRuntimsStats.count, 4); ASSERT_EQ( - serdeKindRuntimsStats.min, static_cast(GetParam().serdeKind)); + serdeKindRuntimsStats.min, + static_cast(VectorSerde::kindByName(GetParam().serdeKind))); ASSERT_EQ( - serdeKindRuntimsStats.max, static_cast(GetParam().serdeKind)); + serdeKindRuntimsStats.max, + static_cast(VectorSerde::kindByName(GetParam().serdeKind))); for (const auto& finalTask : finalTasks) { auto finalPlanStats = toPlanStats(finalTask->taskStats()); const auto serdeKindRuntimsStats = finalPlanStats.at(exchangeNodeId) - .customStats.at(Operator::kShuffleSerdeKind); + .customStats.at(std::string(Operator::kShuffleSerdeKind)); ASSERT_EQ(serdeKindRuntimsStats.count, 1); ASSERT_EQ( - serdeKindRuntimsStats.min, static_cast(GetParam().serdeKind)); + serdeKindRuntimsStats.min, + static_cast(VectorSerde::kindByName(GetParam().serdeKind))); ASSERT_EQ( - serdeKindRuntimsStats.max, static_cast(GetParam().serdeKind)); + serdeKindRuntimsStats.max, + static_cast(VectorSerde::kindByName(GetParam().serdeKind))); } } @@ -675,13 +678,15 @@ TEST_P(MultiFragmentTest, mergeExchange) { EXPECT_LT(0, mergeExchangeStats.inputBytes); EXPECT_LT(0, mergeExchangeStats.rawInputBytes); - const auto serdeKindRuntimsStats = - mergeExchangeStats.customStats.at(Operator::kShuffleSerdeKind); + const auto serdeKindRuntimsStats = mergeExchangeStats.customStats.at( + std::string(Operator::kShuffleSerdeKind)); ASSERT_EQ(serdeKindRuntimsStats.count, 1); ASSERT_EQ( - serdeKindRuntimsStats.min, static_cast(GetParam().serdeKind)); + serdeKindRuntimsStats.min, + static_cast(VectorSerde::kindByName(GetParam().serdeKind))); ASSERT_EQ( - serdeKindRuntimsStats.max, static_cast(GetParam().serdeKind)); + serdeKindRuntimsStats.max, + static_cast(VectorSerde::kindByName(GetParam().serdeKind))); } // Test reordering and dropping columns in PartitionedOutput operator. @@ -917,11 +922,17 @@ TEST_P(MultiFragmentTest, mergeExchangeWithSpill) { .capturePlanNodeId(partitionNodeId) .planNode(); localMergeNodeIds.push_back(localMergeNodeId); - auto sortTask = - makeTask(sortTaskId, partialSortPlan, spillMergeConfigs, tasks.size()); spillDirectories.push_back(TempDirectoryPath::create()); - sortTask->setSpillDirectory( - spillDirectories[numPartialSortTasks]->getPath()); + common::SpillDiskOptions spillOpts; + spillOpts.spillDirPath = spillDirectories[numPartialSortTasks]->getPath(); + auto sortTask = makeTask( + sortTaskId, + partialSortPlan, + spillMergeConfigs, + tasks.size(), + /*consumer=*/nullptr, + memory::kMaxMemory, + spillOpts); tasks.push_back(sortTask); sortTask->start(4); @@ -986,13 +997,15 @@ TEST_P(MultiFragmentTest, mergeExchangeWithSpill) { EXPECT_LT(0, mergeExchangeStats.inputBytes); EXPECT_LT(0, mergeExchangeStats.rawInputBytes); - const auto serdeKindRuntimsStats = - mergeExchangeStats.customStats.at(Operator::kShuffleSerdeKind); + const auto serdeKindRuntimsStats = mergeExchangeStats.customStats.at( + std::string(Operator::kShuffleSerdeKind)); ASSERT_EQ(serdeKindRuntimsStats.count, 1); ASSERT_EQ( - serdeKindRuntimsStats.min, static_cast(GetParam().serdeKind)); + serdeKindRuntimsStats.min, + static_cast(VectorSerde::kindByName(GetParam().serdeKind))); ASSERT_EQ( - serdeKindRuntimsStats.max, static_cast(GetParam().serdeKind)); + serdeKindRuntimsStats.max, + static_cast(VectorSerde::kindByName(GetParam().serdeKind))); } TEST_P(MultiFragmentTest, noHashPartitionSkew) { @@ -1665,7 +1678,7 @@ namespace { core::PlanNodePtr makeJoinOverExchangePlan( const RowTypePtr& exchangeType, const RowVectorPtr& buildData, - VectorSerde::Kind serdeKind) { + std::string serdeKind) { auto planNodeIdGenerator = std::make_shared(); return PlanBuilder(planNodeIdGenerator) .exchange(exchangeType, serdeKind) @@ -2083,14 +2096,14 @@ class TestCustomExchangeNode : public core::PlanNode { TestCustomExchangeNode( const core::PlanNodeId& id, const RowTypePtr type, - VectorSerde::Kind serdeKind) + std::string serdeKind) : PlanNode(id), outputType_(type), serdeKind_(serdeKind) {} const RowTypePtr& outputType() const override { return outputType_; } - VectorSerde::Kind serdeKind() const { + const std::string& serdeKind() const { return serdeKind_; } @@ -2117,7 +2130,7 @@ class TestCustomExchangeNode : public core::PlanNode { } const RowTypePtr outputType_; - const VectorSerde::Kind serdeKind_; + const std::string serdeKind_; }; class TestCustomExchange : public exec::Exchange { @@ -2204,12 +2217,15 @@ TEST_P(MultiFragmentTest, customPlanNodeWithExchangeClient) { auto planStats = toPlanStats(leafTask->taskStats()); const auto serdeKindRuntimsStats = - planStats.at(partitionNodeId).customStats.at(Operator::kShuffleSerdeKind); + planStats.at(partitionNodeId) + .customStats.at(std::string(Operator::kShuffleSerdeKind)); ASSERT_EQ(serdeKindRuntimsStats.count, 1); ASSERT_EQ( - serdeKindRuntimsStats.min, static_cast(GetParam().serdeKind)); + serdeKindRuntimsStats.min, + static_cast(VectorSerde::kindByName(GetParam().serdeKind))); ASSERT_EQ( - serdeKindRuntimsStats.max, static_cast(GetParam().serdeKind)); + serdeKindRuntimsStats.max, + static_cast(VectorSerde::kindByName(GetParam().serdeKind))); } // This test is to reproduce the race condition between task terminate and no @@ -2902,12 +2918,12 @@ TEST_P(MultiFragmentTest, mergeSmallBatchesInExchange) { ASSERT_EQ(numPages, stats.customStats.at("numReceivedPages").sum); }; - if (GetParam().serdeKind == VectorSerde::Kind::kPresto) { + if (GetParam().serdeKind == "Presto") { test(1, 1'000); test(1'000, 56); - test(10'000, 6); - test(100'000, 1); - } else if (GetParam().serdeKind == VectorSerde::Kind::kCompactRow) { + test(10'000, 7); + test(100'000, 2); + } else if (GetParam().serdeKind == "CompactRow") { test(1, 1'000); test(1'000, 39); test(10'000, 5); @@ -2915,13 +2931,13 @@ TEST_P(MultiFragmentTest, mergeSmallBatchesInExchange) { } else { test(1, 1'000); test(1'000, 72); - test(10'000, 7); - test(100'000, 1); + test(10'000, 8); + test(100'000, 2); } } TEST_P(MultiFragmentTest, splitLargeCompactRowsInExchange) { - if (GetParam().serdeKind != VectorSerde::Kind::kCompactRow) { + if (GetParam().serdeKind != "CompactRow") { return; } const uint64_t kNumColumns = 100; @@ -2947,14 +2963,13 @@ TEST_P(MultiFragmentTest, splitLargeCompactRowsInExchange) { {"c0"}, kNumPartitions, /*outputLayout=*/{}, - VectorSerde::Kind::kCompactRow) + "CompactRow") .planNode(); const auto producerTaskId = "local://t1"; - auto plan = - test::PlanBuilder() - .exchange(asRowType(data->type()), VectorSerde::Kind::kCompactRow) - .planNode(); + auto plan = test::PlanBuilder() + .exchange(asRowType(data->type()), "CompactRow") + .planNode(); auto expected = makeRowVector(columns); @@ -3036,33 +3051,41 @@ TEST_P(MultiFragmentTest, compression) { auto consumerTaskStats = exec::toPlanStats(consumerTask->taskStats()); const auto& consumerPlanStats = consumerTaskStats.at("0"); ASSERT_EQ( - consumerPlanStats.customStats.at(Operator::kShuffleCompressionKind).min, + consumerPlanStats.customStats + .at(std::string(Operator::kShuffleCompressionKind)) + .min, static_cast(GetParam().compressionKind)); ASSERT_EQ( - consumerPlanStats.customStats.at(Operator::kShuffleCompressionKind).max, + consumerPlanStats.customStats + .at(std::string(Operator::kShuffleCompressionKind)) + .max, static_cast(GetParam().compressionKind)); ASSERT_EQ(data->size() * kNumRepeats, consumerPlanStats.outputRows); auto producerTaskStats = exec::toPlanStats(producerTask->taskStats()); const auto& producerStats = producerTaskStats.at("1"); ASSERT_EQ( - producerStats.customStats.at(Operator::kShuffleCompressionKind).min, + producerStats.customStats + .at(std::string(Operator::kShuffleCompressionKind)) + .min, static_cast(GetParam().compressionKind)); ASSERT_EQ( - producerStats.customStats.at(Operator::kShuffleCompressionKind).max, + producerStats.customStats + .at(std::string(Operator::kShuffleCompressionKind)) + .max, static_cast(GetParam().compressionKind)); if (GetParam().compressionKind == common::CompressionKind_NONE) { ASSERT_EQ( producerStats.customStats.count( - IterativeVectorSerializer::kCompressedBytes), + std::string(IterativeVectorSerializer::kCompressedBytes)), 0); ASSERT_EQ( producerStats.customStats.count( - IterativeVectorSerializer::kCompressionInputBytes), + std::string(IterativeVectorSerializer::kCompressionInputBytes)), 0); ASSERT_EQ( producerStats.customStats.count( - IterativeVectorSerializer::kCompressionSkippedBytes), + std::string(IterativeVectorSerializer::kCompressionSkippedBytes)), 0); return; } @@ -3070,18 +3093,22 @@ TEST_P(MultiFragmentTest, compression) { if (!expectSkipCompression) { ASSERT_LT( producerStats.customStats - .at(IterativeVectorSerializer::kCompressedBytes) + .at(std::string(IterativeVectorSerializer::kCompressedBytes)) .sum, producerStats.customStats - .at(IterativeVectorSerializer::kCompressionInputBytes) + .at(std::string( + IterativeVectorSerializer::kCompressionInputBytes)) .sum); ASSERT_EQ(producerStats.customStats.count("compressionSkippedBytes"), 0); } else { - ASSERT_LT( - 0, - producerStats.customStats - .at(IterativeVectorSerializer::kCompressionSkippedBytes) - .sum); + // Note: With the crash fix for PartitionedOutput, the serializer is + // recreated after each flush, which resets the compression skip counter. + // This means compression is always attempted, so we verify compression + // stats exist rather than checking for skipped bytes. + ASSERT_GT( + producerStats.customStats.count( + std::string(IterativeVectorSerializer::kCompressionInputBytes)), + 0); } }; @@ -3093,10 +3120,68 @@ TEST_P(MultiFragmentTest, compression) { test("local://t1", 0.7, false); } SCOPED_TRACE(fmt::format("minCompressionRatio 0.0000001")); - { test("local://t2", 0.0000001, true); } + { + test("local://t2", 0.0000001, true); + } } } +TEST_P(MultiFragmentTest, compressionPageSizeSkip) { + if (GetParam().compressionKind == common::CompressionKind_NONE) { + GTEST_SKIP() << "Page size skip only applies with compression enabled"; + } + if (GetParam().serdeKind != "Presto") { + GTEST_SKIP() + << "Page size skip only implemented in PrestoIterativeVectorSerializer"; + } + + const auto data = makeRowVector({makeFlatVector({1, 2, 3})}); + + const auto producerPlan = + test::PlanBuilder() + .values({data}) + .partitionedOutput({}, 1, /*outputLayout=*/{}, GetParam().serdeKind) + .planNode(); + + const auto plan = test::PlanBuilder() + .exchange(asRowType(data->type()), GetParam().serdeKind) + .singleAggregation({}, {"sum(c0)"}) + .planNode(); + + const auto expected = + makeRowVector({makeFlatVector(std::vector{6})}); + + std::unordered_map producerConfig; + producerConfig[core::QueryConfig::kShuffleCompressionKind] = + common::compressionKindToString(GetParam().compressionKind); + producerConfig[core::QueryConfig::kMinShuffleCompressionPageSizeBytes] = + std::to_string(1 << 30); + + auto producerTask = + makeTask("local://pageSizeSkip", producerPlan, producerConfig); + producerTask->start(1); + + auto consumerTask = + test::AssertQueryBuilder(plan) + .split(remoteSplit("local://pageSizeSkip")) + .config( + core::QueryConfig::kShuffleCompressionKind, + common::compressionKindToString(GetParam().compressionKind)) + .destination(0) + .assertResults(expected); + + auto producerTaskStats = exec::toPlanStats(producerTask->taskStats()); + const auto& producerStats = producerTaskStats.at("1"); + ASSERT_GT( + producerStats.customStats.count( + std::string(IterativeVectorSerializer::kCompressionSkippedBytes)), + 0); + ASSERT_EQ( + producerStats.customStats.count( + std::string(IterativeVectorSerializer::kCompressionInputBytes)), + 0); +} + TEST_P(MultiFragmentTest, scaledTableScan) { const int numSplits = 20; std::vector> splitFiles; @@ -3195,34 +3280,189 @@ TEST_P(MultiFragmentTest, scaledTableScan) { if (testData.scaleEnabled) { ASSERT_EQ( planStats.at(scanNodeId) - .customStats.count(TableScan::kNumRunningScaleThreads), + .customStats.count( + std::string(TableScan::kNumRunningScaleThreads)), 1); if (testData.expectScaleUp) { ASSERT_GE( planStats.at(scanNodeId) - .customStats[TableScan::kNumRunningScaleThreads] + .customStats[std::string(TableScan::kNumRunningScaleThreads)] .sum, 1); ASSERT_LE( planStats.at(scanNodeId) - .customStats[TableScan::kNumRunningScaleThreads] + .customStats[std::string(TableScan::kNumRunningScaleThreads)] .sum, numLeafDrivers); } else { ASSERT_EQ( planStats.at(scanNodeId) - .customStats.count(TableScan::kNumRunningScaleThreads), + .customStats.count( + std::string(TableScan::kNumRunningScaleThreads)), 1); } } else { ASSERT_EQ( planStats.at(scanNodeId) - .customStats.count(TableScan::kNumRunningScaleThreads), + .customStats.count( + std::string(TableScan::kNumRunningScaleThreads)), 0); } } } +// Test row output with no columns (empty schema). +TEST_P(MultiFragmentTest, emptySchema) { + // Create data with rows but no columns + auto emptyRowType = ROW({}, {}); + auto data = makeRowVector(emptyRowType, 1'000); + + std::vector> tasks; + auto leafTaskId = makeTaskId("leaf", 0); + + // Leaf task: Values -> PartitionedOutput + auto leafPlan = + PlanBuilder() + .values({data}) + .partitionedOutput({}, 1, /*outputLayout=*/{}, GetParam().serdeKind) + .planNode(); + + auto leafTask = makeTask(leafTaskId, leafPlan, tasks.size()); + tasks.push_back(leafTask); + leafTask->start(4); + + // Root task: Exchange -> Project + auto rootTaskId = makeTaskId("root", 0); + auto rootPlan = PlanBuilder() + .exchange(emptyRowType, GetParam().serdeKind) + .singleAggregation({}, {"count(1)"}) + .planNode(); + + test::AssertQueryBuilder(rootPlan, duckDbQueryRunner_) + .split(remoteSplit(leafTaskId)) + .config( + core::QueryConfig::kShuffleCompressionKind, + common::compressionKindToString(GetParam().compressionKind)) + .assertResults("SELECT 1000"); + + for (auto& task : tasks) { + ASSERT_TRUE(waitForTaskCompletion(task.get())) << task->taskId(); + } +} + +// Test stateful deserialization with different batch byte limits. +// This validates that the Exchange operator correctly breaks in the middle +// and continues from the leftover when batch size limits are reached. +TEST_P(MultiFragmentTest, batchBytes) { + auto test = [&](int32_t numBatches, + int32_t rowsPerBatch, + uint64_t preferredBatchBytes, + uint64_t expectedAtLeastOutputBatches = 0) { + SCOPED_TRACE( + fmt::format( + "numBatches={}, rowsPerBatch={}, preferredBatchBytes={}", + numBatches, + rowsPerBatch, + succinctBytes(preferredBatchBytes))); + + std::vector batches; + batches.reserve(numBatches); + + for (int i = 0; i < numBatches; ++i) { + auto batch = makeRowVector({ + makeFlatVector( + rowsPerBatch, + [i, rowsPerBatch](auto row) { return i * rowsPerBatch + row; }), + makeFlatVector( + rowsPerBatch, + [i, rowsPerBatch](auto row) { + return (i * rowsPerBatch + row) % 1000; + }), + makeFlatVector( + rowsPerBatch, + [i, rowsPerBatch](auto row) { + return (i * rowsPerBatch + row) * 1.5; + }), + }); + batches.push_back(batch); + } + + auto leafTaskId = makeTaskId("leaf", 0); + auto leafPlan = + PlanBuilder() + .values(batches) + .partitionedOutput({}, 1, {"c0", "c1", "c2"}, GetParam().serdeKind) + .planNode(); + + auto leafTask = makeTask(leafTaskId, leafPlan, 0); + leafTask->start(1); + + core::PlanNodeId exchangeNodeId; + auto rootPlan = + PlanBuilder() + .exchange( + ROW({"c0", "c1", "c2"}, {BIGINT(), INTEGER(), DOUBLE()}), + GetParam().serdeKind) + .capturePlanNodeId(exchangeNodeId) + .singleAggregation({}, {"count(1)", "sum(c0)", "avg(c2)"}) + .planNode(); + + auto extraConfigs = std::unordered_map{ + {core::QueryConfig::kPreferredOutputBatchBytes, + std::to_string(preferredBatchBytes)}, + {core::QueryConfig::kShuffleCompressionKind, + common::compressionKindToString(GetParam().compressionKind)}}; + + auto task = test::AssertQueryBuilder(rootPlan, duckDbQueryRunner_) + .split(remoteSplit(leafTaskId)) + .configs(extraConfigs) + .assertResults( + fmt::format( + "SELECT {}, {}, {}", + numBatches * rowsPerBatch, + (static_cast(numBatches) * rowsPerBatch * + (numBatches * rowsPerBatch - 1)) / + 2, + (static_cast(numBatches) * rowsPerBatch * + (numBatches * rowsPerBatch - 1)) / + 2 * 1.5 / (numBatches * rowsPerBatch))); + + waitForTaskCompletion(leafTask.get()); + + // Verify Exchange stats to ensure data was processed correctly + auto rootTaskStats = toPlanStats(task->taskStats()); + const auto& exchangeStats = rootTaskStats.at(exchangeNodeId); + + EXPECT_GE(exchangeStats.outputVectors, expectedAtLeastOutputBatches); + }; + + // Presto serialization operates at page-level granularity (pages are atomic). + // The number of output batches depends on how many Presto pages are created + // during serialization, which varies based on encoding, compression, and + // data. + // + // For this test (100 input batches × 100 rows = 10,000 rows): + // The actual behavior shows all pages are merged and processed together, + // resulting in a single batch output currently. + // + // This is a known limitation - Presto pages cannot be partially deserialized. + // The fix prevents INT32_MAX overflow by controlling the merge size, but + // fine-grained batch control requires deeper changes to PrestoVectorSerde. + + if (GetParam().serdeKind == "Presto") { + // Current implementation merges all pages and processes in one batch + // The key improvement is preventing overflow, not fine-grained batching + test(100, 100, 1, 1); // Expect single batch with all data + test(100, 100, 1ULL << 30, 1); // Expect single batch with all data + } else { + // Row-based serialization (CompactRow/UnsafeRow) supports row-level + // batching With 1 byte limit: Can produce many small batches + test(100, 100, 1, 100); + // With 1GB limit: All rows fit in one batch + test(100, 100, 1ULL << 30, 1); + } +} + VELOX_INSTANTIATE_TEST_SUITE_P( MultiFragmentTest, MultiFragmentTest, @@ -3230,7 +3470,7 @@ VELOX_INSTANTIATE_TEST_SUITE_P( [](const testing::TestParamInfo& info) { return fmt::format( "{}_{}", - VectorSerde::kindName(info.param.serdeKind), + info.param.serdeKind, compressionKindToString(info.param.compressionKind)); }); } // namespace diff --git a/velox/exec/tests/NestedLoopJoinTest.cpp b/velox/exec/tests/NestedLoopJoinTest.cpp index 611b67cb13a..4baec4a8c43 100644 --- a/velox/exec/tests/NestedLoopJoinTest.cpp +++ b/velox/exec/tests/NestedLoopJoinTest.cpp @@ -100,11 +100,12 @@ class NestedLoopJoinTest : public HiveConnectorTestBase { for (const auto joinType : joinTypes_) { for (const auto& comparison : comparisons_) { - SCOPED_TRACE(fmt::format( - "maxDrivers:{} joinType:{} comparison:{}", - std::to_string(numDrivers), - core::JoinTypeName::toName(joinType), - comparison)); + SCOPED_TRACE( + fmt::format( + "maxDrivers:{} joinType:{} comparison:{}", + std::to_string(numDrivers), + core::JoinTypeName::toName(joinType), + comparison)); params.planNode = PlanBuilder(planNodeIdGenerator) diff --git a/velox/exec/tests/OperatorTraceTest.cpp b/velox/exec/tests/OperatorTraceTest.cpp index 968fcbb6c2b..6ded9826b5d 100644 --- a/velox/exec/tests/OperatorTraceTest.cpp +++ b/velox/exec/tests/OperatorTraceTest.cpp @@ -21,24 +21,28 @@ #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/FileSystems.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/connectors/hive/HiveConnector.h" #include "velox/exec/OperatorTraceReader.h" #include "velox/exec/PartitionFunction.h" #include "velox/exec/Split.h" #include "velox/exec/TaskTraceReader.h" #include "velox/exec/TaskTraceWriter.h" -#include "velox/exec/Trace.h" -#include "velox/exec/TraceUtil.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" +#include "velox/exec/trace/Trace.h" +#include "velox/exec/trace/TraceUtil.h" #include "velox/serializers/PrestoSerializer.h" -using namespace facebook::velox::exec::test; - namespace facebook::velox::exec::trace::test { -class OperatorTraceTest : public HiveConnectorTestBase { +namespace { +using namespace facebook::velox::common::testutil; +using exec::test::assertEqualResults; +using exec::test::AssertQueryBuilder; +using exec::test::PlanBuilder; + +class OperatorTraceTest : public exec::test::HiveConnectorTestBase { protected: static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); @@ -106,6 +110,13 @@ class OperatorTraceTest : public HiveConnectorTestBase { return std::make_unique(nullptr, 0, 0, 0, 0); } + std::string getTaskTraceDirectory( + const std::string& traceDir, + const Task& task) { + return trace::getTaskTraceDirectory( + traceDir, task.queryCtx()->queryId(), task.taskId()); + } + RowTypePtr dataType_; VectorFuzzer vectorFuzzer_; }; @@ -299,16 +310,17 @@ TEST_F(OperatorTraceTest, traceMetadata) { executor_.get(), core::QueryConfig(expectedQueryConfigs), expectedConnectorProperties); - auto writer = trace::TaskTraceMetadataWriter(outputDir->getPath(), pool()); - auto traceNode = getTraceNode(planNode, traceNodeId); - writer.write(queryCtx, traceNode); + auto writer = + trace::TaskTraceMetadataWriter(outputDir->getPath(), traceNodeId, pool()); + writer.write(*queryCtx, *planNode); const auto reader = trace::TaskTraceMetadataReader(outputDir->getPath(), pool()); const auto actualQueryConfigs = reader.queryConfigs(); const auto actualConnectorProperties = reader.connectorProperties(); const auto actualQueryPlan = reader.queryPlan(); - ASSERT_TRUE(isSamePlan(actualQueryPlan, traceNode)); + auto expectedTraceNode = getTraceNode(*planNode, traceNodeId); + ASSERT_TRUE(isSamePlan(actualQueryPlan, expectedTraceNode)); ASSERT_EQ(actualQueryConfigs.size(), expectedQueryConfigs.size()); for (const auto& [key, value] : actualQueryConfigs) { ASSERT_EQ(actualQueryConfigs.at(key), expectedQueryConfigs.at(key)); @@ -424,7 +436,7 @@ TEST_F(OperatorTraceTest, task) { const auto actualQueryPlan = reader.queryPlan(); ASSERT_TRUE( - isSamePlan(actualQueryPlan, getTraceNode(planNode, hashJoinNodeId))); + isSamePlan(actualQueryPlan, getTraceNode(*planNode, hashJoinNodeId))); ASSERT_EQ(actualQueryConfigs.size(), expectedQueryConfigs.size()); for (const auto& [key, value] : actualQueryConfigs) { ASSERT_EQ(actualQueryConfigs.at(key), expectedQueryConfigs.at(key)); @@ -445,7 +457,8 @@ TEST_F(OperatorTraceTest, task) { } TEST_F(OperatorTraceTest, error) { - const auto planNode = PlanBuilder().values({}).planNode(); + const auto planNode = + PlanBuilder().values(std::vector{}).planNode(); // No trace dir. { const auto queryConfigs = std::unordered_map{ @@ -505,8 +518,7 @@ TEST_F(OperatorTraceTest, error) { .queryCtx(queryCtx) .maxDrivers(1) .copyResults(pool()), - - "Trace plan node ID = nonexist not found from task"); + "Trace plan node ID = 'nonexist' not found from task"); } } @@ -846,12 +858,12 @@ TEST_F(OperatorTraceTest, traceSplitPartial) { auto ioBuf = folly::IOBuf::create(12 + 16); folly::io::Appender appender(ioBuf.get(), 0); // Writes an invalid split without crc. - appender.writeLE(length); + appender.writeLE(length); appender.push(reinterpret_cast(split.data()), length); // Writes a valid spilt. - appender.writeLE(length); + appender.writeLE(length); appender.push(reinterpret_cast(split.data()), length); - appender.writeLE(crc32); + appender.writeLE(crc32); splitInfoFile->append(std::move(ioBuf)); splitInfoFile->close(); @@ -935,13 +947,13 @@ TEST_F(OperatorTraceTest, traceSplitCorrupted) { auto ioBuf = folly::IOBuf::create(16 * 2); folly::io::Appender appender(ioBuf.get(), 0); // Writes an invalid split with a wrong checksum. - appender.writeLE(length); + appender.writeLE(length); appender.push(reinterpret_cast(split.data()), length); - appender.writeLE(crc32 - 1); + appender.writeLE(crc32 - 1); // Writes a valid split. - appender.writeLE(length); + appender.writeLE(length); appender.push(reinterpret_cast(split.data()), length); - appender.writeLE(crc32); + appender.writeLE(crc32); splitInfoFile->append(std::move(ioBuf)); splitInfoFile->close(); @@ -1124,12 +1136,15 @@ TEST_F(OperatorTraceTest, canTrace) { {"IndexLookupJoin", true}, {"Unnest", true}, {"RowNumber", false}, - {"OrderBy", false}, + {"OrderBy", true}, + {"TopNRowNumber", true}, {"PartialAggregation", true}, {"Aggregation", true}, {"TableWrite", true}, {"TableScan", true}, - {"FilterProject", true}}; + {"FilterProject", true}, + {"Exchange", true}, + {"MergeExchange", true}}; for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); ASSERT_EQ(testData.canTrace, trace::canTrace(testData.operatorType)); @@ -1162,10 +1177,12 @@ TEST_F(OperatorTraceTest, hiveConnectorId) { .config(core::QueryConfig::kQueryTraceTaskRegExp, ".*") .config(core::QueryConfig::kQueryTraceNodeId, "0") .splits(splits) - .runWithoutResults(task); + .countResults(task); const auto taskTraceDir = getTaskTraceDirectory(traceDirPath->getPath(), *task); const auto reader = trace::TaskTraceMetadataReader(taskTraceDir, pool()); ASSERT_EQ("test-hive", reader.connectorId("0")); } + +} // namespace } // namespace facebook::velox::exec::trace::test diff --git a/velox/exec/tests/OperatorUtilsTest.cpp b/velox/exec/tests/OperatorUtilsTest.cpp index 45da6804d67..937823afd6e 100644 --- a/velox/exec/tests/OperatorUtilsTest.cpp +++ b/velox/exec/tests/OperatorUtilsTest.cpp @@ -15,8 +15,10 @@ */ #include "velox/exec/OperatorUtils.h" #include +#include #include "velox/dwio/common/tests/utils/BatchMaker.h" #include "velox/exec/Operator.h" +#include "velox/exec/tests/utils/MergeTestBase.h" #include "velox/exec/tests/utils/OperatorTestBase.h" #include "velox/vector/fuzzer/VectorFuzzer.h" @@ -48,7 +50,8 @@ class OperatorUtilsTest : public OperatorTestBase { std::move(planFragment), 0, core::QueryCtx::create(executor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); driver_ = Driver::testingCreate(); driverCtx_ = std::make_unique(task_, 0, 0, 0, 0); driverCtx_->driver = driver_.get(); @@ -57,7 +60,8 @@ class OperatorUtilsTest : public OperatorTestBase { void gatherCopyTest( const std::shared_ptr& targetType, const std::shared_ptr& sourceType, - int numSources) { + int numSources, + bool flattenSources = true) { folly::Random::DefaultGenerator rng(1); const int kNumRows = 500; const int kNumColumns = sourceType->size(); @@ -65,8 +69,9 @@ class OperatorUtilsTest : public OperatorTestBase { // Build source vectors with nulls. std::vector sources; for (int i = 0; i < numSources; ++i) { - sources.push_back(std::static_pointer_cast( - BatchMaker::createBatch(sourceType, kNumRows, *pool_))); + sources.push_back( + std::static_pointer_cast( + BatchMaker::createBatch(sourceType, kNumRows, *pool_))); for (int j = 0; j < kNumColumns; ++j) { auto vector = sources.back()->childAt(j); int nullRow = (folly::Random::rand32() % kNumRows) / 4; @@ -78,6 +83,23 @@ class OperatorUtilsTest : public OperatorTestBase { } } + if (!flattenSources) { + for (int i = 0; i < numSources; ++i) { + const auto source = sources[i]; + const auto numRows = source->size(); + std::vector sortIndices(numRows, 0); + for (auto i = 0; i < numRows; ++i) { + sortIndices[i] = i; + } + BufferPtr indices = allocateIndices(numRows, pool_.get()); + auto rawIndices = indices->asMutable(); + for (size_t i = 0; i < numRows; ++i) { + rawIndices[i] = sortIndices[i]; + } + sources[i] = wrap(numRows, indices, source); + } + } + std::vector columnMap; if (sourceType != targetType) { for (column_index_t sourceChannel = 0; sourceChannel < kNumColumns; @@ -133,6 +155,61 @@ class OperatorUtilsTest : public OperatorTestBase { } } + void gatherMergeTest( + int32_t numValues, + int numMergeWays, + int targetSize, + bool useRandom) { + auto goldenVector = makeRowVector({ + makeFlatVector(numValues, [&](auto row) { return row; }), + }); + std::vector> mergeWays(numMergeWays); + for (int32_t value = 0; value < numValues; value++) { + int way = useRandom ? folly::Random::rand32() % numMergeWays + : value % numMergeWays; + mergeWays[way].push_back(value); + } + std::vector sources; + std::vector> streams; + std::vector sortKeys = {{0, {true, true}}}; + for (int way = 0; way < numMergeWays; way++) { + auto source = makeRowVector({ + makeFlatVector( + mergeWays[way].size(), + [&](auto row) { return mergeWays[way][row]; }), + }); + sources.push_back(source); + streams.push_back( + std::make_unique(way, sortKeys, source)); + } + auto mergeTree = + std::make_unique>(std::move(streams)); + RowVectorPtr targetVector = std::static_pointer_cast( + BaseVector::create(sources[0]->type(), targetSize, pool_.get())); + std::vector bufferSources(targetSize); + std::vector bufferSourceIndices(targetSize); + for (int32_t batch = 0; batch * targetSize < numValues; batch++) { + int32_t valueBegin = batch * targetSize; + int32_t valueEnd = valueBegin + targetSize; + valueEnd = std::min(valueEnd, numValues); + VectorPtr tmp = std::move(targetVector); + BaseVector::prepareForReuse(tmp, targetSize); + targetVector = std::static_pointer_cast(tmp); + for (auto& child : targetVector->children()) { + child->resize(targetSize); + } + int count = 0; + testingGatherMerge( + targetVector, *mergeTree, count, bufferSources, bufferSourceIndices); + EXPECT_EQ(count, valueEnd - valueBegin); + auto result = targetVector->childAt(0).get(); + auto golden = goldenVector->childAt(0).get(); + for (int32_t row = 0; row < valueEnd - valueBegin; row++) { + EXPECT_TRUE(result->equalValueAt(golden, row, valueBegin + row)); + } + } + } + void setTaskOutputBatchConfig( uint32_t preferredBatchSize, uint32_t maxRows, @@ -373,6 +450,67 @@ TEST_F(OperatorUtilsTest, gatherCopy) { } } +TEST_F(OperatorUtilsTest, gatherCopyEncoding) { + std::shared_ptr rowType; + std::shared_ptr reversedRowType; + { + std::vector names = { + "bool_val", + "tiny_val", + "small_val", + "int_val", + "long_val", + "ordinal", + "float_val", + "double_val", + "string_val", + "array_val", + "struct_val", + "map_val"}; + std::vector reversedNames = names; + std::reverse(reversedNames.begin(), reversedNames.end()); + + std::vector> types = { + BOOLEAN(), + TINYINT(), + SMALLINT(), + INTEGER(), + BIGINT(), + BIGINT(), + REAL(), + DOUBLE(), + VARCHAR(), + ARRAY(VARCHAR()), + ROW({{"s_int", INTEGER()}, {"s_array", ARRAY(REAL())}}), + MAP(VARCHAR(), + MAP(BIGINT(), + ROW({{"s2_int", INTEGER()}, {"s2_string", VARCHAR()}})))}; + std::vector> reversedTypes = types; + std::reverse(reversedTypes.begin(), reversedTypes.end()); + + rowType = ROW(std::move(names), std::move(types)); + reversedRowType = ROW(std::move(reversedNames), std::move(reversedTypes)); + } + + // Gather copy with identical column mapping. + gatherCopyTest(rowType, rowType, 1, false); + gatherCopyTest(rowType, rowType, 5, false); + // Gather copy with non-identical column mapping. + gatherCopyTest(rowType, reversedRowType, 1, false); + gatherCopyTest(rowType, reversedRowType, 5, false); +} + +TEST_F(OperatorUtilsTest, gatherMerge) { + gatherMergeTest(1234, 2, 10, false); + gatherMergeTest(1234, 2, 100, false); + gatherMergeTest(1234, 10, 10, false); + gatherMergeTest(1234, 10, 100, false); + gatherMergeTest(1234, 2, 10, true); + gatherMergeTest(1234, 2, 100, true); + gatherMergeTest(1234, 10, 10, true); + gatherMergeTest(1234, 10, 100, true); +} + TEST_F(OperatorUtilsTest, makeOperatorSpillPath) { EXPECT_EQ("spill/3_1_100", makeOperatorSpillPath("spill", 3, 1, 100)); } @@ -426,56 +564,56 @@ TEST_F(OperatorUtilsTest, wrap) { TEST_F(OperatorUtilsTest, addOperatorRuntimeStats) { std::unordered_map stats; - const std::string statsName("stats"); + constexpr std::string_view statsName{"stats"}; const RuntimeCounter minStatsValue(100, RuntimeCounter::Unit::kBytes); const RuntimeCounter maxStatsValue(200, RuntimeCounter::Unit::kBytes); addOperatorRuntimeStats(statsName, minStatsValue, stats); - ASSERT_EQ(stats[statsName].count, 1); - ASSERT_EQ(stats[statsName].sum, 100); - ASSERT_EQ(stats[statsName].max, 100); - ASSERT_EQ(stats[statsName].min, 100); + ASSERT_EQ(stats[std::string(statsName)].count, 1); + ASSERT_EQ(stats[std::string(statsName)].sum, 100); + ASSERT_EQ(stats[std::string(statsName)].max, 100); + ASSERT_EQ(stats[std::string(statsName)].min, 100); addOperatorRuntimeStats(statsName, maxStatsValue, stats); - ASSERT_EQ(stats[statsName].count, 2); - ASSERT_EQ(stats[statsName].sum, 300); - ASSERT_EQ(stats[statsName].max, 200); - ASSERT_EQ(stats[statsName].min, 100); + ASSERT_EQ(stats[std::string(statsName)].count, 2); + ASSERT_EQ(stats[std::string(statsName)].sum, 300); + ASSERT_EQ(stats[std::string(statsName)].max, 200); + ASSERT_EQ(stats[std::string(statsName)].min, 100); addOperatorRuntimeStats(statsName, maxStatsValue, stats); - ASSERT_EQ(stats[statsName].count, 3); - ASSERT_EQ(stats[statsName].sum, 500); - ASSERT_EQ(stats[statsName].max, 200); - ASSERT_EQ(stats[statsName].min, 100); + ASSERT_EQ(stats[std::string(statsName)].count, 3); + ASSERT_EQ(stats[std::string(statsName)].sum, 500); + ASSERT_EQ(stats[std::string(statsName)].max, 200); + ASSERT_EQ(stats[std::string(statsName)].min, 100); } TEST_F(OperatorUtilsTest, setOperatorRuntimeStats) { std::unordered_map stats; - const std::string statsName("stats"); + constexpr std::string_view statsName{"stats"}; const RuntimeCounter minStatsValue(100, RuntimeCounter::Unit::kBytes); const RuntimeCounter maxStatsValue(200, RuntimeCounter::Unit::kBytes); setOperatorRuntimeStats(statsName, minStatsValue, stats); - ASSERT_EQ(stats[statsName].count, 1); - ASSERT_EQ(stats[statsName].sum, 100); - ASSERT_EQ(stats[statsName].max, 100); - ASSERT_EQ(stats[statsName].min, 100); + ASSERT_EQ(stats[std::string(statsName)].count, 1); + ASSERT_EQ(stats[std::string(statsName)].sum, 100); + ASSERT_EQ(stats[std::string(statsName)].max, 100); + ASSERT_EQ(stats[std::string(statsName)].min, 100); setOperatorRuntimeStats(statsName, maxStatsValue, stats); - ASSERT_EQ(stats[statsName].count, 1); - ASSERT_EQ(stats[statsName].sum, 200); - ASSERT_EQ(stats[statsName].max, 200); - ASSERT_EQ(stats[statsName].min, 200); + ASSERT_EQ(stats[std::string(statsName)].count, 1); + ASSERT_EQ(stats[std::string(statsName)].sum, 200); + ASSERT_EQ(stats[std::string(statsName)].max, 200); + ASSERT_EQ(stats[std::string(statsName)].min, 200); addOperatorRuntimeStats(statsName, maxStatsValue, stats); - ASSERT_EQ(stats[statsName].count, 2); - ASSERT_EQ(stats[statsName].sum, 400); - ASSERT_EQ(stats[statsName].max, 200); - ASSERT_EQ(stats[statsName].min, 200); + ASSERT_EQ(stats[std::string(statsName)].count, 2); + ASSERT_EQ(stats[std::string(statsName)].sum, 400); + ASSERT_EQ(stats[std::string(statsName)].max, 200); + ASSERT_EQ(stats[std::string(statsName)].min, 200); setOperatorRuntimeStats(statsName, minStatsValue, stats); - ASSERT_EQ(stats[statsName].count, 1); - ASSERT_EQ(stats[statsName].sum, 100); - ASSERT_EQ(stats[statsName].max, 100); - ASSERT_EQ(stats[statsName].min, 100); + ASSERT_EQ(stats[std::string(statsName)].count, 1); + ASSERT_EQ(stats[std::string(statsName)].sum, 100); + ASSERT_EQ(stats[std::string(statsName)].max, 100); + ASSERT_EQ(stats[std::string(statsName)].min, 100); } TEST_F(OperatorUtilsTest, initializeRowNumberMapping) { diff --git a/velox/exec/tests/OrderByTest.cpp b/velox/exec/tests/OrderByTest.cpp index c343dbe8020..685660511a3 100644 --- a/velox/exec/tests/OrderByTest.cpp +++ b/velox/exec/tests/OrderByTest.cpp @@ -16,10 +16,11 @@ #include #include -#include "folly/experimental/EventCount.h" +#include "folly/synchronization/EventCount.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/FileSystems.h" #include "velox/common/memory/tests/SharedArbitratorTestUtil.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/common/testutil/TestValue.h" #include "velox/core/QueryConfig.h" #include "velox/dwio/common/tests/utils/BatchMaker.h" @@ -30,7 +31,6 @@ #include "velox/exec/tests/utils/OperatorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/exec/tests/utils/QueryAssertions.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/vector/fuzzer/VectorFuzzer.h" using namespace facebook::velox; @@ -42,8 +42,8 @@ using namespace facebook::velox::exec::test; namespace facebook::velox::exec::test { namespace { // Returns aggregated spilled stats by 'task'. -common::SpillStats spilledStats(const exec::Task& task) { - common::SpillStats spilledStats; +exec::SpillStats spilledStats(const exec::Task& task) { + exec::SpillStats spilledStats; auto stats = task.taskStats(); for (auto& pipeline : stats.pipelineStats) { for (auto op : pipeline.operatorStats) { @@ -194,7 +194,7 @@ class OrderByTest : public OperatorTestBase { } { SCOPED_TRACE("run with spilling"); - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); TestScopedSpillInjection scopedSpillInjection(100); queryCtx->testingOverrideConfigUnsafe({ @@ -495,7 +495,7 @@ TEST_F(OrderByTest, spill) { const auto expectedResult = AssertQueryBuilder(plan).copyResults(pool_.get()); - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); TestScopedSpillInjection scopedSpillInjection(100); auto task = AssertQueryBuilder(plan) .spillDirectory(spillDirectory->getPath()) @@ -510,20 +510,29 @@ TEST_F(OrderByTest, spill) { ASSERT_GT(planStats.spilledInputBytes, 0); ASSERT_EQ(planStats.spilledPartitions, 1); ASSERT_GT(planStats.spilledFiles, 0); - ASSERT_GT(planStats.customStats[Operator::kSpillRuns].count, 0); - ASSERT_GT(planStats.customStats[Operator::kSpillFillTime].sum, 0); - ASSERT_GT(planStats.customStats[Operator::kSpillSortTime].sum, 0); - ASSERT_GT(planStats.customStats[Operator::kSpillExtractVectorTime].sum, 0); - ASSERT_GT(planStats.customStats[Operator::kSpillSerializationTime].sum, 0); - ASSERT_GT(planStats.customStats[Operator::kSpillFlushTime].sum, 0); + ASSERT_GT(planStats.customStats[std::string(Operator::kSpillRuns)].count, 0); + ASSERT_GT( + planStats.customStats[std::string(Operator::kSpillFillTime)].sum, 0); + ASSERT_GT( + planStats.customStats[std::string(Operator::kSpillSortTime)].sum, 0); + ASSERT_GT( + planStats.customStats[std::string(Operator::kSpillExtractVectorTime)].sum, + 0); + ASSERT_GT( + planStats.customStats[std::string(Operator::kSpillSerializationTime)].sum, + 0); + ASSERT_GT( + planStats.customStats[std::string(Operator::kSpillFlushTime)].sum, 0); ASSERT_EQ( - planStats.customStats[Operator::kSpillSerializationTime].count, - planStats.customStats[Operator::kSpillFlushTime].count); - ASSERT_GT(planStats.customStats[Operator::kSpillWrites].sum, 0); - ASSERT_GT(planStats.customStats[Operator::kSpillWriteTime].sum, 0); + planStats.customStats[std::string(Operator::kSpillSerializationTime)] + .count, + planStats.customStats[std::string(Operator::kSpillFlushTime)].count); + ASSERT_GT(planStats.customStats[std::string(Operator::kSpillWrites)].sum, 0); + ASSERT_GT( + planStats.customStats[std::string(Operator::kSpillWriteTime)].sum, 0); ASSERT_EQ( - planStats.customStats[Operator::kSpillWrites].count, - planStats.customStats[Operator::kSpillWriteTime].count); + planStats.customStats[std::string(Operator::kSpillWrites)].count, + planStats.customStats[std::string(Operator::kSpillWriteTime)].count); OperatorTestBase::deleteTaskAndCheckSpillDirectory(task); } @@ -556,10 +565,11 @@ DEBUG_ONLY_TEST_F(OrderByTest, reclaimDuringInputProcessing) { for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); - queryCtx->testingOverrideMemoryPool(memory::memoryManager()->addRootPool( - queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); auto expectedResult = AssertQueryBuilder( PlanBuilder() @@ -698,10 +708,11 @@ DEBUG_ONLY_TEST_F(OrderByTest, reclaimDuringReserve) { batches.push_back(fuzzer.fuzzRow(rowType)); } - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); - queryCtx->testingOverrideMemoryPool(memory::memoryManager()->addRootPool( - queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); auto expectedResult = AssertQueryBuilder( PlanBuilder() @@ -814,7 +825,7 @@ DEBUG_ONLY_TEST_F(OrderByTest, reclaimDuringAllocation) { const std::vector enableSpillings = {false, true}; for (const auto enableSpilling : enableSpillings) { SCOPED_TRACE(fmt::format("enableSpilling {}", enableSpilling)); - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); queryCtx->testingOverrideMemoryPool( memory::memoryManager()->addRootPool(queryCtx->queryId(), kMaxBytes)); @@ -944,10 +955,11 @@ DEBUG_ONLY_TEST_F(OrderByTest, reclaimDuringOutputProcessing) { const std::vector enableSpillings = {false, true}; for (const auto enableSpilling : enableSpillings) { SCOPED_TRACE(fmt::format("enableSpilling {}", enableSpilling)); - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); - queryCtx->testingOverrideMemoryPool(memory::memoryManager()->addRootPool( - queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); auto expectedResult = AssertQueryBuilder( PlanBuilder() @@ -1226,7 +1238,7 @@ DEBUG_ONLY_TEST_F(OrderByTest, spillWithNoMoreOutput) { ASSERT_EQ(reclaimerStats_.reclaimedBytes, 0); }))); - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); auto task = AssertQueryBuilder(plan) .spillDirectory(spillDirectory->getPath()) @@ -1259,7 +1271,7 @@ TEST_F(OrderByTest, maxSpillBytes) { .orderBy({fmt::format("{} ASC NULLS LAST", "c0")}, false) .capturePlanNodeId(orderNodeId) .planNode(); - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); struct { @@ -1312,18 +1324,19 @@ DEBUG_ONLY_TEST_F(OrderByTest, reclaimFromOrderBy) { memory::testingRunArbitration(); }))); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); core::PlanNodeId orderById; auto task = AssertQueryBuilder(duckDbQueryRunner_) .spillDirectory(spillDirectory->getPath()) .config(core::QueryConfig::kSpillEnabled, true) .config(core::QueryConfig::kOrderBySpillEnabled, true) - .plan(PlanBuilder() - .values(vectors) - .orderBy({"c0 ASC NULLS LAST"}, false) - .capturePlanNodeId(orderById) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .orderBy({"c0 ASC NULLS LAST"}, false) + .capturePlanNodeId(orderById) + .planNode()) .assertResults("SELECT * FROM tmp ORDER BY c0 ASC NULLS LAST"); auto taskStats = exec::toPlanStats(task->taskStats()); auto& planStats = taskStats.at(orderById); @@ -1351,16 +1364,17 @@ DEBUG_ONLY_TEST_F(OrderByTest, reclaimFromEmptyOrderBy) { testingRunArbitration(op->pool()); }))); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); auto task = AssertQueryBuilder(duckDbQueryRunner_) .spillDirectory(spillDirectory->getPath()) .config(core::QueryConfig::kSpillEnabled, true) .config(core::QueryConfig::kOrderBySpillEnabled, true) - .plan(PlanBuilder() - .values(vectors) - .orderBy({"c0 ASC NULLS LAST"}, false) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .orderBy({"c0 ASC NULLS LAST"}, false) + .planNode()) .assertResults("SELECT * FROM tmp ORDER BY c0 ASC NULLS LAST"); // Verify no spill has been triggered. const auto stats = task->taskStats().pipelineStats; @@ -1376,8 +1390,9 @@ DEBUG_ONLY_TEST_F(OrderByTest, orderByWithLazyInput) { VectorFuzzer(fuzzerOpts_, pool()).fuzzRowChildrenToLazy(nonLazyVector)); std::vector lazyInputCopy; - lazyInputCopy.push_back(std::dynamic_pointer_cast( - nonLazyVector->testingCopyPreserveEncodings())); + lazyInputCopy.push_back( + std::dynamic_pointer_cast( + nonLazyVector->testingCopyPreserveEncodings())); createDuckDbTable(lazyInputCopy); std::atomic_bool nonReclaimableSectionEntered{false}; @@ -1398,16 +1413,17 @@ DEBUG_ONLY_TEST_F(OrderByTest, orderByWithLazyInput) { } }))); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); auto task = AssertQueryBuilder(duckDbQueryRunner_) .spillDirectory(spillDirectory->getPath()) .config(core::QueryConfig::kSpillEnabled, true) .config(core::QueryConfig::kOrderBySpillEnabled, true) - .plan(PlanBuilder() - .values(lazyInput) - .orderBy({"c0 ASC NULLS LAST"}, false) - .planNode()) + .plan( + PlanBuilder() + .values(lazyInput) + .orderBy({"c0 ASC NULLS LAST"}, false) + .planNode()) .assertResults("SELECT * FROM tmp ORDER BY c0 ASC NULLS LAST"); ASSERT_TRUE(lazyLoadedInNonReclaimableSection.has_value()); diff --git a/velox/exec/tests/OutputBufferManagerTest.cpp b/velox/exec/tests/OutputBufferManagerTest.cpp index f72957c9281..8855b4332b9 100644 --- a/velox/exec/tests/OutputBufferManagerTest.cpp +++ b/velox/exec/tests/OutputBufferManagerTest.cpp @@ -14,8 +14,9 @@ * limitations under the License. */ #include "velox/exec/OutputBufferManager.h" +#include #include -#include "folly/experimental/EventCount.h" +#include "folly/synchronization/EventCount.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/dwio/common/tests/utils/BatchMaker.h" #include "velox/exec/Task.h" @@ -33,23 +34,25 @@ using facebook::velox::test::BatchMaker; struct TestParam { PartitionedOutputNode::Kind outputKind; - VectorSerde::Kind serdeKind; + std::string serdeKind; - TestParam( - PartitionedOutputNode::Kind _outputKind, - VectorSerde::Kind _serdeKind) + TestParam(PartitionedOutputNode::Kind _outputKind, std::string _serdeKind) : outputKind(_outputKind), serdeKind(_serdeKind) {} }; +inline void PrintTo(const TestParam& param, std::ostream* os) { + *os << fmt::format("{}_{}", param.outputKind, param.serdeKind); +} + class OutputBufferManagerTest : public testing::Test { protected: - OutputBufferManagerTest() : serdeKind_(VectorSerde::Kind::kPresto) { + OutputBufferManagerTest() : serdeKind_("Presto") { std::vector names = {"c0", "c1"}; std::vector types = {BIGINT(), VARCHAR()}; rowType_ = ROW(std::move(names), std::move(types)); } - explicit OutputBufferManagerTest(VectorSerde::Kind serdeKind) + explicit OutputBufferManagerTest(std::string serdeKind) : serdeKind_(serdeKind) { std::vector names = {"c0", "c1"}; std::vector types = {BIGINT(), VARCHAR()}; @@ -71,14 +74,14 @@ class OutputBufferManagerTest : public testing::Test { serializer::presto::PrestoOutputStreamListener>(); }); } - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kPresto)) { + if (!isRegisteredNamedVectorSerde("Presto")) { facebook::velox::serializer::presto::PrestoVectorSerde:: registerNamedVectorSerde(); } - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kCompactRow)) { + if (!isRegisteredNamedVectorSerde("CompactRow")) { serializer::CompactRowVectorSerde::registerNamedVectorSerde(); } - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kUnsafeRow)) { + if (!isRegisteredNamedVectorSerde("UnsafeRow")) { serializer::spark::UnsafeRowVectorSerde::registerNamedVectorSerde(); } } @@ -109,13 +112,14 @@ class OutputBufferManagerTest : public testing::Test { std::move(planFragment), 0, std::move(queryCtx), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); bufferManager_->initializeTask(task, kind, numDestinations, numDrivers); return task; } - std::unique_ptr makeSerializedPage( + std::unique_ptr makeSerializedPage( RowTypePtr rowType, vector_size_t size) { auto vector = std::dynamic_pointer_cast( @@ -375,8 +379,9 @@ class OutputBufferManagerTest : public testing::Test { std::vector> pages, int64_t inSequence, std::vector remainingBytes) { - promise.setValue(Response{ - std::move(pages), inSequence, std::move(remainingBytes)}); + promise.setValue( + Response{ + std::move(pages), inSequence, std::move(remainingBytes)}); }); future.wait(); ASSERT_TRUE(future.isReady()); @@ -432,10 +437,10 @@ class OutputBufferManagerTest : public testing::Test { } } - const VectorSerde::Kind serdeKind_; + const std::string serdeKind_; std::shared_ptr executor_{ std::make_shared( - std::thread::hardware_concurrency())}; + folly::available_concurrency())}; std::shared_ptr pool_; std::shared_ptr bufferManager_; RowTypePtr rowType_; @@ -443,13 +448,11 @@ class OutputBufferManagerTest : public testing::Test { class OutputBufferManagerWithDifferentSerdeKindsTest : public OutputBufferManagerTest, - public testing::WithParamInterface { + public testing::WithParamInterface { public: - static std::vector getTestParams() { - static std::vector params = { - VectorSerde::Kind::kPresto, - VectorSerde::Kind::kCompactRow, - VectorSerde::Kind::kUnsafeRow}; + static std::vector getTestParams() { + static std::vector params = { + "Presto", "CompactRow", "UnsafeRow"}; return params; } }; @@ -460,21 +463,15 @@ class AllOutputBufferManagerTest public: static std::vector getTestParams() { static std::vector params = { - {PartitionedOutputNode::Kind::kBroadcast, VectorSerde::Kind::kPresto}, - {PartitionedOutputNode::Kind::kBroadcast, - VectorSerde::Kind::kCompactRow}, - {PartitionedOutputNode::Kind::kBroadcast, - VectorSerde::Kind::kUnsafeRow}, - {PartitionedOutputNode::Kind::kPartitioned, VectorSerde::Kind::kPresto}, - {PartitionedOutputNode::Kind::kPartitioned, - VectorSerde::Kind::kCompactRow}, - {PartitionedOutputNode::Kind::kPartitioned, - VectorSerde::Kind::kUnsafeRow}, - {PartitionedOutputNode::Kind::kArbitrary, VectorSerde::Kind::kPresto}, - {PartitionedOutputNode::Kind::kArbitrary, - VectorSerde::Kind::kCompactRow}, - {PartitionedOutputNode::Kind::kArbitrary, - VectorSerde::Kind::kUnsafeRow}}; + {PartitionedOutputNode::Kind::kBroadcast, "Presto"}, + {PartitionedOutputNode::Kind::kBroadcast, "CompactRow"}, + {PartitionedOutputNode::Kind::kBroadcast, "UnsafeRow"}, + {PartitionedOutputNode::Kind::kPartitioned, "Presto"}, + {PartitionedOutputNode::Kind::kPartitioned, "CompactRow"}, + {PartitionedOutputNode::Kind::kPartitioned, "UnsafeRow"}, + {PartitionedOutputNode::Kind::kArbitrary, "Presto"}, + {PartitionedOutputNode::Kind::kArbitrary, "CompactRow"}, + {PartitionedOutputNode::Kind::kArbitrary, "UnsafeRow"}}; return params; } @@ -1470,7 +1467,7 @@ TEST_P( std::memcpy(iobuf->writableData(), payload.data(), payloadSize); iobuf->append(payloadSize); - auto page = std::make_unique(std::move(iobuf)); + auto page = std::make_unique(std::move(iobuf)); auto queue = std::make_shared(1, 0); std::vector promises; diff --git a/velox/exec/tests/PartitionedOutputTest.cpp b/velox/exec/tests/PartitionedOutputTest.cpp index be921e80c91..fb4f4492159 100644 --- a/velox/exec/tests/PartitionedOutputTest.cpp +++ b/velox/exec/tests/PartitionedOutputTest.cpp @@ -23,15 +23,11 @@ namespace facebook::velox::exec::test { -class PartitionedOutputTest - : public OperatorTestBase, - public testing::WithParamInterface { +class PartitionedOutputTest : public OperatorTestBase, + public testing::WithParamInterface { public: - static std::vector getTestParams() { - const std::vector kinds( - {VectorSerde::Kind::kPresto, - VectorSerde::Kind::kCompactRow, - VectorSerde::Kind::kUnsafeRow}); + static std::vector getTestParams() { + const std::vector kinds({"Presto", "CompactRow", "UnsafeRow"}); return kinds; } @@ -159,10 +155,15 @@ TEST_P(PartitionedOutputTest, flush) { auto planStats = toPlanStats(task->taskStats()); const auto serdeKindRuntimsStats = - planStats.at(partitionNodeId).customStats.at(Operator::kShuffleSerdeKind); + planStats.at(partitionNodeId) + .customStats.at(std::string(Operator::kShuffleSerdeKind)); ASSERT_EQ(serdeKindRuntimsStats.count, 1); - ASSERT_EQ(serdeKindRuntimsStats.min, static_cast(GetParam())); - ASSERT_EQ(serdeKindRuntimsStats.max, static_cast(GetParam())); + ASSERT_EQ( + serdeKindRuntimsStats.min, + static_cast(VectorSerde::kindByName(GetParam()))); + ASSERT_EQ( + serdeKindRuntimsStats.max, + static_cast(VectorSerde::kindByName(GetParam()))); } TEST_P(PartitionedOutputTest, keyChannelNotAtBeginningWithNulls) { @@ -206,6 +207,81 @@ TEST_P(PartitionedOutputTest, keyChannelNotAtBeginningWithNulls) { .count())); } +// This test verifies that the Destination properly handles multiple +// flush-then-append cycles. After flush(), the VectorStreamGroup must be +// properly reset so that subsequent advance() calls create a fresh serializer +// with proper initialization via createStreamTree(). This test exercises +// the fix for T254261397 where crashes occurred due to improper state after +// flush when current_->clear() was called instead of current_.reset(). +TEST_P(PartitionedOutputTest, multipleFlushCycles) { + // Create input data where each row is large enough to trigger a flush + // (exceeds kMinDestinationSize), but we have many batches to ensure + // multiple flush-then-advance cycles occur for the same destination. + const auto largeString = + std::string(PartitionedOutput::kMinDestinationSize * 2, 'x'); + + auto input = makeRowVector( + {"p1", "v1"}, + {// All rows go to partition 0 to ensure multiple flushes on same dest. + makeFlatVector({0, 0, 0, 0}), + makeFlatVector( + {largeString, largeString, largeString, largeString})}); + + core::PlanNodeId partitionNodeId; + // Use 20 batches to ensure many flush cycles (each row triggers a flush). + auto plan = PlanBuilder() + .values({input}, false, 20) + .partitionedOutput( + {"p1"}, 2, std::vector{"v1"}, GetParam()) + .capturePlanNodeId(partitionNodeId) + .planNode(); + + auto taskId = "local://test-partitioned-output-multiple-flush-cycles-0"; + auto task = Task::create( + taskId, + core::PlanFragment{plan}, + 0, + createQueryContext( + {{core::QueryConfig::kMaxPartitionedOutputBufferSize, + std::to_string(PartitionedOutput::kMinDestinationSize * 2)}}), + Task::ExecutionMode::kParallel); + task->start(1); + + const auto partition0 = getAllData(taskId, 0); + const auto partition1 = getAllData(taskId, 1); + + const auto taskWaitUs = std::chrono::duration_cast( + std::chrono::seconds{10}) + .count(); + auto future = task->taskCompletionFuture() + .within(std::chrono::microseconds(taskWaitUs)) + .via(executor_.get()); + future.wait(); + + ASSERT_TRUE(waitForTaskDriversToFinish(task.get(), taskWaitUs)); + + // With 20 batches * 4 rows per batch = 80 rows going to partition 0. + // Each row exceeds the flush threshold, so we expect many pages (~80). + // The exact count may vary due to targetSizePct randomization, but we + // should have at least 40 pages (assuming some batching). + ASSERT_GE(partition0.size(), 40); + + // Partition 1 should have no data (or just the final flush marker). + ASSERT_LE(partition1.size(), 1); + + auto planStats = toPlanStats(task->taskStats()); + const auto serdeKindRuntimsStats = + planStats.at(partitionNodeId) + .customStats.at(std::string(Operator::kShuffleSerdeKind)); + ASSERT_EQ(serdeKindRuntimsStats.count, 1); + ASSERT_EQ( + serdeKindRuntimsStats.min, + static_cast(VectorSerde::kindByName(GetParam()))); + ASSERT_EQ( + serdeKindRuntimsStats.max, + static_cast(VectorSerde::kindByName(GetParam()))); +} + VELOX_INSTANTIATE_TEST_SUITE_P( PartitionedOutputTest, PartitionedOutputTest, diff --git a/velox/exec/tests/PlanBuilderTest.cpp b/velox/exec/tests/PlanBuilderTest.cpp index ca2a2c073f7..3366cb71a8a 100644 --- a/velox/exec/tests/PlanBuilderTest.cpp +++ b/velox/exec/tests/PlanBuilderTest.cpp @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/core/Expressions.h" #include "velox/exec/WindowFunction.h" +#include "velox/exec/tests/utils/ExpressionBuilder.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/TestIndexStorageConnector.h" #include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" @@ -255,51 +258,75 @@ TEST_F(PlanBuilderTest, missingOutputType) { } TEST_F(PlanBuilderTest, projectExpressions) { + using namespace velox::expr_builder; + // Non-typed Expressions. // Simple field access. auto data = ROW({"c0"}, {BIGINT()}); VELOX_CHECK_EQ( PlanBuilder() .tableScan("tmp", data) - .projectExpressions( - {std::make_shared("c0", std::nullopt)}) + .projectExpressions({col("c0")}) .planNode() ->toString(true, false), "-- Project[1][expressions: (c0:BIGINT, ROW[\"c0\"])] -> c0:BIGINT\n"); + // Dereference test using field access query. data = ROW({"c0"}, {ROW({"field0"}, {BIGINT()})}); VELOX_CHECK_EQ( PlanBuilder() .tableScan("tmp", data) - .projectExpressions({std::make_shared( - "field0", - std::nullopt, - std::vector{ - std::make_shared( - "c0", std::nullopt)})}) + .projectExpressions({col("c0").subfield("field0")}) .planNode() ->toString(true, false), "-- Project[1][expressions: (field0:BIGINT, ROW[\"c0\"][field0])] -> field0:BIGINT\n"); // Test Typed Expressions + auto rowType = ROW({"c0"}, {VARCHAR()}); VELOX_CHECK_EQ( PlanBuilder() - .tableScan("tmp", ROW({"c0"}, {ROW({VARCHAR()})})) + .tableScan("tmp", rowType) .projectExpressions( - {std::make_shared(VARCHAR(), "c0")}) + {core::Expressions::inferTypes(col("c0"), rowType, pool_.get())}) + .planNode() + ->toString(true, false), + "-- Project[1][expressions: (p0:VARCHAR, ROW[\"c0\"])] -> p0:VARCHAR\n"); + + rowType = ROW({"c0"}, {ROW({"field0"}, {VARCHAR()})}); + VELOX_CHECK_EQ( + PlanBuilder() + .tableScan("tmp", rowType) + .projectExpressions({core::Expressions::inferTypes( + col("c0").subfield("field0"), rowType, pool_.get())}) + .planNode() + ->toString(true, false), + "-- Project[1][expressions: (p0:VARCHAR, ROW[\"c0\"][field0])] -> p0:VARCHAR\n"); +} + +TEST_F(PlanBuilderTest, filter) { + auto data = ROW({"c0"}, {BIGINT()}); + constexpr std::string_view expectation = + "-- Filter[1][expression: gt(plus(ROW[\"c0\"],10),100)] -> c0:BIGINT\n"; + + // Filter with SQL snippet. + VELOX_CHECK_EQ( + PlanBuilder() + .tableScan("tmp", data) + .filter("c0 + 10 > 100") .planNode() ->toString(true, false), - "-- Project[1][expressions: (p0:VARCHAR, \"c0\")] -> p0:VARCHAR\n"); + expectation); + + using namespace velox::expr_builder; + + // Filter with untyped expression (same expression as above). VELOX_CHECK_EQ( PlanBuilder() - .tableScan("tmp", ROW({"c0"}, {ROW({VARCHAR()})})) - .projectExpressions({std::make_shared( - VARCHAR(), - std::make_shared(VARCHAR(), "c0"), - "field0")}) + .tableScan("tmp", data) + .filter(col("c0") + 10L > 100L) .planNode() ->toString(true, false), - "-- Project[1][expressions: (p0:VARCHAR, \"c0\"[\"field0\"])] -> p0:VARCHAR\n"); + expectation); } TEST_F(PlanBuilderTest, commitStrategyParameter) { @@ -374,9 +401,10 @@ TEST_F(PlanBuilderTest, indexLookupJoinBuilder) { .rightKeys({"u0"}) .indexSource(rightScan) .joinConditions({"contains(t1, u1)"}) - .includeMatchColumn(false) + .hasMarker(false) .outputLayout({"t0", "u1"}) .joinType(core::JoinType::kInner) + .filter("t0 > 0") .endIndexLookupJoin() .planNode(); @@ -389,7 +417,61 @@ TEST_F(PlanBuilderTest, indexLookupJoinBuilder) { ASSERT_EQ(indexJoinNode->leftKeys()[0]->name(), "t0"); ASSERT_EQ(indexJoinNode->rightKeys()[0]->name(), "u0"); ASSERT_EQ(indexJoinNode->joinConditions().size(), 1); - ASSERT_FALSE(indexJoinNode->includeMatchColumn()); + ASSERT_FALSE(indexJoinNode->hasMarker()); + ASSERT_EQ(indexJoinNode->outputType()->names().size(), 2); + ASSERT_EQ(indexJoinNode->outputType()->names()[0], "t0"); + ASSERT_EQ(indexJoinNode->outputType()->names()[1], "u1"); + ASSERT_EQ(indexJoinNode->filter()->toString(), "gt(ROW[\"t0\"],0)"); +} + +TEST_F(PlanBuilderTest, insertTableHandleParameter) { + auto data = makeRowVector({makeFlatVector(10, folly::identity)}); + auto directory = "/some/test/directory"; + + // Lambda to create a plan with given insertableHandle and verify it + auto testInsertTableHandle = + [&](std::shared_ptr insertTableHandle) { + // Create a plan with insertTableHandle + auto planBuilder = PlanBuilder().values({data}).tableWrite( + directory, + {}, + 0, + {}, + {}, + dwio::common::FileFormat::DWRF, + {}, + PlanBuilder::kHiveDefaultConnectorId, + {}, + nullptr, + "", + common::CompressionKind_NONE, + nullptr, + false, + connector::CommitStrategy::kNoCommit, + insertTableHandle); + + // Verify the plan node has the correct insert Table Handle. + auto tableWriteNode = + std::dynamic_pointer_cast( + planBuilder.planNode()); + ASSERT_NE(tableWriteNode, nullptr); + ASSERT_EQ(tableWriteNode->insertTableHandle(), insertTableHandle); + }; + + auto rowType = ROW({"c0", "c1", "c2"}, {BIGINT(), INTEGER(), SMALLINT()}); + auto hiveHandle = HiveConnectorTestBase::makeHiveInsertTableHandle( + rowType->names(), + rowType->children(), + {rowType->names()[0]}, // partitionedBy + nullptr, // bucketProperty + HiveConnectorTestBase::makeLocationHandle( + "/path/to/test", + std::nullopt, + connector::hive::LocationHandle::TableType::kNew)); + + auto insertHandle = std::make_shared( + std::string(PlanBuilder::kHiveDefaultConnectorId), hiveHandle); + testInsertTableHandle(insertHandle); } } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/PlanNodeSerdeTest.cpp b/velox/exec/tests/PlanNodeSerdeTest.cpp index 6802958dd20..92652b0ad4b 100644 --- a/velox/exec/tests/PlanNodeSerdeTest.cpp +++ b/velox/exec/tests/PlanNodeSerdeTest.cpp @@ -16,10 +16,12 @@ #include "velox/common/base/tests/GTestUtils.h" #include "velox/connectors/hive/HiveConnector.h" #include "velox/exec/PartitionFunction.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" #include "velox/parse/TypeResolver.h" +#include "velox/type/TypeSerde.h" #include "velox/vector/tests/utils/VectorTestBase.h" #include @@ -176,6 +178,14 @@ TEST_F(PlanNodeSerdeTest, markDistinct) { testSerde(plan); } +TEST_F(PlanNodeSerdeTest, enforceDistinct) { + auto plan = PlanBuilder() + .values({data_}) + .enforceDistinct({"c0", "c1", "c2"}, "Test error message") + .planNode(); + testSerde(plan); +} + TEST_F(PlanNodeSerdeTest, nestedLoopJoin) { auto left = makeRowVector( {"t0", "t1", "t2"}, @@ -223,10 +233,8 @@ TEST_F(PlanNodeSerdeTest, enforceSingleRow) { } TEST_F(PlanNodeSerdeTest, exchange) { - for (auto serdeKind : std::vector{ - VectorSerde::Kind::kPresto, - VectorSerde::Kind::kCompactRow, - VectorSerde::Kind::kUnsafeRow}) { + for (auto serdeKind : + std::vector{"Presto", "CompactRow", "UnsafeRow"}) { SCOPED_TRACE(fmt::format("serdeKind: {}", serdeKind)); auto plan = PlanBuilder() .exchange( @@ -310,10 +318,8 @@ TEST_F(PlanNodeSerdeTest, limit) { } TEST_F(PlanNodeSerdeTest, mergeExchange) { - for (auto serdeKind : std::vector{ - VectorSerde::Kind::kPresto, - VectorSerde::Kind::kCompactRow, - VectorSerde::Kind::kUnsafeRow}) { + for (auto serdeKind : + std::vector{"Presto", "CompactRow", "UnsafeRow"}) { auto plan = PlanBuilder() .mergeExchange( ROW({"a", "b", "c"}, {BIGINT(), DOUBLE(), VARCHAR()}), @@ -420,10 +426,8 @@ TEST_F(PlanNodeSerdeTest, orderBy) { } TEST_F(PlanNodeSerdeTest, partitionedOutput) { - for (auto serdeKind : std::vector{ - VectorSerde::Kind::kPresto, - VectorSerde::Kind::kCompactRow, - VectorSerde::Kind::kUnsafeRow}) { + for (auto serdeKind : + std::vector{"Presto", "CompactRow", "UnsafeRow"}) { SCOPED_TRACE(fmt::format("serdeKind: {}", serdeKind)); auto plan = PlanBuilder() @@ -534,6 +538,22 @@ TEST_F(PlanNodeSerdeTest, hashJoin) { .planNode(); testSerde(plan); + + // nullAsValue join (used for EXCEPT/INTERSECT). + plan = PlanBuilder(planNodeIdGenerator) + .values({probe}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({build}).planNode(), + "", + {"t0", "t1"}, + core::JoinType::kAnti, + /*nullAware=*/false, + /*nullAsValue=*/true) + .planNode(); + + testSerde(plan); } TEST_F(PlanNodeSerdeTest, topN) { @@ -656,14 +676,44 @@ TEST_F(PlanNodeSerdeTest, rowNumber) { } TEST_F(PlanNodeSerdeTest, scan) { - auto plan = PlanBuilder(pool_.get()) - .tableScan( - ROW({"a", "b", "c", "d"}, - {BIGINT(), BIGINT(), BOOLEAN(), DOUBLE()}), - {"a < 5", "b = 7", "c = true", "d > 0.01"}, - "a + b < 100") - .planNode(); - testSerde(plan); + { + auto plan = PlanBuilder(pool_.get()) + .tableScan( + ROW({"a", "b", "c", "d"}, + {BIGINT(), BIGINT(), BOOLEAN(), DOUBLE()}), + {"a < 5", "b = 7", "c = true", "d > 0.01"}, + "a + b < 100") + .planNode(); + testSerde(plan); + } + + { + auto plan = + PlanBuilder() + .startTableScan() + .outputType(ROW({"x"}, {BIGINT()})) + .assignments( + {{"x", HiveConnectorTestBase::regularColumn("a", BIGINT())}}) + .dataColumns(ROW({"a", "b"}, {BIGINT(), BIGINT()})) + .filterColumnHandles({ + HiveConnectorTestBase::partitionKey("ds", VARCHAR()), + HiveConnectorTestBase::regularColumn("a", BIGINT()), + }) + .remainingFilter("length(ds) + a % 2 > 0") + .endTableScan() + .planNode(); + testSerde(plan); + } + + { + auto plan = PlanBuilder() + .startTableScan() + .outputType(ROW({"x"}, {BIGINT()})) + .sampleRate(0.5) + .endTableScan() + .planNode(); + testSerde(plan); + } } TEST_F(PlanNodeSerdeTest, topNRowNumber) { @@ -964,4 +1014,27 @@ TEST_F(PlanNodeSerdeTest, columnStatsSpec) { } } +TEST_F(PlanNodeSerdeTest, countingJoin) { + for (auto joinType : + {core::JoinType::kCountingAnti, + core::JoinType::kCountingLeftSemiFilter}) { + SCOPED_TRACE(core::JoinTypeName::toName(joinType)); + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values({data_}) + .hashJoin( + {"c0"}, + {"u_c0"}, + PlanBuilder(planNodeIdGenerator) + .values({data_}) + .project({"c0 as u_c0"}) + .planNode(), + "", + {"c0", "c1"}, + joinType) + .planNode(); + testSerde(plan); + } +} + } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/PlanNodeStatsTest.cpp b/velox/exec/tests/PlanNodeStatsTest.cpp new file mode 100644 index 00000000000..2a59537eb43 --- /dev/null +++ b/velox/exec/tests/PlanNodeStatsTest.cpp @@ -0,0 +1,34 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/PlanNodeStats.h" +#include + +namespace facebook::velox::exec::test { + +TEST(PlanNodeStatsTest, exprStatsTotal) { + PlanNodeStats stats; + stats.expressionStats["foo"] = ExprStats{ + .timing = {.wallNanos = 1, .cpuNanos = 2}, + .numProcessedRows = 3, + .numProcessedVectors = 4}; + + PlanNodeStats total; + total += stats; + EXPECT_EQ(total.expressionStats["foo"], stats.expressionStats["foo"]); +} + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/PlanNodeToStringTest.cpp b/velox/exec/tests/PlanNodeToStringTest.cpp index 334c35127db..d734e1f3a5a 100644 --- a/velox/exec/tests/PlanNodeToStringTest.cpp +++ b/velox/exec/tests/PlanNodeToStringTest.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/connectors/hive/HiveConnector.h" #include "velox/exec/WindowFunction.h" #include "velox/exec/tests/utils/PlanBuilder.h" @@ -412,6 +413,28 @@ TEST_F(PlanNodeToStringTest, hashJoin) { ASSERT_EQ( "-- HashJoin[2][ANTI t_c0=u_c0] -> t_c0:SMALLINT, t_c1:INTEGER\n", plan->toString(true, false)); + + plan = PlanBuilder() + .values({data_}) + .project({"c0 as t_c0", "c1 as t_c1"}) + .hashJoin( + {"t_c0"}, + {"u_c0"}, + PlanBuilder() + .values({data_}) + .project({"c0 as u_c0", "c1 as u_c1"}) + .planNode(), + "", + {"t_c0", "t_c1"}, + core::JoinType::kAnti, + false /*nullAware*/, + true /*nullAsValue*/) + .planNode(); + + ASSERT_EQ("-- HashJoin[2]\n", plan->toString()); + ASSERT_EQ( + "-- HashJoin[2][ANTI t_c0=u_c0, null as value] -> t_c0:SMALLINT, t_c1:INTEGER\n", + plan->toString(true, false)); } TEST_F(PlanNodeToStringTest, mergeJoin) { @@ -592,10 +615,8 @@ TEST_F(PlanNodeToStringTest, localPartition) { } TEST_F(PlanNodeToStringTest, partitionedOutput) { - for (auto serdeKind : std::vector{ - VectorSerde::Kind::kPresto, - VectorSerde::Kind::kCompactRow, - VectorSerde::Kind::kUnsafeRow}) { + for (auto serdeKind : + std::vector{"Presto", "CompactRow", "UnsafeRow"}) { SCOPED_TRACE(fmt::format("serdeKind: {}", serdeKind)); auto plan = PlanBuilder() @@ -698,10 +719,8 @@ TEST_F(PlanNodeToStringTest, localMerge) { } TEST_F(PlanNodeToStringTest, exchange) { - for (auto serdeKind : std::vector{ - VectorSerde::Kind::kPresto, - VectorSerde::Kind::kCompactRow, - VectorSerde::Kind::kUnsafeRow}) { + for (auto serdeKind : + std::vector{"Presto", "CompactRow", "UnsafeRow"}) { SCOPED_TRACE(fmt::format("serdeKind: {}", serdeKind)); auto plan = PlanBuilder() @@ -716,10 +735,8 @@ TEST_F(PlanNodeToStringTest, exchange) { } TEST_F(PlanNodeToStringTest, mergeExchange) { - for (auto serdeKind : std::vector{ - VectorSerde::Kind::kPresto, - VectorSerde::Kind::kCompactRow, - VectorSerde::Kind::kUnsafeRow}) { + for (auto serdeKind : + std::vector{"Presto", "CompactRow", "UnsafeRow"}) { SCOPED_TRACE(fmt::format("serdeKind: {}", serdeKind)); auto plan = @@ -994,5 +1011,103 @@ TEST_F(PlanNodeToStringTest, markDistinct) { op->toString(true, false)); } +TEST_F(PlanNodeToStringTest, tableWrite) { + auto outputDir = common::testutil::TempDirectoryPath::create(); + + // TableWrite without stats. + { + auto plan = PlanBuilder() + .values({data_}) + .tableWrite(outputDir->getPath()) + .planNode(); + ASSERT_EQ("-- TableWrite[1]\n", plan->toString()); + ASSERT_EQ( + "-- TableWrite[1][test-hive, c0, c1, c2] -> rows:BIGINT, fragments:VARBINARY, commitcontext:VARBINARY\n", + plan->toString(true, false)); + } + + // TableWrite with stats (no grouping keys) and TableWriteMerge. + { + core::TableWriteNodePtr writeNode; + auto plan = PlanBuilder() + .values({data_}) + .tableWrite( + outputDir->getPath(), + dwio::common::FileFormat::DWRF, + {"min(c0)"}) + .capturePlanNode(writeNode) + .localGather() + .tableWriteMerge() + .planNode(); + + ASSERT_EQ("-- TableWrite[1]\n", writeNode->toString()); + ASSERT_EQ( + "-- TableWrite[1][test-hive, c0, c1, c2, stats[PARTIAL: min(ROW[\"c0\"])]] -> rows:BIGINT, fragments:VARBINARY, commitcontext:VARBINARY, a0:SMALLINT\n", + writeNode->toString(true, false)); + + ASSERT_EQ("-- TableWriteMerge[3]\n", plan->toString()); + ASSERT_EQ( + "-- TableWriteMerge[3][stats[INTERMEDIATE: min(\"a0\")]] -> rows:BIGINT, fragments:VARBINARY, commitcontext:VARBINARY, a0:SMALLINT\n", + plan->toString(true, false)); + } + + // TableWrite with stats and grouping keys (partitioned table). + { + core::TableWriteNodePtr writeNode; + auto plan = PlanBuilder() + .values({data_}) + .tableWrite( + outputDir->getPath(), + {"c2"}, + dwio::common::FileFormat::DWRF, + {"min(c0)", "max(c1)"}) + .capturePlanNode(writeNode) + .localGather() + .tableWriteMerge() + .planNode(); + + ASSERT_EQ("-- TableWrite[1]\n", writeNode->toString()); + ASSERT_EQ( + "-- TableWrite[1][test-hive, c0, c1, c2, stats[PARTIAL [c2]: min(ROW[\"c0\"]), max(ROW[\"c1\"])]] -> rows:BIGINT, fragments:VARBINARY, commitcontext:VARBINARY, c2:BIGINT, a0:SMALLINT, a1:INTEGER\n", + writeNode->toString(true, false)); + + ASSERT_EQ("-- TableWriteMerge[3]\n", plan->toString()); + ASSERT_EQ( + "-- TableWriteMerge[3][stats[INTERMEDIATE [c2]: min(\"a0\"), max(\"a1\")]] -> rows:BIGINT, fragments:VARBINARY, commitcontext:VARBINARY, c2:BIGINT, a0:SMALLINT, a1:INTEGER\n", + plan->toString(true, false)); + } +} + +TEST_F(PlanNodeToStringTest, countingJoin) { + auto makePlan = [&](core::JoinType joinType) { + auto planNodeIdGenerator = std::make_shared(); + return PlanBuilder(planNodeIdGenerator) + .values({data_}) + .hashJoin( + {"c0"}, + {"u_c0"}, + PlanBuilder(planNodeIdGenerator) + .values({data_}) + .project({"c0 as u_c0"}) + .planNode(), + "", + {"c0", "c1"}, + joinType) + .planNode(); + }; + + auto plan = makePlan(core::JoinType::kCountingAnti); + ASSERT_EQ("-- HashJoin[3]\n", plan->toString()); + ASSERT_EQ( + "-- HashJoin[3][COUNTING ANTI c0=u_c0] -> c0:SMALLINT, c1:INTEGER\n", + plan->toString(true, false)); + + plan = makePlan(core::JoinType::kCountingLeftSemiFilter); + ASSERT_EQ("-- HashJoin[3]\n", plan->toString()); + ASSERT_EQ( + "-- HashJoin[3][COUNTING LEFT SEMI (FILTER) c0=u_c0] -> c0:SMALLINT, c1:INTEGER\n", + plan->toString(true, false)); +} + } // namespace } // namespace facebook::velox::exec diff --git a/velox/exec/tests/PrefixSortTest.cpp b/velox/exec/tests/PrefixSortTest.cpp index f404fa05ac4..732d590dc05 100644 --- a/velox/exec/tests/PrefixSortTest.cpp +++ b/velox/exec/tests/PrefixSortTest.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include "velox/exec/PrefixSort.h" diff --git a/velox/exec/tests/PrestoQueryRunnerIntermediateTypeTransformTestBase.cpp b/velox/exec/tests/PrestoQueryRunnerIntermediateTypeTransformTestBase.cpp index 5bc0a66bc91..a52fc815068 100644 --- a/velox/exec/tests/PrestoQueryRunnerIntermediateTypeTransformTestBase.cpp +++ b/velox/exec/tests/PrestoQueryRunnerIntermediateTypeTransformTestBase.cpp @@ -148,20 +148,22 @@ void PrestoQueryRunnerIntermediateTypeTransformTestBase::testRow( test(vectorMaker_.rowVector({field1, field2, field3})); // Test row vector some nulls. - test(std::make_shared( - pool_.get(), - rowType, - makeNulls(size, [](vector_size_t row) { return row % 10 == 0; }), - size, - std::vector{field1, field2, field3})); + test( + std::make_shared( + pool_.get(), + rowType, + makeNulls(size, [](vector_size_t row) { return row % 10 == 0; }), + size, + std::vector{field1, field2, field3})); // Test row vector all nulls. - test(std::make_shared( - pool_.get(), - rowType, - makeNulls(size, [](vector_size_t) { return true; }), - size, - std::vector{field1, field2, field3})); + test( + std::make_shared( + pool_.get(), + rowType, + makeNulls(size, [](vector_size_t) { return true; }), + size, + std::vector{field1, field2, field3})); const auto base = vectorMaker_.rowVector({field1, field2, field3}); testDictionary(base); diff --git a/velox/exec/tests/PrestoQueryRunnerKHyperLogLogTransformTest.cpp b/velox/exec/tests/PrestoQueryRunnerKHyperLogLogTransformTest.cpp new file mode 100644 index 00000000000..80a4f5a7cd1 --- /dev/null +++ b/velox/exec/tests/PrestoQueryRunnerKHyperLogLogTransformTest.cpp @@ -0,0 +1,56 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.h" +#include "velox/exec/tests/PrestoQueryRunnerIntermediateTypeTransformTestBase.h" +#include "velox/functions/prestosql/types/KHyperLogLogType.h" + +namespace facebook::velox::exec::test { +namespace { + +class PrestoQueryRunnerKHyperLogLogTransformTest + : public PrestoQueryRunnerIntermediateTypeTransformTestBase {}; + +TEST_F(PrestoQueryRunnerKHyperLogLogTransformTest, isIntermediateOnlyType) { + ASSERT_TRUE(isIntermediateOnlyType(KHYPERLOGLOG())); + ASSERT_TRUE(isIntermediateOnlyType(ARRAY(KHYPERLOGLOG()))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(KHYPERLOGLOG(), SMALLINT()))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(VARBINARY(), KHYPERLOGLOG()))); + ASSERT_TRUE(isIntermediateOnlyType(ROW({KHYPERLOGLOG(), SMALLINT()}))); + ASSERT_TRUE(isIntermediateOnlyType(ROW( + {SMALLINT(), + TIMESTAMP(), + ARRAY(ROW({MAP(VARCHAR(), KHYPERLOGLOG())}))}))); +} + +TEST_F(PrestoQueryRunnerKHyperLogLogTransformTest, transform) { + test(KHYPERLOGLOG()); +} + +TEST_F(PrestoQueryRunnerKHyperLogLogTransformTest, transformArray) { + testArray(KHYPERLOGLOG()); +} + +TEST_F(PrestoQueryRunnerKHyperLogLogTransformTest, transformMap) { + testMap(KHYPERLOGLOG()); +} + +TEST_F(PrestoQueryRunnerKHyperLogLogTransformTest, transformRow) { + testRow(KHYPERLOGLOG()); +} + +} // namespace +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/PrestoQueryRunnerSetDigestTransformTest.cpp b/velox/exec/tests/PrestoQueryRunnerSetDigestTransformTest.cpp new file mode 100644 index 00000000000..af15b0dec71 --- /dev/null +++ b/velox/exec/tests/PrestoQueryRunnerSetDigestTransformTest.cpp @@ -0,0 +1,54 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.h" +#include "velox/exec/tests/PrestoQueryRunnerIntermediateTypeTransformTestBase.h" +#include "velox/functions/prestosql/types/SetDigestType.h" + +namespace facebook::velox::exec::test { +namespace { + +class PrestoQueryRunnerSetDigestTransformTest + : public PrestoQueryRunnerIntermediateTypeTransformTestBase {}; + +TEST_F(PrestoQueryRunnerSetDigestTransformTest, isIntermediateOnlyType) { + ASSERT_TRUE(isIntermediateOnlyType(SETDIGEST())); + ASSERT_TRUE(isIntermediateOnlyType(ARRAY(SETDIGEST()))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(SETDIGEST(), SMALLINT()))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(VARBINARY(), SETDIGEST()))); + ASSERT_TRUE(isIntermediateOnlyType(ROW({SETDIGEST(), SMALLINT()}))); + ASSERT_TRUE(isIntermediateOnlyType(ROW( + {SMALLINT(), TIMESTAMP(), ARRAY(ROW({MAP(VARCHAR(), SETDIGEST())}))}))); +} + +TEST_F(PrestoQueryRunnerSetDigestTransformTest, transform) { + test(SETDIGEST()); +} + +TEST_F(PrestoQueryRunnerSetDigestTransformTest, transformArray) { + testArray(SETDIGEST()); +} + +TEST_F(PrestoQueryRunnerSetDigestTransformTest, transformMap) { + testMap(SETDIGEST()); +} + +TEST_F(PrestoQueryRunnerSetDigestTransformTest, transformRow) { + testRow(SETDIGEST()); +} + +} // namespace +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/PrestoQueryRunnerTimeTransformTest.cpp b/velox/exec/tests/PrestoQueryRunnerTimeTransformTest.cpp new file mode 100644 index 00000000000..0e72861ff20 --- /dev/null +++ b/velox/exec/tests/PrestoQueryRunnerTimeTransformTest.cpp @@ -0,0 +1,113 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/tests/PrestoQueryRunnerIntermediateTypeTransformTestBase.h" + +namespace facebook::velox::exec::test { +namespace { + +class PrestoQueryRunnerTimeTransformTest + : public PrestoQueryRunnerIntermediateTypeTransformTestBase {}; + +// Test that TIME is recognized as an intermediate type that needs +// transformation +TEST_F(PrestoQueryRunnerTimeTransformTest, isIntermediateOnlyType) { + // Core test: TIME should be an intermediate type + ASSERT_TRUE(isIntermediateOnlyType(TIME())); + + // Complex types containing TIME should also be intermediate types + ASSERT_TRUE(isIntermediateOnlyType(ARRAY(TIME()))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(VARCHAR(), TIME()))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(TIME(), VARCHAR()))); + ASSERT_TRUE(isIntermediateOnlyType(ROW({TIME(), BIGINT()}))); +} + +TEST_F(PrestoQueryRunnerTimeTransformTest, roundTrip) { + // Test basic TIME values (no nulls, some nulls, all nulls) + std::vector> no_nulls{0, 3661000, 43200000, 86399999}; + test(makeNullableFlatVector(no_nulls, TIME())); + + std::vector> some_nulls{ + 0, 3661000, std::nullopt, 86399999}; + test(makeNullableFlatVector(some_nulls, TIME())); + + std::vector> all_nulls{ + std::nullopt, std::nullopt, std::nullopt}; + test(makeNullableFlatVector(all_nulls, TIME())); +} + +TEST_F(PrestoQueryRunnerTimeTransformTest, transformArray) { + auto input = makeNullableFlatVector( + std::vector>{ + 0, // 00:00:00.000 + 1000, // 00:00:01.000 + 3661000, // 01:01:01.000 + 43200000, // 12:00:00.000 (noon) + 86399999, // 23:59:59.999 + 3723456, // 01:02:03.456 + 45678901, // 12:41:18.901 + std::nullopt, + 72000000, // 20:00:00.000 + 36000000 // 10:00:00.000 + }, + TIME()); + testArray(input); +} + +TEST_F(PrestoQueryRunnerTimeTransformTest, transformMap) { + // keys can't be null for maps + auto keys = makeNullableFlatVector( + std::vector>{ + 0, // 00:00:00.000 + 3661000, // 01:01:01.000 + 43200000, // 12:00:00.000 + 86399999, // 23:59:59.999 + 36000000, // 10:00:00.000 + 72000000, // 20:00:00.000 + 1800000, // 00:30:00.000 + 7200000, // 02:00:00.000 + 64800000, // 18:00:00.000 + 32400000 // 09:00:00.000 + }, + TIME()); + + auto values = makeNullableFlatVector( + {100, 200, std::nullopt, 400, 500, std::nullopt, 700, 800, 900, 1000}, + BIGINT()); + + testMap(keys, values); +} + +TEST_F(PrestoQueryRunnerTimeTransformTest, transformRow) { + auto input = makeNullableFlatVector( + std::vector>{ + 0, // 00:00:00.000 + 3661000, // 01:01:01.000 + 43200000, // 12:00:00.000 + 86399999, // 23:59:59.999 + std::nullopt, + 36000000, // 10:00:00.000 + 72000000, // 20:00:00.000 + 1800000, // 00:30:00.000 + 7200000, // 02:00:00.000 + 64800000 // 18:00:00.000 + }, + TIME()); + testRow({input}, {"time_col"}); +} + +} // namespace +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/PrestoQueryRunnerTimestampWithTimeZoneTransformTest.cpp b/velox/exec/tests/PrestoQueryRunnerTimestampWithTimeZoneTransformTest.cpp index afcd8d33239..4b60208686e 100644 --- a/velox/exec/tests/PrestoQueryRunnerTimestampWithTimeZoneTransformTest.cpp +++ b/velox/exec/tests/PrestoQueryRunnerTimestampWithTimeZoneTransformTest.cpp @@ -169,20 +169,22 @@ TEST_F( test(vectorMaker_.rowVector({field1, field2, field3})); // Test row vector some nulls. - test(std::make_shared( - pool_.get(), - rowType, - makeNulls(size, [](vector_size_t row) { return row % 10 == 0; }), - size, - std::vector{field1, field2, field3})); + test( + std::make_shared( + pool_.get(), + rowType, + makeNulls(size, [](vector_size_t row) { return row % 10 == 0; }), + size, + std::vector{field1, field2, field3})); // Test row vector all nulls. - test(std::make_shared( - pool_.get(), - rowType, - makeNulls(size, [](vector_size_t) { return true; }), - size, - std::vector{field1, field2, field3})); + test( + std::make_shared( + pool_.get(), + rowType, + makeNulls(size, [](vector_size_t) { return true; }), + size, + std::vector{field1, field2, field3})); const auto base = vectorMaker_.rowVector({field1, field2, field3}); testDictionary(base); diff --git a/velox/exec/tests/PrintPlanWithStatsTest.cpp b/velox/exec/tests/PrintPlanWithStatsTest.cpp index 73928fce6a5..b718c0c71b9 100644 --- a/velox/exec/tests/PrintPlanWithStatsTest.cpp +++ b/velox/exec/tests/PrintPlanWithStatsTest.cpp @@ -15,17 +15,18 @@ */ #include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/exec/PlanNodeStats.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include #include using namespace facebook::velox; using namespace facebook::velox::exec::test; +using namespace facebook::velox::common::testutil; using facebook::velox::exec::test::PlanBuilder; @@ -48,6 +49,12 @@ void compareOutputs( for (; std::getline(iss, line);) { lineCount++; std::vector potentialLines; + if (expectedLineIndex >= expectedRegex.size()) { + ASSERT_FALSE(true) << "Output has more lines than expected." + << "\n Source: " << testName + << "\n Line number: " << lineCount + << "\n Unexpected Line: " << line; + } auto expectedLine = expectedRegex.at(expectedLineIndex++); while (!RE2::FullMatch(line, expectedLine.line)) { potentialLines.push_back(expectedLine.line); @@ -59,11 +66,18 @@ void compareOutputs( << "\n Expected Line one of: " << folly::join(",", potentialLines); } + if (expectedLineIndex >= expectedRegex.size()) { + ASSERT_FALSE(true) + << "Output did not match and no more patterns to check." + << "\n Source: " << testName << "\n Line number: " << lineCount + << "\n Line: " << line + << "\n Expected Line one of: " << folly::join(",", potentialLines); + } expectedLine = expectedRegex.at(expectedLineIndex++); } } for (int i = expectedLineIndex; i < expectedRegex.size(); i++) { - ASSERT_TRUE(expectedRegex[expectedLineIndex].optional); + ASSERT_TRUE(expectedRegex[i].optional); } } @@ -157,22 +171,28 @@ TEST_F(PrintPlanWithStatsTest, innerJoinWithTableScan) { {" dataSourceLazyCpuNanos[ ]* sum: .+, count: .+, min: .+, max: .+"}, {" dataSourceLazyInputBytes[ ]* sum: .+, count: .+, min: .+, max: .+"}, {" dataSourceLazyWallNanos[ ]* sum: .+, count: 1, min: .+, max: .+"}, + {" driverCpuTimeNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningAddInputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningFinishWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" runningIsBlockedWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" -- HashJoin\\[3\\]\\[INNER c0=u_c0\\] -> c0:INTEGER, c1:BIGINT, u_c1:BIGINT"}, {" Output: 2000 rows \\(.+\\), Cpu time: .+, Blocked wall time: .+, Peak memory: .+, Memory allocations: .+, CPU breakdown: B/I/O/F (.+/.+/.+/.+)"}, {" HashBuild: Input: 100 rows \\(.+\\), Output: 0 rows \\(.+\\), Cpu time: .+, Blocked wall time: .+, Peak memory: .+, Memory allocations: .+, Threads: 1, CPU breakdown: B/I/O/F (.+/.+/.+/.+)"}, {" distinctKey0\\s+sum: 101, count: 1, min: 101, max: 101, avg: 101"}, + {" driverCpuTimeNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" hashtable.buildWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" hashtable.capacity\\s+sum: 200, count: 1, min: 200, max: 200, avg: 200"}, + {" hashtable.hashMode\\s+sum: \\d+, count: 1, min: \\d+, max: \\d+, avg: \\d+"}, {" hashtable.numDistinct\\s+sum: 100, count: 1, min: 100, max: 100, avg: 100"}, {" hashtable.numRehashes\\s+sum: 1, count: 1, min: 1, max: 1, avg: 1"}, + {" hashtable.vectorHasherMergeCpuNanos\\s+sum: .*, count: 1, min: .*, max: .*, avg: .*"}, {" queuedWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" rangeKey0\\s+sum: 200, count: 1, min: 200, max: 200, avg: 200"}, {" runningAddInputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningFinishWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" runningIsBlockedWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" HashProbe: Input: 2000 rows \\(.+\\), Output: 2000 rows \\(.+\\), Cpu time: .+, Blocked wall time: .+, Peak memory: .+, Memory allocations: .+, Threads: 1, CPU breakdown: B/I/O/F (.+/.+/.+/.+)"}, // These lines may or may not appear depending on whether the operator // gets blocked during a run. @@ -180,6 +200,8 @@ TEST_F(PrintPlanWithStatsTest, innerJoinWithTableScan) { true}, {" blockedWaitForJoinBuildWallNanos\\s+sum: .+, count: 1, min: .+, max: .+", true}, + {" driverCpuTimeNanos\\s+sum: .+, count: 1, min: .+, max: .+", + true}, {" dynamicFiltersProduced\\s+sum: 1, count: 1, min: 1, max: 1, avg: 1", true}, {" queuedWallNanos\\s+sum: .+, count: 1, min: .+, max: .+", true}, @@ -189,46 +211,78 @@ TEST_F(PrintPlanWithStatsTest, innerJoinWithTableScan) { {" runningAddInputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningFinishWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" runningIsBlockedWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" -- TableScan\\[2\\]\\[table: hive_table\\] -> c0:INTEGER, c1:BIGINT"}, {" Input: 2000 rows \\(.+\\), Raw Input: 20480 rows \\(.+\\), Output: 2000 rows \\(.+\\), Cpu time: .+, Blocked wall time: .+, Peak memory: .+, Memory allocations: .+, Threads: 1, Splits: 20, DynamicFilter producer plan nodes: 3, CPU breakdown: B/I/O/F (.+/.+/.+/.+)"}, + // These lines may or may not appear depending on whether the operator + // gets blocked waiting for a split during a run. + {" blockedWaitForSplitTimes\\s+sum: .+, count: .+, min: 1, max: 1, avg: 1", + true}, + {" blockedWaitForSplitWallNanos\\s+sum: .+, count: .+, min: .+, max: .+", + true}, + {" cacheWaitWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" coalescedSsdLoadWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" coalescedStorageLoadWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+", + true}, {" connectorSplitSize[ ]* sum: .+, count: .+, min: .+, max: .+"}, - {" dataSourceAddSplitWallNanos[ ]* sum: .+, count: 1, min: .+, max: .+"}, {" dataSourceReadWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" driverCpuTimeNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" dynamicFiltersAccepted[ ]* sum: 1, count: 1, min: 1, max: 1, avg: 1"}, + {" fileFormat\\.dwrf[ ]* sum: .+, count: 1, min: .+, max: .+, avg: .+", + true}, {" footerBufferOverread[ ]* sum: .+, count: 1, min: .+, max: .+"}, - {" ioWaitWallNanos [ ]* sum: .+, count: .+ min: .+, max: .+"}, - {" maxSingleIoWaitWallNanos[ ]*sum: .+, count: 1, min: .+, max: .+"}, - {" numPrefetch [ ]* sum: .+, count: 1, min: .+, max: .+"}, - {" numRamRead [ ]* sum: 60, count: 1, min: 60, max: 60, avg: 60"}, - {" numStorageRead [ ]* sum: .+, count: 1, min: .+, max: .+"}, + {" ioWaitWallNanos [ ]* sum: .+, count: .+ min: .+, max: .+", + true}, + {" numPrefetch [ ]* sum: .+, count: 1, min: .+, max: .+", + true}, + {" numRamRead [ ]* sum: .+, count: 1, min: .+, max: .+, avg: .+", + true}, {" numStripes[ ]* sum: .+, count: 1, min: .+, max: .+"}, - {" overreadBytes[ ]* sum: 0B, count: 1, min: 0B, max: 0B, avg: 0B"}, - {" prefetchBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, + {" overreadBytes[ ]* sum: .+, count: 1, min: .+, max: .+", + true}, + {" prefetchBytes [ ]* sum: .+, count: .+, min: .+, max: .+", + true}, {" preloadSplitPrepareTimeNanos[ ]* sum: .+, count: .+, min: .+, max: .+, avg: .+"}, {" preloadedSplits[ ]+sum: .+, count: .+, min: .+, max: .+", true}, {" processedSplits[ ]+sum: 20, count: 1, min: 20, max: 20, avg: 20"}, {" processedStrides[ ]+sum: 20, count: 1, min: 20, max: 20, avg: 20"}, - {" ramReadBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, + {" processedUnits [ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" ramReadBytes [ ]* sum: .+, count: .+, min: .+, max: .+", + true}, {" readyPreloadedSplits[ ]+sum: .+, count: .+, min: .+, max: .+", true}, {" runningAddInputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningFinishWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, - {" storageReadBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, + {" runningIsBlockedWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" ssdCacheReadWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" storageReadBytes [ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" storageReadWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" totalRemainingFilterCpuNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, {" totalRemainingFilterWallNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, - {" totalScanTime [ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" totalScanTime [ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" unitLoadNanos[ ]* sum: .+, count: .+, min: .+, max: .+, avg: .+"}, {" waitForPreloadSplitNanos[ ]* sum: .+, count: .+, min: .+, max: .+, avg: .+"}, {" -- Project\\[1\\]\\[expressions: \\(u_c0:INTEGER, ROW\\[\"c0\"\\]\\), \\(u_c1:BIGINT, ROW\\[\"c1\"\\]\\)\\] -> u_c0:INTEGER, u_c1:BIGINT"}, {" Output: 100 rows \\(.+\\), Cpu time: .+, Blocked wall time: .+, Peak memory: 0B, Memory allocations: .+, Threads: 1, CPU breakdown: B/I/O/F (.+/.+/.+/.+)"}, + {" driverCpuTimeNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningAddInputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningFinishWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" runningIsBlockedWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" -- Values\\[0\\]\\[100 rows in 1 vectors\\] -> c0:INTEGER, c1:BIGINT"}, {" Input: 0 rows \\(.+\\), Output: 100 rows \\(.+\\), Cpu time: .+, Blocked wall time: .+, Peak memory: 0B, Memory allocations: .+, Threads: 1, CPU breakdown: B/I/O/F (.+/.+/.+/.+)"}, + {" driverCpuTimeNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningAddInputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningFinishWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, - {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}}); + {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" runningIsBlockedWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}}); } TEST_F(PrintPlanWithStatsTest, partialAggregateWithTableScan) { @@ -269,51 +323,142 @@ TEST_F(PrintPlanWithStatsTest, partialAggregateWithTableScan) { {" -- TableScan\\[0\\]\\[table: hive_table\\] -> c0:BIGINT, c1:INTEGER, c2:SMALLINT, c3:REAL, c4:DOUBLE, c5:VARCHAR"}, {" Input: 10000 rows \\(.+\\), Output: 10000 rows \\(.+\\), Cpu time: .+, Blocked wall time: .+, Peak memory: .+, Memory allocations: .+, Threads: 1, Splits: 1, CPU breakdown: B/I/O/F (.+/.+/.+/.+)"}}); + std::vector metrics = { + {"-- Aggregation\\[1\\]\\[PARTIAL \\[c5\\] a0 := max\\(ROW\\[\"c0\"\\]\\), a1 := sum\\(ROW\\[\"c1\"\\]\\), a2 := sum\\(ROW\\[\"c2\"\\]\\), a3 := sum\\(ROW\\[\"c3\"\\]\\), a4 := sum\\(ROW\\[\"c4\"\\]\\)\\] -> c5:VARCHAR, a0:BIGINT, a1:BIGINT, a2:BIGINT, a3:DOUBLE, a4:DOUBLE"}, + {" Output: .+, Cpu time: .+, Blocked wall time: .+, Peak memory: .+, Memory allocations: .+, Threads: 1, CPU breakdown: B/I/O/F (.+/.+/.+/.+)"}, + {" dataSourceLazyCpuNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, + {" dataSourceLazyInputBytes\\s+sum: .+, count: .+, min: .+, max: .+"}, + {" dataSourceLazyWallNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, + {" distinctKey0\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" driverCpuTimeNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" hashtable.capacity\\s+sum: (?:1273|1252), count: 1, min: (?:1273|1252), max: (?:1273|1252), avg: (?:1273|1252)"}, + {" hashtable.hashMode\\s+sum: \\d+, count: 1, min: \\d+, max: \\d+, avg: \\d+"}, + {" hashtable.numDistinct\\s+sum: (?:849|835), count: 1, min: (?:849|835), max: (?:849|835), avg: (?:849|835)"}, + {" hashtable.numRehashes\\s+sum: 1, count: 1, min: 1, max: 1, avg: 1"}, + {" loadedToValueHook\\s+sum: 50000, count: 5, min: 10000, max: 10000, avg: 10000"}, + {" runningAddInputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" runningFinishWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" runningIsBlockedWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}}; + + std::vector scanDisablePrefetchMetrics = { + {" -- TableScan\\[0\\]\\[table: hive_table\\] -> c0:BIGINT, c1:INTEGER, c2:SMALLINT, c3:REAL, c4:DOUBLE, c5:VARCHAR"}, + {" Input: 10000 rows \\(.+\\), Output: 10000 rows \\(.+\\), Cpu time: .+, Blocked wall time: .+, Peak memory: .+, Memory allocations: .+, Threads: 1, Splits: 1, CPU breakdown: B/I/O/F (.+/.+/.+/.+)"}, + {" blockedWaitForSplitTimes\\s+sum: 1, count: 1, min: 1, max: 1, avg: 1", + true}, + {" blockedWaitForSplitWallNanos\\s+sum: .+, count: 1, min: .+, max: .+", + true}, + {" cacheWaitWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" coalescedSsdLoadWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" coalescedStorageLoadWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" connectorSplitSize[ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" dataSourceAddSplitWallNanos[ ]* sum: .+, count: 1, min: .+, max: .+"}, + {" dataSourceReadWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" driverCpuTimeNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" fileFormat\\.dwrf[ ]* sum: .+, count: 1, min: .+, max: .+, avg: .+", + true}, + {" footerBufferOverread[ ]* sum: .+, count: 1, min: .+, max: .+", + true}, + {" ioWaitWallNanos [ ]* sum: .+, count: .+ min: .+, max: .+", + true}, + {" numPrefetch [ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" numRamRead [ ]* sum: .+, count: 1, min: .+, max: .+, avg: .+", + true}, + {" numStripes[ ]* sum: .+, count: 1, min: .+, max: .+"}, + {" overreadBytes[ ]* sum: .+, count: 1, min: .+, max: .+", true}, + {" prefetchBytes [ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" processedSplits [ ]* sum: 1, count: 1, min: 1, max: 1, avg: 1"}, + {" processedStrides [ ]* sum: 1, count: 1, min: 1, max: 1, avg: 1"}, + {" processedUnits [ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" preloadedSplits[ ]+sum: .+, count: .+, min: .+, max: .+", + true}, + {" ramReadBytes [ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" readyPreloadedSplits[ ]+sum: .+, count: .+, min: .+, max: .+", + true}, + {" runningAddInputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" runningFinishWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" runningIsBlockedWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" ssdCacheReadWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" storageReadBytes [ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" storageReadWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" totalRemainingFilterCpuNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, + {" totalRemainingFilterWallNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, + {" totalScanTime [ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" unitLoadNanos[ ]* sum: .+, count: .+, min: .+, max: .+, avg: .+"}}; + + std::vector scanPrefetchMetrics = { + {" -- TableScan\\[0\\]\\[table: hive_table\\] -> c0:BIGINT, c1:INTEGER, c2:SMALLINT, c3:REAL, c4:DOUBLE, c5:VARCHAR"}, + {" Input: 10000 rows \\(.+\\), Output: 10000 rows \\(.+\\), Cpu time: .+, Blocked wall time: .+, Peak memory: .+, Memory allocations: .+, Threads: 1, Splits: 1, CPU breakdown: B/I/O/F (.+/.+/.+/.+)"}, + {" blockedWaitForSplitTimes\\s+sum: 1, count: 1, min: 1, max: 1, avg: 1", + true}, + {" blockedWaitForSplitWallNanos\\s+sum: .+, count: 1, min: .+, max: .+", + true}, + {" cacheWaitWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" coalescedSsdLoadWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" coalescedStorageLoadWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" connectorSplitSize[ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" dataSourceReadWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" driverCpuTimeNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" fileFormat\\.dwrf[ ]* sum: .+, count: 1, min: .+, max: .+, avg: .+", + true}, + {" footerBufferOverread[ ]* sum: .+, count: 1, min: .+, max: .+"}, + {" ioWaitWallNanos [ ]* sum: .+, count: .+ min: .+, max: .+", + true}, + {" numPrefetch [ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" numRamRead [ ]* sum: .+, count: 1, min: .+, max: .+, avg: .+", + true}, + {" numStripes[ ]* sum: .+, count: 1, min: .+, max: .+"}, + {" overreadBytes[ ]* sum: .+, count: 1, min: .+, max: .+", true}, + {" prefetchBytes [ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" preloadSplitPrepareTimeNanos [ ]* sum: .+, count: 1, min: .+, max: .+, avg: .+"}, + {" preloadedSplits [ ]* sum: 1, count: 1, min: 1, max: 1, avg: 1"}, + {" processedSplits [ ]* sum: 1, count: 1, min: 1, max: 1, avg: 1"}, + {" processedStrides [ ]* sum: 1, count: 1, min: 1, max: 1, avg: 1"}, + {" processedUnits [ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" preloadedSplits[ ]+sum: .+, count: .+, min: .+, max: .+", + true}, + {" ramReadBytes [ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" readyPreloadedSplits[ ]+sum: .+, count: .+, min: .+, max: .+", + true}, + {" runningAddInputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" runningFinishWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" runningIsBlockedWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" storageReadBytes [ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" storageReadWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" totalRemainingFilterCpuNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, + {" totalRemainingFilterWallNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, + {" totalScanTime [ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" unitLoadNanos[ ]* sum: .+, count: .+, min: .+, max: .+, avg: .+"}, + {" waitForPreloadSplitNanos [ ]* sum: .+, count: 1, min: .+, max: .+, avg: .+"}}; + + auto scanMetrics = numPrefetchSplit == 0 ? scanDisablePrefetchMetrics + : scanPrefetchMetrics; + metrics.reserve(metrics.size() + scanMetrics.size()); + metrics.insert(metrics.end(), scanMetrics.begin(), scanMetrics.end()); + compareOutputs( ::testing::UnitTest::GetInstance()->current_test_info()->name(), printPlanWithStats(*op, task->taskStats(), true), - {{"-- Aggregation\\[1\\]\\[PARTIAL \\[c5\\] a0 := max\\(ROW\\[\"c0\"\\]\\), a1 := sum\\(ROW\\[\"c1\"\\]\\), a2 := sum\\(ROW\\[\"c2\"\\]\\), a3 := sum\\(ROW\\[\"c3\"\\]\\), a4 := sum\\(ROW\\[\"c4\"\\]\\)\\] -> c5:VARCHAR, a0:BIGINT, a1:BIGINT, a2:BIGINT, a3:DOUBLE, a4:DOUBLE"}, - {" Output: .+, Cpu time: .+, Blocked wall time: .+, Peak memory: .+, Memory allocations: .+, Threads: 1, CPU breakdown: B/I/O/F (.+/.+/.+/.+)"}, - {" dataSourceLazyCpuNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, - {" dataSourceLazyInputBytes\\s+sum: .+, count: .+, min: .+, max: .+"}, - {" dataSourceLazyWallNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, - {" distinctKey0\\s+sum: .+, count: 1, min: .+, max: .+"}, - {" hashtable.capacity\\s+sum: (?:1273|1252), count: 1, min: (?:1273|1252), max: (?:1273|1252), avg: (?:1273|1252)"}, - {" hashtable.numDistinct\\s+sum: (?:849|835), count: 1, min: (?:849|835), max: (?:849|835), avg: (?:849|835)"}, - {" hashtable.numRehashes\\s+sum: 1, count: 1, min: 1, max: 1, avg: 1"}, - {" hashtable.numTombstones\\s+sum: 0, count: 1, min: 0, max: 0, avg: 0"}, - {" loadedToValueHook\\s+sum: 50000, count: 5, min: 10000, max: 10000, avg: 10000"}, - {" runningAddInputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, - {" runningFinishWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, - {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, - {" -- TableScan\\[0\\]\\[table: hive_table\\] -> c0:BIGINT, c1:INTEGER, c2:SMALLINT, c3:REAL, c4:DOUBLE, c5:VARCHAR"}, - {" Input: 10000 rows \\(.+\\), Output: 10000 rows \\(.+\\), Cpu time: .+, Blocked wall time: .+, Peak memory: .+, Memory allocations: .+, Threads: 1, Splits: 1, CPU breakdown: B/I/O/F (.+/.+/.+/.+)"}, - {" connectorSplitSize[ ]* sum: .+, count: .+, min: .+, max: .+"}, - {" dataSourceAddSplitWallNanos[ ]* sum: .+, count: 1, min: .+, max: .+"}, - {" dataSourceReadWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+"}, - {" footerBufferOverread[ ]* sum: .+, count: 1, min: .+, max: .+"}, - {" ioWaitWallNanos [ ]* sum: .+, count: .+ min: .+, max: .+"}, - {" maxSingleIoWaitWallNanos[ ]*sum: .+, count: 1, min: .+, max: .+"}, - {" numPrefetch [ ]* sum: .+, count: .+, min: .+, max: .+"}, - {" numRamRead [ ]* sum: 7, count: 1, min: 7, max: 7, avg: 7"}, - {" numStorageRead [ ]* sum: .+, count: 1, min: .+, max: .+"}, - {" numStripes[ ]* sum: .+, count: 1, min: .+, max: .+"}, - {" overreadBytes[ ]* sum: 0B, count: 1, min: 0B, max: 0B, avg: 0B"}, - - {" prefetchBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, - {" processedSplits [ ]* sum: 1, count: 1, min: 1, max: 1, avg: 1"}, - {" processedStrides [ ]* sum: 1, count: 1, min: 1, max: 1, avg: 1"}, - {" preloadedSplits[ ]+sum: .+, count: .+, min: .+, max: .+", - true}, - {" ramReadBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, - {" readyPreloadedSplits[ ]+sum: .+, count: .+, min: .+, max: .+", - true}, - {" runningAddInputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, - {" runningFinishWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, - {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, - {" storageReadBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, - {" totalRemainingFilterWallNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, - {" totalScanTime [ ]* sum: .+, count: .+, min: .+, max: .+"}}); + metrics); } } @@ -340,7 +485,7 @@ TEST_F(PrintPlanWithStatsTest, tableWriterWithTableScan) { compareOutputs( ::testing::UnitTest::GetInstance()->current_test_info()->name(), printPlanWithStats(*writePlan, task->taskStats()), - {{R"(-- TableWrite\[1\]\[.+InsertTableHandle .+)"}, + {{R"(-- TableWrite\[1\]\[test-hive, c0, c1, c2, c3, c4, c5\] -> rows:BIGINT, fragments:VARBINARY, commitcontext:VARBINARY)"}, {" Output: .+, Physical written output: .+, Cpu time: .+, Blocked wall time: .+, Peak memory: .+, Memory allocations: .+, Threads: 1, CPU breakdown: B/I/O/F (.+/.+/.+/.+)"}, {R"( -- TableScan\[0\]\[table: hive_table\] -> c0:BIGINT, c1:INTEGER, c2:SMALLINT, c3:REAL, c4:DOUBLE, c5:VARCHAR)"}, {R"( Input: 100 rows \(.+\), Output: 100 rows \(.+\), Cpu time: .+, Blocked wall time: .+, Peak memory: .+, Memory allocations: .+, Threads: 1, Splits: 1, CPU breakdown: B/I/O/F (.+/.+/.+/.+))"}}); @@ -348,46 +493,78 @@ TEST_F(PrintPlanWithStatsTest, tableWriterWithTableScan) { compareOutputs( ::testing::UnitTest::GetInstance()->current_test_info()->name(), printPlanWithStats(*writePlan, task->taskStats(), true), - {{R"(-- TableWrite\[1\]\[.+InsertTableHandle .+)"}, - {" Output: .+, Physical written output: .+, Cpu time: .+, Blocked wall time: .+, Peak memory: .+, Memory allocations: .+, Threads: 1, CPU breakdown: B/I/O/F (.+/.+/.+/.+)"}, - {" dataSourceLazyCpuNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, - {" dataSourceLazyInputBytes\\s+sum: .+, count: .+, min: .+, max: .+"}, - {" dataSourceLazyWallNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, - {" numWrittenFiles\\s+sum: .+, count: 1, min: .+, max: .+"}, - {" runningAddInputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, - {" runningFinishWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, - {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, - {" runningWallNanos\\s+sum: .+, count: 1, min: .+, max: .+, avg: .+"}, - {" stripeSize\\s+sum: .+, count: 1, min: .+, max: .+"}, - {" writeIOWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, - {R"( -- TableScan\[0\]\[table: hive_table\] -> c0:BIGINT, c1:INTEGER, c2:SMALLINT, c3:REAL, c4:DOUBLE, c5:VARCHAR)"}, - {R"( Input: 100 rows \(.+\), Output: 100 rows \(.+\), Cpu time: .+, Blocked wall time: .+, Peak memory: .+, Memory allocations: .+, Threads: 1, Splits: 1, CPU breakdown: B/I/O/F (.+/.+/.+/.+))"}, - {" connectorSplitSize[ ]* sum: .+, count: .+, min: .+, max: .+"}, - {" dataSourceAddSplitWallNanos[ ]* sum: .+, count: 1, min: .+, max: .+"}, - {" dataSourceReadWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+"}, - {" footerBufferOverread[ ]* sum: .+, count: 1, min: .+, max: .+"}, - {" ioWaitWallNanos [ ]* sum: .+, count: .+ min: .+, max: .+"}, - {" maxSingleIoWaitWallNanos[ ]*sum: .+, count: 1, min: .+, max: .+"}, - {" numPrefetch [ ]* sum: .+, count: .+, min: .+, max: .+"}, - {" numRamRead [ ]* sum: 7, count: 1, min: 7, max: 7, avg: 7"}, - {" numStorageRead [ ]* sum: .+, count: 1, min: .+, max: .+"}, - {" numStripes[ ]* sum: .+, count: 1, min: .+, max: .+"}, - {" overreadBytes[ ]* sum: 0B, count: 1, min: 0B, max: 0B, avg: 0B"}, - - {" prefetchBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, - {" processedSplits [ ]* sum: 1, count: 1, min: 1, max: 1, avg: 1"}, - {" processedStrides [ ]* sum: 1, count: 1, min: 1, max: 1, avg: 1"}, - {" preloadedSplits[ ]+sum: .+, count: .+, min: .+, max: .+", - true}, - {" ramReadBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, - {" readyPreloadedSplits[ ]+sum: .+, count: .+, min: .+, max: .+", - true}, - {" runningAddInputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, - {" runningFinishWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, - {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, - {" storageReadBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, - {" totalRemainingFilterWallNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, - {" totalScanTime [ ]* sum: .+, count: .+, min: .+, max: .+"}}); + { + {R"(-- TableWrite\[1\]\[test-hive, c0, c1, c2, c3, c4, c5\] -> rows:BIGINT, fragments:VARBINARY, commitcontext:VARBINARY)"}, + {" Output: .+, Physical written output: .+, Cpu time: .+, Blocked wall time: .+, Peak memory: .+, Memory allocations: .+, Threads: 1, CPU breakdown: B/I/O/F (.+/.+/.+/.+)"}, + {" dataSourceLazyCpuNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, + {" dataSourceLazyInputBytes\\s+sum: .+, count: .+, min: .+, max: .+"}, + {" dataSourceLazyWallNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, + {" driverCpuTimeNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" dwrfWriterCount\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" numWrittenFiles\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" runningAddInputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" runningFinishWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" runningIsBlockedWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" runningWallNanos\\s+sum: .+, count: 1, min: .+, max: .+, avg: .+"}, + {" stripeSize\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" writeIOWallNanos\\s+sum: .+, count: 1, min: .+, max: .+, avg: .+"}, + {R"( -- TableScan\[0\]\[table: hive_table\] -> c0:BIGINT, c1:INTEGER, c2:SMALLINT, c3:REAL, c4:DOUBLE, c5:VARCHAR)"}, + {R"( Input: 100 rows \(.+\), Output: 100 rows \(.+\), Cpu time: .+, Blocked wall time: .+, Peak memory: .+, Memory allocations: .+, Threads: 1, Splits: 1, CPU breakdown: B/I/O/F (.+/.+/.+/.+))"}, + // These lines may or may not appear depending on whether the operator + // gets blocked waiting for a split during a run. + {" blockedWaitForSplitTimes\\s+sum: 1, count: 1, min: 1, max: 1, avg: 1", + true}, + {" blockedWaitForSplitWallNanos\\s+sum: .+, count: 1, min: .+, max: .+", + true}, + {" cacheWaitWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" coalescedSsdLoadWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" coalescedStorageLoadWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" connectorSplitSize[ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" dataSourceReadWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" driverCpuTimeNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" fileFormat\\.dwrf[ ]* sum: .+, count: 1, min: .+, max: .+, avg: .+", + true}, + {" footerBufferOverread[ ]* sum: .+, count: 1, min: .+, max: .+"}, + {" ioWaitWallNanos [ ]* sum: .+, count: .+ min: .+, max: .+", + true}, + {" numPrefetch [ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" numRamRead [ ]* sum: .+, count: 1, min: .+, max: .+, avg: .+", + true}, + {" numStripes[ ]* sum: .+, count: 1, min: .+, max: .+"}, + {" overreadBytes[ ]* sum: .+, count: 1, min: .+, max: .+", + true}, + {" prefetchBytes [ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" preloadSplitPrepareTimeNanos [ ]* sum: .+, count: 1, min: .+, max: .+, avg: .+"}, + {" preloadedSplits [ ]* sum: 1, count: 1, min: 1, max: 1, avg: 1"}, + {" processedSplits [ ]* sum: 1, count: 1, min: 1, max: 1, avg: 1"}, + {" processedStrides [ ]* sum: 1, count: 1, min: 1, max: 1, avg: 1"}, + {" processedUnits [ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" preloadedSplits[ ]+sum: .+, count: .+, min: .+, max: .+", + true}, + {" ramReadBytes [ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" readyPreloadedSplits[ ]+sum: .+, count: .+, min: .+, max: .+", + true}, + {" runningAddInputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" runningFinishWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" runningIsBlockedWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" storageReadBytes [ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" storageReadWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" totalRemainingFilterCpuNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, + {" totalRemainingFilterWallNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, + {" totalScanTime [ ]* sum: .+, count: .+, min: .+, max: .+", + true}, + {" unitLoadNanos[ ]* sum: .+, count: .+, min: .+, max: .+, avg: .+"}, + {" waitForPreloadSplitNanos [ ]* sum: .+, count: 1, min: .+, max: .+, avg: .+"}, + }); } TEST_F(PrintPlanWithStatsTest, taskAPI) { diff --git a/velox/exec/tests/QueryAssertionsTest.cpp b/velox/exec/tests/QueryAssertionsTest.cpp index 21b5d60099a..eb08c7b14b7 100644 --- a/velox/exec/tests/QueryAssertionsTest.cpp +++ b/velox/exec/tests/QueryAssertionsTest.cpp @@ -182,7 +182,7 @@ TEST_F(QueryAssertionsTest, singleFloatColumn) { size, [&](auto row) { return row == 302 - ? 2.01 + std::max(kEpsilon, double(6 * FLT_EPSILON)) + ? 2.01 + std::max(Variant::kEpsilon, double(6 * FLT_EPSILON)) : row % 6 + 0.01; }, nullEvery(7)), @@ -282,15 +282,17 @@ TEST_F(QueryAssertionsTest, multiFloatColumnWithUniqueKeys) { makeFlatVector( size, [&](auto row) { - return row == 6 ? 2 + std::max(float(kEpsilon), 6 * FLT_EPSILON) - : row % 4; + return row == 6 + ? 2 + std::max(float(Variant::kEpsilon), 6 * FLT_EPSILON) + : row % 4; }, nullEvery(5)), makeFlatVector( size, [&](auto row) { - return row == 1 ? 1.01 + std::max(kEpsilon, double(3 * FLT_EPSILON)) - : row % 6 + 0.01; + return row == 1 + ? 1.01 + std::max(Variant::kEpsilon, double(3 * FLT_EPSILON)) + : row % 6 + 0.01; }, nullEvery(7)), }); diff --git a/velox/exec/tests/RowContainerTest.cpp b/velox/exec/tests/RowContainerTest.cpp index 94884361d0a..05777919e2e 100644 --- a/velox/exec/tests/RowContainerTest.cpp +++ b/velox/exec/tests/RowContainerTest.cpp @@ -108,7 +108,8 @@ class RowContainerTestHelper { RowContainer* const rowContainer_; }; -class RowContainerTest : public exec::test::RowContainerTestBase { +class RowContainerTest : public exec::test::RowContainerTestBase, + public testing::WithParamInterface { protected: static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); @@ -344,7 +345,9 @@ class RowContainerTest : public exec::test::RowContainerTestBase { sum += data.rowSize(row) - data.fixedRowSize(); } auto usage = data.stringAllocator().currentBytes(); - EXPECT_EQ(usage, sum); + if (data.testingRowPointers().empty()) { + EXPECT_EQ(usage, sum); + } } std::vector store( @@ -484,11 +487,12 @@ class RowContainerTest : public exec::test::RowContainerTestBase { const VectorPtr& expected, std::optional flags) { // If no flags provided then it must be the default of {true, true}. - SCOPED_TRACE(fmt::format( - "{}, ascending = {}, nullsFirst = {}", - type->toString(), - flags.has_value() ? flags.value().ascending : true, - flags.has_value() ? flags.value().nullsFirst : true)); + SCOPED_TRACE( + fmt::format( + "{}, ascending = {}, nullsFirst = {}", + type->toString(), + flags.has_value() ? flags.value().ascending : true, + flags.has_value() ? flags.value().nullsFirst : true)); // Set 'isJoinBuild' to true to enable nullable sort key in test. auto rowContainer = makeRowContainer({type}, {type}, false); @@ -570,8 +574,8 @@ class RowContainerTest : public exec::test::RowContainerTestBase { for (auto row : rows) { ASSERT_EQ( expected[index], - rowContainer->equals( - row, rowContainer->columnAt(0), rhsDecoded, index)) + rowContainer->compare( + row, rowContainer->columnAt(0), rhsDecoded, index) == 0) << fmt::format( "Mismatch at index {} with canHandleNulls {}", index, @@ -831,8 +835,8 @@ class RowContainerTest : public exec::test::RowContainerTestBase { kNumRows, rawValues, rawExpected); // Remove nulls in data because keys cannot be null. auto numNulls = countLeadingNulls(rawValues); - rawValues.erase(rawValues.begin(), rawValues.begin() + numNulls); - rawExpected.erase(rawExpected.begin(), rawExpected.begin() + numNulls); + rawValues.erase(rawValues.cbegin(), rawValues.cbegin() + numNulls); + rawExpected.erase(rawExpected.cbegin(), rawExpected.cbegin() + numNulls); std::optional second = 0; for (auto value : rawValues) { std::vector>> temp{ @@ -935,7 +939,7 @@ static int32_t sign(int32_t n) { } } // namespace -TEST_F(RowContainerTest, extractWithNullsAndTargetOffset) { +TEST_P(RowContainerTest, extractWithNullsAndTargetOffset) { constexpr int32_t kNumRows = 100; // The second column must have no nulls in the first batch. auto rowVector1 = makeRowVector({ @@ -961,8 +965,8 @@ TEST_F(RowContainerTest, extractWithNullsAndTargetOffset) { // Create and fill up two row containers from two row vectors. std::vector vecTypes = {BOOLEAN(), VARCHAR(), TINYINT()}; RowTypePtr rowType = VectorMaker::rowType({BOOLEAN(), VARCHAR(), TINYINT()}); - auto data1 = makeRowContainer({}, vecTypes); - auto data2 = makeRowContainer({}, vecTypes); + auto data1 = makeRowContainer({}, vecTypes, true, GetParam()); + auto data2 = makeRowContainer({}, vecTypes, true, GetParam()); for (auto i = 0; i < kNumRows; i++) { data1->newRow(); data2->newRow(); @@ -1045,7 +1049,7 @@ TEST_F(RowContainerTest, storeExtractArrayOfVarchar) { roundTrip(input); } -TEST_F(RowContainerTest, types) { +TEST_P(RowContainerTest, types) { constexpr int32_t kNumRows = 100; auto batch = makeDataset( ROW( @@ -1105,7 +1109,7 @@ TEST_F(RowContainerTest, types) { std::vector dependents; dependents.insert( dependents.begin(), types.begin() + types.size() / 2, types.end()); - auto data = makeRowContainer(keys, dependents); + auto data = makeRowContainer(keys, dependents, true, GetParam()); EXPECT_GT(data->nextOffset(), 0); EXPECT_GT(data->probedFlagOffset(), 0); @@ -1177,11 +1181,13 @@ TEST_F(RowContainerTest, types) { EXPECT_EQ(source->hashValueAt(i), hashes[i]); // Test non-null and nullable variants of equals. if (column < keys.size()) { - EXPECT_TRUE( - data->equals(rows[i], data->columnAt(column), decoded, i)); + EXPECT_EQ( + data->compare(rows[i], data->columnAt(column), decoded, i), + 0); } else if (!columnType->isMap()) { - EXPECT_TRUE( - data->equals(rows[i], data->columnAt(column), decoded, i)); + EXPECT_EQ( + data->compare(rows[i], data->columnAt(column), decoded, i), + 0); } // Non-key map columns are not comparable, as the map keys are not sorted. if (columnType->isMap() && column >= keys.size()) { @@ -1215,7 +1221,7 @@ TEST_F(RowContainerTest, types) { EXPECT_LT(0, free.second); } -TEST_F(RowContainerTest, extractNulls) { +TEST_P(RowContainerTest, extractNulls) { constexpr int32_t kNumRows = 100; auto batch = makeRowVector({ makeFlatVector( @@ -1266,7 +1272,7 @@ TEST_F(RowContainerTest, extractNulls) { ARRAY(INTEGER()), MAP(INTEGER(), INTEGER()), ROW({INTEGER(), INTEGER()})}; - auto data = makeRowContainer({}, rowType); + auto data = makeRowContainer({}, rowType, true, GetParam()); for (int i = 0; i < kNumRows; i++) { data->newRow(); } @@ -1347,11 +1353,11 @@ TEST_F(RowContainerTest, erase) { RowContainerTestHelper(data.get()).checkConsistency(); } -TEST_F(RowContainerTest, initialNulls) { +TEST_P(RowContainerTest, initialNulls) { std::vector keys{INTEGER()}; std::vector dependent{INTEGER()}; // Join build. - auto data = makeRowContainer(keys, dependent, true); + auto data = makeRowContainer(keys, dependent, true, GetParam()); auto row = data->newRow(); auto isNullAt = [](const RowContainer& data, const char* row, int32_t i) { auto column = data.columnAt(i); @@ -1361,7 +1367,7 @@ TEST_F(RowContainerTest, initialNulls) { EXPECT_FALSE(isNullAt(*data, row, 0)); EXPECT_FALSE(isNullAt(*data, row, 1)); // Non-join build. - data = makeRowContainer(keys, dependent, false); + data = makeRowContainer(keys, dependent, false, GetParam()); row = data->newRow(); EXPECT_FALSE(isNullAt(*data, row, 0)); EXPECT_FALSE(isNullAt(*data, row, 1)); @@ -1369,7 +1375,7 @@ TEST_F(RowContainerTest, initialNulls) { TEST_F(RowContainerTest, rowSize) { constexpr int32_t kNumRows = 100; - auto data = makeRowContainer({SMALLINT()}, {VARCHAR()}); + auto data = makeRowContainer({SMALLINT()}, {VARCHAR()}, true); // The layout is expected to be smallint - 6 bytes of padding - 1 byte of bits // - StringView - rowSize - next pointer. The bits are a null flag for the @@ -1405,10 +1411,10 @@ TEST_F(RowContainerTest, rowSize) { EXPECT_EQ(rows, rowsFromContainer); } -TEST_F(RowContainerTest, columnSize) { +TEST_P(RowContainerTest, columnSize) { const uint64_t kNumRows = 1000; - auto rowContainer = - makeRowContainer({BIGINT(), VARCHAR()}, {BIGINT(), VARCHAR()}); + auto rowContainer = makeRowContainer( + {BIGINT(), VARCHAR()}, {BIGINT(), VARCHAR()}, true, GetParam()); VectorFuzzer fuzzer( { @@ -1452,8 +1458,8 @@ TEST_F(RowContainerTest, columnSize) { } } -TEST_F(RowContainerTest, rowSizeWithNormalizedKey) { - auto data = makeRowContainer({SMALLINT()}, {VARCHAR()}); +TEST_P(RowContainerTest, rowSizeWithNormalizedKey) { + auto data = makeRowContainer({SMALLINT()}, {VARCHAR()}, true, GetParam()); data->newRow(); data->disableNormalizedKeys(); data->newRow(); @@ -1469,7 +1475,7 @@ TEST_F(RowContainerTest, estimateRowSize) { // Make a RowContainer with a fixed-length key column and a variable-length // dependent column. - auto rowContainer = makeRowContainer({BIGINT()}, {VARCHAR()}); + auto rowContainer = makeRowContainer({BIGINT()}, {VARCHAR()}, true); EXPECT_FALSE(rowContainer->estimateRowSize().has_value()); // Store rows to the container. @@ -1511,7 +1517,9 @@ TEST_F(RowContainerTest, alignment) { false, false, true, + false, // hasCountFlag true, + false, pool_.get()); constexpr int kNumRows = 100; char* rows[kNumRows]; @@ -1678,7 +1686,9 @@ TEST_F(RowContainerTest, probedFlag) { true, // hasNext true, // isJoinBuild true, // hasProbedFlag + false, // hasCountFlag false, // hasNormalizedKey + false, // useListRowIndex pool_.get()); auto input = makeRowVector({ @@ -1782,6 +1792,32 @@ TEST_F(RowContainerTest, probedFlag) { result); } +TEST_F(RowContainerTest, countFlag) { + RowContainer rows( + {INTEGER()}, + true, // nullableKeys + std::vector{}, + {}, // dependentTypes + false, // hasNext + true, // isJoinBuild + false, // hasProbedFlag + true, // hasCountFlag + false, // hasNormalizedKey + false, // useListRowIndex + pool_.get()); + + ASSERT_NE(rows.countOffset(), 0); + + auto* row = rows.newRow(); + EXPECT_EQ(rows.count(row), 1); + rows.incrementCount(row); + EXPECT_EQ(rows.count(row), 2); + rows.decrementCount(row); + EXPECT_EQ(rows.count(row), 1); + rows.decrementCount(row); + EXPECT_EQ(rows.count(row), 0); +} + TEST_F(RowContainerTest, mixedFree) { constexpr int32_t kNumRows = 100'000; constexpr int32_t kNumBad = 100; @@ -1868,8 +1904,10 @@ TEST_F(RowContainerTest, unknown) { } for (size_t row = 0; row < size; ++row) { - ASSERT_TRUE(rowContainer->equals( - rows[row], rowContainer->columnAt(0), decoded, row)); + ASSERT_EQ( + rowContainer->compare( + rows[row], rowContainer->columnAt(0), decoded, row), + 0); } { @@ -1892,24 +1930,29 @@ TEST_F(RowContainerTest, unknown) { // Verify compare method with row and decoded vector as input // Sorting a NULL constant Vector doesn't change the Vector, so we just // validate that it runs without throwing an exception. - EXPECT_NO_THROW(std::sort( - indexedRows.begin(), - indexedRows.end(), - [&](const std::pair& l, const std::pair& r) { - return rowContainer->compare( - l.second, rowContainer->columnAt(0), decoded, r.first, {}) < - 0; - })); + EXPECT_NO_THROW( + std::sort( + indexedRows.begin(), + indexedRows.end(), + [&](const std::pair& l, const std::pair& r) { + return rowContainer->compare( + l.second, + rowContainer->columnAt(0), + decoded, + r.first, + {}) < 0; + })); // Verify compareRows method with row as input. // Sorting a NULL constant Vector doesn't change the Vector, so we just // validate that it runs without throwing an exception. - EXPECT_NO_THROW(std::sort( - indexedRows.begin(), - indexedRows.end(), - [&](const std::pair& l, const std::pair& r) { - return rowContainer->compareRows(l.second, r.second) < 0; - })); + EXPECT_NO_THROW( + std::sort( + indexedRows.begin(), + indexedRows.end(), + [&](const std::pair& l, const std::pair& r) { + return rowContainer->compareRows(l.second, r.second) < 0; + })); } TEST_F(RowContainerTest, nans) { @@ -1945,8 +1988,10 @@ TEST_F(RowContainerTest, nans) { // Verify that they are considered equal. for (size_t row = 0; row < size; ++row) { - ASSERT_TRUE(rowContainer->equals( - rows[row], rowContainer->columnAt(0), decoded, row)); + ASSERT_EQ( + rowContainer->compare( + rows[row], rowContainer->columnAt(0), decoded, row), + 0); } ASSERT_EQ(rowContainer->compare(rows[0], rows[1], 0, {}), 0); } @@ -2133,7 +2178,7 @@ DEBUG_ONLY_TEST_F(RowContainerTest, eraseAfterOomStoringString) { rowContainer->eraseRows(folly::Range(rows.data(), numRows)); } -TEST_F(RowContainerTest, hugeIntStoreWithNulls) { +TEST_P(RowContainerTest, hugeIntStoreWithNulls) { constexpr int32_t kNumRows = 100; constexpr int32_t kColumnIndex = 0; @@ -2166,7 +2211,7 @@ TEST_F(RowContainerTest, hugeIntStoreWithNulls) { dictNulls, dictIndices, kNumRows, hugeIntVector); std::vector keys; - auto data = makeRowContainer({HUGEINT()}, {}, false); + auto data = makeRowContainer({HUGEINT()}, {}, false, GetParam()); std::vector rows(kNumRows); for (auto i = 0; i < kNumRows; ++i) { rows[i] = data->newRow(); @@ -2181,9 +2226,9 @@ TEST_F(RowContainerTest, hugeIntStoreWithNulls) { assertEqualVectors(source, extracted); } -TEST_F(RowContainerTest, columnHasNulls) { - auto rowContainer = - makeRowContainer({BIGINT(), BIGINT()}, {BIGINT(), BIGINT()}, false); +TEST_P(RowContainerTest, columnHasNulls) { + auto rowContainer = makeRowContainer( + {BIGINT(), BIGINT()}, {BIGINT(), BIGINT()}, false, GetParam()); for (int i = 0; i < rowContainer->columnTypes().size(); ++i) { ASSERT_FALSE(rowContainer->columnHasNulls(i)); } @@ -2234,7 +2279,7 @@ TEST_F(RowContainerTest, columnHasNulls) { } } -TEST_F(RowContainerTest, store) { +TEST_P(RowContainerTest, store) { const uint64_t kNumRows = 1000; auto rowVectorWithNulls = makeRowVector({ makeFlatVector( @@ -2264,7 +2309,7 @@ TEST_F(RowContainerTest, store) { }); for (auto& rowVector : {rowVectorWithNulls, rowVectorNoNulls}) { auto rowContainer = makeRowContainer( - {BIGINT(), VARCHAR()}, {BIGINT(), ARRAY(BIGINT())}, false); + {BIGINT(), VARCHAR()}, {BIGINT(), ARRAY(BIGINT())}, false, GetParam()); std::vector rows; rows.reserve(kNumRows); @@ -2408,7 +2453,7 @@ TEST_F(RowContainerTest, customComparisonRow) { }); } -TEST_F(RowContainerTest, isNanAt) { +TEST_P(RowContainerTest, isNanAt) { const auto kNan = std::numeric_limits::quiet_NaN(); const auto kNanF = std::numeric_limits::quiet_NaN(); auto rowVector = makeRowVector({ @@ -2419,8 +2464,8 @@ TEST_F(RowContainerTest, isNanAt) { }); const auto kNumRows = rowVector->size(); - auto rowContainer = - makeRowContainer({REAL(), DOUBLE()}, {REAL(), DOUBLE()}, false); + auto rowContainer = makeRowContainer( + {REAL(), DOUBLE()}, {REAL(), DOUBLE()}, false, GetParam()); std::vector rows; rows.reserve(kNumRows); @@ -2653,7 +2698,7 @@ TEST_F(RowContainerTest, rowColumnStats) { EXPECT_EQ(stats.nullCount(), 3); } -TEST_F(RowContainerTest, storeAndCollectColumnStats) { +TEST_P(RowContainerTest, storeAndCollectColumnStats) { const uint64_t kNumRows = 1000; auto rowVector = makeRowVector({ makeFlatVector( @@ -2664,7 +2709,8 @@ TEST_F(RowContainerTest, storeAndCollectColumnStats) { nullEvery(7)), }); - auto rowContainer = makeRowContainer({BIGINT(), VARCHAR()}, {}, false); + auto rowContainer = + makeRowContainer({BIGINT(), VARCHAR()}, {}, false, GetParam()); std::vector rows; rows.reserve(kNumRows); @@ -2722,6 +2768,8 @@ TEST_F(RowContainerTest, setAllNull) { false, true, false, + false, // hasCountFlag + false, false, pool_.get()); @@ -2746,4 +2794,9 @@ TEST_F(RowContainerTest, setAllNull) { (row[accColumn.initializedByte()] & accColumn.initializedMask()), 0); } +VELOX_INSTANTIATE_TEST_SUITE_P( + RowContainerTest, + RowContainerTest, + testing::ValuesIn({false, true})); + } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/RowNumberTest.cpp b/velox/exec/tests/RowNumberTest.cpp index 84e272d4273..ba832c12daf 100644 --- a/velox/exec/tests/RowNumberTest.cpp +++ b/velox/exec/tests/RowNumberTest.cpp @@ -16,13 +16,14 @@ #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/FileSystems.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/exec/PlanNodeStats.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/OperatorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" namespace facebook::velox::exec::test { +using namespace facebook::velox::common::testutil; class RowNumberTest : public OperatorTestBase { protected: @@ -204,7 +205,7 @@ TEST_F(RowNumberTest, spill) { for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); TestScopedSpillInjection scopedSpillInjection(100, ".*", 1); @@ -218,11 +219,12 @@ TEST_F(RowNumberTest, spill) { core::QueryConfig::kSpillNumPartitionBits, testData.spillPartitionBits) .queryCtx(queryCtx) - .plan(PlanBuilder() - .values(vectors) - .rowNumber({"c0"}) - .capturePlanNodeId(rowNumberPlanNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .rowNumber({"c0"}) + .capturePlanNodeId(rowNumberPlanNodeId) + .planNode()) .assertResults( "SELECT *, row_number() over (partition by c0) FROM tmp"); auto taskStats = toPlanStats(task->taskStats()); @@ -237,11 +239,13 @@ TEST_F(RowNumberTest, spill) { task->taskStats().pipelineStats.back().operatorStats.at(1); auto runtimeStats = operatorStats.runtimeStats; ASSERT_EQ( - runtimeStats.at(Operator::kSpillReadBytes).sum, + runtimeStats.at(std::string(Operator::kSpillReadBytes)).sum, operatorStats.spilledBytes); - ASSERT_GT(runtimeStats.at(Operator::kSpillReads).sum, 0); - ASSERT_GT(runtimeStats.at(Operator::kSpillReadTime).sum, 0); - ASSERT_GT(runtimeStats.at(Operator::kSpillDeserializationTime).sum, 0); + ASSERT_GT(runtimeStats.at(std::string(Operator::kSpillReads)).sum, 0); + ASSERT_GT(runtimeStats.at(std::string(Operator::kSpillReadTime)).sum, 0); + ASSERT_GT( + runtimeStats.at(std::string(Operator::kSpillDeserializationTime)).sum, + 0); task.reset(); waitForAllTasksToBeDeleted(); @@ -268,7 +272,7 @@ TEST_F(RowNumberTest, maxSpillBytes) { for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); try { TestScopedSpillInjection scopedSpillInjection(100, ".*", 1); @@ -320,7 +324,7 @@ TEST_F(RowNumberTest, memoryUsage) { for (const auto& spillEnable : {false, true}) { auto queryCtx = core::QueryCtx::create(executor_.get()); - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); const std::string spillEnableConfig = std::to_string(spillEnable); std::shared_ptr task; @@ -374,7 +378,7 @@ DEBUG_ONLY_TEST_F(RowNumberTest, spillOnlyDuringInputOrOutput) { for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); std::atomic_int numRound{0}; @@ -404,11 +408,12 @@ DEBUG_ONLY_TEST_F(RowNumberTest, spillOnlyDuringInputOrOutput) { core::QueryConfig::kSpillNumPartitionBits, testData.spillPartitionBits) .queryCtx(queryCtx) - .plan(PlanBuilder() - .values(vectors) - .rowNumber({"c0"}) - .capturePlanNodeId(rowNumberPlanNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .rowNumber({"c0"}) + .capturePlanNodeId(rowNumberPlanNodeId) + .planNode()) .assertResults( "SELECT *, row_number() over (partition by c0) FROM tmp"); auto taskStats = toPlanStats(task->taskStats()); @@ -441,7 +446,7 @@ DEBUG_ONLY_TEST_F(RowNumberTest, recursiveSpill) { for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); std::atomic_int numSpills{0}; @@ -488,11 +493,12 @@ DEBUG_ONLY_TEST_F(RowNumberTest, recursiveSpill) { .config( core::QueryConfig::kMaxSpillLevel, testData.maxSpillLevel - 1) .queryCtx(queryCtx) - .plan(PlanBuilder() - .values(vectors) - .rowNumber({"c0"}) - .capturePlanNodeId(rowNumberPlanNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .rowNumber({"c0"}) + .capturePlanNodeId(rowNumberPlanNodeId) + .planNode()) .assertResults( "SELECT *, row_number() over (partition by c0) FROM tmp"); auto taskStats = toPlanStats(task->taskStats()); @@ -537,7 +543,7 @@ TEST_F(RowNumberTest, spillWithYield) { SCOPED_TRACE(testData.debugString()); TestScopedSpillInjection scopedSpillInjection( 100, ".*", testData.numSpills); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); core::PlanNodeId rowNumberPlanNodeId; @@ -550,11 +556,12 @@ TEST_F(RowNumberTest, spillWithYield) { core::QueryConfig::kDriverCpuTimeSliceLimitMs, testData.cpuTimeSliceLimitMs) .queryCtx(queryCtx) - .plan(PlanBuilder() - .values(vectors) - .rowNumber({"c0"}) - .capturePlanNodeId(rowNumberPlanNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .rowNumber({"c0"}) + .capturePlanNodeId(rowNumberPlanNodeId) + .planNode()) .assertResults( "SELECT *, row_number() over (partition by c0) FROM tmp"); auto taskStats = toPlanStats(task->taskStats()); @@ -568,4 +575,41 @@ TEST_F(RowNumberTest, spillWithYield) { } } +DEBUG_ONLY_TEST_F(RowNumberTest, rowNumberSpillFileCreateConfig) { + auto vectors = createVectors(8, rowType_, fuzzerOpts_); + createDuckDbTable(vectors); + + auto tempDirectory = TempDirectoryPath::create(); + + std::atomic_bool rowNumberConfigVerified{false}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::isBlocked", + std::function([&](exec::Operator* op) { + const auto* spillConfig = op->testingSpillConfig(); + if (spillConfig == nullptr) { + return; + } + const auto& opType = op->operatorType(); + if (opType == "RowNumber") { + ASSERT_EQ(spillConfig->fileCreateConfig, "test_row_number_config") + << "Operator: " << opType; + rowNumberConfigVerified = true; + } + })); + + TestScopedSpillInjection scopedSpillInjection(100); + AssertQueryBuilder(duckDbQueryRunner_) + .spillDirectory(tempDirectory->getPath()) + .config(core::QueryConfig::kSpillEnabled, true) + .config(core::QueryConfig::kRowNumberSpillEnabled, true) + .config(core::QueryConfig::kSpillFileCreateConfig, "test_default_config") + .config( + core::QueryConfig::kRowNumberSpillFileCreateConfig, + "test_row_number_config") + .plan(PlanBuilder().values(vectors).rowNumber({"c0"}).planNode()) + .assertResults("SELECT *, row_number() over (partition by c0) FROM tmp"); + + ASSERT_TRUE(rowNumberConfigVerified.load()); +} + } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/ScaleWriterLocalPartitionTest.cpp b/velox/exec/tests/ScaleWriterLocalPartitionTest.cpp index 2dda1e2e868..8a6ee39a705 100644 --- a/velox/exec/tests/ScaleWriterLocalPartitionTest.cpp +++ b/velox/exec/tests/ScaleWriterLocalPartitionTest.cpp @@ -81,6 +81,10 @@ class TestExchangeController { if (holdBufferBytes_ == 0) { return; } + if (holdBuffer_ != nullptr) { + return; + } + holdPool_ = pool; holdBuffer_ = holdPool_->allocate(holdBufferBytes_); } @@ -227,11 +231,6 @@ class FakeSourceOperator : public SourceOperator { private: void initialize() override { Operator::initialize(); - - if (operatorCtx_->driverCtx()->driverId != 0) { - return; - } - testController_->maybeHoldBuffer(pool()); } @@ -495,10 +494,11 @@ class ScaleWriterLocalPartitionTest : public HiveConnectorTestBase { for (const auto& name : rowType_->names()) { orderByKeys.push_back(fmt::format("{} ASC NULLS FIRST", name)); } - AssertQueryBuilder queryBuilder(PlanBuilder() - .values(inputVectors) - .orderBy(orderByKeys, false) - .planNode()); + AssertQueryBuilder queryBuilder( + PlanBuilder() + .values(inputVectors) + .orderBy(orderByKeys, false) + .planNode()); return queryBuilder.copyResults(pool_.get()); } @@ -680,22 +680,26 @@ TEST_F(ScaleWriterLocalPartitionTest, unpartitionBasic) { if (testData.expectedRebalance) { ASSERT_GT( planStats.at(exchnangeNodeId) - .customStats.at(ScaleWriterLocalPartition::kScaledWriters) + .customStats + .at(std::string(ScaleWriterLocalPartition::kScaledWriters)) .sum, 0); ASSERT_LE( planStats.at(exchnangeNodeId) - .customStats.at(ScaleWriterLocalPartition::kScaledWriters) + .customStats + .at(std::string(ScaleWriterLocalPartition::kScaledWriters)) .sum, planStats.at(exchnangeNodeId) - .customStats.at(ScaleWriterLocalPartition::kScaledWriters) + .customStats + .at(std::string(ScaleWriterLocalPartition::kScaledWriters)) .count * (testData.numConsumers - 1)); ASSERT_GT(nonEmptyConsumers, 1); } else { ASSERT_EQ( planStats.at(exchnangeNodeId) - .customStats.count(ScaleWriterLocalPartition::kScaledWriters), + .customStats.count( + std::string(ScaleWriterLocalPartition::kScaledWriters)), 0); ASSERT_EQ(nonEmptyConsumers, 1); } @@ -893,40 +897,51 @@ TEST_F(ScaleWriterLocalPartitionTest, partitionBasic) { ASSERT_LE( planStats.at(exchnangeNodeId) .customStats.count( - ScaleWriterPartitioningLocalPartition::kScaledPartitions), + std::string( + ScaleWriterPartitioningLocalPartition:: + kScaledPartitions)), 1); if (planStats.at(exchnangeNodeId) .customStats.count( - ScaleWriterPartitioningLocalPartition::kScaledPartitions) == - 1) { + std::string( + ScaleWriterPartitioningLocalPartition:: + kScaledPartitions)) == 1) { ASSERT_GT( planStats.at(exchnangeNodeId) .customStats - .at(ScaleWriterPartitioningLocalPartition::kScaledPartitions) + .at(std::string( + ScaleWriterPartitioningLocalPartition::kScaledPartitions)) .sum, 0); } ASSERT_EQ( planStats.at(exchnangeNodeId) .customStats.count( - ScaleWriterPartitioningLocalPartition::kRebalanceTriggers), + std::string( + ScaleWriterPartitioningLocalPartition:: + kRebalanceTriggers)), 1); ASSERT_GT( planStats.at(exchnangeNodeId) .customStats - .at(ScaleWriterPartitioningLocalPartition::kRebalanceTriggers) + .at(std::string( + ScaleWriterPartitioningLocalPartition::kRebalanceTriggers)) .sum, 0); } else { ASSERT_EQ( planStats.at(exchnangeNodeId) .customStats.count( - ScaleWriterPartitioningLocalPartition::kScaledPartitions), + std::string( + ScaleWriterPartitioningLocalPartition:: + kScaledPartitions)), 0); ASSERT_EQ( planStats.at(exchnangeNodeId) .customStats.count( - ScaleWriterPartitioningLocalPartition::kRebalanceTriggers), + std::string( + ScaleWriterPartitioningLocalPartition:: + kRebalanceTriggers)), 0); verifyDisjointPartitionKeys(testController.get()); } diff --git a/velox/exec/tests/SimpleAggregateAdapterTest.cpp b/velox/exec/tests/SimpleAggregateAdapterTest.cpp index 52ed8b1e17a..d104443453b 100644 --- a/velox/exec/tests/SimpleAggregateAdapterTest.cpp +++ b/velox/exec/tests/SimpleAggregateAdapterTest.cpp @@ -29,6 +29,8 @@ namespace { const char* const kSimpleAvg = "simple_avg"; const char* const kSimpleArrayAgg = "simple_array_agg"; const char* const kSimpleCountNulls = "simple_count_nulls"; +const char* const kSimpleVariadicSum = "simple_variadic_sum"; +const char* const kSimpleVariadicArrayAgg = "simple_variadic_array_agg"; class SimpleAverageAggregationTest : public AggregationTestBase { protected: @@ -602,5 +604,340 @@ TEST_F(SimpleFuncLevelVariableAggregationTest, simpleAggregateVariables) { {}); } +class SimpleVariadicSumAggregationTest : public AggregationTestBase { + protected: + void SetUp() override { + AggregationTestBase::SetUp(); + registerSimpleVariadicSumAggregate(kSimpleVariadicSum); + } +}; + +TEST_F(SimpleVariadicSumAggregationTest, basicVariadicSum) { + // Test global with 3 variadic arguments: sum each column across rows. + // Input: + // Row 1: count=3, a=1, b=2, c=3 + // Row 2: count=3, a=4, b=5, c=6 + // Expected output: [1+4, 2+5, 3+6] = [5, 7, 9] + auto inputVectors = makeRowVector({ + makeFlatVector({3, 3}), + makeFlatVector({1, 4}), + makeFlatVector({2, 5}), + makeFlatVector({3, 6}), + }); + + auto expected = makeRowVector({makeArrayVector({{5, 7, 9}})}); + + testAggregations( + {inputVectors}, {}, {"simple_variadic_sum(c0, c1, c2, c3)"}, {expected}); + + // Test with grouping. + // Group true: [1+5, 2+6] = [6, 8] + // Group false: [3+7, 4+8] = [10, 12] + inputVectors = makeRowVector({ + makeFlatVector({true, false, true, false}), + makeFlatVector({2, 2, 2, 2}), + makeFlatVector({1, 3, 5, 7}), + makeFlatVector({2, 4, 6, 8}), + }); + + expected = makeRowVector({ + makeFlatVector({true, false}), + makeArrayVector({{6, 8}, {10, 12}}), + }); + + testAggregations( + {inputVectors}, {"c0"}, {"simple_variadic_sum(c1, c2, c3)"}, {expected}); +} + +TEST_F(SimpleVariadicSumAggregationTest, variadicSumWithNulls) { + // Test global handling of null values in variadic arguments. + // With default null behavior, rows with any null variadic element are + // skipped entirely. + // Row 0: variadic=[1, null, 3] -> SKIPPED (null in variadic) + // Row 1: variadic=[4, 5, 6] -> processed + // Row 2: variadic=[7, 8, 9] -> processed + // Expected: [4+7, 5+8, 6+9] = [11, 13, 15] + auto inputVectors = makeRowVector({ + makeFlatVector({3, 3, 3}), + makeNullableFlatVector({1, 4, 7}), + makeNullableFlatVector({std::nullopt, 5, 8}), + makeNullableFlatVector({3, 6, 9}), + }); + + auto expected = makeRowVector({makeArrayVector({{11, 13, 15}})}); + + testAggregations( + {inputVectors}, {}, {"simple_variadic_sum(c0, c1, c2, c3)"}, {expected}); + + // Test with grouping and null values. + // Row 0 (true): variadic=[1, 2, 3] -> processed + // Row 1 (false): variadic=[4, null, 6] -> SKIPPED + // Row 2 (true): variadic=[null, 8, 9] -> SKIPPED + // Row 3 (false): variadic=[10, 11, 12] -> processed + // Group true: only row 0 -> [1, 2, 3] + // Group false: only row 3 -> [10, 11, 12] + inputVectors = makeRowVector({ + makeFlatVector({true, false, true, false}), + makeFlatVector({3, 3, 3, 3}), + makeNullableFlatVector({1, 4, std::nullopt, 10}), + makeNullableFlatVector({2, std::nullopt, 8, 11}), + makeNullableFlatVector({3, 6, 9, 12}), + }); + + expected = makeRowVector({ + makeFlatVector({true, false}), + makeArrayVector({{1, 2, 3}, {10, 11, 12}}), + }); + + testAggregations( + {inputVectors}, + {"c0"}, + {"simple_variadic_sum(c1, c2, c3, c4)"}, + {expected}); +} + +TEST_F(SimpleVariadicSumAggregationTest, singleVariadicArg) { + // Test global with only 1 variadic argument. + // Input: + // Row 1: dummy=1, a=10 + // Row 2: dummy=1, a=20 + // Expected output: [10+20] = [30] + auto inputVectors = makeRowVector({ + makeFlatVector({1, 1}), + makeFlatVector({10, 20}), + }); + + auto expected = makeRowVector({makeArrayVector({{30}})}); + + testAggregations( + {inputVectors}, {}, {"simple_variadic_sum(c0, c1)"}, {expected}); + + // Test with only 1 variadic argument with grouping. + // Group true: [10+30] = [40] + // Group false: [20+40] = [60] + inputVectors = makeRowVector({ + makeFlatVector({true, false, true, false}), + makeFlatVector({1, 1, 1, 1}), + makeFlatVector({10, 20, 30, 40}), + }); + + expected = makeRowVector({ + makeFlatVector({true, false}), + makeArrayVector({{40}, {60}}), + }); + + testAggregations( + {inputVectors}, {"c0"}, {"simple_variadic_sum(c1, c2)"}, {expected}); +} + +TEST_F(SimpleVariadicSumAggregationTest, noVariadicArg) { + // Test global with no variadic argument. + // Expected output: [] + auto inputVectors = makeRowVector({ + makeFlatVector({1, 1}), + }); + + auto expected = makeRowVector({makeArrayVector({{}})}); + + testAggregations({inputVectors}, {}, {"simple_variadic_sum(c0)"}, {expected}); + + // Test with no variadic argument with grouping. + // Group true: [] + // Group false: [] + inputVectors = makeRowVector({ + makeFlatVector({true, false, true, false}), + makeFlatVector({1, 1, 1, 1}), + }); + + expected = makeRowVector({ + makeFlatVector({true, false}), + makeArrayVector({{}, {}}), + }); + + testAggregations( + {inputVectors}, {"c0"}, {"simple_variadic_sum(c1)"}, {expected}); +} + +class SimpleVariadicArrayAggAggregationTest : public AggregationTestBase { + protected: + void SetUp() override { + AggregationTestBase::SetUp(); + registerSimpleVariadicArrayAggAggregate(kSimpleVariadicArrayAgg); + } +}; + +TEST_F(SimpleVariadicArrayAggAggregationTest, basicVariadicArrayAgg) { + // Test global with 3 variadic arguments: collect all values into a single + // array. Input: + // Row 1: a=1, b=2, c=3 + // Row 2: a=4, b=5, c=6 + // Expected output: [1, 2, 3, 4, 5, 6] + auto inputVectors = makeRowVector({ + makeFlatVector({1, 4}), + makeFlatVector({2, 5}), + makeFlatVector({3, 6}), + }); + + auto expected = + makeRowVector({makeArrayVector({{1, 2, 3, 4, 5, 6}})}); + + testAggregations( + {inputVectors}, + {}, + {"simple_variadic_array_agg(c0, c1, c2)"}, + {"array_sort(a0)"}, + {expected}); + + // Test with grouping. + // Group true: rows 1 and 3 -> [1, 2, 5, 6] + // Group false: rows 2 and 4 -> [3, 4, 7, 8] + inputVectors = makeRowVector({ + makeFlatVector({true, false, true, false}), + makeFlatVector({1, 3, 5, 7}), + makeFlatVector({2, 4, 6, 8}), + }); + + expected = makeRowVector({ + makeFlatVector({false, true}), + makeArrayVector({{3, 4, 7, 8}, {1, 2, 5, 6}}), + }); + + testAggregations( + {inputVectors}, + {"c0"}, + {"simple_variadic_array_agg(c1, c2)"}, + {"c0", "array_sort(a0)"}, + {expected}); +} + +TEST_F(SimpleVariadicArrayAggAggregationTest, variadicArrayAggWithNulls) { + // Test global handling of null values in variadic arguments. + // Nulls should be included in the output array (non-default null behavior). + // Row 1: 1, null, 3 + // Row 2: 4, 5, null + // Expected: [1, 3, 4, 5, null, null] + auto inputVectors = makeRowVector({ + makeNullableFlatVector({1, 4}), + makeNullableFlatVector({std::nullopt, 5}), + makeNullableFlatVector({3, std::nullopt}), + }); + + auto expected = makeRowVector({vectorMaker_.arrayVectorNullable( + {{{1, 3, 4, 5, std::nullopt, std::nullopt}}})}); + + testAggregations( + {inputVectors}, + {}, + {"simple_variadic_array_agg(c0, c1, c2)"}, + {"array_sort(a0)"}, + {expected}); + + // Test with grouping and null values. + // Group true: rows 1, 3 -> [1, null, 3, 5, 7, null] + // Group false: rows 2, 4 -> [2, 4, null, 6, null, 8] + inputVectors = makeRowVector({ + makeFlatVector({true, false, true, false}), + makeNullableFlatVector({1, 2, 5, 6}), + makeNullableFlatVector({std::nullopt, 4, 7, std::nullopt}), + makeNullableFlatVector({3, std::nullopt, std::nullopt, 8}), + }); + + expected = makeRowVector({ + makeFlatVector({false, true}), + vectorMaker_.arrayVectorNullable( + {{{2, 4, 6, 8, std::nullopt, std::nullopt}}, + {{1, 3, 5, 7, std::nullopt, std::nullopt}}}), + }); + + testAggregations( + {inputVectors}, + {"c0"}, + {"simple_variadic_array_agg(c1, c2, c3)"}, + {"c0", "array_sort(a0)"}, + {expected}); +} + +TEST_F(SimpleVariadicArrayAggAggregationTest, variadicArrayAggStrings) { + // Test global with string type to verify Generic works with different + // types. + auto inputVectors = makeRowVector({ + makeFlatVector({"a", "d"}), + makeFlatVector({"b", "e"}), + makeFlatVector({"c", "f"}), + }); + + auto expected = makeRowVector( + {makeArrayVector({{"a", "b", "c", "d", "e", "f"}})}); + + testAggregations( + {inputVectors}, + {}, + {"simple_variadic_array_agg(c0, c1, c2)"}, + {"array_sort(a0)"}, + {expected}); + + // Test with grouping and string type. + // Group true: rows 1, 3 -> ["a", "b", "e", "f"] + // Group false: rows 2, 4 -> ["c", "d", "g", "h"] + inputVectors = makeRowVector({ + makeFlatVector({true, false, true, false}), + makeFlatVector({"a", "c", "e", "g"}), + makeFlatVector({"b", "d", "f", "h"}), + }); + + expected = makeRowVector({ + makeFlatVector({false, true}), + makeArrayVector({{"c", "d", "g", "h"}, {"a", "b", "e", "f"}}), + }); + + testAggregations( + {inputVectors}, + {"c0"}, + {"simple_variadic_array_agg(c1, c2)"}, + {"c0", "array_sort(a0)"}, + {expected}); +} + +TEST_F(SimpleVariadicArrayAggAggregationTest, singleVariadicArg) { + // Test global with only 1 variadic argument. + // Input: + // Row 1: a=10 + // Row 2: a=20 + // Row 3: a=30 + // Expected output: [10, 20, 30] + auto inputVectors = makeRowVector({ + makeFlatVector({10, 20, 30}), + }); + + auto expected = makeRowVector({makeArrayVector({{10, 20, 30}})}); + + testAggregations( + {inputVectors}, + {}, + {"simple_variadic_array_agg(c0)"}, + {"array_sort(a0)"}, + {expected}); + + // Test with only 1 variadic argument with grouping. + // Group true: [10, 30] + // Group false: [20, 40] + inputVectors = makeRowVector({ + makeFlatVector({true, false, true, false}), + makeFlatVector({10, 20, 30, 40}), + }); + + expected = makeRowVector({ + makeFlatVector({false, true}), + makeArrayVector({{20, 40}, {10, 30}}), + }); + + testAggregations( + {inputVectors}, + {"c0"}, + {"simple_variadic_array_agg(c1)"}, + {"c0", "array_sort(a0)"}, + {expected}); +} + } // namespace } // namespace facebook::velox::aggregate::test diff --git a/velox/exec/tests/SimpleAggregateFunctionsRegistration.h b/velox/exec/tests/SimpleAggregateFunctionsRegistration.h index 06095ab92e6..760615bf9d3 100644 --- a/velox/exec/tests/SimpleAggregateFunctionsRegistration.h +++ b/velox/exec/tests/SimpleAggregateFunctionsRegistration.h @@ -28,4 +28,10 @@ exec::AggregateRegistrationResult registerSimpleAverageAggregate( exec::AggregateRegistrationResult registerSimpleArrayAggAggregate( const std::string& name); +exec::AggregateRegistrationResult registerSimpleVariadicSumAggregate( + const std::string& name); + +exec::AggregateRegistrationResult registerSimpleVariadicArrayAggAggregate( + const std::string& name); + } // namespace facebook::velox::aggregate diff --git a/velox/exec/tests/SimpleArrayAggAggregate.cpp b/velox/exec/tests/SimpleArrayAggAggregate.cpp index 40a4f4e583e..ed0c793e0be 100644 --- a/velox/exec/tests/SimpleArrayAggAggregate.cpp +++ b/velox/exec/tests/SimpleArrayAggAggregate.cpp @@ -127,7 +127,7 @@ exec::AggregateRegistrationResult registerSimpleArrayAggAggregate( return exec::registerAggregateFunction( name, - std::move(signatures), + signatures, [name]( core::AggregationNode::Step step, const std::vector& argTypes, diff --git a/velox/exec/tests/SimpleAverageAggregate.cpp b/velox/exec/tests/SimpleAverageAggregate.cpp index fea9254cb1c..ea710380b3b 100644 --- a/velox/exec/tests/SimpleAverageAggregate.cpp +++ b/velox/exec/tests/SimpleAverageAggregate.cpp @@ -34,9 +34,8 @@ class AverageAggregate { using InputType = Row; // Type of intermediate result vector wrapped in Row. - using IntermediateType = - Row; + using IntermediateType = Row; // Type of output vector. using OutputType = @@ -102,18 +101,20 @@ exec::AggregateRegistrationResult registerSimpleAverageAggregate( std::vector> signatures; for (const auto& inputType : {"smallint", "integer", "bigint", "double"}) { - signatures.push_back(exec::AggregateFunctionSignatureBuilder() - .returnType("double") - .intermediateType("row(double,bigint)") - .argumentType(inputType) - .build()); + signatures.push_back( + exec::AggregateFunctionSignatureBuilder() + .returnType("double") + .intermediateType("row(double,bigint)") + .argumentType(inputType) + .build()); } - signatures.push_back(exec::AggregateFunctionSignatureBuilder() - .returnType("real") - .intermediateType("row(double,bigint)") - .argumentType("real") - .build()); + signatures.push_back( + exec::AggregateFunctionSignatureBuilder() + .returnType("real") + .intermediateType("row(double,bigint)") + .argumentType("real") + .build()); return exec::registerAggregateFunction( name, diff --git a/velox/exec/tests/SimpleVariadicArrayAggAggregate.cpp b/velox/exec/tests/SimpleVariadicArrayAggAggregate.cpp new file mode 100644 index 00000000000..32886393536 --- /dev/null +++ b/velox/exec/tests/SimpleVariadicArrayAggAggregate.cpp @@ -0,0 +1,160 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/Aggregate.h" +#include "velox/exec/SimpleAggregateAdapter.h" +#include "velox/expression/FunctionSignature.h" +#include "velox/functions/lib/aggregates/ValueList.h" + +using namespace facebook::velox::exec; + +namespace facebook::velox::aggregate { + +namespace { + +// An aggregate function that demonstrates variadic argument support with +// Generic types and non-default null behavior in SimpleAggregateAdapter. This +// function takes a variadic list of values of the same type and aggregates all +// of them into a single array, including nulls. +// +// Example: +// SELECT variadic_array_agg(a, b, c) FROM ( +// VALUES (1, 2, 3), (4, 5, 6) +// ) AS t(a, b, c) +// => [1, 2, 3, 4, 5, 6] +class VariadicArrayAggAggregate { + public: + using InputType = AggregateInputType>>; + + using IntermediateType = Array>; + + using OutputType = Array>; + + static constexpr bool default_null_behavior_ = false; + + static bool toIntermediate( + exec::out_type>>& out, + exec::optional_arg_type>> variadicArgs) { + if (!variadicArgs.has_value()) { + VELOX_UNREACHABLE( + "simple_variadic_array_agg requires at least one variadic argument."); + } + for (auto i = 0; i < variadicArgs.value().size(); ++i) { + if (variadicArgs->at(i).has_value()) { + out.add_item().copy_from(variadicArgs->at(i).value()); + } else { + out.add_null(); + } + } + return true; + } + + struct AccumulatorType { + ValueList elements_; + + AccumulatorType() = delete; + + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + VariadicArrayAggAggregate* /*fn*/) + : elements_{} {} + + static constexpr bool is_fixed_size_ = false; + + bool addInput( + HashStringAllocator* allocator, + exec::optional_arg_type>> variadicArgs) { + if (!variadicArgs.has_value()) { + VELOX_UNREACHABLE( + "simple_variadic_array_agg requires at least one variadic argument."); + } + for (auto i = 0; i < variadicArgs.value().size(); ++i) { + elements_.appendValue(variadicArgs->at(i), allocator); + } + return true; + } + + bool combine( + HashStringAllocator* allocator, + exec::optional_arg_type>> other) { + if (!other.has_value()) { + return false; + } + for (auto i = 0; i < other.value().size(); ++i) { + elements_.appendValue(other->at(i), allocator); + } + return true; + } + + bool writeFinalResult( + bool nonNullGroup, + exec::out_type>>& out) { + if (!nonNullGroup) { + return false; + } + copyValueListToArrayWriter(out, elements_); + return true; + } + + bool writeIntermediateResult( + bool nonNullGroup, + exec::out_type>>& out) { + if (!nonNullGroup) { + return false; + } + copyValueListToArrayWriter(out, elements_); + return true; + } + + void destroy(HashStringAllocator* allocator) { + elements_.free(allocator); + } + }; +}; + +} // namespace + +exec::AggregateRegistrationResult registerSimpleVariadicArrayAggAggregate( + const std::string& name) { + std::vector> signatures{ + exec::AggregateFunctionSignatureBuilder() + .typeVariable("E") + .returnType("array(E)") + .intermediateType("array(E)") + .argumentType("E") + .variableArity() + .build()}; + + return exec::registerAggregateFunction( + name, + signatures, + [name]( + core::AggregationNode::Step step, + const std::vector& argTypes, + const TypePtr& resultType, + const core::QueryConfig& /*config*/) + -> std::unique_ptr { + VELOX_CHECK_GE( + argTypes.size(), 1, "{} requires at least 1 argument", name); + return std::make_unique< + SimpleAggregateAdapter>( + step, argTypes, resultType); + }, + true /*registerCompanionFunctions*/, + true /*overwrite*/); +} + +} // namespace facebook::velox::aggregate diff --git a/velox/exec/tests/SimpleVariadicSumAggregate.cpp b/velox/exec/tests/SimpleVariadicSumAggregate.cpp new file mode 100644 index 00000000000..b1703e20fa3 --- /dev/null +++ b/velox/exec/tests/SimpleVariadicSumAggregate.cpp @@ -0,0 +1,138 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/Aggregate.h" +#include "velox/exec/SimpleAggregateAdapter.h" +#include "velox/expression/FunctionSignature.h" + +using namespace facebook::velox::exec; + +namespace facebook::velox::aggregate { + +namespace { + +// An aggregate function that demonstrates variadic argument support in +// SimpleAggregateAdapter. This function takes a dummy integer argument followed +// by a variadic list of integers. It returns an array where the i-th element is +// the sum of all i-th variadic arguments across all rows. +// +// Example: +// SELECT variadic_sum_agg(3, a, b, c) FROM ( +// VALUES (1, 2, 3), (4, 5, 6) +// ) AS t(a, b, c) +// => [5, 7, 9] (i.e., [1+4, 2+5, 3+6]) +class VariadicSumAggregate { + public: + // Force the function to take a dummy integer before the variadic arguments to + // test variadic list not starting from the beginning. + using InputType = AggregateInputType>; + + using IntermediateType = Array; + + using OutputType = Array; + + struct AccumulatorType { + std::vector sums_; + + AccumulatorType() = delete; + + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + VariadicSumAggregate* /*fn*/) + : sums_{} {} + + static constexpr bool is_fixed_size_ = false; + static constexpr bool use_external_memory_ = true; + + void addInput( + HashStringAllocator* /*allocator*/, + exec::arg_type /*dummy*/, + exec::arg_type> variadicArgs) { + // Initialize sums_ based on the actual number of variadic arguments. + if (sums_.empty()) { + sums_.resize(variadicArgs.size(), 0); + } + + for (auto i = 0; i < variadicArgs.size(); ++i) { + sums_[i] += variadicArgs.at(i).value(); + } + } + + void combine( + HashStringAllocator* /*allocator*/, + exec::arg_type> other) { + // Initialize sums_ based on the incoming array size if not yet + // initialized. + if (sums_.empty()) { + sums_.resize(other.size(), 0); + } + + // Add element-wise. + for (auto i = 0; i < other.size(); ++i) { + if (other.at(i).has_value()) { + sums_[i] += other.at(i).value(); + } + } + } + + bool writeFinalResult(exec::out_type>& out) { + for (const auto& sum : sums_) { + out.add_item() = sum; + } + return true; + } + + bool writeIntermediateResult(exec::out_type>& out) { + for (const auto& sum : sums_) { + out.add_item() = sum; + } + return true; + } + }; +}; + +} // namespace + +exec::AggregateRegistrationResult registerSimpleVariadicSumAggregate( + const std::string& name) { + std::vector> signatures{ + exec::AggregateFunctionSignatureBuilder() + .returnType("array(bigint)") + .intermediateType("array(bigint)") + .argumentType("bigint") + .argumentType("bigint") + .variableArity() + .build()}; + + return exec::registerAggregateFunction( + name, + signatures, + [name]( + core::AggregationNode::Step step, + const std::vector& argTypes, + const TypePtr& resultType, + const core::QueryConfig& /*config*/) + -> std::unique_ptr { + VELOX_CHECK_GE( + argTypes.size(), 1, "{} requires at least 1 argument", name); + return std::make_unique>( + step, argTypes, resultType); + }, + true /*registerCompanionFunctions*/, + true /*overwrite*/); +} + +} // namespace facebook::velox::aggregate diff --git a/velox/exec/tests/SortBufferTest.cpp b/velox/exec/tests/SortBufferTest.cpp index 93bf32edf4d..684aae78051 100644 --- a/velox/exec/tests/SortBufferTest.cpp +++ b/velox/exec/tests/SortBufferTest.cpp @@ -15,12 +15,13 @@ */ #include "velox/exec/SortBuffer.h" +#include #include #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/FileSystems.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/exec/tests/utils/OperatorTestBase.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/type/Type.h" #include "velox/vector/fuzzer/VectorFuzzer.h" #include "velox/vector/tests/utils/VectorTestBase.h" @@ -31,6 +32,7 @@ using namespace facebook::velox; using namespace facebook::velox::memory; namespace facebook::velox::functions::test { +using namespace facebook::velox::common::testutil; namespace { // Class to write runtime stats in the tests to the stats container. class TestRuntimeStatWriter : public BaseRuntimeStatWriter { @@ -39,7 +41,7 @@ class TestRuntimeStatWriter : public BaseRuntimeStatWriter { std::unordered_map& stats) : stats_{stats} {} - void addRuntimeStat(const std::string& name, const RuntimeCounter& value) + void addRuntimeStat(std::string_view name, const RuntimeCounter& value) override { addOperatorRuntimeStats(name, value, stats_); } @@ -89,6 +91,7 @@ class SortBufferTest : public OperatorTestBase, 0, 0, "none", + 0, spillPrefixSortConfig); } @@ -118,7 +121,9 @@ class SortBufferTest : public OperatorTestBase, const std::shared_ptr executor_{ std::make_shared( - std::thread::hardware_concurrency())}; + folly::available_concurrency())}; + const std::shared_ptr fuzzerPool_ = + memory::memoryManager()->addLeafPool("SortBufferTest"); tsan_atomic nonReclaimableSection_{false}; folly::Random::DefaultGenerator rng_; @@ -199,16 +204,16 @@ TEST_P(SortBufferTest, singleKey) { } if (GetParam()) { ASSERT_EQ( - stats_.at(PrefixSort::kNumPrefixSortKeys).sum, + stats_.at(std::string(PrefixSort::kNumPrefixSortKeys)).sum, sortColumnIndices_.size()); ASSERT_EQ( - stats_.at(PrefixSort::kNumPrefixSortKeys).max, + stats_.at(std::string(PrefixSort::kNumPrefixSortKeys)).max, sortColumnIndices_.size()); ASSERT_EQ( - stats_.at(PrefixSort::kNumPrefixSortKeys).min, + stats_.at(std::string(PrefixSort::kNumPrefixSortKeys)).min, sortColumnIndices_.size()); } else { - ASSERT_EQ(stats_.count(PrefixSort::kNumPrefixSortKeys), 0); + ASSERT_EQ(stats_.count(std::string(PrefixSort::kNumPrefixSortKeys)), 0); } stats_.clear(); } @@ -261,16 +266,16 @@ TEST_P(SortBufferTest, multipleKeys) { ASSERT_EQ(output->childAt(1)->asFlatVector()->valueAt(9), 5); if (GetParam()) { ASSERT_EQ( - stats_.at(PrefixSort::kNumPrefixSortKeys).sum, + stats_.at(std::string(PrefixSort::kNumPrefixSortKeys)).sum, sortColumnIndices_.size()); ASSERT_EQ( - stats_.at(PrefixSort::kNumPrefixSortKeys).max, + stats_.at(std::string(PrefixSort::kNumPrefixSortKeys)).max, sortColumnIndices_.size()); ASSERT_EQ( - stats_.at(PrefixSort::kNumPrefixSortKeys).min, + stats_.at(std::string(PrefixSort::kNumPrefixSortKeys)).min, sortColumnIndices_.size()); } else { - ASSERT_EQ(stats_.count(PrefixSort::kNumPrefixSortKeys), 0); + ASSERT_EQ(stats_.count(std::string(PrefixSort::kNumPrefixSortKeys)), 0); } } @@ -335,13 +340,10 @@ TEST_P(SortBufferTest, DISABLED_randomData) { &nonReclaimableSection_, prefixSortConfig_); - const std::shared_ptr fuzzerPool = - memory::memoryManager()->addLeafPool("VectorFuzzer"); - std::vector inputVectors; inputVectors.reserve(3); for (size_t inputRows : {1000, 1000, 1000}) { - VectorFuzzer fuzzer({.vectorSize = inputRows}, fuzzerPool.get()); + VectorFuzzer fuzzer({.vectorSize = inputRows}, fuzzerPool_.get()); RowVectorPtr input = fuzzer.fuzzRow(inputType_); sortBuffer->addInput(input); inputVectors.push_back(input); @@ -384,7 +386,7 @@ TEST_P(SortBufferTest, batchOutput) { TestScopedSpillInjection scopedSpillInjection(100); for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); auto spillConfig = common::SpillConfig( [&]() -> const std::string& { return spillDirectory->getPath(); }, [&](uint64_t) {}, @@ -401,8 +403,9 @@ TEST_P(SortBufferTest, batchOutput) { 0, 0, "none", + 0, prefixSortConfig_); - folly::Synchronized spillStats; + exec::SpillStats spillStats; auto sortBuffer = std::make_unique( inputType_, sortColumnIndices_, @@ -413,15 +416,11 @@ TEST_P(SortBufferTest, batchOutput) { testData.triggerSpill ? &spillConfig : nullptr, &spillStats); ASSERT_EQ(sortBuffer->canSpill(), testData.triggerSpill); - - const std::shared_ptr fuzzerPool = - memory::memoryManager()->addLeafPool("VectorFuzzer"); - std::vector inputVectors; inputVectors.reserve(testData.numInputRows.size()); uint64_t totalNumInput = 0; for (size_t inputRows : testData.numInputRows) { - VectorFuzzer fuzzer({.vectorSize = inputRows}, fuzzerPool.get()); + VectorFuzzer fuzzer({.vectorSize = inputRows}, fuzzerPool_.get()); RowVectorPtr input = fuzzer.fuzzRow(inputType_); sortBuffer->addInput(input); inputVectors.push_back(input); @@ -438,14 +437,14 @@ TEST_P(SortBufferTest, batchOutput) { } if (!testData.triggerSpill) { - ASSERT_TRUE(spillStats.rlock()->empty()); + ASSERT_TRUE(spillStats.empty()); } else { - ASSERT_FALSE(spillStats.rlock()->empty()); - ASSERT_GT(spillStats.rlock()->spilledRows, 0); - ASSERT_LE(spillStats.rlock()->spilledRows, totalNumInput); - ASSERT_GT(spillStats.rlock()->spilledBytes, 0); - ASSERT_EQ(spillStats.rlock()->spilledPartitions, 1); - ASSERT_GT(spillStats.rlock()->spilledFiles, 0); + ASSERT_FALSE(spillStats.empty()); + ASSERT_GT(spillStats.spilledRows, 0); + ASSERT_LE(spillStats.spilledRows, totalNumInput); + ASSERT_GT(spillStats.spilledBytes, 0); + ASSERT_EQ(spillStats.spilledPartitions, 1); + ASSERT_GT(spillStats.spilledFiles, 0); } } } @@ -476,7 +475,7 @@ TEST_P(SortBufferTest, spill) { for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); // memory pool limit is 20M // Set 'kSpillableReservationGrowthPct' to an extreme large value to trigger // memory reservation failure and thus trigger disk spilling. @@ -498,8 +497,9 @@ TEST_P(SortBufferTest, spill) { 0, 0, "none", + 0, prefixSortConfig_); - folly::Synchronized spillStats; + exec::SpillStats spillStats; auto sortBuffer = std::make_unique( inputType_, sortColumnIndices_, @@ -510,9 +510,7 @@ TEST_P(SortBufferTest, spill) { testData.spillEnabled ? &spillConfig : nullptr, &spillStats); - const std::shared_ptr fuzzerPool = - memory::memoryManager()->addLeafPool("spillSource"); - VectorFuzzer fuzzer({.vectorSize = 1024}, fuzzerPool.get()); + VectorFuzzer fuzzer({.vectorSize = 1024}, fuzzerPool_.get()); uint64_t totalNumInput = 0; ASSERT_EQ(memory::spillMemoryPool()->stats().usedBytes, 0); @@ -528,20 +526,20 @@ TEST_P(SortBufferTest, spill) { sortBuffer->noMoreInput(); if (!testData.spillTriggered) { - ASSERT_TRUE(spillStats.rlock()->empty()); + ASSERT_TRUE(spillStats.empty()); if (!testData.spillEnabled) { VELOX_ASSERT_THROW(sortBuffer->spill(), "spill config is null"); } } else { - ASSERT_FALSE(spillStats.rlock()->empty()); - ASSERT_GT(spillStats.rlock()->spilledRows, 0); - ASSERT_LE(spillStats.rlock()->spilledRows, totalNumInput); - ASSERT_GT(spillStats.rlock()->spilledBytes, 0); - ASSERT_EQ(spillStats.rlock()->spilledPartitions, 1); + ASSERT_FALSE(spillStats.empty()); + ASSERT_GT(spillStats.spilledRows, 0); + ASSERT_LE(spillStats.spilledRows, totalNumInput); + ASSERT_GT(spillStats.spilledBytes, 0); + ASSERT_EQ(spillStats.spilledPartitions, 1); // SortBuffer shall not respect maxFileSize. Total files should be num // addInput() calls minus one which is the first one that has nothing to // spill. - ASSERT_EQ(spillStats.rlock()->spilledFiles, 3); + ASSERT_EQ(spillStats.spilledFiles, 3); sortBuffer.reset(); ASSERT_EQ(memory::spillMemoryPool()->stats().usedBytes, 0); if (memory::spillMemoryPool()->trackUsage()) { @@ -552,25 +550,25 @@ TEST_P(SortBufferTest, spill) { } if (GetParam()) { ASSERT_GE( - stats_.at(PrefixSort::kNumPrefixSortKeys).sum, + stats_.at(std::string(PrefixSort::kNumPrefixSortKeys)).sum, sortColumnIndices_.size()); ASSERT_EQ( - stats_.at(PrefixSort::kNumPrefixSortKeys).max, + stats_.at(std::string(PrefixSort::kNumPrefixSortKeys)).max, sortColumnIndices_.size()); ASSERT_EQ( - stats_.at(PrefixSort::kNumPrefixSortKeys).min, + stats_.at(std::string(PrefixSort::kNumPrefixSortKeys)).min, sortColumnIndices_.size()); } else { - ASSERT_EQ(stats_.count(PrefixSort::kNumPrefixSortKeys), 0); + ASSERT_EQ(stats_.count(std::string(PrefixSort::kNumPrefixSortKeys)), 0); } stats_.clear(); } } DEBUG_ONLY_TEST_P(SortBufferTest, spillDuringInput) { - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); const auto spillConfig = getSpillConfig(spillDirectory->getPath()); - folly::Synchronized spillStats; + exec::SpillStats spillStats; auto sortBuffer = std::make_unique( inputType_, sortColumnIndices_, @@ -595,9 +593,7 @@ DEBUG_ONLY_TEST_P(SortBufferTest, spillDuringInput) { ASSERT_EQ(sortBuffer->pool()->usedBytes(), 0); }))); - const std::shared_ptr fuzzerPool = - memory::memoryManager()->addLeafPool("spillDuringInput"); - VectorFuzzer fuzzer({.vectorSize = 1024}, fuzzerPool.get()); + VectorFuzzer fuzzer({.vectorSize = 1024}, fuzzerPool_.get()); ASSERT_EQ(memory::spillMemoryPool()->stats().usedBytes, 0); const auto peakSpillMemoryUsage = @@ -608,12 +604,12 @@ DEBUG_ONLY_TEST_P(SortBufferTest, spillDuringInput) { } sortBuffer->noMoreInput(); - ASSERT_FALSE(spillStats.rlock()->empty()); - ASSERT_GT(spillStats.rlock()->spilledRows, 0); - ASSERT_EQ(spillStats.rlock()->spilledRows, numInputs * 1024); - ASSERT_GT(spillStats.rlock()->spilledBytes, 0); - ASSERT_EQ(spillStats.rlock()->spilledPartitions, 1); - ASSERT_EQ(spillStats.rlock()->spilledFiles, 2); + ASSERT_FALSE(spillStats.empty()); + ASSERT_GT(spillStats.spilledRows, 0); + ASSERT_EQ(spillStats.spilledRows, numInputs * 1024); + ASSERT_GT(spillStats.spilledBytes, 0); + ASSERT_EQ(spillStats.spilledPartitions, 1); + ASSERT_EQ(spillStats.spilledFiles, 2); ASSERT_EQ(memory::spillMemoryPool()->stats().usedBytes, 0); if (memory::spillMemoryPool()->trackUsage()) { @@ -621,12 +617,13 @@ DEBUG_ONLY_TEST_P(SortBufferTest, spillDuringInput) { ASSERT_GE( memory::spillMemoryPool()->stats().peakBytes, peakSpillMemoryUsage); } + ASSERT_EQ(pool_->usedBytes(), 0); } DEBUG_ONLY_TEST_P(SortBufferTest, spillDuringOutput) { - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); const auto spillConfig = getSpillConfig(spillDirectory->getPath()); - folly::Synchronized spillStats; + exec::SpillStats spillStats; auto sortBuffer = std::make_unique( inputType_, sortColumnIndices_, @@ -645,10 +642,7 @@ DEBUG_ONLY_TEST_P(SortBufferTest, spillDuringOutput) { sortBuffer->spill(); ASSERT_EQ(sortBuffer->pool()->usedBytes(), 0); }))); - - const std::shared_ptr fuzzerPool = - memory::memoryManager()->addLeafPool("spillDuringOutput"); - VectorFuzzer fuzzer({.vectorSize = 1024}, fuzzerPool.get()); + VectorFuzzer fuzzer({.vectorSize = 1024}, fuzzerPool_.get()); ASSERT_EQ(memory::spillMemoryPool()->stats().usedBytes, 0); const auto peakSpillMemoryUsage = @@ -659,12 +653,12 @@ DEBUG_ONLY_TEST_P(SortBufferTest, spillDuringOutput) { } sortBuffer->noMoreInput(); - ASSERT_FALSE(spillStats.rlock()->empty()); - ASSERT_GT(spillStats.rlock()->spilledRows, 0); - ASSERT_EQ(spillStats.rlock()->spilledRows, numInputs * 1024); - ASSERT_GT(spillStats.rlock()->spilledBytes, 0); - ASSERT_EQ(spillStats.rlock()->spilledPartitions, 1); - ASSERT_EQ(spillStats.rlock()->spilledFiles, 1); + ASSERT_FALSE(spillStats.empty()); + ASSERT_GT(spillStats.spilledRows, 0); + ASSERT_EQ(spillStats.spilledRows, numInputs * 1024); + ASSERT_GT(spillStats.spilledBytes, 0); + ASSERT_EQ(spillStats.spilledPartitions, 1); + ASSERT_EQ(spillStats.spilledFiles, 1); ASSERT_EQ(memory::spillMemoryPool()->stats().usedBytes, 0); if (memory::spillMemoryPool()->trackUsage()) { @@ -678,9 +672,9 @@ DEBUG_ONLY_TEST_P(SortBufferTest, reserveMemorySortGetOutput) { for (bool spillEnabled : {false, true}) { SCOPED_TRACE(fmt::format("spillEnabled {}", spillEnabled)); - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); const auto spillConfig = getSpillConfig(spillDirectory->getPath()); - folly::Synchronized spillStats; + exec::SpillStats spillStats; auto sortBuffer = std::make_unique( inputType_, sortColumnIndices_, @@ -690,10 +684,7 @@ DEBUG_ONLY_TEST_P(SortBufferTest, reserveMemorySortGetOutput) { prefixSortConfig_, spillEnabled ? &spillConfig : nullptr, &spillStats); - - const std::shared_ptr fuzzerPool = - memory::memoryManager()->addLeafPool("reserveMemoryGetOutput"); - VectorFuzzer fuzzer({.vectorSize = 1024}, fuzzerPool.get()); + VectorFuzzer fuzzer({.vectorSize = 1024}, fuzzerPool_.get()); const int numInputs{10}; for (int i = 0; i < numInputs; ++i) { @@ -735,11 +726,14 @@ DEBUG_ONLY_TEST_P(SortBufferTest, reserveMemorySort) { } testSettings[] = {{false, true}, {true, false}, {true, true}}; for (const auto [usePrefixSort, spillEnabled] : testSettings) { - SCOPED_TRACE(fmt::format( - "usePrefixSort: {}, spillEnabled: {}, ", usePrefixSort, spillEnabled)); - auto spillDirectory = exec::test::TempDirectoryPath::create(); + SCOPED_TRACE( + fmt::format( + "usePrefixSort: {}, spillEnabled: {}, ", + usePrefixSort, + spillEnabled)); + auto spillDirectory = TempDirectoryPath::create(); auto spillConfig = getSpillConfig(spillDirectory->getPath(), usePrefixSort); - folly::Synchronized spillStats; + exec::SpillStats spillStats; auto sortBuffer = std::make_unique( inputType_, sortColumnIndices_, @@ -777,14 +771,11 @@ DEBUG_ONLY_TEST_P(SortBufferTest, reserveMemorySort) { } TEST_P(SortBufferTest, emptySpill) { - const std::shared_ptr fuzzerPool = - memory::memoryManager()->addLeafPool("emptySpillSource"); - for (bool hasPostSpillData : {false, true}) { SCOPED_TRACE(fmt::format("hasPostSpillData {}", hasPostSpillData)); - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); auto spillConfig = getSpillConfig(spillDirectory->getPath()); - folly::Synchronized spillStats; + exec::SpillStats spillStats; auto sortBuffer = std::make_unique( inputType_, sortColumnIndices_, @@ -797,11 +788,11 @@ TEST_P(SortBufferTest, emptySpill) { sortBuffer->spill(); if (hasPostSpillData) { - VectorFuzzer fuzzer({.vectorSize = 1024}, fuzzerPool.get()); + VectorFuzzer fuzzer({.vectorSize = 1024}, fuzzerPool_.get()); sortBuffer->addInput(fuzzer.fuzzRow(inputType_)); } sortBuffer->noMoreInput(); - ASSERT_TRUE(spillStats.rlock()->empty()); + ASSERT_TRUE(spillStats.empty()); } } diff --git a/velox/exec/tests/SpatialIndexTest.cpp b/velox/exec/tests/SpatialIndexTest.cpp new file mode 100644 index 00000000000..e2a7778cd52 --- /dev/null +++ b/velox/exec/tests/SpatialIndexTest.cpp @@ -0,0 +1,550 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/SpatialIndex.h" +#include +#include + +using namespace ::testing; +using namespace facebook::velox::exec; + +namespace facebook::velox::exec::test { + +class SpatialIndexTest : public virtual testing::Test { + protected: + void makeIndex( + std::vector envelopes, + uint32_t branchSize = SpatialIndex::kDefaultRTreeBranchSize) { + branchSize_ = branchSize; + Envelope bounds = Envelope::of(envelopes); + index_ = SpatialIndex(std::move(bounds), std::move(envelopes), branchSize); + } + + Envelope indexBounds() const { + return index_.bounds(); + } + + void assertQuery( + double minX, + double minY, + double maxX, + double maxY, + std::vector expected) const { + std::vector actual = + index_.query(Envelope::from(minX, minY, maxX, maxY)); + std::sort(actual.begin(), actual.end()); + std::sort(expected.begin(), expected.end()); + ASSERT_EQ(actual, expected); + } + + SpatialIndex index_; + uint32_t branchSize_ = SpatialIndex::kDefaultRTreeBranchSize; +}; + +TEST_F(SpatialIndexTest, testEnvelope) { + Envelope empty = Envelope::empty(); + ASSERT_TRUE(empty.isEmpty()); + ASSERT_FALSE(Envelope::intersects(empty, empty)); + + Envelope point = + Envelope{.minX = 0, .minY = 0, .maxX = 0, .maxY = 0, .rowIndex = -1}; + ASSERT_FALSE(point.isEmpty()); + ASSERT_FALSE(Envelope::intersects(empty, point)); + ASSERT_TRUE(Envelope::intersects(point, point)); +} + +TEST_F(SpatialIndexTest, testNaNHandling) { + float nan = std::numeric_limits::quiet_NaN(); + + Envelope envWithNaN{ + .minX = nan, .minY = 0, .maxX = 1, .maxY = 1, .rowIndex = 0}; + ASSERT_TRUE(envWithNaN.isEmpty()); + + Envelope envWithNaN2{ + .minX = 0, .minY = 0, .maxX = nan, .maxY = 1, .rowIndex = 0}; + ASSERT_TRUE(envWithNaN2.isEmpty()); + + Envelope envWithNaN3{ + .minX = 0, .minY = nan, .maxX = 1, .maxY = 1, .rowIndex = 0}; + ASSERT_TRUE(envWithNaN3.isEmpty()); + + Envelope envWithNaN4{ + .minX = 0, .minY = 0, .maxX = 1, .maxY = nan, .rowIndex = 0}; + ASSERT_TRUE(envWithNaN4.isEmpty()); + + Envelope envWithNaN5{ + .minX = nan, .minY = nan, .maxX = nan, .maxY = nan, .rowIndex = 0}; + ASSERT_TRUE(envWithNaN5.isEmpty()); +} + +TEST_F(SpatialIndexTest, testEmptyIndex) { + makeIndex(std::vector{}); + Envelope bounds = indexBounds(); + ASSERT_EQ(bounds.minX, std::numeric_limits::infinity()); + ASSERT_EQ(bounds.minY, std::numeric_limits::infinity()); + ASSERT_EQ(bounds.maxX, -std::numeric_limits::infinity()); + ASSERT_EQ(bounds.maxY, -std::numeric_limits::infinity()); + ASSERT_EQ(bounds.rowIndex, -1); + + assertQuery(0, 0, 1, 1, {}); +} + +TEST_F(SpatialIndexTest, testSingleEnvelope) { + makeIndex( + std::vector{Envelope{ + .minX = 1, .minY = 11, .maxX = 2, .maxY = 12, .rowIndex = 0}}); + + Envelope bounds = indexBounds(); + ASSERT_EQ(bounds.minX, 1); + ASSERT_EQ(bounds.minY, 11); + ASSERT_EQ(bounds.maxX, 2); + ASSERT_EQ(bounds.maxY, 12); + + assertQuery(1.5, 11.5, 1.5, 11.5, {0}); + assertQuery(0.5, 10.5, 1.5, 11.5, {0}); + assertQuery(0, 10, 0.5, 10.5, {}); + assertQuery(3, 13, 4, 14, {}); +} + +TEST_F(SpatialIndexTest, testPointProbe) { + makeIndex( + std::vector{ + Envelope{.minX = 1, .minY = 0, .maxX = 1, .maxY = 0, .rowIndex = 6}, + Envelope{.minX = 0, .minY = 0, .maxX = 0, .maxY = 0, .rowIndex = 5}, + Envelope{.minX = 0, .minY = 0, .maxX = 1, .maxY = 1, .rowIndex = 4}, + Envelope{.minX = -1, .minY = -1, .maxX = 0, .maxY = 0, .rowIndex = 3}, + Envelope{.minX = -1, .minY = -1, .maxX = 1, .maxY = 1, .rowIndex = 2}, + Envelope{ + .minX = 0.5, .minY = 0.5, .maxX = 1, .maxY = 1, .rowIndex = 1}, + }); + Envelope bounds = indexBounds(); + ASSERT_EQ(bounds.minX, -1); + ASSERT_EQ(bounds.minY, -1); + ASSERT_EQ(bounds.maxX, 1); + ASSERT_EQ(bounds.maxY, 1); + ASSERT_EQ(bounds.rowIndex, -1); + + assertQuery(0, 0, 0, 0, {2, 3, 4, 5}); + assertQuery(0, 1, 0, 1, {2, 4}); +} + +TEST_F(SpatialIndexTest, testFloatImprecision) { + // Since the index casts doubles to floats then nudges the result, + // we should make sure that the index gives the right results on + // cases where the double doesn't have an exact float representation. + float float1 = 1.0f; + float float1Down = + std::nextafterf(float1, -std::numeric_limits::infinity()); + float float2 = 2.0f; + float float2Up = + std::nextafterf(float2, std::numeric_limits::infinity()); + + double baseMax = static_cast(float2); + double baseMaxUp = + std::nextafter(baseMax, std::numeric_limits::infinity()); + double baseMaxDown = + std::nextafter(baseMax, -std::numeric_limits::infinity()); + double baseMin = static_cast(float1); + double baseMinUp = + std::nextafter(baseMin, std::numeric_limits::infinity()); + double baseMinDown = + std::nextafter(baseMin, -std::numeric_limits::infinity()); + + makeIndex( + std::vector{ + Envelope::from(baseMin, baseMin, baseMax, baseMax, 1), + Envelope::from(baseMinUp, baseMinUp, baseMaxUp, baseMaxUp, 2), + Envelope::from(baseMinDown, baseMinDown, baseMaxDown, baseMaxDown, 3), + }); + + Envelope bounds = indexBounds(); + ASSERT_EQ(bounds.minX, float1Down); + ASSERT_EQ(bounds.minY, float1Down); + ASSERT_EQ(bounds.maxX, float2Up); + ASSERT_EQ(bounds.maxY, float2Up); + + assertQuery(2.1, 2.1, 2.1, 2.1, {}); + assertQuery(baseMin, baseMin, baseMin, baseMin, {1, 2, 3}); + assertQuery(baseMinDown, baseMinDown, baseMinDown, baseMinDown, {1, 2, 3}); + assertQuery(baseMinUp, baseMinUp, baseMinUp, baseMinUp, {1, 2, 3}); + assertQuery(baseMax, baseMax, baseMax, baseMax, {1, 2, 3}); + assertQuery(baseMaxDown, baseMaxDown, baseMaxDown, baseMaxDown, {1, 2, 3}); + assertQuery(baseMaxUp, baseMaxUp, baseMaxUp, baseMaxUp, {1, 2, 3}); +} + +TEST_F(SpatialIndexTest, testFloatImprecisionSubnormal) { + // Check that our bumping rules work for subnormal floats as well. + float subnormalFloatDown = + std::nextafterf(0.0, -std::numeric_limits::infinity()); + float subnormalFloatUp = + std::nextafterf(0.0, std::numeric_limits::infinity()); + + double subnormalDoubleDown = + std::nextafter(0.0, -std::numeric_limits::infinity()); + double subnormalDoubleUp = + std::nextafter(0.0, std::numeric_limits::infinity()); + + makeIndex( + std::vector{ + Envelope::from(0.0, 0.0, 0.0, 0.0, 1), + Envelope::from( + subnormalDoubleDown, + subnormalDoubleDown, + subnormalDoubleDown, + subnormalDoubleDown, + 2), + Envelope::from( + subnormalDoubleUp, + subnormalDoubleUp, + subnormalDoubleUp, + subnormalDoubleUp, + 3), + Envelope::from( + subnormalDoubleDown, + subnormalDoubleDown, + subnormalDoubleUp, + subnormalDoubleUp, + 4), + }); + + Envelope bounds = indexBounds(); + ASSERT_EQ(bounds.minX, subnormalFloatDown); + ASSERT_EQ(bounds.minY, subnormalFloatDown); + ASSERT_EQ(bounds.maxX, subnormalFloatUp); + ASSERT_EQ(bounds.maxY, subnormalFloatUp); + + assertQuery(0.1, 0.1, 0.1, 0.1, {}); + assertQuery(0.0, 0.0, 0.0, 0.0, {1, 2, 3, 4}); + assertQuery( + subnormalDoubleDown, + subnormalDoubleDown, + subnormalDoubleDown, + subnormalDoubleDown, + {1, 2, 3, 4}); + assertQuery( + subnormalDoubleUp, + subnormalDoubleUp, + subnormalDoubleUp, + subnormalDoubleUp, + {1, 2, 3, 4}); + assertQuery( + subnormalDoubleDown, + subnormalDoubleDown, + subnormalDoubleUp, + subnormalDoubleUp, + {1, 2, 3, 4}); +} + +TEST_F(SpatialIndexTest, testNegativeCoordinates) { + makeIndex( + std::vector{ + Envelope{ + .minX = -5, .minY = -5, .maxX = -1, .maxY = -1, .rowIndex = 0}, + Envelope{ + .minX = -10, .minY = -10, .maxX = -6, .maxY = -6, .rowIndex = 1}, + Envelope{ + .minX = -3, .minY = -8, .maxX = 2, .maxY = -4, .rowIndex = 2}}); + + Envelope bounds = indexBounds(); + ASSERT_EQ(bounds.minX, -10); + ASSERT_EQ(bounds.minY, -10); + ASSERT_EQ(bounds.maxX, 2); + ASSERT_EQ(bounds.maxY, -1); + + assertQuery(-7, -7, -7, -7, {1}); + assertQuery(-2, -5, -2, -5, {0, 2}); + assertQuery(0, 0, 1, 1, {}); +} + +TEST_F(SpatialIndexTest, testOverlappingEnvelopes) { + makeIndex( + std::vector{ + Envelope{.minX = 0, .minY = 0, .maxX = 10, .maxY = 10, .rowIndex = 0}, + Envelope{.minX = 5, .minY = 5, .maxX = 15, .maxY = 15, .rowIndex = 1}, + Envelope{.minX = 2, .minY = 2, .maxX = 8, .maxY = 8, .rowIndex = 2}, + Envelope{ + .minX = 7, .minY = 7, .maxX = 12, .maxY = 12, .rowIndex = 3}}); + + assertQuery(6, 6, 6, 6, {0, 1, 2}); + assertQuery(8, 8, 8, 8, {0, 1, 2, 3}); + assertQuery(9, 9, 9, 9, {0, 1, 3}); + assertQuery(3, 3, 3, 3, {0, 2}); + assertQuery(13, 13, 13, 13, {1}); +} + +TEST_F(SpatialIndexTest, testNonOverlappingEnvelopes) { + makeIndex( + std::vector{ + Envelope{.minX = 0, .minY = 0, .maxX = 1, .maxY = 1, .rowIndex = 0}, + Envelope{.minX = 2, .minY = 2, .maxX = 3, .maxY = 3, .rowIndex = 1}, + Envelope{.minX = 4, .minY = 4, .maxX = 5, .maxY = 5, .rowIndex = 2}, + Envelope{.minX = 6, .minY = 6, .maxX = 7, .maxY = 7, .rowIndex = 3}}); + + assertQuery(0.5, 0.5, 0.5, 0.5, {0}); + assertQuery(2.5, 2.5, 2.5, 2.5, {1}); + assertQuery(4.5, 4.5, 4.5, 4.5, {2}); + assertQuery(6.5, 6.5, 6.5, 6.5, {3}); + assertQuery(1.5, 1.5, 1.5, 1.5, {}); +} + +TEST_F(SpatialIndexTest, testLargeQueryEnvelope) { + makeIndex( + std::vector{ + Envelope{.minX = 1, .minY = 1, .maxX = 2, .maxY = 2, .rowIndex = 0}, + Envelope{.minX = 3, .minY = 3, .maxX = 4, .maxY = 4, .rowIndex = 1}, + Envelope{.minX = 5, .minY = 5, .maxX = 6, .maxY = 6, .rowIndex = 2}}); + + assertQuery(0, 0, 10, 10, {0, 1, 2}); + assertQuery(-100, -100, 100, 100, {0, 1, 2}); +} + +TEST_F(SpatialIndexTest, testSmallQueryEnvelope) { + makeIndex( + std::vector{ + Envelope{ + .minX = 0, .minY = 0, .maxX = 100, .maxY = 100, .rowIndex = 0}, + Envelope{ + .minX = 50, + .minY = 50, + .maxX = 150, + .maxY = 150, + .rowIndex = 1}}); + + assertQuery(25, 25, 26, 26, {0}); + assertQuery(75, 75, 76, 76, {0, 1}); + assertQuery(125, 125, 126, 126, {1}); + assertQuery(0.1, 0.1, 0.2, 0.2, {0}); +} + +TEST_F(SpatialIndexTest, testEdgeTouching) { + makeIndex( + std::vector{ + Envelope{.minX = 0, .minY = 0, .maxX = 5, .maxY = 5, .rowIndex = 0}, + Envelope{.minX = 5, .minY = 0, .maxX = 10, .maxY = 5, .rowIndex = 1}, + Envelope{.minX = 0, .minY = 5, .maxX = 5, .maxY = 10, .rowIndex = 2}, + Envelope{ + .minX = 5, .minY = 5, .maxX = 10, .maxY = 10, .rowIndex = 3}}); + + assertQuery(5, 5, 5, 5, {0, 1, 2, 3}); + assertQuery(5, 2, 5, 2, {0, 1}); + assertQuery(2, 5, 2, 5, {0, 2}); +} + +TEST_F(SpatialIndexTest, testCornerTouching) { + makeIndex( + std::vector{ + Envelope{.minX = 0, .minY = 0, .maxX = 5, .maxY = 5, .rowIndex = 0}, + Envelope{ + .minX = 5, .minY = 5, .maxX = 10, .maxY = 10, .rowIndex = 1}}); + + assertQuery(5, 5, 5, 5, {0, 1}); + assertQuery(4.9, 4.9, 5.1, 5.1, {0, 1}); +} + +TEST_F(SpatialIndexTest, testInfiniteValues) { + float inf = std::numeric_limits::infinity(); + float negInf = -std::numeric_limits::infinity(); + + makeIndex( + std::vector{ + Envelope{.minX = 0, .minY = 0, .maxX = 1, .maxY = 1, .rowIndex = 0}}); + + assertQuery(inf, inf, inf, inf, {}); + assertQuery(negInf, negInf, negInf, negInf, {}); + assertQuery(negInf, negInf, inf, inf, {0}); +} + +TEST_F(SpatialIndexTest, testLargeDataset) { + std::vector envelopes; + envelopes.reserve(1000); + for (int i = 0; i < 1000; ++i) { + envelopes.push_back( + Envelope{ + .minX = static_cast(i), + .minY = static_cast(i), + .maxX = static_cast(i + 1), + .maxY = static_cast(i + 1), + .rowIndex = i}); + } + makeIndex(std::move(envelopes)); + + assertQuery(500.5, 500.5, 500.5, 500.5, {500}); + assertQuery(100.5, 100.5, 104.5, 104.5, {100, 101, 102, 103, 104}); + assertQuery(-1, -1, -1, -1, {}); + assertQuery(1001, 1001, 1001, 1001, {}); + + Envelope bounds = indexBounds(); + ASSERT_EQ(bounds.minX, 0); + ASSERT_EQ(bounds.minY, 0); + ASSERT_EQ(bounds.maxX, 1000); + ASSERT_EQ(bounds.maxY, 1000); +} + +TEST_F(SpatialIndexTest, testVeryLargeCoordinates) { + float largeVal = 1e20f; + makeIndex( + std::vector{ + Envelope{ + .minX = -largeVal, + .minY = -largeVal, + .maxX = -largeVal + 1, + .maxY = -largeVal + 1, + .rowIndex = 0}, + Envelope{ + .minX = largeVal - 1, + .minY = largeVal - 1, + .maxX = largeVal, + .maxY = largeVal, + .rowIndex = 1}}); + + assertQuery( + -largeVal + 0.5, -largeVal + 0.5, -largeVal + 0.5, -largeVal + 0.5, {0}); + assertQuery( + largeVal - 0.5, largeVal - 0.5, largeVal - 0.5, largeVal - 0.5, {1}); + assertQuery(0, 0, 0, 0, {}); +} + +TEST_F(SpatialIndexTest, testQueryOutsideBounds) { + makeIndex( + std::vector{ + Envelope{.minX = 0, .minY = 0, .maxX = 10, .maxY = 10, .rowIndex = 0}, + Envelope{ + .minX = 5, .minY = 5, .maxX = 15, .maxY = 15, .rowIndex = 1}}); + + assertQuery(-10, -10, -5, -5, {}); + assertQuery(20, 20, 25, 25, {}); + assertQuery(-10, 5, -5, 10, {}); + assertQuery(5, 20, 10, 25, {}); +} + +TEST_F(SpatialIndexTest, testPartialOverlap) { + makeIndex( + std::vector{Envelope{ + .minX = 0, .minY = 0, .maxX = 10, .maxY = 10, .rowIndex = 0}}); + + assertQuery(-5, -5, 5, 5, {0}); + assertQuery(5, -5, 15, 5, {0}); + assertQuery(-5, 5, 5, 15, {0}); + assertQuery(5, 5, 15, 15, {0}); +} + +TEST_F(SpatialIndexTest, testMixedSizeEnvelopes) { + makeIndex( + std::vector{ + Envelope{ + .minX = 0, .minY = 0, .maxX = 0.1, .maxY = 0.1, .rowIndex = 0}, + Envelope{ + .minX = 1, .minY = 1, .maxX = 100, .maxY = 100, .rowIndex = 1}, + Envelope{ + .minX = 50, .minY = 50, .maxX = 51, .maxY = 51, .rowIndex = 2}}); + + assertQuery(0.05, 0.05, 0.05, 0.05, {0}); + assertQuery(50, 50, 100, 100, {1, 2}); + assertQuery(25, 25, 25, 25, {1}); +} + +TEST_F(SpatialIndexTest, testZeroAreaEnvelopes) { + makeIndex( + std::vector{ + Envelope{.minX = 0, .minY = 0, .maxX = 0, .maxY = 0, .rowIndex = 0}, + Envelope{.minX = 1, .minY = 1, .maxX = 1, .maxY = 1, .rowIndex = 1}, + Envelope{.minX = 2, .minY = 2, .maxX = 2, .maxY = 2, .rowIndex = 2}}); + + assertQuery(0, 0, 0, 0, {0}); + assertQuery(1, 1, 1, 1, {1}); + assertQuery(2, 2, 2, 2, {2}); + assertQuery(0.5, 0.5, 0.5, 0.5, {}); + assertQuery(0, 0, 2, 2, {0, 1, 2}); +} + +TEST_F(SpatialIndexTest, testIdenticalEnvelopes) { + makeIndex( + std::vector{ + Envelope{.minX = 5, .minY = 5, .maxX = 10, .maxY = 10, .rowIndex = 0}, + Envelope{.minX = 5, .minY = 5, .maxX = 10, .maxY = 10, .rowIndex = 1}, + Envelope{ + .minX = 5, .minY = 5, .maxX = 10, .maxY = 10, .rowIndex = 2}}); + + assertQuery(7, 7, 7, 7, {0, 1, 2}); + assertQuery(5, 5, 10, 10, {0, 1, 2}); + assertQuery(4, 4, 4, 4, {}); +} + +TEST_F(SpatialIndexTest, testDifferentBranchSizes) { + std::vector branchSizes = {2, 3, 4, 8, 16, 32, 64, 128, 256}; + + for (uint32_t branchSize : branchSizes) { + std::vector envelopes; + envelopes.reserve(100); + for (int i = 0; i < 100; ++i) { + envelopes.push_back( + Envelope{ + .minX = static_cast(i), + .minY = static_cast(i), + .maxX = static_cast(i + 1), + .maxY = static_cast(i + 1), + .rowIndex = i}); + } + makeIndex(std::move(envelopes), branchSize); + + assertQuery(50.5, 50.5, 50.5, 50.5, {50}); + assertQuery(10.5, 10.5, 14.5, 14.5, {10, 11, 12, 13, 14}); + assertQuery(-1, -1, -1, -1, {}); + assertQuery(101, 101, 101, 101, {}); + + Envelope bounds = indexBounds(); + ASSERT_EQ(bounds.minX, 0); + ASSERT_EQ(bounds.minY, 0); + ASSERT_EQ(bounds.maxX, 100); + ASSERT_EQ(bounds.maxY, 100); + } +} + +TEST_F(SpatialIndexTest, testSmallBranchSize) { + makeIndex( + std::vector{ + Envelope{.minX = 0, .minY = 0, .maxX = 10, .maxY = 10, .rowIndex = 0}, + Envelope{.minX = 5, .minY = 5, .maxX = 15, .maxY = 15, .rowIndex = 1}, + Envelope{.minX = 2, .minY = 2, .maxX = 8, .maxY = 8, .rowIndex = 2}, + Envelope{ + .minX = 7, .minY = 7, .maxX = 12, .maxY = 12, .rowIndex = 3}}, + 2); + + assertQuery(6, 6, 6, 6, {0, 1, 2}); + assertQuery(8, 8, 8, 8, {0, 1, 2, 3}); + assertQuery(9, 9, 9, 9, {0, 1, 3}); + assertQuery(3, 3, 3, 3, {0, 2}); + assertQuery(13, 13, 13, 13, {1}); +} + +TEST_F(SpatialIndexTest, testLargeBranchSize) { + makeIndex( + std::vector{ + Envelope{.minX = 0, .minY = 0, .maxX = 10, .maxY = 10, .rowIndex = 0}, + Envelope{.minX = 5, .minY = 5, .maxX = 15, .maxY = 15, .rowIndex = 1}, + Envelope{.minX = 2, .minY = 2, .maxX = 8, .maxY = 8, .rowIndex = 2}, + Envelope{ + .minX = 7, .minY = 7, .maxX = 12, .maxY = 12, .rowIndex = 3}}, + 512); + + assertQuery(6, 6, 6, 6, {0, 1, 2}); + assertQuery(8, 8, 8, 8, {0, 1, 2, 3}); + assertQuery(9, 9, 9, 9, {0, 1, 3}); + assertQuery(3, 3, 3, 3, {0, 2}); + assertQuery(13, 13, 13, 13, {1}); +} + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/SpatialJoinTest.cpp b/velox/exec/tests/SpatialJoinTest.cpp index 539f132c6e6..01606121a40 100644 --- a/velox/exec/tests/SpatialJoinTest.cpp +++ b/velox/exec/tests/SpatialJoinTest.cpp @@ -16,6 +16,7 @@ #include "velox/common/base/tests/GTestUtils.h" #include "velox/core/PlanFragment.h" #include "velox/core/PlanNode.h" +#include "velox/core/QueryConfig.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/OperatorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" @@ -67,56 +68,40 @@ class SpatialJoinTest : public OperatorTestBase { void runTest( const std::vector>& probeWkts, const std::vector>& buildWkts, + const std::optional>>& radiiOpt, const std::string& predicate, core::JoinType joinType, const std::vector>& expectedLeftWkts, const std::vector>& expectedRightWkts) { - runTestWithDrivers( - probeWkts, - buildWkts, - predicate, - joinType, - expectedLeftWkts, - expectedRightWkts, - 1, - false); - runTestWithDrivers( - probeWkts, - buildWkts, - predicate, - joinType, - expectedLeftWkts, - expectedRightWkts, - 1, - true); - runTestWithDrivers( - probeWkts, - buildWkts, - predicate, - joinType, - expectedLeftWkts, - expectedRightWkts, - 4, - false); - runTestWithDrivers( - probeWkts, - buildWkts, - predicate, - joinType, - expectedLeftWkts, - expectedRightWkts, - 4, - true); + for (bool separateProbeBatches : {false, true}) { + for (size_t maxBatchSize : {128, 3, 2, 1}) { + for (int32_t maxDrivers : {1, 4}) { + runTestWithConfig( + probeWkts, + buildWkts, + radiiOpt, + predicate, + joinType, + expectedLeftWkts, + expectedRightWkts, + maxDrivers, + maxBatchSize, + separateProbeBatches); + } + } + } } - void runTestWithDrivers( + void runTestWithConfig( const std::vector>& probeWkts, const std::vector>& buildWkts, + const std::optional>>& radiiOpt, const std::string& predicate, core::JoinType joinType, const std::vector>& expectedLeftWkts, const std::vector>& expectedRightWkts, int32_t maxDrivers, + size_t maxBatchSize, bool separateBatches) { std::vector> probeWktsStr( probeWkts.begin(), probeWkts.end()); @@ -126,6 +111,13 @@ class SpatialJoinTest : public OperatorTestBase { expectedLeftWkts.begin(), expectedLeftWkts.end()); std::vector> expectedRightWktsStr( expectedRightWkts.begin(), expectedRightWkts.end()); + auto radii = radiiOpt.value_or( + std::vector>(buildWkts.size(), std::nullopt)); + VELOX_CHECK_EQ(radii.size(), buildWkts.size()); + std::optional radiusVariable = std::nullopt; + if (radiiOpt.has_value()) { + radiusVariable = "radius"; + } std::vector probeBatches; std::vector buildBatches; @@ -134,15 +126,31 @@ class SpatialJoinTest : public OperatorTestBase { probeBatches.push_back(makeRowVector( {"left_g"}, {makeNullableFlatVector({wkt})})); } - for (const auto& wkt : buildWktsStr) { + if (probeBatches.empty()) { + probeBatches.push_back(makeRowVector( + {"left_g"}, {makeNullableFlatVector({})})); + } + + for (size_t idx = 0; idx < buildWktsStr.size(); ++idx) { + auto& wkt = buildWktsStr[idx]; + buildBatches.push_back(makeRowVector( + {"right_g", "radius"}, + {makeNullableFlatVector({wkt}), + makeNullableFlatVector({radii[idx]})})); + } + if (buildBatches.empty()) { buildBatches.push_back(makeRowVector( - {"right_g"}, {makeNullableFlatVector({wkt})})); + {"right_g", "radius"}, + {makeNullableFlatVector({}), + makeNullableFlatVector({})})); } } else { probeBatches.push_back(makeRowVector( {"left_g"}, {makeNullableFlatVector(probeWktsStr)})); buildBatches.push_back(makeRowVector( - {"right_g"}, {makeNullableFlatVector(buildWktsStr)})); + {"right_g", "radius"}, + {makeNullableFlatVector(buildWktsStr), + makeNullableFlatVector(radii)})); } auto expectedRows = makeRowVector( {"left_g", "right_g"}, @@ -158,10 +166,14 @@ class SpatialJoinTest : public OperatorTestBase { .spatialJoin( PlanBuilder(planNodeIdGenerator) .values(buildBatches) - .project({"ST_GeometryFromText(right_g) AS right_g"}) + .project( + {"ST_GeometryFromText(right_g) AS right_g", "radius"}) .localPartition({}) .planNode(), predicate, + "left_g", + "right_g", + radiusVariable, {"left_g", "right_g"}, joinType) .project( @@ -169,27 +181,79 @@ class SpatialJoinTest : public OperatorTestBase { "ST_AsText(right_g) AS right_g"}) .planNode(); AssertQueryBuilder builder{plan}; - builder.maxDrivers(maxDrivers).assertResults({expectedRows}); + builder.maxDrivers(maxDrivers) + .config(core::QueryConfig::kPreferredOutputBatchRows, maxBatchSize) + .config(core::QueryConfig::kMaxOutputBatchRows, maxBatchSize) + .assertResults({expectedRows}); } }; -TEST_F(SpatialJoinTest, simpleSpatialJoin) { + +TEST_F(SpatialJoinTest, testTrivialSpatialJoin) { + runTest( + {"POINT (1 1)"}, + {"POINT (1 1)"}, + std::nullopt, + "ST_Intersects(left_g, right_g)", + core::JoinType::kInner, + {"POINT (1 1)"}, + {"POINT (1 1)"}); +} + +TEST_F(SpatialJoinTest, testSimpleSpatialInnerJoin) { runTest( {"POINT (1 1)", "POINT (1 2)"}, {"POINT (1 1)", "POINT (2 1)"}, + std::nullopt, "ST_Intersects(left_g, right_g)", core::JoinType::kInner, {"POINT (1 1)"}, {"POINT (1 1)"}); +} + +TEST_F(SpatialJoinTest, testSimpleSpatialLeftJoin) { runTest( {"POINT (1 1)", "POINT (1 2)"}, {"POINT (1 1)", "POINT (2 1)"}, + std::nullopt, "ST_Intersects(left_g, right_g)", core::JoinType::kLeft, {"POINT (1 1)", "POINT (1 2)"}, {"POINT (1 1)", std::nullopt}); } -TEST_F(SpatialJoinTest, selfSpatialJoin) { +TEST_F(SpatialJoinTest, testSpatialJoinNullRows) { + runTest( + {"POINT (0 0)", std::nullopt, "POINT (1 1)", std::nullopt}, + {"POINT (0 0)", "POINT (1 1)", std::nullopt, std::nullopt}, + std::nullopt, + "ST_Intersects(left_g, right_g)", + core::JoinType::kInner, + {"POINT (0 0)", "POINT (1 1)"}, + {"POINT (0 0)", "POINT (1 1)"}); + runTest( + {"POINT (0 0)", std::nullopt, "POINT (2 2)", std::nullopt}, + {"POINT (0 0)", "POINT (1 1)", std::nullopt, std::nullopt}, + std::nullopt, + "ST_Intersects(left_g, right_g)", + core::JoinType::kLeft, + {"POINT (0 0)", "POINT (2 2)", std::nullopt, std::nullopt}, + {"POINT (0 0)", std::nullopt, std::nullopt, std::nullopt}); +} + +// Test geometries that don't intersect but their envelopes do. +// Important to test spatial index +TEST_F(SpatialJoinTest, simpleSpatialJoinEnvelopes) { + runTest( + {"POINT (0.5 0.6)", "POINT (0.5 0.5)", "LINESTRING (0 0.1, 0.9 1)"}, + {"POLYGON ((0 0, 1 1, 1 0, 0 0))"}, + std::nullopt, + "ST_Intersects(left_g, right_g)", + core::JoinType::kInner, + {"POINT (0.5 0.5)"}, + {"POLYGON ((0 0, 1 1, 1 0, 0 0))"}); +} + +TEST_F(SpatialJoinTest, testSelfSpatialJoin) { std::vector> inputWkts = { kPolygonA, kPolygonB, kPolygonC, kPolygonD}; std::vector> leftOutputWkts = { @@ -200,6 +264,7 @@ TEST_F(SpatialJoinTest, selfSpatialJoin) { runTest( inputWkts, inputWkts, + std::nullopt, "ST_Intersects(left_g, right_g)", core::JoinType::kInner, leftOutputWkts, @@ -208,6 +273,7 @@ TEST_F(SpatialJoinTest, selfSpatialJoin) { runTest( inputWkts, inputWkts, + std::nullopt, "ST_Intersects(left_g, right_g)", core::JoinType::kLeft, leftOutputWkts, @@ -216,6 +282,7 @@ TEST_F(SpatialJoinTest, selfSpatialJoin) { runTest( inputWkts, inputWkts, + std::nullopt, "ST_Overlaps(left_g, right_g)", core::JoinType::kInner, {kPolygonA, kPolygonB}, @@ -224,6 +291,7 @@ TEST_F(SpatialJoinTest, selfSpatialJoin) { runTest( inputWkts, inputWkts, + std::nullopt, "ST_Intersects(left_g, right_g) AND ST_Overlaps(left_g, right_g)", core::JoinType::kInner, {kPolygonA, kPolygonB}, @@ -232,6 +300,7 @@ TEST_F(SpatialJoinTest, selfSpatialJoin) { runTest( inputWkts, inputWkts, + std::nullopt, "ST_Overlaps(left_g, right_g)", core::JoinType::kLeft, {kPolygonA, kPolygonB, kPolygonC, kPolygonD}, @@ -240,6 +309,7 @@ TEST_F(SpatialJoinTest, selfSpatialJoin) { runTest( inputWkts, inputWkts, + std::nullopt, "ST_Equals(left_g, right_g)", core::JoinType::kInner, inputWkts, @@ -248,6 +318,7 @@ TEST_F(SpatialJoinTest, selfSpatialJoin) { runTest( inputWkts, inputWkts, + std::nullopt, "ST_Equals(left_g, right_g)", core::JoinType::kLeft, inputWkts, @@ -294,12 +365,314 @@ TEST_F(SpatialJoinTest, pointPolygonSpatialJoin) { runTest( pointWkts, polygonWkts, + std::nullopt, "ST_Intersects(left_g, right_g)", core::JoinType::kInner, pointOutputWkts, polygonOutputWkts); } +TEST_F(SpatialJoinTest, testSimpleNullRowsJoin) { + runTest( + {"POINT (1 1)", std::nullopt, "POINT (1 2)"}, + {"POINT (1 1)", "POINT (2 1)", std::nullopt}, + std::nullopt, + "ST_Intersects(left_g, right_g)", + core::JoinType::kInner, + {"POINT (1 1)"}, + {"POINT (1 1)"}); +} + +TEST_F(SpatialJoinTest, testGeometryCollection) { + runTest( + {"GEOMETRYCOLLECTION (POINT (1 1))", + "GEOMETRYCOLLECTION EMPTY", + "POINT (1 1)"}, + {"GEOMETRYCOLLECTION (POINT (1 1))", + "GEOMETRYCOLLECTION EMPTY", + "POINT (1 1)"}, + std::nullopt, + "ST_Intersects(left_g, right_g)", + core::JoinType::kInner, + {"GEOMETRYCOLLECTION (POINT (1 1))", + "GEOMETRYCOLLECTION (POINT (1 1))", + "POINT (1 1)", + "POINT (1 1)"}, + {"GEOMETRYCOLLECTION (POINT (1 1))", + "POINT (1 1)", + "GEOMETRYCOLLECTION (POINT (1 1))", + "POINT (1 1)"}); + + runTest( + {"GEOMETRYCOLLECTION (POINT (1 1))", + "GEOMETRYCOLLECTION EMPTY", + "POINT (1 1)"}, + {"GEOMETRYCOLLECTION (POINT (1 2))", + "GEOMETRYCOLLECTION EMPTY", + "POINT (1 2)"}, + std::vector>{1.0, 1.0, 1.0}, + "ST_Distance(left_g, right_g) <= radius", + core::JoinType::kInner, + {"GEOMETRYCOLLECTION (POINT (1 1))", + "GEOMETRYCOLLECTION (POINT (1 1))", + "POINT (1 1)", + "POINT (1 1)"}, + {"GEOMETRYCOLLECTION (POINT (1 2))", + "POINT (1 2)", + "GEOMETRYCOLLECTION (POINT (1 2))", + "POINT (1 2)"}); +} + +TEST_F(SpatialJoinTest, testDistanceJoin) { + runTest( + {"POINT (1 2)", "POLYGON ((1 2, 2 2, 2 3, 1 3, 1 2))", std::nullopt}, + {"POINT (2 2)", + "POINT (1 1)", + std::nullopt, + "POINT (1 2)", + "POLYGON ((1 1, 1 0, 0 0, 0 1, 1 1))"}, + std::vector>{std::nullopt, 1.0, 0.0, 0.0, 1.0}, + "ST_Distance(left_g, right_g) <= radius", + core::JoinType::kInner, + {"POINT (1 2)", + "POLYGON ((1 2, 1 3, 2 3, 2 2, 1 2))", + "POINT (1 2)", + "POLYGON ((1 2, 1 3, 2 3, 2 2, 1 2))", + "POINT (1 2)", + "POLYGON ((1 2, 1 3, 2 3, 2 2, 1 2))"}, + {"POINT (1 1)", + "POINT (1 1)", + "POINT (1 2)", + "POINT (1 2)", + "POLYGON ((1 1, 1 0, 0 0, 0 1, 1 1))", + "POLYGON ((1 1, 1 0, 0 0, 0 1, 1 1))"}); +} + +TEST_F(SpatialJoinTest, testContainsPointsInPolygons) { + // Tests ST_Contains(polygon, point) - which polygons contain which points + std::vector> pointWkts = { + kPointX, kPointY, kPointZ, kPointW}; + std::vector> polygonWkts = { + kPolygonA, kPolygonB, kPolygonC, kPolygonD}; + + // Expected: A contains X, B contains Y, B contains Z, A contains Y, D + // contains nothing from our test set Note: Y is in both A and B since they + // overlap + std::vector> pointOutputWkts = { + kPointX, kPointY, kPointY, kPointZ}; + std::vector> polygonOutputWkts = { + kPolygonA, kPolygonA, kPolygonB, kPolygonB}; + + runTest( + pointWkts, + polygonWkts, + std::nullopt, + "ST_Contains(right_g, left_g)", + core::JoinType::kInner, + pointOutputWkts, + polygonOutputWkts); +} + +TEST_F(SpatialJoinTest, testContainsPolygonsInPolygons) { + // Tests ST_Contains(polygon, polygon) - which polygons contain which polygons + // From the Java test, polygon C contains polygon B (C is larger and covers B) + std::vector> polygonWkts = { + kPolygonA, kPolygonB, kPolygonC, kPolygonD}; + + // Each polygon contains itself, plus any additional containments + // Based on the spatial relations, we need to check which polygons actually + // contain others For now, test self-containment which should always work + std::vector> leftOutputWkts = { + kPolygonA, kPolygonB, kPolygonC, kPolygonD}; + std::vector> rightOutputWkts = { + kPolygonA, kPolygonB, kPolygonC, kPolygonD}; + + runTest( + polygonWkts, + polygonWkts, + std::nullopt, + "ST_Contains(right_g, left_g)", + core::JoinType::kInner, + leftOutputWkts, + rightOutputWkts); +} + +TEST_F(SpatialJoinTest, testContainsLeftJoin) { + // Tests ST_Contains with LEFT join - all probe rows should appear + std::vector> pointWkts = { + kPointX, kPointY, kPointZ, kPointW}; + std::vector> polygonWkts = { + kPolygonA, kPolygonB}; + + // W is outside both polygons, so it should have null for the right side + std::vector> pointOutputWkts = { + kPointX, kPointY, kPointY, kPointZ, kPointW}; + std::vector> polygonOutputWkts = { + kPolygonA, kPolygonA, kPolygonB, kPolygonB, std::nullopt}; + + runTest( + pointWkts, + polygonWkts, + std::nullopt, + "ST_Contains(right_g, left_g)", + core::JoinType::kLeft, + pointOutputWkts, + polygonOutputWkts); +} + +TEST_F(SpatialJoinTest, testTouches) { + // Test ST_Touches - geometries that touch at boundary but don't overlap + // Polygon and a point on its boundary + std::vector> probeWkts = { + "POINT (1 2)", "POINT (3 2)", "LINESTRING (0 0, 1 1)"}; + std::vector> buildWkts = { + "POLYGON ((1 1, 1 4, 4 4, 4 1, 1 1))", + "POLYGON ((1 1, 1 3, 3 3, 3 1, 1 1))"}; + + // Point (1,2) touches both polygons (on their boundaries) + // Point (3,2) touches second polygon (on its boundary) + // LineString (0 0, 1 1) touches both polygons (endpoint at (1,1)) + std::vector> probeOutputWkts = { + "POINT (1 2)", + "POINT (1 2)", + "POINT (3 2)", + "LINESTRING (0 0, 1 1)", + "LINESTRING (0 0, 1 1)"}; + std::vector> buildOutputWkts = { + "POLYGON ((1 1, 1 4, 4 4, 4 1, 1 1))", + "POLYGON ((1 1, 1 3, 3 3, 3 1, 1 1))", + "POLYGON ((1 1, 1 3, 3 3, 3 1, 1 1))", + "POLYGON ((1 1, 1 4, 4 4, 4 1, 1 1))", + "POLYGON ((1 1, 1 3, 3 3, 3 1, 1 1))"}; + + runTest( + probeWkts, + buildWkts, + std::nullopt, + "ST_Touches(left_g, right_g)", + core::JoinType::kInner, + probeOutputWkts, + buildOutputWkts); +} + +TEST_F(SpatialJoinTest, testTouchesPolygons) { + // Test ST_Touches with two polygons that touch at a corner + std::vector> probeWkts = { + "POLYGON ((1 1, 1 3, 3 3, 3 1, 1 1))", + "POLYGON ((5 5, 5 6, 6 6, 6 5, 5 5))"}; + std::vector> buildWkts = { + "POLYGON ((3 3, 3 5, 5 5, 5 3, 3 3))"}; + + // Both polygons touch the build polygon at corners: + // - First polygon touches at (3,3) + // - Second polygon touches at (5,5) + std::vector> probeOutputWkts = { + "POLYGON ((1 1, 1 3, 3 3, 3 1, 1 1))", + "POLYGON ((5 5, 5 6, 6 6, 6 5, 5 5))"}; + std::vector> buildOutputWkts = { + "POLYGON ((3 3, 3 5, 5 5, 5 3, 3 3))", + "POLYGON ((3 3, 3 5, 5 5, 5 3, 3 3))"}; + + runTest( + probeWkts, + buildWkts, + std::nullopt, + "ST_Touches(left_g, right_g)", + core::JoinType::kInner, + probeOutputWkts, + buildOutputWkts); +} + +TEST_F(SpatialJoinTest, testCrosses) { + // Test ST_Crosses - geometries that cross each other + // A linestring crossing a polygon + std::vector> probeWkts = { + "LINESTRING (0 0, 4 4)", // Crosses both polygons + "LINESTRING (5 0, 5 4)", // Outside both + "LINESTRING (1 1, 2 2)" // Contained in polygon, doesn't cross + }; + std::vector> buildWkts = { + "POLYGON ((1 1, 1 3, 3 3, 3 1, 1 1))"}; + + // Only the first linestring crosses the polygon + std::vector> probeOutputWkts = { + "LINESTRING (0 0, 4 4)"}; + std::vector> buildOutputWkts = { + "POLYGON ((1 1, 1 3, 3 3, 3 1, 1 1))"}; + + runTest( + probeWkts, + buildWkts, + std::nullopt, + "ST_Crosses(left_g, right_g)", + core::JoinType::kInner, + probeOutputWkts, + buildOutputWkts); +} + +TEST_F(SpatialJoinTest, testCrossesLineStrings) { + // Test ST_Crosses with two linestrings that cross each other + std::vector> probeWkts = { + "LINESTRING (0 0, 1 1)", // Crosses first build linestring + "LINESTRING (2 2, 3 3)" // Parallel, doesn't cross + }; + std::vector> buildWkts = { + "LINESTRING (1 0, 0 1)"}; + + // Only the first linestring crosses + std::vector> probeOutputWkts = { + "LINESTRING (0 0, 1 1)"}; + std::vector> buildOutputWkts = { + "LINESTRING (1 0, 0 1)"}; + + runTest( + probeWkts, + buildWkts, + std::nullopt, + "ST_Crosses(left_g, right_g)", + core::JoinType::kInner, + probeOutputWkts, + buildOutputWkts); +} + +TEST_F(SpatialJoinTest, testEmptyBuild) { + runTest( + {kPointX, std::nullopt, kPointY, kPointZ, kPointW}, + {}, + std::nullopt, + "ST_Intersects(left_g, right_g)", + core::JoinType::kInner, + {}, + {}); + runTest( + {kPointX, std::nullopt, kPointY, kPointZ, kPointW}, + {}, + std::nullopt, + "ST_Intersects(left_g, right_g)", + core::JoinType::kLeft, + {kPointX, std::nullopt, kPointY, kPointZ, kPointW}, + {std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt}); +} + +TEST_F(SpatialJoinTest, testEmptyProbe) { + runTest( + {}, + {kPointX, std::nullopt, kPointY, kPointZ, kPointW}, + std::nullopt, + "ST_Intersects(left_g, right_g)", + core::JoinType::kInner, + {}, + {}); + runTest( + {}, + {kPointX, std::nullopt, kPointY, kPointZ, kPointW}, + std::nullopt, + "ST_Intersects(left_g, right_g)", + core::JoinType::kLeft, + {}, + {}); +} + TEST_F(SpatialJoinTest, failOnGroupedExecution) { std::vector batches{ makeRowVector({"wkt"}, {makeFlatVector({"POINT(0 0)"})})}; @@ -317,6 +690,9 @@ TEST_F(SpatialJoinTest, failOnGroupedExecution) { .localPartition({}) .planNode(), "ST_Intersects(left_g, right_g)", + "left_g", + "right_g", + std::nullopt, {"left_g", "right_g"}, core::JoinType::kInner) .project( @@ -335,4 +711,57 @@ TEST_F(SpatialJoinTest, failOnGroupedExecution) { task->start(1), "Spatial joins do not support grouped execution."); } +TEST_F(SpatialJoinTest, testLargeJoinSize) { + size_t numRows = 64; + size_t maxCoord = 17; + std::vector buildWkts; + buildWkts.reserve(numRows); + std::vector probeWkts; + probeWkts.reserve(numRows); + for (size_t i = 0; i < numRows; ++i) { + buildWkts.push_back( + fmt::format("POINT ({} {})", (i + 1) % maxCoord, (i + 2) % maxCoord)); + probeWkts.push_back( + fmt::format("POINT ({} {})", i % maxCoord, (i + 1) % maxCoord)); + } + + std::vector> buildWktsView; + buildWktsView.reserve(numRows); + std::vector> probeWktsView; + probeWktsView.reserve(numRows); + for (size_t i = 0; i < numRows; ++i) { + buildWktsView.push_back(buildWkts[i]); + probeWktsView.push_back(probeWkts[i]); + } + + std::vector> expectedLeftWkts; + expectedLeftWkts.reserve(numRows * numRows / maxCoord); + std::vector> expectedRightWkts; + expectedRightWkts.reserve(numRows * numRows / maxCoord); + for (size_t innerIdx = 0; innerIdx < numRows; ++innerIdx) { + for (size_t outerIdx = 0; outerIdx < numRows; ++outerIdx) { + if (probeWkts[outerIdx] == buildWkts[innerIdx]) { + expectedLeftWkts.push_back(probeWkts[outerIdx]); + expectedRightWkts.push_back(buildWkts[innerIdx]); + } + } + } + + for (bool separateProbeBatches : {false, true}) { + for (size_t maxBatchSize : {64, 13, 7, 5, 3, 2, 1}) { + runTestWithConfig( + buildWktsView, + probeWktsView, + std::nullopt, + "ST_Equals(left_g, right_g)", + core::JoinType::kInner, + expectedLeftWkts, + expectedRightWkts, + 1, + maxBatchSize, + separateProbeBatches); + } + } +} + } // namespace facebook::velox::exec::test diff --git a/velox/common/base/tests/SpillStatsTest.cpp b/velox/exec/tests/SpillStatsTest.cpp similarity index 98% rename from velox/common/base/tests/SpillStatsTest.cpp rename to velox/exec/tests/SpillStatsTest.cpp index 96564011482..7f881ed0c4d 100644 --- a/velox/common/base/tests/SpillStatsTest.cpp +++ b/velox/exec/tests/SpillStatsTest.cpp @@ -14,12 +14,12 @@ * limitations under the License. */ -#include "velox/common/base/SpillStats.h" +#include "velox/exec/SpillStats.h" #include #include "velox/common/base/VeloxException.h" #include "velox/common/base/tests/GTestUtils.h" -using namespace facebook::velox::common; +using namespace facebook::velox::exec; TEST(SpillStatsTest, spillStats) { SpillStats stats1; diff --git a/velox/exec/tests/SpillTest.cpp b/velox/exec/tests/SpillTest.cpp index 9097abd7f66..eb633428d9c 100644 --- a/velox/exec/tests/SpillTest.cpp +++ b/velox/exec/tests/SpillTest.cpp @@ -21,9 +21,10 @@ #include "velox/common/base/RuntimeMetrics.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/FileSystems.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/exec/OperatorUtils.h" #include "velox/exec/Spill.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" +#include "velox/exec/tests/utils/MergeTestBase.h" #include "velox/serializers/PrestoSerializer.h" #include "velox/type/Timestamp.h" #include "velox/vector/tests/utils/VectorTestBase.h" @@ -32,7 +33,7 @@ using namespace facebook; using namespace facebook::velox; using namespace facebook::velox::exec; using namespace facebook::velox::filesystems; -using facebook::velox::exec::test::TempDirectoryPath; +using namespace facebook::velox::common::testutil; namespace { static const int64_t kGB = 1'000'000'000; @@ -44,7 +45,7 @@ class TestRuntimeStatWriter : public BaseRuntimeStatWriter { std::unordered_map& stats) : stats_{stats} {} - void addRuntimeStat(const std::string& name, const RuntimeCounter& value) + void addRuntimeStat(std::string_view name, const RuntimeCounter& value) override { addOperatorRuntimeStats(name, value, stats_); } @@ -78,6 +79,10 @@ struct TestParam { } }; +inline void PrintTo(const TestParam& param, std::ostream* os) { + *os << param.toString(); +} + class SpillTest : public ::testing::TestWithParam, public facebook::velox::test::VectorTestBase { public: @@ -128,7 +133,7 @@ class SpillTest : public ::testing::TestWithParam, facebook::velox::serializer::presto::PrestoVectorSerde:: registerVectorSerde(); } - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kPresto)) { + if (!isRegisteredNamedVectorSerde("Presto")) { facebook::velox::serializer::presto::PrestoVectorSerde:: registerNamedVectorSerde(); } @@ -136,7 +141,7 @@ class SpillTest : public ::testing::TestWithParam, void SetUp() override { allocator_ = memory::memoryManager()->allocator(); - tempDir_ = exec::test::TempDirectoryPath::create(); + tempDir_ = TempDirectoryPath::create(); filesystems::registerLocalFileSystem(); rng_.seed(1); compressionKind_ = TestParam{GetParam()}.compressionKind; @@ -236,7 +241,7 @@ class SpillTest : public ::testing::TestWithParam, batchesByPartition_.clear(); values_.clear(); runtimeStats_.clear(); - spillStats_.wlock()->reset(); + spillStats_.reset(); fileNamePrefix_ = "test"; values_.resize(numBatches * numRowsPerBatch); @@ -282,7 +287,7 @@ class SpillTest : public ::testing::TestWithParam, // vectors have the ith element = i * 'numBatches' + batch, where batch is // the batch number of the vector in the partition. When read back, both // partitions produce an ascending sequence of integers without gaps. - spillStats_.wlock()->reset(); + spillStats_.reset(); const std::optional prefixSortConfig = enablePrefixSort_ ? std::optional(common::PrefixSortConfig()) @@ -302,8 +307,8 @@ class SpillTest : public ::testing::TestWithParam, pool(), &spillStats_); ASSERT_EQ(targetFileSize, state_->targetFileSize()); - ASSERT_EQ(spillStats_.rlock()->spilledPartitions, 0); - ASSERT_EQ(spillStats_.rlock()->spilledPartitions, 0); + ASSERT_EQ(spillStats_.spilledPartitions, 0); + ASSERT_EQ(spillStats_.spilledPartitions, 0); ASSERT_TRUE(state_->spilledPartitionIdSet().empty()); ASSERT_EQ(compressionKind_, state_->compressionKind()); ASSERT_EQ(state_->sortingKeys().size(), numSortKeys); @@ -369,7 +374,7 @@ class SpillTest : public ::testing::TestWithParam, partitionId)); } } - ASSERT_EQ(spillStats_.rlock()->spilledPartitions, partitionIds.size()); + ASSERT_EQ(spillStats_.spilledPartitions, partitionIds.size()); for (const auto& partitionId : partitionIds) { ASSERT_TRUE(state_->spilledPartitionIdSet().contains(partitionId)); } @@ -383,10 +388,9 @@ class SpillTest : public ::testing::TestWithParam, if (targetFileSize > 1) { expectedFiles /= 2; } - ASSERT_EQ(spillStats_.rlock()->spilledFiles, expectedFiles); + ASSERT_EQ(spillStats_.spilledFiles, expectedFiles); ASSERT_GT( - spillStats_.rlock()->spilledBytes, - numPartitions * numBatches * sizeof(int64_t)); + spillStats_.spilledBytes, numPartitions * numBatches * sizeof(int64_t)); int numFinishedFiles{0}; for (const auto& partitionId : partitionIds) { numFinishedFiles += state_->numFinishedFiles(partitionId); @@ -394,6 +398,31 @@ class SpillTest : public ::testing::TestWithParam, ASSERT_EQ(numFinishedFiles, expectedFiles); } + void hybridSpillStateTest( + int64_t targetFileSize, + int numPartitions, + int numBatches, + int numDuplicates, + const std::vector& compareFlags, + uint64_t expectedNumSpilledFiles) { + int mergeWayThresholdBegin = 0; + int mergeWayThresholdEnd = 4; + for (int i = mergeWayThresholdBegin; i < mergeWayThresholdEnd; i++) { + if (i == 1) { + // Skip invalid value. + continue; + } + spillStateTest( + targetFileSize, + numPartitions, + numBatches, + numDuplicates, + compareFlags, + expectedNumSpilledFiles, + i); + } + } + // 'numDuplicates' specifies the number of duplicates generated for each // distinct sort key value in test. void spillStateTest( @@ -402,18 +431,20 @@ class SpillTest : public ::testing::TestWithParam, int numBatches, int numDuplicates, const std::vector& compareFlags, - uint64_t expectedNumSpilledFiles) { + uint64_t expectedNumSpilledFiles, + int mergeWayThreshold) { const int numRowsPerBatch = 1'000; - SCOPED_TRACE(fmt::format( - "targetFileSize: {}, numPartitions: {}, numBatches: {}, numDuplicates: {}, nullsFirst: {}, ascending: {}", - targetFileSize, - numPartitions, - numBatches, - numDuplicates, - compareFlags.empty() ? true : compareFlags[0].nullsFirst, - compareFlags.empty() ? true : compareFlags[0].ascending)); + SCOPED_TRACE( + fmt::format( + "targetFileSize: {}, numPartitions: {}, numBatches: {}, numDuplicates: {}, nullsFirst: {}, ascending: {}", + targetFileSize, + numPartitions, + numBatches, + numDuplicates, + compareFlags.empty() ? true : compareFlags[0].nullsFirst, + compareFlags.empty() ? true : compareFlags[0].ascending)); - const auto prevGStats = common::globalSpillStats(); + const auto prevGStats = globalSpillStats(); SpillPartitionIdSet partitionIds = genPartitionIdSet(numPartitions); @@ -425,7 +456,7 @@ class SpillTest : public ::testing::TestWithParam, numRowsPerBatch, numDuplicates, compareFlags); - const auto stats = spillStats_.copy(); + const auto& stats = spillStats_; ASSERT_EQ(stats.spilledPartitions, numPartitions); ASSERT_EQ(stats.spilledFiles, expectedNumSpilledFiles); ASSERT_GT(stats.spilledBytes, 0); @@ -438,7 +469,7 @@ class SpillTest : public ::testing::TestWithParam, // NOTE: the following stats are not collected by spill state. ASSERT_EQ(stats.spillFillTimeNanos, 0); ASSERT_EQ(stats.spillSortTimeNanos, 0); - const auto newGStats = common::globalSpillStats(); + const auto newGStats = globalSpillStats(); ASSERT_EQ( prevGStats.spilledPartitions + stats.spilledPartitions, newGStats.spilledPartitions); @@ -488,13 +519,20 @@ class SpillTest : public ::testing::TestWithParam, ASSERT_EQ(stats.spilledBytes, totalFileBytes); ASSERT_EQ(prevGStats.spilledBytes + totalFileBytes, newGStats.spilledBytes); + bool usePreMerge = mergeWayThreshold >= 2; for (const auto& partitionId : partitionIds) { auto spillFiles = state_->finish(partitionId); ASSERT_EQ(state_->numFinishedFiles(partitionId), 0); auto spillPartition = SpillPartition(SpillPartitionId(partitionId), std::move(spillFiles)); - auto merge = - spillPartition.createOrderedReader(1 << 20, pool(), &spillStats_); + auto spillConfig = common::SpillConfig(); + spillConfig.numMaxMergeFiles = mergeWayThreshold; + spillConfig.readBufferSize = 1 << 20; + spillConfig.writeBufferSize = 1 << 20; + spillConfig.updateAndCheckSpillLimitCb = [](int64_t) {}; + spillConfig.fileCreateConfig = ""; + std::unique_ptr> merge = + spillPartition.createOrderedReader(spillConfig, pool(), &spillStats_); int numReadBatches = 0; // We expect all the rows in dense increasing order. for (auto i = 0; i < numBatches * numRowsPerBatch; ++i) { @@ -527,11 +565,13 @@ class SpillTest : public ::testing::TestWithParam, } } ASSERT_EQ(nullptr, merge->next()); - // We do two append writes per each input batch. - ASSERT_EQ(numBatches, numReadBatches); + if (!usePreMerge) { + // We do two append writes per each input batch. + ASSERT_EQ(numBatches, numReadBatches); + } } - const auto finalStats = spillStats_.copy(); + const auto& finalStats = spillStats_; ASSERT_EQ(finalStats.spillReadBytes, finalStats.spilledBytes); ASSERT_GT(finalStats.spillReads, 0); ASSERT_GT(finalStats.spillReadTimeNanos, 0); @@ -567,7 +607,67 @@ class SpillTest : public ::testing::TestWithParam, ASSERT_TRUE(fs->exists(spilledFile)); } // Verify stats. - ASSERT_EQ(runtimeStats_["spillFileSize"].count, spilledFiles.size()); + if (!usePreMerge) { + ASSERT_EQ(runtimeStats_["spillFileSize"].count, spilledFiles.size()); + } else { + ASSERT_GE(runtimeStats_["spillFileSize"].count, spilledFiles.size()); + } + } + + void gatherMergeTest( + int32_t numValues, + int numMergeWays, + int targetSize, + bool useRandom) { + auto goldenVector = makeRowVector({ + makeFlatVector(numValues, [&](auto row) { return row; }), + }); + std::vector> mergeWays(numMergeWays); + for (int32_t value = 0; value < numValues; value++) { + int way = useRandom ? folly::Random::rand32() % numMergeWays + : value % numMergeWays; + mergeWays[way].push_back(value); + } + std::vector sources; + std::vector> streams; + std::vector sortKeys = {{0, {true, true}}}; + for (int way = 0; way < numMergeWays; way++) { + auto source = makeRowVector({ + makeFlatVector( + mergeWays[way].size(), + [&](auto row) { return mergeWays[way][row]; }), + }); + sources.push_back(source); + streams.push_back( + std::make_unique( + way, sortKeys, source)); + } + auto mergeTree = + std::make_unique>(std::move(streams)); + RowVectorPtr targetVector = std::static_pointer_cast( + BaseVector::create(sources[0]->type(), targetSize, pool_.get())); + std::vector bufferSources(targetSize); + std::vector bufferSourceIndices(targetSize); + for (int32_t batch = 0; batch * targetSize < numValues; batch++) { + int32_t valueBegin = batch * targetSize; + int32_t valueEnd = valueBegin + targetSize; + valueEnd = std::min(valueEnd, numValues); + VectorPtr tmp = std::move(targetVector); + BaseVector::prepareForReuse(tmp, targetSize); + targetVector = std::static_pointer_cast(tmp); + for (auto& child : targetVector->children()) { + child->resize(targetSize); + } + int count = 0; + testingGatherMerge( + targetVector, *mergeTree, count, bufferSources, bufferSourceIndices); + EXPECT_EQ(count, valueEnd - valueBegin); + auto result = targetVector->childAt(0).get(); + auto golden = goldenVector->childAt(0).get(); + for (int32_t row = 0; row < valueEnd - valueBegin; row++) { + EXPECT_TRUE(result->equalValueAt(golden, row, valueBegin + row)); + } + } } folly::Random::DefaultGenerator rng_; @@ -579,7 +679,7 @@ class SpillTest : public ::testing::TestWithParam, folly::F14FastMap> batchesByPartition_; std::string fileNamePrefix_; - folly::Synchronized spillStats_; + exec::SpillStats spillStats_; std::unique_ptr state_; std::unordered_map runtimeStats_; std::unique_ptr statWriter_; @@ -591,24 +691,24 @@ TEST_P(SpillTest, spillState) { // triggered by batch write. // Test with distinct sort keys. - spillStateTest(kGB, 2, 8, 1, {CompareFlags{true, true}}, 8); - spillStateTest(kGB, 2, 8, 1, {CompareFlags{true, false}}, 8); - spillStateTest(kGB, 2, 8, 1, {CompareFlags{false, true}}, 8); - spillStateTest(kGB, 2, 8, 1, {CompareFlags{false, false}}, 8); - spillStateTest(kGB, 2, 8, 1, {}, 8); + hybridSpillStateTest(kGB, 2, 8, 1, {CompareFlags{true, true}}, 8); + hybridSpillStateTest(kGB, 2, 8, 1, {CompareFlags{true, false}}, 8); + hybridSpillStateTest(kGB, 2, 8, 1, {CompareFlags{false, true}}, 8); + hybridSpillStateTest(kGB, 2, 8, 1, {CompareFlags{false, false}}, 8); + hybridSpillStateTest(kGB, 2, 8, 1, {}, 8); // Test with duplicate sort keys. - spillStateTest(kGB, 2, 8, 8, {CompareFlags{true, true}}, 8); - spillStateTest(kGB, 2, 8, 8, {CompareFlags{true, false}}, 8); - spillStateTest(kGB, 2, 8, 8, {CompareFlags{false, true}}, 8); - spillStateTest(kGB, 2, 8, 8, {CompareFlags{false, false}}, 8); - spillStateTest(kGB, 2, 8, 8, {}, 8); + hybridSpillStateTest(kGB, 2, 8, 8, {CompareFlags{true, true}}, 8); + hybridSpillStateTest(kGB, 2, 8, 8, {CompareFlags{true, false}}, 8); + hybridSpillStateTest(kGB, 2, 8, 8, {CompareFlags{false, true}}, 8); + hybridSpillStateTest(kGB, 2, 8, 8, {CompareFlags{false, false}}, 8); + hybridSpillStateTest(kGB, 2, 8, 8, {}, 8); } TEST_P(SpillTest, spillTimestamp) { // Verify that timestamp type retains it nanosecond precision when spilled and // read back. - auto tempDirectory = exec::test::TempDirectoryPath::create(); + auto tempDirectory = TempDirectoryPath::create(); std::vector emptyCompareFlags; const std::string spillPath = tempDirectory->getPath() + "/test"; std::vector timeValues = { @@ -646,11 +746,17 @@ TEST_P(SpillTest, spillTimestamp) { state.testingNonEmptySpilledPartitionIdSet().contains(partitionId)); SpillPartition spillPartition(SpillPartitionId{0}, state.finish(partitionId)); + auto spillConfig = common::SpillConfig(); + spillConfig.numMaxMergeFiles = 2; + spillConfig.readBufferSize = 1 << 20; + spillConfig.writeBufferSize = 1 << 20; + spillConfig.updateAndCheckSpillLimitCb = [](int64_t) {}; + spillConfig.fileCreateConfig = ""; auto merge = - spillPartition.createOrderedReader(1 << 20, pool(), &spillStats_); + spillPartition.createOrderedReader(spillConfig, pool(), &spillStats_); ASSERT_TRUE(merge != nullptr); ASSERT_TRUE( - spillPartition.createOrderedReader(1 << 20, pool(), &spillStats_) == + spillPartition.createOrderedReader(spillConfig, pool(), &spillStats_) == nullptr); for (auto i = 0; i < timeValues.size(); ++i) { auto* stream = merge->next(); @@ -668,18 +774,18 @@ TEST_P(SpillTest, spillStateWithSmallTargetFileSize) { // write. // Test with distinct sort keys. - spillStateTest(1, 2, 8, 1, {CompareFlags{true, true}}, 8 * 2); - spillStateTest(1, 2, 8, 1, {CompareFlags{true, false}}, 8 * 2); - spillStateTest(1, 2, 8, 1, {CompareFlags{false, true}}, 8 * 2); - spillStateTest(1, 2, 8, 1, {CompareFlags{false, false}}, 8 * 2); - spillStateTest(1, 2, 8, 1, {}, 8 * 2); + hybridSpillStateTest(1, 2, 8, 1, {CompareFlags{true, true}}, 8 * 2); + hybridSpillStateTest(1, 2, 8, 1, {CompareFlags{true, false}}, 8 * 2); + hybridSpillStateTest(1, 2, 8, 1, {CompareFlags{false, true}}, 8 * 2); + hybridSpillStateTest(1, 2, 8, 1, {CompareFlags{false, false}}, 8 * 2); + hybridSpillStateTest(1, 2, 8, 1, {}, 8 * 2); // Test with duplicated sort keys. - spillStateTest(1, 2, 8, 8, {CompareFlags{true, false}}, 8 * 2); - spillStateTest(1, 2, 8, 8, {CompareFlags{true, true}}, 8 * 2); - spillStateTest(1, 2, 8, 8, {CompareFlags{false, false}}, 8 * 2); - spillStateTest(1, 2, 8, 8, {CompareFlags{false, true}}, 8 * 2); - spillStateTest(1, 2, 8, 8, {}, 8 * 2); + hybridSpillStateTest(1, 2, 8, 8, {CompareFlags{true, false}}, 8 * 2); + hybridSpillStateTest(1, 2, 8, 8, {CompareFlags{true, true}}, 8 * 2); + hybridSpillStateTest(1, 2, 8, 8, {CompareFlags{false, false}}, 8 * 2); + hybridSpillStateTest(1, 2, 8, 8, {CompareFlags{false, true}}, 8 * 2); + hybridSpillStateTest(1, 2, 8, 8, {}, 8 * 2); } TEST_P(SpillTest, spillPartitionId) { @@ -958,12 +1064,15 @@ TEST_P(SpillTest, spillPartitionFunctionBasic) { std::vector columns; columns.push_back( makeFlatVector(numRows, [](auto row) { return row; })); - columns.push_back(makeFlatVector( - numRows, [](auto row) { return fmt::format("key_{}", row); })); - columns.push_back(makeFlatVector( - numRows, [](auto row) { return fmt::format("key_{}_{}", row, row); })); - columns.push_back(makeFlatVector( - numRows, [](auto row) { return fmt::format("val_{}", row); })); + columns.push_back(makeFlatVector(numRows, [](auto row) { + return fmt::format("key_{}", row); + })); + columns.push_back(makeFlatVector(numRows, [](auto row) { + return fmt::format("key_{}_{}", row, row); + })); + columns.push_back(makeFlatVector(numRows, [](auto row) { + return fmt::format("val_{}", row); + })); inputVectors.push_back(makeRowVector(columns)); } @@ -1404,7 +1513,7 @@ TEST_P(SpillTest, validatePerSpillWriteSize) { } }; - auto tempDirectory = exec::test::TempDirectoryPath::create(); + auto tempDirectory = TempDirectoryPath::create(); SpillState state( [&]() -> const std::string& { return tempDirectory->getPath(); }, updateSpilledBytesCb_, @@ -1428,7 +1537,7 @@ TEST_P(SpillTest, validatePerSpillWriteSize) { namespace { SpillFiles makeFakeSpillFiles(int32_t numFiles) { - auto tempDir = exec::test::TempDirectoryPath::create(); + auto tempDir = TempDirectoryPath::create(); static uint32_t fakeFileId{0}; SpillFiles files; files.reserve(numFiles); @@ -1522,6 +1631,17 @@ TEST(SpillTest, scopedSpillInjectionRegex) { } } +TEST_P(SpillTest, gatherMerge) { + gatherMergeTest(1234, 2, 10, false); + gatherMergeTest(1234, 2, 100, false); + gatherMergeTest(1234, 10, 10, false); + gatherMergeTest(1234, 10, 100, false); + gatherMergeTest(1234, 2, 10, true); + gatherMergeTest(1234, 2, 100, true); + gatherMergeTest(1234, 10, 10, true); + gatherMergeTest(1234, 10, 100, true); +} + VELOX_INSTANTIATE_TEST_SUITE_P( SpillTestSuite, SpillTest, diff --git a/velox/exec/tests/SpillerAggregateBenchmarkTest.cpp b/velox/exec/tests/SpillerAggregateBenchmarkTest.cpp index d0b34a244f7..92398385ec5 100644 --- a/velox/exec/tests/SpillerAggregateBenchmarkTest.cpp +++ b/velox/exec/tests/SpillerAggregateBenchmarkTest.cpp @@ -37,7 +37,8 @@ int main(int argc, char* argv[]) { "AggregationOutputSpiller], the aggregate spiller dose not support it.", spillerType); } - auto test = std::make_unique(spillerType); + auto test = + std::make_unique(spillerType); test->setUp(); test->run(); test->printStats(); diff --git a/velox/exec/tests/SpillerBenchmarkBase.cpp b/velox/exec/tests/SpillerBenchmarkBase.cpp index ab94f79c62e..640cb7ea88f 100644 --- a/velox/exec/tests/SpillerBenchmarkBase.cpp +++ b/velox/exec/tests/SpillerBenchmarkBase.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include #include @@ -23,9 +24,9 @@ #include "velox/common/compression/Compression.h" #include "velox/common/file/FileSystems.h" #include "velox/common/memory/MmapAllocator.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/exec/Spiller.h" #include "velox/exec/tests/SpillerBenchmarkBase.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/vector/fuzzer/VectorFuzzer.h" DEFINE_string( @@ -56,7 +57,7 @@ DEFINE_uint32( "The number of key columns"); DEFINE_uint32( spiller_benchmark_spill_executor_size, - std::thread::hardware_concurrency(), + folly::available_concurrency(), "The spiller executor size in number of threads"); DEFINE_uint32( spiller_benchmark_spill_vector_size, @@ -78,6 +79,7 @@ DEFINE_uint64( using namespace facebook::velox::memory; namespace facebook::velox::exec::test { +using namespace facebook::velox::common::testutil; void SpillerBenchmarkBase::setUp() { rootPool_ = @@ -108,7 +110,7 @@ void SpillerBenchmarkBase::setUp() { } if (FLAGS_spiller_benchmark_path.empty()) { - tempDir_ = exec::test::TempDirectoryPath::create(); + tempDir_ = TempDirectoryPath::create(); spillDir_ = tempDir_->getPath(); } else { spillDir_ = FLAGS_spiller_benchmark_path; diff --git a/velox/exec/tests/SpillerBenchmarkBase.h b/velox/exec/tests/SpillerBenchmarkBase.h index a4a4fcc228e..5268823d5a0 100644 --- a/velox/exec/tests/SpillerBenchmarkBase.h +++ b/velox/exec/tests/SpillerBenchmarkBase.h @@ -19,8 +19,8 @@ #include #include "velox/common/file/FileSystems.h" #include "velox/common/memory/MmapAllocator.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/exec/Spiller.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/type/Type.h" #include "velox/vector/fuzzer/VectorFuzzer.h" @@ -37,6 +37,8 @@ DECLARE_uint64(spiller_benchmark_min_spill_run_size); DECLARE_uint64(spiller_benchmark_write_buffer_size); namespace facebook::velox::exec::test { + +/// This test measures /// This test measures the spill input overhead in spill join & probe. class SpillerBenchmarkBase { public: @@ -65,12 +67,12 @@ class SpillerBenchmarkBase { std::unique_ptr vectorFuzzer_; std::vector rowVectors_; std::unique_ptr executor_; - std::shared_ptr tempDir_; + std::shared_ptr tempDir_; std::string spillDir_; std::shared_ptr fs_; std::unique_ptr spiller_; // Stats. uint64_t executionTimeUs_{0}; - folly::Synchronized spillStats_; + exec::SpillStats spillStats_; }; } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/SpillerTest.cpp b/velox/exec/tests/SpillerTest.cpp index 807025162a1..ee8dc60aa36 100644 --- a/velox/exec/tests/SpillerTest.cpp +++ b/velox/exec/tests/SpillerTest.cpp @@ -57,7 +57,7 @@ class TestRuntimeStatWriter : public BaseRuntimeStatWriter { std::unordered_map& stats) : stats_{stats} {} - void addRuntimeStat(const std::string& name, const RuntimeCounter& value) + void addRuntimeStat(std::string_view name, const RuntimeCounter& value) override { addOperatorRuntimeStats(name, value, stats_); } @@ -119,6 +119,10 @@ struct TestParam { } }; +inline void PrintTo(const TestParam& param, std::ostream* os) { + *os << param.toString(); +} + struct TestParamsBuilder { std::vector getTestParams() { std::vector params; @@ -269,7 +273,7 @@ class SpillerTest : public exec::test::RowContainerTestBase { rng_.seed(1); const bool asyncRead = folly::Random::oneIn(2); LOG(INFO) << "Async read " << asyncRead; - tempDirPath_ = exec::test::TempDirectoryPath::create(true); + tempDirPath_ = TempDirectoryPath::create(true); fs_ = filesystems::getFileSystem(tempDirPath_->getPath(), nullptr); faultyFs_ = static_cast(fs_.get()); fsExecutor_ = std::make_unique(32); @@ -331,15 +335,16 @@ class SpillerTest : public exec::test::RowContainerTestBase { bool ascending = true, bool makeError = false, uint64_t readBufferSize = 1 << 20) { - SCOPED_TRACE(fmt::format( - "spillType: {} numDuplicates: {} outputBatchSize: {} ascending: {} makeError: {}", - typeName(type_), - numDuplicates, - outputBatchSize, - ascending, - makeError)); + SCOPED_TRACE( + fmt::format( + "spillType: {} numDuplicates: {} outputBatchSize: {} ascending: {} makeError: {}", + typeName(type_), + numDuplicates, + outputBatchSize, + ascending, + makeError)); constexpr int32_t kNumRows = 5'000; - const auto prevGStats = common::globalSpillStats(); + const auto prevGStats = globalSpillStats(); setupSpillData(numKeys_, kNumRows, numDuplicates, [&](RowVectorPtr rows) { // Set ordinal so that the sorted order is unambiguous. @@ -355,8 +360,8 @@ class SpillerTest : public exec::test::RowContainerTestBase { return; } // Verify the spilled file exist on file system. - auto stats = spiller_->stats(); - const auto numSpilledFiles = stats.spilledFiles; + const auto& stats = spiller_->stats(); + const auto numSpilledFiles = stats.spilledFiles.load(); if (type_ == SpillerType::AGGREGATION_OUTPUT) { ASSERT_EQ(numSpilledFiles, 1); } else { @@ -388,45 +393,50 @@ class SpillerTest : public exec::test::RowContainerTestBase { verifySortedSpillData(spillPartitionSet, outputBatchSize); - stats = spiller_->stats(); - ASSERT_EQ(stats.spilledFiles, spilledFileSet.size()); - ASSERT_EQ(stats.spilledPartitions, numPartitions_); - ASSERT_EQ(stats.spilledRows, kNumRows); + const auto& updatedStats = spiller_->stats(); + ASSERT_EQ(updatedStats.spilledFiles.load(), spilledFileSet.size()); + ASSERT_EQ(updatedStats.spilledPartitions.load(), numPartitions_); + ASSERT_EQ(updatedStats.spilledRows.load(), kNumRows); - ASSERT_EQ(stats.spilledBytes, totalSpilledBytes); - ASSERT_EQ(stats.spillReadBytes, totalSpilledBytes); - ASSERT_GT(stats.spillWriteTimeNanos, 0); + ASSERT_EQ(updatedStats.spilledBytes.load(), totalSpilledBytes); + ASSERT_EQ(updatedStats.spillReadBytes.load(), totalSpilledBytes); + ASSERT_GT(updatedStats.spillWriteTimeNanos.load(), 0); if (type_ == SpillerType::AGGREGATION_OUTPUT) { - ASSERT_EQ(stats.spillSortTimeNanos, 0); + ASSERT_EQ(updatedStats.spillSortTimeNanos.load(), 0); } else { - ASSERT_GT(stats.spillSortTimeNanos, 0); + ASSERT_GT(updatedStats.spillSortTimeNanos.load(), 0); } - ASSERT_GT(stats.spillExtractVectorTimeNanos, 0); - ASSERT_GT(stats.spillFlushTimeNanos, 0); - ASSERT_GT(stats.spillFillTimeNanos, 0); - ASSERT_GT(stats.spillSerializationTimeNanos, 0); - ASSERT_GT(stats.spillWrites, 0); + ASSERT_GT(updatedStats.spillExtractVectorTimeNanos.load(), 0); + ASSERT_GT(updatedStats.spillFlushTimeNanos.load(), 0); + ASSERT_GT(updatedStats.spillFillTimeNanos.load(), 0); + ASSERT_GT(updatedStats.spillSerializationTimeNanos.load(), 0); + ASSERT_GT(updatedStats.spillWrites.load(), 0); - const auto newGStats = common::globalSpillStats(); + const auto newGStats = globalSpillStats(); ASSERT_EQ( - prevGStats.spilledFiles + stats.spilledFiles, newGStats.spilledFiles); + prevGStats.spilledFiles + updatedStats.spilledFiles.load(), + newGStats.spilledFiles); ASSERT_EQ( - prevGStats.spilledRows + stats.spilledRows, newGStats.spilledRows); + prevGStats.spilledRows + updatedStats.spilledRows.load(), + newGStats.spilledRows); ASSERT_EQ( - prevGStats.spilledPartitions + stats.spilledPartitions, + prevGStats.spilledPartitions + updatedStats.spilledPartitions.load(), newGStats.spilledPartitions); ASSERT_EQ( - prevGStats.spilledBytes + stats.spilledBytes, newGStats.spilledBytes); + prevGStats.spilledBytes + updatedStats.spilledBytes.load(), + newGStats.spilledBytes); ASSERT_EQ( - prevGStats.spillReadBytes + stats.spillReadBytes, + prevGStats.spillReadBytes + updatedStats.spillReadBytes.load(), newGStats.spillReadBytes); - ASSERT_EQ(prevGStats.spillReads + stats.spillReads, newGStats.spillReads); ASSERT_EQ( - prevGStats.spillReadTimeNanos + stats.spillReadTimeNanos, + prevGStats.spillReads + updatedStats.spillReads.load(), + newGStats.spillReads); + ASSERT_EQ( + prevGStats.spillReadTimeNanos + updatedStats.spillReadTimeNanos.load(), newGStats.spillReadTimeNanos); ASSERT_EQ( prevGStats.spillDeserializationTimeNanos + - stats.spillDeserializationTimeNanos, + updatedStats.spillDeserializationTimeNanos.load(), newGStats.spillDeserializationTimeNanos); ASSERT_EQ( prevGStats.spillWriteTimeNanos + stats.spillWriteTimeNanos, @@ -614,7 +624,8 @@ class SpillerTest : public exec::test::RowContainerTestBase { return tempDirPath_->getPath(); }; stats_.clear(); - spillStats_ = folly::Synchronized(); + spillStats_ = folly::Synchronized(); + spillIoStats_ = IoStats(); spillConfig_.startPartitionBit = hashBits_.begin(); spillConfig_.numPartitionBits = hashBits_.numBits(); @@ -635,7 +646,11 @@ class SpillerTest : public exec::test::RowContainerTestBase { if (type_ == SpillerType::NO_ROW_CONTAINER) { spiller_ = std::make_unique( - rowType_, std::nullopt, hashBits_, &spillConfig_, &spillStats_); + rowType_, + std::nullopt, + hashBits_, + &spillConfig_, + spillStats_.wlock().operator->()); } else if (type_ == SpillerType::SORT_INPUT) { const auto sortingKeys = SpillState::makeSortingKeys( compareFlags_.empty() @@ -646,10 +661,13 @@ class SpillerTest : public exec::test::RowContainerTestBase { rowType_, sortingKeys, &spillConfig_, - &spillStats_); + spillStats_.wlock().operator->()); } else if (type_ == SpillerType::SORT_OUTPUT) { spiller_ = std::make_unique( - rowContainer_.get(), rowType_, &spillConfig_, &spillStats_); + rowContainer_.get(), + rowType_, + &spillConfig_, + spillStats_.wlock().operator->()); } else if (type_ == SpillerType::HASH_BUILD) { spiller_ = std::make_unique( joinType_, @@ -658,7 +676,7 @@ class SpillerTest : public exec::test::RowContainerTestBase { rowType_, hashBits_, &spillConfig_, - &spillStats_); + spillStats_.wlock().operator->()); } else if (type_ == SpillerType::AGGREGATION_INPUT) { const auto sortingKeys = SpillState::makeSortingKeys( compareFlags_.empty() @@ -670,10 +688,13 @@ class SpillerTest : public exec::test::RowContainerTestBase { hashBits_, sortingKeys, &spillConfig_, - &spillStats_); + spillStats_.wlock().operator->()); } else if (type_ == SpillerType::AGGREGATION_OUTPUT) { spiller_ = std::make_unique( - rowContainer_.get(), rowType_, &spillConfig_, &spillStats_); + rowContainer_.get(), + rowType_, + &spillConfig_, + spillStats_.wlock().operator->()); } else if (type_ == SpillerType::ROW_NUMBER_HASH_TABLE) { spiller_ = std::make_unique( rowContainer_.get(), @@ -681,7 +702,7 @@ class SpillerTest : public exec::test::RowContainerTestBase { rowType_, hashBits_, &spillConfig_, - &spillStats_); + spillStats_.wlock().operator->()); } else { VELOX_UNREACHABLE("Unknown spiller type"); } @@ -714,11 +735,12 @@ class SpillerTest : public exec::test::RowContainerTestBase { // We make a merge reader that merges the spill files and the rows that // are still in the RowContainer. auto merge = spillPartition->createOrderedReader( - spillConfig_.readBufferSize, pool(), &spillStats_); + spillConfig_, pool(), spillStats_.wlock().operator->()); ASSERT_TRUE(merge != nullptr); ASSERT_TRUE( spillPartition->createOrderedReader( - spillConfig_.readBufferSize, pool(), &spillStats_) == nullptr); + spillConfig_, pool(), spillStats_.wlock().operator->()) == + nullptr); // We read the spilled data back and check that it matches the sorted // order of the partition. @@ -865,14 +887,15 @@ class SpillerTest : public exec::test::RowContainerTestBase { ss << partitionId.toString() << " "; } ss << "]"; - SCOPED_TRACE(fmt::format( - "Param: {}, numSpillers: {}, numBatchRows: {}, numAppendBatches: {}, targetFileSize: {}, spillPartitionIdSet: {}", - param_.toString(), - numSpillers, - numBatchRows, - numAppendBatches, - targetFileSize, - ss.str())); + SCOPED_TRACE( + fmt::format( + "Param: {}, numSpillers: {}, numBatchRows: {}, numAppendBatches: {}, targetFileSize: {}, spillPartitionIdSet: {}", + param_.toString(), + numSpillers, + numBatchRows, + numAppendBatches, + targetFileSize, + ss.str())); std::vector> inputsByPartition(numPartitions_); @@ -882,7 +905,7 @@ class SpillerTest : public exec::test::RowContainerTestBase { // them by partition. std::vector> spillers; for (int iter = 0; iter < numSpillers; ++iter) { - const auto prevGStats = common::globalSpillStats(); + const auto prevGStats = globalSpillStats(); setupSpillData( numKeys_, (type_ != SpillerType::NO_ROW_CONTAINER) ? numBatchRows * 10 : 0, @@ -963,7 +986,7 @@ class SpillerTest : public exec::test::RowContainerTestBase { ASSERT_EQ(stats.spillFillTimeNanos, 0); } - const auto newGStats = common::globalSpillStats(); + const auto newGStats = globalSpillStats(); ASSERT_EQ( prevGStats.spilledFiles + stats.spilledFiles, newGStats.spilledFiles); ASSERT_EQ( @@ -1007,7 +1030,7 @@ class SpillerTest : public exec::test::RowContainerTestBase { // Spilled file stats should be updated after finalizing spiller. if (numAppendBatches > 0) { - ASSERT_GT(common::globalSpillStats().spilledFiles, 0); + ASSERT_GT(globalSpillStats().spilledFiles, 0); } } @@ -1047,7 +1070,9 @@ class SpillerTest : public exec::test::RowContainerTestBase { spillConfig_.startPartitionBit, spillConfig_.numPartitionBits)); auto reader = spillPartitionEntry.second->createUnorderedReader( - spillConfig_.readBufferSize, pool(), &spillStats_); + spillConfig_.readBufferSize, + pool(), + spillStats_.wlock().operator->()); if (type_ == SpillerType::NO_ROW_CONTAINER) { // For hash probe type, we append each input vector as one batch in // spill file so that we can do one-to-one comparison. @@ -1129,7 +1154,9 @@ class SpillerTest : public exec::test::RowContainerTestBase { spillConfig_.startPartitionBit, spillConfig_.numPartitionBits)); auto reader = spillPartitionEntry.second->createUnorderedReader( - spillConfig_.readBufferSize, pool(), &spillStats_); + spillConfig_.readBufferSize, + pool(), + spillStats_.wlock().operator->()); if (type_ == SpillerType::NO_ROW_CONTAINER) { // For hash probe type, we append each input vector as one batch in // spill file so that we can do one-to-one comparison. @@ -1231,7 +1258,9 @@ class SpillerTest : public exec::test::RowContainerTestBase { std::vector compareFlags_; std::unique_ptr spiller_; common::SpillConfig spillConfig_; - folly::Synchronized spillStats_; + folly::Synchronized spillStats_; + // Filesystem I/O stats for spill operations. + IoStats spillIoStats_; }; struct AllTypesTestParam { @@ -1573,7 +1602,7 @@ TEST_P(AggregationOutputOnly, basic) { ASSERT_EQ(spillPartitionSet.size(), 1); auto spillPartition = std::move(spillPartitionSet.begin()->second); auto merge = spillPartition->createOrderedReader( - spillConfig_.readBufferSize, pool(), &spillStats_); + spillConfig_, pool(), spillStats_.wlock().operator->()); for (auto i = 0; i < expectedNumSpilledRows; ++i) { auto* stream = merge->next(); @@ -1687,7 +1716,7 @@ TEST_P(SortOutputOnly, basic) { const int expectedNumSpilledRows = numListedRows; auto merge = spillPartition->createOrderedReader( - spillConfig_.readBufferSize, pool(), &spillStats_); + spillConfig_, pool(), spillStats_.wlock().operator->()); if (expectedNumSpilledRows == 0) { ASSERT_TRUE(merge == nullptr); } else { diff --git a/velox/exec/tests/SplitListenerTest.cpp b/velox/exec/tests/SplitListenerTest.cpp index 15c8ebc613a..17421713494 100644 --- a/velox/exec/tests/SplitListenerTest.cpp +++ b/velox/exec/tests/SplitListenerTest.cpp @@ -14,15 +14,16 @@ * limitations under the License. */ +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/exec/Task.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" using namespace facebook::velox::exec; namespace facebook::velox::exec::test { +using namespace facebook::velox::common::testutil; namespace { std::unordered_map> @@ -131,10 +132,11 @@ class SplitListenerTest : public HiveConnectorTestBase { std::vector> splits; splits.reserve(filePaths_.size()); for (const auto& filePath : filePaths_) { - splits.emplace_back(connector::hive::HiveConnectorSplitBuilder(filePath) - .connectorId(kHiveConnectorId) - .fileFormat(dwio::common::FileFormat::DWRF) - .build()); + splits.emplace_back( + connector::hive::HiveConnectorSplitBuilder(filePath) + .connectorId(kHiveConnectorId) + .fileFormat(dwio::common::FileFormat::DWRF) + .build()); } return splits; } diff --git a/velox/exec/tests/StreamingAggregationTest.cpp b/velox/exec/tests/StreamingAggregationTest.cpp index 493163413b3..f8c1d265152 100644 --- a/velox/exec/tests/StreamingAggregationTest.cpp +++ b/velox/exec/tests/StreamingAggregationTest.cpp @@ -13,22 +13,33 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/exec/StreamingAggregation.h" #include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/testutil/TempFilePath.h" +#include "velox/common/testutil/TestValue.h" #include "velox/core/Expressions.h" + #include "velox/exec/PlanNodeStats.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/exec/tests/utils/SumNonPODAggregate.h" -#include "velox/exec/tests/utils/TempFilePath.h" + +using namespace facebook::velox::common::testutil; namespace facebook::velox::exec { namespace { using namespace facebook::velox::exec::test; -class StreamingAggregationTest : public HiveConnectorTestBase, - public testing::WithParamInterface { +struct TestParams { + int32_t streamingMinOutputBatchSize; + uint64_t preferredOutputBatchBytes; +}; + +class StreamingAggregationTest + : public HiveConnectorTestBase, + public testing::WithParamInterface { protected: void SetUp() override { HiveConnectorTestBase::SetUp(); @@ -36,7 +47,11 @@ class StreamingAggregationTest : public HiveConnectorTestBase, } int32_t flushRows() { - return GetParam(); + return GetParam().streamingMinOutputBatchSize; + } + + uint64_t preferredOutputBatchBytes() { + return GetParam().preferredOutputBatchBytes; } AssertQueryBuilder& config( @@ -48,7 +63,10 @@ class StreamingAggregationTest : public HiveConnectorTestBase, std::to_string(outputBatchSize)) .config( core::QueryConfig::kStreamingAggregationMinOutputBatchRows, - std::to_string(flushRows())); + std::to_string(flushRows())) + .config( + core::QueryConfig::kPreferredOutputBatchBytes, + std::to_string(preferredOutputBatchBytes())); } void testAggregation( @@ -156,8 +174,9 @@ class StreamingAggregationTest : public HiveConnectorTestBase, core::PlanNodeId aggregationNodeId; auto plan = PlanBuilder(planNodeIdGenerator) .startTableScan() - .outputType(std::dynamic_pointer_cast( - inputVectors[0]->type())) + .outputType( + std::dynamic_pointer_cast( + inputVectors[0]->type())) .endTableScan() .streamingAggregation( {"c0"}, @@ -170,20 +189,34 @@ class StreamingAggregationTest : public HiveConnectorTestBase, .capturePlanNodeId(aggregationNodeId) .planNode(); - for (const auto barrierExecution : {false, true}) { - SCOPED_TRACE(fmt::format("barrierExecution {}", barrierExecution)); + struct { + bool hasBarrier; + bool serialExecution; + + std::string toString() const { + return fmt::format( + "hasBarrier: {}, serialExecution: {}", hasBarrier, serialExecution); + } + } testSettings[] = { + {false, false}, + {false, true}, + {true, false}, + {true, true}, + }; + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.toString()); auto task = AssertQueryBuilder(plan, duckDbQueryRunner_) .splits(makeHiveConnectorSplits(tempFiles)) - .serialExecution(true) - .barrierExecution(barrierExecution) + .serialExecution(testData.serialExecution) + .barrierExecution(testData.hasBarrier) .config( core::QueryConfig::kPreferredOutputBatchRows, std::to_string(outputBatchSize)) .assertResults( "SELECT c0, max(c1 order by c2), max(c1 order by c2 desc), array_agg(c1 order by c2) FROM tmp GROUP BY c0"); const auto taskStats = task->taskStats(); - ASSERT_EQ(taskStats.numBarriers, barrierExecution ? numSplits : 0); + ASSERT_EQ(taskStats.numBarriers, testData.hasBarrier ? numSplits : 0); ASSERT_EQ(taskStats.numFinishedSplits, numSplits); ASSERT_EQ( velox::exec::toPlanStats(taskStats) @@ -251,8 +284,9 @@ class StreamingAggregationTest : public HiveConnectorTestBase, core::PlanNodeId aggregationNodeId; auto plan = PlanBuilder(planNodeIdGenerator) .startTableScan() - .outputType(std::dynamic_pointer_cast( - inputVectors[0]->type())) + .outputType( + std::dynamic_pointer_cast( + inputVectors[0]->type())) .endTableScan() .streamingAggregation( {"c0"}, @@ -265,13 +299,30 @@ class StreamingAggregationTest : public HiveConnectorTestBase, false) .capturePlanNodeId(aggregationNodeId) .planNode(); - for (const auto barrierExecution : {false, true}) { - SCOPED_TRACE(fmt::format("barrierExecution {}", barrierExecution)); + struct { + bool hasBarrier; + bool serialExecution; + + std::string toString() const { + return fmt::format( + "hasBarrier: {}, serialExecution: {}", + hasBarrier, + serialExecution); + } + } testSettings[] = { + {false, true}, + {false, false}, + {true, true}, + {true, false}, + }; + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.toString()); auto task = AssertQueryBuilder(plan, duckDbQueryRunner_) .splits(makeHiveConnectorSplits(tempFiles)) - .serialExecution(true) - .barrierExecution(barrierExecution) + .serialExecution(testData.serialExecution) + .barrierExecution(testData.hasBarrier) + .maxDrivers(testData.serialExecution ? 1 : 3) .config( core::QueryConfig::kPreferredOutputBatchRows, std::to_string(outputBatchSize)) @@ -279,7 +330,7 @@ class StreamingAggregationTest : public HiveConnectorTestBase, "SELECT c0, array_agg(distinct c1), array_agg(c1 order by c2), " "count(distinct c1), array_agg(c2) FROM tmp GROUP BY c0"); const auto taskStats = task->taskStats(); - ASSERT_EQ(taskStats.numBarriers, barrierExecution ? numSplits : 0); + ASSERT_EQ(taskStats.numBarriers, testData.hasBarrier ? numSplits : 0); ASSERT_EQ(taskStats.numFinishedSplits, numSplits); ASSERT_EQ( velox::exec::toPlanStats(taskStats) @@ -294,8 +345,9 @@ class StreamingAggregationTest : public HiveConnectorTestBase, auto plan = PlanBuilder(planNodeIdGenerator) .startTableScan() - .outputType(std::dynamic_pointer_cast( - inputVectors[0]->type())) + .outputType( + std::dynamic_pointer_cast( + inputVectors[0]->type())) .endTableScan() .streamingAggregation( {"c0"}, {}, {}, core::AggregationNode::Step::kSingle, false) @@ -500,8 +552,9 @@ class StreamingAggregationTest : public HiveConnectorTestBase, auto plan = PlanBuilder(planNodeIdGenerator) .startTableScan() - .outputType(std::dynamic_pointer_cast( - inputVectors[0]->type())) + .outputType( + std::dynamic_pointer_cast( + inputVectors[0]->type())) .endTableScan() .streamingAggregation( keys[0]->type()->asRow().names(), @@ -524,18 +577,34 @@ class StreamingAggregationTest : public HiveConnectorTestBase, keySql.str(), keySql.str()); - for (const auto barrierExecution : {false, true}) { - SCOPED_TRACE(fmt::format("barrierExecution {}", barrierExecution)); + struct { + bool hasBarrier; + bool serialExecution; + + std::string toString() const { + return fmt::format( + "hasBarrier: {}, serialExecution: {}", + hasBarrier, + serialExecution); + } + } testSettings[] = { + {false, false}, + {false, true}, + {true, false}, + {true, true}, + }; + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.toString()); auto task = AssertQueryBuilder(plan, duckDbQueryRunner_) .splits(makeHiveConnectorSplits(tempFiles)) - .serialExecution(true) - .barrierExecution(barrierExecution) + .serialExecution(testData.serialExecution) + .barrierExecution(testData.hasBarrier) .config( core::QueryConfig::kPreferredOutputBatchRows, std::to_string(outputBatchSize)) .assertResults(sql); const auto taskStats = task->taskStats(); - ASSERT_EQ(taskStats.numBarriers, barrierExecution ? numSplits : 0); + ASSERT_EQ(taskStats.numBarriers, testData.hasBarrier ? numSplits : 0); ASSERT_EQ(taskStats.numFinishedSplits, numSplits); EXPECT_EQ(NonPODInt64::constructed, NonPODInt64::destructed); } @@ -545,8 +614,9 @@ class StreamingAggregationTest : public HiveConnectorTestBase, core::PlanNodeId aggregationNodeId; auto plan = PlanBuilder(planNodeIdGenerator) .startTableScan() - .outputType(std::dynamic_pointer_cast( - inputVectors[0]->type())) + .outputType( + std::dynamic_pointer_cast( + inputVectors[0]->type())) .endTableScan() .streamingAggregation( keys[0]->type()->asRow().names(), @@ -566,18 +636,34 @@ class StreamingAggregationTest : public HiveConnectorTestBase, const auto sql = fmt::format("SELECT distinct {} FROM tmp", keySql.str()); - for (const auto barrierExecution : {false, true}) { - SCOPED_TRACE(fmt::format("barrierExecution {}", barrierExecution)); + struct { + bool hasBarrier; + bool serialExecution; + + std::string toString() const { + return fmt::format( + "hasBarrier: {}, serialExecution: {}", + hasBarrier, + serialExecution); + } + } testSettings[] = { + {false, false}, + {false, true}, + {true, false}, + {true, true}, + }; + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.toString()); auto task = AssertQueryBuilder(plan, duckDbQueryRunner_) .splits(makeHiveConnectorSplits(tempFiles)) - .serialExecution(true) - .barrierExecution(barrierExecution) + .serialExecution(testData.serialExecution) + .barrierExecution(testData.hasBarrier) .config( core::QueryConfig::kPreferredOutputBatchRows, std::to_string(outputBatchSize)) .assertResults(sql); const auto taskStats = task->taskStats(); - ASSERT_EQ(taskStats.numBarriers, barrierExecution ? numSplits : 0); + ASSERT_EQ(taskStats.numBarriers, testData.hasBarrier ? numSplits : 0); ASSERT_EQ(taskStats.numFinishedSplits, numSplits); EXPECT_EQ(NonPODInt64::constructed, NonPODInt64::destructed); } @@ -588,13 +674,32 @@ class StreamingAggregationTest : public HiveConnectorTestBase, VELOX_INSTANTIATE_TEST_SUITE_P( StreamingAggregationTest, StreamingAggregationTest, - testing::ValuesIn({0, 1, 64, std::numeric_limits::max()}), - [](const testing::TestParamInfo& info) { + testing::Values( + TestParams{0, 1}, + TestParams{0, 1024}, + TestParams{0, std::numeric_limits::max()}, + TestParams{1, 1}, + TestParams{1, 1024}, + TestParams{1, std::numeric_limits::max()}, + TestParams{64, 1}, + TestParams{64, 1024}, + TestParams{64, std::numeric_limits::max()}, + TestParams{std::numeric_limits::max(), 1}, + TestParams{std::numeric_limits::max(), 1024}, + TestParams{ + std::numeric_limits::max(), + std::numeric_limits::max()}), + [](const testing::TestParamInfo& info) { return fmt::format( - "streamingMinOutputBatchSize_{}", - info.param == std::numeric_limits::max() + "streamingMinOutputBatchSize_{}_preferredOutputBatchBytes_{}", + info.param.streamingMinOutputBatchSize == + std::numeric_limits::max() + ? "inf" + : std::to_string(info.param.streamingMinOutputBatchSize), + info.param.preferredOutputBatchBytes == + std::numeric_limits::max() ? "inf" - : std::to_string(info.param)); + : std::to_string(info.param.preferredOutputBatchBytes)); }); TEST_P(StreamingAggregationTest, smallInputBatches) { @@ -1036,16 +1141,17 @@ TEST_P(StreamingAggregationTest, clusteredInputWithBarrier) { auto planNodeIdGenerator = std::make_shared(); core::PlanNodeId streamingAggregationNodeId; - auto plan = PlanBuilder(planNodeIdGenerator) - .startTableScan() - .outputType(std::dynamic_pointer_cast( - inputVectors[0]->type())) - .endTableScan() - .partialStreamingAggregation( - {"c0"}, {"count(c1)", "arbitrary(c1)", "array_agg(c1)"}) - .capturePlanNodeId(streamingAggregationNodeId) - .finalAggregation() - .planNode(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .startTableScan() + .outputType( + std::dynamic_pointer_cast(inputVectors[0]->type())) + .endTableScan() + .partialStreamingAggregation( + {"c0"}, {"count(c1)", "arbitrary(c1)", "array_agg(c1)"}) + .capturePlanNodeId(streamingAggregationNodeId) + .finalAggregation() + .planNode(); const auto expected = makeRowVector( {makeNullableFlatVector( {1, 2, std::nullopt, 3, 4, 9, 10, 11, 12, 17, 18, 19}), @@ -1137,6 +1243,166 @@ TEST_P(StreamingAggregationTest, constantInput) { config(AssertQueryBuilder(plan), 10).assertResults(expected); } +TEST_P(StreamingAggregationTest, preferredOutputBatchBytes) { + // Use grouping keys that span one or more batches. + std::vector keys = { + makeNullableFlatVector({1, 1, std::nullopt, 2, 2}), + makeFlatVector({2, 3, 3, 4}), + makeFlatVector({5, 6, 6, 6}), + makeFlatVector({6, 6, 6, 6}), + makeFlatVector({6, 7, 8}), + }; + + auto data = addPayload(keys, 1); + + auto plan = PlanBuilder() + .values(data) + .partialStreamingAggregation( + {"c0"}, + {"count(1)", + "min(c1)", + "max(c1)", + "sum(c1)", + "sumnonpod(1)", + "sum(cast(NULL as INT))"}) + .finalAggregation() + .planNode(); + + auto results = + config(AssertQueryBuilder(plan), 1024).copyResultBatches(pool_.get()); + + // If streamingMinOutputBatchSize is set to 1, we expect an output batch for: + // {1, NULL}, {2}, {3, 4}, {5}, {6}, {7, 8}. + // Otherwise, we expect the output batches to be determined by + // preferredOutputBatchBytes. + size_t expectedOutputBatches; + if (GetParam().streamingMinOutputBatchSize == 1) { + expectedOutputBatches = 6; + } else if (GetParam().preferredOutputBatchBytes == 1) { + expectedOutputBatches = 5; + } else if (GetParam().preferredOutputBatchBytes == 1024) { + expectedOutputBatches = 2; + } else { + ASSERT_EQ( + GetParam().preferredOutputBatchBytes, + std::numeric_limits::max()); + expectedOutputBatches = 1; + } + + ASSERT_EQ(results.size(), expectedOutputBatches); +} + +TEST_F(StreamingAggregationTest, noGroupsSpanBatchesSingleGroup) { + // Create input batches where each batch has exactly one unique group. + // This tests the corner case where numGroups_ == 1 and noGroupsSpanBatches_ + // is true. + std::vector keys = { + makeFlatVector({1, 1, 1, 1}), + makeFlatVector({2, 2, 2, 2}), + makeFlatVector({3, 3, 3, 3}), + }; + + auto data = addPayload(keys, 1); + createDuckDbTable(data); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId aggregationNodeId; + auto plan = PlanBuilder(planNodeIdGenerator) + .values(data) + .streamingAggregation( + {"c0"}, + {"count(1)", "sum(c1)"}, + {}, + core::AggregationNode::Step::kSingle, + /*ignoreNullKeys=*/false, + /*noGroupsSpanBatches=*/true) + .capturePlanNodeId(aggregationNodeId) + .planNode(); + + auto task = + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config( + core::QueryConfig::kStreamingAggregationMinOutputBatchRows, + std::to_string(1)) + .assertResults("SELECT c0, count(1), sum(c1) FROM tmp GROUP BY c0"); + + // Verify the number of output batches matches the number of input batches + // because each batch contains a single group that should be output + // immediately. + const auto taskStats = task->taskStats(); + ASSERT_EQ( + velox::exec::toPlanStats(taskStats).at(aggregationNodeId).outputVectors, + keys.size()); +} + +// Tests that when noGroupsSpanBatches is set, the number of output batches +// matches the number of input batches when minOutputBatchRows is set to 1. +// When minOutputBatchRows is set to an extremely large value, we expect a +// single output batch. +TEST_F(StreamingAggregationTest, noGroupsSpanBatches) { + // Create input batches where no group spans across batches. + // Each batch has unique grouping keys that don't appear in other batches. + std::vector keys = { + makeFlatVector({1, 1, 2, 2}), + makeFlatVector({3, 3, 4, 4}), + makeFlatVector({5, 5, 6, 6}), + makeFlatVector({7, 7, 8, 8}), + makeFlatVector({9, 9, 10, 10}), + }; + + auto data = addPayload(keys, 1); + createDuckDbTable(data); + + struct { + int32_t minOutputBatchRows; + size_t expectedOutputBatches; + + std::string debugString() const { + return fmt::format( + "minOutputBatchRows={}, expectedOutputBatches={}", + minOutputBatchRows, + expectedOutputBatches); + } + } testSettings[] = { + // Regardless of the value of minOutputBatchRows, each input batch + // produces an output batch. + // We do not respect minOutputBatchRows when noGroupsSpanBatches is set. + {1, keys.size()}, + {std::numeric_limits::max(), keys.size()}, + }; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId aggregationNodeId; + auto plan = PlanBuilder(planNodeIdGenerator) + .values(data) + .streamingAggregation( + {"c0"}, + {"count(1)", "sum(c1)"}, + {}, + core::AggregationNode::Step::kSingle, + /*ignoreNullKeys=*/false, + /*noGroupsSpanBatches=*/true) + .capturePlanNodeId(aggregationNodeId) + .planNode(); + + auto task = + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config( + core::QueryConfig::kStreamingAggregationMinOutputBatchRows, + std::to_string(testData.minOutputBatchRows)) + .assertResults("SELECT c0, count(1), sum(c1) FROM tmp GROUP BY c0"); + + // Verify the number of output batches. + const auto taskStats = task->taskStats(); + ASSERT_EQ( + velox::exec::toPlanStats(taskStats).at(aggregationNodeId).outputVectors, + testData.expectedOutputBatches); + } +} + namespace { class InputSourceNode : public core::PlanNode { public: @@ -1286,5 +1552,41 @@ TEST_P(StreamingAggregationTest, needsInputWhenSplitOutput) { velox::exec::toPlanStats(taskStats).at(aggregationNodeId).outputVectors, 9); } +// Verify that during createOutput, the aggregation state is destroyed. +DEBUG_ONLY_TEST_P(StreamingAggregationTest, singleAggregationCleansState) { + auto size = 1'000; + + VectorPtr keys = + makeFlatVector(size, [](auto row) { return row / 10; }); + + auto data = addPayload({keys, keys, keys, keys}, 1); + + bool checkedAtLeastOnce = false; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::StreamingAggregation::createOutput", + std::function([&](StreamingAggregation*) { + checkedAtLeastOnce = true; + EXPECT_GT(NonPODInt64::destructed, 0); + })); + + auto plan = PlanBuilder() + .values(data) + .streamingAggregation( + {"c0"}, + {"sumnonpod(1)"}, + {}, + core::AggregationNode::Step::kSingle, + false, + true /* noGroupsSpanBatches */) + .planNode(); + + config(AssertQueryBuilder(plan), 100).copyResults(pool()); + + EXPECT_TRUE(checkedAtLeastOnce) + << "TestValue callback was never invoked; createOutput may not have " + "been called"; + EXPECT_EQ(NonPODInt64::constructed, NonPODInt64::destructed); +} + } // namespace } // namespace facebook::velox::exec diff --git a/velox/exec/tests/StreamingEnforceDistinctTest.cpp b/velox/exec/tests/StreamingEnforceDistinctTest.cpp new file mode 100644 index 00000000000..df70a816bae --- /dev/null +++ b/velox/exec/tests/StreamingEnforceDistinctTest.cpp @@ -0,0 +1,149 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" + +using namespace facebook::velox::exec::test; + +namespace facebook::velox::exec { +namespace { + +class StreamingEnforceDistinctTest : public OperatorTestBase { + protected: + core::PlanNodePtr makePlan( + const std::vector& input, + const std::vector& keys, + const std::string& errorMessage) { + return PlanBuilder() + .values(input) + .streamingEnforceDistinct(keys, errorMessage) + .planNode(); + } + + core::PlanNodePtr makePlan( + const RowVectorPtr& input, + const std::string& key, + const std::string& errorMessage) { + return PlanBuilder() + .values({input}) + .streamingEnforceDistinct({key}, errorMessage) + .planNode(); + } + + void assertDistinct( + const std::vector& input, + const std::vector& keys) { + auto plan = makePlan(input, keys, "Duplicate key found"); + AssertQueryBuilder(plan).assertResults(input); + } +}; + +TEST_F(StreamingEnforceDistinctTest, uniqueRowsSingleKey) { + auto data = makeRowVector({ + makeNullableFlatVector({1, 2, 3, std::nullopt, 5, 6, 7, 8, 9}), + makeFlatVector( + {"a", "a", "b", "b", "a", "a", "b", "b", "a"}), + }); + + assertDistinct(split(data, 3), {"c0"}); +} + +TEST_F(StreamingEnforceDistinctTest, uniqueRowsMultipleKeys) { + auto data = makeRowVector({ + makeFlatVector({1, 1, 2, 2, 3, 3, 4, 4, 5}), + makeFlatVector( + {"x", "x", "y", "y", "x", "x", "y", "y", "x"}), + makeFlatVector({10, 20, 10, 20, 10, 20, 10, 20, 10}), + }); + + assertDistinct(split(data, 3), {"c0", "c2"}); +} + +TEST_F(StreamingEnforceDistinctTest, duplicateWithinBatch) { + auto data = makeRowVector({ + makeFlatVector({1, 2, 2, 3, 4}), + }); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(makePlan(data, "c0", "Duplicate key found")) + .countResults(), + "Duplicate key found"); +} + +TEST_F(StreamingEnforceDistinctTest, duplicateAcrossBatches) { + auto batch1 = makeRowVector({ + makeFlatVector({1, 2, 3}), + }); + + auto batch2 = makeRowVector({ + makeFlatVector({3, 4, 5}), + }); + + VELOX_ASSERT_THROW( + AssertQueryBuilder( + makePlan({batch1, batch2}, {"c0"}, "Duplicate key found")) + .countResults(), + "Duplicate key found"); +} + +TEST_F(StreamingEnforceDistinctTest, emptyInput) { + auto data = makeRowVector({ + makeFlatVector({}), + }); + + assertDistinct({data}, {"c0"}); +} + +TEST_F(StreamingEnforceDistinctTest, singleRow) { + auto data = makeRowVector({ + makeFlatVector({42}), + }); + + assertDistinct({data}, {"c0"}); +} + +TEST_F(StreamingEnforceDistinctTest, duplicateNulls) { + auto data = makeRowVector({ + makeNullableFlatVector({1, std::nullopt, std::nullopt, 4, 5}), + }); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(makePlan(data, "c0", "Duplicate key found")) + .countResults(), + "Duplicate key found"); +} + +TEST_F(StreamingEnforceDistinctTest, duplicateNullsAcrossBatches) { + auto batch1 = makeRowVector({ + makeNullableFlatVector({1, 2, std::nullopt}), + }); + + auto batch2 = makeRowVector({ + makeNullableFlatVector({std::nullopt, 4, 5}), + }); + + VELOX_ASSERT_THROW( + AssertQueryBuilder( + makePlan({batch1, batch2}, {"c0"}, "Duplicate key found")) + .countResults(), + "Duplicate key found"); +} + +} // namespace +} // namespace facebook::velox::exec diff --git a/velox/exec/tests/TableEvolutionFuzzer.cpp b/velox/exec/tests/TableEvolutionFuzzer.cpp index a353cbe5057..f9bb3814549 100644 --- a/velox/exec/tests/TableEvolutionFuzzer.cpp +++ b/velox/exec/tests/TableEvolutionFuzzer.cpp @@ -15,13 +15,13 @@ */ #include "velox/exec/tests/TableEvolutionFuzzer.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/connectors/hive/HiveConnectorSplit.h" #include "velox/dwio/common/tests/utils/FilterGenerator.h" #include "velox/dwio/dwrf/common/Config.h" #include "velox/exec/Cursor.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/exec/tests/utils/QueryAssertions.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/expression/fuzzer/ExpressionFuzzer.h" #include "velox/functions/FunctionRegistry.h" #include "velox/vector/fuzzer/VectorFuzzer.h" @@ -48,7 +48,15 @@ DEFINE_bool( "up after failures. Therefore, results are not compared when this is " "enabled. Note that this option only works in debug builds."); +DEFINE_int32( + aggregation_pushdown_frequency, + 5, + "Controls the frequency of aggregation pushdown. The aggregation pushdown " + "is enabled with probability 1/N where N is this value. For example, " + "N=5 means 20% chance, N=2 means 50% chance."); + namespace facebook::velox::exec::test { +using namespace facebook::velox::common::testutil; std::ostream& operator<<( std::ostream& out, @@ -70,6 +78,14 @@ VectorFuzzer::Options makeVectorFuzzerOptions() { return options; } +template +void removeFromVector(std::vector& vec, const T& value) { + auto it = std::find(vec.begin(), vec.end(), value); + if (it != vec.end()) { + vec.erase(it); + } +} + bool hasUnsupportedMapKey(const TypePtr& type) { switch (type->kind()) { case TypeKind::MAP: { @@ -172,6 +188,43 @@ TableEvolutionFuzzer::parseFileFormats(std::string input) { namespace { +// Helper function to randomly select aggregates from available columns +// without replacement. Returns a list of aggregate expressions. +void generateAggregatesForColumns( + const std::vector& availableColumns, + const std::vector& supportedAggFuncs, + const RowTypePtr& schema, + FuzzerGenerator& rng, + std::vector& aggregates) { + if (availableColumns.empty()) { + return; + } + + int numAggregates = std::min( + static_cast(availableColumns.size()), + std::min( + static_cast(5), + static_cast( + folly::Random::rand32(1, availableColumns.size() + 1, rng)))); + + std::unordered_set selectedIndices; + for (int i = 0; i < numAggregates; ++i) { + if (folly::Random::oneIn(2, rng)) { + int randomIdx; + do { + randomIdx = folly::Random::rand32(availableColumns.size(), rng); + } while (selectedIndices.count(randomIdx) > 0); + selectedIndices.insert(randomIdx); + + int colIdx = availableColumns[randomIdx]; + std::string aggFunc = supportedAggFuncs[folly::Random::rand32( + supportedAggFuncs.size(), rng)]; + aggregates.push_back( + fmt::format("{}({})", aggFunc, schema->nameOf(colIdx))); + } + } +} + std::vector> runTaskCursors( const std::vector>& cursors, folly::Executor& executor) { @@ -218,9 +271,9 @@ std::vector> runTaskCursors( }); } std::vector> results; - constexpr std::chrono::seconds kTaskTimeout(10); + results.reserve(futures.size()); for (auto& future : futures) { - results.push_back(std::move(future).get(kTaskTimeout)); + results.push_back(std::move(future).get()); } return results; } @@ -243,7 +296,7 @@ void buildScanSplitFromTableWriteResult( auto* fragments = writeResult[0]->childAt(1)->asChecked>(); for (int i = 1; i < writeResult[0]->size(); ++i) { - auto fragment = folly::parseJson(fragments->valueAt(i)); + auto fragment = folly::parseJson(std::string_view(fragments->valueAt(i))); auto fileName = fragment["fileWriteInfos"][0]["writeFileName"].asString(); auto hiveSplit = std::make_shared( TableEvolutionFuzzer::connectorId(), @@ -268,7 +321,7 @@ void buildScanSplitFromTableWriteResult( for (auto bucketColumnIndex : bucketColumnIndices) { auto handle = std::make_unique( tableSchema->nameOf(bucketColumnIndex), - connector::hive::HiveColumnHandle::ColumnType::kRegular, + connector::hive::FileColumnHandle::ColumnType::kRegular, tableSchema->childAt(bucketColumnIndex), tableSchema->childAt(bucketColumnIndex)); bucketConversion.bucketColumnHandles.push_back(std::move(handle)); @@ -392,6 +445,102 @@ fuzzer::ExpressionFuzzer::FuzzedExpressionData generateRemainingFilters( return expressionFuzzer.fuzzExpressions(1); } +// Generate random aggregation configuration for pushdown testing. +// Only generates aggregations that are eligible for pushdown: +// - Supported aggregate functions: min, max, bool_and, bool_or +// - Each column can only be used by at most one aggregate +// - Grouping keys are optional (can be empty for global aggregation) +// - Columns with filters (subfield or remaining) are excluded to enable +// pushdown +std::optional generateAggregationConfig( + const RowTypePtr& schema, + FuzzerGenerator& rng, + const std::unordered_set& filteredColumns) { + // List of aggregate functions that support pushdown + // Note: Excluding 'sum' to avoid integer overflow in fuzzer with random data + static const std::vector supportedNumericAggs = {"min", "max"}; + static const std::vector supportedBooleanAggs = { + "bool_and", "bool_or"}; + static const std::vector supportedIntegerAggs = { + "bitwise_and_agg", "bitwise_or_agg", "bitwise_xor_agg"}; + + // Randomly decide number of grouping keys (0 to 2) + int numGroupingKeys = folly::Random::rand32(3, rng); + std::vector groupingKeys; + std::unordered_set usedColumnIndices; + + // Select random columns for grouping keys + for (int i = 0; i < numGroupingKeys && i < schema->size(); ++i) { + int colIdx = folly::Random::rand32(schema->size(), rng); + if (usedColumnIndices.count(colIdx) == 0) { + groupingKeys.push_back(schema->nameOf(colIdx)); + usedColumnIndices.insert(colIdx); + } + } + + // Generate aggregates on remaining columns + // For aggregation pushdown to work, each column should only be used once + // and columns with filters should be excluded + std::vector aggregates; + std::vector availableNumericColumns; + std::vector availableIntegerColumns; + std::vector availableBooleanColumns; + for (int i = 0; i < schema->size(); ++i) { + if (usedColumnIndices.count(i) == 0) { + auto columnName = schema->nameOf(i); + // Skip columns that have filters (subfield or remaining) + if (filteredColumns.count(columnName) > 0) { + continue; + } + + auto type = schema->childAt(i); + // Integer types: randomly choose between min/max or bitwise aggregations + // Note: Exclude DATE/Interval type as it doesn't support bitwise + // aggregations + if ((type->isInteger() || type->isBigint() || type->isSmallint() || + type->isTinyint()) && + !type->isDate() && !type->isIntervalDayTime() && + !type->isIntervalYearMonth()) { + if (folly::Random::oneIn(2, rng)) { + availableIntegerColumns.push_back(i); + } else { + availableNumericColumns.push_back(i); + } + } + // Float types support min/max only + else if ((type->isReal() || type->isDouble()) && !type->isDecimal()) { + availableNumericColumns.push_back(i); + } + // Boolean types support bool_and/bool_or + else if (type->isBoolean()) { + availableBooleanColumns.push_back(i); + } + } + } + + // Need at least one column to aggregate + if (availableNumericColumns.empty() && availableBooleanColumns.empty() && + availableIntegerColumns.empty()) { + return std::nullopt; + } + + // Randomly pick columns for aggregates without replacement + generateAggregatesForColumns( + availableNumericColumns, supportedNumericAggs, schema, rng, aggregates); + generateAggregatesForColumns( + availableBooleanColumns, supportedBooleanAggs, schema, rng, aggregates); + generateAggregatesForColumns( + availableIntegerColumns, supportedIntegerAggs, schema, rng, aggregates); + + if (aggregates.empty()) { + return std::nullopt; + } + + return AggregationConfig{ + .groupingKeys = std::move(groupingKeys), + .aggregates = std::move(aggregates)}; +} + } // namespace VectorPtr TableEvolutionFuzzer::liftToType( @@ -484,8 +633,9 @@ VectorPtr TableEvolutionFuzzer::liftToType( if (i < children.size()) { children[i] = liftToType(children[i], childType); } else { - children.push_back(BaseVector::createNullConstant( - childType, row->size(), config_.pool)); + children.push_back( + BaseVector::createNullConstant( + childType, row->size(), config_.pool)); } } return std::make_shared( @@ -559,12 +709,17 @@ void TableEvolutionFuzzer::run() { 2 * config_.evolutionCount - 1); RowVectorPtr finalExpectedData; + folly::F14FastMap> globalMapColumnKeys; + std::vector globallyConsistentColumnIndexVector; + createWriteTasks( testSetups, bucketColumnIndices, tableOutputRootDir->getPath(), writeTasks, - finalExpectedData); + finalExpectedData, + globalMapColumnKeys, + globallyConsistentColumnIndexVector); auto executor = folly::getGlobalCPUExecutor(); auto writeResults = runTaskCursors(writeTasks, *executor); @@ -584,17 +739,17 @@ void TableEvolutionFuzzer::run() { selectedBucket, finalExpectedData); - // Step 5: Setup scan tasks with filters + // Step 5: Setup scan tasks with filters and optional aggregation pushdown auto rowType = testSetups.back().schema; - PushdownConfig pushownConfig; + PushdownConfig pushdownConfig; // Generate subfield filters first - pushownConfig.subfieldFiltersMap = + pushdownConfig.subfieldFiltersMap = generateSubfieldFilters(rowType, finalExpectedData); // Extract field names used by subfield filters to avoid conflicts std::unordered_set subfieldFilteredFields; - for (const auto& [subfield, filter] : pushownConfig.subfieldFiltersMap) { + for (const auto& [subfield, filter] : pushdownConfig.subfieldFiltersMap) { auto fieldName = subfield.toString(); VLOG(1) << "Raw subfield: " << fieldName << ' ' << filter->toString(); // Extract the root field name (before any nested access) @@ -613,15 +768,63 @@ void TableEvolutionFuzzer::run() { applyRemainingFilters( generatedRemainingFilters, columnNameMapping, - pushownConfig, + pushdownConfig, subfieldFilteredFields); } + // Collect all filtered columns (both subfield and remaining filters) + std::unordered_set allFilteredColumns = subfieldFilteredFields; + + // Extract columns from remaining filter if present + if (!pushdownConfig.remainingFilter.empty()) { + for (const auto& name : rowType->names()) { + // Check if the column name appears in the remaining filter + if (pushdownConfig.remainingFilter.find(name) != std::string::npos) { + allFilteredColumns.insert(name); + } + } + } + + // Enable aggregation testing + std::optional aggConfig; + bool shouldTestAggregation = + folly::Random::oneIn(FLAGS_aggregation_pushdown_frequency, rng_); + if (shouldTestAggregation) { + aggConfig = generateAggregationConfig(rowType, rng_, allFilteredColumns); + if (aggConfig.has_value()) { + VLOG(1) << "Testing aggregation pushdown with grouping keys: [" + << folly::join(", ", aggConfig->groupingKeys) + << "] and aggregates: [" + << folly::join(", ", aggConfig->aggregates) << "]"; + } else { + VLOG(1) << "Could not generate valid aggregation configuration"; + aggConfig = std::nullopt; + } + } + std::vector> scanTasks(2); - scanTasks[0] = - makeScanTask(rowType, std::move(actualSplits), pushownConfig, false); - scanTasks[1] = - makeScanTask(rowType, std::move(expectedSplits), pushownConfig, true); + // actual: TableScan -> Aggregation (allows pushdown) + pushdownConfig.aggregationConfig = aggConfig; + scanTasks[0] = makeScanTask( + rowType, + std::move(actualSplits), + pushdownConfig, + false, + false, // insertProjectToBlockPushdown + globalMapColumnKeys, + globallyConsistentColumnIndexVector); + + // expected: TableScan -> Project -> Aggregation (blocks pushdown) + // Insert a Project node to prevent aggregation pushdown + pushdownConfig.aggregationConfig = aggConfig; + scanTasks[1] = makeScanTask( + rowType, + std::move(expectedSplits), + pushdownConfig, + true, + true, // insertProjectToBlockPushdown + globalMapColumnKeys, + globallyConsistentColumnIndexVector); ScopedOOMInjector oomInjectorReadPath( [this]() -> bool { return folly::Random::oneIn(10, rng_); }, @@ -799,7 +1002,9 @@ std::unique_ptr TableEvolutionFuzzer::makeWriteTask( const std::string& outputDir, const std::vector& bucketColumnIndices, FuzzerGenerator& rng, - bool enableFlatMap) { + bool enableFlatMap, + folly::F14FastMap>& globalMapColumnKeys, + std::vector& globallyCompatibleFlatmapColumns) { auto builder = PlanBuilder().values({data}); // Create serdeParameters using proper dwrf::Config for flatmap configuration @@ -813,6 +1018,7 @@ std::unique_ptr TableEvolutionFuzzer::makeWriteTask( if (setup.schema->childAt(i)->isMap()) { // Check if this specific map column has any empty elements if (hasEmptyElement(data, i)) { + removeFromVector(globallyCompatibleFlatmapColumns, i); continue; } @@ -822,7 +1028,76 @@ std::unique_ptr TableEvolutionFuzzer::makeWriteTask( supportedMapColumnIndices.push_back(static_cast(i)); VLOG(1) << "Write column " << setup.schema->nameOf(i) << " as flatmap"; + + // Extract actual keys from the map data and collect directly into + // global set + SelectivityVector allRows(data->childAt(i)->size()); + DecodedVector decodedMap(*data->childAt(i), allRows); + auto* mapVector = decodedMap.base()->asChecked(); + if (mapVector->size() > 0) { + auto keys = mapVector->mapKeys(); + + if (keys) { + // Collect keys directly into the global set + auto& uniqueKeys = globalMapColumnKeys[static_cast(i)]; + + // Iterate through the decoded rows, not the raw mapVector + // indices + for (vector_size_t row = 0; row < data->childAt(i)->size(); + ++row) { + auto decodedIndex = decodedMap.index(row); + if (!decodedMap.isNullAt(row) && + !mapVector->isNullAt(decodedIndex)) { + // Get the map entry for this decoded row + auto mapOffset = mapVector->offsetAt(decodedIndex); + auto mapSize = mapVector->sizeAt(decodedIndex); + + // Process all keys in this map entry + for (vector_size_t keyIdx = 0; keyIdx < mapSize; ++keyIdx) { + auto keyPosition = mapOffset + keyIdx; + if (!keys->isNullAt(keyPosition)) { + std::string keyStr; + if (keys->type()->isVarchar() || + keys->type()->isVarbinary()) { + auto* keyVector = keys->asFlatVector(); + auto keyView = keyVector->valueAt(keyPosition); + keyStr = std::string(keyView); + } else if (keys->type()->isInteger()) { + auto* keyVector = keys->asFlatVector(); + auto keyVal = keyVector->valueAt(keyPosition); + keyStr = std::to_string(keyVal); + } else if (keys->type()->isBigint()) { + auto* keyVector = keys->asFlatVector(); + auto keyVal = keyVector->valueAt(keyPosition); + keyStr = std::to_string(keyVal); + } else if (keys->type()->isSmallint()) { + auto* keyVector = keys->asFlatVector(); + auto keyVal = keyVector->valueAt(keyPosition); + keyStr = std::to_string(keyVal); + } else if (keys->type()->isTinyint()) { + auto* keyVector = keys->asFlatVector(); + auto keyVal = keyVector->valueAt(keyPosition); + keyStr = std::to_string(keyVal); + } else { + // This should not be reached since + // hasUnsupportedMapKey filters out unsupported types + VELOX_UNREACHABLE( + "Unsupported map key type: {}", + keys->type()->toString()); + } + uniqueKeys.insert(keyStr); + } + } + } + } + } + } + } else { + // Remove this column from globallyCompatibleFlatmapColumns + removeFromVector(globallyCompatibleFlatmapColumns, i); } + } else { + removeFromVector(globallyCompatibleFlatmapColumns, i); } } } @@ -906,22 +1181,101 @@ VectorPtr TableEvolutionFuzzer::liftToPrimitiveType( std::vector({})); } +RowTypePtr TableEvolutionFuzzer::buildFlatmapAsStructSchema( + const RowTypePtr& tableSchema, + const folly::F14FastMap>& + globalMapColumnKeys, + const std::vector& globallyCompatibleFlatmapColumns) { + if (globallyCompatibleFlatmapColumns.empty()) { + return tableSchema; + } + + VLOG(1) << "Setting up struct reading for " + << globallyCompatibleFlatmapColumns.size() + << " flatmap columns with real keys"; + + auto names = tableSchema->names(); + auto types = tableSchema->children(); + + // Filter globalMapColumnKeys to only include globally compatible columns + std::unordered_map> filteredMapColumnKeys; + for (int mapColumnIndex : globallyCompatibleFlatmapColumns) { + if (globalMapColumnKeys.find(mapColumnIndex) != globalMapColumnKeys.end()) { + // Add 50% probability to include this column in filteredMapColumnKeys + if (folly::Random::oneIn(2, rng_)) { + filteredMapColumnKeys[mapColumnIndex] = + globalMapColumnKeys.at(mapColumnIndex); + } + } + } + + // Use the filteredMapColumnKeys for struct reading + for (const auto& [mapColumnIndex, keysSet] : filteredMapColumnKeys) { + // Convert map type to struct type for struct reading + auto finalMapType = types[mapColumnIndex]->asMap(); + auto finalValueType = finalMapType.valueType(); + // Convert F14FastSet to vector for ROW constructor + std::vector keys(keysSet.begin(), keysSet.end()); + // Construct struct schema with real keys from write time + final value + // type + std::vector finalStructFieldTypes(keys.size(), finalValueType); + auto finalStructSchema = ROW(keys, finalStructFieldTypes); + + // Replace the map type with struct type in the schema + types[mapColumnIndex] = finalStructSchema; + } + + // Build new schema using struct reading for flatmap columns + return ROW(names, types); +} + std::unique_ptr TableEvolutionFuzzer::makeScanTask( const RowTypePtr& tableSchema, std::vector splits, const PushdownConfig& pushdownConfig, - bool useFiltersAsNode) { + bool useFiltersAsNode, + bool insertProjectToBlockPushdown, + const folly::F14FastMap>& + globalMapColumnKeys, + const std::vector& globallyCompatibleFlatmapColumns) { + // Build schema for flatmap as struct reading + RowTypePtr newSchemaUsingStructReadingFlatMap = buildFlatmapAsStructSchema( + tableSchema, globalMapColumnKeys, globallyCompatibleFlatmapColumns); + CursorParameters params; params.serialExecution = true; - // TODO: Mix in filter and aggregate pushdowns. - params.planNode = PlanBuilder() - .filtersAsNode(useFiltersAsNode) - .tableScanWithPushDown( - tableSchema, - /*pushdownConfig=*/pushdownConfig, - tableSchema, - {}) - .planNode(); + + auto builder = PlanBuilder() + .filtersAsNode(useFiltersAsNode) + .tableScanWithPushDown( + newSchemaUsingStructReadingFlatMap, // Use struct + // schema for + // flatmap reading + /*pushdownConfig=*/pushdownConfig, + tableSchema, // Original schema as dataColumns + {}); + + // If insertProjectToBlockPushdown is set, insert an identity Project node + // to prevent Driver::mayPushdownAggregation() from allowing pushdown + if (insertProjectToBlockPushdown && + pushdownConfig.aggregationConfig.has_value()) { + // Create identity projection: simply pass through all columns + std::vector projectExprs; + for (const auto& name : newSchemaUsingStructReadingFlatMap->names()) { + projectExprs.push_back(name); + } + builder.project(projectExprs); + } + + // Add aggregation if enabled in pushdown config + if (pushdownConfig.aggregationConfig.has_value()) { + builder.singleAggregation( + pushdownConfig.aggregationConfig->groupingKeys, + pushdownConfig.aggregationConfig->aggregates); + } + + params.planNode = builder.planNode(); + auto cursor = TaskCursor::create(params); for (auto& split : splits) { cursor->task()->addSplit("0", std::move(split)); @@ -986,16 +1340,41 @@ void TableEvolutionFuzzer::createWriteTasks( const std::vector& bucketColumnIndices, const std::string& tableOutputRootDirPath, std::vector>& writeTasks, - RowVectorPtr& finalExpectedData) { + RowVectorPtr& finalExpectedData, + folly::F14FastMap>& globalMapColumnKeys, + std::vector& globallyConsistentColumnIndexVector) { + // Initialize globallyConsistentColumnIndexVector with all map column indices + // from the first schema, then filter out incompatible ones during processing + if (hasMapColumns(testSetups[0].schema)) { + for (int j = 0; j < testSetups[0].schema->size(); ++j) { + if (testSetups[0].schema->childAt(j)->isMap() && + !hasUnsupportedMapKey(testSetups[0].schema->childAt(j))) { + globallyConsistentColumnIndexVector.push_back(j); + } + } + } + + // Generate data and create write tasks in a single loop for (int i = 0; i < config_.evolutionCount; ++i) { + // Generate fresh data for each evolution step independently auto data = vectorFuzzer_.fuzzRow(testSetups[i].schema, kVectorSize, false); for (auto& child : data->children()) { BaseVector::flattenVector(child); } + auto actualDir = fmt::format("{}/actual_{}", tableOutputRootDirPath, i); VELOX_CHECK(std::filesystem::create_directory(actualDir)); + + // Pass globally consistent columns to restrict flatmap usage writeTasks[2 * i] = makeWriteTask( - testSetups[i], data, actualDir, bucketColumnIndices, rng_, true); + testSetups[i], + data, + actualDir, + bucketColumnIndices, + rng_, + true, + globalMapColumnKeys, + globallyConsistentColumnIndexVector); if (i == config_.evolutionCount - 1) { finalExpectedData = std::move(data); @@ -1012,7 +1391,9 @@ void TableEvolutionFuzzer::createWriteTasks( expectedDir, bucketColumnIndices, rng_, - true); + true, + globalMapColumnKeys, + globallyConsistentColumnIndexVector); } } diff --git a/velox/exec/tests/TableEvolutionFuzzer.h b/velox/exec/tests/TableEvolutionFuzzer.h index f00a25923a2..80f0ff95538 100644 --- a/velox/exec/tests/TableEvolutionFuzzer.h +++ b/velox/exec/tests/TableEvolutionFuzzer.h @@ -105,7 +105,10 @@ class TableEvolutionFuzzer { const std::string& outputDir, const std::vector& bucketColumnIndices, FuzzerGenerator& rng, - bool enableFlatMap); + bool enableFlatMap, + folly::F14FastMap>& + globalMapColumnKeys, + std::vector& globallyCompatibleFlatmapColumns); template VectorPtr liftToPrimitiveType( @@ -118,7 +121,19 @@ class TableEvolutionFuzzer { const RowTypePtr& tableSchema, std::vector splits, const PushdownConfig& pushdownConfig, - bool useFiltersAsNode); + bool useFiltersAsNode, + bool insertProjectToBlockPushdown, + const folly::F14FastMap>& + globalMapColumnKeys = {}, + const std::vector& globallyCompatibleFlatmapColumns = {}); + + /// Builds schema for flatmap as struct reading by converting selected map + /// columns to struct types. + RowTypePtr buildFlatmapAsStructSchema( + const RowTypePtr& tableSchema, + const folly::F14FastMap>& + globalMapColumnKeys, + const std::vector& globallyCompatibleFlatmapColumns); /// Randomly generates bucket column indices for partitioning data. /// Returns a vector of column indices that will be used for bucketing, @@ -134,7 +149,10 @@ class TableEvolutionFuzzer { const std::vector& bucketColumnIndices, const std::string& tableOutputRootDirPath, std::vector>& writeTasks, - RowVectorPtr& finalExpectedData); + RowVectorPtr& finalExpectedData, + folly::F14FastMap>& + globalMapColumnKeys, + std::vector& globallyConsistentColumnIndexVector); /// Creates scan splits from write results. /// Converts the output of write tasks into scan splits that can be used diff --git a/velox/exec/tests/TableEvolutionFuzzerTest.cpp b/velox/exec/tests/TableEvolutionFuzzerTest.cpp index 45039496844..41fa19997af 100644 --- a/velox/exec/tests/TableEvolutionFuzzerTest.cpp +++ b/velox/exec/tests/TableEvolutionFuzzerTest.cpp @@ -15,9 +15,12 @@ */ #include "velox/exec/tests/TableEvolutionFuzzer.h" +#include "velox/connectors/ConnectorRegistry.h" #include "velox/connectors/hive/HiveConnector.h" +#include "velox/dwio/common/FileSink.h" #include "velox/dwio/dwrf/RegisterDwrfReader.h" #include "velox/dwio/dwrf/RegisterDwrfWriter.h" +#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" #include @@ -40,9 +43,11 @@ void registerFactories(folly::Executor* ioExecutor) { auto hiveConnector = factory.newConnector( TableEvolutionFuzzer::connectorId(), std::make_shared( - std::unordered_map()), + std::unordered_map{ + {connector::hive::HiveConfig::kEnableFileHandleCache, "false"}}), ioExecutor); - connector::registerConnector(hiveConnector); + connector::ConnectorRegistry::global().insert( + hiveConnector->connectorId(), hiveConnector); dwio::common::registerFileSinks(); dwrf::registerDwrfReaderFactory(); dwrf::registerDwrfWriterFactory(); @@ -84,6 +89,7 @@ int main(int argc, char** argv) { auto ioExecutor = folly::getGlobalIOExecutor(); facebook::velox::exec::test::registerFactories(ioExecutor.get()); facebook::velox::functions::prestosql::registerAllScalarFunctions(); + facebook::velox::aggregate::prestosql::registerAllAggregateFunctions(); facebook::velox::parse::registerTypeResolver(); return RUN_ALL_TESTS(); } diff --git a/velox/exec/tests/TableScanTest.cpp b/velox/exec/tests/TableScanTest.cpp index 2bc53f8fcf6..4e580276971 100644 --- a/velox/exec/tests/TableScanTest.cpp +++ b/velox/exec/tests/TableScanTest.cpp @@ -17,8 +17,8 @@ #include #include -#include #include +#include #include #include @@ -30,12 +30,16 @@ #include "velox/common/file/tests/FaultyFile.h" #include "velox/common/file/tests/FaultyFileSystem.h" #include "velox/common/memory/MemoryArbitrator.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/common/testutil/TestValue.h" +#include "velox/connectors/ConnectorRegistry.h" +#include "velox/connectors/hive/ExtractionUtils.h" #include "velox/connectors/hive/HiveConfig.h" #include "velox/connectors/hive/HiveConnector.h" #include "velox/connectors/hive/HiveDataSource.h" #include "velox/connectors/hive/HivePartitionFunction.h" #include "velox/dwio/common/tests/utils/DataFiles.h" +#include "velox/dwio/orc/reader/OrcReader.h" #include "velox/exec/Cursor.h" #include "velox/exec/Exchange.h" #include "velox/exec/PlanNodeStats.h" @@ -44,8 +48,8 @@ #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/exec/tests/utils/TableScanTestBase.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/expression/ExprToSubfieldFilter.h" +#include "velox/functions/lib/IsNull.h" #include "velox/type/Timestamp.h" #include "velox/type/Type.h" #include "velox/type/tests/SubfieldFiltersBuilder.h" @@ -61,6 +65,7 @@ using namespace facebook::velox::tests::utils; DECLARE_int32(cache_prefetch_min_pct); namespace facebook::velox::exec { +using namespace facebook::velox::common::testutil; namespace { void verifyCacheStats( const FileHandleCacheStats& cacheStats, @@ -72,7 +77,12 @@ void verifyCacheStats( EXPECT_EQ(cacheStats.numLookups, numLookups); } -class TableScanTest : public TableScanTestBase {}; +class TableScanTest : public TableScanTestBase { + void SetUp() override { + TableScanTestBase::SetUp(); + orc::registerOrcReaderFactory(); + } +}; TEST_F(TableScanTest, allColumns) { auto vectors = makeVectors(10, 1'000); @@ -112,6 +122,13 @@ TEST_F(TableScanTest, directBufferInputRawInputBytes) { .endTableScan() .planNode(); + // Disable file preloading to ensure individual stream reads are tracked + // for overreadBytes verification. + resetHiveConnector( + std::make_shared( + std::unordered_map{ + {connector::hive::HiveConfig::kFilePreloadThreshold, "0"}})); + std::unordered_map config; std::unordered_map> connectorConfigs = {}; @@ -137,7 +154,8 @@ TEST_F(TableScanTest, directBufferInputRawInputBytes) { auto overreadBytes = getTableScanRuntimeStats(task).at("overreadBytes").sum; ASSERT_GE(rawInputBytes, 500); ASSERT_EQ(overreadBytes, 13); - ASSERT_EQ( + // Without preloading, storageReadBytes is the sum of individual stream reads. + ASSERT_LE( getTableScanRuntimeStats(task).at("storageReadBytes").sum, rawInputBytes + overreadBytes); ASSERT_GT(getTableScanRuntimeStats(task)["totalScanTime"].sum, 0); @@ -181,9 +199,9 @@ DEBUG_ONLY_TEST_F(TableScanTest, pendingCoalescedIoWhenTaskFailed) { // on-demand load. const std::string errMsg{"injectedError"}; SCOPED_TESTVALUE_SET( - "facebook::velox::connector::hive::HiveDataSource::next", - std::function( - [&](connector::hive::HiveDataSource* /*unused*/) { + "facebook::velox::connector::hive::FileDataSource::next", + std::function( + [&](connector::hive::FileDataSource* /*unused*/) { VELOX_FAIL(errMsg); })); SCOPED_TESTVALUE_SET( @@ -207,7 +225,7 @@ DEBUG_ONLY_TEST_F(TableScanTest, pendingCoalescedIoWhenTaskFailed) { TEST_F(TableScanTest, connectorStats) { auto hiveConnector = std::dynamic_pointer_cast( - connector::getConnector(kHiveConnectorId)); + connector::ConnectorRegistry::tryGet(kHiveConnectorId)); EXPECT_NE(nullptr, hiveConnector); verifyCacheStats(hiveConnector->fileHandleCacheStats(), 0, 0, 0); @@ -426,6 +444,7 @@ TEST_F(TableScanTest, timestampPrecisionDefaultMillisecond) { makeFlatVector( 1, [](auto) { return Timestamp(1, 1'000'000); }), }); + split = makeHiveConnectorSplit(file->getPath()); AssertQueryBuilder(plan).split(split).assertResults(expected); } @@ -503,6 +522,30 @@ DEBUG_ONLY_TEST_F(TableScanTest, timeLimitInGetOutput) { EXPECT_GE(numBailed, 12); } +TEST_F(TableScanTest, outputBatchRowsOverride) { + auto vectors = makeVectors(1, 500); + auto filePath = TempFilePath::create(); + writeToFile(filePath->getPath(), vectors); + + constexpr uint32_t kBatchRowsOverride{100}; + auto plan = tableScanNode(); + auto batches = AssertQueryBuilder(duckDbQueryRunner_) + .plan(plan) + .splits(makeHiveConnectorSplits({filePath})) + .config( + QueryConfig::kTableScanOutputBatchRowsOverride, + folly::to(kBatchRowsOverride)) + .copyResultBatches(pool_.get()); + + ASSERT_FALSE(batches.empty()); + for (auto i = 0; i + 1 < batches.size(); ++i) { + EXPECT_EQ(batches[i]->size(), kBatchRowsOverride); + } + EXPECT_LE(batches.back()->size(), kBatchRowsOverride); + + assertEqualResults(vectors, batches); +} + TEST_F(TableScanTest, subfieldPruningRowType) { // rowType: ROW // └── "e": ROW @@ -521,7 +564,7 @@ TEST_F(TableScanTest, subfieldPruningRowType) { connector::ColumnHandleMap assignments; assignments["e"] = std::make_shared( "e", - HiveColumnHandle::ColumnType::kRegular, + FileColumnHandle::ColumnType::kRegular, columnType, columnType, std::move(requiredSubfields)); @@ -577,7 +620,7 @@ TEST_F(TableScanTest, subfieldPruningRemainingFilterSubfieldsMissing) { connector::ColumnHandleMap assignments; assignments["e"] = std::make_shared( "e", - HiveColumnHandle::ColumnType::kRegular, + FileColumnHandle::ColumnType::kRegular, columnType, columnType, std::move(requiredSubfields)); @@ -610,6 +653,7 @@ TEST_F(TableScanTest, subfieldPruningRemainingFilterSubfieldsMissing) { .assignments(assignments) .endTableScan() .planNode(); + split = makeHiveConnectorSplit(filePath->getPath()); result = AssertQueryBuilder(op).split(split).copyResults(pool()); rows = result->as(); e = rows->childAt(0)->as(); @@ -631,7 +675,7 @@ TEST_F(TableScanTest, subfieldPruningRemainingFilterRootFieldMissing) { writeToFile(filePath->getPath(), vectors); connector::ColumnHandleMap assignments; assignments["d"] = std::make_shared( - "d", HiveColumnHandle::ColumnType::kRegular, BIGINT(), BIGINT()); + "d", FileColumnHandle::ColumnType::kRegular, BIGINT(), BIGINT()); auto op = PlanBuilder() .startTableScan() .outputType(ROW({{"d", BIGINT()}})) @@ -673,7 +717,7 @@ TEST_F(TableScanTest, subfieldPruningRemainingFilterStruct) { SCOPED_TRACE(fmt::format("{} {}", outputColumn, filterColumn)); connector::ColumnHandleMap assignments; assignments["d"] = std::make_shared( - "d", HiveColumnHandle::ColumnType::kRegular, BIGINT(), BIGINT()); + "d", FileColumnHandle::ColumnType::kRegular, BIGINT(), BIGINT()); if (outputColumn > kNoOutput) { std::vector subfields; if (outputColumn == kSubfieldOnly) { @@ -681,7 +725,7 @@ TEST_F(TableScanTest, subfieldPruningRemainingFilterStruct) { } assignments["c"] = std::make_shared( "c", - HiveColumnHandle::ColumnType::kRegular, + FileColumnHandle::ColumnType::kRegular, structType, structType, std::move(subfields)); @@ -758,7 +802,7 @@ TEST_F(TableScanTest, subfieldPruningRemainingFilterMap) { SCOPED_TRACE(fmt::format("{} {}", outputColumn, filterColumn)); connector::ColumnHandleMap assignments; assignments["a"] = std::make_shared( - "a", HiveColumnHandle::ColumnType::kRegular, BIGINT(), BIGINT()); + "a", FileColumnHandle::ColumnType::kRegular, BIGINT(), BIGINT()); if (outputColumn > kNoOutput) { std::vector subfields; if (outputColumn == kSubfieldOnly) { @@ -766,15 +810,14 @@ TEST_F(TableScanTest, subfieldPruningRemainingFilterMap) { } assignments["b"] = std::make_shared( "b", - HiveColumnHandle::ColumnType::kRegular, + FileColumnHandle::ColumnType::kRegular, mapType, mapType, std::move(subfields)); } std::string remainingFilter; if (filterColumn == kWholeColumn) { - remainingFilter = - "coalesce(b, cast(null AS MAP(BIGINT, BIGINT)))[0] == 0"; + remainingFilter = "coalesce(b, map_concat(b, b))[0] == 0"; } else { remainingFilter = "b[0] == 0"; } @@ -858,7 +901,7 @@ TEST_F(TableScanTest, subfieldPruningMapType) { connector::ColumnHandleMap assignments; assignments["c"] = std::make_shared( "c", - HiveColumnHandle::ColumnType::kRegular, + FileColumnHandle::ColumnType::kRegular, mapType, mapType, std::move(requiredSubfields)); @@ -941,7 +984,7 @@ TEST_F(TableScanTest, subfieldPruningArrayType) { connector::ColumnHandleMap assignments; assignments["c"] = std::make_shared( "c", - HiveColumnHandle::ColumnType::kRegular, + FileColumnHandle::ColumnType::kRegular, arrayType, arrayType, std::move(requiredSubfields)); @@ -1099,7 +1142,7 @@ TEST_F(TableScanTest, missingColumns) { common::SubfieldFilters filters; filters[common::Subfield("c1")] = lessThanOrEqualDouble(1050.0, true); auto tableHandle = std::make_shared( - kHiveConnectorId, "tmp", true, std::move(filters), nullptr, dataColumns); + kHiveConnectorId, "tmp", std::move(filters), nullptr, dataColumns); connector::ColumnHandleMap assignments; assignments["c0"] = regularColumn("c0", BIGINT()); op = PlanBuilder(pool_.get()) @@ -1352,6 +1395,80 @@ TEST_F(TableScanTest, batchSize) { } } +DEBUG_ONLY_TEST_F(TableScanTest, batchSizeFileEstimateFallback) { + const auto rowSize = 1024; + const auto columnSize = sizeof(int64_t); + const auto numColumns = 2 * rowSize / columnSize; + const auto numRowsSplit1 = 100; + const auto numRowsSplit2 = 2000; + const auto kDefaultBatchRows = 1024; + + std::vector names; + names.reserve(numColumns); + for (size_t i = 0; i < numColumns; i++) { + names.push_back(fmt::format("c{}", i)); + } + auto rowType = + ROW(std::move(names), std::vector(numColumns, BIGINT())); + + auto vector1 = makeVectors(1, numRowsSplit1, rowType); + auto vector2 = makeVectors(1, numRowsSplit2, rowType); + + auto filePath1 = TempFilePath::create(); + auto filePath2 = TempFilePath::create(); + writeToFile(filePath1->getPath(), vector1); + writeToFile(filePath2->getPath(), vector2); + + std::vector allVectors; + allVectors.reserve(2); + allVectors.push_back(vector1[0]); + allVectors.push_back(vector2[0]); + createDuckDbTable(allVectors); + + auto plan = PlanBuilder().tableScan(rowType).planNode(); + + std::atomic_int splitCount{0}; + std::vector batchSizesUsed; + std::mutex mutex; + + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::TableScan::getOutput::gotSplit", + std::function([&](const TableScan* tableScan) { + ++splitCount; + std::lock_guard lock(mutex); + batchSizesUsed.push_back(tableScan->testingReadBatchSize()); + })); + + SCOPED_TESTVALUE_SET( + "facebook::velox::connector::hive::FileDataSource::estimatedRowSize", + std::function([&](int64_t* estimatedRowSize) { + if (splitCount.load() >= 2) { + *estimatedRowSize = connector::DataSource::kUnknownRowSize; + } + })); + + auto task = AssertQueryBuilder(duckDbQueryRunner_) + .plan(plan) + .splits(makeHiveConnectorSplits({filePath1, filePath2})) + .config( + QueryConfig::kPreferredOutputBatchBytes, + folly::to(rowSize * 100)) + .config( + QueryConfig::kPreferredOutputBatchRows, + folly::to(kDefaultBatchRows)) + .assertResults("SELECT * FROM tmp"); + + const auto opStats = task->taskStats().pipelineStats[0].operatorStats[0]; + + EXPECT_EQ(opStats.outputPositions, numRowsSplit1 + numRowsSplit2); + EXPECT_GE(splitCount.load(), 2); + ASSERT_GE(batchSizesUsed.size(), 2); + + EXPECT_LT(batchSizesUsed[1], kDefaultBatchRows) + << "Second split should use the last known estimated row size, not default"; + EXPECT_GT(opStats.outputVectors, 3); +} + // Test that adding the same split with the same sequence id does not cause // double read and the 2nd split is ignored. TEST_F(TableScanTest, sequentialSplitNoDoubleRead) { @@ -1463,6 +1580,20 @@ TEST_F(TableScanTest, multipleSplits) { } } +TEST_F(TableScanTest, preloadSplits) { + auto filePaths = makeFilePaths(10); + auto vectors = makeVectors(10, 10); + for (int32_t i = 0; i < vectors.size(); i++) { + writeToFile(filePaths[i]->getPath(), vectors[i]); + } + createDuckDbTable(vectors); + + auto task = assertQuery( + tableScanNode(), filePaths, "SELECT * FROM tmp", /*numPrefetchSplit=*/10); + auto stats = getTableScanRuntimeStats(task); + ASSERT_EQ(stats.at("preloadedSplits").sum, 10); +} + TEST_F(TableScanTest, preloadingSplitClose) { auto filePaths = makeFilePaths(100); auto vectors = makeVectors(100, 100); @@ -1590,8 +1721,7 @@ DEBUG_ONLY_TEST_F(TableScanTest, tableScanSplitsAndWeights) { auto leafTaskId = "local://leaf-0"; auto leafPlan = PlanBuilder() .values(vectors) - .partitionedOutput( - {}, 1, {"c0", "c1", "c2"}, VectorSerde::Kind::kPresto) + .partitionedOutput({}, 1, {"c0", "c1", "c2"}, "Presto") .planNode(); std::unordered_map config; auto queryCtx = core::QueryCtx::create( @@ -1610,23 +1740,22 @@ DEBUG_ONLY_TEST_F(TableScanTest, tableScanSplitsAndWeights) { // Main task plan with table scan and remote exchange. auto planNodeIdGenerator = std::make_shared(); core::PlanNodeId scanNodeId, exchangeNodeId; - auto planNode = - PlanBuilder(planNodeIdGenerator, pool_.get()) - .tableScan(rowType_) - .capturePlanNodeId(scanNodeId) - .project({"c0 AS t0", "c1 AS t1", "c2 AS t2"}) - .hashJoin( - {"t0"}, - {"u0"}, - PlanBuilder(planNodeIdGenerator, pool_.get()) - .exchange(leafPlan->outputType(), VectorSerde::Kind::kPresto) - .capturePlanNodeId(exchangeNodeId) - .project({"c0 AS u0", "c1 AS u1", "c2 AS u2"}) - .planNode(), - "", - {"t1"}, - core::JoinType::kAnti) - .planNode(); + auto planNode = PlanBuilder(planNodeIdGenerator, pool_.get()) + .tableScan(rowType_) + .capturePlanNodeId(scanNodeId) + .project({"c0 AS t0", "c1 AS t1", "c2 AS t2"}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator, pool_.get()) + .exchange(leafPlan->outputType(), "Presto") + .capturePlanNodeId(exchangeNodeId) + .project({"c0 AS u0", "c1 AS u1", "c2 AS u2"}) + .planNode(), + "", + {"t1"}, + core::JoinType::kAnti) + .planNode(); // Create task, cursor, start the task and supply the table scan splits. const int32_t numDrivers = 6; @@ -1719,9 +1848,9 @@ TEST_F(TableScanTest, splitOffsetAndLength) { } TEST_F(TableScanTest, fileNotFound) { - auto split = - exec::test::HiveConnectorSplitBuilder("/path/to/nowhere.orc").build(); auto assertMissingFile = [&](bool ignoreMissingFiles) { + auto split = + exec::test::HiveConnectorSplitBuilder("/path/to/nowhere.orc").build(); AssertQueryBuilder(tableScanNode()) .connectorSessionProperty( kHiveConnectorId, @@ -1896,7 +2025,7 @@ TEST_F(TableScanTest, partitionedTableDateKey) { 18506, std::numeric_limits::max(), false); auto tableHandle = std::make_shared( - "test-hive", "hive_table", true, std::move(filters), nullptr, nullptr); + "test-hive", "hive_table", std::move(filters), nullptr, nullptr); auto op = std::make_shared( "0", std::move(outputType), @@ -1921,10 +2050,6 @@ TEST_F(TableScanTest, partitionedTableTimestampKey) { // Test partition value is null. testPartitionedTable(filePath->getPath(), partitionType, std::nullopt); - auto split = exec::test::HiveConnectorSplitBuilder(filePath->getPath()) - .partitionKey("pkey", partitionValue) - .build(); - connector::ColumnHandleMap assignments = { {"pkey", partitionKey("pkey", TIMESTAMP())}, {"c0", regularColumn("c0", BIGINT())}, @@ -1956,6 +2081,9 @@ TEST_F(TableScanTest, partitionedTableTimestampKey) { .planNode(); auto expect = [&](bool asLocalTime) { + auto split = exec::test::HiveConnectorSplitBuilder(filePath->getPath()) + .partitionKey("pkey", partitionValue) + .build(); AssertQueryBuilder(plan, duckDbQueryRunner_) .connectorSessionProperty( kHiveConnectorId, @@ -1963,8 +2091,10 @@ TEST_F(TableScanTest, partitionedTableTimestampKey) { kReadTimestampPartitionValueAsLocalTimeSession, asLocalTime ? "true" : "false") .splits({split}) - .assertResults(fmt::format( - "SELECT {}, * FROM tmp", asLocalTime ? tsValueAsLocal : tsValue)); + .assertResults( + fmt::format( + "SELECT {}, * FROM tmp", + asLocalTime ? tsValueAsLocal : tsValue)); }; expect(true); @@ -1983,6 +2113,9 @@ TEST_F(TableScanTest, partitionedTableTimestampKey) { .planNode(); auto expect = [&](bool asLocalTime) { + auto split = exec::test::HiveConnectorSplitBuilder(filePath->getPath()) + .partitionKey("pkey", partitionValue) + .build(); AssertQueryBuilder(plan, duckDbQueryRunner_) .connectorSessionProperty( kHiveConnectorId, @@ -1990,9 +2123,10 @@ TEST_F(TableScanTest, partitionedTableTimestampKey) { kReadTimestampPartitionValueAsLocalTimeSession, asLocalTime ? "true" : "false") .splits({split}) - .assertResults(fmt::format( - "SELECT c0, {}, c1 FROM tmp", - asLocalTime ? tsValueAsLocal : tsValue)); + .assertResults( + fmt::format( + "SELECT c0, {}, c1 FROM tmp", + asLocalTime ? tsValueAsLocal : tsValue)); }; expect(true); expect(false); @@ -2010,6 +2144,9 @@ TEST_F(TableScanTest, partitionedTableTimestampKey) { .planNode(); auto expect = [&](bool asLocalTime) { + auto split = exec::test::HiveConnectorSplitBuilder(filePath->getPath()) + .partitionKey("pkey", partitionValue) + .build(); AssertQueryBuilder(plan, duckDbQueryRunner_) .connectorSessionProperty( kHiveConnectorId, @@ -2017,9 +2154,10 @@ TEST_F(TableScanTest, partitionedTableTimestampKey) { kReadTimestampPartitionValueAsLocalTimeSession, asLocalTime ? "true" : "false") .splits({split}) - .assertResults(fmt::format( - "SELECT c0, c1, {} FROM tmp", - asLocalTime ? tsValueAsLocal : tsValue)); + .assertResults( + fmt::format( + "SELECT c0, c1, {} FROM tmp", + asLocalTime ? tsValueAsLocal : tsValue)); }; expect(true); expect(false); @@ -2037,6 +2175,9 @@ TEST_F(TableScanTest, partitionedTableTimestampKey) { .planNode(); auto expect = [&](bool asLocalTime) { + auto split = exec::test::HiveConnectorSplitBuilder(filePath->getPath()) + .partitionKey("pkey", partitionValue) + .build(); AssertQueryBuilder(plan, duckDbQueryRunner_) .connectorSessionProperty( kHiveConnectorId, @@ -2044,8 +2185,10 @@ TEST_F(TableScanTest, partitionedTableTimestampKey) { kReadTimestampPartitionValueAsLocalTimeSession, asLocalTime ? "true" : "false") .splits({split}) - .assertResults(fmt::format( - "SELECT {} FROM tmp", asLocalTime ? tsValueAsLocal : tsValue)); + .assertResults( + fmt::format( + "SELECT {} FROM tmp", + asLocalTime ? tsValueAsLocal : tsValue)); }; expect(true); expect(false); @@ -2068,12 +2211,7 @@ TEST_F(TableScanTest, partitionedTableTimestampKey) { filters[common::Subfield("pkey")] = std::make_unique(lower, lower, false); auto tableHandle = std::make_shared( - "test-hive", - "hive_table", - true, - std::move(filters), - nullptr, - nullptr); + "test-hive", "hive_table", std::move(filters), nullptr, nullptr); return PlanBuilder() .startTableScan() @@ -2085,6 +2223,9 @@ TEST_F(TableScanTest, partitionedTableTimestampKey) { }; auto expect = [&](bool asLocalTime) { + auto split = exec::test::HiveConnectorSplitBuilder(filePath->getPath()) + .partitionKey("pkey", partitionValue) + .build(); AssertQueryBuilder(planWithSubfilter(asLocalTime), duckDbQueryRunner_) .connectorSessionProperty( kHiveConnectorId, @@ -2092,8 +2233,10 @@ TEST_F(TableScanTest, partitionedTableTimestampKey) { kReadTimestampPartitionValueAsLocalTimeSession, asLocalTime ? "true" : "false") .splits({split}) - .assertResults(fmt::format( - "SELECT {}, * FROM tmp", asLocalTime ? tsValueAsLocal : tsValue)); + .assertResults( + fmt::format( + "SELECT {}, * FROM tmp", + asLocalTime ? tsValueAsLocal : tsValue)); }; expect(true); expect(false); @@ -2501,7 +2644,7 @@ TEST_F(TableScanTest, statsBasedSkippingWithoutDecompression) { auto assertQuery = [&](const std::string& filter) { auto rowType = asRowType(rowVector->type()); return TableScanTest::assertQuery( - PlanBuilder(pool_.get()).tableScan(rowType, {filter}).planNode(), + PlanBuilder(pool_.get()).tableScan(rowType, {}, {filter}).planNode(), filePaths, "SELECT * FROM tmp WHERE " + filter); }; @@ -2952,6 +3095,81 @@ TEST_F(TableScanTest, fileSizeAndModifiedTime) { filterTest(fmt::format("\"{}\" = {}", kModifiedTime, fileTimeValue)); } +// Test that synthesized column filter validation throws an error when the +// filter doesn't match the split's value. +TEST_F(TableScanTest, synthesizedColumnFilterValidation) { + auto rowType = ROW({"a"}, {BIGINT()}); + auto filePath = makeFilePaths(1)[0]; + auto vector = makeVectors(1, 10, rowType)[0]; + writeToFile(filePath->getPath(), vector); + + static const char* kPath = "$path"; + auto assignments = allRegularColumns(rowType); + assignments[kPath] = synthesizedColumn(kPath, VARCHAR()); + + auto typeWithPath = ROW({kPath, "a"}, {VARCHAR(), BIGINT()}); + + // Create a filter that doesn't match the actual $path value. + auto tableHandle = makeTableHandle( + common::SubfieldFilters{}, + parseExpr( + fmt::format("\"{}\" = '/nonexistent/path'", kPath), typeWithPath)); + + auto op = PlanBuilder() + .startTableScan() + .outputType(typeWithPath) + .tableHandle(tableHandle) + .assignments(assignments) + .endTableScan() + .planNode(); + + auto split = + exec::test::HiveConnectorSplitBuilder(filePath->getPath()).build(); + + // The query should throw an exception because the filter doesn't match. + VELOX_ASSERT_THROW( + AssertQueryBuilder(op).splits({split}).copyResults(pool()), + "Synthesized column '$path' failed filter validation"); +} + +// Test that synthesized column filter validation throws an error for +// $file_modified_time when the filter doesn't match the split's value. +TEST_F(TableScanTest, synthesizedColumnFilterValidationModifiedTime) { + auto rowType = ROW({"a"}, {BIGINT()}); + auto filePath = makeFilePaths(1)[0]; + auto vector = makeVectors(1, 10, rowType)[0]; + writeToFile(filePath->getPath(), vector); + + static const char* kModifiedTime = "$file_modified_time"; + auto assignments = allRegularColumns(rowType); + assignments[kModifiedTime] = synthesizedColumn(kModifiedTime, BIGINT()); + + auto typeWithModifiedTime = ROW({kModifiedTime, "a"}, {BIGINT(), BIGINT()}); + + // Create a filter that doesn't match the actual $file_modified_time value. + // Use 0 which won't match any real file's modification time. + auto tableHandle = makeTableHandle( + common::SubfieldFilters{}, + parseExpr( + fmt::format("\"{}\" = 0", kModifiedTime), typeWithModifiedTime)); + + auto op = PlanBuilder() + .startTableScan() + .outputType(typeWithModifiedTime) + .tableHandle(tableHandle) + .assignments(assignments) + .endTableScan() + .planNode(); + + // Use makeHiveConnectorSplits to ensure infoColumns are set properly. + auto splits = makeHiveConnectorSplits({filePath}); + + // The query should throw an exception because the filter doesn't match. + VELOX_ASSERT_THROW( + AssertQueryBuilder(op).splits(splits).copyResults(pool()), + "Synthesized column '$file_modified_time' failed filter validation"); +} + TEST_F(TableScanTest, bucket) { vector_size_t size = 1'000; int numBatches = 5; @@ -3102,7 +3320,7 @@ TEST_F(TableScanTest, bucketConversion) { {"c2", std::make_shared( "c2", - HiveColumnHandle::ColumnType::kRowIndex, + FileColumnHandle::ColumnType::kRowIndex, BIGINT(), BIGINT())}, {"c1", makeColumnHandle("c1", BIGINT(), {})}, @@ -4310,6 +4528,53 @@ TEST_F(TableScanTest, parallelPrepare) { .copyResults(pool_.get()); } +TEST_F(TableScanTest, parallelPrepareWithSubfieldFilters) { + // Test metadataFilter is correctly transferred during split prefetch. + constexpr int32_t kNumParallel = 100; + auto data = makeRowVector({ + makeFlatVector(100, [](auto row) { return row; }), + makeFlatVector(100, [](auto row) { return row * 2; }), + }); + + auto filePath = TempFilePath::create(); + writeToFile(filePath->getPath(), {data}); + + auto subfieldFilters = SubfieldFiltersBuilder() + .add("c0", greaterThanOrEqual(10)) + .add("c1", lessThan(150)) + .build(); + + auto outputType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + auto remainingFilter = parseExpr("c0 % 2 = 0", outputType); + auto tableHandle = + makeTableHandle(std::move(subfieldFilters), remainingFilter); + auto assignments = allRegularColumns(outputType); + + auto plan = exec::test::PlanBuilder(pool_.get()) + .startTableScan() + .outputType(outputType) + .tableHandle(tableHandle) + .assignments(assignments) + .endTableScan() + .planNode(); + + std::vector splits; + for (auto i = 0; i < kNumParallel; ++i) { + splits.push_back(makeHiveSplit(filePath->getPath())); + } + + auto result = AssertQueryBuilder(plan) + .config( + core::QueryConfig::kMaxSplitPreloadPerDriver, + std::to_string(kNumParallel)) + .splits(splits) + .copyResults(pool_.get()); + + // Verify results: c0 >= 10 AND c1 < 150 AND c0 % 2 = 0 + // So rows [10,12,14,...,74] = 33 rows per split + ASSERT_EQ(result->size(), 33 * kNumParallel); +} + TEST_F(TableScanTest, dictionaryMemo) { constexpr int kSize = 100; const char* baseStrings[] = { @@ -4364,10 +4629,15 @@ TEST_F(TableScanTest, reuseRowVector) { .tableScan(rowType, {}, "c0 < 5") .project({"c1.c0"}) .planNode(); - auto split = exec::test::HiveConnectorSplitBuilder(file->getPath()).build(); + auto firstSplit = + exec::test::HiveConnectorSplitBuilder(file->getPath()).build(); + auto secondSplit = + exec::test::HiveConnectorSplitBuilder(file->getPath()).build(); auto expected = makeRowVector( {makeFlatVector(10, [](auto i) { return i % 5; })}); - AssertQueryBuilder(plan).splits({split, split}).assertResults(expected); + AssertQueryBuilder(plan) + .splits({firstSplit, secondSplit}) + .assertResults(expected); } // Tests queries that read more row fields than exist in the data. @@ -4623,6 +4893,7 @@ TEST_F(TableScanTest, readMissingFieldsInMap) { // Now run query with column mapping using names - we should not be able to // find any names. + split = makeHiveConnectorSplit(filePath->getPath()); result = AssertQueryBuilder(op) .connectorSessionProperty( kHiveConnectorId, @@ -4679,6 +4950,7 @@ TEST_F(TableScanTest, readMissingFieldsInMap) { .endTableScan() .project({"i1"}) .planNode(); + split = makeHiveConnectorSplit(filePath->getPath()); EXPECT_THROW( AssertQueryBuilder(op).split(split).copyResults(pool()), VeloxUserError); @@ -4883,6 +5155,7 @@ TEST_F(TableScanTest, readMissingFieldsWithMoreColumns) { // Now run query with column mapping using names - we should not be able to // find any names, except for the last string column. + split = makeHiveConnectorSplit(filePath->getPath()); result = AssertQueryBuilder(op) .connectorSessionProperty( kHiveConnectorId, @@ -5086,6 +5359,42 @@ TEST_F(TableScanTest, readFlatMapAsStruct) { AssertQueryBuilder(plan).split(split).assertResults(expected); } +// Test reading flatmap as struct when none of the requested keys exist in the +// file. All projected struct fields should be null. +TEST_F(TableScanTest, readFlatMapAsStructNoMatchingKeys) { + constexpr int kSize = 10; + std::vector keys = {"1", "2", "3"}; + auto c0 = makeRowVector( + keys, + { + makeFlatVector(kSize, folly::identity), + makeFlatVector(kSize, folly::identity), + makeFlatVector(kSize, folly::identity), + }); + auto vector = makeRowVector({c0}); + auto config = std::make_shared(); + config->set(dwrf::Config::FLATTEN_MAP, true); + config->set>(dwrf::Config::MAP_FLAT_COLS, {0}); + config->set>>( + dwrf::Config::MAP_FLAT_COLS_STRUCT_KEYS, {keys}); + auto file = TempFilePath::create(); + auto writeSchema = ROW({"c0"}, {MAP(INTEGER(), BIGINT())}); + writeToFile(file->getPath(), {vector}, config, writeSchema); + + // Request keys "4" and "5" which don't exist in the file. + auto readSchema = ROW({"c0"}, {ROW({"4", "5"}, {BIGINT(), BIGINT()})}); + auto plan = + PlanBuilder().tableScan(readSchema, {}, "", writeSchema).planNode(); + auto split = makeHiveConnectorSplit(file->getPath()); + auto expected = makeRowVector({makeRowVector( + {"4", "5"}, + { + makeNullConstant(TypeKind::BIGINT, kSize), + makeNullConstant(TypeKind::BIGINT, kSize), + })}); + AssertQueryBuilder(plan).split(split).assertResults(expected); +} + TEST_F(TableScanTest, flatMapReadOffset) { auto vector = makeRowVector( {makeNullableMapVector({std::nullopt, {{{1, 2}}}})}); @@ -5112,13 +5421,13 @@ TEST_F(TableScanTest, flatMapKeyTypeEvolution) { config->set>(dwrf::Config::MAP_FLAT_COLS, {0}); auto file = TempFilePath::create(); writeToFile(file->getPath(), {vector}, config); - auto split = makeHiveConnectorSplit(file->getPath()); auto schema = ROW({"c0"}, {MAP(BIGINT(), BIGINT())}); { SCOPED_TRACE("Read as map"); auto plan = PlanBuilder().tableScan(schema).planNode(); auto expected = makeRowVector({makeMapVector({{{1, 2}, {3, 4}}})}); + auto split = makeHiveConnectorSplit(file->getPath()); AssertQueryBuilder(plan).split(split).assertResults(expected); } { @@ -5128,6 +5437,7 @@ TEST_F(TableScanTest, flatMapKeyTypeEvolution) { auto expected = makeRowVector({makeRowVector( {"1", "3"}, {makeConstant(2, 1), makeConstant(4, 1)})}); + auto split = makeHiveConnectorSplit(file->getPath()); AssertQueryBuilder(plan).split(split).assertResults(expected); } } @@ -5146,7 +5456,6 @@ TEST_F(TableScanTest, flatMapLazyRowValue) { config->set>(dwrf::Config::MAP_FLAT_COLS, {0}); auto file = TempFilePath::create(); writeToFile(file->getPath(), {vector}, config); - auto split = makeHiveConnectorSplit(file->getPath()); { SCOPED_TRACE("Read as map"); auto plan = PlanBuilder() @@ -5157,6 +5466,7 @@ TEST_F(TableScanTest, flatMapLazyRowValue) { vector->rowType()) .planNode(); auto expected = makeRowVector({wrapInDictionary(makeIndices({1}), c0)}); + auto split = makeHiveConnectorSplit(file->getPath()); AssertQueryBuilder(plan).split(split).assertResults(expected); } { @@ -5182,6 +5492,7 @@ TEST_F(TableScanTest, flatMapLazyRowValue) { makeConstant(10, 1), }), })}); + auto split = makeHiveConnectorSplit(file->getPath()); AssertQueryBuilder(plan).split(split).assertResults(expected); } } @@ -5255,13 +5566,13 @@ TEST_F(TableScanTest, dynamicFilterWithRowIndexColumn) { connector::ColumnHandleMap assignments; assignments["a"] = std::make_shared( "a", - connector::hive::HiveColumnHandle::ColumnType::kRegular, + connector::hive::FileColumnHandle::ColumnType::kRegular, BIGINT(), BIGINT()); assignments["row_index"] = std::make_shared( "row_index", - connector::hive::HiveColumnHandle::ColumnType::kRowIndex, + connector::hive::FileColumnHandle::ColumnType::kRowIndex, BIGINT(), BIGINT()); std::shared_ptr files[2]; @@ -5296,6 +5607,63 @@ TEST_F(TableScanTest, dynamicFilterWithRowIndexColumn) { .assertResults(resVector); } +TEST_F(TableScanTest, bloomFilterPushdown) { + auto build = makeRowVector( + {"b"}, + { + makeFlatVector( + 10'001 + VectorHasher::kMaxDistinct, + [](auto i) { return 1000 * i; }), + }); + auto probe = makeRowVector( + {"a"}, + { + makeFlatVector( + 2 * build->size(), [](auto i) { return 500 * i; }), + }); + std::shared_ptr files[2]; + files[0] = TempFilePath::create(); + writeToFile(files[0]->getPath(), {probe}); + files[1] = TempFilePath::create(); + writeToFile(files[1]->getPath(), {build}); + auto idGenerator = std::make_shared(); + core::PlanNodeId probeScanId, buildScanId, joinId; + auto plan = PlanBuilder(idGenerator) + .tableScan(ROW({"a"}, {BIGINT()})) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"a"}, + {"b"}, + PlanBuilder(idGenerator) + .tableScan(ROW({"b"}, {BIGINT()})) + .capturePlanNodeId(buildScanId) + .planNode(), + /*filter=*/"", + {"a"}) + .capturePlanNodeId(joinId) + .planNode(); + for (bool parallelBuild : {false, true}) { + SCOPED_TRACE(fmt::format("parallelBuild={}", parallelBuild)); + AssertQueryBuilder builder(plan); + builder + .config( + core::QueryConfig::kHashProbeBloomFilterPushdownMaxSize, + std::to_string(4 * build->size())) + .split(probeScanId, makeHiveConnectorSplit(files[0]->getPath())) + .split(buildScanId, makeHiveConnectorSplit(files[1]->getPath())); + if (parallelBuild) { + builder.serialExecution(false).maxDrivers(2).config( + core::QueryConfig::kMinTableRowsForParallelJoinBuild, "1"); + } + auto task = builder.assertResults(build); + auto planStats = toPlanStats(task->taskStats()); + ASSERT_EQ( + planStats.at(joinId).customStats.at("dynamicFiltersProduced").sum, + parallelBuild ? 2 : 1); + ASSERT_GT(planStats.at(joinId).customStats.at("bloomFilterSize").sum, 0); + } +} + // TODO: re-enable this test once we add back driver suspension support for // table scan. TEST_F(TableScanTest, DISABLED_memoryArbitrationWithSlowTableScan) { @@ -5356,7 +5724,7 @@ TEST_F(TableScanTest, DISABLED_memoryArbitrationWithSlowTableScan) { .planNode(); std::thread queryThread([&]() { - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); auto task = assertQuery( op, filePaths, @@ -5482,9 +5850,9 @@ DEBUG_ONLY_TEST_F(TableScanTest, cancellationToken) { std::atomic_bool cancelled{false}; SCOPED_TESTVALUE_SET( - "facebook::velox::connector::hive::HiveDataSource::next", - std::function( - [&](connector::hive::HiveDataSource* source) { + "facebook::velox::connector::hive::FileDataSource::next", + std::function( + [&](connector::hive::FileDataSource* source) { auto cancellationToken = source->testingConnectorQueryCtx()->cancellationToken(); while (true) { @@ -5546,7 +5914,7 @@ TEST_F(TableScanTest, rowNumberInRemainingFilter) { {"r1", std::make_shared( "r1", - HiveColumnHandle::ColumnType::kRowIndex, + FileColumnHandle::ColumnType::kRowIndex, BIGINT(), BIGINT())}, }) @@ -5596,7 +5964,7 @@ TEST_F(TableScanTest, rowId) { writeToFile(file->getPath(), {vector}); auto makeRowIdColumnHandle = [&](auto& name) { return std::make_shared( - name, HiveColumnHandle::ColumnType::kRowId, rowIdType, rowIdType); + name, FileColumnHandle::ColumnType::kRowId, rowIdType, rowIdType); }; { SCOPED_TRACE("Preload"); @@ -5703,11 +6071,11 @@ TEST_F(TableScanTest, footerIOCount) { .assertResults( BaseVector::create(vector->type(), 0, pool())); auto stats = getTableScanRuntimeStats(task); - ASSERT_EQ(stats.at("numStorageRead").sum, 1); + ASSERT_EQ(stats.at("storageReadBytes").count, 1); ASSERT_GT(stats.at("footerBufferOverread").sum, 0); } -TEST_F(TableScanTest, statsBasedFilterReorderDisabled) { +TEST_F(TableScanTest, statsBasedFilterReorderBothEnabledAndDisabled) { gflags::FlagSaver gflagSaver; // Disable prefetch to avoid test flakiness. FLAGS_cache_prefetch_min_pct = 200; @@ -5736,8 +6104,8 @@ TEST_F(TableScanTest, statsBasedFilterReorderDisabled) { } createDuckDbTable(vectors); - for (auto disableReoder : {false}) { - SCOPED_TRACE(fmt::format("disableReoder {}", disableReoder)); + for (auto disableReorder : {false, true}) { + SCOPED_TRACE(fmt::format("disableReorder {}", disableReorder)); auto* cache = cache::AsyncDataCache::getInstance(); cache->clear(); @@ -5770,7 +6138,7 @@ TEST_F(TableScanTest, statsBasedFilterReorderDisabled) { kHiveConnectorId, connector::hive::HiveConfig:: kReadStatsBasedFilterReorderDisabledSession, - disableReoder ? "true" : "false") + disableReorder ? "true" : "false") // Disable coalesce so that each column stream has a separate read // per split at least. .connectorSessionProperty( @@ -5794,7 +6162,7 @@ TEST_F(TableScanTest, statsBasedFilterReorderDisabled) { auto tableScanStats = getTableScanStats(task); ASSERT_EQ(tableScanStats.customStats.count("storageReadBytes"), 1); ASSERT_GT(tableScanStats.customStats["storageReadBytes"].sum, 0); - ASSERT_EQ(tableScanStats.customStats["storageReadBytes"].count, 1); + ASSERT_GT(tableScanStats.customStats["storageReadBytes"].count, 0); ASSERT_EQ(tableScanStats.numSplits, numSplits); } @@ -5806,7 +6174,7 @@ TEST_F(TableScanTest, statsBasedFilterReorderDisabled) { kHiveConnectorId, connector::hive::HiveConfig:: kReadStatsBasedFilterReorderDisabledSession, - disableReoder ? "true" : "false") + disableReorder ? "true" : "false") .connectorSessionProperty( kHiveConnectorId, connector::hive::HiveConfig::kMaxCoalescedBytesSession, @@ -5824,15 +6192,17 @@ TEST_F(TableScanTest, statsBasedFilterReorderDisabled) { "SELECT c0 FROM tmp WHERE (c1 IN (1,7,11) OR c1 IS NULL) AND (c3 IN (1,7,11) OR c3 IS NULL)"); auto tableScanStats = getTableScanStats(task); - if (disableReoder) { + if (disableReorder) { ASSERT_EQ(tableScanStats.customStats.count("storageReadBytes"), 0); } else { + // Cache hit if (tableScanStats.customStats.count("storageReadBytes") == 0) { continue; } + // Cache miss, should behave like first time run ASSERT_EQ(tableScanStats.customStats.count("storageReadBytes"), 1); ASSERT_GT(tableScanStats.customStats["storageReadBytes"].sum, 0); - ASSERT_EQ(tableScanStats.customStats["storageReadBytes"].count, 1); + ASSERT_GT(tableScanStats.customStats["storageReadBytes"].count, 0); } ASSERT_EQ(tableScanStats.numSplits, numSplits); } @@ -5932,7 +6302,10 @@ TEST_F(TableScanTest, textfileEscape) { auto it = planStats.find(scanNodeId); ASSERT_TRUE(it != planStats.end()); auto rawInputBytes = it->second.rawInputBytes; - auto overreadBytes = getTableScanRuntimeStats(task).at("overreadBytes").sum; + auto runtimeStats = getTableScanRuntimeStats(task); + auto overreadIt = runtimeStats.find("overreadBytes"); + const int64_t overreadBytes = + overreadIt != runtimeStats.end() ? overreadIt->second.sum : 0; ASSERT_EQ(rawInputBytes, 11); ASSERT_EQ(overreadBytes, 0); @@ -6112,5 +6485,1273 @@ TEST_F(TableScanTest, duplicateFieldProject) { .assertResults("SELECT id, id FROM tmp WHERE name = 'John'"); } +TEST_F(TableScanTest, parallelUnitLoader) { + auto vectors = makeVectors(10, 1'000); + auto filePath = TempFilePath::create(); + writeToFile( + filePath->getPath(), + vectors, + std::make_shared(), + []() { return std::make_unique(1000, 0); }); + createDuckDbTable(vectors); + auto plan = tableScanNode(); + auto task = + AssertQueryBuilder(plan) + .splits(makeHiveConnectorSplits({filePath})) + .connectorSessionProperty( + kHiveConnectorId, + connector::hive::HiveConfig::kParallelUnitLoadCountSession, + std::to_string(3)) + .assertTypeAndNumRows(rowType_, 10'000); + auto stats = getTableScanRuntimeStats(task); + // Verify that parallel unit loader is enabled. + ASSERT_GT(stats.count("waitForUnitReadyNanos"), 0); +} + +TEST_F(TableScanTest, filterColumnHandles) { + auto data = makeVectors(1, 10, ROW({"a", "b"}, BIGINT())); + auto filePath = TempFilePath::create(); + writeToFile(filePath->getPath(), data); + auto split = exec::test::HiveConnectorSplitBuilder(filePath->getPath()) + .partitionKey("ds", "2025-10-23") + .build(); + auto plan = PlanBuilder() + .startTableScan() + .outputType(ROW({"x"}, {BIGINT()})) + .assignments({{"x", regularColumn("a", BIGINT())}}) + .dataColumns(asRowType(data[0]->type())) + .filterColumnHandles({ + partitionKey("ds", VARCHAR()), + regularColumn("a", BIGINT()), + }) + .remainingFilter("length(ds) + a % 2 > 0") + .endTableScan() + .planNode(); + AssertQueryBuilder(plan).split(split).assertResults( + makeRowVector({data[0]->childAt(0)})); +} + +TEST_F(TableScanTest, columnPostProcessorWithSubfieldFilters) { + auto data = makeFlatVector(10, folly::identity); + auto vector = makeRowVector({data, data, data}); + auto file = TempFilePath::create(); + writeToFile(file->getPath(), {vector}); + ChainedVectorLoader::PostVectorLoadProcessor postProc; + postProc = [&](auto& column) { + if (column->isLazy()) { + auto* lazy = column->template asUnchecked(); + if (lazy->isLoaded()) { + column = lazy->loadedVectorShared(); + postProc(column); + } else { + lazy->chain(postProc); + } + return; + } + if (column->encoding() == VectorEncoding::Simple::DICTIONARY) { + auto alphabet = column->valueVector(); + postProc(alphabet); + column->setValueVector(std::move(alphabet)); + return; + } + auto* values = + column->template asChecked>()->mutableRawValues(); + for (vector_size_t i = 0; i < column->size(); ++i) { + ++values[i]; + } + }; + auto c0Handle = std::make_shared( + "c0", + FileColumnHandle::ColumnType::kRegular, + BIGINT(), + BIGINT(), + std::vector{}, + HiveColumnHandle::ColumnParseParameters{}, + postProc); + auto c1Handle = regularColumn("c1", BIGINT()); + auto c2Handle = std::make_shared( + "c2", + FileColumnHandle::ColumnType::kRegular, + BIGINT(), + BIGINT(), + std::vector{}, + HiveColumnHandle::ColumnParseParameters{}, + postProc); + auto expected = makeRowVector({ + makeFlatVector({1, 3, 5, 7, 9}), + makeFlatVector({0, 2, 4, 6, 8}), + makeFlatVector({1, 3, 5, 7, 9}), + }); + { + SCOPED_TRACE("Subfield filters"); + auto plan = + PlanBuilder(pool()) + .startTableScan() + .outputType(asRowType(vector->type())) + .subfieldFilter("c0 in (0, 2, 4, 6, 8)") + .assignments({{"c0", c0Handle}, {"c1", c1Handle}, {"c2", c2Handle}}) + .endTableScan() + .planNode(); + auto split = makeHiveConnectorSplit(file->getPath()); + AssertQueryBuilder(plan).split(split).assertResults(expected); + } + { + SCOPED_TRACE("Remaining filter"); + auto plan = + PlanBuilder() + .startTableScan() + .outputType(asRowType(vector->type())) + .remainingFilter("c0 % 2 = 0") + .assignments({{"c0", c0Handle}, {"c1", c1Handle}, {"c2", c2Handle}}) + .endTableScan() + .planNode(); + auto split = makeHiveConnectorSplit(file->getPath()); + AssertQueryBuilder(plan).split(split).assertResults(expected); + } +} + +TEST_F(TableScanTest, shortDecimalFilter) { + functions::registerIsNotNullFunction("isnotnull"); + + std::vector> values = { + 123456789123456789L, + 987654321123456L, + std::nullopt, + 2000000000000000L, + 5000000000000000L, + 987654321987654321L, + 100000000000000L, + 1230000000123456L, + 120000000123456L, + std::nullopt}; + auto rowVector = makeRowVector( + {"a"}, + { + makeNullableFlatVector(values, DECIMAL(18, 6)), + }); + createDuckDbTable({rowVector}); + + auto filePath = facebook::velox::test::getDataFilePath( + "velox/exec/tests", "data/decimal.orc"); + auto createSplit = [&]() { + return exec::test::HiveConnectorSplitBuilder(filePath) + .start(0) + .length(fs::file_size(filePath)) + .fileFormat(dwio::common::FileFormat::ORC) + .build(); + }; + + auto rowType = rowVector->rowType(); + // Is not null. + auto op = + PlanBuilder().tableScan(rowType, {}, "isnotnull(a)", rowType).planNode(); + assertQuery(op, createSplit(), "SELECT a FROM tmp where a is not null"); + + // Is null. + op = PlanBuilder().tableScan(rowType, {}, "is_null(a)", rowType).planNode(); + assertQuery(op, createSplit(), "SELECT a FROM tmp where a is null"); + + // BigintRange. + op = + PlanBuilder() + .tableScan( + rowType, + {}, + "a > 2000000000.0::DECIMAL(18, 6) and a < 6000000000.0::DECIMAL(18, 6)", + rowType) + .planNode(); + assertQuery( + op, + createSplit(), + "SELECT a FROM tmp where a > 2000000000.0 and a < 6000000000.0"); + + // NegatedBigintRange. + op = + PlanBuilder() + .tableScan( + rowType, + {}, + "not(a between 2000000000.0::DECIMAL(18, 6) and 6000000000.0::DECIMAL(18, 6))", + rowType) + .planNode(); + assertQuery( + op, + createSplit(), + "SELECT a FROM tmp where a < 2000000000.0 or a > 6000000000.0"); +} + +TEST_F(TableScanTest, longDecimalFilter) { + functions::registerIsNotNullFunction("isnotnull"); + + std::vector> shortValues = { + 123456789123456789L, + 987654321123456L, + std::nullopt, + 2000000000000000L, + 5000000000000000L, + 987654321987654321L, + 100000000000000L, + 1230000000123456L, + 120000000123456L, + std::nullopt}; + + std::vector> longValues = { + HugeInt::parse("123456789123456789123456789" + std::string(9, '0')), + HugeInt::parse("987654321123456789" + std::string(9, '0')), + std::nullopt, + HugeInt::parse("2" + std::string(37, '0')), + HugeInt::parse("5" + std::string(37, '0')), + HugeInt::parse("987654321987654321987654321" + std::string(9, '0')), + HugeInt::parse("1" + std::string(26, '0')), + HugeInt::parse("123000000012345678" + std::string(10, '0')), + HugeInt::parse("120000000123456789" + std::string(9, '0')), + HugeInt::parse("9" + std::string(37, '0'))}; + + auto rowVector = makeRowVector( + {"a", "b"}, + { + makeNullableFlatVector(shortValues, DECIMAL(18, 6)), + makeNullableFlatVector(longValues, DECIMAL(38, 18)), + }); + createDuckDbTable({rowVector}); + + auto filePath = facebook::velox::test::getDataFilePath( + "velox/exec/tests", "data/decimal.orc"); + auto createSplit = [&]() { + return exec::test::HiveConnectorSplitBuilder(filePath) + .start(0) + .length(fs::file_size(filePath)) + .fileFormat(dwio::common::FileFormat::ORC) + .build(); + }; + + auto outputType = ROW({"b"}, {DECIMAL(38, 18)}); + auto dataColumns = rowVector->rowType(); + + auto op = PlanBuilder() + .tableScan(outputType, {}, "isnotnull(b)", dataColumns) + .planNode(); + assertQuery(op, createSplit(), "SELECT b FROM tmp where b is not null"); + + // Is null. + op = PlanBuilder() + .tableScan(outputType, {}, "is_null(b)", dataColumns) + .planNode(); + assertQuery(op, createSplit(), "SELECT b FROM tmp where b is null"); + + // HugeintRange. + op = + PlanBuilder() + .tableScan( + outputType, + {}, + "b > 2000000000.0::DECIMAL(38, 18) and b < 6000000000.0::DECIMAL(38, 18)", + dataColumns) + .planNode(); + assertQuery( + op, + createSplit(), + "SELECT b FROM tmp where b > 2000000000.0 and b < 6000000000.0"); + + // Test filter column not being projected out. + op = PlanBuilder() + .tableScan(outputType, {}, "a is null", dataColumns) + .planNode(); + assertQuery(op, createSplit(), "SELECT b FROM tmp WHERE a is null"); +} + +TEST_F(TableScanTest, fileFormatRuntimeStats) { + auto vectors = makeVectors(3, 1'000); + + // Write 3 DWRF files. + auto filePaths = makeFilePaths(3); + for (const auto& filePath : filePaths) { + writeToFile(filePath->getPath(), vectors); + } + + // DuckDB reference table needs all data from all 3 files. + std::vector allVectors; + for (int i = 0; i < 3; ++i) { + allVectors.insert(allVectors.end(), vectors.begin(), vectors.end()); + } + createDuckDbTable(allVectors); + + auto task = assertQuery(tableScanNode(), filePaths, "SELECT * FROM tmp"); + auto stats = getTableScanRuntimeStats(task); + ASSERT_EQ(stats.count("fileFormat.dwrf"), 1); + ASSERT_EQ(stats.at("fileFormat.dwrf").sum, 3); +} + +TEST_F(TableScanTest, scanBatchCallback) { + auto vectors = makeVectors(3, 1'000); + auto filePath = TempFilePath::create(); + writeToFile(filePath->getPath(), vectors); + + uint64_t totalRows{0}; + uint64_t callbackCount{0}; + std::string receivedTableName; + auto queryCtx = core::QueryCtx::create(executor_.get()); + queryCtx->setScanBatchCallback([&](const core::ScanBatchEvent& event) { + totalRows += event.numRows; + if (const auto* fileEvent = + dynamic_cast(&event)) { + receivedTableName = std::string(fileEvent->tableName); + } + ++callbackCount; + }); + + auto plan = tableScanNode(); + auto task = AssertQueryBuilder(plan) + .splits(makeHiveConnectorSplits({filePath})) + .queryCtx(queryCtx) + .copyResults(pool_.get()); + + EXPECT_GT(totalRows, 0); + EXPECT_GT(callbackCount, 0); + EXPECT_FALSE(receivedTableName.empty()); +} + +TEST_F(TableScanTest, scanBatchCallbackPartitionKeys) { + auto vectors = makeVectors(1, 1'000); + auto filePath = TempFilePath::create(); + writeToFile(filePath->getPath(), vectors); + + std::unordered_map> + receivedPartitionKeys; + bool callbackFired{false}; + auto queryCtx = core::QueryCtx::create(executor_.get()); + queryCtx->setScanBatchCallback([&](const core::ScanBatchEvent& event) { + if (const auto* fileEvent = + dynamic_cast(&event)) { + if (fileEvent->partitionKeys) { + receivedPartitionKeys = *fileEvent->partitionKeys; + } + } + callbackFired = true; + }); + + connector::ColumnHandleMap assignments = { + {"c0", regularColumn("c0", BIGINT())}, + {"ds", partitionKey("ds", VARCHAR())}, + }; + + auto outputType = ROW({"c0", "ds"}, {BIGINT(), VARCHAR()}); + auto plan = PlanBuilder() + .startTableScan() + .outputType(outputType) + .assignments(assignments) + .endTableScan() + .planNode(); + + auto split = exec::test::HiveConnectorSplitBuilder(filePath->getPath()) + .partitionKey("ds", "2026-05-12") + .build(); + + AssertQueryBuilder(plan).queryCtx(queryCtx).split(split).copyResults( + pool_.get()); + + ASSERT_TRUE(callbackFired); + const std::unordered_map> expected{ + {"ds", "2026-05-12"}}; + ASSERT_EQ(receivedPartitionKeys, expected); +} + +TEST_F(TableScanTest, scanBatchCallbackNotSetIsNoOp) { + auto vectors = makeVectors(3, 1'000); + auto filePath = TempFilePath::create(); + writeToFile(filePath->getPath(), vectors); + + auto plan = tableScanNode(); + auto result = AssertQueryBuilder(plan) + .splits(makeHiveConnectorSplits({filePath})) + .copyResults(pool_.get()); + EXPECT_GT(result->size(), 0); +} + +// --- Column extraction pushdown table scan tests --- + +TEST_F(TableScanTest, extractionMapKeys) { + // Write a MAP(VARCHAR, BIGINT) column, read with MapKeys extraction. + auto mapVector = makeMapVector( + {{{"a", 1}, {"b", 2}}, {{"c", 3}, {"d", 4}, {"e", 5}}}); + auto vector = makeRowVector({"col"}, {mapVector}); + + auto file = TempFilePath::create(); + writeToFile(file->getPath(), {vector}); + + auto hiveType = MAP(VARCHAR(), BIGINT()); + auto outputType = ROW({"col"}, {ARRAY(VARCHAR())}); + std::vector extractions = { + {"col", + {ExtractionPathElement::simple(ExtractionStep::kMapKeys)}, + ARRAY(VARCHAR())}}; + + auto handle = std::make_shared( + "col", + HiveColumnHandle::ColumnType::kRegular, + ARRAY(VARCHAR()), + hiveType, + std::vector{}, + std::move(extractions)); + + auto plan = PlanBuilder() + .startTableScan() + .outputType(outputType) + .assignments({{"col", handle}}) + .endTableScan() + .planNode(); + + auto result = AssertQueryBuilder(plan) + .split(makeHiveConnectorSplit(file->getPath())) + .copyResults(pool_.get()); + + ASSERT_EQ(result->size(), 2); + auto* resultArray = result->childAt(0)->as(); + ASSERT_EQ(resultArray->sizeAt(0), 2); + ASSERT_EQ(resultArray->sizeAt(1), 3); +} + +TEST_F(TableScanTest, extractionMapValues) { + // Write a MAP(VARCHAR, BIGINT) column, read with MapValues extraction. + auto mapVector = + makeMapVector({{{"a", 10}}, {{"b", 20}, {"c", 30}}}); + auto vector = makeRowVector({"col"}, {mapVector}); + + auto file = TempFilePath::create(); + writeToFile(file->getPath(), {vector}); + + auto hiveType = MAP(VARCHAR(), BIGINT()); + auto outputType = ROW({"col"}, {ARRAY(BIGINT())}); + std::vector extractions = { + {"col", + {ExtractionPathElement::simple(ExtractionStep::kMapValues)}, + ARRAY(BIGINT())}}; + + auto handle = std::make_shared( + "col", + HiveColumnHandle::ColumnType::kRegular, + ARRAY(BIGINT()), + hiveType, + std::vector{}, + std::move(extractions)); + + auto plan = PlanBuilder() + .startTableScan() + .outputType(outputType) + .assignments({{"col", handle}}) + .endTableScan() + .planNode(); + + auto result = AssertQueryBuilder(plan) + .split(makeHiveConnectorSplit(file->getPath())) + .copyResults(pool_.get()); + + ASSERT_EQ(result->size(), 2); + auto* resultArray = result->childAt(0)->as(); + ASSERT_EQ(resultArray->sizeAt(0), 1); + ASSERT_EQ(resultArray->sizeAt(1), 2); +} + +TEST_F(TableScanTest, extractionSize) { + // Write a MAP(VARCHAR, BIGINT) column, read with Size extraction. + auto mapVector = makeMapVector( + {{{"a", 1}, {"b", 2}, {"c", 3}}, {{"d", 4}}}); + auto vector = makeRowVector({"col"}, {mapVector}); + + auto file = TempFilePath::create(); + writeToFile(file->getPath(), {vector}); + + auto hiveType = MAP(VARCHAR(), BIGINT()); + auto outputType = ROW({"col"}, {BIGINT()}); + std::vector extractions = { + {"col", + {ExtractionPathElement::simple(ExtractionStep::kSize)}, + BIGINT()}}; + + auto handle = std::make_shared( + "col", + HiveColumnHandle::ColumnType::kRegular, + BIGINT(), + hiveType, + std::vector{}, + std::move(extractions)); + + auto plan = PlanBuilder() + .startTableScan() + .outputType(outputType) + .assignments({{"col", handle}}) + .endTableScan() + .planNode(); + + auto result = AssertQueryBuilder(plan) + .split(makeHiveConnectorSplit(file->getPath())) + .copyResults(pool_.get()); + + ASSERT_EQ(result->size(), 2); + auto expected = makeRowVector({"col"}, {makeFlatVector({3, 1})}); + facebook::velox::test::assertEqualVectors(expected, result); +} + +TEST_F(TableScanTest, extractionMapKeyFilter) { + // Write a MAP(VARCHAR, BIGINT) column with string keys, read with + // MapKeyFilter extraction to keep only selected keys. + auto mapVector = makeMapVector( + {{{"a", 1}, {"b", 2}, {"c", 3}}, {{"a", 10}, {"d", 40}, {"b", 50}}}); + auto vector = makeRowVector({"col"}, {mapVector}); + + auto file = TempFilePath::create(); + writeToFile(file->getPath(), {vector}); + + auto hiveType = MAP(VARCHAR(), BIGINT()); + auto outputType = ROW({"col"}, {MAP(VARCHAR(), BIGINT())}); + std::vector extractions = { + {"col", + {ExtractionPathElement::mapKeyFilter( + std::vector{"a", "b"})}, + MAP(VARCHAR(), BIGINT())}}; + + auto handle = std::make_shared( + "col", + HiveColumnHandle::ColumnType::kRegular, + MAP(VARCHAR(), BIGINT()), + hiveType, + std::vector{}, + std::move(extractions)); + + auto plan = PlanBuilder() + .startTableScan() + .outputType(outputType) + .assignments({{"col", handle}}) + .endTableScan() + .planNode(); + + auto result = AssertQueryBuilder(plan) + .split(makeHiveConnectorSplit(file->getPath())) + .copyResults(pool_.get()); + + ASSERT_EQ(result->size(), 2); + auto* filteredMap = result->childAt(0)->as(); + // Row 0: {"a":1, "b":2} kept, "c" filtered out. + ASSERT_EQ(filteredMap->sizeAt(0), 2); + // Row 1: {"a":10, "b":50} kept, "d" filtered out. + ASSERT_EQ(filteredMap->sizeAt(1), 2); +} + +TEST_F(TableScanTest, extractionMapKeyFilterIntegerKeys) { + // Write a MAP(BIGINT, VARCHAR) column, read with MapKeyFilter using + // integer filter keys. + auto keys = makeFlatVector({10, 20, 30, 10, 40}); + auto values = makeFlatVector({"aa", "bb", "cc", "dd", "ee"}); + auto mapVector = makeMapVector({0, 3}, keys, values); + auto vector = makeRowVector({"col"}, {mapVector}); + + auto file = TempFilePath::create(); + writeToFile(file->getPath(), {vector}); + + auto hiveType = MAP(BIGINT(), VARCHAR()); + auto outputType = ROW({"col"}, {MAP(BIGINT(), VARCHAR())}); + std::vector extractions = { + {"col", + {ExtractionPathElement::mapKeyFilter(std::vector{10, 30})}, + MAP(BIGINT(), VARCHAR())}}; + + auto handle = std::make_shared( + "col", + HiveColumnHandle::ColumnType::kRegular, + MAP(BIGINT(), VARCHAR()), + hiveType, + std::vector{}, + std::move(extractions)); + + auto plan = PlanBuilder() + .startTableScan() + .outputType(outputType) + .assignments({{"col", handle}}) + .endTableScan() + .planNode(); + + auto result = AssertQueryBuilder(plan) + .split(makeHiveConnectorSplit(file->getPath())) + .copyResults(pool_.get()); + + ASSERT_EQ(result->size(), 2); + auto* filteredMap = result->childAt(0)->as(); + // Row 0: keys {10, 30} kept, 20 filtered out. + ASSERT_EQ(filteredMap->sizeAt(0), 2); + // Row 1: key {10} kept, 40 filtered out. + ASSERT_EQ(filteredMap->sizeAt(1), 1); +} + +TEST_F(TableScanTest, extractionStructField) { + // Write a ROW(x: INT, y: VARCHAR) column, extract just field "x". + auto structVector = makeRowVector( + {"x", "y"}, + {makeFlatVector({10, 20, 30}), + makeFlatVector({"aa", "bb", "cc"})}); + auto vector = makeRowVector({"col"}, {structVector}); + + auto file = TempFilePath::create(); + writeToFile(file->getPath(), {vector}); + + auto hiveType = ROW({{"x", INTEGER()}, {"y", VARCHAR()}}); + auto outputType = ROW({"col"}, {INTEGER()}); + std::vector extractions = { + {"col", {ExtractionPathElement::structField("x")}, INTEGER()}}; + + auto handle = std::make_shared( + "col", + HiveColumnHandle::ColumnType::kRegular, + INTEGER(), + hiveType, + std::vector{}, + std::move(extractions)); + + auto plan = PlanBuilder() + .startTableScan() + .outputType(outputType) + .assignments({{"col", handle}}) + .endTableScan() + .planNode(); + + auto result = AssertQueryBuilder(plan) + .split(makeHiveConnectorSplit(file->getPath())) + .copyResults(pool_.get()); + + ASSERT_EQ(result->size(), 3); + auto expected = + makeRowVector({"col"}, {makeFlatVector({10, 20, 30})}); + facebook::velox::test::assertEqualVectors(expected, result); +} + +TEST_F(TableScanTest, extractionArrayElementsStructField) { + // Write an ARRAY(ROW(x: INT, y: INT)) column, extract the "x" field from + // each array element -> ARRAY(INT). + auto innerStruct = makeRowVector( + {"x", "y"}, + {makeFlatVector({1, 2, 3, 4}), + makeFlatVector({10, 20, 30, 40})}); + auto arrayVector = makeArrayVector({0, 2}, innerStruct); + auto vector = makeRowVector({"col"}, {arrayVector}); + + auto file = TempFilePath::create(); + writeToFile(file->getPath(), {vector}); + + auto hiveType = ARRAY(ROW({{"x", INTEGER()}, {"y", INTEGER()}})); + auto outputType = ROW({"col"}, {ARRAY(INTEGER())}); + std::vector extractions = { + {"col", + {ExtractionPathElement::simple(ExtractionStep::kArrayElements), + ExtractionPathElement::structField("x")}, + ARRAY(INTEGER())}}; + + auto handle = std::make_shared( + "col", + HiveColumnHandle::ColumnType::kRegular, + ARRAY(INTEGER()), + hiveType, + std::vector{}, + std::move(extractions)); + + auto plan = PlanBuilder() + .startTableScan() + .outputType(outputType) + .assignments({{"col", handle}}) + .endTableScan() + .planNode(); + + auto result = AssertQueryBuilder(plan) + .split(makeHiveConnectorSplit(file->getPath())) + .copyResults(pool_.get()); + + ASSERT_EQ(result->size(), 2); + auto* resultArray = result->childAt(0)->as(); + ASSERT_EQ(resultArray->sizeAt(0), 2); + ASSERT_EQ(resultArray->sizeAt(1), 2); + auto* elements = resultArray->elements()->as>(); + ASSERT_EQ(elements->valueAt(0), 1); + ASSERT_EQ(elements->valueAt(1), 2); + ASSERT_EQ(elements->valueAt(2), 3); + ASSERT_EQ(elements->valueAt(3), 4); +} + +TEST_F(TableScanTest, extractionArraySize) { + // Write an ARRAY(BIGINT) column, read with Size extraction. + auto arrayVector = makeArrayVector({{1, 2, 3}, {4}, {5, 6}}); + auto vector = makeRowVector({"col"}, {arrayVector}); + + auto file = TempFilePath::create(); + writeToFile(file->getPath(), {vector}); + + auto hiveType = ARRAY(BIGINT()); + auto outputType = ROW({"col"}, {BIGINT()}); + std::vector extractions = { + {"col", + {ExtractionPathElement::simple(ExtractionStep::kSize)}, + BIGINT()}}; + + auto handle = std::make_shared( + "col", + HiveColumnHandle::ColumnType::kRegular, + BIGINT(), + hiveType, + std::vector{}, + std::move(extractions)); + + auto plan = PlanBuilder() + .startTableScan() + .outputType(outputType) + .assignments({{"col", handle}}) + .endTableScan() + .planNode(); + + auto result = AssertQueryBuilder(plan) + .split(makeHiveConnectorSplit(file->getPath())) + .copyResults(pool_.get()); + + ASSERT_EQ(result->size(), 3); + auto expected = makeRowVector({"col"}, {makeFlatVector({3, 1, 2})}); + facebook::velox::test::assertEqualVectors(expected, result); +} + +TEST_F(TableScanTest, extractionMapValuesStructField) { + // Write MAP(VARCHAR, ROW(x: INT, y: INT)), extract values.x -> ARRAY(INT). + auto keys = makeFlatVector({"a", "b", "c"}); + auto structValues = makeRowVector( + {"x", "y"}, + {makeFlatVector({10, 20, 30}), + makeFlatVector({100, 200, 300})}); + auto mapVector = makeMapVector({0, 2}, keys, structValues); + auto vector = makeRowVector({"col"}, {mapVector}); + + auto file = TempFilePath::create(); + writeToFile(file->getPath(), {vector}); + + auto hiveType = MAP(VARCHAR(), ROW({{"x", INTEGER()}, {"y", INTEGER()}})); + auto outputType = ROW({"col"}, {ARRAY(INTEGER())}); + std::vector extractions = { + {"col", + {ExtractionPathElement::simple(ExtractionStep::kMapValues), + ExtractionPathElement::simple(ExtractionStep::kArrayElements), + ExtractionPathElement::structField("x")}, + ARRAY(INTEGER())}}; + + auto handle = std::make_shared( + "col", + HiveColumnHandle::ColumnType::kRegular, + ARRAY(INTEGER()), + hiveType, + std::vector{}, + std::move(extractions)); + + auto plan = PlanBuilder() + .startTableScan() + .outputType(outputType) + .assignments({{"col", handle}}) + .endTableScan() + .planNode(); + + auto result = AssertQueryBuilder(plan) + .split(makeHiveConnectorSplit(file->getPath())) + .copyResults(pool_.get()); + + ASSERT_EQ(result->size(), 2); + auto* resultArray = result->childAt(0)->as(); + ASSERT_EQ(resultArray->sizeAt(0), 2); + ASSERT_EQ(resultArray->sizeAt(1), 1); + auto* elements = resultArray->elements()->as>(); + ASSERT_EQ(elements->valueAt(0), 10); + ASSERT_EQ(elements->valueAt(1), 20); + ASSERT_EQ(elements->valueAt(2), 30); +} + +TEST_F(TableScanTest, extractionMultipleFromSameColumn) { + // Write MAP(VARCHAR, BIGINT), extract both keys and size. + auto mapVector = + makeMapVector({{{"a", 1}, {"b", 2}}, {{"c", 3}}}); + auto vector = makeRowVector({"col"}, {mapVector}); + + auto file = TempFilePath::create(); + writeToFile(file->getPath(), {vector}); + + auto hiveType = MAP(VARCHAR(), BIGINT()); + auto keysType = ARRAY(VARCHAR()); + auto sizeType = BIGINT(); + auto rowOutputType = ROW({{"keys", keysType}, {"sz", sizeType}}); + auto outputType = ROW({"col"}, {rowOutputType}); + + std::vector extractions = { + {"keys", + {ExtractionPathElement::simple(ExtractionStep::kMapKeys)}, + keysType}, + {"sz", {ExtractionPathElement::simple(ExtractionStep::kSize)}, sizeType}}; + + auto handle = std::make_shared( + "col", + HiveColumnHandle::ColumnType::kRegular, + rowOutputType, + hiveType, + std::vector{}, + std::move(extractions)); + + auto plan = PlanBuilder() + .startTableScan() + .outputType(outputType) + .assignments({{"col", handle}}) + .endTableScan() + .planNode(); + + auto result = AssertQueryBuilder(plan) + .split(makeHiveConnectorSplit(file->getPath())) + .copyResults(pool_.get()); + + ASSERT_EQ(result->size(), 2); + auto* outputRow = result->childAt(0)->as(); + // Check sizes. + auto* sizes = outputRow->childAt(1)->as>(); + ASSERT_EQ(sizes->valueAt(0), 2); + ASSERT_EQ(sizes->valueAt(1), 1); + // Check keys. + auto* keysArray = outputRow->childAt(0)->as(); + ASSERT_EQ(keysArray->sizeAt(0), 2); + ASSERT_EQ(keysArray->sizeAt(1), 1); +} + +TEST_F(TableScanTest, extractionWithRegularColumn) { + // Write two columns: a regular BIGINT and a MAP with extraction. + auto idColumn = makeFlatVector({100, 200, 300}); + auto mapVector = makeMapVector( + {{{"x", 1}}, {{"y", 2}, {"z", 3}}, {{"w", 4}}}); + auto vector = makeRowVector({"id", "m"}, {idColumn, mapVector}); + + auto file = TempFilePath::create(); + writeToFile(file->getPath(), {vector}); + + auto hiveType = MAP(VARCHAR(), BIGINT()); + auto outputType = ROW({"id", "m"}, {BIGINT(), ARRAY(VARCHAR())}); + std::vector extractions = { + {"m", + {ExtractionPathElement::simple(ExtractionStep::kMapKeys)}, + ARRAY(VARCHAR())}}; + + auto extractionHandle = std::make_shared( + "m", + HiveColumnHandle::ColumnType::kRegular, + ARRAY(VARCHAR()), + hiveType, + std::vector{}, + std::move(extractions)); + + auto plan = + PlanBuilder() + .startTableScan() + .outputType(outputType) + .assignments( + {{"id", regularColumn("id", BIGINT())}, {"m", extractionHandle}}) + .endTableScan() + .planNode(); + + auto result = AssertQueryBuilder(plan) + .split(makeHiveConnectorSplit(file->getPath())) + .copyResults(pool_.get()); + + ASSERT_EQ(result->size(), 3); + // Verify regular column. + auto* ids = result->childAt(0)->as>(); + ASSERT_EQ(ids->valueAt(0), 100); + ASSERT_EQ(ids->valueAt(1), 200); + ASSERT_EQ(ids->valueAt(2), 300); + // Verify extracted keys. + auto* keysArray = result->childAt(1)->as(); + ASSERT_EQ(keysArray->sizeAt(0), 1); + ASSERT_EQ(keysArray->sizeAt(1), 2); + ASSERT_EQ(keysArray->sizeAt(2), 1); +} + +TEST_F(TableScanTest, extractionSizeAndMapKeyFilter) { + // Extract both kSize (full map length) and kMapKeyFilter (selected keys) + // from the same MAP column. Size should reflect all entries while the + // filtered map should contain only the selected keys. + auto mapVector = makeMapVector( + {{{"a", 1}, {"b", 2}, {"c", 3}}, {{"a", 10}, {"d", 40}}}); + auto vector = makeRowVector({"col"}, {mapVector}); + + auto file = TempFilePath::create(); + writeToFile(file->getPath(), {vector}); + + auto hiveType = MAP(VARCHAR(), BIGINT()); + auto sizeType = BIGINT(); + auto filteredMapType = MAP(VARCHAR(), BIGINT()); + auto rowOutputType = ROW({{"sz", sizeType}, {"filtered", filteredMapType}}); + auto outputType = ROW({"col"}, {rowOutputType}); + + std::vector extractions = { + {"sz", {ExtractionPathElement::simple(ExtractionStep::kSize)}, sizeType}, + {"filtered", + {ExtractionPathElement::mapKeyFilter( + std::vector{"a", "b"})}, + filteredMapType}}; + + auto handle = std::make_shared( + "col", + HiveColumnHandle::ColumnType::kRegular, + rowOutputType, + hiveType, + std::vector{}, + std::move(extractions)); + + auto plan = PlanBuilder() + .startTableScan() + .outputType(outputType) + .assignments({{"col", handle}}) + .endTableScan() + .planNode(); + + auto result = AssertQueryBuilder(plan) + .split(makeHiveConnectorSplit(file->getPath())) + .copyResults(pool_.get()); + + ASSERT_EQ(result->size(), 2); + auto* outputRow = result->childAt(0)->as(); + // Check sizes — should be the full map lengths. + auto* sizes = outputRow->childAt(0)->as>(); + ASSERT_EQ(sizes->valueAt(0), 3); // {"a","b","c"} + ASSERT_EQ(sizes->valueAt(1), 2); // {"a","d"} + // Check filtered map — only keys "a" and "b". + auto* filteredMap = outputRow->childAt(1)->as(); + ASSERT_EQ(filteredMap->sizeAt(0), 2); // "a" and "b" + ASSERT_EQ(filteredMap->sizeAt(1), 1); // "a" only ("d" filtered out) +} + +TEST_F(TableScanTest, extractionSizeLargeDataMultipleBatches) { + // Write enough data to produce multiple output batches. Verify + // correctness across batches (implicitly tests result vector reuse + // since the reader reuses the FlatVector across batches). + constexpr int kNumRows = 10'000; + // Build map data: each row has (i%5)+1 entries. + vector_size_t totalEntries = 0; + for (int i = 0; i < kNumRows; ++i) { + totalEntries += (i % 5) + 1; + } + std::vector keyStrs(totalEntries); + for (int i = 0; i < totalEntries; ++i) { + keyStrs[i] = std::to_string(i); + } + auto keys = makeFlatVector( + totalEntries, [&](auto i) { return StringView(keyStrs[i]); }); + auto values = makeFlatVector(totalEntries, folly::identity); + std::vector offsets(kNumRows); + vector_size_t offset = 0; + for (int i = 0; i < kNumRows; ++i) { + offsets[i] = offset; + offset += (i % 5) + 1; + } + auto mapVector = makeMapVector(offsets, keys, values); + auto vector = makeRowVector({"col"}, {mapVector}); + + auto file = TempFilePath::create(); + writeToFile(file->getPath(), {vector}); + + auto hiveType = MAP(VARCHAR(), BIGINT()); + auto outputType = ROW({"col"}, {BIGINT()}); + std::vector extractions = { + {"col", + {ExtractionPathElement::simple(ExtractionStep::kSize)}, + BIGINT()}}; + + auto handle = std::make_shared( + "col", + HiveColumnHandle::ColumnType::kRegular, + BIGINT(), + hiveType, + std::vector{}, + std::move(extractions)); + + auto plan = PlanBuilder() + .startTableScan() + .outputType(outputType) + .assignments({{"col", handle}}) + .endTableScan() + .planNode(); + + auto result = AssertQueryBuilder(plan) + .split(makeHiveConnectorSplit(file->getPath())) + .copyResults(pool_.get()); + + // Verify all rows are present with correct sizes. + ASSERT_EQ(result->size(), kNumRows); + auto* sizes = result->childAt(0)->as>(); + for (int i = 0; i < kNumRows; ++i) { + ASSERT_EQ(sizes->valueAt(i), (i % 5) + 1) << "Incorrect size at row " << i; + } +} + +TEST_F(TableScanTest, extractionDeeplyNestedChain) { + // Test deeply nested extraction on a MAP column: + // MAP(VARCHAR, ROW(x: INT, y: ARRAY(BIGINT))) + // Chain: MapValues -> ArrayElements -> StructField("y") -> Size + // Output: ARRAY(BIGINT) + // + // With recursive ScanSpec configuration: + // - MAP: kValues (reader skips decoding keys) + // - ROW values: "x" pruned (reader skips decoding x) + // - "y" ARRAY: kSize (reader skips decoding array elements) + // The effective remaining transform is just [StructField("y")] to extract + // the y field from the ROW. Size is handled by kSize on the ScanSpec. + + // MAP(VARCHAR, ROW(x: INT, y: ARRAY(BIGINT))) + // Row 0: {"k1" -> {x:1, y:[10,20]}, "k2" -> {x:2, y:[30]}} + // Row 1: {"k3" -> {x:3, y:[40,50,60]}} + auto mapKeys = makeFlatVector({"k1", "k2", "k3"}); + auto allInnerStructs = makeRowVector( + {"x", "y"}, + {makeFlatVector({1, 2, 3}), + makeArrayVector({{10, 20}, {30}, {40, 50, 60}})}); + auto mapVector = makeMapVector({0, 2}, mapKeys, allInnerStructs); + auto tableVector = makeRowVector({"col"}, {mapVector}); + + auto file = TempFilePath::create(); + writeToFile(file->getPath(), {tableVector}); + + auto innerStructType = ROW({{"x", INTEGER()}, {"y", ARRAY(BIGINT())}}); + auto hiveType = MAP(VARCHAR(), innerStructType); + + // Extraction chain: MapValues -> ArrayElements -> StructField("y") -> Size + // Output type: ARRAY(BIGINT) — MapValues produces ARRAY(ROW(...)), + // ArrayElements enters each element, StructField("y") gets the array field, + // Size produces BIGINT. The result is wrapped back into ARRAY(BIGINT). + auto outputColType = ARRAY(BIGINT()); + + std::vector extractions = { + {"col", + {ExtractionPathElement::simple(ExtractionStep::kMapValues), + ExtractionPathElement::simple(ExtractionStep::kArrayElements), + ExtractionPathElement::structField("y"), + ExtractionPathElement::simple(ExtractionStep::kSize)}, + outputColType}}; + + auto handle = std::make_shared( + "col", + HiveColumnHandle::ColumnType::kRegular, + outputColType, + hiveType, + std::vector{}, + std::move(extractions)); + + auto outputType = ROW({"col"}, {outputColType}); + auto plan = PlanBuilder() + .startTableScan() + .outputType(outputType) + .assignments({{"col", handle}}) + .endTableScan() + .planNode(); + + auto extractionTask = + AssertQueryBuilder(plan) + .config(core::QueryConfig::kMaxSplitPreloadPerDriver, "0") + .split(makeHiveConnectorSplit(file->getPath())) + .copyResults(pool_.get()); + + ASSERT_EQ(extractionTask->size(), 2); + auto* resultArray = extractionTask->childAt(0)->as(); + + // Row 0: map values are [{x:1,y:[10,20]}, {x:2,y:[30]}] + // -> y sizes = [2, 1] + ASSERT_EQ(resultArray->sizeAt(0), 2); + auto* elements = resultArray->elements()->as>(); + ASSERT_EQ(elements->valueAt(resultArray->offsetAt(0)), 2); // size of [10,20] + ASSERT_EQ(elements->valueAt(resultArray->offsetAt(0) + 1), 1); // size of [30] + + // Row 1: map values are [{x:3,y:[40,50,60]}] + // -> y sizes = [3] + ASSERT_EQ(resultArray->sizeAt(1), 1); + ASSERT_EQ( + elements->valueAt(resultArray->offsetAt(1)), 3); // size of [40,50,60] + + // I/O reduction is validated at the reader level in + // TestReader.extractionMapKeysIoReduction (ReaderTest.cpp). + // Recursive ScanSpec pushdown skips decoding map keys, struct field "x", + // and array element data — only map/array lengths and "y" lengths are read. +} + +TEST_F(TableScanTest, extractionMultipleFormatsMultipleSplits) { + // Test processing DWRF, TEXT, DWRF splits with MapKeys extraction in a + // single table scan. Exercises the extraction code path across format + // switches within a single split reader. + + // Data: MAP(VARCHAR, BIGINT) with 2 rows. + auto mapVector = + makeMapVector({{{"a", 1}, {"b", 2}}, {{"c", 3}}}); + auto vector = makeRowVector({"col"}, {mapVector}); + + // Write DWRF files for splits 1 and 3. + auto dwrfFile1 = TempFilePath::create(); + auto dwrfFile2 = TempFilePath::create(); + writeToFile(dwrfFile1->getPath(), {vector}); + writeToFile(dwrfFile2->getPath(), {vector}); + + // Write TEXT file for split 2. Text format uses field/collection/map-key + // delimiters. The same 2 rows: {"a":1,"b":2} and {"c":3}. + auto textFile = TempFilePath::create(); + { + // TempFilePath already created the file; open for truncating write. + auto writeFile = + std::make_unique(textFile->getPath(), true, false); + // Row format: map entries separated by \x02, key-value by \x03. + writeFile->append( + "a\x03" + "1\x02" + "b\x03" + "2\nc\x03" + "3\n"); + writeFile->close(); + } + + auto hiveType = MAP(VARCHAR(), BIGINT()); + auto outputType = ROW({"col"}, {ARRAY(VARCHAR())}); + auto dataColumns = ROW({"col"}, {hiveType}); + std::vector extractions = { + {"col", + {ExtractionPathElement::simple(ExtractionStep::kMapKeys)}, + ARRAY(VARCHAR())}}; + + auto handle = std::make_shared( + "col", + HiveColumnHandle::ColumnType::kRegular, + ARRAY(VARCHAR()), + hiveType, + std::vector{}, + std::move(extractions)); + + auto plan = PlanBuilder() + .startTableScan() + .dataColumns(dataColumns) + .outputType(outputType) + .assignments({{"col", handle}}) + .endTableScan() + .planNode(); + + // TEXT split with serde parameters for delimiters. + std::unordered_map serdeParameters{ + {dwio::common::SerDeOptions::kFieldDelim, "\x01"}, + {dwio::common::SerDeOptions::kCollectionDelim, "\x02"}, + {dwio::common::SerDeOptions::kMapKeyDelim, "\x03"}}; + auto textSplit = std::make_shared( + kHiveConnectorId, + textFile->getPath(), + dwio::common::FileFormat::TEXT, + 0, + std::numeric_limits::max(), + std::unordered_map>{}, + std::nullopt, + std::unordered_map{}, + nullptr, + serdeParameters); + + // Process 3 splits: DWRF, TEXT, DWRF. Disable split preloading to ensure + // all splits are processed by a single split reader. + auto result = AssertQueryBuilder(plan) + .config(core::QueryConfig::kMaxSplitPreloadPerDriver, "0") + .split(makeHiveConnectorSplit(dwrfFile1->getPath())) + .split(exec::Split(textSplit)) + .split(makeHiveConnectorSplit(dwrfFile2->getPath())) + .copyResults(pool_.get()); + + // 2 rows per split × 3 splits = 6 rows total. + ASSERT_EQ(result->size(), 6); + auto* resultArray = result->childAt(0)->as(); + for (int split = 0; split < 3; ++split) { + ASSERT_EQ(resultArray->sizeAt(split * 2), 2); + ASSERT_EQ(resultArray->sizeAt(split * 2 + 1), 1); + } +} + +TEST_F(TableScanTest, extractionMultipleExtractionsTextFormat) { + // Test multiple extractions (MapKeys + Size) on a MAP column read from a + // TEXT format split. With multiple extractions, ExtractionType stays as + // kNone and full chains are used in the transform. + + // Data: MAP(VARCHAR, BIGINT) with 2 rows. + // Row 0: {"a":1, "b":2} + // Row 1: {"c":3} + auto textFile = TempFilePath::create(); + { + auto writeFile = + std::make_unique(textFile->getPath(), true, false); + // Row format: map entries separated by \x02, key-value by \x03. + writeFile->append( + "a\x03" + "1\x02" + "b\x03" + "2\nc\x03" + "3\n"); + writeFile->close(); + } + + auto hiveType = MAP(VARCHAR(), BIGINT()); + // Output: ROW(keys: ARRAY(VARCHAR), size: BIGINT) + auto outputColType = ROW({"keys", "size"}, {ARRAY(VARCHAR()), BIGINT()}); + auto dataColumns = ROW({"col"}, {hiveType}); + std::vector extractions = { + {"keys", + {ExtractionPathElement::simple(ExtractionStep::kMapKeys)}, + ARRAY(VARCHAR())}, + {"size", + {ExtractionPathElement::simple(ExtractionStep::kSize)}, + BIGINT()}}; + + auto handle = std::make_shared( + "col", + HiveColumnHandle::ColumnType::kRegular, + outputColType, + hiveType, + std::vector{}, + std::move(extractions)); + + auto outputType = ROW({"col"}, {outputColType}); + auto plan = PlanBuilder() + .startTableScan() + .dataColumns(dataColumns) + .outputType(outputType) + .assignments({{"col", handle}}) + .endTableScan() + .planNode(); + + std::unordered_map serdeParameters{ + {dwio::common::SerDeOptions::kFieldDelim, "\x01"}, + {dwio::common::SerDeOptions::kCollectionDelim, "\x02"}, + {dwio::common::SerDeOptions::kMapKeyDelim, "\x03"}}; + auto textSplit = std::make_shared( + kHiveConnectorId, + textFile->getPath(), + dwio::common::FileFormat::TEXT, + 0, + std::numeric_limits::max(), + std::unordered_map>{}, + std::nullopt, + std::unordered_map{}, + nullptr, + serdeParameters); + + auto result = AssertQueryBuilder(plan) + .config(core::QueryConfig::kMaxSplitPreloadPerDriver, "0") + .split(exec::Split(textSplit)) + .copyResults(pool_.get()); + + ASSERT_EQ(result->size(), 2); + // Output is a single column "col" of type ROW{keys, size}. + auto* outerRow = result->childAt(0)->as(); + ASSERT_NE(outerRow, nullptr); + + // Verify keys (ARRAY(VARCHAR)). + auto* keysArray = outerRow->childAt(0)->as(); + ASSERT_EQ(keysArray->sizeAt(0), 2); // Row 0: {"a", "b"} + ASSERT_EQ(keysArray->sizeAt(1), 1); // Row 1: {"c"} + + // Verify size (BIGINT). + auto* sizeVector = outerRow->childAt(1)->as>(); + ASSERT_EQ(sizeVector->valueAt(0), 2); // Row 0: 2 entries + ASSERT_EQ(sizeVector->valueAt(1), 1); // Row 1: 1 entry +} + } // namespace } // namespace facebook::velox::exec diff --git a/velox/exec/tests/TableWriterTest.cpp b/velox/exec/tests/TableWriterTest.cpp index 2b91f729c37..fd8c79865f9 100644 --- a/velox/exec/tests/TableWriterTest.cpp +++ b/velox/exec/tests/TableWriterTest.cpp @@ -19,33 +19,33 @@ #include "folly/dynamic.h" #include "velox/common/base/Fs.h" #include "velox/common/base/tests/GTestUtils.h" -#include "velox/common/hyperloglog/SparseHll.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/common/testutil/TestValue.h" #include "velox/connectors/hive/HiveConfig.h" -#include "velox/connectors/hive/HivePartitionFunction.h" #include "velox/dwio/common/WriterFactory.h" #include "velox/exec/PlanNodeStats.h" #include "velox/exec/TableWriter.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/vector/fuzzer/VectorFuzzer.h" #include #include -#include "folly/experimental/EventCount.h" +#include "folly/synchronization/EventCount.h" #include "velox/common/memory/MemoryArbitrator.h" #include "velox/dwio/common/Options.h" #include "velox/dwio/dwrf/writer/Writer.h" #include "velox/exec/tests/utils/ArbitratorTestUtil.h" namespace velox::exec::test { +using namespace facebook::velox::common::testutil; + constexpr uint64_t kQueryMemoryCapacity = 512 * MB; -class BasicTableWriterTestBase : public HiveConnectorTestBase {}; +class BasicTableWriterTest : public HiveConnectorTestBase {}; -TEST_F(BasicTableWriterTestBase, roundTrip) { +TEST_F(BasicTableWriterTest, roundTrip) { vector_size_t size = 1'000; auto data = makeRowVector({ makeFlatVector(size, [](auto row) { return row; }), @@ -82,7 +82,7 @@ TEST_F(BasicTableWriterTestBase, roundTrip) { ->as>(); ASSERT_TRUE(details->isNullAt(0)); ASSERT_FALSE(details->isNullAt(1)); - folly::dynamic obj = folly::parseJson(details->valueAt(1)); + folly::dynamic obj = folly::parseJson(std::string_view(details->valueAt(1))); ASSERT_EQ(size, obj["rowCount"].asInt()); auto fileWriteInfos = obj["fileWriteInfos"]; @@ -93,16 +93,18 @@ TEST_F(BasicTableWriterTestBase, roundTrip) { // Read from 'writeFileName' and verify the data matches the original. plan = PlanBuilder().tableScan(rowType).planNode(); - auto copy = AssertQueryBuilder(plan) - .split(makeHiveConnectorSplit(fmt::format( - "{}/{}", targetDirectoryPath->getPath(), writeFileName))) - .copyResults(pool()); + auto copy = + AssertQueryBuilder(plan) + .split(makeHiveConnectorSplit( + fmt::format( + "{}/{}", targetDirectoryPath->getPath(), writeFileName))) + .copyResults(pool()); assertEqualResults({data}, {copy}); } // Generates a struct (row), write it as a flap map, and check that it is read // back as a map. -TEST_F(BasicTableWriterTestBase, structAsMap) { +TEST_F(BasicTableWriterTest, structAsMap) { // Input struct type. vector_size_t size = 1'000; auto data = makeRowVector( @@ -162,7 +164,7 @@ TEST_F(BasicTableWriterTestBase, structAsMap) { .assertResults(expected); } -TEST_F(BasicTableWriterTestBase, targetFileName) { +TEST_F(BasicTableWriterTest, targetFileName) { constexpr const char* kFileName = "test.dwrf"; auto data = makeRowVector({makeFlatVector(10, folly::identity)}); auto directory = TempDirectoryPath::create(); @@ -178,7 +180,7 @@ TEST_F(BasicTableWriterTestBase, targetFileName) { auto results = AssertQueryBuilder(plan).copyResults(pool()); auto* details = results->childAt(TableWriteTraits::kFragmentChannel) ->asUnchecked>(); - auto detail = folly::parseJson(details->valueAt(1)); + auto detail = folly::parseJson(std::string_view(details->valueAt(1))); auto fileWriteInfos = detail["fileWriteInfos"]; ASSERT_EQ(1, fileWriteInfos.size()); ASSERT_EQ(fileWriteInfos[0]["writeFileName"].asString(), kFileName); @@ -205,66 +207,72 @@ class PartitionedTableWriterTest for (bool multiDrivers : multiDriverOptions) { for (FileFormat fileFormat : fileFormats) { for (bool scaleWriter : {false, true}) { - testParams.push_back(TestParam{ - fileFormat, - TestMode::kPartitioned, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kPartitioned, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kBucketed, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kBucketed, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kBucketed, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kPrestoNative, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kBucketed, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kPrestoNative, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kPartitioned, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kPartitioned, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kBucketed, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kBucketed, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kBucketed, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kPrestoNative, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kBucketed, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kPrestoNative, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); } } } @@ -288,26 +296,28 @@ class UnpartitionedTableWriterTest for (bool multiDrivers : multiDriverOptions) { for (FileFormat fileFormat : fileFormats) { for (bool scaleWriter : {false, true}) { - testParams.push_back(TestParam{ - fileFormat, - TestMode::kUnpartitioned, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_NONE, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kUnpartitioned, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_NONE, - scaleWriter} - .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kUnpartitioned, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_NONE, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kUnpartitioned, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_NONE, + scaleWriter} + .value); } } } @@ -331,26 +341,28 @@ class BucketedUnpartitionedTableWriterTest const std::vector bucketModes = {TestMode::kOnlyBucketed}; for (bool multiDrivers : multiDriverOptions) { for (FileFormat fileFormat : fileFormats) { - testParams.push_back(TestParam{ - fileFormat, - TestMode::kOnlyBucketed, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kHiveCompatible, - true, - multiDrivers, - facebook::velox::common::CompressionKind_ZSTD, - /*scaleWriter=*/false} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kOnlyBucketed, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kHiveCompatible, - true, - multiDrivers, - facebook::velox::common::CompressionKind_NONE, - /*scaleWriter=*/false} - .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kOnlyBucketed, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kHiveCompatible, + true, + multiDrivers, + facebook::velox::common::CompressionKind_ZSTD, + /*scaleWriter=*/false} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kOnlyBucketed, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kHiveCompatible, + true, + multiDrivers, + facebook::velox::common::CompressionKind_NONE, + /*scaleWriter=*/false} + .value); } } return testParams; @@ -375,76 +387,83 @@ class BucketedTableOnlyWriteTest for (bool multiDrivers : multiDriverOptions) { for (FileFormat fileFormat : fileFormats) { for (auto bucketMode : bucketModes) { - testParams.push_back(TestParam{ - fileFormat, - bucketMode, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - /*scaleWriter=*/false} - .value); - testParams.push_back(TestParam{ - fileFormat, - bucketMode, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kHiveCompatible, - true, - multiDrivers, - CompressionKind_ZSTD, - /*scaleWriter=*/false} - .value); - testParams.push_back(TestParam{ - fileFormat, - bucketMode, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - /*scaleWriter=*/false} - .value); - testParams.push_back(TestParam{ - fileFormat, - bucketMode, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kHiveCompatible, - true, - multiDrivers, - CompressionKind_ZSTD, - /*scaleWriter=*/false} - .value); - testParams.push_back(TestParam{ - fileFormat, - bucketMode, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kPrestoNative, - false, - multiDrivers, - CompressionKind_ZSTD, - /*scaleWriter=*/false} - .value); - testParams.push_back(TestParam{ - fileFormat, - bucketMode, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kPrestoNative, - true, - multiDrivers, - CompressionKind_ZSTD, - /*scaleWriter=*/false} - .value); - testParams.push_back(TestParam{ - fileFormat, - bucketMode, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kPrestoNative, - false, - multiDrivers, - CompressionKind_ZSTD, - /*scaleWriter=*/false} - .value); + testParams.push_back( + TestParam{ + fileFormat, + bucketMode, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + /*scaleWriter=*/false} + .value); + testParams.push_back( + TestParam{ + fileFormat, + bucketMode, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kHiveCompatible, + true, + multiDrivers, + CompressionKind_ZSTD, + /*scaleWriter=*/false} + .value); + testParams.push_back( + TestParam{ + fileFormat, + bucketMode, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + /*scaleWriter=*/false} + .value); + testParams.push_back( + TestParam{ + fileFormat, + bucketMode, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kHiveCompatible, + true, + multiDrivers, + CompressionKind_ZSTD, + /*scaleWriter=*/false} + .value); + testParams.push_back( + TestParam{ + fileFormat, + bucketMode, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kPrestoNative, + false, + multiDrivers, + CompressionKind_ZSTD, + /*scaleWriter=*/false} + .value); + testParams.push_back( + TestParam{ + fileFormat, + bucketMode, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kPrestoNative, + true, + multiDrivers, + CompressionKind_ZSTD, + /*scaleWriter=*/false} + .value); + testParams.push_back( + TestParam{ + fileFormat, + bucketMode, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kPrestoNative, + false, + multiDrivers, + CompressionKind_ZSTD, + /*scaleWriter=*/false} + .value); } } } @@ -470,26 +489,28 @@ class BucketSortOnlyTableWriterTest for (bool multiDrivers : multiDriverOptions) { for (FileFormat fileFormat : fileFormats) { for (auto bucketMode : bucketModes) { - testParams.push_back(TestParam{ - fileFormat, - bucketMode, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kHiveCompatible, - true, - multiDrivers, - facebook::velox::common::CompressionKind_ZSTD, - /*scaleWriter=*/false} - .value); - testParams.push_back(TestParam{ - fileFormat, - bucketMode, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kHiveCompatible, - true, - multiDrivers, - facebook::velox::common::CompressionKind_NONE, - /*scaleWriter=*/false} - .value); + testParams.push_back( + TestParam{ + fileFormat, + bucketMode, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kHiveCompatible, + true, + multiDrivers, + facebook::velox::common::CompressionKind_ZSTD, + /*scaleWriter=*/false} + .value); + testParams.push_back( + TestParam{ + fileFormat, + bucketMode, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kHiveCompatible, + true, + multiDrivers, + facebook::velox::common::CompressionKind_NONE, + /*scaleWriter=*/false} + .value); } } } @@ -513,26 +534,28 @@ class PartitionedWithoutBucketTableWriterTest for (bool multiDrivers : multiDriverOptions) { for (FileFormat fileFormat : fileFormats) { for (bool scaleWriter : {false, true}) { - testParams.push_back(TestParam{ - fileFormat, - TestMode::kPartitioned, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kPartitioned, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kPartitioned, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kPartitioned, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); } } } @@ -555,126 +578,138 @@ class AllTableWriterTest : public TableWriterTestBase, for (bool multiDrivers : multiDriverOptions) { for (FileFormat fileFormat : fileFormats) { for (bool scaleWriter : {false, true}) { - testParams.push_back(TestParam{ - fileFormat, - TestMode::kUnpartitioned, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kUnpartitioned, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kPartitioned, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kPartitioned, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kBucketed, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kBucketed, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kBucketed, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kPrestoNative, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kBucketed, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kPrestoNative, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kOnlyBucketed, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kOnlyBucketed, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kOnlyBucketed, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kPrestoNative, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kOnlyBucketed, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kPrestoNative, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kUnpartitioned, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kUnpartitioned, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kPartitioned, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kPartitioned, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kBucketed, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kBucketed, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kBucketed, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kPrestoNative, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kBucketed, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kPrestoNative, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kOnlyBucketed, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kOnlyBucketed, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kOnlyBucketed, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kPrestoNative, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kOnlyBucketed, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kPrestoNative, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); } } } @@ -1523,10 +1558,16 @@ TEST_P(UnpartitionedTableWriterTest, runtimeStatsCheck) { } ASSERT_EQ( stats[1].runtimeStats["stripeSize"].count, testData.expectedNumStripes); - ASSERT_EQ(stats[1].runtimeStats[TableWriter::kNumWrittenFiles].sum, 1); - ASSERT_EQ(stats[1].runtimeStats[TableWriter::kNumWrittenFiles].count, 1); - ASSERT_GE(stats[1].runtimeStats[TableWriter::kWriteIOTime].sum, 0); - ASSERT_EQ(stats[1].runtimeStats[TableWriter::kWriteIOTime].count, 1); + ASSERT_EQ( + stats[1].runtimeStats[std::string(TableWriter::kNumWrittenFiles)].sum, + 1); + ASSERT_EQ( + stats[1].runtimeStats[std::string(TableWriter::kNumWrittenFiles)].count, + 1); + ASSERT_GE( + stats[1].runtimeStats[std::string(TableWriter::kWriteIOTime)].sum, 0); + ASSERT_EQ( + stats[1].runtimeStats[std::string(TableWriter::kWriteIOTime)].count, 1); } } @@ -1782,7 +1823,8 @@ TEST_P(AllTableWriterTest, tableWriteOutputCheck) { } if (!fragmentVector->isNullAt(i)) { ASSERT_FALSE(fragmentVector->isNullAt(i)); - folly::dynamic obj = folly::parseJson(fragmentVector->valueAt(i)); + folly::dynamic obj = + folly::parseJson(std::string_view(fragmentVector->valueAt(i))); if (testMode_ == TestMode::kUnpartitioned) { ASSERT_EQ(obj["targetPath"], outputDirectory->getPath()); ASSERT_EQ(obj["writePath"], outputDirectory->getPath()); @@ -1791,13 +1833,17 @@ TEST_P(AllTableWriterTest, tableWriteOutputCheck) { for (const auto& partitionBy : partitionedBy_) { partitionDirRe += fmt::format("/{}=.+", partitionBy); } - ASSERT_TRUE(RE2::FullMatch( - obj["targetPath"].asString(), - fmt::format("{}{}", outputDirectory->getPath(), partitionDirRe))) + ASSERT_TRUE( + RE2::FullMatch( + obj["targetPath"].asString(), + fmt::format( + "{}{}", outputDirectory->getPath(), partitionDirRe))) << obj["targetPath"].asString(); - ASSERT_TRUE(RE2::FullMatch( - obj["writePath"].asString(), - fmt::format("{}{}", outputDirectory->getPath(), partitionDirRe))) + ASSERT_TRUE( + RE2::FullMatch( + obj["writePath"].asString(), + fmt::format( + "{}{}", outputDirectory->getPath(), partitionDirRe))) << obj["writePath"].asString(); } numRows += obj["rowCount"].asInt(); @@ -1822,7 +1868,7 @@ TEST_P(AllTableWriterTest, tableWriteOutputCheck) { ASSERT_EQ(writeFileName, targetFileName); } else { const std::string kParquetSuffix = ".parquet"; - if (folly::StringPiece(targetFileName).endsWith(kParquetSuffix)) { + if (targetFileName.ends_with(kParquetSuffix)) { // Remove the .parquet suffix. auto trimmedFilename = targetFileName.substr( 0, targetFileName.size() - kParquetSuffix.size()); @@ -1833,9 +1879,11 @@ TEST_P(AllTableWriterTest, tableWriteOutputCheck) { } } if (!commitContextVector->isNullAt(i)) { - ASSERT_TRUE(RE2::FullMatch( - commitContextVector->valueAt(i).getString(), - fmt::format(".*{}.*", CommitStrategyName::toName(commitStrategy_)))) + ASSERT_TRUE( + RE2::FullMatch( + commitContextVector->valueAt(i).getString(), + fmt::format( + ".*{}.*", CommitStrategyName::toName(commitStrategy_)))) << commitContextVector->valueAt(i); } } @@ -1904,8 +1952,9 @@ TEST_P(AllTableWriterTest, columnStatsDataTypes) { std::vector groupingKeyFields; groupingKeyFields.reserve(partitionedBy_.size()); for (int i = 0; i < partitionedBy_.size(); ++i) { - groupingKeyFields.emplace_back(std::make_shared( - partitionTypes_.at(i), partitionedBy_.at(i))); + groupingKeyFields.emplace_back( + std::make_shared( + partitionTypes_.at(i), partitionedBy_.at(i))); } // aggregation node @@ -2025,7 +2074,7 @@ TEST_P(AllTableWriterTest, columnStatsDataTypes) { const auto distinctCountStatsVector = result->childAt(nextColumnStatsIndex++)->asFlatVector(); HashStringAllocator allocator{pool_.get()}; - DenseHll denseHll{ + DenseHll<> denseHll{ std::string(distinctCountStatsVector->valueAt(0)).c_str(), &allocator}; ASSERT_EQ(denseHll.cardinality(), 1000); const auto maxDataSizeStatsVector = @@ -2245,6 +2294,301 @@ TEST_P(AllTableWriterTest, columnStatsWithTableWriteMerge) { } } +// Verifies TableWriteMerge with kFinal step and multiple drivers. +// Each driver runs TableWrite(kPartial), LocalGather collects, and +// TableWriteMerge(kFinal) produces final statistics. +TEST_F(BasicTableWriterTest, columnStatsWithTableWriteMergeFinal) { + auto outputDirectory = TempDirectoryPath::create(); + auto input = makeRowVector({ + makeFlatVector(100, folly::identity), + }); + + // Stats columns: min(c0), max(c0) at channels 3 and 4. + auto plan = PlanBuilder() + .values({input}) + .tableWrite( + outputDirectory->getPath(), + dwio::common::FileFormat::DWRF, + {"min(c0)", "max(c0)"}) + .localGather() + .tableWriteMerge(core::AggregationNode::Step::kFinal) + .planNode(); + + auto result = AssertQueryBuilder(plan).maxDrivers(4).copyResults(pool()); + ASSERT_GT(result->size(), 0); + + // Verify the output: expect fragment rows, one stats row with values, + // and one summary row with the total row count. + auto* rowCountVector = result->childAt(TableWriteTraits::kRowCountChannel) + ->as>(); + // Stats columns start at kStatsChannel: min(c0) at +0, max(c0) at +1. + auto* minVector = result->childAt(TableWriteTraits::kStatsChannel) + ->as>(); + auto* maxVector = result->childAt(TableWriteTraits::kStatsChannel + 1) + ->as>(); + + int32_t numFragments = 0; + int32_t numStatsWithValues = 0; + int32_t numSummary = 0; + for (vector_size_t i = 0; i < result->size(); ++i) { + if (!result->childAt(TableWriteTraits::kFragmentChannel)->isNullAt(i)) { + ++numFragments; + } else if (!minVector->isNullAt(i)) { + ++numStatsWithValues; + EXPECT_EQ(0, minVector->valueAt(i)); + EXPECT_EQ(99, maxVector->valueAt(i)); + } else if (!rowCountVector->isNullAt(i)) { + ++numSummary; + EXPECT_EQ(100, rowCountVector->valueAt(i)); + } else { + FAIL() << "Unexpected row " << i << ": " << result->toString(i); + } + } + + EXPECT_EQ(1, numFragments); + EXPECT_EQ(1, numStatsWithValues); + EXPECT_EQ(1, numSummary); +} + +// Same as above but with a partition key as grouping key for stats. +TEST_F(BasicTableWriterTest, columnStatsWithTableWriteMergeFinalPartitioned) { + auto outputDirectory = TempDirectoryPath::create(); + auto input = makeRowVector({ + makeFlatVector(100, folly::identity), + makeFlatVector(100, [](auto row) { return row % 3; }), + }); + + // Partition by c1, collect min/max on c0 grouped by c1. + auto plan = PlanBuilder() + .values({input}) + .tableWrite( + outputDirectory->getPath(), + {"c1"}, + dwio::common::FileFormat::DWRF, + {"min(c0)", "max(c0)"}) + .localGather() + .tableWriteMerge(core::AggregationNode::Step::kFinal) + .planNode(); + + auto result = AssertQueryBuilder(plan).maxDrivers(4).copyResults(pool()); + ASSERT_GT(result->size(), 0); + + // With 3 partitions, expect 3 stats rows (one per partition key value). + auto* rowCountVector = result->childAt(TableWriteTraits::kRowCountChannel) + ->as>(); + // Partition key is at kStatsChannel, stats start at kStatsChannel + 1. + auto* partKeyVector = result->childAt(TableWriteTraits::kStatsChannel) + ->as>(); + auto* minVector = result->childAt(TableWriteTraits::kStatsChannel + 1) + ->as>(); + auto* maxVector = result->childAt(TableWriteTraits::kStatsChannel + 2) + ->as>(); + + // Input: c0 = 0..99, c1 = c0 % 3. Per-partition min/max of c0: + // partition 0: min=0, max=99 + // partition 1: min=1, max=97 + // partition 2: min=2, max=98 + std::map> expectedStats = { + {0, {0, 99}}, {1, {1, 97}}, {2, {2, 98}}}; + + int32_t numFragments = 0; + int32_t numSummary = 0; + for (vector_size_t i = 0; i < result->size(); ++i) { + if (!result->childAt(TableWriteTraits::kFragmentChannel)->isNullAt(i)) { + ++numFragments; + } else if (!minVector->isNullAt(i)) { + ASSERT_FALSE(partKeyVector->isNullAt(i)); + auto partKey = partKeyVector->valueAt(i); + auto it = expectedStats.find(partKey); + ASSERT_NE(it, expectedStats.end()) + << "Unexpected or duplicate stats key: " << partKey; + EXPECT_EQ(it->second.first, minVector->valueAt(i)); + EXPECT_EQ(it->second.second, maxVector->valueAt(i)); + expectedStats.erase(it); + } else if (!rowCountVector->isNullAt(i)) { + ++numSummary; + EXPECT_EQ(100, rowCountVector->valueAt(i)); + } else { + FAIL() << "Unexpected row " << i << ": " << result->toString(i); + } + } + + EXPECT_EQ(3, numFragments); + EXPECT_TRUE(expectedStats.empty()) + << "Missing stats: " << expectedStats.size(); + EXPECT_EQ(1, numSummary); +} + +// Extracts file paths from a fragment JSON object. +std::vector extractFragmentFiles(const folly::dynamic& fragment) { + std::vector files; + auto targetPath = fragment["targetPath"].asString(); + for (const auto& fileInfo : fragment["fileWriteInfos"]) { + files.push_back( + fmt::format( + "{}/{}", targetPath, fileInfo["targetFileName"].asString())); + } + return files; +} + +// Builds a kFinal ColumnStatsSpec from a worker plan's kIntermediate spec. +// The input type is used to resolve aggregate input column references. +core::ColumnStatsSpec makeFinalStatsSpec( + const core::PlanNodePtr& workerPlan, + const RowType& inputType) { + auto mergeNode = + std::dynamic_pointer_cast(workerPlan); + VELOX_CHECK_NOT_NULL(mergeNode); + VELOX_CHECK(mergeNode->hasColumnStatsSpec()); + + auto spec = mergeNode->columnStatsSpec().value(); + spec.aggregationStep = core::AggregationNode::Step::kFinal; + for (size_t i = 0; i < spec.aggregates.size(); ++i) { + auto& aggregate = spec.aggregates[i]; + const auto& name = spec.aggregateNames[i]; + aggregate.call = std::make_shared( + aggregate.call->type(), + aggregate.call->name(), + std::make_shared( + inputType.findChild(name), name)); + } + return spec; +} + +// Simulates multi-node table write stats merging: +// 1. Runs worker plan twice on different inputs (simulating 2 workers): +// Values → RoundRobinPartition → TableWrite(kPartial) → LocalGather → +// TableWriteMerge(kIntermediate) +// 2. Collects all worker outputs and feeds them to a coordinator plan: +// Values(allWorkerOutputs) → TableWriteMerge(kFinal) +// This tests mixed stats/data batches, cross-worker merge with different +// taskIds, and dictionary-encoded input handling. +TEST_F(BasicTableWriterTest, columnStatsWithTwoStageMerge) { + // Run worker plan on 2 different inputs (simulating 2 workers). + // Worker 1: c0 = 0..49, Worker 2: c0 = 50..99. + std::vector workerInputs = { + makeRowVector({makeFlatVector(50, folly::identity)}), + makeRowVector( + {makeFlatVector(50, [](auto row) { return row + 50; })}), + }; + + core::PlanNodePtr workerPlan; + std::vector allWorkerOutputs; + std::vector> outputDirectories; + for (const auto& input : workerInputs) { + // Each worker writes to a separate directory. + auto outputDirectory = TempDirectoryPath::create(); + outputDirectories.push_back(outputDirectory); + workerPlan = PlanBuilder() + .values(split(input, 4)) + .localPartitionRoundRobin() + .tableWrite( + outputDirectory->getPath(), + dwio::common::FileFormat::DWRF, + {"min(c0)", "max(c0)"}) + .localGather() + .tableWriteMerge() + .planNode(); + + auto workerResult = + AssertQueryBuilder(workerPlan).maxDrivers(4).copyResults(pool()); + ASSERT_GT(workerResult->size(), 0); + allWorkerOutputs.push_back(workerResult); + } + + // Count fragment rows from worker outputs to know expected total. + int32_t expectedFragments = 0; + for (const auto& output : allWorkerOutputs) { + for (vector_size_t i = 0; i < output->size(); ++i) { + if (!output->childAt(TableWriteTraits::kFragmentChannel)->isNullAt(i)) { + ++expectedFragments; + } + } + } + ASSERT_GE(expectedFragments, 2); + + auto verifyCoordinatorResult = [&](const RowVectorPtr& result) { + ASSERT_GT(result->size(), 0); + + auto* rowCountVector = result->childAt(TableWriteTraits::kRowCountChannel) + ->as>(); + auto* minVector = result->childAt(TableWriteTraits::kStatsChannel) + ->as>(); + auto* maxVector = result->childAt(TableWriteTraits::kStatsChannel + 1) + ->as>(); + + auto* fragmentVector = result->childAt(TableWriteTraits::kFragmentChannel) + ->as>(); + + int32_t numStatsWithValues = 0; + int32_t numSummary = 0; + int64_t totalFragmentRows = 0; + std::set fragmentFiles; + for (vector_size_t i = 0; i < result->size(); ++i) { + if (!fragmentVector->isNullAt(i)) { + auto fragment = + folly::parseJson(std::string(fragmentVector->valueAt(i))); + totalFragmentRows += fragment["rowCount"].asInt(); + for (const auto& file : extractFragmentFiles(fragment)) { + EXPECT_TRUE(fragmentFiles.insert(file).second) + << "Duplicate fragment file: " << file; + } + } else if (!minVector->isNullAt(i)) { + ++numStatsWithValues; + EXPECT_EQ(0, minVector->valueAt(i)); + EXPECT_EQ(99, maxVector->valueAt(i)); + } else if (!rowCountVector->isNullAt(i)) { + ++numSummary; + EXPECT_EQ(100, rowCountVector->valueAt(i)); + } else { + FAIL() << "Unexpected row " << i << ": " << result->toString(i); + } + } + + // Verify fragment files match what's on disk. + std::set diskFiles; + for (const auto& dir : outputDirectories) { + for (const auto& entry : + std::filesystem::directory_iterator(dir->getPath())) { + if (entry.is_regular_file()) { + diskFiles.insert(entry.path().string()); + } + } + } + EXPECT_EQ(fragmentFiles, diskFiles); + EXPECT_EQ(100, totalFragmentRows); + EXPECT_EQ(1, numStatsWithValues); + EXPECT_EQ(1, numSummary); + }; + + auto spec = makeFinalStatsSpec(workerPlan, *allWorkerOutputs[0]->rowType()); + + auto runCoordinator = [&](const std::vector& input) { + auto plan = PlanBuilder().values(input).tableWriteMerge(spec).planNode(); + verifyCoordinatorResult( + AssertQueryBuilder(plan).maxDrivers(1).copyResults(pool())); + }; + + // Run coordinator on worker outputs in order. + runCoordinator(allWorkerOutputs); + + // Run coordinator on reversed worker outputs (worker 2 first). + { + auto reversed = allWorkerOutputs; + std::reverse(reversed.begin(), reversed.end()); + runCoordinator(reversed); + } + + // Run coordinator on split and interleaved worker outputs. Splits each + // worker output into 2 vectors, creating mixed batches with both stats + // and data rows. + { + auto parts1 = split(allWorkerOutputs[0], 2); + auto parts2 = split(allWorkerOutputs[1], 2); + runCoordinator({parts2[0], parts1[0], parts1[1], parts2[1]}); + } +} + // TODO: add partitioned table write update mode tests and more failure tests. TEST_P(AllTableWriterTest, tableWriterStats) { @@ -2307,17 +2651,17 @@ TEST_P(AllTableWriterTest, tableWriterStats) { fixedWrittenBytes); ASSERT_EQ( stats.operatorStats.at("TableWrite") - ->customStats.at(TableWriter::kNumWrittenFiles) + ->customStats.at(std::string(TableWriter::kNumWrittenFiles)) .sum, numWrittenFiles); ASSERT_GE( stats.operatorStats.at("TableWrite") - ->customStats.at(TableWriter::kWriteIOTime) + ->customStats.at(std::string(TableWriter::kWriteIOTime)) .sum, 0); ASSERT_GE( stats.operatorStats.at("TableWrite") - ->customStats.at(TableWriter::kRunningWallNanos) + ->customStats.at(std::string(TableWriter::kRunningWallNanos)) .sum, 0); } @@ -2408,9 +2752,9 @@ DEBUG_ONLY_TEST_P(UnpartitionedTableWriterTest, dataSinkAbortError) { std::atomic triggerAbortErrorOnce{true}; SCOPED_TESTVALUE_SET( - "facebook::velox::connector::hive::HiveDataSink::closeInternal", - std::function( - [&](const HiveDataSink* /*unused*/) { + "facebook::velox::connector::hive::FileDataSink::closeInternal", + std::function( + [&](const FileDataSink* /*unused*/) { if (!triggerAbortErrorOnce.exchange(false)) { return; } @@ -2466,14 +2810,16 @@ TEST_P(BucketSortOnlyTableWriterTest, sortWriterSpill) { // One spilled partition per each written files. const int numWrittenFiles = stats.customStats["numWrittenFiles"].sum; ASSERT_GE(stats.spilledPartitions, numWrittenFiles); - ASSERT_GT(stats.customStats[Operator::kSpillRuns].sum, 0); - ASSERT_GT(stats.customStats[Operator::kSpillFillTime].sum, 0); - ASSERT_GT(stats.customStats[Operator::kSpillSortTime].sum, 0); - ASSERT_GT(stats.customStats[Operator::kSpillExtractVectorTime].sum, 0); - ASSERT_GT(stats.customStats[Operator::kSpillSerializationTime].sum, 0); - ASSERT_GT(stats.customStats[Operator::kSpillFlushTime].sum, 0); - ASSERT_GT(stats.customStats[Operator::kSpillWrites].sum, 0); - ASSERT_GT(stats.customStats[Operator::kSpillWriteTime].sum, 0); + ASSERT_GT(stats.customStats[std::string(Operator::kSpillRuns)].sum, 0); + ASSERT_GT(stats.customStats[std::string(Operator::kSpillFillTime)].sum, 0); + ASSERT_GT(stats.customStats[std::string(Operator::kSpillSortTime)].sum, 0); + ASSERT_GT( + stats.customStats[std::string(Operator::kSpillExtractVectorTime)].sum, 0); + ASSERT_GT( + stats.customStats[std::string(Operator::kSpillSerializationTime)].sum, 0); + ASSERT_GT(stats.customStats[std::string(Operator::kSpillFlushTime)].sum, 0); + ASSERT_GT(stats.customStats[std::string(Operator::kSpillWrites)].sum, 0); + ASSERT_GT(stats.customStats[std::string(Operator::kSpillWriteTime)].sum, 0); } DEBUG_ONLY_TEST_P(BucketSortOnlyTableWriterTest, outputBatchRows) { @@ -2490,13 +2836,14 @@ DEBUG_ONLY_TEST_P(BucketSortOnlyTableWriterTest, outputBatchRows) { maxOutputBytes, expectedOutputCount); } - } testSettings[] = {// we have 4 buckets thus 4 writers. - {10000, "1000kB", 4}, - // when maxOutputRows = 1, 1000 rows triggers 1000 writes - {1, "1kB", 1000}, - // estimatedRowSize is ~62bytes, when maxOutputSize = 62 * - // 100, 1000 rows triggers ~10 writes - {10000, "6200B", 12}}; + } testSettings[] = { + // we have 4 buckets thus 4 writers. + {10000, "1000kB", 4}, + // when maxOutputRows = 1, 1000 rows triggers 1000 writes + {1, "1kB", 1000}, + // estimatedRowSize is ~62bytes, when maxOutputSize = 62 * + // 100, 1000 rows triggers ~10 writes + {10000, "6200B", 12}}; for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); @@ -2772,8 +3119,10 @@ DEBUG_ONLY_TEST_F(TableWriterArbitrationTest, reclaimFromTableWriter) { const int numPrevArbitrationFailures = arbitrator->stats().numFailures; const int numPrevNonReclaimableAttempts = arbitrator->stats().numNonReclaimableAttempts; - auto queryCtx = core::QueryCtx::create( - executor_.get(), QueryConfig{{}}, {}, nullptr, std::move(queryPool)); + auto queryCtx = core::QueryCtx::Builder() + .executor(executor_.get()) + .pool(std::move(queryPool)) + .build(); ASSERT_EQ(queryCtx->pool()->capacity(), kQueryMemoryCapacity); std::atomic_int numInputs{0}; @@ -2891,11 +3240,13 @@ DEBUG_ONLY_TEST_F(TableWriterArbitrationTest, reclaimFromSortTableWriter) { const int numPrevArbitrationFailures = arbitrator->stats().numFailures; const int numPrevNonReclaimableAttempts = arbitrator->stats().numNonReclaimableAttempts; - auto queryCtx = core::QueryCtx::create( - executor_.get(), QueryConfig{{}}, {}, nullptr, std::move(queryPool)); + auto queryCtx = core::QueryCtx::Builder() + .executor(executor_.get()) + .pool(std::move(queryPool)) + .build(); ASSERT_EQ(queryCtx->pool()->capacity(), kQueryMemoryCapacity); - const auto spillStats = common::globalSpillStats(); + const auto spillStats = globalSpillStats(); std::atomic numInputs{0}; SCOPED_TESTVALUE_SET( "facebook::velox::exec::Driver::runInternal::addInput", @@ -2963,7 +3314,7 @@ DEBUG_ONLY_TEST_F(TableWriterArbitrationTest, reclaimFromSortTableWriter) { numPrevNonReclaimableAttempts); waitForAllTasksToBeDeleted(3'000'000); - const auto updatedSpillStats = common::globalSpillStats(); + const auto updatedSpillStats = globalSpillStats(); if (writerSpillEnabled) { ASSERT_GT(updatedSpillStats.spilledBytes, spillStats.spilledBytes); ASSERT_GT( @@ -2994,10 +3345,11 @@ DEBUG_ONLY_TEST_F(TableWriterArbitrationTest, writerFlushThreshold) { const std::vector testParams{ {0, 0}, {0, 1UL << 30}, {64UL << 20, 1UL << 30}}; for (const auto& testParam : testParams) { - SCOPED_TRACE(fmt::format( - "bytesToReserve: {}, writerFlushThreshold: {}", - succinctBytes(testParam.bytesToReserve), - succinctBytes(testParam.writerFlushThreshold))); + SCOPED_TRACE( + fmt::format( + "bytesToReserve: {}, writerFlushThreshold: {}", + succinctBytes(testParam.bytesToReserve), + succinctBytes(testParam.writerFlushThreshold))); auto queryPool = memory::memoryManager()->addRootPool( "writerFlushThreshold", kQueryMemoryCapacity); @@ -3005,8 +3357,11 @@ DEBUG_ONLY_TEST_F(TableWriterArbitrationTest, writerFlushThreshold) { const int numPrevArbitrationFailures = arbitrator->stats().numFailures; const int numPrevNonReclaimableAttempts = arbitrator->stats().numNonReclaimableAttempts; - auto queryCtx = core::QueryCtx::create( - executor_.get(), QueryConfig{{}}, {}, nullptr, std::move(queryPool)); + auto queryCtx = core::QueryCtx::Builder() + .executor(executor_.get()) + .pool(std::move(queryPool)) + .build(); + ASSERT_EQ(queryCtx->pool()->capacity(), kQueryMemoryCapacity); memory::MemoryPool* compressionPool{nullptr}; @@ -3113,8 +3468,11 @@ DEBUG_ONLY_TEST_F( const int numPrevArbitrationFailures = arbitrator->stats().numFailures; const int numPrevNonReclaimableAttempts = arbitrator->stats().numNonReclaimableAttempts; - auto queryCtx = core::QueryCtx::create( - executor_.get(), QueryConfig{{}}, {}, nullptr, std::move(queryPool)); + auto queryCtx = core::QueryCtx::Builder() + .executor(executor_.get()) + .pool(std::move(queryPool)) + .build(); + ASSERT_EQ(queryCtx->pool()->capacity(), kQueryMemoryCapacity); std::atomic injectFakeAllocationOnce{true}; @@ -3196,8 +3554,10 @@ DEBUG_ONLY_TEST_F( const int numPrevNonReclaimableAttempts = arbitrator->stats().numNonReclaimableAttempts; const int numPrevReclaimedBytes = arbitrator->stats().reclaimedUsedBytes; - auto queryCtx = core::QueryCtx::create( - executor_.get(), QueryConfig{{}}, {}, nullptr, std::move(queryPool)); + auto queryCtx = core::QueryCtx::Builder() + .executor(executor_.get()) + .pool(std::move(queryPool)) + .build(); ASSERT_EQ(queryCtx->pool()->capacity(), kQueryMemoryCapacity); std::atomic writerNoMoreInput{false}; @@ -3296,8 +3656,10 @@ DEBUG_ONLY_TEST_F( const int numPrevArbitrationFailures = arbitrator->stats().numFailures; const int numPrevNonReclaimableAttempts = arbitrator->stats().numNonReclaimableAttempts; - auto queryCtx = core::QueryCtx::create( - executor_.get(), QueryConfig{{}}, {}, nullptr, std::move(queryPool)); + auto queryCtx = core::QueryCtx::Builder() + .executor(executor_.get()) + .pool(std::move(queryPool)) + .build(); ASSERT_EQ(queryCtx->pool()->capacity(), kQueryMemoryCapacity); std::atomic injectFakeAllocationOnce{true}; @@ -3340,7 +3702,7 @@ DEBUG_ONLY_TEST_F( {fmt::format("sum({})", TableWriteTraits::rowCountColumnName())}) .planNode(); - const auto spillStats = common::globalSpillStats(); + const auto spillStats = globalSpillStats(); const auto spillDirectory = TempDirectoryPath::create(); AssertQueryBuilder(duckDbQueryRunner_) .queryCtx(queryCtx) @@ -3366,7 +3728,7 @@ DEBUG_ONLY_TEST_F( ASSERT_EQ( arbitrator->stats().numNonReclaimableAttempts, numPrevNonReclaimableAttempts + 1); - const auto updatedSpillStats = common::globalSpillStats(); + const auto updatedSpillStats = globalSpillStats(); ASSERT_EQ(updatedSpillStats, spillStats); waitForAllTasksToBeDeleted(); } @@ -3388,8 +3750,11 @@ DEBUG_ONLY_TEST_F(TableWriterArbitrationTest, tableFileWriteError) { auto queryPool = memory::memoryManager()->addRootPool( "tableFileWriteError", kQueryMemoryCapacity); - auto queryCtx = core::QueryCtx::create( - executor_.get(), QueryConfig{{}}, {}, nullptr, std::move(queryPool)); + auto queryCtx = core::QueryCtx::Builder() + .executor(executor_.get()) + .pool(std::move(queryPool)) + .build(); + ASSERT_EQ(queryCtx->pool()->capacity(), kQueryMemoryCapacity); std::atomic_bool injectWriterErrorOnce{true}; @@ -3456,8 +3821,10 @@ DEBUG_ONLY_TEST_F(TableWriterArbitrationTest, tableWriteSpillUseMoreMemory) { auto queryPool = memory::memoryManager()->addRootPool( "tableWriteSpillUseMoreMemory", kQueryMemoryCapacity / 4); - auto queryCtx = core::QueryCtx::create( - executor_.get(), QueryConfig{{}}, {}, nullptr, std::move(queryPool)); + auto queryCtx = core::QueryCtx::Builder() + .executor(executor_.get()) + .pool(std::move(queryPool)) + .build(); ASSERT_EQ(queryCtx->pool()->capacity(), kQueryMemoryCapacity / 4); auto fakeLeafPool = queryCtx->pool()->addLeafChild( @@ -3543,14 +3910,18 @@ DEBUG_ONLY_TEST_F(TableWriterArbitrationTest, tableWriteReclaimOnClose) { auto queryPool = memory::memoryManager()->addRootPool( "tableWriteSpillUseMoreMemory", kQueryMemoryCapacity); - auto queryCtx = core::QueryCtx::create( - executor_.get(), QueryConfig{{}}, {}, nullptr, std::move(queryPool)); + auto queryCtx = core::QueryCtx::Builder() + .executor(executor_.get()) + .pool(std::move(queryPool)) + .build(); ASSERT_EQ(queryCtx->pool()->capacity(), kQueryMemoryCapacity); auto fakeQueryPool = memory::memoryManager()->addRootPool("fake", kQueryMemoryCapacity); - auto fakeQueryCtx = core::QueryCtx::create( - executor_.get(), QueryConfig{{}}, {}, nullptr, std::move(fakeQueryPool)); + auto fakeQueryCtx = core::QueryCtx::Builder() + .executor(executor_.get()) + .pool(std::move(fakeQueryPool)) + .build(); ASSERT_EQ(fakeQueryCtx->pool()->capacity(), kQueryMemoryCapacity); auto fakeLeafPool = fakeQueryCtx->pool()->addLeafChild( @@ -3636,8 +4007,10 @@ DEBUG_ONLY_TEST_F( .data; auto queryPool = memory::memoryManager()->addRootPool( "tableWriteSpillUseMoreMemory", kQueryMemoryCapacity); - auto queryCtx = core::QueryCtx::create( - executor_.get(), QueryConfig{{}}, {}, nullptr, std::move(queryPool)); + auto queryCtx = core::QueryCtx::Builder() + .executor(executor_.get()) + .pool(std::move(queryPool)) + .build(); ASSERT_EQ(queryCtx->pool()->capacity(), kQueryMemoryCapacity); std::atomic_bool writerCloseWaitFlag{true}; diff --git a/velox/exec/tests/TaskTest.cpp b/velox/exec/tests/TaskTest.cpp index e230dfbbd47..2bfcda8f94d 100644 --- a/velox/exec/tests/TaskTest.cpp +++ b/velox/exec/tests/TaskTest.cpp @@ -15,17 +15,17 @@ */ #include "velox/exec/Task.h" -#include "folly/experimental/EventCount.h" +#include "folly/synchronization/EventCount.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/tests/FaultyFileSystem.h" #include "velox/common/future/VeloxPromise.h" #include "velox/common/memory/MemoryArbitrator.h" #include "velox/common/memory/SharedArbitrator.h" #include "velox/common/memory/tests/SharedArbitratorTestUtil.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/common/testutil/TestValue.h" #include "velox/connectors/hive/HiveConnectorSplit.h" #include "velox/exec/Cursor.h" -#include "velox/exec/HashAggregation.h" #include "velox/exec/OutputBufferManager.h" #include "velox/exec/PlanNodeStats.h" #include "velox/exec/Values.h" @@ -33,13 +33,15 @@ #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/exec/tests/utils/QueryAssertions.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/vector/fuzzer/VectorFuzzer.h" using namespace facebook::velox; using namespace facebook::velox::common::testutil; namespace facebook::velox::exec::test { + +using TempDirectoryPath = facebook::velox::common::testutil::TempDirectoryPath; + namespace { // A test join node whose build is skewed in terms of process time. The driver // id 0 processes slower than other drivers if paralelism greater than 1 @@ -256,7 +258,8 @@ class ExternalBlocker { public: folly::SemiFuture continueFuture() { if (isBlocked_) { - auto [promise, future] = makeVeloxContinuePromiseContract(); + auto [promise, future] = + makeVeloxContinuePromiseContract("ExternalBlocker::continueFuture"); continuePromise_ = std::move(promise); return std::move(future); } @@ -524,7 +527,8 @@ class TaskTest : public HiveConnectorTestBase { plan, 0, core::QueryCtx::create(), - Task::ExecutionMode::kSerial); + Task::ExecutionMode::kSerial, + exec::Consumer{}); for (const auto& [nodeId, paths] : filePaths) { for (const auto& path : paths) { @@ -573,7 +577,8 @@ TEST_F(TaskTest, toJson) { std::move(plan), 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); ASSERT_EQ( task->toString(), @@ -638,7 +643,8 @@ TEST_F(TaskTest, wrongPlanNodeForSplit) { std::move(plan), 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); // Add split for the source node. task->addSplit("0", exec::Split(folly::copy(connectorSplit))); @@ -694,7 +700,8 @@ TEST_F(TaskTest, wrongPlanNodeForSplit) { std::move(plan), 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); errorMessage = "Splits can be associated only with leaf plan nodes which require splits. Plan node ID 0 doesn't refer to such plan node."; VELOX_ASSERT_THROW( @@ -721,7 +728,8 @@ TEST_F(TaskTest, duplicatePlanNodeIds) { std::move(plan), 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel), + Task::ExecutionMode::kParallel, + exec::Consumer{}), "Plan node IDs must be unique. Found duplicate ID: 0.") } @@ -897,8 +905,8 @@ TEST_F(TaskTest, hasMixedExecutionGroupJoin) { task->start(1); - ASSERT_FALSE( - task->hasMixedExecutionGroupJoin(dynamic_cast( + ASSERT_FALSE(task->hasMixedExecutionGroupJoin( + dynamic_cast( nonMixedGroupedModeJoinNode.get()))); ASSERT_TRUE(task->hasMixedExecutionGroupJoin( dynamic_cast(mixedGroupedModeJoinNode.get()))); @@ -1228,7 +1236,8 @@ TEST_F(TaskTest, serialExecutionExternalBlockable) { plan, 0, core::QueryCtx::create(), - Task::ExecutionMode::kSerial); + Task::ExecutionMode::kSerial, + exec::Consumer{}); std::vector results; for (;;) { auto result = nonBlockingTask->next(&continueFuture); @@ -1254,7 +1263,8 @@ TEST_F(TaskTest, serialExecutionExternalBlockable) { plan, 0, core::QueryCtx::create(), - Task::ExecutionMode::kSerial); + Task::ExecutionMode::kSerial, + exec::Consumer{}); // Before we block, we expect `next` to get data normally. results.push_back(blockingTask->next(&continueFuture)); EXPECT_TRUE(results.back() != nullptr); @@ -1291,16 +1301,17 @@ TEST_F(TaskTest, supportSerialExecutionMode) { .project({"c0 % 10"}) .partitionedOutput({}, 1, std::vector{"p0"}) .planFragment(); - auto task = Task::create( - "single.execution.task.0", - plan, - 0, - core::QueryCtx::create(), - Task::ExecutionMode::kSerial); - // PartitionedOutput does not support serial execution mode, therefore the // task doesn't support it either. - ASSERT_FALSE(task->supportSerialExecutionMode()); + VELOX_ASSERT_THROW( + Task::create( + "single.execution.task.0", + plan, + 0, + core::QueryCtx::create(), + Task::ExecutionMode::kSerial, + exec::Consumer{}), + ""); } TEST_F(TaskTest, updateBroadCastOutputBuffers) { @@ -1316,7 +1327,8 @@ TEST_F(TaskTest, updateBroadCastOutputBuffers) { plan, 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); task->start(1, 1); @@ -1334,7 +1346,8 @@ TEST_F(TaskTest, updateBroadCastOutputBuffers) { plan, 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); task->start(1, 1); @@ -1606,8 +1619,7 @@ DEBUG_ONLY_TEST_F(TaskTest, inconsistentExecutionMode) { VELOX_ASSERT_THROW(task->next(), "Inconsistent task execution mode."); getOutputWaitFlag = true; getOutputWait.notify(); - while (cursor->hasNext()) { - cursor->moveNext(); + while (cursor->moveNext()) { } waitForTaskCompletion(task); } @@ -1622,8 +1634,13 @@ DEBUG_ONLY_TEST_F(TaskTest, inconsistentExecutionMode) { auto plan = PlanBuilder().values({data, data, data}).project({"c0"}).planFragment(); auto queryCtx = core::QueryCtx::create(driverExecutor_.get()); - auto task = - Task::create("task.0", plan, 0, queryCtx, Task::ExecutionMode::kSerial); + auto task = Task::create( + "task.0", + plan, + 0, + queryCtx, + Task::ExecutionMode::kSerial, + exec::Consumer{}); task->next(); VELOX_ASSERT_THROW(task->start(4, 1), "Inconsistent task execution mode."); @@ -1868,7 +1885,7 @@ DEBUG_ONLY_TEST_F(TaskTest, driverCounters) { pauseDriverRunInternal.lock(); SCOPED_TESTVALUE_SET( "facebook::velox::exec::Driver::runInternal", - std::function([&](Driver*) { + std::function([&](Driver* /*unused*/) { pauseDriverRunInternal.lock_shared(); pauseDriverRunInternal.unlock_shared(); })); @@ -1957,17 +1974,24 @@ TEST_F(TaskTest, driverCreationMemoryAllocationCheck) { .planFragment(); for (bool singleThreadExecution : {false, true}) { SCOPED_TRACE(fmt::format("singleThreadExecution: ", singleThreadExecution)); - auto badTask = Task::create( - "driverCreationMemoryAllocationCheck", - plan, - 0, - core::QueryCtx::create( - singleThreadExecution ? nullptr : driverExecutor_.get()), - singleThreadExecution ? Task::ExecutionMode::kSerial - : Task::ExecutionMode::kParallel); if (singleThreadExecution) { - VELOX_ASSERT_THROW(badTask->next(), "Unexpected memory pool allocations"); + VELOX_ASSERT_THROW( + Task::create( + "driverCreationMemoryAllocationCheck", + plan, + 0, + core::QueryCtx::create(nullptr), + Task::ExecutionMode::kSerial, + exec::Consumer{}), + "Unexpected memory pool allocations"); } else { + auto badTask = Task::create( + "driverCreationMemoryAllocationCheck", + plan, + 0, + core::QueryCtx::create(driverExecutor_.get()), + Task::ExecutionMode::kParallel, + exec::Consumer{}); VELOX_ASSERT_THROW( badTask->start(1), "Unexpected memory pool allocations"); } @@ -1994,36 +2018,34 @@ TEST_F(TaskTest, spillDirectoryCallback) { {{core::QueryConfig::kSpillEnabled, "true"}, {core::QueryConfig::kAggregationSpillEnabled, "true"}}); params.maxDrivers = 1; - - auto cursor = TaskCursor::create(params); - - std::shared_ptr task = cursor->task(); - auto tmpRootDir = exec::test::TempDirectoryPath::create(); - auto tmpParentSpillDir = fmt::format( + auto spillRootDir = TempDirectoryPath::create(); + auto spillParentDir = fmt::format( "{}{}/parent_spill/", tests::utils::FaultyFileSystem::scheme(), - tmpRootDir->getPath()); - auto tmpSpillDir = fmt::format( + spillRootDir->getPath()); + auto spillDir = fmt::format( "{}{}/parent_spill/spill/", tests::utils::FaultyFileSystem::scheme(), - tmpRootDir->getPath()); - - EXPECT_FALSE(task->hasCreateSpillDirectoryCb()); + spillRootDir->getPath()); - task->setCreateSpillDirectoryCb([tmpParentSpillDir, tmpSpillDir]() { - auto filesystem = filesystems::getFileSystem(tmpParentSpillDir, nullptr); + params.spillDirectory = spillDir; + params.spillDirectoryCallback = [spillParentDir, spillDir]() { + auto filesystem = filesystems::getFileSystem(spillParentDir, nullptr); filesystems::DirectoryOptions options; options.values.emplace( filesystems::DirectoryOptions::kMakeDirectoryConfig.toString(), "dummy.config=123"); - filesystem->mkdir(tmpParentSpillDir, options); - filesystem->mkdir(tmpSpillDir); - return tmpSpillDir; - }); + filesystem->mkdir(spillParentDir, options); + filesystem->mkdir(spillDir); + return spillDir; + }; + auto cursor = TaskCursor::create(params); + std::shared_ptr task = cursor->task(); EXPECT_TRUE(task->hasCreateSpillDirectoryCb()); + auto fs = std::dynamic_pointer_cast( - filesystems::getFileSystem(tmpParentSpillDir, nullptr)); + filesystems::getFileSystem(spillParentDir, nullptr)); fs->setFileSystemInjectionError( std::make_exception_ptr(std::runtime_error("test exception")), @@ -2039,7 +2061,7 @@ TEST_F(TaskTest, spillDirectoryCallback) { auto mkdirOp = static_cast(op); if (mkdirOp->path == - fmt::format("{}/parent_spill/", tmpRootDir->getPath())) { + fmt::format("{}/parent_spill/", spillRootDir->getPath())) { parentDirectoryCreated = true; auto it = mkdirOp->options.values.find( filesystems::DirectoryOptions::kMakeDirectoryConfig.toString()); @@ -2047,7 +2069,7 @@ TEST_F(TaskTest, spillDirectoryCallback) { EXPECT_EQ(it->second, "dummy.config=123"); } if (mkdirOp->path == - fmt::format("{}/parent_spill/spill/", tmpRootDir->getPath())) { + fmt::format("{}/parent_spill/spill/", spillRootDir->getPath())) { spillDirectoryCreated = true; } return; @@ -2092,13 +2114,13 @@ TEST_F(TaskTest, spillDirectoryLifecycleManagement) { {{core::QueryConfig::kSpillEnabled, "true"}, {core::QueryConfig::kAggregationSpillEnabled, "true"}}); params.maxDrivers = 1; + const auto rootTempDir = TempDirectoryPath::create(); + const auto tmpDirectoryPath = + rootTempDir->getPath() + "/spillDirectoryLifecycleManagement"; + params.spillDirectory = tmpDirectoryPath; auto cursor = TaskCursor::create(params); std::shared_ptr task = cursor->task(); - auto rootTempDir = exec::test::TempDirectoryPath::create(); - auto tmpDirectoryPath = - rootTempDir->getPath() + "/spillDirectoryLifecycleManagement"; - task->setSpillDirectory(tmpDirectoryPath, false); TestScopedSpillInjection scopedSpillInjection(100); while (cursor->moveNext()) { @@ -2152,9 +2174,8 @@ TEST_F(TaskTest, spillDirNotCreated) { auto cursor = TaskCursor::create(params); auto* task = cursor->task().get(); - auto rootTempDir = exec::test::TempDirectoryPath::create(); + auto rootTempDir = TempDirectoryPath::create(); auto tmpDirectoryPath = rootTempDir->getPath() + "/spillDirNotCreated"; - task->setSpillDirectory(tmpDirectoryPath, false); while (cursor->moveNext()) { } @@ -2191,7 +2212,7 @@ DEBUG_ONLY_TEST_F(TaskTest, resumeAfterTaskFinish) { SCOPED_TESTVALUE_SET( "facebook::velox::exec::Values::getOutput", std::function( - ([&](const velox::exec::Values* values) { + ([&](const velox::exec::Values* /*unused*/) { valuesWait.await([&]() { return !valuesWaitFlag.load(); }); }))); @@ -2200,7 +2221,8 @@ DEBUG_ONLY_TEST_F(TaskTest, resumeAfterTaskFinish) { std::move(plan), 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); task->start(4, 1); // Request pause and then unblock operators to proceed. @@ -2231,7 +2253,12 @@ DEBUG_ONLY_TEST_F(TaskTest, serialLongRunningOperatorInTaskReclaimerAbort) { auto queryCtx = core::QueryCtx::create(driverExecutor_.get()); auto blockingTask = Task::create( - "blocking.task.0", plan, 0, queryCtx, Task::ExecutionMode::kSerial); + "blocking.task.0", + plan, + 0, + queryCtx, + Task::ExecutionMode::kSerial, + exec::Consumer{}); // Before we block, we expect `next` to get data normally. EXPECT_NE(nullptr, blockingTask->next()); @@ -2306,7 +2333,12 @@ DEBUG_ONLY_TEST_F(TaskTest, longRunningOperatorInTaskReclaimerAbort) { auto queryCtx = core::QueryCtx::create(driverExecutor_.get()); auto blockingTask = Task::create( - "blocking.task.0", plan, 0, queryCtx, Task::ExecutionMode::kParallel); + "blocking.task.0", + plan, + 0, + queryCtx, + Task::ExecutionMode::kParallel, + exec::Consumer{}); blockingTask->start(4, 1); const std::string abortErrorMessage("Synthetic Exception"); @@ -2372,7 +2404,8 @@ DEBUG_ONLY_TEST_F(TaskTest, taskReclaimStats) { std::move(plan), 0, std::move(queryCtx), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); task->start(4, 1); const int numReclaims{10}; @@ -2410,6 +2443,7 @@ DEBUG_ONLY_TEST_F(TaskTest, taskPauseTime) { opts.vectorSize = 32; VectorFuzzer fuzzer(opts, pool_.get()); std::vector valueInputs; + valueInputs.reserve(4); for (int32_t i = 0; i < 4; ++i) { valueInputs.push_back(fuzzer.fuzzRow(rowType)); } @@ -2424,13 +2458,15 @@ DEBUG_ONLY_TEST_F(TaskTest, taskPauseTime) { folly::EventCount taskPauseWait; SCOPED_TESTVALUE_SET( "facebook::velox::exec::Values::getOutput", - std::function([&](const exec::Values* values) { - if (taskPauseWaitFlag.exchange(false)) { - taskPauseWait.notifyAll(); - } - // Inject some delay for task pause stats verification. - std::this_thread::sleep_for(std::chrono::milliseconds(10)); // NOLINT - })); + std::function( + [&](const exec::Values* /*unused*/) { + if (taskPauseWaitFlag.exchange(false)) { + taskPauseWait.notifyAll(); + } + // Inject some delay for task pause stats verification. + std::this_thread::sleep_for( + std::chrono::milliseconds(10)); // NOLINT + })); auto queryPool = memory::memoryManager()->addRootPool( "taskPauseTime", 1UL << 30, exec::MemoryReclaimer::create()); @@ -2446,7 +2482,8 @@ DEBUG_ONLY_TEST_F(TaskTest, taskPauseTime) { std::move(plan), 0, std::move(queryCtx), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); task->start(4, 1); // Wait for the task driver starts to run. @@ -2468,11 +2505,11 @@ DEBUG_ONLY_TEST_F(TaskTest, taskPauseTime) { ASSERT_EQ(taskStats.pipelineStats[0].driverStats.size(), 1); const auto& driverStats = taskStats.pipelineStats[0].driverStats[0]; const auto& totalPauseTime = - driverStats.runtimeStats.at(DriverStats::kTotalPauseTime); + driverStats.runtimeStats.at(std::string(DriverStats::kTotalPauseTime)); ASSERT_EQ(totalPauseTime.count, 1); ASSERT_GE(totalPauseTime.sum, 0); - const auto& totalOffThreadTime = - driverStats.runtimeStats.at(DriverStats::kTotalOffThreadTime); + const auto& totalOffThreadTime = driverStats.runtimeStats.at( + std::string(DriverStats::kTotalOffThreadTime)); ASSERT_EQ(totalOffThreadTime.count, 1); ASSERT_GE(totalOffThreadTime.sum, 0); @@ -2495,7 +2532,8 @@ TEST_F(TaskTest, updateStatsWhileCloseOffThreadDriver) { std::move(plan), 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); task->start(4, 1); std::this_thread::sleep_for(std::chrono::milliseconds(100)); task->testingVisitDrivers( @@ -2508,13 +2546,119 @@ TEST_F(TaskTest, updateStatsWhileCloseOffThreadDriver) { ASSERT_EQ(taskStats.pipelineStats.size(), 1); ASSERT_EQ(taskStats.pipelineStats[0].driverStats.size(), 4); const auto& driverStats = taskStats.pipelineStats[0].driverStats[0]; - ASSERT_EQ(driverStats.runtimeStats.count(DriverStats::kTotalPauseTime), 0); - const auto& totalOffThreadTime = - driverStats.runtimeStats.at(DriverStats::kTotalOffThreadTime); + ASSERT_EQ( + driverStats.runtimeStats.count(std::string(DriverStats::kTotalPauseTime)), + 0); + const auto& totalOffThreadTime = driverStats.runtimeStats.at( + std::string(DriverStats::kTotalOffThreadTime)); ASSERT_EQ(totalOffThreadTime.count, 1); ASSERT_GE(totalOffThreadTime.sum, 0); } +// Verifies that driver-level lifecycle timing metrics +// (driverQueuedWallNanos, driverOnThreadWallNanos, driverBlockedWallNanos) +// are reported correctly for each pipeline. +TEST_F(TaskTest, driverLifecycleTimingStats) { + auto data = makeRowVector({ + makeFlatVector(1'000, [](auto row) { return row % 10; }), + makeFlatVector(1'000, [](auto row) { return row; }), + }); + + auto buildData = makeRowVector( + {"u0", "u1"}, + { + makeFlatVector(100, [](auto row) { return row % 10; }), + makeFlatVector(100, [](auto row) { return row; }), + }); + + auto planNodeIdGenerator = std::make_shared(); + + // Build a hash join plan to get multiple pipelines: + // Pipeline 0 (probe): Values -> HashProbe -> sink + // Pipeline 1 (build): Values -> HashBuild + core::PlanNodeId probeScanId; + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({data}) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"c0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({buildData}).planNode(), + "", + {"c0", "c1"}) + .planNode(); + + auto result = AssertQueryBuilder(plan).copyResults(pool_.get()); + ASSERT_GT(result->size(), 0); + + // Get task stats from the most recently completed task. + auto tasks = Task::getRunningTasks(); + // Tasks may already be cleaned up. Use a cursor-based approach instead. + + // Re-run using CursorParameters to access the task directly. + CursorParameters params; + params.planNode = plan; + params.queryCtx = core::QueryCtx::create(driverExecutor_.get()); + params.maxDrivers = 2; + + auto cursor = TaskCursor::create(params); + while (cursor->moveNext()) { + } + auto task = cursor->task(); + ASSERT_TRUE(waitForTaskCompletion(task.get())); + + auto taskStats = task->taskStats(); + // We should have at least 2 pipelines (probe and build). + ASSERT_GE(taskStats.pipelineStats.size(), 2); + + // Helper to find a driver stat by substring match across all pipelines. + auto findDriverStat = + [&taskStats]( + const std::string& statSubstr) -> std::optional { + for (const auto& pipeline : taskStats.pipelineStats) { + for (const auto& ds : pipeline.driverStats) { + for (const auto& [name, metric] : ds.runtimeStats) { + if (name.find(statSubstr) != std::string::npos) { + return metric; + } + } + } + } + return std::nullopt; + }; + + // Verify driver lifecycle metrics exist and have sensible values. + auto queued = findDriverStat("driverQueuedWallNanos"); + auto onThread = findDriverStat("driverOnThreadWallNanos"); + auto blocked = findDriverStat("driverBlockedWallNanos"); + + ASSERT_TRUE(queued.has_value()) << "driverQueuedWallNanos not found"; + ASSERT_TRUE(onThread.has_value()) << "driverOnThreadWallNanos not found"; + ASSERT_TRUE(blocked.has_value()) << "driverBlockedWallNanos not found"; + + // Each metric should have count >= 1 (at least one driver reported). + ASSERT_GE(queued->count, 1); + ASSERT_GE(onThread->count, 1); + ASSERT_GE(blocked->count, 1); + + // On-thread time must be positive (drivers did real work). + ASSERT_GT(onThread->sum, 0); + ASSERT_GT(onThread->max, 0); + + // Queued and blocked times must be non-negative. + ASSERT_GE(queued->sum, 0); + ASSERT_GE(blocked->sum, 0); + + // Verify the stat names contain the pipeline prefix and source operator. + // Look for probe pipeline stats (source is Values with probeScanId). + auto probePrefix = fmt::format("P0-Values.{}", probeScanId); + auto probeOnThread = findDriverStat(probePrefix + ".driverOnThreadWallNanos"); + ASSERT_TRUE(probeOnThread.has_value()) + << "Probe pipeline stat not found with prefix: " << probePrefix; + ASSERT_GT(probeOnThread->max, 0); +} + DEBUG_ONLY_TEST_F(TaskTest, driverEnqueAfterFailedAndPausedTask) { const auto data = makeRowVector({ makeFlatVector(50, [](auto row) { return row; }), @@ -2530,8 +2674,8 @@ DEBUG_ONLY_TEST_F(TaskTest, driverEnqueAfterFailedAndPausedTask) { folly::EventCount driverWait; SCOPED_TESTVALUE_SET( "facebook::velox::exec::Task::enter", - std::function( - ([&](const velox::exec::ThreadState* /*unused*/) { + std::function( + ([&](const velox::exec::Task* /*unused*/) { driverWait.await([&]() { return !driverWaitFlag.load(); }); }))); @@ -2540,7 +2684,8 @@ DEBUG_ONLY_TEST_F(TaskTest, driverEnqueAfterFailedAndPausedTask) { std::move(plan), 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); task->start(4, 1); // Request pause. @@ -2576,17 +2721,18 @@ DEBUG_ONLY_TEST_F(TaskTest, taskReclaimFailure) { [&](SpillerBase* /*unused*/) { VELOX_FAIL(spillTableError); })); TestScopedSpillInjection injection(100); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); VELOX_ASSERT_THROW( AssertQueryBuilder(duckDbQueryRunner_) .spillDirectory(spillDirectory->getPath()) .config(core::QueryConfig::kSpillEnabled, true) .config(core::QueryConfig::kAggregationSpillEnabled, true) .maxDrivers(1) - .plan(PlanBuilder() - .values(inputVectors) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .planNode()) + .plan( + PlanBuilder() + .values(inputVectors) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .planNode()) .assertResults( "SELECT c0, c1, array_agg(c2) FROM tmp GROUP BY c0, c1"), spillTableError); @@ -2611,10 +2757,11 @@ DEBUG_ONLY_TEST_F(TaskTest, taskDeletionPromise) { std::thread queryThread([&]() { AssertQueryBuilder(duckDbQueryRunner_) .maxDrivers(1) - .plan(PlanBuilder() - .values(inputVectors) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .planNode()) + .plan( + PlanBuilder() + .values(inputVectors) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .planNode()) .assertResults("SELECT c0, c1, array_agg(c2) FROM tmp GROUP BY c0, c1"); }); @@ -2650,8 +2797,8 @@ DEBUG_ONLY_TEST_F(TaskTest, taskCancellation) { folly::EventCount driverWait; SCOPED_TESTVALUE_SET( "facebook::velox::exec::Task::enter", - std::function( - ([&](const velox::exec::ThreadState* /*unused*/) { + std::function( + ([&](const velox::exec::Task* /*unused*/) { driverWait.await([&]() { return !driverWaitFlag.load(); }); }))); @@ -2660,7 +2807,8 @@ DEBUG_ONLY_TEST_F(TaskTest, taskCancellation) { std::move(plan), 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); task->start(4, 1); auto cancellationToken = task->getCancellationToken(); ASSERT_FALSE(cancellationToken.isCancellationRequested()); @@ -2724,17 +2872,42 @@ TEST_F(TaskTest, invalidPlanNodeForBarrier) { .project({"c0 + 5"}) .planFragment(); ASSERT_TRUE(plan.firstNodeNotSupportingBarrier()); + { + const auto task = Task::create( + "invalidPlanNodeForBarrier", + plan, + 0, + core::QueryCtx::create(), + Task::ExecutionMode::kSerial); + ASSERT_TRUE(!task->underBarrier()); + VELOX_ASSERT_THROW( + task->requestBarrier(), + "Name of the first node that doesn't support barriered execution:"); + while (auto next = task->next()) { + } + ASSERT_TRUE(task->isFinished()); + } - const auto task = Task::create( - "invalidPlanNodeForBarrier", - plan, - 0, - core::QueryCtx::create(), - Task::ExecutionMode::kSerial); - ASSERT_TRUE(!task->underBarrier()); - VELOX_ASSERT_THROW( - task->requestBarrier(), - "Name of the first node that doesn't support barriered execution:"); + { + const auto task = Task::create( + "invalidPlanNodeForBarrier", + plan, + 0, + core::QueryCtx::create(executor_.get()), + Task::ExecutionMode::kParallel, + [](const RowVectorPtr& vector, + bool drained, + velox::ContinueFuture* future) { + return BlockingReason::kNotBlocked; + }); + task->start(2); + ASSERT_TRUE(!task->underBarrier()); + VELOX_ASSERT_THROW( + task->requestBarrier(), + "Name of the first node that doesn't support barriered execution:"); + task->taskCompletionFuture().wait(); + ASSERT_TRUE(task->isFinished()); + } } TEST_F(TaskTest, barrierAfterNoMoreSplits) { @@ -2752,50 +2925,54 @@ TEST_F(TaskTest, barrierAfterNoMoreSplits) { .filter("c0 < 100") .project({"c0 + 5"}) .planFragment(); + { + const auto task = Task::create( + "barrierAfterNoMoreSplits", + plan, + 0, + core::QueryCtx::create(), + Task::ExecutionMode::kSerial); + ASSERT_TRUE(!task->underBarrier()); - const auto task = Task::create( - "barrierAfterNoMoreSplits", - plan, - 0, - core::QueryCtx::create(), - Task::ExecutionMode::kSerial); - ASSERT_TRUE(!task->underBarrier()); - - task->addSplit( - scanId, exec::Split(makeHiveConnectorSplit(filePath->getPath()))); - task->noMoreSplits(scanId); - ASSERT_TRUE(!task->underBarrier()); - - VELOX_ASSERT_THROW( - task->requestBarrier(), - "Can't start barrier on task which has already received no more splits"); -} + task->addSplit( + scanId, exec::Split(makeHiveConnectorSplit(filePath->getPath()))); + task->noMoreSplits(scanId); + ASSERT_TRUE(!task->underBarrier()); -TEST_F(TaskTest, invalidTaskModeForBarrier) { - auto data = makeRowVector({ - makeFlatVector(1'000, [](auto row) { return row; }), - }); + VELOX_ASSERT_THROW( + task->requestBarrier(), + "Can't start barrier on task which has already received no more splits"); + task->requestAbort().wait(); + ASSERT_TRUE(!task->isRunning()); + } - // Filter + Project. - core::PlanNodeId scanId; - const auto plan = PlanBuilder() - .tableScan(asRowType(data->type())) - .capturePlanNodeId(scanId) - .filter("c0 < 100") - .project({"c0 + 5"}) - .planFragment(); - ASSERT_TRUE(plan.firstNodeNotSupportingBarrier() == nullptr); + { + const auto task = Task::create( + "invalidPlanNodeForBarrier", + plan, + 0, + core::QueryCtx::create(executor_.get()), + Task::ExecutionMode::kParallel, + [](const RowVectorPtr& vector, + bool drained, + velox::ContinueFuture* future) { + return BlockingReason::kNotBlocked; + }); + + ASSERT_TRUE(!task->underBarrier()); + task->start(2); + task->addSplit( + scanId, exec::Split(makeHiveConnectorSplit(filePath->getPath()))); + task->noMoreSplits(scanId); + ASSERT_TRUE(!task->underBarrier()); - const auto task = Task::create( - "invalidTaskModeForBarrier", - plan, - 0, - core::QueryCtx::create(), - Task::ExecutionMode::kParallel); - ASSERT_TRUE(!task->underBarrier()); - VELOX_ASSERT_THROW( - task->requestBarrier(), - "(Parallel vs. Serial) Task doesn't support barriered execution."); + VELOX_ASSERT_THROW( + task->requestBarrier(), + "Can't start barrier on task which has already received no more splits"); + task->requestAbort().wait(); + ASSERT_TRUE(!task->isRunning()); + } + waitForAllTasksToBeDeleted(); } TEST_F(TaskTest, addSplitAfterBarrier) { @@ -2814,30 +2991,55 @@ TEST_F(TaskTest, addSplitAfterBarrier) { .project({"c0 + 5"}) .planFragment(); ASSERT_TRUE(plan.firstNodeNotSupportingBarrier() == nullptr); + { + const auto task = Task::create( + "barrierAfterNoMoreSplits", + plan, + 0, + core::QueryCtx::create(), + Task::ExecutionMode::kSerial); + ASSERT_TRUE(!task->underBarrier()); - const auto task = Task::create( - "barrierAfterNoMoreSplits", - plan, - 0, - core::QueryCtx::create(), - Task::ExecutionMode::kSerial); - ASSERT_TRUE(!task->underBarrier()); + task->addSplit( + scanId, exec::Split(makeHiveConnectorSplit(filePath->getPath()))); + auto future = task->requestBarrier(); + ASSERT_TRUE(task->underBarrier()); + VELOX_ASSERT_THROW( + task->addSplit( + scanId, exec::Split(makeHiveConnectorSplit(filePath->getPath()))), + "Can't add new split under barrier processing"); + std::this_thread::sleep_for(std::chrono::seconds(1)); // NOLINT + ASSERT_TRUE(task->isRunning()); + ASSERT_FALSE(future.isReady()); + task->requestAbort().wait(); + ASSERT_TRUE(!task->isRunning()); + ASSERT_TRUE(future.isReady()); + future.wait(); + } - task->addSplit( - scanId, exec::Split(makeHiveConnectorSplit(filePath->getPath()))); - auto future = task->requestBarrier(); - ASSERT_TRUE(task->underBarrier()); - VELOX_ASSERT_THROW( - task->addSplit( - scanId, exec::Split(makeHiveConnectorSplit(filePath->getPath()))), - "Can't add new split under barrier processing"); - std::this_thread::sleep_for(std::chrono::seconds(1)); // NOLINT - ASSERT_TRUE(task->isRunning()); - ASSERT_FALSE(future.isReady()); - task->requestAbort().wait(); - ASSERT_TRUE(!task->isRunning()); - ASSERT_TRUE(future.isReady()); - future.wait(); + { + const auto task = Task::create( + "barrierAfterNoMoreSplits", + plan, + 0, + core::QueryCtx::create(executor_.get()), + Task::ExecutionMode::kParallel, + [](const RowVectorPtr& vector, + bool drained, + velox::ContinueFuture* future) { + return BlockingReason::kNotBlocked; + }); + task->start(2); + ASSERT_TRUE(!task->underBarrier()); + task->addSplit( + scanId, exec::Split(makeHiveConnectorSplit(filePath->getPath()))); + auto future = task->requestBarrier(); + future.wait(); + ASSERT_FALSE(task->underBarrier()); + task->requestAbort().wait(); + ASSERT_TRUE(!task->isRunning()); + } + waitForAllTasksToBeDeleted(); } TEST_F(TaskTest, testTerminateDuringBarrier) { @@ -2854,36 +3056,119 @@ TEST_F(TaskTest, testTerminateDuringBarrier) { .capturePlanNodeId(scanId) .project({"c0"}) .planFragment(); + { + auto queryCtx = core::QueryCtx::create(); + queryCtx->testingOverrideConfigUnsafe( + {{core::QueryConfig::kMaxOutputBatchRows, "1"}}); + const auto task = Task::create( + "testTerminateDuringBarrier", + plan, + 0, + std::move(queryCtx), + Task::ExecutionMode::kSerial); + ASSERT_TRUE(!task->underBarrier()); + task->addSplit( + scanId, exec::Split(makeHiveConnectorSplit(filePath->getPath()))); + auto barrierFuture = task->requestBarrier(); + ASSERT_TRUE(task->underBarrier()); + for (int i = 0; i < numRows / 2; ++i) { + ContinueFuture future{ContinueFuture::makeEmpty()}; + auto result = task->next(&future); + ASSERT_TRUE(result != nullptr); + ASSERT_FALSE(future.valid()); + } + task->requestAbort(); + ASSERT_FALSE(task->isRunning()); + VELOX_ASSERT_THROW( + task->next(nullptr), "Task has already finished processing"); + ASSERT_TRUE(barrierFuture.isReady()); + barrierFuture.wait(); + const auto taskStats = task->taskStats(); + ASSERT_EQ(taskStats.numFinishedSplits, 0); + ASSERT_EQ(taskStats.numBarriers, 1); + } - auto queryCtx = core::QueryCtx::create(); - queryCtx->testingOverrideConfigUnsafe( - {{core::QueryConfig::kMaxOutputBatchRows, "1"}}); - const auto task = Task::create( - "testTerminateDuringBarrier", - plan, - 0, - std::move(queryCtx), - Task::ExecutionMode::kSerial); - ASSERT_TRUE(!task->underBarrier()); - task->addSplit( - scanId, exec::Split(makeHiveConnectorSplit(filePath->getPath()))); - auto barrierFuture = task->requestBarrier(); - ASSERT_TRUE(task->underBarrier()); - for (int i = 0; i < numRows / 2; ++i) { - ContinueFuture future{ContinueFuture::makeEmpty()}; - auto result = task->next(&future); - ASSERT_TRUE(result != nullptr); - ASSERT_FALSE(future.valid()); + { + const auto task = Task::create( + "invalidPlanNodeForBarrier", + plan, + 0, + core::QueryCtx::create(executor_.get()), + Task::ExecutionMode::kParallel, + [](const RowVectorPtr& vector, + bool drained, + velox::ContinueFuture* future) { + return BlockingReason::kNotBlocked; + }); + task->start(2); + ASSERT_TRUE(!task->underBarrier()); + task->addSplit( + scanId, exec::Split(makeHiveConnectorSplit(filePath->getPath()))); + auto barrierFuture = task->requestBarrier(); + task->requestAbort().wait(); + ASSERT_FALSE(task->isRunning()); + ASSERT_TRUE(barrierFuture.isReady()); + ASSERT_EQ(task->taskStats().numBarriers, 1); } - task->requestAbort(); - ASSERT_FALSE(task->isRunning()); - VELOX_ASSERT_THROW( - task->next(nullptr), "Task has already finished processing"); - ASSERT_TRUE(barrierFuture.isReady()); - barrierFuture.wait(); - const auto taskStats = task->taskStats(); - ASSERT_EQ(taskStats.numFinishedSplits, 0); - ASSERT_EQ(taskStats.numBarriers, 1); +} + +TEST_F(TaskTest, testBarrierClearedOnTerminate) { + auto data = makeRowVector({ + makeFlatVector(100, [](auto row) { return row; }), + }); + auto filePath = TempFilePath::create(); + writeToFile(filePath->getPath(), {data}); + + core::PlanNodeId scanId; + auto plan = PlanBuilder() + .tableScan(asRowType(data->type())) + .capturePlanNodeId(scanId) + .project({"c0"}) + .planFragment(); + + // Verify that barrierRequested_ is cleared when task is aborted while under + // barrier in serial execution mode. + { + const auto task = Task::create( + "barrierClearedOnTerminate.serial", + plan, + 0, + core::QueryCtx::create(), + Task::ExecutionMode::kSerial); + task->addSplit( + scanId, exec::Split(makeHiveConnectorSplit(filePath->getPath()))); + auto barrierFuture = task->requestBarrier(); + ASSERT_TRUE(task->underBarrier()); + task->requestAbort().wait(); + ASSERT_FALSE(task->isRunning()); + ASSERT_FALSE(task->underBarrier()); + ASSERT_TRUE(barrierFuture.isReady()); + } + + // Verify that barrierRequested_ is cleared when task is aborted while under + // barrier in parallel execution mode. + { + const auto task = Task::create( + "barrierClearedOnTerminate.parallel", + plan, + 0, + core::QueryCtx::create(executor_.get()), + Task::ExecutionMode::kParallel, + [](const RowVectorPtr& /*vector*/, + bool /*drained*/, + velox::ContinueFuture* /*future*/) { + return BlockingReason::kNotBlocked; + }); + task->start(2); + task->addSplit( + scanId, exec::Split(makeHiveConnectorSplit(filePath->getPath()))); + auto barrierFuture = task->requestBarrier(); + task->requestAbort().wait(); + ASSERT_FALSE(task->isRunning()); + ASSERT_FALSE(task->underBarrier()); + ASSERT_TRUE(barrierFuture.isReady()); + } + waitForAllTasksToBeDeleted(); } namespace { @@ -2964,7 +3249,7 @@ class TestBarrierOperator : public exec::Operator { return nullptr; } auto output = inputs_.front(); - inputs_.erase(inputs_.begin()); + inputs_.erase(inputs_.cbegin()); return output; } @@ -3378,4 +3663,74 @@ DEBUG_ONLY_TEST_F(TaskTest, operatorShouldYieldMethod) { } } +// Verifies that blocked wait time is recorded even when an operator is +// terminated while blocked. +DEBUG_ONLY_TEST_F(TaskTest, blockedWaitTimeOnAbort) { + // Test that blocked time is recorded even when a task is aborted while + // an operator is blocked. + // + // We use a simple table scan that blocks waiting for splits. By starting + // the task without adding splits, the scan operator will block on + // kWaitForSplit. We then abort the task while blocked and verify that + // the blocked time is recorded via closeByTask(). + constexpr int kBlockTimeMs = 100; + + // Build a simple plan with a table scan. + core::PlanNodeId scanNodeId; + auto plan = PlanBuilder() + .tableScan(ROW({"c0"}, {BIGINT()})) + .capturePlanNodeId(scanNodeId) + .planFragment(); + + auto task = Task::create( + "blockedWaitTimeOnAbort", + plan, + 0, + core::QueryCtx::create(driverExecutor_.get()), + Task::ExecutionMode::kParallel); + + task->start(1, 1); + + // Wait for the driver to become blocked waiting for splits. + auto startTime = std::chrono::steady_clock::now(); + while (BlockingState::numBlockedDrivers() == 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + auto elapsed = std::chrono::duration_cast( + std::chrono::steady_clock::now() - startTime) + .count(); + if (elapsed > 5) { + task->requestAbort().wait(); + GTEST_SKIP() << "Operator did not block in time"; + } + } + + // Let some time pass while blocked to have measurable blocked time. + std::this_thread::sleep_for(std::chrono::milliseconds(kBlockTimeMs)); + + // Abort the task while the operator is still blocked waiting for splits. + task->requestAbort().wait(); + + // Wait for the task to be fully aborted. + ASSERT_TRUE(waitForTaskAborted(task.get())); + + // Verify that blocked wait time was recorded despite the abort. + const auto stats = task->taskStats().pipelineStats; + ASSERT_FALSE(stats.empty()); + ASSERT_FALSE(stats[0].operatorStats.empty()); + + // Find operator stats with blocked time recorded. + bool foundBlockedTime = false; + for (const auto& opStats : stats[0].operatorStats) { + if (opStats.blockedWallNanos > 0) { + foundBlockedTime = true; + // Verify the blocked time is at least what we waited (with tolerance). + EXPECT_GE(opStats.blockedWallNanos, (kBlockTimeMs - 30) * 1'000'000) + << "Blocked time should be at least " << (kBlockTimeMs - 30) << "ms"; + break; + } + } + EXPECT_TRUE(foundBlockedTime) + << "Operator should have recorded blocked time despite task abort"; +} + } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/ThreadDebugInfoTest.cpp b/velox/exec/tests/ThreadDebugInfoTest.cpp index 75bab5b3759..6bcb1801e14 100644 --- a/velox/exec/tests/ThreadDebugInfoTest.cpp +++ b/velox/exec/tests/ThreadDebugInfoTest.cpp @@ -63,6 +63,7 @@ template struct InduceSegFaultFunction { template void call(TResult& out, const TInput& in) { + LOG(ERROR) << "error"; int* nullpointer = nullptr; *nullpointer = 6; } @@ -117,9 +118,10 @@ DEBUG_ONLY_TEST_F(ThreadDebugInfoDeathTest, withinTheCallingThread) { #ifndef IS_BUILDING_WITH_SAN ASSERT_DEATH( - (task->next()), + task->next(), ".*Fatal signal handler. Query Id= TaskCursorQuery_0 Task Id= single.execution.task.0.*"); #endif + task->requestCancel(); } DEBUG_ONLY_TEST_F(ThreadDebugInfoDeathTest, noThreadContextSet) { diff --git a/velox/exec/tests/TopNRowNumberTest.cpp b/velox/exec/tests/TopNRowNumberTest.cpp index fac57658c2c..7f79ab80de7 100644 --- a/velox/exec/tests/TopNRowNumberTest.cpp +++ b/velox/exec/tests/TopNRowNumberTest.cpp @@ -15,26 +15,40 @@ */ #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/FileSystems.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/exec/PlanNodeStats.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/OperatorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" using namespace facebook::velox::exec::test; namespace facebook::velox::exec { +using namespace facebook::velox::common::testutil; namespace { class TopNRowNumberTest : public OperatorTestBase { protected: - TopNRowNumberTest() { + explicit TopNRowNumberTest(core::TopNRowNumberNode::RankFunction function) + : functionName_(core::TopNRowNumberNode::rankFunctionName(function)) {} + + void SetUp() override { + exec::test::OperatorTestBase::SetUp(); filesystems::registerLocalFileSystem(); } + + const std::string functionName_; +}; + +class MultiTopNRowNumberTest : public TopNRowNumberTest, + public testing::WithParamInterface< + core::TopNRowNumberNode::RankFunction> { + public: + MultiTopNRowNumberTest() : TopNRowNumberTest(GetParam()) {} }; -TEST_F(TopNRowNumberTest, basic) { +TEST_P(MultiTopNRowNumberTest, basic) { auto data = makeRowVector({ // Partitioning key. makeFlatVector({1, 1, 2, 2, 1, 2, 1}), @@ -50,38 +64,101 @@ TEST_F(TopNRowNumberTest, basic) { // Emit row numbers. auto plan = PlanBuilder() .values({data}) - .topNRowNumber({"c0"}, {"c1"}, limit, true) + .topNRank(functionName_, {"c0"}, {"c1"}, limit, true) + .planNode(); + assertQuery( + plan, + fmt::format( + "SELECT * FROM (SELECT *, {}() over (partition by c0 order by c1) as rn FROM tmp) " + " WHERE rn <= {}", + functionName_, + limit)); + + // Do not emit row numbers. + plan = PlanBuilder() + .values({data}) + .topNRank(functionName_, {"c0"}, {"c1"}, limit, false) + .planNode(); + + assertQuery( + plan, + fmt::format( + "SELECT c0, c1, c2 FROM (SELECT *, {}() over (partition by c0 order by c1) as rn FROM tmp) " + " WHERE rn <= {}", + functionName_, + limit)); + + // No partitioning keys. + plan = PlanBuilder() + .values({data}) + .topNRank(functionName_, {}, {"c1"}, limit, true) + .planNode(); + assertQuery( + plan, + fmt::format( + "SELECT * FROM (SELECT *, {}() over (order by c1) as rn FROM tmp) " + " WHERE rn <= {}", + functionName_, + limit)); + }; + + testLimit(1); + testLimit(2); + testLimit(3); + testLimit(5); +} + +TEST_P(MultiTopNRowNumberTest, basicWithPeers) { + auto data = makeRowVector({ + // Partitioning key. + makeFlatVector({1, 1, 2, 2, 1, 2, 1, 1, 1, 1, 1}), + // Sorting key. + makeFlatVector({33, 11, 55, 44, 11, 22, 11, 11, 11, 33, 33}), + // Data. Mapping data to matching sorting keys to avoid ordering issues. + makeFlatVector({10, 50, 30, 40, 50, 60, 50, 50, 50, 10, 10}), + }); + + createDuckDbTable({data}); + + auto testLimit = [&](auto limit) { + // Emit row numbers. + auto plan = PlanBuilder() + .values({data}) + .topNRank(functionName_, {"c0"}, {"c1"}, limit, true) .planNode(); assertQuery( plan, fmt::format( - "SELECT * FROM (SELECT *, row_number() over (partition by c0 order by c1) as rn FROM tmp) " + "SELECT * FROM (SELECT *, {}() over (partition by c0 order by c1) as rn FROM tmp) " " WHERE rn <= {}", + functionName_, limit)); // Do not emit row numbers. plan = PlanBuilder() .values({data}) - .topNRowNumber({"c0"}, {"c1"}, limit, false) + .topNRank(functionName_, {"c0"}, {"c1"}, limit, false) .planNode(); assertQuery( plan, fmt::format( - "SELECT c0, c1, c2 FROM (SELECT *, row_number() over (partition by c0 order by c1) as rn FROM tmp) " + "SELECT c0, c1, c2 FROM (SELECT *, {}() over (partition by c0 order by c1) as rn FROM tmp) " " WHERE rn <= {}", + functionName_, limit)); // No partitioning keys. plan = PlanBuilder() .values({data}) - .topNRowNumber({}, {"c1"}, limit, true) + .topNRank(functionName_, {}, {"c1"}, limit, true) .planNode(); assertQuery( plan, fmt::format( - "SELECT * FROM (SELECT *, row_number() over (order by c1) as rn FROM tmp) " + "SELECT * FROM (SELECT *, {}() over (order by c1) as rn FROM tmp) " " WHERE rn <= {}", + functionName_, limit)); }; @@ -91,7 +168,7 @@ TEST_F(TopNRowNumberTest, basic) { testLimit(5); } -TEST_F(TopNRowNumberTest, largeOutput) { +TEST_P(MultiTopNRowNumberTest, largeOutput) { // Make 10 vectors. Use different types for partitioning key, sorting key and // data. Use order of columns different from partitioning keys, followed by // sorting keys, followed by data. @@ -112,20 +189,21 @@ TEST_F(TopNRowNumberTest, largeOutput) { createDuckDbTable(data); - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); auto testLimit = [&](auto limit) { SCOPED_TRACE(fmt::format("Limit: {}", limit)); core::PlanNodeId topNRowNumberId; auto plan = PlanBuilder() .values(data) - .topNRowNumber({"p"}, {"s"}, limit, true) + .topNRank(functionName_, {"p"}, {"s"}, limit, true) .capturePlanNodeId(topNRowNumberId) .planNode(); auto sql = fmt::format( - "SELECT * FROM (SELECT *, row_number() over (partition by p order by s) as rn FROM tmp) " + "SELECT * FROM (SELECT *, {}() over (partition by p order by s) as rn FROM tmp) " " WHERE rn <= {}", + functionName_, limit); AssertQueryBuilder(plan, duckDbQueryRunner_) .config(core::QueryConfig::kPreferredOutputBatchBytes, "1024") @@ -154,15 +232,17 @@ TEST_F(TopNRowNumberTest, largeOutput) { // No partitioning keys. plan = PlanBuilder() .values(data) - .topNRowNumber({}, {"s"}, limit, true) + .topNRank(functionName_, {}, {"s"}, limit, true) .planNode(); AssertQueryBuilder(plan, duckDbQueryRunner_) .config(core::QueryConfig::kPreferredOutputBatchBytes, "1024") - .assertResults(fmt::format( - "SELECT * FROM (SELECT *, row_number() over (order by s) as rn FROM tmp) " - " WHERE rn <= {}", - limit)); + .assertResults( + fmt::format( + "SELECT * FROM (SELECT *, {}() over (order by s) as rn FROM tmp) " + " WHERE rn <= {}", + functionName_, + limit)); }; testLimit(1); @@ -172,7 +252,7 @@ TEST_F(TopNRowNumberTest, largeOutput) { testLimit(2000); } -TEST_F(TopNRowNumberTest, manyPartitions) { +TEST_P(MultiTopNRowNumberTest, manyPartitions) { const vector_size_t size = 10'000; auto data = split( makeRowVector( @@ -196,20 +276,21 @@ TEST_F(TopNRowNumberTest, manyPartitions) { createDuckDbTable(data); - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); auto testLimit = [&](auto limit, size_t outputBatchBytes = 1024) { SCOPED_TRACE(fmt::format("Limit: {}", limit)); core::PlanNodeId topNRowNumberId; auto plan = PlanBuilder() .values(data) - .topNRowNumber({"p"}, {"s"}, limit, true) + .topNRank(functionName_, {"p"}, {"s"}, limit, true) .capturePlanNodeId(topNRowNumberId) .planNode(); auto sql = fmt::format( - "SELECT * FROM (SELECT *, row_number() over (partition by p order by s) as rn FROM tmp) " + "SELECT * FROM (SELECT *, {}() over (partition by p order by s) as rn FROM tmp) " " WHERE rn <= {}", + functionName_, limit); assertQuery(plan, sql); @@ -243,7 +324,7 @@ TEST_F(TopNRowNumberTest, manyPartitions) { testLimit(1, 1); } -TEST_F(TopNRowNumberTest, fewPartitions) { +TEST_P(MultiTopNRowNumberTest, fewPartitions) { const vector_size_t size = 10'000; auto data = split( makeRowVector( @@ -267,20 +348,21 @@ TEST_F(TopNRowNumberTest, fewPartitions) { createDuckDbTable(data); - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); auto testLimit = [&](auto limit, size_t outputBatchBytes = 1024) { SCOPED_TRACE(fmt::format("Limit: {}", limit)); core::PlanNodeId topNRowNumberId; auto plan = PlanBuilder() .values(data) - .topNRowNumber({"p"}, {"s"}, limit, true) + .topNRank(functionName_, {"p"}, {"s"}, limit, true) .capturePlanNodeId(topNRowNumberId) .planNode(); auto sql = fmt::format( - "SELECT * FROM (SELECT *, row_number() over (partition by p order by s) as rn FROM tmp) " + "SELECT * FROM (SELECT *, {}() over (partition by p order by s) as rn FROM tmp) " " WHERE rn <= {}", + functionName_, limit); assertQuery(plan, sql); @@ -312,7 +394,7 @@ TEST_F(TopNRowNumberTest, fewPartitions) { testLimit(100); } -TEST_F(TopNRowNumberTest, abandonPartialEarly) { +TEST_P(MultiTopNRowNumberTest, abandonPartialEarly) { auto data = makeRowVector( {"p", "s"}, { @@ -326,9 +408,9 @@ TEST_F(TopNRowNumberTest, abandonPartialEarly) { auto runPlan = [&](int32_t minRows) { auto plan = PlanBuilder() .values(split(data, 10)) - .topNRowNumber({"p"}, {"s"}, 99, false) + .topNRank(functionName_, {"p"}, {"s"}, 99, false) .capturePlanNodeId(topNRowNumberId) - .topNRowNumber({"p"}, {"s"}, 99, true) + .topNRank(functionName_, {"p"}, {"s"}, 99, true) .planNode(); auto task = AssertQueryBuilder(plan, duckDbQueryRunner_) @@ -337,8 +419,10 @@ TEST_F(TopNRowNumberTest, abandonPartialEarly) { fmt::format("{}", minRows)) .config(core::QueryConfig::kAbandonPartialTopNRowNumberMinPct, "80") .assertResults( - "SELECT * FROM (SELECT *, row_number() over (partition by p order by s) as rn FROM tmp) " - "WHERE rn <= 99"); + fmt::format( + "SELECT * FROM (SELECT *, {}() over (partition by p order by s) as rn FROM tmp) " + "WHERE rn <= 99", + functionName_)); return exec::toPlanStats(task->taskStats()); }; @@ -360,7 +444,7 @@ TEST_F(TopNRowNumberTest, abandonPartialEarly) { } } -TEST_F(TopNRowNumberTest, planNodeValidation) { +TEST_P(MultiTopNRowNumberTest, planNodeValidation) { auto data = makeRowVector( ROW({"a", "b", "c", "d", "e"}, { @@ -377,7 +461,7 @@ TEST_F(TopNRowNumberTest, planNodeValidation) { int32_t limit = 10) { PlanBuilder() .values({data}) - .topNRowNumber(partitionKeys, sortingKeys, limit, true) + .topNRank(functionName_, partitionKeys, sortingKeys, limit, true) .planNode(); }; @@ -403,15 +487,16 @@ TEST_F(TopNRowNumberTest, planNodeValidation) { plan({"a", "b"}, {"c"}, 0), "Limit must be greater than zero"); } -TEST_F(TopNRowNumberTest, maxSpillBytes) { +TEST_P(MultiTopNRowNumberTest, maxSpillBytes) { const auto rowType = ROW({"c0", "c1", "c2"}, {INTEGER(), INTEGER(), VARCHAR()}); const auto vectors = createVectors(rowType, 1024, 15 << 20); auto planNodeIdGenerator = std::make_shared(); auto plan = PlanBuilder(planNodeIdGenerator) .values(vectors) - .topNRowNumber({"c0"}, {"c1"}, 100, true) + .topNRank(functionName_, {"c0"}, {"c1"}, 100, true) .planNode(); + struct { int32_t maxSpilledBytes; bool expectedExceedLimit; @@ -420,7 +505,7 @@ TEST_F(TopNRowNumberTest, maxSpillBytes) { } } testSettings[] = {{1 << 30, false}, {13 << 20, true}, {0, false}}; - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); for (const auto& testData : testSettings) { @@ -451,7 +536,7 @@ TEST_F(TopNRowNumberTest, maxSpillBytes) { // This test verifies that TopNRowNumber operator reclaim all the memory after // spill. -DEBUG_ONLY_TEST_F(TopNRowNumberTest, memoryUsageCheckAfterReclaim) { +DEBUG_ONLY_TEST_P(MultiTopNRowNumberTest, memoryUsageCheckAfterReclaim) { std::atomic_int inputCount{0}; SCOPED_TESTVALUE_SET( "facebook::velox::exec::Driver::runInternal::addInput", @@ -491,18 +576,19 @@ DEBUG_ONLY_TEST_F(TopNRowNumberTest, memoryUsageCheckAfterReclaim) { createDuckDbTable(data); - auto spillDirectory = exec::test::TempDirectoryPath::create(); + auto spillDirectory = TempDirectoryPath::create(); core::PlanNodeId topNRowNumberId; auto plan = PlanBuilder() .values(data) - .topNRowNumber({"p"}, {"s"}, 1'000, true) + .topNRank(functionName_, {"p"}, {"s"}, 1'000, true) .capturePlanNodeId(topNRowNumberId) .planNode(); - const auto sql = - "SELECT * FROM (SELECT *, row_number() over (partition by p order by s) as rn FROM tmp) " - " WHERE rn <= 1000"; + const auto sql = fmt::format( + "SELECT * FROM (SELECT *, {}() over (partition by p order by s) as rn FROM tmp) " + " WHERE rn <= 1000", + functionName_); auto task = AssertQueryBuilder(plan, duckDbQueryRunner_) .config(core::QueryConfig::kSpillEnabled, "true") .config(core::QueryConfig::kTopNRowNumberSpillEnabled, "true") @@ -520,7 +606,7 @@ DEBUG_ONLY_TEST_F(TopNRowNumberTest, memoryUsageCheckAfterReclaim) { // This test verifies that TopNRowNumber operator can be closed twice which // might be triggered by memory pool abort. -DEBUG_ONLY_TEST_F(TopNRowNumberTest, doubleClose) { +DEBUG_ONLY_TEST_P(MultiTopNRowNumberTest, doubleClose) { const std::string errorMessage("doubleClose"); SCOPED_TESTVALUE_SET( "facebook::velox::exec::Driver::runInternal::noMoreInput", @@ -556,15 +642,72 @@ DEBUG_ONLY_TEST_F(TopNRowNumberTest, doubleClose) { core::PlanNodeId topNRowNumberId; auto plan = PlanBuilder() .values(data) - .topNRowNumber({"p"}, {"s"}, 1'000, true) + .topNRank(functionName_, {"p"}, {"s"}, 1'000, true) .capturePlanNodeId(topNRowNumberId) .planNode(); - const auto sql = - "SELECT * FROM (SELECT *, row_number() over (partition by p order by s) as rn FROM tmp) " - " WHERE rn <= 1000"; + const auto sql = fmt::format( + "SELECT * FROM (SELECT *, {}() over (partition by p order by s) as rn FROM tmp) " + " WHERE rn <= 1000", + functionName_); VELOX_ASSERT_THROW(assertQuery(plan, sql), errorMessage); } + +// This test verifies that TopNRowNumber operator handles OOM that occurs in the +// middle of groupProbe, after inserting some new rows into the row container. +DEBUG_ONLY_TEST_P(MultiTopNRowNumberTest, oomInGroupProbe) { + const std::string errorMessage("Simulated OOM in groupProbe"); + std::atomic_int insertCount{0}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::HashTable::insertEntry", + std::function( + ([&](memory::MemoryPool* /*pool*/) { + // Trigger OOM after inserting some rows to simulate failure in the + // middle of groupProbe insertion. + if (++insertCount == 100) { + VELOX_FAIL(errorMessage); + } + }))); + + const vector_size_t size = 10'000; + auto data = split( + makeRowVector( + {"d", "s", "p"}, + { + // Data. + makeFlatVector( + size, [](auto row) { return row; }, nullEvery(11)), + // Sorting key. + makeFlatVector( + size, + [](auto row) { return (size - row) * 10; }, + [](auto row) { return row == 123; }), + // Partitioning key. Make sure to spread rows from the same + // partition across multiple batches. + makeFlatVector( + size, [](auto row) { return row % 5'000; }, nullEvery(7)), + }), + 10); + + core::PlanNodeId topNRowNumberId; + auto plan = PlanBuilder() + .values(data) + .topNRank(functionName_, {"p"}, {"s"}, 1'000, true) + .capturePlanNodeId(topNRowNumberId) + .planNode(); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan).copyResults(pool_.get()), errorMessage); +} + +VELOX_INSTANTIATE_TEST_SUITE_P( + TopNRowNumberTest, + MultiTopNRowNumberTest, + testing::ValuesIn( + std::vector( + {core::TopNRowNumberNode::RankFunction::kRowNumber, + core::TopNRowNumberNode::RankFunction::kRank, + core::TopNRowNumberNode::RankFunction::kDenseRank}))); } // namespace } // namespace facebook::velox::exec diff --git a/velox/exec/tests/TraceUtilTest.cpp b/velox/exec/tests/TraceUtilTest.cpp index acb94280ba5..c5fdab5ff65 100644 --- a/velox/exec/tests/TraceUtilTest.cpp +++ b/velox/exec/tests/TraceUtilTest.cpp @@ -21,13 +21,13 @@ #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/FileSystems.h" #include "velox/common/file/tests/FaultyFileSystem.h" -#include "velox/exec/Trace.h" -#include "velox/exec/TraceUtil.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" - -using namespace facebook::velox::exec::test; +#include "velox/common/testutil/TempDirectoryPath.h" +#include "velox/exec/trace/Trace.h" +#include "velox/exec/trace/TraceUtil.h" namespace facebook::velox::exec::trace::test { +using namespace facebook::velox::common::testutil; + class TraceUtilTest : public testing::Test { protected: static void SetUpTestCase() { @@ -253,7 +253,7 @@ TEST_F(TraceUtilTest, getDriverIds) { } TEST_F(TraceUtilTest, createTraceDirectoryTest) { - auto tmpRootDir = exec::test::TempDirectoryPath::create(); + auto tmpRootDir = TempDirectoryPath::create(); auto tmpTraceDir = fmt::format( "{}{}/trace", tests::utils::FaultyFileSystem::scheme(), diff --git a/velox/exec/tests/TreeOfLosersTest.cpp b/velox/exec/tests/TreeOfLosersTest.cpp index 8e193e40fcd..0a56a4f4970 100644 --- a/velox/exec/tests/TreeOfLosersTest.cpp +++ b/velox/exec/tests/TreeOfLosersTest.cpp @@ -181,8 +181,9 @@ TEST_F(TreeOfLosersTest, allSorted) { TEST_F(TreeOfLosersTest, allEmpty) { for (bool testNextEqual : {false, true}) { for (int numStreams : {0, 1, 5, 100}) { - SCOPED_TRACE(fmt::format( - "numStreams: {}, testNextEqual", numStreams, testNextEqual)); + SCOPED_TRACE( + fmt::format( + "numStreams: {}, testNextEqual", numStreams, testNextEqual)); std::vector> mergeStreams; for (int i = 0; i < numStreams; ++i) { mergeStreams.push_back( @@ -211,12 +212,13 @@ TEST_F(TreeOfLosersTest, randomWithDuplicates) { for (int iter = 0; iter < 10; ++iter) { const int numCount = std::max(1, folly::Random::rand32(1000'000)); const int numStreams = std::max(3, folly::Random::rand32(100)); - SCOPED_TRACE(fmt::format( - "iter: {}, testNextEqual: {}, numCount: {}, numStreams: {}", - iter, - testNextEqual, - numCount, - numStreams)); + SCOPED_TRACE( + fmt::format( + "iter: {}, testNextEqual: {}, numCount: {}, numStreams: {}", + iter, + testNextEqual, + numCount, + numStreams)); std::vector> streamNumVectors(numStreams); for (int i = 0; i < numCount; ++i) { const int streamIndex = folly::Random::rand32(numStreams); diff --git a/velox/exec/tests/UnnestTest.cpp b/velox/exec/tests/UnnestTest.cpp index 53979988f05..d35a2deaa71 100644 --- a/velox/exec/tests/UnnestTest.cpp +++ b/velox/exec/tests/UnnestTest.cpp @@ -15,15 +15,16 @@ */ #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/testutil/OptionalEmpty.h" +#include "velox/common/testutil/TempFilePath.h" #include "velox/exec/PlanNodeStats.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempFilePath.h" using namespace facebook::velox; using namespace facebook::velox::exec; using namespace facebook::velox::exec::test; +using namespace facebook::velox::common::testutil; class UnnestTest : public HiveConnectorTestBase, public testing::WithParamInterface { @@ -285,7 +286,7 @@ TEST_P(UnnestTest, arrayWithOrdinality) { assertQuery(params, expectedInDict); } -TEST_P(UnnestTest, arrayWithEmptyUnnestValue) { +TEST_P(UnnestTest, arrayWithMarker) { const auto array = makeArrayVectorFromJson( {"[1, 2, null, 4]", "null", "[5, 6]", "[]", "[null]", "[7, 8, 9]"}); const auto input = makeRowVector( @@ -312,7 +313,7 @@ TEST_P(UnnestTest, arrayWithEmptyUnnestValue) { expected->childAt(1), makeNullableFlatVector({1, 2, 3, 4, 1, 2, 1, 1, 2, 3})}); - const auto expectedWithEmptyUnnestValue = makeRowVector( + const auto expectedWithMarker = makeRowVector( {makeNullableFlatVector( {1.1, 1.1, @@ -340,23 +341,23 @@ TEST_P(UnnestTest, arrayWithEmptyUnnestValue) { 8, 9}), makeNullableFlatVector( - {false, - false, - false, - false, + {true, + true, true, - false, - false, true, false, + true, + true, false, - false, - false})}); + true, + true, + true, + true})}); const auto expectedWithBoth = makeRowVector( - {expectedWithEmptyUnnestValue->childAt(0), - expectedWithEmptyUnnestValue->childAt(1), + {expectedWithMarker->childAt(0), + expectedWithMarker->childAt(1), makeNullableFlatVector({1, 2, 3, 4, 0, 1, 2, 0, 1, 1, 2, 3}), - expectedWithEmptyUnnestValue->childAt(2)}); + expectedWithMarker->childAt(2)}); struct { bool hasOrdinality; @@ -375,7 +376,7 @@ TEST_P(UnnestTest, arrayWithEmptyUnnestValue) { } testSettings[] = { {false, false, input, expected}, {true, false, input, expectedWithOrdinality}, - {false, true, input, expectedWithEmptyUnnestValue}, + {false, true, input, expectedWithMarker}, {true, true, input, expectedWithBoth}}; for (const auto& testData : testSettings) { @@ -385,13 +386,13 @@ TEST_P(UnnestTest, arrayWithEmptyUnnestValue) { if (testData.hasOrdinality) { ordinalityName = "ordinal"; } - std::optional emptyUnnestValueName; + std::optional markerName; if (testData.hasEmptyUnnestValue) { - emptyUnnestValueName = "emptyUnnestValue"; + markerName = "emptyUnnestValue"; } auto op = PlanBuilder() .values({testData.input}) - .unnest({"c0"}, {"c1"}, ordinalityName, emptyUnnestValueName) + .unnest({"c0"}, {"c1"}, ordinalityName, markerName) .planNode(); auto params = makeCursorParameters(op); assertQuery(params, testData.expected); @@ -466,7 +467,7 @@ TEST_P(UnnestTest, mapWithOrdinality) { assertQuery(params, expectedInDict); } -TEST_P(UnnestTest, mapWithEmptyUnnestValue) { +TEST_P(UnnestTest, mapWithMarker) { const auto map = makeNullableMapVector( {{{{1, 1.1}, {2, std::nullopt}}}, common::testutil::optionalEmpty, @@ -489,7 +490,7 @@ TEST_P(UnnestTest, mapWithEmptyUnnestValue) { expected->childAt(2), makeNullableFlatVector({1, 2, 1, 2, 3, 1})}); - const auto expectedWithEmptyUnnestValue = makeRowVector( + const auto expectedWithMarker = makeRowVector( {makeNullableFlatVector({1, 1, 2, 3, 3, 3, 4, 5}), makeNullableFlatVector( {1, 2, std::nullopt, 3, 4, 5, std::nullopt, 6}), @@ -503,14 +504,14 @@ TEST_P(UnnestTest, mapWithEmptyUnnestValue) { std::nullopt, std::nullopt}), makeNullableFlatVector( - {false, false, true, false, false, false, true, false})}); + {true, true, false, true, true, true, false, true})}); const auto expectedWithBoth = makeRowVector( - {expectedWithEmptyUnnestValue->childAt(0), - expectedWithEmptyUnnestValue->childAt(1), - expectedWithEmptyUnnestValue->childAt(2), + {expectedWithMarker->childAt(0), + expectedWithMarker->childAt(1), + expectedWithMarker->childAt(2), makeNullableFlatVector({1, 2, 0, 1, 2, 3, 0, 1}), - expectedWithEmptyUnnestValue->childAt(3)}); + expectedWithMarker->childAt(3)}); struct { bool hasOrdinality; @@ -529,7 +530,7 @@ TEST_P(UnnestTest, mapWithEmptyUnnestValue) { } testSettings[] = { {false, false, input, expected}, {true, false, input, expectedWithOrdinality}, - {false, true, input, expectedWithEmptyUnnestValue}, + {false, true, input, expectedWithMarker}, {true, true, input, expectedWithBoth}}; for (const auto& testData : testSettings) { @@ -539,13 +540,13 @@ TEST_P(UnnestTest, mapWithEmptyUnnestValue) { if (testData.hasOrdinality) { ordinalityName = "ordinal"; } - std::optional emptyUnnestValueName; + std::optional markerName; if (testData.hasEmptyUnnestValue) { - emptyUnnestValueName = "emptyUnnestValue"; + markerName = "emptyUnnestValue"; } auto op = PlanBuilder() .values({testData.input}) - .unnest({"c0"}, {"c1"}, ordinalityName, emptyUnnestValueName) + .unnest({"c0"}, {"c1"}, ordinalityName, markerName) .planNode(); auto params = makeCursorParameters(op); assertQuery(params, testData.expected); @@ -848,15 +849,16 @@ TEST_P(UnnestTest, barrier) { // Unnest 1K rows into 3K rows. auto planNodeIdGenerator = std::make_shared(); core::PlanNodeId unnestPlanNodeId; - const auto plan = PlanBuilder(planNodeIdGenerator) - .startTableScan() - .outputType(std::dynamic_pointer_cast( - vectors[0]->type())) - .endTableScan() - .project({"sequence(1, 3) as s"}) - .unnest({}, {"s"}) - .capturePlanNodeId(unnestPlanNodeId) - .planNode(); + const auto plan = + PlanBuilder(planNodeIdGenerator) + .startTableScan() + .outputType( + std::dynamic_pointer_cast(vectors[0]->type())) + .endTableScan() + .project({"sequence(1, 3) as s"}) + .unnest({}, {"s"}) + .capturePlanNodeId(unnestPlanNodeId) + .planNode(); const auto expectedResult = makeRowVector({ makeFlatVector( @@ -866,15 +868,22 @@ TEST_P(UnnestTest, barrier) { struct { bool barrierExecution; + bool serialExecution; int numOutputRows; std::string toString() const { return fmt::format( - "barrierExecution {}, numOutputRows {}", + "barrierExecution {}, serialExecution {}, numOutputRows {}", barrierExecution, + serialExecution, numOutputRows); } - } testSettings[] = {{true, 23}, {false, 23}, {true, 200}, {false, 200}}; + } testSettings[] = { + {true, true, 23}, + {true, false, 23}, + {false, true, 200}, + {false, false, 200}, + }; for (const auto& testData : testSettings) { SCOPED_TRACE(testData.toString()); const int numExpectedOutputVectors = @@ -886,7 +895,8 @@ TEST_P(UnnestTest, barrier) { core::QueryConfig::kMaxSplitPreloadPerDriver, std::to_string(tempFiles.size())) .splits(makeHiveConnectorSplits(tempFiles)) - .serialExecution(true) + .serialExecution(testData.serialExecution) + .maxDrivers(testData.serialExecution ? 1 : 3) .barrierExecution(testData.barrierExecution) .config( core::QueryConfig::kPreferredOutputBatchRows, @@ -936,13 +946,13 @@ TEST_P(UnnestTest, spiltOutput) { struct { bool produceSingleOutput; - int expectedNumOutputExectors; + int expectedNumOutputVectors; std::string toString() const { return fmt::format( - "produceSingleOutput {}, expectedNumOutputExectors {}", + "produceSingleOutput {}, expectedNumOutputVectors {}", produceSingleOutput, - expectedNumOutputExectors); + expectedNumOutputVectors); } } testSettings[] = { {true, numBatches}, @@ -960,7 +970,100 @@ TEST_P(UnnestTest, spiltOutput) { const auto taskStats = task->taskStats(); ASSERT_EQ( exec::toPlanStats(taskStats).at(unnestPlanNodeId).outputVectors, - testData.expectedNumOutputExectors); + testData.expectedNumOutputVectors); + } +} + +// Test that UnnestNode::splitOutput overrides query config. +TEST_P(UnnestTest, splitOutputNodeOverride) { + const auto numBatches = 3; + const auto inputBatchSize = 256; + std::vector vectors; + vectors.reserve(numBatches); + for (int32_t i = 0; i < numBatches; ++i) { + vectors.push_back(makeRowVector({ + makeFlatVector(inputBatchSize, [](auto row) { return row; }), + })); + } + + const auto expectedResult = makeRowVector({ + makeFlatVector( + numBatches * 3 * inputBatchSize, + [](auto row) { return 1 + row % 3; }), + }); + + auto planNodeIdGenerator = std::make_shared(); + + // Create a plan with project node to generate the sequence. + auto projectPlan = PlanBuilder(planNodeIdGenerator) + .values(vectors) + .project({"sequence(1, 3) as s"}) + .planNode(); + + // Get the output type from the project node. + auto projectOutput = projectPlan->outputType(); + std::vector> unnestFields; + unnestFields.emplace_back( + std::make_shared( + projectOutput->childAt(0), "s")); + + struct { + std::optional nodeSplitOutput; + bool configSplitOutput{}; + bool expectSplit{}; + + std::string toString() const { + return fmt::format( + "nodeSplitOutput: {}, configSplitOutput: {}, expectSplit: {}", + nodeSplitOutput.has_value() + ? (nodeSplitOutput.value() ? "true" : "false") + : "nullopt", + configSplitOutput, + expectSplit); + } + } testSettings[] = { + // Node splitOutput not set, use config. + {std::nullopt, true, true}, + {std::nullopt, false, false}, + // Node splitOutput=true overrides config. + {true, true, true}, + {true, false, true}, + // Node splitOutput=false overrides config. + {false, true, false}, + {false, false, false}, + }; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.toString()); + + core::PlanNodeId unnestPlanNodeId; + auto unnestNode = std::make_shared( + planNodeIdGenerator->next(), + std::vector>{}, + unnestFields, + std::vector{"s_e"}, + std::nullopt, + std::nullopt, + testData.nodeSplitOutput, + projectPlan); + unnestPlanNodeId = unnestNode->id(); + + const int expectedNumOutputVectors = testData.expectSplit + ? bits::divRoundUp(inputBatchSize * 3, GetParam()) * numBatches + : numBatches; + + auto task = AssertQueryBuilder(unnestNode) + .config( + core::QueryConfig::kPreferredOutputBatchRows, + std::to_string(GetParam())) + .config( + core::QueryConfig::kUnnestSplitOutput, + testData.configSplitOutput ? "true" : "false") + .assertResults(expectedResult); + const auto taskStats = task->taskStats(); + ASSERT_EQ( + exec::toPlanStats(taskStats).at(unnestPlanNodeId).outputVectors, + expectedNumOutputVectors); } } diff --git a/velox/exec/tests/ValuesTest.cpp b/velox/exec/tests/ValuesTest.cpp index 368fb17bfb5..ed8218e3c97 100644 --- a/velox/exec/tests/ValuesTest.cpp +++ b/velox/exec/tests/ValuesTest.cpp @@ -47,7 +47,9 @@ class ValuesTest : public OperatorTestBase { TEST_F(ValuesTest, empty) { // Base case: no vectors. - AssertQueryBuilder(PlanBuilder().values({}).planNode()).assertEmptyResults(); + AssertQueryBuilder( + PlanBuilder().values(std::vector{}).planNode()) + .assertEmptyResults(); // Empty input vector. auto emptyInput = makeRowVector({}); diff --git a/velox/exec/tests/VectorHasherTest.cpp b/velox/exec/tests/VectorHasherTest.cpp index 2a3cb8ef9f7..deca96c94ac 100644 --- a/velox/exec/tests/VectorHasherTest.cpp +++ b/velox/exec/tests/VectorHasherTest.cpp @@ -462,14 +462,15 @@ TEST_F(VectorHasherTest, stringDistinctOverflow) { for (auto i = 0; i < 7; ++i) { auto& stringVec = strings[i]; stringVec.resize(numRows); - batches.emplace_back(makeFlatVector( - numRows, [&i, &stringVec, numRows](vector_size_t row) { - const auto num = numRows * i + row; - stringVec[row] = (row != 0) - ? fmt::format("abcdefghijabcdefghij{}", num) - : fmt::format("s{}", num); - return StringView(stringVec[row]); - })); + batches.emplace_back( + makeFlatVector( + numRows, [&i, &stringVec, numRows](vector_size_t row) { + const auto num = numRows * i + row; + stringVec[row] = (row != 0) + ? fmt::format("abcdefghijabcdefghij{}", num) + : fmt::format("s{}", num); + return StringView(stringVec[row]); + })); } SelectivityVector rows(numRows, true); @@ -680,9 +681,9 @@ TEST_F(VectorHasherTest, merge) { VectorHasher emptyHasher(BIGINT(), 0); VectorHasher otherEmptyHasher(BIGINT(), 0); EXPECT_TRUE(emptyHasher.empty()); - emptyHasher.merge(otherHasher); - hasher.merge(emptyHasher); - hasher.merge(otherEmptyHasher); + emptyHasher.merge(otherHasher, 1'000'000); + hasher.merge(emptyHasher, 1'000'000); + hasher.merge(otherEmptyHasher, 1'000'000); uint64_t numRange; uint64_t numDistinct; hasher.cardinality(0, numRange, numDistinct); @@ -720,6 +721,45 @@ TEST_F(VectorHasherTest, merge) { EXPECT_EQ(numDistinct - 1, ids.size()); } +TEST_F(VectorHasherTest, mergeMaxNumDistinct) { + constexpr vector_size_t kSize = 100; + SelectivityVector rows(kSize); + raw_vector hashes(kSize); + + auto vector1 = + makeFlatVector(kSize, [](vector_size_t row) { return row; }); + auto vector2 = makeFlatVector( + kSize, [](vector_size_t row) { return 1000 + row; }); + auto vector3 = makeFlatVector( + kSize, [](vector_size_t row) { return 2000 + row; }); + + VectorHasher hasher1(BIGINT(), 0); + hasher1.decode(*vector1, rows); + hasher1.computeValueIds(rows, hashes); + + VectorHasher hasher2(BIGINT(), 0); + hasher2.decode(*vector2, rows); + hasher2.computeValueIds(rows, hashes); + + VectorHasher hasher3(BIGINT(), 0); + hasher3.decode(*vector3, rows); + hasher3.computeValueIds(rows, hashes); + + hasher1.merge(hasher2, kSize * 2); + uint64_t numRange; + uint64_t numDistinct; + hasher1.cardinality(0, numRange, numDistinct); + EXPECT_EQ(numDistinct, kSize * 2 + 1); + + hasher1.merge(hasher3, kSize * 2); + hasher1.cardinality(0, numRange, numDistinct); + EXPECT_EQ(numDistinct, VectorHasher::kRangeTooLarge); + + hasher1.merge(hasher3, kSize * 10); + hasher1.cardinality(0, numRange, numDistinct); + EXPECT_EQ(numDistinct, VectorHasher::kRangeTooLarge); +} + TEST_F(VectorHasherTest, computeValueIdsBigint) { testComputeValueIds(false); testComputeValueIds(true); @@ -1061,6 +1101,52 @@ TEST_F(VectorHasherTest, simdRange) { } } +TEST_F(VectorHasherTest, lookupValueIdsDictionaryWithLargerBase) { + auto base = makeNullableFlatVector({10, 11, 12, 13, 14, 15}); + + auto hasher = VectorHasher::create(INTEGER(), 0); + raw_vector result(8, pool()); + std::fill(result.begin(), result.end(), 0); + SelectivityVector allRows(base->size()); + hasher->decode(*base, allRows); + hasher->computeValueIds(allRows, result); + + auto probeBase = + makeNullableFlatVector({15, 100, 11, 14, std::nullopt, 10}); + auto indices = makeIndices(8, [](vector_size_t row) { + static constexpr vector_size_t kIndices[] = {0, 1, 2, 3, 4, 5, 0, 2}; + return kIndices[row]; + }); + auto probe = + BaseVector::wrapInDictionary(BufferPtr(nullptr), indices, 8, probeBase); + + SelectivityVector rows(8); + rows.clearAll(); + rows.setValid(0, true); + rows.setValid(1, true); + rows.setValid(2, true); + rows.setValid(4, true); + rows.setValid(6, true); + rows.updateBounds(); + std::fill(result.begin(), result.end(), 0); + + VectorHasher::ScratchMemory scratch; + hasher->lookupValueIds(*probe, rows, scratch, result); + + EXPECT_TRUE(rows.isValid(0)); + EXPECT_FALSE(rows.isValid(1)); + EXPECT_TRUE(rows.isValid(2)); + EXPECT_TRUE(rows.isValid(4)); + EXPECT_TRUE(rows.isValid(6)); + EXPECT_EQ(rows.countSelected(), 4); + + EXPECT_EQ(result[0], 6); + EXPECT_EQ(result[1], 0); + EXPECT_EQ(result[2], 2); + EXPECT_EQ(result[4], 0); + EXPECT_EQ(result[6], 6); +} + TEST_F(VectorHasherTest, typeMismatch) { auto hasher = VectorHasher::create(BIGINT(), 0); @@ -1197,3 +1283,200 @@ TEST_F(VectorHasherTest, customComparisonRow) { {std::nullopt, 0, 1, 0, 1, 0, 1}, velox::test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON())})); } + +TEST_F(VectorHasherTest, customComparisonValueIds) { + // Test that VectorHasher created with custom comparison type + // has value IDs disabled (distinctOverflow_ and rangeOverflow_ set). + auto vectorHasher = exec::VectorHasher::create( + velox::test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON(), 1); + + // Test that types with custom comparison do not support value IDs. + EXPECT_FALSE(vectorHasher->typeSupportsValueIds()); + + // Verify that mayUseValueIds() returns false for custom comparison types. + EXPECT_FALSE(vectorHasher->mayUseValueIds()); +} + +TEST_F(VectorHasherTest, stringFilter) { + // Test that getFilter() returns a BytesValues filter for string types. + auto vector = makeFlatVector( + {"apple", "banana", "cherry", "apple", "banana", "date"}); + + auto hasher = exec::VectorHasher::create(VARCHAR(), 0); + SelectivityVector rows(vector->size()); + raw_vector hashes(vector->size()); + + // Analyze the sample data. + hasher->decode(*vector, rows); + hasher->computeValueIds(rows, hashes); + + // Test nullAllowed = false. + auto filter = hasher->getFilter(false); + ASSERT_TRUE(filter != nullptr); + + auto bytesValues = dynamic_cast(filter.get()); + ASSERT_TRUE(bytesValues != nullptr); + ASSERT_FALSE(bytesValues->testNull()); + + ASSERT_TRUE(bytesValues->testBytes("apple", 5)); + ASSERT_TRUE(bytesValues->testBytes("banana", 6)); + ASSERT_TRUE(bytesValues->testBytes("cherry", 6)); + ASSERT_TRUE(bytesValues->testBytes("date", 4)); + ASSERT_FALSE(bytesValues->testBytes("elderberry", 10)); + ASSERT_FALSE(bytesValues->testBytes("fig", 3)); + ASSERT_FALSE(bytesValues->testBytes("grape", 5)); + + // Test nullAllowed = true. + auto filterWithNull = hasher->getFilter(true); + ASSERT_TRUE(filterWithNull != nullptr); + + auto bytesValuesWithNull = + dynamic_cast(filterWithNull.get()); + ASSERT_TRUE(bytesValuesWithNull != nullptr); + ASSERT_TRUE(bytesValuesWithNull->testNull()); + ASSERT_TRUE(bytesValuesWithNull->testBytes("apple", 5)); + ASSERT_FALSE(bytesValuesWithNull->testBytes("unknown", 7)); +} + +TEST_F(VectorHasherTest, stringFilterWithLongStrings) { + // Test string filter with strings longer than 8 bytes (stored as pointers). + // Keep strings in a vector to ensure they persist throughout the test. + std::vector strings = { + "short", + "this is a very very very very very long string", + "this is a very long string", + "short", + "medium length", + }; + + auto vector = makeFlatVector( + strings.size(), + [&strings](vector_size_t row) { return StringView(strings[row]); }); + + auto hasher = exec::VectorHasher::create(VARCHAR(), 0); + SelectivityVector rows(vector->size()); + raw_vector hashes(vector->size()); + + // Analyze the data. + hasher->decode(*vector, rows); + hasher->computeValueIds(rows, hashes); + + // Get the filter. + auto filter = hasher->getFilter(false); + ASSERT_TRUE(filter != nullptr); + + auto bytesValues = dynamic_cast(filter.get()); + ASSERT_TRUE(bytesValues != nullptr); + + // Test that both short and long strings are handled correctly. + ASSERT_TRUE(bytesValues->testBytes(strings[0].data(), strings[0].size())); + ASSERT_TRUE(bytesValues->testBytes(strings[1].data(), strings[1].size())); + ASSERT_TRUE(bytesValues->testBytes(strings[2].data(), strings[2].size())); + ASSERT_TRUE(bytesValues->testBytes(strings[4].data(), strings[4].size())); + + // Test rejection of non-existent strings. + ASSERT_FALSE(bytesValues->testBytes("not in the set", 14)); + ASSERT_FALSE(bytesValues->testBytes("different", 9)); +} + +TEST_F(VectorHasherTest, stringFilterDistinctOverflow) { + // Test that getFilter() returns nullptr when distinctOverflow_ is true. + auto hasher = exec::VectorHasher::create(VARCHAR(), 0); + + constexpr uint32_t numRows = 10000; + constexpr int numBatches = 15; + + SelectivityVector rows(numRows); + raw_vector hashes(numRows); + + // Process enough batches with distinct strings to trigger overflow. + // kMaxDistinct is 100,000, so 15 batches * 10,000 = 150,000 distinct strings. + for (int batch = 0; batch < numBatches; ++batch) { + std::vector strings; + strings.resize(numRows); + + // Create distinct strings for each batch. + for (auto i = 0; i < numRows; ++i) { + strings[i] = fmt::format("batch_{}_string_{}", batch, i); + } + + auto vector = makeFlatVector( + numRows, + [&strings](vector_size_t row) { return StringView(strings[row]); }); + + hasher->decode(*vector, rows); + hasher->computeValueIds(rows, hashes); + } + + // After overflow, getFilter should return nullptr. + auto filter = hasher->getFilter(false); + ASSERT_TRUE(filter == nullptr); +} + +DEBUG_ONLY_TEST_F(VectorHasherTest, customComparisonNoValueIds) { + // Test that custom comparison types cannot use value IDs for optimization. + auto data = makeRowVector({makeNullableFlatVector( + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, + velox::test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON())}); + + auto hasher = exec::VectorHasher::create( + velox::test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON(), 0); + + SelectivityVector allRows(data->size()); + raw_vector result(data->size()); + std::fill(result.begin(), result.end(), 0); + + hasher->decode(*data->childAt(0), allRows); + + VELOX_ASSERT_THROW( + hasher->computeValueIds(allRows, result), "Value IDs cannot be used"); + VectorHasher::ScratchMemory scratchMemory; + VELOX_ASSERT_THROW( + hasher->lookupValueIds(*data->childAt(0), allRows, scratchMemory, result), + "Value IDs cannot be used"); + EXPECT_EQ(nullptr, hasher->getFilter(true)); + VELOX_ASSERT_THROW( + hasher->enableValueRange(1, 50), "Value IDs cannot be used"); + VELOX_ASSERT_THROW( + hasher->enableValueRange(1, 50), "Value IDs cannot be used"); +} + +DEBUG_ONLY_TEST_F(VectorHasherTest, computeValueIdsForRowsCustomComparison) { + // Test that computeValueIdsForRows throws an exception for types with custom + // comparison. + auto hasher = exec::VectorHasher::create( + velox::test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON(), 0); + + constexpr int32_t kNumGroups = 5; + constexpr int32_t kRowSize = 16; + constexpr int32_t kValueOffset = 0; + constexpr int32_t kNullByte = 8; + constexpr uint8_t kNullMask = 1; + + // Allocate memory for row-wise data. + std::vector> rowData(kNumGroups); + std::vector groups(kNumGroups); + + for (int i = 0; i < kNumGroups; ++i) { + rowData[i].resize(kRowSize, 0); + groups[i] = rowData[i].data(); + + // Set values for all rows (no nulls for simplicity). + *reinterpret_cast(groups[i] + kValueOffset) = i * 256; + } + + raw_vector result(kNumGroups); + std::fill(result.begin(), result.end(), 0); + + // computeValueIdsForRows should throw an exception for types with custom + // comparison. + VELOX_ASSERT_THROW( + hasher->computeValueIdsForRows( + groups.data(), + kNumGroups, + kValueOffset, + kNullByte, + kNullMask, + result), + "Value IDs cannot be used"); +} diff --git a/velox/exec/tests/VeloxIn10MinDemo.cpp b/velox/exec/tests/VeloxIn10MinDemo.cpp index 87d571a1788..23a57b6bd2c 100644 --- a/velox/exec/tests/VeloxIn10MinDemo.cpp +++ b/velox/exec/tests/VeloxIn10MinDemo.cpp @@ -14,10 +14,11 @@ * limitations under the License. */ #include +#include #include "velox/common/memory/Memory.h" +#include "velox/connectors/ConnectorRegistry.h" #include "velox/connectors/tpch/TpchConnector.h" #include "velox/connectors/tpch/TpchConnectorSplit.h" -#include "velox/core/Expressions.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/expression/Expr.h" @@ -47,33 +48,25 @@ class VeloxIn10MinDemo : public VectorTestBase { // Register type resolver with DuckDB SQL parser. parse::registerTypeResolver(); - // Register the TPC-H Connector Factory. - connector::registerConnectorFactory( - std::make_shared()); - // Create and register a TPC-H connector. - auto tpchConnector = - connector::getConnectorFactory( - connector::tpch::TpchConnectorFactory::kTpchConnectorName) - ->newConnector( - kTpchConnectorId, - std::make_shared( - std::unordered_map())); - connector::registerConnector(tpchConnector); + connector::tpch::TpchConnectorFactory factory; + auto tpchConnector = factory.newConnector( + kTpchConnectorId, + std::make_shared( + std::unordered_map())); + connector::ConnectorRegistry::global().insert( + tpchConnector->connectorId(), tpchConnector); } ~VeloxIn10MinDemo() { - connector::unregisterConnector(kTpchConnectorId); - connector::unregisterConnectorFactory( - connector::tpch::TpchConnectorFactory::kTpchConnectorName); + connector::ConnectorRegistry::global().erase(kTpchConnectorId); } /// Parse SQL expression into a typed expression tree using DuckDB SQL parser. core::TypedExprPtr parseExpression( const std::string& text, const RowTypePtr& rowType) { - parse::ParseOptions options; - auto untyped = parse::parseExpr(text, options); + auto untyped = parse::DuckSqlExpressionsParser().parseExpr(text); return core::Expressions::inferTypes(untyped, rowType, execCtx_->pool()); } @@ -99,8 +92,9 @@ class VeloxIn10MinDemo : public VectorTestBase { /// Make TPC-H split to add to TableScan node. exec::Split makeTpchSplit() const { - return exec::Split(std::make_shared( - kTpchConnectorId, /*cacheable=*/true, 1, 0)); + return exec::Split( + std::make_shared( + kTpchConnectorId, /*cacheable=*/true, 1, 0)); } /// Run the demo. @@ -108,7 +102,7 @@ class VeloxIn10MinDemo : public VectorTestBase { std::shared_ptr executor_{ std::make_shared( - std::thread::hardware_concurrency())}; + folly::available_concurrency())}; std::shared_ptr queryCtx_{ core::QueryCtx::create(executor_.get())}; std::unique_ptr execCtx_{ diff --git a/velox/exec/tests/WindowFunctionRegistryTest.cpp b/velox/exec/tests/WindowFunctionRegistryTest.cpp index f78b18a0502..5fab67f5de0 100644 --- a/velox/exec/tests/WindowFunctionRegistryTest.cpp +++ b/velox/exec/tests/WindowFunctionRegistryTest.cpp @@ -15,13 +15,19 @@ */ #include +#include "velox/common/base/tests/GTestUtils.h" #include "velox/exec/WindowFunction.h" +#include "velox/exec/tests/AggregateRegistryTestUtil.h" #include "velox/expression/SignatureBinder.h" #include "velox/functions/prestosql/window/WindowFunctionsRegistration.h" namespace facebook::velox::exec::test { namespace { +std::vector nullTypes(size_t n) { + return std::vector(n, nullptr); +} + void registerWindowFunction(const std::string& name) { std::vector signatures{ exec::FunctionSignatureBuilder() @@ -50,6 +56,7 @@ class WindowFunctionRegistryTest : public testing::Test { WindowFunctionRegistryTest() { registerWindowFunction("window_func"); registerWindowFunction("window_Func_Alias"); + registerAggregateFunc("agg_func"); } TypePtr resolveWindowFunction( @@ -57,7 +64,7 @@ class WindowFunctionRegistryTest : public testing::Test { const std::vector& argTypes) { if (auto windowFunctionSignatures = getWindowFunctionSignatures(name)) { for (const auto& signature : windowFunctionSignatures.value()) { - SignatureBinder binder(*signature, argTypes); + SignatureBinder binder(*signature, argTypes, TypeCoercer::defaults()); if (binder.tryBind()) { return binder.tryResolveReturnType(); } @@ -78,6 +85,35 @@ class WindowFunctionRegistryTest : public testing::Test { EXPECT_EQ(actualType, nullptr); } } + + void testNoCoercions( + const std::string& name, + const std::vector& argTypes, + const TypePtr& expectedReturnType) { + VELOX_EXPECT_EQ_TYPES( + resolveWindowResultType(name, argTypes), expectedReturnType); + + std::vector coercions; + auto type = resolveWindowResultTypeWithCoercions( + name, argTypes, coercions, TypeCoercer::defaults()); + VELOX_EXPECT_EQ_TYPES(type, expectedReturnType); + EXPECT_EQ(coercions, nullTypes(argTypes.size())); + } + + void testCoercions( + const std::string& name, + const std::vector& argTypes, + const TypePtr& expectedReturnType, + const std::vector& expectedCoercions) { + std::vector coercions; + auto type = resolveWindowResultTypeWithCoercions( + name, argTypes, coercions, TypeCoercer::defaults()); + VELOX_EXPECT_EQ_TYPES(type, expectedReturnType); + EXPECT_EQ(coercions.size(), expectedCoercions.size()); + for (auto i = 0; i < coercions.size(); ++i) { + VELOX_EXPECT_EQ_TYPES(coercions[i], expectedCoercions[i]); + } + } }; TEST_F(WindowFunctionRegistryTest, basic) { @@ -134,4 +170,81 @@ TEST_F(WindowFunctionRegistryTest, prefix) { } } +TEST_F(WindowFunctionRegistryTest, resolveResultType) { + testNoCoercions("window_func", {BIGINT(), DOUBLE()}, BIGINT()); + testNoCoercions("window_func", {DOUBLE(), DOUBLE()}, DOUBLE()); + testNoCoercions("window_func", {}, DATE()); + + // Aggregate function registered as a window function. + testNoCoercions("agg_func", {BIGINT(), DOUBLE()}, BIGINT()); + testNoCoercions("agg_func", {}, DATE()); +} + +TEST_F(WindowFunctionRegistryTest, resolveResultTypeErrors) { + // Wrong function name. + VELOX_ASSERT_THROW( + resolveWindowResultType("nonexistent_func", {BIGINT(), DOUBLE()}), + "Window function not registered"); + + // Wrong signature for a window function. + VELOX_ASSERT_THROW( + resolveWindowResultType("window_func", {BIGINT()}), + "Window function signature is not supported"); + + // resolveWindowResultTypeWithCoercions: wrong function name. + { + std::vector coercions; + VELOX_ASSERT_THROW( + resolveWindowResultTypeWithCoercions( + "nonexistent_func", + {BIGINT(), DOUBLE()}, + coercions, + TypeCoercer::defaults()), + "Window function not registered"); + } + + // resolveWindowResultTypeWithCoercions: wrong signature. + { + std::vector coercions; + VELOX_ASSERT_THROW( + resolveWindowResultTypeWithCoercions( + "window_func", {VARCHAR()}, coercions, TypeCoercer::defaults()), + "Window function signature is not supported"); + } + + // resolveWindowResultTypeWithCoercions: correct name and arg count, but + // incompatible types (VARCHAR cannot be coerced to match any signature). + { + std::vector coercions; + VELOX_ASSERT_THROW( + resolveWindowResultTypeWithCoercions( + "window_func", + {VARCHAR(), BIGINT()}, + coercions, + TypeCoercer::defaults()), + "Window function signature is not supported"); + } +} + +TEST_F(WindowFunctionRegistryTest, resolveResultTypeWithCoercions) { + // Exact match: no coercions needed. + testNoCoercions("window_func", {BIGINT(), DOUBLE()}, BIGINT()); + + // Window function signature with coercion: (double, bigint) doesn't + // exactly match (bigint, double) or (T, T), but (T, T) matches with + // coercion. + testCoercions( + "window_func", {DOUBLE(), BIGINT()}, DOUBLE(), {nullptr, DOUBLE()}); + + // Coercion with smaller integer types. + testCoercions( + "window_func", {TINYINT(), BIGINT()}, BIGINT(), {BIGINT(), nullptr}); + + // Aggregate function registered as a window function. + testNoCoercions("agg_func", {BIGINT(), DOUBLE()}, BIGINT()); + + testCoercions( + "agg_func", {DOUBLE(), BIGINT()}, DOUBLE(), {nullptr, DOUBLE()}); +} + } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/WindowTest.cpp b/velox/exec/tests/WindowTest.cpp index f1399a7af17..64e4ae543ef 100644 --- a/velox/exec/tests/WindowTest.cpp +++ b/velox/exec/tests/WindowTest.cpp @@ -14,9 +14,11 @@ * limitations under the License. */ #include "velox/exec/Window.h" +#include #include "velox/common/base/Exceptions.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/FileSystems.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/exec/OrderBy.h" #include "velox/exec/PlanNodeStats.h" #include "velox/exec/RowsStreamingWindowBuild.h" @@ -24,13 +26,12 @@ #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/OperatorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/functions/prestosql/window/WindowFunctionsRegistration.h" using namespace facebook::velox::exec::test; namespace facebook::velox::exec { - +using namespace facebook::velox::common::testutil; namespace { class WindowTest : public OperatorTestBase { @@ -63,12 +64,13 @@ class WindowTest : public OperatorTestBase { 0, 0, "none", + 0, prefixSortConfig); } const std::shared_ptr executor_{ std::make_shared( - std::thread::hardware_concurrency())}; + folly::available_concurrency())}; tsan_atomic nonReclaimableSection_{false}; }; @@ -115,6 +117,112 @@ TEST_F(WindowTest, spill) { ASSERT_GT(stats.spilledPartitions, 0); } +TEST_F(WindowTest, spillBatchReadTinyPartitions) { + const vector_size_t size = 1'000; + const uint32_t minReadBatchRows = 100; + // Each tiny partition has 1 row. + const uint32_t partitionRows = 1; + auto data = makeRowVector( + {"d", "p", "s"}, + { + // Payload. + makeFlatVector(size, [](auto row) { return row; }), + // Partition key. + makeFlatVector( + size, [](auto row) { return row / partitionRows; }), + // Sorting key. + makeFlatVector(size, [](auto row) { return row; }), + }); + + createDuckDbTable({data}); + + core::PlanNodeId windowId; + auto plan = PlanBuilder() + .values(split(data, 10)) + .window({"row_number() over (partition by p order by s)"}) + .capturePlanNodeId(windowId) + .planNode(); + + auto spillDirectory = TempDirectoryPath::create(); + TestScopedSpillInjection scopedSpillInjection(100); + auto task = + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kPreferredOutputBatchBytes, "1024") + .config(core::QueryConfig::kSpillEnabled, "true") + .config(core::QueryConfig::kWindowSpillEnabled, "true") + .config( + core::QueryConfig::kWindowSpillMinReadBatchRows, minReadBatchRows) + .spillDirectory(spillDirectory->getPath()) + .assertResults( + "SELECT *, row_number() over (partition by p order by s) FROM tmp"); + + auto taskStats = exec::toPlanStats(task->taskStats()); + const auto& stats = taskStats.at(windowId); + + ASSERT_GT(stats.spilledBytes, 0); + ASSERT_GT(stats.spilledRows, 0); + ASSERT_GT(stats.spilledFiles, 0); + ASSERT_GT(stats.spilledPartitions, 0); + ASSERT_EQ( + stats.operatorStats.at("Window") + ->customStats[std::string(Window::kWindowSpillReadNumBatches)] + .sum, + size / minReadBatchRows); +} + +TEST_F(WindowTest, spillBatchReadHugePartitions) { + const vector_size_t size = 1'000; + const uint32_t minReadBatchRows = 100; + // Each huge partition has 200 rows, which is larger than minReadBatchRows. + const uint32_t partitionRows = 200; + auto data = makeRowVector( + {"d", "p", "s"}, + { + // Payload. + makeFlatVector(size, [](auto row) { return row; }), + // Partition key. + makeFlatVector( + size, [](auto row) { return row / partitionRows; }), + // Sorting key. + makeFlatVector(size, [](auto row) { return row; }), + }); + + createDuckDbTable({data}); + + core::PlanNodeId windowId; + auto plan = PlanBuilder() + .values(split(data, 10)) + .window({"row_number() over (partition by p order by s)"}) + .capturePlanNodeId(windowId) + .planNode(); + + auto spillDirectory = TempDirectoryPath::create(); + TestScopedSpillInjection scopedSpillInjection(100); + auto task = + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kPreferredOutputBatchBytes, "1024") + .config(core::QueryConfig::kSpillEnabled, "true") + .config(core::QueryConfig::kWindowSpillEnabled, "true") + .config( + core::QueryConfig::kWindowSpillMinReadBatchRows, minReadBatchRows) + .spillDirectory(spillDirectory->getPath()) + .assertResults( + "SELECT *, row_number() over (partition by p order by s) FROM tmp"); + + auto taskStats = exec::toPlanStats(task->taskStats()); + const auto& stats = taskStats.at(windowId); + + ASSERT_GT(stats.spilledBytes, 0); + ASSERT_GT(stats.spilledRows, 0); + ASSERT_GT(stats.spilledFiles, 0); + ASSERT_GT(stats.spilledPartitions, 0); + ASSERT_EQ( + stats.operatorStats.at("Window") + ->customStats[std::string(Window::kWindowSpillReadNumBatches)] + .sum, + size / partitionRows); +} + TEST_F(WindowTest, spillUnsupported) { const vector_size_t size = 1'000; auto data = makeRowVector( @@ -156,7 +264,10 @@ TEST_F(WindowTest, spillUnsupported) { ASSERT_EQ(stats.spilledPartitions, 0); auto opStats = toOperatorStats(task->taskStats()); ASSERT_GT( - opStats.at("Window").runtimeStats[Operator::kSpillNotSupported].sum, 1); + opStats.at("Window") + .runtimeStats[std::string(Operator::kSpillNotSupported)] + .sum, + 1); } TEST_F(WindowTest, rowBasedStreamingWindowOOM) { @@ -180,10 +291,11 @@ TEST_F(WindowTest, rowBasedStreamingWindowOOM) { auto planNodeIdGenerator = std::make_shared(); CursorParameters params; auto queryCtx = core::QueryCtx::create(executor_.get()); - queryCtx->testingOverrideMemoryPool(memory::memoryManager()->addRootPool( - queryCtx->queryId(), - 8'388'608 /* 8MB */, - exec::MemoryReclaimer::create())); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), + 8'388'608 /* 8MB */, + exec::MemoryReclaimer::create())); params.queryCtx = queryCtx; @@ -325,6 +437,122 @@ DEBUG_ONLY_TEST_F(WindowTest, valuesRowsStreamingWindowBuild) { ASSERT_TRUE(isStreamCreated.load()); } +TEST_F(WindowTest, prePartitionedSortBuild) { + const vector_size_t size = 1'000; + const int numPartitions = 37; + const int numSubPartitions = 4; + auto data = makeRowVector( + {"p", "s"}, + { + // Partition key. + makeFlatVector( + size, [](auto row) { return row % numPartitions; }), + // Sorting key. + makeFlatVector(size, [](auto row) { return row; }), + }); + + createDuckDbTable({data}); + + core::PlanNodeId windowId; + auto plan = + PlanBuilder() + .values(split(data, 10)) + .window({"row_number() over (partition by p order by s desc)"}) + .capturePlanNodeId(windowId) + .planNode(); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kPreferredOutputBatchBytes, "1024") + .config( + core::QueryConfig::kWindowNumSubPartitions, + std::to_string(numSubPartitions)) + .assertResults( + "SELECT *, row_number() over (partition by p order by s desc) FROM tmp ORDER BY s"); +} + +TEST_F(WindowTest, prePartitionedSortBuildSkewed) { + const vector_size_t size = 1'000; + const int numPartitions = 4; + const int numSubPartitions = 16; + auto data = makeRowVector( + {"p", "s"}, + { + // Partition key. + makeFlatVector( + size, [](auto row) { return row % numPartitions; }), + // Sorting key. + makeFlatVector(size, [](auto row) { return row; }), + }); + + createDuckDbTable({data}); + + core::PlanNodeId windowId; + auto plan = + PlanBuilder() + .values(split(data, 10)) + .window({"row_number() over (partition by p order by s desc)"}) + .capturePlanNodeId(windowId) + .planNode(); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kPreferredOutputBatchBytes, "1024") + .config( + core::QueryConfig::kWindowNumSubPartitions, + std::to_string(numSubPartitions)) + .assertResults( + "SELECT *, row_number() over (partition by p order by s desc) FROM tmp ORDER BY s"); +} + +TEST_F(WindowTest, prePartitionedBuildWithSpill) { + const vector_size_t size = 1'000; + const int numPartitions = 37; + const int numSubPartitions = 4; + auto data = makeRowVector( + {"d", "p", "s"}, + { + // Payload. + makeFlatVector(size, [](auto row) { return row; }), + // Partition key. + makeFlatVector( + size, [](auto row) { return row % numPartitions; }), + // Sorting key. + makeFlatVector(size, [](auto row) { return row; }), + }); + + createDuckDbTable({data}); + + core::PlanNodeId windowId; + auto plan = + PlanBuilder() + .values(split(data, 10)) + .window({"row_number() over (partition by p order by s desc)"}) + .capturePlanNodeId(windowId) + .planNode(); + + auto spillDirectory = TempDirectoryPath::create(); + TestScopedSpillInjection scopedSpillInjection(100); + auto task = + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kPreferredOutputBatchBytes, "1024") + .config( + core::QueryConfig::kWindowNumSubPartitions, + std::to_string(numSubPartitions)) + .config(core::QueryConfig::kSpillEnabled, "true") + .config(core::QueryConfig::kWindowSpillEnabled, "true") + .config(core::QueryConfig::kOrderBySpillEnabled, "false") + .spillDirectory(spillDirectory->getPath()) + .assertResults( + "SELECT *, row_number() over (partition by p order by s desc) FROM tmp ORDER BY s"); + + auto taskStats = exec::toPlanStats(task->taskStats()); + const auto& stats = taskStats.at(windowId); + + ASSERT_GT(stats.spilledBytes, 0); + ASSERT_GT(stats.spilledRows, 0); + ASSERT_GT(stats.spilledFiles, 0); + ASSERT_GT(stats.spilledPartitions, 0); +} + DEBUG_ONLY_TEST_F(WindowTest, aggregationWithNonDefaultFrame) { const vector_size_t size = 1'00; @@ -552,10 +780,11 @@ TEST_F(WindowTest, nagativeFrameArg) { .planNode(); VELOX_ASSERT_USER_THROW( AssertQueryBuilder(plan, duckDbQueryRunner_) - .assertResults(fmt::format( - "SELECT *, regr_count(c0, c1) over (partition by p0, p1 order by row_number ROWS between {} PRECEDING and {} FOLLOWING) from tmp", - startOffset, - endOffset)), + .assertResults( + fmt::format( + "SELECT *, regr_count(c0, c1) over (partition by p0, p1 order by row_number ROWS between {} PRECEDING and {} FOLLOWING) from tmp", + startOffset, + endOffset)), testData.debugString()); } } @@ -645,25 +874,28 @@ DEBUG_ONLY_TEST_F(WindowTest, reserveMemorySort) { for (const auto [usePrefixSort, spillEnabled, enableSpillPrefixSort] : testSettings) { - SCOPED_TRACE(fmt::format( - "usePrefixSort: {}, spillEnabled: {}, enableSpillPrefixSort: {}", - usePrefixSort, - spillEnabled, - enableSpillPrefixSort)); - auto spillDirectory = exec::test::TempDirectoryPath::create(); + SCOPED_TRACE( + fmt::format( + "usePrefixSort: {}, spillEnabled: {}, enableSpillPrefixSort: {}", + usePrefixSort, + spillEnabled, + enableSpillPrefixSort)); + auto spillDirectory = TempDirectoryPath::create(); auto spillConfig = getSpillConfig(spillDirectory->getPath(), enableSpillPrefixSort); - folly::Synchronized spillStats; + exec::SpillStats spillStats; const auto plan = usePrefixSort ? prefixSortPlan : nonPrefixSortPlan; velox::common::PrefixSortConfig prefixSortConfig = velox::common::PrefixSortConfig{ std::numeric_limits::max(), 130, 12}; + folly::Synchronized opStats; auto sortWindowBuild = std::make_unique( plan, pool_.get(), std::move(prefixSortConfig), spillEnabled ? &spillConfig : nullptr, &nonReclaimableSection_, + &opStats, &spillStats); TestScopedSpillInjection scopedSpillInjection(0); @@ -713,18 +945,20 @@ TEST_F(WindowTest, NaNFrameBound) { if (startBound == "following" && endBound == "preceding") { continue; } - frames.push_back(fmt::format( - "{} over (order by s0 {} range between off0 {} and off1 {})", - call, - order, - startBound, - endBound)); - frames.push_back(fmt::format( - "{} over (order by s0 {} range between off1 {} and off0 {})", - call, - order, - startBound, - endBound)); + frames.push_back( + fmt::format( + "{} over (order by s0 {} range between off0 {} and off1 {})", + call, + order, + startBound, + endBound)); + frames.push_back( + fmt::format( + "{} over (order by s0 {} range between off1 {} and off0 {})", + call, + order, + startBound, + endBound)); } } } diff --git a/velox/exec/tests/WriterFuzzerUtilTest.cpp b/velox/exec/tests/WriterFuzzerUtilTest.cpp index d0b765dd813..93ebe8a5c29 100644 --- a/velox/exec/tests/WriterFuzzerUtilTest.cpp +++ b/velox/exec/tests/WriterFuzzerUtilTest.cpp @@ -16,15 +16,16 @@ #include #include "velox/common/file/FileSystems.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/exec/fuzzer/WriterFuzzer.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" using namespace facebook::velox; using namespace facebook::velox::exec::test; +using namespace facebook::velox::common::testutil; TEST(WriterFuzzerUtilTest, listFolders) { facebook::velox::filesystems::registerLocalFileSystem(); - const auto tempFolder = exec::test::TempDirectoryPath::create(); + const auto tempFolder = TempDirectoryPath::create(); // Directory layout: // First layer: dir1/ dir2/ dir3/ a // Second layer: b dir2_1/ diff --git a/velox/exec/tests/data/decimal.orc b/velox/exec/tests/data/decimal.orc new file mode 100644 index 00000000000..b89662b65f6 Binary files /dev/null and b/velox/exec/tests/data/decimal.orc differ diff --git a/velox/exec/tests/utils/AggregationResolver.cpp b/velox/exec/tests/utils/AggregationResolver.cpp index 44cc0058b14..89ab458d199 100644 --- a/velox/exec/tests/utils/AggregationResolver.cpp +++ b/velox/exec/tests/utils/AggregationResolver.cpp @@ -22,15 +22,13 @@ namespace facebook::velox::exec::test { namespace { std::string throwAggregateFunctionDoesntExist(const std::string& name) { - std::stringstream error; - error << "Aggregate function doesn't exist: " << name << "."; exec::aggregateFunctions().withRLock([&](const auto& functionsMap) { if (functionsMap.empty()) { - error << " Registry of aggregate functions is empty. " - "Make sure to register some aggregate functions."; + VELOX_USER_FAIL( + "Registry of aggregate functions is empty. Make sure to register some aggregate functions."); } }); - VELOX_USER_FAIL(error.str()); + VELOX_USER_FAIL("Aggregate function doesn't exist: {}.", name); } std::string throwAggregateFunctionSignatureNotSupported( @@ -38,11 +36,10 @@ std::string throwAggregateFunctionSignatureNotSupported( const std::vector& types, const std::vector>& signatures) { - std::stringstream error; - error << "Aggregate function signature is not supported: " - << toString(name, types) - << ". Supported signatures: " << toString(signatures) << "."; - VELOX_USER_FAIL(error.str()); + VELOX_USER_FAIL( + "Aggregate function signature is not supported: {}. Supported signatures: {}.", + toString(name, types), + toString(signatures)); } } // namespace @@ -53,7 +50,8 @@ TypePtr resolveAggregateType( bool nullOnFailure) { if (auto signatures = exec::getAggregateFunctionSignatures(aggregateName)) { for (const auto& signature : signatures.value()) { - exec::SignatureBinder binder(*signature, rawInputTypes); + exec::SignatureBinder binder( + *signature, rawInputTypes, TypeCoercer::defaults()); if (binder.tryBind()) { return binder.tryResolveType( exec::isPartialOutput(step) ? signature->intermediateType() diff --git a/velox/exec/tests/utils/ArbitratorTestUtil.cpp b/velox/exec/tests/utils/ArbitratorTestUtil.cpp index 8c3bb9dc8ac..f6fbc9862ec 100644 --- a/velox/exec/tests/utils/ArbitratorTestUtil.cpp +++ b/velox/exec/tests/utils/ArbitratorTestUtil.cpp @@ -19,10 +19,8 @@ #include "velox/dwio/dwrf/common/Config.h" #include "velox/exec/TableWriter.h" -using namespace facebook::velox; -using namespace facebook::velox::exec; -using namespace facebook::velox::exec::test; using namespace facebook::velox::memory; +using namespace facebook::velox::common::testutil; namespace facebook::velox::exec::test { @@ -31,18 +29,11 @@ std::shared_ptr newQueryCtx( folly::Executor* executor, int64_t memoryCapacity, const std::string& queryId) { - std::unordered_map> configs; - std::shared_ptr pool = - memoryManager->addRootPool("", memoryCapacity); - auto queryCtx = core::QueryCtx::create( - executor, - core::QueryConfig({}), - configs, - cache::AsyncDataCache::getInstance(), - std::move(pool), - nullptr, - queryId); - return queryCtx; + return core::QueryCtx::Builder() + .executor(executor) + .pool(memoryManager->addRootPool("", memoryCapacity)) + .queryId(queryId) + .build(); } std::unique_ptr createMemoryManager( @@ -104,7 +95,7 @@ QueryTestResult runHashJoinTask( QueryTestResult result; const auto plan = hashJoinPlan(vectors, result.planNodeId); if (enableSpilling) { - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); result.data = AssertQueryBuilder(plan) .serialExecution(serialExecution) .spillDirectory(spillDirectory->getPath()) @@ -149,7 +140,7 @@ QueryTestResult runAggregateTask( QueryTestResult result; const auto plan = aggregationPlan(vectors, result.planNodeId); if (enableSpilling) { - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); result.data = AssertQueryBuilder(plan) .serialExecution(serialExecution) @@ -195,7 +186,7 @@ QueryTestResult runOrderByTask( QueryTestResult result; const auto plan = orderByPlan(vectors, result.planNodeId); if (enableSpilling) { - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); result.data = AssertQueryBuilder(plan) .serialExecution(serialExecution) .spillDirectory(spillDirectory->getPath()) @@ -240,7 +231,7 @@ QueryTestResult runRowNumberTask( QueryTestResult result; const auto plan = rowNumberPlan(vectors, result.planNodeId); if (enableSpilling) { - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); result.data = AssertQueryBuilder(plan) .serialExecution(serialExecution) .spillDirectory(spillDirectory->getPath()) @@ -285,7 +276,7 @@ QueryTestResult runTopNTask( QueryTestResult result; const auto plan = topNPlan(vectors, result.planNodeId); if (enableSpilling) { - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); result.data = AssertQueryBuilder(plan) .serialExecution(serialExecution) @@ -335,7 +326,7 @@ QueryTestResult runWriteTask( const auto outputDirectory = TempDirectoryPath::create(); auto plan = writePlan(vectors, outputDirectory->getPath(), result.planNodeId); if (enableSpilling) { - const auto spillDirectory = exec::test::TempDirectoryPath::create(); + const auto spillDirectory = TempDirectoryPath::create(); result.data = AssertQueryBuilder(plan) .serialExecution(serialExecution) diff --git a/velox/exec/tests/utils/ArbitratorTestUtil.h b/velox/exec/tests/utils/ArbitratorTestUtil.h index 8021db65608..85024deb800 100644 --- a/velox/exec/tests/utils/ArbitratorTestUtil.h +++ b/velox/exec/tests/utils/ArbitratorTestUtil.h @@ -19,15 +19,17 @@ #include #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/memory/MemoryPool.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/exec/Driver.h" #include "velox/exec/MemoryReclaimer.h" #include "velox/exec/Task.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" namespace facebook::velox::exec::test { +using TempDirectoryPath = common::testutil::TempDirectoryPath; + constexpr int64_t KB = 1024L; constexpr int64_t MB = 1024L * KB; @@ -112,7 +114,7 @@ std::shared_ptr newQueryCtx( std::unique_ptr createMemoryManager( int64_t arbitratorCapacity = kMemoryCapacity, uint64_t memoryPoolInitCapacity = kMemoryPoolInitCapacity, - uint64_t maxReclaimWaitMs = 5 * 60 * 1'000, + uint64_t maxReclaimWaitMs = 60 * 1'000, uint64_t fastExponentialGrowthCapacityLimit = 0, double slowCapacityGrowPct = 0); diff --git a/velox/exec/tests/utils/AssertQueryBuilder.cpp b/velox/exec/tests/utils/AssertQueryBuilder.cpp index 1b68b77cf71..ff443ae1b40 100644 --- a/velox/exec/tests/utils/AssertQueryBuilder.cpp +++ b/velox/exec/tests/utils/AssertQueryBuilder.cpp @@ -266,7 +266,26 @@ RowVectorPtr AssertQueryBuilder::copyResults( return copy; } -uint64_t AssertQueryBuilder::runWithoutResults(std::shared_ptr& task) { +std::vector AssertQueryBuilder::copyResultBatches( + memory::MemoryPool* pool) { + auto [cursor, results] = readCursor(); + + if (results.empty()) { + return results; + } + + std::vector copies; + copies.reserve(results.size()); + for (const auto& result : results) { + copies.push_back( + BaseVector::create(result->type(), result->size(), pool)); + copies.back()->copy(result.get(), 0, 0, result->size()); + } + + return copies; +} + +uint64_t AssertQueryBuilder::countResults(std::shared_ptr& task) { auto [cursor, results] = readCursor(); uint64_t count = 0; for (const auto& result : results) { @@ -276,6 +295,11 @@ uint64_t AssertQueryBuilder::runWithoutResults(std::shared_ptr& task) { return count; } +uint64_t AssertQueryBuilder::countResults() { + std::shared_ptr task; + return countResults(task); +} + std::pair, std::vector> AssertQueryBuilder::readCursor() { VELOX_CHECK_NOT_NULL(params_.planNode); @@ -287,17 +311,14 @@ AssertQueryBuilder::readCursor() { static std::atomic cursorQueryId{0}; const std::string queryId = fmt::format("TaskCursorQuery_{}", cursorQueryId++); - auto queryPool = memory::memoryManager()->addRootPool( - queryId, params_.maxQueryCapacity); - params_.queryCtx = core::QueryCtx::create( - executor_.get(), - core::QueryConfig({}), - std:: - unordered_map>{}, - cache::AsyncDataCache::getInstance(), - std::move(queryPool), - nullptr, - queryId); + + params_.queryCtx = core::QueryCtx::Builder() + .executor(executor_.get()) + .pool( + memory::memoryManager()->addRootPool( + queryId, params_.maxQueryCapacity)) + .queryId(queryId) + .build(); } } if (!configs_.empty()) { @@ -310,52 +331,55 @@ AssertQueryBuilder::readCursor() { } } - return test::readCursor(params_, [&](exec::TaskCursor* taskCursor) { - if (taskCursor->noMoreSplits()) { - return; - } - auto& task = taskCursor->task(); - VELOX_CHECK(!params_.barrierExecution || params_.serialExecution); - if (params_.barrierExecution) { - int numSplits{0}; - for (auto& [nodeId, nodeSplits] : splits_) { - if (nodeSplits.empty()) { - task->noMoreSplits(nodeId); - continue; + return test::readCursorAsync( + params_, + [&](exec::TaskCursor* taskCursor) { + if (taskCursor->noMoreSplits()) { + return ContinueFuture::makeEmpty(); } - ++numSplits; - if (addSplitWithSequence_) { - task->addSplitWithSequence( - nodeId, std::move(nodeSplits[0]), ++sequenceId_); - task->setMaxSplitSequenceId(nodeId, sequenceId_); + auto& task = taskCursor->task(); + if (params_.barrierExecution) { + int numSplits{0}; + for (auto& [nodeId, nodeSplits] : splits_) { + if (nodeSplits.empty()) { + task->noMoreSplits(nodeId); + continue; + } + ++numSplits; + if (addSplitWithSequence_) { + task->addSplitWithSequence( + nodeId, std::move(nodeSplits[0]), ++sequenceId_); + task->setMaxSplitSequenceId(nodeId, sequenceId_); + } else { + task->addSplit(nodeId, std::move(nodeSplits[0])); + } + nodeSplits.erase(nodeSplits.cbegin()); + } + if (numSplits > 0) { + VELOX_CHECK_EQ( + numSplits, + splits_.size(), + "Barrier task execution mode requires all the sources have the same number of splits"); + return task->requestBarrier(); + } + taskCursor->setNoMoreSplits(); } else { - task->addSplit(nodeId, std::move(nodeSplits[0])); - } - nodeSplits.erase(nodeSplits.begin()); - } - if (numSplits > 0) { - VELOX_CHECK_EQ( - numSplits, - splits_.size(), - "Barrier task execution mode requires all the sources have the same number of splits"); - task->requestBarrier(); - } else { - taskCursor->setNoMoreSplits(); - } - } else { - for (auto& [nodeId, nodeSplits] : splits_) { - for (auto& split : nodeSplits) { - if (addSplitWithSequence_) { - task->addSplitWithSequence(nodeId, std::move(split), ++sequenceId_); - task->setMaxSplitSequenceId(nodeId, sequenceId_); - } else { - task->addSplit(nodeId, std::move(split)); + for (auto& [nodeId, nodeSplits] : splits_) { + for (auto& split : nodeSplits) { + if (addSplitWithSequence_) { + task->addSplitWithSequence( + nodeId, std::move(split), ++sequenceId_); + task->setMaxSplitSequenceId(nodeId, sequenceId_); + } else { + task->addSplit(nodeId, std::move(split)); + } + } + task->noMoreSplits(nodeId); } + taskCursor->setNoMoreSplits(); } - task->noMoreSplits(nodeId); - } - taskCursor->setNoMoreSplits(); - } - }); + return ContinueFuture::makeEmpty(); + }, + maxWaitMicros_); } } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/utils/AssertQueryBuilder.h b/velox/exec/tests/utils/AssertQueryBuilder.h index 8e9f15d9831..b96b0fda74f 100644 --- a/velox/exec/tests/utils/AssertQueryBuilder.h +++ b/velox/exec/tests/utils/AssertQueryBuilder.h @@ -15,6 +15,7 @@ */ #pragma once +#include #include "velox/exec/tests/utils/QueryAssertions.h" namespace facebook::velox::exec::test { @@ -127,6 +128,13 @@ class AssertQueryBuilder { return *this; } + /// Set the maximum time to wait for task completion after all results have + /// been consumed. Default is 5 seconds. + AssertQueryBuilder& maxWaitMicros(uint64_t maxWaitMicros) { + maxWaitMicros_ = maxWaitMicros; + return *this; + } + /// Spilling directory, if not empty, then the task's spilling directory would /// be built from it. AssertQueryBuilder& spillDirectory(const std::string& dir) { @@ -182,8 +190,7 @@ class AssertQueryBuilder { const TypePtr& expectedType, vector_size_t expectedNumRows); - /// Run the query and collect all results into a single vector. Throws if - /// query returns empty result. + /// Run the query and collect all results into a single vector. RowVectorPtr copyResults(memory::MemoryPool* pool); /// Similar to above method and also returns the task. @@ -191,8 +198,15 @@ class AssertQueryBuilder { memory::MemoryPool* pool, std::shared_ptr& task); + /// Run the query and copy the result Vectors as their original batches. + std::vector copyResultBatches(memory::MemoryPool* pool); + /// Run the query and return the number of result rows. - uint64_t runWithoutResults(std::shared_ptr& task); + uint64_t countResults(std::shared_ptr& task); + + /// Run the query and return the number of result rows without requiring a + /// task parameter. + uint64_t countResults(); private: std::pair, std::vector> @@ -200,7 +214,7 @@ class AssertQueryBuilder { static std::unique_ptr newExecutor() { return std::make_unique( - std::thread::hardware_concurrency()); + folly::available_concurrency()); } // Used by the created task as the default driver executor. @@ -214,6 +228,9 @@ class AssertQueryBuilder { bool addSplitWithSequence_{false}; // The sequence Id to be used when addSplitWithSequence_ is true. int32_t sequenceId_{0}; + // Maximum time in microseconds to wait for task completion after all results + // have been consumed. + uint64_t maxWaitMicros_{5'000'000}; }; } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/utils/CMakeLists.txt b/velox/exec/tests/utils/CMakeLists.txt index cf3ef9dbb10..9174c761628 100644 --- a/velox/exec/tests/utils/CMakeLists.txt +++ b/velox/exec/tests/utils/CMakeLists.txt @@ -12,10 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_library(velox_temp_path TempFilePath.cpp TempDirectoryPath.cpp) - -target_link_libraries(velox_temp_path velox_exception) - add_library( velox_exec_test_lib AggregationResolver.cpp @@ -33,16 +29,44 @@ add_library( TableScanTestBase.cpp TestIndexStorageConnector.cpp TpchQueryBuilder.cpp + TpcdsQueryBuilder.cpp + VeloxPlanLoader.cpp VectorTestUtil.cpp PortUtil.cpp SerializedPageUtil.cpp ) +velox_add_test_headers( + velox_exec_test_lib + AggregationResolver.h + ArbitratorTestUtil.h + AssertQueryBuilder.h + ExpressionBuilder.h + FilterToExpression.h + HashJoinTestBase.h + HiveConnectorTestBase.h + IndexLookupJoinTestBase.h + LocalExchangeSource.h + MergeTestBase.h + OperatorTestBase.h + PlanBuilder.h + PortUtil.h + QueryAssertions.h + RowContainerTestBase.h + SerializedPageUtil.h + SumNonPODAggregate.h + TableScanTestBase.h + TableWriterTestBase.h + TestIndexStorageConnector.h + TpchQueryBuilder.h + TpcdsQueryBuilder.h + VectorTestUtil.h + VeloxPlanLoader.h +) target_link_libraries( velox_exec_test_lib velox_vector_test_lib velox_vector_fuzzer - velox_temp_path velox_cursor velox_core velox_exception @@ -63,3 +87,5 @@ target_link_libraries( velox_functions_prestosql velox_aggregates ) + +velox_add_library(velox_temp_path INTERFACE HEADERS TempDirectoryPath.h TempFilePath.h) diff --git a/velox/exec/tests/utils/ExpressionBuilder.h b/velox/exec/tests/utils/ExpressionBuilder.h new file mode 100644 index 00000000000..d2d208b2964 --- /dev/null +++ b/velox/exec/tests/utils/ExpressionBuilder.h @@ -0,0 +1,478 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/parse/Expressions.h" + +namespace facebook::velox::expr_builder { + +/// Fluent Expression Builder. +/// +/// This file contains fluent methods that make it convenient to create +/// (untyped) expression trees. This provides similar functionality to a SQL +/// parser, without bringing dependency on external libraries or bringing leaked +/// semantics from other systems. +/// +/// The untyped expressions can then be turned into typed expressions ready for +/// execution using type binding from `core::Expressions::inferTypes()`. +/// +/// The API provided is as close to the actual expression trees as possible. +/// Comparisons, arithmetics, conjuncts, function calls, literals, aliases, and +/// more are supported with this API. +/// +/// For example, to create a column reference, you can: +/// +/// > using namespace expr_builder; +/// > core::ExprPtr e = col("c0"); +/// +/// You can also use the "_c" C++ literal provided: +/// +/// > core::ExprPtr e = "c0"_c; +/// +/// Expressions created using ExpressionBuilder functions can be used in any +/// places that accept a ExprPtr. In practice, they create a ExprWrapper object, +/// but ExprWrappers are implicitly convertible to ExprSet. +/// +/// Nested column references can be specified as either: +/// +/// > col("parent", "child"); +/// > col("parent").subfield("child"); +/// +/// To debug the expression generated, you can simply: +/// +/// > LOG(INFO) << *e; +/// +/// Comparisons and other expressions can be fluently created using C++ +/// overloaded operators: +/// +/// > col("c") > 10; // "c > 10" +/// > col("c") != "bar"; // "c != 'bar'" +/// > col("c") == nullptr; // "c = null" +/// +/// C++ literals are automatically converted into ConstantExpr (expression +/// literals) when part of an expression. To explicitly create a literal you can +/// use: +/// +/// > lit(10.3); +/// > lit("str"); +/// +/// Casts can be done using one of the two formats: +/// +/// > lit(3).cast(TINYINT()); +/// > cast("str", VARBINARY()); +/// +/// Null checking filters: +/// +/// > isNull(col("c")) // "c is null" +/// > !isNull(col("c")) // "c is not null" +/// +/// Conjuncts and "between": +/// +/// > (col("a") && col("b")) || col("c"); // "(a AND b) OR c" +/// > between(col("a"), 0, 10); // "a between 0 and 10" +/// +/// You can also use fluent version of these APIs: +/// +/// > col("a").between(0, 10); // "a between 0 and 10" +/// +/// Arithmetic operators are also overloaded: +/// +/// > col("c") * 100 + col("b"); // "c * 100 + b" +/// +/// In any expression, as long as one of the sides is an expression node, the +/// correct expression will be created. For example, both version work as +/// expected: +/// +/// > col("c") * 100; // "c * 100" +/// > 100 * col("c"); // "100 * c" +/// +/// When building long expressions, be careful about C++ constant folding and +/// operator precedence: +/// +/// > col("c") + 5 * 100; +/// +/// C++ will fold "5 * 100" and generate the expression "c + 500". To force the +/// expected behavior, you can explicitly spell out the literal: +/// +/// > col("c") + 5 * lit(100); +/// > col("c") + lit(5) * 100; +/// +/// Both will generate "col + 5 * 100", which is "plus(col, multiply(5, 100))". +/// +/// Generic function calls can be created using `call()`: +/// +/// > call("func", 10); // "func(10)" +/// +/// `call()` supports arbitrary parameters, which can be other expressions or +/// (C++) literals. +/// +/// Lambdas can be created using the following syntax: +/// +/// > lambda({"x", "y"}, col("x") * col("y") + 1) +/// +/// Where the first parameter is a vector of the lambda arguments, and the +/// second the lambda body. +/// +/// All functions above can be nested and combined in arbitrary ways. +/// +/// > 10L * col("c1") > call("func", 3.4, col("g") / col("h"), call("j")); +/// +/// is the same as "10 * c1 > func(3.4, g / h, j())". +/// +/// Comparisons, arithmetics, and other operators are mapped to function names +/// according to the table below. It is the user's responsibility to make sure +/// that there names map to their appropriate implementation: +/// +/// ------------------------------- +/// | C++ | Function Name | +/// ------------------------------- +/// | operator== | eq | +/// | operator!= | neq | +/// | operator< | lt | +/// | operator<= | lte | +/// | operator> | gt | +/// | operator>= | gte | +/// | operator! | not | +/// | operator&& | and | +/// | operator|| | or | +/// | operator+ | plus | +/// | operator- | minus | +/// | operator* | multiply | +/// | operator/ | divide | +/// | operator% | mode | +/// | operator== | eq | +/// ------------------------------- + +namespace detail { + +class ExprWrapper; + +/// Either builds a ConstantExpr (literal) based on a scalar value, or passes +/// through an ExprWrapper already constructed. +template +inline ExprWrapper toExprWrapper(T value); + +// Specialization for long to avoid ambiguity. +inline ExprWrapper toExprWrapper(long value); + +template <> +inline ExprWrapper toExprWrapper(ExprWrapper expr); + +/// Wrapper library used so we can safely overload operators. +class ExprWrapper { + public: + ExprWrapper(const core::ExprPtr& expr) : expr_(expr) {} + + std::string toString() const { + return expr_->toString(); + } + + core::ExprPtr expr() const { + return expr_; + } + + /// Add an alias to the current expression: + /// + /// > col("c0").alias("my_column"); + ExprWrapper& alias(const std::string& newAlias) { + expr_ = expr_->withAlias(newAlias); + return *this; + } + + /// Add a "subfield" expression to enable access of subfields in + /// rows/structs: + /// + /// > col("parent_col").subfield("child_name"); + ExprWrapper& subfield(std::string childName) { + expr_ = std::make_shared( + std::move(childName), std::nullopt, std::vector{expr_}); + return *this; + } + + /// Add a "cast" to the current expression: + /// + /// > col("c0").cast(VARBINARY()); + /// > lit(10).cast(TINYINT()); + ExprWrapper& cast(const TypePtr& castType) { + expr_ = + std::make_shared(castType, expr_, false, std::nullopt); + return *this; + } + + /// Add a "try_cast" to the current expression: + /// + /// > col("c0").tryCast(VARBINARY()); + /// > lit(10).tryCast(TINYINT()); + ExprWrapper& tryCast(const TypePtr& castType) { + expr_ = + std::make_shared(castType, expr_, true, std::nullopt); + return *this; + } + + /// Add a "is_null" to the current expression: + /// + /// > col("c0").isNull(); + ExprWrapper& isNull() { + expr_ = std::make_shared( + "is_null", std::vector{expr_}, std::nullopt); + return *this; + } + + /// Add a "between" clause to the current expression wrapper: + /// + /// > col("a").between(1, 10); + template + ExprWrapper& between(const T1& value1, const T2& value2) { + expr_ = std::make_shared( + "between", + std::vector{ + expr_, + detail::toExprWrapper(value1), + detail::toExprWrapper(value2)}, + std::nullopt); + return *this; + } + + /// If equality is used against an actual ExpPtr (not the wrapper), this will + /// compare the expressions themselves. + /// + /// It won't assume this is generating a eq() Velox expression. + bool operator==(const core::ExprPtr& other) const { + return *expr_ == *other; + } + + /// Provide better gtest failure messages. + friend std::ostream& operator<<(std::ostream& os, const ExprWrapper& expr) { + return os << expr.expr_->toString(); + } + + /// For convenience, enable implicit conversions to ExprPtr. + operator core::ExprPtr() const { + return expr_; + } + + private: + core::ExprPtr expr_; +}; + +/// Unpacks a list of variadic template parameters in a +/// std::vector. The elements could be ExprWrapper or C++ +/// literals, which will get converted to ConstantExpr. +/// +/// Base of recursion. +inline std::vector unpackList() { + return {}; +} + +template +inline std::vector unpackList(TFirst first, TArgs&&... args) { + std::vector head = {toExprWrapper(first)}; + auto tail = unpackList(std::forward(args)...); + head.insert(head.end(), tail.begin(), tail.end()); + return head; +} + +} // namespace detail + +/// Column references. +inline detail::ExprWrapper col(std::string name) { + return {std::make_shared( + std::move(name), std::nullopt)}; +} + +/// Enable users to use a custom C++ literal to add a column reference. +/// For example: "col"_c +inline detail::ExprWrapper operator"" _c(const char* str, size_t len) { + return col(std::string(str, len)); +} + +/// Nested column names. Ror rows/struct member references. +inline detail::ExprWrapper col(std::string parentName, std::string childName) { + return col(std::move(parentName)).subfield(std::move(childName)); +} + +/// Literals. +inline detail::ExprWrapper lit(int64_t value) { + return {std::make_shared(BIGINT(), value, std::nullopt)}; +} + +inline detail::ExprWrapper lit(int32_t value) { + return {std::make_shared(INTEGER(), value, std::nullopt)}; +} + +inline detail::ExprWrapper lit(int16_t value) { + return { + std::make_shared(SMALLINT(), value, std::nullopt)}; +} + +inline detail::ExprWrapper lit(int8_t value) { + return {std::make_shared(TINYINT(), value, std::nullopt)}; +} + +inline detail::ExprWrapper lit(bool value) { + return {std::make_shared(BOOLEAN(), value, std::nullopt)}; +} + +inline detail::ExprWrapper lit(double value) { + return {std::make_shared(DOUBLE(), value, std::nullopt)}; +} + +inline detail::ExprWrapper lit(float value) { + return {std::make_shared(REAL(), value, std::nullopt)}; +} + +/// Different string flavors. +inline detail::ExprWrapper lit(const char* value) { + return {std::make_shared(VARCHAR(), value, std::nullopt)}; +} + +inline detail::ExprWrapper lit(const std::string_view& value) { + return {std::make_shared( + VARCHAR(), std::string(value), std::nullopt)}; +} + +inline detail::ExprWrapper lit(const std::string& value) { + return {std::make_shared(VARCHAR(), value, std::nullopt)}; +} + +/// lit(nullptr). +inline detail::ExprWrapper lit(std::nullptr_t) { + return {std::make_shared( + UNKNOWN(), variant::null(TypeKind::UNKNOWN), std::nullopt)}; +} + +/// Macro to reduce builerplate when overloading C++ operators. The template +/// magic basically means that the overload is matched if either left or right +/// operands are an ExprWrapper. This is so that both "c"_f + 10 and 10 + "c"_f +/// are supported, for example. +/// +/// If either left or right side are ExprWrapper, we either convert the other +/// side as a constant/literal, or use it as-is if it is already an ExprWrapper. +#define VELOX_EXPR_BUILDER_OPERATOR(__op, __name) \ + template \ + inline std::enable_if_t< \ + std::is_same_v || \ + std::is_same_v, \ + detail::ExprWrapper> \ + __op(T1 lhs, T2 rhs) { \ + return {std::make_shared( \ + __name, \ + std::vector{ \ + detail::toExprWrapper(lhs), detail::toExprWrapper(rhs)}, \ + std::nullopt)}; \ + } + +/// Define C++ operator overload for comparisons. +VELOX_EXPR_BUILDER_OPERATOR(operator==, "eq"); +VELOX_EXPR_BUILDER_OPERATOR(operator!=, "neq"); +VELOX_EXPR_BUILDER_OPERATOR(operator<, "lt"); +VELOX_EXPR_BUILDER_OPERATOR(operator<=, "lte"); +VELOX_EXPR_BUILDER_OPERATOR(operator>, "gt"); +VELOX_EXPR_BUILDER_OPERATOR(operator>=, "gte"); + +/// Define C++ operator overload for arithmetics. +VELOX_EXPR_BUILDER_OPERATOR(operator+, "plus"); +VELOX_EXPR_BUILDER_OPERATOR(operator-, "minus"); +VELOX_EXPR_BUILDER_OPERATOR(operator*, "multiply"); +VELOX_EXPR_BUILDER_OPERATOR(operator/, "divide"); +VELOX_EXPR_BUILDER_OPERATOR(operator%, "mod"); + +VELOX_EXPR_BUILDER_OPERATOR(operator&&, "and"); +VELOX_EXPR_BUILDER_OPERATOR(operator||, "or"); + +/// "not" is an unary operator. +inline detail::ExprWrapper operator!(detail::ExprWrapper expr) { + return {std::make_shared( + "not", std::vector{expr.expr()}, std::nullopt)}; +} + +/// "is_null" is also unary. +template +inline detail::ExprWrapper isNull(const T& expr) { + return detail::toExprWrapper(expr).isNull(); +} + +/// "alias" as a free function. +template +inline detail::ExprWrapper alias(TInput lhs, const std::string& newAlias) { + return detail::toExprWrapper(lhs).alias(newAlias); +} + +/// "cast" as a free function. +template +inline detail::ExprWrapper cast(TInput lhs, const TypePtr& castType) { + return detail::toExprWrapper(lhs).cast(castType); +} + +/// "tryCast" as a free function. +template +inline detail::ExprWrapper tryCast(TInput lhs, const TypePtr& castType) { + return detail::toExprWrapper(lhs).tryCast(castType); +} + +/// "between" as a free function. +template +inline detail::ExprWrapper +between(detail::ExprWrapper lhs, const T1& value1, const T2& value2) { + return lhs.between(value1, value2); +} + +/// Creates a lambda expressions, given the function parameters and an +/// expression for the function body. +template +inline detail::ExprWrapper lambda( + std::initializer_list args, + const TInput& body) { + return {std::make_shared( + std::move(args), detail::toExprWrapper(body))}; +} + +/// Convenience lambda builder for single argument lambdas. +template +inline detail::ExprWrapper lambda(std::string arg, const TInput& body) { + return lambda({std::move(arg)}, body); +} + +/// Regular function calls. First parameter is the function name, followed by +/// their parameters. Parameters can be other expression nodes or literals. +template +inline detail::ExprWrapper call(std::string name, TArgs&&... args) { + return {std::make_shared( + std::move(name), + detail::unpackList(std::forward(args)...), + std::nullopt)}; +} + +namespace detail { + +template +inline ExprWrapper toExprWrapper(T value) { + return lit(value); +} + +inline ExprWrapper toExprWrapper(long value) { + return lit(static_cast(value)); +} + +template <> +inline ExprWrapper toExprWrapper(ExprWrapper expr) { + return expr; +} + +} // namespace detail + +} // namespace facebook::velox::expr_builder diff --git a/velox/exec/tests/utils/FilterToExpression.cpp b/velox/exec/tests/utils/FilterToExpression.cpp index 4a73f9c9def..7546694935d 100644 --- a/velox/exec/tests/utils/FilterToExpression.cpp +++ b/velox/exec/tests/utils/FilterToExpression.cpp @@ -146,8 +146,9 @@ core::TypedExprPtr filterToExpr( subfieldType, variant(static_cast(rangeFilter->lower()))); - conditions.push_back(std::make_shared( - BOOLEAN(), "gte", subfieldExpr, lower)); + conditions.push_back( + std::make_shared( + BOOLEAN(), "gte", subfieldExpr, lower)); } if (rangeFilter->upper() < kMaxInt64) { @@ -155,8 +156,9 @@ core::TypedExprPtr filterToExpr( subfieldType, variant(static_cast(rangeFilter->upper()))); - conditions.push_back(std::make_shared( - BOOLEAN(), "lte", subfieldExpr, upper)); + conditions.push_back( + std::make_shared( + BOOLEAN(), "lte", subfieldExpr, upper)); } auto rangeExpr = createBooleanExpr(conditions); @@ -258,11 +260,13 @@ core::TypedExprPtr filterToExpr( } if (doubleFilter->lowerExclusive()) { - conditions.push_back(std::make_shared( - BOOLEAN(), "gt", subfieldExpr, lower)); + conditions.push_back( + std::make_shared( + BOOLEAN(), "gt", subfieldExpr, lower)); } else { - conditions.push_back(std::make_shared( - BOOLEAN(), "gte", subfieldExpr, lower)); + conditions.push_back( + std::make_shared( + BOOLEAN(), "gte", subfieldExpr, lower)); } } @@ -283,11 +287,13 @@ core::TypedExprPtr filterToExpr( } if (doubleFilter->upperExclusive()) { - conditions.push_back(std::make_shared( - BOOLEAN(), "lt", subfieldExpr, upper)); + conditions.push_back( + std::make_shared( + BOOLEAN(), "lt", subfieldExpr, upper)); } else { - conditions.push_back(std::make_shared( - BOOLEAN(), "lte", subfieldExpr, upper)); + conditions.push_back( + std::make_shared( + BOOLEAN(), "lte", subfieldExpr, upper)); } } @@ -308,11 +314,13 @@ core::TypedExprPtr filterToExpr( subfieldType, variant(static_cast(lowerValue))); if (floatFilter->lowerExclusive()) { - conditions.push_back(std::make_shared( - BOOLEAN(), "gt", subfieldExpr, lower)); + conditions.push_back( + std::make_shared( + BOOLEAN(), "gt", subfieldExpr, lower)); } else { - conditions.push_back(std::make_shared( - BOOLEAN(), "gte", subfieldExpr, lower)); + conditions.push_back( + std::make_shared( + BOOLEAN(), "gte", subfieldExpr, lower)); } } @@ -322,11 +330,13 @@ core::TypedExprPtr filterToExpr( subfieldType, variant(static_cast(upperValue))); if (floatFilter->upperExclusive()) { - conditions.push_back(std::make_shared( - BOOLEAN(), "lt", subfieldExpr, upper)); + conditions.push_back( + std::make_shared( + BOOLEAN(), "lt", subfieldExpr, upper)); } else { - conditions.push_back(std::make_shared( - BOOLEAN(), "lte", subfieldExpr, upper)); + conditions.push_back( + std::make_shared( + BOOLEAN(), "lte", subfieldExpr, upper)); } } @@ -346,11 +356,13 @@ core::TypedExprPtr filterToExpr( subfieldType, variant(bytesFilter->lower())); if (bytesFilter->isLowerExclusive()) { - conditions.push_back(std::make_shared( - BOOLEAN(), "gt", subfieldExpr, lower)); + conditions.push_back( + std::make_shared( + BOOLEAN(), "gt", subfieldExpr, lower)); } else { - conditions.push_back(std::make_shared( - BOOLEAN(), "gte", subfieldExpr, lower)); + conditions.push_back( + std::make_shared( + BOOLEAN(), "gte", subfieldExpr, lower)); } } @@ -359,11 +371,13 @@ core::TypedExprPtr filterToExpr( subfieldType, variant(bytesFilter->upper())); if (bytesFilter->isUpperExclusive()) { - conditions.push_back(std::make_shared( - BOOLEAN(), "lt", subfieldExpr, upper)); + conditions.push_back( + std::make_shared( + BOOLEAN(), "lt", subfieldExpr, upper)); } else { - conditions.push_back(std::make_shared( - BOOLEAN(), "lte", subfieldExpr, upper)); + conditions.push_back( + std::make_shared( + BOOLEAN(), "lte", subfieldExpr, upper)); } } @@ -401,21 +415,25 @@ core::TypedExprPtr filterToExpr( for (const auto& value : values) { switch (subfieldType->kind()) { case TypeKind::TINYINT: - valueExprs.push_back(std::make_shared( - subfieldType, variant(static_cast(value)))); + valueExprs.push_back( + std::make_shared( + subfieldType, variant(static_cast(value)))); break; case TypeKind::SMALLINT: - valueExprs.push_back(std::make_shared( - subfieldType, variant(static_cast(value)))); + valueExprs.push_back( + std::make_shared( + subfieldType, variant(static_cast(value)))); break; case TypeKind::INTEGER: - valueExprs.push_back(std::make_shared( - subfieldType, variant(static_cast(value)))); + valueExprs.push_back( + std::make_shared( + subfieldType, variant(static_cast(value)))); break; case TypeKind::BIGINT: default: - valueExprs.push_back(std::make_shared( - subfieldType, variant(value))); + valueExprs.push_back( + std::make_shared( + subfieldType, variant(value))); break; } } diff --git a/velox/exec/tests/utils/HashJoinTestBase.h b/velox/exec/tests/utils/HashJoinTestBase.h index 1f01b7a2e28..6d9ee6b7854 100644 --- a/velox/exec/tests/utils/HashJoinTestBase.h +++ b/velox/exec/tests/utils/HashJoinTestBase.h @@ -17,20 +17,21 @@ #include #include -#include "folly/experimental/EventCount.h" +#include "folly/synchronization/EventCount.h" #include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/common/testutil/TestValue.h" #include "velox/dwio/common/tests/utils/BatchMaker.h" #include "velox/exec/Cursor.h" #include "velox/exec/HashBuild.h" #include "velox/exec/HashJoinBridge.h" +#include "velox/exec/OperatorType.h" #include "velox/exec/OperatorUtils.h" #include "velox/exec/PlanNodeStats.h" #include "velox/exec/tests/utils/ArbitratorTestUtil.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/exec/tests/utils/VectorTestUtil.h" #include "velox/vector/fuzzer/VectorFuzzer.h" @@ -42,28 +43,45 @@ using facebook::velox::test::BatchMaker; struct TestParam { int64_t numDrivers{1}; + bool parallelBuildSideRowsEnabled; - explicit TestParam(int _numDrivers) : numDrivers(_numDrivers) {} + explicit TestParam(int _numDrivers) + : numDrivers(_numDrivers), parallelBuildSideRowsEnabled(false) {} + + TestParam(int _numDrivers, bool _parallelBuildSideRowsEnabled) + : numDrivers(_numDrivers), + parallelBuildSideRowsEnabled(_parallelBuildSideRowsEnabled) {} }; +// Required for GTest to generate unique parameterized test names. +inline std::string TestParamToName(const TestParam& param) { + return fmt::format( + "{}_drivers_{}_parallelBuildSideRowsEnabled", + param.numDrivers, + param.parallelBuildSideRowsEnabled ? "with" : "without"); +} + +using SplitPath = + std::unordered_map>; + using SplitInput = std::unordered_map>; // Returns aggregated spilled stats by build and probe operators from 'task'. -std::pair taskSpilledStats( +std::pair taskSpilledStats( const exec::Task& task) { - common::SpillStats buildStats; - common::SpillStats probeStats; + exec::SpillStats buildStats; + exec::SpillStats probeStats; auto stats = task.taskStats(); for (auto& pipeline : stats.pipelineStats) { for (auto op : pipeline.operatorStats) { - if (op.operatorType == "HashBuild") { + if (op.operatorType == OperatorType::kHashBuild) { buildStats.spilledInputBytes += op.spilledInputBytes; buildStats.spilledBytes += op.spilledBytes; buildStats.spilledRows += op.spilledRows; buildStats.spilledPartitions += op.spilledPartitions; buildStats.spilledFiles += op.spilledFiles; - } else if (op.operatorType == "HashProbe") { + } else if (op.operatorType == OperatorType::kHashProbe) { probeStats.spilledInputBytes += op.spilledInputBytes; probeStats.spilledBytes += op.spilledBytes; probeStats.spilledRows += op.spilledRows; @@ -81,55 +99,90 @@ void verifyTaskSpilledRuntimeStats(const exec::Task& task, bool expectedSpill) { auto stats = task.taskStats(); for (auto& pipeline : stats.pipelineStats) { for (auto op : pipeline.operatorStats) { - if ((op.operatorType == "HashBuild") || - (op.operatorType == "HashProbe")) { + if ((op.operatorType == OperatorType::kHashBuild) || + (op.operatorType == OperatorType::kHashProbe)) { if (!expectedSpill) { - ASSERT_EQ(op.runtimeStats[Operator::kSpillRuns].count, 0); - ASSERT_EQ(op.runtimeStats[Operator::kSpillFillTime].count, 0); - ASSERT_EQ(op.runtimeStats[Operator::kSpillSortTime].count, 0); ASSERT_EQ( - op.runtimeStats[Operator::kSpillExtractVectorTime].count, 0); + op.runtimeStats[std::string(Operator::kSpillRuns)].count, 0); + ASSERT_EQ( + op.runtimeStats[std::string(Operator::kSpillFillTime)].count, 0); + ASSERT_EQ( + op.runtimeStats[std::string(Operator::kSpillSortTime)].count, 0); + ASSERT_EQ( + op.runtimeStats[std::string(Operator::kSpillExtractVectorTime)] + .count, + 0); + ASSERT_EQ( + op.runtimeStats[std::string(Operator::kSpillSerializationTime)] + .count, + 0); + ASSERT_EQ( + op.runtimeStats[std::string(Operator::kSpillFlushTime)].count, 0); ASSERT_EQ( - op.runtimeStats[Operator::kSpillSerializationTime].count, 0); - ASSERT_EQ(op.runtimeStats[Operator::kSpillFlushTime].count, 0); - ASSERT_EQ(op.runtimeStats[Operator::kSpillWrites].count, 0); - ASSERT_EQ(op.runtimeStats[Operator::kSpillWriteTime].count, 0); - ASSERT_EQ(op.runtimeStats[Operator::kSpillReadBytes].count, 0); - ASSERT_EQ(op.runtimeStats[Operator::kSpillReads].count, 0); - ASSERT_EQ(op.runtimeStats[Operator::kSpillReadTime].count, 0); + op.runtimeStats[std::string(Operator::kSpillWrites)].count, 0); ASSERT_EQ( - op.runtimeStats[Operator::kSpillDeserializationTime].count, 0); + op.runtimeStats[std::string(Operator::kSpillWriteTime)].count, 0); + ASSERT_EQ( + op.runtimeStats[std::string(Operator::kSpillReadBytes)].count, 0); + ASSERT_EQ( + op.runtimeStats[std::string(Operator::kSpillReads)].count, 0); + ASSERT_EQ( + op.runtimeStats[std::string(Operator::kSpillReadTime)].count, 0); + ASSERT_EQ( + op.runtimeStats[std::string(Operator::kSpillDeserializationTime)] + .count, + 0); } else { - if (op.operatorType == "HashBuild") { - ASSERT_GT(op.runtimeStats[Operator::kSpillRuns].count, 0); - ASSERT_GT(op.runtimeStats[Operator::kSpillFillTime].sum, 0); + if (op.operatorType == OperatorType::kHashBuild) { ASSERT_GT( - op.runtimeStats[Operator::kSpillExtractVectorTime].sum, 0); + op.runtimeStats[std::string(Operator::kSpillRuns)].count, 0); + ASSERT_GT( + op.runtimeStats[std::string(Operator::kSpillFillTime)].sum, 0); + ASSERT_GT( + op.runtimeStats[std::string(Operator::kSpillExtractVectorTime)] + .sum, + 0); } else { // The table spilling might also be triggered from hash probe side. - ASSERT_GE(op.runtimeStats[Operator::kSpillRuns].count, 0); - ASSERT_GE(op.runtimeStats[Operator::kSpillFillTime].sum, 0); ASSERT_GE( - op.runtimeStats[Operator::kSpillExtractVectorTime].sum, 0); + op.runtimeStats[std::string(Operator::kSpillRuns)].count, 0); + ASSERT_GE( + op.runtimeStats[std::string(Operator::kSpillFillTime)].sum, 0); + ASSERT_GE( + op.runtimeStats[std::string(Operator::kSpillExtractVectorTime)] + .sum, + 0); } - ASSERT_EQ(op.runtimeStats[Operator::kSpillSortTime].sum, 0); - ASSERT_GT(op.runtimeStats[Operator::kSpillSerializationTime].sum, 0); - ASSERT_GE(op.runtimeStats[Operator::kSpillFlushTime].sum, 0); + ASSERT_EQ( + op.runtimeStats[std::string(Operator::kSpillSortTime)].sum, 0); + ASSERT_GT( + op.runtimeStats[std::string(Operator::kSpillSerializationTime)] + .sum, + 0); + ASSERT_GE( + op.runtimeStats[std::string(Operator::kSpillFlushTime)].sum, 0); // NOTE: spill flush might take less than one microsecond. ASSERT_GE( - op.runtimeStats[Operator::kSpillSerializationTime].count, - op.runtimeStats[Operator::kSpillFlushTime].count); - ASSERT_GT(op.runtimeStats[Operator::kSpillWrites].sum, 0); - ASSERT_GE(op.runtimeStats[Operator::kSpillWriteTime].sum, 0); + op.runtimeStats[std::string(Operator::kSpillSerializationTime)] + .count, + op.runtimeStats[std::string(Operator::kSpillFlushTime)].count); + ASSERT_GT( + op.runtimeStats[std::string(Operator::kSpillWrites)].sum, 0); + ASSERT_GE( + op.runtimeStats[std::string(Operator::kSpillWriteTime)].sum, 0); // NOTE: spill flush might take less than one microsecond. ASSERT_GE( - op.runtimeStats[Operator::kSpillWrites].count, - op.runtimeStats[Operator::kSpillWriteTime].count); - ASSERT_GT(op.runtimeStats[Operator::kSpillReadBytes].sum, 0); - ASSERT_GT(op.runtimeStats[Operator::kSpillReads].sum, 0); - ASSERT_GT(op.runtimeStats[Operator::kSpillReadTime].sum, 0); + op.runtimeStats[std::string(Operator::kSpillWrites)].count, + op.runtimeStats[std::string(Operator::kSpillWriteTime)].count); ASSERT_GT( - op.runtimeStats[Operator::kSpillDeserializationTime].sum, 0); + op.runtimeStats[std::string(Operator::kSpillReadBytes)].sum, 0); + ASSERT_GT(op.runtimeStats[std::string(Operator::kSpillReads)].sum, 0); + ASSERT_GT( + op.runtimeStats[std::string(Operator::kSpillReadTime)].sum, 0); + ASSERT_GT( + op.runtimeStats[std::string(Operator::kSpillDeserializationTime)] + .sum, + 0); } } } @@ -155,12 +208,15 @@ int32_t maxHashBuildSpillLevel(const exec::Task& task) { int32_t maxSpillLevel = -1; for (auto& pipelineStat : task.taskStats().pipelineStats) { for (auto& operatorStat : pipelineStat.operatorStats) { - if (operatorStat.operatorType == "HashBuild") { - if (operatorStat.runtimeStats.count("maxSpillLevel") == 0) { + if (operatorStat.operatorType == OperatorType::kHashBuild) { + if (operatorStat.runtimeStats.count( + std::string(HashBuild::kMaxSpillLevel)) == 0) { continue; } maxSpillLevel = std::max( - maxSpillLevel, operatorStat.runtimeStats["maxSpillLevel"].max); + maxSpillLevel, + operatorStat.runtimeStats[std::string(HashBuild::kMaxSpillLevel)] + .max); } } } @@ -175,11 +231,11 @@ std::pair numTaskSpillFiles(const exec::Task& task) { if (operatorStat.runtimeStats.count("spillFileSize") == 0) { continue; } - if (operatorStat.operatorType == "HashBuild") { + if (operatorStat.operatorType == OperatorType::kHashBuild) { numBuildFiles += operatorStat.runtimeStats["spillFileSize"].count; continue; } - if (operatorStat.operatorType == "HashProbe") { + if (operatorStat.operatorType == OperatorType::kHashProbe) { numProbeFiles += operatorStat.runtimeStats["spillFileSize"].count; } } @@ -344,6 +400,11 @@ class HashJoinBuilder { return *this; } + HashJoinBuilder& nullAsValue(bool nullAsValue) { + nullAsValue_ = nullAsValue; + return *this; + } + HashJoinBuilder& joinFilter(const std::string& joinFilter) { joinFilter_ = joinFilter; return *this; @@ -361,8 +422,20 @@ class HashJoinBuilder { return *this; } - HashJoinBuilder& inputSplits(const SplitInput& inputSplits) { - makeInputSplits_ = [inputSplits] { return inputSplits; }; + HashJoinBuilder& inputSplits(const SplitPath& splitPaths) { + makeInputSplits_ = [splitPaths] { + SplitInput inputSplits; + for (const auto& [nodeId, paths] : splitPaths) { + std::vector splits; + splits.reserve(paths.size()); + for (const auto& path : paths) { + splits.emplace_back( + exec::Split(HiveConnectorSplitBuilder(path).build())); + } + inputSplits[nodeId] = std::move(splits); + } + return inputSplits; + }; return *this; } @@ -407,6 +480,11 @@ class HashJoinBuilder { return *this; } + HashJoinBuilder& parallelizeJoinBuildRows(bool value) { + parallelJoinBuildRowsEnabled_ = value; + return *this; + } + HashJoinBuilder& spillDirectory(const std::string& spillDirectory) { spillDirectory_ = spillDirectory; return *this; @@ -467,8 +545,9 @@ class HashJoinBuilder { } for (const auto& testData : testSettings) { - SCOPED_TRACE(fmt::format( - "{} numDrivers: {}", testData.debugString(), numDrivers_)); + SCOPED_TRACE( + fmt::format( + "{} numDrivers: {}", testData.debugString(), numDrivers_)); auto planNodeIdGenerator = std::make_shared(); std::shared_ptr joinNode; auto planNode = @@ -492,7 +571,8 @@ class HashJoinBuilder { joinFilter_, joinOutputLayout_, joinType_, - nullAware_) + nullAware_, + nullAsValue_) .capturePlanNode(joinNode) .optionalProject(outputProjections_) .planNode(); @@ -596,19 +676,18 @@ class HashJoinBuilder { builder.splits(splitEntry.first, splitEntry.second); } } - auto queryCtx = core::QueryCtx::create( - executor_, - core::QueryConfig{{}}, - std::unordered_map>{}, - cache::AsyncDataCache::getInstance(), - memory::MemoryManager::getInstance()->addRootPool( - "query_pool", - memory::kMaxMemory, - memory::MemoryReclaimer::create())); + auto queryCtx = core::QueryCtx::Builder() + .executor(executor_) + .pool( + memory::MemoryManager::getInstance()->addRootPool( + "query_pool", + memory::kMaxMemory, + memory::MemoryReclaimer::create())) + .build(); std::shared_ptr spillDirectory; int32_t spillPct{0}; if (injectSpill) { - spillDirectory = exec::test::TempDirectoryPath::create(); + spillDirectory = TempDirectoryPath::create(); builder.spillDirectory(spillDirectory->getPath()); config(core::QueryConfig::kSpillEnabled, "true"); config(core::QueryConfig::kMaxSpillLevel, std::to_string(maxSpillLevel)); @@ -628,6 +707,9 @@ class HashJoinBuilder { config( core::QueryConfig::kHashProbeFinishEarlyOnEmptyBuild, hashProbeFinishEarlyOnEmptyBuild_ ? "true" : "false"); + config( + core::QueryConfig::kParallelOutputJoinBuildRowsEnabled, + parallelJoinBuildRowsEnabled_ ? "true" : "false"); if (maxDriverYieldTimeMs != 0) { config( core::QueryConfig::kDriverCpuTimeSliceLimitMs, @@ -728,6 +810,7 @@ class HashJoinBuilder { int32_t numDrivers_{1}; core::JoinType joinType_{core::JoinType::kInner}; bool nullAware_{false}; + bool nullAsValue_{false}; std::string referenceQuery_; RowTypePtr probeType_; @@ -759,8 +842,8 @@ class HashJoinBuilder { std::shared_ptr queryPool_; std::string spillDirectory_; bool hashProbeFinishEarlyOnEmptyBuild_{true}; + bool parallelJoinBuildRowsEnabled_{false}; - SplitInput inputSplits_; std::function makeInputSplits_; core::PlanNodePtr planNode_; std::unordered_map configs_; @@ -773,7 +856,8 @@ class HashJoinTestBase : public HiveConnectorTestBase { HashJoinTestBase() : HashJoinTestBase(TestParam(1)) {} explicit HashJoinTestBase(const TestParam& param) - : numDrivers_(param.numDrivers) {} + : numDrivers_(param.numDrivers), + parallelBuildSideRowsEnabled_(param.parallelBuildSideRowsEnabled) {} void SetUp() override { HiveConnectorTestBase::SetUp(); @@ -855,15 +939,13 @@ class HashJoinTestBase : public HiveConnectorTestBase { outputLayout, joinType) .planNode(); - SplitInput splitInput = { - {probeScanId, - {exec::Split(makeHiveConnectorSplit(probeFile->getPath()))}}, - {buildScanId, - {exec::Split(makeHiveConnectorSplit(buildFile->getPath()))}}, + SplitPath splitPaths = { + {probeScanId, {probeFile->getPath()}}, + {buildScanId, {buildFile->getPath()}}, }; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .planNode(std::move(op)) - .inputSplits(splitInput) + .inputSplits(splitPaths) .checkSpillStats(false) .referenceQuery(referenceQuery) .run(); @@ -966,6 +1048,7 @@ class HashJoinTestBase : public HiveConnectorTestBase { } const int32_t numDrivers_; + const bool parallelBuildSideRowsEnabled_; // The default left and right table types used for test. RowTypePtr probeType_; diff --git a/velox/exec/tests/utils/HiveConnectorTestBase.cpp b/velox/exec/tests/utils/HiveConnectorTestBase.cpp index 6e12c0d98aa..a8a0cd09a0c 100644 --- a/velox/exec/tests/utils/HiveConnectorTestBase.cpp +++ b/velox/exec/tests/utils/HiveConnectorTestBase.cpp @@ -16,9 +16,12 @@ #include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/common/caching/AsyncDataCache.h" #include "velox/common/file/FileSystems.h" #include "velox/common/file/tests/FaultyFileSystem.h" +#include "velox/connectors/ConnectorRegistry.h" #include "velox/connectors/hive/HiveConnector.h" +#include "velox/dwio/common/FileSink.h" #include "velox/dwio/common/tests/utils/BatchMaker.h" #include "velox/dwio/dwrf/RegisterDwrfReader.h" #include "velox/dwio/dwrf/RegisterDwrfWriter.h" @@ -28,6 +31,7 @@ #include "velox/exec/tests/utils/AssertQueryBuilder.h" namespace facebook::velox::exec::test { +using namespace facebook::velox::common::testutil; HiveConnectorTestBase::HiveConnectorTestBase() { filesystems::registerLocalFileSystem(); @@ -36,13 +40,21 @@ HiveConnectorTestBase::HiveConnectorTestBase() { void HiveConnectorTestBase::SetUp() { OperatorTestBase::SetUp(); + + // Clear any stale cache entries from previous tests to avoid reading + // corrupted/stale data when temp files are reused (same inode/fileNum). + if (auto* cache = cache::AsyncDataCache::getInstance()) { + cache->clear(); + } + connector::hive::HiveConnectorFactory factory; auto hiveConnector = factory.newConnector( kHiveConnectorId, std::make_shared( std::unordered_map()), ioExecutor_.get()); - connector::registerConnector(hiveConnector); + connector::ConnectorRegistry::global().insert( + hiveConnector->connectorId(), hiveConnector); dwio::common::registerFileSinks(); dwrf::registerDwrfReaderFactory(); dwrf::registerDwrfWriterFactory(); @@ -55,19 +67,20 @@ void HiveConnectorTestBase::TearDown() { ioExecutor_.reset(); dwrf::unregisterDwrfReaderFactory(); dwrf::unregisterDwrfWriterFactory(); - connector::unregisterConnector(kHiveConnectorId); + connector::ConnectorRegistry::global().erase(kHiveConnectorId); text::unregisterTextReaderFactory(); OperatorTestBase::TearDown(); } void HiveConnectorTestBase::resetHiveConnector( const std::shared_ptr& config) { - connector::unregisterConnector(kHiveConnectorId); + connector::ConnectorRegistry::global().erase(kHiveConnectorId); connector::hive::HiveConnectorFactory factory; auto hiveConnector = factory.newConnector(kHiveConnectorId, config, ioExecutor_.get()); - connector::registerConnector(hiveConnector); + connector::ConnectorRegistry::global().insert( + hiveConnector->connectorId(), hiveConnector); } void HiveConnectorTestBase::writeToFiles( @@ -244,7 +257,7 @@ HiveConnectorTestBase::makeColumnHandle( const TypePtr& dataType, const TypePtr& hiveType, const std::vector& requiredSubfields, - connector::hive::HiveColumnHandle::ColumnType columnType) { + connector::hive::FileColumnHandle::ColumnType columnType) { std::vector subfields; subfields.reserve(requiredSubfields.size()); for (auto& path : requiredSubfields) { @@ -320,7 +333,7 @@ HiveConnectorTestBase::makeHiveInsertTableHandle( std::move(locationHandle), tableStorageFormat, compressionKind, - {}, + {}, // serdeParameters writerOptions, ensureFiles); } @@ -337,7 +350,8 @@ HiveConnectorTestBase::makeHiveInsertTableHandle( const std::optional compressionKind, const std::unordered_map& serdeParameters, const std::shared_ptr& writerOptions, - const bool ensureFiles) { + const bool ensureFiles, + const std::unordered_map& storageParameters) { std::vector> columnHandles; std::vector bucketedBy; @@ -371,14 +385,14 @@ HiveConnectorTestBase::makeHiveInsertTableHandle( columnHandles.push_back( std::make_shared( tableColumnNames.at(i), - connector::hive::HiveColumnHandle::ColumnType::kPartitionKey, + connector::hive::FileColumnHandle::ColumnType::kPartitionKey, tableColumnTypes.at(i), tableColumnTypes.at(i))); } else { columnHandles.push_back( std::make_shared( tableColumnNames.at(i), - connector::hive::HiveColumnHandle::ColumnType::kRegular, + connector::hive::FileColumnHandle::ColumnType::kRegular, tableColumnTypes.at(i), tableColumnTypes.at(i))); } @@ -395,7 +409,9 @@ HiveConnectorTestBase::makeHiveInsertTableHandle( compressionKind, serdeParameters, writerOptions, - ensureFiles); + ensureFiles, + std::make_shared(), + storageParameters); } std::shared_ptr @@ -404,7 +420,7 @@ HiveConnectorTestBase::regularColumn( const TypePtr& type) { return std::make_shared( name, - connector::hive::HiveColumnHandle::ColumnType::kRegular, + connector::hive::FileColumnHandle::ColumnType::kRegular, type, type); } @@ -415,7 +431,7 @@ HiveConnectorTestBase::synthesizedColumn( const TypePtr& type) { return std::make_shared( name, - connector::hive::HiveColumnHandle::ColumnType::kSynthesized, + connector::hive::FileColumnHandle::ColumnType::kSynthesized, type, type); } @@ -426,7 +442,7 @@ HiveConnectorTestBase::partitionKey( const TypePtr& type) { return std::make_shared( name, - connector::hive::HiveColumnHandle::ColumnType::kPartitionKey, + connector::hive::FileColumnHandle::ColumnType::kPartitionKey, type, type); } diff --git a/velox/exec/tests/utils/HiveConnectorTestBase.h b/velox/exec/tests/utils/HiveConnectorTestBase.h index b3660714f50..a402fb9de91 100644 --- a/velox/exec/tests/utils/HiveConnectorTestBase.h +++ b/velox/exec/tests/utils/HiveConnectorTestBase.h @@ -15,16 +15,18 @@ */ #pragma once +#include "velox/common/testutil/TempFilePath.h" #include "velox/connectors/hive/HiveConnectorSplit.h" #include "velox/connectors/hive/HiveDataSink.h" #include "velox/connectors/hive/TableHandle.h" #include "velox/dwio/dwrf/common/Config.h" #include "velox/dwio/dwrf/writer/FlushPolicy.h" #include "velox/exec/tests/utils/OperatorTestBase.h" -#include "velox/exec/tests/utils/TempFilePath.h" namespace facebook::velox::exec::test { +using TempFilePath = common::testutil::TempFilePath; + static const std::string kHiveConnectorId = "test-hive"; class HiveConnectorTestBase : public OperatorTestBase { @@ -35,7 +37,9 @@ class HiveConnectorTestBase : public OperatorTestBase { void TearDown() override; void resetHiveConnector( - const std::shared_ptr& config); + const std::shared_ptr& config = + std::make_shared( + std::unordered_map())); void writeToFiles( const std::vector& filePaths, @@ -130,17 +134,17 @@ class HiveConnectorTestBase : public OperatorTestBase { const core::TypedExprPtr& remainingFilter = nullptr, const std::string& tableName = "hive_table", const RowTypePtr& dataColumns = nullptr, - bool filterPushdownEnabled = true, - const std::unordered_map& tableParameters = + const std::vector& indexColumns = {}, + const std::unordered_map& storageParameters = {}) { return std::make_shared( kHiveConnectorId, tableName, - filterPushdownEnabled, std::move(subfieldFilters), remainingFilter, dataColumns, - tableParameters); + indexColumns, + storageParameters); } /// @param name Column name. @@ -160,8 +164,8 @@ class HiveConnectorTestBase : public OperatorTestBase { const TypePtr& dataType, const TypePtr& hiveType, const std::vector& requiredSubfields, - connector::hive::HiveColumnHandle::ColumnType columnType = - connector::hive::HiveColumnHandle::ColumnType::kRegular); + connector::hive::FileColumnHandle::ColumnType columnType = + connector::hive::FileColumnHandle::ColumnType::kRegular); /// @param targetDirectory Final directory of the target table after commit. /// @param writeDirectory Write directory of the target table before commit. @@ -203,7 +207,9 @@ class HiveConnectorTestBase : public OperatorTestBase { const std::unordered_map& serdeParameters = {}, const std::shared_ptr& writerOptions = nullptr, - const bool ensureFiles = false); + const bool ensureFiles = false, + const std::unordered_map& storageParameters = + {}); static std::shared_ptr makeHiveInsertTableHandle( diff --git a/velox/exec/tests/utils/IndexLookupJoinTestBase.cpp b/velox/exec/tests/utils/IndexLookupJoinTestBase.cpp index 9b4e85eb931..0f518b1ad6c 100644 --- a/velox/exec/tests/utils/IndexLookupJoinTestBase.cpp +++ b/velox/exec/tests/utils/IndexLookupJoinTestBase.cpp @@ -17,13 +17,14 @@ #include "velox/exec/tests/utils/IndexLookupJoinTestBase.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/exec/tests/utils/TestIndexStorageConnector.h" -namespace fecebook::velox::exec::test { -using namespace facebook::velox::test; +namespace facebook::velox::exec::test { +using namespace facebook::velox::common::testutil; +using velox::test::assertEqualVectors; namespace { -std::vector appendMatchColumn( - const std::vector columns) { +std::vector appendMarker(const std::vector columns) { std::vector resultColumns; resultColumns.reserve(columns.size() + 1); for (const auto& column : columns) { @@ -56,7 +57,7 @@ std::vector IndexLookupJoinTestBase::generateProbeInput( size_t numBatches, size_t batchSize, size_t numDuplicateProbeRows, - SequenceTableData& tableData, + IndexTableData& tableData, std::shared_ptr& pool, const std::vector& probeJoinKeys, bool hasNullKeys, @@ -80,7 +81,7 @@ std::vector IndexLookupJoinTestBase::generateProbeInput( probeInputs.push_back(fuzzer.fuzzInputRow(probeType_)); // NOTE: index connector doesn't expect in condition column rray elements to // be null. - if ((!inMatchPct.has_value() || tableData.keyData->size() == 0) && + if ((!inMatchPct.has_value() || tableData.keyVectors->size() == 0) && hasNullKeys) { for (int i = 0; i < probeType_->size(); ++i) { const auto columnType = probeType_->childAt(i); @@ -97,14 +98,14 @@ std::vector IndexLookupJoinTestBase::generateProbeInput( } } - if (tableData.keyData->size() == 0) { + if (tableData.keyVectors->size() == 0) { return probeInputs; } - const auto numTableRows = tableData.keyData->size(); + const auto numTableRows = tableData.keyVectors->size(); std::vector> tableKeyVectors; for (int i = 0; i < probeJoinKeys.size(); ++i) { - auto keyVector = tableData.keyData->childAt(i); + auto keyVector = tableData.keyVectors->childAt(i); keyVector->loadedVector(); BaseVector::flattenVector(keyVector); tableKeyVectors.push_back( @@ -117,10 +118,11 @@ std::vector IndexLookupJoinTestBase::generateProbeInput( for (int i = 0; i < numBatches; ++i) { std::vector> probeKeyVectors; for (int j = 0; j < probeJoinKeys.size(); ++j) { - probeKeyVectors.push_back(BaseVector::create>( - probeType_->findChild(probeJoinKeys[j]), - probeInputs[i]->size(), - pool.get())); + probeKeyVectors.push_back( + BaseVector::create>( + probeType_->findChild(probeJoinKeys[j]), + probeInputs[i]->size(), + pool.get())); } for (int row = 0; row < probeInputs[i]->size(); row += numDuplicateProbeRows) { @@ -259,7 +261,7 @@ PlanNodePtr IndexLookupJoinTestBase::makeLookupPlan( const std::vector& leftKeys, const std::vector& rightKeys, const std::vector& joinConditions, - bool includeMatchColumn, + bool hasMarker, JoinType joinType, const std::vector& outputColumns, PlanNodeId& joinNodeId) { @@ -267,14 +269,15 @@ PlanNodePtr IndexLookupJoinTestBase::makeLookupPlan( VELOX_CHECK_LE(leftKeys.size(), keyType_->size()); return PlanBuilder(planNodeIdGenerator, pool_.get()) .values(probeVectors) - .indexLookupJoin( - leftKeys, - rightKeys, - indexScanNode, - joinConditions, - includeMatchColumn, - includeMatchColumn ? appendMatchColumn(outputColumns) : outputColumns, - joinType) + .startIndexLookupJoin() + .leftKeys(leftKeys) + .rightKeys(rightKeys) + .indexSource(indexScanNode) + .joinConditions(joinConditions) + .hasMarker(hasMarker) + .outputLayout(hasMarker ? appendMarker(outputColumns) : outputColumns) + .joinType(joinType) + .endIndexLookupJoin() .capturePlanNodeId(joinNodeId) .planNode(); } @@ -285,7 +288,8 @@ PlanNodePtr IndexLookupJoinTestBase::makeLookupPlan( const std::vector& leftKeys, const std::vector& rightKeys, const std::vector& joinConditions, - bool includeMatchColumn, + const std::string& filter, + bool hasMarker, JoinType joinType, const std::vector& outputColumns) { VELOX_CHECK_EQ(leftKeys.size(), rightKeys.size()); @@ -295,14 +299,16 @@ PlanNodePtr IndexLookupJoinTestBase::makeLookupPlan( .outputType(probeType_) .endTableScan() .captureScanNodeId(probeScanNodeId_) - .indexLookupJoin( - leftKeys, - rightKeys, - indexScanNode, - joinConditions, - includeMatchColumn, - includeMatchColumn ? appendMatchColumn(outputColumns) : outputColumns, - joinType) + .startIndexLookupJoin() + .leftKeys(leftKeys) + .rightKeys(rightKeys) + .indexSource(indexScanNode) + .joinConditions(joinConditions) + .filter(filter) + .hasMarker(hasMarker) + .outputLayout(hasMarker ? appendMarker(outputColumns) : outputColumns) + .joinType(joinType) + .endIndexLookupJoin() .capturePlanNodeId(joinNodeId_) .planNode(); } @@ -346,7 +352,7 @@ TableScanNodePtr IndexLookupJoinTestBase::makeIndexScanNode( void IndexLookupJoinTestBase::generateIndexTableData( const std::vector& keyCardinalities, - SequenceTableData& tableData, + IndexTableData& tableData, std::shared_ptr& pool) { VELOX_CHECK_EQ(keyCardinalities.size(), keyType_->size()); const auto numRows = getNumRows(keyCardinalities); @@ -356,10 +362,10 @@ void IndexLookupJoinTestBase::generateIndexTableData( opts.allowSlice = false; VectorFuzzer fuzzer(opts, pool.get()); - tableData.keyData = fuzzer.fuzzInputFlatRow(keyType_); - tableData.valueData = fuzzer.fuzzInputFlatRow(valueType_); + tableData.keyVectors = fuzzer.fuzzInputFlatRow(keyType_); + tableData.valueVectors = fuzzer.fuzzInputFlatRow(valueType_); - VELOX_CHECK_EQ(numRows, tableData.keyData->size()); + VELOX_CHECK_EQ(numRows, tableData.keyVectors->size()); tableData.maxKeys.resize(keyType_->size()); tableData.minKeys.resize(keyType_->size()); // Set the key column vector to the same value to easy testing with @@ -369,8 +375,8 @@ void IndexLookupJoinTestBase::generateIndexTableData( int64_t minKey = std::numeric_limits::max(); int64_t maxKey = std::numeric_limits::min(); int numKeys = keyCardinalities[i]; - tableData.keyData->childAt(i) = - makeFlatVector(tableData.keyData->size(), [&](auto row) { + tableData.keyVectors->childAt(i) = + makeFlatVector(tableData.keyVectors->size(), [&](auto row) { const int64_t keyValue = 1 + (row / numRepeats) % numKeys; minKey = std::min(minKey, keyValue); maxKey = std::max(maxKey, keyValue); @@ -384,12 +390,12 @@ void IndexLookupJoinTestBase::generateIndexTableData( VELOX_CHECK_EQ(tableType_->size(), keyType_->size() + valueType_->size()); tableColumns.reserve(tableType_->size()); for (auto i = 0; i < keyType_->size(); ++i) { - tableColumns.push_back(tableData.keyData->childAt(i)); + tableColumns.push_back(tableData.keyVectors->childAt(i)); } for (auto i = 0; i < valueType_->size(); ++i) { - tableColumns.push_back(tableData.valueData->childAt(i)); + tableColumns.push_back(tableData.valueVectors->childAt(i)); } - tableData.tableData = makeRowVector(tableType_->names(), tableColumns); + tableData.tableVectors = makeRowVector(tableType_->names(), tableColumns); } RowTypePtr IndexLookupJoinTestBase::makeScanOutputType( @@ -415,13 +421,13 @@ bool IndexLookupJoinTestBase::isFilter(const std::string& conditionSql) const { std::shared_ptr IndexLookupJoinTestBase::runLookupQuery( const PlanNodePtr& plan, int numPrefetchBatches, - const std::string& duckDbVefifySql) { + const std::string& duckDbVerifySql) { return AssertQueryBuilder(duckDbQueryRunner_) .plan(plan) .config( QueryConfig::kIndexLookupJoinMaxPrefetchBatches, std::to_string(numPrefetchBatches)) - .assertResults(duckDbVefifySql); + .assertResults(duckDbVerifySql); } std::shared_ptr IndexLookupJoinTestBase::runLookupQuery( @@ -431,17 +437,30 @@ std::shared_ptr IndexLookupJoinTestBase::runLookupQuery( bool barrierExecution, int maxOutputRows, int numPrefetchBatches, - const std::string& duckDbVefifySql) { - return AssertQueryBuilder(duckDbQueryRunner_) - .plan(plan) + bool needsIndexSplit, + const std::string& duckDbVerifySql, + int maxDrivers) { + AssertQueryBuilder queryBuilder(duckDbQueryRunner_); + queryBuilder.plan(plan) .splits(probeScanNodeId_, makeHiveConnectorSplits(probeFiles)) .serialExecution(serialExecution) .barrierExecution(barrierExecution) + .maxDrivers(maxDrivers) .config(QueryConfig::kMaxOutputBatchRows, std::to_string(maxOutputRows)) .config( QueryConfig::kIndexLookupJoinMaxPrefetchBatches, - std::to_string(numPrefetchBatches)) - .assertResults(duckDbVefifySql); + std::to_string(numPrefetchBatches)); + if (needsIndexSplit) { + // Add a fake split for the index source. The test index source doesn't + // actually use splits, but this is used to verify the split passing + // mechanism works correctly. + queryBuilder.split( + indexScanNodeId_, + Split( + std::make_shared( + kTestIndexConnectorName))); + } + return queryBuilder.assertResults(duckDbVerifySql); } void IndexLookupJoinTestBase::verifyResultWithMatchColumn( @@ -449,21 +468,36 @@ void IndexLookupJoinTestBase::verifyResultWithMatchColumn( const PlanNodeId& probeScanNodeIdWithoutMatchColumn, const PlanNodePtr& planWithMatchColumn, const PlanNodeId& probeScanNodeIdWithMatchColumn, - const std::vector>& probeFiles) { - VectorPtr expectedResult = AssertQueryBuilder(duckDbQueryRunner_) - .plan(planWithoutMatchColumn) - .splits( - probeScanNodeIdWithoutMatchColumn, - makeHiveConnectorSplits(probeFiles)) - .copyResults(pool()); + const std::vector>& probeFiles, + bool needsIndexSplit) { + AssertQueryBuilder expectedResultBuilder(duckDbQueryRunner_); + expectedResultBuilder.plan(planWithoutMatchColumn) + .splits( + probeScanNodeIdWithoutMatchColumn, + makeHiveConnectorSplits(probeFiles)); + if (needsIndexSplit) { + expectedResultBuilder.split( + indexScanNodeId_, + Split( + std::make_shared( + kTestIndexConnectorName))); + } + VectorPtr expectedResult = expectedResultBuilder.copyResults(pool()); BaseVector::flattenVector(expectedResult); - VectorPtr resultWithMatchColumn = AssertQueryBuilder(duckDbQueryRunner_) - .plan(planWithMatchColumn) - .splits( - probeScanNodeIdWithMatchColumn, - makeHiveConnectorSplits(probeFiles)) - .copyResults(pool()); + AssertQueryBuilder resultWithMatchColumnBuilder(duckDbQueryRunner_); + resultWithMatchColumnBuilder.plan(planWithMatchColumn) + .splits( + probeScanNodeIdWithMatchColumn, makeHiveConnectorSplits(probeFiles)); + if (needsIndexSplit) { + resultWithMatchColumnBuilder.split( + indexScanNodeId_, + Split( + std::make_shared( + kTestIndexConnectorName))); + } + VectorPtr resultWithMatchColumn = + resultWithMatchColumnBuilder.copyResults(pool()); BaseVector::flattenVector(resultWithMatchColumn); auto rowResultWithMatchMatchColumn = std::dynamic_pointer_cast(resultWithMatchColumn); @@ -512,4 +546,4 @@ IndexLookupJoinTestBase::createProbeFiles( writeToFiles(toFilePaths(probeFiles), probeVectors); return probeFiles; } -} // namespace fecebook::velox::exec::test +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/utils/IndexLookupJoinTestBase.h b/velox/exec/tests/utils/IndexLookupJoinTestBase.h index 561b6dfe9fe..d27284cd92c 100644 --- a/velox/exec/tests/utils/IndexLookupJoinTestBase.h +++ b/velox/exec/tests/utils/IndexLookupJoinTestBase.h @@ -19,12 +19,13 @@ #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/parse/PlanNodeIdGenerator.h" -namespace fecebook::velox::exec::test { +namespace facebook::velox::exec::test { using namespace facebook::velox; using namespace facebook::velox::core; using namespace facebook::velox::exec; using namespace facebook::velox::exec::test; +using namespace facebook::velox::common::testutil; class IndexLookupJoinTestBase : public HiveConnectorTestBase { protected: @@ -35,10 +36,10 @@ class IndexLookupJoinTestBase : public HiveConnectorTestBase { rng_.seed(123); } - struct SequenceTableData { - RowVectorPtr keyData; - RowVectorPtr valueData; - RowVectorPtr tableData; + struct IndexTableData { + RowVectorPtr keyVectors; + RowVectorPtr valueVectors; + RowVectorPtr tableVectors; std::vector minKeys; std::vector maxKeys; }; @@ -70,7 +71,7 @@ class IndexLookupJoinTestBase : public HiveConnectorTestBase { size_t numBatches, size_t batchSize, size_t numDuplicateProbeRows, - SequenceTableData& tableData, + IndexTableData& tableData, std::shared_ptr& pool, const std::vector& probeJoinKeys, bool hasNullKeys = false, @@ -86,7 +87,7 @@ class IndexLookupJoinTestBase : public HiveConnectorTestBase { /// @param probeVectors: the probe input vectors. /// @param leftKeys: the left join keys of index lookup join. /// @param rightKeys: the right join keys of index lookup join. - /// @param includeMatchColumn: whether the index join output includes a match + /// @param hasMarker: whether the index join output includes a match /// column at the end. /// @param joinType: the join type of index lookup join. /// @param outputColumns: the output column names of index lookup join. @@ -99,30 +100,32 @@ class IndexLookupJoinTestBase : public HiveConnectorTestBase { const std::vector& leftKeys, const std::vector& rightKeys, const std::vector& joinConditions, - bool includeMatchColumn, + bool hasMarker, core::JoinType joinType, const std::vector& outputColumns, core::PlanNodeId& joinNodeId); /// Makes lookup join plan with the following parameters: + /// @param planNodeIdGenerator: generator for creating unique plan node IDs. /// @param indexScanNode: the index table scan node. - /// @param probeVectors: the probe input vectors. /// @param leftKeys: the left join keys of index lookup join. /// @param rightKeys: the right join keys of index lookup join. - /// @param includeMatchColumn: whether the index join output includes a match + /// @param joinConditions: the join conditions for index lookup join that + /// can't be converted into simple equality join conditions. + /// @param filter: additional filter condition SQL string to apply on join + /// results. Can be empty string if no additional filter is needed. + /// @param hasMarker: whether the index join output includes a match /// column at the end. /// @param joinType: the join type of index lookup join. /// @param outputColumns: the output column names of index lookup join. - /// @param joinNodeId: returns the plan node id of the index lookup join - /// node. - /// @param probeScanNodeId: returns the plan node id of the probe table scan PlanNodePtr makeLookupPlan( const std::shared_ptr& planNodeIdGenerator, TableScanNodePtr indexScanNode, const std::vector& leftKeys, const std::vector& rightKeys, const std::vector& joinConditions, - bool includeMatchColumn, + const std::string& filter, + bool hasMarker, JoinType joinType, const std::vector& outputColumns); @@ -148,7 +151,7 @@ class IndexLookupJoinTestBase : public HiveConnectorTestBase { /// each index column. void generateIndexTableData( const std::vector& keyCardinalities, - SequenceTableData& tableData, + IndexTableData& tableData, std::shared_ptr& pool); /// Write 'probeVectors' to a number of files with one per each file. @@ -162,7 +165,7 @@ class IndexLookupJoinTestBase : public HiveConnectorTestBase { std::shared_ptr runLookupQuery( const PlanNodePtr& plan, int numPrefetchBatches, - const std::string& duckDbVefifySql); + const std::string& duckDbVerifySql); std::shared_ptr runLookupQuery( const PlanNodePtr& plan, @@ -171,7 +174,9 @@ class IndexLookupJoinTestBase : public HiveConnectorTestBase { bool barrierExecution, int maxBatchRows, int numPrefetchBatches, - const std::string& duckDbVefifySql); + bool needsIndexSplit, + const std::string& duckDbVerifySql, + int maxDrivers = 1); /// Verifies the results of the index lookup join query with and without match /// column. @@ -180,7 +185,8 @@ class IndexLookupJoinTestBase : public HiveConnectorTestBase { const PlanNodeId& probeScanNodeIdWithoutMatchColumn, const PlanNodePtr& planWithMatchColumn, const PlanNodeId& probeScanNodeIdWithMatchColumn, - const std::vector>& probeFiles); + const std::vector>& probeFiles, + bool needsIndexSplit); RowTypePtr keyType_; std::optional partitionType_; @@ -192,4 +198,4 @@ class IndexLookupJoinTestBase : public HiveConnectorTestBase { PlanNodeId probeScanNodeId_; folly::Random::DefaultGenerator rng_; }; -} // namespace fecebook::velox::exec::test +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/utils/LocalExchangeSource.cpp b/velox/exec/tests/utils/LocalExchangeSource.cpp index d64dad29387..08963d226ed 100644 --- a/velox/exec/tests/utils/LocalExchangeSource.cpp +++ b/velox/exec/tests/utils/LocalExchangeSource.cpp @@ -17,7 +17,7 @@ #include #include #include "velox/common/testutil/TestValue.h" -#include "velox/exec/ExchangeClient.h" +#include "velox/exec/Operator.h" #include "velox/exec/OutputBufferManager.h" namespace facebook::velox::exec::test { @@ -84,13 +84,13 @@ class LocalExchangeSource : public exec::ExchangeSource { << requestedSequence; int64_t nExtra = requestedSequence - sequence; VELOX_CHECK(nExtra < data.size()); - data.erase(data.begin(), data.begin() + nExtra); + data.erase(data.cbegin(), data.cbegin() + nExtra); sequence = requestedSequence; } if (data.empty()) { sequence = requestedSequence; } - std::vector> pages; + std::vector> pages; bool atEnd = false; int64_t totalBytes = 0; for (auto& inputPage : data) { @@ -101,7 +101,8 @@ class LocalExchangeSource : public exec::ExchangeSource { } totalBytes += inputPage->length(); inputPage->unshare(); - pages.push_back(std::make_unique(std::move(inputPage))); + pages.push_back( + std::make_unique(std::move(inputPage))); inputPage = nullptr; } numPages_ += pages.size(); @@ -190,7 +191,7 @@ class LocalExchangeSource : public exec::ExchangeSource { {"localExchangeSource.numPages", RuntimeMetric(numPages_)}, {"localExchangeSource.totalBytes", RuntimeMetric(totalBytes_, RuntimeCounter::Unit::kBytes)}, - {ExchangeClient::kBackgroundCpuTimeMs, + {std::string(Operator::kBackgroundCpuTimeNanos), RuntimeMetric(123 * 1000000, RuntimeCounter::Unit::kNanos)}, }; } @@ -271,7 +272,7 @@ class LocalExchangeSource : public exec::ExchangeSource { } bool checkSetRequestPromise() { - VeloxPromise promise; + VeloxPromise promise{VeloxPromise::makeEmpty()}; { std::lock_guard l(queue_->mutex()); promise = std::move(promise_); diff --git a/velox/exec/tests/utils/MergeTestBase.h b/velox/exec/tests/utils/MergeTestBase.h index 54d4bdbac70..e20118f8df1 100644 --- a/velox/exec/tests/utils/MergeTestBase.h +++ b/velox/exec/tests/utils/MergeTestBase.h @@ -15,8 +15,9 @@ */ #include "velox/common/base/Exceptions.h" +#include "velox/common/base/TreeOfLosers.h" #include "velox/common/time/Timer.h" -#include "velox/exec/TreeOfLosers.h" +#include "velox/exec/Spill.h" #include @@ -100,6 +101,44 @@ class TestingStream final : public MergeStream { std::vector numbers_; }; +class TestingSpillMergeStream : public SpillMergeStream { + public: + TestingSpillMergeStream( + uint32_t id, + const std::vector& sortingKeys, + RowVectorPtr rowVector) + : id_(id), sortingKeys_(sortingKeys) { + rowVector_ = rowVector; + size_ = rowVector_->size(); + } + + uint32_t id() const override { + return id_; + } + + private: + const std::vector& sortingKeys() const override { + VELOX_CHECK(!closed_); + return sortingKeys_; + } + + void nextBatch() override { + VELOX_CHECK(!closed_); + index_ = 0; + size_ = 0; + close(); + rowVector_.reset(); + } + + void close() override { + VELOX_CHECK(!closed_); + SpillMergeStream::close(); + } + + uint32_t id_; + const std::vector sortingKeys_; +}; + // Test data for merging. struct TestData { // Globally sorted sequence of test keys. diff --git a/velox/exec/tests/utils/OperatorTestBase.cpp b/velox/exec/tests/utils/OperatorTestBase.cpp index 276cfc1735a..6951bdb206b 100644 --- a/velox/exec/tests/utils/OperatorTestBase.cpp +++ b/velox/exec/tests/utils/OperatorTestBase.cpp @@ -126,13 +126,13 @@ void OperatorTestBase::SetUp() { if (!isRegisteredVectorSerde()) { this->registerVectorSerde(); } - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kPresto)) { + if (!isRegisteredNamedVectorSerde("Presto")) { serializer::presto::PrestoVectorSerde::registerNamedVectorSerde(); } - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kCompactRow)) { + if (!isRegisteredNamedVectorSerde("CompactRow")) { serializer::CompactRowVectorSerde::registerNamedVectorSerde(); } - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kUnsafeRow)) { + if (!isRegisteredNamedVectorSerde("UnsafeRow")) { serializer::spark::UnsafeRowVectorSerde::registerNamedVectorSerde(); } driverExecutor_ = std::make_unique(3); @@ -153,6 +153,7 @@ void OperatorTestBase::SetUp() { void OperatorTestBase::TearDown() { waitForAllTasksToBeDeleted(); stopPeriodicStatsReporter(); + executor_.reset(); // There might be lingering exchange source on executor even after all tasks // are deleted. This can cause memory leak because exchange source holds // reference to memory pool. We need to make sure they are properly cleaned. @@ -243,7 +244,7 @@ core::TypedExprPtr OperatorTestBase::parseExpr( const std::string& text, RowTypePtr rowType, const parse::ParseOptions& options) { - auto untyped = parse::parseExpr(text, options); + auto untyped = parse::DuckSqlExpressionsParser(options).parseExpr(text); return core::Expressions::inferTypes(untyped, rowType, pool_.get()); } diff --git a/velox/exec/tests/utils/PlanBuilder.cpp b/velox/exec/tests/utils/PlanBuilder.cpp index d48914fd168..32fbbaad617 100644 --- a/velox/exec/tests/utils/PlanBuilder.cpp +++ b/velox/exec/tests/utils/PlanBuilder.cpp @@ -46,7 +46,7 @@ core::TypedExprPtr parseExpr( const RowTypePtr& rowType, const parse::ParseOptions& options, memory::MemoryPool* pool) { - auto untyped = parse::parseExpr(text, options); + auto untyped = parse::DuckSqlExpressionsParser(options).parseExpr(text); return core::Expressions::inferTypes(untyped, rowType, pool); } @@ -143,7 +143,8 @@ PlanBuilder& PlanBuilder::tpchTableScan( core::TypedExprPtr filterExpression; if (!filter.empty()) { - auto expression = parse::parseExpr(filter, options_); + auto expression = + parse::DuckSqlExpressionsParser(options_).parseExpr(filter); filterExpression = core::Expressions::inferTypes(expression, rowType, pool_); } @@ -184,12 +185,65 @@ PlanBuilder& PlanBuilder::tpcdsTableScan( auto rowType = ROW(std::move(columnNames), std::move(outputTypes)); return TableScanBuilder(*this) .outputType(rowType) - .tableHandle(std::make_shared( - std::string(connectorId), table, scaleFactor)) + .tableHandle( + std::make_shared( + std::string(connectorId), table, scaleFactor)) .assignments(assignmentsMap) .endTableScan(); } +namespace { + +// Analyzes 'expr' to determine if it can be expressed as a subfield filter. +// Returns a pair of subfield and filter if so. Otherwise, throws. +// +// Supports all expressions supported by +// exec::ExprToSubfieldFilterParser::leafCallToSubfieldFilter + negations and +// disjunctions over same subfield. +// +// Examples: +// a = 1 +// a = 1 OR a > 10 +// not (a = 1) +std::pair> toSubfieldFilter( + const core::TypedExprPtr& expr, + core::ExpressionEvaluator* evaluator) { + if (expr->isCallKind(); + auto* call = expr->asUnchecked()) { + if (call->name() == "or") { + VELOX_CHECK_EQ(call->inputs().size(), 2); + auto left = toSubfieldFilter(call->inputs()[0], evaluator); + auto right = toSubfieldFilter(call->inputs()[1], evaluator); + VELOX_CHECK(left.first == right.first); + auto filter = exec::ExprToSubfieldFilterParser::makeOrFilter( + std::move(left.second), std::move(right.second)); + VELOX_CHECK_NOT_NULL(filter); + return {std::move(left.first), std::move(filter)}; + } + + if (call->name() == "not") { + const auto& input = call->inputs()[0]; + if (input->isCallKind(); + auto* inner = input->asUnchecked()) { + if (auto result = + exec::ExprToSubfieldFilterParser::getInstance() + ->leafCallToSubfieldFilter(*inner, evaluator, true)) { + return std::move(result).value(); + } + } + } else { + if (auto result = + exec::ExprToSubfieldFilterParser::getInstance() + ->leafCallToSubfieldFilter(*call, evaluator, false)) { + return std::move(result).value(); + } + } + } + VELOX_UNSUPPORTED( + "Unsupported expression for range filter: {}", expr->toString()); +} +} // namespace + PlanBuilder::TableScanBuilder& PlanBuilder::TableScanBuilder::subfieldFilters( std::vector subfieldFilters) { VELOX_CHECK(subfieldFiltersMap_.empty()); @@ -204,13 +258,13 @@ PlanBuilder::TableScanBuilder& PlanBuilder::TableScanBuilder::subfieldFilters( const RowTypePtr& parseType = dataColumns_ ? dataColumns_ : outputType_; for (const auto& filter : subfieldFilters) { - auto untypedExpr = parse::parseExpr(filter, planBuilder_.options_); + auto untypedExpr = parse::DuckSqlExpressionsParser(planBuilder_.options_) + .parseExpr(filter); // Parse directly to subfieldFiltersMap_ auto filterExpr = core::Expressions::inferTypes( untypedExpr, parseType, planBuilder_.pool_); - auto [subfield, subfieldFilter] = - exec::toSubfieldFilter(filterExpr, &evaluator); + auto [subfield, subfieldFilter] = toSubfieldFilter(filterExpr, &evaluator); auto it = columnAliases_.find(subfield.toString()); if (it != columnAliases_.end()) { @@ -239,11 +293,17 @@ PlanBuilder::TableScanBuilder::subfieldFiltersMap( PlanBuilder::TableScanBuilder& PlanBuilder::TableScanBuilder::remainingFilter( std::string remainingFilter) { if (!remainingFilter.empty()) { - remainingFilter_ = parse::parseExpr(remainingFilter, planBuilder_.options_); + remainingFilter_ = parse::DuckSqlExpressionsParser(planBuilder_.options_) + .parseExpr(remainingFilter); } return *this; } +PlanBuilder::TableScanBuilder& PlanBuilder::TableScanBuilder::sampleRate( + double sampleRate) { + sampleRate_ = sampleRate; + return *this; +} namespace { void addConjunct( const core::TypedExprPtr& conjunct, @@ -279,13 +339,24 @@ core::PlanNodePtr PlanBuilder::TableScanBuilder::build(core::PlanNodeId id) { {name, std::make_shared( hiveColumnName, - HiveColumnHandle::ColumnType::kRegular, + FileColumnHandle::ColumnType::kRegular, type, type)}); } } - const RowTypePtr& parseType = dataColumns_ ? dataColumns_ : outputType_; + RowTypePtr parseType = dataColumns_ ? dataColumns_ : outputType_; + if (!filterColumnHandles_.empty()) { + auto names = parseType->names(); + auto types = parseType->children(); + for (auto& handle : filterColumnHandles_) { + if (!parseType->containsChild(handle->name())) { + names.push_back(handle->name()); + types.push_back(handle->hiveType()); + } + } + parseType = ROW(std::move(names), std::move(types)); + } core::TypedExprPtr filterNodeExpr; @@ -315,10 +386,12 @@ core::PlanNodePtr PlanBuilder::TableScanBuilder::build(core::PlanNodeId id) { tableHandle_ = std::make_shared( connectorId_, tableName_, - true, std::move(subfieldFiltersMap_), remainingFilterExpr, - dataColumns_); + dataColumns_, + indexColumns_, + /*tableParameters=*/std::unordered_map{}, + filterColumnHandles_); } core::PlanNodePtr result = std::make_shared( id, outputType_, tableHandle_, assignments_); @@ -354,8 +427,8 @@ core::PlanNodePtr PlanBuilder::TableWriterBuilder::build(core::PlanNodeId id) { std::make_shared( column, isPartitionKey - ? connector::hive::HiveColumnHandle::ColumnType::kPartitionKey - : connector::hive::HiveColumnHandle::ColumnType::kRegular, + ? connector::hive::FileColumnHandle::ColumnType::kPartitionKey + : connector::hive::FileColumnHandle::ColumnType::kRegular, outputType->childAt(i), outputType->childAt(i))); } @@ -393,8 +466,9 @@ core::PlanNodePtr PlanBuilder::TableWriterBuilder::build(core::PlanNodeId id) { std::vector groupingKeys; groupingKeys.reserve(partitionBy_.size()); for (const auto& partitionBy : partitionBy_) { - groupingKeys.push_back(std::make_shared( - outputType->findChild(partitionBy), partitionBy)); + groupingKeys.push_back( + std::make_shared( + outputType->findChild(partitionBy), partitionBy)); } columnStatsSpec = core::ColumnStatsSpec( std::move(groupingKeys), @@ -445,7 +519,7 @@ PlanBuilder& PlanBuilder::traceScan( PlanBuilder& PlanBuilder::exchange( const RowTypePtr& outputType, - VectorSerde::Kind serdeKind) { + std::string serdeKind) { VELOX_CHECK_NULL(planNode_, "Exchange must be the source node"); planNode_ = std::make_shared( nextPlanNodeId(), outputType, serdeKind); @@ -464,7 +538,7 @@ parseOrderByClauses( std::vector> sortingKeys; std::vector sortingOrders; for (const auto& key : keys) { - auto orderBy = parse::parseOrderByExpr(key); + auto orderBy = parse::DuckSqlExpressionsParser().parseOrderByExpr(key); auto typedExpr = core::Expressions::inferTypes(orderBy.expr, inputType, pool); @@ -485,7 +559,7 @@ parseOrderByClauses( PlanBuilder& PlanBuilder::mergeExchange( const RowTypePtr& outputType, const std::vector& keys, - VectorSerde::Kind serdeKind) { + std::string serdeKind) { VELOX_CHECK_NULL(planNode_, "MergeExchange must be the source node"); auto [sortingKeys, sortingOrders] = parseOrderByClauses(keys, outputType, pool_); @@ -556,7 +630,8 @@ PlanBuilder& PlanBuilder::project(const std::vector& projections) { std::vector> expressions; expressions.reserve(projections.size()); for (auto i = 0; i < projections.size(); ++i) { - expressions.push_back(parse::parseExpr(projections[i], options_)); + expressions.push_back( + parse::DuckSqlExpressionsParser(options_).parseExpr(projections[i])); } return projectExpressions(expressions); } @@ -578,7 +653,8 @@ PlanBuilder& PlanBuilder::parallelProject( typedExprs.reserve(group.size()); for (const auto& expr : group) { - const auto typedExpr = inferTypes(parse::parseExpr(expr, options_)); + const auto typedExpr = + inferTypes(parse::DuckSqlExpressionsParser(options_).parseExpr(expr)); typedExprs.push_back(typedExpr); if (auto fieldExpr = @@ -609,7 +685,8 @@ PlanBuilder& PlanBuilder::lazyDereference( std::vector expressions; std::vector projectNames; for (auto i = 0; i < projections.size(); ++i) { - auto expr = inferTypes(parse::parseExpr(projections[i], options_)); + auto expr = inferTypes( + parse::DuckSqlExpressionsParser(options_).parseExpr(projections[i])); expressions.push_back(expr); if (auto* fieldExpr = dynamic_cast(expr.get())) { @@ -645,15 +722,21 @@ PlanBuilder& PlanBuilder::optionalFilter(const std::string& optionalFilter) { return filter(optionalFilter); } -PlanBuilder& PlanBuilder::filter(const std::string& filter) { +PlanBuilder& PlanBuilder::filter(const core::ExprPtr& filterExpr) { VELOX_CHECK_NOT_NULL(planNode_, "Filter cannot be the source node"); - auto expr = parseExpr(filter, planNode_->outputType(), options_, pool_); - planNode_ = - std::make_shared(nextPlanNodeId(), expr, planNode_); + auto typedExpr = + core::Expressions::inferTypes(filterExpr, planNode_->outputType(), pool_); + planNode_ = std::make_shared( + nextPlanNodeId(), typedExpr, planNode_); VELOX_CHECK(planNode_->supportsBarrier()); return *this; } +PlanBuilder& PlanBuilder::filter(const std::string& filterExpr) { + return filter( + parse::DuckSqlExpressionsParser(options_).parseExpr(filterExpr)); +} + PlanBuilder& PlanBuilder::tableWrite( const std::string& outputDirectoryPath, const dwio::common::FileFormat fileFormat, @@ -718,7 +801,8 @@ PlanBuilder& PlanBuilder::tableWrite( const common::CompressionKind compressionKind, const RowTypePtr& schema, const bool ensureFiles, - const connector::CommitStrategy commitStrategy) { + const connector::CommitStrategy commitStrategy, + std::shared_ptr insertTableHandle) { return TableWriterBuilder(*this) .outputDirectoryPath(outputDirectoryPath) .outputFileName(outputFileName) @@ -735,6 +819,7 @@ PlanBuilder& PlanBuilder::tableWrite( .compressionKind(compressionKind) .ensureFiles(ensureFiles) .commitStrategy(commitStrategy) + .insertHandle(insertTableHandle) .endTableWriter(); } @@ -754,51 +839,56 @@ const core::TableWriteNodePtr findTableWrite(const core::PlanNodePtr planNode) { } } // namespace -PlanBuilder& PlanBuilder::tableWriteMerge() { +PlanBuilder& PlanBuilder::tableWriteMerge(core::AggregationNode::Step step) { VELOX_CHECK_NOT_NULL(planNode_, "TableWriteMerge cannot be the source node"); auto writer = findTableWrite(planNode_); VELOX_CHECK_NOT_NULL( - writer, "TableWriteMerge can only be added after TableWrite node"); + writer, "TableWriteMerge requires a TableWrite node in the plan tree"); std::optional columnStatsSpec; if (writer->hasColumnStatsSpec()) { - const auto writerSpec = writer->columnStatsSpec().value(); - VELOX_CHECK_EQ( - writerSpec.aggregationStep, core::AggregationNode::Step::kPartial); - std::vector> aggregateRawInputs; - const auto numAggregates = writerSpec.aggregates.size(); - aggregateRawInputs.reserve(numAggregates); - for (const auto& aggregate : writerSpec.aggregates) { - aggregateRawInputs.push_back(aggregate.rawInputTypes); - } + const auto& writerSpec = writer->columnStatsSpec().value(); const auto& inputType = planNode_->outputType(); + const auto numAggregates = writerSpec.aggregates.size(); std::vector aggregateNames; aggregateNames.reserve(numAggregates); std::vector aggregates; aggregates.reserve(numAggregates); - for (int i = 0; i < numAggregates; ++i) { - core::AggregationNode::Aggregate aggregate = writerSpec.aggregates[i]; + for (size_t i = 0; i < numAggregates; ++i) { + auto aggregate = writerSpec.aggregates[i]; aggregate.call = std::make_shared( aggregate.call->type(), aggregate.call->name(), field(inputType, writerSpec.aggregateNames[i])); aggregates.push_back(std::move(aggregate)); - aggregateNames.push_back(fmt::format("a{}", i)); + aggregateNames.push_back(writerSpec.aggregateNames[i]); } columnStatsSpec = core::ColumnStatsSpec{ writerSpec.groupingKeys, - core::AggregationNode::Step::kIntermediate, + step, std::move(aggregateNames), std::move(aggregates)}; } + auto outputType = TableWriteTraits::outputType(columnStatsSpec); planNode_ = std::make_shared( nextPlanNodeId(), - TableWriteTraits::outputType(columnStatsSpec), - columnStatsSpec, + std::move(outputType), + std::move(columnStatsSpec), + planNode_); + return *this; +} + +PlanBuilder& PlanBuilder::tableWriteMerge( + core::ColumnStatsSpec columnStatsSpec) { + VELOX_CHECK_NOT_NULL(planNode_, "TableWriteMerge cannot be the source node"); + auto outputType = TableWriteTraits::outputType(columnStatsSpec); + planNode_ = std::make_shared( + nextPlanNodeId(), + std::move(outputType), + std::move(columnStatsSpec), planNode_); - VELOX_CHECK(!planNode_->supportsBarrier()); return *this; } @@ -850,6 +940,7 @@ core::PlanNodePtr PlanBuilder::createIntermediateOrFinalAggregation( partialAggNode->aggregateNames(), aggregates, partialAggNode->ignoreNullKeys(), + partialAggNode->noGroupsSpanBatches(), planNode_); VELOX_CHECK_EQ( aggregationNode->supportsBarrier(), aggregationNode->isPreGrouped()); @@ -943,12 +1034,17 @@ PlanBuilder::AggregatesAndNames PlanBuilder::createAggregateExpressionsAndNames( resolver.setRawInputTypes(rawInputTypes[i]); } - auto untypedExpr = duckdb::parseAggregateExpr(aggregate, options); + auto aggCall = duckdb::parseAggregateExpr(aggregate, options); + + // Build a plain CallExpr for type resolution (AggregateCallExpr carries + // options that CallTypedExpr doesn't need). + auto plainCall = std::make_shared( + aggCall->name(), aggCall->inputs(), std::nullopt); core::AggregationNode::Aggregate agg; agg.call = std::dynamic_pointer_cast( - inferTypes(untypedExpr.expr)); + inferTypes(plainCall)); if (step == core::AggregationNode::Step::kPartial || step == core::AggregationNode::Step::kSingle) { @@ -959,10 +1055,10 @@ PlanBuilder::AggregatesAndNames PlanBuilder::createAggregateExpressionsAndNames( agg.rawInputTypes = rawInputTypes[i]; } - if (untypedExpr.maskExpr != nullptr) { + if (aggCall->filter() != nullptr) { auto maskExpr = std::dynamic_pointer_cast( - inferTypes(untypedExpr.maskExpr)); + inferTypes(aggCall->filter())); VELOX_CHECK_NOT_NULL( maskExpr, "FILTER clause must use a column name, not an expression: {}", @@ -977,9 +1073,9 @@ PlanBuilder::AggregatesAndNames PlanBuilder::createAggregateExpressionsAndNames( agg.mask = field(masks[i]); } - agg.distinct = untypedExpr.distinct; + agg.distinct = aggCall->isDistinct(); - if (!untypedExpr.orderBy.empty()) { + if (!aggCall->orderBy().empty()) { auto* entry = exec::getAggregateFunctionEntry(agg.call->name()); const auto& metadata = entry->metadata; if (metadata.orderSensitive) { @@ -989,25 +1085,25 @@ PlanBuilder::AggregatesAndNames PlanBuilder::createAggregateExpressionsAndNames( "into partial and final: {}.", aggregate); } - } - for (const auto& orderBy : untypedExpr.orderBy) { - auto sortingKey = - std::dynamic_pointer_cast( - inferTypes(orderBy.expr)); - VELOX_CHECK_NOT_NULL( - sortingKey, - "ORDER BY clause must use a column name, not an expression: {}", - aggregate); + for (const auto& orderBy : aggCall->orderBy()) { + auto sortingKey = + std::dynamic_pointer_cast( + inferTypes(orderBy.expr)); + VELOX_CHECK_NOT_NULL( + sortingKey, + "ORDER BY clause must use a column name, not an expression: {}", + aggregate); - agg.sortingKeys.push_back(sortingKey); - agg.sortingOrders.emplace_back(orderBy.ascending, orderBy.nullsFirst); + agg.sortingKeys.push_back(sortingKey); + agg.sortingOrders.emplace_back(orderBy.ascending, orderBy.nullsFirst); + } } aggs.emplace_back(agg); - if (untypedExpr.expr->alias().has_value()) { - names.push_back(untypedExpr.expr->alias().value()); + if (aggCall->alias().has_value()) { + names.push_back(aggCall->alias().value()); } else { names.push_back(fmt::format("a{}", i)); } @@ -1055,6 +1151,7 @@ PlanBuilder& PlanBuilder::aggregation( globalGroupingSets, groupId, ignoreNullKeys, + /*noGroupsSpanBatches=*/false, planNode_); VELOX_CHECK_EQ( aggregationNode->supportsBarrier(), aggregationNode->isPreGrouped()); @@ -1067,7 +1164,8 @@ PlanBuilder& PlanBuilder::streamingAggregation( const std::vector& aggregates, const std::vector& masks, core::AggregationNode::Step step, - bool ignoreNullKeys) { + bool ignoreNullKeys, + bool noGroupsSpanBatches) { auto aggregatesAndNames = createAggregateExpressionsAndNames(aggregates, masks, step); auto aggregationNode = std::make_shared( @@ -1078,6 +1176,7 @@ PlanBuilder& PlanBuilder::streamingAggregation( aggregatesAndNames.names, aggregatesAndNames.aggregates, ignoreNullKeys, + noGroupsSpanBatches, planNode_); VELOX_CHECK_EQ( aggregationNode->supportsBarrier(), aggregationNode->isPreGrouped()); @@ -1093,7 +1192,8 @@ PlanBuilder& PlanBuilder::groupId( std::vector groupingKeyInfos; groupingKeyInfos.reserve(groupingKeys.size()); for (const auto& groupingKey : groupingKeys) { - auto untypedExpr = parse::parseExpr(groupingKey, options_); + auto untypedExpr = + parse::DuckSqlExpressionsParser(options_).parseExpr(groupingKey); const auto* fieldAccessExpr = dynamic_cast(untypedExpr.get()); VELOX_USER_CHECK( @@ -1162,7 +1262,9 @@ PlanBuilder& PlanBuilder::expand( std::vector projectExpr; VELOX_CHECK_EQ(numColumns, projections[i].size()); for (auto j = 0; j < numColumns; j++) { - auto untypedExpression = parse::parseExpr(projections[i][j], options_); + auto untypedExpression = + parse::DuckSqlExpressionsParser(options_).parseExpr( + projections[i][j]); auto typedExpression = inferTypes(untypedExpression); if (i == 0) { @@ -1186,8 +1288,9 @@ PlanBuilder& PlanBuilder::expand( dynamic_cast(untypedExpression.get()); VELOX_CHECK_NOT_NULL(constantExpr); VELOX_CHECK(constantExpr->value().isNull()); - projectExpr.push_back(std::make_shared( - expectedType, variant::null(expectedType->kind()))); + projectExpr.push_back( + std::make_shared( + expectedType, variant::null(expectedType->kind()))); } } } @@ -1354,7 +1457,7 @@ PlanBuilder& PlanBuilder::partitionedOutput( const std::vector& keys, int numPartitions, const std::vector& outputLayout, - VectorSerde::Kind serdeKind) { + std::string serdeKind) { return partitionedOutput(keys, numPartitions, false, outputLayout, serdeKind); } @@ -1363,7 +1466,7 @@ PlanBuilder& PlanBuilder::partitionedOutput( int numPartitions, bool replicateNullsAndAny, const std::vector& outputLayout, - VectorSerde::Kind serdeKind) { + std::string serdeKind) { VELOX_CHECK_NOT_NULL( planNode_, "PartitionedOutput cannot be the source node"); @@ -1383,7 +1486,7 @@ PlanBuilder& PlanBuilder::partitionedOutput( bool replicateNullsAndAny, core::PartitionFunctionSpecPtr partitionFunctionSpec, const std::vector& outputLayout, - VectorSerde::Kind serdeKind) { + std::string serdeKind) { VELOX_CHECK_NOT_NULL( planNode_, "PartitionedOutput cannot be the source node"); auto outputType = outputLayout.empty() @@ -1405,7 +1508,7 @@ PlanBuilder& PlanBuilder::partitionedOutput( PlanBuilder& PlanBuilder::partitionedOutputBroadcast( const std::vector& outputLayout, - VectorSerde::Kind serdeKind) { + std::string serdeKind) { VELOX_CHECK_NOT_NULL( planNode_, "PartitionedOutput cannot be the source node"); auto outputType = outputLayout.empty() @@ -1419,7 +1522,7 @@ PlanBuilder& PlanBuilder::partitionedOutputBroadcast( PlanBuilder& PlanBuilder::partitionedOutputArbitrary( const std::vector& outputLayout, - VectorSerde::Kind serdeKind) { + std::string serdeKind) { VELOX_CHECK_NOT_NULL( planNode_, "PartitionedOutput cannot be the source node"); auto outputType = outputLayout.empty() @@ -1456,6 +1559,10 @@ PlanBuilder& PlanBuilder::localPartition(const std::vector& keys) { return *this; } +PlanBuilder& PlanBuilder::localGather() { + return localPartition(std::vector{}); +} + PlanBuilder& PlanBuilder::scaleWriterlocalPartition( const std::vector& keys) { std::vector keyIndices; @@ -1622,7 +1729,8 @@ PlanBuilder& PlanBuilder::hashJoin( const std::string& filter, const std::vector& outputLayout, core::JoinType joinType, - bool nullAware) { + bool nullAware, + bool nullAsValue) { VELOX_CHECK_NOT_NULL(planNode_, "HashJoin cannot be the source node"); VELOX_CHECK_EQ(leftKeys.size(), rightKeys.size()); @@ -1663,7 +1771,9 @@ PlanBuilder& PlanBuilder::hashJoin( std::move(filterExpr), std::move(planNode_), build, - outputType); + outputType, + /*useHashTableCache=*/false, + nullAsValue); VELOX_CHECK(!planNode_->supportsBarrier()); return *this; } @@ -1741,20 +1851,35 @@ PlanBuilder& PlanBuilder::nestedLoopJoin( PlanBuilder& PlanBuilder::spatialJoin( const core::PlanNodePtr& right, const std::string& joinCondition, + const std::string& probeGeometry, + const std::string& buildGeometry, + const std::optional& radius, const std::vector& outputLayout, core::JoinType joinType) { VELOX_CHECK_NOT_NULL(planNode_, "SpatialJoin cannot be the source node"); - auto resultType = concat(planNode_->outputType(), right->outputType()); + auto probeType = planNode_->outputType(); + auto buildType = right->outputType(); + auto resultType = concat(probeType, buildType); auto outputType = extract(resultType, outputLayout); VELOX_CHECK(!joinCondition.empty(), "SpatialJoin condition cannot be empty"); core::TypedExprPtr joinConditionExpr = parseExpr(joinCondition, resultType, options_, pool_); + auto probeGeometryField = field(probeType, probeGeometry); + auto buildGeometryField = field(buildType, buildGeometry); + std::optional radiusField; + if (radius.has_value()) { + radiusField = field(buildType, radius.value()); + } + planNode_ = std::make_shared( nextPlanNodeId(), joinType, std::move(joinConditionExpr), + std::move(probeGeometryField), + std::move(buildGeometryField), + std::move(radiusField), std::move(planNode_), right, outputType); @@ -1965,12 +2090,13 @@ PlanBuilder& PlanBuilder::indexLookupJoin( const std::vector& rightKeys, const core::TableScanNodePtr& right, const std::vector& joinConditions, - bool includeMatchColumn, + const std::string& filter, + bool hasMarker, const std::vector& outputLayout, core::JoinType joinType) { VELOX_CHECK_NOT_NULL(planNode_, "indexLookupJoin cannot be the source node"); auto inputType = concat(planNode_->outputType(), right->outputType()); - if (includeMatchColumn) { + if (hasMarker) { auto names = inputType->names(); names.push_back(outputLayout.back()); auto types = inputType->children(); @@ -1988,13 +2114,20 @@ PlanBuilder& PlanBuilder::indexLookupJoin( parseIndexJoinCondition(joinCondition, inputType, pool_)); } + // Parse filter expression if provided + core::TypedExprPtr filterExpr; + if (!filter.empty()) { + filterExpr = parseExpr(filter, inputType, options_, pool_); + } + planNode_ = std::make_shared( nextPlanNodeId(), joinType, std::move(leftKeyFields), std::move(rightKeyFields), std::move(joinConditionPtrs), - includeMatchColumn, + filterExpr, + hasMarker, std::move(planNode_), right, std::move(outputType)); @@ -2006,7 +2139,7 @@ PlanBuilder& PlanBuilder::unnest( const std::vector& replicateColumns, const std::vector& unnestColumns, const std::optional& ordinalColumn, - const std::optional& emptyUnnestValueName) { + const std::optional& markerName) { VELOX_CHECK_NOT_NULL(planNode_, "Unnest cannot be the source node"); std::vector> replicateFields; @@ -2042,7 +2175,7 @@ PlanBuilder& PlanBuilder::unnest( unnestFields, unnestNames, ordinalColumn, - emptyUnnestValueName, + markerName, planNode_); VELOX_CHECK(planNode_->supportsBarrier()); return *this; @@ -2050,24 +2183,21 @@ PlanBuilder& PlanBuilder::unnest( namespace { std::string throwWindowFunctionDoesntExist(const std::string& name) { - std::stringstream error; - error << "Window function doesn't exist: " << name << "."; if (exec::windowFunctions().empty()) { - error << " Registry of window functions is empty. " - "Make sure to register some window functions."; + VELOX_USER_FAIL( + "Registry of window functions is empty. Make sure to register some window functions."); } - VELOX_USER_FAIL(error.str()); + VELOX_USER_FAIL("Window function doesn't exist: {}.", name); } std::string throwWindowFunctionSignatureNotSupported( const std::string& name, const std::vector& types, const std::vector& signatures) { - std::stringstream error; - error << "Window function signature is not supported: " - << toString(name, types) - << ". Supported signatures: " << toString(signatures) << "."; - VELOX_USER_FAIL(error.str()); + VELOX_USER_FAIL( + "Window function signature is not supported: {}. Supported signatures: {}.", + toString(name, types), + toString(signatures)); } TypePtr resolveWindowType( @@ -2076,7 +2206,8 @@ TypePtr resolveWindowType( bool nullOnFailure) { if (auto signatures = exec::getWindowFunctionSignatures(windowFunctionName)) { for (const auto& signature : signatures.value()) { - exec::SignatureBinder binder(*signature, inputTypes); + exec::SignatureBinder binder( + *signature, inputTypes, TypeCoercer::defaults()); if (binder.tryBind()) { return binder.tryResolveType(signature->returnType()); } @@ -2129,48 +2260,51 @@ class WindowTypeResolver { }; const core::WindowNode::Frame createWindowFrame( - const duckdb::IExprWindowFrame& windowFrame, + const core::WindowCallExpr& windowCall, const TypePtr& inputRow, memory::MemoryPool* pool) { - core::WindowNode::Frame frame; - frame.type = (windowFrame.type == duckdb::WindowType::kRows) - ? core::WindowNode::WindowType::kRows - : core::WindowNode::WindowType::kRange; - - auto boundTypeConversion = - [](duckdb::BoundType boundType) -> core::WindowNode::BoundType { + auto boundTypeConversion = [](core::WindowCallExpr::BoundType boundType) + -> core::WindowNode::BoundType { switch (boundType) { - case duckdb::BoundType::kCurrentRow: + case core::WindowCallExpr::BoundType::kCurrentRow: return core::WindowNode::BoundType::kCurrentRow; - case duckdb::BoundType::kFollowing: + case core::WindowCallExpr::BoundType::kFollowing: return core::WindowNode::BoundType::kFollowing; - case duckdb::BoundType::kPreceding: + case core::WindowCallExpr::BoundType::kPreceding: return core::WindowNode::BoundType::kPreceding; - case duckdb::BoundType::kUnboundedFollowing: + case core::WindowCallExpr::BoundType::kUnboundedFollowing: return core::WindowNode::BoundType::kUnboundedFollowing; - case duckdb::BoundType::kUnboundedPreceding: + case core::WindowCallExpr::BoundType::kUnboundedPreceding: return core::WindowNode::BoundType::kUnboundedPreceding; } VELOX_UNREACHABLE(); }; - frame.startType = boundTypeConversion(windowFrame.startType); - frame.startValue = windowFrame.startValue - ? core::Expressions::inferTypes(windowFrame.startValue, inputRow, pool) + + core::WindowNode::Frame frame; + const auto& windowFrame = windowCall.frame(); + VELOX_CHECK(windowFrame.has_value(), "Window frame must be specified"); + + frame.type = (windowFrame->type == core::WindowCallExpr::WindowType::kRows) + ? core::WindowNode::WindowType::kRows + : core::WindowNode::WindowType::kRange; + frame.startType = boundTypeConversion(windowFrame->startType); + frame.startValue = windowFrame->startValue + ? core::Expressions::inferTypes(windowFrame->startValue, inputRow, pool) : nullptr; - frame.endType = boundTypeConversion(windowFrame.endType); - frame.endValue = windowFrame.endValue - ? core::Expressions::inferTypes(windowFrame.endValue, inputRow, pool) + frame.endType = boundTypeConversion(windowFrame->endType); + frame.endValue = windowFrame->endValue + ? core::Expressions::inferTypes(windowFrame->endValue, inputRow, pool) : nullptr; return frame; } std::vector parsePartitionKeys( - const duckdb::IExprWindowFunction& windowExpr, + const core::WindowCallExpr& windowCall, const std::string& windowString, const TypePtr& inputRow, memory::MemoryPool* pool) { std::vector partitionKeys; - for (const auto& partitionKey : windowExpr.partitionBy) { + for (const auto& partitionKey : windowCall.partitionKeys()) { auto typedExpr = core::Expressions::inferTypes(partitionKey, inputRow, pool); auto typedPartitionKey = @@ -2188,14 +2322,14 @@ std::pair< std::vector, std::vector> parseOrderByKeys( - const duckdb::IExprWindowFunction& windowExpr, + const core::WindowCallExpr& windowCall, const std::string& windowString, const TypePtr& inputRow, memory::MemoryPool* pool) { std::vector sortingKeys; std::vector sortingOrders; - for (const auto& orderBy : windowExpr.orderBy) { + for (const auto& orderBy : windowCall.orderByKeys()) { auto typedExpr = core::Expressions::inferTypes(orderBy.expr, inputRow, pool); auto sortingKey = @@ -2259,33 +2393,34 @@ PlanBuilder& PlanBuilder::window( auto errorOnMismatch = [&](const std::string& windowString, const std::string& mismatchTypeString) -> void { - std::stringstream error; - error << "Window function invocations " << windowString << " and " - << windowFunctions[0] << " do not match " << mismatchTypeString - << " clauses."; - VELOX_USER_FAIL(error.str()); + VELOX_USER_FAIL( + "Window function invocations {} and {} do not match {} clauses.", + windowString, + windowFunctions[0], + mismatchTypeString); }; WindowTypeResolver windowResolver; facebook::velox::duckdb::ParseOptions options; options.parseIntegerAsBigint = options_.parseIntegerAsBigint; for (const auto& windowString : windowFunctions) { - const auto& windowExpr = duckdb::parseWindowExpr(windowString, options); + auto windowExprPtr = duckdb::parseWindowExpr(windowString, options); + auto* windowCall = windowExprPtr->as(); // All window function SQL strings in the list are expected to have the same // PARTITION BY and ORDER BY clauses. Validate this assumption. if (first) { partitionKeys = - parsePartitionKeys(windowExpr, windowString, inputType, pool_); + parsePartitionKeys(*windowCall, windowString, inputType, pool_); auto sortPair = - parseOrderByKeys(windowExpr, windowString, inputType, pool_); + parseOrderByKeys(*windowCall, windowString, inputType, pool_); sortingKeys = sortPair.first; sortingOrders = sortPair.second; first = false; } else { auto latestPartitionKeys = - parsePartitionKeys(windowExpr, windowString, inputType, pool_); + parsePartitionKeys(*windowCall, windowString, inputType, pool_); auto [latestSortingKeys, latestSortingOrders] = - parseOrderByKeys(windowExpr, windowString, inputType, pool_); + parseOrderByKeys(*windowCall, windowString, inputType, pool_); if (!equalFieldAccessTypedExprPtrList( partitionKeys, latestPartitionKeys)) { @@ -2301,15 +2436,19 @@ PlanBuilder& PlanBuilder::window( } } - auto windowCall = std::dynamic_pointer_cast( + // Build a plain CallExpr for type resolution (WindowCallExpr carries + // partition/order metadata that CallTypedExpr doesn't need). + auto plainCall = std::make_shared( + windowCall->name(), windowCall->inputs(), std::nullopt); + auto typedCall = std::dynamic_pointer_cast( core::Expressions::inferTypes( - windowExpr.functionCall, planNode_->outputType(), pool_)); + plainCall, planNode_->outputType(), pool_)); windowNodeFunctions.push_back( - {std::move(windowCall), - createWindowFrame(windowExpr.frame, planNode_->outputType(), pool_), - windowExpr.ignoreNulls}); - if (windowExpr.functionCall->alias().has_value()) { - windowNames.push_back(windowExpr.functionCall->alias().value()); + {std::move(typedCall), + createWindowFrame(*windowCall, planNode_->outputType(), pool_), + windowCall->isIgnoreNulls()}); + if (windowExprPtr->alias().has_value()) { + windowNames.push_back(windowExprPtr->alias().value()); } else { windowNames.push_back(fmt::format("w{}", i++)); } @@ -2404,6 +2543,49 @@ PlanBuilder& PlanBuilder::markDistinct( return *this; } +PlanBuilder& PlanBuilder::enforceDistinct( + const std::vector& distinctKeys, + std::string errorMessage, + const std::vector& preGroupedKeys) { + VELOX_CHECK_NOT_NULL(planNode_, "EnforceDistinct cannot be the source node"); + planNode_ = std::make_shared( + nextPlanNodeId(), + fields(planNode_->outputType(), distinctKeys), + fields(planNode_->outputType(), preGroupedKeys), + std::move(errorMessage), + planNode_); + VELOX_CHECK(!planNode_->supportsBarrier()); + return *this; +} + +PlanBuilder& PlanBuilder::streamingEnforceDistinct( + const std::vector& distinctKeys, + std::string errorMessage) { + return enforceDistinct(distinctKeys, std::move(errorMessage), distinctKeys); +} + +PlanBuilder& PlanBuilder::markSorted( + const std::string& markerKey, + const std::vector& sortingKeys, + const std::vector& sortingOrders) { + VELOX_CHECK_NOT_NULL(planNode_, "MarkSorted cannot be the source node"); + VELOX_CHECK_EQ(sortingKeys.size(), sortingOrders.size()); + + std::vector keyExprs; + for (const auto& key : sortingKeys) { + keyExprs.push_back(field(planNode_->outputType(), key)); + } + + planNode_ = core::MarkSortedNode::Builder() + .id(nextPlanNodeId()) + .markerName(markerKey) + .sortingKeys(keyExprs) + .sortingOrders(sortingOrders) + .source(planNode_) + .build(); + return *this; +} + core::PlanNodeId PlanBuilder::nextPlanNodeId() { return planNodeIdGenerator_->next(); } @@ -2479,7 +2661,9 @@ std::vector PlanBuilder::exprs( std::vector typedExpressions; for (auto& expr : expressions) { auto typedExpression = core::Expressions::inferTypes( - parse::parseExpr(expr, options_), inputType, pool_); + parse::DuckSqlExpressionsParser(options_).parseExpr(expr), + inputType, + pool_); if (dynamic_cast( typedExpression.get())) { @@ -2508,7 +2692,7 @@ core::PlanNodePtr PlanBuilder::IndexLookupJoinBuilder::build( planBuilder_.planNode_, "IndexLookupJoin cannot be the source node"); auto inputType = concat(planBuilder_.planNode_->outputType(), indexSource_->outputType()); - if (includeMatchColumn_) { + if (hasMarker_) { auto names = inputType->names(); names.push_back(outputLayout_.back()); auto types = inputType->children(); @@ -2524,8 +2708,16 @@ core::PlanNodePtr PlanBuilder::IndexLookupJoinBuilder::build( std::vector joinConditionPtrs{}; joinConditionPtrs.reserve(joinConditions_.size()); for (const auto& joinCondition : joinConditions_) { - joinConditionPtrs.push_back(PlanBuilder::parseIndexJoinCondition( - joinCondition, inputType, planBuilder_.pool_)); + joinConditionPtrs.push_back( + PlanBuilder::parseIndexJoinCondition( + joinCondition, inputType, planBuilder_.pool_)); + } + + // Parse filter expression if provided + core::TypedExprPtr filterExpr; + if (!filter_.empty()) { + filterExpr = parseExpr( + filter_, inputType, planBuilder_.options_, planBuilder_.pool_); } return std::make_shared( @@ -2534,7 +2726,8 @@ core::PlanNodePtr PlanBuilder::IndexLookupJoinBuilder::build( std::move(leftKeyFields), std::move(rightKeyFields), std::move(joinConditionPtrs), - includeMatchColumn_, + filterExpr, + hasMarker_, std::move(planBuilder_.planNode_), indexSource_, std::move(outputType)); diff --git a/velox/exec/tests/utils/PlanBuilder.h b/velox/exec/tests/utils/PlanBuilder.h index 96b65390a86..f8896cdcbca 100644 --- a/velox/exec/tests/utils/PlanBuilder.h +++ b/velox/exec/tests/utils/PlanBuilder.h @@ -35,9 +35,16 @@ enum class Table : uint8_t; namespace facebook::velox::exec::test { +struct AggregationConfig { + std::vector groupingKeys; + std::vector aggregates; +}; + struct PushdownConfig { common::SubfieldFilters subfieldFiltersMap; std::string remainingFilter; + // Aggregation pushdown configuration + std::optional aggregationConfig; }; /// A builder class with fluent API for building query plans. Plans are built @@ -289,6 +296,8 @@ class PlanBuilder { /// AND'ed with all the subfieldFilters. TableScanBuilder& remainingFilter(std::string remainingFilter); + TableScanBuilder& sampleRate(double sampleRate); + /// @param dataColumns can be different from 'outputType' for the purposes /// of testing queries using missing columns. It is used, if specified, for /// parseExpr call and as 'dataColumns' for the TableHandle. You supply more @@ -316,6 +325,12 @@ class PlanBuilder { return *this; } + TableScanBuilder& filterColumnHandles( + std::vector filterColumnHandles) { + filterColumnHandles_ = std::move(filterColumnHandles); + return *this; + } + /// @param assignments Optional ColumnHandles. /// outputType names should match the keys in the 'assignments' map. The /// 'assignments' map may contain more columns than 'outputType' if some @@ -325,6 +340,13 @@ class PlanBuilder { return *this; } + /// @param indexColumns Names of columns that form the index key for the + /// table. When set, enables index-based lookups. + TableScanBuilder& indexColumns(std::vector indexColumns) { + indexColumns_ = std::move(indexColumns); + return *this; + } + /// Stop the TableScanBuilder. PlanBuilder& endTableScan() { planBuilder_.planNode_ = build(planBuilder_.nextPlanNodeId()); @@ -340,7 +362,10 @@ class PlanBuilder { std::string connectorId_{kHiveDefaultConnectorId}; RowTypePtr outputType_; core::ExprPtr remainingFilter_; + double sampleRate_{1.0}; RowTypePtr dataColumns_; + std::vector indexColumns_; + std::vector filterColumnHandles_; std::unordered_map columnAliases_; connector::ConnectorTableHandlePtr tableHandle_; connector::ColumnHandleMap assignments_; @@ -395,8 +420,8 @@ class PlanBuilder { return *this; } - IndexLookupJoinBuilder& includeMatchColumn(bool includeMatchColumn) { - includeMatchColumn_ = includeMatchColumn; + IndexLookupJoinBuilder& hasMarker(bool hasMarker) { + hasMarker_ = hasMarker; return *this; } @@ -406,6 +431,13 @@ class PlanBuilder { return *this; } + /// @param filter SQL expression for the additional join filter. Can + /// use columns from both probe and build sides of the join. + IndexLookupJoinBuilder& filter(std::string filter) { + filter_ = std::move(filter); + return *this; + } + /// @param joinType Type of the join supported: inner, left. IndexLookupJoinBuilder& joinType(core::JoinType joinType) { joinType_ = joinType; @@ -427,7 +459,8 @@ class PlanBuilder { std::vector rightKeys_; core::TableScanNodePtr indexSource_; std::vector joinConditions_; - bool includeMatchColumn_{false}; + std::string filter_; + bool hasMarker_{false}; std::vector outputLayout_; core::JoinType joinType_{core::JoinType::kInner}; }; @@ -622,6 +655,11 @@ class PlanBuilder { bool parallelizable = false, size_t repeatTimes = 1); + /// Convenience overload that wraps a single RowVectorPtr in a vector. + PlanBuilder& values(const RowVectorPtr& value) { + return values(std::vector{value}); + } + PlanBuilder& filtersAsNode(bool filtersAsNode) { filtersAsNode_ = filtersAsNode; return *this; @@ -649,9 +687,7 @@ class PlanBuilder { /// /// @param outputType The type of the data coming in and out of the exchange. /// @param serdekind The kind of seralized data format. - PlanBuilder& exchange( - const RowTypePtr& outputType, - VectorSerde::Kind serdekind); + PlanBuilder& exchange(const RowTypePtr& outputType, std::string serdekind); /// Add a MergeExchangeNode using specified ORDER BY clauses. /// @@ -664,7 +700,7 @@ class PlanBuilder { PlanBuilder& mergeExchange( const RowTypePtr& outputType, const std::vector& keys, - VectorSerde::Kind serdekind); + std::string serdekind); /// Add a ProjectNode using specified SQL expressions. /// @@ -726,7 +762,10 @@ class PlanBuilder { /// Add a FilterNode using specified SQL expression. /// /// @param filter SQL expression of type boolean. - PlanBuilder& filter(const std::string& filter); + PlanBuilder& filter(const std::string& filterExpr); + + /// Same as above, but takes an untyped expression. + PlanBuilder& filter(const core::ExprPtr& filterExpr); /// Similar to filter() except 'optionalFilter' could be empty and the /// function will skip creating a FilterNode in that case. @@ -817,6 +856,12 @@ class PlanBuilder { /// output of the previous operator. /// @param ensureFiles When this option is set the HiveDataSink will always /// create a file even if there is no data. + /// @param commitStrategy The commit strategy to use for the table write + /// operation, default is kNoCommit. + /// @param insertTableHandle Encapsulates information needed to write data + /// to a table through a connector. If not specified, tableWrite will build + /// a HiveInsertTableHandle with columnHandles, bucketProperty and + /// locationHandle. PlanBuilder& tableWrite( const std::string& outputDirectoryPath, const std::vector& partitionBy, @@ -835,10 +880,21 @@ class PlanBuilder { const RowTypePtr& schema = nullptr, const bool ensureFiles = false, const connector::CommitStrategy commitStrategy = - connector::CommitStrategy::kNoCommit); - - /// Add a TableWriteMergeNode. - PlanBuilder& tableWriteMerge(); + connector::CommitStrategy::kNoCommit, + std::shared_ptr insertTableHandle = nullptr); + + /// Add a TableWriteMergeNode. Derives the ColumnStatsSpec from the + /// TableWriteNode in the plan tree and applies the given step. + /// Finds the TableWriteNode through LocalPartitionNode if present. + /// @param step Must be kIntermediate or kFinal. Defaults to kIntermediate. + PlanBuilder& tableWriteMerge( + core::AggregationNode::Step step = + core::AggregationNode::Step::kIntermediate); + + /// Add a TableWriteMergeNode with an explicit ColumnStatsSpec. Use for + /// coordinator-side merge where the TableWriteNode is in a different + /// fragment (e.g. after an Exchange). + PlanBuilder& tableWriteMerge(core::ColumnStatsSpec columnStatsSpec); /// Add an AggregationNode representing partial aggregation with the /// specified grouping keys, aggregates and optional masks. @@ -1014,7 +1070,8 @@ class PlanBuilder { const std::vector& aggregates, const std::vector& masks, core::AggregationNode::Step step, - bool ignoreNullKeys); + bool ignoreNullKeys, + bool noGroupsSpanBatches = false); /// Add a GroupIdNode using the specified grouping keys, grouping sets, /// aggregation inputs and a groupId column name. @@ -1134,14 +1191,14 @@ class PlanBuilder { int numPartitions, bool replicateNullsAndAny, const std::vector& outputLayout = {}, - VectorSerde::Kind serdeKind = VectorSerde::Kind::kPresto); + std::string serdeKind = "Presto"); /// Same as above, but assumes 'replicateNullsAndAny' is false. PlanBuilder& partitionedOutput( const std::vector& keys, int numPartitions, const std::vector& outputLayout = {}, - VectorSerde::Kind serdeKind = VectorSerde::Kind::kPresto); + std::string serdeKind = "Presto"); /// Same as above, but allows to provide custom partition function. PlanBuilder& partitionedOutput( @@ -1150,7 +1207,7 @@ class PlanBuilder { bool replicateNullsAndAny, core::PartitionFunctionSpecPtr partitionFunctionSpec, const std::vector& outputLayout = {}, - VectorSerde::Kind serdeKind = VectorSerde::Kind::kPresto); + std::string serdeKind = "Presto"); /// Adds a PartitionedOutputNode to broadcast the input data. /// @@ -1160,12 +1217,12 @@ class PlanBuilder { /// duplicated in the output. PlanBuilder& partitionedOutputBroadcast( const std::vector& outputLayout = {}, - VectorSerde::Kind serdeKind = VectorSerde::Kind::kPresto); + std::string serdeKind = "Presto"); /// Adds a PartitionedOutputNode to put data into arbitrary buffer. PlanBuilder& partitionedOutputArbitrary( const std::vector& outputLayout = {}, - VectorSerde::Kind serdeKind = VectorSerde::Kind::kPresto); + std::string serdeKind = "Presto"); /// Adds a LocalPartitionNode to hash-partition the input on the specified /// keys using exec::HashPartitionFunction. Number of partitions is determined @@ -1182,6 +1239,9 @@ class PlanBuilder { /// current plan node). PlanBuilder& localPartition(const std::vector& keys); + /// Add a LocalPartitionNode with gather type (N-to-1, empty partition keys). + PlanBuilder& localGather(); + /// A convenience method to add a LocalPartitionNode with hive partition /// function. PlanBuilder& localPartition( @@ -1244,7 +1304,8 @@ class PlanBuilder { const std::string& filter, const std::vector& outputLayout, core::JoinType joinType = core::JoinType::kInner, - bool nullAware = false); + bool nullAware = false, + bool nullAsValue = false); /// Add a MergeJoinNode to join two inputs using one or more join keys and an /// optional filter. The caller is responsible to ensure that inputs are @@ -1302,6 +1363,9 @@ class PlanBuilder { PlanBuilder& spatialJoin( const core::PlanNodePtr& right, const std::string& joinCondition, + const std::string& probeGeometry, + const std::string& buildGeometry, + const std::optional& radius, const std::vector& outputLayout, core::JoinType joinType = core::JoinType::kInner); @@ -1315,6 +1379,11 @@ class PlanBuilder { /// node. Second input is specified in 'right' parameter and must be a /// table source with the connector table handle with index lookup support. /// + /// @param leftKeys Join keys from the probe side, the preceding plan node. + /// Cannot be empty. + /// @param rightKeys Join keys from the index lookup side, the plan node + /// specified in 'right' parameter. The number and types of left and right + /// keys must be the same. /// @param right The right input source with index lookup support. /// @param joinConditions SQL expressions as the join conditions. Each join /// condition must use columns from both sides. For the right side, it can @@ -1327,18 +1396,23 @@ class PlanBuilder { /// where "a" is the index column from right side and "b", "c" are either /// condition column from left side or a constant but at least one of them /// must not be constant. They all have the same type. - /// @param joinType Type of the join supported: inner, left. - /// @param includeMatchColumn if true, 'outputLayout' should include a boolean + /// @param filter SQL expression for the additional join filter to apply on + /// join results. This supports filters that can't be converted into join + /// conditions or lookup conditions. Can be an empty string if no additional + /// filter is needed. + /// @param hasMarker if true, 'outputLayout' should include a boolean /// column at the end to indicate if a join output row has a match or not. /// This only applies for left join. - /// - /// See hashJoin method for the description of the other parameters. + /// @param outputLayout Output layout consisting of columns from probe and + /// build sides. + /// @param joinType Type of the join supported: inner, left. PlanBuilder& indexLookupJoin( const std::vector& leftKeys, const std::vector& rightKeys, const core::TableScanNodePtr& right, const std::vector& joinConditions, - bool includeMatchColumn, + const std::string& filter, + bool hasMarker, const std::vector& outputLayout, core::JoinType joinType = core::JoinType::kInner); @@ -1360,16 +1434,16 @@ class PlanBuilder { /// @param ordinalColumn An optional name for the 'ordinal' column to produce. /// This column contains the index of the element of the unnested array or /// map. If not specified, the output will not contain this column. - /// @param emptyUnnestValueName An optional name for the - /// 'emptyUnnestValue' column to produce. This column contains a boolean - /// indicating if the output row has empty unnest value or not. If not - /// specified, the output will not contain this column and the unnest operator - /// also skips producing output rows with empty unnest value. + /// @param markerName An optional name for the marker column to produce. + /// This column contains a boolean indicating whether the output row has + /// non-empty unnested value. If not specified, the output will not contain + /// this column and the unnest operator also skips producing output rows + /// with empty unnest value. PlanBuilder& unnest( const std::vector& replicateColumns, const std::vector& unnestColumns, const std::optional& ordinalColumn = std::nullopt, - const std::optional& emptyUnnestValueName = std::nullopt); + const std::optional& markerName = std::nullopt); /// Add a WindowNode to compute one or more windowFunctions. /// @param windowFunctions A list of one or more window function SQL like @@ -1429,6 +1503,41 @@ class PlanBuilder { std::string markerKey, const std::vector& distinctKeys); + /// Add an EnforceDistinctNode to ensure input has unique values for the + /// specified keys at runtime. Throws with the specified error message if + /// duplicates are found. + /// @param distinctKeys List of columns that must have unique values. + /// @param errorMessage Error message to include in the exception when + /// duplicates are found. + /// @param preGroupedKeys Optional subset of distinctKeys that input is + /// already clustered on. When equal to distinctKeys, uses streaming + /// enforcement. + PlanBuilder& enforceDistinct( + const std::vector& distinctKeys, + std::string errorMessage, + const std::vector& preGroupedKeys = {}); + + /// Add an EnforceDistinctNode to ensure pre-grouped input has unique + /// values for the specified keys. Equivalent to calling enforceDistinct with + /// preGroupedKeys equal to distinctKeys, which uses the streaming + /// implementation. + /// @param distinctKeys List of columns that must have unique values. Input + /// must be clustered on these keys. + /// @param errorMessage Error message to include in the exception when + /// duplicates are found. + PlanBuilder& streamingEnforceDistinct( + const std::vector& distinctKeys, + std::string errorMessage); + + /// Add a MarkSortedNode to mark rows indicating sortedness. + /// @param markerKey Name of output marker column (boolean). + /// @param sortingKeys List of columns used for sorting. + /// @param sortingOrders Sort orders for each sorting key. + PlanBuilder& markSorted( + const std::string& markerKey, + const std::vector& sortingKeys, + const std::vector& sortingOrders); + /// Stores the latest plan node ID into the specified variable. Useful for /// capturing IDs of the leaf plan nodes (table scans, exchanges, etc.) to use /// when adding splits at runtime. diff --git a/velox/exec/tests/utils/QueryAssertions.cpp b/velox/exec/tests/utils/QueryAssertions.cpp index 58adaf6f997..abb184b333b 100644 --- a/velox/exec/tests/utils/QueryAssertions.cpp +++ b/velox/exec/tests/utils/QueryAssertions.cpp @@ -21,6 +21,8 @@ #include "velox/duckdb/conversion/DuckConversion.h" #include "velox/exec/Cursor.h" #include "velox/exec/tests/utils/QueryAssertions.h" +#include "velox/type/Type.h" +#include "velox/vector/VariantToVector.h" #include "velox/vector/VectorTypeUtils.h" using facebook::velox::duckdb::duckdbTimestampToVelox; @@ -100,8 +102,9 @@ ::duckdb::Value duckValueAt( vector_size_t index) { auto type = vector->type(); if (type->isDate()) { - return ::duckdb::Value::DATE(::duckdb::Date::EpochDaysToDate( - vector->as>()->valueAt(index))); + return ::duckdb::Value::DATE( + ::duckdb::Date::EpochDaysToDate( + vector->as>()->valueAt(index))); } return ::duckdb::Value(vector->as>()->valueAt(index)); } @@ -129,6 +132,14 @@ ::duckdb::Value duckValueAt( return ::duckdb::Value::INTERVAL(0, days, microseconds); } + if (type->isTime()) { + VELOX_DCHECK(type->equivalent(*TIME())); + // TIME is stored as milliseconds since midnight in Velox. + // DuckDB TIME is stored as microseconds since midnight. + const auto timeMillis = vector->as>()->valueAt(index); + return ::duckdb::Value::TIME(::duckdb::dtime_t(timeMillis * 1000L)); + } + return ::duckdb::Value(vector->as>()->valueAt(index)); } @@ -216,9 +227,10 @@ ::duckdb::Value duckValueAt( const auto& mapValues = mapVector->mapValues(); auto offset = mapVector->offsetAt(mapRow); auto size = mapVector->sizeAt(mapRow); - auto mapType = ::duckdb::ListType::GetChildType(::duckdb::LogicalType::MAP( - duckdb::fromVeloxType(mapKeys->type()), - duckdb::fromVeloxType(mapValues->type()))); + auto mapType = ::duckdb::ListType::GetChildType( + ::duckdb::LogicalType::MAP( + duckdb::fromVeloxType(mapKeys->type()), + duckdb::fromVeloxType(mapValues->type()))); if (size == 0) { return ::duckdb::Value::MAP(mapType, ::duckdb::vector<::duckdb::Value>()); } @@ -274,7 +286,8 @@ variant variantAt( int32_t row, int32_t column) { return variant::binary( - StringView(::duckdb::StringValue::Get(dataChunk->GetValue(column, row)))); + std::string( + ::duckdb::StringValue::Get(dataChunk->GetValue(column, row)))); } template <> @@ -319,7 +332,7 @@ variant variantAt(const ::duckdb::Value& value) { template <> variant variantAt(const ::duckdb::Value& value) { - return variant::binary(StringView(::duckdb::StringValue::Get(value))); + return variant::binary(std::string(::duckdb::StringValue::Get(value))); } variant nullVariant(const TypePtr& type) { @@ -435,12 +448,21 @@ std::vector materialize( } else if (type->isDecimal()) { row.push_back(duckdb::decimalVariant(dataChunk->GetValue(j, i))); } else if (type->isIntervalDayTime()) { - auto value = variant(::duckdb::Interval::GetMicro( - dataChunk->GetValue(j, i).GetValue<::duckdb::interval_t>())); + auto value = variant( + ::duckdb::Interval::GetMicro( + dataChunk->GetValue(j, i).GetValue<::duckdb::interval_t>())); row.push_back(value); } else if (type->isDate()) { - auto value = variant(::duckdb::Date::EpochDays( - dataChunk->GetValue(j, i).GetValue<::duckdb::date_t>())); + auto value = variant( + ::duckdb::Date::EpochDays( + dataChunk->GetValue(j, i).GetValue<::duckdb::date_t>())); + row.push_back(value); + } else if (type->isTime()) { + VELOX_DCHECK(type->equivalent(*TIME())); + // DuckDB TIME is in microseconds, Velox TIME is in milliseconds. + auto value = variant( + dataChunk->GetValue(j, i).GetValue<::duckdb::dtime_t>().micros / + 1000L); row.push_back(value); } else { auto value = VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( @@ -708,7 +730,7 @@ std::string toTypeString(const MaterializedRow& row) { if (i > 0) { out << ", "; } - out << mapTypeKindToName(row[i].kind()); + out << TypeKindName::toName(row[i].kind()); } out << ")"; return out.str(); @@ -1340,12 +1362,31 @@ std::pair, std::vector> readCursor( const CursorParameters& params, std::function addSplits, uint64_t maxWaitMicros) { + return readCursorAsync( + params, + [addSplitsVoid = std::move(addSplits)](TaskCursor* cursor) { + addSplitsVoid(cursor); + return ContinueFuture::makeEmpty(); + }, + maxWaitMicros); +} + +std::pair, std::vector> +readCursorAsync( + const CursorParameters& params, + std::function addSplits, + uint64_t maxWaitMicros) { auto cursor = TaskCursor::create(params); // 'result' borrows memory from cursor so the life cycle must be shorter. std::vector result; auto* task = cursor->task().get(); + cursor->start(); + auto future = ContinueFuture::makeEmpty(); while (!cursor->noMoreSplits()) { - addSplits(cursor.get()); + if (future.valid()) { + future.wait(); + } + future = addSplits(cursor.get()); while (cursor->moveNext()) { auto vector = cursor->current(); vector->loadedVector(); @@ -1441,6 +1482,15 @@ void waitForAllTasksToBeDeleted(uint64_t maxWaitUs) { folly::join("\n", pendingTaskStats)); } +void cancelAllTasks() { + std::vector> pendingTasks = Task::getRunningTasks(); + for (const auto& task : pendingTasks) { + if (task->isRunning()) { + task->requestCancel(); + } + } +} + std::shared_ptr assertQuery( const core::PlanNodePtr& plan, std::function addSplits, diff --git a/velox/exec/tests/utils/QueryAssertions.h b/velox/exec/tests/utils/QueryAssertions.h index 3acfa88885d..276d58257c2 100644 --- a/velox/exec/tests/utils/QueryAssertions.h +++ b/velox/exec/tests/utils/QueryAssertions.h @@ -22,6 +22,9 @@ #include "velox/exec/Operator.h" #include "velox/vector/ComplexVector.h" +#ifdef BLOCK_SIZE +#undef BLOCK_SIZE +#endif #include // @manual namespace facebook::velox::exec::test { @@ -185,6 +188,19 @@ std::pair, std::vector> readCursor( }, uint64_t maxWaitMicros = 5'000'000); +std::pair, std::vector> +readCursorAsync( + const CursorParameters& params, + std::function addSplits = + [](TaskCursor* taskCursor) { + if (taskCursor->noMoreSplits()) { + return ContinueFuture::makeEmpty(); + } + taskCursor->setNoMoreSplits(); + return ContinueFuture::makeEmpty(); + }, + uint64_t maxWaitMicros = 5'000'000); + /// The Task can return results before the Driver is finished executing. /// Wait upto maxWaitMicros for the Task to finish as 'expectedState' before /// returning to ensure it's stable e.g. the Driver isn't updating it anymore. @@ -221,6 +237,11 @@ bool waitForTaskStateChange( /// during this wait call. This is for testing purpose for now. void waitForAllTasksToBeDeleted(uint64_t maxWaitUs = 3'000'000); +/// Cancels all currently running tasks across all available task managers. +/// This is primarily used in testing scenarios to clean up active tasks +/// and ensure test isolation between test cases. +void cancelAllTasks(); + std::shared_ptr assertQuery( const core::PlanNodePtr& plan, const std::string& duckDbSql, @@ -307,6 +328,13 @@ bool assertEqualResults( const core::PlanNodePtr& plan1, const core::PlanNodePtr& plan2); +bool assertEqualResults( + const MaterializedRowMultiset& expectedRows, + const TypePtr& expectedType, + const MaterializedRowMultiset& actualRows, + const TypePtr& actualType, + const std::string& message); + /// Ensure both datasets have the same type and number of rows. void assertEqualTypeAndNumRows( const TypePtr& expectedType, diff --git a/velox/exec/tests/utils/RowContainerTestBase.h b/velox/exec/tests/utils/RowContainerTestBase.h index c9d126fa869..9c04dcca7b9 100644 --- a/velox/exec/tests/utils/RowContainerTestBase.h +++ b/velox/exec/tests/utils/RowContainerTestBase.h @@ -16,9 +16,9 @@ #include #include "velox/common/file/FileSystems.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/dwio/common/tests/utils/BatchMaker.h" #include "velox/exec/RowContainer.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/serializers/CompactRowSerializer.h" #include "velox/serializers/PrestoSerializer.h" #include "velox/serializers/UnsafeRowSerializer.h" @@ -26,6 +26,8 @@ namespace facebook::velox::exec::test { +using TempDirectoryPath = common::testutil::TempDirectoryPath; + class RowContainerTestBase : public testing::Test, public velox::test::VectorTestBase { protected: @@ -46,15 +48,15 @@ class RowContainerTestBase : public testing::Test, facebook::velox::serializer::presto::PrestoVectorSerde:: registerVectorSerde(); } - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kPresto)) { + if (!isRegisteredNamedVectorSerde("Presto")) { facebook::velox::serializer::presto::PrestoVectorSerde:: registerNamedVectorSerde(); } - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kCompactRow)) { + if (!isRegisteredNamedVectorSerde("CompactRow")) { facebook::velox::serializer::CompactRowVectorSerde:: registerNamedVectorSerde(); } - if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kUnsafeRow)) { + if (!isRegisteredNamedVectorSerde("UnsafeRow")) { facebook::velox::serializer::spark::UnsafeRowVectorSerde:: registerNamedVectorSerde(); } @@ -76,7 +78,8 @@ class RowContainerTestBase : public testing::Test, std::unique_ptr makeRowContainer( const std::vector& keyTypes, const std::vector& dependentTypes, - bool isJoinBuild = true) { + bool isJoinBuild = true, + bool useListRowIndex = false) { auto container = std::make_unique( keyTypes, !isJoinBuild, @@ -85,7 +88,9 @@ class RowContainerTestBase : public testing::Test, isJoinBuild, isJoinBuild, true, + false, // hasCountFlag true, + useListRowIndex, pool_.get()); VELOX_CHECK(container->testingMutable()); return container; diff --git a/velox/exec/tests/utils/SerializedPageUtil.cpp b/velox/exec/tests/utils/SerializedPageUtil.cpp index 4316f1a783f..86150055cea 100644 --- a/velox/exec/tests/utils/SerializedPageUtil.cpp +++ b/velox/exec/tests/utils/SerializedPageUtil.cpp @@ -20,9 +20,9 @@ using namespace facebook::velox; namespace facebook::velox::exec::test { -std::unique_ptr toSerializedPage( +std::unique_ptr toSerializedPage( const RowVectorPtr& vector, - VectorSerde::Kind serdeKind, + std::string serdeKind, const std::shared_ptr& bufferManager, memory::MemoryPool* pool) { auto data = @@ -34,7 +34,8 @@ std::unique_ptr toSerializedPage( auto listener = bufferManager->newListener(); IOBufOutputStream stream(*pool, listener.get(), data->size()); data->flush(&stream); - return std::make_unique(stream.getIOBuf(), nullptr, size); + return std::make_unique( + stream.getIOBuf(), nullptr, size); } } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/utils/SerializedPageUtil.h b/velox/exec/tests/utils/SerializedPageUtil.h index 6890693d634..c19387c9bb9 100644 --- a/velox/exec/tests/utils/SerializedPageUtil.h +++ b/velox/exec/tests/utils/SerializedPageUtil.h @@ -23,9 +23,9 @@ namespace facebook::velox::exec::test { /// Helper function for serializing RowVector to PrestoPage format. -std::unique_ptr toSerializedPage( +std::unique_ptr toSerializedPage( const RowVectorPtr& vector, - VectorSerde::Kind serdeKind, + std::string serdeKind, const std::shared_ptr& bufferManager, memory::MemoryPool* pool); diff --git a/velox/exec/tests/utils/TableScanTestBase.cpp b/velox/exec/tests/utils/TableScanTestBase.cpp index 394020fe1b5..c3b10d466fe 100644 --- a/velox/exec/tests/utils/TableScanTestBase.cpp +++ b/velox/exec/tests/utils/TableScanTestBase.cpp @@ -19,6 +19,8 @@ #include "velox/exec/tests/utils/LocalExchangeSource.h" #include "velox/exec/tests/utils/PlanBuilder.h" +using namespace facebook::velox::common::testutil; + namespace facebook::velox::exec::test { void TableScanTestBase::verifyCacheStats( @@ -183,6 +185,9 @@ void TableScanTestBase::testPartitionedTableImpl( .assignments(assignments) .endTableScan() .planNode(); + split = exec::test::HiveConnectorSplitBuilder(filePath) + .partitionKey("pkey", partitionValue) + .build(); assertQuery( op, split, fmt::format("SELECT c0, {}, c1 FROM tmp", partitionValueStr)); outputType = ROW({"c0", "c1", "pkey"}, {BIGINT(), DOUBLE(), partitionType}); @@ -192,6 +197,9 @@ void TableScanTestBase::testPartitionedTableImpl( .assignments(assignments) .endTableScan() .planNode(); + split = exec::test::HiveConnectorSplitBuilder(filePath) + .partitionKey("pkey", partitionValue) + .build(); assertQuery( op, split, fmt::format("SELECT c0, c1, {} FROM tmp", partitionValueStr)); @@ -204,6 +212,9 @@ void TableScanTestBase::testPartitionedTableImpl( .assignments(assignments) .endTableScan() .planNode(); + split = exec::test::HiveConnectorSplitBuilder(filePath) + .partitionKey("pkey", partitionValue) + .build(); assertQuery(op, split, fmt::format("SELECT {} FROM tmp", partitionValueStr)); } diff --git a/velox/exec/tests/utils/TableScanTestBase.h b/velox/exec/tests/utils/TableScanTestBase.h index f7e2929f284..16c6178ce0f 100644 --- a/velox/exec/tests/utils/TableScanTestBase.h +++ b/velox/exec/tests/utils/TableScanTestBase.h @@ -17,8 +17,8 @@ #include -#include #include +#include #include #include "velox/connectors/hive/FileHandle.h" #include "velox/exec/PlanNodeStats.h" @@ -27,6 +27,8 @@ namespace facebook::velox::exec::test { +using TempFilePath = common::testutil::TempFilePath; + class TableScanTestBase : public HiveConnectorTestBase { protected: void SetUp() override; diff --git a/velox/exec/tests/utils/TableWriterTestBase.cpp b/velox/exec/tests/utils/TableWriterTestBase.cpp index 34304d94ab2..bf349d8b2ca 100644 --- a/velox/exec/tests/utils/TableWriterTestBase.cpp +++ b/velox/exec/tests/utils/TableWriterTestBase.cpp @@ -17,6 +17,7 @@ #include "velox/exec/tests/utils/TableWriterTestBase.h" namespace velox::exec::test { +using namespace facebook::velox::common::testutil; TableWriterTestBase::TestParam::TestParam( FileFormat fileFormat, @@ -399,8 +400,9 @@ TableWriterTestBase::makeHiveConnectorSplits(const std::string& directoryPath) { std::vector> splits; for (auto& path : fs::recursive_directory_iterator(directoryPath)) { if (path.is_regular_file()) { - splits.push_back(HiveConnectorTestBase::makeHiveConnectorSplits( - path.path().string(), 1, fileFormat_)[0]); + splits.push_back( + HiveConnectorTestBase::makeHiveConnectorSplits( + path.path().string(), 1, fileFormat_)[0]); } } return splits; @@ -426,8 +428,9 @@ TableWriterTestBase::makeHiveConnectorSplits( const std::vector& filePaths) { std::vector> splits; for (const auto& filePath : filePaths) { - splits.push_back(HiveConnectorTestBase::makeHiveConnectorSplits( - filePath.string(), 1, fileFormat_)[0]); + splits.push_back( + HiveConnectorTestBase::makeHiveConnectorSplits( + filePath.string(), 1, fileFormat_)[0]); } return splits; } @@ -781,9 +784,10 @@ std::string TableWriterTestBase::partitionNameToPredicate( for (auto i = 0; i < partitionKeyValues.size(); ++i) { if (partitionTypes[i]->isVarchar() || partitionTypes[i]->isVarbinary() || partitionTypes[i]->isDate()) { - conjuncts.push_back(partitionKeyValues[i] - .replace(partitionKeyValues[i].find("="), 1, "='") - .append("'")); + conjuncts.push_back( + partitionKeyValues[i] + .replace(partitionKeyValues[i].find("="), 1, "='") + .append("'")); } else { conjuncts.push_back(partitionKeyValues[i]); } @@ -799,9 +803,10 @@ std::string TableWriterTestBase::partitionNameToPredicate( for (auto i = 0; i < partitionDirNames.size(); ++i) { if (partitionTypes_[i]->isVarchar() || partitionTypes_[i]->isVarbinary() || partitionTypes_[i]->isDate()) { - conjuncts.push_back(partitionKeyValues[i] - .replace(partitionKeyValues[i].find("="), 1, "='") - .append("'")); + conjuncts.push_back( + partitionKeyValues[i] + .replace(partitionKeyValues[i].find("="), 1, "='") + .append("'")); } else { conjuncts.push_back(partitionDirNames[i]); } @@ -814,20 +819,22 @@ void TableWriterTestBase::verifyUnbucketedFilePath( const std::string& targetDir) { ASSERT_EQ(filePath.parent_path().string(), targetDir); if (commitStrategy_ == CommitStrategy::kNoCommit) { - ASSERT_TRUE(RE2::FullMatch( - filePath.filename().string(), - fmt::format( - "test_cursor.+_[0-{}]_{}_.+", - numTableWriterCount_ - 1, - tableWriteNodeId_))) + ASSERT_TRUE( + RE2::FullMatch( + filePath.filename().string(), + fmt::format( + "test_cursor.+_[0-{}]_{}_.+", + numTableWriterCount_ - 1, + tableWriteNodeId_))) << filePath.filename().string(); } else { - ASSERT_TRUE(RE2::FullMatch( - filePath.filename().string(), - fmt::format( - ".tmp.velox.test_cursor.+_[0-{}]_{}_.+", - numTableWriterCount_ - 1, - tableWriteNodeId_))) + ASSERT_TRUE( + RE2::FullMatch( + filePath.filename().string(), + fmt::format( + ".tmp.velox.test_cursor.+_[0-{}]_{}_.+", + numTableWriterCount_ - 1, + tableWriteNodeId_))) << filePath.filename().string(); } } @@ -843,25 +850,29 @@ void TableWriterTestBase::verifyBucketedFileName( const std::filesystem::path& filePath) { if (commitStrategy_ == CommitStrategy::kNoCommit) { if (fileFormat_ == FileFormat::PARQUET) { - ASSERT_TRUE(RE2::FullMatch( - filePath.filename().string(), - "0[0-9]+_0_TaskCursorQuery_[0-9]+\\.parquet$")) + ASSERT_TRUE( + RE2::FullMatch( + filePath.filename().string(), + "0[0-9]+_0_TaskCursorQuery_[0-9]+\\.parquet$")) << filePath.filename().string(); } else { - ASSERT_TRUE(RE2::FullMatch( - filePath.filename().string(), "0[0-9]+_0_TaskCursorQuery_[0-9]+")) + ASSERT_TRUE( + RE2::FullMatch( + filePath.filename().string(), "0[0-9]+_0_TaskCursorQuery_[0-9]+")) << filePath.filename().string(); } } else { if (fileFormat_ == FileFormat::PARQUET) { - ASSERT_TRUE(RE2::FullMatch( - filePath.filename().string(), - ".tmp.velox.0[0-9]+_0_TaskCursorQuery_[0-9]+_.+\\.parquet$")) + ASSERT_TRUE( + RE2::FullMatch( + filePath.filename().string(), + ".tmp.velox.0[0-9]+_0_TaskCursorQuery_[0-9]+_.+\\.parquet$")) << filePath.filename().string(); } else { - ASSERT_TRUE(RE2::FullMatch( - filePath.filename().string(), - ".tmp.velox.0[0-9]+_0_TaskCursorQuery_[0-9]+_.+")) + ASSERT_TRUE( + RE2::FullMatch( + filePath.filename().string(), + ".tmp.velox.0[0-9]+_0_TaskCursorQuery_[0-9]+_.+")) << filePath.filename().string(); } } diff --git a/velox/exec/tests/utils/TableWriterTestBase.h b/velox/exec/tests/utils/TableWriterTestBase.h index f0c71f92bee..83c21b3075c 100644 --- a/velox/exec/tests/utils/TableWriterTestBase.h +++ b/velox/exec/tests/utils/TableWriterTestBase.h @@ -17,6 +17,7 @@ #include "velox/common/base/Fs.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/hyperloglog/SparseHll.h" +#include "velox/common/testutil/TempDirectoryPath.h" #include "velox/common/testutil/TestValue.h" #include "velox/connectors/hive/HiveConfig.h" #include "velox/connectors/hive/HivePartitionFunction.h" @@ -26,12 +27,11 @@ #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/vector/fuzzer/VectorFuzzer.h" #include #include -#include "folly/experimental/EventCount.h" +#include "folly/synchronization/EventCount.h" #include "velox/common/memory/MemoryArbitrator.h" #include "velox/dwio/common/Options.h" #include "velox/dwio/dwrf/writer/Writer.h" diff --git a/velox/exec/tests/utils/TempDirectoryPath.h b/velox/exec/tests/utils/TempDirectoryPath.h index e719d0949d8..0e236e1bcc5 100644 --- a/velox/exec/tests/utils/TempDirectoryPath.h +++ b/velox/exec/tests/utils/TempDirectoryPath.h @@ -15,53 +15,9 @@ */ #pragma once -#include -#include -#include -#include - -#include "velox/common/base/Exceptions.h" +// DEPRECATED: Use velox/common/testutil/TempDirectoryPath.h instead. +#include "velox/common/testutil/TempDirectoryPath.h" namespace facebook::velox::exec::test { - -/// Manages the lifetime of a temporary directory. -class TempDirectoryPath { - public: - /// If 'enableFaultInjection' is true, we enable fault injection on the - /// created file directory. - static std::shared_ptr create( - bool enableFaultInjection = false); - - virtual ~TempDirectoryPath(); - - TempDirectoryPath(const TempDirectoryPath&) = delete; - TempDirectoryPath& operator=(const TempDirectoryPath&) = delete; - - /// If fault injection is enabled, the returned file path will have the faulty - /// file system prefix scheme. The velox fs then opens the directory through - /// the faulty file system. The file operation will then either fail or be - /// delegated to the actual file. - const std::string& getPath() const { - return path_; - } - - /// The actual file path if fault injection is enabled. - const std::string& getDelegatePath() const { - return tempPath_; - } - - private: - static std::string createTempDirectory(); - - explicit TempDirectoryPath(bool enableFaultInjection) - : enableFaultInjection_(enableFaultInjection), - tempPath_(createTempDirectory()), - path_( - enableFaultInjection_ ? fmt::format("faulty:{}", tempPath_) - : tempPath_) {} - - const bool enableFaultInjection_{false}; - const std::string tempPath_; - const std::string path_; -}; +using facebook::velox::common::testutil::TempDirectoryPath; } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/utils/TempFilePath.h b/velox/exec/tests/utils/TempFilePath.h index 4a4f95c042e..0a34011ddb0 100644 --- a/velox/exec/tests/utils/TempFilePath.h +++ b/velox/exec/tests/utils/TempFilePath.h @@ -15,82 +15,10 @@ */ #pragma once -#include -#include -#include -#include -#include -#include - -#include "velox/common/base/Exceptions.h" +// DEPRECATED: Use velox/common/testutil/TempFilePath.h instead. +#include "velox/common/testutil/TempFilePath.h" namespace facebook::velox::exec::test { - -/// Manages the lifetime of a temporary file. -class TempFilePath { - public: - /// If 'enableFaultInjection' is true, we enable fault injection on the - /// created file. - static std::shared_ptr create( - bool enableFaultInjection = false); - - ~TempFilePath(); - - TempFilePath(const TempFilePath&) = delete; - TempFilePath& operator=(const TempFilePath&) = delete; - - void append(std::string data) { - std::ofstream file(tempPath_, std::ios_base::app); - file << data; - file.flush(); - file.close(); - } - - const int64_t fileSize() { - struct stat st; - ::stat(tempPath_.data(), &st); - return st.st_size; - } - - int64_t fileModifiedTime() { - struct stat st; - ::stat(tempPath_.data(), &st); - return st.st_mtime; - } - - /// If fault injection is enabled, the returned the file path has the faulty - /// file system prefix scheme. The velox fs then opens the file through the - /// faulty file system. The actual file operation might either fails or - /// delegate to the actual file. - const std::string& getPath() const { - return path_; - } - - // Returns the delegated file path if fault injection is enabled. - const std::string& tempFilePath() const { - return tempPath_; - } - - private: - static std::string createTempFile(TempFilePath* tempFilePath); - - TempFilePath(bool enableFaultInjection) - : enableFaultInjection_(enableFaultInjection), - tempPath_(createTempFile(this)), - path_( - enableFaultInjection_ ? fmt::format("faulty:{}", tempPath_) - : tempPath_) { - VELOX_CHECK_NE(fd_, -1); - } - - const bool enableFaultInjection_; - const std::string tempPath_; - const std::string path_; - - int fd_; -}; - -std::vector toFilePaths( - const std::vector>& tempFiles); - +using facebook::velox::common::testutil::TempFilePath; +using facebook::velox::common::testutil::toFilePaths; } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/utils/TestIndexStorageConnector.cpp b/velox/exec/tests/utils/TestIndexStorageConnector.cpp index b2dc3a81bdd..232c6e4903f 100644 --- a/velox/exec/tests/utils/TestIndexStorageConnector.cpp +++ b/velox/exec/tests/utils/TestIndexStorageConnector.cpp @@ -72,6 +72,7 @@ std::shared_ptr TestIndexTable::create( /*dependentTypes=*/dependentTypes, /*allowDuplicates=*/true, /*hasProbedFlag=*/false, + /*hasCountFlag=*/false, /*minTableSizeForParallelJoinBuild=*/1, &pool); @@ -96,7 +97,8 @@ std::shared_ptr TestIndexTable::create( } // Build the table index. - table->prepareJoinTable({}, BaseHashTable::kNoSpillInputStartPartitionBit); + table->prepareJoinTable( + {}, BaseHashTable::kNoSpillInputStartPartitionBit, 1'000'000); return std::make_shared( std::move(keyType), std::move(valueType), std::move(table)); } @@ -115,40 +117,82 @@ core::TypedExprPtr toJoinConditionExpr( std::vector conditionExprs; conditionExprs.reserve(joinConditions.size()); for (const auto& condition : joinConditions) { + // Check for EqualIndexLookupCondition first to skip equi-join conditions + // before creating indexColumnExpr, since equi-join key names may not exist + // in keyType (they are handled by hash table lookup, not filter + // evaluation). + if (auto equalCondition = + std::dynamic_pointer_cast( + condition)) { + // Skip equi-join conditions (non-constant values) as they are handled + // by the hash table lookup, not by filter evaluation. + if (!equalCondition->isFilter()) { + continue; + } + // Filter conditions (constant values) need to be evaluated. + auto indexColumnExpr = std::make_shared( + keyType->findChild(condition->key->name()), condition->key->name()); + conditionExprs.push_back( + std::make_shared( + BOOLEAN(), + "eq", + std::move(indexColumnExpr), + equalCondition->value)); + continue; + } auto indexColumnExpr = std::make_shared( keyType->findChild(condition->key->name()), condition->key->name()); if (auto inCondition = std::dynamic_pointer_cast( condition)) { - conditionExprs.push_back(std::make_shared( - BOOLEAN(), - "contains", - inCondition->list, - std::move(indexColumnExpr))); + conditionExprs.push_back( + std::make_shared( + BOOLEAN(), + "contains", + inCondition->list, + std::move(indexColumnExpr))); continue; } if (auto betweenCondition = std::dynamic_pointer_cast( condition)) { - conditionExprs.push_back(std::make_shared( - BOOLEAN(), - "between", - std::move(indexColumnExpr), - betweenCondition->lower, - betweenCondition->upper)); + conditionExprs.push_back( + std::make_shared( + BOOLEAN(), + "between", + std::move(indexColumnExpr), + betweenCondition->lower, + betweenCondition->upper)); continue; } + VELOX_FAIL("Invalid index join condition: {}", condition->toString()); + } + if (conditionExprs.empty()) { + return nullptr; + } + if (conditionExprs.size() == 1) { + return conditionExprs[0]; + } + return std::make_shared( + BOOLEAN(), conditionExprs, "and"); +} + +// Counts the number of equi-join keys in the given join conditions. +// An equi-join key is an EqualIndexLookupCondition where isFilter() +// returns false (i.e., the value references a probe column, not a constant). +size_t countEqualJoinKeys( + const std::vector& joinConditions) { + size_t count = 0; + for (const auto& condition : joinConditions) { if (auto equalCondition = std::dynamic_pointer_cast( condition)) { - conditionExprs.push_back(std::make_shared( - BOOLEAN(), "eq", std::move(indexColumnExpr), equalCondition->value)); - continue; + if (!equalCondition->isFilter()) { + ++count; + } } - VELOX_FAIL("Invalid index join condition: {}", condition->toString()); } - return std::make_shared( - BOOLEAN(), conditionExprs, "and"); + return count; } // Copy values from 'rows' of 'table' according to 'projections' in @@ -222,11 +266,16 @@ void TestIndexSource::checkNotFailed() { } } -std::shared_ptr -TestIndexSource::lookup(const LookupRequest& request) { +std::shared_ptr TestIndexSource::lookup( + const Request& request) { checkNotFailed(); + VELOX_CHECK(!tableHandle_->needsIndexSplit() || !splits_.empty()); const auto numInputRows = request.input->size(); - auto& hashTable = tableHandle_->indexTable()->table; + auto* indexTable = tableHandle_->indexTable().get(); + auto& hashTable = indexTable->table; + // Serialize access to the shared hash table. The hashers contain stateful + // DecodedVectors that are not thread-safe for concurrent access. + std::lock_guard l(indexTable->mutex); auto lookup = std::make_unique(hashTable->hashers(), pool_.get()); SelectivityVector activeRows(numInputRows); VELOX_CHECK(activeRows.isAllSelected()); @@ -298,13 +347,13 @@ void TestIndexSource::recordCpuTiming(const CpuWallTiming& timing) { std::lock_guard l(mutex_); if (timing.wallNanos != 0) { addOperatorRuntimeStats( - IndexLookupJoin::kConnectorLookupWallTime, + std::string(IndexLookupJoin::kConnectorLookupWallTime), RuntimeCounter(timing.wallNanos, RuntimeCounter::Unit::kNanos), runtimeStats_); // This is just for testing purpose to check if the runtime stats has been // set properly. addOperatorRuntimeStats( - IndexLookupJoin::kClientLookupWaitWallTime, + std::string(IndexLookupJoin::kClientLookupWaitWallTime), RuntimeCounter(timing.wallNanos, RuntimeCounter::Unit::kNanos), runtimeStats_); } @@ -312,7 +361,7 @@ void TestIndexSource::recordCpuTiming(const CpuWallTiming& timing) { // This is just for testing purpose to check if the runtime stats has been // set properly. addOperatorRuntimeStats( - IndexLookupJoin::kConnectorResultPrepareTime, + std::string(IndexLookupJoin::kConnectorResultPrepareTime), RuntimeCounter(timing.cpuNanos, RuntimeCounter::Unit::kNanos), runtimeStats_); } @@ -325,7 +374,7 @@ std::unordered_map TestIndexSource::runtimeStats() { TestIndexSource::ResultIterator::ResultIterator( std::shared_ptr source, - const LookupRequest& request, + const Request& request, std::unique_ptr lookupResult, folly::Executor* executor) : source_(std::move(source)), @@ -338,7 +387,17 @@ TestIndexSource::ResultIterator::ResultIterator( lookupResultIter_->reset(*lookupResult_); } -std::optional> +bool TestIndexSource::ResultIterator::hasNext() { + // If we have an async result ready, we have more to return. + if (asyncResult_.has_value()) { + return true; + } + + // If the iterator is not at end, there are more results to fetch. + return !lookupResultIter_->atEnd(); +} + +std::optional> TestIndexSource::ResultIterator::next( vector_size_t size, ContinueFuture& future) { @@ -417,7 +476,7 @@ void TestIndexSource::ResultIterator::asyncLookup( }); } -std::unique_ptr +std::unique_ptr TestIndexSource::ResultIterator::syncLookup(vector_size_t size) { VELOX_CHECK(hasPendingRequest_); if (lookupResultIter_->atEnd()) { @@ -476,7 +535,8 @@ TestIndexSource::ResultIterator::syncLookup(vector_size_t size) { lookupOutput_); VELOX_CHECK_EQ(lookupOutput_->size(), numHits); VELOX_CHECK_EQ(inputHitIndices_->size() / sizeof(vector_size_t), numHits); - return std::make_unique(inputHitIndices_, lookupOutput_); + return std::make_unique( + inputHitIndices_, lookupOutput_); } catch (const std::exception& e) { VELOX_CHECK(source_->error_.empty()); source_->error_ = e.what(); @@ -546,13 +606,14 @@ TestIndexConnector::TestIndexConnector( std::shared_ptr TestIndexConnector::createIndexSource( const RowTypePtr& inputType, - size_t numJoinKeys, const std::vector& joinConditions, const RowTypePtr& outputType, const connector::ConnectorTableHandlePtr& tableHandle, const connector::ColumnHandleMap& columnHandles, connector::ConnectorQueryCtx* connectorQueryCtx) { - VELOX_CHECK_GE(inputType->size(), numJoinKeys + joinConditions.size()); + const size_t numEqualJoinKeys = countEqualJoinKeys(joinConditions); + VELOX_CHECK_GE(inputType->size(), numEqualJoinKeys); + auto testIndexTableHandle = std::dynamic_pointer_cast(tableHandle); VELOX_CHECK_NOT_NULL(testIndexTableHandle); @@ -572,7 +633,7 @@ std::shared_ptr TestIndexConnector::createIndexSource( return std::make_shared( inputType, outputType, - numJoinKeys, + numEqualJoinKeys, joinConditionExpr, testIndexTableHandle, testColumnHandles, diff --git a/velox/exec/tests/utils/TestIndexStorageConnector.h b/velox/exec/tests/utils/TestIndexStorageConnector.h index d9fbbe315e8..a4f3029031f 100644 --- a/velox/exec/tests/utils/TestIndexStorageConnector.h +++ b/velox/exec/tests/utils/TestIndexStorageConnector.h @@ -15,6 +15,7 @@ */ #pragma once +#include "velox/connectors/ConnectorRegistry.h" #include "velox/exec/HashTable.h" #include "velox/type/Type.h" @@ -30,6 +31,11 @@ struct TestIndexTable { RowTypePtr dataType; std::shared_ptr table; + // Mutex to serialize concurrent access to the hash table. The hash table's + // hashers contain stateful DecodedVectors that are not thread-safe for + // concurrent joinProbe calls from multiple index sources. + mutable std::mutex mutex; + TestIndexTable( RowTypePtr _keyType, RowTypePtr _dataType, @@ -52,18 +58,21 @@ class TestIndexTableHandle : public connector::ConnectorTableHandle { explicit TestIndexTableHandle( std::string connectorId, std::shared_ptr indexTable, - bool asyncLookup) + bool asyncLookup, + bool needsIndexSplit = false) : ConnectorTableHandle(std::move(connectorId)), indexTable_(std::move(indexTable)), - asyncLookup_(asyncLookup) {} + asyncLookup_(asyncLookup), + needsIndexSplit_(needsIndexSplit) {} ~TestIndexTableHandle() override = default; std::string toString() const override { return fmt::format( - "IndexTableHandle: num of rows: {}, asyncLookup: {}", + "IndexTableHandle: num of rows: {}, asyncLookup: {}, needsIndexSplit: {}", indexTable_ ? indexTable_->table->rows()->numRows() : 0, - asyncLookup_); + asyncLookup_, + needsIndexSplit_); } const std::string& name() const override { @@ -80,6 +89,7 @@ class TestIndexTableHandle : public connector::ConnectorTableHandle { obj["name"] = name(); obj["connectorId"] = connectorId(); obj["asyncLookup"] = asyncLookup_; + obj["needsIndexSplit"] = needsIndexSplit_; // For testing purpose only, we serialize the index table pointer as an // long integer. obj["indexTable"] = reinterpret_cast(indexTable_.get()); @@ -96,7 +106,8 @@ class TestIndexTableHandle : public connector::ConnectorTableHandle { return std::make_shared( obj["connectorId"].getString(), std::shared_ptr(indexTablePtr, [](TestIndexTable*) {}), - obj["asyncLookup"].asBool()); + obj["asyncLookup"].asBool(), + obj["needsIndexSplit"].asBool()); } static void registerSerDe() { @@ -113,13 +124,53 @@ class TestIndexTableHandle : public connector::ConnectorTableHandle { return asyncLookup_; } + /// If true, the index source requires a split to perform lookup. This is for + /// testing the split collection logic in the IndexLookupJoin operator. + bool needsIndexSplit() const override { + return needsIndexSplit_; + } + private: const std::shared_ptr indexTable_; const bool asyncLookup_; + const bool needsIndexSplit_; }; using TestIndexTableHandlePtr = std::shared_ptr; +/// A fake split class for testing the split collection logic in the +/// IndexLookupJoin operator. The test index source doesn't actually use splits, +/// but this is used to verify the split passing mechanism works correctly. +class TestIndexConnectorSplit : public connector::ConnectorSplit { + public: + explicit TestIndexConnectorSplit(std::string connectorId) + : ConnectorSplit(std::move(connectorId)) {} + + std::string toString() const override { + return "TestIndexConnectorSplit"; + } + + folly::dynamic serialize() const override { + folly::dynamic obj = folly::dynamic::object; + obj["name"] = "TestIndexConnectorSplit"; + obj["connectorId"] = connectorId; + return obj; + } + + static std::shared_ptr create( + const folly::dynamic& obj) { + return std::make_shared( + obj["connectorId"].getString()); + } + + static void registerSerDe() { + auto& registry = DeserializationWithContextRegistryForSharedPtr(); + registry.Register( + "TestIndexConnectorSplit", + [](const folly::dynamic& obj, void*) { return create(obj); }); + } +}; + class TestIndexColumnHandle : public connector::ColumnHandle { public: explicit TestIndexColumnHandle(std::string name) : name_{std::move(name)} {} @@ -165,8 +216,15 @@ class TestIndexSource : public connector::IndexSource, connector::ConnectorQueryCtx* connectorQueryCtx, folly::Executor* executor); - std::shared_ptr lookup( - const LookupRequest& request) override; + void addSplits( + std::vector> splits) override { + VELOX_CHECK(tableHandle_->needsIndexSplit()); + VELOX_CHECK(!splits.empty()); + VELOX_CHECK(splits_.empty()); + splits_ = std::move(splits); + } + + std::shared_ptr lookup(const Request& request) override; std::unordered_map runtimeStats() override; @@ -186,15 +244,17 @@ class TestIndexSource : public connector::IndexSource, return lookupOutputProjections_; } - class ResultIterator : public LookupResultIterator { + class ResultIterator : public connector::IndexSource::ResultIterator { public: ResultIterator( std::shared_ptr source, - const LookupRequest& request, + const Request& request, std::unique_ptr lookupResult, folly::Executor* executor); - std::optional> next( + bool hasNext() override; + + std::optional> next( vector_size_t size, ContinueFuture& future) override; @@ -236,16 +296,17 @@ class TestIndexSource : public connector::IndexSource, // Synchronously lookup the index table and return up to 'size' number of // output rows in result. - std::unique_ptr syncLookup(vector_size_t size); + std::unique_ptr syncLookup( + vector_size_t size); const std::shared_ptr source_; - const LookupRequest request_; + const Request request_; const std::unique_ptr lookupResult_; folly::Executor* const executor_{nullptr}; std::atomic_bool hasPendingRequest_{false}; std::unique_ptr lookupResultIter_; - std::optional> asyncResult_; + std::optional> asyncResult_; // The reusable buffers for lookup result processing. // The input row number in lookup request for each matched result which is @@ -294,7 +355,7 @@ class TestIndexSource : public connector::IndexSource, const std::shared_ptr pool_; folly::Executor* const executor_; - mutable std::mutex mutex_; + std::mutex mutex_; // Join condition filter input type. RowTypePtr conditionInputType_; @@ -311,6 +372,11 @@ class TestIndexSource : public connector::IndexSource, std::vector conditionTableProjections_; std::vector lookupOutputProjections_; std::unordered_map runtimeStats_; + + // Collected splits for the index source (if tableHandle_->needsIndexSplit() + // returns true). This is only used for testing the split interface but not + // actually used in the lookup. + std::vector> splits_; }; class TestIndexConnector : public connector::Connector { @@ -334,7 +400,6 @@ class TestIndexConnector : public connector::Connector { std::shared_ptr createIndexSource( const RowTypePtr& inputType, - size_t numJoinKeys, const std::vector& joinConditions, const RowTypePtr& outputType, const connector::ConnectorTableHandlePtr& tableHandle, @@ -370,7 +435,8 @@ class TestIndexConnectorFactory : public connector::ConnectorFactory { TestIndexConnectorFactory factory; std::shared_ptr connector = factory.newConnector(kTestIndexConnectorName, {}, nullptr, cpuExecutor); - connector::registerConnector(connector); + connector::ConnectorRegistry::global().insert( + connector->connectorId(), connector); } }; } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/utils/TpcdsQueryBuilder.cpp b/velox/exec/tests/utils/TpcdsQueryBuilder.cpp new file mode 100644 index 00000000000..8c474697210 --- /dev/null +++ b/velox/exec/tests/utils/TpcdsQueryBuilder.cpp @@ -0,0 +1,170 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/exec/tests/utils/TpcdsQueryBuilder.h" + +#include "velox/common/base/Exceptions.h" +#include "velox/connectors/ConnectorRegistry.h" +#include "velox/connectors/hive/HiveConnector.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" + +#include +namespace fs = std::filesystem; + +namespace { + +/// Try to find data files for a table name. First tries exact match, +/// then strips any schema prefix (e.g. "tpcds.store_sales" -> +/// "store_sales"). +const std::vector* findDataFiles( + const std::unordered_map>& + tableDataFiles, + const std::string& tableName) { + auto it = tableDataFiles.find(tableName); + if (it != tableDataFiles.end() && !it->second.empty()) { + return &it->second; + } + + auto dotPos = tableName.rfind('.'); + if (dotPos != std::string::npos) { + it = tableDataFiles.find(tableName.substr(dotPos + 1)); + if (it != tableDataFiles.end() && !it->second.empty()) { + return &it->second; + } + } + + return nullptr; +} + +} // namespace + +namespace facebook::velox::exec::test { + +TpcdsQueryBuilder::TpcdsQueryBuilder(dwio::common::FileFormat format) + : format_(format) {} + +void TpcdsQueryBuilder::initialize(const std::string& dataPath) { + tableDataFiles_.clear(); + + std::error_code ec; + for (auto const& tableEntry : + fs::directory_iterator{dataPath, fs::directory_options(), ec}) { + if (!tableEntry.is_directory()) { + continue; + } + const auto tableName = tableEntry.path().filename().string(); + + // Skip hidden directories. + if (tableName.empty() || tableName[0] == '.') { + continue; + } + + auto& files = tableDataFiles_[tableName]; + + std::error_code fileEc; + for (auto const& fileEntry : fs::directory_iterator{ + tableEntry.path(), fs::directory_options(), fileEc}) { + if (!fileEntry.is_regular_file()) { + continue; + } + // Skip hidden files. + if (fileEntry.path().filename().c_str()[0] == '.') { + continue; + } + files.push_back(fileEntry.path().string()); + } + + // Sort for deterministic ordering. + std::sort(files.begin(), files.end()); + } + + VELOX_USER_CHECK( + !tableDataFiles_.empty(), + "No table subdirectories found in data path: {}", + dataPath); +} + +void TpcdsQueryBuilder::registerHiveConnector( + const std::string& connectorId, + folly::Executor* /*ioExecutor*/) { + auto hiveConnector = connector::hive::HiveConnectorFactory().newConnector( + connectorId, + std::make_shared( + std::unordered_map())); + connector::ConnectorRegistry::global().insert( + hiveConnector->connectorId(), hiveConnector); +} + +VeloxPlan TpcdsQueryBuilder::getQueryPlan( + int queryId, + const std::string& planDir, + memory::MemoryPool* pool) { + VeloxPlanLoader loader(planDir, pool); + auto veloxPlan = loader.loadPlanByQueryId(queryId); + + // Collect all TableScan nodes from the plan tree. + auto scanNodes = VeloxPlanLoader::collectTableScanNodes(veloxPlan.plan); + + // Detect connector ID from the plan's TableScan nodes. + // Presto plans typically use "hive" while the test fixture registers + // "test-hive". We need to register a connector under the plan's ID. + if (!scanNodes.empty() && connectorId_.empty()) { + connectorId_ = scanNodes[0]->tableHandle()->connectorId(); + LOG(INFO) << "TpcdsQueryBuilder: detected connector ID '" << connectorId_ + << "' from plan"; + + if (connector::ConnectorRegistry::tryGet(connectorId_) == nullptr) { + registerHiveConnector(connectorId_); + ownedConnector_ = true; + LOG(INFO) << "TpcdsQueryBuilder: registered connector under ID '" + << connectorId_ << "'"; + } + } + + // For each TableScan, look up the table name and populate data files. + for (const auto& scanNode : scanNodes) { + const auto& tableName = scanNode->tableHandle()->name(); + + const auto* files = findDataFiles(tableDataFiles_, tableName); + if (files) { + veloxPlan.dataFiles[scanNode->id()] = *files; + } else { + LOG(WARNING) << "TpcdsQueryBuilder: no data files found for table '" + << tableName << "' (scan node " << scanNode->id() << ")"; + } + } + + veloxPlan.dataFileFormat = format_; + return veloxPlan; +} + +std::shared_ptr TpcdsQueryBuilder::makeSplit( + const std::string& filePath) const { + const auto& id = connectorId_.empty() ? kHiveConnectorId : connectorId_; + return connector::hive::HiveConnectorSplitBuilder(filePath) + .connectorId(id) + .fileFormat(format_) + .build(); +} + +void TpcdsQueryBuilder::shutdown() { + if (ownedConnector_ && !connectorId_.empty()) { + connector::ConnectorRegistry::global().erase(connectorId_); + ownedConnector_ = false; + } + connectorId_.clear(); +} + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/utils/TpcdsQueryBuilder.h b/velox/exec/tests/utils/TpcdsQueryBuilder.h new file mode 100644 index 00000000000..516cc7e1125 --- /dev/null +++ b/velox/exec/tests/utils/TpcdsQueryBuilder.h @@ -0,0 +1,108 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/connectors/hive/HiveConnectorSplit.h" +#include "velox/dwio/common/Options.h" +#include "velox/exec/tests/utils/VeloxPlanLoader.h" + +#include +#include +#include + +namespace facebook::velox::exec::test { + +/// Builds TPC-DS query plans from pre-dumped JSON plan files and data files +/// located in a given data directory. Each table's data must be placed in +/// hive-style partitioning: a sub-directory per table name under the data path. +/// +/// Example: +/// ls /gds/datasets/tpcds/sf100/ +/// store_sales/ customer/ item/ ... +/// +/// ls /gds/datasets/tpcds/sf100/store_sales/ +/// store_sales_part-00000.parquet store_sales_part-00001.parquet ... +/// +/// This class is CuDF-free. To use CudfHiveConnector instead of HiveConnector, +/// subclass and override registerConnector(). +class TpcdsQueryBuilder { + public: + explicit TpcdsQueryBuilder( + dwio::common::FileFormat format = dwio::common::FileFormat::PARQUET); + + virtual ~TpcdsQueryBuilder() = default; + + /// Scan dataPath for table subdirectories and collect all data file paths + /// per table. + void initialize(const std::string& dataPath); + + /// Load the deserialized plan from JSON (via VeloxPlanLoader) and populate + /// dataFiles for each TableScan node by matching table names to the + /// discovered data files in the data path. + /// + /// The connector ID is auto-detected from the plan's TableScan nodes (e.g. + /// Presto plans use "hive"). If no connector is registered under that ID, + /// registerConnector() is called to register one. + /// + /// Table name matching is flexible: if the exact name from the plan (e.g. + /// "tpcds.store_sales") doesn't match a data directory, the part after the + /// last '.' is tried (e.g. "store_sales"). + /// + /// @param queryId TPC-DS query number (1..99) + /// @param planDir Directory containing plan JSON files (Q1.json, ...) + /// @param pool Memory pool for plan deserialization + VeloxPlan getQueryPlan( + int queryId, + const std::string& planDir, + memory::MemoryPool* pool); + + /// Create a ConnectorSplit for a given file path using HiveConnectorSplit. + /// The connector ID used matches the one discovered from the plan. + std::shared_ptr makeSplit( + const std::string& filePath) const; + + /// Clean up: unregisters any connector that was auto-registered by + /// getQueryPlan. + virtual void shutdown(); + + /// Returns the connector ID discovered from the plan. + const std::string& connectorId() const { + return connectorId_; + } + + protected: + /// Called when a connector needs to be registered for the given ID. + /// Default implementation registers a HiveConnector. + /// Override in subclasses to register CudfHiveConnector or others. + /// @param connectorId The connector ID from the plan. + /// @param ioExecutor Optional IO executor. + virtual void registerHiveConnector( + const std::string& connectorId, + folly::Executor* ioExecutor = nullptr); + + dwio::common::FileFormat format_; + /// Connector ID auto-detected from the plan's TableScan nodes. + std::string connectorId_; + /// Whether we registered a connector ourselves (need to unregister on + /// shutdown). + bool ownedConnector_{false}; + + private: + /// tableName -> list of data file paths + std::unordered_map> tableDataFiles_; +}; + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/utils/TpchQueryBuilder.cpp b/velox/exec/tests/utils/TpchQueryBuilder.cpp index 75df732e2a2..d7bfd6bfa00 100644 --- a/velox/exec/tests/utils/TpchQueryBuilder.cpp +++ b/velox/exec/tests/utils/TpchQueryBuilder.cpp @@ -70,7 +70,9 @@ void TpchQueryBuilder::readFileSchema( const std::string& tableName, const std::string& filePath, const std::vector& columns) { - dwio::common::ReaderOptions readerOptions{pool_.get()}; + dwio::common::ReaderOptions readerOptions(pool_.get()); + readerOptions.setDataIoStats(dataIoStats_); + readerOptions.setMetadataIoStats(metadataIoStats_); readerOptions.setFileFormat(format_); auto uniqueReadFile = filesystems::getFileSystem(filePath, nullptr)->openFileForRead(filePath); diff --git a/velox/exec/tests/utils/TpchQueryBuilder.h b/velox/exec/tests/utils/TpchQueryBuilder.h index 30a1ef050bd..a444c1d8951 100644 --- a/velox/exec/tests/utils/TpchQueryBuilder.h +++ b/velox/exec/tests/utils/TpchQueryBuilder.h @@ -15,18 +15,14 @@ */ #pragma once +#include "velox/common/io/IoStatistics.h" #include "velox/dwio/common/Options.h" #include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/exec/tests/utils/VeloxPlanLoader.h" namespace facebook::velox::exec::test { -/// Contains the query plan and input data files keyed on source plan node ID. -/// All data files use the same file format specified in 'dataFileFormat'. -struct TpchPlan { - core::PlanNodePtr plan; - std::unordered_map> dataFiles; - dwio::common::FileFormat dataFileFormat; -}; +using TpchPlan = VeloxPlan; /// Contains type information, data files, and file column names for a table. /// This information is inferred from the input data files. @@ -149,6 +145,10 @@ class TpchQueryBuilder { static constexpr const char* kPartsupp = "partsupp"; std::shared_ptr pool_ = memory::memoryManager()->addLeafPool(); + std::shared_ptr dataIoStats_ = + std::make_shared(); + std::shared_ptr metadataIoStats_ = + std::make_shared(); const bool filtersAsNode_; }; diff --git a/velox/exec/tests/utils/VeloxPlanLoader.cpp b/velox/exec/tests/utils/VeloxPlanLoader.cpp new file mode 100644 index 00000000000..f50d9136e63 --- /dev/null +++ b/velox/exec/tests/utils/VeloxPlanLoader.cpp @@ -0,0 +1,149 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/exec/tests/utils/VeloxPlanLoader.h" +#include +#include +#include +#include +#include "velox/common/base/Exceptions.h" +#include "velox/common/serialization/Serializable.h" +#include "velox/core/PlanNode.h" + +#if __has_include("filesystem") +#include +namespace fs = std::filesystem; +#else +#include +namespace fs = std::experimental::filesystem; +#endif + +namespace facebook::velox::exec::test { + +namespace { + +void collectTableScanNodesRecursive( + const core::PlanNodePtr& node, + std::vector& out) { + if (auto scan = std::dynamic_pointer_cast(node)) { + out.push_back(scan); + } + for (const auto& source : node->sources()) { + collectTableScanNodesRecursive(source, out); + } +} + +std::string readFileToString(const std::string& path) { + std::ifstream in(path); + if (!in) { + VELOX_FAIL("Failed to open plan file: {}", path); + } + std::stringstream buffer; + buffer << in.rdbuf(); + return buffer.str(); +} + +} // namespace + +// static +std::string VeloxPlanLoader::resolvePlanDirectory( + const std::string& defaultDir, + const std::string& envVarName) { + if (!envVarName.empty()) { + const char* env = std::getenv(envVarName.c_str()); + if (env && *env != '\0') { + return std::string(env); + } + } + return defaultDir; +} + +VeloxPlanLoader::VeloxPlanLoader( + const std::string& planDirectory, + memory::MemoryPool* pool, + bool stripPartitionedOutput) + : planDir_(planDirectory), + pool_(pool), + stripPartitionedOutput_(stripPartitionedOutput) { + if (pool_ == nullptr) { + ownedPool_ = memory::memoryManager()->addLeafPool("VeloxPlanLoader"); + pool_ = ownedPool_.get(); + } +} + +std::string VeloxPlanLoader::pathForQuery(int queryId) const { + const std::string path = planDir_ + "/Q" + std::to_string(queryId) + ".json"; + if (fs::exists(path)) { + return path; + } + VELOX_USER_FAIL("Plan file does not exist: {}", path); +} + +void VeloxPlanLoader::maybeStripPartitionedOutput( + core::PlanNodePtr& plan) const { + if (!stripPartitionedOutput_ || !plan) { + return; + } + if (auto partitionedOutput = + std::dynamic_pointer_cast(plan)) { + VELOX_CHECK_EQ( + partitionedOutput->sources().size(), + 1, + "PartitionedOutput should have exactly one source"); + plan = partitionedOutput->sources()[0]; + LOG(INFO) << "VeloxPlanLoader: stripped PartitionedOutput root from plan"; + } +} + +VeloxPlan VeloxPlanLoader::loadPlan(const std::string& path) const { + const std::string contents = readFileToString(path); + folly::dynamic planJson; + try { + planJson = folly::parseJson(contents); + } catch (const std::exception& e) { + VELOX_FAIL("Failed to parse plan JSON from {}: {}", path, e.what()); + } + core::PlanNodePtr plan; + try { + plan = velox::ISerializable::deserialize(planJson, pool_); + } catch (const std::exception& e) { + VELOX_FAIL("Failed to deserialize plan from {}: {}", path, e.what()); + } + maybeStripPartitionedOutput(plan); + VeloxPlan result; + result.plan = std::move(plan); + result.dataFiles = {}; + result.dataFileFormat = dwio::common::FileFormat::PARQUET; + return result; +} + +VeloxPlan VeloxPlanLoader::loadPlanByQueryId(int queryId) const { + VELOX_USER_CHECK( + queryId >= 1 && queryId <= 99, "queryId must be 1..99, got {}", queryId); + std::string path = pathForQuery(queryId); + return loadPlan(path); +} + +// static +std::vector VeloxPlanLoader::collectTableScanNodes( + const core::PlanNodePtr& plan) { + std::vector out; + if (plan) { + collectTableScanNodesRecursive(plan, out); + } + return out; +} + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/utils/VeloxPlanLoader.h b/velox/exec/tests/utils/VeloxPlanLoader.h new file mode 100644 index 00000000000..a4e9694c7fc --- /dev/null +++ b/velox/exec/tests/utils/VeloxPlanLoader.h @@ -0,0 +1,87 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include "velox/core/PlanNode.h" +#include "velox/dwio/common/Options.h" + +namespace facebook::velox::exec::test { + +/// Contains a deserialized Velox query plan and optional data files keyed on +/// source plan node ID. TpchPlan (in TpchQueryBuilder.h) is a type alias for +/// this struct. +struct VeloxPlan { + core::PlanNodePtr plan; + std::unordered_map> dataFiles; + dwio::common::FileFormat dataFileFormat{dwio::common::FileFormat::PARQUET}; +}; + +/// Generic loader for Velox plan JSON files (e.g. dumped from a Presto +/// worker's plan-dump-dir). Not specific to any benchmark suite -- works for +/// TPC-DS, TPC-H, or any arbitrary single-node query plan. +/// +/// Plan directory layout (for loadPlanByQueryId): +/// - Canonical names Q1.json, Q2.json, ... Q99.json. +/// +/// When stripPartitionedOutput is true (default), the loader strips the +/// PartitionedOutput root node that Presto always adds for distributed +/// shuffle. This makes the plan runnable locally via AssertQueryBuilder / +/// TaskCursor. +class VeloxPlanLoader { + public: + /// @param planDirectory Directory containing plan JSON files. + /// @param pool Memory pool for plan deserialization. If nullptr, + /// a leaf pool is created from the global + /// MemoryManager (which must already be initialized). + /// @param stripPartitionedOutput If true, strip PartitionedOutput root. + VeloxPlanLoader( + const std::string& planDirectory, + memory::MemoryPool* pool = nullptr, + bool stripPartitionedOutput = true); + + /// Resolve a plan directory from either an environment variable or a + /// supplied default. If envVarName is non-empty and the env var is set, its + /// value is used; otherwise defaultDir is returned. + static std::string resolvePlanDirectory( + const std::string& defaultDir, + const std::string& envVarName = ""); + + /// Load a plan from a specific file path. + VeloxPlan loadPlan(const std::string& path) const; + + /// Load plan by query ID (1..99). Looks for planDir/Q{id}.json. + VeloxPlan loadPlanByQueryId(int queryId) const; + + /// Collects all TableScan plan nodes in the plan tree. + static std::vector collectTableScanNodes( + const core::PlanNodePtr& plan); + + private: + std::string planDir_; + memory::MemoryPool* pool_; + std::shared_ptr ownedPool_; + bool stripPartitionedOutput_; + + std::string pathForQuery(int queryId) const; + + /// If stripPartitionedOutput_ is true and the plan root is a + /// PartitionedOutputNode, replace it with its single child. + void maybeStripPartitionedOutput(core::PlanNodePtr& plan) const; +}; + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/trace/CMakeLists.txt b/velox/exec/trace/CMakeLists.txt new file mode 100644 index 00000000000..8a4803ef14b --- /dev/null +++ b/velox/exec/trace/CMakeLists.txt @@ -0,0 +1,34 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +velox_add_library( + velox_trace + Trace.cpp + TraceUtil.cpp + HEADERS + Trace.h + TraceCtx.h + TraceUtil.h + TraceWriter.h +) + +velox_link_libraries( + velox_trace + velox_common_base + velox_exception + velox_file + velox_core + velox_type + velox_vector +) diff --git a/velox/exec/Trace.cpp b/velox/exec/trace/Trace.cpp similarity index 97% rename from velox/exec/Trace.cpp rename to velox/exec/trace/Trace.cpp index 6ff55787fc5..f50f5948a73 100644 --- a/velox/exec/Trace.cpp +++ b/velox/exec/trace/Trace.cpp @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "velox/exec/Trace.h" +#include "velox/exec/trace/Trace.h" #include diff --git a/velox/exec/Trace.h b/velox/exec/trace/Trace.h similarity index 58% rename from velox/exec/Trace.h rename to velox/exec/trace/Trace.h index 699de5253cd..cd473680d1e 100644 --- a/velox/exec/Trace.h +++ b/velox/exec/trace/Trace.h @@ -19,31 +19,33 @@ #include #include #include +#include namespace facebook::velox::exec::trace { + /// Defines the shared constants used by query trace implementation. struct TraceTraits { - static inline const std::string kPlanNodeKey = "planNode"; - static inline const std::string kQueryConfigKey = "queryConfig"; - static inline const std::string kConnectorPropertiesKey = + static constexpr std::string_view kPlanNodeKey = "planNode"; + static constexpr std::string_view kQueryConfigKey = "queryConfig"; + static constexpr std::string_view kConnectorPropertiesKey = "connectorProperties"; - static inline const std::string kTaskMetaFileName = "task_trace_meta.json"; + static constexpr std::string_view kTaskMetaFileName = "task_trace_meta.json"; }; struct OperatorTraceTraits { - static inline const std::string kSummaryFileName = "op_trace_summary.json"; - static inline const std::string kInputFileName = "op_input_trace.data"; - static inline const std::string kSplitFileName = "op_split_trace.split"; + static constexpr std::string_view kSummaryFileName = "op_trace_summary.json"; + static constexpr std::string_view kInputFileName = "op_input_trace.data"; + static constexpr std::string_view kSplitFileName = "op_split_trace.split"; /// Keys for operator trace summary file. - static inline const std::string kOpTypeKey = "opType"; - static inline const std::string kPeakMemoryKey = "peakMemory"; - static inline const std::string kInputRowsKey = "inputRows"; - static inline const std::string kInputBytesKey = "inputBytes"; - static inline const std::string kRawInputRowsKey = "rawInputRows"; - static inline const std::string kRawInputBytesKey = "rawInputBytes"; - static inline const std::string kNumSplitsKey = "numSplits"; + static constexpr std::string_view kOpTypeKey = "opType"; + static constexpr std::string_view kPeakMemoryKey = "peakMemory"; + static constexpr std::string_view kInputRowsKey = "inputRows"; + static constexpr std::string_view kInputBytesKey = "inputBytes"; + static constexpr std::string_view kRawInputRowsKey = "rawInputRows"; + static constexpr std::string_view kRawInputBytesKey = "rawInputBytes"; + static constexpr std::string_view kNumSplitsKey = "numSplits"; }; /// Contains the summary of an operator trace. diff --git a/velox/exec/trace/TraceCtx.h b/velox/exec/trace/TraceCtx.h new file mode 100644 index 00000000000..2c171c2b917 --- /dev/null +++ b/velox/exec/trace/TraceCtx.h @@ -0,0 +1,88 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include "velox/exec/trace/TraceWriter.h" + +namespace facebook::velox::exec { +class Operator; +} + +namespace facebook::velox::exec::trace { + +class TraceCtx { + public: + TraceCtx(bool dryRun) : dryRun_(dryRun) {} + + virtual ~TraceCtx() = default; + + /// Overwrite the methods below to provide a concrete trace writer + /// implementation for input, split, and task metadata. + virtual std::unique_ptr createInputTracer( + Operator&) const { + return nullptr; + } + + virtual std::unique_ptr createSplitTracer( + Operator&) const { + return nullptr; + } + + virtual std::unique_ptr createMetadataTracer() + const { + return nullptr; + } + + /// Returns whether a particular operator should be traced. Called before the + /// task starts execution, when operators are instantiated. + virtual bool shouldTrace(const Operator&) const { + return false; + } + + /// Returns whether a particular expression function should be traced. Called + /// during operator initialization in Expr::maybeSetupTracer(). + virtual bool shouldTraceExpr(std::string_view /*functionName*/) const { + return false; + } + + /// Creates a writer for capturing output batches from a traced expression. + /// The instanceIndex distinguishes multiple Expr nodes with the same + /// function name within a single operator. Ownership is transferred to the + /// caller. + virtual std::unique_ptr createExprOutputTracer( + const Operator& /*op*/, + std::string_view /*functionName*/, + int /*instanceIndex*/) const { + return nullptr; + } + + bool dryRun() const { + return dryRun_; + } + + private: + /// If true, we only collect operator input trace without the actual + /// execution. This is used by crash debugging so that we can collect the + /// input that triggers the crash. + bool dryRun_{false}; +}; + +} // namespace facebook::velox::exec::trace diff --git a/velox/exec/TraceUtil.cpp b/velox/exec/trace/TraceUtil.cpp similarity index 78% rename from velox/exec/TraceUtil.cpp rename to velox/exec/trace/TraceUtil.cpp index f2f9c49d8aa..58c721e7c3c 100644 --- a/velox/exec/TraceUtil.cpp +++ b/velox/exec/trace/TraceUtil.cpp @@ -14,20 +14,21 @@ * limitations under the License. */ -#include "velox/exec/TraceUtil.h" +#include "velox/exec/trace/TraceUtil.h" #include - #include #include "velox/common/base/Exceptions.h" #include "velox/common/file/File.h" #include "velox/common/file/FileSystems.h" -#include "velox/exec/TableWriter.h" -#include "velox/exec/Trace.h" +#include "velox/core/TableWriteTraits.h" +#include "velox/exec/OperatorType.h" +#include "velox/exec/trace/Trace.h" namespace facebook::velox::exec::trace { namespace { + std::string findLastPathNode(const std::string& path) { std::vector pathNodes; folly::split("/", path, pathNodes); @@ -38,51 +39,11 @@ std::string findLastPathNode(const std::string& path) { return pathNodes.back(); } -const std::vector kEmptySources; - -class DummySourceNode final : public core::PlanNode { - public: - explicit DummySourceNode(RowTypePtr outputType) - : PlanNode(""), outputType_(std::move(outputType)) {} - - const RowTypePtr& outputType() const override { - return outputType_; - } - - const std::vector& sources() const override { - return kEmptySources; - } - - std::string_view name() const override { - return "DummySource"; - } - - folly::dynamic serialize() const override { - folly::dynamic obj = folly::dynamic::object; - obj["name"] = "DummySource"; - obj["outputType"] = outputType_->serialize(); - return obj; - } - - static core::PlanNodePtr create(const folly::dynamic& obj, void* context) { - return std::make_shared( - ISerializable::deserialize(obj["outputType"])); - } - - private: - void addDetails(std::stringstream& stream) const override { - // Nothing to add. - } - - const RowTypePtr outputType_; -}; - -void registerDummySourceSerDe(); - std::unordered_map& traceNodeRegistry() { static std::unordered_map registry; return registry; } + } // namespace void createTraceDirectory( @@ -123,13 +84,6 @@ std::string getQueryTraceDirectory( return fmt::format("{}/{}", normalizedTraceDir, queryId); } -std::string getTaskTraceDirectory( - const std::string& traceDir, - const Task& task) { - return getTaskTraceDirectory( - traceDir, task.queryCtx()->queryId(), task.taskId()); -} - std::string getTaskTraceDirectory( const std::string& traceDir, const std::string& queryId, @@ -221,10 +175,8 @@ RowTypePtr getDataType( const core::PlanNodePtr& tracedPlan, const std::string& tracedNodeId, size_t sourceIndex) { - const auto* traceNode = core::PlanNode::findFirstNode( - tracedPlan.get(), [&tracedNodeId](const core::PlanNode* node) { - return node->id() == tracedNodeId; - }); + const auto* traceNode = + core::PlanNode::findNodeById(tracedPlan.get(), tracedNodeId); VELOX_CHECK_NOT_NULL( traceNode, "traced node id {} not found in the traced plan", @@ -285,17 +237,23 @@ std::vector extractDriverIds(const std::string& driverIds) { } bool canTrace(const std::string& operatorType) { - static const std::unordered_set kSupportedOperatorTypes{ - "Aggregation", - "FilterProject", - "HashBuild", - "HashProbe", - "IndexLookupJoin", - "Unnest", - "PartialAggregation", - "PartitionedOutput", - "TableScan", - "TableWrite"}; + static const std::unordered_set kSupportedOperatorTypes{ + OperatorType::kAggregation, + OperatorType::kCallbackSink, + OperatorType::kExchange, + OperatorType::kFilterProject, + OperatorType::kHashBuild, + OperatorType::kHashProbe, + OperatorType::kIndexLookupJoin, + OperatorType::kMergeExchange, + OperatorType::kMergeJoin, + OperatorType::kOrderBy, + OperatorType::kPartialAggregation, + OperatorType::kPartitionedOutput, + OperatorType::kTableScan, + OperatorType::kTableWrite, + OperatorType::kTopNRowNumber, + OperatorType::kUnnest}; if (kSupportedOperatorTypes.count(operatorType) > 0 || traceNodeRegistry().count(operatorType) > 0) { return true; @@ -304,11 +262,9 @@ bool canTrace(const std::string& operatorType) { } core::PlanNodePtr getTraceNode( - const core::PlanNodePtr& plan, + const core::PlanNode& plan, core::PlanNodeId nodeId) { - const auto* traceNode = core::PlanNode::findFirstNode( - plan.get(), - [&nodeId](const core::PlanNode* node) { return node->id() == nodeId; }); + const auto* traceNode = core::PlanNode::findNodeById(&plan, nodeId); VELOX_CHECK_NOT_NULL(traceNode, "Failed to find node with id {}", nodeId); if (const auto* hashJoinNode = dynamic_cast(traceNode)) { @@ -326,6 +282,21 @@ core::PlanNodePtr getTraceNode( hashJoinNode->outputType()); } + if (const auto* mergeJoinNode = + dynamic_cast(traceNode)) { + return std::make_shared( + nodeId, + mergeJoinNode->joinType(), + mergeJoinNode->leftKeys(), + mergeJoinNode->rightKeys(), + mergeJoinNode->filter(), + std::make_shared( + mergeJoinNode->sources()[0]->outputType()), + std::make_shared( + mergeJoinNode->sources()[1]->outputType()), + mergeJoinNode->outputType()); + } + if (const auto* filterNode = dynamic_cast(traceNode)) { // Single FilterNode. @@ -380,6 +351,7 @@ core::PlanNodePtr getTraceNode( aggregationNode->globalGroupingSets(), aggregationNode->groupId(), aggregationNode->ignoreNullKeys(), + aggregationNode->noGroupsSpanBatches(), std::make_shared( aggregationNode->sources().front()->outputType())); } @@ -394,7 +366,7 @@ core::PlanNodePtr getTraceNode( partitionedOutputNode->isReplicateNullsAndAny(), partitionedOutputNode->partitionFunctionSpecPtr(), partitionedOutputNode->outputType(), - VectorSerde::Kind::kPresto, + "Presto", std::make_shared( partitionedOutputNode->sources().front()->outputType())); } @@ -407,7 +379,8 @@ core::PlanNodePtr getTraceNode( indexLookupJoinNode->leftKeys(), indexLookupJoinNode->rightKeys(), indexLookupJoinNode->joinConditions(), - indexLookupJoinNode->includeMatchColumn(), + indexLookupJoinNode->filter(), + indexLookupJoinNode->hasMarker(), std::make_shared( indexLookupJoinNode->sources().front()->outputType()), // Probe side indexLookupJoinNode->lookupSource(), // Index side @@ -432,7 +405,7 @@ core::PlanNodePtr getTraceNode( tableWriteNode->columnStatsSpec(), tableWriteNode->insertTableHandle(), tableWriteNode->hasPartitioningScheme(), - TableWriteTraits::outputType(tableWriteNode->columnStatsSpec()), + core::TableWriteTraits::outputType(tableWriteNode->columnStatsSpec()), tableWriteNode->commitStrategy(), std::make_shared( tableWriteNode->sources().front()->outputType())); @@ -446,11 +419,56 @@ core::PlanNodePtr getTraceNode( unnestNode->unnestVariables(), unnestNode->unnestNames(), unnestNode->ordinalityName(), - unnestNode->emptyUnnestValueName(), + unnestNode->markerName(), std::make_shared( unnestNode->sources().front()->outputType())); } + if (const auto* orderByNode = + dynamic_cast(traceNode)) { + return std::make_shared( + nodeId, + orderByNode->sortingKeys(), + orderByNode->sortingOrders(), + orderByNode->isPartial(), + std::make_shared( + orderByNode->sources().front()->outputType())); + } + + if (const auto* topNRowNumberNode = + dynamic_cast(traceNode)) { + const auto generateRowNumber = topNRowNumberNode->generateRowNumber(); + return std::make_shared( + nodeId, + topNRowNumberNode->rankFunction(), + topNRowNumberNode->partitionKeys(), + topNRowNumberNode->sortingKeys(), + topNRowNumberNode->sortingOrders(), + generateRowNumber ? std::make_optional( + topNRowNumberNode->outputType()->names().back()) + : std::nullopt, + topNRowNumberNode->limit(), + std::make_shared( + topNRowNumberNode->sources().front()->outputType())); + } + + if (const auto* exchangeNode = + dynamic_cast(traceNode)) { + // Check if it's a MergeExchangeNode + if (const auto* mergeExchangeNode = + dynamic_cast(traceNode)) { + return std::make_shared( + nodeId, + mergeExchangeNode->outputType(), + mergeExchangeNode->sortingKeys(), + mergeExchangeNode->sortingOrders(), + mergeExchangeNode->serdeKind()); + } + // Regular ExchangeNode + return std::make_shared( + nodeId, exchangeNode->outputType(), exchangeNode->serdeKind()); + } + for (const auto& factory : traceNodeRegistry()) { if (auto node = factory.second(traceNode, nodeId)) { return node; diff --git a/velox/exec/trace/TraceUtil.h b/velox/exec/trace/TraceUtil.h new file mode 100644 index 00000000000..1d7a86d2ab5 --- /dev/null +++ b/velox/exec/trace/TraceUtil.h @@ -0,0 +1,188 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include "velox/common/file/FileSystems.h" +#include "velox/core/PlanNode.h" +#include "velox/type/Type.h" + +#include + +namespace facebook::velox::exec::trace { + +static const std::vector kEmptySources; + +class DummySourceNode final : public core::PlanNode { + public: + explicit DummySourceNode(RowTypePtr outputType) + : PlanNode(""), outputType_(std::move(outputType)) {} + + const RowTypePtr& outputType() const override { + return outputType_; + } + + const std::vector& sources() const override { + return kEmptySources; + } + + std::string_view name() const override { + return "DummySource"; + } + + folly::dynamic serialize() const override { + folly::dynamic obj = folly::dynamic::object; + obj["name"] = "DummySource"; + obj["outputType"] = outputType_->serialize(); + return obj; + } + + static core::PlanNodePtr create(const folly::dynamic& obj, void* context) { + return std::make_shared( + ISerializable::deserialize(obj["outputType"])); + } + + private: + void addDetails(std::stringstream& stream) const override { + // Nothing to add. + } + + const RowTypePtr outputType_; +}; + +/// Creates a directory to store the query trace metdata and data. +void createTraceDirectory( + const std::string& traceDir, + const std::string& directoryConfig = ""); + +/// Returns the trace directory for a given query. +std::string getQueryTraceDirectory( + const std::string& traceDir, + const std::string& queryId); + +/// Returns the trace directory for a given query task. +std::string getTaskTraceDirectory( + const std::string& traceDir, + const std::string& queryId, + const std::string& taskId); + +/// Returns the file path for a given task's metadata trace file. +std::string getTaskTraceMetaFilePath(const std::string& taskTraceDir); + +/// Returns the trace directory for a given traced plan node. +std::string getNodeTraceDirectory( + const std::string& taskTraceDir, + const std::string& nodeId); + +/// Returns the trace directory for a given traced pipeline. +std::string getPipelineTraceDirectory( + const std::string& nodeTraceDir, + uint32_t pipelineId); + +/// Returns the trace directory for a given traced operator. +std::string getOpTraceDirectory( + const std::string& taskTraceDir, + const std::string& nodeId, + uint32_t pipelineId, + uint32_t driverId); + +std::string getOpTraceDirectory( + const std::string& nodeTraceDir, + uint32_t pipelineId, + uint32_t driverId); + +/// Returns the file path for a given operator's traced input file. +std::string getOpTraceInputFilePath(const std::string& opTraceDir); + +/// Returns the file path for a given operator's traced split file. +std::string getOpTraceSplitFilePath(const std::string& opTraceDir); + +/// Returns the file path for a given operator's traced input file. +std::string getOpTraceSummaryFilePath(const std::string& opTraceDir); + +/// Extracts the input data type for the trace scan operator. The function first +/// uses the traced node id to find traced operator's plan node from the traced +/// plan fragment. Then it uses the specified source node index to find the +/// output data type from its source node plans as the input data type of the +/// traced plan node. +/// +/// For hash join plan node, there are two source nodes, the output data type +/// of the first node is the input data type of the 'HashProbe' operator, and +/// the output data type of the second one is the input data type of the +/// 'HashBuild' operator. +/// +/// @param tracedPlan The root node of the trace plan fragment. +/// @param tracedNodeId The node id of the trace node. +/// @param sourceIndex The source index of the specific traced operator. +RowTypePtr getDataType( + const core::PlanNodePtr& tracedPlan, + const std::string& tracedNodeId, + size_t sourceIndex = 0); + +/// Extracts pipeline IDs in ascending order by listing the trace directory, +/// then decoding the names of the subdirectories to obtain the pipeline IDs, +/// and finally sorting them. 'nodeTraceDir' corresponds to the trace directory +/// of the plan node. +std::vector listPipelineIds( + const std::string& nodeTraceDir, + const std::shared_ptr& fs); + +/// Extracts driver IDs in ascending order by listing the trace directory for a +/// given pipeline then decoding the names of the subdirectories to obtain the +/// driver IDs, and finally sorting them. 'nodeTraceDir' corresponds to the +/// trace directory of the plan node. +std::vector listDriverIds( + const std::string& nodeTraceDir, + uint32_t pipelineId, + const std::shared_ptr& fs); + +/// Extracts the driver IDs from the comma-separated list of driver IDs string. +std::vector extractDriverIds(const std::string& driverIds); + +/// Extracts task ids of the query tracing by listing the query trace directory. +/// 'traceDir' is the root trace directory. 'queryId' is the query id. +std::vector getTaskIds( + const std::string& traceDir, + const std::string& queryId, + const std::shared_ptr& fs); + +/// Gets the metadata from a given task metadata file which includes query plan, +/// configs and connector properties. +folly::dynamic getTaskMetadata( + const std::string& taskMetaFilePath, + const std::shared_ptr& fs); + +/// Checks whether the operator can be traced. +bool canTrace(const std::string& operatorType); + +/// Gets the specified the trace node from 'plan'. In the returned trace node, +/// we replace its source nodes with DummySourceNode for replay. +core::PlanNodePtr getTraceNode( + const core::PlanNode& plan, + core::PlanNodeId nodeId); + +using TraceNodeFactory = std::function< + core::PlanNodePtr(const core::PlanNode*, const core::PlanNodeId&)>; + +void registerTraceNodeFactory( + const std::string& operatorType, + TraceNodeFactory&& factory); + +void registerDummySourceSerDe(); + +} // namespace facebook::velox::exec::trace diff --git a/velox/exec/trace/TraceWriter.h b/velox/exec/trace/TraceWriter.h new file mode 100644 index 00000000000..4565c49c756 --- /dev/null +++ b/velox/exec/trace/TraceWriter.h @@ -0,0 +1,86 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/vector/ComplexVector.h" + +namespace facebook::velox::exec { +struct Split; +} + +namespace facebook::velox::core { +class PlanNode; +class QueryCtx; +} // namespace facebook::velox::core + +namespace facebook::velox::exec::trace { + +/// Abstract interface for capturing traced input. Implementations are +/// responsible for serializing and/or storing row batches during query +/// execution tracing, along with associated metadata and summaries. +class TraceInputWriter { + public: + virtual ~TraceInputWriter() = default; + + /// Serializes rows and writes out each batch. Return whether the driver + /// should block the pipeline. If it returns true, a future needs to be set + /// (returned) to signal the driver when it can resume execution. + virtual bool write(const RowVectorPtr& rows, ContinueFuture* future) = 0; + + /// Closes the data file and writes out the data summary. + virtual void finish() = 0; +}; + +/// Abstract interface for capturing traced expression output. Implementations +/// are responsible for processing individual expression result vectors during +/// query execution tracing. +class TraceExprWriter { + public: + virtual ~TraceExprWriter() = default; + + /// Writes a single expression output vector. Expression tracing runs inside + /// Expr::apply(), which has no mechanism to block the driver pipeline. + virtual void write(const VectorPtr& result) = 0; + + /// Closes the data file and writes out the data summary. + virtual void finish() = 0; +}; + +/// Abstract interface for capturing traced split information. Implementations +/// are responsible for processing and/or recording the splits found during +/// query execution tracing, enabling replay and analysis of query input +/// patterns. +class TraceSplitWriter { + public: + virtual ~TraceSplitWriter() = default; + + virtual void write(const exec::Split& split) const = 0; + + virtual void finish() = 0; +}; + +/// Abstract interface for capturing task metadata. +class TraceMetadataWriter { + public: + virtual ~TraceMetadataWriter() = default; + + virtual void write( + const core::QueryCtx& queryCtx, + const core::PlanNode& planNode) = 0; +}; + +} // namespace facebook::velox::exec::trace diff --git a/velox/experimental/breeze/CMakeLists.txt b/velox/experimental/breeze/CMakeLists.txt index 90001307ee8..922624a5946 100644 --- a/velox/experimental/breeze/CMakeLists.txt +++ b/velox/experimental/breeze/CMakeLists.txt @@ -118,7 +118,7 @@ if(BUILD_TRACING) include_directories(external/perfetto/sdk) add_library(perfetto STATIC ../../external/perfetto/sdk/perfetto.cc) target_compile_features(perfetto PRIVATE cxx_std_17) - target_compile_options(perfetto PRIVATE -fPIC) + target_compile_options(perfetto PRIVATE -fPIC -Wno-array-bounds -Wno-stringop-overflow) target_include_directories(perfetto INTERFACE ../../external/perfetto/sdk) target_compile_definitions(perfetto INTERFACE TRACING=1) endif() diff --git a/velox/experimental/breeze/breeze/platforms/cuda.cuh b/velox/experimental/breeze/breeze/platforms/cuda.cuh index 661e0743b46..02378f9af45 100644 --- a/velox/experimental/breeze/breeze/platforms/cuda.cuh +++ b/velox/experimental/breeze/breeze/platforms/cuda.cuh @@ -371,12 +371,12 @@ CudaSpecialization::atomic_add(address.data()); + *reinterpret_cast(address.data()); unsigned long long assumed; do { assumed = old; old = atomicCAS( - reinterpret_cast(address.data()), assumed, + reinterpret_cast(address.data()), assumed, __double_as_longlong(value + __longlong_as_double(assumed))); } while (assumed != old); @@ -394,9 +394,9 @@ CudaSpecialization::atomic_add(address.data()), - *reinterpret_cast(&value)); - return *reinterpret_cast(&result); + atomicAdd(reinterpret_cast(address.data()), + *reinterpret_cast(&value)); + return *reinterpret_cast(&result); } // specialization for T=Slice @@ -409,10 +409,10 @@ __device__ __forceinline__ void CudaSpecialization::atomic_min< static_assert(sizeof(float) == sizeof(unsigned), "unexpected type sizes"); float current = atomic_load(address); while (current > value) { - unsigned old = atomicCAS(reinterpret_cast(address.data()), - *reinterpret_cast(¤t), - *reinterpret_cast(&value)); - current = *reinterpret_cast(&old); + unsigned old = atomicCAS(reinterpret_cast(address.data()), + *reinterpret_cast(¤t), + *reinterpret_cast(&value)); + current = *reinterpret_cast(&old); if (current == value) { break; } @@ -429,10 +429,10 @@ __device__ __forceinline__ void CudaSpecialization::atomic_max< static_assert(sizeof(float) == sizeof(unsigned), "unexpected type sizes"); float current = atomic_load(address); while (current < value) { - unsigned old = atomicCAS(reinterpret_cast(address.data()), - *reinterpret_cast(¤t), - *reinterpret_cast(&value)); - current = *reinterpret_cast(&old); + unsigned old = atomicCAS(reinterpret_cast(address.data()), + *reinterpret_cast(¤t), + *reinterpret_cast(&value)); + current = *reinterpret_cast(&old); if (current == value) { break; } @@ -452,10 +452,10 @@ CudaSpecialization::atomic_min value) { unsigned long long old = - atomicCAS(reinterpret_cast(address.data()), - *reinterpret_cast(¤t), - *reinterpret_cast(&value)); - current = *reinterpret_cast(&old); + atomicCAS(reinterpret_cast(address.data()), + *reinterpret_cast(¤t), + *reinterpret_cast(&value)); + current = *reinterpret_cast(&old); if (current == value) { break; } @@ -475,10 +475,10 @@ CudaSpecialization::atomic_max(address.data()), - *reinterpret_cast(¤t), - *reinterpret_cast(&value)); - current = *reinterpret_cast(&old); + atomicCAS(reinterpret_cast(address.data()), + *reinterpret_cast(¤t), + *reinterpret_cast(&value)); + current = *reinterpret_cast(&old); if (current == value) { break; } @@ -496,10 +496,10 @@ __device__ __forceinline__ long long CudaSpecialization::atomic_cas< address, long long compare, long long value) { unsigned long long old = - atomicCAS(reinterpret_cast(address.data()), - *reinterpret_cast(&compare), - *reinterpret_cast(&value)); - return *reinterpret_cast(&old); + atomicCAS(reinterpret_cast(address.data()), + *reinterpret_cast(&compare), + *reinterpret_cast(&value)); + return *reinterpret_cast(&old); } // specialization for T=Slice @@ -518,9 +518,9 @@ __device__ __forceinline__ long long CudaSpecialization::atomic_cas< unsigned long long>::pointer_type; unsigned long long old = atomicCAS(reinterpret_cast(address.data()), - *reinterpret_cast(&compare), - *reinterpret_cast(&value)); - return *reinterpret_cast(&old); + *reinterpret_cast(&compare), + *reinterpret_cast(&value)); + return *reinterpret_cast(&old); } // specialization for T=Slice @@ -532,10 +532,10 @@ __device__ __forceinline__ float CudaSpecialization::atomic_cas< address, float compare, float value) { static_assert(sizeof(float) == sizeof(unsigned), "unexpected type sizes"); - unsigned old = atomicCAS(reinterpret_cast(address.data()), - *reinterpret_cast(&compare), - *reinterpret_cast(&value)); - return *reinterpret_cast(&old); + unsigned old = atomicCAS(reinterpret_cast(address.data()), + *reinterpret_cast(&compare), + *reinterpret_cast(&value)); + return *reinterpret_cast(&old); } #if __CUDA_ARCH__ >= 800 diff --git a/velox/experimental/breeze/breeze/platforms/hip.hpp b/velox/experimental/breeze/breeze/platforms/hip.hpp index 2ddef322d5e..9f9b38d19f7 100644 --- a/velox/experimental/breeze/breeze/platforms/hip.hpp +++ b/velox/experimental/breeze/breeze/platforms/hip.hpp @@ -70,11 +70,11 @@ struct HipPlatform { } template __device__ __forceinline__ T atomic_load(SliceT address) { - return *reinterpret_cast(address.data()); + return *reinterpret_cast(address.data()); } template __device__ __forceinline__ void atomic_store(SliceT address, T value) { - *reinterpret_cast(address.data()) = value; + *reinterpret_cast(address.data()) = value; } template __device__ __forceinline__ T atomic_cas(SliceT address, T compare, T value) { diff --git a/velox/experimental/breeze/breeze/platforms/metal.h b/velox/experimental/breeze/breeze/platforms/metal.h index c1a3e28882f..498fa31858c 100644 --- a/velox/experimental/breeze/breeze/platforms/metal.h +++ b/velox/experimental/breeze/breeze/platforms/metal.h @@ -78,16 +78,15 @@ struct MetalPlatform { MetalSpecialization::atomic_store(address, value); } template